Skip to content

Commit

Permalink
Adding per-method tracers to the module utility. Changing set_output_…
Browse files Browse the repository at this point in the history
…data_ptr to take in a method name. (#5279)

Summary:
Pull Request resolved: #5279

* Adding per-method tracers to the executorch module utilty to be able to profile/trace methods individually
* Enabling per-method output data pointers to be able to use per-method input/output memory planning through the module.

Reviewed By: tarun292

Differential Revision: D62520386

fbshipit-source-id: 6287701183e664d68435c48d6ed2b566e3d10d93
  • Loading branch information
meta-emilian authored and facebook-github-bot committed Sep 12, 2024
1 parent 4053a18 commit 8888c0d
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 9 deletions.
15 changes: 8 additions & 7 deletions extension/module/module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,9 @@ runtime::Result<std::unordered_set<std::string>> Module::method_names() {
return result;
}

runtime::Error Module::load_method(const std::string& method_name) {
runtime::Error Module::load_method(
const std::string& method_name,
torch::executor::EventTracer* tracer) {
if (!is_method_loaded(method_name)) {
ET_CHECK_OK_OR_RETURN_ERROR(load());

Expand Down Expand Up @@ -151,9 +153,7 @@ runtime::Error Module::load_method(const std::string& method_name) {
method_holder.planned_memory.get(),
temp_allocator_.get());
method_holder.method = ET_UNWRAP_UNIQUE(program_->load_method(
method_name.c_str(),
method_holder.memory_manager.get(),
event_tracer_.get()));
method_name.c_str(), method_holder.memory_manager.get(), tracer));
methods_.emplace(method_name, std::move(method_holder));
}
return runtime::Error::Ok;
Expand Down Expand Up @@ -185,10 +185,11 @@ runtime::Result<std::vector<runtime::EValue>> Module::execute(

runtime::Error Module::set_output_data_ptr(
runtime::EValue output_value,
size_t output_index) {
ET_CHECK_OK_OR_RETURN_ERROR(load_method("forward"));
size_t output_index,
const std::string& method_name) {
ET_CHECK_OK_OR_RETURN_ERROR(load_method(method_name));
auto& output_tensor = output_value.toTensor();
auto& method = methods_.at("forward").method;
auto& method = methods_.at(method_name).method;
return method->set_output_data_ptr(
output_tensor.mutable_data_ptr(), output_tensor.nbytes(), output_index);
}
Expand Down
7 changes: 5 additions & 2 deletions extension/module/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,9 @@ class Module {
* @returns An Error to indicate success or failure.
*/
ET_NODISCARD
runtime::Error load_method(const std::string& method_name);
runtime::Error load_method(
const std::string& method_name,
torch::executor::EventTracer* tracer = nullptr);

/**
* Checks if a specific method is loaded.
Expand Down Expand Up @@ -318,7 +320,8 @@ class Module {
*/
runtime::Error set_output_data_ptr(
runtime::EValue output_value,
size_t output_index);
size_t output_index,
const std::string& method_name = "forward");

private:
struct MethodHolder {
Expand Down

0 comments on commit 8888c0d

Please sign in to comment.