remove the OmitConnectionID field from the Header struct

Instead, just look at the length of the destination connection ID.
This commit is contained in:
Marten Seemann 2018-06-24 17:11:25 +07:00
parent fc57bf6c35
commit 7b96d158d2
9 changed files with 60 additions and 105 deletions

View file

@ -369,11 +369,6 @@ func (c *client) handlePacket(p *receivedPacket) {
}
func (c *client) handlePacketImpl(p *receivedPacket) error {
// reject packets with truncated connection id if we didn't request truncation
if p.header.OmitConnectionID && !c.config.RequestConnectionIDOmission {
return errors.New("received packet with truncated connection ID, but didn't request truncation")
}
c.mutex.Lock()
defer c.mutex.Unlock()
@ -430,16 +425,21 @@ func (c *client) handleIETFQUICPacket(p *receivedPacket) error {
}
func (c *client) handleGQUICPacket(p *receivedPacket) error {
connID := p.header.DestConnectionID
// reject packets with truncated connection id if we didn't request truncation
if !c.config.RequestConnectionIDOmission && connID.Len() == 0 {
return errors.New("received packet with truncated connection ID, but didn't request truncation")
}
// reject packets with the wrong connection ID
if !p.header.OmitConnectionID && !p.header.DestConnectionID.Equal(c.srcConnID) {
return fmt.Errorf("received a packet with an unexpected connection ID (%s, expected %s)", p.header.DestConnectionID, c.srcConnID)
if connID.Len() > 0 && !connID.Equal(c.srcConnID) {
return fmt.Errorf("received a packet with an unexpected connection ID (%s, expected %s)", connID, c.srcConnID)
}
if p.header.ResetFlag {
cr := c.conn.RemoteAddr()
// check if the remote address and the connection ID match
// otherwise this might be an attacker trying to inject a PUBLIC_RESET to kill the connection
if cr.Network() != p.remoteAddr.Network() || cr.String() != p.remoteAddr.String() || !p.header.DestConnectionID.Equal(c.srcConnID) {
if cr.Network() != p.remoteAddr.Network() || cr.String() != p.remoteAddr.String() || !connID.Equal(c.srcConnID) {
return errors.New("Received a spoofed Public Reset")
}
pr, err := wire.ParsePublicReset(bytes.NewReader(p.data))

View file

@ -591,11 +591,9 @@ var _ = Describe("Client", func() {
cl.session = NewMockQuicSession(mockCtrl) // don't EXPECT any handlePacket calls
cl.config = &Config{RequestConnectionIDOmission: false}
hdr := &wire.Header{
OmitConnectionID: true,
SrcConnectionID: connID,
DestConnectionID: connID,
PacketNumber: 1,
PacketNumberLen: 1,
IsPublicHeader: true,
PacketNumber: 1,
PacketNumberLen: 1,
}
err := cl.handlePacketImpl(&receivedPacket{
remoteAddr: addr,
@ -747,6 +745,7 @@ var _ = Describe("Client", func() {
Context("handling packets", func() {
It("handles packets", func() {
cl.config = &Config{}
sess := NewMockQuicSession(mockCtrl)
sess.EXPECT().handlePacket(gomock.Any())
cl.session = sess
@ -786,6 +785,10 @@ var _ = Describe("Client", func() {
})
Context("Public Reset handling", func() {
BeforeEach(func() {
cl.config = &Config{}
})
It("closes the session when receiving a Public Reset", func() {
sess := NewMockQuicSession(mockCtrl)
sess.EXPECT().closeRemote(gomock.Any()).Do(func(err error) {

View file

@ -27,7 +27,6 @@ var _ = Describe("QUIC Proxy", func() {
PacketNumberLen: protocol.PacketNumberLen6,
DestConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef, 0, 0, 0x13, 0x37},
SrcConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef, 0, 0, 0x13, 0x37},
OmitConnectionID: false,
}
hdr.Write(b, protocol.PerspectiveServer, protocol.VersionWhatever)
raw := b.Bytes()

View file

@ -18,7 +18,6 @@ type Header struct {
DestConnectionID protocol.ConnectionID
SrcConnectionID protocol.ConnectionID
OmitConnectionID bool
PacketNumberLen protocol.PacketNumberLen
PacketNumber protocol.PacketNumber

View file

@ -81,7 +81,6 @@ var _ = Describe("IETF QUIC Header", func() {
Expect(err).ToNot(HaveOccurred())
Expect(h.Type).To(Equal(protocol.PacketTypeInitial))
Expect(h.IsLongHeader).To(BeTrue())
Expect(h.OmitConnectionID).To(BeFalse())
Expect(h.DestConnectionID).To(Equal(protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37}))
Expect(h.SrcConnectionID).To(Equal(protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37}))
Expect(h.PayloadLen).To(Equal(protocol.ByteCount(0x1337)))
@ -183,7 +182,6 @@ var _ = Describe("IETF QUIC Header", func() {
Expect(err).ToNot(HaveOccurred())
Expect(h.IsLongHeader).To(BeFalse())
Expect(h.KeyPhase).To(Equal(0))
Expect(h.OmitConnectionID).To(BeFalse())
Expect(h.DestConnectionID).To(Equal(protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37}))
Expect(h.SrcConnectionID).To(BeEmpty())
Expect(h.PacketNumber).To(Equal(protocol.PacketNumber(0x42)))
@ -362,9 +360,8 @@ var _ = Describe("IETF QUIC Header", func() {
It("writes a header with a 2 byte packet number", func() {
err := (&Header{
OmitConnectionID: true,
PacketNumberLen: protocol.PacketNumberLen2,
PacketNumber: 0x765,
PacketNumberLen: protocol.PacketNumberLen2,
PacketNumber: 0x765,
}).writeHeader(buf)
Expect(err).ToNot(HaveOccurred())
expected := []byte{0x30}
@ -374,9 +371,8 @@ var _ = Describe("IETF QUIC Header", func() {
It("writes a header with a 4 byte packet number", func() {
err := (&Header{
OmitConnectionID: true,
PacketNumberLen: protocol.PacketNumberLen4,
PacketNumber: 0x123456,
PacketNumberLen: protocol.PacketNumberLen4,
PacketNumber: 0x123456,
}).writeHeader(buf)
Expect(err).ToNot(HaveOccurred())
expected := []byte{0x30}
@ -386,19 +382,17 @@ var _ = Describe("IETF QUIC Header", func() {
It("errors when given an invalid packet number length", func() {
err := (&Header{
OmitConnectionID: true,
PacketNumberLen: 3,
PacketNumber: 0xdecafbad,
PacketNumberLen: 3,
PacketNumber: 0xdecafbad,
}).writeHeader(buf)
Expect(err).To(MatchError("invalid packet number length: 3"))
})
It("writes the Key Phase Bit", func() {
err := (&Header{
KeyPhase: 1,
OmitConnectionID: true,
PacketNumberLen: protocol.PacketNumberLen1,
PacketNumber: 0x42,
KeyPhase: 1,
PacketNumberLen: protocol.PacketNumberLen1,
PacketNumber: 0x42,
}).writeHeader(buf)
Expect(err).ToNot(HaveOccurred())
Expect(buf.Bytes()).To(Equal([]byte{
@ -458,10 +452,7 @@ var _ = Describe("IETF QUIC Header", func() {
})
It("has the right length for a short header without a connection ID", func() {
h := &Header{
OmitConnectionID: true,
PacketNumberLen: protocol.PacketNumberLen1,
}
h := &Header{PacketNumberLen: protocol.PacketNumberLen1}
Expect(h.getHeaderLength()).To(Equal(protocol.ByteCount(1 + 1)))
err := h.writeHeader(buf)
Expect(err).ToNot(HaveOccurred())
@ -469,10 +460,7 @@ var _ = Describe("IETF QUIC Header", func() {
})
It("has the right length for a short header with a 2 byte packet number", func() {
h := &Header{
OmitConnectionID: true,
PacketNumberLen: protocol.PacketNumberLen2,
}
h := &Header{PacketNumberLen: protocol.PacketNumberLen2}
Expect(h.getHeaderLength()).To(Equal(protocol.ByteCount(1 + 2)))
err := h.writeHeader(buf)
Expect(err).ToNot(HaveOccurred())
@ -480,10 +468,7 @@ var _ = Describe("IETF QUIC Header", func() {
})
It("has the right length for a short header with a 5 byte packet number", func() {
h := &Header{
OmitConnectionID: true,
PacketNumberLen: protocol.PacketNumberLen4,
}
h := &Header{PacketNumberLen: protocol.PacketNumberLen4}
Expect(h.getHeaderLength()).To(Equal(protocol.ByteCount(1 + 4)))
err := h.writeHeader(buf)
Expect(err).ToNot(HaveOccurred())

View file

@ -13,7 +13,6 @@ import (
var (
errResetAndVersionFlagSet = errors.New("PublicHeader: Reset Flag and Version Flag should not be set at the same time")
errReceivedOmittedConnectionID = qerr.Error(qerr.InvalidPacketHeader, "receiving packets with omitted ConnectionID is not supported")
errInvalidConnectionID = qerr.Error(qerr.InvalidPacketHeader, "connection ID cannot be 0")
errGetLengthNotForVersionNegotiation = errors.New("PublicHeader: GetLength cannot be called for VersionNegotiation packets")
errInvalidPacketNumberLen6 = errors.New("invalid packet number length: 6 bytes")
@ -30,7 +29,7 @@ func (h *Header) writePublicHeader(b *bytes.Buffer, pers protocol.Perspective, _
if h.SrcConnectionID.Len() != 0 {
return errors.New("PublicHeader: SrcConnectionID must not be set")
}
if h.DestConnectionID.Len() != 8 {
if len(h.DestConnectionID) != 0 && len(h.DestConnectionID) != 8 {
return fmt.Errorf("PublicHeader: wrong length for Connection ID: %d (expected 8)", len(h.DestConnectionID))
}
@ -41,7 +40,7 @@ func (h *Header) writePublicHeader(b *bytes.Buffer, pers protocol.Perspective, _
if h.ResetFlag {
publicFlagByte |= 0x02
}
if !h.OmitConnectionID {
if h.DestConnectionID.Len() > 0 {
publicFlagByte |= 0x08
}
if len(h.DiversificationNonce) > 0 {
@ -63,7 +62,7 @@ func (h *Header) writePublicHeader(b *bytes.Buffer, pers protocol.Perspective, _
}
b.WriteByte(publicFlagByte)
if !h.OmitConnectionID {
if h.DestConnectionID.Len() > 0 {
b.Write(h.DestConnectionID)
}
if h.VersionFlag && pers == protocol.PerspectiveClient {
@ -112,9 +111,9 @@ func parsePublicHeader(b *bytes.Reader, packetSentBy protocol.Perspective) (*Hea
// return nil, errors.New("diversification nonces should only be sent by servers")
// }
header.OmitConnectionID = publicFlagByte&0x08 == 0
if header.OmitConnectionID && packetSentBy == protocol.PerspectiveClient {
return nil, errReceivedOmittedConnectionID
hasConnectionID := publicFlagByte&0x08 > 0
if !hasConnectionID && packetSentBy == protocol.PerspectiveClient {
return nil, qerr.Error(qerr.InvalidPacketHeader, "receiving packets with omitted ConnectionID is not supported")
}
if header.hasPacketNumber(packetSentBy) {
switch publicFlagByte & 0x30 {
@ -130,12 +129,9 @@ func parsePublicHeader(b *bytes.Reader, packetSentBy protocol.Perspective) (*Hea
}
// Connection ID
if !header.OmitConnectionID {
connID := make(protocol.ConnectionID, 8)
if _, err := io.ReadFull(b, connID); err != nil {
if err == io.ErrUnexpectedEOF {
err = io.EOF
}
if hasConnectionID {
connID, err := protocol.ReadConnectionID(b, 8)
if err != nil {
return nil, err
}
if connID[0] == 0 && connID[1] == 0 && connID[2] == 0 && connID[3] == 0 && connID[4] == 0 && connID[5] == 0 && connID[6] == 0 && connID[7] == 0 {
@ -218,9 +214,7 @@ func (h *Header) getPublicHeaderLength(pers protocol.Perspective) (protocol.Byte
}
length += protocol.ByteCount(h.PacketNumberLen)
}
if !h.OmitConnectionID {
length += 8 // 8 bytes for the connection ID
}
length += protocol.ByteCount(h.DestConnectionID.Len()) // if set, always 8 bytes
// Version Number in packets sent by the client
if h.VersionFlag {
length += 4

View file

@ -38,14 +38,13 @@ var _ = Describe("Public Header", func() {
It("does not accept an omittedd connection ID as a server", func() {
b := bytes.NewReader([]byte{0x00, 0x01})
_, err := parsePublicHeader(b, protocol.PerspectiveClient)
Expect(err).To(MatchError(errReceivedOmittedConnectionID))
Expect(err).To(MatchError("InvalidPacketHeader: receiving packets with omitted ConnectionID is not supported"))
})
It("accepts an omitted connection ID as a client", func() {
b := bytes.NewReader([]byte{0x00, 0x01})
hdr, err := parsePublicHeader(b, protocol.PerspectiveServer)
Expect(err).ToNot(HaveOccurred())
Expect(hdr.OmitConnectionID).To(BeTrue())
Expect(hdr.DestConnectionID).To(BeEmpty())
Expect(hdr.SrcConnectionID).To(BeEmpty())
Expect(b.Len()).To(BeZero())
@ -232,13 +231,10 @@ var _ = Describe("Public Header", func() {
})
It("omits the connection ID", func() {
connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
b := &bytes.Buffer{}
hdr := Header{
DestConnectionID: connID,
OmitConnectionID: true,
PacketNumberLen: protocol.PacketNumberLen1,
PacketNumber: 1,
PacketNumberLen: protocol.PacketNumberLen1,
PacketNumber: 1,
}
err := hdr.writePublicHeader(b, protocol.PerspectiveServer, protocol.VersionWhatever)
Expect(err).ToNot(HaveOccurred())
@ -372,12 +368,10 @@ var _ = Describe("Public Header", func() {
It("gets the lengths of a packet sent by the client with the VersionFlag set", func() {
hdr := Header{
DestConnectionID: connID,
OmitConnectionID: true,
PacketNumber: 0xdecafbad,
PacketNumberLen: protocol.PacketNumberLen4,
VersionFlag: true,
Version: versionBigEndian,
PacketNumber: 0xdecafbad,
PacketNumberLen: protocol.PacketNumberLen4,
VersionFlag: true,
Version: versionBigEndian,
}
length, err := hdr.getPublicHeaderLength(protocol.PerspectiveClient)
Expect(err).ToNot(HaveOccurred())
@ -386,10 +380,8 @@ var _ = Describe("Public Header", func() {
It("gets the length of a packet with longest packet number length and omitted connectionID", func() {
hdr := Header{
DestConnectionID: connID,
OmitConnectionID: true,
PacketNumber: 0xDECAFBAD,
PacketNumberLen: protocol.PacketNumberLen4,
PacketNumber: 0xDECAFBAD,
PacketNumberLen: protocol.PacketNumberLen4,
}
length, err := hdr.getPublicHeaderLength(protocol.PerspectiveServer)
Expect(err).ToNot(HaveOccurred())
@ -414,7 +406,7 @@ var _ = Describe("Public Header", func() {
}
length, err := hdr.getPublicHeaderLength(protocol.PerspectiveServer)
Expect(err).NotTo(HaveOccurred())
Expect(length).To(Equal(protocol.ByteCount(1 + 8 + 3 + 1))) // 1 byte public flag, 8 byte connectionID, 3 byte DiversificationNonce, 1 byte PacketNumber
Expect(length).To(Equal(protocol.ByteCount(1 + 3 + 1))) // 1 byte public flag, 3 byte DiversificationNonce, 1 byte PacketNumber
})
It("gets the length of a PublicReset", func() {
@ -521,19 +513,17 @@ var _ = Describe("Public Header", func() {
It("logs a Public Header with omitted connection ID", func() {
(&Header{
OmitConnectionID: true,
PacketNumber: 0x1337,
PacketNumberLen: 6,
Version: protocol.Version39,
PacketNumber: 0x1337,
PacketNumberLen: 6,
Version: protocol.Version39,
}).logPublicHeader(logger)
Expect(buf.String()).To(ContainSubstring("Public Header{ConnectionID: (empty)"))
})
It("logs a Public Header without a version", func() {
(&Header{
OmitConnectionID: true,
PacketNumber: 0x1337,
PacketNumberLen: 6,
PacketNumber: 0x1337,
PacketNumberLen: 6,
}).logPublicHeader(logger)
Expect(buf.String()).To(ContainSubstring("Version: (unset)"))
})

View file

@ -451,9 +451,8 @@ func (p *packetPacker) getHeader(encLevel protocol.EncryptionLevel) *wire.Header
packetNumberLen := p.getPacketNumberLen(pnum)
header := &wire.Header{
DestConnectionID: p.destConnID,
PacketNumber: pnum,
PacketNumberLen: packetNumberLen,
PacketNumber: pnum,
PacketNumberLen: packetNumberLen,
}
if p.version.UsesTLS() && encLevel != protocol.EncryptionForwardSecure {
@ -469,8 +468,8 @@ func (p *packetPacker) getHeader(encLevel protocol.EncryptionLevel) *wire.Header
}
}
if p.omitConnectionID && encLevel == protocol.EncryptionForwardSecure {
header.OmitConnectionID = true
if !p.omitConnectionID || encLevel != protocol.EncryptionForwardSecure {
header.DestConnectionID = p.destConnID
}
if !p.version.UsesTLS() {
if p.perspective == protocol.PerspectiveServer && encLevel == protocol.EncryptionSecure {

View file

@ -165,16 +165,16 @@ var _ = Describe("Packet packer", func() {
It("it omits the connection ID for forward-secure packets", func() {
ph := packer.getHeader(protocol.EncryptionForwardSecure)
Expect(ph.OmitConnectionID).To(BeFalse())
Expect(ph.DestConnectionID.Len()).ToNot(BeZero())
packer.SetOmitConnectionID()
ph = packer.getHeader(protocol.EncryptionForwardSecure)
Expect(ph.OmitConnectionID).To(BeTrue())
Expect(ph.DestConnectionID.Len()).To(BeZero())
})
It("doesn't omit the connection ID for non-forward-secure packets", func() {
packer.SetOmitConnectionID()
ph := packer.getHeader(protocol.EncryptionSecure)
Expect(ph.OmitConnectionID).To(BeFalse())
Expect(ph.DestConnectionID.Len()).ToNot(BeZero())
})
It("adds the Version Flag to the Public Header before the crypto handshake is finished", func() {
@ -254,20 +254,6 @@ var _ = Describe("Packet packer", func() {
Expect(h.IsLongHeader).To(BeFalse())
Expect(h.PacketNumberLen).To(BeNumerically(">", 0))
})
It("it omits the connection ID for forward-secure packets", func() {
h := packer.getHeader(protocol.EncryptionForwardSecure)
Expect(h.OmitConnectionID).To(BeFalse())
packer.SetOmitConnectionID()
h = packer.getHeader(protocol.EncryptionForwardSecure)
Expect(h.OmitConnectionID).To(BeTrue())
})
It("doesn't omit the connection ID for non-forward-secure packets", func() {
packer.SetOmitConnectionID()
h := packer.getHeader(protocol.EncryptionSecure)
Expect(h.OmitConnectionID).To(BeFalse())
})
})
})