implement the new format of the Retry packet

This commit is contained in:
Marten Seemann 2018-11-27 08:29:31 +07:00
parent c791145cba
commit 0d30e43c32
4 changed files with 23 additions and 32 deletions

View file

@ -2,7 +2,6 @@ package wire
import ( import (
"bytes" "bytes"
"crypto/rand"
"errors" "errors"
"fmt" "fmt"
"io" "io"
@ -94,9 +93,16 @@ func (h *ExtendedHeader) writeLongHeader(b *bytes.Buffer, v protocol.VersionNumb
packetType = 0x3 packetType = 0x3
} }
firstByte := 0xc0 | packetType<<4 firstByte := 0xc0 | packetType<<4
if h.Type != protocol.PacketTypeRetry { // Retry packets don't have a packet number if h.Type == protocol.PacketTypeRetry {
odcil, err := encodeSingleConnIDLen(h.OrigDestConnectionID)
if err != nil {
return err
}
firstByte |= odcil
} else { // Retry packets don't have a packet number
firstByte |= uint8(h.PacketNumberLen - 1) firstByte |= uint8(h.PacketNumberLen - 1)
} }
b.WriteByte(firstByte) b.WriteByte(firstByte)
utils.BigEndian.WriteUint32(b, uint32(h.Version)) utils.BigEndian.WriteUint32(b, uint32(h.Version))
connIDLen, err := encodeConnIDLen(h.DestConnectionID, h.SrcConnectionID) connIDLen, err := encodeConnIDLen(h.DestConnectionID, h.SrcConnectionID)
@ -107,24 +113,14 @@ func (h *ExtendedHeader) writeLongHeader(b *bytes.Buffer, v protocol.VersionNumb
b.Write(h.DestConnectionID.Bytes()) b.Write(h.DestConnectionID.Bytes())
b.Write(h.SrcConnectionID.Bytes()) b.Write(h.SrcConnectionID.Bytes())
if h.Type == protocol.PacketTypeInitial { switch h.Type {
utils.WriteVarInt(b, uint64(len(h.Token))) case protocol.PacketTypeRetry:
b.Write(h.Token)
}
if h.Type == protocol.PacketTypeRetry {
odcil, err := encodeSingleConnIDLen(h.OrigDestConnectionID)
if err != nil {
return err
}
// randomize the first 4 bits
odcilByte := make([]byte, 1)
_, _ = rand.Read(odcilByte) // it's safe to ignore the error here
odcilByte[0] = (odcilByte[0] & 0xf0) | odcil
b.Write(odcilByte)
b.Write(h.OrigDestConnectionID.Bytes()) b.Write(h.OrigDestConnectionID.Bytes())
b.Write(h.Token) b.Write(h.Token)
return nil return nil
case protocol.PacketTypeInitial:
utils.WriteVarInt(b, uint64(len(h.Token)))
b.Write(h.Token)
} }
utils.WriteVarInt(b, uint64(h.Length)) utils.WriteVarInt(b, uint64(h.Length))

View file

@ -120,14 +120,14 @@ var _ = Describe("Header", func() {
Token: token, Token: token,
OrigDestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9}, OrigDestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9},
}}).Write(buf, versionIETFHeader)).To(Succeed()) }}).Write(buf, versionIETFHeader)).To(Succeed())
Expect(buf.Bytes()[:6]).To(Equal([]byte{ expected := []byte{
0xc0 | 0x3<<4, 0xc0 | 0x3<<4 | 9 - 3, /* orig dest connection ID length */
0x1, 0x2, 0x3, 0x4, // version number 0x1, 0x2, 0x3, 0x4, // version number
0x0, // connection ID lengths)) 0x0, // connection ID lengths))
})) 1, 2, 3, 4, 5, 6, 7, 8, 9, // Orig Dest Connection ID
Expect(buf.Bytes()[6] & 0xf).To(Equal(uint8(6))) }
Expect(buf.Bytes()[7 : 7+9]).To(Equal([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9})) // Orig Dest Connection ID expected = append(expected, token...)
Expect(buf.Bytes()[7+9:]).To(Equal(token)) Expect(buf.Bytes()).To(Equal(expected))
}) })
It("refuses to write a Retry packet with an invalid Orig Destination Connection ID length", func() { It("refuses to write a Retry packet with an invalid Orig Destination Connection ID length", func() {

View file

@ -117,11 +117,7 @@ func (h *Header) parseLongHeader(b *bytes.Reader) error {
} }
if h.Type == protocol.PacketTypeRetry { if h.Type == protocol.PacketTypeRetry {
odcilByte, err := b.ReadByte() odcil := decodeSingleConnIDLen(h.typeByte & 0xf)
if err != nil {
return err
}
odcil := decodeSingleConnIDLen(odcilByte & 0xf)
h.OrigDestConnectionID, err = protocol.ReadConnectionID(b, odcil) h.OrigDestConnectionID, err = protocol.ReadConnectionID(b, odcil)
if err != nil { if err != nil {
return err return err

View file

@ -171,10 +171,9 @@ var _ = Describe("Header Parsing", func() {
}) })
It("parses a Retry packet", func() { It("parses a Retry packet", func() {
data := []byte{0xc0 ^ 0x3<<4} data := []byte{0xc0 | 0x3<<4 | (10 - 3) /* connection ID length */}
data = appendVersion(data, versionIETFFrames) data = appendVersion(data, versionIETFFrames)
data = append(data, 0x0) // connection ID lengths data = append(data, 0x0) // connection ID lengths
data = append(data, 0x97) // Orig Destination Connection ID length
data = append(data, []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}...) // source connection ID data = append(data, []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}...) // source connection ID
data = append(data, []byte{'f', 'o', 'o', 'b', 'a', 'r'}...) // token data = append(data, []byte{'f', 'o', 'o', 'b', 'a', 'r'}...) // token
b := bytes.NewReader(data) b := bytes.NewReader(data)