diff --git a/README.md b/README.md index 9bf8f04bc1..dc7312b765 100644 --- a/README.md +++ b/README.md @@ -63,6 +63,7 @@ - [X] [2024.11.11]🎯📢LightRAG now supports [deleting entities by their names](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#delete). - [X] [2024.11.09]🎯📢Introducing the [LightRAG Gui](https://lightrag-gui.streamlit.app), which allows you to insert, query, visualize, and download LightRAG knowledge. - [X] [2024.11.04]🎯📢You can now [use Neo4J for Storage](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#using-neo4j-for-storage). +- [X] [2024.11.04]🎯📢You can now [use FalkorDB for Storage](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#using-falkordb-for-storage). - [X] [2024.10.29]🎯📢LightRAG now supports multiple file types, including PDF, DOC, PPT, and CSV via `textract`. - [X] [2024.10.20]🎯📢We've added a new feature to LightRAG: Graph Visualization. - [X] [2024.10.18]🎯📢We've added a link to a [LightRAG Introduction Video](https://youtu.be/oageL-1I0GE). Thanks to the author! @@ -264,7 +265,7 @@ A full list of LightRAG init parameters: | **workspace** | str | Workspace name for data isolation between different LightRAG Instances | | | **kv_storage** | `str` | Storage type for documents and text chunks. Supported types: `JsonKVStorage`,`PGKVStorage`,`RedisKVStorage`,`MongoKVStorage` | `JsonKVStorage` | | **vector_storage** | `str` | Storage type for embedding vectors. Supported types: `NanoVectorDBStorage`,`PGVectorStorage`,`MilvusVectorDBStorage`,`ChromaVectorDBStorage`,`FaissVectorDBStorage`,`MongoVectorDBStorage`,`QdrantVectorDBStorage` | `NanoVectorDBStorage` | -| **graph_storage** | `str` | Storage type for graph edges and nodes. Supported types: `NetworkXStorage`,`Neo4JStorage`,`PGGraphStorage`,`AGEStorage` | `NetworkXStorage` | +| **graph_storage** | `str` | Storage type for graph edges and nodes. Supported types: `NetworkXStorage`,`Neo4JStorage`,`FalkorDBStorage`,`PGGraphStorage`,`AGEStorage` | `NetworkXStorage` | | **doc_status_storage** | `str` | Storage type for documents process status. Supported types: `JsonDocStatusStorage`,`PGDocStatusStorage`,`MongoDocStatusStorage` | `JsonDocStatusStorage` | | **chunk_token_size** | `int` | Maximum token size per chunk when splitting documents | `1200` | | **chunk_overlap_token_size** | `int` | Overlap token size between two chunks when splitting documents | `100` | @@ -819,6 +820,49 @@ see test_neo4j.py for a working example.
+ + Using FalkorDB for Storage + +* FalkorDB is a high-performance graph database that's Redis module compatible and supports the Cypher query language +* Running FalkorDB in Docker is recommended for seamless local testing +* See: https://hub.docker.com/r/falkordb/falkordb + +```python +export FALKORDB_HOST="localhost" +export FALKORDB_PORT="6379" +export FALKORDB_PASSWORD="password" # optional +export FALKORDB_USERNAME="username" # optional +export FALKORDB_GRAPH_NAME="lightrag_graph" # optional, defaults to namespace + +# Setup logger for LightRAG +setup_logger("lightrag", level="INFO") + +# When you launch the project be sure to override the default KG: NetworkX +# by specifying graph_storage="FalkorDBStorage". + +# Note: Default settings use NetworkX +# Initialize LightRAG with FalkorDB implementation. +async def initialize_rag(): + rag = LightRAG( + working_dir=WORKING_DIR, + llm_model_func=gpt_4o_mini_complete, # Use gpt_4o_mini_complete LLM model + graph_storage="FalkorDBStorage", #<-----------override KG default + ) + + # Initialize database connections + await rag.initialize_storages() + # Initialize pipeline status for document processing + await initialize_pipeline_status() + + return rag +``` + +see examples/falkordb_example.py for a working example. + +
+ +
+ Using PostgreSQL Storage For production level scenarios you will most likely want to leverage an enterprise solution. PostgreSQL can provide a one-stop solution for you as KV store, VectorDB (pgvector) and GraphDB (apache AGE). PostgreSQL version 16.6 or higher is supported. @@ -934,8 +978,9 @@ The `workspace` parameter ensures data isolation between different LightRAG inst - **For databases that store data in collections, it's done by adding a workspace prefix to the collection name:** `RedisKVStorage`, `RedisDocStatusStorage`, `MilvusVectorDBStorage`, `QdrantVectorDBStorage`, `MongoKVStorage`, `MongoDocStatusStorage`, `MongoVectorDBStorage`, `MongoGraphStorage`, `PGGraphStorage`. - **For relational databases, data isolation is achieved by adding a `workspace` field to the tables for logical data separation:** `PGKVStorage`, `PGVectorStorage`, `PGDocStatusStorage`. - **For the Neo4j graph database, logical data isolation is achieved through labels:** `Neo4JStorage` +- **For the FalkorDB graph database, logical data isolation is achieved through labels:** `FalkorDBStorage` -To maintain compatibility with legacy data, the default workspace for PostgreSQL non-graph storage is `default` and, for PostgreSQL AGE graph storage is null, for Neo4j graph storage is `base` when no workspace is configured. For all external storages, the system provides dedicated workspace environment variables to override the common `WORKSPACE` environment variable configuration. These storage-specific workspace environment variables are: `REDIS_WORKSPACE`, `MILVUS_WORKSPACE`, `QDRANT_WORKSPACE`, `MONGODB_WORKSPACE`, `POSTGRES_WORKSPACE`, `NEO4J_WORKSPACE`. +To maintain compatibility with legacy data, the default workspace for PostgreSQL non-graph storage is `default` and, for PostgreSQL AGE graph storage is null, for Neo4j graph storage is `base`, and for FalkorDB graph storage is `base` when no workspace is configured. For all external storages, the system provides dedicated workspace environment variables to override the common `WORKSPACE` environment variable configuration. These storage-specific workspace environment variables are: `REDIS_WORKSPACE`, `MILVUS_WORKSPACE`, `QDRANT_WORKSPACE`, `MONGODB_WORKSPACE`, `POSTGRES_WORKSPACE`, `NEO4J_WORKSPACE`, `FALKORDB_WORKSPACE`. ## Edit Entities and Relations diff --git a/env.example b/env.example index 5eef3913e0..c2e3e3a9b2 100644 --- a/env.example +++ b/env.example @@ -261,6 +261,7 @@ OLLAMA_EMBEDDING_NUM_CTX=8192 # LIGHTRAG_DOC_STATUS_STORAGE=JsonDocStatusStorage # LIGHTRAG_GRAPH_STORAGE=NetworkXStorage # LIGHTRAG_VECTOR_STORAGE=NanoVectorDBStorage +# LIGHTRAG_GRAPH_STORAGE=FalkorDBStorage ### Redis Storage (Recommended for production deployment) # LIGHTRAG_KV_STORAGE=RedisKVStorage @@ -324,6 +325,12 @@ NEO4J_LIVENESS_CHECK_TIMEOUT=30 NEO4J_KEEP_ALIVE=true # NEO4J_WORKSPACE=forced_workspace_name +# FalkorDB Configuration +FALKORDB_URI=falkordb://xxxxxxxx.falkordb.cloud +FALKORDB_GRAPH_NAME=lightrag_graph +# FALKORDB_HOST=localhost +# FALKORDB_PORT=6379 + ### MongoDB Configuration MONGO_URI=mongodb://root:root@localhost:27017/ #MONGO_URI=mongodb+srv://xxxx diff --git a/examples/falkordb_example.py b/examples/falkordb_example.py new file mode 100644 index 0000000000..8e3aeb6aed --- /dev/null +++ b/examples/falkordb_example.py @@ -0,0 +1,130 @@ +#!/usr/bin/env python +""" +Example of using LightRAG with FalkorDB - Updated Version +========================================================= +Fixed imports and modern LightRAG syntax. + +Prerequisites: +1. FalkorDB running: docker run -p 6379:6379 falkordb/falkordb:latest +2. OpenAI API key in .env file +3. Required packages: pip install lightrag falkordb openai python-dotenv +""" + +import asyncio +import os +from dotenv import load_dotenv +from lightrag import LightRAG, QueryParam +from lightrag.llm.openai import gpt_4o_mini_complete, openai_embed +from lightrag.kg.shared_storage import initialize_pipeline_status + +# Load environment variables +load_dotenv() + + +async def main(): + """Example usage of LightRAG with FalkorDB""" + + # Set up environment for FalkorDB + os.environ.setdefault("FALKORDB_HOST", "localhost") + os.environ.setdefault("FALKORDB_PORT", "6379") + os.environ.setdefault("FALKORDB_GRAPH_NAME", "lightrag_example") + os.environ.setdefault("FALKORDB_WORKSPACE", "example_workspace") + + # Initialize LightRAG with FalkorDB + rag = LightRAG( + working_dir="./falkordb_example", + llm_model_func=gpt_4o_mini_complete, # Updated function name + embedding_func=openai_embed, # Updated function name + graph_storage="FalkorDBStorage", # Specify FalkorDB backend + ) + + # Initialize storage connections + await rag.initialize_storages() + await initialize_pipeline_status() + + # Example text to process + sample_text = """ + FalkorDB is a high-performance graph database built on Redis. + It supports OpenCypher queries and provides excellent performance for graph operations. + LightRAG can now use FalkorDB as its graph storage backend, enabling scalable + knowledge graph operations with Redis-based persistence. This integration + allows developers to leverage both the speed of Redis and the power of + graph databases for advanced AI applications. + """ + + print("Inserting text into LightRAG with FalkorDB backend...") + await rag.ainsert(sample_text) + + # Check what was created + storage = rag.chunk_entity_relation_graph + nodes = await storage.get_all_nodes() + edges = await storage.get_all_edges() + print(f"Knowledge graph created: {len(nodes)} nodes, {len(edges)} edges") + + print("\nQuerying the knowledge graph...") + + # Test different query modes + questions = [ + "What is FalkorDB and how does it relate to LightRAG?", + "What are the benefits of using Redis with graph databases?", + "How does FalkorDB support OpenCypher queries?", + ] + + for i, question in enumerate(questions, 1): + print(f"\n--- Question {i} ---") + print(f"Q: {question}") + + try: + response = await rag.aquery( + question, param=QueryParam(mode="hybrid", top_k=3) + ) + print(f"A: {response}") + except Exception as e: + print(f"Error querying: {e}") + + # Show some graph statistics + print("\n--- Graph Statistics ---") + try: + all_labels = await storage.get_all_labels() + print(f"Unique entities: {len(all_labels)}") + + if nodes: + print("Sample entities:") + for i, node in enumerate(nodes[:3]): + entity_id = node.get("entity_id", "Unknown") + entity_type = node.get("entity_type", "Unknown") + print(f" {i+1}. {entity_id} ({entity_type})") + + if edges: + print("Sample relationships:") + for i, edge in enumerate(edges[:2]): + source = edge.get("source", "Unknown") + target = edge.get("target", "Unknown") + print(f" {i+1}. {source} → {target}") + + except Exception as e: + print(f"Error getting statistics: {e}") + + +if __name__ == "__main__": + print("LightRAG with FalkorDB Example") + print("==============================") + print("Note: This requires FalkorDB running on localhost:6379") + print( + "You can start FalkorDB with: docker run -p 6379:6379 falkordb/falkordb:latest" + ) + print() + + # Check OpenAI API key + if not os.getenv("OPENAI_API_KEY"): + print("❌ Please set your OpenAI API key in .env file!") + print(" Create a .env file with: OPENAI_API_KEY=your-actual-api-key") + exit(1) + + try: + asyncio.run(main()) + except KeyboardInterrupt: + print("\n👋 Example interrupted. Goodbye!") + except Exception as e: + print(f"\n💥 Unexpected error: {e}") + print("🔧 Make sure FalkorDB is running and your .env file is configured") diff --git a/examples/graph_visual_with_falkordb.py b/examples/graph_visual_with_falkordb.py new file mode 100644 index 0000000000..6bce2a6462 --- /dev/null +++ b/examples/graph_visual_with_falkordb.py @@ -0,0 +1,279 @@ +import os +import xml.etree.ElementTree as ET +import falkordb + +# Constants +WORKING_DIR = "./dickens" +BATCH_SIZE_NODES = 500 +BATCH_SIZE_EDGES = 100 + +# FalkorDB connection credentials +FALKORDB_HOST = "localhost" +FALKORDB_PORT = 6379 +FALKORDB_GRAPH_NAME = "dickens_graph" + + +def xml_to_json(xml_file): + try: + tree = ET.parse(xml_file) + root = tree.getroot() + + # Print the root element's tag and attributes to confirm the file has been correctly loaded + print(f"Root element: {root.tag}") + print(f"Root attributes: {root.attrib}") + + data = {"nodes": [], "edges": []} + + # Use namespace + namespace = {"": "http://graphml.graphdrawing.org/xmlns"} + + for node in root.findall(".//node", namespace): + node_data = { + "id": node.get("id").strip('"'), + "entity_type": node.find("./data[@key='d1']", namespace).text.strip('"') + if node.find("./data[@key='d1']", namespace) is not None + else "", + "description": node.find("./data[@key='d2']", namespace).text + if node.find("./data[@key='d2']", namespace) is not None + else "", + "source_id": node.find("./data[@key='d3']", namespace).text + if node.find("./data[@key='d3']", namespace) is not None + else "", + } + data["nodes"].append(node_data) + + for edge in root.findall(".//edge", namespace): + edge_data = { + "source": edge.get("source").strip('"'), + "target": edge.get("target").strip('"'), + "weight": float(edge.find("./data[@key='d5']", namespace).text) + if edge.find("./data[@key='d5']", namespace) is not None + else 1.0, + "description": edge.find("./data[@key='d6']", namespace).text + if edge.find("./data[@key='d6']", namespace) is not None + else "", + "keywords": edge.find("./data[@key='d7']", namespace).text + if edge.find("./data[@key='d7']", namespace) is not None + else "", + "source_id": edge.find("./data[@key='d8']", namespace).text + if edge.find("./data[@key='d8']", namespace) is not None + else "", + } + data["edges"].append(edge_data) + + return data + + except ET.ParseError as e: + print(f"Error parsing XML: {e}") + return None + except Exception as e: + print(f"Unexpected error: {e}") + return None + + +def insert_nodes_and_edges_to_falkordb(data): + """Insert graph data into FalkorDB""" + try: + # Connect to FalkorDB + db = falkordb.FalkorDB(host=FALKORDB_HOST, port=FALKORDB_PORT) + graph = db.select_graph(FALKORDB_GRAPH_NAME) + + print(f"Connected to FalkorDB at {FALKORDB_HOST}:{FALKORDB_PORT}") + print(f"Using graph: {FALKORDB_GRAPH_NAME}") + + nodes = data["nodes"] + edges = data["edges"] + + print(f"Total nodes to insert: {len(nodes)}") + print(f"Total edges to insert: {len(edges)}") + + # Insert nodes in batches + for i in range(0, len(nodes), BATCH_SIZE_NODES): + batch_nodes = nodes[i : i + BATCH_SIZE_NODES] + + # Build UNWIND query for batch insert + query = """ + UNWIND $nodes AS node + CREATE (n:Entity { + entity_id: node.id, + entity_type: node.entity_type, + description: node.description, + source_id: node.source_id + }) + """ + + graph.query(query, {"nodes": batch_nodes}) + print(f"Inserted nodes {i+1} to {min(i + BATCH_SIZE_NODES, len(nodes))}") + + # Insert edges in batches + for i in range(0, len(edges), BATCH_SIZE_EDGES): + batch_edges = edges[i : i + BATCH_SIZE_EDGES] + + # Build UNWIND query for batch insert + query = """ + UNWIND $edges AS edge + MATCH (source:Entity {entity_id: edge.source}) + MATCH (target:Entity {entity_id: edge.target}) + CREATE (source)-[r:DIRECTED { + weight: edge.weight, + description: edge.description, + keywords: edge.keywords, + source_id: edge.source_id + }]-(target) + """ + + graph.query(query, {"edges": batch_edges}) + print(f"Inserted edges {i+1} to {min(i + BATCH_SIZE_EDGES, len(edges))}") + + print("Data insertion completed successfully!") + + # Print some statistics + node_count_result = graph.query("MATCH (n:Entity) RETURN count(n) AS count") + edge_count_result = graph.query( + "MATCH ()-[r:DIRECTED]-() RETURN count(r) AS count" + ) + + node_count = ( + node_count_result.result_set[0][0] if node_count_result.result_set else 0 + ) + edge_count = ( + edge_count_result.result_set[0][0] if edge_count_result.result_set else 0 + ) + + print("Final statistics:") + print(f"- Nodes in database: {node_count}") + print(f"- Edges in database: {edge_count}") + + except Exception as e: + print(f"Error inserting data into FalkorDB: {e}") + + +def query_graph_data(): + """Query and display some sample data from FalkorDB""" + try: + # Connect to FalkorDB + db = falkordb.FalkorDB(host=FALKORDB_HOST, port=FALKORDB_PORT) + graph = db.select_graph(FALKORDB_GRAPH_NAME) + + print("\n=== Sample Graph Data ===") + + # Get some sample nodes + query = ( + "MATCH (n:Entity) RETURN n.entity_id, n.entity_type, n.description LIMIT 5" + ) + result = graph.query(query) + + print("\nSample Nodes:") + if result.result_set: + for record in result.result_set: + print(f"- {record[0]} ({record[1]}): {record[2][:100]}...") + + # Get some sample edges + query = """ + MATCH (a:Entity)-[r:DIRECTED]-(b:Entity) + RETURN a.entity_id, b.entity_id, r.weight, r.description + LIMIT 5 + """ + result = graph.query(query) + + print("\nSample Edges:") + if result.result_set: + for record in result.result_set: + print( + f"- {record[0]} -> {record[1]} (weight: {record[2]}): {record[3][:100]}..." + ) + + # Get node degree statistics + query = """ + MATCH (n:Entity) + OPTIONAL MATCH (n)-[r]-() + WITH n, count(r) AS degree + RETURN min(degree) AS min_degree, max(degree) AS max_degree, avg(degree) AS avg_degree + """ + result = graph.query(query) + + print("\nNode Degree Statistics:") + if result.result_set: + record = result.result_set[0] + print(f"- Min degree: {record[0]}") + print(f"- Max degree: {record[1]}") + print(f"- Avg degree: {record[2]:.2f}") + + except Exception as e: + print(f"Error querying FalkorDB: {e}") + + +def clear_graph(): + """Clear all data from the FalkorDB graph""" + try: + db = falkordb.FalkorDB(host=FALKORDB_HOST, port=FALKORDB_PORT) + graph = db.select_graph(FALKORDB_GRAPH_NAME) + + # Delete all nodes and relationships + graph.query("MATCH (n) DETACH DELETE n") + print("Graph cleared successfully!") + + except Exception as e: + print(f"Error clearing graph: {e}") + + +def main(): + xml_file = os.path.join(WORKING_DIR, "graph_chunk_entity_relation.graphml") + + if not os.path.exists(xml_file): + print( + f"Error: File {xml_file} not found. Please ensure the GraphML file exists." + ) + print( + "This file is typically generated by LightRAG after processing documents." + ) + return + + print("FalkorDB Graph Visualization Example") + print("====================================") + print(f"Processing file: {xml_file}") + print(f"FalkorDB connection: {FALKORDB_HOST}:{FALKORDB_PORT}") + print(f"Graph name: {FALKORDB_GRAPH_NAME}") + print() + + # Parse XML to JSON + print("1. Parsing GraphML file...") + data = xml_to_json(xml_file) + if data is None: + print("Failed to parse XML file.") + return + + print(f" Found {len(data['nodes'])} nodes and {len(data['edges'])} edges") + + # Ask user what to do + while True: + print("\nOptions:") + print("1. Clear existing graph data") + print("2. Insert data into FalkorDB") + print("3. Query sample data") + print("4. Exit") + + choice = input("\nSelect an option (1-4): ").strip() + + if choice == "1": + print("\n2. Clearing existing graph data...") + clear_graph() + + elif choice == "2": + print("\n2. Inserting data into FalkorDB...") + insert_nodes_and_edges_to_falkordb(data) + + elif choice == "3": + print("\n3. Querying sample data...") + query_graph_data() + + elif choice == "4": + print("Goodbye!") + break + + else: + print("Invalid choice. Please try again.") + + +if __name__ == "__main__": + main() diff --git a/lightrag/kg/__init__.py b/lightrag/kg/__init__.py index 8d42441ac7..28f505f885 100644 --- a/lightrag/kg/__init__.py +++ b/lightrag/kg/__init__.py @@ -12,6 +12,7 @@ "implementations": [ "NetworkXStorage", "Neo4JStorage", + "FalkorDBStorage", "PGGraphStorage", "MongoGraphStorage", "MemgraphStorage", @@ -51,6 +52,7 @@ # Graph Storage Implementations "NetworkXStorage": [], "Neo4JStorage": ["NEO4J_URI", "NEO4J_USERNAME", "NEO4J_PASSWORD"], + "FalkorDBStorage": ["FALKORDB_HOST", "FALKORDB_PORT"], "MongoGraphStorage": [], "MemgraphStorage": ["MEMGRAPH_URI"], "AGEStorage": [ @@ -85,6 +87,7 @@ "NanoVectorDBStorage": ".kg.nano_vector_db_impl", "JsonDocStatusStorage": ".kg.json_doc_status_impl", "Neo4JStorage": ".kg.neo4j_impl", + "FalkorDBStorage": ".kg.falkordb_impl", "MilvusVectorDBStorage": ".kg.milvus_impl", "MongoKVStorage": ".kg.mongo_impl", "MongoDocStatusStorage": ".kg.mongo_impl", diff --git a/lightrag/kg/falkordb_impl.py b/lightrag/kg/falkordb_impl.py new file mode 100644 index 0000000000..c815ac6f94 --- /dev/null +++ b/lightrag/kg/falkordb_impl.py @@ -0,0 +1,1069 @@ +import os +import re +import asyncio +from dataclasses import dataclass +from typing import final +import configparser +from concurrent.futures import ThreadPoolExecutor + +from tenacity import ( + retry, + stop_after_attempt, + wait_exponential, + retry_if_exception_type, +) + +import logging +from ..utils import logger +from ..base import BaseGraphStorage +from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge +from ..constants import GRAPH_FIELD_SEP +import pipmaster as pm + +if not pm.is_installed("falkordb"): + pm.install("falkordb") + +import falkordb +import redis.exceptions + +from dotenv import load_dotenv + +# use the .env that is inside the current folder +# allows to use different .env file for each lightrag instance +# the OS environment variables take precedence over the .env file +load_dotenv(dotenv_path=".env", override=False) + +config = configparser.ConfigParser() +config.read("config.ini", "utf-8") + + +# Set falkordb logger level to ERROR to suppress warning logs +logging.getLogger("falkordb").setLevel(logging.ERROR) + + +@final +@dataclass +class FalkorDBStorage(BaseGraphStorage): + def __init__(self, namespace, global_config, embedding_func, workspace=None): + # Check FALKORDB_WORKSPACE environment variable and override workspace if set + falkordb_workspace = os.environ.get("FALKORDB_WORKSPACE") + if falkordb_workspace and falkordb_workspace.strip(): + workspace = falkordb_workspace + + super().__init__( + namespace=namespace, + workspace=workspace or "", + global_config=global_config, + embedding_func=embedding_func, + ) + self._db = None + self._graph = None + self._executor = ThreadPoolExecutor(max_workers=4) + + def _get_workspace_label(self) -> str: + """Get workspace label, return 'base' for compatibility when workspace is empty""" + workspace = getattr(self, "workspace", None) + return workspace if workspace else "base" + + async def initialize(self): + HOST = os.environ.get( + "FALKORDB_HOST", config.get("falkordb", "host", fallback="localhost") + ) + PORT = int( + os.environ.get( + "FALKORDB_PORT", config.get("falkordb", "port", fallback=6379) + ) + ) + PASSWORD = os.environ.get( + "FALKORDB_PASSWORD", config.get("falkordb", "password", fallback=None) + ) + USERNAME = os.environ.get( + "FALKORDB_USERNAME", config.get("falkordb", "username", fallback=None) + ) + GRAPH_NAME = os.environ.get( + "FALKORDB_GRAPH_NAME", + config.get( + "falkordb", + "graph_name", + fallback=re.sub(r"[^a-zA-Z0-9-]", "-", self.namespace), + ), + ) + + try: + # Create FalkorDB connection + self._db = falkordb.FalkorDB( + host=HOST, + port=PORT, + password=PASSWORD, + username=USERNAME, + ) + + # Select the graph (creates if doesn't exist) + self._graph = self._db.select_graph(GRAPH_NAME) + + # Test connection with a simple query + await self._run_query("RETURN 1") + + # Create index for workspace nodes on entity_id if it doesn't exist + workspace_label = self._get_workspace_label() + try: + index_query = ( + f"CREATE INDEX FOR (n:`{workspace_label}`) ON (n.entity_id)" + ) + await self._run_query(index_query) + logger.info( + f"Created index for {workspace_label} nodes on entity_id in FalkorDB" + ) + except Exception as e: + # Index may already exist, which is not an error + logger.debug(f"Index creation may have failed or already exists: {e}") + + logger.info(f"Connected to FalkorDB at {HOST}:{PORT}, graph: {GRAPH_NAME}") + + except Exception as e: + logger.error(f"Failed to connect to FalkorDB at {HOST}:{PORT}: {e}") + raise + + async def finalize(self): + """Close the FalkorDB connection and release all resources""" + if self._executor: + self._executor.shutdown(wait=True) + self._executor = None + if self._db: + # FalkorDB doesn't have an explicit close method for the client + self._db = None + self._graph = None + + async def __aexit__(self, exc_type, exc, tb): + """Ensure connection is closed when context manager exits""" + await self.finalize() + + async def _run_query(self, query: str, params: dict = None): + """Run a query asynchronously using thread pool""" + loop = asyncio.get_event_loop() + return await loop.run_in_executor( + self._executor, lambda: self._graph.query(query, params or {}) + ) + + async def index_done_callback(self) -> None: + # FalkorDB handles persistence automatically + pass + + async def has_node(self, node_id: str) -> bool: + """ + Check if a node with the given label exists in the database + + Args: + node_id: Label of the node to check + + Returns: + bool: True if node exists, False otherwise + + Raises: + ValueError: If node_id is invalid + Exception: If there is an error executing the query + """ + workspace_label = self._get_workspace_label() + try: + query = f"MATCH (n:`{workspace_label}` {{entity_id: $entity_id}}) RETURN count(n) > 0 AS node_exists" + result = await self._run_query(query, {"entity_id": node_id.strip()}) + return result.result_set[0][0] if result.result_set else False + except Exception as e: + logger.error(f"Error checking node existence for {node_id}: {str(e)}") + raise + + async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: + """ + Check if an edge exists between two nodes + + Args: + source_node_id: Label of the source node + target_node_id: Label of the target node + + Returns: + bool: True if edge exists, False otherwise + + Raises: + ValueError: If either node_id is invalid + Exception: If there is an error executing the query + """ + workspace_label = self._get_workspace_label() + try: + query = ( + f"MATCH (a:`{workspace_label}` {{entity_id: $source_entity_id}})-[r]-(b:`{workspace_label}` {{entity_id: $target_entity_id}}) " + "RETURN COUNT(r) > 0 AS edgeExists" + ) + result = await self._run_query( + query, + { + "source_entity_id": source_node_id, + "target_entity_id": target_node_id, + }, + ) + return result.result_set[0][0] if result.result_set else False + except Exception as e: + logger.error( + f"Error checking edge existence between {source_node_id} and {target_node_id}: {str(e)}" + ) + raise + + async def get_node(self, node_id: str) -> dict[str, str] | None: + """Get node by its label identifier, return only node properties + + Args: + node_id: The node label to look up + + Returns: + dict: Node properties if found + None: If node not found + + Raises: + ValueError: If node_id is invalid + Exception: If there is an error executing the query + """ + workspace_label = self._get_workspace_label() + try: + query = f"MATCH (n:`{workspace_label}` {{entity_id: $entity_id}}) RETURN n" + result = await self._run_query(query, {"entity_id": node_id}) + + if result.result_set and len(result.result_set) > 0: + node = result.result_set[0][0] # Get the first node + # Convert FalkorDB node to dictionary + node_dict = {key: value for key, value in node.properties.items()} + return node_dict + return None + except Exception as e: + logger.error(f"Error getting node for {node_id}: {str(e)}") + raise + + async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]: + """ + Retrieve multiple nodes in one query using UNWIND. + + Args: + node_ids: List of node entity IDs to fetch. + + Returns: + A dictionary mapping each node_id to its node data (or None if not found). + """ + workspace_label = self._get_workspace_label() + query = f""" + UNWIND $node_ids AS id + MATCH (n:`{workspace_label}` {{entity_id: id}}) + RETURN n.entity_id AS entity_id, n + """ + result = await self._run_query(query, {"node_ids": node_ids}) + nodes = {} + + if result.result_set and len(result.result_set) > 0: + for record in result.result_set: + entity_id = record[0] + node = record[1] + node_dict = {key: value for key, value in node.properties.items()} + nodes[entity_id] = node_dict + + return nodes + + async def node_degree(self, node_id: str) -> int: + """Get the degree (number of relationships) of a node with the given label. + If multiple nodes have the same label, returns the degree of the first node. + If no node is found, returns 0. + + Args: + node_id: The label of the node + + Returns: + int: The number of relationships the node has, or 0 if no node found + + Raises: + ValueError: If node_id is invalid + Exception: If there is an error executing the query + """ + workspace_label = self._get_workspace_label() + try: + query = f""" + MATCH (n:`{workspace_label}` {{entity_id: $entity_id}}) + OPTIONAL MATCH (n)-[r]-() + RETURN COUNT(r) AS degree + """ + result = await self._run_query(query, {"entity_id": node_id}) + + if result.result_set and len(result.result_set) > 0: + degree = result.result_set[0][0] + return degree + else: + logger.warning(f"No node found with label '{node_id}'") + return 0 + except Exception as e: + logger.error(f"Error getting node degree for {node_id}: {str(e)}") + raise + + async def node_degrees_batch(self, node_ids: list[str]) -> dict[str, int]: + """ + Retrieve the degree for multiple nodes in a single query using UNWIND. + + Args: + node_ids: List of node labels (entity_id values) to look up. + + Returns: + A dictionary mapping each node_id to its degree (number of relationships). + If a node is not found, its degree will be set to 0. + """ + workspace_label = self._get_workspace_label() + query = f""" + UNWIND $node_ids AS id + MATCH (n:`{workspace_label}` {{entity_id: id}}) + OPTIONAL MATCH (n)-[r]-() + RETURN n.entity_id AS entity_id, COUNT(r) AS degree + """ + result = await self._run_query(query, {"node_ids": node_ids}) + degrees = {} + + if result.result_set and len(result.result_set) > 0: + for record in result.result_set: + entity_id = record[0] + degrees[entity_id] = record[1] + + # For any node_id that did not return a record, set degree to 0. + for nid in node_ids: + if nid not in degrees: + logger.warning(f"No node found with label '{nid}'") + degrees[nid] = 0 + + return degrees + + async def edge_degree(self, src_id: str, tgt_id: str) -> int: + """Get the total degree (sum of relationships) of two nodes. + + Args: + src_id: Label of the source node + tgt_id: Label of the target node + + Returns: + int: Sum of the degrees of both nodes + """ + src_degree = await self.node_degree(src_id) + trg_degree = await self.node_degree(tgt_id) + + # Convert None to 0 for addition + src_degree = 0 if src_degree is None else src_degree + trg_degree = 0 if trg_degree is None else trg_degree + + degrees = int(src_degree) + int(trg_degree) + return degrees + + async def edge_degrees_batch( + self, edge_pairs: list[tuple[str, str]] + ) -> dict[tuple[str, str], int]: + """ + Calculate the combined degree for each edge (sum of the source and target node degrees) + in batch using the already implemented node_degrees_batch. + + Args: + edge_pairs: List of (src, tgt) tuples. + + Returns: + A dictionary mapping each (src, tgt) tuple to the sum of their degrees. + """ + # Collect unique node IDs from all edge pairs. + unique_node_ids = {src for src, _ in edge_pairs} + unique_node_ids.update({tgt for _, tgt in edge_pairs}) + + # Get degrees for all nodes in one go. + degrees = await self.node_degrees_batch(list(unique_node_ids)) + + # Sum up degrees for each edge pair. + edge_degrees = {} + for src, tgt in edge_pairs: + edge_degrees[(src, tgt)] = degrees.get(src, 0) + degrees.get(tgt, 0) + return edge_degrees + + async def get_edge( + self, source_node_id: str, target_node_id: str + ) -> dict[str, str] | None: + """Get edge properties between two nodes. + + Args: + source_node_id: Label of the source node + target_node_id: Label of the target node + + Returns: + dict: Edge properties if found, default properties if not found or on error + + Raises: + ValueError: If either node_id is invalid + Exception: If there is an error executing the query + """ + workspace_label = self._get_workspace_label() + try: + query = f""" + MATCH (start:`{workspace_label}` {{entity_id: $source_entity_id}})-[r]-(end:`{workspace_label}` {{entity_id: $target_entity_id}}) + RETURN properties(r) as edge_properties + """ + result = await self._run_query( + query, + { + "source_entity_id": source_node_id, + "target_entity_id": target_node_id, + }, + ) + + if result.result_set and len(result.result_set) > 0: + edge_result = result.result_set[0][0] # Get properties dict + + # Ensure required keys exist with defaults + required_keys = { + "weight": 1.0, + "source_id": None, + "description": None, + "keywords": None, + } + for key, default_value in required_keys.items(): + if key not in edge_result: + edge_result[key] = default_value + logger.warning( + f"Edge between {source_node_id} and {target_node_id} " + f"missing {key}, using default: {default_value}" + ) + + return edge_result + + # Return None when no edge found + return None + + except Exception as e: + logger.error( + f"Error in get_edge between {source_node_id} and {target_node_id}: {str(e)}" + ) + raise + + async def get_edges_batch( + self, pairs: list[dict[str, str]] + ) -> dict[tuple[str, str], dict]: + """ + Retrieve edge properties for multiple (src, tgt) pairs in one query. + + Args: + pairs: List of dictionaries, e.g. [{"src": "node1", "tgt": "node2"}, ...] + + Returns: + A dictionary mapping (src, tgt) tuples to their edge properties. + """ + + workspace_label = self._get_workspace_label() + query = f""" + UNWIND $pairs AS pair + MATCH (start:`{workspace_label}` {{entity_id: pair.src}})-[r]-(end:`{workspace_label}` {{entity_id: pair.tgt}}) + RETURN pair.src AS src_id, pair.tgt AS tgt_id, properties(r) AS edge_properties + """ + result = await self._run_query(query, {"pairs": pairs}) + edges_dict = {} + + if result.result_set and len(result.result_set) > 0: + for record in result.result_set: + if record and len(record) >= 3: + src = record[0] + tgt = record[1] + edge_props = record[2] if record[2] else {} + + edge_result = {} + for key, default in { + "weight": 1.0, + "source_id": None, + "description": None, + "keywords": None, + }.items(): + edge_result[key] = edge_props.get(key, default) + + edges_dict[(src, tgt)] = edge_result + + # Add default properties for pairs not found + for pair_dict in pairs: + src = pair_dict["src"] + tgt = pair_dict["tgt"] + if (src, tgt) not in edges_dict: + edges_dict[(src, tgt)] = { + "weight": 1.0, + "source_id": None, + "description": None, + "keywords": None, + } + + return edges_dict + + async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None: + """Retrieves all edges (relationships) for a particular node identified by its label. + + Args: + source_node_id: Label of the node to get edges for + + Returns: + list[tuple[str, str]]: List of (source_label, target_label) tuples representing edges + None: If no edges found + + Raises: + ValueError: If source_node_id is invalid + Exception: If there is an error executing the query + """ + try: + workspace_label = self._get_workspace_label() + query = f"""MATCH (n:`{workspace_label}` {{entity_id: $entity_id}}) + OPTIONAL MATCH (n)-[r]-(connected:`{workspace_label}`) + WHERE connected.entity_id IS NOT NULL + RETURN n, r, connected""" + result = await self._run_query(query, {"entity_id": source_node_id}) + + edges = [] + if result.result_set: + for record in result.result_set: + source_node = record[0] + connected_node = record[2] + + # Skip if either node is None + if not source_node or not connected_node: + continue + + source_label = source_node.properties.get("entity_id") + target_label = connected_node.properties.get("entity_id") + + if source_label and target_label: + edges.append((source_label, target_label)) + + return edges + except Exception as e: + logger.error(f"Error in get_node_edges for {source_node_id}: {str(e)}") + raise + + async def get_nodes_edges_batch( + self, node_ids: list[str] + ) -> dict[str, list[tuple[str, str]]]: + """ + Batch retrieve edges for multiple nodes in one query using UNWIND. + For each node, returns both outgoing and incoming edges to properly represent + the undirected graph nature. + + Args: + node_ids: List of node IDs (entity_id) for which to retrieve edges. + + Returns: + A dictionary mapping each node ID to its list of edge tuples (source, target). + For each node, the list includes both: + - Outgoing edges: (queried_node, connected_node) + - Incoming edges: (connected_node, queried_node) + """ + workspace_label = self._get_workspace_label() + query = f""" + UNWIND $node_ids AS id + MATCH (n:`{workspace_label}` {{entity_id: id}}) + OPTIONAL MATCH (n)-[r]-(connected:`{workspace_label}`) + RETURN id AS queried_id, n.entity_id AS node_entity_id, + connected.entity_id AS connected_entity_id, + startNode(r).entity_id AS start_entity_id + """ + result = await self._run_query(query, {"node_ids": node_ids}) + + # Initialize the dictionary with empty lists for each node ID + edges_dict = {node_id: [] for node_id in node_ids} + + # Process results to include both outgoing and incoming edges + if result.result_set: + for record in result.result_set: + queried_id = record[0] + node_entity_id = record[1] + connected_entity_id = record[2] + start_entity_id = record[3] + + # Skip if either node is None + if not node_entity_id or not connected_entity_id: + continue + + # Determine the actual direction of the edge + # If the start node is the queried node, it's an outgoing edge + # Otherwise, it's an incoming edge + if start_entity_id == node_entity_id: + # Outgoing edge: (queried_node -> connected_node) + edges_dict[queried_id].append((node_entity_id, connected_entity_id)) + else: + # Incoming edge: (connected_node -> queried_node) + edges_dict[queried_id].append((connected_entity_id, node_entity_id)) + + return edges_dict + + async def get_nodes_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]: + workspace_label = self._get_workspace_label() + query = f""" + UNWIND $chunk_ids AS chunk_id + MATCH (n:`{workspace_label}`) + WHERE n.source_id IS NOT NULL AND chunk_id IN split(n.source_id, $sep) + RETURN DISTINCT n + """ + result = await self._run_query( + query, {"chunk_ids": chunk_ids, "sep": GRAPH_FIELD_SEP} + ) + nodes = [] + + if result.result_set: + for record in result.result_set: + node = record[0] + node_dict = {key: value for key, value in node.properties.items()} + # Add node id (entity_id) to the dictionary for easier access + node_dict["id"] = node_dict.get("entity_id") + nodes.append(node_dict) + + return nodes + + async def get_edges_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]: + workspace_label = self._get_workspace_label() + query = f""" + UNWIND $chunk_ids AS chunk_id + MATCH (a:`{workspace_label}`)-[r]-(b:`{workspace_label}`) + WHERE r.source_id IS NOT NULL AND chunk_id IN split(r.source_id, $sep) + RETURN DISTINCT a.entity_id AS source, b.entity_id AS target, properties(r) AS properties + """ + result = await self._run_query( + query, {"chunk_ids": chunk_ids, "sep": GRAPH_FIELD_SEP} + ) + edges = [] + + if result.result_set: + for record in result.result_set: + edge_properties = record[2] + edge_properties["source"] = record[0] + edge_properties["target"] = record[1] + edges.append(edge_properties) + + return edges + + @retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=4, max=10), + retry=retry_if_exception_type((redis.exceptions.RedisError, Exception)), + ) + async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None: + """ + Upsert a node in the FalkorDB database. + + Args: + node_id: The unique identifier for the node (used as label) + node_data: Dictionary of node properties + """ + workspace_label = self._get_workspace_label() + properties = node_data + entity_type = properties["entity_type"] + if "entity_id" not in properties: + raise ValueError( + "FalkorDB: node properties must contain an 'entity_id' field" + ) + + try: + query = f""" + MERGE (n:`{workspace_label}` {{entity_id: $entity_id}}) + SET n += $properties + SET n:`{entity_type}` + """ + await self._run_query( + query, {"entity_id": node_id, "properties": properties} + ) + except Exception as e: + logger.error(f"Error during upsert: {str(e)}") + raise + + @retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=4, max=10), + retry=retry_if_exception_type((redis.exceptions.RedisError, Exception)), + ) + async def upsert_edge( + self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] + ) -> None: + """ + Upsert an edge and its properties between two nodes identified by their labels. + Ensures both source and target nodes exist and are unique before creating the edge. + Uses entity_id property to uniquely identify nodes. + + Args: + source_node_id (str): Label of the source node (used as identifier) + target_node_id (str): Label of the target node (used as identifier) + edge_data (dict): Dictionary of properties to set on the edge + + Raises: + ValueError: If either source or target node does not exist or is not unique + """ + try: + edge_properties = edge_data + workspace_label = self._get_workspace_label() + query = f""" + MATCH (source:`{workspace_label}` {{entity_id: $source_entity_id}}) + WITH source + MATCH (target:`{workspace_label}` {{entity_id: $target_entity_id}}) + MERGE (source)-[r:DIRECTED]-(target) + SET r += $properties + RETURN r, source, target + """ + await self._run_query( + query, + { + "source_entity_id": source_node_id, + "target_entity_id": target_node_id, + "properties": edge_properties, + }, + ) + except Exception as e: + logger.error(f"Error during edge upsert: {str(e)}") + raise + + async def get_knowledge_graph( + self, + node_label: str, + max_depth: int = 3, + max_nodes: int = None, + ) -> KnowledgeGraph: + """ + Retrieve a connected subgraph of nodes where the label includes the specified `node_label`. + + Args: + node_label: Label of the starting node, * means all nodes + max_depth: Maximum depth of the subgraph, Defaults to 3 + max_nodes: Maximum nodes to return by BFS, Defaults to 1000 + + Returns: + KnowledgeGraph object containing nodes and edges, with an is_truncated flag + indicating whether the graph was truncated due to max_nodes limit + """ + # Get max_nodes from global_config if not provided + if max_nodes is None: + max_nodes = self.global_config.get("max_graph_nodes", 1000) + else: + # Limit max_nodes to not exceed global_config max_graph_nodes + max_nodes = min(max_nodes, self.global_config.get("max_graph_nodes", 1000)) + + workspace_label = self._get_workspace_label() + result = KnowledgeGraph() + seen_nodes = set() + seen_edges = set() + + try: + if node_label == "*": + # Get all nodes with highest degree + query = f""" + MATCH (n:`{workspace_label}`) + OPTIONAL MATCH (n)-[r]-() + WITH n, COALESCE(count(r), 0) AS degree + ORDER BY degree DESC + LIMIT $max_nodes + WITH collect(n) AS nodes + UNWIND nodes AS node + OPTIONAL MATCH (node)-[rel]-(connected) + WHERE connected IN nodes + RETURN collect(DISTINCT node) AS filtered_nodes, + collect(DISTINCT rel) AS relationships + """ + graph_result = await self._run_query(query, {"max_nodes": max_nodes}) + else: + # Get subgraph starting from specific node + # Simple BFS implementation since FalkorDB might not have APOC + query = f""" + MATCH path = (start:`{workspace_label}` {{entity_id: $entity_id}})-[*0..{max_depth}]-(connected) + WITH nodes(path) AS path_nodes, relationships(path) AS path_rels + UNWIND path_nodes AS node + WITH collect(DISTINCT node) AS all_nodes, path_rels + UNWIND path_rels AS rel + WITH all_nodes, collect(DISTINCT rel) AS all_rels + RETURN all_nodes[0..{max_nodes}] AS filtered_nodes, all_rels AS relationships + """ + graph_result = await self._run_query(query, {"entity_id": node_label}) + + if graph_result.result_set: + record = graph_result.result_set[0] + nodes_list = record[0] if record[0] else [] + relationships_list = record[1] if record[1] else [] + + # Check if truncated + if len(nodes_list) >= max_nodes: + result.is_truncated = True + + # Handle nodes + for node in nodes_list: + node_id = str(id(node)) # Use internal node ID + if node_id not in seen_nodes: + result.nodes.append( + KnowledgeGraphNode( + id=node_id, + labels=[node.properties.get("entity_id", "")], + properties=dict(node.properties), + ) + ) + seen_nodes.add(node_id) + + # Handle relationships + for rel in relationships_list: + edge_id = str(id(rel)) # Use internal relationship ID + if edge_id not in seen_edges: + # Get start and end node IDs + start_node_id = str(rel.src_node) + end_node_id = str(rel.dest_node) + + result.edges.append( + KnowledgeGraphEdge( + id=edge_id, + type=rel.relation, + source=start_node_id, + target=end_node_id, + properties=dict(rel.properties), + ) + ) + seen_edges.add(edge_id) + + logger.info( + f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}" + ) + + except Exception as e: + logger.error(f"Error in get_knowledge_graph: {str(e)}") + # Return empty graph on error + pass + + return result + + async def get_all_labels(self) -> list[str]: + """ + Get all existing node labels in the database + Returns: + ["Person", "Company", ...] # Alphabetically sorted label list + """ + workspace_label = self._get_workspace_label() + query = f""" + MATCH (n:`{workspace_label}`) + WHERE n.entity_id IS NOT NULL + RETURN DISTINCT n.entity_id AS label + ORDER BY label + """ + result = await self._run_query(query) + labels = [] + + if result.result_set: + for record in result.result_set: + labels.append(record[0]) + + return labels + + @retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=4, max=10), + retry=retry_if_exception_type((redis.exceptions.RedisError, Exception)), + ) + async def delete_node(self, node_id: str) -> None: + """Delete a node with the specified label + + Args: + node_id: The label of the node to delete + """ + try: + workspace_label = self._get_workspace_label() + query = f""" + MATCH (n:`{workspace_label}` {{entity_id: $entity_id}}) + DETACH DELETE n + """ + await self._run_query(query, {"entity_id": node_id}) + logger.debug(f"Deleted node with label '{node_id}'") + except Exception as e: + logger.error(f"Error during node deletion: {str(e)}") + raise + + @retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=4, max=10), + retry=retry_if_exception_type((redis.exceptions.RedisError, Exception)), + ) + async def remove_nodes(self, nodes: list[str]): + """Delete multiple nodes + + Args: + nodes: List of node labels to be deleted + """ + for node in nodes: + await self.delete_node(node) + + @retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=4, max=10), + retry=retry_if_exception_type((redis.exceptions.RedisError, Exception)), + ) + async def remove_edges(self, edges: list[tuple[str, str]]): + """Delete multiple edges + + Args: + edges: List of edges to be deleted, each edge is a (source, target) tuple + """ + for source, target in edges: + try: + workspace_label = self._get_workspace_label() + query = f""" + MATCH (source:`{workspace_label}` {{entity_id: $source_entity_id}})-[r]-(target:`{workspace_label}` {{entity_id: $target_entity_id}}) + DELETE r + """ + await self._run_query( + query, {"source_entity_id": source, "target_entity_id": target} + ) + logger.debug(f"Deleted edge from '{source}' to '{target}'") + except Exception as e: + logger.error(f"Error during edge deletion: {str(e)}") + raise + + async def get_all_nodes(self) -> list[dict]: + """Get all nodes in the graph. + + Returns: + A list of all nodes, where each node is a dictionary of its properties + """ + workspace_label = self._get_workspace_label() + query = f""" + MATCH (n:`{workspace_label}`) + RETURN n + """ + result = await self._run_query(query) + nodes = [] + + if result.result_set: + for record in result.result_set: + node = record[0] + node_dict = {key: value for key, value in node.properties.items()} + # Add node id (entity_id) to the dictionary for easier access + node_dict["id"] = node_dict.get("entity_id") + nodes.append(node_dict) + + return nodes + + async def get_all_edges(self) -> list[dict]: + """Get all edges in the graph. + + Returns: + A list of all edges, where each edge is a dictionary of its properties + """ + workspace_label = self._get_workspace_label() + query = f""" + MATCH (a:`{workspace_label}`)-[r]-(b:`{workspace_label}`) + RETURN DISTINCT a.entity_id AS source, b.entity_id AS target, properties(r) AS properties + """ + result = await self._run_query(query) + edges = [] + + if result.result_set: + for record in result.result_set: + edge_properties = record[2] + edge_properties["source"] = record[0] + edge_properties["target"] = record[1] + edges.append(edge_properties) + + return edges + + async def drop(self) -> dict[str, str]: + """Drop all data from current workspace storage and clean up resources + + This method will delete all nodes and relationships in the current workspace only. + + Returns: + dict[str, str]: Operation status and message + - On success: {"status": "success", "message": "workspace data dropped"} + - On failure: {"status": "error", "message": ""} + """ + workspace_label = self._get_workspace_label() + try: + # Delete all nodes and relationships in current workspace only + query = f"MATCH (n:`{workspace_label}`) DETACH DELETE n" + await self._run_query(query) + + logger.info( + f"Process {os.getpid()} drop FalkorDB workspace '{workspace_label}'" + ) + return { + "status": "success", + "message": f"workspace '{workspace_label}' data dropped", + } + except Exception as e: + logger.error(f"Error dropping FalkorDB workspace '{workspace_label}': {e}") + return {"status": "error", "message": str(e)} + + async def get_popular_labels(self, limit: int = 300) -> list[str]: + """Get popular labels by node degree (most connected entities) + + Args: + limit: Maximum number of labels to return + + Returns: + List of labels sorted by degree (highest first) + """ + workspace_label = self._get_workspace_label() + try: + query = f""" + MATCH (n:`{workspace_label}`) + WHERE n.entity_id IS NOT NULL + OPTIONAL MATCH (n)-[r]-() + WITH n.entity_id AS label, count(r) AS degree + ORDER BY degree DESC, label ASC + LIMIT {limit} + RETURN label + """ + result = await self._run_query(query) + labels = [] + + if result.result_set: + for record in result.result_set: + labels.append(record[0]) + + logger.debug( + f"[{self.workspace}] Retrieved {len(labels)} popular labels (limit: {limit})" + ) + return labels + except Exception as e: + logger.error(f"[{self.workspace}] Error getting popular labels: {str(e)}") + return [] + + async def search_labels(self, query: str, limit: int = 50) -> list[str]: + """Search labels with fuzzy matching + + Args: + query: Search query string + limit: Maximum number of results to return + + Returns: + List of matching labels sorted by relevance + """ + workspace_label = self._get_workspace_label() + query_lower = query.lower().strip() + + if not query_lower: + return [] + + try: + # FalkorDB search using CONTAINS with relevance scoring + cypher_query = f""" + MATCH (n:`{workspace_label}`) + WHERE n.entity_id IS NOT NULL + WITH n.entity_id AS label, toLower(n.entity_id) AS label_lower + WHERE label_lower CONTAINS $query_lower + WITH label, label_lower, + CASE + WHEN label_lower = $query_lower THEN 1000 + WHEN label_lower STARTS WITH $query_lower THEN 500 + ELSE 100 - size(label) + END AS score + ORDER BY score DESC, label ASC + LIMIT {limit} + RETURN label + """ + + result = await self._run_query(cypher_query, {"query_lower": query_lower}) + labels = [] + + if result.result_set: + for record in result.result_set: + labels.append(record[0]) + + logger.debug( + f"[{self.workspace}] Search query '{query}' returned {len(labels)} results (limit: {limit})" + ) + return labels + except Exception as e: + logger.error(f"[{self.workspace}] Error searching labels: {str(e)}") + return []