Skip to content

Refactored commited files to follow pylint standards #162

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/data/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@ def _merge_data(self, existing: list[dict] | None, new_data: list[dict], key_fie
"""Merge existing and new data, avoiding duplicates based on a key field."""
if not existing:
return new_data

# Create a set of existing keys for O(1) lookup
existing_keys = {item[key_field] for item in existing}

# Only add items that don't exist yet
merged = existing.copy()
merged.extend([item for item in new_data if item[key_field] not in existing_keys])
Expand Down
21 changes: 12 additions & 9 deletions src/llm/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,15 @@ class LLMModel(BaseModel):
def to_choice_tuple(self) -> Tuple[str, str, str]:
"""Convert to format needed for questionary choices"""
return (self.display_name, self.model_name, self.provider.value)

def has_json_mode(self) -> bool:
"""Check if the model supports JSON mode"""
return not self.is_deepseek() and not self.is_gemini()

def is_deepseek(self) -> bool:
"""Check if the model is a DeepSeek model"""
return self.model_name.startswith("deepseek")

def is_gemini(self) -> bool:
"""Check if the model is a Gemini model"""
return self.model_name.startswith("gemini")
Expand Down Expand Up @@ -118,32 +118,35 @@ def get_model(model_name: str, model_provider: ModelProvider) -> ChatOpenAI | Ch
api_key = os.getenv("GROQ_API_KEY")
if not api_key:
# Print error to console
print(f"API Key Error: Please make sure GROQ_API_KEY is set in your .env file.")
print("API Key Error: Please make sure GROQ_API_KEY is set in your .env file.")
raise ValueError("Groq API key not found. Please make sure GROQ_API_KEY is set in your .env file.")
return ChatGroq(model=model_name, api_key=api_key)
elif model_provider == ModelProvider.OPENAI:
# Get and validate API key
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
# Print error to console
print(f"API Key Error: Please make sure OPENAI_API_KEY is set in your .env file.")
print("API Key Error: Please make sure OPENAI_API_KEY is set in your .env file.")
raise ValueError("OpenAI API key not found. Please make sure OPENAI_API_KEY is set in your .env file.")
return ChatOpenAI(model=model_name, api_key=api_key)
elif model_provider == ModelProvider.ANTHROPIC:
api_key = os.getenv("ANTHROPIC_API_KEY")
if not api_key:
print(f"API Key Error: Please make sure ANTHROPIC_API_KEY is set in your .env file.")
print("API Key Error: Please make sure ANTHROPIC_API_KEY is set in your .env file.")
raise ValueError("Anthropic API key not found. Please make sure ANTHROPIC_API_KEY is set in your .env file.")
return ChatAnthropic(model=model_name, api_key=api_key)
elif model_provider == ModelProvider.DEEPSEEK:
api_key = os.getenv("DEEPSEEK_API_KEY")
if not api_key:
print(f"API Key Error: Please make sure DEEPSEEK_API_KEY is set in your .env file.")
print("API Key Error: Please make sure DEEPSEEK_API_KEY is set in your .env file.")
raise ValueError("DeepSeek API key not found. Please make sure DEEPSEEK_API_KEY is set in your .env file.")
return ChatDeepSeek(model=model_name, api_key=api_key)
elif model_provider == ModelProvider.GEMINI:
api_key = os.getenv("GOOGLE_API_KEY")
if not api_key:
print(f"API Key Error: Please make sure GOOGLE_API_KEY is set in your .env file.")
print("API Key Error: Please make sure GOOGLE_API_KEY is set in your .env file.")
raise ValueError("Google API key not found. Please make sure GOOGLE_API_KEY is set in your .env file.")
return ChatGoogleGenerativeAI(model=model_name, api_key=api_key)
return ChatGoogleGenerativeAI(model=model_name, api_key=api_key)
else:
print("Error: Model provider not found.")
return None
49 changes: 34 additions & 15 deletions src/main.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
import sys
import json

import argparse
from datetime import datetime
from dateutil.relativedelta import relativedelta
from tabulate import tabulate
from dotenv import load_dotenv
from langchain_core.messages import HumanMessage
from langgraph.graph import END, StateGraph
from colorama import Fore, Back, Style, init
import questionary

from agents.ben_graham import ben_graham_agent
from agents.bill_ackman import bill_ackman_agent
from agents.fundamentals import fundamentals_agent
Expand All @@ -13,19 +19,16 @@
from agents.risk_manager import risk_management_agent
from agents.sentiment import sentiment_agent
from agents.warren_buffett import warren_buffett_agent
from graph.state import AgentState
from agents.valuation import valuation_agent

from graph.state import AgentState

from utils.display import print_trading_output
from utils.analysts import ANALYST_ORDER, get_analyst_nodes
from utils.progress import progress
from llm.models import LLM_ORDER, get_model_info

import argparse
from datetime import datetime
from dateutil.relativedelta import relativedelta
from tabulate import tabulate
from utils.visualize import save_graph_as_png
import json

from llm.models import LLM_ORDER, get_model_info

# Load environment variables from .env file
load_dotenv()
Expand Down Expand Up @@ -60,6 +63,7 @@ def run_hedge_fund(
model_name: str = "gpt-4o",
model_provider: str = "OpenAI",
):
"""Run the hedge fund trading system."""
# Start progress tracking
progress.start()

Expand Down Expand Up @@ -107,7 +111,7 @@ def start(state: AgentState):
return state


def create_workflow(selected_analysts=None):
def create_workflow(selected_analysts = None):
"""Create the workflow with selected analysts."""
workflow = StateGraph(AgentState)
workflow.add_node("start_node", start)
Expand Down Expand Up @@ -154,16 +158,31 @@ def create_workflow(selected_analysts=None):
default=0.0,
help="Initial margin requirement. Defaults to 0.0"
)
parser.add_argument("--tickers", type=str, required=True, help="Comma-separated list of stock ticker symbols")
parser.add_argument(
"--tickers",
type=str,
required=True,
help="Comma-separated list of stock ticker symbols"
)
parser.add_argument(
"--start-date",
type=str,
help="Start date (YYYY-MM-DD). Defaults to 3 months before end date",
)
parser.add_argument("--end-date", type=str, help="End date (YYYY-MM-DD). Defaults to today")
parser.add_argument("--show-reasoning", action="store_true", help="Show reasoning from each agent")
parser.add_argument(
"--show-agent-graph", action="store_true", help="Show the agent graph"
"--end-date",
type=str,
help="End date (YYYY-MM-DD). Defaults to today"
)
parser.add_argument(
"--show-reasoning",
action="store_true",
help="Show reasoning from each agent"
)
parser.add_argument(
"--show-agent-graph",
action="store_true",
help="Show the agent graph"
)

args = parser.parse_args()
Expand Down Expand Up @@ -237,13 +256,13 @@ def create_workflow(selected_analysts=None):
try:
datetime.strptime(args.start_date, "%Y-%m-%d")
except ValueError:
raise ValueError("Start date must be in YYYY-MM-DD format")
raise ValueError("Start date must be in YYYY-MM-DD format") from ValueError

if args.end_date:
try:
datetime.strptime(args.end_date, "%Y-%m-%d")
except ValueError:
raise ValueError("End date must be in YYYY-MM-DD format")
raise ValueError("End date must be in YYYY-MM-DD format") from ValueError

# Set the start and end dates
end_date = args.end_date or datetime.now().strftime("%Y-%m-%d")
Expand Down
17 changes: 8 additions & 9 deletions src/tools/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,33 +211,33 @@ def get_company_news(

all_news = []
current_end_date = end_date

while True:
url = f"https://api.financialdatasets.ai/news/?ticker={ticker}&end_date={current_end_date}"
if start_date:
url += f"&start_date={start_date}"
url += f"&limit={limit}"

response = requests.get(url, headers=headers)
if response.status_code != 200:
raise Exception(f"Error fetching data: {ticker} - {response.status_code} - {response.text}")

data = response.json()
response_model = CompanyNewsResponse(**data)
company_news = response_model.news

if not company_news:
break

all_news.extend(company_news)

# Only continue pagination if we have a start_date and got a full page
if not start_date or len(company_news) < limit:
break

# Update end_date to the oldest date from current batch for next iteration
current_end_date = min(news.date for news in company_news).split('T')[0]

# If we've reached or passed the start_date, we can stop
if current_end_date <= start_date:
break
Expand All @@ -250,7 +250,6 @@ def get_company_news(
return all_news



def get_market_cap(
ticker: str,
end_date: str,
Expand Down