Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add PasswordError to handle error correctly on python side #53

Merged
merged 3 commits into from
Oct 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions bittensor_wallet/errors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@

ConfigurationError = _.ConfigurationError
KeyFileError = _.KeyFileError
PasswordError = _.PasswordError
34 changes: 33 additions & 1 deletion src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@ use pyo3::exceptions::PyException;
use pyo3::prelude::*;
use std::{error, fmt};

// KeyFileError
#[pyclass(extends=PyException)]
#[derive(Debug)]
pub struct KeyFileError {
pub message: String,
}

/// Error thrown when the keyfile is corrupt, non-writable, non-readable or the password used to decrypt is invalid.
/// Error thrown when the keyfile is corrupt, non-writable, non-readable.
#[pymethods]
impl KeyFileError {
#[new]
Expand All @@ -31,6 +32,7 @@ impl fmt::Display for KeyFileError {

impl error::Error for KeyFileError {}

// ConfigurationError
#[pyclass(extends=PyException)]
#[derive(Debug)]
pub struct ConfigurationError {
Expand Down Expand Up @@ -59,3 +61,33 @@ impl fmt::Display for ConfigurationError {
}

impl error::Error for ConfigurationError {}

// PasswordError
#[pyclass(extends=PyException)]
#[derive(Debug)]
pub struct PasswordError {
pub message: String,
}

/// PasswordError occurs if the password used for decryption is invalid.
#[pymethods]
impl PasswordError {
#[new]
#[pyo3(signature = (message=None))]
pub fn new(message: Option<String>) -> Self {
let msg = message.unwrap_or_default();
PasswordError { message: msg }
}

pub fn __str__(&self) -> String {
self.message.clone()
}
}

impl fmt::Display for PasswordError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "PasswordError: {}", self.message)
}
}

impl error::Error for PasswordError {}
12 changes: 6 additions & 6 deletions src/keyfile.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use passwords::analyzer;
use passwords::scorer;
use serde_json::json;

use crate::errors::KeyFileError;
use crate::errors::{KeyFileError, PasswordError};
use crate::keypair::Keypair;
use crate::utils;

Expand Down Expand Up @@ -414,7 +414,7 @@ pub fn decrypt_keyfile_data(
.ok_or(PyErr::new::<PyRuntimeError, _>("Invalid nonce."))?;
let ciphertext = &data[secretbox::NONCEBYTES..];
secretbox::open(ciphertext, &nonce, key)
.map_err(|_| PyErr::new::<KeyFileError, _>("Wrong password for nacl decryption."))
.map_err(|_| PyErr::new::<PasswordError, _>("Wrong password for nacl decryption."))
}

// decrypt of keyfile_data with legacy way
Expand All @@ -428,7 +428,7 @@ pub fn decrypt_keyfile_data(
let keyfile_data_str = from_utf8(keyfile_data)?;
fernet
.decrypt(keyfile_data_str)
.map_err(|_| PyErr::new::<KeyFileError, _>("Wrong password for nacl decryption."))
.map_err(|_| PyErr::new::<PasswordError, _>("Wrong password for legacy decryption."))
}

let mut password = password;
Expand All @@ -453,21 +453,21 @@ pub fn decrypt_keyfile_data(
if keyfile_data_is_encrypted_nacl(py, keyfile_data)? {
let key = derive_key(password.as_bytes());
let decrypted_data = nacl_decrypt(keyfile_data, &key)
.map_err(|_| PyErr::new::<KeyFileError, _>("Wrong password for decryption."))?;
.map_err(|_| PyErr::new::<PasswordError, _>("Wrong password for decryption."))?;
return Ok(PyBytes::new_bound(py, &decrypted_data).into_py(py));
}

// Ansible Vault decryption
if keyfile_data_is_encrypted_ansible(py, keyfile_data)? {
let decrypted_data = decrypt_vault(keyfile_data, password.as_str())
.map_err(|_| PyErr::new::<KeyFileError, _>("Wrong password for decryption."))?;
.map_err(|_| PyErr::new::<PasswordError, _>("Wrong password for decryption."))?;
return Ok(PyBytes::new_bound(py, &decrypted_data).into_py(py));
}

// Legacy decryption
if keyfile_data_is_encrypted_legacy(py, keyfile_data)? {
let decrypted_data = legacy_decrypt(&password, keyfile_data)
.map_err(|_| PyErr::new::<KeyFileError, _>("Wrong password for decryption."))?;
.map_err(|_| PyErr::new::<PasswordError, _>("Wrong password for decryption."))?;
return Ok(PyBytes::new_bound(py, &decrypted_data).into_py(py));
}

Expand Down
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ fn register_errors_module(main_module: &Bound<'_, PyModule>) -> PyResult<()> {
let errors_module = PyModule::new_bound(main_module.py(), "errors")?;
errors_module.add_class::<errors::ConfigurationError>()?;
errors_module.add_class::<errors::KeyFileError>()?;
errors_module.add_class::<errors::PasswordError>()?;
main_module.add_submodule(&errors_module)
}

Expand Down
18 changes: 12 additions & 6 deletions src/wallet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,11 @@ pub fn display_mnemonic_msg(mnemonic: String, key_type: &str) {
}

// Function to safely retrieve attribute as Option<String> from passed python object
fn get_attribute_string(py: Python, obj: &Bound<PyAny>, attr_name: &str) -> PyResult<Option<String>> {
fn get_attribute_string(
py: Python,
obj: &Bound<PyAny>,
attr_name: &str,
) -> PyResult<Option<String>> {
match obj.getattr(attr_name) {
Ok(attr) => {
if attr.is_none() {
Expand Down Expand Up @@ -82,9 +86,8 @@ impl Wallet {
hotkey: Option<String>,
path: Option<String>,
config: Option<PyObject>,
py: Python
py: Python,
) -> PyResult<Wallet> {

// default config's values if config and config.wallet exist
let mut conf_name: Option<String> = None;
let mut conf_hotkey: Option<String> = None;
Expand Down Expand Up @@ -128,7 +131,10 @@ impl Wallet {
let final_path = if let Some(path) = path {
path
} else if let Some(conf_path) = conf_path {
conf_path.strip_prefix("~/").unwrap_or(&conf_path).to_string()
conf_path
.strip_prefix("~/")
.unwrap_or(&conf_path)
.to_string()
} else {
BT_WALLET_PATH.to_string()
};
Expand Down Expand Up @@ -183,8 +189,8 @@ impl Wallet {
env::var("BT_WALLET_NAME").unwrap_or_else(|_| BT_WALLET_NAME.to_string());
let default_hotkey =
env::var("BT_WALLET_HOTKEY").unwrap_or_else(|_| BT_WALLET_HOTKEY.to_string());
let default_path =
env::var("BT_WALLET_PATH").unwrap_or_else(|_| format!("~/{}", BT_WALLET_PATH.to_string()));
let default_path = env::var("BT_WALLET_PATH")
.unwrap_or_else(|_| format!("~/{}", BT_WALLET_PATH.to_string()));

let prefix_str = if let Some(value) = prefix {
format!("\"{}\"", value)
Expand Down
Loading