put back packet buffers after processing a packet

This introduces a reference counter in the packet buffer, which will be
used to process coalesced packets.
This commit is contained in:
Marten Seemann 2018-12-26 20:57:30 +06:30
parent 413844d0bc
commit 767dbdd545
9 changed files with 135 additions and 71 deletions

View file

@ -6,22 +6,42 @@ import (
"github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/protocol"
) )
var bufferPool sync.Pool type packetBuffer struct {
Slice []byte
func getPacketBuffer() *[]byte { // refCount counts how many packets the Slice is used in.
return bufferPool.Get().(*[]byte) // It doesn't support concurrent use.
// It is > 1 when used for coalesced packet.
refCount int
} }
func putPacketBuffer(buf *[]byte) { var bufferPool sync.Pool
if cap(*buf) != int(protocol.MaxReceivePacketSize) {
func getPacketBuffer() *packetBuffer {
buf := bufferPool.Get().(*packetBuffer)
buf.refCount = 1
buf.Slice = buf.Slice[:protocol.MaxReceivePacketSize]
return buf
}
func putPacketBuffer(buf *packetBuffer) {
if cap(buf.Slice) != int(protocol.MaxReceivePacketSize) {
panic("putPacketBuffer called with packet of wrong size!") panic("putPacketBuffer called with packet of wrong size!")
} }
bufferPool.Put(buf) buf.refCount--
if buf.refCount < 0 {
panic("negative packetBuffer refCount")
}
// only put the packetBuffer back if it's not used any more
if buf.refCount == 0 {
bufferPool.Put(buf)
}
} }
func init() { func init() {
bufferPool.New = func() interface{} { bufferPool.New = func() interface{} {
b := make([]byte, 0, protocol.MaxReceivePacketSize) return &packetBuffer{
return &b Slice: make([]byte, 0, protocol.MaxReceivePacketSize),
}
} }
} }

View file

@ -9,13 +9,24 @@ import (
var _ = Describe("Buffer Pool", func() { var _ = Describe("Buffer Pool", func() {
It("returns buffers of cap", func() { It("returns buffers of cap", func() {
buf := *getPacketBuffer() buf := getPacketBuffer()
Expect(buf).To(HaveCap(int(protocol.MaxReceivePacketSize))) Expect(buf.Slice).To(HaveCap(int(protocol.MaxReceivePacketSize)))
})
It("puts buffers back", func() {
buf := getPacketBuffer()
putPacketBuffer(buf)
}) })
It("panics if wrong-sized buffers are passed", func() { It("panics if wrong-sized buffers are passed", func() {
Expect(func() { buf := getPacketBuffer()
putPacketBuffer(&[]byte{0}) buf.Slice = make([]byte, 10)
}).To(Panic()) Expect(func() { putPacketBuffer(buf) }).To(Panic())
})
It("panics if it is put pack twice", func() {
buf := getPacketBuffer()
putPacketBuffer(buf)
Expect(func() { putPacketBuffer(buf) }).To(Panic())
}) })
}) })

View file

@ -144,8 +144,8 @@ func (h *packetHandlerMap) close(e error) error {
func (h *packetHandlerMap) listen() { func (h *packetHandlerMap) listen() {
for { for {
data := *getPacketBuffer() buffer := getPacketBuffer()
data = data[:protocol.MaxReceivePacketSize] data := buffer.Slice
// The packet size should not exceed protocol.MaxReceivePacketSize bytes // The packet size should not exceed protocol.MaxReceivePacketSize bytes
// If it does, we only read a truncated packet, which will then end up undecryptable // If it does, we only read a truncated packet, which will then end up undecryptable
n, addr, err := h.conn.ReadFrom(data) n, addr, err := h.conn.ReadFrom(data)
@ -155,13 +155,17 @@ func (h *packetHandlerMap) listen() {
} }
data = data[:n] data = data[:n]
if err := h.handlePacket(addr, data); err != nil { if err := h.handlePacket(addr, buffer, data); err != nil {
h.logger.Debugf("error handling packet from %s: %s", addr, err) h.logger.Debugf("error handling packet from %s: %s", addr, err)
} }
} }
} }
func (h *packetHandlerMap) handlePacket(addr net.Addr, data []byte) error { func (h *packetHandlerMap) handlePacket(
addr net.Addr,
buffer *packetBuffer,
data []byte,
) error {
r := bytes.NewReader(data) r := bytes.NewReader(data)
hdr, err := wire.ParseHeader(r, h.connIDLen) hdr, err := wire.ParseHeader(r, h.connIDLen)
// drop the packet if we can't parse the header // drop the packet if we can't parse the header
@ -172,8 +176,9 @@ func (h *packetHandlerMap) handlePacket(addr net.Addr, data []byte) error {
p := &receivedPacket{ p := &receivedPacket{
remoteAddr: addr, remoteAddr: addr,
hdr: hdr, hdr: hdr,
data: data,
rcvTime: time.Now(), rcvTime: time.Now(),
data: data,
buffer: buffer,
} }
h.mutex.RLock() h.mutex.RLock()

View file

@ -81,7 +81,7 @@ var _ = Describe("Packet Handler Map", func() {
}) })
It("drops unparseable packets", func() { It("drops unparseable packets", func() {
err := handler.handlePacket(nil, []byte{0, 1, 2, 3}) err := handler.handlePacket(nil, nil, []byte{0, 1, 2, 3})
Expect(err).To(HaveOccurred()) Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("error parsing header:")) Expect(err.Error()).To(ContainSubstring("error parsing header:"))
}) })
@ -91,7 +91,7 @@ var _ = Describe("Packet Handler Map", func() {
connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
handler.Add(connID, NewMockPacketHandler(mockCtrl)) handler.Add(connID, NewMockPacketHandler(mockCtrl))
handler.Remove(connID) handler.Remove(connID)
Expect(handler.handlePacket(nil, getPacket(connID))).To(MatchError("received a packet with an unexpected connection ID 0x0102030405060708")) Expect(handler.handlePacket(nil, nil, getPacket(connID))).To(MatchError("received a packet with an unexpected connection ID 0x0102030405060708"))
}) })
It("deletes retired session entries after a wait time", func() { It("deletes retired session entries after a wait time", func() {
@ -100,7 +100,7 @@ var _ = Describe("Packet Handler Map", func() {
handler.Add(connID, NewMockPacketHandler(mockCtrl)) handler.Add(connID, NewMockPacketHandler(mockCtrl))
handler.Retire(connID) handler.Retire(connID)
time.Sleep(scaleDuration(30 * time.Millisecond)) time.Sleep(scaleDuration(30 * time.Millisecond))
Expect(handler.handlePacket(nil, getPacket(connID))).To(MatchError("received a packet with an unexpected connection ID 0x0102030405060708")) Expect(handler.handlePacket(nil, nil, getPacket(connID))).To(MatchError("received a packet with an unexpected connection ID 0x0102030405060708"))
}) })
It("passes packets arriving late for closed sessions to that session", func() { It("passes packets arriving late for closed sessions to that session", func() {
@ -110,13 +110,13 @@ var _ = Describe("Packet Handler Map", func() {
packetHandler.EXPECT().handlePacket(gomock.Any()) packetHandler.EXPECT().handlePacket(gomock.Any())
handler.Add(connID, packetHandler) handler.Add(connID, packetHandler)
handler.Retire(connID) handler.Retire(connID)
err := handler.handlePacket(nil, getPacket(connID)) err := handler.handlePacket(nil, nil, getPacket(connID))
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
}) })
It("drops packets for unknown receivers", func() { It("drops packets for unknown receivers", func() {
connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
err := handler.handlePacket(nil, getPacket(connID)) err := handler.handlePacket(nil, nil, getPacket(connID))
Expect(err).To(MatchError("received a packet with an unexpected connection ID 0x0102030405060708")) Expect(err).To(MatchError("received a packet with an unexpected connection ID 0x0102030405060708"))
}) })
@ -171,10 +171,10 @@ var _ = Describe("Packet Handler Map", func() {
handler.AddWithResetToken(connID, NewMockPacketHandler(mockCtrl), token) handler.AddWithResetToken(connID, NewMockPacketHandler(mockCtrl), token)
handler.Retire(connID) handler.Retire(connID)
time.Sleep(scaleDuration(30 * time.Millisecond)) time.Sleep(scaleDuration(30 * time.Millisecond))
Expect(handler.handlePacket(nil, getPacket(connID))).To(MatchError("received a packet with an unexpected connection ID 0xdeadbeef42")) Expect(handler.handlePacket(nil, nil, getPacket(connID))).To(MatchError("received a packet with an unexpected connection ID 0xdeadbeef42"))
packet := append([]byte{0x40, 0xde, 0xca, 0xfb, 0xad, 0x99} /* short header packet */, make([]byte, 50)...) packet := append([]byte{0x40, 0xde, 0xca, 0xfb, 0xad, 0x99} /* short header packet */, make([]byte, 50)...)
packet = append(packet, token[:]...) packet = append(packet, token[:]...)
Expect(handler.handlePacket(nil, packet)).To(MatchError("received a short header packet with an unexpected connection ID 0xdecafbad99")) Expect(handler.handlePacket(nil, nil, packet)).To(MatchError("received a short header packet with an unexpected connection ID 0xdecafbad99"))
Expect(handler.resetTokens).To(BeEmpty()) Expect(handler.resetTokens).To(BeEmpty())
}) })
}) })
@ -188,7 +188,7 @@ var _ = Describe("Packet Handler Map", func() {
Expect(p.hdr.DestConnectionID).To(Equal(connID)) Expect(p.hdr.DestConnectionID).To(Equal(connID))
}) })
handler.SetServer(server) handler.SetServer(server)
Expect(handler.handlePacket(nil, p)).To(Succeed()) Expect(handler.handlePacket(nil, nil, p)).To(Succeed())
}) })
It("closes all server sessions", func() { It("closes all server sessions", func() {
@ -209,7 +209,7 @@ var _ = Describe("Packet Handler Map", func() {
server := NewMockUnknownPacketHandler(mockCtrl) server := NewMockUnknownPacketHandler(mockCtrl)
handler.SetServer(server) handler.SetServer(server)
handler.CloseServer() handler.CloseServer()
Expect(handler.handlePacket(nil, p)).To(MatchError("received a packet with an unexpected connection ID 0x1122334455667788")) Expect(handler.handlePacket(nil, nil, p)).To(MatchError("received a packet with an unexpected connection ID 0x1122334455667788"))
}) })
}) })
}) })

View file

@ -28,6 +28,8 @@ type packedPacket struct {
header *wire.ExtendedHeader header *wire.ExtendedHeader
raw []byte raw []byte
frames []wire.Frame frames []wire.Frame
buffer *packetBuffer
} }
func (p *packedPacket) EncryptionLevel() protocol.EncryptionLevel { func (p *packedPacket) EncryptionLevel() protocol.EncryptionLevel {
@ -374,8 +376,8 @@ func (p *packetPacker) writeAndSealPacket(
frames []wire.Frame, frames []wire.Frame,
sealer handshake.Sealer, sealer handshake.Sealer,
) (*packedPacket, error) { ) (*packedPacket, error) {
raw := *getPacketBuffer() packetBuffer := getPacketBuffer()
buffer := bytes.NewBuffer(raw[:0]) buffer := bytes.NewBuffer(packetBuffer.Slice[:0])
addPaddingForInitial := p.perspective == protocol.PerspectiveClient && header.Type == protocol.PacketTypeInitial addPaddingForInitial := p.perspective == protocol.PerspectiveClient && header.Type == protocol.PacketTypeInitial
@ -436,7 +438,7 @@ func (p *packetPacker) writeAndSealPacket(
return nil, fmt.Errorf("PacketPacker BUG: packet too large (%d bytes, allowed %d bytes)", size, p.maxPacketSize) return nil, fmt.Errorf("PacketPacker BUG: packet too large (%d bytes, allowed %d bytes)", size, p.maxPacketSize)
} }
raw = raw[0:buffer.Len()] raw := buffer.Bytes()
_ = sealer.Seal(raw[payloadOffset:payloadOffset], raw[payloadOffset:], header.PacketNumber, raw[:payloadOffset]) _ = sealer.Seal(raw[payloadOffset:payloadOffset], raw[payloadOffset:], header.PacketNumber, raw[:payloadOffset])
raw = raw[0 : buffer.Len()+sealer.Overhead()] raw = raw[0 : buffer.Len()+sealer.Overhead()]
@ -455,6 +457,7 @@ func (p *packetPacker) writeAndSealPacket(
header: header, header: header,
raw: raw, raw: raw,
frames: frames, frames: frames,
buffer: packetBuffer,
}, nil }, nil
} }

View file

@ -318,21 +318,27 @@ func (s *server) handlePacket(p *receivedPacket) {
} }
if hdr.Type == protocol.PacketTypeInitial { if hdr.Type == protocol.PacketTypeInitial {
go s.handleInitial(p) go s.handleInitial(p)
return
} }
putPacketBuffer(p.buffer)
// TODO(#943): send Stateless Reset // TODO(#943): send Stateless Reset
} }
func (s *server) handleInitial(p *receivedPacket) { func (s *server) handleInitial(p *receivedPacket) {
// TODO: add a check that DestConnID == SrcConnID
s.logger.Debugf("<- Received Initial packet.") s.logger.Debugf("<- Received Initial packet.")
sess, connID, err := s.handleInitialImpl(p) sess, connID, err := s.handleInitialImpl(p)
if err != nil { if err != nil {
putPacketBuffer(p.buffer)
s.logger.Errorf("Error occurred handling initial packet: %s", err) s.logger.Errorf("Error occurred handling initial packet: %s", err)
return return
} }
if sess == nil { // a retry was done if sess == nil { // a retry was done
putPacketBuffer(p.buffer)
return return
} }
// Don't put the packet buffer back if a new session was created.
// The session will handle the packet and take of that.
serverSession := newServerSession(sess, s.config, s.logger) serverSession := newServerSession(sess, s.config, s.logger)
s.sessionHandler.Add(connID, serverSession) s.sessionHandler.Add(connID, serverSession)
} }
@ -455,6 +461,7 @@ func (s *server) sendRetry(remoteAddr net.Addr, hdr *wire.Header) error {
} }
func (s *server) sendVersionNegotiationPacket(p *receivedPacket) { func (s *server) sendVersionNegotiationPacket(p *receivedPacket) {
defer putPacketBuffer(p.buffer)
hdr := p.hdr hdr := p.hdr
s.logger.Debugf("Client offered version %s, sending Version Negotiation", hdr.Version) s.logger.Debugf("Client offered version %s, sending Version Negotiation", hdr.Version)
data, err := wire.ComposeVersionNegotiation(hdr.SrcConnectionID, hdr.DestConnectionID, s.config.Versions) data, err := wire.ComposeVersionNegotiation(hdr.SrcConnectionID, hdr.DestConnectionID, s.config.Versions)

View file

@ -122,19 +122,19 @@ var _ = Describe("Server", func() {
} }
It("drops Initial packets with a too short connection ID", func() { It("drops Initial packets with a too short connection ID", func() {
serv.handlePacket(&receivedPacket{ serv.handlePacket(insertPacketBuffer(&receivedPacket{
hdr: &wire.Header{ hdr: &wire.Header{
IsLongHeader: true, IsLongHeader: true,
Type: protocol.PacketTypeInitial, Type: protocol.PacketTypeInitial,
DestConnectionID: protocol.ConnectionID{1, 2, 3, 4}, DestConnectionID: protocol.ConnectionID{1, 2, 3, 4},
Version: serv.config.Versions[0], Version: serv.config.Versions[0],
}, },
}) }))
Consistently(conn.dataWritten).ShouldNot(Receive()) Consistently(conn.dataWritten).ShouldNot(Receive())
}) })
It("drops too small Initial", func() { It("drops too small Initial", func() {
serv.handlePacket(&receivedPacket{ serv.handlePacket(insertPacketBuffer(&receivedPacket{
hdr: &wire.Header{ hdr: &wire.Header{
IsLongHeader: true, IsLongHeader: true,
Type: protocol.PacketTypeInitial, Type: protocol.PacketTypeInitial,
@ -142,12 +142,12 @@ var _ = Describe("Server", func() {
Version: serv.config.Versions[0], Version: serv.config.Versions[0],
}, },
data: bytes.Repeat([]byte{0}, protocol.MinInitialPacketSize-100), data: bytes.Repeat([]byte{0}, protocol.MinInitialPacketSize-100),
}) }))
Consistently(conn.dataWritten).ShouldNot(Receive()) Consistently(conn.dataWritten).ShouldNot(Receive())
}) })
It("drops packets with a too short connection ID", func() { It("drops packets with a too short connection ID", func() {
serv.handlePacket(&receivedPacket{ serv.handlePacket(insertPacketBuffer(&receivedPacket{
hdr: &wire.Header{ hdr: &wire.Header{
IsLongHeader: true, IsLongHeader: true,
Type: protocol.PacketTypeInitial, Type: protocol.PacketTypeInitial,
@ -156,19 +156,19 @@ var _ = Describe("Server", func() {
Version: serv.config.Versions[0], Version: serv.config.Versions[0],
}, },
data: bytes.Repeat([]byte{0}, protocol.MinInitialPacketSize), data: bytes.Repeat([]byte{0}, protocol.MinInitialPacketSize),
}) }))
Consistently(conn.dataWritten).ShouldNot(Receive()) Consistently(conn.dataWritten).ShouldNot(Receive())
}) })
It("drops non-Initial packets", func() { It("drops non-Initial packets", func() {
serv.logger.SetLogLevel(utils.LogLevelDebug) serv.logger.SetLogLevel(utils.LogLevelDebug)
serv.handlePacket(&receivedPacket{ serv.handlePacket(insertPacketBuffer(&receivedPacket{
hdr: &wire.Header{ hdr: &wire.Header{
Type: protocol.PacketTypeHandshake, Type: protocol.PacketTypeHandshake,
Version: serv.config.Versions[0], Version: serv.config.Versions[0],
}, },
data: []byte("invalid"), data: []byte("invalid"),
}) }))
}) })
It("decodes the cookie from the Token field", func() { It("decodes the cookie from the Token field", func() {
@ -185,7 +185,7 @@ var _ = Describe("Server", func() {
} }
token, err := serv.cookieGenerator.NewToken(raddr, nil) token, err := serv.cookieGenerator.NewToken(raddr, nil)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
serv.handlePacket(&receivedPacket{ serv.handlePacket(insertPacketBuffer(&receivedPacket{
remoteAddr: raddr, remoteAddr: raddr,
hdr: &wire.Header{ hdr: &wire.Header{
Type: protocol.PacketTypeInitial, Type: protocol.PacketTypeInitial,
@ -193,7 +193,7 @@ var _ = Describe("Server", func() {
Version: serv.config.Versions[0], Version: serv.config.Versions[0],
}, },
data: bytes.Repeat([]byte{0}, protocol.MinInitialPacketSize), data: bytes.Repeat([]byte{0}, protocol.MinInitialPacketSize),
}) }))
Eventually(done).Should(BeClosed()) Eventually(done).Should(BeClosed())
}) })
@ -209,7 +209,7 @@ var _ = Describe("Server", func() {
close(done) close(done)
return false return false
} }
serv.handlePacket(&receivedPacket{ serv.handlePacket(insertPacketBuffer(&receivedPacket{
remoteAddr: raddr, remoteAddr: raddr,
hdr: &wire.Header{ hdr: &wire.Header{
Type: protocol.PacketTypeInitial, Type: protocol.PacketTypeInitial,
@ -217,14 +217,14 @@ var _ = Describe("Server", func() {
Version: serv.config.Versions[0], Version: serv.config.Versions[0],
}, },
data: bytes.Repeat([]byte{0}, protocol.MinInitialPacketSize), data: bytes.Repeat([]byte{0}, protocol.MinInitialPacketSize),
}) }))
Eventually(done).Should(BeClosed()) Eventually(done).Should(BeClosed())
}) })
It("sends a Version Negotiation Packet for unsupported versions", func() { It("sends a Version Negotiation Packet for unsupported versions", func() {
srcConnID := protocol.ConnectionID{1, 2, 3, 4, 5} srcConnID := protocol.ConnectionID{1, 2, 3, 4, 5}
destConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6} destConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6}
serv.handlePacket(&receivedPacket{ serv.handlePacket(insertPacketBuffer(&receivedPacket{
remoteAddr: &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}, remoteAddr: &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337},
hdr: &wire.Header{ hdr: &wire.Header{
IsLongHeader: true, IsLongHeader: true,
@ -233,7 +233,7 @@ var _ = Describe("Server", func() {
DestConnectionID: destConnID, DestConnectionID: destConnID,
Version: 0x42, Version: 0x42,
}, },
}) }))
var write mockPacketConnWrite var write mockPacketConnWrite
Eventually(conn.dataWritten).Should(Receive(&write)) Eventually(conn.dataWritten).Should(Receive(&write))
Expect(write.to.String()).To(Equal("127.0.0.1:1337")) Expect(write.to.String()).To(Equal("127.0.0.1:1337"))
@ -253,11 +253,11 @@ var _ = Describe("Server", func() {
DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
Version: protocol.VersionTLS, Version: protocol.VersionTLS,
} }
serv.handleInitial(&receivedPacket{ serv.handleInitial(insertPacketBuffer(&receivedPacket{
remoteAddr: &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}, remoteAddr: &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337},
hdr: hdr, hdr: hdr,
data: bytes.Repeat([]byte{0}, protocol.MinInitialPacketSize), data: bytes.Repeat([]byte{0}, protocol.MinInitialPacketSize),
}) }))
var write mockPacketConnWrite var write mockPacketConnWrite
Eventually(conn.dataWritten).Should(Receive(&write)) Eventually(conn.dataWritten).Should(Receive(&write))
Expect(write.to.String()).To(Equal("127.0.0.1:1337")) Expect(write.to.String()).To(Equal("127.0.0.1:1337"))
@ -308,7 +308,7 @@ var _ = Describe("Server", func() {
done := make(chan struct{}) done := make(chan struct{})
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
serv.handlePacket(p) serv.handlePacket(insertPacketBuffer(p))
// the Handshake packet is written by the session // the Handshake packet is written by the session
Consistently(conn.dataWritten).ShouldNot(Receive()) Consistently(conn.dataWritten).ShouldNot(Receive())
close(done) close(done)

View file

@ -53,8 +53,10 @@ type cryptoStreamHandler interface {
type receivedPacket struct { type receivedPacket struct {
remoteAddr net.Addr remoteAddr net.Addr
hdr *wire.Header hdr *wire.Header
data []byte
rcvTime time.Time rcvTime time.Time
data []byte
buffer *packetBuffer
} }
type closeError struct { type closeError struct {
@ -368,9 +370,6 @@ runLoop:
if wasProcessed := s.handlePacketImpl(p); !wasProcessed { if wasProcessed := s.handlePacketImpl(p); !wasProcessed {
continue continue
} }
// This is a bit unclean, but works properly, since the packet always
// begins with the public header and we never copy it.
// TODO: putPacketBuffer(&p.extHdr.Raw)
case <-s.handshakeCompleteChan: case <-s.handshakeCompleteChan:
s.handleHandshakeComplete() s.handleHandshakeComplete()
} }
@ -475,6 +474,15 @@ func (s *session) handleHandshakeComplete() {
} }
func (s *session) handlePacketImpl(p *receivedPacket) bool /* was the packet successfully processed */ { func (s *session) handlePacketImpl(p *receivedPacket) bool /* was the packet successfully processed */ {
var wasQueued bool
defer func() {
// Put back the packet buffer if the packet wasn't queued for later decryption.
if !wasQueued {
putPacketBuffer(p.buffer)
}
}()
// The server can change the source connection ID with the first Handshake packet. // The server can change the source connection ID with the first Handshake packet.
// After this, all packets with a different source connection have to be ignored. // After this, all packets with a different source connection have to be ignored.
if s.receivedFirstPacket && p.hdr.IsLongHeader && !p.hdr.SrcConnectionID.Equal(s.destConnID) { if s.receivedFirstPacket && p.hdr.IsLongHeader && !p.hdr.SrcConnectionID.Equal(s.destConnID) {
@ -490,6 +498,7 @@ func (s *session) handlePacketImpl(p *receivedPacket) bool /* was the packet suc
// if the decryption failed, this might be a packet sent by an attacker // if the decryption failed, this might be a packet sent by an attacker
if err != nil { if err != nil {
if err == handshake.ErrOpenerNotYetAvailable { if err == handshake.ErrOpenerNotYetAvailable {
wasQueued = true
s.tryQueueingUndecryptablePacket(p) s.tryQueueingUndecryptablePacket(p)
return false return false
} }
@ -953,7 +962,7 @@ func (s *session) sendPacket() (bool, error) {
} }
func (s *session) sendPackedPacket(packet *packedPacket) error { func (s *session) sendPackedPacket(packet *packedPacket) error {
defer putPacketBuffer(&packet.raw) defer putPacketBuffer(packet.buffer)
s.logPacket(packet) s.logPacket(packet)
return s.conn.Write(packet.raw) return s.conn.Write(packet.raw)
} }

View file

@ -61,6 +61,11 @@ func areSessionsRunning() bool {
return strings.Contains(b.String(), "quic-go.(*session).run") return strings.Contains(b.String(), "quic-go.(*session).run")
} }
func insertPacketBuffer(p *receivedPacket) *receivedPacket {
p.buffer = getPacketBuffer()
return p
}
var _ = Describe("Session", func() { var _ = Describe("Session", func() {
var ( var (
sess *session sess *session
@ -496,11 +501,11 @@ var _ = Describe("Session", func() {
rph := mockackhandler.NewMockReceivedPacketHandler(mockCtrl) rph := mockackhandler.NewMockReceivedPacketHandler(mockCtrl)
rph.EXPECT().ReceivedPacket(protocol.PacketNumber(0x1337), rcvTime, false) rph.EXPECT().ReceivedPacket(protocol.PacketNumber(0x1337), rcvTime, false)
sess.receivedPacketHandler = rph sess.receivedPacketHandler = rph
Expect(sess.handlePacketImpl(&receivedPacket{ Expect(sess.handlePacketImpl(insertPacketBuffer(&receivedPacket{
rcvTime: rcvTime, rcvTime: rcvTime,
hdr: &hdr.Header, hdr: &hdr.Header,
data: getData(hdr), data: getData(hdr),
})).To(BeTrue()) }))).To(BeTrue())
}) })
It("closes when handling a packet fails", func() { It("closes when handling a packet fails", func() {
@ -518,7 +523,10 @@ var _ = Describe("Session", func() {
close(done) close(done)
}() }()
sessionRunner.EXPECT().retireConnectionID(gomock.Any()) sessionRunner.EXPECT().retireConnectionID(gomock.Any())
sess.handlePacket(&receivedPacket{hdr: &wire.Header{}, data: getData(&wire.ExtendedHeader{PacketNumberLen: protocol.PacketNumberLen1})}) sess.handlePacket(insertPacketBuffer(&receivedPacket{
hdr: &wire.Header{},
data: getData(&wire.ExtendedHeader{PacketNumberLen: protocol.PacketNumberLen1}),
}))
Eventually(done).Should(BeClosed()) Eventually(done).Should(BeClosed())
}) })
@ -528,18 +536,18 @@ var _ = Describe("Session", func() {
PacketNumberLen: protocol.PacketNumberLen1, PacketNumberLen: protocol.PacketNumberLen1,
} }
unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any()).Return(&unpackedPacket{hdr: hdr}, nil).Times(2) unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any()).Return(&unpackedPacket{hdr: hdr}, nil).Times(2)
Expect(sess.handlePacketImpl(&receivedPacket{hdr: &hdr.Header, data: getData(hdr)})).To(BeTrue()) Expect(sess.handlePacketImpl(insertPacketBuffer(&receivedPacket{hdr: &hdr.Header, data: getData(hdr)}))).To(BeTrue())
Expect(sess.handlePacketImpl(&receivedPacket{hdr: &hdr.Header, data: getData(hdr)})).To(BeTrue()) Expect(sess.handlePacketImpl(insertPacketBuffer(&receivedPacket{hdr: &hdr.Header, data: getData(hdr)}))).To(BeTrue())
}) })
It("ignores 0-RTT packets", func() { It("ignores 0-RTT packets", func() {
Expect(sess.handlePacketImpl(&receivedPacket{ Expect(sess.handlePacketImpl(insertPacketBuffer(&receivedPacket{
hdr: &wire.Header{ hdr: &wire.Header{
IsLongHeader: true, IsLongHeader: true,
Type: protocol.PacketType0RTT, Type: protocol.PacketType0RTT,
DestConnectionID: sess.srcConnID, DestConnectionID: sess.srcConnID,
}, },
})).To(BeFalse()) }))).To(BeFalse())
}) })
It("ignores packets with a different source connection ID", func() { It("ignores packets with a different source connection ID", func() {
@ -552,12 +560,12 @@ var _ = Describe("Session", func() {
// Send one packet, which might change the connection ID. // Send one packet, which might change the connection ID.
// only EXPECT one call to the unpacker // only EXPECT one call to the unpacker
unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any()).Return(&unpackedPacket{hdr: &wire.ExtendedHeader{Header: *hdr}}, nil) unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any()).Return(&unpackedPacket{hdr: &wire.ExtendedHeader{Header: *hdr}}, nil)
Expect(sess.handlePacketImpl(&receivedPacket{ Expect(sess.handlePacketImpl(insertPacketBuffer(&receivedPacket{
hdr: hdr, hdr: hdr,
data: getData(&wire.ExtendedHeader{PacketNumberLen: protocol.PacketNumberLen1}), data: getData(&wire.ExtendedHeader{PacketNumberLen: protocol.PacketNumberLen1}),
})).To(BeTrue()) }))).To(BeTrue())
// The next packet has to be ignored, since the source connection ID doesn't match. // The next packet has to be ignored, since the source connection ID doesn't match.
Expect(sess.handlePacketImpl(&receivedPacket{ Expect(sess.handlePacketImpl(insertPacketBuffer(&receivedPacket{
hdr: &wire.Header{ hdr: &wire.Header{
IsLongHeader: true, IsLongHeader: true,
DestConnectionID: sess.destConnID, DestConnectionID: sess.destConnID,
@ -565,7 +573,7 @@ var _ = Describe("Session", func() {
Length: 1, Length: 1,
}, },
data: getData(&wire.ExtendedHeader{PacketNumberLen: protocol.PacketNumberLen1}), data: getData(&wire.ExtendedHeader{PacketNumberLen: protocol.PacketNumberLen1}),
})).To(BeFalse()) }))).To(BeFalse())
}) })
Context("updating the remote address", func() { Context("updating the remote address", func() {
@ -574,12 +582,11 @@ var _ = Describe("Session", func() {
origAddr := sess.conn.(*mockConnection).remoteAddr origAddr := sess.conn.(*mockConnection).remoteAddr
remoteIP := &net.IPAddr{IP: net.IPv4(192, 168, 0, 100)} remoteIP := &net.IPAddr{IP: net.IPv4(192, 168, 0, 100)}
Expect(origAddr).ToNot(Equal(remoteIP)) Expect(origAddr).ToNot(Equal(remoteIP))
p := receivedPacket{ Expect(sess.handlePacketImpl(insertPacketBuffer(&receivedPacket{
remoteAddr: remoteIP, remoteAddr: remoteIP,
hdr: &wire.Header{}, hdr: &wire.Header{},
data: getData(&wire.ExtendedHeader{PacketNumberLen: protocol.PacketNumberLen1}), data: getData(&wire.ExtendedHeader{PacketNumberLen: protocol.PacketNumberLen1}),
} }))).To(BeTrue())
Expect(sess.handlePacketImpl(&p)).To(BeTrue())
Expect(sess.conn.(*mockConnection).remoteAddr).To(Equal(origAddr)) Expect(sess.conn.(*mockConnection).remoteAddr).To(Equal(origAddr))
}) })
}) })
@ -587,10 +594,12 @@ var _ = Describe("Session", func() {
Context("sending packets", func() { Context("sending packets", func() {
getPacket := func(pn protocol.PacketNumber) *packedPacket { getPacket := func(pn protocol.PacketNumber) *packedPacket {
data := *getPacketBuffer() buffer := getPacketBuffer()
data := buffer.Slice[:0]
data = append(data, []byte("foobar")...) data = append(data, []byte("foobar")...)
return &packedPacket{ return &packedPacket{
raw: data, raw: data,
buffer: buffer,
header: &wire.ExtendedHeader{PacketNumber: pn}, header: &wire.ExtendedHeader{PacketNumber: pn},
} }
} }
@ -963,7 +972,7 @@ var _ = Describe("Session", func() {
defer close(done) defer close(done)
return &packedPacket{ return &packedPacket{
header: &wire.ExtendedHeader{}, header: &wire.ExtendedHeader{},
raw: *getPacketBuffer(), buffer: getPacketBuffer(),
}, nil }, nil
}), }),
packer.EXPECT().PackPacket().AnyTimes(), packer.EXPECT().PackPacket().AnyTimes(),
@ -1352,7 +1361,7 @@ var _ = Describe("Client Session", func() {
}() }()
newConnID := protocol.ConnectionID{1, 3, 3, 7, 1, 3, 3, 7} newConnID := protocol.ConnectionID{1, 3, 3, 7, 1, 3, 3, 7}
packer.EXPECT().ChangeDestConnectionID(newConnID) packer.EXPECT().ChangeDestConnectionID(newConnID)
Expect(sess.handlePacketImpl(&receivedPacket{ Expect(sess.handlePacketImpl(insertPacketBuffer(&receivedPacket{
hdr: &wire.Header{ hdr: &wire.Header{
IsLongHeader: true, IsLongHeader: true,
Type: protocol.PacketTypeHandshake, Type: protocol.PacketTypeHandshake,
@ -1361,7 +1370,7 @@ var _ = Describe("Client Session", func() {
Length: 1, Length: 1,
}, },
data: []byte{0}, data: []byte{0},
})).To(BeTrue()) }))).To(BeTrue())
// make sure the go routine returns // make sure the go routine returns
packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil)
sessionRunner.EXPECT().retireConnectionID(gomock.Any()) sessionRunner.EXPECT().retireConnectionID(gomock.Any())