Skip to content

Commit 154bf81

Browse files
committed
Optimize L2P for GPUs
1 parent 64a2196 commit 154bf81

File tree

7 files changed

+474
-51
lines changed

7 files changed

+474
-51
lines changed

sumpy/e2p.py

Lines changed: 128 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,12 @@
2121
"""
2222

2323
from abc import ABC, abstractmethod
24+
from pytools import memoize_method
2425

2526
import numpy as np
2627
import loopy as lp
28+
from loopy.kernel.data import LocalInameTag
29+
import pymbolic.primitives as prim
2730

2831
from sumpy.tools import KernelCacheMixin, gather_loopy_arguments
2932
from loopy.version import MOST_RECENT_LANGUAGE_VERSION
@@ -70,7 +73,7 @@ def __init__(self, ctx, expansion, kernels,
7073

7174
self.ctx = ctx
7275
self.expansion = expansion
73-
self.kernels = kernels
76+
self.kernels = tuple(kernels)
7477
self.name = name or self.default_name
7578
self.device = device
7679

@@ -81,15 +84,18 @@ def __init__(self, ctx, expansion, kernels,
8184
def default_name(self):
8285
pass
8386

87+
@memoize_method
88+
def get_cached_loopy_knl_and_optimizations(self):
89+
return self.expansion.get_loopy_evaluator(self.kernels)
90+
8491
def get_cache_key(self):
8592
return (type(self).__name__, self.expansion, tuple(self.kernels))
8693

8794
def add_loopy_eval_callable(
8895
self, loopy_knl: lp.TranslationUnit) -> lp.TranslationUnit:
89-
inner_knl = self.expansion.get_loopy_evaluator(self.kernels)
96+
inner_knl, _ = self.get_cached_loopy_knl_and_optimizations()
9097
loopy_knl = lp.merge([loopy_knl, inner_knl])
9198
loopy_knl = lp.inline_callable_kernel(loopy_knl, "e2p")
92-
loopy_knl = lp.remove_unused_inames(loopy_knl)
9399
for kernel in self.kernels:
94100
loopy_knl = kernel.prepare_loopy_kernel(loopy_knl)
95101
loopy_knl = lp.tag_array_axes(loopy_knl, "targets", "sep,C")
@@ -117,33 +123,41 @@ class E2PFromSingleBox(E2PBase):
117123
def default_name(self):
118124
return "e2p_from_single_box"
119125

120-
def get_kernel(self):
126+
def get_kernel(self, max_ntargets_in_one_box):
121127
ncoeffs = len(self.expansion)
122128
loopy_args = self.get_loopy_args()
129+
max_work_items = min(32, max(ncoeffs, max_ntargets_in_one_box))
123130

124131
loopy_knl = lp.make_kernel(
125132
[
126133
"{[itgt_box]: 0<=itgt_box<ntgt_boxes}",
127-
"{[itgt,idim]: itgt_start<=itgt<itgt_end and 0<=idim<dim}",
134+
"{[idim]: 0<=idim<dim}",
135+
"{[itgt_offset]: 0<=itgt_offset<max_ntargets_in_one_box}",
128136
"{[icoeff]: 0<=icoeff<ncoeffs}",
129137
"{[iknl]: 0<=iknl<nresults}",
138+
"{[dummy]: 0<=dummy<max_work_items}",
130139
],
131140
self.get_kernel_scaling_assignment()
132141
+ ["""
133142
for itgt_box
134-
<> tgt_ibox = target_boxes[itgt_box]
135-
<> itgt_start = box_target_starts[tgt_ibox]
136-
<> itgt_end = itgt_start+box_target_counts_nonchild[tgt_ibox]
143+
<> tgt_ibox = target_boxes[itgt_box] {id=fetch_init0}
144+
<> itgt_start = box_target_starts[tgt_ibox] {id=fetch_init1}
145+
<> itgt_end = itgt_start+box_target_counts_nonchild[tgt_ibox] \
146+
{id=fetch_init2}
137147
138148
<> center[idim] = centers[idim, tgt_ibox] {id=fetch_center}
139149
140150
<> coeffs[icoeff] = \
141151
src_expansions[tgt_ibox - src_base_ibox, icoeff] \
142152
{id=fetch_coeffs}
143153
144-
for itgt
145-
<> tgt[idim] = targets[idim, itgt] {id=fetch_tgt,dup=idim}
146-
<> result_temp[iknl] = 0 {id=init_result,dup=iknl}
154+
for itgt_offset
155+
<> itgt = itgt_start + itgt_offset
156+
<> run_itgt = itgt<itgt_end
157+
<> tgt[idim] = targets[idim, itgt] {id=fetch_tgt, \
158+
dup=idim,if=run_itgt}
159+
<> result_temp[iknl] = 0 {id=init_result,dup=iknl, \
160+
if=run_itgt}
147161
[iknl]: result_temp[iknl] = e2p(
148162
[iknl]: result_temp[iknl],
149163
[icoeff]: coeffs[icoeff],
@@ -155,9 +169,9 @@ def get_kernel(self):
155169
targets,
156170
""" + ",".join(arg.name for arg in loopy_args) + """
157171
) {dep=fetch_coeffs:fetch_center:init_result:fetch_tgt,\
158-
id=update_result}
172+
id=update_result,if=run_itgt}
159173
result[iknl, itgt] = result_temp[iknl] * kernel_scaling \
160-
{id=write_result,dep=update_result}
174+
{id=write_result,dep=update_result,if=run_itgt}
161175
end
162176
end
163177
"""],
@@ -182,7 +196,9 @@ def get_kernel(self):
182196
silenced_warnings="write_race(*_result)",
183197
default_offset=lp.auto,
184198
fixed_parameters={"dim": self.dim, "nresults": len(self.kernels),
185-
"ncoeffs": ncoeffs},
199+
"ncoeffs": ncoeffs,
200+
"max_work_items": max_work_items,
201+
"max_ntargets_in_one_box": max_ntargets_in_one_box},
186202
lang_version=MOST_RECENT_LANGUAGE_VERSION)
187203

188204
loopy_knl = lp.tag_inames(loopy_knl, "idim*:unr")
@@ -191,13 +207,39 @@ def get_kernel(self):
191207

192208
return loopy_knl
193209

194-
def get_optimized_kernel(self):
195-
# FIXME
196-
knl = self.get_kernel()
210+
def get_optimized_kernel(self, max_ntargets_in_one_box):
211+
inner_knl, optimizations = self.get_cached_loopy_knl_and_optimizations()
212+
knl = self.get_kernel(max_ntargets_in_one_box=max_ntargets_in_one_box)
197213
knl = lp.tag_inames(knl, {"itgt_box": "g.0"})
214+
knl = lp.split_iname(knl, "itgt_offset", 32, inner_tag="l.0")
215+
knl = lp.split_iname(knl, "icoeff", 32, inner_tag="l.0")
216+
knl = lp.add_inames_to_insn(knl, "dummy",
217+
"id:fetch_init* or id:fetch_center or id:kernel_scaling")
198218
knl = lp.add_inames_to_insn(knl, "itgt_box", "id:kernel_scaling")
219+
knl = lp.tag_inames(knl, {"dummy": "l.0"})
220+
knl = lp.set_temporary_address_space(knl, "coeffs", lp.AddressSpace.LOCAL)
199221
knl = lp.set_options(knl,
200-
enforce_variable_access_ordered="no_check")
222+
enforce_variable_access_ordered="no_check", write_code=False)
223+
224+
for transform in optimizations:
225+
knl = transform(knl)
226+
227+
# If there are inames tagged as local in the inner kernel
228+
# we need to remove the iname itgt_offset_inner from instructions
229+
# within those inames and also remove the predicate run_itgt
230+
# which depends on itgt_offset_inner
231+
tagged_inames = [iname.name for iname in
232+
knl.default_entrypoint.inames.values() if
233+
iname.name.startswith("e2p_") and any(
234+
isinstance(tag, LocalInameTag) for tag in iname.tags)]
235+
if tagged_inames:
236+
insn_ids = [insn.id for insn in knl.default_entrypoint.instructions
237+
if any(iname in tagged_inames for iname in insn.within_inames)]
238+
match = " or ".join(f"id:{insn_id}" for insn_id in insn_ids)
239+
knl = lp.remove_inames_from_insn(knl,
240+
frozenset(["itgt_offset_inner"]), match)
241+
knl = lp.remove_predicates_from_insn(knl,
242+
frozenset([prim.Variable("run_itgt")]), match)
201243

202244
return knl
203245

@@ -210,7 +252,9 @@ def __call__(self, queue, **kwargs):
210252
:arg centers:
211253
:arg targets:
212254
"""
213-
knl = self.get_cached_optimized_kernel()
255+
max_ntargets_in_one_box = kwargs.pop("max_ntargets_in_one_box")
256+
knl = self.get_cached_optimized_kernel(
257+
max_ntargets_in_one_box=max_ntargets_in_one_box)
214258

215259
centers = kwargs.pop("centers")
216260
# "1" may be passed for rscale, which won't have its type
@@ -229,42 +273,49 @@ class E2PFromCSR(E2PBase):
229273
def default_name(self):
230274
return "e2p_from_csr"
231275

232-
def get_kernel(self):
276+
def get_kernel(self, max_ntargets_in_one_box):
233277
ncoeffs = len(self.expansion)
234278
loopy_args = self.get_loopy_args()
279+
max_work_items = min(32, max(ncoeffs, max_ntargets_in_one_box))
235280

236281
loopy_knl = lp.make_kernel(
237282
[
238283
"{[itgt_box]: 0<=itgt_box<ntgt_boxes}",
239-
"{[itgt]: itgt_start<=itgt<itgt_end}",
284+
"{[itgt_offset]: 0<=itgt_offset<max_ntargets_in_one_box}",
240285
"{[isrc_box]: isrc_box_start<=isrc_box<isrc_box_end }",
241286
"{[idim]: 0<=idim<dim}",
242287
"{[icoeff]: 0<=icoeff<ncoeffs}",
243288
"{[iknl]: 0<=iknl<nresults}",
289+
"{[dummy]: 0<=dummy<max_work_items}",
244290
],
245291
self.get_kernel_scaling_assignment()
246292
+ ["""
247293
for itgt_box
248-
<> tgt_ibox = target_boxes[itgt_box]
249-
<> itgt_start = box_target_starts[tgt_ibox]
250-
<> itgt_end = itgt_start+box_target_counts_nonchild[tgt_ibox]
251-
252-
for itgt
253-
<> tgt[idim] = targets[idim,itgt] {id=fetch_tgt,dup=idim}
254-
255-
<> isrc_box_start = source_box_starts[itgt_box]
256-
<> isrc_box_end = source_box_starts[itgt_box+1]
257-
258-
<> result_temp[iknl] = 0 {id=init_result,dup=iknl}
259-
for isrc_box
260-
<> src_ibox = source_box_lists[isrc_box]
261-
<> coeffs[icoeff] = \
294+
<> tgt_ibox = target_boxes[itgt_box] {id=init_box0}
295+
<> itgt_start = box_target_starts[tgt_ibox] {id=init_box1}
296+
<> itgt_end = itgt_start+box_target_counts_nonchild[tgt_ibox] \
297+
{id=init_box2}
298+
<> isrc_box_start = source_box_starts[itgt_box] {id=init_box3}
299+
<> isrc_box_end = source_box_starts[itgt_box+1] {id=init_box4}
300+
301+
<> result_temp[itgt_offset, iknl] = 0 \
302+
{id=init_result,dup=iknl}
303+
for isrc_box
304+
<> src_ibox = source_box_lists[isrc_box] {id=fetch_src_box}
305+
<> coeffs[icoeff] = \
262306
src_expansions[src_ibox - src_base_ibox, icoeff] \
263-
{id=fetch_coeffs,dup=icoeff}
264-
<> center[idim] = centers[idim, src_ibox] \
307+
{id=fetch_coeffs}
308+
<> center[idim] = centers[idim, src_ibox] \
265309
{dup=idim,id=fetch_center}
266-
[iknl]: result_temp[iknl] = e2p(
267-
[iknl]: result_temp[iknl],
310+
311+
for itgt_offset
312+
<> itgt = itgt_start + itgt_offset
313+
<> run_itgt = itgt<itgt_end
314+
<> tgt[idim] = targets[idim,itgt] \
315+
{id=fetch_tgt,dup=idim,if=run_itgt}
316+
317+
[iknl]: result_temp[itgt_offset, iknl] = e2p(
318+
[iknl]: result_temp[itgt_offset, iknl],
268319
[icoeff]: coeffs[icoeff],
269320
[idim]: center[idim],
270321
[idim]: tgt[idim],
@@ -274,11 +325,18 @@ def get_kernel(self):
274325
targets,
275326
""" + ",".join(arg.name for arg in loopy_args) + """
276327
) {id=update_result, \
277-
dep=fetch_coeffs:fetch_center:fetch_tgt:init_result}
328+
dep=fetch_coeffs:fetch_center:fetch_tgt:init_result, \
329+
if=run_itgt}
278330
end
279-
result[iknl, itgt] = result[iknl, itgt] + result_temp[iknl] \
280-
* kernel_scaling \
281-
{dep=update_result:init_result,id=write_result,dup=iknl}
331+
end
332+
for itgt_offset
333+
<> itgt2 = itgt_start + itgt_offset {id=init_itgt_for_write}
334+
<> run_itgt2 = itgt_start + itgt_offset < itgt_end \
335+
{id=init_cond_for_write}
336+
result[iknl, itgt2] = result[iknl, itgt2] + result_temp[ \
337+
itgt_offset, iknl] * kernel_scaling \
338+
{dep=update_result:init_result,id=write_result, \
339+
dup=iknl,if=run_itgt2}
282340
end
283341
end
284342
"""],
@@ -306,28 +364,48 @@ def get_kernel(self):
306364
fixed_parameters={
307365
"ncoeffs": ncoeffs,
308366
"dim": self.dim,
367+
"max_work_items": max_work_items,
368+
"max_ntargets_in_one_box": max_ntargets_in_one_box,
309369
"nresults": len(self.kernels)},
310370
lang_version=MOST_RECENT_LANGUAGE_VERSION)
311371

312372
loopy_knl = lp.tag_inames(loopy_knl, "idim*:unr")
313373
loopy_knl = lp.tag_inames(loopy_knl, "iknl*:unr")
314-
loopy_knl = lp.prioritize_loops(loopy_knl, "itgt_box,itgt,isrc_box")
374+
loopy_knl = lp.prioritize_loops(loopy_knl, "itgt_box,isrc_box,itgt_offset")
315375
loopy_knl = self.add_loopy_eval_callable(loopy_knl)
316376
loopy_knl = lp.tag_array_axes(loopy_knl, "targets", "sep,C")
317377

318378
return loopy_knl
319379

320-
def get_optimized_kernel(self):
321-
# FIXME
322-
knl = self.get_kernel()
323-
knl = lp.tag_inames(knl, {"itgt_box": "g.0"})
380+
def get_optimized_kernel(self, max_ntargets_in_one_box):
381+
_, optimizations = self.get_cached_loopy_knl_and_optimizations()
382+
knl = self.get_kernel(max_ntargets_in_one_box=max_ntargets_in_one_box)
383+
knl = lp.tag_inames(knl, {"itgt_box": "g.0", "dummy": "l.0"})
384+
knl = lp.unprivatize_temporaries_with_inames(knl,
385+
"itgt_offset", "result_temp")
386+
knl = lp.split_iname(knl, "itgt_offset", 32, inner_tag="l.0")
387+
knl = lp.split_iname(knl, "icoeff", 32, inner_tag="l.0")
388+
knl = lp.privatize_temporaries_with_inames(knl,
389+
"itgt_offset_outer", "result_temp")
390+
knl = lp.duplicate_inames(knl, "itgt_offset_outer", "id:init_result")
391+
knl = lp.duplicate_inames(knl, "itgt_offset_outer",
392+
"id:write_result or id:init_itgt_for_write or id:init_cond_for_write")
393+
knl = lp.add_inames_to_insn(knl, "dummy",
394+
"id:init_box* or id:fetch_src_box or id:fetch_center "
395+
"or id:kernel_scaling")
324396
knl = lp.add_inames_to_insn(knl, "itgt_box", "id:kernel_scaling")
397+
knl = lp.add_inames_to_insn(knl, "itgt_offset_inner", "id:fetch_init*")
398+
knl = lp.set_temporary_address_space(knl, "coeffs", lp.AddressSpace.LOCAL)
325399
knl = lp.set_options(knl,
326-
enforce_variable_access_ordered="no_check")
400+
enforce_variable_access_ordered="no_check", write_code=False)
401+
for transform in optimizations:
402+
knl = transform(knl)
327403
return knl
328404

329405
def __call__(self, queue, **kwargs):
330-
knl = self.get_cached_optimized_kernel()
406+
max_ntargets_in_one_box = kwargs.pop("max_ntargets_in_one_box")
407+
knl = self.get_cached_optimized_kernel(
408+
max_ntargets_in_one_box=max_ntargets_in_one_box)
331409

332410
centers = kwargs.pop("centers")
333411
# "1" may be passed for rscale, which won't have its type

sumpy/expansion/local.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -405,6 +405,14 @@ def loopy_translate_from(self, src_expansion):
405405
f"A direct loopy kernel for translation from "
406406
f"{src_expansion} to {self} is not implemented.")
407407

408+
def loopy_evaluate(self, kernels):
409+
from sumpy.expansion.loopy import (make_l2p_loopy_kernel_for_volume_taylor,
410+
make_e2p_loopy_kernel)
411+
try:
412+
return make_l2p_loopy_kernel_for_volume_taylor(self, kernels)
413+
except NotImplementedError:
414+
return make_e2p_loopy_kernel(self, kernels)
415+
408416

409417
class VolumeTaylorLocalExpansion(
410418
VolumeTaylorExpansion,

0 commit comments

Comments
 (0)