diff --git a/internal/handshake/params_negotiator.go b/internal/handshake/params_negotiator.go index 10283e84..8f886ccf 100644 --- a/internal/handshake/params_negotiator.go +++ b/internal/handshake/params_negotiator.go @@ -62,8 +62,7 @@ func (h *paramsNegotiator) SetFromTransportParameters(params []transportParamete if len(p.Value) != 2 { return fmt.Errorf("wrong length for idle_timeout: %d (expected 2)", len(p.Value)) } - val := time.Duration(binary.BigEndian.Uint16(p.Value)) * time.Second - h.idleConnectionStateLifetime = h.negotiateIdleConnectionStateLifetime(val) + h.remoteIdleTimeout = time.Duration(binary.BigEndian.Uint16(p.Value)) * time.Second case omitConnectionIDParameterID: if len(p.Value) != 0 { return fmt.Errorf("wrong length for omit_connection_id: %d (expected empty)", len(p.Value)) @@ -87,7 +86,7 @@ func (h *paramsNegotiator) GetTransportParameters() []transportParameter { // TODO: use a reasonable value here binary.BigEndian.PutUint32(initialMaxStreamID, math.MaxUint32) idleTimeout := make([]byte, 2) - binary.BigEndian.PutUint16(idleTimeout, uint16(h.GetIdleConnectionStateLifetime().Seconds())) + binary.BigEndian.PutUint16(idleTimeout, uint16(h.idleTimeout)) maxPacketSize := make([]byte, 2) binary.BigEndian.PutUint16(maxPacketSize, uint16(protocol.MaxReceivePacketSize)) params := []transportParameter{ diff --git a/internal/handshake/params_negotiator_base.go b/internal/handshake/params_negotiator_base.go index a6f1d20d..3540e288 100644 --- a/internal/handshake/params_negotiator_base.go +++ b/internal/handshake/params_negotiator_base.go @@ -19,6 +19,7 @@ type ParamsNegotiator interface { GetMaxReceiveConnectionFlowControlWindow() protocol.ByteCount GetMaxOutgoingStreams() uint32 GetMaxIncomingStreams() uint32 + // get the idle timeout that was sent by the peer GetIdleConnectionStateLifetime() time.Duration // determines if the client requests omission of connection IDs. OmitConnectionID() bool @@ -43,7 +44,8 @@ type paramsNegotiatorBase struct { maxStreamsPerConnection uint32 maxIncomingDynamicStreamsPerConnection uint32 - idleConnectionStateLifetime time.Duration + idleTimeout time.Duration + remoteIdleTimeout time.Duration sendStreamFlowControlWindow protocol.ByteCount sendConnectionFlowControlWindow protocol.ByteCount receiveStreamFlowControlWindow protocol.ByteCount @@ -61,7 +63,7 @@ func (h *paramsNegotiatorBase) init(params *TransportParameters) { h.maxReceiveConnectionFlowControlWindow = params.MaxReceiveConnectionFlowControlWindow h.requestConnectionIDOmission = params.RequestConnectionIDOmission - h.idleConnectionStateLifetime = params.IdleTimeout + h.idleTimeout = params.IdleTimeout if h.perspective == protocol.PerspectiveServer { h.maxStreamsPerConnection = protocol.MaxStreamsPerConnection // this is the value negotiated based on what the client sent h.maxIncomingDynamicStreamsPerConnection = protocol.MaxStreamsPerConnection // "incoming" seen from the client's perspective @@ -79,10 +81,6 @@ func (h *paramsNegotiatorBase) negotiateMaxIncomingDynamicStreamsPerConnection(c return utils.MinUint32(clientValue, protocol.MaxIncomingDynamicStreamsPerConnection) } -func (h *paramsNegotiatorBase) negotiateIdleConnectionStateLifetime(clientValue time.Duration) time.Duration { - return utils.MinDuration(clientValue, h.idleConnectionStateLifetime) -} - // GetSendStreamFlowControlWindow gets the size of the stream-level flow control window for sending data func (h *paramsNegotiatorBase) GetSendStreamFlowControlWindow() protocol.ByteCount { h.mutex.RLock() @@ -137,5 +135,5 @@ func (h *paramsNegotiatorBase) GetMaxIncomingStreams() uint32 { func (h *paramsNegotiatorBase) GetIdleConnectionStateLifetime() time.Duration { h.mutex.RLock() defer h.mutex.RUnlock() - return h.idleConnectionStateLifetime + return h.remoteIdleTimeout } diff --git a/internal/handshake/params_negotiator_gquic.go b/internal/handshake/params_negotiator_gquic.go index 0c4133ae..30350a16 100644 --- a/internal/handshake/params_negotiator_gquic.go +++ b/internal/handshake/params_negotiator_gquic.go @@ -61,7 +61,7 @@ func (h *paramsNegotiatorGQUIC) SetFromMap(params map[Tag][]byte) error { if err != nil { return errMalformedTag } - h.idleConnectionStateLifetime = h.negotiateIdleConnectionStateLifetime(time.Duration(clientValue) * time.Second) + h.remoteIdleTimeout = time.Duration(clientValue) * time.Second } if value, ok := params[TagSFCW]; ok { if h.flowControlNegotiated { @@ -104,7 +104,7 @@ func (h *paramsNegotiatorGQUIC) GetHelloMap() (map[Tag][]byte, error) { mids := bytes.NewBuffer([]byte{}) utils.LittleEndian.WriteUint32(mids, protocol.MaxIncomingDynamicStreamsPerConnection) icsl := bytes.NewBuffer([]byte{}) - utils.LittleEndian.WriteUint32(icsl, uint32(h.GetIdleConnectionStateLifetime()/time.Second)) + utils.LittleEndian.WriteUint32(icsl, uint32(h.idleTimeout/time.Second)) return map[Tag][]byte{ TagICSL: icsl.Bytes(), diff --git a/internal/handshake/params_negotiator_gquic_test.go b/internal/handshake/params_negotiator_gquic_test.go index 1d468fb3..f150efba 100644 --- a/internal/handshake/params_negotiator_gquic_test.go +++ b/internal/handshake/params_negotiator_gquic_test.go @@ -71,11 +71,11 @@ var _ = Describe("Params Negotiator (for gQUIC)", func() { }) It("sets the connection-level flow control windows in SHLO", func() { - pn.idleConnectionStateLifetime = 0xDECAFBAD * time.Second + pn.idleTimeout = 0xdecafbad * time.Second entryMap, err := pn.GetHelloMap() Expect(err).ToNot(HaveOccurred()) Expect(entryMap).To(HaveKey(TagICSL)) - Expect(entryMap[TagICSL]).To(Equal([]byte{0xAD, 0xFB, 0xCA, 0xDE})) + Expect(entryMap[TagICSL]).To(Equal([]byte{0xad, 0xfb, 0xca, 0xde})) }) It("sets the negotiated value for maximum streams in the SHLO", func() { @@ -211,21 +211,8 @@ var _ = Describe("Params Negotiator (for gQUIC)", func() { }) }) - Context("idle connection state lifetime", func() { - It("has initial idle connection state lifetime", func() { - Expect(pn.GetIdleConnectionStateLifetime()).To(Equal(idleTimeout)) - }) - - It("negotiates correctly when the peer wants a longer lifetime", func() { - Expect(pn.negotiateIdleConnectionStateLifetime(idleTimeout + 10*time.Second)).To(Equal(idleTimeout)) - }) - - It("negotiates correctly when the peer wants a shorter lifetime", func() { - Expect(pn.negotiateIdleConnectionStateLifetime(idleTimeout - 3*time.Second)).To(Equal(idleTimeout - 3*time.Second)) - }) - + Context("idle timeout", func() { It("sets the negotiated lifetime", func() { - // this test only works if the value given here is smaller than protocol.MaxIdleConnectionStateLifetime values := map[Tag][]byte{ TagICSL: {10, 0, 0, 0}, } @@ -234,21 +221,6 @@ var _ = Describe("Params Negotiator (for gQUIC)", func() { Expect(pn.GetIdleConnectionStateLifetime()).To(Equal(10 * time.Second)) }) - It("does not change the idle connection state lifetime when given an invalid value", func() { - values := map[Tag][]byte{ - TagSFCW: {0xDE, 0xAD, 0xBE}, // 1 byte too short - } - err := pn.SetFromMap(values) - Expect(err).To(MatchError(errMalformedTag)) - Expect(pn.GetIdleConnectionStateLifetime()).To(Equal(idleTimeout)) - }) - - It("gets idle connection state lifetime", func() { - value := 0xDECAFBAD * time.Second - pn.idleConnectionStateLifetime = value - Expect(pn.GetIdleConnectionStateLifetime()).To(Equal(value)) - }) - It("errors when given an invalid value", func() { values := map[Tag][]byte{TagICSL: {2, 0, 0}} // 1 byte too short err := pn.SetFromMap(values) diff --git a/internal/handshake/params_negotiator_test.go b/internal/handshake/params_negotiator_test.go index b0d1a877..cce62063 100644 --- a/internal/handshake/params_negotiator_test.go +++ b/internal/handshake/params_negotiator_test.go @@ -33,9 +33,7 @@ var _ = Describe("Params Negotiator (for TLS)", func() { pn = newParamsNegotiator( protocol.PerspectiveServer, protocol.VersionWhatever, - &TransportParameters{ - IdleTimeout: 0x5000 * time.Second, - }, + &TransportParameters{}, ) params = map[transportParameterID][]byte{ initialMaxStreamDataParameterID: []byte{0x11, 0x22, 0x33, 0x44}, @@ -47,6 +45,7 @@ var _ = Describe("Params Negotiator (for TLS)", func() { Context("getting", func() { It("creates the parameters list", func() { + pn.idleTimeout = 0xcafe buf := make([]byte, 4) values := paramsListToMap(pn.GetTransportParameters()) Expect(values).To(HaveLen(5)) @@ -55,7 +54,7 @@ var _ = Describe("Params Negotiator (for TLS)", func() { binary.BigEndian.PutUint32(buf, uint32(protocol.ReceiveConnectionFlowControlWindow)) Expect(values).To(HaveKeyWithValue(initialMaxDataParameterID, buf)) Expect(values).To(HaveKeyWithValue(initialMaxStreamIDParameterID, []byte{0xff, 0xff, 0xff, 0xff})) - Expect(values).To(HaveKeyWithValue(idleTimeoutParameterID, []byte{0x50, 0x0})) + Expect(values).To(HaveKeyWithValue(idleTimeoutParameterID, []byte{0xca, 0xfe})) Expect(values).To(HaveKeyWithValue(maxPacketSizeParameterID, []byte{0x5, 0xac})) // 1452 = 0x5ac }) @@ -76,13 +75,6 @@ var _ = Describe("Params Negotiator (for TLS)", func() { Expect(pn.OmitConnectionID()).To(BeFalse()) }) - It("negotiates a smaller idle timeout, if the peer suggest a higher value than configured", func() { - params[idleTimeoutParameterID] = []byte{0xff, 0xff} - err := pn.SetFromTransportParameters(paramsMapToList(params)) - Expect(err).ToNot(HaveOccurred()) - Expect(pn.GetIdleConnectionStateLifetime()).To(Equal(0x5000 * time.Second)) - }) - It("saves if it should omit the connection ID", func() { params[omitConnectionIDParameterID] = []byte{} err := pn.SetFromTransportParameters(paramsMapToList(params)) diff --git a/session.go b/session.go index b1875d77..ccef802b 100644 --- a/session.go +++ b/session.go @@ -329,7 +329,7 @@ runLoop: s.sentPacketHandler.OnAlarm() } - if s.config.KeepAlive && s.handshakeComplete && time.Since(s.lastNetworkActivityTime) >= s.idleTimeout()/2 { + if s.config.KeepAlive && s.handshakeComplete && time.Since(s.lastNetworkActivityTime) >= s.connParams.GetIdleConnectionStateLifetime()/2 { // send the PING frame since there is no activity in the session s.packer.QueueControlFrame(&wire.PingFrame{}) s.keepAlivePingSent = true @@ -344,7 +344,7 @@ runLoop: if !s.handshakeComplete && now.Sub(s.sessionCreationTime) >= s.config.HandshakeTimeout { s.closeLocal(qerr.Error(qerr.HandshakeTimeout, "Crypto handshake did not complete in time.")) } - if s.handshakeComplete && now.Sub(s.lastNetworkActivityTime) >= s.idleTimeout() { + if s.handshakeComplete && now.Sub(s.lastNetworkActivityTime) >= s.config.IdleTimeout { s.closeLocal(qerr.Error(qerr.NetworkIdleTimeout, "No recent network activity.")) } s.garbageCollectStreams() @@ -368,9 +368,9 @@ func (s *session) Context() context.Context { func (s *session) maybeResetTimer() { var deadline time.Time if s.config.KeepAlive && s.handshakeComplete && !s.keepAlivePingSent { - deadline = s.lastNetworkActivityTime.Add(s.idleTimeout() / 2) + deadline = s.lastNetworkActivityTime.Add(s.connParams.GetIdleConnectionStateLifetime() / 2) } else { - deadline = s.lastNetworkActivityTime.Add(s.idleTimeout()) + deadline = s.lastNetworkActivityTime.Add(s.config.IdleTimeout) } if ackAlarm := s.receivedPacketHandler.GetAlarmTimeout(); !ackAlarm.IsZero() { @@ -390,10 +390,6 @@ func (s *session) maybeResetTimer() { s.timer.Reset(deadline) } -func (s *session) idleTimeout() time.Duration { - return s.connParams.GetIdleConnectionStateLifetime() -} - func (s *session) handlePacketImpl(p *receivedPacket) error { if s.perspective == protocol.PerspectiveClient { diversificationNonce := p.publicHeader.DiversificationNonce diff --git a/session_test.go b/session_test.go index f02b3cda..7b19167a 100644 --- a/session_test.go +++ b/session_test.go @@ -1501,36 +1501,45 @@ var _ = Describe("Session", func() { }) Context("keep-alives", func() { + var mockPn *mocks.MockParamsNegotiator + // should be shorter than the local timeout for these tests + // otherwise we'd send a CONNECTION_CLOSE in the tests where we're testing that no PING is sent + remoteIdleTimeout := 20 * time.Second + + BeforeEach(func() { + mockPn = mocks.NewMockParamsNegotiator(mockCtrl) + mockPn.EXPECT().GetIdleConnectionStateLifetime().Return(remoteIdleTimeout).AnyTimes() + sess.connParams = mockPn + }) + It("sends a PING", func() { sess.handshakeComplete = true sess.config.KeepAlive = true - sess.lastNetworkActivityTime = time.Now().Add(-(sess.idleTimeout() / 2)) + sess.lastNetworkActivityTime = time.Now().Add(-remoteIdleTimeout / 2) go sess.run() defer sess.Close(nil) - time.Sleep(60 * time.Millisecond) - Eventually(mconn.written).ShouldNot(BeEmpty()) - Eventually(func() byte { - // -12 because of the crypto tag. This should be 7 (the frame id for a ping frame). - s := <-mconn.written - return s[len(s)-12-1] - }).Should(Equal(byte(0x07))) + var data []byte + Eventually(mconn.written).Should(Receive(&data)) + // -12 because of the crypto tag. This should be 7 (the frame id for a ping frame). + Expect(data[len(data)-12-1 : len(data)-12]).To(Equal([]byte{0x07})) }) It("doesn't send a PING packet if keep-alive is disabled", func() { sess.handshakeComplete = true - sess.lastNetworkActivityTime = time.Now().Add(-(sess.idleTimeout() / 2)) + sess.config.KeepAlive = false + sess.lastNetworkActivityTime = time.Now().Add(-remoteIdleTimeout / 2) go sess.run() defer sess.Close(nil) - Consistently(mconn.written).Should(BeEmpty()) + Consistently(mconn.written).ShouldNot(Receive()) }) It("doesn't send a PING if the handshake isn't completed yet", func() { sess.handshakeComplete = false sess.config.KeepAlive = true - sess.lastNetworkActivityTime = time.Now().Add(-(sess.idleTimeout() / 2)) + sess.lastNetworkActivityTime = time.Now().Add(-remoteIdleTimeout / 2) go sess.run() defer sess.Close(nil) - Consistently(mconn.written).Should(BeEmpty()) + Consistently(mconn.written).ShouldNot(Receive()) }) }) @@ -1545,7 +1554,7 @@ var _ = Describe("Session", func() { close(done) }) - It("times out due to non-completed crypto handshake", func(done Done) { + It("times out due to non-completed handshake", func(done Done) { sess.sessionCreationTime = time.Now().Add(-protocol.DefaultHandshakeTimeout).Add(-time.Second) err := sess.run() // Would normally not return Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.HandshakeTimeout)) @@ -1554,37 +1563,34 @@ var _ = Describe("Session", func() { close(done) }) - It("does not use ICSL before handshake", func() { + It("does not use the idle timeout before the handshake complete", func() { + sess.config.IdleTimeout = 9999 * time.Second defer sess.Close(nil) sess.lastNetworkActivityTime = time.Now().Add(-time.Minute) - mockPn := mocks.NewMockParamsNegotiator(mockCtrl) - mockPn.EXPECT().GetIdleConnectionStateLifetime().Return(9999 * time.Second).AnyTimes() - mockPn.EXPECT().OmitConnectionID().Return(false).AnyTimes() - sess.connParams = mockPn - sess.packer.connParams = mockPn // the handshake timeout is irrelevant here, since it depends on the time the session was created, // and not on the last network activity done := make(chan struct{}) go func() { + defer GinkgoRecover() _ = sess.run() close(done) }() Consistently(done).ShouldNot(BeClosed()) }) - It("uses ICSL after handshake", func(done Done) { + It("closes the session due to the idle timeout after handshake", func() { + sess.config.IdleTimeout = 0 close(aeadChanged) - mockPn := mocks.NewMockParamsNegotiator(mockCtrl) - mockPn.EXPECT().GetIdleConnectionStateLifetime().Return(0 * time.Second) - mockPn.EXPECT().OmitConnectionID().Return(false).AnyTimes() - sess.connParams = mockPn - sess.packer.connParams = mockPn - mockPn.EXPECT().GetIdleConnectionStateLifetime().Return(0 * time.Second).AnyTimes() - err := sess.run() // Would normally not return + errChan := make(chan error) + go func() { + defer GinkgoRecover() + errChan <- sess.run() // Would normally not return + }() + var err error + Eventually(errChan).Should(Receive(&err)) Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.NetworkIdleTimeout)) Expect(mconn.written).To(Receive(ContainSubstring("No recent network activity."))) Expect(sess.Context().Done()).To(BeClosed()) - close(done) }) })