uquic/http3/client_test.go
Marten Seemann 497d3f58a5
http3: add a RoundTripOpt to check the server's SETTINGS frame (#4355)
For some requests, the client is required to check the server's HTTP/3
SETTINGS. For example, a client is only allowed to send HTTP/3 datagrams
if the server explicitly enabled support.

SETTINGS are sent asynchronously on a control stream (usually the first
unidirectional stream). This means that the SETTINGS might not be
available at the beginning of the connection. This is not expected to be
the common case, since the server can send the SETTINGS in 0.5-RTT data,
but we have to be able to deal with arbitrary delays.

For WebTransport, there are even more SETTINGS values that the client
needs to check. By making CheckSettings a callback on the RoundTripOpt,
this entire validation logic can live at the WebTransport layer.
2024-03-12 01:03:00 -07:00

1143 lines
46 KiB
Go

package http3
import (
"bytes"
"compress/gzip"
"context"
"crypto/tls"
"errors"
"fmt"
"io"
"net/http"
"sync"
"time"
"github.com/quic-go/quic-go"
mockquic "github.com/quic-go/quic-go/internal/mocks/quic"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/qerr"
"github.com/quic-go/quic-go/internal/utils"
"github.com/quic-go/quic-go/quicvarint"
"github.com/quic-go/qpack"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"go.uber.org/mock/gomock"
)
var _ = Describe("Client", func() {
var (
cl *client
req *http.Request
origDialAddr = dialAddr
handshakeChan <-chan struct{} // a closed chan
)
BeforeEach(func() {
origDialAddr = dialAddr
hostname := "quic.clemente.io:1337"
c, err := newClient(hostname, nil, &roundTripperOpts{MaxHeaderBytes: 1337}, nil, nil)
Expect(err).ToNot(HaveOccurred())
cl = c.(*client)
Expect(cl.hostname).To(Equal(hostname))
req, err = http.NewRequest("GET", "https://localhost:1337", nil)
Expect(err).ToNot(HaveOccurred())
ch := make(chan struct{})
close(ch)
handshakeChan = ch
})
AfterEach(func() {
dialAddr = origDialAddr
})
It("rejects quic.Configs that allow multiple QUIC versions", func() {
qconf := &quic.Config{
Versions: []quic.Version{protocol.Version2, protocol.Version1},
}
_, err := newClient("localhost:1337", nil, &roundTripperOpts{}, qconf, nil)
Expect(err).To(MatchError("can only use a single QUIC version for dialing a HTTP/3 connection"))
})
It("uses the default QUIC and TLS config if none is give", func() {
client, err := newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil)
Expect(err).ToNot(HaveOccurred())
var dialAddrCalled bool
dialAddr = func(_ context.Context, _ string, tlsConf *tls.Config, quicConf *quic.Config) (quic.EarlyConnection, error) {
Expect(quicConf.MaxIncomingStreams).To(Equal(defaultQuicConfig.MaxIncomingStreams))
Expect(tlsConf.NextProtos).To(Equal([]string{NextProtoH3}))
Expect(quicConf.Versions).To(Equal([]protocol.Version{protocol.Version1}))
dialAddrCalled = true
return nil, errors.New("test done")
}
client.RoundTripOpt(req, RoundTripOpt{})
Expect(dialAddrCalled).To(BeTrue())
})
It("adds the port to the hostname, if none is given", func() {
client, err := newClient("quic.clemente.io", nil, &roundTripperOpts{}, nil, nil)
Expect(err).ToNot(HaveOccurred())
var dialAddrCalled bool
dialAddr = func(_ context.Context, hostname string, _ *tls.Config, _ *quic.Config) (quic.EarlyConnection, error) {
Expect(hostname).To(Equal("quic.clemente.io:443"))
dialAddrCalled = true
return nil, errors.New("test done")
}
req, err := http.NewRequest("GET", "https://quic.clemente.io:443", nil)
Expect(err).ToNot(HaveOccurred())
client.RoundTripOpt(req, RoundTripOpt{})
Expect(dialAddrCalled).To(BeTrue())
})
It("sets the ServerName in the tls.Config, if not set", func() {
const host = "foo.bar"
dialCalled := false
dialFunc := func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) {
Expect(tlsCfg.ServerName).To(Equal(host))
dialCalled = true
return nil, errors.New("test done")
}
client, err := newClient(host, nil, &roundTripperOpts{}, nil, dialFunc)
Expect(err).ToNot(HaveOccurred())
req, err := http.NewRequest("GET", "https://foo.bar", nil)
Expect(err).ToNot(HaveOccurred())
client.RoundTripOpt(req, RoundTripOpt{})
Expect(dialCalled).To(BeTrue())
})
It("uses the TLS config and QUIC config", func() {
tlsConf := &tls.Config{
ServerName: "foo.bar",
NextProtos: []string{"proto foo", "proto bar"},
}
quicConf := &quic.Config{MaxIdleTimeout: time.Nanosecond}
client, err := newClient("localhost:1337", tlsConf, &roundTripperOpts{}, quicConf, nil)
Expect(err).ToNot(HaveOccurred())
var dialAddrCalled bool
dialAddr = func(_ context.Context, host string, tlsConfP *tls.Config, quicConfP *quic.Config) (quic.EarlyConnection, error) {
Expect(host).To(Equal("localhost:1337"))
Expect(tlsConfP.ServerName).To(Equal(tlsConf.ServerName))
Expect(tlsConfP.NextProtos).To(Equal([]string{NextProtoH3}))
Expect(quicConfP.MaxIdleTimeout).To(Equal(quicConf.MaxIdleTimeout))
dialAddrCalled = true
return nil, errors.New("test done")
}
client.RoundTripOpt(req, RoundTripOpt{})
Expect(dialAddrCalled).To(BeTrue())
// make sure the original tls.Config was not modified
Expect(tlsConf.NextProtos).To(Equal([]string{"proto foo", "proto bar"}))
})
It("uses the custom dialer, if provided", func() {
testErr := errors.New("test done")
tlsConf := &tls.Config{ServerName: "foo.bar"}
quicConf := &quic.Config{MaxIdleTimeout: 1337 * time.Second}
ctx, cancel := context.WithTimeout(context.Background(), time.Hour)
defer cancel()
var dialerCalled bool
dialer := func(ctxP context.Context, address string, tlsConfP *tls.Config, quicConfP *quic.Config) (quic.EarlyConnection, error) {
Expect(ctxP).To(Equal(ctx))
Expect(address).To(Equal("localhost:1337"))
Expect(tlsConfP.ServerName).To(Equal("foo.bar"))
Expect(quicConfP.MaxIdleTimeout).To(Equal(quicConf.MaxIdleTimeout))
dialerCalled = true
return nil, testErr
}
client, err := newClient("localhost:1337", tlsConf, &roundTripperOpts{}, quicConf, dialer)
Expect(err).ToNot(HaveOccurred())
_, err = client.RoundTripOpt(req.WithContext(ctx), RoundTripOpt{})
Expect(err).To(MatchError(testErr))
Expect(dialerCalled).To(BeTrue())
})
It("enables HTTP/3 Datagrams", func() {
testErr := errors.New("handshake error")
client, err := newClient("localhost:1337", nil, &roundTripperOpts{EnableDatagram: true}, nil, nil)
Expect(err).ToNot(HaveOccurred())
dialAddr = func(_ context.Context, _ string, _ *tls.Config, quicConf *quic.Config) (quic.EarlyConnection, error) {
Expect(quicConf.EnableDatagrams).To(BeTrue())
return nil, testErr
}
_, err = client.RoundTripOpt(req, RoundTripOpt{})
Expect(err).To(MatchError(testErr))
})
It("errors when dialing fails", func() {
testErr := errors.New("handshake error")
client, err := newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil)
Expect(err).ToNot(HaveOccurred())
dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) {
return nil, testErr
}
_, err = client.RoundTripOpt(req, RoundTripOpt{})
Expect(err).To(MatchError(testErr))
})
It("closes correctly if connection was not created", func() {
client, err := newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil)
Expect(err).ToNot(HaveOccurred())
Expect(client.Close()).To(Succeed())
})
Context("validating the address", func() {
It("refuses to do requests for the wrong host", func() {
req, err := http.NewRequest("https", "https://quic.clemente.io:1336/foobar.html", nil)
Expect(err).ToNot(HaveOccurred())
_, err = cl.RoundTripOpt(req, RoundTripOpt{})
Expect(err).To(MatchError("http3 client BUG: RoundTripOpt called for the wrong client (expected quic.clemente.io:1337, got quic.clemente.io:1336)"))
})
It("allows requests using a different scheme", func() {
testErr := errors.New("handshake error")
req, err := http.NewRequest("masque", "masque://quic.clemente.io:1337/foobar.html", nil)
Expect(err).ToNot(HaveOccurred())
dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) {
return nil, testErr
}
_, err = cl.RoundTripOpt(req, RoundTripOpt{})
Expect(err).To(MatchError(testErr))
})
})
Context("hijacking bidirectional streams", func() {
var (
request *http.Request
conn *mockquic.MockEarlyConnection
settingsFrameWritten chan struct{}
)
testDone := make(chan struct{})
BeforeEach(func() {
testDone = make(chan struct{})
settingsFrameWritten = make(chan struct{})
controlStr := mockquic.NewMockStream(mockCtrl)
controlStr.EXPECT().Write(gomock.Any()).Do(func(b []byte) (int, error) {
defer GinkgoRecover()
close(settingsFrameWritten)
return len(b), nil
})
conn = mockquic.NewMockEarlyConnection(mockCtrl)
conn.EXPECT().OpenUniStream().Return(controlStr, nil)
conn.EXPECT().HandshakeComplete().Return(handshakeChan)
conn.EXPECT().OpenStreamSync(gomock.Any()).Return(nil, errors.New("done"))
conn.EXPECT().AcceptUniStream(gomock.Any()).Return(nil, errors.New("done")).AnyTimes()
dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) {
return conn, nil
}
var err error
request, err = http.NewRequest("GET", "https://quic.clemente.io:1337/file1.dat", nil)
Expect(err).ToNot(HaveOccurred())
})
AfterEach(func() {
testDone <- struct{}{}
Eventually(settingsFrameWritten).Should(BeClosed())
})
It("hijacks a bidirectional stream of unknown frame type", func() {
frameTypeChan := make(chan FrameType, 1)
cl.opts.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream, e error) (hijacked bool, err error) {
Expect(e).ToNot(HaveOccurred())
frameTypeChan <- ft
return true, nil
}
buf := bytes.NewBuffer(quicvarint.Append(nil, 0x41))
unknownStr := mockquic.NewMockStream(mockCtrl)
unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil)
conn.EXPECT().AcceptStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.Stream, error) {
<-testDone
return nil, errors.New("test done")
})
_, err := cl.RoundTripOpt(request, RoundTripOpt{})
Expect(err).To(MatchError("done"))
Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41)))
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
})
It("closes the connection when hijacker didn't hijack a bidirectional stream", func() {
frameTypeChan := make(chan FrameType, 1)
cl.opts.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream, e error) (hijacked bool, err error) {
Expect(e).ToNot(HaveOccurred())
frameTypeChan <- ft
return false, nil
}
buf := bytes.NewBuffer(quicvarint.Append(nil, 0x41))
unknownStr := mockquic.NewMockStream(mockCtrl)
unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil)
conn.EXPECT().AcceptStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.Stream, error) {
<-testDone
return nil, errors.New("test done")
})
conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), gomock.Any()).Return(nil).AnyTimes()
_, err := cl.RoundTripOpt(request, RoundTripOpt{})
Expect(err).To(MatchError("done"))
Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41)))
})
It("closes the connection when hijacker returned error", func() {
frameTypeChan := make(chan FrameType, 1)
cl.opts.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream, e error) (hijacked bool, err error) {
Expect(e).ToNot(HaveOccurred())
frameTypeChan <- ft
return false, errors.New("error in hijacker")
}
buf := bytes.NewBuffer(quicvarint.Append(nil, 0x41))
unknownStr := mockquic.NewMockStream(mockCtrl)
unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil)
conn.EXPECT().AcceptStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.Stream, error) {
<-testDone
return nil, errors.New("test done")
})
conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), gomock.Any()).Return(nil).AnyTimes()
_, err := cl.RoundTripOpt(request, RoundTripOpt{})
Expect(err).To(MatchError("done"))
Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41)))
})
It("handles errors that occur when reading the frame type", func() {
testErr := errors.New("test error")
unknownStr := mockquic.NewMockStream(mockCtrl)
done := make(chan struct{})
cl.opts.StreamHijacker = func(ft FrameType, c quic.Connection, str quic.Stream, e error) (hijacked bool, err error) {
defer close(done)
Expect(e).To(MatchError(testErr))
Expect(ft).To(BeZero())
Expect(str).To(Equal(unknownStr))
return false, nil
}
unknownStr.EXPECT().Read(gomock.Any()).Return(0, testErr).AnyTimes()
conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil)
conn.EXPECT().AcceptStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.Stream, error) {
<-testDone
return nil, errors.New("test done")
})
conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), gomock.Any()).Return(nil).AnyTimes()
_, err := cl.RoundTripOpt(request, RoundTripOpt{})
Expect(err).To(MatchError("done"))
Eventually(done).Should(BeClosed())
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
})
})
Context("hijacking unidirectional streams", func() {
var (
req *http.Request
conn *mockquic.MockEarlyConnection
settingsFrameWritten chan struct{}
)
testDone := make(chan struct{})
BeforeEach(func() {
testDone = make(chan struct{})
settingsFrameWritten = make(chan struct{})
controlStr := mockquic.NewMockStream(mockCtrl)
controlStr.EXPECT().Write(gomock.Any()).Do(func(b []byte) (int, error) {
defer GinkgoRecover()
close(settingsFrameWritten)
return len(b), nil
})
conn = mockquic.NewMockEarlyConnection(mockCtrl)
conn.EXPECT().OpenUniStream().Return(controlStr, nil)
conn.EXPECT().HandshakeComplete().Return(handshakeChan)
conn.EXPECT().OpenStreamSync(gomock.Any()).Return(nil, errors.New("done"))
dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) {
return conn, nil
}
var err error
req, err = http.NewRequest("GET", "https://quic.clemente.io:1337/file1.dat", nil)
Expect(err).ToNot(HaveOccurred())
})
AfterEach(func() {
testDone <- struct{}{}
Eventually(settingsFrameWritten).Should(BeClosed())
})
It("hijacks an unidirectional stream of unknown stream type", func() {
streamTypeChan := make(chan StreamType, 1)
cl.opts.UniStreamHijacker = func(st StreamType, _ quic.Connection, _ quic.ReceiveStream, err error) bool {
Expect(err).ToNot(HaveOccurred())
streamTypeChan <- st
return true
}
buf := bytes.NewBuffer(quicvarint.Append(nil, 0x54))
unknownStr := mockquic.NewMockStream(mockCtrl)
unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
return unknownStr, nil
})
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
<-testDone
return nil, errors.New("test done")
})
_, err := cl.RoundTripOpt(req, RoundTripOpt{})
Expect(err).To(MatchError("done"))
Eventually(streamTypeChan).Should(Receive(BeEquivalentTo(0x54)))
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
})
It("handles errors that occur when reading the stream type", func() {
testErr := errors.New("test error")
done := make(chan struct{})
unknownStr := mockquic.NewMockStream(mockCtrl)
cl.opts.UniStreamHijacker = func(st StreamType, _ quic.Connection, str quic.ReceiveStream, err error) bool {
defer close(done)
Expect(st).To(BeZero())
Expect(str).To(Equal(unknownStr))
Expect(err).To(MatchError(testErr))
return true
}
unknownStr.EXPECT().Read(gomock.Any()).Return(0, testErr)
conn.EXPECT().AcceptUniStream(gomock.Any()).Return(unknownStr, nil)
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
<-testDone
return nil, errors.New("test done")
})
_, err := cl.RoundTripOpt(req, RoundTripOpt{})
Expect(err).To(MatchError("done"))
Eventually(done).Should(BeClosed())
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
})
It("cancels reading when hijacker didn't hijack an unidirectional stream", func() {
streamTypeChan := make(chan StreamType, 1)
cl.opts.UniStreamHijacker = func(st StreamType, _ quic.Connection, _ quic.ReceiveStream, err error) bool {
Expect(err).ToNot(HaveOccurred())
streamTypeChan <- st
return false
}
buf := bytes.NewBuffer(quicvarint.Append(nil, 0x54))
unknownStr := mockquic.NewMockStream(mockCtrl)
unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
unknownStr.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeStreamCreationError))
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
return unknownStr, nil
})
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
<-testDone
return nil, errors.New("test done")
})
_, err := cl.RoundTripOpt(req, RoundTripOpt{})
Expect(err).To(MatchError("done"))
Eventually(streamTypeChan).Should(Receive(BeEquivalentTo(0x54)))
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
})
})
Context("control stream handling", func() {
var (
req *http.Request
conn *mockquic.MockEarlyConnection
settingsFrameWritten chan struct{}
)
testDone := make(chan struct{}, 1)
BeforeEach(func() {
settingsFrameWritten = make(chan struct{})
controlStr := mockquic.NewMockStream(mockCtrl)
controlStr.EXPECT().Write(gomock.Any()).Do(func(b []byte) (int, error) {
defer GinkgoRecover()
close(settingsFrameWritten)
return len(b), nil
})
conn = mockquic.NewMockEarlyConnection(mockCtrl)
conn.EXPECT().OpenUniStream().Return(controlStr, nil)
conn.EXPECT().HandshakeComplete().Return(handshakeChan)
dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) {
return conn, nil
}
var err error
req, err = http.NewRequest("GET", "https://quic.clemente.io:1337/file1.dat", nil)
Expect(err).ToNot(HaveOccurred())
})
AfterEach(func() {
testDone <- struct{}{}
Eventually(settingsFrameWritten).Should(BeClosed())
})
It("parses the SETTINGS frame", func() {
b := quicvarint.Append(nil, streamTypeControlStream)
b = (&settingsFrame{
Datagram: true,
ExtendedConnect: true,
Other: map[uint64]uint64{1337: 42},
}).Append(b)
r := bytes.NewReader(b)
controlStr := mockquic.NewMockStream(mockCtrl)
controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes()
conn.EXPECT().OpenStreamSync(gomock.Any()).Return(nil, errors.New("done"))
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
return controlStr, nil
})
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
<-testDone
return nil, errors.New("test done")
})
conn.EXPECT().Context().Return(context.Background())
_, err := cl.RoundTripOpt(req, RoundTripOpt{CheckSettings: func(settings Settings) error {
defer GinkgoRecover()
Expect(settings.EnableDatagram).To(BeTrue())
Expect(settings.EnableExtendedConnect).To(BeTrue())
Expect(settings.Other).To(HaveLen(1))
Expect(settings.Other).To(HaveKeyWithValue(uint64(1337), uint64(42)))
return nil
}})
Expect(err).To(MatchError("done"))
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
})
It("allows the client to reject the SETTINGS using the CheckSettings RoundTripOpt", func() {
b := quicvarint.Append(nil, streamTypeControlStream)
b = (&settingsFrame{}).Append(b)
r := bytes.NewReader(b)
controlStr := mockquic.NewMockStream(mockCtrl)
controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes()
// Don't EXPECT any call to OpenStreamSync.
// When the SETTINGS are rejected, we don't even open the request stream.
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
return controlStr, nil
})
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
<-testDone
return nil, errors.New("test done")
})
conn.EXPECT().Context().Return(context.Background())
_, err := cl.RoundTripOpt(req, RoundTripOpt{CheckSettings: func(settings Settings) error {
return errors.New("wrong settings")
}})
Expect(err).To(MatchError("wrong settings"))
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
})
It("rejects duplicate control streams", func() {
b := quicvarint.Append(nil, streamTypeControlStream)
b = (&settingsFrame{}).Append(b)
r1 := bytes.NewReader(b)
controlStr1 := mockquic.NewMockStream(mockCtrl)
controlStr1.EXPECT().Read(gomock.Any()).DoAndReturn(r1.Read).AnyTimes()
r2 := bytes.NewReader(b)
controlStr2 := mockquic.NewMockStream(mockCtrl)
controlStr2.EXPECT().Read(gomock.Any()).DoAndReturn(r2.Read).AnyTimes()
done := make(chan struct{})
conn.EXPECT().OpenStreamSync(gomock.Any()).Return(nil, errors.New("done"))
conn.EXPECT().CloseWithError(qerr.ApplicationErrorCode(ErrCodeStreamCreationError), "duplicate control stream").Do(func(qerr.ApplicationErrorCode, string) error {
close(done)
return nil
})
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
return controlStr1, nil
})
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
return controlStr2, nil
})
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
<-done
return nil, errors.New("test done")
})
_, err := cl.RoundTripOpt(req, RoundTripOpt{})
Expect(err).To(HaveOccurred())
Eventually(done).Should(BeClosed())
})
for _, t := range []uint64{streamTypeQPACKEncoderStream, streamTypeQPACKDecoderStream} {
streamType := t
name := "encoder"
if streamType == streamTypeQPACKDecoderStream {
name = "decoder"
}
It(fmt.Sprintf("ignores the QPACK %s streams", name), func() {
buf := bytes.NewBuffer(quicvarint.Append(nil, streamType))
str := mockquic.NewMockStream(mockCtrl)
str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
conn.EXPECT().OpenStreamSync(gomock.Any()).Return(nil, errors.New("done"))
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
return str, nil
})
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
<-testDone
return nil, errors.New("test done")
})
_, err := cl.RoundTripOpt(req, RoundTripOpt{})
Expect(err).To(MatchError("done"))
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to str.CancelRead
})
}
It("resets streams other than the control stream and the QPACK streams", func() {
buf := bytes.NewBuffer(quicvarint.Append(nil, 0x1337))
str := mockquic.NewMockStream(mockCtrl)
str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
done := make(chan struct{})
str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeStreamCreationError)).Do(func(quic.StreamErrorCode) { close(done) })
conn.EXPECT().OpenStreamSync(gomock.Any()).Return(nil, errors.New("done"))
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
return str, nil
})
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
<-testDone
return nil, errors.New("test done")
})
_, err := cl.RoundTripOpt(req, RoundTripOpt{})
Expect(err).To(MatchError("done"))
Eventually(done).Should(BeClosed())
})
It("errors when the first frame on the control stream is not a SETTINGS frame", func() {
b := quicvarint.Append(nil, streamTypeControlStream)
b = (&dataFrame{}).Append(b)
r := bytes.NewReader(b)
controlStr := mockquic.NewMockStream(mockCtrl)
controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes()
conn.EXPECT().OpenStreamSync(gomock.Any()).Return(nil, errors.New("done"))
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
return controlStr, nil
})
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
<-testDone
return nil, errors.New("test done")
})
done := make(chan struct{})
conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeMissingSettings), gomock.Any()).Do(func(quic.ApplicationErrorCode, string) error {
close(done)
return nil
})
_, err := cl.RoundTripOpt(req, RoundTripOpt{})
Expect(err).To(MatchError("done"))
Eventually(done).Should(BeClosed())
})
It("errors when the first frame on the control stream is not a SETTINGS frame, when checking SETTINGS", func() {
b := quicvarint.Append(nil, streamTypeControlStream)
b = (&dataFrame{}).Append(b)
r := bytes.NewReader(b)
controlStr := mockquic.NewMockStream(mockCtrl)
controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes()
// Don't EXPECT any calls to OpenStreamSync.
// We fail before we even get the chance to open the request stream.
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
return controlStr, nil
})
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
<-testDone
return nil, errors.New("test done")
})
doneCtx, doneCancel := context.WithCancelCause(context.Background())
conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeMissingSettings), gomock.Any()).Do(func(quic.ApplicationErrorCode, string) error {
doneCancel(errors.New("done"))
return nil
})
conn.EXPECT().Context().Return(doneCtx).Times(2)
var checked bool
_, err := cl.RoundTripOpt(req, RoundTripOpt{
CheckSettings: func(Settings) error { checked = true; return nil },
})
Expect(checked).To(BeFalse())
Expect(err).To(MatchError("done"))
Eventually(doneCtx.Done()).Should(BeClosed())
})
It("errors when parsing the frame on the control stream fails", func() {
b := quicvarint.Append(nil, streamTypeControlStream)
b = (&settingsFrame{}).Append(b)
r := bytes.NewReader(b[:len(b)-1])
controlStr := mockquic.NewMockStream(mockCtrl)
controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes()
conn.EXPECT().OpenStreamSync(gomock.Any()).Return(nil, errors.New("done"))
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
return controlStr, nil
})
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
<-testDone
return nil, errors.New("test done")
})
done := make(chan struct{})
conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameError), gomock.Any()).Do(func(code quic.ApplicationErrorCode, _ string) error {
close(done)
return nil
})
_, err := cl.RoundTripOpt(req, RoundTripOpt{})
Expect(err).To(MatchError("done"))
Eventually(done).Should(BeClosed())
})
It("errors when parsing the server opens a push stream", func() {
buf := bytes.NewBuffer(quicvarint.Append(nil, streamTypePushStream))
controlStr := mockquic.NewMockStream(mockCtrl)
controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
conn.EXPECT().OpenStreamSync(gomock.Any()).Return(nil, errors.New("done"))
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
return controlStr, nil
})
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
<-testDone
return nil, errors.New("test done")
})
done := make(chan struct{})
conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeIDError), gomock.Any()).Do(func(quic.ApplicationErrorCode, string) error {
close(done)
return nil
})
_, err := cl.RoundTripOpt(req, RoundTripOpt{})
Expect(err).To(MatchError("done"))
Eventually(done).Should(BeClosed())
})
It("errors when the server advertises datagram support (and we enabled support for it)", func() {
cl.opts.EnableDatagram = true
b := quicvarint.Append(nil, streamTypeControlStream)
b = (&settingsFrame{Datagram: true}).Append(b)
r := bytes.NewReader(b)
controlStr := mockquic.NewMockStream(mockCtrl)
controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes()
conn.EXPECT().OpenStreamSync(gomock.Any()).Return(nil, errors.New("done"))
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
return controlStr, nil
})
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
<-testDone
return nil, errors.New("test done")
})
conn.EXPECT().ConnectionState().Return(quic.ConnectionState{SupportsDatagrams: false})
done := make(chan struct{})
conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeSettingsError), "missing QUIC Datagram support").Do(func(quic.ApplicationErrorCode, string) error {
close(done)
return nil
})
_, err := cl.RoundTripOpt(req, RoundTripOpt{})
Expect(err).To(MatchError("done"))
Eventually(done).Should(BeClosed())
})
})
Context("Doing requests", func() {
var (
req *http.Request
str *mockquic.MockStream
conn *mockquic.MockEarlyConnection
settingsFrameWritten chan struct{}
)
testDone := make(chan struct{})
decodeHeader := func(str io.Reader) map[string]string {
fields := make(map[string]string)
decoder := qpack.NewDecoder(nil)
frame, err := parseNextFrame(str, nil)
ExpectWithOffset(1, err).ToNot(HaveOccurred())
ExpectWithOffset(1, frame).To(BeAssignableToTypeOf(&headersFrame{}))
headersFrame := frame.(*headersFrame)
data := make([]byte, headersFrame.Length)
_, err = io.ReadFull(str, data)
ExpectWithOffset(1, err).ToNot(HaveOccurred())
hfs, err := decoder.DecodeFull(data)
ExpectWithOffset(1, err).ToNot(HaveOccurred())
for _, p := range hfs {
fields[p.Name] = p.Value
}
return fields
}
getResponse := func(status int) []byte {
buf := &bytes.Buffer{}
rstr := mockquic.NewMockStream(mockCtrl)
rstr.EXPECT().Write(gomock.Any()).Do(buf.Write).AnyTimes()
rw := newResponseWriter(rstr, nil, utils.DefaultLogger)
rw.WriteHeader(status)
rw.Flush()
return buf.Bytes()
}
BeforeEach(func() {
settingsFrameWritten = make(chan struct{})
controlStr := mockquic.NewMockStream(mockCtrl)
controlStr.EXPECT().Write(gomock.Any()).Do(func(b []byte) (int, error) {
defer GinkgoRecover()
r := bytes.NewReader(b)
streamType, err := quicvarint.Read(r)
Expect(err).ToNot(HaveOccurred())
Expect(streamType).To(BeEquivalentTo(streamTypeControlStream))
close(settingsFrameWritten)
return len(b), nil
}) // SETTINGS frame
str = mockquic.NewMockStream(mockCtrl)
conn = mockquic.NewMockEarlyConnection(mockCtrl)
conn.EXPECT().OpenUniStream().Return(controlStr, nil)
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
<-testDone
return nil, errors.New("test done")
})
dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) {
return conn, nil
}
var err error
req, err = http.NewRequest("GET", "https://quic.clemente.io:1337/file1.dat", nil)
Expect(err).ToNot(HaveOccurred())
})
AfterEach(func() {
testDone <- struct{}{}
Eventually(settingsFrameWritten).Should(BeClosed())
})
It("errors if it can't open a stream", func() {
testErr := errors.New("stream open error")
conn.EXPECT().OpenStreamSync(context.Background()).Return(nil, testErr)
conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).MaxTimes(1)
conn.EXPECT().HandshakeComplete().Return(handshakeChan)
_, err := cl.RoundTripOpt(req, RoundTripOpt{})
Expect(err).To(MatchError(testErr))
})
It("performs a 0-RTT request", func() {
testErr := errors.New("stream open error")
req.Method = MethodGet0RTT
// don't EXPECT any calls to HandshakeComplete()
conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil)
buf := &bytes.Buffer{}
str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write).AnyTimes()
str.EXPECT().Close()
str.EXPECT().CancelWrite(gomock.Any())
str.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) {
return 0, testErr
})
_, err := cl.RoundTripOpt(req, RoundTripOpt{})
Expect(err).To(MatchError(testErr))
Expect(decodeHeader(buf)).To(HaveKeyWithValue(":method", "GET"))
})
It("returns a response", func() {
rspBuf := bytes.NewBuffer(getResponse(418))
gomock.InOrder(
conn.EXPECT().HandshakeComplete().Return(handshakeChan),
conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil),
conn.EXPECT().ConnectionState().Return(quic.ConnectionState{}),
)
str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil })
str.EXPECT().Close()
str.EXPECT().Read(gomock.Any()).DoAndReturn(rspBuf.Read).AnyTimes()
rsp, err := cl.RoundTripOpt(req, RoundTripOpt{})
Expect(err).ToNot(HaveOccurred())
Expect(rsp.Proto).To(Equal("HTTP/3.0"))
Expect(rsp.ProtoMajor).To(Equal(3))
Expect(rsp.StatusCode).To(Equal(418))
Expect(rsp.Request).ToNot(BeNil())
})
It("doesn't close the request stream, with DontCloseRequestStream set", func() {
rspBuf := bytes.NewBuffer(getResponse(418))
gomock.InOrder(
conn.EXPECT().HandshakeComplete().Return(handshakeChan),
conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil),
conn.EXPECT().ConnectionState().Return(quic.ConnectionState{}),
)
str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil })
str.EXPECT().Read(gomock.Any()).DoAndReturn(rspBuf.Read).AnyTimes()
rsp, err := cl.RoundTripOpt(req, RoundTripOpt{DontCloseRequestStream: true})
Expect(err).ToNot(HaveOccurred())
Expect(rsp.Proto).To(Equal("HTTP/3.0"))
Expect(rsp.ProtoMajor).To(Equal(3))
Expect(rsp.StatusCode).To(Equal(418))
})
Context("requests containing a Body", func() {
var strBuf *bytes.Buffer
BeforeEach(func() {
strBuf = &bytes.Buffer{}
gomock.InOrder(
conn.EXPECT().HandshakeComplete().Return(handshakeChan),
conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil),
)
body := &mockBody{}
body.SetData([]byte("request body"))
var err error
req, err = http.NewRequest("POST", "https://quic.clemente.io:1337/upload", body)
Expect(err).ToNot(HaveOccurred())
str.EXPECT().Write(gomock.Any()).DoAndReturn(strBuf.Write).AnyTimes()
})
It("sends a request", func() {
done := make(chan struct{})
gomock.InOrder(
str.EXPECT().Close().Do(func() error { close(done); return nil }),
str.EXPECT().CancelWrite(gomock.Any()).MaxTimes(1), // when reading the response errors
)
// the response body is sent asynchronously, while already reading the response
str.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) {
<-done
return 0, errors.New("test done")
})
_, err := cl.RoundTripOpt(req, RoundTripOpt{})
Expect(err).To(MatchError("test done"))
hfs := decodeHeader(strBuf)
Expect(hfs).To(HaveKeyWithValue(":method", "POST"))
Expect(hfs).To(HaveKeyWithValue(":path", "/upload"))
})
It("doesn't send more bytes than allowed by http.Request.ContentLength", func() {
req.ContentLength = 7
var once sync.Once
done := make(chan struct{})
gomock.InOrder(
str.EXPECT().CancelWrite(gomock.Any()).Do(func(c quic.StreamErrorCode) {
once.Do(func() {
Expect(c).To(Equal(quic.StreamErrorCode(ErrCodeRequestCanceled)))
close(done)
})
}).AnyTimes(),
str.EXPECT().Close().MaxTimes(1),
str.EXPECT().CancelWrite(gomock.Any()).AnyTimes(),
)
str.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) {
<-done
return 0, errors.New("done")
})
cl.RoundTripOpt(req, RoundTripOpt{})
Expect(strBuf.String()).To(ContainSubstring("request"))
Expect(strBuf.String()).ToNot(ContainSubstring("request body"))
})
It("returns the error that occurred when reading the body", func() {
req.Body.(*mockBody).readErr = errors.New("testErr")
done := make(chan struct{})
gomock.InOrder(
str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeRequestCanceled)).Do(func(quic.StreamErrorCode) {
close(done)
}),
str.EXPECT().CancelWrite(gomock.Any()),
)
// the response body is sent asynchronously, while already reading the response
str.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) {
<-done
return 0, errors.New("test done")
})
closed := make(chan struct{})
str.EXPECT().Close().Do(func() error { close(closed); return nil })
_, err := cl.RoundTripOpt(req, RoundTripOpt{})
Expect(err).To(MatchError("test done"))
Eventually(closed).Should(BeClosed())
})
It("closes the connection when the first frame is not a HEADERS frame", func() {
b := (&dataFrame{Length: 0x42}).Append(nil)
conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), gomock.Any())
closed := make(chan struct{})
r := bytes.NewReader(b)
str.EXPECT().Close().Do(func() error { close(closed); return nil })
str.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes()
_, err := cl.RoundTripOpt(req, RoundTripOpt{})
Expect(err).To(MatchError("expected first frame to be a HEADERS frame"))
Eventually(closed).Should(BeClosed())
})
It("cancels the stream when parsing the headers fails", func() {
headerBuf := &bytes.Buffer{}
enc := qpack.NewEncoder(headerBuf)
Expect(enc.WriteField(qpack.HeaderField{Name: ":method", Value: "GET"})).To(Succeed()) // not a valid response pseudo header
Expect(enc.Close()).To(Succeed())
b := (&headersFrame{Length: uint64(headerBuf.Len())}).Append(nil)
b = append(b, headerBuf.Bytes()...)
r := bytes.NewReader(b)
str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeMessageError))
closed := make(chan struct{})
str.EXPECT().Close().Do(func() error { close(closed); return nil })
str.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes()
_, err := cl.RoundTripOpt(req, RoundTripOpt{})
Expect(err).To(HaveOccurred())
Eventually(closed).Should(BeClosed())
})
It("cancels the stream when the HEADERS frame is too large", func() {
b := (&headersFrame{Length: 1338}).Append(nil)
r := bytes.NewReader(b)
str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeFrameError))
closed := make(chan struct{})
str.EXPECT().Close().Do(func() error { close(closed); return nil })
str.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes()
_, err := cl.RoundTripOpt(req, RoundTripOpt{})
Expect(err).To(MatchError("HEADERS frame too large: 1338 bytes (max: 1337)"))
Eventually(closed).Should(BeClosed())
})
})
Context("request cancellations", func() {
for _, dontClose := range []bool{false, true} {
dontClose := dontClose
Context(fmt.Sprintf("with DontCloseRequestStream: %t", dontClose), func() {
roundTripOpt := RoundTripOpt{DontCloseRequestStream: dontClose}
It("cancels a request while waiting for the handshake to complete", func() {
ctx, cancel := context.WithCancel(context.Background())
req := req.WithContext(ctx)
conn.EXPECT().HandshakeComplete().Return(make(chan struct{}))
errChan := make(chan error)
go func() {
_, err := cl.RoundTripOpt(req, roundTripOpt)
errChan <- err
}()
Consistently(errChan).ShouldNot(Receive())
cancel()
Eventually(errChan).Should(Receive(MatchError("context canceled")))
})
It("cancels a request while the request is still in flight", func() {
ctx, cancel := context.WithCancel(context.Background())
req := req.WithContext(ctx)
conn.EXPECT().HandshakeComplete().Return(handshakeChan)
conn.EXPECT().OpenStreamSync(ctx).Return(str, nil)
buf := &bytes.Buffer{}
str.EXPECT().Close().MaxTimes(1)
str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write)
done := make(chan struct{})
canceled := make(chan struct{})
gomock.InOrder(
str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeRequestCanceled)).Do(func(quic.StreamErrorCode) { close(canceled) }),
str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeRequestCanceled)).Do(func(quic.StreamErrorCode) { close(done) }),
)
str.EXPECT().CancelWrite(gomock.Any()).MaxTimes(1)
str.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) {
cancel()
<-canceled
return 0, errors.New("test done")
})
_, err := cl.RoundTripOpt(req, roundTripOpt)
Expect(err).To(MatchError(context.Canceled))
Eventually(done).Should(BeClosed())
})
})
}
It("cancels a request after the response arrived", func() {
rspBuf := bytes.NewBuffer(getResponse(404))
ctx, cancel := context.WithCancel(context.Background())
req := req.WithContext(ctx)
conn.EXPECT().HandshakeComplete().Return(handshakeChan)
conn.EXPECT().OpenStreamSync(ctx).Return(str, nil)
conn.EXPECT().ConnectionState().Return(quic.ConnectionState{})
buf := &bytes.Buffer{}
str.EXPECT().Close().MaxTimes(1)
done := make(chan struct{})
str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write)
str.EXPECT().Read(gomock.Any()).DoAndReturn(rspBuf.Read).AnyTimes()
str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeRequestCanceled))
str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeRequestCanceled)).Do(func(quic.StreamErrorCode) { close(done) })
_, err := cl.RoundTripOpt(req, RoundTripOpt{})
Expect(err).ToNot(HaveOccurred())
cancel()
Eventually(done).Should(BeClosed())
})
})
Context("gzip compression", func() {
BeforeEach(func() {
conn.EXPECT().HandshakeComplete().Return(handshakeChan)
})
It("adds the gzip header to requests", func() {
conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil)
buf := &bytes.Buffer{}
str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write)
gomock.InOrder(
str.EXPECT().Close(),
str.EXPECT().CancelWrite(gomock.Any()).MaxTimes(1), // when the Read errors
)
str.EXPECT().Read(gomock.Any()).Return(0, errors.New("test done"))
_, err := cl.RoundTripOpt(req, RoundTripOpt{})
Expect(err).To(MatchError("test done"))
hfs := decodeHeader(buf)
Expect(hfs).To(HaveKeyWithValue("accept-encoding", "gzip"))
})
It("doesn't add gzip if the header disable it", func() {
client, err := newClient("quic.clemente.io:1337", nil, &roundTripperOpts{DisableCompression: true}, nil, nil)
Expect(err).ToNot(HaveOccurred())
conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil)
buf := &bytes.Buffer{}
str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write)
gomock.InOrder(
str.EXPECT().Close(),
str.EXPECT().CancelWrite(gomock.Any()).MaxTimes(1), // when the Read errors
)
str.EXPECT().Read(gomock.Any()).Return(0, errors.New("test done"))
_, err = client.RoundTripOpt(req, RoundTripOpt{})
Expect(err).To(MatchError("test done"))
hfs := decodeHeader(buf)
Expect(hfs).ToNot(HaveKey("accept-encoding"))
})
It("decompresses the response", func() {
conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil)
conn.EXPECT().ConnectionState().Return(quic.ConnectionState{})
buf := &bytes.Buffer{}
rstr := mockquic.NewMockStream(mockCtrl)
rstr.EXPECT().Write(gomock.Any()).Do(buf.Write).AnyTimes()
rw := newResponseWriter(rstr, nil, utils.DefaultLogger)
rw.Header().Set("Content-Encoding", "gzip")
gz := gzip.NewWriter(rw)
gz.Write([]byte("gzipped response"))
gz.Close()
rw.Flush()
str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil })
str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
str.EXPECT().Close()
rsp, err := cl.RoundTripOpt(req, RoundTripOpt{})
Expect(err).ToNot(HaveOccurred())
data, err := io.ReadAll(rsp.Body)
Expect(err).ToNot(HaveOccurred())
Expect(rsp.ContentLength).To(BeEquivalentTo(-1))
Expect(string(data)).To(Equal("gzipped response"))
Expect(rsp.Header.Get("Content-Encoding")).To(BeEmpty())
Expect(rsp.Uncompressed).To(BeTrue())
})
It("only decompresses the response if the response contains the right content-encoding header", func() {
conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil)
conn.EXPECT().ConnectionState().Return(quic.ConnectionState{})
buf := &bytes.Buffer{}
rstr := mockquic.NewMockStream(mockCtrl)
rstr.EXPECT().Write(gomock.Any()).Do(buf.Write).AnyTimes()
rw := newResponseWriter(rstr, nil, utils.DefaultLogger)
rw.Write([]byte("not gzipped"))
rw.Flush()
str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil })
str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
str.EXPECT().Close()
rsp, err := cl.RoundTripOpt(req, RoundTripOpt{})
Expect(err).ToNot(HaveOccurred())
data, err := io.ReadAll(rsp.Body)
Expect(err).ToNot(HaveOccurred())
Expect(string(data)).To(Equal("not gzipped"))
Expect(rsp.Header.Get("Content-Encoding")).To(BeEmpty())
})
})
})
})