use the connection ID manager to save the destination connection ID

This commit is contained in:
Marten Seemann 2019-10-25 17:58:53 +07:00
parent a321f9faa6
commit 772ffd3d20
7 changed files with 125 additions and 121 deletions

View file

@ -104,7 +104,6 @@ var errCloseForRecreating = errors.New("closing session in order to recreate it"
type session struct {
sessionRunner sessionRunner
destConnID protocol.ConnectionID
origDestConnID protocol.ConnectionID // if the server sends a Retry, this is the connection ID we used initially
srcConnID protocol.ConnectionID
@ -201,13 +200,13 @@ var newSession = func(
sessionRunner: runner,
config: conf,
srcConnID: srcConnID,
destConnID: destConnID,
tokenGenerator: tokenGenerator,
perspective: protocol.PerspectiveServer,
handshakeCompleteChan: make(chan struct{}),
logger: logger,
version: v,
}
s.connIDManager = newConnIDManager(destConnID, s.queueControlFrame)
s.preSetup()
s.sentPacketHandler = ackhandler.NewSentPacketHandler(0, s.rttStats, s.traceCallback, s.logger)
initialStream := newCryptoStream()
@ -231,9 +230,10 @@ var newSession = func(
logger,
)
s.cryptoStreamHandler = cs
s.connIDManager = newConnIDManager(destConnID, s.queueControlFrame)
s.packer = newPacketPacker(
s.destConnID,
s.srcConnID,
s.connIDManager.Get,
initialStream,
handshakeStream,
s.sentPacketHandler,
@ -269,13 +269,13 @@ var newClientSession = func(
sessionRunner: runner,
config: conf,
srcConnID: srcConnID,
destConnID: destConnID,
perspective: protocol.PerspectiveClient,
handshakeCompleteChan: make(chan struct{}),
logger: logger,
initialVersion: initialVersion,
version: v,
}
s.connIDManager = newConnIDManager(destConnID, s.queueControlFrame)
s.preSetup()
s.sentPacketHandler = ackhandler.NewSentPacketHandler(initialPacketNumber, s.rttStats, s.traceCallback, s.logger)
initialStream := newCryptoStream()
@ -285,7 +285,7 @@ var newClientSession = func(
initialStream,
handshakeStream,
oneRTTStream,
s.destConnID,
destConnID,
conn.RemoteAddr(),
params,
&handshakeRunner{
@ -303,8 +303,8 @@ var newClientSession = func(
s.cryptoStreamManager = newCryptoStreamManager(cs, initialStream, handshakeStream, oneRTTStream)
s.unpacker = newPacketUnpacker(cs, s.version)
s.packer = newPacketPacker(
s.destConnID,
s.srcConnID,
s.connIDManager.Get,
initialStream,
handshakeStream,
s.sentPacketHandler,
@ -333,7 +333,6 @@ func (s *session) preSetup() {
s.sendQueue = newSendQueue(s.conn)
s.retransmissionQueue = newRetransmissionQueue(s.version)
s.frameParser = wire.NewFrameParser(s.version)
s.connIDManager = newConnIDManager(s.queueControlFrame)
s.rttStats = &congestion.RTTStats{}
s.receivedPacketHandler = ackhandler.NewReceivedPacketHandler(s.rttStats, s.logger, s.version)
s.connFlowController = flowcontrol.NewConnectionFlowController(
@ -601,8 +600,9 @@ func (s *session) handleSinglePacket(p *receivedPacket, hdr *wire.Header) bool /
// The server can change the source connection ID with the first Handshake packet.
// After this, all packets with a different source connection have to be ignored.
if s.receivedFirstPacket && hdr.IsLongHeader && !hdr.SrcConnectionID.Equal(s.destConnID) {
s.logger.Debugf("Dropping packet with unexpected source connection ID: %s (expected %s)", hdr.SrcConnectionID, s.destConnID)
destConnID := s.connIDManager.Get()
if s.receivedFirstPacket && hdr.IsLongHeader && !hdr.SrcConnectionID.Equal(destConnID) {
s.logger.Debugf("Dropping packet with unexpected source connection ID: %s (expected %s)", hdr.SrcConnectionID, destConnID)
return false
}
// drop 0-RTT packets
@ -652,11 +652,12 @@ func (s *session) handleRetryPacket(hdr *wire.Header) bool /* was this a valid R
return false
}
(&wire.ExtendedHeader{Header: *hdr}).Log(s.logger)
if !hdr.OrigDestConnectionID.Equal(s.destConnID) {
s.logger.Debugf("Ignoring spoofed Retry. Original Destination Connection ID: %s, expected: %s", hdr.OrigDestConnectionID, s.destConnID)
destConnID := s.connIDManager.Get()
if !hdr.OrigDestConnectionID.Equal(destConnID) {
s.logger.Debugf("Ignoring spoofed Retry. Original Destination Connection ID: %s, expected: %s", hdr.OrigDestConnectionID, destConnID)
return false
}
if hdr.SrcConnectionID.Equal(s.destConnID) {
if hdr.SrcConnectionID.Equal(destConnID) {
s.logger.Debugf("Ignoring Retry, since the server didn't change the Source Connection ID.")
return false
}
@ -668,16 +669,16 @@ func (s *session) handleRetryPacket(hdr *wire.Header) bool /* was this a valid R
}
s.logger.Debugf("<- Received Retry")
s.logger.Debugf("Switching destination connection ID to: %s", hdr.SrcConnectionID)
s.origDestConnID = s.destConnID
s.destConnID = hdr.SrcConnectionID
s.origDestConnID = destConnID
newDestConnID := hdr.SrcConnectionID
s.receivedRetry = true
if err := s.sentPacketHandler.ResetForRetry(); err != nil {
s.closeLocal(err)
return false
}
s.cryptoStreamHandler.ChangeConnectionID(s.destConnID)
s.cryptoStreamHandler.ChangeConnectionID(newDestConnID)
s.packer.SetToken(hdr.Token)
s.packer.ChangeDestConnectionID(s.destConnID)
s.connIDManager.ChangeInitialConnID(newDestConnID)
s.scheduleSending()
return true
}
@ -688,10 +689,9 @@ func (s *session) handleUnpackedPacket(packet *unpackedPacket, rcvTime time.Time
}
// The server can change the source connection ID with the first Handshake packet.
if s.perspective == protocol.PerspectiveClient && !s.receivedFirstPacket && packet.hdr.IsLongHeader && !packet.hdr.SrcConnectionID.Equal(s.destConnID) {
if s.perspective == protocol.PerspectiveClient && !s.receivedFirstPacket && packet.hdr.IsLongHeader && !packet.hdr.SrcConnectionID.Equal(s.connIDManager.Get()) {
s.logger.Debugf("Received first packet. Switching destination connection ID to: %s", packet.hdr.SrcConnectionID)
s.destConnID = packet.hdr.SrcConnectionID
s.packer.ChangeDestConnectionID(s.destConnID)
s.connIDManager.ChangeInitialConnID(packet.hdr.SrcConnectionID)
}
s.receivedFirstPacket = true
@ -927,9 +927,9 @@ func (s *session) destroy(e error) {
func (s *session) destroyImpl(e error) {
s.closeOnce.Do(func() {
if nerr, ok := e.(net.Error); ok && nerr.Timeout() {
s.logger.Errorf("Destroying session %s: %s", s.destConnID, e)
s.logger.Errorf("Destroying session %s: %s", s.connIDManager.Get(), e)
} else {
s.logger.Errorf("Destroying session %s with error: %s", s.destConnID, e)
s.logger.Errorf("Destroying session %s with error: %s", s.connIDManager.Get(), e)
}
s.sessionRunner.Remove(s.srcConnID)
s.closeChan <- closeError{err: e, sendClose: false, remote: false}