use the new qtls interface for (re)storing app data with a session state

Application data is now retrieved and restored via two callbacks on the
qtls.Config. This allows us the get rid of the rather complex wrapping
of the qtls.ClientSessionCache. Furthermore, it makes sure that we only
restore the application data when qtls decides to actually use the
ticket.
This commit is contained in:
Marten Seemann 2020-06-26 22:50:21 +07:00
parent f926945ae5
commit 07d4fd0991
13 changed files with 226 additions and 380 deletions

2
go.mod
View file

@ -8,7 +8,7 @@ require (
github.com/golang/mock v1.4.0
github.com/golang/protobuf v1.4.0
github.com/marten-seemann/qpack v0.1.0
github.com/marten-seemann/qtls v0.9.2
github.com/marten-seemann/qtls v0.10.0
github.com/onsi/ginkgo v1.11.0
github.com/onsi/gomega v1.8.1
golang.org/x/crypto v0.0.0-20200423211502-4bdfaf469ed5

4
go.sum
View file

@ -73,8 +73,8 @@ github.com/lunixbochs/vtclean v1.0.0/go.mod h1:pHhQNgMf3btfWnGBVipUOjRYhoOsdGqdm
github.com/mailru/easyjson v0.0.0-20190312143242-1de009706dbe/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc=
github.com/marten-seemann/qpack v0.1.0 h1:/0M7lkda/6mus9B8u34Asqm8ZhHAAt9Ho0vniNuVSVg=
github.com/marten-seemann/qpack v0.1.0/go.mod h1:LFt1NU/Ptjip0C2CPkhimBz5CGE3WGDAUWqna+CNTrI=
github.com/marten-seemann/qtls v0.9.2 h1:5/CTvBD0DlIOyoESU4J8CvooIZK//2sYK2I30Wou8Cs=
github.com/marten-seemann/qtls v0.9.2/go.mod h1:UvMd1oaYDACI99/oZUYLzMCkBXQVT0aGm99sJhbT8hs=
github.com/marten-seemann/qtls v0.10.0 h1:ECsuYUKalRL240rRD4Ri33ISb7kAQ3qGDlrrl55b2pc=
github.com/marten-seemann/qtls v0.10.0/go.mod h1:UvMd1oaYDACI99/oZUYLzMCkBXQVT0aGm99sJhbT8hs=
github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0=
github.com/microcosm-cc/bluemonday v1.0.1/go.mod h1:hsXNsILzKxV+sX77C5b8FSuKF00vh2OMYv+xgHpAMF4=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=

View file

@ -1,106 +0,0 @@
package handshake
import (
"bytes"
"crypto/tls"
"io"
"time"
"unsafe"
"github.com/marten-seemann/qtls"
"github.com/lucas-clemente/quic-go/internal/congestion"
"github.com/lucas-clemente/quic-go/internal/utils"
)
const clientSessionStateRevision = 2
type clientSessionCache struct {
tls.ClientSessionCache
rttStats *congestion.RTTStats
getAppData func() []byte
setAppData func([]byte)
}
func newClientSessionCache(
cache tls.ClientSessionCache,
rttStats *congestion.RTTStats,
get func() []byte,
set func([]byte),
) *clientSessionCache {
return &clientSessionCache{
ClientSessionCache: cache,
rttStats: rttStats,
getAppData: get,
setAppData: set,
}
}
var _ qtls.ClientSessionCache = &clientSessionCache{}
func (c *clientSessionCache) Get(sessionKey string) (*qtls.ClientSessionState, bool) {
sess, ok := c.ClientSessionCache.Get(sessionKey)
if sess == nil {
return nil, ok
}
// qtls.ClientSessionState is identical to the tls.ClientSessionState.
// In order to allow users of quic-go to use a tls.Config,
// we need this workaround to use the ClientSessionCache.
// In unsafe.go we check that the two structs are actually identical.
session := (*clientSessionState)(unsafe.Pointer(sess))
r := bytes.NewReader(session.nonce)
rev, err := utils.ReadVarInt(r)
if err != nil {
return nil, false
}
if rev != clientSessionStateRevision {
return nil, false
}
rtt, err := utils.ReadVarInt(r)
if err != nil {
return nil, false
}
appDataLen, err := utils.ReadVarInt(r)
if err != nil {
return nil, false
}
appData := make([]byte, appDataLen)
if _, err := io.ReadFull(r, appData); err != nil {
return nil, false
}
nonceLen, err := utils.ReadVarInt(r)
if err != nil {
return nil, false
}
nonce := make([]byte, nonceLen)
if _, err := io.ReadFull(r, nonce); err != nil {
return nil, false
}
c.setAppData(appData)
session.nonce = nonce
c.rttStats.SetInitialRTT(time.Duration(rtt) * time.Microsecond)
return (*qtls.ClientSessionState)(unsafe.Pointer(session)), ok
}
func (c *clientSessionCache) Put(sessionKey string, cs *qtls.ClientSessionState) {
if cs == nil {
c.ClientSessionCache.Put(sessionKey, nil)
return
}
// qtls.ClientSessionState is identical to the tls.ClientSessionState.
// In order to allow users of quic-go to use a tls.Config,
// we need this workaround to use the ClientSessionCache.
// In unsafe.go we check that the two structs are actually identical.
session := (*clientSessionState)(unsafe.Pointer(cs))
appData := c.getAppData()
buf := &bytes.Buffer{}
utils.WriteVarInt(buf, clientSessionStateRevision)
utils.WriteVarInt(buf, uint64(c.rttStats.SmoothedRTT().Microseconds()))
utils.WriteVarInt(buf, uint64(len(appData)))
buf.Write(appData)
utils.WriteVarInt(buf, uint64(len(session.nonce)))
buf.Write(session.nonce)
session.nonce = buf.Bytes()
c.ClientSessionCache.Put(sessionKey, (*tls.ClientSessionState)(unsafe.Pointer(session)))
}

View file

@ -1,127 +0,0 @@
package handshake
import (
"bytes"
"crypto/tls"
"time"
"unsafe"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/marten-seemann/qtls"
"github.com/lucas-clemente/quic-go/internal/congestion"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("ClientSessionCache", func() {
encodeIntoSessionTicket := func(data []byte) *tls.ClientSessionState {
session := &clientSessionState{nonce: data}
return (*tls.ClientSessionState)(unsafe.Pointer(session))
}
It("puts and gets", func() {
get := make(chan []byte, 100)
set := make(chan []byte, 100)
csc := newClientSessionCache(
tls.NewLRUClientSessionCache(100),
congestion.NewRTTStats(),
func() []byte { return <-get },
func(b []byte) { set <- b },
)
get <- []byte("foobar")
csc.Put("localhost", &qtls.ClientSessionState{})
Expect(set).To(BeEmpty())
state, ok := csc.Get("localhost")
Expect(ok).To(BeTrue())
Expect(state).ToNot(BeNil())
Expect(set).To(Receive(Equal([]byte("foobar"))))
})
It("saves the RTT", func() {
rttStatsOrig := congestion.NewRTTStats()
rttStatsOrig.UpdateRTT(10*time.Second, 0, time.Now())
Expect(rttStatsOrig.SmoothedRTT()).To(Equal(10 * time.Second))
cache := tls.NewLRUClientSessionCache(100)
csc1 := newClientSessionCache(
cache,
rttStatsOrig,
func() []byte { return nil },
func([]byte) {},
)
csc1.Put("localhost", &qtls.ClientSessionState{})
rttStats := congestion.NewRTTStats()
csc2 := newClientSessionCache(
cache,
rttStats,
func() []byte { return nil },
func([]byte) {},
)
Expect(rttStats.SmoothedRTT()).ToNot(Equal(10 * time.Second))
_, ok := csc2.Get("localhost")
Expect(ok).To(BeTrue())
Expect(rttStats.SmoothedRTT()).To(Equal(10 * time.Second))
})
It("refuses a session state that is too short for the revision", func() {
cache := tls.NewLRUClientSessionCache(1)
cache.Put("localhost", encodeIntoSessionTicket([]byte{}))
csc := newClientSessionCache(
cache,
congestion.NewRTTStats(),
func() []byte { return nil },
func([]byte) {},
)
_, ok := csc.Get("localhost")
Expect(ok).To(BeFalse())
})
It("refuses a session state with the wrong revision", func() {
cache := tls.NewLRUClientSessionCache(1)
b := &bytes.Buffer{}
utils.WriteVarInt(b, clientSessionStateRevision+1)
cache.Put("localhost", encodeIntoSessionTicket(b.Bytes()))
csc := newClientSessionCache(
cache,
congestion.NewRTTStats(),
func() []byte { return nil },
func([]byte) {},
)
_, ok := csc.Get("localhost")
Expect(ok).To(BeFalse())
})
It("refuses a session state when unmarshalling fails", func() {
rttStats := congestion.NewRTTStats()
rttStats.SetInitialRTT(10 * time.Second)
cache := tls.NewLRUClientSessionCache(1)
csc := newClientSessionCache(
cache,
rttStats,
func() []byte { return []byte("foobar") },
func(b []byte) {},
)
csc.Put("localhost", &qtls.ClientSessionState{})
state, ok := cache.Get("localhost")
Expect(ok).To(BeTrue())
session := (*clientSessionState)(unsafe.Pointer(state))
Expect(session.nonce).ToNot(BeEmpty())
_, ok = csc.Get("localhost")
Expect(ok).To(BeTrue())
nonce := session.nonce
for i := 0; i < len(nonce); i++ {
session.nonce = session.nonce[:i]
_, ok = csc.Get("localhost")
Expect(ok).To(BeFalse())
}
})
})

View file

@ -1,23 +0,0 @@
package handshake
import (
"crypto/x509"
"time"
)
// copied from crypto/tls
// Needs to be in a separate file so golangci-lint can skip it.
type clientSessionState struct {
sessionTicket []uint8 // Encrypted ticket used for session resumption with server
vers uint16 // SSL/TLS version negotiated for the session
cipherSuite uint16 // Ciphersuite negotiated for the session
masterSecret []byte // Full handshake MasterSecret, or TLS 1.3 resumption_master_secret
serverCertificates []*x509.Certificate // Certificate chain presented by the server
verifiedChains [][]*x509.Certificate // Certificate chains we built for verification
receivedAt time.Time // When the session ticket was received from the server
// TLS 1.3 fields.
nonce []byte // Ticket nonce sent by the server, to derive PSK
useBy time.Time // Expiration of the ticket lifetime as set by the server
ageAdd uint32 // Random obfuscation factor for sending the ticket age
}

View file

@ -59,6 +59,8 @@ func (m messageType) String() string {
}
}
const clientSessionStateRevision = 3
type cryptoSetup struct {
tlsConf *qtls.Config
conn *qtls.Conn
@ -230,7 +232,7 @@ func newCryptoSetup(
writeRecord: make(chan struct{}, 1),
closeChan: make(chan struct{}),
}
qtlsConf := tlsConfigToQtlsConfig(tlsConf, cs, extHandler, rttStats, cs.marshalPeerParamsForSessionState, cs.handlePeerParamsFromSessionState, cs.accept0RTT, cs.rejected0RTT, enable0RTT)
qtlsConf := tlsConfigToQtlsConfig(tlsConf, cs, extHandler, rttStats, cs.marshalDataForSessionState, cs.handleDataFromSessionState, cs.accept0RTT, cs.rejected0RTT, enable0RTT)
cs.tlsConf = qtlsConf
return cs, cs.clientHelloWrittenChan
}
@ -456,14 +458,16 @@ func (h *cryptoSetup) handleTransportParameters(data []byte) {
}
// must be called after receiving the transport parameters
func (h *cryptoSetup) marshalPeerParamsForSessionState() []byte {
b := &bytes.Buffer{}
h.peerParams.MarshalForSessionTicket(b)
return b.Bytes()
func (h *cryptoSetup) marshalDataForSessionState() []byte {
buf := &bytes.Buffer{}
utils.WriteVarInt(buf, clientSessionStateRevision)
utils.WriteVarInt(buf, uint64(h.rttStats.SmoothedRTT().Microseconds()))
h.peerParams.MarshalForSessionTicket(buf)
return buf.Bytes()
}
func (h *cryptoSetup) handlePeerParamsFromSessionState(data []byte) {
tp, err := h.handlePeerParamsFromSessionStateImpl(data)
func (h *cryptoSetup) handleDataFromSessionState(data []byte) {
tp, err := h.handleDataFromSessionStateImpl(data)
if err != nil {
h.logger.Debugf("Restoring of transport parameters from session ticket failed: %s", err.Error())
return
@ -471,9 +475,22 @@ func (h *cryptoSetup) handlePeerParamsFromSessionState(data []byte) {
h.zeroRTTParameters = tp
}
func (h *cryptoSetup) handlePeerParamsFromSessionStateImpl(data []byte) (*wire.TransportParameters, error) {
func (h *cryptoSetup) handleDataFromSessionStateImpl(data []byte) (*wire.TransportParameters, error) {
r := bytes.NewReader(data)
ver, err := utils.ReadVarInt(r)
if err != nil {
return nil, err
}
if ver != clientSessionStateRevision {
return nil, fmt.Errorf("mismatching version. Got %d, expected %d", ver, clientSessionStateRevision)
}
rtt, err := utils.ReadVarInt(r)
if err != nil {
return nil, err
}
h.rttStats.SetInitialRTT(time.Duration(rtt) * time.Microsecond)
var tp wire.TransportParameters
if err := tp.UnmarshalFromSessionTicket(data); err != nil {
if err := tp.UnmarshalFromSessionTicket(r); err != nil {
return nil, err
}
return &tp, nil

View file

@ -287,6 +287,13 @@ var _ = Describe("Crypto Setup TLS", func() {
}
}
newRTTStatsWithRTT := func(rtt time.Duration) *congestion.RTTStats {
rttStats := &congestion.RTTStats{}
rttStats.UpdateRTT(rtt, 0, time.Now())
ExpectWithOffset(1, rttStats.SmoothedRTT()).To(Equal(rtt))
return rttStats
}
handshake := func(client CryptoSetup, cChunkChan <-chan chunk,
server CryptoSetup, sChunkChan <-chan chunk) {
done := make(chan struct{})
@ -319,7 +326,12 @@ var _ = Describe("Crypto Setup TLS", func() {
Eventually(done).Should(BeClosed())
}
handshakeWithTLSConf := func(clientConf, serverConf *tls.Config, enable0RTT bool) (CryptoSetup /* client */, error /* client error */, CryptoSetup /* server */, error /* server error */) {
handshakeWithTLSConf := func(
clientConf, serverConf *tls.Config,
clientRTTStats, serverRTTStats *congestion.RTTStats,
clientTransportParameters, serverTransportParameters *wire.TransportParameters,
enable0RTT bool,
) (<-chan *wire.TransportParameters /* clientHelloWrittenChan */, CryptoSetup /* client */, error /* client error */, CryptoSetup /* server */, error /* server error */) {
var cHandshakeComplete bool
cChunkChan, cInitialStream, cHandshakeStream := initStreams()
cErrChan := make(chan error, 1)
@ -327,17 +339,18 @@ var _ = Describe("Crypto Setup TLS", func() {
cRunner.EXPECT().OnReceivedParams(gomock.Any())
cRunner.EXPECT().OnError(gomock.Any()).Do(func(e error) { cErrChan <- e }).MaxTimes(1)
cRunner.EXPECT().OnHandshakeComplete().Do(func() { cHandshakeComplete = true }).MaxTimes(1)
client, _ := NewCryptoSetupClient(
cRunner.EXPECT().DropKeys(gomock.Any()).MaxTimes(1)
client, clientHelloWrittenChan := NewCryptoSetupClient(
cInitialStream,
cHandshakeStream,
protocol.ConnectionID{},
nil,
nil,
&wire.TransportParameters{},
clientTransportParameters,
cRunner,
clientConf,
enable0RTT,
&congestion.RTTStats{},
clientRTTStats,
nil,
utils.DefaultLogger.WithPrefix("client"),
)
@ -349,18 +362,21 @@ var _ = Describe("Crypto Setup TLS", func() {
sRunner.EXPECT().OnReceivedParams(gomock.Any())
sRunner.EXPECT().OnError(gomock.Any()).Do(func(e error) { sErrChan <- e }).MaxTimes(1)
sRunner.EXPECT().OnHandshakeComplete().Do(func() { sHandshakeComplete = true }).MaxTimes(1)
var token [16]byte
if serverTransportParameters.StatelessResetToken == nil {
var token [16]byte
serverTransportParameters.StatelessResetToken = &token
}
server := NewCryptoSetupServer(
sInitialStream,
sHandshakeStream,
protocol.ConnectionID{},
nil,
nil,
&wire.TransportParameters{StatelessResetToken: &token},
serverTransportParameters,
sRunner,
serverConf,
enable0RTT,
&congestion.RTTStats{},
serverRTTStats,
nil,
utils.DefaultLogger.WithPrefix("server"),
)
@ -377,18 +393,28 @@ var _ = Describe("Crypto Setup TLS", func() {
default:
Expect(cHandshakeComplete).To(BeTrue())
}
return client, cErr, server, sErr
return clientHelloWrittenChan, client, cErr, server, sErr
}
It("handshakes", func() {
_, clientErr, _, serverErr := handshakeWithTLSConf(clientConf, serverConf, false)
_, _, clientErr, _, serverErr := handshakeWithTLSConf(
clientConf, serverConf,
&congestion.RTTStats{}, &congestion.RTTStats{},
&wire.TransportParameters{}, &wire.TransportParameters{},
false,
)
Expect(clientErr).ToNot(HaveOccurred())
Expect(serverErr).ToNot(HaveOccurred())
})
It("performs a HelloRetryRequst", func() {
serverConf.CurvePreferences = []tls.CurveID{tls.CurveP384}
_, clientErr, _, serverErr := handshakeWithTLSConf(clientConf, serverConf, false)
_, _, clientErr, _, serverErr := handshakeWithTLSConf(
clientConf, serverConf,
&congestion.RTTStats{}, &congestion.RTTStats{},
&wire.TransportParameters{}, &wire.TransportParameters{},
false,
)
Expect(clientErr).ToNot(HaveOccurred())
Expect(serverErr).ToNot(HaveOccurred())
})
@ -396,7 +422,12 @@ var _ = Describe("Crypto Setup TLS", func() {
It("handshakes with client auth", func() {
clientConf.Certificates = []tls.Certificate{generateCert()}
serverConf.ClientAuth = qtls.RequireAnyClientCert
_, clientErr, _, serverErr := handshakeWithTLSConf(clientConf, serverConf, false)
_, _, clientErr, _, serverErr := handshakeWithTLSConf(
clientConf, serverConf,
&congestion.RTTStats{}, &congestion.RTTStats{},
&wire.TransportParameters{}, &wire.TransportParameters{},
false,
)
Expect(clientErr).ToNot(HaveOccurred())
Expect(serverErr).ToNot(HaveOccurred())
})
@ -626,21 +657,37 @@ var _ = Describe("Crypto Setup TLS", func() {
close(receivedSessionTicket)
})
clientConf.ClientSessionCache = csc
client, clientErr, server, serverErr := handshakeWithTLSConf(clientConf, serverConf, false)
const clientRTT = 30 * time.Millisecond // RTT as measured by the client. Should be restored.
clientOrigRTTStats := newRTTStatsWithRTT(clientRTT)
clientHelloWrittenChan, client, clientErr, server, serverErr := handshakeWithTLSConf(
clientConf, serverConf,
clientOrigRTTStats, &congestion.RTTStats{},
&wire.TransportParameters{}, &wire.TransportParameters{},
false,
)
Expect(clientErr).ToNot(HaveOccurred())
Expect(serverErr).ToNot(HaveOccurred())
Eventually(receivedSessionTicket).Should(BeClosed())
Expect(server.ConnectionState().DidResume).To(BeFalse())
Expect(client.ConnectionState().DidResume).To(BeFalse())
Expect(clientHelloWrittenChan).To(Receive(BeNil()))
csc.EXPECT().Get(gomock.Any()).Return(state, true)
csc.EXPECT().Put(gomock.Any(), gomock.Any()).MaxTimes(1)
client, clientErr, server, serverErr = handshakeWithTLSConf(clientConf, serverConf, false)
clientRTTStats := &congestion.RTTStats{}
clientHelloWrittenChan, client, clientErr, server, serverErr = handshakeWithTLSConf(
clientConf, serverConf,
clientRTTStats, &congestion.RTTStats{},
&wire.TransportParameters{}, &wire.TransportParameters{},
false,
)
Expect(clientErr).ToNot(HaveOccurred())
Expect(serverErr).ToNot(HaveOccurred())
Eventually(receivedSessionTicket).Should(BeClosed())
Expect(server.ConnectionState().DidResume).To(BeTrue())
Expect(client.ConnectionState().DidResume).To(BeTrue())
Expect(clientRTTStats.SmoothedRTT()).To(Equal(clientRTT))
Expect(clientHelloWrittenChan).To(Receive(BeNil()))
})
It("doesn't use session resumption if the server disabled it", func() {
@ -653,7 +700,12 @@ var _ = Describe("Crypto Setup TLS", func() {
close(receivedSessionTicket)
})
clientConf.ClientSessionCache = csc
client, clientErr, server, serverErr := handshakeWithTLSConf(clientConf, serverConf, false)
_, client, clientErr, server, serverErr := handshakeWithTLSConf(
clientConf, serverConf,
&congestion.RTTStats{}, &congestion.RTTStats{},
&wire.TransportParameters{}, &wire.TransportParameters{},
false,
)
Expect(clientErr).ToNot(HaveOccurred())
Expect(serverErr).ToNot(HaveOccurred())
Eventually(receivedSessionTicket).Should(BeClosed())
@ -662,7 +714,12 @@ var _ = Describe("Crypto Setup TLS", func() {
serverConf.SessionTicketsDisabled = true
csc.EXPECT().Get(gomock.Any()).Return(state, true)
client, clientErr, server, serverErr = handshakeWithTLSConf(clientConf, serverConf, false)
_, client, clientErr, server, serverErr = handshakeWithTLSConf(
clientConf, serverConf,
&congestion.RTTStats{}, &congestion.RTTStats{},
&wire.TransportParameters{}, &wire.TransportParameters{},
false,
)
Expect(clientErr).ToNot(HaveOccurred())
Expect(serverErr).ToNot(HaveOccurred())
Eventually(receivedSessionTicket).Should(BeClosed())
@ -680,68 +737,100 @@ var _ = Describe("Crypto Setup TLS", func() {
close(receivedSessionTicket)
})
clientConf.ClientSessionCache = csc
client, clientErr, server, serverErr := handshakeWithTLSConf(clientConf, serverConf, true)
const serverRTT = 25 * time.Millisecond // RTT as measured by the server. Should be restored.
const clientRTT = 30 * time.Millisecond // RTT as measured by the client. Should be restored.
serverOrigRTTStats := newRTTStatsWithRTT(serverRTT)
clientOrigRTTStats := newRTTStatsWithRTT(clientRTT)
const initialMaxData protocol.ByteCount = 1337
clientHelloWrittenChan, client, clientErr, server, serverErr := handshakeWithTLSConf(
clientConf, serverConf,
clientOrigRTTStats, serverOrigRTTStats,
&wire.TransportParameters{}, &wire.TransportParameters{InitialMaxData: initialMaxData},
true,
)
Expect(clientErr).ToNot(HaveOccurred())
Expect(serverErr).ToNot(HaveOccurred())
Eventually(receivedSessionTicket).Should(BeClosed())
Expect(server.ConnectionState().DidResume).To(BeFalse())
Expect(client.ConnectionState().DidResume).To(BeFalse())
Expect(clientHelloWrittenChan).To(Receive(BeNil()))
csc.EXPECT().Get(gomock.Any()).Return(state, true)
csc.EXPECT().Put(gomock.Any(), nil)
csc.EXPECT().Put(gomock.Any(), gomock.Any()).MaxTimes(1)
cChunkChan, cInitialStream, cHandshakeStream := initStreams()
cRunner := NewMockHandshakeRunner(mockCtrl)
cRunner.EXPECT().OnReceivedParams(gomock.Any())
cRunner.EXPECT().OnHandshakeComplete()
client, clientHelloChan := NewCryptoSetupClient(
cInitialStream,
cHandshakeStream,
protocol.ConnectionID{},
nil,
nil,
&wire.TransportParameters{},
cRunner,
clientConf,
clientRTTStats := &congestion.RTTStats{}
serverRTTStats := &congestion.RTTStats{}
clientHelloWrittenChan, client, clientErr, server, serverErr = handshakeWithTLSConf(
clientConf, serverConf,
clientRTTStats, serverRTTStats,
&wire.TransportParameters{}, &wire.TransportParameters{InitialMaxData: initialMaxData},
true,
&congestion.RTTStats{},
nil,
utils.DefaultLogger.WithPrefix("client"),
)
Expect(clientErr).ToNot(HaveOccurred())
Expect(serverErr).ToNot(HaveOccurred())
Expect(clientRTTStats.SmoothedRTT()).To(Equal(clientRTT))
Expect(serverRTTStats.SmoothedRTT()).To(Equal(serverRTT))
sChunkChan, sInitialStream, sHandshakeStream := initStreams()
sRunner := NewMockHandshakeRunner(mockCtrl)
sRunner.EXPECT().OnReceivedParams(gomock.Any())
sRunner.EXPECT().OnHandshakeComplete()
var token [16]byte
server = NewCryptoSetupServer(
sInitialStream,
sHandshakeStream,
protocol.ConnectionID{},
nil,
nil,
&wire.TransportParameters{StatelessResetToken: &token},
sRunner,
serverConf,
true,
&congestion.RTTStats{},
nil,
utils.DefaultLogger.WithPrefix("server"),
)
done := make(chan struct{})
go func() {
defer GinkgoRecover()
handshake(client, cChunkChan, server, sChunkChan)
close(done)
}()
Eventually(done).Should(BeClosed())
Expect(clientHelloChan).To(Receive(Not(BeNil())))
var tp *wire.TransportParameters
Expect(clientHelloWrittenChan).To(Receive(&tp))
Expect(tp.InitialMaxData).To(Equal(initialMaxData))
Expect(server.ConnectionState().DidResume).To(BeTrue())
Expect(client.ConnectionState().DidResume).To(BeTrue())
Expect(server.ConnectionState().Used0RTT).To(BeTrue())
Expect(client.ConnectionState().Used0RTT).To(BeTrue())
})
It("rejects 0-RTT, whent the transport parameters changed", func() {
csc := NewMockClientSessionCache(mockCtrl)
var state *tls.ClientSessionState
receivedSessionTicket := make(chan struct{})
csc.EXPECT().Get(gomock.Any())
csc.EXPECT().Put(gomock.Any(), gomock.Any()).Do(func(_ string, css *tls.ClientSessionState) {
state = css
close(receivedSessionTicket)
})
clientConf.ClientSessionCache = csc
const clientRTT = 30 * time.Millisecond // RTT as measured by the client. Should be restored.
clientOrigRTTStats := newRTTStatsWithRTT(clientRTT)
const initialMaxData protocol.ByteCount = 1337
clientHelloWrittenChan, client, clientErr, server, serverErr := handshakeWithTLSConf(
clientConf, serverConf,
clientOrigRTTStats, &congestion.RTTStats{},
&wire.TransportParameters{}, &wire.TransportParameters{InitialMaxData: initialMaxData},
true,
)
Expect(clientErr).ToNot(HaveOccurred())
Expect(serverErr).ToNot(HaveOccurred())
Eventually(receivedSessionTicket).Should(BeClosed())
Expect(server.ConnectionState().DidResume).To(BeFalse())
Expect(client.ConnectionState().DidResume).To(BeFalse())
Expect(clientHelloWrittenChan).To(Receive(BeNil()))
csc.EXPECT().Get(gomock.Any()).Return(state, true)
csc.EXPECT().Put(gomock.Any(), nil)
csc.EXPECT().Put(gomock.Any(), gomock.Any()).MaxTimes(1)
clientRTTStats := &congestion.RTTStats{}
clientHelloWrittenChan, client, clientErr, server, serverErr = handshakeWithTLSConf(
clientConf, serverConf,
clientRTTStats, &congestion.RTTStats{},
&wire.TransportParameters{}, &wire.TransportParameters{InitialMaxData: initialMaxData + 1},
true,
)
Expect(clientErr).ToNot(HaveOccurred())
Expect(serverErr).ToNot(HaveOccurred())
Expect(clientRTTStats.SmoothedRTT()).To(Equal(clientRTT))
var tp *wire.TransportParameters
Expect(clientHelloWrittenChan).To(Receive(&tp))
Expect(tp.InitialMaxData).To(Equal(initialMaxData))
Expect(server.ConnectionState().DidResume).To(BeTrue())
Expect(client.ConnectionState().DidResume).To(BeTrue())
Expect(server.ConnectionState().Used0RTT).To(BeFalse())
Expect(client.ConnectionState().Used0RTT).To(BeFalse())
})
})
})

View file

@ -96,7 +96,7 @@ func tlsConfigToQtlsConfig(
}
var csc qtls.ClientSessionCache
if c.ClientSessionCache != nil {
csc = newClientSessionCache(c.ClientSessionCache, rttStats, getDataForSessionState, setDataFromSessionState)
csc = &clientSessionCache{c.ClientSessionCache}
}
conf := &qtls.Config{
Rand: c.Rand,
@ -126,12 +126,14 @@ func tlsConfigToQtlsConfig(
CurvePreferences: c.CurvePreferences,
DynamicRecordSizingDisabled: c.DynamicRecordSizingDisabled,
// no need to copy Renegotiation, it's not supported by TLS 1.3
KeyLogWriter: c.KeyLogWriter,
AlternativeRecordLayer: recordLayer,
GetExtensions: extHandler.GetExtensions,
ReceivedExtensions: extHandler.ReceivedExtensions,
Accept0RTT: accept0RTT,
Rejected0RTT: rejected0RTT,
KeyLogWriter: c.KeyLogWriter,
AlternativeRecordLayer: recordLayer,
GetExtensions: extHandler.GetExtensions,
ReceivedExtensions: extHandler.ReceivedExtensions,
Accept0RTT: accept0RTT,
Rejected0RTT: rejected0RTT,
GetAppDataForSessionState: getDataForSessionState,
SetAppDataFromSessionState: setDataFromSessionState,
}
if enable0RTT {
conf.Enable0RTT = true
@ -140,6 +142,36 @@ func tlsConfigToQtlsConfig(
return conf
}
type clientSessionCache struct {
tls.ClientSessionCache
}
var _ qtls.ClientSessionCache = &clientSessionCache{}
func (c *clientSessionCache) Get(sessionKey string) (*qtls.ClientSessionState, bool) {
sess, ok := c.ClientSessionCache.Get(sessionKey)
if sess == nil {
return nil, ok
}
// qtls.ClientSessionState is identical to the tls.ClientSessionState.
// In order to allow users of quic-go to use a tls.Config,
// we need this workaround to use the ClientSessionCache.
// In unsafe.go we check that the two structs are actually identical.
return (*qtls.ClientSessionState)(unsafe.Pointer(sess)), ok
}
func (c *clientSessionCache) Put(sessionKey string, cs *qtls.ClientSessionState) {
if cs == nil {
c.ClientSessionCache.Put(sessionKey, nil)
return
}
// qtls.ClientSessionState is identical to the tls.ClientSessionState.
// In order to allow users of quic-go to use a tls.Config,
// we need this workaround to use the ClientSessionCache.
// In unsafe.go we check that the two structs are actually identical.
c.ClientSessionCache.Put(sessionKey, (*tls.ClientSessionState)(unsafe.Pointer(cs)))
}
type clientHelloInfo struct {
CipherSuites []uint16
ServerName string

View file

@ -6,7 +6,6 @@ import (
"net"
"unsafe"
gomock "github.com/golang/mock/gomock"
"github.com/lucas-clemente/quic-go/internal/congestion"
"github.com/marten-seemann/qtls"
. "github.com/onsi/ginkgo"
@ -181,36 +180,6 @@ var _ = Describe("qtls.Config", func() {
Expect(qtlsConf.ClientSessionCache).To(BeNil())
})
It("sets it, and puts and gets session states", func() {
csc := NewMockClientSessionCache(mockCtrl)
tlsConf := &tls.Config{ClientSessionCache: csc}
var appData []byte
qtlsConf := tlsConfigToQtlsConfig(
tlsConf,
nil,
&mockExtensionHandler{},
congestion.NewRTTStats(),
func() []byte { return []byte("foobar") },
func(p []byte) { appData = p },
nil,
nil,
false,
)
Expect(qtlsConf.ClientSessionCache).ToNot(BeNil())
var state *tls.ClientSessionState
// put something
csc.EXPECT().Put("localhost", gomock.Any()).Do(func(_ string, css *tls.ClientSessionState) {
state = css
})
qtlsConf.ClientSessionCache.Put("localhost", &qtls.ClientSessionState{})
// get something
csc.EXPECT().Get("localhost").Return(state, true)
_, ok := qtlsConf.ClientSessionCache.Get("localhost")
Expect(ok).To(BeTrue())
Expect(appData).To(Equal([]byte("foobar")))
})
It("puts a nil session state", func() {
csc := NewMockClientSessionCache(mockCtrl)
tlsConf := &tls.Config{ClientSessionCache: csc}

View file

@ -39,7 +39,7 @@ func (t *sessionTicket) Unmarshal(b []byte) error {
return errors.New("failed to read RTT")
}
var tp wire.TransportParameters
if err := tp.UnmarshalFromSessionTicket(b[len(b)-r.Len():]); err != nil {
if err := tp.UnmarshalFromSessionTicket(r); err != nil {
return fmt.Errorf("unmarshaling transport parameters from session ticket failed: %s", err.Error())
}
t.Parameters = &tp

View file

@ -25,9 +25,6 @@ func init() {
if !structsEqual(&tls.ClientSessionState{}, &qtls.ClientSessionState{}) {
panic("qtls.ClientSessionState not compatible with tls.ClientSessionState")
}
if !structsEqual(&tls.ClientSessionState{}, &clientSessionState{}) {
panic("clientSessionState not compatible with tls.ClientSessionState")
}
}
func structsEqual(a, b interface{}) bool {

View file

@ -422,7 +422,7 @@ var _ = Describe("Transport Parameters", func() {
b := &bytes.Buffer{}
params.MarshalForSessionTicket(b)
var tp TransportParameters
Expect(tp.UnmarshalFromSessionTicket(b.Bytes())).To(Succeed())
Expect(tp.UnmarshalFromSessionTicket(bytes.NewReader(b.Bytes()))).To(Succeed())
Expect(tp.InitialMaxStreamDataBidiLocal).To(Equal(params.InitialMaxStreamDataBidiLocal))
Expect(tp.InitialMaxStreamDataBidiRemote).To(Equal(params.InitialMaxStreamDataBidiRemote))
Expect(tp.InitialMaxStreamDataUni).To(Equal(params.InitialMaxStreamDataUni))
@ -434,7 +434,7 @@ var _ = Describe("Transport Parameters", func() {
It("rejects the parameters if it can't parse them", func() {
var p TransportParameters
Expect(p.UnmarshalFromSessionTicket([]byte("foobar"))).ToNot(Succeed())
Expect(p.UnmarshalFromSessionTicket(bytes.NewReader([]byte("foobar")))).ToNot(Succeed())
})
It("rejects the parameters if the version changed", func() {
@ -445,7 +445,7 @@ var _ = Describe("Transport Parameters", func() {
b := &bytes.Buffer{}
utils.WriteVarInt(b, transportParameterMarshalingVersion+1)
b.Write(data[utils.VarIntLen(transportParameterMarshalingVersion):])
Expect(p.UnmarshalFromSessionTicket(b.Bytes())).To(MatchError(fmt.Sprintf("unknown transport parameter marshaling version: %d", transportParameterMarshalingVersion+1)))
Expect(p.UnmarshalFromSessionTicket(bytes.NewReader(b.Bytes()))).To(MatchError(fmt.Sprintf("unknown transport parameter marshaling version: %d", transportParameterMarshalingVersion+1)))
})
Context("rejects the parameters if they changed", func() {

View file

@ -85,13 +85,13 @@ type TransportParameters struct {
// Unmarshal the transport parameters
func (p *TransportParameters) Unmarshal(data []byte, sentBy protocol.Perspective) error {
if err := p.unmarshal(data, sentBy, false); err != nil {
if err := p.unmarshal(bytes.NewReader(data), sentBy, false); err != nil {
return qerr.NewError(qerr.TransportParameterError, err.Error())
}
return nil
}
func (p *TransportParameters) unmarshal(data []byte, sentBy protocol.Perspective, fromSessionTicket bool) error {
func (p *TransportParameters) unmarshal(r *bytes.Reader, sentBy protocol.Perspective, fromSessionTicket bool) error {
// needed to check that every parameter is only sent at most once
var parameterIDs []transportParameterID
@ -102,7 +102,6 @@ func (p *TransportParameters) unmarshal(data []byte, sentBy protocol.Perspective
readInitialSourceConnectionID bool
)
r := bytes.NewReader(data)
for r.Len() > 0 {
paramIDInt, err := utils.ReadVarInt(r)
if err != nil {
@ -429,8 +428,7 @@ func (p *TransportParameters) MarshalForSessionTicket(b *bytes.Buffer) {
}
// UnmarshalFromSessionTicket unmarshals transport parameters from a session ticket.
func (p *TransportParameters) UnmarshalFromSessionTicket(data []byte) error {
r := bytes.NewReader(data)
func (p *TransportParameters) UnmarshalFromSessionTicket(r *bytes.Reader) error {
version, err := utils.ReadVarInt(r)
if err != nil {
return err
@ -438,7 +436,7 @@ func (p *TransportParameters) UnmarshalFromSessionTicket(data []byte) error {
if version != transportParameterMarshalingVersion {
return fmt.Errorf("unknown transport parameter marshaling version: %d", version)
}
return p.unmarshal(data[len(data)-r.Len():], protocol.PerspectiveServer, true)
return p.unmarshal(r, protocol.PerspectiveServer, true)
}
// ValidFor0RTT checks if the transport parameters match those saved in the session ticket.