diff --git a/.gitmodules b/.gitmodules deleted file mode 100644 index e69de29..0000000 diff --git a/plugin.cpp b/plugin.cpp index 4f3bb60..bd7673c 100644 --- a/plugin.cpp +++ b/plugin.cpp @@ -1,14 +1,51 @@ #include +struct access_info +{ + cexpr_t* underlying_expr = nullptr; + uint64_t mask; + ea_t ea; + uint8_t shift_value; + + explicit operator bool() const { return underlying_expr != nullptr; } +}; + // makes sure that the immediate / cot_num is on the right hand side -std::pair normalize_binop( cexpr_t* expr ) +inline std::pair normalize_binop( cexpr_t* expr ) { const auto num = expr->find_num_op(); return { expr->theother( num ), num ? num : expr->y }; } +inline void replace_or_delete( cexpr_t* expr, cexpr_t* replacement, bool success ) +{ + if ( !replacement ) + return; + + if ( success ) + expr->replace_by( replacement ); + else + delete replacement; +} + +inline void merge_accesses( cexpr_t*& original, cexpr_t* access, ctype_t op, ea_t ea, tinfo_t type ) +{ + if ( !access ) + return; + + if ( !original ) + original = access; + else + { + original = new cexpr_t( op, original, access ); + original->type = std::move( type ); + original->exflags = 0; + original->ea = ea; + } +} + // used for the allocation of helper names -char* alloc_cstr( const char* str ) +inline char* alloc_cstr( const char* str ) { const auto len = strlen( str ) + 1; auto alloc = hexrays_alloc( len ); @@ -18,7 +55,7 @@ char* alloc_cstr( const char* str ) } // selects (adds memref expr) for the first member that is a struct inside of an union -void select_first_union_field( cexpr_t*& expr ) +inline void select_first_union_field( cexpr_t*& expr ) { if ( !expr->type.is_union() ) return; @@ -42,27 +79,65 @@ void select_first_union_field( cexpr_t*& expr ) } } -// creates a function type of signature `type(bitfield_type, number_type)` -// if creation fails, returns "unknown" type -tinfo_t create_bitfunc( tinfo_t bitfield_type, tinfo_t number_type, tinfo_t type ) +inline cexpr_t* create_bitfield_access( access_info& info, udt_member_t& member, ea_t original_ea, tinfo_t& common_type ) { func_type_data_t data; data.flags = FTI_PURE; - data.rettype = type; - data.push_back( funcarg_t{ .type = bitfield_type } ); - data.push_back( funcarg_t{ .type = number_type } ); + data.rettype = member.size == 1 ? tinfo_t{ BTF_BOOL } : common_type; + data.cc = CM_CC_UNKNOWN; + data.push_back( funcarg_t{ .type = info.underlying_expr->type } ); + data.push_back( funcarg_t{ .type = common_type } ); - tinfo_t t; - if ( t.create_func( data ) ) - return t; + tinfo_t functype; + if ( !functype.create_func( data ) ) + { + msg( "[bitfields] failed to create a bitfield access function type.\n" ); + return nullptr; + } - return tinfo_t{}; + // construct the callable + auto call_fn = new cexpr_t(); + call_fn->op = cot_helper; + call_fn->type = functype; + call_fn->exflags = 0; + call_fn->helper = alloc_cstr( "b" ); + + // construct the call args + auto call_args = new carglist_t( std::move( functype ) ); + + call_args->push_back( carg_t{} ); + auto& arg0 = ( *call_args )[ 0 ]; + static_cast< cexpr_t& >( arg0 ) = *info.underlying_expr; + arg0.ea = info.ea; + + call_args->push_back( carg_t{} ); + auto& arg1 = ( *call_args )[ 1 ]; + arg1.op = cot_helper; + arg1.type = common_type; + arg1.exflags = EXFL_ALONE; + arg1.helper = alloc_cstr( member.name.c_str() ); + + // construct the call / access itself + auto access = new cexpr_t( cot_call, call_fn ); + access->type = member.size == 1 ? tinfo_t{ BTF_BOOL } : common_type; + access->exflags = 0; + access->a = call_args; + access->ea = original_ea; + + return access; +} + +inline uint64_t bitfield_access_mask( udt_member_t& member ) +{ + uint64_t mask = 0; + for ( int i = member.offset; i < member.offset + member.size; ++i ) + mask |= ( 1ull << i ); + return mask; } // executes callback for each member in `type` where its offset coincides with `and_mask`. -// `cmp_mask` is used to calculate whether the bitfield is enabled or not and passed as second arg to the callback. -// If the bitfield is multi-bit, the cmp_mask and boolean passed to the callback has no meaning. -template bool for_each_bitfield( Callback cb, tinfo_t type, uint64_t and_mask, uint64_t cmp_mask ) +// `cmp_mask` is used to calculate enabled bits in the bitfield. +template bool for_each_bitfield( Callback cb, tinfo_t type, uint64_t and_mask ) { udt_member_t member; for ( size_t i = 0; i < 64; ++i ) @@ -80,118 +155,141 @@ template bool for_each_bitfield( Callback cb, tinfo_t type, uint if ( member.offset != i ) continue; - if ( member.size != 1 ) + uint64_t mask = bitfield_access_mask( member ); + if ( member.size != 1 && ( and_mask & mask ) != mask ) { - uint64_t mask = 0; - for ( int i = member.offset; i < member.offset + member.size; ++i ) - mask |= ( 1ull << i ); - - if ( ( and_mask & mask ) != mask ) - { - msg( "[bitfields] bad offset (%ull) and size (%ull) combo of a field for given mask (%ull)\n", member.offset, member.size, and_mask ); - return false; - } + msg( "[bitfields] bad offset (%ull) and size (%ull) combo of a field for given mask (%ull)\n", member.offset, member.size, and_mask ); + return false; } - cb( member, ( cmp_mask & ( 1ull << i ) ) != 0 ); + cb( member ); } return true; } -void handle_comparison( cexpr_t* expr ) +// handles various cases of potential bitfield access. +// * (*(type*)&x >> imm1) & imm2 +// * *(type*)&x & imm +// * HIDWORD(*(type*)&x) +inline access_info unwrap_access( cexpr_t* expr ) { - // (x & imm) == imm - auto [eq, eq_num] = normalize_binop( expr ); - if ( eq->op != cot_band || eq_num->op != cot_num ) - return; + access_info res; + if ( expr->op == cot_band ) + { + auto num = expr->find_num_op(); + if ( !num ) + return res; - auto [band, band_num] = normalize_binop( eq ); + res.mask = num->n->_value; + res.shift_value = 0; + expr = expr->theother( num ); + } + else if ( expr->op == cot_call ) + { + if ( expr->x->op != cot_helper || expr->a->size() != 1 ) + return res; + + constexpr static std::tuple functions[] = { + {"LOBYTE", 0x00'00'00'00'00'00'00'FF, 0 * 8}, + {"LOWORD", 0x00'00'00'00'00'00'FF'FF, 0 * 8}, + {"LODWORD", 0x00'00'00'00'FF'FF'FF'FF, 0 * 8}, + {"HIBYTE", 0xFF'00'00'00'00'00'00'00, 7 * 8}, + {"HIWORD", 0xFF'FF'00'00'00'00'00'00, 6 * 8}, + {"HIDWORD", 0xFF'FF'FF'FF'00'00'00'00, 4 * 8}, + {"BYTE1", 0x00'00'00'00'00'00'FF'00, 1 * 8}, + {"BYTE2", 0x00'00'00'00'00'FF'00'00, 2 * 8}, + {"BYTE3", 0x00'00'00'00'FF'00'00'00, 3 * 8}, + {"BYTE4", 0x00'00'00'FF'00'00'00'00, 4 * 8}, + {"BYTE5", 0x00'00'FF'00'00'00'00'00, 5 * 8}, + {"BYTE6", 0x00'FF'00'00'00'00'00'00, 6 * 8}, + {"WORD1", 0x00'00'00'00'FF'FF'00'00, 2 * 8}, + {"WORD2", 0x00'00'FF'FF'00'00'00'00, 4 * 8}, + }; + + // check if it's one of the functions we care for + auto it = std::ranges::find( functions, expr->x->helper, [ ] ( auto&& func ) { return std::get<0>( func ); } ); + if ( it == std::end( functions ) ) + return res; + + expr = &( *expr->a )[ 0 ]; + res.mask = std::get<1>( *it ); + res.shift_value = std::get<2>( *it ); + } + else + return res; - // unwrap shifts that compiler might generate - // ((x >> imm) & imm) == imm - uint64_t mask_shift_right = 0; - if ( band->op == cot_ushr ) + if ( expr->op == cot_ushr ) { - auto shiftnum = band->find_num_op(); + auto shiftnum = expr->find_num_op(); if ( !shiftnum ) - return; + return res; - mask_shift_right = shiftnum->n->_value; - band = band->theother( shiftnum ); + expr = expr->theother( shiftnum ); + if ( res.shift_value == 0 ) + res.mask <<= shiftnum->n->_value; + + res.shift_value += ( uint8_t ) shiftnum->n->_value; } - if ( band->op != cot_ptr || band_num->op != cot_num - || band->x->op != cot_cast - || band->x->x->op != cot_ref ) - return; + if ( expr->op != cot_ptr || expr->x->op != cot_cast || expr->x->x->op != cot_ref ) + return res; - // original member ref without the `*(type*)&` part - auto orig = band->x->x->x; - auto type = orig->type; + res.underlying_expr = expr->x->x->x; // extract the ea from one of the expression parts for union selection to work // thanks to @RolfRolles for help with making it work - ea_t use_ea = band->x->x->ea; - use_ea = use_ea != BADADDR ? use_ea : band->x->ea; - use_ea = use_ea != BADADDR ? use_ea : band->ea; + ea_t use_ea = expr->x->x->ea; + use_ea = use_ea != BADADDR ? use_ea : expr->x->ea; + use_ea = use_ea != BADADDR ? use_ea : expr->ea; if ( use_ea == BADADDR ) msg( "[bitfields] can't find parent ea - won't be able to save union selection\n" ); + res.ea = use_ea; + + return res; +} + +inline void handle_comparison( cexpr_t* expr ) +{ + auto [eq, eq_num] = normalize_binop( expr ); + if ( eq_num->op != cot_num ) + return; - // invert comparison mask if it's a not equal check - const auto cmp_mask = expr->op == cot_eq ? eq_num->n->_value : ~eq_num->n->_value; + auto info = unwrap_access( eq ); + if ( !info ) + return; cexpr_t* replacement = nullptr; auto success = for_each_bitfield( - [ &, eq_num = eq_num, band_num = band_num ] ( udt_member_t& member, bool enabled ) + [ &, eq_num = eq_num ] ( udt_member_t& member ) { - const auto fret = member.size > 1 ? eq_num->type : tinfo_t{ BTF_BOOL }; - auto ftype = create_bitfunc( orig->type, band_num->type, fret ); - if ( ftype.is_unknown() ) - { - msg( "[bitfields] Failed to create bitflag function prototype\n" ); + // construct the call / access itself + auto access = create_bitfield_access( info, member, expr->ea, eq_num->type ); + if ( !access ) return; - } - - // construct the callable - auto call_fn = new cexpr_t(); - call_fn->op = cot_helper; - call_fn->type = ftype; - call_fn->exflags = 0; - call_fn->helper = alloc_cstr( "b" ); - // construct the call args - auto call_args = new carglist_t( ftype ); - - call_args->push_back( carg_t{} ); - auto& arg0 = ( *call_args )[ 0 ]; - static_cast< cexpr_t& >( arg0 ) = *orig; - arg0.ea = use_ea; - - call_args->push_back( carg_t{} ); - auto& arg1 = ( *call_args )[ 1 ]; - arg1.op = cot_helper; - arg1.type = tinfo_t{ ( type_t ) ( band_num->type.get_size() > 4 ? BTF_UINT64 : BTF_UINT32 ) }; - arg1.exflags = EXFL_ALONE; - arg1.helper = alloc_cstr( member.name.c_str() ); - - // construct the call / access itself - auto access = new cexpr_t( cot_call, call_fn ); - access->type = fret; - access->exflags = 0; - access->a = call_args; - access->ea = expr->ea; + const auto mask = bitfield_access_mask( member ); + const auto value = ( ( eq_num->n->_value << info.shift_value ) & mask ) >> member.offset; // if the flag is multi byte, reconstruct the comparison if ( member.size > 1 ) { - access = new cexpr_t( expr->op, access, new cexpr_t( *eq_num ) ); + auto num = new cnumber_t(); + num->assign( value, access->type.get_size(), member.type.is_signed() ? type_signed : type_unsigned ); + + auto num_expr = new cexpr_t(); + num_expr->op = cot_num; + num_expr->type = access->type; + num_expr->n = num; + num_expr->exflags = 0; + + access = new cexpr_t( expr->op, access, num_expr ); access->type = tinfo_t{ BTF_BOOL }; access->exflags = 0; access->ea = expr->ea; } // otherwise the flag is single bit; if the flag is disabled, add logical not - else if ( !enabled ) + else if ( value ^ ( expr->op == cot_eq ) ) { access = new cexpr_t( cot_lnot, access ); access->type = tinfo_t{ BTF_BOOL }; @@ -199,27 +297,34 @@ void handle_comparison( cexpr_t* expr ) access->ea = expr->ea; } - if ( !replacement ) - replacement = access; - else - { - replacement = new cexpr_t( cot_land, replacement, access ); - replacement->type = tinfo_t{ BTF_BOOL }; - replacement->exflags = 0; - } - }, type, band_num->n->_value << mask_shift_right, cmp_mask << mask_shift_right ); + merge_accesses( replacement, access, cot_land, expr->ea, tinfo_t{ BTF_BOOL } ); + }, info.underlying_expr->type, info.mask ); - if ( replacement ) - { - if ( success ) - expr->replace_by( replacement ); - else - delete expr; - } + replace_or_delete( expr, replacement, success ); +} + +inline void handle_assignment( cexpr_t* expr ) +{ + auto rhs = expr->y; + auto info = unwrap_access( rhs ); + if ( !info ) + return; + + cexpr_t* replacement = nullptr; + auto success = for_each_bitfield( + [ & ] ( udt_member_t& member ) + { + // TODO: for assignment where more than 1 field is being accessed create a new bitfield type for the result + // that would contain the correctly masked and shifted fields + const auto access = create_bitfield_access( info, member, expr->y->ea, expr->x->type ); + merge_accesses( replacement, access, cot_bor, rhs->ea, expr->x->type ); + }, info.underlying_expr->type, info.mask ); + + replace_or_delete( expr->y, replacement, success ); } // match special bit functions -void handle_call( cexpr_t* expr ) +inline void handle_call( cexpr_t* expr ) { constexpr static size_t num_bitmask_funcs = 8; constexpr static std::string_view functions[] = { @@ -243,17 +348,12 @@ void handle_call( cexpr_t* expr ) "_interlockedbittestandset64" }; - // 2 args - if ( expr->a->size() != 2 ) - return; - - // second arg has to be a number - auto& arg1 = ( *expr->a )[ 1 ]; - if ( arg1.op != cot_num ) + // we expect a helper whose name is one of special functions + if ( expr->x->op != cot_helper ) return; - // we expect a helper whose name is one of those functions - if ( expr->x->op != cot_helper ) + // 2 args + if ( expr->a->size() != 2 ) return; // (type*)& is expected for first arg @@ -261,6 +361,11 @@ void handle_call( cexpr_t* expr ) if ( arg0->op != cot_cast || arg0->x->op != cot_ref ) return; + // second arg has to be a number + auto& arg1 = ( *expr->a )[ 1 ]; + if ( arg1.op != cot_num ) + return; + // these functions will reference the union directly, so select a field for a start select_first_union_field( arg0->x->x ); arg0 = arg0->x->x; @@ -273,12 +378,12 @@ void handle_call( cexpr_t* expr ) auto mask = arg1.n->_value; // if it's a bitmask function make the mask 1 << n - if ( std::distance( std::begin( functions ), it ) >= num_bitmask_funcs ) + if ( std::distance( functions, it ) >= num_bitmask_funcs ) mask = ( 1ull << mask ); cexpr_t* replacement = nullptr; bool success = for_each_bitfield( - [ & ] ( udt_member_t& member, bool ) + [ & ] ( udt_member_t& member ) { auto helper = new cexpr_t(); helper->op = cot_helper; @@ -287,27 +392,13 @@ void handle_call( cexpr_t* expr ) helper->exflags = EXFL_ALONE; helper->helper = alloc_cstr( member.name.c_str() ); - if ( !replacement ) - replacement = helper; - else - { - replacement = new cexpr_t( cot_bor, replacement, helper ); - replacement->type = arg1.type; - replacement->ea = arg1.ea; - replacement->exflags = 0; - } - }, arg0->type, mask, ( uint64_t ) -1 ); + merge_accesses( replacement, helper, cot_bor, arg1.ea, arg1.type ); + }, arg0->type, mask ); - if ( replacement ) - { - if ( success ) - arg1.replace_by( replacement ); - else - delete replacement; - } + replace_or_delete( &arg1, replacement, success ); } -auto bitfields_optimizer = hex::hexrays_callback_for( +inline auto bitfields_optimizer = hex::hexrays_callback_for( [ ] ( cfunc_t* cfunc, ctree_maturity_t maturity )->ssize_t { if ( maturity != CMAT_FINAL ) @@ -323,6 +414,8 @@ auto bitfields_optimizer = hex::hexrays_callback_for( handle_comparison( expr ); else if ( expr->op == cot_call ) handle_call( expr ); + else if ( expr->op == cot_asg ) + handle_assignment( expr ); return 0; }