diff --git a/buffer_pool.go b/buffer_pool.go index e5646820..c890d32b 100644 --- a/buffer_pool.go +++ b/buffer_pool.go @@ -15,6 +15,13 @@ type packetBuffer struct { refCount int } +// Split increases the refCount. +// It must be called when a packet buffer is used for more than one packet, +// e.g. when splitting coalesced packets. +func (b *packetBuffer) Split() { + b.refCount++ +} + var bufferPool sync.Pool func getPacketBuffer() *packetBuffer { diff --git a/buffer_pool_test.go b/buffer_pool_test.go index 8b00b4d5..ef6b4852 100644 --- a/buffer_pool_test.go +++ b/buffer_pool_test.go @@ -29,4 +29,15 @@ var _ = Describe("Buffer Pool", func() { putPacketBuffer(buf) Expect(func() { putPacketBuffer(buf) }).To(Panic()) }) + + It("waits until all parts have been put back", func() { + buf := getPacketBuffer() + buf.Split() + buf.Split() + // now we have 3 parts + putPacketBuffer(buf) + putPacketBuffer(buf) + putPacketBuffer(buf) + Expect(func() { putPacketBuffer(buf) }).To(Panic()) + }) }) diff --git a/packet_handler_map.go b/packet_handler_map.go index 114d2a18..5c2a1948 100644 --- a/packet_handler_map.go +++ b/packet_handler_map.go @@ -153,11 +153,7 @@ func (h *packetHandlerMap) listen() { h.close(err) return } - data = data[:n] - - if err := h.handlePacket(addr, buffer, data); err != nil { - h.logger.Debugf("error handling packet from %s: %s", addr, err) - } + h.handlePacket(addr, buffer, data[:n]) } } @@ -165,56 +161,102 @@ func (h *packetHandlerMap) handlePacket( addr net.Addr, buffer *packetBuffer, data []byte, -) error { - r := bytes.NewReader(data) - hdr, err := wire.ParseHeader(r, h.connIDLen) - // drop the packet if we can't parse the header +) { + packets, err := h.parsePacket(addr, buffer, data) if err != nil { - return fmt.Errorf("error parsing header: %s", err) + h.logger.Debugf("error parsing packets from %s: %s", addr, err) + // This is just the error from parsing the last packet. + // We still need to process the packets that were successfully parsed before. } + if len(packets) == 0 { + putPacketBuffer(buffer) + return + } + h.handleParsedPackets(packets) +} - if hdr.IsLongHeader { - if protocol.ByteCount(r.Len()) < hdr.Length { - return fmt.Errorf("packet length (%d bytes) is smaller than the expected length (%d bytes)", len(data)-int(hdr.ParsedLen()), hdr.Length) +func (h *packetHandlerMap) parsePacket( + addr net.Addr, + buffer *packetBuffer, + data []byte, +) ([]*receivedPacket, error) { + rcvTime := time.Now() + packets := make([]*receivedPacket, 0, 1) + + var counter int + var lastConnID protocol.ConnectionID + for len(data) > 0 { + if counter > 0 && h.logger.Debug() { + h.logger.Debugf("Parsed a coalesced packet. Part %d: %d bytes", counter, len(packets[counter-1].data)) } - data = data[:int(hdr.ParsedLen()+hdr.Length)] - // TODO(#1312): implement parsing of compound packets - } - p := &receivedPacket{ - remoteAddr: addr, - hdr: hdr, - rcvTime: time.Now(), - data: data, - buffer: buffer, - } + hdr, err := wire.ParseHeader(bytes.NewReader(data), h.connIDLen) + // drop the packet if we can't parse the header + if err != nil { + return packets, fmt.Errorf("error parsing header: %s", err) + } + if counter > 0 && !hdr.DestConnectionID.Equal(lastConnID) { + return packets, fmt.Errorf("coalesced packet has different destination connection ID: %s, expected %s", hdr.DestConnectionID, lastConnID) + } + lastConnID = hdr.DestConnectionID + var rest []byte + if hdr.IsLongHeader { + if protocol.ByteCount(len(data)) < hdr.Length { + return packets, fmt.Errorf("packet length (%d bytes) is smaller than the expected length (%d bytes)", len(data)-int(hdr.ParsedLen()), hdr.Length) + } + packetLen := int(hdr.ParsedLen() + hdr.Length) + rest = data[packetLen:] + data = data[:packetLen] + } + + if counter > 0 { + buffer.Split() + } + counter++ + packets = append(packets, &receivedPacket{ + remoteAddr: addr, + hdr: hdr, + rcvTime: rcvTime, + data: data, + buffer: buffer, + }) + data = rest + } + return packets, nil +} + +func (h *packetHandlerMap) handleParsedPackets(packets []*receivedPacket) { h.mutex.RLock() defer h.mutex.RUnlock() - handlerEntry, handlerFound := h.handlers[string(hdr.DestConnectionID)] + // coalesced packets all have the same destination connection ID + handlerEntry, handlerFound := h.handlers[string(packets[0].hdr.DestConnectionID)] - if handlerFound { // existing session - handlerEntry.handler.handlePacket(p) - return nil - } - // No session found. - // This might be a stateless reset. - if !hdr.IsLongHeader { - if len(data) >= protocol.MinStatelessResetSize { - var token [16]byte - copy(token[:], data[len(data)-16:]) - if sess, ok := h.resetTokens[token]; ok { - sess.destroy(errors.New("received a stateless reset")) - return nil - } + for _, p := range packets { + if handlerFound { // existing session + handlerEntry.handler.handlePacket(p) + continue } - // TODO(#943): send a stateless reset - return fmt.Errorf("received a short header packet with an unexpected connection ID %s", hdr.DestConnectionID) + // No session found. + // This might be a stateless reset. + if !p.hdr.IsLongHeader { + if len(p.data) >= protocol.MinStatelessResetSize { + var token [16]byte + copy(token[:], p.data[len(p.data)-16:]) + if sess, ok := h.resetTokens[token]; ok { + sess.destroy(errors.New("received a stateless reset")) + continue + } + } + // TODO(#943): send a stateless reset + h.logger.Debugf("received a short header packet with an unexpected connection ID %s", p.hdr.DestConnectionID) + break // a short header packet is always the last in a coalesced packet + + } + if h.server != nil { // no server set + h.server.handlePacket(p) + } + h.logger.Debugf("received a packet with an unexpected connection ID %s", p.hdr.DestConnectionID) } - if h.server == nil { // no server set - return fmt.Errorf("received a packet with an unexpected connection ID %s", hdr.DestConnectionID) - } - h.server.handlePacket(p) - return nil } diff --git a/packet_handler_map_test.go b/packet_handler_map_test.go index 8a0141d3..6a4e9ece 100644 --- a/packet_handler_map_test.go +++ b/packet_handler_map_test.go @@ -3,6 +3,7 @@ package quic import ( "bytes" "errors" + "net" "time" "github.com/golang/mock/gomock" @@ -35,7 +36,7 @@ var _ = Describe("Packet Handler Map", func() { } getPacket := func(connID protocol.ConnectionID) []byte { - return getPacketWithLength(connID, 1) + return getPacketWithLength(connID, 2) } BeforeEach(func() { @@ -85,7 +86,7 @@ var _ = Describe("Packet Handler Map", func() { }) It("drops unparseable packets", func() { - err := handler.handlePacket(nil, nil, []byte{0, 1, 2, 3}) + _, err := handler.parsePacket(nil, nil, []byte{0, 1, 2, 3}) Expect(err).To(HaveOccurred()) Expect(err.Error()).To(ContainSubstring("error parsing header:")) }) @@ -95,7 +96,8 @@ var _ = Describe("Packet Handler Map", func() { connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} handler.Add(connID, NewMockPacketHandler(mockCtrl)) handler.Remove(connID) - Expect(handler.handlePacket(nil, nil, getPacket(connID))).To(MatchError("received a packet with an unexpected connection ID 0x0102030405060708")) + handler.handlePacket(nil, nil, getPacket(connID)) + // don't EXPECT any calls to handlePacket of the MockPacketHandler }) It("deletes retired session entries after a wait time", func() { @@ -104,7 +106,8 @@ var _ = Describe("Packet Handler Map", func() { handler.Add(connID, NewMockPacketHandler(mockCtrl)) handler.Retire(connID) time.Sleep(scaleDuration(30 * time.Millisecond)) - Expect(handler.handlePacket(nil, nil, getPacket(connID))).To(MatchError("received a packet with an unexpected connection ID 0x0102030405060708")) + handler.handlePacket(nil, nil, getPacket(connID)) + // don't EXPECT any calls to handlePacket of the MockPacketHandler }) It("passes packets arriving late for closed sessions to that session", func() { @@ -114,14 +117,12 @@ var _ = Describe("Packet Handler Map", func() { packetHandler.EXPECT().handlePacket(gomock.Any()) handler.Add(connID, packetHandler) handler.Retire(connID) - err := handler.handlePacket(nil, nil, getPacket(connID)) - Expect(err).ToNot(HaveOccurred()) + handler.handlePacket(nil, nil, getPacket(connID)) }) It("drops packets for unknown receivers", func() { connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} - err := handler.handlePacket(nil, nil, getPacket(connID)) - Expect(err).To(MatchError("received a packet with an unexpected connection ID 0x0102030405060708")) + handler.handlePacket(nil, nil, getPacket(connID)) }) It("closes the packet handlers when reading from the conn fails", func() { @@ -136,22 +137,73 @@ var _ = Describe("Packet Handler Map", func() { Eventually(done).Should(BeClosed()) }) - It("errors on packets that are smaller than the length in the packet header", func() { - connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} - data := append(getPacketWithLength(connID, 1000), make([]byte, 500-2 /* for packet number length */)...) - err := handler.handlePacket(nil, nil, data) - Expect(err).To(MatchError("packet length (500 bytes) is smaller than the expected length (1000 bytes)")) - }) - - It("cuts packets to the right length", func() { - connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} - data := append(getPacketWithLength(connID, 456), make([]byte, 1000)...) - packetHandler := NewMockPacketHandler(mockCtrl) - packetHandler.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) { - Expect(p.data).To(HaveLen(456 + int(p.hdr.ParsedLen()))) + Context("coalesced packets", func() { + It("errors on packets that are smaller than the length in the packet header", func() { + connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} + data := append(getPacketWithLength(connID, 1000), make([]byte, 500-2 /* for packet number length */)...) + _, err := handler.parsePacket(nil, nil, data) + Expect(err).To(MatchError("packet length (500 bytes) is smaller than the expected length (1000 bytes)")) + }) + + It("cuts packets to the right length", func() { + connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} + data := append(getPacketWithLength(connID, 456), make([]byte, 1000)...) + packetHandler := NewMockPacketHandler(mockCtrl) + packetHandler.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) { + Expect(p.data).To(HaveLen(456 + int(p.hdr.ParsedLen()))) + }) + handler.Add(connID, packetHandler) + handler.handlePacket(nil, nil, data) + }) + + It("handles coalesced packets", func() { + connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} + packetHandler := NewMockPacketHandler(mockCtrl) + handledPackets := make(chan *receivedPacket, 3) + packetHandler.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) { + handledPackets <- p + }).Times(3) + handler.Add(connID, packetHandler) + + buffer := getPacketBuffer() + packet := buffer.Slice[:0] + packet = append(packet, append(getPacketWithLength(connID, 10), make([]byte, 10-2 /* packet number len */)...)...) + packet = append(packet, append(getPacketWithLength(connID, 20), make([]byte, 20-2 /* packet number len */)...)...) + packet = append(packet, append(getPacketWithLength(connID, 30), make([]byte, 30-2 /* packet number len */)...)...) + conn.dataToRead <- packet + + now := time.Now() + for i := 1; i <= 3; i++ { + var p *receivedPacket + Eventually(handledPackets).Should(Receive(&p)) + Expect(p.hdr.DestConnectionID).To(Equal(connID)) + Expect(p.hdr.Length).To(BeEquivalentTo(10 * i)) + Expect(p.data).To(HaveLen(int(p.hdr.ParsedLen() + p.hdr.Length))) + Expect(p.rcvTime).To(BeTemporally("~", now, scaleDuration(20*time.Millisecond))) + Expect(p.buffer.refCount).To(Equal(3)) + } + + // makes the listen go routine return + packetHandler.EXPECT().destroy(gomock.Any()).AnyTimes() + close(conn.dataToRead) + }) + + It("ignores coalesced packet parts if the connection IDs don't match", func() { + connID1 := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} + connID2 := protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1} + + buffer := getPacketBuffer() + packet := buffer.Slice[:0] + // var packet []byte + packet = append(packet, getPacket(connID1)...) + packet = append(packet, getPacket(connID2)...) + + packets, err := handler.parsePacket(&net.UDPAddr{}, buffer, packet) + Expect(err).To(MatchError("coalesced packet has different destination connection ID: 0x0807060504030201, expected 0x0102030405060708")) + Expect(packets).To(HaveLen(1)) + Expect(packets[0].hdr.DestConnectionID).To(Equal(connID1)) + Expect(packets[0].buffer.refCount).To(Equal(1)) }) - handler.Add(connID, packetHandler) - Expect(handler.handlePacket(nil, nil, data)).To(Succeed()) }) }) @@ -186,6 +238,24 @@ var _ = Describe("Packet Handler Map", func() { Eventually(destroyed).Should(BeClosed()) }) + It("detects a stateless that is coalesced with another packet", func() { + packetHandler := NewMockPacketHandler(mockCtrl) + connID := protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad} + token := [16]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} + handler.AddWithResetToken(connID, packetHandler, token) + fakeConnID := protocol.ConnectionID{1, 2, 3, 4, 5} + packet := getPacket(fakeConnID) + reset := append([]byte{0x40} /* short header packet */, fakeConnID...) + reset = append(reset, make([]byte, 50)...) // add some "random" data + reset = append(reset, token[:]...) + destroyed := make(chan struct{}) + packetHandler.EXPECT().destroy(errors.New("received a stateless reset")).Do(func(error) { + close(destroyed) + }) + conn.dataToRead <- append(packet, reset...) + Eventually(destroyed).Should(BeClosed()) + }) + It("deletes reset tokens when the session is retired", func() { handler.deleteRetiredSessionsAfter = scaleDuration(10 * time.Millisecond) connID := protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef, 0x42} @@ -193,10 +263,12 @@ var _ = Describe("Packet Handler Map", func() { handler.AddWithResetToken(connID, NewMockPacketHandler(mockCtrl), token) handler.Retire(connID) time.Sleep(scaleDuration(30 * time.Millisecond)) - Expect(handler.handlePacket(nil, nil, getPacket(connID))).To(MatchError("received a packet with an unexpected connection ID 0xdeadbeef42")) + handler.handlePacket(nil, nil, getPacket(connID)) + // don't EXPECT any calls to handlePacket of the MockPacketHandler packet := append([]byte{0x40, 0xde, 0xca, 0xfb, 0xad, 0x99} /* short header packet */, make([]byte, 50)...) packet = append(packet, token[:]...) - Expect(handler.handlePacket(nil, nil, packet)).To(MatchError("received a short header packet with an unexpected connection ID 0xdecafbad99")) + handler.handlePacket(nil, nil, packet) + // don't EXPECT any calls to handlePacket of the MockPacketHandler Expect(handler.resetTokens).To(BeEmpty()) }) }) @@ -210,7 +282,7 @@ var _ = Describe("Packet Handler Map", func() { Expect(p.hdr.DestConnectionID).To(Equal(connID)) }) handler.SetServer(server) - Expect(handler.handlePacket(nil, nil, p)).To(Succeed()) + handler.handlePacket(nil, nil, p) }) It("closes all server sessions", func() { @@ -229,9 +301,10 @@ var _ = Describe("Packet Handler Map", func() { connID := protocol.ConnectionID{0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88} p := getPacket(connID) server := NewMockUnknownPacketHandler(mockCtrl) + // don't EXPECT any calls to server.handlePacket handler.SetServer(server) handler.CloseServer() - Expect(handler.handlePacket(nil, nil, p)).To(MatchError("received a packet with an unexpected connection ID 0x1122334455667788")) + handler.handlePacket(nil, nil, p) }) }) })