uquic/http3/client_test.go

499 lines
18 KiB
Go

package http3
import (
"bytes"
"compress/gzip"
"context"
"crypto/tls"
"errors"
"io"
"io/ioutil"
"net/http"
"time"
"github.com/golang/mock/gomock"
quic "github.com/lucas-clemente/quic-go"
mockquic "github.com/lucas-clemente/quic-go/internal/mocks/quic"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/marten-seemann/qpack"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("Client", func() {
var (
client *client
req *http.Request
origDialAddr = dialAddr
handshakeCtx context.Context // an already canceled context
)
BeforeEach(func() {
origDialAddr = dialAddr
hostname := "quic.clemente.io:1337"
client = newClient(hostname, nil, &roundTripperOpts{MaxHeaderBytes: 1337}, nil, nil)
Expect(client.hostname).To(Equal(hostname))
var err error
req, err = http.NewRequest("GET", "https://localhost:1337", nil)
Expect(err).ToNot(HaveOccurred())
ctx, cancel := context.WithCancel(context.Background())
cancel()
handshakeCtx = ctx
})
AfterEach(func() {
dialAddr = origDialAddr
})
It("uses the default QUIC and TLS config if none is give", func() {
client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil)
var dialAddrCalled bool
dialAddr = func(_ string, tlsConf *tls.Config, quicConf *quic.Config) (quic.EarlySession, error) {
Expect(quicConf).To(Equal(defaultQuicConfig))
Expect(tlsConf.NextProtos).To(Equal([]string{nextProtoH3Draft29}))
dialAddrCalled = true
return nil, errors.New("test done")
}
client.RoundTrip(req)
Expect(dialAddrCalled).To(BeTrue())
})
It("adds the port to the hostname, if none is given", func() {
client = newClient("quic.clemente.io", nil, &roundTripperOpts{}, nil, nil)
var dialAddrCalled bool
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.EarlySession, error) {
Expect(hostname).To(Equal("quic.clemente.io:443"))
dialAddrCalled = true
return nil, errors.New("test done")
}
req, err := http.NewRequest("GET", "https://quic.clemente.io:443", nil)
Expect(err).ToNot(HaveOccurred())
client.RoundTrip(req)
Expect(dialAddrCalled).To(BeTrue())
})
It("uses the TLS config and QUIC config", func() {
tlsConf := &tls.Config{
ServerName: "foo.bar",
NextProtos: []string{"proto foo", "proto bar"},
}
quicConf := &quic.Config{MaxIdleTimeout: time.Nanosecond}
client = newClient("localhost:1337", tlsConf, &roundTripperOpts{}, quicConf, nil)
var dialAddrCalled bool
dialAddr = func(
hostname string,
tlsConfP *tls.Config,
quicConfP *quic.Config,
) (quic.EarlySession, error) {
Expect(hostname).To(Equal("localhost:1337"))
Expect(tlsConfP.ServerName).To(Equal(tlsConf.ServerName))
Expect(tlsConfP.NextProtos).To(Equal([]string{nextProtoH3Draft29}))
Expect(quicConfP.MaxIdleTimeout).To(Equal(quicConf.MaxIdleTimeout))
dialAddrCalled = true
return nil, errors.New("test done")
}
client.RoundTrip(req)
Expect(dialAddrCalled).To(BeTrue())
// make sure the original tls.Config was not modified
Expect(tlsConf.NextProtos).To(Equal([]string{"proto foo", "proto bar"}))
})
It("uses the custom dialer, if provided", func() {
testErr := errors.New("test done")
tlsConf := &tls.Config{ServerName: "foo.bar"}
quicConf := &quic.Config{MaxIdleTimeout: 1337 * time.Second}
var dialerCalled bool
dialer := func(network, address string, tlsConfP *tls.Config, quicConfP *quic.Config) (quic.EarlySession, error) {
Expect(network).To(Equal("udp"))
Expect(address).To(Equal("localhost:1337"))
Expect(tlsConfP.ServerName).To(Equal("foo.bar"))
Expect(quicConfP.MaxIdleTimeout).To(Equal(quicConf.MaxIdleTimeout))
dialerCalled = true
return nil, testErr
}
client = newClient("localhost:1337", tlsConf, &roundTripperOpts{}, quicConf, dialer)
_, err := client.RoundTrip(req)
Expect(err).To(MatchError(testErr))
Expect(dialerCalled).To(BeTrue())
})
It("errors when dialing fails", func() {
testErr := errors.New("handshake error")
client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil)
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.EarlySession, error) {
return nil, testErr
}
_, err := client.RoundTrip(req)
Expect(err).To(MatchError(testErr))
})
It("errors if it can't open a stream", func() {
testErr := errors.New("stream open error")
client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil)
session := mockquic.NewMockEarlySession(mockCtrl)
session.EXPECT().OpenUniStream().Return(nil, testErr).MaxTimes(1)
session.EXPECT().HandshakeComplete().Return(handshakeCtx).MaxTimes(1)
session.EXPECT().OpenStreamSync(context.Background()).Return(nil, testErr).MaxTimes(1)
session.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).MaxTimes(1)
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.EarlySession, error) {
return session, nil
}
defer GinkgoRecover()
_, err := client.RoundTrip(req)
Expect(err).To(MatchError(testErr))
})
It("closes correctly if session was not created", func() {
client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil)
err := client.Close()
Expect(err).ToNot(HaveOccurred())
})
Context("Doing requests", func() {
var (
request *http.Request
str *mockquic.MockStream
sess *mockquic.MockEarlySession
)
decodeHeader := func(str io.Reader) map[string]string {
fields := make(map[string]string)
decoder := qpack.NewDecoder(nil)
frame, err := parseNextFrame(str)
ExpectWithOffset(1, err).ToNot(HaveOccurred())
ExpectWithOffset(1, frame).To(BeAssignableToTypeOf(&headersFrame{}))
headersFrame := frame.(*headersFrame)
data := make([]byte, headersFrame.Length)
_, err = io.ReadFull(str, data)
ExpectWithOffset(1, err).ToNot(HaveOccurred())
hfs, err := decoder.DecodeFull(data)
ExpectWithOffset(1, err).ToNot(HaveOccurred())
for _, p := range hfs {
fields[p.Name] = p.Value
}
return fields
}
BeforeEach(func() {
controlStr := mockquic.NewMockStream(mockCtrl)
controlStr.EXPECT().Write([]byte{0x0}).Return(1, nil).MaxTimes(1)
controlStr.EXPECT().Write(gomock.Any()).MaxTimes(1) // SETTINGS frame
str = mockquic.NewMockStream(mockCtrl)
sess = mockquic.NewMockEarlySession(mockCtrl)
sess.EXPECT().OpenUniStream().Return(controlStr, nil).MaxTimes(1)
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.EarlySession, error) {
return sess, nil
}
var err error
request, err = http.NewRequest("GET", "https://quic.clemente.io:1337/file1.dat", nil)
Expect(err).ToNot(HaveOccurred())
})
It("performs a 0-RTT request", func() {
testErr := errors.New("stream open error")
request.Method = MethodGet0RTT
// don't EXPECT any calls to HandshakeComplete()
sess.EXPECT().OpenStreamSync(context.Background()).Return(str, nil)
buf := &bytes.Buffer{}
str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
return buf.Write(p)
}).AnyTimes()
str.EXPECT().Close()
str.EXPECT().CancelWrite(gomock.Any())
str.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) {
return 0, testErr
})
_, err := client.RoundTrip(request)
Expect(err).To(MatchError(testErr))
Expect(decodeHeader(buf)).To(HaveKeyWithValue(":method", "GET"))
})
It("returns a response", func() {
rspBuf := &bytes.Buffer{}
rw := newResponseWriter(rspBuf, utils.DefaultLogger)
rw.WriteHeader(418)
rw.Flush()
gomock.InOrder(
sess.EXPECT().HandshakeComplete().Return(handshakeCtx),
sess.EXPECT().OpenStreamSync(context.Background()).Return(str, nil),
)
str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil })
str.EXPECT().Close()
str.EXPECT().Read(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
return rspBuf.Read(p)
}).AnyTimes()
rsp, err := client.RoundTrip(request)
Expect(err).ToNot(HaveOccurred())
Expect(rsp.Proto).To(Equal("HTTP/3"))
Expect(rsp.ProtoMajor).To(Equal(3))
Expect(rsp.StatusCode).To(Equal(418))
})
Context("validating the address", func() {
It("refuses to do requests for the wrong host", func() {
req, err := http.NewRequest("https", "https://quic.clemente.io:1336/foobar.html", nil)
Expect(err).ToNot(HaveOccurred())
_, err = client.RoundTrip(req)
Expect(err).To(MatchError("http3 client BUG: RoundTrip called for the wrong client (expected quic.clemente.io:1337, got quic.clemente.io:1336)"))
})
It("refuses to do plain HTTP requests", func() {
req, err := http.NewRequest("https", "http://quic.clemente.io:1337/foobar.html", nil)
Expect(err).ToNot(HaveOccurred())
_, err = client.RoundTrip(req)
Expect(err).To(MatchError("http3: unsupported scheme"))
})
})
Context("requests containing a Body", func() {
var strBuf *bytes.Buffer
BeforeEach(func() {
strBuf = &bytes.Buffer{}
gomock.InOrder(
sess.EXPECT().HandshakeComplete().Return(handshakeCtx),
sess.EXPECT().OpenStreamSync(context.Background()).Return(str, nil),
)
body := &mockBody{}
body.SetData([]byte("request body"))
var err error
request, err = http.NewRequest("POST", "https://quic.clemente.io:1337/upload", body)
Expect(err).ToNot(HaveOccurred())
str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
return strBuf.Write(p)
}).AnyTimes()
})
It("sends a request", func() {
done := make(chan struct{})
gomock.InOrder(
str.EXPECT().Close().Do(func() { close(done) }),
str.EXPECT().CancelWrite(gomock.Any()).MaxTimes(1), // when reading the response errors
)
// the response body is sent asynchronously, while already reading the response
str.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) {
<-done
return 0, errors.New("test done")
})
_, err := client.RoundTrip(request)
Expect(err).To(MatchError("test done"))
hfs := decodeHeader(strBuf)
Expect(hfs).To(HaveKeyWithValue(":method", "POST"))
Expect(hfs).To(HaveKeyWithValue(":path", "/upload"))
})
It("returns the error that occurred when reading the body", func() {
request.Body.(*mockBody).readErr = errors.New("testErr")
done := make(chan struct{})
gomock.InOrder(
str.EXPECT().CancelWrite(quic.ErrorCode(errorRequestCanceled)).Do(func(quic.ErrorCode) {
close(done)
}),
str.EXPECT().CancelWrite(gomock.Any()),
)
// the response body is sent asynchronously, while already reading the response
str.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) {
<-done
return 0, errors.New("test done")
})
_, err := client.RoundTrip(request)
Expect(err).To(MatchError("test done"))
})
It("closes the connection when the first frame is not a HEADERS frame", func() {
buf := &bytes.Buffer{}
(&dataFrame{Length: 0x42}).Write(buf)
sess.EXPECT().CloseWithError(quic.ErrorCode(errorFrameUnexpected), gomock.Any())
closed := make(chan struct{})
str.EXPECT().Close().Do(func() { close(closed) })
str.EXPECT().Read(gomock.Any()).DoAndReturn(func(b []byte) (int, error) {
return buf.Read(b)
}).AnyTimes()
_, err := client.RoundTrip(request)
Expect(err).To(MatchError("expected first frame to be a HEADERS frame"))
Eventually(closed).Should(BeClosed())
})
It("cancels the stream when the HEADERS frame is too large", func() {
buf := &bytes.Buffer{}
(&headersFrame{Length: 1338}).Write(buf)
str.EXPECT().CancelWrite(quic.ErrorCode(errorFrameError))
closed := make(chan struct{})
str.EXPECT().Close().Do(func() { close(closed) })
str.EXPECT().Read(gomock.Any()).DoAndReturn(func(b []byte) (int, error) {
return buf.Read(b)
}).AnyTimes()
_, err := client.RoundTrip(request)
Expect(err).To(MatchError("HEADERS frame too large: 1338 bytes (max: 1337)"))
Eventually(closed).Should(BeClosed())
})
})
Context("request cancellations", func() {
It("cancels a request while waiting for the handshake to complete", func() {
ctx, cancel := context.WithCancel(context.Background())
req := request.WithContext(ctx)
sess.EXPECT().HandshakeComplete().Return(context.Background())
errChan := make(chan error)
go func() {
_, err := client.RoundTrip(req)
errChan <- err
}()
Consistently(errChan).ShouldNot(Receive())
cancel()
Eventually(errChan).Should(Receive(MatchError("context canceled")))
})
It("cancels a request while the request is still in flight", func() {
ctx, cancel := context.WithCancel(context.Background())
req := request.WithContext(ctx)
sess.EXPECT().HandshakeComplete().Return(handshakeCtx)
sess.EXPECT().OpenStreamSync(ctx).Return(str, nil)
buf := &bytes.Buffer{}
str.EXPECT().Close().MaxTimes(1)
str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
return buf.Write(p)
})
done := make(chan struct{})
canceled := make(chan struct{})
gomock.InOrder(
str.EXPECT().CancelWrite(quic.ErrorCode(errorRequestCanceled)).Do(func(quic.ErrorCode) { close(canceled) }),
str.EXPECT().CancelRead(quic.ErrorCode(errorRequestCanceled)).Do(func(quic.ErrorCode) { close(done) }),
)
str.EXPECT().CancelWrite(gomock.Any()).MaxTimes(1)
str.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) {
cancel()
<-canceled
return 0, errors.New("test done")
})
_, err := client.RoundTrip(req)
Expect(err).To(MatchError("test done"))
Eventually(done).Should(BeClosed())
})
It("cancels a request after the response arrived", func() {
rspBuf := &bytes.Buffer{}
rw := newResponseWriter(rspBuf, utils.DefaultLogger)
rw.WriteHeader(418)
rw.Flush()
ctx, cancel := context.WithCancel(context.Background())
req := request.WithContext(ctx)
sess.EXPECT().HandshakeComplete().Return(handshakeCtx)
sess.EXPECT().OpenStreamSync(ctx).Return(str, nil)
buf := &bytes.Buffer{}
str.EXPECT().Close().MaxTimes(1)
done := make(chan struct{})
str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
return buf.Write(p)
})
str.EXPECT().Read(gomock.Any()).DoAndReturn(func(b []byte) (int, error) {
return rspBuf.Read(b)
}).AnyTimes()
str.EXPECT().CancelWrite(quic.ErrorCode(errorRequestCanceled))
str.EXPECT().CancelRead(quic.ErrorCode(errorRequestCanceled)).Do(func(quic.ErrorCode) { close(done) })
_, err := client.RoundTrip(req)
Expect(err).ToNot(HaveOccurred())
cancel()
Eventually(done).Should(BeClosed())
})
})
Context("gzip compression", func() {
BeforeEach(func() {
sess.EXPECT().HandshakeComplete().Return(handshakeCtx)
})
It("adds the gzip header to requests", func() {
sess.EXPECT().OpenStreamSync(context.Background()).Return(str, nil)
buf := &bytes.Buffer{}
str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
return buf.Write(p)
})
gomock.InOrder(
str.EXPECT().Close(),
str.EXPECT().CancelWrite(gomock.Any()).MaxTimes(1), // when the Read errors
)
str.EXPECT().Read(gomock.Any()).Return(0, errors.New("test done"))
_, err := client.RoundTrip(request)
Expect(err).To(MatchError("test done"))
hfs := decodeHeader(buf)
Expect(hfs).To(HaveKeyWithValue("accept-encoding", "gzip"))
})
It("doesn't add gzip if the header disable it", func() {
client = newClient("quic.clemente.io:1337", nil, &roundTripperOpts{DisableCompression: true}, nil, nil)
sess.EXPECT().OpenStreamSync(context.Background()).Return(str, nil)
buf := &bytes.Buffer{}
str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
return buf.Write(p)
})
gomock.InOrder(
str.EXPECT().Close(),
str.EXPECT().CancelWrite(gomock.Any()).MaxTimes(1), // when the Read errors
)
str.EXPECT().Read(gomock.Any()).Return(0, errors.New("test done"))
_, err := client.RoundTrip(request)
Expect(err).To(MatchError("test done"))
hfs := decodeHeader(buf)
Expect(hfs).ToNot(HaveKey("accept-encoding"))
})
It("decompresses the response", func() {
sess.EXPECT().OpenStreamSync(context.Background()).Return(str, nil)
buf := &bytes.Buffer{}
rw := newResponseWriter(buf, utils.DefaultLogger)
rw.Header().Set("Content-Encoding", "gzip")
gz := gzip.NewWriter(rw)
gz.Write([]byte("gzipped response"))
gz.Close()
rw.Flush()
str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil })
str.EXPECT().Read(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
return buf.Read(p)
}).AnyTimes()
str.EXPECT().Close()
rsp, err := client.RoundTrip(request)
Expect(err).ToNot(HaveOccurred())
data, err := ioutil.ReadAll(rsp.Body)
Expect(err).ToNot(HaveOccurred())
Expect(rsp.ContentLength).To(BeEquivalentTo(-1))
Expect(string(data)).To(Equal("gzipped response"))
Expect(rsp.Header.Get("Content-Encoding")).To(BeEmpty())
Expect(rsp.Uncompressed).To(BeTrue())
})
It("only decompresses the response if the response contains the right content-encoding header", func() {
sess.EXPECT().OpenStreamSync(context.Background()).Return(str, nil)
buf := &bytes.Buffer{}
rw := newResponseWriter(buf, utils.DefaultLogger)
rw.Write([]byte("not gzipped"))
rw.Flush()
str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil })
str.EXPECT().Read(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
return buf.Read(p)
}).AnyTimes()
str.EXPECT().Close()
rsp, err := client.RoundTrip(request)
Expect(err).ToNot(HaveOccurred())
data, err := ioutil.ReadAll(rsp.Body)
Expect(err).ToNot(HaveOccurred())
Expect(rsp.ContentLength).ToNot(BeEquivalentTo(-1))
Expect(string(data)).To(Equal("not gzipped"))
Expect(rsp.Header.Get("Content-Encoding")).To(BeEmpty())
})
})
})
})