From f1c7af6d9a854d47b5299c3b19c10cf48b513c94 Mon Sep 17 00:00:00 2001 From: Jack Grigg Date: Mon, 28 Dec 2020 18:17:20 +0000 Subject: [PATCH] [WIP] age: Migrate to aead crate's STREAM implementation --- Cargo.lock | 4 +- Cargo.toml | 3 + age/Cargo.toml | 1 + age/src/primitives/stream.rs | 328 +++++++++++++++++++---------------- 4 files changed, 184 insertions(+), 152 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index cdcadef..79969eb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -9,8 +9,7 @@ checksum = "aae1277d39aeec15cb388266ecc24b11c80469deae6067e17a1a7aa9e5c1f234" [[package]] name = "aead" version = "0.3.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7fc95d1bdb8e6666b2b217308eeeb09f2d6728d104be3e31916cc74d15420331" +source = "git+https://github.com/RustCrypto/traits.git?branch=aead/stream#1c1bfb9504b1191118676dbe6ebb6c0f3a3d7675" dependencies = [ "generic-array", ] @@ -62,6 +61,7 @@ dependencies = [ name = "age" version = "0.5.0" dependencies = [ + "aead", "aes", "aes-ctr", "age-core", diff --git a/Cargo.toml b/Cargo.toml index 4bc7a34..4a8dc68 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,3 +4,6 @@ members = [ "age-core", "rage", ] + +[patch.crates-io] +aead = { git = "https://github.com/RustCrypto/traits.git", branch = "aead/stream", features = ["stream"] } \ No newline at end of file diff --git a/age/Cargo.toml b/age/Cargo.toml index b6d0add..5185b6c 100644 --- a/age/Cargo.toml +++ b/age/Cargo.toml @@ -21,6 +21,7 @@ age-core = { version = "0.5.0", path = "../age-core" } base64 = "0.12" # - ChaCha20-Poly1305 from RFC 7539 +aead = { version = "0.3", features = ["stream"] } c2-chacha = "0.3" chacha20poly1305 = { version = "0.7", default-features = false, features = ["alloc"] } diff --git a/age/src/primitives/stream.rs b/age/src/primitives/stream.rs index b58d129..6cb8e31 100644 --- a/age/src/primitives/stream.rs +++ b/age/src/primitives/stream.rs @@ -1,13 +1,17 @@ //! I/O helper structs for age file encryption and decryption. use chacha20poly1305::{ - aead::{generic_array::GenericArray, Aead, NewAead}, + aead::{ + self, + generic_array::{typenum::U12, GenericArray}, + stream::{Decryptor, Encryptor, StreamPrimitive}, + Aead, AeadInPlace, NewAead, + }, ChaChaPoly1305, }; use pin_project::pin_project; use secrecy::{ExposeSecret, SecretVec}; use std::cmp; -use std::convert::TryInto; use std::io::{self, Read, Seek, SeekFrom, Write}; use zeroize::Zeroize; @@ -24,6 +28,9 @@ const CHUNK_SIZE: usize = 64 * 1024; const TAG_SIZE: usize = 16; const ENCRYPTED_CHUNK_SIZE: usize = CHUNK_SIZE + TAG_SIZE; +type AgeEncryptor = Encryptor, Stream>; +type AgeDecryptor = Decryptor, Stream>; + pub(crate) struct PayloadKey( pub(crate) GenericArray as NewAead>::KeySize>, ); @@ -34,47 +41,6 @@ impl Drop for PayloadKey { } } -/// The nonce used in age's STREAM encryption. -/// -/// Structured as an 11 bytes of big endian counter, and 1 byte of last block flag -/// (`0x00 / 0x01`). We store this in the lower 12 bytes of a `u128`. -#[derive(Clone, Copy, Default)] -struct Nonce(u128); - -impl Nonce { - /// Unsets last-chunk flag. - fn set_counter(&mut self, val: u64) { - self.0 = u128::from(val) << 8; - } - - fn increment_counter(&mut self) { - // Increment the 11-byte counter - self.0 += 1 << 8; - if self.0 >> (8 * 12) != 0 { - panic!("We overflowed the nonce!"); - } - } - - fn is_last(&self) -> bool { - self.0 & 1 != 0 - } - - fn set_last(&mut self, last: bool) -> Result<(), ()> { - if !self.is_last() { - self.0 |= if last { 1 } else { 0 }; - Ok(()) - } else { - Err(()) - } - } - - fn to_bytes(&self) -> [u8; 12] { - self.0.to_be_bytes()[4..] - .try_into() - .expect("slice is correct length") - } -} - #[cfg(feature = "async")] struct EncryptedChunk { bytes: Vec, @@ -90,14 +56,12 @@ struct EncryptedChunk { /// [STREAM]: https://eprint.iacr.org/2015/189.pdf pub(crate) struct Stream { aead: ChaChaPoly1305, - nonce: Nonce, } impl Stream { fn new(key: PayloadKey) -> Self { Stream { aead: ChaChaPoly1305::new(&key.0), - nonce: Nonce::default(), } } @@ -110,7 +74,7 @@ impl Stream { /// [`HKDF`]: age_core::primitives::hkdf pub(crate) fn encrypt(key: PayloadKey, inner: W) -> StreamWriter { StreamWriter { - stream: Self::new(key), + stream: Self::new(key).encryptor(), inner, chunk: Vec::with_capacity(CHUNK_SIZE), #[cfg(feature = "async")] @@ -128,7 +92,7 @@ impl Stream { #[cfg(feature = "async")] pub(crate) fn encrypt_async(key: PayloadKey, inner: W) -> StreamWriter { StreamWriter { - stream: Self::new(key), + stream: Self::new(key).encryptor(), inner, chunk: Vec::with_capacity(CHUNK_SIZE), encrypted_chunk: None, @@ -144,7 +108,7 @@ impl Stream { /// [`HKDF`]: age_core::primitives::hkdf pub(crate) fn decrypt(key: PayloadKey, inner: R) -> StreamReader { StreamReader { - stream: Self::new(key), + stream: Self::new(key).decryptor(), inner, encrypted_chunk: vec![0; ENCRYPTED_CHUNK_SIZE], encrypted_pos: 0, @@ -164,7 +128,7 @@ impl Stream { #[cfg(feature = "async")] pub(crate) fn decrypt_async(key: PayloadKey, inner: R) -> StreamReader { StreamReader { - stream: Self::new(key), + stream: Self::new(key).decryptor(), inner, encrypted_chunk: vec![0; ENCRYPTED_CHUNK_SIZE], encrypted_pos: 0, @@ -174,51 +138,62 @@ impl Stream { } } - fn encrypt_chunk(&mut self, chunk: &[u8], last: bool) -> io::Result> { - assert!(chunk.len() <= CHUNK_SIZE); + /// Computes the nonce used in age's STREAM encryption. + /// + /// Structured as an 11 bytes of big endian counter, and 1 byte of last block flag + /// (`0x00 / 0x01`). We store this in the lower 12 bytes of a `u128`. + fn aead_nonce( + &self, + position: u128, + last_block: bool, + ) -> Result as AeadInPlace>::NonceSize>, aead::Error> + { + if position > Self::COUNTER_MAX { + return Err(aead::Error); + } - self.nonce.set_last(last).map_err(|_| { - io::Error::new(io::ErrorKind::WriteZero, "last chunk has been processed") - })?; + let position_with_flag = position | (last_block as u128); - let encrypted = self - .aead - .encrypt(&self.nonce.to_bytes().into(), chunk) - .expect("we will never hit chacha20::MAX_BLOCKS because of the chunk size"); - self.nonce.increment_counter(); + let mut result = GenericArray::default(); + result.copy_from_slice(&position_with_flag.to_be_bytes()[4..]); - Ok(encrypted) + Ok(result) + } +} + +impl StreamPrimitive> for Stream { + type NonceOverhead = U12; + type Counter = u128; + const COUNTER_INCR: u128 = 1 << 8; + const COUNTER_MAX: u128 = 0xffffffff_ffffffff_ffffff00; + + fn encrypt_in_place( + &self, + position: Self::Counter, + last_block: bool, + associated_data: &[u8], + buffer: &mut dyn aead::Buffer, + ) -> Result<(), aead::Error> { + let nonce = self.aead_nonce(position, last_block)?; + self.aead.encrypt_in_place(&nonce, associated_data, buffer) } - fn decrypt_chunk(&mut self, chunk: &[u8], last: bool) -> io::Result> { - assert!(chunk.len() <= ENCRYPTED_CHUNK_SIZE); - - self.nonce.set_last(last).map_err(|_| { - io::Error::new( - io::ErrorKind::UnexpectedEof, - "last chunk has been processed", - ) - })?; - - let decrypted = self - .aead - .decrypt(&self.nonce.to_bytes().into(), chunk) - .map(SecretVec::new) - .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "decryption error"))?; - self.nonce.increment_counter(); - - Ok(decrypted) - } - - fn is_complete(&self) -> bool { - self.nonce.is_last() + fn decrypt_in_place( + &self, + position: Self::Counter, + last_block: bool, + associated_data: &[u8], + buffer: &mut dyn aead::Buffer, + ) -> Result<(), aead::Error> { + let nonce = self.aead_nonce(position, last_block)?; + self.aead.decrypt_in_place(&nonce, associated_data, buffer) } } /// Writes an encrypted age file. #[pin_project(project = StreamWriterProj)] pub struct StreamWriter { - stream: Stream, + stream: AgeEncryptor, #[pin] inner: W, chunk: Vec, @@ -233,8 +208,14 @@ impl StreamWriter { /// encryption process. Failing to call `finish` will result in a truncated file that /// that will fail to decrypt. pub fn finish(mut self) -> io::Result { - let encrypted = self.stream.encrypt_chunk(&self.chunk, true)?; - self.inner.write_all(&encrypted)?; + self.stream + .encrypt_last_in_place(&[], &mut self.chunk) + .map_err(|_| { + // We will never hit chacha20::MAX_BLOCKS because of the chunk + // size, so this is the only possible error. + io::Error::new(io::ErrorKind::WriteZero, "last chunk has been processed") + })?; + self.inner.write_all(&self.chunk)?; Ok(self.inner) } } @@ -255,8 +236,14 @@ impl Write for StreamWriter { // Only encrypt the chunk if we have more data to write, as the last // chunk must be written in finish(). if !buf.is_empty() { - let encrypted = self.stream.encrypt_chunk(&self.chunk, false)?; - self.inner.write_all(&encrypted)?; + self.stream + .encrypt_next_in_place(&[], &mut self.chunk) + .map_err(|_| { + // We will never hit chacha20::MAX_BLOCKS because of the chunk + // size, so this is the only possible error. + io::Error::new(io::ErrorKind::WriteZero, "last chunk has been processed") + })?; + self.inner.write_all(&self.chunk)?; self.chunk.clear(); } } @@ -317,10 +304,15 @@ impl AsyncWrite for StreamWriter { // chunk must be written in poll_close(). if !buf.is_empty() { let this = self.as_mut().project(); - *this.encrypted_chunk = Some(EncryptedChunk { - bytes: this.stream.encrypt_chunk(&this.chunk, false)?, - offset: 0, - }); + let mut bytes = this.chunk.clone(); + this.stream + .encrypt_next_in_place(&[], &mut bytes) + .map_err(|_| { + // We will never hit chacha20::MAX_BLOCKS because of the chunk + // size, so this is the only possible error. + io::Error::new(io::ErrorKind::WriteZero, "last chunk has been processed") + })?; + *this.encrypted_chunk = Some(EncryptedChunk { bytes, offset: 0 }); this.chunk.clear(); } @@ -336,13 +328,19 @@ impl AsyncWrite for StreamWriter { // Flush any remaining encrypted chunk bytes. ready!(self.as_mut().poll_flush_chunk(cx))?; - if !self.stream.is_complete() { + if !self.chunk.is_empty() { // Finish the stream. let this = self.as_mut().project(); - *this.encrypted_chunk = Some(EncryptedChunk { - bytes: this.stream.encrypt_chunk(&this.chunk, true)?, - offset: 0, - }); + let mut bytes = this.chunk.clone(); + this.stream + .encrypt_last_in_place(&[], &mut bytes) + .map_err(|_| { + // We will never hit chacha20::MAX_BLOCKS because of the chunk + // size, so this is the only possible error. + io::Error::new(io::ErrorKind::WriteZero, "last chunk has been processed") + })?; + *this.encrypted_chunk = Some(EncryptedChunk { bytes, offset: 0 }); + this.chunk.clear(); } // Flush the final chunk (if we didn't in the first call). @@ -369,7 +367,7 @@ enum StartPos { /// Provides access to a decrypted age file. #[pin_project] pub struct StreamReader { - stream: Stream, + stream: AgeDecryptor, #[pin] inner: R, encrypted_chunk: Vec, @@ -392,23 +390,49 @@ impl StreamReader { let chunk = &self.encrypted_chunk[..self.encrypted_pos]; if chunk.is_empty() { - if !self.stream.is_complete() { - // Stream has ended before seeing the last chunk. - return Err(io::Error::new( - io::ErrorKind::UnexpectedEof, - "age file is truncated", - )); - } + // TODO + // if !self.stream.is_complete() { + // // Stream has ended before seeing the last chunk. + // return Err(io::Error::new( + // io::ErrorKind::UnexpectedEof, + // "age file is truncated", + // )); + // } } else { // This check works for all cases except when the age file is an integer // multiple of the chunk size. In that case, we try decrypting twice on a // decryption failure. let last = chunk.len() < ENCRYPTED_CHUNK_SIZE; - self.chunk = match (self.stream.decrypt_chunk(chunk, last), last) { - (Ok(chunk), _) => Some(chunk), - (Err(_), false) => Some(self.stream.decrypt_chunk(chunk, true)?), - (Err(e), true) => return Err(e), + let mut buffer = chunk.to_owned(); + let res = if last { + self.stream.decrypt_last_in_place(&[], &mut buffer) + } else { + self.stream.decrypt_next_in_place(&[], &mut buffer) + }; + + self.chunk = match (res, last) { + (Ok(()), _) => Some(SecretVec::new(buffer)), + (Err(_), false) => { + // We need to re-clone the encrypted bytes, because the buffer is + // clobbered in case of an error. + let mut buffer = chunk.to_owned(); + self.stream + .decrypt_last_in_place(&[], &mut buffer) + .map_err(|_| { + io::Error::new( + io::ErrorKind::InvalidData, + "last chunk has been processed", + ) + })?; + Some(SecretVec::new(buffer)) + } + (Err(_), true) => { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "last chunk has been processed", + )) + } }; } @@ -563,7 +587,8 @@ impl Seek for StreamReader { self.inner.seek(SeekFrom::Start( start + (target_chunk_index * ENCRYPTED_CHUNK_SIZE as u64), ))?; - self.stream.nonce.set_counter(target_chunk_index); + // TODO: Fix once aead::stream is seekable + // self.stream.nonce.set_counter(target_chunk_index); self.cur_plaintext_pos = target_chunk_index * CHUNK_SIZE as u64; // Read and drop bytes from the chunk to reach the target position. @@ -580,6 +605,7 @@ impl Seek for StreamReader { #[cfg(test)] mod tests { + use chacha20poly1305::aead::stream::StreamPrimitive; use secrecy::ExposeSecret; use std::io::{self, Cursor, Read, Seek, SeekFrom, Write}; @@ -598,59 +624,61 @@ mod tests { fn chunk_round_trip() { let data = vec![42; CHUNK_SIZE]; - let encrypted = { - let mut s = Stream::new(PayloadKey([7; 32].into())); - s.encrypt_chunk(&data, false).unwrap() + let mut encrypted = data.clone(); + { + let mut s = Stream::new(PayloadKey([7; 32].into())).encryptor(); + s.encrypt_next_in_place(&[], &mut encrypted).unwrap() }; - let decrypted = { - let mut s = Stream::new(PayloadKey([7; 32].into())); - s.decrypt_chunk(&encrypted, false).unwrap() - }; + let decrypted = encrypted.clone(); + { + let mut s = Stream::new(PayloadKey([7; 32].into())).decryptor(); + s.decrypt_next_in_place(&[], &mut decrypted).unwrap(); + } - assert_eq!(decrypted.expose_secret(), &data); + assert_eq!(&decrypted, &data); } - #[test] - fn last_chunk_round_trip() { - let data = vec![42; CHUNK_SIZE]; + // #[test] + // fn last_chunk_round_trip() { + // let data = vec![42; CHUNK_SIZE]; - let encrypted = { - let mut s = Stream::new(PayloadKey([7; 32].into())); - let res = s.encrypt_chunk(&data, true).unwrap(); + // let encrypted = { + // let mut s = Stream::new(PayloadKey([7; 32].into())); + // let res = s.encrypt_chunk(&data, true).unwrap(); - // Further calls return an error - assert_eq!( - s.encrypt_chunk(&data, false).unwrap_err().kind(), - io::ErrorKind::WriteZero - ); - assert_eq!( - s.encrypt_chunk(&data, true).unwrap_err().kind(), - io::ErrorKind::WriteZero - ); + // // Further calls return an error + // assert_eq!( + // s.encrypt_chunk(&data, false).unwrap_err().kind(), + // io::ErrorKind::WriteZero + // ); + // assert_eq!( + // s.encrypt_chunk(&data, true).unwrap_err().kind(), + // io::ErrorKind::WriteZero + // ); - res - }; + // res + // }; - let decrypted = { - let mut s = Stream::new(PayloadKey([7; 32].into())); - let res = s.decrypt_chunk(&encrypted, true).unwrap(); + // let decrypted = { + // let mut s = Stream::new(PayloadKey([7; 32].into())); + // let res = s.decrypt_chunk(&encrypted, true).unwrap(); - // Further calls return an error - match s.decrypt_chunk(&encrypted, false) { - Err(e) => assert_eq!(e.kind(), io::ErrorKind::UnexpectedEof), - _ => panic!("Expected error"), - } - match s.decrypt_chunk(&encrypted, true) { - Err(e) => assert_eq!(e.kind(), io::ErrorKind::UnexpectedEof), - _ => panic!("Expected error"), - } + // // Further calls return an error + // match s.decrypt_chunk(&encrypted, false) { + // Err(e) => assert_eq!(e.kind(), io::ErrorKind::UnexpectedEof), + // _ => panic!("Expected error"), + // } + // match s.decrypt_chunk(&encrypted, true) { + // Err(e) => assert_eq!(e.kind(), io::ErrorKind::UnexpectedEof), + // _ => panic!("Expected error"), + // } - res - }; + // res + // }; - assert_eq!(decrypted.expose_secret(), &data); - } + // assert_eq!(decrypted.expose_secret(), &data); + // } fn stream_round_trip(data: &[u8]) { let mut encrypted = vec![];