diff --git a/client.go b/client.go index d45e2edc..bd0eb77a 100644 --- a/client.go +++ b/client.go @@ -6,6 +6,7 @@ import ( "errors" "net" "strings" + "sync" "sync/atomic" "time" @@ -14,45 +15,35 @@ import ( "github.com/lucas-clemente/quic-go/utils" ) -// A Client of QUIC -type Client struct { +type client struct { + mutex sync.Mutex + connStateChangeCond sync.Cond + conn connection hostname string + config *Config + connectionID protocol.ConnectionID version protocol.VersionNumber versionNegotiated bool closed uint32 // atomic bool - tlsConfig *tls.Config - cryptoChangeCallback CryptoChangeCallback - versionNegotiateCallback VersionNegotiateCallback + tlsConfig *tls.Config + cryptoChangeCallback CryptoChangeCallback session packetHandler } -// VersionNegotiateCallback is called once the client has a negotiated version -type VersionNegotiateCallback func() error - var errHostname = errors.New("Invalid hostname") var ( errCloseSessionForNewVersion = errors.New("closing session in order to recreate it with a new version") ) -// NewClient makes a new client -func NewClient(host string, tlsConfig *tls.Config, cryptoChangeCallback CryptoChangeCallback, versionNegotiateCallback VersionNegotiateCallback) (*Client, error) { - udpAddr, err := net.ResolveUDPAddr("udp", host) - if err != nil { - return nil, err - } - - udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) - if err != nil { - return nil, err - } - - connectionID, err := utils.GenerateConnectionID() +// Dial establishes a new QUIC connection to a server +func Dial(pconn net.PacketConn, remoteAddr net.Addr, host string, config *Config) (Session, error) { + connID, err := utils.GenerateConnectionID() if err != nil { return nil, err } @@ -62,28 +53,65 @@ func NewClient(host string, tlsConfig *tls.Config, cryptoChangeCallback CryptoCh return nil, err } - client := &Client{ - conn: &conn{pconn: udpConn, currentAddr: udpAddr}, - hostname: hostname, - version: protocol.SupportedVersions[len(protocol.SupportedVersions)-1], // use the highest supported version by default - connectionID: connectionID, - tlsConfig: tlsConfig, - cryptoChangeCallback: cryptoChangeCallback, - versionNegotiateCallback: versionNegotiateCallback, + c := &client{ + conn: &conn{pconn: pconn, currentAddr: remoteAddr}, + connectionID: connID, + hostname: hostname, + config: config, + version: protocol.SupportedVersions[len(protocol.SupportedVersions)-1], // use the highest supported version by default } - utils.Infof("Starting new connection to %s (%s), connectionID %x, version %d", host, udpAddr.String(), connectionID, client.version) + c.connStateChangeCond.L = &c.mutex - err = client.createNewSession(nil) + c.cryptoChangeCallback = func(isForwardSecure bool) { + var state ConnState + if isForwardSecure { + state = ConnStateForwardSecure + } else { + state = ConnStateSecure + } + + if c.config.ConnState != nil { + go config.ConnState(c.session, state) + } + } + + err = c.createNewSession(nil) if err != nil { return nil, err } - return client, nil + utils.Infof("Starting new connection to %s (%s), connectionID %x, version %d", hostname, c.conn.RemoteAddr().String(), c.connectionID, c.version) + + // TODO: handle errors + go c.Listen() + + c.mutex.Lock() + for !c.versionNegotiated { + c.connStateChangeCond.Wait() + } + c.mutex.Unlock() + + return c.session, nil +} + +// DialAddr establishes a new QUIC connection to a server +func DialAddr(hostname string, config *Config) (Session, error) { + udpAddr, err := net.ResolveUDPAddr("udp", hostname) + if err != nil { + return nil, err + } + + udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) + if err != nil { + return nil, err + } + + return Dial(udpConn, udpAddr, hostname, config) } // Listen listens -func (c *Client) Listen() error { +func (c *client) Listen() error { for { data := getPacketBuffer() data = data[:protocol.MaxPacketSize] @@ -106,13 +134,8 @@ func (c *Client) Listen() error { } } -// OpenStream opens a stream, for client-side created streams (i.e. odd streamIDs) -func (c *Client) OpenStream() (Stream, error) { - return c.session.OpenStream() -} - // Close closes the connection -func (c *Client) Close(e error) error { +func (c *client) Close(e error) error { // Only close once if !atomic.CompareAndSwapUint32(&c.closed, 0, 1) { return nil @@ -122,7 +145,7 @@ func (c *Client) Close(e error) error { return c.conn.Close() } -func (c *Client) handlePacket(remoteAddr net.Addr, packet []byte) error { +func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) error { if protocol.ByteCount(len(packet)) > protocol.MaxPacketSize { return qerr.PacketTooLarge } @@ -145,10 +168,12 @@ func (c *Client) handlePacket(remoteAddr net.Addr, packet []byte) error { // this is the first packet after the client sent a packet with the VersionFlag set // if the server doesn't send a version negotiation packet, it supports the suggested version if !hdr.VersionFlag && !c.versionNegotiated { + c.mutex.Lock() c.versionNegotiated = true - err = c.versionNegotiateCallback() - if err != nil { - return err + c.connStateChangeCond.Signal() + c.mutex.Unlock() + if c.config.ConnState != nil { + go c.config.ConnState(c.session, ConnStateVersionNegotiated) } } @@ -187,7 +212,9 @@ func (c *Client) handlePacket(remoteAddr net.Addr, packet []byte) error { if err != nil { return err } - err = c.versionNegotiateCallback() + if c.config.ConnState != nil { + go c.config.ConnState(c.session, ConnStateVersionNegotiated) + } if err != nil { return err } @@ -204,14 +231,14 @@ func (c *Client) handlePacket(remoteAddr net.Addr, packet []byte) error { return nil } -func (c *Client) createNewSession(negotiatedVersions []protocol.VersionNumber) error { +func (c *client) createNewSession(negotiatedVersions []protocol.VersionNumber) error { var err error c.session, err = newClientSession( c.conn, c.hostname, c.version, c.connectionID, - c.tlsConfig, + c.config.TLSConfig, c.closeCallback, c.cryptoChangeCallback, negotiatedVersions) @@ -223,6 +250,6 @@ func (c *Client) createNewSession(negotiatedVersions []protocol.VersionNumber) e return nil } -func (c *Client) closeCallback(id protocol.ConnectionID) { +func (c *client) closeCallback(id protocol.ConnectionID) { utils.Infof("Connection %x closed.", id) } diff --git a/client_test.go b/client_test.go index 0b398b1d..f65a4de3 100644 --- a/client_test.go +++ b/client_test.go @@ -19,44 +19,51 @@ import ( var _ = Describe("Client", func() { var ( - client *Client - sess *mockSession - packetConn *mockPacketConn - versionNegotiateCallbackCalled bool + cl *client + config *Config + sess *mockSession + packetConn *mockPacketConn + addr net.Addr + versionNegotiateConnStateCalled bool ) BeforeEach(func() { + versionNegotiateConnStateCalled = false packetConn = &mockPacketConn{} - versionNegotiateCallbackCalled = false - client = &Client{ - versionNegotiateCallback: func() error { - versionNegotiateCallbackCalled = true - return nil + config = &Config{ + ConnState: func(_ Session, state ConnState) { + if state == ConnStateVersionNegotiated { + versionNegotiateConnStateCalled = true + } }, } - addr := &net.UDPAddr{IP: net.IPv4(192, 168, 100, 200), Port: 1337} + addr = &net.UDPAddr{IP: net.IPv4(192, 168, 100, 200), Port: 1337} sess = &mockSession{connectionID: 0x1337} - client.connectionID = 0x1337 - client.session = sess - client.version = protocol.Version36 - client.conn = &conn{pconn: packetConn, currentAddr: addr} + cl = &client{ + config: config, + connectionID: 0x1337, + session: sess, + version: protocol.Version36, + conn: &conn{pconn: packetConn, currentAddr: addr}, + } }) It("creates a new client", func() { + packetConn.dataToRead = []byte{0x0, 0x1, 0x0} var err error - client, err = NewClient("quic.clemente.io:1337", nil, nil, nil) + sess, err := Dial(packetConn, addr, "quic.clemente.io:1337", config) Expect(err).ToNot(HaveOccurred()) - Expect(client.hostname).To(Equal("quic.clemente.io")) - Expect(*(*[]protocol.VersionNumber)(unsafe.Pointer(reflect.ValueOf(client.session.(*session).cryptoSetup).Elem().FieldByName("negotiatedVersions").UnsafeAddr()))).To(BeNil()) + Expect(*(*[]protocol.VersionNumber)(unsafe.Pointer(reflect.ValueOf(sess.(*session).cryptoSetup).Elem().FieldByName("negotiatedVersions").UnsafeAddr()))).To(BeNil()) + Expect(*(*string)(unsafe.Pointer(reflect.ValueOf(sess.(*session).cryptoSetup).Elem().FieldByName("hostname").UnsafeAddr()))).To(Equal("quic.clemente.io")) }) It("errors on invalid public header", func() { - err := client.handlePacket(nil, nil) + err := cl.handlePacket(nil, nil) Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.InvalidPacketHeader)) }) It("errors on large packets", func() { - err := client.handlePacket(nil, bytes.Repeat([]byte{'a'}, int(protocol.MaxPacketSize)+1)) + err := cl.handlePacket(nil, bytes.Repeat([]byte{'a'}, int(protocol.MaxPacketSize)+1)) Expect(err).To(MatchError(qerr.PacketTooLarge)) }) @@ -68,51 +75,45 @@ var _ = Describe("Client", func() { var stoppedListening bool go func() { defer GinkgoRecover() - err := client.Listen() + err := cl.Listen() Expect(err).ToNot(HaveOccurred()) stoppedListening = true }() - err := client.Close(testErr) + err := cl.Close(testErr) Expect(err).ToNot(HaveOccurred()) Eventually(sess.closed).Should(BeTrue()) Expect(sess.closeReason).To(MatchError(testErr)) - Expect(client.closed).To(Equal(uint32(1))) + Expect(cl.closed).To(Equal(uint32(1))) Eventually(func() bool { return stoppedListening }).Should(BeTrue()) Eventually(runtime.NumGoroutine()).Should(Equal(numGoRoutines)) close(done) }, 10) It("only closes the client once", func() { - client.closed = 1 - err := client.Close(errors.New("test error")) + cl.closed = 1 + err := cl.Close(errors.New("test error")) Expect(err).ToNot(HaveOccurred()) Eventually(sess.closed).Should(BeFalse()) Expect(sess.closeReason).ToNot(HaveOccurred()) }) It("creates new sessions with the right parameters", func() { - client.session = nil - client.hostname = "hostname" - err := client.createNewSession(nil) + cl.session = nil + cl.hostname = "hostname" + err := cl.createNewSession(nil) Expect(err).ToNot(HaveOccurred()) - Expect(client.session).ToNot(BeNil()) - Expect(client.session.(*session).connectionID).To(Equal(client.connectionID)) - Expect(client.session.(*session).version).To(Equal(client.version)) + Expect(cl.session).ToNot(BeNil()) + Expect(cl.session.(*session).connectionID).To(Equal(cl.connectionID)) + Expect(cl.session.(*session).version).To(Equal(cl.version)) - err = client.Close(nil) + err = cl.Close(nil) Expect(err).ToNot(HaveOccurred()) }) - It("opens a stream", func() { - stream, err := client.OpenStream() - Expect(err).ToNot(HaveOccurred()) - Expect(stream).ToNot(BeNil()) - }) - Context("handling packets", func() { It("errors on too large packets", func() { - err := client.handlePacket(nil, bytes.Repeat([]byte{'f'}, int(protocol.MaxPacketSize+1))) + err := cl.handlePacket(nil, bytes.Repeat([]byte{'f'}, int(protocol.MaxPacketSize+1))) Expect(err).To(MatchError(qerr.PacketTooLarge)) }) @@ -130,7 +131,7 @@ var _ = Describe("Client", func() { Expect(sess.packetCount).To(BeZero()) var stoppedListening bool go func() { - _ = client.Listen() + _ = cl.Listen() // it should continue listening when receiving valid packets stoppedListening = true }() @@ -142,7 +143,7 @@ var _ = Describe("Client", func() { It("closes the session when encountering an error while handling a packet", func() { packetConn.dataToRead = bytes.Repeat([]byte{0xff}, 100) - listenErr := client.Listen() + listenErr := cl.Listen() Expect(listenErr).To(HaveOccurred()) Expect(sess.closed).To(BeTrue()) Expect(sess.closeReason).To(MatchError(listenErr)) @@ -160,7 +161,7 @@ var _ = Describe("Client", func() { b.Write(s) } protocol.SupportedVersionsAsTags = b.Bytes() - packet := composeVersionNegotiation(client.connectionID) + packet := composeVersionNegotiation(cl.connectionID) protocol.SupportedVersionsAsTags = oldSupportVersionTags Expect(composeVersionNegotiation(0x1337)).To(Equal(oldVersionNegotiationPacket)) return packet @@ -175,51 +176,51 @@ var _ = Describe("Client", func() { b := &bytes.Buffer{} err := ph.Write(b, protocol.VersionWhatever, protocol.PerspectiveServer) Expect(err).ToNot(HaveOccurred()) - err = client.handlePacket(nil, b.Bytes()) + err = cl.handlePacket(nil, b.Bytes()) Expect(err).ToNot(HaveOccurred()) - Expect(client.versionNegotiated).To(BeTrue()) - Expect(versionNegotiateCallbackCalled).To(BeTrue()) + Expect(cl.versionNegotiated).To(BeTrue()) + Eventually(func() bool { return versionNegotiateConnStateCalled }).Should(BeTrue()) }) It("changes the version after receiving a version negotiation packet", func() { newVersion := protocol.Version35 - Expect(newVersion).ToNot(Equal(client.version)) + Expect(newVersion).ToNot(Equal(cl.version)) Expect(sess.packetCount).To(BeZero()) - client.connectionID = 0x1337 - err := client.handlePacket(nil, getVersionNegotiation([]protocol.VersionNumber{newVersion})) - Expect(client.version).To(Equal(newVersion)) - Expect(client.versionNegotiated).To(BeTrue()) - Expect(versionNegotiateCallbackCalled).To(BeTrue()) + cl.connectionID = 0x1337 + err := cl.handlePacket(nil, getVersionNegotiation([]protocol.VersionNumber{newVersion})) + Expect(cl.version).To(Equal(newVersion)) + Expect(cl.versionNegotiated).To(BeTrue()) + Eventually(func() bool { return versionNegotiateConnStateCalled }).Should(BeTrue()) // it swapped the sessions - Expect(client.session).ToNot(Equal(sess)) - Expect(client.connectionID).ToNot(Equal(0x1337)) // it generated a new connection ID + Expect(cl.session).ToNot(Equal(sess)) + Expect(cl.connectionID).ToNot(Equal(0x1337)) // it generated a new connection ID Expect(err).ToNot(HaveOccurred()) // it didn't pass the version negoation packet to the session (since it has no payload) Expect(sess.packetCount).To(BeZero()) - Expect(*(*[]protocol.VersionNumber)(unsafe.Pointer(reflect.ValueOf(client.session.(*session).cryptoSetup).Elem().FieldByName("negotiatedVersions").UnsafeAddr()))).To(Equal([]protocol.VersionNumber{35})) + Expect(*(*[]protocol.VersionNumber)(unsafe.Pointer(reflect.ValueOf(cl.session.(*session).cryptoSetup).Elem().FieldByName("negotiatedVersions").UnsafeAddr()))).To(Equal([]protocol.VersionNumber{35})) - err = client.Close(nil) + err = cl.Close(nil) Expect(err).ToNot(HaveOccurred()) }) It("errors if no matching version is found", func() { - err := client.handlePacket(nil, getVersionNegotiation([]protocol.VersionNumber{1})) + err := cl.handlePacket(nil, getVersionNegotiation([]protocol.VersionNumber{1})) Expect(err).To(MatchError(qerr.InvalidVersion)) }) It("ignores delayed version negotiation packets", func() { // if the version was not yet negotiated, handlePacket would return a VersionNegotiationMismatch error, see above test - client.versionNegotiated = true + cl.versionNegotiated = true Expect(sess.packetCount).To(BeZero()) - err := client.handlePacket(nil, getVersionNegotiation([]protocol.VersionNumber{1})) + err := cl.handlePacket(nil, getVersionNegotiation([]protocol.VersionNumber{1})) Expect(err).ToNot(HaveOccurred()) - Expect(client.versionNegotiated).To(BeTrue()) + Expect(cl.versionNegotiated).To(BeTrue()) Expect(sess.packetCount).To(BeZero()) - Expect(versionNegotiateCallbackCalled).To(BeFalse()) + Consistently(func() bool { return versionNegotiateConnStateCalled }).Should(BeFalse()) }) It("errors if the server should have accepted the offered version", func() { - err := client.handlePacket(nil, getVersionNegotiation([]protocol.VersionNumber{client.version})) + err := cl.handlePacket(nil, getVersionNegotiation([]protocol.VersionNumber{cl.version})) Expect(err).To(MatchError(qerr.Error(qerr.InvalidVersionNegotiationPacket, "Server already supports client's version and should have accepted the connection."))) }) }) diff --git a/h2quic/client.go b/h2quic/client.go index 1e270421..009b6eb4 100644 --- a/h2quic/client.go +++ b/h2quic/client.go @@ -20,23 +20,19 @@ import ( "github.com/lucas-clemente/quic-go/utils" ) -type quicClient interface { - OpenStream() (quic.Stream, error) - Close(error) error - Listen() error -} - // Client is a HTTP2 client doing QUIC requests type Client struct { mutex sync.RWMutex cryptoChangedCond sync.Cond + config *quic.Config + t *QuicRoundTripper hostname string encryptionLevel protocol.EncryptionLevel - client quicClient + session quic.Session headerStream quic.Stream headerErr *qerr.QuicError requestWriter *requestWriter @@ -47,42 +43,50 @@ type Client struct { var _ h2quicClient = &Client{} // NewClient creates a new client -func NewClient(t *QuicRoundTripper, tlsConfig *tls.Config, hostname string) (*Client, error) { +func NewClient(t *QuicRoundTripper, tlsConfig *tls.Config, hostname string) *Client { c := &Client{ t: t, hostname: authorityAddr("https", hostname), responses: make(map[protocol.StreamID]chan *http.Response), } c.cryptoChangedCond = sync.Cond{L: &c.mutex} - - var err error - c.client, err = quic.NewClient(c.hostname, tlsConfig, c.cryptoChangeCallback, c.versionNegotiateCallback) - if err != nil { - return nil, err + c.config = &quic.Config{ + ConnState: c.connStateCallback, } - - go c.client.Listen() - return c, nil + return c } -func (c *Client) cryptoChangeCallback(isForwardSecure bool) { - c.cryptoChangedCond.L.Lock() - defer c.cryptoChangedCond.L.Unlock() +// Dial dials the connection +func (c *Client) Dial() error { + _, err := quic.DialAddr(c.hostname, c.config) + return err +} - if isForwardSecure { - c.encryptionLevel = protocol.EncryptionForwardSecure - utils.Debugf("is forward secure") - } else { +func (c *Client) connStateCallback(sess quic.Session, state quic.ConnState) { + c.mutex.Lock() + if c.session == nil { + c.session = sess + } + switch state { + case quic.ConnStateVersionNegotiated: + // TODO: handle errors + c.versionNegotiateCallback() + case quic.ConnStateSecure: c.encryptionLevel = protocol.EncryptionSecure utils.Debugf("is secure") + c.cryptoChangedCond.Broadcast() + case quic.ConnStateForwardSecure: + c.encryptionLevel = protocol.EncryptionForwardSecure + utils.Debugf("is forward secure") + c.cryptoChangedCond.Broadcast() } - c.cryptoChangedCond.Broadcast() + c.mutex.Unlock() } func (c *Client) versionNegotiateCallback() error { var err error // once the version has been negotiated, open the header stream - c.headerStream, err = c.client.OpenStream() + c.headerStream, err = c.session.OpenStream() if err != nil { return err } @@ -162,7 +166,7 @@ func (c *Client) Do(req *http.Request) (*http.Response, error) { } hdrChan := make(chan *http.Response) // TODO: think about what to do with a TooManyOpenStreams error. Wait and retry? - dataStream, err := c.client.OpenStream() + dataStream, err := c.session.OpenStream() if err != nil { c.Close(err) return nil, err @@ -260,7 +264,7 @@ func (c *Client) writeRequestBody(dataStream quic.Stream, body io.ReadCloser) (e // Close closes the client func (c *Client) Close(e error) { - _ = c.client.Close(e) + _ = c.session.Close(e) } // copied from net/transport.go diff --git a/h2quic/client_test.go b/h2quic/client_test.go index 38d973ed..34ed1979 100644 --- a/h2quic/client_test.go +++ b/h2quic/client_test.go @@ -4,6 +4,7 @@ import ( "bytes" "compress/gzip" "errors" + "net" "net/http" "golang.org/x/net/http2" @@ -17,85 +18,73 @@ import ( . "github.com/onsi/gomega" ) -type mockQuicClient struct { - nextStream protocol.StreamID - streams map[protocol.StreamID]*mockStream - closeErr error -} - -func (m *mockQuicClient) Close(e error) error { m.closeErr = e; return nil } -func (m *mockQuicClient) Listen() error { panic("not implemented") } -func (m *mockQuicClient) OpenStream() (quic.Stream, error) { - id := m.nextStream - ms := &mockStream{id: id} - m.streams[id] = ms - m.nextStream += 2 - return ms, nil -} - -func newMockQuicClient() *mockQuicClient { - return &mockQuicClient{ - streams: make(map[protocol.StreamID]*mockStream), - nextStream: 5, - } -} - -var _ quicClient = &mockQuicClient{} - var _ = Describe("Client", func() { var ( client *Client - qClient *mockQuicClient + session *mockSession headerStream *mockStream quicTransport *QuicRoundTripper ) BeforeEach(func() { - var err error quicTransport = &QuicRoundTripper{} hostname := "quic.clemente.io:1337" - client, err = NewClient(quicTransport, nil, hostname) - Expect(err).ToNot(HaveOccurred()) + client = NewClient(quicTransport, nil, hostname) Expect(client.hostname).To(Equal(hostname)) - qClient = newMockQuicClient() - client.client = qClient + session = &mockSession{} + client.session = session - headerStream = &mockStream{} - qClient.streams[3] = headerStream + headerStream = &mockStream{id: 3} client.headerStream = headerStream client.requestWriter = newRequestWriter(headerStream) }) It("adds the port to the hostname, if none is given", func() { - var err error - client, err = NewClient(quicTransport, nil, "quic.clemente.io") - Expect(err).ToNot(HaveOccurred()) + client = NewClient(quicTransport, nil, "quic.clemente.io") Expect(client.hostname).To(Equal("quic.clemente.io:443")) }) + It("dials", func() { + udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0}) + Expect(err).ToNot(HaveOccurred()) + client = NewClient(quicTransport, nil, udpConn.LocalAddr().String()) + go client.Dial() + data := make([]byte, 100) + _, err = udpConn.Read(data) + hdr, err := quic.ParsePublicHeader(bytes.NewReader(data), protocol.PerspectiveClient) + Expect(err).ToNot(HaveOccurred()) + Expect(hdr.VersionFlag).To(BeTrue()) + Expect(hdr.ConnectionID).ToNot(BeNil()) + }) + + It("saves the session when the ConnState callback is called", func() { + client.session = nil // unset the session set in BeforeEach + client.config.ConnState(session, quic.ConnStateForwardSecure) + Expect(client.session).To(Equal(session)) + }) + It("opens the header stream only after the version has been negotiated", func() { // delete the headerStream openend in the BeforeEach client.headerStream = nil - delete(qClient.streams, 3) - qClient.nextStream = 3 + session.streamToOpen = headerStream Expect(client.headerStream).To(BeNil()) // header stream not yet opened // now start the actual test - err := client.versionNegotiateCallback() - Expect(err).ToNot(HaveOccurred()) + client.config.ConnState(session, quic.ConnStateVersionNegotiated) Expect(client.headerStream).ToNot(BeNil()) Expect(client.headerStream.StreamID()).To(Equal(protocol.StreamID(3))) }) It("sets the correct crypto level", func() { Expect(client.encryptionLevel).To(Equal(protocol.Unencrypted)) - client.cryptoChangeCallback(false) + client.config.ConnState(session, quic.ConnStateSecure) Expect(client.encryptionLevel).To(Equal(protocol.EncryptionSecure)) - client.cryptoChangeCallback(true) + client.config.ConnState(session, quic.ConnStateForwardSecure) Expect(client.encryptionLevel).To(Equal(protocol.EncryptionForwardSecure)) }) Context("Doing requests", func() { var request *http.Request + var dataStream *mockStream getRequest := func(data []byte) *http2.MetaHeadersFrame { r := bytes.NewReader(data) @@ -122,6 +111,9 @@ var _ = Describe("Client", func() { client.encryptionLevel = protocol.EncryptionForwardSecure request, err = http.NewRequest("https", "https://quic.clemente.io:1337/file1.dat", nil) Expect(err).ToNot(HaveOccurred()) + + dataStream = &mockStream{id: 5} + session.streamToOpen = dataStream }) It("does a request", func(done Done) { @@ -134,7 +126,6 @@ var _ = Describe("Client", func() { }() Eventually(func() []byte { return headerStream.dataWritten.Bytes() }).ShouldNot(BeEmpty()) - Expect(qClient.streams).Should(HaveKey(protocol.StreamID(5))) Expect(client.responses).To(HaveKey(protocol.StreamID(5))) rsp := &http.Response{ Status: "418 I'm a teapot", @@ -144,7 +135,7 @@ var _ = Describe("Client", func() { Eventually(func() bool { return doReturned }).Should(BeTrue()) Expect(doErr).ToNot(HaveOccurred()) Expect(doRsp).To(Equal(rsp)) - Expect(doRsp.Body).ToNot(BeNil()) + Expect(doRsp.Body).To(Equal(dataStream)) Expect(doRsp.ContentLength).To(BeEquivalentTo(-1)) Expect(doRsp.Request).To(Equal(request)) close(done) @@ -172,7 +163,7 @@ var _ = Describe("Client", func() { Expect(client.headerErr).To(HaveOccurred()) Expect(doErr).To(MatchError(client.headerErr)) Expect(doRsp).To(BeNil()) - Expect(client.client.(*mockQuicClient).closeErr).To(MatchError(client.headerErr)) + Expect(client.session.(*mockSession).closedWithError).To(MatchError(client.headerErr)) }) Context("validating the address", func() { @@ -192,8 +183,7 @@ var _ = Describe("Client", func() { It("adds the port for request URLs without one", func(done Done) { var err error - client, err = NewClient(quicTransport, nil, "quic.clemente.io") - Expect(err).ToNot(HaveOccurred()) + client = NewClient(quicTransport, nil, "quic.clemente.io") req, err := http.NewRequest("https", "https://quic.clemente.io/foobar.html", nil) Expect(err).ToNot(HaveOccurred()) @@ -251,7 +241,6 @@ var _ = Describe("Client", func() { }() Eventually(func() chan *http.Response { return client.responses[5] }).ShouldNot(BeNil()) client.responses[5] <- response - dataStream := qClient.streams[5] Eventually(func() bool { return doReturned }).Should(BeTrue()) Expect(dataStream.dataWritten.Bytes()).To(Equal(requestBody)) Expect(dataStream.closed).To(BeTrue()) @@ -317,7 +306,7 @@ var _ = Describe("Client", func() { go func() { doRsp, doErr = client.Do(request) }() Eventually(func() chan *http.Response { return client.responses[5] }).ShouldNot(BeNil()) - qClient.streams[5].dataToRead.Write(gzippedData) + dataStream.dataToRead.Write(gzippedData) response.Header.Add("Content-Encoding", "gzip") client.responses[5] <- response Eventually(func() *http.Response { return doRsp }).ShouldNot(BeNil()) @@ -350,7 +339,7 @@ var _ = Describe("Client", func() { go func() { doRsp, doErr = client.Do(request) }() Eventually(func() chan *http.Response { return client.responses[5] }).ShouldNot(BeNil()) - qClient.streams[5].dataToRead.Write([]byte("not gzipped")) + dataStream.dataToRead.Write([]byte("not gzipped")) client.responses[5] <- response Eventually(func() *http.Response { return doRsp }).ShouldNot(BeNil()) Expect(doErr).ToNot(HaveOccurred()) @@ -369,7 +358,7 @@ var _ = Describe("Client", func() { go func() { doRsp, doErr = client.Do(request) }() Eventually(func() chan *http.Response { return client.responses[5] }).ShouldNot(BeNil()) - qClient.streams[5].dataToRead.Write([]byte("gzipped data")) + dataStream.dataToRead.Write([]byte("gzipped data")) client.responses[5] <- response Eventually(func() *http.Response { return doRsp }).ShouldNot(BeNil()) Expect(doErr).ToNot(HaveOccurred()) diff --git a/h2quic/roundtrip.go b/h2quic/roundtrip.go index 85faf8e2..5b935366 100644 --- a/h2quic/roundtrip.go +++ b/h2quic/roundtrip.go @@ -12,6 +12,7 @@ import ( ) type h2quicClient interface { + Dial() error Do(*http.Request) (*http.Response, error) } @@ -92,8 +93,8 @@ func (r *QuicRoundTripper) getClient(hostname string) (h2quicClient, error) { client, ok := r.clients[hostname] if !ok { - var err error - client, err = NewClient(r, r.TLSClientConfig, hostname) + client = NewClient(r, r.TLSClientConfig, hostname) + err := client.Dial() if err != nil { return nil, err } diff --git a/h2quic/roundtrip_test.go b/h2quic/roundtrip_test.go index e0350f50..3a48fa2f 100644 --- a/h2quic/roundtrip_test.go +++ b/h2quic/roundtrip_test.go @@ -11,6 +11,9 @@ import ( type mockQuicRoundTripper struct{} +func (m *mockQuicRoundTripper) Dial() error { + return nil +} func (m *mockQuicRoundTripper) Do(req *http.Request) (*http.Response, error) { return &http.Response{Request: req}, nil } diff --git a/h2quic/server_test.go b/h2quic/server_test.go index c9b7a167..681e75a2 100644 --- a/h2quic/server_test.go +++ b/h2quic/server_test.go @@ -27,6 +27,7 @@ type mockSession struct { closedWithError error dataStream quic.Stream streamToAccept quic.Stream + streamToOpen quic.Stream } func (s *mockSession) GetOrOpenStream(id protocol.StreamID) (quic.Stream, error) { @@ -36,7 +37,7 @@ func (s *mockSession) AcceptStream() (quic.Stream, error) { return s.streamToAccept, nil } func (s *mockSession) OpenStream() (quic.Stream, error) { - panic("not implemented") + return s.streamToOpen, nil } func (s *mockSession) OpenStreamSync() (quic.Stream, error) { panic("not implemented")