diff --git a/mmrazor/models/task_modules/demo_inputs/mmpose_demo_input.py b/mmrazor/models/task_modules/demo_inputs/mmpose_demo_input.py index dbf5f2772..66f316018 100644 --- a/mmrazor/models/task_modules/demo_inputs/mmpose_demo_input.py +++ b/mmrazor/models/task_modules/demo_inputs/mmpose_demo_input.py @@ -29,7 +29,7 @@ def demo_mmpose_inputs(model, for_training=False, batch_size=1): imgs = torch.randn(*input_shape) batch_data_samples = [] - from mmpose.models.heads import RTMHead + from mmpose.models.heads import RTMCCHead as RTMHead if isinstance(model.head, HeatmapHead): batch_data_samples = get_packed_inputs( batch_size,