Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 44 additions & 39 deletions graphiti_core/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,23 +33,25 @@

load_dotenv()

USE_PARALLEL_RUNTIME = bool(os.getenv('USE_PARALLEL_RUNTIME', False))
SEMAPHORE_LIMIT = int(os.getenv('SEMAPHORE_LIMIT', 20))
MAX_REFLEXION_ITERATIONS = int(os.getenv('MAX_REFLEXION_ITERATIONS', 0))
USE_PARALLEL_RUNTIME = bool(os.getenv("USE_PARALLEL_RUNTIME", False))
SEMAPHORE_LIMIT = int(os.getenv("SEMAPHORE_LIMIT", 20))
MAX_REFLEXION_ITERATIONS = int(os.getenv("MAX_REFLEXION_ITERATIONS", 0))
DEFAULT_PAGE_LIMIT = 20
# Max tokens for edge extraction operations
EXTRACT_EDGES_MAX_TOKENS = 16384

RUNTIME_QUERY: LiteralString = (
'CYPHER runtime = parallel parallelRuntimeSupport=all\n' if USE_PARALLEL_RUNTIME else ''
"CYPHER runtime = parallel parallelRuntimeSupport=all\n"
if USE_PARALLEL_RUNTIME
else ""
)


def parse_db_date(neo_date: neo4j_time.DateTime | str | None) -> datetime | None:
return (
neo_date.to_native()
if isinstance(neo_date, neo4j_time.DateTime)
else datetime.fromisoformat(neo_date)
if neo_date
else None
else datetime.fromisoformat(neo_date) if neo_date else None
)


Expand All @@ -59,41 +61,41 @@ def get_default_group_id(provider: GraphProvider) -> str:
For most databases, the default group id is an empty string, while there are database types that require a specific default group id.
"""
if provider == GraphProvider.FALKORDB:
return '_'
return "_"
else:
return ''
return ""


def lucene_sanitize(query: str) -> str:
# Escape special characters from a query before passing into Lucene
# + - && || ! ( ) { } [ ] ^ " ~ * ? : \ /
escape_map = str.maketrans(
{
'+': r'\+',
'-': r'\-',
'&': r'\&',
'|': r'\|',
'!': r'\!',
'(': r'\(',
')': r'\)',
'{': r'\{',
'}': r'\}',
'[': r'\[',
']': r'\]',
'^': r'\^',
'"': r'\"',
'~': r'\~',
'*': r'\*',
'?': r'\?',
':': r'\:',
'\\': r'\\',
'/': r'\/',
'O': r'\O',
'R': r'\R',
'N': r'\N',
'T': r'\T',
'A': r'\A',
'D': r'\D',
"+": r"\+",
"-": r"\-",
"&": r"\&",
"|": r"\|",
"!": r"\!",
"(": r"\(",
")": r"\)",
"{": r"\{",
"}": r"\}",
"[": r"\[",
"]": r"\]",
"^": r"\^",
'"': r"\"",
"~": r"\~",
"*": r"\*",
"?": r"\?",
":": r"\:",
"\\": r"\\",
"/": r"\/",
"O": r"\O",
"R": r"\R",
"N": r"\N",
"T": r"\T",
"A": r"\A",
"D": r"\D",
}
)

Expand All @@ -118,7 +120,9 @@ async def _wrap_coroutine(coroutine):
async with semaphore:
return await coroutine

return await asyncio.gather(*(_wrap_coroutine(coroutine) for coroutine in coroutines))
return await asyncio.gather(
*(_wrap_coroutine(coroutine) for coroutine in coroutines)
)


def validate_group_id(group_id: str) -> bool:
Expand All @@ -141,14 +145,15 @@ def validate_group_id(group_id: str) -> bool:

# Check if string contains only ASCII alphanumeric characters, dashes, or underscores
# Pattern matches: letters (a-z, A-Z), digits (0-9), hyphens (-), and underscores (_)
if not re.match(r'^[a-zA-Z0-9_-]+$', group_id):
if not re.match(r"^[a-zA-Z0-9_-]+$", group_id):
raise GroupIdValidationError(group_id)

return True


def validate_excluded_entity_types(
excluded_entity_types: list[str] | None, entity_types: dict[str, type[BaseModel]] | None = None
excluded_entity_types: list[str] | None,
entity_types: dict[str, type[BaseModel]] | None = None,
) -> bool:
"""
Validate that excluded entity types are valid type names.
Expand All @@ -167,15 +172,15 @@ def validate_excluded_entity_types(
return True

# Build set of available type names
available_types = {'Entity'} # Default type is always available
available_types = {"Entity"} # Default type is always available
if entity_types:
available_types.update(entity_types.keys())

# Check for invalid type names
invalid_types = set(excluded_entity_types) - available_types
if invalid_types:
raise ValueError(
f'Invalid excluded entity types: {sorted(invalid_types)}. Available types: {sorted(available_types)}'
f"Invalid excluded entity types: {sorted(invalid_types)}. Available types: {sorted(available_types)}"
)

return True
108 changes: 58 additions & 50 deletions graphiti_core/llm_client/anthropic_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,30 +39,30 @@
from anthropic.types import MessageParam, ToolChoiceParam, ToolUnionParam
except ImportError:
raise ImportError(
'anthropic is required for AnthropicClient. '
'Install it with: pip install graphiti-core[anthropic]'
"anthropic is required for AnthropicClient. "
"Install it with: pip install graphiti-core[anthropic]"
) from None


logger = logging.getLogger(__name__)

AnthropicModel = Literal[
'claude-3-7-sonnet-latest',
'claude-3-7-sonnet-20250219',
'claude-3-5-haiku-latest',
'claude-3-5-haiku-20241022',
'claude-3-5-sonnet-latest',
'claude-3-5-sonnet-20241022',
'claude-3-5-sonnet-20240620',
'claude-3-opus-latest',
'claude-3-opus-20240229',
'claude-3-sonnet-20240229',
'claude-3-haiku-20240307',
'claude-2.1',
'claude-2.0',
"claude-3-7-sonnet-latest",
"claude-3-7-sonnet-20250219",
"claude-3-5-haiku-latest",
"claude-3-5-haiku-20241022",
"claude-3-5-sonnet-latest",
"claude-3-5-sonnet-20241022",
"claude-3-5-sonnet-20240620",
"claude-3-opus-latest",
"claude-3-opus-20240229",
"claude-3-sonnet-20240229",
"claude-3-haiku-20240307",
"claude-2.1",
"claude-2.0",
]

DEFAULT_MODEL: AnthropicModel = 'claude-3-7-sonnet-latest'
DEFAULT_MODEL: AnthropicModel = "claude-3-7-sonnet-latest"


class AnthropicClient(LLMClient):
Expand Down Expand Up @@ -95,7 +95,7 @@ def __init__(
) -> None:
if config is None:
config = LLMConfig()
config.api_key = os.getenv('ANTHROPIC_API_KEY')
config.api_key = os.getenv("ANTHROPIC_API_KEY")
config.max_tokens = max_tokens

if config.model is None:
Expand Down Expand Up @@ -129,15 +129,17 @@ def _extract_json_from_text(self, text: str) -> dict[str, typing.Any]:
ValueError: If JSON cannot be extracted or parsed
"""
try:
json_start = text.find('{')
json_end = text.rfind('}') + 1
json_start = text.find("{")
json_end = text.rfind("}") + 1
if json_start >= 0 and json_end > json_start:
json_str = text[json_start:json_end]
return json.loads(json_str)
else:
raise ValueError(f'Could not extract JSON from model response: {text}')
raise ValueError(f"Could not extract JSON from model response: {text}")
except (JSONDecodeError, ValueError) as e:
raise ValueError(f'Could not extract JSON from model response: {text}') from e
raise ValueError(
f"Could not extract JSON from model response: {text}"
) from e

def _create_tool(
self, response_model: type[BaseModel] | None = None
Expand All @@ -155,25 +157,27 @@ def _create_tool(
# Use the response_model to define the tool
model_schema = response_model.model_json_schema()
tool_name = response_model.__name__
description = model_schema.get('description', f'Extract {tool_name} information')
description = model_schema.get(
"description", f"Extract {tool_name} information"
)
else:
# Create a generic JSON output tool
tool_name = 'generic_json_output'
description = 'Output data in JSON format'
tool_name = "generic_json_output"
description = "Output data in JSON format"
model_schema = {
'type': 'object',
'additionalProperties': True,
'description': 'Any JSON object containing the requested information',
"type": "object",
"additionalProperties": True,
"description": "Any JSON object containing the requested information",
}

tool = {
'name': tool_name,
'description': description,
'input_schema': model_schema,
"name": tool_name,
"description": description,
"input_schema": model_schema,
}
tool_list = [tool]
tool_list_cast = typing.cast(list[ToolUnionParam], tool_list)
tool_choice = {'type': 'tool', 'name': tool_name}
tool_choice = {"type": "tool", "name": tool_name}
tool_choice_cast = typing.cast(ToolChoiceParam, tool_choice)
return tool_list_cast, tool_choice_cast

Expand Down Expand Up @@ -201,11 +205,9 @@ async def _generate_response(
Exception: If an error occurs during the generation process.
"""
system_message = messages[0]
user_messages = [{'role': m.role, 'content': m.content} for m in messages[1:]]
user_messages = [{"role": m.role, "content": m.content} for m in messages[1:]]
user_messages_cast = typing.cast(list[MessageParam], user_messages)

# TODO: Replace hacky min finding solution after fixing hardcoded EXTRACT_EDGES_MAX_TOKENS = 16384 in
# edge_operations.py. Throws errors with cheaper models that lower max_tokens.
max_creation_tokens: int = min(
max_tokens if max_tokens is not None else self.config.max_tokens,
DEFAULT_MAX_TOKENS,
Expand All @@ -226,7 +228,7 @@ async def _generate_response(

# Extract the tool output from the response
for content_item in result.content:
if content_item.type == 'tool_use':
if content_item.type == "tool_use":
if isinstance(content_item.input, dict):
tool_args: dict[str, typing.Any] = content_item.input
else:
Expand All @@ -235,25 +237,27 @@ async def _generate_response(

# If we didn't get a proper tool_use response, try to extract from text
for content_item in result.content:
if content_item.type == 'text':
if content_item.type == "text":
return self._extract_json_from_text(content_item.text)
else:
raise ValueError(
f'Could not extract structured data from model response: {result.content}'
f"Could not extract structured data from model response: {result.content}"
)

# If we get here, we couldn't parse a structured response
raise ValueError(
f'Could not extract structured data from model response: {result.content}'
f"Could not extract structured data from model response: {result.content}"
)

except anthropic.RateLimitError as e:
raise RateLimitError(f'Rate limit exceeded. Please try again later. Error: {e}') from e
raise RateLimitError(
f"Rate limit exceeded. Please try again later. Error: {e}"
) from e
except anthropic.APIError as e:
# Special case for content policy violations. We convert these to RefusalError
# to bypass the retry mechanism, as retrying policy-violating content will always fail.
# This avoids wasting API calls and provides more specific error messaging to the user.
if 'refused to respond' in str(e).lower():
if "refused to respond" in str(e).lower():
raise RefusalError(str(e)) from e
raise e
except Exception as e:
Expand Down Expand Up @@ -313,27 +317,31 @@ async def generate_response(
if retry_count >= max_retries:
if isinstance(e, ValidationError):
logger.error(
f'Validation error after {retry_count}/{max_retries} attempts: {e}'
f"Validation error after {retry_count}/{max_retries} attempts: {e}"
)
else:
logger.error(f'Max retries ({max_retries}) exceeded. Last error: {e}')
logger.error(
f"Max retries ({max_retries}) exceeded. Last error: {e}"
)
raise e

if isinstance(e, ValidationError):
response_model_cast = typing.cast(type[BaseModel], response_model)
error_context = f'The previous response was invalid. Please provide a valid {response_model_cast.__name__} object. Error: {e}'
error_context = f"The previous response was invalid. Please provide a valid {response_model_cast.__name__} object. Error: {e}"
else:
error_context = (
f'The previous response attempt was invalid. '
f'Error type: {e.__class__.__name__}. '
f'Error details: {str(e)}. '
f'Please try again with a valid response.'
f"The previous response attempt was invalid. "
f"Error type: {e.__class__.__name__}. "
f"Error details: {str(e)}. "
f"Please try again with a valid response."
)

# Common retry logic
retry_count += 1
messages.append(Message(role='user', content=error_context))
logger.warning(f'Retrying after error (attempt {retry_count}/{max_retries}): {e}')
messages.append(Message(role="user", content=error_context))
logger.warning(
f"Retrying after error (attempt {retry_count}/{max_retries}): {e}"
)

# If we somehow get here, raise the last error
raise last_error or Exception('Max retries exceeded with no specific error')
raise last_error or Exception("Max retries exceeded with no specific error")
4 changes: 2 additions & 2 deletions graphiti_core/utils/maintenance/edge_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
create_entity_edge_embeddings,
)
from graphiti_core.graphiti_types import GraphitiClients
from graphiti_core.helpers import MAX_REFLEXION_ITERATIONS, semaphore_gather
from graphiti_core.helpers import MAX_REFLEXION_ITERATIONS, semaphore_gather, EXTRACT_EDGES_MAX_TOKENS
from graphiti_core.llm_client import LLMClient
from graphiti_core.llm_client.config import ModelSize
from graphiti_core.nodes import CommunityNode, EntityNode, EpisodicNode
Expand Down Expand Up @@ -118,7 +118,7 @@ async def extract_edges(
) -> list[EntityEdge]:
start = time()

extract_edges_max_tokens = 16384
extract_edges_max_tokens = EXTRACT_EDGES_MAX_TOKENS
llm_client = clients.llm_client

edge_type_signature_map: dict[str, tuple[str, str]] = {
Expand Down