mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-02 19:57:35 +03:00
uTLS is not yet bumped to the new version, so this commit breaks the dependencies relationship by getting rid of the local replace.
1043 lines
41 KiB
Go
1043 lines
41 KiB
Go
package http3
|
|
|
|
import (
|
|
"bytes"
|
|
"compress/gzip"
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"sync"
|
|
"time"
|
|
|
|
quic "github.com/refraction-networking/uquic"
|
|
tls "github.com/refraction-networking/utls"
|
|
|
|
mockquic "github.com/refraction-networking/uquic/internal/mocks/quic"
|
|
"github.com/refraction-networking/uquic/internal/protocol"
|
|
"github.com/refraction-networking/uquic/internal/utils"
|
|
"github.com/refraction-networking/uquic/quicvarint"
|
|
|
|
"github.com/golang/mock/gomock"
|
|
"github.com/quic-go/qpack"
|
|
|
|
. "github.com/onsi/ginkgo/v2"
|
|
. "github.com/onsi/gomega"
|
|
)
|
|
|
|
var _ = Describe("Client", func() {
|
|
var (
|
|
cl *client
|
|
req *http.Request
|
|
origDialAddr = dialAddr
|
|
handshakeChan <-chan struct{} // a closed chan
|
|
)
|
|
|
|
BeforeEach(func() {
|
|
origDialAddr = dialAddr
|
|
hostname := "quic.clemente.io:1337"
|
|
c, err := newClient(hostname, nil, &roundTripperOpts{MaxHeaderBytes: 1337}, nil, nil)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
cl = c.(*client)
|
|
Expect(cl.hostname).To(Equal(hostname))
|
|
|
|
req, err = http.NewRequest("GET", "https://localhost:1337", nil)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
|
|
ch := make(chan struct{})
|
|
close(ch)
|
|
handshakeChan = ch
|
|
})
|
|
|
|
AfterEach(func() {
|
|
dialAddr = origDialAddr
|
|
})
|
|
|
|
It("rejects quic.Configs that allow multiple QUIC versions", func() {
|
|
qconf := &quic.Config{
|
|
Versions: []quic.VersionNumber{protocol.Version2, protocol.Version1},
|
|
}
|
|
_, err := newClient("localhost:1337", nil, &roundTripperOpts{}, qconf, nil)
|
|
Expect(err).To(MatchError("can only use a single QUIC version for dialing a HTTP/3 connection"))
|
|
})
|
|
|
|
It("uses the default QUIC and TLS config if none is give", func() {
|
|
client, err := newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
var dialAddrCalled bool
|
|
dialAddr = func(_ context.Context, _ string, tlsConf *tls.Config, quicConf *quic.Config) (quic.EarlyConnection, error) {
|
|
Expect(quicConf.MaxIncomingStreams).To(Equal(defaultQuicConfig.MaxIncomingStreams))
|
|
Expect(tlsConf.NextProtos).To(Equal([]string{NextProtoH3}))
|
|
Expect(quicConf.Versions).To(Equal([]protocol.VersionNumber{protocol.Version1}))
|
|
dialAddrCalled = true
|
|
return nil, errors.New("test done")
|
|
}
|
|
client.RoundTripOpt(req, RoundTripOpt{})
|
|
Expect(dialAddrCalled).To(BeTrue())
|
|
})
|
|
|
|
It("adds the port to the hostname, if none is given", func() {
|
|
client, err := newClient("quic.clemente.io", nil, &roundTripperOpts{}, nil, nil)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
var dialAddrCalled bool
|
|
dialAddr = func(_ context.Context, hostname string, _ *tls.Config, _ *quic.Config) (quic.EarlyConnection, error) {
|
|
Expect(hostname).To(Equal("quic.clemente.io:443"))
|
|
dialAddrCalled = true
|
|
return nil, errors.New("test done")
|
|
}
|
|
req, err := http.NewRequest("GET", "https://quic.clemente.io:443", nil)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
client.RoundTripOpt(req, RoundTripOpt{})
|
|
Expect(dialAddrCalled).To(BeTrue())
|
|
})
|
|
|
|
It("sets the ServerName in the tls.Config, if not set", func() {
|
|
const host = "foo.bar"
|
|
dialCalled := false
|
|
dialFunc := func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) {
|
|
Expect(tlsCfg.ServerName).To(Equal(host))
|
|
dialCalled = true
|
|
return nil, errors.New("test done")
|
|
}
|
|
client, err := newClient(host, nil, &roundTripperOpts{}, nil, dialFunc)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
req, err := http.NewRequest("GET", "https://foo.bar", nil)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
client.RoundTripOpt(req, RoundTripOpt{})
|
|
Expect(dialCalled).To(BeTrue())
|
|
})
|
|
|
|
It("uses the TLS config and QUIC config", func() {
|
|
tlsConf := &tls.Config{
|
|
ServerName: "foo.bar",
|
|
NextProtos: []string{"proto foo", "proto bar"},
|
|
}
|
|
quicConf := &quic.Config{MaxIdleTimeout: time.Nanosecond}
|
|
client, err := newClient("localhost:1337", tlsConf, &roundTripperOpts{}, quicConf, nil)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
var dialAddrCalled bool
|
|
dialAddr = func(_ context.Context, host string, tlsConfP *tls.Config, quicConfP *quic.Config) (quic.EarlyConnection, error) {
|
|
Expect(host).To(Equal("localhost:1337"))
|
|
Expect(tlsConfP.ServerName).To(Equal(tlsConf.ServerName))
|
|
Expect(tlsConfP.NextProtos).To(Equal([]string{NextProtoH3}))
|
|
Expect(quicConfP.MaxIdleTimeout).To(Equal(quicConf.MaxIdleTimeout))
|
|
dialAddrCalled = true
|
|
return nil, errors.New("test done")
|
|
}
|
|
client.RoundTripOpt(req, RoundTripOpt{})
|
|
Expect(dialAddrCalled).To(BeTrue())
|
|
// make sure the original tls.Config was not modified
|
|
Expect(tlsConf.NextProtos).To(Equal([]string{"proto foo", "proto bar"}))
|
|
})
|
|
|
|
It("uses the custom dialer, if provided", func() {
|
|
testErr := errors.New("test done")
|
|
tlsConf := &tls.Config{ServerName: "foo.bar"}
|
|
quicConf := &quic.Config{MaxIdleTimeout: 1337 * time.Second}
|
|
ctx, cancel := context.WithTimeout(context.Background(), time.Hour)
|
|
defer cancel()
|
|
var dialerCalled bool
|
|
dialer := func(ctxP context.Context, address string, tlsConfP *tls.Config, quicConfP *quic.Config) (quic.EarlyConnection, error) {
|
|
Expect(ctxP).To(Equal(ctx))
|
|
Expect(address).To(Equal("localhost:1337"))
|
|
Expect(tlsConfP.ServerName).To(Equal("foo.bar"))
|
|
Expect(quicConfP.MaxIdleTimeout).To(Equal(quicConf.MaxIdleTimeout))
|
|
dialerCalled = true
|
|
return nil, testErr
|
|
}
|
|
client, err := newClient("localhost:1337", tlsConf, &roundTripperOpts{}, quicConf, dialer)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
_, err = client.RoundTripOpt(req.WithContext(ctx), RoundTripOpt{})
|
|
Expect(err).To(MatchError(testErr))
|
|
Expect(dialerCalled).To(BeTrue())
|
|
})
|
|
|
|
It("enables HTTP/3 Datagrams", func() {
|
|
testErr := errors.New("handshake error")
|
|
client, err := newClient("localhost:1337", nil, &roundTripperOpts{EnableDatagram: true}, nil, nil)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
dialAddr = func(_ context.Context, _ string, _ *tls.Config, quicConf *quic.Config) (quic.EarlyConnection, error) {
|
|
Expect(quicConf.EnableDatagrams).To(BeTrue())
|
|
return nil, testErr
|
|
}
|
|
_, err = client.RoundTripOpt(req, RoundTripOpt{})
|
|
Expect(err).To(MatchError(testErr))
|
|
})
|
|
|
|
It("errors when dialing fails", func() {
|
|
testErr := errors.New("handshake error")
|
|
client, err := newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) {
|
|
return nil, testErr
|
|
}
|
|
_, err = client.RoundTripOpt(req, RoundTripOpt{})
|
|
Expect(err).To(MatchError(testErr))
|
|
})
|
|
|
|
It("closes correctly if connection was not created", func() {
|
|
client, err := newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(client.Close()).To(Succeed())
|
|
})
|
|
|
|
Context("validating the address", func() {
|
|
It("refuses to do requests for the wrong host", func() {
|
|
req, err := http.NewRequest("https", "https://quic.clemente.io:1336/foobar.html", nil)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
_, err = cl.RoundTripOpt(req, RoundTripOpt{})
|
|
Expect(err).To(MatchError("http3 client BUG: RoundTripOpt called for the wrong client (expected quic.clemente.io:1337, got quic.clemente.io:1336)"))
|
|
})
|
|
|
|
It("allows requests using a different scheme", func() {
|
|
testErr := errors.New("handshake error")
|
|
req, err := http.NewRequest("masque", "masque://quic.clemente.io:1337/foobar.html", nil)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) {
|
|
return nil, testErr
|
|
}
|
|
_, err = cl.RoundTripOpt(req, RoundTripOpt{})
|
|
Expect(err).To(MatchError(testErr))
|
|
})
|
|
})
|
|
|
|
Context("hijacking bidirectional streams", func() {
|
|
var (
|
|
request *http.Request
|
|
conn *mockquic.MockEarlyConnection
|
|
settingsFrameWritten chan struct{}
|
|
)
|
|
testDone := make(chan struct{})
|
|
|
|
BeforeEach(func() {
|
|
testDone = make(chan struct{})
|
|
settingsFrameWritten = make(chan struct{})
|
|
controlStr := mockquic.NewMockStream(mockCtrl)
|
|
controlStr.EXPECT().Write(gomock.Any()).Do(func(b []byte) {
|
|
defer GinkgoRecover()
|
|
close(settingsFrameWritten)
|
|
})
|
|
conn = mockquic.NewMockEarlyConnection(mockCtrl)
|
|
conn.EXPECT().OpenUniStream().Return(controlStr, nil)
|
|
conn.EXPECT().HandshakeComplete().Return(handshakeChan)
|
|
conn.EXPECT().OpenStreamSync(gomock.Any()).Return(nil, errors.New("done"))
|
|
conn.EXPECT().AcceptUniStream(gomock.Any()).Return(nil, errors.New("done")).AnyTimes()
|
|
dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) {
|
|
return conn, nil
|
|
}
|
|
var err error
|
|
request, err = http.NewRequest("GET", "https://quic.clemente.io:1337/file1.dat", nil)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
})
|
|
|
|
AfterEach(func() {
|
|
testDone <- struct{}{}
|
|
Eventually(settingsFrameWritten).Should(BeClosed())
|
|
})
|
|
|
|
It("hijacks a bidirectional stream of unknown frame type", func() {
|
|
frameTypeChan := make(chan FrameType, 1)
|
|
cl.opts.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream, e error) (hijacked bool, err error) {
|
|
Expect(e).ToNot(HaveOccurred())
|
|
frameTypeChan <- ft
|
|
return true, nil
|
|
}
|
|
|
|
buf := bytes.NewBuffer(quicvarint.Append(nil, 0x41))
|
|
unknownStr := mockquic.NewMockStream(mockCtrl)
|
|
unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
|
|
conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil)
|
|
conn.EXPECT().AcceptStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.Stream, error) {
|
|
<-testDone
|
|
return nil, errors.New("test done")
|
|
})
|
|
_, err := cl.RoundTripOpt(request, RoundTripOpt{})
|
|
Expect(err).To(MatchError("done"))
|
|
Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41)))
|
|
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
|
|
})
|
|
|
|
It("closes the connection when hijacker didn't hijack a bidirectional stream", func() {
|
|
frameTypeChan := make(chan FrameType, 1)
|
|
cl.opts.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream, e error) (hijacked bool, err error) {
|
|
Expect(e).ToNot(HaveOccurred())
|
|
frameTypeChan <- ft
|
|
return false, nil
|
|
}
|
|
|
|
buf := bytes.NewBuffer(quicvarint.Append(nil, 0x41))
|
|
unknownStr := mockquic.NewMockStream(mockCtrl)
|
|
unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
|
|
conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil)
|
|
conn.EXPECT().AcceptStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.Stream, error) {
|
|
<-testDone
|
|
return nil, errors.New("test done")
|
|
})
|
|
conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), gomock.Any()).Return(nil).AnyTimes()
|
|
_, err := cl.RoundTripOpt(request, RoundTripOpt{})
|
|
Expect(err).To(MatchError("done"))
|
|
Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41)))
|
|
})
|
|
|
|
It("closes the connection when hijacker returned error", func() {
|
|
frameTypeChan := make(chan FrameType, 1)
|
|
cl.opts.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream, e error) (hijacked bool, err error) {
|
|
Expect(e).ToNot(HaveOccurred())
|
|
frameTypeChan <- ft
|
|
return false, errors.New("error in hijacker")
|
|
}
|
|
|
|
buf := bytes.NewBuffer(quicvarint.Append(nil, 0x41))
|
|
unknownStr := mockquic.NewMockStream(mockCtrl)
|
|
unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
|
|
conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil)
|
|
conn.EXPECT().AcceptStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.Stream, error) {
|
|
<-testDone
|
|
return nil, errors.New("test done")
|
|
})
|
|
conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), gomock.Any()).Return(nil).AnyTimes()
|
|
_, err := cl.RoundTripOpt(request, RoundTripOpt{})
|
|
Expect(err).To(MatchError("done"))
|
|
Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41)))
|
|
})
|
|
|
|
It("handles errors that occur when reading the frame type", func() {
|
|
testErr := errors.New("test error")
|
|
unknownStr := mockquic.NewMockStream(mockCtrl)
|
|
done := make(chan struct{})
|
|
cl.opts.StreamHijacker = func(ft FrameType, c quic.Connection, str quic.Stream, e error) (hijacked bool, err error) {
|
|
defer close(done)
|
|
Expect(e).To(MatchError(testErr))
|
|
Expect(ft).To(BeZero())
|
|
Expect(str).To(Equal(unknownStr))
|
|
return false, nil
|
|
}
|
|
|
|
unknownStr.EXPECT().Read(gomock.Any()).Return(0, testErr).AnyTimes()
|
|
conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil)
|
|
conn.EXPECT().AcceptStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.Stream, error) {
|
|
<-testDone
|
|
return nil, errors.New("test done")
|
|
})
|
|
conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), gomock.Any()).Return(nil).AnyTimes()
|
|
_, err := cl.RoundTripOpt(request, RoundTripOpt{})
|
|
Expect(err).To(MatchError("done"))
|
|
Eventually(done).Should(BeClosed())
|
|
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
|
|
})
|
|
})
|
|
|
|
Context("hijacking unidirectional streams", func() {
|
|
var (
|
|
req *http.Request
|
|
conn *mockquic.MockEarlyConnection
|
|
settingsFrameWritten chan struct{}
|
|
)
|
|
testDone := make(chan struct{})
|
|
|
|
BeforeEach(func() {
|
|
testDone = make(chan struct{})
|
|
settingsFrameWritten = make(chan struct{})
|
|
controlStr := mockquic.NewMockStream(mockCtrl)
|
|
controlStr.EXPECT().Write(gomock.Any()).Do(func(b []byte) {
|
|
defer GinkgoRecover()
|
|
close(settingsFrameWritten)
|
|
})
|
|
conn = mockquic.NewMockEarlyConnection(mockCtrl)
|
|
conn.EXPECT().OpenUniStream().Return(controlStr, nil)
|
|
conn.EXPECT().HandshakeComplete().Return(handshakeChan)
|
|
conn.EXPECT().OpenStreamSync(gomock.Any()).Return(nil, errors.New("done"))
|
|
dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) {
|
|
return conn, nil
|
|
}
|
|
var err error
|
|
req, err = http.NewRequest("GET", "https://quic.clemente.io:1337/file1.dat", nil)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
})
|
|
|
|
AfterEach(func() {
|
|
testDone <- struct{}{}
|
|
Eventually(settingsFrameWritten).Should(BeClosed())
|
|
})
|
|
|
|
It("hijacks an unidirectional stream of unknown stream type", func() {
|
|
streamTypeChan := make(chan StreamType, 1)
|
|
cl.opts.UniStreamHijacker = func(st StreamType, _ quic.Connection, _ quic.ReceiveStream, err error) bool {
|
|
Expect(err).ToNot(HaveOccurred())
|
|
streamTypeChan <- st
|
|
return true
|
|
}
|
|
|
|
buf := bytes.NewBuffer(quicvarint.Append(nil, 0x54))
|
|
unknownStr := mockquic.NewMockStream(mockCtrl)
|
|
unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
|
|
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
|
|
return unknownStr, nil
|
|
})
|
|
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
|
|
<-testDone
|
|
return nil, errors.New("test done")
|
|
})
|
|
_, err := cl.RoundTripOpt(req, RoundTripOpt{})
|
|
Expect(err).To(MatchError("done"))
|
|
Eventually(streamTypeChan).Should(Receive(BeEquivalentTo(0x54)))
|
|
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
|
|
})
|
|
|
|
It("handles errors that occur when reading the stream type", func() {
|
|
testErr := errors.New("test error")
|
|
done := make(chan struct{})
|
|
unknownStr := mockquic.NewMockStream(mockCtrl)
|
|
cl.opts.UniStreamHijacker = func(st StreamType, _ quic.Connection, str quic.ReceiveStream, err error) bool {
|
|
defer close(done)
|
|
Expect(st).To(BeZero())
|
|
Expect(str).To(Equal(unknownStr))
|
|
Expect(err).To(MatchError(testErr))
|
|
return true
|
|
}
|
|
|
|
unknownStr.EXPECT().Read(gomock.Any()).Return(0, testErr)
|
|
conn.EXPECT().AcceptUniStream(gomock.Any()).Return(unknownStr, nil)
|
|
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
|
|
<-testDone
|
|
return nil, errors.New("test done")
|
|
})
|
|
_, err := cl.RoundTripOpt(req, RoundTripOpt{})
|
|
Expect(err).To(MatchError("done"))
|
|
Eventually(done).Should(BeClosed())
|
|
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
|
|
})
|
|
|
|
It("cancels reading when hijacker didn't hijack an unidirectional stream", func() {
|
|
streamTypeChan := make(chan StreamType, 1)
|
|
cl.opts.UniStreamHijacker = func(st StreamType, _ quic.Connection, _ quic.ReceiveStream, err error) bool {
|
|
Expect(err).ToNot(HaveOccurred())
|
|
streamTypeChan <- st
|
|
return false
|
|
}
|
|
|
|
buf := bytes.NewBuffer(quicvarint.Append(nil, 0x54))
|
|
unknownStr := mockquic.NewMockStream(mockCtrl)
|
|
unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
|
|
unknownStr.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeStreamCreationError))
|
|
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
|
|
return unknownStr, nil
|
|
})
|
|
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
|
|
<-testDone
|
|
return nil, errors.New("test done")
|
|
})
|
|
_, err := cl.RoundTripOpt(req, RoundTripOpt{})
|
|
Expect(err).To(MatchError("done"))
|
|
Eventually(streamTypeChan).Should(Receive(BeEquivalentTo(0x54)))
|
|
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
|
|
})
|
|
})
|
|
|
|
Context("control stream handling", func() {
|
|
var (
|
|
req *http.Request
|
|
conn *mockquic.MockEarlyConnection
|
|
settingsFrameWritten chan struct{}
|
|
)
|
|
testDone := make(chan struct{})
|
|
|
|
BeforeEach(func() {
|
|
settingsFrameWritten = make(chan struct{})
|
|
controlStr := mockquic.NewMockStream(mockCtrl)
|
|
controlStr.EXPECT().Write(gomock.Any()).Do(func(b []byte) {
|
|
defer GinkgoRecover()
|
|
close(settingsFrameWritten)
|
|
})
|
|
conn = mockquic.NewMockEarlyConnection(mockCtrl)
|
|
conn.EXPECT().OpenUniStream().Return(controlStr, nil)
|
|
conn.EXPECT().HandshakeComplete().Return(handshakeChan)
|
|
conn.EXPECT().OpenStreamSync(gomock.Any()).Return(nil, errors.New("done"))
|
|
dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) {
|
|
return conn, nil
|
|
}
|
|
var err error
|
|
req, err = http.NewRequest("GET", "https://quic.clemente.io:1337/file1.dat", nil)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
})
|
|
|
|
AfterEach(func() {
|
|
testDone <- struct{}{}
|
|
Eventually(settingsFrameWritten).Should(BeClosed())
|
|
})
|
|
|
|
It("parses the SETTINGS frame", func() {
|
|
b := quicvarint.Append(nil, streamTypeControlStream)
|
|
b = (&settingsFrame{}).Append(b)
|
|
r := bytes.NewReader(b)
|
|
controlStr := mockquic.NewMockStream(mockCtrl)
|
|
controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes()
|
|
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
|
|
return controlStr, nil
|
|
})
|
|
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
|
|
<-testDone
|
|
return nil, errors.New("test done")
|
|
})
|
|
_, err := cl.RoundTripOpt(req, RoundTripOpt{})
|
|
Expect(err).To(MatchError("done"))
|
|
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
|
|
})
|
|
|
|
for _, t := range []uint64{streamTypeQPACKEncoderStream, streamTypeQPACKDecoderStream} {
|
|
streamType := t
|
|
name := "encoder"
|
|
if streamType == streamTypeQPACKDecoderStream {
|
|
name = "decoder"
|
|
}
|
|
|
|
It(fmt.Sprintf("ignores the QPACK %s streams", name), func() {
|
|
buf := bytes.NewBuffer(quicvarint.Append(nil, streamType))
|
|
str := mockquic.NewMockStream(mockCtrl)
|
|
str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
|
|
|
|
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
|
|
return str, nil
|
|
})
|
|
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
|
|
<-testDone
|
|
return nil, errors.New("test done")
|
|
})
|
|
_, err := cl.RoundTripOpt(req, RoundTripOpt{})
|
|
Expect(err).To(MatchError("done"))
|
|
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to str.CancelRead
|
|
})
|
|
}
|
|
|
|
It("resets streams Other than the control stream and the QPACK streams", func() {
|
|
buf := bytes.NewBuffer(quicvarint.Append(nil, 0x1337))
|
|
str := mockquic.NewMockStream(mockCtrl)
|
|
str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
|
|
done := make(chan struct{})
|
|
str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeStreamCreationError)).Do(func(code quic.StreamErrorCode) {
|
|
close(done)
|
|
})
|
|
|
|
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
|
|
return str, nil
|
|
})
|
|
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
|
|
<-testDone
|
|
return nil, errors.New("test done")
|
|
})
|
|
_, err := cl.RoundTripOpt(req, RoundTripOpt{})
|
|
Expect(err).To(MatchError("done"))
|
|
Eventually(done).Should(BeClosed())
|
|
})
|
|
|
|
It("errors when the first frame on the control stream is not a SETTINGS frame", func() {
|
|
b := quicvarint.Append(nil, streamTypeControlStream)
|
|
b = (&dataFrame{}).Append(b)
|
|
r := bytes.NewReader(b)
|
|
controlStr := mockquic.NewMockStream(mockCtrl)
|
|
controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes()
|
|
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
|
|
return controlStr, nil
|
|
})
|
|
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
|
|
<-testDone
|
|
return nil, errors.New("test done")
|
|
})
|
|
done := make(chan struct{})
|
|
conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(code quic.ApplicationErrorCode, _ string) {
|
|
defer GinkgoRecover()
|
|
Expect(code).To(BeEquivalentTo(ErrCodeMissingSettings))
|
|
close(done)
|
|
})
|
|
_, err := cl.RoundTripOpt(req, RoundTripOpt{})
|
|
Expect(err).To(MatchError("done"))
|
|
Eventually(done).Should(BeClosed())
|
|
})
|
|
|
|
It("errors when parsing the frame on the control stream fails", func() {
|
|
b := quicvarint.Append(nil, streamTypeControlStream)
|
|
b = (&settingsFrame{}).Append(b)
|
|
r := bytes.NewReader(b[:len(b)-1])
|
|
controlStr := mockquic.NewMockStream(mockCtrl)
|
|
controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes()
|
|
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
|
|
return controlStr, nil
|
|
})
|
|
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
|
|
<-testDone
|
|
return nil, errors.New("test done")
|
|
})
|
|
done := make(chan struct{})
|
|
conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(code quic.ApplicationErrorCode, _ string) {
|
|
defer GinkgoRecover()
|
|
Expect(code).To(BeEquivalentTo(ErrCodeFrameError))
|
|
close(done)
|
|
})
|
|
_, err := cl.RoundTripOpt(req, RoundTripOpt{})
|
|
Expect(err).To(MatchError("done"))
|
|
Eventually(done).Should(BeClosed())
|
|
})
|
|
|
|
It("errors when parsing the server opens a push stream", func() {
|
|
buf := bytes.NewBuffer(quicvarint.Append(nil, streamTypePushStream))
|
|
controlStr := mockquic.NewMockStream(mockCtrl)
|
|
controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
|
|
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
|
|
return controlStr, nil
|
|
})
|
|
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
|
|
<-testDone
|
|
return nil, errors.New("test done")
|
|
})
|
|
done := make(chan struct{})
|
|
conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(code quic.ApplicationErrorCode, _ string) {
|
|
defer GinkgoRecover()
|
|
Expect(code).To(BeEquivalentTo(ErrCodeIDError))
|
|
close(done)
|
|
})
|
|
_, err := cl.RoundTripOpt(req, RoundTripOpt{})
|
|
Expect(err).To(MatchError("done"))
|
|
Eventually(done).Should(BeClosed())
|
|
})
|
|
|
|
It("errors when the server advertises datagram support (and we enabled support for it)", func() {
|
|
cl.opts.EnableDatagram = true
|
|
b := quicvarint.Append(nil, streamTypeControlStream)
|
|
b = (&settingsFrame{Datagram: true}).Append(b)
|
|
r := bytes.NewReader(b)
|
|
controlStr := mockquic.NewMockStream(mockCtrl)
|
|
controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes()
|
|
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
|
|
return controlStr, nil
|
|
})
|
|
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
|
|
<-testDone
|
|
return nil, errors.New("test done")
|
|
})
|
|
conn.EXPECT().ConnectionState().Return(quic.ConnectionState{SupportsDatagrams: false})
|
|
done := make(chan struct{})
|
|
conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(code quic.ApplicationErrorCode, reason string) {
|
|
defer GinkgoRecover()
|
|
Expect(code).To(BeEquivalentTo(ErrCodeSettingsError))
|
|
Expect(reason).To(Equal("missing QUIC Datagram support"))
|
|
close(done)
|
|
})
|
|
_, err := cl.RoundTripOpt(req, RoundTripOpt{})
|
|
Expect(err).To(MatchError("done"))
|
|
Eventually(done).Should(BeClosed())
|
|
})
|
|
})
|
|
|
|
Context("Doing requests", func() {
|
|
var (
|
|
req *http.Request
|
|
str *mockquic.MockStream
|
|
conn *mockquic.MockEarlyConnection
|
|
settingsFrameWritten chan struct{}
|
|
)
|
|
testDone := make(chan struct{})
|
|
|
|
decodeHeader := func(str io.Reader) map[string]string {
|
|
fields := make(map[string]string)
|
|
decoder := qpack.NewDecoder(nil)
|
|
|
|
frame, err := parseNextFrame(str, nil)
|
|
ExpectWithOffset(1, err).ToNot(HaveOccurred())
|
|
ExpectWithOffset(1, frame).To(BeAssignableToTypeOf(&headersFrame{}))
|
|
headersFrame := frame.(*headersFrame)
|
|
data := make([]byte, headersFrame.Length)
|
|
_, err = io.ReadFull(str, data)
|
|
ExpectWithOffset(1, err).ToNot(HaveOccurred())
|
|
hfs, err := decoder.DecodeFull(data)
|
|
ExpectWithOffset(1, err).ToNot(HaveOccurred())
|
|
for _, p := range hfs {
|
|
fields[p.Name] = p.Value
|
|
}
|
|
return fields
|
|
}
|
|
|
|
getResponse := func(status int) []byte {
|
|
buf := &bytes.Buffer{}
|
|
rstr := mockquic.NewMockStream(mockCtrl)
|
|
rstr.EXPECT().Write(gomock.Any()).Do(buf.Write).AnyTimes()
|
|
rw := newResponseWriter(rstr, nil, utils.DefaultLogger)
|
|
rw.WriteHeader(status)
|
|
rw.Flush()
|
|
return buf.Bytes()
|
|
}
|
|
|
|
BeforeEach(func() {
|
|
settingsFrameWritten = make(chan struct{})
|
|
controlStr := mockquic.NewMockStream(mockCtrl)
|
|
controlStr.EXPECT().Write(gomock.Any()).Do(func(b []byte) {
|
|
defer GinkgoRecover()
|
|
r := bytes.NewReader(b)
|
|
streamType, err := quicvarint.Read(r)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(streamType).To(BeEquivalentTo(streamTypeControlStream))
|
|
close(settingsFrameWritten)
|
|
}) // SETTINGS frame
|
|
str = mockquic.NewMockStream(mockCtrl)
|
|
conn = mockquic.NewMockEarlyConnection(mockCtrl)
|
|
conn.EXPECT().OpenUniStream().Return(controlStr, nil)
|
|
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
|
|
<-testDone
|
|
return nil, errors.New("test done")
|
|
})
|
|
dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) {
|
|
return conn, nil
|
|
}
|
|
var err error
|
|
req, err = http.NewRequest("GET", "https://quic.clemente.io:1337/file1.dat", nil)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
})
|
|
|
|
AfterEach(func() {
|
|
testDone <- struct{}{}
|
|
Eventually(settingsFrameWritten).Should(BeClosed())
|
|
})
|
|
|
|
It("errors if it can't open a stream", func() {
|
|
testErr := errors.New("stream open error")
|
|
conn.EXPECT().OpenStreamSync(context.Background()).Return(nil, testErr)
|
|
conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).MaxTimes(1)
|
|
conn.EXPECT().HandshakeComplete().Return(handshakeChan)
|
|
_, err := cl.RoundTripOpt(req, RoundTripOpt{})
|
|
Expect(err).To(MatchError(testErr))
|
|
})
|
|
|
|
It("performs a 0-RTT request", func() {
|
|
testErr := errors.New("stream open error")
|
|
req.Method = MethodGet0RTT
|
|
// don't EXPECT any calls to HandshakeComplete()
|
|
conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil)
|
|
buf := &bytes.Buffer{}
|
|
str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write).AnyTimes()
|
|
str.EXPECT().Close()
|
|
str.EXPECT().CancelWrite(gomock.Any())
|
|
str.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) {
|
|
return 0, testErr
|
|
})
|
|
_, err := cl.RoundTripOpt(req, RoundTripOpt{})
|
|
Expect(err).To(MatchError(testErr))
|
|
Expect(decodeHeader(buf)).To(HaveKeyWithValue(":method", "GET"))
|
|
})
|
|
|
|
It("returns a response", func() {
|
|
rspBuf := bytes.NewBuffer(getResponse(418))
|
|
gomock.InOrder(
|
|
conn.EXPECT().HandshakeComplete().Return(handshakeChan),
|
|
conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil),
|
|
conn.EXPECT().ConnectionState().Return(quic.ConnectionState{}),
|
|
)
|
|
str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil })
|
|
str.EXPECT().Close()
|
|
str.EXPECT().Read(gomock.Any()).DoAndReturn(rspBuf.Read).AnyTimes()
|
|
rsp, err := cl.RoundTripOpt(req, RoundTripOpt{})
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(rsp.Proto).To(Equal("HTTP/3.0"))
|
|
Expect(rsp.ProtoMajor).To(Equal(3))
|
|
Expect(rsp.StatusCode).To(Equal(418))
|
|
Expect(rsp.Request).ToNot(BeNil())
|
|
})
|
|
|
|
It("doesn't close the request stream, with DontCloseRequestStream set", func() {
|
|
rspBuf := bytes.NewBuffer(getResponse(418))
|
|
gomock.InOrder(
|
|
conn.EXPECT().HandshakeComplete().Return(handshakeChan),
|
|
conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil),
|
|
conn.EXPECT().ConnectionState().Return(quic.ConnectionState{}),
|
|
)
|
|
str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil })
|
|
str.EXPECT().Read(gomock.Any()).DoAndReturn(rspBuf.Read).AnyTimes()
|
|
rsp, err := cl.RoundTripOpt(req, RoundTripOpt{DontCloseRequestStream: true})
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(rsp.Proto).To(Equal("HTTP/3.0"))
|
|
Expect(rsp.ProtoMajor).To(Equal(3))
|
|
Expect(rsp.StatusCode).To(Equal(418))
|
|
})
|
|
|
|
Context("requests containing a Body", func() {
|
|
var strBuf *bytes.Buffer
|
|
|
|
BeforeEach(func() {
|
|
strBuf = &bytes.Buffer{}
|
|
gomock.InOrder(
|
|
conn.EXPECT().HandshakeComplete().Return(handshakeChan),
|
|
conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil),
|
|
)
|
|
body := &mockBody{}
|
|
body.SetData([]byte("request body"))
|
|
var err error
|
|
req, err = http.NewRequest("POST", "https://quic.clemente.io:1337/upload", body)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
str.EXPECT().Write(gomock.Any()).DoAndReturn(strBuf.Write).AnyTimes()
|
|
})
|
|
|
|
It("sends a request", func() {
|
|
done := make(chan struct{})
|
|
gomock.InOrder(
|
|
str.EXPECT().Close().Do(func() { close(done) }),
|
|
str.EXPECT().CancelWrite(gomock.Any()).MaxTimes(1), // when reading the response errors
|
|
)
|
|
// the response body is sent asynchronously, while already reading the response
|
|
str.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) {
|
|
<-done
|
|
return 0, errors.New("test done")
|
|
})
|
|
_, err := cl.RoundTripOpt(req, RoundTripOpt{})
|
|
Expect(err).To(MatchError("test done"))
|
|
hfs := decodeHeader(strBuf)
|
|
Expect(hfs).To(HaveKeyWithValue(":method", "POST"))
|
|
Expect(hfs).To(HaveKeyWithValue(":path", "/upload"))
|
|
})
|
|
|
|
It("doesn't send more bytes than allowed by http.Request.ContentLength", func() {
|
|
req.ContentLength = 7
|
|
var once sync.Once
|
|
done := make(chan struct{})
|
|
gomock.InOrder(
|
|
str.EXPECT().CancelWrite(gomock.Any()).Do(func(c quic.StreamErrorCode) {
|
|
once.Do(func() {
|
|
Expect(c).To(Equal(quic.StreamErrorCode(ErrCodeRequestCanceled)))
|
|
close(done)
|
|
})
|
|
}).AnyTimes(),
|
|
str.EXPECT().Close().MaxTimes(1),
|
|
str.EXPECT().CancelWrite(gomock.Any()).AnyTimes(),
|
|
)
|
|
str.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) {
|
|
<-done
|
|
return 0, errors.New("done")
|
|
})
|
|
cl.RoundTripOpt(req, RoundTripOpt{})
|
|
Expect(strBuf.String()).To(ContainSubstring("request"))
|
|
Expect(strBuf.String()).ToNot(ContainSubstring("request body"))
|
|
})
|
|
|
|
It("returns the error that occurred when reading the body", func() {
|
|
req.Body.(*mockBody).readErr = errors.New("testErr")
|
|
done := make(chan struct{})
|
|
gomock.InOrder(
|
|
str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeRequestCanceled)).Do(func(quic.StreamErrorCode) {
|
|
close(done)
|
|
}),
|
|
str.EXPECT().CancelWrite(gomock.Any()),
|
|
)
|
|
|
|
// the response body is sent asynchronously, while already reading the response
|
|
str.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) {
|
|
<-done
|
|
return 0, errors.New("test done")
|
|
})
|
|
closed := make(chan struct{})
|
|
str.EXPECT().Close().Do(func() { close(closed) })
|
|
_, err := cl.RoundTripOpt(req, RoundTripOpt{})
|
|
Expect(err).To(MatchError("test done"))
|
|
Eventually(closed).Should(BeClosed())
|
|
})
|
|
|
|
It("closes the connection when the first frame is not a HEADERS frame", func() {
|
|
b := (&dataFrame{Length: 0x42}).Append(nil)
|
|
conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), gomock.Any())
|
|
closed := make(chan struct{})
|
|
r := bytes.NewReader(b)
|
|
str.EXPECT().Close().Do(func() { close(closed) })
|
|
str.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes()
|
|
_, err := cl.RoundTripOpt(req, RoundTripOpt{})
|
|
Expect(err).To(MatchError("expected first frame to be a HEADERS frame"))
|
|
Eventually(closed).Should(BeClosed())
|
|
})
|
|
|
|
It("cancels the stream when parsing the headers fails", func() {
|
|
headerBuf := &bytes.Buffer{}
|
|
enc := qpack.NewEncoder(headerBuf)
|
|
Expect(enc.WriteField(qpack.HeaderField{Name: ":method", Value: "GET"})).To(Succeed()) // not a valid response pseudo header
|
|
Expect(enc.Close()).To(Succeed())
|
|
b := (&headersFrame{Length: uint64(headerBuf.Len())}).Append(nil)
|
|
b = append(b, headerBuf.Bytes()...)
|
|
|
|
r := bytes.NewReader(b)
|
|
str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeMessageError))
|
|
closed := make(chan struct{})
|
|
str.EXPECT().Close().Do(func() { close(closed) })
|
|
str.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes()
|
|
_, err := cl.RoundTripOpt(req, RoundTripOpt{})
|
|
Expect(err).To(HaveOccurred())
|
|
Eventually(closed).Should(BeClosed())
|
|
})
|
|
|
|
It("cancels the stream when the HEADERS frame is too large", func() {
|
|
b := (&headersFrame{Length: 1338}).Append(nil)
|
|
r := bytes.NewReader(b)
|
|
str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeFrameError))
|
|
closed := make(chan struct{})
|
|
str.EXPECT().Close().Do(func() { close(closed) })
|
|
str.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes()
|
|
_, err := cl.RoundTripOpt(req, RoundTripOpt{})
|
|
Expect(err).To(MatchError("HEADERS frame too large: 1338 bytes (max: 1337)"))
|
|
Eventually(closed).Should(BeClosed())
|
|
})
|
|
})
|
|
|
|
Context("request cancellations", func() {
|
|
for _, dontClose := range []bool{false, true} {
|
|
dontClose := dontClose
|
|
|
|
Context(fmt.Sprintf("with DontCloseRequestStream: %t", dontClose), func() {
|
|
roundTripOpt := RoundTripOpt{DontCloseRequestStream: dontClose}
|
|
|
|
It("cancels a request while waiting for the handshake to complete", func() {
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
req := req.WithContext(ctx)
|
|
conn.EXPECT().HandshakeComplete().Return(make(chan struct{}))
|
|
|
|
errChan := make(chan error)
|
|
go func() {
|
|
_, err := cl.RoundTripOpt(req, roundTripOpt)
|
|
errChan <- err
|
|
}()
|
|
Consistently(errChan).ShouldNot(Receive())
|
|
cancel()
|
|
Eventually(errChan).Should(Receive(MatchError("context canceled")))
|
|
})
|
|
|
|
It("cancels a request while the request is still in flight", func() {
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
req := req.WithContext(ctx)
|
|
conn.EXPECT().HandshakeComplete().Return(handshakeChan)
|
|
conn.EXPECT().OpenStreamSync(ctx).Return(str, nil)
|
|
buf := &bytes.Buffer{}
|
|
str.EXPECT().Close().MaxTimes(1)
|
|
|
|
str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write)
|
|
|
|
done := make(chan struct{})
|
|
canceled := make(chan struct{})
|
|
gomock.InOrder(
|
|
str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeRequestCanceled)).Do(func(quic.StreamErrorCode) { close(canceled) }),
|
|
str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeRequestCanceled)).Do(func(quic.StreamErrorCode) { close(done) }),
|
|
)
|
|
str.EXPECT().CancelWrite(gomock.Any()).MaxTimes(1)
|
|
str.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) {
|
|
cancel()
|
|
<-canceled
|
|
return 0, errors.New("test done")
|
|
})
|
|
_, err := cl.RoundTripOpt(req, roundTripOpt)
|
|
Expect(err).To(MatchError("test done"))
|
|
Eventually(done).Should(BeClosed())
|
|
})
|
|
})
|
|
}
|
|
|
|
It("cancels a request after the response arrived", func() {
|
|
rspBuf := bytes.NewBuffer(getResponse(404))
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
req := req.WithContext(ctx)
|
|
conn.EXPECT().HandshakeComplete().Return(handshakeChan)
|
|
conn.EXPECT().OpenStreamSync(ctx).Return(str, nil)
|
|
conn.EXPECT().ConnectionState().Return(quic.ConnectionState{})
|
|
buf := &bytes.Buffer{}
|
|
str.EXPECT().Close().MaxTimes(1)
|
|
|
|
done := make(chan struct{})
|
|
str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write)
|
|
str.EXPECT().Read(gomock.Any()).DoAndReturn(rspBuf.Read).AnyTimes()
|
|
str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeRequestCanceled))
|
|
str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeRequestCanceled)).Do(func(quic.StreamErrorCode) { close(done) })
|
|
_, err := cl.RoundTripOpt(req, RoundTripOpt{})
|
|
Expect(err).ToNot(HaveOccurred())
|
|
cancel()
|
|
Eventually(done).Should(BeClosed())
|
|
})
|
|
})
|
|
|
|
Context("gzip compression", func() {
|
|
BeforeEach(func() {
|
|
conn.EXPECT().HandshakeComplete().Return(handshakeChan)
|
|
})
|
|
|
|
It("adds the gzip header to requests", func() {
|
|
conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil)
|
|
buf := &bytes.Buffer{}
|
|
str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write)
|
|
gomock.InOrder(
|
|
str.EXPECT().Close(),
|
|
str.EXPECT().CancelWrite(gomock.Any()).MaxTimes(1), // when the Read errors
|
|
)
|
|
str.EXPECT().Read(gomock.Any()).Return(0, errors.New("test done"))
|
|
_, err := cl.RoundTripOpt(req, RoundTripOpt{})
|
|
Expect(err).To(MatchError("test done"))
|
|
hfs := decodeHeader(buf)
|
|
Expect(hfs).To(HaveKeyWithValue("accept-encoding", "gzip"))
|
|
})
|
|
|
|
It("doesn't add gzip if the header disable it", func() {
|
|
client, err := newClient("quic.clemente.io:1337", nil, &roundTripperOpts{DisableCompression: true}, nil, nil)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil)
|
|
buf := &bytes.Buffer{}
|
|
str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write)
|
|
gomock.InOrder(
|
|
str.EXPECT().Close(),
|
|
str.EXPECT().CancelWrite(gomock.Any()).MaxTimes(1), // when the Read errors
|
|
)
|
|
str.EXPECT().Read(gomock.Any()).Return(0, errors.New("test done"))
|
|
_, err = client.RoundTripOpt(req, RoundTripOpt{})
|
|
Expect(err).To(MatchError("test done"))
|
|
hfs := decodeHeader(buf)
|
|
Expect(hfs).ToNot(HaveKey("accept-encoding"))
|
|
})
|
|
|
|
It("decompresses the response", func() {
|
|
conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil)
|
|
conn.EXPECT().ConnectionState().Return(quic.ConnectionState{})
|
|
buf := &bytes.Buffer{}
|
|
rstr := mockquic.NewMockStream(mockCtrl)
|
|
rstr.EXPECT().Write(gomock.Any()).Do(buf.Write).AnyTimes()
|
|
rw := newResponseWriter(rstr, nil, utils.DefaultLogger)
|
|
rw.Header().Set("Content-Encoding", "gzip")
|
|
gz := gzip.NewWriter(rw)
|
|
gz.Write([]byte("gzipped response"))
|
|
gz.Close()
|
|
rw.Flush()
|
|
str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil })
|
|
str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
|
|
str.EXPECT().Close()
|
|
|
|
rsp, err := cl.RoundTripOpt(req, RoundTripOpt{})
|
|
Expect(err).ToNot(HaveOccurred())
|
|
data, err := io.ReadAll(rsp.Body)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(rsp.ContentLength).To(BeEquivalentTo(-1))
|
|
Expect(string(data)).To(Equal("gzipped response"))
|
|
Expect(rsp.Header.Get("Content-Encoding")).To(BeEmpty())
|
|
Expect(rsp.Uncompressed).To(BeTrue())
|
|
})
|
|
|
|
It("only decompresses the response if the response contains the right content-encoding header", func() {
|
|
conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil)
|
|
conn.EXPECT().ConnectionState().Return(quic.ConnectionState{})
|
|
buf := &bytes.Buffer{}
|
|
rstr := mockquic.NewMockStream(mockCtrl)
|
|
rstr.EXPECT().Write(gomock.Any()).Do(buf.Write).AnyTimes()
|
|
rw := newResponseWriter(rstr, nil, utils.DefaultLogger)
|
|
rw.Write([]byte("not gzipped"))
|
|
rw.Flush()
|
|
str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil })
|
|
str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
|
|
str.EXPECT().Close()
|
|
|
|
rsp, err := cl.RoundTripOpt(req, RoundTripOpt{})
|
|
Expect(err).ToNot(HaveOccurred())
|
|
data, err := io.ReadAll(rsp.Body)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(string(data)).To(Equal("not gzipped"))
|
|
Expect(rsp.Header.Get("Content-Encoding")).To(BeEmpty())
|
|
})
|
|
})
|
|
})
|
|
})
|