diff --git a/client.go b/client.go index 65da2541..7482eeaf 100644 --- a/client.go +++ b/client.go @@ -369,11 +369,6 @@ func (c *client) handlePacket(p *receivedPacket) { } func (c *client) handlePacketImpl(p *receivedPacket) error { - // reject packets with truncated connection id if we didn't request truncation - if p.header.OmitConnectionID && !c.config.RequestConnectionIDOmission { - return errors.New("received packet with truncated connection ID, but didn't request truncation") - } - c.mutex.Lock() defer c.mutex.Unlock() @@ -430,16 +425,21 @@ func (c *client) handleIETFQUICPacket(p *receivedPacket) error { } func (c *client) handleGQUICPacket(p *receivedPacket) error { + connID := p.header.DestConnectionID + // reject packets with truncated connection id if we didn't request truncation + if !c.config.RequestConnectionIDOmission && connID.Len() == 0 { + return errors.New("received packet with truncated connection ID, but didn't request truncation") + } // reject packets with the wrong connection ID - if !p.header.OmitConnectionID && !p.header.DestConnectionID.Equal(c.srcConnID) { - return fmt.Errorf("received a packet with an unexpected connection ID (%s, expected %s)", p.header.DestConnectionID, c.srcConnID) + if connID.Len() > 0 && !connID.Equal(c.srcConnID) { + return fmt.Errorf("received a packet with an unexpected connection ID (%s, expected %s)", connID, c.srcConnID) } if p.header.ResetFlag { cr := c.conn.RemoteAddr() // check if the remote address and the connection ID match // otherwise this might be an attacker trying to inject a PUBLIC_RESET to kill the connection - if cr.Network() != p.remoteAddr.Network() || cr.String() != p.remoteAddr.String() || !p.header.DestConnectionID.Equal(c.srcConnID) { + if cr.Network() != p.remoteAddr.Network() || cr.String() != p.remoteAddr.String() || !connID.Equal(c.srcConnID) { return errors.New("Received a spoofed Public Reset") } pr, err := wire.ParsePublicReset(bytes.NewReader(p.data)) diff --git a/client_test.go b/client_test.go index 6584599f..94524646 100644 --- a/client_test.go +++ b/client_test.go @@ -591,11 +591,9 @@ var _ = Describe("Client", func() { cl.session = NewMockQuicSession(mockCtrl) // don't EXPECT any handlePacket calls cl.config = &Config{RequestConnectionIDOmission: false} hdr := &wire.Header{ - OmitConnectionID: true, - SrcConnectionID: connID, - DestConnectionID: connID, - PacketNumber: 1, - PacketNumberLen: 1, + IsPublicHeader: true, + PacketNumber: 1, + PacketNumberLen: 1, } err := cl.handlePacketImpl(&receivedPacket{ remoteAddr: addr, @@ -747,6 +745,7 @@ var _ = Describe("Client", func() { Context("handling packets", func() { It("handles packets", func() { + cl.config = &Config{} sess := NewMockQuicSession(mockCtrl) sess.EXPECT().handlePacket(gomock.Any()) cl.session = sess @@ -786,6 +785,10 @@ var _ = Describe("Client", func() { }) Context("Public Reset handling", func() { + BeforeEach(func() { + cl.config = &Config{} + }) + It("closes the session when receiving a Public Reset", func() { sess := NewMockQuicSession(mockCtrl) sess.EXPECT().closeRemote(gomock.Any()).Do(func(err error) { diff --git a/integrationtests/tools/proxy/proxy_test.go b/integrationtests/tools/proxy/proxy_test.go index f1a010ba..01cd612d 100644 --- a/integrationtests/tools/proxy/proxy_test.go +++ b/integrationtests/tools/proxy/proxy_test.go @@ -27,7 +27,6 @@ var _ = Describe("QUIC Proxy", func() { PacketNumberLen: protocol.PacketNumberLen6, DestConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef, 0, 0, 0x13, 0x37}, SrcConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef, 0, 0, 0x13, 0x37}, - OmitConnectionID: false, } hdr.Write(b, protocol.PerspectiveServer, protocol.VersionWhatever) raw := b.Bytes() diff --git a/internal/wire/header.go b/internal/wire/header.go index 076e4d43..3d8e912f 100644 --- a/internal/wire/header.go +++ b/internal/wire/header.go @@ -18,7 +18,6 @@ type Header struct { DestConnectionID protocol.ConnectionID SrcConnectionID protocol.ConnectionID - OmitConnectionID bool PacketNumberLen protocol.PacketNumberLen PacketNumber protocol.PacketNumber diff --git a/internal/wire/ietf_header_test.go b/internal/wire/ietf_header_test.go index f7ccc951..e2026539 100644 --- a/internal/wire/ietf_header_test.go +++ b/internal/wire/ietf_header_test.go @@ -81,7 +81,6 @@ var _ = Describe("IETF QUIC Header", func() { Expect(err).ToNot(HaveOccurred()) Expect(h.Type).To(Equal(protocol.PacketTypeInitial)) Expect(h.IsLongHeader).To(BeTrue()) - Expect(h.OmitConnectionID).To(BeFalse()) Expect(h.DestConnectionID).To(Equal(protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37})) Expect(h.SrcConnectionID).To(Equal(protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37})) Expect(h.PayloadLen).To(Equal(protocol.ByteCount(0x1337))) @@ -183,7 +182,6 @@ var _ = Describe("IETF QUIC Header", func() { Expect(err).ToNot(HaveOccurred()) Expect(h.IsLongHeader).To(BeFalse()) Expect(h.KeyPhase).To(Equal(0)) - Expect(h.OmitConnectionID).To(BeFalse()) Expect(h.DestConnectionID).To(Equal(protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37})) Expect(h.SrcConnectionID).To(BeEmpty()) Expect(h.PacketNumber).To(Equal(protocol.PacketNumber(0x42))) @@ -362,9 +360,8 @@ var _ = Describe("IETF QUIC Header", func() { It("writes a header with a 2 byte packet number", func() { err := (&Header{ - OmitConnectionID: true, - PacketNumberLen: protocol.PacketNumberLen2, - PacketNumber: 0x765, + PacketNumberLen: protocol.PacketNumberLen2, + PacketNumber: 0x765, }).writeHeader(buf) Expect(err).ToNot(HaveOccurred()) expected := []byte{0x30} @@ -374,9 +371,8 @@ var _ = Describe("IETF QUIC Header", func() { It("writes a header with a 4 byte packet number", func() { err := (&Header{ - OmitConnectionID: true, - PacketNumberLen: protocol.PacketNumberLen4, - PacketNumber: 0x123456, + PacketNumberLen: protocol.PacketNumberLen4, + PacketNumber: 0x123456, }).writeHeader(buf) Expect(err).ToNot(HaveOccurred()) expected := []byte{0x30} @@ -386,19 +382,17 @@ var _ = Describe("IETF QUIC Header", func() { It("errors when given an invalid packet number length", func() { err := (&Header{ - OmitConnectionID: true, - PacketNumberLen: 3, - PacketNumber: 0xdecafbad, + PacketNumberLen: 3, + PacketNumber: 0xdecafbad, }).writeHeader(buf) Expect(err).To(MatchError("invalid packet number length: 3")) }) It("writes the Key Phase Bit", func() { err := (&Header{ - KeyPhase: 1, - OmitConnectionID: true, - PacketNumberLen: protocol.PacketNumberLen1, - PacketNumber: 0x42, + KeyPhase: 1, + PacketNumberLen: protocol.PacketNumberLen1, + PacketNumber: 0x42, }).writeHeader(buf) Expect(err).ToNot(HaveOccurred()) Expect(buf.Bytes()).To(Equal([]byte{ @@ -458,10 +452,7 @@ var _ = Describe("IETF QUIC Header", func() { }) It("has the right length for a short header without a connection ID", func() { - h := &Header{ - OmitConnectionID: true, - PacketNumberLen: protocol.PacketNumberLen1, - } + h := &Header{PacketNumberLen: protocol.PacketNumberLen1} Expect(h.getHeaderLength()).To(Equal(protocol.ByteCount(1 + 1))) err := h.writeHeader(buf) Expect(err).ToNot(HaveOccurred()) @@ -469,10 +460,7 @@ var _ = Describe("IETF QUIC Header", func() { }) It("has the right length for a short header with a 2 byte packet number", func() { - h := &Header{ - OmitConnectionID: true, - PacketNumberLen: protocol.PacketNumberLen2, - } + h := &Header{PacketNumberLen: protocol.PacketNumberLen2} Expect(h.getHeaderLength()).To(Equal(protocol.ByteCount(1 + 2))) err := h.writeHeader(buf) Expect(err).ToNot(HaveOccurred()) @@ -480,10 +468,7 @@ var _ = Describe("IETF QUIC Header", func() { }) It("has the right length for a short header with a 5 byte packet number", func() { - h := &Header{ - OmitConnectionID: true, - PacketNumberLen: protocol.PacketNumberLen4, - } + h := &Header{PacketNumberLen: protocol.PacketNumberLen4} Expect(h.getHeaderLength()).To(Equal(protocol.ByteCount(1 + 4))) err := h.writeHeader(buf) Expect(err).ToNot(HaveOccurred()) diff --git a/internal/wire/public_header.go b/internal/wire/public_header.go index 8ec8e699..1ece84c7 100644 --- a/internal/wire/public_header.go +++ b/internal/wire/public_header.go @@ -13,7 +13,6 @@ import ( var ( errResetAndVersionFlagSet = errors.New("PublicHeader: Reset Flag and Version Flag should not be set at the same time") - errReceivedOmittedConnectionID = qerr.Error(qerr.InvalidPacketHeader, "receiving packets with omitted ConnectionID is not supported") errInvalidConnectionID = qerr.Error(qerr.InvalidPacketHeader, "connection ID cannot be 0") errGetLengthNotForVersionNegotiation = errors.New("PublicHeader: GetLength cannot be called for VersionNegotiation packets") errInvalidPacketNumberLen6 = errors.New("invalid packet number length: 6 bytes") @@ -30,7 +29,7 @@ func (h *Header) writePublicHeader(b *bytes.Buffer, pers protocol.Perspective, _ if h.SrcConnectionID.Len() != 0 { return errors.New("PublicHeader: SrcConnectionID must not be set") } - if h.DestConnectionID.Len() != 8 { + if len(h.DestConnectionID) != 0 && len(h.DestConnectionID) != 8 { return fmt.Errorf("PublicHeader: wrong length for Connection ID: %d (expected 8)", len(h.DestConnectionID)) } @@ -41,7 +40,7 @@ func (h *Header) writePublicHeader(b *bytes.Buffer, pers protocol.Perspective, _ if h.ResetFlag { publicFlagByte |= 0x02 } - if !h.OmitConnectionID { + if h.DestConnectionID.Len() > 0 { publicFlagByte |= 0x08 } if len(h.DiversificationNonce) > 0 { @@ -63,7 +62,7 @@ func (h *Header) writePublicHeader(b *bytes.Buffer, pers protocol.Perspective, _ } b.WriteByte(publicFlagByte) - if !h.OmitConnectionID { + if h.DestConnectionID.Len() > 0 { b.Write(h.DestConnectionID) } if h.VersionFlag && pers == protocol.PerspectiveClient { @@ -112,9 +111,9 @@ func parsePublicHeader(b *bytes.Reader, packetSentBy protocol.Perspective) (*Hea // return nil, errors.New("diversification nonces should only be sent by servers") // } - header.OmitConnectionID = publicFlagByte&0x08 == 0 - if header.OmitConnectionID && packetSentBy == protocol.PerspectiveClient { - return nil, errReceivedOmittedConnectionID + hasConnectionID := publicFlagByte&0x08 > 0 + if !hasConnectionID && packetSentBy == protocol.PerspectiveClient { + return nil, qerr.Error(qerr.InvalidPacketHeader, "receiving packets with omitted ConnectionID is not supported") } if header.hasPacketNumber(packetSentBy) { switch publicFlagByte & 0x30 { @@ -130,12 +129,9 @@ func parsePublicHeader(b *bytes.Reader, packetSentBy protocol.Perspective) (*Hea } // Connection ID - if !header.OmitConnectionID { - connID := make(protocol.ConnectionID, 8) - if _, err := io.ReadFull(b, connID); err != nil { - if err == io.ErrUnexpectedEOF { - err = io.EOF - } + if hasConnectionID { + connID, err := protocol.ReadConnectionID(b, 8) + if err != nil { return nil, err } if connID[0] == 0 && connID[1] == 0 && connID[2] == 0 && connID[3] == 0 && connID[4] == 0 && connID[5] == 0 && connID[6] == 0 && connID[7] == 0 { @@ -218,9 +214,7 @@ func (h *Header) getPublicHeaderLength(pers protocol.Perspective) (protocol.Byte } length += protocol.ByteCount(h.PacketNumberLen) } - if !h.OmitConnectionID { - length += 8 // 8 bytes for the connection ID - } + length += protocol.ByteCount(h.DestConnectionID.Len()) // if set, always 8 bytes // Version Number in packets sent by the client if h.VersionFlag { length += 4 diff --git a/internal/wire/public_header_test.go b/internal/wire/public_header_test.go index 34339f4d..84809629 100644 --- a/internal/wire/public_header_test.go +++ b/internal/wire/public_header_test.go @@ -38,14 +38,13 @@ var _ = Describe("Public Header", func() { It("does not accept an omittedd connection ID as a server", func() { b := bytes.NewReader([]byte{0x00, 0x01}) _, err := parsePublicHeader(b, protocol.PerspectiveClient) - Expect(err).To(MatchError(errReceivedOmittedConnectionID)) + Expect(err).To(MatchError("InvalidPacketHeader: receiving packets with omitted ConnectionID is not supported")) }) It("accepts an omitted connection ID as a client", func() { b := bytes.NewReader([]byte{0x00, 0x01}) hdr, err := parsePublicHeader(b, protocol.PerspectiveServer) Expect(err).ToNot(HaveOccurred()) - Expect(hdr.OmitConnectionID).To(BeTrue()) Expect(hdr.DestConnectionID).To(BeEmpty()) Expect(hdr.SrcConnectionID).To(BeEmpty()) Expect(b.Len()).To(BeZero()) @@ -232,13 +231,10 @@ var _ = Describe("Public Header", func() { }) It("omits the connection ID", func() { - connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} b := &bytes.Buffer{} hdr := Header{ - DestConnectionID: connID, - OmitConnectionID: true, - PacketNumberLen: protocol.PacketNumberLen1, - PacketNumber: 1, + PacketNumberLen: protocol.PacketNumberLen1, + PacketNumber: 1, } err := hdr.writePublicHeader(b, protocol.PerspectiveServer, protocol.VersionWhatever) Expect(err).ToNot(HaveOccurred()) @@ -372,12 +368,10 @@ var _ = Describe("Public Header", func() { It("gets the lengths of a packet sent by the client with the VersionFlag set", func() { hdr := Header{ - DestConnectionID: connID, - OmitConnectionID: true, - PacketNumber: 0xdecafbad, - PacketNumberLen: protocol.PacketNumberLen4, - VersionFlag: true, - Version: versionBigEndian, + PacketNumber: 0xdecafbad, + PacketNumberLen: protocol.PacketNumberLen4, + VersionFlag: true, + Version: versionBigEndian, } length, err := hdr.getPublicHeaderLength(protocol.PerspectiveClient) Expect(err).ToNot(HaveOccurred()) @@ -386,10 +380,8 @@ var _ = Describe("Public Header", func() { It("gets the length of a packet with longest packet number length and omitted connectionID", func() { hdr := Header{ - DestConnectionID: connID, - OmitConnectionID: true, - PacketNumber: 0xDECAFBAD, - PacketNumberLen: protocol.PacketNumberLen4, + PacketNumber: 0xDECAFBAD, + PacketNumberLen: protocol.PacketNumberLen4, } length, err := hdr.getPublicHeaderLength(protocol.PerspectiveServer) Expect(err).ToNot(HaveOccurred()) @@ -414,7 +406,7 @@ var _ = Describe("Public Header", func() { } length, err := hdr.getPublicHeaderLength(protocol.PerspectiveServer) Expect(err).NotTo(HaveOccurred()) - Expect(length).To(Equal(protocol.ByteCount(1 + 8 + 3 + 1))) // 1 byte public flag, 8 byte connectionID, 3 byte DiversificationNonce, 1 byte PacketNumber + Expect(length).To(Equal(protocol.ByteCount(1 + 3 + 1))) // 1 byte public flag, 3 byte DiversificationNonce, 1 byte PacketNumber }) It("gets the length of a PublicReset", func() { @@ -521,19 +513,17 @@ var _ = Describe("Public Header", func() { It("logs a Public Header with omitted connection ID", func() { (&Header{ - OmitConnectionID: true, - PacketNumber: 0x1337, - PacketNumberLen: 6, - Version: protocol.Version39, + PacketNumber: 0x1337, + PacketNumberLen: 6, + Version: protocol.Version39, }).logPublicHeader(logger) Expect(buf.String()).To(ContainSubstring("Public Header{ConnectionID: (empty)")) }) It("logs a Public Header without a version", func() { (&Header{ - OmitConnectionID: true, - PacketNumber: 0x1337, - PacketNumberLen: 6, + PacketNumber: 0x1337, + PacketNumberLen: 6, }).logPublicHeader(logger) Expect(buf.String()).To(ContainSubstring("Version: (unset)")) }) diff --git a/packet_packer.go b/packet_packer.go index e44e4b6c..dcf25222 100644 --- a/packet_packer.go +++ b/packet_packer.go @@ -451,9 +451,8 @@ func (p *packetPacker) getHeader(encLevel protocol.EncryptionLevel) *wire.Header packetNumberLen := p.getPacketNumberLen(pnum) header := &wire.Header{ - DestConnectionID: p.destConnID, - PacketNumber: pnum, - PacketNumberLen: packetNumberLen, + PacketNumber: pnum, + PacketNumberLen: packetNumberLen, } if p.version.UsesTLS() && encLevel != protocol.EncryptionForwardSecure { @@ -469,8 +468,8 @@ func (p *packetPacker) getHeader(encLevel protocol.EncryptionLevel) *wire.Header } } - if p.omitConnectionID && encLevel == protocol.EncryptionForwardSecure { - header.OmitConnectionID = true + if !p.omitConnectionID || encLevel != protocol.EncryptionForwardSecure { + header.DestConnectionID = p.destConnID } if !p.version.UsesTLS() { if p.perspective == protocol.PerspectiveServer && encLevel == protocol.EncryptionSecure { diff --git a/packet_packer_test.go b/packet_packer_test.go index 0bc1fc2b..ec68e023 100644 --- a/packet_packer_test.go +++ b/packet_packer_test.go @@ -165,16 +165,16 @@ var _ = Describe("Packet packer", func() { It("it omits the connection ID for forward-secure packets", func() { ph := packer.getHeader(protocol.EncryptionForwardSecure) - Expect(ph.OmitConnectionID).To(BeFalse()) + Expect(ph.DestConnectionID.Len()).ToNot(BeZero()) packer.SetOmitConnectionID() ph = packer.getHeader(protocol.EncryptionForwardSecure) - Expect(ph.OmitConnectionID).To(BeTrue()) + Expect(ph.DestConnectionID.Len()).To(BeZero()) }) It("doesn't omit the connection ID for non-forward-secure packets", func() { packer.SetOmitConnectionID() ph := packer.getHeader(protocol.EncryptionSecure) - Expect(ph.OmitConnectionID).To(BeFalse()) + Expect(ph.DestConnectionID.Len()).ToNot(BeZero()) }) It("adds the Version Flag to the Public Header before the crypto handshake is finished", func() { @@ -254,20 +254,6 @@ var _ = Describe("Packet packer", func() { Expect(h.IsLongHeader).To(BeFalse()) Expect(h.PacketNumberLen).To(BeNumerically(">", 0)) }) - - It("it omits the connection ID for forward-secure packets", func() { - h := packer.getHeader(protocol.EncryptionForwardSecure) - Expect(h.OmitConnectionID).To(BeFalse()) - packer.SetOmitConnectionID() - h = packer.getHeader(protocol.EncryptionForwardSecure) - Expect(h.OmitConnectionID).To(BeTrue()) - }) - - It("doesn't omit the connection ID for non-forward-secure packets", func() { - packer.SetOmitConnectionID() - h := packer.getHeader(protocol.EncryptionSecure) - Expect(h.OmitConnectionID).To(BeFalse()) - }) }) })