refactor how session tickets are sent

Previously, RunHandshake() would send the session tickets. Now, the
session initiates the sending of the session ticket. This simplifies the
setup a bit, and it will make it possible to include the RTT estimate in
the session ticket without accessing the RTTStats concurrently.
This commit is contained in:
Marten Seemann 2020-02-02 13:45:52 +07:00
parent 3e32a693ad
commit 8cde4ab638
9 changed files with 103 additions and 140 deletions

View file

@ -22,35 +22,6 @@ type cryptoStream interface {
PopCryptoFrame(protocol.ByteCount) *wire.CryptoFrame
}
type postHandshakeCryptoStream struct {
cryptoStream
framer framer
}
func newPostHandshakeCryptoStream(framer framer) cryptoStream {
return &postHandshakeCryptoStream{
cryptoStream: newCryptoStream(),
framer: framer,
}
}
// Write writes post-handshake messages.
// For simplicity, post-handshake crypto messages are treated as control frames.
// The framer functions as a stack (LIFO), so if there are multiple writes,
// they will be returned in the opposite order.
// This is acceptable, since post-handshake crypto messages are very rare.
func (s *postHandshakeCryptoStream) Write(p []byte) (int, error) {
n, err := s.cryptoStream.Write(p)
if err != nil {
return n, err
}
for s.cryptoStream.HasData() {
s.framer.QueueControlFrame(s.PopCryptoFrame(protocol.MaxPostHandshakeCryptoFrameSize))
}
return n, nil
}
type cryptoStreamImpl struct {
queue *frameSorter
msgBuf []byte

View file

@ -180,46 +180,3 @@ var _ = Describe("Crypto Stream", func() {
})
})
})
var _ = Describe("Post Handshake Crypto Stream", func() {
var (
cs cryptoStream
framer framer
)
BeforeEach(func() {
framer = newFramer(NewMockStreamGetter(mockCtrl), protocol.VersionTLS)
cs = newPostHandshakeCryptoStream(framer)
})
It("queues CRYPTO frames when writing data", func() {
n, err := cs.Write([]byte("foo"))
Expect(err).ToNot(HaveOccurred())
Expect(n).To(Equal(3))
n, err = cs.Write([]byte("bar"))
Expect(err).ToNot(HaveOccurred())
Expect(n).To(Equal(3))
frames, _ := framer.AppendControlFrames(nil, 1000)
Expect(frames).To(HaveLen(2))
fs := []wire.Frame{frames[0].Frame, frames[1].Frame}
Expect(fs).To(ContainElement(&wire.CryptoFrame{Data: []byte("foo")}))
Expect(fs).To(ContainElement(&wire.CryptoFrame{Data: []byte("bar"), Offset: 3}))
})
It("splits large writes into multiple frames", func() {
size := 10 * protocol.MaxPostHandshakeCryptoFrameSize
n, err := cs.Write(make([]byte, size))
Expect(err).ToNot(HaveOccurred())
Expect(n).To(BeEquivalentTo(size))
frames, _ := framer.AppendControlFrames(nil, protocol.MaxByteCount)
Expect(frames).To(HaveLen(11)) // one more for framing overhead
var dataLen int
for _, f := range frames {
Expect(f.Frame.Length(protocol.VersionTLS)).To(BeNumerically("<=", protocol.MaxPostHandshakeCryptoFrameSize))
Expect(f.Frame).To(BeAssignableToTypeOf(&wire.CryptoFrame{}))
dataLen += len(f.Frame.(*wire.CryptoFrame).Data)
}
Expect(dataLen).To(BeEquivalentTo(size))
})
})

View file

@ -16,12 +16,8 @@ import (
"github.com/marten-seemann/qtls"
)
const (
// TLS unexpected_message alert
alertUnexpectedMessage uint8 = 10
// TLS internal error
alertInternalError uint8 = 80
)
// TLS unexpected_message alert
const alertUnexpectedMessage uint8 = 10
type messageType uint8
@ -111,7 +107,6 @@ type cryptoSetup struct {
handshakeOpener LongHeaderOpener
handshakeSealer LongHeaderSealer
oneRTTStream io.Writer
aead *updatableAEAD
has1RTTSealer bool
has1RTTOpener bool
@ -124,7 +119,6 @@ var _ CryptoSetup = &cryptoSetup{}
func NewCryptoSetupClient(
initialStream io.Writer,
handshakeStream io.Writer,
oneRTTStream io.Writer,
connID protocol.ConnectionID,
remoteAddr net.Addr,
tp *TransportParameters,
@ -137,7 +131,6 @@ func NewCryptoSetupClient(
cs, clientHelloWritten := newCryptoSetup(
initialStream,
handshakeStream,
oneRTTStream,
connID,
tp,
runner,
@ -155,7 +148,6 @@ func NewCryptoSetupClient(
func NewCryptoSetupServer(
initialStream io.Writer,
handshakeStream io.Writer,
oneRTTStream io.Writer,
connID protocol.ConnectionID,
remoteAddr net.Addr,
tp *TransportParameters,
@ -168,7 +160,6 @@ func NewCryptoSetupServer(
cs, _ := newCryptoSetup(
initialStream,
handshakeStream,
oneRTTStream,
connID,
tp,
runner,
@ -185,7 +176,6 @@ func NewCryptoSetupServer(
func newCryptoSetup(
initialStream io.Writer,
handshakeStream io.Writer,
oneRTTStream io.Writer,
connID protocol.ConnectionID,
tp *TransportParameters,
runner handshakeRunner,
@ -202,7 +192,6 @@ func newCryptoSetup(
initialSealer: initialSealer,
initialOpener: initialOpener,
handshakeStream: handshakeStream,
oneRTTStream: oneRTTStream,
aead: newUpdatableAEAD(rttStats, logger),
readEncLevel: protocol.EncryptionInitial,
writeEncLevel: protocol.EncryptionInitial,
@ -251,10 +240,6 @@ func (h *cryptoSetup) RunHandshake() {
select {
case <-handshakeComplete: // return when the handshake is done
h.runner.OnHandshakeComplete()
// send a session ticket
if h.perspective == protocol.PerspectiveServer {
h.maybeSendSessionTicket()
}
case <-h.closeChan:
close(h.messageChan)
// wait until the Handshake() go routine has returned
@ -475,20 +460,13 @@ func (h *cryptoSetup) handlePeerParamsFromSessionStateImpl(data []byte) (*Transp
}
// only valid for the server
func (h *cryptoSetup) maybeSendSessionTicket() {
func (h *cryptoSetup) GetSessionTicket() ([]byte, error) {
var appData []byte
// Save transport parameters to the session ticket if we're allowing 0-RTT.
if h.tlsConf.MaxEarlyData > 0 {
appData = (&sessionTicket{Parameters: h.ourParams}).Marshal()
}
ticket, err := h.conn.GetSessionTicket(appData)
if err != nil {
h.onError(alertInternalError, err.Error())
return
}
if ticket != nil {
h.oneRTTStream.Write(ticket)
}
return h.conn.GetSessionTicket(appData)
}
// accept0RTT is called for the server when receiving the client's session ticket.

View file

@ -8,7 +8,6 @@ import (
"crypto/x509"
"crypto/x509/pkix"
"errors"
"io/ioutil"
"math/big"
"time"
@ -55,12 +54,11 @@ func (s *stream) Write(b []byte) (int, error) {
var _ = Describe("Crypto Setup TLS", func() {
var clientConf, serverConf *tls.Config
initStreams := func() (chan chunk, *stream /* initial */, *stream /* handshake */, *stream /* 1-RTT */) {
initStreams := func() (chan chunk, *stream /* initial */, *stream /* handshake */) {
chunkChan := make(chan chunk, 100)
initialStream := newStream(chunkChan, protocol.EncryptionInitial)
handshakeStream := newStream(chunkChan, protocol.EncryptionHandshake)
oneRTTStream := newStream(chunkChan, protocol.Encryption1RTT)
return chunkChan, initialStream, handshakeStream, oneRTTStream
return chunkChan, initialStream, handshakeStream
}
BeforeEach(func() {
@ -89,7 +87,6 @@ var _ = Describe("Crypto Setup TLS", func() {
server := NewCryptoSetupServer(
&bytes.Buffer{},
&bytes.Buffer{},
ioutil.Discard,
protocol.ConnectionID{},
nil,
&TransportParameters{},
@ -117,11 +114,10 @@ var _ = Describe("Crypto Setup TLS", func() {
sErrChan := make(chan error, 1)
runner := NewMockHandshakeRunner(mockCtrl)
runner.EXPECT().OnError(gomock.Any()).Do(func(e error) { sErrChan <- e })
_, sInitialStream, sHandshakeStream, _ := initStreams()
_, sInitialStream, sHandshakeStream := initStreams()
server := NewCryptoSetupServer(
sInitialStream,
sHandshakeStream,
ioutil.Discard,
protocol.ConnectionID{},
nil,
&TransportParameters{},
@ -153,13 +149,12 @@ var _ = Describe("Crypto Setup TLS", func() {
It("errors when a message is received at the wrong encryption level", func() {
sErrChan := make(chan error, 1)
_, sInitialStream, sHandshakeStream, _ := initStreams()
_, sInitialStream, sHandshakeStream := initStreams()
runner := NewMockHandshakeRunner(mockCtrl)
runner.EXPECT().OnError(gomock.Any()).Do(func(e error) { sErrChan <- e })
server := NewCryptoSetupServer(
sInitialStream,
sHandshakeStream,
ioutil.Discard,
protocol.ConnectionID{},
nil,
&TransportParameters{},
@ -194,13 +189,12 @@ var _ = Describe("Crypto Setup TLS", func() {
It("returns Handshake() when handling a message fails", func() {
sErrChan := make(chan error, 1)
_, sInitialStream, sHandshakeStream, _ := initStreams()
_, sInitialStream, sHandshakeStream := initStreams()
runner := NewMockHandshakeRunner(mockCtrl)
runner.EXPECT().OnError(gomock.Any()).Do(func(e error) { sErrChan <- e })
server := NewCryptoSetupServer(
sInitialStream,
sHandshakeStream,
ioutil.Discard,
protocol.ConnectionID{},
nil,
&TransportParameters{},
@ -230,11 +224,10 @@ var _ = Describe("Crypto Setup TLS", func() {
})
It("returns Handshake() when it is closed", func() {
_, sInitialStream, sHandshakeStream, _ := initStreams()
_, sInitialStream, sHandshakeStream := initStreams()
server := NewCryptoSetupServer(
sInitialStream,
sHandshakeStream,
ioutil.Discard,
protocol.ConnectionID{},
nil,
&TransportParameters{},
@ -305,6 +298,11 @@ var _ = Describe("Crypto Setup TLS", func() {
go func() {
defer GinkgoRecover()
server.RunHandshake()
ticket, err := server.GetSessionTicket()
Expect(err).ToNot(HaveOccurred())
if ticket != nil {
client.HandleMessage(ticket, protocol.Encryption1RTT)
}
close(done)
}()
@ -314,7 +312,7 @@ var _ = Describe("Crypto Setup TLS", func() {
handshakeWithTLSConf := func(clientConf, serverConf *tls.Config, enable0RTT bool) (CryptoSetup /* client */, error /* client error */, CryptoSetup /* server */, error /* server error */) {
var cHandshakeComplete bool
cChunkChan, cInitialStream, cHandshakeStream, cOneRTTStream := initStreams()
cChunkChan, cInitialStream, cHandshakeStream := initStreams()
cErrChan := make(chan error, 1)
cRunner := NewMockHandshakeRunner(mockCtrl)
cRunner.EXPECT().OnReceivedParams(gomock.Any())
@ -323,7 +321,6 @@ var _ = Describe("Crypto Setup TLS", func() {
client, _ := NewCryptoSetupClient(
cInitialStream,
cHandshakeStream,
cOneRTTStream,
protocol.ConnectionID{},
nil,
&TransportParameters{},
@ -335,7 +332,7 @@ var _ = Describe("Crypto Setup TLS", func() {
)
var sHandshakeComplete bool
sChunkChan, sInitialStream, sHandshakeStream, sOneRTTStream := initStreams()
sChunkChan, sInitialStream, sHandshakeStream := initStreams()
sErrChan := make(chan error, 1)
sRunner := NewMockHandshakeRunner(mockCtrl)
sRunner.EXPECT().OnReceivedParams(gomock.Any())
@ -345,7 +342,6 @@ var _ = Describe("Crypto Setup TLS", func() {
server := NewCryptoSetupServer(
sInitialStream,
sHandshakeStream,
sOneRTTStream,
protocol.ConnectionID{},
nil,
&TransportParameters{StatelessResetToken: &token},
@ -394,11 +390,10 @@ var _ = Describe("Crypto Setup TLS", func() {
It("signals when it has written the ClientHello", func() {
runner := NewMockHandshakeRunner(mockCtrl)
cChunkChan, cInitialStream, cHandshakeStream, _ := initStreams()
cChunkChan, cInitialStream, cHandshakeStream := initStreams()
client, chChan := NewCryptoSetupClient(
cInitialStream,
cHandshakeStream,
ioutil.Discard,
protocol.ConnectionID{},
nil,
&TransportParameters{},
@ -431,7 +426,7 @@ var _ = Describe("Crypto Setup TLS", func() {
It("receives transport parameters", func() {
var cTransportParametersRcvd, sTransportParametersRcvd *TransportParameters
cChunkChan, cInitialStream, cHandshakeStream, _ := initStreams()
cChunkChan, cInitialStream, cHandshakeStream := initStreams()
cTransportParameters := &TransportParameters{MaxIdleTimeout: 0x42 * time.Second}
cRunner := NewMockHandshakeRunner(mockCtrl)
cRunner.EXPECT().OnReceivedParams(gomock.Any()).Do(func(tp *TransportParameters) { sTransportParametersRcvd = tp })
@ -439,7 +434,6 @@ var _ = Describe("Crypto Setup TLS", func() {
client, _ := NewCryptoSetupClient(
cInitialStream,
cHandshakeStream,
ioutil.Discard,
protocol.ConnectionID{},
nil,
cTransportParameters,
@ -450,7 +444,7 @@ var _ = Describe("Crypto Setup TLS", func() {
utils.DefaultLogger.WithPrefix("client"),
)
sChunkChan, sInitialStream, sHandshakeStream, _ := initStreams()
sChunkChan, sInitialStream, sHandshakeStream := initStreams()
var token [16]byte
sRunner := NewMockHandshakeRunner(mockCtrl)
sRunner.EXPECT().OnReceivedParams(gomock.Any()).Do(func(tp *TransportParameters) { cTransportParametersRcvd = tp })
@ -462,7 +456,6 @@ var _ = Describe("Crypto Setup TLS", func() {
server := NewCryptoSetupServer(
sInitialStream,
sHandshakeStream,
ioutil.Discard,
protocol.ConnectionID{},
nil,
sTransportParameters,
@ -487,14 +480,13 @@ var _ = Describe("Crypto Setup TLS", func() {
Context("with session tickets", func() {
It("errors when the NewSessionTicket is sent at the wrong encryption level", func() {
cChunkChan, cInitialStream, cHandshakeStream, _ := initStreams()
cChunkChan, cInitialStream, cHandshakeStream := initStreams()
cRunner := NewMockHandshakeRunner(mockCtrl)
cRunner.EXPECT().OnReceivedParams(gomock.Any())
cRunner.EXPECT().OnHandshakeComplete()
client, _ := NewCryptoSetupClient(
cInitialStream,
cHandshakeStream,
ioutil.Discard,
protocol.ConnectionID{},
nil,
&TransportParameters{},
@ -505,14 +497,13 @@ var _ = Describe("Crypto Setup TLS", func() {
utils.DefaultLogger.WithPrefix("client"),
)
sChunkChan, sInitialStream, sHandshakeStream, _ := initStreams()
sChunkChan, sInitialStream, sHandshakeStream := initStreams()
sRunner := NewMockHandshakeRunner(mockCtrl)
sRunner.EXPECT().OnReceivedParams(gomock.Any())
sRunner.EXPECT().OnHandshakeComplete()
server := NewCryptoSetupServer(
sInitialStream,
sHandshakeStream,
ioutil.Discard,
protocol.ConnectionID{},
nil,
&TransportParameters{},
@ -544,14 +535,13 @@ var _ = Describe("Crypto Setup TLS", func() {
})
It("errors when handling the NewSessionTicket fails", func() {
cChunkChan, cInitialStream, cHandshakeStream, _ := initStreams()
cChunkChan, cInitialStream, cHandshakeStream := initStreams()
cRunner := NewMockHandshakeRunner(mockCtrl)
cRunner.EXPECT().OnReceivedParams(gomock.Any())
cRunner.EXPECT().OnHandshakeComplete()
client, _ := NewCryptoSetupClient(
cInitialStream,
cHandshakeStream,
ioutil.Discard,
protocol.ConnectionID{},
nil,
&TransportParameters{},
@ -562,14 +552,13 @@ var _ = Describe("Crypto Setup TLS", func() {
utils.DefaultLogger.WithPrefix("client"),
)
sChunkChan, sInitialStream, sHandshakeStream, _ := initStreams()
sChunkChan, sInitialStream, sHandshakeStream := initStreams()
sRunner := NewMockHandshakeRunner(mockCtrl)
sRunner.EXPECT().OnReceivedParams(gomock.Any())
sRunner.EXPECT().OnHandshakeComplete()
server := NewCryptoSetupServer(
sInitialStream,
sHandshakeStream,
ioutil.Discard,
protocol.ConnectionID{},
nil,
&TransportParameters{},
@ -670,16 +659,16 @@ var _ = Describe("Crypto Setup TLS", func() {
Expect(client.ConnectionState().DidResume).To(BeFalse())
csc.EXPECT().Get(gomock.Any()).Return(state, true)
csc.EXPECT().Put(gomock.Any(), nil)
csc.EXPECT().Put(gomock.Any(), gomock.Any()).MaxTimes(1)
cChunkChan, cInitialStream, cHandshakeStream, _ := initStreams()
cChunkChan, cInitialStream, cHandshakeStream := initStreams()
cRunner := NewMockHandshakeRunner(mockCtrl)
cRunner.EXPECT().OnReceivedParams(gomock.Any())
cRunner.EXPECT().OnHandshakeComplete()
client, clientHelloChan := NewCryptoSetupClient(
cInitialStream,
cHandshakeStream,
ioutil.Discard,
protocol.ConnectionID{},
nil,
&TransportParameters{},
@ -690,14 +679,13 @@ var _ = Describe("Crypto Setup TLS", func() {
utils.DefaultLogger.WithPrefix("client"),
)
sChunkChan, sInitialStream, sHandshakeStream, _ := initStreams()
sChunkChan, sInitialStream, sHandshakeStream := initStreams()
sRunner := NewMockHandshakeRunner(mockCtrl)
sRunner.EXPECT().OnReceivedParams(gomock.Any())
sRunner.EXPECT().OnHandshakeComplete()
server = NewCryptoSetupServer(
sInitialStream,
sHandshakeStream,
ioutil.Discard,
protocol.ConnectionID{},
nil,
&TransportParameters{},

View file

@ -72,6 +72,7 @@ type CryptoSetup interface {
RunHandshake()
io.Closer
ChangeConnectionID(protocol.ConnectionID)
GetSessionTicket() ([]byte, error)
HandleMessage([]byte, protocol.EncryptionLevel) bool
SetLargest1RTTAcked(protocol.PacketNumber)

View file

@ -208,6 +208,21 @@ func (mr *MockCryptoSetupMockRecorder) GetInitialSealer() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetInitialSealer", reflect.TypeOf((*MockCryptoSetup)(nil).GetInitialSealer))
}
// GetSessionTicket mocks base method
func (m *MockCryptoSetup) GetSessionTicket() ([]byte, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetSessionTicket")
ret0, _ := ret[0].([]byte)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetSessionTicket indicates an expected call of GetSessionTicket
func (mr *MockCryptoSetupMockRecorder) GetSessionTicket() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSessionTicket", reflect.TypeOf((*MockCryptoSetup)(nil).GetSessionTicket))
}
// HandleMessage mocks base method
func (m *MockCryptoSetup) HandleMessage(arg0 []byte, arg1 protocol.EncryptionLevel) bool {
m.ctrl.T.Helper()

View file

@ -110,7 +110,7 @@ const MinStreamFrameSize ByteCount = 128
// MaxPostHandshakeCryptoFrameSize is the maximum size of CRYPTO frames
// we send after the handshake completes.
const MaxPostHandshakeCryptoFrameSize ByteCount = 1000
const MaxPostHandshakeCryptoFrameSize = 1000
// MaxAckFrameSize is the maximum size for an ACK frame that we write
// Due to the varint encoding, ACK frames can grow (almost) indefinitely large.

View file

@ -52,6 +52,7 @@ type cryptoStreamHandler interface {
ChangeConnectionID(protocol.ConnectionID)
SetLargest1RTTAcked(protocol.PacketNumber)
DropHandshakeKeys()
GetSessionTicket() ([]byte, error)
io.Closer
ConnectionState() handshake.ConnectionState
}
@ -141,6 +142,7 @@ type session struct {
frameParser wire.FrameParser
packer packer
oneRTTStream cryptoStream // only set for the server
cryptoStreamHandler cryptoStreamHandler
receivedPackets chan *receivedPacket
@ -213,6 +215,7 @@ var newSession = func(
handshakeDestConnID: destConnID,
srcConnIDLen: srcConnID.Len(),
tokenGenerator: tokenGenerator,
oneRTTStream: newCryptoStream(),
perspective: protocol.PerspectiveServer,
handshakeCompleteChan: make(chan struct{}),
logger: logger,
@ -244,7 +247,6 @@ var newSession = func(
s.sentPacketHandler = ackhandler.NewSentPacketHandler(0, s.rttStats, s.traceCallback, s.logger)
initialStream := newCryptoStream()
handshakeStream := newCryptoStream()
oneRTTStream := newPostHandshakeCryptoStream(s.framer)
params := &handshake.TransportParameters{
InitialMaxStreamDataBidiLocal: protocol.InitialMaxStreamData,
InitialMaxStreamDataBidiRemote: protocol.InitialMaxStreamData,
@ -263,7 +265,6 @@ var newSession = func(
cs := handshake.NewCryptoSetupServer(
initialStream,
handshakeStream,
oneRTTStream,
clientDestConnID,
conn.RemoteAddr(),
params,
@ -297,7 +298,7 @@ var newSession = func(
s.version,
)
s.unpacker = newPacketUnpacker(cs, s.version)
s.cryptoStreamManager = newCryptoStreamManager(cs, initialStream, handshakeStream, oneRTTStream)
s.cryptoStreamManager = newCryptoStreamManager(cs, initialStream, handshakeStream, s.oneRTTStream)
return s
}
@ -348,7 +349,6 @@ var newClientSession = func(
s.sentPacketHandler = ackhandler.NewSentPacketHandler(initialPacketNumber, s.rttStats, s.traceCallback, s.logger)
initialStream := newCryptoStream()
handshakeStream := newCryptoStream()
oneRTTStream := newPostHandshakeCryptoStream(s.framer)
params := &handshake.TransportParameters{
InitialMaxStreamDataBidiRemote: protocol.InitialMaxStreamData,
InitialMaxStreamDataBidiLocal: protocol.InitialMaxStreamData,
@ -365,7 +365,6 @@ var newClientSession = func(
cs, clientHelloWritten := handshake.NewCryptoSetupClient(
initialStream,
handshakeStream,
oneRTTStream,
destConnID,
conn.RemoteAddr(),
params,
@ -382,7 +381,7 @@ var newClientSession = func(
)
s.clientHelloWritten = clientHelloWritten
s.cryptoStreamHandler = cs
s.cryptoStreamManager = newCryptoStreamManager(cs, initialStream, handshakeStream, oneRTTStream)
s.cryptoStreamManager = newCryptoStreamManager(cs, initialStream, handshakeStream, newCryptoStream())
s.unpacker = newPacketUnpacker(cs, s.version)
s.packer = newPacketPacker(
srcConnID,
@ -633,6 +632,16 @@ func (s *session) handleHandshakeComplete() {
s.sentPacketHandler.SetHandshakeComplete()
if s.perspective == protocol.PerspectiveServer {
ticket, err := s.cryptoStreamHandler.GetSessionTicket()
if err != nil {
s.closeLocal(err)
}
if ticket != nil {
s.oneRTTStream.Write(ticket)
for s.oneRTTStream.HasData() {
s.queueControlFrame(s.oneRTTStream.PopCryptoFrame(protocol.MaxPostHandshakeCryptoFrameSize))
}
}
token, err := s.tokenGenerator.NewToken(s.conn.RemoteAddr())
if err != nil {
s.closeLocal(err)

View file

@ -1191,6 +1191,7 @@ var _ = Describe("Session", func() {
<-finishHandshake
cryptoSetup.EXPECT().RunHandshake()
cryptoSetup.EXPECT().DropHandshakeKeys()
cryptoSetup.EXPECT().GetSessionTicket()
close(sess.handshakeCompleteChan)
sess.run()
}()
@ -1210,6 +1211,47 @@ var _ = Describe("Session", func() {
Eventually(sess.Context().Done()).Should(BeClosed())
})
It("sends a session ticket when the handshake completes", func() {
const size = protocol.MaxPostHandshakeCryptoFrameSize * 3 / 2
packer.EXPECT().PackPacket().AnyTimes()
finishHandshake := make(chan struct{})
sessionRunner.EXPECT().Retire(clientDestConnID)
go func() {
defer GinkgoRecover()
<-finishHandshake
cryptoSetup.EXPECT().RunHandshake()
cryptoSetup.EXPECT().DropHandshakeKeys()
cryptoSetup.EXPECT().GetSessionTicket().Return(make([]byte, size), nil)
close(sess.handshakeCompleteChan)
sess.run()
}()
handshakeCtx := sess.HandshakeComplete()
Consistently(handshakeCtx.Done()).ShouldNot(BeClosed())
mconn.EXPECT().RemoteAddr().Return(&net.UDPAddr{}) // the remote addr is needed for the token
close(finishHandshake)
Eventually(handshakeCtx.Done()).Should(BeClosed())
frames, _ := sess.framer.AppendControlFrames(nil, protocol.MaxByteCount)
var count int
var s int
for _, f := range frames {
if cf, ok := f.Frame.(*wire.CryptoFrame); ok {
count++
s += len(cf.Data)
Expect(f.Length(sess.version)).To(BeNumerically("<=", protocol.MaxPostHandshakeCryptoFrameSize))
}
}
Expect(size).To(BeEquivalentTo(s))
// make sure the go routine returns
streamManager.EXPECT().CloseWithError(gomock.Any())
expectReplaceWithClosed()
packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil)
cryptoSetup.EXPECT().Close()
mconn.EXPECT().Write(gomock.Any())
sess.shutdown()
Eventually(sess.Context().Done()).Should(BeClosed())
})
It("doesn't cancel the HandshakeComplete context when the handshake fails", func() {
packer.EXPECT().PackPacket().AnyTimes()
streamManager.EXPECT().CloseWithError(gomock.Any())
@ -1247,6 +1289,7 @@ var _ = Describe("Session", func() {
defer GinkgoRecover()
cryptoSetup.EXPECT().RunHandshake()
cryptoSetup.EXPECT().DropHandshakeKeys()
cryptoSetup.EXPECT().GetSessionTicket()
mconn.EXPECT().RemoteAddr().Return(&net.UDPAddr{}) // the remote addr is needed for the token
mconn.EXPECT().Write(gomock.Any())
close(sess.handshakeCompleteChan)
@ -1492,6 +1535,7 @@ var _ = Describe("Session", func() {
go func() {
defer GinkgoRecover()
cryptoSetup.EXPECT().RunHandshake().MaxTimes(1)
cryptoSetup.EXPECT().GetSessionTicket().MaxTimes(1)
cryptoSetup.EXPECT().DropHandshakeKeys().MaxTimes(1)
mconn.EXPECT().RemoteAddr().Return(&net.UDPAddr{})
close(sess.handshakeCompleteChan)