diff --git a/connection.go b/connection.go index c02b7af6..25563ffc 100644 --- a/connection.go +++ b/connection.go @@ -1072,7 +1072,7 @@ func (s *connection) handleVersionNegotiationPacket(p *receivedPacket) { return } - hdr, supportedVersions, err := wire.ParseVersionNegotiationPacket(bytes.NewReader(p.data)) + src, dest, supportedVersions, err := wire.ParseVersionNegotiationPacket(p.data) if err != nil { if s.tracer != nil { s.tracer.DroppedPacket(logging.PacketTypeVersionNegotiation, p.Size(), logging.PacketDropHeaderParseError) @@ -1094,11 +1094,7 @@ func (s *connection) handleVersionNegotiationPacket(p *receivedPacket) { s.logger.Infof("Received a Version Negotiation packet. Supported Versions: %s", supportedVersions) if s.tracer != nil { - s.tracer.ReceivedVersionNegotiationPacket( - protocol.ArbitraryLenConnectionID(hdr.DestConnectionID), - protocol.ArbitraryLenConnectionID(hdr.SrcConnectionID), - supportedVersions, - ) + s.tracer.ReceivedVersionNegotiationPacket(dest, src, supportedVersions) } newVersion, ok := protocol.ChooseSupportedVersion(s.config.Versions, supportedVersions) if !ok { diff --git a/connection_test.go b/connection_test.go index 375f168f..609ebff4 100644 --- a/connection_test.go +++ b/connection_test.go @@ -663,7 +663,11 @@ var _ = Describe("Connection", func() { }) It("drops Version Negotiation packets", func() { - b := wire.ComposeVersionNegotiation(srcConnID, destConnID, conn.config.Versions) + b := wire.ComposeVersionNegotiation( + protocol.ArbitraryLenConnectionID(srcConnID), + protocol.ArbitraryLenConnectionID(destConnID), + conn.config.Versions, + ) tracer.EXPECT().DroppedPacket(logging.PacketTypeVersionNegotiation, protocol.ByteCount(len(b)), logging.PacketDropUnexpectedPacket) Expect(conn.handlePacketImpl(&receivedPacket{ data: b, @@ -2593,7 +2597,11 @@ var _ = Describe("Client Connection", func() { Context("handling Version Negotiation", func() { getVNP := func(versions ...protocol.VersionNumber) *receivedPacket { - b := wire.ComposeVersionNegotiation(srcConnID, destConnID, versions) + b := wire.ComposeVersionNegotiation( + protocol.ArbitraryLenConnectionID(srcConnID), + protocol.ArbitraryLenConnectionID(destConnID), + versions, + ) return &receivedPacket{ data: b, buffer: getPacketBuffer(), diff --git a/fuzzing/header/cmd/corpus.go b/fuzzing/header/cmd/corpus.go index eeb880ce..1cccc4c6 100644 --- a/fuzzing/header/cmd/corpus.go +++ b/fuzzing/header/cmd/corpus.go @@ -19,7 +19,7 @@ func getRandomData(l int) []byte { return b } -func getVNP(src, dest protocol.ConnectionID, numVersions int) []byte { +func getVNP(src, dest protocol.ArbitraryLenConnectionID, numVersions int) []byte { versions := make([]protocol.VersionNumber, numVersions) for i := 0; i < numVersions; i++ { versions[i] = protocol.VersionNumber(rand.Uint32()) @@ -113,28 +113,28 @@ func main() { vnps := [][]byte{ getVNP( - protocol.ConnectionID(getRandomData(8)), - protocol.ConnectionID(getRandomData(10)), + protocol.ArbitraryLenConnectionID(getRandomData(8)), + protocol.ArbitraryLenConnectionID(getRandomData(10)), 4, ), getVNP( - protocol.ConnectionID(getRandomData(10)), - protocol.ConnectionID(getRandomData(5)), + protocol.ArbitraryLenConnectionID(getRandomData(10)), + protocol.ArbitraryLenConnectionID(getRandomData(5)), 0, ), getVNP( - protocol.ConnectionID(getRandomData(3)), - protocol.ConnectionID(getRandomData(19)), + protocol.ArbitraryLenConnectionID(getRandomData(3)), + protocol.ArbitraryLenConnectionID(getRandomData(19)), 100, ), getVNP( - protocol.ConnectionID(getRandomData(3)), + protocol.ArbitraryLenConnectionID(getRandomData(3)), nil, 20, ), getVNP( nil, - protocol.ConnectionID(getRandomData(10)), + protocol.ArbitraryLenConnectionID(getRandomData(10)), 5, ), } diff --git a/fuzzing/header/fuzz.go b/fuzzing/header/fuzz.go index ba37172e..87c72e6c 100644 --- a/fuzzing/header/fuzz.go +++ b/fuzzing/header/fuzz.go @@ -82,16 +82,16 @@ func fuzzVNP(data []byte) int { if err != nil { return 0 } - hdr, versions, err := wire.ParseVersionNegotiationPacket(bytes.NewReader(data)) + dest, src, versions, err := wire.ParseVersionNegotiationPacket(data) if err != nil { return 0 } - if !hdr.DestConnectionID.Equal(connID) { + if !bytes.Equal(dest, connID.Bytes()) { panic("connection IDs don't match") } if len(versions) == 0 { panic("no versions") } - wire.ComposeVersionNegotiation(hdr.SrcConnectionID, hdr.DestConnectionID, versions) + wire.ComposeVersionNegotiation(src, dest, versions) return 1 } diff --git a/integrationtests/self/mitm_test.go b/integrationtests/self/mitm_test.go index f7c6e6c7..ff543fd4 100644 --- a/integrationtests/self/mitm_test.go +++ b/integrationtests/self/mitm_test.go @@ -326,7 +326,11 @@ var _ = Describe("MITM test", func() { // Create fake version negotiation packet with no supported versions versions := []protocol.VersionNumber{} - packet := wire.ComposeVersionNegotiation(hdr.SrcConnectionID, hdr.DestConnectionID, versions) + packet := wire.ComposeVersionNegotiation( + protocol.ArbitraryLenConnectionID(hdr.SrcConnectionID.Bytes()), + protocol.ArbitraryLenConnectionID(hdr.DestConnectionID.Bytes()), + versions, + ) // Send the packet _, err = serverUDPConn.WriteTo(packet, clientUDPConn.LocalAddr()) diff --git a/internal/wire/version_negotiation.go b/internal/wire/version_negotiation.go index 196853e0..2cfa2ca3 100644 --- a/internal/wire/version_negotiation.go +++ b/internal/wire/version_negotiation.go @@ -3,6 +3,7 @@ package wire import ( "bytes" "crypto/rand" + "encoding/binary" "errors" "github.com/lucas-clemente/quic-go/internal/protocol" @@ -10,32 +11,30 @@ import ( ) // ParseVersionNegotiationPacket parses a Version Negotiation packet. -func ParseVersionNegotiationPacket(b *bytes.Reader) (*Header, []protocol.VersionNumber, error) { - hdr, err := parseHeader(b, 0) +func ParseVersionNegotiationPacket(b []byte) (dest, src protocol.ArbitraryLenConnectionID, _ []protocol.VersionNumber, _ error) { + n, dest, src, err := ParseArbitraryLenConnectionIDs(b) if err != nil { - return nil, nil, err + return nil, nil, nil, err } - if b.Len() == 0 { + b = b[n:] + if len(b) == 0 { //nolint:stylecheck - return nil, nil, errors.New("Version Negotiation packet has empty version list") + return nil, nil, nil, errors.New("Version Negotiation packet has empty version list") } - if b.Len()%4 != 0 { + if len(b)%4 != 0 { //nolint:stylecheck - return nil, nil, errors.New("Version Negotiation packet has a version list with an invalid length") + return nil, nil, nil, errors.New("Version Negotiation packet has a version list with an invalid length") } - versions := make([]protocol.VersionNumber, b.Len()/4) - for i := 0; b.Len() > 0; i++ { - v, err := utils.BigEndian.ReadUint32(b) - if err != nil { - return nil, nil, err - } - versions[i] = protocol.VersionNumber(v) + versions := make([]protocol.VersionNumber, len(b)/4) + for i := 0; len(b) > 0; i++ { + versions[i] = protocol.VersionNumber(binary.BigEndian.Uint32(b[:4])) + b = b[4:] } - return hdr, versions, nil + return dest, src, versions, nil } // ComposeVersionNegotiation composes a Version Negotiation -func ComposeVersionNegotiation(destConnID, srcConnID protocol.ConnectionID, versions []protocol.VersionNumber) []byte { +func ComposeVersionNegotiation(destConnID, srcConnID protocol.ArbitraryLenConnectionID, versions []protocol.VersionNumber) []byte { greasedVersions := protocol.GetGreasedVersions(versions) expectedLen := 1 /* type byte */ + 4 /* version field */ + 1 /* dest connection ID length field */ + destConnID.Len() + 1 /* src connection ID length field */ + srcConnID.Len() + len(greasedVersions)*4 buf := bytes.NewBuffer(make([]byte, 0, expectedLen)) @@ -44,9 +43,9 @@ func ComposeVersionNegotiation(destConnID, srcConnID protocol.ConnectionID, vers buf.WriteByte(r[0] | 0x80) utils.BigEndian.WriteUint32(buf, 0) // version 0 buf.WriteByte(uint8(destConnID.Len())) - buf.Write(destConnID) + buf.Write(destConnID.Bytes()) buf.WriteByte(uint8(srcConnID.Len())) - buf.Write(srcConnID) + buf.Write(srcConnID.Bytes()) for _, v := range greasedVersions { utils.BigEndian.WriteUint32(buf, uint32(v)) } diff --git a/internal/wire/version_negotiation_test.go b/internal/wire/version_negotiation_test.go index 2783cb17..cd9bb1c2 100644 --- a/internal/wire/version_negotiation_test.go +++ b/internal/wire/version_negotiation_test.go @@ -1,8 +1,10 @@ package wire import ( - "bytes" "encoding/binary" + mrand "math/rand" + + "golang.org/x/exp/rand" "github.com/lucas-clemente/quic-go/internal/protocol" . "github.com/onsi/ginkgo" @@ -10,9 +12,16 @@ import ( ) var _ = Describe("Version Negotiation Packets", func() { + randConnID := func(l int) protocol.ArbitraryLenConnectionID { + b := make(protocol.ArbitraryLenConnectionID, l) + _, err := mrand.Read(b) + Expect(err).ToNot(HaveOccurred()) + return b + } + It("parses a Version Negotiation packet", func() { - srcConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12} - destConnID := protocol.ConnectionID{9, 8, 7, 6, 5, 4, 3, 2, 1} + srcConnID := randConnID(rand.Intn(255) + 1) + destConnID := randConnID(rand.Intn(255) + 1) versions := []protocol.VersionNumber{0x22334455, 0x33445566} data := []byte{0x80, 0, 0, 0, 0} data = append(data, uint8(len(destConnID))) @@ -24,44 +33,44 @@ var _ = Describe("Version Negotiation Packets", func() { binary.BigEndian.PutUint32(data[len(data)-4:], uint32(v)) } Expect(IsVersionNegotiationPacket(data)).To(BeTrue()) - hdr, supportedVersions, err := ParseVersionNegotiationPacket(bytes.NewReader(data)) + dest, src, supportedVersions, err := ParseVersionNegotiationPacket(data) Expect(err).ToNot(HaveOccurred()) - Expect(hdr.DestConnectionID).To(Equal(destConnID)) - Expect(hdr.SrcConnectionID).To(Equal(srcConnID)) - Expect(hdr.IsLongHeader).To(BeTrue()) - Expect(hdr.Version).To(BeZero()) + Expect(dest).To(Equal(destConnID)) + Expect(src).To(Equal(srcConnID)) Expect(supportedVersions).To(Equal(versions)) }) It("errors if it contains versions of the wrong length", func() { - connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} + connID := protocol.ArbitraryLenConnectionID{1, 2, 3, 4, 5, 6, 7, 8} versions := []protocol.VersionNumber{0x22334455, 0x33445566} data := ComposeVersionNegotiation(connID, connID, versions) - _, _, err := ParseVersionNegotiationPacket(bytes.NewReader(data[:len(data)-2])) + _, _, _, err := ParseVersionNegotiationPacket(data[:len(data)-2]) Expect(err).To(MatchError("Version Negotiation packet has a version list with an invalid length")) }) It("errors if the version list is empty", func() { - connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} + connID := protocol.ArbitraryLenConnectionID{1, 2, 3, 4, 5, 6, 7, 8} versions := []protocol.VersionNumber{0x22334455} data := ComposeVersionNegotiation(connID, connID, versions) // remove 8 bytes (two versions), since ComposeVersionNegotiation also added a reserved version number data = data[:len(data)-8] - _, _, err := ParseVersionNegotiationPacket(bytes.NewReader(data)) + _, _, _, err := ParseVersionNegotiationPacket(data) Expect(err).To(MatchError("Version Negotiation packet has empty version list")) }) It("adds a reserved version", func() { - srcConnID := protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37} - destConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} + srcConnID := protocol.ArbitraryLenConnectionID{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37} + destConnID := protocol.ArbitraryLenConnectionID{1, 2, 3, 4, 5, 6, 7, 8} versions := []protocol.VersionNumber{1001, 1003} data := ComposeVersionNegotiation(destConnID, srcConnID, versions) Expect(IsLongHeaderPacket(data[0])).To(BeTrue()) - hdr, supportedVersions, err := ParseVersionNegotiationPacket(bytes.NewReader(data)) + v, err := ParseVersion(data) Expect(err).ToNot(HaveOccurred()) - Expect(hdr.DestConnectionID).To(Equal(destConnID)) - Expect(hdr.SrcConnectionID).To(Equal(srcConnID)) - Expect(hdr.Version).To(BeZero()) + Expect(v).To(BeZero()) + dest, src, supportedVersions, err := ParseVersionNegotiationPacket(data) + Expect(err).ToNot(HaveOccurred()) + Expect(dest).To(Equal(destConnID)) + Expect(src).To(Equal(srcConnID)) // the supported versions should include one reserved version number Expect(supportedVersions).To(HaveLen(len(versions) + 1)) for _, v := range versions { diff --git a/server.go b/server.go index 218d5825..cacd0358 100644 --- a/server.go +++ b/server.go @@ -320,20 +320,43 @@ func (s *baseServer) handlePacketImpl(p *receivedPacket) bool /* is the buffer s } return false } + // Short header packets should never end up here in the first place + if !wire.IsLongHeaderPacket(p.data[0]) { + panic(fmt.Sprintf("misrouted packet: %#v", p.data)) + } + v, err := wire.ParseVersion(p.data) + // send a Version Negotiation Packet if the client is speaking a different protocol version + if err != nil || !protocol.IsSupportedVersion(s.config.Versions, v) { + if err != nil || p.Size() < protocol.MinUnknownVersionPacketSize { + s.logger.Debugf("Dropping a packet with an unknown version that is too small (%d bytes)", p.Size()) + if s.config.Tracer != nil { + s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropUnexpectedPacket) + } + return false + } + _, src, dest, err := wire.ParseArbitraryLenConnectionIDs(p.data) + if err != nil { // should never happen + s.logger.Debugf("Dropping a packet with an unknown version for which we failed to parse connection IDs") + if s.config.Tracer != nil { + s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropUnexpectedPacket) + } + return false + } + if !s.config.DisableVersionNegotiationPackets { + go s.sendVersionNegotiationPacket(p.remoteAddr, src, dest, p.info.OOB()) + } + return false + } // If we're creating a new connection, the packet will be passed to the connection. // The header will then be parsed again. hdr, _, _, err := wire.ParsePacket(p.data, s.config.ConnectionIDGenerator.ConnectionIDLen()) - if err != nil && err != wire.ErrUnsupportedVersion { + if err != nil { if s.config.Tracer != nil { s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropHeaderParseError) } s.logger.Debugf("Error parsing packet: %s", err) return false } - // Short header packets should never end up here in the first place - if !hdr.IsLongHeader { - panic(fmt.Sprintf("misrouted packet: %#v", hdr)) - } if hdr.Type == protocol.PacketTypeInitial && p.Size() < protocol.MinInitialPacketSize { s.logger.Debugf("Dropping a packet that is too small to be a valid Initial (%d bytes)", p.Size()) if s.config.Tracer != nil { @@ -341,20 +364,7 @@ func (s *baseServer) handlePacketImpl(p *receivedPacket) bool /* is the buffer s } return false } - // send a Version Negotiation Packet if the client is speaking a different protocol version - if !protocol.IsSupportedVersion(s.config.Versions, hdr.Version) { - if p.Size() < protocol.MinUnknownVersionPacketSize { - s.logger.Debugf("Dropping a packet with an unknown version that is too small (%d bytes)", p.Size()) - if s.config.Tracer != nil { - s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropUnexpectedPacket) - } - return false - } - if !s.config.DisableVersionNegotiationPackets { - go s.sendVersionNegotiationPacket(p, hdr) - } - return false - } + if hdr.IsLongHeader && hdr.Type != protocol.PacketTypeInitial { // Drop long header packets. // There's little point in sending a Stateless Reset, since the client @@ -664,22 +674,14 @@ func (s *baseServer) sendError(remoteAddr net.Addr, hdr *wire.Header, sealer han return err } -func (s *baseServer) sendVersionNegotiationPacket(p *receivedPacket, hdr *wire.Header) { - s.logger.Debugf("Client offered version %s, sending Version Negotiation", hdr.Version) - data := wire.ComposeVersionNegotiation(hdr.SrcConnectionID, hdr.DestConnectionID, s.config.Versions) +func (s *baseServer) sendVersionNegotiationPacket(remote net.Addr, src, dest protocol.ArbitraryLenConnectionID, oob []byte) { + s.logger.Debugf("Client offered version %s, sending Version Negotiation") + + data := wire.ComposeVersionNegotiation(dest, src, s.config.Versions) if s.config.Tracer != nil { - s.config.Tracer.SentPacket( - p.remoteAddr, - &wire.Header{ - IsLongHeader: true, - DestConnectionID: hdr.SrcConnectionID, - SrcConnectionID: hdr.DestConnectionID, - }, - protocol.ByteCount(len(data)), - nil, - ) + s.config.Tracer.SentVersionNegotiationPacket(remote, src, dest, s.config.Versions) } - if _, err := s.conn.WritePacket(data, p.remoteAddr, p.info.OOB()); err != nil { + if _, err := s.conn.WritePacket(data, remote, oob); err != nil { s.logger.Debugf("Error sending Version Negotiation: %s", err) } } diff --git a/server_test.go b/server_test.go index eafba0cc..672f17d3 100644 --- a/server_test.go +++ b/server_test.go @@ -336,20 +336,18 @@ var _ = Describe("Server", func() { }, make([]byte, protocol.MinUnknownVersionPacketSize)) raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} packet.remoteAddr = raddr - tracer.EXPECT().SentPacket(packet.remoteAddr, gomock.Any(), gomock.Any(), nil).Do(func(_ net.Addr, replyHdr *logging.Header, _ logging.ByteCount, _ []logging.Frame) { - Expect(replyHdr.IsLongHeader).To(BeTrue()) - Expect(replyHdr.Version).To(BeZero()) - Expect(replyHdr.SrcConnectionID).To(Equal(destConnID)) - Expect(replyHdr.DestConnectionID).To(Equal(srcConnID)) + tracer.EXPECT().SentVersionNegotiationPacket(packet.remoteAddr, gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ net.Addr, src, dest protocol.ArbitraryLenConnectionID, _ []protocol.VersionNumber) { + Expect(src).To(BeEquivalentTo(destConnID)) + Expect(dest).To(BeEquivalentTo(srcConnID)) }) done := make(chan struct{}) conn.EXPECT().WriteTo(gomock.Any(), raddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) { defer close(done) Expect(wire.IsVersionNegotiationPacket(b)).To(BeTrue()) - hdr, versions, err := wire.ParseVersionNegotiationPacket(bytes.NewReader(b)) + dest, src, versions, err := wire.ParseVersionNegotiationPacket(b) Expect(err).ToNot(HaveOccurred()) - Expect(hdr.DestConnectionID).To(Equal(srcConnID)) - Expect(hdr.SrcConnectionID).To(Equal(destConnID)) + Expect(dest).To(BeEquivalentTo(srcConnID)) + Expect(src).To(BeEquivalentTo(destConnID)) Expect(versions).ToNot(ContainElement(protocol.VersionNumber(0x42))) return len(b), nil }) @@ -378,8 +376,8 @@ var _ = Describe("Server", func() { It("ignores Version Negotiation packets", func() { data := wire.ComposeVersionNegotiation( - protocol.ConnectionID{1, 2, 3, 4}, - protocol.ConnectionID{4, 3, 2, 1}, + protocol.ArbitraryLenConnectionID{1, 2, 3, 4}, + protocol.ArbitraryLenConnectionID{4, 3, 2, 1}, []protocol.VersionNumber{1, 2, 3}, ) raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}