diff --git a/client.go b/client.go index e5504f68..e03f822b 100644 --- a/client.go +++ b/client.go @@ -314,6 +314,16 @@ func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) error { return errors.New("received packet with truncated connection ID, but didn't request truncation") } hdr.Raw = packet[:len(packet)-r.Len()] + packetData := packet[len(packet)-r.Len():] + + if hdr.IsLongHeader { + c.logger.Debugf("len(packet data): %d, payloadLen: %d", len(packetData), hdr.PayloadLen) + if protocol.ByteCount(len(packetData)) < hdr.PayloadLen { + return fmt.Errorf("packet payload (%d bytes) is smaller than the expected payload length (%d bytes)", len(packetData), hdr.PayloadLen) + } + packetData = packetData[:int(hdr.PayloadLen)] + // TODO(#1312): implement parsing of compound packets + } c.mutex.Lock() defer c.mutex.Unlock() @@ -366,7 +376,7 @@ func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) error { c.session.handlePacket(&receivedPacket{ remoteAddr: remoteAddr, header: hdr, - data: packet[len(packet)-r.Len():], + data: packetData, rcvTime: rcvTime, }) return nil diff --git a/client_test.go b/client_test.go index a225da6a..6b5ad873 100644 --- a/client_test.go +++ b/client_test.go @@ -435,7 +435,7 @@ var _ = Describe("Client", func() { // it didn't pass the version negoation packet to the old session (since it has no payload) Eventually(func() bool { return firstSession.closed }).Should(BeTrue()) Expect(firstSession.closeReason).To(Equal(errCloseSessionForNewVersion)) - Expect(firstSession.packetCount).To(BeZero()) + Expect(firstSession.handledPackets).To(BeEmpty()) Eventually(sessionChan).Should(Receive(&secondSession)) // make the server accept the new version packetConn.dataToRead <- acceptClientVersionPacket(secondSession.connectionID) @@ -516,10 +516,42 @@ var _ = Describe("Client", func() { err := cl.handlePacket(addr, []byte("invalid packet")) Expect(err).To(HaveOccurred()) Expect(err.Error()).To(ContainSubstring("error parsing packet from")) - Expect(sess.packetCount).To(BeZero()) + Expect(sess.handledPackets).To(BeEmpty()) Expect(sess.closed).To(BeFalse()) }) + It("errors on packets that are smaller than the Payload Length in the packet header", func() { + b := &bytes.Buffer{} + hdr := &wire.Header{ + IsLongHeader: true, + Type: protocol.PacketTypeHandshake, + PayloadLen: 1000, + SrcConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, + DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, + Version: versionIETFFrames, + } + Expect(hdr.Write(b, protocol.PerspectiveClient, versionIETFFrames)).To(Succeed()) + cl.handlePacket(addr, append(b.Bytes(), make([]byte, 456)...)) + Expect(sess.handledPackets).To(BeEmpty()) + Expect(sess.closed).To(BeFalse()) + }) + + It("cuts packets at the payload length", func() { + b := &bytes.Buffer{} + hdr := &wire.Header{ + IsLongHeader: true, + Type: protocol.PacketTypeHandshake, + PayloadLen: 123, + SrcConnectionID: connID, + DestConnectionID: connID, + Version: versionIETFFrames, + } + Expect(hdr.Write(b, protocol.PerspectiveClient, versionIETFFrames)).To(Succeed()) + cl.handlePacket(addr, append(b.Bytes(), make([]byte, 456)...)) + Expect(sess.handledPackets).To(HaveLen(1)) + Expect(sess.handledPackets[0].data).To(HaveLen(123)) + }) + It("ignores packets without connection id, if it didn't request connection id trunctation", func() { cl.config = &Config{RequestConnectionIDOmission: false} buf := &bytes.Buffer{} @@ -533,7 +565,7 @@ var _ = Describe("Client", func() { Expect(err).ToNot(HaveOccurred()) err = cl.handlePacket(addr, buf.Bytes()) Expect(err).To(MatchError("received packet with truncated connection ID, but didn't request truncation")) - Expect(sess.packetCount).To(BeZero()) + Expect(sess.handledPackets).To(BeEmpty()) Expect(sess.closed).To(BeFalse()) }) @@ -552,7 +584,7 @@ var _ = Describe("Client", func() { Expect(err).ToNot(HaveOccurred()) err = cl.handlePacket(addr, buf.Bytes()) Expect(err).To(MatchError(fmt.Sprintf("received a packet with an unexpected connection ID (0x0807060504030201, expected %s)", connID))) - Expect(sess.packetCount).To(BeZero()) + Expect(sess.handledPackets).To(BeEmpty()) Expect(sess.closed).To(BeFalse()) }) @@ -647,7 +679,7 @@ var _ = Describe("Client", func() { Expect(err).ToNot(HaveOccurred()) packetConn.dataToRead <- b.Bytes() - Expect(sess.packetCount).To(BeZero()) + Expect(sess.handledPackets).To(BeEmpty()) stoppedListening := make(chan struct{}) go func() { cl.listen() @@ -655,7 +687,7 @@ var _ = Describe("Client", func() { close(stoppedListening) }() - Eventually(func() int { return sess.packetCount }).Should(Equal(1)) + Eventually(func() []*receivedPacket { return sess.handledPackets }).Should(HaveLen(1)) Expect(sess.closed).To(BeFalse()) Consistently(stoppedListening).ShouldNot(BeClosed()) }) diff --git a/server.go b/server.go index b72fd2b4..23e06b6d 100644 --- a/server.go +++ b/server.go @@ -308,6 +308,14 @@ func (s *server) handlePacket(pconn net.PacketConn, remoteAddr net.Addr, packet hdr.Raw = packet[:len(packet)-r.Len()] packetData := packet[len(packet)-r.Len():] + if hdr.IsLongHeader { + if protocol.ByteCount(len(packetData)) < hdr.PayloadLen { + return fmt.Errorf("packet payload (%d bytes) is smaller than the expected payload length (%d bytes)", len(packetData), hdr.PayloadLen) + } + packetData = packetData[:int(hdr.PayloadLen)] + // TODO(#1312): implement parsing of compound packets + } + if hdr.Type == protocol.PacketTypeInitial { if s.supportsTLS { go s.serverTLS.HandleInitial(remoteAddr, hdr, packetData) diff --git a/server_test.go b/server_test.go index 75331015..d2bae1c0 100644 --- a/server_test.go +++ b/server_test.go @@ -22,17 +22,17 @@ import ( ) type mockSession struct { - connectionID protocol.ConnectionID - packetCount int - closed bool - closeReason error - closedRemote bool - stopRunLoop chan struct{} // run returns as soon as this channel receives a value - handshakeChan chan error + connectionID protocol.ConnectionID + handledPackets []*receivedPacket + closed bool + closeReason error + closedRemote bool + stopRunLoop chan struct{} // run returns as soon as this channel receives a value + handshakeChan chan error } -func (s *mockSession) handlePacket(*receivedPacket) { - s.packetCount++ +func (s *mockSession) handlePacket(p *receivedPacket) { + s.handledPackets = append(s.handledPackets, p) } func (s *mockSession) run() error { @@ -176,7 +176,7 @@ var _ = Describe("Server", func() { Expect(serv.sessions).To(HaveLen(1)) sess := serv.sessions[string(connID)].(*mockSession) Expect(sess.connectionID).To(Equal(connID)) - Expect(sess.packetCount).To(Equal(1)) + Expect(sess.handledPackets).To(HaveLen(1)) }) It("accepts new TLS sessions", func() { @@ -265,7 +265,7 @@ var _ = Describe("Server", func() { Expect(err).ToNot(HaveOccurred()) Expect(serv.sessions).To(HaveLen(1)) Expect(serv.sessions[string(connID)].(*mockSession).connectionID).To(Equal(connID)) - Expect(serv.sessions[string(connID)].(*mockSession).packetCount).To(Equal(2)) + Expect(serv.sessions[string(connID)].(*mockSession).handledPackets).To(HaveLen(2)) }) It("closes and deletes sessions", func() { @@ -366,7 +366,7 @@ var _ = Describe("Server", func() { It("ignores delayed packets with mismatching versions", func() { err := serv.handlePacket(nil, nil, firstPacket) Expect(err).ToNot(HaveOccurred()) - Expect(serv.sessions[string(connID)].(*mockSession).packetCount).To(Equal(1)) + Expect(serv.sessions[string(connID)].(*mockSession).handledPackets).To(HaveLen(1)) b := &bytes.Buffer{} // add an unsupported version data := []byte{0x09, 0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6} @@ -377,7 +377,7 @@ var _ = Describe("Server", func() { // if we didn't ignore the packet, the server would try to send a version negotiation packet, which would make the test panic because it doesn't have a udpConn Expect(conn.dataWritten.Bytes()).To(BeEmpty()) // make sure the packet was *not* passed to session.handlePacket() - Expect(serv.sessions[string(connID)].(*mockSession).packetCount).To(Equal(1)) + Expect(serv.sessions[string(connID)].(*mockSession).handledPackets).To(HaveLen(1)) }) It("errors on invalid public header", func() { @@ -385,6 +385,40 @@ var _ = Describe("Server", func() { Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.InvalidPacketHeader)) }) + It("errors on packets that are smaller than the Payload Length in the packet header", func() { + b := &bytes.Buffer{} + hdr := &wire.Header{ + IsLongHeader: true, + Type: protocol.PacketTypeHandshake, + PayloadLen: 1000, + SrcConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, + DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, + Version: versionIETFFrames, + } + Expect(hdr.Write(b, protocol.PerspectiveClient, versionIETFFrames)).To(Succeed()) + err := serv.handlePacket(nil, nil, append(b.Bytes(), make([]byte, 456)...)) + Expect(err).To(MatchError("packet payload (456 bytes) is smaller than the expected payload length (1000 bytes)")) + }) + + It("cuts packets at the payload length", func() { + err := serv.handlePacket(nil, nil, firstPacket) + Expect(err).ToNot(HaveOccurred()) + b := &bytes.Buffer{} + hdr := &wire.Header{ + IsLongHeader: true, + Type: protocol.PacketTypeHandshake, + PayloadLen: 123, + SrcConnectionID: connID, + DestConnectionID: connID, + Version: versionIETFFrames, + } + Expect(hdr.Write(b, protocol.PerspectiveClient, versionIETFFrames)).To(Succeed()) + err = serv.handlePacket(nil, nil, append(b.Bytes(), make([]byte, 456)...)) + Expect(err).ToNot(HaveOccurred()) + Expect(serv.sessions[string(connID)].(*mockSession).handledPackets).To(HaveLen(2)) + Expect(serv.sessions[string(connID)].(*mockSession).handledPackets[1].data).To(HaveLen(123)) + }) + It("ignores public resets for unknown connections", func() { err := serv.handlePacket(nil, nil, wire.WritePublicReset([]byte{9, 9, 9, 9, 9, 9, 9, 9}, 1, 1337)) Expect(err).ToNot(HaveOccurred()) @@ -395,23 +429,23 @@ var _ = Describe("Server", func() { err := serv.handlePacket(nil, nil, firstPacket) Expect(err).ToNot(HaveOccurred()) Expect(serv.sessions).To(HaveLen(1)) - Expect(serv.sessions[string(connID)].(*mockSession).packetCount).To(Equal(1)) + Expect(serv.sessions[string(connID)].(*mockSession).handledPackets).To(HaveLen(1)) err = serv.handlePacket(nil, nil, wire.WritePublicReset(connID, 1, 1337)) Expect(err).ToNot(HaveOccurred()) Expect(serv.sessions).To(HaveLen(1)) - Expect(serv.sessions[string(connID)].(*mockSession).packetCount).To(Equal(1)) + Expect(serv.sessions[string(connID)].(*mockSession).handledPackets).To(HaveLen(1)) }) It("ignores invalid public resets for known connections", func() { err := serv.handlePacket(nil, nil, firstPacket) Expect(err).ToNot(HaveOccurred()) Expect(serv.sessions).To(HaveLen(1)) - Expect(serv.sessions[string(connID)].(*mockSession).packetCount).To(Equal(1)) + Expect(serv.sessions[string(connID)].(*mockSession).handledPackets).To(HaveLen(1)) data := wire.WritePublicReset(connID, 1, 1337) err = serv.handlePacket(nil, nil, data[:len(data)-2]) Expect(err).ToNot(HaveOccurred()) Expect(serv.sessions).To(HaveLen(1)) - Expect(serv.sessions[string(connID)].(*mockSession).packetCount).To(Equal(1)) + Expect(serv.sessions[string(connID)].(*mockSession).handledPackets).To(HaveLen(1)) }) It("doesn't try to process a packet after sending a gQUIC Version Negotiation Packet", func() { @@ -558,6 +592,7 @@ var _ = Describe("Server", func() { SrcConnectionID: connID, PacketNumber: 0x55, Version: 0x1234, + PayloadLen: protocol.MinInitialPacketSize, } err := hdr.Write(b, protocol.PerspectiveClient, protocol.VersionTLS) Expect(err).ToNot(HaveOccurred())