diff --git a/pyth-sdk-solana/src/state.rs b/pyth-sdk-solana/src/state.rs index 6202ba6..cb9865b 100644 --- a/pyth-sdk-solana/src/state.rs +++ b/pyth-sdk-solana/src/state.rs @@ -1,13 +1,23 @@ //! Structures and functions for interacting with Solana on-chain account data. +//! +//! NOTE(2023-05-12): enums defined here use u32 corresponding with +//! uint32_t that's currently used in pyth-client's oracle.h struct +//! definitions. Enum correctness is validated with bytemuck's checked +//! casting functions and derive(CheckedBitPattern) on the relevant +//! enums. use borsh::{ BorshDeserialize, BorshSerialize, }; -use bytemuck::{ +use bytemuck::checked::{ cast_slice, from_bytes, try_cast_slice, + CheckedCastError, +}; +use bytemuck::{ + CheckedBitPattern, Pod, PodCastError, Zeroable, @@ -46,10 +56,11 @@ pub const PROD_ATTR_SIZE: usize = PROD_ACCT_SIZE - PROD_HDR_SIZE; BorshDeserialize, serde::Serialize, serde::Deserialize, + CheckedBitPattern, )] -#[repr(C)] +#[repr(u32)] pub enum AccountType { - Unknown, + Unknown = 0, Mapping, Product, Price, @@ -73,10 +84,11 @@ impl Default for AccountType { BorshDeserialize, serde::Serialize, serde::Deserialize, + CheckedBitPattern, )] -#[repr(C)] +#[repr(u32)] pub enum CorpAction { - NoCorpAct, + NoCorpAct = 0, } impl Default for CorpAction { @@ -97,10 +109,11 @@ impl Default for CorpAction { BorshDeserialize, serde::Serialize, serde::Deserialize, + CheckedBitPattern, )] -#[repr(C)] +#[repr(u32)] pub enum PriceType { - Unknown, + Unknown = 0, Price, } @@ -122,11 +135,12 @@ impl Default for PriceType { BorshDeserialize, serde::Serialize, serde::Deserialize, + CheckedBitPattern, )] -#[repr(C)] +#[repr(u32)] pub enum PriceStatus { /// The price feed is not currently updating for an unknown reason. - Unknown, + Unknown = 0, /// The price feed is updating as expected. Trading, /// The price feed is not currently updating because trading in the product has been halted. @@ -410,14 +424,14 @@ impl PriceAccount { } } -fn load(data: &[u8]) -> Result<&T, PodCastError> { +fn load(data: &[u8]) -> Result<&T, CheckedCastError> { let size = size_of::(); if data.len() >= size { Ok(from_bytes(cast_slice::(try_cast_slice( &data[0..size], )?))) } else { - Err(PodCastError::SizeMismatch) + Err(CheckedCastError::PodCastError(PodCastError::SizeMismatch)) } } @@ -502,6 +516,7 @@ fn get_attr_str(buf: &[u8]) -> (&str, &[u8]) { #[cfg(test)] mod test { + use bytemuck::checked::try_from_bytes; use pyth_sdk::{ Identifier, Price, @@ -737,4 +752,26 @@ mod test { assert_eq!(price_account.get_price_no_older_than(&clock, 1), None); } + + /// Ensure that bytemuck::checked::* casting functions accept + /// valid bytes + #[test] + fn test_happy_recognized_price_status() { + let happy_status_bytes = 2u32.to_le_bytes(); + + let happy_status_result = try_from_bytes::(happy_status_bytes.as_slice()); + + assert_eq!(happy_status_result, Ok(&PriceStatus::Halted)); + } + + /// Ensure that bytemuck::checked::* casting functions reject + /// invalid bytes + #[test] + fn test_sad_unrecognized_price_status() { + let sad_status_bytes = 42_000u32.to_le_bytes(); + + let sad_status_result = try_from_bytes::(sad_status_bytes.as_slice()); + + assert!(sad_status_result.is_err()); + } }