From 7d5ceb06396ba862ff49d40ad60c90d3acb261d3 Mon Sep 17 00:00:00 2001 From: universalmind303 Date: Wed, 1 Oct 2025 11:19:55 -0500 Subject: [PATCH 01/15] make file-array serializable --- src/daft-core/src/array/ops/from_arrow.rs | 13 ++----------- src/daft-schema/src/dtype.rs | 6 ++---- 2 files changed, 4 insertions(+), 15 deletions(-) diff --git a/src/daft-core/src/array/ops/from_arrow.rs b/src/daft-core/src/array/ops/from_arrow.rs index 742ae7b910..90d0d5f0f8 100644 --- a/src/daft-core/src/array/ops/from_arrow.rs +++ b/src/daft-core/src/array/ops/from_arrow.rs @@ -6,7 +6,7 @@ use common_error::{DaftError, DaftResult}; use crate::{ array::{DataArray, FixedSizeListArray, ListArray, StructArray}, datatypes::{ - DaftDataType, DaftLogicalType, DaftPhysicalType, DataType, Field, FieldRef, FileArray, + DaftDataType, DaftLogicalType, DaftPhysicalType, DataType, Field, FieldRef, FileType, logical::LogicalArray, }, prelude::*, @@ -208,16 +208,6 @@ impl FromArrow for MapArray { } } -impl FromArrow for FileArray { - fn from_arrow(field: FieldRef, arrow_arr: Box) -> DaftResult { - Err(DaftError::TypeError(format!( - "Attempting to create Daft FileArray with type {} from arrow array with type {:?}", - field.dtype, - arrow_arr.data_type() - ))) - } -} - #[cfg(feature = "python")] impl FromArrow for PythonArray { fn from_arrow(field: FieldRef, arrow_arr: Box) -> DaftResult { @@ -271,3 +261,4 @@ impl_logical_from_arrow!(FixedShapeTensorType); impl_logical_from_arrow!(SparseTensorType); impl_logical_from_arrow!(FixedShapeSparseTensorType); impl_logical_from_arrow!(FixedShapeImageType); +impl_logical_from_arrow!(FileType); diff --git a/src/daft-schema/src/dtype.rs b/src/daft-schema/src/dtype.rs index 419de22e91..0ae801f960 100644 --- a/src/daft-schema/src/dtype.rs +++ b/src/daft-schema/src/dtype.rs @@ -272,7 +272,8 @@ impl DataType { | Self::Tensor(..) | Self::FixedShapeTensor(..) | Self::SparseTensor(..) - | Self::FixedShapeSparseTensor(..) => { + | Self::FixedShapeSparseTensor(..) + | Self::File => { let physical = Box::new(self.to_physical()); let logical_extension = Self::Extension( DAFT_SUPER_EXTENSION_NAME.into(), @@ -294,9 +295,6 @@ impl DataType { Self::Unknown => Err(DaftError::TypeError(format!( "Can not convert {self:?} into arrow type" ))), - Self::File => Err(DaftError::TypeError(format!( - "Can not convert {self:?} into arrow type" - ))), } } From b9cf3a6949541cdd3fb3555cac158cb4abe45740 Mon Sep 17 00:00:00 2001 From: universalmind303 Date: Wed, 1 Oct 2025 11:29:06 -0500 Subject: [PATCH 02/15] perf: allow Series::from_literals to take in optional datatype This makes it faster because we don't need to infer the datatype when we know what it is upfront. also makes the `cast` logic check if it's the same type and avoid casting if same type. This speeds up the literal creation a lot. --- src/daft-core/src/array/ops/cast.rs | 2 +- src/daft-core/src/lit/mod.rs | 12 +++++++---- src/daft-core/src/python/series.rs | 2 +- src/daft-core/src/series/from_lit.rs | 32 ++++++++++++++++------------ src/daft-core/src/series/mod.rs | 2 +- src/daft-dsl/src/python_udf.rs | 5 +++-- src/daft-sql/src/planner.rs | 2 +- 7 files changed, 33 insertions(+), 24 deletions(-) diff --git a/src/daft-core/src/array/ops/cast.rs b/src/daft-core/src/array/ops/cast.rs index d257170195..f72cb5f2a7 100644 --- a/src/daft-core/src/array/ops/cast.rs +++ b/src/daft-core/src/array/ops/cast.rs @@ -608,7 +608,7 @@ impl PythonArray { .collect::>>() })?; - Ok(Series::from_literals(literals)? + Ok(Series::from_literals(literals, None)? .cast(dtype)? .rename(self.name())) } diff --git a/src/daft-core/src/lit/mod.rs b/src/daft-core/src/lit/mod.rs index cc78d2a834..de56c8c3fd 100644 --- a/src/daft-core/src/lit/mod.rs +++ b/src/daft-core/src/lit/mod.rs @@ -583,10 +583,14 @@ impl Literal { /// /// This method is lossy, AKA it is not guaranteed that `lit.cast(dtype).get_type() == dtype`. /// This is because null literals always have the null data type. - pub fn cast(&self, dtype: &DataType) -> DaftResult { - Series::from(self.clone()) - .cast(dtype) - .and_then(|s| Self::try_from_single_value_series(&s)) + pub fn cast(self, dtype: &DataType) -> DaftResult { + if &self.get_type() == dtype { + Ok(self) + } else { + Series::from(self) + .cast(dtype) + .and_then(|s| Self::try_from_single_value_series(&s)) + } } } diff --git a/src/daft-core/src/python/series.rs b/src/daft-core/src/python/series.rs index 8318e44e5a..e67f081c7d 100644 --- a/src/daft-core/src/python/series.rs +++ b/src/daft-core/src/python/series.rs @@ -114,7 +114,7 @@ impl PySeries { (literals_with_supertype, supertype) }; - let mut series = Series::from_literals(literals)?.cast(&dtype)?; + let mut series = Series::from_literals(literals, None)?.cast(&dtype)?; if let Some(name) = name { series = series.rename(name); } diff --git a/src/daft-core/src/series/from_lit.rs b/src/daft-core/src/series/from_lit.rs index 39c8b91de4..cd0d8f8936 100644 --- a/src/daft-core/src/series/from_lit.rs +++ b/src/daft-core/src/series/from_lit.rs @@ -58,16 +58,19 @@ impl Series { /// /// Literals must all be the same type or null, this function does not do any casting or coercion. /// If that is desired, you should handle it for each literal before converting it into a series. - pub fn from_literals(values: Vec) -> DaftResult { - let dtype = values.iter().try_fold(DataType::Null, |acc, v| { - let dtype = v.get_type(); - combine_lit_types(&acc, &dtype).ok_or_else(|| { - DaftError::ValueError(format!( - "All literals must have the same data type or null. Found: {} vs {}", - acc, dtype - )) - }) - })?; + pub fn from_literals(values: Vec, dtype: Option) -> DaftResult { + let dtype = match dtype { + Some(dtype) => dtype, + None => values.iter().try_fold(DataType::Null, |acc, v| { + let dtype = v.get_type(); + combine_lit_types(&acc, &dtype).ok_or_else(|| { + DaftError::ValueError(format!( + "All literals must have the same data type or null. Found: {} vs {}", + acc, dtype + )) + }) + })?, + }; let field = Field::new("literal", dtype.clone()); @@ -184,7 +187,7 @@ impl Series { }) .collect::>(); - Ok(Self::from_literals(child_values)?.rename(&f.name)) + Ok(Self::from_literals(child_values, None)?.rename(&f.name)) }) .collect::>()?; @@ -385,7 +388,8 @@ impl Series { impl From for Series { fn from(value: Literal) -> Self { - Self::from_literals(vec![value]) + let dtype = Some(value.get_type()); + Self::from_literals(vec![value], dtype) .expect("Series::try_from should not fail on single literal value") } } @@ -512,7 +516,7 @@ mod test { ])] fn test_literal_series_roundtrip_basics(#[case] literals: Vec) { let expected = [vec![Literal::Null], literals, vec![Literal::Null]].concat(); - let series = Series::from_literals(expected.clone()).unwrap(); + let series = Series::from_literals(expected.clone(), None).unwrap(); let actual = series.to_literals().collect::>(); assert_eq!(expected, actual) @@ -521,7 +525,7 @@ mod test { #[test] fn test_literals_to_series_mismatched() { let values = vec![Literal::UInt64(1), Literal::Utf8("test".to_string())]; - let actual = Series::from_literals(values); + let actual = Series::from_literals(values, None); assert!(actual.is_err()); } } diff --git a/src/daft-core/src/series/mod.rs b/src/daft-core/src/series/mod.rs index c38ccadfb6..7b9b7a3bae 100644 --- a/src/daft-core/src/series/mod.rs +++ b/src/daft-core/src/series/mod.rs @@ -262,7 +262,7 @@ macro_rules! series { // put into a vec first for compile-time type consistency checking let elements = vec![$($element),+]; let elements_lit = elements.into_iter().map(Literal::from).collect::>(); - Series::from_literals(elements_lit).unwrap() + Series::from_literals(elements_lit, None).unwrap() } }; } diff --git a/src/daft-dsl/src/python_udf.rs b/src/daft-dsl/src/python_udf.rs index b19687a347..f0696a091f 100644 --- a/src/daft-dsl/src/python_udf.rs +++ b/src/daft-dsl/src/python_udf.rs @@ -191,7 +191,6 @@ impl RowWisePyFn { args: &[Series], num_rows: usize, ) -> DaftResult<(Series, std::time::Duration)> { - use daft_core::python::PySeries; use pyo3::prelude::*; let inner_ref = self.inner.as_ref(); @@ -216,13 +215,15 @@ impl RowWisePyFn { } let result = func.call1((inner_ref, args_ref, &py_args))?; + let result = Literal::from_pyobj(&result, Some(&self.return_dtype))?; + py_args.clear(); DaftResult::Ok(result) }) .collect::>>()?; Ok(( - PySeries::from_pylist_impl(name, outputs, self.return_dtype.clone())?.series, + Series::from_literals(outputs, Some(self.return_dtype.clone()))?.rename(name), gil_contention_time, )) }) diff --git a/src/daft-sql/src/planner.rs b/src/daft-sql/src/planner.rs index 658c68209b..265d1390a6 100644 --- a/src/daft-sql/src/planner.rs +++ b/src/daft-sql/src/planner.rs @@ -1512,7 +1512,7 @@ impl SQLPlanner<'_> { }) .collect::>>()?; - let s = Series::from_literals(values)?; + let s = Series::from_literals(values, None)?; let s = FixedSizeListArray::new( Field::new("tuple", s.data_type().clone()) .to_fixed_size_list_field(exprs.len())?, From bb8ebf2781a98594c00f0d7d6bfefef30ed9b3c8 Mon Sep 17 00:00:00 2001 From: universalmind303 Date: Wed, 1 Oct 2025 11:37:25 -0500 Subject: [PATCH 03/15] skip inference for struct types too --- src/daft-core/src/series/from_lit.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/daft-core/src/series/from_lit.rs b/src/daft-core/src/series/from_lit.rs index cd0d8f8936..43dc31d941 100644 --- a/src/daft-core/src/series/from_lit.rs +++ b/src/daft-core/src/series/from_lit.rs @@ -187,7 +187,8 @@ impl Series { }) .collect::>(); - Ok(Self::from_literals(child_values, None)?.rename(&f.name)) + Ok(Self::from_literals(child_values, Some(f.dtype.clone()))? + .rename(&f.name)) }) .collect::>()?; From 82b86f5e00c769873ff7c1fc6ba8616290f57d27 Mon Sep 17 00:00:00 2001 From: universalmind303 Date: Wed, 1 Oct 2025 11:41:52 -0500 Subject: [PATCH 04/15] skip inference for python contructor --- src/daft-core/src/python/series.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/daft-core/src/python/series.rs b/src/daft-core/src/python/series.rs index e67f081c7d..b625084a41 100644 --- a/src/daft-core/src/python/series.rs +++ b/src/daft-core/src/python/series.rs @@ -114,7 +114,7 @@ impl PySeries { (literals_with_supertype, supertype) }; - let mut series = Series::from_literals(literals, None)?.cast(&dtype)?; + let mut series = Series::from_literals(literals, Some(dtype))?; if let Some(name) = name { series = series.rename(name); } From da0f54f4d2fedc8a27f6f9499dea4c26ce02bd73 Mon Sep 17 00:00:00 2001 From: universalmind303 Date: Wed, 1 Oct 2025 16:42:35 -0500 Subject: [PATCH 05/15] more literal optimizations --- src/arrow2/src/trusted_len.rs | 2 + src/daft-core/src/array/ops/get.rs | 10 +- src/daft-core/src/array/ops/get_lit.rs | 19 +-- src/daft-core/src/series/from_lit.rs | 186 +++++++++++++++-------- src/daft-core/src/series/mod.rs | 1 + src/daft-dsl/src/functions/python/mod.rs | 4 +- src/daft-dsl/src/python_udf.rs | 37 ++--- 7 files changed, 163 insertions(+), 96 deletions(-) diff --git a/src/arrow2/src/trusted_len.rs b/src/arrow2/src/trusted_len.rs index 67b941c63b..98ce95bac2 100644 --- a/src/arrow2/src/trusted_len.rs +++ b/src/arrow2/src/trusted_len.rs @@ -11,6 +11,8 @@ use std::slice::Iter; /// Consumers of this trait must inspect Iterator::size_hint()’s upper bound. pub unsafe trait TrustedLen: Iterator {} +unsafe impl TrustedLen for std::ops::Range {} + unsafe impl TrustedLen for Iter<'_, T> {} unsafe impl B> TrustedLen for std::iter::Map {} diff --git a/src/daft-core/src/array/ops/get.rs b/src/daft-core/src/array/ops/get.rs index ffaaa36187..584212cc33 100644 --- a/src/daft-core/src/array/ops/get.rs +++ b/src/daft-core/src/array/ops/get.rs @@ -48,9 +48,13 @@ macro_rules! impl_array_arrow_get { impl $ArrayT { #[inline] pub fn get(&self, idx: usize) -> Option<$output> { - if idx >= self.len() { - panic!("Out of bounds: {} vs len: {}", idx, self.len()) - } + assert!( + idx < self.len(), + "Out of bounds: {} vs len: {}", + idx, + self.len() + ); + let arrow_array = self.as_arrow(); let is_valid = arrow_array .validity() diff --git a/src/daft-core/src/array/ops/get_lit.rs b/src/daft-core/src/array/ops/get_lit.rs index a5e3ff70c6..adbed80fca 100644 --- a/src/daft-core/src/array/ops/get_lit.rs +++ b/src/daft-core/src/array/ops/get_lit.rs @@ -141,12 +141,12 @@ impl FixedShapeSparseTensorArray { impl MapArray { pub fn get_lit(&self, idx: usize) -> Literal { - assert!( - idx < self.len(), - "Out of bounds: {} vs len: {}", - idx, - self.len() - ); + // assert!( + // idx < self.len(), + // "Out of bounds: {} vs len: {}", + // idx, + // self.len() + // ); map_or_null(self.get(idx), |entry: Series| { let entry = entry.struct_().unwrap(); @@ -179,12 +179,7 @@ macro_rules! impl_array_get_lit { ($type:ty, $variant:ident) => { impl $type { pub fn get_lit(&self, idx: usize) -> Literal { - assert!( - idx < self.len(), - "Out of bounds: {} vs len: {}", - idx, - self.len() - ); + // don't need to do assertions here because it also happens in `self.get` map_or_null(self.get(idx), Literal::$variant) } } diff --git a/src/daft-core/src/series/from_lit.rs b/src/daft-core/src/series/from_lit.rs index 43dc31d941..4fdf743317 100644 --- a/src/daft-core/src/series/from_lit.rs +++ b/src/daft-core/src/series/from_lit.rs @@ -1,3 +1,9 @@ +use std::sync::Arc; + +use arrow2::{ + array::{MutableArray, MutablePrimitiveArray}, + trusted_len::TrustedLen, +}; use common_error::{DaftError, DaftResult}; use common_image::CowImage; @@ -53,81 +59,102 @@ pub(crate) fn combine_lit_types(left: &DataType, right: &DataType) -> Option, dtype: Option) -> DaftResult { - let dtype = match dtype { - Some(dtype) => dtype, - None => values.iter().try_fold(DataType::Null, |acc, v| { - let dtype = v.get_type(); - combine_lit_types(&acc, &dtype).ok_or_else(|| { - DaftError::ValueError(format!( - "All literals must have the same data type or null. Found: {} vs {}", - acc, dtype - )) - }) - })?, - }; +/// How to handle errors when combining literals. +pub enum OnError { + /// Errors result in Null values + Null, + /// Errors will be raised + Raise, +} +impl Series { + pub fn from_literals_iter> + TrustedLen>( + values: I, + dtype: DataType, + on_error: OnError, + ) -> DaftResult { let field = Field::new("literal", dtype.clone()); + // DaftError isn't `Clone` so just using a string to store the error instead. + let mut err: Option = None; + macro_rules! unwrap_inner { ($expr:expr, $variant:ident) => { match $expr { - Literal::$variant(val) => Some(val), - Literal::Null => None, - _ => unreachable!("datatype is already checked"), + Ok(Literal::$variant(val)) => Some(val), + Ok(Literal::Null) => None, + Err(e) => { + match on_error { + OnError::Null => None, + OnError::Raise => { + err = Some(e.to_string()); + None + }, + + } + + } + // SAFETY: This is safe because we have already checked that all literals have the same data type or null. + _ => unsafe { std::hint::unreachable_unchecked() }, } }; ($expr:expr, $literal:pat => $value:expr) => { match $expr { - $literal => Some($value), - Literal::Null => None, - _ => unreachable!("datatype is already checked"), + Ok($literal) => Some($value), + Ok(Literal::Null) => None, + Err(e) => { + match on_error { + OnError::Null => None, + OnError::Raise => { + err = Some(e.to_string()); + None + }, + + } + } + // SAFETY: This is safe because we have already checked that all literals have the same data type or null. + _ => unsafe { std::hint::unreachable_unchecked() }, } }; } macro_rules! from_iter_with_str { ($arr_type:ty, $variant:ident) => {{ - <$arr_type>::from_iter( - "literal", - values.into_iter().map(|lit| unwrap_inner!(lit, $variant)), - ) - .into_series() + <$arr_type>::from_iter("literal", values.map(|lit| unwrap_inner!(lit, $variant))) + .into_series() }}; } - macro_rules! from_iter_with_field { + macro_rules! from_mutable_primitive_array { ($arr_type:ty, $variant:ident) => {{ - <$arr_type>::from_iter( - field, - values.into_iter().map(|lit| unwrap_inner!(lit, $variant)), - ) - .into_series() + let mut arr = MutablePrimitiveArray::<$arr_type>::with_capacity(values.len()); + for value in values { + let value = unwrap_inner!(value, $variant); + arr.push(value); + } + + Series::from_arrow(Arc::new(field), arr.as_box())? }}; } - Ok(match dtype { + let s = match dtype { DataType::Null => NullArray::full_null("literal", &dtype, values.len()).into_series(), DataType::Boolean => from_iter_with_str!(BooleanArray, Boolean), DataType::Utf8 => from_iter_with_str!(Utf8Array, Utf8), DataType::Binary => from_iter_with_str!(BinaryArray, Binary), - DataType::Int8 => from_iter_with_field!(Int8Array, Int8), - DataType::UInt8 => from_iter_with_field!(UInt8Array, UInt8), - DataType::Int16 => from_iter_with_field!(Int16Array, Int16), - DataType::UInt16 => from_iter_with_field!(UInt16Array, UInt16), - DataType::Int32 => from_iter_with_field!(Int32Array, Int32), - DataType::UInt32 => from_iter_with_field!(UInt32Array, UInt32), - DataType::Int64 => from_iter_with_field!(Int64Array, Int64), - DataType::UInt64 => from_iter_with_field!(UInt64Array, UInt64), + DataType::Int8 => from_mutable_primitive_array!(i8, Int8), + DataType::UInt8 => from_mutable_primitive_array!(u8, UInt8), + DataType::Int16 => from_mutable_primitive_array!(i16, Int16), + DataType::UInt16 => from_mutable_primitive_array!(u16, UInt16), + DataType::Int32 => from_mutable_primitive_array!(i32, Int32), + DataType::UInt32 => from_mutable_primitive_array!(u32, UInt32), + DataType::Int64 => from_mutable_primitive_array!(i64, Int64), + DataType::UInt64 => from_mutable_primitive_array!(u64, UInt64), DataType::Interval => from_iter_with_str!(IntervalArray, Interval), - DataType::Float32 => from_iter_with_field!(Float32Array, Float32), - DataType::Float64 => from_iter_with_field!(Float64Array, Float64), + DataType::Float32 => from_mutable_primitive_array!(f32, Float32), + DataType::Float64 => from_mutable_primitive_array!(f64, Float64), + DataType::Decimal128 { .. } => Decimal128Array::from_iter( field, values @@ -165,7 +192,6 @@ impl Series { } DataType::List(child_dtype) => { let data = values - .iter() .map(|v| { unwrap_inner!(v, List) .map(|s| s.cast(&child_dtype)) @@ -175,6 +201,8 @@ impl Series { ListArray::try_from(("literal", data.as_slice()))?.into_series() } DataType::Struct(fields) => { + let values = values.collect::>(); + let children = fields .iter() .enumerate() @@ -193,7 +221,9 @@ impl Series { .collect::>()?; let validity = arrow2::bitmap::Bitmap::from_trusted_len_iter( - values.iter().map(|v| *v != Literal::Null), + values + .iter() + .map(|v| v.as_ref().is_ok_and(|v| v != &Literal::Null)), ); StructArray::new(field, children, Some(validity)).into_series() @@ -209,6 +239,7 @@ impl Series { } #[cfg(feature = "python")] DataType::File => { + let values = values.collect::>(); use std::sync::Arc; use common_file::FileReference; @@ -276,8 +307,9 @@ impl Series { DataType::File => unreachable!("File type is only supported with the python feature"), DataType::Tensor(_) => { - let (data, shapes) = values - .iter() + let values = values.collect::>(); + + let (data, shapes) = values.iter() .map(|v| { unwrap_inner!(v, Literal::Tensor { data, shape } => (data, shape.as_slice())).unzip() }) @@ -293,6 +325,8 @@ impl Series { TensorArray::new(field, physical).into_series() } DataType::SparseTensor(..) => { + let values = values.collect::>(); + let (values, indices, shapes) = values .iter() .map(|v| { @@ -319,17 +353,19 @@ impl Series { SparseTensorArray::new(field, physical).into_series() } DataType::Embedding(inner_dtype, size) => { - let validity = arrow2::bitmap::Bitmap::from_trusted_len_iter( - values.iter().map(|v| *v != Literal::Null), - ); - - let data = values + let (validity, data): (Vec<_>, Vec<_>) = values .into_iter() .map(|v| { - unwrap_inner!(v, Embedding) - .unwrap_or_else(|| Self::full_null("literal", &inner_dtype, size)) + ( + v.as_ref().is_ok_and(|v| v != &Literal::Null), + unwrap_inner!(v, Embedding) + .unwrap_or_else(|| Self::full_null("literal", &inner_dtype, size)), + ) }) - .collect::>(); + .unzip(); + + let validity = arrow2::bitmap::Bitmap::from(validity); + let flat_child = Self::concat(&data.iter().collect::>())?; let physical = @@ -369,8 +405,7 @@ impl Series { } DataType::Image(image_mode) => { let data = values - .iter() - .map(|v| unwrap_inner!(v, Image).map(|img| CowImage::from(&img.0))) + .map(|v| unwrap_inner!(v, Image).map(|img| CowImage::from(img.0))) .collect::>(); image_array_from_img_buffers("literal", &data, image_mode)?.into_series() @@ -383,7 +418,36 @@ impl Series { | DataType::FixedShapeTensor(..) | DataType::FixedShapeSparseTensor(..) | DataType::Unknown => unreachable!("Literal should never have data type: {dtype}"), - }) + }; + if let Some(e) = err { + Err(DaftError::ComputeError(e)) + } else { + Ok(s) + } + } + + /// Converts a vec of literals into a Series. + /// + /// Literals must all be the same type or null, this function does not do any casting or coercion. + /// If that is desired, you should handle it for each literal before converting it into a series. + pub fn from_literals(values: Vec, dtype: Option) -> DaftResult { + let dtype = match dtype { + Some(dtype) => dtype, + None => values.iter().try_fold(DataType::Null, |acc, v| { + let dtype = v.get_type(); + combine_lit_types(&acc, &dtype).ok_or_else(|| { + DaftError::ValueError(format!( + "All literals must have the same data type or null. Found: {} vs {}", + acc, dtype + )) + }) + })?, + }; + Self::from_literals_iter( + values.into_iter().map(DaftResult::Ok), + dtype, + OnError::Raise, + ) } } diff --git a/src/daft-core/src/series/mod.rs b/src/daft-core/src/series/mod.rs index 7b9b7a3bae..e14ee04b44 100644 --- a/src/daft-core/src/series/mod.rs +++ b/src/daft-core/src/series/mod.rs @@ -124,6 +124,7 @@ impl Series { /// /// This function will check the provided [`Field`] (and all its associated potentially nested fields/dtypes) against /// the provided [`arrow2::array::Array`] for compatibility, and returns an error if they do not match. + #[inline] pub fn from_arrow( field: FieldRef, arrow_arr: Box, diff --git a/src/daft-dsl/src/functions/python/mod.rs b/src/daft-dsl/src/functions/python/mod.rs index 8d73e3ba2b..883378bc94 100644 --- a/src/daft-dsl/src/functions/python/mod.rs +++ b/src/daft-dsl/src/functions/python/mod.rs @@ -379,9 +379,9 @@ pub fn get_udf_properties(expr: &ExprRef) -> UDFProperties { udf_properties = Some(UDFProperties { name: py.name().to_string(), resource_request: None, - batch_size: Some(512), + batch_size: None, concurrency: None, - use_process: None, + use_process: Some(true), }); } Ok(TreeNodeRecursion::Continue) diff --git a/src/daft-dsl/src/python_udf.rs b/src/daft-dsl/src/python_udf.rs index f0696a091f..5161c689ed 100644 --- a/src/daft-dsl/src/python_udf.rs +++ b/src/daft-dsl/src/python_udf.rs @@ -198,32 +198,33 @@ impl RowWisePyFn { let name = args[0].name(); let start_time = std::time::Instant::now(); Python::with_gil(|py| { + use daft_core::series::from_lit::OnError; + let gil_contention_time = start_time.elapsed(); let func = py .import(pyo3::intern!(py, "daft.udf.row_wise"))? .getattr(pyo3::intern!(py, "__call_func"))?; let mut py_args = Vec::with_capacity(args.len()); - // pre-allocating py_args vector so we're not creating a new vector for each iteration - let outputs = (0..num_rows) - .map(|i| { - for s in args { - let idx = if s.len() == 1 { 0 } else { i }; - let lit = s.get_lit(idx); - let pyarg = lit.into_pyobject(py)?; - py_args.push(pyarg); - } - - let result = func.call1((inner_ref, args_ref, &py_args))?; - let result = Literal::from_pyobj(&result, Some(&self.return_dtype))?; - - py_args.clear(); - DaftResult::Ok(result) - }) - .collect::>>()?; + // pre-allocating py_args vector so we're not creating a new vector for each iteration + let outputs = (0..num_rows).map(|i| { + for s in args { + let idx = if s.len() == 1 { 0 } else { i }; + let lit = s.get_lit(idx); + let pyarg = lit.into_pyobject(py)?; + py_args.push(pyarg); + } + + let result = func.call1((inner_ref, args_ref, &py_args))?; + let lit = Literal::from_pyobj(&result, Some(&self.return_dtype))?; + + py_args.clear(); + DaftResult::Ok(lit) + }); Ok(( - Series::from_literals(outputs, Some(self.return_dtype.clone()))?.rename(name), + Series::from_literals_iter(outputs, self.return_dtype.clone(), OnError::Raise)? + .rename(name), gil_contention_time, )) }) From 5e952c0e3b7b04b0eef06cabd70b0ec275518f9a Mon Sep 17 00:00:00 2001 From: universalmind303 Date: Wed, 1 Oct 2025 16:55:15 -0500 Subject: [PATCH 06/15] cleanup --- src/daft-core/src/array/ops/get_lit.rs | 12 ++++++------ src/daft-core/src/series/from_lit.rs | 6 ++++++ 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/src/daft-core/src/array/ops/get_lit.rs b/src/daft-core/src/array/ops/get_lit.rs index adbed80fca..8dcb6794db 100644 --- a/src/daft-core/src/array/ops/get_lit.rs +++ b/src/daft-core/src/array/ops/get_lit.rs @@ -141,12 +141,12 @@ impl FixedShapeSparseTensorArray { impl MapArray { pub fn get_lit(&self, idx: usize) -> Literal { - // assert!( - // idx < self.len(), - // "Out of bounds: {} vs len: {}", - // idx, - // self.len() - // ); + assert!( + idx < self.len(), + "Out of bounds: {} vs len: {}", + idx, + self.len() + ); map_or_null(self.get(idx), |entry: Series| { let entry = entry.struct_().unwrap(); diff --git a/src/daft-core/src/series/from_lit.rs b/src/daft-core/src/series/from_lit.rs index 4fdf743317..01b2031a20 100644 --- a/src/daft-core/src/series/from_lit.rs +++ b/src/daft-core/src/series/from_lit.rs @@ -68,6 +68,12 @@ pub enum OnError { } impl Series { + /// Creates a series from an iterator of Result. + /// The `on_error` indicates how to handle the result. + /// + /// Unlike `from_literals`, `from_literal_iter` does not do any dtype inference + /// as such, it'll panic if the dtype variant does not match the expected literal variant. + /// So it is up to the caller to ensure the datatypes match! pub fn from_literals_iter> + TrustedLen>( values: I, dtype: DataType, From 37390cf637a8f876ca888ed64330116d5c4eeef9 Mon Sep 17 00:00:00 2001 From: universalmind303 Date: Wed, 1 Oct 2025 16:56:33 -0500 Subject: [PATCH 07/15] cleanup --- src/daft-core/src/series/from_lit.rs | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/src/daft-core/src/series/from_lit.rs b/src/daft-core/src/series/from_lit.rs index 01b2031a20..425ccb61b0 100644 --- a/src/daft-core/src/series/from_lit.rs +++ b/src/daft-core/src/series/from_lit.rs @@ -68,11 +68,11 @@ pub enum OnError { } impl Series { - /// Creates a series from an iterator of Result. - /// The `on_error` indicates how to handle the result. - /// - /// Unlike `from_literals`, `from_literal_iter` does not do any dtype inference - /// as such, it'll panic if the dtype variant does not match the expected literal variant. + /// Creates a series from an iterator of Result. + /// The `on_error` indicates how to handle the result. + /// + /// Unlike `from_literals`, `from_literal_iter` does not do any dtype inference + /// as such, it'll panic if the dtype variant does not match the expected literal variant. /// So it is up to the caller to ensure the datatypes match! pub fn from_literals_iter> + TrustedLen>( values: I, @@ -98,10 +98,8 @@ impl Series { }, } - } - // SAFETY: This is safe because we have already checked that all literals have the same data type or null. - _ => unsafe { std::hint::unreachable_unchecked() }, + _ => unreachable!() } }; @@ -119,8 +117,7 @@ impl Series { } } - // SAFETY: This is safe because we have already checked that all literals have the same data type or null. - _ => unsafe { std::hint::unreachable_unchecked() }, + _ => unreachable!() } }; } From b93a44c0db95bebfc107fe3c29589ff15df9f1fd Mon Sep 17 00:00:00 2001 From: universalmind303 Date: Thu, 2 Oct 2025 11:57:45 -0500 Subject: [PATCH 08/15] add tests and better docstrings --- src/daft-core/src/series/from_lit.rs | 282 +++++++++++++++++++++++---- src/daft-core/src/series/mod.rs | 2 +- 2 files changed, 240 insertions(+), 44 deletions(-) diff --git a/src/daft-core/src/series/from_lit.rs b/src/daft-core/src/series/from_lit.rs index 425ccb61b0..013b3a0d6b 100644 --- a/src/daft-core/src/series/from_lit.rs +++ b/src/daft-core/src/series/from_lit.rs @@ -3,6 +3,7 @@ use std::sync::Arc; use arrow2::{ array::{MutableArray, MutablePrimitiveArray}, trusted_len::TrustedLen, + types::months_days_ns, }; use common_error::{DaftError, DaftResult}; use common_image::CowImage; @@ -68,12 +69,11 @@ pub enum OnError { } impl Series { - /// Creates a series from an iterator of Result. + /// Creates a series from an iterator of `Result`. /// The `on_error` indicates how to handle the result. /// - /// Unlike `from_literals`, `from_literal_iter` does not do any dtype inference - /// as such, it'll panic if the dtype variant does not match the expected literal variant. - /// So it is up to the caller to ensure the datatypes match! + /// Unlike `from_literals`, `from_literal_iter` does not do any dtype inference. + /// If the datatype of `Literal` != dtype, then it will either result in a null value, or an error depending on the `on_error` setting. pub fn from_literals_iter> + TrustedLen>( values: I, dtype: DataType, @@ -82,24 +82,45 @@ impl Series { let field = Field::new("literal", dtype.clone()); // DaftError isn't `Clone` so just using a string to store the error instead. - let mut err: Option = None; + let mut errs: Vec = Vec::with_capacity(values.len()); + // TODO: ideally we'd want to fail fast if `OnError::Raise`, but you can't return from the outer function inside a closure, so we just push the error to the vector and continue + // Then we check if there are any errors at the end. + // Instead, we should just immediately return an error one is encountered. + // But this requires significant refactoring to remove the iter closures and use for loops for all variants. + // For now, I created the `unwrap_inner_fail_fast` that'll immediately return, but it can't be used in closures. macro_rules! unwrap_inner { ($expr:expr, $variant:ident) => { match $expr { Ok(Literal::$variant(val)) => Some(val), Ok(Literal::Null) => None, - Err(e) => { - match on_error { - OnError::Null => None, - OnError::Raise => { - err = Some(e.to_string()); - None - }, + Err(e) => match on_error { + OnError::Null => None, + OnError::Raise => { + errs.push(e.to_string()); + None + } + }, + Ok(other) => { + let ty = other.get_type(); + if &ty != &dtype { + match on_error { + OnError::Null => None, + OnError::Raise => { + errs.push(format!( + "All literals must have the same data type or null. Found: {ty} vs {dtype}" + )); + None + } + } + } else { + errs.push(format!( + "All literals must have the same data type or null. Found: {ty} vs {dtype}" + )); + None } } - _ => unreachable!() } }; @@ -107,17 +128,100 @@ impl Series { match $expr { Ok($literal) => Some($value), Ok(Literal::Null) => None, - Err(e) => { - match on_error { - OnError::Null => None, - OnError::Raise => { - err = Some(e.to_string()); - None - }, + Err(e) => match on_error { + OnError::Null => None, + OnError::Raise => { + errs.push(e.to_string()); + None + } + }, + Ok(other) => { + let ty = other.get_type(); + if &ty != &dtype { + match on_error { + OnError::Null => None, + OnError::Raise => { + errs.push(format!( + "All literals must have the same data type or null. Found: {ty} vs {dtype}" + )); + None + } + } + } else { + errs.push(format!( + "All literals must have the same data type or null. Found: {ty} vs {dtype}" + )); + None + } + } + } + }; + } + /// Same as `unwrap_inner` but immediately return out of `from_literals_iter` instead of collecting the errors. + /// This can't be used inside closures that don't return a Result + macro_rules! unwrap_inner_fail_fast { + ($expr:expr, $variant:ident) => { + match $expr { + Ok(Literal::$variant(val)) => Some(val), + Ok(Literal::Null) => None, + Err(e) => match on_error { + OnError::Null => None, + OnError::Raise => { + return Err(e) + } + }, + Ok(other) => { + let ty = other.get_type(); + if &ty != &dtype { + match on_error { + OnError::Null => None, + OnError::Raise => { + return Err(DaftError::ValueError(format!( + "All literals must have the same data type or null. Found: {} vs {}", + ty, dtype + ))) + } + } + } else { + return Err(DaftError::ValueError(format!( + "All literals must have the same data type or null. Found: {} vs {}", + ty, dtype + ))) + } + } + } + }; + ($expr:expr, $literal:pat => $value:expr) => { + match $expr { + Ok($literal) => Some($value), + Ok(Literal::Null) => None, + Err(e) => match on_error { + OnError::Null => None, + OnError::Raise => { + errs.push(e.to_string()); + None + } + }, + Ok(other) => { + let ty = other.get_type(); + if &ty != &dtype { + match on_error { + OnError::Null => None, + OnError::Raise => { + return Err(DaftError::ValueError(format!( + "All literals must have the same data type or null. Found: {} vs {}", + ty, dtype + ))) + } + } + } else { + return Err(DaftError::ValueError(format!( + "All literals must have the same data type or null. Found: {} vs {}", + ty, dtype + ))) } } - _ => unreachable!() } }; } @@ -133,7 +237,7 @@ impl Series { ($arr_type:ty, $variant:ident) => {{ let mut arr = MutablePrimitiveArray::<$arr_type>::with_capacity(values.len()); for value in values { - let value = unwrap_inner!(value, $variant); + let value = unwrap_inner_fail_fast!(value, $variant); arr.push(value); } @@ -154,9 +258,21 @@ impl Series { DataType::UInt32 => from_mutable_primitive_array!(u32, UInt32), DataType::Int64 => from_mutable_primitive_array!(i64, Int64), DataType::UInt64 => from_mutable_primitive_array!(u64, UInt64), - DataType::Interval => from_iter_with_str!(IntervalArray, Interval), DataType::Float32 => from_mutable_primitive_array!(f32, Float32), DataType::Float64 => from_mutable_primitive_array!(f64, Float64), + DataType::Interval => { + let mut arr = MutablePrimitiveArray::::with_capacity(values.len()); + for value in values { + let value = + unwrap_inner!(value, Literal::Interval(d) => months_days_ns::from(d)); + arr.push(value); + } + + Self::from_arrow( + Arc::new(Field::new("literal", DataType::Interval)), + arr.as_box(), + )? + } DataType::Decimal128 { .. } => Decimal128Array::from_iter( field, @@ -166,15 +282,28 @@ impl Series { ) .into_series(), DataType::Timestamp(_, _) => { - let data = values - .into_iter() - .map(|lit| unwrap_inner!(lit, Literal::Timestamp(ts, ..) => ts)); - let physical = Int64Array::from_iter(Field::new("literal", DataType::Int64), data); + let mut arr = MutablePrimitiveArray::::with_capacity(values.len()); + for value in values { + let value = unwrap_inner_fail_fast!(value, Literal::Timestamp(ts, ..) => ts); + arr.push(value); + } + let physical = Int64Array::from_arrow( + Arc::new(Field::new("literal", DataType::Int64)), + arr.as_box(), + )?; TimestampArray::new(field, physical).into_series() } DataType::Date => { - let data = values.into_iter().map(|lit| unwrap_inner!(lit, Date)); - let physical = Int32Array::from_iter(Field::new("literal", DataType::Int32), data); + let mut arr = MutablePrimitiveArray::::with_capacity(values.len()); + for value in values { + let value = unwrap_inner_fail_fast!(value, Date); + arr.push(value); + } + let physical = Int32Array::from_arrow( + Arc::new(Field::new("literal", DataType::Int32)), + arr.as_box(), + )?; + DateArray::new(field, physical).into_series() } DataType::Time(_) => { @@ -193,17 +322,17 @@ impl Series { DurationArray::new(field, physical).into_series() } - DataType::List(child_dtype) => { + DataType::List(ref child_dtype) => { let data = values .map(|v| { - unwrap_inner!(v, List) - .map(|s| s.cast(&child_dtype)) + unwrap_inner_fail_fast!(v, List) + .map(|s| s.cast(child_dtype)) .transpose() }) .collect::>>()?; ListArray::try_from(("literal", data.as_slice()))?.into_series() } - DataType::Struct(fields) => { + DataType::Struct(ref fields) => { let values = values.collect::>(); let children = fields @@ -355,14 +484,14 @@ impl Series { SparseTensorArray::new(field, physical).into_series() } - DataType::Embedding(inner_dtype, size) => { + DataType::Embedding(ref inner_dtype, ref size) => { let (validity, data): (Vec<_>, Vec<_>) = values .into_iter() .map(|v| { ( v.as_ref().is_ok_and(|v| v != &Literal::Null), unwrap_inner!(v, Embedding) - .unwrap_or_else(|| Self::full_null("literal", &inner_dtype, size)), + .unwrap_or_else(|| Self::full_null("literal", inner_dtype, *size)), ) }) .unzip(); @@ -377,8 +506,8 @@ impl Series { EmbeddingArray::new(field, physical).into_series() } DataType::Map { - key: key_dtype, - value: value_dtype, + key: ref key_dtype, + value: ref value_dtype, } => { let data = values .into_iter() @@ -391,8 +520,8 @@ impl Series { .to_exploded_field() .expect("Expected physical type of map to be list"), vec![ - k.cast(&key_dtype)?.rename("key"), - v.cast(&value_dtype)?.rename("value"), + k.cast(key_dtype)?.rename("key"), + v.cast(value_dtype)?.rename("value"), ], None, ) @@ -422,10 +551,13 @@ impl Series { | DataType::FixedShapeSparseTensor(..) | DataType::Unknown => unreachable!("Literal should never have data type: {dtype}"), }; - if let Some(e) = err { - Err(DaftError::ComputeError(e)) - } else { + if errs.is_empty() { Ok(s) + } else { + Err(DaftError::ComputeError(format!( + "Errors occurred while creating series: {:?}", + errs + ))) } } @@ -464,12 +596,13 @@ impl From for Series { #[cfg(test)] mod test { + use common_error::DaftError; use common_image::Image; use image::{GrayImage, RgbaImage}; use indexmap::indexmap; use rstest::rstest; - use crate::{datatypes::IntervalValue, prelude::*, series}; + use crate::{datatypes::IntervalValue, prelude::*, series, series::from_lit::OnError}; #[rstest] #[case::null(vec![Literal::Null, Literal::Null])] @@ -596,4 +729,67 @@ mod test { let actual = Series::from_literals(values, None); assert!(actual.is_err()); } + + #[test] + fn test_literals_to_series_mismatched_and_dtype() { + let values = vec![Literal::UInt64(1), Literal::Utf8("test".to_string())]; + let actual = Series::from_literals(values, Some(DataType::Utf8)); + assert!(actual.is_err()); + } + + #[test] + fn test_literals_to_series_ignore_errors() { + let values = vec![ + Ok(Literal::Boolean(true)), + Err(DaftError::ValueError("null".to_string())), + ] + .into_iter(); + + let dtype = DataType::Boolean; + let res = Series::from_literals_iter(values, dtype, OnError::Null).unwrap(); + let value = res.get_lit(1); + assert_eq!(value, Literal::Null); + } + + #[test] + fn test_literals_to_series_mismatch_ignore() { + let values = vec![ + Ok(Literal::Boolean(true)), + Ok(Literal::Int8(1)), + Err(DaftError::ValueError("null".to_string())), + ] + .into_iter(); + + let dtype = DataType::Boolean; + let res = Series::from_literals_iter(values, dtype, OnError::Null) + .expect("Failed to create series"); + assert_eq!(res.get_lit(0), Literal::Boolean(true)); + assert_eq!(res.get_lit(1), Literal::Null); + assert_eq!(res.get_lit(2), Literal::Null); + } + + #[test] + fn test_literals_to_series_mismatch_raise() { + let values = vec![ + Ok(Literal::Boolean(true)), + Ok(Literal::Int8(1)), + Err(DaftError::ValueError("null".to_string())), + ] + .into_iter(); + + let dtype = DataType::Boolean; + let res = Series::from_literals_iter(values, dtype, OnError::Raise); + assert!(res.is_err()) + } + + #[test] + fn test_literals_to_series_mismatched_and_dtype_ignore() { + let values = vec![Literal::UInt64(1), Literal::Utf8("test".to_string())]; + let actual = + Series::from_literals_iter(values.into_iter().map(Ok), DataType::Utf8, OnError::Null); + assert!(actual.is_ok()); + let actual = actual.unwrap(); + assert_eq!(actual.get_lit(0), Literal::Null); + assert_eq!(actual.get_lit(1), Literal::Utf8("test".to_string())); + } } diff --git a/src/daft-core/src/series/mod.rs b/src/daft-core/src/series/mod.rs index e14ee04b44..c65b5c7c7c 100644 --- a/src/daft-core/src/series/mod.rs +++ b/src/daft-core/src/series/mod.rs @@ -207,7 +207,7 @@ impl Series { /// Attempts to downcast the Series to a primitive slice /// This will return an error if the Series is not of the physical type `T` /// # Example - /// ```rust,no_run + /// ```rust,ignore /// let i32_arr: &[i32] = series.try_as_slice::()?; /// /// let f64_arr: &[f64] = series.try_as_slice::()?; From 434cbdc23583263311633bb45c052f40947c53df Mon Sep 17 00:00:00 2001 From: universalmind303 Date: Thu, 2 Oct 2025 12:08:03 -0500 Subject: [PATCH 09/15] lit casting --- src/daft-core/src/lit/mod.rs | 5 ++++- src/daft-core/src/python/series.rs | 2 +- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/src/daft-core/src/lit/mod.rs b/src/daft-core/src/lit/mod.rs index de56c8c3fd..85841eed5d 100644 --- a/src/daft-core/src/lit/mod.rs +++ b/src/daft-core/src/lit/mod.rs @@ -22,6 +22,7 @@ use serde::{Deserialize, Serialize}; use crate::{ datatypes::IntervalValue, prelude::*, + series::from_lit::combine_lit_types, utils::display::{ display_date32, display_decimal128, display_duration, display_series_in_literal, display_time64, display_timestamp, @@ -584,7 +585,9 @@ impl Literal { /// This method is lossy, AKA it is not guaranteed that `lit.cast(dtype).get_type() == dtype`. /// This is because null literals always have the null data type. pub fn cast(self, dtype: &DataType) -> DaftResult { - if &self.get_type() == dtype { + if &self.get_type() == dtype + || (combine_lit_types(&self.get_type(), dtype).as_ref() == Some(dtype)) + { Ok(self) } else { Series::from(self) diff --git a/src/daft-core/src/python/series.rs b/src/daft-core/src/python/series.rs index b625084a41..e67f081c7d 100644 --- a/src/daft-core/src/python/series.rs +++ b/src/daft-core/src/python/series.rs @@ -114,7 +114,7 @@ impl PySeries { (literals_with_supertype, supertype) }; - let mut series = Series::from_literals(literals, Some(dtype))?; + let mut series = Series::from_literals(literals, None)?.cast(&dtype)?; if let Some(name) = name { series = series.rename(name); } From f4e28d782d26df9f0245871847104e9f960fa46d Mon Sep 17 00:00:00 2001 From: universalmind303 Date: Thu, 2 Oct 2025 12:42:22 -0500 Subject: [PATCH 10/15] disable useprocess --- src/daft-dsl/src/functions/python/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/daft-dsl/src/functions/python/mod.rs b/src/daft-dsl/src/functions/python/mod.rs index 883378bc94..187ab3103b 100644 --- a/src/daft-dsl/src/functions/python/mod.rs +++ b/src/daft-dsl/src/functions/python/mod.rs @@ -381,7 +381,7 @@ pub fn get_udf_properties(expr: &ExprRef) -> UDFProperties { resource_request: None, batch_size: None, concurrency: None, - use_process: Some(true), + use_process: None, }); } Ok(TreeNodeRecursion::Continue) From 3de4b8df5ee4b7020042fb4559a31f1e0346a0a4 Mon Sep 17 00:00:00 2001 From: universalmind303 Date: Thu, 2 Oct 2025 13:25:26 -0500 Subject: [PATCH 11/15] add use_process flag to daft.func --- daft/daft/__init__.pyi | 1 + daft/udf/__init__.py | 36 ++++++++++++++++++----- daft/udf/generator.py | 9 ++++-- daft/udf/row_wise.py | 5 ++-- src/daft-core/src/series/mod.rs | 3 +- src/daft-dsl/src/expr/mod.rs | 2 ++ src/daft-dsl/src/functions/python/mod.rs | 8 ++--- src/daft-dsl/src/python.rs | 2 ++ src/daft-dsl/src/python_udf.rs | 4 +++ src/daft-ir/src/proto/functions.rs | 2 ++ src/daft-logical-plan/src/ops/project.rs | 2 ++ src/daft-logical-plan/src/partitioning.rs | 2 ++ src/daft-proto/proto/v1/daft.proto | 1 + src/daft-proto/src/generated/daft.v1.rs | 2 ++ 14 files changed, 62 insertions(+), 17 deletions(-) diff --git a/daft/daft/__init__.pyi b/daft/daft/__init__.pyi index 347ef98c3b..eb664e1406 100644 --- a/daft/daft/__init__.pyi +++ b/daft/daft/__init__.pyi @@ -1343,6 +1343,7 @@ def row_wise_udf( name: str, inner: Callable[..., Any], return_dtype: PyDataType, + use_process: bool | None, original_args: tuple[tuple[Any, ...], dict[str, Any]], expr_args: list[PyExpr], ) -> PyExpr: ... diff --git a/daft/udf/__init__.py b/daft/udf/__init__.py index 43aa94c2ae..ca44f7b58f 100644 --- a/daft/udf/__init__.py +++ b/daft/udf/__init__.py @@ -28,6 +28,7 @@ class _PartialUdf: return_dtype: DataTypeLike | None unnest: bool + use_process: bool | None @overload def __call__(self, fn: Callable[P, Iterator[T]]) -> GeneratorUdf[P, T]: ... # type: ignore[overload-overlap] @@ -36,9 +37,9 @@ def __call__(self, fn: Callable[P, T]) -> RowWiseUdf[P, T]: ... def __call__(self, fn: Callable[P, Any]) -> GeneratorUdf[P, Any] | RowWiseUdf[P, Any]: if isgeneratorfunction(fn): - return GeneratorUdf(fn, return_dtype=self.return_dtype, unnest=self.unnest) + return GeneratorUdf(fn, return_dtype=self.return_dtype, unnest=self.unnest, use_process=self.use_process) else: - return RowWiseUdf(fn, return_dtype=self.return_dtype, unnest=self.unnest) + return RowWiseUdf(fn, return_dtype=self.return_dtype, unnest=self.unnest, use_process=self.use_process) class _DaftFuncDecorator: @@ -212,20 +213,41 @@ class _DaftFuncDecorator: """ @overload - def __new__(cls, *, return_dtype: DataTypeLike | None = None, unnest: bool = False) -> _PartialUdf: ... # type: ignore[misc] + def __new__( # type: ignore[misc] + cls, + *, + return_dtype: DataTypeLike | None = None, + unnest: bool = False, + use_process: bool | None = None, + ) -> _PartialUdf: ... @overload def __new__( # type: ignore[misc] - cls, fn: Callable[P, Iterator[T]], *, return_dtype: DataTypeLike | None = None, unnest: bool = False + cls, + fn: Callable[P, Iterator[T]], + *, + return_dtype: DataTypeLike | None = None, + unnest: bool = False, + use_process: bool | None = None, ) -> GeneratorUdf[P, T]: ... @overload def __new__( # type: ignore[misc] - cls, fn: Callable[P, T], *, return_dtype: DataTypeLike | None = None, unnest: bool = False + cls, + fn: Callable[P, T], + *, + return_dtype: DataTypeLike | None = None, + unnest: bool = False, + use_process: bool | None = None, ) -> RowWiseUdf[P, T]: ... def __new__( # type: ignore[misc] - cls, fn: Callable[P, Any] | None = None, *, return_dtype: DataTypeLike | None = None, unnest: bool = False + cls, + fn: Callable[P, Any] | None = None, + *, + return_dtype: DataTypeLike | None = None, + unnest: bool = False, + use_process: bool | None = None, ) -> _PartialUdf | GeneratorUdf[P, Any] | RowWiseUdf[P, Any]: - partial_udf = _PartialUdf(return_dtype=return_dtype, unnest=unnest) + partial_udf = _PartialUdf(return_dtype=return_dtype, unnest=unnest, use_process=use_process) return partial_udf if fn is None else partial_udf(fn) diff --git a/daft/udf/generator.py b/daft/udf/generator.py index 91216ea3df..5c4297b664 100644 --- a/daft/udf/generator.py +++ b/daft/udf/generator.py @@ -32,10 +32,13 @@ class GeneratorUdf(Generic[P, T]): If no values are yielded for an input, a null value is inserted. """ - def __init__(self, fn: Callable[P, Iterator[T]], return_dtype: DataTypeLike | None, unnest: bool): + def __init__( + self, fn: Callable[P, Iterator[T]], return_dtype: DataTypeLike | None, unnest: bool, use_process: bool | None + ): self._inner = fn self.name = get_unique_function_name(fn) self.unnest = unnest + self.use_process = use_process # attempt to extract return type from an Iterator or Generator type hint if return_dtype is None: @@ -85,7 +88,9 @@ def inner_rowwise(*args: P.args, **kwargs: P.kwargs) -> list[T]: return_dtype_rowwise = DataType.list(self.return_dtype) expr = Expression._from_pyexpr( - row_wise_udf(self.name, inner_rowwise, return_dtype_rowwise._dtype, (args, kwargs), expr_args) + row_wise_udf( + self.name, inner_rowwise, return_dtype_rowwise._dtype, self.use_process, (args, kwargs), expr_args + ) ).explode() if self.unnest: diff --git a/daft/udf/row_wise.py b/daft/udf/row_wise.py index 6e596f9aca..cc27830d89 100644 --- a/daft/udf/row_wise.py +++ b/daft/udf/row_wise.py @@ -33,10 +33,11 @@ class RowWiseUdf(Generic[P, T]): Row-wise functions are called with data from one row at a time, and map that to a single output value for that row. """ - def __init__(self, fn: Callable[P, T], return_dtype: DataTypeLike | None, unnest: bool): + def __init__(self, fn: Callable[P, T], return_dtype: DataTypeLike | None, unnest: bool, use_process: bool | None): self._inner = fn self.name = get_unique_function_name(fn) self.unnest = unnest + self.use_process = use_process if return_dtype is None: type_hints = get_type_hints(fn) @@ -69,7 +70,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> Expression | T: return self._inner(*args, **kwargs) expr = Expression._from_pyexpr( - row_wise_udf(self.name, self._inner, self.return_dtype._dtype, (args, kwargs), expr_args) + row_wise_udf(self.name, self._inner, self.return_dtype._dtype, self.use_process, (args, kwargs), expr_args) ) if self.unnest: diff --git a/src/daft-core/src/series/mod.rs b/src/daft-core/src/series/mod.rs index c65b5c7c7c..063bf25404 100644 --- a/src/daft-core/src/series/mod.rs +++ b/src/daft-core/src/series/mod.rs @@ -124,7 +124,6 @@ impl Series { /// /// This function will check the provided [`Field`] (and all its associated potentially nested fields/dtypes) against /// the provided [`arrow2::array::Array`] for compatibility, and returns an error if they do not match. - #[inline] pub fn from_arrow( field: FieldRef, arrow_arr: Box, @@ -207,7 +206,7 @@ impl Series { /// Attempts to downcast the Series to a primitive slice /// This will return an error if the Series is not of the physical type `T` /// # Example - /// ```rust,ignore + /// ```rust,ign /// let i32_arr: &[i32] = series.try_as_slice::()?; /// /// let f64_arr: &[f64] = series.try_as_slice::()?; diff --git a/src/daft-dsl/src/expr/mod.rs b/src/daft-dsl/src/expr/mod.rs index 87557eea9e..0487b248c7 100644 --- a/src/daft-dsl/src/expr/mod.rs +++ b/src/daft-dsl/src/expr/mod.rs @@ -1478,6 +1478,7 @@ impl Expr { return_dtype, original_args, args: old_children, + use_process, }))) => { assert!( children.len() == old_children.len(), @@ -1490,6 +1491,7 @@ impl Expr { return_dtype: return_dtype.clone(), original_args: original_args.clone(), args: children, + use_process: *use_process, }))) } } diff --git a/src/daft-dsl/src/functions/python/mod.rs b/src/daft-dsl/src/functions/python/mod.rs index 187ab3103b..9d3102a39c 100644 --- a/src/daft-dsl/src/functions/python/mod.rs +++ b/src/daft-dsl/src/functions/python/mod.rs @@ -19,7 +19,7 @@ use serde::{Deserialize, Serialize}; use super::FunctionExpr; #[cfg(feature = "python")] use crate::python::PyExpr; -use crate::{Expr, ExprRef, functions::scalar::ScalarFn}; +use crate::{Expr, ExprRef, functions::scalar::ScalarFn, python_udf::PyScalarFn}; #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] pub enum MaybeInitializedUDF { @@ -375,13 +375,13 @@ pub fn get_udf_properties(expr: &ExprRef) -> UDFProperties { concurrency: *concurrency, use_process: *use_process, }); - } else if let Expr::ScalarFn(ScalarFn::Python(py)) = e.as_ref() { + } else if let Expr::ScalarFn(ScalarFn::Python(PyScalarFn::RowWise(pyfn))) = e.as_ref() { udf_properties = Some(UDFProperties { - name: py.name().to_string(), + name: pyfn.function_name.to_string(), resource_request: None, batch_size: None, concurrency: None, - use_process: None, + use_process: pyfn.use_process, }); } Ok(TreeNodeRecursion::Continue) diff --git a/src/daft-dsl/src/python.rs b/src/daft-dsl/src/python.rs index a4a563aaa2..4d5a5456be 100644 --- a/src/daft-dsl/src/python.rs +++ b/src/daft-dsl/src/python.rs @@ -238,6 +238,7 @@ pub fn row_wise_udf( name: &str, inner: PyObject, return_dtype: PyDataType, + use_process: Option, original_args: PyObject, expr_args: Vec, ) -> PyExpr { @@ -250,6 +251,7 @@ pub fn row_wise_udf( name, inner.into(), return_dtype.into(), + use_process, original_args.into(), args, ) diff --git a/src/daft-dsl/src/python_udf.rs b/src/daft-dsl/src/python_udf.rs index 5161c689ed..d6241109c4 100644 --- a/src/daft-dsl/src/python_udf.rs +++ b/src/daft-dsl/src/python_udf.rs @@ -58,6 +58,7 @@ pub fn row_wise_udf( name: &str, inner: RuntimePyObject, return_dtype: DataType, + use_process: Option, original_args: RuntimePyObject, args: Vec, ) -> Expr { @@ -67,6 +68,7 @@ pub fn row_wise_udf( return_dtype, original_args, args, + use_process, }))) } @@ -77,6 +79,7 @@ pub struct RowWisePyFn { pub return_dtype: DataType, pub original_args: RuntimePyObject, pub args: Vec, + pub use_process: Option, } impl Display for RowWisePyFn { @@ -101,6 +104,7 @@ impl RowWisePyFn { return_dtype: self.return_dtype.clone(), original_args: self.original_args.clone(), args: children, + use_process: self.use_process, } } diff --git a/src/daft-ir/src/proto/functions.rs b/src/daft-ir/src/proto/functions.rs index 7abd20a4af..3973a5c7f8 100644 --- a/src/daft-ir/src/proto/functions.rs +++ b/src/daft-ir/src/proto/functions.rs @@ -40,6 +40,7 @@ pub fn from_proto_function(message: proto::ScalarFn) -> ProtoResult { return_dtype: from_proto(row_wise_fn.return_dtype)?, original_args: from_proto(row_wise_fn.original_args)?, args: args.into_inner(), + use_process: row_wise_fn.use_process, }; ir::rex::from_py_rowwise_func(func) } @@ -100,6 +101,7 @@ pub fn scalar_fn_to_proto(sf: &ir::functions::scalar::ScalarFn) -> ProtoResult

{ let transforms = children .iter() @@ -508,6 +509,7 @@ fn replace_column_with_semantic_id( return_dtype: return_dtype.clone(), original_args: original_args.clone(), args: new_children, + use_process: *use_process, }), )))) } diff --git a/src/daft-logical-plan/src/partitioning.rs b/src/daft-logical-plan/src/partitioning.rs index 222d7ae6de..7e56d5ff7e 100644 --- a/src/daft-logical-plan/src/partitioning.rs +++ b/src/daft-logical-plan/src/partitioning.rs @@ -399,6 +399,7 @@ fn translate_clustering_spec_expr( return_dtype, original_args, args: children, + use_process, }))) => { let new_children = children .iter() @@ -411,6 +412,7 @@ fn translate_clustering_spec_expr( return_dtype: return_dtype.clone(), original_args: original_args.clone(), args: new_children, + use_process: *use_process, }), )))) } diff --git a/src/daft-proto/proto/v1/daft.proto b/src/daft-proto/proto/v1/daft.proto index 7a21621826..b5202e9ca0 100644 --- a/src/daft-proto/proto/v1/daft.proto +++ b/src/daft-proto/proto/v1/daft.proto @@ -365,6 +365,7 @@ message ScalarFn { DataType return_dtype = 2; PyObject inner = 3; PyObject original_args = 4; + optional bool use_process = 5; } } } diff --git a/src/daft-proto/src/generated/daft.v1.rs b/src/daft-proto/src/generated/daft.v1.rs index 9b66954c4c..a504c80007 100644 --- a/src/daft-proto/src/generated/daft.v1.rs +++ b/src/daft-proto/src/generated/daft.v1.rs @@ -533,6 +533,8 @@ pub mod scalar_fn { pub inner: ::core::option::Option, #[prost(message, optional, tag = "4")] pub original_args: ::core::option::Option, + #[prost(bool, optional, tag = "5")] + pub use_process: ::core::option::Option, } #[derive(Clone, PartialEq, ::prost::Oneof)] pub enum Variant { From f9fa4af02a708cb4d39b926590a55bf986bbb27f Mon Sep 17 00:00:00 2001 From: universalmind303 Date: Thu, 2 Oct 2025 13:27:20 -0500 Subject: [PATCH 12/15] whoops --- src/daft-core/src/series/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/daft-core/src/series/mod.rs b/src/daft-core/src/series/mod.rs index 063bf25404..c23589a218 100644 --- a/src/daft-core/src/series/mod.rs +++ b/src/daft-core/src/series/mod.rs @@ -206,7 +206,7 @@ impl Series { /// Attempts to downcast the Series to a primitive slice /// This will return an error if the Series is not of the physical type `T` /// # Example - /// ```rust,ign + /// ```rust,ignore /// let i32_arr: &[i32] = series.try_as_slice::()?; /// /// let f64_arr: &[f64] = series.try_as_slice::()?; From 121aee36c2eb9bc6e3369662066e1989eeb9d8a7 Mon Sep 17 00:00:00 2001 From: universalmind303 Date: Thu, 2 Oct 2025 14:38:53 -0500 Subject: [PATCH 13/15] fix bug in literal conversion --- src/daft-dsl/src/python_udf.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/daft-dsl/src/python_udf.rs b/src/daft-dsl/src/python_udf.rs index 5161c689ed..5f6b5b6bd2 100644 --- a/src/daft-dsl/src/python_udf.rs +++ b/src/daft-dsl/src/python_udf.rs @@ -217,7 +217,7 @@ impl RowWisePyFn { } let result = func.call1((inner_ref, args_ref, &py_args))?; - let lit = Literal::from_pyobj(&result, Some(&self.return_dtype))?; + let lit = Literal::from_pyobj(&result, None)?; py_args.clear(); DaftResult::Ok(lit) From 1f56c93ad880ae741268c257c6e887412cc6790e Mon Sep 17 00:00:00 2001 From: universalmind303 Date: Thu, 2 Oct 2025 16:15:15 -0500 Subject: [PATCH 14/15] revert series::from_literals changes --- src/daft-core/src/array/ops/cast.rs | 2 +- src/daft-core/src/series/from_lit.rs | 424 +++++---------------------- src/daft-core/src/series/mod.rs | 5 +- src/daft-dsl/src/python_udf.rs | 36 ++- src/daft-sql/src/planner.rs | 2 +- 5 files changed, 99 insertions(+), 370 deletions(-) diff --git a/src/daft-core/src/array/ops/cast.rs b/src/daft-core/src/array/ops/cast.rs index f72cb5f2a7..d257170195 100644 --- a/src/daft-core/src/array/ops/cast.rs +++ b/src/daft-core/src/array/ops/cast.rs @@ -608,7 +608,7 @@ impl PythonArray { .collect::>>() })?; - Ok(Series::from_literals(literals, None)? + Ok(Series::from_literals(literals)? .cast(dtype)? .rename(self.name())) } diff --git a/src/daft-core/src/series/from_lit.rs b/src/daft-core/src/series/from_lit.rs index 013b3a0d6b..39c8b91de4 100644 --- a/src/daft-core/src/series/from_lit.rs +++ b/src/daft-core/src/series/from_lit.rs @@ -1,10 +1,3 @@ -use std::sync::Arc; - -use arrow2::{ - array::{MutableArray, MutablePrimitiveArray}, - trusted_len::TrustedLen, - types::months_days_ns, -}; use common_error::{DaftError, DaftResult}; use common_image::CowImage; @@ -60,220 +53,78 @@ pub(crate) fn combine_lit_types(left: &DataType, right: &DataType) -> Option`. - /// The `on_error` indicates how to handle the result. + /// Converts a vec of literals into a Series. /// - /// Unlike `from_literals`, `from_literal_iter` does not do any dtype inference. - /// If the datatype of `Literal` != dtype, then it will either result in a null value, or an error depending on the `on_error` setting. - pub fn from_literals_iter> + TrustedLen>( - values: I, - dtype: DataType, - on_error: OnError, - ) -> DaftResult { - let field = Field::new("literal", dtype.clone()); + /// Literals must all be the same type or null, this function does not do any casting or coercion. + /// If that is desired, you should handle it for each literal before converting it into a series. + pub fn from_literals(values: Vec) -> DaftResult { + let dtype = values.iter().try_fold(DataType::Null, |acc, v| { + let dtype = v.get_type(); + combine_lit_types(&acc, &dtype).ok_or_else(|| { + DaftError::ValueError(format!( + "All literals must have the same data type or null. Found: {} vs {}", + acc, dtype + )) + }) + })?; - // DaftError isn't `Clone` so just using a string to store the error instead. - let mut errs: Vec = Vec::with_capacity(values.len()); + let field = Field::new("literal", dtype.clone()); - // TODO: ideally we'd want to fail fast if `OnError::Raise`, but you can't return from the outer function inside a closure, so we just push the error to the vector and continue - // Then we check if there are any errors at the end. - // Instead, we should just immediately return an error one is encountered. - // But this requires significant refactoring to remove the iter closures and use for loops for all variants. - // For now, I created the `unwrap_inner_fail_fast` that'll immediately return, but it can't be used in closures. macro_rules! unwrap_inner { ($expr:expr, $variant:ident) => { match $expr { - Ok(Literal::$variant(val)) => Some(val), - Ok(Literal::Null) => None, - Err(e) => match on_error { - OnError::Null => None, - OnError::Raise => { - errs.push(e.to_string()); - None - } - }, - Ok(other) => { - let ty = other.get_type(); - if &ty != &dtype { - match on_error { - OnError::Null => None, - OnError::Raise => { - errs.push(format!( - "All literals must have the same data type or null. Found: {ty} vs {dtype}" - )); - None - } - } - } else { - errs.push(format!( - "All literals must have the same data type or null. Found: {ty} vs {dtype}" - )); - None - - } - } + Literal::$variant(val) => Some(val), + Literal::Null => None, + _ => unreachable!("datatype is already checked"), } }; ($expr:expr, $literal:pat => $value:expr) => { match $expr { - Ok($literal) => Some($value), - Ok(Literal::Null) => None, - Err(e) => match on_error { - OnError::Null => None, - OnError::Raise => { - errs.push(e.to_string()); - None - } - }, - Ok(other) => { - let ty = other.get_type(); - if &ty != &dtype { - match on_error { - OnError::Null => None, - OnError::Raise => { - errs.push(format!( - "All literals must have the same data type or null. Found: {ty} vs {dtype}" - )); - None - } - } - } else { - errs.push(format!( - "All literals must have the same data type or null. Found: {ty} vs {dtype}" - )); - None - } - } - } - }; - } - - /// Same as `unwrap_inner` but immediately return out of `from_literals_iter` instead of collecting the errors. - /// This can't be used inside closures that don't return a Result - macro_rules! unwrap_inner_fail_fast { - ($expr:expr, $variant:ident) => { - match $expr { - Ok(Literal::$variant(val)) => Some(val), - Ok(Literal::Null) => None, - Err(e) => match on_error { - OnError::Null => None, - OnError::Raise => { - return Err(e) - } - }, - Ok(other) => { - let ty = other.get_type(); - if &ty != &dtype { - match on_error { - OnError::Null => None, - OnError::Raise => { - return Err(DaftError::ValueError(format!( - "All literals must have the same data type or null. Found: {} vs {}", - ty, dtype - ))) - } - } - } else { - return Err(DaftError::ValueError(format!( - "All literals must have the same data type or null. Found: {} vs {}", - ty, dtype - ))) - } - } - } - }; - ($expr:expr, $literal:pat => $value:expr) => { - match $expr { - Ok($literal) => Some($value), - Ok(Literal::Null) => None, - Err(e) => match on_error { - OnError::Null => None, - OnError::Raise => { - errs.push(e.to_string()); - None - } - }, - Ok(other) => { - let ty = other.get_type(); - if &ty != &dtype { - match on_error { - OnError::Null => None, - OnError::Raise => { - return Err(DaftError::ValueError(format!( - "All literals must have the same data type or null. Found: {} vs {}", - ty, dtype - ))) - } - } - } else { - return Err(DaftError::ValueError(format!( - "All literals must have the same data type or null. Found: {} vs {}", - ty, dtype - ))) - } - } + $literal => Some($value), + Literal::Null => None, + _ => unreachable!("datatype is already checked"), } }; } macro_rules! from_iter_with_str { ($arr_type:ty, $variant:ident) => {{ - <$arr_type>::from_iter("literal", values.map(|lit| unwrap_inner!(lit, $variant))) - .into_series() + <$arr_type>::from_iter( + "literal", + values.into_iter().map(|lit| unwrap_inner!(lit, $variant)), + ) + .into_series() }}; } - macro_rules! from_mutable_primitive_array { + macro_rules! from_iter_with_field { ($arr_type:ty, $variant:ident) => {{ - let mut arr = MutablePrimitiveArray::<$arr_type>::with_capacity(values.len()); - for value in values { - let value = unwrap_inner_fail_fast!(value, $variant); - arr.push(value); - } - - Series::from_arrow(Arc::new(field), arr.as_box())? + <$arr_type>::from_iter( + field, + values.into_iter().map(|lit| unwrap_inner!(lit, $variant)), + ) + .into_series() }}; } - let s = match dtype { + Ok(match dtype { DataType::Null => NullArray::full_null("literal", &dtype, values.len()).into_series(), DataType::Boolean => from_iter_with_str!(BooleanArray, Boolean), DataType::Utf8 => from_iter_with_str!(Utf8Array, Utf8), DataType::Binary => from_iter_with_str!(BinaryArray, Binary), - DataType::Int8 => from_mutable_primitive_array!(i8, Int8), - DataType::UInt8 => from_mutable_primitive_array!(u8, UInt8), - DataType::Int16 => from_mutable_primitive_array!(i16, Int16), - DataType::UInt16 => from_mutable_primitive_array!(u16, UInt16), - DataType::Int32 => from_mutable_primitive_array!(i32, Int32), - DataType::UInt32 => from_mutable_primitive_array!(u32, UInt32), - DataType::Int64 => from_mutable_primitive_array!(i64, Int64), - DataType::UInt64 => from_mutable_primitive_array!(u64, UInt64), - DataType::Float32 => from_mutable_primitive_array!(f32, Float32), - DataType::Float64 => from_mutable_primitive_array!(f64, Float64), - DataType::Interval => { - let mut arr = MutablePrimitiveArray::::with_capacity(values.len()); - for value in values { - let value = - unwrap_inner!(value, Literal::Interval(d) => months_days_ns::from(d)); - arr.push(value); - } - - Self::from_arrow( - Arc::new(Field::new("literal", DataType::Interval)), - arr.as_box(), - )? - } - + DataType::Int8 => from_iter_with_field!(Int8Array, Int8), + DataType::UInt8 => from_iter_with_field!(UInt8Array, UInt8), + DataType::Int16 => from_iter_with_field!(Int16Array, Int16), + DataType::UInt16 => from_iter_with_field!(UInt16Array, UInt16), + DataType::Int32 => from_iter_with_field!(Int32Array, Int32), + DataType::UInt32 => from_iter_with_field!(UInt32Array, UInt32), + DataType::Int64 => from_iter_with_field!(Int64Array, Int64), + DataType::UInt64 => from_iter_with_field!(UInt64Array, UInt64), + DataType::Interval => from_iter_with_str!(IntervalArray, Interval), + DataType::Float32 => from_iter_with_field!(Float32Array, Float32), + DataType::Float64 => from_iter_with_field!(Float64Array, Float64), DataType::Decimal128 { .. } => Decimal128Array::from_iter( field, values @@ -282,28 +133,15 @@ impl Series { ) .into_series(), DataType::Timestamp(_, _) => { - let mut arr = MutablePrimitiveArray::::with_capacity(values.len()); - for value in values { - let value = unwrap_inner_fail_fast!(value, Literal::Timestamp(ts, ..) => ts); - arr.push(value); - } - let physical = Int64Array::from_arrow( - Arc::new(Field::new("literal", DataType::Int64)), - arr.as_box(), - )?; + let data = values + .into_iter() + .map(|lit| unwrap_inner!(lit, Literal::Timestamp(ts, ..) => ts)); + let physical = Int64Array::from_iter(Field::new("literal", DataType::Int64), data); TimestampArray::new(field, physical).into_series() } DataType::Date => { - let mut arr = MutablePrimitiveArray::::with_capacity(values.len()); - for value in values { - let value = unwrap_inner_fail_fast!(value, Date); - arr.push(value); - } - let physical = Int32Array::from_arrow( - Arc::new(Field::new("literal", DataType::Int32)), - arr.as_box(), - )?; - + let data = values.into_iter().map(|lit| unwrap_inner!(lit, Date)); + let physical = Int32Array::from_iter(Field::new("literal", DataType::Int32), data); DateArray::new(field, physical).into_series() } DataType::Time(_) => { @@ -322,19 +160,18 @@ impl Series { DurationArray::new(field, physical).into_series() } - DataType::List(ref child_dtype) => { + DataType::List(child_dtype) => { let data = values + .iter() .map(|v| { - unwrap_inner_fail_fast!(v, List) - .map(|s| s.cast(child_dtype)) + unwrap_inner!(v, List) + .map(|s| s.cast(&child_dtype)) .transpose() }) .collect::>>()?; ListArray::try_from(("literal", data.as_slice()))?.into_series() } - DataType::Struct(ref fields) => { - let values = values.collect::>(); - + DataType::Struct(fields) => { let children = fields .iter() .enumerate() @@ -347,15 +184,12 @@ impl Series { }) .collect::>(); - Ok(Self::from_literals(child_values, Some(f.dtype.clone()))? - .rename(&f.name)) + Ok(Self::from_literals(child_values)?.rename(&f.name)) }) .collect::>()?; let validity = arrow2::bitmap::Bitmap::from_trusted_len_iter( - values - .iter() - .map(|v| v.as_ref().is_ok_and(|v| v != &Literal::Null)), + values.iter().map(|v| *v != Literal::Null), ); StructArray::new(field, children, Some(validity)).into_series() @@ -371,7 +205,6 @@ impl Series { } #[cfg(feature = "python")] DataType::File => { - let values = values.collect::>(); use std::sync::Arc; use common_file::FileReference; @@ -439,9 +272,8 @@ impl Series { DataType::File => unreachable!("File type is only supported with the python feature"), DataType::Tensor(_) => { - let values = values.collect::>(); - - let (data, shapes) = values.iter() + let (data, shapes) = values + .iter() .map(|v| { unwrap_inner!(v, Literal::Tensor { data, shape } => (data, shape.as_slice())).unzip() }) @@ -457,8 +289,6 @@ impl Series { TensorArray::new(field, physical).into_series() } DataType::SparseTensor(..) => { - let values = values.collect::>(); - let (values, indices, shapes) = values .iter() .map(|v| { @@ -484,20 +314,18 @@ impl Series { SparseTensorArray::new(field, physical).into_series() } - DataType::Embedding(ref inner_dtype, ref size) => { - let (validity, data): (Vec<_>, Vec<_>) = values + DataType::Embedding(inner_dtype, size) => { + let validity = arrow2::bitmap::Bitmap::from_trusted_len_iter( + values.iter().map(|v| *v != Literal::Null), + ); + + let data = values .into_iter() .map(|v| { - ( - v.as_ref().is_ok_and(|v| v != &Literal::Null), - unwrap_inner!(v, Embedding) - .unwrap_or_else(|| Self::full_null("literal", inner_dtype, *size)), - ) + unwrap_inner!(v, Embedding) + .unwrap_or_else(|| Self::full_null("literal", &inner_dtype, size)) }) - .unzip(); - - let validity = arrow2::bitmap::Bitmap::from(validity); - + .collect::>(); let flat_child = Self::concat(&data.iter().collect::>())?; let physical = @@ -506,8 +334,8 @@ impl Series { EmbeddingArray::new(field, physical).into_series() } DataType::Map { - key: ref key_dtype, - value: ref value_dtype, + key: key_dtype, + value: value_dtype, } => { let data = values .into_iter() @@ -520,8 +348,8 @@ impl Series { .to_exploded_field() .expect("Expected physical type of map to be list"), vec![ - k.cast(key_dtype)?.rename("key"), - v.cast(value_dtype)?.rename("value"), + k.cast(&key_dtype)?.rename("key"), + v.cast(&value_dtype)?.rename("value"), ], None, ) @@ -537,7 +365,8 @@ impl Series { } DataType::Image(image_mode) => { let data = values - .map(|v| unwrap_inner!(v, Image).map(|img| CowImage::from(img.0))) + .iter() + .map(|v| unwrap_inner!(v, Image).map(|img| CowImage::from(&img.0))) .collect::>(); image_array_from_img_buffers("literal", &data, image_mode)?.into_series() @@ -550,59 +379,25 @@ impl Series { | DataType::FixedShapeTensor(..) | DataType::FixedShapeSparseTensor(..) | DataType::Unknown => unreachable!("Literal should never have data type: {dtype}"), - }; - if errs.is_empty() { - Ok(s) - } else { - Err(DaftError::ComputeError(format!( - "Errors occurred while creating series: {:?}", - errs - ))) - } - } - - /// Converts a vec of literals into a Series. - /// - /// Literals must all be the same type or null, this function does not do any casting or coercion. - /// If that is desired, you should handle it for each literal before converting it into a series. - pub fn from_literals(values: Vec, dtype: Option) -> DaftResult { - let dtype = match dtype { - Some(dtype) => dtype, - None => values.iter().try_fold(DataType::Null, |acc, v| { - let dtype = v.get_type(); - combine_lit_types(&acc, &dtype).ok_or_else(|| { - DaftError::ValueError(format!( - "All literals must have the same data type or null. Found: {} vs {}", - acc, dtype - )) - }) - })?, - }; - Self::from_literals_iter( - values.into_iter().map(DaftResult::Ok), - dtype, - OnError::Raise, - ) + }) } } impl From for Series { fn from(value: Literal) -> Self { - let dtype = Some(value.get_type()); - Self::from_literals(vec![value], dtype) + Self::from_literals(vec![value]) .expect("Series::try_from should not fail on single literal value") } } #[cfg(test)] mod test { - use common_error::DaftError; use common_image::Image; use image::{GrayImage, RgbaImage}; use indexmap::indexmap; use rstest::rstest; - use crate::{datatypes::IntervalValue, prelude::*, series, series::from_lit::OnError}; + use crate::{datatypes::IntervalValue, prelude::*, series}; #[rstest] #[case::null(vec![Literal::Null, Literal::Null])] @@ -717,7 +512,7 @@ mod test { ])] fn test_literal_series_roundtrip_basics(#[case] literals: Vec) { let expected = [vec![Literal::Null], literals, vec![Literal::Null]].concat(); - let series = Series::from_literals(expected.clone(), None).unwrap(); + let series = Series::from_literals(expected.clone()).unwrap(); let actual = series.to_literals().collect::>(); assert_eq!(expected, actual) @@ -726,70 +521,7 @@ mod test { #[test] fn test_literals_to_series_mismatched() { let values = vec![Literal::UInt64(1), Literal::Utf8("test".to_string())]; - let actual = Series::from_literals(values, None); + let actual = Series::from_literals(values); assert!(actual.is_err()); } - - #[test] - fn test_literals_to_series_mismatched_and_dtype() { - let values = vec![Literal::UInt64(1), Literal::Utf8("test".to_string())]; - let actual = Series::from_literals(values, Some(DataType::Utf8)); - assert!(actual.is_err()); - } - - #[test] - fn test_literals_to_series_ignore_errors() { - let values = vec![ - Ok(Literal::Boolean(true)), - Err(DaftError::ValueError("null".to_string())), - ] - .into_iter(); - - let dtype = DataType::Boolean; - let res = Series::from_literals_iter(values, dtype, OnError::Null).unwrap(); - let value = res.get_lit(1); - assert_eq!(value, Literal::Null); - } - - #[test] - fn test_literals_to_series_mismatch_ignore() { - let values = vec![ - Ok(Literal::Boolean(true)), - Ok(Literal::Int8(1)), - Err(DaftError::ValueError("null".to_string())), - ] - .into_iter(); - - let dtype = DataType::Boolean; - let res = Series::from_literals_iter(values, dtype, OnError::Null) - .expect("Failed to create series"); - assert_eq!(res.get_lit(0), Literal::Boolean(true)); - assert_eq!(res.get_lit(1), Literal::Null); - assert_eq!(res.get_lit(2), Literal::Null); - } - - #[test] - fn test_literals_to_series_mismatch_raise() { - let values = vec![ - Ok(Literal::Boolean(true)), - Ok(Literal::Int8(1)), - Err(DaftError::ValueError("null".to_string())), - ] - .into_iter(); - - let dtype = DataType::Boolean; - let res = Series::from_literals_iter(values, dtype, OnError::Raise); - assert!(res.is_err()) - } - - #[test] - fn test_literals_to_series_mismatched_and_dtype_ignore() { - let values = vec![Literal::UInt64(1), Literal::Utf8("test".to_string())]; - let actual = - Series::from_literals_iter(values.into_iter().map(Ok), DataType::Utf8, OnError::Null); - assert!(actual.is_ok()); - let actual = actual.unwrap(); - assert_eq!(actual.get_lit(0), Literal::Null); - assert_eq!(actual.get_lit(1), Literal::Utf8("test".to_string())); - } } diff --git a/src/daft-core/src/series/mod.rs b/src/daft-core/src/series/mod.rs index c65b5c7c7c..c38ccadfb6 100644 --- a/src/daft-core/src/series/mod.rs +++ b/src/daft-core/src/series/mod.rs @@ -124,7 +124,6 @@ impl Series { /// /// This function will check the provided [`Field`] (and all its associated potentially nested fields/dtypes) against /// the provided [`arrow2::array::Array`] for compatibility, and returns an error if they do not match. - #[inline] pub fn from_arrow( field: FieldRef, arrow_arr: Box, @@ -207,7 +206,7 @@ impl Series { /// Attempts to downcast the Series to a primitive slice /// This will return an error if the Series is not of the physical type `T` /// # Example - /// ```rust,ignore + /// ```rust,no_run /// let i32_arr: &[i32] = series.try_as_slice::()?; /// /// let f64_arr: &[f64] = series.try_as_slice::()?; @@ -263,7 +262,7 @@ macro_rules! series { // put into a vec first for compile-time type consistency checking let elements = vec![$($element),+]; let elements_lit = elements.into_iter().map(Literal::from).collect::>(); - Series::from_literals(elements_lit, None).unwrap() + Series::from_literals(elements_lit).unwrap() } }; } diff --git a/src/daft-dsl/src/python_udf.rs b/src/daft-dsl/src/python_udf.rs index 5f6b5b6bd2..b19687a347 100644 --- a/src/daft-dsl/src/python_udf.rs +++ b/src/daft-dsl/src/python_udf.rs @@ -191,6 +191,7 @@ impl RowWisePyFn { args: &[Series], num_rows: usize, ) -> DaftResult<(Series, std::time::Duration)> { + use daft_core::python::PySeries; use pyo3::prelude::*; let inner_ref = self.inner.as_ref(); @@ -198,33 +199,30 @@ impl RowWisePyFn { let name = args[0].name(); let start_time = std::time::Instant::now(); Python::with_gil(|py| { - use daft_core::series::from_lit::OnError; - let gil_contention_time = start_time.elapsed(); let func = py .import(pyo3::intern!(py, "daft.udf.row_wise"))? .getattr(pyo3::intern!(py, "__call_func"))?; let mut py_args = Vec::with_capacity(args.len()); - // pre-allocating py_args vector so we're not creating a new vector for each iteration - let outputs = (0..num_rows).map(|i| { - for s in args { - let idx = if s.len() == 1 { 0 } else { i }; - let lit = s.get_lit(idx); - let pyarg = lit.into_pyobject(py)?; - py_args.push(pyarg); - } - - let result = func.call1((inner_ref, args_ref, &py_args))?; - let lit = Literal::from_pyobj(&result, None)?; - - py_args.clear(); - DaftResult::Ok(lit) - }); + let outputs = (0..num_rows) + .map(|i| { + for s in args { + let idx = if s.len() == 1 { 0 } else { i }; + let lit = s.get_lit(idx); + let pyarg = lit.into_pyobject(py)?; + py_args.push(pyarg); + } + + let result = func.call1((inner_ref, args_ref, &py_args))?; + py_args.clear(); + DaftResult::Ok(result) + }) + .collect::>>()?; + Ok(( - Series::from_literals_iter(outputs, self.return_dtype.clone(), OnError::Raise)? - .rename(name), + PySeries::from_pylist_impl(name, outputs, self.return_dtype.clone())?.series, gil_contention_time, )) }) diff --git a/src/daft-sql/src/planner.rs b/src/daft-sql/src/planner.rs index 265d1390a6..658c68209b 100644 --- a/src/daft-sql/src/planner.rs +++ b/src/daft-sql/src/planner.rs @@ -1512,7 +1512,7 @@ impl SQLPlanner<'_> { }) .collect::>>()?; - let s = Series::from_literals(values, None)?; + let s = Series::from_literals(values)?; let s = FixedSizeListArray::new( Field::new("tuple", s.data_type().clone()) .to_fixed_size_list_field(exprs.len())?, From 97515ffa4f2c634305a6817b53c2e1ed5ff5da1a Mon Sep 17 00:00:00 2001 From: universalmind303 Date: Thu, 2 Oct 2025 16:16:02 -0500 Subject: [PATCH 15/15] oops --- src/daft-core/src/python/series.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/daft-core/src/python/series.rs b/src/daft-core/src/python/series.rs index e67f081c7d..8318e44e5a 100644 --- a/src/daft-core/src/python/series.rs +++ b/src/daft-core/src/python/series.rs @@ -114,7 +114,7 @@ impl PySeries { (literals_with_supertype, supertype) }; - let mut series = Series::from_literals(literals, None)?.cast(&dtype)?; + let mut series = Series::from_literals(literals)?.cast(&dtype)?; if let Some(name) = name { series = series.rename(name); }