use an errors.As comparison to detect stateless resets in the session

This commit is contained in:
Marten Seemann 2020-07-10 11:13:36 +07:00
parent e7fa420e26
commit 1f676c2a6c
4 changed files with 29 additions and 9 deletions

View file

@ -19,8 +19,6 @@ type statelessResetErr struct {
token *[16]byte
}
func (e statelessResetErr) StatelessResetToken() *[16]byte { return e.token }
func (e statelessResetErr) Error() string {
return fmt.Sprintf("received a stateless reset with token %x", *e.token)
}
@ -319,7 +317,7 @@ func (h *packetHandlerMap) maybeHandleStatelessReset(data []byte) bool {
copy(token[:], data[len(data)-16:])
if sess, ok := h.resetTokens[token]; ok {
h.logger.Debugf("Received a stateless reset with token %#x. Closing session.", token)
go sess.destroy(&statelessResetErr{token: &token})
go sess.destroy(statelessResetErr{token: &token})
return true
}
return false

View file

@ -242,10 +242,12 @@ var _ = Describe("Packet Handler Map", func() {
packet = append(packet, token[:]...)
destroyed := make(chan struct{})
packetHandler.EXPECT().destroy(gomock.Any()).Do(func(err error) {
defer GinkgoRecover()
Expect(err).To(HaveOccurred())
Expect(err).To(BeAssignableToTypeOf(&statelessResetErr{}))
var resetErr statelessResetErr
Expect(errors.As(err, &resetErr)).To(BeTrue())
Expect(err.Error()).To(ContainSubstring("received a stateless reset"))
Expect(*err.(*statelessResetErr).StatelessResetToken()).To(Equal(token))
Expect(resetErr.token).To(Equal(&token))
close(destroyed)
})
conn.dataToRead <- packet
@ -261,10 +263,12 @@ var _ = Describe("Packet Handler Map", func() {
packet = append(packet, token[:]...)
destroyed := make(chan struct{})
packetHandler.EXPECT().destroy(gomock.Any()).Do(func(err error) {
defer GinkgoRecover()
Expect(err).To(HaveOccurred())
Expect(err).To(BeAssignableToTypeOf(&statelessResetErr{}))
var resetErr statelessResetErr
Expect(errors.As(err, &resetErr)).To(BeTrue())
Expect(err.Error()).To(ContainSubstring("received a stateless reset"))
Expect(*err.(*statelessResetErr).StatelessResetToken()).To(Equal(token))
Expect(resetErr.token).To(Equal(&token))
close(destroyed)
})
conn.dataToRead <- packet

View file

@ -1299,8 +1299,9 @@ func (s *session) handleCloseError(closeErr closeError) {
if s.tracer != nil {
// timeout errors are logged as soon as they occur (to distinguish between handshake and idle timeouts)
if nerr, ok := closeErr.err.(net.Error); !ok || !nerr.Timeout() {
if statelessReset, ok := closeErr.err.(interface{ StatelessResetToken() *[16]byte }); ok && s.tracer != nil {
s.tracer.ClosedConnection(logging.NewStatelessResetCloseReason(statelessReset.StatelessResetToken()))
var resetErr statelessResetErr
if errors.As(closeErr.err, &resetErr) {
s.tracer.ClosedConnection(logging.NewStatelessResetCloseReason(resetErr.token))
} else if quicErr.IsApplicationError() {
s.tracer.ClosedConnection(logging.NewApplicationCloseReason(quicErr.ErrorCode, closeErr.remote))
} else {

View file

@ -639,6 +639,23 @@ var _ = Describe("Session", func() {
sess.scheduleSending()
Eventually(sess.Context().Done()).Should(BeClosed())
})
It("closes due to a stateless reset", func() {
token := [16]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}
runSession()
gomock.InOrder(
tracer.EXPECT().ClosedConnection(gomock.Any()).Do(func(reason logging.CloseReason) {
t, ok := reason.StatelessReset()
Expect(ok).To(BeTrue())
Expect(t).To(Equal(token))
}),
tracer.EXPECT().Close(),
)
streamManager.EXPECT().CloseWithError(gomock.Any())
sessionRunner.EXPECT().Remove(gomock.Any()).AnyTimes()
cryptoSetup.EXPECT().Close()
sess.destroy(statelessResetErr{token: &token})
})
})
Context("receiving packets", func() {