diff --git a/internal/protocol/params.go b/internal/protocol/params.go index 4e4bc00c..726cd18e 100644 --- a/internal/protocol/params.go +++ b/internal/protocol/params.go @@ -99,6 +99,10 @@ const DefaultIdleTimeout = 30 * time.Second // DefaultHandshakeTimeout is the default timeout for a connection until the crypto handshake succeeds. const DefaultHandshakeTimeout = 10 * time.Second +// MaxKeepAliveInterval is the maximum time until we send a packet to keep a connection alive. +// It should be shorter than the time that NATs clear their mapping. +const MaxKeepAliveInterval = 20 * time.Second + // RetiredConnectionIDDeleteTimeout is the time we keep closed sessions around in order to retransmit the CONNECTION_CLOSE. // after this time all information about the old connection will be deleted const RetiredConnectionIDDeleteTimeout = 5 * time.Second diff --git a/session.go b/session.go index 8fbc4061..b57723e7 100644 --- a/session.go +++ b/session.go @@ -170,9 +170,10 @@ type session struct { peerParams *handshake.TransportParameters timer *utils.Timer - // keepAlivePingSent stores whether a Ping frame was sent to the peer or not - // it is reset as soon as we receive a packet from the peer + // keepAlivePingSent stores whether a keep alive PING is in flight. + // It is reset as soon as we receive a packet from the peer. keepAlivePingSent bool + keepAliveInterval time.Duration traceCallback func(quictrace.Event) @@ -504,7 +505,7 @@ runLoop: if s.pacingDeadline.IsZero() { // the timer didn't have a pacing deadline set pacingDeadline = s.sentPacketHandler.TimeUntilSend() } - if s.config.KeepAlive && !s.keepAlivePingSent && s.handshakeComplete && s.firstAckElicitingPacketAfterIdleSentTime.IsZero() && time.Since(s.lastPacketReceivedTime) >= s.peerParams.IdleTimeout/2 { + if s.config.KeepAlive && !s.keepAlivePingSent && s.handshakeComplete && s.firstAckElicitingPacketAfterIdleSentTime.IsZero() && time.Since(s.lastPacketReceivedTime) >= s.keepAliveInterval/2 { // send a PING frame since there is no activity in the session s.logger.Debugf("Sending a keep-alive ping to keep the connection alive.") s.framer.QueueControlFrame(&wire.PingFrame{}) @@ -558,7 +559,7 @@ func (s *session) ConnectionState() tls.ConnectionState { func (s *session) maybeResetTimer() { var deadline time.Time if s.config.KeepAlive && s.handshakeComplete && !s.keepAlivePingSent { - deadline = s.idleTimeoutStartTime().Add(s.peerParams.IdleTimeout / 2) + deadline = s.idleTimeoutStartTime().Add(s.keepAliveInterval / 2) } else { deadline = s.idleTimeoutStartTime().Add(s.config.IdleTimeout) } @@ -1079,6 +1080,7 @@ func (s *session) processTransportParameters(data []byte) { } s.logger.Debugf("Received Transport Parameters: %s", params) s.peerParams = params + s.keepAliveInterval = utils.MinDuration(params.IdleTimeout/2, protocol.MaxKeepAliveInterval) if err := s.streamsMap.UpdateLimits(params); err != nil { s.closeLocal(err) return diff --git a/session_test.go b/session_test.go index 06105c80..fbe5e3c4 100644 --- a/session_test.go +++ b/session_test.go @@ -1295,12 +1295,24 @@ var _ = Describe("Session", func() { }) Context("keep-alives", func() { - // should be shorter than the local timeout for these tests - // otherwise we'd send a CONNECTION_CLOSE in the tests where we're testing that no PING is sent - remoteIdleTimeout := 20 * time.Second + setRemoteIdleTimeout := func(t time.Duration) { + tp := &handshake.TransportParameters{IdleTimeout: t} + streamManager.EXPECT().UpdateLimits(gomock.Any()) + packer.EXPECT().HandleTransportParameters(gomock.Any()) + sess.processTransportParameters(tp.Marshal()) + } + + runSession := func() { + go func() { + defer GinkgoRecover() + cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) + sess.run() + }() + } BeforeEach(func() { - sess.peerParams = &handshake.TransportParameters{IdleTimeout: remoteIdleTimeout} + sess.config.KeepAlive = true + sess.handshakeComplete = true }) AfterEach(func() { @@ -1313,44 +1325,44 @@ var _ = Describe("Session", func() { Eventually(sess.Context().Done()).Should(BeClosed()) }) - It("sends a PING as a keep-alive", func() { - sess.handshakeComplete = true - sess.config.KeepAlive = true - sess.lastPacketReceivedTime = time.Now().Add(-remoteIdleTimeout / 2) + It("sends a PING as a keep-alive after half the idle timeout", func() { + setRemoteIdleTimeout(5 * time.Second) + sess.lastPacketReceivedTime = time.Now().Add(-protocol.MaxKeepAliveInterval) sent := make(chan struct{}) packer.EXPECT().PackPacket().Do(func() (*packedPacket, error) { close(sent) return nil, nil }) - go func() { - defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) - sess.run() - }() + runSession() + Eventually(sent).Should(BeClosed()) + }) + + It("sends a PING after a maximum of protocol.MaxKeepAliveInterval", func() { + setRemoteIdleTimeout(time.Hour) + sess.lastPacketReceivedTime = time.Now().Add(-protocol.MaxKeepAliveInterval).Add(-time.Millisecond) + sent := make(chan struct{}) + packer.EXPECT().PackPacket().Do(func() (*packedPacket, error) { + close(sent) + return nil, nil + }) + runSession() Eventually(sent).Should(BeClosed()) }) It("doesn't send a PING packet if keep-alive is disabled", func() { - sess.handshakeComplete = true + setRemoteIdleTimeout(5 * time.Second) sess.config.KeepAlive = false - sess.lastPacketReceivedTime = time.Now().Add(-remoteIdleTimeout / 2) - go func() { - defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) - sess.run() - }() + sess.lastPacketReceivedTime = time.Now().Add(-time.Second * 5 / 2) + runSession() Consistently(mconn.written).ShouldNot(Receive()) }) It("doesn't send a PING if the handshake isn't completed yet", func() { sess.handshakeComplete = false - sess.config.KeepAlive = true - sess.lastPacketReceivedTime = time.Now().Add(-remoteIdleTimeout / 2) - go func() { - defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) - sess.run() - }() + // Needs to be shorter than our idle timeout. + // Otherwise we'll try to send a CONNECTION_CLOSE. + sess.lastPacketReceivedTime = time.Now().Add(-20 * time.Second) + runSession() Consistently(mconn.written).ShouldNot(Receive()) }) })