mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-05 13:17:36 +03:00
Merge pull request #3359 from lucas-clemente/http3-dial-context
respect the request context when dialing
This commit is contained in:
commit
42f3159497
4 changed files with 42 additions and 46 deletions
|
@ -34,7 +34,9 @@ var defaultQuicConfig = &quic.Config{
|
||||||
Versions: []protocol.VersionNumber{protocol.VersionTLS},
|
Versions: []protocol.VersionNumber{protocol.VersionTLS},
|
||||||
}
|
}
|
||||||
|
|
||||||
var dialAddr = quic.DialAddrEarly
|
type dialFunc func(ctx context.Context, network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlySession, error)
|
||||||
|
|
||||||
|
var dialAddr = quic.DialAddrEarlyContext
|
||||||
|
|
||||||
type roundTripperOpts struct {
|
type roundTripperOpts struct {
|
||||||
DisableCompression bool
|
DisableCompression bool
|
||||||
|
@ -49,7 +51,7 @@ type client struct {
|
||||||
opts *roundTripperOpts
|
opts *roundTripperOpts
|
||||||
|
|
||||||
dialOnce sync.Once
|
dialOnce sync.Once
|
||||||
dialer func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlySession, error)
|
dialer dialFunc
|
||||||
handshakeErr error
|
handshakeErr error
|
||||||
|
|
||||||
requestWriter *requestWriter
|
requestWriter *requestWriter
|
||||||
|
@ -62,24 +64,18 @@ type client struct {
|
||||||
logger utils.Logger
|
logger utils.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
func newClient(
|
func newClient(hostname string, tlsConf *tls.Config, opts *roundTripperOpts, conf *quic.Config, dialer dialFunc) (*client, error) {
|
||||||
hostname string,
|
if conf == nil {
|
||||||
tlsConf *tls.Config,
|
conf = defaultQuicConfig.Clone()
|
||||||
opts *roundTripperOpts,
|
} else if len(conf.Versions) == 0 {
|
||||||
quicConfig *quic.Config,
|
conf = conf.Clone()
|
||||||
dialer func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlySession, error),
|
conf.Versions = []quic.VersionNumber{defaultQuicConfig.Versions[0]}
|
||||||
) (*client, error) {
|
|
||||||
if quicConfig == nil {
|
|
||||||
quicConfig = defaultQuicConfig.Clone()
|
|
||||||
} else if len(quicConfig.Versions) == 0 {
|
|
||||||
quicConfig = quicConfig.Clone()
|
|
||||||
quicConfig.Versions = []quic.VersionNumber{defaultQuicConfig.Versions[0]}
|
|
||||||
}
|
}
|
||||||
if len(quicConfig.Versions) != 1 {
|
if len(conf.Versions) != 1 {
|
||||||
return nil, errors.New("can only use a single QUIC version for dialing a HTTP/3 connection")
|
return nil, errors.New("can only use a single QUIC version for dialing a HTTP/3 connection")
|
||||||
}
|
}
|
||||||
quicConfig.MaxIncomingStreams = -1 // don't allow any bidirectional streams
|
conf.MaxIncomingStreams = -1 // don't allow any bidirectional streams
|
||||||
quicConfig.EnableDatagrams = opts.EnableDatagram
|
conf.EnableDatagrams = opts.EnableDatagram
|
||||||
logger := utils.DefaultLogger.WithPrefix("h3 client")
|
logger := utils.DefaultLogger.WithPrefix("h3 client")
|
||||||
|
|
||||||
if tlsConf == nil {
|
if tlsConf == nil {
|
||||||
|
@ -88,26 +84,26 @@ func newClient(
|
||||||
tlsConf = tlsConf.Clone()
|
tlsConf = tlsConf.Clone()
|
||||||
}
|
}
|
||||||
// Replace existing ALPNs by H3
|
// Replace existing ALPNs by H3
|
||||||
tlsConf.NextProtos = []string{versionToALPN(quicConfig.Versions[0])}
|
tlsConf.NextProtos = []string{versionToALPN(conf.Versions[0])}
|
||||||
|
|
||||||
return &client{
|
return &client{
|
||||||
hostname: authorityAddr("https", hostname),
|
hostname: authorityAddr("https", hostname),
|
||||||
tlsConf: tlsConf,
|
tlsConf: tlsConf,
|
||||||
requestWriter: newRequestWriter(logger),
|
requestWriter: newRequestWriter(logger),
|
||||||
decoder: qpack.NewDecoder(func(hf qpack.HeaderField) {}),
|
decoder: qpack.NewDecoder(func(hf qpack.HeaderField) {}),
|
||||||
config: quicConfig,
|
config: conf,
|
||||||
opts: opts,
|
opts: opts,
|
||||||
dialer: dialer,
|
dialer: dialer,
|
||||||
logger: logger,
|
logger: logger,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *client) dial() error {
|
func (c *client) dial(ctx context.Context) error {
|
||||||
var err error
|
var err error
|
||||||
if c.dialer != nil {
|
if c.dialer != nil {
|
||||||
c.session, err = c.dialer("udp", c.hostname, c.tlsConf, c.config)
|
c.session, err = c.dialer(ctx, "udp", c.hostname, c.tlsConf, c.config)
|
||||||
} else {
|
} else {
|
||||||
c.session, err = dialAddr(c.hostname, c.tlsConf, c.config)
|
c.session, err = dialAddr(ctx, c.hostname, c.tlsConf, c.config)
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -212,7 +208,7 @@ func (c *client) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
c.dialOnce.Do(func() {
|
c.dialOnce.Do(func() {
|
||||||
c.handshakeErr = c.dial()
|
c.handshakeErr = c.dial(req.Context())
|
||||||
})
|
})
|
||||||
|
|
||||||
if c.handshakeErr != nil {
|
if c.handshakeErr != nil {
|
||||||
|
|
|
@ -12,13 +12,13 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/golang/mock/gomock"
|
|
||||||
"github.com/lucas-clemente/quic-go"
|
"github.com/lucas-clemente/quic-go"
|
||||||
mockquic "github.com/lucas-clemente/quic-go/internal/mocks/quic"
|
mockquic "github.com/lucas-clemente/quic-go/internal/mocks/quic"
|
||||||
"github.com/lucas-clemente/quic-go/quicvarint"
|
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||||
|
"github.com/lucas-clemente/quic-go/quicvarint"
|
||||||
|
|
||||||
|
"github.com/golang/mock/gomock"
|
||||||
"github.com/marten-seemann/qpack"
|
"github.com/marten-seemann/qpack"
|
||||||
|
|
||||||
. "github.com/onsi/ginkgo"
|
. "github.com/onsi/ginkgo"
|
||||||
|
@ -65,7 +65,7 @@ var _ = Describe("Client", func() {
|
||||||
client, err := newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil)
|
client, err := newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
var dialAddrCalled bool
|
var dialAddrCalled bool
|
||||||
dialAddr = func(_ string, tlsConf *tls.Config, quicConf *quic.Config) (quic.EarlySession, error) {
|
dialAddr = func(_ context.Context, _ string, tlsConf *tls.Config, quicConf *quic.Config) (quic.EarlySession, error) {
|
||||||
Expect(quicConf).To(Equal(defaultQuicConfig))
|
Expect(quicConf).To(Equal(defaultQuicConfig))
|
||||||
Expect(tlsConf.NextProtos).To(Equal([]string{nextProtoH3}))
|
Expect(tlsConf.NextProtos).To(Equal([]string{nextProtoH3}))
|
||||||
Expect(quicConf.Versions).To(Equal([]protocol.VersionNumber{protocol.Version1}))
|
Expect(quicConf.Versions).To(Equal([]protocol.VersionNumber{protocol.Version1}))
|
||||||
|
@ -80,7 +80,7 @@ var _ = Describe("Client", func() {
|
||||||
client, err := newClient("quic.clemente.io", nil, &roundTripperOpts{}, nil, nil)
|
client, err := newClient("quic.clemente.io", nil, &roundTripperOpts{}, nil, nil)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
var dialAddrCalled bool
|
var dialAddrCalled bool
|
||||||
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.EarlySession, error) {
|
dialAddr = func(_ context.Context, hostname string, _ *tls.Config, _ *quic.Config) (quic.EarlySession, error) {
|
||||||
Expect(hostname).To(Equal("quic.clemente.io:443"))
|
Expect(hostname).To(Equal("quic.clemente.io:443"))
|
||||||
dialAddrCalled = true
|
dialAddrCalled = true
|
||||||
return nil, errors.New("test done")
|
return nil, errors.New("test done")
|
||||||
|
@ -100,12 +100,8 @@ var _ = Describe("Client", func() {
|
||||||
client, err := newClient("localhost:1337", tlsConf, &roundTripperOpts{}, quicConf, nil)
|
client, err := newClient("localhost:1337", tlsConf, &roundTripperOpts{}, quicConf, nil)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
var dialAddrCalled bool
|
var dialAddrCalled bool
|
||||||
dialAddr = func(
|
dialAddr = func(_ context.Context, host string, tlsConfP *tls.Config, quicConfP *quic.Config) (quic.EarlySession, error) {
|
||||||
hostname string,
|
Expect(host).To(Equal("localhost:1337"))
|
||||||
tlsConfP *tls.Config,
|
|
||||||
quicConfP *quic.Config,
|
|
||||||
) (quic.EarlySession, error) {
|
|
||||||
Expect(hostname).To(Equal("localhost:1337"))
|
|
||||||
Expect(tlsConfP.ServerName).To(Equal(tlsConf.ServerName))
|
Expect(tlsConfP.ServerName).To(Equal(tlsConf.ServerName))
|
||||||
Expect(tlsConfP.NextProtos).To(Equal([]string{nextProtoH3}))
|
Expect(tlsConfP.NextProtos).To(Equal([]string{nextProtoH3}))
|
||||||
Expect(quicConfP.MaxIdleTimeout).To(Equal(quicConf.MaxIdleTimeout))
|
Expect(quicConfP.MaxIdleTimeout).To(Equal(quicConf.MaxIdleTimeout))
|
||||||
|
@ -122,8 +118,11 @@ var _ = Describe("Client", func() {
|
||||||
testErr := errors.New("test done")
|
testErr := errors.New("test done")
|
||||||
tlsConf := &tls.Config{ServerName: "foo.bar"}
|
tlsConf := &tls.Config{ServerName: "foo.bar"}
|
||||||
quicConf := &quic.Config{MaxIdleTimeout: 1337 * time.Second}
|
quicConf := &quic.Config{MaxIdleTimeout: 1337 * time.Second}
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), time.Hour)
|
||||||
|
defer cancel()
|
||||||
var dialerCalled bool
|
var dialerCalled bool
|
||||||
dialer := func(network, address string, tlsConfP *tls.Config, quicConfP *quic.Config) (quic.EarlySession, error) {
|
dialer := func(ctxP context.Context, network, address string, tlsConfP *tls.Config, quicConfP *quic.Config) (quic.EarlySession, error) {
|
||||||
|
Expect(ctxP).To(Equal(ctx))
|
||||||
Expect(network).To(Equal("udp"))
|
Expect(network).To(Equal("udp"))
|
||||||
Expect(address).To(Equal("localhost:1337"))
|
Expect(address).To(Equal("localhost:1337"))
|
||||||
Expect(tlsConfP.ServerName).To(Equal("foo.bar"))
|
Expect(tlsConfP.ServerName).To(Equal("foo.bar"))
|
||||||
|
@ -133,7 +132,7 @@ var _ = Describe("Client", func() {
|
||||||
}
|
}
|
||||||
client, err := newClient("localhost:1337", tlsConf, &roundTripperOpts{}, quicConf, dialer)
|
client, err := newClient("localhost:1337", tlsConf, &roundTripperOpts{}, quicConf, dialer)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
_, err = client.RoundTrip(req)
|
_, err = client.RoundTrip(req.WithContext(ctx))
|
||||||
Expect(err).To(MatchError(testErr))
|
Expect(err).To(MatchError(testErr))
|
||||||
Expect(dialerCalled).To(BeTrue())
|
Expect(dialerCalled).To(BeTrue())
|
||||||
})
|
})
|
||||||
|
@ -142,7 +141,7 @@ var _ = Describe("Client", func() {
|
||||||
testErr := errors.New("handshake error")
|
testErr := errors.New("handshake error")
|
||||||
client, err := newClient("localhost:1337", nil, &roundTripperOpts{EnableDatagram: true}, nil, nil)
|
client, err := newClient("localhost:1337", nil, &roundTripperOpts{EnableDatagram: true}, nil, nil)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
dialAddr = func(hostname string, _ *tls.Config, quicConf *quic.Config) (quic.EarlySession, error) {
|
dialAddr = func(_ context.Context, _ string, _ *tls.Config, quicConf *quic.Config) (quic.EarlySession, error) {
|
||||||
Expect(quicConf.EnableDatagrams).To(BeTrue())
|
Expect(quicConf.EnableDatagrams).To(BeTrue())
|
||||||
return nil, testErr
|
return nil, testErr
|
||||||
}
|
}
|
||||||
|
@ -154,7 +153,7 @@ var _ = Describe("Client", func() {
|
||||||
testErr := errors.New("handshake error")
|
testErr := errors.New("handshake error")
|
||||||
client, err := newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil)
|
client, err := newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.EarlySession, error) {
|
dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlySession, error) {
|
||||||
return nil, testErr
|
return nil, testErr
|
||||||
}
|
}
|
||||||
_, err = client.RoundTrip(req)
|
_, err = client.RoundTrip(req)
|
||||||
|
@ -179,7 +178,7 @@ var _ = Describe("Client", func() {
|
||||||
testErr := errors.New("handshake error")
|
testErr := errors.New("handshake error")
|
||||||
req, err := http.NewRequest("masque", "masque://quic.clemente.io:1337/foobar.html", nil)
|
req, err := http.NewRequest("masque", "masque://quic.clemente.io:1337/foobar.html", nil)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.EarlySession, error) {
|
dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlySession, error) {
|
||||||
return nil, testErr
|
return nil, testErr
|
||||||
}
|
}
|
||||||
_, err = client.RoundTrip(req)
|
_, err = client.RoundTrip(req)
|
||||||
|
@ -206,7 +205,7 @@ var _ = Describe("Client", func() {
|
||||||
sess.EXPECT().OpenUniStream().Return(controlStr, nil)
|
sess.EXPECT().OpenUniStream().Return(controlStr, nil)
|
||||||
sess.EXPECT().HandshakeComplete().Return(handshakeCtx)
|
sess.EXPECT().HandshakeComplete().Return(handshakeCtx)
|
||||||
sess.EXPECT().OpenStreamSync(gomock.Any()).Return(nil, errors.New("done"))
|
sess.EXPECT().OpenStreamSync(gomock.Any()).Return(nil, errors.New("done"))
|
||||||
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.EarlySession, error) { return sess, nil }
|
dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlySession, error) { return sess, nil }
|
||||||
var err error
|
var err error
|
||||||
request, err = http.NewRequest("GET", "https://quic.clemente.io:1337/file1.dat", nil)
|
request, err = http.NewRequest("GET", "https://quic.clemente.io:1337/file1.dat", nil)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
@ -453,7 +452,7 @@ var _ = Describe("Client", func() {
|
||||||
<-testDone
|
<-testDone
|
||||||
return nil, errors.New("test done")
|
return nil, errors.New("test done")
|
||||||
})
|
})
|
||||||
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.EarlySession, error) { return sess, nil }
|
dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlySession, error) { return sess, nil }
|
||||||
var err error
|
var err error
|
||||||
request, err = http.NewRequest("GET", "https://quic.clemente.io:1337/file1.dat", nil)
|
request, err = http.NewRequest("GET", "https://quic.clemente.io:1337/file1.dat", nil)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package http3
|
package http3
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
@ -48,8 +49,8 @@ type RoundTripper struct {
|
||||||
|
|
||||||
// Dial specifies an optional dial function for creating QUIC
|
// Dial specifies an optional dial function for creating QUIC
|
||||||
// connections for requests.
|
// connections for requests.
|
||||||
// If Dial is nil, quic.DialAddrEarly will be used.
|
// If Dial is nil, quic.DialAddrEarlyContext will be used.
|
||||||
Dial func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlySession, error)
|
Dial func(ctx context.Context, network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlySession, error)
|
||||||
|
|
||||||
// MaxResponseHeaderBytes specifies a limit on how many response bytes are
|
// MaxResponseHeaderBytes specifies a limit on how many response bytes are
|
||||||
// allowed in the server's response header.
|
// allowed in the server's response header.
|
||||||
|
|
|
@ -82,7 +82,7 @@ var _ = Describe("RoundTripper", func() {
|
||||||
BeforeEach(func() {
|
BeforeEach(func() {
|
||||||
session = mockquic.NewMockEarlySession(mockCtrl)
|
session = mockquic.NewMockEarlySession(mockCtrl)
|
||||||
origDialAddr = dialAddr
|
origDialAddr = dialAddr
|
||||||
dialAddr = func(addr string, tlsConf *tls.Config, config *quic.Config) (quic.EarlySession, error) {
|
dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlySession, error) {
|
||||||
// return an error when trying to open a stream
|
// return an error when trying to open a stream
|
||||||
// we don't want to test all the dial logic here, just that dialing happens at all
|
// we don't want to test all the dial logic here, just that dialing happens at all
|
||||||
return session, nil
|
return session, nil
|
||||||
|
@ -115,7 +115,7 @@ var _ = Describe("RoundTripper", func() {
|
||||||
It("uses the quic.Config, if provided", func() {
|
It("uses the quic.Config, if provided", func() {
|
||||||
config := &quic.Config{HandshakeIdleTimeout: time.Millisecond}
|
config := &quic.Config{HandshakeIdleTimeout: time.Millisecond}
|
||||||
var receivedConfig *quic.Config
|
var receivedConfig *quic.Config
|
||||||
dialAddr = func(addr string, tlsConf *tls.Config, config *quic.Config) (quic.EarlySession, error) {
|
dialAddr = func(_ context.Context, _ string, _ *tls.Config, config *quic.Config) (quic.EarlySession, error) {
|
||||||
receivedConfig = config
|
receivedConfig = config
|
||||||
return nil, errors.New("handshake error")
|
return nil, errors.New("handshake error")
|
||||||
}
|
}
|
||||||
|
@ -127,7 +127,7 @@ var _ = Describe("RoundTripper", func() {
|
||||||
|
|
||||||
It("uses the custom dialer, if provided", func() {
|
It("uses the custom dialer, if provided", func() {
|
||||||
var dialed bool
|
var dialed bool
|
||||||
dialer := func(_, _ string, tlsCfgP *tls.Config, cfg *quic.Config) (quic.EarlySession, error) {
|
dialer := func(_ context.Context, _, _ string, tlsCfgP *tls.Config, cfg *quic.Config) (quic.EarlySession, error) {
|
||||||
dialed = true
|
dialed = true
|
||||||
return nil, errors.New("handshake error")
|
return nil, errors.New("handshake error")
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue