uquic/internal/qtls/qtls_test.go
2020-08-18 14:26:23 +07:00

227 lines
8 KiB
Go

package qtls
import (
"crypto/tls"
"errors"
"net"
"unsafe"
mocktls "github.com/lucas-clemente/quic-go/internal/mocks/tls"
"github.com/marten-seemann/qtls"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("Config", func() {
It("sets MinVersion and MaxVersion", func() {
tlsConf := &tls.Config{MinVersion: tls.VersionTLS11, MaxVersion: tls.VersionTLS12}
qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil)
Expect(qtlsConf.MinVersion).To(BeEquivalentTo(tls.VersionTLS13))
Expect(qtlsConf.MaxVersion).To(BeEquivalentTo(tls.VersionTLS13))
})
It("works when called with a nil config", func() {
qtlsConf := tlsConfigToQtlsConfig(nil, nil)
Expect(qtlsConf).ToNot(BeNil())
})
It("sets the setter and getter function for TLS extensions", func() {
var get, received bool
extraConfig := &ExtraConfig{
GetExtensions: func(handshakeMessageType uint8) []Extension { get = true; return nil },
ReceivedExtensions: func(handshakeMessageType uint8, exts []qtls.Extension) { received = true },
}
qtlsConf := tlsConfigToQtlsConfig(&tls.Config{}, extraConfig)
qtlsConf.GetExtensions(10)
Expect(get).To(BeTrue())
Expect(received).To(BeFalse())
qtlsConf.ReceivedExtensions(10, nil)
Expect(received).To(BeTrue())
})
It("sets the Accept0RTT callback", func() {
qtlsConf := tlsConfigToQtlsConfig(nil, &ExtraConfig{Accept0RTT: func([]byte) bool { return true }})
Expect(qtlsConf.Accept0RTT).ToNot(BeNil())
Expect(qtlsConf.Accept0RTT(nil)).To(BeTrue())
})
It("sets the Rejected0RTT callback", func() {
var called bool
qtlsConf := tlsConfigToQtlsConfig(nil, &ExtraConfig{Rejected0RTT: func() { called = true }})
Expect(qtlsConf.Rejected0RTT).ToNot(BeNil())
qtlsConf.Rejected0RTT()
Expect(called).To(BeTrue())
})
It("sets MaxEarlyData", func() {
qtlsConf := tlsConfigToQtlsConfig(nil, nil)
Expect(qtlsConf.MaxEarlyData).To(BeZero())
qtlsConf = tlsConfigToQtlsConfig(nil, &ExtraConfig{MaxEarlyData: 1337})
Expect(qtlsConf.MaxEarlyData).To(Equal(uint32(1337)))
})
It("enables 0-RTT", func() {
qtlsConf := tlsConfigToQtlsConfig(nil, nil)
Expect(qtlsConf.Enable0RTT).To(BeFalse())
qtlsConf = tlsConfigToQtlsConfig(nil, &ExtraConfig{Enable0RTT: true})
Expect(qtlsConf.Enable0RTT).To(BeTrue())
})
It("initializes such that the session ticket key remains constant", func() {
tlsConf := &tls.Config{}
qtlsConf1 := tlsConfigToQtlsConfig(tlsConf, nil)
qtlsConf2 := tlsConfigToQtlsConfig(tlsConf, nil)
Expect(qtlsConf1.SessionTicketKey).ToNot(BeZero()) // should now contain a random value
Expect(qtlsConf1.SessionTicketKey).To(Equal(qtlsConf2.SessionTicketKey))
})
Context("GetConfigForClient callback", func() {
It("doesn't set it if absent", func() {
qtlsConf := tlsConfigToQtlsConfig(nil, nil)
Expect(qtlsConf.GetConfigForClient).To(BeNil())
})
It("returns a Config", func() {
tlsConf := &tls.Config{
GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) {
return &tls.Config{ServerName: "foo.bar"}, nil
},
}
var received bool
qtlsConf := tlsConfigToQtlsConfig(tlsConf, &ExtraConfig{ReceivedExtensions: func(uint8, []Extension) { received = true }})
Expect(qtlsConf.GetConfigForClient).ToNot(BeNil())
confForClient, err := qtlsConf.GetConfigForClient(nil)
Expect(err).ToNot(HaveOccurred())
Expect(confForClient.ServerName).To(Equal("foo.bar"))
Expect(confForClient).ToNot(BeNil())
Expect(confForClient.MinVersion).To(BeEquivalentTo(tls.VersionTLS13))
Expect(confForClient.MaxVersion).To(BeEquivalentTo(tls.VersionTLS13))
Expect(received).To(BeFalse())
Expect(confForClient.ReceivedExtensions).ToNot(BeNil())
confForClient.ReceivedExtensions(10, nil)
Expect(received).To(BeTrue())
})
It("returns errors", func() {
testErr := errors.New("test")
tlsConf := &tls.Config{
GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) {
return nil, testErr
},
}
qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil)
_, err := qtlsConf.GetConfigForClient(nil)
Expect(err).To(MatchError(testErr))
})
It("returns nil when the callback returns nil", func() {
tlsConf := &tls.Config{
GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) {
return nil, nil
},
}
qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil)
Expect(qtlsConf.GetConfigForClient(nil)).To(BeNil())
})
})
Context("GetCertificate callback", func() {
It("returns a certificate", func() {
tlsConf := &tls.Config{
GetCertificate: func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
return &tls.Certificate{Certificate: [][]byte{[]byte("foo"), []byte("bar")}}, nil
},
}
qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil)
qtlsCert, err := qtlsConf.GetCertificate(nil)
Expect(err).ToNot(HaveOccurred())
Expect(qtlsCert).ToNot(BeNil())
Expect(qtlsCert.Certificate).To(Equal([][]byte{[]byte("foo"), []byte("bar")}))
})
It("doesn't set it if absent", func() {
qtlsConf := tlsConfigToQtlsConfig(&tls.Config{}, nil)
Expect(qtlsConf.GetCertificate).To(BeNil())
})
It("returns errors", func() {
tlsConf := &tls.Config{
GetCertificate: func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
return nil, errors.New("test")
},
}
qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil)
_, err := qtlsConf.GetCertificate(nil)
Expect(err).To(MatchError("test"))
})
It("returns nil when the callback returns nil", func() {
tlsConf := &tls.Config{
GetCertificate: func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
return nil, nil
},
}
qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil)
Expect(qtlsConf.GetCertificate(nil)).To(BeNil())
})
})
Context("ClientSessionCache", func() {
It("doesn't set if absent", func() {
qtlsConf := tlsConfigToQtlsConfig(&tls.Config{}, nil)
Expect(qtlsConf.ClientSessionCache).To(BeNil())
})
It("puts a nil session state", func() {
csc := mocktls.NewMockClientSessionCache(mockCtrl)
tlsConf := &tls.Config{ClientSessionCache: csc}
qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil)
// put something
csc.EXPECT().Put("foobar", nil)
qtlsConf.ClientSessionCache.Put("foobar", nil)
})
})
})
var _ = Describe("Config generation", func() {
It("converts a ClientHelloInfo to a tls.ClientHelloInfo", func() {
chi := &qtlsClientHelloInfo{
CipherSuites: []uint16{1, 2, 3},
ServerName: "foo.bar",
SupportedCurves: []tls.CurveID{4, 5, 6},
SupportedPoints: []uint8{7, 8, 9},
SignatureSchemes: []tls.SignatureScheme{10, 11, 12},
SupportedProtos: []string{"foo", "bar"},
SupportedVersions: []uint16{13, 14, 15},
Conn: &net.UDPConn{},
config: &Config{
MinVersion: tls.VersionTLS10,
MaxVersion: tls.VersionTLS12,
CipherSuites: []uint16{16, 17, 18},
CurvePreferences: []tls.CurveID{19, 20, 21},
},
}
tlsCHI := toTLSClientHelloInfo((*ClientHelloInfo)(unsafe.Pointer(chi)))
Expect(tlsCHI.CipherSuites).To(Equal([]uint16{1, 2, 3}))
Expect(tlsCHI.ServerName).To(Equal("foo.bar"))
Expect(tlsCHI.SupportedCurves).To(Equal([]tls.CurveID{4, 5, 6}))
Expect(tlsCHI.SupportedPoints).To(Equal([]uint8{7, 8, 9}))
Expect(tlsCHI.SignatureSchemes).To(Equal([]tls.SignatureScheme{10, 11, 12}))
Expect(tlsCHI.SupportedProtos).To(Equal([]string{"foo", "bar"}))
Expect(tlsCHI.SupportedVersions).To(Equal([]uint16{13, 14, 15}))
Expect(tlsCHI.Conn).To(Equal(&net.UDPConn{}))
c := (*clientHelloInfo)(unsafe.Pointer(tlsCHI))
Expect(c.config.CipherSuites).To(Equal([]uint16{16, 17, 18}))
Expect(c.config.MinVersion).To(BeEquivalentTo(tls.VersionTLS10))
Expect(c.config.MaxVersion).To(BeEquivalentTo(tls.VersionTLS12))
Expect(c.config.CurvePreferences).To(Equal([]tls.CurveID{19, 20, 21}))
})
It("converts a ClientHelloInfo to a tls.ClientHelloInfo, if no config is set", func() {
chi := &qtlsClientHelloInfo{CipherSuites: []uint16{13, 37}}
tlsCHI := toTLSClientHelloInfo((*ClientHelloInfo)(unsafe.Pointer(chi)))
Expect(tlsCHI.CipherSuites).To(Equal([]uint16{13, 37}))
})
})