Skip to content

Commit 1059382

Browse files
committed
Add full op
1 parent 3db1b12 commit 1059382

File tree

4 files changed

+80
-0
lines changed

4 files changed

+80
-0
lines changed

python/python_direct/ops.cpp

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2281,6 +2281,15 @@ TensorView
22812281
)");
22822282
}
22832283

2284+
template <class ShapeType>
2285+
TensorView* full_op_fn(
2286+
ShapeType generic_output_shape,
2287+
Val* fill_value,
2288+
PrimDataType dtype) {
2289+
std::vector<Val*> output_shape = SequenceAsVector(generic_output_shape);
2290+
return full(output_shape, fill_value, dtype);
2291+
}
2292+
22842293
void bindTensorFactoryOps(py::module_& ops) {
22852294
ops.def(
22862295
"iota",
@@ -2308,6 +2317,37 @@ Returns
23082317
-------
23092318
TensorView
23102319
The tensor with values from 0 to length-1.
2320+
)",
2321+
py::return_value_policy::reference);
2322+
ops.def(
2323+
"full",
2324+
full_op_fn<py::list>,
2325+
py::arg("shape"),
2326+
py::arg("fill_value"),
2327+
py::arg("dtype"),
2328+
py::return_value_policy::reference);
2329+
ops.def(
2330+
"full",
2331+
full_op_fn<py::tuple>,
2332+
py::arg("shape"),
2333+
py::arg("fill_value"),
2334+
py::arg("dtype"),
2335+
R"(
2336+
Create a tensor with all elements set to a specified value.
2337+
2338+
Parameters
2339+
----------
2340+
shape : list or tuple
2341+
The shape of the tensor.
2342+
fill_value : Val
2343+
The value to fill the tensor with.
2344+
dtype : PrimDataType
2345+
The data type of the tensor.
2346+
2347+
Returns
2348+
-------
2349+
TensorView
2350+
The tensor with all elements set to the specified value.
23112351
)",
23122352
py::return_value_policy::reference);
23132353
}

python/python_direct/python_translate.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -981,6 +981,25 @@ class PythonTranslator : public OptInConstDispatch {
981981
{lsop->out()});
982982
}
983983

984+
// Map FullOp to python frontend
985+
void handle(const FullOp* fop) final {
986+
NVF_ERROR(fop != nullptr);
987+
TensorView* out_tv = fop->output(0)->as<TensorView>();
988+
visited_vals_.insert(out_tv);
989+
990+
// Fill value can be dynamic so create it
991+
dispatch(fop->getFillValue());
992+
993+
static const std::vector<std::string> argument_names = {
994+
"shape", "fill_value", "dtype"};
995+
printer_.generateKwargsOperation(
996+
"fd.ops.full",
997+
std::make_tuple(),
998+
argument_names,
999+
std::make_tuple(getShape(out_tv), fop->getFillValue(), out_tv->dtype()),
1000+
{out_tv});
1001+
}
1002+
9841003
// Map IotaOp to python frontend
9851004
void handle(const IotaOp* iop) final {
9861005
NVF_ERROR(iop != nullptr);

tests/python/direct/test_python_frontend.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -910,3 +910,23 @@ def fusion_func(fd: FusionDefinition):
910910
eager_out2 = torch.tensor([100, 101, 102], dtype=torch.int, device="cuda")
911911
nvfuser_direct_test.assertEqual(eager_out1, nvf_out[0])
912912
nvfuser_direct_test.assertEqual(eager_out2, nvf_out[1])
913+
914+
915+
def test_scalar_only_inputs(nvfuser_direct_test):
916+
# We don't allow scalar outputs, currently,
917+
# so a tensor has to be returned
918+
def fusion_func(fd: FusionDefinition):
919+
s0 = fd.define_scalar()
920+
s1 = fd.define_scalar()
921+
s2 = fd.ops.add(s0, s1)
922+
c0 = fd.define_scalar(1.0, DataType.Float)
923+
t3 = fd.ops.full(shape=[2, 2], fill_value=c0, dtype=DataType.Float)
924+
t4 = fd.ops.mul(t3, s2)
925+
fd.add_output(t4)
926+
927+
with FusionDefinition() as fd:
928+
fusion_func(fd)
929+
930+
nvf_out, _ = nvfuser_direct_test.exec_nvfuser(fusion_func, [2.0, 3.0])
931+
eager_out = torch.full([2, 2], 1.0) * 5.0
932+
nvfuser_direct_test.assertEqual(eager_out, nvf_out[0])

tests/python/opinfo/opinfos.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1311,6 +1311,7 @@ def torch_reshape_sym_fn(input_tensor, output_shaped_tensor):
13111311
ArgumentType.Symbolic,
13121312
ArgumentType.Constant,
13131313
),
1314+
supports_direct_bindings=True,
13141315
)
13151316
tensor_creation_ops.append(full_opinfo)
13161317

0 commit comments

Comments
 (0)