Skip to content

Commit

Permalink
hotfix: bugfix to #756 (#757)
Browse files Browse the repository at this point in the history
This pull request includes changes to the `flashinfer/jit/__init__.py`
file to improve the import structure and handle the `prebuilt_ops_uri`
import conditionally.

Improvements to import structure:

* Removed the unconditional import of `prebuilt_ops_uri` from
`aot_config` and added it conditionally within the try-except block to
handle cases where `_kernels` or `_kernels_sm90` are not available.
[[1]](diffhunk://#diff-39845bb8e1f81f9ca7d510e99dad0c15c7596ebe1ac909dc1b3c25b742700b5cL20)
[[2]](diffhunk://#diff-39845bb8e1f81f9ca7d510e99dad0c15c7596ebe1ac909dc1b3c25b742700b5cR48-R52)

cc @ByronHsu
  • Loading branch information
yzh119 authored Jan 27, 2025
1 parent aa9c5d9 commit 68d1177
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions flashinfer/jit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
# Re-export
from .activation import gen_act_and_mul_module as gen_act_and_mul_module
from .activation import get_act_and_mul_cu_str as get_act_and_mul_cu_str
from .aot_config import prebuilt_ops_uri as prebuilt_ops_uri
from .attention import gen_batch_decode_mla_module as gen_batch_decode_mla_module
from .attention import gen_batch_decode_module as gen_batch_decode_module
from .attention import gen_batch_prefill_module as gen_batch_prefill_module
Expand All @@ -40,13 +39,15 @@
from .attention import get_batch_prefill_uri as get_batch_prefill_uri
from .attention import get_single_decode_uri as get_single_decode_uri
from .attention import get_single_prefill_uri as get_single_prefill_uri
from .core import clear_cache_dir, load_cuda_ops # noqa: F401
from .core import clear_cache_dir, load_cuda_ops # noqa: F401
from .env import *
from .utils import parallel_load_modules as parallel_load_modules

try:
from .. import _kernels, _kernels_sm90 # noqa: F401
from .aot_config import prebuilt_ops_uri as prebuilt_ops_uri

has_prebuilt_ops = True
except ImportError:
prebuilt_ops_uri = {}
has_prebuilt_ops = False

0 comments on commit 68d1177

Please sign in to comment.