diff --git a/example/main.go b/example/main.go index 1265d3b9..bd36b59f 100644 --- a/example/main.go +++ b/example/main.go @@ -118,10 +118,12 @@ func main() { }) fullReply := &bytes.Buffer{} - quic.WritePublicHeader(fullReply, &quic.PublicHeader{ - ConnectionID: publicHeader.ConnectionID, - PacketNumber: 1, - }) + responsePublicHeader := quic.PublicHeader{ConnectionID: publicHeader.ConnectionID, PacketNumber: 1} + fmt.Println(responsePublicHeader) + err = responsePublicHeader.WritePublicHeader(fullReply) + if err != nil { + panic(err) + } nullAEAD.Seal(fullReply, fullReply.Bytes(), replyFrame.Bytes()) diff --git a/public_header.go b/public_header.go index 0cd11c05..a77152ac 100644 --- a/public_header.go +++ b/public_header.go @@ -2,6 +2,7 @@ package quic import ( "bytes" + "errors" "io" "github.com/lucas-clemente/quic-go/utils" @@ -14,6 +15,26 @@ type PublicHeader struct { ConnectionID uint64 QuicVersion uint32 PacketNumber uint64 + // packetNumberLen uint8 +} + +// WritePublicHeader writes a public header +func (h *PublicHeader) WritePublicHeader(b *bytes.Buffer) error { + publicFlagByte := uint8(0x0C | 0x20) + if h.VersionFlag && h.ResetFlag { + return errors.New("Reset Flag and Version Flag should not be set at the same time") + } + if h.VersionFlag { + publicFlagByte |= 0x01 + } + if h.ResetFlag { + publicFlagByte |= 0x02 + } + + b.WriteByte(publicFlagByte) + utils.WriteUint64(b, h.ConnectionID) // TODO: Send shorter connection id if possible + utils.WriteUint32(b, uint32(h.PacketNumber)) // TODO: Send shorter packet number if possible + return nil } // ParsePublicHeader parses a QUIC packet's public header @@ -71,11 +92,3 @@ func ParsePublicHeader(b io.ByteReader) (*PublicHeader, error) { return header, nil } - -// WritePublicHeader writes a public header -func WritePublicHeader(b *bytes.Buffer, h *PublicHeader) { - publicFlagByte := uint8(0x0C | 0x20) - b.WriteByte(publicFlagByte) - utils.WriteUint64(b, h.ConnectionID) // TODO: Send shorter connection id if possible - utils.WriteUint32(b, uint32(h.PacketNumber)) // TODO: Send shorter packet number if possible -} diff --git a/public_header_test.go b/public_header_test.go index e30013d1..f5f6399f 100644 --- a/public_header_test.go +++ b/public_header_test.go @@ -76,11 +76,48 @@ var _ = Describe("Public Header", func() { Context("when writing", func() { It("writes a sample header", func() { b := &bytes.Buffer{} - WritePublicHeader(b, &PublicHeader{ + publicHeader := PublicHeader{ ConnectionID: 0x4cfa9f9b668619f6, PacketNumber: 2, - }) + } + publicHeader.WritePublicHeader(b) Expect(b.Bytes()).To(Equal([]byte{0x2c, 0xf6, 0x19, 0x86, 0x66, 0x9b, 0x9f, 0xfa, 0x4c, 2, 0, 0, 0})) }) + + It("sets the Version Flag", func() { + b := &bytes.Buffer{} + publicHeader := PublicHeader{ + VersionFlag: true, + ConnectionID: 0x4cfa9f9b668619f6, + PacketNumber: 2, + } + publicHeader.WritePublicHeader(b) + firstByte, _ := b.ReadByte() + Expect(firstByte & 0x01).To(Equal(uint8(1))) + }) + + It("sets the Reset Flag", func() { + b := &bytes.Buffer{} + publicHeader := PublicHeader{ + ResetFlag: true, + ConnectionID: 0x4cfa9f9b668619f6, + PacketNumber: 2, + } + publicHeader.WritePublicHeader(b) + firstByte, _ := b.ReadByte() + Expect((firstByte & 0x02) >> 1).To(Equal(uint8(1))) + }) + + It("throws an error if both Reset Flag and Version Flag are set", func() { + b := &bytes.Buffer{} + publicHeader := PublicHeader{ + VersionFlag: true, + ResetFlag: true, + ConnectionID: 0x4cfa9f9b668619f6, + PacketNumber: 2, + } + err := publicHeader.WritePublicHeader(b) + Expect(err).To(HaveOccurred()) + }) }) })