diff --git a/benchmark/test_unary_pointwise_perf.py b/benchmark/test_unary_pointwise_perf.py index 8d421878f..7c5215431 100644 --- a/benchmark/test_unary_pointwise_perf.py +++ b/benchmark/test_unary_pointwise_perf.py @@ -40,6 +40,7 @@ def get_tflops(self, op, *args, **kwargs): forward_operations = [ ("abs", torch.abs, FLOAT_DTYPES), + ("acosh", torch.acosh, FLOAT_DTYPES), ("angle", torch.angle, COMPLEX_DTYPES + [torch.float32] + INT_DTYPES + BOOL_DTYPES), ("erf", torch.erf, FLOAT_DTYPES), ("exp", torch.exp, FLOAT_DTYPES), diff --git a/src/flag_gems/__init__.py b/src/flag_gems/__init__.py index 81cb9b48b..5fba7c60d 100644 --- a/src/flag_gems/__init__.py +++ b/src/flag_gems/__init__.py @@ -43,6 +43,7 @@ def enable( ("_weight_norm_interface_backward", weight_norm_interface_backward), ("abs", abs), ("abs_", abs_), + ("acosh", acosh), ("add.Tensor", add), ("add_.Tensor", add_), ("addcdiv", addcdiv), diff --git a/src/flag_gems/ops/__init__.py b/src/flag_gems/ops/__init__.py index e592f4bb8..ab3cc5aba 100755 --- a/src/flag_gems/ops/__init__.py +++ b/src/flag_gems/ops/__init__.py @@ -1,4 +1,5 @@ from flag_gems.ops.abs import abs, abs_ +from flag_gems.ops.acosh import acosh from flag_gems.ops.add import add, add_ from flag_gems.ops.addcdiv import addcdiv from flag_gems.ops.addcmul import addcmul @@ -207,6 +208,7 @@ "_upsample_bicubic2d_aa", "abs", "abs_", + "acosh", "add", "add_", "addcdiv", diff --git a/src/flag_gems/ops/acosh.py b/src/flag_gems/ops/acosh.py new file mode 100644 index 000000000..698317d97 --- /dev/null +++ b/src/flag_gems/ops/acosh.py @@ -0,0 +1,34 @@ +import logging +import triton +import triton.language as tl +import torch + +from flag_gems.utils import pointwise_dynamic + +logger = logging.getLogger(__name__) + +@pointwise_dynamic( + promotion_method=[(0, "DEFAULT")], +) +@triton.jit +def acosh_forward_kernel(x) + return tl.log(x + tl.sqrt(x * x - 1.0)) + +def acosh(input: torch.Tensor, *, out: torch.Tensor = None): + """ + Returns a new tensor with the inverse hyperbolic cosine of the elements of input. + + Args: + input (Tensor): the input tensor + out (Tensor, optional): the output tensor + + Returns: + Tensor: the output tensor with the inverse hyperbolic cosine values + """ + result = acosh_forward_kernel(input) + + if out is not None: + out.copy_(result) + return out + + return output \ No newline at end of file diff --git a/tests/test_unary_pointwise_ops.py b/tests/test_unary_pointwise_ops.py index 8471d128e..0b0da153d 100644 --- a/tests/test_unary_pointwise_ops.py +++ b/tests/test_unary_pointwise_ops.py @@ -33,6 +33,20 @@ def test_accuracy_abs(shape, dtype): gems_assert_equal(res_out, ref_out) +@pytest.mark.acosh +@pytest.mark.parametrize("shape", POINTWISE_SHAPES) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_accuracy_acosh(shape, dtype): + inp = torch.randn(shape, dtype=dtype, device=flag_gems.device).exp() + 0.5 + + ref_inp = to_reference(inp, True) + ref_out = torch.acosh(ref_inp) + with flag_gems.use_gems(): + res_out = torch.acosh(inp) + + gems_assert_close(res_out, ref_out, dtype, equal_nan=True) + + @pytest.mark.inplace @pytest.mark.abs_ @pytest.mark.parametrize("shape", POINTWISE_SHAPES)