expose a StatelessResetError

This commit is contained in:
Marten Seemann 2021-04-25 19:10:15 +07:00
parent 1ce572228b
commit 93cfef57ca
8 changed files with 53 additions and 24 deletions

View file

@ -8,6 +8,7 @@ type (
TransportError = qerr.TransportError TransportError = qerr.TransportError
ApplicationError = qerr.ApplicationError ApplicationError = qerr.ApplicationError
VersionNegotiationError = qerr.VersionNegotiationError VersionNegotiationError = qerr.VersionNegotiationError
StatelessResetError = qerr.StatelessResetError
) )
type ( type (

View file

@ -99,8 +99,7 @@ var _ = Describe("Stateless Resets", func() {
_, serr = str.Read([]byte{0}) _, serr = str.Read([]byte{0})
} }
Expect(serr).To(HaveOccurred()) Expect(serr).To(HaveOccurred())
Expect(serr.Error()).To(ContainSubstring("received a stateless reset")) Expect(serr).To(MatchError(&quic.StatelessResetError{}))
Expect(ln2.Close()).To(Succeed()) Expect(ln2.Close()).To(Succeed())
Eventually(acceptStopped).Should(BeClosed()) Eventually(acceptStopped).Should(BeClosed())
}) })

View file

@ -2,6 +2,7 @@ package qerr
import ( import (
"fmt" "fmt"
"net"
"github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/protocol"
) )
@ -109,3 +110,22 @@ func (e *VersionNegotiationError) Is(target error) bool {
_, ok := target.(*VersionNegotiationError) _, ok := target.(*VersionNegotiationError)
return ok return ok
} }
// A StatelessResetError occurs when we receive a stateless reset.
type StatelessResetError struct {
Token protocol.StatelessResetToken
}
var _ net.Error = &StatelessResetError{}
func (e *StatelessResetError) Error() string {
return fmt.Sprintf("received a stateless reset with token %x", e.Token)
}
func (e *StatelessResetError) Is(target error) bool {
_, ok := target.(*StatelessResetError)
return ok
}
func (e *StatelessResetError) Timeout() bool { return false }
func (e *StatelessResetError) Temporary() bool { return true }

View file

@ -115,4 +115,26 @@ var _ = Describe("QUIC Errors", func() {
}).Error()).To(Equal("no compatible QUIC version found (we support [0x2 0x3], server offered [0x4 0x5 0x6])")) }).Error()).To(Equal("no compatible QUIC version found (we support [0x2 0x3], server offered [0x4 0x5 0x6])"))
}) })
}) })
Context("Stateless Reset errors", func() {
token := protocol.StatelessResetToken{0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8, 0x9, 0xa, 0xb, 0xc, 0xd, 0xe, 0xf}
It("is a Stateless Reset error", func() {
Expect(errors.Is(&StatelessResetError{Token: token}, &StatelessResetError{})).To(BeTrue())
})
It("has a string representation", func() {
Expect((&StatelessResetError{Token: token}).Error()).To(Equal("received a stateless reset with token 000102030405060708090a0b0c0d0e0f"))
})
It("is a net.Error", func() {
//nolint:gosimple // we need to assign to an interface here
var err error
err = &StatelessResetError{}
nerr, ok := err.(net.Error)
Expect(ok).To(BeTrue())
Expect(nerr.Timeout()).To(BeFalse())
Expect(nerr.Temporary()).To(BeTrue())
})
})
}) })

View file

@ -18,19 +18,6 @@ import (
"github.com/lucas-clemente/quic-go/logging" "github.com/lucas-clemente/quic-go/logging"
) )
type statelessResetErr struct {
token protocol.StatelessResetToken
}
func (e *statelessResetErr) Error() string {
return fmt.Sprintf("received a stateless reset with token %x", e.token)
}
func (e *statelessResetErr) Is(target error) bool {
_, ok := target.(*statelessResetErr)
return ok
}
type zeroRTTQueue struct { type zeroRTTQueue struct {
queue []*receivedPacket queue []*receivedPacket
retireTimer *time.Timer retireTimer *time.Timer
@ -435,7 +422,7 @@ func (h *packetHandlerMap) maybeHandleStatelessReset(data []byte) bool {
copy(token[:], data[len(data)-16:]) copy(token[:], data[len(data)-16:])
if sess, ok := h.resetTokens[token]; ok { if sess, ok := h.resetTokens[token]; ok {
h.logger.Debugf("Received a stateless reset with token %#x. Closing session.", token) h.logger.Debugf("Received a stateless reset with token %#x. Closing session.", token)
go sess.destroy(&statelessResetErr{token: token}) go sess.destroy(&StatelessResetError{Token: token})
return true return true
} }
return false return false

View file

@ -373,10 +373,10 @@ var _ = Describe("Packet Handler Map", func() {
defer GinkgoRecover() defer GinkgoRecover()
defer close(destroyed) defer close(destroyed)
Expect(err).To(HaveOccurred()) Expect(err).To(HaveOccurred())
var resetErr *statelessResetErr var resetErr *StatelessResetError
Expect(errors.As(err, &resetErr)).To(BeTrue()) Expect(errors.As(err, &resetErr)).To(BeTrue())
Expect(err.Error()).To(ContainSubstring("received a stateless reset")) Expect(err.Error()).To(ContainSubstring("received a stateless reset"))
Expect(resetErr.token).To(Equal(token)) Expect(resetErr.Token).To(Equal(token))
}) })
packetChan <- packetToRead{data: packet} packetChan <- packetToRead{data: packet}
Eventually(destroyed).Should(BeClosed()) Eventually(destroyed).Should(BeClosed())
@ -393,10 +393,10 @@ var _ = Describe("Packet Handler Map", func() {
packetHandler.EXPECT().destroy(gomock.Any()).Do(func(err error) { packetHandler.EXPECT().destroy(gomock.Any()).Do(func(err error) {
defer GinkgoRecover() defer GinkgoRecover()
Expect(err).To(HaveOccurred()) Expect(err).To(HaveOccurred())
var resetErr *statelessResetErr var resetErr *StatelessResetError
Expect(errors.As(err, &resetErr)).To(BeTrue()) Expect(errors.As(err, &resetErr)).To(BeTrue())
Expect(err.Error()).To(ContainSubstring("received a stateless reset")) Expect(err.Error()).To(ContainSubstring("received a stateless reset"))
Expect(resetErr.token).To(Equal(token)) Expect(resetErr.Token).To(Equal(token))
close(destroyed) close(destroyed)
}) })
packetChan <- packetToRead{data: packet} packetChan <- packetToRead{data: packet}

View file

@ -1483,7 +1483,7 @@ func (s *session) handleCloseError(closeErr *closeError) {
switch { switch {
case errors.Is(e, qerr.ErrIdleTimeout), case errors.Is(e, qerr.ErrIdleTimeout),
errors.Is(e, qerr.ErrHandshakeTimeout), errors.Is(e, qerr.ErrHandshakeTimeout),
errors.Is(e, &statelessResetErr{}), errors.Is(e, &StatelessResetError{}),
errors.Is(e, &VersionNegotiationError{}), errors.Is(e, &VersionNegotiationError{}),
errors.Is(e, &errCloseForRecreating{}), errors.Is(e, &errCloseForRecreating{}),
errors.Is(e, &qerr.ApplicationError{}), errors.Is(e, &qerr.ApplicationError{}),
@ -1503,7 +1503,7 @@ func (s *session) handleCloseError(closeErr *closeError) {
if s.tracer != nil && !errors.Is(e, &errCloseForRecreating{}) { if s.tracer != nil && !errors.Is(e, &errCloseForRecreating{}) {
var ( var (
resetErr *statelessResetErr resetErr *StatelessResetError
vnErr *VersionNegotiationError vnErr *VersionNegotiationError
transportErr *qerr.TransportError transportErr *qerr.TransportError
applicationErr *qerr.ApplicationError applicationErr *qerr.ApplicationError
@ -1514,7 +1514,7 @@ func (s *session) handleCloseError(closeErr *closeError) {
case errors.Is(e, qerr.ErrHandshakeTimeout): case errors.Is(e, qerr.ErrHandshakeTimeout):
s.tracer.ClosedConnection(logging.NewTimeoutCloseReason(logging.TimeoutReasonHandshake)) s.tracer.ClosedConnection(logging.NewTimeoutCloseReason(logging.TimeoutReasonHandshake))
case errors.As(e, &resetErr): case errors.As(e, &resetErr):
s.tracer.ClosedConnection(logging.NewStatelessResetCloseReason(resetErr.token)) s.tracer.ClosedConnection(logging.NewStatelessResetCloseReason(resetErr.Token))
case errors.As(e, &vnErr): case errors.As(e, &vnErr):
s.tracer.ClosedConnection(logging.NewVersionNegotiationError(vnErr.Theirs)) s.tracer.ClosedConnection(logging.NewVersionNegotiationError(vnErr.Theirs))
case errors.As(e, &applicationErr): case errors.As(e, &applicationErr):

View file

@ -658,7 +658,7 @@ var _ = Describe("Session", func() {
streamManager.EXPECT().CloseWithError(gomock.Any()) streamManager.EXPECT().CloseWithError(gomock.Any())
sessionRunner.EXPECT().Remove(gomock.Any()).AnyTimes() sessionRunner.EXPECT().Remove(gomock.Any()).AnyTimes()
cryptoSetup.EXPECT().Close() cryptoSetup.EXPECT().Close()
sess.destroy(&statelessResetErr{token: token}) sess.destroy(&StatelessResetError{Token: token})
}) })
}) })