diff --git a/internal/handshake/interface.go b/internal/handshake/interface.go index 55bc035d..34cf7855 100644 --- a/internal/handshake/interface.go +++ b/internal/handshake/interface.go @@ -34,7 +34,7 @@ type LongHeaderOpener interface { // ShortHeaderOpener opens a short header packet type ShortHeaderOpener interface { headerDecryptor - Open(dst, src []byte, pn protocol.PacketNumber, kp protocol.KeyPhase, associatedData []byte) ([]byte, error) + Open(dst, src []byte, pn protocol.PacketNumber, kp protocol.KeyPhaseBit, associatedData []byte) ([]byte, error) } // LongHeaderSealer seals a long header packet @@ -47,7 +47,7 @@ type LongHeaderSealer interface { // ShortHeaderSealer seals a short header packet type ShortHeaderSealer interface { LongHeaderSealer - KeyPhase() protocol.KeyPhase + KeyPhase() protocol.KeyPhaseBit } // A tlsExtensionHandler sends and received the QUIC TLS extension. diff --git a/internal/handshake/updatable_aead.go b/internal/handshake/updatable_aead.go index 4fe8767b..e086846c 100644 --- a/internal/handshake/updatable_aead.go +++ b/internal/handshake/updatable_aead.go @@ -43,7 +43,7 @@ func setKeyUpdateInterval() { type updatableAEAD struct { suite cipherSuite - keyPhase protocol.KeyPhase + keyPhase protocol.KeyPhaseBit largestAcked protocol.PacketNumber keyUpdateInterval uint64 @@ -134,7 +134,7 @@ func (a *updatableAEAD) SetWriteKey(suite cipherSuite, trafficSecret []byte) { a.nextSendAEAD = createAEAD(suite, a.nextSendTrafficSecret) } -func (a *updatableAEAD) Open(dst, src []byte, pn protocol.PacketNumber, kp protocol.KeyPhase, ad []byte) ([]byte, error) { +func (a *updatableAEAD) Open(dst, src []byte, pn protocol.PacketNumber, kp protocol.KeyPhaseBit, ad []byte) ([]byte, error) { binary.BigEndian.PutUint64(a.nonceBuf[len(a.nonceBuf)-8:], uint64(pn)) if kp != a.keyPhase { if a.firstRcvdWithCurrentKey == protocol.InvalidPacketNumber || pn < a.firstRcvdWithCurrentKey { @@ -215,7 +215,7 @@ func (a *updatableAEAD) shouldInitiateKeyUpdate() bool { return false } -func (a *updatableAEAD) KeyPhase() protocol.KeyPhase { +func (a *updatableAEAD) KeyPhase() protocol.KeyPhaseBit { if a.shouldInitiateKeyUpdate() { a.rollKeys() } diff --git a/internal/mocks/short_header_opener.go b/internal/mocks/short_header_opener.go index a51f9eb5..3d834b3b 100644 --- a/internal/mocks/short_header_opener.go +++ b/internal/mocks/short_header_opener.go @@ -47,7 +47,7 @@ func (mr *MockShortHeaderOpenerMockRecorder) DecryptHeader(arg0, arg1, arg2 inte } // Open mocks base method -func (m *MockShortHeaderOpener) Open(arg0, arg1 []byte, arg2 protocol.PacketNumber, arg3 protocol.KeyPhase, arg4 []byte) ([]byte, error) { +func (m *MockShortHeaderOpener) Open(arg0, arg1 []byte, arg2 protocol.PacketNumber, arg3 protocol.KeyPhaseBit, arg4 []byte) ([]byte, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Open", arg0, arg1, arg2, arg3, arg4) ret0, _ := ret[0].([]byte) diff --git a/internal/mocks/short_header_sealer.go b/internal/mocks/short_header_sealer.go index 1f269b51..81825543 100644 --- a/internal/mocks/short_header_sealer.go +++ b/internal/mocks/short_header_sealer.go @@ -47,10 +47,10 @@ func (mr *MockShortHeaderSealerMockRecorder) EncryptHeader(arg0, arg1, arg2 inte } // KeyPhase mocks base method -func (m *MockShortHeaderSealer) KeyPhase() protocol.KeyPhase { +func (m *MockShortHeaderSealer) KeyPhase() protocol.KeyPhaseBit { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "KeyPhase") - ret0, _ := ret[0].(protocol.KeyPhase) + ret0, _ := ret[0].(protocol.KeyPhaseBit) return ret0 } diff --git a/internal/protocol/key_phase.go b/internal/protocol/key_phase.go new file mode 100644 index 00000000..693aeb15 --- /dev/null +++ b/internal/protocol/key_phase.go @@ -0,0 +1,22 @@ +package protocol + +// KeyPhaseBit is the key phase bit +type KeyPhaseBit bool + +const ( + // KeyPhaseZero is key phase 0 + KeyPhaseZero KeyPhaseBit = false + // KeyPhaseOne is key phase 1 + KeyPhaseOne KeyPhaseBit = true +) + +func (p KeyPhaseBit) String() string { + if p == KeyPhaseZero { + return "0" + } + return "1" +} + +func (p KeyPhaseBit) Next() KeyPhaseBit { + return !p +} diff --git a/internal/protocol/key_phase_test.go b/internal/protocol/key_phase_test.go new file mode 100644 index 00000000..67e3aabf --- /dev/null +++ b/internal/protocol/key_phase_test.go @@ -0,0 +1,18 @@ +package protocol + +import ( + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Key Phases", func() { + It("has the correct string representation", func() { + Expect(KeyPhaseZero.String()).To(Equal("0")) + Expect(KeyPhaseOne.String()).To(Equal("1")) + }) + + It("returns the next key phase", func() { + Expect(KeyPhaseZero.Next()).To(Equal(KeyPhaseOne)) + Expect(KeyPhaseOne.Next()).To(Equal(KeyPhaseZero)) + }) +}) diff --git a/internal/protocol/protocol.go b/internal/protocol/protocol.go index 9f19bef9..130cb1bf 100644 --- a/internal/protocol/protocol.go +++ b/internal/protocol/protocol.go @@ -34,27 +34,6 @@ func (t PacketType) String() string { } } -// KeyPhase is the key phase -type KeyPhase bool - -const ( - // KeyPhaseZero is key phase 0 - KeyPhaseZero KeyPhase = false - // KeyPhaseOne is key phase 1 - KeyPhaseOne KeyPhase = true -) - -func (p KeyPhase) String() string { - if p == KeyPhaseZero { - return "0" - } - return "1" -} - -func (p KeyPhase) Next() KeyPhase { - return !p -} - // A ByteCount in QUIC type ByteCount uint64 diff --git a/internal/protocol/protocol_test.go b/internal/protocol/protocol_test.go index 057e87b7..a89f732f 100644 --- a/internal/protocol/protocol_test.go +++ b/internal/protocol/protocol_test.go @@ -15,16 +15,4 @@ var _ = Describe("Protocol", func() { Expect(PacketType(10).String()).To(Equal("unknown packet type: 10")) }) }) - - Context("Key Phases", func() { - It("has the correct string representation", func() { - Expect(KeyPhaseZero.String()).To(Equal("0")) - Expect(KeyPhaseOne.String()).To(Equal("1")) - }) - - It("returns the next key phase", func() { - Expect(KeyPhaseZero.Next()).To(Equal(KeyPhaseOne)) - Expect(KeyPhaseOne.Next()).To(Equal(KeyPhaseZero)) - }) - }) }) diff --git a/internal/wire/extended_header.go b/internal/wire/extended_header.go index 081ed01a..4ab3951d 100644 --- a/internal/wire/extended_header.go +++ b/internal/wire/extended_header.go @@ -25,7 +25,7 @@ type ExtendedHeader struct { PacketNumberLen protocol.PacketNumberLen PacketNumber protocol.PacketNumber - KeyPhase protocol.KeyPhase + KeyPhase protocol.KeyPhaseBit } func (h *ExtendedHeader) parse(b *bytes.Reader, v protocol.VersionNumber) (*ExtendedHeader, error) { diff --git a/packet_packer.go b/packet_packer.go index 9c437445..419d90c1 100644 --- a/packet_packer.go +++ b/packet_packer.go @@ -432,7 +432,7 @@ func (p *packetPacker) composeNextPacket(maxFrameSize protocol.ByteCount) (paylo return payload, nil } -func (p *packetPacker) getShortHeader(kp protocol.KeyPhase) *wire.ExtendedHeader { +func (p *packetPacker) getShortHeader(kp protocol.KeyPhaseBit) *wire.ExtendedHeader { pn, pnLen := p.pnManager.PeekPacketNumber(protocol.Encryption1RTT) hdr := &wire.ExtendedHeader{} hdr.PacketNumber = pn