-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathcuda.bzl
42 lines (34 loc) · 1.1 KB
/
cuda.bzl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
def _cuda_binary(ctx):
default_flags = ctx.attr.flags + \
" -x=cu -Xcompiler \"-O3 -Ofast -Wall -Wextra -DWITH_CUDA\""
cmd = "/usr/local/cuda/bin/nvcc -D__CUDACC__ "
cmd += " " + default_flags + " "
for src in ctx.files.srcs:
cmd += src.path
executable = ctx.actions.declare_file(ctx.attr.name)
cmd += " -o " + executable.path
for include in ctx.attr.includes:
cmd += " -I" + include
ctx.actions.run_shell(
outputs=[ctx.actions.declare_file(ctx.label.name)],
inputs=ctx.files.srcs + ctx.files.hdrs,
command=cmd,
mnemonic="CudaCompile",
progress_message="compile cuda",
use_default_shell_env=True,
)
return [DefaultInfo(
files=depset([executable]),
executable=executable,
)]
cuda_binary = rule(
implementation=_cuda_binary,
executable=True,
attrs={
"flags": attr.string(default=""),
"srcs": attr.label_list(default=[], allow_files=[".cc"]),
"hdrs": attr.label_list(default=[], allow_files=[".h"]),
"includes": attr.string_list(default=[]),
"out": attr.output(mandatory=False),
},
)