Skip to content

Commit 2f04b51

Browse files
deduplicate hip_check
1 parent 343413f commit 2f04b51

File tree

4 files changed

+359
-13
lines changed

4 files changed

+359
-13
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .hip import *

kernel_tuner/backends/hip/hip.py

Lines changed: 339 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,339 @@
1+
"""This module contains all HIP specific kernel_tuner functions."""
2+
3+
import ctypes
4+
import ctypes.util
5+
import logging
6+
7+
import numpy as np
8+
9+
from kernel_tuner.backends.backend import GPUBackend
10+
from kernel_tuner.observers.hip import HipRuntimeObserver
11+
from kernel_tuner.backends.hip.util import hip_check
12+
13+
try:
14+
from hip import hip, hiprtc
15+
except (ImportError, RuntimeError):
16+
hip = None
17+
hiprtc = None
18+
19+
dtype_map = {
20+
"bool": ctypes.c_bool,
21+
"int8": ctypes.c_int8,
22+
"int16": ctypes.c_int16,
23+
"int32": ctypes.c_int32,
24+
"int64": ctypes.c_int64,
25+
"uint8": ctypes.c_uint8,
26+
"uint16": ctypes.c_uint16,
27+
"uint32": ctypes.c_uint32,
28+
"uint64": ctypes.c_uint64,
29+
"float32": ctypes.c_float,
30+
"float64": ctypes.c_double,
31+
}
32+
33+
hipSuccess = 0
34+
35+
class HipFunctions(GPUBackend):
36+
"""Class that groups the HIP functions on maintains state about the device."""
37+
38+
def __init__(self, device=0, iterations=7, compiler_options=None, observers=None):
39+
"""Instantiate HipFunctions object used for interacting with the HIP device.
40+
41+
Instantiating this object will inspect and store certain device properties at
42+
runtime, which are used during compilation and/or execution of kernels by the
43+
kernel tuner. It also maintains a reference to the most recently compiled
44+
source module for copying data to constant memory before kernel launch.
45+
46+
:param device: Number of HIP device to use for this context
47+
:type device: int
48+
49+
:param iterations: Number of iterations used while benchmarking a kernel, 7 by default.
50+
:type iterations: int
51+
"""
52+
if not hip or not hiprtc:
53+
raise ImportError(
54+
"Unable to import HIP Python, check https://kerneltuner.github.io/kernel_tuner/stable/install.html#hip-and-hip-python."
55+
)
56+
57+
# embedded in try block to be able to generate documentation
58+
# and run tests without HIP Python installed
59+
logging.debug("HipFunction instantiated")
60+
61+
# Get device properties
62+
props = hip.hipDeviceProp_t()
63+
hip_check(hip.hipGetDeviceProperties(props, device))
64+
65+
self.name = props.name.decode("utf-8")
66+
self.max_threads = props.maxThreadsPerBlock
67+
self.device = device
68+
self.compiler_options = compiler_options or []
69+
self.iterations = iterations
70+
71+
env = dict()
72+
env["device_name"] = self.name
73+
env["iterations"] = self.iterations
74+
env["compiler_options"] = compiler_options
75+
self.env = env
76+
77+
# Create stream and events
78+
self.stream = hip_check(hip.hipStreamCreate())
79+
self.start = hip_check(hip.hipEventCreate())
80+
self.end = hip_check(hip.hipEventCreate())
81+
82+
# Default dynamically allocated shared memory size
83+
self.smem_size = 0
84+
85+
self.current_module = None
86+
87+
# Setup observers
88+
self.observers = observers or []
89+
self.observers.append(HipRuntimeObserver(self))
90+
for obs in self.observers:
91+
obs.register_device(self)
92+
93+
def ready_argument_list(self, arguments):
94+
"""Ready argument list to be passed to the HIP function.
95+
96+
:param arguments: List of arguments to be passed to the HIP function.
97+
The order should match the argument list on the HIP function.
98+
Allowed values are np.ndarray, and/or np.int32, np.float32, and so on.
99+
:type arguments: list(numpy objects)
100+
:returns: List of arguments to be passed to the HIP function.
101+
:rtype: list
102+
"""
103+
logging.debug("HipFunction ready_argument_list called")
104+
prepared_args = []
105+
106+
for arg in arguments:
107+
dtype_str = str(arg.dtype)
108+
109+
# Handle numpy arrays
110+
if isinstance(arg, np.ndarray):
111+
# Allocate device memory
112+
device_ptr = hip_check(hip.hipMalloc(arg.nbytes))
113+
114+
# Copy data to device using hipMemcpy
115+
hip_check(hip.hipMemcpy(device_ptr, arg, arg.nbytes, hip.hipMemcpyKind.hipMemcpyHostToDevice))
116+
117+
prepared_args.append(device_ptr)
118+
119+
# Handle numpy scalar types
120+
elif isinstance(arg, np.generic):
121+
# Convert numpy scalar to corresponding ctypes
122+
if dtype_str in dtype_map:
123+
ctype_arg = dtype_map[dtype_str](arg)
124+
prepared_args.append(ctype_arg)
125+
# 16-bit float is not supported, view it as uint16
126+
elif dtype_str in ("float16", "bfloat16"):
127+
ctype_arg = ctypes.c_uint16(arg.view(np.uint16))
128+
prepared_args.append(ctype_arg)
129+
else:
130+
raise ValueError(f"Invalid argument type {dtype_str}: {arg}")
131+
132+
else:
133+
raise ValueError(f"Invalid argument type {type(arg)}: {arg}")
134+
135+
return prepared_args
136+
137+
def compile(self, kernel_instance):
138+
"""Call the HIP compiler to compile the kernel, return the function.
139+
140+
:param kernel_instance: An object representing the specific instance of the tunable kernel
141+
in the parameter space.
142+
:type kernel_instance: kernel_tuner.core.KernelInstance
143+
144+
:returns: A HIP kernel function that can be called.
145+
:rtype: hipFunction_t
146+
"""
147+
logging.debug("HipFunction compile called")
148+
149+
# Format kernel string
150+
kernel_string = kernel_instance.kernel_string
151+
kernel_name = kernel_instance.name
152+
if 'extern "C"' not in kernel_string:
153+
kernel_string = 'extern "C" {\n' + kernel_string + "\n}"
154+
155+
# Create program
156+
prog = hip_check(hiprtc.hiprtcCreateProgram(kernel_string.encode(), kernel_name.encode(), 0, [], []))
157+
158+
try:
159+
# Get device properties
160+
props = hip.hipDeviceProp_t()
161+
hip_check(hip.hipGetDeviceProperties(props, 0))
162+
163+
# Setup compilation options
164+
arch = props.gcnArchName
165+
cflags = [b"--offload-arch=" + arch]
166+
cflags.extend([opt.encode() if isinstance(opt, str) else opt for opt in self.compiler_options])
167+
168+
# Compile program
169+
(err,) = hiprtc.hiprtcCompileProgram(prog, len(cflags), cflags)
170+
if err != hiprtc.hiprtcResult.HIPRTC_SUCCESS:
171+
# Get compilation log if there's an error
172+
log_size = hip_check(hiprtc.hiprtcGetProgramLogSize(prog))
173+
log = bytearray(log_size)
174+
hip_check(hiprtc.hiprtcGetProgramLog(prog, log))
175+
raise RuntimeError(log.decode())
176+
177+
# Get compiled code
178+
code_size = hip_check(hiprtc.hiprtcGetCodeSize(prog))
179+
code = bytearray(code_size)
180+
hip_check(hiprtc.hiprtcGetCode(prog, code))
181+
182+
# Load module and get function
183+
module = hip_check(hip.hipModuleLoadData(code))
184+
self.current_module = module
185+
kernel = hip_check(hip.hipModuleGetFunction(module, kernel_name.encode()))
186+
187+
except Exception as e:
188+
# Cleanup
189+
hip_check(hiprtc.hiprtcDestroyProgram(prog.createRef()))
190+
raise e
191+
192+
return kernel
193+
194+
def start_event(self):
195+
"""Records the event that marks the start of a measurement."""
196+
logging.debug("HipFunction start_event called")
197+
198+
hip_check(hip.hipEventRecord(self.start, self.stream))
199+
200+
def stop_event(self):
201+
"""Records the event that marks the end of a measurement."""
202+
logging.debug("HipFunction stop_event called")
203+
204+
hip_check(hip.hipEventRecord(self.end, self.stream))
205+
206+
def kernel_finished(self):
207+
"""Returns True if the kernel has finished, False otherwise."""
208+
logging.debug("HipFunction kernel_finished called")
209+
210+
# ROCm HIP returns (hipError_t, bool) for hipEventQuery
211+
status = hip.hipEventQuery(self.end)
212+
if status[0] == hip.hipError_t.hipSuccess:
213+
return True
214+
elif status[0] == hip.hipError_t.hipErrorNotReady:
215+
return False
216+
else:
217+
hip_check(status)
218+
219+
def synchronize(self):
220+
"""Halts execution until device has finished its tasks."""
221+
logging.debug("HipFunction synchronize called")
222+
223+
hip_check(hip.hipDeviceSynchronize())
224+
225+
def run_kernel(self, func, gpu_args, threads, grid, stream=None):
226+
"""Runs the HIP kernel passed as 'func'.
227+
228+
:param func: A HIP kernel compiled for this specific kernel configuration
229+
:type func: hipFunction_t
230+
231+
:param gpu_args: List of arguments to pass to the kernel. Can be DeviceArray
232+
objects or ctypes values
233+
:type gpu_args: list
234+
235+
:param threads: A tuple listing the number of threads in each dimension of
236+
the thread block
237+
:type threads: tuple(int, int, int)
238+
239+
:param grid: A tuple listing the number of thread blocks in each dimension
240+
of the grid
241+
:type grid: tuple(int, int, int)
242+
"""
243+
logging.debug("HipFunction run_kernel called")
244+
245+
if stream is None:
246+
stream = self.stream
247+
248+
# Create dim3 objects for grid and block dimensions
249+
grid_dim = hip.dim3(x=grid[0], y=grid[1], z=grid[2])
250+
block_dim = hip.dim3(x=threads[0], y=threads[1], z=threads[2])
251+
252+
# Launch kernel with the arguments
253+
hip_check(
254+
hip.hipModuleLaunchKernel(
255+
func,
256+
*grid_dim,
257+
*block_dim,
258+
sharedMemBytes=self.smem_size,
259+
stream=stream,
260+
kernelParams=None,
261+
extra=tuple(gpu_args),
262+
)
263+
)
264+
265+
def memset(self, allocation, value, size):
266+
"""Set the memory in allocation to the value in value.
267+
268+
:param allocation: A GPU memory allocation (DeviceArray)
269+
:type allocation: DeviceArray or int
270+
271+
:param value: The value to set the memory to
272+
:type value: int (8-bit unsigned)
273+
274+
:param size: The size of to the allocation unit in bytes
275+
:type size: int
276+
"""
277+
logging.debug("HipFunction memset called")
278+
279+
hip_check(hip.hipMemset(allocation, value, size))
280+
281+
def memcpy_dtoh(self, dest, src):
282+
"""Perform a device to host memory copy.
283+
284+
:param dest: A numpy array in host memory to store the data
285+
:type dest: numpy.ndarray
286+
287+
:param src: A GPU memory allocation unit
288+
:type src: DeviceArray or int
289+
"""
290+
logging.debug("HipFunction memcpy_dtoh called")
291+
292+
hip_check(hip.hipMemcpy(dest, src, dest.nbytes, hip.hipMemcpyKind.hipMemcpyDeviceToHost))
293+
294+
def memcpy_htod(self, dest, src):
295+
"""Perform a host to device memory copy.
296+
297+
:param dest: A GPU memory allocation unit
298+
:type dest: DeviceArray or int
299+
300+
:param src: A numpy array in host memory to copy from
301+
:type src: numpy.ndarray
302+
"""
303+
logging.debug("HipFunction memcpy_htod called")
304+
305+
hip_check(hip.hipMemcpy(dest, src, src.nbytes, hip.hipMemcpyKind.hipMemcpyHostToDevice))
306+
307+
def copy_constant_memory_args(self, cmem_args):
308+
"""Adds constant memory arguments to the most recently compiled module.
309+
310+
:param cmem_args: A dictionary containing the data to be passed to the
311+
device constant memory. The format to be used is as follows: A
312+
string key is used to name the constant memory symbol to which the
313+
value needs to be copied. Similar to regular arguments, these need
314+
to be numpy objects, such as numpy.ndarray or numpy.int32, and so on.
315+
:type cmem_args: dict(string: numpy.ndarray, ...)
316+
"""
317+
logging.debug("HipFunction copy_constant_memory_args called")
318+
319+
# Iterate over dictionary
320+
for symbol_name, data in cmem_args.items():
321+
# Get symbol pointer and size using hipModuleGetGlobal
322+
dptr, _ = hip_check(hip.hipModuleGetGlobal(self.current_module, symbol_name.encode()))
323+
324+
# Copy data to the global memory location
325+
hip_check(hip.hipMemcpy(dptr, data, data.nbytes, hip.hipMemcpyKind.hipMemcpyHostToDevice))
326+
327+
def copy_shared_memory_args(self, smem_args):
328+
"""Add shared memory arguments to the kernel."""
329+
logging.debug("HipFunction copy_shared_memory_args called")
330+
331+
self.smem_size = smem_args["size"]
332+
333+
def copy_texture_memory_args(self, texmem_args):
334+
"""Copy texture memory arguments. Not yet implemented."""
335+
logging.debug("HipFunction copy_texture_memory_args called")
336+
337+
raise NotImplementedError("HIP backend does not support texture memory")
338+
339+
units = {"time": "ms", "power": "s,mW", "energy": "J"}

kernel_tuner/backends/hip/util.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
try:
2+
from hip import hip, hiprtc
3+
except (ImportError, RuntimeError):
4+
hip = None
5+
6+
7+
def hip_check(call_result):
8+
"""helper function to check return values of hip calls"""
9+
err = call_result[0]
10+
result = call_result[1:]
11+
if len(result) == 1:
12+
result = result[0]
13+
if isinstance(err, hip.hipError_t) and err != hip.hipError_t.hipSuccess:
14+
_, error_name = hip.hipGetErrorName(err)
15+
_, error_str = hip.hipGetErrorString(err)
16+
raise RuntimeError(f"{error_name}: {error_str}")
17+
return result
18+

kernel_tuner/observers/hip.py

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import numpy as np
22

33
from kernel_tuner.observers.observer import BenchmarkObserver
4+
from kernel_tuner.backends.hip.util import hip_check
45

56
try:
67
from hip import hip, hiprtc
@@ -9,19 +10,6 @@
910
hiprtc = None
1011

1112

12-
def hip_check(call_result):
13-
"""helper function to check return values of hip calls"""
14-
err = call_result[0]
15-
result = call_result[1:]
16-
if len(result) == 1:
17-
result = result[0]
18-
if isinstance(err, hip.hipError_t) and err != hip.hipError_t.hipSuccess:
19-
_, error_name = hip.hipGetErrorName(err)
20-
_, error_str = hip.hipGetErrorString(err)
21-
raise RuntimeError(f"{error_name}: {error_str}")
22-
return result
23-
24-
2513
class HipRuntimeObserver(BenchmarkObserver):
2614
"""Observer that measures time using CUDA events during benchmarking."""
2715

0 commit comments

Comments
 (0)