mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-03 20:27:35 +03:00
refactor header writing to append to a byte slice (#3646)
This avoids having to allocate a bytes.Buffer.
This commit is contained in:
parent
3d4bbc28ba
commit
c24fbb094c
18 changed files with 282 additions and 279 deletions
|
@ -571,8 +571,8 @@ var _ = Describe("Connection", func() {
|
|||
cryptoSetup.EXPECT().Close()
|
||||
streamManager.EXPECT().CloseWithError(gomock.Any())
|
||||
connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
|
||||
buf := &bytes.Buffer{}
|
||||
Expect(wire.WriteShortHeader(buf, srcConnID, 42, protocol.PacketNumberLen2, protocol.KeyPhaseOne)).To(Succeed())
|
||||
b, err := wire.AppendShortHeader(nil, srcConnID, 42, protocol.PacketNumberLen2, protocol.KeyPhaseOne)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).DoAndReturn(func(time.Time, []byte) (protocol.PacketNumber, protocol.PacketNumberLen, protocol.KeyPhaseBit, []byte, error) {
|
||||
b, err := (&wire.ConnectionCloseFrame{ErrorCode: uint64(qerr.StreamLimitError)}).Append(nil, conn.version)
|
||||
|
@ -589,7 +589,7 @@ var _ = Describe("Connection", func() {
|
|||
rcvTime: time.Now(),
|
||||
remoteAddr: &net.UDPAddr{},
|
||||
buffer: getPacketBuffer(),
|
||||
data: buf.Bytes(),
|
||||
data: b,
|
||||
})
|
||||
// Consistently(pack).ShouldNot(Receive())
|
||||
Eventually(conn.Context().Done()).Should(BeClosed())
|
||||
|
@ -649,20 +649,20 @@ var _ = Describe("Connection", func() {
|
|||
})
|
||||
|
||||
getShortHeaderPacket := func(connID protocol.ConnectionID, pn protocol.PacketNumber, data []byte) *receivedPacket {
|
||||
buf := &bytes.Buffer{}
|
||||
Expect(wire.WriteShortHeader(buf, connID, pn, protocol.PacketNumberLen2, protocol.KeyPhaseOne)).To(Succeed())
|
||||
b, err := wire.AppendShortHeader(nil, connID, pn, protocol.PacketNumberLen2, protocol.KeyPhaseOne)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
return &receivedPacket{
|
||||
data: append(buf.Bytes(), data...),
|
||||
data: append(b, data...),
|
||||
buffer: getPacketBuffer(),
|
||||
rcvTime: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
getLongHeaderPacket := func(extHdr *wire.ExtendedHeader, data []byte) *receivedPacket {
|
||||
buf := &bytes.Buffer{}
|
||||
Expect(extHdr.Write(buf, conn.version)).To(Succeed())
|
||||
b, err := extHdr.Append(nil, conn.version)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
return &receivedPacket{
|
||||
data: append(buf.Bytes(), data...),
|
||||
data: append(b, data...),
|
||||
buffer: getPacketBuffer(),
|
||||
rcvTime: time.Now(),
|
||||
}
|
||||
|
@ -2365,10 +2365,10 @@ var _ = Describe("Client Connection", func() {
|
|||
destConnID := protocol.ParseConnectionID([]byte{8, 7, 6, 5, 4, 3, 2, 1})
|
||||
|
||||
getPacket := func(hdr *wire.ExtendedHeader, data []byte) *receivedPacket {
|
||||
buf := &bytes.Buffer{}
|
||||
Expect(hdr.Write(buf, conn.version)).To(Succeed())
|
||||
b, err := hdr.Append(nil, conn.version)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
return &receivedPacket{
|
||||
data: append(buf.Bytes(), data...),
|
||||
data: append(b, data...),
|
||||
buffer: getPacketBuffer(),
|
||||
}
|
||||
}
|
||||
|
@ -2643,9 +2643,9 @@ var _ = Describe("Client Connection", func() {
|
|||
})
|
||||
|
||||
getRetryTag := func(hdr *wire.ExtendedHeader) []byte {
|
||||
buf := &bytes.Buffer{}
|
||||
hdr.Write(buf, conn.version)
|
||||
return handshake.GetRetryIntegrityTag(buf.Bytes(), origDestConnID, hdr.Version)[:]
|
||||
b, err := hdr.Append(nil, conn.version)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
return handshake.GetRetryIntegrityTag(b, origDestConnID, hdr.Version)[:]
|
||||
}
|
||||
|
||||
It("handles Retry packets", func() {
|
||||
|
@ -2858,10 +2858,10 @@ var _ = Describe("Client Connection", func() {
|
|||
var unpacker *MockUnpacker
|
||||
|
||||
getPacket := func(extHdr *wire.ExtendedHeader, data []byte) *receivedPacket {
|
||||
buf := &bytes.Buffer{}
|
||||
Expect(extHdr.Write(buf, conn.version)).To(Succeed())
|
||||
b, err := extHdr.Append(nil, conn.version)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
return &receivedPacket{
|
||||
data: append(buf.Bytes(), data...),
|
||||
data: append(b, data...),
|
||||
buffer: getPacketBuffer(),
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"log"
|
||||
"math/rand"
|
||||
|
||||
|
@ -85,28 +84,28 @@ func main() {
|
|||
PacketNumberLen: protocol.PacketNumberLen(rand.Intn(4) + 1),
|
||||
PacketNumber: protocol.PacketNumber(rand.Uint64()),
|
||||
}
|
||||
b := &bytes.Buffer{}
|
||||
if err := extHdr.Write(b, version); err != nil {
|
||||
b, err := extHdr.Append(nil, version)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
if h.Type == protocol.PacketTypeRetry {
|
||||
b.Write([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16})
|
||||
b = append(b, []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}...)
|
||||
}
|
||||
if h.Length > 0 {
|
||||
b.Write(make([]byte, h.Length))
|
||||
b = append(b, make([]byte, h.Length)...)
|
||||
}
|
||||
|
||||
if err := helper.WriteCorpusFileWithPrefix("corpus", b.Bytes(), header.PrefixLen); err != nil {
|
||||
if err := helper.WriteCorpusFileWithPrefix("corpus", b, header.PrefixLen); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
// short header
|
||||
b := &bytes.Buffer{}
|
||||
if err := wire.WriteShortHeader(b, protocol.ParseConnectionID(getRandomData(8)), 1337, protocol.PacketNumberLen2, protocol.KeyPhaseOne); err != nil {
|
||||
b, err := wire.AppendShortHeader(nil, protocol.ParseConnectionID(getRandomData(8)), 1337, protocol.PacketNumberLen2, protocol.KeyPhaseOne)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
if err := helper.WriteCorpusFileWithPrefix("corpus", b.Bytes(), header.PrefixLen); err != nil {
|
||||
if err := helper.WriteCorpusFileWithPrefix("corpus", b, header.PrefixLen); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
|
|
|
@ -64,8 +64,8 @@ func Fuzz(data []byte) int {
|
|||
if hdr.Length > 16383 {
|
||||
return 1
|
||||
}
|
||||
b := &bytes.Buffer{}
|
||||
if err := extHdr.Write(b, version); err != nil {
|
||||
b, err := extHdr.Append(nil, version)
|
||||
if err != nil {
|
||||
// We are able to parse packets with connection IDs longer than 20 bytes,
|
||||
// but in QUIC version 1, we don't write headers with longer connection IDs.
|
||||
if hdr.DestConnectionID.Len() <= protocol.MaxConnIDLen &&
|
||||
|
@ -76,8 +76,8 @@ func Fuzz(data []byte) int {
|
|||
}
|
||||
// GetLength is not implemented for Retry packets
|
||||
if hdr.Type != protocol.PacketTypeRetry {
|
||||
if expLen := extHdr.GetLength(version); expLen != protocol.ByteCount(b.Len()) {
|
||||
panic(fmt.Sprintf("inconsistent header length: %#v. Expected %d, got %d", extHdr, expLen, b.Len()))
|
||||
if expLen := extHdr.GetLength(version); expLen != protocol.ByteCount(len(b)) {
|
||||
panic(fmt.Sprintf("inconsistent header length: %#v. Expected %d, got %d", extHdr, expLen, len(b)))
|
||||
}
|
||||
}
|
||||
return 1
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
package self_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
@ -116,12 +115,12 @@ var _ = Describe("MITM test", func() {
|
|||
for i := 0; i < numPackets; i++ {
|
||||
payloadLen := mrand.Int31n(100)
|
||||
replyHdr.Length = protocol.ByteCount(mrand.Int31n(payloadLen + 1))
|
||||
buf := &bytes.Buffer{}
|
||||
Expect(replyHdr.Write(buf, version)).To(Succeed())
|
||||
b := make([]byte, payloadLen)
|
||||
mrand.Read(b)
|
||||
buf.Write(b)
|
||||
if _, err := conn.WriteTo(buf.Bytes(), remoteAddr); err != nil {
|
||||
b, err := replyHdr.Append(nil, version)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
r := make([]byte, payloadLen)
|
||||
mrand.Read(r)
|
||||
b = append(b, r...)
|
||||
if _, err := conn.WriteTo(b, remoteAddr); err != nil {
|
||||
return
|
||||
}
|
||||
<-ticker.C
|
||||
|
@ -134,13 +133,13 @@ var _ = Describe("MITM test", func() {
|
|||
Expect(err).To(MatchError(wire.ErrInvalidReservedBits))
|
||||
}
|
||||
for i := 0; i < numPackets; i++ {
|
||||
buf := &bytes.Buffer{}
|
||||
Expect(wire.WriteShortHeader(buf, connID, pn, pnLen, protocol.KeyPhaseBit(mrand.Intn(2)))).To(Succeed())
|
||||
b, err := wire.AppendShortHeader(nil, connID, pn, pnLen, protocol.KeyPhaseBit(mrand.Intn(2)))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
payloadLen := mrand.Int31n(100)
|
||||
b := make([]byte, payloadLen)
|
||||
mrand.Read(b)
|
||||
buf.Write(b)
|
||||
if _, err := conn.WriteTo(buf.Bytes(), remoteAddr); err != nil {
|
||||
r := make([]byte, payloadLen)
|
||||
mrand.Read(r)
|
||||
b = append(b, r...)
|
||||
if _, err := conn.WriteTo(b, remoteAddr); err != nil {
|
||||
return
|
||||
}
|
||||
<-ticker.C
|
||||
|
|
|
@ -28,7 +28,6 @@ func isProxyRunning() bool {
|
|||
|
||||
var _ = Describe("QUIC Proxy", func() {
|
||||
makePacket := func(p protocol.PacketNumber, payload []byte) []byte {
|
||||
b := &bytes.Buffer{}
|
||||
hdr := wire.ExtendedHeader{
|
||||
Header: wire.Header{
|
||||
Type: protocol.PacketTypeInitial,
|
||||
|
@ -40,10 +39,10 @@ var _ = Describe("QUIC Proxy", func() {
|
|||
PacketNumber: p,
|
||||
PacketNumberLen: protocol.PacketNumberLen4,
|
||||
}
|
||||
Expect(hdr.Write(b, protocol.VersionWhatever)).To(Succeed())
|
||||
raw := b.Bytes()
|
||||
raw = append(raw, payload...)
|
||||
return raw
|
||||
b, err := hdr.Append(nil, protocol.Version1)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
b = append(b, payload...)
|
||||
return b
|
||||
}
|
||||
|
||||
readPacketNumber := func(b []byte) protocol.PacketNumber {
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
package testutils
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/handshake"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
|
@ -13,9 +13,11 @@ import (
|
|||
|
||||
// writePacket returns a new raw packet with the specified header and payload
|
||||
func writePacket(hdr *wire.ExtendedHeader, data []byte) []byte {
|
||||
buf := &bytes.Buffer{}
|
||||
hdr.Write(buf, hdr.Version)
|
||||
return append(buf.Bytes(), data...)
|
||||
b, err := hdr.Append(nil, hdr.Version)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("failed to write header: %s", err))
|
||||
}
|
||||
return append(b, data...)
|
||||
}
|
||||
|
||||
// packRawPayload returns a new raw payload containing given frames
|
||||
|
|
|
@ -2,6 +2,7 @@ package wire
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
|
@ -93,20 +94,17 @@ func (h *ExtendedHeader) readPacketNumber(b *bytes.Reader) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// Write writes the Header.
|
||||
func (h *ExtendedHeader) Write(b *bytes.Buffer, ver protocol.VersionNumber) error {
|
||||
// Append appends the Header.
|
||||
func (h *ExtendedHeader) Append(b []byte, v protocol.VersionNumber) ([]byte, error) {
|
||||
if h.DestConnectionID.Len() > protocol.MaxConnIDLen {
|
||||
return fmt.Errorf("invalid connection ID length: %d bytes", h.DestConnectionID.Len())
|
||||
return nil, fmt.Errorf("invalid connection ID length: %d bytes", h.DestConnectionID.Len())
|
||||
}
|
||||
if h.SrcConnectionID.Len() > protocol.MaxConnIDLen {
|
||||
return fmt.Errorf("invalid connection ID length: %d bytes", h.SrcConnectionID.Len())
|
||||
return nil, fmt.Errorf("invalid connection ID length: %d bytes", h.SrcConnectionID.Len())
|
||||
}
|
||||
return h.writeLongHeader(b, ver)
|
||||
}
|
||||
|
||||
func (h *ExtendedHeader) writeLongHeader(b *bytes.Buffer, version protocol.VersionNumber) error {
|
||||
var packetType uint8
|
||||
if version == protocol.Version2 {
|
||||
if v == protocol.Version2 {
|
||||
//nolint:exhaustive
|
||||
switch h.Type {
|
||||
case protocol.PacketTypeInitial:
|
||||
|
@ -137,24 +135,25 @@ func (h *ExtendedHeader) writeLongHeader(b *bytes.Buffer, version protocol.Versi
|
|||
firstByte |= uint8(h.PacketNumberLen - 1)
|
||||
}
|
||||
|
||||
b.WriteByte(firstByte)
|
||||
utils.BigEndian.WriteUint32(b, uint32(h.Version))
|
||||
b.WriteByte(uint8(h.DestConnectionID.Len()))
|
||||
b.Write(h.DestConnectionID.Bytes())
|
||||
b.WriteByte(uint8(h.SrcConnectionID.Len()))
|
||||
b.Write(h.SrcConnectionID.Bytes())
|
||||
b = append(b, firstByte)
|
||||
b = append(b, make([]byte, 4)...)
|
||||
binary.BigEndian.PutUint32(b[len(b)-4:], uint32(h.Version))
|
||||
b = append(b, uint8(h.DestConnectionID.Len()))
|
||||
b = append(b, h.DestConnectionID.Bytes()...)
|
||||
b = append(b, uint8(h.SrcConnectionID.Len()))
|
||||
b = append(b, h.SrcConnectionID.Bytes()...)
|
||||
|
||||
//nolint:exhaustive
|
||||
switch h.Type {
|
||||
case protocol.PacketTypeRetry:
|
||||
b.Write(h.Token)
|
||||
return nil
|
||||
b = append(b, h.Token...)
|
||||
return b, nil
|
||||
case protocol.PacketTypeInitial:
|
||||
quicvarint.Write(b, uint64(len(h.Token)))
|
||||
b.Write(h.Token)
|
||||
b = quicvarint.Append(b, uint64(len(h.Token)))
|
||||
b = append(b, h.Token...)
|
||||
}
|
||||
quicvarint.WriteWithLen(b, uint64(h.Length), 2)
|
||||
return writePacketNumber(b, h.PacketNumber, h.PacketNumberLen)
|
||||
b = quicvarint.AppendWithLen(b, uint64(h.Length), 2)
|
||||
return appendPacketNumber(b, h.PacketNumber, h.PacketNumberLen)
|
||||
}
|
||||
|
||||
// ParsedLen returns the number of bytes that were consumed when parsing the header
|
||||
|
@ -188,18 +187,24 @@ func (h *ExtendedHeader) Log(logger utils.Logger) {
|
|||
logger.Debugf("\tLong Header{Type: %s, DestConnectionID: %s, SrcConnectionID: %s, %sPacketNumber: %d, PacketNumberLen: %d, Length: %d, Version: %s}", h.Type, h.DestConnectionID, h.SrcConnectionID, token, h.PacketNumber, h.PacketNumberLen, h.Length, h.Version)
|
||||
}
|
||||
|
||||
func writePacketNumber(b *bytes.Buffer, pn protocol.PacketNumber, pnLen protocol.PacketNumberLen) error {
|
||||
func appendPacketNumber(b []byte, pn protocol.PacketNumber, pnLen protocol.PacketNumberLen) ([]byte, error) {
|
||||
switch pnLen {
|
||||
case protocol.PacketNumberLen1:
|
||||
b.WriteByte(uint8(pn))
|
||||
b = append(b, uint8(pn))
|
||||
case protocol.PacketNumberLen2:
|
||||
utils.BigEndian.WriteUint16(b, uint16(pn))
|
||||
buf := make([]byte, 2)
|
||||
binary.BigEndian.PutUint16(buf, uint16(pn))
|
||||
b = append(b, buf...)
|
||||
case protocol.PacketNumberLen3:
|
||||
utils.BigEndian.WriteUint24(b, uint32(pn))
|
||||
buf := make([]byte, 4)
|
||||
binary.BigEndian.PutUint32(buf, uint32(pn))
|
||||
b = append(b, buf[1:]...)
|
||||
case protocol.PacketNumberLen4:
|
||||
utils.BigEndian.WriteUint32(b, uint32(pn))
|
||||
buf := make([]byte, 4)
|
||||
binary.BigEndian.PutUint32(buf, uint32(pn))
|
||||
b = append(b, buf...)
|
||||
default:
|
||||
return fmt.Errorf("invalid packet number length: %d", pnLen)
|
||||
return nil, fmt.Errorf("invalid packet number length: %d", pnLen)
|
||||
}
|
||||
return nil
|
||||
return b, nil
|
||||
}
|
||||
|
|
|
@ -14,20 +14,12 @@ import (
|
|||
)
|
||||
|
||||
var _ = Describe("Header", func() {
|
||||
const versionIETFHeader = protocol.Version1
|
||||
|
||||
Context("Writing", func() {
|
||||
var buf *bytes.Buffer
|
||||
|
||||
BeforeEach(func() {
|
||||
buf = &bytes.Buffer{}
|
||||
})
|
||||
|
||||
Context("Long Header", func() {
|
||||
Context("Long Header, version 1", func() {
|
||||
srcConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})
|
||||
|
||||
It("writes", func() {
|
||||
Expect((&ExtendedHeader{
|
||||
b, err := (&ExtendedHeader{
|
||||
Header: Header{
|
||||
Type: protocol.PacketTypeHandshake,
|
||||
DestConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe}),
|
||||
|
@ -37,7 +29,8 @@ var _ = Describe("Header", func() {
|
|||
},
|
||||
PacketNumber: 0xdecaf,
|
||||
PacketNumberLen: protocol.PacketNumberLen3,
|
||||
}).Write(buf, versionIETFHeader)).To(Succeed())
|
||||
}).Append(nil, protocol.Version1)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
expected := []byte{
|
||||
0xc0 | 0x2<<4 | 0x2,
|
||||
0x1, 0x2, 0x3, 0x4, // version number
|
||||
|
@ -48,11 +41,11 @@ var _ = Describe("Header", func() {
|
|||
}
|
||||
expected = append(expected, encodeVarInt(protocol.InitialPacketSizeIPv4)...) // length
|
||||
expected = append(expected, []byte{0xd, 0xec, 0xaf}...) // packet number
|
||||
Expect(buf.Bytes()).To(Equal(expected))
|
||||
Expect(b).To(Equal(expected))
|
||||
})
|
||||
|
||||
It("writes a header with a 20 byte connection ID", func() {
|
||||
err := (&ExtendedHeader{
|
||||
b, err := (&ExtendedHeader{
|
||||
Header: Header{
|
||||
SrcConnectionID: srcConnID,
|
||||
DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20}), // connection IDs must be at most 20 bytes long
|
||||
|
@ -61,14 +54,14 @@ var _ = Describe("Header", func() {
|
|||
},
|
||||
PacketNumber: 0xdecafbad,
|
||||
PacketNumberLen: protocol.PacketNumberLen4,
|
||||
}).Write(buf, versionIETFHeader)
|
||||
}).Append(nil, protocol.Version1)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(buf.Bytes()).To(ContainSubstring(string([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20})))
|
||||
Expect(b).To(ContainSubstring(string([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20})))
|
||||
})
|
||||
|
||||
It("writes an Initial containing a token", func() {
|
||||
token := []byte("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.")
|
||||
Expect((&ExtendedHeader{
|
||||
b, err := (&ExtendedHeader{
|
||||
Header: Header{
|
||||
Version: 0x1020304,
|
||||
Type: protocol.PacketTypeInitial,
|
||||
|
@ -76,14 +69,15 @@ var _ = Describe("Header", func() {
|
|||
},
|
||||
PacketNumber: 0xdecafbad,
|
||||
PacketNumberLen: protocol.PacketNumberLen4,
|
||||
}).Write(buf, versionIETFHeader)).To(Succeed())
|
||||
Expect(buf.Bytes()[0]>>4&0b11 == 0)
|
||||
}).Append(nil, protocol.Version1)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(b[0]>>4&0b11 == 0)
|
||||
expectedSubstring := append(encodeVarInt(uint64(len(token))), token...)
|
||||
Expect(buf.Bytes()).To(ContainSubstring(string(expectedSubstring)))
|
||||
Expect(b).To(ContainSubstring(string(expectedSubstring)))
|
||||
})
|
||||
|
||||
It("uses a 2-byte encoding for the length on Initial packets", func() {
|
||||
Expect((&ExtendedHeader{
|
||||
b, err := (&ExtendedHeader{
|
||||
Header: Header{
|
||||
Version: 0x1020304,
|
||||
Type: protocol.PacketTypeInitial,
|
||||
|
@ -91,89 +85,88 @@ var _ = Describe("Header", func() {
|
|||
},
|
||||
PacketNumber: 0xdecafbad,
|
||||
PacketNumberLen: protocol.PacketNumberLen4,
|
||||
}).Write(buf, versionIETFHeader)).To(Succeed())
|
||||
b := &bytes.Buffer{}
|
||||
quicvarint.WriteWithLen(b, 37, 2)
|
||||
Expect(buf.Bytes()[buf.Len()-6 : buf.Len()-4]).To(Equal(b.Bytes()))
|
||||
}).Append(nil, protocol.Version1)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
lengthEncoded := quicvarint.AppendWithLen(nil, 37, 2)
|
||||
Expect(b[len(b)-6 : len(b)-4]).To(Equal(lengthEncoded))
|
||||
})
|
||||
|
||||
It("writes a Retry packet", func() {
|
||||
token := []byte("Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat.")
|
||||
Expect((&ExtendedHeader{Header: Header{
|
||||
b, err := (&ExtendedHeader{Header: Header{
|
||||
Version: protocol.Version1,
|
||||
Type: protocol.PacketTypeRetry,
|
||||
Token: token,
|
||||
}}).Write(buf, versionIETFHeader)).To(Succeed())
|
||||
}}).Append(nil, protocol.Version1)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
expected := []byte{0xc0 | 0b11<<4}
|
||||
expected = appendVersion(expected, protocol.Version1)
|
||||
expected = append(expected, 0x0) // dest connection ID length
|
||||
expected = append(expected, 0x0) // src connection ID length
|
||||
expected = append(expected, token...)
|
||||
Expect(buf.Bytes()).To(Equal(expected))
|
||||
Expect(b).To(Equal(expected))
|
||||
})
|
||||
})
|
||||
|
||||
Context("long header, version 2", func() {
|
||||
It("writes an Initial", func() {
|
||||
Expect((&ExtendedHeader{
|
||||
b, err := (&ExtendedHeader{
|
||||
Header: Header{
|
||||
Version: protocol.Version2,
|
||||
Type: protocol.PacketTypeInitial,
|
||||
},
|
||||
PacketNumber: 0xdecafbad,
|
||||
PacketNumberLen: protocol.PacketNumberLen4,
|
||||
}).Write(buf, protocol.Version2)).To(Succeed())
|
||||
Expect(buf.Bytes()[0]>>4&0b11 == 0b01)
|
||||
}).Append(nil, protocol.Version2)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(b[0]>>4&0b11 == 0b01)
|
||||
})
|
||||
|
||||
It("writes a Retry packet", func() {
|
||||
token := []byte("Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat.")
|
||||
Expect((&ExtendedHeader{Header: Header{
|
||||
b, err := (&ExtendedHeader{Header: Header{
|
||||
Version: protocol.Version2,
|
||||
Type: protocol.PacketTypeRetry,
|
||||
Token: token,
|
||||
}}).Write(buf, versionIETFHeader)).To(Succeed())
|
||||
expected := []byte{0xc0 | 0b11<<4}
|
||||
}}).Append(nil, protocol.Version2)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
expected := []byte{0xc0 | 0b00<<4}
|
||||
expected = appendVersion(expected, protocol.Version2)
|
||||
expected = append(expected, 0x0) // dest connection ID length
|
||||
expected = append(expected, 0x0) // src connection ID length
|
||||
expected = append(expected, token...)
|
||||
Expect(buf.Bytes()).To(Equal(expected))
|
||||
Expect(b).To(Equal(expected))
|
||||
})
|
||||
|
||||
It("writes a Handshake Packet", func() {
|
||||
Expect((&ExtendedHeader{
|
||||
b, err := (&ExtendedHeader{
|
||||
Header: Header{
|
||||
Version: protocol.Version2,
|
||||
Type: protocol.PacketTypeHandshake,
|
||||
},
|
||||
PacketNumber: 0xdecafbad,
|
||||
PacketNumberLen: protocol.PacketNumberLen4,
|
||||
}).Write(buf, protocol.Version2)).To(Succeed())
|
||||
Expect(buf.Bytes()[0]>>4&0b11 == 0b11)
|
||||
}).Append(nil, protocol.Version2)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(b[0]>>4&0b11 == 0b11)
|
||||
})
|
||||
|
||||
It("writes a 0-RTT Packet", func() {
|
||||
Expect((&ExtendedHeader{
|
||||
b, err := (&ExtendedHeader{
|
||||
Header: Header{
|
||||
Version: protocol.Version2,
|
||||
Type: protocol.PacketType0RTT,
|
||||
},
|
||||
PacketNumber: 0xdecafbad,
|
||||
PacketNumberLen: protocol.PacketNumberLen4,
|
||||
}).Write(buf, protocol.Version2)).To(Succeed())
|
||||
Expect(buf.Bytes()[0]>>4&0b11 == 0b10)
|
||||
}).Append(nil, protocol.Version2)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(b[0]>>4&0b11 == 0b10)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
Context("getting the length", func() {
|
||||
var buf *bytes.Buffer
|
||||
|
||||
BeforeEach(func() {
|
||||
buf = &bytes.Buffer{}
|
||||
})
|
||||
|
||||
It("has the right length for the Long Header, for a short length", func() {
|
||||
h := &ExtendedHeader{
|
||||
Header: Header{
|
||||
|
@ -185,9 +178,10 @@ var _ = Describe("Header", func() {
|
|||
PacketNumberLen: protocol.PacketNumberLen1,
|
||||
}
|
||||
expectedLen := 1 /* type byte */ + 4 /* version */ + 1 /* dest conn ID len */ + 8 /* dest conn id */ + 1 /* src conn ID len */ + 8 /* src conn id */ + 2 /* length */ + 1 /* packet number */
|
||||
Expect(h.GetLength(versionIETFHeader)).To(BeEquivalentTo(expectedLen))
|
||||
Expect(h.Write(buf, versionIETFHeader)).To(Succeed())
|
||||
Expect(buf.Len()).To(Equal(expectedLen))
|
||||
Expect(h.GetLength(protocol.Version1)).To(BeEquivalentTo(expectedLen))
|
||||
b, err := h.Append(nil, protocol.Version1)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(b).To(HaveLen(expectedLen))
|
||||
})
|
||||
|
||||
It("has the right length for the Long Header, for a long length", func() {
|
||||
|
@ -201,9 +195,10 @@ var _ = Describe("Header", func() {
|
|||
PacketNumberLen: protocol.PacketNumberLen2,
|
||||
}
|
||||
expectedLen := 1 /* type byte */ + 4 /* version */ + 1 /* dest conn id len */ + 8 /* dest conn id */ + 1 /* src conn ID len */ + 8 /* src conn id */ + 2 /* long len */ + 2 /* packet number */
|
||||
Expect(h.GetLength(versionIETFHeader)).To(BeEquivalentTo(expectedLen))
|
||||
Expect(h.Write(buf, versionIETFHeader)).To(Succeed())
|
||||
Expect(buf.Len()).To(Equal(expectedLen))
|
||||
Expect(h.GetLength(protocol.Version1)).To(BeEquivalentTo(expectedLen))
|
||||
b, err := h.Append(nil, protocol.Version1)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(b).To(HaveLen(expectedLen))
|
||||
})
|
||||
|
||||
It("has the right length for an Initial that has a short length", func() {
|
||||
|
@ -217,9 +212,10 @@ var _ = Describe("Header", func() {
|
|||
PacketNumberLen: protocol.PacketNumberLen2,
|
||||
}
|
||||
expectedLen := 1 /* type byte */ + 4 /* version */ + 1 /* dest conn id len */ + 8 /* dest conn id */ + 1 /* src conn ID len */ + 4 /* src conn id */ + 1 /* token length */ + 2 /* length len */ + 2 /* packet number */
|
||||
Expect(h.GetLength(versionIETFHeader)).To(BeEquivalentTo(expectedLen))
|
||||
Expect(h.Write(buf, versionIETFHeader)).To(Succeed())
|
||||
Expect(buf.Len()).To(Equal(expectedLen))
|
||||
Expect(h.GetLength(protocol.Version1)).To(BeEquivalentTo(expectedLen))
|
||||
b, err := h.Append(nil, protocol.Version1)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(b).To(HaveLen(expectedLen))
|
||||
})
|
||||
|
||||
It("has the right length for an Initial not containing a Token", func() {
|
||||
|
@ -233,9 +229,10 @@ var _ = Describe("Header", func() {
|
|||
PacketNumberLen: protocol.PacketNumberLen2,
|
||||
}
|
||||
expectedLen := 1 /* type byte */ + 4 /* version */ + 1 /* dest conn id len */ + 8 /* dest conn id */ + 1 /* src conn ID len */ + 4 /* src conn id */ + 1 /* token length */ + 2 /* length len */ + 2 /* packet number */
|
||||
Expect(h.GetLength(versionIETFHeader)).To(BeEquivalentTo(expectedLen))
|
||||
Expect(h.Write(buf, versionIETFHeader)).To(Succeed())
|
||||
Expect(buf.Len()).To(Equal(expectedLen))
|
||||
Expect(h.GetLength(protocol.Version1)).To(BeEquivalentTo(expectedLen))
|
||||
b, err := h.Append(nil, protocol.Version1)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(b).To(HaveLen(expectedLen))
|
||||
})
|
||||
|
||||
It("has the right length for an Initial containing a Token", func() {
|
||||
|
@ -250,9 +247,10 @@ var _ = Describe("Header", func() {
|
|||
PacketNumberLen: protocol.PacketNumberLen2,
|
||||
}
|
||||
expectedLen := 1 /* type byte */ + 4 /* version */ + 1 /* dest conn id len */ + 8 /* dest conn id */ + 1 /* src conn id len */ + 4 /* src conn id */ + 1 /* token length */ + 3 /* token */ + 2 /* long len */ + 2 /* packet number */
|
||||
Expect(h.GetLength(versionIETFHeader)).To(BeEquivalentTo(expectedLen))
|
||||
Expect(h.Write(buf, versionIETFHeader)).To(Succeed())
|
||||
Expect(buf.Len()).To(Equal(expectedLen))
|
||||
Expect(h.GetLength(protocol.Version1)).To(BeEquivalentTo(expectedLen))
|
||||
b, err := h.Append(nil, protocol.Version1)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(b).To(HaveLen(expectedLen))
|
||||
})
|
||||
})
|
||||
|
||||
|
|
|
@ -15,8 +15,7 @@ import (
|
|||
var _ = Describe("Header Parsing", func() {
|
||||
Context("Parsing the Connection ID", func() {
|
||||
It("parses the connection ID of a long header packet", func() {
|
||||
buf := &bytes.Buffer{}
|
||||
Expect((&ExtendedHeader{
|
||||
b, err := (&ExtendedHeader{
|
||||
Header: Header{
|
||||
Type: protocol.PacketTypeHandshake,
|
||||
DestConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad}),
|
||||
|
@ -24,15 +23,15 @@ var _ = Describe("Header Parsing", func() {
|
|||
Version: protocol.Version1,
|
||||
},
|
||||
PacketNumberLen: 2,
|
||||
}).Write(buf, protocol.Version1)).To(Succeed())
|
||||
connID, err := ParseConnectionID(buf.Bytes(), 8)
|
||||
}).Append(nil, protocol.Version1)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
connID, err := ParseConnectionID(b, 8)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(connID).To(Equal(protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad})))
|
||||
})
|
||||
|
||||
It("errors on EOF, for long header packets", func() {
|
||||
buf := &bytes.Buffer{}
|
||||
Expect((&ExtendedHeader{
|
||||
b, err := (&ExtendedHeader{
|
||||
Header: Header{
|
||||
Type: protocol.PacketTypeHandshake,
|
||||
DestConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad, 0x13, 0x37}),
|
||||
|
@ -40,9 +39,10 @@ var _ = Describe("Header Parsing", func() {
|
|||
Version: protocol.Version1,
|
||||
},
|
||||
PacketNumberLen: 2,
|
||||
}).Write(buf, protocol.Version1)).To(Succeed())
|
||||
data := buf.Bytes()[:buf.Len()-2] // cut the packet number
|
||||
_, err := ParseConnectionID(data, 8)
|
||||
}).Append(nil, protocol.Version1)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
data := b[:len(b)-2] // cut the packet number
|
||||
_, err = ParseConnectionID(data, 8)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
for i := 0; i < 1 /* first byte */ +4 /* version */ +1 /* conn ID lengths */ +6; /* dest conn ID */ i++ {
|
||||
b := make([]byte, i)
|
||||
|
@ -419,22 +419,22 @@ var _ = Describe("Header Parsing", func() {
|
|||
|
||||
Context("coalesced packets", func() {
|
||||
It("cuts packets", func() {
|
||||
buf := &bytes.Buffer{}
|
||||
hdr := Header{
|
||||
Type: protocol.PacketTypeInitial,
|
||||
DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4}),
|
||||
Length: 2 + 6,
|
||||
Version: protocol.Version1,
|
||||
}
|
||||
Expect((&ExtendedHeader{
|
||||
b, err := (&ExtendedHeader{
|
||||
Header: hdr,
|
||||
PacketNumber: 0x1337,
|
||||
PacketNumberLen: 2,
|
||||
}).Write(buf, protocol.Version1)).To(Succeed())
|
||||
hdrRaw := append([]byte{}, buf.Bytes()...)
|
||||
buf.Write([]byte("foobar")) // payload of the first packet
|
||||
buf.Write([]byte("raboof")) // second packet
|
||||
parsedHdr, data, rest, err := ParsePacket(buf.Bytes())
|
||||
}).Append(nil, protocol.Version1)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
hdrRaw := append([]byte{}, b...)
|
||||
b = append(b, []byte("foobar")...) // payload of the first packet
|
||||
b = append(b, []byte("raboof")...) // second packet
|
||||
parsedHdr, data, rest, err := ParsePacket(b)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(parsedHdr.Type).To(Equal(hdr.Type))
|
||||
Expect(parsedHdr.DestConnectionID).To(Equal(hdr.DestConnectionID))
|
||||
|
@ -443,8 +443,7 @@ var _ = Describe("Header Parsing", func() {
|
|||
})
|
||||
|
||||
It("errors on packets that are smaller than the length in the packet header, for too small packet number", func() {
|
||||
buf := &bytes.Buffer{}
|
||||
Expect((&ExtendedHeader{
|
||||
b, err := (&ExtendedHeader{
|
||||
Header: Header{
|
||||
Type: protocol.PacketTypeInitial,
|
||||
DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4}),
|
||||
|
@ -453,15 +452,15 @@ var _ = Describe("Header Parsing", func() {
|
|||
},
|
||||
PacketNumber: 0x1337,
|
||||
PacketNumberLen: 2,
|
||||
}).Write(buf, protocol.Version1)).To(Succeed())
|
||||
_, _, _, err := ParsePacket(buf.Bytes())
|
||||
}).Append(nil, protocol.Version1)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, _, _, err = ParsePacket(b)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("packet length (2 bytes) is smaller than the expected length (3 bytes)"))
|
||||
})
|
||||
|
||||
It("errors on packets that are smaller than the length in the packet header, for too small payload", func() {
|
||||
buf := &bytes.Buffer{}
|
||||
Expect((&ExtendedHeader{
|
||||
b, err := (&ExtendedHeader{
|
||||
Header: Header{
|
||||
Type: protocol.PacketTypeInitial,
|
||||
DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4}),
|
||||
|
@ -470,9 +469,10 @@ var _ = Describe("Header Parsing", func() {
|
|||
},
|
||||
PacketNumber: 0x1337,
|
||||
PacketNumberLen: 2,
|
||||
}).Write(buf, protocol.Version1)).To(Succeed())
|
||||
buf.Write(make([]byte, 500-2 /* for packet number length */))
|
||||
_, _, _, err := ParsePacket(buf.Bytes())
|
||||
}).Append(nil, protocol.Version1)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
b = append(b, make([]byte, 500-2 /* for packet number length */)...)
|
||||
_, _, _, err = ParsePacket(b)
|
||||
Expect(err).To(MatchError("packet length (500 bytes) is smaller than the expected length (1000 bytes)"))
|
||||
})
|
||||
})
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
package wire
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
|
@ -54,15 +53,15 @@ func ParseShortHeader(data []byte, connIDLen int) (length int, _ protocol.Packet
|
|||
return 1 + connIDLen + int(pnLen), pn, pnLen, kp, err
|
||||
}
|
||||
|
||||
// WriteShortHeader writes a short header.
|
||||
func WriteShortHeader(b *bytes.Buffer, connID protocol.ConnectionID, pn protocol.PacketNumber, pnLen protocol.PacketNumberLen, kp protocol.KeyPhaseBit) error {
|
||||
// AppendShortHeader writes a short header.
|
||||
func AppendShortHeader(b []byte, connID protocol.ConnectionID, pn protocol.PacketNumber, pnLen protocol.PacketNumberLen, kp protocol.KeyPhaseBit) ([]byte, error) {
|
||||
typeByte := 0x40 | uint8(pnLen-1)
|
||||
if kp == protocol.KeyPhaseOne {
|
||||
typeByte |= byte(1 << 2)
|
||||
}
|
||||
b.WriteByte(typeByte)
|
||||
b.Write(connID.Bytes())
|
||||
return writePacketNumber(b, pn, pnLen)
|
||||
b = append(b, typeByte)
|
||||
b = append(b, connID.Bytes()...)
|
||||
return appendPacketNumber(b, pn, pnLen)
|
||||
}
|
||||
|
||||
func ShortHeaderLen(dest protocol.ConnectionID, pnLen protocol.PacketNumberLen) protocol.ByteCount {
|
||||
|
|
|
@ -5,6 +5,7 @@ import (
|
|||
"io"
|
||||
"log"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
|
@ -77,15 +78,15 @@ var _ = Describe("Short Header", func() {
|
|||
|
||||
Context("writing", func() {
|
||||
It("writes a short header packet", func() {
|
||||
b := &bytes.Buffer{}
|
||||
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4})
|
||||
Expect(WriteShortHeader(b, connID, 1337, 4, protocol.KeyPhaseOne)).To(Succeed())
|
||||
l, pn, pnLen, kp, err := ParseShortHeader(b.Bytes(), 4)
|
||||
b, err := AppendShortHeader(nil, connID, 1337, 4, protocol.KeyPhaseOne)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
l, pn, pnLen, kp, err := ParseShortHeader(b, 4)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(pn).To(Equal(protocol.PacketNumber(1337)))
|
||||
Expect(pnLen).To(Equal(protocol.PacketNumberLen4))
|
||||
Expect(kp).To(Equal(protocol.KeyPhaseOne))
|
||||
Expect(l).To(Equal(b.Len()))
|
||||
Expect(l).To(Equal(len(b)))
|
||||
})
|
||||
})
|
||||
|
||||
|
@ -113,3 +114,17 @@ var _ = Describe("Short Header", func() {
|
|||
})
|
||||
})
|
||||
})
|
||||
|
||||
func BenchmarkWriteShortHeader(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
buf := make([]byte, 100)
|
||||
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6})
|
||||
for i := 0; i < b.N; i++ {
|
||||
var err error
|
||||
buf, err = AppendShortHeader(buf, connID, 1337, protocol.PacketNumberLen4, protocol.KeyPhaseOne)
|
||||
if err != nil {
|
||||
b.Fatalf("failed to write short header: %s", err)
|
||||
}
|
||||
buf = buf[:0]
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
package quic
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"errors"
|
||||
"net"
|
||||
|
@ -14,7 +13,6 @@ import (
|
|||
"github.com/lucas-clemente/quic-go/logging"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
@ -37,8 +35,7 @@ var _ = Describe("Packet Handler Map", func() {
|
|||
)
|
||||
|
||||
getPacketWithPacketType := func(connID protocol.ConnectionID, t protocol.PacketType, length protocol.ByteCount) []byte {
|
||||
buf := &bytes.Buffer{}
|
||||
Expect((&wire.ExtendedHeader{
|
||||
b, err := (&wire.ExtendedHeader{
|
||||
Header: wire.Header{
|
||||
Type: t,
|
||||
DestConnectionID: connID,
|
||||
|
@ -46,8 +43,9 @@ var _ = Describe("Packet Handler Map", func() {
|
|||
Version: protocol.VersionTLS,
|
||||
},
|
||||
PacketNumberLen: protocol.PacketNumberLen2,
|
||||
}).Write(buf, protocol.VersionWhatever)).To(Succeed())
|
||||
return buf.Bytes()
|
||||
}).Append(nil, protocol.VersionWhatever)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
return b
|
||||
}
|
||||
|
||||
getPacket := func(connID protocol.ConnectionID) []byte {
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
package quic
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
|
@ -800,13 +799,12 @@ func (p *packetPacker) appendLongHeaderPacket(buffer *packetBuffer, header *wire
|
|||
paddingLen += padding
|
||||
header.Length = pnLen + protocol.ByteCount(sealer.Overhead()) + payload.length + paddingLen
|
||||
|
||||
raw := buffer.Data[len(buffer.Data):]
|
||||
buf := bytes.NewBuffer(buffer.Data)
|
||||
startLen := buf.Len()
|
||||
if err := header.Write(buf, p.version); err != nil {
|
||||
startLen := len(buffer.Data)
|
||||
raw := buffer.Data[startLen:]
|
||||
raw, err := header.Append(raw, p.version)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
raw = raw[:buf.Len()-startLen]
|
||||
payloadOffset := protocol.ByteCount(len(raw))
|
||||
|
||||
pn := p.pnManager.PopPacketNumber(encLevel)
|
||||
|
@ -814,7 +812,7 @@ func (p *packetPacker) appendLongHeaderPacket(buffer *packetBuffer, header *wire
|
|||
return nil, errors.New("packetPacker BUG: Peeked and Popped packet numbers do not match")
|
||||
}
|
||||
|
||||
raw, err := p.appendPacketPayload(raw, payload, paddingLen)
|
||||
raw, err = p.appendPacketPayload(raw, payload, paddingLen)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -846,20 +844,19 @@ func (p *packetPacker) appendShortHeaderPacket(
|
|||
}
|
||||
paddingLen += padding
|
||||
|
||||
raw := buffer.Data[len(buffer.Data):]
|
||||
buf := bytes.NewBuffer(buffer.Data)
|
||||
startLen := buf.Len()
|
||||
if err := wire.WriteShortHeader(buf, connID, pn, pnLen, kp); err != nil {
|
||||
startLen := len(buffer.Data)
|
||||
raw := buffer.Data[startLen:]
|
||||
raw, err := wire.AppendShortHeader(raw, connID, pn, pnLen, kp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
raw = raw[:buf.Len()-startLen]
|
||||
payloadOffset := protocol.ByteCount(len(raw))
|
||||
|
||||
if pn != p.pnManager.PopPacketNumber(protocol.Encryption1RTT) {
|
||||
return nil, errors.New("packetPacker BUG: Peeked and Popped packet numbers do not match")
|
||||
}
|
||||
|
||||
raw, err := p.appendPacketPayload(raw, payload, paddingLen)
|
||||
raw, err = p.appendPacketPayload(raw, payload, paddingLen)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
package quic
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
|
@ -12,7 +11,6 @@ import (
|
|||
"github.com/lucas-clemente/quic-go/internal/wire"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
@ -28,21 +26,22 @@ var _ = Describe("Packet Unpacker", func() {
|
|||
)
|
||||
|
||||
getLongHeader := func(extHdr *wire.ExtendedHeader) (*wire.Header, []byte) {
|
||||
buf := &bytes.Buffer{}
|
||||
ExpectWithOffset(1, extHdr.Write(buf, version)).To(Succeed())
|
||||
hdrLen := buf.Len()
|
||||
if extHdr.Length > protocol.ByteCount(extHdr.PacketNumberLen) {
|
||||
buf.Write(make([]byte, int(extHdr.Length)-int(extHdr.PacketNumberLen)))
|
||||
}
|
||||
hdr, _, _, err := wire.ParsePacket(buf.Bytes())
|
||||
b, err := extHdr.Append(nil, version)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
ExpectWithOffset(1, err).ToNot(HaveOccurred())
|
||||
return hdr, buf.Bytes()[:hdrLen]
|
||||
hdrLen := len(b)
|
||||
if extHdr.Length > protocol.ByteCount(extHdr.PacketNumberLen) {
|
||||
b = append(b, make([]byte, int(extHdr.Length)-int(extHdr.PacketNumberLen))...)
|
||||
}
|
||||
hdr, _, _, err := wire.ParsePacket(b)
|
||||
ExpectWithOffset(1, err).ToNot(HaveOccurred())
|
||||
return hdr, b[:hdrLen]
|
||||
}
|
||||
|
||||
getShortHeader := func(connID protocol.ConnectionID, pn protocol.PacketNumber, pnLen protocol.PacketNumberLen, kp protocol.KeyPhaseBit) []byte {
|
||||
buf := &bytes.Buffer{}
|
||||
Expect(wire.WriteShortHeader(buf, connID, pn, pnLen, kp)).To(Succeed())
|
||||
return buf.Bytes()
|
||||
b, err := wire.AppendShortHeader(nil, connID, pn, pnLen, kp)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
return b
|
||||
}
|
||||
|
||||
BeforeEach(func() {
|
||||
|
@ -72,12 +71,12 @@ var _ = Describe("Packet Unpacker", func() {
|
|||
})
|
||||
|
||||
It("errors when the packet is too small to obtain the header decryption sample, for short headers", func() {
|
||||
buf := &bytes.Buffer{}
|
||||
Expect(wire.WriteShortHeader(buf, connID, 1337, protocol.PacketNumberLen2, protocol.KeyPhaseOne)).To(Succeed())
|
||||
data := append(buf.Bytes(), make([]byte, 2 /* fill up packet number */ +15 /* need 16 bytes */)...)
|
||||
b, err := wire.AppendShortHeader(nil, connID, 1337, protocol.PacketNumberLen2, protocol.KeyPhaseOne)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
data := append(b, make([]byte, 2 /* fill up packet number */ +15 /* need 16 bytes */)...)
|
||||
opener := mocks.NewMockShortHeaderOpener(mockCtrl)
|
||||
cs.EXPECT().Get1RTTOpener().Return(opener, nil)
|
||||
_, _, _, _, err := unpacker.UnpackShortHeader(time.Now(), data)
|
||||
_, _, _, _, err = unpacker.UnpackShortHeader(time.Now(), data)
|
||||
Expect(err).To(BeAssignableToTypeOf(&headerParseError{}))
|
||||
Expect(err).To(MatchError("packet too small, expected at least 20 bytes after the header, got 19"))
|
||||
})
|
||||
|
|
|
@ -107,32 +107,32 @@ func Append(b []byte, i uint64) []byte {
|
|||
panic(fmt.Sprintf("%#x doesn't fit into 62 bits", i))
|
||||
}
|
||||
|
||||
// WriteWithLen writes i in the QUIC varint format with the desired length to w.
|
||||
func WriteWithLen(w Writer, i uint64, length protocol.ByteCount) {
|
||||
// AppendWithLen append i in the QUIC varint format with the desired length.
|
||||
func AppendWithLen(b []byte, i uint64, length protocol.ByteCount) []byte {
|
||||
if length != 1 && length != 2 && length != 4 && length != 8 {
|
||||
panic("invalid varint length")
|
||||
}
|
||||
l := Len(i)
|
||||
if l == length {
|
||||
Write(w, i)
|
||||
return
|
||||
return Append(b, i)
|
||||
}
|
||||
if l > length {
|
||||
panic(fmt.Sprintf("cannot encode %d in %d bytes", i, length))
|
||||
}
|
||||
if length == 2 {
|
||||
w.WriteByte(0b01000000)
|
||||
b = append(b, 0b01000000)
|
||||
} else if length == 4 {
|
||||
w.WriteByte(0b10000000)
|
||||
b = append(b, 0b10000000)
|
||||
} else if length == 8 {
|
||||
w.WriteByte(0b11000000)
|
||||
b = append(b, 0b11000000)
|
||||
}
|
||||
for j := protocol.ByteCount(1); j < length-l; j++ {
|
||||
w.WriteByte(0)
|
||||
b = append(b, 0)
|
||||
}
|
||||
for j := protocol.ByteCount(0); j < l; j++ {
|
||||
w.WriteByte(uint8(i >> (8 * (l - 1 - j))))
|
||||
b = append(b, uint8(i>>(8*(l-1-j))))
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
// Len determines the number of bytes that will be needed to write the number i.
|
||||
|
|
|
@ -142,54 +142,47 @@ var _ = Describe("Varint encoding / decoding", func() {
|
|||
|
||||
Context("with fixed length", func() {
|
||||
It("panics when given an invalid length", func() {
|
||||
Expect(func() { WriteWithLen(&bytes.Buffer{}, 25, 3) }).Should(Panic())
|
||||
Expect(func() { AppendWithLen(nil, 25, 3) }).Should(Panic())
|
||||
})
|
||||
|
||||
It("panics when given a too short length", func() {
|
||||
Expect(func() { WriteWithLen(&bytes.Buffer{}, maxVarInt1+1, 1) }).Should(Panic())
|
||||
Expect(func() { WriteWithLen(&bytes.Buffer{}, maxVarInt2+1, 2) }).Should(Panic())
|
||||
Expect(func() { WriteWithLen(&bytes.Buffer{}, maxVarInt4+1, 4) }).Should(Panic())
|
||||
Expect(func() { AppendWithLen(nil, maxVarInt1+1, 1) }).Should(Panic())
|
||||
Expect(func() { AppendWithLen(nil, maxVarInt2+1, 2) }).Should(Panic())
|
||||
Expect(func() { AppendWithLen(nil, maxVarInt4+1, 4) }).Should(Panic())
|
||||
})
|
||||
|
||||
It("writes a 1-byte number in minimal encoding", func() {
|
||||
b := &bytes.Buffer{}
|
||||
WriteWithLen(b, 37, 1)
|
||||
Expect(b.Bytes()).To(Equal([]byte{0x25}))
|
||||
Expect(AppendWithLen(nil, 37, 1)).To(Equal([]byte{0x25}))
|
||||
})
|
||||
|
||||
It("writes a 1-byte number in 2 bytes", func() {
|
||||
b := &bytes.Buffer{}
|
||||
WriteWithLen(b, 37, 2)
|
||||
Expect(b.Bytes()).To(Equal([]byte{0b01000000, 0x25}))
|
||||
Expect(Read(b)).To(BeEquivalentTo(37))
|
||||
b := AppendWithLen(nil, 37, 2)
|
||||
Expect(b).To(Equal([]byte{0b01000000, 0x25}))
|
||||
Expect(Read(bytes.NewReader(b))).To(BeEquivalentTo(37))
|
||||
})
|
||||
|
||||
It("writes a 1-byte number in 4 bytes", func() {
|
||||
b := &bytes.Buffer{}
|
||||
WriteWithLen(b, 37, 4)
|
||||
Expect(b.Bytes()).To(Equal([]byte{0b10000000, 0, 0, 0x25}))
|
||||
Expect(Read(b)).To(BeEquivalentTo(37))
|
||||
b := AppendWithLen(nil, 37, 4)
|
||||
Expect(b).To(Equal([]byte{0b10000000, 0, 0, 0x25}))
|
||||
Expect(Read(bytes.NewReader(b))).To(BeEquivalentTo(37))
|
||||
})
|
||||
|
||||
It("writes a 1-byte number in 8 bytes", func() {
|
||||
b := &bytes.Buffer{}
|
||||
WriteWithLen(b, 37, 8)
|
||||
Expect(b.Bytes()).To(Equal([]byte{0b11000000, 0, 0, 0, 0, 0, 0, 0x25}))
|
||||
Expect(Read(b)).To(BeEquivalentTo(37))
|
||||
b := AppendWithLen(nil, 37, 8)
|
||||
Expect(b).To(Equal([]byte{0b11000000, 0, 0, 0, 0, 0, 0, 0x25}))
|
||||
Expect(Read(bytes.NewReader(b))).To(BeEquivalentTo(37))
|
||||
})
|
||||
|
||||
It("writes a 2-byte number in 4 bytes", func() {
|
||||
b := &bytes.Buffer{}
|
||||
WriteWithLen(b, 15293, 4)
|
||||
Expect(b.Bytes()).To(Equal([]byte{0b10000000, 0, 0x3b, 0xbd}))
|
||||
Expect(Read(b)).To(BeEquivalentTo(15293))
|
||||
b := AppendWithLen(nil, 15293, 4)
|
||||
Expect(b).To(Equal([]byte{0b10000000, 0, 0x3b, 0xbd}))
|
||||
Expect(Read(bytes.NewReader(b))).To(BeEquivalentTo(15293))
|
||||
})
|
||||
|
||||
It("write a 4-byte number in 8 bytes", func() {
|
||||
b := &bytes.Buffer{}
|
||||
WriteWithLen(b, 494878333, 8)
|
||||
Expect(b.Bytes()).To(Equal([]byte{0b11000000, 0, 0, 0, 0x1d, 0x7f, 0x3e, 0x7d}))
|
||||
Expect(Read(b)).To(BeEquivalentTo(494878333))
|
||||
b := AppendWithLen(nil, 494878333, 8)
|
||||
Expect(b).To(Equal([]byte{0b11000000, 0, 0, 0, 0x1d, 0x7f, 0x3e, 0x7d}))
|
||||
Expect(Read(bytes.NewReader(b))).To(BeEquivalentTo(494878333))
|
||||
})
|
||||
})
|
||||
|
||||
|
|
45
server.go
45
server.go
|
@ -1,7 +1,6 @@
|
|||
package quic
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/tls"
|
||||
|
@ -577,19 +576,19 @@ func (s *baseServer) sendRetry(remoteAddr net.Addr, hdr *wire.Header, info *pack
|
|||
replyHdr.Log(s.logger)
|
||||
}
|
||||
|
||||
packetBuffer := getPacketBuffer()
|
||||
defer packetBuffer.Release()
|
||||
buf := bytes.NewBuffer(packetBuffer.Data)
|
||||
if err := replyHdr.Write(buf, hdr.Version); err != nil {
|
||||
buf := getPacketBuffer()
|
||||
defer buf.Release()
|
||||
buf.Data, err = replyHdr.Append(buf.Data, hdr.Version)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// append the Retry integrity tag
|
||||
tag := handshake.GetRetryIntegrityTag(buf.Bytes(), hdr.DestConnectionID, hdr.Version)
|
||||
buf.Write(tag[:])
|
||||
tag := handshake.GetRetryIntegrityTag(buf.Data, hdr.DestConnectionID, hdr.Version)
|
||||
buf.Data = append(buf.Data, tag[:]...)
|
||||
if s.config.Tracer != nil {
|
||||
s.config.Tracer.SentPacket(remoteAddr, &replyHdr.Header, protocol.ByteCount(buf.Len()), nil)
|
||||
s.config.Tracer.SentPacket(remoteAddr, &replyHdr.Header, protocol.ByteCount(len(buf.Data)), nil)
|
||||
}
|
||||
_, err = s.conn.WritePacket(buf.Bytes(), remoteAddr, info.OOB())
|
||||
_, err = s.conn.WritePacket(buf.Data, remoteAddr, info.OOB())
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -627,9 +626,8 @@ func (s *baseServer) sendConnectionRefused(remoteAddr net.Addr, hdr *wire.Header
|
|||
|
||||
// sendError sends the error as a response to the packet received with header hdr
|
||||
func (s *baseServer) sendError(remoteAddr net.Addr, hdr *wire.Header, sealer handshake.LongHeaderSealer, errorCode qerr.TransportErrorCode, info *packetInfo) error {
|
||||
packetBuffer := getPacketBuffer()
|
||||
defer packetBuffer.Release()
|
||||
buf := bytes.NewBuffer(packetBuffer.Data)
|
||||
b := getPacketBuffer()
|
||||
defer b.Release()
|
||||
|
||||
ccf := &wire.ConnectionCloseFrame{ErrorCode: uint64(errorCode)}
|
||||
|
||||
|
@ -640,33 +638,34 @@ func (s *baseServer) sendError(remoteAddr net.Addr, hdr *wire.Header, sealer han
|
|||
replyHdr.DestConnectionID = hdr.SrcConnectionID
|
||||
replyHdr.PacketNumberLen = protocol.PacketNumberLen4
|
||||
replyHdr.Length = 4 /* packet number len */ + ccf.Length(hdr.Version) + protocol.ByteCount(sealer.Overhead())
|
||||
if err := replyHdr.Write(buf, hdr.Version); err != nil {
|
||||
var err error
|
||||
b.Data, err = replyHdr.Append(b.Data, hdr.Version)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
payloadOffset := buf.Len()
|
||||
payloadOffset := len(b.Data)
|
||||
|
||||
raw := buf.Bytes()
|
||||
raw, err := ccf.Append(raw, hdr.Version)
|
||||
b.Data, err = ccf.Append(b.Data, hdr.Version)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_ = sealer.Seal(raw[payloadOffset:payloadOffset], raw[payloadOffset:], replyHdr.PacketNumber, raw[:payloadOffset])
|
||||
raw = raw[0 : len(raw)+sealer.Overhead()]
|
||||
_ = sealer.Seal(b.Data[payloadOffset:payloadOffset], b.Data[payloadOffset:], replyHdr.PacketNumber, b.Data[:payloadOffset])
|
||||
b.Data = b.Data[0 : len(b.Data)+sealer.Overhead()]
|
||||
|
||||
pnOffset := payloadOffset - int(replyHdr.PacketNumberLen)
|
||||
sealer.EncryptHeader(
|
||||
raw[pnOffset+4:pnOffset+4+16],
|
||||
&raw[0],
|
||||
raw[pnOffset:payloadOffset],
|
||||
b.Data[pnOffset+4:pnOffset+4+16],
|
||||
&b.Data[0],
|
||||
b.Data[pnOffset:payloadOffset],
|
||||
)
|
||||
|
||||
replyHdr.Log(s.logger)
|
||||
wire.LogFrame(s.logger, ccf, true)
|
||||
if s.config.Tracer != nil {
|
||||
s.config.Tracer.SentPacket(remoteAddr, &replyHdr.Header, protocol.ByteCount(len(raw)), []logging.Frame{ccf})
|
||||
s.config.Tracer.SentPacket(remoteAddr, &replyHdr.Header, protocol.ByteCount(len(b.Data)), []logging.Frame{ccf})
|
||||
}
|
||||
_, err = s.conn.WritePacket(raw, remoteAddr, info.OOB())
|
||||
_, err = s.conn.WritePacket(b.Data, remoteAddr, info.OOB())
|
||||
return err
|
||||
}
|
||||
|
||||
|
|
|
@ -42,17 +42,18 @@ var _ = Describe("Server", func() {
|
|||
)
|
||||
|
||||
getPacket := func(hdr *wire.Header, p []byte) *receivedPacket {
|
||||
buffer := getPacketBuffer()
|
||||
buf := bytes.NewBuffer(buffer.Data)
|
||||
buf := getPacketBuffer()
|
||||
hdr.Length = 4 + protocol.ByteCount(len(p)) + 16
|
||||
Expect((&wire.ExtendedHeader{
|
||||
var err error
|
||||
buf.Data, err = (&wire.ExtendedHeader{
|
||||
Header: *hdr,
|
||||
PacketNumber: 0x42,
|
||||
PacketNumberLen: protocol.PacketNumberLen4,
|
||||
}).Write(buf, protocol.VersionTLS)).To(Succeed())
|
||||
n := buf.Len()
|
||||
buf.Write(p)
|
||||
data := buffer.Data[:buf.Len()]
|
||||
}).Append(buf.Data, protocol.VersionTLS)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
n := len(buf.Data)
|
||||
buf.Data = append(buf.Data, p...)
|
||||
data := buf.Data
|
||||
sealer, _ := handshake.NewInitialAEAD(hdr.DestConnectionID, protocol.PerspectiveClient, hdr.Version)
|
||||
_ = sealer.Seal(data[n:n], data[n:], 0x42, data[:n])
|
||||
data = data[:len(data)+16]
|
||||
|
@ -60,7 +61,7 @@ var _ = Describe("Server", func() {
|
|||
return &receivedPacket{
|
||||
remoteAddr: &net.UDPAddr{IP: net.IPv4(4, 5, 6, 7), Port: 456},
|
||||
data: data,
|
||||
buffer: buffer,
|
||||
buffer: buf,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue