drop Initial and Handshake keys when receiving the first 1-RTT ACK

This commit is contained in:
Marten Seemann 2019-05-30 02:23:07 +08:00
parent 4834962cbd
commit a4989c3d9c
8 changed files with 111 additions and 35 deletions

View file

@ -105,17 +105,12 @@ func NewSentPacketHandler(
func (h *sentPacketHandler) DropPackets(encLevel protocol.EncryptionLevel) { func (h *sentPacketHandler) DropPackets(encLevel protocol.EncryptionLevel) {
// remove outstanding packets from bytes_in_flight // remove outstanding packets from bytes_in_flight
pnSpace := h.getPacketNumberSpace(encLevel) pnSpace := h.getPacketNumberSpace(encLevel)
var packets []*Packet
pnSpace.history.Iterate(func(p *Packet) (bool, error) { pnSpace.history.Iterate(func(p *Packet) (bool, error) {
packets = append(packets, p)
if p.includedInBytesInFlight { if p.includedInBytesInFlight {
h.bytesInFlight -= p.Length h.bytesInFlight -= p.Length
} }
return true, nil return true, nil
}) })
for _, p := range packets {
pnSpace.history.Remove(p.PacketNumber)
}
// remove packets from the retransmission queue // remove packets from the retransmission queue
var queue []*Packet var queue []*Packet
for _, packet := range h.retransmissionQueue { for _, packet := range h.retransmissionQueue {
@ -124,6 +119,15 @@ func (h *sentPacketHandler) DropPackets(encLevel protocol.EncryptionLevel) {
} }
} }
h.retransmissionQueue = queue h.retransmissionQueue = queue
// drop the packet history
switch encLevel {
case protocol.EncryptionInitial:
h.initialPackets = nil
case protocol.EncryptionHandshake:
h.handshakePackets = nil
default:
panic(fmt.Sprintf("Cannot drop keys for encryption level %s", encLevel))
}
} }
func (h *sentPacketHandler) SetMaxAckDelay(mad time.Duration) { func (h *sentPacketHandler) SetMaxAckDelay(mad time.Duration) {
@ -312,7 +316,14 @@ func (h *sentPacketHandler) determineNewlyAckedPackets(
} }
func (h *sentPacketHandler) hasOutstandingCryptoPackets() bool { func (h *sentPacketHandler) hasOutstandingCryptoPackets() bool {
return h.initialPackets.history.HasOutstandingPackets() || h.handshakePackets.history.HasOutstandingPackets() var hasInitial, hasHandshake bool
if h.initialPackets != nil {
hasInitial = h.initialPackets.history.HasOutstandingPackets()
}
if h.handshakePackets != nil {
hasHandshake = h.handshakePackets.history.HasOutstandingPackets()
}
return hasInitial || hasHandshake
} }
func (h *sentPacketHandler) hasOutstandingPackets() bool { func (h *sentPacketHandler) hasOutstandingPackets() bool {
@ -536,8 +547,13 @@ func (h *sentPacketHandler) PopPacketNumber(encLevel protocol.EncryptionLevel) p
} }
func (h *sentPacketHandler) SendMode() SendMode { func (h *sentPacketHandler) SendMode() SendMode {
numTrackedPackets := len(h.retransmissionQueue) + h.initialPackets.history.Len() + numTrackedPackets := len(h.retransmissionQueue) + h.oneRTTPackets.history.Len()
h.handshakePackets.history.Len() + h.oneRTTPackets.history.Len() if h.initialPackets != nil {
numTrackedPackets += h.initialPackets.history.Len()
}
if h.handshakePackets != nil {
numTrackedPackets += h.handshakePackets.history.Len()
}
// Don't send any packets if we're keeping track of the maximum number of packets. // Don't send any packets if we're keeping track of the maximum number of packets.
// Note that since MaxOutstandingSentPackets is smaller than MaxTrackedSentPackets, // Note that since MaxOutstandingSentPackets is smaller than MaxTrackedSentPackets,

View file

@ -861,7 +861,7 @@ var _ = Describe("SentPacketHandler", func() {
handler.queuePacketForRetransmission(lostPacket, handler.getPacketNumberSpace(protocol.EncryptionHandshake)) handler.queuePacketForRetransmission(lostPacket, handler.getPacketNumberSpace(protocol.EncryptionHandshake))
handler.DropPackets(protocol.EncryptionInitial) handler.DropPackets(protocol.EncryptionInitial)
Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(10))) Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(10)))
Expect(handler.initialPackets.history.Len()).To(BeZero()) Expect(handler.initialPackets).To(BeNil())
Expect(handler.handshakePackets.history.Len()).ToNot(BeZero()) Expect(handler.handshakePackets.history.Len()).ToNot(BeZero())
packet := handler.DequeuePacketForRetransmission() packet := handler.DequeuePacketForRetransmission()
Expect(packet).To(Equal(lostPacket)) Expect(packet).To(Equal(lostPacket))
@ -882,7 +882,7 @@ var _ = Describe("SentPacketHandler", func() {
handler.queuePacketForRetransmission(lostPacket, handler.getPacketNumberSpace(protocol.EncryptionHandshake)) handler.queuePacketForRetransmission(lostPacket, handler.getPacketNumberSpace(protocol.EncryptionHandshake))
handler.DropPackets(protocol.EncryptionHandshake) handler.DropPackets(protocol.EncryptionHandshake)
Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(10))) Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(10)))
Expect(handler.handshakePackets.history.Len()).To(BeZero()) Expect(handler.handshakePackets).To(BeNil())
packet := handler.DequeuePacketForRetransmission() packet := handler.DequeuePacketForRetransmission()
Expect(packet).To(Equal(lostPacket)) Expect(packet).To(Equal(lostPacket))
}) })

View file

@ -53,10 +53,15 @@ func (m messageType) String() string {
} }
} }
// ErrOpenerNotYetAvailable is returned when an opener is requested for an encryption level, var (
// but the corresponding opener has not yet been initialized // ErrOpenerNotYetAvailable is returned when an opener is requested for an encryption level,
// This can happen when packets arrive out of order. // but the corresponding opener has not yet been initialized
var ErrOpenerNotYetAvailable = errors.New("CryptoSetup: opener at this encryption level not yet available") // This can happen when packets arrive out of order.
ErrOpenerNotYetAvailable = errors.New("CryptoSetup: opener at this encryption level not yet available")
// ErrKeysDropped is returned when an opener or a sealer is requested for an encryption level,
// but the corresponding keys have already been dropped.
ErrKeysDropped = errors.New("CryptoSetup: keys were already dropped")
)
type cryptoSetup struct { type cryptoSetup struct {
tlsConf *qtls.Config tlsConf *qtls.Config
@ -67,6 +72,8 @@ type cryptoSetup struct {
paramsChan <-chan []byte paramsChan <-chan []byte
handleParamsCallback func([]byte) handleParamsCallback func([]byte)
dropKeyCallback func(protocol.EncryptionLevel)
alertChan chan uint8 alertChan chan uint8
// HandleData() sends errors on the messageErrChan // HandleData() sends errors on the messageErrChan
messageErrChan chan error messageErrChan chan error
@ -121,6 +128,7 @@ func NewCryptoSetupClient(
remoteAddr net.Addr, remoteAddr net.Addr,
tp *TransportParameters, tp *TransportParameters,
handleParams func([]byte), handleParams func([]byte),
dropKeys func(protocol.EncryptionLevel),
tlsConf *tls.Config, tlsConf *tls.Config,
logger utils.Logger, logger utils.Logger,
) (CryptoSetup, <-chan struct{} /* ClientHello written */, error) { ) (CryptoSetup, <-chan struct{} /* ClientHello written */, error) {
@ -131,6 +139,7 @@ func NewCryptoSetupClient(
connID, connID,
tp, tp,
handleParams, handleParams,
dropKeys,
tlsConf, tlsConf,
logger, logger,
protocol.PerspectiveClient, protocol.PerspectiveClient,
@ -151,6 +160,7 @@ func NewCryptoSetupServer(
remoteAddr net.Addr, remoteAddr net.Addr,
tp *TransportParameters, tp *TransportParameters,
handleParams func([]byte), handleParams func([]byte),
dropKeys func(protocol.EncryptionLevel),
tlsConf *tls.Config, tlsConf *tls.Config,
logger utils.Logger, logger utils.Logger,
) (CryptoSetup, error) { ) (CryptoSetup, error) {
@ -161,6 +171,7 @@ func NewCryptoSetupServer(
connID, connID,
tp, tp,
handleParams, handleParams,
dropKeys,
tlsConf, tlsConf,
logger, logger,
protocol.PerspectiveServer, protocol.PerspectiveServer,
@ -179,6 +190,7 @@ func newCryptoSetup(
connID protocol.ConnectionID, connID protocol.ConnectionID,
tp *TransportParameters, tp *TransportParameters,
handleParams func([]byte), handleParams func([]byte),
dropKeys func(protocol.EncryptionLevel),
tlsConf *tls.Config, tlsConf *tls.Config,
logger utils.Logger, logger utils.Logger,
perspective protocol.Perspective, perspective protocol.Perspective,
@ -197,6 +209,7 @@ func newCryptoSetup(
readEncLevel: protocol.EncryptionInitial, readEncLevel: protocol.EncryptionInitial,
writeEncLevel: protocol.EncryptionInitial, writeEncLevel: protocol.EncryptionInitial,
handleParamsCallback: handleParams, handleParamsCallback: handleParams,
dropKeyCallback: dropKeys,
paramsChan: extHandler.TransportParameters(), paramsChan: extHandler.TransportParameters(),
logger: logger, logger: logger,
perspective: perspective, perspective: perspective,
@ -225,6 +238,24 @@ func (h *cryptoSetup) ChangeConnectionID(id protocol.ConnectionID) error {
return nil return nil
} }
func (h *cryptoSetup) Received1RTTAck() {
// drop initial keys
// TODO: do this earlier
if h.initialOpener != nil {
h.initialOpener = nil
h.initialSealer = nil
h.dropKeyCallback(protocol.EncryptionInitial)
h.logger.Debugf("Dropping Initial keys.")
}
// drop handshake keys
if h.handshakeOpener != nil {
h.handshakeOpener = nil
h.handshakeSealer = nil
h.logger.Debugf("Dropping Handshake keys.")
h.dropKeyCallback(protocol.EncryptionHandshake)
}
}
func (h *cryptoSetup) RunHandshake() error { func (h *cryptoSetup) RunHandshake() error {
// Handle errors that might occur when HandleData() is called. // Handle errors that might occur when HandleData() is called.
handshakeComplete := make(chan struct{}) handshakeComplete := make(chan struct{})
@ -554,10 +585,17 @@ func (h *cryptoSetup) GetOpener(level protocol.EncryptionLevel) (Opener, error)
switch level { switch level {
case protocol.EncryptionInitial: case protocol.EncryptionInitial:
if h.initialOpener == nil {
return nil, ErrKeysDropped
}
return h.initialOpener, nil return h.initialOpener, nil
case protocol.EncryptionHandshake: case protocol.EncryptionHandshake:
if h.handshakeOpener == nil { if h.handshakeOpener == nil {
return nil, ErrOpenerNotYetAvailable if h.initialOpener != nil {
return nil, ErrOpenerNotYetAvailable
}
// if the initial opener is also not available, the keys were already dropped
return nil, ErrKeysDropped
} }
return h.handshakeOpener, nil return h.handshakeOpener, nil
case protocol.Encryption1RTT: case protocol.Encryption1RTT:

View file

@ -87,6 +87,7 @@ var _ = Describe("Crypto Setup TLS", func() {
nil, nil,
&TransportParameters{}, &TransportParameters{},
func([]byte) {}, func([]byte) {},
func(protocol.EncryptionLevel) {},
tlsConf, tlsConf,
utils.DefaultLogger.WithPrefix("server"), utils.DefaultLogger.WithPrefix("server"),
) )
@ -115,6 +116,7 @@ var _ = Describe("Crypto Setup TLS", func() {
nil, nil,
&TransportParameters{}, &TransportParameters{},
func([]byte) {}, func([]byte) {},
func(protocol.EncryptionLevel) {},
testdata.GetTLSConfig(), testdata.GetTLSConfig(),
utils.DefaultLogger.WithPrefix("server"), utils.DefaultLogger.WithPrefix("server"),
) )
@ -149,6 +151,7 @@ var _ = Describe("Crypto Setup TLS", func() {
nil, nil,
&TransportParameters{}, &TransportParameters{},
func([]byte) {}, func([]byte) {},
func(protocol.EncryptionLevel) {},
testdata.GetTLSConfig(), testdata.GetTLSConfig(),
utils.DefaultLogger.WithPrefix("server"), utils.DefaultLogger.WithPrefix("server"),
) )
@ -177,6 +180,7 @@ var _ = Describe("Crypto Setup TLS", func() {
nil, nil,
&TransportParameters{}, &TransportParameters{},
func([]byte) {}, func([]byte) {},
func(protocol.EncryptionLevel) {},
testdata.GetTLSConfig(), testdata.GetTLSConfig(),
utils.DefaultLogger.WithPrefix("server"), utils.DefaultLogger.WithPrefix("server"),
) )
@ -256,6 +260,7 @@ var _ = Describe("Crypto Setup TLS", func() {
nil, nil,
&TransportParameters{}, &TransportParameters{},
func([]byte) {}, func([]byte) {},
func(protocol.EncryptionLevel) {},
clientConf, clientConf,
utils.DefaultLogger.WithPrefix("client"), utils.DefaultLogger.WithPrefix("client"),
) )
@ -271,6 +276,7 @@ var _ = Describe("Crypto Setup TLS", func() {
nil, nil,
&TransportParameters{StatelessResetToken: &token}, &TransportParameters{StatelessResetToken: &token},
func([]byte) {}, func([]byte) {},
func(protocol.EncryptionLevel) {},
serverConf, serverConf,
utils.DefaultLogger.WithPrefix("server"), utils.DefaultLogger.WithPrefix("server"),
) )
@ -313,6 +319,7 @@ var _ = Describe("Crypto Setup TLS", func() {
nil, nil,
&TransportParameters{}, &TransportParameters{},
func([]byte) {}, func([]byte) {},
func(protocol.EncryptionLevel) {},
&tls.Config{InsecureSkipVerify: true}, &tls.Config{InsecureSkipVerify: true},
utils.DefaultLogger.WithPrefix("client"), utils.DefaultLogger.WithPrefix("client"),
) )
@ -350,6 +357,7 @@ var _ = Describe("Crypto Setup TLS", func() {
nil, nil,
cTransportParameters, cTransportParameters,
func(p []byte) { sTransportParametersRcvd = p }, func(p []byte) { sTransportParametersRcvd = p },
func(protocol.EncryptionLevel) {},
clientConf, clientConf,
utils.DefaultLogger.WithPrefix("client"), utils.DefaultLogger.WithPrefix("client"),
) )
@ -369,6 +377,7 @@ var _ = Describe("Crypto Setup TLS", func() {
nil, nil,
sTransportParameters, sTransportParameters,
func(p []byte) { cTransportParametersRcvd = p }, func(p []byte) { cTransportParametersRcvd = p },
func(protocol.EncryptionLevel) {},
testdata.GetTLSConfig(), testdata.GetTLSConfig(),
utils.DefaultLogger.WithPrefix("server"), utils.DefaultLogger.WithPrefix("server"),
) )

View file

@ -36,6 +36,7 @@ type CryptoSetup interface {
ChangeConnectionID(protocol.ConnectionID) error ChangeConnectionID(protocol.ConnectionID) error
HandleMessage([]byte, protocol.EncryptionLevel) bool HandleMessage([]byte, protocol.EncryptionLevel) bool
Received1RTTAck()
ConnectionState() tls.ConnectionState ConnectionState() tls.ConnectionState
GetSealer() (protocol.EncryptionLevel, Sealer) GetSealer() (protocol.EncryptionLevel, Sealer)

View file

@ -137,6 +137,18 @@ func (mr *MockCryptoSetupMockRecorder) HandleMessage(arg0, arg1 interface{}) *go
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandleMessage", reflect.TypeOf((*MockCryptoSetup)(nil).HandleMessage), arg0, arg1) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandleMessage", reflect.TypeOf((*MockCryptoSetup)(nil).HandleMessage), arg0, arg1)
} }
// Received1RTTAck mocks base method
func (m *MockCryptoSetup) Received1RTTAck() {
m.ctrl.T.Helper()
m.ctrl.Call(m, "Received1RTTAck")
}
// Received1RTTAck indicates an expected call of Received1RTTAck
func (mr *MockCryptoSetupMockRecorder) Received1RTTAck() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Received1RTTAck", reflect.TypeOf((*MockCryptoSetup)(nil).Received1RTTAck))
}
// RunHandshake mocks base method // RunHandshake mocks base method
func (m *MockCryptoSetup) RunHandshake() error { func (m *MockCryptoSetup) RunHandshake() error {
m.ctrl.T.Helper() m.ctrl.T.Helper()

View file

@ -49,6 +49,7 @@ type streamManager interface {
type cryptoStreamHandler interface { type cryptoStreamHandler interface {
RunHandshake() error RunHandshake() error
ChangeConnectionID(protocol.ConnectionID) error ChangeConnectionID(protocol.ConnectionID) error
Received1RTTAck()
io.Closer io.Closer
ConnectionState() tls.ConnectionState ConnectionState() tls.ConnectionState
} }
@ -129,9 +130,8 @@ type session struct {
handshakeCompleteChan chan struct{} // is closed when the handshake completes handshakeCompleteChan chan struct{} // is closed when the handshake completes
handshakeComplete bool handshakeComplete bool
receivedRetry bool receivedRetry bool
receivedFirstPacket bool receivedFirstPacket bool
receivedFirstForwardSecurePacket bool
sessionCreationTime time.Time sessionCreationTime time.Time
// The idle timeout is set based on the max of the time we received the last packet... // The idle timeout is set based on the max of the time we received the last packet...
@ -199,6 +199,7 @@ var newSession = func(
conn.RemoteAddr(), conn.RemoteAddr(),
params, params,
s.processTransportParameters, s.processTransportParameters,
s.dropEncryptionLevel,
tlsConf, tlsConf,
logger, logger,
) )
@ -267,6 +268,7 @@ var newClientSession = func(
conn.RemoteAddr(), conn.RemoteAddr(),
params, params,
s.processTransportParameters, s.processTransportParameters,
s.dropEncryptionLevel,
tlsConf, tlsConf,
logger, logger,
) )
@ -485,8 +487,6 @@ func (s *session) handleHandshakeComplete() {
// independent from the application protocol. // independent from the application protocol.
if s.perspective == protocol.PerspectiveServer { if s.perspective == protocol.PerspectiveServer {
s.queueControlFrame(&wire.PingFrame{}) s.queueControlFrame(&wire.PingFrame{})
s.sentPacketHandler.DropPackets(protocol.EncryptionInitial)
s.sentPacketHandler.DropPackets(protocol.EncryptionHandshake)
} }
} }
@ -560,16 +560,19 @@ func (s *session) handleSinglePacket(p *receivedPacket, hdr *wire.Header) bool /
packet, err := s.unpacker.Unpack(hdr, p.data) packet, err := s.unpacker.Unpack(hdr, p.data)
if err != nil { if err != nil {
if err == handshake.ErrOpenerNotYetAvailable { switch err {
case handshake.ErrKeysDropped:
s.logger.Debugf("Dropping packet because we already dropped the keys.")
case handshake.ErrOpenerNotYetAvailable:
// Sealer for this encryption level not yet available. // Sealer for this encryption level not yet available.
// Try again later. // Try again later.
wasQueued = true wasQueued = true
s.tryQueueingUndecryptablePacket(p) s.tryQueueingUndecryptablePacket(p)
return false default:
// This might be a packet injected by an attacker.
// Drop it.
s.logger.Debugf("Dropping packet that could not be unpacked. Unpack error: %s", err)
} }
// This might be a packet injected by an attacker.
// Drop it.
s.logger.Debugf("Dropping packet that could not be unpacked. Unpack error: %s", err)
return false return false
} }
@ -642,16 +645,6 @@ func (s *session) handleUnpackedPacket(packet *unpackedPacket, rcvTime time.Time
s.firstAckElicitingPacketAfterIdleSentTime = time.Time{} s.firstAckElicitingPacketAfterIdleSentTime = time.Time{}
s.keepAlivePingSent = false s.keepAlivePingSent = false
// The client completes the handshake first (after sending the CFIN).
// We know that the server completed the handshake as soon as we receive a forward-secure packet.
if s.perspective == protocol.PerspectiveClient {
if !s.receivedFirstForwardSecurePacket && packet.encryptionLevel == protocol.Encryption1RTT {
s.receivedFirstForwardSecurePacket = true
s.sentPacketHandler.DropPackets(protocol.EncryptionInitial)
s.sentPacketHandler.DropPackets(protocol.EncryptionHandshake)
}
}
r := bytes.NewReader(packet.data) r := bytes.NewReader(packet.data)
var isAckEliciting bool var isAckEliciting bool
for { for {
@ -834,6 +827,7 @@ func (s *session) handleAckFrame(frame *wire.AckFrame, pn protocol.PacketNumber,
} }
if encLevel == protocol.Encryption1RTT { if encLevel == protocol.Encryption1RTT {
s.receivedPacketHandler.IgnoreBelow(s.sentPacketHandler.GetLowestPacketNotConfirmedAcked()) s.receivedPacketHandler.IgnoreBelow(s.sentPacketHandler.GetLowestPacketNotConfirmedAcked())
s.cryptoStreamHandler.Received1RTTAck()
} }
return nil return nil
} }
@ -924,6 +918,11 @@ func (s *session) handleCloseError(closeErr closeError) {
} }
} }
func (s *session) dropEncryptionLevel(encLevel protocol.EncryptionLevel) {
s.sentPacketHandler.DropPackets(encLevel)
s.receivedPacketHandler.DropPackets(encLevel)
}
func (s *session) processTransportParameters(data []byte) { func (s *session) processTransportParameters(data []byte) {
var params *handshake.TransportParameters var params *handshake.TransportParameters
var err error var err error

View file

@ -161,6 +161,7 @@ var _ = Describe("Session", func() {
}) })
It("tells the ReceivedPacketHandler to ignore low ranges", func() { It("tells the ReceivedPacketHandler to ignore low ranges", func() {
cryptoSetup.EXPECT().Received1RTTAck()
ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 2, Largest: 3}}} ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 2, Largest: 3}}}
sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) sph := mockackhandler.NewMockSentPacketHandler(mockCtrl)
sph.EXPECT().ReceivedAck(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) sph.EXPECT().ReceivedAck(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())