Skip to content

Commit c9b0454

Browse files
authored
[Triton] Fix constexpr variable in softmax (triton-lang#6555)
1 parent 13d0e74 commit c9b0454

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

python/triton/language/standard.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,12 @@ def sigmoid(x):
5252
@math._add_math_1arg_docstr("softmax")
5353
def softmax(x, dim=None, ieee_rounding=False):
5454
if dim is None:
55-
dim: core.constexpr = x.shape[-1]
56-
z = x - max(x, dim, keep_dims=True)
55+
_dim: core.constexpr = x.shape[-1]
56+
else:
57+
_dim: core.constexpr = dim
58+
z = x - max(x, _dim, keep_dims=True)
5759
num = math.exp(z)
58-
den = sum(num, dim, keep_dims=True)
60+
den = sum(num, _dim, keep_dims=True)
5961
return math.fdiv(num, den, ieee_rounding)
6062

6163

0 commit comments

Comments
 (0)