From d2e401ff0b5bf0d056e553075e16458b0d8a882b Mon Sep 17 00:00:00 2001 From: Brian Smith Date: Wed, 8 Jan 2025 15:53:08 -0800 Subject: [PATCH] chacha: Use `Overlapping` in the implementation of the fallback impl. Eliminate all of the `unsafe` in the fallback implementation. --- src/aead/chacha/fallback.rs | 31 ++++++++------- src/aead/overlapping/array.rs | 72 +++++++++++++++++++++++++++++++++++ src/aead/overlapping/base.rs | 57 ++++++++++++++++++++++++--- src/aead/overlapping/mod.rs | 4 ++ 4 files changed, 144 insertions(+), 20 deletions(-) create mode 100644 src/aead/overlapping/array.rs diff --git a/src/aead/chacha/fallback.rs b/src/aead/chacha/fallback.rs index 48f46d94b..3d57eb076 100644 --- a/src/aead/chacha/fallback.rs +++ b/src/aead/chacha/fallback.rs @@ -15,11 +15,11 @@ // Adapted from the public domain, estream code by D. Bernstein. // Adapted from the BoringSSL crypto/chacha/chacha.c. -use super::{Counter, Key, Overlapping, BLOCK_LEN}; +use super::{super::overlapping::IndexError, Counter, Key, Overlapping, BLOCK_LEN}; use crate::{constant_time, polyfill::sliceutil}; -use core::{mem::size_of, slice}; +use core::mem::size_of; -pub(super) fn ChaCha20_ctr32(key: &Key, counter: Counter, in_out: Overlapping<'_>) { +pub(super) fn ChaCha20_ctr32(key: &Key, counter: Counter, mut in_out: Overlapping<'_>) { const SIGMA: [u32; 4] = [ u32::from_le_bytes(*b"expa"), u32::from_le_bytes(*b"nd 3"), @@ -35,31 +35,34 @@ pub(super) fn ChaCha20_ctr32(key: &Key, counter: Counter, in_out: Overlapping<'_ key[6], key[7], counter[0], counter[1], counter[2], counter[3], ]; - let (mut input, mut output, mut in_out_len) = in_out.into_input_output_len(); + let mut in_out_len = in_out.len(); let mut buf = [0u8; BLOCK_LEN]; while in_out_len > 0 { chacha_core(&mut buf, &state); state[12] += 1; + debug_assert_eq!(in_out_len, in_out.len()); + // Both branches do the same thing, but the duplication helps the // compiler optimize (vectorize) the `BLOCK_LEN` case. if in_out_len >= BLOCK_LEN { - let input = unsafe { slice::from_raw_parts(input, BLOCK_LEN) }; - constant_time::xor_assign_at_start(&mut buf, input); - let output = unsafe { slice::from_raw_parts_mut(output, BLOCK_LEN) }; - sliceutil::overwrite_at_start(output, &buf); + in_out = in_out + .split_first_chunk::(|in_out| { + constant_time::xor_assign_at_start(&mut buf, in_out.input()); + sliceutil::overwrite_at_start(in_out.into_unwritten_output(), &buf); + }) + .unwrap_or_else(|IndexError { .. }| { + // Since `in_out_len == in_out.len() && in_out_len >= BLOCK_LEN`. + unreachable!() + }); } else { - let input = unsafe { slice::from_raw_parts(input, in_out_len) }; - constant_time::xor_assign_at_start(&mut buf, input); - let output = unsafe { slice::from_raw_parts_mut(output, in_out_len) }; - sliceutil::overwrite_at_start(output, &buf); + constant_time::xor_assign_at_start(&mut buf, in_out.input()); + sliceutil::overwrite_at_start(in_out.into_unwritten_output(), &buf); break; } in_out_len -= BLOCK_LEN; - input = unsafe { input.add(BLOCK_LEN) }; - output = unsafe { output.add(BLOCK_LEN) }; } } diff --git a/src/aead/overlapping/array.rs b/src/aead/overlapping/array.rs new file mode 100644 index 000000000..eb047aec0 --- /dev/null +++ b/src/aead/overlapping/array.rs @@ -0,0 +1,72 @@ +// Copyright 2024 Brian Smith. +// +// Permission to use, copy, modify, and/or distribute this software for any +// purpose with or without fee is hereby granted, provided that the above +// copyright notice and this permission notice appear in all copies. +// +// THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHORS DISCLAIM ALL WARRANTIES +// WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +// MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY +// SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +// WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION +// OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN +// CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +#![cfg_attr(not(test), allow(dead_code))] + +use super::Overlapping; +use core::array::TryFromSliceError; + +pub struct Array<'o, T, const N: usize> { + // Invariant: N != 0. + // Invariant: `self.in_out.len() == N`. + in_out: Overlapping<'o, T>, +} + +impl<'o, T, const N: usize> Array<'o, T, N> { + pub(super) fn new(in_out: Overlapping<'o, T>) -> Result { + if N == 0 || in_out.len() != N { + return Err(LenMismatchError::new(N)); + } + Ok(Self { in_out }) + } + + pub fn into_unwritten_output(self) -> &'o mut [T; N] + where + &'o mut [T]: TryInto<&'o mut [T; N], Error = TryFromSliceError>, + { + self.in_out + .into_unwritten_output() + .try_into() + .unwrap_or_else(|TryFromSliceError { .. }| { + unreachable!() // Due to invariant + }) + } +} + +impl Array<'_, T, N> { + pub fn input<'s>(&'s self) -> &'s [T; N] + where + &'s [T]: TryInto<&'s [T; N], Error = TryFromSliceError>, + { + self.in_out + .input() + .try_into() + .unwrap_or_else(|TryFromSliceError { .. }| { + unreachable!() // Due to invariant + }) + } +} + +pub struct LenMismatchError { + #[allow(dead_code)] + len: usize, +} + +impl LenMismatchError { + #[cold] + #[inline(never)] + fn new(len: usize) -> Self { + Self { len } + } +} diff --git a/src/aead/overlapping/base.rs b/src/aead/overlapping/base.rs index fb641306f..a3357f649 100644 --- a/src/aead/overlapping/base.rs +++ b/src/aead/overlapping/base.rs @@ -12,7 +12,8 @@ // OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN // CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. -use core::ops::RangeFrom; +use super::{Array, LenMismatchError}; +use core::{mem, ops::RangeFrom}; pub struct Overlapping<'o, T> { // Invariant: self.src.start <= in_out.len(). @@ -28,7 +29,7 @@ impl<'o, T> Overlapping<'o, T> { pub fn new(in_out: &'o mut [T], src: RangeFrom) -> Result { match in_out.get(src.clone()) { Some(_) => Ok(Self { in_out, src }), - None => Err(IndexError::new(src)), + None => Err(IndexError::new(src.start)), } } @@ -51,7 +52,7 @@ impl<'o, T> Overlapping<'o, T> { (self.in_out, self.src) } - pub(super) fn into_unwritten_output(self) -> &'o mut [T] { + pub fn into_unwritten_output(self) -> &'o mut [T] { let len = self.len(); self.in_out.get_mut(..len).unwrap_or_else(|| { // The invariant ensures this succeeds. @@ -83,14 +84,58 @@ impl Overlapping<'_, T> { let input = unsafe { output_const.add(self.src.start) }; (input, output, len) } + + // Perhaps unlike `slice::split_first_chunk_mut`, this is biased, + // performance-wise, against the case where `N > self.len()`, so callers + // should be structured to avoid that. + // + // If the result is `Err` then nothing was written to `self`; if anything + // was written then the result will not be `Err`. + #[cfg_attr(not(test), allow(dead_code))] + pub fn split_first_chunk( + mut self, + f: impl for<'a> FnOnce(Array<'a, T, N>), + ) -> Result { + let src = self.src.clone(); + let end = self + .src + .start + .checked_add(N) + .ok_or_else(|| IndexError::new(N))?; + let first = self + .in_out + .get_mut(..end) + .ok_or_else(|| IndexError::new(N))?; + let first = Overlapping::new(first, src).unwrap_or_else(|IndexError { .. }| { + // Since `end == src.start + N`. + unreachable!() + }); + let first = Array::new(first).unwrap_or_else(|LenMismatchError { .. }| { + // Since `end == src.start + N`. + unreachable!() + }); + // Once we call `f`, we must return `Ok` because `f` may have written + // over (part of) the input. + Ok({ + f(first); + let tail = mem::take(&mut self.in_out).get_mut(N..).unwrap_or_else(|| { + // There are at least `N` elements since `end == src.start + N`. + unreachable!() + }); + Self::new(tail, self.src).unwrap_or_else(|IndexError { .. }| { + // Follows from `end == src.start + N`. + unreachable!() + }) + }) + } } -pub struct IndexError(#[allow(dead_code)] RangeFrom); +pub struct IndexError(#[allow(dead_code)] usize); impl IndexError { #[cold] #[inline(never)] - fn new(src: RangeFrom) -> Self { - Self(src) + fn new(index: usize) -> Self { + Self(index) } } diff --git a/src/aead/overlapping/mod.rs b/src/aead/overlapping/mod.rs index 908b8b92c..22d60c8b5 100644 --- a/src/aead/overlapping/mod.rs +++ b/src/aead/overlapping/mod.rs @@ -13,9 +13,13 @@ // CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. pub use self::{ + array::Array, base::{IndexError, Overlapping}, partial_block::PartialBlock, }; +use self::array::LenMismatchError; + +mod array; mod base; mod partial_block;