From 6db71dc63a1ab869e394cb5307b79f952b3f26b9 Mon Sep 17 00:00:00 2001 From: Benson Ma Date: Fri, 7 Feb 2025 13:03:58 -0800 Subject: [PATCH] Re-organize SLL ops, pt 9 (#3665) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/3665 X-link: https://github.com/facebookresearch/FBGEMM/pull/740 - Move cpu_sll and meta_sll to their own folders Differential Revision: D69227334 --- fbgemm_gpu/fbgemm_gpu/sll/__init__.py | 89 ++----------------- fbgemm_gpu/fbgemm_gpu/sll/cpu/__init__.py | 78 ++++++++++++++++ .../fbgemm_gpu/sll/{ => cpu}/cpu_sll.py | 0 fbgemm_gpu/fbgemm_gpu/sll/meta/__init__.py | 35 ++++++++ .../fbgemm_gpu/sll/{ => meta}/meta_sll.py | 0 5 files changed, 118 insertions(+), 84 deletions(-) create mode 100644 fbgemm_gpu/fbgemm_gpu/sll/cpu/__init__.py rename fbgemm_gpu/fbgemm_gpu/sll/{ => cpu}/cpu_sll.py (100%) create mode 100644 fbgemm_gpu/fbgemm_gpu/sll/meta/__init__.py rename fbgemm_gpu/fbgemm_gpu/sll/{ => meta}/meta_sll.py (100%) diff --git a/fbgemm_gpu/fbgemm_gpu/sll/__init__.py b/fbgemm_gpu/fbgemm_gpu/sll/__init__.py index bd89e4ff5..9822d8566 100644 --- a/fbgemm_gpu/fbgemm_gpu/sll/__init__.py +++ b/fbgemm_gpu/fbgemm_gpu/sll/__init__.py @@ -8,31 +8,8 @@ # pyre-strict import torch - -from fbgemm_gpu.sll.cpu_sll import ( # noqa F401 - cpu_array_jagged_bmm_jagged_out, - cpu_dense_jagged_cat_jagged_out, - cpu_jagged2_softmax, - cpu_jagged2_to_padded_dense, - cpu_jagged_dense_bmm, - cpu_jagged_dense_elementwise_add, - cpu_jagged_dense_elementwise_mul_jagged_out, - cpu_jagged_dense_flash_attention, - cpu_jagged_flash_attention_basic, - cpu_jagged_jagged_bmm, - cpu_jagged_jagged_bmm_jagged_out, - cpu_jagged_self_substraction_jagged_out, - cpu_jagged_softmax, -) - -from fbgemm_gpu.sll.meta_sll import ( # noqa F401 - meta_array_jagged_bmm_jagged_out, - meta_jagged2_softmax, - meta_jagged_dense_elementwise_mul_jagged_out, - meta_jagged_jagged_bmm_jagged_out, - meta_jagged_self_substraction_jagged_out, -) - +from fbgemm_gpu.sll.cpu import op_registrations as sll_cpu_registrations +from fbgemm_gpu.sll.meta import op_registrations as sll_meta_registrations from fbgemm_gpu.utils import TorchLibraryFragment lib = TorchLibraryFragment("fbgemm") @@ -198,68 +175,12 @@ # need the autograd forward to save the context because we don't need to do # backward. -# pyre-ignore[5] -sll_cpu_registrations = { - "sll_jagged_dense_bmm": { - "CPU": cpu_jagged_dense_bmm, - "AutogradCPU": cpu_jagged_dense_bmm, - }, - "sll_jagged_jagged_bmm": { - "CPU": cpu_jagged_jagged_bmm, - "AutogradCPU": cpu_jagged_jagged_bmm, - }, - "sll_dense_jagged_cat_jagged_out": { - "CPU": cpu_dense_jagged_cat_jagged_out, - }, - "sll_jagged_self_substraction_jagged_out": { - "CPU": cpu_jagged_self_substraction_jagged_out, - "Meta": meta_jagged_self_substraction_jagged_out, - }, - "sll_jagged2_to_padded_dense": { - "CPU": cpu_jagged2_to_padded_dense, - "AutogradCPU": cpu_jagged2_to_padded_dense, - }, - "sll_jagged_dense_elementwise_mul_jagged_out": { - "CPU": cpu_jagged_dense_elementwise_mul_jagged_out, - "AutogradCPU": cpu_jagged_dense_elementwise_mul_jagged_out, - "Meta": meta_jagged_dense_elementwise_mul_jagged_out, - }, - "sll_jagged_softmax": { - "CPU": cpu_jagged_softmax, - "AutogradCPU": cpu_jagged_softmax, - }, - "sll_jagged2_softmax": { - "CPU": cpu_jagged2_softmax, - "AutogradCPU": cpu_jagged2_softmax, - "AutogradMeta": meta_jagged2_softmax, - }, - "sll_array_jagged_bmm_jagged_out": { - "CPU": cpu_array_jagged_bmm_jagged_out, - "AutogradCPU": cpu_array_jagged_bmm_jagged_out, - "AutogradMeta": meta_array_jagged_bmm_jagged_out, - }, - "sll_jagged_jagged_bmm_jagged_out": { - "CPU": cpu_jagged_jagged_bmm_jagged_out, - "AutogradCPU": cpu_jagged_jagged_bmm_jagged_out, - "AutogradMeta": meta_jagged_jagged_bmm_jagged_out, - }, - "sll_jagged_flash_attention_basic": { - "CPU": cpu_jagged_flash_attention_basic, - "AutogradCPU": cpu_jagged_flash_attention_basic, - }, - "sll_jagged_dense_elementwise_add": { - "CPU": cpu_jagged_dense_elementwise_add, - "AutogradCPU": cpu_jagged_dense_elementwise_add, - }, - "sll_jagged_dense_flash_attention": { - "CPU": cpu_jagged_dense_flash_attention, - "AutogradCPU": cpu_jagged_dense_flash_attention, - }, -} - for op_name, dispatches in sll_cpu_registrations.items(): lib.register(op_name, dispatches) +for op_name, dispatches in sll_meta_registrations.items(): + lib.register(op_name, dispatches) + if torch.cuda.is_available(): from fbgemm_gpu.sll.triton import op_registrations as sll_gpu_registrations diff --git a/fbgemm_gpu/fbgemm_gpu/sll/cpu/__init__.py b/fbgemm_gpu/fbgemm_gpu/sll/cpu/__init__.py new file mode 100644 index 000000000..a8c84c576 --- /dev/null +++ b/fbgemm_gpu/fbgemm_gpu/sll/cpu/__init__.py @@ -0,0 +1,78 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from fbgemm_gpu.sll.cpu.cpu_sll import ( # noqa F401 + cpu_array_jagged_bmm_jagged_out, + cpu_dense_jagged_cat_jagged_out, + cpu_jagged2_softmax, + cpu_jagged2_to_padded_dense, + cpu_jagged_dense_bmm, + cpu_jagged_dense_elementwise_add, + cpu_jagged_dense_elementwise_mul_jagged_out, + cpu_jagged_dense_flash_attention, + cpu_jagged_flash_attention_basic, + cpu_jagged_jagged_bmm, + cpu_jagged_jagged_bmm_jagged_out, + cpu_jagged_self_substraction_jagged_out, + cpu_jagged_softmax, +) + +# pyre-ignore[5] +op_registrations = { + "sll_jagged_dense_bmm": { + "CPU": cpu_jagged_dense_bmm, + "AutogradCPU": cpu_jagged_dense_bmm, + }, + "sll_jagged_jagged_bmm": { + "CPU": cpu_jagged_jagged_bmm, + "AutogradCPU": cpu_jagged_jagged_bmm, + }, + "sll_dense_jagged_cat_jagged_out": { + "CPU": cpu_dense_jagged_cat_jagged_out, + }, + "sll_jagged_self_substraction_jagged_out": { + "CPU": cpu_jagged_self_substraction_jagged_out, + }, + "sll_jagged2_to_padded_dense": { + "CPU": cpu_jagged2_to_padded_dense, + "AutogradCPU": cpu_jagged2_to_padded_dense, + }, + "sll_jagged_dense_elementwise_mul_jagged_out": { + "CPU": cpu_jagged_dense_elementwise_mul_jagged_out, + "AutogradCPU": cpu_jagged_dense_elementwise_mul_jagged_out, + }, + "sll_jagged_softmax": { + "CPU": cpu_jagged_softmax, + "AutogradCPU": cpu_jagged_softmax, + }, + "sll_jagged2_softmax": { + "CPU": cpu_jagged2_softmax, + "AutogradCPU": cpu_jagged2_softmax, + }, + "sll_array_jagged_bmm_jagged_out": { + "CPU": cpu_array_jagged_bmm_jagged_out, + "AutogradCPU": cpu_array_jagged_bmm_jagged_out, + }, + "sll_jagged_jagged_bmm_jagged_out": { + "CPU": cpu_jagged_jagged_bmm_jagged_out, + "AutogradCPU": cpu_jagged_jagged_bmm_jagged_out, + }, + "sll_jagged_flash_attention_basic": { + "CPU": cpu_jagged_flash_attention_basic, + "AutogradCPU": cpu_jagged_flash_attention_basic, + }, + "sll_jagged_dense_elementwise_add": { + "CPU": cpu_jagged_dense_elementwise_add, + "AutogradCPU": cpu_jagged_dense_elementwise_add, + }, + "sll_jagged_dense_flash_attention": { + "CPU": cpu_jagged_dense_flash_attention, + "AutogradCPU": cpu_jagged_dense_flash_attention, + }, +} diff --git a/fbgemm_gpu/fbgemm_gpu/sll/cpu_sll.py b/fbgemm_gpu/fbgemm_gpu/sll/cpu/cpu_sll.py similarity index 100% rename from fbgemm_gpu/fbgemm_gpu/sll/cpu_sll.py rename to fbgemm_gpu/fbgemm_gpu/sll/cpu/cpu_sll.py diff --git a/fbgemm_gpu/fbgemm_gpu/sll/meta/__init__.py b/fbgemm_gpu/fbgemm_gpu/sll/meta/__init__.py new file mode 100644 index 000000000..23920961e --- /dev/null +++ b/fbgemm_gpu/fbgemm_gpu/sll/meta/__init__.py @@ -0,0 +1,35 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from fbgemm_gpu.sll.meta.meta_sll import ( # noqa F401 + meta_array_jagged_bmm_jagged_out, + meta_jagged2_softmax, + meta_jagged_dense_elementwise_mul_jagged_out, + meta_jagged_jagged_bmm_jagged_out, + meta_jagged_self_substraction_jagged_out, +) + +# pyre-ignore[5] +op_registrations = { + "sll_jagged_self_substraction_jagged_out": { + "Meta": meta_jagged_self_substraction_jagged_out, + }, + "sll_jagged_dense_elementwise_mul_jagged_out": { + "Meta": meta_jagged_dense_elementwise_mul_jagged_out, + }, + "sll_jagged2_softmax": { + "AutogradMeta": meta_jagged2_softmax, + }, + "sll_array_jagged_bmm_jagged_out": { + "AutogradMeta": meta_array_jagged_bmm_jagged_out, + }, + "sll_jagged_jagged_bmm_jagged_out": { + "AutogradMeta": meta_jagged_jagged_bmm_jagged_out, + }, +} diff --git a/fbgemm_gpu/fbgemm_gpu/sll/meta_sll.py b/fbgemm_gpu/fbgemm_gpu/sll/meta/meta_sll.py similarity index 100% rename from fbgemm_gpu/fbgemm_gpu/sll/meta_sll.py rename to fbgemm_gpu/fbgemm_gpu/sll/meta/meta_sll.py