diff --git a/src/common/wallet.c b/src/common/wallet.c index 5821ff011..6b956ab13 100644 --- a/src/common/wallet.c +++ b/src/common/wallet.c @@ -424,7 +424,12 @@ int parse_policy_map_key_info(buffer_t *buffer, policy_map_key_info_t *out, int return 0; } -static int parse_placeholder(buffer_t *in_buf, int version, policy_node_key_placeholder_t *out) { +// parses a placeholder from in_buf, storing it in out. On success, the pointed placeholder_index is +// stored in out->placeholder_index, and then it's incremented. +static int parse_placeholder(buffer_t *in_buf, + int version, + policy_node_key_placeholder_t *out, + uint16_t *placeholder_index) { char c; if (!buffer_read_u8(in_buf, (uint8_t *) &c) || c != '@') { return WITH_ERROR(-1, "Expected key placeholder starting with '@'"); @@ -489,6 +494,8 @@ static int parse_placeholder(buffer_t *in_buf, int version, policy_node_key_plac return WITH_ERROR(-1, "Invalid version number"); } + out->placeholder_index = *placeholder_index; + ++(*placeholder_index); return 0; } @@ -544,6 +551,15 @@ static int parse_script(buffer_t *in_buf, unsigned int context_flags) { int n_wrappers = 0; + // Keep track of how many key placeholders have been created while parsing + // This allows to know the counter even in recursive calls + static uint16_t key_placeholder_count = 0; + + if (depth == 0) { + // reset the counter on function entry, but not in recursive calls + key_placeholder_count = 0; + } + policy_node_t *outermost_node = (policy_node_t *) buffer_get_cur(out_buf); policy_node_with_script_t *inner_wrapper = NULL; // pointer to the inner wrapper, if any @@ -1396,7 +1412,7 @@ static int parse_script(buffer_t *in_buf, node->base.type = token; - if (0 > parse_placeholder(in_buf, version, key_placeholder)) { + if (0 > parse_placeholder(in_buf, version, key_placeholder, &key_placeholder_count)) { return WITH_ERROR(-1, "Couldn't parse key placeholder"); } @@ -1466,7 +1482,7 @@ static int parse_script(buffer_t *in_buf, } i_policy_node_key_placeholder(&node->key_placeholder, key_placeholder); - if (0 > parse_placeholder(in_buf, version, key_placeholder)) { + if (0 > parse_placeholder(in_buf, version, key_placeholder, &key_placeholder_count)) { return WITH_ERROR(-1, "Couldn't parse key placeholder"); } @@ -1606,7 +1622,8 @@ static int parse_script(buffer_t *in_buf, return WITH_ERROR(-1, "Out of memory"); } - if (0 > parse_placeholder(in_buf, version, key_placeholder)) { + if (0 > + parse_placeholder(in_buf, version, key_placeholder, &key_placeholder_count)) { return WITH_ERROR(-1, "Error parsing key placeholder"); } diff --git a/src/common/wallet.h b/src/common/wallet.h index 5435292f7..219111ad1 100644 --- a/src/common/wallet.h +++ b/src/common/wallet.h @@ -300,6 +300,8 @@ typedef struct { // common between V1 and V2 int16_t key_index; // index of the key + int16_t + placeholder_index; // index of the placeholder in the descriptor template, in parsing order } policy_node_key_placeholder_t; DEFINE_REL_PTR(policy_node_key_placeholder, policy_node_key_placeholder_t) diff --git a/src/crypto.c b/src/crypto.c index 37486988e..ccb95cfe6 100644 --- a/src/crypto.c +++ b/src/crypto.c @@ -78,6 +78,28 @@ static const uint8_t BIP0341_taptweak_tag[] = {'T', 'a', 'p', 'T', 'w', 'e', 'a' static const uint8_t BIP0341_tapbranch_tag[] = {'T', 'a', 'p', 'B', 'r', 'a', 'n', 'c', 'h'}; static const uint8_t BIP0341_tapleaf_tag[] = {'T', 'a', 'p', 'L', 'e', 'a', 'f'}; +// Copy of cx_ecfp_scalar_mult_no_throw, but without using randomization for the scalar +// multiplication. Therefore, it is faster, but not safe to use on private data, as it is vulnerable +// to timing attacks. +cx_err_t cx_ecfp_scalar_mult_unsafe(cx_curve_t curve, uint8_t *P, const uint8_t *k, size_t k_len) { + size_t size; + cx_ecpoint_t ecP; + cx_err_t error; + + CX_CHECK(cx_ecdomain_parameters_length(curve, &size)); + CX_CHECK(cx_bn_lock(size, 0)); + + CX_CHECK(cx_ecpoint_alloc(&ecP, curve)); + CX_CHECK(cx_ecpoint_init(&ecP, P + 1, size, P + 1 + size, size)); + CX_CHECK(cx_ecpoint_scalarmul(&ecP, k, k_len)); + P[0] = 0x04; + CX_CHECK(cx_ecpoint_export(&ecP, &P[1], size, &P[1 + size], size)); + +end: + cx_bn_unlock(); + return error; +} + /** * Gets the point on the SECP256K1 that corresponds to kG, where G is the curve's generator point. * Returns -1 if point is Infinity or any error occurs; 0 otherwise. @@ -88,6 +110,16 @@ static int secp256k1_point(const uint8_t k[static 32], uint8_t out[static 65]) { return 0; } +/** + * Equivalent to secp256k1_point, but it does not use randomization; it is faster, but only to be + * used with public data, as it is vulnerable to timing attacks. + */ +static int secp256k1_point_unsafe(const uint8_t k[static 32], uint8_t out[static 65]) { + memcpy(out, secp256k1_generator, 65); + if (CX_OK != cx_ecfp_scalar_mult_unsafe(CX_CURVE_SECP256K1, out, k, 32)) return -1; + return 0; +} + int bip32_CKDpub(const serialized_extended_pubkey_t *parent, uint32_t index, serialized_extended_pubkey_t *child) { @@ -126,7 +158,9 @@ int bip32_CKDpub(const serialized_extended_pubkey_t *parent, { // make sure that heavy memory allocations are freed as soon as possible // compute point(I_L) uint8_t P[65]; - if (0 > secp256k1_point(I_L, P)) return -1; + // as the arguments of bip32_CKDpub are public keys, we do not need to use math functions + // hardened against side channels attacks, which are slower + if (0 > secp256k1_point_unsafe(I_L, P)) return -1; uint8_t K_par[65]; crypto_get_uncompressed_pubkey(parent->compressed_pubkey, K_par); @@ -543,7 +577,9 @@ int crypto_tr_tweak_pubkey(const uint8_t pubkey[static 32], return -1; } - if (0 > secp256k1_point(t, Q)) { + // as the arguments of bip32_CKDpub are public keys, we do not need to use math functions + // hardened against side channels attacks, which are slower + if (0 > secp256k1_point_unsafe(t, Q)) { // point at infinity, or error return -1; } diff --git a/src/handler/lib/policy.c b/src/handler/lib/policy.c index c7b7800d1..791484b55 100644 --- a/src/handler/lib/policy.c +++ b/src/handler/lib/policy.c @@ -469,10 +469,16 @@ __attribute__((warn_unused_result)) static int get_derived_pubkey( // we derive the // child of this pubkey // we reuse the same memory of ext_pubkey - bip32_CKDpub(&ext_pubkey, - wdi->change ? key_placeholder->num_second : key_placeholder->num_first, - &ext_pubkey); - bip32_CKDpub(&ext_pubkey, wdi->address_index, &ext_pubkey); + if (0 > derive_first_step_for_pubkey(&ext_pubkey, + key_placeholder, + wdi->sign_psbt_cache, + wdi->change, + &ext_pubkey)) { + return -1; + } + if (0 > bip32_CKDpub(&ext_pubkey, wdi->address_index, &ext_pubkey)) { + return -1; + } memcpy(out, ext_pubkey.compressed_pubkey, 33); diff --git a/src/handler/lib/policy.h b/src/handler/lib/policy.h index 121560ce4..ad9f897b3 100644 --- a/src/handler/lib/policy.h +++ b/src/handler/lib/policy.h @@ -2,6 +2,7 @@ #include "../../boilerplate/dispatcher.h" #include "../../common/wallet.h" +#include "../../handler/sign_psbt/sign_psbt_cache.h" /** * Parses a serialized wallet policy, saving the wallet header, the policy map descriptor and the @@ -48,6 +49,8 @@ typedef struct { uint32_t n_keys; // The number of key information placeholders in the policy size_t address_index; // The address index to use in the derivation bool change; // whether a change address or a receive address is derived + sign_psbt_cache_t + *sign_psbt_cache; // If not NULL, the cache for key derivations used during signing } wallet_derivation_info_t; /** diff --git a/src/handler/sign_psbt.c b/src/handler/sign_psbt.c index 6cadf10c0..1736c912a 100644 --- a/src/handler/sign_psbt.c +++ b/src/handler/sign_psbt.c @@ -48,6 +48,7 @@ #include "handlers.h" +#include "sign_psbt/sign_psbt_cache.h" #include "sign_psbt/compare_wallet_script_at_path.h" #include "sign_psbt/extract_bip32_derivation.h" #include "sign_psbt/update_hashes_with_map_value.h" @@ -373,6 +374,7 @@ static int read_change_and_index_from_psbt_bip32_derivation( dispatcher_context_t *dc, placeholder_info_t *placeholder_info, in_out_info_t *in_out, + sign_psbt_cache_t *sign_psbt_cache, int psbt_key_type, buffer_t *data, const merkleized_map_commitment_t *map_commitment, @@ -424,30 +426,35 @@ static int read_change_and_index_from_psbt_bip32_derivation( } } - uint32_t change = fpt_der[1 + der_len - 2]; + uint32_t change_step = fpt_der[1 + der_len - 2]; uint32_t addr_index = fpt_der[1 + der_len - 1]; - // check that we can indeed derive the same key from the current placeholder - serialized_extended_pubkey_t pubkey; - if (0 > bip32_CKDpub(&placeholder_info->pubkey, change, &pubkey)) return -1; - if (0 > bip32_CKDpub(&pubkey, addr_index, &pubkey)) return -1; - - int pk_offset = is_tap ? 1 : 0; - if (memcmp(pubkey.compressed_pubkey + pk_offset, bip32_derivation_pubkey, key_len) != 0) { - return 0; - } - // check if the 'change' derivation step is indeed coherent with placeholder - if (change == placeholder_info->placeholder.num_first) { + if (change_step == placeholder_info->placeholder.num_first) { in_out->is_change = false; in_out->address_index = addr_index; - } else if (change == placeholder_info->placeholder.num_second) { + } else if (change_step == placeholder_info->placeholder.num_second) { in_out->is_change = true; in_out->address_index = addr_index; } else { return 0; } + // check that we can indeed derive the same key from the current placeholder + serialized_extended_pubkey_t pubkey; + if (0 > derive_first_step_for_pubkey(&placeholder_info->pubkey, + &placeholder_info->placeholder, + sign_psbt_cache, + in_out->is_change, + &pubkey)) + return -1; + if (0 > bip32_CKDpub(&pubkey, addr_index, &pubkey)) return -1; + + int pk_offset = is_tap ? 1 : 0; + if (memcmp(pubkey.compressed_pubkey + pk_offset, bip32_derivation_pubkey, key_len) != 0) { + return 0; + } + in_out->placeholder_found = true; return 1; } @@ -463,6 +470,7 @@ static int read_change_and_index_from_psbt_bip32_derivation( */ static int is_in_out_internal(dispatcher_context_t *dispatcher_context, const sign_psbt_state_t *state, + sign_psbt_cache_t *sign_psbt_cache, const in_out_info_t *in_out_info, bool is_input) { // If we did not find any info about the pubkey associated to the placeholder we're considering, @@ -477,6 +485,7 @@ static int is_in_out_internal(dispatcher_context_t *dispatcher_context, } return compare_wallet_script_at_path(dispatcher_context, + sign_psbt_cache, in_out_info->is_change, in_out_info->address_index, state->wallet_policy_map, @@ -751,6 +760,7 @@ static bool find_first_internal_key_placeholder(dispatcher_context_t *dc, typedef struct { placeholder_info_t *placeholder_info; input_info_t *input; + sign_psbt_cache_t *sign_psbt_cache; } input_keys_callback_data_t; /** @@ -781,6 +791,7 @@ static void input_keys_callback(dispatcher_context_t *dc, read_change_and_index_from_psbt_bip32_derivation(dc, callback_data->placeholder_info, &callback_data->input->in_out, + callback_data->sign_psbt_cache, key_type, data, map_commitment, @@ -794,6 +805,7 @@ static void input_keys_callback(dispatcher_context_t *dc, static bool __attribute__((noinline)) preprocess_inputs(dispatcher_context_t *dc, sign_psbt_state_t *st, + sign_psbt_cache_t *sign_psbt_cache, uint8_t internal_inputs[static BITVECTOR_REAL_SIZE(MAX_N_INPUTS_CAN_SIGN)]) { LOG_PROCESSOR(__FILE__, __LINE__, __func__); @@ -810,7 +822,8 @@ preprocess_inputs(dispatcher_context_t *dc, memset(&input, 0, sizeof(input)); input_keys_callback_data_t callback_data = {.input = &input, - .placeholder_info = &placeholder_info}; + .placeholder_info = &placeholder_info, + .sign_psbt_cache = sign_psbt_cache}; int res = call_get_merkleized_map_with_callback( dc, (void *) &callback_data, @@ -908,7 +921,7 @@ preprocess_inputs(dispatcher_context_t *dc, // check if the input is internal; if not, continue - int is_internal = is_in_out_internal(dc, st, &input.in_out, true); + int is_internal = is_in_out_internal(dc, st, sign_psbt_cache, &input.in_out, true); if (is_internal < 0) { PRINTF("Error checking if input %d is internal\n", cur_input_index); SEND_SW(dc, SW_INCORRECT_DATA); @@ -1007,6 +1020,7 @@ preprocess_inputs(dispatcher_context_t *dc, typedef struct { placeholder_info_t *placeholder_info; output_info_t *output; + sign_psbt_cache_t *sign_psbt_cache; } output_keys_callback_data_t; /** @@ -1029,6 +1043,7 @@ static void output_keys_callback(dispatcher_context_t *dc, read_change_and_index_from_psbt_bip32_derivation(dc, callback_data->placeholder_info, &callback_data->output->in_out, + callback_data->sign_psbt_cache, key_type, data, map_commitment, @@ -1042,6 +1057,7 @@ static void output_keys_callback(dispatcher_context_t *dc, static bool __attribute__((noinline)) preprocess_outputs(dispatcher_context_t *dc, sign_psbt_state_t *st, + sign_psbt_cache_t *sign_psbt_cache, uint8_t internal_outputs[static BITVECTOR_REAL_SIZE(MAX_N_OUTPUTS_CAN_SIGN)]) { /** OUTPUTS VERIFICATION FLOW * @@ -1067,7 +1083,8 @@ preprocess_outputs(dispatcher_context_t *dc, memset(&output, 0, sizeof(output)); output_keys_callback_data_t callback_data = {.output = &output, - .placeholder_info = &placeholder_info}; + .placeholder_info = &placeholder_info, + .sign_psbt_cache = sign_psbt_cache}; int res = call_get_merkleized_map_with_callback( dc, (void *) &callback_data, @@ -1122,7 +1139,7 @@ preprocess_outputs(dispatcher_context_t *dc, output.in_out.scriptPubKey_len = result_len; - int is_internal = is_in_out_internal(dc, st, &output.in_out, false); + int is_internal = is_in_out_internal(dc, st, sign_psbt_cache, &output.in_out, false); if (is_internal < 0) { PRINTF("Error checking if output %d is internal\n", cur_output_index); @@ -2423,6 +2440,7 @@ compute_segwit_hashes(dispatcher_context_t *dc, sign_psbt_state_t *st, segwit_ha static bool __attribute__((noinline)) sign_transaction_input(dispatcher_context_t *dc, sign_psbt_state_t *st, + sign_psbt_cache_t *sign_psbt_cache, segwit_hashes_t *hashes, placeholder_info_t *placeholder_info, input_info_t *input, @@ -2573,7 +2591,8 @@ static bool __attribute__((noinline)) sign_transaction_input(dispatcher_context_ .change = input->in_out.is_change ? 1 : 0, .keys_merkle_root = st->wallet_header.keys_info_merkle_root, .n_keys = st->wallet_header.n_keys, - .wallet_version = st->wallet_header.version}, + .wallet_version = st->wallet_header.version, + .sign_psbt_cache = sign_psbt_cache}, r_policy_node_tree(&policy->tree), input->taptree_hash)) { PRINTF("Error while computing taptree hash\n"); @@ -2603,22 +2622,22 @@ fill_taproot_placeholder_info(dispatcher_context_t *dc, sign_psbt_state_t *st, const input_info_t *input, const policy_node_t *tapleaf_ptr, - placeholder_info_t *placeholder_info) { + placeholder_info_t *placeholder_info, + sign_psbt_cache_t *sign_psbt_cache) { cx_sha256_t hash_context; crypto_tr_tapleaf_hash_init(&hash_context); - // we compute the tapscript once just to compute its length - // this avoids having to store it - int tapscript_len = get_wallet_internal_script_hash( - dc, - tapleaf_ptr, - &(wallet_derivation_info_t){.wallet_version = st->wallet_header.version, + wallet_derivation_info_t wdi = {.wallet_version = st->wallet_header.version, .keys_merkle_root = st->wallet_header.keys_info_merkle_root, .n_keys = st->wallet_header.n_keys, .change = input->in_out.is_change, - .address_index = input->in_out.address_index}, - WRAPPED_SCRIPT_TYPE_TAPSCRIPT, - NULL); + .address_index = input->in_out.address_index, + .sign_psbt_cache = sign_psbt_cache}; + + // we compute the tapscript once just to compute its length + // this avoids having to store it + int tapscript_len = + get_wallet_internal_script_hash(dc, tapleaf_ptr, &wdi, WRAPPED_SCRIPT_TYPE_TAPSCRIPT, NULL); if (tapscript_len < 0) { PRINTF("Failed to compute tapleaf script\n"); return false; @@ -2628,17 +2647,11 @@ fill_taproot_placeholder_info(dispatcher_context_t *dc, crypto_hash_update_varint(&hash_context.header, tapscript_len); // we compute it again to get add the actual script code to the hash computation - if (0 > - get_wallet_internal_script_hash( - dc, - tapleaf_ptr, - &(wallet_derivation_info_t){.wallet_version = st->wallet_header.version, - .keys_merkle_root = st->wallet_header.keys_info_merkle_root, - .n_keys = st->wallet_header.n_keys, - .change = input->in_out.is_change, - .address_index = input->in_out.address_index}, - WRAPPED_SCRIPT_TYPE_TAPSCRIPT, - &hash_context.header)) { + if (0 > get_wallet_internal_script_hash(dc, + tapleaf_ptr, + &wdi, + WRAPPED_SCRIPT_TYPE_TAPSCRIPT, + &hash_context.header)) { return false; // should never happen! } crypto_hash_digest(&hash_context.header, placeholder_info->tapleaf_hash, 32); @@ -2649,6 +2662,7 @@ fill_taproot_placeholder_info(dispatcher_context_t *dc, static bool __attribute__((noinline)) sign_transaction(dispatcher_context_t *dc, sign_psbt_state_t *st, + sign_psbt_cache_t *sign_psbt_cache, const uint8_t internal_inputs[static BITVECTOR_REAL_SIZE(MAX_N_INPUTS_CAN_SIGN)]) { LOG_PROCESSOR(__FILE__, __LINE__, __func__); @@ -2717,10 +2731,17 @@ sign_transaction(dispatcher_context_t *dc, st, &input, tapleaf_ptr, - &placeholder_info)) + &placeholder_info, + sign_psbt_cache)) return false; - if (!sign_transaction_input(dc, st, &hashes, &placeholder_info, &input, i)) { + if (!sign_transaction_input(dc, + st, + sign_psbt_cache, + &hashes, + &placeholder_info, + &input, + i)) { // we do not send a status word, since sign_transaction_input // already does it on failure return false; @@ -2745,6 +2766,9 @@ void handler_sign_psbt(dispatcher_context_t *dc, uint8_t protocol_version) { // read APDU inputs, intialize global state and read global PSBT map if (!init_global_state(dc, &st)) return; + sign_psbt_cache_t cache; + init_sign_psbt_cache(&cache); + // bitmap to keep track of which inputs are internal uint8_t internal_inputs[BITVECTOR_REAL_SIZE(MAX_N_INPUTS_CAN_SIGN)]; memset(internal_inputs, 0, sizeof(internal_inputs)); @@ -2761,14 +2785,14 @@ void handler_sign_psbt(dispatcher_context_t *dc, uint8_t protocol_version) { * - detect internal inputs that should be signed, and if there are external inputs or unusual * sighashes */ - if (!preprocess_inputs(dc, &st, internal_inputs)) return; + if (!preprocess_inputs(dc, &st, &cache, internal_inputs)) return; /** OUTPUTS VERIFICATION FLOW * * For each output, check if it's a change address. * Check if it's an acceptable output. */ - if (!preprocess_outputs(dc, &st, internal_outputs)) return; + if (!preprocess_outputs(dc, &st, &cache, internal_outputs)) return; if (G_swap_state.called_from_swap) { /** SWAP CHECKS @@ -2794,7 +2818,7 @@ void handler_sign_psbt(dispatcher_context_t *dc, uint8_t protocol_version) { * For each internal placeholder, and for each internal input, sign using the * appropriate algorithm. */ - int sign_result = sign_transaction(dc, &st, internal_inputs); + int sign_result = sign_transaction(dc, &st, &cache, internal_inputs); if (!G_swap_state.called_from_swap) { ui_post_processing_confirm_transaction(dc, sign_result); diff --git a/src/handler/sign_psbt/compare_wallet_script_at_path.c b/src/handler/sign_psbt/compare_wallet_script_at_path.c index 08f2c2f8d..fffc008b9 100644 --- a/src/handler/sign_psbt/compare_wallet_script_at_path.c +++ b/src/handler/sign_psbt/compare_wallet_script_at_path.c @@ -9,6 +9,7 @@ #include "../../common/read.h" int compare_wallet_script_at_path(dispatcher_context_t *dispatcher_context, + sign_psbt_cache_t *sign_psbt_cache, uint32_t change, uint32_t address_index, const policy_node_t *policy, @@ -28,7 +29,8 @@ int compare_wallet_script_at_path(dispatcher_context_t *dispatcher_context, .keys_merkle_root = keys_merkle_root, .n_keys = n_keys, .change = change, - .address_index = address_index}, + .address_index = address_index, + .sign_psbt_cache = sign_psbt_cache}, wallet_script); if (wallet_script_len < 0) { PRINTF("Failed to get wallet script\n"); diff --git a/src/handler/sign_psbt/compare_wallet_script_at_path.h b/src/handler/sign_psbt/compare_wallet_script_at_path.h index cdcbee1b7..2940ebf0c 100644 --- a/src/handler/sign_psbt/compare_wallet_script_at_path.h +++ b/src/handler/sign_psbt/compare_wallet_script_at_path.h @@ -3,11 +3,13 @@ #include "../../boilerplate/dispatcher.h" #include "../../common/merkle.h" #include "../../common/wallet.h" +#include "../../handler/sign_psbt/sign_psbt_cache.h" /** * TODO */ int compare_wallet_script_at_path(dispatcher_context_t *dispatcher_context, + sign_psbt_cache_t *sign_psbt_cache, uint32_t change, uint32_t address_index, const policy_node_t *policy, diff --git a/src/handler/sign_psbt/sign_psbt_cache.c b/src/handler/sign_psbt/sign_psbt_cache.c new file mode 100644 index 000000000..c89f17af2 --- /dev/null +++ b/src/handler/sign_psbt/sign_psbt_cache.c @@ -0,0 +1,35 @@ +#include "sign_psbt_cache.h" + +int derive_first_step_for_pubkey(const serialized_extended_pubkey_t *base_key, + const policy_node_key_placeholder_t *placeholder, + sign_psbt_cache_t *cache, + bool is_change, + serialized_extended_pubkey_t *out_pubkey) { + uint32_t change_step = is_change ? placeholder->num_second : placeholder->num_first; + + // make sure a cache was provided, and the index is less than the size of the cache + if (placeholder->placeholder_index >= MAX_CACHED_KEY_EXPRESSIONS || !cache) { + // do not use the cache, derive the key directly + return bip32_CKDpub(base_key, change_step, out_pubkey); + } + + if (!cache->derived_child[placeholder->placeholder_index] + .is_child_pubkey_initialized[is_change]) { + // key not in cache; compute it and store it in the cache + if (0 > bip32_CKDpub( + base_key, + change_step, + &cache->derived_child[placeholder->placeholder_index].child_pubkeys[is_change])) + return -1; + + cache->derived_child[placeholder->placeholder_index] + .is_child_pubkey_initialized[is_change] = true; + } + + // now that we are guaranteed that the key is in cache, we just copy it + memcpy(out_pubkey, + &cache->derived_child[placeholder->placeholder_index].child_pubkeys[is_change], + sizeof(serialized_extended_pubkey_t)); + + return 0; +} diff --git a/src/handler/sign_psbt/sign_psbt_cache.h b/src/handler/sign_psbt/sign_psbt_cache.h new file mode 100644 index 000000000..e0c7b38d5 --- /dev/null +++ b/src/handler/sign_psbt/sign_psbt_cache.h @@ -0,0 +1,59 @@ +#pragma once + +#include "../crypto.h" +#include "../common/wallet.h" + +// This allows to keep the cache size small, while only paying a performance hit for any extremely +// complicated policy with more than 16 key expressions in total (should that occur in practice). +#define MAX_CACHED_KEY_EXPRESSIONS 16 + +// This structure contains all the information that is deterministically computed during the signing +// flow, and might be accessed multiple times. +// Currently, it only contains the derived child keys of the root keys in the key expressions of the +// wallet policy. +typedef struct sign_psbt_cache_s { + struct { + // 0 for the receiving address, 1 for the change address + bool is_child_pubkey_initialized[2]; + serialized_extended_pubkey_t child_pubkeys[2]; + } derived_child[MAX_CACHED_KEY_EXPRESSIONS]; // 78 * 2 * MAX_CACHED_KEY_EXPRESSIONS bytes +} sign_psbt_cache_t; + +/** + * Initializes the sign_psbt_cache_t structure. + * It must be called before a sign_psbt_cache_t is used. + * + * @param[in] cache Pointer to the cache structure to be initialized. + */ +static inline void init_sign_psbt_cache(sign_psbt_cache_t *cache) { + memset(cache, 0, sizeof(sign_psbt_cache_t)); +} + +/* +Public keys in a wallet policy always have two derivation steps: the first is typically 0 or 1, +while the second step is the address index and is usually not reused in different UTXOs. +Therefore, the inputs (and change addresses) will often share the same first step. +By caching the intermediate pubkeys, we avoid recomputing the same BIP-32 pubkey derivations +multiple times. This is particularly important for transactions with many inputs, as the total +number of BIP-32 derivations is cut almost by half when using the cache. +*/ + +/** + * Derives the first step for a public key in a placeholder, using a precomputed value from the + * cache if available. If the key is not in the cache, it is computed and stored in the cache, + * unless the index is placeholder index is too large. + * + * @param[in] base_key Pointer to the base serialized extended public key. + * @param[in] placeholder Pointer to the policy node key placeholder, which contains derivation + * information. + * @param[in] cache Pointer to the cache structure used to store derived child keys. + * @param[in] is_change true if deriving the change address, false otherwise. + * @param[out] out_pubkey Pointer to the output serialized extended public key. + * + * @return 0 on success, -1 on failure. + */ +int derive_first_step_for_pubkey(const serialized_extended_pubkey_t *base_key, + const policy_node_key_placeholder_t *placeholder, + sign_psbt_cache_t *cache, + bool is_change, + serialized_extended_pubkey_t *out_pubkey); \ No newline at end of file