Skip to content

Commit cd6ebd5

Browse files
feat: add use_process flag for @daft.func(...) (#5323)
## Changes Made adds `use_process` to daft.func to allow same performance benefits as `use_process` for `daft.udf`. ## Related Issues <!-- Link to related GitHub issues, e.g., "Closes #123" --> ## Checklist - [ ] Documented in API Docs (if applicable) - [ ] Documented in User Guide (if applicable) - [ ] If adding a new documentation page, doc is added to `docs/mkdocs.yml` navigation - [ ] Documentation builds and is formatted properly
1 parent 6cb34b3 commit cd6ebd5

File tree

13 files changed

+61
-15
lines changed

13 files changed

+61
-15
lines changed

daft/daft/__init__.pyi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1343,6 +1343,7 @@ def row_wise_udf(
13431343
name: str,
13441344
inner: Callable[..., Any],
13451345
return_dtype: PyDataType,
1346+
use_process: bool | None,
13461347
original_args: tuple[tuple[Any, ...], dict[str, Any]],
13471348
expr_args: list[PyExpr],
13481349
) -> PyExpr: ...

daft/udf/__init__.py

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ class _PartialUdf:
2828

2929
return_dtype: DataTypeLike | None
3030
unnest: bool
31+
use_process: bool | None
3132

3233
@overload
3334
def __call__(self, fn: Callable[P, Iterator[T]]) -> GeneratorUdf[P, T]: ... # type: ignore[overload-overlap]
@@ -36,9 +37,9 @@ def __call__(self, fn: Callable[P, T]) -> RowWiseUdf[P, T]: ...
3637

3738
def __call__(self, fn: Callable[P, Any]) -> GeneratorUdf[P, Any] | RowWiseUdf[P, Any]:
3839
if isgeneratorfunction(fn):
39-
return GeneratorUdf(fn, return_dtype=self.return_dtype, unnest=self.unnest)
40+
return GeneratorUdf(fn, return_dtype=self.return_dtype, unnest=self.unnest, use_process=self.use_process)
4041
else:
41-
return RowWiseUdf(fn, return_dtype=self.return_dtype, unnest=self.unnest)
42+
return RowWiseUdf(fn, return_dtype=self.return_dtype, unnest=self.unnest, use_process=self.use_process)
4243

4344

4445
class _DaftFuncDecorator:
@@ -212,20 +213,41 @@ class _DaftFuncDecorator:
212213
"""
213214

214215
@overload
215-
def __new__(cls, *, return_dtype: DataTypeLike | None = None, unnest: bool = False) -> _PartialUdf: ... # type: ignore[misc]
216+
def __new__( # type: ignore[misc]
217+
cls,
218+
*,
219+
return_dtype: DataTypeLike | None = None,
220+
unnest: bool = False,
221+
use_process: bool | None = None,
222+
) -> _PartialUdf: ...
216223
@overload
217224
def __new__( # type: ignore[misc]
218-
cls, fn: Callable[P, Iterator[T]], *, return_dtype: DataTypeLike | None = None, unnest: bool = False
225+
cls,
226+
fn: Callable[P, Iterator[T]],
227+
*,
228+
return_dtype: DataTypeLike | None = None,
229+
unnest: bool = False,
230+
use_process: bool | None = None,
219231
) -> GeneratorUdf[P, T]: ...
220232
@overload
221233
def __new__( # type: ignore[misc]
222-
cls, fn: Callable[P, T], *, return_dtype: DataTypeLike | None = None, unnest: bool = False
234+
cls,
235+
fn: Callable[P, T],
236+
*,
237+
return_dtype: DataTypeLike | None = None,
238+
unnest: bool = False,
239+
use_process: bool | None = None,
223240
) -> RowWiseUdf[P, T]: ...
224241

225242
def __new__( # type: ignore[misc]
226-
cls, fn: Callable[P, Any] | None = None, *, return_dtype: DataTypeLike | None = None, unnest: bool = False
243+
cls,
244+
fn: Callable[P, Any] | None = None,
245+
*,
246+
return_dtype: DataTypeLike | None = None,
247+
unnest: bool = False,
248+
use_process: bool | None = None,
227249
) -> _PartialUdf | GeneratorUdf[P, Any] | RowWiseUdf[P, Any]:
228-
partial_udf = _PartialUdf(return_dtype=return_dtype, unnest=unnest)
250+
partial_udf = _PartialUdf(return_dtype=return_dtype, unnest=unnest, use_process=use_process)
229251
return partial_udf if fn is None else partial_udf(fn)
230252

231253

daft/udf/generator.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,13 @@ class GeneratorUdf(Generic[P, T]):
3232
If no values are yielded for an input, a null value is inserted.
3333
"""
3434

35-
def __init__(self, fn: Callable[P, Iterator[T]], return_dtype: DataTypeLike | None, unnest: bool):
35+
def __init__(
36+
self, fn: Callable[P, Iterator[T]], return_dtype: DataTypeLike | None, unnest: bool, use_process: bool | None
37+
):
3638
self._inner = fn
3739
self.name = get_unique_function_name(fn)
3840
self.unnest = unnest
41+
self.use_process = use_process
3942

4043
# attempt to extract return type from an Iterator or Generator type hint
4144
if return_dtype is None:
@@ -85,7 +88,9 @@ def inner_rowwise(*args: P.args, **kwargs: P.kwargs) -> list[T]:
8588
return_dtype_rowwise = DataType.list(self.return_dtype)
8689

8790
expr = Expression._from_pyexpr(
88-
row_wise_udf(self.name, inner_rowwise, return_dtype_rowwise._dtype, (args, kwargs), expr_args)
91+
row_wise_udf(
92+
self.name, inner_rowwise, return_dtype_rowwise._dtype, self.use_process, (args, kwargs), expr_args
93+
)
8994
).explode()
9095

9196
if self.unnest:

daft/udf/row_wise.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,11 @@ class RowWiseUdf(Generic[P, T]):
3333
Row-wise functions are called with data from one row at a time, and map that to a single output value for that row.
3434
"""
3535

36-
def __init__(self, fn: Callable[P, T], return_dtype: DataTypeLike | None, unnest: bool):
36+
def __init__(self, fn: Callable[P, T], return_dtype: DataTypeLike | None, unnest: bool, use_process: bool | None):
3737
self._inner = fn
3838
self.name = get_unique_function_name(fn)
3939
self.unnest = unnest
40+
self.use_process = use_process
4041

4142
if return_dtype is None:
4243
type_hints = get_type_hints(fn)
@@ -69,7 +70,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> Expression | T:
6970
return self._inner(*args, **kwargs)
7071

7172
expr = Expression._from_pyexpr(
72-
row_wise_udf(self.name, self._inner, self.return_dtype._dtype, (args, kwargs), expr_args)
73+
row_wise_udf(self.name, self._inner, self.return_dtype._dtype, self.use_process, (args, kwargs), expr_args)
7374
)
7475

7576
if self.unnest:

src/daft-dsl/src/expr/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1478,6 +1478,7 @@ impl Expr {
14781478
return_dtype,
14791479
original_args,
14801480
args: old_children,
1481+
use_process,
14811482
}))) => {
14821483
assert!(
14831484
children.len() == old_children.len(),
@@ -1490,6 +1491,7 @@ impl Expr {
14901491
return_dtype: return_dtype.clone(),
14911492
original_args: original_args.clone(),
14921493
args: children,
1494+
use_process: *use_process,
14931495
})))
14941496
}
14951497
}

src/daft-dsl/src/functions/python/mod.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ use serde::{Deserialize, Serialize};
1919
use super::FunctionExpr;
2020
#[cfg(feature = "python")]
2121
use crate::python::PyExpr;
22-
use crate::{Expr, ExprRef, functions::scalar::ScalarFn};
22+
use crate::{Expr, ExprRef, functions::scalar::ScalarFn, python_udf::PyScalarFn};
2323

2424
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
2525
pub enum MaybeInitializedUDF {
@@ -375,13 +375,13 @@ pub fn get_udf_properties(expr: &ExprRef) -> UDFProperties {
375375
concurrency: *concurrency,
376376
use_process: *use_process,
377377
});
378-
} else if let Expr::ScalarFn(ScalarFn::Python(py)) = e.as_ref() {
378+
} else if let Expr::ScalarFn(ScalarFn::Python(PyScalarFn::RowWise(pyfn))) = e.as_ref() {
379379
udf_properties = Some(UDFProperties {
380-
name: py.name().to_string(),
380+
name: pyfn.function_name.to_string(),
381381
resource_request: None,
382382
batch_size: None,
383383
concurrency: None,
384-
use_process: None,
384+
use_process: pyfn.use_process,
385385
});
386386
}
387387
Ok(TreeNodeRecursion::Continue)

src/daft-dsl/src/python.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,7 @@ pub fn row_wise_udf(
238238
name: &str,
239239
inner: PyObject,
240240
return_dtype: PyDataType,
241+
use_process: Option<bool>,
241242
original_args: PyObject,
242243
expr_args: Vec<PyExpr>,
243244
) -> PyExpr {
@@ -250,6 +251,7 @@ pub fn row_wise_udf(
250251
name,
251252
inner.into(),
252253
return_dtype.into(),
254+
use_process,
253255
original_args.into(),
254256
args,
255257
)

src/daft-dsl/src/python_udf.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ pub fn row_wise_udf(
5858
name: &str,
5959
inner: RuntimePyObject,
6060
return_dtype: DataType,
61+
use_process: Option<bool>,
6162
original_args: RuntimePyObject,
6263
args: Vec<ExprRef>,
6364
) -> Expr {
@@ -67,6 +68,7 @@ pub fn row_wise_udf(
6768
return_dtype,
6869
original_args,
6970
args,
71+
use_process,
7072
})))
7173
}
7274

@@ -77,6 +79,7 @@ pub struct RowWisePyFn {
7779
pub return_dtype: DataType,
7880
pub original_args: RuntimePyObject,
7981
pub args: Vec<ExprRef>,
82+
pub use_process: Option<bool>,
8083
}
8184

8285
impl Display for RowWisePyFn {
@@ -101,6 +104,7 @@ impl RowWisePyFn {
101104
return_dtype: self.return_dtype.clone(),
102105
original_args: self.original_args.clone(),
103106
args: children,
107+
use_process: self.use_process,
104108
}
105109
}
106110

src/daft-ir/src/proto/functions.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ pub fn from_proto_function(message: proto::ScalarFn) -> ProtoResult<ir::Expr> {
4040
return_dtype: from_proto(row_wise_fn.return_dtype)?,
4141
original_args: from_proto(row_wise_fn.original_args)?,
4242
args: args.into_inner(),
43+
use_process: row_wise_fn.use_process,
4344
};
4445
ir::rex::from_py_rowwise_func(func)
4546
}
@@ -100,6 +101,7 @@ pub fn scalar_fn_to_proto(sf: &ir::functions::scalar::ScalarFn) -> ProtoResult<p
100101
return_dtype: Some(row_wise_fn.return_dtype.to_proto()?),
101102
inner: Some(row_wise_fn.inner.to_proto()?),
102103
original_args: Some(row_wise_fn.original_args.to_proto()?),
104+
use_process: row_wise_fn.use_process,
103105
},
104106
)),
105107
})),

src/daft-logical-plan/src/ops/project.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -486,6 +486,7 @@ fn replace_column_with_semantic_id(
486486
return_dtype,
487487
original_args,
488488
args: children,
489+
use_process,
489490
}))) => {
490491
let transforms = children
491492
.iter()
@@ -508,6 +509,7 @@ fn replace_column_with_semantic_id(
508509
return_dtype: return_dtype.clone(),
509510
original_args: original_args.clone(),
510511
args: new_children,
512+
use_process: *use_process,
511513
}),
512514
))))
513515
}

0 commit comments

Comments
 (0)