Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
"""This example demonstrates how to use the SchemaFromExistingGraphExtractor component
to automatically extract a schema from an existing Neo4j database.
"""

import asyncio
from pprint import pprint

import neo4j

from neo4j_graphrag.experimental.components.schema import (
SchemaFromExistingGraphExtractor,
GraphSchema,
)


URI = "neo4j+s://demo.neo4jlabs.com"
AUTH = ("recommendations", "recommendations")
DATABASE = "recommendations"


async def main() -> None:
"""Run the example."""

with neo4j.GraphDatabase.driver(
URI,
auth=AUTH,
) as driver:
extractor = SchemaFromExistingGraphExtractor(
driver,
# optional:
neo4j_database=DATABASE,
additional_patterns=True,
additional_node_types=True,
additional_relationship_types=True,
additional_properties=True,
)
schema: GraphSchema = await extractor.run()
# schema.store_as_json("my_schema.json")
pprint(schema.model_dump())


if __name__ == "__main__":
asyncio.run(main())
177 changes: 175 additions & 2 deletions src/neo4j_graphrag/experimental/components/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from __future__ import annotations

import json

import neo4j
import logging
import warnings
from typing import Any, Dict, List, Literal, Optional, Tuple, Union, Sequence, Callable
Expand Down Expand Up @@ -44,6 +46,10 @@
from neo4j_graphrag.generation import SchemaExtractionTemplate, PromptTemplate
from neo4j_graphrag.llm import LLMInterface
from neo4j_graphrag.utils.file_handler import FileHandler, FileFormat
from neo4j_graphrag.schema import get_structured_schema


logger = logging.getLogger(__name__)


class PropertyType(BaseModel):
Expand Down Expand Up @@ -306,7 +312,12 @@ def from_file(
raise SchemaValidationError(str(e)) from e


class SchemaBuilder(Component):
class BaseSchemaBuilder(Component):
async def run(self, *args: Any, **kwargs: Any) -> GraphSchema:
raise NotImplementedError()


class SchemaBuilder(BaseSchemaBuilder):
"""
A builder class for constructing GraphSchema objects from given entities,
relations, and their interrelationships defined in a potential schema.
Expand Down Expand Up @@ -424,7 +435,7 @@ async def run(
)


class SchemaFromTextExtractor(Component):
class SchemaFromTextExtractor(BaseSchemaBuilder):
"""
A component for constructing GraphSchema objects from the output of an LLM after
automatic schema extraction from text.
Expand Down Expand Up @@ -621,3 +632,165 @@ async def run(self, text: str, examples: str = "", **kwargs: Any) -> GraphSchema
"patterns": extracted_patterns,
}
)


class SchemaFromExistingGraphExtractor(BaseSchemaBuilder):
"""A class to build a GraphSchema object from an existing graph.

Uses the get_structured_schema function to extract existing node labels,
relationship types, properties and existence constraints.

By default, the built schema does not allow any additional item (property,
node label, relationship type or pattern).

Args:
driver (neo4j.Driver): connection to the neo4j database.
additional_properties (bool, default False): see GraphSchema
additional_node_types (bool, default False): see GraphSchema
additional_relationship_types (bool, default False): see GraphSchema:
additional_patterns (bool, default False): see GraphSchema:
neo4j_database (Optional | str): name of the neo4j database to use
"""

def __init__(
self,
driver: neo4j.Driver,
additional_properties: bool | None = None,
additional_node_types: bool | None = None,
additional_relationship_types: bool | None = None,
additional_patterns: bool | None = None,
neo4j_database: Optional[str] = None,
) -> None:
self.driver = driver
self.database = neo4j_database

self.additional_properties = additional_properties
self.additional_node_types = additional_node_types
self.additional_relationship_types = additional_relationship_types
self.additional_patterns = additional_patterns

@staticmethod
def _extract_required_properties(
structured_schema: dict[str, Any],
) -> list[tuple[str, str]]:
"""Extract a list of (node label (or rel type), property name) for which
an "EXISTENCE" or "KEY" constraint is defined in the DB.

Args:

structured_schema (dict[str, Any]): the result of the `get_structured_schema()` function.

Returns:

list of tuples of (node label (or rel type), property name)

"""
schema_metadata = structured_schema.get("metadata", {})
existence_constraint = [] # list of (node label, property name)
for constraint in schema_metadata.get("constraint", []):
if constraint["type"] in (
"NODE_PROPERTY_EXISTENCE",
"NODE_KEY",
"RELATIONSHIP_PROPERTY_EXISTENCE",
"RELATIONSHIP_KEY",
):
properties = constraint["properties"]
labels = constraint["labelsOrTypes"]
# note: existence constraint only apply to a single property
# and a single label
prop = properties[0]
lab = labels[0]
existence_constraint.append((lab, prop))
return existence_constraint

def _to_schema_entity_dict(
self,
key: str,
property_dict: list[dict[str, Any]],
existence_constraint: list[tuple[str, str]],
) -> dict[str, Any]:
entity_dict: dict[str, Any] = {
"label": key,
"properties": [
{
"name": p["property"],
"type": p["type"],
"required": (key, p["property"]) in existence_constraint,
}
for p in property_dict
],
}
if self.additional_properties:
entity_dict["additional_properties"] = self.additional_properties
return entity_dict

async def run(self, *args: Any, **kwargs: Any) -> GraphSchema:
structured_schema = get_structured_schema(self.driver, database=self.database)
existence_constraint = self._extract_required_properties(structured_schema)

# node label with properties
node_labels = set(structured_schema["node_props"].keys())
node_types = [
self._to_schema_entity_dict(key, properties, existence_constraint)
for key, properties in structured_schema["node_props"].items()
]

# relationships with properties
rel_labels = set(structured_schema["rel_props"].keys())
relationship_types = [
self._to_schema_entity_dict(key, properties, existence_constraint)
for key, properties in structured_schema["rel_props"].items()
]

patterns = [
(s["start"], s["type"], s["end"])
for s in structured_schema["relationships"]
]

# deal with nodes and relationships without properties
for source, rel, target in patterns:
if source not in node_labels:
if self.additional_properties is False:
logger.warning(
f"SCHEMA: found node label {source} without property and additional_properties=False: this node label will always be pruned!"
)
node_labels.add(source)
node_types.append(
{
"label": source,
}
)
if target not in node_labels:
if self.additional_properties is False:
logger.warning(
f"SCHEMA: found node label {target} without property and additional_properties=False: this node label will always be pruned!"
)
node_labels.add(target)
node_types.append(
{
"label": target,
}
)
if rel not in rel_labels:
rel_labels.add(rel)
relationship_types.append(
{
"label": rel,
}
)
schema_dict: dict[str, Any] = {
"node_types": node_types,
"relationship_types": relationship_types,
"patterns": patterns,
}
if self.additional_node_types is not None:
schema_dict["additional_node_types"] = self.additional_node_types
if self.additional_relationship_types is not None:
schema_dict["additional_relationship_types"] = (
self.additional_relationship_types
)
if self.additional_patterns is not None:
schema_dict["additional_patterns"] = self.additional_patterns
return GraphSchema.model_validate(
schema_dict,
)
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
SchemaBuilder,
GraphSchema,
SchemaFromTextExtractor,
BaseSchemaBuilder,
)
from neo4j_graphrag.experimental.components.text_splitters.base import TextSplitter
from neo4j_graphrag.experimental.components.text_splitters.fixed_size_splitter import (
Expand Down Expand Up @@ -178,7 +179,7 @@ def _get_run_params_for_splitter(self) -> dict[str, Any]:
def _get_chunk_embedder(self) -> TextChunkEmbedder:
return TextChunkEmbedder(embedder=self.get_default_embedder())

def _get_schema(self) -> Union[SchemaBuilder, SchemaFromTextExtractor]:
def _get_schema(self) -> BaseSchemaBuilder:
"""
Get the appropriate schema component based on configuration.
Return SchemaFromTextExtractor for automatic extraction or SchemaBuilder for manual schema.
Expand Down
Loading