Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion src/flag_gems/runtime/backend/_kunlunxin/ops/ge.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import os

import triton
import triton.language as tl
Expand All @@ -16,7 +17,11 @@ def ge_func(x, y):

def ge(A, B):
logger.debug("GEMS GE")
os.environ["TRITONXPU_COMPARE_FUSION"] = "1"
os.environ["TRITONXPU_FP16_FAST"] = "1"
res = ge_func(A, B)
del os.environ["TRITONXPU_COMPARE_FUSION"]
del os.environ["TRITONXPU_FP16_FAST"]
return res


Expand All @@ -28,4 +33,9 @@ def ge_func_scalar(x, y):

def ge_scalar(A, B):
logger.debug("GEMS GE SCALAR")
return ge_func_scalar(A, B)
os.environ["TRITONXPU_COMPARE_FUSION"] = "1"
os.environ["TRITONXPU_FP16_FAST"] = "1"
res = ge_func_scalar(A, B)
del os.environ["TRITONXPU_COMPARE_FUSION"]
del os.environ["TRITONXPU_FP16_FAST"]
return res
12 changes: 11 additions & 1 deletion src/flag_gems/runtime/backend/_kunlunxin/ops/gt.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import os

import triton
import triton.language as tl
Expand All @@ -16,7 +17,11 @@ def gt_func(x, y):

def gt(A, B):
logger.debug("GEMS GT")
os.environ["TRITONXPU_COMPARE_FUSION"] = "1"
os.environ["TRITONXPU_FP16_FAST"] = "1"
res = gt_func(A, B)
del os.environ["TRITONXPU_COMPARE_FUSION"]
del os.environ["TRITONXPU_FP16_FAST"]
return res


Expand All @@ -28,4 +33,9 @@ def gt_func_scalar(x, y):

def gt_scalar(A, B):
logger.debug("GEMS GT SCALAR")
return gt_func_scalar(A, B)
os.environ["TRITONXPU_COMPARE_FUSION"] = "1"
os.environ["TRITONXPU_FP16_FAST"] = "1"
res = gt_func_scalar(A, B)
del os.environ["TRITONXPU_COMPARE_FUSION"]
del os.environ["TRITONXPU_FP16_FAST"]
return res
10 changes: 9 additions & 1 deletion src/flag_gems/runtime/backend/_kunlunxin/ops/le.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import os

import triton
import triton.language as tl
Expand All @@ -16,7 +17,11 @@ def le_func(x, y):

def le(A, B):
logger.debug("GEMS LE")
os.environ["TRITONXPU_COMPARE_FUSION"] = "1"
os.environ["TRITONXPU_FP16_FAST"] = "1"
res = le_func(A, B)
del os.environ["TRITONXPU_COMPARE_FUSION"]
del os.environ["TRITONXPU_FP16_FAST"]
return res


Expand All @@ -28,4 +33,7 @@ def le_func_scalar(x, y):

def le_scalar(A, B):
logger.debug("GEMS LE SCALAR")
return le_func_scalar(A, B)
os.environ["TRITONXPU_COMPARE_FUSION"] = "1"
res = le_func_scalar(A, B)
del os.environ["TRITONXPU_COMPARE_FUSION"]
Comment on lines +36 to +38
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This logic for modifying os.environ is not exception-safe and does not properly restore the previous state. The same robust approach suggested for other functions in this PR should be applied here.

Suggested change
os.environ["TRITONXPU_COMPARE_FUSION"] = "1"
res = le_func_scalar(A, B)
del os.environ["TRITONXPU_COMPARE_FUSION"]
key = "TRITONXPU_COMPARE_FUSION"
original_value = os.environ.get(key)
os.environ[key] = "1"
try:
res = le_func_scalar(A, B)
finally:
if original_value is None:
os.environ.pop(key, None)
else:
os.environ[key] = original_value

return res
12 changes: 11 additions & 1 deletion src/flag_gems/runtime/backend/_kunlunxin/ops/lt.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import os

import triton
import triton.language as tl
Expand All @@ -16,7 +17,11 @@ def lt_func(x, y):

def lt(A, B):
logger.debug("GEMS LT")
os.environ["TRITONXPU_COMPARE_FUSION"] = "1"
os.environ["TRITONXPU_FP16_FAST"] = "1"
res = lt_func(A, B)
del os.environ["TRITONXPU_COMPARE_FUSION"]
del os.environ["TRITONXPU_FP16_FAST"]
return res


Expand All @@ -28,4 +33,9 @@ def lt_func_scalar(x, y):

def lt_scalar(A, B):
logger.debug("GEMS LT SCALAR")
return lt_func_scalar(A, B)
os.environ["TRITONXPU_COMPARE_FUSION"] = "1"
os.environ["TRITONXPU_FP16_FAST"] = "1"
res = lt_func_scalar(A, B)
del os.environ["TRITONXPU_COMPARE_FUSION"]
del os.environ["TRITONXPU_FP16_FAST"]
return res
12 changes: 11 additions & 1 deletion src/flag_gems/runtime/backend/_kunlunxin/ops/ne.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import os

import triton
import triton.language as tl
Expand All @@ -16,7 +17,11 @@ def ne_func(x, y):

def ne(A, B):
logger.debug("GEMS NE")
os.environ["TRITONXPU_COMPARE_FUSION"] = "1"
os.environ["TRITONXPU_FP16_FAST"] = "1"
res = ne_func(A, B)
del os.environ["TRITONXPU_COMPARE_FUSION"]
del os.environ["TRITONXPU_FP16_FAST"]
return res


Expand All @@ -28,4 +33,9 @@ def ne_func_scalar(x, y):

def ne_scalar(A, B):
logger.debug("GEMS NE SCALAR")
return ne_func_scalar(A, B)
os.environ["TRITONXPU_COMPARE_FUSION"] = "1"
os.environ["TRITONXPU_FP16_FAST"] = "1"
res = ne_func_scalar(A, B)
del os.environ["TRITONXPU_COMPARE_FUSION"]
del os.environ["TRITONXPU_FP16_FAST"]
return res
Loading