diff --git a/packet_handler_map.go b/packet_handler_map.go index 6831ecfa..49195c3a 100644 --- a/packet_handler_map.go +++ b/packet_handler_map.go @@ -19,8 +19,6 @@ type statelessResetErr struct { token *[16]byte } -func (e statelessResetErr) StatelessResetToken() *[16]byte { return e.token } - func (e statelessResetErr) Error() string { return fmt.Sprintf("received a stateless reset with token %x", *e.token) } @@ -319,7 +317,7 @@ func (h *packetHandlerMap) maybeHandleStatelessReset(data []byte) bool { copy(token[:], data[len(data)-16:]) if sess, ok := h.resetTokens[token]; ok { h.logger.Debugf("Received a stateless reset with token %#x. Closing session.", token) - go sess.destroy(&statelessResetErr{token: &token}) + go sess.destroy(statelessResetErr{token: &token}) return true } return false diff --git a/packet_handler_map_test.go b/packet_handler_map_test.go index a110b8d9..c08c75ac 100644 --- a/packet_handler_map_test.go +++ b/packet_handler_map_test.go @@ -242,10 +242,12 @@ var _ = Describe("Packet Handler Map", func() { packet = append(packet, token[:]...) destroyed := make(chan struct{}) packetHandler.EXPECT().destroy(gomock.Any()).Do(func(err error) { + defer GinkgoRecover() Expect(err).To(HaveOccurred()) - Expect(err).To(BeAssignableToTypeOf(&statelessResetErr{})) + var resetErr statelessResetErr + Expect(errors.As(err, &resetErr)).To(BeTrue()) Expect(err.Error()).To(ContainSubstring("received a stateless reset")) - Expect(*err.(*statelessResetErr).StatelessResetToken()).To(Equal(token)) + Expect(resetErr.token).To(Equal(&token)) close(destroyed) }) conn.dataToRead <- packet @@ -261,10 +263,12 @@ var _ = Describe("Packet Handler Map", func() { packet = append(packet, token[:]...) destroyed := make(chan struct{}) packetHandler.EXPECT().destroy(gomock.Any()).Do(func(err error) { + defer GinkgoRecover() Expect(err).To(HaveOccurred()) - Expect(err).To(BeAssignableToTypeOf(&statelessResetErr{})) + var resetErr statelessResetErr + Expect(errors.As(err, &resetErr)).To(BeTrue()) Expect(err.Error()).To(ContainSubstring("received a stateless reset")) - Expect(*err.(*statelessResetErr).StatelessResetToken()).To(Equal(token)) + Expect(resetErr.token).To(Equal(&token)) close(destroyed) }) conn.dataToRead <- packet diff --git a/session.go b/session.go index 2af7425c..09fe0e10 100644 --- a/session.go +++ b/session.go @@ -1299,8 +1299,9 @@ func (s *session) handleCloseError(closeErr closeError) { if s.tracer != nil { // timeout errors are logged as soon as they occur (to distinguish between handshake and idle timeouts) if nerr, ok := closeErr.err.(net.Error); !ok || !nerr.Timeout() { - if statelessReset, ok := closeErr.err.(interface{ StatelessResetToken() *[16]byte }); ok && s.tracer != nil { - s.tracer.ClosedConnection(logging.NewStatelessResetCloseReason(statelessReset.StatelessResetToken())) + var resetErr statelessResetErr + if errors.As(closeErr.err, &resetErr) { + s.tracer.ClosedConnection(logging.NewStatelessResetCloseReason(resetErr.token)) } else if quicErr.IsApplicationError() { s.tracer.ClosedConnection(logging.NewApplicationCloseReason(quicErr.ErrorCode, closeErr.remote)) } else { diff --git a/session_test.go b/session_test.go index b10050d1..38b50f7c 100644 --- a/session_test.go +++ b/session_test.go @@ -639,6 +639,23 @@ var _ = Describe("Session", func() { sess.scheduleSending() Eventually(sess.Context().Done()).Should(BeClosed()) }) + + It("closes due to a stateless reset", func() { + token := [16]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} + runSession() + gomock.InOrder( + tracer.EXPECT().ClosedConnection(gomock.Any()).Do(func(reason logging.CloseReason) { + t, ok := reason.StatelessReset() + Expect(ok).To(BeTrue()) + Expect(t).To(Equal(token)) + }), + tracer.EXPECT().Close(), + ) + streamManager.EXPECT().CloseWithError(gomock.Any()) + sessionRunner.EXPECT().Remove(gomock.Any()).AnyTimes() + cryptoSetup.EXPECT().Close() + sess.destroy(statelessResetErr{token: &token}) + }) }) Context("receiving packets", func() {