From 072a602cc1ec55191ccf5c39afb1b352f93f15b1 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sat, 3 Jun 2023 10:08:58 +0300 Subject: [PATCH] pass around receivedPacket as struct instead of as pointer (#3823) --- closed_conn.go | 4 +-- closed_conn_test.go | 2 +- connection.go | 24 +++++++++--------- connection_test.go | 38 ++++++++++++++--------------- mock_packet_handler_test.go | 2 +- mock_quic_conn_test.go | 2 +- mock_unknown_packet_handler_test.go | 2 +- packet_handler_map.go | 4 +-- packet_handler_map_test.go | 4 +-- server.go | 38 +++++++++++++++-------------- server_test.go | 24 +++++++++--------- sys_conn.go | 6 ++--- sys_conn_oob.go | 8 +++--- sys_conn_oob_test.go | 16 ++++++------ transport.go | 10 ++++---- transport_test.go | 6 ++--- 16 files changed, 96 insertions(+), 94 deletions(-) diff --git a/closed_conn.go b/closed_conn.go index 73904b84..901bb8ae 100644 --- a/closed_conn.go +++ b/closed_conn.go @@ -30,7 +30,7 @@ func newClosedLocalConn(sendPacket func(net.Addr, *packetInfo), pers protocol.Pe } } -func (c *closedLocalConn) handlePacket(p *receivedPacket) { +func (c *closedLocalConn) handlePacket(p receivedPacket) { c.counter++ // exponential backoff // only send a CONNECTION_CLOSE for the 1st, 2nd, 4th, 8th, 16th, ... packet arriving @@ -58,7 +58,7 @@ func newClosedRemoteConn(pers protocol.Perspective) packetHandler { return &closedRemoteConn{perspective: pers} } -func (s *closedRemoteConn) handlePacket(*receivedPacket) {} +func (s *closedRemoteConn) handlePacket(receivedPacket) {} func (s *closedRemoteConn) shutdown() {} func (s *closedRemoteConn) destroy(error) {} func (s *closedRemoteConn) getPerspective() protocol.Perspective { return s.perspective } diff --git a/closed_conn_test.go b/closed_conn_test.go index 2e32b70a..21fddab4 100644 --- a/closed_conn_test.go +++ b/closed_conn_test.go @@ -27,7 +27,7 @@ var _ = Describe("Closed local connection", func() { ) addr := &net.UDPAddr{IP: net.IPv4(127, 1, 2, 3), Port: 1337} for i := 1; i <= 20; i++ { - conn.handlePacket(&receivedPacket{remoteAddr: addr}) + conn.handlePacket(receivedPacket{remoteAddr: addr}) if i == 1 || i == 2 || i == 4 || i == 8 || i == 16 { Expect(written).To(Receive(Equal(addr))) // receive the CONNECTION_CLOSE } else { diff --git a/connection.go b/connection.go index 51cd5fb7..d94c4bf6 100644 --- a/connection.go +++ b/connection.go @@ -168,7 +168,7 @@ type connection struct { oneRTTStream cryptoStream // only set for the server cryptoStreamHandler cryptoStreamHandler - receivedPackets chan *receivedPacket + receivedPackets chan receivedPacket sendingScheduled chan struct{} closeOnce sync.Once @@ -180,8 +180,8 @@ type connection struct { handshakeCtx context.Context handshakeCtxCancel context.CancelFunc - undecryptablePackets []*receivedPacket // undecryptable packets, waiting for a change in encryption level - undecryptablePacketsToProcess []*receivedPacket + undecryptablePackets []receivedPacket // undecryptable packets, waiting for a change in encryption level + undecryptablePacketsToProcess []receivedPacket clientHelloWritten <-chan *wire.TransportParameters earlyConnReadyChan chan struct{} @@ -509,7 +509,7 @@ func (s *connection) preSetup() { s.perspective, ) s.framer = newFramer(s.streamsMap) - s.receivedPackets = make(chan *receivedPacket, protocol.MaxConnUnprocessedPackets) + s.receivedPackets = make(chan receivedPacket, protocol.MaxConnUnprocessedPackets) s.closeChan = make(chan closeError, 1) s.sendingScheduled = make(chan struct{}, 1) s.handshakeCtx, s.handshakeCtxCancel = context.WithCancel(context.Background()) @@ -806,7 +806,7 @@ func (s *connection) handleHandshakeConfirmed() { } } -func (s *connection) handlePacketImpl(rp *receivedPacket) bool { +func (s *connection) handlePacketImpl(rp receivedPacket) bool { s.sentPacketHandler.ReceivedBytes(rp.Size()) if wire.IsVersionNegotiationPacket(rp.data) { @@ -822,7 +822,7 @@ func (s *connection) handlePacketImpl(rp *receivedPacket) bool { for len(data) > 0 { var destConnID protocol.ConnectionID if counter > 0 { - p = p.Clone() + p = *(p.Clone()) p.data = data var err error @@ -895,7 +895,7 @@ func (s *connection) handlePacketImpl(rp *receivedPacket) bool { return processed } -func (s *connection) handleShortHeaderPacket(p *receivedPacket, destConnID protocol.ConnectionID) bool { +func (s *connection) handleShortHeaderPacket(p receivedPacket, destConnID protocol.ConnectionID) bool { var wasQueued bool defer func() { @@ -946,7 +946,7 @@ func (s *connection) handleShortHeaderPacket(p *receivedPacket, destConnID proto return true } -func (s *connection) handleLongHeaderPacket(p *receivedPacket, hdr *wire.Header) bool /* was the packet successfully processed */ { +func (s *connection) handleLongHeaderPacket(p receivedPacket, hdr *wire.Header) bool /* was the packet successfully processed */ { var wasQueued bool defer func() { @@ -1003,7 +1003,7 @@ func (s *connection) handleLongHeaderPacket(p *receivedPacket, hdr *wire.Header) return true } -func (s *connection) handleUnpackError(err error, p *receivedPacket, pt logging.PacketType) (wasQueued bool) { +func (s *connection) handleUnpackError(err error, p receivedPacket, pt logging.PacketType) (wasQueued bool) { switch err { case handshake.ErrKeysDropped: if s.tracer != nil { @@ -1105,7 +1105,7 @@ func (s *connection) handleRetryPacket(hdr *wire.Header, data []byte) bool /* wa return true } -func (s *connection) handleVersionNegotiationPacket(p *receivedPacket) { +func (s *connection) handleVersionNegotiationPacket(p receivedPacket) { if s.perspective == protocol.PerspectiveServer || // servers never receive version negotiation packets s.receivedFirstPacket || s.versionNegotiated { // ignore delayed / duplicated version negotiation packets if s.tracer != nil { @@ -1340,7 +1340,7 @@ func (s *connection) handleFrame(f wire.Frame, encLevel protocol.EncryptionLevel } // handlePacket is called by the server with a new packet -func (s *connection) handlePacket(p *receivedPacket) { +func (s *connection) handlePacket(p receivedPacket) { // Discard packets once the amount of queued packets is larger than // the channel size, protocol.MaxConnUnprocessedPackets select { @@ -2230,7 +2230,7 @@ func (s *connection) scheduleSending() { // tryQueueingUndecryptablePacket queues a packet for which we're missing the decryption keys. // The logging.PacketType is only used for logging purposes. -func (s *connection) tryQueueingUndecryptablePacket(p *receivedPacket, pt logging.PacketType) { +func (s *connection) tryQueueingUndecryptablePacket(p receivedPacket, pt logging.PacketType) { if s.handshakeComplete { panic("shouldn't queue undecryptable packets after handshake completion") } diff --git a/connection_test.go b/connection_test.go index 50812d0b..c20284e3 100644 --- a/connection_test.go +++ b/connection_test.go @@ -592,7 +592,7 @@ var _ = Describe("Connection", func() { tracer.EXPECT().Close(), ) // don't EXPECT any calls to packer.PackPacket() - conn.handlePacket(&receivedPacket{ + conn.handlePacket(receivedPacket{ rcvTime: time.Now(), remoteAddr: &net.UDPAddr{}, buffer: getPacketBuffer(), @@ -654,20 +654,20 @@ var _ = Describe("Connection", func() { conn.unpacker = unpacker }) - getShortHeaderPacket := func(connID protocol.ConnectionID, pn protocol.PacketNumber, data []byte) *receivedPacket { + getShortHeaderPacket := func(connID protocol.ConnectionID, pn protocol.PacketNumber, data []byte) receivedPacket { b, err := wire.AppendShortHeader(nil, connID, pn, protocol.PacketNumberLen2, protocol.KeyPhaseOne) Expect(err).ToNot(HaveOccurred()) - return &receivedPacket{ + return receivedPacket{ data: append(b, data...), buffer: getPacketBuffer(), rcvTime: time.Now(), } } - getLongHeaderPacket := func(extHdr *wire.ExtendedHeader, data []byte) *receivedPacket { + getLongHeaderPacket := func(extHdr *wire.ExtendedHeader, data []byte) receivedPacket { b, err := extHdr.Append(nil, conn.version) Expect(err).ToNot(HaveOccurred()) - return &receivedPacket{ + return receivedPacket{ data: append(b, data...), buffer: getPacketBuffer(), rcvTime: time.Now(), @@ -693,7 +693,7 @@ var _ = Describe("Connection", func() { conn.config.Versions, ) tracer.EXPECT().DroppedPacket(logging.PacketTypeVersionNegotiation, protocol.ByteCount(len(b)), logging.PacketDropUnexpectedPacket) - Expect(conn.handlePacketImpl(&receivedPacket{ + Expect(conn.handlePacketImpl(receivedPacket{ data: b, buffer: getPacketBuffer(), })).To(BeFalse()) @@ -1036,7 +1036,7 @@ var _ = Describe("Connection", func() { packet := getLongHeaderPacket(hdr, nil) tracer.EXPECT().BufferedPacket(logging.PacketTypeHandshake, packet.Size()) Expect(conn.handlePacketImpl(packet)).To(BeFalse()) - Expect(conn.undecryptablePackets).To(Equal([]*receivedPacket{packet})) + Expect(conn.undecryptablePackets).To(Equal([]receivedPacket{packet})) }) Context("updating the remote address", func() { @@ -1053,7 +1053,7 @@ var _ = Describe("Connection", func() { BeforeEach(func() { tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).MaxTimes(1) }) - getPacketWithLength := func(connID protocol.ConnectionID, length protocol.ByteCount) (int /* header length */, *receivedPacket) { + getPacketWithLength := func(connID protocol.ConnectionID, length protocol.ByteCount) (int /* header length */, receivedPacket) { hdr := &wire.ExtendedHeader{ Header: wire.Header{ Type: protocol.PacketTypeHandshake, @@ -1612,7 +1612,7 @@ var _ = Describe("Connection", func() { sender.EXPECT().WouldBlock().AnyTimes() sph.EXPECT().SentPacket(gomock.Any()).Do(func(*ackhandler.Packet) { sph.EXPECT().ReceivedBytes(gomock.Any()) - conn.handlePacket(&receivedPacket{buffer: getPacketBuffer()}) + conn.handlePacket(receivedPacket{buffer: getPacketBuffer()}) }) sph.EXPECT().SendMode().Return(ackhandler.SendAny).AnyTimes() expectAppendPacket(packer, shortHeaderPacket{Packet: &ackhandler.Packet{PacketNumber: 10}}, []byte("packet10")) @@ -2316,7 +2316,7 @@ var _ = Describe("Connection", func() { }) // Nothing here should block for i := protocol.PacketNumber(0); i < protocol.MaxConnUnprocessedPackets+1; i++ { - conn.handlePacket(&receivedPacket{data: []byte("foobar")}) + conn.handlePacket(receivedPacket{data: []byte("foobar")}) } Eventually(done).Should(BeClosed()) }) @@ -2398,10 +2398,10 @@ var _ = Describe("Client Connection", func() { srcConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}) destConnID := protocol.ParseConnectionID([]byte{8, 7, 6, 5, 4, 3, 2, 1}) - getPacket := func(hdr *wire.ExtendedHeader, data []byte) *receivedPacket { + getPacket := func(hdr *wire.ExtendedHeader, data []byte) receivedPacket { b, err := hdr.Append(nil, conn.version) Expect(err).ToNot(HaveOccurred()) - return &receivedPacket{ + return receivedPacket{ data: append(b, data...), buffer: getPacketBuffer(), } @@ -2519,7 +2519,7 @@ var _ = Describe("Client Connection", func() { SrcConnectionID: destConnID, } tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any()) - Expect(conn.handleLongHeaderPacket(&receivedPacket{buffer: getPacketBuffer()}, hdr)).To(BeTrue()) + Expect(conn.handleLongHeaderPacket(receivedPacket{buffer: getPacketBuffer()}, hdr)).To(BeTrue()) }) It("handles HANDSHAKE_DONE frames", func() { @@ -2580,13 +2580,13 @@ var _ = Describe("Client Connection", func() { }) Context("handling Version Negotiation", func() { - getVNP := func(versions ...protocol.VersionNumber) *receivedPacket { + getVNP := func(versions ...protocol.VersionNumber) receivedPacket { b := wire.ComposeVersionNegotiation( protocol.ArbitraryLenConnectionID(srcConnID.Bytes()), protocol.ArbitraryLenConnectionID(destConnID.Bytes()), versions, ) - return &receivedPacket{ + return receivedPacket{ data: b, buffer: getPacketBuffer(), } @@ -2892,18 +2892,18 @@ var _ = Describe("Client Connection", func() { Context("handling potentially injected packets", func() { var unpacker *MockUnpacker - getPacket := func(extHdr *wire.ExtendedHeader, data []byte) *receivedPacket { + getPacket := func(extHdr *wire.ExtendedHeader, data []byte) receivedPacket { b, err := extHdr.Append(nil, conn.version) Expect(err).ToNot(HaveOccurred()) - return &receivedPacket{ + return receivedPacket{ data: append(b, data...), buffer: getPacketBuffer(), } } // Convert an already packed raw packet into a receivedPacket - wrapPacket := func(packet []byte) *receivedPacket { - return &receivedPacket{ + wrapPacket := func(packet []byte) receivedPacket { + return receivedPacket{ data: packet, buffer: getPacketBuffer(), } diff --git a/mock_packet_handler_test.go b/mock_packet_handler_test.go index aabc1760..529d1b84 100644 --- a/mock_packet_handler_test.go +++ b/mock_packet_handler_test.go @@ -61,7 +61,7 @@ func (mr *MockPacketHandlerMockRecorder) getPerspective() *gomock.Call { } // handlePacket mocks base method. -func (m *MockPacketHandler) handlePacket(arg0 *receivedPacket) { +func (m *MockPacketHandler) handlePacket(arg0 receivedPacket) { m.ctrl.T.Helper() m.ctrl.Call(m, "handlePacket", arg0) } diff --git a/mock_quic_conn_test.go b/mock_quic_conn_test.go index c1d867c3..bebc1c27 100644 --- a/mock_quic_conn_test.go +++ b/mock_quic_conn_test.go @@ -309,7 +309,7 @@ func (mr *MockQUICConnMockRecorder) getPerspective() *gomock.Call { } // handlePacket mocks base method. -func (m *MockQUICConn) handlePacket(arg0 *receivedPacket) { +func (m *MockQUICConn) handlePacket(arg0 receivedPacket) { m.ctrl.T.Helper() m.ctrl.Call(m, "handlePacket", arg0) } diff --git a/mock_unknown_packet_handler_test.go b/mock_unknown_packet_handler_test.go index f8d63ef0..f7489782 100644 --- a/mock_unknown_packet_handler_test.go +++ b/mock_unknown_packet_handler_test.go @@ -34,7 +34,7 @@ func (m *MockUnknownPacketHandler) EXPECT() *MockUnknownPacketHandlerMockRecorde } // handlePacket mocks base method. -func (m *MockUnknownPacketHandler) handlePacket(arg0 *receivedPacket) { +func (m *MockUnknownPacketHandler) handlePacket(arg0 receivedPacket) { m.ctrl.T.Helper() m.ctrl.Call(m, "handlePacket", arg0) } diff --git a/packet_handler_map.go b/packet_handler_map.go index 823c6836..47f0dcc2 100644 --- a/packet_handler_map.go +++ b/packet_handler_map.go @@ -25,7 +25,7 @@ type connCapabilities struct { // rawConn is a connection that allow reading of a receivedPackeh. type rawConn interface { - ReadPacket() (*receivedPacket, error) + ReadPacket() (receivedPacket, error) // The size parameter is used for GSO. // If GSO is not support, len(b) must be equal to size. WritePacket(b []byte, size uint16, addr net.Addr, oob []byte) (int, error) @@ -43,7 +43,7 @@ type closePacket struct { } type unknownPacketHandler interface { - handlePacket(*receivedPacket) + handlePacket(receivedPacket) setCloseError(error) } diff --git a/packet_handler_map_test.go b/packet_handler_map_test.go index 2969bb5b..24cef871 100644 --- a/packet_handler_map_test.go +++ b/packet_handler_map_test.go @@ -129,7 +129,7 @@ var _ = Describe("Packet Handler Map", func() { Expect(ok).To(BeTrue()) Expect(h).ToNot(Equal(handler)) addr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1234} - h.handlePacket(&receivedPacket{remoteAddr: addr}) + h.handlePacket(receivedPacket{remoteAddr: addr}) Expect(closePackets).To(HaveLen(1)) Expect(closePackets[0].addr).To(Equal(addr)) Expect(closePackets[0].payload).To(Equal([]byte("foobar"))) @@ -152,7 +152,7 @@ var _ = Describe("Packet Handler Map", func() { Expect(ok).To(BeTrue()) Expect(h).ToNot(Equal(handler)) addr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1234} - h.handlePacket(&receivedPacket{remoteAddr: addr}) + h.handlePacket(receivedPacket{remoteAddr: addr}) Expect(closePackets).To(BeEmpty()) time.Sleep(dur) diff --git a/server.go b/server.go index 352a83d3..8337a978 100644 --- a/server.go +++ b/server.go @@ -24,7 +24,7 @@ var ErrServerClosed = errors.New("quic: server closed") // packetHandler handles packets type packetHandler interface { - handlePacket(*receivedPacket) + handlePacket(receivedPacket) shutdown() destroy(error) getPerspective() protocol.Perspective @@ -42,7 +42,7 @@ type packetHandlerManager interface { type quicConn interface { EarlyConnection earlyConnReady() <-chan struct{} - handlePacket(*receivedPacket) + handlePacket(receivedPacket) GetVersion() protocol.VersionNumber getPerspective() protocol.Perspective run() error @@ -51,7 +51,7 @@ type quicConn interface { } type zeroRTTQueue struct { - packets []*receivedPacket + packets []receivedPacket expiration time.Time } @@ -72,7 +72,7 @@ type baseServer struct { connHandler packetHandlerManager onClose func() - receivedPackets chan *receivedPacket + receivedPackets chan receivedPacket nextZeroRTTCleanup time.Time zeroRTTQueues map[protocol.ConnectionID]*zeroRTTQueue // only initialized if acceptEarlyConns == true @@ -102,8 +102,8 @@ type baseServer struct { errorChan chan struct{} closed bool running chan struct{} // closed as soon as run() returns - versionNegotiationQueue chan *receivedPacket - invalidTokenQueue chan *receivedPacket + versionNegotiationQueue chan receivedPacket + invalidTokenQueue chan receivedPacket connQueue chan quicConn connQueueLen int32 // to be used as an atomic @@ -242,9 +242,9 @@ func newServer( connQueue: make(chan quicConn), errorChan: make(chan struct{}), running: make(chan struct{}), - receivedPackets: make(chan *receivedPacket, protocol.MaxServerUnprocessedPackets), - versionNegotiationQueue: make(chan *receivedPacket, 4), - invalidTokenQueue: make(chan *receivedPacket, 4), + receivedPackets: make(chan receivedPacket, protocol.MaxServerUnprocessedPackets), + versionNegotiationQueue: make(chan receivedPacket, 4), + invalidTokenQueue: make(chan receivedPacket, 4), newConn: newConnection, tracer: tracer, logger: utils.DefaultLogger.WithPrefix("server"), @@ -345,7 +345,7 @@ func (s *baseServer) Addr() net.Addr { return s.conn.LocalAddr() } -func (s *baseServer) handlePacket(p *receivedPacket) { +func (s *baseServer) handlePacket(p receivedPacket) { select { case s.receivedPackets <- p: default: @@ -356,7 +356,7 @@ func (s *baseServer) handlePacket(p *receivedPacket) { } } -func (s *baseServer) handlePacketImpl(p *receivedPacket) bool /* is the buffer still in use? */ { +func (s *baseServer) handlePacketImpl(p receivedPacket) bool /* is the buffer still in use? */ { if !s.nextZeroRTTCleanup.IsZero() && p.rcvTime.After(s.nextZeroRTTCleanup) { defer s.cleanupZeroRTTQueues(p.rcvTime) } @@ -446,7 +446,7 @@ func (s *baseServer) handlePacketImpl(p *receivedPacket) bool /* is the buffer s return true } -func (s *baseServer) handle0RTTPacket(p *receivedPacket) bool { +func (s *baseServer) handle0RTTPacket(p receivedPacket) bool { connID, err := wire.ParseConnectionID(p.data, 0) if err != nil { if s.tracer != nil { @@ -478,7 +478,7 @@ func (s *baseServer) handle0RTTPacket(p *receivedPacket) bool { } return false } - queue := &zeroRTTQueue{packets: make([]*receivedPacket, 1, 8)} + queue := &zeroRTTQueue{packets: make([]receivedPacket, 1, 8)} queue.packets[0] = p expiration := p.rcvTime.Add(protocol.Max0RTTQueueingDuration) queue.expiration = expiration @@ -534,7 +534,7 @@ func (s *baseServer) validateToken(token *handshake.Token, addr net.Addr) bool { return true } -func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) error { +func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error { if len(hdr.Token) == 0 && hdr.DestConnectionID.Len() < protocol.MinConnectionIDLenInitial { p.buffer.Release() if s.tracer != nil { @@ -746,7 +746,7 @@ func (s *baseServer) sendRetry(remoteAddr net.Addr, hdr *wire.Header, info *pack return err } -func (s *baseServer) enqueueInvalidToken(p *receivedPacket) { +func (s *baseServer) enqueueInvalidToken(p receivedPacket) { select { case s.invalidTokenQueue <- p: default: @@ -755,7 +755,7 @@ func (s *baseServer) enqueueInvalidToken(p *receivedPacket) { } } -func (s *baseServer) maybeSendInvalidToken(p *receivedPacket) { +func (s *baseServer) maybeSendInvalidToken(p receivedPacket) { defer p.buffer.Release() hdr, _, _, err := wire.ParsePacket(p.data) @@ -772,6 +772,8 @@ func (s *baseServer) maybeSendInvalidToken(p *receivedPacket) { sealer, opener := handshake.NewInitialAEAD(hdr.DestConnectionID, protocol.PerspectiveServer, hdr.Version) data := p.data[:hdr.ParsedLen()+hdr.Length] extHdr, err := unpackLongHeader(opener, hdr, data, hdr.Version) + // Only send INVALID_TOKEN if we can unprotect the packet. + // This makes sure that we won't send it for packets that were corrupted. if err != nil { if s.tracer != nil { s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropHeaderParseError) @@ -843,7 +845,7 @@ func (s *baseServer) sendError(remoteAddr net.Addr, hdr *wire.Header, sealer han return err } -func (s *baseServer) enqueueVersionNegotiationPacket(p *receivedPacket) (bufferInUse bool) { +func (s *baseServer) enqueueVersionNegotiationPacket(p receivedPacket) (bufferInUse bool) { select { case s.versionNegotiationQueue <- p: return true @@ -853,7 +855,7 @@ func (s *baseServer) enqueueVersionNegotiationPacket(p *receivedPacket) (bufferI return false } -func (s *baseServer) maybeSendVersionNegotiationPacket(p *receivedPacket) { +func (s *baseServer) maybeSendVersionNegotiationPacket(p receivedPacket) { defer p.buffer.Release() v, err := wire.ParseVersion(p.data) diff --git a/server_test.go b/server_test.go index bf0adfbb..2ba39cf5 100644 --- a/server_test.go +++ b/server_test.go @@ -31,7 +31,7 @@ var _ = Describe("Server", func() { tlsConf *tls.Config ) - getPacket := func(hdr *wire.Header, p []byte) *receivedPacket { + getPacket := func(hdr *wire.Header, p []byte) receivedPacket { buf := getPacketBuffer() hdr.Length = 4 + protocol.ByteCount(len(p)) + 16 var err error @@ -48,14 +48,14 @@ var _ = Describe("Server", func() { _ = sealer.Seal(data[n:n], data[n:], 0x42, data[:n]) data = data[:len(data)+16] sealer.EncryptHeader(data[n:n+16], &data[0], data[n-4:n]) - return &receivedPacket{ + return receivedPacket{ remoteAddr: &net.UDPAddr{IP: net.IPv4(4, 5, 6, 7), Port: 456}, data: data, buffer: buf, } } - getInitial := func(destConnID protocol.ConnectionID) *receivedPacket { + getInitial := func(destConnID protocol.ConnectionID) receivedPacket { senderAddr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 42} hdr := &wire.Header{ Type: protocol.PacketTypeInitial, @@ -69,7 +69,7 @@ var _ = Describe("Server", func() { return p } - getInitialWithRandomDestConnID := func() *receivedPacket { + getInitialWithRandomDestConnID := func() receivedPacket { b := make([]byte, 10) _, err := rand.Read(b) Expect(err).ToNot(HaveOccurred()) @@ -236,7 +236,7 @@ var _ = Describe("Server", func() { conn := NewMockPacketHandler(mockCtrl) phm.EXPECT().Get(connID).Return(conn, true) handled := make(chan struct{}) - conn.EXPECT().handlePacket(p).Do(func(*receivedPacket) { close(handled) }) + conn.EXPECT().handlePacket(p).Do(func(receivedPacket) { close(handled) }) serv.handlePacket(p) Eventually(handled).Should(BeClosed()) }) @@ -385,7 +385,7 @@ var _ = Describe("Server", func() { tracer.EXPECT().DroppedPacket(raddr, logging.PacketTypeVersionNegotiation, protocol.ByteCount(len(data)), logging.PacketDropUnexpectedPacket).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) { close(done) }) - serv.handlePacket(&receivedPacket{ + serv.handlePacket(receivedPacket{ remoteAddr: raddr, data: data, buffer: getPacketBuffer(), @@ -1040,7 +1040,7 @@ var _ = Describe("Server", func() { return ok }) serv.handleInitialImpl( - &receivedPacket{buffer: getPacketBuffer()}, + receivedPacket{buffer: getPacketBuffer()}, &wire.Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})}, ) Consistently(done).ShouldNot(BeClosed()) @@ -1065,7 +1065,7 @@ var _ = Describe("Server", func() { return len(b), nil }) serv.handleInitialImpl( - &receivedPacket{buffer: getPacketBuffer()}, + receivedPacket{buffer: getPacketBuffer()}, &wire.Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}), Version: protocol.Version1}, ) Eventually(done).Should(BeClosed()) @@ -1116,7 +1116,7 @@ var _ = Describe("Server", func() { return ok }) serv.handleInitialImpl( - &receivedPacket{buffer: getPacketBuffer()}, + receivedPacket{buffer: getPacketBuffer()}, &wire.Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})}, ) Consistently(done).ShouldNot(BeClosed()) @@ -1189,7 +1189,7 @@ var _ = Describe("Server", func() { return ok }) serv.baseServer.handleInitialImpl( - &receivedPacket{buffer: getPacketBuffer()}, + receivedPacket{buffer: getPacketBuffer()}, &wire.Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})}, ) Consistently(done).ShouldNot(BeClosed()) @@ -1352,7 +1352,7 @@ var _ = Describe("Server", func() { conn := NewMockPacketHandler(mockCtrl) phm.EXPECT().Get(connID).Return(conn, true) handled := make(chan struct{}) - conn.EXPECT().handlePacket(p).Do(func(*receivedPacket) { close(handled) }) + conn.EXPECT().handlePacket(p).Do(func(receivedPacket) { close(handled) }) serv.handlePacket(p) Eventually(handled).Should(BeClosed()) }) @@ -1360,7 +1360,7 @@ var _ = Describe("Server", func() { It("queues 0-RTT packets, up to Max0RTTQueueSize", func() { connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}) - var zeroRTTPackets []*receivedPacket + var zeroRTTPackets []receivedPacket for i := 0; i < protocol.Max0RTTQueueLen; i++ { p := getPacket(&wire.Header{ diff --git a/sys_conn.go b/sys_conn.go index f7feabae..463a4564 100644 --- a/sys_conn.go +++ b/sys_conn.go @@ -79,16 +79,16 @@ type basicConn struct { var _ rawConn = &basicConn{} -func (c *basicConn) ReadPacket() (*receivedPacket, error) { +func (c *basicConn) ReadPacket() (receivedPacket, error) { buffer := getPacketBuffer() // The packet size should not exceed protocol.MaxPacketBufferSize bytes // If it does, we only read a truncated packet, which will then end up undecryptable buffer.Data = buffer.Data[:protocol.MaxPacketBufferSize] n, addr, err := c.PacketConn.ReadFrom(buffer.Data) if err != nil { - return nil, err + return receivedPacket{}, err } - return &receivedPacket{ + return receivedPacket{ remoteAddr: addr, rcvTime: time.Now(), data: buffer.Data[:n], diff --git a/sys_conn_oob.go b/sys_conn_oob.go index 1ce6a95a..5e6213b1 100644 --- a/sys_conn_oob.go +++ b/sys_conn_oob.go @@ -148,7 +148,7 @@ func newConn(c OOBCapablePacketConn, supportsDF bool) (*oobConn, error) { return oobConn, nil } -func (c *oobConn) ReadPacket() (*receivedPacket, error) { +func (c *oobConn) ReadPacket() (receivedPacket, error) { if len(c.messages) == int(c.readPos) { // all messages read. Read the next batch of messages. c.messages = c.messages[:batchSize] // replace buffers data buffers up to the packet that has been consumed during the last ReadBatch call @@ -162,7 +162,7 @@ func (c *oobConn) ReadPacket() (*receivedPacket, error) { n, err := c.batchConn.ReadBatch(c.messages, 0) if n == 0 || err != nil { - return nil, err + return receivedPacket{}, err } c.messages = c.messages[:n] } @@ -178,7 +178,7 @@ func (c *oobConn) ReadPacket() (*receivedPacket, error) { for len(data) > 0 { hdr, body, remainder, err := unix.ParseOneSocketControlMessage(data) if err != nil { - return nil, err + return receivedPacket{}, err } if hdr.Level == unix.IPPROTO_IP { switch hdr.Type { @@ -228,7 +228,7 @@ func (c *oobConn) ReadPacket() (*receivedPacket, error) { ifIndex: ifIndex, } } - return &receivedPacket{ + return receivedPacket{ remoteAddr: msg.Addr, rcvTime: time.Now(), data: msg.Buffers[0][:msg.N], diff --git a/sys_conn_oob_test.go b/sys_conn_oob_test.go index 3e92bf32..57df623c 100644 --- a/sys_conn_oob_test.go +++ b/sys_conn_oob_test.go @@ -19,7 +19,7 @@ import ( ) var _ = Describe("OOB Conn Test", func() { - runServer := func(network, address string) (*net.UDPConn, <-chan *receivedPacket) { + runServer := func(network, address string) (*net.UDPConn, <-chan receivedPacket) { addr, err := net.ResolveUDPAddr(network, address) Expect(err).ToNot(HaveOccurred()) udpConn, err := net.ListenUDP(network, addr) @@ -28,7 +28,7 @@ var _ = Describe("OOB Conn Test", func() { Expect(err).ToNot(HaveOccurred()) Expect(oobConn.capabilities().DF).To(BeTrue()) - packetChan := make(chan *receivedPacket) + packetChan := make(chan receivedPacket) go func() { defer GinkgoRecover() for { @@ -69,7 +69,7 @@ var _ = Describe("OOB Conn Test", func() { }, ) - var p *receivedPacket + var p receivedPacket Eventually(packetChan).Should(Receive(&p)) Expect(p.rcvTime).To(BeTemporally("~", time.Now(), scaleDuration(20*time.Millisecond))) Expect(p.data).To(Equal([]byte("foobar"))) @@ -89,7 +89,7 @@ var _ = Describe("OOB Conn Test", func() { }, ) - var p *receivedPacket + var p receivedPacket Eventually(packetChan).Should(Receive(&p)) Expect(p.rcvTime).To(BeTemporally("~", time.Now(), scaleDuration(20*time.Millisecond))) Expect(p.data).To(Equal([]byte("foobar"))) @@ -111,7 +111,7 @@ var _ = Describe("OOB Conn Test", func() { }, ) - var p *receivedPacket + var p receivedPacket Eventually(packetChan).Should(Receive(&p)) Expect(utils.IsIPv4(p.remoteAddr.(*net.UDPAddr).IP)).To(BeTrue()) Expect(p.ecn).To(Equal(protocol.ECNCE)) @@ -149,7 +149,7 @@ var _ = Describe("OOB Conn Test", func() { addr.IP = ip sentFrom := sendPacket("udp4", addr) - var p *receivedPacket + var p receivedPacket Eventually(packetChan).Should(Receive(&p)) Expect(p.rcvTime).To(BeTemporally("~", time.Now(), scaleDuration(20*time.Millisecond))) Expect(p.data).To(Equal([]byte("foobar"))) @@ -167,7 +167,7 @@ var _ = Describe("OOB Conn Test", func() { addr.IP = ip sentFrom := sendPacket("udp6", addr) - var p *receivedPacket + var p receivedPacket Eventually(packetChan).Should(Receive(&p)) Expect(p.rcvTime).To(BeTemporally("~", time.Now(), scaleDuration(20*time.Millisecond))) Expect(p.data).To(Equal([]byte("foobar"))) @@ -185,7 +185,7 @@ var _ = Describe("OOB Conn Test", func() { ip4 := net.ParseIP("127.0.0.1").To4() sendPacket("udp4", &net.UDPAddr{IP: ip4, Port: port}) - var p *receivedPacket + var p receivedPacket Eventually(packetChan).Should(Receive(&p)) Expect(utils.IsIPv4(p.remoteAddr.(*net.UDPAddr).IP)).To(BeTrue()) Expect(p.info).To(Not(BeNil())) diff --git a/transport.go b/transport.go index b6d402b4..cc3d294f 100644 --- a/transport.go +++ b/transport.go @@ -74,7 +74,7 @@ type Transport struct { conn rawConn closeQueue chan closePacket - statelessResetQueue chan *receivedPacket + statelessResetQueue chan receivedPacket listening chan struct{} // is closed when listen returns closed bool @@ -197,7 +197,7 @@ func (t *Transport) init(isServer bool) error { t.listening = make(chan struct{}) t.closeQueue = make(chan closePacket, 4) - t.statelessResetQueue = make(chan *receivedPacket, 4) + t.statelessResetQueue = make(chan receivedPacket, 4) if t.ConnectionIDGenerator != nil { t.connIDGenerator = t.ConnectionIDGenerator @@ -339,7 +339,7 @@ func (t *Transport) listen(conn rawConn) { } } -func (t *Transport) handlePacket(p *receivedPacket) { +func (t *Transport) handlePacket(p receivedPacket) { connID, err := wire.ParseConnectionID(p.data, t.connIDLen) if err != nil { t.logger.Debugf("error parsing connection ID on packet from %s: %s", p.remoteAddr, err) @@ -371,7 +371,7 @@ func (t *Transport) handlePacket(p *receivedPacket) { t.server.handlePacket(p) } -func (t *Transport) maybeSendStatelessReset(p *receivedPacket) { +func (t *Transport) maybeSendStatelessReset(p receivedPacket) { if t.StatelessResetKey == nil { p.buffer.Release() return @@ -392,7 +392,7 @@ func (t *Transport) maybeSendStatelessReset(p *receivedPacket) { } } -func (t *Transport) sendStatelessReset(p *receivedPacket) { +func (t *Transport) sendStatelessReset(p receivedPacket) { defer p.buffer.Release() connID, err := wire.ParseConnectionID(p.data, t.connIDLen) diff --git a/transport_test.go b/transport_test.go index f60db3b2..9426f350 100644 --- a/transport_test.go +++ b/transport_test.go @@ -70,7 +70,7 @@ var _ = Describe("Transport", func() { handled := make(chan struct{}, 2) phm.EXPECT().Get(connID1).DoAndReturn(func(protocol.ConnectionID) (packetHandler, bool) { h := NewMockPacketHandler(mockCtrl) - h.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) { + h.EXPECT().handlePacket(gomock.Any()).Do(func(p receivedPacket) { defer GinkgoRecover() connID, err := wire.ParseConnectionID(p.data, 0) Expect(err).ToNot(HaveOccurred()) @@ -81,7 +81,7 @@ var _ = Describe("Transport", func() { }) phm.EXPECT().Get(connID2).DoAndReturn(func(protocol.ConnectionID) (packetHandler, bool) { h := NewMockPacketHandler(mockCtrl) - h.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) { + h.EXPECT().handlePacket(gomock.Any()).Do(func(p receivedPacket) { defer GinkgoRecover() connID, err := wire.ParseConnectionID(p.data, 0) Expect(err).ToNot(HaveOccurred()) @@ -205,7 +205,7 @@ var _ = Describe("Transport", func() { gomock.InOrder( phm.EXPECT().GetByResetToken(token), phm.EXPECT().Get(connID).Return(conn, true), - conn.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) { + conn.EXPECT().handlePacket(gomock.Any()).Do(func(p receivedPacket) { Expect(p.data).To(Equal(b)) Expect(p.rcvTime).To(BeTemporally("~", time.Now(), time.Second)) }),