mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-04 12:47:36 +03:00
use 0-RTT to open the H3 client's control stream
This commit is contained in:
parent
63c9272bf4
commit
1372e5dd5e
4 changed files with 44 additions and 34 deletions
|
@ -23,7 +23,7 @@ var defaultQuicConfig = &quic.Config{
|
||||||
KeepAlive: true,
|
KeepAlive: true,
|
||||||
}
|
}
|
||||||
|
|
||||||
var dialAddr = quic.DialAddr
|
var dialAddr = quic.DialAddrEarly
|
||||||
|
|
||||||
type roundTripperOpts struct {
|
type roundTripperOpts struct {
|
||||||
DisableCompression bool
|
DisableCompression bool
|
||||||
|
@ -37,7 +37,7 @@ type client struct {
|
||||||
opts *roundTripperOpts
|
opts *roundTripperOpts
|
||||||
|
|
||||||
dialOnce sync.Once
|
dialOnce sync.Once
|
||||||
dialer func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.Session, error)
|
dialer func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlySession, error)
|
||||||
handshakeErr error
|
handshakeErr error
|
||||||
|
|
||||||
requestWriter *requestWriter
|
requestWriter *requestWriter
|
||||||
|
@ -45,7 +45,7 @@ type client struct {
|
||||||
decoder *qpack.Decoder
|
decoder *qpack.Decoder
|
||||||
|
|
||||||
hostname string
|
hostname string
|
||||||
session quic.Session
|
session quic.EarlySession
|
||||||
|
|
||||||
logger utils.Logger
|
logger utils.Logger
|
||||||
}
|
}
|
||||||
|
@ -55,7 +55,7 @@ func newClient(
|
||||||
tlsConf *tls.Config,
|
tlsConf *tls.Config,
|
||||||
opts *roundTripperOpts,
|
opts *roundTripperOpts,
|
||||||
quicConfig *quic.Config,
|
quicConfig *quic.Config,
|
||||||
dialer func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.Session, error),
|
dialer func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlySession, error),
|
||||||
) *client {
|
) *client {
|
||||||
if tlsConf == nil {
|
if tlsConf == nil {
|
||||||
tlsConf = &tls.Config{}
|
tlsConf = &tls.Config{}
|
||||||
|
@ -93,6 +93,7 @@ func (c *client) dial() error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// run the sesssion setup using 0-RTT data
|
||||||
go func() {
|
go func() {
|
||||||
if err := c.setupSession(); err != nil {
|
if err := c.setupSession(); err != nil {
|
||||||
c.logger.Debugf("Setting up session failed: %s", err)
|
c.logger.Debugf("Setting up session failed: %s", err)
|
||||||
|
@ -100,6 +101,7 @@ func (c *client) dial() error {
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
<-c.session.HandshakeComplete().Done()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -26,6 +26,7 @@ var _ = Describe("Client", func() {
|
||||||
client *client
|
client *client
|
||||||
req *http.Request
|
req *http.Request
|
||||||
origDialAddr = dialAddr
|
origDialAddr = dialAddr
|
||||||
|
handshakeCtx context.Context // an already canceled context
|
||||||
)
|
)
|
||||||
|
|
||||||
BeforeEach(func() {
|
BeforeEach(func() {
|
||||||
|
@ -37,6 +38,10 @@ var _ = Describe("Client", func() {
|
||||||
var err error
|
var err error
|
||||||
req, err = http.NewRequest("GET", "https://localhost:1337", nil)
|
req, err = http.NewRequest("GET", "https://localhost:1337", nil)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
cancel()
|
||||||
|
handshakeCtx = ctx
|
||||||
})
|
})
|
||||||
|
|
||||||
AfterEach(func() {
|
AfterEach(func() {
|
||||||
|
@ -46,7 +51,7 @@ var _ = Describe("Client", func() {
|
||||||
It("uses the default QUIC and TLS config if none is give", func() {
|
It("uses the default QUIC and TLS config if none is give", func() {
|
||||||
client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil)
|
client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil)
|
||||||
var dialAddrCalled bool
|
var dialAddrCalled bool
|
||||||
dialAddr = func(_ string, tlsConf *tls.Config, quicConf *quic.Config) (quic.Session, error) {
|
dialAddr = func(_ string, tlsConf *tls.Config, quicConf *quic.Config) (quic.EarlySession, error) {
|
||||||
Expect(quicConf).To(Equal(defaultQuicConfig))
|
Expect(quicConf).To(Equal(defaultQuicConfig))
|
||||||
Expect(tlsConf.NextProtos).To(Equal([]string{nextProtoH3}))
|
Expect(tlsConf.NextProtos).To(Equal([]string{nextProtoH3}))
|
||||||
dialAddrCalled = true
|
dialAddrCalled = true
|
||||||
|
@ -59,7 +64,7 @@ var _ = Describe("Client", func() {
|
||||||
It("adds the port to the hostname, if none is given", func() {
|
It("adds the port to the hostname, if none is given", func() {
|
||||||
client = newClient("quic.clemente.io", nil, &roundTripperOpts{}, nil, nil)
|
client = newClient("quic.clemente.io", nil, &roundTripperOpts{}, nil, nil)
|
||||||
var dialAddrCalled bool
|
var dialAddrCalled bool
|
||||||
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) {
|
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.EarlySession, error) {
|
||||||
Expect(hostname).To(Equal("quic.clemente.io:443"))
|
Expect(hostname).To(Equal("quic.clemente.io:443"))
|
||||||
dialAddrCalled = true
|
dialAddrCalled = true
|
||||||
return nil, errors.New("test done")
|
return nil, errors.New("test done")
|
||||||
|
@ -82,7 +87,7 @@ var _ = Describe("Client", func() {
|
||||||
hostname string,
|
hostname string,
|
||||||
tlsConfP *tls.Config,
|
tlsConfP *tls.Config,
|
||||||
quicConfP *quic.Config,
|
quicConfP *quic.Config,
|
||||||
) (quic.Session, error) {
|
) (quic.EarlySession, error) {
|
||||||
Expect(hostname).To(Equal("localhost:1337"))
|
Expect(hostname).To(Equal("localhost:1337"))
|
||||||
Expect(tlsConfP.ServerName).To(Equal(tlsConf.ServerName))
|
Expect(tlsConfP.ServerName).To(Equal(tlsConf.ServerName))
|
||||||
Expect(tlsConfP.NextProtos).To(Equal([]string{nextProtoH3}))
|
Expect(tlsConfP.NextProtos).To(Equal([]string{nextProtoH3}))
|
||||||
|
@ -101,7 +106,7 @@ var _ = Describe("Client", func() {
|
||||||
tlsConf := &tls.Config{ServerName: "foo.bar"}
|
tlsConf := &tls.Config{ServerName: "foo.bar"}
|
||||||
quicConf := &quic.Config{MaxIdleTimeout: 1337 * time.Second}
|
quicConf := &quic.Config{MaxIdleTimeout: 1337 * time.Second}
|
||||||
var dialerCalled bool
|
var dialerCalled bool
|
||||||
dialer := func(network, address string, tlsConfP *tls.Config, quicConfP *quic.Config) (quic.Session, error) {
|
dialer := func(network, address string, tlsConfP *tls.Config, quicConfP *quic.Config) (quic.EarlySession, error) {
|
||||||
Expect(network).To(Equal("udp"))
|
Expect(network).To(Equal("udp"))
|
||||||
Expect(address).To(Equal("localhost:1337"))
|
Expect(address).To(Equal("localhost:1337"))
|
||||||
Expect(tlsConfP.ServerName).To(Equal("foo.bar"))
|
Expect(tlsConfP.ServerName).To(Equal("foo.bar"))
|
||||||
|
@ -118,7 +123,7 @@ var _ = Describe("Client", func() {
|
||||||
It("errors when dialing fails", func() {
|
It("errors when dialing fails", func() {
|
||||||
testErr := errors.New("handshake error")
|
testErr := errors.New("handshake error")
|
||||||
client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil)
|
client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil)
|
||||||
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) {
|
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.EarlySession, error) {
|
||||||
return nil, testErr
|
return nil, testErr
|
||||||
}
|
}
|
||||||
_, err := client.RoundTrip(req)
|
_, err := client.RoundTrip(req)
|
||||||
|
@ -130,9 +135,10 @@ var _ = Describe("Client", func() {
|
||||||
client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil)
|
client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil)
|
||||||
session := mockquic.NewMockEarlySession(mockCtrl)
|
session := mockquic.NewMockEarlySession(mockCtrl)
|
||||||
session.EXPECT().OpenUniStream().Return(nil, testErr).MaxTimes(1)
|
session.EXPECT().OpenUniStream().Return(nil, testErr).MaxTimes(1)
|
||||||
|
session.EXPECT().HandshakeComplete().Return(handshakeCtx).MaxTimes(1)
|
||||||
session.EXPECT().OpenStreamSync(context.Background()).Return(nil, testErr).MaxTimes(1)
|
session.EXPECT().OpenStreamSync(context.Background()).Return(nil, testErr).MaxTimes(1)
|
||||||
session.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).MaxTimes(1)
|
session.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).MaxTimes(1)
|
||||||
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) {
|
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.EarlySession, error) {
|
||||||
return session, nil
|
return session, nil
|
||||||
}
|
}
|
||||||
defer GinkgoRecover()
|
defer GinkgoRecover()
|
||||||
|
@ -173,7 +179,7 @@ var _ = Describe("Client", func() {
|
||||||
str = mockquic.NewMockStream(mockCtrl)
|
str = mockquic.NewMockStream(mockCtrl)
|
||||||
sess = mockquic.NewMockEarlySession(mockCtrl)
|
sess = mockquic.NewMockEarlySession(mockCtrl)
|
||||||
sess.EXPECT().OpenUniStream().Return(controlStr, nil).MaxTimes(1)
|
sess.EXPECT().OpenUniStream().Return(controlStr, nil).MaxTimes(1)
|
||||||
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) {
|
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.EarlySession, error) {
|
||||||
return sess, nil
|
return sess, nil
|
||||||
}
|
}
|
||||||
var err error
|
var err error
|
||||||
|
@ -186,7 +192,10 @@ var _ = Describe("Client", func() {
|
||||||
rw := newResponseWriter(rspBuf, utils.DefaultLogger)
|
rw := newResponseWriter(rspBuf, utils.DefaultLogger)
|
||||||
rw.WriteHeader(418)
|
rw.WriteHeader(418)
|
||||||
|
|
||||||
sess.EXPECT().OpenStreamSync(context.Background()).Return(str, nil)
|
gomock.InOrder(
|
||||||
|
sess.EXPECT().HandshakeComplete().Return(handshakeCtx),
|
||||||
|
sess.EXPECT().OpenStreamSync(context.Background()).Return(str, nil),
|
||||||
|
)
|
||||||
str.EXPECT().Write(gomock.Any()).AnyTimes()
|
str.EXPECT().Write(gomock.Any()).AnyTimes()
|
||||||
str.EXPECT().Close()
|
str.EXPECT().Close()
|
||||||
str.EXPECT().Read(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
|
str.EXPECT().Read(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
|
||||||
|
@ -220,7 +229,10 @@ var _ = Describe("Client", func() {
|
||||||
|
|
||||||
BeforeEach(func() {
|
BeforeEach(func() {
|
||||||
strBuf = &bytes.Buffer{}
|
strBuf = &bytes.Buffer{}
|
||||||
sess.EXPECT().OpenStreamSync(context.Background()).Return(str, nil)
|
gomock.InOrder(
|
||||||
|
sess.EXPECT().HandshakeComplete().Return(handshakeCtx),
|
||||||
|
sess.EXPECT().OpenStreamSync(context.Background()).Return(str, nil),
|
||||||
|
)
|
||||||
body := &mockBody{}
|
body := &mockBody{}
|
||||||
body.SetData([]byte("request body"))
|
body.SetData([]byte("request body"))
|
||||||
var err error
|
var err error
|
||||||
|
@ -301,6 +313,7 @@ var _ = Describe("Client", func() {
|
||||||
It("cancels a request while the request is still in flight", func() {
|
It("cancels a request while the request is still in flight", func() {
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
req := request.WithContext(ctx)
|
req := request.WithContext(ctx)
|
||||||
|
sess.EXPECT().HandshakeComplete().Return(handshakeCtx)
|
||||||
sess.EXPECT().OpenStreamSync(ctx).Return(str, nil)
|
sess.EXPECT().OpenStreamSync(ctx).Return(str, nil)
|
||||||
buf := &bytes.Buffer{}
|
buf := &bytes.Buffer{}
|
||||||
str.EXPECT().Close().MaxTimes(1)
|
str.EXPECT().Close().MaxTimes(1)
|
||||||
|
@ -333,6 +346,7 @@ var _ = Describe("Client", func() {
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
req := request.WithContext(ctx)
|
req := request.WithContext(ctx)
|
||||||
|
sess.EXPECT().HandshakeComplete().Return(handshakeCtx)
|
||||||
sess.EXPECT().OpenStreamSync(ctx).Return(str, nil)
|
sess.EXPECT().OpenStreamSync(ctx).Return(str, nil)
|
||||||
buf := &bytes.Buffer{}
|
buf := &bytes.Buffer{}
|
||||||
str.EXPECT().Close().MaxTimes(1)
|
str.EXPECT().Close().MaxTimes(1)
|
||||||
|
@ -354,21 +368,8 @@ var _ = Describe("Client", func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
Context("gzip compression", func() {
|
Context("gzip compression", func() {
|
||||||
var gzippedData []byte // a gzipped foobar
|
|
||||||
var response *http.Response
|
|
||||||
|
|
||||||
BeforeEach(func() {
|
BeforeEach(func() {
|
||||||
var b bytes.Buffer
|
sess.EXPECT().HandshakeComplete().Return(handshakeCtx)
|
||||||
w := gzip.NewWriter(&b)
|
|
||||||
w.Write([]byte("foobar"))
|
|
||||||
w.Close()
|
|
||||||
gzippedData = b.Bytes()
|
|
||||||
response = &http.Response{
|
|
||||||
StatusCode: 200,
|
|
||||||
Header: http.Header{"Content-Length": []string{"1000"}},
|
|
||||||
}
|
|
||||||
_ = gzippedData
|
|
||||||
_ = response
|
|
||||||
})
|
})
|
||||||
|
|
||||||
It("adds the gzip header to requests", func() {
|
It("adds the gzip header to requests", func() {
|
||||||
|
|
|
@ -44,7 +44,7 @@ type RoundTripper struct {
|
||||||
// Dial specifies an optional dial function for creating QUIC
|
// Dial specifies an optional dial function for creating QUIC
|
||||||
// connections for requests.
|
// connections for requests.
|
||||||
// If Dial is nil, quic.DialAddr will be used.
|
// If Dial is nil, quic.DialAddr will be used.
|
||||||
Dial func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.Session, error)
|
Dial func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlySession, error)
|
||||||
|
|
||||||
// MaxResponseHeaderBytes specifies a limit on how many response bytes are
|
// MaxResponseHeaderBytes specifies a limit on how many response bytes are
|
||||||
// allowed in the server's response header.
|
// allowed in the server's response header.
|
||||||
|
|
|
@ -58,9 +58,10 @@ func (m *mockBody) Close() error {
|
||||||
|
|
||||||
var _ = Describe("RoundTripper", func() {
|
var _ = Describe("RoundTripper", func() {
|
||||||
var (
|
var (
|
||||||
rt *RoundTripper
|
rt *RoundTripper
|
||||||
req1 *http.Request
|
req1 *http.Request
|
||||||
session *mockquic.MockEarlySession
|
session *mockquic.MockEarlySession
|
||||||
|
handshakeCtx context.Context // an already canceled context
|
||||||
)
|
)
|
||||||
|
|
||||||
BeforeEach(func() {
|
BeforeEach(func() {
|
||||||
|
@ -68,6 +69,10 @@ var _ = Describe("RoundTripper", func() {
|
||||||
var err error
|
var err error
|
||||||
req1, err = http.NewRequest("GET", "https://www.example.org/file1.html", nil)
|
req1, err = http.NewRequest("GET", "https://www.example.org/file1.html", nil)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
cancel()
|
||||||
|
handshakeCtx = ctx
|
||||||
})
|
})
|
||||||
|
|
||||||
Context("dialing hosts", func() {
|
Context("dialing hosts", func() {
|
||||||
|
@ -76,7 +81,7 @@ var _ = Describe("RoundTripper", func() {
|
||||||
BeforeEach(func() {
|
BeforeEach(func() {
|
||||||
session = mockquic.NewMockEarlySession(mockCtrl)
|
session = mockquic.NewMockEarlySession(mockCtrl)
|
||||||
origDialAddr = dialAddr
|
origDialAddr = dialAddr
|
||||||
dialAddr = func(addr string, tlsConf *tls.Config, config *quic.Config) (quic.Session, error) {
|
dialAddr = func(addr string, tlsConf *tls.Config, config *quic.Config) (quic.EarlySession, error) {
|
||||||
// return an error when trying to open a stream
|
// return an error when trying to open a stream
|
||||||
// we don't want to test all the dial logic here, just that dialing happens at all
|
// we don't want to test all the dial logic here, just that dialing happens at all
|
||||||
return session, nil
|
return session, nil
|
||||||
|
@ -93,6 +98,7 @@ var _ = Describe("RoundTripper", func() {
|
||||||
req, err := http.NewRequest("GET", "https://quic.clemente.io/foobar.html", nil)
|
req, err := http.NewRequest("GET", "https://quic.clemente.io/foobar.html", nil)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
session.EXPECT().OpenUniStream().AnyTimes().Return(nil, testErr)
|
session.EXPECT().OpenUniStream().AnyTimes().Return(nil, testErr)
|
||||||
|
session.EXPECT().HandshakeComplete().Return(handshakeCtx)
|
||||||
session.EXPECT().OpenStreamSync(context.Background()).Return(nil, testErr)
|
session.EXPECT().OpenStreamSync(context.Background()).Return(nil, testErr)
|
||||||
session.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(quic.ErrorCode, string) { close(closed) })
|
session.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(quic.ErrorCode, string) { close(closed) })
|
||||||
_, err = rt.RoundTrip(req)
|
_, err = rt.RoundTrip(req)
|
||||||
|
@ -104,7 +110,7 @@ var _ = Describe("RoundTripper", func() {
|
||||||
It("uses the quic.Config, if provided", func() {
|
It("uses the quic.Config, if provided", func() {
|
||||||
config := &quic.Config{HandshakeTimeout: time.Millisecond}
|
config := &quic.Config{HandshakeTimeout: time.Millisecond}
|
||||||
var receivedConfig *quic.Config
|
var receivedConfig *quic.Config
|
||||||
dialAddr = func(addr string, tlsConf *tls.Config, config *quic.Config) (quic.Session, error) {
|
dialAddr = func(addr string, tlsConf *tls.Config, config *quic.Config) (quic.EarlySession, error) {
|
||||||
receivedConfig = config
|
receivedConfig = config
|
||||||
return nil, errors.New("handshake error")
|
return nil, errors.New("handshake error")
|
||||||
}
|
}
|
||||||
|
@ -116,7 +122,7 @@ var _ = Describe("RoundTripper", func() {
|
||||||
|
|
||||||
It("uses the custom dialer, if provided", func() {
|
It("uses the custom dialer, if provided", func() {
|
||||||
var dialed bool
|
var dialed bool
|
||||||
dialer := func(_, _ string, tlsCfgP *tls.Config, cfg *quic.Config) (quic.Session, error) {
|
dialer := func(_, _ string, tlsCfgP *tls.Config, cfg *quic.Config) (quic.EarlySession, error) {
|
||||||
dialed = true
|
dialed = true
|
||||||
return nil, errors.New("handshake error")
|
return nil, errors.New("handshake error")
|
||||||
}
|
}
|
||||||
|
@ -129,6 +135,7 @@ var _ = Describe("RoundTripper", func() {
|
||||||
It("reuses existing clients", func() {
|
It("reuses existing clients", func() {
|
||||||
closed := make(chan struct{})
|
closed := make(chan struct{})
|
||||||
testErr := errors.New("test err")
|
testErr := errors.New("test err")
|
||||||
|
session.EXPECT().HandshakeComplete().Return(handshakeCtx)
|
||||||
session.EXPECT().OpenUniStream().AnyTimes().Return(nil, testErr)
|
session.EXPECT().OpenUniStream().AnyTimes().Return(nil, testErr)
|
||||||
session.EXPECT().OpenStreamSync(context.Background()).Return(nil, testErr).Times(2)
|
session.EXPECT().OpenStreamSync(context.Background()).Return(nil, testErr).Times(2)
|
||||||
session.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(quic.ErrorCode, string) { close(closed) })
|
session.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(quic.ErrorCode, string) { close(closed) })
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue