make sure the TLS config contains the H3 ALPN token in server and client

This commit is contained in:
Marten Seemann 2019-05-20 17:55:59 +01:00
parent 8afed81c49
commit afa4615c4d
4 changed files with 54 additions and 7 deletions

View file

@ -55,7 +55,9 @@ func newClient(
if tlsConf == nil {
tlsConf = &tls.Config{}
}
tlsConf.NextProtos = []string{"h3-19"}
if !strSliceContains(tlsConf.NextProtos, nextProtoH3) {
tlsConf.NextProtos = append(tlsConf.NextProtos, nextProtoH3)
}
if quicConfig == nil {
quicConfig = defaultQuicConfig
}

View file

@ -42,11 +42,12 @@ var _ = Describe("Client", func() {
dialAddr = origDialAddr
})
It("uses the default QUIC config if none is give", func() {
It("uses the default QUIC and TLS config if none is give", func() {
client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil)
var dialAddrCalled bool
dialAddr = func(_ string, _ *tls.Config, quicConf *quic.Config) (quic.Session, error) {
dialAddr = func(_ string, tlsConf *tls.Config, quicConf *quic.Config) (quic.Session, error) {
Expect(quicConf).To(Equal(defaultQuicConfig))
Expect(tlsConf.NextProtos).To(Equal([]string{nextProtoH3}))
dialAddrCalled = true
return nil, errors.New("test done")
}
@ -69,7 +70,10 @@ var _ = Describe("Client", func() {
})
It("uses the TLS config and QUIC config", func() {
tlsConf := &tls.Config{ServerName: "foo.bar"}
tlsConf := &tls.Config{
ServerName: "foo.bar",
NextProtos: []string{"proto foo", "proto bar"},
}
quicConf := &quic.Config{IdleTimeout: time.Nanosecond}
client = newClient("localhost:1337", tlsConf, &roundTripperOpts{}, quicConf, nil)
var dialAddrCalled bool
@ -79,7 +83,8 @@ var _ = Describe("Client", func() {
quicConfP *quic.Config,
) (quic.Session, error) {
Expect(hostname).To(Equal("localhost:1337"))
Expect(tlsConfP).To(Equal(tlsConf))
Expect(tlsConfP.ServerName).To(Equal(tlsConf.ServerName))
Expect(tlsConfP.NextProtos).To(Equal([]string{"proto foo", "proto bar", nextProtoH3}))
Expect(quicConfP.IdleTimeout).To(Equal(quicConf.IdleTimeout))
dialAddrCalled = true
return nil, errors.New("test done")

View file

@ -25,6 +25,8 @@ var (
quicListenAddr = quic.ListenAddr
)
const nextProtoH3 = "h3-20"
// Server is a HTTP2 server listening for QUIC connections.
type Server struct {
*http.Server
@ -88,6 +90,14 @@ func (s *Server) serveImpl(tlsConfig *tls.Config, conn net.PacketConn) error {
return errors.New("ListenAndServe may only be called once")
}
if tlsConfig == nil {
tlsConfig = &tls.Config{}
}
if !strSliceContains(tlsConfig.NextProtos, nextProtoH3) {
tlsConfig.NextProtos = append(tlsConfig.NextProtos, nextProtoH3)
}
var ln quic.Listener
var err error
if conn == nil {
@ -353,3 +363,12 @@ func ListenAndServe(addr, certFile, keyFile string, handler http.Handler) error
return err
}
}
func strSliceContains(ss []string, s string) bool {
for _, v := range ss {
if v == s {
return true
}
}
return false
}

View file

@ -349,10 +349,10 @@ var _ = Describe("Server", func() {
Expect(s.Close()).To(Succeed())
})
It("uses the quic.Config to start the quic server", func() {
It("uses the quic.Config to start the QUIC server", func() {
conf := &quic.Config{HandshakeTimeout: time.Nanosecond}
var receivedConf *quic.Config
quicListenAddr = func(addr string, tlsConf *tls.Config, config *quic.Config) (quic.Listener, error) {
quicListenAddr = func(addr string, _ *tls.Config, config *quic.Config) (quic.Listener, error) {
receivedConf = config
return nil, errors.New("listen err")
}
@ -360,6 +360,27 @@ var _ = Describe("Server", func() {
Expect(s.ListenAndServe()).To(HaveOccurred())
Expect(receivedConf).To(Equal(conf))
})
It("adds the ALPN token to the tls.Config", func() {
var receivedConf *tls.Config
quicListenAddr = func(addr string, tlsConf *tls.Config, _ *quic.Config) (quic.Listener, error) {
receivedConf = tlsConf
return nil, errors.New("listen err")
}
s.TLSConfig = &tls.Config{NextProtos: []string{"foo", "bar"}}
Expect(s.ListenAndServe()).To(HaveOccurred())
Expect(receivedConf.NextProtos).To(Equal([]string{"foo", "bar", nextProtoH3}))
})
It("uses the ALPN token if no tls.Config is given", func() {
var receivedConf *tls.Config
quicListenAddr = func(addr string, tlsConf *tls.Config, _ *quic.Config) (quic.Listener, error) {
receivedConf = tlsConf
return nil, errors.New("listen err")
}
Expect(s.ListenAndServe()).To(HaveOccurred())
Expect(receivedConf.NextProtos).To(Equal([]string{nextProtoH3}))
})
})
Context("ListenAndServeTLS", func() {