parse IETF headers independent of the sender's perspective

The IETF header format allows parsing of the header without knowing
which peer sent the packet.
This commit is contained in:
Marten Seemann 2018-05-05 17:59:54 +09:00
parent 70f6e3651e
commit 8f2fed1b10
9 changed files with 74 additions and 74 deletions

View file

@ -346,6 +346,9 @@ func (c *client) handleIETFQUICPacket(hdr *wire.Header, packetData []byte, remot
return fmt.Errorf("received a packet with an unexpected connection ID (%s, expected %s)", hdr.DestConnectionID, c.srcConnID) return fmt.Errorf("received a packet with an unexpected connection ID (%s, expected %s)", hdr.DestConnectionID, c.srcConnID)
} }
if hdr.IsLongHeader { if hdr.IsLongHeader {
if hdr.Type != protocol.PacketTypeRetry && hdr.Type != protocol.PacketTypeHandshake {
return fmt.Errorf("Received unsupported packet type: %s", hdr.Type)
}
c.logger.Debugf("len(packet data): %d, payloadLen: %d", len(packetData), hdr.PayloadLen) c.logger.Debugf("len(packet data): %d, payloadLen: %d", len(packetData), hdr.PayloadLen)
if protocol.ByteCount(len(packetData)) < hdr.PayloadLen { if protocol.ByteCount(len(packetData)) < hdr.PayloadLen {
return fmt.Errorf("packet payload (%d bytes) is smaller than the expected payload length (%d bytes)", len(packetData), hdr.PayloadLen) return fmt.Errorf("packet payload (%d bytes) is smaller than the expected payload length (%d bytes)", len(packetData), hdr.PayloadLen)

View file

@ -547,11 +547,27 @@ var _ = Describe("Client", func() {
Version: versionIETFFrames, Version: versionIETFFrames,
} }
Expect(hdr.Write(b, protocol.PerspectiveClient, versionIETFFrames)).To(Succeed()) Expect(hdr.Write(b, protocol.PerspectiveClient, versionIETFFrames)).To(Succeed())
cl.handlePacket(addr, append(b.Bytes(), make([]byte, 456)...)) err := cl.handlePacket(addr, append(b.Bytes(), make([]byte, 456)...))
Expect(err).ToNot(HaveOccurred())
Expect(sess.handledPackets).To(HaveLen(1)) Expect(sess.handledPackets).To(HaveLen(1))
Expect(sess.handledPackets[0].data).To(HaveLen(123)) Expect(sess.handledPackets[0].data).To(HaveLen(123))
}) })
It("ignores packets with the wrong Long Header Type", func() {
b := &bytes.Buffer{}
hdr := &wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeInitial,
PayloadLen: 123,
SrcConnectionID: connID,
DestConnectionID: connID,
Version: versionIETFFrames,
}
Expect(hdr.Write(b, protocol.PerspectiveServer, versionIETFFrames)).To(Succeed())
err := cl.handlePacket(addr, append(b.Bytes(), make([]byte, 456)...))
Expect(err).To(MatchError("Received unsupported packet type: Initial"))
})
It("ignores packets without connection id, if it didn't request connection id trunctation", func() { It("ignores packets without connection id, if it didn't request connection id trunctation", func() {
cl.config = &Config{RequestConnectionIDOmission: false} cl.config = &Config{RequestConnectionIDOmission: false}
buf := &bytes.Buffer{} buf := &bytes.Buffer{}

View file

@ -87,7 +87,7 @@ func parsePacketHeader(b *bytes.Reader, sentBy protocol.Perspective, isPublicHea
hdr.IsPublicHeader = true // save that this is a Public Header, so we can log it correctly later hdr.IsPublicHeader = true // save that this is a Public Header, so we can log it correctly later
return hdr, nil return hdr, nil
} }
return parseHeader(b, sentBy) return parseHeader(b)
} }
// Write writes the Header. // Write writes the Header.

View file

@ -190,6 +190,7 @@ var _ = Describe("Header", func() {
It("writes a IETF draft header", func() { It("writes a IETF draft header", func() {
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
hdr := &Header{ hdr := &Header{
Type: protocol.PacketTypeHandshake,
DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8},
SrcConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, SrcConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8},
PacketNumber: 0x42, PacketNumber: 0x42,
@ -198,7 +199,7 @@ var _ = Describe("Header", func() {
} }
err := hdr.Write(buf, protocol.PerspectiveServer, versionIETFHeader) err := hdr.Write(buf, protocol.PerspectiveServer, versionIETFHeader)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
_, err = parseHeader(bytes.NewReader(buf.Bytes()), protocol.PerspectiveServer) _, err = ParseHeaderSentByServer(bytes.NewReader(buf.Bytes()), versionIETFFrames)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(hdr.IsPublicHeader).To(BeFalse()) Expect(hdr.IsPublicHeader).To(BeFalse())
}) })

View file

@ -12,19 +12,19 @@ import (
) )
// parseHeader parses the header. // parseHeader parses the header.
func parseHeader(b *bytes.Reader, packetSentBy protocol.Perspective) (*Header, error) { func parseHeader(b *bytes.Reader) (*Header, error) {
typeByte, err := b.ReadByte() typeByte, err := b.ReadByte()
if err != nil { if err != nil {
return nil, err return nil, err
} }
if typeByte&0x80 > 0 { if typeByte&0x80 > 0 {
return parseLongHeader(b, packetSentBy, typeByte) return parseLongHeader(b, typeByte)
} }
return parseShortHeader(b, typeByte) return parseShortHeader(b, typeByte)
} }
// parse long header and version negotiation packets // parse long header and version negotiation packets
func parseLongHeader(b *bytes.Reader, sentBy protocol.Perspective, typeByte byte) (*Header, error) { func parseLongHeader(b *bytes.Reader, typeByte byte) (*Header, error) {
v, err := utils.BigEndian.ReadUint32(b) v, err := utils.BigEndian.ReadUint32(b)
if err != nil { if err != nil {
return nil, err return nil, err
@ -52,9 +52,6 @@ func parseLongHeader(b *bytes.Reader, sentBy protocol.Perspective, typeByte byte
} }
if v == 0 { // version negotiation packet if v == 0 { // version negotiation packet
if sentBy == protocol.PerspectiveClient {
return nil, qerr.InvalidVersion
}
if b.Len() == 0 { if b.Len() == 0 {
return nil, qerr.Error(qerr.InvalidVersionNegotiationPacket, "empty version list") return nil, qerr.Error(qerr.InvalidVersionNegotiationPacket, "empty version list")
} }
@ -82,10 +79,8 @@ func parseLongHeader(b *bytes.Reader, sentBy protocol.Perspective, typeByte byte
h.PacketNumber = protocol.PacketNumber(pn) h.PacketNumber = protocol.PacketNumber(pn)
h.PacketNumberLen = protocol.PacketNumberLen4 h.PacketNumberLen = protocol.PacketNumberLen4
h.Type = protocol.PacketType(typeByte & 0x7f) h.Type = protocol.PacketType(typeByte & 0x7f)
if sentBy == protocol.PerspectiveClient && (h.Type != protocol.PacketTypeInitial && h.Type != protocol.PacketTypeHandshake && h.Type != protocol.PacketType0RTT) {
return nil, qerr.Error(qerr.InvalidPacketHeader, fmt.Sprintf("Received packet with invalid packet type: %d", h.Type)) if h.Type != protocol.PacketTypeInitial && h.Type != protocol.PacketTypeRetry && h.Type != protocol.PacketType0RTT && h.Type != protocol.PacketTypeHandshake {
}
if sentBy == protocol.PerspectiveServer && (h.Type != protocol.PacketTypeRetry && h.Type != protocol.PacketTypeHandshake) {
return nil, qerr.Error(qerr.InvalidPacketHeader, fmt.Sprintf("Received packet with invalid packet type: %d", h.Type)) return nil, qerr.Error(qerr.InvalidPacketHeader, fmt.Sprintf("Received packet with invalid packet type: %d", h.Type))
} }
return h, nil return h, nil

View file

@ -2,7 +2,6 @@ package wire
import ( import (
"bytes" "bytes"
"fmt"
"io" "io"
"log" "log"
"os" "os"
@ -26,7 +25,7 @@ var _ = Describe("IETF QUIC Header", func() {
data, err := ComposeVersionNegotiation(connID, connID, versions) data, err := ComposeVersionNegotiation(connID, connID, versions)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
b := bytes.NewReader(data) b := bytes.NewReader(data)
h, err := parseHeader(b, protocol.PerspectiveServer) h, err := parseHeader(b)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(h.IsVersionNegotiation).To(BeTrue()) Expect(h.IsVersionNegotiation).To(BeTrue())
Expect(h.Version).To(BeZero()) Expect(h.Version).To(BeZero())
@ -43,7 +42,7 @@ var _ = Describe("IETF QUIC Header", func() {
data, err := ComposeVersionNegotiation(connID, connID, versions) data, err := ComposeVersionNegotiation(connID, connID, versions)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
b := bytes.NewReader(data[:len(data)-2]) b := bytes.NewReader(data[:len(data)-2])
_, err = parseHeader(b, protocol.PerspectiveServer) _, err = parseHeader(b)
Expect(err).To(MatchError(qerr.InvalidVersionNegotiationPacket)) Expect(err).To(MatchError(qerr.InvalidVersionNegotiationPacket))
}) })
@ -53,7 +52,7 @@ var _ = Describe("IETF QUIC Header", func() {
data, err := ComposeVersionNegotiation(connID, connID, versions) data, err := ComposeVersionNegotiation(connID, connID, versions)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
// remove 8 bytes (two versions), since ComposeVersionNegotiation also added a reserved version number // remove 8 bytes (two versions), since ComposeVersionNegotiation also added a reserved version number
_, err = parseHeader(bytes.NewReader(data[:len(data)-8]), protocol.PerspectiveServer) _, err = parseHeader(bytes.NewReader(data[:len(data)-8]))
Expect(err).To(MatchError("InvalidVersionNegotiationPacket: empty version list")) Expect(err).To(MatchError("InvalidVersionNegotiationPacket: empty version list"))
}) })
}) })
@ -74,7 +73,7 @@ var _ = Describe("IETF QUIC Header", func() {
It("parses a long header", func() { It("parses a long header", func() {
b := bytes.NewReader(generatePacket(protocol.PacketTypeInitial)) b := bytes.NewReader(generatePacket(protocol.PacketTypeInitial))
h, err := parseHeader(b, protocol.PerspectiveClient) h, err := parseHeader(b)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(h.Type).To(Equal(protocol.PacketTypeInitial)) Expect(h.Type).To(Equal(protocol.PacketTypeInitial))
Expect(h.IsLongHeader).To(BeTrue()) Expect(h.IsLongHeader).To(BeTrue())
@ -99,7 +98,7 @@ var _ = Describe("IETF QUIC Header", func() {
data = append(data, encodeVarInt(0x42)...) // payload length data = append(data, encodeVarInt(0x42)...) // payload length
data = append(data, []byte{0xde, 0xca, 0xfb, 0xad}...) data = append(data, []byte{0xde, 0xca, 0xfb, 0xad}...)
b := bytes.NewReader(data) b := bytes.NewReader(data)
h, err := parseHeader(b, protocol.PerspectiveClient) h, err := parseHeader(b)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(h.SrcConnectionID).To(Equal(protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef})) Expect(h.SrcConnectionID).To(Equal(protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}))
Expect(h.DestConnectionID).To(BeEmpty()) Expect(h.DestConnectionID).To(BeEmpty())
@ -115,40 +114,12 @@ var _ = Describe("IETF QUIC Header", func() {
data = append(data, encodeVarInt(0x42)...) // payload length data = append(data, encodeVarInt(0x42)...) // payload length
data = append(data, []byte{0xde, 0xca, 0xfb, 0xad}...) data = append(data, []byte{0xde, 0xca, 0xfb, 0xad}...)
b := bytes.NewReader(data) b := bytes.NewReader(data)
h, err := parseHeader(b, protocol.PerspectiveClient) h, err := parseHeader(b)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(h.SrcConnectionID).To(BeEmpty()) Expect(h.SrcConnectionID).To(BeEmpty())
Expect(h.DestConnectionID).To(Equal(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10})) Expect(h.DestConnectionID).To(Equal(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}))
}) })
It("rejects packets sent by the client that use packet types for packets sent by the server", func() {
buf := &bytes.Buffer{}
err := (&Header{
IsLongHeader: true,
Type: protocol.PacketTypeRetry,
SrcConnectionID: srcConnID,
Version: 0x10203040,
}).Write(buf, protocol.PerspectiveServer, protocol.VersionTLS)
Expect(err).ToNot(HaveOccurred())
b := bytes.NewReader(buf.Bytes())
_, err = parseHeader(b, protocol.PerspectiveClient)
Expect(err).To(MatchError(fmt.Sprintf("InvalidPacketHeader: Received packet with invalid packet type: %d", protocol.PacketTypeRetry)))
})
It("rejects packets sent by the client that use packet types for packets sent by the server", func() {
buf := &bytes.Buffer{}
err := (&Header{
IsLongHeader: true,
Type: protocol.PacketType0RTT,
SrcConnectionID: srcConnID,
Version: 0x10203040,
}).Write(buf, protocol.PerspectiveClient, protocol.VersionTLS)
Expect(err).ToNot(HaveOccurred())
b := bytes.NewReader(buf.Bytes())
_, err = parseHeader(b, protocol.PerspectiveServer)
Expect(err).To(MatchError(fmt.Sprintf("InvalidPacketHeader: Received packet with invalid packet type: %d", protocol.PacketType0RTT)))
})
It("rejects packets sent with an unknown packet type", func() { It("rejects packets sent with an unknown packet type", func() {
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
err := (&Header{ err := (&Header{
@ -159,21 +130,10 @@ var _ = Describe("IETF QUIC Header", func() {
}).Write(buf, protocol.PerspectiveClient, protocol.VersionTLS) }).Write(buf, protocol.PerspectiveClient, protocol.VersionTLS)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
b := bytes.NewReader(buf.Bytes()) b := bytes.NewReader(buf.Bytes())
_, err = parseHeader(b, protocol.PerspectiveServer) _, err = parseHeader(b)
Expect(err).To(MatchError("InvalidPacketHeader: Received packet with invalid packet type: 42")) Expect(err).To(MatchError("InvalidPacketHeader: Received packet with invalid packet type: 42"))
}) })
It("rejects version 0 for packets sent by the client", func() {
data := []byte{
0x80 ^ uint8(protocol.PacketTypeInitial),
0x0, 0x0, 0x0, 0x0, // version number
0x0, // no connection IDs
0xde, 0xca, 0xfb, 0xad, // packet number
}
_, err := parseHeader(bytes.NewReader(data), protocol.PerspectiveClient)
Expect(err).To(MatchError(qerr.InvalidVersion))
})
It("errors on EOF", func() { It("errors on EOF", func() {
data := []byte{ data := []byte{
0x80 ^ uint8(protocol.PacketTypeInitial), 0x80 ^ uint8(protocol.PacketTypeInitial),
@ -184,7 +144,7 @@ var _ = Describe("IETF QUIC Header", func() {
0xde, 0xca, 0xfb, 0xad, // packet number 0xde, 0xca, 0xfb, 0xad, // packet number
} }
for i := 0; i < len(data); i++ { for i := 0; i < len(data); i++ {
_, err := parseHeader(bytes.NewReader(data[:i]), protocol.PerspectiveClient) _, err := parseHeader(bytes.NewReader(data[:i]))
Expect(err).To(Equal(io.EOF)) Expect(err).To(Equal(io.EOF))
} }
}) })
@ -198,7 +158,7 @@ var _ = Describe("IETF QUIC Header", func() {
0x42, // packet number 0x42, // packet number
} }
b := bytes.NewReader(data) b := bytes.NewReader(data)
h, err := parseHeader(b, protocol.PerspectiveClient) h, err := parseHeader(b)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(h.IsLongHeader).To(BeFalse()) Expect(h.IsLongHeader).To(BeFalse())
Expect(h.KeyPhase).To(Equal(0)) Expect(h.KeyPhase).To(Equal(0))
@ -217,7 +177,7 @@ var _ = Describe("IETF QUIC Header", func() {
0x11, 0x11,
} }
b := bytes.NewReader(data) b := bytes.NewReader(data)
h, err := parseHeader(b, protocol.PerspectiveClient) h, err := parseHeader(b)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(h.IsLongHeader).To(BeFalse()) Expect(h.IsLongHeader).To(BeFalse())
Expect(h.KeyPhase).To(Equal(1)) Expect(h.KeyPhase).To(Equal(1))
@ -231,7 +191,7 @@ var _ = Describe("IETF QUIC Header", func() {
0x13, 0x37, // packet number 0x13, 0x37, // packet number
} }
b := bytes.NewReader(data) b := bytes.NewReader(data)
h, err := parseHeader(b, protocol.PerspectiveClient) h, err := parseHeader(b)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(h.IsLongHeader).To(BeFalse()) Expect(h.IsLongHeader).To(BeFalse())
Expect(h.PacketNumber).To(Equal(protocol.PacketNumber(0x1337))) Expect(h.PacketNumber).To(Equal(protocol.PacketNumber(0x1337)))
@ -246,7 +206,7 @@ var _ = Describe("IETF QUIC Header", func() {
0xde, 0xad, 0xbe, 0xef, // packet number 0xde, 0xad, 0xbe, 0xef, // packet number
} }
b := bytes.NewReader(data) b := bytes.NewReader(data)
h, err := parseHeader(b, protocol.PerspectiveClient) h, err := parseHeader(b)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(h.IsLongHeader).To(BeFalse()) Expect(h.IsLongHeader).To(BeFalse())
Expect(h.PacketNumber).To(Equal(protocol.PacketNumber(0xdeadbeef))) Expect(h.PacketNumber).To(Equal(protocol.PacketNumber(0xdeadbeef)))
@ -261,7 +221,7 @@ var _ = Describe("IETF QUIC Header", func() {
0xde, 0xad, 0xbe, 0xef, // packet number 0xde, 0xad, 0xbe, 0xef, // packet number
} }
b := bytes.NewReader(data) b := bytes.NewReader(data)
_, err := parseHeader(b, protocol.PerspectiveClient) _, err := parseHeader(b)
Expect(err).To(MatchError("invalid short header type")) Expect(err).To(MatchError("invalid short header type"))
}) })
@ -272,7 +232,7 @@ var _ = Describe("IETF QUIC Header", func() {
0xde, 0xca, 0xfb, 0xad, // packet number 0xde, 0xca, 0xfb, 0xad, // packet number
} }
b := bytes.NewReader(data) b := bytes.NewReader(data)
_, err := parseHeader(b, protocol.PerspectiveClient) _, err := parseHeader(b)
Expect(err).To(MatchError("invalid bits 3, 4 and 5")) Expect(err).To(MatchError("invalid bits 3, 4 and 5"))
}) })
@ -283,7 +243,7 @@ var _ = Describe("IETF QUIC Header", func() {
0xde, 0xca, 0xfb, 0xad, // packet number 0xde, 0xca, 0xfb, 0xad, // packet number
} }
for i := 0; i < len(data); i++ { for i := 0; i < len(data); i++ {
_, err := parseHeader(bytes.NewReader(data[:i]), protocol.PerspectiveClient) _, err := parseHeader(bytes.NewReader(data[:i]))
Expect(err).To(Equal(io.EOF)) Expect(err).To(Equal(io.EOF))
} }
}) })
@ -545,7 +505,7 @@ var _ = Describe("IETF QUIC Header", func() {
srcConnID := protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad, 0x013, 0x37, 0x13, 0x37} srcConnID := protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad, 0x013, 0x37, 0x13, 0x37}
data, err := ComposeVersionNegotiation(destConnID, srcConnID, []protocol.VersionNumber{0x12345678, 0x87654321}) data, err := ComposeVersionNegotiation(destConnID, srcConnID, []protocol.VersionNumber{0x12345678, 0x87654321})
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
hdr, err := parseLongHeader(bytes.NewReader(data[1:]), protocol.PerspectiveServer, data[0]) hdr, err := parseLongHeader(bytes.NewReader(data[1:]), data[0])
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
hdr.logHeader(logger) hdr.logHeader(logger)
Expect(buf.String()).To(ContainSubstring("VersionNegotiationPacket{DestConnectionID: 0xdeadbeefcafe1337, SrcConnectionID: 0xdecafbad13371337")) Expect(buf.String()).To(ContainSubstring("VersionNegotiationPacket{DestConnectionID: 0xdeadbeefcafe1337, SrcConnectionID: 0xdecafbad13371337"))

View file

@ -28,7 +28,7 @@ var _ = Describe("Version Negotiation Packets", func() {
data, err := ComposeVersionNegotiation(destConnID, srcConnID, versions) data, err := ComposeVersionNegotiation(destConnID, srcConnID, versions)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(data[0] & 0x80).ToNot(BeZero()) Expect(data[0] & 0x80).ToNot(BeZero())
hdr, err := parseHeader(bytes.NewReader(data), protocol.PerspectiveServer) hdr, err := parseHeader(bytes.NewReader(data))
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(hdr.IsVersionNegotiation).To(BeTrue()) Expect(hdr.IsVersionNegotiation).To(BeTrue())
Expect(hdr.DestConnectionID).To(Equal(destConnID)) Expect(hdr.DestConnectionID).To(Equal(destConnID))

View file

@ -314,18 +314,25 @@ func (s *server) handlePacket(remoteAddr net.Addr, packet []byte) error {
func (s *server) handleIETFQUICPacket(hdr *wire.Header, packetData []byte, remoteAddr net.Addr, rcvTime time.Time) error { func (s *server) handleIETFQUICPacket(hdr *wire.Header, packetData []byte, remoteAddr net.Addr, rcvTime time.Time) error {
if hdr.IsLongHeader { if hdr.IsLongHeader {
if !s.supportsTLS {
return errors.New("Received an IETF QUIC Long Header")
}
if protocol.ByteCount(len(packetData)) < hdr.PayloadLen { if protocol.ByteCount(len(packetData)) < hdr.PayloadLen {
return fmt.Errorf("packet payload (%d bytes) is smaller than the expected payload length (%d bytes)", len(packetData), hdr.PayloadLen) return fmt.Errorf("packet payload (%d bytes) is smaller than the expected payload length (%d bytes)", len(packetData), hdr.PayloadLen)
} }
packetData = packetData[:int(hdr.PayloadLen)] packetData = packetData[:int(hdr.PayloadLen)]
// TODO(#1312): implement parsing of compound packets // TODO(#1312): implement parsing of compound packets
}
if hdr.Type == protocol.PacketTypeInitial { switch hdr.Type {
if s.supportsTLS { case protocol.PacketTypeInitial:
go s.serverTLS.HandleInitial(remoteAddr, hdr, packetData) go s.serverTLS.HandleInitial(remoteAddr, hdr, packetData)
return nil
case protocol.PacketTypeHandshake:
// nothing to do here. Packet will be passed to the session.
default:
// Note that this also drops 0-RTT packets.
return fmt.Errorf("Received unsupported packet type: %s", hdr.Type)
} }
return nil
} }
s.sessionsMutex.RLock() s.sessionsMutex.RLock()

View file

@ -386,6 +386,7 @@ var _ = Describe("Server", func() {
}) })
It("errors on packets that are smaller than the Payload Length in the packet header", func() { It("errors on packets that are smaller than the Payload Length in the packet header", func() {
serv.supportsTLS = true
b := &bytes.Buffer{} b := &bytes.Buffer{}
hdr := &wire.Header{ hdr := &wire.Header{
IsLongHeader: true, IsLongHeader: true,
@ -401,6 +402,7 @@ var _ = Describe("Server", func() {
}) })
It("cuts packets at the payload length", func() { It("cuts packets at the payload length", func() {
serv.supportsTLS = true
err := serv.handlePacket(nil, firstPacket) err := serv.handlePacket(nil, firstPacket)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
b := &bytes.Buffer{} b := &bytes.Buffer{}
@ -419,6 +421,22 @@ var _ = Describe("Server", func() {
Expect(serv.sessions[string(connID)].(*mockSession).handledPackets[1].data).To(HaveLen(123)) Expect(serv.sessions[string(connID)].(*mockSession).handledPackets[1].data).To(HaveLen(123))
}) })
It("drops packets with invalid packet types", func() {
serv.supportsTLS = true
b := &bytes.Buffer{}
hdr := &wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeRetry,
PayloadLen: 123,
SrcConnectionID: connID,
DestConnectionID: connID,
Version: versionIETFFrames,
}
Expect(hdr.Write(b, protocol.PerspectiveClient, versionIETFFrames)).To(Succeed())
err := serv.handlePacket(nil, append(b.Bytes(), make([]byte, 456)...))
Expect(err).To(MatchError("Received unsupported packet type: Retry"))
})
It("ignores Public Resets", func() { It("ignores Public Resets", func() {
err := serv.handlePacket(nil, firstPacket) err := serv.handlePacket(nil, firstPacket)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())