cut packets at the payload length when receiving

This commit is contained in:
Marten Seemann 2018-04-21 19:31:36 +09:00
parent a7f550ae0f
commit cc536fb895
4 changed files with 109 additions and 24 deletions

View file

@ -314,6 +314,16 @@ func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) error {
return errors.New("received packet with truncated connection ID, but didn't request truncation")
}
hdr.Raw = packet[:len(packet)-r.Len()]
packetData := packet[len(packet)-r.Len():]
if hdr.IsLongHeader {
c.logger.Debugf("len(packet data): %d, payloadLen: %d", len(packetData), hdr.PayloadLen)
if protocol.ByteCount(len(packetData)) < hdr.PayloadLen {
return fmt.Errorf("packet payload (%d bytes) is smaller than the expected payload length (%d bytes)", len(packetData), hdr.PayloadLen)
}
packetData = packetData[:int(hdr.PayloadLen)]
// TODO(#1312): implement parsing of compound packets
}
c.mutex.Lock()
defer c.mutex.Unlock()
@ -366,7 +376,7 @@ func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) error {
c.session.handlePacket(&receivedPacket{
remoteAddr: remoteAddr,
header: hdr,
data: packet[len(packet)-r.Len():],
data: packetData,
rcvTime: rcvTime,
})
return nil

View file

@ -435,7 +435,7 @@ var _ = Describe("Client", func() {
// it didn't pass the version negoation packet to the old session (since it has no payload)
Eventually(func() bool { return firstSession.closed }).Should(BeTrue())
Expect(firstSession.closeReason).To(Equal(errCloseSessionForNewVersion))
Expect(firstSession.packetCount).To(BeZero())
Expect(firstSession.handledPackets).To(BeEmpty())
Eventually(sessionChan).Should(Receive(&secondSession))
// make the server accept the new version
packetConn.dataToRead <- acceptClientVersionPacket(secondSession.connectionID)
@ -516,10 +516,42 @@ var _ = Describe("Client", func() {
err := cl.handlePacket(addr, []byte("invalid packet"))
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("error parsing packet from"))
Expect(sess.packetCount).To(BeZero())
Expect(sess.handledPackets).To(BeEmpty())
Expect(sess.closed).To(BeFalse())
})
It("errors on packets that are smaller than the Payload Length in the packet header", func() {
b := &bytes.Buffer{}
hdr := &wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeHandshake,
PayloadLen: 1000,
SrcConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8},
DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8},
Version: versionIETFFrames,
}
Expect(hdr.Write(b, protocol.PerspectiveClient, versionIETFFrames)).To(Succeed())
cl.handlePacket(addr, append(b.Bytes(), make([]byte, 456)...))
Expect(sess.handledPackets).To(BeEmpty())
Expect(sess.closed).To(BeFalse())
})
It("cuts packets at the payload length", func() {
b := &bytes.Buffer{}
hdr := &wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeHandshake,
PayloadLen: 123,
SrcConnectionID: connID,
DestConnectionID: connID,
Version: versionIETFFrames,
}
Expect(hdr.Write(b, protocol.PerspectiveClient, versionIETFFrames)).To(Succeed())
cl.handlePacket(addr, append(b.Bytes(), make([]byte, 456)...))
Expect(sess.handledPackets).To(HaveLen(1))
Expect(sess.handledPackets[0].data).To(HaveLen(123))
})
It("ignores packets without connection id, if it didn't request connection id trunctation", func() {
cl.config = &Config{RequestConnectionIDOmission: false}
buf := &bytes.Buffer{}
@ -533,7 +565,7 @@ var _ = Describe("Client", func() {
Expect(err).ToNot(HaveOccurred())
err = cl.handlePacket(addr, buf.Bytes())
Expect(err).To(MatchError("received packet with truncated connection ID, but didn't request truncation"))
Expect(sess.packetCount).To(BeZero())
Expect(sess.handledPackets).To(BeEmpty())
Expect(sess.closed).To(BeFalse())
})
@ -552,7 +584,7 @@ var _ = Describe("Client", func() {
Expect(err).ToNot(HaveOccurred())
err = cl.handlePacket(addr, buf.Bytes())
Expect(err).To(MatchError(fmt.Sprintf("received a packet with an unexpected connection ID (0x0807060504030201, expected %s)", connID)))
Expect(sess.packetCount).To(BeZero())
Expect(sess.handledPackets).To(BeEmpty())
Expect(sess.closed).To(BeFalse())
})
@ -647,7 +679,7 @@ var _ = Describe("Client", func() {
Expect(err).ToNot(HaveOccurred())
packetConn.dataToRead <- b.Bytes()
Expect(sess.packetCount).To(BeZero())
Expect(sess.handledPackets).To(BeEmpty())
stoppedListening := make(chan struct{})
go func() {
cl.listen()
@ -655,7 +687,7 @@ var _ = Describe("Client", func() {
close(stoppedListening)
}()
Eventually(func() int { return sess.packetCount }).Should(Equal(1))
Eventually(func() []*receivedPacket { return sess.handledPackets }).Should(HaveLen(1))
Expect(sess.closed).To(BeFalse())
Consistently(stoppedListening).ShouldNot(BeClosed())
})

View file

@ -308,6 +308,14 @@ func (s *server) handlePacket(pconn net.PacketConn, remoteAddr net.Addr, packet
hdr.Raw = packet[:len(packet)-r.Len()]
packetData := packet[len(packet)-r.Len():]
if hdr.IsLongHeader {
if protocol.ByteCount(len(packetData)) < hdr.PayloadLen {
return fmt.Errorf("packet payload (%d bytes) is smaller than the expected payload length (%d bytes)", len(packetData), hdr.PayloadLen)
}
packetData = packetData[:int(hdr.PayloadLen)]
// TODO(#1312): implement parsing of compound packets
}
if hdr.Type == protocol.PacketTypeInitial {
if s.supportsTLS {
go s.serverTLS.HandleInitial(remoteAddr, hdr, packetData)

View file

@ -22,17 +22,17 @@ import (
)
type mockSession struct {
connectionID protocol.ConnectionID
packetCount int
closed bool
closeReason error
closedRemote bool
stopRunLoop chan struct{} // run returns as soon as this channel receives a value
handshakeChan chan error
connectionID protocol.ConnectionID
handledPackets []*receivedPacket
closed bool
closeReason error
closedRemote bool
stopRunLoop chan struct{} // run returns as soon as this channel receives a value
handshakeChan chan error
}
func (s *mockSession) handlePacket(*receivedPacket) {
s.packetCount++
func (s *mockSession) handlePacket(p *receivedPacket) {
s.handledPackets = append(s.handledPackets, p)
}
func (s *mockSession) run() error {
@ -176,7 +176,7 @@ var _ = Describe("Server", func() {
Expect(serv.sessions).To(HaveLen(1))
sess := serv.sessions[string(connID)].(*mockSession)
Expect(sess.connectionID).To(Equal(connID))
Expect(sess.packetCount).To(Equal(1))
Expect(sess.handledPackets).To(HaveLen(1))
})
It("accepts new TLS sessions", func() {
@ -265,7 +265,7 @@ var _ = Describe("Server", func() {
Expect(err).ToNot(HaveOccurred())
Expect(serv.sessions).To(HaveLen(1))
Expect(serv.sessions[string(connID)].(*mockSession).connectionID).To(Equal(connID))
Expect(serv.sessions[string(connID)].(*mockSession).packetCount).To(Equal(2))
Expect(serv.sessions[string(connID)].(*mockSession).handledPackets).To(HaveLen(2))
})
It("closes and deletes sessions", func() {
@ -366,7 +366,7 @@ var _ = Describe("Server", func() {
It("ignores delayed packets with mismatching versions", func() {
err := serv.handlePacket(nil, nil, firstPacket)
Expect(err).ToNot(HaveOccurred())
Expect(serv.sessions[string(connID)].(*mockSession).packetCount).To(Equal(1))
Expect(serv.sessions[string(connID)].(*mockSession).handledPackets).To(HaveLen(1))
b := &bytes.Buffer{}
// add an unsupported version
data := []byte{0x09, 0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6}
@ -377,7 +377,7 @@ var _ = Describe("Server", func() {
// if we didn't ignore the packet, the server would try to send a version negotiation packet, which would make the test panic because it doesn't have a udpConn
Expect(conn.dataWritten.Bytes()).To(BeEmpty())
// make sure the packet was *not* passed to session.handlePacket()
Expect(serv.sessions[string(connID)].(*mockSession).packetCount).To(Equal(1))
Expect(serv.sessions[string(connID)].(*mockSession).handledPackets).To(HaveLen(1))
})
It("errors on invalid public header", func() {
@ -385,6 +385,40 @@ var _ = Describe("Server", func() {
Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.InvalidPacketHeader))
})
It("errors on packets that are smaller than the Payload Length in the packet header", func() {
b := &bytes.Buffer{}
hdr := &wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeHandshake,
PayloadLen: 1000,
SrcConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8},
DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8},
Version: versionIETFFrames,
}
Expect(hdr.Write(b, protocol.PerspectiveClient, versionIETFFrames)).To(Succeed())
err := serv.handlePacket(nil, nil, append(b.Bytes(), make([]byte, 456)...))
Expect(err).To(MatchError("packet payload (456 bytes) is smaller than the expected payload length (1000 bytes)"))
})
It("cuts packets at the payload length", func() {
err := serv.handlePacket(nil, nil, firstPacket)
Expect(err).ToNot(HaveOccurred())
b := &bytes.Buffer{}
hdr := &wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeHandshake,
PayloadLen: 123,
SrcConnectionID: connID,
DestConnectionID: connID,
Version: versionIETFFrames,
}
Expect(hdr.Write(b, protocol.PerspectiveClient, versionIETFFrames)).To(Succeed())
err = serv.handlePacket(nil, nil, append(b.Bytes(), make([]byte, 456)...))
Expect(err).ToNot(HaveOccurred())
Expect(serv.sessions[string(connID)].(*mockSession).handledPackets).To(HaveLen(2))
Expect(serv.sessions[string(connID)].(*mockSession).handledPackets[1].data).To(HaveLen(123))
})
It("ignores public resets for unknown connections", func() {
err := serv.handlePacket(nil, nil, wire.WritePublicReset([]byte{9, 9, 9, 9, 9, 9, 9, 9}, 1, 1337))
Expect(err).ToNot(HaveOccurred())
@ -395,23 +429,23 @@ var _ = Describe("Server", func() {
err := serv.handlePacket(nil, nil, firstPacket)
Expect(err).ToNot(HaveOccurred())
Expect(serv.sessions).To(HaveLen(1))
Expect(serv.sessions[string(connID)].(*mockSession).packetCount).To(Equal(1))
Expect(serv.sessions[string(connID)].(*mockSession).handledPackets).To(HaveLen(1))
err = serv.handlePacket(nil, nil, wire.WritePublicReset(connID, 1, 1337))
Expect(err).ToNot(HaveOccurred())
Expect(serv.sessions).To(HaveLen(1))
Expect(serv.sessions[string(connID)].(*mockSession).packetCount).To(Equal(1))
Expect(serv.sessions[string(connID)].(*mockSession).handledPackets).To(HaveLen(1))
})
It("ignores invalid public resets for known connections", func() {
err := serv.handlePacket(nil, nil, firstPacket)
Expect(err).ToNot(HaveOccurred())
Expect(serv.sessions).To(HaveLen(1))
Expect(serv.sessions[string(connID)].(*mockSession).packetCount).To(Equal(1))
Expect(serv.sessions[string(connID)].(*mockSession).handledPackets).To(HaveLen(1))
data := wire.WritePublicReset(connID, 1, 1337)
err = serv.handlePacket(nil, nil, data[:len(data)-2])
Expect(err).ToNot(HaveOccurred())
Expect(serv.sessions).To(HaveLen(1))
Expect(serv.sessions[string(connID)].(*mockSession).packetCount).To(Equal(1))
Expect(serv.sessions[string(connID)].(*mockSession).handledPackets).To(HaveLen(1))
})
It("doesn't try to process a packet after sending a gQUIC Version Negotiation Packet", func() {
@ -558,6 +592,7 @@ var _ = Describe("Server", func() {
SrcConnectionID: connID,
PacketNumber: 0x55,
Version: 0x1234,
PayloadLen: protocol.MinInitialPacketSize,
}
err := hdr.Write(b, protocol.PerspectiveClient, protocol.VersionTLS)
Expect(err).ToNot(HaveOccurred())