select the H3 ALPN based on the QUIC version in use (for the H3 server)

This commit is contained in:
Marten Seemann 2020-10-29 13:08:16 +07:00
parent b7652887d2
commit c968b18a21
10 changed files with 134 additions and 62 deletions

View file

@ -90,6 +90,7 @@ func main() {
utils.NewRTTStats(),
nil,
utils.DefaultLogger.WithPrefix("client"),
protocol.VersionTLS,
)
sChunkChan, sInitialStream, sHandshakeStream := initStreams()
@ -108,6 +109,7 @@ func main() {
utils.NewRTTStats(),
nil,
utils.DefaultLogger.WithPrefix("server"),
protocol.VersionTLS,
)
serverHandshakeCompleted := make(chan struct{})

View file

@ -387,6 +387,7 @@ func runHandshake(runConfig [confLen]byte, messageConfig uint8, clientConf *tls.
utils.NewRTTStats(),
nil,
utils.DefaultLogger.WithPrefix("client"),
protocol.VersionTLS,
)
sChunkChan, sInitialStream, sHandshakeStream := initStreams()
@ -403,6 +404,7 @@ func runHandshake(runConfig [confLen]byte, messageConfig uint8, clientConf *tls.
utils.NewRTTStats(),
nil,
utils.DefaultLogger.WithPrefix("server"),
protocol.VersionTLS,
)
if len(data) == 0 {

View file

@ -69,7 +69,7 @@ func newClient(
tlsConf = tlsConf.Clone()
}
// Replace existing ALPNs by H3
tlsConf.NextProtos = []string{nextProtoH3}
tlsConf.NextProtos = []string{nextProtoH3Draft29}
if quicConfig == nil {
quicConfig = defaultQuicConfig
}

View file

@ -53,7 +53,7 @@ var _ = Describe("Client", func() {
var dialAddrCalled bool
dialAddr = func(_ string, tlsConf *tls.Config, quicConf *quic.Config) (quic.EarlySession, error) {
Expect(quicConf).To(Equal(defaultQuicConfig))
Expect(tlsConf.NextProtos).To(Equal([]string{nextProtoH3}))
Expect(tlsConf.NextProtos).To(Equal([]string{nextProtoH3Draft29}))
dialAddrCalled = true
return nil, errors.New("test done")
}
@ -90,7 +90,7 @@ var _ = Describe("Client", func() {
) (quic.EarlySession, error) {
Expect(hostname).To(Equal("localhost:1337"))
Expect(tlsConfP.ServerName).To(Equal(tlsConf.ServerName))
Expect(tlsConfP.NextProtos).To(Equal([]string{nextProtoH3}))
Expect(tlsConfP.NextProtos).To(Equal([]string{nextProtoH3Draft29}))
Expect(quicConfP.MaxIdleTimeout).To(Equal(quicConf.MaxIdleTimeout))
dialAddrCalled = true
return nil, errors.New("test done")

View file

@ -10,11 +10,13 @@ import (
"net"
"net/http"
"runtime"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go/internal/handshake"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/marten-seemann/qpack"
)
@ -25,7 +27,12 @@ var (
quicListenAddr = quic.ListenAddrEarly
)
const nextProtoH3 = "h3-29"
const (
nextProtoH3Draft29 = "h3-29"
nextProtoH3Draft32 = "h3-32"
)
var supportedVersions = []string{nextProtoH3Draft29, nextProtoH3Draft32}
// contextKey is a value for use with context.WithValue. It's used as
// a pointer so it fits in an interface{} without allocation.
@ -115,32 +122,36 @@ func (s *Server) serveImpl(tlsConf *tls.Config, conn net.PacketConn) error {
s.logger = utils.DefaultLogger.WithPrefix("server")
})
if tlsConf == nil {
tlsConf = &tls.Config{}
} else {
tlsConf = tlsConf.Clone()
}
// Replace existing ALPNs by H3
tlsConf.NextProtos = []string{nextProtoH3}
if tlsConf.GetConfigForClient != nil {
getConfigForClient := tlsConf.GetConfigForClient
tlsConf.GetConfigForClient = func(ch *tls.ClientHelloInfo) (*tls.Config, error) {
conf, err := getConfigForClient(ch)
if err != nil || conf == nil {
return conf, err
// The tls.Config we pass to Listen needs to have the GetConfigForClient callback set.
// That way, we can get the QUIC version and set the correct ALPN value.
baseConf := &tls.Config{
GetConfigForClient: func(ch *tls.ClientHelloInfo) (*tls.Config, error) {
// determine the ALPN from the QUIC version used
proto := nextProtoH3Draft29
if qconn, ok := ch.Conn.(handshake.ConnWithVersion); ok && qconn.GetQUICVersion() == quic.VersionDraft32 {
proto = nextProtoH3Draft32
}
conf := tlsConf
if tlsConf.GetConfigForClient != nil {
getConfigForClient := tlsConf.GetConfigForClient
var err error
conf, err = getConfigForClient(ch)
if err != nil {
return nil, err
}
}
conf = conf.Clone()
conf.NextProtos = []string{nextProtoH3}
conf.NextProtos = []string{proto}
return conf, nil
}
},
}
var ln quic.EarlyListener
var err error
if conn == nil {
ln, err = quicListenAddr(s.Addr, tlsConf, s.QuicConfig)
ln, err = quicListenAddr(s.Addr, baseConf, s.QuicConfig)
} else {
ln, err = quicListen(conn, tlsConf, s.QuicConfig)
ln, err = quicListen(conn, baseConf, s.QuicConfig)
}
if err != nil {
return err
@ -344,8 +355,11 @@ func (s *Server) SetQuicHeaders(hdr http.Header) error {
atomic.StoreUint32(&s.port, port)
}
hdr.Add("Alt-Svc", fmt.Sprintf(`%s=":%d"; ma=2592000`, nextProtoH3, port))
altSvc := make([]string, len(supportedVersions))
for i, v := range supportedVersions {
altSvc[i] = fmt.Sprintf(`%s=":%d"; ma=2592000`, v, port)
}
hdr.Add("Alt-Svc", strings.Join(altSvc, ","))
return nil
}

View file

@ -5,7 +5,6 @@ import (
"context"
"crypto/tls"
"errors"
"fmt"
"io"
"net"
"net/http"
@ -14,6 +13,7 @@ import (
"github.com/golang/mock/gomock"
"github.com/lucas-clemente/quic-go"
mockquic "github.com/lucas-clemente/quic-go/internal/mocks/quic"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/testdata"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/marten-seemann/qpack"
@ -22,6 +22,19 @@ import (
. "github.com/onsi/gomega"
)
type mockConn struct {
net.Conn
version protocol.VersionNumber
}
func newMockConn(version protocol.VersionNumber) net.Conn {
return &mockConn{version: version}
}
func (c *mockConn) GetQUICVersion() protocol.VersionNumber {
return c.version
}
var _ = Describe("Server", func() {
var (
s *Server
@ -339,19 +352,10 @@ var _ = Describe("Server", func() {
})
Context("setting http headers", func() {
var expected http.Header
getExpectedHeader := func() http.Header {
return http.Header{
"Alt-Svc": {fmt.Sprintf(`%s=":443"; ma=2592000`, nextProtoH3)},
}
expected := http.Header{
"Alt-Svc": {`h3-29=":443"; ma=2592000,h3-32=":443"; ma=2592000`},
}
BeforeEach(func() {
Expect(getExpectedHeader()).To(Equal(http.Header{"Alt-Svc": {nextProtoH3 + `=":443"; ma=2592000`}}))
expected = getExpectedHeader()
})
It("sets proper headers with numeric port", func() {
s.Server.Addr = ":443"
hdr := http.Header{}
@ -496,6 +500,15 @@ var _ = Describe("Server", func() {
Expect(s.Close()).To(Succeed())
})
checkGetConfigForClientVersions := func(conf *tls.Config) {
c, err := conf.GetConfigForClient(&tls.ClientHelloInfo{Conn: newMockConn(protocol.VersionDraft29)})
ExpectWithOffset(1, err).ToNot(HaveOccurred())
ExpectWithOffset(1, c.NextProtos).To(Equal([]string{nextProtoH3Draft29}))
c, err = conf.GetConfigForClient(&tls.ClientHelloInfo{Conn: newMockConn(protocol.VersionDraft32)})
ExpectWithOffset(1, err).ToNot(HaveOccurred())
ExpectWithOffset(1, c.NextProtos).To(Equal([]string{nextProtoH3Draft32}))
}
It("uses the quic.Config to start the QUIC server", func() {
conf := &quic.Config{HandshakeTimeout: time.Nanosecond}
var receivedConf *quic.Config
@ -508,8 +521,11 @@ var _ = Describe("Server", func() {
Expect(receivedConf).To(Equal(conf))
})
It("replaces the ALPN token to the tls.Config", func() {
tlsConf := &tls.Config{NextProtos: []string{"foo", "bar"}}
It("sets the GetConfigForClient and replaces the ALPN token to the tls.Config, if the GetConfigForClient callback is not set", func() {
tlsConf := &tls.Config{
ClientAuth: tls.RequireAndVerifyClientCert,
NextProtos: []string{"foo", "bar"},
}
var receivedConf *tls.Config
quicListenAddr = func(addr string, tlsConf *tls.Config, _ *quic.Config) (quic.EarlyListener, error) {
receivedConf = tlsConf
@ -517,25 +533,35 @@ var _ = Describe("Server", func() {
}
s.TLSConfig = tlsConf
Expect(s.ListenAndServe()).To(HaveOccurred())
Expect(receivedConf.NextProtos).To(Equal([]string{nextProtoH3}))
Expect(receivedConf.NextProtos).To(BeEmpty())
Expect(receivedConf.ClientAuth).To(BeZero())
// make sure the original tls.Config was not modified
Expect(tlsConf.NextProtos).To(Equal([]string{"foo", "bar"}))
// make sure that the config returned from the GetConfigForClient callback sets the fields of the original config
conf, err := receivedConf.GetConfigForClient(&tls.ClientHelloInfo{})
Expect(err).ToNot(HaveOccurred())
Expect(conf.ClientAuth).To(Equal(tls.RequireAndVerifyClientCert))
checkGetConfigForClientVersions(receivedConf)
})
It("uses the ALPN token if no tls.Config is given", func() {
It("sets the GetConfigForClient callback if no tls.Config is given", func() {
var receivedConf *tls.Config
quicListenAddr = func(addr string, tlsConf *tls.Config, _ *quic.Config) (quic.EarlyListener, error) {
receivedConf = tlsConf
return nil, errors.New("listen err")
}
Expect(s.ListenAndServe()).To(HaveOccurred())
Expect(receivedConf.NextProtos).To(Equal([]string{nextProtoH3}))
Expect(receivedConf).ToNot(BeNil())
checkGetConfigForClientVersions(receivedConf)
})
It("sets the ALPN for tls.Configs returned by the tls.GetConfigForClient", func() {
tlsConf := &tls.Config{
GetConfigForClient: func(ch *tls.ClientHelloInfo) (*tls.Config, error) {
return &tls.Config{NextProtos: []string{"foo", "bar"}}, nil
return &tls.Config{
ClientAuth: tls.RequireAndVerifyClientCert,
NextProtos: []string{"foo", "bar"},
}, nil
},
}
@ -546,14 +572,15 @@ var _ = Describe("Server", func() {
}
s.TLSConfig = tlsConf
Expect(s.ListenAndServe()).To(HaveOccurred())
// check that the config used by QUIC uses the h3 ALPN
conf, err := receivedConf.GetConfigForClient(&tls.ClientHelloInfo{})
Expect(err).ToNot(HaveOccurred())
Expect(conf.NextProtos).To(Equal([]string{nextProtoH3}))
// check that the original config was not modified
conf, err = tlsConf.GetConfigForClient(&tls.ClientHelloInfo{})
conf, err := tlsConf.GetConfigForClient(&tls.ClientHelloInfo{})
Expect(err).ToNot(HaveOccurred())
Expect(conf.NextProtos).To(Equal([]string{"foo", "bar"}))
// check that the config returned by the GetConfigForClient callback uses the returned config
conf, err = receivedConf.GetConfigForClient(&tls.ClientHelloInfo{})
Expect(err).ToNot(HaveOccurred())
Expect(conf.ClientAuth).To(Equal(tls.RequireAndVerifyClientCert))
checkGetConfigForClientVersions(receivedConf)
})
It("sets the ALPN for tls.Configs returned by the tls.GetConfigForClient, if it returns a static tls.Config", func() {
@ -571,14 +598,11 @@ var _ = Describe("Server", func() {
}
s.TLSConfig = tlsConf
Expect(s.ListenAndServe()).To(HaveOccurred())
// check that the config used by QUIC uses the h3 ALPN
conf, err := receivedConf.GetConfigForClient(&tls.ClientHelloInfo{})
Expect(err).ToNot(HaveOccurred())
Expect(conf.NextProtos).To(Equal([]string{nextProtoH3}))
// check that the original config was not modified
conf, err = tlsConf.GetConfigForClient(&tls.ClientHelloInfo{})
conf, err := tlsConf.GetConfigForClient(&tls.ClientHelloInfo{})
Expect(err).ToNot(HaveOccurred())
Expect(conf.NextProtos).To(Equal([]string{"foo", "bar"}))
checkGetConfigForClientVersions(receivedConf)
})
})

View file

@ -62,25 +62,30 @@ const clientSessionStateRevision = 3
type conn struct {
localAddr, remoteAddr net.Addr
version protocol.VersionNumber
}
func newConn(local, remote net.Addr) net.Conn {
var _ ConnWithVersion = &conn{}
func newConn(local, remote net.Addr, version protocol.VersionNumber) ConnWithVersion {
return &conn{
localAddr: local,
remoteAddr: remote,
version: version,
}
}
var _ net.Conn = &conn{}
func (c *conn) Read([]byte) (int, error) { return 0, nil }
func (c *conn) Write([]byte) (int, error) { return 0, nil }
func (c *conn) Close() error { return nil }
func (c *conn) RemoteAddr() net.Addr { return c.remoteAddr }
func (c *conn) LocalAddr() net.Addr { return c.localAddr }
func (c *conn) SetReadDeadline(time.Time) error { return nil }
func (c *conn) SetWriteDeadline(time.Time) error { return nil }
func (c *conn) SetDeadline(time.Time) error { return nil }
func (c *conn) Read([]byte) (int, error) { return 0, nil }
func (c *conn) Write([]byte) (int, error) { return 0, nil }
func (c *conn) Close() error { return nil }
func (c *conn) RemoteAddr() net.Addr { return c.remoteAddr }
func (c *conn) LocalAddr() net.Addr { return c.localAddr }
func (c *conn) SetReadDeadline(time.Time) error { return nil }
func (c *conn) SetWriteDeadline(time.Time) error { return nil }
func (c *conn) SetDeadline(time.Time) error { return nil }
func (c *conn) GetQUICVersion() protocol.VersionNumber { return c.version }
type cryptoSetup struct {
tlsConf *tls.Config
@ -156,6 +161,7 @@ func NewCryptoSetupClient(
rttStats *utils.RTTStats,
tracer logging.ConnectionTracer,
logger utils.Logger,
version protocol.VersionNumber,
) (CryptoSetup, <-chan *wire.TransportParameters /* ClientHello written. Receive nil for non-0-RTT */) {
cs, clientHelloWritten := newCryptoSetup(
initialStream,
@ -170,7 +176,7 @@ func NewCryptoSetupClient(
logger,
protocol.PerspectiveClient,
)
cs.conn = qtls.Client(newConn(localAddr, remoteAddr), cs.tlsConf, cs.extraConf)
cs.conn = qtls.Client(newConn(localAddr, remoteAddr, version), cs.tlsConf, cs.extraConf)
return cs, clientHelloWritten
}
@ -188,6 +194,7 @@ func NewCryptoSetupServer(
rttStats *utils.RTTStats,
tracer logging.ConnectionTracer,
logger utils.Logger,
version protocol.VersionNumber,
) CryptoSetup {
cs, _ := newCryptoSetup(
initialStream,
@ -202,7 +209,7 @@ func NewCryptoSetupServer(
logger,
protocol.PerspectiveServer,
)
cs.conn = qtls.Server(newConn(localAddr, remoteAddr), cs.tlsConf, cs.extraConf)
cs.conn = qtls.Server(newConn(localAddr, remoteAddr, version), cs.tlsConf, cs.extraConf)
return cs
}

View file

@ -99,6 +99,7 @@ var _ = Describe("Crypto Setup TLS", func() {
&utils.RTTStats{},
nil,
utils.DefaultLogger.WithPrefix("server"),
protocol.VersionTLS,
)
done := make(chan struct{})
@ -139,6 +140,7 @@ var _ = Describe("Crypto Setup TLS", func() {
&utils.RTTStats{},
nil,
utils.DefaultLogger.WithPrefix("server"),
protocol.VersionTLS,
)
done := make(chan struct{})
@ -182,6 +184,7 @@ var _ = Describe("Crypto Setup TLS", func() {
&utils.RTTStats{},
nil,
utils.DefaultLogger.WithPrefix("server"),
protocol.VersionTLS,
)
done := make(chan struct{})
@ -218,6 +221,7 @@ var _ = Describe("Crypto Setup TLS", func() {
&utils.RTTStats{},
nil,
utils.DefaultLogger.WithPrefix("server"),
protocol.VersionTLS,
)
done := make(chan struct{})
@ -334,6 +338,7 @@ var _ = Describe("Crypto Setup TLS", func() {
clientRTTStats,
nil,
utils.DefaultLogger.WithPrefix("client"),
protocol.VersionTLS,
)
var sHandshakeComplete bool
@ -360,6 +365,7 @@ var _ = Describe("Crypto Setup TLS", func() {
serverRTTStats,
nil,
utils.DefaultLogger.WithPrefix("server"),
protocol.VersionTLS,
)
handshake(client, cChunkChan, server, sChunkChan)
@ -429,6 +435,7 @@ var _ = Describe("Crypto Setup TLS", func() {
&utils.RTTStats{},
nil,
utils.DefaultLogger.WithPrefix("client"),
protocol.VersionTLS,
)
done := make(chan struct{})
@ -471,6 +478,7 @@ var _ = Describe("Crypto Setup TLS", func() {
&utils.RTTStats{},
nil,
utils.DefaultLogger.WithPrefix("client"),
protocol.VersionTLS,
)
sChunkChan, sInitialStream, sHandshakeStream := initStreams()
@ -495,6 +503,7 @@ var _ = Describe("Crypto Setup TLS", func() {
&utils.RTTStats{},
nil,
utils.DefaultLogger.WithPrefix("server"),
protocol.VersionTLS,
)
done := make(chan struct{})
@ -528,6 +537,7 @@ var _ = Describe("Crypto Setup TLS", func() {
&utils.RTTStats{},
nil,
utils.DefaultLogger.WithPrefix("client"),
protocol.VersionTLS,
)
sChunkChan, sInitialStream, sHandshakeStream := initStreams()
@ -548,6 +558,7 @@ var _ = Describe("Crypto Setup TLS", func() {
&utils.RTTStats{},
nil,
utils.DefaultLogger.WithPrefix("server"),
protocol.VersionTLS,
)
done := make(chan struct{})
@ -588,6 +599,7 @@ var _ = Describe("Crypto Setup TLS", func() {
&utils.RTTStats{},
nil,
utils.DefaultLogger.WithPrefix("client"),
protocol.VersionTLS,
)
sChunkChan, sInitialStream, sHandshakeStream := initStreams()
@ -608,6 +620,7 @@ var _ = Describe("Crypto Setup TLS", func() {
&utils.RTTStats{},
nil,
utils.DefaultLogger.WithPrefix("server"),
protocol.VersionTLS,
)
done := make(chan struct{})

View file

@ -3,6 +3,7 @@ package handshake
import (
"errors"
"io"
"net"
"time"
"github.com/lucas-clemente/quic-go/internal/protocol"
@ -90,3 +91,10 @@ type CryptoSetup interface {
Get0RTTSealer() (LongHeaderSealer, error)
Get1RTTSealer() (ShortHeaderSealer, error)
}
// ConnWithVersion is the connection used in the ClientHelloInfo.
// It can be used to determine the QUIC version in use.
type ConnWithVersion interface {
net.Conn
GetQUICVersion() protocol.VersionNumber
}

View file

@ -325,6 +325,7 @@ var newSession = func(
s.rttStats,
tracer,
logger,
s.version,
)
s.cryptoStreamHandler = cs
s.packer = newPacketPacker(
@ -442,6 +443,7 @@ var newClientSession = func(
s.rttStats,
tracer,
logger,
s.version,
)
s.clientHelloWritten = clientHelloWritten
s.cryptoStreamHandler = cs