From 4a9220f0f5b8c7b69b54a5396ee724f5109b0d68 Mon Sep 17 00:00:00 2001 From: ziyi chen Date: Tue, 9 Sep 2025 12:28:17 +0800 Subject: [PATCH 1/7] update --- extension/extension_config.cmake | 2 +- extension/fts/src/function/create_fts_index.cpp | 10 ++++++---- extension/fts/src/function/fts_config.cpp | 7 +++++-- extension/fts/src/include/function/fts_config.h | 8 ++++++++ 4 files changed, 20 insertions(+), 7 deletions(-) diff --git a/extension/extension_config.cmake b/extension/extension_config.cmake index 92154136db9..aa4ba7e7b44 100644 --- a/extension/extension_config.cmake +++ b/extension/extension_config.cmake @@ -1,6 +1,6 @@ set(EXTENSION_LIST azure delta duckdb fts httpfs iceberg json llm postgres sqlite unity_catalog vector neo4j algo) -#set(EXTENSION_STATIC_LINK_LIST fts) +set(EXTENSION_STATIC_LINK_LIST fts) string(JOIN ", " joined_extensions ${EXTENSION_STATIC_LINK_LIST}) message(STATUS "Static link extensions: ${joined_extensions}") foreach(extension IN LISTS EXTENSION_STATIC_LINK_LIST) diff --git a/extension/fts/src/function/create_fts_index.cpp b/extension/fts/src/function/create_fts_index.cpp index 2062137dcb0..26938ab9dfd 100644 --- a/extension/fts/src/function/create_fts_index.cpp +++ b/extension/fts/src/function/create_fts_index.cpp @@ -174,8 +174,9 @@ std::string createFTSIndexQuery(ClientContext& context, const TableFuncBindData& // Create the terms_in_doc table which servers as a temporary table to store the // relationship between terms and docs. auto appearsInfoTableName = FTSUtils::getAppearsInfoTableName(tableID, indexName); - query += stringFormat("CREATE NODE TABLE `{}` (ID SERIAL, term string, docID INT64, primary " - "key(ID));", + query += stringFormat( + "CREATE NODE TABLE `{}` (ID SERIAL, term string, term_origin string, docID INT64, primary " + "key(ID));", appearsInfoTableName); auto tableName = ftsBindData->tableName; auto tableEntry = catalog::Catalog::Get(context)->getTableCatalogEntry( @@ -189,7 +190,7 @@ std::string createFTSIndexQuery(ClientContext& context, const TableFuncBindData& "WITH t AS t1, id AS id1 " "WHERE t1 is NOT NULL AND SIZE(t1) > 0 AND " "NOT EXISTS {MATCH (s:`{}` {sw: t1})} " - "RETURN STEM(t1, '{}'), id1);", + "RETURN STEM(t1, '{}'), t1, id1);", appearsInfoTableName, tableName, FTSUtils::getTokenizeMacroName(tableID, indexName), propertyName, ftsBindData->createFTSConfig.stopWordsTableInfo.tableName, ftsBindData->createFTSConfig.stemmer); @@ -206,7 +207,8 @@ std::string createFTSIndexQuery(ClientContext& context, const TableFuncBindData& auto termsTableName = FTSUtils::getTermsTableName(tableID, indexName); // Create the dic table which records all distinct terms and their document frequency. - query += stringFormat("CREATE NODE TABLE `{}` (term STRING, df UINT64, PRIMARY KEY(term));", + query += stringFormat( + "CREATE NODE TABLE `{}` (term STRING, term_origin STRING, df UINT64, PRIMARY KEY(term));", termsTableName); query += stringFormat("COPY `{}` FROM " "(MATCH (t:`{}`) " diff --git a/extension/fts/src/function/fts_config.cpp b/extension/fts/src/function/fts_config.cpp index 44ca6905518..6b2058b9a2e 100644 --- a/extension/fts/src/function/fts_config.cpp +++ b/extension/fts/src/function/fts_config.cpp @@ -149,14 +149,17 @@ CreateFTSConfig::CreateFTSConfig(main::ClientContext& context, common::table_id_ common::StringUtils::replaceAll(ignorePatternQuery, "?", ""); IgnorePattern::validate(ignorePattern); IgnorePattern::validate(ignorePatternQuery); - } else if (lowerCaseName == "tokenizer") { - value.validateType(common::LogicalTypeID::STRING); + } else if (Tokenizer::NAME == lowerCaseName) { + value.validateType(Tokenizer::TYPE); tokenizerInfo.tokenizer = common::StringUtils::getLower(value.getValue()); Tokenizer::validate(tokenizerInfo.tokenizer); } else if (lowerCaseName == "jieba_dict_dir") { value.validateType(common::LogicalTypeID::STRING); tokenizerInfo.jiebaDictDir = common::StringUtils::getLower(value.getValue()); + } else if (AdvancedWildCardPattern::NAME == lowerCaseName) { + value.validateType(AdvancedWildCardPattern::TYPE); + advancedWildCardPattern = value.getValue(); } else { throw common::BinderException{"Unrecognized optional parameter: " + name}; } diff --git a/extension/fts/src/include/function/fts_config.h b/extension/fts/src/include/function/fts_config.h index 2a1d8d2628b..d6e60b528a5 100644 --- a/extension/fts/src/include/function/fts_config.h +++ b/extension/fts/src/include/function/fts_config.h @@ -4,6 +4,7 @@ #include "common/types/types.h" #include "function/table/bind_input.h" +#include "function/table/optional_params.h" namespace kuzu { namespace fts_extension { @@ -16,6 +17,12 @@ struct Stemmer { static void validate(const std::string& stemmer); }; +struct AdvancedWildCardPattern { + static constexpr const char* NAME = "advanced_wild_card_pattern"; + static constexpr common::LogicalTypeID TYPE = common::LogicalTypeID::BOOL; + static constexpr bool DEFAULT_VALUE = false; +}; + enum class StopWordsSource : uint8_t { FILE = 0, TABLE = 1, @@ -78,6 +85,7 @@ struct CreateFTSConfig { std::string ignorePattern = IgnorePattern::DEFAULT_VALUE; std::string ignorePatternQuery = IgnorePattern::DEFAULT_VALUE_QUERY; TokenizerInfo tokenizerInfo; + bool advancedWildCardPattern = AdvancedWildCardPattern::DEFAULT_VALUE; CreateFTSConfig() = default; CreateFTSConfig(main::ClientContext& context, common::table_id_t tableID, From 5dda78fb08facfb9b58c1e01b19bbee8b1455567 Mon Sep 17 00:00:00 2001 From: ziyi chen Date: Wed, 10 Sep 2025 15:43:01 +0800 Subject: [PATCH 2/7] update --- extension/extension_config.cmake | 2 +- extension/fts/src/function/CMakeLists.txt | 4 +- .../fts/src/function/create_fts_index.cpp | 26 +++- extension/fts/src/function/fts_config.cpp | 11 +- .../fts/src/function/query_fts/CMakeLists.txt | 10 ++ .../{ => query_fts}/query_fts_bind_data.cpp | 44 ++++-- .../{ => query_fts}/query_fts_index.cpp | 93 +++--------- .../query_fts/query_fts_pattern_match.cpp | 135 ++++++++++++++++++ .../query_fts/query_fts_term_lookup.cpp | 44 ++++++ .../fts/src/include/function/fts_config.h | 11 +- .../{ => query_fts}/query_fts_bind_data.h | 21 ++- .../{ => query_fts}/query_fts_index.h | 0 .../query_fts/query_fts_pattern_match.h | 25 ++++ .../query_fts/query_fts_term_lookup.h | 32 +++++ extension/fts/src/include/utils/fts_utils.h | 10 ++ extension/fts/src/main/fts_extension.cpp | 2 +- .../test_files/advanced_pattern_match.test | 24 ++++ 17 files changed, 384 insertions(+), 110 deletions(-) create mode 100644 extension/fts/src/function/query_fts/CMakeLists.txt rename extension/fts/src/function/{ => query_fts}/query_fts_bind_data.cpp (51%) rename extension/fts/src/function/{ => query_fts}/query_fts_index.cpp (83%) create mode 100644 extension/fts/src/function/query_fts/query_fts_pattern_match.cpp create mode 100644 extension/fts/src/function/query_fts/query_fts_term_lookup.cpp rename extension/fts/src/include/function/{ => query_fts}/query_fts_bind_data.h (77%) rename extension/fts/src/include/function/{ => query_fts}/query_fts_index.h (100%) create mode 100644 extension/fts/src/include/function/query_fts/query_fts_pattern_match.h create mode 100644 extension/fts/src/include/function/query_fts/query_fts_term_lookup.h create mode 100644 extension/fts/test/test_files/advanced_pattern_match.test diff --git a/extension/extension_config.cmake b/extension/extension_config.cmake index aa4ba7e7b44..92154136db9 100644 --- a/extension/extension_config.cmake +++ b/extension/extension_config.cmake @@ -1,6 +1,6 @@ set(EXTENSION_LIST azure delta duckdb fts httpfs iceberg json llm postgres sqlite unity_catalog vector neo4j algo) -set(EXTENSION_STATIC_LINK_LIST fts) +#set(EXTENSION_STATIC_LINK_LIST fts) string(JOIN ", " joined_extensions ${EXTENSION_STATIC_LINK_LIST}) message(STATUS "Static link extensions: ${joined_extensions}") foreach(extension IN LISTS EXTENSION_STATIC_LINK_LIST) diff --git a/extension/fts/src/function/CMakeLists.txt b/extension/fts/src/function/CMakeLists.txt index 644389029dc..5e693626b51 100644 --- a/extension/fts/src/function/CMakeLists.txt +++ b/extension/fts/src/function/CMakeLists.txt @@ -1,11 +1,11 @@ +add_subdirectory(query_fts) + add_library(kuzu_fts_function OBJECT create_fts_index.cpp drop_fts_index.cpp fts_config.cpp fts_index_utils.cpp - query_fts_index.cpp - query_fts_bind_data.cpp stem.cpp tokenize.cpp) diff --git a/extension/fts/src/function/create_fts_index.cpp b/extension/fts/src/function/create_fts_index.cpp index 26938ab9dfd..ec0ef2fac66 100644 --- a/extension/fts/src/function/create_fts_index.cpp +++ b/extension/fts/src/function/create_fts_index.cpp @@ -145,6 +145,20 @@ static std::string formatStrInCypher(const std::string& input) { return result; } +static std::string createAdvancedPatternMatchTable(const CreateFTSBindData& bindData) { + std::string query; + auto appearsInfoTableName = + FTSUtils::getAppearsInfoTableName(bindData.tableID, bindData.indexName); + auto originalTermsTableName = + FTSUtils::getOrigTermsTableName(bindData.tableID, bindData.indexName); + query += common::stringFormat("CREATE NODE TABLE `{}`(term string, primary key(term));", + originalTermsTableName); + query += + common::stringFormat("COPY `{}` FROM (match (doc:`{}`) return distinct doc.term_origin);", + originalTermsTableName, appearsInfoTableName); + return query; +} + std::string createFTSIndexQuery(ClientContext& context, const TableFuncBindData& bindData) { auto ftsBindData = bindData.constPtrCast(); auto tableID = ftsBindData->tableID; @@ -207,14 +221,18 @@ std::string createFTSIndexQuery(ClientContext& context, const TableFuncBindData& auto termsTableName = FTSUtils::getTermsTableName(tableID, indexName); // Create the dic table which records all distinct terms and their document frequency. - query += stringFormat( - "CREATE NODE TABLE `{}` (term STRING, term_origin STRING, df UINT64, PRIMARY KEY(term));", + query += stringFormat("CREATE NODE TABLE `{}` (term STRING, df UINT64, PRIMARY KEY(term));", termsTableName); query += stringFormat("COPY `{}` FROM " "(MATCH (t:`{}`) " "RETURN t.term, CAST(count(distinct t.docID) AS UINT64));", termsTableName, appearsInfoTableName); + // If the advanced_pattern_match is enabled, we need to create two additional tables. + if (ftsBindData->createFTSConfig.advancedPatternMatch) { + query += createAdvancedPatternMatchTable(*ftsBindData); + } + auto appearsInTableName = FTSUtils::getAppearsInTableName(tableID, indexName); // Finally, create a terms table that records the documents in which the terms appear, along // with the frequency of each term. @@ -238,8 +256,10 @@ std::string createFTSIndexQuery(ClientContext& context, const TableFuncBindData& properties += "]"; std::string params; params += stringFormat("stemmer := '{}', ", ftsBindData->createFTSConfig.stemmer); - params += stringFormat("stopWords := '{}'", + params += stringFormat("stopWords := '{}', ", ftsBindData->createFTSConfig.stopWordsTableInfo.stopWords); + params += stringFormat("advanced_pattern_match := {}", + ftsBindData->createFTSConfig.advancedPatternMatch ? "true" : "false"); query += stringFormat("CALL _CREATE_FTS_INDEX('{}', '{}', {}, {});", tableName, indexName, properties, params); query += stringFormat("RETURN 'Index {} has been created.' as result;", ftsBindData->indexName); diff --git a/extension/fts/src/function/fts_config.cpp b/extension/fts/src/function/fts_config.cpp index 6b2058b9a2e..e2d2dedaffc 100644 --- a/extension/fts/src/function/fts_config.cpp +++ b/extension/fts/src/function/fts_config.cpp @@ -157,9 +157,9 @@ CreateFTSConfig::CreateFTSConfig(main::ClientContext& context, common::table_id_ value.validateType(common::LogicalTypeID::STRING); tokenizerInfo.jiebaDictDir = common::StringUtils::getLower(value.getValue()); - } else if (AdvancedWildCardPattern::NAME == lowerCaseName) { - value.validateType(AdvancedWildCardPattern::TYPE); - advancedWildCardPattern = value.getValue(); + } else if (AdvancedPatternMatch::NAME == lowerCaseName) { + value.validateType(AdvancedPatternMatch::TYPE); + advancedPatternMatch = value.getValue(); } else { throw common::BinderException{"Unrecognized optional parameter: " + name}; } @@ -168,7 +168,8 @@ CreateFTSConfig::CreateFTSConfig(main::ClientContext& context, common::table_id_ FTSConfig CreateFTSConfig::getFTSConfig() const { return FTSConfig{stemmer, stopWordsTableInfo.tableName, stopWordsTableInfo.stopWords, - ignorePattern, ignorePatternQuery, tokenizerInfo.tokenizer, tokenizerInfo.jiebaDictDir}; + ignorePattern, ignorePatternQuery, tokenizerInfo.tokenizer, tokenizerInfo.jiebaDictDir, + advancedPatternMatch}; } void FTSConfig::serialize(common::Serializer& serializer) const { @@ -179,6 +180,7 @@ void FTSConfig::serialize(common::Serializer& serializer) const { serializer.serializeValue(ignorePatternQuery); serializer.serializeValue(tokenizer); serializer.serializeValue(jiebaDictDir); + serializer.serializeValue(advancedPatternMatch); } FTSConfig FTSConfig::deserialize(common::Deserializer& deserializer) { @@ -190,6 +192,7 @@ FTSConfig FTSConfig::deserialize(common::Deserializer& deserializer) { deserializer.deserializeValue(config.ignorePatternQuery); deserializer.deserializeValue(config.tokenizer); deserializer.deserializeValue(config.jiebaDictDir); + deserializer.deserializeValue(config.advancedPatternMatch); return config; } diff --git a/extension/fts/src/function/query_fts/CMakeLists.txt b/extension/fts/src/function/query_fts/CMakeLists.txt new file mode 100644 index 00000000000..93df47a9844 --- /dev/null +++ b/extension/fts/src/function/query_fts/CMakeLists.txt @@ -0,0 +1,10 @@ +add_library(kuzu_query_fts_function + OBJECT + query_fts_index.cpp + query_fts_pattern_match.cpp + query_fts_bind_data.cpp + query_fts_term_lookup.cpp) + +set(FTS_EXTENSION_OBJECT_FILES + ${FTS_EXTENSION_OBJECT_FILES} $ + PARENT_SCOPE) diff --git a/extension/fts/src/function/query_fts_bind_data.cpp b/extension/fts/src/function/query_fts/query_fts_bind_data.cpp similarity index 51% rename from extension/fts/src/function/query_fts_bind_data.cpp rename to extension/fts/src/function/query_fts/query_fts_bind_data.cpp index 11f2752b916..405219393ba 100644 --- a/extension/fts/src/function/query_fts_bind_data.cpp +++ b/extension/fts/src/function/query_fts/query_fts_bind_data.cpp @@ -1,4 +1,4 @@ -#include "function/query_fts_bind_data.h" +#include "function/query_fts/query_fts_bind_data.h" #include "binder/binder.h" #include "binder/expression/expression_util.h" @@ -42,19 +42,47 @@ void QueryFTSOptionalParams::evaluateParams(main::ClientContext* context) { topK.evaluateParam(context); } +QueryFTSBindData::QueryFTSBindData(binder::expression_vector columns, + graph::NativeGraphEntry graphEntry, std::shared_ptr docs, + std::shared_ptr query, const catalog::IndexCatalogEntry& entry, + std::unique_ptr optionalParams, common::idx_t numDocs, double avgDocLen) + : GDSBindData{std::move(columns), std::move(graphEntry), binder::expression_vector{docs}}, + query{std::move(query)}, entry{entry}, + outputTableID{output[0]->constCast().getTableIDs()[0]}, + numDocs{numDocs}, avgDocLen{avgDocLen}, + patternMatchAlgo{PatternMatchFactory::getPatternMatchAlgo( + entry.getAuxInfo().cast().config.advancedPatternMatch)} { + auto& nodeExpr = output[0]->constCast(); + KU_ASSERT(nodeExpr.getNumEntries() == 1); + outputTableID = nodeExpr.getEntry(0)->getTableID(); + this->optionalParams = std::move(optionalParams); +} + +catalog::TableCatalogEntry* QueryFTSBindData::getTermsEntry(main::ClientContext& context) const { + auto catalog = catalog::Catalog::Get(context); + return catalog->getTableCatalogEntry(transaction::Transaction::Get(context), + FTSUtils::getTermsTableName(entry.getTableID(), entry.getIndexName())); +} + +catalog::TableCatalogEntry* QueryFTSBindData::getOrigTermsEntry( + main::ClientContext& context) const { + auto catalog = catalog::Catalog::Get(context); + return catalog->getTableCatalogEntry(transaction::Transaction::Get(context), + FTSUtils::getOrigTermsTableName(entry.getTableID(), entry.getIndexName())); +} + std::vector QueryFTSBindData::getQueryTerms(main::ClientContext& context) const { auto queryInStr = ExpressionUtil::evaluateLiteral(&context, query, LogicalType::STRING()); auto config = entry.getAuxInfo().cast().config; FTSUtils::normalizeQuery(queryInStr, config.ignorePatternQuery); auto terms = FTSUtils::tokenizeString(queryInStr, config); - auto stopWordsTable = - StorageManager::Get(context) - ->getTable(catalog::Catalog::Get(context) - ->getTableCatalogEntry(transaction::Transaction::Get(context), - config.stopWordsTableName) - ->getTableID()) - ->ptrCast(); + auto stopWordsTable = StorageManager::Get(context) + ->getTable(catalog::Catalog::Get(context) + ->getTableCatalogEntry(transaction::Transaction::Get(context), + config.stopWordsTableName) + ->getTableID()) + ->ptrCast(); return FTSUtils::stemTerms(terms, entry.getAuxInfo().cast().config, MemoryManager::Get(context), stopWordsTable, transaction::Transaction::Get(context), optionalParams->constCast().conjunctive.getParamVal(), diff --git a/extension/fts/src/function/query_fts_index.cpp b/extension/fts/src/function/query_fts/query_fts_index.cpp similarity index 83% rename from extension/fts/src/function/query_fts_index.cpp rename to extension/fts/src/function/query_fts/query_fts_index.cpp index f663d93fac0..68cf614f6b5 100644 --- a/extension/fts/src/function/query_fts_index.cpp +++ b/extension/fts/src/function/query_fts/query_fts_index.cpp @@ -1,4 +1,4 @@ -#include "function/query_fts_index.h" +#include "function/query_fts/query_fts_index.h" #include @@ -11,16 +11,16 @@ #include "common/types/internal_id_util.h" #include "function/fts_index_utils.h" #include "function/gds/gds_utils.h" -#include "function/query_fts_bind_data.h" +#include "function/query_fts/query_fts_bind_data.h" +#include "function/query_fts/query_fts_pattern_match.h" +#include "function/query_fts/query_fts_term_lookup.h" #include "graph/on_disk_graph.h" #include "index/fts_index.h" #include "planner/operator/logical_hash_join.h" #include "planner/operator/logical_table_function_call.h" #include "planner/planner.h" #include "processor/execution_context.h" -#include "re2.h" #include "storage/storage_manager.h" -#include "storage/table/node_table.h" #include "utils/fts_utils.h" namespace kuzu { @@ -242,71 +242,14 @@ class QFTSVertexCompute final : public VertexCompute { std::unique_ptr writer; }; -using VCQueryTerm = std::variant>; -class MatchTermsVertexCompute final : public VertexCompute { -public: - explicit MatchTermsVertexCompute(std::unordered_map& resDfs, - std::vector& queryTerms) - : resDfs{resDfs}, queryTerms{queryTerms} {} - void vertexCompute(const graph::VertexScanState::Chunk& chunk) override { - auto terms = chunk.getProperties(0); - auto dfs = chunk.getProperties(1); - auto nodeIds = chunk.getNodeIDs(); - for (auto& queryTerm : queryTerms) { - // queryTerm.index() is 0 for string, 1 for unique_ptr - if (queryTerm.index() == 0) { - std::string& queryString = std::get<0>(queryTerm); - for (auto i = 0u; i < chunk.size(); ++i) { - if (queryString == terms[i].getAsString()) { - resDfs[nodeIds[i].offset] = dfs[i]; - } - } - } else { - RE2& regex = *std::get<1>(queryTerm); - for (auto i = 0u; i < chunk.size(); ++i) { - if (RE2::FullMatch(terms[i].getAsString(), regex)) { - resDfs[nodeIds[i].offset] = dfs[i]; - } - } - } - } - } - std::unique_ptr copy() override { - return std::make_unique(resDfs, queryTerms); - } - -private: - std::unordered_map& resDfs; - std::vector& queryTerms; -}; - static constexpr char SCORE_PROP_NAME[] = "score"; -static constexpr char DOC_FREQUENCY_PROP_NAME[] = "df"; static constexpr char TERM_FREQUENCY_PROP_NAME[] = "tf"; static constexpr char DOC_LEN_PROP_NAME[] = "len"; static constexpr char DOC_ID_PROP_NAME[] = "docID"; static std::unordered_map getDFs(main::ClientContext& context, processor::ExecutionContext* executionContext, graph::Graph* graph, - catalog::TableCatalogEntry* termsEntry, std::vector& queryTerms) { - auto storageManager = StorageManager::Get(context); - auto tableID = termsEntry->getTableID(); - auto& termsNodeTable = storageManager->getTable(tableID)->cast(); - auto tx = transaction::Transaction::Get(context); - auto dfColumnID = termsEntry->getColumnID(DOC_FREQUENCY_PROP_NAME); - std::vector vectorTypes; - vectorTypes.push_back(LogicalType::INTERNAL_ID()); - vectorTypes.push_back(LogicalType::UINT64()); - auto dataChunk = Table::constructDataChunk(MemoryManager::Get(context), std::move(vectorTypes)); - dataChunk.state->getSelVectorUnsafe().setSelSize(1); - auto nodeIDVector = &dataChunk.getValueVectorMutable(0); - auto dfVector = &dataChunk.getValueVectorMutable(1); - auto termsVector = ValueVector(LogicalType::STRING(), MemoryManager::Get(context)); - termsVector.state = dataChunk.state; - auto nodeTableScanState = - NodeTableScanState(nodeIDVector, std::vector{dfVector}, dataChunk.state); - nodeTableScanState.setToTable(transaction::Transaction::Get(context), &termsNodeTable, - {dfColumnID}, {}); + const QueryFTSBindData& bindData, std::vector& queryTerms) { std::unordered_map dfs; std::vector vcQueryTerms; vcQueryTerms.reserve(queryTerms.size()); @@ -323,22 +266,17 @@ static std::unordered_map getDFs(main::ClientContext& contex vcQueryTerms.emplace_back(std::in_place_type, queryTerm); } } + if (hasWildcardQueryTerm) { - auto matchVc = MatchTermsVertexCompute{dfs, vcQueryTerms}; - GDSUtils::runVertexCompute(executionContext, GDSDensityState::DENSE, graph, matchVc, - termsEntry, std::vector{"term", DOC_FREQUENCY_PROP_NAME}); + bindData.patternMatchAlgo(dfs, vcQueryTerms, executionContext, graph, bindData); } else { + TermsDFLookup termsDFLookup{bindData.getTermsEntry(context), context}; for (auto& queryTerm : queryTerms) { - termsVector.setValue(0, queryTerm); - offset_t offset = 0; - if (!termsNodeTable.lookupPK(tx, &termsVector, 0 /* vectorPos */, offset)) { + auto offsetDFPair = termsDFLookup.lookupTermDF(queryTerm); + if (offsetDFPair.first == INVALID_OFFSET) { continue; } - auto nodeID = nodeID_t{offset, tableID}; - nodeIDVector->setValue(0, nodeID); - termsNodeTable.initScanState(tx, nodeTableScanState, tableID, offset); - [[maybe_unused]] auto res = termsNodeTable.lookup(tx, nodeTableScanState); - dfs.emplace(offset, dfVector->getValue(0)); + dfs.emplace(offsetDFPair); } } return dfs; @@ -381,7 +319,7 @@ static offset_t tableFunc(const TableFuncInput& input, TableFuncOutput&) { } auto termsEntry = graphEntry->nodeInfos[0].entry; auto queryTerms = qFTSBindData.getQueryTerms(clientContext); - auto dfs = getDFs(clientContext, input.context, graph, termsEntry, queryTerms); + auto dfs = getDFs(clientContext, input.context, graph, qFTSBindData, queryTerms); // Do edge compute to extend terms -> docs and save the term frequency and document frequency // for each term-doc pair. The reason why we store the term frequency and document frequency // is that: we need the `len` property from the docs table which is only available during the @@ -459,7 +397,12 @@ static std::unique_ptr bindFunc(main::ClientContext* context, FTSUtils::getDocsTableName(tableEntry->getTableID(), indexName)); auto appearsInEntry = catalog->getTableCatalogEntry(transaction, FTSUtils::getAppearsInTableName(tableEntry->getTableID(), indexName)); - auto graphEntry = graph::NativeGraphEntry({termsEntry, docsEntry}, {appearsInEntry}); + std::vector nodeEntries{termsEntry, docsEntry}; + if (ftsIndexEntry->getAuxInfo().cast().config.advancedPatternMatch) { + nodeEntries.push_back(catalog->getTableCatalogEntry(transaction, + FTSUtils::getOrigTermsTableName(tableEntry->getTableID(), indexName))); + } + auto graphEntry = graph::NativeGraphEntry(std::move(nodeEntries), {appearsInEntry}); expression_vector columns; auto& docsNode = nodeOutput->constCast(); diff --git a/extension/fts/src/function/query_fts/query_fts_pattern_match.cpp b/extension/fts/src/function/query_fts/query_fts_pattern_match.cpp new file mode 100644 index 00000000000..295d53ea226 --- /dev/null +++ b/extension/fts/src/function/query_fts/query_fts_pattern_match.cpp @@ -0,0 +1,135 @@ +#include "function/query_fts/query_fts_pattern_match.h" + +#include + +#include "catalog/fts_index_catalog_entry.h" +#include "function/gds/compute.h" +#include "function/gds/gds_utils.h" +#include "function/query_fts/query_fts_bind_data.h" +#include "function/query_fts/query_fts_term_lookup.h" +#include "libstemmer.h" +#include "storage/storage_manager.h" +#include "utils/fts_utils.h" + +using namespace kuzu::function; +using namespace kuzu::processor; + +namespace kuzu { +namespace fts_extension { + +class MatchTermVertexCompute : public function::VertexCompute { +public: + MatchTermVertexCompute(std::vector& queryTerms, + std::unordered_map& resDfs) + : queryTerms{queryTerms}, resDfs{resDfs} {} + + virtual void handleMatchedTerm(uint64_t itr, const graph::VertexScanState::Chunk& chunk) = 0; + + void vertexCompute(const graph::VertexScanState::Chunk& chunk) override { + auto terms = chunk.getProperties(0); + for (auto& queryTerm : queryTerms) { + // queryTerm.index() is 0 for string, 1 for unique_ptr + if (queryTerm.index() == 0) { + std::string& queryString = std::get<0>(queryTerm); + for (auto i = 0u; i < chunk.size(); ++i) { + if (queryString == terms[i].getAsString()) { + handleMatchedTerm(i, chunk); + } + } + } else { + RE2& regex = *std::get<1>(queryTerm); + for (auto i = 0u; i < chunk.size(); ++i) { + if (RE2::FullMatch(terms[i].getAsString(), regex)) { + handleMatchedTerm(i, chunk); + } + } + } + } + } + +protected: + std::vector& queryTerms; + std::unordered_map& resDfs; +}; + +class BasicMatchVertexCompute final : public MatchTermVertexCompute { +public: + explicit BasicMatchVertexCompute(std::unordered_map& resDfs, + std::vector& queryTerms) + : MatchTermVertexCompute{queryTerms, resDfs} {} + + void handleMatchedTerm(uint64_t itr, const graph::VertexScanState::Chunk& chunk) override { + auto dfs = chunk.getProperties(1); + auto nodeIds = chunk.getNodeIDs(); + resDfs[nodeIds[itr].offset] = dfs[itr]; + } + + std::unique_ptr copy() override { + return std::make_unique(resDfs, queryTerms); + } +}; + +class AdvancedMatchVertexCompute final : public MatchTermVertexCompute { +public: + AdvancedMatchVertexCompute(std::unordered_map& resDfs, + std::vector& queryTerms, const QueryFTSBindData& bindData, + main::ClientContext& context) + : MatchTermVertexCompute{queryTerms, resDfs}, + sbStemmer{sb_stemmer_new( + reinterpret_cast( + bindData.entry.getAuxInfo().cast().config.stemmer.c_str()), + "UTF_8")}, + bindData{bindData}, context{context}, + termsDFLookup{bindData.getTermsEntry(context), context} {} + + ~AdvancedMatchVertexCompute() override { sb_stemmer_delete(sbStemmer); } + + void handleMatchedTerm(uint64_t itr, const graph::VertexScanState::Chunk& chunk) override { + auto term = chunk.getProperties(0)[itr]; + auto stemData = sb_stemmer_stem(sbStemmer, + reinterpret_cast(term.getData()), term.len); + auto result = termsDFLookup.lookupTermDF(reinterpret_cast(stemData)); + KU_ASSERT(result.first != common::INVALID_OFFSET); + resDfs.insert(result); + } + + std::unique_ptr copy() override { + return std::make_unique(resDfs, queryTerms, bindData, context); + } + +private: + sb_stemmer* sbStemmer; + const QueryFTSBindData& bindData; + main::ClientContext& context; + TermsDFLookup termsDFLookup; +}; + +static void basicMatchAlgo(std::unordered_map& dfs, + std::vector& vcQueryTerms, ExecutionContext* executionContext, graph::Graph* graph, + const QueryFTSBindData& bindData) { + auto matchVc = BasicMatchVertexCompute{dfs, vcQueryTerms}; + GDSUtils::runVertexCompute(executionContext, GDSDensityState::DENSE, graph, matchVc, + bindData.getTermsEntry(*executionContext->clientContext), + std::vector{"term", TermsDFLookup::DOC_FREQUENCY_PROP_NAME}); +} + +static void advancedMatchAlgo(std::unordered_map& dfs, + std::vector& vcQueryTerms, ExecutionContext* executionContext, graph::Graph* graph, + const QueryFTSBindData& bindData) { + auto matchOrigTermVc = + AdvancedMatchVertexCompute{dfs, vcQueryTerms, bindData, *executionContext->clientContext}; + GDSUtils::runVertexCompute(executionContext, GDSDensityState::DENSE, graph, matchOrigTermVc, + bindData.getOrigTermsEntry(*executionContext->clientContext), + std::vector{"term"}); +} + +pattern_match_algo PatternMatchFactory::getPatternMatchAlgo(bool isAdvanced) { + if (isAdvanced) { + return advancedMatchAlgo; + } else { + return basicMatchAlgo; + } +} + +} // namespace fts_extension +} // namespace kuzu diff --git a/extension/fts/src/function/query_fts/query_fts_term_lookup.cpp b/extension/fts/src/function/query_fts/query_fts_term_lookup.cpp new file mode 100644 index 00000000000..8c072e0f2e6 --- /dev/null +++ b/extension/fts/src/function/query_fts/query_fts_term_lookup.cpp @@ -0,0 +1,44 @@ +#include "function/query_fts/query_fts_term_lookup.h" + +#include "storage/storage_manager.h" +#include "transaction/transaction.h" + +namespace kuzu { +namespace fts_extension { + +using namespace kuzu::common; +using namespace kuzu::catalog; +using namespace kuzu::main; +using namespace kuzu::storage; +using namespace kuzu::transaction; + +TermsDFLookup::TermsDFLookup(TableCatalogEntry* termsEntry, ClientContext& context) + : dataChunkState{DataChunkState::getSingleValueDataChunkState()}, + termsVector{LogicalType::STRING(), MemoryManager::Get(context)}, + nodeIDVector{LogicalType::INTERNAL_ID()}, dfVector{LogicalType::UINT64()}, + termsTable{ + StorageManager::Get(context)->getTable(termsEntry->getTableID())->cast()}, + nodeTableScanState{&nodeIDVector, std::vector{&dfVector}, dataChunkState}, + dfColumnID{termsEntry->getColumnID(DOC_FREQUENCY_PROP_NAME)}, trx{Transaction::Get(context)} { + termsVector.state = dataChunkState; + nodeIDVector.state = dataChunkState; + dfVector.state = dataChunkState; + nodeTableScanState.setToTable(transaction::Transaction::Get(context), &termsTable, {dfColumnID}, + {}); +} + +std::pair TermsDFLookup::lookupTermDF(const std::string& term) { + termsVector.setValue(0, term); + offset_t offset = 0; + if (!termsTable.lookupPK(trx, &termsVector, 0 /* vectorPos */, offset)) { + return {INVALID_OFFSET, UINT64_MAX}; + } + auto nodeID = nodeID_t{offset, termsTable.getTableID()}; + nodeIDVector.setValue(0, nodeID); + termsTable.initScanState(trx, nodeTableScanState, termsTable.getTableID(), offset); + termsTable.lookup(trx, nodeTableScanState); + return {offset, dfVector.getValue(0)}; +} + +} // namespace fts_extension +} // namespace kuzu diff --git a/extension/fts/src/include/function/fts_config.h b/extension/fts/src/include/function/fts_config.h index d6e60b528a5..4b461beb9cc 100644 --- a/extension/fts/src/include/function/fts_config.h +++ b/extension/fts/src/include/function/fts_config.h @@ -17,8 +17,8 @@ struct Stemmer { static void validate(const std::string& stemmer); }; -struct AdvancedWildCardPattern { - static constexpr const char* NAME = "advanced_wild_card_pattern"; +struct AdvancedPatternMatch { + static constexpr const char* NAME = "advanced_pattern_match"; static constexpr common::LogicalTypeID TYPE = common::LogicalTypeID::BOOL; static constexpr bool DEFAULT_VALUE = false; }; @@ -85,7 +85,7 @@ struct CreateFTSConfig { std::string ignorePattern = IgnorePattern::DEFAULT_VALUE; std::string ignorePatternQuery = IgnorePattern::DEFAULT_VALUE_QUERY; TokenizerInfo tokenizerInfo; - bool advancedWildCardPattern = AdvancedWildCardPattern::DEFAULT_VALUE; + bool advancedPatternMatch = AdvancedPatternMatch::DEFAULT_VALUE; CreateFTSConfig() = default; CreateFTSConfig(main::ClientContext& context, common::table_id_t tableID, @@ -104,15 +104,16 @@ struct FTSConfig { std::string ignorePatternQuery = ""; std::string tokenizer = ""; std::string jiebaDictDir = ""; + bool advancedPatternMatch = false; FTSConfig() = default; FTSConfig(std::string stemmer, std::string stopWordsTableName, std::string stopWordsSource, std::string ignorePattern, std::string ignorePatternQuery, std::string tokenizer, - std::string jiebaDictDir) + std::string jiebaDictDir, bool advancedPatternMatch) : stemmer{std::move(stemmer)}, stopWordsTableName{std::move(stopWordsTableName)}, stopWordsSource{std::move(stopWordsSource)}, ignorePattern{std::move(ignorePattern)}, ignorePatternQuery{std::move(ignorePatternQuery)}, tokenizer{std::move(tokenizer)}, - jiebaDictDir{std::move(jiebaDictDir)} {} + jiebaDictDir{std::move(jiebaDictDir)}, advancedPatternMatch{advancedPatternMatch} {} void serialize(common::Serializer& serializer) const; diff --git a/extension/fts/src/include/function/query_fts_bind_data.h b/extension/fts/src/include/function/query_fts/query_fts_bind_data.h similarity index 77% rename from extension/fts/src/include/function/query_fts_bind_data.h rename to extension/fts/src/include/function/query_fts/query_fts_bind_data.h index 716fbcdc780..f92585d4781 100644 --- a/extension/fts/src/include/function/query_fts_bind_data.h +++ b/extension/fts/src/include/function/query_fts/query_fts_bind_data.h @@ -4,6 +4,7 @@ #include "catalog/catalog_entry/index_catalog_entry.h" #include "function/fts_config.h" #include "function/gds/gds.h" +#include "function/query_fts/query_fts_pattern_match.h" namespace kuzu { namespace fts_extension { @@ -35,24 +36,22 @@ struct QueryFTSBindData final : public function::GDSBindData { common::table_id_t outputTableID; common::idx_t numDocs; double avgDocLen; + pattern_match_algo patternMatchAlgo; QueryFTSBindData(binder::expression_vector columns, graph::NativeGraphEntry graphEntry, std::shared_ptr docs, std::shared_ptr query, const catalog::IndexCatalogEntry& entry, std::unique_ptr optionalParams, common::idx_t numDocs, - double avgDocLen) - : GDSBindData{std::move(columns), std::move(graphEntry), binder::expression_vector{docs}}, - query{std::move(query)}, entry{entry}, - outputTableID{output[0]->constCast().getTableIDs()[0]}, - numDocs{numDocs}, avgDocLen{avgDocLen} { - auto& nodeExpr = output[0]->constCast(); - KU_ASSERT(nodeExpr.getNumEntries() == 1); - outputTableID = nodeExpr.getEntry(0)->getTableID(); - this->optionalParams = std::move(optionalParams); - } + double avgDocLen); + QueryFTSBindData(const QueryFTSBindData& other) : GDSBindData{other}, query{other.query}, entry{other.entry}, - outputTableID{other.outputTableID}, numDocs{other.numDocs}, avgDocLen{other.avgDocLen} {} + outputTableID{other.outputTableID}, numDocs{other.numDocs}, avgDocLen{other.avgDocLen}, + patternMatchAlgo{other.patternMatchAlgo} {} + + catalog::TableCatalogEntry* getTermsEntry(main::ClientContext& context) const; + + catalog::TableCatalogEntry* getOrigTermsEntry(main::ClientContext& context) const; std::vector getQueryTerms(main::ClientContext& context) const; diff --git a/extension/fts/src/include/function/query_fts_index.h b/extension/fts/src/include/function/query_fts/query_fts_index.h similarity index 100% rename from extension/fts/src/include/function/query_fts_index.h rename to extension/fts/src/include/function/query_fts/query_fts_index.h diff --git a/extension/fts/src/include/function/query_fts/query_fts_pattern_match.h b/extension/fts/src/include/function/query_fts/query_fts_pattern_match.h new file mode 100644 index 00000000000..282e7994c5f --- /dev/null +++ b/extension/fts/src/include/function/query_fts/query_fts_pattern_match.h @@ -0,0 +1,25 @@ +#pragma once + +#include "graph/graph.h" +#include "processor/execution_context.h" +#include "re2.h" + +namespace kuzu { +namespace fts_extension { + +struct FTSConfig; +struct QueryFTSBindData; + +using VCQueryTerm = std::variant>; + +using pattern_match_algo = std::function& dfs, + std::vector& vcQueryTerms, processor::ExecutionContext* executionContext, + graph::Graph* graph, const QueryFTSBindData& bindData)>; + +class PatternMatchFactory { +public: + static pattern_match_algo getPatternMatchAlgo(bool isAdvanced); +}; + +} // namespace fts_extension +} // namespace kuzu diff --git a/extension/fts/src/include/function/query_fts/query_fts_term_lookup.h b/extension/fts/src/include/function/query_fts/query_fts_term_lookup.h new file mode 100644 index 00000000000..f713ec89359 --- /dev/null +++ b/extension/fts/src/include/function/query_fts/query_fts_term_lookup.h @@ -0,0 +1,32 @@ +#pragma once + +#include "catalog/catalog_entry/table_catalog_entry.h" +#include "main/client_context.h" +#include "storage/buffer_manager/memory_manager.h" +#include "storage/table/node_table.h" + +namespace kuzu { +namespace fts_extension { + +class TermsDFLookup { +public: + static constexpr char DOC_FREQUENCY_PROP_NAME[] = "df"; + +public: + TermsDFLookup(catalog::TableCatalogEntry* termsEntry, main::ClientContext& context); + + std::pair lookupTermDF(const std::string& term); + +private: + std::shared_ptr dataChunkState; + common::ValueVector termsVector; + common::ValueVector nodeIDVector; + common::ValueVector dfVector; + storage::NodeTable& termsTable; + storage::NodeTableScanState nodeTableScanState; + common::column_id_t dfColumnID; + transaction::Transaction* trx; +}; + +} // namespace fts_extension +} // namespace kuzu diff --git a/extension/fts/src/include/utils/fts_utils.h b/extension/fts/src/include/utils/fts_utils.h index 7649733745f..d4d5d1bef1f 100644 --- a/extension/fts/src/include/utils/fts_utils.h +++ b/extension/fts/src/include/utils/fts_utils.h @@ -48,6 +48,16 @@ struct FTSUtils { return common::stringFormat("{}_terms", getInternalTablePrefix(tableID, indexName)); } + static std::string getOrigTermsTableName(common::table_id_t tableID, + const std::string& indexName) { + return common::stringFormat("{}_orig_terms", getInternalTablePrefix(tableID, indexName)); + } + + static std::string getOrigTermsRelTableName(common::table_id_t tableID, + const std::string& indexName) { + return common::stringFormat("{}_orig_terms_rel", getInternalTablePrefix(tableID, indexName)); + } + static std::string getAppearsInTableName(common::table_id_t tableID, const std::string& indexName) { return common::stringFormat("{}_appears_in", getInternalTablePrefix(tableID, indexName)); diff --git a/extension/fts/src/main/fts_extension.cpp b/extension/fts/src/main/fts_extension.cpp index edcbf59d6a9..201bb8f7833 100644 --- a/extension/fts/src/main/fts_extension.cpp +++ b/extension/fts/src/main/fts_extension.cpp @@ -4,7 +4,7 @@ #include "catalog/fts_index_catalog_entry.h" #include "function/create_fts_index.h" #include "function/drop_fts_index.h" -#include "function/query_fts_index.h" +#include "function/query_fts/query_fts_index.h" #include "function/stem.h" #include "function/tokenize.h" #include "index/fts_index.h" diff --git a/extension/fts/test/test_files/advanced_pattern_match.test b/extension/fts/test/test_files/advanced_pattern_match.test new file mode 100644 index 00000000000..8d5f56b8f0c --- /dev/null +++ b/extension/fts/test/test_files/advanced_pattern_match.test @@ -0,0 +1,24 @@ +-DATASET CSV fts-basic + +-- + +-CASE WildcardBasic + +-LOAD_DYNAMIC_EXTENSION fts +-STATEMENT CREATE NODE TABLE news (content string, primary key(content)); +---- ok +-STATEMENT create (n:news {content: "alice is a canadian runner"}) +---- ok +-STATEMENT create (n:news {content: "carol is running in the playground"}) +---- ok +-STATEMENT CALL CREATE_FTS_INDEX('news', 'news_index_0', ['content'], advanced_pattern_match := FALSE); +---- ok +-STATEMENT CALL QUERY_FTS_INDEX('news', 'news_index_0', 'runn*') RETURN node.content, score order by score +---- 1 +alice is a canadian runner|0.301030 +-STATEMENT CALL CREATE_FTS_INDEX('news', 'news_index_1', ['content'], advanced_pattern_match := TRUE); +---- ok +-STATEMENT CALL QUERY_FTS_INDEX('news', 'news_index_1', 'runn*') RETURN node.content, score order by score +---- 2 +alice is a canadian runner|0.301030 +carol is running in the playground|0.301030 From fede65a9f377312d50377141e2588292e06dceaa Mon Sep 17 00:00:00 2001 From: CI Bot Date: Wed, 10 Sep 2025 07:51:59 +0000 Subject: [PATCH 3/7] ci: auto code format --- .../src/function/query_fts/query_fts_bind_data.cpp | 13 +++++++------ extension/fts/src/include/utils/fts_utils.h | 3 ++- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/extension/fts/src/function/query_fts/query_fts_bind_data.cpp b/extension/fts/src/function/query_fts/query_fts_bind_data.cpp index 405219393ba..3abde4fbcca 100644 --- a/extension/fts/src/function/query_fts/query_fts_bind_data.cpp +++ b/extension/fts/src/function/query_fts/query_fts_bind_data.cpp @@ -77,12 +77,13 @@ std::vector QueryFTSBindData::getQueryTerms(main::ClientContext& co auto config = entry.getAuxInfo().cast().config; FTSUtils::normalizeQuery(queryInStr, config.ignorePatternQuery); auto terms = FTSUtils::tokenizeString(queryInStr, config); - auto stopWordsTable = StorageManager::Get(context) - ->getTable(catalog::Catalog::Get(context) - ->getTableCatalogEntry(transaction::Transaction::Get(context), - config.stopWordsTableName) - ->getTableID()) - ->ptrCast(); + auto stopWordsTable = + StorageManager::Get(context) + ->getTable(catalog::Catalog::Get(context) + ->getTableCatalogEntry(transaction::Transaction::Get(context), + config.stopWordsTableName) + ->getTableID()) + ->ptrCast(); return FTSUtils::stemTerms(terms, entry.getAuxInfo().cast().config, MemoryManager::Get(context), stopWordsTable, transaction::Transaction::Get(context), optionalParams->constCast().conjunctive.getParamVal(), diff --git a/extension/fts/src/include/utils/fts_utils.h b/extension/fts/src/include/utils/fts_utils.h index d4d5d1bef1f..709530b8dc0 100644 --- a/extension/fts/src/include/utils/fts_utils.h +++ b/extension/fts/src/include/utils/fts_utils.h @@ -55,7 +55,8 @@ struct FTSUtils { static std::string getOrigTermsRelTableName(common::table_id_t tableID, const std::string& indexName) { - return common::stringFormat("{}_orig_terms_rel", getInternalTablePrefix(tableID, indexName)); + return common::stringFormat("{}_orig_terms_rel", + getInternalTablePrefix(tableID, indexName)); } static std::string getAppearsInTableName(common::table_id_t tableID, From ec14a00ad776b123c075d4bfdb8bc2440091492e Mon Sep 17 00:00:00 2001 From: ziyi chen Date: Wed, 10 Sep 2025 16:03:08 +0800 Subject: [PATCH 4/7] update --- .../src/include/function/query_fts/query_fts_pattern_match.h | 2 ++ 1 file changed, 2 insertions(+) diff --git a/extension/fts/src/include/function/query_fts/query_fts_pattern_match.h b/extension/fts/src/include/function/query_fts/query_fts_pattern_match.h index 282e7994c5f..546ce5d7350 100644 --- a/extension/fts/src/include/function/query_fts/query_fts_pattern_match.h +++ b/extension/fts/src/include/function/query_fts/query_fts_pattern_match.h @@ -1,5 +1,7 @@ #pragma once +#include + #include "graph/graph.h" #include "processor/execution_context.h" #include "re2.h" From 8d0a0bc1f3289e6234c091915c0e749c9bc27ea8 Mon Sep 17 00:00:00 2001 From: ziyi chen Date: Wed, 10 Sep 2025 16:07:05 +0800 Subject: [PATCH 5/7] update --- .../fts/src/include/function/query_fts/query_fts_pattern_match.h | 1 + 1 file changed, 1 insertion(+) diff --git a/extension/fts/src/include/function/query_fts/query_fts_pattern_match.h b/extension/fts/src/include/function/query_fts/query_fts_pattern_match.h index 546ce5d7350..83c851255ef 100644 --- a/extension/fts/src/include/function/query_fts/query_fts_pattern_match.h +++ b/extension/fts/src/include/function/query_fts/query_fts_pattern_match.h @@ -1,5 +1,6 @@ #pragma once +#include #include #include "graph/graph.h" From 1ef39d59ca4ac2e245496b8a78880877db7c1265 Mon Sep 17 00:00:00 2001 From: ziyi chen Date: Thu, 11 Sep 2025 12:49:24 +0800 Subject: [PATCH 6/7] update --- .../fts/src/function/create_fts_index.cpp | 12 +++---- extension/fts/src/function/fts_config.cpp | 12 +++---- .../query_fts/query_fts_bind_data.cpp | 16 ++++----- .../function/query_fts/query_fts_index.cpp | 3 +- .../query_fts/query_fts_pattern_match.cpp | 35 ++++++++++--------- .../fts/src/include/function/fts_config.h | 12 +++---- .../query_fts/query_fts_pattern_match.h | 7 +++- .../test_files/advanced_pattern_match.test | 24 ------------- extension/fts/test/test_files/wildcard.test | 31 ++++++++++++++++ 9 files changed, 83 insertions(+), 69 deletions(-) delete mode 100644 extension/fts/test/test_files/advanced_pattern_match.test diff --git a/extension/fts/src/function/create_fts_index.cpp b/extension/fts/src/function/create_fts_index.cpp index ec0ef2fac66..3ee5d1a9de5 100644 --- a/extension/fts/src/function/create_fts_index.cpp +++ b/extension/fts/src/function/create_fts_index.cpp @@ -145,7 +145,7 @@ static std::string formatStrInCypher(const std::string& input) { return result; } -static std::string createAdvancedPatternMatchTable(const CreateFTSBindData& bindData) { +static std::string createTablesForExactTermMatch(const CreateFTSBindData& bindData) { std::string query; auto appearsInfoTableName = FTSUtils::getAppearsInfoTableName(bindData.tableID, bindData.indexName); @@ -228,9 +228,9 @@ std::string createFTSIndexQuery(ClientContext& context, const TableFuncBindData& "RETURN t.term, CAST(count(distinct t.docID) AS UINT64));", termsTableName, appearsInfoTableName); - // If the advanced_pattern_match is enabled, we need to create two additional tables. - if (ftsBindData->createFTSConfig.advancedPatternMatch) { - query += createAdvancedPatternMatchTable(*ftsBindData); + // If the exact_term_match is enabled, we need to create an additional tables. + if (ftsBindData->createFTSConfig.exactTermMatch) { + query += createTablesForExactTermMatch(*ftsBindData); } auto appearsInTableName = FTSUtils::getAppearsInTableName(tableID, indexName); @@ -258,8 +258,8 @@ std::string createFTSIndexQuery(ClientContext& context, const TableFuncBindData& params += stringFormat("stemmer := '{}', ", ftsBindData->createFTSConfig.stemmer); params += stringFormat("stopWords := '{}', ", ftsBindData->createFTSConfig.stopWordsTableInfo.stopWords); - params += stringFormat("advanced_pattern_match := {}", - ftsBindData->createFTSConfig.advancedPatternMatch ? "true" : "false"); + params += stringFormat("exact_term_match := {}", + ftsBindData->createFTSConfig.exactTermMatch ? "true" : "false"); query += stringFormat("CALL _CREATE_FTS_INDEX('{}', '{}', {}, {});", tableName, indexName, properties, params); query += stringFormat("RETURN 'Index {} has been created.' as result;", ftsBindData->indexName); diff --git a/extension/fts/src/function/fts_config.cpp b/extension/fts/src/function/fts_config.cpp index e2d2dedaffc..2c1226b4c6a 100644 --- a/extension/fts/src/function/fts_config.cpp +++ b/extension/fts/src/function/fts_config.cpp @@ -157,9 +157,9 @@ CreateFTSConfig::CreateFTSConfig(main::ClientContext& context, common::table_id_ value.validateType(common::LogicalTypeID::STRING); tokenizerInfo.jiebaDictDir = common::StringUtils::getLower(value.getValue()); - } else if (AdvancedPatternMatch::NAME == lowerCaseName) { - value.validateType(AdvancedPatternMatch::TYPE); - advancedPatternMatch = value.getValue(); + } else if (ExactTermMatch::NAME == lowerCaseName) { + value.validateType(ExactTermMatch::TYPE); + exactTermMatch = value.getValue(); } else { throw common::BinderException{"Unrecognized optional parameter: " + name}; } @@ -169,7 +169,7 @@ CreateFTSConfig::CreateFTSConfig(main::ClientContext& context, common::table_id_ FTSConfig CreateFTSConfig::getFTSConfig() const { return FTSConfig{stemmer, stopWordsTableInfo.tableName, stopWordsTableInfo.stopWords, ignorePattern, ignorePatternQuery, tokenizerInfo.tokenizer, tokenizerInfo.jiebaDictDir, - advancedPatternMatch}; + exactTermMatch}; } void FTSConfig::serialize(common::Serializer& serializer) const { @@ -180,7 +180,7 @@ void FTSConfig::serialize(common::Serializer& serializer) const { serializer.serializeValue(ignorePatternQuery); serializer.serializeValue(tokenizer); serializer.serializeValue(jiebaDictDir); - serializer.serializeValue(advancedPatternMatch); + serializer.serializeValue(exactTermMatch); } FTSConfig FTSConfig::deserialize(common::Deserializer& deserializer) { @@ -192,7 +192,7 @@ FTSConfig FTSConfig::deserialize(common::Deserializer& deserializer) { deserializer.deserializeValue(config.ignorePatternQuery); deserializer.deserializeValue(config.tokenizer); deserializer.deserializeValue(config.jiebaDictDir); - deserializer.deserializeValue(config.advancedPatternMatch); + deserializer.deserializeValue(config.exactTermMatch); return config; } diff --git a/extension/fts/src/function/query_fts/query_fts_bind_data.cpp b/extension/fts/src/function/query_fts/query_fts_bind_data.cpp index 3abde4fbcca..02dc3acd417 100644 --- a/extension/fts/src/function/query_fts/query_fts_bind_data.cpp +++ b/extension/fts/src/function/query_fts/query_fts_bind_data.cpp @@ -51,7 +51,8 @@ QueryFTSBindData::QueryFTSBindData(binder::expression_vector columns, outputTableID{output[0]->constCast().getTableIDs()[0]}, numDocs{numDocs}, avgDocLen{avgDocLen}, patternMatchAlgo{PatternMatchFactory::getPatternMatchAlgo( - entry.getAuxInfo().cast().config.advancedPatternMatch)} { + entry.getAuxInfo().cast().config.exactTermMatch ? TermMatchType::EXACT : + TermMatchType::STEM)} { auto& nodeExpr = output[0]->constCast(); KU_ASSERT(nodeExpr.getNumEntries() == 1); outputTableID = nodeExpr.getEntry(0)->getTableID(); @@ -77,13 +78,12 @@ std::vector QueryFTSBindData::getQueryTerms(main::ClientContext& co auto config = entry.getAuxInfo().cast().config; FTSUtils::normalizeQuery(queryInStr, config.ignorePatternQuery); auto terms = FTSUtils::tokenizeString(queryInStr, config); - auto stopWordsTable = - StorageManager::Get(context) - ->getTable(catalog::Catalog::Get(context) - ->getTableCatalogEntry(transaction::Transaction::Get(context), - config.stopWordsTableName) - ->getTableID()) - ->ptrCast(); + auto stopWordsTable = StorageManager::Get(context) + ->getTable(catalog::Catalog::Get(context) + ->getTableCatalogEntry(transaction::Transaction::Get(context), + config.stopWordsTableName) + ->getTableID()) + ->ptrCast(); return FTSUtils::stemTerms(terms, entry.getAuxInfo().cast().config, MemoryManager::Get(context), stopWordsTable, transaction::Transaction::Get(context), optionalParams->constCast().conjunctive.getParamVal(), diff --git a/extension/fts/src/function/query_fts/query_fts_index.cpp b/extension/fts/src/function/query_fts/query_fts_index.cpp index 68cf614f6b5..bf86407f2f9 100644 --- a/extension/fts/src/function/query_fts/query_fts_index.cpp +++ b/extension/fts/src/function/query_fts/query_fts_index.cpp @@ -382,7 +382,6 @@ static std::unique_ptr bindFunc(main::ClientContext* context, auto inputTableName = getParamVal(*input, 0); auto indexName = getParamVal(*input, 1); auto query = input->getParam(2); - auto tableEntry = FTSIndexUtils::bindNodeTable(*context, inputTableName, indexName, FTSIndexUtils::IndexOperation::QUERY); auto catalog = catalog::Catalog::Get(*context); @@ -398,7 +397,7 @@ static std::unique_ptr bindFunc(main::ClientContext* context, auto appearsInEntry = catalog->getTableCatalogEntry(transaction, FTSUtils::getAppearsInTableName(tableEntry->getTableID(), indexName)); std::vector nodeEntries{termsEntry, docsEntry}; - if (ftsIndexEntry->getAuxInfo().cast().config.advancedPatternMatch) { + if (ftsIndexEntry->getAuxInfo().cast().config.exactTermMatch) { nodeEntries.push_back(catalog->getTableCatalogEntry(transaction, FTSUtils::getOrigTermsTableName(tableEntry->getTableID(), indexName))); } diff --git a/extension/fts/src/function/query_fts/query_fts_pattern_match.cpp b/extension/fts/src/function/query_fts/query_fts_pattern_match.cpp index 295d53ea226..6ba01068086 100644 --- a/extension/fts/src/function/query_fts/query_fts_pattern_match.cpp +++ b/extension/fts/src/function/query_fts/query_fts_pattern_match.cpp @@ -52,9 +52,9 @@ class MatchTermVertexCompute : public function::VertexCompute { std::unordered_map& resDfs; }; -class BasicMatchVertexCompute final : public MatchTermVertexCompute { +class StemTermMatchVertexCompute final : public MatchTermVertexCompute { public: - explicit BasicMatchVertexCompute(std::unordered_map& resDfs, + explicit StemTermMatchVertexCompute(std::unordered_map& resDfs, std::vector& queryTerms) : MatchTermVertexCompute{queryTerms, resDfs} {} @@ -65,13 +65,13 @@ class BasicMatchVertexCompute final : public MatchTermVertexCompute { } std::unique_ptr copy() override { - return std::make_unique(resDfs, queryTerms); + return std::make_unique(resDfs, queryTerms); } }; -class AdvancedMatchVertexCompute final : public MatchTermVertexCompute { +class ExactTermMatchVertexCompute final : public MatchTermVertexCompute { public: - AdvancedMatchVertexCompute(std::unordered_map& resDfs, + ExactTermMatchVertexCompute(std::unordered_map& resDfs, std::vector& queryTerms, const QueryFTSBindData& bindData, main::ClientContext& context) : MatchTermVertexCompute{queryTerms, resDfs}, @@ -82,7 +82,7 @@ class AdvancedMatchVertexCompute final : public MatchTermVertexCompute { bindData{bindData}, context{context}, termsDFLookup{bindData.getTermsEntry(context), context} {} - ~AdvancedMatchVertexCompute() override { sb_stemmer_delete(sbStemmer); } + ~ExactTermMatchVertexCompute() override { sb_stemmer_delete(sbStemmer); } void handleMatchedTerm(uint64_t itr, const graph::VertexScanState::Chunk& chunk) override { auto term = chunk.getProperties(0)[itr]; @@ -94,7 +94,7 @@ class AdvancedMatchVertexCompute final : public MatchTermVertexCompute { } std::unique_ptr copy() override { - return std::make_unique(resDfs, queryTerms, bindData, context); + return std::make_unique(resDfs, queryTerms, bindData, context); } private: @@ -104,30 +104,33 @@ class AdvancedMatchVertexCompute final : public MatchTermVertexCompute { TermsDFLookup termsDFLookup; }; -static void basicMatchAlgo(std::unordered_map& dfs, +static void stemTermMatch(std::unordered_map& dfs, std::vector& vcQueryTerms, ExecutionContext* executionContext, graph::Graph* graph, const QueryFTSBindData& bindData) { - auto matchVc = BasicMatchVertexCompute{dfs, vcQueryTerms}; + auto matchVc = StemTermMatchVertexCompute{dfs, vcQueryTerms}; GDSUtils::runVertexCompute(executionContext, GDSDensityState::DENSE, graph, matchVc, bindData.getTermsEntry(*executionContext->clientContext), std::vector{"term", TermsDFLookup::DOC_FREQUENCY_PROP_NAME}); } -static void advancedMatchAlgo(std::unordered_map& dfs, +static void exactTermMatch(std::unordered_map& dfs, std::vector& vcQueryTerms, ExecutionContext* executionContext, graph::Graph* graph, const QueryFTSBindData& bindData) { auto matchOrigTermVc = - AdvancedMatchVertexCompute{dfs, vcQueryTerms, bindData, *executionContext->clientContext}; + ExactTermMatchVertexCompute{dfs, vcQueryTerms, bindData, *executionContext->clientContext}; GDSUtils::runVertexCompute(executionContext, GDSDensityState::DENSE, graph, matchOrigTermVc, bindData.getOrigTermsEntry(*executionContext->clientContext), std::vector{"term"}); } -pattern_match_algo PatternMatchFactory::getPatternMatchAlgo(bool isAdvanced) { - if (isAdvanced) { - return advancedMatchAlgo; - } else { - return basicMatchAlgo; +pattern_match_algo PatternMatchFactory::getPatternMatchAlgo(TermMatchType termMatchType) { + switch (termMatchType) { + case TermMatchType::EXACT: + return exactTermMatch; + case TermMatchType::STEM: + return stemTermMatch; + default: + KU_UNREACHABLE; } } diff --git a/extension/fts/src/include/function/fts_config.h b/extension/fts/src/include/function/fts_config.h index 4b461beb9cc..1fcd2476510 100644 --- a/extension/fts/src/include/function/fts_config.h +++ b/extension/fts/src/include/function/fts_config.h @@ -17,8 +17,8 @@ struct Stemmer { static void validate(const std::string& stemmer); }; -struct AdvancedPatternMatch { - static constexpr const char* NAME = "advanced_pattern_match"; +struct ExactTermMatch { + static constexpr const char* NAME = "exact_term_match"; static constexpr common::LogicalTypeID TYPE = common::LogicalTypeID::BOOL; static constexpr bool DEFAULT_VALUE = false; }; @@ -85,7 +85,7 @@ struct CreateFTSConfig { std::string ignorePattern = IgnorePattern::DEFAULT_VALUE; std::string ignorePatternQuery = IgnorePattern::DEFAULT_VALUE_QUERY; TokenizerInfo tokenizerInfo; - bool advancedPatternMatch = AdvancedPatternMatch::DEFAULT_VALUE; + bool exactTermMatch = ExactTermMatch::DEFAULT_VALUE; CreateFTSConfig() = default; CreateFTSConfig(main::ClientContext& context, common::table_id_t tableID, @@ -104,16 +104,16 @@ struct FTSConfig { std::string ignorePatternQuery = ""; std::string tokenizer = ""; std::string jiebaDictDir = ""; - bool advancedPatternMatch = false; + bool exactTermMatch = false; FTSConfig() = default; FTSConfig(std::string stemmer, std::string stopWordsTableName, std::string stopWordsSource, std::string ignorePattern, std::string ignorePatternQuery, std::string tokenizer, - std::string jiebaDictDir, bool advancedPatternMatch) + std::string jiebaDictDir, bool exactTermMatch) : stemmer{std::move(stemmer)}, stopWordsTableName{std::move(stopWordsTableName)}, stopWordsSource{std::move(stopWordsSource)}, ignorePattern{std::move(ignorePattern)}, ignorePatternQuery{std::move(ignorePatternQuery)}, tokenizer{std::move(tokenizer)}, - jiebaDictDir{std::move(jiebaDictDir)}, advancedPatternMatch{advancedPatternMatch} {} + jiebaDictDir{std::move(jiebaDictDir)}, exactTermMatch{exactTermMatch} {} void serialize(common::Serializer& serializer) const; diff --git a/extension/fts/src/include/function/query_fts/query_fts_pattern_match.h b/extension/fts/src/include/function/query_fts/query_fts_pattern_match.h index 83c851255ef..2a2c1040ff0 100644 --- a/extension/fts/src/include/function/query_fts/query_fts_pattern_match.h +++ b/extension/fts/src/include/function/query_fts/query_fts_pattern_match.h @@ -19,9 +19,14 @@ using pattern_match_algo = std::function& vcQueryTerms, processor::ExecutionContext* executionContext, graph::Graph* graph, const QueryFTSBindData& bindData)>; +enum class TermMatchType : uint8_t { + STEM = 0, + EXACT = 1, +}; + class PatternMatchFactory { public: - static pattern_match_algo getPatternMatchAlgo(bool isAdvanced); + static pattern_match_algo getPatternMatchAlgo(TermMatchType termMatchType); }; } // namespace fts_extension diff --git a/extension/fts/test/test_files/advanced_pattern_match.test b/extension/fts/test/test_files/advanced_pattern_match.test deleted file mode 100644 index 8d5f56b8f0c..00000000000 --- a/extension/fts/test/test_files/advanced_pattern_match.test +++ /dev/null @@ -1,24 +0,0 @@ --DATASET CSV fts-basic - --- - --CASE WildcardBasic - --LOAD_DYNAMIC_EXTENSION fts --STATEMENT CREATE NODE TABLE news (content string, primary key(content)); ----- ok --STATEMENT create (n:news {content: "alice is a canadian runner"}) ----- ok --STATEMENT create (n:news {content: "carol is running in the playground"}) ----- ok --STATEMENT CALL CREATE_FTS_INDEX('news', 'news_index_0', ['content'], advanced_pattern_match := FALSE); ----- ok --STATEMENT CALL QUERY_FTS_INDEX('news', 'news_index_0', 'runn*') RETURN node.content, score order by score ----- 1 -alice is a canadian runner|0.301030 --STATEMENT CALL CREATE_FTS_INDEX('news', 'news_index_1', ['content'], advanced_pattern_match := TRUE); ----- ok --STATEMENT CALL QUERY_FTS_INDEX('news', 'news_index_1', 'runn*') RETURN node.content, score order by score ----- 2 -alice is a canadian runner|0.301030 -carol is running in the playground|0.301030 diff --git a/extension/fts/test/test_files/wildcard.test b/extension/fts/test/test_files/wildcard.test index b14ab52d8b5..72743c7cd90 100644 --- a/extension/fts/test/test_files/wildcard.test +++ b/extension/fts/test/test_files/wildcard.test @@ -51,3 +51,34 @@ Abcdefg|This book is a test?ax*alphabetical? ---- 2 Echoes of the Past|A deep dive into the history of ancient civilizations. Computers|The hiory*?story*a?b?c of computing + +-CASE exact_term_match +-LOAD_DYNAMIC_EXTENSION fts +-STATEMENT CREATE NODE TABLE news (content string, primary key(content)); +---- ok +-STATEMENT create (n:news {content: "alice is a canadian runner"}) +---- ok +-STATEMENT create (n:news {content: "carol is running in the playground"}) +---- ok +-STATEMENT CALL CREATE_FTS_INDEX('news', 'news_index_0', ['content'], exact_term_match := FALSE); +---- ok +-STATEMENT CALL QUERY_FTS_INDEX('news', 'news_index_0', 'runn*') RETURN node.content, score order by score +---- 1 +alice is a canadian runner|0.301030 +-STATEMENT CALL QUERY_FTS_INDEX('news', 'news_index_0', 'runn?ng') RETURN node.content, score order by score +---- 0 +-STATEMENT CALL QUERY_FTS_INDEX('news', 'news_index_0', 'runne?') RETURN node.content, score order by score +---- 1 +alice is a canadian runner|0.301030 +-STATEMENT CALL CREATE_FTS_INDEX('news', 'news_index_1', ['content'], exact_term_match := TRUE); +---- ok +-STATEMENT CALL QUERY_FTS_INDEX('news', 'news_index_1', 'runn*') RETURN node.content, score order by score +---- 2 +alice is a canadian runner|0.301030 +carol is running in the playground|0.301030 +-STATEMENT CALL QUERY_FTS_INDEX('news', 'news_index_1', 'runn?ng') RETURN node.content, score order by score +---- 1 +carol is running in the playground|0.301030 +-STATEMENT CALL QUERY_FTS_INDEX('news', 'news_index_1', 'runne?') RETURN node.content, score order by score +---- 1 +alice is a canadian runner|0.301030 From bc4862d1b7e4b435ac4d32a6ca0aa4f124c0dec8 Mon Sep 17 00:00:00 2001 From: CI Bot Date: Thu, 11 Sep 2025 04:51:56 +0000 Subject: [PATCH 7/7] ci: auto code format --- .../src/function/query_fts/query_fts_bind_data.cpp | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/extension/fts/src/function/query_fts/query_fts_bind_data.cpp b/extension/fts/src/function/query_fts/query_fts_bind_data.cpp index 02dc3acd417..4b58cefe73e 100644 --- a/extension/fts/src/function/query_fts/query_fts_bind_data.cpp +++ b/extension/fts/src/function/query_fts/query_fts_bind_data.cpp @@ -78,12 +78,13 @@ std::vector QueryFTSBindData::getQueryTerms(main::ClientContext& co auto config = entry.getAuxInfo().cast().config; FTSUtils::normalizeQuery(queryInStr, config.ignorePatternQuery); auto terms = FTSUtils::tokenizeString(queryInStr, config); - auto stopWordsTable = StorageManager::Get(context) - ->getTable(catalog::Catalog::Get(context) - ->getTableCatalogEntry(transaction::Transaction::Get(context), - config.stopWordsTableName) - ->getTableID()) - ->ptrCast(); + auto stopWordsTable = + StorageManager::Get(context) + ->getTable(catalog::Catalog::Get(context) + ->getTableCatalogEntry(transaction::Transaction::Get(context), + config.stopWordsTableName) + ->getTableID()) + ->ptrCast(); return FTSUtils::stemTerms(terms, entry.getAuxInfo().cast().config, MemoryManager::Get(context), stopWordsTable, transaction::Transaction::Get(context), optionalParams->constCast().conjunctive.getParamVal(),