refactor how session tickets are sent

Previously, RunHandshake() would send the session tickets. Now, the
session initiates the sending of the session ticket. This simplifies the
setup a bit, and it will make it possible to include the RTT estimate in
the session ticket without accessing the RTTStats concurrently.
This commit is contained in:
Marten Seemann 2020-02-02 13:45:52 +07:00
parent 3e32a693ad
commit 8cde4ab638
9 changed files with 103 additions and 140 deletions

View file

@ -22,35 +22,6 @@ type cryptoStream interface {
PopCryptoFrame(protocol.ByteCount) *wire.CryptoFrame PopCryptoFrame(protocol.ByteCount) *wire.CryptoFrame
} }
type postHandshakeCryptoStream struct {
cryptoStream
framer framer
}
func newPostHandshakeCryptoStream(framer framer) cryptoStream {
return &postHandshakeCryptoStream{
cryptoStream: newCryptoStream(),
framer: framer,
}
}
// Write writes post-handshake messages.
// For simplicity, post-handshake crypto messages are treated as control frames.
// The framer functions as a stack (LIFO), so if there are multiple writes,
// they will be returned in the opposite order.
// This is acceptable, since post-handshake crypto messages are very rare.
func (s *postHandshakeCryptoStream) Write(p []byte) (int, error) {
n, err := s.cryptoStream.Write(p)
if err != nil {
return n, err
}
for s.cryptoStream.HasData() {
s.framer.QueueControlFrame(s.PopCryptoFrame(protocol.MaxPostHandshakeCryptoFrameSize))
}
return n, nil
}
type cryptoStreamImpl struct { type cryptoStreamImpl struct {
queue *frameSorter queue *frameSorter
msgBuf []byte msgBuf []byte

View file

@ -180,46 +180,3 @@ var _ = Describe("Crypto Stream", func() {
}) })
}) })
}) })
var _ = Describe("Post Handshake Crypto Stream", func() {
var (
cs cryptoStream
framer framer
)
BeforeEach(func() {
framer = newFramer(NewMockStreamGetter(mockCtrl), protocol.VersionTLS)
cs = newPostHandshakeCryptoStream(framer)
})
It("queues CRYPTO frames when writing data", func() {
n, err := cs.Write([]byte("foo"))
Expect(err).ToNot(HaveOccurred())
Expect(n).To(Equal(3))
n, err = cs.Write([]byte("bar"))
Expect(err).ToNot(HaveOccurred())
Expect(n).To(Equal(3))
frames, _ := framer.AppendControlFrames(nil, 1000)
Expect(frames).To(HaveLen(2))
fs := []wire.Frame{frames[0].Frame, frames[1].Frame}
Expect(fs).To(ContainElement(&wire.CryptoFrame{Data: []byte("foo")}))
Expect(fs).To(ContainElement(&wire.CryptoFrame{Data: []byte("bar"), Offset: 3}))
})
It("splits large writes into multiple frames", func() {
size := 10 * protocol.MaxPostHandshakeCryptoFrameSize
n, err := cs.Write(make([]byte, size))
Expect(err).ToNot(HaveOccurred())
Expect(n).To(BeEquivalentTo(size))
frames, _ := framer.AppendControlFrames(nil, protocol.MaxByteCount)
Expect(frames).To(HaveLen(11)) // one more for framing overhead
var dataLen int
for _, f := range frames {
Expect(f.Frame.Length(protocol.VersionTLS)).To(BeNumerically("<=", protocol.MaxPostHandshakeCryptoFrameSize))
Expect(f.Frame).To(BeAssignableToTypeOf(&wire.CryptoFrame{}))
dataLen += len(f.Frame.(*wire.CryptoFrame).Data)
}
Expect(dataLen).To(BeEquivalentTo(size))
})
})

View file

@ -16,12 +16,8 @@ import (
"github.com/marten-seemann/qtls" "github.com/marten-seemann/qtls"
) )
const ( // TLS unexpected_message alert
// TLS unexpected_message alert const alertUnexpectedMessage uint8 = 10
alertUnexpectedMessage uint8 = 10
// TLS internal error
alertInternalError uint8 = 80
)
type messageType uint8 type messageType uint8
@ -111,7 +107,6 @@ type cryptoSetup struct {
handshakeOpener LongHeaderOpener handshakeOpener LongHeaderOpener
handshakeSealer LongHeaderSealer handshakeSealer LongHeaderSealer
oneRTTStream io.Writer
aead *updatableAEAD aead *updatableAEAD
has1RTTSealer bool has1RTTSealer bool
has1RTTOpener bool has1RTTOpener bool
@ -124,7 +119,6 @@ var _ CryptoSetup = &cryptoSetup{}
func NewCryptoSetupClient( func NewCryptoSetupClient(
initialStream io.Writer, initialStream io.Writer,
handshakeStream io.Writer, handshakeStream io.Writer,
oneRTTStream io.Writer,
connID protocol.ConnectionID, connID protocol.ConnectionID,
remoteAddr net.Addr, remoteAddr net.Addr,
tp *TransportParameters, tp *TransportParameters,
@ -137,7 +131,6 @@ func NewCryptoSetupClient(
cs, clientHelloWritten := newCryptoSetup( cs, clientHelloWritten := newCryptoSetup(
initialStream, initialStream,
handshakeStream, handshakeStream,
oneRTTStream,
connID, connID,
tp, tp,
runner, runner,
@ -155,7 +148,6 @@ func NewCryptoSetupClient(
func NewCryptoSetupServer( func NewCryptoSetupServer(
initialStream io.Writer, initialStream io.Writer,
handshakeStream io.Writer, handshakeStream io.Writer,
oneRTTStream io.Writer,
connID protocol.ConnectionID, connID protocol.ConnectionID,
remoteAddr net.Addr, remoteAddr net.Addr,
tp *TransportParameters, tp *TransportParameters,
@ -168,7 +160,6 @@ func NewCryptoSetupServer(
cs, _ := newCryptoSetup( cs, _ := newCryptoSetup(
initialStream, initialStream,
handshakeStream, handshakeStream,
oneRTTStream,
connID, connID,
tp, tp,
runner, runner,
@ -185,7 +176,6 @@ func NewCryptoSetupServer(
func newCryptoSetup( func newCryptoSetup(
initialStream io.Writer, initialStream io.Writer,
handshakeStream io.Writer, handshakeStream io.Writer,
oneRTTStream io.Writer,
connID protocol.ConnectionID, connID protocol.ConnectionID,
tp *TransportParameters, tp *TransportParameters,
runner handshakeRunner, runner handshakeRunner,
@ -202,7 +192,6 @@ func newCryptoSetup(
initialSealer: initialSealer, initialSealer: initialSealer,
initialOpener: initialOpener, initialOpener: initialOpener,
handshakeStream: handshakeStream, handshakeStream: handshakeStream,
oneRTTStream: oneRTTStream,
aead: newUpdatableAEAD(rttStats, logger), aead: newUpdatableAEAD(rttStats, logger),
readEncLevel: protocol.EncryptionInitial, readEncLevel: protocol.EncryptionInitial,
writeEncLevel: protocol.EncryptionInitial, writeEncLevel: protocol.EncryptionInitial,
@ -251,10 +240,6 @@ func (h *cryptoSetup) RunHandshake() {
select { select {
case <-handshakeComplete: // return when the handshake is done case <-handshakeComplete: // return when the handshake is done
h.runner.OnHandshakeComplete() h.runner.OnHandshakeComplete()
// send a session ticket
if h.perspective == protocol.PerspectiveServer {
h.maybeSendSessionTicket()
}
case <-h.closeChan: case <-h.closeChan:
close(h.messageChan) close(h.messageChan)
// wait until the Handshake() go routine has returned // wait until the Handshake() go routine has returned
@ -475,20 +460,13 @@ func (h *cryptoSetup) handlePeerParamsFromSessionStateImpl(data []byte) (*Transp
} }
// only valid for the server // only valid for the server
func (h *cryptoSetup) maybeSendSessionTicket() { func (h *cryptoSetup) GetSessionTicket() ([]byte, error) {
var appData []byte var appData []byte
// Save transport parameters to the session ticket if we're allowing 0-RTT. // Save transport parameters to the session ticket if we're allowing 0-RTT.
if h.tlsConf.MaxEarlyData > 0 { if h.tlsConf.MaxEarlyData > 0 {
appData = (&sessionTicket{Parameters: h.ourParams}).Marshal() appData = (&sessionTicket{Parameters: h.ourParams}).Marshal()
} }
ticket, err := h.conn.GetSessionTicket(appData) return h.conn.GetSessionTicket(appData)
if err != nil {
h.onError(alertInternalError, err.Error())
return
}
if ticket != nil {
h.oneRTTStream.Write(ticket)
}
} }
// accept0RTT is called for the server when receiving the client's session ticket. // accept0RTT is called for the server when receiving the client's session ticket.

View file

@ -8,7 +8,6 @@ import (
"crypto/x509" "crypto/x509"
"crypto/x509/pkix" "crypto/x509/pkix"
"errors" "errors"
"io/ioutil"
"math/big" "math/big"
"time" "time"
@ -55,12 +54,11 @@ func (s *stream) Write(b []byte) (int, error) {
var _ = Describe("Crypto Setup TLS", func() { var _ = Describe("Crypto Setup TLS", func() {
var clientConf, serverConf *tls.Config var clientConf, serverConf *tls.Config
initStreams := func() (chan chunk, *stream /* initial */, *stream /* handshake */, *stream /* 1-RTT */) { initStreams := func() (chan chunk, *stream /* initial */, *stream /* handshake */) {
chunkChan := make(chan chunk, 100) chunkChan := make(chan chunk, 100)
initialStream := newStream(chunkChan, protocol.EncryptionInitial) initialStream := newStream(chunkChan, protocol.EncryptionInitial)
handshakeStream := newStream(chunkChan, protocol.EncryptionHandshake) handshakeStream := newStream(chunkChan, protocol.EncryptionHandshake)
oneRTTStream := newStream(chunkChan, protocol.Encryption1RTT) return chunkChan, initialStream, handshakeStream
return chunkChan, initialStream, handshakeStream, oneRTTStream
} }
BeforeEach(func() { BeforeEach(func() {
@ -89,7 +87,6 @@ var _ = Describe("Crypto Setup TLS", func() {
server := NewCryptoSetupServer( server := NewCryptoSetupServer(
&bytes.Buffer{}, &bytes.Buffer{},
&bytes.Buffer{}, &bytes.Buffer{},
ioutil.Discard,
protocol.ConnectionID{}, protocol.ConnectionID{},
nil, nil,
&TransportParameters{}, &TransportParameters{},
@ -117,11 +114,10 @@ var _ = Describe("Crypto Setup TLS", func() {
sErrChan := make(chan error, 1) sErrChan := make(chan error, 1)
runner := NewMockHandshakeRunner(mockCtrl) runner := NewMockHandshakeRunner(mockCtrl)
runner.EXPECT().OnError(gomock.Any()).Do(func(e error) { sErrChan <- e }) runner.EXPECT().OnError(gomock.Any()).Do(func(e error) { sErrChan <- e })
_, sInitialStream, sHandshakeStream, _ := initStreams() _, sInitialStream, sHandshakeStream := initStreams()
server := NewCryptoSetupServer( server := NewCryptoSetupServer(
sInitialStream, sInitialStream,
sHandshakeStream, sHandshakeStream,
ioutil.Discard,
protocol.ConnectionID{}, protocol.ConnectionID{},
nil, nil,
&TransportParameters{}, &TransportParameters{},
@ -153,13 +149,12 @@ var _ = Describe("Crypto Setup TLS", func() {
It("errors when a message is received at the wrong encryption level", func() { It("errors when a message is received at the wrong encryption level", func() {
sErrChan := make(chan error, 1) sErrChan := make(chan error, 1)
_, sInitialStream, sHandshakeStream, _ := initStreams() _, sInitialStream, sHandshakeStream := initStreams()
runner := NewMockHandshakeRunner(mockCtrl) runner := NewMockHandshakeRunner(mockCtrl)
runner.EXPECT().OnError(gomock.Any()).Do(func(e error) { sErrChan <- e }) runner.EXPECT().OnError(gomock.Any()).Do(func(e error) { sErrChan <- e })
server := NewCryptoSetupServer( server := NewCryptoSetupServer(
sInitialStream, sInitialStream,
sHandshakeStream, sHandshakeStream,
ioutil.Discard,
protocol.ConnectionID{}, protocol.ConnectionID{},
nil, nil,
&TransportParameters{}, &TransportParameters{},
@ -194,13 +189,12 @@ var _ = Describe("Crypto Setup TLS", func() {
It("returns Handshake() when handling a message fails", func() { It("returns Handshake() when handling a message fails", func() {
sErrChan := make(chan error, 1) sErrChan := make(chan error, 1)
_, sInitialStream, sHandshakeStream, _ := initStreams() _, sInitialStream, sHandshakeStream := initStreams()
runner := NewMockHandshakeRunner(mockCtrl) runner := NewMockHandshakeRunner(mockCtrl)
runner.EXPECT().OnError(gomock.Any()).Do(func(e error) { sErrChan <- e }) runner.EXPECT().OnError(gomock.Any()).Do(func(e error) { sErrChan <- e })
server := NewCryptoSetupServer( server := NewCryptoSetupServer(
sInitialStream, sInitialStream,
sHandshakeStream, sHandshakeStream,
ioutil.Discard,
protocol.ConnectionID{}, protocol.ConnectionID{},
nil, nil,
&TransportParameters{}, &TransportParameters{},
@ -230,11 +224,10 @@ var _ = Describe("Crypto Setup TLS", func() {
}) })
It("returns Handshake() when it is closed", func() { It("returns Handshake() when it is closed", func() {
_, sInitialStream, sHandshakeStream, _ := initStreams() _, sInitialStream, sHandshakeStream := initStreams()
server := NewCryptoSetupServer( server := NewCryptoSetupServer(
sInitialStream, sInitialStream,
sHandshakeStream, sHandshakeStream,
ioutil.Discard,
protocol.ConnectionID{}, protocol.ConnectionID{},
nil, nil,
&TransportParameters{}, &TransportParameters{},
@ -305,6 +298,11 @@ var _ = Describe("Crypto Setup TLS", func() {
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
server.RunHandshake() server.RunHandshake()
ticket, err := server.GetSessionTicket()
Expect(err).ToNot(HaveOccurred())
if ticket != nil {
client.HandleMessage(ticket, protocol.Encryption1RTT)
}
close(done) close(done)
}() }()
@ -314,7 +312,7 @@ var _ = Describe("Crypto Setup TLS", func() {
handshakeWithTLSConf := func(clientConf, serverConf *tls.Config, enable0RTT bool) (CryptoSetup /* client */, error /* client error */, CryptoSetup /* server */, error /* server error */) { handshakeWithTLSConf := func(clientConf, serverConf *tls.Config, enable0RTT bool) (CryptoSetup /* client */, error /* client error */, CryptoSetup /* server */, error /* server error */) {
var cHandshakeComplete bool var cHandshakeComplete bool
cChunkChan, cInitialStream, cHandshakeStream, cOneRTTStream := initStreams() cChunkChan, cInitialStream, cHandshakeStream := initStreams()
cErrChan := make(chan error, 1) cErrChan := make(chan error, 1)
cRunner := NewMockHandshakeRunner(mockCtrl) cRunner := NewMockHandshakeRunner(mockCtrl)
cRunner.EXPECT().OnReceivedParams(gomock.Any()) cRunner.EXPECT().OnReceivedParams(gomock.Any())
@ -323,7 +321,6 @@ var _ = Describe("Crypto Setup TLS", func() {
client, _ := NewCryptoSetupClient( client, _ := NewCryptoSetupClient(
cInitialStream, cInitialStream,
cHandshakeStream, cHandshakeStream,
cOneRTTStream,
protocol.ConnectionID{}, protocol.ConnectionID{},
nil, nil,
&TransportParameters{}, &TransportParameters{},
@ -335,7 +332,7 @@ var _ = Describe("Crypto Setup TLS", func() {
) )
var sHandshakeComplete bool var sHandshakeComplete bool
sChunkChan, sInitialStream, sHandshakeStream, sOneRTTStream := initStreams() sChunkChan, sInitialStream, sHandshakeStream := initStreams()
sErrChan := make(chan error, 1) sErrChan := make(chan error, 1)
sRunner := NewMockHandshakeRunner(mockCtrl) sRunner := NewMockHandshakeRunner(mockCtrl)
sRunner.EXPECT().OnReceivedParams(gomock.Any()) sRunner.EXPECT().OnReceivedParams(gomock.Any())
@ -345,7 +342,6 @@ var _ = Describe("Crypto Setup TLS", func() {
server := NewCryptoSetupServer( server := NewCryptoSetupServer(
sInitialStream, sInitialStream,
sHandshakeStream, sHandshakeStream,
sOneRTTStream,
protocol.ConnectionID{}, protocol.ConnectionID{},
nil, nil,
&TransportParameters{StatelessResetToken: &token}, &TransportParameters{StatelessResetToken: &token},
@ -394,11 +390,10 @@ var _ = Describe("Crypto Setup TLS", func() {
It("signals when it has written the ClientHello", func() { It("signals when it has written the ClientHello", func() {
runner := NewMockHandshakeRunner(mockCtrl) runner := NewMockHandshakeRunner(mockCtrl)
cChunkChan, cInitialStream, cHandshakeStream, _ := initStreams() cChunkChan, cInitialStream, cHandshakeStream := initStreams()
client, chChan := NewCryptoSetupClient( client, chChan := NewCryptoSetupClient(
cInitialStream, cInitialStream,
cHandshakeStream, cHandshakeStream,
ioutil.Discard,
protocol.ConnectionID{}, protocol.ConnectionID{},
nil, nil,
&TransportParameters{}, &TransportParameters{},
@ -431,7 +426,7 @@ var _ = Describe("Crypto Setup TLS", func() {
It("receives transport parameters", func() { It("receives transport parameters", func() {
var cTransportParametersRcvd, sTransportParametersRcvd *TransportParameters var cTransportParametersRcvd, sTransportParametersRcvd *TransportParameters
cChunkChan, cInitialStream, cHandshakeStream, _ := initStreams() cChunkChan, cInitialStream, cHandshakeStream := initStreams()
cTransportParameters := &TransportParameters{MaxIdleTimeout: 0x42 * time.Second} cTransportParameters := &TransportParameters{MaxIdleTimeout: 0x42 * time.Second}
cRunner := NewMockHandshakeRunner(mockCtrl) cRunner := NewMockHandshakeRunner(mockCtrl)
cRunner.EXPECT().OnReceivedParams(gomock.Any()).Do(func(tp *TransportParameters) { sTransportParametersRcvd = tp }) cRunner.EXPECT().OnReceivedParams(gomock.Any()).Do(func(tp *TransportParameters) { sTransportParametersRcvd = tp })
@ -439,7 +434,6 @@ var _ = Describe("Crypto Setup TLS", func() {
client, _ := NewCryptoSetupClient( client, _ := NewCryptoSetupClient(
cInitialStream, cInitialStream,
cHandshakeStream, cHandshakeStream,
ioutil.Discard,
protocol.ConnectionID{}, protocol.ConnectionID{},
nil, nil,
cTransportParameters, cTransportParameters,
@ -450,7 +444,7 @@ var _ = Describe("Crypto Setup TLS", func() {
utils.DefaultLogger.WithPrefix("client"), utils.DefaultLogger.WithPrefix("client"),
) )
sChunkChan, sInitialStream, sHandshakeStream, _ := initStreams() sChunkChan, sInitialStream, sHandshakeStream := initStreams()
var token [16]byte var token [16]byte
sRunner := NewMockHandshakeRunner(mockCtrl) sRunner := NewMockHandshakeRunner(mockCtrl)
sRunner.EXPECT().OnReceivedParams(gomock.Any()).Do(func(tp *TransportParameters) { cTransportParametersRcvd = tp }) sRunner.EXPECT().OnReceivedParams(gomock.Any()).Do(func(tp *TransportParameters) { cTransportParametersRcvd = tp })
@ -462,7 +456,6 @@ var _ = Describe("Crypto Setup TLS", func() {
server := NewCryptoSetupServer( server := NewCryptoSetupServer(
sInitialStream, sInitialStream,
sHandshakeStream, sHandshakeStream,
ioutil.Discard,
protocol.ConnectionID{}, protocol.ConnectionID{},
nil, nil,
sTransportParameters, sTransportParameters,
@ -487,14 +480,13 @@ var _ = Describe("Crypto Setup TLS", func() {
Context("with session tickets", func() { Context("with session tickets", func() {
It("errors when the NewSessionTicket is sent at the wrong encryption level", func() { It("errors when the NewSessionTicket is sent at the wrong encryption level", func() {
cChunkChan, cInitialStream, cHandshakeStream, _ := initStreams() cChunkChan, cInitialStream, cHandshakeStream := initStreams()
cRunner := NewMockHandshakeRunner(mockCtrl) cRunner := NewMockHandshakeRunner(mockCtrl)
cRunner.EXPECT().OnReceivedParams(gomock.Any()) cRunner.EXPECT().OnReceivedParams(gomock.Any())
cRunner.EXPECT().OnHandshakeComplete() cRunner.EXPECT().OnHandshakeComplete()
client, _ := NewCryptoSetupClient( client, _ := NewCryptoSetupClient(
cInitialStream, cInitialStream,
cHandshakeStream, cHandshakeStream,
ioutil.Discard,
protocol.ConnectionID{}, protocol.ConnectionID{},
nil, nil,
&TransportParameters{}, &TransportParameters{},
@ -505,14 +497,13 @@ var _ = Describe("Crypto Setup TLS", func() {
utils.DefaultLogger.WithPrefix("client"), utils.DefaultLogger.WithPrefix("client"),
) )
sChunkChan, sInitialStream, sHandshakeStream, _ := initStreams() sChunkChan, sInitialStream, sHandshakeStream := initStreams()
sRunner := NewMockHandshakeRunner(mockCtrl) sRunner := NewMockHandshakeRunner(mockCtrl)
sRunner.EXPECT().OnReceivedParams(gomock.Any()) sRunner.EXPECT().OnReceivedParams(gomock.Any())
sRunner.EXPECT().OnHandshakeComplete() sRunner.EXPECT().OnHandshakeComplete()
server := NewCryptoSetupServer( server := NewCryptoSetupServer(
sInitialStream, sInitialStream,
sHandshakeStream, sHandshakeStream,
ioutil.Discard,
protocol.ConnectionID{}, protocol.ConnectionID{},
nil, nil,
&TransportParameters{}, &TransportParameters{},
@ -544,14 +535,13 @@ var _ = Describe("Crypto Setup TLS", func() {
}) })
It("errors when handling the NewSessionTicket fails", func() { It("errors when handling the NewSessionTicket fails", func() {
cChunkChan, cInitialStream, cHandshakeStream, _ := initStreams() cChunkChan, cInitialStream, cHandshakeStream := initStreams()
cRunner := NewMockHandshakeRunner(mockCtrl) cRunner := NewMockHandshakeRunner(mockCtrl)
cRunner.EXPECT().OnReceivedParams(gomock.Any()) cRunner.EXPECT().OnReceivedParams(gomock.Any())
cRunner.EXPECT().OnHandshakeComplete() cRunner.EXPECT().OnHandshakeComplete()
client, _ := NewCryptoSetupClient( client, _ := NewCryptoSetupClient(
cInitialStream, cInitialStream,
cHandshakeStream, cHandshakeStream,
ioutil.Discard,
protocol.ConnectionID{}, protocol.ConnectionID{},
nil, nil,
&TransportParameters{}, &TransportParameters{},
@ -562,14 +552,13 @@ var _ = Describe("Crypto Setup TLS", func() {
utils.DefaultLogger.WithPrefix("client"), utils.DefaultLogger.WithPrefix("client"),
) )
sChunkChan, sInitialStream, sHandshakeStream, _ := initStreams() sChunkChan, sInitialStream, sHandshakeStream := initStreams()
sRunner := NewMockHandshakeRunner(mockCtrl) sRunner := NewMockHandshakeRunner(mockCtrl)
sRunner.EXPECT().OnReceivedParams(gomock.Any()) sRunner.EXPECT().OnReceivedParams(gomock.Any())
sRunner.EXPECT().OnHandshakeComplete() sRunner.EXPECT().OnHandshakeComplete()
server := NewCryptoSetupServer( server := NewCryptoSetupServer(
sInitialStream, sInitialStream,
sHandshakeStream, sHandshakeStream,
ioutil.Discard,
protocol.ConnectionID{}, protocol.ConnectionID{},
nil, nil,
&TransportParameters{}, &TransportParameters{},
@ -670,16 +659,16 @@ var _ = Describe("Crypto Setup TLS", func() {
Expect(client.ConnectionState().DidResume).To(BeFalse()) Expect(client.ConnectionState().DidResume).To(BeFalse())
csc.EXPECT().Get(gomock.Any()).Return(state, true) csc.EXPECT().Get(gomock.Any()).Return(state, true)
csc.EXPECT().Put(gomock.Any(), nil)
csc.EXPECT().Put(gomock.Any(), gomock.Any()).MaxTimes(1) csc.EXPECT().Put(gomock.Any(), gomock.Any()).MaxTimes(1)
cChunkChan, cInitialStream, cHandshakeStream, _ := initStreams() cChunkChan, cInitialStream, cHandshakeStream := initStreams()
cRunner := NewMockHandshakeRunner(mockCtrl) cRunner := NewMockHandshakeRunner(mockCtrl)
cRunner.EXPECT().OnReceivedParams(gomock.Any()) cRunner.EXPECT().OnReceivedParams(gomock.Any())
cRunner.EXPECT().OnHandshakeComplete() cRunner.EXPECT().OnHandshakeComplete()
client, clientHelloChan := NewCryptoSetupClient( client, clientHelloChan := NewCryptoSetupClient(
cInitialStream, cInitialStream,
cHandshakeStream, cHandshakeStream,
ioutil.Discard,
protocol.ConnectionID{}, protocol.ConnectionID{},
nil, nil,
&TransportParameters{}, &TransportParameters{},
@ -690,14 +679,13 @@ var _ = Describe("Crypto Setup TLS", func() {
utils.DefaultLogger.WithPrefix("client"), utils.DefaultLogger.WithPrefix("client"),
) )
sChunkChan, sInitialStream, sHandshakeStream, _ := initStreams() sChunkChan, sInitialStream, sHandshakeStream := initStreams()
sRunner := NewMockHandshakeRunner(mockCtrl) sRunner := NewMockHandshakeRunner(mockCtrl)
sRunner.EXPECT().OnReceivedParams(gomock.Any()) sRunner.EXPECT().OnReceivedParams(gomock.Any())
sRunner.EXPECT().OnHandshakeComplete() sRunner.EXPECT().OnHandshakeComplete()
server = NewCryptoSetupServer( server = NewCryptoSetupServer(
sInitialStream, sInitialStream,
sHandshakeStream, sHandshakeStream,
ioutil.Discard,
protocol.ConnectionID{}, protocol.ConnectionID{},
nil, nil,
&TransportParameters{}, &TransportParameters{},

View file

@ -72,6 +72,7 @@ type CryptoSetup interface {
RunHandshake() RunHandshake()
io.Closer io.Closer
ChangeConnectionID(protocol.ConnectionID) ChangeConnectionID(protocol.ConnectionID)
GetSessionTicket() ([]byte, error)
HandleMessage([]byte, protocol.EncryptionLevel) bool HandleMessage([]byte, protocol.EncryptionLevel) bool
SetLargest1RTTAcked(protocol.PacketNumber) SetLargest1RTTAcked(protocol.PacketNumber)

View file

@ -208,6 +208,21 @@ func (mr *MockCryptoSetupMockRecorder) GetInitialSealer() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetInitialSealer", reflect.TypeOf((*MockCryptoSetup)(nil).GetInitialSealer)) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetInitialSealer", reflect.TypeOf((*MockCryptoSetup)(nil).GetInitialSealer))
} }
// GetSessionTicket mocks base method
func (m *MockCryptoSetup) GetSessionTicket() ([]byte, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetSessionTicket")
ret0, _ := ret[0].([]byte)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetSessionTicket indicates an expected call of GetSessionTicket
func (mr *MockCryptoSetupMockRecorder) GetSessionTicket() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSessionTicket", reflect.TypeOf((*MockCryptoSetup)(nil).GetSessionTicket))
}
// HandleMessage mocks base method // HandleMessage mocks base method
func (m *MockCryptoSetup) HandleMessage(arg0 []byte, arg1 protocol.EncryptionLevel) bool { func (m *MockCryptoSetup) HandleMessage(arg0 []byte, arg1 protocol.EncryptionLevel) bool {
m.ctrl.T.Helper() m.ctrl.T.Helper()

View file

@ -110,7 +110,7 @@ const MinStreamFrameSize ByteCount = 128
// MaxPostHandshakeCryptoFrameSize is the maximum size of CRYPTO frames // MaxPostHandshakeCryptoFrameSize is the maximum size of CRYPTO frames
// we send after the handshake completes. // we send after the handshake completes.
const MaxPostHandshakeCryptoFrameSize ByteCount = 1000 const MaxPostHandshakeCryptoFrameSize = 1000
// MaxAckFrameSize is the maximum size for an ACK frame that we write // MaxAckFrameSize is the maximum size for an ACK frame that we write
// Due to the varint encoding, ACK frames can grow (almost) indefinitely large. // Due to the varint encoding, ACK frames can grow (almost) indefinitely large.

View file

@ -52,6 +52,7 @@ type cryptoStreamHandler interface {
ChangeConnectionID(protocol.ConnectionID) ChangeConnectionID(protocol.ConnectionID)
SetLargest1RTTAcked(protocol.PacketNumber) SetLargest1RTTAcked(protocol.PacketNumber)
DropHandshakeKeys() DropHandshakeKeys()
GetSessionTicket() ([]byte, error)
io.Closer io.Closer
ConnectionState() handshake.ConnectionState ConnectionState() handshake.ConnectionState
} }
@ -141,6 +142,7 @@ type session struct {
frameParser wire.FrameParser frameParser wire.FrameParser
packer packer packer packer
oneRTTStream cryptoStream // only set for the server
cryptoStreamHandler cryptoStreamHandler cryptoStreamHandler cryptoStreamHandler
receivedPackets chan *receivedPacket receivedPackets chan *receivedPacket
@ -213,6 +215,7 @@ var newSession = func(
handshakeDestConnID: destConnID, handshakeDestConnID: destConnID,
srcConnIDLen: srcConnID.Len(), srcConnIDLen: srcConnID.Len(),
tokenGenerator: tokenGenerator, tokenGenerator: tokenGenerator,
oneRTTStream: newCryptoStream(),
perspective: protocol.PerspectiveServer, perspective: protocol.PerspectiveServer,
handshakeCompleteChan: make(chan struct{}), handshakeCompleteChan: make(chan struct{}),
logger: logger, logger: logger,
@ -244,7 +247,6 @@ var newSession = func(
s.sentPacketHandler = ackhandler.NewSentPacketHandler(0, s.rttStats, s.traceCallback, s.logger) s.sentPacketHandler = ackhandler.NewSentPacketHandler(0, s.rttStats, s.traceCallback, s.logger)
initialStream := newCryptoStream() initialStream := newCryptoStream()
handshakeStream := newCryptoStream() handshakeStream := newCryptoStream()
oneRTTStream := newPostHandshakeCryptoStream(s.framer)
params := &handshake.TransportParameters{ params := &handshake.TransportParameters{
InitialMaxStreamDataBidiLocal: protocol.InitialMaxStreamData, InitialMaxStreamDataBidiLocal: protocol.InitialMaxStreamData,
InitialMaxStreamDataBidiRemote: protocol.InitialMaxStreamData, InitialMaxStreamDataBidiRemote: protocol.InitialMaxStreamData,
@ -263,7 +265,6 @@ var newSession = func(
cs := handshake.NewCryptoSetupServer( cs := handshake.NewCryptoSetupServer(
initialStream, initialStream,
handshakeStream, handshakeStream,
oneRTTStream,
clientDestConnID, clientDestConnID,
conn.RemoteAddr(), conn.RemoteAddr(),
params, params,
@ -297,7 +298,7 @@ var newSession = func(
s.version, s.version,
) )
s.unpacker = newPacketUnpacker(cs, s.version) s.unpacker = newPacketUnpacker(cs, s.version)
s.cryptoStreamManager = newCryptoStreamManager(cs, initialStream, handshakeStream, oneRTTStream) s.cryptoStreamManager = newCryptoStreamManager(cs, initialStream, handshakeStream, s.oneRTTStream)
return s return s
} }
@ -348,7 +349,6 @@ var newClientSession = func(
s.sentPacketHandler = ackhandler.NewSentPacketHandler(initialPacketNumber, s.rttStats, s.traceCallback, s.logger) s.sentPacketHandler = ackhandler.NewSentPacketHandler(initialPacketNumber, s.rttStats, s.traceCallback, s.logger)
initialStream := newCryptoStream() initialStream := newCryptoStream()
handshakeStream := newCryptoStream() handshakeStream := newCryptoStream()
oneRTTStream := newPostHandshakeCryptoStream(s.framer)
params := &handshake.TransportParameters{ params := &handshake.TransportParameters{
InitialMaxStreamDataBidiRemote: protocol.InitialMaxStreamData, InitialMaxStreamDataBidiRemote: protocol.InitialMaxStreamData,
InitialMaxStreamDataBidiLocal: protocol.InitialMaxStreamData, InitialMaxStreamDataBidiLocal: protocol.InitialMaxStreamData,
@ -365,7 +365,6 @@ var newClientSession = func(
cs, clientHelloWritten := handshake.NewCryptoSetupClient( cs, clientHelloWritten := handshake.NewCryptoSetupClient(
initialStream, initialStream,
handshakeStream, handshakeStream,
oneRTTStream,
destConnID, destConnID,
conn.RemoteAddr(), conn.RemoteAddr(),
params, params,
@ -382,7 +381,7 @@ var newClientSession = func(
) )
s.clientHelloWritten = clientHelloWritten s.clientHelloWritten = clientHelloWritten
s.cryptoStreamHandler = cs s.cryptoStreamHandler = cs
s.cryptoStreamManager = newCryptoStreamManager(cs, initialStream, handshakeStream, oneRTTStream) s.cryptoStreamManager = newCryptoStreamManager(cs, initialStream, handshakeStream, newCryptoStream())
s.unpacker = newPacketUnpacker(cs, s.version) s.unpacker = newPacketUnpacker(cs, s.version)
s.packer = newPacketPacker( s.packer = newPacketPacker(
srcConnID, srcConnID,
@ -633,6 +632,16 @@ func (s *session) handleHandshakeComplete() {
s.sentPacketHandler.SetHandshakeComplete() s.sentPacketHandler.SetHandshakeComplete()
if s.perspective == protocol.PerspectiveServer { if s.perspective == protocol.PerspectiveServer {
ticket, err := s.cryptoStreamHandler.GetSessionTicket()
if err != nil {
s.closeLocal(err)
}
if ticket != nil {
s.oneRTTStream.Write(ticket)
for s.oneRTTStream.HasData() {
s.queueControlFrame(s.oneRTTStream.PopCryptoFrame(protocol.MaxPostHandshakeCryptoFrameSize))
}
}
token, err := s.tokenGenerator.NewToken(s.conn.RemoteAddr()) token, err := s.tokenGenerator.NewToken(s.conn.RemoteAddr())
if err != nil { if err != nil {
s.closeLocal(err) s.closeLocal(err)

View file

@ -1191,6 +1191,7 @@ var _ = Describe("Session", func() {
<-finishHandshake <-finishHandshake
cryptoSetup.EXPECT().RunHandshake() cryptoSetup.EXPECT().RunHandshake()
cryptoSetup.EXPECT().DropHandshakeKeys() cryptoSetup.EXPECT().DropHandshakeKeys()
cryptoSetup.EXPECT().GetSessionTicket()
close(sess.handshakeCompleteChan) close(sess.handshakeCompleteChan)
sess.run() sess.run()
}() }()
@ -1210,6 +1211,47 @@ var _ = Describe("Session", func() {
Eventually(sess.Context().Done()).Should(BeClosed()) Eventually(sess.Context().Done()).Should(BeClosed())
}) })
It("sends a session ticket when the handshake completes", func() {
const size = protocol.MaxPostHandshakeCryptoFrameSize * 3 / 2
packer.EXPECT().PackPacket().AnyTimes()
finishHandshake := make(chan struct{})
sessionRunner.EXPECT().Retire(clientDestConnID)
go func() {
defer GinkgoRecover()
<-finishHandshake
cryptoSetup.EXPECT().RunHandshake()
cryptoSetup.EXPECT().DropHandshakeKeys()
cryptoSetup.EXPECT().GetSessionTicket().Return(make([]byte, size), nil)
close(sess.handshakeCompleteChan)
sess.run()
}()
handshakeCtx := sess.HandshakeComplete()
Consistently(handshakeCtx.Done()).ShouldNot(BeClosed())
mconn.EXPECT().RemoteAddr().Return(&net.UDPAddr{}) // the remote addr is needed for the token
close(finishHandshake)
Eventually(handshakeCtx.Done()).Should(BeClosed())
frames, _ := sess.framer.AppendControlFrames(nil, protocol.MaxByteCount)
var count int
var s int
for _, f := range frames {
if cf, ok := f.Frame.(*wire.CryptoFrame); ok {
count++
s += len(cf.Data)
Expect(f.Length(sess.version)).To(BeNumerically("<=", protocol.MaxPostHandshakeCryptoFrameSize))
}
}
Expect(size).To(BeEquivalentTo(s))
// make sure the go routine returns
streamManager.EXPECT().CloseWithError(gomock.Any())
expectReplaceWithClosed()
packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil)
cryptoSetup.EXPECT().Close()
mconn.EXPECT().Write(gomock.Any())
sess.shutdown()
Eventually(sess.Context().Done()).Should(BeClosed())
})
It("doesn't cancel the HandshakeComplete context when the handshake fails", func() { It("doesn't cancel the HandshakeComplete context when the handshake fails", func() {
packer.EXPECT().PackPacket().AnyTimes() packer.EXPECT().PackPacket().AnyTimes()
streamManager.EXPECT().CloseWithError(gomock.Any()) streamManager.EXPECT().CloseWithError(gomock.Any())
@ -1247,6 +1289,7 @@ var _ = Describe("Session", func() {
defer GinkgoRecover() defer GinkgoRecover()
cryptoSetup.EXPECT().RunHandshake() cryptoSetup.EXPECT().RunHandshake()
cryptoSetup.EXPECT().DropHandshakeKeys() cryptoSetup.EXPECT().DropHandshakeKeys()
cryptoSetup.EXPECT().GetSessionTicket()
mconn.EXPECT().RemoteAddr().Return(&net.UDPAddr{}) // the remote addr is needed for the token mconn.EXPECT().RemoteAddr().Return(&net.UDPAddr{}) // the remote addr is needed for the token
mconn.EXPECT().Write(gomock.Any()) mconn.EXPECT().Write(gomock.Any())
close(sess.handshakeCompleteChan) close(sess.handshakeCompleteChan)
@ -1492,6 +1535,7 @@ var _ = Describe("Session", func() {
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) cryptoSetup.EXPECT().RunHandshake().MaxTimes(1)
cryptoSetup.EXPECT().GetSessionTicket().MaxTimes(1)
cryptoSetup.EXPECT().DropHandshakeKeys().MaxTimes(1) cryptoSetup.EXPECT().DropHandshakeKeys().MaxTimes(1)
mconn.EXPECT().RemoteAddr().Return(&net.UDPAddr{}) mconn.EXPECT().RemoteAddr().Return(&net.UDPAddr{})
close(sess.handshakeCompleteChan) close(sess.handshakeCompleteChan)