mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-03 20:27:35 +03:00
restore the server's transport parameters from the session ticket
This commit is contained in:
parent
1f8a47af02
commit
44aa12850e
5 changed files with 124 additions and 41 deletions
|
@ -67,6 +67,7 @@ type cryptoSetup struct {
|
|||
messageChan chan []byte
|
||||
|
||||
ourParams *TransportParameters
|
||||
peerParams *TransportParameters
|
||||
paramsChan <-chan []byte
|
||||
|
||||
runner handshakeRunner
|
||||
|
@ -77,8 +78,9 @@ type cryptoSetup struct {
|
|||
// is closed when Close() is called
|
||||
closeChan chan struct{}
|
||||
|
||||
zeroRTTParameters *TransportParameters
|
||||
clientHelloWritten bool
|
||||
clientHelloWrittenChan chan struct{}
|
||||
clientHelloWrittenChan chan *TransportParameters
|
||||
|
||||
receivedWriteKey chan struct{}
|
||||
receivedReadKey chan struct{}
|
||||
|
@ -131,7 +133,7 @@ func NewCryptoSetupClient(
|
|||
enable0RTT bool,
|
||||
rttStats *congestion.RTTStats,
|
||||
logger utils.Logger,
|
||||
) (CryptoSetup, <-chan struct{} /* ClientHello written */) {
|
||||
) (CryptoSetup, <-chan *TransportParameters /* ClientHello written. Receive nil for non-0-RTT */) {
|
||||
cs, clientHelloWritten := newCryptoSetup(
|
||||
initialStream,
|
||||
handshakeStream,
|
||||
|
@ -192,7 +194,7 @@ func newCryptoSetup(
|
|||
rttStats *congestion.RTTStats,
|
||||
logger utils.Logger,
|
||||
perspective protocol.Perspective,
|
||||
) (*cryptoSetup, <-chan struct{} /* ClientHello written */) {
|
||||
) (*cryptoSetup, <-chan *TransportParameters /* ClientHello written. Receive nil for non-0-RTT */) {
|
||||
initialSealer, initialOpener := NewInitialAEAD(connID, perspective)
|
||||
extHandler := newExtensionHandler(tp.Marshal(), perspective)
|
||||
cs := &cryptoSetup{
|
||||
|
@ -211,14 +213,14 @@ func newCryptoSetup(
|
|||
perspective: perspective,
|
||||
handshakeDone: make(chan struct{}),
|
||||
alertChan: make(chan uint8),
|
||||
clientHelloWrittenChan: make(chan struct{}),
|
||||
clientHelloWrittenChan: make(chan *TransportParameters, 1),
|
||||
messageChan: make(chan []byte, 100),
|
||||
receivedReadKey: make(chan struct{}),
|
||||
receivedWriteKey: make(chan struct{}),
|
||||
writeRecord: make(chan struct{}, 1),
|
||||
closeChan: make(chan struct{}),
|
||||
}
|
||||
qtlsConf := tlsConfigToQtlsConfig(tlsConf, cs, extHandler, cs.accept0RTT, enable0RTT)
|
||||
qtlsConf := tlsConfigToQtlsConfig(tlsConf, cs, extHandler, cs.marshalPeerParamsForSessionState, cs.handlePeerParamsFromSessionState, cs.accept0RTT, enable0RTT)
|
||||
cs.tlsConf = qtlsConf
|
||||
return cs, cs.clientHelloWrittenChan
|
||||
}
|
||||
|
@ -436,7 +438,22 @@ func (h *cryptoSetup) handleTransportParameters(data []byte) {
|
|||
if err := tp.Unmarshal(data, h.perspective.Opposite()); err != nil {
|
||||
h.runner.OnError(qerr.Error(qerr.TransportParameterError, err.Error()))
|
||||
}
|
||||
h.runner.OnReceivedParams(&tp)
|
||||
h.peerParams = &tp
|
||||
h.runner.OnReceivedParams(h.peerParams)
|
||||
}
|
||||
|
||||
// must be called after receiving the transport parameters
|
||||
func (h *cryptoSetup) marshalPeerParamsForSessionState() []byte {
|
||||
return h.peerParams.MarshalForSessionTicket()
|
||||
}
|
||||
|
||||
func (h *cryptoSetup) handlePeerParamsFromSessionState(data []byte) {
|
||||
var tp TransportParameters
|
||||
if err := tp.Unmarshal(data, protocol.PerspectiveServer); err != nil {
|
||||
h.logger.Debugf("Restoring of transport parameters from session ticket failed: %s", err.Error())
|
||||
return
|
||||
}
|
||||
h.zeroRTTParameters = &tp
|
||||
}
|
||||
|
||||
// only valid for the server
|
||||
|
@ -569,7 +586,13 @@ func (h *cryptoSetup) WriteRecord(p []byte) (int, error) {
|
|||
n, err := h.initialStream.Write(p)
|
||||
if !h.clientHelloWritten && h.perspective == protocol.PerspectiveClient {
|
||||
h.clientHelloWritten = true
|
||||
close(h.clientHelloWrittenChan)
|
||||
if h.zeroRTTSealer != nil && h.zeroRTTParameters != nil {
|
||||
h.logger.Debugf("Doing 0-RTT.")
|
||||
h.clientHelloWrittenChan <- h.zeroRTTParameters
|
||||
} else {
|
||||
h.logger.Debugf("Not doing 0-RTT. Has Sealer: %t, has params: %t", h.zeroRTTSealer != nil, h.zeroRTTParameters != nil)
|
||||
h.clientHelloWrittenChan <- nil
|
||||
}
|
||||
} else {
|
||||
// We need additional signaling to properly detect HelloRetryRequests.
|
||||
// For servers: when the ServerHello is written.
|
||||
|
|
|
@ -417,7 +417,7 @@ var _ = Describe("Crypto Setup TLS", func() {
|
|||
}()
|
||||
var ch chunk
|
||||
Eventually(cChunkChan).Should(Receive(&ch))
|
||||
Eventually(chChan).Should(BeClosed())
|
||||
Eventually(chChan).Should(Receive(BeNil()))
|
||||
// make sure the whole ClientHello was written
|
||||
Expect(len(ch.data)).To(BeNumerically(">=", 4))
|
||||
Expect(messageType(ch.data[0])).To(Equal(typeClientHello))
|
||||
|
@ -671,10 +671,53 @@ var _ = Describe("Crypto Setup TLS", func() {
|
|||
|
||||
csc.EXPECT().Get(gomock.Any()).Return(state, true)
|
||||
csc.EXPECT().Put(gomock.Any(), gomock.Any()).MaxTimes(1)
|
||||
client, clientErr, server, serverErr = handshakeWithTLSConf(clientConf, serverConf, true)
|
||||
Expect(clientErr).ToNot(HaveOccurred())
|
||||
Expect(serverErr).ToNot(HaveOccurred())
|
||||
Eventually(receivedSessionTicket).Should(BeClosed())
|
||||
|
||||
cChunkChan, cInitialStream, cHandshakeStream, _ := initStreams()
|
||||
cRunner := NewMockHandshakeRunner(mockCtrl)
|
||||
cRunner.EXPECT().OnReceivedParams(gomock.Any())
|
||||
cRunner.EXPECT().OnHandshakeComplete()
|
||||
client, clientHelloChan := NewCryptoSetupClient(
|
||||
cInitialStream,
|
||||
cHandshakeStream,
|
||||
ioutil.Discard,
|
||||
protocol.ConnectionID{},
|
||||
nil,
|
||||
&TransportParameters{},
|
||||
cRunner,
|
||||
clientConf,
|
||||
true,
|
||||
&congestion.RTTStats{},
|
||||
utils.DefaultLogger.WithPrefix("client"),
|
||||
)
|
||||
|
||||
sChunkChan, sInitialStream, sHandshakeStream, _ := initStreams()
|
||||
sRunner := NewMockHandshakeRunner(mockCtrl)
|
||||
sRunner.EXPECT().OnReceivedParams(gomock.Any())
|
||||
sRunner.EXPECT().OnHandshakeComplete()
|
||||
server = NewCryptoSetupServer(
|
||||
sInitialStream,
|
||||
sHandshakeStream,
|
||||
ioutil.Discard,
|
||||
protocol.ConnectionID{},
|
||||
nil,
|
||||
&TransportParameters{},
|
||||
sRunner,
|
||||
serverConf,
|
||||
true,
|
||||
&congestion.RTTStats{},
|
||||
utils.DefaultLogger.WithPrefix("server"),
|
||||
)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
handshake(client, cChunkChan, server, sChunkChan)
|
||||
close(done)
|
||||
}()
|
||||
Eventually(done).Should(BeClosed())
|
||||
|
||||
Expect(clientHelloChan).To(Receive(Not(BeNil())))
|
||||
|
||||
Expect(server.ConnectionState().DidResume).To(BeTrue())
|
||||
Expect(client.ConnectionState().DidResume).To(BeTrue())
|
||||
opener, err := server.Get0RTTOpener()
|
||||
|
|
|
@ -31,6 +31,8 @@ func tlsConfigToQtlsConfig(
|
|||
c *tls.Config,
|
||||
recordLayer qtls.RecordLayer,
|
||||
extHandler tlsExtensionHandler,
|
||||
getDataForSessionState func() []byte,
|
||||
setDataFromSessionState func([]byte),
|
||||
accept0RTT func([]byte) bool,
|
||||
enable0RTT bool,
|
||||
) *qtls.Config {
|
||||
|
@ -59,16 +61,12 @@ func tlsConfigToQtlsConfig(
|
|||
if tlsConf == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return tlsConfigToQtlsConfig(tlsConf, recordLayer, extHandler, accept0RTT, enable0RTT), nil
|
||||
return tlsConfigToQtlsConfig(tlsConf, recordLayer, extHandler, getDataForSessionState, setDataFromSessionState, accept0RTT, enable0RTT), nil
|
||||
}
|
||||
}
|
||||
var csc qtls.ClientSessionCache
|
||||
if c.ClientSessionCache != nil {
|
||||
csc = newClientSessionCache(
|
||||
c.ClientSessionCache,
|
||||
func() []byte { return nil },
|
||||
func([]byte) {},
|
||||
)
|
||||
csc = newClientSessionCache(c.ClientSessionCache, getDataForSessionState, setDataFromSessionState)
|
||||
}
|
||||
conf := &qtls.Config{
|
||||
Rand: c.Rand,
|
||||
|
|
|
@ -28,19 +28,19 @@ func (*mockExtensionHandler) TransportParameters() <-chan []byte { panic("not im
|
|||
var _ = Describe("qtls.Config generation", func() {
|
||||
It("sets MinVersion and MaxVersion", func() {
|
||||
tlsConf := &tls.Config{MinVersion: tls.VersionTLS11, MaxVersion: tls.VersionTLS12}
|
||||
qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil, &mockExtensionHandler{}, nil, false)
|
||||
qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil, &mockExtensionHandler{}, nil, nil, nil, false)
|
||||
Expect(qtlsConf.MinVersion).To(BeEquivalentTo(tls.VersionTLS13))
|
||||
Expect(qtlsConf.MaxVersion).To(BeEquivalentTo(tls.VersionTLS13))
|
||||
})
|
||||
|
||||
It("works when called with a nil config", func() {
|
||||
qtlsConf := tlsConfigToQtlsConfig(nil, nil, &mockExtensionHandler{}, nil, false)
|
||||
qtlsConf := tlsConfigToQtlsConfig(nil, nil, &mockExtensionHandler{}, nil, nil, nil, false)
|
||||
Expect(qtlsConf).ToNot(BeNil())
|
||||
})
|
||||
|
||||
It("sets the setter and getter function for TLS extensions", func() {
|
||||
extHandler := &mockExtensionHandler{}
|
||||
qtlsConf := tlsConfigToQtlsConfig(&tls.Config{}, nil, extHandler, nil, false)
|
||||
qtlsConf := tlsConfigToQtlsConfig(&tls.Config{}, nil, extHandler, nil, nil, nil, false)
|
||||
Expect(extHandler.get).To(BeFalse())
|
||||
qtlsConf.GetExtensions(10)
|
||||
Expect(extHandler.get).To(BeTrue())
|
||||
|
@ -51,31 +51,31 @@ var _ = Describe("qtls.Config generation", func() {
|
|||
|
||||
It("sets the Accept0RTT callback", func() {
|
||||
accept0RTT := func([]byte) bool { return true }
|
||||
qtlsConf := tlsConfigToQtlsConfig(nil, nil, &mockExtensionHandler{}, accept0RTT, false)
|
||||
qtlsConf := tlsConfigToQtlsConfig(nil, nil, &mockExtensionHandler{}, nil, nil, accept0RTT, false)
|
||||
Expect(qtlsConf.Accept0RTT).ToNot(BeNil())
|
||||
Expect(qtlsConf.Accept0RTT(nil)).To(BeTrue())
|
||||
})
|
||||
|
||||
It("enables 0-RTT", func() {
|
||||
qtlsConf := tlsConfigToQtlsConfig(nil, nil, &mockExtensionHandler{}, nil, false)
|
||||
qtlsConf := tlsConfigToQtlsConfig(nil, nil, &mockExtensionHandler{}, nil, nil, nil, false)
|
||||
Expect(qtlsConf.Enable0RTT).To(BeFalse())
|
||||
Expect(qtlsConf.MaxEarlyData).To(BeZero())
|
||||
qtlsConf = tlsConfigToQtlsConfig(nil, nil, &mockExtensionHandler{}, nil, true)
|
||||
qtlsConf = tlsConfigToQtlsConfig(nil, nil, &mockExtensionHandler{}, nil, nil, nil, true)
|
||||
Expect(qtlsConf.Enable0RTT).To(BeTrue())
|
||||
Expect(qtlsConf.MaxEarlyData).To(Equal(uint32(0xffffffff)))
|
||||
})
|
||||
|
||||
It("initializes such that the session ticket key remains constant", func() {
|
||||
tlsConf := &tls.Config{}
|
||||
qtlsConf1 := tlsConfigToQtlsConfig(tlsConf, nil, &mockExtensionHandler{}, nil, false)
|
||||
qtlsConf2 := tlsConfigToQtlsConfig(tlsConf, nil, &mockExtensionHandler{}, nil, false)
|
||||
qtlsConf1 := tlsConfigToQtlsConfig(tlsConf, nil, &mockExtensionHandler{}, nil, nil, nil, false)
|
||||
qtlsConf2 := tlsConfigToQtlsConfig(tlsConf, nil, &mockExtensionHandler{}, nil, nil, nil, false)
|
||||
Expect(qtlsConf1.SessionTicketKey).ToNot(BeZero()) // should now contain a random value
|
||||
Expect(qtlsConf1.SessionTicketKey).To(Equal(qtlsConf2.SessionTicketKey))
|
||||
})
|
||||
|
||||
Context("GetConfigForClient callback", func() {
|
||||
It("doesn't set it if absent", func() {
|
||||
qtlsConf := tlsConfigToQtlsConfig(&tls.Config{}, nil, &mockExtensionHandler{}, nil, false)
|
||||
qtlsConf := tlsConfigToQtlsConfig(&tls.Config{}, nil, &mockExtensionHandler{}, nil, nil, nil, false)
|
||||
Expect(qtlsConf.GetConfigForClient).To(BeNil())
|
||||
})
|
||||
|
||||
|
@ -86,7 +86,7 @@ var _ = Describe("qtls.Config generation", func() {
|
|||
},
|
||||
}
|
||||
extHandler := &mockExtensionHandler{}
|
||||
qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil, extHandler, nil, false)
|
||||
qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil, extHandler, nil, nil, nil, false)
|
||||
Expect(qtlsConf.GetConfigForClient).ToNot(BeNil())
|
||||
confForClient, err := qtlsConf.GetConfigForClient(nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
@ -106,7 +106,7 @@ var _ = Describe("qtls.Config generation", func() {
|
|||
return nil, testErr
|
||||
},
|
||||
}
|
||||
qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil, &mockExtensionHandler{}, nil, false)
|
||||
qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil, &mockExtensionHandler{}, nil, nil, nil, false)
|
||||
_, err := qtlsConf.GetConfigForClient(nil)
|
||||
Expect(err).To(MatchError(testErr))
|
||||
})
|
||||
|
@ -117,35 +117,49 @@ var _ = Describe("qtls.Config generation", func() {
|
|||
return nil, nil
|
||||
},
|
||||
}
|
||||
qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil, &mockExtensionHandler{}, nil, false)
|
||||
qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil, &mockExtensionHandler{}, nil, nil, nil, false)
|
||||
Expect(qtlsConf.GetConfigForClient(nil)).To(BeNil())
|
||||
})
|
||||
})
|
||||
|
||||
Context("ClientSessionCache", func() {
|
||||
It("doesn't set if absent", func() {
|
||||
qtlsConf := tlsConfigToQtlsConfig(&tls.Config{}, nil, &mockExtensionHandler{}, nil, false)
|
||||
qtlsConf := tlsConfigToQtlsConfig(&tls.Config{}, nil, &mockExtensionHandler{}, nil, nil, nil, false)
|
||||
Expect(qtlsConf.ClientSessionCache).To(BeNil())
|
||||
})
|
||||
|
||||
It("sets it, and puts and gets session states", func() {
|
||||
csc := NewMockClientSessionCache(mockCtrl)
|
||||
tlsConf := &tls.Config{ClientSessionCache: csc}
|
||||
qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil, &mockExtensionHandler{}, nil, false)
|
||||
var appData []byte
|
||||
qtlsConf := tlsConfigToQtlsConfig(
|
||||
tlsConf,
|
||||
nil,
|
||||
&mockExtensionHandler{},
|
||||
func() []byte { return []byte("foobar") },
|
||||
func(p []byte) { appData = p },
|
||||
nil,
|
||||
false,
|
||||
)
|
||||
Expect(qtlsConf.ClientSessionCache).ToNot(BeNil())
|
||||
|
||||
var state *tls.ClientSessionState
|
||||
// put something
|
||||
csc.EXPECT().Put("foobar", gomock.Any())
|
||||
qtlsConf.ClientSessionCache.Put("foobar", &qtls.ClientSessionState{})
|
||||
csc.EXPECT().Put("localhost", gomock.Any()).Do(func(_ string, css *tls.ClientSessionState) {
|
||||
state = css
|
||||
})
|
||||
qtlsConf.ClientSessionCache.Put("localhost", &qtls.ClientSessionState{})
|
||||
// get something
|
||||
csc.EXPECT().Get("raboof").Return(nil, true)
|
||||
_, ok := qtlsConf.ClientSessionCache.Get("raboof")
|
||||
csc.EXPECT().Get("localhost").Return(state, true)
|
||||
_, ok := qtlsConf.ClientSessionCache.Get("localhost")
|
||||
Expect(ok).To(BeTrue())
|
||||
Expect(appData).To(Equal([]byte("foobar")))
|
||||
})
|
||||
|
||||
It("puts a nil session state", func() {
|
||||
csc := NewMockClientSessionCache(mockCtrl)
|
||||
tlsConf := &tls.Config{ClientSessionCache: csc}
|
||||
qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil, &mockExtensionHandler{}, nil, false)
|
||||
qtlsConf := tlsConfigToQtlsConfig(tlsConf, nil, &mockExtensionHandler{}, nil, nil, nil, false)
|
||||
// put something
|
||||
csc.EXPECT().Put("foobar", nil)
|
||||
qtlsConf.ClientSessionCache.Put("foobar", nil)
|
||||
|
|
13
session.go
13
session.go
|
@ -156,7 +156,7 @@ type session struct {
|
|||
|
||||
undecryptablePackets []*receivedPacket
|
||||
|
||||
clientHelloWritten <-chan struct{}
|
||||
clientHelloWritten <-chan *handshake.TransportParameters
|
||||
earlySessionReadyChan chan struct{}
|
||||
handshakeCompleteChan chan struct{} // is closed when the handshake completes
|
||||
handshakeComplete bool
|
||||
|
@ -463,8 +463,11 @@ func (s *session) run() error {
|
|||
|
||||
if s.perspective == protocol.PerspectiveClient {
|
||||
select {
|
||||
case <-s.clientHelloWritten:
|
||||
case zeroRTTParams := <-s.clientHelloWritten:
|
||||
s.scheduleSending()
|
||||
if zeroRTTParams != nil {
|
||||
s.processTransportParameters(zeroRTTParams)
|
||||
}
|
||||
case closeErr := <-s.closeChan:
|
||||
// put the close error back into the channel, so that the run loop can receive it
|
||||
s.closeChan <- closeErr
|
||||
|
@ -1099,7 +1102,7 @@ func (s *session) processTransportParameters(params *handshake.TransportParamete
|
|||
return
|
||||
}
|
||||
|
||||
s.logger.Debugf("Received Transport Parameters: %s", params)
|
||||
s.logger.Debugf("Processed Transport Parameters: %s", params)
|
||||
s.peerParams = params
|
||||
// Our local idle timeout will always be > 0.
|
||||
s.idleTimeout = utils.MinNonZeroDuration(s.config.MaxIdleTimeout, params.MaxIdleTimeout)
|
||||
|
@ -1124,7 +1127,9 @@ func (s *session) processTransportParameters(params *handshake.TransportParamete
|
|||
}
|
||||
// On the server side, the early session is ready as soon as we processed
|
||||
// the client's transport parameters.
|
||||
close(s.earlySessionReadyChan)
|
||||
if s.perspective == protocol.PerspectiveServer {
|
||||
close(s.earlySessionReadyChan)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *session) sendPackets() error {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue