diff --git a/internal/protocol/protocol.go b/internal/protocol/protocol.go index 130cb1bf..9f19bef9 100644 --- a/internal/protocol/protocol.go +++ b/internal/protocol/protocol.go @@ -34,6 +34,27 @@ func (t PacketType) String() string { } } +// KeyPhase is the key phase +type KeyPhase bool + +const ( + // KeyPhaseZero is key phase 0 + KeyPhaseZero KeyPhase = false + // KeyPhaseOne is key phase 1 + KeyPhaseOne KeyPhase = true +) + +func (p KeyPhase) String() string { + if p == KeyPhaseZero { + return "0" + } + return "1" +} + +func (p KeyPhase) Next() KeyPhase { + return !p +} + // A ByteCount in QUIC type ByteCount uint64 diff --git a/internal/protocol/protocol_test.go b/internal/protocol/protocol_test.go index a89f732f..057e87b7 100644 --- a/internal/protocol/protocol_test.go +++ b/internal/protocol/protocol_test.go @@ -15,4 +15,16 @@ var _ = Describe("Protocol", func() { Expect(PacketType(10).String()).To(Equal("unknown packet type: 10")) }) }) + + Context("Key Phases", func() { + It("has the correct string representation", func() { + Expect(KeyPhaseZero.String()).To(Equal("0")) + Expect(KeyPhaseOne.String()).To(Equal("1")) + }) + + It("returns the next key phase", func() { + Expect(KeyPhaseZero.Next()).To(Equal(KeyPhaseOne)) + Expect(KeyPhaseOne.Next()).To(Equal(KeyPhaseZero)) + }) + }) }) diff --git a/internal/wire/extended_header.go b/internal/wire/extended_header.go index 19a0b064..4c50e050 100644 --- a/internal/wire/extended_header.go +++ b/internal/wire/extended_header.go @@ -19,7 +19,7 @@ type ExtendedHeader struct { PacketNumberLen protocol.PacketNumberLen PacketNumber protocol.PacketNumber - KeyPhase int + KeyPhase protocol.KeyPhase } func (h *ExtendedHeader) parse(b *bytes.Reader, v protocol.VersionNumber) (*ExtendedHeader, error) { @@ -53,7 +53,10 @@ func (h *ExtendedHeader) parseShortHeader(b *bytes.Reader, v protocol.VersionNum return nil, errors.New("4th and 5th bit must be 0") } - h.KeyPhase = int(h.typeByte&0x4) >> 2 + h.KeyPhase = protocol.KeyPhaseZero + if h.typeByte&0x4 > 0 { + h.KeyPhase = protocol.KeyPhaseOne + } if err := h.readPacketNumber(b); err != nil { return nil, err @@ -129,7 +132,9 @@ func (h *ExtendedHeader) writeLongHeader(b *bytes.Buffer, v protocol.VersionNumb // TODO: add support for the key phase func (h *ExtendedHeader) writeShortHeader(b *bytes.Buffer, v protocol.VersionNumber) error { typeByte := 0x40 | uint8(h.PacketNumberLen-1) - typeByte |= byte(h.KeyPhase << 2) + if h.KeyPhase == protocol.KeyPhaseOne { + typeByte |= byte(1 << 2) + } b.WriteByte(typeByte) b.Write(h.DestConnectionID.Bytes()) @@ -176,7 +181,7 @@ func (h *ExtendedHeader) Log(logger utils.Logger) { } logger.Debugf("\tLong Header{Type: %s, DestConnectionID: %s, SrcConnectionID: %s, %sPacketNumber: %#x, PacketNumberLen: %d, Length: %d, Version: %s}", h.Type, h.DestConnectionID, h.SrcConnectionID, token, h.PacketNumber, h.PacketNumberLen, h.Length, h.Version) } else { - logger.Debugf("\tShort Header{DestConnectionID: %s, PacketNumber: %#x, PacketNumberLen: %d, KeyPhase: %d}", h.DestConnectionID, h.PacketNumber, h.PacketNumberLen, h.KeyPhase) + logger.Debugf("\tShort Header{DestConnectionID: %s, PacketNumber: %#x, PacketNumberLen: %d, KeyPhase: %s}", h.DestConnectionID, h.PacketNumber, h.PacketNumberLen, h.KeyPhase) } } diff --git a/internal/wire/extended_header_test.go b/internal/wire/extended_header_test.go index e669fd09..86f85d84 100644 --- a/internal/wire/extended_header_test.go +++ b/internal/wire/extended_header_test.go @@ -199,7 +199,7 @@ var _ = Describe("Header", func() { It("writes the Key Phase Bit", func() { Expect((&ExtendedHeader{ - KeyPhase: 1, + KeyPhase: protocol.KeyPhaseOne, PacketNumberLen: protocol.PacketNumberLen1, PacketNumber: 0x42, }).Write(buf, versionIETFHeader)).To(Succeed()) @@ -407,7 +407,7 @@ var _ = Describe("Header", func() { Header: Header{ DestConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37}, }, - KeyPhase: 1, + KeyPhase: protocol.KeyPhaseOne, PacketNumber: 0x1337, PacketNumberLen: 4, }).Log(logger) diff --git a/internal/wire/header_test.go b/internal/wire/header_test.go index 67a3fb4f..e9ef160d 100644 --- a/internal/wire/header_test.go +++ b/internal/wire/header_test.go @@ -425,7 +425,7 @@ var _ = Describe("Header Parsing", func() { b := bytes.NewReader(data) extHdr, err := hdr.ParseExtended(b, versionIETFFrames) Expect(err).ToNot(HaveOccurred()) - Expect(extHdr.KeyPhase).To(Equal(0)) + Expect(extHdr.KeyPhase).To(Equal(protocol.KeyPhaseZero)) Expect(extHdr.DestConnectionID).To(Equal(connID)) Expect(extHdr.SrcConnectionID).To(BeEmpty()) Expect(extHdr.PacketNumber).To(Equal(protocol.PacketNumber(0x42))) @@ -462,7 +462,7 @@ var _ = Describe("Header Parsing", func() { b := bytes.NewReader(data) extHdr, err := hdr.ParseExtended(b, versionIETFFrames) Expect(err).ToNot(HaveOccurred()) - Expect(extHdr.KeyPhase).To(Equal(0)) + Expect(extHdr.KeyPhase).To(Equal(protocol.KeyPhaseZero)) Expect(extHdr.DestConnectionID).To(Equal(connID)) Expect(extHdr.SrcConnectionID).To(BeEmpty()) Expect(rest).To(BeEmpty()) @@ -480,7 +480,7 @@ var _ = Describe("Header Parsing", func() { b := bytes.NewReader(data) extHdr, err := hdr.ParseExtended(b, versionIETFFrames) Expect(err).ToNot(HaveOccurred()) - Expect(extHdr.KeyPhase).To(Equal(1)) + Expect(extHdr.KeyPhase).To(Equal(protocol.KeyPhaseOne)) Expect(b.Len()).To(BeZero()) })