Skip to content

Commit

Permalink
[API][Backend] Fix hcl.print with UInt supported (#184)
Browse files Browse the repository at this point in the history
  • Loading branch information
seanlatias authored Apr 29, 2020
1 parent a339907 commit 01be7f5
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 9 deletions.
5 changes: 3 additions & 2 deletions python/heterocl/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,8 +384,9 @@ def print(vals, format=""):

def get_format(val):
if isinstance(val, (TensorSlice, Scalar, _expr.Expr)):
if util.get_type(val.dtype)[0] == "int":
return "%d"
if (util.get_type(val.dtype)[0] == "int"
or util.get_type(val.dtype)[0] == "uint"):
return "%lld"
else:
return "%f"
elif isinstance(val, int):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_api_print.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def test_print_expr():

outputs = get_stdout("print_expr").split("\n")

N = 4
N = 5
for i in range(0, N):
assert outputs[i] == outputs[i+N]

Expand Down
25 changes: 22 additions & 3 deletions tests/test_api_print_cases/print_expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,26 @@ def kernel(A):

print(hcl_A.asnumpy()[5])

# case2: float
# case1: uint

hcl.init(hcl.UInt(4))

A = hcl.placeholder((10,))

def kernel(A):
hcl.print(A[5])

s = hcl.create_schedule([A], kernel)
f = hcl.build(s)

np_A = np.random.randint(20, 30, size=(10,))
hcl_A = hcl.asarray(np_A)

f(hcl_A)

print(hcl_A.asnumpy()[5])

# case3: float

hcl.init(hcl.Float())

Expand All @@ -39,7 +58,7 @@ def kernel(A):

print("%.4f" % hcl_A.asnumpy()[5])

# case3: fixed points
# case4: fixed points

hcl.init(hcl.UFixed(6, 4))

Expand All @@ -58,7 +77,7 @@ def kernel(A):

print("%.4f" % hcl_A.asnumpy()[5])

# case4: two ints
# case5: two ints

hcl.init()

Expand Down
7 changes: 4 additions & 3 deletions tvm/src/codegen/llvm/codegen_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1414,7 +1414,7 @@ void CodeGenLLVM::VisitStmt_(const Print* op) {
values.push_back(MakeValue(v));
types.push_back(v.type());
if (v.type().is_int() || v.type().is_uint()) {
llvm_types.push_back(LLVMType(v.type()));
llvm_types.push_back(t_int64_);
} else {
llvm_types.push_back(llvm::Type::getDoubleTy(*ctx_));
}
Expand All @@ -1423,14 +1423,15 @@ void CodeGenLLVM::VisitStmt_(const Print* op) {
#if TVM_LLVM_VERSION <= 60
llvm::Function* printf_call = llvm::cast<llvm::Function>(module_->getOrInsertFunction("printf", call_ftype));
#else
llvm::Function *printf_call = llvm::cast<llvm::Function>(module_->getOrInsertFunction("printf", call_ftype).getCallee());
llvm::Function* printf_call = llvm::cast<llvm::Function>(module_->getOrInsertFunction("printf", call_ftype).getCallee());
#endif
std::vector<llvm::Value*> printf_args;
std::string format = op->format;
printf_args.push_back(builder_->CreateGlobalStringPtr(format));
for (size_t i = 0; i < op->values.size(); i++) {
if (types[i].is_int() || types[i].is_uint()) {
printf_args.push_back(values[i]);
llvm::Value* ivalue = CreateCast(types[i], Int(64), values[i]);
printf_args.push_back(ivalue);
} else { // fixed or float
llvm::Value* fvalue = CreateCast(types[i], Float(64), values[i]);
printf_args.push_back(fvalue);
Expand Down

0 comments on commit 01be7f5

Please sign in to comment.