package http3 import ( "bytes" "compress/gzip" "context" "errors" "fmt" "io" "net/http" "sync" "time" tls "github.com/refraction-networking/utls" quic "github.com/refraction-networking/uquic" mockquic "github.com/refraction-networking/uquic/internal/mocks/quic" "github.com/refraction-networking/uquic/internal/protocol" "github.com/refraction-networking/uquic/internal/qerr" "github.com/refraction-networking/uquic/internal/utils" "github.com/refraction-networking/uquic/quicvarint" "github.com/quic-go/qpack" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" "go.uber.org/mock/gomock" ) var _ = Describe("Client", func() { var ( cl *client req *http.Request origDialAddr = dialAddr handshakeChan <-chan struct{} // a closed chan ) BeforeEach(func() { origDialAddr = dialAddr hostname := "quic.clemente.io:1337" c, err := newClient(hostname, nil, &roundTripperOpts{MaxHeaderBytes: 1337}, nil, nil) Expect(err).ToNot(HaveOccurred()) cl = c.(*client) Expect(cl.hostname).To(Equal(hostname)) req, err = http.NewRequest("GET", "https://localhost:1337", nil) Expect(err).ToNot(HaveOccurred()) ch := make(chan struct{}) close(ch) handshakeChan = ch }) AfterEach(func() { dialAddr = origDialAddr }) It("rejects quic.Configs that allow multiple QUIC versions", func() { qconf := &quic.Config{ Versions: []quic.Version{protocol.Version2, protocol.Version1}, } _, err := newClient("localhost:1337", nil, &roundTripperOpts{}, qconf, nil) Expect(err).To(MatchError("can only use a single QUIC version for dialing a HTTP/3 connection")) }) It("uses the default QUIC and TLS config if none is give", func() { client, err := newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil) Expect(err).ToNot(HaveOccurred()) var dialAddrCalled bool dialAddr = func(_ context.Context, _ string, tlsConf *tls.Config, quicConf *quic.Config) (quic.EarlyConnection, error) { Expect(quicConf.MaxIncomingStreams).To(Equal(defaultQuicConfig.MaxIncomingStreams)) Expect(tlsConf.NextProtos).To(Equal([]string{NextProtoH3})) Expect(quicConf.Versions).To(Equal([]protocol.Version{protocol.Version1})) dialAddrCalled = true return nil, errors.New("test done") } client.RoundTripOpt(req, RoundTripOpt{}) Expect(dialAddrCalled).To(BeTrue()) }) It("adds the port to the hostname, if none is given", func() { client, err := newClient("quic.clemente.io", nil, &roundTripperOpts{}, nil, nil) Expect(err).ToNot(HaveOccurred()) var dialAddrCalled bool dialAddr = func(_ context.Context, hostname string, _ *tls.Config, _ *quic.Config) (quic.EarlyConnection, error) { Expect(hostname).To(Equal("quic.clemente.io:443")) dialAddrCalled = true return nil, errors.New("test done") } req, err := http.NewRequest("GET", "https://quic.clemente.io:443", nil) Expect(err).ToNot(HaveOccurred()) client.RoundTripOpt(req, RoundTripOpt{}) Expect(dialAddrCalled).To(BeTrue()) }) It("sets the ServerName in the tls.Config, if not set", func() { const host = "foo.bar" dialCalled := false dialFunc := func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) { Expect(tlsCfg.ServerName).To(Equal(host)) dialCalled = true return nil, errors.New("test done") } client, err := newClient(host, nil, &roundTripperOpts{}, nil, dialFunc) Expect(err).ToNot(HaveOccurred()) req, err := http.NewRequest("GET", "https://foo.bar", nil) Expect(err).ToNot(HaveOccurred()) client.RoundTripOpt(req, RoundTripOpt{}) Expect(dialCalled).To(BeTrue()) }) It("uses the TLS config and QUIC config", func() { tlsConf := &tls.Config{ ServerName: "foo.bar", NextProtos: []string{"proto foo", "proto bar"}, } quicConf := &quic.Config{MaxIdleTimeout: time.Nanosecond} client, err := newClient("localhost:1337", tlsConf, &roundTripperOpts{}, quicConf, nil) Expect(err).ToNot(HaveOccurred()) var dialAddrCalled bool dialAddr = func(_ context.Context, host string, tlsConfP *tls.Config, quicConfP *quic.Config) (quic.EarlyConnection, error) { Expect(host).To(Equal("localhost:1337")) Expect(tlsConfP.ServerName).To(Equal(tlsConf.ServerName)) Expect(tlsConfP.NextProtos).To(Equal([]string{NextProtoH3})) Expect(quicConfP.MaxIdleTimeout).To(Equal(quicConf.MaxIdleTimeout)) dialAddrCalled = true return nil, errors.New("test done") } client.RoundTripOpt(req, RoundTripOpt{}) Expect(dialAddrCalled).To(BeTrue()) // make sure the original tls.Config was not modified Expect(tlsConf.NextProtos).To(Equal([]string{"proto foo", "proto bar"})) }) It("uses the custom dialer, if provided", func() { testErr := errors.New("test done") tlsConf := &tls.Config{ServerName: "foo.bar"} quicConf := &quic.Config{MaxIdleTimeout: 1337 * time.Second} ctx, cancel := context.WithTimeout(context.Background(), time.Hour) defer cancel() var dialerCalled bool dialer := func(ctxP context.Context, address string, tlsConfP *tls.Config, quicConfP *quic.Config) (quic.EarlyConnection, error) { Expect(ctxP).To(Equal(ctx)) Expect(address).To(Equal("localhost:1337")) Expect(tlsConfP.ServerName).To(Equal("foo.bar")) Expect(quicConfP.MaxIdleTimeout).To(Equal(quicConf.MaxIdleTimeout)) dialerCalled = true return nil, testErr } client, err := newClient("localhost:1337", tlsConf, &roundTripperOpts{}, quicConf, dialer) Expect(err).ToNot(HaveOccurred()) _, err = client.RoundTripOpt(req.WithContext(ctx), RoundTripOpt{}) Expect(err).To(MatchError(testErr)) Expect(dialerCalled).To(BeTrue()) }) It("enables HTTP/3 Datagrams", func() { testErr := errors.New("handshake error") client, err := newClient("localhost:1337", nil, &roundTripperOpts{EnableDatagram: true}, nil, nil) Expect(err).ToNot(HaveOccurred()) dialAddr = func(_ context.Context, _ string, _ *tls.Config, quicConf *quic.Config) (quic.EarlyConnection, error) { Expect(quicConf.EnableDatagrams).To(BeTrue()) return nil, testErr } _, err = client.RoundTripOpt(req, RoundTripOpt{}) Expect(err).To(MatchError(testErr)) }) It("errors when dialing fails", func() { testErr := errors.New("handshake error") client, err := newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil) Expect(err).ToNot(HaveOccurred()) dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) { return nil, testErr } _, err = client.RoundTripOpt(req, RoundTripOpt{}) Expect(err).To(MatchError(testErr)) }) It("closes correctly if connection was not created", func() { client, err := newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil) Expect(err).ToNot(HaveOccurred()) Expect(client.Close()).To(Succeed()) }) Context("validating the address", func() { It("refuses to do requests for the wrong host", func() { req, err := http.NewRequest("https", "https://quic.clemente.io:1336/foobar.html", nil) Expect(err).ToNot(HaveOccurred()) _, err = cl.RoundTripOpt(req, RoundTripOpt{}) Expect(err).To(MatchError("http3 client BUG: RoundTripOpt called for the wrong client (expected quic.clemente.io:1337, got quic.clemente.io:1336)")) }) It("allows requests using a different scheme", func() { testErr := errors.New("handshake error") req, err := http.NewRequest("masque", "masque://quic.clemente.io:1337/foobar.html", nil) Expect(err).ToNot(HaveOccurred()) dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) { return nil, testErr } _, err = cl.RoundTripOpt(req, RoundTripOpt{}) Expect(err).To(MatchError(testErr)) }) }) Context("hijacking bidirectional streams", func() { var ( request *http.Request conn *mockquic.MockEarlyConnection settingsFrameWritten chan struct{} ) testDone := make(chan struct{}) BeforeEach(func() { testDone = make(chan struct{}) settingsFrameWritten = make(chan struct{}) controlStr := mockquic.NewMockStream(mockCtrl) controlStr.EXPECT().Write(gomock.Any()).Do(func(b []byte) (int, error) { defer GinkgoRecover() close(settingsFrameWritten) return len(b), nil }) conn = mockquic.NewMockEarlyConnection(mockCtrl) conn.EXPECT().OpenUniStream().Return(controlStr, nil) conn.EXPECT().HandshakeComplete().Return(handshakeChan) conn.EXPECT().OpenStreamSync(gomock.Any()).Return(nil, errors.New("done")) conn.EXPECT().AcceptUniStream(gomock.Any()).Return(nil, errors.New("done")).AnyTimes() dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) { return conn, nil } var err error request, err = http.NewRequest("GET", "https://quic.clemente.io:1337/file1.dat", nil) Expect(err).ToNot(HaveOccurred()) }) AfterEach(func() { testDone <- struct{}{} Eventually(settingsFrameWritten).Should(BeClosed()) }) It("hijacks a bidirectional stream of unknown frame type", func() { frameTypeChan := make(chan FrameType, 1) cl.opts.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream, e error) (hijacked bool, err error) { Expect(e).ToNot(HaveOccurred()) frameTypeChan <- ft return true, nil } buf := bytes.NewBuffer(quicvarint.Append(nil, 0x41)) unknownStr := mockquic.NewMockStream(mockCtrl) unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil) conn.EXPECT().AcceptStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.Stream, error) { <-testDone return nil, errors.New("test done") }) _, err := cl.RoundTripOpt(request, RoundTripOpt{}) Expect(err).To(MatchError("done")) Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41))) time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError }) It("closes the connection when hijacker didn't hijack a bidirectional stream", func() { frameTypeChan := make(chan FrameType, 1) cl.opts.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream, e error) (hijacked bool, err error) { Expect(e).ToNot(HaveOccurred()) frameTypeChan <- ft return false, nil } buf := bytes.NewBuffer(quicvarint.Append(nil, 0x41)) unknownStr := mockquic.NewMockStream(mockCtrl) unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil) conn.EXPECT().AcceptStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.Stream, error) { <-testDone return nil, errors.New("test done") }) conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), gomock.Any()).Return(nil).AnyTimes() _, err := cl.RoundTripOpt(request, RoundTripOpt{}) Expect(err).To(MatchError("done")) Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41))) }) It("closes the connection when hijacker returned error", func() { frameTypeChan := make(chan FrameType, 1) cl.opts.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream, e error) (hijacked bool, err error) { Expect(e).ToNot(HaveOccurred()) frameTypeChan <- ft return false, errors.New("error in hijacker") } buf := bytes.NewBuffer(quicvarint.Append(nil, 0x41)) unknownStr := mockquic.NewMockStream(mockCtrl) unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil) conn.EXPECT().AcceptStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.Stream, error) { <-testDone return nil, errors.New("test done") }) conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), gomock.Any()).Return(nil).AnyTimes() _, err := cl.RoundTripOpt(request, RoundTripOpt{}) Expect(err).To(MatchError("done")) Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41))) }) It("handles errors that occur when reading the frame type", func() { testErr := errors.New("test error") unknownStr := mockquic.NewMockStream(mockCtrl) done := make(chan struct{}) cl.opts.StreamHijacker = func(ft FrameType, c quic.Connection, str quic.Stream, e error) (hijacked bool, err error) { defer close(done) Expect(e).To(MatchError(testErr)) Expect(ft).To(BeZero()) Expect(str).To(Equal(unknownStr)) return false, nil } unknownStr.EXPECT().Read(gomock.Any()).Return(0, testErr).AnyTimes() conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil) conn.EXPECT().AcceptStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.Stream, error) { <-testDone return nil, errors.New("test done") }) conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), gomock.Any()).Return(nil).AnyTimes() _, err := cl.RoundTripOpt(request, RoundTripOpt{}) Expect(err).To(MatchError("done")) Eventually(done).Should(BeClosed()) time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError }) }) Context("hijacking unidirectional streams", func() { var ( req *http.Request conn *mockquic.MockEarlyConnection settingsFrameWritten chan struct{} ) testDone := make(chan struct{}) BeforeEach(func() { testDone = make(chan struct{}) settingsFrameWritten = make(chan struct{}) controlStr := mockquic.NewMockStream(mockCtrl) controlStr.EXPECT().Write(gomock.Any()).Do(func(b []byte) (int, error) { defer GinkgoRecover() close(settingsFrameWritten) return len(b), nil }) conn = mockquic.NewMockEarlyConnection(mockCtrl) conn.EXPECT().OpenUniStream().Return(controlStr, nil) conn.EXPECT().HandshakeComplete().Return(handshakeChan) conn.EXPECT().OpenStreamSync(gomock.Any()).Return(nil, errors.New("done")) dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) { return conn, nil } var err error req, err = http.NewRequest("GET", "https://quic.clemente.io:1337/file1.dat", nil) Expect(err).ToNot(HaveOccurred()) }) AfterEach(func() { testDone <- struct{}{} Eventually(settingsFrameWritten).Should(BeClosed()) }) It("hijacks an unidirectional stream of unknown stream type", func() { streamTypeChan := make(chan StreamType, 1) cl.opts.UniStreamHijacker = func(st StreamType, _ quic.Connection, _ quic.ReceiveStream, err error) bool { Expect(err).ToNot(HaveOccurred()) streamTypeChan <- st return true } buf := bytes.NewBuffer(quicvarint.Append(nil, 0x54)) unknownStr := mockquic.NewMockStream(mockCtrl) unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { return unknownStr, nil }) conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { <-testDone return nil, errors.New("test done") }) _, err := cl.RoundTripOpt(req, RoundTripOpt{}) Expect(err).To(MatchError("done")) Eventually(streamTypeChan).Should(Receive(BeEquivalentTo(0x54))) time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError }) It("handles errors that occur when reading the stream type", func() { testErr := errors.New("test error") done := make(chan struct{}) unknownStr := mockquic.NewMockStream(mockCtrl) cl.opts.UniStreamHijacker = func(st StreamType, _ quic.Connection, str quic.ReceiveStream, err error) bool { defer close(done) Expect(st).To(BeZero()) Expect(str).To(Equal(unknownStr)) Expect(err).To(MatchError(testErr)) return true } unknownStr.EXPECT().Read(gomock.Any()).Return(0, testErr) conn.EXPECT().AcceptUniStream(gomock.Any()).Return(unknownStr, nil) conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { <-testDone return nil, errors.New("test done") }) _, err := cl.RoundTripOpt(req, RoundTripOpt{}) Expect(err).To(MatchError("done")) Eventually(done).Should(BeClosed()) time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError }) It("cancels reading when hijacker didn't hijack an unidirectional stream", func() { streamTypeChan := make(chan StreamType, 1) cl.opts.UniStreamHijacker = func(st StreamType, _ quic.Connection, _ quic.ReceiveStream, err error) bool { Expect(err).ToNot(HaveOccurred()) streamTypeChan <- st return false } buf := bytes.NewBuffer(quicvarint.Append(nil, 0x54)) unknownStr := mockquic.NewMockStream(mockCtrl) unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() unknownStr.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeStreamCreationError)) conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { return unknownStr, nil }) conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { <-testDone return nil, errors.New("test done") }) _, err := cl.RoundTripOpt(req, RoundTripOpt{}) Expect(err).To(MatchError("done")) Eventually(streamTypeChan).Should(Receive(BeEquivalentTo(0x54))) time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError }) }) Context("control stream handling", func() { var ( req *http.Request conn *mockquic.MockEarlyConnection settingsFrameWritten chan struct{} ) testDone := make(chan struct{}, 1) BeforeEach(func() { settingsFrameWritten = make(chan struct{}) controlStr := mockquic.NewMockStream(mockCtrl) controlStr.EXPECT().Write(gomock.Any()).Do(func(b []byte) (int, error) { defer GinkgoRecover() close(settingsFrameWritten) return len(b), nil }) conn = mockquic.NewMockEarlyConnection(mockCtrl) conn.EXPECT().OpenUniStream().Return(controlStr, nil) conn.EXPECT().HandshakeComplete().Return(handshakeChan) dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) { return conn, nil } var err error req, err = http.NewRequest("GET", "https://quic.clemente.io:1337/file1.dat", nil) Expect(err).ToNot(HaveOccurred()) }) AfterEach(func() { testDone <- struct{}{} Eventually(settingsFrameWritten).Should(BeClosed()) }) It("parses the SETTINGS frame", func() { b := quicvarint.Append(nil, streamTypeControlStream) b = (&settingsFrame{ Datagram: true, ExtendedConnect: true, Other: map[uint64]uint64{1337: 42}, }).Append(b) r := bytes.NewReader(b) controlStr := mockquic.NewMockStream(mockCtrl) controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes() conn.EXPECT().OpenStreamSync(gomock.Any()).Return(nil, errors.New("done")) conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { return controlStr, nil }) conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { <-testDone return nil, errors.New("test done") }) conn.EXPECT().Context().Return(context.Background()) _, err := cl.RoundTripOpt(req, RoundTripOpt{CheckSettings: func(settings Settings) error { defer GinkgoRecover() Expect(settings.EnableDatagram).To(BeTrue()) Expect(settings.EnableExtendedConnect).To(BeTrue()) Expect(settings.Other).To(HaveLen(1)) Expect(settings.Other).To(HaveKeyWithValue(uint64(1337), uint64(42))) return nil }}) Expect(err).To(MatchError("done")) time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError }) It("allows the client to reject the SETTINGS using the CheckSettings RoundTripOpt", func() { b := quicvarint.Append(nil, streamTypeControlStream) b = (&settingsFrame{}).Append(b) r := bytes.NewReader(b) controlStr := mockquic.NewMockStream(mockCtrl) controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes() // Don't EXPECT any call to OpenStreamSync. // When the SETTINGS are rejected, we don't even open the request stream. conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { return controlStr, nil }) conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { <-testDone return nil, errors.New("test done") }) conn.EXPECT().Context().Return(context.Background()) _, err := cl.RoundTripOpt(req, RoundTripOpt{CheckSettings: func(settings Settings) error { return errors.New("wrong settings") }}) Expect(err).To(MatchError("wrong settings")) time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError }) It("rejects duplicate control streams", func() { b := quicvarint.Append(nil, streamTypeControlStream) b = (&settingsFrame{}).Append(b) r1 := bytes.NewReader(b) controlStr1 := mockquic.NewMockStream(mockCtrl) controlStr1.EXPECT().Read(gomock.Any()).DoAndReturn(r1.Read).AnyTimes() r2 := bytes.NewReader(b) controlStr2 := mockquic.NewMockStream(mockCtrl) controlStr2.EXPECT().Read(gomock.Any()).DoAndReturn(r2.Read).AnyTimes() done := make(chan struct{}) conn.EXPECT().OpenStreamSync(gomock.Any()).Return(nil, errors.New("done")) conn.EXPECT().CloseWithError(qerr.ApplicationErrorCode(ErrCodeStreamCreationError), "duplicate control stream").Do(func(qerr.ApplicationErrorCode, string) error { close(done) return nil }) conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { return controlStr1, nil }) conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { return controlStr2, nil }) conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { <-done return nil, errors.New("test done") }) _, err := cl.RoundTripOpt(req, RoundTripOpt{}) Expect(err).To(HaveOccurred()) Eventually(done).Should(BeClosed()) }) for _, t := range []uint64{streamTypeQPACKEncoderStream, streamTypeQPACKDecoderStream} { streamType := t name := "encoder" if streamType == streamTypeQPACKDecoderStream { name = "decoder" } It(fmt.Sprintf("ignores the QPACK %s streams", name), func() { buf := bytes.NewBuffer(quicvarint.Append(nil, streamType)) str := mockquic.NewMockStream(mockCtrl) str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() conn.EXPECT().OpenStreamSync(gomock.Any()).Return(nil, errors.New("done")) conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { return str, nil }) conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { <-testDone return nil, errors.New("test done") }) _, err := cl.RoundTripOpt(req, RoundTripOpt{}) Expect(err).To(MatchError("done")) time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to str.CancelRead }) } It("resets streams other than the control stream and the QPACK streams", func() { buf := bytes.NewBuffer(quicvarint.Append(nil, 0x1337)) str := mockquic.NewMockStream(mockCtrl) str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() done := make(chan struct{}) str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeStreamCreationError)).Do(func(quic.StreamErrorCode) { close(done) }) conn.EXPECT().OpenStreamSync(gomock.Any()).Return(nil, errors.New("done")) conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { return str, nil }) conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { <-testDone return nil, errors.New("test done") }) _, err := cl.RoundTripOpt(req, RoundTripOpt{}) Expect(err).To(MatchError("done")) Eventually(done).Should(BeClosed()) }) It("errors when the first frame on the control stream is not a SETTINGS frame", func() { b := quicvarint.Append(nil, streamTypeControlStream) b = (&dataFrame{}).Append(b) r := bytes.NewReader(b) controlStr := mockquic.NewMockStream(mockCtrl) controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes() conn.EXPECT().OpenStreamSync(gomock.Any()).Return(nil, errors.New("done")) conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { return controlStr, nil }) conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { <-testDone return nil, errors.New("test done") }) done := make(chan struct{}) conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeMissingSettings), gomock.Any()).Do(func(quic.ApplicationErrorCode, string) error { close(done) return nil }) _, err := cl.RoundTripOpt(req, RoundTripOpt{}) Expect(err).To(MatchError("done")) Eventually(done).Should(BeClosed()) }) It("errors when the first frame on the control stream is not a SETTINGS frame, when checking SETTINGS", func() { b := quicvarint.Append(nil, streamTypeControlStream) b = (&dataFrame{}).Append(b) r := bytes.NewReader(b) controlStr := mockquic.NewMockStream(mockCtrl) controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes() // Don't EXPECT any calls to OpenStreamSync. // We fail before we even get the chance to open the request stream. conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { return controlStr, nil }) conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { <-testDone return nil, errors.New("test done") }) doneCtx, doneCancel := context.WithCancelCause(context.Background()) conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeMissingSettings), gomock.Any()).Do(func(quic.ApplicationErrorCode, string) error { doneCancel(errors.New("done")) return nil }) conn.EXPECT().Context().Return(doneCtx).Times(2) var checked bool _, err := cl.RoundTripOpt(req, RoundTripOpt{ CheckSettings: func(Settings) error { checked = true; return nil }, }) Expect(checked).To(BeFalse()) Expect(err).To(MatchError("done")) Eventually(doneCtx.Done()).Should(BeClosed()) }) It("errors when parsing the frame on the control stream fails", func() { b := quicvarint.Append(nil, streamTypeControlStream) b = (&settingsFrame{}).Append(b) r := bytes.NewReader(b[:len(b)-1]) controlStr := mockquic.NewMockStream(mockCtrl) controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes() conn.EXPECT().OpenStreamSync(gomock.Any()).Return(nil, errors.New("done")) conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { return controlStr, nil }) conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { <-testDone return nil, errors.New("test done") }) done := make(chan struct{}) conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameError), gomock.Any()).Do(func(code quic.ApplicationErrorCode, _ string) error { close(done) return nil }) _, err := cl.RoundTripOpt(req, RoundTripOpt{}) Expect(err).To(MatchError("done")) Eventually(done).Should(BeClosed()) }) It("errors when parsing the server opens a push stream", func() { buf := bytes.NewBuffer(quicvarint.Append(nil, streamTypePushStream)) controlStr := mockquic.NewMockStream(mockCtrl) controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() conn.EXPECT().OpenStreamSync(gomock.Any()).Return(nil, errors.New("done")) conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { return controlStr, nil }) conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { <-testDone return nil, errors.New("test done") }) done := make(chan struct{}) conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeIDError), gomock.Any()).Do(func(quic.ApplicationErrorCode, string) error { close(done) return nil }) _, err := cl.RoundTripOpt(req, RoundTripOpt{}) Expect(err).To(MatchError("done")) Eventually(done).Should(BeClosed()) }) It("errors when the server advertises datagram support (and we enabled support for it)", func() { cl.opts.EnableDatagram = true b := quicvarint.Append(nil, streamTypeControlStream) b = (&settingsFrame{Datagram: true}).Append(b) r := bytes.NewReader(b) controlStr := mockquic.NewMockStream(mockCtrl) controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes() conn.EXPECT().OpenStreamSync(gomock.Any()).Return(nil, errors.New("done")) conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { return controlStr, nil }) conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { <-testDone return nil, errors.New("test done") }) conn.EXPECT().ConnectionState().Return(quic.ConnectionState{SupportsDatagrams: false}) done := make(chan struct{}) conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeSettingsError), "missing QUIC Datagram support").Do(func(quic.ApplicationErrorCode, string) error { close(done) return nil }) _, err := cl.RoundTripOpt(req, RoundTripOpt{}) Expect(err).To(MatchError("done")) Eventually(done).Should(BeClosed()) }) }) Context("Doing requests", func() { var ( req *http.Request str *mockquic.MockStream conn *mockquic.MockEarlyConnection settingsFrameWritten chan struct{} ) testDone := make(chan struct{}) decodeHeader := func(str io.Reader) map[string]string { fields := make(map[string]string) decoder := qpack.NewDecoder(nil) frame, err := parseNextFrame(str, nil) ExpectWithOffset(1, err).ToNot(HaveOccurred()) ExpectWithOffset(1, frame).To(BeAssignableToTypeOf(&headersFrame{})) headersFrame := frame.(*headersFrame) data := make([]byte, headersFrame.Length) _, err = io.ReadFull(str, data) ExpectWithOffset(1, err).ToNot(HaveOccurred()) hfs, err := decoder.DecodeFull(data) ExpectWithOffset(1, err).ToNot(HaveOccurred()) for _, p := range hfs { fields[p.Name] = p.Value } return fields } getResponse := func(status int) []byte { buf := &bytes.Buffer{} rstr := mockquic.NewMockStream(mockCtrl) rstr.EXPECT().Write(gomock.Any()).Do(buf.Write).AnyTimes() rw := newResponseWriter(rstr, nil, utils.DefaultLogger) rw.WriteHeader(status) rw.Flush() return buf.Bytes() } BeforeEach(func() { settingsFrameWritten = make(chan struct{}) controlStr := mockquic.NewMockStream(mockCtrl) controlStr.EXPECT().Write(gomock.Any()).Do(func(b []byte) (int, error) { defer GinkgoRecover() r := bytes.NewReader(b) streamType, err := quicvarint.Read(r) Expect(err).ToNot(HaveOccurred()) Expect(streamType).To(BeEquivalentTo(streamTypeControlStream)) close(settingsFrameWritten) return len(b), nil }) // SETTINGS frame str = mockquic.NewMockStream(mockCtrl) conn = mockquic.NewMockEarlyConnection(mockCtrl) conn.EXPECT().OpenUniStream().Return(controlStr, nil) conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { <-testDone return nil, errors.New("test done") }) dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) { return conn, nil } var err error req, err = http.NewRequest("GET", "https://quic.clemente.io:1337/file1.dat", nil) Expect(err).ToNot(HaveOccurred()) }) AfterEach(func() { testDone <- struct{}{} Eventually(settingsFrameWritten).Should(BeClosed()) }) It("errors if it can't open a stream", func() { testErr := errors.New("stream open error") conn.EXPECT().OpenStreamSync(context.Background()).Return(nil, testErr) conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).MaxTimes(1) conn.EXPECT().HandshakeComplete().Return(handshakeChan) _, err := cl.RoundTripOpt(req, RoundTripOpt{}) Expect(err).To(MatchError(testErr)) }) It("performs a 0-RTT request", func() { testErr := errors.New("stream open error") req.Method = MethodGet0RTT // don't EXPECT any calls to HandshakeComplete() conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil) buf := &bytes.Buffer{} str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write).AnyTimes() str.EXPECT().Close() str.EXPECT().CancelWrite(gomock.Any()) str.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) { return 0, testErr }) _, err := cl.RoundTripOpt(req, RoundTripOpt{}) Expect(err).To(MatchError(testErr)) Expect(decodeHeader(buf)).To(HaveKeyWithValue(":method", "GET")) }) It("returns a response", func() { rspBuf := bytes.NewBuffer(getResponse(418)) gomock.InOrder( conn.EXPECT().HandshakeComplete().Return(handshakeChan), conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil), conn.EXPECT().ConnectionState().Return(quic.ConnectionState{}), ) str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil }) str.EXPECT().Close() str.EXPECT().Read(gomock.Any()).DoAndReturn(rspBuf.Read).AnyTimes() rsp, err := cl.RoundTripOpt(req, RoundTripOpt{}) Expect(err).ToNot(HaveOccurred()) Expect(rsp.Proto).To(Equal("HTTP/3.0")) Expect(rsp.ProtoMajor).To(Equal(3)) Expect(rsp.StatusCode).To(Equal(418)) Expect(rsp.Request).ToNot(BeNil()) }) It("doesn't close the request stream, with DontCloseRequestStream set", func() { rspBuf := bytes.NewBuffer(getResponse(418)) gomock.InOrder( conn.EXPECT().HandshakeComplete().Return(handshakeChan), conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil), conn.EXPECT().ConnectionState().Return(quic.ConnectionState{}), ) str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil }) str.EXPECT().Read(gomock.Any()).DoAndReturn(rspBuf.Read).AnyTimes() rsp, err := cl.RoundTripOpt(req, RoundTripOpt{DontCloseRequestStream: true}) Expect(err).ToNot(HaveOccurred()) Expect(rsp.Proto).To(Equal("HTTP/3.0")) Expect(rsp.ProtoMajor).To(Equal(3)) Expect(rsp.StatusCode).To(Equal(418)) }) Context("requests containing a Body", func() { var strBuf *bytes.Buffer BeforeEach(func() { strBuf = &bytes.Buffer{} gomock.InOrder( conn.EXPECT().HandshakeComplete().Return(handshakeChan), conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil), ) body := &mockBody{} body.SetData([]byte("request body")) var err error req, err = http.NewRequest("POST", "https://quic.clemente.io:1337/upload", body) Expect(err).ToNot(HaveOccurred()) str.EXPECT().Write(gomock.Any()).DoAndReturn(strBuf.Write).AnyTimes() }) It("sends a request", func() { done := make(chan struct{}) gomock.InOrder( str.EXPECT().Close().Do(func() error { close(done); return nil }), str.EXPECT().CancelWrite(gomock.Any()).MaxTimes(1), // when reading the response errors ) // the response body is sent asynchronously, while already reading the response str.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) { <-done return 0, errors.New("test done") }) _, err := cl.RoundTripOpt(req, RoundTripOpt{}) Expect(err).To(MatchError("test done")) hfs := decodeHeader(strBuf) Expect(hfs).To(HaveKeyWithValue(":method", "POST")) Expect(hfs).To(HaveKeyWithValue(":path", "/upload")) }) It("doesn't send more bytes than allowed by http.Request.ContentLength", func() { req.ContentLength = 7 var once sync.Once done := make(chan struct{}) gomock.InOrder( str.EXPECT().CancelWrite(gomock.Any()).Do(func(c quic.StreamErrorCode) { once.Do(func() { Expect(c).To(Equal(quic.StreamErrorCode(ErrCodeRequestCanceled))) close(done) }) }).AnyTimes(), str.EXPECT().Close().MaxTimes(1), str.EXPECT().CancelWrite(gomock.Any()).AnyTimes(), ) str.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) { <-done return 0, errors.New("done") }) cl.RoundTripOpt(req, RoundTripOpt{}) Expect(strBuf.String()).To(ContainSubstring("request")) Expect(strBuf.String()).ToNot(ContainSubstring("request body")) }) It("returns the error that occurred when reading the body", func() { req.Body.(*mockBody).readErr = errors.New("testErr") done := make(chan struct{}) gomock.InOrder( str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeRequestCanceled)).Do(func(quic.StreamErrorCode) { close(done) }), str.EXPECT().CancelWrite(gomock.Any()), ) // the response body is sent asynchronously, while already reading the response str.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) { <-done return 0, errors.New("test done") }) closed := make(chan struct{}) str.EXPECT().Close().Do(func() error { close(closed); return nil }) _, err := cl.RoundTripOpt(req, RoundTripOpt{}) Expect(err).To(MatchError("test done")) Eventually(closed).Should(BeClosed()) }) It("closes the connection when the first frame is not a HEADERS frame", func() { b := (&dataFrame{Length: 0x42}).Append(nil) conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), gomock.Any()) closed := make(chan struct{}) r := bytes.NewReader(b) str.EXPECT().Close().Do(func() error { close(closed); return nil }) str.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes() _, err := cl.RoundTripOpt(req, RoundTripOpt{}) Expect(err).To(MatchError("expected first frame to be a HEADERS frame")) Eventually(closed).Should(BeClosed()) }) It("cancels the stream when parsing the headers fails", func() { headerBuf := &bytes.Buffer{} enc := qpack.NewEncoder(headerBuf) Expect(enc.WriteField(qpack.HeaderField{Name: ":method", Value: "GET"})).To(Succeed()) // not a valid response pseudo header Expect(enc.Close()).To(Succeed()) b := (&headersFrame{Length: uint64(headerBuf.Len())}).Append(nil) b = append(b, headerBuf.Bytes()...) r := bytes.NewReader(b) str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeMessageError)) closed := make(chan struct{}) str.EXPECT().Close().Do(func() error { close(closed); return nil }) str.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes() _, err := cl.RoundTripOpt(req, RoundTripOpt{}) Expect(err).To(HaveOccurred()) Eventually(closed).Should(BeClosed()) }) It("cancels the stream when the HEADERS frame is too large", func() { b := (&headersFrame{Length: 1338}).Append(nil) r := bytes.NewReader(b) str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeFrameError)) closed := make(chan struct{}) str.EXPECT().Close().Do(func() error { close(closed); return nil }) str.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes() _, err := cl.RoundTripOpt(req, RoundTripOpt{}) Expect(err).To(MatchError("HEADERS frame too large: 1338 bytes (max: 1337)")) Eventually(closed).Should(BeClosed()) }) }) Context("request cancellations", func() { for _, dontClose := range []bool{false, true} { dontClose := dontClose Context(fmt.Sprintf("with DontCloseRequestStream: %t", dontClose), func() { roundTripOpt := RoundTripOpt{DontCloseRequestStream: dontClose} It("cancels a request while waiting for the handshake to complete", func() { ctx, cancel := context.WithCancel(context.Background()) req := req.WithContext(ctx) conn.EXPECT().HandshakeComplete().Return(make(chan struct{})) errChan := make(chan error) go func() { _, err := cl.RoundTripOpt(req, roundTripOpt) errChan <- err }() Consistently(errChan).ShouldNot(Receive()) cancel() Eventually(errChan).Should(Receive(MatchError("context canceled"))) }) It("cancels a request while the request is still in flight", func() { ctx, cancel := context.WithCancel(context.Background()) req := req.WithContext(ctx) conn.EXPECT().HandshakeComplete().Return(handshakeChan) conn.EXPECT().OpenStreamSync(ctx).Return(str, nil) buf := &bytes.Buffer{} str.EXPECT().Close().MaxTimes(1) str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write) done := make(chan struct{}) canceled := make(chan struct{}) gomock.InOrder( str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeRequestCanceled)).Do(func(quic.StreamErrorCode) { close(canceled) }), str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeRequestCanceled)).Do(func(quic.StreamErrorCode) { close(done) }), ) str.EXPECT().CancelWrite(gomock.Any()).MaxTimes(1) str.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) { cancel() <-canceled return 0, errors.New("test done") }) _, err := cl.RoundTripOpt(req, roundTripOpt) Expect(err).To(MatchError(context.Canceled)) Eventually(done).Should(BeClosed()) }) }) } It("cancels a request after the response arrived", func() { rspBuf := bytes.NewBuffer(getResponse(404)) ctx, cancel := context.WithCancel(context.Background()) req := req.WithContext(ctx) conn.EXPECT().HandshakeComplete().Return(handshakeChan) conn.EXPECT().OpenStreamSync(ctx).Return(str, nil) conn.EXPECT().ConnectionState().Return(quic.ConnectionState{}) buf := &bytes.Buffer{} str.EXPECT().Close().MaxTimes(1) done := make(chan struct{}) str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write) str.EXPECT().Read(gomock.Any()).DoAndReturn(rspBuf.Read).AnyTimes() str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeRequestCanceled)) str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeRequestCanceled)).Do(func(quic.StreamErrorCode) { close(done) }) _, err := cl.RoundTripOpt(req, RoundTripOpt{}) Expect(err).ToNot(HaveOccurred()) cancel() Eventually(done).Should(BeClosed()) }) }) Context("gzip compression", func() { BeforeEach(func() { conn.EXPECT().HandshakeComplete().Return(handshakeChan) }) It("adds the gzip header to requests", func() { conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil) buf := &bytes.Buffer{} str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write) gomock.InOrder( str.EXPECT().Close(), str.EXPECT().CancelWrite(gomock.Any()).MaxTimes(1), // when the Read errors ) str.EXPECT().Read(gomock.Any()).Return(0, errors.New("test done")) _, err := cl.RoundTripOpt(req, RoundTripOpt{}) Expect(err).To(MatchError("test done")) hfs := decodeHeader(buf) Expect(hfs).To(HaveKeyWithValue("accept-encoding", "gzip")) }) It("doesn't add gzip if the header disable it", func() { client, err := newClient("quic.clemente.io:1337", nil, &roundTripperOpts{DisableCompression: true}, nil, nil) Expect(err).ToNot(HaveOccurred()) conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil) buf := &bytes.Buffer{} str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write) gomock.InOrder( str.EXPECT().Close(), str.EXPECT().CancelWrite(gomock.Any()).MaxTimes(1), // when the Read errors ) str.EXPECT().Read(gomock.Any()).Return(0, errors.New("test done")) _, err = client.RoundTripOpt(req, RoundTripOpt{}) Expect(err).To(MatchError("test done")) hfs := decodeHeader(buf) Expect(hfs).ToNot(HaveKey("accept-encoding")) }) It("decompresses the response", func() { conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil) conn.EXPECT().ConnectionState().Return(quic.ConnectionState{}) buf := &bytes.Buffer{} rstr := mockquic.NewMockStream(mockCtrl) rstr.EXPECT().Write(gomock.Any()).Do(buf.Write).AnyTimes() rw := newResponseWriter(rstr, nil, utils.DefaultLogger) rw.Header().Set("Content-Encoding", "gzip") gz := gzip.NewWriter(rw) gz.Write([]byte("gzipped response")) gz.Close() rw.Flush() str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil }) str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() str.EXPECT().Close() rsp, err := cl.RoundTripOpt(req, RoundTripOpt{}) Expect(err).ToNot(HaveOccurred()) data, err := io.ReadAll(rsp.Body) Expect(err).ToNot(HaveOccurred()) Expect(rsp.ContentLength).To(BeEquivalentTo(-1)) Expect(string(data)).To(Equal("gzipped response")) Expect(rsp.Header.Get("Content-Encoding")).To(BeEmpty()) Expect(rsp.Uncompressed).To(BeTrue()) }) It("only decompresses the response if the response contains the right content-encoding header", func() { conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil) conn.EXPECT().ConnectionState().Return(quic.ConnectionState{}) buf := &bytes.Buffer{} rstr := mockquic.NewMockStream(mockCtrl) rstr.EXPECT().Write(gomock.Any()).Do(buf.Write).AnyTimes() rw := newResponseWriter(rstr, nil, utils.DefaultLogger) rw.Write([]byte("not gzipped")) rw.Flush() str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil }) str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() str.EXPECT().Close() rsp, err := cl.RoundTripOpt(req, RoundTripOpt{}) Expect(err).ToNot(HaveOccurred()) data, err := io.ReadAll(rsp.Body) Expect(err).ToNot(HaveOccurred()) Expect(string(data)).To(Equal("not gzipped")) Expect(rsp.Header.Get("Content-Encoding")).To(BeEmpty()) }) }) }) })