diff --git a/fx2ait/fx2ait/ait_splitter.py b/fx2ait/fx2ait/ait_splitter.py index fd48a4a84..17add596f 100644 --- a/fx2ait/fx2ait/ait_splitter.py +++ b/fx2ait/fx2ait/ait_splitter.py @@ -85,6 +85,7 @@ def create_ait_operator_support( else [ ops.OpSupports.decline_if_input_dtype(torch.int64), ops.OpSupports.decline_if_input_dtype(torch.int32), + ops.OpSupports.decline_if_input_dtype(torch.uint8), ] ) chained_not_supported_ops += [ diff --git a/fx2ait/fx2ait/test/test_ait_splitter.py b/fx2ait/fx2ait/test/test_ait_splitter.py index 5b010c3a3..bbb7a0ed3 100644 --- a/fx2ait/fx2ait/test/test_ait_splitter.py +++ b/fx2ait/fx2ait/test/test_ait_splitter.py @@ -165,6 +165,23 @@ def forward(self, a): {"_run_on_gpu_0"}, ) + # nodes w/ uint8 input should not be lowered + mod = acc_tracer.trace(test_mod, [x]) + splitter = AITSplitter( + mod, + (x.to(torch.uint8).cuda(),), + operator_support, + settings, + ) + + split_results_int = splitter.generate_split_results() + + self.assertTrue(len(split_results_int), 1) + self.assertEqual( + dict(split_results_int.split_module.named_children()).keys(), + {"_run_on_gpu_0"}, + ) + # nodes w/ integer input should be lowered mod = acc_tracer.trace(test_mod, [x]) settings.allow_int_inputs = True