diff --git a/graphiti_core/helpers.py b/graphiti_core/helpers.py index 9feb30735..f52ca4728 100644 --- a/graphiti_core/helpers.py +++ b/graphiti_core/helpers.py @@ -33,13 +33,17 @@ load_dotenv() -USE_PARALLEL_RUNTIME = bool(os.getenv('USE_PARALLEL_RUNTIME', False)) -SEMAPHORE_LIMIT = int(os.getenv('SEMAPHORE_LIMIT', 20)) -MAX_REFLEXION_ITERATIONS = int(os.getenv('MAX_REFLEXION_ITERATIONS', 0)) +USE_PARALLEL_RUNTIME = bool(os.getenv("USE_PARALLEL_RUNTIME", False)) +SEMAPHORE_LIMIT = int(os.getenv("SEMAPHORE_LIMIT", 20)) +MAX_REFLEXION_ITERATIONS = int(os.getenv("MAX_REFLEXION_ITERATIONS", 0)) DEFAULT_PAGE_LIMIT = 20 +# Max tokens for edge extraction operations +EXTRACT_EDGES_MAX_TOKENS = 16384 RUNTIME_QUERY: LiteralString = ( - 'CYPHER runtime = parallel parallelRuntimeSupport=all\n' if USE_PARALLEL_RUNTIME else '' + "CYPHER runtime = parallel parallelRuntimeSupport=all\n" + if USE_PARALLEL_RUNTIME + else "" ) @@ -47,9 +51,7 @@ def parse_db_date(neo_date: neo4j_time.DateTime | str | None) -> datetime | None return ( neo_date.to_native() if isinstance(neo_date, neo4j_time.DateTime) - else datetime.fromisoformat(neo_date) - if neo_date - else None + else datetime.fromisoformat(neo_date) if neo_date else None ) @@ -59,9 +61,9 @@ def get_default_group_id(provider: GraphProvider) -> str: For most databases, the default group id is an empty string, while there are database types that require a specific default group id. """ if provider == GraphProvider.FALKORDB: - return '_' + return "_" else: - return '' + return "" def lucene_sanitize(query: str) -> str: @@ -69,31 +71,31 @@ def lucene_sanitize(query: str) -> str: # + - && || ! ( ) { } [ ] ^ " ~ * ? : \ / escape_map = str.maketrans( { - '+': r'\+', - '-': r'\-', - '&': r'\&', - '|': r'\|', - '!': r'\!', - '(': r'\(', - ')': r'\)', - '{': r'\{', - '}': r'\}', - '[': r'\[', - ']': r'\]', - '^': r'\^', - '"': r'\"', - '~': r'\~', - '*': r'\*', - '?': r'\?', - ':': r'\:', - '\\': r'\\', - '/': r'\/', - 'O': r'\O', - 'R': r'\R', - 'N': r'\N', - 'T': r'\T', - 'A': r'\A', - 'D': r'\D', + "+": r"\+", + "-": r"\-", + "&": r"\&", + "|": r"\|", + "!": r"\!", + "(": r"\(", + ")": r"\)", + "{": r"\{", + "}": r"\}", + "[": r"\[", + "]": r"\]", + "^": r"\^", + '"': r"\"", + "~": r"\~", + "*": r"\*", + "?": r"\?", + ":": r"\:", + "\\": r"\\", + "/": r"\/", + "O": r"\O", + "R": r"\R", + "N": r"\N", + "T": r"\T", + "A": r"\A", + "D": r"\D", } ) @@ -118,7 +120,9 @@ async def _wrap_coroutine(coroutine): async with semaphore: return await coroutine - return await asyncio.gather(*(_wrap_coroutine(coroutine) for coroutine in coroutines)) + return await asyncio.gather( + *(_wrap_coroutine(coroutine) for coroutine in coroutines) + ) def validate_group_id(group_id: str) -> bool: @@ -141,14 +145,15 @@ def validate_group_id(group_id: str) -> bool: # Check if string contains only ASCII alphanumeric characters, dashes, or underscores # Pattern matches: letters (a-z, A-Z), digits (0-9), hyphens (-), and underscores (_) - if not re.match(r'^[a-zA-Z0-9_-]+$', group_id): + if not re.match(r"^[a-zA-Z0-9_-]+$", group_id): raise GroupIdValidationError(group_id) return True def validate_excluded_entity_types( - excluded_entity_types: list[str] | None, entity_types: dict[str, type[BaseModel]] | None = None + excluded_entity_types: list[str] | None, + entity_types: dict[str, type[BaseModel]] | None = None, ) -> bool: """ Validate that excluded entity types are valid type names. @@ -167,7 +172,7 @@ def validate_excluded_entity_types( return True # Build set of available type names - available_types = {'Entity'} # Default type is always available + available_types = {"Entity"} # Default type is always available if entity_types: available_types.update(entity_types.keys()) @@ -175,7 +180,7 @@ def validate_excluded_entity_types( invalid_types = set(excluded_entity_types) - available_types if invalid_types: raise ValueError( - f'Invalid excluded entity types: {sorted(invalid_types)}. Available types: {sorted(available_types)}' + f"Invalid excluded entity types: {sorted(invalid_types)}. Available types: {sorted(available_types)}" ) return True diff --git a/graphiti_core/llm_client/anthropic_client.py b/graphiti_core/llm_client/anthropic_client.py index 916370984..b2bf6f081 100644 --- a/graphiti_core/llm_client/anthropic_client.py +++ b/graphiti_core/llm_client/anthropic_client.py @@ -39,30 +39,30 @@ from anthropic.types import MessageParam, ToolChoiceParam, ToolUnionParam except ImportError: raise ImportError( - 'anthropic is required for AnthropicClient. ' - 'Install it with: pip install graphiti-core[anthropic]' + "anthropic is required for AnthropicClient. " + "Install it with: pip install graphiti-core[anthropic]" ) from None logger = logging.getLogger(__name__) AnthropicModel = Literal[ - 'claude-3-7-sonnet-latest', - 'claude-3-7-sonnet-20250219', - 'claude-3-5-haiku-latest', - 'claude-3-5-haiku-20241022', - 'claude-3-5-sonnet-latest', - 'claude-3-5-sonnet-20241022', - 'claude-3-5-sonnet-20240620', - 'claude-3-opus-latest', - 'claude-3-opus-20240229', - 'claude-3-sonnet-20240229', - 'claude-3-haiku-20240307', - 'claude-2.1', - 'claude-2.0', + "claude-3-7-sonnet-latest", + "claude-3-7-sonnet-20250219", + "claude-3-5-haiku-latest", + "claude-3-5-haiku-20241022", + "claude-3-5-sonnet-latest", + "claude-3-5-sonnet-20241022", + "claude-3-5-sonnet-20240620", + "claude-3-opus-latest", + "claude-3-opus-20240229", + "claude-3-sonnet-20240229", + "claude-3-haiku-20240307", + "claude-2.1", + "claude-2.0", ] -DEFAULT_MODEL: AnthropicModel = 'claude-3-7-sonnet-latest' +DEFAULT_MODEL: AnthropicModel = "claude-3-7-sonnet-latest" class AnthropicClient(LLMClient): @@ -95,7 +95,7 @@ def __init__( ) -> None: if config is None: config = LLMConfig() - config.api_key = os.getenv('ANTHROPIC_API_KEY') + config.api_key = os.getenv("ANTHROPIC_API_KEY") config.max_tokens = max_tokens if config.model is None: @@ -129,15 +129,17 @@ def _extract_json_from_text(self, text: str) -> dict[str, typing.Any]: ValueError: If JSON cannot be extracted or parsed """ try: - json_start = text.find('{') - json_end = text.rfind('}') + 1 + json_start = text.find("{") + json_end = text.rfind("}") + 1 if json_start >= 0 and json_end > json_start: json_str = text[json_start:json_end] return json.loads(json_str) else: - raise ValueError(f'Could not extract JSON from model response: {text}') + raise ValueError(f"Could not extract JSON from model response: {text}") except (JSONDecodeError, ValueError) as e: - raise ValueError(f'Could not extract JSON from model response: {text}') from e + raise ValueError( + f"Could not extract JSON from model response: {text}" + ) from e def _create_tool( self, response_model: type[BaseModel] | None = None @@ -155,25 +157,27 @@ def _create_tool( # Use the response_model to define the tool model_schema = response_model.model_json_schema() tool_name = response_model.__name__ - description = model_schema.get('description', f'Extract {tool_name} information') + description = model_schema.get( + "description", f"Extract {tool_name} information" + ) else: # Create a generic JSON output tool - tool_name = 'generic_json_output' - description = 'Output data in JSON format' + tool_name = "generic_json_output" + description = "Output data in JSON format" model_schema = { - 'type': 'object', - 'additionalProperties': True, - 'description': 'Any JSON object containing the requested information', + "type": "object", + "additionalProperties": True, + "description": "Any JSON object containing the requested information", } tool = { - 'name': tool_name, - 'description': description, - 'input_schema': model_schema, + "name": tool_name, + "description": description, + "input_schema": model_schema, } tool_list = [tool] tool_list_cast = typing.cast(list[ToolUnionParam], tool_list) - tool_choice = {'type': 'tool', 'name': tool_name} + tool_choice = {"type": "tool", "name": tool_name} tool_choice_cast = typing.cast(ToolChoiceParam, tool_choice) return tool_list_cast, tool_choice_cast @@ -201,11 +205,9 @@ async def _generate_response( Exception: If an error occurs during the generation process. """ system_message = messages[0] - user_messages = [{'role': m.role, 'content': m.content} for m in messages[1:]] + user_messages = [{"role": m.role, "content": m.content} for m in messages[1:]] user_messages_cast = typing.cast(list[MessageParam], user_messages) - # TODO: Replace hacky min finding solution after fixing hardcoded EXTRACT_EDGES_MAX_TOKENS = 16384 in - # edge_operations.py. Throws errors with cheaper models that lower max_tokens. max_creation_tokens: int = min( max_tokens if max_tokens is not None else self.config.max_tokens, DEFAULT_MAX_TOKENS, @@ -226,7 +228,7 @@ async def _generate_response( # Extract the tool output from the response for content_item in result.content: - if content_item.type == 'tool_use': + if content_item.type == "tool_use": if isinstance(content_item.input, dict): tool_args: dict[str, typing.Any] = content_item.input else: @@ -235,25 +237,27 @@ async def _generate_response( # If we didn't get a proper tool_use response, try to extract from text for content_item in result.content: - if content_item.type == 'text': + if content_item.type == "text": return self._extract_json_from_text(content_item.text) else: raise ValueError( - f'Could not extract structured data from model response: {result.content}' + f"Could not extract structured data from model response: {result.content}" ) # If we get here, we couldn't parse a structured response raise ValueError( - f'Could not extract structured data from model response: {result.content}' + f"Could not extract structured data from model response: {result.content}" ) except anthropic.RateLimitError as e: - raise RateLimitError(f'Rate limit exceeded. Please try again later. Error: {e}') from e + raise RateLimitError( + f"Rate limit exceeded. Please try again later. Error: {e}" + ) from e except anthropic.APIError as e: # Special case for content policy violations. We convert these to RefusalError # to bypass the retry mechanism, as retrying policy-violating content will always fail. # This avoids wasting API calls and provides more specific error messaging to the user. - if 'refused to respond' in str(e).lower(): + if "refused to respond" in str(e).lower(): raise RefusalError(str(e)) from e raise e except Exception as e: @@ -313,27 +317,31 @@ async def generate_response( if retry_count >= max_retries: if isinstance(e, ValidationError): logger.error( - f'Validation error after {retry_count}/{max_retries} attempts: {e}' + f"Validation error after {retry_count}/{max_retries} attempts: {e}" ) else: - logger.error(f'Max retries ({max_retries}) exceeded. Last error: {e}') + logger.error( + f"Max retries ({max_retries}) exceeded. Last error: {e}" + ) raise e if isinstance(e, ValidationError): response_model_cast = typing.cast(type[BaseModel], response_model) - error_context = f'The previous response was invalid. Please provide a valid {response_model_cast.__name__} object. Error: {e}' + error_context = f"The previous response was invalid. Please provide a valid {response_model_cast.__name__} object. Error: {e}" else: error_context = ( - f'The previous response attempt was invalid. ' - f'Error type: {e.__class__.__name__}. ' - f'Error details: {str(e)}. ' - f'Please try again with a valid response.' + f"The previous response attempt was invalid. " + f"Error type: {e.__class__.__name__}. " + f"Error details: {str(e)}. " + f"Please try again with a valid response." ) # Common retry logic retry_count += 1 - messages.append(Message(role='user', content=error_context)) - logger.warning(f'Retrying after error (attempt {retry_count}/{max_retries}): {e}') + messages.append(Message(role="user", content=error_context)) + logger.warning( + f"Retrying after error (attempt {retry_count}/{max_retries}): {e}" + ) # If we somehow get here, raise the last error - raise last_error or Exception('Max retries exceeded with no specific error') + raise last_error or Exception("Max retries exceeded with no specific error") diff --git a/graphiti_core/utils/maintenance/edge_operations.py b/graphiti_core/utils/maintenance/edge_operations.py index ef78db438..561e6450a 100644 --- a/graphiti_core/utils/maintenance/edge_operations.py +++ b/graphiti_core/utils/maintenance/edge_operations.py @@ -29,7 +29,7 @@ create_entity_edge_embeddings, ) from graphiti_core.graphiti_types import GraphitiClients -from graphiti_core.helpers import MAX_REFLEXION_ITERATIONS, semaphore_gather +from graphiti_core.helpers import MAX_REFLEXION_ITERATIONS, semaphore_gather, EXTRACT_EDGES_MAX_TOKENS from graphiti_core.llm_client import LLMClient from graphiti_core.llm_client.config import ModelSize from graphiti_core.nodes import CommunityNode, EntityNode, EpisodicNode @@ -118,7 +118,7 @@ async def extract_edges( ) -> list[EntityEdge]: start = time() - extract_edges_max_tokens = 16384 + extract_edges_max_tokens = EXTRACT_EDGES_MAX_TOKENS llm_client = clients.llm_client edge_type_signature_map: dict[str, tuple[str, str]] = {