diff --git a/LocalPreferences.toml b/LocalPreferences.toml index 513fc75593..57209ef082 100644 --- a/LocalPreferences.toml +++ b/LocalPreferences.toml @@ -16,6 +16,9 @@ # possible values: "device", "unified", "host" #default_memory = "device" +# From PrecompileTools, whether or not to precompile the GPUCompiler + Inference stack +#precompile_workload = true + [CUDA_Driver_jll] # whether to attempt to load a forwards-compatibile userspace driver. # only turn this off if you experience issues, e.g., when using a local diff --git a/Project.toml b/Project.toml index 21c476cbca..1342f7d1bf 100644 --- a/Project.toml +++ b/Project.toml @@ -23,6 +23,7 @@ Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" NVTX = "5da4648a-3479-48b8-97b9-01cb529c0a1f" +PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" Preferences = "21216c6a-2e73-6563-6e65-726566657250" PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" @@ -68,6 +69,7 @@ Libdl = "1" LinearAlgebra = "1" Logging = "1" NVTX = "0.3.2" +PrecompileTools = "1.2.1" Preferences = "1" PrettyTables = "2" Printf = "1" diff --git a/src/precompile.jl b/src/precompile.jl index fc95f362ba..3228a8aab9 100644 --- a/src/precompile.jl +++ b/src/precompile.jl @@ -14,3 +14,15 @@ precompile(run_and_collect, (Cmd,)) precompile(cudaconvert, (Function,)) precompile(Core.kwfunc(cudacall), (NamedTuple{(:threads, :blocks), Tuple{Int64, Int64}},typeof(cudacall),CuFunction,Type{Tuple{}})) precompile(Core.kwfunc(launch), (NamedTuple{(:threads, :blocks), Tuple{Int64, Int64}},typeof(launch),CuFunction)) + +using PrecompileTools: @setup_workload, @compile_workload +@setup_workload let + @compile_workload begin + target = PTXCompilerTarget(; cap=v"7.5") + params = CUDACompilerParams(; cap=v"7.5", ptx=v"7.5") + config = CompilerConfig(target, params) + mi = GPUCompiler.methodinstance(typeof(identity), Tuple{Nothing}) + job = CompilerJob(mi, config) + GPUCompiler.code_native(devnull, job) + end +end