diff --git a/session.go b/session.go index 933fc11d..c7e89eb1 100644 --- a/session.go +++ b/session.go @@ -535,6 +535,12 @@ runLoop: if wasProcessed := s.handlePacketImpl(p); !wasProcessed { continue } + // Don't set timers and send packets if the packet made us close the session. + select { + case closeErr = <-s.closeChan: + break runLoop + default: + } case <-s.handshakeCompleteChan: s.handleHandshakeComplete() } @@ -1105,6 +1111,7 @@ func (s *session) closeForRecreating() protocol.PacketNumber { func (s *session) closeRemote(e error) { s.closeOnce.Do(func() { s.logger.Errorf("Peer closed session with error: %s", e) + s.logger.Debugf("sending to close chan") s.closeChan <- closeError{err: e, immediate: true, remote: true} }) } diff --git a/session_test.go b/session_test.go index abc25a33..90da9ea4 100644 --- a/session_test.go +++ b/session_test.go @@ -502,6 +502,34 @@ var _ = Describe("Session", func() { sess.shutdown() Eventually(returned).Should(BeClosed()) }) + + It("doesn't send any more packets after receiving a CONNECTION_CLOSE", func() { + unpacker := NewMockUnpacker(mockCtrl) + sess.handshakeConfirmed = true + sess.unpacker = unpacker + cryptoSetup.EXPECT().Close() + streamManager.EXPECT().CloseWithError(gomock.Any()) + sessionRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any()).AnyTimes() + buf := &bytes.Buffer{} + Expect((&wire.ExtendedHeader{ + Header: wire.Header{DestConnectionID: srcConnID}, + PacketNumberLen: protocol.PacketNumberLen2, + }).Write(buf, sess.version)).To(Succeed()) + unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(*wire.Header, time.Time, []byte) (*unpackedPacket, error) { + buf := &bytes.Buffer{} + Expect((&wire.ConnectionCloseFrame{ErrorCode: qerr.StreamLimitError}).Write(buf, sess.version)).To(Succeed()) + return &unpackedPacket{data: buf.Bytes(), encryptionLevel: protocol.Encryption1RTT}, nil + }) + // don't EXPECT any calls to packer.PackPacket() + sess.handlePacket(&receivedPacket{ + rcvTime: time.Now(), + remoteAddr: &net.UDPAddr{}, + buffer: getPacketBuffer(), + data: buf.Bytes(), + }) + // Consistently(pack).ShouldNot(Receive()) + Eventually(sess.Context().Done()).Should(BeClosed()) + }) }) Context("receiving packets", func() {