21
21
"""
22
22
23
23
from abc import ABC , abstractmethod
24
+ from pytools import memoize_method
24
25
25
26
import numpy as np
26
27
import loopy as lp
28
+ from loopy .kernel .data import LocalInameTag
29
+ import pymbolic .primitives as prim
27
30
28
31
from sumpy .tools import KernelCacheMixin , gather_loopy_arguments
29
32
from loopy .version import MOST_RECENT_LANGUAGE_VERSION
@@ -70,7 +73,7 @@ def __init__(self, ctx, expansion, kernels,
70
73
71
74
self .ctx = ctx
72
75
self .expansion = expansion
73
- self .kernels = kernels
76
+ self .kernels = tuple ( kernels )
74
77
self .name = name or self .default_name
75
78
self .device = device
76
79
@@ -81,15 +84,18 @@ def __init__(self, ctx, expansion, kernels,
81
84
def default_name (self ):
82
85
pass
83
86
87
+ @memoize_method
88
+ def get_cached_loopy_knl_and_optimizations (self ):
89
+ return self .expansion .get_loopy_evaluator (self .kernels )
90
+
84
91
def get_cache_key (self ):
85
92
return (type (self ).__name__ , self .expansion , tuple (self .kernels ))
86
93
87
94
def add_loopy_eval_callable (
88
95
self , loopy_knl : lp .TranslationUnit ) -> lp .TranslationUnit :
89
- inner_knl = self .expansion . get_loopy_evaluator ( self . kernels )
96
+ inner_knl , _ = self .get_cached_loopy_knl_and_optimizations ( )
90
97
loopy_knl = lp .merge ([loopy_knl , inner_knl ])
91
98
loopy_knl = lp .inline_callable_kernel (loopy_knl , "e2p" )
92
- loopy_knl = lp .remove_unused_inames (loopy_knl )
93
99
for kernel in self .kernels :
94
100
loopy_knl = kernel .prepare_loopy_kernel (loopy_knl )
95
101
loopy_knl = lp .tag_array_axes (loopy_knl , "targets" , "sep,C" )
@@ -117,33 +123,41 @@ class E2PFromSingleBox(E2PBase):
117
123
def default_name (self ):
118
124
return "e2p_from_single_box"
119
125
120
- def get_kernel (self ):
126
+ def get_kernel (self , max_ntargets_in_one_box ):
121
127
ncoeffs = len (self .expansion )
122
128
loopy_args = self .get_loopy_args ()
129
+ max_work_items = min (32 , max (ncoeffs , max_ntargets_in_one_box ))
123
130
124
131
loopy_knl = lp .make_kernel (
125
132
[
126
133
"{[itgt_box]: 0<=itgt_box<ntgt_boxes}" ,
127
- "{[itgt,idim]: itgt_start<=itgt<itgt_end and 0<=idim<dim}" ,
134
+ "{[idim]: 0<=idim<dim}" ,
135
+ "{[itgt_offset]: 0<=itgt_offset<max_ntargets_in_one_box}" ,
128
136
"{[icoeff]: 0<=icoeff<ncoeffs}" ,
129
137
"{[iknl]: 0<=iknl<nresults}" ,
138
+ "{[dummy]: 0<=dummy<max_work_items}" ,
130
139
],
131
140
self .get_kernel_scaling_assignment ()
132
141
+ ["""
133
142
for itgt_box
134
- <> tgt_ibox = target_boxes[itgt_box]
135
- <> itgt_start = box_target_starts[tgt_ibox]
136
- <> itgt_end = itgt_start+box_target_counts_nonchild[tgt_ibox]
143
+ <> tgt_ibox = target_boxes[itgt_box] {id=fetch_init0}
144
+ <> itgt_start = box_target_starts[tgt_ibox] {id=fetch_init1}
145
+ <> itgt_end = itgt_start+box_target_counts_nonchild[tgt_ibox] \
146
+ {id=fetch_init2}
137
147
138
148
<> center[idim] = centers[idim, tgt_ibox] {id=fetch_center}
139
149
140
150
<> coeffs[icoeff] = \
141
151
src_expansions[tgt_ibox - src_base_ibox, icoeff] \
142
152
{id=fetch_coeffs}
143
153
144
- for itgt
145
- <> tgt[idim] = targets[idim, itgt] {id=fetch_tgt,dup=idim}
146
- <> result_temp[iknl] = 0 {id=init_result,dup=iknl}
154
+ for itgt_offset
155
+ <> itgt = itgt_start + itgt_offset
156
+ <> run_itgt = itgt<itgt_end
157
+ <> tgt[idim] = targets[idim, itgt] {id=fetch_tgt, \
158
+ dup=idim,if=run_itgt}
159
+ <> result_temp[iknl] = 0 {id=init_result,dup=iknl, \
160
+ if=run_itgt}
147
161
[iknl]: result_temp[iknl] = e2p(
148
162
[iknl]: result_temp[iknl],
149
163
[icoeff]: coeffs[icoeff],
@@ -155,9 +169,9 @@ def get_kernel(self):
155
169
targets,
156
170
""" + "," .join (arg .name for arg in loopy_args ) + """
157
171
) {dep=fetch_coeffs:fetch_center:init_result:fetch_tgt,\
158
- id=update_result}
172
+ id=update_result,if=run_itgt }
159
173
result[iknl, itgt] = result_temp[iknl] * kernel_scaling \
160
- {id=write_result,dep=update_result}
174
+ {id=write_result,dep=update_result,if=run_itgt }
161
175
end
162
176
end
163
177
""" ],
@@ -182,7 +196,9 @@ def get_kernel(self):
182
196
silenced_warnings = "write_race(*_result)" ,
183
197
default_offset = lp .auto ,
184
198
fixed_parameters = {"dim" : self .dim , "nresults" : len (self .kernels ),
185
- "ncoeffs" : ncoeffs },
199
+ "ncoeffs" : ncoeffs ,
200
+ "max_work_items" : max_work_items ,
201
+ "max_ntargets_in_one_box" : max_ntargets_in_one_box },
186
202
lang_version = MOST_RECENT_LANGUAGE_VERSION )
187
203
188
204
loopy_knl = lp .tag_inames (loopy_knl , "idim*:unr" )
@@ -191,13 +207,39 @@ def get_kernel(self):
191
207
192
208
return loopy_knl
193
209
194
- def get_optimized_kernel (self ):
195
- # FIXME
196
- knl = self .get_kernel ()
210
+ def get_optimized_kernel (self , max_ntargets_in_one_box ):
211
+ inner_knl , optimizations = self . get_cached_loopy_knl_and_optimizations ()
212
+ knl = self .get_kernel (max_ntargets_in_one_box = max_ntargets_in_one_box )
197
213
knl = lp .tag_inames (knl , {"itgt_box" : "g.0" })
214
+ knl = lp .split_iname (knl , "itgt_offset" , 32 , inner_tag = "l.0" )
215
+ knl = lp .split_iname (knl , "icoeff" , 32 , inner_tag = "l.0" )
216
+ knl = lp .add_inames_to_insn (knl , "dummy" ,
217
+ "id:fetch_init* or id:fetch_center or id:kernel_scaling" )
198
218
knl = lp .add_inames_to_insn (knl , "itgt_box" , "id:kernel_scaling" )
219
+ knl = lp .tag_inames (knl , {"dummy" : "l.0" })
220
+ knl = lp .set_temporary_address_space (knl , "coeffs" , lp .AddressSpace .LOCAL )
199
221
knl = lp .set_options (knl ,
200
- enforce_variable_access_ordered = "no_check" )
222
+ enforce_variable_access_ordered = "no_check" , write_code = False )
223
+
224
+ for transform in optimizations :
225
+ knl = transform (knl )
226
+
227
+ # If there are inames tagged as local in the inner kernel
228
+ # we need to remove the iname itgt_offset_inner from instructions
229
+ # within those inames and also remove the predicate run_itgt
230
+ # which depends on itgt_offset_inner
231
+ tagged_inames = [iname .name for iname in
232
+ knl .default_entrypoint .inames .values () if
233
+ iname .name .startswith ("e2p_" ) and any (
234
+ isinstance (tag , LocalInameTag ) for tag in iname .tags )]
235
+ if tagged_inames :
236
+ insn_ids = [insn .id for insn in knl .default_entrypoint .instructions
237
+ if any (iname in tagged_inames for iname in insn .within_inames )]
238
+ match = " or " .join (f"id:{ insn_id } " for insn_id in insn_ids )
239
+ knl = lp .remove_inames_from_insn (knl ,
240
+ frozenset (["itgt_offset_inner" ]), match )
241
+ knl = lp .remove_predicates_from_insn (knl ,
242
+ frozenset ([prim .Variable ("run_itgt" )]), match )
201
243
202
244
return knl
203
245
@@ -210,7 +252,9 @@ def __call__(self, queue, **kwargs):
210
252
:arg centers:
211
253
:arg targets:
212
254
"""
213
- knl = self .get_cached_optimized_kernel ()
255
+ max_ntargets_in_one_box = kwargs .pop ("max_ntargets_in_one_box" )
256
+ knl = self .get_cached_optimized_kernel (
257
+ max_ntargets_in_one_box = max_ntargets_in_one_box )
214
258
215
259
centers = kwargs .pop ("centers" )
216
260
# "1" may be passed for rscale, which won't have its type
@@ -229,42 +273,49 @@ class E2PFromCSR(E2PBase):
229
273
def default_name (self ):
230
274
return "e2p_from_csr"
231
275
232
- def get_kernel (self ):
276
+ def get_kernel (self , max_ntargets_in_one_box ):
233
277
ncoeffs = len (self .expansion )
234
278
loopy_args = self .get_loopy_args ()
279
+ max_work_items = min (32 , max (ncoeffs , max_ntargets_in_one_box ))
235
280
236
281
loopy_knl = lp .make_kernel (
237
282
[
238
283
"{[itgt_box]: 0<=itgt_box<ntgt_boxes}" ,
239
- "{[itgt ]: itgt_start<=itgt<itgt_end }" ,
284
+ "{[itgt_offset ]: 0<=itgt_offset<max_ntargets_in_one_box }" ,
240
285
"{[isrc_box]: isrc_box_start<=isrc_box<isrc_box_end }" ,
241
286
"{[idim]: 0<=idim<dim}" ,
242
287
"{[icoeff]: 0<=icoeff<ncoeffs}" ,
243
288
"{[iknl]: 0<=iknl<nresults}" ,
289
+ "{[dummy]: 0<=dummy<max_work_items}" ,
244
290
],
245
291
self .get_kernel_scaling_assignment ()
246
292
+ ["""
247
293
for itgt_box
248
- <> tgt_ibox = target_boxes[itgt_box]
249
- <> itgt_start = box_target_starts[tgt_ibox]
250
- <> itgt_end = itgt_start+box_target_counts_nonchild[tgt_ibox]
251
-
252
- for itgt
253
- <> tgt[idim] = targets[idim,itgt] {id=fetch_tgt,dup=idim}
254
-
255
- <> isrc_box_start = source_box_starts[itgt_box]
256
- <> isrc_box_end = source_box_starts[itgt_box+1]
257
-
258
- <> result_temp[iknl] = 0 {id=init_result,dup=iknl}
259
- for isrc_box
260
- <> src_ibox = source_box_lists[isrc_box]
261
- <> coeffs[icoeff] = \
294
+ <> tgt_ibox = target_boxes[itgt_box] {id=init_box0}
295
+ <> itgt_start = box_target_starts[tgt_ibox] {id=init_box1}
296
+ <> itgt_end = itgt_start+box_target_counts_nonchild[tgt_ibox] \
297
+ {id=init_box2}
298
+ <> isrc_box_start = source_box_starts[itgt_box] {id=init_box3}
299
+ <> isrc_box_end = source_box_starts[itgt_box+1] {id=init_box4}
300
+
301
+ <> result_temp[itgt_offset, iknl] = 0 \
302
+ {id=init_result,dup=iknl}
303
+ for isrc_box
304
+ <> src_ibox = source_box_lists[isrc_box] {id=fetch_src_box}
305
+ <> coeffs[icoeff] = \
262
306
src_expansions[src_ibox - src_base_ibox, icoeff] \
263
- {id=fetch_coeffs,dup=icoeff }
264
- <> center[idim] = centers[idim, src_ibox] \
307
+ {id=fetch_coeffs}
308
+ <> center[idim] = centers[idim, src_ibox] \
265
309
{dup=idim,id=fetch_center}
266
- [iknl]: result_temp[iknl] = e2p(
267
- [iknl]: result_temp[iknl],
310
+
311
+ for itgt_offset
312
+ <> itgt = itgt_start + itgt_offset
313
+ <> run_itgt = itgt<itgt_end
314
+ <> tgt[idim] = targets[idim,itgt] \
315
+ {id=fetch_tgt,dup=idim,if=run_itgt}
316
+
317
+ [iknl]: result_temp[itgt_offset, iknl] = e2p(
318
+ [iknl]: result_temp[itgt_offset, iknl],
268
319
[icoeff]: coeffs[icoeff],
269
320
[idim]: center[idim],
270
321
[idim]: tgt[idim],
@@ -274,11 +325,18 @@ def get_kernel(self):
274
325
targets,
275
326
""" + "," .join (arg .name for arg in loopy_args ) + """
276
327
) {id=update_result, \
277
- dep=fetch_coeffs:fetch_center:fetch_tgt:init_result}
328
+ dep=fetch_coeffs:fetch_center:fetch_tgt:init_result, \
329
+ if=run_itgt}
278
330
end
279
- result[iknl, itgt] = result[iknl, itgt] + result_temp[iknl] \
280
- * kernel_scaling \
281
- {dep=update_result:init_result,id=write_result,dup=iknl}
331
+ end
332
+ for itgt_offset
333
+ <> itgt2 = itgt_start + itgt_offset {id=init_itgt_for_write}
334
+ <> run_itgt2 = itgt_start + itgt_offset < itgt_end \
335
+ {id=init_cond_for_write}
336
+ result[iknl, itgt2] = result[iknl, itgt2] + result_temp[ \
337
+ itgt_offset, iknl] * kernel_scaling \
338
+ {dep=update_result:init_result,id=write_result, \
339
+ dup=iknl,if=run_itgt2}
282
340
end
283
341
end
284
342
""" ],
@@ -306,28 +364,48 @@ def get_kernel(self):
306
364
fixed_parameters = {
307
365
"ncoeffs" : ncoeffs ,
308
366
"dim" : self .dim ,
367
+ "max_work_items" : max_work_items ,
368
+ "max_ntargets_in_one_box" : max_ntargets_in_one_box ,
309
369
"nresults" : len (self .kernels )},
310
370
lang_version = MOST_RECENT_LANGUAGE_VERSION )
311
371
312
372
loopy_knl = lp .tag_inames (loopy_knl , "idim*:unr" )
313
373
loopy_knl = lp .tag_inames (loopy_knl , "iknl*:unr" )
314
- loopy_knl = lp .prioritize_loops (loopy_knl , "itgt_box,itgt, isrc_box" )
374
+ loopy_knl = lp .prioritize_loops (loopy_knl , "itgt_box,isrc_box,itgt_offset " )
315
375
loopy_knl = self .add_loopy_eval_callable (loopy_knl )
316
376
loopy_knl = lp .tag_array_axes (loopy_knl , "targets" , "sep,C" )
317
377
318
378
return loopy_knl
319
379
320
- def get_optimized_kernel (self ):
321
- # FIXME
322
- knl = self .get_kernel ()
323
- knl = lp .tag_inames (knl , {"itgt_box" : "g.0" })
380
+ def get_optimized_kernel (self , max_ntargets_in_one_box ):
381
+ _ , optimizations = self .get_cached_loopy_knl_and_optimizations ()
382
+ knl = self .get_kernel (max_ntargets_in_one_box = max_ntargets_in_one_box )
383
+ knl = lp .tag_inames (knl , {"itgt_box" : "g.0" , "dummy" : "l.0" })
384
+ knl = lp .unprivatize_temporaries_with_inames (knl ,
385
+ "itgt_offset" , "result_temp" )
386
+ knl = lp .split_iname (knl , "itgt_offset" , 32 , inner_tag = "l.0" )
387
+ knl = lp .split_iname (knl , "icoeff" , 32 , inner_tag = "l.0" )
388
+ knl = lp .privatize_temporaries_with_inames (knl ,
389
+ "itgt_offset_outer" , "result_temp" )
390
+ knl = lp .duplicate_inames (knl , "itgt_offset_outer" , "id:init_result" )
391
+ knl = lp .duplicate_inames (knl , "itgt_offset_outer" ,
392
+ "id:write_result or id:init_itgt_for_write or id:init_cond_for_write" )
393
+ knl = lp .add_inames_to_insn (knl , "dummy" ,
394
+ "id:init_box* or id:fetch_src_box or id:fetch_center "
395
+ "or id:kernel_scaling" )
324
396
knl = lp .add_inames_to_insn (knl , "itgt_box" , "id:kernel_scaling" )
397
+ knl = lp .add_inames_to_insn (knl , "itgt_offset_inner" , "id:fetch_init*" )
398
+ knl = lp .set_temporary_address_space (knl , "coeffs" , lp .AddressSpace .LOCAL )
325
399
knl = lp .set_options (knl ,
326
- enforce_variable_access_ordered = "no_check" )
400
+ enforce_variable_access_ordered = "no_check" , write_code = False )
401
+ for transform in optimizations :
402
+ knl = transform (knl )
327
403
return knl
328
404
329
405
def __call__ (self , queue , ** kwargs ):
330
- knl = self .get_cached_optimized_kernel ()
406
+ max_ntargets_in_one_box = kwargs .pop ("max_ntargets_in_one_box" )
407
+ knl = self .get_cached_optimized_kernel (
408
+ max_ntargets_in_one_box = max_ntargets_in_one_box )
331
409
332
410
centers = kwargs .pop ("centers" )
333
411
# "1" may be passed for rscale, which won't have its type
0 commit comments