From 4556ad01e56a674c8ec566bccbabee8ce9b4efac Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Mon, 23 Oct 2017 12:28:52 +0700 Subject: [PATCH] use the new packet header for connections using TLS --- client.go | 14 +- client_test.go | 20 +- integrationtests/tools/proxy/proxy_test.go | 4 +- internal/wire/header.go | 96 ++++++++++ internal/wire/header_test.go | 178 ++++++++++++++++++ internal/wire/ietf_header.go | 28 +-- internal/wire/ietf_header_test.go | 62 +++---- internal/wire/public_header.go | 84 ++------- internal/wire/public_header_test.go | 206 +++++++++------------ internal/wire/version_negotiation.go | 5 +- packet_packer.go | 92 +++++---- packet_packer_test.go | 156 +++++++++------- packet_unpacker.go | 4 +- packet_unpacker_test.go | 4 +- server.go | 21 ++- server_test.go | 8 +- session.go | 22 +-- session_test.go | 48 ++--- 18 files changed, 631 insertions(+), 421 deletions(-) create mode 100644 internal/wire/header.go create mode 100644 internal/wire/header_test.go diff --git a/client.go b/client.go index f74b9ab5..ca56ce43 100644 --- a/client.go +++ b/client.go @@ -249,10 +249,10 @@ func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) { rcvTime := time.Now() r := bytes.NewReader(packet) - hdr, err := wire.ParsePublicHeader(r, protocol.PerspectiveServer, c.version) + hdr, err := wire.ParseHeader(r, protocol.PerspectiveServer, c.version) if err != nil { utils.Errorf("error parsing packet from %s: %s", remoteAddr.String(), err.Error()) - // drop this packet if we can't parse the Public Header + // drop this packet if we can't parse the header return } // reject packets with truncated connection id if we didn't request truncation @@ -307,14 +307,14 @@ func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) { } c.session.handlePacket(&receivedPacket{ - remoteAddr: remoteAddr, - publicHeader: hdr, - data: packet[len(packet)-r.Len():], - rcvTime: rcvTime, + remoteAddr: remoteAddr, + header: hdr, + data: packet[len(packet)-r.Len():], + rcvTime: rcvTime, }) } -func (c *client) handlePacketWithVersionFlag(hdr *wire.PublicHeader) error { +func (c *client) handlePacketWithVersionFlag(hdr *wire.Header) error { for _, v := range hdr.SupportedVersions { if v == c.version { // the version negotiation packet contains the version that we offered diff --git a/client_test.go b/client_test.go index 9bf60d21..1d8e1aa5 100644 --- a/client_test.go +++ b/client_test.go @@ -30,11 +30,11 @@ var _ = Describe("Client", func() { // generate a packet sent by the server that accepts the QUIC version suggested by the client acceptClientVersionPacket := func(connID protocol.ConnectionID) []byte { b := &bytes.Buffer{} - err := (&wire.PublicHeader{ + err := (&wire.Header{ ConnectionID: connID, PacketNumber: 1, PacketNumberLen: 1, - }).Write(b, protocol.VersionWhatever, protocol.PerspectiveServer) + }).Write(b, protocol.PerspectiveServer, protocol.VersionWhatever) Expect(err).ToNot(HaveOccurred()) return b.Bytes() } @@ -302,13 +302,13 @@ var _ = Describe("Client", func() { Context("version negotiation", func() { It("recognizes that a packet without VersionFlag means that the server accepted the suggested version", func() { - ph := wire.PublicHeader{ + ph := wire.Header{ PacketNumber: 1, PacketNumberLen: protocol.PacketNumberLen2, ConnectionID: 0x1337, } b := &bytes.Buffer{} - err := ph.Write(b, protocol.VersionWhatever, protocol.PerspectiveServer) + err := ph.Write(b, protocol.PerspectiveServer, protocol.VersionWhatever) Expect(err).ToNot(HaveOccurred()) cl.handlePacket(nil, b.Bytes()) Expect(cl.versionNegotiated).To(BeTrue()) @@ -450,11 +450,11 @@ var _ = Describe("Client", func() { It("ignores packets without connection id, if it didn't request connection id trunctation", func() { cl.config.RequestConnectionIDOmission = false buf := &bytes.Buffer{} - (&wire.PublicHeader{ + (&wire.Header{ OmitConnectionID: true, PacketNumber: 1, PacketNumberLen: 1, - }).Write(buf, protocol.VersionWhatever, protocol.PerspectiveServer) + }).Write(buf, protocol.PerspectiveServer, protocol.VersionWhatever) cl.handlePacket(addr, buf.Bytes()) Expect(sess.packetCount).To(BeZero()) Expect(sess.closed).To(BeFalse()) @@ -462,11 +462,11 @@ var _ = Describe("Client", func() { It("ignores packets with the wrong connection ID", func() { buf := &bytes.Buffer{} - (&wire.PublicHeader{ + (&wire.Header{ ConnectionID: cl.connectionID + 1, PacketNumber: 1, PacketNumberLen: 1, - }).Write(buf, protocol.VersionWhatever, protocol.PerspectiveServer) + }).Write(buf, protocol.PerspectiveServer, protocol.VersionWhatever) cl.handlePacket(addr, buf.Bytes()) Expect(sess.packetCount).To(BeZero()) Expect(sess.closed).To(BeFalse()) @@ -513,13 +513,13 @@ var _ = Describe("Client", func() { Context("handling packets", func() { It("handles packets", func() { - ph := wire.PublicHeader{ + ph := wire.Header{ PacketNumber: 1, PacketNumberLen: protocol.PacketNumberLen2, ConnectionID: 0x1337, } b := &bytes.Buffer{} - err := ph.Write(b, cl.version, protocol.PerspectiveServer) + err := ph.Write(b, protocol.PerspectiveServer, cl.version) Expect(err).ToNot(HaveOccurred()) packetConn.dataToRead = b.Bytes() diff --git a/integrationtests/tools/proxy/proxy_test.go b/integrationtests/tools/proxy/proxy_test.go index faf3219e..13fc89c1 100644 --- a/integrationtests/tools/proxy/proxy_test.go +++ b/integrationtests/tools/proxy/proxy_test.go @@ -20,13 +20,13 @@ type packetData []byte var _ = Describe("QUIC Proxy", func() { makePacket := func(p protocol.PacketNumber, payload []byte) []byte { b := &bytes.Buffer{} - hdr := wire.PublicHeader{ + hdr := wire.Header{ PacketNumber: p, PacketNumberLen: protocol.PacketNumberLen6, ConnectionID: 1337, OmitConnectionID: false, } - hdr.Write(b, protocol.VersionWhatever, protocol.PerspectiveServer) + hdr.Write(b, protocol.PerspectiveServer, protocol.VersionWhatever) raw := b.Bytes() raw = append(raw, payload...) return raw diff --git a/internal/wire/header.go b/internal/wire/header.go new file mode 100644 index 00000000..b45198c7 --- /dev/null +++ b/internal/wire/header.go @@ -0,0 +1,96 @@ +package wire + +import ( + "bytes" + + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" +) + +// Header is the header of a QUIC packet. +// It contains fields that are only needed for the gQUIC Public Header and the IETF draft Header. +type Header struct { + Raw []byte + ConnectionID protocol.ConnectionID + OmitConnectionID bool + PacketNumberLen protocol.PacketNumberLen + PacketNumber protocol.PacketNumber + Version protocol.VersionNumber // VersionNumber sent by the client + SupportedVersions []protocol.VersionNumber // Version Number sent in a Version Negotiation Packet by the server + + // only needed for the gQUIC Public Header + VersionFlag bool + ResetFlag bool + DiversificationNonce []byte + + // only needed for the IETF Header + Type uint8 + IsLongHeader bool + KeyPhase int +} + +// ParseHeader parses the header. +func ParseHeader(b *bytes.Reader, sentBy protocol.Perspective, version protocol.VersionNumber) (*Header, error) { + var typeByte uint8 + if version == protocol.VersionUnknown { + var err error + typeByte, err = b.ReadByte() + if err != nil { + return nil, err + } + _ = b.UnreadByte() // unread the type byte + } + + // There are two conditions this is a header in the IETF Header format: + // 1. We already know the version (because this is a packet that belongs to an exisitng session). + // 2. If this is a new packet, it must have the Long Format, which has the 0x80 bit set (which is always 0 in gQUIC). + // There's a third option: This could be a packet with Short Format that arrives after a server lost state. + // In that case, we'll try parsing the header as a gQUIC Public Header. + if version.UsesTLS() || (version == protocol.VersionUnknown && typeByte&0x80 > 0) { + return parseHeader(b, sentBy) + } + + // This is a gQUIC Public Header. + return parsePublicHeader(b, sentBy, version) +} + +// PeekConnectionID parses the connection ID from a QUIC packet's public header, sent by the client. +// This function should not be called for packets sent by the server, since on these packets the Connection ID could be omitted. +// If no error occurs, it restores the read position in the bytes.Reader. +func PeekConnectionID(b *bytes.Reader) (protocol.ConnectionID, error) { + var connectionID protocol.ConnectionID + if _, err := b.ReadByte(); err != nil { + return 0, err + } + // unread the public flag byte + defer b.UnreadByte() + + // Assume that the packet contains the Connection ID. + // This is a valid assumption for all packets sent by the client, because the server doesn't allow the ommision of the Connection ID. + connID, err := utils.BigEndian.ReadUint64(b) + if err != nil { + return 0, err + } + connectionID = protocol.ConnectionID(connID) + // unread the connection ID + for i := 0; i < 8; i++ { + b.UnreadByte() + } + return connectionID, nil +} + +// Write writes the Header. +func (h *Header) Write(b *bytes.Buffer, pers protocol.Perspective, version protocol.VersionNumber) error { + if !version.UsesTLS() { + return h.writePublicHeader(b, pers, version) + } + return h.writeHeader(b) +} + +// GetLength determines the length of the Header. +func (h *Header) GetLength(pers protocol.Perspective, version protocol.VersionNumber) (protocol.ByteCount, error) { + if !version.UsesTLS() { + return h.getPublicHeaderLength(pers) + } + return h.getHeaderLength() +} diff --git a/internal/wire/header_test.go b/internal/wire/header_test.go new file mode 100644 index 00000000..126ef7a9 --- /dev/null +++ b/internal/wire/header_test.go @@ -0,0 +1,178 @@ +package wire + +import ( + "bytes" + "io" + + "github.com/lucas-clemente/quic-go/internal/protocol" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Header", func() { + const ( + versionPublicHeader = protocol.Version39 // a QUIC version that uses the Public Header format + versionIETFHeader = protocol.VersionTLS // a QUIC version taht uses the IETF Header format + ) + + Context("peeking the connection ID", func() { + It("gets the connection ID", func() { + b := bytes.NewReader([]byte{0x09, 0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6, 0x51, 0x30, 0x33, 0x34, 0x01}) + len := b.Len() + connID, err := PeekConnectionID(b) + Expect(err).ToNot(HaveOccurred()) + Expect(connID).To(Equal(protocol.ConnectionID(0x4cfa9f9b668619f6))) + Expect(b.Len()).To(Equal(len)) + }) + + It("errors if the header is too short", func() { + b := bytes.NewReader([]byte{0x09, 0xf6, 0x19, 0x86, 0x66, 0x9b}) + _, err := PeekConnectionID(b) + Expect(err).To(HaveOccurred()) + }) + + It("errors if the header is empty", func() { + b := bytes.NewReader([]byte{}) + _, err := PeekConnectionID(b) + Expect(err).To(HaveOccurred()) + }) + }) + + Context("parsing", func() { + It("parses an IETF draft header, when the QUIC version supports TLS", func() { + buf := &bytes.Buffer{} + // use a short header, which isn't distinguishable from the gQUIC Public Header when looking at the type byte + err := (&Header{ + IsLongHeader: false, + KeyPhase: 1, + PacketNumber: 0x42, + PacketNumberLen: protocol.PacketNumberLen2, + }).writeHeader(buf) + Expect(err).ToNot(HaveOccurred()) + hdr, err := ParseHeader(bytes.NewReader(buf.Bytes()), protocol.PerspectiveClient, versionIETFHeader) + Expect(err).ToNot(HaveOccurred()) + Expect(hdr.KeyPhase).To(BeEquivalentTo(1)) + Expect(hdr.PacketNumber).To(Equal(protocol.PacketNumber(0x42))) + }) + + It("parses an IETF draft header, when the version is not known, but it has Long Header format", func() { + buf := &bytes.Buffer{} + err := (&Header{ + IsLongHeader: true, + Type: 3, + PacketNumber: 0x42, + }).writeHeader(buf) + Expect(err).ToNot(HaveOccurred()) + hdr, err := ParseHeader(bytes.NewReader(buf.Bytes()), protocol.PerspectiveClient, protocol.VersionUnknown) + Expect(err).ToNot(HaveOccurred()) + Expect(hdr.Type).To(BeEquivalentTo(3)) + Expect(hdr.PacketNumber).To(Equal(protocol.PacketNumber(0x42))) + }) + + It("parses a gQUIC Public Header, when the version is not known", func() { + buf := &bytes.Buffer{} + err := (&Header{ + VersionFlag: true, + Version: versionPublicHeader, + ConnectionID: 0x42, + PacketNumber: 0x1337, + PacketNumberLen: protocol.PacketNumberLen6, + }).writePublicHeader(buf, protocol.PerspectiveClient, versionPublicHeader) + Expect(err).ToNot(HaveOccurred()) + hdr, err := ParseHeader(bytes.NewReader(buf.Bytes()), protocol.PerspectiveClient, protocol.VersionUnknown) + Expect(err).ToNot(HaveOccurred()) + Expect(hdr.PacketNumber).To(Equal(protocol.PacketNumber(0x1337))) + Expect(hdr.Version).To(Equal(versionPublicHeader)) + }) + + It("parses a gQUIC Public Header, when the version is known", func() { + buf := &bytes.Buffer{} + err := (&Header{ + ConnectionID: 0x42, + PacketNumber: 0x1337, + PacketNumberLen: protocol.PacketNumberLen6, + DiversificationNonce: bytes.Repeat([]byte{'f'}, 32), + }).writePublicHeader(buf, protocol.PerspectiveServer, versionPublicHeader) + Expect(err).ToNot(HaveOccurred()) + hdr, err := ParseHeader(bytes.NewReader(buf.Bytes()), protocol.PerspectiveServer, versionPublicHeader) + Expect(err).ToNot(HaveOccurred()) + Expect(hdr.PacketNumber).To(Equal(protocol.PacketNumber(0x1337))) + Expect(hdr.DiversificationNonce).To(HaveLen(32)) + }) + + It("errors when given no data", func() { + _, err := ParseHeader(bytes.NewReader([]byte{}), protocol.PerspectiveClient, protocol.VersionUnknown) + Expect(err).To(MatchError(io.EOF)) + }) + }) + + Context("writing", func() { + It("writes a gQUIC Public Header", func() { + buf := &bytes.Buffer{} + err := (&Header{ + ConnectionID: 0x1337, + PacketNumber: 0x42, + PacketNumberLen: protocol.PacketNumberLen2, + }).Write(buf, protocol.PerspectiveServer, versionPublicHeader) + Expect(err).ToNot(HaveOccurred()) + _, err = parsePublicHeader(bytes.NewReader(buf.Bytes()), protocol.PerspectiveServer, versionPublicHeader) + Expect(err).ToNot(HaveOccurred()) + }) + + It("writes a IETF draft header", func() { + buf := &bytes.Buffer{} + err := (&Header{ + ConnectionID: 0x1337, + PacketNumber: 0x42, + PacketNumberLen: protocol.PacketNumberLen2, + KeyPhase: 1, + }).Write(buf, protocol.PerspectiveServer, versionIETFHeader) + Expect(err).ToNot(HaveOccurred()) + _, err = parseHeader(bytes.NewReader(buf.Bytes()), protocol.PerspectiveServer) + Expect(err).ToNot(HaveOccurred()) + }) + }) + + Context("getting the length", func() { + It("get the length of a gQUIC Public Header", func() { + buf := &bytes.Buffer{} + hdr := &Header{ + ConnectionID: 0x1337, + PacketNumber: 0x42, + PacketNumberLen: protocol.PacketNumberLen2, + DiversificationNonce: bytes.Repeat([]byte{'f'}, 32), + } + err := hdr.Write(buf, protocol.PerspectiveServer, versionPublicHeader) + Expect(err).ToNot(HaveOccurred()) + publicHeaderLen, err := hdr.getPublicHeaderLength(protocol.PerspectiveServer) + Expect(err).ToNot(HaveOccurred()) + ietfHeaderLen, err := hdr.getHeaderLength() + Expect(err).ToNot(HaveOccurred()) + Expect(publicHeaderLen).ToNot(Equal(ietfHeaderLen)) // make sure we can distinguish between the two header types + len, err := hdr.GetLength(protocol.PerspectiveServer, versionPublicHeader) + Expect(err).ToNot(HaveOccurred()) + Expect(len).To(Equal(publicHeaderLen)) + }) + + It("get the length of a a IETF draft header", func() { + buf := &bytes.Buffer{} + hdr := &Header{ + IsLongHeader: true, + ConnectionID: 0x1337, + PacketNumber: 0x42, + PacketNumberLen: protocol.PacketNumberLen2, + KeyPhase: 1, + } + err := hdr.Write(buf, protocol.PerspectiveServer, versionIETFHeader) + Expect(err).ToNot(HaveOccurred()) + publicHeaderLen, err := hdr.getPublicHeaderLength(protocol.PerspectiveServer) + Expect(err).ToNot(HaveOccurred()) + ietfHeaderLen, err := hdr.getHeaderLength() + Expect(err).ToNot(HaveOccurred()) + Expect(publicHeaderLen).ToNot(Equal(ietfHeaderLen)) // make sure we can distinguish between the two header types + len, err := hdr.GetLength(protocol.PerspectiveServer, versionIETFHeader) + Expect(err).ToNot(HaveOccurred()) + Expect(len).To(Equal(ietfHeaderLen)) + }) + }) +}) diff --git a/internal/wire/ietf_header.go b/internal/wire/ietf_header.go index 9bc23d27..5e35dca2 100644 --- a/internal/wire/ietf_header.go +++ b/internal/wire/ietf_header.go @@ -9,22 +9,8 @@ import ( "github.com/lucas-clemente/quic-go/qerr" ) -// The Header is the header of a QUIC Packet. -// TODO: add support for the key phase -type Header struct { - Type uint8 - IsLongHeader bool - KeyPhase int - OmitConnectionID bool - ConnectionID protocol.ConnectionID - PacketNumber protocol.PacketNumber - PacketNumberLen protocol.PacketNumberLen - Version protocol.VersionNumber - SupportedVersions []protocol.VersionNumber // version number sent in a Version Negotiation Packet by the server -} - -// ParseHeader parses a header -func ParseHeader(b *bytes.Reader, packetSentBy protocol.Perspective) (*Header, error) { +// parseHeader parses the header. +func parseHeader(b *bytes.Reader, packetSentBy protocol.Perspective) (*Header, error) { typeByte, err := b.ReadByte() if err != nil { return nil, err @@ -99,13 +85,15 @@ func parseShortHeader(b *bytes.Reader, typeByte byte) (*Header, error) { }, nil } -func (h *Header) Write(b *bytes.Buffer) error { +// writeHeader writes the Header. +func (h *Header) writeHeader(b *bytes.Buffer) error { if h.IsLongHeader { return h.writeLongHeader(b) } return h.writeShortHeader(b) } +// TODO: add support for the key phase func (h *Header) writeLongHeader(b *bytes.Buffer) error { b.WriteByte(byte(0x80 ^ h.Type)) utils.BigEndian.WriteUint64(b, uint64(h.ConnectionID)) @@ -129,8 +117,8 @@ func (h *Header) writeShortHeader(b *bytes.Buffer) error { default: return fmt.Errorf("invalid packet number length: %d", h.PacketNumberLen) } - b.WriteByte(typeByte) + if !h.OmitConnectionID { utils.BigEndian.WriteUint64(b, uint64(h.ConnectionID)) } @@ -145,8 +133,8 @@ func (h *Header) writeShortHeader(b *bytes.Buffer) error { return nil } -// GetLength gets the length of the Header in bytes. -func (h *Header) GetLength() (protocol.ByteCount, error) { +// getHeaderLength gets the length of the Header in bytes. +func (h *Header) getHeaderLength() (protocol.ByteCount, error) { if h.IsLongHeader { return 1 + 8 + 4 + 4, nil } diff --git a/internal/wire/ietf_header_test.go b/internal/wire/ietf_header_test.go index ecd96bab..335d96cf 100644 --- a/internal/wire/ietf_header_test.go +++ b/internal/wire/ietf_header_test.go @@ -11,7 +11,7 @@ import ( . "github.com/onsi/gomega" ) -var _ = Describe("Header", func() { +var _ = Describe("IETF draft Header", func() { Context("parsing", func() { Context("long headers", func() { var data []byte @@ -27,7 +27,7 @@ var _ = Describe("Header", func() { It("parses a long header", func() { b := bytes.NewReader(data) - h, err := ParseHeader(b, protocol.PerspectiveClient) + h, err := parseHeader(b, protocol.PerspectiveClient) Expect(err).ToNot(HaveOccurred()) Expect(h.Type).To(BeEquivalentTo(3)) Expect(h.IsLongHeader).To(BeTrue()) @@ -41,7 +41,7 @@ var _ = Describe("Header", func() { It("errors on EOF", func() { for i := 0; i < len(data); i++ { - _, err := ParseHeader(bytes.NewReader(data[:i]), protocol.PerspectiveClient) + _, err := parseHeader(bytes.NewReader(data[:i]), protocol.PerspectiveClient) Expect(err).To(Equal(io.EOF)) } }) @@ -57,7 +57,7 @@ var _ = Describe("Header", func() { 0x33, 0x44, 0x55, 0x66}..., ) b := bytes.NewReader(data) - h, err := ParseHeader(b, protocol.PerspectiveServer) + h, err := parseHeader(b, protocol.PerspectiveServer) Expect(err).ToNot(HaveOccurred()) Expect(h.SupportedVersions).To(Equal([]protocol.VersionNumber{ 0x22334455, @@ -68,20 +68,20 @@ var _ = Describe("Header", func() { It("errors if it contains versions of the wrong length", func() { data = append(data, []byte{0x22, 0x33}...) // too short. Should be 4 bytes. b := bytes.NewReader(data) - _, err := ParseHeader(b, protocol.PerspectiveServer) + _, err := parseHeader(b, protocol.PerspectiveServer) Expect(err).To(MatchError(qerr.InvalidVersionNegotiationPacket)) }) It("errors if it was sent by the client", func() { data = append(data, []byte{0x22, 0x33, 0x44, 0x55}...) b := bytes.NewReader(data) - _, err := ParseHeader(b, protocol.PerspectiveClient) + _, err := parseHeader(b, protocol.PerspectiveClient) Expect(err).To(MatchError("InvalidVersionNegotiationPacket: sent by the client")) }) It("errors if the version list is emtpy", func() { b := bytes.NewReader(data) - _, err := ParseHeader(b, protocol.PerspectiveServer) + _, err := parseHeader(b, protocol.PerspectiveServer) Expect(err).To(MatchError("InvalidVersionNegotiationPacket: empty version list")) }) }) @@ -95,7 +95,7 @@ var _ = Describe("Header", func() { 0x42, // packet number } b := bytes.NewReader(data) - h, err := ParseHeader(b, protocol.PerspectiveClient) + h, err := parseHeader(b, protocol.PerspectiveClient) Expect(err).ToNot(HaveOccurred()) Expect(h.IsLongHeader).To(BeFalse()) Expect(h.KeyPhase).To(Equal(0)) @@ -111,7 +111,7 @@ var _ = Describe("Header", func() { 0x11, } b := bytes.NewReader(data) - h, err := ParseHeader(b, protocol.PerspectiveClient) + h, err := parseHeader(b, protocol.PerspectiveClient) Expect(err).ToNot(HaveOccurred()) Expect(h.IsLongHeader).To(BeFalse()) Expect(h.KeyPhase).To(Equal(1)) @@ -124,7 +124,7 @@ var _ = Describe("Header", func() { 0x21, // packet number } b := bytes.NewReader(data) - h, err := ParseHeader(b, protocol.PerspectiveClient) + h, err := parseHeader(b, protocol.PerspectiveClient) Expect(err).ToNot(HaveOccurred()) Expect(h.IsLongHeader).To(BeFalse()) Expect(h.OmitConnectionID).To(BeTrue()) @@ -139,7 +139,7 @@ var _ = Describe("Header", func() { 0x13, 0x37, // packet number } b := bytes.NewReader(data) - h, err := ParseHeader(b, protocol.PerspectiveClient) + h, err := parseHeader(b, protocol.PerspectiveClient) Expect(err).ToNot(HaveOccurred()) Expect(h.IsLongHeader).To(BeFalse()) Expect(h.PacketNumber).To(Equal(protocol.PacketNumber(0x1337))) @@ -153,7 +153,7 @@ var _ = Describe("Header", func() { 0xde, 0xad, 0xbe, 0xef, // packet number } b := bytes.NewReader(data) - h, err := ParseHeader(b, protocol.PerspectiveClient) + h, err := parseHeader(b, protocol.PerspectiveClient) Expect(err).ToNot(HaveOccurred()) Expect(h.IsLongHeader).To(BeFalse()) Expect(h.PacketNumber).To(Equal(protocol.PacketNumber(0xdeadbeef))) @@ -168,7 +168,7 @@ var _ = Describe("Header", func() { 0xde, 0xca, 0xfb, 0xad, // packet number } for i := 0; i < len(data); i++ { - _, err := ParseHeader(bytes.NewReader(data[:i]), protocol.PerspectiveClient) + _, err := parseHeader(bytes.NewReader(data[:i]), protocol.PerspectiveClient) Expect(err).To(Equal(io.EOF)) } }) @@ -190,7 +190,7 @@ var _ = Describe("Header", func() { ConnectionID: 0xdeadbeefcafe1337, PacketNumber: 0xdecafbad, Version: 0x1020304, - }).Write(buf) + }).writeHeader(buf) Expect(err).ToNot(HaveOccurred()) Expect(buf.Bytes()).To(Equal([]byte{ 0x80 ^ 0x5, @@ -207,7 +207,7 @@ var _ = Describe("Header", func() { ConnectionID: 0xdeadbeefcafe1337, PacketNumberLen: protocol.PacketNumberLen1, PacketNumber: 0x42, - }).Write(buf) + }).writeHeader(buf) Expect(err).ToNot(HaveOccurred()) Expect(buf.Bytes()).To(Equal([]byte{ 0x40 ^ 0x1, @@ -221,7 +221,7 @@ var _ = Describe("Header", func() { OmitConnectionID: true, PacketNumberLen: protocol.PacketNumberLen1, PacketNumber: 0x42, - }).Write(buf) + }).writeHeader(buf) Expect(err).ToNot(HaveOccurred()) Expect(buf.Bytes()).To(Equal([]byte{ 0x1, @@ -234,7 +234,7 @@ var _ = Describe("Header", func() { OmitConnectionID: true, PacketNumberLen: protocol.PacketNumberLen2, PacketNumber: 0x1337, - }).Write(buf) + }).writeHeader(buf) Expect(err).ToNot(HaveOccurred()) Expect(buf.Bytes()).To(Equal([]byte{ 0x2, @@ -247,7 +247,7 @@ var _ = Describe("Header", func() { OmitConnectionID: true, PacketNumberLen: protocol.PacketNumberLen4, PacketNumber: 0xdecafbad, - }).Write(buf) + }).writeHeader(buf) Expect(err).ToNot(HaveOccurred()) Expect(buf.Bytes()).To(Equal([]byte{ 0x3, @@ -260,7 +260,7 @@ var _ = Describe("Header", func() { OmitConnectionID: true, PacketNumberLen: 3, PacketNumber: 0xdecafbad, - }).Write(buf) + }).writeHeader(buf) Expect(err).To(MatchError("invalid packet number length: 3")) }) @@ -270,7 +270,7 @@ var _ = Describe("Header", func() { OmitConnectionID: true, PacketNumberLen: protocol.PacketNumberLen1, PacketNumber: 0x42, - }).Write(buf) + }).writeHeader(buf) Expect(err).ToNot(HaveOccurred()) Expect(buf.Bytes()).To(Equal([]byte{ 0x20 ^ 0x1, @@ -289,8 +289,8 @@ var _ = Describe("Header", func() { It("has the right length for the long header", func() { h := &Header{IsLongHeader: true} - Expect(h.GetLength()).To(Equal(protocol.ByteCount(17))) - err := h.Write(buf) + Expect(h.getHeaderLength()).To(Equal(protocol.ByteCount(17))) + err := h.writeHeader(buf) Expect(err).ToNot(HaveOccurred()) Expect(buf.Len()).To(Equal(17)) }) @@ -299,8 +299,8 @@ var _ = Describe("Header", func() { h := &Header{ PacketNumberLen: protocol.PacketNumberLen1, } - Expect(h.GetLength()).To(Equal(protocol.ByteCount(1 + 8 + 1))) - err := h.Write(buf) + Expect(h.getHeaderLength()).To(Equal(protocol.ByteCount(1 + 8 + 1))) + err := h.writeHeader(buf) Expect(err).ToNot(HaveOccurred()) Expect(buf.Len()).To(Equal(10)) }) @@ -310,8 +310,8 @@ var _ = Describe("Header", func() { OmitConnectionID: true, PacketNumberLen: protocol.PacketNumberLen1, } - Expect(h.GetLength()).To(Equal(protocol.ByteCount(1 + 1))) - err := h.Write(buf) + Expect(h.getHeaderLength()).To(Equal(protocol.ByteCount(1 + 1))) + err := h.writeHeader(buf) Expect(err).ToNot(HaveOccurred()) Expect(buf.Len()).To(Equal(2)) }) @@ -321,8 +321,8 @@ var _ = Describe("Header", func() { OmitConnectionID: true, PacketNumberLen: protocol.PacketNumberLen2, } - Expect(h.GetLength()).To(Equal(protocol.ByteCount(1 + 2))) - err := h.Write(buf) + Expect(h.getHeaderLength()).To(Equal(protocol.ByteCount(1 + 2))) + err := h.writeHeader(buf) Expect(err).ToNot(HaveOccurred()) Expect(buf.Len()).To(Equal(3)) }) @@ -332,15 +332,15 @@ var _ = Describe("Header", func() { OmitConnectionID: true, PacketNumberLen: protocol.PacketNumberLen4, } - Expect(h.GetLength()).To(Equal(protocol.ByteCount(1 + 4))) - err := h.Write(buf) + Expect(h.getHeaderLength()).To(Equal(protocol.ByteCount(1 + 4))) + err := h.writeHeader(buf) Expect(err).ToNot(HaveOccurred()) Expect(buf.Len()).To(Equal(5)) }) It("errors when given an invalid packet number length", func() { h := &Header{PacketNumberLen: 5} - _, err := h.GetLength() + _, err := h.getHeaderLength() Expect(err).To(MatchError("invalid packet number length: 5")) }) }) diff --git a/internal/wire/public_header.go b/internal/wire/public_header.go index 0e9d9b7d..bb778eee 100644 --- a/internal/wire/public_header.go +++ b/internal/wire/public_header.go @@ -20,28 +20,13 @@ var ( errGetLengthNotForVersionNegotiation = errors.New("PublicHeader: GetLength cannot be called for VersionNegotiation packets") ) -// The PublicHeader is the header of a gQUIC packet. -type PublicHeader struct { - Raw []byte - ConnectionID protocol.ConnectionID - VersionFlag bool - ResetFlag bool - OmitConnectionID bool - PacketNumberLen protocol.PacketNumberLen - PacketNumber protocol.PacketNumber - VersionNumber protocol.VersionNumber // VersionNumber sent by the client - SupportedVersions []protocol.VersionNumber // VersionNumbers sent by the server - DiversificationNonce []byte -} - -// Write writes a Public Header. -func (h *PublicHeader) Write(b *bytes.Buffer, version protocol.VersionNumber, pers protocol.Perspective) error { - publicFlagByte := uint8(0x00) - +// writePublicHeader writes a Public Header. +func (h *Header) writePublicHeader(b *bytes.Buffer, pers protocol.Perspective, version protocol.VersionNumber) error { if h.VersionFlag && h.ResetFlag { return errResetAndVersionFlagSet } + publicFlagByte := uint8(0x00) if h.VersionFlag { publicFlagByte |= 0x01 } @@ -51,14 +36,12 @@ func (h *PublicHeader) Write(b *bytes.Buffer, version protocol.VersionNumber, pe if !h.OmitConnectionID { publicFlagByte |= 0x08 } - if len(h.DiversificationNonce) > 0 { if len(h.DiversificationNonce) != 32 { return errors.New("invalid diversification nonce length") } publicFlagByte |= 0x04 } - // only set PacketNumberLen bits if a packet number will be written if h.hasPacketNumber(pers) { switch h.PacketNumberLen { @@ -72,21 +55,17 @@ func (h *PublicHeader) Write(b *bytes.Buffer, version protocol.VersionNumber, pe publicFlagByte |= 0x30 } } - b.WriteByte(publicFlagByte) if !h.OmitConnectionID { utils.BigEndian.WriteUint64(b, uint64(h.ConnectionID)) } - if h.VersionFlag && pers == protocol.PerspectiveClient { - utils.LittleEndian.WriteUint32(b, protocol.VersionNumberToTag(h.VersionNumber)) + utils.LittleEndian.WriteUint32(b, protocol.VersionNumberToTag(h.Version)) } - if len(h.DiversificationNonce) > 0 { b.Write(h.DiversificationNonce) } - // if we're a server, and the VersionFlag is set, we must not include anything else in the packet if !h.hasPacketNumber(pers) { return nil @@ -108,39 +87,10 @@ func (h *PublicHeader) Write(b *bytes.Buffer, version protocol.VersionNumber, pe return nil } -// PeekConnectionID parses the connection ID from a QUIC packet's public header. -// If no error occurs, it restores the read position in the bytes.Reader. -func PeekConnectionID(b *bytes.Reader, packetSentBy protocol.Perspective) (protocol.ConnectionID, error) { - var connectionID protocol.ConnectionID - publicFlagByte, err := b.ReadByte() - if err != nil { - return 0, err - } - // unread the public flag byte - defer b.UnreadByte() - - omitConnectionID := publicFlagByte&0x08 == 0 - if omitConnectionID && packetSentBy == protocol.PerspectiveClient { - return 0, errReceivedOmittedConnectionID - } - if !omitConnectionID { - connID, err := utils.BigEndian.ReadUint64(b) - if err != nil { - return 0, err - } - connectionID = protocol.ConnectionID(connID) - // unread the connection ID - for i := 0; i < 8; i++ { - b.UnreadByte() - } - } - return connectionID, nil -} - -// ParsePublicHeader parses a QUIC packet's Public Header. +// parsePublicHeader parses a QUIC packet's Public Header. // The packetSentBy is the perspective of the peer that sent this PublicHeader, i.e. if we're the server, packetSentBy should be PerspectiveClient. -func ParsePublicHeader(b *bytes.Reader, packetSentBy protocol.Perspective, version protocol.VersionNumber) (*PublicHeader, error) { - header := &PublicHeader{} +func parsePublicHeader(b *bytes.Reader, packetSentBy protocol.Perspective, version protocol.VersionNumber) (*Header, error) { + header := &Header{} // First byte publicFlagByte, err := b.ReadByte() @@ -163,7 +113,6 @@ func ParsePublicHeader(b *bytes.Reader, packetSentBy protocol.Perspective, versi if header.OmitConnectionID && packetSentBy == protocol.PerspectiveClient { return nil, errReceivedOmittedConnectionID } - if header.hasPacketNumber(packetSentBy) { switch publicFlagByte & 0x30 { case 0x30: @@ -230,8 +179,8 @@ func ParsePublicHeader(b *bytes.Reader, packetSentBy protocol.Perspective, versi if err != nil { return nil, err } - header.VersionNumber = protocol.VersionTagToNumber(versionTag) - version = header.VersionNumber + header.Version = protocol.VersionTagToNumber(versionTag) + version = header.Version } // Packet number @@ -242,47 +191,40 @@ func ParsePublicHeader(b *bytes.Reader, packetSentBy protocol.Perspective, versi } header.PacketNumber = protocol.PacketNumber(packetNumber) } - return header, nil } -// GetLength gets the length of the publicHeader in bytes. +// getPublicHeaderLength gets the length of the publicHeader in bytes. // It can only be called for regular packets. -func (h *PublicHeader) GetLength(pers protocol.Perspective) (protocol.ByteCount, error) { +func (h *Header) getPublicHeaderLength(pers protocol.Perspective) (protocol.ByteCount, error) { if h.VersionFlag && h.ResetFlag { return 0, errResetAndVersionFlagSet } - if h.VersionFlag && pers == protocol.PerspectiveServer { return 0, errGetLengthNotForVersionNegotiation } length := protocol.ByteCount(1) // 1 byte for public flags - if h.hasPacketNumber(pers) { if h.PacketNumberLen != protocol.PacketNumberLen1 && h.PacketNumberLen != protocol.PacketNumberLen2 && h.PacketNumberLen != protocol.PacketNumberLen4 && h.PacketNumberLen != protocol.PacketNumberLen6 { return 0, errPacketNumberLenNotSet } length += protocol.ByteCount(h.PacketNumberLen) } - if !h.OmitConnectionID { length += 8 // 8 bytes for the connection ID } - // Version Number in packets sent by the client if h.VersionFlag { length += 4 } - length += protocol.ByteCount(len(h.DiversificationNonce)) - return length, nil } -// hasPacketNumber determines if this PublicHeader will contain a packet number +// hasPacketNumber determines if this Public Header will contain a packet number // this depends on the ResetFlag, the VersionFlag and who sent the packet -func (h *PublicHeader) hasPacketNumber(packetSentBy protocol.Perspective) bool { +func (h *Header) hasPacketNumber(packetSentBy protocol.Perspective) bool { if h.ResetFlag { return false } diff --git a/internal/wire/public_header_test.go b/internal/wire/public_header_test.go index 98414cdf..0e5af10f 100644 --- a/internal/wire/public_header_test.go +++ b/internal/wire/public_header_test.go @@ -12,53 +12,15 @@ import ( ) var _ = Describe("Public Header", func() { - Context("parsing the connection ID", func() { - It("does not accept an omitted connection ID as a server", func() { - b := bytes.NewReader([]byte{0x00, 0x01}) - _, err := PeekConnectionID(b, protocol.PerspectiveClient) - Expect(err).To(MatchError(errReceivedOmittedConnectionID)) - }) - - It("gets the connection ID", func() { - b := bytes.NewReader([]byte{0x09, 0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6, 0x51, 0x30, 0x33, 0x34, 0x01}) - len := b.Len() - connID, err := PeekConnectionID(b, protocol.PerspectiveClient) - Expect(err).ToNot(HaveOccurred()) - Expect(connID).To(Equal(protocol.ConnectionID(0x4cfa9f9b668619f6))) - Expect(b.Len()).To(Equal(len)) - }) - - It("errors if the Public Header is too short", func() { - b := bytes.NewReader([]byte{0x09, 0xf6, 0x19, 0x86, 0x66, 0x9b}) - _, err := PeekConnectionID(b, protocol.PerspectiveClient) - Expect(err).To(HaveOccurred()) - }) - - It("errors if the Public Header is empty", func() { - b := bytes.NewReader([]byte{}) - _, err := PeekConnectionID(b, protocol.PerspectiveClient) - Expect(err).To(HaveOccurred()) - }) - - It("accepts an ommitted connection ID as a client", func() { - b := bytes.NewReader([]byte{0x00, 0x01}) - len := b.Len() - connID, err := PeekConnectionID(b, protocol.PerspectiveServer) - Expect(err).ToNot(HaveOccurred()) - Expect(connID).To(BeZero()) - Expect(b.Len()).To(Equal(len)) - }) - }) - Context("when parsing", func() { It("accepts a sample client header", func() { b := bytes.NewReader([]byte{0x09, 0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6, 0x51, 0x30, 0x33, 0x34, 0x01}) - hdr, err := ParsePublicHeader(b, protocol.PerspectiveClient, protocol.VersionUnknown) + hdr, err := parsePublicHeader(b, protocol.PerspectiveClient, protocol.VersionUnknown) Expect(err).ToNot(HaveOccurred()) Expect(hdr.VersionFlag).To(BeTrue()) Expect(hdr.ResetFlag).To(BeFalse()) Expect(hdr.ConnectionID).To(Equal(protocol.ConnectionID(0x4cfa9f9b668619f6))) - Expect(hdr.VersionNumber).To(Equal(protocol.VersionNumber(34))) + Expect(hdr.Version).To(Equal(protocol.VersionNumber(34))) Expect(hdr.SupportedVersions).To(BeEmpty()) Expect(hdr.PacketNumber).To(Equal(protocol.PacketNumber(1))) Expect(b.Len()).To(BeZero()) @@ -66,13 +28,13 @@ var _ = Describe("Public Header", func() { It("does not accept an omittedd connection ID as a server", func() { b := bytes.NewReader([]byte{0x00, 0x01}) - _, err := ParsePublicHeader(b, protocol.PerspectiveClient, protocol.VersionWhatever) + _, err := parsePublicHeader(b, protocol.PerspectiveClient, protocol.VersionWhatever) Expect(err).To(MatchError(errReceivedOmittedConnectionID)) }) It("accepts aan d connection ID as a client", func() { b := bytes.NewReader([]byte{0x00, 0x01}) - hdr, err := ParsePublicHeader(b, protocol.PerspectiveServer, protocol.VersionWhatever) + hdr, err := parsePublicHeader(b, protocol.PerspectiveServer, protocol.VersionWhatever) Expect(err).ToNot(HaveOccurred()) Expect(hdr.OmitConnectionID).To(BeTrue()) Expect(hdr.ConnectionID).To(BeZero()) @@ -81,13 +43,13 @@ var _ = Describe("Public Header", func() { It("rejects 0 as a connection ID", func() { b := bytes.NewReader([]byte{0x09, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x51, 0x30, 0x33, 0x30, 0x01}) - _, err := ParsePublicHeader(b, protocol.PerspectiveClient, protocol.VersionUnknown) + _, err := parsePublicHeader(b, protocol.PerspectiveClient, protocol.VersionUnknown) Expect(err).To(MatchError(errInvalidConnectionID)) }) It("reads a PublicReset packet", func() { b := bytes.NewReader([]byte{0xa, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8}) - hdr, err := ParsePublicHeader(b, protocol.PerspectiveServer, protocol.VersionUnknown) + hdr, err := parsePublicHeader(b, protocol.PerspectiveServer, protocol.VersionUnknown) Expect(err).ToNot(HaveOccurred()) Expect(hdr.ResetFlag).To(BeTrue()) Expect(hdr.ConnectionID).ToNot(BeZero()) @@ -95,7 +57,7 @@ var _ = Describe("Public Header", func() { It("parses a public reset packet", func() { b := bytes.NewReader([]byte{0xa, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08}) - hdr, err := ParsePublicHeader(b, protocol.PerspectiveServer, protocol.VersionUnknown) + hdr, err := parsePublicHeader(b, protocol.PerspectiveServer, protocol.VersionUnknown) Expect(err).ToNot(HaveOccurred()) Expect(hdr.ResetFlag).To(BeTrue()) Expect(hdr.VersionFlag).To(BeFalse()) @@ -106,7 +68,7 @@ var _ = Describe("Public Header", func() { divNonce := []byte{0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8, 0x9, 0xa, 0xb, 0xc, 0xd, 0xe, 0xf, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f} Expect(divNonce).To(HaveLen(32)) b := bytes.NewReader(append(append([]byte{0x0c, 0xf6, 0x19, 0x86, 0x66, 0x9b, 0x9f, 0xfa, 0x4c}, divNonce...), 0x37)) - hdr, err := ParsePublicHeader(b, protocol.PerspectiveServer, protocol.VersionWhatever) + hdr, err := parsePublicHeader(b, protocol.PerspectiveServer, protocol.VersionWhatever) Expect(err).ToNot(HaveOccurred()) Expect(hdr.ConnectionID).To(Not(BeZero())) Expect(hdr.DiversificationNonce).To(Equal(divNonce)) @@ -115,7 +77,7 @@ var _ = Describe("Public Header", func() { It("returns an unknown version error when receiving a packet without a version for which the version is not given", func() { b := bytes.NewReader([]byte{0x10, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8, 0xef}) - _, err := ParsePublicHeader(b, protocol.PerspectiveServer, protocol.VersionUnknown) + _, err := parsePublicHeader(b, protocol.PerspectiveServer, protocol.VersionUnknown) Expect(err).To(MatchError(ErrPacketWithUnknownVersion)) }) @@ -124,7 +86,7 @@ var _ = Describe("Public Header", func() { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 0x01, }) - _, err := ParsePublicHeader(b, protocol.PerspectiveClient, protocol.VersionWhatever) + _, err := parsePublicHeader(b, protocol.PerspectiveClient, protocol.VersionWhatever) Expect(err).To(MatchError("diversification nonces should only be sent by servers")) }) @@ -137,17 +99,17 @@ var _ = Describe("Public Header", func() { It("parses version negotiation packets sent by the server", func() { b := bytes.NewReader(ComposeVersionNegotiation(0x1337, protocol.SupportedVersions)) - hdr, err := ParsePublicHeader(b, protocol.PerspectiveServer, protocol.VersionUnknown) + hdr, err := parsePublicHeader(b, protocol.PerspectiveServer, protocol.VersionUnknown) Expect(err).ToNot(HaveOccurred()) Expect(hdr.VersionFlag).To(BeTrue()) - Expect(hdr.VersionNumber).To(BeZero()) // unitialized + Expect(hdr.Version).To(BeZero()) // unitialized Expect(hdr.SupportedVersions).To(Equal(protocol.SupportedVersions)) Expect(b.Len()).To(BeZero()) }) It("errors if it doesn't contain any versions", func() { b := bytes.NewReader([]byte{0x9, 0xf6, 0x19, 0x86, 0x66, 0x9b, 0x9f, 0xfa, 0x4c}) - _, err := ParsePublicHeader(b, protocol.PerspectiveServer, protocol.VersionUnknown) + _, err := parsePublicHeader(b, protocol.PerspectiveServer, protocol.VersionUnknown) Expect(err).To(MatchError("InvalidVersionNegotiationPacket: empty version list")) }) @@ -157,7 +119,7 @@ var _ = Describe("Public Header", func() { data = appendVersion(data, protocol.SupportedVersions[0]) data = appendVersion(data, 99) // unsupported version b := bytes.NewReader(data) - hdr, err := ParsePublicHeader(b, protocol.PerspectiveServer, protocol.VersionUnknown) + hdr, err := parsePublicHeader(b, protocol.PerspectiveServer, protocol.VersionUnknown) Expect(err).ToNot(HaveOccurred()) Expect(hdr.VersionFlag).To(BeTrue()) Expect(hdr.SupportedVersions).To(Equal([]protocol.VersionNumber{1, protocol.SupportedVersions[0], 99})) @@ -168,7 +130,7 @@ var _ = Describe("Public Header", func() { data := ComposeVersionNegotiation(0x1337, protocol.SupportedVersions) data = append(data, []byte{0x13, 0x37}...) b := bytes.NewReader(data) - _, err := ParsePublicHeader(b, protocol.PerspectiveServer, protocol.VersionUnknown) + _, err := parsePublicHeader(b, protocol.PerspectiveServer, protocol.VersionUnknown) Expect(err).To(MatchError(qerr.InvalidVersionNegotiationPacket)) }) }) @@ -183,7 +145,7 @@ var _ = Describe("Public Header", func() { It("accepts 1-byte packet numbers", func() { b := bytes.NewReader([]byte{0x08, 0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6, 0xde}) - hdr, err := ParsePublicHeader(b, protocol.PerspectiveClient, version) + hdr, err := parsePublicHeader(b, protocol.PerspectiveClient, version) Expect(err).ToNot(HaveOccurred()) Expect(hdr.PacketNumber).To(Equal(protocol.PacketNumber(0xde))) Expect(hdr.PacketNumberLen).To(Equal(protocol.PacketNumberLen1)) @@ -192,7 +154,7 @@ var _ = Describe("Public Header", func() { It("accepts 2-byte packet numbers", func() { b := bytes.NewReader([]byte{0x18, 0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6, 0xde, 0xca}) - hdr, err := ParsePublicHeader(b, protocol.PerspectiveClient, version) + hdr, err := parsePublicHeader(b, protocol.PerspectiveClient, version) Expect(err).ToNot(HaveOccurred()) Expect(hdr.PacketNumber).To(Equal(protocol.PacketNumber(0xcade))) Expect(hdr.PacketNumberLen).To(Equal(protocol.PacketNumberLen2)) @@ -201,7 +163,7 @@ var _ = Describe("Public Header", func() { It("accepts 4-byte packet numbers", func() { b := bytes.NewReader([]byte{0x28, 0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6, 0xad, 0xfb, 0xca, 0xde}) - hdr, err := ParsePublicHeader(b, protocol.PerspectiveClient, version) + hdr, err := parsePublicHeader(b, protocol.PerspectiveClient, version) Expect(err).ToNot(HaveOccurred()) Expect(hdr.PacketNumber).To(Equal(protocol.PacketNumber(0xdecafbad))) Expect(hdr.PacketNumberLen).To(Equal(protocol.PacketNumberLen4)) @@ -210,7 +172,7 @@ var _ = Describe("Public Header", func() { It("accepts 6-byte packet numbers", func() { b := bytes.NewReader([]byte{0x38, 0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6, 0x23, 0x42, 0xad, 0xfb, 0xca, 0xde}) - hdr, err := ParsePublicHeader(b, protocol.PerspectiveClient, version) + hdr, err := parsePublicHeader(b, protocol.PerspectiveClient, version) Expect(err).ToNot(HaveOccurred()) Expect(hdr.PacketNumber).To(Equal(protocol.PacketNumber(0xdecafbad4223))) Expect(hdr.PacketNumberLen).To(Equal(protocol.PacketNumberLen6)) @@ -227,7 +189,7 @@ var _ = Describe("Public Header", func() { It("accepts 1-byte packet numbers", func() { b := bytes.NewReader([]byte{0x08, 0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6, 0xde}) - hdr, err := ParsePublicHeader(b, protocol.PerspectiveClient, version) + hdr, err := parsePublicHeader(b, protocol.PerspectiveClient, version) Expect(err).ToNot(HaveOccurred()) Expect(hdr.PacketNumber).To(Equal(protocol.PacketNumber(0xde))) Expect(hdr.PacketNumberLen).To(Equal(protocol.PacketNumberLen1)) @@ -236,7 +198,7 @@ var _ = Describe("Public Header", func() { It("accepts 2-byte packet numbers", func() { b := bytes.NewReader([]byte{0x18, 0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6, 0xde, 0xca}) - hdr, err := ParsePublicHeader(b, protocol.PerspectiveClient, version) + hdr, err := parsePublicHeader(b, protocol.PerspectiveClient, version) Expect(err).ToNot(HaveOccurred()) Expect(hdr.PacketNumber).To(Equal(protocol.PacketNumber(0xdeca))) Expect(hdr.PacketNumberLen).To(Equal(protocol.PacketNumberLen2)) @@ -245,7 +207,7 @@ var _ = Describe("Public Header", func() { It("accepts 4-byte packet numbers", func() { b := bytes.NewReader([]byte{0x28, 0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6, 0xad, 0xfb, 0xca, 0xde}) - hdr, err := ParsePublicHeader(b, protocol.PerspectiveClient, version) + hdr, err := parsePublicHeader(b, protocol.PerspectiveClient, version) Expect(err).ToNot(HaveOccurred()) Expect(hdr.PacketNumber).To(Equal(protocol.PacketNumber(0xadfbcade))) Expect(hdr.PacketNumberLen).To(Equal(protocol.PacketNumberLen4)) @@ -254,7 +216,7 @@ var _ = Describe("Public Header", func() { It("accepts 6-byte packet numbers", func() { b := bytes.NewReader([]byte{0x38, 0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6, 0x23, 0x42, 0xad, 0xfb, 0xca, 0xde}) - hdr, err := ParsePublicHeader(b, protocol.PerspectiveClient, version) + hdr, err := parsePublicHeader(b, protocol.PerspectiveClient, version) Expect(err).ToNot(HaveOccurred()) Expect(hdr.PacketNumber).To(Equal(protocol.PacketNumber(0x2342adfbcade))) Expect(hdr.PacketNumberLen).To(Equal(protocol.PacketNumberLen6)) @@ -267,60 +229,60 @@ var _ = Describe("Public Header", func() { Context("when writing", func() { It("writes a sample header as a server", func() { b := &bytes.Buffer{} - hdr := PublicHeader{ + hdr := Header{ ConnectionID: 0x4cfa9f9b668619f6, PacketNumber: 2, PacketNumberLen: protocol.PacketNumberLen6, } - err := hdr.Write(b, versionLittleEndian, protocol.PerspectiveServer) + err := hdr.writePublicHeader(b, protocol.PerspectiveServer, versionLittleEndian) Expect(err).ToNot(HaveOccurred()) Expect(b.Bytes()).To(Equal([]byte{0x38, 0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6, 2, 0, 0, 0, 0, 0})) }) It("writes a sample header as a client", func() { b := &bytes.Buffer{} - hdr := PublicHeader{ + hdr := Header{ ConnectionID: 0x4cfa9f9b668619f6, PacketNumber: 0x1337, PacketNumberLen: protocol.PacketNumberLen6, } - err := hdr.Write(b, versionLittleEndian, protocol.PerspectiveClient) + err := hdr.writePublicHeader(b, protocol.PerspectiveClient, versionLittleEndian) Expect(err).ToNot(HaveOccurred()) Expect(b.Bytes()).To(Equal([]byte{0x38, 0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6, 0x37, 0x13, 0, 0, 0, 0})) }) It("refuses to write a Public Header if the PacketNumberLen is not set", func() { - hdr := PublicHeader{ + hdr := Header{ ConnectionID: 1, PacketNumber: 2, } b := &bytes.Buffer{} - err := hdr.Write(b, protocol.VersionWhatever, protocol.PerspectiveServer) + err := hdr.writePublicHeader(b, protocol.PerspectiveServer, protocol.VersionWhatever) Expect(err).To(MatchError("PublicHeader: PacketNumberLen not set")) }) It("omits the connection ID", func() { b := &bytes.Buffer{} - hdr := PublicHeader{ + hdr := Header{ ConnectionID: 0x4cfa9f9b668619f6, OmitConnectionID: true, PacketNumberLen: protocol.PacketNumberLen6, PacketNumber: 1, } - err := hdr.Write(b, protocol.VersionWhatever, protocol.PerspectiveServer) + err := hdr.writePublicHeader(b, protocol.PerspectiveServer, protocol.VersionWhatever) Expect(err).ToNot(HaveOccurred()) Expect(b.Bytes()).To(Equal([]byte{0x30, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0})) }) It("writes diversification nonces", func() { b := &bytes.Buffer{} - hdr := PublicHeader{ + hdr := Header{ ConnectionID: 0x4cfa9f9b668619f6, PacketNumber: 1, PacketNumberLen: protocol.PacketNumberLen1, DiversificationNonce: bytes.Repeat([]byte{1}, 32), } - err := hdr.Write(b, versionLittleEndian, protocol.PerspectiveServer) + err := hdr.writePublicHeader(b, protocol.PerspectiveServer, versionLittleEndian) Expect(err).ToNot(HaveOccurred()) Expect(b.Bytes()).To(Equal([]byte{ 0x0c, 0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6, @@ -331,24 +293,24 @@ var _ = Describe("Public Header", func() { It("throws an error if both Reset Flag and Version Flag are set", func() { b := &bytes.Buffer{} - hdr := PublicHeader{ + hdr := Header{ VersionFlag: true, ResetFlag: true, } - err := hdr.Write(b, protocol.VersionWhatever, protocol.PerspectiveServer) + err := hdr.writePublicHeader(b, protocol.PerspectiveServer, protocol.VersionWhatever) Expect(err).To(MatchError(errResetAndVersionFlagSet)) }) Context("Version Negotiation packets", func() { It("sets the Version Flag for packets sent as a server", func() { b := &bytes.Buffer{} - hdr := PublicHeader{ + hdr := Header{ VersionFlag: true, ConnectionID: 0x4cfa9f9b668619f6, PacketNumber: 2, PacketNumberLen: protocol.PacketNumberLen6, } - err := hdr.Write(b, protocol.VersionWhatever, protocol.PerspectiveServer) + err := hdr.writePublicHeader(b, protocol.PerspectiveServer, protocol.VersionWhatever) Expect(err).ToNot(HaveOccurred()) // must be the first assertion Expect(b.Len()).To(Equal(1 + 8)) // 1 FlagByte + 8 ConnectionID @@ -359,14 +321,14 @@ var _ = Describe("Public Header", func() { It("sets the Version Flag for packets sent as a client, and adds a packet number", func() { b := &bytes.Buffer{} - hdr := PublicHeader{ + hdr := Header{ VersionFlag: true, - VersionNumber: protocol.Version38, + Version: protocol.Version38, ConnectionID: 0x4cfa9f9b668619f6, PacketNumber: 0x1337, PacketNumberLen: protocol.PacketNumberLen6, } - err := hdr.Write(b, protocol.VersionWhatever, protocol.PerspectiveClient) + err := hdr.writePublicHeader(b, protocol.PerspectiveClient, protocol.VersionWhatever) Expect(err).ToNot(HaveOccurred()) // must be the first assertion Expect(b.Len()).To(Equal(1 + 8 + 4 + 6)) // 1 FlagByte + 8 ConnectionID + 4 version number + 6 PacketNumber @@ -381,11 +343,11 @@ var _ = Describe("Public Header", func() { Context("PublicReset packets", func() { It("sets the Reset Flag", func() { b := &bytes.Buffer{} - hdr := PublicHeader{ + hdr := Header{ ResetFlag: true, ConnectionID: 0x4cfa9f9b668619f6, } - err := hdr.Write(b, protocol.VersionWhatever, protocol.PerspectiveServer) + err := hdr.writePublicHeader(b, protocol.PerspectiveServer, protocol.VersionWhatever) Expect(err).ToNot(HaveOccurred()) // must be the first assertion Expect(b.Len()).To(Equal(1 + 8)) // 1 FlagByte + 8 ConnectionID @@ -395,108 +357,108 @@ var _ = Describe("Public Header", func() { It("doesn't add a packet number for headers with Reset Flag sent as a client", func() { b := &bytes.Buffer{} - hdr := PublicHeader{ + hdr := Header{ ResetFlag: true, ConnectionID: 0x4cfa9f9b668619f6, PacketNumber: 2, PacketNumberLen: protocol.PacketNumberLen6, } - err := hdr.Write(b, protocol.VersionWhatever, protocol.PerspectiveClient) + err := hdr.writePublicHeader(b, protocol.PerspectiveClient, protocol.VersionWhatever) Expect(err).ToNot(HaveOccurred()) // must be the first assertion Expect(b.Len()).To(Equal(1 + 8)) // 1 FlagByte + 8 ConnectionID }) }) - Context("GetLength", func() { - It("errors when calling GetLength for Version Negotiation packets", func() { - hdr := PublicHeader{VersionFlag: true} - _, err := hdr.GetLength(protocol.PerspectiveServer) + Context("getting the length", func() { + It("errors when calling getPublicHeaderLength for Version Negotiation packets", func() { + hdr := Header{VersionFlag: true} + _, err := hdr.getPublicHeaderLength(protocol.PerspectiveServer) Expect(err).To(MatchError(errGetLengthNotForVersionNegotiation)) }) - It("errors when calling GetLength for packets that have the VersionFlag and the ResetFlag set", func() { - hdr := PublicHeader{ + It("errors when calling getPublicHeaderLength for packets that have the VersionFlag and the ResetFlag set", func() { + hdr := Header{ ResetFlag: true, VersionFlag: true, } - _, err := hdr.GetLength(protocol.PerspectiveServer) + _, err := hdr.getPublicHeaderLength(protocol.PerspectiveServer) Expect(err).To(MatchError(errResetAndVersionFlagSet)) }) It("errors when PacketNumberLen is not set", func() { - hdr := PublicHeader{ + hdr := Header{ ConnectionID: 0x4cfa9f9b668619f6, PacketNumber: 0xDECAFBAD, } - _, err := hdr.GetLength(protocol.PerspectiveServer) + _, err := hdr.getPublicHeaderLength(protocol.PerspectiveServer) Expect(err).To(MatchError(errPacketNumberLenNotSet)) }) It("gets the length of a packet with longest packet number length and connectionID", func() { - hdr := PublicHeader{ + hdr := Header{ ConnectionID: 0x4cfa9f9b668619f6, PacketNumber: 0xDECAFBAD, PacketNumberLen: protocol.PacketNumberLen6, } - length, err := hdr.GetLength(protocol.PerspectiveServer) + length, err := hdr.getPublicHeaderLength(protocol.PerspectiveServer) Expect(err).ToNot(HaveOccurred()) Expect(length).To(Equal(protocol.ByteCount(1 + 8 + 6))) // 1 byte public flag, 8 bytes connectionID, and packet number }) It("gets the lengths of a packet sent by the client with the VersionFlag set", func() { - hdr := PublicHeader{ + hdr := Header{ ConnectionID: 0x4cfa9f9b668619f6, OmitConnectionID: true, PacketNumber: 0xDECAFBAD, PacketNumberLen: protocol.PacketNumberLen6, VersionFlag: true, - VersionNumber: versionLittleEndian, + Version: versionLittleEndian, } - length, err := hdr.GetLength(protocol.PerspectiveClient) + length, err := hdr.getPublicHeaderLength(protocol.PerspectiveClient) Expect(err).ToNot(HaveOccurred()) Expect(length).To(Equal(protocol.ByteCount(1 + 4 + 6))) // 1 byte public flag, 4 version number, and packet number }) It("gets the length of a packet with longest packet number length and omitted connectionID", func() { - hdr := PublicHeader{ + hdr := Header{ ConnectionID: 0x4cfa9f9b668619f6, OmitConnectionID: true, PacketNumber: 0xDECAFBAD, PacketNumberLen: protocol.PacketNumberLen6, } - length, err := hdr.GetLength(protocol.PerspectiveServer) + length, err := hdr.getPublicHeaderLength(protocol.PerspectiveServer) Expect(err).ToNot(HaveOccurred()) Expect(length).To(Equal(protocol.ByteCount(1 + 6))) // 1 byte public flag, and packet number }) It("gets the length of a packet 2 byte packet number length ", func() { - hdr := PublicHeader{ + hdr := Header{ ConnectionID: 0x4cfa9f9b668619f6, PacketNumber: 0xDECAFBAD, PacketNumberLen: protocol.PacketNumberLen2, } - length, err := hdr.GetLength(protocol.PerspectiveServer) + length, err := hdr.getPublicHeaderLength(protocol.PerspectiveServer) Expect(err).ToNot(HaveOccurred()) Expect(length).To(Equal(protocol.ByteCount(1 + 8 + 2))) // 1 byte public flag, 8 byte connectionID, and packet number }) It("works with diversification nonce", func() { - hdr := PublicHeader{ + hdr := Header{ DiversificationNonce: []byte("foo"), PacketNumberLen: protocol.PacketNumberLen1, } - length, err := hdr.GetLength(protocol.PerspectiveServer) + length, err := hdr.getPublicHeaderLength(protocol.PerspectiveServer) Expect(err).NotTo(HaveOccurred()) Expect(length).To(Equal(protocol.ByteCount(1 + 8 + 3 + 1))) // 1 byte public flag, 8 byte connectionID, 3 byte DiversificationNonce, 1 byte PacketNumber }) It("gets the length of a PublicReset", func() { - hdr := PublicHeader{ + hdr := Header{ ResetFlag: true, ConnectionID: 0x4cfa9f9b668619f6, } - length, err := hdr.GetLength(protocol.PerspectiveServer) + length, err := hdr.getPublicHeaderLength(protocol.PerspectiveServer) Expect(err).NotTo(HaveOccurred()) Expect(length).To(Equal(protocol.ByteCount(1 + 8))) // 1 byte public flag, 8 byte connectionID }) @@ -505,11 +467,11 @@ var _ = Describe("Public Header", func() { Context("packet number length", func() { It("doesn't write a header if the packet number length is not set", func() { b := &bytes.Buffer{} - hdr := PublicHeader{ + hdr := Header{ ConnectionID: 0x4cfa9f9b668619f6, PacketNumber: 0xDECAFBAD, } - err := hdr.Write(b, protocol.VersionWhatever, protocol.PerspectiveServer) + err := hdr.writePublicHeader(b, protocol.PerspectiveServer, protocol.VersionWhatever) Expect(err).To(MatchError("PublicHeader: PacketNumberLen not set")) }) @@ -522,48 +484,48 @@ var _ = Describe("Public Header", func() { It("writes a header with a 1-byte packet number", func() { b := &bytes.Buffer{} - hdr := PublicHeader{ + hdr := Header{ ConnectionID: 0x4cfa9f9b668619f6, PacketNumber: 0xDECAFBAD, PacketNumberLen: protocol.PacketNumberLen1, } - err := hdr.Write(b, version, protocol.PerspectiveServer) + err := hdr.writePublicHeader(b, protocol.PerspectiveServer, version) Expect(err).ToNot(HaveOccurred()) Expect(b.Bytes()).To(Equal([]byte{0x08, 0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6, 0xAD})) }) It("writes a header with a 2-byte packet number", func() { b := &bytes.Buffer{} - hdr := PublicHeader{ + hdr := Header{ ConnectionID: 0x4cfa9f9b668619f6, PacketNumber: 0xDECAFBAD, PacketNumberLen: protocol.PacketNumberLen2, } - err := hdr.Write(b, version, protocol.PerspectiveServer) + err := hdr.writePublicHeader(b, protocol.PerspectiveServer, version) Expect(err).ToNot(HaveOccurred()) Expect(b.Bytes()).To(Equal([]byte{0x18, 0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6, 0xad, 0xfb})) }) It("writes a header with a 4-byte packet number", func() { b := &bytes.Buffer{} - hdr := PublicHeader{ + hdr := Header{ ConnectionID: 0x4cfa9f9b668619f6, PacketNumber: 0x13DECAFBAD, PacketNumberLen: protocol.PacketNumberLen4, } - err := hdr.Write(b, version, protocol.PerspectiveServer) + err := hdr.writePublicHeader(b, protocol.PerspectiveServer, version) Expect(err).ToNot(HaveOccurred()) Expect(b.Bytes()).To(Equal([]byte{0x28, 0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6, 0xAD, 0xfb, 0xca, 0xde})) }) It("writes a header with a 6-byte packet number", func() { b := &bytes.Buffer{} - hdr := PublicHeader{ + hdr := Header{ ConnectionID: 0x4cfa9f9b668619f6, PacketNumber: 0xBE1337DECAFBAD, PacketNumberLen: protocol.PacketNumberLen6, } - err := hdr.Write(b, version, protocol.PerspectiveServer) + err := hdr.writePublicHeader(b, protocol.PerspectiveServer, version) Expect(err).ToNot(HaveOccurred()) Expect(b.Bytes()).To(Equal([]byte{0x38, 0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6, 0xad, 0xfb, 0xca, 0xde, 0x37, 0x13})) }) @@ -578,48 +540,48 @@ var _ = Describe("Public Header", func() { It("writes a header with a 1-byte packet number", func() { b := &bytes.Buffer{} - hdr := PublicHeader{ + hdr := Header{ ConnectionID: 0x4cfa9f9b668619f6, PacketNumber: 0xdecafbad, PacketNumberLen: protocol.PacketNumberLen1, } - err := hdr.Write(b, version, protocol.PerspectiveServer) + err := hdr.writePublicHeader(b, protocol.PerspectiveServer, version) Expect(err).ToNot(HaveOccurred()) Expect(b.Bytes()).To(Equal([]byte{0x08, 0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6, 0xad})) }) It("writes a header with a 2-byte packet number", func() { b := &bytes.Buffer{} - hdr := PublicHeader{ + hdr := Header{ ConnectionID: 0x4cfa9f9b668619f6, PacketNumber: 0xdecafbad, PacketNumberLen: protocol.PacketNumberLen2, } - err := hdr.Write(b, version, protocol.PerspectiveServer) + err := hdr.writePublicHeader(b, protocol.PerspectiveServer, version) Expect(err).ToNot(HaveOccurred()) Expect(b.Bytes()).To(Equal([]byte{0x18, 0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6, 0xfb, 0xad})) }) It("writes a header with a 4-byte packet number", func() { b := &bytes.Buffer{} - hdr := PublicHeader{ + hdr := Header{ ConnectionID: 0x4cfa9f9b668619f6, PacketNumber: 0x13decafbad, PacketNumberLen: protocol.PacketNumberLen4, } - err := hdr.Write(b, version, protocol.PerspectiveServer) + err := hdr.writePublicHeader(b, protocol.PerspectiveServer, version) Expect(err).ToNot(HaveOccurred()) Expect(b.Bytes()).To(Equal([]byte{0x28, 0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6, 0xde, 0xca, 0xfb, 0xad})) }) It("writes a header with a 6-byte packet number", func() { b := &bytes.Buffer{} - hdr := PublicHeader{ + hdr := Header{ ConnectionID: 0x4cfa9f9b668619f6, PacketNumber: 0xbe1337decafbad, PacketNumberLen: protocol.PacketNumberLen6, } - err := hdr.Write(b, version, protocol.PerspectiveServer) + err := hdr.writePublicHeader(b, protocol.PerspectiveServer, version) Expect(err).ToNot(HaveOccurred()) Expect(b.Bytes()).To(Equal([]byte{0x38, 0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6, 0x13, 0x37, 0xde, 0xca, 0xfb, 0xad})) }) diff --git a/internal/wire/version_negotiation.go b/internal/wire/version_negotiation.go index 9caa7fb4..eaa13210 100644 --- a/internal/wire/version_negotiation.go +++ b/internal/wire/version_negotiation.go @@ -8,14 +8,15 @@ import ( ) // ComposeVersionNegotiation composes a Version Negotiation Packet +// TODO(894): implement the IETF draft format of Version Negotiation Packets func ComposeVersionNegotiation(connectionID protocol.ConnectionID, versions []protocol.VersionNumber) []byte { fullReply := &bytes.Buffer{} - responsePublicHeader := PublicHeader{ + ph := Header{ ConnectionID: connectionID, PacketNumber: 1, VersionFlag: true, } - err := responsePublicHeader.Write(fullReply, protocol.VersionWhatever, protocol.PerspectiveServer) + err := ph.writePublicHeader(fullReply, protocol.PerspectiveServer, protocol.VersionWhatever) if err != nil { utils.Errorf("error composing version negotiation packet: %s", err.Error()) } diff --git a/packet_packer.go b/packet_packer.go index 57ff3532..5cc2e08a 100644 --- a/packet_packer.go +++ b/packet_packer.go @@ -54,10 +54,10 @@ func newPacketPacker(connectionID protocol.ConnectionID, func (p *packetPacker) PackConnectionClose(ccf *wire.ConnectionCloseFrame) (*packedPacket, error) { frames := []wire.Frame{ccf} encLevel, sealer := p.cryptoSetup.GetSealer() - ph := p.getPublicHeader(encLevel) - raw, err := p.writeAndSealPacket(ph, frames, sealer) + header := p.getHeader(encLevel) + raw, err := p.writeAndSealPacket(header, frames, sealer) return &packedPacket{ - number: ph.PacketNumber, + number: header.PacketNumber, raw: raw, frames: frames, encryptionLevel: encLevel, @@ -69,18 +69,18 @@ func (p *packetPacker) PackAckPacket() (*packedPacket, error) { return nil, errors.New("packet packer BUG: no ack frame queued") } encLevel, sealer := p.cryptoSetup.GetSealer() - ph := p.getPublicHeader(encLevel) + header := p.getHeader(encLevel) frames := []wire.Frame{p.ackFrame} if p.stopWaiting != nil { - p.stopWaiting.PacketNumber = ph.PacketNumber - p.stopWaiting.PacketNumberLen = ph.PacketNumberLen + p.stopWaiting.PacketNumber = header.PacketNumber + p.stopWaiting.PacketNumberLen = header.PacketNumberLen frames = append(frames, p.stopWaiting) p.stopWaiting = nil } p.ackFrame = nil - raw, err := p.writeAndSealPacket(ph, frames, sealer) + raw, err := p.writeAndSealPacket(header, frames, sealer) return &packedPacket{ - number: ph.PacketNumber, + number: header.PacketNumber, raw: raw, frames: frames, encryptionLevel: encLevel, @@ -99,14 +99,14 @@ func (p *packetPacker) PackHandshakeRetransmission(packet *ackhandler.Packet) (* if p.stopWaiting == nil { return nil, errors.New("PacketPacker BUG: Handshake retransmissions must contain a StopWaitingFrame") } - ph := p.getPublicHeader(packet.EncryptionLevel) - p.stopWaiting.PacketNumber = ph.PacketNumber - p.stopWaiting.PacketNumberLen = ph.PacketNumberLen + header := p.getHeader(packet.EncryptionLevel) + p.stopWaiting.PacketNumber = header.PacketNumber + p.stopWaiting.PacketNumberLen = header.PacketNumberLen frames := append([]wire.Frame{p.stopWaiting}, packet.Frames...) p.stopWaiting = nil - raw, err := p.writeAndSealPacket(ph, frames, sealer) + raw, err := p.writeAndSealPacket(header, frames, sealer) return &packedPacket{ - number: ph.PacketNumber, + number: header.PacketNumber, raw: raw, frames: frames, encryptionLevel: packet.EncryptionLevel, @@ -122,17 +122,17 @@ func (p *packetPacker) PackPacket() (*packedPacket, error) { encLevel, sealer := p.cryptoSetup.GetSealer() - publicHeader := p.getPublicHeader(encLevel) - publicHeaderLength, err := publicHeader.GetLength(p.perspective) + header := p.getHeader(encLevel) + headerLength, err := header.GetLength(p.perspective, p.version) if err != nil { return nil, err } if p.stopWaiting != nil { - p.stopWaiting.PacketNumber = publicHeader.PacketNumber - p.stopWaiting.PacketNumberLen = publicHeader.PacketNumberLen + p.stopWaiting.PacketNumber = header.PacketNumber + p.stopWaiting.PacketNumberLen = header.PacketNumberLen } - maxSize := protocol.MaxPacketSize - protocol.ByteCount(sealer.Overhead()) - publicHeaderLength + maxSize := protocol.MaxPacketSize - protocol.ByteCount(sealer.Overhead()) - headerLength payloadFrames, err := p.composeNextPacket(maxSize, p.canSendData(encLevel)) if err != nil { return nil, err @@ -149,12 +149,12 @@ func (p *packetPacker) PackPacket() (*packedPacket, error) { p.stopWaiting = nil p.ackFrame = nil - raw, err := p.writeAndSealPacket(publicHeader, payloadFrames, sealer) + raw, err := p.writeAndSealPacket(header, payloadFrames, sealer) if err != nil { return nil, err } return &packedPacket{ - number: publicHeader.PacketNumber, + number: header.PacketNumber, raw: raw, frames: payloadFrames, encryptionLevel: encLevel, @@ -163,19 +163,19 @@ func (p *packetPacker) PackPacket() (*packedPacket, error) { func (p *packetPacker) packCryptoPacket() (*packedPacket, error) { encLevel, sealer := p.cryptoSetup.GetSealerForCryptoStream() - publicHeader := p.getPublicHeader(encLevel) - publicHeaderLength, err := publicHeader.GetLength(p.perspective) + header := p.getHeader(encLevel) + headerLength, err := header.GetLength(p.perspective, p.version) if err != nil { return nil, err } - maxLen := protocol.MaxPacketSize - protocol.ByteCount(sealer.Overhead()) - protocol.NonForwardSecurePacketSizeReduction - publicHeaderLength + maxLen := protocol.MaxPacketSize - protocol.ByteCount(sealer.Overhead()) - protocol.NonForwardSecurePacketSizeReduction - headerLength frames := []wire.Frame{p.streamFramer.PopCryptoStreamFrame(maxLen)} - raw, err := p.writeAndSealPacket(publicHeader, frames, sealer) + raw, err := p.writeAndSealPacket(header, frames, sealer) if err != nil { return nil, err } return &packedPacket{ - number: publicHeader.PacketNumber, + number: header.PacketNumber, raw: raw, frames: frames, encryptionLevel: encLevel, @@ -262,38 +262,50 @@ func (p *packetPacker) QueueControlFrame(frame wire.Frame) { } } -func (p *packetPacker) getPublicHeader(encLevel protocol.EncryptionLevel) *wire.PublicHeader { +func (p *packetPacker) getHeader(encLevel protocol.EncryptionLevel) *wire.Header { pnum := p.packetNumberGenerator.Peek() packetNumberLen := protocol.GetPacketNumberLengthForHeader(pnum, p.leastUnacked) - publicHeader := &wire.PublicHeader{ + + var isLongHeader bool + if p.version.UsesTLS() && encLevel != protocol.EncryptionForwardSecure { + // TODO: set the Long Header type + packetNumberLen = protocol.PacketNumberLen4 + isLongHeader = true + } + + header := &wire.Header{ ConnectionID: p.connectionID, PacketNumber: pnum, PacketNumberLen: packetNumberLen, + IsLongHeader: isLongHeader, } if p.omitConnectionID && encLevel == protocol.EncryptionForwardSecure { - publicHeader.OmitConnectionID = true + header.OmitConnectionID = true } - if p.perspective == protocol.PerspectiveServer && encLevel == protocol.EncryptionSecure { - publicHeader.DiversificationNonce = p.cryptoSetup.DiversificationNonce() + if !p.version.UsesTLS() { + if p.perspective == protocol.PerspectiveServer && encLevel == protocol.EncryptionSecure { + header.DiversificationNonce = p.cryptoSetup.DiversificationNonce() + } + if p.perspective == protocol.PerspectiveClient && encLevel != protocol.EncryptionForwardSecure { + header.VersionFlag = true + header.Version = p.version + } + } else if encLevel != protocol.EncryptionForwardSecure { + header.Version = p.version } - if p.perspective == protocol.PerspectiveClient && encLevel != protocol.EncryptionForwardSecure { - publicHeader.VersionFlag = true - publicHeader.VersionNumber = p.version - } - - return publicHeader + return header } func (p *packetPacker) writeAndSealPacket( - publicHeader *wire.PublicHeader, + header *wire.Header, payloadFrames []wire.Frame, sealer handshake.Sealer, ) ([]byte, error) { raw := getPacketBuffer() buffer := bytes.NewBuffer(raw) - if err := publicHeader.Write(buffer, p.version, p.perspective); err != nil { + if err := header.Write(buffer, p.perspective, p.version); err != nil { return nil, err } payloadStartIndex := buffer.Len() @@ -308,11 +320,11 @@ func (p *packetPacker) writeAndSealPacket( } raw = raw[0:buffer.Len()] - _ = sealer.Seal(raw[payloadStartIndex:payloadStartIndex], raw[payloadStartIndex:], publicHeader.PacketNumber, raw[:payloadStartIndex]) + _ = sealer.Seal(raw[payloadStartIndex:payloadStartIndex], raw[payloadStartIndex:], header.PacketNumber, raw[:payloadStartIndex]) raw = raw[0 : buffer.Len()+sealer.Overhead()] num := p.packetNumberGenerator.Pop() - if num != publicHeader.PacketNumber { + if num != header.PacketNumber { return nil, errors.New("packetPacker BUG: Peeked and Popped packet numbers do not match") } diff --git a/packet_packer_test.go b/packet_packer_test.go index 2ae8aa9f..694b0fb3 100644 --- a/packet_packer_test.go +++ b/packet_packer_test.go @@ -109,33 +109,105 @@ var _ = Describe("Packet packer", func() { Expect(p.encryptionLevel).To(Equal(protocol.EncryptionForwardSecure)) }) - Context("diversificaton nonces", func() { - var nonce []byte + Context("generating a packet header", func() { + const ( + versionPublicHeader = protocol.Version39 // a QUIC version that uses the Public Header format + versionIETFHeader = protocol.VersionTLS // a QUIC version taht uses the IETF Header format + ) - BeforeEach(func() { - nonce = bytes.Repeat([]byte{'e'}, 32) - packer.cryptoSetup.(*mockCryptoSetup).divNonce = nonce + Context("Public Header (for gQUIC)", func() { + BeforeEach(func() { + packer.version = versionPublicHeader + }) + + It("it omits the connection ID for forward-secure packets", func() { + ph := packer.getHeader(protocol.EncryptionForwardSecure) + Expect(ph.OmitConnectionID).To(BeFalse()) + packer.SetOmitConnectionID() + ph = packer.getHeader(protocol.EncryptionForwardSecure) + Expect(ph.OmitConnectionID).To(BeTrue()) + }) + + It("doesn't omit the connection ID for non-forward-secure packets", func() { + packer.SetOmitConnectionID() + ph := packer.getHeader(protocol.EncryptionSecure) + Expect(ph.OmitConnectionID).To(BeFalse()) + }) + + It("adds the Version Flag to the Public Header before the crypto handshake is finished", func() { + packer.perspective = protocol.PerspectiveClient + ph := packer.getHeader(protocol.EncryptionSecure) + Expect(ph.VersionFlag).To(BeTrue()) + }) + + It("doesn't add the Version Flag to the Public Header for forward-secure packets", func() { + packer.perspective = protocol.PerspectiveClient + ph := packer.getHeader(protocol.EncryptionForwardSecure) + Expect(ph.VersionFlag).To(BeFalse()) + }) + + Context("diversificaton nonces", func() { + var nonce []byte + + BeforeEach(func() { + nonce = bytes.Repeat([]byte{'e'}, 32) + packer.cryptoSetup.(*mockCryptoSetup).divNonce = nonce + }) + + It("doesn't include a div nonce, when sending a packet with initial encryption", func() { + ph := packer.getHeader(protocol.EncryptionUnencrypted) + Expect(ph.DiversificationNonce).To(BeEmpty()) + }) + + It("includes a div nonce, when sending a packet with secure encryption", func() { + ph := packer.getHeader(protocol.EncryptionSecure) + Expect(ph.DiversificationNonce).To(Equal(nonce)) + }) + + It("doesn't include a div nonce, when sending a packet with forward-secure encryption", func() { + ph := packer.getHeader(protocol.EncryptionForwardSecure) + Expect(ph.DiversificationNonce).To(BeEmpty()) + }) + + It("doesn't send a div nonce as a client", func() { + packer.perspective = protocol.PerspectiveClient + ph := packer.getHeader(protocol.EncryptionSecure) + Expect(ph.DiversificationNonce).To(BeEmpty()) + }) + }) }) - It("doesn't include a div nonce, when sending a packet with initial encryption", func() { - ph := packer.getPublicHeader(protocol.EncryptionUnencrypted) - Expect(ph.DiversificationNonce).To(BeEmpty()) - }) + Context("Header (for IETF draft QUIC)", func() { + BeforeEach(func() { + packer.version = versionIETFHeader + }) - It("includes a div nonce, when sending a packet with secure encryption", func() { - ph := packer.getPublicHeader(protocol.EncryptionSecure) - Expect(ph.DiversificationNonce).To(Equal(nonce)) - }) + It("uses the Long Header format for non-forward-secure packets", func() { + h := packer.getHeader(protocol.EncryptionSecure) + Expect(h.IsLongHeader).To(BeTrue()) + Expect(h.PacketNumberLen).To(Equal(protocol.PacketNumberLen4)) + Expect(h.Version).To(Equal(versionIETFHeader)) + }) - It("doesn't include a div nonce, when sending a packet with forward-secure encryption", func() { - ph := packer.getPublicHeader(protocol.EncryptionForwardSecure) - Expect(ph.DiversificationNonce).To(BeEmpty()) - }) + It("uses the Short Header format for forward-secure packets", func() { + h := packer.getHeader(protocol.EncryptionForwardSecure) + Expect(h.IsLongHeader).To(BeFalse()) + Expect(h.PacketNumberLen).To(BeNumerically(">", 0)) + }) - It("doesn't send a div nonce as a client", func() { - packer.perspective = protocol.PerspectiveClient - ph := packer.getPublicHeader(protocol.EncryptionSecure) - Expect(ph.DiversificationNonce).To(BeEmpty()) + It("it omits the connection ID for forward-secure packets", func() { + h := packer.getHeader(protocol.EncryptionForwardSecure) + Expect(h.OmitConnectionID).To(BeFalse()) + packer.SetOmitConnectionID() + h = packer.getHeader(protocol.EncryptionForwardSecure) + Expect(h.OmitConnectionID).To(BeTrue()) + }) + + It("doesn't omit the connection ID for non-forward-secure packets", func() { + packer.SetOmitConnectionID() + h := packer.getHeader(protocol.EncryptionSecure) + Expect(h.OmitConnectionID).To(BeFalse()) + }) }) }) @@ -226,48 +298,6 @@ var _ = Describe("Packet packer", func() { Expect(p).ToNot(BeNil()) }) - It("it omits the connection ID for forward-secure packets", func() { - ph := packer.getPublicHeader(protocol.EncryptionForwardSecure) - Expect(ph.OmitConnectionID).To(BeFalse()) - packer.SetOmitConnectionID() - ph = packer.getPublicHeader(protocol.EncryptionForwardSecure) - Expect(ph.OmitConnectionID).To(BeTrue()) - }) - - It("doesn't omit the connection ID for non-forware-secure packets", func() { - packer.SetOmitConnectionID() - ph := packer.getPublicHeader(protocol.EncryptionSecure) - Expect(ph.OmitConnectionID).To(BeFalse()) - }) - - It("adds the version flag to the public header before the crypto handshake is finished", func() { - packer.perspective = protocol.PerspectiveClient - packer.cryptoSetup.(*mockCryptoSetup).encLevelSeal = protocol.EncryptionSecure - packer.controlFrames = []wire.Frame{&wire.BlockedFrame{StreamID: 0}} - packer.connectionID = 0x1337 - packer.version = 123 - p, err := packer.PackPacket() - Expect(err).ToNot(HaveOccurred()) - Expect(p).ToNot(BeNil()) - hdr, err := wire.ParsePublicHeader(bytes.NewReader(p.raw), protocol.PerspectiveClient, packer.version) - Expect(err).ToNot(HaveOccurred()) - Expect(hdr.VersionFlag).To(BeTrue()) - Expect(hdr.VersionNumber).To(Equal(packer.version)) - }) - - It("doesn't add the version flag to the public header for forward-secure packets", func() { - packer.perspective = protocol.PerspectiveClient - packer.cryptoSetup.(*mockCryptoSetup).encLevelSeal = protocol.EncryptionForwardSecure - packer.controlFrames = []wire.Frame{&wire.BlockedFrame{StreamID: 0}} - packer.connectionID = 0x1337 - p, err := packer.PackPacket() - Expect(err).ToNot(HaveOccurred()) - Expect(p).ToNot(BeNil()) - hdr, err := wire.ParsePublicHeader(bytes.NewReader(p.raw), protocol.PerspectiveClient, packer.version) - Expect(err).ToNot(HaveOccurred()) - Expect(hdr.VersionFlag).To(BeFalse()) - }) - It("packs many control frames into 1 packets", func() { f := &wire.AckFrame{LargestAcked: 1} b := &bytes.Buffer{} diff --git a/packet_unpacker.go b/packet_unpacker.go index 055f6754..bf1e0cfe 100644 --- a/packet_unpacker.go +++ b/packet_unpacker.go @@ -24,10 +24,10 @@ type packetUnpacker struct { aead quicAEAD } -func (u *packetUnpacker) Unpack(publicHeaderBinary []byte, hdr *wire.PublicHeader, data []byte) (*unpackedPacket, error) { +func (u *packetUnpacker) Unpack(headerBinary []byte, hdr *wire.Header, data []byte) (*unpackedPacket, error) { buf := getPacketBuffer() defer putPacketBuffer(buf) - decrypted, encryptionLevel, err := u.aead.Open(buf, data, hdr.PacketNumber, publicHeaderBinary) + decrypted, encryptionLevel, err := u.aead.Open(buf, data, hdr.PacketNumber, headerBinary) if err != nil { // Wrap err in quicError so that public reset is sent by session return nil, qerr.Error(qerr.DecryptionFailure, err.Error()) diff --git a/packet_unpacker_test.go b/packet_unpacker_test.go index 238e0a3f..e7150ac0 100644 --- a/packet_unpacker_test.go +++ b/packet_unpacker_test.go @@ -31,14 +31,14 @@ var _ quicAEAD = &mockAEAD{} var _ = Describe("Packet unpacker", func() { var ( unpacker *packetUnpacker - hdr *wire.PublicHeader + hdr *wire.Header hdrBin []byte data []byte buf *bytes.Buffer ) BeforeEach(func() { - hdr = &wire.PublicHeader{ + hdr = &wire.Header{ PacketNumber: 10, PacketNumberLen: 1, } diff --git a/server.go b/server.go index b4be3cb3..f8ea5946 100644 --- a/server.go +++ b/server.go @@ -214,7 +214,7 @@ func (s *server) handlePacket(pconn net.PacketConn, remoteAddr net.Addr, packet rcvTime := time.Now() r := bytes.NewReader(packet) - connID, err := wire.PeekConnectionID(r, protocol.PerspectiveClient) + connID, err := wire.PeekConnectionID(r) if err != nil { return qerr.Error(qerr.InvalidPacketHeader, err.Error()) } @@ -233,7 +233,7 @@ func (s *server) handlePacket(pconn net.PacketConn, remoteAddr net.Addr, packet version = session.GetVersion() } - hdr, err := wire.ParsePublicHeader(r, protocol.PerspectiveClient, version) + hdr, err := wire.ParseHeader(r, protocol.PerspectiveClient, version) if err == wire.ErrPacketWithUnknownVersion { _, err = pconn.WriteTo(wire.WritePublicReset(connID, 0, 0), remoteAddr) return err @@ -262,23 +262,24 @@ func (s *server) handlePacket(pconn net.PacketConn, remoteAddr net.Addr, packet // a session is only created once the client sent a supported version // if we receive a packet for a connection that already has session, it's probably an old packet that was sent by the client before the version was negotiated // it is safe to drop it - if ok && hdr.VersionFlag && !protocol.IsSupportedVersion(s.config.Versions, hdr.VersionNumber) { + if ok && hdr.VersionFlag && !protocol.IsSupportedVersion(s.config.Versions, hdr.Version) { return nil } // Send Version Negotiation Packet if the client is speaking a different protocol version - if hdr.VersionFlag && !protocol.IsSupportedVersion(s.config.Versions, hdr.VersionNumber) { + if hdr.VersionFlag && !protocol.IsSupportedVersion(s.config.Versions, hdr.Version) { // drop packets that are too small to be valid first packets if len(packet) < protocol.ClientHelloMinimumSize+len(hdr.Raw) { return errors.New("dropping small packet with unknown version") } - utils.Infof("Client offered version %s, sending VersionNegotiationPacket", hdr.VersionNumber) + // TODO(894): send a IETF draft style Version Negotiation Packets + utils.Infof("Client offered version %s, sending VersionNegotiationPacket", hdr.Version) _, err = pconn.WriteTo(wire.ComposeVersionNegotiation(hdr.ConnectionID, s.config.Versions), remoteAddr) return err } if !ok { - version := hdr.VersionNumber + version := hdr.Version if !protocol.IsSupportedVersion(s.config.Versions, version) { return errors.New("Server BUG: negotiated version not supported") } @@ -320,10 +321,10 @@ func (s *server) handlePacket(pconn net.PacketConn, remoteAddr net.Addr, packet }() } session.handlePacket(&receivedPacket{ - remoteAddr: remoteAddr, - publicHeader: hdr, - data: packet[len(packet)-r.Len():], - rcvTime: rcvTime, + remoteAddr: remoteAddr, + header: hdr, + data: packet[len(packet)-r.Len():], + rcvTime: rcvTime, }) return nil } diff --git a/server_test.go b/server_test.go index 54b516b1..abd60a20 100644 --- a/server_test.go +++ b/server_test.go @@ -327,13 +327,13 @@ var _ = Describe("Server", func() { It("doesn't respond with a version negotiation packet if the first packet is too small", func() { b := &bytes.Buffer{} - hdr := wire.PublicHeader{ + hdr := wire.Header{ VersionFlag: true, ConnectionID: 0x1337, PacketNumber: 1, PacketNumberLen: protocol.PacketNumberLen2, } - hdr.Write(b, 13 /* not a valid QUIC version */, protocol.PerspectiveClient) + hdr.Write(b, protocol.PerspectiveClient, 13 /* not a valid QUIC version */) b.Write(bytes.Repeat([]byte{0}, protocol.ClientHelloMinimumSize-1)) // this packet is 1 byte too small err := serv.handlePacket(conn, udpAddr, b.Bytes()) Expect(err).To(MatchError("dropping small packet with unknown version")) @@ -398,13 +398,13 @@ var _ = Describe("Server", func() { It("setups and responds with version negotiation", func() { config.Versions = []protocol.VersionNumber{99} b := &bytes.Buffer{} - hdr := wire.PublicHeader{ + hdr := wire.Header{ VersionFlag: true, ConnectionID: 0x1337, PacketNumber: 1, PacketNumberLen: protocol.PacketNumberLen2, } - hdr.Write(b, 13 /* not a valid QUIC version */, protocol.PerspectiveClient) + hdr.Write(b, protocol.PerspectiveClient, 13 /* not a valid QUIC version */) b.Write(bytes.Repeat([]byte{0}, protocol.ClientHelloMinimumSize)) // add a fake CHLO conn.dataToRead = b.Bytes() conn.dataReadFrom = udpAddr diff --git a/session.go b/session.go index ca7769ef..11998fe5 100644 --- a/session.go +++ b/session.go @@ -20,14 +20,14 @@ import ( ) type unpacker interface { - Unpack(publicHeaderBinary []byte, hdr *wire.PublicHeader, data []byte) (*unpackedPacket, error) + Unpack(headerBinary []byte, hdr *wire.Header, data []byte) (*unpackedPacket, error) } type receivedPacket struct { - remoteAddr net.Addr - publicHeader *wire.PublicHeader - data []byte - rcvTime time.Time + remoteAddr net.Addr + header *wire.Header + data []byte + rcvTime time.Time } var ( @@ -322,7 +322,7 @@ runLoop: } // This is a bit unclean, but works properly, since the packet always // begins with the public header and we never copy it. - putPacketBuffer(p.publicHeader.Raw) + putPacketBuffer(p.header.Raw) case p := <-s.paramsChan: s.processTransportParameters(&p) case l, ok := <-aeadChanged: @@ -410,7 +410,7 @@ func (s *session) maybeResetTimer() { func (s *session) handlePacketImpl(p *receivedPacket) error { if s.perspective == protocol.PerspectiveClient { - diversificationNonce := p.publicHeader.DiversificationNonce + diversificationNonce := p.header.DiversificationNonce if len(diversificationNonce) > 0 { s.cryptoSetup.SetDiversificationNonce(diversificationNonce) } @@ -423,7 +423,7 @@ func (s *session) handlePacketImpl(p *receivedPacket) error { s.lastNetworkActivityTime = p.rcvTime s.keepAlivePingSent = false - hdr := p.publicHeader + hdr := p.header data := p.data // Calculate packet number @@ -847,7 +847,7 @@ func (s *session) scheduleSending() { func (s *session) tryQueueingUndecryptablePacket(p *receivedPacket) { if s.handshakeComplete { - utils.Debugf("Received undecryptable packet from %s after the handshake: %#v, %d bytes data", p.remoteAddr.String(), p.publicHeader, len(p.data)) + utils.Debugf("Received undecryptable packet from %s after the handshake: %#v, %d bytes data", p.remoteAddr.String(), p.header, len(p.data)) return } if len(s.undecryptablePackets)+1 > protocol.MaxUndecryptablePackets { @@ -856,10 +856,10 @@ func (s *session) tryQueueingUndecryptablePacket(p *receivedPacket) { s.receivedTooManyUndecrytablePacketsTime = time.Now() s.maybeResetTimer() } - utils.Infof("Dropping undecrytable packet 0x%x (undecryptable packet queue full)", p.publicHeader.PacketNumber) + utils.Infof("Dropping undecrytable packet 0x%x (undecryptable packet queue full)", p.header.PacketNumber) return } - utils.Infof("Queueing packet 0x%x for later decryption", p.publicHeader.PacketNumber) + utils.Infof("Queueing packet 0x%x for later decryption", p.header.PacketNumber) s.undecryptablePackets = append(s.undecryptablePackets, p) } diff --git a/session_test.go b/session_test.go index 45b8d1a4..ba051538 100644 --- a/session_test.go +++ b/session_test.go @@ -61,7 +61,7 @@ type mockUnpacker struct { unpackErr error } -func (m *mockUnpacker) Unpack(publicHeaderBinary []byte, hdr *wire.PublicHeader, data []byte) (*unpackedPacket, error) { +func (m *mockUnpacker) Unpack(headerBinary []byte, hdr *wire.Header, data []byte) (*unpackedPacket, error) { if m.unpackErr != nil { return nil, m.unpackErr } @@ -688,16 +688,16 @@ var _ = Describe("Session", func() { }) Context("receiving packets", func() { - var hdr *wire.PublicHeader + var hdr *wire.Header BeforeEach(func() { sess.unpacker = &mockUnpacker{} - hdr = &wire.PublicHeader{PacketNumberLen: protocol.PacketNumberLen6} + hdr = &wire.Header{PacketNumberLen: protocol.PacketNumberLen6} }) It("sets the {last,largest}RcvdPacketNumber", func() { hdr.PacketNumber = 5 - err := sess.handlePacketImpl(&receivedPacket{publicHeader: hdr}) + err := sess.handlePacketImpl(&receivedPacket{header: hdr}) Expect(err).ToNot(HaveOccurred()) Expect(sess.lastRcvdPacketNumber).To(Equal(protocol.PacketNumber(5))) Expect(sess.largestRcvdPacketNumber).To(Equal(protocol.PacketNumber(5))) @@ -711,7 +711,7 @@ var _ = Describe("Session", func() { runErr = sess.run() }() sess.unpacker.(*mockUnpacker).unpackErr = testErr - sess.handlePacket(&receivedPacket{publicHeader: hdr}) + sess.handlePacket(&receivedPacket{header: hdr}) Eventually(func() error { return runErr }).Should(MatchError(testErr)) Expect(sess.Context().Done()).To(BeClosed()) close(done) @@ -719,12 +719,12 @@ var _ = Describe("Session", func() { It("sets the {last,largest}RcvdPacketNumber, for an out-of-order packet", func() { hdr.PacketNumber = 5 - err := sess.handlePacketImpl(&receivedPacket{publicHeader: hdr}) + err := sess.handlePacketImpl(&receivedPacket{header: hdr}) Expect(err).ToNot(HaveOccurred()) Expect(sess.lastRcvdPacketNumber).To(Equal(protocol.PacketNumber(5))) Expect(sess.largestRcvdPacketNumber).To(Equal(protocol.PacketNumber(5))) hdr.PacketNumber = 3 - err = sess.handlePacketImpl(&receivedPacket{publicHeader: hdr}) + err = sess.handlePacketImpl(&receivedPacket{header: hdr}) Expect(err).ToNot(HaveOccurred()) Expect(sess.lastRcvdPacketNumber).To(Equal(protocol.PacketNumber(3))) Expect(sess.largestRcvdPacketNumber).To(Equal(protocol.PacketNumber(5))) @@ -732,9 +732,9 @@ var _ = Describe("Session", func() { It("handles duplicate packets", func() { hdr.PacketNumber = 5 - err := sess.handlePacketImpl(&receivedPacket{publicHeader: hdr}) + err := sess.handlePacketImpl(&receivedPacket{header: hdr}) Expect(err).ToNot(HaveOccurred()) - err = sess.handlePacketImpl(&receivedPacket{publicHeader: hdr}) + err = sess.handlePacketImpl(&receivedPacket{header: hdr}) Expect(err).ToNot(HaveOccurred()) }) @@ -743,8 +743,8 @@ var _ = Describe("Session", func() { remoteIP := &net.IPAddr{IP: net.IPv4(192, 168, 0, 100)} Expect(sess.conn.(*mockConnection).remoteAddr).ToNot(Equal(remoteIP)) p := receivedPacket{ - remoteAddr: remoteIP, - publicHeader: &wire.PublicHeader{PacketNumber: 1337}, + remoteAddr: remoteIP, + header: &wire.Header{PacketNumber: 1337}, } err := sess.handlePacketImpl(&p) Expect(err).ToNot(HaveOccurred()) @@ -759,8 +759,8 @@ var _ = Describe("Session", func() { sess.unpacker = &packetUnpacker{} sess.unpacker.(*packetUnpacker).aead = &mockAEAD{} p := receivedPacket{ - remoteAddr: attackerIP, - publicHeader: &wire.PublicHeader{PacketNumber: 1337}, + remoteAddr: attackerIP, + header: &wire.Header{PacketNumber: 1337}, } err := sess.handlePacketImpl(&p) quicErr := err.(*qerr.QuicError) @@ -773,8 +773,8 @@ var _ = Describe("Session", func() { remoteIP := &net.IPAddr{IP: net.IPv4(192, 168, 0, 100)} Expect(sess.conn.(*mockConnection).remoteAddr).ToNot(Equal(remoteIP)) p := receivedPacket{ - remoteAddr: remoteIP, - publicHeader: &wire.PublicHeader{PacketNumber: 1337}, + remoteAddr: remoteIP, + header: &wire.Header{PacketNumber: 1337}, } sess.unpacker.(*mockUnpacker).unpackErr = testErr err := sess.handlePacketImpl(&p) @@ -1140,13 +1140,13 @@ var _ = Describe("Session", func() { // this completely fills up the undecryptable packets queue and triggers the public reset timer sendUndecryptablePackets := func() { for i := 0; i < protocol.MaxUndecryptablePackets+1; i++ { - hdr := &wire.PublicHeader{ + hdr := &wire.Header{ PacketNumber: protocol.PacketNumber(i + 1), } sess.handlePacket(&receivedPacket{ - publicHeader: hdr, - remoteAddr: &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1234}, - data: []byte("foobar"), + header: hdr, + remoteAddr: &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1234}, + data: []byte("foobar"), }) } } @@ -1175,7 +1175,7 @@ var _ = Describe("Session", func() { sendUndecryptablePackets() Eventually(func() []*receivedPacket { return sess.undecryptablePackets }).Should(HaveLen(protocol.MaxUndecryptablePackets)) // check that old packets are kept, and the new packets are dropped - Expect(sess.undecryptablePackets[0].publicHeader.PacketNumber).To(Equal(protocol.PacketNumber(1))) + Expect(sess.undecryptablePackets[0].header.PacketNumber).To(Equal(protocol.PacketNumber(1))) sess.Close(nil) }) @@ -1213,7 +1213,7 @@ var _ = Describe("Session", func() { It("unqueues undecryptable packets for later decryption", func() { sess.undecryptablePackets = []*receivedPacket{{ - publicHeader: &wire.PublicHeader{PacketNumber: protocol.PacketNumber(42)}, + header: &wire.Header{PacketNumber: protocol.PacketNumber(42)}, }} Expect(sess.receivedPackets).NotTo(Receive()) sess.tryDecryptingQueuedPackets() @@ -1549,10 +1549,10 @@ var _ = Describe("Client Session", func() { }) Context("receiving packets", func() { - var hdr *wire.PublicHeader + var hdr *wire.Header BeforeEach(func() { - hdr = &wire.PublicHeader{PacketNumberLen: protocol.PacketNumberLen6} + hdr = &wire.Header{PacketNumberLen: protocol.PacketNumberLen6} sess.unpacker = &mockUnpacker{} }) @@ -1560,7 +1560,7 @@ var _ = Describe("Client Session", func() { go sess.run() hdr.PacketNumber = 5 hdr.DiversificationNonce = []byte("foobar") - err := sess.handlePacketImpl(&receivedPacket{publicHeader: hdr}) + err := sess.handlePacketImpl(&receivedPacket{header: hdr}) Expect(err).ToNot(HaveOccurred()) Eventually(func() []byte { return cryptoSetup.divNonce }).Should(Equal(hdr.DiversificationNonce)) Expect(sess.Close(nil)).To(Succeed())