use the new qtls interface for (re)storing app data with a session state

Application data is now retrieved and restored via two callbacks on the
qtls.Config. This allows us the get rid of the rather complex wrapping
of the qtls.ClientSessionCache. Furthermore, it makes sure that we only
restore the application data when qtls decides to actually use the
ticket.
This commit is contained in:
Marten Seemann 2020-06-26 22:50:21 +07:00
parent f926945ae5
commit 07d4fd0991
13 changed files with 226 additions and 380 deletions

2
go.mod
View file

@ -8,7 +8,7 @@ require (
github.com/golang/mock v1.4.0 github.com/golang/mock v1.4.0
github.com/golang/protobuf v1.4.0 github.com/golang/protobuf v1.4.0
github.com/marten-seemann/qpack v0.1.0 github.com/marten-seemann/qpack v0.1.0
github.com/marten-seemann/qtls v0.9.2 github.com/marten-seemann/qtls v0.10.0
github.com/onsi/ginkgo v1.11.0 github.com/onsi/ginkgo v1.11.0
github.com/onsi/gomega v1.8.1 github.com/onsi/gomega v1.8.1
golang.org/x/crypto v0.0.0-20200423211502-4bdfaf469ed5 golang.org/x/crypto v0.0.0-20200423211502-4bdfaf469ed5

4
go.sum
View file

@ -73,8 +73,8 @@ github.com/lunixbochs/vtclean v1.0.0/go.mod h1:pHhQNgMf3btfWnGBVipUOjRYhoOsdGqdm
github.com/mailru/easyjson v0.0.0-20190312143242-1de009706dbe/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc= github.com/mailru/easyjson v0.0.0-20190312143242-1de009706dbe/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc=
github.com/marten-seemann/qpack v0.1.0 h1:/0M7lkda/6mus9B8u34Asqm8ZhHAAt9Ho0vniNuVSVg= github.com/marten-seemann/qpack v0.1.0 h1:/0M7lkda/6mus9B8u34Asqm8ZhHAAt9Ho0vniNuVSVg=
github.com/marten-seemann/qpack v0.1.0/go.mod h1:LFt1NU/Ptjip0C2CPkhimBz5CGE3WGDAUWqna+CNTrI= github.com/marten-seemann/qpack v0.1.0/go.mod h1:LFt1NU/Ptjip0C2CPkhimBz5CGE3WGDAUWqna+CNTrI=
github.com/marten-seemann/qtls v0.9.2 h1:5/CTvBD0DlIOyoESU4J8CvooIZK//2sYK2I30Wou8Cs= github.com/marten-seemann/qtls v0.10.0 h1:ECsuYUKalRL240rRD4Ri33ISb7kAQ3qGDlrrl55b2pc=
github.com/marten-seemann/qtls v0.9.2/go.mod h1:UvMd1oaYDACI99/oZUYLzMCkBXQVT0aGm99sJhbT8hs= github.com/marten-seemann/qtls v0.10.0/go.mod h1:UvMd1oaYDACI99/oZUYLzMCkBXQVT0aGm99sJhbT8hs=
github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0=
github.com/microcosm-cc/bluemonday v1.0.1/go.mod h1:hsXNsILzKxV+sX77C5b8FSuKF00vh2OMYv+xgHpAMF4= github.com/microcosm-cc/bluemonday v1.0.1/go.mod h1:hsXNsILzKxV+sX77C5b8FSuKF00vh2OMYv+xgHpAMF4=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=

View file

@ -1,106 +0,0 @@
package handshake
import (
"bytes"
"crypto/tls"
"io"
"time"
"unsafe"
"github.com/marten-seemann/qtls"
"github.com/lucas-clemente/quic-go/internal/congestion"
"github.com/lucas-clemente/quic-go/internal/utils"
)
const clientSessionStateRevision = 2
type clientSessionCache struct {
tls.ClientSessionCache
rttStats *congestion.RTTStats
getAppData func() []byte
setAppData func([]byte)
}
func newClientSessionCache(
cache tls.ClientSessionCache,
rttStats *congestion.RTTStats,
get func() []byte,
set func([]byte),
) *clientSessionCache {
return &clientSessionCache{
ClientSessionCache: cache,
rttStats: rttStats,
getAppData: get,
setAppData: set,
}
}
var _ qtls.ClientSessionCache = &clientSessionCache{}
func (c *clientSessionCache) Get(sessionKey string) (*qtls.ClientSessionState, bool) {
sess, ok := c.ClientSessionCache.Get(sessionKey)
if sess == nil {
return nil, ok
}
// qtls.ClientSessionState is identical to the tls.ClientSessionState.
// In order to allow users of quic-go to use a tls.Config,
// we need this workaround to use the ClientSessionCache.
// In unsafe.go we check that the two structs are actually identical.
session := (*clientSessionState)(unsafe.Pointer(sess))
r := bytes.NewReader(session.nonce)
rev, err := utils.ReadVarInt(r)
if err != nil {
return nil, false
}
if rev != clientSessionStateRevision {
return nil, false
}
rtt, err := utils.ReadVarInt(r)
if err != nil {
return nil, false
}
appDataLen, err := utils.ReadVarInt(r)
if err != nil {
return nil, false
}
appData := make([]byte, appDataLen)
if _, err := io.ReadFull(r, appData); err != nil {
return nil, false
}
nonceLen, err := utils.ReadVarInt(r)
if err != nil {
return nil, false
}
nonce := make([]byte, nonceLen)
if _, err := io.ReadFull(r, nonce); err != nil {
return nil, false
}
c.setAppData(appData)
session.nonce = nonce
c.rttStats.SetInitialRTT(time.Duration(rtt) * time.Microsecond)
return (*qtls.ClientSessionState)(unsafe.Pointer(session)), ok
}
func (c *clientSessionCache) Put(sessionKey string, cs *qtls.ClientSessionState) {
if cs == nil {
c.ClientSessionCache.Put(sessionKey, nil)
return
}
// qtls.ClientSessionState is identical to the tls.ClientSessionState.
// In order to allow users of quic-go to use a tls.Config,
// we need this workaround to use the ClientSessionCache.
// In unsafe.go we check that the two structs are actually identical.
session := (*clientSessionState)(unsafe.Pointer(cs))
appData := c.getAppData()
buf := &bytes.Buffer{}
utils.WriteVarInt(buf, clientSessionStateRevision)
utils.WriteVarInt(buf, uint64(c.rttStats.SmoothedRTT().Microseconds()))
utils.WriteVarInt(buf, uint64(len(appData)))
buf.Write(appData)
utils.WriteVarInt(buf, uint64(len(session.nonce)))
buf.Write(session.nonce)
session.nonce = buf.Bytes()
c.ClientSessionCache.Put(sessionKey, (*tls.ClientSessionState)(unsafe.Pointer(session)))
}

View file

@ -1,127 +0,0 @@
package handshake
import (
"bytes"
"crypto/tls"
"time"
"unsafe"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/marten-seemann/qtls"
"github.com/lucas-clemente/quic-go/internal/congestion"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("ClientSessionCache", func() {
encodeIntoSessionTicket := func(data []byte) *tls.ClientSessionState {
session := &clientSessionState{nonce: data}
return (*tls.ClientSessionState)(unsafe.Pointer(session))
}
It("puts and gets", func() {
get := make(chan []byte, 100)
set := make(chan []byte, 100)
csc := newClientSessionCache(
tls.NewLRUClientSessionCache(100),
congestion.NewRTTStats(),
func() []byte { return <-get },
func(b []byte) { set <- b },
)
get <- []byte("foobar")
csc.Put("localhost", &qtls.ClientSessionState{})
Expect(set).To(BeEmpty())
state, ok := csc.Get("localhost")
Expect(ok).To(BeTrue())
Expect(state).ToNot(BeNil())
Expect(set).To(Receive(Equal([]byte("foobar"))))
})
It("saves the RTT", func() {
rttStatsOrig := congestion.NewRTTStats()
rttStatsOrig.UpdateRTT(10*time.Second, 0, time.Now())
Expect(rttStatsOrig.SmoothedRTT()).To(Equal(10 * time.Second))
cache := tls.NewLRUClientSessionCache(100)
csc1 := newClientSessionCache(
cache,
rttStatsOrig,
func() []byte { return nil },
func([]byte) {},
)
csc1.Put("localhost", &qtls.ClientSessionState{})
rttStats := congestion.NewRTTStats()
csc2 := newClientSessionCache(
cache,
rttStats,
func() []byte { return nil },
func([]byte) {},
)
Expect(rttStats.SmoothedRTT()).ToNot(Equal(10 * time.Second))
_, ok := csc2.Get("localhost")
Expect(ok).To(BeTrue())
Expect(rttStats.SmoothedRTT()).To(Equal(10 * time.Second))
})
It("refuses a session state that is too short for the revision", func() {
cache := tls.NewLRUClientSessionCache(1)
cache.Put("localhost", encodeIntoSessionTicket([]byte{}))
csc := newClientSessionCache(
cache,
congestion.NewRTTStats(),
func() []byte { return nil },
func([]byte) {},
)
_, ok := csc.Get("localhost")
Expect(ok).To(BeFalse())
})
It("refuses a session state with the wrong revision", func() {
cache := tls.NewLRUClientSessionCache(1)
b := &bytes.Buffer{}
utils.WriteVarInt(b, clientSessionStateRevision+1)
cache.Put("localhost", encodeIntoSessionTicket(b.Bytes()))
csc := newClientSessionCache(
cache,
congestion.NewRTTStats(),
func() []byte { return nil },
func([]byte) {},
)
_, ok := csc.Get("localhost")
Expect(ok).To(BeFalse())
})
It("refuses a session state when unmarshalling fails", func() {
rttStats := congestion.NewRTTStats()
rttStats.SetInitialRTT(10 * time.Second)
cache := tls.NewLRUClientSessionCache(1)
csc := newClientSessionCache(
cache,
rttStats,
func() []byte { return []byte("foobar") },
func(b []byte) {},
)
csc.Put("localhost", &qtls.ClientSessionState{})
state, ok := cache.Get("localhost")
Expect(ok).To(BeTrue())
session := (*clientSessionState)(unsafe.Pointer(state))
Expect(session.nonce).ToNot(BeEmpty())
_, ok = csc.Get("localhost")
Expect(ok).To(BeTrue())
nonce := session.nonce
for i := 0; i < len(nonce); i++ {
session.nonce = session.nonce[:i]
_, ok = csc.Get("localhost")
Expect(ok).To(BeFalse())
}
})
})

View file

@ -1,23 +0,0 @@
package handshake
import (
"crypto/x509"
"time"
)
// copied from crypto/tls
// Needs to be in a separate file so golangci-lint can skip it.
type clientSessionState struct {
sessionTicket []uint8 // Encrypted ticket used for session resumption with server
vers uint16 // SSL/TLS version negotiated for the session
cipherSuite uint16 // Ciphersuite negotiated for the session
masterSecret []byte // Full handshake MasterSecret, or TLS 1.3 resumption_master_secret
serverCertificates []*x509.Certificate // Certificate chain presented by the server
verifiedChains [][]*x509.Certificate // Certificate chains we built for verification
receivedAt time.Time // When the session ticket was received from the server
// TLS 1.3 fields.
nonce []byte // Ticket nonce sent by the server, to derive PSK
useBy time.Time // Expiration of the ticket lifetime as set by the server
ageAdd uint32 // Random obfuscation factor for sending the ticket age
}

View file

@ -59,6 +59,8 @@ func (m messageType) String() string {
} }
} }
const clientSessionStateRevision = 3
type cryptoSetup struct { type cryptoSetup struct {
tlsConf *qtls.Config tlsConf *qtls.Config
conn *qtls.Conn conn *qtls.Conn
@ -230,7 +232,7 @@ func newCryptoSetup(
writeRecord: make(chan struct{}, 1), writeRecord: make(chan struct{}, 1),
closeChan: make(chan struct{}), closeChan: make(chan struct{}),
} }
qtlsConf := tlsConfigToQtlsConfig(tlsConf, cs, extHandler, rttStats, cs.marshalPeerParamsForSessionState, cs.handlePeerParamsFromSessionState, cs.accept0RTT, cs.rejected0RTT, enable0RTT) qtlsConf := tlsConfigToQtlsConfig(tlsConf, cs, extHandler, rttStats, cs.marshalDataForSessionState, cs.handleDataFromSessionState, cs.accept0RTT, cs.rejected0RTT, enable0RTT)
cs.tlsConf = qtlsConf cs.tlsConf = qtlsConf
return cs, cs.clientHelloWrittenChan return cs, cs.clientHelloWrittenChan
} }
@ -456,14 +458,16 @@ func (h *cryptoSetup) handleTransportParameters(data []byte) {
} }
// must be called after receiving the transport parameters // must be called after receiving the transport parameters
func (h *cryptoSetup) marshalPeerParamsForSessionState() []byte { func (h *cryptoSetup) marshalDataForSessionState() []byte {
b := &bytes.Buffer{} buf := &bytes.Buffer{}
h.peerParams.MarshalForSessionTicket(b) utils.WriteVarInt(buf, clientSessionStateRevision)
return b.Bytes() utils.WriteVarInt(buf, uint64(h.rttStats.SmoothedRTT().Microseconds()))
h.peerParams.MarshalForSessionTicket(buf)
return buf.Bytes()
} }
func (h *cryptoSetup) handlePeerParamsFromSessionState(data []byte) { func (h *cryptoSetup) handleDataFromSessionState(data []byte) {
tp, err := h.handlePeerParamsFromSessionStateImpl(data) tp, err := h.handleDataFromSessionStateImpl(data)
if err != nil { if err != nil {
h.logger.Debugf("Restoring of transport parameters from session ticket failed: %s", err.Error()) h.logger.Debugf("Restoring of transport parameters from session ticket failed: %s", err.Error())
return return
@ -471,9 +475,22 @@ func (h *cryptoSetup) handlePeerParamsFromSessionState(data []byte) {
h.zeroRTTParameters = tp h.zeroRTTParameters = tp
} }
func (h *cryptoSetup) handlePeerParamsFromSessionStateImpl(data []byte) (*wire.TransportParameters, error) { func (h *cryptoSetup) handleDataFromSessionStateImpl(data []byte) (*wire.TransportParameters, error) {
r := bytes.NewReader(data)
ver, err := utils.ReadVarInt(r)
if err != nil {
return nil, err
}
if ver != clientSessionStateRevision {
return nil, fmt.Errorf("mismatching version. Got %d, expected %d", ver, clientSessionStateRevision)
}
rtt, err := utils.ReadVarInt(r)
if err != nil {
return nil, err
}
h.rttStats.SetInitialRTT(time.Duration(rtt) * time.Microsecond)
var tp wire.TransportParameters var tp wire.TransportParameters
if err := tp.UnmarshalFromSessionTicket(data); err != nil { if err := tp.UnmarshalFromSessionTicket(r); err != nil {
return nil, err return nil, err
} }
return &tp, nil return &tp, nil

View file

@ -287,6 +287,13 @@ var _ = Describe("Crypto Setup TLS", func() {
} }
} }
newRTTStatsWithRTT := func(rtt time.Duration) *congestion.RTTStats {
rttStats := &congestion.RTTStats{}
rttStats.UpdateRTT(rtt, 0, time.Now())
ExpectWithOffset(1, rttStats.SmoothedRTT()).To(Equal(rtt))
return rttStats
}
handshake := func(client CryptoSetup, cChunkChan <-chan chunk, handshake := func(client CryptoSetup, cChunkChan <-chan chunk,
server CryptoSetup, sChunkChan <-chan chunk) { server CryptoSetup, sChunkChan <-chan chunk) {
done := make(chan struct{}) done := make(chan struct{})
@ -319,7 +326,12 @@ var _ = Describe("Crypto Setup TLS", func() {
Eventually(done).Should(BeClosed()) Eventually(done).Should(BeClosed())
} }
handshakeWithTLSConf := func(clientConf, serverConf *tls.Config, enable0RTT bool) (CryptoSetup /* client */, error /* client error */, CryptoSetup /* server */, error /* server error */) { handshakeWithTLSConf := func(
clientConf, serverConf *tls.Config,
clientRTTStats, serverRTTStats *congestion.RTTStats,
clientTransportParameters, serverTransportParameters *wire.TransportParameters,
enable0RTT bool,
) (<-chan *wire.TransportParameters /* clientHelloWrittenChan */, CryptoSetup /* client */, error /* client error */, CryptoSetup /* server */, error /* server error */) {
var cHandshakeComplete bool var cHandshakeComplete bool
cChunkChan, cInitialStream, cHandshakeStream := initStreams() cChunkChan, cInitialStream, cHandshakeStream := initStreams()
cErrChan := make(chan error, 1) cErrChan := make(chan error, 1)
@ -327,17 +339,18 @@ var _ = Describe("Crypto Setup TLS", func() {
cRunner.EXPECT().OnReceivedParams(gomock.Any()) cRunner.EXPECT().OnReceivedParams(gomock.Any())
cRunner.EXPECT().OnError(gomock.Any()).Do(func(e error) { cErrChan <- e }).MaxTimes(1) cRunner.EXPECT().OnError(gomock.Any()).Do(func(e error) { cErrChan <- e }).MaxTimes(1)
cRunner.EXPECT().OnHandshakeComplete().Do(func() { cHandshakeComplete = true }).MaxTimes(1) cRunner.EXPECT().OnHandshakeComplete().Do(func() { cHandshakeComplete = true }).MaxTimes(1)
client, _ := NewCryptoSetupClient( cRunner.EXPECT().DropKeys(gomock.Any()).MaxTimes(1)
client, clientHelloWrittenChan := NewCryptoSetupClient(
cInitialStream, cInitialStream,
cHandshakeStream, cHandshakeStream,
protocol.ConnectionID{}, protocol.ConnectionID{},
nil, nil,
nil, nil,
&wire.TransportParameters{}, clientTransportParameters,
cRunner, cRunner,
clientConf, clientConf,
enable0RTT, enable0RTT,
&congestion.RTTStats{}, clientRTTStats,
nil, nil,
utils.DefaultLogger.WithPrefix("client"), utils.DefaultLogger.WithPrefix("client"),
) )
@ -349,18 +362,21 @@ var _ = Describe("Crypto Setup TLS", func() {
sRunner.EXPECT().OnReceivedParams(gomock.Any()) sRunner.EXPECT().OnReceivedParams(gomock.Any())
sRunner.EXPECT().OnError(gomock.Any()).Do(func(e error) { sErrChan <- e }).MaxTimes(1) sRunner.EXPECT().OnError(gomock.Any()).Do(func(e error) { sErrChan <- e }).MaxTimes(1)
sRunner.EXPECT().OnHandshakeComplete().Do(func() { sHandshakeComplete = true }).MaxTimes(1) sRunner.EXPECT().OnHandshakeComplete().Do(func() { sHandshakeComplete = true }).MaxTimes(1)
if serverTransportParameters.StatelessResetToken == nil {
var token [16]byte var token [16]byte
serverTransportParameters.StatelessResetToken = &token
}
server := NewCryptoSetupServer( server := NewCryptoSetupServer(
sInitialStream, sInitialStream,
sHandshakeStream, sHandshakeStream,
protocol.ConnectionID{}, protocol.ConnectionID{},
nil, nil,
nil, nil,
&wire.TransportParameters{StatelessResetToken: &token}, serverTransportParameters,
sRunner, sRunner,
serverConf, serverConf,
enable0RTT, enable0RTT,
&congestion.RTTStats{}, serverRTTStats,
nil, nil,
utils.DefaultLogger.WithPrefix("server"), utils.DefaultLogger.WithPrefix("server"),
) )
@ -377,18 +393,28 @@ var _ = Describe("Crypto Setup TLS", func() {
default: default:
Expect(cHandshakeComplete).To(BeTrue()) Expect(cHandshakeComplete).To(BeTrue())
} }
return client, cErr, server, sErr return clientHelloWrittenChan, client, cErr, server, sErr
} }
It("handshakes", func() { It("handshakes", func() {
_, clientErr, _, serverErr := handshakeWithTLSConf(clientConf, serverConf, false) _, _, clientErr, _, serverErr := handshakeWithTLSConf(
clientConf, serverConf,
&congestion.RTTStats{}, &congestion.RTTStats{},
&wire.TransportParameters{}, &wire.TransportParameters{},
false,
)
Expect(clientErr).ToNot(HaveOccurred()) Expect(clientErr).ToNot(HaveOccurred())
Expect(serverErr).ToNot(HaveOccurred()) Expect(serverErr).ToNot(HaveOccurred())
}) })
It("performs a HelloRetryRequst", func() { It("performs a HelloRetryRequst", func() {
serverConf.CurvePreferences = []tls.CurveID{tls.CurveP384} serverConf.CurvePreferences = []tls.CurveID{tls.CurveP384}
_, clientErr, _, serverErr := handshakeWithTLSConf(clientConf, serverConf, false) _, _, clientErr, _, serverErr := handshakeWithTLSConf(
clientConf, serverConf,
&congestion.RTTStats{}, &congestion.RTTStats{},
&wire.TransportParameters{}, &wire.TransportParameters{},
false,
)
Expect(clientErr).ToNot(HaveOccurred()) Expect(clientErr).ToNot(HaveOccurred())
Expect(serverErr).ToNot(HaveOccurred()) Expect(serverErr).ToNot(HaveOccurred())
}) })
@ -396,7 +422,12 @@ var _ = Describe("Crypto Setup TLS", func() {
It("handshakes with client auth", func() { It("handshakes with client auth", func() {
clientConf.Certificates = []tls.Certificate{generateCert()} clientConf.Certificates = []tls.Certificate{generateCert()}
serverConf.ClientAuth = qtls.RequireAnyClientCert serverConf.ClientAuth = qtls.RequireAnyClientCert
_, clientErr, _, serverErr := handshakeWithTLSConf(clientConf, serverConf, false) _, _, clientErr, _, serverErr := handshakeWithTLSConf(
clientConf, serverConf,
&congestion.RTTStats{}, &congestion.RTTStats{},
&wire.TransportParameters{}, &wire.TransportParameters{},
false,
)
Expect(clientErr).ToNot(HaveOccurred()) Expect(clientErr).ToNot(HaveOccurred())
Expect(serverErr).ToNot(HaveOccurred()) Expect(serverErr).ToNot(HaveOccurred())
}) })
@ -626,21 +657,37 @@ var _ = Describe("Crypto Setup TLS", func() {
close(receivedSessionTicket) close(receivedSessionTicket)
}) })
clientConf.ClientSessionCache = csc clientConf.ClientSessionCache = csc
client, clientErr, server, serverErr := handshakeWithTLSConf(clientConf, serverConf, false) const clientRTT = 30 * time.Millisecond // RTT as measured by the client. Should be restored.
clientOrigRTTStats := newRTTStatsWithRTT(clientRTT)
clientHelloWrittenChan, client, clientErr, server, serverErr := handshakeWithTLSConf(
clientConf, serverConf,
clientOrigRTTStats, &congestion.RTTStats{},
&wire.TransportParameters{}, &wire.TransportParameters{},
false,
)
Expect(clientErr).ToNot(HaveOccurred()) Expect(clientErr).ToNot(HaveOccurred())
Expect(serverErr).ToNot(HaveOccurred()) Expect(serverErr).ToNot(HaveOccurred())
Eventually(receivedSessionTicket).Should(BeClosed()) Eventually(receivedSessionTicket).Should(BeClosed())
Expect(server.ConnectionState().DidResume).To(BeFalse()) Expect(server.ConnectionState().DidResume).To(BeFalse())
Expect(client.ConnectionState().DidResume).To(BeFalse()) Expect(client.ConnectionState().DidResume).To(BeFalse())
Expect(clientHelloWrittenChan).To(Receive(BeNil()))
csc.EXPECT().Get(gomock.Any()).Return(state, true) csc.EXPECT().Get(gomock.Any()).Return(state, true)
csc.EXPECT().Put(gomock.Any(), gomock.Any()).MaxTimes(1) csc.EXPECT().Put(gomock.Any(), gomock.Any()).MaxTimes(1)
client, clientErr, server, serverErr = handshakeWithTLSConf(clientConf, serverConf, false) clientRTTStats := &congestion.RTTStats{}
clientHelloWrittenChan, client, clientErr, server, serverErr = handshakeWithTLSConf(
clientConf, serverConf,
clientRTTStats, &congestion.RTTStats{},
&wire.TransportParameters{}, &wire.TransportParameters{},
false,
)
Expect(clientErr).ToNot(HaveOccurred()) Expect(clientErr).ToNot(HaveOccurred())
Expect(serverErr).ToNot(HaveOccurred()) Expect(serverErr).ToNot(HaveOccurred())
Eventually(receivedSessionTicket).Should(BeClosed()) Eventually(receivedSessionTicket).Should(BeClosed())
Expect(server.ConnectionState().DidResume).To(BeTrue()) Expect(server.ConnectionState().DidResume).To(BeTrue())
Expect(client.ConnectionState().DidResume).To(BeTrue()) Expect(client.ConnectionState().DidResume).To(BeTrue())
Expect(clientRTTStats.SmoothedRTT()).To(Equal(clientRTT))
Expect(clientHelloWrittenChan).To(Receive(BeNil()))
}) })
It("doesn't use session resumption if the server disabled it", func() { It("doesn't use session resumption if the server disabled it", func() {
@ -653,7 +700,12 @@ var _ = Describe("Crypto Setup TLS", func() {
close(receivedSessionTicket) close(receivedSessionTicket)
}) })
clientConf.ClientSessionCache = csc clientConf.ClientSessionCache = csc
client, clientErr, server, serverErr := handshakeWithTLSConf(clientConf, serverConf, false) _, client, clientErr, server, serverErr := handshakeWithTLSConf(
clientConf, serverConf,
&congestion.RTTStats{}, &congestion.RTTStats{},
&wire.TransportParameters{}, &wire.TransportParameters{},
false,
)
Expect(clientErr).ToNot(HaveOccurred()) Expect(clientErr).ToNot(HaveOccurred())
Expect(serverErr).ToNot(HaveOccurred()) Expect(serverErr).ToNot(HaveOccurred())
Eventually(receivedSessionTicket).Should(BeClosed()) Eventually(receivedSessionTicket).Should(BeClosed())
@ -662,7 +714,12 @@ var _ = Describe("Crypto Setup TLS", func() {
serverConf.SessionTicketsDisabled = true serverConf.SessionTicketsDisabled = true
csc.EXPECT().Get(gomock.Any()).Return(state, true) csc.EXPECT().Get(gomock.Any()).Return(state, true)
client, clientErr, server, serverErr = handshakeWithTLSConf(clientConf, serverConf, false) _, client, clientErr, server, serverErr = handshakeWithTLSConf(
clientConf, serverConf,
&congestion.RTTStats{}, &congestion.RTTStats{},
&wire.TransportParameters{}, &wire.TransportParameters{},
false,
)
Expect(clientErr).ToNot(HaveOccurred()) Expect(clientErr).ToNot(HaveOccurred())
Expect(serverErr).ToNot(HaveOccurred()) Expect(serverErr).ToNot(HaveOccurred())
Eventually(receivedSessionTicket).Should(BeClosed()) Eventually(receivedSessionTicket).Should(BeClosed())
@ -680,68 +737,100 @@ var _ = Describe("Crypto Setup TLS", func() {
close(receivedSessionTicket) close(receivedSessionTicket)
}) })
clientConf.ClientSessionCache = csc clientConf.ClientSessionCache = csc
client, clientErr, server, serverErr := handshakeWithTLSConf(clientConf, serverConf, true) const serverRTT = 25 * time.Millisecond // RTT as measured by the server. Should be restored.
const clientRTT = 30 * time.Millisecond // RTT as measured by the client. Should be restored.
serverOrigRTTStats := newRTTStatsWithRTT(serverRTT)
clientOrigRTTStats := newRTTStatsWithRTT(clientRTT)
const initialMaxData protocol.ByteCount = 1337
clientHelloWrittenChan, client, clientErr, server, serverErr := handshakeWithTLSConf(
clientConf, serverConf,
clientOrigRTTStats, serverOrigRTTStats,
&wire.TransportParameters{}, &wire.TransportParameters{InitialMaxData: initialMaxData},
true,
)
Expect(clientErr).ToNot(HaveOccurred()) Expect(clientErr).ToNot(HaveOccurred())
Expect(serverErr).ToNot(HaveOccurred()) Expect(serverErr).ToNot(HaveOccurred())
Eventually(receivedSessionTicket).Should(BeClosed()) Eventually(receivedSessionTicket).Should(BeClosed())
Expect(server.ConnectionState().DidResume).To(BeFalse()) Expect(server.ConnectionState().DidResume).To(BeFalse())
Expect(client.ConnectionState().DidResume).To(BeFalse()) Expect(client.ConnectionState().DidResume).To(BeFalse())
Expect(clientHelloWrittenChan).To(Receive(BeNil()))
csc.EXPECT().Get(gomock.Any()).Return(state, true) csc.EXPECT().Get(gomock.Any()).Return(state, true)
csc.EXPECT().Put(gomock.Any(), nil) csc.EXPECT().Put(gomock.Any(), nil)
csc.EXPECT().Put(gomock.Any(), gomock.Any()).MaxTimes(1) csc.EXPECT().Put(gomock.Any(), gomock.Any()).MaxTimes(1)
cChunkChan, cInitialStream, cHandshakeStream := initStreams() clientRTTStats := &congestion.RTTStats{}
cRunner := NewMockHandshakeRunner(mockCtrl) serverRTTStats := &congestion.RTTStats{}
cRunner.EXPECT().OnReceivedParams(gomock.Any()) clientHelloWrittenChan, client, clientErr, server, serverErr = handshakeWithTLSConf(
cRunner.EXPECT().OnHandshakeComplete() clientConf, serverConf,
client, clientHelloChan := NewCryptoSetupClient( clientRTTStats, serverRTTStats,
cInitialStream, &wire.TransportParameters{}, &wire.TransportParameters{InitialMaxData: initialMaxData},
cHandshakeStream,
protocol.ConnectionID{},
nil,
nil,
&wire.TransportParameters{},
cRunner,
clientConf,
true, true,
&congestion.RTTStats{},
nil,
utils.DefaultLogger.WithPrefix("client"),
) )
Expect(clientErr).ToNot(HaveOccurred())
Expect(serverErr).ToNot(HaveOccurred())
Expect(clientRTTStats.SmoothedRTT()).To(Equal(clientRTT))
Expect(serverRTTStats.SmoothedRTT()).To(Equal(serverRTT))
sChunkChan, sInitialStream, sHandshakeStream := initStreams() var tp *wire.TransportParameters
sRunner := NewMockHandshakeRunner(mockCtrl) Expect(clientHelloWrittenChan).To(Receive(&tp))
sRunner.EXPECT().OnReceivedParams(gomock.Any()) Expect(tp.InitialMaxData).To(Equal(initialMaxData))
sRunner.EXPECT().OnHandshakeComplete()
var token [16]byte
server = NewCryptoSetupServer(
sInitialStream,
sHandshakeStream,
protocol.ConnectionID{},
nil,
nil,
&wire.TransportParameters{StatelessResetToken: &token},
sRunner,
serverConf,
true,
&congestion.RTTStats{},
nil,
utils.DefaultLogger.WithPrefix("server"),
)
done := make(chan struct{})
go func() {
defer GinkgoRecover()
handshake(client, cChunkChan, server, sChunkChan)
close(done)
}()
Eventually(done).Should(BeClosed())
Expect(clientHelloChan).To(Receive(Not(BeNil())))
Expect(server.ConnectionState().DidResume).To(BeTrue()) Expect(server.ConnectionState().DidResume).To(BeTrue())
Expect(client.ConnectionState().DidResume).To(BeTrue()) Expect(client.ConnectionState().DidResume).To(BeTrue())
Expect(server.ConnectionState().Used0RTT).To(BeTrue())
Expect(client.ConnectionState().Used0RTT).To(BeTrue())
})
It("rejects 0-RTT, whent the transport parameters changed", func() {
csc := NewMockClientSessionCache(mockCtrl)
var state *tls.ClientSessionState
receivedSessionTicket := make(chan struct{})
csc.EXPECT().Get(gomock.Any())
csc.EXPECT().Put(gomock.Any(), gomock.Any()).Do(func(_ string, css *tls.ClientSessionState) {
state = css
close(receivedSessionTicket)
})
clientConf.ClientSessionCache = csc
const clientRTT = 30 * time.Millisecond // RTT as measured by the client. Should be restored.
clientOrigRTTStats := newRTTStatsWithRTT(clientRTT)
const initialMaxData protocol.ByteCount = 1337
clientHelloWrittenChan, client, clientErr, server, serverErr := handshakeWithTLSConf(
clientConf, serverConf,
clientOrigRTTStats, &congestion.RTTStats{},
&wire.TransportParameters{}, &wire.TransportParameters{InitialMaxData: initialMaxData},
true,
)
Expect(clientErr).ToNot(HaveOccurred())
Expect(serverErr).ToNot(HaveOccurred())
Eventually(receivedSessionTicket).Should(BeClosed())
Expect(server.ConnectionState().DidResume).To(BeFalse())
Expect(client.ConnectionState().DidResume).To(BeFalse())
Expect(clientHelloWrittenChan).To(Receive(BeNil()))
csc.EXPECT().Get(gomock.Any()).Return(state, true)
csc.EXPECT().Put(gomock.Any(), nil)
csc.EXPECT().Put(gomock.Any(), gomock.Any()).MaxTimes(1)
clientRTTStats := &congestion.RTTStats{}
clientHelloWrittenChan, client, clientErr, server, serverErr = handshakeWithTLSConf(
clientConf, serverConf,
clientRTTStats, &congestion.RTTStats{},
&wire.TransportParameters{}, &wire.TransportParameters{InitialMaxData: initialMaxData + 1},
true,
)
Expect(clientErr).ToNot(HaveOccurred())
Expect(serverErr).ToNot(HaveOccurred())
Expect(clientRTTStats.SmoothedRTT()).To(Equal(clientRTT))
var tp *wire.TransportParameters
Expect(clientHelloWrittenChan).To(Receive(&tp))
Expect(tp.InitialMaxData).To(Equal(initialMaxData))
Expect(server.ConnectionState().DidResume).To(BeTrue())
Expect(client.ConnectionState().DidResume).To(BeTrue())
Expect(server.ConnectionState().Used0RTT).To(BeFalse())
Expect(client.ConnectionState().Used0RTT).To(BeFalse())
}) })
}) })
}) })

View file

@ -96,7 +96,7 @@ func tlsConfigToQtlsConfig(
} }
var csc qtls.ClientSessionCache var csc qtls.ClientSessionCache
if c.ClientSessionCache != nil { if c.ClientSessionCache != nil {
csc = newClientSessionCache(c.ClientSessionCache, rttStats, getDataForSessionState, setDataFromSessionState) csc = &clientSessionCache{c.ClientSessionCache}
} }
conf := &qtls.Config{ conf := &qtls.Config{
Rand: c.Rand, Rand: c.Rand,
@ -132,6 +132,8 @@ func tlsConfigToQtlsConfig(
ReceivedExtensions: extHandler.ReceivedExtensions, ReceivedExtensions: extHandler.ReceivedExtensions,
Accept0RTT: accept0RTT, Accept0RTT: accept0RTT,
Rejected0RTT: rejected0RTT, Rejected0RTT: rejected0RTT,
GetAppDataForSessionState: getDataForSessionState,
SetAppDataFromSessionState: setDataFromSessionState,
} }
if enable0RTT { if enable0RTT {
conf.Enable0RTT = true conf.Enable0RTT = true
@ -140,6 +142,36 @@ func tlsConfigToQtlsConfig(
return conf return conf
} }
type clientSessionCache struct {
tls.ClientSessionCache
}
var _ qtls.ClientSessionCache = &clientSessionCache{}
func (c *clientSessionCache) Get(sessionKey string) (*qtls.ClientSessionState, bool) {
sess, ok := c.ClientSessionCache.Get(sessionKey)
if sess == nil {
return nil, ok
}
// qtls.ClientSessionState is identical to the tls.ClientSessionState.
// In order to allow users of quic-go to use a tls.Config,
// we need this workaround to use the ClientSessionCache.
// In unsafe.go we check that the two structs are actually identical.
return (*qtls.ClientSessionState)(unsafe.Pointer(sess)), ok
}
func (c *clientSessionCache) Put(sessionKey string, cs *qtls.ClientSessionState) {
if cs == nil {
c.ClientSessionCache.Put(sessionKey, nil)
return
}
// qtls.ClientSessionState is identical to the tls.ClientSessionState.
// In order to allow users of quic-go to use a tls.Config,
// we need this workaround to use the ClientSessionCache.
// In unsafe.go we check that the two structs are actually identical.
c.ClientSessionCache.Put(sessionKey, (*tls.ClientSessionState)(unsafe.Pointer(cs)))
}
type clientHelloInfo struct { type clientHelloInfo struct {
CipherSuites []uint16 CipherSuites []uint16
ServerName string ServerName string

View file

@ -6,7 +6,6 @@ import (
"net" "net"
"unsafe" "unsafe"
gomock "github.com/golang/mock/gomock"
"github.com/lucas-clemente/quic-go/internal/congestion" "github.com/lucas-clemente/quic-go/internal/congestion"
"github.com/marten-seemann/qtls" "github.com/marten-seemann/qtls"
. "github.com/onsi/ginkgo" . "github.com/onsi/ginkgo"
@ -181,36 +180,6 @@ var _ = Describe("qtls.Config", func() {
Expect(qtlsConf.ClientSessionCache).To(BeNil()) Expect(qtlsConf.ClientSessionCache).To(BeNil())
}) })
It("sets it, and puts and gets session states", func() {
csc := NewMockClientSessionCache(mockCtrl)
tlsConf := &tls.Config{ClientSessionCache: csc}
var appData []byte
qtlsConf := tlsConfigToQtlsConfig(
tlsConf,
nil,
&mockExtensionHandler{},
congestion.NewRTTStats(),
func() []byte { return []byte("foobar") },
func(p []byte) { appData = p },
nil,
nil,
false,
)
Expect(qtlsConf.ClientSessionCache).ToNot(BeNil())
var state *tls.ClientSessionState
// put something
csc.EXPECT().Put("localhost", gomock.Any()).Do(func(_ string, css *tls.ClientSessionState) {
state = css
})
qtlsConf.ClientSessionCache.Put("localhost", &qtls.ClientSessionState{})
// get something
csc.EXPECT().Get("localhost").Return(state, true)
_, ok := qtlsConf.ClientSessionCache.Get("localhost")
Expect(ok).To(BeTrue())
Expect(appData).To(Equal([]byte("foobar")))
})
It("puts a nil session state", func() { It("puts a nil session state", func() {
csc := NewMockClientSessionCache(mockCtrl) csc := NewMockClientSessionCache(mockCtrl)
tlsConf := &tls.Config{ClientSessionCache: csc} tlsConf := &tls.Config{ClientSessionCache: csc}

View file

@ -39,7 +39,7 @@ func (t *sessionTicket) Unmarshal(b []byte) error {
return errors.New("failed to read RTT") return errors.New("failed to read RTT")
} }
var tp wire.TransportParameters var tp wire.TransportParameters
if err := tp.UnmarshalFromSessionTicket(b[len(b)-r.Len():]); err != nil { if err := tp.UnmarshalFromSessionTicket(r); err != nil {
return fmt.Errorf("unmarshaling transport parameters from session ticket failed: %s", err.Error()) return fmt.Errorf("unmarshaling transport parameters from session ticket failed: %s", err.Error())
} }
t.Parameters = &tp t.Parameters = &tp

View file

@ -25,9 +25,6 @@ func init() {
if !structsEqual(&tls.ClientSessionState{}, &qtls.ClientSessionState{}) { if !structsEqual(&tls.ClientSessionState{}, &qtls.ClientSessionState{}) {
panic("qtls.ClientSessionState not compatible with tls.ClientSessionState") panic("qtls.ClientSessionState not compatible with tls.ClientSessionState")
} }
if !structsEqual(&tls.ClientSessionState{}, &clientSessionState{}) {
panic("clientSessionState not compatible with tls.ClientSessionState")
}
} }
func structsEqual(a, b interface{}) bool { func structsEqual(a, b interface{}) bool {

View file

@ -422,7 +422,7 @@ var _ = Describe("Transport Parameters", func() {
b := &bytes.Buffer{} b := &bytes.Buffer{}
params.MarshalForSessionTicket(b) params.MarshalForSessionTicket(b)
var tp TransportParameters var tp TransportParameters
Expect(tp.UnmarshalFromSessionTicket(b.Bytes())).To(Succeed()) Expect(tp.UnmarshalFromSessionTicket(bytes.NewReader(b.Bytes()))).To(Succeed())
Expect(tp.InitialMaxStreamDataBidiLocal).To(Equal(params.InitialMaxStreamDataBidiLocal)) Expect(tp.InitialMaxStreamDataBidiLocal).To(Equal(params.InitialMaxStreamDataBidiLocal))
Expect(tp.InitialMaxStreamDataBidiRemote).To(Equal(params.InitialMaxStreamDataBidiRemote)) Expect(tp.InitialMaxStreamDataBidiRemote).To(Equal(params.InitialMaxStreamDataBidiRemote))
Expect(tp.InitialMaxStreamDataUni).To(Equal(params.InitialMaxStreamDataUni)) Expect(tp.InitialMaxStreamDataUni).To(Equal(params.InitialMaxStreamDataUni))
@ -434,7 +434,7 @@ var _ = Describe("Transport Parameters", func() {
It("rejects the parameters if it can't parse them", func() { It("rejects the parameters if it can't parse them", func() {
var p TransportParameters var p TransportParameters
Expect(p.UnmarshalFromSessionTicket([]byte("foobar"))).ToNot(Succeed()) Expect(p.UnmarshalFromSessionTicket(bytes.NewReader([]byte("foobar")))).ToNot(Succeed())
}) })
It("rejects the parameters if the version changed", func() { It("rejects the parameters if the version changed", func() {
@ -445,7 +445,7 @@ var _ = Describe("Transport Parameters", func() {
b := &bytes.Buffer{} b := &bytes.Buffer{}
utils.WriteVarInt(b, transportParameterMarshalingVersion+1) utils.WriteVarInt(b, transportParameterMarshalingVersion+1)
b.Write(data[utils.VarIntLen(transportParameterMarshalingVersion):]) b.Write(data[utils.VarIntLen(transportParameterMarshalingVersion):])
Expect(p.UnmarshalFromSessionTicket(b.Bytes())).To(MatchError(fmt.Sprintf("unknown transport parameter marshaling version: %d", transportParameterMarshalingVersion+1))) Expect(p.UnmarshalFromSessionTicket(bytes.NewReader(b.Bytes()))).To(MatchError(fmt.Sprintf("unknown transport parameter marshaling version: %d", transportParameterMarshalingVersion+1)))
}) })
Context("rejects the parameters if they changed", func() { Context("rejects the parameters if they changed", func() {

View file

@ -85,13 +85,13 @@ type TransportParameters struct {
// Unmarshal the transport parameters // Unmarshal the transport parameters
func (p *TransportParameters) Unmarshal(data []byte, sentBy protocol.Perspective) error { func (p *TransportParameters) Unmarshal(data []byte, sentBy protocol.Perspective) error {
if err := p.unmarshal(data, sentBy, false); err != nil { if err := p.unmarshal(bytes.NewReader(data), sentBy, false); err != nil {
return qerr.NewError(qerr.TransportParameterError, err.Error()) return qerr.NewError(qerr.TransportParameterError, err.Error())
} }
return nil return nil
} }
func (p *TransportParameters) unmarshal(data []byte, sentBy protocol.Perspective, fromSessionTicket bool) error { func (p *TransportParameters) unmarshal(r *bytes.Reader, sentBy protocol.Perspective, fromSessionTicket bool) error {
// needed to check that every parameter is only sent at most once // needed to check that every parameter is only sent at most once
var parameterIDs []transportParameterID var parameterIDs []transportParameterID
@ -102,7 +102,6 @@ func (p *TransportParameters) unmarshal(data []byte, sentBy protocol.Perspective
readInitialSourceConnectionID bool readInitialSourceConnectionID bool
) )
r := bytes.NewReader(data)
for r.Len() > 0 { for r.Len() > 0 {
paramIDInt, err := utils.ReadVarInt(r) paramIDInt, err := utils.ReadVarInt(r)
if err != nil { if err != nil {
@ -429,8 +428,7 @@ func (p *TransportParameters) MarshalForSessionTicket(b *bytes.Buffer) {
} }
// UnmarshalFromSessionTicket unmarshals transport parameters from a session ticket. // UnmarshalFromSessionTicket unmarshals transport parameters from a session ticket.
func (p *TransportParameters) UnmarshalFromSessionTicket(data []byte) error { func (p *TransportParameters) UnmarshalFromSessionTicket(r *bytes.Reader) error {
r := bytes.NewReader(data)
version, err := utils.ReadVarInt(r) version, err := utils.ReadVarInt(r)
if err != nil { if err != nil {
return err return err
@ -438,7 +436,7 @@ func (p *TransportParameters) UnmarshalFromSessionTicket(data []byte) error {
if version != transportParameterMarshalingVersion { if version != transportParameterMarshalingVersion {
return fmt.Errorf("unknown transport parameter marshaling version: %d", version) return fmt.Errorf("unknown transport parameter marshaling version: %d", version)
} }
return p.unmarshal(data[len(data)-r.Len():], protocol.PerspectiveServer, true) return p.unmarshal(r, protocol.PerspectiveServer, true)
} }
// ValidFor0RTT checks if the transport parameters match those saved in the session ticket. // ValidFor0RTT checks if the transport parameters match those saved in the session ticket.