We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 13d0e74 commit c9b0454Copy full SHA for c9b0454
python/triton/language/standard.py
@@ -52,10 +52,12 @@ def sigmoid(x):
52
@math._add_math_1arg_docstr("softmax")
53
def softmax(x, dim=None, ieee_rounding=False):
54
if dim is None:
55
- dim: core.constexpr = x.shape[-1]
56
- z = x - max(x, dim, keep_dims=True)
+ _dim: core.constexpr = x.shape[-1]
+ else:
57
+ _dim: core.constexpr = dim
58
+ z = x - max(x, _dim, keep_dims=True)
59
num = math.exp(z)
- den = sum(num, dim, keep_dims=True)
60
+ den = sum(num, _dim, keep_dims=True)
61
return math.fdiv(num, den, ieee_rounding)
62
63
0 commit comments