mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-04 12:47:36 +03:00
refactor how session tickets are sent
Previously, RunHandshake() would send the session tickets. Now, the session initiates the sending of the session ticket. This simplifies the setup a bit, and it will make it possible to include the RTT estimate in the session ticket without accessing the RTTStats concurrently.
This commit is contained in:
parent
3e32a693ad
commit
8cde4ab638
9 changed files with 103 additions and 140 deletions
|
@ -22,35 +22,6 @@ type cryptoStream interface {
|
|||
PopCryptoFrame(protocol.ByteCount) *wire.CryptoFrame
|
||||
}
|
||||
|
||||
type postHandshakeCryptoStream struct {
|
||||
cryptoStream
|
||||
|
||||
framer framer
|
||||
}
|
||||
|
||||
func newPostHandshakeCryptoStream(framer framer) cryptoStream {
|
||||
return &postHandshakeCryptoStream{
|
||||
cryptoStream: newCryptoStream(),
|
||||
framer: framer,
|
||||
}
|
||||
}
|
||||
|
||||
// Write writes post-handshake messages.
|
||||
// For simplicity, post-handshake crypto messages are treated as control frames.
|
||||
// The framer functions as a stack (LIFO), so if there are multiple writes,
|
||||
// they will be returned in the opposite order.
|
||||
// This is acceptable, since post-handshake crypto messages are very rare.
|
||||
func (s *postHandshakeCryptoStream) Write(p []byte) (int, error) {
|
||||
n, err := s.cryptoStream.Write(p)
|
||||
if err != nil {
|
||||
return n, err
|
||||
}
|
||||
for s.cryptoStream.HasData() {
|
||||
s.framer.QueueControlFrame(s.PopCryptoFrame(protocol.MaxPostHandshakeCryptoFrameSize))
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
|
||||
type cryptoStreamImpl struct {
|
||||
queue *frameSorter
|
||||
msgBuf []byte
|
||||
|
|
|
@ -180,46 +180,3 @@ var _ = Describe("Crypto Stream", func() {
|
|||
})
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("Post Handshake Crypto Stream", func() {
|
||||
var (
|
||||
cs cryptoStream
|
||||
framer framer
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
framer = newFramer(NewMockStreamGetter(mockCtrl), protocol.VersionTLS)
|
||||
cs = newPostHandshakeCryptoStream(framer)
|
||||
})
|
||||
|
||||
It("queues CRYPTO frames when writing data", func() {
|
||||
n, err := cs.Write([]byte("foo"))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(n).To(Equal(3))
|
||||
n, err = cs.Write([]byte("bar"))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(n).To(Equal(3))
|
||||
frames, _ := framer.AppendControlFrames(nil, 1000)
|
||||
Expect(frames).To(HaveLen(2))
|
||||
fs := []wire.Frame{frames[0].Frame, frames[1].Frame}
|
||||
Expect(fs).To(ContainElement(&wire.CryptoFrame{Data: []byte("foo")}))
|
||||
Expect(fs).To(ContainElement(&wire.CryptoFrame{Data: []byte("bar"), Offset: 3}))
|
||||
})
|
||||
|
||||
It("splits large writes into multiple frames", func() {
|
||||
size := 10 * protocol.MaxPostHandshakeCryptoFrameSize
|
||||
n, err := cs.Write(make([]byte, size))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(n).To(BeEquivalentTo(size))
|
||||
frames, _ := framer.AppendControlFrames(nil, protocol.MaxByteCount)
|
||||
Expect(frames).To(HaveLen(11)) // one more for framing overhead
|
||||
var dataLen int
|
||||
for _, f := range frames {
|
||||
Expect(f.Frame.Length(protocol.VersionTLS)).To(BeNumerically("<=", protocol.MaxPostHandshakeCryptoFrameSize))
|
||||
Expect(f.Frame).To(BeAssignableToTypeOf(&wire.CryptoFrame{}))
|
||||
dataLen += len(f.Frame.(*wire.CryptoFrame).Data)
|
||||
}
|
||||
Expect(dataLen).To(BeEquivalentTo(size))
|
||||
})
|
||||
|
||||
})
|
||||
|
|
|
@ -16,12 +16,8 @@ import (
|
|||
"github.com/marten-seemann/qtls"
|
||||
)
|
||||
|
||||
const (
|
||||
// TLS unexpected_message alert
|
||||
alertUnexpectedMessage uint8 = 10
|
||||
// TLS internal error
|
||||
alertInternalError uint8 = 80
|
||||
)
|
||||
// TLS unexpected_message alert
|
||||
const alertUnexpectedMessage uint8 = 10
|
||||
|
||||
type messageType uint8
|
||||
|
||||
|
@ -111,7 +107,6 @@ type cryptoSetup struct {
|
|||
handshakeOpener LongHeaderOpener
|
||||
handshakeSealer LongHeaderSealer
|
||||
|
||||
oneRTTStream io.Writer
|
||||
aead *updatableAEAD
|
||||
has1RTTSealer bool
|
||||
has1RTTOpener bool
|
||||
|
@ -124,7 +119,6 @@ var _ CryptoSetup = &cryptoSetup{}
|
|||
func NewCryptoSetupClient(
|
||||
initialStream io.Writer,
|
||||
handshakeStream io.Writer,
|
||||
oneRTTStream io.Writer,
|
||||
connID protocol.ConnectionID,
|
||||
remoteAddr net.Addr,
|
||||
tp *TransportParameters,
|
||||
|
@ -137,7 +131,6 @@ func NewCryptoSetupClient(
|
|||
cs, clientHelloWritten := newCryptoSetup(
|
||||
initialStream,
|
||||
handshakeStream,
|
||||
oneRTTStream,
|
||||
connID,
|
||||
tp,
|
||||
runner,
|
||||
|
@ -155,7 +148,6 @@ func NewCryptoSetupClient(
|
|||
func NewCryptoSetupServer(
|
||||
initialStream io.Writer,
|
||||
handshakeStream io.Writer,
|
||||
oneRTTStream io.Writer,
|
||||
connID protocol.ConnectionID,
|
||||
remoteAddr net.Addr,
|
||||
tp *TransportParameters,
|
||||
|
@ -168,7 +160,6 @@ func NewCryptoSetupServer(
|
|||
cs, _ := newCryptoSetup(
|
||||
initialStream,
|
||||
handshakeStream,
|
||||
oneRTTStream,
|
||||
connID,
|
||||
tp,
|
||||
runner,
|
||||
|
@ -185,7 +176,6 @@ func NewCryptoSetupServer(
|
|||
func newCryptoSetup(
|
||||
initialStream io.Writer,
|
||||
handshakeStream io.Writer,
|
||||
oneRTTStream io.Writer,
|
||||
connID protocol.ConnectionID,
|
||||
tp *TransportParameters,
|
||||
runner handshakeRunner,
|
||||
|
@ -202,7 +192,6 @@ func newCryptoSetup(
|
|||
initialSealer: initialSealer,
|
||||
initialOpener: initialOpener,
|
||||
handshakeStream: handshakeStream,
|
||||
oneRTTStream: oneRTTStream,
|
||||
aead: newUpdatableAEAD(rttStats, logger),
|
||||
readEncLevel: protocol.EncryptionInitial,
|
||||
writeEncLevel: protocol.EncryptionInitial,
|
||||
|
@ -251,10 +240,6 @@ func (h *cryptoSetup) RunHandshake() {
|
|||
select {
|
||||
case <-handshakeComplete: // return when the handshake is done
|
||||
h.runner.OnHandshakeComplete()
|
||||
// send a session ticket
|
||||
if h.perspective == protocol.PerspectiveServer {
|
||||
h.maybeSendSessionTicket()
|
||||
}
|
||||
case <-h.closeChan:
|
||||
close(h.messageChan)
|
||||
// wait until the Handshake() go routine has returned
|
||||
|
@ -475,20 +460,13 @@ func (h *cryptoSetup) handlePeerParamsFromSessionStateImpl(data []byte) (*Transp
|
|||
}
|
||||
|
||||
// only valid for the server
|
||||
func (h *cryptoSetup) maybeSendSessionTicket() {
|
||||
func (h *cryptoSetup) GetSessionTicket() ([]byte, error) {
|
||||
var appData []byte
|
||||
// Save transport parameters to the session ticket if we're allowing 0-RTT.
|
||||
if h.tlsConf.MaxEarlyData > 0 {
|
||||
appData = (&sessionTicket{Parameters: h.ourParams}).Marshal()
|
||||
}
|
||||
ticket, err := h.conn.GetSessionTicket(appData)
|
||||
if err != nil {
|
||||
h.onError(alertInternalError, err.Error())
|
||||
return
|
||||
}
|
||||
if ticket != nil {
|
||||
h.oneRTTStream.Write(ticket)
|
||||
}
|
||||
return h.conn.GetSessionTicket(appData)
|
||||
}
|
||||
|
||||
// accept0RTT is called for the server when receiving the client's session ticket.
|
||||
|
|
|
@ -8,7 +8,6 @@ import (
|
|||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"errors"
|
||||
"io/ioutil"
|
||||
"math/big"
|
||||
"time"
|
||||
|
||||
|
@ -55,12 +54,11 @@ func (s *stream) Write(b []byte) (int, error) {
|
|||
var _ = Describe("Crypto Setup TLS", func() {
|
||||
var clientConf, serverConf *tls.Config
|
||||
|
||||
initStreams := func() (chan chunk, *stream /* initial */, *stream /* handshake */, *stream /* 1-RTT */) {
|
||||
initStreams := func() (chan chunk, *stream /* initial */, *stream /* handshake */) {
|
||||
chunkChan := make(chan chunk, 100)
|
||||
initialStream := newStream(chunkChan, protocol.EncryptionInitial)
|
||||
handshakeStream := newStream(chunkChan, protocol.EncryptionHandshake)
|
||||
oneRTTStream := newStream(chunkChan, protocol.Encryption1RTT)
|
||||
return chunkChan, initialStream, handshakeStream, oneRTTStream
|
||||
return chunkChan, initialStream, handshakeStream
|
||||
}
|
||||
|
||||
BeforeEach(func() {
|
||||
|
@ -89,7 +87,6 @@ var _ = Describe("Crypto Setup TLS", func() {
|
|||
server := NewCryptoSetupServer(
|
||||
&bytes.Buffer{},
|
||||
&bytes.Buffer{},
|
||||
ioutil.Discard,
|
||||
protocol.ConnectionID{},
|
||||
nil,
|
||||
&TransportParameters{},
|
||||
|
@ -117,11 +114,10 @@ var _ = Describe("Crypto Setup TLS", func() {
|
|||
sErrChan := make(chan error, 1)
|
||||
runner := NewMockHandshakeRunner(mockCtrl)
|
||||
runner.EXPECT().OnError(gomock.Any()).Do(func(e error) { sErrChan <- e })
|
||||
_, sInitialStream, sHandshakeStream, _ := initStreams()
|
||||
_, sInitialStream, sHandshakeStream := initStreams()
|
||||
server := NewCryptoSetupServer(
|
||||
sInitialStream,
|
||||
sHandshakeStream,
|
||||
ioutil.Discard,
|
||||
protocol.ConnectionID{},
|
||||
nil,
|
||||
&TransportParameters{},
|
||||
|
@ -153,13 +149,12 @@ var _ = Describe("Crypto Setup TLS", func() {
|
|||
|
||||
It("errors when a message is received at the wrong encryption level", func() {
|
||||
sErrChan := make(chan error, 1)
|
||||
_, sInitialStream, sHandshakeStream, _ := initStreams()
|
||||
_, sInitialStream, sHandshakeStream := initStreams()
|
||||
runner := NewMockHandshakeRunner(mockCtrl)
|
||||
runner.EXPECT().OnError(gomock.Any()).Do(func(e error) { sErrChan <- e })
|
||||
server := NewCryptoSetupServer(
|
||||
sInitialStream,
|
||||
sHandshakeStream,
|
||||
ioutil.Discard,
|
||||
protocol.ConnectionID{},
|
||||
nil,
|
||||
&TransportParameters{},
|
||||
|
@ -194,13 +189,12 @@ var _ = Describe("Crypto Setup TLS", func() {
|
|||
|
||||
It("returns Handshake() when handling a message fails", func() {
|
||||
sErrChan := make(chan error, 1)
|
||||
_, sInitialStream, sHandshakeStream, _ := initStreams()
|
||||
_, sInitialStream, sHandshakeStream := initStreams()
|
||||
runner := NewMockHandshakeRunner(mockCtrl)
|
||||
runner.EXPECT().OnError(gomock.Any()).Do(func(e error) { sErrChan <- e })
|
||||
server := NewCryptoSetupServer(
|
||||
sInitialStream,
|
||||
sHandshakeStream,
|
||||
ioutil.Discard,
|
||||
protocol.ConnectionID{},
|
||||
nil,
|
||||
&TransportParameters{},
|
||||
|
@ -230,11 +224,10 @@ var _ = Describe("Crypto Setup TLS", func() {
|
|||
})
|
||||
|
||||
It("returns Handshake() when it is closed", func() {
|
||||
_, sInitialStream, sHandshakeStream, _ := initStreams()
|
||||
_, sInitialStream, sHandshakeStream := initStreams()
|
||||
server := NewCryptoSetupServer(
|
||||
sInitialStream,
|
||||
sHandshakeStream,
|
||||
ioutil.Discard,
|
||||
protocol.ConnectionID{},
|
||||
nil,
|
||||
&TransportParameters{},
|
||||
|
@ -305,6 +298,11 @@ var _ = Describe("Crypto Setup TLS", func() {
|
|||
go func() {
|
||||
defer GinkgoRecover()
|
||||
server.RunHandshake()
|
||||
ticket, err := server.GetSessionTicket()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
if ticket != nil {
|
||||
client.HandleMessage(ticket, protocol.Encryption1RTT)
|
||||
}
|
||||
close(done)
|
||||
}()
|
||||
|
||||
|
@ -314,7 +312,7 @@ var _ = Describe("Crypto Setup TLS", func() {
|
|||
|
||||
handshakeWithTLSConf := func(clientConf, serverConf *tls.Config, enable0RTT bool) (CryptoSetup /* client */, error /* client error */, CryptoSetup /* server */, error /* server error */) {
|
||||
var cHandshakeComplete bool
|
||||
cChunkChan, cInitialStream, cHandshakeStream, cOneRTTStream := initStreams()
|
||||
cChunkChan, cInitialStream, cHandshakeStream := initStreams()
|
||||
cErrChan := make(chan error, 1)
|
||||
cRunner := NewMockHandshakeRunner(mockCtrl)
|
||||
cRunner.EXPECT().OnReceivedParams(gomock.Any())
|
||||
|
@ -323,7 +321,6 @@ var _ = Describe("Crypto Setup TLS", func() {
|
|||
client, _ := NewCryptoSetupClient(
|
||||
cInitialStream,
|
||||
cHandshakeStream,
|
||||
cOneRTTStream,
|
||||
protocol.ConnectionID{},
|
||||
nil,
|
||||
&TransportParameters{},
|
||||
|
@ -335,7 +332,7 @@ var _ = Describe("Crypto Setup TLS", func() {
|
|||
)
|
||||
|
||||
var sHandshakeComplete bool
|
||||
sChunkChan, sInitialStream, sHandshakeStream, sOneRTTStream := initStreams()
|
||||
sChunkChan, sInitialStream, sHandshakeStream := initStreams()
|
||||
sErrChan := make(chan error, 1)
|
||||
sRunner := NewMockHandshakeRunner(mockCtrl)
|
||||
sRunner.EXPECT().OnReceivedParams(gomock.Any())
|
||||
|
@ -345,7 +342,6 @@ var _ = Describe("Crypto Setup TLS", func() {
|
|||
server := NewCryptoSetupServer(
|
||||
sInitialStream,
|
||||
sHandshakeStream,
|
||||
sOneRTTStream,
|
||||
protocol.ConnectionID{},
|
||||
nil,
|
||||
&TransportParameters{StatelessResetToken: &token},
|
||||
|
@ -394,11 +390,10 @@ var _ = Describe("Crypto Setup TLS", func() {
|
|||
|
||||
It("signals when it has written the ClientHello", func() {
|
||||
runner := NewMockHandshakeRunner(mockCtrl)
|
||||
cChunkChan, cInitialStream, cHandshakeStream, _ := initStreams()
|
||||
cChunkChan, cInitialStream, cHandshakeStream := initStreams()
|
||||
client, chChan := NewCryptoSetupClient(
|
||||
cInitialStream,
|
||||
cHandshakeStream,
|
||||
ioutil.Discard,
|
||||
protocol.ConnectionID{},
|
||||
nil,
|
||||
&TransportParameters{},
|
||||
|
@ -431,7 +426,7 @@ var _ = Describe("Crypto Setup TLS", func() {
|
|||
|
||||
It("receives transport parameters", func() {
|
||||
var cTransportParametersRcvd, sTransportParametersRcvd *TransportParameters
|
||||
cChunkChan, cInitialStream, cHandshakeStream, _ := initStreams()
|
||||
cChunkChan, cInitialStream, cHandshakeStream := initStreams()
|
||||
cTransportParameters := &TransportParameters{MaxIdleTimeout: 0x42 * time.Second}
|
||||
cRunner := NewMockHandshakeRunner(mockCtrl)
|
||||
cRunner.EXPECT().OnReceivedParams(gomock.Any()).Do(func(tp *TransportParameters) { sTransportParametersRcvd = tp })
|
||||
|
@ -439,7 +434,6 @@ var _ = Describe("Crypto Setup TLS", func() {
|
|||
client, _ := NewCryptoSetupClient(
|
||||
cInitialStream,
|
||||
cHandshakeStream,
|
||||
ioutil.Discard,
|
||||
protocol.ConnectionID{},
|
||||
nil,
|
||||
cTransportParameters,
|
||||
|
@ -450,7 +444,7 @@ var _ = Describe("Crypto Setup TLS", func() {
|
|||
utils.DefaultLogger.WithPrefix("client"),
|
||||
)
|
||||
|
||||
sChunkChan, sInitialStream, sHandshakeStream, _ := initStreams()
|
||||
sChunkChan, sInitialStream, sHandshakeStream := initStreams()
|
||||
var token [16]byte
|
||||
sRunner := NewMockHandshakeRunner(mockCtrl)
|
||||
sRunner.EXPECT().OnReceivedParams(gomock.Any()).Do(func(tp *TransportParameters) { cTransportParametersRcvd = tp })
|
||||
|
@ -462,7 +456,6 @@ var _ = Describe("Crypto Setup TLS", func() {
|
|||
server := NewCryptoSetupServer(
|
||||
sInitialStream,
|
||||
sHandshakeStream,
|
||||
ioutil.Discard,
|
||||
protocol.ConnectionID{},
|
||||
nil,
|
||||
sTransportParameters,
|
||||
|
@ -487,14 +480,13 @@ var _ = Describe("Crypto Setup TLS", func() {
|
|||
|
||||
Context("with session tickets", func() {
|
||||
It("errors when the NewSessionTicket is sent at the wrong encryption level", func() {
|
||||
cChunkChan, cInitialStream, cHandshakeStream, _ := initStreams()
|
||||
cChunkChan, cInitialStream, cHandshakeStream := initStreams()
|
||||
cRunner := NewMockHandshakeRunner(mockCtrl)
|
||||
cRunner.EXPECT().OnReceivedParams(gomock.Any())
|
||||
cRunner.EXPECT().OnHandshakeComplete()
|
||||
client, _ := NewCryptoSetupClient(
|
||||
cInitialStream,
|
||||
cHandshakeStream,
|
||||
ioutil.Discard,
|
||||
protocol.ConnectionID{},
|
||||
nil,
|
||||
&TransportParameters{},
|
||||
|
@ -505,14 +497,13 @@ var _ = Describe("Crypto Setup TLS", func() {
|
|||
utils.DefaultLogger.WithPrefix("client"),
|
||||
)
|
||||
|
||||
sChunkChan, sInitialStream, sHandshakeStream, _ := initStreams()
|
||||
sChunkChan, sInitialStream, sHandshakeStream := initStreams()
|
||||
sRunner := NewMockHandshakeRunner(mockCtrl)
|
||||
sRunner.EXPECT().OnReceivedParams(gomock.Any())
|
||||
sRunner.EXPECT().OnHandshakeComplete()
|
||||
server := NewCryptoSetupServer(
|
||||
sInitialStream,
|
||||
sHandshakeStream,
|
||||
ioutil.Discard,
|
||||
protocol.ConnectionID{},
|
||||
nil,
|
||||
&TransportParameters{},
|
||||
|
@ -544,14 +535,13 @@ var _ = Describe("Crypto Setup TLS", func() {
|
|||
})
|
||||
|
||||
It("errors when handling the NewSessionTicket fails", func() {
|
||||
cChunkChan, cInitialStream, cHandshakeStream, _ := initStreams()
|
||||
cChunkChan, cInitialStream, cHandshakeStream := initStreams()
|
||||
cRunner := NewMockHandshakeRunner(mockCtrl)
|
||||
cRunner.EXPECT().OnReceivedParams(gomock.Any())
|
||||
cRunner.EXPECT().OnHandshakeComplete()
|
||||
client, _ := NewCryptoSetupClient(
|
||||
cInitialStream,
|
||||
cHandshakeStream,
|
||||
ioutil.Discard,
|
||||
protocol.ConnectionID{},
|
||||
nil,
|
||||
&TransportParameters{},
|
||||
|
@ -562,14 +552,13 @@ var _ = Describe("Crypto Setup TLS", func() {
|
|||
utils.DefaultLogger.WithPrefix("client"),
|
||||
)
|
||||
|
||||
sChunkChan, sInitialStream, sHandshakeStream, _ := initStreams()
|
||||
sChunkChan, sInitialStream, sHandshakeStream := initStreams()
|
||||
sRunner := NewMockHandshakeRunner(mockCtrl)
|
||||
sRunner.EXPECT().OnReceivedParams(gomock.Any())
|
||||
sRunner.EXPECT().OnHandshakeComplete()
|
||||
server := NewCryptoSetupServer(
|
||||
sInitialStream,
|
||||
sHandshakeStream,
|
||||
ioutil.Discard,
|
||||
protocol.ConnectionID{},
|
||||
nil,
|
||||
&TransportParameters{},
|
||||
|
@ -670,16 +659,16 @@ var _ = Describe("Crypto Setup TLS", func() {
|
|||
Expect(client.ConnectionState().DidResume).To(BeFalse())
|
||||
|
||||
csc.EXPECT().Get(gomock.Any()).Return(state, true)
|
||||
csc.EXPECT().Put(gomock.Any(), nil)
|
||||
csc.EXPECT().Put(gomock.Any(), gomock.Any()).MaxTimes(1)
|
||||
|
||||
cChunkChan, cInitialStream, cHandshakeStream, _ := initStreams()
|
||||
cChunkChan, cInitialStream, cHandshakeStream := initStreams()
|
||||
cRunner := NewMockHandshakeRunner(mockCtrl)
|
||||
cRunner.EXPECT().OnReceivedParams(gomock.Any())
|
||||
cRunner.EXPECT().OnHandshakeComplete()
|
||||
client, clientHelloChan := NewCryptoSetupClient(
|
||||
cInitialStream,
|
||||
cHandshakeStream,
|
||||
ioutil.Discard,
|
||||
protocol.ConnectionID{},
|
||||
nil,
|
||||
&TransportParameters{},
|
||||
|
@ -690,14 +679,13 @@ var _ = Describe("Crypto Setup TLS", func() {
|
|||
utils.DefaultLogger.WithPrefix("client"),
|
||||
)
|
||||
|
||||
sChunkChan, sInitialStream, sHandshakeStream, _ := initStreams()
|
||||
sChunkChan, sInitialStream, sHandshakeStream := initStreams()
|
||||
sRunner := NewMockHandshakeRunner(mockCtrl)
|
||||
sRunner.EXPECT().OnReceivedParams(gomock.Any())
|
||||
sRunner.EXPECT().OnHandshakeComplete()
|
||||
server = NewCryptoSetupServer(
|
||||
sInitialStream,
|
||||
sHandshakeStream,
|
||||
ioutil.Discard,
|
||||
protocol.ConnectionID{},
|
||||
nil,
|
||||
&TransportParameters{},
|
||||
|
|
|
@ -72,6 +72,7 @@ type CryptoSetup interface {
|
|||
RunHandshake()
|
||||
io.Closer
|
||||
ChangeConnectionID(protocol.ConnectionID)
|
||||
GetSessionTicket() ([]byte, error)
|
||||
|
||||
HandleMessage([]byte, protocol.EncryptionLevel) bool
|
||||
SetLargest1RTTAcked(protocol.PacketNumber)
|
||||
|
|
|
@ -208,6 +208,21 @@ func (mr *MockCryptoSetupMockRecorder) GetInitialSealer() *gomock.Call {
|
|||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetInitialSealer", reflect.TypeOf((*MockCryptoSetup)(nil).GetInitialSealer))
|
||||
}
|
||||
|
||||
// GetSessionTicket mocks base method
|
||||
func (m *MockCryptoSetup) GetSessionTicket() ([]byte, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetSessionTicket")
|
||||
ret0, _ := ret[0].([]byte)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetSessionTicket indicates an expected call of GetSessionTicket
|
||||
func (mr *MockCryptoSetupMockRecorder) GetSessionTicket() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSessionTicket", reflect.TypeOf((*MockCryptoSetup)(nil).GetSessionTicket))
|
||||
}
|
||||
|
||||
// HandleMessage mocks base method
|
||||
func (m *MockCryptoSetup) HandleMessage(arg0 []byte, arg1 protocol.EncryptionLevel) bool {
|
||||
m.ctrl.T.Helper()
|
||||
|
|
|
@ -110,7 +110,7 @@ const MinStreamFrameSize ByteCount = 128
|
|||
|
||||
// MaxPostHandshakeCryptoFrameSize is the maximum size of CRYPTO frames
|
||||
// we send after the handshake completes.
|
||||
const MaxPostHandshakeCryptoFrameSize ByteCount = 1000
|
||||
const MaxPostHandshakeCryptoFrameSize = 1000
|
||||
|
||||
// MaxAckFrameSize is the maximum size for an ACK frame that we write
|
||||
// Due to the varint encoding, ACK frames can grow (almost) indefinitely large.
|
||||
|
|
21
session.go
21
session.go
|
@ -52,6 +52,7 @@ type cryptoStreamHandler interface {
|
|||
ChangeConnectionID(protocol.ConnectionID)
|
||||
SetLargest1RTTAcked(protocol.PacketNumber)
|
||||
DropHandshakeKeys()
|
||||
GetSessionTicket() ([]byte, error)
|
||||
io.Closer
|
||||
ConnectionState() handshake.ConnectionState
|
||||
}
|
||||
|
@ -141,6 +142,7 @@ type session struct {
|
|||
frameParser wire.FrameParser
|
||||
packer packer
|
||||
|
||||
oneRTTStream cryptoStream // only set for the server
|
||||
cryptoStreamHandler cryptoStreamHandler
|
||||
|
||||
receivedPackets chan *receivedPacket
|
||||
|
@ -213,6 +215,7 @@ var newSession = func(
|
|||
handshakeDestConnID: destConnID,
|
||||
srcConnIDLen: srcConnID.Len(),
|
||||
tokenGenerator: tokenGenerator,
|
||||
oneRTTStream: newCryptoStream(),
|
||||
perspective: protocol.PerspectiveServer,
|
||||
handshakeCompleteChan: make(chan struct{}),
|
||||
logger: logger,
|
||||
|
@ -244,7 +247,6 @@ var newSession = func(
|
|||
s.sentPacketHandler = ackhandler.NewSentPacketHandler(0, s.rttStats, s.traceCallback, s.logger)
|
||||
initialStream := newCryptoStream()
|
||||
handshakeStream := newCryptoStream()
|
||||
oneRTTStream := newPostHandshakeCryptoStream(s.framer)
|
||||
params := &handshake.TransportParameters{
|
||||
InitialMaxStreamDataBidiLocal: protocol.InitialMaxStreamData,
|
||||
InitialMaxStreamDataBidiRemote: protocol.InitialMaxStreamData,
|
||||
|
@ -263,7 +265,6 @@ var newSession = func(
|
|||
cs := handshake.NewCryptoSetupServer(
|
||||
initialStream,
|
||||
handshakeStream,
|
||||
oneRTTStream,
|
||||
clientDestConnID,
|
||||
conn.RemoteAddr(),
|
||||
params,
|
||||
|
@ -297,7 +298,7 @@ var newSession = func(
|
|||
s.version,
|
||||
)
|
||||
s.unpacker = newPacketUnpacker(cs, s.version)
|
||||
s.cryptoStreamManager = newCryptoStreamManager(cs, initialStream, handshakeStream, oneRTTStream)
|
||||
s.cryptoStreamManager = newCryptoStreamManager(cs, initialStream, handshakeStream, s.oneRTTStream)
|
||||
return s
|
||||
}
|
||||
|
||||
|
@ -348,7 +349,6 @@ var newClientSession = func(
|
|||
s.sentPacketHandler = ackhandler.NewSentPacketHandler(initialPacketNumber, s.rttStats, s.traceCallback, s.logger)
|
||||
initialStream := newCryptoStream()
|
||||
handshakeStream := newCryptoStream()
|
||||
oneRTTStream := newPostHandshakeCryptoStream(s.framer)
|
||||
params := &handshake.TransportParameters{
|
||||
InitialMaxStreamDataBidiRemote: protocol.InitialMaxStreamData,
|
||||
InitialMaxStreamDataBidiLocal: protocol.InitialMaxStreamData,
|
||||
|
@ -365,7 +365,6 @@ var newClientSession = func(
|
|||
cs, clientHelloWritten := handshake.NewCryptoSetupClient(
|
||||
initialStream,
|
||||
handshakeStream,
|
||||
oneRTTStream,
|
||||
destConnID,
|
||||
conn.RemoteAddr(),
|
||||
params,
|
||||
|
@ -382,7 +381,7 @@ var newClientSession = func(
|
|||
)
|
||||
s.clientHelloWritten = clientHelloWritten
|
||||
s.cryptoStreamHandler = cs
|
||||
s.cryptoStreamManager = newCryptoStreamManager(cs, initialStream, handshakeStream, oneRTTStream)
|
||||
s.cryptoStreamManager = newCryptoStreamManager(cs, initialStream, handshakeStream, newCryptoStream())
|
||||
s.unpacker = newPacketUnpacker(cs, s.version)
|
||||
s.packer = newPacketPacker(
|
||||
srcConnID,
|
||||
|
@ -633,6 +632,16 @@ func (s *session) handleHandshakeComplete() {
|
|||
s.sentPacketHandler.SetHandshakeComplete()
|
||||
|
||||
if s.perspective == protocol.PerspectiveServer {
|
||||
ticket, err := s.cryptoStreamHandler.GetSessionTicket()
|
||||
if err != nil {
|
||||
s.closeLocal(err)
|
||||
}
|
||||
if ticket != nil {
|
||||
s.oneRTTStream.Write(ticket)
|
||||
for s.oneRTTStream.HasData() {
|
||||
s.queueControlFrame(s.oneRTTStream.PopCryptoFrame(protocol.MaxPostHandshakeCryptoFrameSize))
|
||||
}
|
||||
}
|
||||
token, err := s.tokenGenerator.NewToken(s.conn.RemoteAddr())
|
||||
if err != nil {
|
||||
s.closeLocal(err)
|
||||
|
|
|
@ -1191,6 +1191,7 @@ var _ = Describe("Session", func() {
|
|||
<-finishHandshake
|
||||
cryptoSetup.EXPECT().RunHandshake()
|
||||
cryptoSetup.EXPECT().DropHandshakeKeys()
|
||||
cryptoSetup.EXPECT().GetSessionTicket()
|
||||
close(sess.handshakeCompleteChan)
|
||||
sess.run()
|
||||
}()
|
||||
|
@ -1210,6 +1211,47 @@ var _ = Describe("Session", func() {
|
|||
Eventually(sess.Context().Done()).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("sends a session ticket when the handshake completes", func() {
|
||||
const size = protocol.MaxPostHandshakeCryptoFrameSize * 3 / 2
|
||||
packer.EXPECT().PackPacket().AnyTimes()
|
||||
finishHandshake := make(chan struct{})
|
||||
sessionRunner.EXPECT().Retire(clientDestConnID)
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
<-finishHandshake
|
||||
cryptoSetup.EXPECT().RunHandshake()
|
||||
cryptoSetup.EXPECT().DropHandshakeKeys()
|
||||
cryptoSetup.EXPECT().GetSessionTicket().Return(make([]byte, size), nil)
|
||||
close(sess.handshakeCompleteChan)
|
||||
sess.run()
|
||||
}()
|
||||
|
||||
handshakeCtx := sess.HandshakeComplete()
|
||||
Consistently(handshakeCtx.Done()).ShouldNot(BeClosed())
|
||||
mconn.EXPECT().RemoteAddr().Return(&net.UDPAddr{}) // the remote addr is needed for the token
|
||||
close(finishHandshake)
|
||||
Eventually(handshakeCtx.Done()).Should(BeClosed())
|
||||
frames, _ := sess.framer.AppendControlFrames(nil, protocol.MaxByteCount)
|
||||
var count int
|
||||
var s int
|
||||
for _, f := range frames {
|
||||
if cf, ok := f.Frame.(*wire.CryptoFrame); ok {
|
||||
count++
|
||||
s += len(cf.Data)
|
||||
Expect(f.Length(sess.version)).To(BeNumerically("<=", protocol.MaxPostHandshakeCryptoFrameSize))
|
||||
}
|
||||
}
|
||||
Expect(size).To(BeEquivalentTo(s))
|
||||
// make sure the go routine returns
|
||||
streamManager.EXPECT().CloseWithError(gomock.Any())
|
||||
expectReplaceWithClosed()
|
||||
packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil)
|
||||
cryptoSetup.EXPECT().Close()
|
||||
mconn.EXPECT().Write(gomock.Any())
|
||||
sess.shutdown()
|
||||
Eventually(sess.Context().Done()).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("doesn't cancel the HandshakeComplete context when the handshake fails", func() {
|
||||
packer.EXPECT().PackPacket().AnyTimes()
|
||||
streamManager.EXPECT().CloseWithError(gomock.Any())
|
||||
|
@ -1247,6 +1289,7 @@ var _ = Describe("Session", func() {
|
|||
defer GinkgoRecover()
|
||||
cryptoSetup.EXPECT().RunHandshake()
|
||||
cryptoSetup.EXPECT().DropHandshakeKeys()
|
||||
cryptoSetup.EXPECT().GetSessionTicket()
|
||||
mconn.EXPECT().RemoteAddr().Return(&net.UDPAddr{}) // the remote addr is needed for the token
|
||||
mconn.EXPECT().Write(gomock.Any())
|
||||
close(sess.handshakeCompleteChan)
|
||||
|
@ -1492,6 +1535,7 @@ var _ = Describe("Session", func() {
|
|||
go func() {
|
||||
defer GinkgoRecover()
|
||||
cryptoSetup.EXPECT().RunHandshake().MaxTimes(1)
|
||||
cryptoSetup.EXPECT().GetSessionTicket().MaxTimes(1)
|
||||
cryptoSetup.EXPECT().DropHandshakeKeys().MaxTimes(1)
|
||||
mconn.EXPECT().RemoteAddr().Return(&net.UDPAddr{})
|
||||
close(sess.handshakeCompleteChan)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue