pass around receivedPacket as struct instead of as pointer (#3823)

This commit is contained in:
Marten Seemann 2023-06-03 10:08:58 +03:00 committed by GitHub
parent 591ab1ab5e
commit 072a602cc1
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
16 changed files with 96 additions and 94 deletions

View file

@ -30,7 +30,7 @@ func newClosedLocalConn(sendPacket func(net.Addr, *packetInfo), pers protocol.Pe
} }
} }
func (c *closedLocalConn) handlePacket(p *receivedPacket) { func (c *closedLocalConn) handlePacket(p receivedPacket) {
c.counter++ c.counter++
// exponential backoff // exponential backoff
// only send a CONNECTION_CLOSE for the 1st, 2nd, 4th, 8th, 16th, ... packet arriving // only send a CONNECTION_CLOSE for the 1st, 2nd, 4th, 8th, 16th, ... packet arriving
@ -58,7 +58,7 @@ func newClosedRemoteConn(pers protocol.Perspective) packetHandler {
return &closedRemoteConn{perspective: pers} return &closedRemoteConn{perspective: pers}
} }
func (s *closedRemoteConn) handlePacket(*receivedPacket) {} func (s *closedRemoteConn) handlePacket(receivedPacket) {}
func (s *closedRemoteConn) shutdown() {} func (s *closedRemoteConn) shutdown() {}
func (s *closedRemoteConn) destroy(error) {} func (s *closedRemoteConn) destroy(error) {}
func (s *closedRemoteConn) getPerspective() protocol.Perspective { return s.perspective } func (s *closedRemoteConn) getPerspective() protocol.Perspective { return s.perspective }

View file

@ -27,7 +27,7 @@ var _ = Describe("Closed local connection", func() {
) )
addr := &net.UDPAddr{IP: net.IPv4(127, 1, 2, 3), Port: 1337} addr := &net.UDPAddr{IP: net.IPv4(127, 1, 2, 3), Port: 1337}
for i := 1; i <= 20; i++ { for i := 1; i <= 20; i++ {
conn.handlePacket(&receivedPacket{remoteAddr: addr}) conn.handlePacket(receivedPacket{remoteAddr: addr})
if i == 1 || i == 2 || i == 4 || i == 8 || i == 16 { if i == 1 || i == 2 || i == 4 || i == 8 || i == 16 {
Expect(written).To(Receive(Equal(addr))) // receive the CONNECTION_CLOSE Expect(written).To(Receive(Equal(addr))) // receive the CONNECTION_CLOSE
} else { } else {

View file

@ -168,7 +168,7 @@ type connection struct {
oneRTTStream cryptoStream // only set for the server oneRTTStream cryptoStream // only set for the server
cryptoStreamHandler cryptoStreamHandler cryptoStreamHandler cryptoStreamHandler
receivedPackets chan *receivedPacket receivedPackets chan receivedPacket
sendingScheduled chan struct{} sendingScheduled chan struct{}
closeOnce sync.Once closeOnce sync.Once
@ -180,8 +180,8 @@ type connection struct {
handshakeCtx context.Context handshakeCtx context.Context
handshakeCtxCancel context.CancelFunc handshakeCtxCancel context.CancelFunc
undecryptablePackets []*receivedPacket // undecryptable packets, waiting for a change in encryption level undecryptablePackets []receivedPacket // undecryptable packets, waiting for a change in encryption level
undecryptablePacketsToProcess []*receivedPacket undecryptablePacketsToProcess []receivedPacket
clientHelloWritten <-chan *wire.TransportParameters clientHelloWritten <-chan *wire.TransportParameters
earlyConnReadyChan chan struct{} earlyConnReadyChan chan struct{}
@ -509,7 +509,7 @@ func (s *connection) preSetup() {
s.perspective, s.perspective,
) )
s.framer = newFramer(s.streamsMap) s.framer = newFramer(s.streamsMap)
s.receivedPackets = make(chan *receivedPacket, protocol.MaxConnUnprocessedPackets) s.receivedPackets = make(chan receivedPacket, protocol.MaxConnUnprocessedPackets)
s.closeChan = make(chan closeError, 1) s.closeChan = make(chan closeError, 1)
s.sendingScheduled = make(chan struct{}, 1) s.sendingScheduled = make(chan struct{}, 1)
s.handshakeCtx, s.handshakeCtxCancel = context.WithCancel(context.Background()) s.handshakeCtx, s.handshakeCtxCancel = context.WithCancel(context.Background())
@ -806,7 +806,7 @@ func (s *connection) handleHandshakeConfirmed() {
} }
} }
func (s *connection) handlePacketImpl(rp *receivedPacket) bool { func (s *connection) handlePacketImpl(rp receivedPacket) bool {
s.sentPacketHandler.ReceivedBytes(rp.Size()) s.sentPacketHandler.ReceivedBytes(rp.Size())
if wire.IsVersionNegotiationPacket(rp.data) { if wire.IsVersionNegotiationPacket(rp.data) {
@ -822,7 +822,7 @@ func (s *connection) handlePacketImpl(rp *receivedPacket) bool {
for len(data) > 0 { for len(data) > 0 {
var destConnID protocol.ConnectionID var destConnID protocol.ConnectionID
if counter > 0 { if counter > 0 {
p = p.Clone() p = *(p.Clone())
p.data = data p.data = data
var err error var err error
@ -895,7 +895,7 @@ func (s *connection) handlePacketImpl(rp *receivedPacket) bool {
return processed return processed
} }
func (s *connection) handleShortHeaderPacket(p *receivedPacket, destConnID protocol.ConnectionID) bool { func (s *connection) handleShortHeaderPacket(p receivedPacket, destConnID protocol.ConnectionID) bool {
var wasQueued bool var wasQueued bool
defer func() { defer func() {
@ -946,7 +946,7 @@ func (s *connection) handleShortHeaderPacket(p *receivedPacket, destConnID proto
return true return true
} }
func (s *connection) handleLongHeaderPacket(p *receivedPacket, hdr *wire.Header) bool /* was the packet successfully processed */ { func (s *connection) handleLongHeaderPacket(p receivedPacket, hdr *wire.Header) bool /* was the packet successfully processed */ {
var wasQueued bool var wasQueued bool
defer func() { defer func() {
@ -1003,7 +1003,7 @@ func (s *connection) handleLongHeaderPacket(p *receivedPacket, hdr *wire.Header)
return true return true
} }
func (s *connection) handleUnpackError(err error, p *receivedPacket, pt logging.PacketType) (wasQueued bool) { func (s *connection) handleUnpackError(err error, p receivedPacket, pt logging.PacketType) (wasQueued bool) {
switch err { switch err {
case handshake.ErrKeysDropped: case handshake.ErrKeysDropped:
if s.tracer != nil { if s.tracer != nil {
@ -1105,7 +1105,7 @@ func (s *connection) handleRetryPacket(hdr *wire.Header, data []byte) bool /* wa
return true return true
} }
func (s *connection) handleVersionNegotiationPacket(p *receivedPacket) { func (s *connection) handleVersionNegotiationPacket(p receivedPacket) {
if s.perspective == protocol.PerspectiveServer || // servers never receive version negotiation packets if s.perspective == protocol.PerspectiveServer || // servers never receive version negotiation packets
s.receivedFirstPacket || s.versionNegotiated { // ignore delayed / duplicated version negotiation packets s.receivedFirstPacket || s.versionNegotiated { // ignore delayed / duplicated version negotiation packets
if s.tracer != nil { if s.tracer != nil {
@ -1340,7 +1340,7 @@ func (s *connection) handleFrame(f wire.Frame, encLevel protocol.EncryptionLevel
} }
// handlePacket is called by the server with a new packet // handlePacket is called by the server with a new packet
func (s *connection) handlePacket(p *receivedPacket) { func (s *connection) handlePacket(p receivedPacket) {
// Discard packets once the amount of queued packets is larger than // Discard packets once the amount of queued packets is larger than
// the channel size, protocol.MaxConnUnprocessedPackets // the channel size, protocol.MaxConnUnprocessedPackets
select { select {
@ -2230,7 +2230,7 @@ func (s *connection) scheduleSending() {
// tryQueueingUndecryptablePacket queues a packet for which we're missing the decryption keys. // tryQueueingUndecryptablePacket queues a packet for which we're missing the decryption keys.
// The logging.PacketType is only used for logging purposes. // The logging.PacketType is only used for logging purposes.
func (s *connection) tryQueueingUndecryptablePacket(p *receivedPacket, pt logging.PacketType) { func (s *connection) tryQueueingUndecryptablePacket(p receivedPacket, pt logging.PacketType) {
if s.handshakeComplete { if s.handshakeComplete {
panic("shouldn't queue undecryptable packets after handshake completion") panic("shouldn't queue undecryptable packets after handshake completion")
} }

View file

@ -592,7 +592,7 @@ var _ = Describe("Connection", func() {
tracer.EXPECT().Close(), tracer.EXPECT().Close(),
) )
// don't EXPECT any calls to packer.PackPacket() // don't EXPECT any calls to packer.PackPacket()
conn.handlePacket(&receivedPacket{ conn.handlePacket(receivedPacket{
rcvTime: time.Now(), rcvTime: time.Now(),
remoteAddr: &net.UDPAddr{}, remoteAddr: &net.UDPAddr{},
buffer: getPacketBuffer(), buffer: getPacketBuffer(),
@ -654,20 +654,20 @@ var _ = Describe("Connection", func() {
conn.unpacker = unpacker conn.unpacker = unpacker
}) })
getShortHeaderPacket := func(connID protocol.ConnectionID, pn protocol.PacketNumber, data []byte) *receivedPacket { getShortHeaderPacket := func(connID protocol.ConnectionID, pn protocol.PacketNumber, data []byte) receivedPacket {
b, err := wire.AppendShortHeader(nil, connID, pn, protocol.PacketNumberLen2, protocol.KeyPhaseOne) b, err := wire.AppendShortHeader(nil, connID, pn, protocol.PacketNumberLen2, protocol.KeyPhaseOne)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
return &receivedPacket{ return receivedPacket{
data: append(b, data...), data: append(b, data...),
buffer: getPacketBuffer(), buffer: getPacketBuffer(),
rcvTime: time.Now(), rcvTime: time.Now(),
} }
} }
getLongHeaderPacket := func(extHdr *wire.ExtendedHeader, data []byte) *receivedPacket { getLongHeaderPacket := func(extHdr *wire.ExtendedHeader, data []byte) receivedPacket {
b, err := extHdr.Append(nil, conn.version) b, err := extHdr.Append(nil, conn.version)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
return &receivedPacket{ return receivedPacket{
data: append(b, data...), data: append(b, data...),
buffer: getPacketBuffer(), buffer: getPacketBuffer(),
rcvTime: time.Now(), rcvTime: time.Now(),
@ -693,7 +693,7 @@ var _ = Describe("Connection", func() {
conn.config.Versions, conn.config.Versions,
) )
tracer.EXPECT().DroppedPacket(logging.PacketTypeVersionNegotiation, protocol.ByteCount(len(b)), logging.PacketDropUnexpectedPacket) tracer.EXPECT().DroppedPacket(logging.PacketTypeVersionNegotiation, protocol.ByteCount(len(b)), logging.PacketDropUnexpectedPacket)
Expect(conn.handlePacketImpl(&receivedPacket{ Expect(conn.handlePacketImpl(receivedPacket{
data: b, data: b,
buffer: getPacketBuffer(), buffer: getPacketBuffer(),
})).To(BeFalse()) })).To(BeFalse())
@ -1036,7 +1036,7 @@ var _ = Describe("Connection", func() {
packet := getLongHeaderPacket(hdr, nil) packet := getLongHeaderPacket(hdr, nil)
tracer.EXPECT().BufferedPacket(logging.PacketTypeHandshake, packet.Size()) tracer.EXPECT().BufferedPacket(logging.PacketTypeHandshake, packet.Size())
Expect(conn.handlePacketImpl(packet)).To(BeFalse()) Expect(conn.handlePacketImpl(packet)).To(BeFalse())
Expect(conn.undecryptablePackets).To(Equal([]*receivedPacket{packet})) Expect(conn.undecryptablePackets).To(Equal([]receivedPacket{packet}))
}) })
Context("updating the remote address", func() { Context("updating the remote address", func() {
@ -1053,7 +1053,7 @@ var _ = Describe("Connection", func() {
BeforeEach(func() { BeforeEach(func() {
tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).MaxTimes(1) tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).MaxTimes(1)
}) })
getPacketWithLength := func(connID protocol.ConnectionID, length protocol.ByteCount) (int /* header length */, *receivedPacket) { getPacketWithLength := func(connID protocol.ConnectionID, length protocol.ByteCount) (int /* header length */, receivedPacket) {
hdr := &wire.ExtendedHeader{ hdr := &wire.ExtendedHeader{
Header: wire.Header{ Header: wire.Header{
Type: protocol.PacketTypeHandshake, Type: protocol.PacketTypeHandshake,
@ -1612,7 +1612,7 @@ var _ = Describe("Connection", func() {
sender.EXPECT().WouldBlock().AnyTimes() sender.EXPECT().WouldBlock().AnyTimes()
sph.EXPECT().SentPacket(gomock.Any()).Do(func(*ackhandler.Packet) { sph.EXPECT().SentPacket(gomock.Any()).Do(func(*ackhandler.Packet) {
sph.EXPECT().ReceivedBytes(gomock.Any()) sph.EXPECT().ReceivedBytes(gomock.Any())
conn.handlePacket(&receivedPacket{buffer: getPacketBuffer()}) conn.handlePacket(receivedPacket{buffer: getPacketBuffer()})
}) })
sph.EXPECT().SendMode().Return(ackhandler.SendAny).AnyTimes() sph.EXPECT().SendMode().Return(ackhandler.SendAny).AnyTimes()
expectAppendPacket(packer, shortHeaderPacket{Packet: &ackhandler.Packet{PacketNumber: 10}}, []byte("packet10")) expectAppendPacket(packer, shortHeaderPacket{Packet: &ackhandler.Packet{PacketNumber: 10}}, []byte("packet10"))
@ -2316,7 +2316,7 @@ var _ = Describe("Connection", func() {
}) })
// Nothing here should block // Nothing here should block
for i := protocol.PacketNumber(0); i < protocol.MaxConnUnprocessedPackets+1; i++ { for i := protocol.PacketNumber(0); i < protocol.MaxConnUnprocessedPackets+1; i++ {
conn.handlePacket(&receivedPacket{data: []byte("foobar")}) conn.handlePacket(receivedPacket{data: []byte("foobar")})
} }
Eventually(done).Should(BeClosed()) Eventually(done).Should(BeClosed())
}) })
@ -2398,10 +2398,10 @@ var _ = Describe("Client Connection", func() {
srcConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}) srcConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})
destConnID := protocol.ParseConnectionID([]byte{8, 7, 6, 5, 4, 3, 2, 1}) destConnID := protocol.ParseConnectionID([]byte{8, 7, 6, 5, 4, 3, 2, 1})
getPacket := func(hdr *wire.ExtendedHeader, data []byte) *receivedPacket { getPacket := func(hdr *wire.ExtendedHeader, data []byte) receivedPacket {
b, err := hdr.Append(nil, conn.version) b, err := hdr.Append(nil, conn.version)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
return &receivedPacket{ return receivedPacket{
data: append(b, data...), data: append(b, data...),
buffer: getPacketBuffer(), buffer: getPacketBuffer(),
} }
@ -2519,7 +2519,7 @@ var _ = Describe("Client Connection", func() {
SrcConnectionID: destConnID, SrcConnectionID: destConnID,
} }
tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any()) tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any())
Expect(conn.handleLongHeaderPacket(&receivedPacket{buffer: getPacketBuffer()}, hdr)).To(BeTrue()) Expect(conn.handleLongHeaderPacket(receivedPacket{buffer: getPacketBuffer()}, hdr)).To(BeTrue())
}) })
It("handles HANDSHAKE_DONE frames", func() { It("handles HANDSHAKE_DONE frames", func() {
@ -2580,13 +2580,13 @@ var _ = Describe("Client Connection", func() {
}) })
Context("handling Version Negotiation", func() { Context("handling Version Negotiation", func() {
getVNP := func(versions ...protocol.VersionNumber) *receivedPacket { getVNP := func(versions ...protocol.VersionNumber) receivedPacket {
b := wire.ComposeVersionNegotiation( b := wire.ComposeVersionNegotiation(
protocol.ArbitraryLenConnectionID(srcConnID.Bytes()), protocol.ArbitraryLenConnectionID(srcConnID.Bytes()),
protocol.ArbitraryLenConnectionID(destConnID.Bytes()), protocol.ArbitraryLenConnectionID(destConnID.Bytes()),
versions, versions,
) )
return &receivedPacket{ return receivedPacket{
data: b, data: b,
buffer: getPacketBuffer(), buffer: getPacketBuffer(),
} }
@ -2892,18 +2892,18 @@ var _ = Describe("Client Connection", func() {
Context("handling potentially injected packets", func() { Context("handling potentially injected packets", func() {
var unpacker *MockUnpacker var unpacker *MockUnpacker
getPacket := func(extHdr *wire.ExtendedHeader, data []byte) *receivedPacket { getPacket := func(extHdr *wire.ExtendedHeader, data []byte) receivedPacket {
b, err := extHdr.Append(nil, conn.version) b, err := extHdr.Append(nil, conn.version)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
return &receivedPacket{ return receivedPacket{
data: append(b, data...), data: append(b, data...),
buffer: getPacketBuffer(), buffer: getPacketBuffer(),
} }
} }
// Convert an already packed raw packet into a receivedPacket // Convert an already packed raw packet into a receivedPacket
wrapPacket := func(packet []byte) *receivedPacket { wrapPacket := func(packet []byte) receivedPacket {
return &receivedPacket{ return receivedPacket{
data: packet, data: packet,
buffer: getPacketBuffer(), buffer: getPacketBuffer(),
} }

View file

@ -61,7 +61,7 @@ func (mr *MockPacketHandlerMockRecorder) getPerspective() *gomock.Call {
} }
// handlePacket mocks base method. // handlePacket mocks base method.
func (m *MockPacketHandler) handlePacket(arg0 *receivedPacket) { func (m *MockPacketHandler) handlePacket(arg0 receivedPacket) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
m.ctrl.Call(m, "handlePacket", arg0) m.ctrl.Call(m, "handlePacket", arg0)
} }

View file

@ -309,7 +309,7 @@ func (mr *MockQUICConnMockRecorder) getPerspective() *gomock.Call {
} }
// handlePacket mocks base method. // handlePacket mocks base method.
func (m *MockQUICConn) handlePacket(arg0 *receivedPacket) { func (m *MockQUICConn) handlePacket(arg0 receivedPacket) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
m.ctrl.Call(m, "handlePacket", arg0) m.ctrl.Call(m, "handlePacket", arg0)
} }

View file

@ -34,7 +34,7 @@ func (m *MockUnknownPacketHandler) EXPECT() *MockUnknownPacketHandlerMockRecorde
} }
// handlePacket mocks base method. // handlePacket mocks base method.
func (m *MockUnknownPacketHandler) handlePacket(arg0 *receivedPacket) { func (m *MockUnknownPacketHandler) handlePacket(arg0 receivedPacket) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
m.ctrl.Call(m, "handlePacket", arg0) m.ctrl.Call(m, "handlePacket", arg0)
} }

View file

@ -25,7 +25,7 @@ type connCapabilities struct {
// rawConn is a connection that allow reading of a receivedPackeh. // rawConn is a connection that allow reading of a receivedPackeh.
type rawConn interface { type rawConn interface {
ReadPacket() (*receivedPacket, error) ReadPacket() (receivedPacket, error)
// The size parameter is used for GSO. // The size parameter is used for GSO.
// If GSO is not support, len(b) must be equal to size. // If GSO is not support, len(b) must be equal to size.
WritePacket(b []byte, size uint16, addr net.Addr, oob []byte) (int, error) WritePacket(b []byte, size uint16, addr net.Addr, oob []byte) (int, error)
@ -43,7 +43,7 @@ type closePacket struct {
} }
type unknownPacketHandler interface { type unknownPacketHandler interface {
handlePacket(*receivedPacket) handlePacket(receivedPacket)
setCloseError(error) setCloseError(error)
} }

View file

@ -129,7 +129,7 @@ var _ = Describe("Packet Handler Map", func() {
Expect(ok).To(BeTrue()) Expect(ok).To(BeTrue())
Expect(h).ToNot(Equal(handler)) Expect(h).ToNot(Equal(handler))
addr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1234} addr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1234}
h.handlePacket(&receivedPacket{remoteAddr: addr}) h.handlePacket(receivedPacket{remoteAddr: addr})
Expect(closePackets).To(HaveLen(1)) Expect(closePackets).To(HaveLen(1))
Expect(closePackets[0].addr).To(Equal(addr)) Expect(closePackets[0].addr).To(Equal(addr))
Expect(closePackets[0].payload).To(Equal([]byte("foobar"))) Expect(closePackets[0].payload).To(Equal([]byte("foobar")))
@ -152,7 +152,7 @@ var _ = Describe("Packet Handler Map", func() {
Expect(ok).To(BeTrue()) Expect(ok).To(BeTrue())
Expect(h).ToNot(Equal(handler)) Expect(h).ToNot(Equal(handler))
addr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1234} addr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1234}
h.handlePacket(&receivedPacket{remoteAddr: addr}) h.handlePacket(receivedPacket{remoteAddr: addr})
Expect(closePackets).To(BeEmpty()) Expect(closePackets).To(BeEmpty())
time.Sleep(dur) time.Sleep(dur)

View file

@ -24,7 +24,7 @@ var ErrServerClosed = errors.New("quic: server closed")
// packetHandler handles packets // packetHandler handles packets
type packetHandler interface { type packetHandler interface {
handlePacket(*receivedPacket) handlePacket(receivedPacket)
shutdown() shutdown()
destroy(error) destroy(error)
getPerspective() protocol.Perspective getPerspective() protocol.Perspective
@ -42,7 +42,7 @@ type packetHandlerManager interface {
type quicConn interface { type quicConn interface {
EarlyConnection EarlyConnection
earlyConnReady() <-chan struct{} earlyConnReady() <-chan struct{}
handlePacket(*receivedPacket) handlePacket(receivedPacket)
GetVersion() protocol.VersionNumber GetVersion() protocol.VersionNumber
getPerspective() protocol.Perspective getPerspective() protocol.Perspective
run() error run() error
@ -51,7 +51,7 @@ type quicConn interface {
} }
type zeroRTTQueue struct { type zeroRTTQueue struct {
packets []*receivedPacket packets []receivedPacket
expiration time.Time expiration time.Time
} }
@ -72,7 +72,7 @@ type baseServer struct {
connHandler packetHandlerManager connHandler packetHandlerManager
onClose func() onClose func()
receivedPackets chan *receivedPacket receivedPackets chan receivedPacket
nextZeroRTTCleanup time.Time nextZeroRTTCleanup time.Time
zeroRTTQueues map[protocol.ConnectionID]*zeroRTTQueue // only initialized if acceptEarlyConns == true zeroRTTQueues map[protocol.ConnectionID]*zeroRTTQueue // only initialized if acceptEarlyConns == true
@ -102,8 +102,8 @@ type baseServer struct {
errorChan chan struct{} errorChan chan struct{}
closed bool closed bool
running chan struct{} // closed as soon as run() returns running chan struct{} // closed as soon as run() returns
versionNegotiationQueue chan *receivedPacket versionNegotiationQueue chan receivedPacket
invalidTokenQueue chan *receivedPacket invalidTokenQueue chan receivedPacket
connQueue chan quicConn connQueue chan quicConn
connQueueLen int32 // to be used as an atomic connQueueLen int32 // to be used as an atomic
@ -242,9 +242,9 @@ func newServer(
connQueue: make(chan quicConn), connQueue: make(chan quicConn),
errorChan: make(chan struct{}), errorChan: make(chan struct{}),
running: make(chan struct{}), running: make(chan struct{}),
receivedPackets: make(chan *receivedPacket, protocol.MaxServerUnprocessedPackets), receivedPackets: make(chan receivedPacket, protocol.MaxServerUnprocessedPackets),
versionNegotiationQueue: make(chan *receivedPacket, 4), versionNegotiationQueue: make(chan receivedPacket, 4),
invalidTokenQueue: make(chan *receivedPacket, 4), invalidTokenQueue: make(chan receivedPacket, 4),
newConn: newConnection, newConn: newConnection,
tracer: tracer, tracer: tracer,
logger: utils.DefaultLogger.WithPrefix("server"), logger: utils.DefaultLogger.WithPrefix("server"),
@ -345,7 +345,7 @@ func (s *baseServer) Addr() net.Addr {
return s.conn.LocalAddr() return s.conn.LocalAddr()
} }
func (s *baseServer) handlePacket(p *receivedPacket) { func (s *baseServer) handlePacket(p receivedPacket) {
select { select {
case s.receivedPackets <- p: case s.receivedPackets <- p:
default: default:
@ -356,7 +356,7 @@ func (s *baseServer) handlePacket(p *receivedPacket) {
} }
} }
func (s *baseServer) handlePacketImpl(p *receivedPacket) bool /* is the buffer still in use? */ { func (s *baseServer) handlePacketImpl(p receivedPacket) bool /* is the buffer still in use? */ {
if !s.nextZeroRTTCleanup.IsZero() && p.rcvTime.After(s.nextZeroRTTCleanup) { if !s.nextZeroRTTCleanup.IsZero() && p.rcvTime.After(s.nextZeroRTTCleanup) {
defer s.cleanupZeroRTTQueues(p.rcvTime) defer s.cleanupZeroRTTQueues(p.rcvTime)
} }
@ -446,7 +446,7 @@ func (s *baseServer) handlePacketImpl(p *receivedPacket) bool /* is the buffer s
return true return true
} }
func (s *baseServer) handle0RTTPacket(p *receivedPacket) bool { func (s *baseServer) handle0RTTPacket(p receivedPacket) bool {
connID, err := wire.ParseConnectionID(p.data, 0) connID, err := wire.ParseConnectionID(p.data, 0)
if err != nil { if err != nil {
if s.tracer != nil { if s.tracer != nil {
@ -478,7 +478,7 @@ func (s *baseServer) handle0RTTPacket(p *receivedPacket) bool {
} }
return false return false
} }
queue := &zeroRTTQueue{packets: make([]*receivedPacket, 1, 8)} queue := &zeroRTTQueue{packets: make([]receivedPacket, 1, 8)}
queue.packets[0] = p queue.packets[0] = p
expiration := p.rcvTime.Add(protocol.Max0RTTQueueingDuration) expiration := p.rcvTime.Add(protocol.Max0RTTQueueingDuration)
queue.expiration = expiration queue.expiration = expiration
@ -534,7 +534,7 @@ func (s *baseServer) validateToken(token *handshake.Token, addr net.Addr) bool {
return true return true
} }
func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) error { func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error {
if len(hdr.Token) == 0 && hdr.DestConnectionID.Len() < protocol.MinConnectionIDLenInitial { if len(hdr.Token) == 0 && hdr.DestConnectionID.Len() < protocol.MinConnectionIDLenInitial {
p.buffer.Release() p.buffer.Release()
if s.tracer != nil { if s.tracer != nil {
@ -746,7 +746,7 @@ func (s *baseServer) sendRetry(remoteAddr net.Addr, hdr *wire.Header, info *pack
return err return err
} }
func (s *baseServer) enqueueInvalidToken(p *receivedPacket) { func (s *baseServer) enqueueInvalidToken(p receivedPacket) {
select { select {
case s.invalidTokenQueue <- p: case s.invalidTokenQueue <- p:
default: default:
@ -755,7 +755,7 @@ func (s *baseServer) enqueueInvalidToken(p *receivedPacket) {
} }
} }
func (s *baseServer) maybeSendInvalidToken(p *receivedPacket) { func (s *baseServer) maybeSendInvalidToken(p receivedPacket) {
defer p.buffer.Release() defer p.buffer.Release()
hdr, _, _, err := wire.ParsePacket(p.data) hdr, _, _, err := wire.ParsePacket(p.data)
@ -772,6 +772,8 @@ func (s *baseServer) maybeSendInvalidToken(p *receivedPacket) {
sealer, opener := handshake.NewInitialAEAD(hdr.DestConnectionID, protocol.PerspectiveServer, hdr.Version) sealer, opener := handshake.NewInitialAEAD(hdr.DestConnectionID, protocol.PerspectiveServer, hdr.Version)
data := p.data[:hdr.ParsedLen()+hdr.Length] data := p.data[:hdr.ParsedLen()+hdr.Length]
extHdr, err := unpackLongHeader(opener, hdr, data, hdr.Version) extHdr, err := unpackLongHeader(opener, hdr, data, hdr.Version)
// Only send INVALID_TOKEN if we can unprotect the packet.
// This makes sure that we won't send it for packets that were corrupted.
if err != nil { if err != nil {
if s.tracer != nil { if s.tracer != nil {
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropHeaderParseError) s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropHeaderParseError)
@ -843,7 +845,7 @@ func (s *baseServer) sendError(remoteAddr net.Addr, hdr *wire.Header, sealer han
return err return err
} }
func (s *baseServer) enqueueVersionNegotiationPacket(p *receivedPacket) (bufferInUse bool) { func (s *baseServer) enqueueVersionNegotiationPacket(p receivedPacket) (bufferInUse bool) {
select { select {
case s.versionNegotiationQueue <- p: case s.versionNegotiationQueue <- p:
return true return true
@ -853,7 +855,7 @@ func (s *baseServer) enqueueVersionNegotiationPacket(p *receivedPacket) (bufferI
return false return false
} }
func (s *baseServer) maybeSendVersionNegotiationPacket(p *receivedPacket) { func (s *baseServer) maybeSendVersionNegotiationPacket(p receivedPacket) {
defer p.buffer.Release() defer p.buffer.Release()
v, err := wire.ParseVersion(p.data) v, err := wire.ParseVersion(p.data)

View file

@ -31,7 +31,7 @@ var _ = Describe("Server", func() {
tlsConf *tls.Config tlsConf *tls.Config
) )
getPacket := func(hdr *wire.Header, p []byte) *receivedPacket { getPacket := func(hdr *wire.Header, p []byte) receivedPacket {
buf := getPacketBuffer() buf := getPacketBuffer()
hdr.Length = 4 + protocol.ByteCount(len(p)) + 16 hdr.Length = 4 + protocol.ByteCount(len(p)) + 16
var err error var err error
@ -48,14 +48,14 @@ var _ = Describe("Server", func() {
_ = sealer.Seal(data[n:n], data[n:], 0x42, data[:n]) _ = sealer.Seal(data[n:n], data[n:], 0x42, data[:n])
data = data[:len(data)+16] data = data[:len(data)+16]
sealer.EncryptHeader(data[n:n+16], &data[0], data[n-4:n]) sealer.EncryptHeader(data[n:n+16], &data[0], data[n-4:n])
return &receivedPacket{ return receivedPacket{
remoteAddr: &net.UDPAddr{IP: net.IPv4(4, 5, 6, 7), Port: 456}, remoteAddr: &net.UDPAddr{IP: net.IPv4(4, 5, 6, 7), Port: 456},
data: data, data: data,
buffer: buf, buffer: buf,
} }
} }
getInitial := func(destConnID protocol.ConnectionID) *receivedPacket { getInitial := func(destConnID protocol.ConnectionID) receivedPacket {
senderAddr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 42} senderAddr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 42}
hdr := &wire.Header{ hdr := &wire.Header{
Type: protocol.PacketTypeInitial, Type: protocol.PacketTypeInitial,
@ -69,7 +69,7 @@ var _ = Describe("Server", func() {
return p return p
} }
getInitialWithRandomDestConnID := func() *receivedPacket { getInitialWithRandomDestConnID := func() receivedPacket {
b := make([]byte, 10) b := make([]byte, 10)
_, err := rand.Read(b) _, err := rand.Read(b)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
@ -236,7 +236,7 @@ var _ = Describe("Server", func() {
conn := NewMockPacketHandler(mockCtrl) conn := NewMockPacketHandler(mockCtrl)
phm.EXPECT().Get(connID).Return(conn, true) phm.EXPECT().Get(connID).Return(conn, true)
handled := make(chan struct{}) handled := make(chan struct{})
conn.EXPECT().handlePacket(p).Do(func(*receivedPacket) { close(handled) }) conn.EXPECT().handlePacket(p).Do(func(receivedPacket) { close(handled) })
serv.handlePacket(p) serv.handlePacket(p)
Eventually(handled).Should(BeClosed()) Eventually(handled).Should(BeClosed())
}) })
@ -385,7 +385,7 @@ var _ = Describe("Server", func() {
tracer.EXPECT().DroppedPacket(raddr, logging.PacketTypeVersionNegotiation, protocol.ByteCount(len(data)), logging.PacketDropUnexpectedPacket).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) { tracer.EXPECT().DroppedPacket(raddr, logging.PacketTypeVersionNegotiation, protocol.ByteCount(len(data)), logging.PacketDropUnexpectedPacket).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) {
close(done) close(done)
}) })
serv.handlePacket(&receivedPacket{ serv.handlePacket(receivedPacket{
remoteAddr: raddr, remoteAddr: raddr,
data: data, data: data,
buffer: getPacketBuffer(), buffer: getPacketBuffer(),
@ -1040,7 +1040,7 @@ var _ = Describe("Server", func() {
return ok return ok
}) })
serv.handleInitialImpl( serv.handleInitialImpl(
&receivedPacket{buffer: getPacketBuffer()}, receivedPacket{buffer: getPacketBuffer()},
&wire.Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})}, &wire.Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})},
) )
Consistently(done).ShouldNot(BeClosed()) Consistently(done).ShouldNot(BeClosed())
@ -1065,7 +1065,7 @@ var _ = Describe("Server", func() {
return len(b), nil return len(b), nil
}) })
serv.handleInitialImpl( serv.handleInitialImpl(
&receivedPacket{buffer: getPacketBuffer()}, receivedPacket{buffer: getPacketBuffer()},
&wire.Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}), Version: protocol.Version1}, &wire.Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}), Version: protocol.Version1},
) )
Eventually(done).Should(BeClosed()) Eventually(done).Should(BeClosed())
@ -1116,7 +1116,7 @@ var _ = Describe("Server", func() {
return ok return ok
}) })
serv.handleInitialImpl( serv.handleInitialImpl(
&receivedPacket{buffer: getPacketBuffer()}, receivedPacket{buffer: getPacketBuffer()},
&wire.Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})}, &wire.Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})},
) )
Consistently(done).ShouldNot(BeClosed()) Consistently(done).ShouldNot(BeClosed())
@ -1189,7 +1189,7 @@ var _ = Describe("Server", func() {
return ok return ok
}) })
serv.baseServer.handleInitialImpl( serv.baseServer.handleInitialImpl(
&receivedPacket{buffer: getPacketBuffer()}, receivedPacket{buffer: getPacketBuffer()},
&wire.Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})}, &wire.Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})},
) )
Consistently(done).ShouldNot(BeClosed()) Consistently(done).ShouldNot(BeClosed())
@ -1352,7 +1352,7 @@ var _ = Describe("Server", func() {
conn := NewMockPacketHandler(mockCtrl) conn := NewMockPacketHandler(mockCtrl)
phm.EXPECT().Get(connID).Return(conn, true) phm.EXPECT().Get(connID).Return(conn, true)
handled := make(chan struct{}) handled := make(chan struct{})
conn.EXPECT().handlePacket(p).Do(func(*receivedPacket) { close(handled) }) conn.EXPECT().handlePacket(p).Do(func(receivedPacket) { close(handled) })
serv.handlePacket(p) serv.handlePacket(p)
Eventually(handled).Should(BeClosed()) Eventually(handled).Should(BeClosed())
}) })
@ -1360,7 +1360,7 @@ var _ = Describe("Server", func() {
It("queues 0-RTT packets, up to Max0RTTQueueSize", func() { It("queues 0-RTT packets, up to Max0RTTQueueSize", func() {
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}) connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})
var zeroRTTPackets []*receivedPacket var zeroRTTPackets []receivedPacket
for i := 0; i < protocol.Max0RTTQueueLen; i++ { for i := 0; i < protocol.Max0RTTQueueLen; i++ {
p := getPacket(&wire.Header{ p := getPacket(&wire.Header{

View file

@ -79,16 +79,16 @@ type basicConn struct {
var _ rawConn = &basicConn{} var _ rawConn = &basicConn{}
func (c *basicConn) ReadPacket() (*receivedPacket, error) { func (c *basicConn) ReadPacket() (receivedPacket, error) {
buffer := getPacketBuffer() buffer := getPacketBuffer()
// The packet size should not exceed protocol.MaxPacketBufferSize bytes // The packet size should not exceed protocol.MaxPacketBufferSize bytes
// If it does, we only read a truncated packet, which will then end up undecryptable // If it does, we only read a truncated packet, which will then end up undecryptable
buffer.Data = buffer.Data[:protocol.MaxPacketBufferSize] buffer.Data = buffer.Data[:protocol.MaxPacketBufferSize]
n, addr, err := c.PacketConn.ReadFrom(buffer.Data) n, addr, err := c.PacketConn.ReadFrom(buffer.Data)
if err != nil { if err != nil {
return nil, err return receivedPacket{}, err
} }
return &receivedPacket{ return receivedPacket{
remoteAddr: addr, remoteAddr: addr,
rcvTime: time.Now(), rcvTime: time.Now(),
data: buffer.Data[:n], data: buffer.Data[:n],

View file

@ -148,7 +148,7 @@ func newConn(c OOBCapablePacketConn, supportsDF bool) (*oobConn, error) {
return oobConn, nil return oobConn, nil
} }
func (c *oobConn) ReadPacket() (*receivedPacket, error) { func (c *oobConn) ReadPacket() (receivedPacket, error) {
if len(c.messages) == int(c.readPos) { // all messages read. Read the next batch of messages. if len(c.messages) == int(c.readPos) { // all messages read. Read the next batch of messages.
c.messages = c.messages[:batchSize] c.messages = c.messages[:batchSize]
// replace buffers data buffers up to the packet that has been consumed during the last ReadBatch call // replace buffers data buffers up to the packet that has been consumed during the last ReadBatch call
@ -162,7 +162,7 @@ func (c *oobConn) ReadPacket() (*receivedPacket, error) {
n, err := c.batchConn.ReadBatch(c.messages, 0) n, err := c.batchConn.ReadBatch(c.messages, 0)
if n == 0 || err != nil { if n == 0 || err != nil {
return nil, err return receivedPacket{}, err
} }
c.messages = c.messages[:n] c.messages = c.messages[:n]
} }
@ -178,7 +178,7 @@ func (c *oobConn) ReadPacket() (*receivedPacket, error) {
for len(data) > 0 { for len(data) > 0 {
hdr, body, remainder, err := unix.ParseOneSocketControlMessage(data) hdr, body, remainder, err := unix.ParseOneSocketControlMessage(data)
if err != nil { if err != nil {
return nil, err return receivedPacket{}, err
} }
if hdr.Level == unix.IPPROTO_IP { if hdr.Level == unix.IPPROTO_IP {
switch hdr.Type { switch hdr.Type {
@ -228,7 +228,7 @@ func (c *oobConn) ReadPacket() (*receivedPacket, error) {
ifIndex: ifIndex, ifIndex: ifIndex,
} }
} }
return &receivedPacket{ return receivedPacket{
remoteAddr: msg.Addr, remoteAddr: msg.Addr,
rcvTime: time.Now(), rcvTime: time.Now(),
data: msg.Buffers[0][:msg.N], data: msg.Buffers[0][:msg.N],

View file

@ -19,7 +19,7 @@ import (
) )
var _ = Describe("OOB Conn Test", func() { var _ = Describe("OOB Conn Test", func() {
runServer := func(network, address string) (*net.UDPConn, <-chan *receivedPacket) { runServer := func(network, address string) (*net.UDPConn, <-chan receivedPacket) {
addr, err := net.ResolveUDPAddr(network, address) addr, err := net.ResolveUDPAddr(network, address)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
udpConn, err := net.ListenUDP(network, addr) udpConn, err := net.ListenUDP(network, addr)
@ -28,7 +28,7 @@ var _ = Describe("OOB Conn Test", func() {
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(oobConn.capabilities().DF).To(BeTrue()) Expect(oobConn.capabilities().DF).To(BeTrue())
packetChan := make(chan *receivedPacket) packetChan := make(chan receivedPacket)
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
for { for {
@ -69,7 +69,7 @@ var _ = Describe("OOB Conn Test", func() {
}, },
) )
var p *receivedPacket var p receivedPacket
Eventually(packetChan).Should(Receive(&p)) Eventually(packetChan).Should(Receive(&p))
Expect(p.rcvTime).To(BeTemporally("~", time.Now(), scaleDuration(20*time.Millisecond))) Expect(p.rcvTime).To(BeTemporally("~", time.Now(), scaleDuration(20*time.Millisecond)))
Expect(p.data).To(Equal([]byte("foobar"))) Expect(p.data).To(Equal([]byte("foobar")))
@ -89,7 +89,7 @@ var _ = Describe("OOB Conn Test", func() {
}, },
) )
var p *receivedPacket var p receivedPacket
Eventually(packetChan).Should(Receive(&p)) Eventually(packetChan).Should(Receive(&p))
Expect(p.rcvTime).To(BeTemporally("~", time.Now(), scaleDuration(20*time.Millisecond))) Expect(p.rcvTime).To(BeTemporally("~", time.Now(), scaleDuration(20*time.Millisecond)))
Expect(p.data).To(Equal([]byte("foobar"))) Expect(p.data).To(Equal([]byte("foobar")))
@ -111,7 +111,7 @@ var _ = Describe("OOB Conn Test", func() {
}, },
) )
var p *receivedPacket var p receivedPacket
Eventually(packetChan).Should(Receive(&p)) Eventually(packetChan).Should(Receive(&p))
Expect(utils.IsIPv4(p.remoteAddr.(*net.UDPAddr).IP)).To(BeTrue()) Expect(utils.IsIPv4(p.remoteAddr.(*net.UDPAddr).IP)).To(BeTrue())
Expect(p.ecn).To(Equal(protocol.ECNCE)) Expect(p.ecn).To(Equal(protocol.ECNCE))
@ -149,7 +149,7 @@ var _ = Describe("OOB Conn Test", func() {
addr.IP = ip addr.IP = ip
sentFrom := sendPacket("udp4", addr) sentFrom := sendPacket("udp4", addr)
var p *receivedPacket var p receivedPacket
Eventually(packetChan).Should(Receive(&p)) Eventually(packetChan).Should(Receive(&p))
Expect(p.rcvTime).To(BeTemporally("~", time.Now(), scaleDuration(20*time.Millisecond))) Expect(p.rcvTime).To(BeTemporally("~", time.Now(), scaleDuration(20*time.Millisecond)))
Expect(p.data).To(Equal([]byte("foobar"))) Expect(p.data).To(Equal([]byte("foobar")))
@ -167,7 +167,7 @@ var _ = Describe("OOB Conn Test", func() {
addr.IP = ip addr.IP = ip
sentFrom := sendPacket("udp6", addr) sentFrom := sendPacket("udp6", addr)
var p *receivedPacket var p receivedPacket
Eventually(packetChan).Should(Receive(&p)) Eventually(packetChan).Should(Receive(&p))
Expect(p.rcvTime).To(BeTemporally("~", time.Now(), scaleDuration(20*time.Millisecond))) Expect(p.rcvTime).To(BeTemporally("~", time.Now(), scaleDuration(20*time.Millisecond)))
Expect(p.data).To(Equal([]byte("foobar"))) Expect(p.data).To(Equal([]byte("foobar")))
@ -185,7 +185,7 @@ var _ = Describe("OOB Conn Test", func() {
ip4 := net.ParseIP("127.0.0.1").To4() ip4 := net.ParseIP("127.0.0.1").To4()
sendPacket("udp4", &net.UDPAddr{IP: ip4, Port: port}) sendPacket("udp4", &net.UDPAddr{IP: ip4, Port: port})
var p *receivedPacket var p receivedPacket
Eventually(packetChan).Should(Receive(&p)) Eventually(packetChan).Should(Receive(&p))
Expect(utils.IsIPv4(p.remoteAddr.(*net.UDPAddr).IP)).To(BeTrue()) Expect(utils.IsIPv4(p.remoteAddr.(*net.UDPAddr).IP)).To(BeTrue())
Expect(p.info).To(Not(BeNil())) Expect(p.info).To(Not(BeNil()))

View file

@ -74,7 +74,7 @@ type Transport struct {
conn rawConn conn rawConn
closeQueue chan closePacket closeQueue chan closePacket
statelessResetQueue chan *receivedPacket statelessResetQueue chan receivedPacket
listening chan struct{} // is closed when listen returns listening chan struct{} // is closed when listen returns
closed bool closed bool
@ -197,7 +197,7 @@ func (t *Transport) init(isServer bool) error {
t.listening = make(chan struct{}) t.listening = make(chan struct{})
t.closeQueue = make(chan closePacket, 4) t.closeQueue = make(chan closePacket, 4)
t.statelessResetQueue = make(chan *receivedPacket, 4) t.statelessResetQueue = make(chan receivedPacket, 4)
if t.ConnectionIDGenerator != nil { if t.ConnectionIDGenerator != nil {
t.connIDGenerator = t.ConnectionIDGenerator t.connIDGenerator = t.ConnectionIDGenerator
@ -339,7 +339,7 @@ func (t *Transport) listen(conn rawConn) {
} }
} }
func (t *Transport) handlePacket(p *receivedPacket) { func (t *Transport) handlePacket(p receivedPacket) {
connID, err := wire.ParseConnectionID(p.data, t.connIDLen) connID, err := wire.ParseConnectionID(p.data, t.connIDLen)
if err != nil { if err != nil {
t.logger.Debugf("error parsing connection ID on packet from %s: %s", p.remoteAddr, err) t.logger.Debugf("error parsing connection ID on packet from %s: %s", p.remoteAddr, err)
@ -371,7 +371,7 @@ func (t *Transport) handlePacket(p *receivedPacket) {
t.server.handlePacket(p) t.server.handlePacket(p)
} }
func (t *Transport) maybeSendStatelessReset(p *receivedPacket) { func (t *Transport) maybeSendStatelessReset(p receivedPacket) {
if t.StatelessResetKey == nil { if t.StatelessResetKey == nil {
p.buffer.Release() p.buffer.Release()
return return
@ -392,7 +392,7 @@ func (t *Transport) maybeSendStatelessReset(p *receivedPacket) {
} }
} }
func (t *Transport) sendStatelessReset(p *receivedPacket) { func (t *Transport) sendStatelessReset(p receivedPacket) {
defer p.buffer.Release() defer p.buffer.Release()
connID, err := wire.ParseConnectionID(p.data, t.connIDLen) connID, err := wire.ParseConnectionID(p.data, t.connIDLen)

View file

@ -70,7 +70,7 @@ var _ = Describe("Transport", func() {
handled := make(chan struct{}, 2) handled := make(chan struct{}, 2)
phm.EXPECT().Get(connID1).DoAndReturn(func(protocol.ConnectionID) (packetHandler, bool) { phm.EXPECT().Get(connID1).DoAndReturn(func(protocol.ConnectionID) (packetHandler, bool) {
h := NewMockPacketHandler(mockCtrl) h := NewMockPacketHandler(mockCtrl)
h.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) { h.EXPECT().handlePacket(gomock.Any()).Do(func(p receivedPacket) {
defer GinkgoRecover() defer GinkgoRecover()
connID, err := wire.ParseConnectionID(p.data, 0) connID, err := wire.ParseConnectionID(p.data, 0)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
@ -81,7 +81,7 @@ var _ = Describe("Transport", func() {
}) })
phm.EXPECT().Get(connID2).DoAndReturn(func(protocol.ConnectionID) (packetHandler, bool) { phm.EXPECT().Get(connID2).DoAndReturn(func(protocol.ConnectionID) (packetHandler, bool) {
h := NewMockPacketHandler(mockCtrl) h := NewMockPacketHandler(mockCtrl)
h.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) { h.EXPECT().handlePacket(gomock.Any()).Do(func(p receivedPacket) {
defer GinkgoRecover() defer GinkgoRecover()
connID, err := wire.ParseConnectionID(p.data, 0) connID, err := wire.ParseConnectionID(p.data, 0)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
@ -205,7 +205,7 @@ var _ = Describe("Transport", func() {
gomock.InOrder( gomock.InOrder(
phm.EXPECT().GetByResetToken(token), phm.EXPECT().GetByResetToken(token),
phm.EXPECT().Get(connID).Return(conn, true), phm.EXPECT().Get(connID).Return(conn, true),
conn.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) { conn.EXPECT().handlePacket(gomock.Any()).Do(func(p receivedPacket) {
Expect(p.data).To(Equal(b)) Expect(p.data).To(Equal(b))
Expect(p.rcvTime).To(BeTemporally("~", time.Now(), time.Second)) Expect(p.rcvTime).To(BeTemporally("~", time.Now(), time.Second))
}), }),