From eec74b14c6b9f60b2eb31f19075d9cd2359846a2 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sun, 29 Dec 2024 17:56:38 +0800 Subject: [PATCH] fix errors.Is for StreamError and DatagramTooLargeError (#4825) --- errors.go | 8 ++--- errors_test.go | 37 +++++++++++++++++++++++ integrationtests/self/cancelation_test.go | 24 +++++++++------ 3 files changed, 55 insertions(+), 14 deletions(-) create mode 100644 errors_test.go diff --git a/errors.go b/errors.go index 3fe1e0a9..4a69a7f1 100644 --- a/errors.go +++ b/errors.go @@ -50,8 +50,8 @@ type StreamError struct { } func (e *StreamError) Is(target error) bool { - _, ok := target.(*StreamError) - return ok + t, ok := target.(*StreamError) + return ok && e.StreamID == t.StreamID && e.ErrorCode == t.ErrorCode && e.Remote == t.Remote } func (e *StreamError) Error() string { @@ -68,8 +68,8 @@ type DatagramTooLargeError struct { } func (e *DatagramTooLargeError) Is(target error) bool { - _, ok := target.(*DatagramTooLargeError) - return ok + t, ok := target.(*DatagramTooLargeError) + return ok && e.MaxDatagramPayloadSize == t.MaxDatagramPayloadSize } func (e *DatagramTooLargeError) Error() string { return "DATAGRAM frame too large" } diff --git a/errors_test.go b/errors_test.go new file mode 100644 index 00000000..58fea644 --- /dev/null +++ b/errors_test.go @@ -0,0 +1,37 @@ +package quic + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestStreamError(t *testing.T) { + require.True(t, errors.Is( + &StreamError{StreamID: 1, ErrorCode: 2, Remote: true}, + &StreamError{StreamID: 1, ErrorCode: 2, Remote: true}, + )) + require.False(t, errors.Is(&StreamError{StreamID: 1}, &StreamError{StreamID: 2})) + require.False(t, errors.Is(&StreamError{StreamID: 1}, &StreamError{StreamID: 2})) + require.Equal(t, + "stream 1 canceled by remote with error code 2", + (&StreamError{StreamID: 1, ErrorCode: 2, Remote: true}).Error(), + ) + require.Equal(t, + "stream 42 canceled by local with error code 1337", + (&StreamError{StreamID: 42, ErrorCode: 1337, Remote: false}).Error(), + ) +} + +func TestDatagramTooLargeError(t *testing.T) { + require.True(t, errors.Is( + &DatagramTooLargeError{MaxDatagramPayloadSize: 1024}, + &DatagramTooLargeError{MaxDatagramPayloadSize: 1024}, + )) + require.False(t, errors.Is( + &DatagramTooLargeError{MaxDatagramPayloadSize: 1024}, + &DatagramTooLargeError{MaxDatagramPayloadSize: 1025}, + )) + require.Equal(t, "DATAGRAM frame too large", (&DatagramTooLargeError{MaxDatagramPayloadSize: 1024}).Error()) +} diff --git a/integrationtests/self/cancelation_test.go b/integrationtests/self/cancelation_test.go index 8f14e1c3..7577575c 100644 --- a/integrationtests/self/cancelation_test.go +++ b/integrationtests/self/cancelation_test.go @@ -187,14 +187,6 @@ func testStreamCancellation( ) { const numStreams = 80 - expectedCancellationError := func(streamID quic.StreamID, remote bool) error { - return &quic.StreamError{ - StreamID: streamID, - ErrorCode: quic.StreamErrorCode(streamID), - Remote: remote, - } - } - server, err := quic.Listen(newUPDConnLocalhost(t), getTLSConfig(), getQuicConfig(nil)) require.NoError(t, err) defer server.Close() @@ -291,7 +283,13 @@ func testStreamCancellation( require.NoError(t, err.Err) continue } - assert.ErrorIs(t, err.Err, expectedCancellationError(err.StreamID, true)) + var streamErr *quic.StreamError + require.ErrorAs(t, err.Err, &streamErr) + assert.Equal(t, streamErr.StreamID, err.StreamID) + assert.Equal(t, streamErr.ErrorCode, quic.StreamErrorCode(err.StreamID)) + if readFunc != nil && writeFunc == nil { + assert.Equal(t, streamErr.Remote, readFunc != nil) + } serverErrs++ } case <-timeout: @@ -304,7 +302,13 @@ func testStreamCancellation( require.NoError(t, err.Err) continue } - assert.ErrorIs(t, err.Err, expectedCancellationError(err.StreamID, true)) + var streamErr *quic.StreamError + require.ErrorAs(t, err.Err, &streamErr) + assert.Equal(t, streamErr.StreamID, err.StreamID) + assert.Equal(t, streamErr.ErrorCode, quic.StreamErrorCode(err.StreamID)) + if readFunc != nil && writeFunc == nil { + assert.Equal(t, streamErr.Remote, writeFunc != nil) + } clientErrs++ } case <-timeout: