Skip to content
This repository was archived by the owner on Oct 10, 2025. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions extension/fts/src/function/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
add_subdirectory(query_fts)

add_library(kuzu_fts_function
OBJECT
create_fts_index.cpp
drop_fts_index.cpp
fts_config.cpp
fts_index_utils.cpp
query_fts_index.cpp
query_fts_bind_data.cpp
stem.cpp
tokenize.cpp)

Expand Down
30 changes: 26 additions & 4 deletions extension/fts/src/function/create_fts_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,20 @@ static std::string formatStrInCypher(const std::string& input) {
return result;
}

static std::string createTablesForExactTermMatch(const CreateFTSBindData& bindData) {
std::string query;
auto appearsInfoTableName =
FTSUtils::getAppearsInfoTableName(bindData.tableID, bindData.indexName);
auto originalTermsTableName =
FTSUtils::getOrigTermsTableName(bindData.tableID, bindData.indexName);
query += common::stringFormat("CREATE NODE TABLE `{}`(term string, primary key(term));",
originalTermsTableName);
query +=
common::stringFormat("COPY `{}` FROM (match (doc:`{}`) return distinct doc.term_origin);",
originalTermsTableName, appearsInfoTableName);
return query;
}

std::string createFTSIndexQuery(ClientContext& context, const TableFuncBindData& bindData) {
auto ftsBindData = bindData.constPtrCast<CreateFTSBindData>();
auto tableID = ftsBindData->tableID;
Expand Down Expand Up @@ -174,8 +188,9 @@ std::string createFTSIndexQuery(ClientContext& context, const TableFuncBindData&
// Create the terms_in_doc table which servers as a temporary table to store the
// relationship between terms and docs.
auto appearsInfoTableName = FTSUtils::getAppearsInfoTableName(tableID, indexName);
query += stringFormat("CREATE NODE TABLE `{}` (ID SERIAL, term string, docID INT64, primary "
"key(ID));",
query += stringFormat(
"CREATE NODE TABLE `{}` (ID SERIAL, term string, term_origin string, docID INT64, primary "
"key(ID));",
appearsInfoTableName);
auto tableName = ftsBindData->tableName;
auto tableEntry = catalog::Catalog::Get(context)->getTableCatalogEntry(
Expand All @@ -189,7 +204,7 @@ std::string createFTSIndexQuery(ClientContext& context, const TableFuncBindData&
"WITH t AS t1, id AS id1 "
"WHERE t1 is NOT NULL AND SIZE(t1) > 0 AND "
"NOT EXISTS {MATCH (s:`{}` {sw: t1})} "
"RETURN STEM(t1, '{}'), id1);",
"RETURN STEM(t1, '{}'), t1, id1);",
appearsInfoTableName, tableName, FTSUtils::getTokenizeMacroName(tableID, indexName),
propertyName, ftsBindData->createFTSConfig.stopWordsTableInfo.tableName,
ftsBindData->createFTSConfig.stemmer);
Expand All @@ -213,6 +228,11 @@ std::string createFTSIndexQuery(ClientContext& context, const TableFuncBindData&
"RETURN t.term, CAST(count(distinct t.docID) AS UINT64));",
termsTableName, appearsInfoTableName);

// If the exact_term_match is enabled, we need to create an additional tables.
if (ftsBindData->createFTSConfig.exactTermMatch) {
query += createTablesForExactTermMatch(*ftsBindData);
}

auto appearsInTableName = FTSUtils::getAppearsInTableName(tableID, indexName);
// Finally, create a terms table that records the documents in which the terms appear, along
// with the frequency of each term.
Expand All @@ -236,8 +256,10 @@ std::string createFTSIndexQuery(ClientContext& context, const TableFuncBindData&
properties += "]";
std::string params;
params += stringFormat("stemmer := '{}', ", ftsBindData->createFTSConfig.stemmer);
params += stringFormat("stopWords := '{}'",
params += stringFormat("stopWords := '{}', ",
ftsBindData->createFTSConfig.stopWordsTableInfo.stopWords);
params += stringFormat("exact_term_match := {}",
ftsBindData->createFTSConfig.exactTermMatch ? "true" : "false");
query += stringFormat("CALL _CREATE_FTS_INDEX('{}', '{}', {}, {});", tableName, indexName,
properties, params);
query += stringFormat("RETURN 'Index {} has been created.' as result;", ftsBindData->indexName);
Expand Down
12 changes: 9 additions & 3 deletions extension/fts/src/function/fts_config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,14 +149,17 @@ CreateFTSConfig::CreateFTSConfig(main::ClientContext& context, common::table_id_
common::StringUtils::replaceAll(ignorePatternQuery, "?", "");
IgnorePattern::validate(ignorePattern);
IgnorePattern::validate(ignorePatternQuery);
} else if (lowerCaseName == "tokenizer") {
value.validateType(common::LogicalTypeID::STRING);
} else if (Tokenizer::NAME == lowerCaseName) {
value.validateType(Tokenizer::TYPE);
tokenizerInfo.tokenizer = common::StringUtils::getLower(value.getValue<std::string>());
Tokenizer::validate(tokenizerInfo.tokenizer);
} else if (lowerCaseName == "jieba_dict_dir") {
value.validateType(common::LogicalTypeID::STRING);
tokenizerInfo.jiebaDictDir =
common::StringUtils::getLower(value.getValue<std::string>());
} else if (ExactTermMatch::NAME == lowerCaseName) {
value.validateType(ExactTermMatch::TYPE);
exactTermMatch = value.getValue<bool>();
} else {
throw common::BinderException{"Unrecognized optional parameter: " + name};
}
Expand All @@ -165,7 +168,8 @@ CreateFTSConfig::CreateFTSConfig(main::ClientContext& context, common::table_id_

FTSConfig CreateFTSConfig::getFTSConfig() const {
return FTSConfig{stemmer, stopWordsTableInfo.tableName, stopWordsTableInfo.stopWords,
ignorePattern, ignorePatternQuery, tokenizerInfo.tokenizer, tokenizerInfo.jiebaDictDir};
ignorePattern, ignorePatternQuery, tokenizerInfo.tokenizer, tokenizerInfo.jiebaDictDir,
exactTermMatch};
}

void FTSConfig::serialize(common::Serializer& serializer) const {
Expand All @@ -176,6 +180,7 @@ void FTSConfig::serialize(common::Serializer& serializer) const {
serializer.serializeValue(ignorePatternQuery);
serializer.serializeValue(tokenizer);
serializer.serializeValue(jiebaDictDir);
serializer.serializeValue(exactTermMatch);
}

FTSConfig FTSConfig::deserialize(common::Deserializer& deserializer) {
Expand All @@ -187,6 +192,7 @@ FTSConfig FTSConfig::deserialize(common::Deserializer& deserializer) {
deserializer.deserializeValue(config.ignorePatternQuery);
deserializer.deserializeValue(config.tokenizer);
deserializer.deserializeValue(config.jiebaDictDir);
deserializer.deserializeValue(config.exactTermMatch);
return config;
}

Expand Down
10 changes: 10 additions & 0 deletions extension/fts/src/function/query_fts/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
add_library(kuzu_query_fts_function
OBJECT
query_fts_index.cpp
query_fts_pattern_match.cpp
query_fts_bind_data.cpp
query_fts_term_lookup.cpp)

set(FTS_EXTENSION_OBJECT_FILES
${FTS_EXTENSION_OBJECT_FILES} $<TARGET_OBJECTS:kuzu_query_fts_function>
PARENT_SCOPE)
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#include "function/query_fts_bind_data.h"
#include "function/query_fts/query_fts_bind_data.h"

#include "binder/binder.h"
#include "binder/expression/expression_util.h"
Expand Down Expand Up @@ -42,6 +42,36 @@ void QueryFTSOptionalParams::evaluateParams(main::ClientContext* context) {
topK.evaluateParam(context);
}

QueryFTSBindData::QueryFTSBindData(binder::expression_vector columns,
graph::NativeGraphEntry graphEntry, std::shared_ptr<binder::Expression> docs,
std::shared_ptr<binder::Expression> query, const catalog::IndexCatalogEntry& entry,
std::unique_ptr<QueryFTSOptionalParams> optionalParams, common::idx_t numDocs, double avgDocLen)
: GDSBindData{std::move(columns), std::move(graphEntry), binder::expression_vector{docs}},
query{std::move(query)}, entry{entry},
outputTableID{output[0]->constCast<binder::NodeExpression>().getTableIDs()[0]},
numDocs{numDocs}, avgDocLen{avgDocLen},
patternMatchAlgo{PatternMatchFactory::getPatternMatchAlgo(
entry.getAuxInfo().cast<FTSIndexAuxInfo>().config.exactTermMatch ? TermMatchType::EXACT :
TermMatchType::STEM)} {
auto& nodeExpr = output[0]->constCast<binder::NodeExpression>();
KU_ASSERT(nodeExpr.getNumEntries() == 1);
outputTableID = nodeExpr.getEntry(0)->getTableID();
this->optionalParams = std::move(optionalParams);
}

catalog::TableCatalogEntry* QueryFTSBindData::getTermsEntry(main::ClientContext& context) const {
auto catalog = catalog::Catalog::Get(context);
return catalog->getTableCatalogEntry(transaction::Transaction::Get(context),
FTSUtils::getTermsTableName(entry.getTableID(), entry.getIndexName()));
}

catalog::TableCatalogEntry* QueryFTSBindData::getOrigTermsEntry(
main::ClientContext& context) const {
auto catalog = catalog::Catalog::Get(context);
return catalog->getTableCatalogEntry(transaction::Transaction::Get(context),
FTSUtils::getOrigTermsTableName(entry.getTableID(), entry.getIndexName()));
}

std::vector<std::string> QueryFTSBindData::getQueryTerms(main::ClientContext& context) const {
auto queryInStr =
ExpressionUtil::evaluateLiteral<std::string>(&context, query, LogicalType::STRING());
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#include "function/query_fts_index.h"
#include "function/query_fts/query_fts_index.h"

#include <queue>

Expand All @@ -11,16 +11,16 @@
#include "common/types/internal_id_util.h"
#include "function/fts_index_utils.h"
#include "function/gds/gds_utils.h"
#include "function/query_fts_bind_data.h"
#include "function/query_fts/query_fts_bind_data.h"
#include "function/query_fts/query_fts_pattern_match.h"
#include "function/query_fts/query_fts_term_lookup.h"
#include "graph/on_disk_graph.h"
#include "index/fts_index.h"
#include "planner/operator/logical_hash_join.h"
#include "planner/operator/logical_table_function_call.h"
#include "planner/planner.h"
#include "processor/execution_context.h"
#include "re2.h"
#include "storage/storage_manager.h"
#include "storage/table/node_table.h"
#include "utils/fts_utils.h"

namespace kuzu {
Expand Down Expand Up @@ -242,71 +242,14 @@ class QFTSVertexCompute final : public VertexCompute {
std::unique_ptr<QFTSOutputWriter> writer;
};

using VCQueryTerm = std::variant<std::string, std::unique_ptr<RE2>>;
class MatchTermsVertexCompute final : public VertexCompute {
public:
explicit MatchTermsVertexCompute(std::unordered_map<offset_t, uint64_t>& resDfs,
std::vector<VCQueryTerm>& queryTerms)
: resDfs{resDfs}, queryTerms{queryTerms} {}
void vertexCompute(const graph::VertexScanState::Chunk& chunk) override {
auto terms = chunk.getProperties<ku_string_t>(0);
auto dfs = chunk.getProperties<uint64_t>(1);
auto nodeIds = chunk.getNodeIDs();
for (auto& queryTerm : queryTerms) {
// queryTerm.index() is 0 for string, 1 for unique_ptr<RE2>
if (queryTerm.index() == 0) {
std::string& queryString = std::get<0>(queryTerm);
for (auto i = 0u; i < chunk.size(); ++i) {
if (queryString == terms[i].getAsString()) {
resDfs[nodeIds[i].offset] = dfs[i];
}
}
} else {
RE2& regex = *std::get<1>(queryTerm);
for (auto i = 0u; i < chunk.size(); ++i) {
if (RE2::FullMatch(terms[i].getAsString(), regex)) {
resDfs[nodeIds[i].offset] = dfs[i];
}
}
}
}
}
std::unique_ptr<VertexCompute> copy() override {
return std::make_unique<MatchTermsVertexCompute>(resDfs, queryTerms);
}

private:
std::unordered_map<offset_t, uint64_t>& resDfs;
std::vector<VCQueryTerm>& queryTerms;
};

static constexpr char SCORE_PROP_NAME[] = "score";
static constexpr char DOC_FREQUENCY_PROP_NAME[] = "df";
static constexpr char TERM_FREQUENCY_PROP_NAME[] = "tf";
static constexpr char DOC_LEN_PROP_NAME[] = "len";
static constexpr char DOC_ID_PROP_NAME[] = "docID";

static std::unordered_map<offset_t, uint64_t> getDFs(main::ClientContext& context,
processor::ExecutionContext* executionContext, graph::Graph* graph,
catalog::TableCatalogEntry* termsEntry, std::vector<std::string>& queryTerms) {
auto storageManager = StorageManager::Get(context);
auto tableID = termsEntry->getTableID();
auto& termsNodeTable = storageManager->getTable(tableID)->cast<NodeTable>();
auto tx = transaction::Transaction::Get(context);
auto dfColumnID = termsEntry->getColumnID(DOC_FREQUENCY_PROP_NAME);
std::vector<LogicalType> vectorTypes;
vectorTypes.push_back(LogicalType::INTERNAL_ID());
vectorTypes.push_back(LogicalType::UINT64());
auto dataChunk = Table::constructDataChunk(MemoryManager::Get(context), std::move(vectorTypes));
dataChunk.state->getSelVectorUnsafe().setSelSize(1);
auto nodeIDVector = &dataChunk.getValueVectorMutable(0);
auto dfVector = &dataChunk.getValueVectorMutable(1);
auto termsVector = ValueVector(LogicalType::STRING(), MemoryManager::Get(context));
termsVector.state = dataChunk.state;
auto nodeTableScanState =
NodeTableScanState(nodeIDVector, std::vector{dfVector}, dataChunk.state);
nodeTableScanState.setToTable(transaction::Transaction::Get(context), &termsNodeTable,
{dfColumnID}, {});
const QueryFTSBindData& bindData, std::vector<std::string>& queryTerms) {
std::unordered_map<offset_t, uint64_t> dfs;
std::vector<VCQueryTerm> vcQueryTerms;
vcQueryTerms.reserve(queryTerms.size());
Expand All @@ -323,22 +266,17 @@ static std::unordered_map<offset_t, uint64_t> getDFs(main::ClientContext& contex
vcQueryTerms.emplace_back(std::in_place_type<std::string>, queryTerm);
}
}

if (hasWildcardQueryTerm) {
auto matchVc = MatchTermsVertexCompute{dfs, vcQueryTerms};
GDSUtils::runVertexCompute(executionContext, GDSDensityState::DENSE, graph, matchVc,
termsEntry, std::vector<std::string>{"term", DOC_FREQUENCY_PROP_NAME});
bindData.patternMatchAlgo(dfs, vcQueryTerms, executionContext, graph, bindData);
} else {
TermsDFLookup termsDFLookup{bindData.getTermsEntry(context), context};
for (auto& queryTerm : queryTerms) {
termsVector.setValue(0, queryTerm);
offset_t offset = 0;
if (!termsNodeTable.lookupPK(tx, &termsVector, 0 /* vectorPos */, offset)) {
auto offsetDFPair = termsDFLookup.lookupTermDF(queryTerm);
if (offsetDFPair.first == INVALID_OFFSET) {
continue;
}
auto nodeID = nodeID_t{offset, tableID};
nodeIDVector->setValue(0, nodeID);
termsNodeTable.initScanState(tx, nodeTableScanState, tableID, offset);
[[maybe_unused]] auto res = termsNodeTable.lookup(tx, nodeTableScanState);
dfs.emplace(offset, dfVector->getValue<uint64_t>(0));
dfs.emplace(offsetDFPair);
}
}
return dfs;
Expand Down Expand Up @@ -381,7 +319,7 @@ static offset_t tableFunc(const TableFuncInput& input, TableFuncOutput&) {
}
auto termsEntry = graphEntry->nodeInfos[0].entry;
auto queryTerms = qFTSBindData.getQueryTerms(clientContext);
auto dfs = getDFs(clientContext, input.context, graph, termsEntry, queryTerms);
auto dfs = getDFs(clientContext, input.context, graph, qFTSBindData, queryTerms);
// Do edge compute to extend terms -> docs and save the term frequency and document frequency
// for each term-doc pair. The reason why we store the term frequency and document frequency
// is that: we need the `len` property from the docs table which is only available during the
Expand Down Expand Up @@ -444,7 +382,6 @@ static std::unique_ptr<TableFuncBindData> bindFunc(main::ClientContext* context,
auto inputTableName = getParamVal(*input, 0);
auto indexName = getParamVal(*input, 1);
auto query = input->getParam(2);

auto tableEntry = FTSIndexUtils::bindNodeTable(*context, inputTableName, indexName,
FTSIndexUtils::IndexOperation::QUERY);
auto catalog = catalog::Catalog::Get(*context);
Expand All @@ -459,7 +396,12 @@ static std::unique_ptr<TableFuncBindData> bindFunc(main::ClientContext* context,
FTSUtils::getDocsTableName(tableEntry->getTableID(), indexName));
auto appearsInEntry = catalog->getTableCatalogEntry(transaction,
FTSUtils::getAppearsInTableName(tableEntry->getTableID(), indexName));
auto graphEntry = graph::NativeGraphEntry({termsEntry, docsEntry}, {appearsInEntry});
std::vector<catalog::TableCatalogEntry*> nodeEntries{termsEntry, docsEntry};
if (ftsIndexEntry->getAuxInfo().cast<FTSIndexAuxInfo>().config.exactTermMatch) {
nodeEntries.push_back(catalog->getTableCatalogEntry(transaction,
FTSUtils::getOrigTermsTableName(tableEntry->getTableID(), indexName)));
}
auto graphEntry = graph::NativeGraphEntry(std::move(nodeEntries), {appearsInEntry});

expression_vector columns;
auto& docsNode = nodeOutput->constCast<NodeExpression>();
Expand Down
Loading