diff --git a/contracts/rewards/src/contract.rs b/contracts/rewards/src/contract.rs index baeb60968..7c8883575 100644 --- a/contracts/rewards/src/contract.rs +++ b/contracts/rewards/src/contract.rs @@ -1,5 +1,3 @@ -use itertools::Itertools; - use axelar_wasm_std::nonempty; #[cfg(not(feature = "library"))] use cosmwasm_std::entry_point; @@ -7,12 +5,12 @@ use cosmwasm_std::{ to_json_binary, BankMsg, Binary, Coin, Deps, DepsMut, Env, MessageInfo, Response, }; use error_stack::ResultExt; +use itertools::Itertools; use crate::{ - contract::execute::Contract, error::ContractError, msg::{ExecuteMsg, InstantiateMsg, QueryMsg}, - state::{Config, Epoch, ParamsSnapshot, PoolId, CONFIG, PARAMS}, + state::{self, Config, Epoch, ParamsSnapshot, PoolId, CONFIG, PARAMS}, }; mod execute; @@ -67,25 +65,30 @@ pub fn execute( chain_name, contract: info.sender.clone(), }; - Contract::new(deps) - .record_participation(event_id, worker_address, pool_id, env.block.height) - .map_err(axelar_wasm_std::ContractError::from)?; + execute::record_participation( + deps.storage, + event_id, + worker_address, + pool_id, + env.block.height, + ) + .map_err(axelar_wasm_std::ContractError::from)?; Ok(Response::new()) } ExecuteMsg::AddRewards { pool_id } => { deps.api.addr_validate(pool_id.contract.as_str())?; - let mut contract = Contract::new(deps); let amount = info .funds .iter() - .find(|coin| coin.denom == contract.config.rewards_denom) + .find(|coin| coin.denom == state::load_config(deps.storage).rewards_denom) .filter(|_| info.funds.len() == 1) // filter here to make sure expected denom is the only one attached to this message, and other funds aren't silently swallowed .ok_or(ContractError::WrongDenom)? .amount; - contract.add_rewards( + execute::add_rewards( + deps.storage, pool_id, nonempty::Uint128::try_from(amount).change_context(ContractError::ZeroRewards)?, )?; @@ -98,10 +101,9 @@ pub fn execute( } => { deps.api.addr_validate(pool_id.contract.as_str())?; - let mut contract = Contract::new(deps); - let rewards = contract - .distribute_rewards(pool_id, env.block.height, epoch_count) - .map_err(axelar_wasm_std::ContractError::from)?; + let rewards = + execute::distribute_rewards(deps.storage, pool_id, env.block.height, epoch_count) + .map_err(axelar_wasm_std::ContractError::from)?; let msgs = rewards .into_iter() @@ -109,7 +111,7 @@ pub fn execute( .map(|(addr, amount)| BankMsg::Send { to_address: addr.into(), amount: vec![Coin { - denom: contract.config.rewards_denom.clone(), + denom: state::load_config(deps.storage).rewards_denom.clone(), amount, }], }); @@ -117,7 +119,7 @@ pub fn execute( Ok(Response::new().add_messages(msgs)) } ExecuteMsg::UpdateParams { params } => { - Contract::new(deps).update_params(params, env.block.height, info.sender)?; + execute::update_params(deps.storage, params, env.block.height, info.sender)?; Ok(Response::new()) } diff --git a/contracts/rewards/src/contract/execute.rs b/contracts/rewards/src/contract/execute.rs index 44de76910..36ec48f0f 100644 --- a/contracts/rewards/src/contract/execute.rs +++ b/contracts/rewards/src/contract/execute.rs @@ -1,234 +1,190 @@ use std::collections::HashMap; use axelar_wasm_std::{nonempty, FnExt}; -use cosmwasm_std::{Addr, DepsMut, OverflowError, OverflowOperation, Uint128}; +use cosmwasm_std::{Addr, OverflowError, OverflowOperation, Storage, Uint128}; use error_stack::{Report, Result}; use crate::{ error::ContractError, msg::Params, - state::{ - Config, Epoch, EpochTally, Event, ParamsSnapshot, PoolId, RewardsStore, StorageState, - Store, CONFIG, - }, + state::{self, Config, Epoch, EpochTally, Event, ParamsSnapshot, PoolId, StorageState}, }; const DEFAULT_EPOCHS_TO_PROCESS: u64 = 10; const EPOCH_PAYOUT_DELAY: u64 = 2; -pub struct Contract -where - S: Store, -{ - pub store: S, - pub config: Config, +fn require_governance(config: Config, sender: Addr) -> Result<(), ContractError> { + if config.governance != sender { + return Err(ContractError::Unauthorized.into()); + } + Ok(()) } -impl<'a> Contract> { - pub fn new(deps: DepsMut) -> Contract { - let config = CONFIG.load(deps.storage).expect("couldn't load config"); - Contract { - store: RewardsStore { - storage: deps.storage, - }, - config, +pub(crate) fn record_participation( + storage: &mut dyn Storage, + event_id: nonempty::String, + worker: Addr, + pool_id: PoolId, + block_height: u64, +) -> Result<(), ContractError> { + let current_params = state::load_params(storage); + let cur_epoch = Epoch::current(¤t_params, block_height)?; + + let event = load_or_store_event(storage, event_id, pool_id.clone(), cur_epoch.epoch_num)?; + + state::load_epoch_tally(storage, pool_id.clone(), event.epoch_num)? + .unwrap_or(EpochTally::new(pool_id, cur_epoch, current_params.params)) + .record_participation(worker) + .then(|mut tally| { + if matches!(event, StorageState::New(_)) { + tally.event_count = tally.event_count.saturating_add(1) + } + state::save_epoch_tally(storage, &tally) + }) +} + +fn load_or_store_event( + storage: &mut dyn Storage, + event_id: nonempty::String, + pool_id: PoolId, + cur_epoch_num: u64, +) -> Result, ContractError> { + let event = state::load_event(storage, event_id.to_string(), pool_id.clone())?; + + match event { + None => { + let event = Event::new(event_id, pool_id, cur_epoch_num); + state::save_event(storage, &event)?; + Ok(StorageState::New(event)) } + Some(event) => Ok(StorageState::Existing(event)), } } -#[allow(dead_code)] -impl Contract -where - S: Store, -{ - /// Returns the current epoch. The current epoch is computed dynamically based on the current - /// block height and the epoch duration. If the epoch duration is updated, we store the epoch - /// in which the update occurs as the last checkpoint - fn current_epoch(&self, cur_block_height: u64) -> Result { - let current_params = self.store.load_params(); - Epoch::current(¤t_params, cur_block_height) - } +pub(crate) fn distribute_rewards( + storage: &mut dyn Storage, + pool_id: PoolId, + cur_block_height: u64, + epoch_process_limit: Option, +) -> Result, ContractError> { + let epoch_process_limit = epoch_process_limit.unwrap_or(DEFAULT_EPOCHS_TO_PROCESS); + let cur_epoch = Epoch::current(&state::load_params(storage), cur_block_height)?; - fn require_governance(&self, sender: Addr) -> Result<(), ContractError> { - if self.config.governance != sender { - return Err(ContractError::Unauthorized.into()); - } - Ok(()) - } + let from = state::load_rewards_watermark(storage, pool_id.clone())? + .map_or(0, |last_processed| last_processed.saturating_add(1)); - pub fn record_participation( - &mut self, - event_id: nonempty::String, - worker: Addr, - pool_id: PoolId, - block_height: u64, - ) -> Result<(), ContractError> { - let cur_epoch = self.current_epoch(block_height)?; - - let event = self.load_or_store_event(event_id, pool_id.clone(), cur_epoch.epoch_num)?; - - self.store - .load_epoch_tally(pool_id.clone(), event.epoch_num)? - .unwrap_or(EpochTally::new( - pool_id, - cur_epoch, - self.store.load_params().params, - )) - .record_participation(worker) - .then(|mut tally| { - if matches!(event, StorageState::New(_)) { - tally.event_count = tally.event_count.saturating_add(1) - } - self.store.save_epoch_tally(&tally) - }) - } + let to = std::cmp::min( + (from.saturating_add(epoch_process_limit)).saturating_sub(1), // for process limit =1 "from" and "to" must be equal + cur_epoch.epoch_num.saturating_sub(EPOCH_PAYOUT_DELAY), + ); - fn load_or_store_event( - &mut self, - event_id: nonempty::String, - pool_id: PoolId, - cur_epoch_num: u64, - ) -> Result, ContractError> { - let event = self - .store - .load_event(event_id.to_string(), pool_id.clone())?; - - match event { - None => { - let event = Event::new(event_id, pool_id, cur_epoch_num); - self.store.save_event(&event)?; - Ok(StorageState::New(event)) - } - Some(event) => Ok(StorageState::Existing(event)), - } + if to < from || cur_epoch.epoch_num < EPOCH_PAYOUT_DELAY { + return Err(ContractError::NoRewardsToDistribute.into()); } - pub fn distribute_rewards( - &mut self, - pool_id: PoolId, - cur_block_height: u64, - epoch_process_limit: Option, - ) -> Result, ContractError> { - let epoch_process_limit = epoch_process_limit.unwrap_or(DEFAULT_EPOCHS_TO_PROCESS); - let cur_epoch = self.current_epoch(cur_block_height)?; - - let from = self - .store - .load_rewards_watermark(pool_id.clone())? - .map_or(0, |last_processed| last_processed.saturating_add(1)); - - let to = std::cmp::min( - (from.saturating_add(epoch_process_limit)).saturating_sub(1), // for process limit =1 "from" and "to" must be equal - cur_epoch.epoch_num.saturating_sub(EPOCH_PAYOUT_DELAY), - ); + let rewards = process_rewards_for_epochs(storage, pool_id.clone(), from, to)?; + state::save_rewards_watermark(storage, pool_id, to)?; + Ok(rewards) +} - if to < from || cur_epoch.epoch_num < EPOCH_PAYOUT_DELAY { - return Err(ContractError::NoRewardsToDistribute.into()); - } +fn process_rewards_for_epochs( + storage: &mut dyn Storage, + pool_id: PoolId, + from: u64, + to: u64, +) -> Result, ContractError> { + let rewards = cumulate_rewards(storage, &pool_id, from, to)?; + state::load_rewards_pool_or_new(storage, pool_id.clone())? + .sub_reward(rewards.values().sum())? + .then(|pool| state::save_rewards_pool(storage, &pool))?; - let rewards = self.process_rewards_for_epochs(pool_id.clone(), from, to)?; - self.store.save_rewards_watermark(pool_id, to)?; - Ok(rewards) - } + Ok(rewards) +} - fn process_rewards_for_epochs( - &mut self, - pool_id: PoolId, - from: u64, - to: u64, - ) -> Result, ContractError> { - let rewards = self.cumulate_rewards(&pool_id, from, to)?; - self.store - .load_rewards_pool(pool_id.clone())? - .sub_reward(rewards.values().sum())? - .then(|pool| self.store.save_rewards_pool(&pool))?; - - Ok(rewards) - } +fn cumulate_rewards( + storage: &mut dyn Storage, + pool_id: &PoolId, + from: u64, + to: u64, +) -> Result, ContractError> { + iterate_epoch_tallies(storage, pool_id, from, to) + .map(|tally| tally.rewards_by_worker()) + .try_fold(HashMap::new(), merge_rewards) +} - fn cumulate_rewards( - &mut self, - pool_id: &PoolId, - from: u64, - to: u64, - ) -> Result, ContractError> { - self.iterate_epoch_tallies(pool_id, from, to) - .map(|tally| tally.rewards_by_worker()) - .try_fold(HashMap::new(), merge_rewards) - } +fn iterate_epoch_tallies<'a>( + storage: &'a mut dyn Storage, + pool_id: &'a PoolId, + from: u64, + to: u64, +) -> impl Iterator + 'a { + (from..=to).filter_map(|epoch_num| { + state::load_epoch_tally(storage, pool_id.clone(), epoch_num).unwrap_or_default() + }) +} - fn iterate_epoch_tallies<'a>( - &'a mut self, - pool_id: &'a PoolId, - from: u64, - to: u64, - ) -> impl Iterator + 'a { - (from..=to).filter_map(|epoch_num| { - self.store - .load_epoch_tally(pool_id.clone(), epoch_num) - .unwrap_or_default() +pub(crate) fn update_params( + storage: &mut dyn Storage, + new_params: Params, + block_height: u64, + sender: Addr, +) -> Result<(), ContractError> { + require_governance(state::load_config(storage), sender)?; + let cur_epoch = Epoch::current(&state::load_params(storage), block_height)?; + // If the param update reduces the epoch duration such that the current epoch immediately ends, + // start a new epoch at this block, incrementing the current epoch number by 1. + // This prevents us from jumping forward an arbitrary number of epochs, and maintains consistency for past events. + // (i.e. we are in epoch 0, which started at block 0 and epoch duration is 1000. At epoch 500, the params + // are updated to shorten the epoch duration to 100 blocks. We set the epoch number to 1, to prevent skipping + // epochs 1-4, and so all events prior to the start of epoch 1 have an epoch number of 0) + let should_end = cur_epoch + .block_height_started + .checked_add(u64::from(new_params.epoch_duration)) + .ok_or_else(|| { + OverflowError::new( + OverflowOperation::Add, + cur_epoch.block_height_started, + new_params.epoch_duration, + ) }) - } - - pub fn update_params( - &mut self, - new_params: Params, - block_height: u64, - sender: Addr, - ) -> Result<(), ContractError> { - self.require_governance(sender)?; - let cur_epoch = self.current_epoch(block_height)?; - // If the param update reduces the epoch duration such that the current epoch immediately ends, - // start a new epoch at this block, incrementing the current epoch number by 1. - // This prevents us from jumping forward an arbitrary number of epochs, and maintains consistency for past events. - // (i.e. we are in epoch 0, which started at block 0 and epoch duration is 1000. At epoch 500, the params - // are updated to shorten the epoch duration to 100 blocks. We set the epoch number to 1, to prevent skipping - // epochs 1-4, and so all events prior to the start of epoch 1 have an epoch number of 0) - let should_end = cur_epoch - .block_height_started - .checked_add(u64::from(new_params.epoch_duration)) - .ok_or_else(|| { - OverflowError::new( - OverflowOperation::Add, - cur_epoch.block_height_started, - new_params.epoch_duration, - ) - }) - .map_err(ContractError::from)? - < block_height; - let cur_epoch = if should_end { - Epoch { - block_height_started: block_height, - epoch_num: cur_epoch.epoch_num.checked_add(1).expect( - "epoch number should be strictly smaller than the current block height", - ), - } - } else { - cur_epoch - }; - self.store.save_params(&ParamsSnapshot { + .map_err(ContractError::from)? + < block_height; + let cur_epoch = if should_end { + Epoch { + block_height_started: block_height, + epoch_num: cur_epoch + .epoch_num + .checked_add(1) + .expect("epoch number should be strictly smaller than the current block height"), + } + } else { + cur_epoch + }; + state::save_params( + storage, + &ParamsSnapshot { params: new_params, created_at: cur_epoch, - })?; - Ok(()) - } + }, + )?; + Ok(()) +} - pub fn add_rewards( - &mut self, - pool_id: PoolId, - amount: nonempty::Uint128, - ) -> Result<(), ContractError> { - let mut pool = self.store.load_rewards_pool(pool_id)?; - pool.balance = pool - .balance - .checked_add(Uint128::from(amount)) - .map_err(Into::::into) - .map_err(Report::from)?; - - self.store.save_rewards_pool(&pool)?; - - Ok(()) - } +pub(crate) fn add_rewards( + storage: &mut dyn Storage, + pool_id: PoolId, + amount: nonempty::Uint128, +) -> Result<(), ContractError> { + let mut pool = state::load_rewards_pool_or_new(storage, pool_id)?; + pool.balance = pool + .balance + .checked_add(Uint128::from(amount)) + .map_err(Into::::into) + .map_err(Report::from)?; + + state::save_rewards_pool(storage, &pool)?; + + Ok(()) } /// Merges rewards_2 into rewards_1. For each (address, amount) pair in rewards_2, @@ -258,44 +214,44 @@ fn merge_rewards( #[cfg(test)] mod test { - use std::{ - collections::HashMap, - sync::{Arc, RwLock}, - }; + use super::*; + + use std::collections::HashMap; use axelar_wasm_std::nonempty; use connection_router_api::ChainName; - use cosmwasm_std::{Addr, Uint128, Uint64}; + use cosmwasm_std::{ + testing::{mock_dependencies, MockApi, MockQuerier, MockStorage}, + Addr, OwnedDeps, Uint128, Uint64, + }; use crate::{ error::ContractError, msg::Params, state::{ - self, Config, Epoch, EpochTally, Event, ParamsSnapshot, PoolId, RewardsPool, Store, - TallyId, + self, Config, Epoch, EpochTally, Event, ParamsSnapshot, PoolId, RewardsPool, CONFIG, }, }; - use super::Contract; - /// Tests that the current epoch is computed correctly when the expected epoch is the same as the stored epoch #[test] fn current_epoch_same_epoch_is_idempotent() { let cur_epoch_num = 1u64; let block_height_started = 250u64; let epoch_duration = 100u64; - let contract = setup(cur_epoch_num, block_height_started, epoch_duration); - let new_epoch = contract.current_epoch(block_height_started).unwrap(); + let (mock_deps, _) = setup(cur_epoch_num, block_height_started, epoch_duration); + let current_params = state::load_params(mock_deps.as_ref().storage); + + let new_epoch = Epoch::current(¤t_params, block_height_started).unwrap(); assert_eq!(new_epoch.epoch_num, cur_epoch_num); assert_eq!(new_epoch.block_height_started, block_height_started); - let new_epoch = contract.current_epoch(block_height_started + 1).unwrap(); + let new_epoch = Epoch::current(¤t_params, block_height_started + 1).unwrap(); assert_eq!(new_epoch.epoch_num, cur_epoch_num); assert_eq!(new_epoch.block_height_started, block_height_started); - let new_epoch = contract - .current_epoch(block_height_started + epoch_duration - 1) - .unwrap(); + let new_epoch = + Epoch::current(¤t_params, block_height_started + epoch_duration - 1).unwrap(); assert_eq!(new_epoch.epoch_num, cur_epoch_num); assert_eq!(new_epoch.block_height_started, block_height_started); } @@ -307,11 +263,11 @@ mod test { let cur_epoch_num = 1u64; let block_height_started = 250u64; let epoch_duration = 100u64; - let contract = setup(cur_epoch_num, block_height_started, epoch_duration); - assert!(contract.current_epoch(block_height_started - 1).is_err()); - assert!(contract - .current_epoch(block_height_started - epoch_duration) - .is_err()); + let (mock_deps, _) = setup(cur_epoch_num, block_height_started, epoch_duration); + let current_params = state::load_params(mock_deps.as_ref().storage); + + assert!(Epoch::current(¤t_params, block_height_started - 1).is_err()); + assert!(Epoch::current(¤t_params, block_height_started - epoch_duration).is_err()); } /// Tests that the current epoch is computed correctly when the expected epoch is different from the stored epoch @@ -320,7 +276,7 @@ mod test { let cur_epoch_num = 1u64; let block_height_started = 250u64; let epoch_duration = 100u64; - let contract = setup(cur_epoch_num, block_height_started, epoch_duration); + let (mock_deps, _) = setup(cur_epoch_num, block_height_started, epoch_duration); // elements are (height, expected epoch number, expected epoch start) let test_cases = vec![ @@ -347,7 +303,8 @@ mod test { ]; for (height, expected_epoch_num, expected_block_start) in test_cases { - let new_epoch = contract.current_epoch(height).unwrap(); + let new_epoch = + Epoch::current(&state::load_params(mock_deps.as_ref().storage), height).unwrap(); assert_eq!(new_epoch.epoch_num, expected_epoch_num); assert_eq!(new_epoch.block_height_started, expected_block_start); @@ -361,7 +318,7 @@ mod test { let epoch_block_start = 250u64; let epoch_duration = 100u64; - let mut contract = setup(cur_epoch_num, epoch_block_start, epoch_duration); + let (mut mock_deps, _) = setup(cur_epoch_num, epoch_block_start, epoch_duration); let pool_id = PoolId { chain_name: "mock-chain".parse().unwrap(), @@ -380,18 +337,21 @@ mod test { // simulates a worker participating in only part_count events if i < *part_count { let event_id = i.to_string().try_into().unwrap(); - contract - .record_participation(event_id, worker.clone(), pool_id.clone(), cur_height) - .unwrap(); + record_participation( + mock_deps.as_mut().storage, + event_id, + worker.clone(), + pool_id.clone(), + cur_height, + ) + .unwrap(); } } cur_height += 1; } - let tally = contract - .store - .load_epoch_tally(pool_id, cur_epoch_num) - .unwrap(); + let tally = + state::load_epoch_tally(mock_deps.as_ref().storage, pool_id, cur_epoch_num).unwrap(); assert!(tally.is_some()); let tally = tally.unwrap(); @@ -412,7 +372,7 @@ mod test { let block_height_started = 250u64; let epoch_duration = 100u64; - let mut contract = setup(starting_epoch_num, block_height_started, epoch_duration); + let (mut mock_deps, _) = setup(starting_epoch_num, block_height_started, epoch_duration); let pool_id = PoolId { chain_name: "mock-chain".parse().unwrap(), @@ -428,23 +388,29 @@ mod test { let height_at_epoch_end = block_height_started + epoch_duration - 1; // workers participate in consecutive blocks for (i, workers) in workers.iter().enumerate() { - contract - .record_participation( - "some event".to_string().try_into().unwrap(), - workers.clone(), - pool_id.clone(), - height_at_epoch_end + i as u64, - ) - .unwrap(); + record_participation( + mock_deps.as_mut().storage, + "some event".to_string().try_into().unwrap(), + workers.clone(), + pool_id.clone(), + height_at_epoch_end + i as u64, + ) + .unwrap(); } - let cur_epoch = contract.current_epoch(height_at_epoch_end).unwrap(); + let cur_epoch = Epoch::current( + &state::load_params(mock_deps.as_ref().storage), + height_at_epoch_end, + ) + .unwrap(); assert_ne!(starting_epoch_num + 1, cur_epoch.epoch_num); - let tally = contract - .store - .load_epoch_tally(pool_id.clone(), starting_epoch_num) - .unwrap(); + let tally = state::load_epoch_tally( + mock_deps.as_ref().storage, + pool_id.clone(), + starting_epoch_num, + ) + .unwrap(); assert!(tally.is_some()); let tally = tally.unwrap(); @@ -455,10 +421,9 @@ mod test { assert_eq!(tally.participation.get(&w.to_string()), Some(&1)); } - let tally = contract - .store - .load_epoch_tally(pool_id, starting_epoch_num + 1) - .unwrap(); + let tally = + state::load_epoch_tally(mock_deps.as_ref().storage, pool_id, starting_epoch_num + 1) + .unwrap(); assert!(tally.is_none()); } @@ -469,7 +434,7 @@ mod test { let block_height_started = 250u64; let epoch_duration = 100u64; - let mut contract = setup(cur_epoch_num, block_height_started, epoch_duration); + let (mut mock_deps, _) = setup(cur_epoch_num, block_height_started, epoch_duration); let mut simulated_participation = HashMap::new(); simulated_participation.insert( @@ -506,21 +471,23 @@ mod test { for (worker, (worker_contract, events_participated)) in &simulated_participation { for i in 0..*events_participated { let event_id = i.to_string().try_into().unwrap(); - contract - .record_participation( - event_id, - worker.clone(), - worker_contract.clone(), - block_height_started, - ) - .unwrap(); + record_participation( + mock_deps.as_mut().storage, + event_id, + worker.clone(), + worker_contract.clone(), + block_height_started, + ) + .unwrap(); } } for (worker, (worker_contract, events_participated)) in simulated_participation { - let tally = contract - .store - .load_epoch_tally(worker_contract.clone(), cur_epoch_num) - .unwrap(); + let tally = state::load_epoch_tally( + mock_deps.as_ref().storage, + worker_contract.clone(), + cur_epoch_num, + ) + .unwrap(); assert!(tally.is_some()); let tally = tally.unwrap(); @@ -536,13 +503,13 @@ mod test { /// Test that rewards parameters are updated correctly. In this test we don't change the epoch duration, so /// that computation of the current epoch is unaffected. #[test] - fn update_params() { + fn successfully_update_params() { let initial_epoch_num = 1u64; let initial_epoch_start = 250u64; let initial_rewards_per_epoch = 100u128; let initial_participation_threshold = (1, 2); let epoch_duration = 100u64; - let mut contract = setup_with_params( + let (mut mock_deps, config) = setup_with_params( initial_epoch_num, initial_epoch_start, epoch_duration, @@ -562,20 +529,22 @@ mod test { }; // the epoch shouldn't change when the params are updated, since we are not changing the epoch duration - let expected_epoch = contract.current_epoch(cur_height).unwrap(); + let expected_epoch = + Epoch::current(&state::load_params(mock_deps.as_ref().storage), cur_height).unwrap(); - contract - .update_params( - new_params.clone(), - cur_height, - contract.config.governance.clone(), - ) - .unwrap(); - let stored = contract.store.load_params(); + update_params( + mock_deps.as_mut().storage, + new_params.clone(), + cur_height, + config.governance.clone(), + ) + .unwrap(); + let stored = state::load_params(mock_deps.as_ref().storage); assert_eq!(stored.params, new_params); // current epoch shouldn't have changed - let cur_epoch = contract.current_epoch(cur_height).unwrap(); + let cur_epoch = + Epoch::current(&state::load_params(mock_deps.as_ref().storage), cur_height).unwrap(); assert_eq!(expected_epoch.epoch_num, cur_epoch.epoch_num); assert_eq!( expected_epoch.block_height_started, @@ -592,7 +561,7 @@ mod test { let initial_epoch_num = 1u64; let initial_epoch_start = 250u64; let epoch_duration = 100u64; - let mut contract = setup(initial_epoch_num, initial_epoch_start, epoch_duration); + let (mut mock_deps, _) = setup(initial_epoch_num, initial_epoch_start, epoch_duration); let new_params = Params { rewards_per_epoch: cosmwasm_std::Uint128::from(100u128).try_into().unwrap(), @@ -600,7 +569,8 @@ mod test { epoch_duration: epoch_duration.try_into().unwrap(), }; - let res = contract.update_params( + let res = update_params( + mock_deps.as_mut().storage, new_params.clone(), initial_epoch_start, Addr::unchecked("some non governance address"), @@ -618,7 +588,7 @@ mod test { let initial_epoch_num = 1u64; let initial_epoch_start = 250u64; let initial_epoch_duration = 100u64; - let mut contract = setup( + let (mut mock_deps, config) = setup( initial_epoch_num, initial_epoch_start, initial_epoch_duration, @@ -629,36 +599,40 @@ mod test { let cur_height = initial_epoch_start + initial_epoch_duration * epochs_elapsed + 10; // add 10 here just to be a little past the epoch boundary // epoch shouldn't change if we are extending the duration - let epoch_prior_to_update = contract.current_epoch(cur_height).unwrap(); + let initial_params_snapshot = state::load_params(mock_deps.as_ref().storage); + let epoch_prior_to_update = Epoch::current(&initial_params_snapshot, cur_height).unwrap(); let new_epoch_duration = initial_epoch_duration * 2; let new_params = Params { epoch_duration: (new_epoch_duration).try_into().unwrap(), - ..contract.store.load_params().params // keep everything besides epoch duration the same + ..initial_params_snapshot.params // keep everything besides epoch duration the same }; - contract - .update_params( - new_params.clone(), - cur_height, - contract.config.governance.clone(), - ) - .unwrap(); + update_params( + mock_deps.as_mut().storage, + new_params.clone(), + cur_height, + config.governance.clone(), + ) + .unwrap(); + + let updated_params_snapshot = state::load_params(mock_deps.as_ref().storage); // current epoch shouldn't change - let epoch = contract.current_epoch(cur_height).unwrap(); + let epoch = Epoch::current(&updated_params_snapshot, cur_height).unwrap(); assert_eq!(epoch, epoch_prior_to_update); // we increased the epoch duration, so adding the initial epoch duration should leave us in the same epoch - let epoch = contract - .current_epoch(cur_height + initial_epoch_duration) - .unwrap(); + let epoch = Epoch::current( + &updated_params_snapshot, + cur_height + initial_epoch_duration, + ) + .unwrap(); assert_eq!(epoch, epoch_prior_to_update); // check that we can correctly compute the start of the next epoch - let next_epoch = contract - .current_epoch(cur_height + new_epoch_duration) - .unwrap(); + let next_epoch = + Epoch::current(&updated_params_snapshot, cur_height + new_epoch_duration).unwrap(); assert_eq!(next_epoch.epoch_num, epoch_prior_to_update.epoch_num + 1); assert_eq!( next_epoch.block_height_started, @@ -673,7 +647,7 @@ mod test { let initial_epoch_num = 1u64; let initial_epoch_start = 256u64; let initial_epoch_duration = 100u64; - let mut contract = setup( + let (mut mock_deps, config) = setup( initial_epoch_num, initial_epoch_start, initial_epoch_duration, @@ -684,30 +658,33 @@ mod test { let cur_height = initial_epoch_start + initial_epoch_duration * epochs_elapsed; let new_epoch_duration = initial_epoch_duration / 2; - let epoch_prior_to_update = contract.current_epoch(cur_height).unwrap(); + + let initial_params_snapshot = state::load_params(mock_deps.as_ref().storage); + let epoch_prior_to_update = Epoch::current(&initial_params_snapshot, cur_height).unwrap(); // we are shortening the epoch, but not so much it causes the epoch number to change. We want to remain in the same epoch assert!(cur_height - epoch_prior_to_update.block_height_started < new_epoch_duration); let new_params = Params { epoch_duration: new_epoch_duration.try_into().unwrap(), - ..contract.store.load_params().params + ..initial_params_snapshot.params }; - contract - .update_params( - new_params.clone(), - cur_height, - contract.config.governance.clone(), - ) - .unwrap(); + update_params( + mock_deps.as_mut().storage, + new_params.clone(), + cur_height, + config.governance.clone(), + ) + .unwrap(); + + let updated_params_snapshot = state::load_params(mock_deps.as_ref().storage); // current epoch shouldn't have changed - let epoch = contract.current_epoch(cur_height).unwrap(); + let epoch = Epoch::current(&updated_params_snapshot, cur_height).unwrap(); assert_eq!(epoch_prior_to_update, epoch); // adding the new epoch duration should increase the epoch number by 1 - let epoch = contract - .current_epoch(cur_height + new_epoch_duration) - .unwrap(); + let epoch = + Epoch::current(&updated_params_snapshot, cur_height + new_epoch_duration).unwrap(); assert_eq!(epoch.epoch_num, epoch_prior_to_update.epoch_num + 1); assert_eq!( epoch.block_height_started, @@ -722,7 +699,7 @@ mod test { let initial_epoch_num = 1u64; let initial_epoch_start = 250u64; let initial_epoch_duration = 100u64; - let mut contract = setup( + let (mut mock_deps, config) = setup( initial_epoch_num, initial_epoch_start, initial_epoch_duration, @@ -735,29 +712,32 @@ mod test { // simulate progressing far enough into the epoch such that shortening the epoch duration would change the epoch let cur_height = initial_epoch_start + initial_epoch_duration * epochs_elapsed + new_epoch_duration * 2; - let epoch_prior_to_update = contract.current_epoch(cur_height).unwrap(); + + let initial_params_snapshot = state::load_params(mock_deps.as_ref().storage); + let epoch_prior_to_update = Epoch::current(&initial_params_snapshot, cur_height).unwrap(); let new_params = Params { epoch_duration: 10.try_into().unwrap(), - ..contract.store.load_params().params + ..initial_params_snapshot.params }; - contract - .update_params( - new_params.clone(), - cur_height, - contract.config.governance.clone(), - ) - .unwrap(); + update_params( + mock_deps.as_mut().storage, + new_params.clone(), + cur_height, + config.governance.clone(), + ) + .unwrap(); + + let updated_params_snapshot = state::load_params(mock_deps.as_ref().storage); // should be in new epoch now - let epoch = contract.current_epoch(cur_height).unwrap(); + let epoch = Epoch::current(&updated_params_snapshot, cur_height).unwrap(); assert_eq!(epoch.epoch_num, epoch_prior_to_update.epoch_num + 1); assert_eq!(epoch.block_height_started, cur_height); // moving forward the new epoch duration # of blocks should increment the epoch - let epoch = contract - .current_epoch(cur_height + new_epoch_duration) - .unwrap(); + let epoch = + Epoch::current(&updated_params_snapshot, cur_height + new_epoch_duration).unwrap(); assert_eq!(epoch.epoch_num, epoch_prior_to_update.epoch_num + 2); assert_eq!(epoch.block_height_started, cur_height + new_epoch_duration); } @@ -769,29 +749,37 @@ mod test { let block_height_started = 250u64; let epoch_duration = 100u64; - let mut contract = setup(cur_epoch_num, block_height_started, epoch_duration); + let (mut mock_deps, _) = setup(cur_epoch_num, block_height_started, epoch_duration); let pool_id = PoolId { chain_name: "mock-chain".parse().unwrap(), contract: Addr::unchecked("some contract"), }; - let pool = contract.store.load_rewards_pool(pool_id.clone()).unwrap(); + let pool = + state::load_rewards_pool_or_new(mock_deps.as_ref().storage, pool_id.clone()).unwrap(); assert!(pool.balance.is_zero()); let initial_amount = Uint128::from(100u128); - contract - .add_rewards(pool_id.clone(), initial_amount.try_into().unwrap()) - .unwrap(); + add_rewards( + mock_deps.as_mut().storage, + pool_id.clone(), + initial_amount.try_into().unwrap(), + ) + .unwrap(); - let pool = contract.store.load_rewards_pool(pool_id.clone()).unwrap(); + let pool = + state::load_rewards_pool_or_new(mock_deps.as_ref().storage, pool_id.clone()).unwrap(); assert_eq!(pool.balance, initial_amount); let added_amount = Uint128::from(500u128); - contract - .add_rewards(pool_id.clone(), added_amount.try_into().unwrap()) - .unwrap(); + add_rewards( + mock_deps.as_mut().storage, + pool_id.clone(), + added_amount.try_into().unwrap(), + ) + .unwrap(); - let pool = contract.store.load_rewards_pool(pool_id).unwrap(); + let pool = state::load_rewards_pool_or_new(mock_deps.as_ref().storage, pool_id).unwrap(); assert_eq!(pool.balance, initial_amount + added_amount); } @@ -802,7 +790,7 @@ mod test { let block_height_started = 250u64; let epoch_duration = 100u64; - let mut contract = setup(cur_epoch_num, block_height_started, epoch_duration); + let (mut mock_deps, _) = setup(cur_epoch_num, block_height_started, epoch_duration); // a vector of (worker contract, rewards amounts) pairs let test_data = vec![ (Addr::unchecked("contract_1"), vec![100, 200, 50]), @@ -819,12 +807,12 @@ mod test { }; for amount in rewards { - contract - .add_rewards( - pool_id.clone(), - cosmwasm_std::Uint128::from(*amount).try_into().unwrap(), - ) - .unwrap(); + add_rewards( + mock_deps.as_mut().storage, + pool_id.clone(), + cosmwasm_std::Uint128::from(*amount).try_into().unwrap(), + ) + .unwrap(); } } @@ -834,7 +822,8 @@ mod test { contract: worker_contract.clone(), }; - let pool = contract.store.load_rewards_pool(pool_id).unwrap(); + let pool = + state::load_rewards_pool_or_new(mock_deps.as_ref().storage, pool_id).unwrap(); assert_eq!( pool.balance, cosmwasm_std::Uint128::from(rewards.iter().sum::()) @@ -844,14 +833,14 @@ mod test { /// Tests that rewards are distributed correctly based on participation #[test] - fn distribute_rewards() { + fn successfully_distribute_rewards() { let cur_epoch_num = 0u64; let block_height_started = 0u64; let epoch_duration = 1000u64; let rewards_per_epoch = 100u128; let participation_threshold = (2, 3); - let mut contract = setup_with_params( + let (mut mock_deps, _) = setup_with_params( cur_epoch_num, block_height_started, epoch_duration, @@ -901,7 +890,8 @@ mod test { for (epoch, events) in events_participated.iter().enumerate().take(epoch_count) { for event in events { let event_id = event.to_string() + &epoch.to_string() + "event"; - let _ = contract.record_participation( + let _ = record_participation( + mock_deps.as_mut().storage, event_id.clone().try_into().unwrap(), worker.clone(), pool_id.clone(), @@ -914,18 +904,19 @@ mod test { // we add 2 epochs worth of rewards. There were 2 epochs of participation, but only 2 epochs where rewards should be given out // These tests we are accounting correctly, and only removing from the pool when we actually give out rewards let rewards_added = 2 * rewards_per_epoch; - let _ = contract.add_rewards( + let _ = add_rewards( + mock_deps.as_mut().storage, pool_id.clone(), Uint128::from(rewards_added).try_into().unwrap(), ); - let rewards_claimed = contract - .distribute_rewards( - pool_id, - block_height_started + epoch_duration * (epoch_count + 2) as u64, - None, - ) - .unwrap(); + let rewards_claimed = distribute_rewards( + mock_deps.as_mut().storage, + pool_id, + block_height_started + epoch_duration * (epoch_count + 2) as u64, + None, + ) + .unwrap(); assert_eq!(rewards_claimed.len(), worker_participation_per_epoch.len()); for (worker, rewards) in expected_rewards_per_worker { @@ -943,7 +934,7 @@ mod test { let rewards_per_epoch = 100u128; let participation_threshold = (1, 2); - let mut contract = setup_with_params( + let (mut mock_deps, _) = setup_with_params( cur_epoch_num, block_height_started, epoch_duration, @@ -958,7 +949,8 @@ mod test { for height in block_height_started..block_height_started + epoch_duration * 9 { let event_id = height.to_string() + "event"; - let _ = contract.record_participation( + let _ = record_participation( + mock_deps.as_mut().storage, event_id.try_into().unwrap(), worker.clone(), pool_id.clone(), @@ -967,7 +959,8 @@ mod test { } let rewards_added = 1000u128; - let _ = contract.add_rewards( + let _ = add_rewards( + mock_deps.as_mut().storage, pool_id.clone(), Uint128::from(rewards_added).try_into().unwrap(), ); @@ -978,9 +971,13 @@ mod test { // distribute 5 epochs worth of rewards let epochs_to_process = 5; - let rewards_claimed = contract - .distribute_rewards(pool_id.clone(), cur_height, Some(epochs_to_process)) - .unwrap(); + let rewards_claimed = distribute_rewards( + mock_deps.as_mut().storage, + pool_id.clone(), + cur_height, + Some(epochs_to_process), + ) + .unwrap(); assert_eq!(rewards_claimed.len(), 1); assert!(rewards_claimed.contains_key(&worker)); assert_eq!( @@ -989,9 +986,13 @@ mod test { ); // distribute the remaining epochs worth of rewards - let rewards_claimed = contract - .distribute_rewards(pool_id.clone(), cur_height, None) - .unwrap(); + let rewards_claimed = distribute_rewards( + mock_deps.as_mut().storage, + pool_id.clone(), + cur_height, + None, + ) + .unwrap(); assert_eq!(rewards_claimed.len(), 1); assert!(rewards_claimed.contains_key(&worker)); assert_eq!( @@ -1012,7 +1013,7 @@ mod test { let rewards_per_epoch = 100u128; let participation_threshold = (8, 10); - let mut contract = setup_with_params( + let (mut mock_deps, _) = setup_with_params( cur_epoch_num, block_height_started, epoch_duration, @@ -1025,7 +1026,8 @@ mod test { contract: Addr::unchecked("worker_contract"), }; - let _ = contract.record_participation( + let _ = record_participation( + mock_deps.as_mut().storage, "event".try_into().unwrap(), worker.clone(), pool_id.clone(), @@ -1033,37 +1035,50 @@ mod test { ); let rewards_added = 1000u128; - let _ = contract.add_rewards( + let _ = add_rewards( + mock_deps.as_mut().storage, pool_id.clone(), Uint128::from(rewards_added).try_into().unwrap(), ); // too early, still in the same epoch - let err = contract - .distribute_rewards(pool_id.clone(), block_height_started, None) - .unwrap_err(); + let err = distribute_rewards( + mock_deps.as_mut().storage, + pool_id.clone(), + block_height_started, + None, + ) + .unwrap_err(); assert_eq!(err.current_context(), &ContractError::NoRewardsToDistribute); // next epoch, but still too early to claim rewards - let err = contract - .distribute_rewards(pool_id.clone(), block_height_started + epoch_duration, None) - .unwrap_err(); + let err = distribute_rewards( + mock_deps.as_mut().storage, + pool_id.clone(), + block_height_started + epoch_duration, + None, + ) + .unwrap_err(); assert_eq!(err.current_context(), &ContractError::NoRewardsToDistribute); // can claim now, two epochs after participation - let rewards_claimed = contract - .distribute_rewards( - pool_id.clone(), - block_height_started + epoch_duration * 2, - None, - ) - .unwrap(); + let rewards_claimed = distribute_rewards( + mock_deps.as_mut().storage, + pool_id.clone(), + block_height_started + epoch_duration * 2, + None, + ) + .unwrap(); assert_eq!(rewards_claimed.len(), 1); // should error if we try again - let err = contract - .distribute_rewards(pool_id, block_height_started + epoch_duration * 2, None) - .unwrap_err(); + let err = distribute_rewards( + mock_deps.as_mut().storage, + pool_id, + block_height_started + epoch_duration * 2, + None, + ) + .unwrap_err(); assert_eq!(err.current_context(), &ContractError::NoRewardsToDistribute); } @@ -1077,7 +1092,7 @@ mod test { let rewards_per_epoch = 100u128; let participation_threshold = (8, 10); - let mut contract = setup_with_params( + let (mut mock_deps, _) = setup_with_params( cur_epoch_num, block_height_started, epoch_duration, @@ -1090,7 +1105,8 @@ mod test { contract: Addr::unchecked("worker_contract"), }; - let _ = contract.record_participation( + let _ = record_participation( + mock_deps.as_mut().storage, "event".try_into().unwrap(), worker.clone(), pool_id.clone(), @@ -1099,31 +1115,37 @@ mod test { // rewards per epoch is 100, we only add 10 let rewards_added = 10u128; - let _ = contract.add_rewards( + let _ = add_rewards( + mock_deps.as_mut().storage, pool_id.clone(), Uint128::from(rewards_added).try_into().unwrap(), ); - let err = contract - .distribute_rewards( - pool_id.clone(), - block_height_started + epoch_duration * 2, - None, - ) - .unwrap_err(); + let err = distribute_rewards( + mock_deps.as_mut().storage, + pool_id.clone(), + block_height_started + epoch_duration * 2, + None, + ) + .unwrap_err(); assert_eq!( err.current_context(), &ContractError::PoolBalanceInsufficient ); // add some more rewards let rewards_added = 90u128; - let _ = contract.add_rewards( + let _ = add_rewards( + mock_deps.as_mut().storage, pool_id.clone(), Uint128::from(rewards_added).try_into().unwrap(), ); - let result = - contract.distribute_rewards(pool_id, block_height_started + epoch_duration * 2, None); + let result = distribute_rewards( + mock_deps.as_mut().storage, + pool_id, + block_height_started + epoch_duration * 2, + None, + ); assert!(result.is_ok()); assert_eq!(result.unwrap().len(), 1); } @@ -1137,7 +1159,7 @@ mod test { let rewards_per_epoch = 100u128; let participation_threshold = (8, 10); - let mut contract = setup_with_params( + let (mut mock_deps, _) = setup_with_params( cur_epoch_num, block_height_started, epoch_duration, @@ -1150,7 +1172,8 @@ mod test { contract: Addr::unchecked("worker_contract"), }; - let _ = contract.record_participation( + let _ = record_participation( + mock_deps.as_mut().storage, "event".try_into().unwrap(), worker.clone(), pool_id.clone(), @@ -1158,123 +1181,82 @@ mod test { ); let rewards_added = 1000u128; - let _ = contract.add_rewards( + let _ = add_rewards( + mock_deps.as_mut().storage, pool_id.clone(), Uint128::from(rewards_added).try_into().unwrap(), ); - let rewards_claimed = contract - .distribute_rewards( - pool_id.clone(), - block_height_started + epoch_duration * 2, - None, - ) - .unwrap(); + let rewards_claimed = distribute_rewards( + mock_deps.as_mut().storage, + pool_id.clone(), + block_height_started + epoch_duration * 2, + None, + ) + .unwrap(); assert_eq!(rewards_claimed.len(), 1); // try to claim again, shouldn't get an error - let err = contract - .distribute_rewards(pool_id, block_height_started + epoch_duration * 2, None) - .unwrap_err(); + let err = distribute_rewards( + mock_deps.as_mut().storage, + pool_id, + block_height_started + epoch_duration * 2, + None, + ) + .unwrap_err(); assert_eq!(err.current_context(), &ContractError::NoRewardsToDistribute); } - fn create_contract( - params_store: Arc>, - events_store: Arc>>, - tally_store: Arc>>, - rewards_store: Arc>>, - watermark_store: Arc>>, - ) -> Contract { - let mut store = state::MockStore::new(); - let params_store_cloned = params_store.clone(); - store - .expect_load_params() - .returning(move || params_store_cloned.read().unwrap().clone()); - store.expect_save_params().returning(move |new_params| { - let mut params_store = params_store.write().unwrap(); - *params_store = new_params.clone(); - Ok(()) - }); - let events_store_cloned = events_store.clone(); - store.expect_load_event().returning(move |id, pool_id| { - let events_store = events_store_cloned.read().unwrap(); - Ok(events_store.get(&(id, pool_id)).cloned()) - }); - store.expect_save_event().returning(move |event| { - let mut events_store = events_store.write().unwrap(); - events_store.insert( - (event.event_id.clone().into(), event.pool_id.clone()), - event.clone(), - ); - Ok(()) - }); - let tally_store_cloned = tally_store.clone(); - store - .expect_load_epoch_tally() - .returning(move |pool_id, epoch_num| { - let tally_store = tally_store_cloned.read().unwrap(); - let tally_id = TallyId { - pool_id: pool_id.clone(), - epoch_num, - }; - Ok(tally_store.get(&tally_id).cloned()) - }); - store.expect_save_epoch_tally().returning(move |tally| { - let mut tally_store = tally_store.write().unwrap(); - let tally_id = TallyId { - pool_id: tally.pool_id.clone(), - epoch_num: tally.epoch.epoch_num, - }; - tally_store.insert(tally_id, tally.clone()); - Ok(()) + type MockDeps = OwnedDeps; + + fn set_initial_storage( + params_store: ParamsSnapshot, + events_store: Vec, + tally_store: Vec, + rewards_store: Vec, + watermark_store: HashMap, + ) -> (MockDeps, Config) { + let mut deps = mock_dependencies(); + let storage = deps.as_mut().storage; + + state::save_params(storage, ¶ms_store).unwrap(); + + events_store.iter().for_each(|event| { + state::save_event(storage, event).unwrap(); }); - let rewards_store_cloned = rewards_store.clone(); - store.expect_load_rewards_pool().returning(move |pool_id| { - let rewards_store = rewards_store_cloned.read().unwrap(); - Ok(rewards_store.get(&pool_id).cloned().unwrap_or(RewardsPool { - id: pool_id, - balance: Uint128::zero(), - })) + tally_store.iter().for_each(|tally| { + state::save_epoch_tally(storage, tally).unwrap(); }); - store.expect_save_rewards_pool().returning(move |pool| { - let mut rewards_store = rewards_store.write().unwrap(); - rewards_store.insert(pool.id.clone(), pool.clone()); - Ok(()) + + rewards_store.iter().for_each(|pool| { + state::save_rewards_pool(storage, pool).unwrap(); }); - let watermark_store_cloned = watermark_store.clone(); - store - .expect_load_rewards_watermark() - .returning(move |contract| { - let watermark_store = watermark_store_cloned.read().unwrap(); - Ok(watermark_store.get(&contract).cloned()) - }); - store - .expect_save_rewards_watermark() - .returning(move |pool_id, epoch_num| { - let mut watermark_store = watermark_store.write().unwrap(); - watermark_store.insert(pool_id, epoch_num); - Ok(()) + watermark_store + .into_iter() + .for_each(|(pool_id, epoch_num)| { + state::save_rewards_watermark(storage, pool_id, epoch_num).unwrap(); }); - Contract { - store, - config: Config { - governance: Addr::unchecked("governance"), - rewards_denom: "AXL".to_string(), - }, - } + + let config = Config { + governance: Addr::unchecked("governance"), + rewards_denom: "AXL".to_string(), + }; + + CONFIG.save(storage, &config).unwrap(); + + (deps, config) } fn setup_with_stores( - params_store: Arc>, - events_store: Arc>>, - tally_store: Arc>>, - rewards_store: Arc>>, - watermark_store: Arc>>, - ) -> Contract { - create_contract( + params_store: ParamsSnapshot, + events_store: Vec, + tally_store: Vec, + rewards_store: Vec, + watermark_store: HashMap, + ) -> (MockDeps, Config) { + set_initial_storage( params_store, events_store, tally_store, @@ -1289,7 +1271,7 @@ mod test { epoch_duration: u64, rewards_per_epoch: u128, participation_threshold: (u64, u64), - ) -> Contract { + ) -> (MockDeps, Config) { let rewards_per_epoch: nonempty::Uint128 = cosmwasm_std::Uint128::from(rewards_per_epoch) .try_into() .unwrap(); @@ -1306,11 +1288,11 @@ mod test { }, created_at: current_epoch.clone(), }; - let params_snapshot = Arc::new(RwLock::new(params_snapshot)); - let rewards_store = Arc::new(RwLock::new(HashMap::new())); - let events_store = Arc::new(RwLock::new(HashMap::new())); - let tally_store = Arc::new(RwLock::new(HashMap::new())); - let watermark_store = Arc::new(RwLock::new(HashMap::new())); + + let rewards_store = Vec::new(); + let events_store = Vec::new(); + let tally_store = Vec::new(); + let watermark_store = HashMap::new(); setup_with_stores( params_snapshot, events_store, @@ -1324,7 +1306,7 @@ mod test { cur_epoch_num: u64, block_height_started: u64, epoch_duration: u64, - ) -> Contract { + ) -> (MockDeps, Config) { let participation_threshold = (1, 2); let rewards_per_epoch = 100u128; setup_with_params( diff --git a/contracts/rewards/src/contract/query.rs b/contracts/rewards/src/contract/query.rs index c55f94f7a..2a01fdc15 100644 --- a/contracts/rewards/src/contract/query.rs +++ b/contracts/rewards/src/contract/query.rs @@ -41,7 +41,7 @@ mod tests { use crate::{ msg::Params, - state::{EpochTally, ParamsSnapshot, RewardsPool, RewardsStore, Store}, + state::{EpochTally, ParamsSnapshot, RewardsPool}, }; use super::*; @@ -71,9 +71,8 @@ mod tests { balance: initial_balance, }; - let mut store = RewardsStore { storage }; - store.save_params(¶ms_snapshot).unwrap(); - store.save_rewards_pool(&rewards_pool).unwrap(); + state::save_params(storage, ¶ms_snapshot).unwrap(); + state::save_rewards_pool(storage, &rewards_pool).unwrap(); (params_snapshot, pool_id) } @@ -113,12 +112,12 @@ mod tests { let block_height = 1000; let last_distribution_epoch = 5u64; - let mut store = RewardsStore { - storage: deps.as_mut().storage, - }; - store - .save_rewards_watermark(pool_id.clone(), last_distribution_epoch) - .unwrap(); + state::save_rewards_watermark( + deps.as_mut().storage, + pool_id.clone(), + last_distribution_epoch, + ) + .unwrap(); let res = rewards_pool(deps.as_mut().storage, pool_id.clone(), block_height).unwrap(); assert_eq!( @@ -152,16 +151,15 @@ mod tests { participation_threshold: (2, 3).try_into().unwrap(), }; - let mut store = RewardsStore { - storage: deps.as_mut().storage, - }; - store - .save_epoch_tally(&EpochTally::new( + state::save_epoch_tally( + deps.as_mut().storage, + &EpochTally::new( pool_id.clone(), Epoch::current(¤t_params, old_block_height).unwrap(), tally_params.clone(), - )) - .unwrap(); + ), + ) + .unwrap(); let cur_block_height = 1000; let res = rewards_pool(deps.as_mut().storage, pool_id.clone(), cur_block_height).unwrap(); @@ -195,16 +193,15 @@ mod tests { participation_threshold: (2, 3).try_into().unwrap(), }; - let mut store = RewardsStore { - storage: deps.as_mut().storage, - }; - store - .save_epoch_tally(&EpochTally::new( + state::save_epoch_tally( + deps.as_mut().storage, + &EpochTally::new( pool_id.clone(), Epoch::current(¤t_params, block_height).unwrap(), tally_params.clone(), - )) - .unwrap(); + ), + ) + .unwrap(); let res = rewards_pool(deps.as_mut().storage, pool_id.clone(), block_height).unwrap(); assert_eq!( diff --git a/contracts/rewards/src/state.rs b/contracts/rewards/src/state.rs index e3fa9ad0a..a2125bd46 100644 --- a/contracts/rewards/src/state.rs +++ b/contracts/rewards/src/state.rs @@ -7,7 +7,6 @@ use cosmwasm_schema::cw_serde; use cosmwasm_std::{Addr, StdResult, Storage, Uint128}; use cw_storage_plus::{Item, Key, KeyDeserialize, Map, Prefixer, PrimaryKey}; use error_stack::{Result, ResultExt}; -use mockall::automock; use crate::{error::ContractError, msg::Params}; @@ -32,6 +31,15 @@ pub struct PoolId { pub contract: Addr, } +impl PoolId { + pub fn new(chain_name: ChainName, contract: Addr) -> Self { + PoolId { + chain_name, + contract, + } + } +} + impl PrimaryKey<'_> for PoolId { type Prefix = ChainName; type SubPrefix = (); @@ -218,13 +226,9 @@ pub struct RewardsPool { } impl RewardsPool { - #[allow(dead_code)] - pub fn new(chain_name: ChainName, contract: Addr) -> Self { + pub fn new(id: PoolId) -> Self { RewardsPool { - id: PoolId { - chain_name, - contract, - }, + id, balance: Uint128::zero(), } } @@ -239,38 +243,6 @@ impl RewardsPool { } } -#[automock] -pub trait Store { - fn load_params(&self) -> ParamsSnapshot; - - fn load_rewards_watermark(&self, pool_id: PoolId) -> Result, ContractError>; - - fn load_event(&self, event_id: String, pool_id: PoolId) - -> Result, ContractError>; - - fn load_epoch_tally( - &self, - pool_id: PoolId, - epoch_num: u64, - ) -> Result, ContractError>; - - fn load_rewards_pool(&self, pool_id: PoolId) -> Result; - - fn save_params(&mut self, params: &ParamsSnapshot) -> Result<(), ContractError>; - - fn save_rewards_watermark( - &mut self, - pool_id: PoolId, - epoch_num: u64, - ) -> Result<(), ContractError>; - - fn save_event(&mut self, event: &Event) -> Result<(), ContractError>; - - fn save_epoch_tally(&mut self, tally: &EpochTally) -> Result<(), ContractError>; - - fn save_rewards_pool(&mut self, pool: &RewardsPool) -> Result<(), ContractError>; -} - /// Current rewards parameters, along with when the params were updated pub const PARAMS: Item = Item::new("params"); @@ -289,93 +261,118 @@ const WATERMARKS: Map = Map::new("rewards_watermarks"); pub const CONFIG: Item = Item::new("config"); -pub struct RewardsStore<'a> { - pub storage: &'a mut dyn Storage, +pub(crate) fn load_config(storage: &dyn Storage) -> Config { + CONFIG.load(storage).expect("couldn't load config") } -impl Store for RewardsStore<'_> { - fn load_params(&self) -> ParamsSnapshot { - load_params(self.storage) - } +pub(crate) fn load_params(storage: &dyn Storage) -> ParamsSnapshot { + PARAMS.load(storage).expect("params should exist") +} - fn load_rewards_watermark(&self, pool_id: PoolId) -> Result, ContractError> { - WATERMARKS - .may_load(self.storage, pool_id) - .change_context(ContractError::LoadRewardsWatermark) - } +pub(crate) fn load_rewards_watermark( + storage: &dyn Storage, + pool_id: PoolId, +) -> Result, ContractError> { + WATERMARKS + .may_load(storage, pool_id) + .change_context(ContractError::LoadRewardsWatermark) +} - fn load_event( - &self, - event_id: String, - pool_id: PoolId, - ) -> Result, ContractError> { - EVENTS - .may_load(self.storage, (event_id, pool_id)) - .change_context(ContractError::LoadEvent) - } +pub(crate) fn load_event( + storage: &dyn Storage, + event_id: String, + pool_id: PoolId, +) -> Result, ContractError> { + EVENTS + .may_load(storage, (event_id, pool_id)) + .change_context(ContractError::LoadEvent) +} - fn load_epoch_tally( - &self, - pool_id: PoolId, - epoch_num: u64, - ) -> Result, ContractError> { - load_epoch_tally(self.storage, pool_id, epoch_num) - } +pub(crate) fn load_epoch_tally( + storage: &dyn Storage, + pool_id: PoolId, + epoch_num: u64, +) -> Result, ContractError> { + TALLIES + .may_load(storage, TallyId { pool_id, epoch_num }) + .change_context(ContractError::LoadEpochTally) +} - fn load_rewards_pool(&self, pool_id: PoolId) -> Result { - POOLS - .may_load(self.storage, pool_id.clone()) - .change_context(ContractError::LoadRewardsPool) - .map(|pool| { - pool.unwrap_or(RewardsPool { - id: pool_id, - balance: Uint128::zero(), - }) - }) - } +pub(crate) fn may_load_rewards_pool( + storage: &dyn Storage, + pool_id: PoolId, +) -> Result, ContractError> { + POOLS + .may_load(storage, pool_id.clone()) + .change_context(ContractError::LoadRewardsPool) +} - fn save_params(&mut self, params: &ParamsSnapshot) -> Result<(), ContractError> { - PARAMS - .save(self.storage, params) - .change_context(ContractError::SaveParams) - } +pub(crate) fn load_rewards_pool_or_new( + storage: &dyn Storage, + pool_id: PoolId, +) -> Result { + may_load_rewards_pool(storage, pool_id.clone()) + .map(|pool| pool.unwrap_or(RewardsPool::new(pool_id))) +} - fn save_rewards_watermark( - &mut self, - pool_id: PoolId, - epoch_num: u64, - ) -> Result<(), ContractError> { - WATERMARKS - .save(self.storage, pool_id, &epoch_num) - .change_context(ContractError::SaveRewardsWatermark) - } +pub(crate) fn load_rewards_pool( + storage: &dyn Storage, + pool_id: PoolId, +) -> Result { + may_load_rewards_pool(storage, pool_id.clone())? + .ok_or(ContractError::RewardsPoolNotFound.into()) +} - fn save_event(&mut self, event: &Event) -> Result<(), ContractError> { - EVENTS - .save( - self.storage, - (event.event_id.clone().into(), event.pool_id.clone()), - event, - ) - .change_context(ContractError::SaveEvent) - } +pub(crate) fn save_params( + storage: &mut dyn Storage, + params: &ParamsSnapshot, +) -> Result<(), ContractError> { + PARAMS + .save(storage, params) + .change_context(ContractError::SaveParams) +} - fn save_epoch_tally(&mut self, tally: &EpochTally) -> Result<(), ContractError> { - let tally_id = TallyId { - pool_id: tally.pool_id.clone(), - epoch_num: tally.epoch.epoch_num, - }; +pub(crate) fn save_rewards_watermark( + storage: &mut dyn Storage, + pool_id: PoolId, + epoch_num: u64, +) -> Result<(), ContractError> { + WATERMARKS + .save(storage, pool_id, &epoch_num) + .change_context(ContractError::SaveRewardsWatermark) +} - TALLIES - .save(self.storage, tally_id, tally) - .change_context(ContractError::SaveEpochTally) - } +pub(crate) fn save_event(storage: &mut dyn Storage, event: &Event) -> Result<(), ContractError> { + EVENTS + .save( + storage, + (event.event_id.clone().into(), event.pool_id.clone()), + event, + ) + .change_context(ContractError::SaveEvent) +} - fn save_rewards_pool(&mut self, pool: &RewardsPool) -> Result<(), ContractError> { - POOLS - .save(self.storage, pool.id.clone(), pool) - .change_context(ContractError::SaveRewardsPool) - } +pub(crate) fn save_epoch_tally( + storage: &mut dyn Storage, + tally: &EpochTally, +) -> Result<(), ContractError> { + let tally_id = TallyId { + pool_id: tally.pool_id.clone(), + epoch_num: tally.epoch.epoch_num, + }; + + TALLIES + .save(storage, tally_id, tally) + .change_context(ContractError::SaveEpochTally) +} + +pub(crate) fn save_rewards_pool( + storage: &mut dyn Storage, + pool: &RewardsPool, +) -> Result<(), ContractError> { + POOLS + .save(storage, pool.id.clone(), pool) + .change_context(ContractError::SaveRewardsPool) } pub(crate) enum StorageState { @@ -394,42 +391,9 @@ impl Deref for StorageState { } } -pub(crate) fn load_rewards_pool( - storage: &dyn Storage, - pool_id: PoolId, -) -> Result { - POOLS - .may_load(storage, pool_id) - .change_context(ContractError::LoadRewardsPool)? - .ok_or(ContractError::RewardsPoolNotFound.into()) -} - -pub(crate) fn load_params(storage: &dyn Storage) -> ParamsSnapshot { - PARAMS.load(storage).expect("params should exist") -} - -pub(crate) fn load_rewards_watermark( - storage: &dyn Storage, - pool_id: PoolId, -) -> Result, ContractError> { - WATERMARKS - .may_load(storage, pool_id) - .change_context(ContractError::LoadRewardsWatermark) -} - -pub(crate) fn load_epoch_tally( - storage: &dyn Storage, - pool_id: PoolId, - epoch_num: u64, -) -> Result, ContractError> { - TALLIES - .may_load(storage, TallyId { pool_id, epoch_num }) - .change_context(ContractError::LoadEpochTally) -} - #[cfg(test)] mod test { - use super::{Epoch, EpochTally, Event, PoolId, RewardsPool, RewardsStore, Store}; + use super::*; use crate::error::ContractError; use crate::{msg::Params, state::ParamsSnapshot}; use connection_router_api::ChainName; @@ -522,9 +486,6 @@ mod test { #[test] fn save_and_load_params() { let mut mock_deps = mock_dependencies(); - let mut store = RewardsStore { - storage: &mut mock_deps.storage, - }; let params = ParamsSnapshot { params: Params { participation_threshold: (Uint64::new(1), Uint64::new(2)).try_into().unwrap(), @@ -537,8 +498,8 @@ mod test { }, }; // save an initial params, then load it - assert!(store.save_params(¶ms).is_ok()); - let loaded = store.load_params(); + assert!(save_params(mock_deps.as_mut().storage, ¶ms).is_ok()); + let loaded = load_params(mock_deps.as_ref().storage); assert_eq!(loaded, params); // now store a new params, and check that it was updated @@ -552,17 +513,14 @@ mod test { block_height_started: 101, }, }; - assert!(store.save_params(&new_params).is_ok()); - let loaded = store.load_params(); + assert!(save_params(mock_deps.as_mut().storage, &new_params).is_ok()); + let loaded = load_params(mock_deps.as_mut().storage); assert_eq!(loaded, new_params); } #[test] fn save_and_load_rewards_watermark() { let mut mock_deps = mock_dependencies(); - let mut store = RewardsStore { - storage: &mut mock_deps.storage, - }; let epoch = Epoch { epoch_num: 10, block_height_started: 1000, @@ -573,24 +531,29 @@ mod test { }; // should be empty at first - let loaded = store.load_rewards_watermark(pool_id.clone()); + let loaded = load_rewards_watermark(mock_deps.as_ref().storage, pool_id.clone()); assert!(loaded.is_ok()); assert!(loaded.unwrap().is_none()); // save the first watermark - let res = store.save_rewards_watermark(pool_id.clone(), epoch.epoch_num); + let res = + save_rewards_watermark(mock_deps.as_mut().storage, pool_id.clone(), epoch.epoch_num); assert!(res.is_ok()); - let loaded = store.load_rewards_watermark(pool_id.clone()); + let loaded = load_rewards_watermark(mock_deps.as_ref().storage, pool_id.clone()); assert!(loaded.is_ok()); assert!(loaded.as_ref().unwrap().is_some()); assert_eq!(loaded.unwrap().unwrap(), epoch.epoch_num); // now store a new watermark, should overwrite - let res = store.save_rewards_watermark(pool_id.clone(), epoch.epoch_num + 1); + let res = save_rewards_watermark( + mock_deps.as_mut().storage, + pool_id.clone(), + epoch.epoch_num + 1, + ); assert!(res.is_ok()); - let loaded = store.load_rewards_watermark(pool_id); + let loaded = load_rewards_watermark(mock_deps.as_ref().storage, pool_id); assert!(loaded.is_ok()); assert!(loaded.as_ref().unwrap().is_some()); assert_eq!(loaded.unwrap().unwrap(), epoch.epoch_num + 1); @@ -601,15 +564,19 @@ mod test { contract: Addr::unchecked("some other contract"), }; // should be empty at first - let loaded = store.load_rewards_watermark(diff_pool_id.clone()); + let loaded = load_rewards_watermark(mock_deps.as_ref().storage, diff_pool_id.clone()); assert!(loaded.is_ok()); assert!(loaded.unwrap().is_none()); // save the first watermark for this contract - let res = store.save_rewards_watermark(diff_pool_id.clone(), epoch.epoch_num + 7); + let res = save_rewards_watermark( + mock_deps.as_mut().storage, + diff_pool_id.clone(), + epoch.epoch_num + 7, + ); assert!(res.is_ok()); - let loaded = store.load_rewards_watermark(diff_pool_id.clone()); + let loaded = load_rewards_watermark(mock_deps.as_ref().storage, diff_pool_id.clone()); assert!(loaded.is_ok()); assert!(loaded.as_ref().unwrap().is_some()); assert_eq!(loaded.unwrap().unwrap(), epoch.epoch_num + 7); @@ -618,9 +585,6 @@ mod test { #[test] fn save_and_load_event() { let mut mock_deps = mock_dependencies(); - let mut store = RewardsStore { - storage: &mut mock_deps.storage, - }; let event = Event { pool_id: PoolId { @@ -631,11 +595,15 @@ mod test { epoch_num: 2, }; - let res = store.save_event(&event); + let res = save_event(mock_deps.as_mut().storage, &event); assert!(res.is_ok()); // check that we load the event that we just saved - let loaded = store.load_event(event.event_id.clone().into(), event.pool_id.clone()); + let loaded = load_event( + mock_deps.as_ref().storage, + event.event_id.clone().into(), + event.pool_id.clone(), + ); assert!(loaded.is_ok()); assert!(loaded.as_ref().unwrap().is_some()); assert_eq!(loaded.unwrap().unwrap(), event); @@ -645,17 +613,29 @@ mod test { chain_name: "mock-chain".parse().unwrap(), contract: Addr::unchecked("different contract"), }; - let loaded = store.load_event("some other event".into(), diff_pool_id.clone()); + let loaded = load_event( + mock_deps.as_ref().storage, + "some other event".into(), + diff_pool_id.clone(), + ); assert!(loaded.is_ok()); assert!(loaded.unwrap().is_none()); // same event id but different contract address, should still return none - let loaded = store.load_event(event.event_id.clone().into(), diff_pool_id); + let loaded = load_event( + mock_deps.as_ref().storage, + event.event_id.clone().into(), + diff_pool_id, + ); assert!(loaded.is_ok()); assert!(loaded.unwrap().is_none()); // different event id, but same contract address, should still return none - let loaded = store.load_event("some other event".into(), event.pool_id); + let loaded = load_event( + mock_deps.as_ref().storage, + "some other event".into(), + event.pool_id, + ); assert!(loaded.is_ok()); assert!(loaded.unwrap().is_none()); } @@ -663,9 +643,6 @@ mod test { #[test] fn save_and_load_epoch_tally() { let mut mock_deps = mock_dependencies(); - let mut store = RewardsStore { - storage: &mut mock_deps.storage, - }; let epoch_num = 10; let rewards_rate = Uint128::from(100u128).try_into().unwrap(); @@ -689,17 +666,18 @@ mod test { tally = tally.record_participation(Addr::unchecked("worker")); - let res = store.save_epoch_tally(&tally); + let res = save_epoch_tally(mock_deps.as_mut().storage, &tally); assert!(res.is_ok()); // check that we load the tally that we just saved - let loaded = store.load_epoch_tally(pool_id.clone(), epoch_num); + let loaded = load_epoch_tally(mock_deps.as_ref().storage, pool_id.clone(), epoch_num); assert!(loaded.is_ok()); assert!(loaded.as_ref().unwrap().is_some()); assert_eq!(loaded.unwrap().unwrap(), tally); // different contract but same epoch should return none - let loaded = store.load_epoch_tally( + let loaded = load_epoch_tally( + mock_deps.as_ref().storage, PoolId { chain_name: "mock-chain".parse().unwrap(), contract: Addr::unchecked("different contract"), @@ -710,12 +688,12 @@ mod test { assert!(loaded.unwrap().is_none()); // different epoch but same contract should return none - let loaded = store.load_epoch_tally(pool_id.clone(), epoch_num + 1); + let loaded = load_epoch_tally(mock_deps.as_ref().storage, pool_id.clone(), epoch_num + 1); assert!(loaded.is_ok()); assert!(loaded.unwrap().is_none()); // different epoch and different contract should return none - let loaded = store.load_epoch_tally(pool_id.clone(), epoch_num + 1); + let loaded = load_epoch_tally(mock_deps.as_ref().storage, pool_id.clone(), epoch_num + 1); assert!(loaded.is_ok()); assert!(loaded.unwrap().is_none()); } @@ -723,24 +701,28 @@ mod test { #[test] fn save_and_load_rewards_pool() { let mut mock_deps = mock_dependencies(); - let mut store = RewardsStore { - storage: &mut mock_deps.storage, - }; let chain_name: ChainName = "mock-chain".parse().unwrap(); - let pool = RewardsPool::new(chain_name.clone(), Addr::unchecked("some contract")); - let res = store.save_rewards_pool(&pool); + let pool = RewardsPool::new(PoolId::new( + chain_name.clone(), + Addr::unchecked("some contract"), + )); + let res = save_rewards_pool(mock_deps.as_mut().storage, &pool); assert!(res.is_ok()); - let loaded = store.load_rewards_pool(pool.id.clone()); + let loaded = load_rewards_pool_or_new(mock_deps.as_ref().storage, pool.id.clone()); assert!(loaded.is_ok()); assert_eq!(loaded.unwrap(), pool); - let loaded = store.load_rewards_pool(PoolId { - chain_name: chain_name.clone(), - contract: Addr::unchecked("a different contract"), - }); + // return new pool when pool is not found + let loaded = load_rewards_pool_or_new( + mock_deps.as_ref().storage, + PoolId { + chain_name: chain_name.clone(), + contract: Addr::unchecked("a different contract"), + }, + ); assert!(loaded.is_ok()); assert!(loaded.as_ref().unwrap().balance.is_zero()); }