diff --git a/crates/cdk/src/mint/mod.rs b/crates/cdk/src/mint/mod.rs index 3a1fc751b..613964307 100644 --- a/crates/cdk/src/mint/mod.rs +++ b/crates/cdk/src/mint/mod.rs @@ -590,10 +590,7 @@ fn create_new_keyset( } fn derivation_path_from_unit(unit: CurrencyUnit, index: u32) -> Option { - let unit_index = match unit.derivation_index() { - Some(index) => index, - None => return None, - }; + let unit_index = unit.derivation_index(); Some(DerivationPath::from(vec![ ChildNumber::from_hardened_idx(0).expect("0 is a valid index"), @@ -613,7 +610,7 @@ mod tests { use crate::types::LnKey; #[test] - fn mint_mod_generate_keyset_from_seed() { + fn test_generate_keyset_from_seed() { let seed = "test_seed".as_bytes(); let keyset = MintKeySet::generate_from_seed( &Secp256k1::new(), @@ -655,7 +652,7 @@ mod tests { } #[test] - fn mint_mod_generate_keyset_from_xpriv() { + fn test_generate_keyset_from_xpriv() { let seed = "test_seed".as_bytes(); let network = Network::Bitcoin; let xpriv = Xpriv::new_master(network, seed).expect("Failed to create xpriv"); @@ -698,6 +695,28 @@ mod tests { assert_eq!(amounts_and_pubkeys, expected_amounts_and_pubkeys); } + #[test] + fn test_derivation_path_from_unit() { + let test_cases = vec![ + // min value for a hardened derivation path index + (CurrencyUnit::Sat, 0, "0'/0'/0'"), + // max value for a hardened derivation path index + (CurrencyUnit::Msat, i32::MAX as u32, "0'/1'/2147483647'"), + (CurrencyUnit::Usd, 21, "0'/2'/21'"), + (CurrencyUnit::Eur, 1337, "0'/3'/1337'"), + ( + CurrencyUnit::Custom("DOGE".to_string(), 69), + 420, + "0'/69'/420'", + ), + ]; + + for (unit, index, expected) in test_cases { + let path = derivation_path_from_unit(unit, index).unwrap(); + assert_eq!(path.to_string(), expected); + } + } + use cdk_database::mint_memory::MintMemoryDatabase; #[derive(Default)] diff --git a/crates/cdk/src/nuts/nut00/mod.rs b/crates/cdk/src/nuts/nut00/mod.rs index 929b0c058..0435cb671 100644 --- a/crates/cdk/src/nuts/nut00/mod.rs +++ b/crates/cdk/src/nuts/nut00/mod.rs @@ -374,19 +374,37 @@ pub enum CurrencyUnit { /// Euro Eur, /// Custom currency unit - Custom(String), + Custom(String, u32), +} + +impl CurrencyUnit { + /// Constructor for `CurrencyUnit::Custom` + pub fn new_custom(name: String, index: u32) -> Result { + match index { + 0..=3 => Err(format!( + "Index {} is reserved and cannot be used for custom currency units.", + index + )), + i if i > i32::MAX as u32 => Err(format!( + "Index {} exceeds maximum allowed value of {}.", + index, + i32::MAX + )), + _ => Ok(Self::Custom(name, index)), + } + } } #[cfg(feature = "mint")] impl CurrencyUnit { /// Derivation index mint will use for unit - pub fn derivation_index(&self) -> Option { + pub fn derivation_index(&self) -> u32 { match self { - Self::Sat => Some(0), - Self::Msat => Some(1), - Self::Usd => Some(2), - Self::Eur => Some(3), - _ => None, + Self::Sat => 0, + Self::Msat => 1, + Self::Usd => 2, + Self::Eur => 3, + Self::Custom(_, index) => *index, } } } @@ -394,13 +412,23 @@ impl CurrencyUnit { impl FromStr for CurrencyUnit { type Err = Error; fn from_str(value: &str) -> Result { - let value = &value.to_uppercase(); - match value.as_str() { + // Split on ':' to check for derivation index + let parts: Vec<&str> = value.split(':').collect(); + let currency = parts[0].to_uppercase(); + + match currency.as_str() { "SAT" => Ok(Self::Sat), "MSAT" => Ok(Self::Msat), "USD" => Ok(Self::Usd), "EUR" => Ok(Self::Eur), - c => Ok(Self::Custom(c.to_string())), + c => { + // Require explicit index for custom currencies + if parts.len() != 2 { + return Err(Error::UnsupportedUnit); + } + let index = parts[1].parse().map_err(|_| Error::UnsupportedUnit)?; + Ok(Self::new_custom(c.to_string(), index).map_err(|_| Error::UnsupportedUnit)?) + } } } } @@ -408,11 +436,11 @@ impl FromStr for CurrencyUnit { impl fmt::Display for CurrencyUnit { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { let s = match self { - CurrencyUnit::Sat => "SAT", - CurrencyUnit::Msat => "MSAT", - CurrencyUnit::Usd => "USD", - CurrencyUnit::Eur => "EUR", - CurrencyUnit::Custom(unit) => unit, + CurrencyUnit::Sat => "SAT".to_string(), + CurrencyUnit::Msat => "MSAT".to_string(), + CurrencyUnit::Usd => "USD".to_string(), + CurrencyUnit::Eur => "EUR".to_string(), + CurrencyUnit::Custom(unit, index) => format!("{}:{}", unit, index), }; if let Some(width) = f.width() { write!(f, "{:width$}", s.to_lowercase(), width = width) @@ -763,4 +791,86 @@ mod tests { .unwrap(); assert_eq!(b.len(), 1); } + + #[test] + fn test_currency_unit_from_str() { + // valid cases + let standard_cases = [ + ("SAT", CurrencyUnit::Sat), + ("MSAT", CurrencyUnit::Msat), + ("USD", CurrencyUnit::Usd), + ("EUR", CurrencyUnit::Eur), + ("GBP:1001", CurrencyUnit::Custom("GBP".to_string(), 1001)), + ]; + + for (input, expected) in standard_cases { + assert_eq!(CurrencyUnit::from_str(input).unwrap(), expected); + assert_eq!( + CurrencyUnit::from_str(&input.to_lowercase()).unwrap(), + expected + ); + } + + // invalid cases + let invalid_cases = [ + "GBP", + "GBP:", + "GBP:abc", + "", + // one more than max index + "GBP:2147483648", + ]; + + for invalid in invalid_cases { + match CurrencyUnit::from_str(invalid) { + Err(Error::UnsupportedUnit) => {} + other => panic!("Expected UnsupportedUnit error, got {:?}", other), + } + match CurrencyUnit::from_str(&invalid.to_lowercase()) { + Err(Error::UnsupportedUnit) => {} + other => panic!("Expected UnsupportedUnit error, got {:?}", other), + } + } + } + + #[test] + fn test_currency_unit_display() { + // Standard currencies + assert_eq!(CurrencyUnit::Sat.to_string(), "sat"); + assert_eq!(CurrencyUnit::Msat.to_string(), "msat"); + assert_eq!(CurrencyUnit::Usd.to_string(), "usd"); + assert_eq!(CurrencyUnit::Eur.to_string(), "eur"); + + // Custom currency + assert_eq!( + CurrencyUnit::Custom("GBP".to_string(), 1001).to_string(), + "gbp:1001" + ); + } + + #[test] + fn test_custom_currency_valid_index() { + let valid_index = 5; + let result = CurrencyUnit::new_custom("MyCurrency".to_string(), valid_index); + assert!(result.is_ok()); + if let Ok(CurrencyUnit::Custom(name, index)) = result { + assert_eq!(name, "MyCurrency"); + assert_eq!(index, valid_index); + } + } + + #[test] + fn test_custom_currency_invalid_indexes() { + let invalid_indexes = [0, 1, 2, 3, 2147483648]; + for &index in &invalid_indexes { + let result = CurrencyUnit::new_custom("InvalidCurrency".to_string(), index); + assert!(result.is_err(), "Index {} should not be allowed", index); + if let Err(err) = result { + assert!( + err.contains(&index.to_string()), + "Error message should mention the invalid index" + ); + } + } + } }