mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-04 12:47:36 +03:00
only allow the HTTP/3 client to dial with a single QUIC version
This commit is contained in:
parent
30c91149c2
commit
898f484d0a
3 changed files with 54 additions and 24 deletions
|
@ -11,6 +11,7 @@ import (
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go"
|
"github.com/lucas-clemente/quic-go"
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
"github.com/lucas-clemente/quic-go/internal/qtls"
|
"github.com/lucas-clemente/quic-go/internal/qtls"
|
||||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||||
"github.com/marten-seemann/qpack"
|
"github.com/marten-seemann/qpack"
|
||||||
|
@ -28,6 +29,7 @@ const (
|
||||||
var defaultQuicConfig = &quic.Config{
|
var defaultQuicConfig = &quic.Config{
|
||||||
MaxIncomingStreams: -1, // don't allow the server to create bidirectional streams
|
MaxIncomingStreams: -1, // don't allow the server to create bidirectional streams
|
||||||
KeepAlive: true,
|
KeepAlive: true,
|
||||||
|
Versions: []protocol.VersionNumber{protocol.VersionTLS},
|
||||||
}
|
}
|
||||||
|
|
||||||
var dialAddr = quic.DialAddrEarly
|
var dialAddr = quic.DialAddrEarly
|
||||||
|
@ -63,19 +65,26 @@ func newClient(
|
||||||
opts *roundTripperOpts,
|
opts *roundTripperOpts,
|
||||||
quicConfig *quic.Config,
|
quicConfig *quic.Config,
|
||||||
dialer func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlySession, error),
|
dialer func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlySession, error),
|
||||||
) *client {
|
) (*client, error) {
|
||||||
|
if quicConfig == nil {
|
||||||
|
quicConfig = defaultQuicConfig
|
||||||
|
} else if len(quicConfig.Versions) == 0 {
|
||||||
|
quicConfig = quicConfig.Clone()
|
||||||
|
quicConfig.Versions = []quic.VersionNumber{defaultQuicConfig.Versions[0]}
|
||||||
|
}
|
||||||
|
if len(quicConfig.Versions) != 1 {
|
||||||
|
return nil, errors.New("can only use a single QUIC version for dialing a HTTP/3 connection")
|
||||||
|
}
|
||||||
|
quicConfig.MaxIncomingStreams = -1 // don't allow any bidirectional streams
|
||||||
|
logger := utils.DefaultLogger.WithPrefix("h3 client")
|
||||||
|
|
||||||
if tlsConf == nil {
|
if tlsConf == nil {
|
||||||
tlsConf = &tls.Config{}
|
tlsConf = &tls.Config{}
|
||||||
} else {
|
} else {
|
||||||
tlsConf = tlsConf.Clone()
|
tlsConf = tlsConf.Clone()
|
||||||
}
|
}
|
||||||
// Replace existing ALPNs by H3
|
// Replace existing ALPNs by H3
|
||||||
tlsConf.NextProtos = []string{nextProtoH3Draft29}
|
tlsConf.NextProtos = []string{versionToALPN(quicConfig.Versions[0])}
|
||||||
if quicConfig == nil {
|
|
||||||
quicConfig = defaultQuicConfig
|
|
||||||
}
|
|
||||||
quicConfig.MaxIncomingStreams = -1 // don't allow any bidirectional streams
|
|
||||||
logger := utils.DefaultLogger.WithPrefix("h3 client")
|
|
||||||
|
|
||||||
return &client{
|
return &client{
|
||||||
hostname: authorityAddr("https", hostname),
|
hostname: authorityAddr("https", hostname),
|
||||||
|
@ -86,7 +95,7 @@ func newClient(
|
||||||
opts: opts,
|
opts: opts,
|
||||||
dialer: dialer,
|
dialer: dialer,
|
||||||
logger: logger,
|
logger: logger,
|
||||||
}
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *client) dial() error {
|
func (c *client) dial() error {
|
||||||
|
|
|
@ -35,10 +35,11 @@ var _ = Describe("Client", func() {
|
||||||
BeforeEach(func() {
|
BeforeEach(func() {
|
||||||
origDialAddr = dialAddr
|
origDialAddr = dialAddr
|
||||||
hostname := "quic.clemente.io:1337"
|
hostname := "quic.clemente.io:1337"
|
||||||
client = newClient(hostname, nil, &roundTripperOpts{MaxHeaderBytes: 1337}, nil, nil)
|
var err error
|
||||||
|
client, err = newClient(hostname, nil, &roundTripperOpts{MaxHeaderBytes: 1337}, nil, nil)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
Expect(client.hostname).To(Equal(hostname))
|
Expect(client.hostname).To(Equal(hostname))
|
||||||
|
|
||||||
var err error
|
|
||||||
req, err = http.NewRequest("GET", "https://localhost:1337", nil)
|
req, err = http.NewRequest("GET", "https://localhost:1337", nil)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
|
@ -51,12 +52,22 @@ var _ = Describe("Client", func() {
|
||||||
dialAddr = origDialAddr
|
dialAddr = origDialAddr
|
||||||
})
|
})
|
||||||
|
|
||||||
|
It("rejects quic.Configs that allow multiple QUIC versions", func() {
|
||||||
|
qconf := &quic.Config{
|
||||||
|
Versions: []quic.VersionNumber{protocol.VersionDraft29, protocol.VersionDraft32},
|
||||||
|
}
|
||||||
|
_, err := newClient("localhost:1337", nil, &roundTripperOpts{}, qconf, nil)
|
||||||
|
Expect(err).To(MatchError("can only use a single QUIC version for dialing a HTTP/3 connection"))
|
||||||
|
})
|
||||||
|
|
||||||
It("uses the default QUIC and TLS config if none is give", func() {
|
It("uses the default QUIC and TLS config if none is give", func() {
|
||||||
client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil)
|
client, err := newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
var dialAddrCalled bool
|
var dialAddrCalled bool
|
||||||
dialAddr = func(_ string, tlsConf *tls.Config, quicConf *quic.Config) (quic.EarlySession, error) {
|
dialAddr = func(_ string, tlsConf *tls.Config, quicConf *quic.Config) (quic.EarlySession, error) {
|
||||||
Expect(quicConf).To(Equal(defaultQuicConfig))
|
Expect(quicConf).To(Equal(defaultQuicConfig))
|
||||||
Expect(tlsConf.NextProtos).To(Equal([]string{nextProtoH3Draft29}))
|
Expect(tlsConf.NextProtos).To(Equal([]string{nextProtoH3Draft29}))
|
||||||
|
Expect(quicConf.Versions).To(Equal([]protocol.VersionNumber{protocol.VersionTLS}))
|
||||||
dialAddrCalled = true
|
dialAddrCalled = true
|
||||||
return nil, errors.New("test done")
|
return nil, errors.New("test done")
|
||||||
}
|
}
|
||||||
|
@ -65,7 +76,8 @@ var _ = Describe("Client", func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
It("adds the port to the hostname, if none is given", func() {
|
It("adds the port to the hostname, if none is given", func() {
|
||||||
client = newClient("quic.clemente.io", nil, &roundTripperOpts{}, nil, nil)
|
client, err := newClient("quic.clemente.io", nil, &roundTripperOpts{}, nil, nil)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
var dialAddrCalled bool
|
var dialAddrCalled bool
|
||||||
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.EarlySession, error) {
|
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.EarlySession, error) {
|
||||||
Expect(hostname).To(Equal("quic.clemente.io:443"))
|
Expect(hostname).To(Equal("quic.clemente.io:443"))
|
||||||
|
@ -84,7 +96,8 @@ var _ = Describe("Client", func() {
|
||||||
NextProtos: []string{"proto foo", "proto bar"},
|
NextProtos: []string{"proto foo", "proto bar"},
|
||||||
}
|
}
|
||||||
quicConf := &quic.Config{MaxIdleTimeout: time.Nanosecond}
|
quicConf := &quic.Config{MaxIdleTimeout: time.Nanosecond}
|
||||||
client = newClient("localhost:1337", tlsConf, &roundTripperOpts{}, quicConf, nil)
|
client, err := newClient("localhost:1337", tlsConf, &roundTripperOpts{}, quicConf, nil)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
var dialAddrCalled bool
|
var dialAddrCalled bool
|
||||||
dialAddr = func(
|
dialAddr = func(
|
||||||
hostname string,
|
hostname string,
|
||||||
|
@ -117,25 +130,28 @@ var _ = Describe("Client", func() {
|
||||||
dialerCalled = true
|
dialerCalled = true
|
||||||
return nil, testErr
|
return nil, testErr
|
||||||
}
|
}
|
||||||
client = newClient("localhost:1337", tlsConf, &roundTripperOpts{}, quicConf, dialer)
|
client, err := newClient("localhost:1337", tlsConf, &roundTripperOpts{}, quicConf, dialer)
|
||||||
_, err := client.RoundTrip(req)
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
_, err = client.RoundTrip(req)
|
||||||
Expect(err).To(MatchError(testErr))
|
Expect(err).To(MatchError(testErr))
|
||||||
Expect(dialerCalled).To(BeTrue())
|
Expect(dialerCalled).To(BeTrue())
|
||||||
})
|
})
|
||||||
|
|
||||||
It("errors when dialing fails", func() {
|
It("errors when dialing fails", func() {
|
||||||
testErr := errors.New("handshake error")
|
testErr := errors.New("handshake error")
|
||||||
client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil)
|
client, err := newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.EarlySession, error) {
|
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.EarlySession, error) {
|
||||||
return nil, testErr
|
return nil, testErr
|
||||||
}
|
}
|
||||||
_, err := client.RoundTrip(req)
|
_, err = client.RoundTrip(req)
|
||||||
Expect(err).To(MatchError(testErr))
|
Expect(err).To(MatchError(testErr))
|
||||||
})
|
})
|
||||||
|
|
||||||
It("errors if it can't open a stream", func() {
|
It("errors if it can't open a stream", func() {
|
||||||
testErr := errors.New("stream open error")
|
testErr := errors.New("stream open error")
|
||||||
client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil)
|
client, err := newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
session := mockquic.NewMockEarlySession(mockCtrl)
|
session := mockquic.NewMockEarlySession(mockCtrl)
|
||||||
session.EXPECT().OpenUniStream().Return(nil, testErr).MaxTimes(1)
|
session.EXPECT().OpenUniStream().Return(nil, testErr).MaxTimes(1)
|
||||||
session.EXPECT().HandshakeComplete().Return(handshakeCtx).MaxTimes(1)
|
session.EXPECT().HandshakeComplete().Return(handshakeCtx).MaxTimes(1)
|
||||||
|
@ -145,14 +161,14 @@ var _ = Describe("Client", func() {
|
||||||
return session, nil
|
return session, nil
|
||||||
}
|
}
|
||||||
defer GinkgoRecover()
|
defer GinkgoRecover()
|
||||||
_, err := client.RoundTrip(req)
|
_, err = client.RoundTrip(req)
|
||||||
Expect(err).To(MatchError(testErr))
|
Expect(err).To(MatchError(testErr))
|
||||||
})
|
})
|
||||||
|
|
||||||
It("closes correctly if session was not created", func() {
|
It("closes correctly if session was not created", func() {
|
||||||
client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil)
|
client, err := newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil)
|
||||||
err := client.Close()
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(client.Close()).To(Succeed())
|
||||||
})
|
})
|
||||||
|
|
||||||
Context("Doing requests", func() {
|
Context("Doing requests", func() {
|
||||||
|
@ -437,7 +453,8 @@ var _ = Describe("Client", func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
It("doesn't add gzip if the header disable it", func() {
|
It("doesn't add gzip if the header disable it", func() {
|
||||||
client = newClient("quic.clemente.io:1337", nil, &roundTripperOpts{DisableCompression: true}, nil, nil)
|
client, err := newClient("quic.clemente.io:1337", nil, &roundTripperOpts{DisableCompression: true}, nil, nil)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
sess.EXPECT().OpenStreamSync(context.Background()).Return(str, nil)
|
sess.EXPECT().OpenStreamSync(context.Background()).Return(str, nil)
|
||||||
buf := &bytes.Buffer{}
|
buf := &bytes.Buffer{}
|
||||||
str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
|
str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
|
||||||
|
@ -448,7 +465,7 @@ var _ = Describe("Client", func() {
|
||||||
str.EXPECT().CancelWrite(gomock.Any()).MaxTimes(1), // when the Read errors
|
str.EXPECT().CancelWrite(gomock.Any()).MaxTimes(1), // when the Read errors
|
||||||
)
|
)
|
||||||
str.EXPECT().Read(gomock.Any()).Return(0, errors.New("test done"))
|
str.EXPECT().Read(gomock.Any()).Return(0, errors.New("test done"))
|
||||||
_, err := client.RoundTrip(request)
|
_, err = client.RoundTrip(request)
|
||||||
Expect(err).To(MatchError("test done"))
|
Expect(err).To(MatchError("test done"))
|
||||||
hfs := decodeHeader(buf)
|
hfs := decodeHeader(buf)
|
||||||
Expect(hfs).ToNot(HaveKey("accept-encoding"))
|
Expect(hfs).ToNot(HaveKey("accept-encoding"))
|
||||||
|
|
|
@ -130,7 +130,8 @@ func (r *RoundTripper) getClient(hostname string, onlyCached bool) (http.RoundTr
|
||||||
if onlyCached {
|
if onlyCached {
|
||||||
return nil, ErrNoCachedConn
|
return nil, ErrNoCachedConn
|
||||||
}
|
}
|
||||||
client = newClient(
|
var err error
|
||||||
|
client, err = newClient(
|
||||||
hostname,
|
hostname,
|
||||||
r.TLSClientConfig,
|
r.TLSClientConfig,
|
||||||
&roundTripperOpts{
|
&roundTripperOpts{
|
||||||
|
@ -140,6 +141,9 @@ func (r *RoundTripper) getClient(hostname string, onlyCached bool) (http.RoundTr
|
||||||
r.QuicConfig,
|
r.QuicConfig,
|
||||||
r.Dial,
|
r.Dial,
|
||||||
)
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
r.clients[hostname] = client
|
r.clients[hostname] = client
|
||||||
}
|
}
|
||||||
return client, nil
|
return client, nil
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue