restore the server's transport parameters from the session ticket

This commit is contained in:
Marten Seemann 2019-08-09 15:32:57 +07:00
parent 1f8a47af02
commit 44aa12850e
5 changed files with 124 additions and 41 deletions

View file

@ -67,6 +67,7 @@ type cryptoSetup struct {
messageChan chan []byte messageChan chan []byte
ourParams *TransportParameters ourParams *TransportParameters
peerParams *TransportParameters
paramsChan <-chan []byte paramsChan <-chan []byte
runner handshakeRunner runner handshakeRunner
@ -77,8 +78,9 @@ type cryptoSetup struct {
// is closed when Close() is called // is closed when Close() is called
closeChan chan struct{} closeChan chan struct{}
zeroRTTParameters *TransportParameters
clientHelloWritten bool clientHelloWritten bool
clientHelloWrittenChan chan struct{} clientHelloWrittenChan chan *TransportParameters
receivedWriteKey chan struct{} receivedWriteKey chan struct{}
receivedReadKey chan struct{} receivedReadKey chan struct{}
@ -131,7 +133,7 @@ func NewCryptoSetupClient(
enable0RTT bool, enable0RTT bool,
rttStats *congestion.RTTStats, rttStats *congestion.RTTStats,
logger utils.Logger, logger utils.Logger,
) (CryptoSetup, <-chan struct{} /* ClientHello written */) { ) (CryptoSetup, <-chan *TransportParameters /* ClientHello written. Receive nil for non-0-RTT */) {
cs, clientHelloWritten := newCryptoSetup( cs, clientHelloWritten := newCryptoSetup(
initialStream, initialStream,
handshakeStream, handshakeStream,
@ -192,7 +194,7 @@ func newCryptoSetup(
rttStats *congestion.RTTStats, rttStats *congestion.RTTStats,
logger utils.Logger, logger utils.Logger,
perspective protocol.Perspective, perspective protocol.Perspective,
) (*cryptoSetup, <-chan struct{} /* ClientHello written */) { ) (*cryptoSetup, <-chan *TransportParameters /* ClientHello written. Receive nil for non-0-RTT */) {
initialSealer, initialOpener := NewInitialAEAD(connID, perspective) initialSealer, initialOpener := NewInitialAEAD(connID, perspective)
extHandler := newExtensionHandler(tp.Marshal(), perspective) extHandler := newExtensionHandler(tp.Marshal(), perspective)
cs := &cryptoSetup{ cs := &cryptoSetup{
@ -211,14 +213,14 @@ func newCryptoSetup(
perspective: perspective, perspective: perspective,
handshakeDone: make(chan struct{}), handshakeDone: make(chan struct{}),
alertChan: make(chan uint8), alertChan: make(chan uint8),
clientHelloWrittenChan: make(chan struct{}), clientHelloWrittenChan: make(chan *TransportParameters, 1),
messageChan: make(chan []byte, 100), messageChan: make(chan []byte, 100),
receivedReadKey: make(chan struct{}), receivedReadKey: make(chan struct{}),
receivedWriteKey: make(chan struct{}), receivedWriteKey: make(chan struct{}),
writeRecord: make(chan struct{}, 1), writeRecord: make(chan struct{}, 1),
closeChan: make(chan struct{}), closeChan: make(chan struct{}),
} }
qtlsConf := tlsConfigToQtlsConfig(tlsConf, cs, extHandler, cs.accept0RTT, enable0RTT) qtlsConf := tlsConfigToQtlsConfig(tlsConf, cs, extHandler, cs.marshalPeerParamsForSessionState, cs.handlePeerParamsFromSessionState, cs.accept0RTT, enable0RTT)
cs.tlsConf = qtlsConf cs.tlsConf = qtlsConf
return cs, cs.clientHelloWrittenChan return cs, cs.clientHelloWrittenChan
} }
@ -436,7 +438,22 @@ func (h *cryptoSetup) handleTransportParameters(data []byte) {
if err := tp.Unmarshal(data, h.perspective.Opposite()); err != nil { if err := tp.Unmarshal(data, h.perspective.Opposite()); err != nil {
h.runner.OnError(qerr.Error(qerr.TransportParameterError, err.Error())) h.runner.OnError(qerr.Error(qerr.TransportParameterError, err.Error()))
} }
h.runner.OnReceivedParams(&tp) h.peerParams = &tp
h.runner.OnReceivedParams(h.peerParams)
}
// must be called after receiving the transport parameters
func (h *cryptoSetup) marshalPeerParamsForSessionState() []byte {
return h.peerParams.MarshalForSessionTicket()
}
func (h *cryptoSetup) handlePeerParamsFromSessionState(data []byte) {
var tp TransportParameters
if err := tp.Unmarshal(data, protocol.PerspectiveServer); err != nil {
h.logger.Debugf("Restoring of transport parameters from session ticket failed: %s", err.Error())
return
}
h.zeroRTTParameters = &tp
} }
// only valid for the server // only valid for the server
@ -569,7 +586,13 @@ func (h *cryptoSetup) WriteRecord(p []byte) (int, error) {
n, err := h.initialStream.Write(p) n, err := h.initialStream.Write(p)
if !h.clientHelloWritten && h.perspective == protocol.PerspectiveClient { if !h.clientHelloWritten && h.perspective == protocol.PerspectiveClient {
h.clientHelloWritten = true h.clientHelloWritten = true
close(h.clientHelloWrittenChan) if h.zeroRTTSealer != nil && h.zeroRTTParameters != nil {
h.logger.Debugf("Doing 0-RTT.")
h.clientHelloWrittenChan <- h.zeroRTTParameters
} else {
h.logger.Debugf("Not doing 0-RTT. Has Sealer: %t, has params: %t", h.zeroRTTSealer != nil, h.zeroRTTParameters != nil)
h.clientHelloWrittenChan <- nil
}
} else { } else {
// We need additional signaling to properly detect HelloRetryRequests. // We need additional signaling to properly detect HelloRetryRequests.
// For servers: when the ServerHello is written. // For servers: when the ServerHello is written.

View file

@ -417,7 +417,7 @@ var _ = Describe("Crypto Setup TLS", func() {
}() }()
var ch chunk var ch chunk
Eventually(cChunkChan).Should(Receive(&ch)) Eventually(cChunkChan).Should(Receive(&ch))
Eventually(chChan).Should(BeClosed()) Eventually(chChan).Should(Receive(BeNil()))
// make sure the whole ClientHello was written // make sure the whole ClientHello was written
Expect(len(ch.data)).To(BeNumerically(">=", 4)) Expect(len(ch.data)).To(BeNumerically(">=", 4))
Expect(messageType(ch.data[0])).To(Equal(typeClientHello)) Expect(messageType(ch.data[0])).To(Equal(typeClientHello))
@ -671,10 +671,53 @@ var _ = Describe("Crypto Setup TLS", func() {
csc.EXPECT().Get(gomock.Any()).Return(state, true) csc.EXPECT().Get(gomock.Any()).Return(state, true)
csc.EXPECT().Put(gomock.Any(), gomock.Any()).MaxTimes(1) csc.EXPECT().Put(gomock.Any(), gomock.Any()).MaxTimes(1)
client, clientErr, server, serverErr = handshakeWithTLSConf(clientConf, serverConf, true)
Expect(clientErr).ToNot(HaveOccurred()) cChunkChan, cInitialStream, cHandshakeStream, _ := initStreams()
Expect(serverErr).ToNot(HaveOccurred()) cRunner := NewMockHandshakeRunner(mockCtrl)
Eventually(receivedSessionTicket).Should(BeClosed()) cRunner.EXPECT().OnReceivedParams(gomock.Any())
cRunner.EXPECT().OnHandshakeComplete()
client, clientHelloChan := NewCryptoSetupClient(
cInitialStream,
cHandshakeStream,
ioutil.Discard,
protocol.ConnectionID{},
nil,
&TransportParameters{},
cRunner,
clientConf,
true,
&congestion.RTTStats{},
utils.DefaultLogger.WithPrefix("client"),
)
sChunkChan, sInitialStream, sHandshakeStream, _ := initStreams()
sRunner := NewMockHandshakeRunner(mockCtrl)
sRunner.EXPECT().OnReceivedParams(gomock.Any())
sRunner.EXPECT().OnHandshakeComplete()
server = NewCryptoSetupServer(
sInitialStream,
sHandshakeStream,
ioutil.Discard,
protocol.ConnectionID{},
nil,
&TransportParameters{},
sRunner,
serverConf,
true,
&congestion.RTTStats{},
utils.DefaultLogger.WithPrefix("server"),
)
done := make(chan struct{})
go func() {
defer GinkgoRecover()
handshake(client, cChunkChan, server, sChunkChan)
close(done)
}()
Eventually(done).Should(BeClosed())
Expect(clientHelloChan).To(Receive(Not(BeNil())))
Expect(server.ConnectionState().DidResume).To(BeTrue()) Expect(server.ConnectionState().DidResume).To(BeTrue())
Expect(client.ConnectionState().DidResume).To(BeTrue()) Expect(client.ConnectionState().DidResume).To(BeTrue())
opener, err := server.Get0RTTOpener() opener, err := server.Get0RTTOpener()

View file

@ -31,6 +31,8 @@ func tlsConfigToQtlsConfig(
c *tls.Config, c *tls.Config,
recordLayer qtls.RecordLayer, recordLayer qtls.RecordLayer,
extHandler tlsExtensionHandler, extHandler tlsExtensionHandler,
getDataForSessionState func() []byte,
setDataFromSessionState func([]byte),
accept0RTT func([]byte) bool, accept0RTT func([]byte) bool,
enable0RTT bool, enable0RTT bool,
) *qtls.Config { ) *qtls.Config {
@ -59,16 +61,12 @@ func tlsConfigToQtlsConfig(
if tlsConf == nil { if tlsConf == nil {
return nil, nil return nil, nil
} }
return tlsConfigToQtlsConfig(tlsConf, recordLayer, extHandler, accept0RTT, enable0RTT), nil return tlsConfigToQtlsConfig(tlsConf, recordLayer, extHandler, getDataForSessionState, setDataFromSessionState, accept0RTT, enable0RTT), nil
} }
} }
var csc qtls.ClientSessionCache var csc qtls.ClientSessionCache
if c.ClientSessionCache != nil { if c.ClientSessionCache != nil {
csc = newClientSessionCache( csc = newClientSessionCache(c.ClientSessionCache, getDataForSessionState, setDataFromSessionState)
c.ClientSessionCache,
func() []byte { return nil },
func([]byte) {},
)
} }
conf := &qtls.Config{ conf := &qtls.Config{
Rand: c.Rand, Rand: c.Rand,

View file

@ -28,19 +28,19 @@ func (*mockExtensionHandler) TransportParameters() <-chan []byte { panic("not im
var _ = Describe("qtls.Config generation", func() { var _ = Describe("qtls.Config generation", func() {
It("sets MinVersion and MaxVersion", func() { It("sets MinVersion and MaxVersion", func() {
tlsConf := &tls.Config{MinVersion: tls.VersionTLS11, MaxVersion: tls.VersionTLS12} tlsConf := &tls.Config{MinVersion: tls.VersionTLS11, MaxVersion: tls.VersionTLS12}
qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil, &mockExtensionHandler{}, nil, false) qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil, &mockExtensionHandler{}, nil, nil, nil, false)
Expect(qtlsConf.MinVersion).To(BeEquivalentTo(tls.VersionTLS13)) Expect(qtlsConf.MinVersion).To(BeEquivalentTo(tls.VersionTLS13))
Expect(qtlsConf.MaxVersion).To(BeEquivalentTo(tls.VersionTLS13)) Expect(qtlsConf.MaxVersion).To(BeEquivalentTo(tls.VersionTLS13))
}) })
It("works when called with a nil config", func() { It("works when called with a nil config", func() {
qtlsConf := tlsConfigToQtlsConfig(nil, nil, &mockExtensionHandler{}, nil, false) qtlsConf := tlsConfigToQtlsConfig(nil, nil, &mockExtensionHandler{}, nil, nil, nil, false)
Expect(qtlsConf).ToNot(BeNil()) Expect(qtlsConf).ToNot(BeNil())
}) })
It("sets the setter and getter function for TLS extensions", func() { It("sets the setter and getter function for TLS extensions", func() {
extHandler := &mockExtensionHandler{} extHandler := &mockExtensionHandler{}
qtlsConf := tlsConfigToQtlsConfig(&tls.Config{}, nil, extHandler, nil, false) qtlsConf := tlsConfigToQtlsConfig(&tls.Config{}, nil, extHandler, nil, nil, nil, false)
Expect(extHandler.get).To(BeFalse()) Expect(extHandler.get).To(BeFalse())
qtlsConf.GetExtensions(10) qtlsConf.GetExtensions(10)
Expect(extHandler.get).To(BeTrue()) Expect(extHandler.get).To(BeTrue())
@ -51,31 +51,31 @@ var _ = Describe("qtls.Config generation", func() {
It("sets the Accept0RTT callback", func() { It("sets the Accept0RTT callback", func() {
accept0RTT := func([]byte) bool { return true } accept0RTT := func([]byte) bool { return true }
qtlsConf := tlsConfigToQtlsConfig(nil, nil, &mockExtensionHandler{}, accept0RTT, false) qtlsConf := tlsConfigToQtlsConfig(nil, nil, &mockExtensionHandler{}, nil, nil, accept0RTT, false)
Expect(qtlsConf.Accept0RTT).ToNot(BeNil()) Expect(qtlsConf.Accept0RTT).ToNot(BeNil())
Expect(qtlsConf.Accept0RTT(nil)).To(BeTrue()) Expect(qtlsConf.Accept0RTT(nil)).To(BeTrue())
}) })
It("enables 0-RTT", func() { It("enables 0-RTT", func() {
qtlsConf := tlsConfigToQtlsConfig(nil, nil, &mockExtensionHandler{}, nil, false) qtlsConf := tlsConfigToQtlsConfig(nil, nil, &mockExtensionHandler{}, nil, nil, nil, false)
Expect(qtlsConf.Enable0RTT).To(BeFalse()) Expect(qtlsConf.Enable0RTT).To(BeFalse())
Expect(qtlsConf.MaxEarlyData).To(BeZero()) Expect(qtlsConf.MaxEarlyData).To(BeZero())
qtlsConf = tlsConfigToQtlsConfig(nil, nil, &mockExtensionHandler{}, nil, true) qtlsConf = tlsConfigToQtlsConfig(nil, nil, &mockExtensionHandler{}, nil, nil, nil, true)
Expect(qtlsConf.Enable0RTT).To(BeTrue()) Expect(qtlsConf.Enable0RTT).To(BeTrue())
Expect(qtlsConf.MaxEarlyData).To(Equal(uint32(0xffffffff))) Expect(qtlsConf.MaxEarlyData).To(Equal(uint32(0xffffffff)))
}) })
It("initializes such that the session ticket key remains constant", func() { It("initializes such that the session ticket key remains constant", func() {
tlsConf := &tls.Config{} tlsConf := &tls.Config{}
qtlsConf1 := tlsConfigToQtlsConfig(tlsConf, nil, &mockExtensionHandler{}, nil, false) qtlsConf1 := tlsConfigToQtlsConfig(tlsConf, nil, &mockExtensionHandler{}, nil, nil, nil, false)
qtlsConf2 := tlsConfigToQtlsConfig(tlsConf, nil, &mockExtensionHandler{}, nil, false) qtlsConf2 := tlsConfigToQtlsConfig(tlsConf, nil, &mockExtensionHandler{}, nil, nil, nil, false)
Expect(qtlsConf1.SessionTicketKey).ToNot(BeZero()) // should now contain a random value Expect(qtlsConf1.SessionTicketKey).ToNot(BeZero()) // should now contain a random value
Expect(qtlsConf1.SessionTicketKey).To(Equal(qtlsConf2.SessionTicketKey)) Expect(qtlsConf1.SessionTicketKey).To(Equal(qtlsConf2.SessionTicketKey))
}) })
Context("GetConfigForClient callback", func() { Context("GetConfigForClient callback", func() {
It("doesn't set it if absent", func() { It("doesn't set it if absent", func() {
qtlsConf := tlsConfigToQtlsConfig(&tls.Config{}, nil, &mockExtensionHandler{}, nil, false) qtlsConf := tlsConfigToQtlsConfig(&tls.Config{}, nil, &mockExtensionHandler{}, nil, nil, nil, false)
Expect(qtlsConf.GetConfigForClient).To(BeNil()) Expect(qtlsConf.GetConfigForClient).To(BeNil())
}) })
@ -86,7 +86,7 @@ var _ = Describe("qtls.Config generation", func() {
}, },
} }
extHandler := &mockExtensionHandler{} extHandler := &mockExtensionHandler{}
qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil, extHandler, nil, false) qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil, extHandler, nil, nil, nil, false)
Expect(qtlsConf.GetConfigForClient).ToNot(BeNil()) Expect(qtlsConf.GetConfigForClient).ToNot(BeNil())
confForClient, err := qtlsConf.GetConfigForClient(nil) confForClient, err := qtlsConf.GetConfigForClient(nil)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
@ -106,7 +106,7 @@ var _ = Describe("qtls.Config generation", func() {
return nil, testErr return nil, testErr
}, },
} }
qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil, &mockExtensionHandler{}, nil, false) qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil, &mockExtensionHandler{}, nil, nil, nil, false)
_, err := qtlsConf.GetConfigForClient(nil) _, err := qtlsConf.GetConfigForClient(nil)
Expect(err).To(MatchError(testErr)) Expect(err).To(MatchError(testErr))
}) })
@ -117,35 +117,49 @@ var _ = Describe("qtls.Config generation", func() {
return nil, nil return nil, nil
}, },
} }
qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil, &mockExtensionHandler{}, nil, false) qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil, &mockExtensionHandler{}, nil, nil, nil, false)
Expect(qtlsConf.GetConfigForClient(nil)).To(BeNil()) Expect(qtlsConf.GetConfigForClient(nil)).To(BeNil())
}) })
}) })
Context("ClientSessionCache", func() { Context("ClientSessionCache", func() {
It("doesn't set if absent", func() { It("doesn't set if absent", func() {
qtlsConf := tlsConfigToQtlsConfig(&tls.Config{}, nil, &mockExtensionHandler{}, nil, false) qtlsConf := tlsConfigToQtlsConfig(&tls.Config{}, nil, &mockExtensionHandler{}, nil, nil, nil, false)
Expect(qtlsConf.ClientSessionCache).To(BeNil()) Expect(qtlsConf.ClientSessionCache).To(BeNil())
}) })
It("sets it, and puts and gets session states", func() { It("sets it, and puts and gets session states", func() {
csc := NewMockClientSessionCache(mockCtrl) csc := NewMockClientSessionCache(mockCtrl)
tlsConf := &tls.Config{ClientSessionCache: csc} tlsConf := &tls.Config{ClientSessionCache: csc}
qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil, &mockExtensionHandler{}, nil, false) var appData []byte
qtlsConf := tlsConfigToQtlsConfig(
tlsConf,
nil,
&mockExtensionHandler{},
func() []byte { return []byte("foobar") },
func(p []byte) { appData = p },
nil,
false,
)
Expect(qtlsConf.ClientSessionCache).ToNot(BeNil()) Expect(qtlsConf.ClientSessionCache).ToNot(BeNil())
var state *tls.ClientSessionState
// put something // put something
csc.EXPECT().Put("foobar", gomock.Any()) csc.EXPECT().Put("localhost", gomock.Any()).Do(func(_ string, css *tls.ClientSessionState) {
qtlsConf.ClientSessionCache.Put("foobar", &qtls.ClientSessionState{}) state = css
})
qtlsConf.ClientSessionCache.Put("localhost", &qtls.ClientSessionState{})
// get something // get something
csc.EXPECT().Get("raboof").Return(nil, true) csc.EXPECT().Get("localhost").Return(state, true)
_, ok := qtlsConf.ClientSessionCache.Get("raboof") _, ok := qtlsConf.ClientSessionCache.Get("localhost")
Expect(ok).To(BeTrue()) Expect(ok).To(BeTrue())
Expect(appData).To(Equal([]byte("foobar")))
}) })
It("puts a nil session state", func() { It("puts a nil session state", func() {
csc := NewMockClientSessionCache(mockCtrl) csc := NewMockClientSessionCache(mockCtrl)
tlsConf := &tls.Config{ClientSessionCache: csc} tlsConf := &tls.Config{ClientSessionCache: csc}
qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil, &mockExtensionHandler{}, nil, false) qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil, &mockExtensionHandler{}, nil, nil, nil, false)
// put something // put something
csc.EXPECT().Put("foobar", nil) csc.EXPECT().Put("foobar", nil)
qtlsConf.ClientSessionCache.Put("foobar", nil) qtlsConf.ClientSessionCache.Put("foobar", nil)

View file

@ -156,7 +156,7 @@ type session struct {
undecryptablePackets []*receivedPacket undecryptablePackets []*receivedPacket
clientHelloWritten <-chan struct{} clientHelloWritten <-chan *handshake.TransportParameters
earlySessionReadyChan chan struct{} earlySessionReadyChan chan struct{}
handshakeCompleteChan chan struct{} // is closed when the handshake completes handshakeCompleteChan chan struct{} // is closed when the handshake completes
handshakeComplete bool handshakeComplete bool
@ -463,8 +463,11 @@ func (s *session) run() error {
if s.perspective == protocol.PerspectiveClient { if s.perspective == protocol.PerspectiveClient {
select { select {
case <-s.clientHelloWritten: case zeroRTTParams := <-s.clientHelloWritten:
s.scheduleSending() s.scheduleSending()
if zeroRTTParams != nil {
s.processTransportParameters(zeroRTTParams)
}
case closeErr := <-s.closeChan: case closeErr := <-s.closeChan:
// put the close error back into the channel, so that the run loop can receive it // put the close error back into the channel, so that the run loop can receive it
s.closeChan <- closeErr s.closeChan <- closeErr
@ -1099,7 +1102,7 @@ func (s *session) processTransportParameters(params *handshake.TransportParamete
return return
} }
s.logger.Debugf("Received Transport Parameters: %s", params) s.logger.Debugf("Processed Transport Parameters: %s", params)
s.peerParams = params s.peerParams = params
// Our local idle timeout will always be > 0. // Our local idle timeout will always be > 0.
s.idleTimeout = utils.MinNonZeroDuration(s.config.MaxIdleTimeout, params.MaxIdleTimeout) s.idleTimeout = utils.MinNonZeroDuration(s.config.MaxIdleTimeout, params.MaxIdleTimeout)
@ -1124,7 +1127,9 @@ func (s *session) processTransportParameters(params *handshake.TransportParamete
} }
// On the server side, the early session is ready as soon as we processed // On the server side, the early session is ready as soon as we processed
// the client's transport parameters. // the client's transport parameters.
close(s.earlySessionReadyChan) if s.perspective == protocol.PerspectiveServer {
close(s.earlySessionReadyChan)
}
} }
func (s *session) sendPackets() error { func (s *session) sendPackets() error {