[TRITON] Gluon softmax implementation #1227
Draft
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Changes
Wrote Gluon implementation of online softmax kernel, found in
aiter/ops/triton/gluon/softmax.py.num_warps,waves_per_eu).aiter/ops/triton/softmax.pywas modified so that thenum_stagesparameter is correctly passed into thetl.rangeloops.num_stagesparameter, since the compiler does not seem to generate pipelined code with the use of hardware-specific load/store operations (ex:gl.amd.cdna4.buffer_load).num_stages = 2) is manually written in the Gluon kernel.threads_per_warpandwarps_per_ctaare both determined by the hardware context.size_per_threadis determined by the given block size, but bounded by the number of elements that can be loaded by a single thread according to the input tensor's datatype.Testing
A benchmark script for softmax can be found in
op_tests/op_benchmarks/triton/bench_softmax.py, outputting time and bandwidth metrics. Run with the-gluonflag to get the timing results for the Gluon kernel. By default, the benchmark runs softmax on tensors of shapeN = 8192andMranging from 1 to 8192. If the-Nflag is passed, theNsize varies instead atM = 8192.Correctness for the Gluon kernel was tested with the
op_tests//triton_tests/test_softmax.pyscript.Below is a bandwidth plot comparing the performance of the Triton and Gluon kernels:
