From 76a043d7986234bdf99b74472d0522816addeba9 Mon Sep 17 00:00:00 2001 From: William Sanville Date: Tue, 11 Feb 2025 14:54:55 -0800 Subject: [PATCH] Evaluate Intrinsics in StringAnalyzer Summary: Logic in this block should be the same for both of these functions; if arguments are known to be non-null strings then this should just work. Reviewed By: NTillmann Differential Revision: D67682098 fbshipit-source-id: be97ad36bd318e17b6797a8c630cfd044c2118a7 --- .../IPConstantPropagation.cpp | 18 +++++++-- .../IPConstantPropagation.h | 3 ++ opt/shrinker/ShrinkerPass.cpp | 8 ++++ opt/shrinker/ShrinkerPass.h | 1 + .../ConstantPropagationAnalysis.cpp | 38 ++++++++++++++++++- .../ConstantPropagationAnalysis.h | 20 ++++++++-- service/method-inliner/CallSiteSummaries.cpp | 3 +- service/shrinker/Shrinker.cpp | 4 +- service/shrinker/Shrinker.h | 1 + .../IPConstantPropagationTest.cpp | 21 +++++----- .../StringPropagationTest.cpp | 18 ++++++--- 11 files changed, 108 insertions(+), 27 deletions(-) diff --git a/opt/constant-propagation/IPConstantPropagation.cpp b/opt/constant-propagation/IPConstantPropagation.cpp index 295949ed665..98c4d74b28e 100644 --- a/opt/constant-propagation/IPConstantPropagation.cpp +++ b/opt/constant-propagation/IPConstantPropagation.cpp @@ -62,6 +62,7 @@ using CombinedAnalyzer = class AnalyzerGenerator { const ImmutableAttributeAnalyzerState* m_immut_analyzer_state; const ApiLevelAnalyzerState* m_api_level_analyzer_state; + const StringAnalyzerState* m_string_analyzer_state; const PackageNameState* m_package_name_state; const State& m_cp_state; @@ -69,10 +70,12 @@ class AnalyzerGenerator { explicit AnalyzerGenerator( const ImmutableAttributeAnalyzerState* immut_analyzer_state, const ApiLevelAnalyzerState* api_level_analyzer_state, + const StringAnalyzerState* string_analyzer_state, const PackageNameState* package_name_state, const State& cp_state) : m_immut_analyzer_state(immut_analyzer_state), m_api_level_analyzer_state(api_level_analyzer_state), + m_string_analyzer_state(string_analyzer_state), m_package_name_state(package_name_state), m_cp_state(cp_state) { // Initialize the singletons that `operator()` needs ahead of time to @@ -114,7 +117,7 @@ class AnalyzerGenerator { CombinedAnalyzer( class_under_init, immut_analyzer_state, wps_accessor_ptr, EnumFieldAnalyzerState::get(), BoxedBooleanAnalyzerState::get(), - nullptr, nullptr, + const_cast(m_string_analyzer_state), nullptr, *const_cast(m_api_level_analyzer_state), const_cast(m_package_name_state), immut_analyzer_state, nullptr), @@ -139,6 +142,7 @@ std::unique_ptr PassImpl::analyze( const Scope& scope, const ImmutableAttributeAnalyzerState* immut_analyzer_state, const ApiLevelAnalyzerState* api_level_analyzer_state, + const StringAnalyzerState* string_analyzer_state, const PackageNameState* package_name_state, const State& cp_state) { auto method_override_graph = mog::build_graph(scope); @@ -161,7 +165,7 @@ std::unique_ptr PassImpl::analyze( auto fp_iter = std::make_unique( cg, AnalyzerGenerator(immut_analyzer_state, api_level_analyzer_state, - package_name_state, cp_state), + string_analyzer_state, package_name_state, cp_state), cg_for_wps); // Run the bootstrap. All field value and method return values are // represented by Top. @@ -299,17 +303,25 @@ void PassImpl::run(const DexStoresVector& stores, immutable_state::analyze_constructors(scope, &immut_analyzer_state); ApiLevelAnalyzerState api_level_analyzer_state = ApiLevelAnalyzerState::get(min_sdk); + auto string_analyzer_state = StringAnalyzerState::get(); auto package_name_state = PackageNameState::get(package_name); State cp_state; auto fp_iter = analyze(scope, &immut_analyzer_state, &api_level_analyzer_state, - &package_name_state, cp_state); + &string_analyzer_state, &package_name_state, cp_state); m_stats.fp_iter = fp_iter->get_stats(); TypeSystem type_system(scope); optimize(scope, type_system, xstores, *fp_iter, &immut_analyzer_state, cp_state); } +void PassImpl::eval_pass(DexStoresVector& stores, + ConfigFiles& conf, + PassManager&) { + auto string_analyzer_state = StringAnalyzerState::get(); + string_analyzer_state.set_methods_as_root(); +} + void PassImpl::run_pass(DexStoresVector& stores, ConfigFiles& config, PassManager& mgr) { diff --git a/opt/constant-propagation/IPConstantPropagation.h b/opt/constant-propagation/IPConstantPropagation.h index 27fec9861de..601bfde84b0 100644 --- a/opt/constant-propagation/IPConstantPropagation.h +++ b/opt/constant-propagation/IPConstantPropagation.h @@ -80,6 +80,8 @@ class PassImpl : public Pass { "they are read, in order to ignore the default value 0."); } + void eval_pass(DexStoresVector&, ConfigFiles&, PassManager&) override; + void run_pass(DexStoresVector& stores, ConfigFiles& conf, PassManager& mgr) override; @@ -100,6 +102,7 @@ class PassImpl : public Pass { const Scope&, const ImmutableAttributeAnalyzerState*, const ApiLevelAnalyzerState*, + const StringAnalyzerState*, const PackageNameState*, const State&); diff --git a/opt/shrinker/ShrinkerPass.cpp b/opt/shrinker/ShrinkerPass.cpp index 1ff19bcfcab..51f31ce23d6 100644 --- a/opt/shrinker/ShrinkerPass.cpp +++ b/opt/shrinker/ShrinkerPass.cpp @@ -8,6 +8,7 @@ #include "ShrinkerPass.h" #include "ConfigFiles.h" +#include "ConstantPropagationAnalysis.h" #include "PassManager.h" #include "ScopedMetrics.h" #include "Shrinker.h" @@ -43,6 +44,13 @@ void ShrinkerPass::bind_config() { "relevant when using constant-propagaation)"); } +void ShrinkerPass::eval_pass(DexStoresVector& stores, + ConfigFiles& conf, + PassManager&) { + auto string_analyzer_state = constant_propagation::StringAnalyzerState::get(); + string_analyzer_state.set_methods_as_root(); +} + void ShrinkerPass::run_pass(DexStoresVector& stores, ConfigFiles& conf, PassManager& mgr) { diff --git a/opt/shrinker/ShrinkerPass.h b/opt/shrinker/ShrinkerPass.h index 95d52a67735..f471a40b75e 100644 --- a/opt/shrinker/ShrinkerPass.h +++ b/opt/shrinker/ShrinkerPass.h @@ -28,6 +28,7 @@ class ShrinkerPass : public Pass { } void bind_config() override; + void eval_pass(DexStoresVector&, ConfigFiles&, PassManager&) override; void run_pass(DexStoresVector&, ConfigFiles&, PassManager&) override; private: diff --git a/service/constant-propagation/ConstantPropagationAnalysis.cpp b/service/constant-propagation/ConstantPropagationAnalysis.cpp index 2241185cb69..6134266dff4 100644 --- a/service/constant-propagation/ConstantPropagationAnalysis.cpp +++ b/service/constant-propagation/ConstantPropagationAnalysis.cpp @@ -994,7 +994,39 @@ bool BoxedBooleanAnalyzer::analyze_invoke( } } -bool StringAnalyzer::analyze_invoke(const IRInstruction* insn, +static boost::optional s_string_analyzer_state{ + boost::none}; +static std::mutex s_string_analyzer_state_mtx; +StringAnalyzerState StringAnalyzerState::get() { + std::lock_guard lock(s_string_analyzer_state_mtx); + if (s_string_analyzer_state == boost::none) { + std::unordered_set methods; + auto kotlin_are_equal = DexMethod::get_method( + "Lkotlin/jvm/internal/Intrinsics;.areEqual:(Ljava/lang/Object;Ljava/" + "lang/Object;)Z"); + if (kotlin_are_equal != nullptr && kotlin_are_equal->as_def() != nullptr) { + methods.emplace(kotlin_are_equal->as_def()); + } + s_string_analyzer_state = StringAnalyzerState(methods); + // For tests. + g_redex->add_destruction_task([]() { + std::lock_guard task_lock(s_string_analyzer_state_mtx); + s_string_analyzer_state = boost::none; + }); + } + return *s_string_analyzer_state; +} + +void StringAnalyzerState::set_methods_as_root() { + for (const auto& method : string_equality_methods) { + if (!method->is_external()) { + method->rstate.set_root(); + } + } +} + +bool StringAnalyzer::analyze_invoke(const StringAnalyzerState* state, + const IRInstruction* insn, ConstantEnvironment* env) { DexMethod* method = resolve_method(insn->get_method(), opcode_to_search(insn)); @@ -1013,7 +1045,9 @@ bool StringAnalyzer::analyze_invoke(const IRInstruction* insn, return nullptr; }; - if (method == method::java_lang_String_equals()) { + if (method == method::java_lang_String_equals() || + (state != nullptr && state->string_equality_methods.find(method) != + state->string_equality_methods.end())) { always_assert(insn->srcs_size() == 2); if (const auto* arg0 = maybe_string(0)) { if (const auto* arg1 = maybe_string(1)) { diff --git a/service/constant-propagation/ConstantPropagationAnalysis.h b/service/constant-propagation/ConstantPropagationAnalysis.h index 5d261efce7e..a994363a0e4 100644 --- a/service/constant-propagation/ConstantPropagationAnalysis.h +++ b/service/constant-propagation/ConstantPropagationAnalysis.h @@ -410,16 +410,28 @@ class BoxedBooleanAnalyzer final ConstantEnvironment*); }; -class StringAnalyzer - : public InstructionAnalyzerBase { +struct StringAnalyzerState { + static StringAnalyzerState get(); + explicit StringAnalyzerState(const std::unordered_set& methods) + : string_equality_methods(methods) {} + void set_methods_as_root(); + std::unordered_set string_equality_methods; +}; + +class StringAnalyzer : public InstructionAnalyzerBase { public: - static bool analyze_const_string(const IRInstruction* insn, + static bool analyze_const_string(const StringAnalyzerState* state, + const IRInstruction* insn, ConstantEnvironment* env) { env->set(RESULT_REGISTER, StringDomain(insn->get_string())); return true; } - static bool analyze_invoke(const IRInstruction*, ConstantEnvironment*); + static bool analyze_invoke(const StringAnalyzerState* state, + const IRInstruction*, + ConstantEnvironment*); }; class ConstantClassObjectAnalyzer diff --git a/service/method-inliner/CallSiteSummaries.cpp b/service/method-inliner/CallSiteSummaries.cpp index 0cad3863d18..fdf935e6fa8 100644 --- a/service/method-inliner/CallSiteSummaries.cpp +++ b/service/method-inliner/CallSiteSummaries.cpp @@ -302,7 +302,8 @@ CallSiteSummarizer::get_invoke_call_site_summaries( m_shrinker.get_immut_analyzer_state(), m_shrinker.get_immut_analyzer_state(), constant_propagation::EnumFieldAnalyzerState::get(), - constant_propagation::BoxedBooleanAnalyzerState::get(), nullptr, + constant_propagation::BoxedBooleanAnalyzerState::get(), + /* TODO update */ nullptr, constant_propagation::ApiLevelAnalyzerState::get(), m_shrinker.get_package_name_state(), nullptr, m_shrinker.get_immut_analyzer_state(), nullptr)); diff --git a/service/shrinker/Shrinker.cpp b/service/shrinker/Shrinker.cpp index 491f85551fb..78e401b59ec 100644 --- a/service/shrinker/Shrinker.cpp +++ b/service/shrinker/Shrinker.cpp @@ -94,6 +94,7 @@ Shrinker::Shrinker( m_pure_methods(configured_pure_methods), m_finalish_field_names(configured_finalish_field_names), m_finalish_fields(configured_finalish_fields), + m_string_analyzer_state(constant_propagation::StringAnalyzerState::get()), m_package_name_state( constant_propagation::PackageNameState::get(package_name)) { // Initialize the singletons that `operator()` needs ahead of time to @@ -155,7 +156,8 @@ constant_propagation::Transform::Stats Shrinker::constant_propagation( constant_propagation::ConstantPrimitiveAndBoxedAnalyzer( &m_immut_analyzer_state, &m_immut_analyzer_state, constant_propagation::EnumFieldAnalyzerState::get(), - constant_propagation::BoxedBooleanAnalyzerState::get(), nullptr, + constant_propagation::BoxedBooleanAnalyzerState::get(), + &m_string_analyzer_state, constant_propagation::ApiLevelAnalyzerState::get(m_min_sdk), &m_package_name_state, nullptr, &m_immut_analyzer_state, nullptr), /* imprecise_switches */ true); diff --git a/service/shrinker/Shrinker.h b/service/shrinker/Shrinker.h index e7a518e0e69..aca5dfa6ec8 100644 --- a/service/shrinker/Shrinker.h +++ b/service/shrinker/Shrinker.h @@ -166,6 +166,7 @@ class Shrinker { std::unordered_set m_finalish_fields; constant_propagation::ImmutableAttributeAnalyzerState m_immut_analyzer_state; + constant_propagation::StringAnalyzerState m_string_analyzer_state; constant_propagation::PackageNameState m_package_name_state; constant_propagation::State m_cp_state; diff --git a/test/unit/constant-propagation/IPConstantPropagationTest.cpp b/test/unit/constant-propagation/IPConstantPropagationTest.cpp index 0c6cb6cb578..c76963a742a 100644 --- a/test/unit/constant-propagation/IPConstantPropagationTest.cpp +++ b/test/unit/constant-propagation/IPConstantPropagationTest.cpp @@ -49,6 +49,7 @@ struct InterproceduralConstantPropagationTest : public RedexTest { const std::string package_name = "com.facebook.redextest"; ImmutableAttributeAnalyzerState m_immut_analyzer_state; ApiLevelAnalyzerState m_api_level_analyzer_state; + StringAnalyzerState m_string_analyzer_state = StringAnalyzerState::get(); PackageNameState m_package_name_state = PackageNameState::get(package_name); State m_cp_state; }; @@ -1084,7 +1085,7 @@ TEST_F(InterproceduralConstantPropagationTest, constantFieldAfterClinit) { auto fp_iter = InterproceduralConstantPropagationPass(config).analyze( scope, &m_immut_analyzer_state, &m_api_level_analyzer_state, - &m_package_name_state, m_cp_state); + &m_string_analyzer_state, &m_package_name_state, m_cp_state); auto& wps = fp_iter->get_whole_program_state(); EXPECT_EQ(wps.get_field_value(field_qux), SignedConstantDomain(0)); EXPECT_EQ(wps.get_field_value(field_corge), SignedConstantDomain(1)); @@ -1177,7 +1178,7 @@ TEST_F(InterproceduralConstantPropagationTest, auto fp_iter = InterproceduralConstantPropagationPass(config).analyze( scope, &m_immut_analyzer_state, &m_api_level_analyzer_state, - &m_package_name_state, m_cp_state); + &m_string_analyzer_state, &m_package_name_state, m_cp_state); auto& wps = fp_iter->get_whole_program_state(); EXPECT_EQ(wps.get_field_value(field_qux), ConstantValue::top()); @@ -1667,7 +1668,7 @@ TEST_F(InterproceduralConstantPropagationTest, whiteBoxReturnValues) { config.max_heap_analysis_iterations = 1; auto fp_iter = InterproceduralConstantPropagationPass(config).analyze( scope, &m_immut_analyzer_state, &m_api_level_analyzer_state, - &m_package_name_state, m_cp_state); + &m_string_analyzer_state, &m_package_name_state, m_cp_state); auto& wps = fp_iter->get_whole_program_state(); // Make sure we mark methods that have a reachable return-void statement as @@ -1704,7 +1705,7 @@ TEST_F(InterproceduralConstantPropagationTest, min_sdk) { config.max_heap_analysis_iterations = 1; auto fp_iter = InterproceduralConstantPropagationPass(config).analyze( scope, &m_immut_analyzer_state, &m_api_level_analyzer_state, - &m_package_name_state, m_cp_state); + &m_string_analyzer_state, &m_package_name_state, m_cp_state); auto& wps = fp_iter->get_whole_program_state(); // Make sure we mark methods that have a reachable return-void statement as @@ -1818,7 +1819,7 @@ TEST_F(InterproceduralConstantPropagationTest, auto fp_iter = InterproceduralConstantPropagationPass(config).analyze( scope, &m_immut_analyzer_state, &m_api_level_analyzer_state, - &m_package_name_state, m_cp_state); + &m_string_analyzer_state, &m_package_name_state, m_cp_state); auto& wps = fp_iter->get_whole_program_state(); // as the field is definitely-assigned, 0 was not added to the numeric // interval domain @@ -1894,7 +1895,7 @@ TEST_F(InterproceduralConstantPropagationTest, auto fp_iter = InterproceduralConstantPropagationPass(config).analyze( scope, &m_immut_analyzer_state, &m_api_level_analyzer_state, - &m_package_name_state, m_cp_state); + &m_string_analyzer_state, &m_package_name_state, m_cp_state); auto& wps = fp_iter->get_whole_program_state(); // as the field is definitely-assigned, even with the branching in the // constructor, 0 was not added to the numeric interval domain @@ -1963,7 +1964,7 @@ TEST_F(InterproceduralConstantPropagationTest, auto fp_iter = InterproceduralConstantPropagationPass(config).analyze( scope, &m_immut_analyzer_state, &m_api_level_analyzer_state, - &m_package_name_state, m_cp_state); + &m_string_analyzer_state, &m_package_name_state, m_cp_state); auto& wps = fp_iter->get_whole_program_state(); // 0 is included in the numeric interval as 'this' escaped before the // assignment @@ -2030,7 +2031,7 @@ TEST_F(InterproceduralConstantPropagationTest, auto fp_iter = InterproceduralConstantPropagationPass(config).analyze( scope, &m_immut_analyzer_state, &m_api_level_analyzer_state, - &m_package_name_state, m_cp_state); + &m_string_analyzer_state, &m_package_name_state, m_cp_state); auto& wps = fp_iter->get_whole_program_state(); // 0 is included in the numeric interval as 'this' escaped before the // assignment @@ -2098,7 +2099,7 @@ TEST_F(InterproceduralConstantPropagationTest, auto fp_iter = InterproceduralConstantPropagationPass(config).analyze( scope, &m_immut_analyzer_state, &m_api_level_analyzer_state, - &m_package_name_state, m_cp_state); + &m_string_analyzer_state, &m_package_name_state, m_cp_state); auto& wps = fp_iter->get_whole_program_state(); // 0 is included in the numeric interval as no actual constructor was ever // called @@ -2169,7 +2170,7 @@ TEST_F(InterproceduralConstantPropagationTest, auto fp_iter = InterproceduralConstantPropagationPass(config).analyze( scope, &m_immut_analyzer_state, &m_api_level_analyzer_state, - &m_package_name_state, m_cp_state); + &m_string_analyzer_state, &m_package_name_state, m_cp_state); auto& wps = fp_iter->get_whole_program_state(); // 0 is included in the numeric interval as the field was read before written EXPECT_EQ(wps.get_field_value(field_f), SignedConstantDomain(0, 42)); diff --git a/test/unit/constant-propagation/StringPropagationTest.cpp b/test/unit/constant-propagation/StringPropagationTest.cpp index df636957cce..9a357bd9d1a 100644 --- a/test/unit/constant-propagation/StringPropagationTest.cpp +++ b/test/unit/constant-propagation/StringPropagationTest.cpp @@ -53,7 +53,8 @@ TEST_F(StringTest, neq) { ) )"); - do_const_prop(code.get(), StringAnalyzer()); + auto state = cp::StringAnalyzerState::get(); + do_const_prop(code.get(), StringAnalyzer(&state, nullptr)); auto expected_code = assembler::ircode_from_string(R"( ( @@ -85,7 +86,8 @@ TEST_F(StringTest, equals_false) { std::unordered_set pure_methods{ method::java_lang_String_equals()}; config.pure_methods = &pure_methods; - do_const_prop(code.get(), StringAnalyzer(), config); + auto state = cp::StringAnalyzerState::get(); + do_const_prop(code.get(), StringAnalyzer(&state, nullptr), config); auto expected_code = assembler::ircode_from_string(R"( ( @@ -119,7 +121,8 @@ TEST_F(StringTest, equals_true) { std::unordered_set pure_methods{ method::java_lang_String_equals()}; config.pure_methods = &pure_methods; - do_const_prop(code.get(), StringAnalyzer(), config); + auto state = cp::StringAnalyzerState::get(); + do_const_prop(code.get(), StringAnalyzer(&state, nullptr), config); auto expected_code = assembler::ircode_from_string(R"( ( @@ -151,7 +154,8 @@ TEST_F(StringTest, hashCode) { std::unordered_set pure_methods{ method::java_lang_String_hashCode()}; config.pure_methods = &pure_methods; - do_const_prop(code.get(), StringAnalyzer(), config); + auto state = cp::StringAnalyzerState::get(); + do_const_prop(code.get(), StringAnalyzer(&state, nullptr), config); auto expected_code = assembler::ircode_from_string(R"( ( @@ -232,8 +236,10 @@ TEST_F(StringTest, package_equals_true) { std::unordered_set pure_methods{ method::java_lang_String_equals()}; config.pure_methods = &pure_methods; - auto state = cp::PackageNameState::get("com.facebook.redextest"); - do_const_prop(code.get(), PackageStringAnalyzer(&state, nullptr, nullptr), + auto package_state = cp::PackageNameState::get("com.facebook.redextest"); + auto string_state = cp::StringAnalyzerState::get(); + do_const_prop(code.get(), + PackageStringAnalyzer(&package_state, &string_state, nullptr), config); auto expected_code = assembler::ircode_from_string(R"(