Skip to content
This repository was archived by the owner on Oct 10, 2025. It is now read-only.
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/function/function_collection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,10 @@ FunctionCollection* FunctionCollection::getFunctions() {
SCALAR_FUNCTION(ListReduceFunction), SCALAR_FUNCTION(ListAnyFunction),
SCALAR_FUNCTION(ListAllFunction), SCALAR_FUNCTION(ListNoneFunction),
SCALAR_FUNCTION(ListSingleFunction), SCALAR_FUNCTION(ListHasAllFunction),
SCALAR_FUNCTION(ListCosineSimilarityFunction), SCALAR_FUNCTION(ListCosineDistanceFunction),
SCALAR_FUNCTION(ListDistanceFunction), SCALAR_FUNCTION(ListHasAnyFunction),
SCALAR_FUNCTION(ListIntersectFunction), SCALAR_FUNCTION(ListSelectFunction),
SCALAR_FUNCTION(ListWhereFunction),

// Cast functions
SCALAR_FUNCTION(CastToDateFunction), SCALAR_FUNCTION_ALIAS(DateFunction),
Expand Down
7 changes: 6 additions & 1 deletion src/function/list/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ add_library(kuzu_list_function
list_append_function.cpp
list_concat_function.cpp
list_contains_function.cpp
list_binary_float_function.cpp
list_creation.cpp
list_distinct_function.cpp
list_extract_function.cpp
Expand All @@ -26,7 +27,11 @@ add_library(kuzu_list_function
list_single.cpp
size_function.cpp
quantifier_functions.cpp
list_has_all.cpp)
list_has_all.cpp
list_has_any.cpp
list_intersect.cpp
list_select_function.cpp
list_where_function.cpp)

set(ALL_OBJECT_FILES
${ALL_OBJECT_FILES} $<TARGET_OBJECTS:kuzu_list_function>
Expand Down
182 changes: 182 additions & 0 deletions src/function/list/list_binary_float_function.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
#include "math.h"

#include "common/exception/binder.h"
#include "common/exception/message.h"
#include "common/type_utils.h"
#include "common/vector/value_vector.h"
#include "function/list/functions/list_function_utils.h"
#include "function/list/vector_list_functions.h"
#include "function/scalar_function.h"
#include <simsimd.h>

using namespace kuzu::common;

namespace kuzu {
namespace function {

struct ListCosineSimilarity {
template<std::floating_point T>
static void operation(common::list_entry_t& left, common::list_entry_t& right, T& result,
common::ValueVector& leftVector, common::ValueVector& rightVector,
common::ValueVector& /*resultVector*/) {
auto leftElements = (T*)common::ListVector::getListValues(&leftVector, left);
auto rightElements = (T*)common::ListVector::getListValues(&rightVector, right);
if (left.size != right.size) {
throw BinderException(stringFormat("LIST_COSINE_SIMILARITY requires both arguments to "
"be in same size: left : {} ; right : {}",
left.size, right.size));
}
KU_ASSERT(left.size == right.size);
simsimd_distance_t tmpResult = 0.0;
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double>);
if constexpr (std::is_same_v<T, float>) {
simsimd_cos_f32(leftElements, rightElements, left.size, &tmpResult);
} else {
simsimd_cos_f64(leftElements, rightElements, left.size, &tmpResult);
}
result = 1.0 - tmpResult;
}
};

struct ListCosineDistance {
template<std::floating_point T>
static void operation(common::list_entry_t& left, common::list_entry_t& right, T& result,
common::ValueVector& leftVector, common::ValueVector& rightVector,
common::ValueVector& /*resultVector*/) {
auto leftElements = (T*)common::ListVector::getListValues(&leftVector, left);
auto rightElements = (T*)common::ListVector::getListValues(&rightVector, right);
if (left.size != right.size) {
throw BinderException(stringFormat("LIST_COSINE_DISTANCE requires both arguments to be "
"in same size: left : {} ; right : {}",
left.size, right.size));
}
KU_ASSERT(left.size == right.size);
simsimd_distance_t tmpResult = 0.0;
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double>);
if constexpr (std::is_same_v<T, float>) {
simsimd_cos_f32(leftElements, rightElements, left.size, &tmpResult);
} else {
simsimd_cos_f64(leftElements, rightElements, left.size, &tmpResult);
}
result = tmpResult;
}
};

struct ListDistance {
template<std::floating_point T>
static void operation(common::list_entry_t& left, common::list_entry_t& right, T& result,
common::ValueVector& leftVector, common::ValueVector& rightVector,
common::ValueVector& /*resultVector*/) {
auto leftElements = (T*)common::ListVector::getListValues(&leftVector, left);
auto rightElements = (T*)common::ListVector::getListValues(&rightVector, right);
if (left.size != right.size) {
throw BinderException(stringFormat(
"LIST_DISTANCE requires both arguments to be in same size: left : {} ; right : {}",
left.size, right.size));
}
KU_ASSERT(left.size == right.size);
simsimd_distance_t tmpResult = 0.0;
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double>);
if constexpr (std::is_same_v<T, float>) {
simsimd_l2sq_f32(leftElements, rightElements, left.size, &tmpResult);
} else {
simsimd_l2sq_f64(leftElements, rightElements, left.size, &tmpResult);
}
result = std::sqrt(tmpResult);
}
};

static void validateChildType(const LogicalType& type, const std::string& functionName) {
switch (type.getLogicalTypeID()) {
case LogicalTypeID::DOUBLE:
case LogicalTypeID::FLOAT:
return;
default:
throw BinderException(
stringFormat("{} requires argument type to be FLOAT[] or DOUBLE[].", functionName));
}
}

static LogicalType validateListFunctionParameters(const LogicalType& leftType,
const LogicalType& rightType, const std::string& functionName) {
if ((leftType.getPhysicalType() != common::PhysicalTypeID::LIST) ||
(rightType.getPhysicalType() != common::PhysicalTypeID::LIST)) {
throw BinderException(
stringFormat("Function {} did not receive correct arguments", functionName));
}
const auto& leftChildType = ListType::getChildType(leftType);
const auto& rightChildType = ListType::getChildType(rightType);
validateChildType(leftChildType, functionName);
validateChildType(rightChildType, functionName);
if (leftType.getLogicalTypeID() == common::LogicalTypeID::LIST) {
return leftType.copy();
} else if (rightType.getLogicalTypeID() == common::LogicalTypeID::LIST) {
return rightType.copy();
}
throw BinderException(
stringFormat("{} requires at least one argument to be LIST.", functionName));
}

template<typename OPERATION, typename RESULT>
static scalar_func_exec_t getBinaryListExecFuncSwitchResultType() {
auto execFunc =
ScalarFunction::BinaryExecListStructFunction<list_entry_t, list_entry_t, RESULT, OPERATION>;
return execFunc;
}

template<typename OPERATION>
scalar_func_exec_t getScalarExecFunc(LogicalType type) {
scalar_func_exec_t execFunc;
switch (ListType::getChildType(type).getLogicalTypeID()) {
case LogicalTypeID::FLOAT:
execFunc = getBinaryListExecFuncSwitchResultType<OPERATION, float>();
break;
case LogicalTypeID::DOUBLE:
execFunc = getBinaryListExecFuncSwitchResultType<OPERATION, double>();
break;
default:
KU_UNREACHABLE;
}
return execFunc;
}

template<typename OPERATION>
static std::unique_ptr<FunctionBindData> bindFunc(const ScalarBindFuncInput& input) {
std::vector<LogicalType> types;
types.push_back(input.arguments[0]->getDataType().copy());
types.push_back(input.arguments[1]->getDataType().copy());
auto paramType = validateListFunctionParameters(types[0], types[1], input.definition->name);
input.definition->ptrCast<ScalarFunction>()->execFunc =
std::move(getScalarExecFunc<OPERATION>(paramType.copy()));
auto bindData = std::make_unique<FunctionBindData>(ListType::getChildType(paramType).copy());
std::vector<LogicalType> paramTypes;
for (auto& _ : input.arguments) {
(void)_;
bindData->paramTypes.push_back(paramType.copy());
}
return bindData;
}
template<typename OPERATION>
function_set templateGetFunctionSet(const std::string& name) {
function_set result;
auto function = std::make_unique<ScalarFunction>(name,
std::vector<LogicalTypeID>{LogicalTypeID::LIST, LogicalTypeID::LIST}, LogicalTypeID::ANY);
function->bindFunc = bindFunc<OPERATION>;
result.push_back(std::move(function));
return result;
}

function_set ListCosineSimilarityFunction::getFunctionSet() {
return templateGetFunctionSet<ListCosineSimilarity>(name);
}

function_set ListCosineDistanceFunction::getFunctionSet() {
return templateGetFunctionSet<ListCosineDistance>(name);
}

function_set ListDistanceFunction::getFunctionSet() {
return templateGetFunctionSet<ListDistance>(name);
}

} // namespace function
} // namespace kuzu
2 changes: 1 addition & 1 deletion src/function/list/list_has_all.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ struct ListHasAll {
}
};

std::unique_ptr<FunctionBindData> bindFunc(const ScalarBindFuncInput& input) {
static std::unique_ptr<FunctionBindData> bindFunc(const ScalarBindFuncInput& input) {
std::vector<LogicalType> types;
for (auto& arg : input.arguments) {
if (arg->dataType == LogicalType::ANY()) {
Expand Down
68 changes: 68 additions & 0 deletions src/function/list/list_has_any.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
#include "common/exception/binder.h"
#include "common/exception/message.h"
#include "common/type_utils.h"
#include "function/list/functions/list_position_function.h"
#include "function/list/vector_list_functions.h"
#include "function/scalar_function.h"

using namespace kuzu::common;

namespace kuzu {
namespace function {

struct ListHasAny {
static void operation(common::list_entry_t& left, common::list_entry_t& right, uint8_t& result,
common::ValueVector& leftVector, common::ValueVector& rightVector,
common::ValueVector& resultVector) {
int64_t pos = 0;
auto rightDataVector = ListVector::getDataVector(&rightVector);
result = false;
for (auto i = 0u; i < right.size; i++) {
common::TypeUtils::visit(ListType::getChildType(rightVector.dataType).getPhysicalType(),
[&]<typename T>(T) {
if (rightDataVector->isNull(right.offset + i)) {
return;
}
ListPosition::operation(left,
*(T*)ListVector::getListValuesWithOffset(&rightVector, right, i), pos,
leftVector, *ListVector::getDataVector(&rightVector), resultVector);
result = (pos != 0);
});
if (result) {
return;
}
}
}
};

static std::unique_ptr<FunctionBindData> bindFunc(const ScalarBindFuncInput& input) {
std::vector<LogicalType> types;
for (auto& arg : input.arguments) {
if (arg->dataType == LogicalType::ANY()) {
types.push_back(LogicalType::LIST(LogicalType::INT64()));
} else {
types.push_back(arg->dataType.copy());
}
}
if (types[0] != types[1]) {
throw common::BinderException(ExceptionMessage::listFunctionIncompatibleChildrenType(
ListHasAnyFunction::name, input.arguments[0]->getDataType().toString(),
input.arguments[1]->getDataType().toString()));
}
return std::make_unique<FunctionBindData>(std::move(types), LogicalType::BOOL());
}

function_set ListHasAnyFunction::getFunctionSet() {
function_set result;
auto execFunc = ScalarFunction::BinaryExecListStructFunction<list_entry_t, list_entry_t,
uint8_t, ListHasAny>;
auto function = std::make_unique<ScalarFunction>(name,
std::vector<LogicalTypeID>{LogicalTypeID::LIST, LogicalTypeID::LIST}, LogicalTypeID::BOOL,
execFunc);
function->bindFunc = bindFunc;
result.push_back(std::move(function));
return result;
}

} // namespace function
} // namespace kuzu
84 changes: 84 additions & 0 deletions src/function/list/list_intersect.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
#include "common/exception/binder.h"
#include "common/exception/message.h"
#include "common/type_utils.h"
#include "function/list/functions/list_function_utils.h"
#include "function/list/functions/list_position_function.h"
#include "function/list/functions/list_unique_function.h"
#include "function/list/vector_list_functions.h"
#include "function/scalar_function.h"

using namespace kuzu::common;

namespace kuzu {
namespace function {

struct ListIntersect {
static void operation(common::list_entry_t& left, common::list_entry_t& right,
common::list_entry_t& result, common::ValueVector& leftVector,
common::ValueVector& rightVector, common::ValueVector& resultVector) {
int64_t pos = 0;
auto rightDataVector = common::ListVector::getDataVector(&rightVector);
auto rightPos = right.offset;
std::vector<offset_t> rightOffsets;
for (auto i = 0u; i < right.size; i++) {
common::TypeUtils::visit(ListType::getChildType(rightVector.dataType).getPhysicalType(),
[&]<typename T>(T) {
if (rightDataVector->isNull(right.offset + i)) {
return;
}
ListPosition::operation(left,
*(T*)ListVector::getListValuesWithOffset(&rightVector, right, i), pos,
leftVector, *ListVector::getDataVector(&rightVector), resultVector);
});
if (pos != 0) {
rightOffsets.push_back(rightPos + i);
}
}
common::ValueVector tempVec(
kuzu::common::LogicalType::LIST(rightDataVector->dataType.copy()), nullptr, nullptr);
auto tempDataVec = common::ListVector::getDataVector(&tempVec);
auto temp = common::ListVector::addList(&tempVec, rightOffsets.size());
auto tempPos = temp.offset;
for (auto i = 0u; i < rightOffsets.size(); i++) {
tempDataVec->copyFromVectorData(tempPos++, rightDataVector,
rightPos + rightOffsets.at(i));
}
auto numUniqueValues = ListUnique::appendListElementsToValueSet(temp, tempVec);
result = common::ListVector::addList(&resultVector, numUniqueValues);
auto resultDataVector = common::ListVector::getDataVector(&resultVector);
auto resultDataVectorBuffer =
common::ListVector::getListValuesWithOffset(&resultVector, result, 0 /* offset */);
ListUnique::appendListElementsToValueSet(temp, tempVec, nullptr,
[&resultDataVector, &resultDataVectorBuffer](common::ValueVector& dataVector,
uint64_t pos) -> void {
resultDataVector->copyFromVectorData(resultDataVectorBuffer, &dataVector,
dataVector.getData() + pos * dataVector.getNumBytesPerValue());
resultDataVectorBuffer += dataVector.getNumBytesPerValue();
});
}
};
static std::unique_ptr<FunctionBindData> bindFunc(const ScalarBindFuncInput& input) {
std::vector<LogicalType> types;
types.push_back(input.arguments[0]->getDataType().copy());
types.push_back(input.arguments[1]->getDataType().copy());
if (types[0] != types[1]) {
throw BinderException(ExceptionMessage::listFunctionIncompatibleChildrenType(
ListIntersectFunction::name, types[0].toString(), types[1].toString()));
}
return std::make_unique<FunctionBindData>(std::move(types), types[0].copy());
}

function_set ListIntersectFunction::getFunctionSet() {
function_set result;
auto execFunc = ScalarFunction::BinaryExecListStructFunction<list_entry_t, list_entry_t,
list_entry_t, ListIntersect>;
auto function = std::make_unique<ScalarFunction>(name,
std::vector<LogicalTypeID>{LogicalTypeID::LIST, LogicalTypeID::LIST}, LogicalTypeID::LIST,
execFunc);
function->bindFunc = bindFunc;
result.push_back(std::move(function));
return result;
}

} // namespace function
} // namespace kuzu
Loading