respect the request context when dialing

This commit is contained in:
Marten Seemann 2022-03-25 09:23:48 +01:00
parent d4293fb274
commit 137491916b
4 changed files with 30 additions and 30 deletions

View file

@ -34,6 +34,8 @@ var defaultQuicConfig = &quic.Config{
Versions: []protocol.VersionNumber{protocol.VersionTLS},
}
type dialFunc func(ctx context.Context, network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlySession, error)
var dialAddr = quic.DialAddrEarly
type roundTripperOpts struct {
@ -49,7 +51,7 @@ type client struct {
opts *roundTripperOpts
dialOnce sync.Once
dialer func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlySession, error)
dialer dialFunc
handshakeErr error
requestWriter *requestWriter
@ -62,24 +64,18 @@ type client struct {
logger utils.Logger
}
func newClient(
hostname string,
tlsConf *tls.Config,
opts *roundTripperOpts,
quicConfig *quic.Config,
dialer func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlySession, error),
) (*client, error) {
if quicConfig == nil {
quicConfig = defaultQuicConfig.Clone()
} else if len(quicConfig.Versions) == 0 {
quicConfig = quicConfig.Clone()
quicConfig.Versions = []quic.VersionNumber{defaultQuicConfig.Versions[0]}
func newClient(hostname string, tlsConf *tls.Config, opts *roundTripperOpts, conf *quic.Config, dialer dialFunc) (*client, error) {
if conf == nil {
conf = defaultQuicConfig.Clone()
} else if len(conf.Versions) == 0 {
conf = conf.Clone()
conf.Versions = []quic.VersionNumber{defaultQuicConfig.Versions[0]}
}
if len(quicConfig.Versions) != 1 {
if len(conf.Versions) != 1 {
return nil, errors.New("can only use a single QUIC version for dialing a HTTP/3 connection")
}
quicConfig.MaxIncomingStreams = -1 // don't allow any bidirectional streams
quicConfig.EnableDatagrams = opts.EnableDatagram
conf.MaxIncomingStreams = -1 // don't allow any bidirectional streams
conf.EnableDatagrams = opts.EnableDatagram
logger := utils.DefaultLogger.WithPrefix("h3 client")
if tlsConf == nil {
@ -88,24 +84,24 @@ func newClient(
tlsConf = tlsConf.Clone()
}
// Replace existing ALPNs by H3
tlsConf.NextProtos = []string{versionToALPN(quicConfig.Versions[0])}
tlsConf.NextProtos = []string{versionToALPN(conf.Versions[0])}
return &client{
hostname: authorityAddr("https", hostname),
tlsConf: tlsConf,
requestWriter: newRequestWriter(logger),
decoder: qpack.NewDecoder(func(hf qpack.HeaderField) {}),
config: quicConfig,
config: conf,
opts: opts,
dialer: dialer,
logger: logger,
}, nil
}
func (c *client) dial() error {
func (c *client) dial(ctx context.Context) error {
var err error
if c.dialer != nil {
c.session, err = c.dialer("udp", c.hostname, c.tlsConf, c.config)
c.session, err = c.dialer(ctx, "udp", c.hostname, c.tlsConf, c.config)
} else {
c.session, err = dialAddr(c.hostname, c.tlsConf, c.config)
}
@ -212,7 +208,7 @@ func (c *client) RoundTrip(req *http.Request) (*http.Response, error) {
}
c.dialOnce.Do(func() {
c.handshakeErr = c.dial()
c.handshakeErr = c.dial(req.Context())
})
if c.handshakeErr != nil {

View file

@ -12,13 +12,13 @@ import (
"net/http"
"time"
"github.com/golang/mock/gomock"
"github.com/lucas-clemente/quic-go"
mockquic "github.com/lucas-clemente/quic-go/internal/mocks/quic"
"github.com/lucas-clemente/quic-go/quicvarint"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/quicvarint"
"github.com/golang/mock/gomock"
"github.com/marten-seemann/qpack"
. "github.com/onsi/ginkgo"
@ -122,8 +122,11 @@ var _ = Describe("Client", func() {
testErr := errors.New("test done")
tlsConf := &tls.Config{ServerName: "foo.bar"}
quicConf := &quic.Config{MaxIdleTimeout: 1337 * time.Second}
ctx, cancel := context.WithTimeout(context.Background(), time.Hour)
defer cancel()
var dialerCalled bool
dialer := func(network, address string, tlsConfP *tls.Config, quicConfP *quic.Config) (quic.EarlySession, error) {
dialer := func(ctxP context.Context, network, address string, tlsConfP *tls.Config, quicConfP *quic.Config) (quic.EarlySession, error) {
Expect(ctxP).To(Equal(ctx))
Expect(network).To(Equal("udp"))
Expect(address).To(Equal("localhost:1337"))
Expect(tlsConfP.ServerName).To(Equal("foo.bar"))
@ -133,7 +136,7 @@ var _ = Describe("Client", func() {
}
client, err := newClient("localhost:1337", tlsConf, &roundTripperOpts{}, quicConf, dialer)
Expect(err).ToNot(HaveOccurred())
_, err = client.RoundTrip(req)
_, err = client.RoundTrip(req.WithContext(ctx))
Expect(err).To(MatchError(testErr))
Expect(dialerCalled).To(BeTrue())
})

View file

@ -1,6 +1,7 @@
package http3
import (
"context"
"crypto/tls"
"errors"
"fmt"
@ -9,7 +10,7 @@ import (
"strings"
"sync"
quic "github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go"
"golang.org/x/net/http/httpguts"
)
@ -48,8 +49,8 @@ type RoundTripper struct {
// Dial specifies an optional dial function for creating QUIC
// connections for requests.
// If Dial is nil, quic.DialAddrEarly will be used.
Dial func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlySession, error)
// If Dial is nil, quic.DialAddrEarlyContext will be used.
Dial func(ctx context.Context, network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlySession, error)
// MaxResponseHeaderBytes specifies a limit on how many response bytes are
// allowed in the server's response header.

View file

@ -127,7 +127,7 @@ var _ = Describe("RoundTripper", func() {
It("uses the custom dialer, if provided", func() {
var dialed bool
dialer := func(_, _ string, tlsCfgP *tls.Config, cfg *quic.Config) (quic.EarlySession, error) {
dialer := func(_ context.Context, _, _ string, tlsCfgP *tls.Config, cfg *quic.Config) (quic.EarlySession, error) {
dialed = true
return nil, errors.New("handshake error")
}