From 592fb9cad9d0d136c619131b46f4d89ffb0d8205 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sat, 24 Apr 2021 22:11:06 +0700 Subject: [PATCH] introduce a dedicated qerr.TransportError and qerr.ApplicationError --- client.go | 2 +- conn_id_generator.go | 10 +- conn_id_generator_test.go | 11 +- conn_id_manager.go | 2 +- conn_id_manager_test.go | 4 +- crypto_stream.go | 15 +- crypto_stream_test.go | 23 ++- fuzzing/frames/cmd/corpus.go | 9 +- integrationtests/self/handshake_test.go | 43 +++- integrationtests/self/mitm_test.go | 10 +- integrationtests/self/stateless_reset_test.go | 4 +- integrationtests/self/timeout_test.go | 5 +- internal/ackhandler/sent_packet_handler.go | 10 +- .../ackhandler/sent_packet_handler_test.go | 17 +- .../flowcontrol/connection_flow_controller.go | 5 +- .../flowcontrol/stream_flow_controller.go | 20 +- .../stream_flow_controller_test.go | 22 ++- internal/handshake/crypto_setup.go | 5 +- internal/handshake/crypto_setup_test.go | 36 ++-- internal/handshake/updatable_aead.go | 12 +- internal/handshake/updatable_aead_test.go | 14 +- internal/qerr/error_codes.go | 57 +++--- internal/qerr/errorcodes_test.go | 16 +- internal/qerr/quic_error.go | 143 ++++++-------- internal/qerr/quic_error_test.go | 154 ++++++++------- internal/wire/connection_close_frame.go | 11 +- internal/wire/connection_close_frame_test.go | 7 +- internal/wire/frame_parser.go | 8 +- internal/wire/frame_parser_test.go | 29 ++- internal/wire/stream_frame.go | 3 +- internal/wire/stream_frame_test.go | 2 +- internal/wire/transport_parameter_test.go | 101 +++++++--- internal/wire/transport_parameters.go | 7 +- logging/interface.go | 4 +- mock_packer_test.go | 17 +- packet_handler_map.go | 9 +- packet_handler_map_test.go | 4 +- packet_packer.go | 47 +++-- packet_packer_test.go | 30 ++- packet_unpacker_test.go | 5 +- qlog/frame.go | 4 +- qlog/frame_test.go | 2 +- qlog/types.go | 4 +- server.go | 4 +- server_test.go | 4 +- session.go | 160 ++++++++++----- session_test.go | 186 +++++++++++------- streams_map.go | 10 +- streams_map_incoming_bidi.go | 4 +- streams_map_incoming_generic.go | 4 +- streams_map_incoming_generic_test.go | 2 +- streams_map_incoming_uni.go | 4 +- streams_map_outgoing_bidi.go | 2 +- streams_map_outgoing_generic.go | 2 +- streams_map_outgoing_generic_test.go | 4 +- streams_map_outgoing_uni.go | 2 +- streams_map_test.go | 35 +++- 57 files changed, 845 insertions(+), 521 deletions(-) diff --git a/client.go b/client.go index 267cdbc6..bd899fe8 100644 --- a/client.go +++ b/client.go @@ -300,7 +300,7 @@ func (c *client) dial(ctx context.Context) error { errorChan := make(chan error, 1) go func() { err := c.session.run() // returns as soon as the session is closed - if !errors.Is(err, errCloseForRecreating{}) && c.createdPacketConn { + if !errors.Is(err, &errCloseForRecreating{}) && c.createdPacketConn { c.packetHandlers.Destroy() } errorChan <- err diff --git a/conn_id_generator.go b/conn_id_generator.go index 11fb1ab4..2904a57c 100644 --- a/conn_id_generator.go +++ b/conn_id_generator.go @@ -73,7 +73,10 @@ func (m *connIDGenerator) SetMaxActiveConnIDs(limit uint64) error { func (m *connIDGenerator) Retire(seq uint64, sentWithDestConnID protocol.ConnectionID) error { if seq > m.highestSeq { - return qerr.NewError(qerr.ProtocolViolation, fmt.Sprintf("tried to retire connection ID %d. Highest issued: %d", seq, m.highestSeq)) + return &qerr.TransportError{ + ErrorCode: qerr.ProtocolViolation, + ErrorMessage: fmt.Sprintf("retired connection ID %d (highest issued: %d)", seq, m.highestSeq), + } } connID, ok := m.activeSrcConnIDs[seq] // We might already have deleted this connection ID, if this is a duplicate frame. @@ -81,7 +84,10 @@ func (m *connIDGenerator) Retire(seq uint64, sentWithDestConnID protocol.Connect return nil } if connID.Equal(sentWithDestConnID) && !protocol.UseRetireBugBackwardsCompatibilityMode(RetireBugBackwardsCompatibilityMode, m.version) { - return qerr.NewError(qerr.ProtocolViolation, fmt.Sprintf("tried to retire connection ID %d (%s), which was used as the Destination Connection ID on this packet", seq, connID)) + return &qerr.TransportError{ + ErrorCode: qerr.ProtocolViolation, + ErrorMessage: fmt.Sprintf("retired connection ID %d (%s), which was used as the Destination Connection ID on this packet", seq, connID), + } } m.retireConnectionID(connID) delete(m.activeSrcConnIDs, seq) diff --git a/conn_id_generator_test.go b/conn_id_generator_test.go index 39302a52..0f201142 100644 --- a/conn_id_generator_test.go +++ b/conn_id_generator_test.go @@ -4,6 +4,7 @@ import ( "fmt" "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/qerr" "github.com/lucas-clemente/quic-go/internal/wire" . "github.com/onsi/ginkgo" @@ -111,7 +112,10 @@ var _ = Describe("Connection ID Generator", func() { }) It("errors if the peers tries to retire a connection ID that wasn't yet issued", func() { - Expect(g.Retire(1, protocol.ConnectionID{})).To(MatchError("PROTOCOL_VIOLATION: tried to retire connection ID 1. Highest issued: 0")) + Expect(g.Retire(1, protocol.ConnectionID{})).To(MatchError(&qerr.TransportError{ + ErrorCode: qerr.ProtocolViolation, + ErrorMessage: "retired connection ID 1 (highest issued: 0)", + })) }) It("errors if the peers tries to retire a connection ID in a packet with that connection ID", func() { @@ -119,7 +123,10 @@ var _ = Describe("Connection ID Generator", func() { Expect(queuedFrames).ToNot(BeEmpty()) Expect(queuedFrames[0]).To(BeAssignableToTypeOf(&wire.NewConnectionIDFrame{})) f := queuedFrames[0].(*wire.NewConnectionIDFrame) - Expect(g.Retire(f.SequenceNumber, f.ConnectionID)).To(MatchError(fmt.Sprintf("PROTOCOL_VIOLATION: tried to retire connection ID %d (%s), which was used as the Destination Connection ID on this packet", f.SequenceNumber, f.ConnectionID))) + Expect(g.Retire(f.SequenceNumber, f.ConnectionID)).To(MatchError(&qerr.TransportError{ + ErrorCode: qerr.ProtocolViolation, + ErrorMessage: fmt.Sprintf("retired connection ID %d (%s), which was used as the Destination Connection ID on this packet", f.SequenceNumber, f.ConnectionID), + })) }) It("doesn't error if the peers tries to retire a connection ID in a packet with that connection ID in RetireBugBackwardsCompatibilityMode", func() { diff --git a/conn_id_manager.go b/conn_id_manager.go index 99462376..e1b025a9 100644 --- a/conn_id_manager.go +++ b/conn_id_manager.go @@ -53,7 +53,7 @@ func (h *connIDManager) Add(f *wire.NewConnectionIDFrame) error { return err } if h.queue.Len() >= protocol.MaxActiveConnectionIDs { - return qerr.ConnectionIDLimitError + return &qerr.TransportError{ErrorCode: qerr.ConnectionIDLimitError} } return nil } diff --git a/conn_id_manager_test.go b/conn_id_manager_test.go index 224b83ab..6c849059 100644 --- a/conn_id_manager_test.go +++ b/conn_id_manager_test.go @@ -2,7 +2,9 @@ package quic import ( "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/qerr" "github.com/lucas-clemente/quic-go/internal/wire" + . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" ) @@ -231,7 +233,7 @@ var _ = Describe("Connection ID Manager", func() { SequenceNumber: uint64(9999), ConnectionID: protocol.ConnectionID{1, 2, 3, 4}, StatelessResetToken: protocol.StatelessResetToken{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, - })).To(MatchError("CONNECTION_ID_LIMIT_ERROR")) + })).To(MatchError(&qerr.TransportError{ErrorCode: qerr.ConnectionIDLimitError})) }) It("initiates the first connection ID update as soon as possible", func() { diff --git a/crypto_stream.go b/crypto_stream.go index 292dca44..36e21d33 100644 --- a/crypto_stream.go +++ b/crypto_stream.go @@ -39,12 +39,18 @@ func newCryptoStream() cryptoStream { func (s *cryptoStreamImpl) HandleCryptoFrame(f *wire.CryptoFrame) error { highestOffset := f.Offset + protocol.ByteCount(len(f.Data)) if maxOffset := highestOffset; maxOffset > protocol.MaxCryptoStreamOffset { - return qerr.NewError(qerr.CryptoBufferExceeded, fmt.Sprintf("received invalid offset %d on crypto stream, maximum allowed %d", maxOffset, protocol.MaxCryptoStreamOffset)) + return &qerr.TransportError{ + ErrorCode: qerr.CryptoBufferExceeded, + ErrorMessage: fmt.Sprintf("received invalid offset %d on crypto stream, maximum allowed %d", maxOffset, protocol.MaxCryptoStreamOffset), + } } if s.finished { if highestOffset > s.highestOffset { // reject crypto data received after this stream was already finished - return qerr.NewError(qerr.ProtocolViolation, "received crypto data after change of encryption level") + return &qerr.TransportError{ + ErrorCode: qerr.ProtocolViolation, + ErrorMessage: "received crypto data after change of encryption level", + } } // ignore data with a smaller offset than the highest received // could e.g. be a retransmission @@ -80,7 +86,10 @@ func (s *cryptoStreamImpl) GetCryptoData() []byte { func (s *cryptoStreamImpl) Finish() error { if s.queue.HasMoreData() { - return qerr.NewError(qerr.ProtocolViolation, "encryption level changed, but crypto stream has more data to read") + return &qerr.TransportError{ + ErrorCode: qerr.ProtocolViolation, + ErrorMessage: "encryption level changed, but crypto stream has more data to read", + } } s.finished = true return nil diff --git a/crypto_stream_test.go b/crypto_stream_test.go index a4f92b65..a1628dd1 100644 --- a/crypto_stream_test.go +++ b/crypto_stream_test.go @@ -5,6 +5,7 @@ import ( "fmt" "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/qerr" "github.com/lucas-clemente/quic-go/internal/wire" . "github.com/onsi/ginkgo" @@ -49,11 +50,13 @@ var _ = Describe("Crypto Stream", func() { }) It("errors if the frame exceeds the maximum offset", func() { - err := str.HandleCryptoFrame(&wire.CryptoFrame{ + Expect(str.HandleCryptoFrame(&wire.CryptoFrame{ Offset: protocol.MaxCryptoStreamOffset - 5, Data: []byte("foobar"), - }) - Expect(err).To(MatchError(fmt.Sprintf("CRYPTO_BUFFER_EXCEEDED: received invalid offset %d on crypto stream, maximum allowed %d", protocol.MaxCryptoStreamOffset+1, protocol.MaxCryptoStreamOffset))) + })).To(MatchError(&qerr.TransportError{ + ErrorCode: qerr.CryptoBufferExceeded, + ErrorMessage: fmt.Sprintf("received invalid offset %d on crypto stream, maximum allowed %d", protocol.MaxCryptoStreamOffset+1, protocol.MaxCryptoStreamOffset), + })) }) It("handles messages split over multiple CRYPTO frames", func() { @@ -94,8 +97,10 @@ var _ = Describe("Crypto Stream", func() { Data: createHandshakeMessage(5), Offset: 10, })).To(Succeed()) - err := str.Finish() - Expect(err).To(MatchError("PROTOCOL_VIOLATION: encryption level changed, but crypto stream has more data to read")) + Expect(str.Finish()).To(MatchError(&qerr.TransportError{ + ErrorCode: qerr.ProtocolViolation, + ErrorMessage: "encryption level changed, but crypto stream has more data to read", + })) }) It("works with reordered data", func() { @@ -114,10 +119,12 @@ var _ = Describe("Crypto Stream", func() { It("rejects new crypto data after finishing", func() { Expect(str.Finish()).To(Succeed()) - err := str.HandleCryptoFrame(&wire.CryptoFrame{ + Expect(str.HandleCryptoFrame(&wire.CryptoFrame{ Data: createHandshakeMessage(5), - }) - Expect(err).To(MatchError("PROTOCOL_VIOLATION: received crypto data after change of encryption level")) + })).To(MatchError(&qerr.TransportError{ + ErrorCode: qerr.ProtocolViolation, + ErrorMessage: "received crypto data after change of encryption level", + })) }) It("ignores crypto data below the maximum offset received before finishing", func() { diff --git a/fuzzing/frames/cmd/corpus.go b/fuzzing/frames/cmd/corpus.go index 8ea956ed..9773995f 100644 --- a/fuzzing/frames/cmd/corpus.go +++ b/fuzzing/frames/cmd/corpus.go @@ -9,7 +9,6 @@ import ( "github.com/lucas-clemente/quic-go" "github.com/lucas-clemente/quic-go/fuzzing/internal/helper" "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/lucas-clemente/quic-go/internal/qerr" "github.com/lucas-clemente/quic-go/internal/wire" ) @@ -196,23 +195,23 @@ func getFrames() []wire.Frame { }, &wire.ConnectionCloseFrame{ // QUIC error with empty reason IsApplicationError: false, - ErrorCode: qerr.ErrorCode(getRandomNumber()), + ErrorCode: getRandomNumber(), ReasonPhrase: "", }, &wire.ConnectionCloseFrame{ // QUIC error with reason IsApplicationError: false, // TODO: add frame type - ErrorCode: qerr.ErrorCode(getRandomNumber()), + ErrorCode: getRandomNumber(), ReasonPhrase: string(getRandomData(100)), }, &wire.ConnectionCloseFrame{ // application error with empty reason IsApplicationError: true, - ErrorCode: qerr.ErrorCode(getRandomNumber()), + ErrorCode: getRandomNumber(), ReasonPhrase: "", }, &wire.ConnectionCloseFrame{ // application error with reason IsApplicationError: true, - ErrorCode: qerr.ErrorCode(getRandomNumber()), + ErrorCode: getRandomNumber(), ReasonPhrase: string(getRandomData(100)), }, } diff --git a/integrationtests/self/handshake_test.go b/integrationtests/self/handshake_test.go index cac5f12e..c48d3b10 100644 --- a/integrationtests/self/handshake_test.go +++ b/integrationtests/self/handshake_test.go @@ -3,6 +3,7 @@ package self_test import ( "context" "crypto/tls" + "errors" "fmt" "io/ioutil" "net" @@ -267,7 +268,11 @@ var _ = Describe("Handshake tests", func() { getTLSClientConfig(), clientConfig, ) - Expect(err).To(MatchError("CRYPTO_ERROR (0x12a): x509: certificate is valid for localhost, not foo.bar")) + Expect(err).To(HaveOccurred()) + var transportErr *qerr.TransportError + Expect(errors.As(err, &transportErr)).To(BeTrue()) + Expect(transportErr.ErrorCode.IsCryptoError()).To(BeTrue()) + Expect(transportErr.Error()).To(ContainSubstring("x509: certificate is valid for localhost, not foo.bar")) }) It("fails the handshake if the client fails to provide the requested client cert", func() { @@ -292,7 +297,11 @@ var _ = Describe("Handshake tests", func() { }() Eventually(errChan).Should(Receive(&err)) } - Expect(err).To(MatchError("CRYPTO_ERROR (0x12a): tls: bad certificate")) + Expect(err).To(HaveOccurred()) + var transportErr *qerr.TransportError + Expect(errors.As(err, &transportErr)).To(BeTrue()) + Expect(transportErr.ErrorCode.IsCryptoError()).To(BeTrue()) + Expect(transportErr.Error()).To(ContainSubstring("tls: bad certificate")) }) It("uses the ServerName in the tls.Config", func() { @@ -304,7 +313,11 @@ var _ = Describe("Handshake tests", func() { tlsConf, clientConfig, ) - Expect(err).To(MatchError("CRYPTO_ERROR (0x12a): x509: certificate is valid for localhost, not foo.bar")) + Expect(err).To(HaveOccurred()) + var transportErr *qerr.TransportError + Expect(errors.As(err, &transportErr)).To(BeTrue()) + Expect(transportErr.ErrorCode.IsCryptoError()).To(BeTrue()) + Expect(transportErr.Error()).To(ContainSubstring("x509: certificate is valid for localhost, not foo.bar")) }) }) } @@ -363,7 +376,9 @@ var _ = Describe("Handshake tests", func() { _, err := dial() Expect(err).To(HaveOccurred()) - Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.ConnectionRefused)) + var transportErr *qerr.TransportError + Expect(errors.As(err, &transportErr)).To(BeTrue()) + Expect(transportErr.ErrorCode).To(Equal(qerr.ConnectionRefused)) // now accept one session, freeing one spot in the queue _, err = server.Accept(context.Background()) @@ -376,7 +391,8 @@ var _ = Describe("Handshake tests", func() { _, err = dial() Expect(err).To(HaveOccurred()) - Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.ConnectionRefused)) + Expect(errors.As(err, &transportErr)).To(BeTrue()) + Expect(transportErr.ErrorCode).To(Equal(qerr.ConnectionRefused)) }) It("removes closed connections from the accept queue", func() { @@ -392,7 +408,9 @@ var _ = Describe("Handshake tests", func() { _, err = dial() Expect(err).To(HaveOccurred()) - Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.ConnectionRefused)) + var transportErr *qerr.TransportError + Expect(errors.As(err, &transportErr)).To(BeTrue()) + Expect(transportErr.ErrorCode).To(Equal(qerr.ConnectionRefused)) // Now close the one of the session that are waiting to be accepted. // This should free one spot in the queue. @@ -407,7 +425,8 @@ var _ = Describe("Handshake tests", func() { _, err = dial() Expect(err).To(HaveOccurred()) - Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.ConnectionRefused)) + Expect(errors.As(err, &transportErr)).To(BeTrue()) + Expect(transportErr.ErrorCode).To(Equal(qerr.ConnectionRefused)) }) }) @@ -450,8 +469,10 @@ var _ = Describe("Handshake tests", func() { nil, ) Expect(err).To(HaveOccurred()) - Expect(err.Error()).To(ContainSubstring("CRYPTO_ERROR")) - Expect(err.Error()).To(ContainSubstring("no application protocol")) + var transportErr *qerr.TransportError + Expect(errors.As(err, &transportErr)).To(BeTrue()) + Expect(transportErr.ErrorCode.IsCryptoError()).To(BeTrue()) + Expect(transportErr.Error()).To(ContainSubstring("no application protocol")) }) }) @@ -529,7 +550,9 @@ var _ = Describe("Handshake tests", func() { nil, ) Expect(err).To(HaveOccurred()) - Expect(err.Error()).To(ContainSubstring("INVALID_TOKEN")) + var transportErr *qerr.TransportError + Expect(errors.As(err, &transportErr)).To(BeTrue()) + Expect(transportErr.ErrorCode).To(Equal(qerr.InvalidToken)) // Receiving a Retry might lead the client to measure a very small RTT. // Then, it sometimes would retransmit the ClientHello before receiving the ServerHello. Expect(len(tokenChan)).To(BeNumerically(">=", 2)) diff --git a/integrationtests/self/mitm_test.go b/integrationtests/self/mitm_test.go index 46031edc..03f027a0 100644 --- a/integrationtests/self/mitm_test.go +++ b/integrationtests/self/mitm_test.go @@ -3,6 +3,7 @@ package self_test import ( "bytes" "context" + "errors" "fmt" "io/ioutil" "math" @@ -11,7 +12,7 @@ import ( "sync/atomic" "time" - quic "github.com/lucas-clemente/quic-go" + "github.com/lucas-clemente/quic-go" quicproxy "github.com/lucas-clemente/quic-go/integrationtests/tools/proxy" "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/qerr" @@ -440,9 +441,10 @@ var _ = Describe("MITM test", func() { } err := runTest(delayCb) Expect(err).To(HaveOccurred()) - Expect(err).To(BeAssignableToTypeOf(&qerr.QuicError{})) - Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.ProtocolViolation)) - Expect(err.Error()).To(ContainSubstring("Received ACK for an unsent packet")) + var transportErr *qerr.TransportError + Expect(errors.As(err, &transportErr)).To(BeTrue()) + Expect(transportErr.ErrorCode).To(Equal(qerr.ProtocolViolation)) + Expect(transportErr.ErrorMessage).To(ContainSubstring("received ACK for an unsent packet")) }) }) }) diff --git a/integrationtests/self/stateless_reset_test.go b/integrationtests/self/stateless_reset_test.go index d485000f..86f3c893 100644 --- a/integrationtests/self/stateless_reset_test.go +++ b/integrationtests/self/stateless_reset_test.go @@ -7,7 +7,7 @@ import ( "net" "time" - quic "github.com/lucas-clemente/quic-go" + "github.com/lucas-clemente/quic-go" quicproxy "github.com/lucas-clemente/quic-go/integrationtests/tools/proxy" "github.com/lucas-clemente/quic-go/internal/utils" @@ -99,7 +99,7 @@ var _ = Describe("Stateless Resets", func() { _, serr = str.Read([]byte{0}) } Expect(serr).To(HaveOccurred()) - Expect(serr.Error()).To(ContainSubstring("INTERNAL_ERROR: received a stateless reset")) + Expect(serr.Error()).To(ContainSubstring("received a stateless reset")) Expect(ln2.Close()).To(Succeed()) Eventually(acceptStopped).Should(BeClosed()) diff --git a/integrationtests/self/timeout_test.go b/integrationtests/self/timeout_test.go index 6c660e10..4dd81efb 100644 --- a/integrationtests/self/timeout_test.go +++ b/integrationtests/self/timeout_test.go @@ -3,6 +3,7 @@ package self_test import ( "bytes" "context" + "errors" "fmt" "io" "io/ioutil" @@ -13,7 +14,9 @@ import ( "sync/atomic" "time" - quic "github.com/lucas-clemente/quic-go" + "github.com/lucas-clemente/quic-go/internal/qerr" + + "github.com/lucas-clemente/quic-go" quicproxy "github.com/lucas-clemente/quic-go/integrationtests/tools/proxy" "github.com/lucas-clemente/quic-go/internal/utils" "github.com/lucas-clemente/quic-go/logging" diff --git a/internal/ackhandler/sent_packet_handler.go b/internal/ackhandler/sent_packet_handler.go index 57aac995..dd8988f0 100644 --- a/internal/ackhandler/sent_packet_handler.go +++ b/internal/ackhandler/sent_packet_handler.go @@ -282,7 +282,10 @@ func (h *sentPacketHandler) ReceivedAck(ack *wire.AckFrame, encLevel protocol.En largestAcked := ack.LargestAcked() if largestAcked > pnSpace.largestSent { - return false, qerr.NewError(qerr.ProtocolViolation, "Received ACK for an unsent packet") + return false, &qerr.TransportError{ + ErrorCode: qerr.ProtocolViolation, + ErrorMessage: "received ACK for an unsent packet", + } } pnSpace.largestAcked = utils.MaxPacketNumber(pnSpace.largestAcked, largestAcked) @@ -385,7 +388,10 @@ func (h *sentPacketHandler) detectAndRemoveAckedPackets(ack *wire.AckFrame, encL } } if p.skippedPacket { - return false, fmt.Errorf("received an ACK for skipped packet number: %d (%s)", p.PacketNumber, encLevel) + return false, &qerr.TransportError{ + ErrorCode: qerr.ProtocolViolation, + ErrorMessage: fmt.Sprintf("received an ACK for skipped packet number: %d (%s)", p.PacketNumber, encLevel), + } } h.ackedPackets = append(h.ackedPackets, p) return true, nil diff --git a/internal/ackhandler/sent_packet_handler_test.go b/internal/ackhandler/sent_packet_handler_test.go index 26875f00..7282444d 100644 --- a/internal/ackhandler/sent_packet_handler_test.go +++ b/internal/ackhandler/sent_packet_handler_test.go @@ -5,8 +5,10 @@ import ( "time" "github.com/golang/mock/gomock" + "github.com/lucas-clemente/quic-go/internal/mocks" "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/qerr" "github.com/lucas-clemente/quic-go/internal/utils" "github.com/lucas-clemente/quic-go/internal/wire" @@ -199,13 +201,19 @@ var _ = Describe("SentPacketHandler", func() { handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 102})) ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 100, Largest: 102}}} _, err := handler.ReceivedAck(ack, protocol.Encryption1RTT, time.Now()) - Expect(err).To(MatchError("received an ACK for skipped packet number: 101 (1-RTT)")) + Expect(err).To(MatchError(&qerr.TransportError{ + ErrorCode: qerr.ProtocolViolation, + ErrorMessage: "received an ACK for skipped packet number: 101 (1-RTT)", + })) }) It("rejects ACKs with a too high LargestAcked packet number", func() { ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 0, Largest: 9999}}} _, err := handler.ReceivedAck(ack, protocol.Encryption1RTT, time.Now()) - Expect(err).To(MatchError("PROTOCOL_VIOLATION: Received ACK for an unsent packet")) + Expect(err).To(MatchError(&qerr.TransportError{ + ErrorCode: qerr.ProtocolViolation, + ErrorMessage: "received ACK for an unsent packet", + })) Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(10))) }) @@ -1131,7 +1139,10 @@ var _ = Describe("SentPacketHandler", func() { })) ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 13, Largest: 13}}} _, err := handler.ReceivedAck(ack, protocol.EncryptionHandshake, time.Now()) - Expect(err).To(MatchError("PROTOCOL_VIOLATION: Received ACK for an unsent packet")) + Expect(err).To(MatchError(&qerr.TransportError{ + ErrorCode: qerr.ProtocolViolation, + ErrorMessage: "received ACK for an unsent packet", + })) }) It("deletes Initial packets, as a server", func() { diff --git a/internal/flowcontrol/connection_flow_controller.go b/internal/flowcontrol/connection_flow_controller.go index 3519817a..90e7ceab 100644 --- a/internal/flowcontrol/connection_flow_controller.go +++ b/internal/flowcontrol/connection_flow_controller.go @@ -50,7 +50,10 @@ func (c *connectionFlowController) IncrementHighestReceived(increment protocol.B c.highestReceived += increment if c.checkFlowControlViolation() { - return qerr.NewError(qerr.FlowControlError, fmt.Sprintf("Received %d bytes for the connection, allowed %d bytes", c.highestReceived, c.receiveWindow)) + return &qerr.TransportError{ + ErrorCode: qerr.FlowControlError, + ErrorMessage: fmt.Sprintf("received %d bytes for the connection, allowed %d bytes", c.highestReceived, c.receiveWindow), + } } return nil } diff --git a/internal/flowcontrol/stream_flow_controller.go b/internal/flowcontrol/stream_flow_controller.go index 46a911cf..4cb94d90 100644 --- a/internal/flowcontrol/stream_flow_controller.go +++ b/internal/flowcontrol/stream_flow_controller.go @@ -54,11 +54,17 @@ func (c *streamFlowController) UpdateHighestReceived(offset protocol.ByteCount, if c.receivedFinalOffset { // If we receive another final offset, check that it's the same. if final && offset != c.highestReceived { - return qerr.NewError(qerr.FinalSizeError, fmt.Sprintf("Received inconsistent final offset for stream %d (old: %d, new: %d bytes)", c.streamID, c.highestReceived, offset)) + return &qerr.TransportError{ + ErrorCode: qerr.FinalSizeError, + ErrorMessage: fmt.Sprintf("received inconsistent final offset for stream %d (old: %d, new: %d bytes)", c.streamID, c.highestReceived, offset), + } } // Check that the offset is below the final offset. if offset > c.highestReceived { - return qerr.NewError(qerr.FinalSizeError, fmt.Sprintf("Received offset %d for stream %d. Final offset was already received at %d", offset, c.streamID, c.highestReceived)) + return &qerr.TransportError{ + ErrorCode: qerr.FinalSizeError, + ErrorMessage: fmt.Sprintf("received offset %d for stream %d, but final offset was already received at %d", offset, c.streamID, c.highestReceived), + } } } @@ -72,7 +78,10 @@ func (c *streamFlowController) UpdateHighestReceived(offset protocol.ByteCount, // This can happen due to reordering. if offset <= c.highestReceived { if final { - return qerr.NewError(qerr.FinalSizeError, fmt.Sprintf("Received final offset %d for stream %d, but already received offset %d before", offset, c.streamID, c.highestReceived)) + return &qerr.TransportError{ + ErrorCode: qerr.FinalSizeError, + ErrorMessage: fmt.Sprintf("received final offset %d for stream %d, but already received offset %d before", offset, c.streamID, c.highestReceived), + } } return nil } @@ -80,7 +89,10 @@ func (c *streamFlowController) UpdateHighestReceived(offset protocol.ByteCount, increment := offset - c.highestReceived c.highestReceived = offset if c.checkFlowControlViolation() { - return qerr.NewError(qerr.FlowControlError, fmt.Sprintf("Received %d bytes on stream %d, allowed %d bytes", offset, c.streamID, c.receiveWindow)) + return &qerr.TransportError{ + ErrorCode: qerr.FlowControlError, + ErrorMessage: fmt.Sprintf("received %d bytes on stream %d, allowed %d bytes", offset, c.streamID, c.receiveWindow), + } } return c.connection.IncrementHighestReceived(increment) } diff --git a/internal/flowcontrol/stream_flow_controller_test.go b/internal/flowcontrol/stream_flow_controller_test.go index 9f1d65d2..26f7a9cc 100644 --- a/internal/flowcontrol/stream_flow_controller_test.go +++ b/internal/flowcontrol/stream_flow_controller_test.go @@ -4,7 +4,9 @@ import ( "time" "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/qerr" "github.com/lucas-clemente/quic-go/internal/utils" + . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" ) @@ -97,7 +99,10 @@ var _ = Describe("Stream Flow controller", func() { }) It("detects a flow control violation", func() { - Expect(controller.UpdateHighestReceived(receiveWindow+1, false)).To(MatchError("FLOW_CONTROL_ERROR: Received 10001 bytes on stream 10, allowed 10000 bytes")) + Expect(controller.UpdateHighestReceived(receiveWindow+1, false)).To(MatchError(&qerr.TransportError{ + ErrorCode: qerr.FlowControlError, + ErrorMessage: "received 10001 bytes on stream 10, allowed 10000 bytes", + })) }) It("accepts a final offset higher than the highest received", func() { @@ -108,7 +113,10 @@ var _ = Describe("Stream Flow controller", func() { It("errors when receiving a final offset smaller than the highest offset received so far", func() { controller.UpdateHighestReceived(100, false) - Expect(controller.UpdateHighestReceived(50, true)).To(MatchError("FINAL_SIZE_ERROR: Received final offset 50 for stream 10, but already received offset 100 before")) + Expect(controller.UpdateHighestReceived(50, true)).To(MatchError(&qerr.TransportError{ + ErrorCode: qerr.FinalSizeError, + ErrorMessage: "received final offset 50 for stream 10, but already received offset 100 before", + })) }) It("accepts delayed data after receiving a final offset", func() { @@ -118,7 +126,10 @@ var _ = Describe("Stream Flow controller", func() { It("errors when receiving a higher offset after receiving a final offset", func() { Expect(controller.UpdateHighestReceived(200, true)).To(Succeed()) - Expect(controller.UpdateHighestReceived(250, false)).To(MatchError("FINAL_SIZE_ERROR: Received offset 250 for stream 10. Final offset was already received at 200")) + Expect(controller.UpdateHighestReceived(250, false)).To(MatchError(&qerr.TransportError{ + ErrorCode: qerr.FinalSizeError, + ErrorMessage: "received offset 250 for stream 10, but final offset was already received at 200", + })) }) It("accepts duplicate final offsets", func() { @@ -129,7 +140,10 @@ var _ = Describe("Stream Flow controller", func() { It("errors when receiving inconsistent final offsets", func() { Expect(controller.UpdateHighestReceived(200, true)).To(Succeed()) - Expect(controller.UpdateHighestReceived(201, true)).To(MatchError("FINAL_SIZE_ERROR: Received inconsistent final offset for stream 10 (old: 200, new: 201 bytes)")) + Expect(controller.UpdateHighestReceived(201, true)).To(MatchError(&qerr.TransportError{ + ErrorCode: qerr.FinalSizeError, + ErrorMessage: "received inconsistent final offset for stream 10 (old: 200, new: 201 bytes)", + })) }) It("tells the connection flow controller when a stream is abandoned", func() { diff --git a/internal/handshake/crypto_setup.go b/internal/handshake/crypto_setup.go index 699807ac..319d3318 100644 --- a/internal/handshake/crypto_setup.go +++ b/internal/handshake/crypto_setup.go @@ -403,7 +403,10 @@ func (h *cryptoSetup) checkEncryptionLevel(msgType messageType, encLevel protoco func (h *cryptoSetup) handleTransportParameters(data []byte) { var tp wire.TransportParameters if err := tp.Unmarshal(data, h.perspective.Opposite()); err != nil { - h.runner.OnError(qerr.NewError(qerr.TransportParameterError, err.Error())) + h.runner.OnError(&qerr.TransportError{ + ErrorCode: qerr.TransportParameterError, + ErrorMessage: err.Error(), + }) } h.peerParams = &tp h.runner.OnReceivedParams(h.peerParams) diff --git a/internal/handshake/crypto_setup_test.go b/internal/handshake/crypto_setup_test.go index d09673e9..ec0e4e1e 100644 --- a/internal/handshake/crypto_setup_test.go +++ b/internal/handshake/crypto_setup_test.go @@ -106,7 +106,10 @@ var _ = Describe("Crypto Setup TLS", func() { go func() { defer GinkgoRecover() server.RunHandshake() - Expect(sErrChan).To(Receive(MatchError("CRYPTO_ERROR (0x10a): local error: tls: unexpected message"))) + Expect(sErrChan).To(Receive(MatchError(&qerr.TransportError{ + ErrorCode: 0x10a, + ErrorMessage: "local error: tls: unexpected message", + }))) close(done) }() @@ -152,13 +155,10 @@ var _ = Describe("Crypto Setup TLS", func() { fakeCH := append([]byte{byte(typeClientHello), 0, 0, 6}, []byte("foobar")...) server.HandleMessage(fakeCH, protocol.EncryptionHandshake) // wrong encryption level - var err error - Expect(sErrChan).To(Receive(&err)) - Expect(err).To(BeAssignableToTypeOf(&qerr.QuicError{})) - qerr := err.(*qerr.QuicError) - Expect(qerr.IsCryptoError()).To(BeTrue()) - Expect(qerr.ErrorCode).To(BeEquivalentTo(0x100 + int(alertUnexpectedMessage))) - Expect(err.Error()).To(ContainSubstring("expected handshake message ClientHello to have encryption level Initial, has Handshake")) + Expect(sErrChan).To(Receive(MatchError(&qerr.TransportError{ + ErrorCode: 0x100 + qerr.TransportErrorCode(alertUnexpectedMessage), + ErrorMessage: "expected handshake message ClientHello to have encryption level Initial, has Handshake", + }))) // make the go routine return Expect(server.Close()).To(Succeed()) @@ -193,10 +193,8 @@ var _ = Describe("Crypto Setup TLS", func() { server.RunHandshake() var err error Expect(sErrChan).To(Receive(&err)) - Expect(err).To(BeAssignableToTypeOf(&qerr.QuicError{})) - qerr := err.(*qerr.QuicError) - Expect(qerr.IsCryptoError()).To(BeTrue()) - Expect(qerr.ErrorCode).To(BeEquivalentTo(0x100 + int(alertUnexpectedMessage))) + Expect(err).To(BeAssignableToTypeOf(&qerr.TransportError{})) + Expect(err.(*qerr.TransportError).ErrorCode).To(BeEquivalentTo(0x100 + int(alertUnexpectedMessage))) close(done) }() @@ -570,12 +568,9 @@ var _ = Describe("Crypto Setup TLS", func() { Eventually(done).Should(BeClosed()) // inject an invalid session ticket - cRunner.EXPECT().OnError(gomock.Any()).Do(func(err error) { - Expect(err).To(BeAssignableToTypeOf(&qerr.QuicError{})) - qerr := err.(*qerr.QuicError) - Expect(qerr.IsCryptoError()).To(BeTrue()) - Expect(qerr.ErrorCode).To(BeEquivalentTo(0x100 + int(alertUnexpectedMessage))) - Expect(qerr.Error()).To(ContainSubstring("expected handshake message NewSessionTicket to have encryption level 1-RTT, has Handshake")) + cRunner.EXPECT().OnError(&qerr.TransportError{ + ErrorCode: 0x100 + qerr.TransportErrorCode(alertUnexpectedMessage), + ErrorMessage: "expected handshake message NewSessionTicket to have encryption level 1-RTT, has Handshake", }) b := append([]byte{uint8(typeNewSessionTicket), 0, 0, 6}, []byte("foobar")...) client.HandleMessage(b, protocol.EncryptionHandshake) @@ -633,9 +628,8 @@ var _ = Describe("Crypto Setup TLS", func() { // inject an invalid session ticket cRunner.EXPECT().OnError(gomock.Any()).Do(func(err error) { - Expect(err).To(BeAssignableToTypeOf(&qerr.QuicError{})) - qerr := err.(*qerr.QuicError) - Expect(qerr.IsCryptoError()).To(BeTrue()) + Expect(err).To(BeAssignableToTypeOf(&qerr.TransportError{})) + Expect(err.(*qerr.TransportError).ErrorCode.IsCryptoError()).To(BeTrue()) }) b := append([]byte{uint8(typeNewSessionTicket), 0, 0, 6}, []byte("foobar")...) client.HandleMessage(b, protocol.Encryption1RTT) diff --git a/internal/handshake/updatable_aead.go b/internal/handshake/updatable_aead.go index f28f3e03..094e6504 100644 --- a/internal/handshake/updatable_aead.go +++ b/internal/handshake/updatable_aead.go @@ -163,7 +163,7 @@ func (a *updatableAEAD) Open(dst, src []byte, rcvTime time.Time, pn protocol.Pac if err == ErrDecryptionFailed { a.invalidPacketCount++ if a.invalidPacketCount >= a.invalidPacketLimit { - return nil, qerr.AEADLimitReached + return nil, &qerr.TransportError{ErrorCode: qerr.AEADLimitReached} } } if err == nil { @@ -201,7 +201,10 @@ func (a *updatableAEAD) open(dst, src []byte, rcvTime time.Time, pn protocol.Pac } // Opening succeeded. Check if the peer was allowed to update. if a.keyPhase > 0 && a.firstSentWithCurrentKey == protocol.InvalidPacketNumber { - return nil, qerr.NewError(qerr.KeyUpdateError, "keys updated too quickly") + return nil, &qerr.TransportError{ + ErrorCode: qerr.KeyUpdateError, + ErrorMessage: "keys updated too quickly", + } } a.rollKeys() a.logger.Debugf("Peer updated keys to %d", a.keyPhase) @@ -250,7 +253,10 @@ func (a *updatableAEAD) Seal(dst, src []byte, pn protocol.PacketNumber, ad []byt func (a *updatableAEAD) SetLargestAcked(pn protocol.PacketNumber) error { if a.firstSentWithCurrentKey != protocol.InvalidPacketNumber && pn >= a.firstSentWithCurrentKey && a.numRcvdWithCurrentKey == 0 { - return qerr.NewError(qerr.KeyUpdateError, fmt.Sprintf("received ACK for key phase %d, but peer didn't update keys", a.keyPhase)) + return &qerr.TransportError{ + ErrorCode: qerr.KeyUpdateError, + ErrorMessage: fmt.Sprintf("received ACK for key phase %d, but peer didn't update keys", a.keyPhase), + } } a.largestAcked = pn return nil diff --git a/internal/handshake/updatable_aead_test.go b/internal/handshake/updatable_aead_test.go index 8726c3b0..df49a0dc 100644 --- a/internal/handshake/updatable_aead_test.go +++ b/internal/handshake/updatable_aead_test.go @@ -138,7 +138,9 @@ var _ = Describe("Updatable AEAD", func() { Expect(err).To(MatchError(ErrDecryptionFailed)) } _, err := client.Open(nil, []byte("foobar"), time.Now(), 10, protocol.KeyPhaseZero, []byte("ad")) - Expect(err).To(MatchError(qerr.AEADLimitReached)) + Expect(err).To(HaveOccurred()) + Expect(err).To(BeAssignableToTypeOf(&qerr.TransportError{})) + Expect(err.(*qerr.TransportError).ErrorCode).To(Equal(qerr.AEADLimitReached)) }) Context("key updates", func() { @@ -257,7 +259,10 @@ var _ = Describe("Updatable AEAD", func() { client.rollKeys() encrypted1 := client.Seal(nil, msg, 0x42, ad) _, err = server.Open(nil, encrypted1, time.Now(), 0x42, protocol.KeyPhaseZero, ad) - Expect(err).To(MatchError("KEY_UPDATE_ERROR: keys updated too quickly")) + Expect(err).To(MatchError(&qerr.TransportError{ + ErrorCode: qerr.KeyUpdateError, + ErrorMessage: "keys updated too quickly", + })) }) }) @@ -315,7 +320,10 @@ var _ = Describe("Updatable AEAD", func() { server.Seal(nil, msg, nextPN, ad) // We haven't decrypted any packet in the new key phase yet. // This means that the ACK must have been sent in the old key phase. - Expect(server.SetLargestAcked(nextPN)).To(MatchError("KEY_UPDATE_ERROR: received ACK for key phase 1, but peer didn't update keys")) + Expect(server.SetLargestAcked(nextPN)).To(MatchError(&qerr.TransportError{ + ErrorCode: qerr.KeyUpdateError, + ErrorMessage: "received ACK for key phase 1, but peer didn't update keys", + })) }) It("doesn't error before actually sending a packet in the new key phase", func() { diff --git a/internal/qerr/error_codes.go b/internal/qerr/error_codes.go index 7ac11603..bee42d51 100644 --- a/internal/qerr/error_codes.go +++ b/internal/qerr/error_codes.go @@ -6,51 +6,44 @@ import ( "github.com/lucas-clemente/quic-go/internal/qtls" ) -// ErrorCode can be used as a normal error without reason. -type ErrorCode uint64 +// TransportErrorCode is a QUIC transport error. +type TransportErrorCode uint64 // The error codes defined by QUIC const ( - NoError ErrorCode = 0x0 - InternalError ErrorCode = 0x1 - ConnectionRefused ErrorCode = 0x2 - FlowControlError ErrorCode = 0x3 - StreamLimitError ErrorCode = 0x4 - StreamStateError ErrorCode = 0x5 - FinalSizeError ErrorCode = 0x6 - FrameEncodingError ErrorCode = 0x7 - TransportParameterError ErrorCode = 0x8 - ConnectionIDLimitError ErrorCode = 0x9 - ProtocolViolation ErrorCode = 0xa - InvalidToken ErrorCode = 0xb - ApplicationError ErrorCode = 0xc - CryptoBufferExceeded ErrorCode = 0xd - KeyUpdateError ErrorCode = 0xe - AEADLimitReached ErrorCode = 0xf - NoViablePathError ErrorCode = 0x10 + NoError TransportErrorCode = 0x0 + InternalError TransportErrorCode = 0x1 + ConnectionRefused TransportErrorCode = 0x2 + FlowControlError TransportErrorCode = 0x3 + StreamLimitError TransportErrorCode = 0x4 + StreamStateError TransportErrorCode = 0x5 + FinalSizeError TransportErrorCode = 0x6 + FrameEncodingError TransportErrorCode = 0x7 + TransportParameterError TransportErrorCode = 0x8 + ConnectionIDLimitError TransportErrorCode = 0x9 + ProtocolViolation TransportErrorCode = 0xa + InvalidToken TransportErrorCode = 0xb + ApplicationErrorErrorCode TransportErrorCode = 0xc + CryptoBufferExceeded TransportErrorCode = 0xd + KeyUpdateError TransportErrorCode = 0xe + AEADLimitReached TransportErrorCode = 0xf + NoViablePathError TransportErrorCode = 0x10 ) -func (e ErrorCode) isCryptoError() bool { +func (e TransportErrorCode) IsCryptoError() bool { return e >= 0x100 && e < 0x200 } -func (e ErrorCode) Error() string { - if e.isCryptoError() { - return fmt.Sprintf("%s: %s", e.String(), e.Message()) - } - return e.String() -} - // Message is a description of the error. // It only returns a non-empty string for crypto errors. -func (e ErrorCode) Message() string { - if !e.isCryptoError() { +func (e TransportErrorCode) Message() string { + if !e.IsCryptoError() { return "" } return qtls.Alert(e - 0x100).Error() } -func (e ErrorCode) String() string { +func (e TransportErrorCode) String() string { switch e { case NoError: return "NO_ERROR" @@ -76,7 +69,7 @@ func (e ErrorCode) String() string { return "PROTOCOL_VIOLATION" case InvalidToken: return "INVALID_TOKEN" - case ApplicationError: + case ApplicationErrorErrorCode: return "APPLICATION_ERROR" case CryptoBufferExceeded: return "CRYPTO_BUFFER_EXCEEDED" @@ -87,7 +80,7 @@ func (e ErrorCode) String() string { case NoViablePathError: return "NO_VIABLE_PATH" default: - if e.isCryptoError() { + if e.IsCryptoError() { return fmt.Sprintf("CRYPTO_ERROR (%#x)", uint16(e)) } return fmt.Sprintf("unknown error code: %#x", uint16(e)) diff --git a/internal/qerr/errorcodes_test.go b/internal/qerr/errorcodes_test.go index 99cc4c95..cfc6cd85 100644 --- a/internal/qerr/errorcodes_test.go +++ b/internal/qerr/errorcodes_test.go @@ -30,11 +30,23 @@ var _ = Describe("error codes", func() { valString := c.(*ast.ValueSpec).Values[0].(*ast.BasicLit).Value val, err := strconv.ParseInt(valString, 0, 64) Expect(err).NotTo(HaveOccurred()) - Expect(ErrorCode(val).String()).ToNot(Equal("unknown error code")) + Expect(TransportErrorCode(val).String()).ToNot(Equal("unknown error code")) } }) It("has a string representation for unknown error codes", func() { - Expect(ErrorCode(0x1337).String()).To(Equal("unknown error code: 0x1337")) + Expect(TransportErrorCode(0x1337).String()).To(Equal("unknown error code: 0x1337")) + }) + + It("says if an error is a crypto error", func() { + for i := 0; i < 0x100; i++ { + Expect(TransportErrorCode(i).IsCryptoError()).To(BeFalse()) + } + for i := 0x100; i < 0x200; i++ { + Expect(TransportErrorCode(i).IsCryptoError()).To(BeTrue()) + } + for i := 0x200; i < 0x300; i++ { + Expect(TransportErrorCode(i).IsCryptoError()).To(BeFalse()) + } }) }) diff --git a/internal/qerr/quic_error.go b/internal/qerr/quic_error.go index b5f56957..979e6a19 100644 --- a/internal/qerr/quic_error.go +++ b/internal/qerr/quic_error.go @@ -2,82 +2,39 @@ package qerr import ( "fmt" - "net" ) var ( - ErrIdleTimeout = newTimeoutError("No recent network activity") - ErrHandshakeTimeout = newTimeoutError("Handshake did not complete in time") + ErrHandshakeTimeout = &HandshakeTimeoutError{} + ErrIdleTimeout = &IdleTimeoutError{} ) -// A QuicError consists of an error code plus a error reason -type QuicError struct { - ErrorCode ErrorCode - FrameType uint64 // only valid if this not an application error - ErrorMessage string - isTimeout bool - isApplicationError bool +type TransportError struct { + Remote bool + FrameType uint64 + ErrorCode TransportErrorCode + ErrorMessage string } -var _ net.Error = &QuicError{} +var _ error = &TransportError{} -// NewError creates a new QuicError instance -func NewError(errorCode ErrorCode, errorMessage string) *QuicError { - return &QuicError{ - ErrorCode: errorCode, +// NewCryptoError create a new TransportError instance for a crypto error +func NewCryptoError(tlsAlert uint8, errorMessage string) *TransportError { + return &TransportError{ + ErrorCode: 0x100 + TransportErrorCode(tlsAlert), ErrorMessage: errorMessage, } } -// NewErrorWithFrameType creates a new QuicError instance for a specific frame type -func NewErrorWithFrameType(errorCode ErrorCode, frameType uint64, errorMessage string) *QuicError { - return &QuicError{ - ErrorCode: errorCode, - FrameType: frameType, - ErrorMessage: errorMessage, - } +func (e *TransportError) Is(target error) bool { + _, ok := target.(*TransportError) + return ok } -// newTimeoutError creates a new QuicError instance for a timeout error -func newTimeoutError(errorMessage string) *QuicError { - return &QuicError{ - ErrorMessage: errorMessage, - isTimeout: true, - } -} - -// NewCryptoError create a new QuicError instance for a crypto error -func NewCryptoError(tlsAlert uint8, errorMessage string) *QuicError { - return &QuicError{ - ErrorCode: 0x100 + ErrorCode(tlsAlert), - ErrorMessage: errorMessage, - } -} - -// NewApplicationError creates a new QuicError instance for an application error -func NewApplicationError(errorCode ErrorCode, errorMessage string) *QuicError { - return &QuicError{ - ErrorCode: errorCode, - ErrorMessage: errorMessage, - isApplicationError: true, - } -} - -func (e *QuicError) Error() string { - if e.isApplicationError { - if len(e.ErrorMessage) == 0 { - return fmt.Sprintf("Application error %#x", uint64(e.ErrorCode)) - } - return fmt.Sprintf("Application error %#x: %s", uint64(e.ErrorCode), e.ErrorMessage) - } - var str string - if e.isTimeout { - str = "Timeout" - } else { - str = e.ErrorCode.String() - if e.FrameType != 0 { - str += fmt.Sprintf(" (frame type: %#x)", e.FrameType) - } +func (e *TransportError) Error() string { + str := e.ErrorCode.String() + if e.FrameType != 0 { + str += fmt.Sprintf(" (frame type: %#x)", e.FrameType) } msg := e.ErrorMessage if len(msg) == 0 { @@ -89,34 +46,46 @@ func (e *QuicError) Error() string { return str + ": " + msg } -// IsCryptoError says if this error is a crypto error -func (e *QuicError) IsCryptoError() bool { - return e.ErrorCode.isCryptoError() +type ApplicationError struct { + Remote bool + ErrorCode uint64 + ErrorMessage string } -// IsApplicationError says if this error is an application error -func (e *QuicError) IsApplicationError() bool { - return e.isApplicationError +var _ error = &ApplicationError{} + +func (e *ApplicationError) Is(target error) bool { + _, ok := target.(*ApplicationError) + return ok } -// Temporary says if the error is temporary. -func (e *QuicError) Temporary() bool { - return false -} - -// Timeout says if this error is a timeout. -func (e *QuicError) Timeout() bool { - return e.isTimeout -} - -// ToQuicError converts an arbitrary error to a QuicError. It leaves QuicErrors -// unchanged, and properly handles `ErrorCode`s. -func ToQuicError(err error) *QuicError { - switch e := err.(type) { - case *QuicError: - return e - case ErrorCode: - return NewError(e, "") +func (e *ApplicationError) Error() string { + if len(e.ErrorMessage) == 0 { + return fmt.Sprintf("Application error %#x", e.ErrorCode) } - return NewError(InternalError, err.Error()) + return fmt.Sprintf("Application error %#x: %s", e.ErrorCode, e.ErrorMessage) +} + +type IdleTimeoutError struct{} + +var _ error = &IdleTimeoutError{} + +func (e *IdleTimeoutError) Timeout() bool { return true } +func (e *IdleTimeoutError) Temporary() bool { return false } +func (e *IdleTimeoutError) Error() string { return "timeout: no recent network activity" } +func (e *IdleTimeoutError) Is(target error) bool { + _, ok := target.(*IdleTimeoutError) + return ok +} + +type HandshakeTimeoutError struct{} + +var _ error = &HandshakeTimeoutError{} + +func (e *HandshakeTimeoutError) Timeout() bool { return true } +func (e *HandshakeTimeoutError) Temporary() bool { return false } +func (e *HandshakeTimeoutError) Error() string { return "timeout: handshake did not complete in time" } +func (e *HandshakeTimeoutError) Is(target error) bool { + _, ok := target.(*HandshakeTimeoutError) + return ok } diff --git a/internal/qerr/quic_error_test.go b/internal/qerr/quic_error_test.go index 206831b7..362e0edf 100644 --- a/internal/qerr/quic_error_test.go +++ b/internal/qerr/quic_error_test.go @@ -1,98 +1,104 @@ package qerr import ( - "io" + "errors" + "net" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" ) -var _ = Describe("QUIC Transport Errors", func() { - It("has a string representation", func() { - err := NewError(FlowControlError, "foobar") - Expect(err.Timeout()).To(BeFalse()) - Expect(err.IsApplicationError()).To(BeFalse()) - Expect(err.Error()).To(Equal("FLOW_CONTROL_ERROR: foobar")) +var _ = Describe("QUIC Errors", func() { + Context("Transport Errors", func() { + It("has a string representation", func() { + Expect((&TransportError{ + ErrorCode: FlowControlError, + ErrorMessage: "foobar", + }).Error()).To(Equal("FLOW_CONTROL_ERROR: foobar")) + }) + + It("has a string representation for empty error phrases", func() { + Expect((&TransportError{ErrorCode: FlowControlError}).Error()).To(Equal("FLOW_CONTROL_ERROR")) + }) + + It("includes the frame type, for errors without a message", func() { + Expect((&TransportError{ + ErrorCode: FlowControlError, + FrameType: 0x1337, + }).Error()).To(Equal("FLOW_CONTROL_ERROR (frame type: 0x1337)")) + }) + + It("includes the frame type, for errors with a message", func() { + Expect((&TransportError{ + ErrorCode: FlowControlError, + FrameType: 0x1337, + ErrorMessage: "foobar", + }).Error()).To(Equal("FLOW_CONTROL_ERROR (frame type: 0x1337): foobar")) + }) + + It("works with error assertions", func() { + Expect(errors.Is(&TransportError{ErrorCode: FlowControlError}, &TransportError{})).To(BeTrue()) + Expect(errors.Is(&TransportError{ErrorCode: FlowControlError}, &ApplicationError{})).To(BeFalse()) + }) + + Context("crypto errors", func() { + It("has a string representation for errors with a message", func() { + err := NewCryptoError(0x42, "foobar") + Expect(err.Error()).To(Equal("CRYPTO_ERROR (0x142): foobar")) + }) + + It("has a string representation for errors without a message", func() { + err := NewCryptoError(0x2a, "") + Expect(err.Error()).To(Equal("CRYPTO_ERROR (0x12a): tls: bad certificate")) + }) + }) }) - It("has a string representation for empty error phrases", func() { - err := NewError(FlowControlError, "") - Expect(err.Error()).To(Equal("FLOW_CONTROL_ERROR")) - }) - - It("includes the frame type, for errors without a message", func() { - err := NewErrorWithFrameType(FlowControlError, 0x1337, "") - Expect(err.Error()).To(Equal("FLOW_CONTROL_ERROR (frame type: 0x1337)")) - }) - - It("includes the frame type, for errors with a message", func() { - err := NewErrorWithFrameType(FlowControlError, 0x1337, "foobar") - Expect(err.Error()).To(Equal("FLOW_CONTROL_ERROR (frame type: 0x1337): foobar")) - }) - - It("has a string representation for timeout errors", func() { - err := newTimeoutError("foobar") - Expect(err.Timeout()).To(BeTrue()) - Expect(err.Error()).To(Equal("Timeout: foobar")) - }) - - Context("crypto errors", func() { + Context("Application Errors", func() { It("has a string representation for errors with a message", func() { - err := NewCryptoError(0x42, "foobar") - Expect(err.Error()).To(Equal("CRYPTO_ERROR (0x142): foobar")) + Expect((&ApplicationError{ + ErrorCode: 0x42, + ErrorMessage: "foobar", + }).Error()).To(Equal("Application error 0x42: foobar")) }) It("has a string representation for errors without a message", func() { - err := NewCryptoError(0x2a, "") - Expect(err.Error()).To(Equal("CRYPTO_ERROR (0x12a): tls: bad certificate")) + Expect((&ApplicationError{ + ErrorCode: 0x42, + }).Error()).To(Equal("Application error 0x42")) }) - It("says if an error is a crypto error", func() { - Expect(NewError(FlowControlError, "").IsCryptoError()).To(BeFalse()) - err := NewCryptoError(42, "") - Expect(err.IsCryptoError()).To(BeTrue()) - Expect(err.IsApplicationError()).To(BeFalse()) + It("works with error assertions", func() { + Expect(errors.Is(&ApplicationError{ErrorCode: 0x1234}, &ApplicationError{})).To(BeTrue()) + Expect(errors.Is(&ApplicationError{ErrorCode: 0x1234}, &TransportError{})).To(BeFalse()) }) }) - Context("application errors", func() { - It("has a string representation for errors with a message", func() { - err := NewApplicationError(0x42, "foobar") - Expect(err.IsApplicationError()).To(BeTrue()) - Expect(err.Error()).To(Equal("Application error 0x42: foobar")) + Context("timeout errors", func() { + It("handshake timeouts", func() { + //nolint:gosimple // we need to assign to an interface here + var err error + err = &HandshakeTimeoutError{} + nerr, ok := err.(net.Error) + Expect(ok).To(BeTrue()) + Expect(nerr.Timeout()).To(BeTrue()) + Expect(nerr.Temporary()).To(BeFalse()) + Expect(err.Error()).To(Equal("timeout: handshake did not complete in time")) + Expect(errors.Is(err, &HandshakeTimeoutError{})).To(BeTrue()) + Expect(errors.Is(err, &IdleTimeoutError{})).To(BeFalse()) }) - It("has a string representation for errors without a message", func() { - err := NewApplicationError(0x42, "") - Expect(err.Error()).To(Equal("Application error 0x42")) - }) - }) - - Context("ErrorCode", func() { - It("works as error", func() { - var err error = StreamStateError - Expect(err).To(MatchError("STREAM_STATE_ERROR")) - }) - - It("recognizes crypto errors", func() { - err := ErrorCode(0x100 + 0x2a) - Expect(err.Error()).To(Equal("CRYPTO_ERROR (0x12a): tls: bad certificate")) - }) - }) - - Context("ToQuicError", func() { - It("leaves QuicError unchanged", func() { - err := NewError(TransportParameterError, "foo") - Expect(ToQuicError(err)).To(Equal(err)) - }) - - It("wraps ErrorCode properly", func() { - var err error = FinalSizeError - Expect(ToQuicError(err)).To(Equal(NewError(FinalSizeError, ""))) - }) - - It("changes default errors to InternalError", func() { - Expect(ToQuicError(io.EOF)).To(Equal(NewError(InternalError, "EOF"))) + It("idle timeouts", func() { + //nolint:gosimple // we need to assign to an interface here + var err error + err = &IdleTimeoutError{} + nerr, ok := err.(net.Error) + Expect(ok).To(BeTrue()) + Expect(nerr.Timeout()).To(BeTrue()) + Expect(nerr.Temporary()).To(BeFalse()) + Expect(err.Error()).To(Equal("timeout: no recent network activity")) + Expect(errors.Is(err, &HandshakeTimeoutError{})).To(BeFalse()) + Expect(errors.Is(err, &IdleTimeoutError{})).To(BeTrue()) }) }) }) diff --git a/internal/wire/connection_close_frame.go b/internal/wire/connection_close_frame.go index 9a77a276..4ce49af6 100644 --- a/internal/wire/connection_close_frame.go +++ b/internal/wire/connection_close_frame.go @@ -5,14 +5,13 @@ import ( "io" "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/lucas-clemente/quic-go/internal/qerr" "github.com/lucas-clemente/quic-go/quicvarint" ) // A ConnectionCloseFrame is a CONNECTION_CLOSE frame type ConnectionCloseFrame struct { IsApplicationError bool - ErrorCode qerr.ErrorCode + ErrorCode uint64 FrameType uint64 ReasonPhrase string } @@ -28,7 +27,7 @@ func parseConnectionCloseFrame(r *bytes.Reader, _ protocol.VersionNumber) (*Conn if err != nil { return nil, err } - f.ErrorCode = qerr.ErrorCode(ec) + f.ErrorCode = ec // read the Frame Type, if this is not an application error if !f.IsApplicationError { ft, err := quicvarint.Read(r) @@ -59,8 +58,8 @@ func parseConnectionCloseFrame(r *bytes.Reader, _ protocol.VersionNumber) (*Conn } // Length of a written frame -func (f *ConnectionCloseFrame) Length(version protocol.VersionNumber) protocol.ByteCount { - length := 1 + quicvarint.Len(uint64(f.ErrorCode)) + quicvarint.Len(uint64(len(f.ReasonPhrase))) + protocol.ByteCount(len(f.ReasonPhrase)) +func (f *ConnectionCloseFrame) Length(protocol.VersionNumber) protocol.ByteCount { + length := 1 + quicvarint.Len(f.ErrorCode) + quicvarint.Len(uint64(len(f.ReasonPhrase))) + protocol.ByteCount(len(f.ReasonPhrase)) if !f.IsApplicationError { length += quicvarint.Len(f.FrameType) // for the frame type } @@ -74,7 +73,7 @@ func (f *ConnectionCloseFrame) Write(b *bytes.Buffer, version protocol.VersionNu b.WriteByte(0x1c) } - quicvarint.Write(b, uint64(f.ErrorCode)) + quicvarint.Write(b, f.ErrorCode) if !f.IsApplicationError { quicvarint.Write(b, f.FrameType) } diff --git a/internal/wire/connection_close_frame_test.go b/internal/wire/connection_close_frame_test.go index 517b2f84..c116454a 100644 --- a/internal/wire/connection_close_frame_test.go +++ b/internal/wire/connection_close_frame_test.go @@ -5,7 +5,6 @@ import ( "io" "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/lucas-clemente/quic-go/internal/qerr" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" @@ -24,8 +23,8 @@ var _ = Describe("CONNECTION_CLOSE Frame", func() { frame, err := parseConnectionCloseFrame(b, versionIETFFrames) Expect(err).ToNot(HaveOccurred()) Expect(frame.IsApplicationError).To(BeFalse()) - Expect(frame.ErrorCode).To(Equal(qerr.ErrorCode(0x19))) - Expect(frame.FrameType).To(Equal(uint64(0x1337))) + Expect(frame.ErrorCode).To(BeEquivalentTo(0x19)) + Expect(frame.FrameType).To(BeEquivalentTo(0x1337)) Expect(frame.ReasonPhrase).To(Equal(reason)) Expect(b.Len()).To(BeZero()) }) @@ -40,7 +39,7 @@ var _ = Describe("CONNECTION_CLOSE Frame", func() { frame, err := parseConnectionCloseFrame(b, versionIETFFrames) Expect(err).ToNot(HaveOccurred()) Expect(frame.IsApplicationError).To(BeTrue()) - Expect(frame.ErrorCode).To(Equal(qerr.ErrorCode(0xcafe))) + Expect(frame.ErrorCode).To(BeEquivalentTo(0xcafe)) Expect(frame.ReasonPhrase).To(Equal(reason)) Expect(b.Len()).To(BeZero()) }) diff --git a/internal/wire/frame_parser.go b/internal/wire/frame_parser.go index a858989e..f3a51ecb 100644 --- a/internal/wire/frame_parser.go +++ b/internal/wire/frame_parser.go @@ -26,7 +26,7 @@ func NewFrameParser(supportsDatagrams bool, v protocol.VersionNumber) FrameParse } } -// ParseNextFrame parses the next frame +// ParseNext parses the next frame. // It skips PADDING frames. func (p *frameParser) ParseNext(r *bytes.Reader, encLevel protocol.EncryptionLevel) (Frame, error) { for r.Len() != 0 { @@ -38,7 +38,11 @@ func (p *frameParser) ParseNext(r *bytes.Reader, encLevel protocol.EncryptionLev f, err := p.parseFrame(r, typeByte, encLevel) if err != nil { - return nil, qerr.NewErrorWithFrameType(qerr.FrameEncodingError, uint64(typeByte), err.Error()) + return nil, &qerr.TransportError{ + FrameType: uint64(typeByte), + ErrorCode: qerr.FrameEncodingError, + ErrorMessage: err.Error(), + } } return f, nil } diff --git a/internal/wire/frame_parser_test.go b/internal/wire/frame_parser_test.go index 800af60f..86e970a1 100644 --- a/internal/wire/frame_parser_test.go +++ b/internal/wire/frame_parser_test.go @@ -295,12 +295,20 @@ var _ = Describe("Frame parsing", func() { buf := &bytes.Buffer{} Expect(f.Write(buf, versionIETFFrames)).To(Succeed()) _, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) - Expect(err).To(MatchError("FRAME_ENCODING_ERROR (frame type: 0x30): unknown frame type")) + Expect(err).To(MatchError(&qerr.TransportError{ + ErrorCode: qerr.FrameEncodingError, + FrameType: 0x30, + ErrorMessage: "unknown frame type", + })) }) It("errors on invalid type", func() { _, err := parser.ParseNext(bytes.NewReader([]byte{0x42}), protocol.Encryption1RTT) - Expect(err).To(MatchError("FRAME_ENCODING_ERROR (frame type: 0x42): unknown frame type")) + Expect(err).To(MatchError(&qerr.TransportError{ + ErrorCode: qerr.FrameEncodingError, + FrameType: 0x42, + ErrorMessage: "unknown frame type", + })) }) It("errors on invalid frames", func() { @@ -312,7 +320,7 @@ var _ = Describe("Frame parsing", func() { f.Write(b, versionIETFFrames) _, err := parser.ParseNext(bytes.NewReader(b.Bytes()[:b.Len()-2]), protocol.Encryption1RTT) Expect(err).To(HaveOccurred()) - Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.FrameEncodingError)) + Expect(err.(*qerr.TransportError).ErrorCode).To(Equal(qerr.FrameEncodingError)) }) Context("encryption level check", func() { @@ -357,8 +365,9 @@ var _ = Describe("Frame parsing", func() { case *AckFrame, *ConnectionCloseFrame, *CryptoFrame, *PingFrame: Expect(err).ToNot(HaveOccurred()) default: - Expect(err).To(HaveOccurred()) - Expect(err.Error()).To(ContainSubstring("not allowed at encryption level Initial")) + Expect(err).To(BeAssignableToTypeOf(&qerr.TransportError{})) + Expect(err.(*qerr.TransportError).ErrorCode).To(Equal(qerr.FrameEncodingError)) + Expect(err.(*qerr.TransportError).ErrorMessage).To(ContainSubstring("not allowed at encryption level Initial")) } } }) @@ -370,8 +379,9 @@ var _ = Describe("Frame parsing", func() { case *AckFrame, *ConnectionCloseFrame, *CryptoFrame, *PingFrame: Expect(err).ToNot(HaveOccurred()) default: - Expect(err).To(HaveOccurred()) - Expect(err.Error()).To(ContainSubstring("not allowed at encryption level Handshake")) + Expect(err).To(BeAssignableToTypeOf(&qerr.TransportError{})) + Expect(err.(*qerr.TransportError).ErrorCode).To(Equal(qerr.FrameEncodingError)) + Expect(err.(*qerr.TransportError).ErrorMessage).To(ContainSubstring("not allowed at encryption level Handshake")) } } }) @@ -381,8 +391,9 @@ var _ = Describe("Frame parsing", func() { _, err := parser.ParseNext(bytes.NewReader(b), protocol.Encryption0RTT) switch frames[i].(type) { case *AckFrame, *ConnectionCloseFrame, *CryptoFrame, *NewTokenFrame, *PathResponseFrame, *RetireConnectionIDFrame: - Expect(err).To(HaveOccurred()) - Expect(err.Error()).To(ContainSubstring("not allowed at encryption level 0-RTT")) + Expect(err).To(BeAssignableToTypeOf(&qerr.TransportError{})) + Expect(err.(*qerr.TransportError).ErrorCode).To(Equal(qerr.FrameEncodingError)) + Expect(err.(*qerr.TransportError).ErrorMessage).To(ContainSubstring("not allowed at encryption level 0-RTT")) default: Expect(err).ToNot(HaveOccurred()) } diff --git a/internal/wire/stream_frame.go b/internal/wire/stream_frame.go index cf1f3fcb..66340d16 100644 --- a/internal/wire/stream_frame.go +++ b/internal/wire/stream_frame.go @@ -6,7 +6,6 @@ import ( "io" "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/lucas-clemente/quic-go/internal/qerr" "github.com/lucas-clemente/quic-go/quicvarint" ) @@ -79,7 +78,7 @@ func parseStreamFrame(r *bytes.Reader, _ protocol.VersionNumber) (*StreamFrame, } } if frame.Offset+frame.DataLen() > protocol.MaxByteCount { - return nil, qerr.NewError(qerr.FrameEncodingError, "stream data overflows maximum offset") + return nil, errors.New("stream data overflows maximum offset") } return frame, nil } diff --git a/internal/wire/stream_frame_test.go b/internal/wire/stream_frame_test.go index 6d4fef68..08a421b6 100644 --- a/internal/wire/stream_frame_test.go +++ b/internal/wire/stream_frame_test.go @@ -77,7 +77,7 @@ var _ = Describe("STREAM frame", func() { data = append(data, []byte("foobar")...) r := bytes.NewReader(data) _, err := parseStreamFrame(r, versionIETFFrames) - Expect(err).To(MatchError("FRAME_ENCODING_ERROR: stream data overflows maximum offset")) + Expect(err).To(MatchError("stream data overflows maximum offset")) }) It("rejects frames that claim to be longer than the packet size", func() { diff --git a/internal/wire/transport_parameter_test.go b/internal/wire/transport_parameter_test.go index 283fdab6..81d33c60 100644 --- a/internal/wire/transport_parameter_test.go +++ b/internal/wire/transport_parameter_test.go @@ -9,6 +9,7 @@ import ( "time" "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/qerr" "github.com/lucas-clemente/quic-go/quicvarint" . "github.com/onsi/ginkgo" @@ -147,7 +148,10 @@ var _ = Describe("Transport Parameters", func() { quicvarint.Write(b, uint64(statelessResetTokenParameterID)) quicvarint.Write(b, 15) b.Write(make([]byte, 15)) - Expect((&TransportParameters{}).Unmarshal(b.Bytes(), protocol.PerspectiveServer)).To(MatchError("TRANSPORT_PARAMETER_ERROR: wrong length for stateless_reset_token: 15 (expected 16)")) + Expect((&TransportParameters{}).Unmarshal(b.Bytes(), protocol.PerspectiveServer)).To(MatchError(&qerr.TransportError{ + ErrorCode: qerr.TransportParameterError, + ErrorMessage: "wrong length for stateless_reset_token: 15 (expected 16)", + })) }) It("errors when the max_packet_size is too small", func() { @@ -155,7 +159,10 @@ var _ = Describe("Transport Parameters", func() { quicvarint.Write(b, uint64(maxUDPPayloadSizeParameterID)) quicvarint.Write(b, uint64(quicvarint.Len(1199))) quicvarint.Write(b, 1199) - Expect((&TransportParameters{}).Unmarshal(b.Bytes(), protocol.PerspectiveServer)).To(MatchError("TRANSPORT_PARAMETER_ERROR: invalid value for max_packet_size: 1199 (minimum 1200)")) + Expect((&TransportParameters{}).Unmarshal(b.Bytes(), protocol.PerspectiveServer)).To(MatchError(&qerr.TransportError{ + ErrorCode: qerr.TransportParameterError, + ErrorMessage: "invalid value for max_packet_size: 1199 (minimum 1200)", + })) }) It("errors when disable_active_migration has content", func() { @@ -163,7 +170,10 @@ var _ = Describe("Transport Parameters", func() { quicvarint.Write(b, uint64(disableActiveMigrationParameterID)) quicvarint.Write(b, 6) b.Write([]byte("foobar")) - Expect((&TransportParameters{}).Unmarshal(b.Bytes(), protocol.PerspectiveServer)).To(MatchError("TRANSPORT_PARAMETER_ERROR: wrong length for disable_active_migration: 6 (expected empty)")) + Expect((&TransportParameters{}).Unmarshal(b.Bytes(), protocol.PerspectiveServer)).To(MatchError(&qerr.TransportError{ + ErrorCode: qerr.TransportParameterError, + ErrorMessage: "wrong length for disable_active_migration: 6 (expected empty)", + })) }) It("errors when the server doesn't set the original_destination_connection_id", func() { @@ -172,11 +182,17 @@ var _ = Describe("Transport Parameters", func() { quicvarint.Write(b, 16) b.Write(make([]byte, 16)) addInitialSourceConnectionID(b) - Expect((&TransportParameters{}).Unmarshal(b.Bytes(), protocol.PerspectiveServer)).To(MatchError("TRANSPORT_PARAMETER_ERROR: missing original_destination_connection_id")) + Expect((&TransportParameters{}).Unmarshal(b.Bytes(), protocol.PerspectiveServer)).To(MatchError(&qerr.TransportError{ + ErrorCode: qerr.TransportParameterError, + ErrorMessage: "missing original_destination_connection_id", + })) }) It("errors when the initial_source_connection_id is missing", func() { - Expect((&TransportParameters{}).Unmarshal([]byte{}, protocol.PerspectiveClient)).To(MatchError("TRANSPORT_PARAMETER_ERROR: missing initial_source_connection_id")) + Expect((&TransportParameters{}).Unmarshal([]byte{}, protocol.PerspectiveClient)).To(MatchError(&qerr.TransportError{ + ErrorCode: qerr.TransportParameterError, + ErrorMessage: "missing initial_source_connection_id", + })) }) It("errors when the max_ack_delay is too large", func() { @@ -185,7 +201,10 @@ var _ = Describe("Transport Parameters", func() { StatelessResetToken: &protocol.StatelessResetToken{}, }).Marshal(protocol.PerspectiveServer) p := &TransportParameters{} - Expect(p.Unmarshal(data, protocol.PerspectiveServer)).To(MatchError("TRANSPORT_PARAMETER_ERROR: invalid value for max_ack_delay: 16384ms (maximum 16383ms)")) + Expect(p.Unmarshal(data, protocol.PerspectiveServer)).To(MatchError(&qerr.TransportError{ + ErrorCode: qerr.TransportParameterError, + ErrorMessage: "invalid value for max_ack_delay: 16384ms (maximum 16383ms)", + })) }) It("doesn't send the max_ack_delay, if it has the default value", func() { @@ -215,7 +234,10 @@ var _ = Describe("Transport Parameters", func() { StatelessResetToken: &protocol.StatelessResetToken{}, }).Marshal(protocol.PerspectiveServer) p := &TransportParameters{} - Expect(p.Unmarshal(data, protocol.PerspectiveServer)).To(MatchError("TRANSPORT_PARAMETER_ERROR: invalid value for ack_delay_exponent: 21 (maximum 20)")) + Expect(p.Unmarshal(data, protocol.PerspectiveServer)).To(MatchError(&qerr.TransportError{ + ErrorCode: qerr.TransportParameterError, + ErrorMessage: "invalid value for ack_delay_exponent: 21 (maximum 20)", + })) }) It("doesn't send the ack_delay_exponent, if it has the default value", func() { @@ -256,9 +278,10 @@ var _ = Describe("Transport Parameters", func() { Expect(quicvarint.Len(val)).ToNot(BeEquivalentTo(2)) quicvarint.Write(b, val) addInitialSourceConnectionID(b) - err := (&TransportParameters{}).Unmarshal(b.Bytes(), protocol.PerspectiveServer) - Expect(err).To(HaveOccurred()) - Expect(err.Error()).To(ContainSubstring("TRANSPORT_PARAMETER_ERROR: inconsistent transport parameter length")) + Expect((&TransportParameters{}).Unmarshal(b.Bytes(), protocol.PerspectiveServer)).To(MatchError(&qerr.TransportError{ + ErrorCode: qerr.TransportParameterError, + ErrorMessage: fmt.Sprintf("inconsistent transport parameter length for transport parameter %#x", initialMaxStreamDataBidiLocalParameterID), + })) }) It("errors if initial_max_streams_bidi is too large", func() { @@ -267,9 +290,10 @@ var _ = Describe("Transport Parameters", func() { quicvarint.Write(b, uint64(quicvarint.Len(uint64(protocol.MaxStreamCount+1)))) quicvarint.Write(b, uint64(protocol.MaxStreamCount+1)) addInitialSourceConnectionID(b) - err := (&TransportParameters{}).Unmarshal(b.Bytes(), protocol.PerspectiveServer) - Expect(err).To(HaveOccurred()) - Expect(err.Error()).To(ContainSubstring("TRANSPORT_PARAMETER_ERROR: initial_max_streams_bidi too large: 1152921504606846977 (maximum 1152921504606846976)")) + Expect((&TransportParameters{}).Unmarshal(b.Bytes(), protocol.PerspectiveServer)).To(MatchError(&qerr.TransportError{ + ErrorCode: qerr.TransportParameterError, + ErrorMessage: "initial_max_streams_bidi too large: 1152921504606846977 (maximum 1152921504606846976)", + })) }) It("errors if initial_max_streams_uni is too large", func() { @@ -278,9 +302,10 @@ var _ = Describe("Transport Parameters", func() { quicvarint.Write(b, uint64(quicvarint.Len(uint64(protocol.MaxStreamCount+1)))) quicvarint.Write(b, uint64(protocol.MaxStreamCount+1)) addInitialSourceConnectionID(b) - err := (&TransportParameters{}).Unmarshal(b.Bytes(), protocol.PerspectiveServer) - Expect(err).To(HaveOccurred()) - Expect(err.Error()).To(ContainSubstring("TRANSPORT_PARAMETER_ERROR: initial_max_streams_uni too large: 1152921504606846977 (maximum 1152921504606846976)")) + Expect((&TransportParameters{}).Unmarshal(b.Bytes(), protocol.PerspectiveServer)).To(MatchError(&qerr.TransportError{ + ErrorCode: qerr.TransportParameterError, + ErrorMessage: "initial_max_streams_uni too large: 1152921504606846977 (maximum 1152921504606846976)", + })) }) It("handles huge max_ack_delay values", func() { @@ -290,9 +315,10 @@ var _ = Describe("Transport Parameters", func() { quicvarint.Write(b, uint64(quicvarint.Len(val))) quicvarint.Write(b, val) addInitialSourceConnectionID(b) - err := (&TransportParameters{}).Unmarshal(b.Bytes(), protocol.PerspectiveClient) - Expect(err).To(HaveOccurred()) - Expect(err.Error()).To(ContainSubstring("invalid value for max_ack_delay")) + Expect((&TransportParameters{}).Unmarshal(b.Bytes(), protocol.PerspectiveClient)).To(MatchError(&qerr.TransportError{ + ErrorCode: qerr.TransportParameterError, + ErrorMessage: "invalid value for max_ack_delay: 3689348814741910323ms (maximum 16383ms)", + })) }) It("skips unknown parameters", func() { @@ -331,9 +357,10 @@ var _ = Describe("Transport Parameters", func() { quicvarint.Write(b, uint64(quicvarint.Len(0x1337))) quicvarint.Write(b, 0x1337) addInitialSourceConnectionID(b) - err := (&TransportParameters{}).Unmarshal(b.Bytes(), protocol.PerspectiveClient) - Expect(err).To(HaveOccurred()) - Expect(err.Error()).To(ContainSubstring("received duplicate transport parameter")) + Expect((&TransportParameters{}).Unmarshal(b.Bytes(), protocol.PerspectiveClient)).To(MatchError(&qerr.TransportError{ + ErrorCode: qerr.TransportParameterError, + ErrorMessage: fmt.Sprintf("received duplicate transport parameter %#x", initialMaxStreamDataBidiLocalParameterID), + })) }) It("errors if there's not enough data to read", func() { @@ -342,7 +369,10 @@ var _ = Describe("Transport Parameters", func() { quicvarint.Write(b, 7) b.Write([]byte("foobar")) p := &TransportParameters{} - Expect(p.Unmarshal(b.Bytes(), protocol.PerspectiveServer)).To(MatchError("TRANSPORT_PARAMETER_ERROR: remaining length (6) smaller than parameter length (7)")) + Expect(p.Unmarshal(b.Bytes(), protocol.PerspectiveServer)).To(MatchError(&qerr.TransportError{ + ErrorCode: qerr.TransportParameterError, + ErrorMessage: "remaining length (6) smaller than parameter length (7)", + })) }) It("errors if the client sent a stateless_reset_token", func() { @@ -350,7 +380,10 @@ var _ = Describe("Transport Parameters", func() { quicvarint.Write(b, uint64(statelessResetTokenParameterID)) quicvarint.Write(b, uint64(quicvarint.Len(16))) b.Write(make([]byte, 16)) - Expect((&TransportParameters{}).Unmarshal(b.Bytes(), protocol.PerspectiveClient)).To(MatchError("TRANSPORT_PARAMETER_ERROR: client sent a stateless_reset_token")) + Expect((&TransportParameters{}).Unmarshal(b.Bytes(), protocol.PerspectiveClient)).To(MatchError(&qerr.TransportError{ + ErrorCode: qerr.TransportParameterError, + ErrorMessage: "client sent a stateless_reset_token", + })) }) It("errors if the client sent the original_destination_connection_id", func() { @@ -358,7 +391,10 @@ var _ = Describe("Transport Parameters", func() { quicvarint.Write(b, uint64(originalDestinationConnectionIDParameterID)) quicvarint.Write(b, 6) b.Write([]byte("foobar")) - Expect((&TransportParameters{}).Unmarshal(b.Bytes(), protocol.PerspectiveClient)).To(MatchError("TRANSPORT_PARAMETER_ERROR: client sent an original_destination_connection_id")) + Expect((&TransportParameters{}).Unmarshal(b.Bytes(), protocol.PerspectiveClient)).To(MatchError(&qerr.TransportError{ + ErrorCode: qerr.TransportParameterError, + ErrorMessage: "client sent an original_destination_connection_id", + })) }) Context("preferred address", func() { @@ -396,7 +432,10 @@ var _ = Describe("Transport Parameters", func() { quicvarint.Write(b, 6) b.Write([]byte("foobar")) p := &TransportParameters{} - Expect(p.Unmarshal(b.Bytes(), protocol.PerspectiveClient)).To(MatchError("TRANSPORT_PARAMETER_ERROR: client sent a preferred_address")) + Expect(p.Unmarshal(b.Bytes(), protocol.PerspectiveClient)).To(MatchError(&qerr.TransportError{ + ErrorCode: qerr.TransportParameterError, + ErrorMessage: "client sent a preferred_address", + })) }) It("errors on zero-length connection IDs", func() { @@ -406,7 +445,10 @@ var _ = Describe("Transport Parameters", func() { StatelessResetToken: &protocol.StatelessResetToken{}, }).Marshal(protocol.PerspectiveServer) p := &TransportParameters{} - Expect(p.Unmarshal(data, protocol.PerspectiveServer)).To(MatchError("TRANSPORT_PARAMETER_ERROR: invalid connection ID length: 0")) + Expect(p.Unmarshal(data, protocol.PerspectiveServer)).To(MatchError(&qerr.TransportError{ + ErrorCode: qerr.TransportParameterError, + ErrorMessage: "invalid connection ID length: 0", + })) }) It("errors on too long connection IDs", func() { @@ -417,7 +459,10 @@ var _ = Describe("Transport Parameters", func() { StatelessResetToken: &protocol.StatelessResetToken{}, }).Marshal(protocol.PerspectiveServer) p := &TransportParameters{} - Expect(p.Unmarshal(data, protocol.PerspectiveServer)).To(MatchError("TRANSPORT_PARAMETER_ERROR: invalid connection ID length: 21")) + Expect(p.Unmarshal(data, protocol.PerspectiveServer)).To(MatchError(&qerr.TransportError{ + ErrorCode: qerr.TransportParameterError, + ErrorMessage: "invalid connection ID length: 21", + })) }) It("errors on EOF", func() { diff --git a/internal/wire/transport_parameters.go b/internal/wire/transport_parameters.go index 1f1085bc..b7e0a8c9 100644 --- a/internal/wire/transport_parameters.go +++ b/internal/wire/transport_parameters.go @@ -90,7 +90,10 @@ type TransportParameters struct { // Unmarshal the transport parameters func (p *TransportParameters) Unmarshal(data []byte, sentBy protocol.Perspective) error { if err := p.unmarshal(bytes.NewReader(data), sentBy, false); err != nil { - return qerr.NewError(qerr.TransportParameterError, err.Error()) + return &qerr.TransportError{ + ErrorCode: qerr.TransportParameterError, + ErrorMessage: err.Error(), + } } return nil } @@ -259,7 +262,7 @@ func (p *TransportParameters) readNumericTransportParameter( return fmt.Errorf("error while reading transport parameter %d: %s", paramID, err) } if remainingLen-r.Len() != expectedLen { - return fmt.Errorf("inconsistent transport parameter length for %d", paramID) + return fmt.Errorf("inconsistent transport parameter length for transport parameter %#x", paramID) } //nolint:exhaustive // This only covers the numeric transport parameters. switch paramID { diff --git a/logging/interface.go b/logging/interface.go index 30e91899..8f689a59 100644 --- a/logging/interface.go +++ b/logging/interface.go @@ -50,9 +50,9 @@ type ( PreferredAddress = wire.PreferredAddress // A TransportError is a transport-level error code. - TransportError = qerr.ErrorCode + TransportError = qerr.TransportErrorCode // An ApplicationError is an application-defined error code. - ApplicationError = qerr.ErrorCode + ApplicationError = qerr.TransportErrorCode // The RTTStats contain statistics used by the congestion controller. RTTStats = utils.RTTStats diff --git a/mock_packer_test.go b/mock_packer_test.go index b63f0277..54b7e482 100644 --- a/mock_packer_test.go +++ b/mock_packer_test.go @@ -79,6 +79,21 @@ func (mr *MockPackerMockRecorder) MaybePackProbePacket(arg0 interface{}) *gomock return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MaybePackProbePacket", reflect.TypeOf((*MockPacker)(nil).MaybePackProbePacket), arg0) } +// PackApplicationClose mocks base method. +func (m *MockPacker) PackApplicationClose(arg0 *qerr.ApplicationError) (*coalescedPacket, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "PackApplicationClose", arg0) + ret0, _ := ret[0].(*coalescedPacket) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// PackApplicationClose indicates an expected call of PackApplicationClose. +func (mr *MockPackerMockRecorder) PackApplicationClose(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PackApplicationClose", reflect.TypeOf((*MockPacker)(nil).PackApplicationClose), arg0) +} + // PackCoalescedPacket mocks base method. func (m *MockPacker) PackCoalescedPacket() (*coalescedPacket, error) { m.ctrl.T.Helper() @@ -95,7 +110,7 @@ func (mr *MockPackerMockRecorder) PackCoalescedPacket() *gomock.Call { } // PackConnectionClose mocks base method. -func (m *MockPacker) PackConnectionClose(arg0 *qerr.QuicError) (*coalescedPacket, error) { +func (m *MockPacker) PackConnectionClose(arg0 *qerr.TransportError) (*coalescedPacket, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "PackConnectionClose", arg0) ret0, _ := ret[0].(*coalescedPacket) diff --git a/packet_handler_map.go b/packet_handler_map.go index 63d33703..0d31bd83 100644 --- a/packet_handler_map.go +++ b/packet_handler_map.go @@ -22,10 +22,15 @@ type statelessResetErr struct { token protocol.StatelessResetToken } -func (e statelessResetErr) Error() string { +func (e *statelessResetErr) Error() string { return fmt.Sprintf("received a stateless reset with token %x", e.token) } +func (e *statelessResetErr) Is(target error) bool { + _, ok := target.(*statelessResetErr) + return ok +} + type zeroRTTQueue struct { queue []*receivedPacket retireTimer *time.Timer @@ -430,7 +435,7 @@ func (h *packetHandlerMap) maybeHandleStatelessReset(data []byte) bool { copy(token[:], data[len(data)-16:]) if sess, ok := h.resetTokens[token]; ok { h.logger.Debugf("Received a stateless reset with token %#x. Closing session.", token) - go sess.destroy(statelessResetErr{token: token}) + go sess.destroy(&statelessResetErr{token: token}) return true } return false diff --git a/packet_handler_map_test.go b/packet_handler_map_test.go index 76de8c16..66882f55 100644 --- a/packet_handler_map_test.go +++ b/packet_handler_map_test.go @@ -373,7 +373,7 @@ var _ = Describe("Packet Handler Map", func() { defer GinkgoRecover() defer close(destroyed) Expect(err).To(HaveOccurred()) - var resetErr statelessResetErr + var resetErr *statelessResetErr Expect(errors.As(err, &resetErr)).To(BeTrue()) Expect(err.Error()).To(ContainSubstring("received a stateless reset")) Expect(resetErr.token).To(Equal(token)) @@ -393,7 +393,7 @@ var _ = Describe("Packet Handler Map", func() { packetHandler.EXPECT().destroy(gomock.Any()).Do(func(err error) { defer GinkgoRecover() Expect(err).To(HaveOccurred()) - var resetErr statelessResetErr + var resetErr *statelessResetErr Expect(errors.As(err, &resetErr)).To(BeTrue()) Expect(err.Error()).To(ContainSubstring("received a stateless reset")) Expect(resetErr.token).To(Equal(token)) diff --git a/packet_packer.go b/packet_packer.go index 20f58ba1..b02525db 100644 --- a/packet_packer.go +++ b/packet_packer.go @@ -20,7 +20,8 @@ type packer interface { PackPacket() (*packedPacket, error) MaybePackProbePacket(protocol.EncryptionLevel) (*packedPacket, error) MaybePackAckPacket(handshakeConfirmed bool) (*packedPacket, error) - PackConnectionClose(*qerr.QuicError) (*coalescedPacket, error) + PackConnectionClose(*qerr.TransportError) (*coalescedPacket, error) + PackApplicationClose(*qerr.ApplicationError) (*coalescedPacket, error) SetMaxPacketSize(protocol.ByteCount) PackMTUProbePacket(ping ackhandler.Frame, size protocol.ByteCount) (*packedPacket, error) @@ -203,14 +204,27 @@ func newPacketPacker( } } -// PackConnectionClose packs a packet that ONLY contains a ConnectionCloseFrame -func (p *packetPacker) PackConnectionClose(quicErr *qerr.QuicError) (*coalescedPacket, error) { +// PackConnectionClose packs a packet that closes the connection with a transport error. +func (p *packetPacker) PackConnectionClose(e *qerr.TransportError) (*coalescedPacket, error) { var reason string // don't send details of crypto errors - if !quicErr.IsCryptoError() { - reason = quicErr.ErrorMessage + if !e.ErrorCode.IsCryptoError() { + reason = e.ErrorMessage } + return p.packConnectionClose(false, uint64(e.ErrorCode), e.FrameType, reason) +} +// PackApplicationClose packs a packet that closes the connection with an application error. +func (p *packetPacker) PackApplicationClose(e *qerr.ApplicationError) (*coalescedPacket, error) { + return p.packConnectionClose(true, e.ErrorCode, 0, e.ErrorMessage) +} + +func (p *packetPacker) packConnectionClose( + isApplicationError bool, + errorCode uint64, + frameType uint64, + reason string, +) (*coalescedPacket, error) { var sealers [4]sealer var hdrs [4]*wire.ExtendedHeader var payloads [4]*payload @@ -221,20 +235,17 @@ func (p *packetPacker) PackConnectionClose(quicErr *qerr.QuicError) (*coalescedP if p.perspective == protocol.PerspectiveServer && encLevel == protocol.Encryption0RTT { continue } - quicErrToSend := quicErr - reasonPhrase := reason - if encLevel == protocol.EncryptionInitial || encLevel == protocol.EncryptionHandshake { - // don't send application errors in Initial or Handshake packets - if quicErr.IsApplicationError() { - quicErrToSend = qerr.NewError(qerr.ApplicationError, "") - reasonPhrase = "" - } - } ccf := &wire.ConnectionCloseFrame{ - IsApplicationError: quicErrToSend.IsApplicationError(), - ErrorCode: quicErrToSend.ErrorCode, - FrameType: quicErrToSend.FrameType, - ReasonPhrase: reasonPhrase, + IsApplicationError: isApplicationError, + ErrorCode: errorCode, + FrameType: frameType, + ReasonPhrase: reason, + } + // don't send application errors in Initial or Handshake packets + if isApplicationError && (encLevel == protocol.EncryptionInitial || encLevel == protocol.EncryptionHandshake) { + ccf.IsApplicationError = false + ccf.ErrorCode = uint64(qerr.ApplicationErrorErrorCode) + ccf.ReasonPhrase = "" } payload := &payload{ frames: []ackhandler.Frame{{Frame: ccf}}, diff --git a/packet_packer_test.go b/packet_packer_test.go index 3911d8aa..6fb7d030 100644 --- a/packet_packer_test.go +++ b/packet_packer_test.go @@ -339,7 +339,10 @@ var _ = Describe("Packet packer", func() { sealingManager.EXPECT().GetHandshakeSealer().Return(nil, handshake.ErrKeysDropped) sealingManager.EXPECT().Get1RTTSealer().Return(getSealer(), nil) // expect no framer.PopStreamFrames - p, err := packer.PackConnectionClose(qerr.NewError(qerr.CryptoBufferExceeded, "test error")) + p, err := packer.PackConnectionClose(&qerr.TransportError{ + ErrorCode: qerr.CryptoBufferExceeded, + ErrorMessage: "test error", + }) Expect(err).ToNot(HaveOccurred()) Expect(p.packets).To(HaveLen(1)) Expect(p.packets[0].header.IsLongHeader).To(BeFalse()) @@ -347,7 +350,7 @@ var _ = Describe("Packet packer", func() { Expect(p.packets[0].frames[0].Frame).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{})) ccf := p.packets[0].frames[0].Frame.(*wire.ConnectionCloseFrame) Expect(ccf.IsApplicationError).To(BeFalse()) - Expect(ccf.ErrorCode).To(Equal(qerr.CryptoBufferExceeded)) + Expect(ccf.ErrorCode).To(BeEquivalentTo(qerr.CryptoBufferExceeded)) Expect(ccf.ReasonPhrase).To(Equal("test error")) }) @@ -361,7 +364,10 @@ var _ = Describe("Packet packer", func() { sealingManager.EXPECT().GetInitialSealer().Return(getSealer(), nil) sealingManager.EXPECT().GetHandshakeSealer().Return(getSealer(), nil) sealingManager.EXPECT().Get1RTTSealer().Return(getSealer(), nil) - p, err := packer.PackConnectionClose(qerr.NewApplicationError(0x1337, "test error")) + p, err := packer.PackApplicationClose(&qerr.ApplicationError{ + ErrorCode: 0x1337, + ErrorMessage: "test error", + }) Expect(err).ToNot(HaveOccurred()) Expect(p.packets).To(HaveLen(3)) Expect(p.packets[0].header.Type).To(Equal(protocol.PacketTypeInitial)) @@ -370,7 +376,7 @@ var _ = Describe("Packet packer", func() { Expect(p.packets[0].frames[0].Frame).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{})) ccf := p.packets[0].frames[0].Frame.(*wire.ConnectionCloseFrame) Expect(ccf.IsApplicationError).To(BeFalse()) - Expect(ccf.ErrorCode).To(Equal(qerr.ApplicationError)) + Expect(ccf.ErrorCode).To(BeEquivalentTo(qerr.ApplicationErrorErrorCode)) Expect(ccf.ReasonPhrase).To(BeEmpty()) Expect(p.packets[1].header.Type).To(Equal(protocol.PacketTypeHandshake)) Expect(p.packets[1].header.PacketNumber).To(Equal(protocol.PacketNumber(2))) @@ -378,7 +384,7 @@ var _ = Describe("Packet packer", func() { Expect(p.packets[1].frames[0].Frame).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{})) ccf = p.packets[1].frames[0].Frame.(*wire.ConnectionCloseFrame) Expect(ccf.IsApplicationError).To(BeFalse()) - Expect(ccf.ErrorCode).To(Equal(qerr.ApplicationError)) + Expect(ccf.ErrorCode).To(BeEquivalentTo(qerr.ApplicationErrorErrorCode)) Expect(ccf.ReasonPhrase).To(BeEmpty()) Expect(p.packets[2].header.IsLongHeader).To(BeFalse()) Expect(p.packets[2].header.PacketNumber).To(Equal(protocol.PacketNumber(3))) @@ -400,7 +406,10 @@ var _ = Describe("Packet packer", func() { sealingManager.EXPECT().GetHandshakeSealer().Return(getSealer(), nil) sealingManager.EXPECT().Get0RTTSealer().Return(nil, handshake.ErrKeysDropped) sealingManager.EXPECT().Get1RTTSealer().Return(getSealer(), nil) - p, err := packer.PackConnectionClose(qerr.NewApplicationError(0x1337, "test error")) + p, err := packer.PackApplicationClose(&qerr.ApplicationError{ + ErrorCode: 0x1337, + ErrorMessage: "test error", + }) Expect(err).ToNot(HaveOccurred()) Expect(p.packets).To(HaveLen(2)) Expect(p.buffer.Len()).To(BeNumerically("<", protocol.MinInitialPacketSize)) @@ -410,7 +419,7 @@ var _ = Describe("Packet packer", func() { Expect(p.packets[0].frames[0].Frame).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{})) ccf := p.packets[0].frames[0].Frame.(*wire.ConnectionCloseFrame) Expect(ccf.IsApplicationError).To(BeFalse()) - Expect(ccf.ErrorCode).To(Equal(qerr.ApplicationError)) + Expect(ccf.ErrorCode).To(BeEquivalentTo(qerr.ApplicationErrorErrorCode)) Expect(ccf.ReasonPhrase).To(BeEmpty()) Expect(p.packets[1].header.IsLongHeader).To(BeFalse()) Expect(p.packets[1].header.PacketNumber).To(Equal(protocol.PacketNumber(2))) @@ -432,7 +441,10 @@ var _ = Describe("Packet packer", func() { sealingManager.EXPECT().GetHandshakeSealer().Return(nil, handshake.ErrKeysNotYetAvailable) sealingManager.EXPECT().Get0RTTSealer().Return(getSealer(), nil) sealingManager.EXPECT().Get1RTTSealer().Return(nil, handshake.ErrKeysNotYetAvailable) - p, err := packer.PackConnectionClose(qerr.NewApplicationError(0x1337, "test error")) + p, err := packer.PackApplicationClose(&qerr.ApplicationError{ + ErrorCode: 0x1337, + ErrorMessage: "test error", + }) Expect(err).ToNot(HaveOccurred()) Expect(p.packets).To(HaveLen(2)) Expect(p.buffer.Len()).To(BeNumerically(">=", protocol.MinInitialPacketSize)) @@ -443,7 +455,7 @@ var _ = Describe("Packet packer", func() { Expect(p.packets[0].frames[0].Frame).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{})) ccf := p.packets[0].frames[0].Frame.(*wire.ConnectionCloseFrame) Expect(ccf.IsApplicationError).To(BeFalse()) - Expect(ccf.ErrorCode).To(Equal(qerr.ApplicationError)) + Expect(ccf.ErrorCode).To(BeEquivalentTo(qerr.ApplicationErrorErrorCode)) Expect(ccf.ReasonPhrase).To(BeEmpty()) Expect(p.packets[1].header.Type).To(Equal(protocol.PacketType0RTT)) Expect(p.packets[1].header.PacketNumber).To(Equal(protocol.PacketNumber(2))) diff --git a/packet_unpacker_test.go b/packet_unpacker_test.go index 51eabf85..5787da0a 100644 --- a/packet_unpacker_test.go +++ b/packet_unpacker_test.go @@ -186,9 +186,10 @@ var _ = Describe("Packet Unpacker", func() { cs.EXPECT().GetHandshakeOpener().Return(opener, nil) opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any()) opener.EXPECT().DecodePacketNumber(gomock.Any(), gomock.Any()) - opener.EXPECT().Open(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, qerr.CryptoBufferExceeded) + unpackErr := &qerr.TransportError{ErrorCode: qerr.CryptoBufferExceeded} + opener.EXPECT().Open(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, unpackErr) _, err := unpacker.Unpack(hdr, time.Now(), append(hdrRaw, payload...)) - Expect(err).To(MatchError(qerr.CryptoBufferExceeded)) + Expect(err).To(MatchError(unpackErr)) }) It("defends against the timing side-channel when the reserved bits are wrong, for long header packets", func() { diff --git a/qlog/frame.go b/qlog/frame.go index 498d3982..4530f0fb 100644 --- a/qlog/frame.go +++ b/qlog/frame.go @@ -211,9 +211,9 @@ func marshalConnectionCloseFrame(enc *gojay.Encoder, f *logging.ConnectionCloseF if errName := transportError(f.ErrorCode).String(); len(errName) > 0 { enc.StringKey("error_code", errName) } else { - enc.Uint64Key("error_code", uint64(f.ErrorCode)) + enc.Uint64Key("error_code", f.ErrorCode) } - enc.Uint64Key("raw_error_code", uint64(f.ErrorCode)) + enc.Uint64Key("raw_error_code", f.ErrorCode) enc.StringKey("reason", f.ReasonPhrase) } diff --git a/qlog/frame_test.go b/qlog/frame_test.go index bab01f01..b5e553e8 100644 --- a/qlog/frame_test.go +++ b/qlog/frame_test.go @@ -343,7 +343,7 @@ var _ = Describe("Frames", func() { It("marshals CONNECTION_CLOSE frames, for transport error codes", func() { check( &logging.ConnectionCloseFrame{ - ErrorCode: qerr.FlowControlError, + ErrorCode: uint64(qerr.FlowControlError), ReasonPhrase: "lorem ipsum", }, map[string]interface{}{ diff --git a/qlog/types.go b/qlog/types.go index 50202c6c..dd9c70f2 100644 --- a/qlog/types.go +++ b/qlog/types.go @@ -180,7 +180,7 @@ func (t keyUpdateTrigger) String() string { type transportError uint64 func (e transportError) String() string { - switch qerr.ErrorCode(e) { + switch qerr.TransportErrorCode(e) { case qerr.NoError: return "no_error" case qerr.InternalError: @@ -205,7 +205,7 @@ func (e transportError) String() string { return "protocol_violation" case qerr.InvalidToken: return "invalid_token" - case qerr.ApplicationError: + case qerr.ApplicationErrorErrorCode: return "application_error" case qerr.CryptoBufferExceeded: return "crypto_buffer_exceeded" diff --git a/server.go b/server.go index c75324d2..f2b94c42 100644 --- a/server.go +++ b/server.go @@ -600,12 +600,12 @@ func (s *baseServer) sendConnectionRefused(remoteAddr net.Addr, hdr *wire.Header } // sendError sends the error as a response to the packet received with header hdr -func (s *baseServer) sendError(remoteAddr net.Addr, hdr *wire.Header, sealer handshake.LongHeaderSealer, errorCode qerr.ErrorCode, info *packetInfo) error { +func (s *baseServer) sendError(remoteAddr net.Addr, hdr *wire.Header, sealer handshake.LongHeaderSealer, errorCode qerr.TransportErrorCode, info *packetInfo) error { packetBuffer := getPacketBuffer() defer packetBuffer.Release() buf := bytes.NewBuffer(packetBuffer.Data) - ccf := &wire.ConnectionCloseFrame{ErrorCode: errorCode} + ccf := &wire.ConnectionCloseFrame{ErrorCode: uint64(errorCode)} replyHdr := &wire.ExtendedHeader{} replyHdr.IsLongHeader = true diff --git a/server_test.go b/server_test.go index 2fd21a85..3abb5341 100644 --- a/server_test.go +++ b/server_test.go @@ -508,7 +508,7 @@ var _ = Describe("Server", func() { Expect(frames[0]).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{})) ccf := frames[0].(*logging.ConnectionCloseFrame) Expect(ccf.IsApplicationError).To(BeFalse()) - Expect(ccf.ErrorCode).To(Equal(qerr.InvalidToken)) + Expect(ccf.ErrorCode).To(BeEquivalentTo(qerr.InvalidToken)) }) done := make(chan struct{}) conn.EXPECT().WriteTo(gomock.Any(), raddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) { @@ -526,7 +526,7 @@ var _ = Describe("Server", func() { Expect(err).ToNot(HaveOccurred()) Expect(f).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{})) ccf := f.(*wire.ConnectionCloseFrame) - Expect(ccf.ErrorCode).To(Equal(qerr.InvalidToken)) + Expect(ccf.ErrorCode).To(BeEquivalentTo(qerr.InvalidToken)) Expect(ccf.ReasonPhrase).To(BeEmpty()) return len(b), nil }) diff --git a/session.go b/session.go index 49146ca1..0b728433 100644 --- a/session.go +++ b/session.go @@ -123,12 +123,12 @@ type errCloseForRecreating struct { nextVersion protocol.VersionNumber } -func (errCloseForRecreating) Error() string { +func (e *errCloseForRecreating) Error() string { return "closing session in order to recreate it" } -func (errCloseForRecreating) Is(target error) bool { - _, ok := target.(errCloseForRecreating) +func (e *errCloseForRecreating) Is(target error) bool { + _, ok := target.(*errCloseForRecreating) return ok } @@ -137,10 +137,15 @@ type errVersionNegotiation struct { theirVersions []protocol.VersionNumber } -func (e errVersionNegotiation) Error() string { +func (e *errVersionNegotiation) Error() string { return fmt.Sprintf("no compatible QUIC version found (we support %s, server offered %s)", e.ourVersions, e.theirVersions) } +func (e *errVersionNegotiation) Is(target error) bool { + _, ok := target.(*errVersionNegotiation) + return ok +} + var sessionTracingID uint64 // to be accessed atomically func nextSessionTracingID() uint64 { return atomic.AddUint64(&sessionTracingID, 1) } @@ -699,8 +704,8 @@ runLoop: } } - s.handleCloseError(closeErr) - if !errors.Is(closeErr.err, errCloseForRecreating{}) && s.tracer != nil { + s.handleCloseError(&closeErr) + if !errors.Is(closeErr.err, &errCloseForRecreating{}) && s.tracer != nil { s.tracer.Close() } s.logger.Infof("Connection %s closed.", s.logID) @@ -952,7 +957,10 @@ func (s *session) handleSinglePacket(p *receivedPacket, hdr *wire.Header) bool / wasQueued = true s.tryQueueingUndecryptablePacket(p, hdr) case wire.ErrInvalidReservedBits: - s.closeLocal(qerr.NewError(qerr.ProtocolViolation, err.Error())) + s.closeLocal(&qerr.TransportError{ + ErrorCode: qerr.ProtocolViolation, + ErrorMessage: err.Error(), + }) case handshake.ErrDecryptionFailed: // This might be a packet injected by an attacker. Drop it. if s.tracer != nil { @@ -1093,7 +1101,7 @@ func (s *session) handleVersionNegotiationPacket(p *receivedPacket) { } newVersion, ok := protocol.ChooseSupportedVersion(s.config.Versions, supportedVersions) if !ok { - s.destroyImpl(errVersionNegotiation{ + s.destroyImpl(&errVersionNegotiation{ ourVersions: s.config.Versions, theirVersions: supportedVersions, }) @@ -1119,7 +1127,10 @@ func (s *session) handleUnpackedPacket( packetSize protocol.ByteCount, // only for logging ) error { if len(packet.data) == 0 { - return qerr.NewError(qerr.ProtocolViolation, "empty packet") + return &qerr.TransportError{ + ErrorCode: qerr.ProtocolViolation, + ErrorMessage: "empty packet", + } } if !s.receivedFirstPacket { @@ -1270,13 +1281,20 @@ func (s *session) handlePacket(p *receivedPacket) { } func (s *session) handleConnectionCloseFrame(frame *wire.ConnectionCloseFrame) { - var e error if frame.IsApplicationError { - e = qerr.NewApplicationError(frame.ErrorCode, frame.ReasonPhrase) - } else { - e = qerr.NewError(frame.ErrorCode, frame.ReasonPhrase) + s.closeRemote(&qerr.ApplicationError{ + Remote: true, + ErrorCode: frame.ErrorCode, + ErrorMessage: frame.ReasonPhrase, + }) + return } - s.closeRemote(e) + s.closeRemote(&qerr.TransportError{ + Remote: true, + ErrorCode: qerr.TransportErrorCode(frame.ErrorCode), + FrameType: frame.FrameType, + ErrorMessage: frame.ReasonPhrase, + }) } func (s *session) handleCryptoFrame(frame *wire.CryptoFrame, encLevel protocol.EncryptionLevel) error { @@ -1357,7 +1375,10 @@ func (s *session) handlePathChallengeFrame(frame *wire.PathChallengeFrame) { func (s *session) handleNewTokenFrame(frame *wire.NewTokenFrame) error { if s.perspective == protocol.PerspectiveServer { - return qerr.NewError(qerr.ProtocolViolation, "Received NEW_TOKEN frame from the client.") + return &qerr.TransportError{ + ErrorCode: qerr.ProtocolViolation, + ErrorMessage: "received NEW_TOKEN frame from the client", + } } if s.config.TokenStore != nil { s.config.TokenStore.Put(s.tokenStoreKey, &ClientToken{data: frame.Token}) @@ -1375,7 +1396,10 @@ func (s *session) handleRetireConnectionIDFrame(f *wire.RetireConnectionIDFrame, func (s *session) handleHandshakeDoneFrame() error { if s.perspective == protocol.PerspectiveServer { - return qerr.NewError(qerr.ProtocolViolation, "received a HANDSHAKE_DONE frame") + return &qerr.TransportError{ + ErrorCode: qerr.ProtocolViolation, + ErrorMessage: "received a HANDSHAKE_DONE frame", + } } if !s.handshakeConfirmed { s.handleHandshakeConfirmed() @@ -1399,7 +1423,10 @@ func (s *session) handleAckFrame(frame *wire.AckFrame, encLevel protocol.Encrypt func (s *session) handleDatagramFrame(f *wire.DatagramFrame) error { if f.Length(s.version) > protocol.MaxDatagramFrameSize { - return qerr.NewError(qerr.ProtocolViolation, "DATAGRAM frame too large") + return &qerr.TransportError{ + ErrorCode: qerr.ProtocolViolation, + ErrorMessage: "DATAGRAM frame too large", + } } s.datagramQueue.HandleDatagramFrame(f) return nil @@ -1448,45 +1475,66 @@ func (s *session) shutdown() { <-s.ctx.Done() } -func (s *session) CloseWithError(code protocol.ApplicationErrorCode, desc string) error { - s.closeLocal(qerr.NewApplicationError(qerr.ErrorCode(code), desc)) +func (s *session) CloseWithError(code ErrorCode, desc string) error { + s.closeLocal(&qerr.ApplicationError{ + ErrorCode: uint64(code), + ErrorMessage: desc, + }) <-s.ctx.Done() return nil } -func (s *session) handleCloseError(closeErr closeError) { - if closeErr.err == nil { - closeErr.err = qerr.NewApplicationError(0, "") +func (s *session) handleCloseError(closeErr *closeError) { + e := closeErr.err + if e == nil { + e = &qerr.ApplicationError{} + } else { + defer func() { + closeErr.err = e + }() } - var quicErr *qerr.QuicError - var ok bool - if quicErr, ok = closeErr.err.(*qerr.QuicError); !ok { - quicErr = qerr.ToQuicError(closeErr.err) + switch { + case errors.Is(e, qerr.ErrIdleTimeout), + errors.Is(e, qerr.ErrHandshakeTimeout), + errors.Is(e, &statelessResetErr{}), + errors.Is(e, &errVersionNegotiation{}), + errors.Is(e, &errCloseForRecreating{}), + errors.Is(e, &qerr.ApplicationError{}), + errors.Is(e, &qerr.TransportError{}): + default: + e = &qerr.TransportError{ + ErrorCode: qerr.InternalError, + ErrorMessage: e.Error(), + } } - s.streamsMap.CloseWithError(quicErr) + s.streamsMap.CloseWithError(e) s.connIDManager.Close() if s.datagramQueue != nil { - s.datagramQueue.CloseWithError(quicErr) + s.datagramQueue.CloseWithError(e) } - if s.tracer != nil && !errors.Is(closeErr.err, errCloseForRecreating{}) { - var resetErr statelessResetErr - var vnErr errVersionNegotiation + if s.tracer != nil && !errors.Is(e, &errCloseForRecreating{}) { + var ( + resetErr *statelessResetErr + vnErr *errVersionNegotiation + transportErr *qerr.TransportError + applicationErr *qerr.ApplicationError + ) switch { - case errors.Is(closeErr.err, qerr.ErrIdleTimeout): + case errors.Is(e, qerr.ErrIdleTimeout): s.tracer.ClosedConnection(logging.NewTimeoutCloseReason(logging.TimeoutReasonIdle)) - case errors.Is(closeErr.err, qerr.ErrHandshakeTimeout): + case errors.Is(e, qerr.ErrHandshakeTimeout): s.tracer.ClosedConnection(logging.NewTimeoutCloseReason(logging.TimeoutReasonHandshake)) - case errors.As(closeErr.err, &resetErr): + case errors.As(e, &resetErr): s.tracer.ClosedConnection(logging.NewStatelessResetCloseReason(resetErr.token)) - case errors.As(closeErr.err, &vnErr): + case errors.As(e, &vnErr): s.tracer.ClosedConnection(logging.NewVersionNegotiationError(vnErr.theirVersions)) - case quicErr.IsApplicationError(): - s.tracer.ClosedConnection(logging.NewApplicationCloseReason(quicErr.ErrorCode, closeErr.remote)) - default: - s.tracer.ClosedConnection(logging.NewTransportCloseReason(quicErr.ErrorCode, closeErr.remote)) + case errors.As(e, &applicationErr): + s.tracer.ClosedConnection(logging.NewApplicationCloseReason(logging.ApplicationError(applicationErr.ErrorCode), closeErr.remote)) + case errors.As(e, &transportErr): + s.tracer.ClosedConnection(logging.NewTransportCloseReason(transportErr.ErrorCode, closeErr.remote)) } } @@ -1499,7 +1547,7 @@ func (s *session) handleCloseError(closeErr closeError) { s.connIDGenerator.RemoveAll() return } - connClosePacket, err := s.sendConnectionClose(quicErr) + connClosePacket, err := s.sendConnectionClose(e) if err != nil { s.logger.Debugf("Error sending CONNECTION_CLOSE: %s", err) } @@ -1538,7 +1586,10 @@ func (s *session) restoreTransportParameters(params *wire.TransportParameters) { func (s *session) handleTransportParameters(params *wire.TransportParameters) { if err := s.checkTransportParameters(params); err != nil { - s.closeLocal(err) + s.closeLocal(&qerr.TransportError{ + ErrorCode: qerr.TransportParameterError, + ErrorMessage: err.Error(), + }) } s.peerParams = params // On the client side we have to wait for handshake completion. @@ -1561,7 +1612,7 @@ func (s *session) checkTransportParameters(params *wire.TransportParameters) err // check the initial_source_connection_id if !params.InitialSourceConnectionID.Equal(s.handshakeDestConnID) { - return qerr.NewError(qerr.TransportParameterError, fmt.Sprintf("expected initial_source_connection_id to equal %s, is %s", s.handshakeDestConnID, params.InitialSourceConnectionID)) + return fmt.Errorf("expected initial_source_connection_id to equal %s, is %s", s.handshakeDestConnID, params.InitialSourceConnectionID) } if s.perspective == protocol.PerspectiveServer { @@ -1569,17 +1620,17 @@ func (s *session) checkTransportParameters(params *wire.TransportParameters) err } // check the original_destination_connection_id if !params.OriginalDestinationConnectionID.Equal(s.origDestConnID) { - return qerr.NewError(qerr.TransportParameterError, fmt.Sprintf("expected original_destination_connection_id to equal %s, is %s", s.origDestConnID, params.OriginalDestinationConnectionID)) + return fmt.Errorf("expected original_destination_connection_id to equal %s, is %s", s.origDestConnID, params.OriginalDestinationConnectionID) } if s.retrySrcConnID != nil { // a Retry was performed if params.RetrySourceConnectionID == nil { - return qerr.NewError(qerr.TransportParameterError, "missing retry_source_connection_id") + return errors.New("missing retry_source_connection_id") } if !(*params.RetrySourceConnectionID).Equal(*s.retrySrcConnID) { - return qerr.NewError(qerr.TransportParameterError, fmt.Sprintf("expected retry_source_connection_id to equal %s, is %s", s.retrySrcConnID, *params.RetrySourceConnectionID)) + return fmt.Errorf("expected retry_source_connection_id to equal %s, is %s", s.retrySrcConnID, *params.RetrySourceConnectionID) } } else if params.RetrySourceConnectionID != nil { - return qerr.NewError(qerr.TransportParameterError, "received retry_source_connection_id, although no Retry was performed") + return errors.New("received retry_source_connection_id, although no Retry was performed") } return nil } @@ -1774,8 +1825,21 @@ func (s *session) sendPackedPacket(packet *packedPacket, now time.Time) { s.sendQueue.Send(packet.buffer) } -func (s *session) sendConnectionClose(quicErr *qerr.QuicError) ([]byte, error) { - packet, err := s.packer.PackConnectionClose(quicErr) +func (s *session) sendConnectionClose(e error) ([]byte, error) { + var packet *coalescedPacket + var err error + var transportErr *qerr.TransportError + var applicationErr *qerr.ApplicationError + if errors.As(e, &transportErr) { + packet, err = s.packer.PackConnectionClose(transportErr) + } else if errors.As(e, &applicationErr) { + packet, err = s.packer.PackApplicationClose(applicationErr) + } else { + packet, err = s.packer.PackConnectionClose(&qerr.TransportError{ + ErrorCode: qerr.InternalError, + ErrorMessage: fmt.Sprintf("session BUG: unspecified error type (msg: %s)", e.Error()), + }) + } if err != nil { return nil, err } diff --git a/session_test.go b/session_test.go index 0e405a1e..7000d810 100644 --- a/session_test.go +++ b/session_test.go @@ -303,8 +303,8 @@ var _ = Describe("Session", func() { It("rejects NEW_TOKEN frames", func() { err := sess.handleNewTokenFrame(&wire.NewTokenFrame{}) Expect(err).To(HaveOccurred()) - Expect(err).To(BeAssignableToTypeOf(&qerr.QuicError{})) - Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.ProtocolViolation)) + Expect(err).To(BeAssignableToTypeOf(&qerr.TransportError{})) + Expect(err.(*qerr.TransportError).ErrorCode).To(Equal(qerr.ProtocolViolation)) }) It("handles BLOCKED frames", func() { @@ -323,8 +323,12 @@ var _ = Describe("Session", func() { }) It("handles CONNECTION_CLOSE frames, with a transport error code", func() { - testErr := qerr.NewError(qerr.StreamLimitError, "foobar") - streamManager.EXPECT().CloseWithError(testErr) + expectedErr := &qerr.TransportError{ + Remote: true, + ErrorCode: qerr.StreamLimitError, + ErrorMessage: "foobar", + } + streamManager.EXPECT().CloseWithError(expectedErr) sessionRunner.EXPECT().ReplaceWithClosed(srcConnID, gomock.Any()).Do(func(_ protocol.ConnectionID, s packetHandler) { Expect(s).To(BeAssignableToTypeOf(&closedRemoteSession{})) }) @@ -345,17 +349,21 @@ var _ = Describe("Session", func() { go func() { defer GinkgoRecover() cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) - Expect(sess.run()).To(MatchError(testErr)) + Expect(sess.run()).To(MatchError(expectedErr)) }() Expect(sess.handleFrame(&wire.ConnectionCloseFrame{ - ErrorCode: qerr.StreamLimitError, + ErrorCode: uint64(qerr.StreamLimitError), ReasonPhrase: "foobar", }, protocol.Encryption1RTT, protocol.ConnectionID{})).To(Succeed()) Eventually(sess.Context().Done()).Should(BeClosed()) }) It("handles CONNECTION_CLOSE frames, with an application error code", func() { - testErr := qerr.NewApplicationError(0x1337, "foobar") + testErr := &qerr.ApplicationError{ + Remote: true, + ErrorCode: 0x1337, + ErrorMessage: "foobar", + } streamManager.EXPECT().CloseWithError(testErr) sessionRunner.EXPECT().ReplaceWithClosed(srcConnID, gomock.Any()).Do(func(_ protocol.ConnectionID, s packetHandler) { Expect(s).To(BeAssignableToTypeOf(&closedRemoteSession{})) @@ -389,7 +397,10 @@ var _ = Describe("Session", func() { }) It("errors on HANDSHAKE_DONE frames", func() { - Expect(sess.handleHandshakeDoneFrame()).To(MatchError("PROTOCOL_VIOLATION: received a HANDSHAKE_DONE frame")) + Expect(sess.handleHandshakeDoneFrame()).To(MatchError(&qerr.TransportError{ + ErrorCode: qerr.ProtocolViolation, + ErrorMessage: "received a HANDSHAKE_DONE frame", + })) }) }) @@ -429,14 +440,14 @@ var _ = Describe("Session", func() { It("shuts down without error", func() { sess.handshakeComplete = true runSession() - streamManager.EXPECT().CloseWithError(qerr.NewApplicationError(0, "")) + streamManager.EXPECT().CloseWithError(&qerr.ApplicationError{}) expectReplaceWithClosed() cryptoSetup.EXPECT().Close() buffer := getPacketBuffer() buffer.Data = append(buffer.Data, []byte("connection close")...) - packer.EXPECT().PackConnectionClose(gomock.Any()).DoAndReturn(func(quicErr *qerr.QuicError) (*coalescedPacket, error) { - Expect(quicErr.ErrorCode).To(BeEquivalentTo(qerr.NoError)) - Expect(quicErr.ErrorMessage).To(BeEmpty()) + packer.EXPECT().PackApplicationClose(gomock.Any()).DoAndReturn(func(e *qerr.ApplicationError) (*coalescedPacket, error) { + Expect(e.ErrorCode).To(BeEquivalentTo(qerr.NoError)) + Expect(e.ErrorMessage).To(BeEmpty()) return &coalescedPacket{buffer: buffer}, nil }) mconn.EXPECT().Write([]byte("connection close")) @@ -459,7 +470,7 @@ var _ = Describe("Session", func() { streamManager.EXPECT().CloseWithError(gomock.Any()) expectReplaceWithClosed() cryptoSetup.EXPECT().Close() - packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) + packer.EXPECT().PackApplicationClose(gomock.Any()).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) mconn.EXPECT().Write(gomock.Any()) tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().Close() @@ -471,15 +482,14 @@ var _ = Describe("Session", func() { It("closes with an error", func() { runSession() - streamManager.EXPECT().CloseWithError(qerr.NewApplicationError(0x1337, "test error")) + expectedErr := &qerr.ApplicationError{ + ErrorCode: 0x1337, + ErrorMessage: "test error", + } + streamManager.EXPECT().CloseWithError(expectedErr) expectReplaceWithClosed() cryptoSetup.EXPECT().Close() - packer.EXPECT().PackConnectionClose(gomock.Any()).DoAndReturn(func(quicErr *qerr.QuicError) (*coalescedPacket, error) { - Expect(quicErr.IsApplicationError()).To(BeTrue()) - Expect(quicErr.ErrorCode).To(BeEquivalentTo(0x1337)) - Expect(quicErr.ErrorMessage).To(Equal("test error")) - return &coalescedPacket{buffer: getPacketBuffer()}, nil - }) + packer.EXPECT().PackApplicationClose(expectedErr).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) mconn.EXPECT().Write(gomock.Any()) gomock.InOrder( tracer.EXPECT().ClosedConnection(gomock.Any()).Do(func(reason logging.CloseReason) { @@ -497,17 +507,15 @@ var _ = Describe("Session", func() { It("includes the frame type in transport-level close frames", func() { runSession() - testErr := qerr.NewErrorWithFrameType(0x1337, 0x42, "test error") - streamManager.EXPECT().CloseWithError(testErr) + expectedErr := &qerr.TransportError{ + ErrorCode: 0x1337, + FrameType: 0x42, + ErrorMessage: "test error", + } + streamManager.EXPECT().CloseWithError(expectedErr) expectReplaceWithClosed() cryptoSetup.EXPECT().Close() - packer.EXPECT().PackConnectionClose(gomock.Any()).DoAndReturn(func(quicErr *qerr.QuicError) (*coalescedPacket, error) { - Expect(quicErr.IsApplicationError()).To(BeFalse()) - Expect(quicErr.FrameType).To(BeEquivalentTo(0x42)) - Expect(quicErr.ErrorCode).To(BeEquivalentTo(0x1337)) - Expect(quicErr.ErrorMessage).To(Equal("test error")) - return &coalescedPacket{buffer: getPacketBuffer()}, nil - }) + packer.EXPECT().PackConnectionClose(expectedErr).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) mconn.EXPECT().Write(gomock.Any()) gomock.InOrder( tracer.EXPECT().ClosedConnection(gomock.Any()).Do(func(reason logging.CloseReason) { @@ -518,7 +526,7 @@ var _ = Describe("Session", func() { }), tracer.EXPECT().Close(), ) - sess.closeLocal(testErr) + sess.closeLocal(expectedErr) Eventually(areSessionsRunning).Should(BeFalse()) Expect(sess.Context().Done()).To(BeClosed()) }) @@ -541,7 +549,10 @@ var _ = Describe("Session", func() { ) sess.destroy(testErr) Eventually(areSessionsRunning).Should(BeFalse()) - expectedRunErr = testErr + expectedRunErr = &qerr.TransportError{ + ErrorCode: qerr.InternalError, + ErrorMessage: testErr.Error(), + } }) It("cancels the context when the run loop exists", func() { @@ -549,7 +560,7 @@ var _ = Describe("Session", func() { streamManager.EXPECT().CloseWithError(gomock.Any()) expectReplaceWithClosed() cryptoSetup.EXPECT().Close() - packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) + packer.EXPECT().PackApplicationClose(gomock.Any()).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) returned := make(chan struct{}) go func() { defer GinkgoRecover() @@ -582,7 +593,7 @@ var _ = Describe("Session", func() { Expect(hdr.Write(buf, sess.version)).To(Succeed()) unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(*wire.Header, time.Time, []byte) (*unpackedPacket, error) { buf := &bytes.Buffer{} - Expect((&wire.ConnectionCloseFrame{ErrorCode: qerr.StreamLimitError}).Write(buf, sess.version)).To(Succeed()) + Expect((&wire.ConnectionCloseFrame{ErrorCode: uint64(qerr.StreamLimitError)}).Write(buf, sess.version)).To(Succeed()) return &unpackedPacket{ hdr: hdr, data: buf.Bytes(), @@ -647,7 +658,7 @@ var _ = Describe("Session", func() { streamManager.EXPECT().CloseWithError(gomock.Any()) sessionRunner.EXPECT().Remove(gomock.Any()).AnyTimes() cryptoSetup.EXPECT().Close() - sess.destroy(statelessResetErr{token: token}) + sess.destroy(&statelessResetErr{token: token}) }) }) @@ -950,7 +961,8 @@ var _ = Describe("Session", func() { cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) err := sess.run() Expect(err).To(HaveOccurred()) - Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.ProtocolViolation)) + Expect(err).To(BeAssignableToTypeOf(&qerr.TransportError{})) + Expect(err.(*qerr.TransportError).ErrorCode).To(Equal(qerr.ProtocolViolation)) close(done) }() expectReplaceWithClosed() @@ -984,7 +996,7 @@ var _ = Describe("Session", func() { }, nil)) Consistently(runErr).ShouldNot(Receive()) // make the go routine return - packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) + packer.EXPECT().PackApplicationClose(gomock.Any()).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().Close() mconn.EXPECT().Write(gomock.Any()) @@ -993,7 +1005,7 @@ var _ = Describe("Session", func() { }) It("closes the session when unpacking fails because of an error other than a decryption error", func() { - unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, qerr.ConnectionIDLimitError) + unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, &qerr.TransportError{ErrorCode: qerr.ConnectionIDLimitError}) streamManager.EXPECT().CloseWithError(gomock.Any()) cryptoSetup.EXPECT().Close() packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) @@ -1003,7 +1015,8 @@ var _ = Describe("Session", func() { cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) err := sess.run() Expect(err).To(HaveOccurred()) - Expect(err.(qerr.ErrorCode)).To(Equal(qerr.ConnectionIDLimitError)) + Expect(err).To(BeAssignableToTypeOf(&qerr.TransportError{})) + Expect(err.(*qerr.TransportError).ErrorCode).To(Equal(qerr.ConnectionIDLimitError)) close(done) }() expectReplaceWithClosed() @@ -1031,7 +1044,10 @@ var _ = Describe("Session", func() { go func() { defer GinkgoRecover() cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) - Expect(sess.run()).To(MatchError("PROTOCOL_VIOLATION: empty packet")) + Expect(sess.run()).To(MatchError(&qerr.TransportError{ + ErrorCode: qerr.ProtocolViolation, + ErrorMessage: "empty packet", + })) close(done) }() expectReplaceWithClosed() @@ -1259,7 +1275,7 @@ var _ = Describe("Session", func() { AfterEach(func() { streamManager.EXPECT().CloseWithError(gomock.Any()) - packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) + packer.EXPECT().PackApplicationClose(gomock.Any()).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) expectReplaceWithClosed() cryptoSetup.EXPECT().Close() mconn.EXPECT().Write(gomock.Any()) @@ -1451,7 +1467,7 @@ var _ = Describe("Session", func() { AfterEach(func() { // make the go routine return - packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) + packer.EXPECT().PackApplicationClose(gomock.Any()).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) expectReplaceWithClosed() cryptoSetup.EXPECT().Close() mconn.EXPECT().Write(gomock.Any()) @@ -1729,7 +1745,7 @@ var _ = Describe("Session", func() { // make the go routine return expectReplaceWithClosed() streamManager.EXPECT().CloseWithError(gomock.Any()) - packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) + packer.EXPECT().PackApplicationClose(gomock.Any()).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) cryptoSetup.EXPECT().Close() mconn.EXPECT().Write(gomock.Any()) sender.EXPECT().Close() @@ -1867,7 +1883,7 @@ var _ = Describe("Session", func() { // make sure the go routine returns streamManager.EXPECT().CloseWithError(gomock.Any()) expectReplaceWithClosed() - packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) + packer.EXPECT().PackApplicationClose(gomock.Any()).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) cryptoSetup.EXPECT().Close() mconn.EXPECT().Write(gomock.Any()) tracer.EXPECT().ClosedConnection(gomock.Any()) @@ -1902,7 +1918,7 @@ var _ = Describe("Session", func() { // make sure the go routine returns streamManager.EXPECT().CloseWithError(gomock.Any()) expectReplaceWithClosed() - packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) + packer.EXPECT().PackApplicationClose(gomock.Any()).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) cryptoSetup.EXPECT().Close() mconn.EXPECT().Write(gomock.Any()) tracer.EXPECT().ClosedConnection(gomock.Any()) @@ -1947,7 +1963,7 @@ var _ = Describe("Session", func() { // make sure the go routine returns streamManager.EXPECT().CloseWithError(gomock.Any()) expectReplaceWithClosed() - packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) + packer.EXPECT().PackApplicationClose(gomock.Any()).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) cryptoSetup.EXPECT().Close() mconn.EXPECT().Write(gomock.Any()) tracer.EXPECT().ClosedConnection(gomock.Any()) @@ -2016,7 +2032,7 @@ var _ = Describe("Session", func() { // make sure the go routine returns streamManager.EXPECT().CloseWithError(gomock.Any()) expectReplaceWithClosed() - packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) + packer.EXPECT().PackApplicationClose(gomock.Any()).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) cryptoSetup.EXPECT().Close() tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().Close() @@ -2034,7 +2050,7 @@ var _ = Describe("Session", func() { }() streamManager.EXPECT().CloseWithError(gomock.Any()) expectReplaceWithClosed() - packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) + packer.EXPECT().PackApplicationClose(gomock.Any()).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) cryptoSetup.EXPECT().Close() mconn.EXPECT().Write(gomock.Any()) tracer.EXPECT().ClosedConnection(gomock.Any()) @@ -2050,12 +2066,15 @@ var _ = Describe("Session", func() { defer GinkgoRecover() cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) err := sess.run() - Expect(err).To(MatchError(qerr.NewApplicationError(0x1337, testErr.Error()))) + Expect(err).To(MatchError(&qerr.ApplicationError{ + ErrorCode: 0x1337, + ErrorMessage: testErr.Error(), + })) close(done) }() streamManager.EXPECT().CloseWithError(gomock.Any()) expectReplaceWithClosed() - packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) + packer.EXPECT().PackApplicationClose(gomock.Any()).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) cryptoSetup.EXPECT().Close() mconn.EXPECT().Write(gomock.Any()) tracer.EXPECT().ClosedConnection(gomock.Any()) @@ -2116,7 +2135,7 @@ var _ = Describe("Session", func() { // make the go routine return expectReplaceWithClosed() streamManager.EXPECT().CloseWithError(gomock.Any()) - packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) + packer.EXPECT().PackApplicationClose(gomock.Any()).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) cryptoSetup.EXPECT().Close() mconn.EXPECT().Write(gomock.Any()) tracer.EXPECT().ClosedConnection(gomock.Any()) @@ -2196,7 +2215,7 @@ var _ = Describe("Session", func() { nerr, ok := err.(net.Error) Expect(ok).To(BeTrue()) Expect(nerr.Timeout()).To(BeTrue()) - Expect(err.Error()).To(ContainSubstring("No recent network activity")) + Expect(err).To(MatchError(qerr.ErrIdleTimeout)) close(done) }() Eventually(done).Should(BeClosed()) @@ -2223,7 +2242,7 @@ var _ = Describe("Session", func() { nerr, ok := err.(net.Error) Expect(ok).To(BeTrue()) Expect(nerr.Timeout()).To(BeTrue()) - Expect(err.Error()).To(ContainSubstring("Handshake did not complete in time")) + Expect(err).To(MatchError(qerr.ErrHandshakeTimeout)) close(done) }() Eventually(done).Should(BeClosed()) @@ -2234,8 +2253,8 @@ var _ = Describe("Session", func() { sess.config.HandshakeIdleTimeout = 9999 * time.Second sess.config.MaxIdleTimeout = 9999 * time.Second sess.lastPacketReceivedTime = time.Now().Add(-time.Minute) - packer.EXPECT().PackConnectionClose(gomock.Any()).DoAndReturn(func(quicErr *qerr.QuicError) (*coalescedPacket, error) { - Expect(quicErr.ErrorCode).To(Equal(qerr.NoError)) + packer.EXPECT().PackApplicationClose(gomock.Any()).DoAndReturn(func(e *qerr.ApplicationError) (*coalescedPacket, error) { + Expect(e.ErrorCode).To(BeZero()) return &coalescedPacket{buffer: getPacketBuffer()}, nil }) gomock.InOrder( @@ -2284,7 +2303,7 @@ var _ = Describe("Session", func() { nerr, ok := err.(net.Error) Expect(ok).To(BeTrue()) Expect(nerr.Timeout()).To(BeTrue()) - Expect(err.Error()).To(ContainSubstring("No recent network activity")) + Expect(err).To(MatchError(qerr.ErrIdleTimeout)) close(done) }() Eventually(done).Should(BeClosed()) @@ -2317,7 +2336,7 @@ var _ = Describe("Session", func() { nerr, ok := err.(net.Error) Expect(ok).To(BeTrue()) Expect(nerr.Timeout()).To(BeTrue()) - Expect(err.Error()).To(ContainSubstring("No recent network activity")) + Expect(err).To(MatchError(qerr.ErrIdleTimeout)) close(done) }() Eventually(done).Should(BeClosed()) @@ -2334,7 +2353,7 @@ var _ = Describe("Session", func() { }() Consistently(sess.Context().Done()).ShouldNot(BeClosed()) // make the go routine return - packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) + packer.EXPECT().PackApplicationClose(gomock.Any()).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) expectReplaceWithClosed() cryptoSetup.EXPECT().Close() mconn.EXPECT().Write(gomock.Any()) @@ -2521,7 +2540,7 @@ var _ = Describe("Client Session", func() { tracer.EXPECT().ReceivedPacket(gomock.Any(), p.Size(), []logging.Frame{}) Expect(sess.handlePacketImpl(p)).To(BeTrue()) // make sure the go routine returns - packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) + packer.EXPECT().PackApplicationClose(gomock.Any()).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) expectReplaceWithClosed() cryptoSetup.EXPECT().Close() mconn.EXPECT().Write(gomock.Any()) @@ -2659,7 +2678,7 @@ var _ = Describe("Client Session", func() { var err error Eventually(errChan).Should(Receive(&err)) Expect(err).To(HaveOccurred()) - Expect(err).ToNot(BeAssignableToTypeOf(&errCloseForRecreating{})) + Expect(err).ToNot(BeAssignableToTypeOf(errCloseForRecreating{})) Expect(err.Error()).To(ContainSubstring("no compatible QUIC version found")) vns, ok := closeReason.VersionNegotiation() Expect(ok).To(BeTrue()) @@ -2759,13 +2778,17 @@ var _ = Describe("Client Session", func() { }() }) - expectClose := func() { + expectClose := func(applicationClose bool) { if !closed { sessionRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any()).Do(func(_ protocol.ConnectionID, s packetHandler) { Expect(s).To(BeAssignableToTypeOf(&closedLocalSession{})) s.shutdown() }) - packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil).MaxTimes(1) + if applicationClose { + packer.EXPECT().PackApplicationClose(gomock.Any()).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil).MaxTimes(1) + } else { + packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil).MaxTimes(1) + } cryptoSetup.EXPECT().Close() mconn.EXPECT().Write(gomock.Any()) gomock.InOrder( @@ -2777,7 +2800,6 @@ var _ = Describe("Client Session", func() { } AfterEach(func() { - expectClose() sess.shutdown() Eventually(sess.Context().Done()).Should(BeClosed()) Eventually(errChan).Should(BeClosed()) @@ -2806,7 +2828,7 @@ var _ = Describe("Client Session", func() { Expect(sess.connIDManager.Get()).To(Equal(protocol.ConnectionID{1, 2, 3, 4})) // shut down sessionRunner.EXPECT().RemoveResetToken(protocol.StatelessResetToken{16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1}) - expectClose() + expectClose(true) }) It("uses the minimum of the peers' idle timeouts", func() { @@ -2821,19 +2843,23 @@ var _ = Describe("Client Session", func() { sess.handleTransportParameters(params) sess.handleHandshakeComplete() Expect(sess.idleTimeout).To(Equal(18 * time.Second)) + expectClose(true) }) - It("errors if the TransportParameters contain a wrong initial_source_connection_id", func() { + It("errors if the transport parameters contain a wrong initial_source_connection_id", func() { sess.handshakeDestConnID = protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef} params := &wire.TransportParameters{ OriginalDestinationConnectionID: destConnID, InitialSourceConnectionID: protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad}, StatelessResetToken: &protocol.StatelessResetToken{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, } - expectClose() + expectClose(false) tracer.EXPECT().ReceivedTransportParameters(params) sess.handleTransportParameters(params) - Eventually(errChan).Should(Receive(MatchError("TRANSPORT_PARAMETER_ERROR: expected initial_source_connection_id to equal deadbeef, is decafbad"))) + Eventually(errChan).Should(Receive(MatchError(&qerr.TransportError{ + ErrorCode: qerr.TransportParameterError, + ErrorMessage: "expected initial_source_connection_id to equal deadbeef, is decafbad", + }))) }) It("errors if the transport parameters don't contain the retry_source_connection_id, if a Retry was performed", func() { @@ -2843,10 +2869,13 @@ var _ = Describe("Client Session", func() { InitialSourceConnectionID: destConnID, StatelessResetToken: &protocol.StatelessResetToken{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, } - expectClose() + expectClose(false) tracer.EXPECT().ReceivedTransportParameters(params) sess.handleTransportParameters(params) - Eventually(errChan).Should(Receive(MatchError("TRANSPORT_PARAMETER_ERROR: missing retry_source_connection_id"))) + Eventually(errChan).Should(Receive(MatchError(&qerr.TransportError{ + ErrorCode: qerr.TransportParameterError, + ErrorMessage: "missing retry_source_connection_id", + }))) }) It("errors if the transport parameters contain the wrong retry_source_connection_id, if a Retry was performed", func() { @@ -2857,10 +2886,13 @@ var _ = Describe("Client Session", func() { RetrySourceConnectionID: &protocol.ConnectionID{0xde, 0xad, 0xc0, 0xde}, StatelessResetToken: &protocol.StatelessResetToken{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, } - expectClose() + expectClose(false) tracer.EXPECT().ReceivedTransportParameters(params) sess.handleTransportParameters(params) - Eventually(errChan).Should(Receive(MatchError("TRANSPORT_PARAMETER_ERROR: expected retry_source_connection_id to equal deadbeef, is deadc0de"))) + Eventually(errChan).Should(Receive(MatchError(&qerr.TransportError{ + ErrorCode: qerr.TransportParameterError, + ErrorMessage: "expected retry_source_connection_id to equal deadbeef, is deadc0de", + }))) }) It("errors if the transport parameters contain the retry_source_connection_id, if no Retry was performed", func() { @@ -2870,10 +2902,13 @@ var _ = Describe("Client Session", func() { RetrySourceConnectionID: &protocol.ConnectionID{0xde, 0xad, 0xc0, 0xde}, StatelessResetToken: &protocol.StatelessResetToken{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, } - expectClose() + expectClose(false) tracer.EXPECT().ReceivedTransportParameters(params) sess.handleTransportParameters(params) - Eventually(errChan).Should(Receive(MatchError("TRANSPORT_PARAMETER_ERROR: received retry_source_connection_id, although no Retry was performed"))) + Eventually(errChan).Should(Receive(MatchError(&qerr.TransportError{ + ErrorCode: qerr.TransportParameterError, + ErrorMessage: "received retry_source_connection_id, although no Retry was performed", + }))) }) It("errors if the transport parameters contain a wrong original_destination_connection_id", func() { @@ -2883,10 +2918,13 @@ var _ = Describe("Client Session", func() { InitialSourceConnectionID: sess.handshakeDestConnID, StatelessResetToken: &protocol.StatelessResetToken{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, } - expectClose() + expectClose(false) tracer.EXPECT().ReceivedTransportParameters(params) sess.handleTransportParameters(params) - Eventually(errChan).Should(Receive(MatchError("TRANSPORT_PARAMETER_ERROR: expected original_destination_connection_id to equal deadbeef, is decafbad"))) + Eventually(errChan).Should(Receive(MatchError(&qerr.TransportError{ + ErrorCode: qerr.TransportParameterError, + ErrorMessage: "expected original_destination_connection_id to equal deadbeef, is decafbad", + }))) }) }) diff --git a/streams_map.go b/streams_map.go index 4ad53a13..79c1ee91 100644 --- a/streams_map.go +++ b/streams_map.go @@ -209,7 +209,10 @@ func (m *streamsMap) DeleteStream(id protocol.StreamID) error { func (m *streamsMap) GetOrOpenReceiveStream(id protocol.StreamID) (receiveStreamI, error) { str, err := m.getOrOpenReceiveStream(id) if err != nil { - return nil, qerr.NewError(qerr.StreamStateError, err.Error()) + return nil, &qerr.TransportError{ + ErrorCode: qerr.StreamStateError, + ErrorMessage: err.Error(), + } } return str, nil } @@ -240,7 +243,10 @@ func (m *streamsMap) getOrOpenReceiveStream(id protocol.StreamID) (receiveStream func (m *streamsMap) GetOrOpenSendStream(id protocol.StreamID) (sendStreamI, error) { str, err := m.getOrOpenSendStream(id) if err != nil { - return nil, qerr.NewError(qerr.StreamStateError, err.Error()) + return nil, &qerr.TransportError{ + ErrorCode: qerr.StreamStateError, + ErrorMessage: err.Error(), + } } return str, nil } diff --git a/streams_map_incoming_bidi.go b/streams_map_incoming_bidi.go index 16319be0..46c8c73a 100644 --- a/streams_map_incoming_bidi.go +++ b/streams_map_incoming_bidi.go @@ -145,7 +145,7 @@ func (m *incomingBidiStreamsMap) DeleteStream(num protocol.StreamNum) error { func (m *incomingBidiStreamsMap) deleteStream(num protocol.StreamNum) error { if _, ok := m.streams[num]; !ok { return streamError{ - message: "Tried to delete unknown incoming stream %d", + message: "tried to delete unknown incoming stream %d", nums: []protocol.StreamNum{num}, } } @@ -156,7 +156,7 @@ func (m *incomingBidiStreamsMap) deleteStream(num protocol.StreamNum) error { entry, ok := m.streams[num] if ok && entry.shouldDelete { return streamError{ - message: "Tried to delete incoming stream %d multiple times", + message: "tried to delete incoming stream %d multiple times", nums: []protocol.StreamNum{num}, } } diff --git a/streams_map_incoming_generic.go b/streams_map_incoming_generic.go index 52076732..4c7696a0 100644 --- a/streams_map_incoming_generic.go +++ b/streams_map_incoming_generic.go @@ -143,7 +143,7 @@ func (m *incomingItemsMap) DeleteStream(num protocol.StreamNum) error { func (m *incomingItemsMap) deleteStream(num protocol.StreamNum) error { if _, ok := m.streams[num]; !ok { return streamError{ - message: "Tried to delete unknown incoming stream %d", + message: "tried to delete unknown incoming stream %d", nums: []protocol.StreamNum{num}, } } @@ -154,7 +154,7 @@ func (m *incomingItemsMap) deleteStream(num protocol.StreamNum) error { entry, ok := m.streams[num] if ok && entry.shouldDelete { return streamError{ - message: "Tried to delete incoming stream %d multiple times", + message: "tried to delete incoming stream %d multiple times", nums: []protocol.StreamNum{num}, } } diff --git a/streams_map_incoming_generic_test.go b/streams_map_incoming_generic_test.go index 6d83a6b2..d84aa047 100644 --- a/streams_map_incoming_generic_test.go +++ b/streams_map_incoming_generic_test.go @@ -213,7 +213,7 @@ var _ = Describe("Streams Map (incoming)", func() { It("errors when deleting a non-existing stream", func() { err := m.DeleteStream(1337) Expect(err).To(HaveOccurred()) - Expect(err.(streamError).TestError()).To(MatchError("Tried to delete unknown incoming stream 1337")) + Expect(err.(streamError).TestError()).To(MatchError("tried to delete unknown incoming stream 1337")) }) It("sends MAX_STREAMS frames when streams are deleted", func() { diff --git a/streams_map_incoming_uni.go b/streams_map_incoming_uni.go index 0dff6f2c..5bddec00 100644 --- a/streams_map_incoming_uni.go +++ b/streams_map_incoming_uni.go @@ -145,7 +145,7 @@ func (m *incomingUniStreamsMap) DeleteStream(num protocol.StreamNum) error { func (m *incomingUniStreamsMap) deleteStream(num protocol.StreamNum) error { if _, ok := m.streams[num]; !ok { return streamError{ - message: "Tried to delete unknown incoming stream %d", + message: "tried to delete unknown incoming stream %d", nums: []protocol.StreamNum{num}, } } @@ -156,7 +156,7 @@ func (m *incomingUniStreamsMap) deleteStream(num protocol.StreamNum) error { entry, ok := m.streams[num] if ok && entry.shouldDelete { return streamError{ - message: "Tried to delete incoming stream %d multiple times", + message: "tried to delete incoming stream %d multiple times", nums: []protocol.StreamNum{num}, } } diff --git a/streams_map_outgoing_bidi.go b/streams_map_outgoing_bidi.go index c14ab981..3f7ec166 100644 --- a/streams_map_outgoing_bidi.go +++ b/streams_map_outgoing_bidi.go @@ -157,7 +157,7 @@ func (m *outgoingBidiStreamsMap) DeleteStream(num protocol.StreamNum) error { if _, ok := m.streams[num]; !ok { return streamError{ - message: "Tried to delete unknown outgoing stream %d", + message: "tried to delete unknown outgoing stream %d", nums: []protocol.StreamNum{num}, } } diff --git a/streams_map_outgoing_generic.go b/streams_map_outgoing_generic.go index 2fe89936..dde75043 100644 --- a/streams_map_outgoing_generic.go +++ b/streams_map_outgoing_generic.go @@ -155,7 +155,7 @@ func (m *outgoingItemsMap) DeleteStream(num protocol.StreamNum) error { if _, ok := m.streams[num]; !ok { return streamError{ - message: "Tried to delete unknown outgoing stream %d", + message: "tried to delete unknown outgoing stream %d", nums: []protocol.StreamNum{num}, } } diff --git a/streams_map_outgoing_generic_test.go b/streams_map_outgoing_generic_test.go index fb07dd99..421fb4ae 100644 --- a/streams_map_outgoing_generic_test.go +++ b/streams_map_outgoing_generic_test.go @@ -88,7 +88,7 @@ var _ = Describe("Streams Map (outgoing)", func() { It("errors when deleting a non-existing stream", func() { err := m.DeleteStream(1337) Expect(err).To(HaveOccurred()) - Expect(err.(streamError).TestError()).To(MatchError("Tried to delete unknown outgoing stream 1337")) + Expect(err.(streamError).TestError()).To(MatchError("tried to delete unknown outgoing stream 1337")) }) It("errors when deleting a stream twice", func() { @@ -97,7 +97,7 @@ var _ = Describe("Streams Map (outgoing)", func() { Expect(m.DeleteStream(1)).To(Succeed()) err = m.DeleteStream(1) Expect(err).To(HaveOccurred()) - Expect(err.(streamError).TestError()).To(MatchError("Tried to delete unknown outgoing stream 1")) + Expect(err.(streamError).TestError()).To(MatchError("tried to delete unknown outgoing stream 1")) }) It("closes all streams when CloseWithError is called", func() { diff --git a/streams_map_outgoing_uni.go b/streams_map_outgoing_uni.go index 9177c143..8782364a 100644 --- a/streams_map_outgoing_uni.go +++ b/streams_map_outgoing_uni.go @@ -157,7 +157,7 @@ func (m *outgoingUniStreamsMap) DeleteStream(num protocol.StreamNum) error { if _, ok := m.streams[num]; !ok { return streamError{ - message: "Tried to delete unknown outgoing stream %d", + message: "tried to delete unknown outgoing stream %d", nums: []protocol.StreamNum{num}, } } diff --git a/streams_map_test.go b/streams_map_test.go index 8ef91fc2..ffce136b 100644 --- a/streams_map_test.go +++ b/streams_map_test.go @@ -7,9 +7,11 @@ import ( "net" "github.com/golang/mock/gomock" + "github.com/lucas-clemente/quic-go/internal/flowcontrol" "github.com/lucas-clemente/quic-go/internal/mocks" "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/qerr" "github.com/lucas-clemente/quic-go/internal/wire" . "github.com/onsi/ginkgo" @@ -210,22 +212,22 @@ var _ = Describe("Streams Map", func() { It("errors when deleting unknown incoming unidirectional streams", func() { id := ids.firstIncomingUniStream + 4 - Expect(m.DeleteStream(id)).To(MatchError(fmt.Sprintf("Tried to delete unknown incoming stream %d", id))) + Expect(m.DeleteStream(id)).To(MatchError(fmt.Sprintf("tried to delete unknown incoming stream %d", id))) }) It("errors when deleting unknown outgoing unidirectional streams", func() { id := ids.firstOutgoingUniStream + 4 - Expect(m.DeleteStream(id)).To(MatchError(fmt.Sprintf("Tried to delete unknown outgoing stream %d", id))) + Expect(m.DeleteStream(id)).To(MatchError(fmt.Sprintf("tried to delete unknown outgoing stream %d", id))) }) It("errors when deleting unknown incoming bidirectional streams", func() { id := ids.firstIncomingBidiStream + 4 - Expect(m.DeleteStream(id)).To(MatchError(fmt.Sprintf("Tried to delete unknown incoming stream %d", id))) + Expect(m.DeleteStream(id)).To(MatchError(fmt.Sprintf("tried to delete unknown incoming stream %d", id))) }) It("errors when deleting unknown outgoing bidirectional streams", func() { id := ids.firstOutgoingBidiStream + 4 - Expect(m.DeleteStream(id)).To(MatchError(fmt.Sprintf("Tried to delete unknown outgoing stream %d", id))) + Expect(m.DeleteStream(id)).To(MatchError(fmt.Sprintf("tried to delete unknown outgoing stream %d", id))) }) }) @@ -248,7 +250,10 @@ var _ = Describe("Streams Map", func() { It("errors when the peer tries to open a higher outgoing bidirectional stream", func() { id := ids.firstOutgoingBidiStream + 5*4 _, err := m.GetOrOpenSendStream(id) - Expect(err).To(MatchError(fmt.Sprintf("STREAM_STATE_ERROR: peer attempted to open stream %d", id))) + Expect(err).To(MatchError(&qerr.TransportError{ + ErrorCode: qerr.StreamStateError, + ErrorMessage: fmt.Sprintf("peer attempted to open stream %d", id), + })) }) It("gets an outgoing unidirectional stream", func() { @@ -264,7 +269,10 @@ var _ = Describe("Streams Map", func() { It("errors when the peer tries to open a higher outgoing bidirectional stream", func() { id := ids.firstOutgoingUniStream + 5*4 _, err := m.GetOrOpenSendStream(id) - Expect(err).To(MatchError(fmt.Sprintf("STREAM_STATE_ERROR: peer attempted to open stream %d", id))) + Expect(err).To(MatchError(&qerr.TransportError{ + ErrorCode: qerr.StreamStateError, + ErrorMessage: fmt.Sprintf("peer attempted to open stream %d", id), + })) }) It("gets an incoming bidirectional stream", func() { @@ -277,7 +285,10 @@ var _ = Describe("Streams Map", func() { It("errors when trying to get an incoming unidirectional stream", func() { id := ids.firstIncomingUniStream _, err := m.GetOrOpenSendStream(id) - Expect(err).To(MatchError(fmt.Sprintf("STREAM_STATE_ERROR: peer attempted to open send stream %d", id))) + Expect(err).To(MatchError(&qerr.TransportError{ + ErrorCode: qerr.StreamStateError, + ErrorMessage: fmt.Sprintf("peer attempted to open send stream %d", id), + })) }) }) @@ -295,7 +306,10 @@ var _ = Describe("Streams Map", func() { It("errors when the peer tries to open a higher outgoing bidirectional stream", func() { id := ids.firstOutgoingBidiStream + 5*4 _, err := m.GetOrOpenReceiveStream(id) - Expect(err).To(MatchError(fmt.Sprintf("STREAM_STATE_ERROR: peer attempted to open stream %d", id))) + Expect(err).To(MatchError(&qerr.TransportError{ + ErrorCode: qerr.StreamStateError, + ErrorMessage: fmt.Sprintf("peer attempted to open stream %d", id), + })) }) It("gets an incoming bidirectional stream", func() { @@ -315,7 +329,10 @@ var _ = Describe("Streams Map", func() { It("errors when trying to get an outgoing unidirectional stream", func() { id := ids.firstOutgoingUniStream _, err := m.GetOrOpenReceiveStream(id) - Expect(err).To(MatchError(fmt.Sprintf("STREAM_STATE_ERROR: peer attempted to open receive stream %d", id))) + Expect(err).To(MatchError(&qerr.TransportError{ + ErrorCode: qerr.StreamStateError, + ErrorMessage: fmt.Sprintf("peer attempted to open receive stream %d", id), + })) }) }) })