Skip to content

add ollama support for any hosted model, dynamically loaded at runtime #197

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,6 @@ GOOGLE_API_KEY=your-google-api-key
FINANCIAL_DATASETS_API_KEY=your-financial-datasets-api-key
# For running LLMs hosted by openai (gpt-4o, gpt-4o-mini, etc.)
# Get your OpenAI API key from https://platform.openai.com/
OPENAI_API_KEY=your-openai-api-key
OPENAI_API_KEY=your-openai-api-key

OLLAMA_BASE_URL=http://localhost:11434
216 changes: 181 additions & 35 deletions poetry.lock

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ colorama = "^0.4.6"
questionary = "^2.1.0"
rich = "^13.9.4"
langchain-google-genai = "^2.0.11"
langchain-ollama = "^0.3.0"

[tool.poetry.group.dev.dependencies]
pytest = "^7.4.0"
Expand Down
4 changes: 2 additions & 2 deletions src/backtester.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import numpy as np
import itertools

from llm.models import LLM_ORDER, get_model_info
from llm.models import get_llm_order, get_model_info
from utils.analysts import ANALYST_ORDER
from main import run_hedge_fund
from tools.api import (
Expand Down Expand Up @@ -720,7 +720,7 @@ def analyze_performance(self):
# Select LLM model
model_choice = questionary.select(
"Select your LLM model:",
choices=[questionary.Choice(display, value=value) for display, value, _ in LLM_ORDER],
choices=[questionary.Choice(display, value=value) for display, value, _ in get_llm_order()],
style=questionary.Style([
("selected", "fg:green bold"),
("pointer", "fg:green bold"),
Expand Down
40 changes: 34 additions & 6 deletions src/llm/models.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import os
import requests
from langchain_anthropic import ChatAnthropic
from langchain_deepseek import ChatDeepSeek
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_groq import ChatGroq
from langchain_openai import ChatOpenAI
from langchain_ollama import ChatOllama
from enum import Enum
from pydantic import BaseModel
from typing import Tuple
Expand All @@ -16,7 +18,7 @@ class ModelProvider(str, Enum):
GEMINI = "Gemini"
GROQ = "Groq"
OPENAI = "OpenAI"

OLLAMA = "Ollama"


class LLMModel(BaseModel):
Expand Down Expand Up @@ -106,14 +108,38 @@ def is_gemini(self) -> bool:
),
]

# Create LLM_ORDER in the format expected by the UI
LLM_ORDER = [model.to_choice_tuple() for model in AVAILABLE_MODELS]
def get_available_ollama_models():
"""Fetch available models from an Ollama server"""
try:
base_url = os.environ.get("OLLAMA_BASE_URL")
response = requests.get(f"{base_url}/api/tags")
if response.status_code == 200:
models = response.json().get("models", [])
models.sort(key=lambda x: x['name'].lower())
return [
LLMModel(
display_name=f"[ollama] {model['name']}",
model_name=model['name'],
provider=ModelProvider.OLLAMA
)
for model in models
]
else:
print(f"Failed to fetch Ollama models: {response.status_code} - {response.text}")
return []
except Exception as e:
# If no Ollama server is running, return an empty list
return []

def get_llm_order():
"""Get the available LLM models including Ollama models"""
return [model.to_choice_tuple() for model in AVAILABLE_MODELS + get_available_ollama_models()]

def get_model_info(model_name: str) -> LLMModel | None:
"""Get model information by model_name"""
return next((model for model in AVAILABLE_MODELS if model.model_name == model_name), None)
return next((model for model in AVAILABLE_MODELS + get_available_ollama_models() if model.model_name == model_name), None)

def get_model(model_name: str, model_provider: ModelProvider) -> ChatOpenAI | ChatGroq | None:
def get_model(model_name: str, model_provider: ModelProvider) -> ChatOpenAI | ChatGroq | ChatAnthropic | ChatDeepSeek | ChatGoogleGenerativeAI | ChatOllama | None:
if model_provider == ModelProvider.GROQ:
api_key = os.getenv("GROQ_API_KEY")
if not api_key:
Expand Down Expand Up @@ -146,4 +172,6 @@ def get_model(model_name: str, model_provider: ModelProvider) -> ChatOpenAI | Ch
if not api_key:
print(f"API Key Error: Please make sure GOOGLE_API_KEY is set in your .env file.")
raise ValueError("Google API key not found. Please make sure GOOGLE_API_KEY is set in your .env file.")
return ChatGoogleGenerativeAI(model=model_name, api_key=api_key)
return ChatGoogleGenerativeAI(model=model_name, api_key=api_key)
elif model_provider == ModelProvider.OLLAMA:
return ChatOllama(model=model_name)
4 changes: 2 additions & 2 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from utils.display import print_trading_output
from utils.analysts import ANALYST_ORDER, get_analyst_nodes
from utils.progress import progress
from llm.models import LLM_ORDER, get_model_info
from llm.models import get_llm_order, get_model_info

import argparse
from datetime import datetime
Expand Down Expand Up @@ -198,7 +198,7 @@ def create_workflow(selected_analysts=None):
# Select LLM model
model_choice = questionary.select(
"Select your LLM model:",
choices=[questionary.Choice(display, value=value) for display, value, _ in LLM_ORDER],
choices=[questionary.Choice(display, value=value) for display, value, _ in get_llm_order()],
style=questionary.Style([
("selected", "fg:green bold"),
("pointer", "fg:green bold"),
Expand Down