use 0-RTT to open the H3 client's control stream

This commit is contained in:
Marten Seemann 2020-01-25 21:26:27 +07:00
parent 63c9272bf4
commit 1372e5dd5e
4 changed files with 44 additions and 34 deletions

View file

@ -23,7 +23,7 @@ var defaultQuicConfig = &quic.Config{
KeepAlive: true, KeepAlive: true,
} }
var dialAddr = quic.DialAddr var dialAddr = quic.DialAddrEarly
type roundTripperOpts struct { type roundTripperOpts struct {
DisableCompression bool DisableCompression bool
@ -37,7 +37,7 @@ type client struct {
opts *roundTripperOpts opts *roundTripperOpts
dialOnce sync.Once dialOnce sync.Once
dialer func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.Session, error) dialer func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlySession, error)
handshakeErr error handshakeErr error
requestWriter *requestWriter requestWriter *requestWriter
@ -45,7 +45,7 @@ type client struct {
decoder *qpack.Decoder decoder *qpack.Decoder
hostname string hostname string
session quic.Session session quic.EarlySession
logger utils.Logger logger utils.Logger
} }
@ -55,7 +55,7 @@ func newClient(
tlsConf *tls.Config, tlsConf *tls.Config,
opts *roundTripperOpts, opts *roundTripperOpts,
quicConfig *quic.Config, quicConfig *quic.Config,
dialer func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.Session, error), dialer func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlySession, error),
) *client { ) *client {
if tlsConf == nil { if tlsConf == nil {
tlsConf = &tls.Config{} tlsConf = &tls.Config{}
@ -93,6 +93,7 @@ func (c *client) dial() error {
return err return err
} }
// run the sesssion setup using 0-RTT data
go func() { go func() {
if err := c.setupSession(); err != nil { if err := c.setupSession(); err != nil {
c.logger.Debugf("Setting up session failed: %s", err) c.logger.Debugf("Setting up session failed: %s", err)
@ -100,6 +101,7 @@ func (c *client) dial() error {
} }
}() }()
<-c.session.HandshakeComplete().Done()
return nil return nil
} }

View file

@ -26,6 +26,7 @@ var _ = Describe("Client", func() {
client *client client *client
req *http.Request req *http.Request
origDialAddr = dialAddr origDialAddr = dialAddr
handshakeCtx context.Context // an already canceled context
) )
BeforeEach(func() { BeforeEach(func() {
@ -37,6 +38,10 @@ var _ = Describe("Client", func() {
var err error var err error
req, err = http.NewRequest("GET", "https://localhost:1337", nil) req, err = http.NewRequest("GET", "https://localhost:1337", nil)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
ctx, cancel := context.WithCancel(context.Background())
cancel()
handshakeCtx = ctx
}) })
AfterEach(func() { AfterEach(func() {
@ -46,7 +51,7 @@ var _ = Describe("Client", func() {
It("uses the default QUIC and TLS config if none is give", func() { It("uses the default QUIC and TLS config if none is give", func() {
client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil) client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil)
var dialAddrCalled bool var dialAddrCalled bool
dialAddr = func(_ string, tlsConf *tls.Config, quicConf *quic.Config) (quic.Session, error) { dialAddr = func(_ string, tlsConf *tls.Config, quicConf *quic.Config) (quic.EarlySession, error) {
Expect(quicConf).To(Equal(defaultQuicConfig)) Expect(quicConf).To(Equal(defaultQuicConfig))
Expect(tlsConf.NextProtos).To(Equal([]string{nextProtoH3})) Expect(tlsConf.NextProtos).To(Equal([]string{nextProtoH3}))
dialAddrCalled = true dialAddrCalled = true
@ -59,7 +64,7 @@ var _ = Describe("Client", func() {
It("adds the port to the hostname, if none is given", func() { It("adds the port to the hostname, if none is given", func() {
client = newClient("quic.clemente.io", nil, &roundTripperOpts{}, nil, nil) client = newClient("quic.clemente.io", nil, &roundTripperOpts{}, nil, nil)
var dialAddrCalled bool var dialAddrCalled bool
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) { dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.EarlySession, error) {
Expect(hostname).To(Equal("quic.clemente.io:443")) Expect(hostname).To(Equal("quic.clemente.io:443"))
dialAddrCalled = true dialAddrCalled = true
return nil, errors.New("test done") return nil, errors.New("test done")
@ -82,7 +87,7 @@ var _ = Describe("Client", func() {
hostname string, hostname string,
tlsConfP *tls.Config, tlsConfP *tls.Config,
quicConfP *quic.Config, quicConfP *quic.Config,
) (quic.Session, error) { ) (quic.EarlySession, error) {
Expect(hostname).To(Equal("localhost:1337")) Expect(hostname).To(Equal("localhost:1337"))
Expect(tlsConfP.ServerName).To(Equal(tlsConf.ServerName)) Expect(tlsConfP.ServerName).To(Equal(tlsConf.ServerName))
Expect(tlsConfP.NextProtos).To(Equal([]string{nextProtoH3})) Expect(tlsConfP.NextProtos).To(Equal([]string{nextProtoH3}))
@ -101,7 +106,7 @@ var _ = Describe("Client", func() {
tlsConf := &tls.Config{ServerName: "foo.bar"} tlsConf := &tls.Config{ServerName: "foo.bar"}
quicConf := &quic.Config{MaxIdleTimeout: 1337 * time.Second} quicConf := &quic.Config{MaxIdleTimeout: 1337 * time.Second}
var dialerCalled bool var dialerCalled bool
dialer := func(network, address string, tlsConfP *tls.Config, quicConfP *quic.Config) (quic.Session, error) { dialer := func(network, address string, tlsConfP *tls.Config, quicConfP *quic.Config) (quic.EarlySession, error) {
Expect(network).To(Equal("udp")) Expect(network).To(Equal("udp"))
Expect(address).To(Equal("localhost:1337")) Expect(address).To(Equal("localhost:1337"))
Expect(tlsConfP.ServerName).To(Equal("foo.bar")) Expect(tlsConfP.ServerName).To(Equal("foo.bar"))
@ -118,7 +123,7 @@ var _ = Describe("Client", func() {
It("errors when dialing fails", func() { It("errors when dialing fails", func() {
testErr := errors.New("handshake error") testErr := errors.New("handshake error")
client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil) client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil)
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) { dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.EarlySession, error) {
return nil, testErr return nil, testErr
} }
_, err := client.RoundTrip(req) _, err := client.RoundTrip(req)
@ -130,9 +135,10 @@ var _ = Describe("Client", func() {
client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil) client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil)
session := mockquic.NewMockEarlySession(mockCtrl) session := mockquic.NewMockEarlySession(mockCtrl)
session.EXPECT().OpenUniStream().Return(nil, testErr).MaxTimes(1) session.EXPECT().OpenUniStream().Return(nil, testErr).MaxTimes(1)
session.EXPECT().HandshakeComplete().Return(handshakeCtx).MaxTimes(1)
session.EXPECT().OpenStreamSync(context.Background()).Return(nil, testErr).MaxTimes(1) session.EXPECT().OpenStreamSync(context.Background()).Return(nil, testErr).MaxTimes(1)
session.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).MaxTimes(1) session.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).MaxTimes(1)
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) { dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.EarlySession, error) {
return session, nil return session, nil
} }
defer GinkgoRecover() defer GinkgoRecover()
@ -173,7 +179,7 @@ var _ = Describe("Client", func() {
str = mockquic.NewMockStream(mockCtrl) str = mockquic.NewMockStream(mockCtrl)
sess = mockquic.NewMockEarlySession(mockCtrl) sess = mockquic.NewMockEarlySession(mockCtrl)
sess.EXPECT().OpenUniStream().Return(controlStr, nil).MaxTimes(1) sess.EXPECT().OpenUniStream().Return(controlStr, nil).MaxTimes(1)
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) { dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.EarlySession, error) {
return sess, nil return sess, nil
} }
var err error var err error
@ -186,7 +192,10 @@ var _ = Describe("Client", func() {
rw := newResponseWriter(rspBuf, utils.DefaultLogger) rw := newResponseWriter(rspBuf, utils.DefaultLogger)
rw.WriteHeader(418) rw.WriteHeader(418)
sess.EXPECT().OpenStreamSync(context.Background()).Return(str, nil) gomock.InOrder(
sess.EXPECT().HandshakeComplete().Return(handshakeCtx),
sess.EXPECT().OpenStreamSync(context.Background()).Return(str, nil),
)
str.EXPECT().Write(gomock.Any()).AnyTimes() str.EXPECT().Write(gomock.Any()).AnyTimes()
str.EXPECT().Close() str.EXPECT().Close()
str.EXPECT().Read(gomock.Any()).DoAndReturn(func(p []byte) (int, error) { str.EXPECT().Read(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
@ -220,7 +229,10 @@ var _ = Describe("Client", func() {
BeforeEach(func() { BeforeEach(func() {
strBuf = &bytes.Buffer{} strBuf = &bytes.Buffer{}
sess.EXPECT().OpenStreamSync(context.Background()).Return(str, nil) gomock.InOrder(
sess.EXPECT().HandshakeComplete().Return(handshakeCtx),
sess.EXPECT().OpenStreamSync(context.Background()).Return(str, nil),
)
body := &mockBody{} body := &mockBody{}
body.SetData([]byte("request body")) body.SetData([]byte("request body"))
var err error var err error
@ -301,6 +313,7 @@ var _ = Describe("Client", func() {
It("cancels a request while the request is still in flight", func() { It("cancels a request while the request is still in flight", func() {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
req := request.WithContext(ctx) req := request.WithContext(ctx)
sess.EXPECT().HandshakeComplete().Return(handshakeCtx)
sess.EXPECT().OpenStreamSync(ctx).Return(str, nil) sess.EXPECT().OpenStreamSync(ctx).Return(str, nil)
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
str.EXPECT().Close().MaxTimes(1) str.EXPECT().Close().MaxTimes(1)
@ -333,6 +346,7 @@ var _ = Describe("Client", func() {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
req := request.WithContext(ctx) req := request.WithContext(ctx)
sess.EXPECT().HandshakeComplete().Return(handshakeCtx)
sess.EXPECT().OpenStreamSync(ctx).Return(str, nil) sess.EXPECT().OpenStreamSync(ctx).Return(str, nil)
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
str.EXPECT().Close().MaxTimes(1) str.EXPECT().Close().MaxTimes(1)
@ -354,21 +368,8 @@ var _ = Describe("Client", func() {
}) })
Context("gzip compression", func() { Context("gzip compression", func() {
var gzippedData []byte // a gzipped foobar
var response *http.Response
BeforeEach(func() { BeforeEach(func() {
var b bytes.Buffer sess.EXPECT().HandshakeComplete().Return(handshakeCtx)
w := gzip.NewWriter(&b)
w.Write([]byte("foobar"))
w.Close()
gzippedData = b.Bytes()
response = &http.Response{
StatusCode: 200,
Header: http.Header{"Content-Length": []string{"1000"}},
}
_ = gzippedData
_ = response
}) })
It("adds the gzip header to requests", func() { It("adds the gzip header to requests", func() {

View file

@ -44,7 +44,7 @@ type RoundTripper struct {
// Dial specifies an optional dial function for creating QUIC // Dial specifies an optional dial function for creating QUIC
// connections for requests. // connections for requests.
// If Dial is nil, quic.DialAddr will be used. // If Dial is nil, quic.DialAddr will be used.
Dial func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.Session, error) Dial func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlySession, error)
// MaxResponseHeaderBytes specifies a limit on how many response bytes are // MaxResponseHeaderBytes specifies a limit on how many response bytes are
// allowed in the server's response header. // allowed in the server's response header.

View file

@ -58,9 +58,10 @@ func (m *mockBody) Close() error {
var _ = Describe("RoundTripper", func() { var _ = Describe("RoundTripper", func() {
var ( var (
rt *RoundTripper rt *RoundTripper
req1 *http.Request req1 *http.Request
session *mockquic.MockEarlySession session *mockquic.MockEarlySession
handshakeCtx context.Context // an already canceled context
) )
BeforeEach(func() { BeforeEach(func() {
@ -68,6 +69,10 @@ var _ = Describe("RoundTripper", func() {
var err error var err error
req1, err = http.NewRequest("GET", "https://www.example.org/file1.html", nil) req1, err = http.NewRequest("GET", "https://www.example.org/file1.html", nil)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
ctx, cancel := context.WithCancel(context.Background())
cancel()
handshakeCtx = ctx
}) })
Context("dialing hosts", func() { Context("dialing hosts", func() {
@ -76,7 +81,7 @@ var _ = Describe("RoundTripper", func() {
BeforeEach(func() { BeforeEach(func() {
session = mockquic.NewMockEarlySession(mockCtrl) session = mockquic.NewMockEarlySession(mockCtrl)
origDialAddr = dialAddr origDialAddr = dialAddr
dialAddr = func(addr string, tlsConf *tls.Config, config *quic.Config) (quic.Session, error) { dialAddr = func(addr string, tlsConf *tls.Config, config *quic.Config) (quic.EarlySession, error) {
// return an error when trying to open a stream // return an error when trying to open a stream
// we don't want to test all the dial logic here, just that dialing happens at all // we don't want to test all the dial logic here, just that dialing happens at all
return session, nil return session, nil
@ -93,6 +98,7 @@ var _ = Describe("RoundTripper", func() {
req, err := http.NewRequest("GET", "https://quic.clemente.io/foobar.html", nil) req, err := http.NewRequest("GET", "https://quic.clemente.io/foobar.html", nil)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
session.EXPECT().OpenUniStream().AnyTimes().Return(nil, testErr) session.EXPECT().OpenUniStream().AnyTimes().Return(nil, testErr)
session.EXPECT().HandshakeComplete().Return(handshakeCtx)
session.EXPECT().OpenStreamSync(context.Background()).Return(nil, testErr) session.EXPECT().OpenStreamSync(context.Background()).Return(nil, testErr)
session.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(quic.ErrorCode, string) { close(closed) }) session.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(quic.ErrorCode, string) { close(closed) })
_, err = rt.RoundTrip(req) _, err = rt.RoundTrip(req)
@ -104,7 +110,7 @@ var _ = Describe("RoundTripper", func() {
It("uses the quic.Config, if provided", func() { It("uses the quic.Config, if provided", func() {
config := &quic.Config{HandshakeTimeout: time.Millisecond} config := &quic.Config{HandshakeTimeout: time.Millisecond}
var receivedConfig *quic.Config var receivedConfig *quic.Config
dialAddr = func(addr string, tlsConf *tls.Config, config *quic.Config) (quic.Session, error) { dialAddr = func(addr string, tlsConf *tls.Config, config *quic.Config) (quic.EarlySession, error) {
receivedConfig = config receivedConfig = config
return nil, errors.New("handshake error") return nil, errors.New("handshake error")
} }
@ -116,7 +122,7 @@ var _ = Describe("RoundTripper", func() {
It("uses the custom dialer, if provided", func() { It("uses the custom dialer, if provided", func() {
var dialed bool var dialed bool
dialer := func(_, _ string, tlsCfgP *tls.Config, cfg *quic.Config) (quic.Session, error) { dialer := func(_, _ string, tlsCfgP *tls.Config, cfg *quic.Config) (quic.EarlySession, error) {
dialed = true dialed = true
return nil, errors.New("handshake error") return nil, errors.New("handshake error")
} }
@ -129,6 +135,7 @@ var _ = Describe("RoundTripper", func() {
It("reuses existing clients", func() { It("reuses existing clients", func() {
closed := make(chan struct{}) closed := make(chan struct{})
testErr := errors.New("test err") testErr := errors.New("test err")
session.EXPECT().HandshakeComplete().Return(handshakeCtx)
session.EXPECT().OpenUniStream().AnyTimes().Return(nil, testErr) session.EXPECT().OpenUniStream().AnyTimes().Return(nil, testErr)
session.EXPECT().OpenStreamSync(context.Background()).Return(nil, testErr).Times(2) session.EXPECT().OpenStreamSync(context.Background()).Return(nil, testErr).Times(2)
session.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(quic.ErrorCode, string) { close(closed) }) session.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(quic.ErrorCode, string) { close(closed) })