Skip to content

Commit

Permalink
Handle inductor config attribute error
Browse files Browse the repository at this point in the history
Summary:
We need to add inductor dependencies to run amd target with inductor and max-autotune.

Trying to revert D54820824 which seems to cause issues in OSS.

Reviewed By: nmacchioni

Differential Revision: D54820824

fbshipit-source-id: d20c54bff19dffce2fff80be3d08295d7a5c76d4
  • Loading branch information
xuzhao9 authored and facebook-github-bot committed Mar 14, 2024
1 parent d24ff04 commit 6ba2382
Showing 1 changed file with 11 additions and 3 deletions.
14 changes: 11 additions & 3 deletions torchbenchmark/util/backends/torchdynamo.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,15 @@
"debug",
]

def _try_get_inductor_config():
try:
return torch._inductor.config.shallow_copy_dict()
except AttributeError:
# access torch inductor config directly
# if torch._inductor module does not has config attribute
from torch._inductor import config as inductor_config
return inductor_config.shallow_copy_dict()

def parse_torchdynamo_args(dynamo_args: List[str]) -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument(
Expand Down Expand Up @@ -91,8 +100,7 @@ def parse_torchdynamo_args(dynamo_args: List[str]) -> argparse.Namespace:
)

# inductor boolean configs
from torch._inductor import config as inductor_config
inductor_config_dict = inductor_config.shallow_copy_dict()
inductor_config_dict = _try_get_inductor_config()
for inductor_config_key in INDUCTOR_CONFIG_KEYS:
inductor_config_key_arg = inductor_config_key.replace(".", "-")
parser.add_argument(
Expand Down Expand Up @@ -136,7 +144,7 @@ def apply_torchdynamo_args(model: 'torchbenchmark.util.model.BenchmarkModel', ar
if compile_threads := args.torchinductor_compile_threads:
os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = str(compile_threads)
# Deal with boolean inductor configs
inductor_config_dict = torch._inductor.config.shallow_copy_dict()
inductor_config_dict = _try_get_inductor_config()
for inductor_config_key in INDUCTOR_CONFIG_KEYS:
inductor_config_key_arg = inductor_config_key.replace(".", "_")
if getattr(args, f"no_pt2_{inductor_config_key_arg}", None) == False:
Expand Down

0 comments on commit 6ba2382

Please sign in to comment.