select the H3 ALPN based on the QUIC version in use (for the H3 server)

This commit is contained in:
Marten Seemann 2020-10-29 13:08:16 +07:00
parent b7652887d2
commit c968b18a21
10 changed files with 134 additions and 62 deletions

View file

@ -5,7 +5,6 @@ import (
"context"
"crypto/tls"
"errors"
"fmt"
"io"
"net"
"net/http"
@ -14,6 +13,7 @@ import (
"github.com/golang/mock/gomock"
"github.com/lucas-clemente/quic-go"
mockquic "github.com/lucas-clemente/quic-go/internal/mocks/quic"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/testdata"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/marten-seemann/qpack"
@ -22,6 +22,19 @@ import (
. "github.com/onsi/gomega"
)
type mockConn struct {
net.Conn
version protocol.VersionNumber
}
func newMockConn(version protocol.VersionNumber) net.Conn {
return &mockConn{version: version}
}
func (c *mockConn) GetQUICVersion() protocol.VersionNumber {
return c.version
}
var _ = Describe("Server", func() {
var (
s *Server
@ -339,19 +352,10 @@ var _ = Describe("Server", func() {
})
Context("setting http headers", func() {
var expected http.Header
getExpectedHeader := func() http.Header {
return http.Header{
"Alt-Svc": {fmt.Sprintf(`%s=":443"; ma=2592000`, nextProtoH3)},
}
expected := http.Header{
"Alt-Svc": {`h3-29=":443"; ma=2592000,h3-32=":443"; ma=2592000`},
}
BeforeEach(func() {
Expect(getExpectedHeader()).To(Equal(http.Header{"Alt-Svc": {nextProtoH3 + `=":443"; ma=2592000`}}))
expected = getExpectedHeader()
})
It("sets proper headers with numeric port", func() {
s.Server.Addr = ":443"
hdr := http.Header{}
@ -496,6 +500,15 @@ var _ = Describe("Server", func() {
Expect(s.Close()).To(Succeed())
})
checkGetConfigForClientVersions := func(conf *tls.Config) {
c, err := conf.GetConfigForClient(&tls.ClientHelloInfo{Conn: newMockConn(protocol.VersionDraft29)})
ExpectWithOffset(1, err).ToNot(HaveOccurred())
ExpectWithOffset(1, c.NextProtos).To(Equal([]string{nextProtoH3Draft29}))
c, err = conf.GetConfigForClient(&tls.ClientHelloInfo{Conn: newMockConn(protocol.VersionDraft32)})
ExpectWithOffset(1, err).ToNot(HaveOccurred())
ExpectWithOffset(1, c.NextProtos).To(Equal([]string{nextProtoH3Draft32}))
}
It("uses the quic.Config to start the QUIC server", func() {
conf := &quic.Config{HandshakeTimeout: time.Nanosecond}
var receivedConf *quic.Config
@ -508,8 +521,11 @@ var _ = Describe("Server", func() {
Expect(receivedConf).To(Equal(conf))
})
It("replaces the ALPN token to the tls.Config", func() {
tlsConf := &tls.Config{NextProtos: []string{"foo", "bar"}}
It("sets the GetConfigForClient and replaces the ALPN token to the tls.Config, if the GetConfigForClient callback is not set", func() {
tlsConf := &tls.Config{
ClientAuth: tls.RequireAndVerifyClientCert,
NextProtos: []string{"foo", "bar"},
}
var receivedConf *tls.Config
quicListenAddr = func(addr string, tlsConf *tls.Config, _ *quic.Config) (quic.EarlyListener, error) {
receivedConf = tlsConf
@ -517,25 +533,35 @@ var _ = Describe("Server", func() {
}
s.TLSConfig = tlsConf
Expect(s.ListenAndServe()).To(HaveOccurred())
Expect(receivedConf.NextProtos).To(Equal([]string{nextProtoH3}))
Expect(receivedConf.NextProtos).To(BeEmpty())
Expect(receivedConf.ClientAuth).To(BeZero())
// make sure the original tls.Config was not modified
Expect(tlsConf.NextProtos).To(Equal([]string{"foo", "bar"}))
// make sure that the config returned from the GetConfigForClient callback sets the fields of the original config
conf, err := receivedConf.GetConfigForClient(&tls.ClientHelloInfo{})
Expect(err).ToNot(HaveOccurred())
Expect(conf.ClientAuth).To(Equal(tls.RequireAndVerifyClientCert))
checkGetConfigForClientVersions(receivedConf)
})
It("uses the ALPN token if no tls.Config is given", func() {
It("sets the GetConfigForClient callback if no tls.Config is given", func() {
var receivedConf *tls.Config
quicListenAddr = func(addr string, tlsConf *tls.Config, _ *quic.Config) (quic.EarlyListener, error) {
receivedConf = tlsConf
return nil, errors.New("listen err")
}
Expect(s.ListenAndServe()).To(HaveOccurred())
Expect(receivedConf.NextProtos).To(Equal([]string{nextProtoH3}))
Expect(receivedConf).ToNot(BeNil())
checkGetConfigForClientVersions(receivedConf)
})
It("sets the ALPN for tls.Configs returned by the tls.GetConfigForClient", func() {
tlsConf := &tls.Config{
GetConfigForClient: func(ch *tls.ClientHelloInfo) (*tls.Config, error) {
return &tls.Config{NextProtos: []string{"foo", "bar"}}, nil
return &tls.Config{
ClientAuth: tls.RequireAndVerifyClientCert,
NextProtos: []string{"foo", "bar"},
}, nil
},
}
@ -546,14 +572,15 @@ var _ = Describe("Server", func() {
}
s.TLSConfig = tlsConf
Expect(s.ListenAndServe()).To(HaveOccurred())
// check that the config used by QUIC uses the h3 ALPN
conf, err := receivedConf.GetConfigForClient(&tls.ClientHelloInfo{})
Expect(err).ToNot(HaveOccurred())
Expect(conf.NextProtos).To(Equal([]string{nextProtoH3}))
// check that the original config was not modified
conf, err = tlsConf.GetConfigForClient(&tls.ClientHelloInfo{})
conf, err := tlsConf.GetConfigForClient(&tls.ClientHelloInfo{})
Expect(err).ToNot(HaveOccurred())
Expect(conf.NextProtos).To(Equal([]string{"foo", "bar"}))
// check that the config returned by the GetConfigForClient callback uses the returned config
conf, err = receivedConf.GetConfigForClient(&tls.ClientHelloInfo{})
Expect(err).ToNot(HaveOccurred())
Expect(conf.ClientAuth).To(Equal(tls.RequireAndVerifyClientCert))
checkGetConfigForClientVersions(receivedConf)
})
It("sets the ALPN for tls.Configs returned by the tls.GetConfigForClient, if it returns a static tls.Config", func() {
@ -571,14 +598,11 @@ var _ = Describe("Server", func() {
}
s.TLSConfig = tlsConf
Expect(s.ListenAndServe()).To(HaveOccurred())
// check that the config used by QUIC uses the h3 ALPN
conf, err := receivedConf.GetConfigForClient(&tls.ClientHelloInfo{})
Expect(err).ToNot(HaveOccurred())
Expect(conf.NextProtos).To(Equal([]string{nextProtoH3}))
// check that the original config was not modified
conf, err = tlsConf.GetConfigForClient(&tls.ClientHelloInfo{})
conf, err := tlsConf.GetConfigForClient(&tls.ClientHelloInfo{})
Expect(err).ToNot(HaveOccurred())
Expect(conf.NextProtos).To(Equal([]string{"foo", "bar"}))
checkGetConfigForClientVersions(receivedConf)
})
})