@@ -289,8 +289,8 @@ def has_type(self, dtype):
289289 """
290290 return self .context .get (dtype , None ) is not None
291291
292- def compile_program (self , name , source , dtype , fast , timestamp ):
293- # type: (str, str, np.dtype, bool, float) -> cl.Program
292+ def compile_program (self , name , source , dtype , fast , timestamp , kernel_names ):
293+ # type: (str, str, np.dtype, bool, float, list[str] ) -> cl.Program
294294 """
295295 Compile the program for the device in the given context.
296296 """
@@ -299,17 +299,18 @@ def compile_program(self, name, source, dtype, fast, timestamp):
299299 tag = generate .tag_source (source )
300300 key = "%s-%s-%s%s" % (name , dtype , tag , ("-fast" if fast else "" ))
301301 # Check timestamp on program.
302- program , program_timestamp = self .compiled .get (key , (None , np .inf ))
303- if program_timestamp < timestamp :
302+ program , compile_timestamp , kernels = self .compiled .get (key , (None , np .inf , [] ))
303+ if compile_timestamp < timestamp :
304304 del self .compiled [key ]
305305 if key not in self .compiled :
306306 context = self .context [dtype ]
307307 logging .info ("building %s for OpenCL %s" , key ,
308308 context .devices [0 ].name .strip ())
309309 program = compile_model (self .context [dtype ],
310310 str (source ), dtype , fast )
311- self .compiled [key ] = (program , timestamp )
312- return program
311+ kernels = [getattr (program , k ) for k in kernel_names ]
312+ self .compiled [key ] = (program , timestamp , kernels )
313+ return kernels
313314
314315
315316def _create_some_context ():
@@ -457,20 +458,18 @@ def get_function(self, name):
457458 def _prepare_program (self ):
458459 # type: (str) -> None
459460 env = environment ()
461+ variants = ['Iq' , 'Iqxy' , 'Imagnetic' ]
462+ kernel_names = [generate .kernel_name (self .info , k ) for k in variants ]
460463 timestamp = generate .ocl_timestamp (self .info )
461- program = env .compile_program (
464+ kernels = env .compile_program (
462465 self .info .name ,
463466 self .source ['opencl' ],
464467 self .dtype ,
465468 self .fast ,
466- timestamp )
467- variants = ['Iq' , 'Iqxy' , 'Imagnetic' ]
468- names = [generate .kernel_name (self .info , k ) for k in variants ]
469- functions = [getattr (program , k ) for k in names ]
470- self ._kernels = {k : v for k , v in zip (variants , functions )}
471- # Keep a handle to program so GC doesn't collect.
472- self ._program = program
473-
469+ timestamp ,
470+ kernel_names ,
471+ )
472+ self ._kernels = {k : v for k , v in zip (variants , kernels )}
474473
475474# TODO: Check that we don't need a destructor for buffers which go out of scope.
476475class GpuInput :
0 commit comments