diff --git a/src/bigint.rs b/src/bigint.rs index 891eeb46..9e918bec 100644 --- a/src/bigint.rs +++ b/src/bigint.rs @@ -1095,6 +1095,19 @@ impl BigInt { // The top bit may have been cleared, so normalize self.normalize(); } + + /// Sets the sign of the `BigInt`. + /// Does not change the sign if the value is 0. + /// If the sign is `Sign::NoSign`, the value will be set to 0. + #[inline] + pub fn set_sign(&mut self, sign: Sign) { + if sign == Sign::NoSign { + self.set_zero(); + } + if !self.is_zero() { + self.sign = sign; + } + } } #[test] diff --git a/src/bigint/multiplication.rs b/src/bigint/multiplication.rs index a2d97081..0d7a0eef 100644 --- a/src/bigint/multiplication.rs +++ b/src/bigint/multiplication.rs @@ -214,4 +214,23 @@ impl CheckedMul for BigInt { } } +/// Equivalent to `self.set_sign(self.sign() * rhs)`. +impl Mul for BigInt { + type Output = BigInt; + + #[inline] + fn mul(mut self, rhs: Sign) -> Self::Output { + self *= rhs; + self + } +} + +/// Equivalent to `self.set_sign(self.sign() * rhs)`. +impl MulAssign for BigInt { + #[inline] + fn mul_assign(&mut self, rhs: Sign) { + self.set_sign(self.sign() * rhs); + } +} + impl_product_iter_type!(BigInt); diff --git a/tests/bigint.rs b/tests/bigint.rs index f244bc4b..0c325756 100644 --- a/tests/bigint.rs +++ b/tests/bigint.rs @@ -746,6 +746,21 @@ fn test_mul() { } } +#[test] +fn test_mul_sign() { + assert_eq!(BigInt::zero() * Plus, BigInt::zero()); + assert_eq!(BigInt::zero() * Minus, BigInt::zero()); + assert_eq!(BigInt::zero() * NoSign, BigInt::zero()); + + assert_eq!(BigInt::one() * Plus, BigInt::one()); + assert_eq!(BigInt::one() * Minus, -BigInt::one()); + assert_eq!(BigInt::one() * NoSign, BigInt::zero()); + + assert_eq!((-BigInt::one()) * Plus, -BigInt::one()); + assert_eq!((-BigInt::one()) * Minus, BigInt::one()); + assert_eq!((-BigInt::one()) * NoSign, BigInt::zero()); +} + #[test] fn test_div_mod_floor() { fn check_sub(a: &BigInt, b: &BigInt, ans_d: &BigInt, ans_m: &BigInt) { @@ -1404,3 +1419,31 @@ fn test_set_bit() { x.set_bit(0, false); assert_eq!(x, BigInt::from_biguint(Minus, BigUint::one() << 200)); } + +#[test] +fn test_set_sign() { + // Zero should be unaffected. + for sign in &[Plus, Minus, NoSign] { + let mut x = BigInt::zero(); + x.set_sign(*sign); + assert!(x.is_zero()); + } + + // Since the only thing different about the two numbers is their signs, + // the `set_sign` operation should behave the same. + for orig_x in &[BigInt::one(), -BigInt::one()] { + let mut x = orig_x.clone(); + x.set_sign(Plus); + assert_eq!(x.sign(), Plus); + assert_eq!(x, BigInt::one()); + + x = orig_x.clone(); + x.set_sign(Minus); + assert_eq!(x.sign(), Minus); + assert_eq!(x, -BigInt::one()); + + x = orig_x.clone(); + x.set_sign(NoSign); + assert!(x.is_zero()); + } +}