mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-03 20:27:35 +03:00
put back packet buffers after processing a packet
This introduces a reference counter in the packet buffer, which will be used to process coalesced packets.
This commit is contained in:
parent
413844d0bc
commit
767dbdd545
9 changed files with 135 additions and 71 deletions
|
@ -6,22 +6,42 @@ import (
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
)
|
)
|
||||||
|
|
||||||
var bufferPool sync.Pool
|
type packetBuffer struct {
|
||||||
|
Slice []byte
|
||||||
|
|
||||||
func getPacketBuffer() *[]byte {
|
// refCount counts how many packets the Slice is used in.
|
||||||
return bufferPool.Get().(*[]byte)
|
// It doesn't support concurrent use.
|
||||||
|
// It is > 1 when used for coalesced packet.
|
||||||
|
refCount int
|
||||||
}
|
}
|
||||||
|
|
||||||
func putPacketBuffer(buf *[]byte) {
|
var bufferPool sync.Pool
|
||||||
if cap(*buf) != int(protocol.MaxReceivePacketSize) {
|
|
||||||
|
func getPacketBuffer() *packetBuffer {
|
||||||
|
buf := bufferPool.Get().(*packetBuffer)
|
||||||
|
buf.refCount = 1
|
||||||
|
buf.Slice = buf.Slice[:protocol.MaxReceivePacketSize]
|
||||||
|
return buf
|
||||||
|
}
|
||||||
|
|
||||||
|
func putPacketBuffer(buf *packetBuffer) {
|
||||||
|
if cap(buf.Slice) != int(protocol.MaxReceivePacketSize) {
|
||||||
panic("putPacketBuffer called with packet of wrong size!")
|
panic("putPacketBuffer called with packet of wrong size!")
|
||||||
}
|
}
|
||||||
bufferPool.Put(buf)
|
buf.refCount--
|
||||||
|
if buf.refCount < 0 {
|
||||||
|
panic("negative packetBuffer refCount")
|
||||||
|
}
|
||||||
|
// only put the packetBuffer back if it's not used any more
|
||||||
|
if buf.refCount == 0 {
|
||||||
|
bufferPool.Put(buf)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
bufferPool.New = func() interface{} {
|
bufferPool.New = func() interface{} {
|
||||||
b := make([]byte, 0, protocol.MaxReceivePacketSize)
|
return &packetBuffer{
|
||||||
return &b
|
Slice: make([]byte, 0, protocol.MaxReceivePacketSize),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -9,13 +9,24 @@ import (
|
||||||
|
|
||||||
var _ = Describe("Buffer Pool", func() {
|
var _ = Describe("Buffer Pool", func() {
|
||||||
It("returns buffers of cap", func() {
|
It("returns buffers of cap", func() {
|
||||||
buf := *getPacketBuffer()
|
buf := getPacketBuffer()
|
||||||
Expect(buf).To(HaveCap(int(protocol.MaxReceivePacketSize)))
|
Expect(buf.Slice).To(HaveCap(int(protocol.MaxReceivePacketSize)))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("puts buffers back", func() {
|
||||||
|
buf := getPacketBuffer()
|
||||||
|
putPacketBuffer(buf)
|
||||||
})
|
})
|
||||||
|
|
||||||
It("panics if wrong-sized buffers are passed", func() {
|
It("panics if wrong-sized buffers are passed", func() {
|
||||||
Expect(func() {
|
buf := getPacketBuffer()
|
||||||
putPacketBuffer(&[]byte{0})
|
buf.Slice = make([]byte, 10)
|
||||||
}).To(Panic())
|
Expect(func() { putPacketBuffer(buf) }).To(Panic())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("panics if it is put pack twice", func() {
|
||||||
|
buf := getPacketBuffer()
|
||||||
|
putPacketBuffer(buf)
|
||||||
|
Expect(func() { putPacketBuffer(buf) }).To(Panic())
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
|
@ -144,8 +144,8 @@ func (h *packetHandlerMap) close(e error) error {
|
||||||
|
|
||||||
func (h *packetHandlerMap) listen() {
|
func (h *packetHandlerMap) listen() {
|
||||||
for {
|
for {
|
||||||
data := *getPacketBuffer()
|
buffer := getPacketBuffer()
|
||||||
data = data[:protocol.MaxReceivePacketSize]
|
data := buffer.Slice
|
||||||
// The packet size should not exceed protocol.MaxReceivePacketSize bytes
|
// The packet size should not exceed protocol.MaxReceivePacketSize bytes
|
||||||
// If it does, we only read a truncated packet, which will then end up undecryptable
|
// If it does, we only read a truncated packet, which will then end up undecryptable
|
||||||
n, addr, err := h.conn.ReadFrom(data)
|
n, addr, err := h.conn.ReadFrom(data)
|
||||||
|
@ -155,13 +155,17 @@ func (h *packetHandlerMap) listen() {
|
||||||
}
|
}
|
||||||
data = data[:n]
|
data = data[:n]
|
||||||
|
|
||||||
if err := h.handlePacket(addr, data); err != nil {
|
if err := h.handlePacket(addr, buffer, data); err != nil {
|
||||||
h.logger.Debugf("error handling packet from %s: %s", addr, err)
|
h.logger.Debugf("error handling packet from %s: %s", addr, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *packetHandlerMap) handlePacket(addr net.Addr, data []byte) error {
|
func (h *packetHandlerMap) handlePacket(
|
||||||
|
addr net.Addr,
|
||||||
|
buffer *packetBuffer,
|
||||||
|
data []byte,
|
||||||
|
) error {
|
||||||
r := bytes.NewReader(data)
|
r := bytes.NewReader(data)
|
||||||
hdr, err := wire.ParseHeader(r, h.connIDLen)
|
hdr, err := wire.ParseHeader(r, h.connIDLen)
|
||||||
// drop the packet if we can't parse the header
|
// drop the packet if we can't parse the header
|
||||||
|
@ -172,8 +176,9 @@ func (h *packetHandlerMap) handlePacket(addr net.Addr, data []byte) error {
|
||||||
p := &receivedPacket{
|
p := &receivedPacket{
|
||||||
remoteAddr: addr,
|
remoteAddr: addr,
|
||||||
hdr: hdr,
|
hdr: hdr,
|
||||||
data: data,
|
|
||||||
rcvTime: time.Now(),
|
rcvTime: time.Now(),
|
||||||
|
data: data,
|
||||||
|
buffer: buffer,
|
||||||
}
|
}
|
||||||
|
|
||||||
h.mutex.RLock()
|
h.mutex.RLock()
|
||||||
|
|
|
@ -81,7 +81,7 @@ var _ = Describe("Packet Handler Map", func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
It("drops unparseable packets", func() {
|
It("drops unparseable packets", func() {
|
||||||
err := handler.handlePacket(nil, []byte{0, 1, 2, 3})
|
err := handler.handlePacket(nil, nil, []byte{0, 1, 2, 3})
|
||||||
Expect(err).To(HaveOccurred())
|
Expect(err).To(HaveOccurred())
|
||||||
Expect(err.Error()).To(ContainSubstring("error parsing header:"))
|
Expect(err.Error()).To(ContainSubstring("error parsing header:"))
|
||||||
})
|
})
|
||||||
|
@ -91,7 +91,7 @@ var _ = Describe("Packet Handler Map", func() {
|
||||||
connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
|
connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
|
||||||
handler.Add(connID, NewMockPacketHandler(mockCtrl))
|
handler.Add(connID, NewMockPacketHandler(mockCtrl))
|
||||||
handler.Remove(connID)
|
handler.Remove(connID)
|
||||||
Expect(handler.handlePacket(nil, getPacket(connID))).To(MatchError("received a packet with an unexpected connection ID 0x0102030405060708"))
|
Expect(handler.handlePacket(nil, nil, getPacket(connID))).To(MatchError("received a packet with an unexpected connection ID 0x0102030405060708"))
|
||||||
})
|
})
|
||||||
|
|
||||||
It("deletes retired session entries after a wait time", func() {
|
It("deletes retired session entries after a wait time", func() {
|
||||||
|
@ -100,7 +100,7 @@ var _ = Describe("Packet Handler Map", func() {
|
||||||
handler.Add(connID, NewMockPacketHandler(mockCtrl))
|
handler.Add(connID, NewMockPacketHandler(mockCtrl))
|
||||||
handler.Retire(connID)
|
handler.Retire(connID)
|
||||||
time.Sleep(scaleDuration(30 * time.Millisecond))
|
time.Sleep(scaleDuration(30 * time.Millisecond))
|
||||||
Expect(handler.handlePacket(nil, getPacket(connID))).To(MatchError("received a packet with an unexpected connection ID 0x0102030405060708"))
|
Expect(handler.handlePacket(nil, nil, getPacket(connID))).To(MatchError("received a packet with an unexpected connection ID 0x0102030405060708"))
|
||||||
})
|
})
|
||||||
|
|
||||||
It("passes packets arriving late for closed sessions to that session", func() {
|
It("passes packets arriving late for closed sessions to that session", func() {
|
||||||
|
@ -110,13 +110,13 @@ var _ = Describe("Packet Handler Map", func() {
|
||||||
packetHandler.EXPECT().handlePacket(gomock.Any())
|
packetHandler.EXPECT().handlePacket(gomock.Any())
|
||||||
handler.Add(connID, packetHandler)
|
handler.Add(connID, packetHandler)
|
||||||
handler.Retire(connID)
|
handler.Retire(connID)
|
||||||
err := handler.handlePacket(nil, getPacket(connID))
|
err := handler.handlePacket(nil, nil, getPacket(connID))
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
})
|
})
|
||||||
|
|
||||||
It("drops packets for unknown receivers", func() {
|
It("drops packets for unknown receivers", func() {
|
||||||
connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
|
connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
|
||||||
err := handler.handlePacket(nil, getPacket(connID))
|
err := handler.handlePacket(nil, nil, getPacket(connID))
|
||||||
Expect(err).To(MatchError("received a packet with an unexpected connection ID 0x0102030405060708"))
|
Expect(err).To(MatchError("received a packet with an unexpected connection ID 0x0102030405060708"))
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -171,10 +171,10 @@ var _ = Describe("Packet Handler Map", func() {
|
||||||
handler.AddWithResetToken(connID, NewMockPacketHandler(mockCtrl), token)
|
handler.AddWithResetToken(connID, NewMockPacketHandler(mockCtrl), token)
|
||||||
handler.Retire(connID)
|
handler.Retire(connID)
|
||||||
time.Sleep(scaleDuration(30 * time.Millisecond))
|
time.Sleep(scaleDuration(30 * time.Millisecond))
|
||||||
Expect(handler.handlePacket(nil, getPacket(connID))).To(MatchError("received a packet with an unexpected connection ID 0xdeadbeef42"))
|
Expect(handler.handlePacket(nil, nil, getPacket(connID))).To(MatchError("received a packet with an unexpected connection ID 0xdeadbeef42"))
|
||||||
packet := append([]byte{0x40, 0xde, 0xca, 0xfb, 0xad, 0x99} /* short header packet */, make([]byte, 50)...)
|
packet := append([]byte{0x40, 0xde, 0xca, 0xfb, 0xad, 0x99} /* short header packet */, make([]byte, 50)...)
|
||||||
packet = append(packet, token[:]...)
|
packet = append(packet, token[:]...)
|
||||||
Expect(handler.handlePacket(nil, packet)).To(MatchError("received a short header packet with an unexpected connection ID 0xdecafbad99"))
|
Expect(handler.handlePacket(nil, nil, packet)).To(MatchError("received a short header packet with an unexpected connection ID 0xdecafbad99"))
|
||||||
Expect(handler.resetTokens).To(BeEmpty())
|
Expect(handler.resetTokens).To(BeEmpty())
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
@ -188,7 +188,7 @@ var _ = Describe("Packet Handler Map", func() {
|
||||||
Expect(p.hdr.DestConnectionID).To(Equal(connID))
|
Expect(p.hdr.DestConnectionID).To(Equal(connID))
|
||||||
})
|
})
|
||||||
handler.SetServer(server)
|
handler.SetServer(server)
|
||||||
Expect(handler.handlePacket(nil, p)).To(Succeed())
|
Expect(handler.handlePacket(nil, nil, p)).To(Succeed())
|
||||||
})
|
})
|
||||||
|
|
||||||
It("closes all server sessions", func() {
|
It("closes all server sessions", func() {
|
||||||
|
@ -209,7 +209,7 @@ var _ = Describe("Packet Handler Map", func() {
|
||||||
server := NewMockUnknownPacketHandler(mockCtrl)
|
server := NewMockUnknownPacketHandler(mockCtrl)
|
||||||
handler.SetServer(server)
|
handler.SetServer(server)
|
||||||
handler.CloseServer()
|
handler.CloseServer()
|
||||||
Expect(handler.handlePacket(nil, p)).To(MatchError("received a packet with an unexpected connection ID 0x1122334455667788"))
|
Expect(handler.handlePacket(nil, nil, p)).To(MatchError("received a packet with an unexpected connection ID 0x1122334455667788"))
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
|
@ -28,6 +28,8 @@ type packedPacket struct {
|
||||||
header *wire.ExtendedHeader
|
header *wire.ExtendedHeader
|
||||||
raw []byte
|
raw []byte
|
||||||
frames []wire.Frame
|
frames []wire.Frame
|
||||||
|
|
||||||
|
buffer *packetBuffer
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *packedPacket) EncryptionLevel() protocol.EncryptionLevel {
|
func (p *packedPacket) EncryptionLevel() protocol.EncryptionLevel {
|
||||||
|
@ -374,8 +376,8 @@ func (p *packetPacker) writeAndSealPacket(
|
||||||
frames []wire.Frame,
|
frames []wire.Frame,
|
||||||
sealer handshake.Sealer,
|
sealer handshake.Sealer,
|
||||||
) (*packedPacket, error) {
|
) (*packedPacket, error) {
|
||||||
raw := *getPacketBuffer()
|
packetBuffer := getPacketBuffer()
|
||||||
buffer := bytes.NewBuffer(raw[:0])
|
buffer := bytes.NewBuffer(packetBuffer.Slice[:0])
|
||||||
|
|
||||||
addPaddingForInitial := p.perspective == protocol.PerspectiveClient && header.Type == protocol.PacketTypeInitial
|
addPaddingForInitial := p.perspective == protocol.PerspectiveClient && header.Type == protocol.PacketTypeInitial
|
||||||
|
|
||||||
|
@ -436,7 +438,7 @@ func (p *packetPacker) writeAndSealPacket(
|
||||||
return nil, fmt.Errorf("PacketPacker BUG: packet too large (%d bytes, allowed %d bytes)", size, p.maxPacketSize)
|
return nil, fmt.Errorf("PacketPacker BUG: packet too large (%d bytes, allowed %d bytes)", size, p.maxPacketSize)
|
||||||
}
|
}
|
||||||
|
|
||||||
raw = raw[0:buffer.Len()]
|
raw := buffer.Bytes()
|
||||||
_ = sealer.Seal(raw[payloadOffset:payloadOffset], raw[payloadOffset:], header.PacketNumber, raw[:payloadOffset])
|
_ = sealer.Seal(raw[payloadOffset:payloadOffset], raw[payloadOffset:], header.PacketNumber, raw[:payloadOffset])
|
||||||
raw = raw[0 : buffer.Len()+sealer.Overhead()]
|
raw = raw[0 : buffer.Len()+sealer.Overhead()]
|
||||||
|
|
||||||
|
@ -455,6 +457,7 @@ func (p *packetPacker) writeAndSealPacket(
|
||||||
header: header,
|
header: header,
|
||||||
raw: raw,
|
raw: raw,
|
||||||
frames: frames,
|
frames: frames,
|
||||||
|
buffer: packetBuffer,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -318,21 +318,27 @@ func (s *server) handlePacket(p *receivedPacket) {
|
||||||
}
|
}
|
||||||
if hdr.Type == protocol.PacketTypeInitial {
|
if hdr.Type == protocol.PacketTypeInitial {
|
||||||
go s.handleInitial(p)
|
go s.handleInitial(p)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
putPacketBuffer(p.buffer)
|
||||||
// TODO(#943): send Stateless Reset
|
// TODO(#943): send Stateless Reset
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *server) handleInitial(p *receivedPacket) {
|
func (s *server) handleInitial(p *receivedPacket) {
|
||||||
// TODO: add a check that DestConnID == SrcConnID
|
|
||||||
s.logger.Debugf("<- Received Initial packet.")
|
s.logger.Debugf("<- Received Initial packet.")
|
||||||
sess, connID, err := s.handleInitialImpl(p)
|
sess, connID, err := s.handleInitialImpl(p)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
putPacketBuffer(p.buffer)
|
||||||
s.logger.Errorf("Error occurred handling initial packet: %s", err)
|
s.logger.Errorf("Error occurred handling initial packet: %s", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if sess == nil { // a retry was done
|
if sess == nil { // a retry was done
|
||||||
|
putPacketBuffer(p.buffer)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
// Don't put the packet buffer back if a new session was created.
|
||||||
|
// The session will handle the packet and take of that.
|
||||||
serverSession := newServerSession(sess, s.config, s.logger)
|
serverSession := newServerSession(sess, s.config, s.logger)
|
||||||
s.sessionHandler.Add(connID, serverSession)
|
s.sessionHandler.Add(connID, serverSession)
|
||||||
}
|
}
|
||||||
|
@ -455,6 +461,7 @@ func (s *server) sendRetry(remoteAddr net.Addr, hdr *wire.Header) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *server) sendVersionNegotiationPacket(p *receivedPacket) {
|
func (s *server) sendVersionNegotiationPacket(p *receivedPacket) {
|
||||||
|
defer putPacketBuffer(p.buffer)
|
||||||
hdr := p.hdr
|
hdr := p.hdr
|
||||||
s.logger.Debugf("Client offered version %s, sending Version Negotiation", hdr.Version)
|
s.logger.Debugf("Client offered version %s, sending Version Negotiation", hdr.Version)
|
||||||
data, err := wire.ComposeVersionNegotiation(hdr.SrcConnectionID, hdr.DestConnectionID, s.config.Versions)
|
data, err := wire.ComposeVersionNegotiation(hdr.SrcConnectionID, hdr.DestConnectionID, s.config.Versions)
|
||||||
|
|
|
@ -122,19 +122,19 @@ var _ = Describe("Server", func() {
|
||||||
}
|
}
|
||||||
|
|
||||||
It("drops Initial packets with a too short connection ID", func() {
|
It("drops Initial packets with a too short connection ID", func() {
|
||||||
serv.handlePacket(&receivedPacket{
|
serv.handlePacket(insertPacketBuffer(&receivedPacket{
|
||||||
hdr: &wire.Header{
|
hdr: &wire.Header{
|
||||||
IsLongHeader: true,
|
IsLongHeader: true,
|
||||||
Type: protocol.PacketTypeInitial,
|
Type: protocol.PacketTypeInitial,
|
||||||
DestConnectionID: protocol.ConnectionID{1, 2, 3, 4},
|
DestConnectionID: protocol.ConnectionID{1, 2, 3, 4},
|
||||||
Version: serv.config.Versions[0],
|
Version: serv.config.Versions[0],
|
||||||
},
|
},
|
||||||
})
|
}))
|
||||||
Consistently(conn.dataWritten).ShouldNot(Receive())
|
Consistently(conn.dataWritten).ShouldNot(Receive())
|
||||||
})
|
})
|
||||||
|
|
||||||
It("drops too small Initial", func() {
|
It("drops too small Initial", func() {
|
||||||
serv.handlePacket(&receivedPacket{
|
serv.handlePacket(insertPacketBuffer(&receivedPacket{
|
||||||
hdr: &wire.Header{
|
hdr: &wire.Header{
|
||||||
IsLongHeader: true,
|
IsLongHeader: true,
|
||||||
Type: protocol.PacketTypeInitial,
|
Type: protocol.PacketTypeInitial,
|
||||||
|
@ -142,12 +142,12 @@ var _ = Describe("Server", func() {
|
||||||
Version: serv.config.Versions[0],
|
Version: serv.config.Versions[0],
|
||||||
},
|
},
|
||||||
data: bytes.Repeat([]byte{0}, protocol.MinInitialPacketSize-100),
|
data: bytes.Repeat([]byte{0}, protocol.MinInitialPacketSize-100),
|
||||||
})
|
}))
|
||||||
Consistently(conn.dataWritten).ShouldNot(Receive())
|
Consistently(conn.dataWritten).ShouldNot(Receive())
|
||||||
})
|
})
|
||||||
|
|
||||||
It("drops packets with a too short connection ID", func() {
|
It("drops packets with a too short connection ID", func() {
|
||||||
serv.handlePacket(&receivedPacket{
|
serv.handlePacket(insertPacketBuffer(&receivedPacket{
|
||||||
hdr: &wire.Header{
|
hdr: &wire.Header{
|
||||||
IsLongHeader: true,
|
IsLongHeader: true,
|
||||||
Type: protocol.PacketTypeInitial,
|
Type: protocol.PacketTypeInitial,
|
||||||
|
@ -156,19 +156,19 @@ var _ = Describe("Server", func() {
|
||||||
Version: serv.config.Versions[0],
|
Version: serv.config.Versions[0],
|
||||||
},
|
},
|
||||||
data: bytes.Repeat([]byte{0}, protocol.MinInitialPacketSize),
|
data: bytes.Repeat([]byte{0}, protocol.MinInitialPacketSize),
|
||||||
})
|
}))
|
||||||
Consistently(conn.dataWritten).ShouldNot(Receive())
|
Consistently(conn.dataWritten).ShouldNot(Receive())
|
||||||
})
|
})
|
||||||
|
|
||||||
It("drops non-Initial packets", func() {
|
It("drops non-Initial packets", func() {
|
||||||
serv.logger.SetLogLevel(utils.LogLevelDebug)
|
serv.logger.SetLogLevel(utils.LogLevelDebug)
|
||||||
serv.handlePacket(&receivedPacket{
|
serv.handlePacket(insertPacketBuffer(&receivedPacket{
|
||||||
hdr: &wire.Header{
|
hdr: &wire.Header{
|
||||||
Type: protocol.PacketTypeHandshake,
|
Type: protocol.PacketTypeHandshake,
|
||||||
Version: serv.config.Versions[0],
|
Version: serv.config.Versions[0],
|
||||||
},
|
},
|
||||||
data: []byte("invalid"),
|
data: []byte("invalid"),
|
||||||
})
|
}))
|
||||||
})
|
})
|
||||||
|
|
||||||
It("decodes the cookie from the Token field", func() {
|
It("decodes the cookie from the Token field", func() {
|
||||||
|
@ -185,7 +185,7 @@ var _ = Describe("Server", func() {
|
||||||
}
|
}
|
||||||
token, err := serv.cookieGenerator.NewToken(raddr, nil)
|
token, err := serv.cookieGenerator.NewToken(raddr, nil)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
serv.handlePacket(&receivedPacket{
|
serv.handlePacket(insertPacketBuffer(&receivedPacket{
|
||||||
remoteAddr: raddr,
|
remoteAddr: raddr,
|
||||||
hdr: &wire.Header{
|
hdr: &wire.Header{
|
||||||
Type: protocol.PacketTypeInitial,
|
Type: protocol.PacketTypeInitial,
|
||||||
|
@ -193,7 +193,7 @@ var _ = Describe("Server", func() {
|
||||||
Version: serv.config.Versions[0],
|
Version: serv.config.Versions[0],
|
||||||
},
|
},
|
||||||
data: bytes.Repeat([]byte{0}, protocol.MinInitialPacketSize),
|
data: bytes.Repeat([]byte{0}, protocol.MinInitialPacketSize),
|
||||||
})
|
}))
|
||||||
Eventually(done).Should(BeClosed())
|
Eventually(done).Should(BeClosed())
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -209,7 +209,7 @@ var _ = Describe("Server", func() {
|
||||||
close(done)
|
close(done)
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
serv.handlePacket(&receivedPacket{
|
serv.handlePacket(insertPacketBuffer(&receivedPacket{
|
||||||
remoteAddr: raddr,
|
remoteAddr: raddr,
|
||||||
hdr: &wire.Header{
|
hdr: &wire.Header{
|
||||||
Type: protocol.PacketTypeInitial,
|
Type: protocol.PacketTypeInitial,
|
||||||
|
@ -217,14 +217,14 @@ var _ = Describe("Server", func() {
|
||||||
Version: serv.config.Versions[0],
|
Version: serv.config.Versions[0],
|
||||||
},
|
},
|
||||||
data: bytes.Repeat([]byte{0}, protocol.MinInitialPacketSize),
|
data: bytes.Repeat([]byte{0}, protocol.MinInitialPacketSize),
|
||||||
})
|
}))
|
||||||
Eventually(done).Should(BeClosed())
|
Eventually(done).Should(BeClosed())
|
||||||
})
|
})
|
||||||
|
|
||||||
It("sends a Version Negotiation Packet for unsupported versions", func() {
|
It("sends a Version Negotiation Packet for unsupported versions", func() {
|
||||||
srcConnID := protocol.ConnectionID{1, 2, 3, 4, 5}
|
srcConnID := protocol.ConnectionID{1, 2, 3, 4, 5}
|
||||||
destConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6}
|
destConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6}
|
||||||
serv.handlePacket(&receivedPacket{
|
serv.handlePacket(insertPacketBuffer(&receivedPacket{
|
||||||
remoteAddr: &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337},
|
remoteAddr: &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337},
|
||||||
hdr: &wire.Header{
|
hdr: &wire.Header{
|
||||||
IsLongHeader: true,
|
IsLongHeader: true,
|
||||||
|
@ -233,7 +233,7 @@ var _ = Describe("Server", func() {
|
||||||
DestConnectionID: destConnID,
|
DestConnectionID: destConnID,
|
||||||
Version: 0x42,
|
Version: 0x42,
|
||||||
},
|
},
|
||||||
})
|
}))
|
||||||
var write mockPacketConnWrite
|
var write mockPacketConnWrite
|
||||||
Eventually(conn.dataWritten).Should(Receive(&write))
|
Eventually(conn.dataWritten).Should(Receive(&write))
|
||||||
Expect(write.to.String()).To(Equal("127.0.0.1:1337"))
|
Expect(write.to.String()).To(Equal("127.0.0.1:1337"))
|
||||||
|
@ -253,11 +253,11 @@ var _ = Describe("Server", func() {
|
||||||
DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
|
DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
|
||||||
Version: protocol.VersionTLS,
|
Version: protocol.VersionTLS,
|
||||||
}
|
}
|
||||||
serv.handleInitial(&receivedPacket{
|
serv.handleInitial(insertPacketBuffer(&receivedPacket{
|
||||||
remoteAddr: &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337},
|
remoteAddr: &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337},
|
||||||
hdr: hdr,
|
hdr: hdr,
|
||||||
data: bytes.Repeat([]byte{0}, protocol.MinInitialPacketSize),
|
data: bytes.Repeat([]byte{0}, protocol.MinInitialPacketSize),
|
||||||
})
|
}))
|
||||||
var write mockPacketConnWrite
|
var write mockPacketConnWrite
|
||||||
Eventually(conn.dataWritten).Should(Receive(&write))
|
Eventually(conn.dataWritten).Should(Receive(&write))
|
||||||
Expect(write.to.String()).To(Equal("127.0.0.1:1337"))
|
Expect(write.to.String()).To(Equal("127.0.0.1:1337"))
|
||||||
|
@ -308,7 +308,7 @@ var _ = Describe("Server", func() {
|
||||||
done := make(chan struct{})
|
done := make(chan struct{})
|
||||||
go func() {
|
go func() {
|
||||||
defer GinkgoRecover()
|
defer GinkgoRecover()
|
||||||
serv.handlePacket(p)
|
serv.handlePacket(insertPacketBuffer(p))
|
||||||
// the Handshake packet is written by the session
|
// the Handshake packet is written by the session
|
||||||
Consistently(conn.dataWritten).ShouldNot(Receive())
|
Consistently(conn.dataWritten).ShouldNot(Receive())
|
||||||
close(done)
|
close(done)
|
||||||
|
|
19
session.go
19
session.go
|
@ -53,8 +53,10 @@ type cryptoStreamHandler interface {
|
||||||
type receivedPacket struct {
|
type receivedPacket struct {
|
||||||
remoteAddr net.Addr
|
remoteAddr net.Addr
|
||||||
hdr *wire.Header
|
hdr *wire.Header
|
||||||
data []byte
|
|
||||||
rcvTime time.Time
|
rcvTime time.Time
|
||||||
|
data []byte
|
||||||
|
|
||||||
|
buffer *packetBuffer
|
||||||
}
|
}
|
||||||
|
|
||||||
type closeError struct {
|
type closeError struct {
|
||||||
|
@ -368,9 +370,6 @@ runLoop:
|
||||||
if wasProcessed := s.handlePacketImpl(p); !wasProcessed {
|
if wasProcessed := s.handlePacketImpl(p); !wasProcessed {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
// This is a bit unclean, but works properly, since the packet always
|
|
||||||
// begins with the public header and we never copy it.
|
|
||||||
// TODO: putPacketBuffer(&p.extHdr.Raw)
|
|
||||||
case <-s.handshakeCompleteChan:
|
case <-s.handshakeCompleteChan:
|
||||||
s.handleHandshakeComplete()
|
s.handleHandshakeComplete()
|
||||||
}
|
}
|
||||||
|
@ -475,6 +474,15 @@ func (s *session) handleHandshakeComplete() {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *session) handlePacketImpl(p *receivedPacket) bool /* was the packet successfully processed */ {
|
func (s *session) handlePacketImpl(p *receivedPacket) bool /* was the packet successfully processed */ {
|
||||||
|
var wasQueued bool
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
// Put back the packet buffer if the packet wasn't queued for later decryption.
|
||||||
|
if !wasQueued {
|
||||||
|
putPacketBuffer(p.buffer)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
// The server can change the source connection ID with the first Handshake packet.
|
// The server can change the source connection ID with the first Handshake packet.
|
||||||
// After this, all packets with a different source connection have to be ignored.
|
// After this, all packets with a different source connection have to be ignored.
|
||||||
if s.receivedFirstPacket && p.hdr.IsLongHeader && !p.hdr.SrcConnectionID.Equal(s.destConnID) {
|
if s.receivedFirstPacket && p.hdr.IsLongHeader && !p.hdr.SrcConnectionID.Equal(s.destConnID) {
|
||||||
|
@ -490,6 +498,7 @@ func (s *session) handlePacketImpl(p *receivedPacket) bool /* was the packet suc
|
||||||
// if the decryption failed, this might be a packet sent by an attacker
|
// if the decryption failed, this might be a packet sent by an attacker
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == handshake.ErrOpenerNotYetAvailable {
|
if err == handshake.ErrOpenerNotYetAvailable {
|
||||||
|
wasQueued = true
|
||||||
s.tryQueueingUndecryptablePacket(p)
|
s.tryQueueingUndecryptablePacket(p)
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
@ -953,7 +962,7 @@ func (s *session) sendPacket() (bool, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *session) sendPackedPacket(packet *packedPacket) error {
|
func (s *session) sendPackedPacket(packet *packedPacket) error {
|
||||||
defer putPacketBuffer(&packet.raw)
|
defer putPacketBuffer(packet.buffer)
|
||||||
s.logPacket(packet)
|
s.logPacket(packet)
|
||||||
return s.conn.Write(packet.raw)
|
return s.conn.Write(packet.raw)
|
||||||
}
|
}
|
||||||
|
|
|
@ -61,6 +61,11 @@ func areSessionsRunning() bool {
|
||||||
return strings.Contains(b.String(), "quic-go.(*session).run")
|
return strings.Contains(b.String(), "quic-go.(*session).run")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func insertPacketBuffer(p *receivedPacket) *receivedPacket {
|
||||||
|
p.buffer = getPacketBuffer()
|
||||||
|
return p
|
||||||
|
}
|
||||||
|
|
||||||
var _ = Describe("Session", func() {
|
var _ = Describe("Session", func() {
|
||||||
var (
|
var (
|
||||||
sess *session
|
sess *session
|
||||||
|
@ -496,11 +501,11 @@ var _ = Describe("Session", func() {
|
||||||
rph := mockackhandler.NewMockReceivedPacketHandler(mockCtrl)
|
rph := mockackhandler.NewMockReceivedPacketHandler(mockCtrl)
|
||||||
rph.EXPECT().ReceivedPacket(protocol.PacketNumber(0x1337), rcvTime, false)
|
rph.EXPECT().ReceivedPacket(protocol.PacketNumber(0x1337), rcvTime, false)
|
||||||
sess.receivedPacketHandler = rph
|
sess.receivedPacketHandler = rph
|
||||||
Expect(sess.handlePacketImpl(&receivedPacket{
|
Expect(sess.handlePacketImpl(insertPacketBuffer(&receivedPacket{
|
||||||
rcvTime: rcvTime,
|
rcvTime: rcvTime,
|
||||||
hdr: &hdr.Header,
|
hdr: &hdr.Header,
|
||||||
data: getData(hdr),
|
data: getData(hdr),
|
||||||
})).To(BeTrue())
|
}))).To(BeTrue())
|
||||||
})
|
})
|
||||||
|
|
||||||
It("closes when handling a packet fails", func() {
|
It("closes when handling a packet fails", func() {
|
||||||
|
@ -518,7 +523,10 @@ var _ = Describe("Session", func() {
|
||||||
close(done)
|
close(done)
|
||||||
}()
|
}()
|
||||||
sessionRunner.EXPECT().retireConnectionID(gomock.Any())
|
sessionRunner.EXPECT().retireConnectionID(gomock.Any())
|
||||||
sess.handlePacket(&receivedPacket{hdr: &wire.Header{}, data: getData(&wire.ExtendedHeader{PacketNumberLen: protocol.PacketNumberLen1})})
|
sess.handlePacket(insertPacketBuffer(&receivedPacket{
|
||||||
|
hdr: &wire.Header{},
|
||||||
|
data: getData(&wire.ExtendedHeader{PacketNumberLen: protocol.PacketNumberLen1}),
|
||||||
|
}))
|
||||||
Eventually(done).Should(BeClosed())
|
Eventually(done).Should(BeClosed())
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -528,18 +536,18 @@ var _ = Describe("Session", func() {
|
||||||
PacketNumberLen: protocol.PacketNumberLen1,
|
PacketNumberLen: protocol.PacketNumberLen1,
|
||||||
}
|
}
|
||||||
unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any()).Return(&unpackedPacket{hdr: hdr}, nil).Times(2)
|
unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any()).Return(&unpackedPacket{hdr: hdr}, nil).Times(2)
|
||||||
Expect(sess.handlePacketImpl(&receivedPacket{hdr: &hdr.Header, data: getData(hdr)})).To(BeTrue())
|
Expect(sess.handlePacketImpl(insertPacketBuffer(&receivedPacket{hdr: &hdr.Header, data: getData(hdr)}))).To(BeTrue())
|
||||||
Expect(sess.handlePacketImpl(&receivedPacket{hdr: &hdr.Header, data: getData(hdr)})).To(BeTrue())
|
Expect(sess.handlePacketImpl(insertPacketBuffer(&receivedPacket{hdr: &hdr.Header, data: getData(hdr)}))).To(BeTrue())
|
||||||
})
|
})
|
||||||
|
|
||||||
It("ignores 0-RTT packets", func() {
|
It("ignores 0-RTT packets", func() {
|
||||||
Expect(sess.handlePacketImpl(&receivedPacket{
|
Expect(sess.handlePacketImpl(insertPacketBuffer(&receivedPacket{
|
||||||
hdr: &wire.Header{
|
hdr: &wire.Header{
|
||||||
IsLongHeader: true,
|
IsLongHeader: true,
|
||||||
Type: protocol.PacketType0RTT,
|
Type: protocol.PacketType0RTT,
|
||||||
DestConnectionID: sess.srcConnID,
|
DestConnectionID: sess.srcConnID,
|
||||||
},
|
},
|
||||||
})).To(BeFalse())
|
}))).To(BeFalse())
|
||||||
})
|
})
|
||||||
|
|
||||||
It("ignores packets with a different source connection ID", func() {
|
It("ignores packets with a different source connection ID", func() {
|
||||||
|
@ -552,12 +560,12 @@ var _ = Describe("Session", func() {
|
||||||
// Send one packet, which might change the connection ID.
|
// Send one packet, which might change the connection ID.
|
||||||
// only EXPECT one call to the unpacker
|
// only EXPECT one call to the unpacker
|
||||||
unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any()).Return(&unpackedPacket{hdr: &wire.ExtendedHeader{Header: *hdr}}, nil)
|
unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any()).Return(&unpackedPacket{hdr: &wire.ExtendedHeader{Header: *hdr}}, nil)
|
||||||
Expect(sess.handlePacketImpl(&receivedPacket{
|
Expect(sess.handlePacketImpl(insertPacketBuffer(&receivedPacket{
|
||||||
hdr: hdr,
|
hdr: hdr,
|
||||||
data: getData(&wire.ExtendedHeader{PacketNumberLen: protocol.PacketNumberLen1}),
|
data: getData(&wire.ExtendedHeader{PacketNumberLen: protocol.PacketNumberLen1}),
|
||||||
})).To(BeTrue())
|
}))).To(BeTrue())
|
||||||
// The next packet has to be ignored, since the source connection ID doesn't match.
|
// The next packet has to be ignored, since the source connection ID doesn't match.
|
||||||
Expect(sess.handlePacketImpl(&receivedPacket{
|
Expect(sess.handlePacketImpl(insertPacketBuffer(&receivedPacket{
|
||||||
hdr: &wire.Header{
|
hdr: &wire.Header{
|
||||||
IsLongHeader: true,
|
IsLongHeader: true,
|
||||||
DestConnectionID: sess.destConnID,
|
DestConnectionID: sess.destConnID,
|
||||||
|
@ -565,7 +573,7 @@ var _ = Describe("Session", func() {
|
||||||
Length: 1,
|
Length: 1,
|
||||||
},
|
},
|
||||||
data: getData(&wire.ExtendedHeader{PacketNumberLen: protocol.PacketNumberLen1}),
|
data: getData(&wire.ExtendedHeader{PacketNumberLen: protocol.PacketNumberLen1}),
|
||||||
})).To(BeFalse())
|
}))).To(BeFalse())
|
||||||
})
|
})
|
||||||
|
|
||||||
Context("updating the remote address", func() {
|
Context("updating the remote address", func() {
|
||||||
|
@ -574,12 +582,11 @@ var _ = Describe("Session", func() {
|
||||||
origAddr := sess.conn.(*mockConnection).remoteAddr
|
origAddr := sess.conn.(*mockConnection).remoteAddr
|
||||||
remoteIP := &net.IPAddr{IP: net.IPv4(192, 168, 0, 100)}
|
remoteIP := &net.IPAddr{IP: net.IPv4(192, 168, 0, 100)}
|
||||||
Expect(origAddr).ToNot(Equal(remoteIP))
|
Expect(origAddr).ToNot(Equal(remoteIP))
|
||||||
p := receivedPacket{
|
Expect(sess.handlePacketImpl(insertPacketBuffer(&receivedPacket{
|
||||||
remoteAddr: remoteIP,
|
remoteAddr: remoteIP,
|
||||||
hdr: &wire.Header{},
|
hdr: &wire.Header{},
|
||||||
data: getData(&wire.ExtendedHeader{PacketNumberLen: protocol.PacketNumberLen1}),
|
data: getData(&wire.ExtendedHeader{PacketNumberLen: protocol.PacketNumberLen1}),
|
||||||
}
|
}))).To(BeTrue())
|
||||||
Expect(sess.handlePacketImpl(&p)).To(BeTrue())
|
|
||||||
Expect(sess.conn.(*mockConnection).remoteAddr).To(Equal(origAddr))
|
Expect(sess.conn.(*mockConnection).remoteAddr).To(Equal(origAddr))
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
@ -587,10 +594,12 @@ var _ = Describe("Session", func() {
|
||||||
|
|
||||||
Context("sending packets", func() {
|
Context("sending packets", func() {
|
||||||
getPacket := func(pn protocol.PacketNumber) *packedPacket {
|
getPacket := func(pn protocol.PacketNumber) *packedPacket {
|
||||||
data := *getPacketBuffer()
|
buffer := getPacketBuffer()
|
||||||
|
data := buffer.Slice[:0]
|
||||||
data = append(data, []byte("foobar")...)
|
data = append(data, []byte("foobar")...)
|
||||||
return &packedPacket{
|
return &packedPacket{
|
||||||
raw: data,
|
raw: data,
|
||||||
|
buffer: buffer,
|
||||||
header: &wire.ExtendedHeader{PacketNumber: pn},
|
header: &wire.ExtendedHeader{PacketNumber: pn},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -963,7 +972,7 @@ var _ = Describe("Session", func() {
|
||||||
defer close(done)
|
defer close(done)
|
||||||
return &packedPacket{
|
return &packedPacket{
|
||||||
header: &wire.ExtendedHeader{},
|
header: &wire.ExtendedHeader{},
|
||||||
raw: *getPacketBuffer(),
|
buffer: getPacketBuffer(),
|
||||||
}, nil
|
}, nil
|
||||||
}),
|
}),
|
||||||
packer.EXPECT().PackPacket().AnyTimes(),
|
packer.EXPECT().PackPacket().AnyTimes(),
|
||||||
|
@ -1352,7 +1361,7 @@ var _ = Describe("Client Session", func() {
|
||||||
}()
|
}()
|
||||||
newConnID := protocol.ConnectionID{1, 3, 3, 7, 1, 3, 3, 7}
|
newConnID := protocol.ConnectionID{1, 3, 3, 7, 1, 3, 3, 7}
|
||||||
packer.EXPECT().ChangeDestConnectionID(newConnID)
|
packer.EXPECT().ChangeDestConnectionID(newConnID)
|
||||||
Expect(sess.handlePacketImpl(&receivedPacket{
|
Expect(sess.handlePacketImpl(insertPacketBuffer(&receivedPacket{
|
||||||
hdr: &wire.Header{
|
hdr: &wire.Header{
|
||||||
IsLongHeader: true,
|
IsLongHeader: true,
|
||||||
Type: protocol.PacketTypeHandshake,
|
Type: protocol.PacketTypeHandshake,
|
||||||
|
@ -1361,7 +1370,7 @@ var _ = Describe("Client Session", func() {
|
||||||
Length: 1,
|
Length: 1,
|
||||||
},
|
},
|
||||||
data: []byte{0},
|
data: []byte{0},
|
||||||
})).To(BeTrue())
|
}))).To(BeTrue())
|
||||||
// make sure the go routine returns
|
// make sure the go routine returns
|
||||||
packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil)
|
packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil)
|
||||||
sessionRunner.EXPECT().retireConnectionID(gomock.Any())
|
sessionRunner.EXPECT().retireConnectionID(gomock.Any())
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue