diff --git a/internal/utils/varint.go b/internal/utils/varint.go new file mode 100644 index 00000000..35e8674e --- /dev/null +++ b/internal/utils/varint.go @@ -0,0 +1,101 @@ +package utils + +import ( + "bytes" + "fmt" + "io" + + "github.com/lucas-clemente/quic-go/internal/protocol" +) + +// taken from the QUIC draft +const ( + maxVarInt1 = 63 + maxVarInt2 = 16383 + maxVarInt4 = 1073741823 + maxVarInt8 = 4611686018427387903 +) + +// ReadVarInt reads a number in the QUIC varint format +func ReadVarInt(b io.ByteReader) (uint64, error) { + firstByte, err := b.ReadByte() + if err != nil { + return 0, err + } + // the first two bits of the first byte encode the length + len := 1 << ((firstByte & 0xc0) >> 6) + b1 := firstByte & (0xff - 0xc0) + if len == 1 { + return uint64(b1), nil + } + b2, err := b.ReadByte() + if err != nil { + return 0, err + } + if len == 2 { + return uint64(b2) + uint64(b1)<<8, nil + } + b3, err := b.ReadByte() + if err != nil { + return 0, err + } + b4, err := b.ReadByte() + if err != nil { + return 0, err + } + if len == 4 { + return uint64(b4) + uint64(b3)<<8 + uint64(b2)<<16 + uint64(b1)<<24, nil + } + b5, err := b.ReadByte() + if err != nil { + return 0, err + } + b6, err := b.ReadByte() + if err != nil { + return 0, err + } + b7, err := b.ReadByte() + if err != nil { + return 0, err + } + b8, err := b.ReadByte() + if err != nil { + return 0, err + } + return uint64(b8) + uint64(b7)<<8 + uint64(b6)<<16 + uint64(b5)<<24 + uint64(b4)<<32 + uint64(b3)<<40 + uint64(b2)<<48 + uint64(b1)<<56, nil +} + +// WriteVarInt writes a number in the QUIC varint format +func WriteVarInt(b *bytes.Buffer, i uint64) { + if i <= maxVarInt1 { + b.WriteByte(uint8(i)) + } else if i <= maxVarInt2 { + b.Write([]byte{uint8(i>>8) | 0x40, uint8(i)}) + } else if i <= maxVarInt4 { + b.Write([]byte{uint8(i>>24) | 0x80, uint8(i >> 16), uint8(i >> 8), uint8(i)}) + } else if i <= maxVarInt8 { + b.Write([]byte{ + uint8(i>>56) | 0xc0, uint8(i >> 48), uint8(i >> 40), uint8(i >> 32), + uint8(i >> 24), uint8(i >> 16), uint8(i >> 8), uint8(i), + }) + } else { + panic(fmt.Sprintf("%#x doesn't fit into 62 bits", i)) + } +} + +// VarIntLen determines the number of bytes that will be needed to write a number +func VarIntLen(i uint64) protocol.ByteCount { + if i <= maxVarInt1 { + return 1 + } + if i <= maxVarInt2 { + return 2 + } + if i <= maxVarInt4 { + return 4 + } + if i <= maxVarInt8 { + return 8 + } + panic(fmt.Sprintf("%#x doesn't fit into 62 bits", i)) +} diff --git a/internal/utils/varint_test.go b/internal/utils/varint_test.go new file mode 100644 index 00000000..b49774f0 --- /dev/null +++ b/internal/utils/varint_test.go @@ -0,0 +1,157 @@ +package utils + +import ( + "bytes" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Varint encoding / decoding", func() { + Context("decoding", func() { + It("reads a 1 byte number", func() { + b := bytes.NewReader([]byte{25}) // 00011001 + val, err := ReadVarInt(b) + Expect(err).ToNot(HaveOccurred()) + Expect(val).To(Equal(uint64(25))) + Expect(b.Len()).To(BeZero()) + }) + + It("reads a number that is encoded too long", func() { + b := bytes.NewReader([]byte{0x40, 0x25}) // first byte: 01000000 + val, err := ReadVarInt(b) + Expect(err).ToNot(HaveOccurred()) + Expect(val).To(Equal(uint64(37))) + Expect(b.Len()).To(BeZero()) + }) + + It("reads a 2 byte number", func() { + b := bytes.NewReader([]byte{0x7b, 0xbd}) // first byte: 01111011 + val, err := ReadVarInt(b) + Expect(err).ToNot(HaveOccurred()) + Expect(val).To(Equal(uint64(15293))) + Expect(b.Len()).To(BeZero()) + }) + + It("reads a 4 byte number", func() { + b := bytes.NewReader([]byte{0x9d, 0x7f, 0x3e, 0x7d}) // first byte: 10011011 + val, err := ReadVarInt(b) + Expect(err).ToNot(HaveOccurred()) + Expect(val).To(Equal(uint64(494878333))) + Expect(b.Len()).To(BeZero()) + }) + + It("reads an 8 byte number", func() { + b := bytes.NewReader([]byte{0xc2, 0x19, 0x7c, 0x5e, 0xff, 0x14, 0xe8, 0x8c}) // first byte: 10000010 + val, err := ReadVarInt(b) + Expect(err).ToNot(HaveOccurred()) + Expect(val).To(Equal(uint64(151288809941952652))) + Expect(b.Len()).To(BeZero()) + }) + }) + + Context("encoding", func() { + It("writes a 1 byte number", func() { + b := &bytes.Buffer{} + WriteVarInt(b, 37) + Expect(b.Bytes()).To(Equal([]byte{0x25})) + }) + + It("writes the maximum 1 byte number in 1 byte", func() { + b := &bytes.Buffer{} + WriteVarInt(b, maxVarInt1) + Expect(b.Bytes()).To(Equal([]byte{0x3f /* 00111111 */})) + }) + + It("writes the minimum 2 byte number in 2 bytes", func() { + b := &bytes.Buffer{} + WriteVarInt(b, maxVarInt1+1) + Expect(b.Bytes()).To(Equal([]byte{0x40, maxVarInt1 + 1})) + }) + + It("writes a 2 byte number", func() { + b := &bytes.Buffer{} + WriteVarInt(b, 15293) + Expect(b.Bytes()).To(Equal([]byte{0x7b, 0xbd})) + }) + + It("writes the maximum 2 byte number in 2 bytes", func() { + b := &bytes.Buffer{} + WriteVarInt(b, maxVarInt2) + Expect(b.Bytes()).To(Equal([]byte{0x7f /* 01111111 */, 0xff})) + }) + + It("writes the minimum 4 byte number in 4 bytes", func() { + b := &bytes.Buffer{} + WriteVarInt(b, maxVarInt2+1) + Expect(b.Len()).To(Equal(4)) + num, err := ReadVarInt(b) + Expect(err).ToNot(HaveOccurred()) + Expect(num).To(Equal(uint64(maxVarInt2 + 1))) + }) + + It("writes a 4 byte number", func() { + b := &bytes.Buffer{} + WriteVarInt(b, 494878333) + Expect(b.Bytes()).To(Equal([]byte{0x9d, 0x7f, 0x3e, 0x7d})) + }) + + It("writes the maximum 4 byte number in 4 bytes", func() { + b := &bytes.Buffer{} + WriteVarInt(b, maxVarInt4) + Expect(b.Bytes()).To(Equal([]byte{0xbf /* 10111111 */, 0xff, 0xff, 0xff})) + }) + + It("writes the minimum 8 byte number in 8 bytes", func() { + b := &bytes.Buffer{} + WriteVarInt(b, maxVarInt4+1) + Expect(b.Len()).To(Equal(8)) + num, err := ReadVarInt(b) + Expect(err).ToNot(HaveOccurred()) + Expect(num).To(Equal(uint64(maxVarInt4 + 1))) + }) + + It("writes an 8 byte number", func() { + b := &bytes.Buffer{} + WriteVarInt(b, 151288809941952652) + Expect(b.Bytes()).To(Equal([]byte{0xc2, 0x19, 0x7c, 0x5e, 0xff, 0x14, 0xe8, 0x8c})) + }) + + It("writes the maximum 8 byte number in 8 bytes", func() { + b := &bytes.Buffer{} + WriteVarInt(b, maxVarInt8) + Expect(b.Bytes()).To(Equal([]byte{0xff /* 11111111 */, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff})) + }) + + It("panics when given a too large number (> 62 bit)", func() { + b := &bytes.Buffer{} + Expect(func() { WriteVarInt(b, maxVarInt8+1) }).Should(Panic()) + }) + }) + + Context("determining the length needed for encoding", func() { + It("for numbers that need 1 byte", func() { + Expect(VarIntLen(0)).To(BeEquivalentTo(1)) + Expect(VarIntLen(maxVarInt1)).To(BeEquivalentTo(1)) + }) + + It("for numbers that need 2 bytes", func() { + Expect(VarIntLen(maxVarInt1 + 1)).To(BeEquivalentTo(2)) + Expect(VarIntLen(maxVarInt2)).To(BeEquivalentTo(2)) + }) + + It("for numbers that need 4 bytes", func() { + Expect(VarIntLen(maxVarInt2 + 1)).To(BeEquivalentTo(4)) + Expect(VarIntLen(maxVarInt4)).To(BeEquivalentTo(4)) + }) + + It("for numbers that need 8 bytes", func() { + Expect(VarIntLen(maxVarInt4 + 1)).To(BeEquivalentTo(8)) + Expect(VarIntLen(maxVarInt8)).To(BeEquivalentTo(8)) + }) + + It("panics when given a too large number (> 62 bit)", func() { + Expect(func() { VarIntLen(maxVarInt8 + 1) }).Should(Panic()) + }) + }) +})