diff --git a/crypto/AEAD.go b/crypto/AEAD.go index c7169792..df7d7b4a 100644 --- a/crypto/AEAD.go +++ b/crypto/AEAD.go @@ -3,10 +3,12 @@ package crypto import ( "bytes" "io" + + "github.com/lucas-clemente/quic-go/protocol" ) // An AEAD implements QUIC's authenticated encryption and associated data type AEAD interface { - Open(packetNumber uint64, associatedData []byte, ciphertext io.Reader) (*bytes.Reader, error) - Seal(packetNumber uint64, b *bytes.Buffer, associatedData []byte, plaintext []byte) + Open(packetNumber protocol.PacketNumber, associatedData []byte, ciphertext io.Reader) (*bytes.Reader, error) + Seal(packetNumber protocol.PacketNumber, b *bytes.Buffer, associatedData []byte, plaintext []byte) } diff --git a/crypto/NullAEAD.go b/crypto/NullAEAD.go index 92df9494..42608e59 100644 --- a/crypto/NullAEAD.go +++ b/crypto/NullAEAD.go @@ -7,6 +7,7 @@ import ( "io" "io/ioutil" + "github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/utils" ) @@ -16,7 +17,7 @@ type NullAEAD struct{} var _ AEAD = &NullAEAD{} // Open and verify the ciphertext -func (*NullAEAD) Open(packetNumber uint64, associatedData []byte, r io.Reader) (*bytes.Reader, error) { +func (*NullAEAD) Open(packetNumber protocol.PacketNumber, associatedData []byte, r io.Reader) (*bytes.Reader, error) { ciphertext, err := ioutil.ReadAll(r) if err != nil { return nil, err @@ -40,7 +41,7 @@ func (*NullAEAD) Open(packetNumber uint64, associatedData []byte, r io.Reader) ( } // Seal writes hash and ciphertext to the buffer -func (*NullAEAD) Seal(packetNumber uint64, b *bytes.Buffer, associatedData []byte, plaintext []byte) { +func (*NullAEAD) Seal(packetNumber protocol.PacketNumber, b *bytes.Buffer, associatedData []byte, plaintext []byte) { hash := New128a() hash.Write(associatedData) hash.Write(plaintext) diff --git a/crypto/aes_gcm_aead.go b/crypto/aes_gcm_aead.go index 1d5973bb..3082fcf2 100644 --- a/crypto/aes_gcm_aead.go +++ b/crypto/aes_gcm_aead.go @@ -8,6 +8,8 @@ import ( "errors" "io" "io/ioutil" + + "github.com/lucas-clemente/quic-go/protocol" ) type aeadAESGCM struct { @@ -46,7 +48,7 @@ func NewAEADAESGCM(otherKey []byte, myKey []byte, otherIV []byte, myIV []byte) ( }, nil } -func (aead *aeadAESGCM) Open(packetNumber uint64, associatedData []byte, r io.Reader) (*bytes.Reader, error) { +func (aead *aeadAESGCM) Open(packetNumber protocol.PacketNumber, associatedData []byte, r io.Reader) (*bytes.Reader, error) { ciphertext, err := ioutil.ReadAll(r) if err != nil { return nil, err @@ -58,14 +60,14 @@ func (aead *aeadAESGCM) Open(packetNumber uint64, associatedData []byte, r io.Re return bytes.NewReader(plaintext), nil } -func (aead *aeadAESGCM) Seal(packetNumber uint64, b *bytes.Buffer, associatedData []byte, plaintext []byte) { +func (aead *aeadAESGCM) Seal(packetNumber protocol.PacketNumber, b *bytes.Buffer, associatedData []byte, plaintext []byte) { ciphertext := aead.encrypter.Seal(nil, makeNonce(aead.myIV, packetNumber), plaintext, associatedData) b.Write(ciphertext) } -func makeNonce(iv []byte, packetNumber uint64) []byte { +func makeNonce(iv []byte, packetNumber protocol.PacketNumber) []byte { res := make([]byte, 12) copy(res[0:4], iv) - binary.LittleEndian.PutUint64(res[4:12], packetNumber) + binary.LittleEndian.PutUint64(res[4:12], uint64(packetNumber)) return res } diff --git a/entropy_accumulator.go b/entropy_accumulator.go index cfa2fbdc..14594eef 100644 --- a/entropy_accumulator.go +++ b/entropy_accumulator.go @@ -1,10 +1,12 @@ package quic +import "github.com/lucas-clemente/quic-go/protocol" + // EntropyAccumulator accumulates the entropy according to the QUIC docs type EntropyAccumulator byte // Add the contribution of the entropy flag of a given packet number -func (e *EntropyAccumulator) Add(packetNumber uint64, entropyFlag bool) { +func (e *EntropyAccumulator) Add(packetNumber protocol.PacketNumber, entropyFlag bool) { if entropyFlag { (*e) ^= 0x01 << (packetNumber % 8) } diff --git a/protocol/protocol.go b/protocol/protocol.go new file mode 100644 index 00000000..e869b8be --- /dev/null +++ b/protocol/protocol.go @@ -0,0 +1,4 @@ +package protocol + +// A PacketNumber in QUIC +type PacketNumber uint64 diff --git a/public_header.go b/public_header.go index a77152ac..1a994c15 100644 --- a/public_header.go +++ b/public_header.go @@ -5,6 +5,7 @@ import ( "errors" "io" + "github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/utils" ) @@ -14,7 +15,7 @@ type PublicHeader struct { ResetFlag bool ConnectionID uint64 QuicVersion uint32 - PacketNumber uint64 + PacketNumber protocol.PacketNumber // packetNumberLen uint8 } @@ -85,10 +86,11 @@ func ParsePublicHeader(b io.ByteReader) (*PublicHeader, error) { } // Packet number - header.PacketNumber, err = utils.ReadUintN(b, packetNumberLen) + pcktNumber, err := utils.ReadUintN(b, packetNumberLen) if err != nil { return nil, err } + header.PacketNumber = protocol.PacketNumber(pcktNumber) return header, nil } diff --git a/public_header_test.go b/public_header_test.go index f5f6399f..0172eebf 100644 --- a/public_header_test.go +++ b/public_header_test.go @@ -4,6 +4,7 @@ import ( "bytes" "encoding/binary" + "github.com/lucas-clemente/quic-go/protocol" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" ) @@ -18,7 +19,7 @@ var _ = Describe("Public Header", func() { Expect(publicHeader.ResetFlag).To(BeFalse()) Expect(publicHeader.ConnectionID).To(Equal(uint64(0x4cfa9f9b668619f6))) Expect(publicHeader.QuicVersion).To(Equal(binary.BigEndian.Uint32([]byte("Q030")))) - Expect(publicHeader.PacketNumber).To(Equal(uint64(1))) + Expect(publicHeader.PacketNumber).To(Equal(protocol.PacketNumber(1))) Expect(b.Len()).To(BeZero()) }) @@ -52,7 +53,7 @@ var _ = Describe("Public Header", func() { b := bytes.NewReader([]byte{0x10, 0xde, 0xca}) publicHeader, err := ParsePublicHeader(b) Expect(err).ToNot(HaveOccurred()) - Expect(publicHeader.PacketNumber).To(Equal(uint64(0xcade))) + Expect(publicHeader.PacketNumber).To(Equal(protocol.PacketNumber(0xcade))) Expect(b.Len()).To(BeZero()) }) @@ -60,7 +61,7 @@ var _ = Describe("Public Header", func() { b := bytes.NewReader([]byte{0x20, 0xad, 0xfb, 0xca, 0xde}) publicHeader, err := ParsePublicHeader(b) Expect(err).ToNot(HaveOccurred()) - Expect(publicHeader.PacketNumber).To(Equal(uint64(0xdecafbad))) + Expect(publicHeader.PacketNumber).To(Equal(protocol.PacketNumber(0xdecafbad))) Expect(b.Len()).To(BeZero()) }) @@ -68,7 +69,7 @@ var _ = Describe("Public Header", func() { b := bytes.NewReader([]byte{0x30, 0x23, 0x42, 0xad, 0xfb, 0xca, 0xde}) publicHeader, err := ParsePublicHeader(b) Expect(err).ToNot(HaveOccurred()) - Expect(publicHeader.PacketNumber).To(Equal(uint64(0xdecafbad4223))) + Expect(publicHeader.PacketNumber).To(Equal(protocol.PacketNumber(0xdecafbad4223))) Expect(b.Len()).To(BeZero()) }) })