mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-04 20:57:36 +03:00
make sure the TLS config contains the H3 ALPN token in server and client
This commit is contained in:
parent
8afed81c49
commit
afa4615c4d
4 changed files with 54 additions and 7 deletions
|
@ -55,7 +55,9 @@ func newClient(
|
|||
if tlsConf == nil {
|
||||
tlsConf = &tls.Config{}
|
||||
}
|
||||
tlsConf.NextProtos = []string{"h3-19"}
|
||||
if !strSliceContains(tlsConf.NextProtos, nextProtoH3) {
|
||||
tlsConf.NextProtos = append(tlsConf.NextProtos, nextProtoH3)
|
||||
}
|
||||
if quicConfig == nil {
|
||||
quicConfig = defaultQuicConfig
|
||||
}
|
||||
|
|
|
@ -42,11 +42,12 @@ var _ = Describe("Client", func() {
|
|||
dialAddr = origDialAddr
|
||||
})
|
||||
|
||||
It("uses the default QUIC config if none is give", func() {
|
||||
It("uses the default QUIC and TLS config if none is give", func() {
|
||||
client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil)
|
||||
var dialAddrCalled bool
|
||||
dialAddr = func(_ string, _ *tls.Config, quicConf *quic.Config) (quic.Session, error) {
|
||||
dialAddr = func(_ string, tlsConf *tls.Config, quicConf *quic.Config) (quic.Session, error) {
|
||||
Expect(quicConf).To(Equal(defaultQuicConfig))
|
||||
Expect(tlsConf.NextProtos).To(Equal([]string{nextProtoH3}))
|
||||
dialAddrCalled = true
|
||||
return nil, errors.New("test done")
|
||||
}
|
||||
|
@ -69,7 +70,10 @@ var _ = Describe("Client", func() {
|
|||
})
|
||||
|
||||
It("uses the TLS config and QUIC config", func() {
|
||||
tlsConf := &tls.Config{ServerName: "foo.bar"}
|
||||
tlsConf := &tls.Config{
|
||||
ServerName: "foo.bar",
|
||||
NextProtos: []string{"proto foo", "proto bar"},
|
||||
}
|
||||
quicConf := &quic.Config{IdleTimeout: time.Nanosecond}
|
||||
client = newClient("localhost:1337", tlsConf, &roundTripperOpts{}, quicConf, nil)
|
||||
var dialAddrCalled bool
|
||||
|
@ -79,7 +83,8 @@ var _ = Describe("Client", func() {
|
|||
quicConfP *quic.Config,
|
||||
) (quic.Session, error) {
|
||||
Expect(hostname).To(Equal("localhost:1337"))
|
||||
Expect(tlsConfP).To(Equal(tlsConf))
|
||||
Expect(tlsConfP.ServerName).To(Equal(tlsConf.ServerName))
|
||||
Expect(tlsConfP.NextProtos).To(Equal([]string{"proto foo", "proto bar", nextProtoH3}))
|
||||
Expect(quicConfP.IdleTimeout).To(Equal(quicConf.IdleTimeout))
|
||||
dialAddrCalled = true
|
||||
return nil, errors.New("test done")
|
||||
|
|
|
@ -25,6 +25,8 @@ var (
|
|||
quicListenAddr = quic.ListenAddr
|
||||
)
|
||||
|
||||
const nextProtoH3 = "h3-20"
|
||||
|
||||
// Server is a HTTP2 server listening for QUIC connections.
|
||||
type Server struct {
|
||||
*http.Server
|
||||
|
@ -88,6 +90,14 @@ func (s *Server) serveImpl(tlsConfig *tls.Config, conn net.PacketConn) error {
|
|||
return errors.New("ListenAndServe may only be called once")
|
||||
}
|
||||
|
||||
if tlsConfig == nil {
|
||||
tlsConfig = &tls.Config{}
|
||||
}
|
||||
|
||||
if !strSliceContains(tlsConfig.NextProtos, nextProtoH3) {
|
||||
tlsConfig.NextProtos = append(tlsConfig.NextProtos, nextProtoH3)
|
||||
}
|
||||
|
||||
var ln quic.Listener
|
||||
var err error
|
||||
if conn == nil {
|
||||
|
@ -353,3 +363,12 @@ func ListenAndServe(addr, certFile, keyFile string, handler http.Handler) error
|
|||
return err
|
||||
}
|
||||
}
|
||||
|
||||
func strSliceContains(ss []string, s string) bool {
|
||||
for _, v := range ss {
|
||||
if v == s {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
|
|
@ -349,10 +349,10 @@ var _ = Describe("Server", func() {
|
|||
Expect(s.Close()).To(Succeed())
|
||||
})
|
||||
|
||||
It("uses the quic.Config to start the quic server", func() {
|
||||
It("uses the quic.Config to start the QUIC server", func() {
|
||||
conf := &quic.Config{HandshakeTimeout: time.Nanosecond}
|
||||
var receivedConf *quic.Config
|
||||
quicListenAddr = func(addr string, tlsConf *tls.Config, config *quic.Config) (quic.Listener, error) {
|
||||
quicListenAddr = func(addr string, _ *tls.Config, config *quic.Config) (quic.Listener, error) {
|
||||
receivedConf = config
|
||||
return nil, errors.New("listen err")
|
||||
}
|
||||
|
@ -360,6 +360,27 @@ var _ = Describe("Server", func() {
|
|||
Expect(s.ListenAndServe()).To(HaveOccurred())
|
||||
Expect(receivedConf).To(Equal(conf))
|
||||
})
|
||||
|
||||
It("adds the ALPN token to the tls.Config", func() {
|
||||
var receivedConf *tls.Config
|
||||
quicListenAddr = func(addr string, tlsConf *tls.Config, _ *quic.Config) (quic.Listener, error) {
|
||||
receivedConf = tlsConf
|
||||
return nil, errors.New("listen err")
|
||||
}
|
||||
s.TLSConfig = &tls.Config{NextProtos: []string{"foo", "bar"}}
|
||||
Expect(s.ListenAndServe()).To(HaveOccurred())
|
||||
Expect(receivedConf.NextProtos).To(Equal([]string{"foo", "bar", nextProtoH3}))
|
||||
})
|
||||
|
||||
It("uses the ALPN token if no tls.Config is given", func() {
|
||||
var receivedConf *tls.Config
|
||||
quicListenAddr = func(addr string, tlsConf *tls.Config, _ *quic.Config) (quic.Listener, error) {
|
||||
receivedConf = tlsConf
|
||||
return nil, errors.New("listen err")
|
||||
}
|
||||
Expect(s.ListenAndServe()).To(HaveOccurred())
|
||||
Expect(receivedConf.NextProtos).To(Equal([]string{nextProtoH3}))
|
||||
})
|
||||
})
|
||||
|
||||
Context("ListenAndServeTLS", func() {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue