cut packets at the payload length when receiving

This commit is contained in:
Marten Seemann 2018-04-21 19:31:36 +09:00
parent a7f550ae0f
commit cc536fb895
4 changed files with 109 additions and 24 deletions

View file

@ -22,17 +22,17 @@ import (
)
type mockSession struct {
connectionID protocol.ConnectionID
packetCount int
closed bool
closeReason error
closedRemote bool
stopRunLoop chan struct{} // run returns as soon as this channel receives a value
handshakeChan chan error
connectionID protocol.ConnectionID
handledPackets []*receivedPacket
closed bool
closeReason error
closedRemote bool
stopRunLoop chan struct{} // run returns as soon as this channel receives a value
handshakeChan chan error
}
func (s *mockSession) handlePacket(*receivedPacket) {
s.packetCount++
func (s *mockSession) handlePacket(p *receivedPacket) {
s.handledPackets = append(s.handledPackets, p)
}
func (s *mockSession) run() error {
@ -176,7 +176,7 @@ var _ = Describe("Server", func() {
Expect(serv.sessions).To(HaveLen(1))
sess := serv.sessions[string(connID)].(*mockSession)
Expect(sess.connectionID).To(Equal(connID))
Expect(sess.packetCount).To(Equal(1))
Expect(sess.handledPackets).To(HaveLen(1))
})
It("accepts new TLS sessions", func() {
@ -265,7 +265,7 @@ var _ = Describe("Server", func() {
Expect(err).ToNot(HaveOccurred())
Expect(serv.sessions).To(HaveLen(1))
Expect(serv.sessions[string(connID)].(*mockSession).connectionID).To(Equal(connID))
Expect(serv.sessions[string(connID)].(*mockSession).packetCount).To(Equal(2))
Expect(serv.sessions[string(connID)].(*mockSession).handledPackets).To(HaveLen(2))
})
It("closes and deletes sessions", func() {
@ -366,7 +366,7 @@ var _ = Describe("Server", func() {
It("ignores delayed packets with mismatching versions", func() {
err := serv.handlePacket(nil, nil, firstPacket)
Expect(err).ToNot(HaveOccurred())
Expect(serv.sessions[string(connID)].(*mockSession).packetCount).To(Equal(1))
Expect(serv.sessions[string(connID)].(*mockSession).handledPackets).To(HaveLen(1))
b := &bytes.Buffer{}
// add an unsupported version
data := []byte{0x09, 0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6}
@ -377,7 +377,7 @@ var _ = Describe("Server", func() {
// if we didn't ignore the packet, the server would try to send a version negotiation packet, which would make the test panic because it doesn't have a udpConn
Expect(conn.dataWritten.Bytes()).To(BeEmpty())
// make sure the packet was *not* passed to session.handlePacket()
Expect(serv.sessions[string(connID)].(*mockSession).packetCount).To(Equal(1))
Expect(serv.sessions[string(connID)].(*mockSession).handledPackets).To(HaveLen(1))
})
It("errors on invalid public header", func() {
@ -385,6 +385,40 @@ var _ = Describe("Server", func() {
Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.InvalidPacketHeader))
})
It("errors on packets that are smaller than the Payload Length in the packet header", func() {
b := &bytes.Buffer{}
hdr := &wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeHandshake,
PayloadLen: 1000,
SrcConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8},
DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8},
Version: versionIETFFrames,
}
Expect(hdr.Write(b, protocol.PerspectiveClient, versionIETFFrames)).To(Succeed())
err := serv.handlePacket(nil, nil, append(b.Bytes(), make([]byte, 456)...))
Expect(err).To(MatchError("packet payload (456 bytes) is smaller than the expected payload length (1000 bytes)"))
})
It("cuts packets at the payload length", func() {
err := serv.handlePacket(nil, nil, firstPacket)
Expect(err).ToNot(HaveOccurred())
b := &bytes.Buffer{}
hdr := &wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeHandshake,
PayloadLen: 123,
SrcConnectionID: connID,
DestConnectionID: connID,
Version: versionIETFFrames,
}
Expect(hdr.Write(b, protocol.PerspectiveClient, versionIETFFrames)).To(Succeed())
err = serv.handlePacket(nil, nil, append(b.Bytes(), make([]byte, 456)...))
Expect(err).ToNot(HaveOccurred())
Expect(serv.sessions[string(connID)].(*mockSession).handledPackets).To(HaveLen(2))
Expect(serv.sessions[string(connID)].(*mockSession).handledPackets[1].data).To(HaveLen(123))
})
It("ignores public resets for unknown connections", func() {
err := serv.handlePacket(nil, nil, wire.WritePublicReset([]byte{9, 9, 9, 9, 9, 9, 9, 9}, 1, 1337))
Expect(err).ToNot(HaveOccurred())
@ -395,23 +429,23 @@ var _ = Describe("Server", func() {
err := serv.handlePacket(nil, nil, firstPacket)
Expect(err).ToNot(HaveOccurred())
Expect(serv.sessions).To(HaveLen(1))
Expect(serv.sessions[string(connID)].(*mockSession).packetCount).To(Equal(1))
Expect(serv.sessions[string(connID)].(*mockSession).handledPackets).To(HaveLen(1))
err = serv.handlePacket(nil, nil, wire.WritePublicReset(connID, 1, 1337))
Expect(err).ToNot(HaveOccurred())
Expect(serv.sessions).To(HaveLen(1))
Expect(serv.sessions[string(connID)].(*mockSession).packetCount).To(Equal(1))
Expect(serv.sessions[string(connID)].(*mockSession).handledPackets).To(HaveLen(1))
})
It("ignores invalid public resets for known connections", func() {
err := serv.handlePacket(nil, nil, firstPacket)
Expect(err).ToNot(HaveOccurred())
Expect(serv.sessions).To(HaveLen(1))
Expect(serv.sessions[string(connID)].(*mockSession).packetCount).To(Equal(1))
Expect(serv.sessions[string(connID)].(*mockSession).handledPackets).To(HaveLen(1))
data := wire.WritePublicReset(connID, 1, 1337)
err = serv.handlePacket(nil, nil, data[:len(data)-2])
Expect(err).ToNot(HaveOccurred())
Expect(serv.sessions).To(HaveLen(1))
Expect(serv.sessions[string(connID)].(*mockSession).packetCount).To(Equal(1))
Expect(serv.sessions[string(connID)].(*mockSession).handledPackets).To(HaveLen(1))
})
It("doesn't try to process a packet after sending a gQUIC Version Negotiation Packet", func() {
@ -558,6 +592,7 @@ var _ = Describe("Server", func() {
SrcConnectionID: connID,
PacketNumber: 0x55,
Version: 0x1234,
PayloadLen: protocol.MinInitialPacketSize,
}
err := hdr.Write(b, protocol.PerspectiveClient, protocol.VersionTLS)
Expect(err).ToNot(HaveOccurred())