attach the QUIC version to context returned by ClientHelloInfo.Context (#3721)

This commit is contained in:
Marten Seemann 2023-03-27 00:26:14 +11:00 committed by GitHub
parent 11f493381f
commit 41ddaa0262
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 106 additions and 176 deletions

View file

@ -14,7 +14,6 @@ import (
"time"
"github.com/quic-go/quic-go"
"github.com/quic-go/quic-go/internal/handshake"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/utils"
"github.com/quic-go/quic-go/quicvarint"
@ -66,8 +65,9 @@ func ConfigureTLSConfig(tlsConf *tls.Config) *tls.Config {
GetConfigForClient: func(ch *tls.ClientHelloInfo) (*tls.Config, error) {
// determine the ALPN from the QUIC version used
proto := NextProtoH3
if qconn, ok := ch.Conn.(handshake.ConnWithVersion); ok {
proto = versionToALPN(qconn.GetQUICVersion())
val := ch.Context().Value(quic.QUICVersionContextKey)
if v, ok := val.(quic.VersionNumber); ok {
proto = versionToALPN(v)
}
config := tlsConf
if tlsConf.GetConfigForClient != nil {

View file

@ -28,19 +28,6 @@ import (
gmtypes "github.com/onsi/gomega/types"
)
type mockConn struct {
net.Conn
version protocol.VersionNumber
}
func newMockConn(version protocol.VersionNumber) net.Conn {
return &mockConn{version: version}
}
func (c *mockConn) GetQUICVersion() protocol.VersionNumber {
return c.version
}
type mockAddr struct {
addr string
}
@ -940,31 +927,87 @@ var _ = Describe("Server", func() {
})
Context("ConfigureTLSConfig", func() {
var tlsConf *tls.Config
var ch *tls.ClientHelloInfo
BeforeEach(func() {
tlsConf = &tls.Config{}
ch = &tls.ClientHelloInfo{}
})
It("advertises v1 by default", func() {
tlsConf = ConfigureTLSConfig(tlsConf)
Expect(tlsConf.GetConfigForClient).NotTo(BeNil())
config, err := tlsConf.GetConfigForClient(ch)
Expect(err).NotTo(HaveOccurred())
Expect(config.NextProtos).To(Equal([]string{NextProtoH3}))
conf := ConfigureTLSConfig(testdata.GetTLSConfig())
ln, err := quic.ListenAddr("localhost:0", conf, &quic.Config{Versions: []quic.VersionNumber{quic.Version1}})
Expect(err).ToNot(HaveOccurred())
defer ln.Close()
c, err := quic.DialAddr(ln.Addr().String(), &tls.Config{InsecureSkipVerify: true, NextProtos: []string{NextProtoH3}}, nil)
Expect(err).ToNot(HaveOccurred())
defer c.CloseWithError(0, "")
Expect(c.ConnectionState().TLS.ConnectionState.NegotiatedProtocol).To(Equal(NextProtoH3))
})
It("advertises h3-29 for draft-29", func() {
tlsConf = ConfigureTLSConfig(tlsConf)
Expect(tlsConf.GetConfigForClient).NotTo(BeNil())
conf := ConfigureTLSConfig(testdata.GetTLSConfig())
ln, err := quic.ListenAddr("localhost:0", conf, &quic.Config{Versions: []quic.VersionNumber{quic.VersionDraft29}})
Expect(err).ToNot(HaveOccurred())
defer ln.Close()
c, err := quic.DialAddr(ln.Addr().String(), &tls.Config{InsecureSkipVerify: true, NextProtos: []string{NextProtoH3Draft29}}, nil)
Expect(err).ToNot(HaveOccurred())
defer c.CloseWithError(0, "")
Expect(c.ConnectionState().TLS.ConnectionState.NegotiatedProtocol).To(Equal(NextProtoH3Draft29))
})
ch.Conn = newMockConn(protocol.VersionDraft29)
config, err := tlsConf.GetConfigForClient(ch)
Expect(err).NotTo(HaveOccurred())
Expect(config.NextProtos).To(Equal([]string{NextProtoH3Draft29}))
It("sets the GetConfigForClient callback if no tls.Config is given", func() {
var receivedConf *tls.Config
quicListenAddr = func(addr string, tlsConf *tls.Config, _ *quic.Config) (quic.EarlyListener, error) {
receivedConf = tlsConf
return nil, errors.New("listen err")
}
Expect(s.ListenAndServe()).To(HaveOccurred())
Expect(receivedConf).ToNot(BeNil())
})
It("sets the ALPN for tls.Configs returned by the tls.GetConfigForClient", func() {
tlsConf := &tls.Config{
GetConfigForClient: func(ch *tls.ClientHelloInfo) (*tls.Config, error) {
c := testdata.GetTLSConfig()
c.NextProtos = []string{"foo", "bar"}
return c, nil
},
}
ln, err := quic.ListenAddr("localhost:0", ConfigureTLSConfig(tlsConf), &quic.Config{Versions: []quic.VersionNumber{quic.Version1}})
Expect(err).ToNot(HaveOccurred())
defer ln.Close()
c, err := quic.DialAddr(ln.Addr().String(), &tls.Config{InsecureSkipVerify: true, NextProtos: []string{NextProtoH3}}, nil)
Expect(err).ToNot(HaveOccurred())
defer c.CloseWithError(0, "")
Expect(c.ConnectionState().TLS.ConnectionState.NegotiatedProtocol).To(Equal(NextProtoH3))
})
It("works if GetConfigForClient returns a nil tls.Config", func() {
tlsConf := testdata.GetTLSConfig()
tlsConf.GetConfigForClient = func(*tls.ClientHelloInfo) (*tls.Config, error) { return nil, nil }
ln, err := quic.ListenAddr("localhost:0", ConfigureTLSConfig(tlsConf), &quic.Config{Versions: []quic.VersionNumber{quic.Version1}})
Expect(err).ToNot(HaveOccurred())
defer ln.Close()
c, err := quic.DialAddr(ln.Addr().String(), &tls.Config{InsecureSkipVerify: true, NextProtos: []string{NextProtoH3}}, nil)
Expect(err).ToNot(HaveOccurred())
defer c.CloseWithError(0, "")
Expect(c.ConnectionState().TLS.ConnectionState.NegotiatedProtocol).To(Equal(NextProtoH3))
})
It("sets the ALPN for tls.Configs returned by the tls.GetConfigForClient, if it returns a static tls.Config", func() {
tlsClientConf := testdata.GetTLSConfig()
tlsClientConf.NextProtos = []string{"foo", "bar"}
tlsConf := &tls.Config{
GetConfigForClient: func(ch *tls.ClientHelloInfo) (*tls.Config, error) {
return tlsClientConf, nil
},
}
ln, err := quic.ListenAddr("localhost:0", ConfigureTLSConfig(tlsConf), &quic.Config{Versions: []quic.VersionNumber{quic.Version1}})
Expect(err).ToNot(HaveOccurred())
defer ln.Close()
c, err := quic.DialAddr(ln.Addr().String(), &tls.Config{InsecureSkipVerify: true, NextProtos: []string{NextProtoH3}}, nil)
Expect(err).ToNot(HaveOccurred())
defer c.CloseWithError(0, "")
Expect(c.ConnectionState().TLS.ConnectionState.NegotiatedProtocol).To(Equal(NextProtoH3))
// check that the original config was not modified
Expect(tlsClientConf.NextProtos).To(Equal([]string{"foo", "bar"}))
})
})
@ -1179,15 +1222,6 @@ var _ = Describe("Server", func() {
Expect(s.Close()).To(Succeed())
})
checkGetConfigForClientVersions := func(conf *tls.Config) {
c, err := conf.GetConfigForClient(&tls.ClientHelloInfo{Conn: newMockConn(protocol.VersionDraft29)})
ExpectWithOffset(1, err).ToNot(HaveOccurred())
ExpectWithOffset(1, c.NextProtos).To(Equal([]string{NextProtoH3Draft29}))
c, err = conf.GetConfigForClient(&tls.ClientHelloInfo{Conn: newMockConn(protocol.Version1)})
ExpectWithOffset(1, err).ToNot(HaveOccurred())
ExpectWithOffset(1, c.NextProtos).To(Equal([]string{NextProtoH3}))
}
It("uses the quic.Config to start the QUIC server", func() {
conf := &quic.Config{HandshakeIdleTimeout: time.Nanosecond}
var receivedConf *quic.Config
@ -1199,106 +1233,6 @@ var _ = Describe("Server", func() {
Expect(s.ListenAndServe()).To(HaveOccurred())
Expect(receivedConf).To(Equal(conf))
})
It("sets the GetConfigForClient and replaces the ALPN token to the tls.Config, if the GetConfigForClient callback is not set", func() {
tlsConf := &tls.Config{
ClientAuth: tls.RequireAndVerifyClientCert,
NextProtos: []string{"foo", "bar"},
}
var receivedConf *tls.Config
quicListenAddr = func(addr string, tlsConf *tls.Config, _ *quic.Config) (quic.EarlyListener, error) {
receivedConf = tlsConf
return nil, errors.New("listen err")
}
s.TLSConfig = tlsConf
Expect(s.ListenAndServe()).To(HaveOccurred())
Expect(receivedConf.NextProtos).To(BeEmpty())
Expect(receivedConf.ClientAuth).To(BeZero())
// make sure the original tls.Config was not modified
Expect(tlsConf.NextProtos).To(Equal([]string{"foo", "bar"}))
// make sure that the config returned from the GetConfigForClient callback sets the fields of the original config
conf, err := receivedConf.GetConfigForClient(&tls.ClientHelloInfo{})
Expect(err).ToNot(HaveOccurred())
Expect(conf.ClientAuth).To(Equal(tls.RequireAndVerifyClientCert))
checkGetConfigForClientVersions(receivedConf)
})
It("sets the GetConfigForClient callback if no tls.Config is given", func() {
var receivedConf *tls.Config
quicListenAddr = func(addr string, tlsConf *tls.Config, _ *quic.Config) (quic.EarlyListener, error) {
receivedConf = tlsConf
return nil, errors.New("listen err")
}
Expect(s.ListenAndServe()).To(HaveOccurred())
Expect(receivedConf).ToNot(BeNil())
checkGetConfigForClientVersions(receivedConf)
})
It("sets the ALPN for tls.Configs returned by the tls.GetConfigForClient", func() {
tlsConf := &tls.Config{
GetConfigForClient: func(ch *tls.ClientHelloInfo) (*tls.Config, error) {
return &tls.Config{
ClientAuth: tls.RequireAndVerifyClientCert,
NextProtos: []string{"foo", "bar"},
}, nil
},
}
var receivedConf *tls.Config
quicListenAddr = func(addr string, conf *tls.Config, _ *quic.Config) (quic.EarlyListener, error) {
receivedConf = conf
return nil, errors.New("listen err")
}
s.TLSConfig = tlsConf
Expect(s.ListenAndServe()).To(HaveOccurred())
// check that the original config was not modified
conf, err := tlsConf.GetConfigForClient(&tls.ClientHelloInfo{})
Expect(err).ToNot(HaveOccurred())
Expect(conf.NextProtos).To(Equal([]string{"foo", "bar"}))
// check that the config returned by the GetConfigForClient callback uses the returned config
conf, err = receivedConf.GetConfigForClient(&tls.ClientHelloInfo{})
Expect(err).ToNot(HaveOccurred())
Expect(conf.ClientAuth).To(Equal(tls.RequireAndVerifyClientCert))
checkGetConfigForClientVersions(receivedConf)
})
It("sets the ALPN for tls.Configs returned by the tls.GetConfigForClient, if it returns a static tls.Config", func() {
tlsClientConf := &tls.Config{NextProtos: []string{"foo", "bar"}}
tlsConf := &tls.Config{
GetConfigForClient: func(ch *tls.ClientHelloInfo) (*tls.Config, error) {
return tlsClientConf, nil
},
}
var receivedConf *tls.Config
quicListenAddr = func(addr string, conf *tls.Config, _ *quic.Config) (quic.EarlyListener, error) {
receivedConf = conf
return nil, errors.New("listen err")
}
s.TLSConfig = tlsConf
Expect(s.ListenAndServe()).To(HaveOccurred())
// check that the original config was not modified
conf, err := tlsConf.GetConfigForClient(&tls.ClientHelloInfo{})
Expect(err).ToNot(HaveOccurred())
Expect(conf.NextProtos).To(Equal([]string{"foo", "bar"}))
checkGetConfigForClientVersions(receivedConf)
})
It("works if GetConfigForClient returns a nil tls.Config", func() {
tlsConf := &tls.Config{GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) { return nil, nil }}
var receivedConf *tls.Config
quicListenAddr = func(addr string, conf *tls.Config, _ *quic.Config) (quic.EarlyListener, error) {
receivedConf = conf
return nil, errors.New("listen err")
}
s.TLSConfig = tlsConf
Expect(s.ListenAndServe()).To(HaveOccurred())
conf, err := receivedConf.GetConfigForClient(&tls.ClientHelloInfo{})
Expect(err).ToNot(HaveOccurred())
Expect(conf).ToNot(BeNil())
checkGetConfigForClientVersions(receivedConf)
})
})
It("closes gracefully", func() {

View file

@ -57,6 +57,10 @@ var ConnectionTracingKey = connTracingCtxKey{}
type connTracingCtxKey struct{}
// QUICVersionContextKey can be used to find out the QUIC version of a TLS handshake from the
// context returned by tls.Config.ClientHelloInfo.Context.
var QUICVersionContextKey = handshake.QUICVersionContextKey
// Stream is the interface implemented by QUIC streams
// In addition to the errors listed on the Connection,
// calls to stream functions can return a StreamError if the stream is canceled.

View file

@ -2,6 +2,7 @@ package handshake
import (
"bytes"
"context"
"crypto/tls"
"errors"
"fmt"
@ -20,6 +21,10 @@ import (
"github.com/quic-go/quic-go/quicvarint"
)
type quicVersionContextKey struct{}
var QUICVersionContextKey = &quicVersionContextKey{}
// TLS unexpected_message alert
const alertUnexpectedMessage uint8 = 10
@ -64,30 +69,25 @@ const clientSessionStateRevision = 3
type conn struct {
localAddr, remoteAddr net.Addr
version protocol.VersionNumber
}
var _ ConnWithVersion = &conn{}
func newConn(local, remote net.Addr, version protocol.VersionNumber) ConnWithVersion {
return &conn{
localAddr: local,
remoteAddr: remote,
version: version,
}
}
var _ net.Conn = &conn{}
func (c *conn) Read([]byte) (int, error) { return 0, nil }
func (c *conn) Write([]byte) (int, error) { return 0, nil }
func (c *conn) Close() error { return nil }
func (c *conn) RemoteAddr() net.Addr { return c.remoteAddr }
func (c *conn) LocalAddr() net.Addr { return c.localAddr }
func (c *conn) SetReadDeadline(time.Time) error { return nil }
func (c *conn) SetWriteDeadline(time.Time) error { return nil }
func (c *conn) SetDeadline(time.Time) error { return nil }
func (c *conn) GetQUICVersion() protocol.VersionNumber { return c.version }
func newConn(local, remote net.Addr) net.Conn {
return &conn{
localAddr: local,
remoteAddr: remote,
}
}
func (c *conn) Read([]byte) (int, error) { return 0, nil }
func (c *conn) Write([]byte) (int, error) { return 0, nil }
func (c *conn) Close() error { return nil }
func (c *conn) RemoteAddr() net.Addr { return c.remoteAddr }
func (c *conn) LocalAddr() net.Addr { return c.localAddr }
func (c *conn) SetReadDeadline(time.Time) error { return nil }
func (c *conn) SetWriteDeadline(time.Time) error { return nil }
func (c *conn) SetDeadline(time.Time) error { return nil }
type cryptoSetup struct {
tlsConf *tls.Config
@ -183,7 +183,7 @@ func NewCryptoSetupClient(
protocol.PerspectiveClient,
version,
)
cs.conn = qtls.Client(newConn(localAddr, remoteAddr, version), cs.tlsConf, cs.extraConf)
cs.conn = qtls.Client(newConn(localAddr, remoteAddr), cs.tlsConf, cs.extraConf)
return cs, clientHelloWritten
}
@ -218,7 +218,7 @@ func NewCryptoSetupServer(
version,
)
cs.allow0RTT = allow0RTT
cs.conn = qtls.Server(newConn(localAddr, remoteAddr, version), cs.tlsConf, cs.extraConf)
cs.conn = qtls.Server(newConn(localAddr, remoteAddr), cs.tlsConf, cs.extraConf)
return cs
}
@ -307,7 +307,7 @@ func (h *cryptoSetup) RunHandshake() {
handshakeErrChan := make(chan error, 1)
go func() {
defer close(h.handshakeDone)
if err := h.conn.Handshake(); err != nil {
if err := h.conn.HandshakeContext(context.WithValue(context.Background(), QUICVersionContextKey, h.version)); err != nil {
handshakeErrChan <- err
return
}

View file

@ -3,7 +3,6 @@ package handshake
import (
"errors"
"io"
"net"
"time"
"github.com/quic-go/quic-go/internal/protocol"
@ -93,10 +92,3 @@ type CryptoSetup interface {
Get0RTTSealer() (LongHeaderSealer, error)
Get1RTTSealer() (ShortHeaderSealer, error)
}
// ConnWithVersion is the connection used in the ClientHelloInfo.
// It can be used to determine the QUIC version in use.
type ConnWithVersion interface {
net.Conn
GetQUICVersion() protocol.VersionNumber
}