mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-04 20:57:36 +03:00
use 0-RTT to open the H3 client's control stream
This commit is contained in:
parent
63c9272bf4
commit
1372e5dd5e
4 changed files with 44 additions and 34 deletions
|
@ -23,7 +23,7 @@ var defaultQuicConfig = &quic.Config{
|
|||
KeepAlive: true,
|
||||
}
|
||||
|
||||
var dialAddr = quic.DialAddr
|
||||
var dialAddr = quic.DialAddrEarly
|
||||
|
||||
type roundTripperOpts struct {
|
||||
DisableCompression bool
|
||||
|
@ -37,7 +37,7 @@ type client struct {
|
|||
opts *roundTripperOpts
|
||||
|
||||
dialOnce sync.Once
|
||||
dialer func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.Session, error)
|
||||
dialer func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlySession, error)
|
||||
handshakeErr error
|
||||
|
||||
requestWriter *requestWriter
|
||||
|
@ -45,7 +45,7 @@ type client struct {
|
|||
decoder *qpack.Decoder
|
||||
|
||||
hostname string
|
||||
session quic.Session
|
||||
session quic.EarlySession
|
||||
|
||||
logger utils.Logger
|
||||
}
|
||||
|
@ -55,7 +55,7 @@ func newClient(
|
|||
tlsConf *tls.Config,
|
||||
opts *roundTripperOpts,
|
||||
quicConfig *quic.Config,
|
||||
dialer func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.Session, error),
|
||||
dialer func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlySession, error),
|
||||
) *client {
|
||||
if tlsConf == nil {
|
||||
tlsConf = &tls.Config{}
|
||||
|
@ -93,6 +93,7 @@ func (c *client) dial() error {
|
|||
return err
|
||||
}
|
||||
|
||||
// run the sesssion setup using 0-RTT data
|
||||
go func() {
|
||||
if err := c.setupSession(); err != nil {
|
||||
c.logger.Debugf("Setting up session failed: %s", err)
|
||||
|
@ -100,6 +101,7 @@ func (c *client) dial() error {
|
|||
}
|
||||
}()
|
||||
|
||||
<-c.session.HandshakeComplete().Done()
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
@ -26,6 +26,7 @@ var _ = Describe("Client", func() {
|
|||
client *client
|
||||
req *http.Request
|
||||
origDialAddr = dialAddr
|
||||
handshakeCtx context.Context // an already canceled context
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
|
@ -37,6 +38,10 @@ var _ = Describe("Client", func() {
|
|||
var err error
|
||||
req, err = http.NewRequest("GET", "https://localhost:1337", nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
handshakeCtx = ctx
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
|
@ -46,7 +51,7 @@ var _ = Describe("Client", func() {
|
|||
It("uses the default QUIC and TLS config if none is give", func() {
|
||||
client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil)
|
||||
var dialAddrCalled bool
|
||||
dialAddr = func(_ string, tlsConf *tls.Config, quicConf *quic.Config) (quic.Session, error) {
|
||||
dialAddr = func(_ string, tlsConf *tls.Config, quicConf *quic.Config) (quic.EarlySession, error) {
|
||||
Expect(quicConf).To(Equal(defaultQuicConfig))
|
||||
Expect(tlsConf.NextProtos).To(Equal([]string{nextProtoH3}))
|
||||
dialAddrCalled = true
|
||||
|
@ -59,7 +64,7 @@ var _ = Describe("Client", func() {
|
|||
It("adds the port to the hostname, if none is given", func() {
|
||||
client = newClient("quic.clemente.io", nil, &roundTripperOpts{}, nil, nil)
|
||||
var dialAddrCalled bool
|
||||
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) {
|
||||
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.EarlySession, error) {
|
||||
Expect(hostname).To(Equal("quic.clemente.io:443"))
|
||||
dialAddrCalled = true
|
||||
return nil, errors.New("test done")
|
||||
|
@ -82,7 +87,7 @@ var _ = Describe("Client", func() {
|
|||
hostname string,
|
||||
tlsConfP *tls.Config,
|
||||
quicConfP *quic.Config,
|
||||
) (quic.Session, error) {
|
||||
) (quic.EarlySession, error) {
|
||||
Expect(hostname).To(Equal("localhost:1337"))
|
||||
Expect(tlsConfP.ServerName).To(Equal(tlsConf.ServerName))
|
||||
Expect(tlsConfP.NextProtos).To(Equal([]string{nextProtoH3}))
|
||||
|
@ -101,7 +106,7 @@ var _ = Describe("Client", func() {
|
|||
tlsConf := &tls.Config{ServerName: "foo.bar"}
|
||||
quicConf := &quic.Config{MaxIdleTimeout: 1337 * time.Second}
|
||||
var dialerCalled bool
|
||||
dialer := func(network, address string, tlsConfP *tls.Config, quicConfP *quic.Config) (quic.Session, error) {
|
||||
dialer := func(network, address string, tlsConfP *tls.Config, quicConfP *quic.Config) (quic.EarlySession, error) {
|
||||
Expect(network).To(Equal("udp"))
|
||||
Expect(address).To(Equal("localhost:1337"))
|
||||
Expect(tlsConfP.ServerName).To(Equal("foo.bar"))
|
||||
|
@ -118,7 +123,7 @@ var _ = Describe("Client", func() {
|
|||
It("errors when dialing fails", func() {
|
||||
testErr := errors.New("handshake error")
|
||||
client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil)
|
||||
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) {
|
||||
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.EarlySession, error) {
|
||||
return nil, testErr
|
||||
}
|
||||
_, err := client.RoundTrip(req)
|
||||
|
@ -130,9 +135,10 @@ var _ = Describe("Client", func() {
|
|||
client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil)
|
||||
session := mockquic.NewMockEarlySession(mockCtrl)
|
||||
session.EXPECT().OpenUniStream().Return(nil, testErr).MaxTimes(1)
|
||||
session.EXPECT().HandshakeComplete().Return(handshakeCtx).MaxTimes(1)
|
||||
session.EXPECT().OpenStreamSync(context.Background()).Return(nil, testErr).MaxTimes(1)
|
||||
session.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).MaxTimes(1)
|
||||
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) {
|
||||
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.EarlySession, error) {
|
||||
return session, nil
|
||||
}
|
||||
defer GinkgoRecover()
|
||||
|
@ -173,7 +179,7 @@ var _ = Describe("Client", func() {
|
|||
str = mockquic.NewMockStream(mockCtrl)
|
||||
sess = mockquic.NewMockEarlySession(mockCtrl)
|
||||
sess.EXPECT().OpenUniStream().Return(controlStr, nil).MaxTimes(1)
|
||||
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) {
|
||||
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.EarlySession, error) {
|
||||
return sess, nil
|
||||
}
|
||||
var err error
|
||||
|
@ -186,7 +192,10 @@ var _ = Describe("Client", func() {
|
|||
rw := newResponseWriter(rspBuf, utils.DefaultLogger)
|
||||
rw.WriteHeader(418)
|
||||
|
||||
sess.EXPECT().OpenStreamSync(context.Background()).Return(str, nil)
|
||||
gomock.InOrder(
|
||||
sess.EXPECT().HandshakeComplete().Return(handshakeCtx),
|
||||
sess.EXPECT().OpenStreamSync(context.Background()).Return(str, nil),
|
||||
)
|
||||
str.EXPECT().Write(gomock.Any()).AnyTimes()
|
||||
str.EXPECT().Close()
|
||||
str.EXPECT().Read(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
|
||||
|
@ -220,7 +229,10 @@ var _ = Describe("Client", func() {
|
|||
|
||||
BeforeEach(func() {
|
||||
strBuf = &bytes.Buffer{}
|
||||
sess.EXPECT().OpenStreamSync(context.Background()).Return(str, nil)
|
||||
gomock.InOrder(
|
||||
sess.EXPECT().HandshakeComplete().Return(handshakeCtx),
|
||||
sess.EXPECT().OpenStreamSync(context.Background()).Return(str, nil),
|
||||
)
|
||||
body := &mockBody{}
|
||||
body.SetData([]byte("request body"))
|
||||
var err error
|
||||
|
@ -301,6 +313,7 @@ var _ = Describe("Client", func() {
|
|||
It("cancels a request while the request is still in flight", func() {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
req := request.WithContext(ctx)
|
||||
sess.EXPECT().HandshakeComplete().Return(handshakeCtx)
|
||||
sess.EXPECT().OpenStreamSync(ctx).Return(str, nil)
|
||||
buf := &bytes.Buffer{}
|
||||
str.EXPECT().Close().MaxTimes(1)
|
||||
|
@ -333,6 +346,7 @@ var _ = Describe("Client", func() {
|
|||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
req := request.WithContext(ctx)
|
||||
sess.EXPECT().HandshakeComplete().Return(handshakeCtx)
|
||||
sess.EXPECT().OpenStreamSync(ctx).Return(str, nil)
|
||||
buf := &bytes.Buffer{}
|
||||
str.EXPECT().Close().MaxTimes(1)
|
||||
|
@ -354,21 +368,8 @@ var _ = Describe("Client", func() {
|
|||
})
|
||||
|
||||
Context("gzip compression", func() {
|
||||
var gzippedData []byte // a gzipped foobar
|
||||
var response *http.Response
|
||||
|
||||
BeforeEach(func() {
|
||||
var b bytes.Buffer
|
||||
w := gzip.NewWriter(&b)
|
||||
w.Write([]byte("foobar"))
|
||||
w.Close()
|
||||
gzippedData = b.Bytes()
|
||||
response = &http.Response{
|
||||
StatusCode: 200,
|
||||
Header: http.Header{"Content-Length": []string{"1000"}},
|
||||
}
|
||||
_ = gzippedData
|
||||
_ = response
|
||||
sess.EXPECT().HandshakeComplete().Return(handshakeCtx)
|
||||
})
|
||||
|
||||
It("adds the gzip header to requests", func() {
|
||||
|
|
|
@ -44,7 +44,7 @@ type RoundTripper struct {
|
|||
// Dial specifies an optional dial function for creating QUIC
|
||||
// connections for requests.
|
||||
// If Dial is nil, quic.DialAddr will be used.
|
||||
Dial func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.Session, error)
|
||||
Dial func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlySession, error)
|
||||
|
||||
// MaxResponseHeaderBytes specifies a limit on how many response bytes are
|
||||
// allowed in the server's response header.
|
||||
|
|
|
@ -61,6 +61,7 @@ var _ = Describe("RoundTripper", func() {
|
|||
rt *RoundTripper
|
||||
req1 *http.Request
|
||||
session *mockquic.MockEarlySession
|
||||
handshakeCtx context.Context // an already canceled context
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
|
@ -68,6 +69,10 @@ var _ = Describe("RoundTripper", func() {
|
|||
var err error
|
||||
req1, err = http.NewRequest("GET", "https://www.example.org/file1.html", nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
handshakeCtx = ctx
|
||||
})
|
||||
|
||||
Context("dialing hosts", func() {
|
||||
|
@ -76,7 +81,7 @@ var _ = Describe("RoundTripper", func() {
|
|||
BeforeEach(func() {
|
||||
session = mockquic.NewMockEarlySession(mockCtrl)
|
||||
origDialAddr = dialAddr
|
||||
dialAddr = func(addr string, tlsConf *tls.Config, config *quic.Config) (quic.Session, error) {
|
||||
dialAddr = func(addr string, tlsConf *tls.Config, config *quic.Config) (quic.EarlySession, error) {
|
||||
// return an error when trying to open a stream
|
||||
// we don't want to test all the dial logic here, just that dialing happens at all
|
||||
return session, nil
|
||||
|
@ -93,6 +98,7 @@ var _ = Describe("RoundTripper", func() {
|
|||
req, err := http.NewRequest("GET", "https://quic.clemente.io/foobar.html", nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
session.EXPECT().OpenUniStream().AnyTimes().Return(nil, testErr)
|
||||
session.EXPECT().HandshakeComplete().Return(handshakeCtx)
|
||||
session.EXPECT().OpenStreamSync(context.Background()).Return(nil, testErr)
|
||||
session.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(quic.ErrorCode, string) { close(closed) })
|
||||
_, err = rt.RoundTrip(req)
|
||||
|
@ -104,7 +110,7 @@ var _ = Describe("RoundTripper", func() {
|
|||
It("uses the quic.Config, if provided", func() {
|
||||
config := &quic.Config{HandshakeTimeout: time.Millisecond}
|
||||
var receivedConfig *quic.Config
|
||||
dialAddr = func(addr string, tlsConf *tls.Config, config *quic.Config) (quic.Session, error) {
|
||||
dialAddr = func(addr string, tlsConf *tls.Config, config *quic.Config) (quic.EarlySession, error) {
|
||||
receivedConfig = config
|
||||
return nil, errors.New("handshake error")
|
||||
}
|
||||
|
@ -116,7 +122,7 @@ var _ = Describe("RoundTripper", func() {
|
|||
|
||||
It("uses the custom dialer, if provided", func() {
|
||||
var dialed bool
|
||||
dialer := func(_, _ string, tlsCfgP *tls.Config, cfg *quic.Config) (quic.Session, error) {
|
||||
dialer := func(_, _ string, tlsCfgP *tls.Config, cfg *quic.Config) (quic.EarlySession, error) {
|
||||
dialed = true
|
||||
return nil, errors.New("handshake error")
|
||||
}
|
||||
|
@ -129,6 +135,7 @@ var _ = Describe("RoundTripper", func() {
|
|||
It("reuses existing clients", func() {
|
||||
closed := make(chan struct{})
|
||||
testErr := errors.New("test err")
|
||||
session.EXPECT().HandshakeComplete().Return(handshakeCtx)
|
||||
session.EXPECT().OpenUniStream().AnyTimes().Return(nil, testErr)
|
||||
session.EXPECT().OpenStreamSync(context.Background()).Return(nil, testErr).Times(2)
|
||||
session.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(quic.ErrorCode, string) { close(closed) })
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue