Skip to content

Commit 8e47955

Browse files
cuda: support download toolchains (#7459)
* cuda: support download toolchains * limit plat
1 parent 7998da2 commit 8e47955

File tree

1 file changed

+49
-8
lines changed

1 file changed

+49
-8
lines changed

packages/c/cuda/xmake.lua

Lines changed: 49 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,24 @@
11
package("cuda")
2-
2+
set_kind("toolchain")
33
set_homepage("https://developer.nvidia.com/cuda-zone/")
44
set_description("CUDA® is a parallel computing platform and programming model developed by NVIDIA for general computing on graphical processing units (GPUs).")
55

6+
if is_plat("windows") then
7+
add_urls("https://developer.download.nvidia.com/compute/cuda/$(version)_windows.exe", {
8+
version = function (version)
9+
local driver_version_map = {
10+
["12.6.3"] = "561.17",
11+
}
12+
return format("%s/local_installers/cuda_%s_%s", version, version, driver_version_map[tostring(version)])
13+
end})
14+
15+
add_versions("12.6.3", "d73e937c75aaa8114da3aff4eee96f9cae03d4b9d70a30b962ccf3c9b4d7a8e1")
16+
end
17+
618
add_configs("utils", {description = "Enabled cuda utilities.", default = {}, type = "table"})
19+
add_configs("debug", {description = "Enable debug symbols.", default = false, type = "boolean", readonly = true})
720

8-
on_load(function (package)
9-
import("detect.sdks.find_cuda")
10-
local cuda = find_cuda()
11-
if cuda then
12-
package:addenv("PATH", cuda.bindir)
13-
end
14-
end)
21+
set_policy("package.precompiled", false)
1522

1623
on_fetch(function (package, opt)
1724
if opt.system then
@@ -35,3 +42,37 @@ package("cuda")
3542
end
3643
end
3744
end)
45+
46+
on_load("windows", function (package)
47+
package:mark_as_pathenv("CUDA_PATH")
48+
package:setenv("CUDA_PATH", ".")
49+
end)
50+
51+
on_install("windows|x64", function(package)
52+
import("lib.detect.find_tool")
53+
import("lib.detect.find_directory")
54+
55+
if package:is_plat("windows") then
56+
local z7 = assert(find_tool("7z"), "7z tool not found!")
57+
os.vrunv(z7.program, {"x", "-y", package:originfile()})
58+
59+
-- reference: https://github.com/ScoopInstaller/Main/blob/master/bucket/cuda.json
60+
local names = {"bin", "extras", "include", "lib", "libnvvp", "nvml", "nvvm", "compute-sanitizer"}
61+
for _, dir in ipairs(os.dirs("*")) do
62+
if dir:startswith("cuda_") or dir:startswith("lib") then
63+
for _, name in ipairs(names) do
64+
local util_dir = find_directory(name, path.join(dir, "*"))
65+
if util_dir then
66+
os.vcp(path.join(util_dir, "*"), package:installdir(name))
67+
end
68+
end
69+
end
70+
end
71+
end
72+
end)
73+
74+
on_test(function (package)
75+
if not package:is_cross() then
76+
os.vrun("nvcc -V")
77+
end
78+
end)

0 commit comments

Comments
 (0)