move all dependencies on qtls to a separate package

This commit is contained in:
Marten Seemann 2020-08-13 10:23:33 +07:00
parent 524da2213c
commit 977dbc828c
29 changed files with 572 additions and 478 deletions

View file

@ -5,7 +5,7 @@ import (
"encoding/binary"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/marten-seemann/qtls"
"github.com/lucas-clemente/quic-go/internal/qtls"
)
func createAEAD(suite *qtls.CipherSuiteTLS13, trafficSecret []byte) cipher.AEAD {

View file

@ -8,7 +8,7 @@ import (
"fmt"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/marten-seemann/qtls"
"github.com/lucas-clemente/quic-go/internal/qtls"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"

View file

@ -12,10 +12,10 @@ import (
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/qerr"
"github.com/lucas-clemente/quic-go/internal/qtls"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/internal/wire"
"github.com/lucas-clemente/quic-go/logging"
"github.com/marten-seemann/qtls"
)
// TLS unexpected_message alert
@ -60,9 +60,32 @@ func (m messageType) String() string {
const clientSessionStateRevision = 3
type conn struct {
localAddr, remoteAddr net.Addr
}
func newConn(local, remote net.Addr) net.Conn {
return &conn{
localAddr: local,
remoteAddr: remote,
}
}
var _ net.Conn = &conn{}
func (c *conn) Read([]byte) (int, error) { return 0, nil }
func (c *conn) Write([]byte) (int, error) { return 0, nil }
func (c *conn) Close() error { return nil }
func (c *conn) RemoteAddr() net.Addr { return c.remoteAddr }
func (c *conn) LocalAddr() net.Addr { return c.localAddr }
func (c *conn) SetReadDeadline(time.Time) error { return nil }
func (c *conn) SetWriteDeadline(time.Time) error { return nil }
func (c *conn) SetDeadline(time.Time) error { return nil }
type cryptoSetup struct {
tlsConf *qtls.Config
conn *qtls.Conn
tlsConf *tls.Config
extraConf *qtls.ExtraConfig
conn *qtls.Conn
messageChan chan []byte
@ -152,7 +175,7 @@ func NewCryptoSetupClient(
logger,
protocol.PerspectiveClient,
)
cs.conn = qtls.Client(newConn(localAddr, remoteAddr), cs.tlsConf)
cs.conn = qtls.Client(newConn(localAddr, remoteAddr), cs.tlsConf, cs.extraConf)
return cs, clientHelloWritten
}
@ -184,7 +207,7 @@ func NewCryptoSetupServer(
logger,
protocol.PerspectiveServer,
)
cs.conn = qtls.Server(newConn(localAddr, remoteAddr), cs.tlsConf)
cs.conn = qtls.Server(newConn(localAddr, remoteAddr), cs.tlsConf, cs.extraConf)
return cs
}
@ -208,6 +231,7 @@ func newCryptoSetup(
}
extHandler := newExtensionHandler(tp.Marshal(perspective), perspective)
cs := &cryptoSetup{
tlsConf: tlsConf,
initialStream: initialStream,
initialSealer: initialSealer,
initialOpener: initialOpener,
@ -231,8 +255,22 @@ func newCryptoSetup(
writeRecord: make(chan struct{}, 1),
closeChan: make(chan struct{}),
}
qtlsConf := tlsConfigToQtlsConfig(tlsConf, cs, extHandler, rttStats, cs.marshalDataForSessionState, cs.handleDataFromSessionState, cs.accept0RTT, cs.rejected0RTT, enable0RTT)
cs.tlsConf = qtlsConf
var maxEarlyData uint32
if enable0RTT {
maxEarlyData = 0xffffffff
}
cs.extraConf = &qtls.ExtraConfig{
GetExtensions: extHandler.GetExtensions,
ReceivedExtensions: extHandler.ReceivedExtensions,
AlternativeRecordLayer: cs,
EnforceNextProtoSelection: true,
MaxEarlyData: maxEarlyData,
Accept0RTT: cs.accept0RTT,
Rejected0RTT: cs.rejected0RTT,
Enable0RTT: enable0RTT,
GetAppDataForSessionState: cs.marshalDataForSessionState,
SetAppDataFromSessionState: cs.handleDataFromSessionState,
}
return cs, cs.clientHelloWrittenChan
}
@ -499,7 +537,7 @@ func (h *cryptoSetup) handleDataFromSessionStateImpl(data []byte) (*wire.Transpo
func (h *cryptoSetup) GetSessionTicket() ([]byte, error) {
var appData []byte
// Save transport parameters to the session ticket if we're allowing 0-RTT.
if h.tlsConf.MaxEarlyData > 0 {
if h.extraConf.MaxEarlyData > 0 {
appData = (&sessionTicket{
Parameters: h.ourParams,
RTT: h.rttStats.SmoothedRTT(),
@ -819,5 +857,5 @@ func (h *cryptoSetup) Get1RTTOpener() (ShortHeaderOpener, error) {
}
func (h *cryptoSetup) ConnectionState() ConnectionState {
return h.conn.ConnectionState()
return qtls.GetConnectionState(h.conn)
}

View file

@ -1,22 +1,20 @@
package handshake
import (
"bytes"
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"errors"
"math/big"
"time"
mocktls "github.com/lucas-clemente/quic-go/internal/mocks/tls"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/qerr"
"github.com/lucas-clemente/quic-go/internal/testdata"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/internal/wire"
"github.com/marten-seemann/qtls"
"github.com/golang/mock/gomock"
@ -74,48 +72,6 @@ var _ = Describe("Crypto Setup TLS", func() {
}
})
It("creates a qtls.Config", func() {
tlsConf := &tls.Config{
ServerName: "quic.clemente.io",
GetCertificate: func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
return nil, errors.New("GetCertificate")
},
GetClientCertificate: func(*tls.CertificateRequestInfo) (*tls.Certificate, error) {
return nil, errors.New("GetClientCertificate")
},
GetConfigForClient: func(ch *tls.ClientHelloInfo) (*tls.Config, error) {
return &tls.Config{ServerName: ch.ServerName}, nil
},
}
var token protocol.StatelessResetToken
server := NewCryptoSetupServer(
&bytes.Buffer{},
&bytes.Buffer{},
protocol.ConnectionID{},
nil,
nil,
&wire.TransportParameters{StatelessResetToken: &token},
NewMockHandshakeRunner(mockCtrl),
tlsConf,
false,
&utils.RTTStats{},
nil,
utils.DefaultLogger.WithPrefix("server"),
)
qtlsConf := server.(*cryptoSetup).tlsConf
Expect(qtlsConf.ServerName).To(Equal(tlsConf.ServerName))
_, getCertificateErr := qtlsConf.GetCertificate(nil)
Expect(getCertificateErr).To(MatchError("GetCertificate"))
_, getClientCertificateErr := qtlsConf.GetClientCertificate(nil)
Expect(getClientCertificateErr).To(MatchError("GetClientCertificate"))
cconf, err := qtlsConf.GetConfigForClient(&qtls.ClientHelloInfo{ServerName: "foo.bar"})
Expect(err).ToNot(HaveOccurred())
Expect(cconf.ServerName).To(Equal("foo.bar"))
Expect(cconf.AlternativeRecordLayer).ToNot(BeNil())
Expect(cconf.GetExtensions).ToNot(BeNil())
Expect(cconf.ReceivedExtensions).ToNot(BeNil())
})
It("returns Handshake() when an error occurs in qtls", func() {
sErrChan := make(chan error, 1)
runner := NewMockHandshakeRunner(mockCtrl)
@ -420,7 +376,7 @@ var _ = Describe("Crypto Setup TLS", func() {
It("handshakes with client auth", func() {
clientConf.Certificates = []tls.Certificate{generateCert()}
serverConf.ClientAuth = qtls.RequireAnyClientCert
serverConf.ClientAuth = tls.RequireAnyClientCert
_, _, clientErr, _, serverErr := handshakeWithTLSConf(
clientConf, serverConf,
&utils.RTTStats{}, &utils.RTTStats{},
@ -647,7 +603,7 @@ var _ = Describe("Crypto Setup TLS", func() {
})
It("uses session resumption", func() {
csc := NewMockClientSessionCache(mockCtrl)
csc := mocktls.NewMockClientSessionCache(mockCtrl)
var state *tls.ClientSessionState
receivedSessionTicket := make(chan struct{})
csc.EXPECT().Get(gomock.Any())
@ -690,7 +646,7 @@ var _ = Describe("Crypto Setup TLS", func() {
})
It("doesn't use session resumption if the server disabled it", func() {
csc := NewMockClientSessionCache(mockCtrl)
csc := mocktls.NewMockClientSessionCache(mockCtrl)
var state *tls.ClientSessionState
receivedSessionTicket := make(chan struct{})
csc.EXPECT().Get(gomock.Any())
@ -727,7 +683,7 @@ var _ = Describe("Crypto Setup TLS", func() {
})
It("uses 0-RTT", func() {
csc := NewMockClientSessionCache(mockCtrl)
csc := mocktls.NewMockClientSessionCache(mockCtrl)
var state *tls.ClientSessionState
receivedSessionTicket := make(chan struct{})
csc.EXPECT().Get(gomock.Any())
@ -782,7 +738,7 @@ var _ = Describe("Crypto Setup TLS", func() {
})
It("rejects 0-RTT, whent the transport parameters changed", func() {
csc := NewMockClientSessionCache(mockCtrl)
csc := mocktls.NewMockClientSessionCache(mockCtrl)
var state *tls.ClientSessionState
receivedSessionTicket := make(chan struct{})
csc.EXPECT().Get(gomock.Any())

View file

@ -1,14 +1,13 @@
package handshake
import (
"crypto"
"crypto/cipher"
"crypto/tls"
"encoding/hex"
"strings"
"unsafe"
"github.com/lucas-clemente/quic-go/internal/qtls"
"github.com/golang/mock/gomock"
"github.com/marten-seemann/qtls"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
@ -31,27 +30,6 @@ var _ = AfterEach(func() {
mockCtrl.Finish()
})
var cipherSuites = []*qtls.CipherSuiteTLS13{
&qtls.CipherSuiteTLS13{
ID: qtls.TLS_AES_128_GCM_SHA256,
KeyLen: 16,
AEAD: qtls.AEADAESGCMTLS13,
Hash: crypto.SHA256,
},
&qtls.CipherSuiteTLS13{
ID: qtls.TLS_AES_256_GCM_SHA384,
KeyLen: 32,
AEAD: qtls.AEADAESGCMTLS13,
Hash: crypto.SHA384,
},
&qtls.CipherSuiteTLS13{
ID: qtls.TLS_CHACHA20_POLY1305_SHA256,
KeyLen: 32,
AEAD: nil, // will be set by init
Hash: crypto.SHA256,
},
}
func splitHexString(s string) (slice []byte) {
for _, ss := range strings.Split(s, " ") {
if ss[0:2] == "0x" {
@ -64,25 +42,8 @@ func splitHexString(s string) (slice []byte) {
return
}
type cipherSuiteTLS13 struct {
ID uint16
KeyLen int
AEAD func(key, fixedNonce []byte) cipher.AEAD
Hash crypto.Hash
}
//go:linkname cipherSuiteTLS13ByID github.com/marten-seemann/qtls.cipherSuiteTLS13ByID
func cipherSuiteTLS13ByID(id uint16) *cipherSuiteTLS13
func init() {
val := cipherSuiteTLS13ByID(qtls.TLS_CHACHA20_POLY1305_SHA256)
chacha := (*cipherSuiteTLS13)(unsafe.Pointer(val))
for _, s := range cipherSuites {
if s.ID == qtls.TLS_CHACHA20_POLY1305_SHA256 {
if s.KeyLen != chacha.KeyLen || s.Hash != chacha.Hash {
panic("invalid parameters for ChaCha20")
}
s.AEAD = chacha.AEAD
}
}
var cipherSuites = []*qtls.CipherSuiteTLS13{
qtls.CipherSuiteTLS13ByID(tls.TLS_AES_128_GCM_SHA256),
qtls.CipherSuiteTLS13ByID(tls.TLS_AES_256_GCM_SHA384),
qtls.CipherSuiteTLS13ByID(tls.TLS_CHACHA20_POLY1305_SHA256),
}

View file

@ -3,12 +3,13 @@ package handshake
import (
"crypto/aes"
"crypto/cipher"
"crypto/tls"
"encoding/binary"
"fmt"
"golang.org/x/crypto/chacha20"
"github.com/marten-seemann/qtls"
"github.com/lucas-clemente/quic-go/internal/qtls"
)
type headerProtector interface {
@ -18,9 +19,9 @@ type headerProtector interface {
func newHeaderProtector(suite *qtls.CipherSuiteTLS13, trafficSecret []byte, isLongHeader bool) headerProtector {
switch suite.ID {
case qtls.TLS_AES_128_GCM_SHA256, qtls.TLS_AES_256_GCM_SHA384:
case tls.TLS_AES_128_GCM_SHA256, tls.TLS_AES_256_GCM_SHA384:
return newAESHeaderProtector(suite, trafficSecret, isLongHeader)
case qtls.TLS_CHACHA20_POLY1305_SHA256:
case tls.TLS_CHACHA20_POLY1305_SHA256:
return newChaChaHeaderProtector(suite, trafficSecret, isLongHeader)
default:
panic(fmt.Sprintf("Invalid cipher suite id: %d", suite.ID))

View file

@ -5,7 +5,7 @@ import (
"crypto/rand"
mrand "math/rand"
"github.com/marten-seemann/qtls"
"github.com/lucas-clemente/quic-go/internal/qtls"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"

View file

@ -2,15 +2,16 @@ package handshake
import (
"crypto"
"crypto/tls"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/marten-seemann/qtls"
"github.com/lucas-clemente/quic-go/internal/qtls"
)
var quicVersion1Salt = []byte{0xaf, 0xbf, 0xec, 0x28, 0x99, 0x93, 0xd2, 0x4c, 0x9e, 0x97, 0x86, 0xf1, 0x9c, 0x61, 0x11, 0xe0, 0x43, 0x90, 0xa8, 0x99}
var initialSuite = &qtls.CipherSuiteTLS13{
ID: qtls.TLS_AES_128_GCM_SHA256,
ID: tls.TLS_AES_128_GCM_SHA256,
KeyLen: 16,
AEAD: qtls.AEADAESGCMTLS13,
Hash: crypto.SHA256,

View file

@ -6,8 +6,8 @@ import (
"time"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/qtls"
"github.com/lucas-clemente/quic-go/internal/wire"
"github.com/marten-seemann/qtls"
)
var (

View file

@ -1,7 +1,3 @@
package handshake
//go:generate sh -c "../../mockgen_private.sh handshake mock_handshake_runner_test.go github.com/lucas-clemente/quic-go/internal/handshake handshakeRunner"
// The following command produces a warning message on OSX, however, it still generates the correct mock file.
// See https://github.com/golang/mock/issues/339 for details.
//go:generate sh -c "mockgen -package handshake -destination mock_client_session_cache_test.go crypto/tls ClientSessionCache && goimports -w mock_client_session_cache_test.go"

View file

@ -1,233 +0,0 @@
package handshake
import (
"crypto/tls"
"net"
"time"
"unsafe"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/marten-seemann/qtls"
)
func init() {
if !structsEqual(&tls.ClientHelloInfo{}, &clientHelloInfo{}) {
panic("clientHelloInfo not compatible with tls.ClientHelloInfo")
}
if !structsEqual(&qtls.ClientHelloInfo{}, &qtlsClientHelloInfo{}) {
panic("qtlsClientHelloInfo not compatible with qtls.ClientHelloInfo")
}
}
type conn struct {
localAddr, remoteAddr net.Addr
}
func newConn(local, remote net.Addr) net.Conn {
return &conn{
localAddr: local,
remoteAddr: remote,
}
}
var _ net.Conn = &conn{}
func (c *conn) Read([]byte) (int, error) { return 0, nil }
func (c *conn) Write([]byte) (int, error) { return 0, nil }
func (c *conn) Close() error { return nil }
func (c *conn) RemoteAddr() net.Addr { return c.remoteAddr }
func (c *conn) LocalAddr() net.Addr { return c.localAddr }
func (c *conn) SetReadDeadline(time.Time) error { return nil }
func (c *conn) SetWriteDeadline(time.Time) error { return nil }
func (c *conn) SetDeadline(time.Time) error { return nil }
func tlsConfigToQtlsConfig(
c *tls.Config,
recordLayer qtls.RecordLayer,
extHandler tlsExtensionHandler,
rttStats *utils.RTTStats,
getDataForSessionState func() []byte,
setDataFromSessionState func([]byte),
accept0RTT func([]byte) bool,
rejected0RTT func(),
enable0RTT bool,
) *qtls.Config {
if c == nil {
c = &tls.Config{}
}
// Clone the config first. This executes the tls.Config.serverInit().
// This sets the SessionTicketKey, if the user didn't supply one.
c = c.Clone()
// QUIC requires TLS 1.3 or newer
minVersion := c.MinVersion
if minVersion < qtls.VersionTLS13 {
minVersion = qtls.VersionTLS13
}
maxVersion := c.MaxVersion
if maxVersion < qtls.VersionTLS13 {
maxVersion = qtls.VersionTLS13
}
var getConfigForClient func(ch *qtls.ClientHelloInfo) (*qtls.Config, error)
if c.GetConfigForClient != nil {
getConfigForClient = func(ch *qtls.ClientHelloInfo) (*qtls.Config, error) {
tlsConf, err := c.GetConfigForClient(toTLSClientHelloInfo(ch))
if err != nil {
return nil, err
}
if tlsConf == nil {
return nil, nil
}
return tlsConfigToQtlsConfig(tlsConf, recordLayer, extHandler, rttStats, getDataForSessionState, setDataFromSessionState, accept0RTT, rejected0RTT, enable0RTT), nil
}
}
var getCertificate func(ch *qtls.ClientHelloInfo) (*qtls.Certificate, error)
if c.GetCertificate != nil {
getCertificate = func(ch *qtls.ClientHelloInfo) (*qtls.Certificate, error) {
cert, err := c.GetCertificate(toTLSClientHelloInfo(ch))
if err != nil {
return nil, err
}
if cert == nil {
return nil, nil
}
return (*qtls.Certificate)(cert), nil
}
}
var csc qtls.ClientSessionCache
if c.ClientSessionCache != nil {
csc = &clientSessionCache{c.ClientSessionCache}
}
conf := &qtls.Config{
Rand: c.Rand,
Time: c.Time,
Certificates: *(*[]qtls.Certificate)(unsafe.Pointer(&c.Certificates)),
// NameToCertificate is deprecated, but we still need to copy it if the user sets it.
//nolint:staticcheck
NameToCertificate: *(*map[string]*qtls.Certificate)(unsafe.Pointer(&c.NameToCertificate)),
GetCertificate: getCertificate,
GetClientCertificate: *(*func(*qtls.CertificateRequestInfo) (*qtls.Certificate, error))(unsafe.Pointer(&c.GetClientCertificate)),
GetConfigForClient: getConfigForClient,
VerifyPeerCertificate: c.VerifyPeerCertificate,
RootCAs: c.RootCAs,
NextProtos: c.NextProtos,
EnforceNextProtoSelection: true,
ServerName: c.ServerName,
ClientAuth: c.ClientAuth,
ClientCAs: c.ClientCAs,
InsecureSkipVerify: c.InsecureSkipVerify,
CipherSuites: c.CipherSuites,
PreferServerCipherSuites: c.PreferServerCipherSuites,
SessionTicketsDisabled: c.SessionTicketsDisabled,
SessionTicketKey: c.SessionTicketKey,
ClientSessionCache: csc,
MinVersion: minVersion,
MaxVersion: maxVersion,
CurvePreferences: c.CurvePreferences,
DynamicRecordSizingDisabled: c.DynamicRecordSizingDisabled,
// no need to copy Renegotiation, it's not supported by TLS 1.3
KeyLogWriter: c.KeyLogWriter,
AlternativeRecordLayer: recordLayer,
GetExtensions: extHandler.GetExtensions,
ReceivedExtensions: extHandler.ReceivedExtensions,
Accept0RTT: accept0RTT,
Rejected0RTT: rejected0RTT,
GetAppDataForSessionState: getDataForSessionState,
SetAppDataFromSessionState: setDataFromSessionState,
}
if enable0RTT {
conf.Enable0RTT = true
conf.MaxEarlyData = 0xffffffff
}
return conf
}
type clientSessionCache struct {
tls.ClientSessionCache
}
var _ qtls.ClientSessionCache = &clientSessionCache{}
func (c *clientSessionCache) Get(sessionKey string) (*qtls.ClientSessionState, bool) {
sess, ok := c.ClientSessionCache.Get(sessionKey)
if sess == nil {
return nil, ok
}
// qtls.ClientSessionState is identical to the tls.ClientSessionState.
// In order to allow users of quic-go to use a tls.Config,
// we need this workaround to use the ClientSessionCache.
// In unsafe.go we check that the two structs are actually identical.
return (*qtls.ClientSessionState)(unsafe.Pointer(sess)), ok
}
func (c *clientSessionCache) Put(sessionKey string, cs *qtls.ClientSessionState) {
if cs == nil {
c.ClientSessionCache.Put(sessionKey, nil)
return
}
// qtls.ClientSessionState is identical to the tls.ClientSessionState.
// In order to allow users of quic-go to use a tls.Config,
// we need this workaround to use the ClientSessionCache.
// In unsafe.go we check that the two structs are actually identical.
c.ClientSessionCache.Put(sessionKey, (*tls.ClientSessionState)(unsafe.Pointer(cs)))
}
type clientHelloInfo struct {
CipherSuites []uint16
ServerName string
SupportedCurves []tls.CurveID
SupportedPoints []uint8
SignatureSchemes []tls.SignatureScheme
SupportedProtos []string
SupportedVersions []uint16
Conn net.Conn
config *tls.Config
}
type qtlsClientHelloInfo struct {
CipherSuites []uint16
ServerName string
SupportedCurves []tls.CurveID
SupportedPoints []uint8
SignatureSchemes []tls.SignatureScheme
SupportedProtos []string
SupportedVersions []uint16
Conn net.Conn
config *qtls.Config
}
func toTLSClientHelloInfo(chi *qtls.ClientHelloInfo) *tls.ClientHelloInfo {
if chi == nil {
return nil
}
qtlsCHI := (*qtlsClientHelloInfo)(unsafe.Pointer(chi))
var config *tls.Config
if qtlsCHI.config != nil {
config = qtlsConfigToTLSConfig(qtlsCHI.config)
}
return (*tls.ClientHelloInfo)(unsafe.Pointer(&clientHelloInfo{
CipherSuites: chi.CipherSuites,
ServerName: chi.ServerName,
SupportedCurves: chi.SupportedCurves,
SupportedPoints: chi.SupportedPoints,
SignatureSchemes: chi.SignatureSchemes,
SupportedProtos: chi.SupportedProtos,
SupportedVersions: chi.SupportedVersions,
Conn: chi.Conn,
config: config,
}))
}
// qtlsConfigToTLSConfig is used to transform a qtls.Config to a tls.Config.
// It is used to create the tls.Config in the ClientHelloInfo.
// It doesn't copy all values, but only those used by ClientHelloInfo.SupportsCertificate.
func qtlsConfigToTLSConfig(config *qtls.Config) *tls.Config {
return &tls.Config{
MinVersion: config.MinVersion,
MaxVersion: config.MaxVersion,
CipherSuites: config.CipherSuites,
CurvePreferences: config.CurvePreferences,
}
}

View file

@ -2,7 +2,7 @@ package handshake
import (
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/marten-seemann/qtls"
"github.com/lucas-clemente/quic-go/internal/qtls"
)
const quicTLSExtensionType = 0xffa5

View file

@ -2,7 +2,7 @@ package handshake
import (
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/marten-seemann/qtls"
"github.com/lucas-clemente/quic-go/internal/qtls"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"

View file

@ -1,44 +0,0 @@
package handshake
// This package uses unsafe to convert between:
// * qtls.Certificate and tls.Certificate
// * qtls.CertificateRequestInfo and tls.CertificateRequestInfo
// * qtls.ClientHelloInfo and tls.ClientHelloInfo
// * qtls.ConnectionState and tls.ConnectionState
// * qtls.ClientSessionState and tls.ClientSessionState
// We check in init() that this conversion actually is safe.
import (
"crypto/tls"
"reflect"
"github.com/marten-seemann/qtls"
)
func init() {
if !structsEqual(&tls.Certificate{}, &qtls.Certificate{}) {
panic("qtls.Certificate not compatible with tls.Certificate")
}
if !structsEqual(&tls.CertificateRequestInfo{}, &qtls.CertificateRequestInfo{}) {
panic("qtls.CertificateRequestInfo not compatible with tls.CertificateRequestInfo")
}
if !structsEqual(&tls.ClientSessionState{}, &qtls.ClientSessionState{}) {
panic("qtls.ClientSessionState not compatible with tls.ClientSessionState")
}
}
func structsEqual(a, b interface{}) bool {
sa := reflect.ValueOf(a).Elem()
sb := reflect.ValueOf(b).Elem()
if sa.NumField() != sb.NumField() {
return false
}
for i := 0; i < sa.NumField(); i++ {
fa := sa.Type().Field(i)
fb := sb.Type().Field(i)
if !reflect.DeepEqual(fa.Index, fb.Index) || fa.Name != fb.Name || fa.Anonymous != fb.Anonymous || fa.Offset != fb.Offset || !reflect.DeepEqual(fa.Type, fb.Type) {
return false
}
}
return true
}

View file

@ -9,12 +9,11 @@ import (
"strconv"
"time"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/qerr"
"github.com/lucas-clemente/quic-go/internal/qtls"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/logging"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/marten-seemann/qtls"
)
// By setting this environment variable, the key update interval can be adjusted.

View file

@ -2,13 +2,14 @@ package handshake
import (
"crypto/rand"
"crypto/tls"
"fmt"
"os"
"time"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/qtls"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/marten-seemann/qtls"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
@ -19,7 +20,7 @@ var _ = Describe("Updatable AEAD", func() {
secret := splitHexString("9ac312a7f877468ebe69422748ad00a1 5443f18203a07d6060f688f30f21632b")
aead := newUpdatableAEAD(&utils.RTTStats{}, nil, nil)
chacha := cipherSuites[2]
Expect(chacha.ID).To(Equal(qtls.TLS_CHACHA20_POLY1305_SHA256))
Expect(chacha.ID).To(Equal(tls.TLS_CHACHA20_POLY1305_SHA256))
aead.SetWriteKey(chacha, secret)
header := splitHexString("4200bff4")
const pnOffset = 1

View file

@ -10,7 +10,7 @@ import (
gomock "github.com/golang/mock/gomock"
handshake "github.com/lucas-clemente/quic-go/internal/handshake"
protocol "github.com/lucas-clemente/quic-go/internal/protocol"
qtls "github.com/marten-seemann/qtls"
qtls "github.com/lucas-clemente/quic-go/internal/qtls"
)
// MockCryptoSetup is a mock of CryptoSetup interface

View file

@ -1,16 +1,20 @@
package mocks
//go:generate sh -c "mockgen -package mockquic -destination quic/stream.go github.com/lucas-clemente/quic-go Stream && goimports -w quic/stream.go"
//go:generate sh -c "mockgen -package mockquic -destination quic/early_session.go github.com/lucas-clemente/quic-go EarlySession && goimports -w quic/early_session.go"
//go:generate sh -c "mockgen -package mockquic -destination quic/early_session_tmp.go github.com/lucas-clemente/quic-go EarlySession && sed 's/qtls.ConnectionState/quic.ConnectionState/g' quic/early_session_tmp.go > quic/early_session.go && rm quic/early_session_tmp.go && goimports -w quic/early_session.go"
//go:generate sh -c "mockgen -package mockquic -destination quic/early_listener.go github.com/lucas-clemente/quic-go EarlyListener && goimports -w quic/early_listener.go"
//go:generate sh -c "mockgen -package mocks -destination tracer.go github.com/lucas-clemente/quic-go/logging Tracer && goimports -w tracer.go"
//go:generate sh -c "mockgen -package mocks -destination connection_tracer.go github.com/lucas-clemente/quic-go/logging ConnectionTracer && goimports -w connection_tracer.go"
//go:generate sh -c "mockgen -package mocks -destination short_header_sealer.go github.com/lucas-clemente/quic-go/internal/handshake ShortHeaderSealer && goimports -w short_header_sealer.go"
//go:generate sh -c "mockgen -package mocks -destination short_header_opener.go github.com/lucas-clemente/quic-go/internal/handshake ShortHeaderOpener && goimports -w short_header_opener.go"
//go:generate sh -c "mockgen -package mocks -destination long_header_opener.go github.com/lucas-clemente/quic-go/internal/handshake LongHeaderOpener && goimports -w long_header_opener.go"
//go:generate sh -c "mockgen -package mocks -destination crypto_setup.go github.com/lucas-clemente/quic-go/internal/handshake CryptoSetup && goimports -w crypto_setup.go"
//go:generate sh -c "mockgen -package mocks -destination crypto_setup_tmp.go github.com/lucas-clemente/quic-go/internal/handshake CryptoSetup && sed 's/github.com\\/marten-seemann\\/qtls/github.com\\/lucas-clemente\\/quic-go\\/internal\\/qtls/g' crypto_setup_tmp.go > crypto_setup.go && rm crypto_setup_tmp.go && goimports -w crypto_setup.go"
//go:generate sh -c "mockgen -package mocks -destination stream_flow_controller.go github.com/lucas-clemente/quic-go/internal/flowcontrol StreamFlowController && goimports -w stream_flow_controller.go"
//go:generate sh -c "mockgen -package mocks -destination congestion.go github.com/lucas-clemente/quic-go/internal/congestion SendAlgorithmWithDebugInfos && goimports -w congestion.go"
//go:generate sh -c "mockgen -package mocks -destination connection_flow_controller.go github.com/lucas-clemente/quic-go/internal/flowcontrol ConnectionFlowController && goimports -w connection_flow_controller.go"
//go:generate sh -c "mockgen -package mockackhandler -destination ackhandler/sent_packet_handler.go github.com/lucas-clemente/quic-go/internal/ackhandler SentPacketHandler && goimports -w ackhandler/sent_packet_handler.go"
//go:generate sh -c "mockgen -package mockackhandler -destination ackhandler/received_packet_handler.go github.com/lucas-clemente/quic-go/internal/ackhandler ReceivedPacketHandler && goimports -w ackhandler/received_packet_handler.go"
// The following command produces a warning message on OSX, however, it still generates the correct mock file.
// See https://github.com/golang/mock/issues/339 for details.
//go:generate sh -c "mockgen -package mocktls -destination tls/client_session_cache.go crypto/tls ClientSessionCache && goimports -w tls/client_session_cache.go"

View file

@ -12,7 +12,6 @@ import (
gomock "github.com/golang/mock/gomock"
quic "github.com/lucas-clemente/quic-go"
protocol "github.com/lucas-clemente/quic-go/internal/protocol"
qtls "github.com/marten-seemann/qtls"
)
// MockEarlySession is a mock of EarlySession interface
@ -83,10 +82,10 @@ func (mr *MockEarlySessionMockRecorder) CloseWithError(arg0, arg1 interface{}) *
}
// ConnectionState mocks base method
func (m *MockEarlySession) ConnectionState() qtls.ConnectionState {
func (m *MockEarlySession) ConnectionState() quic.ConnectionState {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ConnectionState")
ret0, _ := ret[0].(qtls.ConnectionState)
ret0, _ := ret[0].(quic.ConnectionState)
return ret0
}

View file

@ -1,8 +1,8 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: crypto/tls (interfaces: ClientSessionCache)
// Package handshake is a generated GoMock package.
package handshake
// Package mocktls is a generated GoMock package.
package mocktls
import (
tls "crypto/tls"

View file

@ -3,7 +3,7 @@ package qerr
import (
"fmt"
"github.com/marten-seemann/qtls"
"github.com/lucas-clemente/quic-go/internal/qtls"
)
// ErrorCode can be used as a normal error without reason.

164
internal/qtls/interface.go Normal file
View file

@ -0,0 +1,164 @@
package qtls
import (
"crypto"
"crypto/cipher"
"crypto/tls"
"net"
"unsafe"
"github.com/marten-seemann/qtls"
)
type (
// Alert is a TLS alert
Alert = qtls.Alert
// A Certificate is qtls.Certificate.
Certificate = qtls.Certificate
// CertificateRequestInfo contains inforamtion about a certificate request.
CertificateRequestInfo = qtls.CertificateRequestInfo
// A CipherSuiteTLS13 is a cipher suite for TLS 1.3
CipherSuiteTLS13 = qtls.CipherSuiteTLS13
// ClientHelloInfo contains information about a ClientHello.
ClientHelloInfo = qtls.ClientHelloInfo
// ClientSessionCache is a cache used for session resumption.
ClientSessionCache = qtls.ClientSessionCache
// ClientSessionState is a state needed for session resumption.
ClientSessionState = qtls.ClientSessionState
// A Config is a qtls.Config.
Config = qtls.Config
// A Conn is a qtls.Conn.
Conn = qtls.Conn
// ConnectionState contains information about the state of the connection.
ConnectionState = qtls.ConnectionState
// EncryptionLevel is the encryption level of a message.
EncryptionLevel = qtls.EncryptionLevel
// Extension is a TLS extension
Extension = qtls.Extension
// RecordLayer is a qtls RecordLayer.
RecordLayer = qtls.RecordLayer
)
type ExtraConfig struct {
// GetExtensions, if not nil, is called before a message that allows
// sending of extensions is sent.
// Currently only implemented for the ClientHello message (for the client)
// and for the EncryptedExtensions message (for the server).
// Only valid for TLS 1.3.
GetExtensions func(handshakeMessageType uint8) []Extension
// ReceivedExtensions, if not nil, is called when a message that allows the
// inclusion of extensions is received.
// It is called with an empty slice of extensions, if the message didn't
// contain any extensions.
// Currently only implemented for the ClientHello message (sent by the
// client) and for the EncryptedExtensions message (sent by the server).
// Only valid for TLS 1.3.
ReceivedExtensions func(handshakeMessageType uint8, exts []Extension)
// AlternativeRecordLayer is used by QUIC
AlternativeRecordLayer RecordLayer
// Enforce the selection of a supported application protocol.
// Only works for TLS 1.3.
// If enabled, client and server have to agree on an application protocol.
// Otherwise, connection establishment fails.
EnforceNextProtoSelection bool
// If MaxEarlyData is greater than 0, the client will be allowed to send early
// data when resuming a session.
// Requires the AlternativeRecordLayer to be set.
//
// It has no meaning on the client.
MaxEarlyData uint32
// The Accept0RTT callback is called when the client offers 0-RTT.
// The server then has to decide if it wants to accept or reject 0-RTT.
// It is only used for servers.
Accept0RTT func(appData []byte) bool
// 0RTTRejected is called when the server rejectes 0-RTT.
// It is only used for clients.
Rejected0RTT func()
// If set, the client will export the 0-RTT key when resuming a session that
// allows sending of early data.
// Requires the AlternativeRecordLayer to be set.
//
// It has no meaning to the server.
Enable0RTT bool
// Is called when the client saves a session ticket to the session ticket.
// This gives the application the opportunity to save some data along with the ticket,
// which can be restored when the session ticket is used.
GetAppDataForSessionState func() []byte
// Is called when the client uses a session ticket.
// Restores the application data that was saved earlier on GetAppDataForSessionTicket.
SetAppDataFromSessionState func([]byte)
}
const (
// EncryptionHandshake is the Handshake encryption level
EncryptionHandshake = qtls.EncryptionHandshake
// Encryption0RTT is the 0-RTT encryption level
Encryption0RTT = qtls.Encryption0RTT
// EncryptionApplication is the application data encryption level
EncryptionApplication = qtls.EncryptionApplication
)
// CipherSuiteName gets the name of a cipher suite.
func CipherSuiteName(id uint16) string {
return qtls.CipherSuiteName(id)
}
// HkdfExtract generates a pseudorandom key for use with Expand from an input secret and an optional independent salt.
func HkdfExtract(hash crypto.Hash, newSecret, currentSecret []byte) []byte {
return qtls.HkdfExtract(hash, newSecret, currentSecret)
}
// HkdfExpandLabel HKDF expands a label
func HkdfExpandLabel(hash crypto.Hash, secret, hashValue []byte, label string, L int) []byte {
return qtls.HkdfExpandLabel(hash, secret, hashValue, label, L)
}
// AEADAESGCMTLS13 creates a new AES-GCM AEAD for TLS 1.3
func AEADAESGCMTLS13(key, fixedNonce []byte) cipher.AEAD {
return qtls.AEADAESGCMTLS13(key, fixedNonce)
}
// Client returns a new TLS client side connection.
func Client(conn net.Conn, config *tls.Config, extraConfig *ExtraConfig) *Conn {
return qtls.Client(conn, tlsConfigToQtlsConfig(config, extraConfig))
}
// Server returns a new TLS server side connection.
func Server(conn net.Conn, config *tls.Config, extraConfig *ExtraConfig) *Conn {
return qtls.Server(conn, tlsConfigToQtlsConfig(config, extraConfig))
}
func GetConnectionState(conn *Conn) ConnectionState {
return conn.ConnectionState()
}
type cipherSuiteTLS13 struct {
ID uint16
KeyLen int
AEAD func(key, fixedNonce []byte) cipher.AEAD
Hash crypto.Hash
}
//go:linkname cipherSuiteTLS13ByID github.com/marten-seemann/qtls.cipherSuiteTLS13ByID
func cipherSuiteTLS13ByID(id uint16) *cipherSuiteTLS13
// CipherSuiteTLS13ByID gets a TLS 1.3 cipher suite.
func CipherSuiteTLS13ByID(id uint16) *CipherSuiteTLS13 {
val := cipherSuiteTLS13ByID(id)
cs := (*cipherSuiteTLS13)(unsafe.Pointer(val))
return &qtls.CipherSuiteTLS13{
ID: cs.ID,
KeyLen: cs.KeyLen,
AEAD: cs.AEAD,
Hash: cs.Hash,
}
}

214
internal/qtls/qtls.go Normal file
View file

@ -0,0 +1,214 @@
package qtls
// This package uses unsafe to convert between:
// * Certificate and tls.Certificate
// * CertificateRequestInfo and tls.CertificateRequestInfo
// * ClientHelloInfo and tls.ClientHelloInfo
// * ConnectionState and tls.ConnectionState
// * ClientSessionState and tls.ClientSessionState
// We check in init() that this conversion actually is safe.
import (
"crypto/tls"
"net"
"unsafe"
)
func init() {
if !structsEqual(&tls.Certificate{}, &Certificate{}) {
panic("Certificate not compatible with tls.Certificate")
}
if !structsEqual(&tls.CertificateRequestInfo{}, &CertificateRequestInfo{}) {
panic("CertificateRequestInfo not compatible with tls.CertificateRequestInfo")
}
if !structsEqual(&tls.ClientSessionState{}, &ClientSessionState{}) {
panic("ClientSessionState not compatible with tls.ClientSessionState")
}
if !structsEqual(&tls.ClientHelloInfo{}, &clientHelloInfo{}) {
panic("clientHelloInfo not compatible with tls.ClientHelloInfo")
}
if !structsEqual(&ClientHelloInfo{}, &qtlsClientHelloInfo{}) {
panic("qtlsClientHelloInfo not compatible with ClientHelloInfo")
}
}
func tlsConfigToQtlsConfig(c *tls.Config, ec *ExtraConfig) *Config {
if c == nil {
c = &tls.Config{}
}
if ec == nil {
ec = &ExtraConfig{}
}
// Clone the config first. This executes the tls.Config.serverInit().
// This sets the SessionTicketKey, if the user didn't supply one.
c = c.Clone()
// QUIC requires TLS 1.3 or newer
minVersion := c.MinVersion
if minVersion < tls.VersionTLS13 {
minVersion = tls.VersionTLS13
}
maxVersion := c.MaxVersion
if maxVersion < tls.VersionTLS13 {
maxVersion = tls.VersionTLS13
}
var getConfigForClient func(ch *ClientHelloInfo) (*Config, error)
if c.GetConfigForClient != nil {
getConfigForClient = func(ch *ClientHelloInfo) (*Config, error) {
tlsConf, err := c.GetConfigForClient(toTLSClientHelloInfo(ch))
if err != nil {
return nil, err
}
if tlsConf == nil {
return nil, nil
}
return tlsConfigToQtlsConfig(tlsConf, ec), nil
}
}
var getCertificate func(ch *ClientHelloInfo) (*Certificate, error)
if c.GetCertificate != nil {
getCertificate = func(ch *ClientHelloInfo) (*Certificate, error) {
cert, err := c.GetCertificate(toTLSClientHelloInfo(ch))
if err != nil {
return nil, err
}
if cert == nil {
return nil, nil
}
return (*Certificate)(cert), nil
}
}
var csc ClientSessionCache
if c.ClientSessionCache != nil {
csc = &clientSessionCache{c.ClientSessionCache}
}
conf := &Config{
Rand: c.Rand,
Time: c.Time,
Certificates: *(*[]Certificate)(unsafe.Pointer(&c.Certificates)),
//nolint:staticcheck // NameToCertificate is deprecated, but we still need to copy it if the user sets it.
NameToCertificate: *(*map[string]*Certificate)(unsafe.Pointer(&c.NameToCertificate)),
GetCertificate: getCertificate,
GetClientCertificate: *(*func(*CertificateRequestInfo) (*Certificate, error))(unsafe.Pointer(&c.GetClientCertificate)),
GetConfigForClient: getConfigForClient,
VerifyPeerCertificate: c.VerifyPeerCertificate,
RootCAs: c.RootCAs,
NextProtos: c.NextProtos,
EnforceNextProtoSelection: true,
ServerName: c.ServerName,
ClientAuth: c.ClientAuth,
ClientCAs: c.ClientCAs,
InsecureSkipVerify: c.InsecureSkipVerify,
CipherSuites: c.CipherSuites,
PreferServerCipherSuites: c.PreferServerCipherSuites,
SessionTicketsDisabled: c.SessionTicketsDisabled,
//nolint:staticcheck // SessionTicketKey is deprecated, but we still need to copy it if the user sets it.
SessionTicketKey: c.SessionTicketKey,
ClientSessionCache: csc,
MinVersion: minVersion,
MaxVersion: maxVersion,
CurvePreferences: c.CurvePreferences,
DynamicRecordSizingDisabled: c.DynamicRecordSizingDisabled,
// no need to copy Renegotiation, it's not supported by TLS 1.3
KeyLogWriter: c.KeyLogWriter,
AlternativeRecordLayer: ec.AlternativeRecordLayer,
GetExtensions: ec.GetExtensions,
ReceivedExtensions: ec.ReceivedExtensions,
Accept0RTT: ec.Accept0RTT,
Rejected0RTT: ec.Rejected0RTT,
GetAppDataForSessionState: ec.GetAppDataForSessionState,
SetAppDataFromSessionState: ec.SetAppDataFromSessionState,
Enable0RTT: ec.Enable0RTT,
MaxEarlyData: ec.MaxEarlyData,
}
return conf
}
type clientSessionCache struct {
tls.ClientSessionCache
}
var _ ClientSessionCache = &clientSessionCache{}
func (c *clientSessionCache) Get(sessionKey string) (*ClientSessionState, bool) {
sess, ok := c.ClientSessionCache.Get(sessionKey)
if sess == nil {
return nil, ok
}
// ClientSessionState is identical to the tls.ClientSessionState.
// In order to allow users of quic-go to use a tls.Config,
// we need this workaround to use the ClientSessionCache.
// In unsafe.go we check that the two structs are actually identical.
return (*ClientSessionState)(unsafe.Pointer(sess)), ok
}
func (c *clientSessionCache) Put(sessionKey string, cs *ClientSessionState) {
if cs == nil {
c.ClientSessionCache.Put(sessionKey, nil)
return
}
// ClientSessionState is identical to the tls.ClientSessionState.
// In order to allow users of quic-go to use a tls.Config,
// we need this workaround to use the ClientSessionCache.
// In unsafe.go we check that the two structs are actually identical.
c.ClientSessionCache.Put(sessionKey, (*tls.ClientSessionState)(unsafe.Pointer(cs)))
}
type clientHelloInfo struct {
CipherSuites []uint16
ServerName string
SupportedCurves []tls.CurveID
SupportedPoints []uint8
SignatureSchemes []tls.SignatureScheme
SupportedProtos []string
SupportedVersions []uint16
Conn net.Conn
config *tls.Config
}
type qtlsClientHelloInfo struct {
CipherSuites []uint16
ServerName string
SupportedCurves []tls.CurveID
SupportedPoints []uint8
SignatureSchemes []tls.SignatureScheme
SupportedProtos []string
SupportedVersions []uint16
Conn net.Conn
config *Config
}
func toTLSClientHelloInfo(chi *ClientHelloInfo) *tls.ClientHelloInfo {
if chi == nil {
return nil
}
qtlsCHI := (*qtlsClientHelloInfo)(unsafe.Pointer(chi))
var config *tls.Config
if qtlsCHI.config != nil {
config = qtlsConfigToTLSConfig(qtlsCHI.config)
}
return (*tls.ClientHelloInfo)(unsafe.Pointer(&clientHelloInfo{
CipherSuites: chi.CipherSuites,
ServerName: chi.ServerName,
SupportedCurves: chi.SupportedCurves,
SupportedPoints: chi.SupportedPoints,
SignatureSchemes: chi.SignatureSchemes,
SupportedProtos: chi.SupportedProtos,
SupportedVersions: chi.SupportedVersions,
Conn: chi.Conn,
config: config,
}))
}
// qtlsConfigToTLSConfig is used to transform a Config to a tls.Config.
// It is used to create the tls.Config in the ClientHelloInfo.
// It doesn't copy all values, but only those used by ClientHelloInfo.SupportsCertificate.
func qtlsConfigToTLSConfig(config *Config) *tls.Config {
return &tls.Config{
MinVersion: config.MinVersion,
MaxVersion: config.MaxVersion,
CipherSuites: config.CipherSuites,
CurvePreferences: config.CurvePreferences,
}
}

View file

@ -0,0 +1,25 @@
package qtls
import (
"testing"
gomock "github.com/golang/mock/gomock"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
func TestQTLS(t *testing.T) {
RegisterFailHandler(Fail)
RunSpecs(t, "qtls Suite")
}
var mockCtrl *gomock.Controller
var _ = BeforeEach(func() {
mockCtrl = gomock.NewController(GinkgoT())
})
var _ = AfterEach(func() {
mockCtrl.Finish()
})

View file

@ -1,4 +1,4 @@
package handshake
package qtls
import (
"crypto/tls"
@ -6,99 +6,91 @@ import (
"net"
"unsafe"
"github.com/lucas-clemente/quic-go/internal/utils"
mocktls "github.com/lucas-clemente/quic-go/internal/mocks/tls"
"github.com/marten-seemann/qtls"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
type mockExtensionHandler struct {
get, received bool
}
var _ tlsExtensionHandler = &mockExtensionHandler{}
func (h *mockExtensionHandler) GetExtensions(msgType uint8) []qtls.Extension {
h.get = true
return nil
}
func (h *mockExtensionHandler) ReceivedExtensions(msgType uint8, exts []qtls.Extension) {
h.received = true
}
func (*mockExtensionHandler) TransportParameters() <-chan []byte { panic("not implemented") }
var _ = Describe("qtls.Config", func() {
var _ = Describe("Config", func() {
It("sets MinVersion and MaxVersion", func() {
tlsConf := &tls.Config{MinVersion: tls.VersionTLS11, MaxVersion: tls.VersionTLS12}
qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil, &mockExtensionHandler{}, utils.NewRTTStats(), nil, nil, nil, nil, false)
qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil)
Expect(qtlsConf.MinVersion).To(BeEquivalentTo(tls.VersionTLS13))
Expect(qtlsConf.MaxVersion).To(BeEquivalentTo(tls.VersionTLS13))
})
It("works when called with a nil config", func() {
qtlsConf := tlsConfigToQtlsConfig(nil, nil, &mockExtensionHandler{}, utils.NewRTTStats(), nil, nil, nil, nil, false)
qtlsConf := tlsConfigToQtlsConfig(nil, nil)
Expect(qtlsConf).ToNot(BeNil())
})
It("sets the setter and getter function for TLS extensions", func() {
extHandler := &mockExtensionHandler{}
qtlsConf := tlsConfigToQtlsConfig(&tls.Config{}, nil, extHandler, utils.NewRTTStats(), nil, nil, nil, nil, false)
Expect(extHandler.get).To(BeFalse())
var get, received bool
extraConfig := &ExtraConfig{
GetExtensions: func(handshakeMessageType uint8) []Extension { get = true; return nil },
ReceivedExtensions: func(handshakeMessageType uint8, exts []qtls.Extension) { received = true },
}
qtlsConf := tlsConfigToQtlsConfig(&tls.Config{}, extraConfig)
qtlsConf.GetExtensions(10)
Expect(extHandler.get).To(BeTrue())
Expect(extHandler.received).To(BeFalse())
Expect(get).To(BeTrue())
Expect(received).To(BeFalse())
qtlsConf.ReceivedExtensions(10, nil)
Expect(extHandler.received).To(BeTrue())
Expect(received).To(BeTrue())
})
It("sets the Accept0RTT callback", func() {
accept0RTT := func([]byte) bool { return true }
qtlsConf := tlsConfigToQtlsConfig(nil, nil, &mockExtensionHandler{}, utils.NewRTTStats(), nil, nil, accept0RTT, nil, false)
qtlsConf := tlsConfigToQtlsConfig(nil, &ExtraConfig{Accept0RTT: func([]byte) bool { return true }})
Expect(qtlsConf.Accept0RTT).ToNot(BeNil())
Expect(qtlsConf.Accept0RTT(nil)).To(BeTrue())
})
It("sets the Accept0RTT callback", func() {
It("sets the Rejected0RTT callback", func() {
var called bool
rejected0RTT := func() { called = true }
qtlsConf := tlsConfigToQtlsConfig(nil, nil, &mockExtensionHandler{}, utils.NewRTTStats(), nil, nil, nil, rejected0RTT, false)
qtlsConf := tlsConfigToQtlsConfig(nil, &ExtraConfig{Rejected0RTT: func() { called = true }})
Expect(qtlsConf.Rejected0RTT).ToNot(BeNil())
qtlsConf.Rejected0RTT()
Expect(called).To(BeTrue())
})
It("enables 0-RTT", func() {
qtlsConf := tlsConfigToQtlsConfig(nil, nil, &mockExtensionHandler{}, utils.NewRTTStats(), nil, nil, nil, nil, false)
Expect(qtlsConf.Enable0RTT).To(BeFalse())
It("sets MaxEarlyData", func() {
qtlsConf := tlsConfigToQtlsConfig(nil, nil)
Expect(qtlsConf.MaxEarlyData).To(BeZero())
qtlsConf = tlsConfigToQtlsConfig(nil, nil, &mockExtensionHandler{}, utils.NewRTTStats(), nil, nil, nil, nil, true)
qtlsConf = tlsConfigToQtlsConfig(nil, &ExtraConfig{MaxEarlyData: 1337})
Expect(qtlsConf.MaxEarlyData).To(Equal(uint32(1337)))
})
It("enables 0-RTT", func() {
qtlsConf := tlsConfigToQtlsConfig(nil, nil)
Expect(qtlsConf.Enable0RTT).To(BeFalse())
qtlsConf = tlsConfigToQtlsConfig(nil, &ExtraConfig{Enable0RTT: true})
Expect(qtlsConf.Enable0RTT).To(BeTrue())
Expect(qtlsConf.MaxEarlyData).To(Equal(uint32(0xffffffff)))
})
It("initializes such that the session ticket key remains constant", func() {
tlsConf := &tls.Config{}
qtlsConf1 := tlsConfigToQtlsConfig(tlsConf, nil, &mockExtensionHandler{}, utils.NewRTTStats(), nil, nil, nil, nil, false)
qtlsConf2 := tlsConfigToQtlsConfig(tlsConf, nil, &mockExtensionHandler{}, utils.NewRTTStats(), nil, nil, nil, nil, false)
qtlsConf1 := tlsConfigToQtlsConfig(tlsConf, nil)
qtlsConf2 := tlsConfigToQtlsConfig(tlsConf, nil)
Expect(qtlsConf1.SessionTicketKey).ToNot(BeZero()) // should now contain a random value
Expect(qtlsConf1.SessionTicketKey).To(Equal(qtlsConf2.SessionTicketKey))
})
Context("GetConfigForClient callback", func() {
It("doesn't set it if absent", func() {
qtlsConf := tlsConfigToQtlsConfig(&tls.Config{}, nil, &mockExtensionHandler{}, utils.NewRTTStats(), nil, nil, nil, nil, false)
qtlsConf := tlsConfigToQtlsConfig(nil, nil)
Expect(qtlsConf.GetConfigForClient).To(BeNil())
})
It("returns a qtls.Config", func() {
It("returns a Config", func() {
tlsConf := &tls.Config{
GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) {
return &tls.Config{ServerName: "foo.bar"}, nil
},
}
extHandler := &mockExtensionHandler{}
qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil, extHandler, utils.NewRTTStats(), nil, nil, nil, nil, false)
var received bool
qtlsConf := tlsConfigToQtlsConfig(tlsConf, &ExtraConfig{ReceivedExtensions: func(uint8, []Extension) { received = true }})
Expect(qtlsConf.GetConfigForClient).ToNot(BeNil())
confForClient, err := qtlsConf.GetConfigForClient(nil)
Expect(err).ToNot(HaveOccurred())
@ -106,9 +98,10 @@ var _ = Describe("qtls.Config", func() {
Expect(confForClient).ToNot(BeNil())
Expect(confForClient.MinVersion).To(BeEquivalentTo(tls.VersionTLS13))
Expect(confForClient.MaxVersion).To(BeEquivalentTo(tls.VersionTLS13))
Expect(extHandler.get).To(BeFalse())
confForClient.GetExtensions(10)
Expect(extHandler.get).To(BeTrue())
Expect(received).To(BeFalse())
Expect(confForClient.ReceivedExtensions).ToNot(BeNil())
confForClient.ReceivedExtensions(10, nil)
Expect(received).To(BeTrue())
})
It("returns errors", func() {
@ -118,7 +111,7 @@ var _ = Describe("qtls.Config", func() {
return nil, testErr
},
}
qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil, &mockExtensionHandler{}, utils.NewRTTStats(), nil, nil, nil, nil, false)
qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil)
_, err := qtlsConf.GetConfigForClient(nil)
Expect(err).To(MatchError(testErr))
})
@ -129,7 +122,7 @@ var _ = Describe("qtls.Config", func() {
return nil, nil
},
}
qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil, &mockExtensionHandler{}, utils.NewRTTStats(), nil, nil, nil, nil, false)
qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil)
Expect(qtlsConf.GetConfigForClient(nil)).To(BeNil())
})
})
@ -141,7 +134,7 @@ var _ = Describe("qtls.Config", func() {
return &tls.Certificate{Certificate: [][]byte{[]byte("foo"), []byte("bar")}}, nil
},
}
qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil, &mockExtensionHandler{}, utils.NewRTTStats(), nil, nil, nil, nil, false)
qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil)
qtlsCert, err := qtlsConf.GetCertificate(nil)
Expect(err).ToNot(HaveOccurred())
Expect(qtlsCert).ToNot(BeNil())
@ -149,7 +142,7 @@ var _ = Describe("qtls.Config", func() {
})
It("doesn't set it if absent", func() {
qtlsConf := tlsConfigToQtlsConfig(&tls.Config{}, nil, &mockExtensionHandler{}, utils.NewRTTStats(), nil, nil, nil, nil, false)
qtlsConf := tlsConfigToQtlsConfig(&tls.Config{}, nil)
Expect(qtlsConf.GetCertificate).To(BeNil())
})
@ -159,7 +152,7 @@ var _ = Describe("qtls.Config", func() {
return nil, errors.New("test")
},
}
qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil, &mockExtensionHandler{}, utils.NewRTTStats(), nil, nil, nil, nil, false)
qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil)
_, err := qtlsConf.GetCertificate(nil)
Expect(err).To(MatchError("test"))
})
@ -170,21 +163,21 @@ var _ = Describe("qtls.Config", func() {
return nil, nil
},
}
qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil, &mockExtensionHandler{}, utils.NewRTTStats(), nil, nil, nil, nil, false)
qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil)
Expect(qtlsConf.GetCertificate(nil)).To(BeNil())
})
})
Context("ClientSessionCache", func() {
It("doesn't set if absent", func() {
qtlsConf := tlsConfigToQtlsConfig(&tls.Config{}, nil, &mockExtensionHandler{}, utils.NewRTTStats(), nil, nil, nil, nil, false)
qtlsConf := tlsConfigToQtlsConfig(&tls.Config{}, nil)
Expect(qtlsConf.ClientSessionCache).To(BeNil())
})
It("puts a nil session state", func() {
csc := NewMockClientSessionCache(mockCtrl)
csc := mocktls.NewMockClientSessionCache(mockCtrl)
tlsConf := &tls.Config{ClientSessionCache: csc}
qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil, &mockExtensionHandler{}, utils.NewRTTStats(), nil, nil, nil, nil, false)
qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil)
// put something
csc.EXPECT().Put("foobar", nil)
qtlsConf.ClientSessionCache.Put("foobar", nil)
@ -192,25 +185,25 @@ var _ = Describe("qtls.Config", func() {
})
})
var _ = Describe("qtls.Config generation", func() {
It("converts a qtls.ClientHelloInfo to a tls.ClientHelloInfo", func() {
var _ = Describe("Config generation", func() {
It("converts a ClientHelloInfo to a tls.ClientHelloInfo", func() {
chi := &qtlsClientHelloInfo{
CipherSuites: []uint16{1, 2, 3},
ServerName: "foo.bar",
SupportedCurves: []qtls.CurveID{4, 5, 6},
SupportedCurves: []tls.CurveID{4, 5, 6},
SupportedPoints: []uint8{7, 8, 9},
SignatureSchemes: []qtls.SignatureScheme{10, 11, 12},
SignatureSchemes: []tls.SignatureScheme{10, 11, 12},
SupportedProtos: []string{"foo", "bar"},
SupportedVersions: []uint16{13, 14, 15},
Conn: &net.UDPConn{},
config: &qtls.Config{
config: &Config{
MinVersion: tls.VersionTLS10,
MaxVersion: tls.VersionTLS12,
CipherSuites: []uint16{16, 17, 18},
CurvePreferences: []qtls.CurveID{19, 20, 21},
CurvePreferences: []tls.CurveID{19, 20, 21},
},
}
tlsCHI := toTLSClientHelloInfo((*qtls.ClientHelloInfo)(unsafe.Pointer(chi)))
tlsCHI := toTLSClientHelloInfo((*ClientHelloInfo)(unsafe.Pointer(chi)))
Expect(tlsCHI.CipherSuites).To(Equal([]uint16{1, 2, 3}))
Expect(tlsCHI.ServerName).To(Equal("foo.bar"))
Expect(tlsCHI.SupportedCurves).To(Equal([]tls.CurveID{4, 5, 6}))
@ -226,9 +219,9 @@ var _ = Describe("qtls.Config generation", func() {
Expect(c.config.CurvePreferences).To(Equal([]tls.CurveID{19, 20, 21}))
})
It("converts a qtls.ClientHelloInfo to a tls.ClientHelloInfo, if no config is set", func() {
It("converts a ClientHelloInfo to a tls.ClientHelloInfo, if no config is set", func() {
chi := &qtlsClientHelloInfo{CipherSuites: []uint16{13, 37}}
tlsCHI := toTLSClientHelloInfo((*qtls.ClientHelloInfo)(unsafe.Pointer(chi)))
tlsCHI := toTLSClientHelloInfo((*ClientHelloInfo)(unsafe.Pointer(chi)))
Expect(tlsCHI.CipherSuites).To(Equal([]uint16{13, 37}))
})
})

View file

@ -0,0 +1,19 @@
package qtls
import "reflect"
func structsEqual(a, b interface{}) bool {
sa := reflect.ValueOf(a).Elem()
sb := reflect.ValueOf(b).Elem()
if sa.NumField() != sb.NumField() {
return false
}
for i := 0; i < sa.NumField(); i++ {
fa := sa.Type().Field(i)
fb := sb.Type().Field(i)
if !reflect.DeepEqual(fa.Index, fb.Index) || fa.Name != fb.Name || fa.Anonymous != fb.Anonymous || fa.Offset != fb.Offset || !reflect.DeepEqual(fa.Type, fb.Type) {
return false
}
}
return true
}

View file

@ -1,4 +1,4 @@
package handshake
package qtls
import (
. "github.com/onsi/ginkgo"

View file

@ -11,7 +11,6 @@ import (
gomock "github.com/golang/mock/gomock"
protocol "github.com/lucas-clemente/quic-go/internal/protocol"
qtls "github.com/marten-seemann/qtls"
)
// MockQuicSession is a mock of QuicSession interface
@ -82,10 +81,10 @@ func (mr *MockQuicSessionMockRecorder) CloseWithError(arg0, arg1 interface{}) *g
}
// ConnectionState mocks base method
func (m *MockQuicSession) ConnectionState() qtls.ConnectionState {
func (m *MockQuicSession) ConnectionState() ConnectionState {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ConnectionState")
ret0, _ := ret[0].(qtls.ConnectionState)
ret0, _ := ret[0].(ConnectionState)
return ret0
}

View file

@ -10,6 +10,7 @@ echo -e "package $1\n" > $TMPFILE
echo "type $INTERFACE_NAME = $4" >> $TMPFILE
mockgen -package $1 -self_package $PACKAGE -destination $2 $PACKAGE $INTERFACE_NAME
mv $2 $TMPFILE && sed 's/qtls.ConnectionState/ConnectionState/g' $TMPFILE > $2
goimports -w $2
rm $TMPFILE