Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions daft/daft/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -1343,6 +1343,7 @@ def row_wise_udf(
name: str,
inner: Callable[..., Any],
return_dtype: PyDataType,
use_process: bool | None,
original_args: tuple[tuple[Any, ...], dict[str, Any]],
expr_args: list[PyExpr],
) -> PyExpr: ...
Expand Down
36 changes: 29 additions & 7 deletions daft/udf/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class _PartialUdf:

return_dtype: DataTypeLike | None
unnest: bool
use_process: bool | None

@overload
def __call__(self, fn: Callable[P, Iterator[T]]) -> GeneratorUdf[P, T]: ... # type: ignore[overload-overlap]
Expand All @@ -36,9 +37,9 @@ def __call__(self, fn: Callable[P, T]) -> RowWiseUdf[P, T]: ...

def __call__(self, fn: Callable[P, Any]) -> GeneratorUdf[P, Any] | RowWiseUdf[P, Any]:
if isgeneratorfunction(fn):
return GeneratorUdf(fn, return_dtype=self.return_dtype, unnest=self.unnest)
return GeneratorUdf(fn, return_dtype=self.return_dtype, unnest=self.unnest, use_process=self.use_process)
else:
return RowWiseUdf(fn, return_dtype=self.return_dtype, unnest=self.unnest)
return RowWiseUdf(fn, return_dtype=self.return_dtype, unnest=self.unnest, use_process=self.use_process)


class _DaftFuncDecorator:
Expand Down Expand Up @@ -212,20 +213,41 @@ class _DaftFuncDecorator:
"""

@overload
def __new__(cls, *, return_dtype: DataTypeLike | None = None, unnest: bool = False) -> _PartialUdf: ... # type: ignore[misc]
def __new__( # type: ignore[misc]
cls,
*,
return_dtype: DataTypeLike | None = None,
unnest: bool = False,
use_process: bool | None = None,
) -> _PartialUdf: ...
@overload
def __new__( # type: ignore[misc]
cls, fn: Callable[P, Iterator[T]], *, return_dtype: DataTypeLike | None = None, unnest: bool = False
cls,
fn: Callable[P, Iterator[T]],
*,
return_dtype: DataTypeLike | None = None,
unnest: bool = False,
use_process: bool | None = None,
) -> GeneratorUdf[P, T]: ...
@overload
def __new__( # type: ignore[misc]
cls, fn: Callable[P, T], *, return_dtype: DataTypeLike | None = None, unnest: bool = False
cls,
fn: Callable[P, T],
*,
return_dtype: DataTypeLike | None = None,
unnest: bool = False,
use_process: bool | None = None,
) -> RowWiseUdf[P, T]: ...

def __new__( # type: ignore[misc]
cls, fn: Callable[P, Any] | None = None, *, return_dtype: DataTypeLike | None = None, unnest: bool = False
cls,
fn: Callable[P, Any] | None = None,
*,
return_dtype: DataTypeLike | None = None,
unnest: bool = False,
use_process: bool | None = None,
) -> _PartialUdf | GeneratorUdf[P, Any] | RowWiseUdf[P, Any]:
partial_udf = _PartialUdf(return_dtype=return_dtype, unnest=unnest)
partial_udf = _PartialUdf(return_dtype=return_dtype, unnest=unnest, use_process=use_process)
return partial_udf if fn is None else partial_udf(fn)


Expand Down
9 changes: 7 additions & 2 deletions daft/udf/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,13 @@ class GeneratorUdf(Generic[P, T]):
If no values are yielded for an input, a null value is inserted.
"""

def __init__(self, fn: Callable[P, Iterator[T]], return_dtype: DataTypeLike | None, unnest: bool):
def __init__(
self, fn: Callable[P, Iterator[T]], return_dtype: DataTypeLike | None, unnest: bool, use_process: bool | None
):
self._inner = fn
self.name = get_unique_function_name(fn)
self.unnest = unnest
self.use_process = use_process

# attempt to extract return type from an Iterator or Generator type hint
if return_dtype is None:
Expand Down Expand Up @@ -85,7 +88,9 @@ def inner_rowwise(*args: P.args, **kwargs: P.kwargs) -> list[T]:
return_dtype_rowwise = DataType.list(self.return_dtype)

expr = Expression._from_pyexpr(
row_wise_udf(self.name, inner_rowwise, return_dtype_rowwise._dtype, (args, kwargs), expr_args)
row_wise_udf(
self.name, inner_rowwise, return_dtype_rowwise._dtype, self.use_process, (args, kwargs), expr_args
)
).explode()

if self.unnest:
Expand Down
5 changes: 3 additions & 2 deletions daft/udf/row_wise.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,11 @@ class RowWiseUdf(Generic[P, T]):
Row-wise functions are called with data from one row at a time, and map that to a single output value for that row.
"""

def __init__(self, fn: Callable[P, T], return_dtype: DataTypeLike | None, unnest: bool):
def __init__(self, fn: Callable[P, T], return_dtype: DataTypeLike | None, unnest: bool, use_process: bool | None):
self._inner = fn
self.name = get_unique_function_name(fn)
self.unnest = unnest
self.use_process = use_process

if return_dtype is None:
type_hints = get_type_hints(fn)
Expand Down Expand Up @@ -69,7 +70,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> Expression | T:
return self._inner(*args, **kwargs)

expr = Expression._from_pyexpr(
row_wise_udf(self.name, self._inner, self.return_dtype._dtype, (args, kwargs), expr_args)
row_wise_udf(self.name, self._inner, self.return_dtype._dtype, self.use_process, (args, kwargs), expr_args)
)

if self.unnest:
Expand Down
2 changes: 2 additions & 0 deletions src/daft-dsl/src/expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1478,6 +1478,7 @@ impl Expr {
return_dtype,
original_args,
args: old_children,
use_process,
}))) => {
assert!(
children.len() == old_children.len(),
Expand All @@ -1490,6 +1491,7 @@ impl Expr {
return_dtype: return_dtype.clone(),
original_args: original_args.clone(),
args: children,
use_process: *use_process,
})))
}
}
Expand Down
8 changes: 4 additions & 4 deletions src/daft-dsl/src/functions/python/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use serde::{Deserialize, Serialize};
use super::FunctionExpr;
#[cfg(feature = "python")]
use crate::python::PyExpr;
use crate::{Expr, ExprRef, functions::scalar::ScalarFn};
use crate::{Expr, ExprRef, functions::scalar::ScalarFn, python_udf::PyScalarFn};

#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
pub enum MaybeInitializedUDF {
Expand Down Expand Up @@ -375,13 +375,13 @@ pub fn get_udf_properties(expr: &ExprRef) -> UDFProperties {
concurrency: *concurrency,
use_process: *use_process,
});
} else if let Expr::ScalarFn(ScalarFn::Python(py)) = e.as_ref() {
} else if let Expr::ScalarFn(ScalarFn::Python(PyScalarFn::RowWise(pyfn))) = e.as_ref() {
udf_properties = Some(UDFProperties {
name: py.name().to_string(),
name: pyfn.function_name.to_string(),
resource_request: None,
batch_size: None,
concurrency: None,
use_process: None,
use_process: pyfn.use_process,
});
}
Ok(TreeNodeRecursion::Continue)
Expand Down
2 changes: 2 additions & 0 deletions src/daft-dsl/src/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,7 @@ pub fn row_wise_udf(
name: &str,
inner: PyObject,
return_dtype: PyDataType,
use_process: Option<bool>,
original_args: PyObject,
expr_args: Vec<PyExpr>,
) -> PyExpr {
Expand All @@ -250,6 +251,7 @@ pub fn row_wise_udf(
name,
inner.into(),
return_dtype.into(),
use_process,
original_args.into(),
args,
)
Expand Down
4 changes: 4 additions & 0 deletions src/daft-dsl/src/python_udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ pub fn row_wise_udf(
name: &str,
inner: RuntimePyObject,
return_dtype: DataType,
use_process: Option<bool>,
original_args: RuntimePyObject,
args: Vec<ExprRef>,
) -> Expr {
Expand All @@ -67,6 +68,7 @@ pub fn row_wise_udf(
return_dtype,
original_args,
args,
use_process,
})))
}

Expand All @@ -77,6 +79,7 @@ pub struct RowWisePyFn {
pub return_dtype: DataType,
pub original_args: RuntimePyObject,
pub args: Vec<ExprRef>,
pub use_process: Option<bool>,
}

impl Display for RowWisePyFn {
Expand All @@ -101,6 +104,7 @@ impl RowWisePyFn {
return_dtype: self.return_dtype.clone(),
original_args: self.original_args.clone(),
args: children,
use_process: self.use_process,
}
}

Expand Down
2 changes: 2 additions & 0 deletions src/daft-ir/src/proto/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ pub fn from_proto_function(message: proto::ScalarFn) -> ProtoResult<ir::Expr> {
return_dtype: from_proto(row_wise_fn.return_dtype)?,
original_args: from_proto(row_wise_fn.original_args)?,
args: args.into_inner(),
use_process: row_wise_fn.use_process,
};
ir::rex::from_py_rowwise_func(func)
}
Expand Down Expand Up @@ -100,6 +101,7 @@ pub fn scalar_fn_to_proto(sf: &ir::functions::scalar::ScalarFn) -> ProtoResult<p
return_dtype: Some(row_wise_fn.return_dtype.to_proto()?),
inner: Some(row_wise_fn.inner.to_proto()?),
original_args: Some(row_wise_fn.original_args.to_proto()?),
use_process: row_wise_fn.use_process,
},
)),
})),
Expand Down
2 changes: 2 additions & 0 deletions src/daft-logical-plan/src/ops/project.rs
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,7 @@ fn replace_column_with_semantic_id(
return_dtype,
original_args,
args: children,
use_process,
}))) => {
let transforms = children
.iter()
Expand All @@ -508,6 +509,7 @@ fn replace_column_with_semantic_id(
return_dtype: return_dtype.clone(),
original_args: original_args.clone(),
args: new_children,
use_process: *use_process,
}),
))))
}
Expand Down
2 changes: 2 additions & 0 deletions src/daft-logical-plan/src/partitioning.rs
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,7 @@ fn translate_clustering_spec_expr(
return_dtype,
original_args,
args: children,
use_process,
}))) => {
let new_children = children
.iter()
Expand All @@ -411,6 +412,7 @@ fn translate_clustering_spec_expr(
return_dtype: return_dtype.clone(),
original_args: original_args.clone(),
args: new_children,
use_process: *use_process,
}),
))))
}
Expand Down
1 change: 1 addition & 0 deletions src/daft-proto/proto/v1/daft.proto
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,7 @@ message ScalarFn {
DataType return_dtype = 2;
PyObject inner = 3;
PyObject original_args = 4;
optional bool use_process = 5;
}
}
}
Expand Down
2 changes: 2 additions & 0 deletions src/daft-proto/src/generated/daft.v1.rs
Original file line number Diff line number Diff line change
Expand Up @@ -533,6 +533,8 @@ pub mod scalar_fn {
pub inner: ::core::option::Option<super::super::PyObject>,
#[prost(message, optional, tag = "4")]
pub original_args: ::core::option::Option<super::super::PyObject>,
#[prost(bool, optional, tag = "5")]
pub use_process: ::core::option::Option<bool>,
}
#[derive(Clone, PartialEq, ::prost::Oneof)]
pub enum Variant {
Expand Down