From dd0daaaf1edce5dd8b45482cefe109adced7a80e Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Thu, 3 Aug 2017 15:13:23 +0700 Subject: [PATCH] implement version-dependent parsing of the Public Header --- client.go | 2 +- integrationtests/gquic/drop_test.go | 2 +- integrationtests/gquic/random_rtt_test.go | 2 +- integrationtests/gquic/rtt_test.go | 2 +- integrationtests/self/handshake_rtt_test.go | 2 +- integrationtests/tools/proxy/proxy.go | 7 +- integrationtests/tools/proxy/proxy_test.go | 8 +- internal/utils/byteorder.go | 11 + internal/utils/byteorder_test.go | 18 ++ packet_packer_test.go | 4 +- protocol/version.go | 2 + public_header.go | 20 +- public_header_test.go | 295 ++++++++++++++------ server.go | 40 ++- server_test.go | 11 +- session.go | 4 + session_test.go | 5 + 17 files changed, 311 insertions(+), 124 deletions(-) create mode 100644 internal/utils/byteorder_test.go diff --git a/client.go b/client.go index 2e18de80..f076745c 100644 --- a/client.go +++ b/client.go @@ -219,7 +219,7 @@ func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) { rcvTime := time.Now() r := bytes.NewReader(packet) - hdr, err := ParsePublicHeader(r, protocol.PerspectiveServer) + hdr, err := ParsePublicHeader(r, protocol.PerspectiveServer, c.version) if err != nil { utils.Errorf("error parsing packet from %s: %s", remoteAddr.String(), err.Error()) // drop this packet if we can't parse the Public Header diff --git a/integrationtests/gquic/drop_test.go b/integrationtests/gquic/drop_test.go index a3f710b0..6ee43de3 100644 --- a/integrationtests/gquic/drop_test.go +++ b/integrationtests/gquic/drop_test.go @@ -21,7 +21,7 @@ var _ = Describe("Drop tests", func() { runDropTest := func(dropCallback quicproxy.DropCallback, version protocol.VersionNumber) { var err error - proxy, err = quicproxy.NewQuicProxy("localhost:0", quicproxy.Opts{ + proxy, err = quicproxy.NewQuicProxy("localhost:0", version, quicproxy.Opts{ RemoteAddr: "localhost:" + testserver.Port(), DropPacket: dropCallback, }) diff --git a/integrationtests/gquic/random_rtt_test.go b/integrationtests/gquic/random_rtt_test.go index e458272e..2a525d57 100644 --- a/integrationtests/gquic/random_rtt_test.go +++ b/integrationtests/gquic/random_rtt_test.go @@ -56,7 +56,7 @@ var _ = Describe("Random RTT", func() { runRTTTest := func(minRtt, maxRtt time.Duration, version protocol.VersionNumber) { rand.Seed(time.Now().UnixNano()) var err error - proxy, err = quicproxy.NewQuicProxy("localhost:", quicproxy.Opts{ + proxy, err = quicproxy.NewQuicProxy("localhost:", version, quicproxy.Opts{ RemoteAddr: "localhost:" + testserver.Port(), DelayPacket: func(_ quicproxy.Direction, _ protocol.PacketNumber) time.Duration { return getRandomDuration(minRtt, maxRtt) diff --git a/integrationtests/gquic/rtt_test.go b/integrationtests/gquic/rtt_test.go index 333d07bb..1cce41ad 100644 --- a/integrationtests/gquic/rtt_test.go +++ b/integrationtests/gquic/rtt_test.go @@ -22,7 +22,7 @@ var _ = Describe("non-zero RTT", func() { runRTTTest := func(rtt time.Duration, version protocol.VersionNumber) { var err error - proxy, err = quicproxy.NewQuicProxy("localhost:", quicproxy.Opts{ + proxy, err = quicproxy.NewQuicProxy("localhost:", version, quicproxy.Opts{ RemoteAddr: "localhost:" + testserver.Port(), DelayPacket: func(_ quicproxy.Direction, _ protocol.PacketNumber) time.Duration { return rtt / 2 diff --git a/integrationtests/self/handshake_rtt_test.go b/integrationtests/self/handshake_rtt_test.go index 66122ab0..47ff6dfb 100644 --- a/integrationtests/self/handshake_rtt_test.go +++ b/integrationtests/self/handshake_rtt_test.go @@ -42,7 +42,7 @@ var _ = Describe("Handshake RTT tests", func() { server, err = quic.ListenAddr("localhost:0", testdata.GetTLSConfig(), serverConfig) Expect(err).ToNot(HaveOccurred()) // start the proxy - proxy, err = quicproxy.NewQuicProxy("localhost:0", quicproxy.Opts{ + proxy, err = quicproxy.NewQuicProxy("localhost:0", protocol.VersionWhatever, quicproxy.Opts{ RemoteAddr: server.Addr().String(), DelayPacket: func(_ quicproxy.Direction, _ protocol.PacketNumber) time.Duration { return rtt / 2 }, }) diff --git a/integrationtests/tools/proxy/proxy.go b/integrationtests/tools/proxy/proxy.go index 70be8889..bc72411d 100644 --- a/integrationtests/tools/proxy/proxy.go +++ b/integrationtests/tools/proxy/proxy.go @@ -62,6 +62,8 @@ type Opts struct { type QuicProxy struct { mutex sync.Mutex + version protocol.VersionNumber + conn *net.UDPConn serverAddr *net.UDPAddr @@ -73,7 +75,7 @@ type QuicProxy struct { } // NewQuicProxy creates a new UDP proxy -func NewQuicProxy(local string, opts Opts) (*QuicProxy, error) { +func NewQuicProxy(local string, version protocol.VersionNumber, opts Opts) (*QuicProxy, error) { laddr, err := net.ResolveUDPAddr("udp", local) if err != nil { return nil, err @@ -103,6 +105,7 @@ func NewQuicProxy(local string, opts Opts) (*QuicProxy, error) { serverAddr: raddr, dropPacket: packetDropper, delayPacket: packetDelayer, + version: version, } go p.runProxy() @@ -162,7 +165,7 @@ func (p *QuicProxy) runProxy() error { atomic.AddUint64(&conn.incomingPacketCounter, 1) r := bytes.NewReader(raw) - hdr, err := quic.ParsePublicHeader(r, protocol.PerspectiveClient) + hdr, err := quic.ParsePublicHeader(r, protocol.PerspectiveClient, protocol.VersionWhatever) if err != nil { return err } diff --git a/integrationtests/tools/proxy/proxy_test.go b/integrationtests/tools/proxy/proxy_test.go index a4c1f53d..565a8911 100644 --- a/integrationtests/tools/proxy/proxy_test.go +++ b/integrationtests/tools/proxy/proxy_test.go @@ -40,7 +40,7 @@ var _ = Describe("QUIC Proxy", func() { Context("Proxy setup and teardown", func() { It("sets up the UDPProxy", func() { - proxy, err := NewQuicProxy("localhost:0", Opts{RemoteAddr: serverAddr}) + proxy, err := NewQuicProxy("localhost:0", protocol.VersionWhatever, Opts{RemoteAddr: serverAddr}) Expect(err).ToNot(HaveOccurred()) Expect(proxy.clientDict).To(HaveLen(0)) @@ -53,7 +53,7 @@ var _ = Describe("QUIC Proxy", func() { }) It("stops the UDPProxy", func() { - proxy, err := NewQuicProxy("localhost:0", Opts{RemoteAddr: serverAddr}) + proxy, err := NewQuicProxy("localhost:0", protocol.VersionWhatever, Opts{RemoteAddr: serverAddr}) Expect(err).ToNot(HaveOccurred()) port := proxy.LocalPort() err = proxy.Close() @@ -71,7 +71,7 @@ var _ = Describe("QUIC Proxy", func() { }) It("has the correct LocalAddr and LocalPort", func() { - proxy, err := NewQuicProxy("localhost:0", Opts{RemoteAddr: serverAddr}) + proxy, err := NewQuicProxy("localhost:0", protocol.VersionWhatever, Opts{RemoteAddr: serverAddr}) Expect(err).ToNot(HaveOccurred()) Expect(proxy.LocalAddr().String()).To(Equal("127.0.0.1:" + strconv.Itoa(proxy.LocalPort()))) @@ -92,7 +92,7 @@ var _ = Describe("QUIC Proxy", func() { startProxy := func(opts Opts) { var err error - proxy, err = NewQuicProxy("localhost:0", opts) + proxy, err = NewQuicProxy("localhost:0", protocol.VersionWhatever, opts) Expect(err).ToNot(HaveOccurred()) clientConn, err = net.DialUDP("udp", nil, proxy.LocalAddr().(*net.UDPAddr)) Expect(err).ToNot(HaveOccurred()) diff --git a/internal/utils/byteorder.go b/internal/utils/byteorder.go index b45800a3..bfba0738 100644 --- a/internal/utils/byteorder.go +++ b/internal/utils/byteorder.go @@ -3,6 +3,8 @@ package utils import ( "bytes" "io" + + "github.com/lucas-clemente/quic-go/protocol" ) // A ByteOrder specifies how to convert byte sequences into 16-, 32-, or 64-bit unsigned integers. @@ -23,3 +25,12 @@ type ByteOrder interface { ReadUfloat16(io.ByteReader) (uint64, error) WriteUfloat16(*bytes.Buffer, uint64) } + +// GetByteOrder gets the ByteOrder (little endian or big endian) used to represent values on the wire +// from QUIC 39, values are encoded in big endian, before that in little endian +func GetByteOrder(v protocol.VersionNumber) ByteOrder { + if v < protocol.Version39 { + return LittleEndian + } + return BigEndian +} diff --git a/internal/utils/byteorder_test.go b/internal/utils/byteorder_test.go new file mode 100644 index 00000000..a1ddb8b0 --- /dev/null +++ b/internal/utils/byteorder_test.go @@ -0,0 +1,18 @@ +package utils + +import ( + "github.com/lucas-clemente/quic-go/protocol" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Byte Order", func() { + It("says little Little Endian before QUIC 39", func() { + Expect(GetByteOrder(protocol.Version36)).To(Equal(LittleEndian)) + Expect(GetByteOrder(protocol.Version37)).To(Equal(LittleEndian)) + }) + + It("says little Little Endian for QUIC 39", func() { + Expect(GetByteOrder(protocol.Version39)).To(Equal(BigEndian)) + }) +}) diff --git a/packet_packer_test.go b/packet_packer_test.go index 8cf0de4c..e7e1226b 100644 --- a/packet_packer_test.go +++ b/packet_packer_test.go @@ -238,7 +238,7 @@ var _ = Describe("Packet packer", func() { p, err := packer.PackPacket() Expect(err).ToNot(HaveOccurred()) Expect(p).ToNot(BeNil()) - hdr, err := ParsePublicHeader(bytes.NewReader(p.raw), protocol.PerspectiveClient) + hdr, err := ParsePublicHeader(bytes.NewReader(p.raw), protocol.PerspectiveClient, packer.version) Expect(err).ToNot(HaveOccurred()) Expect(hdr.VersionFlag).To(BeTrue()) Expect(hdr.VersionNumber).To(Equal(packer.version)) @@ -252,7 +252,7 @@ var _ = Describe("Packet packer", func() { p, err := packer.PackPacket() Expect(err).ToNot(HaveOccurred()) Expect(p).ToNot(BeNil()) - hdr, err := ParsePublicHeader(bytes.NewReader(p.raw), protocol.PerspectiveClient) + hdr, err := ParsePublicHeader(bytes.NewReader(p.raw), protocol.PerspectiveClient, packer.version) Expect(err).ToNot(HaveOccurred()) Expect(hdr.VersionFlag).To(BeFalse()) }) diff --git a/protocol/version.go b/protocol/version.go index c250cab6..1eb0ffe3 100644 --- a/protocol/version.go +++ b/protocol/version.go @@ -9,8 +9,10 @@ const ( Version36 Version37 Version38 + Version39 VersionWhatever VersionNumber = 0 // for when the version doesn't matter VersionUnsupported VersionNumber = -1 + VersionUnknown VersionNumber = -2 ) // SupportedVersions lists the versions that the server supports diff --git a/public_header.go b/public_header.go index 09c5dac6..8d440b80 100644 --- a/public_header.go +++ b/public_header.go @@ -16,6 +16,8 @@ var ( errReceivedTruncatedConnectionID = qerr.Error(qerr.InvalidPacketHeader, "receiving packets with truncated ConnectionID is not supported") errInvalidConnectionID = qerr.Error(qerr.InvalidPacketHeader, "connection ID cannot be 0") errGetLengthNotForVersionNegotiation = errors.New("PublicHeader: GetLength cannot be called for VersionNegotiation packets") + // this can happen when the server is restarted. The client will send a packet without a version number + errPacketWithUnknownVersion = errors.New("PublicHeader: Received a packet without version number, that we don't know the version for") ) // The PublicHeader of a QUIC packet. Warning: This struct should not be considered stable and will change soon. @@ -74,6 +76,7 @@ func (h *PublicHeader) Write(b *bytes.Buffer, version protocol.VersionNumber, pe b.WriteByte(publicFlagByte) if !h.TruncateConnectionID { + // always read the connection ID in little endian utils.LittleEndian.WriteUint64(b, uint64(h.ConnectionID)) } @@ -98,11 +101,11 @@ func (h *PublicHeader) Write(b *bytes.Buffer, version protocol.VersionNumber, pe case protocol.PacketNumberLen1: b.WriteByte(uint8(h.PacketNumber)) case protocol.PacketNumberLen2: - utils.LittleEndian.WriteUint16(b, uint16(h.PacketNumber)) + utils.GetByteOrder(version).WriteUint16(b, uint16(h.PacketNumber)) case protocol.PacketNumberLen4: - utils.LittleEndian.WriteUint32(b, uint32(h.PacketNumber)) + utils.GetByteOrder(version).WriteUint32(b, uint32(h.PacketNumber)) case protocol.PacketNumberLen6: - utils.LittleEndian.WriteUint48(b, uint64(h.PacketNumber)&(1<<48-1)) + utils.GetByteOrder(version).WriteUint48(b, uint64(h.PacketNumber)&(1<<48-1)) default: return errPacketNumberLenNotSet } @@ -142,7 +145,7 @@ func PeekConnectionID(b *bytes.Reader, packetSentBy protocol.Perspective) (proto // ParsePublicHeader parses a QUIC packet's public header. // The packetSentBy is the perspective of the peer that sent this PublicHeader, i.e. if we're the server, packetSentBy should be PerspectiveClient. // Warning: This API should not be considered stable and will change soon. -func ParsePublicHeader(b *bytes.Reader, packetSentBy protocol.Perspective) (*PublicHeader, error) { +func ParsePublicHeader(b *bytes.Reader, packetSentBy protocol.Perspective, version protocol.VersionNumber) (*PublicHeader, error) { header := &PublicHeader{} // First byte @@ -150,8 +153,11 @@ func ParsePublicHeader(b *bytes.Reader, packetSentBy protocol.Perspective) (*Pub if err != nil { return nil, err } - header.VersionFlag = publicFlagByte&0x01 > 0 header.ResetFlag = publicFlagByte&0x02 > 0 + header.VersionFlag = publicFlagByte&0x01 > 0 + if version == protocol.VersionUnknown && !(header.VersionFlag || header.ResetFlag) { + return nil, errPacketWithUnknownVersion + } // TODO: activate this check once Chrome sends the correct value // see https://github.com/lucas-clemente/quic-go/issues/232 @@ -180,6 +186,7 @@ func ParsePublicHeader(b *bytes.Reader, packetSentBy protocol.Perspective) (*Pub // Connection ID if !header.TruncateConnectionID { var connID uint64 + // always write the connection ID in little endian connID, err = utils.LittleEndian.ReadUint64(b) if err != nil { return nil, err @@ -228,11 +235,12 @@ func ParsePublicHeader(b *bytes.Reader, packetSentBy protocol.Perspective) (*Pub return nil, err } header.VersionNumber = protocol.VersionTagToNumber(versionTag) + version = header.VersionNumber } // Packet number if header.hasPacketNumber(packetSentBy) { - packetNumber, err := utils.LittleEndian.ReadUintN(b, uint8(header.PacketNumberLen)) + packetNumber, err := utils.GetByteOrder(version).ReadUintN(b, uint8(header.PacketNumberLen)) if err != nil { return nil, err } diff --git a/public_header_test.go b/public_header_test.go index 67510865..e9b9ce40 100644 --- a/public_header_test.go +++ b/public_header_test.go @@ -4,6 +4,7 @@ import ( "bytes" "encoding/binary" + "github.com/lucas-clemente/quic-go/internal/utils" "github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/qerr" . "github.com/onsi/ginkgo" @@ -52,7 +53,7 @@ var _ = Describe("Public Header", func() { Context("when parsing", func() { It("accepts a sample client header", func() { b := bytes.NewReader([]byte{0x09, 0xf6, 0x19, 0x86, 0x66, 0x9b, 0x9f, 0xfa, 0x4c, 0x51, 0x30, 0x33, 0x34, 0x01}) - hdr, err := ParsePublicHeader(b, protocol.PerspectiveClient) + hdr, err := ParsePublicHeader(b, protocol.PerspectiveClient, protocol.VersionUnknown) Expect(err).ToNot(HaveOccurred()) Expect(hdr.VersionFlag).To(BeTrue()) Expect(hdr.ResetFlag).To(BeFalse()) @@ -65,13 +66,13 @@ var _ = Describe("Public Header", func() { It("does not accept truncated connection ID as a server", func() { b := bytes.NewReader([]byte{0x00, 0x01}) - _, err := ParsePublicHeader(b, protocol.PerspectiveClient) + _, err := ParsePublicHeader(b, protocol.PerspectiveClient, protocol.VersionWhatever) Expect(err).To(MatchError(errReceivedTruncatedConnectionID)) }) It("accepts a truncated connection ID as a client", func() { b := bytes.NewReader([]byte{0x00, 0x01}) - hdr, err := ParsePublicHeader(b, protocol.PerspectiveServer) + hdr, err := ParsePublicHeader(b, protocol.PerspectiveServer, protocol.VersionWhatever) Expect(err).ToNot(HaveOccurred()) Expect(hdr.TruncateConnectionID).To(BeTrue()) Expect(hdr.ConnectionID).To(BeZero()) @@ -80,13 +81,13 @@ var _ = Describe("Public Header", func() { It("rejects 0 as a connection ID", func() { b := bytes.NewReader([]byte{0x09, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x51, 0x30, 0x33, 0x30, 0x01}) - _, err := ParsePublicHeader(b, protocol.PerspectiveClient) + _, err := ParsePublicHeader(b, protocol.PerspectiveClient, protocol.VersionUnknown) Expect(err).To(MatchError(errInvalidConnectionID)) }) It("reads a PublicReset packet", func() { b := bytes.NewReader([]byte{0xa, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8}) - hdr, err := ParsePublicHeader(b, protocol.PerspectiveServer) + hdr, err := ParsePublicHeader(b, protocol.PerspectiveServer, protocol.VersionUnknown) Expect(err).ToNot(HaveOccurred()) Expect(hdr.ResetFlag).To(BeTrue()) Expect(hdr.ConnectionID).ToNot(BeZero()) @@ -94,7 +95,7 @@ var _ = Describe("Public Header", func() { It("parses a public reset packet", func() { b := bytes.NewReader([]byte{0xa, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08}) - hdr, err := ParsePublicHeader(b, protocol.PerspectiveServer) + hdr, err := ParsePublicHeader(b, protocol.PerspectiveServer, protocol.VersionUnknown) Expect(err).ToNot(HaveOccurred()) Expect(hdr.ResetFlag).To(BeTrue()) Expect(hdr.VersionFlag).To(BeFalse()) @@ -105,20 +106,25 @@ var _ = Describe("Public Header", func() { divNonce := []byte{0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8, 0x9, 0xa, 0xb, 0xc, 0xd, 0xe, 0xf, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f} Expect(divNonce).To(HaveLen(32)) b := bytes.NewReader(append(append([]byte{0x0c, 0xf6, 0x19, 0x86, 0x66, 0x9b, 0x9f, 0xfa, 0x4c}, divNonce...), 0x37)) - hdr, err := ParsePublicHeader(b, protocol.PerspectiveServer) + hdr, err := ParsePublicHeader(b, protocol.PerspectiveServer, protocol.VersionWhatever) Expect(err).ToNot(HaveOccurred()) Expect(hdr.ConnectionID).To(Not(BeZero())) Expect(hdr.DiversificationNonce).To(Equal(divNonce)) - Expect(hdr.PacketNumber).To(Equal(protocol.PacketNumber(0x37))) Expect(b.Len()).To(BeZero()) }) + It("returns an unknown version error when receiving a packet without a version for which the version is not given", func() { + b := bytes.NewReader([]byte{0x10, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8, 0xef}) + _, err := ParsePublicHeader(b, protocol.PerspectiveServer, protocol.VersionUnknown) + Expect(err).To(MatchError(errPacketWithUnknownVersion)) + }) + PIt("rejects diversification nonces sent by the client", func() { b := bytes.NewReader([]byte{0x0c, 0xf6, 0x19, 0x86, 0x66, 0x9b, 0x9f, 0xfa, 0x4c, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 0x01, }) - _, err := ParsePublicHeader(b, protocol.PerspectiveClient) + _, err := ParsePublicHeader(b, protocol.PerspectiveClient, protocol.VersionWhatever) Expect(err).To(MatchError("diversification nonces should only be sent by servers")) }) @@ -131,7 +137,7 @@ var _ = Describe("Public Header", func() { It("parses version negotiation packets sent by the server", func() { b := bytes.NewReader(composeVersionNegotiation(0x1337, protocol.SupportedVersions)) - hdr, err := ParsePublicHeader(b, protocol.PerspectiveServer) + hdr, err := ParsePublicHeader(b, protocol.PerspectiveServer, protocol.VersionUnknown) Expect(err).ToNot(HaveOccurred()) Expect(hdr.VersionFlag).To(BeTrue()) Expect(hdr.VersionNumber).To(BeZero()) // unitialized @@ -141,7 +147,7 @@ var _ = Describe("Public Header", func() { It("parses a version negotiation packet that contains 0 versions", func() { b := bytes.NewReader([]byte{0x9, 0xf6, 0x19, 0x86, 0x66, 0x9b, 0x9f, 0xfa, 0x4c}) - hdr, err := ParsePublicHeader(b, protocol.PerspectiveServer) + hdr, err := ParsePublicHeader(b, protocol.PerspectiveServer, protocol.VersionUnknown) Expect(err).ToNot(HaveOccurred()) Expect(hdr.VersionFlag).To(BeTrue()) Expect(hdr.VersionNumber).To(BeZero()) // unitialized @@ -155,7 +161,7 @@ var _ = Describe("Public Header", func() { data = appendVersion(data, protocol.SupportedVersions[0]) data = appendVersion(data, 99) // unsupported version b := bytes.NewReader(data) - hdr, err := ParsePublicHeader(b, protocol.PerspectiveServer) + hdr, err := ParsePublicHeader(b, protocol.PerspectiveServer, protocol.VersionUnknown) Expect(err).ToNot(HaveOccurred()) Expect(hdr.VersionFlag).To(BeTrue()) Expect(hdr.SupportedVersions).To(Equal([]protocol.VersionNumber{1, protocol.SupportedVersions[0], 99})) @@ -166,42 +172,98 @@ var _ = Describe("Public Header", func() { data := composeVersionNegotiation(0x1337, protocol.SupportedVersions) data = append(data, []byte{0x13, 0x37}...) b := bytes.NewReader(data) - _, err := ParsePublicHeader(b, protocol.PerspectiveServer) + _, err := ParsePublicHeader(b, protocol.PerspectiveServer, protocol.VersionUnknown) Expect(err).To(MatchError(qerr.InvalidVersionNegotiationPacket)) }) }) Context("Packet Number lengths", func() { - It("accepts 1-byte packet numbers", func() { - b := bytes.NewReader([]byte{0x08, 0xf6, 0x19, 0x86, 0x66, 0x9b, 0x9f, 0xfa, 0x4c, 0xde}) - hdr, err := ParsePublicHeader(b, protocol.PerspectiveClient) - Expect(err).ToNot(HaveOccurred()) - Expect(hdr.PacketNumber).To(Equal(protocol.PacketNumber(0xde))) - Expect(b.Len()).To(BeZero()) + Context("in little endian encoding", func() { + version := protocol.Version37 + + BeforeEach(func() { + Expect(utils.GetByteOrder(version)).To(Equal(utils.LittleEndian)) + }) + + It("accepts 1-byte packet numbers", func() { + b := bytes.NewReader([]byte{0x08, 0xf6, 0x19, 0x86, 0x66, 0x9b, 0x9f, 0xfa, 0x4c, 0xde}) + hdr, err := ParsePublicHeader(b, protocol.PerspectiveClient, version) + Expect(err).ToNot(HaveOccurred()) + Expect(hdr.PacketNumber).To(Equal(protocol.PacketNumber(0xde))) + Expect(hdr.PacketNumberLen).To(Equal(protocol.PacketNumberLen1)) + Expect(b.Len()).To(BeZero()) + }) + + It("accepts 2-byte packet numbers", func() { + b := bytes.NewReader([]byte{0x18, 0xf6, 0x19, 0x86, 0x66, 0x9b, 0x9f, 0xfa, 0x4c, 0xde, 0xca}) + hdr, err := ParsePublicHeader(b, protocol.PerspectiveClient, version) + Expect(err).ToNot(HaveOccurred()) + Expect(hdr.PacketNumber).To(Equal(protocol.PacketNumber(0xcade))) + Expect(hdr.PacketNumberLen).To(Equal(protocol.PacketNumberLen2)) + Expect(b.Len()).To(BeZero()) + }) + + It("accepts 4-byte packet numbers", func() { + b := bytes.NewReader([]byte{0x28, 0xf6, 0x19, 0x86, 0x66, 0x9b, 0x9f, 0xfa, 0x4c, 0xad, 0xfb, 0xca, 0xde}) + hdr, err := ParsePublicHeader(b, protocol.PerspectiveClient, version) + Expect(err).ToNot(HaveOccurred()) + Expect(hdr.PacketNumber).To(Equal(protocol.PacketNumber(0xdecafbad))) + Expect(hdr.PacketNumberLen).To(Equal(protocol.PacketNumberLen4)) + Expect(b.Len()).To(BeZero()) + }) + + It("accepts 6-byte packet numbers", func() { + b := bytes.NewReader([]byte{0x38, 0xf6, 0x19, 0x86, 0x66, 0x9b, 0x9f, 0xfa, 0x4c, 0x23, 0x42, 0xad, 0xfb, 0xca, 0xde}) + hdr, err := ParsePublicHeader(b, protocol.PerspectiveClient, version) + Expect(err).ToNot(HaveOccurred()) + Expect(hdr.PacketNumber).To(Equal(protocol.PacketNumber(0xdecafbad4223))) + Expect(hdr.PacketNumberLen).To(Equal(protocol.PacketNumberLen6)) + Expect(b.Len()).To(BeZero()) + }) }) - It("accepts 2-byte packet numbers", func() { - b := bytes.NewReader([]byte{0x18, 0xf6, 0x19, 0x86, 0x66, 0x9b, 0x9f, 0xfa, 0x4c, 0xde, 0xca}) - hdr, err := ParsePublicHeader(b, protocol.PerspectiveClient) - Expect(err).ToNot(HaveOccurred()) - Expect(hdr.PacketNumber).To(Equal(protocol.PacketNumber(0xcade))) - Expect(b.Len()).To(BeZero()) - }) + Context("in big endian encoding", func() { + version := protocol.Version39 - It("accepts 4-byte packet numbers", func() { - b := bytes.NewReader([]byte{0x28, 0xf6, 0x19, 0x86, 0x66, 0x9b, 0x9f, 0xfa, 0x4c, 0xad, 0xfb, 0xca, 0xde}) - hdr, err := ParsePublicHeader(b, protocol.PerspectiveClient) - Expect(err).ToNot(HaveOccurred()) - Expect(hdr.PacketNumber).To(Equal(protocol.PacketNumber(0xdecafbad))) - Expect(b.Len()).To(BeZero()) - }) + BeforeEach(func() { + Expect(utils.GetByteOrder(version)).To(Equal(utils.BigEndian)) + }) - It("accepts 6-byte packet numbers", func() { - b := bytes.NewReader([]byte{0x38, 0xf6, 0x19, 0x86, 0x66, 0x9b, 0x9f, 0xfa, 0x4c, 0x23, 0x42, 0xad, 0xfb, 0xca, 0xde}) - hdr, err := ParsePublicHeader(b, protocol.PerspectiveClient) - Expect(err).ToNot(HaveOccurred()) - Expect(hdr.PacketNumber).To(Equal(protocol.PacketNumber(0xdecafbad4223))) - Expect(b.Len()).To(BeZero()) + It("accepts 1-byte packet numbers", func() { + b := bytes.NewReader([]byte{0x08, 0xf6, 0x19, 0x86, 0x66, 0x9b, 0x9f, 0xfa, 0x4c, 0xde}) + hdr, err := ParsePublicHeader(b, protocol.PerspectiveClient, version) + Expect(err).ToNot(HaveOccurred()) + Expect(hdr.PacketNumber).To(Equal(protocol.PacketNumber(0xde))) + Expect(hdr.PacketNumberLen).To(Equal(protocol.PacketNumberLen1)) + Expect(b.Len()).To(BeZero()) + }) + + It("accepts 2-byte packet numbers", func() { + b := bytes.NewReader([]byte{0x18, 0xf6, 0x19, 0x86, 0x66, 0x9b, 0x9f, 0xfa, 0x4c, 0xde, 0xca}) + hdr, err := ParsePublicHeader(b, protocol.PerspectiveClient, version) + Expect(err).ToNot(HaveOccurred()) + Expect(hdr.PacketNumber).To(Equal(protocol.PacketNumber(0xdeca))) + Expect(hdr.PacketNumberLen).To(Equal(protocol.PacketNumberLen2)) + Expect(b.Len()).To(BeZero()) + }) + + It("accepts 4-byte packet numbers", func() { + b := bytes.NewReader([]byte{0x28, 0xf6, 0x19, 0x86, 0x66, 0x9b, 0x9f, 0xfa, 0x4c, 0xad, 0xfb, 0xca, 0xde}) + hdr, err := ParsePublicHeader(b, protocol.PerspectiveClient, version) + Expect(err).ToNot(HaveOccurred()) + Expect(hdr.PacketNumber).To(Equal(protocol.PacketNumber(0xadfbcade))) + Expect(hdr.PacketNumberLen).To(Equal(protocol.PacketNumberLen4)) + Expect(b.Len()).To(BeZero()) + }) + + It("accepts 6-byte packet numbers", func() { + b := bytes.NewReader([]byte{0x38, 0xf6, 0x19, 0x86, 0x66, 0x9b, 0x9f, 0xfa, 0x4c, 0x23, 0x42, 0xad, 0xfb, 0xca, 0xde}) + hdr, err := ParsePublicHeader(b, protocol.PerspectiveClient, version) + Expect(err).ToNot(HaveOccurred()) + Expect(hdr.PacketNumber).To(Equal(protocol.PacketNumber(0x2342adfbcade))) + Expect(hdr.PacketNumberLen).To(Equal(protocol.PacketNumberLen6)) + Expect(b.Len()).To(BeZero()) + }) }) }) }) @@ -286,11 +348,8 @@ var _ = Describe("Public Header", func() { It("throws an error if both Reset Flag and Version Flag are set", func() { b := &bytes.Buffer{} hdr := PublicHeader{ - VersionFlag: true, - ResetFlag: true, - ConnectionID: 0x4cfa9f9b668619f6, - PacketNumber: 2, - PacketNumberLen: protocol.PacketNumberLen6, + VersionFlag: true, + ResetFlag: true, } err := hdr.Write(b, protocol.VersionWhatever, protocol.PerspectiveServer) Expect(err).To(MatchError(errResetAndVersionFlagSet)) @@ -470,52 +529,116 @@ var _ = Describe("Public Header", func() { Expect(err).To(MatchError(errPacketNumberLenNotSet)) }) - It("writes a header with a 1-byte packet number", func() { - b := &bytes.Buffer{} - hdr := PublicHeader{ - ConnectionID: 0x4cfa9f9b668619f6, - PacketNumber: 0xDECAFBAD, - PacketNumberLen: protocol.PacketNumberLen1, - } - err := hdr.Write(b, protocol.VersionWhatever, protocol.PerspectiveServer) - Expect(err).ToNot(HaveOccurred()) - Expect(b.Bytes()).To(Equal([]byte{0x08, 0xf6, 0x19, 0x86, 0x66, 0x9b, 0x9f, 0xfa, 0x4c, 0xAD})) + Context("in little endian", func() { + version := protocol.Version37 + + BeforeEach(func() { + Expect(utils.GetByteOrder(version)).To(Equal(utils.LittleEndian)) + }) + + It("writes a header with a 1-byte packet number", func() { + b := &bytes.Buffer{} + hdr := PublicHeader{ + ConnectionID: 0x4cfa9f9b668619f6, + PacketNumber: 0xDECAFBAD, + PacketNumberLen: protocol.PacketNumberLen1, + } + err := hdr.Write(b, version, protocol.PerspectiveServer) + Expect(err).ToNot(HaveOccurred()) + Expect(b.Bytes()).To(Equal([]byte{0x08, 0xf6, 0x19, 0x86, 0x66, 0x9b, 0x9f, 0xfa, 0x4c, 0xAD})) + }) + + It("writes a header with a 2-byte packet number", func() { + b := &bytes.Buffer{} + hdr := PublicHeader{ + ConnectionID: 0x4cfa9f9b668619f6, + PacketNumber: 0xDECAFBAD, + PacketNumberLen: protocol.PacketNumberLen2, + } + err := hdr.Write(b, version, protocol.PerspectiveServer) + Expect(err).ToNot(HaveOccurred()) + Expect(b.Bytes()).To(Equal([]byte{0x18, 0xf6, 0x19, 0x86, 0x66, 0x9b, 0x9f, 0xfa, 0x4c, 0xad, 0xfb})) + }) + + It("writes a header with a 4-byte packet number", func() { + b := &bytes.Buffer{} + hdr := PublicHeader{ + ConnectionID: 0x4cfa9f9b668619f6, + PacketNumber: 0x13DECAFBAD, + PacketNumberLen: protocol.PacketNumberLen4, + } + err := hdr.Write(b, version, protocol.PerspectiveServer) + Expect(err).ToNot(HaveOccurred()) + Expect(b.Bytes()).To(Equal([]byte{0x28, 0xf6, 0x19, 0x86, 0x66, 0x9b, 0x9f, 0xfa, 0x4c, 0xAD, 0xfb, 0xca, 0xde})) + }) + + It("writes a header with a 6-byte packet number", func() { + b := &bytes.Buffer{} + hdr := PublicHeader{ + ConnectionID: 0x4cfa9f9b668619f6, + PacketNumber: 0xBE1337DECAFBAD, + PacketNumberLen: protocol.PacketNumberLen6, + } + err := hdr.Write(b, version, protocol.PerspectiveServer) + Expect(err).ToNot(HaveOccurred()) + Expect(b.Bytes()).To(Equal([]byte{0x38, 0xf6, 0x19, 0x86, 0x66, 0x9b, 0x9f, 0xfa, 0x4c, 0xad, 0xfb, 0xca, 0xde, 0x37, 0x13})) + }) }) - It("writes a header with a 2-byte packet number", func() { - b := &bytes.Buffer{} - hdr := PublicHeader{ - ConnectionID: 0x4cfa9f9b668619f6, - PacketNumber: 0xDECAFBAD, - PacketNumberLen: protocol.PacketNumberLen2, - } - err := hdr.Write(b, protocol.VersionWhatever, protocol.PerspectiveServer) - Expect(err).ToNot(HaveOccurred()) - Expect(b.Bytes()).To(Equal([]byte{0x18, 0xf6, 0x19, 0x86, 0x66, 0x9b, 0x9f, 0xfa, 0x4c, 0xAD, 0xFB})) - }) + Context("in big endian", func() { + version := protocol.Version39 - It("writes a header with a 4-byte packet number", func() { - b := &bytes.Buffer{} - hdr := PublicHeader{ - ConnectionID: 0x4cfa9f9b668619f6, - PacketNumber: 0x13DECAFBAD, - PacketNumberLen: protocol.PacketNumberLen4, - } - err := hdr.Write(b, protocol.VersionWhatever, protocol.PerspectiveServer) - Expect(err).ToNot(HaveOccurred()) - Expect(b.Bytes()).To(Equal([]byte{0x28, 0xf6, 0x19, 0x86, 0x66, 0x9b, 0x9f, 0xfa, 0x4c, 0xAD, 0xFB, 0xCA, 0xDE})) - }) + BeforeEach(func() { + Expect(utils.GetByteOrder(version)).To(Equal(utils.BigEndian)) + }) - It("writes a header with a 6-byte packet number", func() { - b := &bytes.Buffer{} - hdr := PublicHeader{ - ConnectionID: 0x4cfa9f9b668619f6, - PacketNumber: 0xBE1337DECAFBAD, - PacketNumberLen: protocol.PacketNumberLen6, - } - err := hdr.Write(b, protocol.VersionWhatever, protocol.PerspectiveServer) - Expect(err).ToNot(HaveOccurred()) - Expect(b.Bytes()).To(Equal([]byte{0x38, 0xf6, 0x19, 0x86, 0x66, 0x9b, 0x9f, 0xfa, 0x4c, 0xAD, 0xFB, 0xCA, 0xDE, 0x37, 0x13})) + It("writes a header with a 1-byte packet number", func() { + b := &bytes.Buffer{} + hdr := PublicHeader{ + ConnectionID: 0x4cfa9f9b668619f6, + PacketNumber: 0xdecafbad, + PacketNumberLen: protocol.PacketNumberLen1, + } + err := hdr.Write(b, version, protocol.PerspectiveServer) + Expect(err).ToNot(HaveOccurred()) + Expect(b.Bytes()).To(Equal([]byte{0x08, 0xf6, 0x19, 0x86, 0x66, 0x9b, 0x9f, 0xfa, 0x4c, 0xad})) + }) + + It("writes a header with a 2-byte packet number", func() { + b := &bytes.Buffer{} + hdr := PublicHeader{ + ConnectionID: 0x4cfa9f9b668619f6, + PacketNumber: 0xdecafbad, + PacketNumberLen: protocol.PacketNumberLen2, + } + err := hdr.Write(b, version, protocol.PerspectiveServer) + Expect(err).ToNot(HaveOccurred()) + Expect(b.Bytes()).To(Equal([]byte{0x18, 0xf6, 0x19, 0x86, 0x66, 0x9b, 0x9f, 0xfa, 0x4c, 0xfb, 0xad})) + }) + + It("writes a header with a 4-byte packet number", func() { + b := &bytes.Buffer{} + hdr := PublicHeader{ + ConnectionID: 0x4cfa9f9b668619f6, + PacketNumber: 0x13decafbad, + PacketNumberLen: protocol.PacketNumberLen4, + } + err := hdr.Write(b, version, protocol.PerspectiveServer) + Expect(err).ToNot(HaveOccurred()) + Expect(b.Bytes()).To(Equal([]byte{0x28, 0xf6, 0x19, 0x86, 0x66, 0x9b, 0x9f, 0xfa, 0x4c, 0xde, 0xca, 0xfb, 0xad})) + }) + + It("writes a header with a 6-byte packet number", func() { + b := &bytes.Buffer{} + hdr := PublicHeader{ + ConnectionID: 0x4cfa9f9b668619f6, + PacketNumber: 0xbe1337decafbad, + PacketNumberLen: protocol.PacketNumberLen6, + } + err := hdr.Write(b, version, protocol.PerspectiveServer) + Expect(err).ToNot(HaveOccurred()) + Expect(b.Bytes()).To(Equal([]byte{0x38, 0xf6, 0x19, 0x86, 0x66, 0x9b, 0x9f, 0xfa, 0x4c, 0x13, 0x37, 0xde, 0xca, 0xfb, 0xad})) + }) }) }) }) diff --git a/server.go b/server.go index d1ba95eb..9ec3fc64 100644 --- a/server.go +++ b/server.go @@ -19,6 +19,7 @@ import ( type packetHandler interface { Session handlePacket(*receivedPacket) + GetVersion() protocol.VersionNumber run() error closeRemote(error) } @@ -205,16 +206,35 @@ func (s *server) handlePacket(pconn net.PacketConn, remoteAddr net.Addr, packet rcvTime := time.Now() r := bytes.NewReader(packet) - hdr, err := ParsePublicHeader(r, protocol.PerspectiveClient) + connID, err := PeekConnectionID(r, protocol.PerspectiveClient) + if err != nil { + return qerr.Error(qerr.InvalidPacketHeader, err.Error()) + } + + s.sessionsMutex.RLock() + session, ok := s.sessions[connID] + s.sessionsMutex.RUnlock() + + if ok && session == nil { + // Late packet for closed session + return nil + } + + version := protocol.VersionUnknown + if ok { + version = session.GetVersion() + } + + hdr, err := ParsePublicHeader(r, protocol.PerspectiveClient, version) + if err == errPacketWithUnknownVersion { + _, err = pconn.WriteTo(writePublicReset(connID, 0, 0), remoteAddr) + return err + } if err != nil { return qerr.Error(qerr.InvalidPacketHeader, err.Error()) } hdr.Raw = packet[:len(packet)-r.Len()] - s.sessionsMutex.RLock() - session, ok := s.sessions[hdr.ConnectionID] - s.sessionsMutex.RUnlock() - // ignore all Public Reset packets if hdr.ResetFlag { if ok { @@ -250,10 +270,6 @@ func (s *server) handlePacket(pconn net.PacketConn, remoteAddr net.Addr, packet } if !ok { - if !hdr.VersionFlag { - _, err = pconn.WriteTo(writePublicReset(hdr.ConnectionID, hdr.PacketNumber, 0), remoteAddr) - return err - } version := hdr.VersionNumber if !protocol.IsSupportedVersion(s.config.Versions, version) { return errors.New("Server BUG: negotiated version not supported") @@ -273,7 +289,7 @@ func (s *server) handlePacket(pconn net.PacketConn, remoteAddr net.Addr, packet return err } s.sessionsMutex.Lock() - s.sessions[hdr.ConnectionID] = session + s.sessions[connID] = session s.sessionsMutex.Unlock() go func() { @@ -295,10 +311,6 @@ func (s *server) handlePacket(pconn net.PacketConn, remoteAddr net.Addr, packet s.sessionQueue <- session }() } - if session == nil { - // Late packet for closed session - return nil - } session.handlePacket(&receivedPacket{ remoteAddr: remoteAddr, publicHeader: hdr, diff --git a/server_test.go b/server_test.go index a74931ef..3bb85001 100644 --- a/server_test.go +++ b/server_test.go @@ -59,11 +59,12 @@ func (s *mockSession) closeRemote(e error) { func (s *mockSession) OpenStream() (Stream, error) { return &stream{streamID: 1337}, nil } -func (s *mockSession) AcceptStream() (Stream, error) { panic("not implemented") } -func (s *mockSession) OpenStreamSync() (Stream, error) { panic("not implemented") } -func (s *mockSession) LocalAddr() net.Addr { panic("not implemented") } -func (s *mockSession) RemoteAddr() net.Addr { panic("not implemented") } -func (*mockSession) Context() context.Context { panic("not implemented") } +func (s *mockSession) AcceptStream() (Stream, error) { panic("not implemented") } +func (s *mockSession) OpenStreamSync() (Stream, error) { panic("not implemented") } +func (s *mockSession) LocalAddr() net.Addr { panic("not implemented") } +func (s *mockSession) RemoteAddr() net.Addr { panic("not implemented") } +func (*mockSession) Context() context.Context { panic("not implemented") } +func (*mockSession) GetVersion() protocol.VersionNumber { return protocol.VersionWhatever } var _ Session = &mockSession{} var _ NonFWSession = &mockSession{} diff --git a/session.go b/session.go index 88d8ba56..b299beac 100644 --- a/session.go +++ b/session.go @@ -841,3 +841,7 @@ func (s *session) LocalAddr() net.Addr { func (s *session) RemoteAddr() net.Addr { return s.conn.RemoteAddr() } + +func (s *session) GetVersion() protocol.VersionNumber { + return s.version +} diff --git a/session_test.go b/session_test.go index a6644c1f..3556c54d 100644 --- a/session_test.go +++ b/session_test.go @@ -655,6 +655,11 @@ var _ = Describe("Session", func() { close(done) }) + It("tells its versions", func() { + sess.version = 4242 + Expect(sess.GetVersion()).To(Equal(protocol.VersionNumber(4242))) + }) + Context("waiting until the handshake completes", func() { It("waits until the handshake is complete", func(done Done) { go sess.run()