Skip to content

Commit

Permalink
v0.4.16: fix bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
FindDefinition committed Sep 6, 2024
1 parent fc9d019 commit 66b17a3
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 14 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
# Changelog
## [0.4.16] - 2024-09-06
### Fixed
- fix a bug that bool is captured as int64_t in inliner.
- fix a bug of additional vars in inliner.

## [0.4.15] - 2024-08-14
### Added
- add compiled conversion to greatly reduce launch overhead of inline function. currently only inline cuda/metal kernels in cumm support this.
Expand Down
33 changes: 20 additions & 13 deletions pccm/builder/inliner.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,13 @@ def gcs(*instances):


def get_base_type_string(obj):
if isinstance(obj, int):
# bool is int, so we must check it first
if isinstance(obj, bool):
return "bool", False
elif isinstance(obj, int):
return "int64_t", False
elif isinstance(obj, (float, np.floating)):
return "float", False
elif isinstance(obj, (bool, )):
return "bool", False
elif isinstance(obj, str):
return "std::string", False
elif isinstance(obj, np.integer):
Expand Down Expand Up @@ -93,7 +94,7 @@ def _get_captures_in_code(code_str: str):
all_captures.append(
CaptureStmt(cap_name, is_expr, sym_range, [rep_range]))
unique_name[all_captures[-1].name] = all_captures[-1]
return all_captures
return all_captures, it.identifiers


class MultiTypeKindError(Exception):
Expand Down Expand Up @@ -136,7 +137,7 @@ def capture_vars(self, *, _frame_cnt: int = 2):
with self.capture_to_new_code(code_for_inspect):
yield
code_str = code_for_inspect.inspect_body()
all_captures = _get_captures_in_code(code_str)
all_captures, _ = _get_captures_in_code(code_str)

for cap in all_captures:
if not cap.is_expr:
Expand Down Expand Up @@ -737,7 +738,7 @@ def inline(self,
code_str.split("\n"))))
if not exist:
# 1. extract captured vars
all_captures = _get_captures_in_code(code_str)
all_captures, all_identifiers = _get_captures_in_code(code_str)
# 2. find captures in prev frame
container_fcode = FunctionCode()
inner_fcode = FunctionCode()
Expand Down Expand Up @@ -834,21 +835,23 @@ def inline(self,
replaces.append(replace)
cnt_addi = len(non_nested_stmts)
for k, v in additional_vars.items():
_, mapped_cpp_type = nested_type_analysis(
cpp_type, mapped_cpp_type = nested_type_analysis(
v, self.plugins, user_arg=user_arg)
container_fcode.arg(k,
str(mapped_cpp_type),
array=mapped_cpp_type.count,
userdata=v)
args.append(v)
if generate_non_nested_code:
non_nested_stmt, import_stmts = _non_nested_compile(
f"{PCCM_INLINE_ARG_PREFIX}_{cnt_addi}",
"local_vars", mapped_cpp_type, v,
"local_vars", cpp_type, v,
CaptureStmt(k, False, (0, 0),
[]), self.plugins, user_arg)
non_nested_stmts.append(non_nested_stmt)
non_nested_import_stmts.extend(import_stmts)
v = _nested_apply_plugin_transform(
cpp_type, v, self.plugins, user_arg)
container_fcode.arg(k,
str(mapped_cpp_type),
array=mapped_cpp_type.count,
userdata=v)
args.append(v)
cnt_addi += 1
func_obj = None
if generate_non_nested_code:
Expand Down Expand Up @@ -909,6 +912,10 @@ def inline(self,
inner_decl = self.create_inner_decl(inner_code_str,
container_fcode,
inner_fcode, user_arg)
all_arg_names = set([arg.name for arg in container_fcode.arguments])
if len(container_fcode.arguments) != len(all_arg_names):
raise ValueError(f"you have duplicate arg names. {all_arg_names}, additional keys: {additional_vars.keys()}")
# validate additional vars
# now we have complete code. we need to determine a history build dir and use it to build library if need.
# here we must reserve build dir because we need to rebuild when dependency change.
if meta is None:
Expand Down
2 changes: 1 addition & 1 deletion version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.4.15
0.4.16

0 comments on commit 66b17a3

Please sign in to comment.