[WIP] age: Migrate to aead crate's STREAM implementation

This commit is contained in:
Jack Grigg 2020-12-28 18:17:20 +00:00
parent c5c23d67bb
commit f1c7af6d9a
4 changed files with 184 additions and 152 deletions

4
Cargo.lock generated
View file

@ -9,8 +9,7 @@ checksum = "aae1277d39aeec15cb388266ecc24b11c80469deae6067e17a1a7aa9e5c1f234"
[[package]]
name = "aead"
version = "0.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7fc95d1bdb8e6666b2b217308eeeb09f2d6728d104be3e31916cc74d15420331"
source = "git+https://github.com/RustCrypto/traits.git?branch=aead/stream#1c1bfb9504b1191118676dbe6ebb6c0f3a3d7675"
dependencies = [
"generic-array",
]
@ -62,6 +61,7 @@ dependencies = [
name = "age"
version = "0.5.0"
dependencies = [
"aead",
"aes",
"aes-ctr",
"age-core",

View file

@ -4,3 +4,6 @@ members = [
"age-core",
"rage",
]
[patch.crates-io]
aead = { git = "https://github.com/RustCrypto/traits.git", branch = "aead/stream", features = ["stream"] }

View file

@ -21,6 +21,7 @@ age-core = { version = "0.5.0", path = "../age-core" }
base64 = "0.12"
# - ChaCha20-Poly1305 from RFC 7539
aead = { version = "0.3", features = ["stream"] }
c2-chacha = "0.3"
chacha20poly1305 = { version = "0.7", default-features = false, features = ["alloc"] }

View file

@ -1,13 +1,17 @@
//! I/O helper structs for age file encryption and decryption.
use chacha20poly1305::{
aead::{generic_array::GenericArray, Aead, NewAead},
aead::{
self,
generic_array::{typenum::U12, GenericArray},
stream::{Decryptor, Encryptor, StreamPrimitive},
Aead, AeadInPlace, NewAead,
},
ChaChaPoly1305,
};
use pin_project::pin_project;
use secrecy::{ExposeSecret, SecretVec};
use std::cmp;
use std::convert::TryInto;
use std::io::{self, Read, Seek, SeekFrom, Write};
use zeroize::Zeroize;
@ -24,6 +28,9 @@ const CHUNK_SIZE: usize = 64 * 1024;
const TAG_SIZE: usize = 16;
const ENCRYPTED_CHUNK_SIZE: usize = CHUNK_SIZE + TAG_SIZE;
type AgeEncryptor = Encryptor<ChaChaPoly1305<c2_chacha::Ietf>, Stream>;
type AgeDecryptor = Decryptor<ChaChaPoly1305<c2_chacha::Ietf>, Stream>;
pub(crate) struct PayloadKey(
pub(crate) GenericArray<u8, <ChaChaPoly1305<c2_chacha::Ietf> as NewAead>::KeySize>,
);
@ -34,47 +41,6 @@ impl Drop for PayloadKey {
}
}
/// The nonce used in age's STREAM encryption.
///
/// Structured as an 11 bytes of big endian counter, and 1 byte of last block flag
/// (`0x00 / 0x01`). We store this in the lower 12 bytes of a `u128`.
#[derive(Clone, Copy, Default)]
struct Nonce(u128);
impl Nonce {
/// Unsets last-chunk flag.
fn set_counter(&mut self, val: u64) {
self.0 = u128::from(val) << 8;
}
fn increment_counter(&mut self) {
// Increment the 11-byte counter
self.0 += 1 << 8;
if self.0 >> (8 * 12) != 0 {
panic!("We overflowed the nonce!");
}
}
fn is_last(&self) -> bool {
self.0 & 1 != 0
}
fn set_last(&mut self, last: bool) -> Result<(), ()> {
if !self.is_last() {
self.0 |= if last { 1 } else { 0 };
Ok(())
} else {
Err(())
}
}
fn to_bytes(&self) -> [u8; 12] {
self.0.to_be_bytes()[4..]
.try_into()
.expect("slice is correct length")
}
}
#[cfg(feature = "async")]
struct EncryptedChunk {
bytes: Vec<u8>,
@ -90,14 +56,12 @@ struct EncryptedChunk {
/// [STREAM]: https://eprint.iacr.org/2015/189.pdf
pub(crate) struct Stream {
aead: ChaChaPoly1305<c2_chacha::Ietf>,
nonce: Nonce,
}
impl Stream {
fn new(key: PayloadKey) -> Self {
Stream {
aead: ChaChaPoly1305::new(&key.0),
nonce: Nonce::default(),
}
}
@ -110,7 +74,7 @@ impl Stream {
/// [`HKDF`]: age_core::primitives::hkdf
pub(crate) fn encrypt<W: Write>(key: PayloadKey, inner: W) -> StreamWriter<W> {
StreamWriter {
stream: Self::new(key),
stream: Self::new(key).encryptor(),
inner,
chunk: Vec::with_capacity(CHUNK_SIZE),
#[cfg(feature = "async")]
@ -128,7 +92,7 @@ impl Stream {
#[cfg(feature = "async")]
pub(crate) fn encrypt_async<W: AsyncWrite>(key: PayloadKey, inner: W) -> StreamWriter<W> {
StreamWriter {
stream: Self::new(key),
stream: Self::new(key).encryptor(),
inner,
chunk: Vec::with_capacity(CHUNK_SIZE),
encrypted_chunk: None,
@ -144,7 +108,7 @@ impl Stream {
/// [`HKDF`]: age_core::primitives::hkdf
pub(crate) fn decrypt<R: Read>(key: PayloadKey, inner: R) -> StreamReader<R> {
StreamReader {
stream: Self::new(key),
stream: Self::new(key).decryptor(),
inner,
encrypted_chunk: vec![0; ENCRYPTED_CHUNK_SIZE],
encrypted_pos: 0,
@ -164,7 +128,7 @@ impl Stream {
#[cfg(feature = "async")]
pub(crate) fn decrypt_async<R: AsyncRead>(key: PayloadKey, inner: R) -> StreamReader<R> {
StreamReader {
stream: Self::new(key),
stream: Self::new(key).decryptor(),
inner,
encrypted_chunk: vec![0; ENCRYPTED_CHUNK_SIZE],
encrypted_pos: 0,
@ -174,51 +138,62 @@ impl Stream {
}
}
fn encrypt_chunk(&mut self, chunk: &[u8], last: bool) -> io::Result<Vec<u8>> {
assert!(chunk.len() <= CHUNK_SIZE);
/// Computes the nonce used in age's STREAM encryption.
///
/// Structured as an 11 bytes of big endian counter, and 1 byte of last block flag
/// (`0x00 / 0x01`). We store this in the lower 12 bytes of a `u128`.
fn aead_nonce(
&self,
position: u128,
last_block: bool,
) -> Result<aead::Nonce<<ChaChaPoly1305<c2_chacha::Ietf> as AeadInPlace>::NonceSize>, aead::Error>
{
if position > Self::COUNTER_MAX {
return Err(aead::Error);
}
self.nonce.set_last(last).map_err(|_| {
io::Error::new(io::ErrorKind::WriteZero, "last chunk has been processed")
})?;
let position_with_flag = position | (last_block as u128);
let encrypted = self
.aead
.encrypt(&self.nonce.to_bytes().into(), chunk)
.expect("we will never hit chacha20::MAX_BLOCKS because of the chunk size");
self.nonce.increment_counter();
let mut result = GenericArray::default();
result.copy_from_slice(&position_with_flag.to_be_bytes()[4..]);
Ok(encrypted)
Ok(result)
}
}
impl StreamPrimitive<ChaChaPoly1305<c2_chacha::Ietf>> for Stream {
type NonceOverhead = U12;
type Counter = u128;
const COUNTER_INCR: u128 = 1 << 8;
const COUNTER_MAX: u128 = 0xffffffff_ffffffff_ffffff00;
fn encrypt_in_place(
&self,
position: Self::Counter,
last_block: bool,
associated_data: &[u8],
buffer: &mut dyn aead::Buffer,
) -> Result<(), aead::Error> {
let nonce = self.aead_nonce(position, last_block)?;
self.aead.encrypt_in_place(&nonce, associated_data, buffer)
}
fn decrypt_chunk(&mut self, chunk: &[u8], last: bool) -> io::Result<SecretVec<u8>> {
assert!(chunk.len() <= ENCRYPTED_CHUNK_SIZE);
self.nonce.set_last(last).map_err(|_| {
io::Error::new(
io::ErrorKind::UnexpectedEof,
"last chunk has been processed",
)
})?;
let decrypted = self
.aead
.decrypt(&self.nonce.to_bytes().into(), chunk)
.map(SecretVec::new)
.map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "decryption error"))?;
self.nonce.increment_counter();
Ok(decrypted)
}
fn is_complete(&self) -> bool {
self.nonce.is_last()
fn decrypt_in_place(
&self,
position: Self::Counter,
last_block: bool,
associated_data: &[u8],
buffer: &mut dyn aead::Buffer,
) -> Result<(), aead::Error> {
let nonce = self.aead_nonce(position, last_block)?;
self.aead.decrypt_in_place(&nonce, associated_data, buffer)
}
}
/// Writes an encrypted age file.
#[pin_project(project = StreamWriterProj)]
pub struct StreamWriter<W> {
stream: Stream,
stream: AgeEncryptor,
#[pin]
inner: W,
chunk: Vec<u8>,
@ -233,8 +208,14 @@ impl<W: Write> StreamWriter<W> {
/// encryption process. Failing to call `finish` will result in a truncated file that
/// that will fail to decrypt.
pub fn finish(mut self) -> io::Result<W> {
let encrypted = self.stream.encrypt_chunk(&self.chunk, true)?;
self.inner.write_all(&encrypted)?;
self.stream
.encrypt_last_in_place(&[], &mut self.chunk)
.map_err(|_| {
// We will never hit chacha20::MAX_BLOCKS because of the chunk
// size, so this is the only possible error.
io::Error::new(io::ErrorKind::WriteZero, "last chunk has been processed")
})?;
self.inner.write_all(&self.chunk)?;
Ok(self.inner)
}
}
@ -255,8 +236,14 @@ impl<W: Write> Write for StreamWriter<W> {
// Only encrypt the chunk if we have more data to write, as the last
// chunk must be written in finish().
if !buf.is_empty() {
let encrypted = self.stream.encrypt_chunk(&self.chunk, false)?;
self.inner.write_all(&encrypted)?;
self.stream
.encrypt_next_in_place(&[], &mut self.chunk)
.map_err(|_| {
// We will never hit chacha20::MAX_BLOCKS because of the chunk
// size, so this is the only possible error.
io::Error::new(io::ErrorKind::WriteZero, "last chunk has been processed")
})?;
self.inner.write_all(&self.chunk)?;
self.chunk.clear();
}
}
@ -317,10 +304,15 @@ impl<W: AsyncWrite> AsyncWrite for StreamWriter<W> {
// chunk must be written in poll_close().
if !buf.is_empty() {
let this = self.as_mut().project();
*this.encrypted_chunk = Some(EncryptedChunk {
bytes: this.stream.encrypt_chunk(&this.chunk, false)?,
offset: 0,
});
let mut bytes = this.chunk.clone();
this.stream
.encrypt_next_in_place(&[], &mut bytes)
.map_err(|_| {
// We will never hit chacha20::MAX_BLOCKS because of the chunk
// size, so this is the only possible error.
io::Error::new(io::ErrorKind::WriteZero, "last chunk has been processed")
})?;
*this.encrypted_chunk = Some(EncryptedChunk { bytes, offset: 0 });
this.chunk.clear();
}
@ -336,13 +328,19 @@ impl<W: AsyncWrite> AsyncWrite for StreamWriter<W> {
// Flush any remaining encrypted chunk bytes.
ready!(self.as_mut().poll_flush_chunk(cx))?;
if !self.stream.is_complete() {
if !self.chunk.is_empty() {
// Finish the stream.
let this = self.as_mut().project();
*this.encrypted_chunk = Some(EncryptedChunk {
bytes: this.stream.encrypt_chunk(&this.chunk, true)?,
offset: 0,
});
let mut bytes = this.chunk.clone();
this.stream
.encrypt_last_in_place(&[], &mut bytes)
.map_err(|_| {
// We will never hit chacha20::MAX_BLOCKS because of the chunk
// size, so this is the only possible error.
io::Error::new(io::ErrorKind::WriteZero, "last chunk has been processed")
})?;
*this.encrypted_chunk = Some(EncryptedChunk { bytes, offset: 0 });
this.chunk.clear();
}
// Flush the final chunk (if we didn't in the first call).
@ -369,7 +367,7 @@ enum StartPos {
/// Provides access to a decrypted age file.
#[pin_project]
pub struct StreamReader<R> {
stream: Stream,
stream: AgeDecryptor,
#[pin]
inner: R,
encrypted_chunk: Vec<u8>,
@ -392,23 +390,49 @@ impl<R> StreamReader<R> {
let chunk = &self.encrypted_chunk[..self.encrypted_pos];
if chunk.is_empty() {
if !self.stream.is_complete() {
// Stream has ended before seeing the last chunk.
return Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
"age file is truncated",
));
}
// TODO
// if !self.stream.is_complete() {
// // Stream has ended before seeing the last chunk.
// return Err(io::Error::new(
// io::ErrorKind::UnexpectedEof,
// "age file is truncated",
// ));
// }
} else {
// This check works for all cases except when the age file is an integer
// multiple of the chunk size. In that case, we try decrypting twice on a
// decryption failure.
let last = chunk.len() < ENCRYPTED_CHUNK_SIZE;
self.chunk = match (self.stream.decrypt_chunk(chunk, last), last) {
(Ok(chunk), _) => Some(chunk),
(Err(_), false) => Some(self.stream.decrypt_chunk(chunk, true)?),
(Err(e), true) => return Err(e),
let mut buffer = chunk.to_owned();
let res = if last {
self.stream.decrypt_last_in_place(&[], &mut buffer)
} else {
self.stream.decrypt_next_in_place(&[], &mut buffer)
};
self.chunk = match (res, last) {
(Ok(()), _) => Some(SecretVec::new(buffer)),
(Err(_), false) => {
// We need to re-clone the encrypted bytes, because the buffer is
// clobbered in case of an error.
let mut buffer = chunk.to_owned();
self.stream
.decrypt_last_in_place(&[], &mut buffer)
.map_err(|_| {
io::Error::new(
io::ErrorKind::InvalidData,
"last chunk has been processed",
)
})?;
Some(SecretVec::new(buffer))
}
(Err(_), true) => {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"last chunk has been processed",
))
}
};
}
@ -563,7 +587,8 @@ impl<R: Read + Seek> Seek for StreamReader<R> {
self.inner.seek(SeekFrom::Start(
start + (target_chunk_index * ENCRYPTED_CHUNK_SIZE as u64),
))?;
self.stream.nonce.set_counter(target_chunk_index);
// TODO: Fix once aead::stream is seekable
// self.stream.nonce.set_counter(target_chunk_index);
self.cur_plaintext_pos = target_chunk_index * CHUNK_SIZE as u64;
// Read and drop bytes from the chunk to reach the target position.
@ -580,6 +605,7 @@ impl<R: Read + Seek> Seek for StreamReader<R> {
#[cfg(test)]
mod tests {
use chacha20poly1305::aead::stream::StreamPrimitive;
use secrecy::ExposeSecret;
use std::io::{self, Cursor, Read, Seek, SeekFrom, Write};
@ -598,59 +624,61 @@ mod tests {
fn chunk_round_trip() {
let data = vec![42; CHUNK_SIZE];
let encrypted = {
let mut s = Stream::new(PayloadKey([7; 32].into()));
s.encrypt_chunk(&data, false).unwrap()
let mut encrypted = data.clone();
{
let mut s = Stream::new(PayloadKey([7; 32].into())).encryptor();
s.encrypt_next_in_place(&[], &mut encrypted).unwrap()
};
let decrypted = {
let mut s = Stream::new(PayloadKey([7; 32].into()));
s.decrypt_chunk(&encrypted, false).unwrap()
};
let decrypted = encrypted.clone();
{
let mut s = Stream::new(PayloadKey([7; 32].into())).decryptor();
s.decrypt_next_in_place(&[], &mut decrypted).unwrap();
}
assert_eq!(decrypted.expose_secret(), &data);
assert_eq!(&decrypted, &data);
}
#[test]
fn last_chunk_round_trip() {
let data = vec![42; CHUNK_SIZE];
// #[test]
// fn last_chunk_round_trip() {
// let data = vec![42; CHUNK_SIZE];
let encrypted = {
let mut s = Stream::new(PayloadKey([7; 32].into()));
let res = s.encrypt_chunk(&data, true).unwrap();
// let encrypted = {
// let mut s = Stream::new(PayloadKey([7; 32].into()));
// let res = s.encrypt_chunk(&data, true).unwrap();
// Further calls return an error
assert_eq!(
s.encrypt_chunk(&data, false).unwrap_err().kind(),
io::ErrorKind::WriteZero
);
assert_eq!(
s.encrypt_chunk(&data, true).unwrap_err().kind(),
io::ErrorKind::WriteZero
);
// // Further calls return an error
// assert_eq!(
// s.encrypt_chunk(&data, false).unwrap_err().kind(),
// io::ErrorKind::WriteZero
// );
// assert_eq!(
// s.encrypt_chunk(&data, true).unwrap_err().kind(),
// io::ErrorKind::WriteZero
// );
res
};
// res
// };
let decrypted = {
let mut s = Stream::new(PayloadKey([7; 32].into()));
let res = s.decrypt_chunk(&encrypted, true).unwrap();
// let decrypted = {
// let mut s = Stream::new(PayloadKey([7; 32].into()));
// let res = s.decrypt_chunk(&encrypted, true).unwrap();
// Further calls return an error
match s.decrypt_chunk(&encrypted, false) {
Err(e) => assert_eq!(e.kind(), io::ErrorKind::UnexpectedEof),
_ => panic!("Expected error"),
}
match s.decrypt_chunk(&encrypted, true) {
Err(e) => assert_eq!(e.kind(), io::ErrorKind::UnexpectedEof),
_ => panic!("Expected error"),
}
// // Further calls return an error
// match s.decrypt_chunk(&encrypted, false) {
// Err(e) => assert_eq!(e.kind(), io::ErrorKind::UnexpectedEof),
// _ => panic!("Expected error"),
// }
// match s.decrypt_chunk(&encrypted, true) {
// Err(e) => assert_eq!(e.kind(), io::ErrorKind::UnexpectedEof),
// _ => panic!("Expected error"),
// }
res
};
// res
// };
assert_eq!(decrypted.expose_secret(), &data);
}
// assert_eq!(decrypted.expose_secret(), &data);
// }
fn stream_round_trip(data: &[u8]) {
let mut encrypted = vec![];