Skip to content

Commit

Permalink
Make token_amount from odra-bdd generic over the decimals.
Browse files Browse the repository at this point in the history
Added USDC and WCSPR as common usage examples.
  • Loading branch information
kubaplas committed Feb 12, 2025
1 parent cd99c6c commit 40ebfeb
Showing 1 changed file with 101 additions and 24 deletions.
125 changes: 101 additions & 24 deletions odra-bdd/src/types/token_amount.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
use cucumber::Parameter;
use odra::casper_types::U256;
use std::fmt::Display;
use std::marker::PhantomData;
use std::ops::Deref;
use std::str::FromStr;

#[derive(Parameter, Debug, Clone, Copy)]
#[param(regex = r"\d+(\.\d+)?", name = "token_amount")]
pub struct TokenAmount {
#[derive(Debug, Clone, Copy)]
pub struct TokenAmount<const DECIMALS: usize> {
amount: U256,
precision: usize
precision: usize,
_phantom: PhantomData<[(); DECIMALS]>
}

impl PartialEq for TokenAmount {
impl<const DECIMALS: usize> PartialEq for TokenAmount<DECIMALS> {
fn eq(&self, other: &Self) -> bool {
let min_precision = self.precision.min(other.precision);

Expand All @@ -26,37 +27,45 @@ impl PartialEq for TokenAmount {
}
}

impl PartialOrd for TokenAmount {
impl<const DECIMALS: usize> PartialOrd for TokenAmount<DECIMALS> {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
self.amount.partial_cmp(&other.amount)
}
}

impl TokenAmount {
impl<const DECIMALS: usize> TokenAmount<DECIMALS> {
pub fn new(amount: U256, precision: usize) -> Self {
TokenAmount { amount, precision }
TokenAmount {
amount,
precision,
_phantom: PhantomData
}
}

pub fn amount(&self) -> U256 {
self.amount
}

fn multiplier() -> U256 {
U256::from(10u64).pow(U256::from(DECIMALS))
}
}

impl Deref for TokenAmount {
impl<const DECIMALS: usize> Deref for TokenAmount<DECIMALS> {
type Target = U256;

fn deref(&self) -> &Self::Target {
&self.amount
}
}

impl Display for TokenAmount {
impl<const DECIMALS: usize> Display for TokenAmount<DECIMALS> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.amount)
}
}

impl FromStr for TokenAmount {
impl<const DECIMALS: usize> FromStr for TokenAmount<DECIMALS> {
type Err = String;

fn from_str(s: &str) -> Result<Self, Self::Err> {
Expand All @@ -66,47 +75,115 @@ impl FromStr for TokenAmount {
// No decimal point
let amount = U256::from_dec_str(parts[0]).map_err(|e| e.to_string())?;
Ok(TokenAmount {
amount: amount * U256::from(1_000_000_000u64),
precision: 0
amount: amount * Self::multiplier(),
precision: 0,
_phantom: PhantomData
})
}
2 => {
// Has decimal point
let whole = U256::from_dec_str(parts[0]).map_err(|e| e.to_string())?;
let mut decimal = parts[1].to_string();
let precision = 9 - decimal.len();
// Pad with zeros if less than 9 decimal places
while decimal.len() < 9 {
let precision = DECIMALS - decimal.len();
// Pad with zeros if less than DECIMALS decimal places
while decimal.len() < DECIMALS {
decimal.push('0');
}
// Truncate if more than 9 decimal places
decimal.truncate(9);
// Truncate if more than DECIMALS decimal places
decimal.truncate(DECIMALS);
let fractional = U256::from_dec_str(&decimal).map_err(|e| e.to_string())?;

Ok(TokenAmount {
amount: (whole * U256::from(1_000_000_000u64)) + fractional,
precision
amount: (whole * Self::multiplier()) + fractional,
precision,
_phantom: PhantomData
})
}
_ => Err("Invalid token amount format".to_string())
}
}
}

impl From<U256> for TokenAmount {
impl<const DECIMALS: usize> From<U256> for TokenAmount<DECIMALS> {
fn from(amount: U256) -> Self {
TokenAmount {
amount,
precision: 9
precision: DECIMALS,
_phantom: PhantomData
}
}
}

impl From<&U256> for TokenAmount {
impl<const DECIMALS: usize> From<&U256> for TokenAmount<DECIMALS> {
fn from(amount: &U256) -> Self {
TokenAmount {
amount: *amount,
precision: 9
precision: DECIMALS,
_phantom: PhantomData
}
}
}

#[derive(Debug, Clone, Copy, Parameter)]
#[param(regex = r"\d+(\.\d+)?", name = "usdc")]
pub struct USDCAmount(TokenAmount<6>);

impl Deref for USDCAmount {
type Target = TokenAmount<6>;

fn deref(&self) -> &Self::Target {
&self.0
}
}

impl FromStr for USDCAmount {
type Err = String;

fn from_str(s: &str) -> Result<Self, Self::Err> {
TokenAmount::<6>::from_str(s).map(USDCAmount)
}
}

impl From<U256> for USDCAmount {
fn from(amount: U256) -> Self {
USDCAmount(TokenAmount::<6>::from(amount))
}
}

impl From<&U256> for USDCAmount {
fn from(amount: &U256) -> Self {
USDCAmount(TokenAmount::<6>::from(amount))
}
}

#[derive(Debug, Clone, Copy, Parameter)]
#[param(regex = r"\d+(\.\d+)?", name = "wcspr")]
pub struct WrappedCSPRAmount(TokenAmount<9>);

impl Deref for WrappedCSPRAmount {
type Target = TokenAmount<9>;

fn deref(&self) -> &Self::Target {
&self.0
}
}

impl FromStr for WrappedCSPRAmount {
type Err = String;

fn from_str(s: &str) -> Result<Self, Self::Err> {
TokenAmount::<9>::from_str(s).map(WrappedCSPRAmount)
}
}

impl From<U256> for WrappedCSPRAmount {
fn from(amount: U256) -> Self {
WrappedCSPRAmount(TokenAmount::<9>::from(amount))
}
}

impl From<&U256> for WrappedCSPRAmount {
fn from(amount: &U256) -> Self {
WrappedCSPRAmount(TokenAmount::<9>::from(amount))
}
}

0 comments on commit 40ebfeb

Please sign in to comment.