Skip to content

Commit 8fb43a9

Browse files
Paul KienzleDrPaulSharp
authored andcommitted
Fix RepeatedKernelRetrieval error in OpenCL
1 parent 72e7de6 commit 8fb43a9

File tree

2 files changed

+27
-29
lines changed

2 files changed

+27
-29
lines changed

sasmodels/kernelcl.py

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -289,8 +289,8 @@ def has_type(self, dtype):
289289
"""
290290
return self.context.get(dtype, None) is not None
291291

292-
def compile_program(self, name, source, dtype, fast, timestamp):
293-
# type: (str, str, np.dtype, bool, float) -> cl.Program
292+
def compile_program(self, name, source, dtype, fast, timestamp, kernel_names):
293+
# type: (str, str, np.dtype, bool, float, list[str]) -> cl.Program
294294
"""
295295
Compile the program for the device in the given context.
296296
"""
@@ -299,17 +299,18 @@ def compile_program(self, name, source, dtype, fast, timestamp):
299299
tag = generate.tag_source(source)
300300
key = "%s-%s-%s%s"%(name, dtype, tag, ("-fast" if fast else ""))
301301
# Check timestamp on program.
302-
program, program_timestamp = self.compiled.get(key, (None, np.inf))
303-
if program_timestamp < timestamp:
302+
program, compile_timestamp, kernels = self.compiled.get(key, (None, np.inf, []))
303+
if compile_timestamp < timestamp:
304304
del self.compiled[key]
305305
if key not in self.compiled:
306306
context = self.context[dtype]
307307
logging.info("building %s for OpenCL %s", key,
308308
context.devices[0].name.strip())
309309
program = compile_model(self.context[dtype],
310310
str(source), dtype, fast)
311-
self.compiled[key] = (program, timestamp)
312-
return program
311+
kernels = [getattr(program, k) for k in kernel_names]
312+
self.compiled[key] = (program, timestamp, kernels)
313+
return kernels
313314

314315

315316
def _create_some_context():
@@ -457,20 +458,18 @@ def get_function(self, name):
457458
def _prepare_program(self):
458459
# type: (str) -> None
459460
env = environment()
461+
variants = ['Iq', 'Iqxy', 'Imagnetic']
462+
kernel_names = [generate.kernel_name(self.info, k) for k in variants]
460463
timestamp = generate.ocl_timestamp(self.info)
461-
program = env.compile_program(
464+
kernels = env.compile_program(
462465
self.info.name,
463466
self.source['opencl'],
464467
self.dtype,
465468
self.fast,
466-
timestamp)
467-
variants = ['Iq', 'Iqxy', 'Imagnetic']
468-
names = [generate.kernel_name(self.info, k) for k in variants]
469-
functions = [getattr(program, k) for k in names]
470-
self._kernels = {k: v for k, v in zip(variants, functions)}
471-
# Keep a handle to program so GC doesn't collect.
472-
self._program = program
473-
469+
timestamp,
470+
kernel_names,
471+
)
472+
self._kernels = {k: v for k, v in zip(variants, kernels)}
474473

475474
# TODO: Check that we don't need a destructor for buffers which go out of scope.
476475
class GpuInput:

sasmodels/kernelcuda.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -274,8 +274,8 @@ def has_type(self, dtype):
274274
"""
275275
return has_type(dtype)
276276

277-
def compile_program(self, name, source, dtype, fast, timestamp):
278-
# type: (str, str, np.dtype, bool, float) -> SourceModule
277+
def compile_program(self, name, source, dtype, fast, timestamp, kernel_names):
278+
# type: (str, str, np.dtype, bool, float, list[str]) -> SourceModule
279279
"""
280280
Compile the program for the device in the given context.
281281
"""
@@ -284,14 +284,15 @@ def compile_program(self, name, source, dtype, fast, timestamp):
284284
tag = generate.tag_source(source)
285285
key = "%s-%s-%s%s"%(name, dtype, tag, ("-fast" if fast else ""))
286286
# Check timestamp on program.
287-
program, program_timestamp = self.compiled.get(key, (None, np.inf))
288-
if program_timestamp < timestamp:
287+
program, compile_timestamp, kernels = self.compiled.get(key, (None, np.inf, []))
288+
if compile_timestamp < timestamp:
289289
del self.compiled[key]
290290
if key not in self.compiled:
291291
logging.info("building %s for CUDA", key)
292292
program = compile_model(str(source), dtype, fast)
293-
self.compiled[key] = (program, timestamp)
294-
return program
293+
kernels = [getattr(program, k) for k in kernel_names]
294+
self.compiled[key] = (program, timestamp, kernels)
295+
return kernels
295296

296297

297298
class GpuModel(KernelModel):
@@ -349,19 +350,17 @@ def get_function(self, name):
349350
def _prepare_program(self):
350351
# type: (str) -> None
351352
env = environment()
353+
variants = ['Iq', 'Iqxy', 'Imagnetic']
354+
kernel_names = [generate.kernel_name(self.info, k) for k in variants]
352355
timestamp = generate.ocl_timestamp(self.info)
353-
program = env.compile_program(
356+
kernels = env.compile_program(
354357
self.info.name,
355358
self.source['opencl'],
356359
self.dtype,
357360
self.fast,
358-
timestamp)
359-
variants = ['Iq', 'Iqxy', 'Imagnetic']
360-
names = [generate.kernel_name(self.info, k) for k in variants]
361-
functions = [program.get_function(k) for k in names]
362-
self._kernels = {k: v for k, v in zip(variants, functions)}
363-
# Keep a handle to program so GC doesn't collect.
364-
self._program = program
361+
timestamp,
362+
kernel_names)
363+
self._kernels = {k: v for k, v in zip(variants, kernels)}
365364

366365

367366
# TODO: Check that we don't need a destructor for buffers which go out of scope.

0 commit comments

Comments
 (0)