From 3fbcec4cdbe0e53e42bbcaaf22ba2dd09d368499 Mon Sep 17 00:00:00 2001
From: Brian Smith <brian@briansmith.org>
Date: Sun, 26 Jan 2025 11:17:39 -0800
Subject: [PATCH] arithmetic: Clarify memory safety of some binary-ish ops.

---
 src/arithmetic/bigint.rs | 28 +++++++++++++++-------------
 src/ec/suite_b/ops.rs    |  3 ++-
 src/limb.rs              | 17 ++++++++++++-----
 3 files changed, 29 insertions(+), 19 deletions(-)

diff --git a/src/arithmetic/bigint.rs b/src/arithmetic/bigint.rs
index f32ca8c928..5336b0bb86 100644
--- a/src/arithmetic/bigint.rs
+++ b/src/arithmetic/bigint.rs
@@ -42,7 +42,7 @@ pub(crate) use self::{
     modulusvalue::OwnedModulusValue,
     private_exponent::PrivateExponent,
 };
-use super::{montgomery::*, LimbSliceError, MAX_LIMBS};
+use super::{inout::AliasingSlices3, montgomery::*, LimbSliceError, MAX_LIMBS};
 use crate::{
     bits::BitLength,
     c,
@@ -50,7 +50,10 @@ use crate::{
     limb::{self, Limb, LIMB_BITS},
 };
 use alloc::vec;
-use core::{marker::PhantomData, num::NonZeroU64};
+use core::{
+    marker::PhantomData,
+    num::{NonZeroU64, NonZeroUsize},
+};
 
 mod boxed_limbs;
 mod modulus;
@@ -233,7 +236,8 @@ pub fn elem_widen<Larger, Smaller>(
 
 // TODO: Document why this works for all Montgomery factors.
 pub fn elem_add<M, E>(mut a: Elem<M, E>, b: Elem<M, E>, m: &Modulus<M>) -> Elem<M, E> {
-    limb::limbs_add_assign_mod(&mut a.limbs, &b.limbs, m.limbs());
+    limb::limbs_add_assign_mod(&mut a.limbs, &b.limbs, m.limbs())
+        .unwrap_or_else(unwrap_impossible_len_mismatch_error);
     a
 }
 
@@ -246,18 +250,16 @@ pub fn elem_sub<M, E>(mut a: Elem<M, E>, b: &Elem<M, E>, m: &Modulus<M>) -> Elem
             a: *const Limb,
             b: *const Limb,
             m: *const Limb,
-            num_limbs: c::size_t,
-        );
-    }
-    unsafe {
-        LIMBS_sub_mod(
-            a.limbs.as_mut_ptr(),
-            a.limbs.as_ptr(),
-            b.limbs.as_ptr(),
-            m.limbs().as_ptr(),
-            m.limbs().len(),
+            num_limbs: c::NonZero_size_t,
         );
     }
+    let num_limbs = NonZeroUsize::new(m.limbs().len()).unwrap();
+    (a.limbs.as_mut(), b.limbs.as_ref())
+        .with_non_dangling_non_null_pointers_rab(num_limbs, |r, a, b| {
+            let m = m.limbs().as_ptr(); // Also non-dangling because num_limbs is non-zero.
+            unsafe { LIMBS_sub_mod(r, a, b, m, num_limbs) }
+        })
+        .unwrap_or_else(unwrap_impossible_len_mismatch_error);
     a
 }
 
diff --git a/src/ec/suite_b/ops.rs b/src/ec/suite_b/ops.rs
index baee45881a..73c948e9ca 100644
--- a/src/ec/suite_b/ops.rs
+++ b/src/ec/suite_b/ops.rs
@@ -137,7 +137,8 @@ impl<M> Modulus<M> {
             &mut a.limbs[..num_limbs],
             &b.limbs[..num_limbs],
             &self.limbs[..num_limbs],
-        );
+        )
+        .unwrap_or_else(unwrap_impossible_len_mismatch_error)
     }
 }
 
diff --git a/src/limb.rs b/src/limb.rs
index e1e448f792..2a82eb7cbc 100644
--- a/src/limb.rs
+++ b/src/limb.rs
@@ -19,6 +19,7 @@
 //! limbs use the native endianness.
 
 use crate::{
+    arithmetic::inout::AliasingSlices3,
     c, constant_time,
     error::{self, LenMismatchError},
     polyfill::{sliceutil, usize_from_u32, ArrayFlatMap},
@@ -325,9 +326,11 @@ pub fn fold_5_bit_windows<R, I: FnOnce(Window) -> R, F: Fn(R, Window) -> R>(
 }
 
 #[inline]
-pub(crate) fn limbs_add_assign_mod(a: &mut [Limb], b: &[Limb], m: &[Limb]) {
-    debug_assert_eq!(a.len(), m.len());
-    debug_assert_eq!(b.len(), m.len());
+pub(crate) fn limbs_add_assign_mod(
+    a: &mut [Limb],
+    b: &[Limb],
+    m: &[Limb],
+) -> Result<(), LenMismatchError> {
     prefixed_extern! {
         // `r` and `a` may alias.
         fn LIMBS_add_mod(
@@ -335,10 +338,14 @@ pub(crate) fn limbs_add_assign_mod(a: &mut [Limb], b: &[Limb], m: &[Limb]) {
             a: *const Limb,
             b: *const Limb,
             m: *const Limb,
-            num_limbs: c::size_t,
+            num_limbs: c::NonZero_size_t,
         );
     }
-    unsafe { LIMBS_add_mod(a.as_mut_ptr(), a.as_ptr(), b.as_ptr(), m.as_ptr(), m.len()) }
+    let num_limbs = NonZeroUsize::new(m.len()).ok_or_else(|| LenMismatchError::new(m.len()))?;
+    (a, b).with_non_dangling_non_null_pointers_rab(num_limbs, |r, a, b| {
+        let m = m.as_ptr(); // Also non-dangling because `num_limbs` is non-zero.
+        unsafe { LIMBS_add_mod(r, a, b, m, num_limbs) }
+    })
 }
 
 // r *= 2 (mod m).