mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-05 13:17:36 +03:00
149 lines
4.9 KiB
Go
149 lines
4.9 KiB
Go
package handshake
|
|
|
|
import (
|
|
"crypto/tls"
|
|
"errors"
|
|
|
|
"github.com/marten-seemann/qtls"
|
|
. "github.com/onsi/ginkgo"
|
|
. "github.com/onsi/gomega"
|
|
)
|
|
|
|
type mockExtensionHandler struct {
|
|
get, received bool
|
|
}
|
|
|
|
var _ tlsExtensionHandler = &mockExtensionHandler{}
|
|
|
|
func (h *mockExtensionHandler) GetExtensions(msgType uint8) []qtls.Extension {
|
|
h.get = true
|
|
return nil
|
|
}
|
|
func (h *mockExtensionHandler) ReceivedExtensions(msgType uint8, exts []qtls.Extension) {
|
|
h.received = true
|
|
}
|
|
func (*mockExtensionHandler) TransportParameters() <-chan []byte { panic("not implemented") }
|
|
|
|
type mockClientSessionCache struct {
|
|
get, put string
|
|
session *tls.ClientSessionState
|
|
}
|
|
|
|
var _ tls.ClientSessionCache = &mockClientSessionCache{}
|
|
|
|
func (c *mockClientSessionCache) Get(sessionKey string) (session *tls.ClientSessionState, ok bool) {
|
|
c.get = sessionKey
|
|
return c.session, false
|
|
}
|
|
func (c *mockClientSessionCache) Put(sessionKey string, cs *tls.ClientSessionState) {
|
|
c.put = sessionKey
|
|
}
|
|
|
|
var _ = Describe("qtls.Config generation", func() {
|
|
It("sets MinVersion and MaxVersion", func() {
|
|
tlsConf := &tls.Config{MinVersion: tls.VersionTLS11, MaxVersion: tls.VersionTLS12}
|
|
qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil, &mockExtensionHandler{})
|
|
Expect(qtlsConf.MinVersion).To(BeEquivalentTo(tls.VersionTLS13))
|
|
Expect(qtlsConf.MaxVersion).To(BeEquivalentTo(tls.VersionTLS13))
|
|
})
|
|
|
|
It("works when called with a nil config", func() {
|
|
qtlsConf := tlsConfigToQtlsConfig(nil, nil, &mockExtensionHandler{})
|
|
Expect(qtlsConf).ToNot(BeNil())
|
|
})
|
|
|
|
It("sets the setter and getter function for TLS extensions", func() {
|
|
extHandler := &mockExtensionHandler{}
|
|
qtlsConf := tlsConfigToQtlsConfig(&tls.Config{}, nil, extHandler)
|
|
Expect(extHandler.get).To(BeFalse())
|
|
qtlsConf.GetExtensions(10)
|
|
Expect(extHandler.get).To(BeTrue())
|
|
Expect(extHandler.received).To(BeFalse())
|
|
qtlsConf.ReceivedExtensions(10, nil)
|
|
Expect(extHandler.received).To(BeTrue())
|
|
})
|
|
|
|
Context("GetConfigForClient callback", func() {
|
|
It("doesn't set it if absent", func() {
|
|
qtlsConf := tlsConfigToQtlsConfig(&tls.Config{}, nil, &mockExtensionHandler{})
|
|
Expect(qtlsConf.GetConfigForClient).To(BeNil())
|
|
})
|
|
|
|
It("returns a qtls.Config", func() {
|
|
tlsConf := &tls.Config{
|
|
GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) {
|
|
return &tls.Config{ServerName: "foo.bar"}, nil
|
|
},
|
|
}
|
|
extHandler := &mockExtensionHandler{}
|
|
qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil, extHandler)
|
|
Expect(qtlsConf.GetConfigForClient).ToNot(BeNil())
|
|
confForClient, err := qtlsConf.GetConfigForClient(nil)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(confForClient.ServerName).To(Equal("foo.bar"))
|
|
Expect(confForClient).ToNot(BeNil())
|
|
Expect(confForClient.MinVersion).To(BeEquivalentTo(tls.VersionTLS13))
|
|
Expect(confForClient.MaxVersion).To(BeEquivalentTo(tls.VersionTLS13))
|
|
Expect(extHandler.get).To(BeFalse())
|
|
confForClient.GetExtensions(10)
|
|
Expect(extHandler.get).To(BeTrue())
|
|
})
|
|
|
|
It("returns errors", func() {
|
|
testErr := errors.New("test")
|
|
tlsConf := &tls.Config{
|
|
GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) {
|
|
return nil, testErr
|
|
},
|
|
}
|
|
qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil, &mockExtensionHandler{})
|
|
_, err := qtlsConf.GetConfigForClient(nil)
|
|
Expect(err).To(MatchError(testErr))
|
|
})
|
|
|
|
It("returns nil when the callback returns nil", func() {
|
|
tlsConf := &tls.Config{
|
|
GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) {
|
|
return nil, nil
|
|
},
|
|
}
|
|
qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil, &mockExtensionHandler{})
|
|
Expect(qtlsConf.GetConfigForClient(nil)).To(BeNil())
|
|
})
|
|
})
|
|
|
|
Context("ClientSessionCache", func() {
|
|
It("doesn't set if absent", func() {
|
|
qtlsConf := tlsConfigToQtlsConfig(&tls.Config{}, nil, &mockExtensionHandler{})
|
|
Expect(qtlsConf.ClientSessionCache).To(BeNil())
|
|
})
|
|
|
|
It("sets it, and puts and gets session states", func() {
|
|
state := &qtls.ClientSessionState{}
|
|
csc := &mockClientSessionCache{session: &tls.ClientSessionState{}}
|
|
tlsConf := &tls.Config{ClientSessionCache: csc}
|
|
qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil, &mockExtensionHandler{})
|
|
Expect(qtlsConf.ClientSessionCache).ToNot(BeNil())
|
|
Expect(csc.put).To(BeEmpty())
|
|
qtlsConf.ClientSessionCache.Put("foobar", state)
|
|
Expect(csc.put).To(Equal("foobar"))
|
|
Expect(csc.get).To(BeEmpty())
|
|
sess, _ := qtlsConf.ClientSessionCache.Get("raboof")
|
|
Expect(csc.get).To(Equal("raboof"))
|
|
Expect(sess).To(Equal(state))
|
|
})
|
|
|
|
It("sets it, and gets nil session states for unknown keys", func() {
|
|
csc := &mockClientSessionCache{}
|
|
tlsConf := &tls.Config{
|
|
ClientSessionCache: csc,
|
|
}
|
|
qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil, &mockExtensionHandler{})
|
|
Expect(qtlsConf.ClientSessionCache).ToNot(BeNil())
|
|
Expect(csc.get).To(BeEmpty())
|
|
sess, _ := qtlsConf.ClientSessionCache.Get("raboof")
|
|
Expect(csc.get).To(Equal("raboof"))
|
|
Expect(sess).To(BeNil())
|
|
})
|
|
})
|
|
})
|