remove the ConnStateCallback from the client

Dial and DialAddr return once the connection is forward secure. There is
currently no option to get the session earlier, this will be added later.
This commit is contained in:
Marten Seemann 2017-05-06 11:37:44 +08:00
parent 30a0211243
commit 612323985b
No known key found for this signature in database
GPG key ID: 3603F40B121FCDEA
4 changed files with 78 additions and 242 deletions

128
client.go
View file

@ -14,15 +14,17 @@ import (
) )
type client struct { type client struct {
mutex sync.Mutex mutex sync.Mutex
connStateChangeOrErrCond sync.Cond listenErr error
listenErr error
conn connection conn connection
hostname string hostname string
errorChan chan struct{}
config *Config handshakeChan chan struct{} // is closed as soon as the handshake completes
connState ConnState
config *Config
versionNegotiated bool // has version negotiation completed yet
connectionID protocol.ConnectionID connectionID protocol.ConnectionID
version protocol.VersionNumber version protocol.VersionNumber
@ -34,6 +36,20 @@ var (
errCloseSessionForNewVersion = errors.New("closing session in order to recreate it with a new version") errCloseSessionForNewVersion = errors.New("closing session in order to recreate it with a new version")
) )
// DialAddr establishes a new QUIC connection to a server.
// The hostname for SNI is taken from the given address.
func DialAddr(addr string, config *Config) (Session, error) {
udpAddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil {
return nil, err
}
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
return nil, err
}
return Dial(udpConn, udpAddr, addr, config)
}
// Dial establishes a new QUIC connection to a server using a net.PacketConn. // Dial establishes a new QUIC connection to a server using a net.PacketConn.
// The host parameter is used for SNI. // The host parameter is used for SNI.
func Dial(pconn net.PacketConn, remoteAddr net.Addr, host string, config *Config) (Session, error) { func Dial(pconn net.PacketConn, remoteAddr net.Addr, host string, config *Config) (Session, error) {
@ -49,15 +65,15 @@ func Dial(pconn net.PacketConn, remoteAddr net.Addr, host string, config *Config
clientConfig := populateClientConfig(config) clientConfig := populateClientConfig(config)
c := &client{ c := &client{
conn: &conn{pconn: pconn, currentAddr: remoteAddr}, conn: &conn{pconn: pconn, currentAddr: remoteAddr},
connectionID: connID, connectionID: connID,
hostname: hostname, hostname: hostname,
config: clientConfig, config: clientConfig,
version: clientConfig.Versions[0], version: clientConfig.Versions[0],
errorChan: make(chan struct{}),
handshakeChan: make(chan struct{}),
} }
c.connStateChangeOrErrCond.L = &c.mutex
err = c.createNewSession(nil) err = c.createNewSession(nil)
if err != nil { if err != nil {
return nil, err return nil, err
@ -76,48 +92,20 @@ func populateClientConfig(config *Config) *Config {
return &Config{ return &Config{
TLSConfig: config.TLSConfig, TLSConfig: config.TLSConfig,
ConnState: config.ConnState,
Versions: versions, Versions: versions,
RequestConnectionIDTruncation: config.RequestConnectionIDTruncation, RequestConnectionIDTruncation: config.RequestConnectionIDTruncation,
} }
} }
// DialAddr establishes a new QUIC connection to a server.
// The hostname for SNI is taken from the given address.
func DialAddr(addr string, config *Config) (Session, error) {
udpAddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil {
return nil, err
}
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
return nil, err
}
return Dial(udpConn, udpAddr, addr, config)
}
func (c *client) establishConnection() (Session, error) { func (c *client) establishConnection() (Session, error) {
go c.listen() go c.listen()
c.mutex.Lock() select {
defer c.mutex.Unlock() case <-c.errorChan:
return nil, c.listenErr
for { case <-c.handshakeChan:
if c.listenErr != nil { return c.session, nil
return nil, c.listenErr
}
if c.config.ConnState != nil && c.connState >= ConnStateVersionNegotiated {
break
}
if c.config.ConnState == nil && c.connState == ConnStateForwardSecure {
break
}
c.connStateChangeOrErrCond.Wait()
} }
return c.session, nil
} }
// Listen listens // Listen listens
@ -147,11 +135,6 @@ func (c *client) listen() {
break break
} }
} }
c.mutex.Lock()
c.listenErr = err
c.connStateChangeOrErrCond.Signal()
c.mutex.Unlock()
} }
func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) error { func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) error {
@ -168,18 +151,14 @@ func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) error {
defer c.mutex.Unlock() defer c.mutex.Unlock()
// ignore delayed / duplicated version negotiation packets // ignore delayed / duplicated version negotiation packets
if c.connState >= ConnStateVersionNegotiated && hdr.VersionFlag { if c.versionNegotiated && hdr.VersionFlag {
return nil return nil
} }
// this is the first packet after the client sent a packet with the VersionFlag set // this is the first packet after the client sent a packet with the VersionFlag set
// if the server doesn't send a version negotiation packet, it supports the suggested version // if the server doesn't send a version negotiation packet, it supports the suggested version
if !hdr.VersionFlag && c.connState == ConnStateInitial { if !hdr.VersionFlag && !c.versionNegotiated {
c.connState = ConnStateVersionNegotiated c.versionNegotiated = true
c.connStateChangeOrErrCond.Signal()
if c.config.ConnState != nil {
go c.config.ConnState(c.session, ConnStateVersionNegotiated)
}
} }
if hdr.VersionFlag { if hdr.VersionFlag {
@ -213,7 +192,7 @@ func (c *client) handlePacketWithVersionFlag(hdr *PublicHeader) error {
// switch to negotiated version // switch to negotiated version
c.version = newVersion c.version = newVersion
c.connState = ConnStateVersionNegotiated c.versionNegotiated = true
var err error var err error
c.connectionID, err = utils.GenerateConnectionID() c.connectionID, err = utils.GenerateConnectionID()
if err != nil { if err != nil {
@ -222,32 +201,12 @@ func (c *client) handlePacketWithVersionFlag(hdr *PublicHeader) error {
utils.Infof("Switching to QUIC version %d. New connection ID: %x", newVersion, c.connectionID) utils.Infof("Switching to QUIC version %d. New connection ID: %x", newVersion, c.connectionID)
c.session.Close(errCloseSessionForNewVersion) c.session.Close(errCloseSessionForNewVersion)
err = c.createNewSession(hdr.SupportedVersions) return c.createNewSession(hdr.SupportedVersions)
if err != nil {
return err
}
if c.config.ConnState != nil {
go c.config.ConnState(c.session, ConnStateVersionNegotiated)
}
return nil
} }
func (c *client) cryptoChangeCallback(_ Session, isForwardSecure bool) { func (c *client) cryptoChangeCallback(_ Session, isForwardSecure bool) {
var state ConnState
if isForwardSecure { if isForwardSecure {
state = ConnStateForwardSecure close(c.handshakeChan)
} else {
state = ConnStateSecure
}
c.mutex.Lock()
c.connState = state
c.connStateChangeOrErrCond.Signal()
c.mutex.Unlock()
if c.config.ConnState != nil {
go c.config.ConnState(c.session, state)
} }
} }
@ -272,11 +231,8 @@ func (c *client) createNewSession(negotiatedVersions []protocol.VersionNumber) e
if err == errCloseSessionForNewVersion { if err == errCloseSessionForNewVersion {
return return
} }
c.mutex.Lock()
c.listenErr = err c.listenErr = err
c.connStateChangeOrErrCond.Signal() close(c.errorChan)
c.mutex.Unlock()
utils.Infof("Connection %x closed.", c.connectionID) utils.Infof("Connection %x closed.", c.connectionID)
c.conn.Close() c.conn.Close()

View file

@ -16,34 +16,29 @@ import (
var _ = Describe("Client", func() { var _ = Describe("Client", func() {
var ( var (
cl *client cl *client
config *Config config *Config
sess *mockSession sess *mockSession
packetConn *mockPacketConn packetConn *mockPacketConn
addr net.Addr addr net.Addr
versionNegotiateConnStateCalled bool
) )
BeforeEach(func() { BeforeEach(func() {
Eventually(areSessionsRunning).Should(BeFalse()) Eventually(areSessionsRunning).Should(BeFalse())
versionNegotiateConnStateCalled = false
packetConn = &mockPacketConn{} packetConn = &mockPacketConn{}
config = &Config{ config = &Config{
ConnState: func(_ Session, state ConnState) {
if state == ConnStateVersionNegotiated {
versionNegotiateConnStateCalled = true
}
},
Versions: []protocol.VersionNumber{protocol.SupportedVersions[0], 77, 78}, Versions: []protocol.VersionNumber{protocol.SupportedVersions[0], 77, 78},
} }
addr = &net.UDPAddr{IP: net.IPv4(192, 168, 100, 200), Port: 1337} addr = &net.UDPAddr{IP: net.IPv4(192, 168, 100, 200), Port: 1337}
sess = &mockSession{connectionID: 0x1337} sess = &mockSession{connectionID: 0x1337}
cl = &client{ cl = &client{
config: config, config: config,
connectionID: 0x1337, connectionID: 0x1337,
session: sess, session: sess,
version: protocol.SupportedVersions[0], version: protocol.SupportedVersions[0],
conn: &conn{pconn: packetConn, currentAddr: addr}, conn: &conn{pconn: packetConn, currentAddr: addr},
errorChan: make(chan struct{}),
handshakeChan: make(chan struct{}),
} }
}) })
@ -55,7 +50,7 @@ var _ = Describe("Client", func() {
}) })
Context("Dialing", func() { Context("Dialing", func() {
It("creates a new client", func() { PIt("creates a new client", func() {
packetConn.dataToRead = []byte{0x0, 0x1, 0x0} packetConn.dataToRead = []byte{0x0, 0x1, 0x0}
sess, err := Dial(packetConn, addr, "quic.clemente.io:1337", config) sess, err := Dial(packetConn, addr, "quic.clemente.io:1337", config)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
@ -82,41 +77,6 @@ var _ = Describe("Client", func() {
_, err := Dial(packetConn, addr, "quic.clemente.io:1337", config) _, err := Dial(packetConn, addr, "quic.clemente.io:1337", config)
Expect(err).To(MatchError(testErr)) Expect(err).To(MatchError(testErr))
}) })
// now we're only testing that Dial doesn't return directly after version negotiation
PIt("doesn't return after version negotiation is established if no ConnState is defined", func() {
// TODO(#506): Fix test
packetConn.dataToRead = []byte{0x0, 0x1, 0x0}
config.ConnState = nil
var dialReturned bool
go func() {
defer GinkgoRecover()
_, err := Dial(packetConn, addr, "quic.clemente.io:1337", config)
Expect(err).ToNot(HaveOccurred())
dialReturned = true
}()
Consistently(func() bool { return dialReturned }).Should(BeFalse())
})
It("only establishes a connection once it is forward-secure if no ConnState is defined", func() {
config.ConnState = nil
client := &client{conn: &conn{pconn: packetConn, currentAddr: addr}, config: config}
client.connStateChangeOrErrCond.L = &client.mutex
var returned bool
go func() {
defer GinkgoRecover()
_, err := client.establishConnection()
Expect(err).ToNot(HaveOccurred())
returned = true
}()
Consistently(func() bool { return returned }).Should(BeFalse())
// switch to a secure connection
client.cryptoChangeCallback(nil, false)
Consistently(func() bool { return returned }).Should(BeFalse())
// switch to a forward-secure connection
client.cryptoChangeCallback(nil, true)
Eventually(func() bool { return returned }).Should(BeTrue())
})
}) })
It("errors on invalid public header", func() { It("errors on invalid public header", func() {
@ -124,9 +84,9 @@ var _ = Describe("Client", func() {
Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.InvalidPacketHeader)) Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.InvalidPacketHeader))
}) })
// this test requires a real session (because it calls the close callback) // this test requires a real session
// and a real UDP conn (because it unblocks and errors when it is closed) // and a real UDP conn (because it unblocks and errors when it is closed)
It("properly closes", func(done Done) { PIt("properly closes", func(done Done) {
Eventually(areSessionsRunning).Should(BeFalse()) Eventually(areSessionsRunning).Should(BeFalse())
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1)}) udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1)})
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
@ -213,8 +173,7 @@ var _ = Describe("Client", func() {
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
err = cl.handlePacket(nil, b.Bytes()) err = cl.handlePacket(nil, b.Bytes())
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(cl.connState).To(Equal(ConnStateVersionNegotiated)) Expect(cl.versionNegotiated).To(BeTrue())
Eventually(func() bool { return versionNegotiateConnStateCalled }).Should(BeTrue())
}) })
It("changes the version after receiving a version negotiation packet", func() { It("changes the version after receiving a version negotiation packet", func() {
@ -226,8 +185,7 @@ var _ = Describe("Client", func() {
err := cl.handlePacket(nil, composeVersionNegotiation(0x1337, []protocol.VersionNumber{newVersion})) err := cl.handlePacket(nil, composeVersionNegotiation(0x1337, []protocol.VersionNumber{newVersion}))
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(cl.version).To(Equal(newVersion)) Expect(cl.version).To(Equal(newVersion))
Expect(cl.connState).To(Equal(ConnStateVersionNegotiated)) Expect(cl.versionNegotiated).To(BeTrue())
Eventually(func() bool { return versionNegotiateConnStateCalled }).Should(BeTrue())
// it swapped the sessions // it swapped the sessions
Expect(cl.session).ToNot(Equal(sess)) Expect(cl.session).ToNot(Equal(sess))
Expect(cl.connectionID).ToNot(Equal(0x1337)) // it generated a new connection ID Expect(cl.connectionID).ToNot(Equal(0x1337)) // it generated a new connection ID
@ -260,13 +218,12 @@ var _ = Describe("Client", func() {
It("ignores delayed version negotiation packets", func() { It("ignores delayed version negotiation packets", func() {
// if the version was not yet negotiated, handlePacket would return a VersionNegotiationMismatch error, see above test // if the version was not yet negotiated, handlePacket would return a VersionNegotiationMismatch error, see above test
cl.connState = ConnStateVersionNegotiated cl.versionNegotiated = true
Expect(sess.packetCount).To(BeZero()) Expect(sess.packetCount).To(BeZero())
err := cl.handlePacket(nil, composeVersionNegotiation(0x1337, []protocol.VersionNumber{1})) err := cl.handlePacket(nil, composeVersionNegotiation(0x1337, []protocol.VersionNumber{1}))
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(cl.connState).To(Equal(ConnStateVersionNegotiated)) Expect(cl.versionNegotiated).To(BeTrue())
Expect(sess.packetCount).To(BeZero()) Expect(sess.packetCount).To(BeZero())
Consistently(func() bool { return versionNegotiateConnStateCalled }).Should(BeFalse())
}) })
It("drops version negotiation packets that contain the offered version", func() { It("drops version negotiation packets that contain the offered version", func() {

View file

@ -22,8 +22,7 @@ import (
// Client is a HTTP2 client doing QUIC requests // Client is a HTTP2 client doing QUIC requests
type Client struct { type Client struct {
mutex sync.RWMutex mutex sync.RWMutex
cryptoChangedCond sync.Cond
config *quic.Config config *quic.Config
@ -31,6 +30,7 @@ type Client struct {
hostname string hostname string
encryptionLevel protocol.EncryptionLevel encryptionLevel protocol.EncryptionLevel
dialChan chan struct{} // will be closed once the handshake is complete and the header stream has been opened
session quic.Session session quic.Session
headerStream quic.Stream headerStream quic.Stream
@ -44,57 +44,27 @@ var _ h2quicClient = &Client{}
// NewClient creates a new client // NewClient creates a new client
func NewClient(t *QuicRoundTripper, tlsConfig *tls.Config, hostname string) *Client { func NewClient(t *QuicRoundTripper, tlsConfig *tls.Config, hostname string) *Client {
c := &Client{ return &Client{
t: t, t: t,
hostname: authorityAddr("https", hostname), hostname: authorityAddr("https", hostname),
responses: make(map[protocol.StreamID]chan *http.Response), responses: make(map[protocol.StreamID]chan *http.Response),
encryptionLevel: protocol.EncryptionUnencrypted, encryptionLevel: protocol.EncryptionUnencrypted,
config: &quic.Config{
TLSConfig: tlsConfig,
RequestConnectionIDTruncation: true,
},
dialChan: make(chan struct{}),
} }
c.cryptoChangedCond = sync.Cond{L: &c.mutex}
c.config = &quic.Config{
ConnState: c.connStateCallback,
TLSConfig: tlsConfig,
RequestConnectionIDTruncation: true,
}
return c
} }
// Dial dials the connection // Dial dials the connection
func (c *Client) Dial() error { func (c *Client) Dial() error {
_, err := quic.DialAddr(c.hostname, c.config)
return err
}
// connStateCallback is the ConnStateCallback passed to the quic.Dial
// this function is called in a separate go-routine
func (c *Client) connStateCallback(sess quic.Session, state quic.ConnState) {
c.mutex.Lock()
if c.session == nil {
c.session = sess
}
switch state {
case quic.ConnStateVersionNegotiated:
err := c.versionNegotiateCallback()
if err != nil {
c.Close(err)
}
case quic.ConnStateSecure:
utils.Debugf("is secure")
// only save the encryption level if it is now higher than it was before
if c.encryptionLevel < protocol.EncryptionSecure {
c.encryptionLevel = protocol.EncryptionSecure
}
c.cryptoChangedCond.Broadcast()
case quic.ConnStateForwardSecure:
utils.Debugf("is forward secure")
c.encryptionLevel = protocol.EncryptionForwardSecure
c.cryptoChangedCond.Broadcast()
}
c.mutex.Unlock()
}
func (c *Client) versionNegotiateCallback() error {
var err error var err error
c.session, err = quic.DialAddr(c.hostname, c.config)
if err != nil {
return err
}
// once the version has been negotiated, open the header stream // once the version has been negotiated, open the header stream
c.headerStream, err = c.session.OpenStream() c.headerStream, err = c.session.OpenStream()
if err != nil { if err != nil {
@ -105,6 +75,7 @@ func (c *Client) versionNegotiateCallback() error {
} }
c.requestWriter = newRequestWriter(c.headerStream) c.requestWriter = newRequestWriter(c.headerStream)
go c.handleHeaderStream() go c.handleHeaderStream()
close(c.dialChan)
return nil return nil
} }
@ -170,16 +141,14 @@ func (c *Client) Do(req *http.Request) (*http.Response, error) {
hasBody := (req.Body != nil) hasBody := (req.Body != nil)
c.mutex.Lock() <-c.dialChan // wait until the handshake has completed
for c.encryptionLevel != protocol.EncryptionForwardSecure {
c.cryptoChangedCond.Wait()
}
hdrChan := make(chan *http.Response) hdrChan := make(chan *http.Response)
dataStream, err := c.session.OpenStreamSync() dataStream, err := c.session.OpenStreamSync()
if err != nil { if err != nil {
c.Close(err) c.Close(err)
return nil, err return nil, err
} }
c.mutex.Lock()
c.responses[dataStream.StreamID()] = hdrChan c.responses[dataStream.StreamID()] = hdrChan
c.mutex.Unlock() c.mutex.Unlock()

View file

@ -64,53 +64,6 @@ var _ = Describe("Client", func() {
Expect(hdr.ConnectionID).ToNot(BeNil()) Expect(hdr.ConnectionID).ToNot(BeNil())
}) })
It("saves the session when the ConnState callback is called", func() {
client.session = nil // unset the session set in BeforeEach
client.config.ConnState(session, quic.ConnStateForwardSecure)
Expect(client.session).To(Equal(session))
})
It("opens the header stream only after the version has been negotiated", func() {
client.headerStream = nil // unset the headerStream openend in the BeforeEach
session.streamToOpen = headerStream
Expect(client.headerStream).To(BeNil()) // header stream not yet opened
// now start the actual test
client.config.ConnState(session, quic.ConnStateVersionNegotiated)
Expect(client.headerStream).ToNot(BeNil())
Expect(client.headerStream.StreamID()).To(Equal(protocol.StreamID(3)))
})
It("errors if it can't open the header stream", func() {
testErr := errors.New("test error")
client.headerStream = nil // unset the headerStream openend in the BeforeEach
session.streamOpenErr = testErr
client.config.ConnState(session, quic.ConnStateVersionNegotiated)
Expect(session.closed).To(BeTrue())
Expect(session.closedWithError).To(MatchError(testErr))
})
It("errors if the header stream has the wrong StreamID", func() {
session.streamToOpen = &mockStream{id: 1337}
client.config.ConnState(session, quic.ConnStateVersionNegotiated)
Expect(session.closed).To(BeTrue())
Expect(session.closedWithError).To(MatchError("h2quic Client BUG: StreamID of Header Stream is not 3"))
})
It("sets the correct crypto level", func() {
Expect(client.encryptionLevel).To(Equal(protocol.EncryptionUnencrypted))
client.config.ConnState(session, quic.ConnStateSecure)
Expect(client.encryptionLevel).To(Equal(protocol.EncryptionSecure))
client.config.ConnState(session, quic.ConnStateForwardSecure)
Expect(client.encryptionLevel).To(Equal(protocol.EncryptionForwardSecure))
})
It("sets the correct crypto level, if the ConnStateCallback is called in the wrong order", func() {
client.config.ConnState(session, quic.ConnStateForwardSecure)
Expect(client.encryptionLevel).To(Equal(protocol.EncryptionForwardSecure))
client.config.ConnState(session, quic.ConnStateSecure)
Expect(client.encryptionLevel).To(Equal(protocol.EncryptionForwardSecure))
})
Context("Doing requests", func() { Context("Doing requests", func() {
var request *http.Request var request *http.Request
var dataStream *mockStream var dataStream *mockStream
@ -143,6 +96,7 @@ var _ = Describe("Client", func() {
dataStream = &mockStream{id: 5} dataStream = &mockStream{id: 5}
session.streamToOpen = dataStream session.streamToOpen = dataStream
close(client.dialChan)
}) })
It("does a request", func(done Done) { It("does a request", func(done Done) {