uquic/internal/handshake/qtls_test.go
Marten Seemann 741dc28d74 move the RTTStats to the utils package
The RTTStats are used by the logging package. In order to instrument the
congestion package, the RTTStats can't be part of that package any more
(to avoid an import loop).
2020-07-23 11:53:08 +07:00

234 lines
9.3 KiB
Go

package handshake
import (
"crypto/tls"
"errors"
"net"
"unsafe"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/marten-seemann/qtls"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
type mockExtensionHandler struct {
get, received bool
}
var _ tlsExtensionHandler = &mockExtensionHandler{}
func (h *mockExtensionHandler) GetExtensions(msgType uint8) []qtls.Extension {
h.get = true
return nil
}
func (h *mockExtensionHandler) ReceivedExtensions(msgType uint8, exts []qtls.Extension) {
h.received = true
}
func (*mockExtensionHandler) TransportParameters() <-chan []byte { panic("not implemented") }
var _ = Describe("qtls.Config", func() {
It("sets MinVersion and MaxVersion", func() {
tlsConf := &tls.Config{MinVersion: tls.VersionTLS11, MaxVersion: tls.VersionTLS12}
qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil, &mockExtensionHandler{}, utils.NewRTTStats(), nil, nil, nil, nil, false)
Expect(qtlsConf.MinVersion).To(BeEquivalentTo(tls.VersionTLS13))
Expect(qtlsConf.MaxVersion).To(BeEquivalentTo(tls.VersionTLS13))
})
It("works when called with a nil config", func() {
qtlsConf := tlsConfigToQtlsConfig(nil, nil, &mockExtensionHandler{}, utils.NewRTTStats(), nil, nil, nil, nil, false)
Expect(qtlsConf).ToNot(BeNil())
})
It("sets the setter and getter function for TLS extensions", func() {
extHandler := &mockExtensionHandler{}
qtlsConf := tlsConfigToQtlsConfig(&tls.Config{}, nil, extHandler, utils.NewRTTStats(), nil, nil, nil, nil, false)
Expect(extHandler.get).To(BeFalse())
qtlsConf.GetExtensions(10)
Expect(extHandler.get).To(BeTrue())
Expect(extHandler.received).To(BeFalse())
qtlsConf.ReceivedExtensions(10, nil)
Expect(extHandler.received).To(BeTrue())
})
It("sets the Accept0RTT callback", func() {
accept0RTT := func([]byte) bool { return true }
qtlsConf := tlsConfigToQtlsConfig(nil, nil, &mockExtensionHandler{}, utils.NewRTTStats(), nil, nil, accept0RTT, nil, false)
Expect(qtlsConf.Accept0RTT).ToNot(BeNil())
Expect(qtlsConf.Accept0RTT(nil)).To(BeTrue())
})
It("sets the Accept0RTT callback", func() {
var called bool
rejected0RTT := func() { called = true }
qtlsConf := tlsConfigToQtlsConfig(nil, nil, &mockExtensionHandler{}, utils.NewRTTStats(), nil, nil, nil, rejected0RTT, false)
Expect(qtlsConf.Rejected0RTT).ToNot(BeNil())
qtlsConf.Rejected0RTT()
Expect(called).To(BeTrue())
})
It("enables 0-RTT", func() {
qtlsConf := tlsConfigToQtlsConfig(nil, nil, &mockExtensionHandler{}, utils.NewRTTStats(), nil, nil, nil, nil, false)
Expect(qtlsConf.Enable0RTT).To(BeFalse())
Expect(qtlsConf.MaxEarlyData).To(BeZero())
qtlsConf = tlsConfigToQtlsConfig(nil, nil, &mockExtensionHandler{}, utils.NewRTTStats(), nil, nil, nil, nil, true)
Expect(qtlsConf.Enable0RTT).To(BeTrue())
Expect(qtlsConf.MaxEarlyData).To(Equal(uint32(0xffffffff)))
})
It("initializes such that the session ticket key remains constant", func() {
tlsConf := &tls.Config{}
qtlsConf1 := tlsConfigToQtlsConfig(tlsConf, nil, &mockExtensionHandler{}, utils.NewRTTStats(), nil, nil, nil, nil, false)
qtlsConf2 := tlsConfigToQtlsConfig(tlsConf, nil, &mockExtensionHandler{}, utils.NewRTTStats(), nil, nil, nil, nil, false)
Expect(qtlsConf1.SessionTicketKey).ToNot(BeZero()) // should now contain a random value
Expect(qtlsConf1.SessionTicketKey).To(Equal(qtlsConf2.SessionTicketKey))
})
Context("GetConfigForClient callback", func() {
It("doesn't set it if absent", func() {
qtlsConf := tlsConfigToQtlsConfig(&tls.Config{}, nil, &mockExtensionHandler{}, utils.NewRTTStats(), nil, nil, nil, nil, false)
Expect(qtlsConf.GetConfigForClient).To(BeNil())
})
It("returns a qtls.Config", func() {
tlsConf := &tls.Config{
GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) {
return &tls.Config{ServerName: "foo.bar"}, nil
},
}
extHandler := &mockExtensionHandler{}
qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil, extHandler, utils.NewRTTStats(), nil, nil, nil, nil, false)
Expect(qtlsConf.GetConfigForClient).ToNot(BeNil())
confForClient, err := qtlsConf.GetConfigForClient(nil)
Expect(err).ToNot(HaveOccurred())
Expect(confForClient.ServerName).To(Equal("foo.bar"))
Expect(confForClient).ToNot(BeNil())
Expect(confForClient.MinVersion).To(BeEquivalentTo(tls.VersionTLS13))
Expect(confForClient.MaxVersion).To(BeEquivalentTo(tls.VersionTLS13))
Expect(extHandler.get).To(BeFalse())
confForClient.GetExtensions(10)
Expect(extHandler.get).To(BeTrue())
})
It("returns errors", func() {
testErr := errors.New("test")
tlsConf := &tls.Config{
GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) {
return nil, testErr
},
}
qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil, &mockExtensionHandler{}, utils.NewRTTStats(), nil, nil, nil, nil, false)
_, err := qtlsConf.GetConfigForClient(nil)
Expect(err).To(MatchError(testErr))
})
It("returns nil when the callback returns nil", func() {
tlsConf := &tls.Config{
GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) {
return nil, nil
},
}
qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil, &mockExtensionHandler{}, utils.NewRTTStats(), nil, nil, nil, nil, false)
Expect(qtlsConf.GetConfigForClient(nil)).To(BeNil())
})
})
Context("GetCertificate callback", func() {
It("returns a certificate", func() {
tlsConf := &tls.Config{
GetCertificate: func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
return &tls.Certificate{Certificate: [][]byte{[]byte("foo"), []byte("bar")}}, nil
},
}
qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil, &mockExtensionHandler{}, utils.NewRTTStats(), nil, nil, nil, nil, false)
qtlsCert, err := qtlsConf.GetCertificate(nil)
Expect(err).ToNot(HaveOccurred())
Expect(qtlsCert).ToNot(BeNil())
Expect(qtlsCert.Certificate).To(Equal([][]byte{[]byte("foo"), []byte("bar")}))
})
It("doesn't set it if absent", func() {
qtlsConf := tlsConfigToQtlsConfig(&tls.Config{}, nil, &mockExtensionHandler{}, utils.NewRTTStats(), nil, nil, nil, nil, false)
Expect(qtlsConf.GetCertificate).To(BeNil())
})
It("returns errors", func() {
tlsConf := &tls.Config{
GetCertificate: func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
return nil, errors.New("test")
},
}
qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil, &mockExtensionHandler{}, utils.NewRTTStats(), nil, nil, nil, nil, false)
_, err := qtlsConf.GetCertificate(nil)
Expect(err).To(MatchError("test"))
})
It("returns nil when the callback returns nil", func() {
tlsConf := &tls.Config{
GetCertificate: func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
return nil, nil
},
}
qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil, &mockExtensionHandler{}, utils.NewRTTStats(), nil, nil, nil, nil, false)
Expect(qtlsConf.GetCertificate(nil)).To(BeNil())
})
})
Context("ClientSessionCache", func() {
It("doesn't set if absent", func() {
qtlsConf := tlsConfigToQtlsConfig(&tls.Config{}, nil, &mockExtensionHandler{}, utils.NewRTTStats(), nil, nil, nil, nil, false)
Expect(qtlsConf.ClientSessionCache).To(BeNil())
})
It("puts a nil session state", func() {
csc := NewMockClientSessionCache(mockCtrl)
tlsConf := &tls.Config{ClientSessionCache: csc}
qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil, &mockExtensionHandler{}, utils.NewRTTStats(), nil, nil, nil, nil, false)
// put something
csc.EXPECT().Put("foobar", nil)
qtlsConf.ClientSessionCache.Put("foobar", nil)
})
})
})
var _ = Describe("qtls.Config generation", func() {
It("converts a qtls.ClientHelloInfo to a tls.ClientHelloInfo", func() {
chi := &qtlsClientHelloInfo{
CipherSuites: []uint16{1, 2, 3},
ServerName: "foo.bar",
SupportedCurves: []qtls.CurveID{4, 5, 6},
SupportedPoints: []uint8{7, 8, 9},
SignatureSchemes: []qtls.SignatureScheme{10, 11, 12},
SupportedProtos: []string{"foo", "bar"},
SupportedVersions: []uint16{13, 14, 15},
Conn: &net.UDPConn{},
config: &qtls.Config{
MinVersion: tls.VersionTLS10,
MaxVersion: tls.VersionTLS12,
CipherSuites: []uint16{16, 17, 18},
CurvePreferences: []qtls.CurveID{19, 20, 21},
},
}
tlsCHI := toTLSClientHelloInfo((*qtls.ClientHelloInfo)(unsafe.Pointer(chi)))
Expect(tlsCHI.CipherSuites).To(Equal([]uint16{1, 2, 3}))
Expect(tlsCHI.ServerName).To(Equal("foo.bar"))
Expect(tlsCHI.SupportedCurves).To(Equal([]tls.CurveID{4, 5, 6}))
Expect(tlsCHI.SupportedPoints).To(Equal([]uint8{7, 8, 9}))
Expect(tlsCHI.SignatureSchemes).To(Equal([]tls.SignatureScheme{10, 11, 12}))
Expect(tlsCHI.SupportedProtos).To(Equal([]string{"foo", "bar"}))
Expect(tlsCHI.SupportedVersions).To(Equal([]uint16{13, 14, 15}))
Expect(tlsCHI.Conn).To(Equal(&net.UDPConn{}))
c := (*clientHelloInfo)(unsafe.Pointer(tlsCHI))
Expect(c.config.CipherSuites).To(Equal([]uint16{16, 17, 18}))
Expect(c.config.MinVersion).To(BeEquivalentTo(tls.VersionTLS10))
Expect(c.config.MaxVersion).To(BeEquivalentTo(tls.VersionTLS12))
Expect(c.config.CurvePreferences).To(Equal([]tls.CurveID{19, 20, 21}))
})
It("converts a qtls.ClientHelloInfo to a tls.ClientHelloInfo, if no config is set", func() {
chi := &qtlsClientHelloInfo{CipherSuites: []uint16{13, 37}}
tlsCHI := toTLSClientHelloInfo((*qtls.ClientHelloInfo)(unsafe.Pointer(chi)))
Expect(tlsCHI.CipherSuites).To(Equal([]uint16{13, 37}))
})
})