From fc57bf6c35bafcdc473c3f815c96e3a1561dcf78 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sun, 24 Jun 2018 16:47:59 +0700 Subject: [PATCH] only set the destination connection ID for the Public Header --- client_multiplexer_test.go | 1 - client_test.go | 2 -- internal/wire/header_test.go | 11 ++---- internal/wire/public_header.go | 7 ++-- internal/wire/public_header_test.go | 41 +++++------------------ internal/wire/version_negotiation_test.go | 2 +- packet_packer.go | 2 +- packet_packer_test.go | 5 +++ server.go | 2 +- server_test.go | 11 ++---- session.go | 6 ++-- session_test.go | 3 ++ 12 files changed, 31 insertions(+), 62 deletions(-) diff --git a/client_multiplexer_test.go b/client_multiplexer_test.go index b848c7d4..d477a183 100644 --- a/client_multiplexer_test.go +++ b/client_multiplexer_test.go @@ -17,7 +17,6 @@ var _ = Describe("Client Multiplexer", func() { getPacket := func(connID protocol.ConnectionID) []byte { buf := &bytes.Buffer{} err := (&wire.Header{ - SrcConnectionID: connID, DestConnectionID: connID, PacketNumberLen: protocol.PacketNumberLen1, }).Write(buf, protocol.PerspectiveServer, protocol.VersionWhatever) diff --git a/client_test.go b/client_test.go index 37b76496..6584599f 100644 --- a/client_test.go +++ b/client_test.go @@ -36,7 +36,6 @@ var _ = Describe("Client", func() { b := &bytes.Buffer{} err := (&wire.Header{ DestConnectionID: connID, - SrcConnectionID: connID, PacketNumber: 1, PacketNumberLen: 1, }).Write(b, protocol.PerspectiveServer, protocol.VersionWhatever) @@ -755,7 +754,6 @@ var _ = Describe("Client", func() { PacketNumber: 1, PacketNumberLen: protocol.PacketNumberLen2, DestConnectionID: connID, - SrcConnectionID: connID, } b := &bytes.Buffer{} err := ph.Write(b, protocol.PerspectiveServer, cl.version) diff --git a/internal/wire/header_test.go b/internal/wire/header_test.go index 1b1b74e4..a9c0c37d 100644 --- a/internal/wire/header_test.go +++ b/internal/wire/header_test.go @@ -81,7 +81,6 @@ var _ = Describe("Header", func() { VersionFlag: true, Version: versionPublicHeader, DestConnectionID: connID, - SrcConnectionID: connID, PacketNumber: 0x1337, PacketNumberLen: protocol.PacketNumberLen4, }).writePublicHeader(buf, protocol.PerspectiveClient, versionPublicHeader) @@ -89,7 +88,7 @@ var _ = Describe("Header", func() { hdr, err := ParseHeaderSentByClient(bytes.NewReader(buf.Bytes())) Expect(err).ToNot(HaveOccurred()) Expect(hdr.DestConnectionID).To(Equal(connID)) - Expect(hdr.SrcConnectionID).To(Equal(connID)) + Expect(hdr.SrcConnectionID).To(BeEmpty()) Expect(hdr.PacketNumber).To(Equal(protocol.PacketNumber(0x1337))) Expect(hdr.Version).To(Equal(versionPublicHeader)) Expect(hdr.IsPublicHeader).To(BeTrue()) @@ -100,7 +99,6 @@ var _ = Describe("Header", func() { buf := &bytes.Buffer{} err := (&Header{ DestConnectionID: connID, - SrcConnectionID: connID, PacketNumber: 0x1337, PacketNumberLen: protocol.PacketNumberLen4, DiversificationNonce: bytes.Repeat([]byte{'f'}, 32), @@ -109,7 +107,7 @@ var _ = Describe("Header", func() { hdr, err := ParseHeaderSentByServer(bytes.NewReader(buf.Bytes())) Expect(err).ToNot(HaveOccurred()) Expect(hdr.DestConnectionID).To(Equal(connID)) - Expect(hdr.SrcConnectionID).To(Equal(connID)) + Expect(hdr.SrcConnectionID).To(BeEmpty()) Expect(hdr.PacketNumber).To(Equal(protocol.PacketNumber(0x1337))) Expect(hdr.DiversificationNonce).To(HaveLen(32)) Expect(hdr.IsPublicHeader).To(BeTrue()) @@ -121,7 +119,6 @@ var _ = Describe("Header", func() { VersionFlag: true, Version: versionPublicHeader, DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, - SrcConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, PacketNumber: 0x1337, PacketNumberLen: protocol.PacketNumberLen2, }).writePublicHeader(buf, protocol.PerspectiveClient, versionPublicHeader) @@ -145,7 +142,7 @@ var _ = Describe("Header", func() { Expect(err).ToNot(HaveOccurred()) Expect(hdr.IsPublicHeader).To(BeTrue()) Expect(hdr.DestConnectionID).To(Equal(connID)) - Expect(hdr.SrcConnectionID).To(Equal(connID)) + Expect(hdr.SrcConnectionID).To(BeEmpty()) // in addition to the versions, the supported versions might contain a reserved version number for _, version := range versions { Expect(hdr.SupportedVersions).To(ContainElement(version)) @@ -177,7 +174,6 @@ var _ = Describe("Header", func() { buf := &bytes.Buffer{} hdr := &Header{ DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, - SrcConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, PacketNumber: 0x42, PacketNumberLen: protocol.PacketNumberLen2, } @@ -211,7 +207,6 @@ var _ = Describe("Header", func() { buf := &bytes.Buffer{} hdr := &Header{ DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, - SrcConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, PacketNumber: 0x42, PacketNumberLen: protocol.PacketNumberLen2, DiversificationNonce: bytes.Repeat([]byte{'f'}, 32), diff --git a/internal/wire/public_header.go b/internal/wire/public_header.go index ceaeba70..8ec8e699 100644 --- a/internal/wire/public_header.go +++ b/internal/wire/public_header.go @@ -27,10 +27,10 @@ func (h *Header) writePublicHeader(b *bytes.Buffer, pers protocol.Perspective, _ if h.VersionFlag && h.ResetFlag { return errResetAndVersionFlagSet } - if !h.DestConnectionID.Equal(h.SrcConnectionID) { - return fmt.Errorf("PublicHeader: SrcConnectionID must be equal to DestConnectionID") + if h.SrcConnectionID.Len() != 0 { + return errors.New("PublicHeader: SrcConnectionID must not be set") } - if len(h.DestConnectionID) != 8 { + if h.DestConnectionID.Len() != 8 { return fmt.Errorf("PublicHeader: wrong length for Connection ID: %d (expected 8)", len(h.DestConnectionID)) } @@ -142,7 +142,6 @@ func parsePublicHeader(b *bytes.Reader, packetSentBy protocol.Perspective) (*Hea return nil, errInvalidConnectionID } header.DestConnectionID = connID - header.SrcConnectionID = connID } // Contrary to what the gQUIC wire spec says, the 0x4 bit only indicates the presence of the diversification nonce for packets sent by the server. diff --git a/internal/wire/public_header_test.go b/internal/wire/public_header_test.go index 0e506e9f..34339f4d 100644 --- a/internal/wire/public_header_test.go +++ b/internal/wire/public_header_test.go @@ -28,7 +28,7 @@ var _ = Describe("Public Header", func() { Expect(hdr.ResetFlag).To(BeFalse()) connID := protocol.ConnectionID{0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6} Expect(hdr.DestConnectionID).To(Equal(connID)) - Expect(hdr.SrcConnectionID).To(Equal(connID)) + Expect(hdr.SrcConnectionID).To(BeEmpty()) Expect(hdr.Version).To(Equal(protocol.SupportedVersions[0])) Expect(hdr.SupportedVersions).To(BeEmpty()) Expect(hdr.PacketNumber).To(Equal(protocol.PacketNumber(1))) @@ -65,7 +65,7 @@ var _ = Describe("Public Header", func() { Expect(hdr.VersionFlag).To(BeFalse()) Expect(hdr.IsVersionNegotiation).To(BeFalse()) connID := protocol.ConnectionID{0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8} - Expect(hdr.SrcConnectionID).To(Equal(connID)) + Expect(hdr.SrcConnectionID).To(BeEmpty()) Expect(hdr.DestConnectionID).To(Equal(connID)) }) @@ -76,7 +76,7 @@ var _ = Describe("Public Header", func() { hdr, err := parsePublicHeader(b, protocol.PerspectiveServer) Expect(err).ToNot(HaveOccurred()) Expect(hdr.DestConnectionID).ToNot(BeEmpty()) - Expect(hdr.SrcConnectionID).ToNot(BeEmpty()) + Expect(hdr.SrcConnectionID).To(BeEmpty()) Expect(hdr.DiversificationNonce).To(Equal(divNonce)) Expect(b.Len()).To(BeZero()) }) @@ -95,7 +95,7 @@ var _ = Describe("Public Header", func() { hdr, err := parsePublicHeader(b, protocol.PerspectiveServer) Expect(err).ToNot(HaveOccurred()) Expect(hdr.DestConnectionID).To(Equal(connID)) - Expect(hdr.SrcConnectionID).To(Equal(connID)) + Expect(hdr.SrcConnectionID).To(BeEmpty()) Expect(hdr.VersionFlag).To(BeTrue()) Expect(hdr.Version).To(BeZero()) // unitialized Expect(hdr.IsVersionNegotiation).To(BeTrue()) @@ -176,7 +176,6 @@ var _ = Describe("Public Header", func() { b := &bytes.Buffer{} hdr := Header{ DestConnectionID: connID, - SrcConnectionID: connID, PacketNumber: 2, PacketNumberLen: protocol.PacketNumberLen4, } @@ -189,7 +188,6 @@ var _ = Describe("Public Header", func() { b := &bytes.Buffer{} hdr := Header{ DestConnectionID: connID, - SrcConnectionID: connID, PacketNumber: 0x1337, PacketNumberLen: protocol.PacketNumberLen2, } @@ -198,7 +196,7 @@ var _ = Describe("Public Header", func() { Expect(b.Bytes()).To(Equal([]byte{0x18, 0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6, 0x13, 0x37})) }) - It("refuses to write a Public Header if the source and destination connection IDs are not matching", func() { + It("refuses to write a Public Header with a source connection ID", func() { b := &bytes.Buffer{} hdr := Header{ DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, @@ -207,14 +205,13 @@ var _ = Describe("Public Header", func() { PacketNumberLen: protocol.PacketNumberLen4, } err := hdr.writePublicHeader(b, protocol.PerspectiveClient, versionBigEndian) - Expect(err).To(MatchError("PublicHeader: SrcConnectionID must be equal to DestConnectionID")) + Expect(err).To(MatchError("PublicHeader: SrcConnectionID must not be set")) }) It("refuses to write a Public Header if the connection ID has the wrong length", func() { connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7} hdr := Header{ DestConnectionID: connID, - SrcConnectionID: connID, PacketNumber: 2, PacketNumberLen: protocol.PacketNumberLen2, } @@ -227,7 +224,6 @@ var _ = Describe("Public Header", func() { connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} hdr := Header{ DestConnectionID: connID, - SrcConnectionID: connID, PacketNumber: 2, } b := &bytes.Buffer{} @@ -240,7 +236,6 @@ var _ = Describe("Public Header", func() { b := &bytes.Buffer{} hdr := Header{ DestConnectionID: connID, - SrcConnectionID: connID, OmitConnectionID: true, PacketNumberLen: protocol.PacketNumberLen1, PacketNumber: 1, @@ -254,7 +249,6 @@ var _ = Describe("Public Header", func() { b := &bytes.Buffer{} hdr := Header{ DestConnectionID: connID, - SrcConnectionID: connID, PacketNumber: 1, PacketNumberLen: protocol.PacketNumberLen1, DiversificationNonce: bytes.Repeat([]byte{1}, 32), @@ -283,7 +277,6 @@ var _ = Describe("Public Header", func() { hdr := Header{ VersionFlag: true, DestConnectionID: connID, - SrcConnectionID: connID, PacketNumber: 2, PacketNumberLen: protocol.PacketNumberLen6, } @@ -297,7 +290,6 @@ var _ = Describe("Public Header", func() { VersionFlag: true, Version: protocol.Version39, DestConnectionID: connID, - SrcConnectionID: connID, PacketNumber: 0x42, PacketNumberLen: protocol.PacketNumberLen1, } @@ -315,11 +307,9 @@ var _ = Describe("Public Header", func() { Context("PublicReset packets", func() { It("sets the Reset Flag", func() { b := &bytes.Buffer{} - connID := protocol.ConnectionID{0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6} hdr := Header{ ResetFlag: true, - DestConnectionID: connID, - SrcConnectionID: connID, + DestConnectionID: protocol.ConnectionID{0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6}, } err := hdr.writePublicHeader(b, protocol.PerspectiveServer, protocol.VersionWhatever) Expect(err).ToNot(HaveOccurred()) @@ -331,11 +321,9 @@ var _ = Describe("Public Header", func() { It("doesn't add a packet number for headers with Reset Flag sent as a client", func() { b := &bytes.Buffer{} - connID := protocol.ConnectionID{0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6} hdr := Header{ ResetFlag: true, - DestConnectionID: connID, - SrcConnectionID: connID, + DestConnectionID: protocol.ConnectionID{0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6}, PacketNumber: 2, PacketNumberLen: protocol.PacketNumberLen6, } @@ -365,7 +353,6 @@ var _ = Describe("Public Header", func() { It("errors when PacketNumberLen is not set", func() { hdr := Header{ DestConnectionID: connID, - SrcConnectionID: connID, PacketNumber: 0xdecafbad, } _, err := hdr.getPublicHeaderLength(protocol.PerspectiveServer) @@ -375,7 +362,6 @@ var _ = Describe("Public Header", func() { It("gets the length of a packet with longest packet number length and connectionID", func() { hdr := Header{ DestConnectionID: connID, - SrcConnectionID: connID, PacketNumber: 0xdecafbad, PacketNumberLen: protocol.PacketNumberLen4, } @@ -387,7 +373,6 @@ var _ = Describe("Public Header", func() { It("gets the lengths of a packet sent by the client with the VersionFlag set", func() { hdr := Header{ DestConnectionID: connID, - SrcConnectionID: connID, OmitConnectionID: true, PacketNumber: 0xdecafbad, PacketNumberLen: protocol.PacketNumberLen4, @@ -402,7 +387,6 @@ var _ = Describe("Public Header", func() { It("gets the length of a packet with longest packet number length and omitted connectionID", func() { hdr := Header{ DestConnectionID: connID, - SrcConnectionID: connID, OmitConnectionID: true, PacketNumber: 0xDECAFBAD, PacketNumberLen: protocol.PacketNumberLen4, @@ -415,7 +399,6 @@ var _ = Describe("Public Header", func() { It("gets the length of a packet 2 byte packet number length ", func() { hdr := Header{ DestConnectionID: connID, - SrcConnectionID: connID, PacketNumber: 0xDECAFBAD, PacketNumberLen: protocol.PacketNumberLen2, } @@ -438,7 +421,6 @@ var _ = Describe("Public Header", func() { hdr := Header{ ResetFlag: true, DestConnectionID: connID, - SrcConnectionID: connID, } length, err := hdr.getPublicHeaderLength(protocol.PerspectiveServer) Expect(err).NotTo(HaveOccurred()) @@ -451,7 +433,6 @@ var _ = Describe("Public Header", func() { b := &bytes.Buffer{} hdr := Header{ DestConnectionID: connID, - SrcConnectionID: connID, PacketNumber: 0xDECAFBAD, } err := hdr.writePublicHeader(b, protocol.PerspectiveServer, protocol.VersionWhatever) @@ -465,7 +446,6 @@ var _ = Describe("Public Header", func() { b := &bytes.Buffer{} hdr := Header{ DestConnectionID: connID, - SrcConnectionID: connID, PacketNumber: 0xdecafbad, PacketNumberLen: protocol.PacketNumberLen1, } @@ -478,7 +458,6 @@ var _ = Describe("Public Header", func() { b := &bytes.Buffer{} hdr := Header{ DestConnectionID: connID, - SrcConnectionID: connID, PacketNumber: 0xdecafbad, PacketNumberLen: protocol.PacketNumberLen2, } @@ -491,7 +470,6 @@ var _ = Describe("Public Header", func() { b := &bytes.Buffer{} hdr := Header{ DestConnectionID: connID, - SrcConnectionID: connID, PacketNumber: 0x13decafbad, PacketNumberLen: protocol.PacketNumberLen4, } @@ -504,7 +482,6 @@ var _ = Describe("Public Header", func() { b := &bytes.Buffer{} hdr := Header{ DestConnectionID: connID, - SrcConnectionID: connID, PacketNumber: 0xbe1337decafbad, PacketNumberLen: protocol.PacketNumberLen6, } @@ -535,7 +512,6 @@ var _ = Describe("Public Header", func() { It("logs a Public Header containing a connection ID", func() { (&Header{ DestConnectionID: protocol.ConnectionID{0x13, 0x37, 0, 0, 0xde, 0xca, 0xfb, 0xad}, - SrcConnectionID: protocol.ConnectionID{0x13, 0x37, 0, 0, 0xde, 0xca, 0xfb, 0xad}, PacketNumber: 0x1337, PacketNumberLen: 6, Version: protocol.Version39, @@ -565,7 +541,6 @@ var _ = Describe("Public Header", func() { It("logs diversification nonces", func() { (&Header{ DestConnectionID: []byte{0x13, 0x13, 0, 0, 0xde, 0xca, 0xfb, 0xad}, - SrcConnectionID: []byte{0x13, 0x13, 0, 0, 0xde, 0xca, 0xfb, 0xad}, DiversificationNonce: []byte{0xba, 0xdf, 0x00, 0x0d}, }).logPublicHeader(logger) Expect(buf.String()).To(ContainSubstring("DiversificationNonce: []byte{0xba, 0xdf, 0x0, 0xd}")) diff --git a/internal/wire/version_negotiation_test.go b/internal/wire/version_negotiation_test.go index a635f02e..cb722d71 100644 --- a/internal/wire/version_negotiation_test.go +++ b/internal/wire/version_negotiation_test.go @@ -17,7 +17,7 @@ var _ = Describe("Version Negotiation Packets", func() { Expect(err).ToNot(HaveOccurred()) Expect(hdr.VersionFlag).To(BeTrue()) Expect(hdr.DestConnectionID).To(Equal(connID)) - Expect(hdr.SrcConnectionID).To(Equal(connID)) + Expect(hdr.SrcConnectionID).To(BeEmpty()) Expect(hdr.SupportedVersions).To(Equal(versions)) }) diff --git a/packet_packer.go b/packet_packer.go index 61a3b863..e44e4b6c 100644 --- a/packet_packer.go +++ b/packet_packer.go @@ -452,13 +452,13 @@ func (p *packetPacker) getHeader(encLevel protocol.EncryptionLevel) *wire.Header header := &wire.Header{ DestConnectionID: p.destConnID, - SrcConnectionID: p.srcConnID, PacketNumber: pnum, PacketNumberLen: packetNumberLen, } if p.version.UsesTLS() && encLevel != protocol.EncryptionForwardSecure { header.IsLongHeader = true + header.SrcConnectionID = p.srcConnID // Set the payload len to maximum size. // Since it is encoded as a varint, this guarantees us that the header will end up at most as big as GetLength() returns. header.PayloadLen = p.maxPacketSize diff --git a/packet_packer_test.go b/packet_packer_test.go index bc1745d8..0bc1fc2b 100644 --- a/packet_packer_test.go +++ b/packet_packer_test.go @@ -158,6 +158,11 @@ var _ = Describe("Packet packer", func() { packer.version = versionPublicHeader }) + It("doesn't set the source connection ID", func() { + ph := packer.getHeader(protocol.EncryptionForwardSecure) + Expect(ph.SrcConnectionID).To(BeEmpty()) + }) + It("it omits the connection ID for forward-secure packets", func() { ph := packer.getHeader(protocol.EncryptionForwardSecure) Expect(ph.OmitConnectionID).To(BeFalse()) diff --git a/server.go b/server.go index 426a7df8..ea1e2da0 100644 --- a/server.go +++ b/server.go @@ -393,7 +393,7 @@ func (s *server) handleGQUICPacket(hdr *wire.Header, packetData []byte, remoteAd return errors.New("dropping small packet with unknown version") } s.logger.Infof("Client offered version %s, sending Version Negotiation Packet", hdr.Version) - _, err := s.conn.WriteTo(wire.ComposeGQUICVersionNegotiation(hdr.SrcConnectionID, s.config.Versions), remoteAddr) + _, err := s.conn.WriteTo(wire.ComposeGQUICVersionNegotiation(hdr.DestConnectionID, s.config.Versions), remoteAddr) return err } diff --git a/server_test.go b/server_test.go index 78f4f226..209af04f 100644 --- a/server_test.go +++ b/server_test.go @@ -371,11 +371,10 @@ var _ = Describe("Server", func() { hdr := wire.Header{ VersionFlag: true, DestConnectionID: connID, - SrcConnectionID: connID, PacketNumber: 1, PacketNumberLen: protocol.PacketNumberLen2, } - hdr.Write(b, protocol.PerspectiveClient, 13 /* not a valid QUIC version */) + Expect(hdr.Write(b, protocol.PerspectiveClient, 13 /* not a valid QUIC version */)).To(Succeed()) b.Write(bytes.Repeat([]byte{0}, protocol.MinClientHelloSize)) // add a fake CHLO serv.conn = conn sessionHandler.EXPECT().Get(connID) @@ -389,11 +388,10 @@ var _ = Describe("Server", func() { hdr := wire.Header{ VersionFlag: true, DestConnectionID: connID, - SrcConnectionID: connID, PacketNumber: 1, PacketNumberLen: protocol.PacketNumberLen2, } - hdr.Write(b, protocol.PerspectiveClient, 13 /* not a valid QUIC version */) + Expect(hdr.Write(b, protocol.PerspectiveClient, 13 /* not a valid QUIC version */)).To(Succeed()) b.Write(bytes.Repeat([]byte{0}, protocol.MinClientHelloSize-1)) // this packet is 1 byte too small serv.conn = conn sessionHandler.EXPECT().Get(connID) @@ -468,11 +466,10 @@ var _ = Describe("Server", func() { hdr := wire.Header{ VersionFlag: true, DestConnectionID: connID, - SrcConnectionID: connID, PacketNumber: 1, PacketNumberLen: protocol.PacketNumberLen2, } - hdr.Write(b, protocol.PerspectiveClient, 13 /* not a valid QUIC version */) + Expect(hdr.Write(b, protocol.PerspectiveClient, 13 /* not a valid QUIC version */)).To(Succeed()) b.Write(bytes.Repeat([]byte{0}, protocol.MinClientHelloSize)) // add a fake CHLO conn.dataToRead <- b.Bytes() conn.dataReadFrom = udpAddr @@ -485,7 +482,6 @@ var _ = Describe("Server", func() { ln.Accept() close(done) }() - Eventually(func() int { return conn.dataWritten.Len() }).ShouldNot(BeZero()) Expect(conn.dataWrittenTo).To(Equal(udpAddr)) r := bytes.NewReader(conn.dataWritten.Bytes()) @@ -493,7 +489,6 @@ var _ = Describe("Server", func() { Expect(err).ToNot(HaveOccurred()) Expect(packet.VersionFlag).To(BeTrue()) Expect(packet.DestConnectionID).To(Equal(connID)) - Expect(packet.SrcConnectionID).To(Equal(connID)) Expect(r.Len()).To(BeZero()) Consistently(done).ShouldNot(BeClosed()) // make the go routine return diff --git a/session.go b/session.go index 6c5857dc..c73e9953 100644 --- a/session.go +++ b/session.go @@ -205,7 +205,7 @@ func newSession( s.streamFramer = newStreamFramer(s.cryptoStream, s.streamsMap, s.version) s.packer = newPacketPacker( connectionID, - connectionID, + nil, // no src connection ID 1, s.sentPacketHandler.GetPacketNumberLen, s.RemoteAddr(), @@ -275,7 +275,7 @@ var newClientSession = func( s.streamFramer = newStreamFramer(s.cryptoStream, s.streamsMap, s.version) s.packer = newPacketPacker( connectionID, - connectionID, + nil, // no src connection ID 1, s.sentPacketHandler.GetPacketNumberLen, s.RemoteAddr(), @@ -643,7 +643,7 @@ func (s *session) handlePacketImpl(p *receivedPacket) error { } // The server can change the source connection ID with the first Handshake packet. - if s.perspective == protocol.PerspectiveClient && !s.receivedFirstPacket && !hdr.SrcConnectionID.Equal(s.destConnID) { + if s.perspective == protocol.PerspectiveClient && !s.receivedFirstPacket && hdr.IsLongHeader && !hdr.SrcConnectionID.Equal(s.destConnID) { s.logger.Debugf("Received first packet. Switching destination connection ID to: %s", hdr.SrcConnectionID) s.destConnID = hdr.SrcConnectionID s.packer.ChangeDestConnectionID(s.destConnID) diff --git a/session_test.go b/session_test.go index 58fff6ce..1294363e 100644 --- a/session_test.go +++ b/session_test.go @@ -1124,6 +1124,7 @@ var _ = Describe("Session", func() { It("retransmits an unencrypted packet, and doesn't add a STOP_WAITING frame (for IETF QUIC)", func() { sess.version = versionIETFFrames sess.packer.version = versionIETFFrames + sess.packer.srcConnID = sess.destConnID sf := &wire.StreamFrame{StreamID: 1, Data: []byte("foobar")} sph.EXPECT().DequeuePacketForRetransmission().Return(&ackhandler.Packet{ PacketNumber: 1337, @@ -1799,6 +1800,7 @@ var _ = Describe("Client Session", func() { It("changes the connection ID when receiving the first packet from the server", func() { sess.version = protocol.VersionTLS sess.packer.version = protocol.VersionTLS + sess.packer.srcConnID = sess.destConnID unpacker := NewMockUnpacker(mockCtrl) unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).Return(&unpackedPacket{}, nil) sess.unpacker = unpacker @@ -1808,6 +1810,7 @@ var _ = Describe("Client Session", func() { }() err := sess.handlePacketImpl(&receivedPacket{ header: &wire.Header{ + IsLongHeader: true, Type: protocol.PacketTypeHandshake, SrcConnectionID: protocol.ConnectionID{1, 3, 3, 7, 1, 3, 3, 7}, DestConnectionID: sess.srcConnID,