Skip to content

fix issue #203 #206

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 34 additions & 47 deletions src/backtester.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,16 +67,17 @@ def __init__(
"margin_requirement": initial_margin_requirement, # The margin ratio required for shorts
"positions": {
ticker: {
"long": 0, # Number of shares held long
"short": 0, # Number of shares held short
"long": 0, # Number of shares held long
"short": 0, # Number of shares held short
"long_cost_basis": 0.0, # Average cost basis per share (long)
"short_cost_basis": 0.0, # Average cost basis per share (short)
"short_margin_used": 0.0 # Dollars of margin used for this ticker's short
"short_cost_basis": 0.0, # Average cost basis per share (short)
"short_margin_used": 0.0, # Dollars of margin used for this ticker's short
"short_proceeds": 0.0 # Added to track short sale proceeds
} for ticker in tickers
},
"realized_gains": {
ticker: {
"long": 0.0, # Realized gains from long positions
"long": 0.0, # Realized gains from long positions
"short": 0.0, # Realized gains from short positions
} for ticker in tickers
}
Expand All @@ -85,13 +86,14 @@ def __init__(
def execute_trade(self, ticker: str, action: str, quantity: float, current_price: float):
"""
Execute trades with support for both long and short positions.
`quantity` is the number of shares the agent wants to buy/sell/short/cover.
We will only trade integer shares to keep it simple.
For short sales, we immediately update the cash balance by adding the sale proceeds
and subtracting the required margin. When covering, we subtract the cover cost,
add back the released margin, and realize profit or loss.
"""
if quantity <= 0:
return 0

quantity = int(quantity) # force integer shares
quantity = int(quantity) # Force integer shares
position = self.portfolio["positions"][ticker]

if action == "buy":
Expand All @@ -100,14 +102,11 @@ def execute_trade(self, ticker: str, action: str, quantity: float, current_price
# Weighted average cost basis for the new total
old_shares = position["long"]
old_cost_basis = position["long_cost_basis"]
new_shares = quantity
total_shares = old_shares + new_shares

total_shares = old_shares + quantity
if total_shares > 0:
total_old_cost = old_cost_basis * old_shares
total_new_cost = cost
position["long_cost_basis"] = (total_old_cost + total_new_cost) / total_shares

position["long"] += quantity
self.portfolio["cash"] -= cost
return quantity
Expand All @@ -119,12 +118,10 @@ def execute_trade(self, ticker: str, action: str, quantity: float, current_price
old_shares = position["long"]
old_cost_basis = position["long_cost_basis"]
total_shares = old_shares + max_quantity

if total_shares > 0:
total_old_cost = old_cost_basis * old_shares
total_new_cost = cost
position["long_cost_basis"] = (total_old_cost + total_new_cost) / total_shares

position["long"] += max_quantity
self.portfolio["cash"] -= cost
return max_quantity
Expand All @@ -150,31 +147,27 @@ def execute_trade(self, ticker: str, action: str, quantity: float, current_price
elif action == "short":
"""
Typical short sale flow:
1) Receive proceeds = current_price * quantity
2) Post margin_required = proceeds * margin_ratio
3) Net effect on cash = +proceeds - margin_required
1) Receive proceeds = current_price * quantity
2) Post margin_required = proceeds * margin_ratio
3) Net effect on cash = +proceeds - margin_required
"""
proceeds = current_price * quantity
margin_required = proceeds * self.portfolio["margin_requirement"]
if margin_required <= self.portfolio["cash"]:
# Weighted average short cost basis
old_short_shares = position["short"]
old_shares = position["short"]
old_cost_basis = position["short_cost_basis"]
new_shares = quantity
total_shares = old_short_shares + new_shares

total_shares = old_shares + quantity
if total_shares > 0:
total_old_cost = old_cost_basis * old_short_shares
total_new_cost = current_price * new_shares
total_old_cost = old_cost_basis * old_shares
total_new_cost = current_price * quantity
position["short_cost_basis"] = (total_old_cost + total_new_cost) / total_shares

position["short"] += quantity

# Update margin usage
position["short_margin_used"] += margin_required
self.portfolio["margin_used"] += margin_required

# Increase cash by proceeds, then subtract the required margin
# Update cash: add sale proceeds then reserve the required margin
self.portfolio["cash"] += proceeds
self.portfolio["cash"] -= margin_required
return quantity
Expand All @@ -189,20 +182,16 @@ def execute_trade(self, ticker: str, action: str, quantity: float, current_price
if max_quantity > 0:
proceeds = current_price * max_quantity
margin_required = proceeds * margin_ratio

old_short_shares = position["short"]
old_shares = position["short"]
old_cost_basis = position["short_cost_basis"]
total_shares = old_short_shares + max_quantity

total_shares = old_shares + max_quantity
if total_shares > 0:
total_old_cost = old_cost_basis * old_short_shares
total_old_cost = old_cost_basis * old_shares
total_new_cost = current_price * max_quantity
position["short_cost_basis"] = (total_old_cost + total_new_cost) / total_shares

position["short"] += max_quantity
position["short_margin_used"] += margin_required
self.portfolio["margin_used"] += margin_required

self.portfolio["cash"] += proceeds
self.portfolio["cash"] -= margin_required
return max_quantity
Expand All @@ -211,22 +200,19 @@ def execute_trade(self, ticker: str, action: str, quantity: float, current_price
elif action == "cover":
"""
When covering shares:
1) Pay cover cost = current_price * quantity
2) Release a proportional share of the margin
3) Net effect on cash = -cover_cost + released_margin
1) Pay cover cost = current_price * quantity
2) Release a proportional share of the margin
3) Net effect on cash = -cover_cost + released_margin
"""
quantity = min(quantity, position["short"])
if quantity > 0:
cover_cost = quantity * current_price
avg_short_price = position["short_cost_basis"] if position["short"] > 0 else 0
realized_gain = (avg_short_price - current_price) * quantity

if position["short"] > 0:
portion = quantity / position["short"]
else:
portion = 1.0

margin_to_release = portion * position["short_margin_used"]
# Determine the proportion of the short position being covered
proportion = quantity / position["short"] if position["short"] > 0 else 1.0
margin_to_release = proportion * position["short_margin_used"]

position["short"] -= quantity
position["short_margin_used"] -= margin_to_release
Expand All @@ -249,9 +235,9 @@ def execute_trade(self, ticker: str, action: str, quantity: float, current_price
def calculate_portfolio_value(self, current_prices):
"""
Calculate total portfolio value, including:
- cash
- market value of long positions
- unrealized gains/losses for short positions
- cash
- market value of long positions
- unrealized gains/losses for short positions
"""
total_value = self.portfolio["cash"]

Expand All @@ -265,10 +251,11 @@ def calculate_portfolio_value(self, current_prices):

# Short position unrealized PnL = short_shares * (short_cost_basis - current_price)
if position["short"] > 0:
total_value += position["short"] * (position["short_cost_basis"] - price)
short_value = - (position["short"] * price)
total_value += short_value

return total_value

def prefetch_data(self):
"""Pre-fetch all data needed for the backtest period."""
print("\nPre-fetching data for the entire backtest period...")
Expand Down