Skip to content

Commit 3db1b12

Browse files
authored
Add Iota operation to direct bindings (#4879)
## PR List - #4876 - #4877 - #4878 - #4879 **<< This PR.** - #4880 - #4881 - #4882 - #4883 - #4884
1 parent 78408ac commit 3db1b12

File tree

4 files changed

+80
-1
lines changed

4 files changed

+80
-1
lines changed

python/python_direct/ops.cpp

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2281,6 +2281,37 @@ TensorView
22812281
)");
22822282
}
22832283

2284+
void bindTensorFactoryOps(py::module_& ops) {
2285+
ops.def(
2286+
"iota",
2287+
[](Val* length, Val* start, Val* step, PrimDataType dtype)
2288+
-> TensorView* { return iota(length, start, step, dtype); },
2289+
py::arg("length"),
2290+
py::arg("start").none(true) = py::none(),
2291+
py::arg("step").none(true) = py::none(),
2292+
py::arg("dtype") = DataType::Int,
2293+
R"(
2294+
Create a tensor with values from 0 to length-1.
2295+
2296+
Parameters
2297+
----------
2298+
length : Val
2299+
The length of the tensor.
2300+
start : Val, optional
2301+
The start of the tensor. When the default is None, start is set to zero.
2302+
step : Val, optional
2303+
The step of the tensor. When the default is None, step is set to zero.
2304+
dtype : PrimDataType, optional
2305+
The data type of the tensor. Default is DataType::Int.
2306+
2307+
Returns
2308+
-------
2309+
TensorView
2310+
The tensor with values from 0 to length-1.
2311+
)",
2312+
py::return_value_policy::reference);
2313+
}
2314+
22842315
} // namespace
22852316

22862317
void bindOperations(py::module& nvfuser) {
@@ -2290,12 +2321,13 @@ void bindOperations(py::module& nvfuser) {
22902321
bindBinaryOps(nvf_ops);
22912322
bindTernaryOps(nvf_ops);
22922323
bindReductionOps(nvf_ops);
2324+
bindScanOps(nvf_ops);
22932325
bindCastOps(nvf_ops);
22942326
bindMatmulOps(nvf_ops);
22952327
bindMetadataOps(nvf_ops);
22962328
bindTensorUtilityOps(nvf_ops);
22972329
bindIndexingOps(nvf_ops);
2298-
bindScanOps(nvf_ops);
2330+
bindTensorFactoryOps(nvf_ops);
22992331
}
23002332

23012333
} // namespace nvfuser::python

python/python_direct/python_translate.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -981,6 +981,29 @@ class PythonTranslator : public OptInConstDispatch {
981981
{lsop->out()});
982982
}
983983

984+
// Map IotaOp to python frontend
985+
void handle(const IotaOp* iop) final {
986+
NVF_ERROR(iop != nullptr);
987+
TensorView* out_tv = iop->output(0)->as<TensorView>();
988+
visited_vals_.insert(out_tv);
989+
990+
dispatch(iop->length());
991+
dispatch(iop->start());
992+
dispatch(iop->step());
993+
994+
static const auto default_args = std::make_tuple(
995+
KeywordArgument<decltype(iop->length())>{"length", std::nullopt},
996+
KeywordArgument<decltype(iop->start())>{"start", nullptr},
997+
KeywordArgument<decltype(iop->step())>{"step", nullptr},
998+
KeywordArgument<DataType>{"dtype", DataType::Int});
999+
printer_.generateKwargsOperation(
1000+
"fd.ops.iota",
1001+
std::make_tuple(),
1002+
default_args,
1003+
std::make_tuple(iop->length(), iop->start(), iop->step(), iop->dtype()),
1004+
{out_tv});
1005+
}
1006+
9841007
// Map IndexSelectOp to IndexSelectOpRecord
9851008
void handle(const IndexSelectOp* isop) final {
9861009
NVF_ERROR(isop != nullptr);

tests/python/direct/test_python_frontend.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -887,3 +887,26 @@ def fusion_func(fd: FusionDefinition) -> None:
887887
nvf_out, _ = nvfuser_direct_test.exec_nvfuser(fusion_func, inputs)
888888
for out in nvf_out:
889889
nvfuser_direct_test.assertTrue(out.allclose(x[:, 1:, 2:]))
890+
891+
892+
def test_iota(nvfuser_direct_test):
893+
inputs = [
894+
(2, 0, 2, DataType.Int),
895+
(3, 100, 1, DataType.Int32),
896+
]
897+
898+
def fusion_func(fd: FusionDefinition):
899+
for input in inputs:
900+
c0 = fd.define_scalar(input[0])
901+
c1 = None if input[1] is None else fd.define_scalar(input[1])
902+
c2 = None if input[2] is None else fd.define_scalar(input[2])
903+
dt = input[3]
904+
t3 = fd.ops.iota(c0, c1, c2, dt)
905+
fd.add_output(t3)
906+
907+
nvf_out, _ = nvfuser_direct_test.exec_nvfuser(fusion_func, [])
908+
909+
eager_out1 = torch.tensor([0, 2], dtype=torch.long, device="cuda")
910+
eager_out2 = torch.tensor([100, 101, 102], dtype=torch.int, device="cuda")
911+
nvfuser_direct_test.assertEqual(eager_out1, nvf_out[0])
912+
nvfuser_direct_test.assertEqual(eager_out2, nvf_out[1])

tests/python/opinfo/opinfos.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1326,6 +1326,7 @@ def torch_reshape_sym_fn(input_tensor, output_shaped_tensor):
13261326
ArgumentType.ConstantScalar,
13271327
ArgumentType.Constant,
13281328
),
1329+
supports_direct_bindings=True,
13291330
)
13301331
tensor_creation_ops.append(iota_opinfo)
13311332

0 commit comments

Comments
 (0)