mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-04 12:47:36 +03:00
move all dependencies on qtls to a separate package
This commit is contained in:
parent
524da2213c
commit
977dbc828c
29 changed files with 572 additions and 478 deletions
|
@ -5,7 +5,7 @@ import (
|
|||
"encoding/binary"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/marten-seemann/qtls"
|
||||
"github.com/lucas-clemente/quic-go/internal/qtls"
|
||||
)
|
||||
|
||||
func createAEAD(suite *qtls.CipherSuiteTLS13, trafficSecret []byte) cipher.AEAD {
|
||||
|
|
|
@ -8,7 +8,7 @@ import (
|
|||
"fmt"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/marten-seemann/qtls"
|
||||
"github.com/lucas-clemente/quic-go/internal/qtls"
|
||||
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
|
|
|
@ -12,10 +12,10 @@ import (
|
|||
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/qerr"
|
||||
"github.com/lucas-clemente/quic-go/internal/qtls"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
"github.com/lucas-clemente/quic-go/internal/wire"
|
||||
"github.com/lucas-clemente/quic-go/logging"
|
||||
"github.com/marten-seemann/qtls"
|
||||
)
|
||||
|
||||
// TLS unexpected_message alert
|
||||
|
@ -60,9 +60,32 @@ func (m messageType) String() string {
|
|||
|
||||
const clientSessionStateRevision = 3
|
||||
|
||||
type conn struct {
|
||||
localAddr, remoteAddr net.Addr
|
||||
}
|
||||
|
||||
func newConn(local, remote net.Addr) net.Conn {
|
||||
return &conn{
|
||||
localAddr: local,
|
||||
remoteAddr: remote,
|
||||
}
|
||||
}
|
||||
|
||||
var _ net.Conn = &conn{}
|
||||
|
||||
func (c *conn) Read([]byte) (int, error) { return 0, nil }
|
||||
func (c *conn) Write([]byte) (int, error) { return 0, nil }
|
||||
func (c *conn) Close() error { return nil }
|
||||
func (c *conn) RemoteAddr() net.Addr { return c.remoteAddr }
|
||||
func (c *conn) LocalAddr() net.Addr { return c.localAddr }
|
||||
func (c *conn) SetReadDeadline(time.Time) error { return nil }
|
||||
func (c *conn) SetWriteDeadline(time.Time) error { return nil }
|
||||
func (c *conn) SetDeadline(time.Time) error { return nil }
|
||||
|
||||
type cryptoSetup struct {
|
||||
tlsConf *qtls.Config
|
||||
conn *qtls.Conn
|
||||
tlsConf *tls.Config
|
||||
extraConf *qtls.ExtraConfig
|
||||
conn *qtls.Conn
|
||||
|
||||
messageChan chan []byte
|
||||
|
||||
|
@ -152,7 +175,7 @@ func NewCryptoSetupClient(
|
|||
logger,
|
||||
protocol.PerspectiveClient,
|
||||
)
|
||||
cs.conn = qtls.Client(newConn(localAddr, remoteAddr), cs.tlsConf)
|
||||
cs.conn = qtls.Client(newConn(localAddr, remoteAddr), cs.tlsConf, cs.extraConf)
|
||||
return cs, clientHelloWritten
|
||||
}
|
||||
|
||||
|
@ -184,7 +207,7 @@ func NewCryptoSetupServer(
|
|||
logger,
|
||||
protocol.PerspectiveServer,
|
||||
)
|
||||
cs.conn = qtls.Server(newConn(localAddr, remoteAddr), cs.tlsConf)
|
||||
cs.conn = qtls.Server(newConn(localAddr, remoteAddr), cs.tlsConf, cs.extraConf)
|
||||
return cs
|
||||
}
|
||||
|
||||
|
@ -208,6 +231,7 @@ func newCryptoSetup(
|
|||
}
|
||||
extHandler := newExtensionHandler(tp.Marshal(perspective), perspective)
|
||||
cs := &cryptoSetup{
|
||||
tlsConf: tlsConf,
|
||||
initialStream: initialStream,
|
||||
initialSealer: initialSealer,
|
||||
initialOpener: initialOpener,
|
||||
|
@ -231,8 +255,22 @@ func newCryptoSetup(
|
|||
writeRecord: make(chan struct{}, 1),
|
||||
closeChan: make(chan struct{}),
|
||||
}
|
||||
qtlsConf := tlsConfigToQtlsConfig(tlsConf, cs, extHandler, rttStats, cs.marshalDataForSessionState, cs.handleDataFromSessionState, cs.accept0RTT, cs.rejected0RTT, enable0RTT)
|
||||
cs.tlsConf = qtlsConf
|
||||
var maxEarlyData uint32
|
||||
if enable0RTT {
|
||||
maxEarlyData = 0xffffffff
|
||||
}
|
||||
cs.extraConf = &qtls.ExtraConfig{
|
||||
GetExtensions: extHandler.GetExtensions,
|
||||
ReceivedExtensions: extHandler.ReceivedExtensions,
|
||||
AlternativeRecordLayer: cs,
|
||||
EnforceNextProtoSelection: true,
|
||||
MaxEarlyData: maxEarlyData,
|
||||
Accept0RTT: cs.accept0RTT,
|
||||
Rejected0RTT: cs.rejected0RTT,
|
||||
Enable0RTT: enable0RTT,
|
||||
GetAppDataForSessionState: cs.marshalDataForSessionState,
|
||||
SetAppDataFromSessionState: cs.handleDataFromSessionState,
|
||||
}
|
||||
return cs, cs.clientHelloWrittenChan
|
||||
}
|
||||
|
||||
|
@ -499,7 +537,7 @@ func (h *cryptoSetup) handleDataFromSessionStateImpl(data []byte) (*wire.Transpo
|
|||
func (h *cryptoSetup) GetSessionTicket() ([]byte, error) {
|
||||
var appData []byte
|
||||
// Save transport parameters to the session ticket if we're allowing 0-RTT.
|
||||
if h.tlsConf.MaxEarlyData > 0 {
|
||||
if h.extraConf.MaxEarlyData > 0 {
|
||||
appData = (&sessionTicket{
|
||||
Parameters: h.ourParams,
|
||||
RTT: h.rttStats.SmoothedRTT(),
|
||||
|
@ -819,5 +857,5 @@ func (h *cryptoSetup) Get1RTTOpener() (ShortHeaderOpener, error) {
|
|||
}
|
||||
|
||||
func (h *cryptoSetup) ConnectionState() ConnectionState {
|
||||
return h.conn.ConnectionState()
|
||||
return qtls.GetConnectionState(h.conn)
|
||||
}
|
||||
|
|
|
@ -1,22 +1,20 @@
|
|||
package handshake
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"errors"
|
||||
"math/big"
|
||||
"time"
|
||||
|
||||
mocktls "github.com/lucas-clemente/quic-go/internal/mocks/tls"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/qerr"
|
||||
"github.com/lucas-clemente/quic-go/internal/testdata"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
"github.com/lucas-clemente/quic-go/internal/wire"
|
||||
"github.com/marten-seemann/qtls"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
|
||||
|
@ -74,48 +72,6 @@ var _ = Describe("Crypto Setup TLS", func() {
|
|||
}
|
||||
})
|
||||
|
||||
It("creates a qtls.Config", func() {
|
||||
tlsConf := &tls.Config{
|
||||
ServerName: "quic.clemente.io",
|
||||
GetCertificate: func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
return nil, errors.New("GetCertificate")
|
||||
},
|
||||
GetClientCertificate: func(*tls.CertificateRequestInfo) (*tls.Certificate, error) {
|
||||
return nil, errors.New("GetClientCertificate")
|
||||
},
|
||||
GetConfigForClient: func(ch *tls.ClientHelloInfo) (*tls.Config, error) {
|
||||
return &tls.Config{ServerName: ch.ServerName}, nil
|
||||
},
|
||||
}
|
||||
var token protocol.StatelessResetToken
|
||||
server := NewCryptoSetupServer(
|
||||
&bytes.Buffer{},
|
||||
&bytes.Buffer{},
|
||||
protocol.ConnectionID{},
|
||||
nil,
|
||||
nil,
|
||||
&wire.TransportParameters{StatelessResetToken: &token},
|
||||
NewMockHandshakeRunner(mockCtrl),
|
||||
tlsConf,
|
||||
false,
|
||||
&utils.RTTStats{},
|
||||
nil,
|
||||
utils.DefaultLogger.WithPrefix("server"),
|
||||
)
|
||||
qtlsConf := server.(*cryptoSetup).tlsConf
|
||||
Expect(qtlsConf.ServerName).To(Equal(tlsConf.ServerName))
|
||||
_, getCertificateErr := qtlsConf.GetCertificate(nil)
|
||||
Expect(getCertificateErr).To(MatchError("GetCertificate"))
|
||||
_, getClientCertificateErr := qtlsConf.GetClientCertificate(nil)
|
||||
Expect(getClientCertificateErr).To(MatchError("GetClientCertificate"))
|
||||
cconf, err := qtlsConf.GetConfigForClient(&qtls.ClientHelloInfo{ServerName: "foo.bar"})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(cconf.ServerName).To(Equal("foo.bar"))
|
||||
Expect(cconf.AlternativeRecordLayer).ToNot(BeNil())
|
||||
Expect(cconf.GetExtensions).ToNot(BeNil())
|
||||
Expect(cconf.ReceivedExtensions).ToNot(BeNil())
|
||||
})
|
||||
|
||||
It("returns Handshake() when an error occurs in qtls", func() {
|
||||
sErrChan := make(chan error, 1)
|
||||
runner := NewMockHandshakeRunner(mockCtrl)
|
||||
|
@ -420,7 +376,7 @@ var _ = Describe("Crypto Setup TLS", func() {
|
|||
|
||||
It("handshakes with client auth", func() {
|
||||
clientConf.Certificates = []tls.Certificate{generateCert()}
|
||||
serverConf.ClientAuth = qtls.RequireAnyClientCert
|
||||
serverConf.ClientAuth = tls.RequireAnyClientCert
|
||||
_, _, clientErr, _, serverErr := handshakeWithTLSConf(
|
||||
clientConf, serverConf,
|
||||
&utils.RTTStats{}, &utils.RTTStats{},
|
||||
|
@ -647,7 +603,7 @@ var _ = Describe("Crypto Setup TLS", func() {
|
|||
})
|
||||
|
||||
It("uses session resumption", func() {
|
||||
csc := NewMockClientSessionCache(mockCtrl)
|
||||
csc := mocktls.NewMockClientSessionCache(mockCtrl)
|
||||
var state *tls.ClientSessionState
|
||||
receivedSessionTicket := make(chan struct{})
|
||||
csc.EXPECT().Get(gomock.Any())
|
||||
|
@ -690,7 +646,7 @@ var _ = Describe("Crypto Setup TLS", func() {
|
|||
})
|
||||
|
||||
It("doesn't use session resumption if the server disabled it", func() {
|
||||
csc := NewMockClientSessionCache(mockCtrl)
|
||||
csc := mocktls.NewMockClientSessionCache(mockCtrl)
|
||||
var state *tls.ClientSessionState
|
||||
receivedSessionTicket := make(chan struct{})
|
||||
csc.EXPECT().Get(gomock.Any())
|
||||
|
@ -727,7 +683,7 @@ var _ = Describe("Crypto Setup TLS", func() {
|
|||
})
|
||||
|
||||
It("uses 0-RTT", func() {
|
||||
csc := NewMockClientSessionCache(mockCtrl)
|
||||
csc := mocktls.NewMockClientSessionCache(mockCtrl)
|
||||
var state *tls.ClientSessionState
|
||||
receivedSessionTicket := make(chan struct{})
|
||||
csc.EXPECT().Get(gomock.Any())
|
||||
|
@ -782,7 +738,7 @@ var _ = Describe("Crypto Setup TLS", func() {
|
|||
})
|
||||
|
||||
It("rejects 0-RTT, whent the transport parameters changed", func() {
|
||||
csc := NewMockClientSessionCache(mockCtrl)
|
||||
csc := mocktls.NewMockClientSessionCache(mockCtrl)
|
||||
var state *tls.ClientSessionState
|
||||
receivedSessionTicket := make(chan struct{})
|
||||
csc.EXPECT().Get(gomock.Any())
|
||||
|
|
|
@ -1,14 +1,13 @@
|
|||
package handshake
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/cipher"
|
||||
"crypto/tls"
|
||||
"encoding/hex"
|
||||
"strings"
|
||||
"unsafe"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/qtls"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/marten-seemann/qtls"
|
||||
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
|
@ -31,27 +30,6 @@ var _ = AfterEach(func() {
|
|||
mockCtrl.Finish()
|
||||
})
|
||||
|
||||
var cipherSuites = []*qtls.CipherSuiteTLS13{
|
||||
&qtls.CipherSuiteTLS13{
|
||||
ID: qtls.TLS_AES_128_GCM_SHA256,
|
||||
KeyLen: 16,
|
||||
AEAD: qtls.AEADAESGCMTLS13,
|
||||
Hash: crypto.SHA256,
|
||||
},
|
||||
&qtls.CipherSuiteTLS13{
|
||||
ID: qtls.TLS_AES_256_GCM_SHA384,
|
||||
KeyLen: 32,
|
||||
AEAD: qtls.AEADAESGCMTLS13,
|
||||
Hash: crypto.SHA384,
|
||||
},
|
||||
&qtls.CipherSuiteTLS13{
|
||||
ID: qtls.TLS_CHACHA20_POLY1305_SHA256,
|
||||
KeyLen: 32,
|
||||
AEAD: nil, // will be set by init
|
||||
Hash: crypto.SHA256,
|
||||
},
|
||||
}
|
||||
|
||||
func splitHexString(s string) (slice []byte) {
|
||||
for _, ss := range strings.Split(s, " ") {
|
||||
if ss[0:2] == "0x" {
|
||||
|
@ -64,25 +42,8 @@ func splitHexString(s string) (slice []byte) {
|
|||
return
|
||||
}
|
||||
|
||||
type cipherSuiteTLS13 struct {
|
||||
ID uint16
|
||||
KeyLen int
|
||||
AEAD func(key, fixedNonce []byte) cipher.AEAD
|
||||
Hash crypto.Hash
|
||||
}
|
||||
|
||||
//go:linkname cipherSuiteTLS13ByID github.com/marten-seemann/qtls.cipherSuiteTLS13ByID
|
||||
func cipherSuiteTLS13ByID(id uint16) *cipherSuiteTLS13
|
||||
|
||||
func init() {
|
||||
val := cipherSuiteTLS13ByID(qtls.TLS_CHACHA20_POLY1305_SHA256)
|
||||
chacha := (*cipherSuiteTLS13)(unsafe.Pointer(val))
|
||||
for _, s := range cipherSuites {
|
||||
if s.ID == qtls.TLS_CHACHA20_POLY1305_SHA256 {
|
||||
if s.KeyLen != chacha.KeyLen || s.Hash != chacha.Hash {
|
||||
panic("invalid parameters for ChaCha20")
|
||||
}
|
||||
s.AEAD = chacha.AEAD
|
||||
}
|
||||
}
|
||||
var cipherSuites = []*qtls.CipherSuiteTLS13{
|
||||
qtls.CipherSuiteTLS13ByID(tls.TLS_AES_128_GCM_SHA256),
|
||||
qtls.CipherSuiteTLS13ByID(tls.TLS_AES_256_GCM_SHA384),
|
||||
qtls.CipherSuiteTLS13ByID(tls.TLS_CHACHA20_POLY1305_SHA256),
|
||||
}
|
||||
|
|
|
@ -3,12 +3,13 @@ package handshake
|
|||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/tls"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
|
||||
"golang.org/x/crypto/chacha20"
|
||||
|
||||
"github.com/marten-seemann/qtls"
|
||||
"github.com/lucas-clemente/quic-go/internal/qtls"
|
||||
)
|
||||
|
||||
type headerProtector interface {
|
||||
|
@ -18,9 +19,9 @@ type headerProtector interface {
|
|||
|
||||
func newHeaderProtector(suite *qtls.CipherSuiteTLS13, trafficSecret []byte, isLongHeader bool) headerProtector {
|
||||
switch suite.ID {
|
||||
case qtls.TLS_AES_128_GCM_SHA256, qtls.TLS_AES_256_GCM_SHA384:
|
||||
case tls.TLS_AES_128_GCM_SHA256, tls.TLS_AES_256_GCM_SHA384:
|
||||
return newAESHeaderProtector(suite, trafficSecret, isLongHeader)
|
||||
case qtls.TLS_CHACHA20_POLY1305_SHA256:
|
||||
case tls.TLS_CHACHA20_POLY1305_SHA256:
|
||||
return newChaChaHeaderProtector(suite, trafficSecret, isLongHeader)
|
||||
default:
|
||||
panic(fmt.Sprintf("Invalid cipher suite id: %d", suite.ID))
|
||||
|
|
|
@ -5,7 +5,7 @@ import (
|
|||
"crypto/rand"
|
||||
mrand "math/rand"
|
||||
|
||||
"github.com/marten-seemann/qtls"
|
||||
"github.com/lucas-clemente/quic-go/internal/qtls"
|
||||
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
|
|
|
@ -2,15 +2,16 @@ package handshake
|
|||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/tls"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/marten-seemann/qtls"
|
||||
"github.com/lucas-clemente/quic-go/internal/qtls"
|
||||
)
|
||||
|
||||
var quicVersion1Salt = []byte{0xaf, 0xbf, 0xec, 0x28, 0x99, 0x93, 0xd2, 0x4c, 0x9e, 0x97, 0x86, 0xf1, 0x9c, 0x61, 0x11, 0xe0, 0x43, 0x90, 0xa8, 0x99}
|
||||
|
||||
var initialSuite = &qtls.CipherSuiteTLS13{
|
||||
ID: qtls.TLS_AES_128_GCM_SHA256,
|
||||
ID: tls.TLS_AES_128_GCM_SHA256,
|
||||
KeyLen: 16,
|
||||
AEAD: qtls.AEADAESGCMTLS13,
|
||||
Hash: crypto.SHA256,
|
||||
|
|
|
@ -6,8 +6,8 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/qtls"
|
||||
"github.com/lucas-clemente/quic-go/internal/wire"
|
||||
"github.com/marten-seemann/qtls"
|
||||
)
|
||||
|
||||
var (
|
||||
|
|
|
@ -1,7 +1,3 @@
|
|||
package handshake
|
||||
|
||||
//go:generate sh -c "../../mockgen_private.sh handshake mock_handshake_runner_test.go github.com/lucas-clemente/quic-go/internal/handshake handshakeRunner"
|
||||
|
||||
// The following command produces a warning message on OSX, however, it still generates the correct mock file.
|
||||
// See https://github.com/golang/mock/issues/339 for details.
|
||||
//go:generate sh -c "mockgen -package handshake -destination mock_client_session_cache_test.go crypto/tls ClientSessionCache && goimports -w mock_client_session_cache_test.go"
|
||||
|
|
|
@ -1,233 +0,0 @@
|
|||
package handshake
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"net"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
|
||||
"github.com/marten-seemann/qtls"
|
||||
)
|
||||
|
||||
func init() {
|
||||
if !structsEqual(&tls.ClientHelloInfo{}, &clientHelloInfo{}) {
|
||||
panic("clientHelloInfo not compatible with tls.ClientHelloInfo")
|
||||
}
|
||||
if !structsEqual(&qtls.ClientHelloInfo{}, &qtlsClientHelloInfo{}) {
|
||||
panic("qtlsClientHelloInfo not compatible with qtls.ClientHelloInfo")
|
||||
}
|
||||
}
|
||||
|
||||
type conn struct {
|
||||
localAddr, remoteAddr net.Addr
|
||||
}
|
||||
|
||||
func newConn(local, remote net.Addr) net.Conn {
|
||||
return &conn{
|
||||
localAddr: local,
|
||||
remoteAddr: remote,
|
||||
}
|
||||
}
|
||||
|
||||
var _ net.Conn = &conn{}
|
||||
|
||||
func (c *conn) Read([]byte) (int, error) { return 0, nil }
|
||||
func (c *conn) Write([]byte) (int, error) { return 0, nil }
|
||||
func (c *conn) Close() error { return nil }
|
||||
func (c *conn) RemoteAddr() net.Addr { return c.remoteAddr }
|
||||
func (c *conn) LocalAddr() net.Addr { return c.localAddr }
|
||||
func (c *conn) SetReadDeadline(time.Time) error { return nil }
|
||||
func (c *conn) SetWriteDeadline(time.Time) error { return nil }
|
||||
func (c *conn) SetDeadline(time.Time) error { return nil }
|
||||
|
||||
func tlsConfigToQtlsConfig(
|
||||
c *tls.Config,
|
||||
recordLayer qtls.RecordLayer,
|
||||
extHandler tlsExtensionHandler,
|
||||
rttStats *utils.RTTStats,
|
||||
getDataForSessionState func() []byte,
|
||||
setDataFromSessionState func([]byte),
|
||||
accept0RTT func([]byte) bool,
|
||||
rejected0RTT func(),
|
||||
enable0RTT bool,
|
||||
) *qtls.Config {
|
||||
if c == nil {
|
||||
c = &tls.Config{}
|
||||
}
|
||||
// Clone the config first. This executes the tls.Config.serverInit().
|
||||
// This sets the SessionTicketKey, if the user didn't supply one.
|
||||
c = c.Clone()
|
||||
// QUIC requires TLS 1.3 or newer
|
||||
minVersion := c.MinVersion
|
||||
if minVersion < qtls.VersionTLS13 {
|
||||
minVersion = qtls.VersionTLS13
|
||||
}
|
||||
maxVersion := c.MaxVersion
|
||||
if maxVersion < qtls.VersionTLS13 {
|
||||
maxVersion = qtls.VersionTLS13
|
||||
}
|
||||
var getConfigForClient func(ch *qtls.ClientHelloInfo) (*qtls.Config, error)
|
||||
if c.GetConfigForClient != nil {
|
||||
getConfigForClient = func(ch *qtls.ClientHelloInfo) (*qtls.Config, error) {
|
||||
tlsConf, err := c.GetConfigForClient(toTLSClientHelloInfo(ch))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if tlsConf == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return tlsConfigToQtlsConfig(tlsConf, recordLayer, extHandler, rttStats, getDataForSessionState, setDataFromSessionState, accept0RTT, rejected0RTT, enable0RTT), nil
|
||||
}
|
||||
}
|
||||
var getCertificate func(ch *qtls.ClientHelloInfo) (*qtls.Certificate, error)
|
||||
if c.GetCertificate != nil {
|
||||
getCertificate = func(ch *qtls.ClientHelloInfo) (*qtls.Certificate, error) {
|
||||
cert, err := c.GetCertificate(toTLSClientHelloInfo(ch))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if cert == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return (*qtls.Certificate)(cert), nil
|
||||
}
|
||||
}
|
||||
var csc qtls.ClientSessionCache
|
||||
if c.ClientSessionCache != nil {
|
||||
csc = &clientSessionCache{c.ClientSessionCache}
|
||||
}
|
||||
conf := &qtls.Config{
|
||||
Rand: c.Rand,
|
||||
Time: c.Time,
|
||||
Certificates: *(*[]qtls.Certificate)(unsafe.Pointer(&c.Certificates)),
|
||||
// NameToCertificate is deprecated, but we still need to copy it if the user sets it.
|
||||
//nolint:staticcheck
|
||||
NameToCertificate: *(*map[string]*qtls.Certificate)(unsafe.Pointer(&c.NameToCertificate)),
|
||||
GetCertificate: getCertificate,
|
||||
GetClientCertificate: *(*func(*qtls.CertificateRequestInfo) (*qtls.Certificate, error))(unsafe.Pointer(&c.GetClientCertificate)),
|
||||
GetConfigForClient: getConfigForClient,
|
||||
VerifyPeerCertificate: c.VerifyPeerCertificate,
|
||||
RootCAs: c.RootCAs,
|
||||
NextProtos: c.NextProtos,
|
||||
EnforceNextProtoSelection: true,
|
||||
ServerName: c.ServerName,
|
||||
ClientAuth: c.ClientAuth,
|
||||
ClientCAs: c.ClientCAs,
|
||||
InsecureSkipVerify: c.InsecureSkipVerify,
|
||||
CipherSuites: c.CipherSuites,
|
||||
PreferServerCipherSuites: c.PreferServerCipherSuites,
|
||||
SessionTicketsDisabled: c.SessionTicketsDisabled,
|
||||
SessionTicketKey: c.SessionTicketKey,
|
||||
ClientSessionCache: csc,
|
||||
MinVersion: minVersion,
|
||||
MaxVersion: maxVersion,
|
||||
CurvePreferences: c.CurvePreferences,
|
||||
DynamicRecordSizingDisabled: c.DynamicRecordSizingDisabled,
|
||||
// no need to copy Renegotiation, it's not supported by TLS 1.3
|
||||
KeyLogWriter: c.KeyLogWriter,
|
||||
AlternativeRecordLayer: recordLayer,
|
||||
GetExtensions: extHandler.GetExtensions,
|
||||
ReceivedExtensions: extHandler.ReceivedExtensions,
|
||||
Accept0RTT: accept0RTT,
|
||||
Rejected0RTT: rejected0RTT,
|
||||
GetAppDataForSessionState: getDataForSessionState,
|
||||
SetAppDataFromSessionState: setDataFromSessionState,
|
||||
}
|
||||
if enable0RTT {
|
||||
conf.Enable0RTT = true
|
||||
conf.MaxEarlyData = 0xffffffff
|
||||
}
|
||||
return conf
|
||||
}
|
||||
|
||||
type clientSessionCache struct {
|
||||
tls.ClientSessionCache
|
||||
}
|
||||
|
||||
var _ qtls.ClientSessionCache = &clientSessionCache{}
|
||||
|
||||
func (c *clientSessionCache) Get(sessionKey string) (*qtls.ClientSessionState, bool) {
|
||||
sess, ok := c.ClientSessionCache.Get(sessionKey)
|
||||
if sess == nil {
|
||||
return nil, ok
|
||||
}
|
||||
// qtls.ClientSessionState is identical to the tls.ClientSessionState.
|
||||
// In order to allow users of quic-go to use a tls.Config,
|
||||
// we need this workaround to use the ClientSessionCache.
|
||||
// In unsafe.go we check that the two structs are actually identical.
|
||||
return (*qtls.ClientSessionState)(unsafe.Pointer(sess)), ok
|
||||
}
|
||||
|
||||
func (c *clientSessionCache) Put(sessionKey string, cs *qtls.ClientSessionState) {
|
||||
if cs == nil {
|
||||
c.ClientSessionCache.Put(sessionKey, nil)
|
||||
return
|
||||
}
|
||||
// qtls.ClientSessionState is identical to the tls.ClientSessionState.
|
||||
// In order to allow users of quic-go to use a tls.Config,
|
||||
// we need this workaround to use the ClientSessionCache.
|
||||
// In unsafe.go we check that the two structs are actually identical.
|
||||
c.ClientSessionCache.Put(sessionKey, (*tls.ClientSessionState)(unsafe.Pointer(cs)))
|
||||
}
|
||||
|
||||
type clientHelloInfo struct {
|
||||
CipherSuites []uint16
|
||||
ServerName string
|
||||
SupportedCurves []tls.CurveID
|
||||
SupportedPoints []uint8
|
||||
SignatureSchemes []tls.SignatureScheme
|
||||
SupportedProtos []string
|
||||
SupportedVersions []uint16
|
||||
Conn net.Conn
|
||||
|
||||
config *tls.Config
|
||||
}
|
||||
|
||||
type qtlsClientHelloInfo struct {
|
||||
CipherSuites []uint16
|
||||
ServerName string
|
||||
SupportedCurves []tls.CurveID
|
||||
SupportedPoints []uint8
|
||||
SignatureSchemes []tls.SignatureScheme
|
||||
SupportedProtos []string
|
||||
SupportedVersions []uint16
|
||||
Conn net.Conn
|
||||
|
||||
config *qtls.Config
|
||||
}
|
||||
|
||||
func toTLSClientHelloInfo(chi *qtls.ClientHelloInfo) *tls.ClientHelloInfo {
|
||||
if chi == nil {
|
||||
return nil
|
||||
}
|
||||
qtlsCHI := (*qtlsClientHelloInfo)(unsafe.Pointer(chi))
|
||||
var config *tls.Config
|
||||
if qtlsCHI.config != nil {
|
||||
config = qtlsConfigToTLSConfig(qtlsCHI.config)
|
||||
}
|
||||
return (*tls.ClientHelloInfo)(unsafe.Pointer(&clientHelloInfo{
|
||||
CipherSuites: chi.CipherSuites,
|
||||
ServerName: chi.ServerName,
|
||||
SupportedCurves: chi.SupportedCurves,
|
||||
SupportedPoints: chi.SupportedPoints,
|
||||
SignatureSchemes: chi.SignatureSchemes,
|
||||
SupportedProtos: chi.SupportedProtos,
|
||||
SupportedVersions: chi.SupportedVersions,
|
||||
Conn: chi.Conn,
|
||||
config: config,
|
||||
}))
|
||||
}
|
||||
|
||||
// qtlsConfigToTLSConfig is used to transform a qtls.Config to a tls.Config.
|
||||
// It is used to create the tls.Config in the ClientHelloInfo.
|
||||
// It doesn't copy all values, but only those used by ClientHelloInfo.SupportsCertificate.
|
||||
func qtlsConfigToTLSConfig(config *qtls.Config) *tls.Config {
|
||||
return &tls.Config{
|
||||
MinVersion: config.MinVersion,
|
||||
MaxVersion: config.MaxVersion,
|
||||
CipherSuites: config.CipherSuites,
|
||||
CurvePreferences: config.CurvePreferences,
|
||||
}
|
||||
}
|
|
@ -2,7 +2,7 @@ package handshake
|
|||
|
||||
import (
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/marten-seemann/qtls"
|
||||
"github.com/lucas-clemente/quic-go/internal/qtls"
|
||||
)
|
||||
|
||||
const quicTLSExtensionType = 0xffa5
|
||||
|
|
|
@ -2,7 +2,7 @@ package handshake
|
|||
|
||||
import (
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/marten-seemann/qtls"
|
||||
"github.com/lucas-clemente/quic-go/internal/qtls"
|
||||
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
|
|
|
@ -1,44 +0,0 @@
|
|||
package handshake
|
||||
|
||||
// This package uses unsafe to convert between:
|
||||
// * qtls.Certificate and tls.Certificate
|
||||
// * qtls.CertificateRequestInfo and tls.CertificateRequestInfo
|
||||
// * qtls.ClientHelloInfo and tls.ClientHelloInfo
|
||||
// * qtls.ConnectionState and tls.ConnectionState
|
||||
// * qtls.ClientSessionState and tls.ClientSessionState
|
||||
// We check in init() that this conversion actually is safe.
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"reflect"
|
||||
|
||||
"github.com/marten-seemann/qtls"
|
||||
)
|
||||
|
||||
func init() {
|
||||
if !structsEqual(&tls.Certificate{}, &qtls.Certificate{}) {
|
||||
panic("qtls.Certificate not compatible with tls.Certificate")
|
||||
}
|
||||
if !structsEqual(&tls.CertificateRequestInfo{}, &qtls.CertificateRequestInfo{}) {
|
||||
panic("qtls.CertificateRequestInfo not compatible with tls.CertificateRequestInfo")
|
||||
}
|
||||
if !structsEqual(&tls.ClientSessionState{}, &qtls.ClientSessionState{}) {
|
||||
panic("qtls.ClientSessionState not compatible with tls.ClientSessionState")
|
||||
}
|
||||
}
|
||||
|
||||
func structsEqual(a, b interface{}) bool {
|
||||
sa := reflect.ValueOf(a).Elem()
|
||||
sb := reflect.ValueOf(b).Elem()
|
||||
if sa.NumField() != sb.NumField() {
|
||||
return false
|
||||
}
|
||||
for i := 0; i < sa.NumField(); i++ {
|
||||
fa := sa.Type().Field(i)
|
||||
fb := sb.Type().Field(i)
|
||||
if !reflect.DeepEqual(fa.Index, fb.Index) || fa.Name != fb.Name || fa.Anonymous != fb.Anonymous || fa.Offset != fb.Offset || !reflect.DeepEqual(fa.Type, fb.Type) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
|
@ -9,12 +9,11 @@ import (
|
|||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/qerr"
|
||||
"github.com/lucas-clemente/quic-go/internal/qtls"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
"github.com/lucas-clemente/quic-go/logging"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/marten-seemann/qtls"
|
||||
)
|
||||
|
||||
// By setting this environment variable, the key update interval can be adjusted.
|
||||
|
|
|
@ -2,13 +2,14 @@ package handshake
|
|||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/qtls"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
"github.com/marten-seemann/qtls"
|
||||
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
|
@ -19,7 +20,7 @@ var _ = Describe("Updatable AEAD", func() {
|
|||
secret := splitHexString("9ac312a7f877468ebe69422748ad00a1 5443f18203a07d6060f688f30f21632b")
|
||||
aead := newUpdatableAEAD(&utils.RTTStats{}, nil, nil)
|
||||
chacha := cipherSuites[2]
|
||||
Expect(chacha.ID).To(Equal(qtls.TLS_CHACHA20_POLY1305_SHA256))
|
||||
Expect(chacha.ID).To(Equal(tls.TLS_CHACHA20_POLY1305_SHA256))
|
||||
aead.SetWriteKey(chacha, secret)
|
||||
header := splitHexString("4200bff4")
|
||||
const pnOffset = 1
|
||||
|
|
|
@ -10,7 +10,7 @@ import (
|
|||
gomock "github.com/golang/mock/gomock"
|
||||
handshake "github.com/lucas-clemente/quic-go/internal/handshake"
|
||||
protocol "github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
qtls "github.com/marten-seemann/qtls"
|
||||
qtls "github.com/lucas-clemente/quic-go/internal/qtls"
|
||||
)
|
||||
|
||||
// MockCryptoSetup is a mock of CryptoSetup interface
|
||||
|
|
|
@ -1,16 +1,20 @@
|
|||
package mocks
|
||||
|
||||
//go:generate sh -c "mockgen -package mockquic -destination quic/stream.go github.com/lucas-clemente/quic-go Stream && goimports -w quic/stream.go"
|
||||
//go:generate sh -c "mockgen -package mockquic -destination quic/early_session.go github.com/lucas-clemente/quic-go EarlySession && goimports -w quic/early_session.go"
|
||||
//go:generate sh -c "mockgen -package mockquic -destination quic/early_session_tmp.go github.com/lucas-clemente/quic-go EarlySession && sed 's/qtls.ConnectionState/quic.ConnectionState/g' quic/early_session_tmp.go > quic/early_session.go && rm quic/early_session_tmp.go && goimports -w quic/early_session.go"
|
||||
//go:generate sh -c "mockgen -package mockquic -destination quic/early_listener.go github.com/lucas-clemente/quic-go EarlyListener && goimports -w quic/early_listener.go"
|
||||
//go:generate sh -c "mockgen -package mocks -destination tracer.go github.com/lucas-clemente/quic-go/logging Tracer && goimports -w tracer.go"
|
||||
//go:generate sh -c "mockgen -package mocks -destination connection_tracer.go github.com/lucas-clemente/quic-go/logging ConnectionTracer && goimports -w connection_tracer.go"
|
||||
//go:generate sh -c "mockgen -package mocks -destination short_header_sealer.go github.com/lucas-clemente/quic-go/internal/handshake ShortHeaderSealer && goimports -w short_header_sealer.go"
|
||||
//go:generate sh -c "mockgen -package mocks -destination short_header_opener.go github.com/lucas-clemente/quic-go/internal/handshake ShortHeaderOpener && goimports -w short_header_opener.go"
|
||||
//go:generate sh -c "mockgen -package mocks -destination long_header_opener.go github.com/lucas-clemente/quic-go/internal/handshake LongHeaderOpener && goimports -w long_header_opener.go"
|
||||
//go:generate sh -c "mockgen -package mocks -destination crypto_setup.go github.com/lucas-clemente/quic-go/internal/handshake CryptoSetup && goimports -w crypto_setup.go"
|
||||
//go:generate sh -c "mockgen -package mocks -destination crypto_setup_tmp.go github.com/lucas-clemente/quic-go/internal/handshake CryptoSetup && sed 's/github.com\\/marten-seemann\\/qtls/github.com\\/lucas-clemente\\/quic-go\\/internal\\/qtls/g' crypto_setup_tmp.go > crypto_setup.go && rm crypto_setup_tmp.go && goimports -w crypto_setup.go"
|
||||
//go:generate sh -c "mockgen -package mocks -destination stream_flow_controller.go github.com/lucas-clemente/quic-go/internal/flowcontrol StreamFlowController && goimports -w stream_flow_controller.go"
|
||||
//go:generate sh -c "mockgen -package mocks -destination congestion.go github.com/lucas-clemente/quic-go/internal/congestion SendAlgorithmWithDebugInfos && goimports -w congestion.go"
|
||||
//go:generate sh -c "mockgen -package mocks -destination connection_flow_controller.go github.com/lucas-clemente/quic-go/internal/flowcontrol ConnectionFlowController && goimports -w connection_flow_controller.go"
|
||||
//go:generate sh -c "mockgen -package mockackhandler -destination ackhandler/sent_packet_handler.go github.com/lucas-clemente/quic-go/internal/ackhandler SentPacketHandler && goimports -w ackhandler/sent_packet_handler.go"
|
||||
//go:generate sh -c "mockgen -package mockackhandler -destination ackhandler/received_packet_handler.go github.com/lucas-clemente/quic-go/internal/ackhandler ReceivedPacketHandler && goimports -w ackhandler/received_packet_handler.go"
|
||||
|
||||
// The following command produces a warning message on OSX, however, it still generates the correct mock file.
|
||||
// See https://github.com/golang/mock/issues/339 for details.
|
||||
//go:generate sh -c "mockgen -package mocktls -destination tls/client_session_cache.go crypto/tls ClientSessionCache && goimports -w tls/client_session_cache.go"
|
||||
|
|
|
@ -12,7 +12,6 @@ import (
|
|||
gomock "github.com/golang/mock/gomock"
|
||||
quic "github.com/lucas-clemente/quic-go"
|
||||
protocol "github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
qtls "github.com/marten-seemann/qtls"
|
||||
)
|
||||
|
||||
// MockEarlySession is a mock of EarlySession interface
|
||||
|
@ -83,10 +82,10 @@ func (mr *MockEarlySessionMockRecorder) CloseWithError(arg0, arg1 interface{}) *
|
|||
}
|
||||
|
||||
// ConnectionState mocks base method
|
||||
func (m *MockEarlySession) ConnectionState() qtls.ConnectionState {
|
||||
func (m *MockEarlySession) ConnectionState() quic.ConnectionState {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "ConnectionState")
|
||||
ret0, _ := ret[0].(qtls.ConnectionState)
|
||||
ret0, _ := ret[0].(quic.ConnectionState)
|
||||
return ret0
|
||||
}
|
||||
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: crypto/tls (interfaces: ClientSessionCache)
|
||||
|
||||
// Package handshake is a generated GoMock package.
|
||||
package handshake
|
||||
// Package mocktls is a generated GoMock package.
|
||||
package mocktls
|
||||
|
||||
import (
|
||||
tls "crypto/tls"
|
|
@ -3,7 +3,7 @@ package qerr
|
|||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/marten-seemann/qtls"
|
||||
"github.com/lucas-clemente/quic-go/internal/qtls"
|
||||
)
|
||||
|
||||
// ErrorCode can be used as a normal error without reason.
|
||||
|
|
164
internal/qtls/interface.go
Normal file
164
internal/qtls/interface.go
Normal file
|
@ -0,0 +1,164 @@
|
|||
package qtls
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/cipher"
|
||||
"crypto/tls"
|
||||
"net"
|
||||
"unsafe"
|
||||
|
||||
"github.com/marten-seemann/qtls"
|
||||
)
|
||||
|
||||
type (
|
||||
// Alert is a TLS alert
|
||||
Alert = qtls.Alert
|
||||
// A Certificate is qtls.Certificate.
|
||||
Certificate = qtls.Certificate
|
||||
// CertificateRequestInfo contains inforamtion about a certificate request.
|
||||
CertificateRequestInfo = qtls.CertificateRequestInfo
|
||||
// A CipherSuiteTLS13 is a cipher suite for TLS 1.3
|
||||
CipherSuiteTLS13 = qtls.CipherSuiteTLS13
|
||||
// ClientHelloInfo contains information about a ClientHello.
|
||||
ClientHelloInfo = qtls.ClientHelloInfo
|
||||
// ClientSessionCache is a cache used for session resumption.
|
||||
ClientSessionCache = qtls.ClientSessionCache
|
||||
// ClientSessionState is a state needed for session resumption.
|
||||
ClientSessionState = qtls.ClientSessionState
|
||||
// A Config is a qtls.Config.
|
||||
Config = qtls.Config
|
||||
// A Conn is a qtls.Conn.
|
||||
Conn = qtls.Conn
|
||||
// ConnectionState contains information about the state of the connection.
|
||||
ConnectionState = qtls.ConnectionState
|
||||
// EncryptionLevel is the encryption level of a message.
|
||||
EncryptionLevel = qtls.EncryptionLevel
|
||||
// Extension is a TLS extension
|
||||
Extension = qtls.Extension
|
||||
// RecordLayer is a qtls RecordLayer.
|
||||
RecordLayer = qtls.RecordLayer
|
||||
)
|
||||
|
||||
type ExtraConfig struct {
|
||||
// GetExtensions, if not nil, is called before a message that allows
|
||||
// sending of extensions is sent.
|
||||
// Currently only implemented for the ClientHello message (for the client)
|
||||
// and for the EncryptedExtensions message (for the server).
|
||||
// Only valid for TLS 1.3.
|
||||
GetExtensions func(handshakeMessageType uint8) []Extension
|
||||
|
||||
// ReceivedExtensions, if not nil, is called when a message that allows the
|
||||
// inclusion of extensions is received.
|
||||
// It is called with an empty slice of extensions, if the message didn't
|
||||
// contain any extensions.
|
||||
// Currently only implemented for the ClientHello message (sent by the
|
||||
// client) and for the EncryptedExtensions message (sent by the server).
|
||||
// Only valid for TLS 1.3.
|
||||
ReceivedExtensions func(handshakeMessageType uint8, exts []Extension)
|
||||
|
||||
// AlternativeRecordLayer is used by QUIC
|
||||
AlternativeRecordLayer RecordLayer
|
||||
|
||||
// Enforce the selection of a supported application protocol.
|
||||
// Only works for TLS 1.3.
|
||||
// If enabled, client and server have to agree on an application protocol.
|
||||
// Otherwise, connection establishment fails.
|
||||
EnforceNextProtoSelection bool
|
||||
|
||||
// If MaxEarlyData is greater than 0, the client will be allowed to send early
|
||||
// data when resuming a session.
|
||||
// Requires the AlternativeRecordLayer to be set.
|
||||
//
|
||||
// It has no meaning on the client.
|
||||
MaxEarlyData uint32
|
||||
|
||||
// The Accept0RTT callback is called when the client offers 0-RTT.
|
||||
// The server then has to decide if it wants to accept or reject 0-RTT.
|
||||
// It is only used for servers.
|
||||
Accept0RTT func(appData []byte) bool
|
||||
|
||||
// 0RTTRejected is called when the server rejectes 0-RTT.
|
||||
// It is only used for clients.
|
||||
Rejected0RTT func()
|
||||
|
||||
// If set, the client will export the 0-RTT key when resuming a session that
|
||||
// allows sending of early data.
|
||||
// Requires the AlternativeRecordLayer to be set.
|
||||
//
|
||||
// It has no meaning to the server.
|
||||
Enable0RTT bool
|
||||
|
||||
// Is called when the client saves a session ticket to the session ticket.
|
||||
// This gives the application the opportunity to save some data along with the ticket,
|
||||
// which can be restored when the session ticket is used.
|
||||
GetAppDataForSessionState func() []byte
|
||||
|
||||
// Is called when the client uses a session ticket.
|
||||
// Restores the application data that was saved earlier on GetAppDataForSessionTicket.
|
||||
SetAppDataFromSessionState func([]byte)
|
||||
}
|
||||
|
||||
const (
|
||||
// EncryptionHandshake is the Handshake encryption level
|
||||
EncryptionHandshake = qtls.EncryptionHandshake
|
||||
// Encryption0RTT is the 0-RTT encryption level
|
||||
Encryption0RTT = qtls.Encryption0RTT
|
||||
// EncryptionApplication is the application data encryption level
|
||||
EncryptionApplication = qtls.EncryptionApplication
|
||||
)
|
||||
|
||||
// CipherSuiteName gets the name of a cipher suite.
|
||||
func CipherSuiteName(id uint16) string {
|
||||
return qtls.CipherSuiteName(id)
|
||||
}
|
||||
|
||||
// HkdfExtract generates a pseudorandom key for use with Expand from an input secret and an optional independent salt.
|
||||
func HkdfExtract(hash crypto.Hash, newSecret, currentSecret []byte) []byte {
|
||||
return qtls.HkdfExtract(hash, newSecret, currentSecret)
|
||||
}
|
||||
|
||||
// HkdfExpandLabel HKDF expands a label
|
||||
func HkdfExpandLabel(hash crypto.Hash, secret, hashValue []byte, label string, L int) []byte {
|
||||
return qtls.HkdfExpandLabel(hash, secret, hashValue, label, L)
|
||||
}
|
||||
|
||||
// AEADAESGCMTLS13 creates a new AES-GCM AEAD for TLS 1.3
|
||||
func AEADAESGCMTLS13(key, fixedNonce []byte) cipher.AEAD {
|
||||
return qtls.AEADAESGCMTLS13(key, fixedNonce)
|
||||
}
|
||||
|
||||
// Client returns a new TLS client side connection.
|
||||
func Client(conn net.Conn, config *tls.Config, extraConfig *ExtraConfig) *Conn {
|
||||
return qtls.Client(conn, tlsConfigToQtlsConfig(config, extraConfig))
|
||||
}
|
||||
|
||||
// Server returns a new TLS server side connection.
|
||||
func Server(conn net.Conn, config *tls.Config, extraConfig *ExtraConfig) *Conn {
|
||||
return qtls.Server(conn, tlsConfigToQtlsConfig(config, extraConfig))
|
||||
}
|
||||
|
||||
func GetConnectionState(conn *Conn) ConnectionState {
|
||||
return conn.ConnectionState()
|
||||
}
|
||||
|
||||
type cipherSuiteTLS13 struct {
|
||||
ID uint16
|
||||
KeyLen int
|
||||
AEAD func(key, fixedNonce []byte) cipher.AEAD
|
||||
Hash crypto.Hash
|
||||
}
|
||||
|
||||
//go:linkname cipherSuiteTLS13ByID github.com/marten-seemann/qtls.cipherSuiteTLS13ByID
|
||||
func cipherSuiteTLS13ByID(id uint16) *cipherSuiteTLS13
|
||||
|
||||
// CipherSuiteTLS13ByID gets a TLS 1.3 cipher suite.
|
||||
func CipherSuiteTLS13ByID(id uint16) *CipherSuiteTLS13 {
|
||||
val := cipherSuiteTLS13ByID(id)
|
||||
cs := (*cipherSuiteTLS13)(unsafe.Pointer(val))
|
||||
return &qtls.CipherSuiteTLS13{
|
||||
ID: cs.ID,
|
||||
KeyLen: cs.KeyLen,
|
||||
AEAD: cs.AEAD,
|
||||
Hash: cs.Hash,
|
||||
}
|
||||
}
|
214
internal/qtls/qtls.go
Normal file
214
internal/qtls/qtls.go
Normal file
|
@ -0,0 +1,214 @@
|
|||
package qtls
|
||||
|
||||
// This package uses unsafe to convert between:
|
||||
// * Certificate and tls.Certificate
|
||||
// * CertificateRequestInfo and tls.CertificateRequestInfo
|
||||
// * ClientHelloInfo and tls.ClientHelloInfo
|
||||
// * ConnectionState and tls.ConnectionState
|
||||
// * ClientSessionState and tls.ClientSessionState
|
||||
// We check in init() that this conversion actually is safe.
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"net"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
func init() {
|
||||
if !structsEqual(&tls.Certificate{}, &Certificate{}) {
|
||||
panic("Certificate not compatible with tls.Certificate")
|
||||
}
|
||||
if !structsEqual(&tls.CertificateRequestInfo{}, &CertificateRequestInfo{}) {
|
||||
panic("CertificateRequestInfo not compatible with tls.CertificateRequestInfo")
|
||||
}
|
||||
if !structsEqual(&tls.ClientSessionState{}, &ClientSessionState{}) {
|
||||
panic("ClientSessionState not compatible with tls.ClientSessionState")
|
||||
}
|
||||
if !structsEqual(&tls.ClientHelloInfo{}, &clientHelloInfo{}) {
|
||||
panic("clientHelloInfo not compatible with tls.ClientHelloInfo")
|
||||
}
|
||||
if !structsEqual(&ClientHelloInfo{}, &qtlsClientHelloInfo{}) {
|
||||
panic("qtlsClientHelloInfo not compatible with ClientHelloInfo")
|
||||
}
|
||||
}
|
||||
|
||||
func tlsConfigToQtlsConfig(c *tls.Config, ec *ExtraConfig) *Config {
|
||||
if c == nil {
|
||||
c = &tls.Config{}
|
||||
}
|
||||
if ec == nil {
|
||||
ec = &ExtraConfig{}
|
||||
}
|
||||
// Clone the config first. This executes the tls.Config.serverInit().
|
||||
// This sets the SessionTicketKey, if the user didn't supply one.
|
||||
c = c.Clone()
|
||||
// QUIC requires TLS 1.3 or newer
|
||||
minVersion := c.MinVersion
|
||||
if minVersion < tls.VersionTLS13 {
|
||||
minVersion = tls.VersionTLS13
|
||||
}
|
||||
maxVersion := c.MaxVersion
|
||||
if maxVersion < tls.VersionTLS13 {
|
||||
maxVersion = tls.VersionTLS13
|
||||
}
|
||||
var getConfigForClient func(ch *ClientHelloInfo) (*Config, error)
|
||||
if c.GetConfigForClient != nil {
|
||||
getConfigForClient = func(ch *ClientHelloInfo) (*Config, error) {
|
||||
tlsConf, err := c.GetConfigForClient(toTLSClientHelloInfo(ch))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if tlsConf == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return tlsConfigToQtlsConfig(tlsConf, ec), nil
|
||||
}
|
||||
}
|
||||
var getCertificate func(ch *ClientHelloInfo) (*Certificate, error)
|
||||
if c.GetCertificate != nil {
|
||||
getCertificate = func(ch *ClientHelloInfo) (*Certificate, error) {
|
||||
cert, err := c.GetCertificate(toTLSClientHelloInfo(ch))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if cert == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return (*Certificate)(cert), nil
|
||||
}
|
||||
}
|
||||
var csc ClientSessionCache
|
||||
if c.ClientSessionCache != nil {
|
||||
csc = &clientSessionCache{c.ClientSessionCache}
|
||||
}
|
||||
conf := &Config{
|
||||
Rand: c.Rand,
|
||||
Time: c.Time,
|
||||
Certificates: *(*[]Certificate)(unsafe.Pointer(&c.Certificates)),
|
||||
//nolint:staticcheck // NameToCertificate is deprecated, but we still need to copy it if the user sets it.
|
||||
NameToCertificate: *(*map[string]*Certificate)(unsafe.Pointer(&c.NameToCertificate)),
|
||||
GetCertificate: getCertificate,
|
||||
GetClientCertificate: *(*func(*CertificateRequestInfo) (*Certificate, error))(unsafe.Pointer(&c.GetClientCertificate)),
|
||||
GetConfigForClient: getConfigForClient,
|
||||
VerifyPeerCertificate: c.VerifyPeerCertificate,
|
||||
RootCAs: c.RootCAs,
|
||||
NextProtos: c.NextProtos,
|
||||
EnforceNextProtoSelection: true,
|
||||
ServerName: c.ServerName,
|
||||
ClientAuth: c.ClientAuth,
|
||||
ClientCAs: c.ClientCAs,
|
||||
InsecureSkipVerify: c.InsecureSkipVerify,
|
||||
CipherSuites: c.CipherSuites,
|
||||
PreferServerCipherSuites: c.PreferServerCipherSuites,
|
||||
SessionTicketsDisabled: c.SessionTicketsDisabled,
|
||||
//nolint:staticcheck // SessionTicketKey is deprecated, but we still need to copy it if the user sets it.
|
||||
SessionTicketKey: c.SessionTicketKey,
|
||||
ClientSessionCache: csc,
|
||||
MinVersion: minVersion,
|
||||
MaxVersion: maxVersion,
|
||||
CurvePreferences: c.CurvePreferences,
|
||||
DynamicRecordSizingDisabled: c.DynamicRecordSizingDisabled,
|
||||
// no need to copy Renegotiation, it's not supported by TLS 1.3
|
||||
KeyLogWriter: c.KeyLogWriter,
|
||||
AlternativeRecordLayer: ec.AlternativeRecordLayer,
|
||||
GetExtensions: ec.GetExtensions,
|
||||
ReceivedExtensions: ec.ReceivedExtensions,
|
||||
Accept0RTT: ec.Accept0RTT,
|
||||
Rejected0RTT: ec.Rejected0RTT,
|
||||
GetAppDataForSessionState: ec.GetAppDataForSessionState,
|
||||
SetAppDataFromSessionState: ec.SetAppDataFromSessionState,
|
||||
Enable0RTT: ec.Enable0RTT,
|
||||
MaxEarlyData: ec.MaxEarlyData,
|
||||
}
|
||||
return conf
|
||||
}
|
||||
|
||||
type clientSessionCache struct {
|
||||
tls.ClientSessionCache
|
||||
}
|
||||
|
||||
var _ ClientSessionCache = &clientSessionCache{}
|
||||
|
||||
func (c *clientSessionCache) Get(sessionKey string) (*ClientSessionState, bool) {
|
||||
sess, ok := c.ClientSessionCache.Get(sessionKey)
|
||||
if sess == nil {
|
||||
return nil, ok
|
||||
}
|
||||
// ClientSessionState is identical to the tls.ClientSessionState.
|
||||
// In order to allow users of quic-go to use a tls.Config,
|
||||
// we need this workaround to use the ClientSessionCache.
|
||||
// In unsafe.go we check that the two structs are actually identical.
|
||||
return (*ClientSessionState)(unsafe.Pointer(sess)), ok
|
||||
}
|
||||
|
||||
func (c *clientSessionCache) Put(sessionKey string, cs *ClientSessionState) {
|
||||
if cs == nil {
|
||||
c.ClientSessionCache.Put(sessionKey, nil)
|
||||
return
|
||||
}
|
||||
// ClientSessionState is identical to the tls.ClientSessionState.
|
||||
// In order to allow users of quic-go to use a tls.Config,
|
||||
// we need this workaround to use the ClientSessionCache.
|
||||
// In unsafe.go we check that the two structs are actually identical.
|
||||
c.ClientSessionCache.Put(sessionKey, (*tls.ClientSessionState)(unsafe.Pointer(cs)))
|
||||
}
|
||||
|
||||
type clientHelloInfo struct {
|
||||
CipherSuites []uint16
|
||||
ServerName string
|
||||
SupportedCurves []tls.CurveID
|
||||
SupportedPoints []uint8
|
||||
SignatureSchemes []tls.SignatureScheme
|
||||
SupportedProtos []string
|
||||
SupportedVersions []uint16
|
||||
Conn net.Conn
|
||||
|
||||
config *tls.Config
|
||||
}
|
||||
|
||||
type qtlsClientHelloInfo struct {
|
||||
CipherSuites []uint16
|
||||
ServerName string
|
||||
SupportedCurves []tls.CurveID
|
||||
SupportedPoints []uint8
|
||||
SignatureSchemes []tls.SignatureScheme
|
||||
SupportedProtos []string
|
||||
SupportedVersions []uint16
|
||||
Conn net.Conn
|
||||
|
||||
config *Config
|
||||
}
|
||||
|
||||
func toTLSClientHelloInfo(chi *ClientHelloInfo) *tls.ClientHelloInfo {
|
||||
if chi == nil {
|
||||
return nil
|
||||
}
|
||||
qtlsCHI := (*qtlsClientHelloInfo)(unsafe.Pointer(chi))
|
||||
var config *tls.Config
|
||||
if qtlsCHI.config != nil {
|
||||
config = qtlsConfigToTLSConfig(qtlsCHI.config)
|
||||
}
|
||||
return (*tls.ClientHelloInfo)(unsafe.Pointer(&clientHelloInfo{
|
||||
CipherSuites: chi.CipherSuites,
|
||||
ServerName: chi.ServerName,
|
||||
SupportedCurves: chi.SupportedCurves,
|
||||
SupportedPoints: chi.SupportedPoints,
|
||||
SignatureSchemes: chi.SignatureSchemes,
|
||||
SupportedProtos: chi.SupportedProtos,
|
||||
SupportedVersions: chi.SupportedVersions,
|
||||
Conn: chi.Conn,
|
||||
config: config,
|
||||
}))
|
||||
}
|
||||
|
||||
// qtlsConfigToTLSConfig is used to transform a Config to a tls.Config.
|
||||
// It is used to create the tls.Config in the ClientHelloInfo.
|
||||
// It doesn't copy all values, but only those used by ClientHelloInfo.SupportsCertificate.
|
||||
func qtlsConfigToTLSConfig(config *Config) *tls.Config {
|
||||
return &tls.Config{
|
||||
MinVersion: config.MinVersion,
|
||||
MaxVersion: config.MaxVersion,
|
||||
CipherSuites: config.CipherSuites,
|
||||
CurvePreferences: config.CurvePreferences,
|
||||
}
|
||||
}
|
25
internal/qtls/qtls_suite_test.go
Normal file
25
internal/qtls/qtls_suite_test.go
Normal file
|
@ -0,0 +1,25 @@
|
|||
package qtls
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
func TestQTLS(t *testing.T) {
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "qtls Suite")
|
||||
}
|
||||
|
||||
var mockCtrl *gomock.Controller
|
||||
|
||||
var _ = BeforeEach(func() {
|
||||
mockCtrl = gomock.NewController(GinkgoT())
|
||||
})
|
||||
|
||||
var _ = AfterEach(func() {
|
||||
mockCtrl.Finish()
|
||||
})
|
|
@ -1,4 +1,4 @@
|
|||
package handshake
|
||||
package qtls
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
|
@ -6,99 +6,91 @@ import (
|
|||
"net"
|
||||
"unsafe"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
mocktls "github.com/lucas-clemente/quic-go/internal/mocks/tls"
|
||||
|
||||
"github.com/marten-seemann/qtls"
|
||||
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
type mockExtensionHandler struct {
|
||||
get, received bool
|
||||
}
|
||||
|
||||
var _ tlsExtensionHandler = &mockExtensionHandler{}
|
||||
|
||||
func (h *mockExtensionHandler) GetExtensions(msgType uint8) []qtls.Extension {
|
||||
h.get = true
|
||||
return nil
|
||||
}
|
||||
func (h *mockExtensionHandler) ReceivedExtensions(msgType uint8, exts []qtls.Extension) {
|
||||
h.received = true
|
||||
}
|
||||
func (*mockExtensionHandler) TransportParameters() <-chan []byte { panic("not implemented") }
|
||||
|
||||
var _ = Describe("qtls.Config", func() {
|
||||
var _ = Describe("Config", func() {
|
||||
It("sets MinVersion and MaxVersion", func() {
|
||||
tlsConf := &tls.Config{MinVersion: tls.VersionTLS11, MaxVersion: tls.VersionTLS12}
|
||||
qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil, &mockExtensionHandler{}, utils.NewRTTStats(), nil, nil, nil, nil, false)
|
||||
qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil)
|
||||
Expect(qtlsConf.MinVersion).To(BeEquivalentTo(tls.VersionTLS13))
|
||||
Expect(qtlsConf.MaxVersion).To(BeEquivalentTo(tls.VersionTLS13))
|
||||
})
|
||||
|
||||
It("works when called with a nil config", func() {
|
||||
qtlsConf := tlsConfigToQtlsConfig(nil, nil, &mockExtensionHandler{}, utils.NewRTTStats(), nil, nil, nil, nil, false)
|
||||
qtlsConf := tlsConfigToQtlsConfig(nil, nil)
|
||||
Expect(qtlsConf).ToNot(BeNil())
|
||||
})
|
||||
|
||||
It("sets the setter and getter function for TLS extensions", func() {
|
||||
extHandler := &mockExtensionHandler{}
|
||||
qtlsConf := tlsConfigToQtlsConfig(&tls.Config{}, nil, extHandler, utils.NewRTTStats(), nil, nil, nil, nil, false)
|
||||
Expect(extHandler.get).To(BeFalse())
|
||||
var get, received bool
|
||||
extraConfig := &ExtraConfig{
|
||||
GetExtensions: func(handshakeMessageType uint8) []Extension { get = true; return nil },
|
||||
ReceivedExtensions: func(handshakeMessageType uint8, exts []qtls.Extension) { received = true },
|
||||
}
|
||||
qtlsConf := tlsConfigToQtlsConfig(&tls.Config{}, extraConfig)
|
||||
qtlsConf.GetExtensions(10)
|
||||
Expect(extHandler.get).To(BeTrue())
|
||||
Expect(extHandler.received).To(BeFalse())
|
||||
Expect(get).To(BeTrue())
|
||||
Expect(received).To(BeFalse())
|
||||
qtlsConf.ReceivedExtensions(10, nil)
|
||||
Expect(extHandler.received).To(BeTrue())
|
||||
Expect(received).To(BeTrue())
|
||||
})
|
||||
|
||||
It("sets the Accept0RTT callback", func() {
|
||||
accept0RTT := func([]byte) bool { return true }
|
||||
qtlsConf := tlsConfigToQtlsConfig(nil, nil, &mockExtensionHandler{}, utils.NewRTTStats(), nil, nil, accept0RTT, nil, false)
|
||||
qtlsConf := tlsConfigToQtlsConfig(nil, &ExtraConfig{Accept0RTT: func([]byte) bool { return true }})
|
||||
Expect(qtlsConf.Accept0RTT).ToNot(BeNil())
|
||||
Expect(qtlsConf.Accept0RTT(nil)).To(BeTrue())
|
||||
})
|
||||
|
||||
It("sets the Accept0RTT callback", func() {
|
||||
It("sets the Rejected0RTT callback", func() {
|
||||
var called bool
|
||||
rejected0RTT := func() { called = true }
|
||||
qtlsConf := tlsConfigToQtlsConfig(nil, nil, &mockExtensionHandler{}, utils.NewRTTStats(), nil, nil, nil, rejected0RTT, false)
|
||||
qtlsConf := tlsConfigToQtlsConfig(nil, &ExtraConfig{Rejected0RTT: func() { called = true }})
|
||||
Expect(qtlsConf.Rejected0RTT).ToNot(BeNil())
|
||||
qtlsConf.Rejected0RTT()
|
||||
Expect(called).To(BeTrue())
|
||||
})
|
||||
|
||||
It("enables 0-RTT", func() {
|
||||
qtlsConf := tlsConfigToQtlsConfig(nil, nil, &mockExtensionHandler{}, utils.NewRTTStats(), nil, nil, nil, nil, false)
|
||||
Expect(qtlsConf.Enable0RTT).To(BeFalse())
|
||||
It("sets MaxEarlyData", func() {
|
||||
qtlsConf := tlsConfigToQtlsConfig(nil, nil)
|
||||
Expect(qtlsConf.MaxEarlyData).To(BeZero())
|
||||
qtlsConf = tlsConfigToQtlsConfig(nil, nil, &mockExtensionHandler{}, utils.NewRTTStats(), nil, nil, nil, nil, true)
|
||||
qtlsConf = tlsConfigToQtlsConfig(nil, &ExtraConfig{MaxEarlyData: 1337})
|
||||
Expect(qtlsConf.MaxEarlyData).To(Equal(uint32(1337)))
|
||||
})
|
||||
|
||||
It("enables 0-RTT", func() {
|
||||
qtlsConf := tlsConfigToQtlsConfig(nil, nil)
|
||||
Expect(qtlsConf.Enable0RTT).To(BeFalse())
|
||||
qtlsConf = tlsConfigToQtlsConfig(nil, &ExtraConfig{Enable0RTT: true})
|
||||
Expect(qtlsConf.Enable0RTT).To(BeTrue())
|
||||
Expect(qtlsConf.MaxEarlyData).To(Equal(uint32(0xffffffff)))
|
||||
})
|
||||
|
||||
It("initializes such that the session ticket key remains constant", func() {
|
||||
tlsConf := &tls.Config{}
|
||||
qtlsConf1 := tlsConfigToQtlsConfig(tlsConf, nil, &mockExtensionHandler{}, utils.NewRTTStats(), nil, nil, nil, nil, false)
|
||||
qtlsConf2 := tlsConfigToQtlsConfig(tlsConf, nil, &mockExtensionHandler{}, utils.NewRTTStats(), nil, nil, nil, nil, false)
|
||||
qtlsConf1 := tlsConfigToQtlsConfig(tlsConf, nil)
|
||||
qtlsConf2 := tlsConfigToQtlsConfig(tlsConf, nil)
|
||||
Expect(qtlsConf1.SessionTicketKey).ToNot(BeZero()) // should now contain a random value
|
||||
Expect(qtlsConf1.SessionTicketKey).To(Equal(qtlsConf2.SessionTicketKey))
|
||||
})
|
||||
|
||||
Context("GetConfigForClient callback", func() {
|
||||
It("doesn't set it if absent", func() {
|
||||
qtlsConf := tlsConfigToQtlsConfig(&tls.Config{}, nil, &mockExtensionHandler{}, utils.NewRTTStats(), nil, nil, nil, nil, false)
|
||||
qtlsConf := tlsConfigToQtlsConfig(nil, nil)
|
||||
Expect(qtlsConf.GetConfigForClient).To(BeNil())
|
||||
})
|
||||
|
||||
It("returns a qtls.Config", func() {
|
||||
It("returns a Config", func() {
|
||||
tlsConf := &tls.Config{
|
||||
GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) {
|
||||
return &tls.Config{ServerName: "foo.bar"}, nil
|
||||
},
|
||||
}
|
||||
extHandler := &mockExtensionHandler{}
|
||||
qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil, extHandler, utils.NewRTTStats(), nil, nil, nil, nil, false)
|
||||
var received bool
|
||||
qtlsConf := tlsConfigToQtlsConfig(tlsConf, &ExtraConfig{ReceivedExtensions: func(uint8, []Extension) { received = true }})
|
||||
Expect(qtlsConf.GetConfigForClient).ToNot(BeNil())
|
||||
confForClient, err := qtlsConf.GetConfigForClient(nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
@ -106,9 +98,10 @@ var _ = Describe("qtls.Config", func() {
|
|||
Expect(confForClient).ToNot(BeNil())
|
||||
Expect(confForClient.MinVersion).To(BeEquivalentTo(tls.VersionTLS13))
|
||||
Expect(confForClient.MaxVersion).To(BeEquivalentTo(tls.VersionTLS13))
|
||||
Expect(extHandler.get).To(BeFalse())
|
||||
confForClient.GetExtensions(10)
|
||||
Expect(extHandler.get).To(BeTrue())
|
||||
Expect(received).To(BeFalse())
|
||||
Expect(confForClient.ReceivedExtensions).ToNot(BeNil())
|
||||
confForClient.ReceivedExtensions(10, nil)
|
||||
Expect(received).To(BeTrue())
|
||||
})
|
||||
|
||||
It("returns errors", func() {
|
||||
|
@ -118,7 +111,7 @@ var _ = Describe("qtls.Config", func() {
|
|||
return nil, testErr
|
||||
},
|
||||
}
|
||||
qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil, &mockExtensionHandler{}, utils.NewRTTStats(), nil, nil, nil, nil, false)
|
||||
qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil)
|
||||
_, err := qtlsConf.GetConfigForClient(nil)
|
||||
Expect(err).To(MatchError(testErr))
|
||||
})
|
||||
|
@ -129,7 +122,7 @@ var _ = Describe("qtls.Config", func() {
|
|||
return nil, nil
|
||||
},
|
||||
}
|
||||
qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil, &mockExtensionHandler{}, utils.NewRTTStats(), nil, nil, nil, nil, false)
|
||||
qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil)
|
||||
Expect(qtlsConf.GetConfigForClient(nil)).To(BeNil())
|
||||
})
|
||||
})
|
||||
|
@ -141,7 +134,7 @@ var _ = Describe("qtls.Config", func() {
|
|||
return &tls.Certificate{Certificate: [][]byte{[]byte("foo"), []byte("bar")}}, nil
|
||||
},
|
||||
}
|
||||
qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil, &mockExtensionHandler{}, utils.NewRTTStats(), nil, nil, nil, nil, false)
|
||||
qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil)
|
||||
qtlsCert, err := qtlsConf.GetCertificate(nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(qtlsCert).ToNot(BeNil())
|
||||
|
@ -149,7 +142,7 @@ var _ = Describe("qtls.Config", func() {
|
|||
})
|
||||
|
||||
It("doesn't set it if absent", func() {
|
||||
qtlsConf := tlsConfigToQtlsConfig(&tls.Config{}, nil, &mockExtensionHandler{}, utils.NewRTTStats(), nil, nil, nil, nil, false)
|
||||
qtlsConf := tlsConfigToQtlsConfig(&tls.Config{}, nil)
|
||||
Expect(qtlsConf.GetCertificate).To(BeNil())
|
||||
})
|
||||
|
||||
|
@ -159,7 +152,7 @@ var _ = Describe("qtls.Config", func() {
|
|||
return nil, errors.New("test")
|
||||
},
|
||||
}
|
||||
qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil, &mockExtensionHandler{}, utils.NewRTTStats(), nil, nil, nil, nil, false)
|
||||
qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil)
|
||||
_, err := qtlsConf.GetCertificate(nil)
|
||||
Expect(err).To(MatchError("test"))
|
||||
})
|
||||
|
@ -170,21 +163,21 @@ var _ = Describe("qtls.Config", func() {
|
|||
return nil, nil
|
||||
},
|
||||
}
|
||||
qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil, &mockExtensionHandler{}, utils.NewRTTStats(), nil, nil, nil, nil, false)
|
||||
qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil)
|
||||
Expect(qtlsConf.GetCertificate(nil)).To(BeNil())
|
||||
})
|
||||
})
|
||||
|
||||
Context("ClientSessionCache", func() {
|
||||
It("doesn't set if absent", func() {
|
||||
qtlsConf := tlsConfigToQtlsConfig(&tls.Config{}, nil, &mockExtensionHandler{}, utils.NewRTTStats(), nil, nil, nil, nil, false)
|
||||
qtlsConf := tlsConfigToQtlsConfig(&tls.Config{}, nil)
|
||||
Expect(qtlsConf.ClientSessionCache).To(BeNil())
|
||||
})
|
||||
|
||||
It("puts a nil session state", func() {
|
||||
csc := NewMockClientSessionCache(mockCtrl)
|
||||
csc := mocktls.NewMockClientSessionCache(mockCtrl)
|
||||
tlsConf := &tls.Config{ClientSessionCache: csc}
|
||||
qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil, &mockExtensionHandler{}, utils.NewRTTStats(), nil, nil, nil, nil, false)
|
||||
qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil)
|
||||
// put something
|
||||
csc.EXPECT().Put("foobar", nil)
|
||||
qtlsConf.ClientSessionCache.Put("foobar", nil)
|
||||
|
@ -192,25 +185,25 @@ var _ = Describe("qtls.Config", func() {
|
|||
})
|
||||
})
|
||||
|
||||
var _ = Describe("qtls.Config generation", func() {
|
||||
It("converts a qtls.ClientHelloInfo to a tls.ClientHelloInfo", func() {
|
||||
var _ = Describe("Config generation", func() {
|
||||
It("converts a ClientHelloInfo to a tls.ClientHelloInfo", func() {
|
||||
chi := &qtlsClientHelloInfo{
|
||||
CipherSuites: []uint16{1, 2, 3},
|
||||
ServerName: "foo.bar",
|
||||
SupportedCurves: []qtls.CurveID{4, 5, 6},
|
||||
SupportedCurves: []tls.CurveID{4, 5, 6},
|
||||
SupportedPoints: []uint8{7, 8, 9},
|
||||
SignatureSchemes: []qtls.SignatureScheme{10, 11, 12},
|
||||
SignatureSchemes: []tls.SignatureScheme{10, 11, 12},
|
||||
SupportedProtos: []string{"foo", "bar"},
|
||||
SupportedVersions: []uint16{13, 14, 15},
|
||||
Conn: &net.UDPConn{},
|
||||
config: &qtls.Config{
|
||||
config: &Config{
|
||||
MinVersion: tls.VersionTLS10,
|
||||
MaxVersion: tls.VersionTLS12,
|
||||
CipherSuites: []uint16{16, 17, 18},
|
||||
CurvePreferences: []qtls.CurveID{19, 20, 21},
|
||||
CurvePreferences: []tls.CurveID{19, 20, 21},
|
||||
},
|
||||
}
|
||||
tlsCHI := toTLSClientHelloInfo((*qtls.ClientHelloInfo)(unsafe.Pointer(chi)))
|
||||
tlsCHI := toTLSClientHelloInfo((*ClientHelloInfo)(unsafe.Pointer(chi)))
|
||||
Expect(tlsCHI.CipherSuites).To(Equal([]uint16{1, 2, 3}))
|
||||
Expect(tlsCHI.ServerName).To(Equal("foo.bar"))
|
||||
Expect(tlsCHI.SupportedCurves).To(Equal([]tls.CurveID{4, 5, 6}))
|
||||
|
@ -226,9 +219,9 @@ var _ = Describe("qtls.Config generation", func() {
|
|||
Expect(c.config.CurvePreferences).To(Equal([]tls.CurveID{19, 20, 21}))
|
||||
})
|
||||
|
||||
It("converts a qtls.ClientHelloInfo to a tls.ClientHelloInfo, if no config is set", func() {
|
||||
It("converts a ClientHelloInfo to a tls.ClientHelloInfo, if no config is set", func() {
|
||||
chi := &qtlsClientHelloInfo{CipherSuites: []uint16{13, 37}}
|
||||
tlsCHI := toTLSClientHelloInfo((*qtls.ClientHelloInfo)(unsafe.Pointer(chi)))
|
||||
tlsCHI := toTLSClientHelloInfo((*ClientHelloInfo)(unsafe.Pointer(chi)))
|
||||
Expect(tlsCHI.CipherSuites).To(Equal([]uint16{13, 37}))
|
||||
})
|
||||
})
|
19
internal/qtls/structs_equal.go
Normal file
19
internal/qtls/structs_equal.go
Normal file
|
@ -0,0 +1,19 @@
|
|||
package qtls
|
||||
|
||||
import "reflect"
|
||||
|
||||
func structsEqual(a, b interface{}) bool {
|
||||
sa := reflect.ValueOf(a).Elem()
|
||||
sb := reflect.ValueOf(b).Elem()
|
||||
if sa.NumField() != sb.NumField() {
|
||||
return false
|
||||
}
|
||||
for i := 0; i < sa.NumField(); i++ {
|
||||
fa := sa.Type().Field(i)
|
||||
fb := sb.Type().Field(i)
|
||||
if !reflect.DeepEqual(fa.Index, fb.Index) || fa.Name != fb.Name || fa.Anonymous != fb.Anonymous || fa.Offset != fb.Offset || !reflect.DeepEqual(fa.Type, fb.Type) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
|
@ -1,4 +1,4 @@
|
|||
package handshake
|
||||
package qtls
|
||||
|
||||
import (
|
||||
. "github.com/onsi/ginkgo"
|
|
@ -11,7 +11,6 @@ import (
|
|||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
protocol "github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
qtls "github.com/marten-seemann/qtls"
|
||||
)
|
||||
|
||||
// MockQuicSession is a mock of QuicSession interface
|
||||
|
@ -82,10 +81,10 @@ func (mr *MockQuicSessionMockRecorder) CloseWithError(arg0, arg1 interface{}) *g
|
|||
}
|
||||
|
||||
// ConnectionState mocks base method
|
||||
func (m *MockQuicSession) ConnectionState() qtls.ConnectionState {
|
||||
func (m *MockQuicSession) ConnectionState() ConnectionState {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "ConnectionState")
|
||||
ret0, _ := ret[0].(qtls.ConnectionState)
|
||||
ret0, _ := ret[0].(ConnectionState)
|
||||
return ret0
|
||||
}
|
||||
|
||||
|
|
|
@ -10,6 +10,7 @@ echo -e "package $1\n" > $TMPFILE
|
|||
echo "type $INTERFACE_NAME = $4" >> $TMPFILE
|
||||
|
||||
mockgen -package $1 -self_package $PACKAGE -destination $2 $PACKAGE $INTERFACE_NAME
|
||||
mv $2 $TMPFILE && sed 's/qtls.ConnectionState/ConnectionState/g' $TMPFILE > $2
|
||||
goimports -w $2
|
||||
|
||||
rm $TMPFILE
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue