fix incorrect usage of errors.Is

errors.Is is supposed to used for equality of errors, not for type
assertions. That's what errors.As is there for.
This commit is contained in:
Marten Seemann 2021-06-26 15:18:54 -07:00
parent a887f8f436
commit dbb517858e
10 changed files with 37 additions and 133 deletions

View file

@ -300,7 +300,8 @@ func (c *client) dial(ctx context.Context) error {
errorChan := make(chan error, 1)
go func() {
err := c.session.run() // returns as soon as the session is closed
if !errors.Is(err, &errCloseForRecreating{}) && c.createdPacketConn {
if e := (&errCloseForRecreating{}); !errors.As(err, &e) && c.createdPacketConn {
c.packetHandlers.Destroy()
}
errorChan <- err

View file

@ -371,7 +371,8 @@ var _ = Describe("MITM test", func() {
}
err := runTest(delayCb)
Expect(err).To(HaveOccurred())
Expect(err).To(MatchError(&quic.VersionNegotiationError{}))
vnErr := &quic.VersionNegotiationError{}
Expect(errors.As(err, &vnErr)).To(BeTrue())
})
// times out, because client doesn't accept subsequent real retry packets from server

View file

@ -2,6 +2,7 @@ package self_test
import (
"context"
"errors"
"fmt"
"math/rand"
"net"
@ -99,7 +100,8 @@ var _ = Describe("Stateless Resets", func() {
_, serr = str.Read([]byte{0})
}
Expect(serr).To(HaveOccurred())
Expect(serr).To(MatchError(&quic.StatelessResetError{}))
statelessResetErr := &quic.StatelessResetError{}
Expect(errors.As(serr, &statelessResetErr)).To(BeTrue())
Expect(ln2.Close()).To(Succeed())
Eventually(acceptStopped).Should(BeClosed())
})

View file

@ -6,50 +6,9 @@ import (
"net"
)
func (e *TransportError) Is(target error) bool {
_, ok := target.(*TransportError)
if ok {
return true
}
return target == net.ErrClosed
}
func (e *ApplicationError) Is(target error) bool {
_, ok := target.(*ApplicationError)
if ok {
return true
}
return target == net.ErrClosed
}
func (e *IdleTimeoutError) Is(target error) bool {
_, ok := target.(*IdleTimeoutError)
if ok {
return true
}
return target == net.ErrClosed
}
func (e *HandshakeTimeoutError) Is(target error) bool {
_, ok := target.(*HandshakeTimeoutError)
if ok {
return true
}
return target == net.ErrClosed
}
func (e *VersionNegotiationError) Is(target error) bool {
_, ok := target.(*VersionNegotiationError)
if ok {
return true
}
return target == net.ErrClosed
}
func (e *StatelessResetError) Is(target error) bool {
_, ok := target.(*StatelessResetError)
if ok {
return true
}
return target == net.ErrClosed
}
func (e *TransportError) Is(target error) bool { return target == net.ErrClosed }
func (e *ApplicationError) Is(target error) bool { return target == net.ErrClosed }
func (e *IdleTimeoutError) Is(target error) bool { return target == net.ErrClosed }
func (e *HandshakeTimeoutError) Is(target error) bool { return target == net.ErrClosed }
func (e *VersionNegotiationError) Is(target error) bool { return target == net.ErrClosed }
func (e *StatelessResetError) Is(target error) bool { return target == net.ErrClosed }

View file

@ -1,33 +0,0 @@
// +build !go1.16
package qerr
func (e *TransportError) Is(target error) bool {
_, ok := target.(*TransportError)
return ok
}
func (e *ApplicationError) Is(target error) bool {
_, ok := target.(*ApplicationError)
return ok
}
func (e *IdleTimeoutError) Is(target error) bool {
_, ok := target.(*IdleTimeoutError)
return ok
}
func (e *HandshakeTimeoutError) Is(target error) bool {
_, ok := target.(*HandshakeTimeoutError)
return ok
}
func (e *VersionNegotiationError) Is(target error) bool {
_, ok := target.(*VersionNegotiationError)
return ok
}
func (e *StatelessResetError) Is(target error) bool {
_, ok := target.(*StatelessResetError)
return ok
}

View file

@ -1,7 +1,6 @@
package qerr
import (
"errors"
"net"
"github.com/lucas-clemente/quic-go/internal/protocol"
@ -37,11 +36,6 @@ var _ = Describe("QUIC Errors", func() {
}).Error()).To(Equal("FLOW_CONTROL_ERROR (frame type: 0x1337): foobar"))
})
It("works with error assertions", func() {
Expect(errors.Is(&TransportError{ErrorCode: FlowControlError}, &TransportError{})).To(BeTrue())
Expect(errors.Is(&TransportError{ErrorCode: FlowControlError}, &ApplicationError{})).To(BeFalse())
})
Context("crypto errors", func() {
It("has a string representation for errors with a message", func() {
err := NewCryptoError(0x42, "foobar")
@ -68,11 +62,6 @@ var _ = Describe("QUIC Errors", func() {
ErrorCode: 0x42,
}).Error()).To(Equal("Application error 0x42"))
})
It("works with error assertions", func() {
Expect(errors.Is(&ApplicationError{ErrorCode: 0x1234}, &ApplicationError{})).To(BeTrue())
Expect(errors.Is(&ApplicationError{ErrorCode: 0x1234}, &TransportError{})).To(BeFalse())
})
})
Context("timeout errors", func() {
@ -85,8 +74,6 @@ var _ = Describe("QUIC Errors", func() {
Expect(nerr.Timeout()).To(BeTrue())
Expect(nerr.Temporary()).To(BeFalse())
Expect(err.Error()).To(Equal("timeout: handshake did not complete in time"))
Expect(errors.Is(err, &HandshakeTimeoutError{})).To(BeTrue())
Expect(errors.Is(err, &IdleTimeoutError{})).To(BeFalse())
})
It("idle timeouts", func() {
@ -98,16 +85,10 @@ var _ = Describe("QUIC Errors", func() {
Expect(nerr.Timeout()).To(BeTrue())
Expect(nerr.Temporary()).To(BeFalse())
Expect(err.Error()).To(Equal("timeout: no recent network activity"))
Expect(errors.Is(err, &HandshakeTimeoutError{})).To(BeFalse())
Expect(errors.Is(err, &IdleTimeoutError{})).To(BeTrue())
})
})
Context("Version Negotiation errors", func() {
It("is a Version Negotiation error", func() {
Expect(errors.Is(&VersionNegotiationError{Ours: []protocol.VersionNumber{2, 3}}, &VersionNegotiationError{})).To(BeTrue())
})
It("has a string representation", func() {
Expect((&VersionNegotiationError{
Ours: []protocol.VersionNumber{2, 3},
@ -119,10 +100,6 @@ var _ = Describe("QUIC Errors", func() {
Context("Stateless Reset errors", func() {
token := protocol.StatelessResetToken{0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8, 0x9, 0xa, 0xb, 0xc, 0xd, 0xe, 0xf}
It("is a Stateless Reset error", func() {
Expect(errors.Is(&StatelessResetError{Token: token}, &StatelessResetError{})).To(BeTrue())
})
It("has a string representation", func() {
Expect((&StatelessResetError{Token: token}).Error()).To(Equal("received a stateless reset with token 000102030405060708090a0b0c0d0e0f"))
})

View file

@ -18,11 +18,6 @@ type headerParseError struct {
err error
}
func (e *headerParseError) Is(err error) bool {
_, ok := err.(*headerParseError)
return ok
}
func (e *headerParseError) Unwrap() error {
return e.err
}

View file

@ -60,7 +60,7 @@ var _ = Describe("Packet Unpacker", func() {
opener := mocks.NewMockLongHeaderOpener(mockCtrl)
cs.EXPECT().GetHandshakeOpener().Return(opener, nil)
_, err := unpacker.Unpack(hdr, time.Now(), data)
Expect(errors.Is(err, &headerParseError{})).To(BeTrue())
Expect(err).To(BeAssignableToTypeOf(&headerParseError{}))
var headerErr *headerParseError
Expect(errors.As(err, &headerErr)).To(BeTrue())
Expect(err).To(MatchError("Packet too small. Expected at least 20 bytes after the header, got 19"))
@ -77,9 +77,7 @@ var _ = Describe("Packet Unpacker", func() {
opener := mocks.NewMockShortHeaderOpener(mockCtrl)
cs.EXPECT().Get1RTTOpener().Return(opener, nil)
_, err := unpacker.Unpack(hdr, time.Now(), data)
Expect(errors.Is(err, &headerParseError{})).To(BeTrue())
var headerErr *headerParseError
Expect(errors.As(err, &headerErr)).To(BeTrue())
Expect(err).To(BeAssignableToTypeOf(&headerParseError{}))
Expect(err).To(MatchError("Packet too small. Expected at least 20 bytes after the header, got 19"))
})

View file

@ -127,11 +127,6 @@ func (e *errCloseForRecreating) Error() string {
return "closing session in order to recreate it"
}
func (e *errCloseForRecreating) Is(target error) bool {
_, ok := target.(*errCloseForRecreating)
return ok
}
var sessionTracingID uint64 // to be accessed atomically
func nextSessionTracingID() uint64 { return atomic.AddUint64(&sessionTracingID, 1) }
@ -691,7 +686,7 @@ runLoop:
}
s.handleCloseError(&closeErr)
if !errors.Is(closeErr.err, &errCloseForRecreating{}) && s.tracer != nil {
if e := (&errCloseForRecreating{}); !errors.As(closeErr.err, &e) && s.tracer != nil {
s.tracer.Close()
}
s.logger.Infof("Connection %s closed.", s.logID)
@ -1480,14 +1475,21 @@ func (s *session) handleCloseError(closeErr *closeError) {
}()
}
var (
statelessResetErr *StatelessResetError
versionNegotiationErr *VersionNegotiationError
recreateErr *errCloseForRecreating
applicationErr *ApplicationError
transportErr *TransportError
)
switch {
case errors.Is(e, qerr.ErrIdleTimeout),
errors.Is(e, qerr.ErrHandshakeTimeout),
errors.Is(e, &StatelessResetError{}),
errors.Is(e, &VersionNegotiationError{}),
errors.Is(e, &errCloseForRecreating{}),
errors.Is(e, &qerr.ApplicationError{}),
errors.Is(e, &qerr.TransportError{}):
errors.As(e, &statelessResetErr),
errors.As(e, &versionNegotiationErr),
errors.As(e, &recreateErr),
errors.As(e, &applicationErr),
errors.As(e, &transportErr):
default:
e = &qerr.TransportError{
ErrorCode: qerr.InternalError,
@ -1501,7 +1503,7 @@ func (s *session) handleCloseError(closeErr *closeError) {
s.datagramQueue.CloseWithError(e)
}
if s.tracer != nil && !errors.Is(e, &errCloseForRecreating{}) {
if s.tracer != nil && !errors.As(e, &recreateErr) {
s.tracer.ClosedConnection(e)
}

View file

@ -2182,7 +2182,7 @@ var _ = Describe("Session", func() {
cryptoSetup.EXPECT().Close()
gomock.InOrder(
tracer.EXPECT().ClosedConnection(gomock.Any()).Do(func(e error) {
Expect(errors.Is(e, &IdleTimeoutError{})).To(BeTrue())
Expect(e).To(MatchError(&qerr.IdleTimeoutError{}))
}),
tracer.EXPECT().Close(),
)
@ -2206,7 +2206,7 @@ var _ = Describe("Session", func() {
cryptoSetup.EXPECT().Close()
gomock.InOrder(
tracer.EXPECT().ClosedConnection(gomock.Any()).Do(func(e error) {
Expect(errors.Is(e, &HandshakeTimeoutError{})).To(BeTrue())
Expect(e).To(MatchError(&HandshakeTimeoutError{}))
}),
tracer.EXPECT().Close(),
)
@ -2235,8 +2235,10 @@ var _ = Describe("Session", func() {
})
gomock.InOrder(
tracer.EXPECT().ClosedConnection(gomock.Any()).Do(func(e error) {
Expect(errors.Is(e, &IdleTimeoutError{})).To(BeFalse())
Expect(errors.Is(e, &HandshakeTimeoutError{})).To(BeFalse())
idleTimeout := &IdleTimeoutError{}
handshakeTimeout := &HandshakeTimeoutError{}
Expect(errors.As(e, &idleTimeout)).To(BeFalse())
Expect(errors.As(e, &handshakeTimeout)).To(BeFalse())
}),
tracer.EXPECT().Close(),
)
@ -2263,7 +2265,7 @@ var _ = Describe("Session", func() {
cryptoSetup.EXPECT().Close()
gomock.InOrder(
tracer.EXPECT().ClosedConnection(gomock.Any()).Do(func(e error) {
Expect(errors.Is(e, &IdleTimeoutError{})).To(BeTrue())
Expect(e).To(MatchError(&IdleTimeoutError{}))
}),
tracer.EXPECT().Close(),
)
@ -2292,7 +2294,7 @@ var _ = Describe("Session", func() {
cryptoSetup.EXPECT().Close()
gomock.InOrder(
tracer.EXPECT().ClosedConnection(gomock.Any()).Do(func(e error) {
Expect(errors.Is(e, &IdleTimeoutError{})).To(BeTrue())
Expect(e).To(MatchError(&IdleTimeoutError{}))
}),
tracer.EXPECT().Close(),
)