use an errors.As comparison to detect stateless resets in the session

This commit is contained in:
Marten Seemann 2020-07-10 11:13:36 +07:00
parent e7fa420e26
commit 1f676c2a6c
4 changed files with 29 additions and 9 deletions

View file

@ -19,8 +19,6 @@ type statelessResetErr struct {
token *[16]byte token *[16]byte
} }
func (e statelessResetErr) StatelessResetToken() *[16]byte { return e.token }
func (e statelessResetErr) Error() string { func (e statelessResetErr) Error() string {
return fmt.Sprintf("received a stateless reset with token %x", *e.token) return fmt.Sprintf("received a stateless reset with token %x", *e.token)
} }
@ -319,7 +317,7 @@ func (h *packetHandlerMap) maybeHandleStatelessReset(data []byte) bool {
copy(token[:], data[len(data)-16:]) copy(token[:], data[len(data)-16:])
if sess, ok := h.resetTokens[token]; ok { if sess, ok := h.resetTokens[token]; ok {
h.logger.Debugf("Received a stateless reset with token %#x. Closing session.", token) h.logger.Debugf("Received a stateless reset with token %#x. Closing session.", token)
go sess.destroy(&statelessResetErr{token: &token}) go sess.destroy(statelessResetErr{token: &token})
return true return true
} }
return false return false

View file

@ -242,10 +242,12 @@ var _ = Describe("Packet Handler Map", func() {
packet = append(packet, token[:]...) packet = append(packet, token[:]...)
destroyed := make(chan struct{}) destroyed := make(chan struct{})
packetHandler.EXPECT().destroy(gomock.Any()).Do(func(err error) { packetHandler.EXPECT().destroy(gomock.Any()).Do(func(err error) {
defer GinkgoRecover()
Expect(err).To(HaveOccurred()) Expect(err).To(HaveOccurred())
Expect(err).To(BeAssignableToTypeOf(&statelessResetErr{})) var resetErr statelessResetErr
Expect(errors.As(err, &resetErr)).To(BeTrue())
Expect(err.Error()).To(ContainSubstring("received a stateless reset")) Expect(err.Error()).To(ContainSubstring("received a stateless reset"))
Expect(*err.(*statelessResetErr).StatelessResetToken()).To(Equal(token)) Expect(resetErr.token).To(Equal(&token))
close(destroyed) close(destroyed)
}) })
conn.dataToRead <- packet conn.dataToRead <- packet
@ -261,10 +263,12 @@ var _ = Describe("Packet Handler Map", func() {
packet = append(packet, token[:]...) packet = append(packet, token[:]...)
destroyed := make(chan struct{}) destroyed := make(chan struct{})
packetHandler.EXPECT().destroy(gomock.Any()).Do(func(err error) { packetHandler.EXPECT().destroy(gomock.Any()).Do(func(err error) {
defer GinkgoRecover()
Expect(err).To(HaveOccurred()) Expect(err).To(HaveOccurred())
Expect(err).To(BeAssignableToTypeOf(&statelessResetErr{})) var resetErr statelessResetErr
Expect(errors.As(err, &resetErr)).To(BeTrue())
Expect(err.Error()).To(ContainSubstring("received a stateless reset")) Expect(err.Error()).To(ContainSubstring("received a stateless reset"))
Expect(*err.(*statelessResetErr).StatelessResetToken()).To(Equal(token)) Expect(resetErr.token).To(Equal(&token))
close(destroyed) close(destroyed)
}) })
conn.dataToRead <- packet conn.dataToRead <- packet

View file

@ -1299,8 +1299,9 @@ func (s *session) handleCloseError(closeErr closeError) {
if s.tracer != nil { if s.tracer != nil {
// timeout errors are logged as soon as they occur (to distinguish between handshake and idle timeouts) // timeout errors are logged as soon as they occur (to distinguish between handshake and idle timeouts)
if nerr, ok := closeErr.err.(net.Error); !ok || !nerr.Timeout() { if nerr, ok := closeErr.err.(net.Error); !ok || !nerr.Timeout() {
if statelessReset, ok := closeErr.err.(interface{ StatelessResetToken() *[16]byte }); ok && s.tracer != nil { var resetErr statelessResetErr
s.tracer.ClosedConnection(logging.NewStatelessResetCloseReason(statelessReset.StatelessResetToken())) if errors.As(closeErr.err, &resetErr) {
s.tracer.ClosedConnection(logging.NewStatelessResetCloseReason(resetErr.token))
} else if quicErr.IsApplicationError() { } else if quicErr.IsApplicationError() {
s.tracer.ClosedConnection(logging.NewApplicationCloseReason(quicErr.ErrorCode, closeErr.remote)) s.tracer.ClosedConnection(logging.NewApplicationCloseReason(quicErr.ErrorCode, closeErr.remote))
} else { } else {

View file

@ -639,6 +639,23 @@ var _ = Describe("Session", func() {
sess.scheduleSending() sess.scheduleSending()
Eventually(sess.Context().Done()).Should(BeClosed()) Eventually(sess.Context().Done()).Should(BeClosed())
}) })
It("closes due to a stateless reset", func() {
token := [16]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}
runSession()
gomock.InOrder(
tracer.EXPECT().ClosedConnection(gomock.Any()).Do(func(reason logging.CloseReason) {
t, ok := reason.StatelessReset()
Expect(ok).To(BeTrue())
Expect(t).To(Equal(token))
}),
tracer.EXPECT().Close(),
)
streamManager.EXPECT().CloseWithError(gomock.Any())
sessionRunner.EXPECT().Remove(gomock.Any()).AnyTimes()
cryptoSetup.EXPECT().Close()
sess.destroy(statelessResetErr{token: &token})
})
}) })
Context("receiving packets", func() { Context("receiving packets", func() {