From dbb517858e34e14a79fe5843788cba8a3c87191b Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sat, 26 Jun 2021 15:18:54 -0700 Subject: [PATCH] fix incorrect usage of errors.Is errors.Is is supposed to used for equality of errors, not for type assertions. That's what errors.As is there for. --- client.go | 3 +- integrationtests/self/mitm_test.go | 3 +- integrationtests/self/stateless_reset_test.go | 4 +- internal/qerr/errors_go116.go | 53 +++---------------- internal/qerr/errors_not_go116.go | 33 ------------ internal/qerr/errors_test.go | 23 -------- packet_unpacker.go | 5 -- packet_unpacker_test.go | 6 +-- session.go | 26 ++++----- session_test.go | 14 ++--- 10 files changed, 37 insertions(+), 133 deletions(-) delete mode 100644 internal/qerr/errors_not_go116.go diff --git a/client.go b/client.go index bd899fe8..9dbe4ac5 100644 --- a/client.go +++ b/client.go @@ -300,7 +300,8 @@ func (c *client) dial(ctx context.Context) error { errorChan := make(chan error, 1) go func() { err := c.session.run() // returns as soon as the session is closed - if !errors.Is(err, &errCloseForRecreating{}) && c.createdPacketConn { + + if e := (&errCloseForRecreating{}); !errors.As(err, &e) && c.createdPacketConn { c.packetHandlers.Destroy() } errorChan <- err diff --git a/integrationtests/self/mitm_test.go b/integrationtests/self/mitm_test.go index ebd57df9..b36b0b2d 100644 --- a/integrationtests/self/mitm_test.go +++ b/integrationtests/self/mitm_test.go @@ -371,7 +371,8 @@ var _ = Describe("MITM test", func() { } err := runTest(delayCb) Expect(err).To(HaveOccurred()) - Expect(err).To(MatchError(&quic.VersionNegotiationError{})) + vnErr := &quic.VersionNegotiationError{} + Expect(errors.As(err, &vnErr)).To(BeTrue()) }) // times out, because client doesn't accept subsequent real retry packets from server diff --git a/integrationtests/self/stateless_reset_test.go b/integrationtests/self/stateless_reset_test.go index 790dc9bf..2f08fc7b 100644 --- a/integrationtests/self/stateless_reset_test.go +++ b/integrationtests/self/stateless_reset_test.go @@ -2,6 +2,7 @@ package self_test import ( "context" + "errors" "fmt" "math/rand" "net" @@ -99,7 +100,8 @@ var _ = Describe("Stateless Resets", func() { _, serr = str.Read([]byte{0}) } Expect(serr).To(HaveOccurred()) - Expect(serr).To(MatchError(&quic.StatelessResetError{})) + statelessResetErr := &quic.StatelessResetError{} + Expect(errors.As(serr, &statelessResetErr)).To(BeTrue()) Expect(ln2.Close()).To(Succeed()) Eventually(acceptStopped).Should(BeClosed()) }) diff --git a/internal/qerr/errors_go116.go b/internal/qerr/errors_go116.go index 173cd91d..c57cbf3d 100644 --- a/internal/qerr/errors_go116.go +++ b/internal/qerr/errors_go116.go @@ -6,50 +6,9 @@ import ( "net" ) -func (e *TransportError) Is(target error) bool { - _, ok := target.(*TransportError) - if ok { - return true - } - return target == net.ErrClosed -} - -func (e *ApplicationError) Is(target error) bool { - _, ok := target.(*ApplicationError) - if ok { - return true - } - return target == net.ErrClosed -} - -func (e *IdleTimeoutError) Is(target error) bool { - _, ok := target.(*IdleTimeoutError) - if ok { - return true - } - return target == net.ErrClosed -} - -func (e *HandshakeTimeoutError) Is(target error) bool { - _, ok := target.(*HandshakeTimeoutError) - if ok { - return true - } - return target == net.ErrClosed -} - -func (e *VersionNegotiationError) Is(target error) bool { - _, ok := target.(*VersionNegotiationError) - if ok { - return true - } - return target == net.ErrClosed -} - -func (e *StatelessResetError) Is(target error) bool { - _, ok := target.(*StatelessResetError) - if ok { - return true - } - return target == net.ErrClosed -} +func (e *TransportError) Is(target error) bool { return target == net.ErrClosed } +func (e *ApplicationError) Is(target error) bool { return target == net.ErrClosed } +func (e *IdleTimeoutError) Is(target error) bool { return target == net.ErrClosed } +func (e *HandshakeTimeoutError) Is(target error) bool { return target == net.ErrClosed } +func (e *VersionNegotiationError) Is(target error) bool { return target == net.ErrClosed } +func (e *StatelessResetError) Is(target error) bool { return target == net.ErrClosed } diff --git a/internal/qerr/errors_not_go116.go b/internal/qerr/errors_not_go116.go deleted file mode 100644 index 747022b7..00000000 --- a/internal/qerr/errors_not_go116.go +++ /dev/null @@ -1,33 +0,0 @@ -// +build !go1.16 - -package qerr - -func (e *TransportError) Is(target error) bool { - _, ok := target.(*TransportError) - return ok -} - -func (e *ApplicationError) Is(target error) bool { - _, ok := target.(*ApplicationError) - return ok -} - -func (e *IdleTimeoutError) Is(target error) bool { - _, ok := target.(*IdleTimeoutError) - return ok -} - -func (e *HandshakeTimeoutError) Is(target error) bool { - _, ok := target.(*HandshakeTimeoutError) - return ok -} - -func (e *VersionNegotiationError) Is(target error) bool { - _, ok := target.(*VersionNegotiationError) - return ok -} - -func (e *StatelessResetError) Is(target error) bool { - _, ok := target.(*StatelessResetError) - return ok -} diff --git a/internal/qerr/errors_test.go b/internal/qerr/errors_test.go index 325032f8..434cd848 100644 --- a/internal/qerr/errors_test.go +++ b/internal/qerr/errors_test.go @@ -1,7 +1,6 @@ package qerr import ( - "errors" "net" "github.com/lucas-clemente/quic-go/internal/protocol" @@ -37,11 +36,6 @@ var _ = Describe("QUIC Errors", func() { }).Error()).To(Equal("FLOW_CONTROL_ERROR (frame type: 0x1337): foobar")) }) - It("works with error assertions", func() { - Expect(errors.Is(&TransportError{ErrorCode: FlowControlError}, &TransportError{})).To(BeTrue()) - Expect(errors.Is(&TransportError{ErrorCode: FlowControlError}, &ApplicationError{})).To(BeFalse()) - }) - Context("crypto errors", func() { It("has a string representation for errors with a message", func() { err := NewCryptoError(0x42, "foobar") @@ -68,11 +62,6 @@ var _ = Describe("QUIC Errors", func() { ErrorCode: 0x42, }).Error()).To(Equal("Application error 0x42")) }) - - It("works with error assertions", func() { - Expect(errors.Is(&ApplicationError{ErrorCode: 0x1234}, &ApplicationError{})).To(BeTrue()) - Expect(errors.Is(&ApplicationError{ErrorCode: 0x1234}, &TransportError{})).To(BeFalse()) - }) }) Context("timeout errors", func() { @@ -85,8 +74,6 @@ var _ = Describe("QUIC Errors", func() { Expect(nerr.Timeout()).To(BeTrue()) Expect(nerr.Temporary()).To(BeFalse()) Expect(err.Error()).To(Equal("timeout: handshake did not complete in time")) - Expect(errors.Is(err, &HandshakeTimeoutError{})).To(BeTrue()) - Expect(errors.Is(err, &IdleTimeoutError{})).To(BeFalse()) }) It("idle timeouts", func() { @@ -98,16 +85,10 @@ var _ = Describe("QUIC Errors", func() { Expect(nerr.Timeout()).To(BeTrue()) Expect(nerr.Temporary()).To(BeFalse()) Expect(err.Error()).To(Equal("timeout: no recent network activity")) - Expect(errors.Is(err, &HandshakeTimeoutError{})).To(BeFalse()) - Expect(errors.Is(err, &IdleTimeoutError{})).To(BeTrue()) }) }) Context("Version Negotiation errors", func() { - It("is a Version Negotiation error", func() { - Expect(errors.Is(&VersionNegotiationError{Ours: []protocol.VersionNumber{2, 3}}, &VersionNegotiationError{})).To(BeTrue()) - }) - It("has a string representation", func() { Expect((&VersionNegotiationError{ Ours: []protocol.VersionNumber{2, 3}, @@ -119,10 +100,6 @@ var _ = Describe("QUIC Errors", func() { Context("Stateless Reset errors", func() { token := protocol.StatelessResetToken{0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8, 0x9, 0xa, 0xb, 0xc, 0xd, 0xe, 0xf} - It("is a Stateless Reset error", func() { - Expect(errors.Is(&StatelessResetError{Token: token}, &StatelessResetError{})).To(BeTrue()) - }) - It("has a string representation", func() { Expect((&StatelessResetError{Token: token}).Error()).To(Equal("received a stateless reset with token 000102030405060708090a0b0c0d0e0f")) }) diff --git a/packet_unpacker.go b/packet_unpacker.go index f78c6a63..f70d8d07 100644 --- a/packet_unpacker.go +++ b/packet_unpacker.go @@ -18,11 +18,6 @@ type headerParseError struct { err error } -func (e *headerParseError) Is(err error) bool { - _, ok := err.(*headerParseError) - return ok -} - func (e *headerParseError) Unwrap() error { return e.err } diff --git a/packet_unpacker_test.go b/packet_unpacker_test.go index 5787da0a..16c708e6 100644 --- a/packet_unpacker_test.go +++ b/packet_unpacker_test.go @@ -60,7 +60,7 @@ var _ = Describe("Packet Unpacker", func() { opener := mocks.NewMockLongHeaderOpener(mockCtrl) cs.EXPECT().GetHandshakeOpener().Return(opener, nil) _, err := unpacker.Unpack(hdr, time.Now(), data) - Expect(errors.Is(err, &headerParseError{})).To(BeTrue()) + Expect(err).To(BeAssignableToTypeOf(&headerParseError{})) var headerErr *headerParseError Expect(errors.As(err, &headerErr)).To(BeTrue()) Expect(err).To(MatchError("Packet too small. Expected at least 20 bytes after the header, got 19")) @@ -77,9 +77,7 @@ var _ = Describe("Packet Unpacker", func() { opener := mocks.NewMockShortHeaderOpener(mockCtrl) cs.EXPECT().Get1RTTOpener().Return(opener, nil) _, err := unpacker.Unpack(hdr, time.Now(), data) - Expect(errors.Is(err, &headerParseError{})).To(BeTrue()) - var headerErr *headerParseError - Expect(errors.As(err, &headerErr)).To(BeTrue()) + Expect(err).To(BeAssignableToTypeOf(&headerParseError{})) Expect(err).To(MatchError("Packet too small. Expected at least 20 bytes after the header, got 19")) }) diff --git a/session.go b/session.go index 97ed9df6..f38c30d0 100644 --- a/session.go +++ b/session.go @@ -127,11 +127,6 @@ func (e *errCloseForRecreating) Error() string { return "closing session in order to recreate it" } -func (e *errCloseForRecreating) Is(target error) bool { - _, ok := target.(*errCloseForRecreating) - return ok -} - var sessionTracingID uint64 // to be accessed atomically func nextSessionTracingID() uint64 { return atomic.AddUint64(&sessionTracingID, 1) } @@ -691,7 +686,7 @@ runLoop: } s.handleCloseError(&closeErr) - if !errors.Is(closeErr.err, &errCloseForRecreating{}) && s.tracer != nil { + if e := (&errCloseForRecreating{}); !errors.As(closeErr.err, &e) && s.tracer != nil { s.tracer.Close() } s.logger.Infof("Connection %s closed.", s.logID) @@ -1480,14 +1475,21 @@ func (s *session) handleCloseError(closeErr *closeError) { }() } + var ( + statelessResetErr *StatelessResetError + versionNegotiationErr *VersionNegotiationError + recreateErr *errCloseForRecreating + applicationErr *ApplicationError + transportErr *TransportError + ) switch { case errors.Is(e, qerr.ErrIdleTimeout), errors.Is(e, qerr.ErrHandshakeTimeout), - errors.Is(e, &StatelessResetError{}), - errors.Is(e, &VersionNegotiationError{}), - errors.Is(e, &errCloseForRecreating{}), - errors.Is(e, &qerr.ApplicationError{}), - errors.Is(e, &qerr.TransportError{}): + errors.As(e, &statelessResetErr), + errors.As(e, &versionNegotiationErr), + errors.As(e, &recreateErr), + errors.As(e, &applicationErr), + errors.As(e, &transportErr): default: e = &qerr.TransportError{ ErrorCode: qerr.InternalError, @@ -1501,7 +1503,7 @@ func (s *session) handleCloseError(closeErr *closeError) { s.datagramQueue.CloseWithError(e) } - if s.tracer != nil && !errors.Is(e, &errCloseForRecreating{}) { + if s.tracer != nil && !errors.As(e, &recreateErr) { s.tracer.ClosedConnection(e) } diff --git a/session_test.go b/session_test.go index f57dcb0a..531828a0 100644 --- a/session_test.go +++ b/session_test.go @@ -2182,7 +2182,7 @@ var _ = Describe("Session", func() { cryptoSetup.EXPECT().Close() gomock.InOrder( tracer.EXPECT().ClosedConnection(gomock.Any()).Do(func(e error) { - Expect(errors.Is(e, &IdleTimeoutError{})).To(BeTrue()) + Expect(e).To(MatchError(&qerr.IdleTimeoutError{})) }), tracer.EXPECT().Close(), ) @@ -2206,7 +2206,7 @@ var _ = Describe("Session", func() { cryptoSetup.EXPECT().Close() gomock.InOrder( tracer.EXPECT().ClosedConnection(gomock.Any()).Do(func(e error) { - Expect(errors.Is(e, &HandshakeTimeoutError{})).To(BeTrue()) + Expect(e).To(MatchError(&HandshakeTimeoutError{})) }), tracer.EXPECT().Close(), ) @@ -2235,8 +2235,10 @@ var _ = Describe("Session", func() { }) gomock.InOrder( tracer.EXPECT().ClosedConnection(gomock.Any()).Do(func(e error) { - Expect(errors.Is(e, &IdleTimeoutError{})).To(BeFalse()) - Expect(errors.Is(e, &HandshakeTimeoutError{})).To(BeFalse()) + idleTimeout := &IdleTimeoutError{} + handshakeTimeout := &HandshakeTimeoutError{} + Expect(errors.As(e, &idleTimeout)).To(BeFalse()) + Expect(errors.As(e, &handshakeTimeout)).To(BeFalse()) }), tracer.EXPECT().Close(), ) @@ -2263,7 +2265,7 @@ var _ = Describe("Session", func() { cryptoSetup.EXPECT().Close() gomock.InOrder( tracer.EXPECT().ClosedConnection(gomock.Any()).Do(func(e error) { - Expect(errors.Is(e, &IdleTimeoutError{})).To(BeTrue()) + Expect(e).To(MatchError(&IdleTimeoutError{})) }), tracer.EXPECT().Close(), ) @@ -2292,7 +2294,7 @@ var _ = Describe("Session", func() { cryptoSetup.EXPECT().Close() gomock.InOrder( tracer.EXPECT().ClosedConnection(gomock.Any()).Do(func(e error) { - Expect(errors.Is(e, &IdleTimeoutError{})).To(BeTrue()) + Expect(e).To(MatchError(&IdleTimeoutError{})) }), tracer.EXPECT().Close(), )