mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-04 20:57:36 +03:00
refactor how session tickets are sent
Previously, RunHandshake() would send the session tickets. Now, the session initiates the sending of the session ticket. This simplifies the setup a bit, and it will make it possible to include the RTT estimate in the session ticket without accessing the RTTStats concurrently.
This commit is contained in:
parent
3e32a693ad
commit
8cde4ab638
9 changed files with 103 additions and 140 deletions
|
@ -22,35 +22,6 @@ type cryptoStream interface {
|
||||||
PopCryptoFrame(protocol.ByteCount) *wire.CryptoFrame
|
PopCryptoFrame(protocol.ByteCount) *wire.CryptoFrame
|
||||||
}
|
}
|
||||||
|
|
||||||
type postHandshakeCryptoStream struct {
|
|
||||||
cryptoStream
|
|
||||||
|
|
||||||
framer framer
|
|
||||||
}
|
|
||||||
|
|
||||||
func newPostHandshakeCryptoStream(framer framer) cryptoStream {
|
|
||||||
return &postHandshakeCryptoStream{
|
|
||||||
cryptoStream: newCryptoStream(),
|
|
||||||
framer: framer,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Write writes post-handshake messages.
|
|
||||||
// For simplicity, post-handshake crypto messages are treated as control frames.
|
|
||||||
// The framer functions as a stack (LIFO), so if there are multiple writes,
|
|
||||||
// they will be returned in the opposite order.
|
|
||||||
// This is acceptable, since post-handshake crypto messages are very rare.
|
|
||||||
func (s *postHandshakeCryptoStream) Write(p []byte) (int, error) {
|
|
||||||
n, err := s.cryptoStream.Write(p)
|
|
||||||
if err != nil {
|
|
||||||
return n, err
|
|
||||||
}
|
|
||||||
for s.cryptoStream.HasData() {
|
|
||||||
s.framer.QueueControlFrame(s.PopCryptoFrame(protocol.MaxPostHandshakeCryptoFrameSize))
|
|
||||||
}
|
|
||||||
return n, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type cryptoStreamImpl struct {
|
type cryptoStreamImpl struct {
|
||||||
queue *frameSorter
|
queue *frameSorter
|
||||||
msgBuf []byte
|
msgBuf []byte
|
||||||
|
|
|
@ -180,46 +180,3 @@ var _ = Describe("Crypto Stream", func() {
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
var _ = Describe("Post Handshake Crypto Stream", func() {
|
|
||||||
var (
|
|
||||||
cs cryptoStream
|
|
||||||
framer framer
|
|
||||||
)
|
|
||||||
|
|
||||||
BeforeEach(func() {
|
|
||||||
framer = newFramer(NewMockStreamGetter(mockCtrl), protocol.VersionTLS)
|
|
||||||
cs = newPostHandshakeCryptoStream(framer)
|
|
||||||
})
|
|
||||||
|
|
||||||
It("queues CRYPTO frames when writing data", func() {
|
|
||||||
n, err := cs.Write([]byte("foo"))
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(n).To(Equal(3))
|
|
||||||
n, err = cs.Write([]byte("bar"))
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(n).To(Equal(3))
|
|
||||||
frames, _ := framer.AppendControlFrames(nil, 1000)
|
|
||||||
Expect(frames).To(HaveLen(2))
|
|
||||||
fs := []wire.Frame{frames[0].Frame, frames[1].Frame}
|
|
||||||
Expect(fs).To(ContainElement(&wire.CryptoFrame{Data: []byte("foo")}))
|
|
||||||
Expect(fs).To(ContainElement(&wire.CryptoFrame{Data: []byte("bar"), Offset: 3}))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("splits large writes into multiple frames", func() {
|
|
||||||
size := 10 * protocol.MaxPostHandshakeCryptoFrameSize
|
|
||||||
n, err := cs.Write(make([]byte, size))
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(n).To(BeEquivalentTo(size))
|
|
||||||
frames, _ := framer.AppendControlFrames(nil, protocol.MaxByteCount)
|
|
||||||
Expect(frames).To(HaveLen(11)) // one more for framing overhead
|
|
||||||
var dataLen int
|
|
||||||
for _, f := range frames {
|
|
||||||
Expect(f.Frame.Length(protocol.VersionTLS)).To(BeNumerically("<=", protocol.MaxPostHandshakeCryptoFrameSize))
|
|
||||||
Expect(f.Frame).To(BeAssignableToTypeOf(&wire.CryptoFrame{}))
|
|
||||||
dataLen += len(f.Frame.(*wire.CryptoFrame).Data)
|
|
||||||
}
|
|
||||||
Expect(dataLen).To(BeEquivalentTo(size))
|
|
||||||
})
|
|
||||||
|
|
||||||
})
|
|
||||||
|
|
|
@ -16,12 +16,8 @@ import (
|
||||||
"github.com/marten-seemann/qtls"
|
"github.com/marten-seemann/qtls"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
// TLS unexpected_message alert
|
||||||
// TLS unexpected_message alert
|
const alertUnexpectedMessage uint8 = 10
|
||||||
alertUnexpectedMessage uint8 = 10
|
|
||||||
// TLS internal error
|
|
||||||
alertInternalError uint8 = 80
|
|
||||||
)
|
|
||||||
|
|
||||||
type messageType uint8
|
type messageType uint8
|
||||||
|
|
||||||
|
@ -111,7 +107,6 @@ type cryptoSetup struct {
|
||||||
handshakeOpener LongHeaderOpener
|
handshakeOpener LongHeaderOpener
|
||||||
handshakeSealer LongHeaderSealer
|
handshakeSealer LongHeaderSealer
|
||||||
|
|
||||||
oneRTTStream io.Writer
|
|
||||||
aead *updatableAEAD
|
aead *updatableAEAD
|
||||||
has1RTTSealer bool
|
has1RTTSealer bool
|
||||||
has1RTTOpener bool
|
has1RTTOpener bool
|
||||||
|
@ -124,7 +119,6 @@ var _ CryptoSetup = &cryptoSetup{}
|
||||||
func NewCryptoSetupClient(
|
func NewCryptoSetupClient(
|
||||||
initialStream io.Writer,
|
initialStream io.Writer,
|
||||||
handshakeStream io.Writer,
|
handshakeStream io.Writer,
|
||||||
oneRTTStream io.Writer,
|
|
||||||
connID protocol.ConnectionID,
|
connID protocol.ConnectionID,
|
||||||
remoteAddr net.Addr,
|
remoteAddr net.Addr,
|
||||||
tp *TransportParameters,
|
tp *TransportParameters,
|
||||||
|
@ -137,7 +131,6 @@ func NewCryptoSetupClient(
|
||||||
cs, clientHelloWritten := newCryptoSetup(
|
cs, clientHelloWritten := newCryptoSetup(
|
||||||
initialStream,
|
initialStream,
|
||||||
handshakeStream,
|
handshakeStream,
|
||||||
oneRTTStream,
|
|
||||||
connID,
|
connID,
|
||||||
tp,
|
tp,
|
||||||
runner,
|
runner,
|
||||||
|
@ -155,7 +148,6 @@ func NewCryptoSetupClient(
|
||||||
func NewCryptoSetupServer(
|
func NewCryptoSetupServer(
|
||||||
initialStream io.Writer,
|
initialStream io.Writer,
|
||||||
handshakeStream io.Writer,
|
handshakeStream io.Writer,
|
||||||
oneRTTStream io.Writer,
|
|
||||||
connID protocol.ConnectionID,
|
connID protocol.ConnectionID,
|
||||||
remoteAddr net.Addr,
|
remoteAddr net.Addr,
|
||||||
tp *TransportParameters,
|
tp *TransportParameters,
|
||||||
|
@ -168,7 +160,6 @@ func NewCryptoSetupServer(
|
||||||
cs, _ := newCryptoSetup(
|
cs, _ := newCryptoSetup(
|
||||||
initialStream,
|
initialStream,
|
||||||
handshakeStream,
|
handshakeStream,
|
||||||
oneRTTStream,
|
|
||||||
connID,
|
connID,
|
||||||
tp,
|
tp,
|
||||||
runner,
|
runner,
|
||||||
|
@ -185,7 +176,6 @@ func NewCryptoSetupServer(
|
||||||
func newCryptoSetup(
|
func newCryptoSetup(
|
||||||
initialStream io.Writer,
|
initialStream io.Writer,
|
||||||
handshakeStream io.Writer,
|
handshakeStream io.Writer,
|
||||||
oneRTTStream io.Writer,
|
|
||||||
connID protocol.ConnectionID,
|
connID protocol.ConnectionID,
|
||||||
tp *TransportParameters,
|
tp *TransportParameters,
|
||||||
runner handshakeRunner,
|
runner handshakeRunner,
|
||||||
|
@ -202,7 +192,6 @@ func newCryptoSetup(
|
||||||
initialSealer: initialSealer,
|
initialSealer: initialSealer,
|
||||||
initialOpener: initialOpener,
|
initialOpener: initialOpener,
|
||||||
handshakeStream: handshakeStream,
|
handshakeStream: handshakeStream,
|
||||||
oneRTTStream: oneRTTStream,
|
|
||||||
aead: newUpdatableAEAD(rttStats, logger),
|
aead: newUpdatableAEAD(rttStats, logger),
|
||||||
readEncLevel: protocol.EncryptionInitial,
|
readEncLevel: protocol.EncryptionInitial,
|
||||||
writeEncLevel: protocol.EncryptionInitial,
|
writeEncLevel: protocol.EncryptionInitial,
|
||||||
|
@ -251,10 +240,6 @@ func (h *cryptoSetup) RunHandshake() {
|
||||||
select {
|
select {
|
||||||
case <-handshakeComplete: // return when the handshake is done
|
case <-handshakeComplete: // return when the handshake is done
|
||||||
h.runner.OnHandshakeComplete()
|
h.runner.OnHandshakeComplete()
|
||||||
// send a session ticket
|
|
||||||
if h.perspective == protocol.PerspectiveServer {
|
|
||||||
h.maybeSendSessionTicket()
|
|
||||||
}
|
|
||||||
case <-h.closeChan:
|
case <-h.closeChan:
|
||||||
close(h.messageChan)
|
close(h.messageChan)
|
||||||
// wait until the Handshake() go routine has returned
|
// wait until the Handshake() go routine has returned
|
||||||
|
@ -475,20 +460,13 @@ func (h *cryptoSetup) handlePeerParamsFromSessionStateImpl(data []byte) (*Transp
|
||||||
}
|
}
|
||||||
|
|
||||||
// only valid for the server
|
// only valid for the server
|
||||||
func (h *cryptoSetup) maybeSendSessionTicket() {
|
func (h *cryptoSetup) GetSessionTicket() ([]byte, error) {
|
||||||
var appData []byte
|
var appData []byte
|
||||||
// Save transport parameters to the session ticket if we're allowing 0-RTT.
|
// Save transport parameters to the session ticket if we're allowing 0-RTT.
|
||||||
if h.tlsConf.MaxEarlyData > 0 {
|
if h.tlsConf.MaxEarlyData > 0 {
|
||||||
appData = (&sessionTicket{Parameters: h.ourParams}).Marshal()
|
appData = (&sessionTicket{Parameters: h.ourParams}).Marshal()
|
||||||
}
|
}
|
||||||
ticket, err := h.conn.GetSessionTicket(appData)
|
return h.conn.GetSessionTicket(appData)
|
||||||
if err != nil {
|
|
||||||
h.onError(alertInternalError, err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if ticket != nil {
|
|
||||||
h.oneRTTStream.Write(ticket)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// accept0RTT is called for the server when receiving the client's session ticket.
|
// accept0RTT is called for the server when receiving the client's session ticket.
|
||||||
|
|
|
@ -8,7 +8,6 @@ import (
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"crypto/x509/pkix"
|
"crypto/x509/pkix"
|
||||||
"errors"
|
"errors"
|
||||||
"io/ioutil"
|
|
||||||
"math/big"
|
"math/big"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -55,12 +54,11 @@ func (s *stream) Write(b []byte) (int, error) {
|
||||||
var _ = Describe("Crypto Setup TLS", func() {
|
var _ = Describe("Crypto Setup TLS", func() {
|
||||||
var clientConf, serverConf *tls.Config
|
var clientConf, serverConf *tls.Config
|
||||||
|
|
||||||
initStreams := func() (chan chunk, *stream /* initial */, *stream /* handshake */, *stream /* 1-RTT */) {
|
initStreams := func() (chan chunk, *stream /* initial */, *stream /* handshake */) {
|
||||||
chunkChan := make(chan chunk, 100)
|
chunkChan := make(chan chunk, 100)
|
||||||
initialStream := newStream(chunkChan, protocol.EncryptionInitial)
|
initialStream := newStream(chunkChan, protocol.EncryptionInitial)
|
||||||
handshakeStream := newStream(chunkChan, protocol.EncryptionHandshake)
|
handshakeStream := newStream(chunkChan, protocol.EncryptionHandshake)
|
||||||
oneRTTStream := newStream(chunkChan, protocol.Encryption1RTT)
|
return chunkChan, initialStream, handshakeStream
|
||||||
return chunkChan, initialStream, handshakeStream, oneRTTStream
|
|
||||||
}
|
}
|
||||||
|
|
||||||
BeforeEach(func() {
|
BeforeEach(func() {
|
||||||
|
@ -89,7 +87,6 @@ var _ = Describe("Crypto Setup TLS", func() {
|
||||||
server := NewCryptoSetupServer(
|
server := NewCryptoSetupServer(
|
||||||
&bytes.Buffer{},
|
&bytes.Buffer{},
|
||||||
&bytes.Buffer{},
|
&bytes.Buffer{},
|
||||||
ioutil.Discard,
|
|
||||||
protocol.ConnectionID{},
|
protocol.ConnectionID{},
|
||||||
nil,
|
nil,
|
||||||
&TransportParameters{},
|
&TransportParameters{},
|
||||||
|
@ -117,11 +114,10 @@ var _ = Describe("Crypto Setup TLS", func() {
|
||||||
sErrChan := make(chan error, 1)
|
sErrChan := make(chan error, 1)
|
||||||
runner := NewMockHandshakeRunner(mockCtrl)
|
runner := NewMockHandshakeRunner(mockCtrl)
|
||||||
runner.EXPECT().OnError(gomock.Any()).Do(func(e error) { sErrChan <- e })
|
runner.EXPECT().OnError(gomock.Any()).Do(func(e error) { sErrChan <- e })
|
||||||
_, sInitialStream, sHandshakeStream, _ := initStreams()
|
_, sInitialStream, sHandshakeStream := initStreams()
|
||||||
server := NewCryptoSetupServer(
|
server := NewCryptoSetupServer(
|
||||||
sInitialStream,
|
sInitialStream,
|
||||||
sHandshakeStream,
|
sHandshakeStream,
|
||||||
ioutil.Discard,
|
|
||||||
protocol.ConnectionID{},
|
protocol.ConnectionID{},
|
||||||
nil,
|
nil,
|
||||||
&TransportParameters{},
|
&TransportParameters{},
|
||||||
|
@ -153,13 +149,12 @@ var _ = Describe("Crypto Setup TLS", func() {
|
||||||
|
|
||||||
It("errors when a message is received at the wrong encryption level", func() {
|
It("errors when a message is received at the wrong encryption level", func() {
|
||||||
sErrChan := make(chan error, 1)
|
sErrChan := make(chan error, 1)
|
||||||
_, sInitialStream, sHandshakeStream, _ := initStreams()
|
_, sInitialStream, sHandshakeStream := initStreams()
|
||||||
runner := NewMockHandshakeRunner(mockCtrl)
|
runner := NewMockHandshakeRunner(mockCtrl)
|
||||||
runner.EXPECT().OnError(gomock.Any()).Do(func(e error) { sErrChan <- e })
|
runner.EXPECT().OnError(gomock.Any()).Do(func(e error) { sErrChan <- e })
|
||||||
server := NewCryptoSetupServer(
|
server := NewCryptoSetupServer(
|
||||||
sInitialStream,
|
sInitialStream,
|
||||||
sHandshakeStream,
|
sHandshakeStream,
|
||||||
ioutil.Discard,
|
|
||||||
protocol.ConnectionID{},
|
protocol.ConnectionID{},
|
||||||
nil,
|
nil,
|
||||||
&TransportParameters{},
|
&TransportParameters{},
|
||||||
|
@ -194,13 +189,12 @@ var _ = Describe("Crypto Setup TLS", func() {
|
||||||
|
|
||||||
It("returns Handshake() when handling a message fails", func() {
|
It("returns Handshake() when handling a message fails", func() {
|
||||||
sErrChan := make(chan error, 1)
|
sErrChan := make(chan error, 1)
|
||||||
_, sInitialStream, sHandshakeStream, _ := initStreams()
|
_, sInitialStream, sHandshakeStream := initStreams()
|
||||||
runner := NewMockHandshakeRunner(mockCtrl)
|
runner := NewMockHandshakeRunner(mockCtrl)
|
||||||
runner.EXPECT().OnError(gomock.Any()).Do(func(e error) { sErrChan <- e })
|
runner.EXPECT().OnError(gomock.Any()).Do(func(e error) { sErrChan <- e })
|
||||||
server := NewCryptoSetupServer(
|
server := NewCryptoSetupServer(
|
||||||
sInitialStream,
|
sInitialStream,
|
||||||
sHandshakeStream,
|
sHandshakeStream,
|
||||||
ioutil.Discard,
|
|
||||||
protocol.ConnectionID{},
|
protocol.ConnectionID{},
|
||||||
nil,
|
nil,
|
||||||
&TransportParameters{},
|
&TransportParameters{},
|
||||||
|
@ -230,11 +224,10 @@ var _ = Describe("Crypto Setup TLS", func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
It("returns Handshake() when it is closed", func() {
|
It("returns Handshake() when it is closed", func() {
|
||||||
_, sInitialStream, sHandshakeStream, _ := initStreams()
|
_, sInitialStream, sHandshakeStream := initStreams()
|
||||||
server := NewCryptoSetupServer(
|
server := NewCryptoSetupServer(
|
||||||
sInitialStream,
|
sInitialStream,
|
||||||
sHandshakeStream,
|
sHandshakeStream,
|
||||||
ioutil.Discard,
|
|
||||||
protocol.ConnectionID{},
|
protocol.ConnectionID{},
|
||||||
nil,
|
nil,
|
||||||
&TransportParameters{},
|
&TransportParameters{},
|
||||||
|
@ -305,6 +298,11 @@ var _ = Describe("Crypto Setup TLS", func() {
|
||||||
go func() {
|
go func() {
|
||||||
defer GinkgoRecover()
|
defer GinkgoRecover()
|
||||||
server.RunHandshake()
|
server.RunHandshake()
|
||||||
|
ticket, err := server.GetSessionTicket()
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
if ticket != nil {
|
||||||
|
client.HandleMessage(ticket, protocol.Encryption1RTT)
|
||||||
|
}
|
||||||
close(done)
|
close(done)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
@ -314,7 +312,7 @@ var _ = Describe("Crypto Setup TLS", func() {
|
||||||
|
|
||||||
handshakeWithTLSConf := func(clientConf, serverConf *tls.Config, enable0RTT bool) (CryptoSetup /* client */, error /* client error */, CryptoSetup /* server */, error /* server error */) {
|
handshakeWithTLSConf := func(clientConf, serverConf *tls.Config, enable0RTT bool) (CryptoSetup /* client */, error /* client error */, CryptoSetup /* server */, error /* server error */) {
|
||||||
var cHandshakeComplete bool
|
var cHandshakeComplete bool
|
||||||
cChunkChan, cInitialStream, cHandshakeStream, cOneRTTStream := initStreams()
|
cChunkChan, cInitialStream, cHandshakeStream := initStreams()
|
||||||
cErrChan := make(chan error, 1)
|
cErrChan := make(chan error, 1)
|
||||||
cRunner := NewMockHandshakeRunner(mockCtrl)
|
cRunner := NewMockHandshakeRunner(mockCtrl)
|
||||||
cRunner.EXPECT().OnReceivedParams(gomock.Any())
|
cRunner.EXPECT().OnReceivedParams(gomock.Any())
|
||||||
|
@ -323,7 +321,6 @@ var _ = Describe("Crypto Setup TLS", func() {
|
||||||
client, _ := NewCryptoSetupClient(
|
client, _ := NewCryptoSetupClient(
|
||||||
cInitialStream,
|
cInitialStream,
|
||||||
cHandshakeStream,
|
cHandshakeStream,
|
||||||
cOneRTTStream,
|
|
||||||
protocol.ConnectionID{},
|
protocol.ConnectionID{},
|
||||||
nil,
|
nil,
|
||||||
&TransportParameters{},
|
&TransportParameters{},
|
||||||
|
@ -335,7 +332,7 @@ var _ = Describe("Crypto Setup TLS", func() {
|
||||||
)
|
)
|
||||||
|
|
||||||
var sHandshakeComplete bool
|
var sHandshakeComplete bool
|
||||||
sChunkChan, sInitialStream, sHandshakeStream, sOneRTTStream := initStreams()
|
sChunkChan, sInitialStream, sHandshakeStream := initStreams()
|
||||||
sErrChan := make(chan error, 1)
|
sErrChan := make(chan error, 1)
|
||||||
sRunner := NewMockHandshakeRunner(mockCtrl)
|
sRunner := NewMockHandshakeRunner(mockCtrl)
|
||||||
sRunner.EXPECT().OnReceivedParams(gomock.Any())
|
sRunner.EXPECT().OnReceivedParams(gomock.Any())
|
||||||
|
@ -345,7 +342,6 @@ var _ = Describe("Crypto Setup TLS", func() {
|
||||||
server := NewCryptoSetupServer(
|
server := NewCryptoSetupServer(
|
||||||
sInitialStream,
|
sInitialStream,
|
||||||
sHandshakeStream,
|
sHandshakeStream,
|
||||||
sOneRTTStream,
|
|
||||||
protocol.ConnectionID{},
|
protocol.ConnectionID{},
|
||||||
nil,
|
nil,
|
||||||
&TransportParameters{StatelessResetToken: &token},
|
&TransportParameters{StatelessResetToken: &token},
|
||||||
|
@ -394,11 +390,10 @@ var _ = Describe("Crypto Setup TLS", func() {
|
||||||
|
|
||||||
It("signals when it has written the ClientHello", func() {
|
It("signals when it has written the ClientHello", func() {
|
||||||
runner := NewMockHandshakeRunner(mockCtrl)
|
runner := NewMockHandshakeRunner(mockCtrl)
|
||||||
cChunkChan, cInitialStream, cHandshakeStream, _ := initStreams()
|
cChunkChan, cInitialStream, cHandshakeStream := initStreams()
|
||||||
client, chChan := NewCryptoSetupClient(
|
client, chChan := NewCryptoSetupClient(
|
||||||
cInitialStream,
|
cInitialStream,
|
||||||
cHandshakeStream,
|
cHandshakeStream,
|
||||||
ioutil.Discard,
|
|
||||||
protocol.ConnectionID{},
|
protocol.ConnectionID{},
|
||||||
nil,
|
nil,
|
||||||
&TransportParameters{},
|
&TransportParameters{},
|
||||||
|
@ -431,7 +426,7 @@ var _ = Describe("Crypto Setup TLS", func() {
|
||||||
|
|
||||||
It("receives transport parameters", func() {
|
It("receives transport parameters", func() {
|
||||||
var cTransportParametersRcvd, sTransportParametersRcvd *TransportParameters
|
var cTransportParametersRcvd, sTransportParametersRcvd *TransportParameters
|
||||||
cChunkChan, cInitialStream, cHandshakeStream, _ := initStreams()
|
cChunkChan, cInitialStream, cHandshakeStream := initStreams()
|
||||||
cTransportParameters := &TransportParameters{MaxIdleTimeout: 0x42 * time.Second}
|
cTransportParameters := &TransportParameters{MaxIdleTimeout: 0x42 * time.Second}
|
||||||
cRunner := NewMockHandshakeRunner(mockCtrl)
|
cRunner := NewMockHandshakeRunner(mockCtrl)
|
||||||
cRunner.EXPECT().OnReceivedParams(gomock.Any()).Do(func(tp *TransportParameters) { sTransportParametersRcvd = tp })
|
cRunner.EXPECT().OnReceivedParams(gomock.Any()).Do(func(tp *TransportParameters) { sTransportParametersRcvd = tp })
|
||||||
|
@ -439,7 +434,6 @@ var _ = Describe("Crypto Setup TLS", func() {
|
||||||
client, _ := NewCryptoSetupClient(
|
client, _ := NewCryptoSetupClient(
|
||||||
cInitialStream,
|
cInitialStream,
|
||||||
cHandshakeStream,
|
cHandshakeStream,
|
||||||
ioutil.Discard,
|
|
||||||
protocol.ConnectionID{},
|
protocol.ConnectionID{},
|
||||||
nil,
|
nil,
|
||||||
cTransportParameters,
|
cTransportParameters,
|
||||||
|
@ -450,7 +444,7 @@ var _ = Describe("Crypto Setup TLS", func() {
|
||||||
utils.DefaultLogger.WithPrefix("client"),
|
utils.DefaultLogger.WithPrefix("client"),
|
||||||
)
|
)
|
||||||
|
|
||||||
sChunkChan, sInitialStream, sHandshakeStream, _ := initStreams()
|
sChunkChan, sInitialStream, sHandshakeStream := initStreams()
|
||||||
var token [16]byte
|
var token [16]byte
|
||||||
sRunner := NewMockHandshakeRunner(mockCtrl)
|
sRunner := NewMockHandshakeRunner(mockCtrl)
|
||||||
sRunner.EXPECT().OnReceivedParams(gomock.Any()).Do(func(tp *TransportParameters) { cTransportParametersRcvd = tp })
|
sRunner.EXPECT().OnReceivedParams(gomock.Any()).Do(func(tp *TransportParameters) { cTransportParametersRcvd = tp })
|
||||||
|
@ -462,7 +456,6 @@ var _ = Describe("Crypto Setup TLS", func() {
|
||||||
server := NewCryptoSetupServer(
|
server := NewCryptoSetupServer(
|
||||||
sInitialStream,
|
sInitialStream,
|
||||||
sHandshakeStream,
|
sHandshakeStream,
|
||||||
ioutil.Discard,
|
|
||||||
protocol.ConnectionID{},
|
protocol.ConnectionID{},
|
||||||
nil,
|
nil,
|
||||||
sTransportParameters,
|
sTransportParameters,
|
||||||
|
@ -487,14 +480,13 @@ var _ = Describe("Crypto Setup TLS", func() {
|
||||||
|
|
||||||
Context("with session tickets", func() {
|
Context("with session tickets", func() {
|
||||||
It("errors when the NewSessionTicket is sent at the wrong encryption level", func() {
|
It("errors when the NewSessionTicket is sent at the wrong encryption level", func() {
|
||||||
cChunkChan, cInitialStream, cHandshakeStream, _ := initStreams()
|
cChunkChan, cInitialStream, cHandshakeStream := initStreams()
|
||||||
cRunner := NewMockHandshakeRunner(mockCtrl)
|
cRunner := NewMockHandshakeRunner(mockCtrl)
|
||||||
cRunner.EXPECT().OnReceivedParams(gomock.Any())
|
cRunner.EXPECT().OnReceivedParams(gomock.Any())
|
||||||
cRunner.EXPECT().OnHandshakeComplete()
|
cRunner.EXPECT().OnHandshakeComplete()
|
||||||
client, _ := NewCryptoSetupClient(
|
client, _ := NewCryptoSetupClient(
|
||||||
cInitialStream,
|
cInitialStream,
|
||||||
cHandshakeStream,
|
cHandshakeStream,
|
||||||
ioutil.Discard,
|
|
||||||
protocol.ConnectionID{},
|
protocol.ConnectionID{},
|
||||||
nil,
|
nil,
|
||||||
&TransportParameters{},
|
&TransportParameters{},
|
||||||
|
@ -505,14 +497,13 @@ var _ = Describe("Crypto Setup TLS", func() {
|
||||||
utils.DefaultLogger.WithPrefix("client"),
|
utils.DefaultLogger.WithPrefix("client"),
|
||||||
)
|
)
|
||||||
|
|
||||||
sChunkChan, sInitialStream, sHandshakeStream, _ := initStreams()
|
sChunkChan, sInitialStream, sHandshakeStream := initStreams()
|
||||||
sRunner := NewMockHandshakeRunner(mockCtrl)
|
sRunner := NewMockHandshakeRunner(mockCtrl)
|
||||||
sRunner.EXPECT().OnReceivedParams(gomock.Any())
|
sRunner.EXPECT().OnReceivedParams(gomock.Any())
|
||||||
sRunner.EXPECT().OnHandshakeComplete()
|
sRunner.EXPECT().OnHandshakeComplete()
|
||||||
server := NewCryptoSetupServer(
|
server := NewCryptoSetupServer(
|
||||||
sInitialStream,
|
sInitialStream,
|
||||||
sHandshakeStream,
|
sHandshakeStream,
|
||||||
ioutil.Discard,
|
|
||||||
protocol.ConnectionID{},
|
protocol.ConnectionID{},
|
||||||
nil,
|
nil,
|
||||||
&TransportParameters{},
|
&TransportParameters{},
|
||||||
|
@ -544,14 +535,13 @@ var _ = Describe("Crypto Setup TLS", func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
It("errors when handling the NewSessionTicket fails", func() {
|
It("errors when handling the NewSessionTicket fails", func() {
|
||||||
cChunkChan, cInitialStream, cHandshakeStream, _ := initStreams()
|
cChunkChan, cInitialStream, cHandshakeStream := initStreams()
|
||||||
cRunner := NewMockHandshakeRunner(mockCtrl)
|
cRunner := NewMockHandshakeRunner(mockCtrl)
|
||||||
cRunner.EXPECT().OnReceivedParams(gomock.Any())
|
cRunner.EXPECT().OnReceivedParams(gomock.Any())
|
||||||
cRunner.EXPECT().OnHandshakeComplete()
|
cRunner.EXPECT().OnHandshakeComplete()
|
||||||
client, _ := NewCryptoSetupClient(
|
client, _ := NewCryptoSetupClient(
|
||||||
cInitialStream,
|
cInitialStream,
|
||||||
cHandshakeStream,
|
cHandshakeStream,
|
||||||
ioutil.Discard,
|
|
||||||
protocol.ConnectionID{},
|
protocol.ConnectionID{},
|
||||||
nil,
|
nil,
|
||||||
&TransportParameters{},
|
&TransportParameters{},
|
||||||
|
@ -562,14 +552,13 @@ var _ = Describe("Crypto Setup TLS", func() {
|
||||||
utils.DefaultLogger.WithPrefix("client"),
|
utils.DefaultLogger.WithPrefix("client"),
|
||||||
)
|
)
|
||||||
|
|
||||||
sChunkChan, sInitialStream, sHandshakeStream, _ := initStreams()
|
sChunkChan, sInitialStream, sHandshakeStream := initStreams()
|
||||||
sRunner := NewMockHandshakeRunner(mockCtrl)
|
sRunner := NewMockHandshakeRunner(mockCtrl)
|
||||||
sRunner.EXPECT().OnReceivedParams(gomock.Any())
|
sRunner.EXPECT().OnReceivedParams(gomock.Any())
|
||||||
sRunner.EXPECT().OnHandshakeComplete()
|
sRunner.EXPECT().OnHandshakeComplete()
|
||||||
server := NewCryptoSetupServer(
|
server := NewCryptoSetupServer(
|
||||||
sInitialStream,
|
sInitialStream,
|
||||||
sHandshakeStream,
|
sHandshakeStream,
|
||||||
ioutil.Discard,
|
|
||||||
protocol.ConnectionID{},
|
protocol.ConnectionID{},
|
||||||
nil,
|
nil,
|
||||||
&TransportParameters{},
|
&TransportParameters{},
|
||||||
|
@ -670,16 +659,16 @@ var _ = Describe("Crypto Setup TLS", func() {
|
||||||
Expect(client.ConnectionState().DidResume).To(BeFalse())
|
Expect(client.ConnectionState().DidResume).To(BeFalse())
|
||||||
|
|
||||||
csc.EXPECT().Get(gomock.Any()).Return(state, true)
|
csc.EXPECT().Get(gomock.Any()).Return(state, true)
|
||||||
|
csc.EXPECT().Put(gomock.Any(), nil)
|
||||||
csc.EXPECT().Put(gomock.Any(), gomock.Any()).MaxTimes(1)
|
csc.EXPECT().Put(gomock.Any(), gomock.Any()).MaxTimes(1)
|
||||||
|
|
||||||
cChunkChan, cInitialStream, cHandshakeStream, _ := initStreams()
|
cChunkChan, cInitialStream, cHandshakeStream := initStreams()
|
||||||
cRunner := NewMockHandshakeRunner(mockCtrl)
|
cRunner := NewMockHandshakeRunner(mockCtrl)
|
||||||
cRunner.EXPECT().OnReceivedParams(gomock.Any())
|
cRunner.EXPECT().OnReceivedParams(gomock.Any())
|
||||||
cRunner.EXPECT().OnHandshakeComplete()
|
cRunner.EXPECT().OnHandshakeComplete()
|
||||||
client, clientHelloChan := NewCryptoSetupClient(
|
client, clientHelloChan := NewCryptoSetupClient(
|
||||||
cInitialStream,
|
cInitialStream,
|
||||||
cHandshakeStream,
|
cHandshakeStream,
|
||||||
ioutil.Discard,
|
|
||||||
protocol.ConnectionID{},
|
protocol.ConnectionID{},
|
||||||
nil,
|
nil,
|
||||||
&TransportParameters{},
|
&TransportParameters{},
|
||||||
|
@ -690,14 +679,13 @@ var _ = Describe("Crypto Setup TLS", func() {
|
||||||
utils.DefaultLogger.WithPrefix("client"),
|
utils.DefaultLogger.WithPrefix("client"),
|
||||||
)
|
)
|
||||||
|
|
||||||
sChunkChan, sInitialStream, sHandshakeStream, _ := initStreams()
|
sChunkChan, sInitialStream, sHandshakeStream := initStreams()
|
||||||
sRunner := NewMockHandshakeRunner(mockCtrl)
|
sRunner := NewMockHandshakeRunner(mockCtrl)
|
||||||
sRunner.EXPECT().OnReceivedParams(gomock.Any())
|
sRunner.EXPECT().OnReceivedParams(gomock.Any())
|
||||||
sRunner.EXPECT().OnHandshakeComplete()
|
sRunner.EXPECT().OnHandshakeComplete()
|
||||||
server = NewCryptoSetupServer(
|
server = NewCryptoSetupServer(
|
||||||
sInitialStream,
|
sInitialStream,
|
||||||
sHandshakeStream,
|
sHandshakeStream,
|
||||||
ioutil.Discard,
|
|
||||||
protocol.ConnectionID{},
|
protocol.ConnectionID{},
|
||||||
nil,
|
nil,
|
||||||
&TransportParameters{},
|
&TransportParameters{},
|
||||||
|
|
|
@ -72,6 +72,7 @@ type CryptoSetup interface {
|
||||||
RunHandshake()
|
RunHandshake()
|
||||||
io.Closer
|
io.Closer
|
||||||
ChangeConnectionID(protocol.ConnectionID)
|
ChangeConnectionID(protocol.ConnectionID)
|
||||||
|
GetSessionTicket() ([]byte, error)
|
||||||
|
|
||||||
HandleMessage([]byte, protocol.EncryptionLevel) bool
|
HandleMessage([]byte, protocol.EncryptionLevel) bool
|
||||||
SetLargest1RTTAcked(protocol.PacketNumber)
|
SetLargest1RTTAcked(protocol.PacketNumber)
|
||||||
|
|
|
@ -208,6 +208,21 @@ func (mr *MockCryptoSetupMockRecorder) GetInitialSealer() *gomock.Call {
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetInitialSealer", reflect.TypeOf((*MockCryptoSetup)(nil).GetInitialSealer))
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetInitialSealer", reflect.TypeOf((*MockCryptoSetup)(nil).GetInitialSealer))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetSessionTicket mocks base method
|
||||||
|
func (m *MockCryptoSetup) GetSessionTicket() ([]byte, error) {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "GetSessionTicket")
|
||||||
|
ret0, _ := ret[0].([]byte)
|
||||||
|
ret1, _ := ret[1].(error)
|
||||||
|
return ret0, ret1
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetSessionTicket indicates an expected call of GetSessionTicket
|
||||||
|
func (mr *MockCryptoSetupMockRecorder) GetSessionTicket() *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSessionTicket", reflect.TypeOf((*MockCryptoSetup)(nil).GetSessionTicket))
|
||||||
|
}
|
||||||
|
|
||||||
// HandleMessage mocks base method
|
// HandleMessage mocks base method
|
||||||
func (m *MockCryptoSetup) HandleMessage(arg0 []byte, arg1 protocol.EncryptionLevel) bool {
|
func (m *MockCryptoSetup) HandleMessage(arg0 []byte, arg1 protocol.EncryptionLevel) bool {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
|
|
|
@ -110,7 +110,7 @@ const MinStreamFrameSize ByteCount = 128
|
||||||
|
|
||||||
// MaxPostHandshakeCryptoFrameSize is the maximum size of CRYPTO frames
|
// MaxPostHandshakeCryptoFrameSize is the maximum size of CRYPTO frames
|
||||||
// we send after the handshake completes.
|
// we send after the handshake completes.
|
||||||
const MaxPostHandshakeCryptoFrameSize ByteCount = 1000
|
const MaxPostHandshakeCryptoFrameSize = 1000
|
||||||
|
|
||||||
// MaxAckFrameSize is the maximum size for an ACK frame that we write
|
// MaxAckFrameSize is the maximum size for an ACK frame that we write
|
||||||
// Due to the varint encoding, ACK frames can grow (almost) indefinitely large.
|
// Due to the varint encoding, ACK frames can grow (almost) indefinitely large.
|
||||||
|
|
21
session.go
21
session.go
|
@ -52,6 +52,7 @@ type cryptoStreamHandler interface {
|
||||||
ChangeConnectionID(protocol.ConnectionID)
|
ChangeConnectionID(protocol.ConnectionID)
|
||||||
SetLargest1RTTAcked(protocol.PacketNumber)
|
SetLargest1RTTAcked(protocol.PacketNumber)
|
||||||
DropHandshakeKeys()
|
DropHandshakeKeys()
|
||||||
|
GetSessionTicket() ([]byte, error)
|
||||||
io.Closer
|
io.Closer
|
||||||
ConnectionState() handshake.ConnectionState
|
ConnectionState() handshake.ConnectionState
|
||||||
}
|
}
|
||||||
|
@ -141,6 +142,7 @@ type session struct {
|
||||||
frameParser wire.FrameParser
|
frameParser wire.FrameParser
|
||||||
packer packer
|
packer packer
|
||||||
|
|
||||||
|
oneRTTStream cryptoStream // only set for the server
|
||||||
cryptoStreamHandler cryptoStreamHandler
|
cryptoStreamHandler cryptoStreamHandler
|
||||||
|
|
||||||
receivedPackets chan *receivedPacket
|
receivedPackets chan *receivedPacket
|
||||||
|
@ -213,6 +215,7 @@ var newSession = func(
|
||||||
handshakeDestConnID: destConnID,
|
handshakeDestConnID: destConnID,
|
||||||
srcConnIDLen: srcConnID.Len(),
|
srcConnIDLen: srcConnID.Len(),
|
||||||
tokenGenerator: tokenGenerator,
|
tokenGenerator: tokenGenerator,
|
||||||
|
oneRTTStream: newCryptoStream(),
|
||||||
perspective: protocol.PerspectiveServer,
|
perspective: protocol.PerspectiveServer,
|
||||||
handshakeCompleteChan: make(chan struct{}),
|
handshakeCompleteChan: make(chan struct{}),
|
||||||
logger: logger,
|
logger: logger,
|
||||||
|
@ -244,7 +247,6 @@ var newSession = func(
|
||||||
s.sentPacketHandler = ackhandler.NewSentPacketHandler(0, s.rttStats, s.traceCallback, s.logger)
|
s.sentPacketHandler = ackhandler.NewSentPacketHandler(0, s.rttStats, s.traceCallback, s.logger)
|
||||||
initialStream := newCryptoStream()
|
initialStream := newCryptoStream()
|
||||||
handshakeStream := newCryptoStream()
|
handshakeStream := newCryptoStream()
|
||||||
oneRTTStream := newPostHandshakeCryptoStream(s.framer)
|
|
||||||
params := &handshake.TransportParameters{
|
params := &handshake.TransportParameters{
|
||||||
InitialMaxStreamDataBidiLocal: protocol.InitialMaxStreamData,
|
InitialMaxStreamDataBidiLocal: protocol.InitialMaxStreamData,
|
||||||
InitialMaxStreamDataBidiRemote: protocol.InitialMaxStreamData,
|
InitialMaxStreamDataBidiRemote: protocol.InitialMaxStreamData,
|
||||||
|
@ -263,7 +265,6 @@ var newSession = func(
|
||||||
cs := handshake.NewCryptoSetupServer(
|
cs := handshake.NewCryptoSetupServer(
|
||||||
initialStream,
|
initialStream,
|
||||||
handshakeStream,
|
handshakeStream,
|
||||||
oneRTTStream,
|
|
||||||
clientDestConnID,
|
clientDestConnID,
|
||||||
conn.RemoteAddr(),
|
conn.RemoteAddr(),
|
||||||
params,
|
params,
|
||||||
|
@ -297,7 +298,7 @@ var newSession = func(
|
||||||
s.version,
|
s.version,
|
||||||
)
|
)
|
||||||
s.unpacker = newPacketUnpacker(cs, s.version)
|
s.unpacker = newPacketUnpacker(cs, s.version)
|
||||||
s.cryptoStreamManager = newCryptoStreamManager(cs, initialStream, handshakeStream, oneRTTStream)
|
s.cryptoStreamManager = newCryptoStreamManager(cs, initialStream, handshakeStream, s.oneRTTStream)
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -348,7 +349,6 @@ var newClientSession = func(
|
||||||
s.sentPacketHandler = ackhandler.NewSentPacketHandler(initialPacketNumber, s.rttStats, s.traceCallback, s.logger)
|
s.sentPacketHandler = ackhandler.NewSentPacketHandler(initialPacketNumber, s.rttStats, s.traceCallback, s.logger)
|
||||||
initialStream := newCryptoStream()
|
initialStream := newCryptoStream()
|
||||||
handshakeStream := newCryptoStream()
|
handshakeStream := newCryptoStream()
|
||||||
oneRTTStream := newPostHandshakeCryptoStream(s.framer)
|
|
||||||
params := &handshake.TransportParameters{
|
params := &handshake.TransportParameters{
|
||||||
InitialMaxStreamDataBidiRemote: protocol.InitialMaxStreamData,
|
InitialMaxStreamDataBidiRemote: protocol.InitialMaxStreamData,
|
||||||
InitialMaxStreamDataBidiLocal: protocol.InitialMaxStreamData,
|
InitialMaxStreamDataBidiLocal: protocol.InitialMaxStreamData,
|
||||||
|
@ -365,7 +365,6 @@ var newClientSession = func(
|
||||||
cs, clientHelloWritten := handshake.NewCryptoSetupClient(
|
cs, clientHelloWritten := handshake.NewCryptoSetupClient(
|
||||||
initialStream,
|
initialStream,
|
||||||
handshakeStream,
|
handshakeStream,
|
||||||
oneRTTStream,
|
|
||||||
destConnID,
|
destConnID,
|
||||||
conn.RemoteAddr(),
|
conn.RemoteAddr(),
|
||||||
params,
|
params,
|
||||||
|
@ -382,7 +381,7 @@ var newClientSession = func(
|
||||||
)
|
)
|
||||||
s.clientHelloWritten = clientHelloWritten
|
s.clientHelloWritten = clientHelloWritten
|
||||||
s.cryptoStreamHandler = cs
|
s.cryptoStreamHandler = cs
|
||||||
s.cryptoStreamManager = newCryptoStreamManager(cs, initialStream, handshakeStream, oneRTTStream)
|
s.cryptoStreamManager = newCryptoStreamManager(cs, initialStream, handshakeStream, newCryptoStream())
|
||||||
s.unpacker = newPacketUnpacker(cs, s.version)
|
s.unpacker = newPacketUnpacker(cs, s.version)
|
||||||
s.packer = newPacketPacker(
|
s.packer = newPacketPacker(
|
||||||
srcConnID,
|
srcConnID,
|
||||||
|
@ -633,6 +632,16 @@ func (s *session) handleHandshakeComplete() {
|
||||||
s.sentPacketHandler.SetHandshakeComplete()
|
s.sentPacketHandler.SetHandshakeComplete()
|
||||||
|
|
||||||
if s.perspective == protocol.PerspectiveServer {
|
if s.perspective == protocol.PerspectiveServer {
|
||||||
|
ticket, err := s.cryptoStreamHandler.GetSessionTicket()
|
||||||
|
if err != nil {
|
||||||
|
s.closeLocal(err)
|
||||||
|
}
|
||||||
|
if ticket != nil {
|
||||||
|
s.oneRTTStream.Write(ticket)
|
||||||
|
for s.oneRTTStream.HasData() {
|
||||||
|
s.queueControlFrame(s.oneRTTStream.PopCryptoFrame(protocol.MaxPostHandshakeCryptoFrameSize))
|
||||||
|
}
|
||||||
|
}
|
||||||
token, err := s.tokenGenerator.NewToken(s.conn.RemoteAddr())
|
token, err := s.tokenGenerator.NewToken(s.conn.RemoteAddr())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.closeLocal(err)
|
s.closeLocal(err)
|
||||||
|
|
|
@ -1191,6 +1191,7 @@ var _ = Describe("Session", func() {
|
||||||
<-finishHandshake
|
<-finishHandshake
|
||||||
cryptoSetup.EXPECT().RunHandshake()
|
cryptoSetup.EXPECT().RunHandshake()
|
||||||
cryptoSetup.EXPECT().DropHandshakeKeys()
|
cryptoSetup.EXPECT().DropHandshakeKeys()
|
||||||
|
cryptoSetup.EXPECT().GetSessionTicket()
|
||||||
close(sess.handshakeCompleteChan)
|
close(sess.handshakeCompleteChan)
|
||||||
sess.run()
|
sess.run()
|
||||||
}()
|
}()
|
||||||
|
@ -1210,6 +1211,47 @@ var _ = Describe("Session", func() {
|
||||||
Eventually(sess.Context().Done()).Should(BeClosed())
|
Eventually(sess.Context().Done()).Should(BeClosed())
|
||||||
})
|
})
|
||||||
|
|
||||||
|
It("sends a session ticket when the handshake completes", func() {
|
||||||
|
const size = protocol.MaxPostHandshakeCryptoFrameSize * 3 / 2
|
||||||
|
packer.EXPECT().PackPacket().AnyTimes()
|
||||||
|
finishHandshake := make(chan struct{})
|
||||||
|
sessionRunner.EXPECT().Retire(clientDestConnID)
|
||||||
|
go func() {
|
||||||
|
defer GinkgoRecover()
|
||||||
|
<-finishHandshake
|
||||||
|
cryptoSetup.EXPECT().RunHandshake()
|
||||||
|
cryptoSetup.EXPECT().DropHandshakeKeys()
|
||||||
|
cryptoSetup.EXPECT().GetSessionTicket().Return(make([]byte, size), nil)
|
||||||
|
close(sess.handshakeCompleteChan)
|
||||||
|
sess.run()
|
||||||
|
}()
|
||||||
|
|
||||||
|
handshakeCtx := sess.HandshakeComplete()
|
||||||
|
Consistently(handshakeCtx.Done()).ShouldNot(BeClosed())
|
||||||
|
mconn.EXPECT().RemoteAddr().Return(&net.UDPAddr{}) // the remote addr is needed for the token
|
||||||
|
close(finishHandshake)
|
||||||
|
Eventually(handshakeCtx.Done()).Should(BeClosed())
|
||||||
|
frames, _ := sess.framer.AppendControlFrames(nil, protocol.MaxByteCount)
|
||||||
|
var count int
|
||||||
|
var s int
|
||||||
|
for _, f := range frames {
|
||||||
|
if cf, ok := f.Frame.(*wire.CryptoFrame); ok {
|
||||||
|
count++
|
||||||
|
s += len(cf.Data)
|
||||||
|
Expect(f.Length(sess.version)).To(BeNumerically("<=", protocol.MaxPostHandshakeCryptoFrameSize))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Expect(size).To(BeEquivalentTo(s))
|
||||||
|
// make sure the go routine returns
|
||||||
|
streamManager.EXPECT().CloseWithError(gomock.Any())
|
||||||
|
expectReplaceWithClosed()
|
||||||
|
packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil)
|
||||||
|
cryptoSetup.EXPECT().Close()
|
||||||
|
mconn.EXPECT().Write(gomock.Any())
|
||||||
|
sess.shutdown()
|
||||||
|
Eventually(sess.Context().Done()).Should(BeClosed())
|
||||||
|
})
|
||||||
|
|
||||||
It("doesn't cancel the HandshakeComplete context when the handshake fails", func() {
|
It("doesn't cancel the HandshakeComplete context when the handshake fails", func() {
|
||||||
packer.EXPECT().PackPacket().AnyTimes()
|
packer.EXPECT().PackPacket().AnyTimes()
|
||||||
streamManager.EXPECT().CloseWithError(gomock.Any())
|
streamManager.EXPECT().CloseWithError(gomock.Any())
|
||||||
|
@ -1247,6 +1289,7 @@ var _ = Describe("Session", func() {
|
||||||
defer GinkgoRecover()
|
defer GinkgoRecover()
|
||||||
cryptoSetup.EXPECT().RunHandshake()
|
cryptoSetup.EXPECT().RunHandshake()
|
||||||
cryptoSetup.EXPECT().DropHandshakeKeys()
|
cryptoSetup.EXPECT().DropHandshakeKeys()
|
||||||
|
cryptoSetup.EXPECT().GetSessionTicket()
|
||||||
mconn.EXPECT().RemoteAddr().Return(&net.UDPAddr{}) // the remote addr is needed for the token
|
mconn.EXPECT().RemoteAddr().Return(&net.UDPAddr{}) // the remote addr is needed for the token
|
||||||
mconn.EXPECT().Write(gomock.Any())
|
mconn.EXPECT().Write(gomock.Any())
|
||||||
close(sess.handshakeCompleteChan)
|
close(sess.handshakeCompleteChan)
|
||||||
|
@ -1492,6 +1535,7 @@ var _ = Describe("Session", func() {
|
||||||
go func() {
|
go func() {
|
||||||
defer GinkgoRecover()
|
defer GinkgoRecover()
|
||||||
cryptoSetup.EXPECT().RunHandshake().MaxTimes(1)
|
cryptoSetup.EXPECT().RunHandshake().MaxTimes(1)
|
||||||
|
cryptoSetup.EXPECT().GetSessionTicket().MaxTimes(1)
|
||||||
cryptoSetup.EXPECT().DropHandshakeKeys().MaxTimes(1)
|
cryptoSetup.EXPECT().DropHandshakeKeys().MaxTimes(1)
|
||||||
mconn.EXPECT().RemoteAddr().Return(&net.UDPAddr{})
|
mconn.EXPECT().RemoteAddr().Return(&net.UDPAddr{})
|
||||||
close(sess.handshakeCompleteChan)
|
close(sess.handshakeCompleteChan)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue