mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-04 20:57:36 +03:00
The RTTStats are used by the logging package. In order to instrument the congestion package, the RTTStats can't be part of that package any more (to avoid an import loop).
234 lines
9.3 KiB
Go
234 lines
9.3 KiB
Go
package handshake
|
|
|
|
import (
|
|
"crypto/tls"
|
|
"errors"
|
|
"net"
|
|
"unsafe"
|
|
|
|
"github.com/lucas-clemente/quic-go/internal/utils"
|
|
|
|
"github.com/marten-seemann/qtls"
|
|
. "github.com/onsi/ginkgo"
|
|
. "github.com/onsi/gomega"
|
|
)
|
|
|
|
type mockExtensionHandler struct {
|
|
get, received bool
|
|
}
|
|
|
|
var _ tlsExtensionHandler = &mockExtensionHandler{}
|
|
|
|
func (h *mockExtensionHandler) GetExtensions(msgType uint8) []qtls.Extension {
|
|
h.get = true
|
|
return nil
|
|
}
|
|
func (h *mockExtensionHandler) ReceivedExtensions(msgType uint8, exts []qtls.Extension) {
|
|
h.received = true
|
|
}
|
|
func (*mockExtensionHandler) TransportParameters() <-chan []byte { panic("not implemented") }
|
|
|
|
var _ = Describe("qtls.Config", func() {
|
|
It("sets MinVersion and MaxVersion", func() {
|
|
tlsConf := &tls.Config{MinVersion: tls.VersionTLS11, MaxVersion: tls.VersionTLS12}
|
|
qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil, &mockExtensionHandler{}, utils.NewRTTStats(), nil, nil, nil, nil, false)
|
|
Expect(qtlsConf.MinVersion).To(BeEquivalentTo(tls.VersionTLS13))
|
|
Expect(qtlsConf.MaxVersion).To(BeEquivalentTo(tls.VersionTLS13))
|
|
})
|
|
|
|
It("works when called with a nil config", func() {
|
|
qtlsConf := tlsConfigToQtlsConfig(nil, nil, &mockExtensionHandler{}, utils.NewRTTStats(), nil, nil, nil, nil, false)
|
|
Expect(qtlsConf).ToNot(BeNil())
|
|
})
|
|
|
|
It("sets the setter and getter function for TLS extensions", func() {
|
|
extHandler := &mockExtensionHandler{}
|
|
qtlsConf := tlsConfigToQtlsConfig(&tls.Config{}, nil, extHandler, utils.NewRTTStats(), nil, nil, nil, nil, false)
|
|
Expect(extHandler.get).To(BeFalse())
|
|
qtlsConf.GetExtensions(10)
|
|
Expect(extHandler.get).To(BeTrue())
|
|
Expect(extHandler.received).To(BeFalse())
|
|
qtlsConf.ReceivedExtensions(10, nil)
|
|
Expect(extHandler.received).To(BeTrue())
|
|
})
|
|
|
|
It("sets the Accept0RTT callback", func() {
|
|
accept0RTT := func([]byte) bool { return true }
|
|
qtlsConf := tlsConfigToQtlsConfig(nil, nil, &mockExtensionHandler{}, utils.NewRTTStats(), nil, nil, accept0RTT, nil, false)
|
|
Expect(qtlsConf.Accept0RTT).ToNot(BeNil())
|
|
Expect(qtlsConf.Accept0RTT(nil)).To(BeTrue())
|
|
})
|
|
|
|
It("sets the Accept0RTT callback", func() {
|
|
var called bool
|
|
rejected0RTT := func() { called = true }
|
|
qtlsConf := tlsConfigToQtlsConfig(nil, nil, &mockExtensionHandler{}, utils.NewRTTStats(), nil, nil, nil, rejected0RTT, false)
|
|
Expect(qtlsConf.Rejected0RTT).ToNot(BeNil())
|
|
qtlsConf.Rejected0RTT()
|
|
Expect(called).To(BeTrue())
|
|
})
|
|
|
|
It("enables 0-RTT", func() {
|
|
qtlsConf := tlsConfigToQtlsConfig(nil, nil, &mockExtensionHandler{}, utils.NewRTTStats(), nil, nil, nil, nil, false)
|
|
Expect(qtlsConf.Enable0RTT).To(BeFalse())
|
|
Expect(qtlsConf.MaxEarlyData).To(BeZero())
|
|
qtlsConf = tlsConfigToQtlsConfig(nil, nil, &mockExtensionHandler{}, utils.NewRTTStats(), nil, nil, nil, nil, true)
|
|
Expect(qtlsConf.Enable0RTT).To(BeTrue())
|
|
Expect(qtlsConf.MaxEarlyData).To(Equal(uint32(0xffffffff)))
|
|
})
|
|
|
|
It("initializes such that the session ticket key remains constant", func() {
|
|
tlsConf := &tls.Config{}
|
|
qtlsConf1 := tlsConfigToQtlsConfig(tlsConf, nil, &mockExtensionHandler{}, utils.NewRTTStats(), nil, nil, nil, nil, false)
|
|
qtlsConf2 := tlsConfigToQtlsConfig(tlsConf, nil, &mockExtensionHandler{}, utils.NewRTTStats(), nil, nil, nil, nil, false)
|
|
Expect(qtlsConf1.SessionTicketKey).ToNot(BeZero()) // should now contain a random value
|
|
Expect(qtlsConf1.SessionTicketKey).To(Equal(qtlsConf2.SessionTicketKey))
|
|
})
|
|
|
|
Context("GetConfigForClient callback", func() {
|
|
It("doesn't set it if absent", func() {
|
|
qtlsConf := tlsConfigToQtlsConfig(&tls.Config{}, nil, &mockExtensionHandler{}, utils.NewRTTStats(), nil, nil, nil, nil, false)
|
|
Expect(qtlsConf.GetConfigForClient).To(BeNil())
|
|
})
|
|
|
|
It("returns a qtls.Config", func() {
|
|
tlsConf := &tls.Config{
|
|
GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) {
|
|
return &tls.Config{ServerName: "foo.bar"}, nil
|
|
},
|
|
}
|
|
extHandler := &mockExtensionHandler{}
|
|
qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil, extHandler, utils.NewRTTStats(), nil, nil, nil, nil, false)
|
|
Expect(qtlsConf.GetConfigForClient).ToNot(BeNil())
|
|
confForClient, err := qtlsConf.GetConfigForClient(nil)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(confForClient.ServerName).To(Equal("foo.bar"))
|
|
Expect(confForClient).ToNot(BeNil())
|
|
Expect(confForClient.MinVersion).To(BeEquivalentTo(tls.VersionTLS13))
|
|
Expect(confForClient.MaxVersion).To(BeEquivalentTo(tls.VersionTLS13))
|
|
Expect(extHandler.get).To(BeFalse())
|
|
confForClient.GetExtensions(10)
|
|
Expect(extHandler.get).To(BeTrue())
|
|
})
|
|
|
|
It("returns errors", func() {
|
|
testErr := errors.New("test")
|
|
tlsConf := &tls.Config{
|
|
GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) {
|
|
return nil, testErr
|
|
},
|
|
}
|
|
qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil, &mockExtensionHandler{}, utils.NewRTTStats(), nil, nil, nil, nil, false)
|
|
_, err := qtlsConf.GetConfigForClient(nil)
|
|
Expect(err).To(MatchError(testErr))
|
|
})
|
|
|
|
It("returns nil when the callback returns nil", func() {
|
|
tlsConf := &tls.Config{
|
|
GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) {
|
|
return nil, nil
|
|
},
|
|
}
|
|
qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil, &mockExtensionHandler{}, utils.NewRTTStats(), nil, nil, nil, nil, false)
|
|
Expect(qtlsConf.GetConfigForClient(nil)).To(BeNil())
|
|
})
|
|
})
|
|
|
|
Context("GetCertificate callback", func() {
|
|
It("returns a certificate", func() {
|
|
tlsConf := &tls.Config{
|
|
GetCertificate: func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
|
|
return &tls.Certificate{Certificate: [][]byte{[]byte("foo"), []byte("bar")}}, nil
|
|
},
|
|
}
|
|
qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil, &mockExtensionHandler{}, utils.NewRTTStats(), nil, nil, nil, nil, false)
|
|
qtlsCert, err := qtlsConf.GetCertificate(nil)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(qtlsCert).ToNot(BeNil())
|
|
Expect(qtlsCert.Certificate).To(Equal([][]byte{[]byte("foo"), []byte("bar")}))
|
|
})
|
|
|
|
It("doesn't set it if absent", func() {
|
|
qtlsConf := tlsConfigToQtlsConfig(&tls.Config{}, nil, &mockExtensionHandler{}, utils.NewRTTStats(), nil, nil, nil, nil, false)
|
|
Expect(qtlsConf.GetCertificate).To(BeNil())
|
|
})
|
|
|
|
It("returns errors", func() {
|
|
tlsConf := &tls.Config{
|
|
GetCertificate: func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
|
|
return nil, errors.New("test")
|
|
},
|
|
}
|
|
qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil, &mockExtensionHandler{}, utils.NewRTTStats(), nil, nil, nil, nil, false)
|
|
_, err := qtlsConf.GetCertificate(nil)
|
|
Expect(err).To(MatchError("test"))
|
|
})
|
|
|
|
It("returns nil when the callback returns nil", func() {
|
|
tlsConf := &tls.Config{
|
|
GetCertificate: func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
|
|
return nil, nil
|
|
},
|
|
}
|
|
qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil, &mockExtensionHandler{}, utils.NewRTTStats(), nil, nil, nil, nil, false)
|
|
Expect(qtlsConf.GetCertificate(nil)).To(BeNil())
|
|
})
|
|
})
|
|
|
|
Context("ClientSessionCache", func() {
|
|
It("doesn't set if absent", func() {
|
|
qtlsConf := tlsConfigToQtlsConfig(&tls.Config{}, nil, &mockExtensionHandler{}, utils.NewRTTStats(), nil, nil, nil, nil, false)
|
|
Expect(qtlsConf.ClientSessionCache).To(BeNil())
|
|
})
|
|
|
|
It("puts a nil session state", func() {
|
|
csc := NewMockClientSessionCache(mockCtrl)
|
|
tlsConf := &tls.Config{ClientSessionCache: csc}
|
|
qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil, &mockExtensionHandler{}, utils.NewRTTStats(), nil, nil, nil, nil, false)
|
|
// put something
|
|
csc.EXPECT().Put("foobar", nil)
|
|
qtlsConf.ClientSessionCache.Put("foobar", nil)
|
|
})
|
|
})
|
|
})
|
|
|
|
var _ = Describe("qtls.Config generation", func() {
|
|
It("converts a qtls.ClientHelloInfo to a tls.ClientHelloInfo", func() {
|
|
chi := &qtlsClientHelloInfo{
|
|
CipherSuites: []uint16{1, 2, 3},
|
|
ServerName: "foo.bar",
|
|
SupportedCurves: []qtls.CurveID{4, 5, 6},
|
|
SupportedPoints: []uint8{7, 8, 9},
|
|
SignatureSchemes: []qtls.SignatureScheme{10, 11, 12},
|
|
SupportedProtos: []string{"foo", "bar"},
|
|
SupportedVersions: []uint16{13, 14, 15},
|
|
Conn: &net.UDPConn{},
|
|
config: &qtls.Config{
|
|
MinVersion: tls.VersionTLS10,
|
|
MaxVersion: tls.VersionTLS12,
|
|
CipherSuites: []uint16{16, 17, 18},
|
|
CurvePreferences: []qtls.CurveID{19, 20, 21},
|
|
},
|
|
}
|
|
tlsCHI := toTLSClientHelloInfo((*qtls.ClientHelloInfo)(unsafe.Pointer(chi)))
|
|
Expect(tlsCHI.CipherSuites).To(Equal([]uint16{1, 2, 3}))
|
|
Expect(tlsCHI.ServerName).To(Equal("foo.bar"))
|
|
Expect(tlsCHI.SupportedCurves).To(Equal([]tls.CurveID{4, 5, 6}))
|
|
Expect(tlsCHI.SupportedPoints).To(Equal([]uint8{7, 8, 9}))
|
|
Expect(tlsCHI.SignatureSchemes).To(Equal([]tls.SignatureScheme{10, 11, 12}))
|
|
Expect(tlsCHI.SupportedProtos).To(Equal([]string{"foo", "bar"}))
|
|
Expect(tlsCHI.SupportedVersions).To(Equal([]uint16{13, 14, 15}))
|
|
Expect(tlsCHI.Conn).To(Equal(&net.UDPConn{}))
|
|
c := (*clientHelloInfo)(unsafe.Pointer(tlsCHI))
|
|
Expect(c.config.CipherSuites).To(Equal([]uint16{16, 17, 18}))
|
|
Expect(c.config.MinVersion).To(BeEquivalentTo(tls.VersionTLS10))
|
|
Expect(c.config.MaxVersion).To(BeEquivalentTo(tls.VersionTLS12))
|
|
Expect(c.config.CurvePreferences).To(Equal([]tls.CurveID{19, 20, 21}))
|
|
})
|
|
|
|
It("converts a qtls.ClientHelloInfo to a tls.ClientHelloInfo, if no config is set", func() {
|
|
chi := &qtlsClientHelloInfo{CipherSuites: []uint16{13, 37}}
|
|
tlsCHI := toTLSClientHelloInfo((*qtls.ClientHelloInfo)(unsafe.Pointer(chi)))
|
|
Expect(tlsCHI.CipherSuites).To(Equal([]uint16{13, 37}))
|
|
})
|
|
})
|