diff --git a/.env.example b/.env.example index 6665c79cc..1ba566b0a 100644 --- a/.env.example +++ b/.env.example @@ -1,5 +1,7 @@ OPENAI_API_KEY= +OLLAMA_HOST=http://localhost:11434 + # Neo4j database connection NEO4J_URI= NEO4J_PORT= diff --git a/README.md b/README.md index 5dfe57fe2..a6a239ecb 100644 --- a/README.md +++ b/README.md @@ -196,6 +196,9 @@ pip install graphiti-core[groq] # Install with Google Gemini support pip install graphiti-core[google-genai] +# Install with Ollama support +pip install graphiti-core[ollama] + # Install with multiple providers pip install graphiti-core[anthropic,groq,google-genai] diff --git a/graphiti_core/embedder/ollama.py b/graphiti_core/embedder/ollama.py new file mode 100644 index 000000000..0fc8e973c --- /dev/null +++ b/graphiti_core/embedder/ollama.py @@ -0,0 +1,176 @@ +""" +Copyright 2024, Zep Software, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import logging +from collections.abc import Iterable +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from ollama import AsyncClient +else: + try: + from ollama import AsyncClient + except ImportError: + raise ImportError( + 'ollama is required for OllamaEmbedder. ' + 'Install it with: pip install graphiti-core[ollama]' + ) from None + +from pydantic import Field + +from .client import EmbedderClient, EmbedderConfig + +logger = logging.getLogger(__name__) + +DEFAULT_EMBEDDING_MODEL = 'bge-m3:567m' +DEFAULT_BATCH_SIZE = 100 + +class OllamaEmbedderConfig(EmbedderConfig): + embedding_model: str = Field(default=DEFAULT_EMBEDDING_MODEL) + api_key: str | None = None + base_url: str | None = None + + +class OllamaEmbedder(EmbedderClient): + """ + Ollama Embedder Client + """ + def __init__( + self, + config: OllamaEmbedderConfig | None = None, + client: AsyncClient | None = None, + batch_size: int | None = None, + ): + if config is None: + config = OllamaEmbedderConfig() + + self.config = config + + if client is None: + # AsyncClient doesn't necessarily accept api_key; pass host via headers if needed + try: + host = config.base_url.rstrip('/v1') if config.base_url else None + self.client = AsyncClient(host=host) + except TypeError as e: + logger.warning(f"Error creating AsyncClient: {e}") + self.client = AsyncClient() + else: + self.client = client + + if batch_size is None: + self.batch_size = DEFAULT_BATCH_SIZE + else: + self.batch_size = batch_size + + async def create(self, input_data: str | list[str] | Iterable[int] | Iterable[Iterable[int]]) -> list[float]: + """Create a single embedding for the input using Ollama. + + Ollama's embed endpoint accepts either a single string or list of strings. + We normalize to a single-item list and return the first embedding vector. + """ + # Ollama's embed returns an object with 'embedding' or similar fields + try: + # Support call with client.embed for async client + result = await self.client.embed(model=self.config.embedding_model or DEFAULT_EMBEDDING_MODEL, input=input_data) # type: ignore[arg-type] + except Exception as e: + logger.error(f'Ollama embed error: {e}') + raise + + # Extract embedding and coerce to list[float] + values: list[float] | None = None + + if hasattr(result, 'embeddings'): + emb = result.embeddings + if isinstance(emb, list) and len(emb) > 0: + values = emb[0] if isinstance(emb[0], list | tuple) else emb # type: ignore + elif isinstance(result, dict): + if 'embedding' in result and isinstance(result['embedding'], list | tuple): + values = list(result['embedding']) # type: ignore + elif 'embeddings' in result and isinstance(result['embeddings'], list) and len(result['embeddings']) > 0: + first = result['embeddings'][0] + if isinstance(first, dict) and 'embedding' in first and isinstance(first['embedding'], list | tuple): + values = list(first['embedding']) + elif isinstance(first, list | tuple): + values = list(first) + + # If result itself is a list (some clients return list for single input) + if values is None and isinstance(result, list | tuple): + # assume it's already the embedding vector + values = list(result) # type: ignore + if not values: + raise ValueError('No embeddings returned from Ollama API in create()') + + return values + + async def create_batch(self, input_data_list: list[str]) -> list[list[float]]: + if not input_data_list: + return [] + + all_embeddings: list[list[float]] = [] + + for i in range(0, len(input_data_list), self.batch_size): + batch = input_data_list[i : i + self.batch_size] + try: + result = await self.client.embed(model=self.config.embedding_model or DEFAULT_EMBEDDING_MODEL, input=batch) + + # result may be dict with 'embeddings' list or single 'embedding' + if isinstance(result, dict) and 'embeddings' in result: + for emb in result['embeddings']: + if isinstance(emb, dict) and 'embedding' in emb and isinstance(emb['embedding'], list | tuple): + all_embeddings.append(list(emb['embedding'])) + elif isinstance(emb, list | tuple): + all_embeddings.append(list(emb)) + else: + # unexpected shape + raise ValueError('Unexpected embedding shape in batch result') + else: + # Fallback: maybe result itself is a list of vectors + if isinstance(result, list): + all_embeddings.extend(result) + else: + # Single embedding returned for whole batch; if so, duplicate per item + embedding = None + if isinstance(result, dict) and 'embedding' in result: + embedding = result['embedding'] + if embedding is None: + raise ValueError('No embeddings returned') + for _ in batch: + all_embeddings.append(embedding) + + except Exception as e: + logger.warning(f'Batch embedding failed for batch {i // self.batch_size + 1}, falling back to individual processing: {e}') + for item in batch: + try: + single = await self.client.embed(model=self.config.embedding_model or DEFAULT_EMBEDDING_MODEL, input=item) + emb = None + if hasattr(result, 'embeddings'): + _emb = result.embeddings + if isinstance(_emb, list) and len(_emb) > 0: + emb = _emb[0] if isinstance(_emb[0], list | tuple) else _emb # type: ignore + elif isinstance(single, dict) and 'embedding' in single: + emb = single['embedding'] + elif isinstance(single, dict) and 'embeddings' in single: + emb = single['embeddings'] + elif isinstance(single, list | tuple): + emb = single[0] if single else None # type: ignore + if not emb: + raise ValueError('No embeddings returned from Ollama API') + all_embeddings.append(emb) # type: ignore + except Exception as individual_error: + logger.error(f'Failed to embed individual item: {individual_error}') + raise individual_error + + return all_embeddings diff --git a/graphiti_core/graphiti.py b/graphiti_core/graphiti.py index bebbdc7cd..05fa635e4 100644 --- a/graphiti_core/graphiti.py +++ b/graphiti_core/graphiti.py @@ -275,6 +275,8 @@ def _get_provider_type(self, client) -> str: return 'gemini' elif 'groq' in class_name: return 'groq' + elif 'ollama' in class_name: + return 'ollama' # Database providers elif 'neo4j' in class_name: return 'neo4j' diff --git a/graphiti_core/llm_client/ollama_client.py b/graphiti_core/llm_client/ollama_client.py new file mode 100644 index 000000000..68f9deebc --- /dev/null +++ b/graphiti_core/llm_client/ollama_client.py @@ -0,0 +1,148 @@ +""" +Copyright 2024, Zep Software, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import json +import logging +import typing +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from ollama import AsyncClient +else: + try: + from ollama import AsyncClient + except ImportError: + raise ImportError( + 'ollama is required for OllamaClient. Install it with: pip install graphiti-core[ollama]' + ) from None +from pydantic import BaseModel + +from ..prompts.models import Message +from .client import LLMClient +from .config import LLMConfig, ModelSize +from .errors import RateLimitError + +logger = logging.getLogger(__name__) + +DEFAULT_MODEL = 'qwen3:4b' +DEFAULT_MAX_TOKENS = 8192 + + +class OllamaClient(LLMClient): + """Ollama async client wrapper for Graphiti. + + This client expects the `ollama` python package to be installed. It uses the + AsyncClient.chat(...) API to generate chat responses. The response content + is expected to be JSON which will be parsed and returned as a dict. + """ + + def __init__(self, config: LLMConfig | None = None, cache: bool = False, client: typing.Any | None = None): + if config is None: + config = LLMConfig(max_tokens=DEFAULT_MAX_TOKENS) + elif config.max_tokens is None: + config.max_tokens = DEFAULT_MAX_TOKENS + super().__init__(config, cache) + + # Allow injecting a preconfigured AsyncClient for testing + if client is None: + # AsyncClient accepts host and other httpx args; pass api_key/base_url when available + try: + host = config.base_url.rstrip('/v1') if config.base_url else None + self.client = AsyncClient(host=host) + except TypeError as e: + logger.warning(f"Error creating AsyncClient: {e}") + self.client = AsyncClient() + else: + self.client = client + + async def _generate_response( + self, + messages: list[Message], + response_model: type[BaseModel] | None = None, + max_tokens: int = DEFAULT_MAX_TOKENS, + model_size: ModelSize = ModelSize.medium, + ) -> dict[str, typing.Any]: + msgs: list[dict[str, str]] = [] + for m in messages: + if m.role == 'user': + msgs.append({'role': 'user', 'content': m.content}) + elif m.role == 'system': + msgs.append({'role': 'system', 'content': m.content}) + + try: + # Prepare options + options: dict[str, typing.Any] = {} + if max_tokens is not None: + options['max_tokens'] = max_tokens + if self.temperature is not None: + options['temperature'] = self.temperature + + # If a response_model is provided, try to get its JSON schema for format + schema = None + if response_model is not None: + try: + schema = response_model.model_json_schema() + except Exception: + schema = None + response = await self.client.chat( + model=self.model or DEFAULT_MODEL, + messages=msgs, + stream=False, + format=schema, + options=options, + ) + + # Extract content + content: str | None = None + if isinstance(response, dict) and 'message' in response and isinstance(response['message'], dict): + content = response['message'].get('content') + else: + # Some clients return objects with a .message attribute instead of dicts + msg = getattr(response, 'message', None) + + if isinstance(msg, dict): + content = msg.get('content') + elif msg is not None: + content = getattr(msg, 'content', None) + + if content is None: + # fallback to string + content = str(response) + + # If structured response requested, validate with pydantic model + if response_model is not None: + # Use pydantic v2 model validate json method + try: + validated = response_model.model_validate_json(content) + # return model as dict + return validated.model_dump() # type: ignore[attr-defined] + except Exception as e: + logger.error(f'Failed to validate response with response_model: {e}') + # fallthrough to try json loads + + # Try parse JSON otherwise + try: + return json.loads(content) + except Exception: + return {'text': content} + except Exception as e: + # map obvious ollama rate limit / response errors to RateLimitError when possible + err_name = e.__class__.__name__ + status_code = getattr(e, 'status_code', None) or getattr(e, 'status', None) + if err_name in ('RequestError', 'ResponseError') and status_code == 429: + raise RateLimitError from e + logger.error(f'Error in generating LLM response (ollama): {e}') + raise \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 07ca5ec00..6be109d19 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,6 +35,7 @@ voyageai = ["voyageai>=0.2.3"] neo4j-opensearch = ["boto3>=1.39.16", "opensearch-py>=3.0.0"] sentence-transformers = ["sentence-transformers>=3.2.1"] neptune = ["langchain-aws>=0.2.29", "opensearch-py>=3.0.0", "boto3>=1.39.16"] +ollama = ["ollama>=0.5.3"] tracing = ["opentelemetry-api>=1.20.0", "opentelemetry-sdk>=1.20.0"] dev = [ "pyright>=1.1.404", @@ -56,6 +57,7 @@ dev = [ "sentence-transformers>=3.2.1", "transformers>=4.45.2", "voyageai>=0.2.3", + "ollama>=0.5.3", "pytest>=8.3.3", "pytest-asyncio>=0.24.0", "pytest-xdist>=3.6.1", diff --git a/tests/llm_client/test_ollama_client.py b/tests/llm_client/test_ollama_client.py new file mode 100644 index 000000000..6c6f4c4c8 --- /dev/null +++ b/tests/llm_client/test_ollama_client.py @@ -0,0 +1,95 @@ +""" +Copyright 2024, Zep Software, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +# Running tests: pytest -xvs tests/llm_client/test_ollama_client.py + + +import pytest +from pydantic import BaseModel, Field + +from graphiti_core.llm_client.ollama_client import OllamaClient +from graphiti_core.prompts.models import Message + +# Skip tests if no Ollama API/key or local server available +# pytestmark = pytest.mark.skipif( +# 'OLLAMA_HOST' not in os.environ, +# reason='Ollama API/host not available', +# ) + + +# Rename to avoid pytest collection as a test class +class SimpleResponseModel(BaseModel): + message: str = Field(..., description='A message from the model') + + +class Pet(BaseModel): + name: str + animal: str + age: int + color: str | None + favorite_toy: str | None + + +class PetList(BaseModel): + pets: list[Pet] + + +@pytest.mark.asyncio +@pytest.mark.integration +async def test_generate_simple_response(): + client = OllamaClient() + + messages = [ + Message( + role='user', + content="Respond with a JSON object containing a 'message' field with value 'Hello, world!'", + ) + ] + + try: + response = await client.generate_response(messages, response_model=SimpleResponseModel) + assert isinstance(response, dict) + assert 'message' in response + assert response['message'] == 'Hello, world!' + except Exception as e: + pytest.skip(f'Test skipped due to Ollama API error: {str(e)}') + + +@pytest.mark.asyncio +@pytest.mark.integration +async def test_structured_output_with_pydantic(): + client = OllamaClient() + + messages = [ + Message( + role='user', + content=''' + I have two pets. + A cat named Luna who is 5 years old and loves playing with yarn. She has grey fur. + I also have a 2 year old black cat named Loki who loves tennis balls. + ''', + ) + ] + + try: + response = await client.generate_response(messages, response_model=PetList) + assert isinstance(response, dict) + assert 'pets' in response + assert isinstance(response['pets'], list) + assert len(response['pets']) >= 1 + except Exception as e: + pytest.skip(f'Test skipped due to Ollama API error: {str(e)}') +