use 0-RTT to open the H3 client's control stream

This commit is contained in:
Marten Seemann 2020-01-25 21:26:27 +07:00
parent 63c9272bf4
commit 1372e5dd5e
4 changed files with 44 additions and 34 deletions

View file

@ -23,7 +23,7 @@ var defaultQuicConfig = &quic.Config{
KeepAlive: true,
}
var dialAddr = quic.DialAddr
var dialAddr = quic.DialAddrEarly
type roundTripperOpts struct {
DisableCompression bool
@ -37,7 +37,7 @@ type client struct {
opts *roundTripperOpts
dialOnce sync.Once
dialer func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.Session, error)
dialer func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlySession, error)
handshakeErr error
requestWriter *requestWriter
@ -45,7 +45,7 @@ type client struct {
decoder *qpack.Decoder
hostname string
session quic.Session
session quic.EarlySession
logger utils.Logger
}
@ -55,7 +55,7 @@ func newClient(
tlsConf *tls.Config,
opts *roundTripperOpts,
quicConfig *quic.Config,
dialer func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.Session, error),
dialer func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlySession, error),
) *client {
if tlsConf == nil {
tlsConf = &tls.Config{}
@ -93,6 +93,7 @@ func (c *client) dial() error {
return err
}
// run the sesssion setup using 0-RTT data
go func() {
if err := c.setupSession(); err != nil {
c.logger.Debugf("Setting up session failed: %s", err)
@ -100,6 +101,7 @@ func (c *client) dial() error {
}
}()
<-c.session.HandshakeComplete().Done()
return nil
}

View file

@ -26,6 +26,7 @@ var _ = Describe("Client", func() {
client *client
req *http.Request
origDialAddr = dialAddr
handshakeCtx context.Context // an already canceled context
)
BeforeEach(func() {
@ -37,6 +38,10 @@ var _ = Describe("Client", func() {
var err error
req, err = http.NewRequest("GET", "https://localhost:1337", nil)
Expect(err).ToNot(HaveOccurred())
ctx, cancel := context.WithCancel(context.Background())
cancel()
handshakeCtx = ctx
})
AfterEach(func() {
@ -46,7 +51,7 @@ var _ = Describe("Client", func() {
It("uses the default QUIC and TLS config if none is give", func() {
client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil)
var dialAddrCalled bool
dialAddr = func(_ string, tlsConf *tls.Config, quicConf *quic.Config) (quic.Session, error) {
dialAddr = func(_ string, tlsConf *tls.Config, quicConf *quic.Config) (quic.EarlySession, error) {
Expect(quicConf).To(Equal(defaultQuicConfig))
Expect(tlsConf.NextProtos).To(Equal([]string{nextProtoH3}))
dialAddrCalled = true
@ -59,7 +64,7 @@ var _ = Describe("Client", func() {
It("adds the port to the hostname, if none is given", func() {
client = newClient("quic.clemente.io", nil, &roundTripperOpts{}, nil, nil)
var dialAddrCalled bool
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) {
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.EarlySession, error) {
Expect(hostname).To(Equal("quic.clemente.io:443"))
dialAddrCalled = true
return nil, errors.New("test done")
@ -82,7 +87,7 @@ var _ = Describe("Client", func() {
hostname string,
tlsConfP *tls.Config,
quicConfP *quic.Config,
) (quic.Session, error) {
) (quic.EarlySession, error) {
Expect(hostname).To(Equal("localhost:1337"))
Expect(tlsConfP.ServerName).To(Equal(tlsConf.ServerName))
Expect(tlsConfP.NextProtos).To(Equal([]string{nextProtoH3}))
@ -101,7 +106,7 @@ var _ = Describe("Client", func() {
tlsConf := &tls.Config{ServerName: "foo.bar"}
quicConf := &quic.Config{MaxIdleTimeout: 1337 * time.Second}
var dialerCalled bool
dialer := func(network, address string, tlsConfP *tls.Config, quicConfP *quic.Config) (quic.Session, error) {
dialer := func(network, address string, tlsConfP *tls.Config, quicConfP *quic.Config) (quic.EarlySession, error) {
Expect(network).To(Equal("udp"))
Expect(address).To(Equal("localhost:1337"))
Expect(tlsConfP.ServerName).To(Equal("foo.bar"))
@ -118,7 +123,7 @@ var _ = Describe("Client", func() {
It("errors when dialing fails", func() {
testErr := errors.New("handshake error")
client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil)
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) {
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.EarlySession, error) {
return nil, testErr
}
_, err := client.RoundTrip(req)
@ -130,9 +135,10 @@ var _ = Describe("Client", func() {
client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil)
session := mockquic.NewMockEarlySession(mockCtrl)
session.EXPECT().OpenUniStream().Return(nil, testErr).MaxTimes(1)
session.EXPECT().HandshakeComplete().Return(handshakeCtx).MaxTimes(1)
session.EXPECT().OpenStreamSync(context.Background()).Return(nil, testErr).MaxTimes(1)
session.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).MaxTimes(1)
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) {
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.EarlySession, error) {
return session, nil
}
defer GinkgoRecover()
@ -173,7 +179,7 @@ var _ = Describe("Client", func() {
str = mockquic.NewMockStream(mockCtrl)
sess = mockquic.NewMockEarlySession(mockCtrl)
sess.EXPECT().OpenUniStream().Return(controlStr, nil).MaxTimes(1)
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) {
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.EarlySession, error) {
return sess, nil
}
var err error
@ -186,7 +192,10 @@ var _ = Describe("Client", func() {
rw := newResponseWriter(rspBuf, utils.DefaultLogger)
rw.WriteHeader(418)
sess.EXPECT().OpenStreamSync(context.Background()).Return(str, nil)
gomock.InOrder(
sess.EXPECT().HandshakeComplete().Return(handshakeCtx),
sess.EXPECT().OpenStreamSync(context.Background()).Return(str, nil),
)
str.EXPECT().Write(gomock.Any()).AnyTimes()
str.EXPECT().Close()
str.EXPECT().Read(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
@ -220,7 +229,10 @@ var _ = Describe("Client", func() {
BeforeEach(func() {
strBuf = &bytes.Buffer{}
sess.EXPECT().OpenStreamSync(context.Background()).Return(str, nil)
gomock.InOrder(
sess.EXPECT().HandshakeComplete().Return(handshakeCtx),
sess.EXPECT().OpenStreamSync(context.Background()).Return(str, nil),
)
body := &mockBody{}
body.SetData([]byte("request body"))
var err error
@ -301,6 +313,7 @@ var _ = Describe("Client", func() {
It("cancels a request while the request is still in flight", func() {
ctx, cancel := context.WithCancel(context.Background())
req := request.WithContext(ctx)
sess.EXPECT().HandshakeComplete().Return(handshakeCtx)
sess.EXPECT().OpenStreamSync(ctx).Return(str, nil)
buf := &bytes.Buffer{}
str.EXPECT().Close().MaxTimes(1)
@ -333,6 +346,7 @@ var _ = Describe("Client", func() {
ctx, cancel := context.WithCancel(context.Background())
req := request.WithContext(ctx)
sess.EXPECT().HandshakeComplete().Return(handshakeCtx)
sess.EXPECT().OpenStreamSync(ctx).Return(str, nil)
buf := &bytes.Buffer{}
str.EXPECT().Close().MaxTimes(1)
@ -354,21 +368,8 @@ var _ = Describe("Client", func() {
})
Context("gzip compression", func() {
var gzippedData []byte // a gzipped foobar
var response *http.Response
BeforeEach(func() {
var b bytes.Buffer
w := gzip.NewWriter(&b)
w.Write([]byte("foobar"))
w.Close()
gzippedData = b.Bytes()
response = &http.Response{
StatusCode: 200,
Header: http.Header{"Content-Length": []string{"1000"}},
}
_ = gzippedData
_ = response
sess.EXPECT().HandshakeComplete().Return(handshakeCtx)
})
It("adds the gzip header to requests", func() {

View file

@ -44,7 +44,7 @@ type RoundTripper struct {
// Dial specifies an optional dial function for creating QUIC
// connections for requests.
// If Dial is nil, quic.DialAddr will be used.
Dial func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.Session, error)
Dial func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlySession, error)
// MaxResponseHeaderBytes specifies a limit on how many response bytes are
// allowed in the server's response header.

View file

@ -58,9 +58,10 @@ func (m *mockBody) Close() error {
var _ = Describe("RoundTripper", func() {
var (
rt *RoundTripper
req1 *http.Request
session *mockquic.MockEarlySession
rt *RoundTripper
req1 *http.Request
session *mockquic.MockEarlySession
handshakeCtx context.Context // an already canceled context
)
BeforeEach(func() {
@ -68,6 +69,10 @@ var _ = Describe("RoundTripper", func() {
var err error
req1, err = http.NewRequest("GET", "https://www.example.org/file1.html", nil)
Expect(err).ToNot(HaveOccurred())
ctx, cancel := context.WithCancel(context.Background())
cancel()
handshakeCtx = ctx
})
Context("dialing hosts", func() {
@ -76,7 +81,7 @@ var _ = Describe("RoundTripper", func() {
BeforeEach(func() {
session = mockquic.NewMockEarlySession(mockCtrl)
origDialAddr = dialAddr
dialAddr = func(addr string, tlsConf *tls.Config, config *quic.Config) (quic.Session, error) {
dialAddr = func(addr string, tlsConf *tls.Config, config *quic.Config) (quic.EarlySession, error) {
// return an error when trying to open a stream
// we don't want to test all the dial logic here, just that dialing happens at all
return session, nil
@ -93,6 +98,7 @@ var _ = Describe("RoundTripper", func() {
req, err := http.NewRequest("GET", "https://quic.clemente.io/foobar.html", nil)
Expect(err).ToNot(HaveOccurred())
session.EXPECT().OpenUniStream().AnyTimes().Return(nil, testErr)
session.EXPECT().HandshakeComplete().Return(handshakeCtx)
session.EXPECT().OpenStreamSync(context.Background()).Return(nil, testErr)
session.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(quic.ErrorCode, string) { close(closed) })
_, err = rt.RoundTrip(req)
@ -104,7 +110,7 @@ var _ = Describe("RoundTripper", func() {
It("uses the quic.Config, if provided", func() {
config := &quic.Config{HandshakeTimeout: time.Millisecond}
var receivedConfig *quic.Config
dialAddr = func(addr string, tlsConf *tls.Config, config *quic.Config) (quic.Session, error) {
dialAddr = func(addr string, tlsConf *tls.Config, config *quic.Config) (quic.EarlySession, error) {
receivedConfig = config
return nil, errors.New("handshake error")
}
@ -116,7 +122,7 @@ var _ = Describe("RoundTripper", func() {
It("uses the custom dialer, if provided", func() {
var dialed bool
dialer := func(_, _ string, tlsCfgP *tls.Config, cfg *quic.Config) (quic.Session, error) {
dialer := func(_, _ string, tlsCfgP *tls.Config, cfg *quic.Config) (quic.EarlySession, error) {
dialed = true
return nil, errors.New("handshake error")
}
@ -129,6 +135,7 @@ var _ = Describe("RoundTripper", func() {
It("reuses existing clients", func() {
closed := make(chan struct{})
testErr := errors.New("test err")
session.EXPECT().HandshakeComplete().Return(handshakeCtx)
session.EXPECT().OpenUniStream().AnyTimes().Return(nil, testErr)
session.EXPECT().OpenStreamSync(context.Background()).Return(nil, testErr).Times(2)
session.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(quic.ErrorCode, string) { close(closed) })