From 1176bdc0e459a5c63ab1fba2b2502d13c59b5406 Mon Sep 17 00:00:00 2001 From: Jasonlmx <56242204+Jasonlmx@users.noreply.github.com> Date: Mon, 15 Sep 2025 13:56:01 -0400 Subject: [PATCH 01/13] Draft pr for instructions --- src/function/list/CMakeLists.txt | 1 + src/function/list/list_cosine_similarity.cpp | 65 +++++++++++++++++++ .../function/list/vector_list_functions.h | 6 ++ 3 files changed, 72 insertions(+) create mode 100644 src/function/list/list_cosine_similarity.cpp diff --git a/src/function/list/CMakeLists.txt b/src/function/list/CMakeLists.txt index d19cc2606f2..b92681952dd 100644 --- a/src/function/list/CMakeLists.txt +++ b/src/function/list/CMakeLists.txt @@ -7,6 +7,7 @@ add_library(kuzu_list_function list_append_function.cpp list_concat_function.cpp list_contains_function.cpp + list_cosine_similarity.cpp list_creation.cpp list_distinct_function.cpp list_extract_function.cpp diff --git a/src/function/list/list_cosine_similarity.cpp b/src/function/list/list_cosine_similarity.cpp new file mode 100644 index 00000000000..c03d53b5136 --- /dev/null +++ b/src/function/list/list_cosine_similarity.cpp @@ -0,0 +1,65 @@ +#include "common/exception/binder.h" +#include "common/exception/message.h" +#include "common/type_utils.h" +#include "function/list/functions/list_function_utils.h" +#include "function/list/vector_list_functions.h" +#include "function/scalar_function.h" + +using namespace kuzu::common; + +namespace kuzu { +namespace function { + +struct ListCosineSimilarity { + template + static void operation(common::list_entry_t& left, common::list_entry_t& right, T& result, common::ValueVector& leftVector, + common::ValueVector& rightVector, common::ValueVector& /*resultVector*/) { + auto leftDataVector = common::ListVector::getDataVector(&leftVector); + auto rightDataVector = common::ListVector::getDataVector(&rightVector); + result = 0; + // for test, returning the sum of elements in 2 lists + for (auto i=0u; i < left.size(); i++) { + if (leftDataVector->isNull(left.offset + i)) { + continue; + } + result += leftDataVector->getValue(left.offset + i); + } + for (auto i=0u; i < right.size(); i++) { + if (rightDataVector->isNull(right.offset + i)) { + continue; + } + result +=rightataVector->getValue(right.offset + i); + } + } +}; + +static std::unique_ptr bindFunc(const ScalarBindFuncInput& input) { + std::vector types; + auto scalarFunction = input.definition->ptrCast(); + types.push_back(input.arguments[0]->getDataType().copy()); + types.push_back(input.arguments[1]->getDataType().copy()); + const auto& resultType = ListType::getChildType(input.arguments[0]->dataType); + // justify datatypes + if ((types[0] != types[1])|| + (types[0].getLogicalTypeID() == LogicalType::INT64().getLogicalTypeID())) { + throw BinderException(stringFormat("Unsupported inner data type for {}: {}", + input.definition->name, + types[0].getLogicalTypeID() == LogicalType::INT64().getLogicalTypeID()? + LogicalTypeUtils::toString(types[1].getLogicalTypeID()): LogicalTypeUtils::toString(types[0].getLogicalTypeID()) + )); + } + return std::make_unique(std::move(types), resultType); +} + +function_set ListConsineSimilarityFunction::getFunctionSet() { + function_set result; + auto execFunc = ScalarFunction::BinaryExecFunction; + auto function = std::make_unique(name, + std::vector{LogicalTypeID::LIST, LogicalTypeID::LIST}, LogicalTypeID::INT64, execFunc); + function->bindFunc = bindFunc; + result.push_back(std::move(function)); + return result; +} + +} // namespace function +} // namespace kuzu diff --git a/src/include/function/list/vector_list_functions.h b/src/include/function/list/vector_list_functions.h index 52c08f463b7..e61f964696d 100644 --- a/src/include/function/list/vector_list_functions.h +++ b/src/include/function/list/vector_list_functions.h @@ -212,5 +212,11 @@ struct ListHasAllFunction { static function_set getFunctionSet(); }; +struct ListConsineSimilarityFunction { + static constexpr const char* name = "LIST_COSINE_SIMILARITY"; + + static function_set getFunctionSet(); +}; + } // namespace function } // namespace kuzu From 795ef5151bfbd9a6fc9dabc11f2a96b2315e8f7e Mon Sep 17 00:00:00 2001 From: Jasonlmx <56242204+Jasonlmx@users.noreply.github.com> Date: Mon, 15 Sep 2025 16:56:45 -0400 Subject: [PATCH 02/13] Draft implement of list cosine similarity --- src/function/function_collection.cpp | 1 + src/function/list/list_cosine_similarity.cpp | 105 ++++++++++++++----- 2 files changed, 77 insertions(+), 29 deletions(-) diff --git a/src/function/function_collection.cpp b/src/function/function_collection.cpp index f54cd664496..111c1fbbe16 100644 --- a/src/function/function_collection.cpp +++ b/src/function/function_collection.cpp @@ -128,6 +128,7 @@ FunctionCollection* FunctionCollection::getFunctions() { SCALAR_FUNCTION(ListReduceFunction), SCALAR_FUNCTION(ListAnyFunction), SCALAR_FUNCTION(ListAllFunction), SCALAR_FUNCTION(ListNoneFunction), SCALAR_FUNCTION(ListSingleFunction), SCALAR_FUNCTION(ListHasAllFunction), + SCALAR_FUNCTION(ListConsineSimilarityFunction), // Cast functions SCALAR_FUNCTION(CastToDateFunction), SCALAR_FUNCTION_ALIAS(DateFunction), diff --git a/src/function/list/list_cosine_similarity.cpp b/src/function/list/list_cosine_similarity.cpp index c03d53b5136..7a75aa4ae2e 100644 --- a/src/function/list/list_cosine_similarity.cpp +++ b/src/function/list/list_cosine_similarity.cpp @@ -4,6 +4,9 @@ #include "function/list/functions/list_function_utils.h" #include "function/list/vector_list_functions.h" #include "function/scalar_function.h" +#include "math.h" +#include "common/vector/value_vector.h" +#include using namespace kuzu::common; @@ -11,51 +14,95 @@ namespace kuzu { namespace function { struct ListCosineSimilarity { - template + template static void operation(common::list_entry_t& left, common::list_entry_t& right, T& result, common::ValueVector& leftVector, common::ValueVector& rightVector, common::ValueVector& /*resultVector*/) { - auto leftDataVector = common::ListVector::getDataVector(&leftVector); - auto rightDataVector = common::ListVector::getDataVector(&rightVector); - result = 0; - // for test, returning the sum of elements in 2 lists - for (auto i=0u; i < left.size(); i++) { - if (leftDataVector->isNull(left.offset + i)) { - continue; - } - result += leftDataVector->getValue(left.offset + i); - } - for (auto i=0u; i < right.size(); i++) { - if (rightDataVector->isNull(right.offset + i)) { - continue; - } - result +=rightataVector->getValue(right.offset + i); + auto leftElements = (T*)common::ListVector::getListValues(&leftVector, left); + auto rightElements = (T*)common::ListVector::getListValues(&rightVector, right); + KU_ASSERT(left.size == right.size); + simsimd_distance_t tmpResult = 0.0; + static_assert(std::is_same_v || std::is_same_v); + if constexpr (std::is_same_v) { + simsimd_cos_f32(leftElements, rightElements, left.size, &tmpResult); + } else { + simsimd_cos_f64(leftElements, rightElements, left.size, &tmpResult); } + result = 1.0 - tmpResult; } }; +static void validateChildType(const LogicalType& type, const std::string& functionName) { + switch (type.getLogicalTypeID()) { + case LogicalTypeID::DOUBLE: + case LogicalTypeID::FLOAT: + return; + default: + throw BinderException( + stringFormat("{} requires argument type to be FLOAT[] or DOUBLE[].", functionName)); + } +} + +static LogicalType validateListFunctionParameters(const LogicalType& leftType, + const LogicalType& rightType, const std::string& functionName) { + const auto& leftChildType = ListType::getChildType(leftType); + const auto& rightChildType = ListType::getChildType(rightType); + validateChildType(leftChildType, functionName); + validateChildType(rightChildType, functionName); + if (leftType.getLogicalTypeID() == common::LogicalTypeID::LIST) { + return leftType.copy(); + } else if (rightType.getLogicalTypeID() == common::LogicalTypeID::LIST) { + return rightType.copy(); + } + throw BinderException( + stringFormat("{} requires at least one argument to be LIST.", + functionName)); +} + +template +static scalar_func_exec_t getBinaryListExecFuncSwitchResultType() { + auto execFunc = + ScalarFunction::BinaryExecListStructFunction; + return execFunc; +} + +template +scalar_func_exec_t getScalarExecFunc(LogicalType type) { + scalar_func_exec_t execFunc; + switch (ListType::getChildType(type).getLogicalTypeID()) { + case LogicalTypeID::FLOAT: + execFunc = getBinaryListExecFuncSwitchResultType(); + break; + case LogicalTypeID::DOUBLE: + execFunc = getBinaryListExecFuncSwitchResultType(); + break; + default: + KU_UNREACHABLE; + } + return execFunc; +} + static std::unique_ptr bindFunc(const ScalarBindFuncInput& input) { std::vector types; - auto scalarFunction = input.definition->ptrCast(); + //auto scalarFunction = input.definition->ptrCast(); types.push_back(input.arguments[0]->getDataType().copy()); types.push_back(input.arguments[1]->getDataType().copy()); - const auto& resultType = ListType::getChildType(input.arguments[0]->dataType); - // justify datatypes - if ((types[0] != types[1])|| - (types[0].getLogicalTypeID() == LogicalType::INT64().getLogicalTypeID())) { - throw BinderException(stringFormat("Unsupported inner data type for {}: {}", - input.definition->name, - types[0].getLogicalTypeID() == LogicalType::INT64().getLogicalTypeID()? - LogicalTypeUtils::toString(types[1].getLogicalTypeID()): LogicalTypeUtils::toString(types[0].getLogicalTypeID()) - )); + auto paramType = validateListFunctionParameters(types[0], types[1], input.definition->name); + //const auto& resultType = ListType::getChildType(input.arguments[0]->dataType); + input.definition->ptrCast()->execFunc = std::move(getScalarExecFunc(paramType.copy())); + auto bindData = std::make_unique(ListType::getChildType(paramType).copy()); + std::vector paramTypes; + for (auto& _ : input.arguments) { + (void)_; + bindData->paramTypes.push_back(paramType.copy()); } - return std::make_unique(std::move(types), resultType); + return bindData; } function_set ListConsineSimilarityFunction::getFunctionSet() { function_set result; - auto execFunc = ScalarFunction::BinaryExecFunction; + //auto execFunc = ScalarFunction::BinaryExecListStructFunction; auto function = std::make_unique(name, - std::vector{LogicalTypeID::LIST, LogicalTypeID::LIST}, LogicalTypeID::INT64, execFunc); + std::vector{LogicalTypeID::LIST, LogicalTypeID::LIST}, LogicalTypeID::ANY); function->bindFunc = bindFunc; result.push_back(std::move(function)); return result; From 3853f9d4ed62f58f6bce27ba2835ef9de6ce3f97 Mon Sep 17 00:00:00 2001 From: Jasonlmx <56242204+Jasonlmx@users.noreply.github.com> Date: Tue, 16 Sep 2025 11:03:59 -0400 Subject: [PATCH 03/13] Add list cosine distance and list distance, modifed list_cosine_similarity.cpp into list_binary float function --- src/function/function_collection.cpp | 3 +- src/function/list/CMakeLists.txt | 2 +- ...ity.cpp => list_binary_float_function.cpp} | 72 +++++++++++++++++-- .../function/list/vector_list_functions.h | 14 +++- 4 files changed, 81 insertions(+), 10 deletions(-) rename src/function/list/{list_cosine_similarity.cpp => list_binary_float_function.cpp} (57%) diff --git a/src/function/function_collection.cpp b/src/function/function_collection.cpp index 111c1fbbe16..ecc50b17cb0 100644 --- a/src/function/function_collection.cpp +++ b/src/function/function_collection.cpp @@ -128,7 +128,8 @@ FunctionCollection* FunctionCollection::getFunctions() { SCALAR_FUNCTION(ListReduceFunction), SCALAR_FUNCTION(ListAnyFunction), SCALAR_FUNCTION(ListAllFunction), SCALAR_FUNCTION(ListNoneFunction), SCALAR_FUNCTION(ListSingleFunction), SCALAR_FUNCTION(ListHasAllFunction), - SCALAR_FUNCTION(ListConsineSimilarityFunction), + SCALAR_FUNCTION(ListCosineSimilarityFunction),SCALAR_FUNCTION(ListCosineDistanceFunction), + SCALAR_FUNCTION(ListDistanceFunction), // Cast functions SCALAR_FUNCTION(CastToDateFunction), SCALAR_FUNCTION_ALIAS(DateFunction), diff --git a/src/function/list/CMakeLists.txt b/src/function/list/CMakeLists.txt index b92681952dd..691c1e585e7 100644 --- a/src/function/list/CMakeLists.txt +++ b/src/function/list/CMakeLists.txt @@ -7,7 +7,7 @@ add_library(kuzu_list_function list_append_function.cpp list_concat_function.cpp list_contains_function.cpp - list_cosine_similarity.cpp + list_binary_float_function.cpp list_creation.cpp list_distinct_function.cpp list_extract_function.cpp diff --git a/src/function/list/list_cosine_similarity.cpp b/src/function/list/list_binary_float_function.cpp similarity index 57% rename from src/function/list/list_cosine_similarity.cpp rename to src/function/list/list_binary_float_function.cpp index 7a75aa4ae2e..85f6b888874 100644 --- a/src/function/list/list_cosine_similarity.cpp +++ b/src/function/list/list_binary_float_function.cpp @@ -19,6 +19,10 @@ struct ListCosineSimilarity { common::ValueVector& rightVector, common::ValueVector& /*resultVector*/) { auto leftElements = (T*)common::ListVector::getListValues(&leftVector, left); auto rightElements = (T*)common::ListVector::getListValues(&rightVector, right); + if (left.size!=right.size) { + throw BinderException(stringFormat("LIST_COSINE_SIMILARITY requires both arguments to be in same size: left : {} ; right : {}" + ,left.size, right.size)); + } KU_ASSERT(left.size == right.size); simsimd_distance_t tmpResult = 0.0; static_assert(std::is_same_v || std::is_same_v); @@ -31,6 +35,50 @@ struct ListCosineSimilarity { } }; +struct ListCosineDistance { + template + static void operation(common::list_entry_t& left, common::list_entry_t& right, T& result, common::ValueVector& leftVector, + common::ValueVector& rightVector, common::ValueVector& /*resultVector*/) { + auto leftElements = (T*)common::ListVector::getListValues(&leftVector, left); + auto rightElements = (T*)common::ListVector::getListValues(&rightVector, right); + if (left.size!=right.size) { + throw BinderException(stringFormat("LIST_COSINE_DISTANCE requires both arguments to be in same size: left : {} ; right : {}" + ,left.size, right.size)); + } + KU_ASSERT(left.size == right.size); + simsimd_distance_t tmpResult = 0.0; + static_assert(std::is_same_v || std::is_same_v); + if constexpr (std::is_same_v) { + simsimd_cos_f32(leftElements, rightElements, left.size, &tmpResult); + } else { + simsimd_cos_f64(leftElements, rightElements, left.size, &tmpResult); + } + result = tmpResult; + } +}; + +struct ListDistance { + template + static void operation(common::list_entry_t& left, common::list_entry_t& right, T& result, common::ValueVector& leftVector, + common::ValueVector& rightVector, common::ValueVector& /*resultVector*/) { + auto leftElements = (T*)common::ListVector::getListValues(&leftVector, left); + auto rightElements = (T*)common::ListVector::getListValues(&rightVector, right); + if (left.size!=right.size) { + throw BinderException(stringFormat("LIST_DISTANCE requires both arguments to be in same size: left : {} ; right : {}" + ,left.size, right.size)); + } + KU_ASSERT(left.size == right.size); + simsimd_distance_t tmpResult = 0.0; + static_assert(std::is_same_v || std::is_same_v); + if constexpr (std::is_same_v) { + simsimd_l2sq_f32(leftElements, rightElements, left.size, &tmpResult); + } else { + simsimd_l2sq_f64(leftElements, rightElements, left.size, &tmpResult); + } + result = std::sqrt(tmpResult); + } +}; + static void validateChildType(const LogicalType& type, const std::string& functionName) { switch (type.getLogicalTypeID()) { case LogicalTypeID::DOUBLE: @@ -81,14 +129,13 @@ scalar_func_exec_t getScalarExecFunc(LogicalType type) { return execFunc; } +template static std::unique_ptr bindFunc(const ScalarBindFuncInput& input) { std::vector types; - //auto scalarFunction = input.definition->ptrCast(); types.push_back(input.arguments[0]->getDataType().copy()); types.push_back(input.arguments[1]->getDataType().copy()); auto paramType = validateListFunctionParameters(types[0], types[1], input.definition->name); - //const auto& resultType = ListType::getChildType(input.arguments[0]->dataType); - input.definition->ptrCast()->execFunc = std::move(getScalarExecFunc(paramType.copy())); + input.definition->ptrCast()->execFunc = std::move(getScalarExecFunc(paramType.copy())); auto bindData = std::make_unique(ListType::getChildType(paramType).copy()); std::vector paramTypes; for (auto& _ : input.arguments) { @@ -97,16 +144,27 @@ static std::unique_ptr bindFunc(const ScalarBindFuncInput& inp } return bindData; } - -function_set ListConsineSimilarityFunction::getFunctionSet() { +template +function_set templateGetFunctionSet(const std::string& name) { function_set result; - //auto execFunc = ScalarFunction::BinaryExecListStructFunction; auto function = std::make_unique(name, std::vector{LogicalTypeID::LIST, LogicalTypeID::LIST}, LogicalTypeID::ANY); - function->bindFunc = bindFunc; + function->bindFunc = bindFunc; result.push_back(std::move(function)); return result; } +function_set ListCosineSimilarityFunction::getFunctionSet() { + return templateGetFunctionSet(name); +} + +function_set ListCosineDistanceFunction::getFunctionSet() { + return templateGetFunctionSet(name); +} + +function_set ListDistanceFunction::getFunctionSet() { + return templateGetFunctionSet(name); +} + } // namespace function } // namespace kuzu diff --git a/src/include/function/list/vector_list_functions.h b/src/include/function/list/vector_list_functions.h index e61f964696d..36276d84065 100644 --- a/src/include/function/list/vector_list_functions.h +++ b/src/include/function/list/vector_list_functions.h @@ -212,11 +212,23 @@ struct ListHasAllFunction { static function_set getFunctionSet(); }; -struct ListConsineSimilarityFunction { +struct ListCosineSimilarityFunction { static constexpr const char* name = "LIST_COSINE_SIMILARITY"; static function_set getFunctionSet(); }; +struct ListCosineDistanceFunction { + static constexpr const char* name = "LIST_COSINE_DISTANCE"; + + static function_set getFunctionSet(); +}; + +struct ListDistanceFunction { + static constexpr const char* name = "LIST_DISTANCE"; + + static function_set getFunctionSet(); +}; + } // namespace function } // namespace kuzu From c67dfeb161904daff89c5986428ac2672b536e33 Mon Sep 17 00:00:00 2001 From: Jasonlmx <56242204+Jasonlmx@users.noreply.github.com> Date: Tue, 16 Sep 2025 14:57:05 -0400 Subject: [PATCH 04/13] Add list has any function --- src/function/function_collection.cpp | 2 +- src/function/list/CMakeLists.txt | 3 +- src/function/list/list_has_all.cpp | 2 +- src/function/list/list_has_any.cpp | 68 +++++++++++++++++++ .../function/list/vector_list_functions.h | 6 ++ 5 files changed, 78 insertions(+), 3 deletions(-) create mode 100644 src/function/list/list_has_any.cpp diff --git a/src/function/function_collection.cpp b/src/function/function_collection.cpp index ecc50b17cb0..0f9aca3edfc 100644 --- a/src/function/function_collection.cpp +++ b/src/function/function_collection.cpp @@ -129,7 +129,7 @@ FunctionCollection* FunctionCollection::getFunctions() { SCALAR_FUNCTION(ListAllFunction), SCALAR_FUNCTION(ListNoneFunction), SCALAR_FUNCTION(ListSingleFunction), SCALAR_FUNCTION(ListHasAllFunction), SCALAR_FUNCTION(ListCosineSimilarityFunction),SCALAR_FUNCTION(ListCosineDistanceFunction), - SCALAR_FUNCTION(ListDistanceFunction), + SCALAR_FUNCTION(ListDistanceFunction),SCALAR_FUNCTION(ListHasAnyFunction), // Cast functions SCALAR_FUNCTION(CastToDateFunction), SCALAR_FUNCTION_ALIAS(DateFunction), diff --git a/src/function/list/CMakeLists.txt b/src/function/list/CMakeLists.txt index 691c1e585e7..79f4cd3ff94 100644 --- a/src/function/list/CMakeLists.txt +++ b/src/function/list/CMakeLists.txt @@ -27,7 +27,8 @@ add_library(kuzu_list_function list_single.cpp size_function.cpp quantifier_functions.cpp - list_has_all.cpp) + list_has_all.cpp + list_has_any.cpp) set(ALL_OBJECT_FILES ${ALL_OBJECT_FILES} $ diff --git a/src/function/list/list_has_all.cpp b/src/function/list/list_has_all.cpp index a62c0f06222..f235664b1f7 100644 --- a/src/function/list/list_has_all.cpp +++ b/src/function/list/list_has_all.cpp @@ -35,7 +35,7 @@ struct ListHasAll { } }; -std::unique_ptr bindFunc(const ScalarBindFuncInput& input) { +static std::unique_ptr bindFunc(const ScalarBindFuncInput& input) { std::vector types; for (auto& arg : input.arguments) { if (arg->dataType == LogicalType::ANY()) { diff --git a/src/function/list/list_has_any.cpp b/src/function/list/list_has_any.cpp new file mode 100644 index 00000000000..187d08f946f --- /dev/null +++ b/src/function/list/list_has_any.cpp @@ -0,0 +1,68 @@ +#include "common/exception/binder.h" +#include "common/exception/message.h" +#include "common/type_utils.h" +#include "function/list/functions/list_position_function.h" +#include "function/list/vector_list_functions.h" +#include "function/scalar_function.h" + +using namespace kuzu::common; + +namespace kuzu { +namespace function { + +struct ListHasAny { + static void operation(common::list_entry_t& left, common::list_entry_t& right, uint8_t& result, + common::ValueVector& leftVector, common::ValueVector& rightVector, + common::ValueVector& resultVector) { + int64_t pos = 0; + auto rightDataVector = ListVector::getDataVector(&rightVector); + result = false; + for (auto i = 0u; i < right.size; i++) { + common::TypeUtils::visit(ListType::getChildType(rightVector.dataType).getPhysicalType(), + [&](T) { + if (rightDataVector->isNull(right.offset + i)) { + return; + } + ListPosition::operation(left, + *(T*)ListVector::getListValuesWithOffset(&rightVector, right, i), pos, + leftVector, *ListVector::getDataVector(&rightVector), resultVector); + result = (pos != 0); + }); + if (result) { + return; + } + } + } +}; + +static std::unique_ptr bindFunc(const ScalarBindFuncInput& input) { + std::vector types; + for (auto& arg : input.arguments) { + if (arg->dataType == LogicalType::ANY()) { + types.push_back(LogicalType::LIST(LogicalType::INT64())); + } else { + types.push_back(arg->dataType.copy()); + } + } + if (types[0] != types[1]) { + throw common::BinderException(ExceptionMessage::listFunctionIncompatibleChildrenType( + ListHasAnyFunction::name, input.arguments[0]->getDataType().toString(), + input.arguments[1]->getDataType().toString())); + } + return std::make_unique(std::move(types), LogicalType::BOOL()); +} + +function_set ListHasAnyFunction::getFunctionSet() { + function_set result; + auto execFunc = ScalarFunction::BinaryExecListStructFunction; + auto function = std::make_unique(name, + std::vector{LogicalTypeID::LIST, LogicalTypeID::LIST}, LogicalTypeID::BOOL, + execFunc); + function->bindFunc = bindFunc; + result.push_back(std::move(function)); + return result; +} + +} // namespace function +} // namespace kuzu diff --git a/src/include/function/list/vector_list_functions.h b/src/include/function/list/vector_list_functions.h index 36276d84065..cf21a89de14 100644 --- a/src/include/function/list/vector_list_functions.h +++ b/src/include/function/list/vector_list_functions.h @@ -212,6 +212,12 @@ struct ListHasAllFunction { static function_set getFunctionSet(); }; +struct ListHasAnyFunction { + static constexpr const char* name = "LIST_HAS_ANY"; + + static function_set getFunctionSet(); +}; + struct ListCosineSimilarityFunction { static constexpr const char* name = "LIST_COSINE_SIMILARITY"; From 661da210b77d61aa9be60b6327c1a7b6074f84c0 Mon Sep 17 00:00:00 2001 From: Jasonlmx <56242204+Jasonlmx@users.noreply.github.com> Date: Wed, 17 Sep 2025 15:14:34 -0400 Subject: [PATCH 05/13] Add support of list funtion List Intersect --- src/function/function_collection.cpp | 1 + src/function/list/CMakeLists.txt | 3 +- src/function/list/list_intersect.cpp | 87 +++++++++++++++++++ .../function/list/vector_list_functions.h | 6 ++ 4 files changed, 96 insertions(+), 1 deletion(-) create mode 100644 src/function/list/list_intersect.cpp diff --git a/src/function/function_collection.cpp b/src/function/function_collection.cpp index 0f9aca3edfc..b482629203b 100644 --- a/src/function/function_collection.cpp +++ b/src/function/function_collection.cpp @@ -130,6 +130,7 @@ FunctionCollection* FunctionCollection::getFunctions() { SCALAR_FUNCTION(ListSingleFunction), SCALAR_FUNCTION(ListHasAllFunction), SCALAR_FUNCTION(ListCosineSimilarityFunction),SCALAR_FUNCTION(ListCosineDistanceFunction), SCALAR_FUNCTION(ListDistanceFunction),SCALAR_FUNCTION(ListHasAnyFunction), + SCALAR_FUNCTION(ListIntersectFunction), // Cast functions SCALAR_FUNCTION(CastToDateFunction), SCALAR_FUNCTION_ALIAS(DateFunction), diff --git a/src/function/list/CMakeLists.txt b/src/function/list/CMakeLists.txt index 79f4cd3ff94..6366fe07321 100644 --- a/src/function/list/CMakeLists.txt +++ b/src/function/list/CMakeLists.txt @@ -28,7 +28,8 @@ add_library(kuzu_list_function size_function.cpp quantifier_functions.cpp list_has_all.cpp - list_has_any.cpp) + list_has_any.cpp + list_intersect.cpp) set(ALL_OBJECT_FILES ${ALL_OBJECT_FILES} $ diff --git a/src/function/list/list_intersect.cpp b/src/function/list/list_intersect.cpp new file mode 100644 index 00000000000..7778171f70a --- /dev/null +++ b/src/function/list/list_intersect.cpp @@ -0,0 +1,87 @@ +#include "common/exception/binder.h" +#include "common/exception/message.h" +#include "common/type_utils.h" +#include "function/list/functions/list_function_utils.h" +#include "function/list/functions/list_position_function.h" +#include "function/list/functions/list_unique_function.h" +#include "function/list/vector_list_functions.h" +#include "function/scalar_function.h" + +using namespace kuzu::common; + +namespace kuzu { +namespace function { + +struct ListIntersect { + static void operation(common::list_entry_t& left, common::list_entry_t& right, + common::list_entry_t& result, common::ValueVector& leftVector, common::ValueVector& rightVector, + common::ValueVector& resultVector) { + int64_t pos = 0; + result = common::ListVector::addList(&resultVector, left.size>right.size? left.size:right.size); + auto resultDataVector = common::ListVector::getDataVector(&resultVector); + auto resultPos = result.offset; + auto rightDataVector = common::ListVector::getDataVector(&rightVector); + auto rightPos = right.offset; + std::vector rightOffsets; + for (auto i=0u; i < right.size; i++) { + common::TypeUtils::visit(ListType::getChildType(rightVector.dataType).getPhysicalType(), + [&](T) { + if (rightDataVector->isNull(right.offset + i)) { + return; + } + ListPosition::operation(left, + *(T*)ListVector::getListValuesWithOffset(&rightVector, right, i), pos, + leftVector, *ListVector::getDataVector(&rightVector), resultVector); + }); + if (pos !=0) { + //resultDataVector->copyFromVectorData(resultPos++, rightDataVector, rightPos+i); + rightOffsets.push_back(rightPos+i); + } + } + common::ValueVector tempVec(kuzu::common::LogicalType::LIST(rightDataVector->dataType.getLogicalTypeID()), nullptr); + auto tempDataVec=common::ListVector::getDataVector(&tempVec); + auto temp = common::ListVector::addList(&tempVec, rightOffsets.size()); + auto tempPos = temp.offset; + for (auto i=0u; icopyFromVectorData(tempPos++, rightDataVector, rightPos+rightOffsets.at(i)); + } + auto numUniqueValues = ListUnique::appendListElementsToValueSet(temp, tempVec); + result = common::ListVector::addList(&resultVector, numUniqueValues); + resultDataVector = common::ListVector::getDataVector(&resultVector); + auto resultDataVectorBuffer = + common::ListVector::getListValuesWithOffset(&resultVector, result, 0 /* offset */); + ListUnique::appendListElementsToValueSet(temp, tempVec, nullptr, + [&resultDataVector, &resultDataVectorBuffer](common::ValueVector& dataVector, + uint64_t pos) -> void { + resultDataVector->copyFromVectorData(resultDataVectorBuffer, &dataVector, + dataVector.getData() + pos * dataVector.getNumBytesPerValue()); + resultDataVectorBuffer += dataVector.getNumBytesPerValue(); + }); + } + +}; +static std::unique_ptr bindFunc(const ScalarBindFuncInput& input) { + std::vector types; + types.push_back(input.arguments[0]->getDataType().copy()); + types.push_back(input.arguments[1]->getDataType().copy()); + if (types[0] != types[1]) { + throw BinderException(ExceptionMessage::listFunctionIncompatibleChildrenType(ListIntersectFunction::name, + types[0].toString(), types[1].toString())); + } + return std::make_unique(std::move(types), types[0].copy()); +} + +function_set ListIntersectFunction::getFunctionSet() { + function_set result; + auto execFunc = ScalarFunction::BinaryExecListStructFunction; + auto function = std::make_unique(name, + std::vector{LogicalTypeID::LIST, LogicalTypeID::LIST}, LogicalTypeID::LIST, + execFunc); + function->bindFunc = bindFunc; + result.push_back(std::move(function)); + return result; +} + +} // namespace function +} // namespace kuzu diff --git a/src/include/function/list/vector_list_functions.h b/src/include/function/list/vector_list_functions.h index cf21a89de14..8d896b77c56 100644 --- a/src/include/function/list/vector_list_functions.h +++ b/src/include/function/list/vector_list_functions.h @@ -236,5 +236,11 @@ struct ListDistanceFunction { static function_set getFunctionSet(); }; +struct ListIntersectFunction { + static constexpr const char* name = "LIST_INTERSECT"; + + static function_set getFunctionSet(); +}; + } // namespace function } // namespace kuzu From 3955f444cba64b8553c6c38ebd9ae6fded64bd20 Mon Sep 17 00:00:00 2001 From: Jasonlmx <56242204+Jasonlmx@users.noreply.github.com> Date: Wed, 17 Sep 2025 15:17:50 -0400 Subject: [PATCH 06/13] Removes unnecessary lines in list_intersect --- src/function/list/list_intersect.cpp | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/function/list/list_intersect.cpp b/src/function/list/list_intersect.cpp index 7778171f70a..1bfac31f53e 100644 --- a/src/function/list/list_intersect.cpp +++ b/src/function/list/list_intersect.cpp @@ -17,9 +17,6 @@ struct ListIntersect { common::list_entry_t& result, common::ValueVector& leftVector, common::ValueVector& rightVector, common::ValueVector& resultVector) { int64_t pos = 0; - result = common::ListVector::addList(&resultVector, left.size>right.size? left.size:right.size); - auto resultDataVector = common::ListVector::getDataVector(&resultVector); - auto resultPos = result.offset; auto rightDataVector = common::ListVector::getDataVector(&rightVector); auto rightPos = right.offset; std::vector rightOffsets; @@ -34,7 +31,6 @@ struct ListIntersect { leftVector, *ListVector::getDataVector(&rightVector), resultVector); }); if (pos !=0) { - //resultDataVector->copyFromVectorData(resultPos++, rightDataVector, rightPos+i); rightOffsets.push_back(rightPos+i); } } @@ -47,7 +43,7 @@ struct ListIntersect { } auto numUniqueValues = ListUnique::appendListElementsToValueSet(temp, tempVec); result = common::ListVector::addList(&resultVector, numUniqueValues); - resultDataVector = common::ListVector::getDataVector(&resultVector); + auto resultDataVector = common::ListVector::getDataVector(&resultVector); auto resultDataVectorBuffer = common::ListVector::getListValuesWithOffset(&resultVector, result, 0 /* offset */); ListUnique::appendListElementsToValueSet(temp, tempVec, nullptr, From 8125e1d4bbafac04d9bf7d49ea881deb35a38883 Mon Sep 17 00:00:00 2001 From: CI Bot Date: Wed, 17 Sep 2025 19:20:01 +0000 Subject: [PATCH 07/13] ci: auto code format --- src/function/function_collection.cpp | 4 +- .../list/list_binary_float_function.cpp | 59 +++++++++++-------- src/function/list/list_intersect.cpp | 27 +++++---- .../function/list/vector_list_functions.h | 2 +- 4 files changed, 50 insertions(+), 42 deletions(-) diff --git a/src/function/function_collection.cpp b/src/function/function_collection.cpp index b482629203b..659516d8a5c 100644 --- a/src/function/function_collection.cpp +++ b/src/function/function_collection.cpp @@ -128,8 +128,8 @@ FunctionCollection* FunctionCollection::getFunctions() { SCALAR_FUNCTION(ListReduceFunction), SCALAR_FUNCTION(ListAnyFunction), SCALAR_FUNCTION(ListAllFunction), SCALAR_FUNCTION(ListNoneFunction), SCALAR_FUNCTION(ListSingleFunction), SCALAR_FUNCTION(ListHasAllFunction), - SCALAR_FUNCTION(ListCosineSimilarityFunction),SCALAR_FUNCTION(ListCosineDistanceFunction), - SCALAR_FUNCTION(ListDistanceFunction),SCALAR_FUNCTION(ListHasAnyFunction), + SCALAR_FUNCTION(ListCosineSimilarityFunction), SCALAR_FUNCTION(ListCosineDistanceFunction), + SCALAR_FUNCTION(ListDistanceFunction), SCALAR_FUNCTION(ListHasAnyFunction), SCALAR_FUNCTION(ListIntersectFunction), // Cast functions diff --git a/src/function/list/list_binary_float_function.cpp b/src/function/list/list_binary_float_function.cpp index 85f6b888874..cede4fd51a1 100644 --- a/src/function/list/list_binary_float_function.cpp +++ b/src/function/list/list_binary_float_function.cpp @@ -1,11 +1,12 @@ +#include "math.h" + #include "common/exception/binder.h" #include "common/exception/message.h" #include "common/type_utils.h" +#include "common/vector/value_vector.h" #include "function/list/functions/list_function_utils.h" #include "function/list/vector_list_functions.h" #include "function/scalar_function.h" -#include "math.h" -#include "common/vector/value_vector.h" #include using namespace kuzu::common; @@ -14,14 +15,16 @@ namespace kuzu { namespace function { struct ListCosineSimilarity { - template - static void operation(common::list_entry_t& left, common::list_entry_t& right, T& result, common::ValueVector& leftVector, - common::ValueVector& rightVector, common::ValueVector& /*resultVector*/) { + template + static void operation(common::list_entry_t& left, common::list_entry_t& right, T& result, + common::ValueVector& leftVector, common::ValueVector& rightVector, + common::ValueVector& /*resultVector*/) { auto leftElements = (T*)common::ListVector::getListValues(&leftVector, left); auto rightElements = (T*)common::ListVector::getListValues(&rightVector, right); - if (left.size!=right.size) { - throw BinderException(stringFormat("LIST_COSINE_SIMILARITY requires both arguments to be in same size: left : {} ; right : {}" - ,left.size, right.size)); + if (left.size != right.size) { + throw BinderException(stringFormat("LIST_COSINE_SIMILARITY requires both arguments to " + "be in same size: left : {} ; right : {}", + left.size, right.size)); } KU_ASSERT(left.size == right.size); simsimd_distance_t tmpResult = 0.0; @@ -36,14 +39,16 @@ struct ListCosineSimilarity { }; struct ListCosineDistance { - template - static void operation(common::list_entry_t& left, common::list_entry_t& right, T& result, common::ValueVector& leftVector, - common::ValueVector& rightVector, common::ValueVector& /*resultVector*/) { + template + static void operation(common::list_entry_t& left, common::list_entry_t& right, T& result, + common::ValueVector& leftVector, common::ValueVector& rightVector, + common::ValueVector& /*resultVector*/) { auto leftElements = (T*)common::ListVector::getListValues(&leftVector, left); auto rightElements = (T*)common::ListVector::getListValues(&rightVector, right); - if (left.size!=right.size) { - throw BinderException(stringFormat("LIST_COSINE_DISTANCE requires both arguments to be in same size: left : {} ; right : {}" - ,left.size, right.size)); + if (left.size != right.size) { + throw BinderException(stringFormat("LIST_COSINE_DISTANCE requires both arguments to be " + "in same size: left : {} ; right : {}", + left.size, right.size)); } KU_ASSERT(left.size == right.size); simsimd_distance_t tmpResult = 0.0; @@ -58,14 +63,16 @@ struct ListCosineDistance { }; struct ListDistance { - template - static void operation(common::list_entry_t& left, common::list_entry_t& right, T& result, common::ValueVector& leftVector, - common::ValueVector& rightVector, common::ValueVector& /*resultVector*/) { + template + static void operation(common::list_entry_t& left, common::list_entry_t& right, T& result, + common::ValueVector& leftVector, common::ValueVector& rightVector, + common::ValueVector& /*resultVector*/) { auto leftElements = (T*)common::ListVector::getListValues(&leftVector, left); auto rightElements = (T*)common::ListVector::getListValues(&rightVector, right); - if (left.size!=right.size) { - throw BinderException(stringFormat("LIST_DISTANCE requires both arguments to be in same size: left : {} ; right : {}" - ,left.size, right.size)); + if (left.size != right.size) { + throw BinderException(stringFormat( + "LIST_DISTANCE requires both arguments to be in same size: left : {} ; right : {}", + left.size, right.size)); } KU_ASSERT(left.size == right.size); simsimd_distance_t tmpResult = 0.0; @@ -102,8 +109,7 @@ static LogicalType validateListFunctionParameters(const LogicalType& leftType, return rightType.copy(); } throw BinderException( - stringFormat("{} requires at least one argument to be LIST.", - functionName)); + stringFormat("{} requires at least one argument to be LIST.", functionName)); } template @@ -129,13 +135,14 @@ scalar_func_exec_t getScalarExecFunc(LogicalType type) { return execFunc; } -template +template static std::unique_ptr bindFunc(const ScalarBindFuncInput& input) { std::vector types; types.push_back(input.arguments[0]->getDataType().copy()); types.push_back(input.arguments[1]->getDataType().copy()); auto paramType = validateListFunctionParameters(types[0], types[1], input.definition->name); - input.definition->ptrCast()->execFunc = std::move(getScalarExecFunc(paramType.copy())); + input.definition->ptrCast()->execFunc = + std::move(getScalarExecFunc(paramType.copy())); auto bindData = std::make_unique(ListType::getChildType(paramType).copy()); std::vector paramTypes; for (auto& _ : input.arguments) { @@ -144,10 +151,10 @@ static std::unique_ptr bindFunc(const ScalarBindFuncInput& inp } return bindData; } -template +template function_set templateGetFunctionSet(const std::string& name) { function_set result; - auto function = std::make_unique(name, + auto function = std::make_unique(name, std::vector{LogicalTypeID::LIST, LogicalTypeID::LIST}, LogicalTypeID::ANY); function->bindFunc = bindFunc; result.push_back(std::move(function)); diff --git a/src/function/list/list_intersect.cpp b/src/function/list/list_intersect.cpp index 1bfac31f53e..49250a79ca9 100644 --- a/src/function/list/list_intersect.cpp +++ b/src/function/list/list_intersect.cpp @@ -14,15 +14,15 @@ namespace function { struct ListIntersect { static void operation(common::list_entry_t& left, common::list_entry_t& right, - common::list_entry_t& result, common::ValueVector& leftVector, common::ValueVector& rightVector, - common::ValueVector& resultVector) { + common::list_entry_t& result, common::ValueVector& leftVector, + common::ValueVector& rightVector, common::ValueVector& resultVector) { int64_t pos = 0; auto rightDataVector = common::ListVector::getDataVector(&rightVector); auto rightPos = right.offset; std::vector rightOffsets; - for (auto i=0u; i < right.size; i++) { + for (auto i = 0u; i < right.size; i++) { common::TypeUtils::visit(ListType::getChildType(rightVector.dataType).getPhysicalType(), - [&](T) { + [&](T) { if (rightDataVector->isNull(right.offset + i)) { return; } @@ -30,16 +30,18 @@ struct ListIntersect { *(T*)ListVector::getListValuesWithOffset(&rightVector, right, i), pos, leftVector, *ListVector::getDataVector(&rightVector), resultVector); }); - if (pos !=0) { - rightOffsets.push_back(rightPos+i); + if (pos != 0) { + rightOffsets.push_back(rightPos + i); } } - common::ValueVector tempVec(kuzu::common::LogicalType::LIST(rightDataVector->dataType.getLogicalTypeID()), nullptr); - auto tempDataVec=common::ListVector::getDataVector(&tempVec); + common::ValueVector tempVec( + kuzu::common::LogicalType::LIST(rightDataVector->dataType.getLogicalTypeID()), nullptr); + auto tempDataVec = common::ListVector::getDataVector(&tempVec); auto temp = common::ListVector::addList(&tempVec, rightOffsets.size()); auto tempPos = temp.offset; - for (auto i=0u; icopyFromVectorData(tempPos++, rightDataVector, rightPos+rightOffsets.at(i)); + for (auto i = 0u; i < rightOffsets.size(); i++) { + tempDataVec->copyFromVectorData(tempPos++, rightDataVector, + rightPos + rightOffsets.at(i)); } auto numUniqueValues = ListUnique::appendListElementsToValueSet(temp, tempVec); result = common::ListVector::addList(&resultVector, numUniqueValues); @@ -54,15 +56,14 @@ struct ListIntersect { resultDataVectorBuffer += dataVector.getNumBytesPerValue(); }); } - }; static std::unique_ptr bindFunc(const ScalarBindFuncInput& input) { std::vector types; types.push_back(input.arguments[0]->getDataType().copy()); types.push_back(input.arguments[1]->getDataType().copy()); if (types[0] != types[1]) { - throw BinderException(ExceptionMessage::listFunctionIncompatibleChildrenType(ListIntersectFunction::name, - types[0].toString(), types[1].toString())); + throw BinderException(ExceptionMessage::listFunctionIncompatibleChildrenType( + ListIntersectFunction::name, types[0].toString(), types[1].toString())); } return std::make_unique(std::move(types), types[0].copy()); } diff --git a/src/include/function/list/vector_list_functions.h b/src/include/function/list/vector_list_functions.h index 8d896b77c56..e8b21370775 100644 --- a/src/include/function/list/vector_list_functions.h +++ b/src/include/function/list/vector_list_functions.h @@ -236,7 +236,7 @@ struct ListDistanceFunction { static function_set getFunctionSet(); }; -struct ListIntersectFunction { +struct ListIntersectFunction { static constexpr const char* name = "LIST_INTERSECT"; static function_set getFunctionSet(); From a3f8e4ab98db5f65eb8cbbd2650524f1cc99d9e6 Mon Sep 17 00:00:00 2001 From: Jasonlmx <56242204+Jasonlmx@users.noreply.github.com> Date: Thu, 18 Sep 2025 15:48:35 -0400 Subject: [PATCH 08/13] Implement list_select function --- src/function/function_collection.cpp | 2 +- src/function/list/CMakeLists.txt | 3 +- src/function/list/list_select_function.cpp | 64 +++++++++++++++++++ .../function/list/vector_list_functions.h | 6 ++ 4 files changed, 73 insertions(+), 2 deletions(-) create mode 100644 src/function/list/list_select_function.cpp diff --git a/src/function/function_collection.cpp b/src/function/function_collection.cpp index 659516d8a5c..36b9f01f122 100644 --- a/src/function/function_collection.cpp +++ b/src/function/function_collection.cpp @@ -130,7 +130,7 @@ FunctionCollection* FunctionCollection::getFunctions() { SCALAR_FUNCTION(ListSingleFunction), SCALAR_FUNCTION(ListHasAllFunction), SCALAR_FUNCTION(ListCosineSimilarityFunction), SCALAR_FUNCTION(ListCosineDistanceFunction), SCALAR_FUNCTION(ListDistanceFunction), SCALAR_FUNCTION(ListHasAnyFunction), - SCALAR_FUNCTION(ListIntersectFunction), + SCALAR_FUNCTION(ListIntersectFunction), SCALAR_FUNCTION(ListSelectFunction), // Cast functions SCALAR_FUNCTION(CastToDateFunction), SCALAR_FUNCTION_ALIAS(DateFunction), diff --git a/src/function/list/CMakeLists.txt b/src/function/list/CMakeLists.txt index 6366fe07321..a70d907b728 100644 --- a/src/function/list/CMakeLists.txt +++ b/src/function/list/CMakeLists.txt @@ -29,7 +29,8 @@ add_library(kuzu_list_function quantifier_functions.cpp list_has_all.cpp list_has_any.cpp - list_intersect.cpp) + list_intersect.cpp + list_select_function.cpp) set(ALL_OBJECT_FILES ${ALL_OBJECT_FILES} $ diff --git a/src/function/list/list_select_function.cpp b/src/function/list/list_select_function.cpp new file mode 100644 index 00000000000..b18405f7eb5 --- /dev/null +++ b/src/function/list/list_select_function.cpp @@ -0,0 +1,64 @@ +#include "common/exception/binder.h" +#include "common/exception/message.h" +#include "common/type_utils.h" +#include "common/types/types.h" +#include "function/list/functions/list_function_utils.h" +#include "function/list/functions/list_position_function.h" +#include "function/list/functions/list_unique_function.h" +#include "function/list/vector_list_functions.h" +#include "function/scalar_function.h" + +using namespace kuzu::common; + +namespace kuzu { +namespace function { + +struct ListSelect { + static void operation(common::list_entry_t& left, common::list_entry_t& right, + common::list_entry_t& result, common::ValueVector& leftVector, + common::ValueVector& rightVector, common::ValueVector& resultVector) { + result = common::ListVector::addList(&resultVector, right.size); + auto resultDataVector = common::ListVector::getDataVector(&resultVector); + auto resultPos = result.offset; + auto leftDataVector = common::ListVector::getDataVector(&leftVector); + auto rightDataVector = common::ListVector::getDataVector(&rightVector); + auto rightPos = right.offset; + for (auto i=0u; i < right.size; i++) { + auto leftIndexPos=rightDataVector->getValue(rightPos+i); + if ((leftIndexPos<0)||(leftIndexPos>=left.size)) { + throw BinderException(stringFormat("LIST_SELECTION encounters index out of range : {} , min: {}, max: {}", leftIndexPos, 0, left.size-1)); + } + resultDataVector->copyFromVectorData(resultPos++, leftDataVector, leftIndexPos); + } + } +}; +static std::unique_ptr bindFunc(const ScalarBindFuncInput& input) { + std::vector types; + types.push_back(input.arguments[0]->getDataType().copy()); + types.push_back(input.arguments[1]->getDataType().copy()); + if (types[1].getPhysicalType()!=PhysicalTypeID::LIST) { + throw BinderException(ExceptionMessage::listFunctionIncompatibleChildrenType( + ListIntersectFunction::name, types[0].toString(), types[1].toString())); + auto thisExtraTypeInfo=types[1].getExtraTypeInfo(); + auto thisListTypeInfo=ku_dynamic_cast(thisExtraTypeInfo); + if (thisListTypeInfo->getChildType().getLogicalTypeID()!=LogicalTypeID::INT64) { + throw BinderException("LIST_SELECT expecting argument type: LIST of ANY, LIST of INT"); + } + } + return std::make_unique(std::move(types), types[0].copy()); +} + +function_set ListSelectFunction::getFunctionSet() { + function_set result; + auto execFunc = ScalarFunction::BinaryExecListStructFunction; + auto function = std::make_unique(name, + std::vector{LogicalTypeID::LIST, LogicalTypeID::LIST}, LogicalTypeID::LIST, + execFunc); + function->bindFunc = bindFunc; + result.push_back(std::move(function)); + return result; +} + +} // namespace function +} // namespace kuzu diff --git a/src/include/function/list/vector_list_functions.h b/src/include/function/list/vector_list_functions.h index e8b21370775..9734dbd3af4 100644 --- a/src/include/function/list/vector_list_functions.h +++ b/src/include/function/list/vector_list_functions.h @@ -242,5 +242,11 @@ struct ListIntersectFunction { static function_set getFunctionSet(); }; +struct ListSelectFunction { + static constexpr const char* name = "LIST_SELECT"; + + static function_set getFunctionSet(); +}; + } // namespace function } // namespace kuzu From 1bb696b3de2aa21d5e5850ee12a3faafe6be27ca Mon Sep 17 00:00:00 2001 From: Jasonlmx <56242204+Jasonlmx@users.noreply.github.com> Date: Fri, 19 Sep 2025 16:30:52 -0400 Subject: [PATCH 09/13] implemented list_where, and debug a typecheck error for list_select --- src/function/function_collection.cpp | 1 + src/function/list/CMakeLists.txt | 3 +- src/function/list/list_select_function.cpp | 3 +- src/function/list/list_where_function.cpp | 77 +++++++++++++++++++ .../function/list/vector_list_functions.h | 6 ++ 5 files changed, 88 insertions(+), 2 deletions(-) create mode 100644 src/function/list/list_where_function.cpp diff --git a/src/function/function_collection.cpp b/src/function/function_collection.cpp index 36b9f01f122..f5e3af82388 100644 --- a/src/function/function_collection.cpp +++ b/src/function/function_collection.cpp @@ -131,6 +131,7 @@ FunctionCollection* FunctionCollection::getFunctions() { SCALAR_FUNCTION(ListCosineSimilarityFunction), SCALAR_FUNCTION(ListCosineDistanceFunction), SCALAR_FUNCTION(ListDistanceFunction), SCALAR_FUNCTION(ListHasAnyFunction), SCALAR_FUNCTION(ListIntersectFunction), SCALAR_FUNCTION(ListSelectFunction), + SCALAR_FUNCTION(ListWhereFunction), // Cast functions SCALAR_FUNCTION(CastToDateFunction), SCALAR_FUNCTION_ALIAS(DateFunction), diff --git a/src/function/list/CMakeLists.txt b/src/function/list/CMakeLists.txt index a70d907b728..5d6e9977758 100644 --- a/src/function/list/CMakeLists.txt +++ b/src/function/list/CMakeLists.txt @@ -30,7 +30,8 @@ add_library(kuzu_list_function list_has_all.cpp list_has_any.cpp list_intersect.cpp - list_select_function.cpp) + list_select_function.cpp + list_where_function.cpp) set(ALL_OBJECT_FILES ${ALL_OBJECT_FILES} $ diff --git a/src/function/list/list_select_function.cpp b/src/function/list/list_select_function.cpp index b18405f7eb5..6670cc36733 100644 --- a/src/function/list/list_select_function.cpp +++ b/src/function/list/list_select_function.cpp @@ -39,9 +39,10 @@ static std::unique_ptr bindFunc(const ScalarBindFuncInput& inp if (types[1].getPhysicalType()!=PhysicalTypeID::LIST) { throw BinderException(ExceptionMessage::listFunctionIncompatibleChildrenType( ListIntersectFunction::name, types[0].toString(), types[1].toString())); + } else { auto thisExtraTypeInfo=types[1].getExtraTypeInfo(); auto thisListTypeInfo=ku_dynamic_cast(thisExtraTypeInfo); - if (thisListTypeInfo->getChildType().getLogicalTypeID()!=LogicalTypeID::INT64) { + if (thisListTypeInfo->getChildType().getPhysicalType()!=PhysicalTypeID::INT64) { throw BinderException("LIST_SELECT expecting argument type: LIST of ANY, LIST of INT"); } } diff --git a/src/function/list/list_where_function.cpp b/src/function/list/list_where_function.cpp new file mode 100644 index 00000000000..f59c158e635 --- /dev/null +++ b/src/function/list/list_where_function.cpp @@ -0,0 +1,77 @@ +#include "common/exception/binder.h" +#include "common/exception/message.h" +#include "common/type_utils.h" +#include "common/types/types.h" +#include "function/list/functions/list_function_utils.h" +#include "function/list/functions/list_position_function.h" +#include "function/list/functions/list_unique_function.h" +#include "function/list/vector_list_functions.h" +#include "function/scalar_function.h" + +using namespace kuzu::common; + +namespace kuzu { +namespace function { + +struct ListWhere { + static void operation(common::list_entry_t& left, common::list_entry_t& right, + common::list_entry_t& result, common::ValueVector& leftVector, + common::ValueVector& rightVector, common::ValueVector& resultVector) { + if (right.size!=left.size) { + throw BinderException(stringFormat("LIST_WHERE expecting lists of same size, receiving size {} and size {}", left.size, left.size)); + } + auto leftDataVector = common::ListVector::getDataVector(&leftVector); + auto leftPos = left.offset; + auto rightDataVector = common::ListVector::getDataVector(&rightVector); + auto rightPos = right.offset; + list_size_t resultSize=0; + std::vector maskListBools; + for (auto i=0u; i < right.size; i++) { + auto maskBool=rightDataVector->getValue(rightPos+i); + maskListBools.push_back(maskBool); + if (maskBool) { + resultSize++; + } + } + result = common::ListVector::addList(&resultVector, resultSize); + auto resultDataVector = common::ListVector::getDataVector(&resultVector); + auto resultPos = result.offset; + for (auto i=0u; i < right.size; i++) { + auto maskBool=maskListBools.at(i); + if (maskBool) { + resultDataVector->copyFromVectorData(resultPos++, leftDataVector, leftPos+i); + } + } + } +}; +static std::unique_ptr bindFunc(const ScalarBindFuncInput& input) { + std::vector types; + types.push_back(input.arguments[0]->getDataType().copy()); + types.push_back(input.arguments[1]->getDataType().copy()); + if (types[1].getPhysicalType()!=PhysicalTypeID::LIST) { + throw BinderException(ExceptionMessage::listFunctionIncompatibleChildrenType( + ListIntersectFunction::name, types[0].toString(), types[1].toString())); + } else { + auto thisExtraTypeInfo=types[1].getExtraTypeInfo(); + auto thisListTypeInfo=ku_dynamic_cast(thisExtraTypeInfo); + if (thisListTypeInfo->getChildType().getPhysicalType()!=PhysicalTypeID::BOOL) { + throw BinderException("LIST_SELECT expecting argument type: LIST of ANY, LIST of BOOL"); + } + } + return std::make_unique(std::move(types), types[0].copy()); +} + +function_set ListWhereFunction::getFunctionSet() { + function_set result; + auto execFunc = ScalarFunction::BinaryExecListStructFunction; + auto function = std::make_unique(name, + std::vector{LogicalTypeID::LIST, LogicalTypeID::LIST}, LogicalTypeID::LIST, + execFunc); + function->bindFunc = bindFunc; + result.push_back(std::move(function)); + return result; +} + +} // namespace function +} // namespace kuzu diff --git a/src/include/function/list/vector_list_functions.h b/src/include/function/list/vector_list_functions.h index 9734dbd3af4..087a504d476 100644 --- a/src/include/function/list/vector_list_functions.h +++ b/src/include/function/list/vector_list_functions.h @@ -248,5 +248,11 @@ struct ListSelectFunction { static function_set getFunctionSet(); }; +struct ListWhereFunction { + static constexpr const char* name = "LIST_WHERE"; + + static function_set getFunctionSet(); +}; + } // namespace function } // namespace kuzu From bc079f43e6720fc07f3bf86634a0c1f5ef03816a Mon Sep 17 00:00:00 2001 From: Jasonlmx <56242204+Jasonlmx@users.noreply.github.com> Date: Thu, 25 Sep 2025 16:26:20 -0400 Subject: [PATCH 10/13] Fix structure constructing bugs in list_intersect, add test for list_intersect andother former list functions --- .../list/list_binary_float_function.cpp | 4 + src/function/list/list_intersect.cpp | 2 +- src/function/list/list_zip_function.cpp | 0 test/test_files/function/list.test | 168 ++++++++++++++++++ 4 files changed, 173 insertions(+), 1 deletion(-) create mode 100644 src/function/list/list_zip_function.cpp diff --git a/src/function/list/list_binary_float_function.cpp b/src/function/list/list_binary_float_function.cpp index cede4fd51a1..d4eb40fb7e3 100644 --- a/src/function/list/list_binary_float_function.cpp +++ b/src/function/list/list_binary_float_function.cpp @@ -99,6 +99,10 @@ static void validateChildType(const LogicalType& type, const std::string& functi static LogicalType validateListFunctionParameters(const LogicalType& leftType, const LogicalType& rightType, const std::string& functionName) { + if ((leftType.getPhysicalType()!=common::PhysicalTypeID::LIST)|| + (rightType.getPhysicalType()!=common::PhysicalTypeID::LIST)) { + throw BinderException(stringFormat("Function {} did not receive correct arguments",functionName)); + } const auto& leftChildType = ListType::getChildType(leftType); const auto& rightChildType = ListType::getChildType(rightType); validateChildType(leftChildType, functionName); diff --git a/src/function/list/list_intersect.cpp b/src/function/list/list_intersect.cpp index 49250a79ca9..45deacf2ca4 100644 --- a/src/function/list/list_intersect.cpp +++ b/src/function/list/list_intersect.cpp @@ -35,7 +35,7 @@ struct ListIntersect { } } common::ValueVector tempVec( - kuzu::common::LogicalType::LIST(rightDataVector->dataType.getLogicalTypeID()), nullptr); + kuzu::common::LogicalType::LIST(rightDataVector->dataType.copy()), nullptr, nullptr); auto tempDataVec = common::ListVector::getDataVector(&tempVec); auto temp = common::ListVector::addList(&tempVec, rightOffsets.size()); auto tempPos = temp.offset; diff --git a/src/function/list/list_zip_function.cpp b/src/function/list/list_zip_function.cpp new file mode 100644 index 00000000000..e69de29bb2d diff --git a/test/test_files/function/list.test b/test/test_files/function/list.test index 4cebbfdae55..daacad677e5 100644 --- a/test/test_files/function/list.test +++ b/test/test_files/function/list.test @@ -2138,3 +2138,171 @@ True -STATEMENT RETURN LIST_CAT(null, null) ---- 1 +-CASE ListCosineSimilarity +-STATEMENT RETURN list_cosine_similarity([1,2,3],[4,5,6]) +---- error +Binder exception: LIST_COSINE_SIMILARITY requires argument type to be FLOAT[] or DOUBLE[]. + +-STATEMENT RETURN list_cosine_similarity([1.0,2.0,3.0],[4.0,5.0,6.0]) +---- 1 +0.974632 + +-STATEMENT RETURN list_cosine_similarity([1.0,NULL,3.0,4.0],[4.0,5.0,6.0]) +---- error +Binder exception: LIST_COSINE_SIMILARITY requires both arguments to be in same size: left : 4 ; right : 3 + +-STATEMENT RETURN list_cosine_similarity(null,null) +---- error +Binder exception: Function LIST_COSINE_SIMILARITY did not receive correct arguments + +-STATEMENT RETURN list_cosine_similarity(2,3,4,5) +---- error +Binder exception: Function LIST_COSINE_SIMILARITY did not receive correct arguments: +Actual: (INT64,INT64,INT64,INT64) +Expected: (LIST,LIST) -> ANY + +-STATEMENT RETURN list_cosine_similarity([1.5, 2.4, 3.3, 1.23, 4.56],[4.53,6.23,6.55,3.42,2.44]) +---- 1 +0.835090 + +-CASE ListCosineDistance +-STATEMENT RETURN list_cosine_distance([1,2,3],[4,5,6]) +---- error +Binder exception: LIST_COSINE_DISTANCE requires argument type to be FLOAT[] or DOUBLE[]. + +-STATEMENT RETURN list_cosine_distance([1.0,2.0,3.0],[4.0,5.0,6.0]) +---- 1 +0.025368 + +-STATEMENT RETURN list_cosine_distance([1.0,NULL,3.0,4.0],[4.0,5.0,6.0]) +---- error +Binder exception: LIST_COSINE_DISTANCE requires both arguments to be in same size: left : 4 ; right : 3 + +-STATEMENT RETURN list_cosine_distance(null,null) +---- error +Binder exception: Function LIST_COSINE_DISTANCE did not receive correct arguments + +-STATEMENT RETURN list_cosine_distance(2,3,4,5) +---- error +Binder exception: Function LIST_COSINE_DISTANCE did not receive correct arguments: +Actual: (INT64,INT64,INT64,INT64) +Expected: (LIST,LIST) -> ANY + +-STATEMENT RETURN list_cosine_distance([1.5, 2.4, 3.3, 1.23, 4.56],[4.53,6.23,6.55,3.42,2.44]) +---- 1 +0.164910 + +-CASE ListDistance +-STATEMENT RETURN list_distance([1,2,3],[4,5,6]) +---- error +Binder exception: LIST_DISTANCE requires argument type to be FLOAT[] or DOUBLE[]. + +-STATEMENT RETURN list_distance([1.0,2.0,3.0],[4.0,5.0,6.0]) +---- 1 +5.196152 + +-STATEMENT RETURN list_distance([1.0,NULL,3.0,4.0],[4.0,5.0,6.0]) +---- error +Binder exception: LIST_DISTANCE requires both arguments to be in same size: left : 4 ; right : 3 + +-STATEMENT RETURN list_distance(null,null) +---- error +Binder exception: Function LIST_DISTANCE did not receive correct arguments + +-STATEMENT RETURN list_distance(2,3,4,5) +---- error +Error: Binder exception: Function LIST_DISTANCE did not receive correct arguments: +Actual: (INT64,INT64,INT64,INT64) +Expected: (LIST,LIST) -> ANY + +-STATEMENT RETURN list_distance([1.5, 2.4, 3.3, 1.23, 4.56],[4.53,6.23,6.55,3.42,2.44]) +---- 1 +6.610809 + +-CASE ListHasAny +-STATEMENT RETURN list_has_any([1,2,3],[2,3,4]) +---- 1 +True + +-STATEMENT RETURN list_has_any([1,4,5],[2,3,6]) +---- 1 +False + +-STATEMENT RETURN list_has_any([1,2,3],['2','3','6']) +---- error +Binder exception: Cannot bind LIST_HAS_ANY with parameter type INT64[] and STRING[]. + +-STATEMENT RETURN list_has_any(['2','3','4'],['2','3','6']) +---- 1 +True + +-STATEMENT RETURN list_has_any([True,False,False],[False,False,True]) +---- 1 +True + +-STATEMENT RETURN list_has_any([1,2],[null]) +---- 1 +False + +-STATEMENT RETURN list_has_any([null,null],[null]) +---- 1 +False + +-STATEMENT RETURN list_has_any([{a: 3, b: 4}],[{b: 2, c: 3}, {a: 3, b: 4}]) +---- 1 +True + +-STATEMENT RETURN list_has_any([],[]) +---- 1 +False + +-STATEMENT RETURN list_has_any(null, [1,3,2]) +---- 1 + +-STATEMENT RETURN list_has_any([1,2], null) +---- 1 + +-STATEMENT RETURN list_has_any(null, null) +---- 1 + +-CASE ListIntersect +-STATEMENT RETURN list_intersect([1,2,3,4,5,5],[5,4,3,3,2,1]) +---- 1 +[5,4,3,2,1] + +-STATEMENT RETURN list_intersect([1,2,3,4,5,5],[]) +---- 1 +[] + +-STATEMENT RETURN list_intersect([1,2,3,4,5,5],[]) +---- 1 +[] + +-STATEMENT RETURN list_intersect([],[5,4,3,2,1]) +---- 1 +[] + +-STATEMENT RETURN list_intersect([],[]) +---- 1 +[] + +-STATEMENT RETURN list_intersect([null],[null,null]) +---- 1 +[] + +-STATEMENT RETURN list_intersect(['1','2','3','33','4'],['2','2','1','3','4']) +---- 1 +[2,1,3,4] + +-STATEMENT RETURN list_intersect([{a: 44, b: 48}, {a: 45, b: 56}],[{a:44, b:48}]) +---- 1 +[{a: 44, b: 48}] + +-STATEMENT RETURN list_intersect([null, 1, 2, 4, 8],[null, 2, null, 3, null]) +---- 1 +[2] + +-STATEMENT RETURN list_intersect([true, false, true, false],[false]) +---- 1 +[False] + From 9b584b3bcc6256becd06496709130d03f6c48c66 Mon Sep 17 00:00:00 2001 From: Jasonlmx <56242204+Jasonlmx@users.noreply.github.com> Date: Fri, 26 Sep 2025 11:41:19 -0400 Subject: [PATCH 11/13] Add test for all required list function except for list_zip, modify list_insersect and list_where to make them act similar to those in duckdb --- src/function/list/list_select_function.cpp | 11 +++- src/function/list/list_where_function.cpp | 22 +++++-- test/test_files/function/list.test | 77 ++++++++++++++++++++++ 3 files changed, 100 insertions(+), 10 deletions(-) diff --git a/src/function/list/list_select_function.cpp b/src/function/list/list_select_function.cpp index 6670cc36733..88714942f72 100644 --- a/src/function/list/list_select_function.cpp +++ b/src/function/list/list_select_function.cpp @@ -24,11 +24,13 @@ struct ListSelect { auto rightDataVector = common::ListVector::getDataVector(&rightVector); auto rightPos = right.offset; for (auto i=0u; i < right.size; i++) { - auto leftIndexPos=rightDataVector->getValue(rightPos+i); + auto leftIndexPos=rightDataVector->getValue(rightPos+i)-1; if ((leftIndexPos<0)||(leftIndexPos>=left.size)) { - throw BinderException(stringFormat("LIST_SELECTION encounters index out of range : {} , min: {}, max: {}", leftIndexPos, 0, left.size-1)); + // append null to result if out of index + resultDataVector->setNull(resultPos++, true); + } else { + resultDataVector->copyFromVectorData(resultPos++, leftDataVector, leftIndexPos); } - resultDataVector->copyFromVectorData(resultPos++, leftDataVector, leftIndexPos); } } }; @@ -36,6 +38,9 @@ static std::unique_ptr bindFunc(const ScalarBindFuncInput& inp std::vector types; types.push_back(input.arguments[0]->getDataType().copy()); types.push_back(input.arguments[1]->getDataType().copy()); + if (types[0].getPhysicalType()!=PhysicalTypeID::LIST) { + throw BinderException("LIST_SELECT expecting argument type: LIST of ANY, LIST of INT"); + } if (types[1].getPhysicalType()!=PhysicalTypeID::LIST) { throw BinderException(ExceptionMessage::listFunctionIncompatibleChildrenType( ListIntersectFunction::name, types[0].toString(), types[1].toString())); diff --git a/src/function/list/list_where_function.cpp b/src/function/list/list_where_function.cpp index f59c158e635..b60ae599e60 100644 --- a/src/function/list/list_where_function.cpp +++ b/src/function/list/list_where_function.cpp @@ -17,9 +17,6 @@ struct ListWhere { static void operation(common::list_entry_t& left, common::list_entry_t& right, common::list_entry_t& result, common::ValueVector& leftVector, common::ValueVector& rightVector, common::ValueVector& resultVector) { - if (right.size!=left.size) { - throw BinderException(stringFormat("LIST_WHERE expecting lists of same size, receiving size {} and size {}", left.size, left.size)); - } auto leftDataVector = common::ListVector::getDataVector(&leftVector); auto leftPos = left.offset; auto rightDataVector = common::ListVector::getDataVector(&rightVector); @@ -27,11 +24,14 @@ struct ListWhere { list_size_t resultSize=0; std::vector maskListBools; for (auto i=0u; i < right.size; i++) { + if (rightDataVector->isNull(rightPos+i)) { + throw BinderException("NULLs are not allowed as list elements in the second input parameter."); + } auto maskBool=rightDataVector->getValue(rightPos+i); - maskListBools.push_back(maskBool); - if (maskBool) { + if (maskBool){ resultSize++; } + maskListBools.push_back(maskBool); } result = common::ListVector::addList(&resultVector, resultSize); auto resultDataVector = common::ListVector::getDataVector(&resultVector); @@ -39,7 +39,12 @@ struct ListWhere { for (auto i=0u; i < right.size; i++) { auto maskBool=maskListBools.at(i); if (maskBool) { - resultDataVector->copyFromVectorData(resultPos++, leftDataVector, leftPos+i); + if (leftPos+icopyFromVectorData(resultPos++, leftDataVector, leftPos+i); + } else { + resultDataVector->setNull(resultPos++,true); + } + } } } @@ -48,6 +53,9 @@ static std::unique_ptr bindFunc(const ScalarBindFuncInput& inp std::vector types; types.push_back(input.arguments[0]->getDataType().copy()); types.push_back(input.arguments[1]->getDataType().copy()); + if (types[0].getPhysicalType()!=PhysicalTypeID::LIST) { + throw BinderException("LIST_WHERE expecting argument type: LIST of ANY, LIST of BOOL"); + } if (types[1].getPhysicalType()!=PhysicalTypeID::LIST) { throw BinderException(ExceptionMessage::listFunctionIncompatibleChildrenType( ListIntersectFunction::name, types[0].toString(), types[1].toString())); @@ -55,7 +63,7 @@ static std::unique_ptr bindFunc(const ScalarBindFuncInput& inp auto thisExtraTypeInfo=types[1].getExtraTypeInfo(); auto thisListTypeInfo=ku_dynamic_cast(thisExtraTypeInfo); if (thisListTypeInfo->getChildType().getPhysicalType()!=PhysicalTypeID::BOOL) { - throw BinderException("LIST_SELECT expecting argument type: LIST of ANY, LIST of BOOL"); + throw BinderException("LIST_WHERE expecting argument type: LIST of ANY, LIST of BOOL"); } } return std::make_unique(std::move(types), types[0].copy()); diff --git a/test/test_files/function/list.test b/test/test_files/function/list.test index daacad677e5..6402a7955e8 100644 --- a/test/test_files/function/list.test +++ b/test/test_files/function/list.test @@ -2306,3 +2306,80 @@ False ---- 1 [False] +-CASE ListSelect +-STATEMENT RETURN list_select([1,2,3,4,5],[1,2,3]) +---- 1 +[1,2,3] + +-STATEMENT RETURN list_select(['s','t','d','u','e','f','~'],[2,2,3,4,5]) +---- 1 +[t,t,d,u,e] + +-STATEMENT RETURN list_select([1,2,3,4,5],[1.0,2.0,3.0,5.0]) +---- error +Binder exception: LIST_SELECT expecting argument type: LIST of ANY, LIST of INT + +-STATEMENT RETURN list_select([[1],[2,3],[2,3,4],[1,2,3],[1]],[1,2,3,4,2,3,2]) +---- 1 +[[1],[2,3],[2,3,4],[1,2,3],[2,3],[2,3,4],[2,3]] + +-STATEMENT Return list_select([1,2,3,4,5],[0,1,2,3,4,5,6]) +---- 1 +[,1,2,3,4,5,] + +-STATEMENT Return list_select([1,2,3,4,5],[null]) +---- 1 +[] + +-STATEMENT Return list_select([1,2,3,4,5],null) +---- error +Binder exception: Cannot bind LIST_INTERSECT with parameter type INT64[] and ANY. + +-STATEMENT Return list_select(null,[0,1,2,3,4,5]) +---- error +Binder exception: LIST_SELECT expecting argument type: LIST of ANY, LIST of INT + +-STATEMENT RETURN list_select([{q:1, p:2},{q:2, p:3},{a: 3, q: 4}],[1,2,2,1]) +---- 1 +[{a: 1, q: 2},{a: 2, q: 3},{a: 2, q: 3},{a: 1, q: 2}] + +-STATEMENT RETURN list_select([1,2,3,4,5],[null, 2, null, 4, 5]) +---- 1 +[,2,,4,5] + +-CASE ListWhere +-STATEMENT RETURN list_where([1,2,3,4,5],[true, false, true, true, false]) +---- 1 +[1,3,4] + +-STATEMENT RETURN list_where([{a:2, b:3},{a:3, b:5},{a:4, c:1},{a:21, u:33}],[true, false, true, true, false]) +---- 1 +[{a: 2, u: 3},{a: 4, u: 1},{a: 21, u: 33}] + +-STATEMENT RETURN list_where([1,2,3,4,null],[true, false, true, true, true]) +---- 1 +[1,3,4,] + +-STATEMENT RETURN list_where([1,2,3,4],[true, false, true]) +---- 1 +[1,3] + +-STATEMENT RETURN list_where([1,2,3,4],[true, false, true,false, true, true]) +---- 1 +[1,3,,] + +-STATEMENT RETURN list_where(['a','b','c','6','r'],[false, false, true,false, true]) +---- 1 +[c,r] + +-STATEMENT RETURN list_where(null,[false, false, true,false, true]) +---- error +Binder exception: LIST_WHERE expecting argument type: LIST of ANY, LIST of BOOL + +-STATEMENT RETURN list_where([1,2,3,4],null) +---- error +Binder exception: Cannot bind LIST_INTERSECT with parameter type INT64[] and ANY. + +-STATEMENT RETURN list_where([1,2,3,4],[true, false, null, true]) +---- error +Binder exception: NULLs are not allowed as list elements in the second input parameter. From 728d14c42b933fdc12e4eee0cc9afb47498c5cdf Mon Sep 17 00:00:00 2001 From: Jasonlmx <56242204+Jasonlmx@users.noreply.github.com> Date: Fri, 26 Sep 2025 12:05:21 -0400 Subject: [PATCH 12/13] Correct unexpected error for listDistance test in list.test --- test/test_files/function/list.test | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_files/function/list.test b/test/test_files/function/list.test index 6402a7955e8..d1268307287 100644 --- a/test/test_files/function/list.test +++ b/test/test_files/function/list.test @@ -2211,7 +2211,7 @@ Binder exception: Function LIST_DISTANCE did not receive correct arguments -STATEMENT RETURN list_distance(2,3,4,5) ---- error -Error: Binder exception: Function LIST_DISTANCE did not receive correct arguments: +Binder exception: Function LIST_DISTANCE did not receive correct arguments: Actual: (INT64,INT64,INT64,INT64) Expected: (LIST,LIST) -> ANY From 26d0cd7b8199e18b75dac37c980fe57dc28fa12d Mon Sep 17 00:00:00 2001 From: CI Bot Date: Fri, 26 Sep 2025 16:07:32 +0000 Subject: [PATCH 13/13] ci: auto code format --- .../list/list_binary_float_function.cpp | 7 ++-- src/function/list/list_select_function.cpp | 16 ++++----- src/function/list/list_where_function.cpp | 34 +++++++++---------- 3 files changed, 29 insertions(+), 28 deletions(-) diff --git a/src/function/list/list_binary_float_function.cpp b/src/function/list/list_binary_float_function.cpp index d4eb40fb7e3..ee5abe40318 100644 --- a/src/function/list/list_binary_float_function.cpp +++ b/src/function/list/list_binary_float_function.cpp @@ -99,9 +99,10 @@ static void validateChildType(const LogicalType& type, const std::string& functi static LogicalType validateListFunctionParameters(const LogicalType& leftType, const LogicalType& rightType, const std::string& functionName) { - if ((leftType.getPhysicalType()!=common::PhysicalTypeID::LIST)|| - (rightType.getPhysicalType()!=common::PhysicalTypeID::LIST)) { - throw BinderException(stringFormat("Function {} did not receive correct arguments",functionName)); + if ((leftType.getPhysicalType() != common::PhysicalTypeID::LIST) || + (rightType.getPhysicalType() != common::PhysicalTypeID::LIST)) { + throw BinderException( + stringFormat("Function {} did not receive correct arguments", functionName)); } const auto& leftChildType = ListType::getChildType(leftType); const auto& rightChildType = ListType::getChildType(rightType); diff --git a/src/function/list/list_select_function.cpp b/src/function/list/list_select_function.cpp index 88714942f72..aec949e963d 100644 --- a/src/function/list/list_select_function.cpp +++ b/src/function/list/list_select_function.cpp @@ -23,9 +23,9 @@ struct ListSelect { auto leftDataVector = common::ListVector::getDataVector(&leftVector); auto rightDataVector = common::ListVector::getDataVector(&rightVector); auto rightPos = right.offset; - for (auto i=0u; i < right.size; i++) { - auto leftIndexPos=rightDataVector->getValue(rightPos+i)-1; - if ((leftIndexPos<0)||(leftIndexPos>=left.size)) { + for (auto i = 0u; i < right.size; i++) { + auto leftIndexPos = rightDataVector->getValue(rightPos + i) - 1; + if ((leftIndexPos < 0) || (leftIndexPos >= left.size)) { // append null to result if out of index resultDataVector->setNull(resultPos++, true); } else { @@ -38,16 +38,16 @@ static std::unique_ptr bindFunc(const ScalarBindFuncInput& inp std::vector types; types.push_back(input.arguments[0]->getDataType().copy()); types.push_back(input.arguments[1]->getDataType().copy()); - if (types[0].getPhysicalType()!=PhysicalTypeID::LIST) { + if (types[0].getPhysicalType() != PhysicalTypeID::LIST) { throw BinderException("LIST_SELECT expecting argument type: LIST of ANY, LIST of INT"); } - if (types[1].getPhysicalType()!=PhysicalTypeID::LIST) { + if (types[1].getPhysicalType() != PhysicalTypeID::LIST) { throw BinderException(ExceptionMessage::listFunctionIncompatibleChildrenType( ListIntersectFunction::name, types[0].toString(), types[1].toString())); } else { - auto thisExtraTypeInfo=types[1].getExtraTypeInfo(); - auto thisListTypeInfo=ku_dynamic_cast(thisExtraTypeInfo); - if (thisListTypeInfo->getChildType().getPhysicalType()!=PhysicalTypeID::INT64) { + auto thisExtraTypeInfo = types[1].getExtraTypeInfo(); + auto thisListTypeInfo = ku_dynamic_cast(thisExtraTypeInfo); + if (thisListTypeInfo->getChildType().getPhysicalType() != PhysicalTypeID::INT64) { throw BinderException("LIST_SELECT expecting argument type: LIST of ANY, LIST of INT"); } } diff --git a/src/function/list/list_where_function.cpp b/src/function/list/list_where_function.cpp index b60ae599e60..6349b62d116 100644 --- a/src/function/list/list_where_function.cpp +++ b/src/function/list/list_where_function.cpp @@ -21,14 +21,15 @@ struct ListWhere { auto leftPos = left.offset; auto rightDataVector = common::ListVector::getDataVector(&rightVector); auto rightPos = right.offset; - list_size_t resultSize=0; + list_size_t resultSize = 0; std::vector maskListBools; - for (auto i=0u; i < right.size; i++) { - if (rightDataVector->isNull(rightPos+i)) { - throw BinderException("NULLs are not allowed as list elements in the second input parameter."); + for (auto i = 0u; i < right.size; i++) { + if (rightDataVector->isNull(rightPos + i)) { + throw BinderException( + "NULLs are not allowed as list elements in the second input parameter."); } - auto maskBool=rightDataVector->getValue(rightPos+i); - if (maskBool){ + auto maskBool = rightDataVector->getValue(rightPos + i); + if (maskBool) { resultSize++; } maskListBools.push_back(maskBool); @@ -36,15 +37,14 @@ struct ListWhere { result = common::ListVector::addList(&resultVector, resultSize); auto resultDataVector = common::ListVector::getDataVector(&resultVector); auto resultPos = result.offset; - for (auto i=0u; i < right.size; i++) { - auto maskBool=maskListBools.at(i); + for (auto i = 0u; i < right.size; i++) { + auto maskBool = maskListBools.at(i); if (maskBool) { - if (leftPos+icopyFromVectorData(resultPos++, leftDataVector, leftPos+i); + if (leftPos + i < left.size) { + resultDataVector->copyFromVectorData(resultPos++, leftDataVector, leftPos + i); } else { - resultDataVector->setNull(resultPos++,true); + resultDataVector->setNull(resultPos++, true); } - } } } @@ -53,16 +53,16 @@ static std::unique_ptr bindFunc(const ScalarBindFuncInput& inp std::vector types; types.push_back(input.arguments[0]->getDataType().copy()); types.push_back(input.arguments[1]->getDataType().copy()); - if (types[0].getPhysicalType()!=PhysicalTypeID::LIST) { + if (types[0].getPhysicalType() != PhysicalTypeID::LIST) { throw BinderException("LIST_WHERE expecting argument type: LIST of ANY, LIST of BOOL"); } - if (types[1].getPhysicalType()!=PhysicalTypeID::LIST) { + if (types[1].getPhysicalType() != PhysicalTypeID::LIST) { throw BinderException(ExceptionMessage::listFunctionIncompatibleChildrenType( ListIntersectFunction::name, types[0].toString(), types[1].toString())); } else { - auto thisExtraTypeInfo=types[1].getExtraTypeInfo(); - auto thisListTypeInfo=ku_dynamic_cast(thisExtraTypeInfo); - if (thisListTypeInfo->getChildType().getPhysicalType()!=PhysicalTypeID::BOOL) { + auto thisExtraTypeInfo = types[1].getExtraTypeInfo(); + auto thisListTypeInfo = ku_dynamic_cast(thisExtraTypeInfo); + if (thisListTypeInfo->getChildType().getPhysicalType() != PhysicalTypeID::BOOL) { throw BinderException("LIST_WHERE expecting argument type: LIST of ANY, LIST of BOOL"); } }