Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions servers/fastapi/api/main.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from api.lifespan import app_lifespan
from api.middlewares import UserConfigEnvUpdateMiddleware
from api.v1.ppt.router import API_V1_PPT_ROUTER
Expand All @@ -24,3 +25,6 @@
)

app.add_middleware(UserConfigEnvUpdateMiddleware)

# Mount static files directory
app.mount("/static", StaticFiles(directory="/app/servers/fastapi/static"), name="static")
95 changes: 93 additions & 2 deletions servers/fastapi/api/v1/organisations.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi import APIRouter, Depends, HTTPException, status, UploadFile, File, Form
from sqlmodel import Session, select
from schemas.organisations import (
OrganisationCreate,
Expand All @@ -9,7 +9,9 @@
LoginRequest,
OrganisationOnboardingRequest,
OrganisationCreateResponse,
SimpleOrganisationOnboardRequest
SimpleOrganisationOnboardRequest,
OrganisationUpdate,
OrganisationUpdateResponse
)
from models.sql.organisation import Organisation
from models.sql.user import User
Expand All @@ -20,6 +22,7 @@
verify_password,
get_current_user
)
from utils.file_upload import save_logo

router = APIRouter()

Expand Down Expand Up @@ -170,3 +173,91 @@ async def get_current_user_info(
):
"""Get current user information"""
return current_user

@router.get("/{org_id}", response_model=OrganisationResponse)
async def get_organisation(
org_id: str,
session: Session = Depends(get_async_session),
current_user: User = Depends(get_current_user)
):
"""
Get organization details by ID.

This endpoint:
1. Retrieves the organization details by ID
2. Returns the organization details
"""
# Check if the user belongs to the organization
if current_user.organisation_id != org_id:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="You don't have permission to access this organization"
)

# Get the organization
org = await session.get(Organisation, org_id)
if not org:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Organization not found"
)

return org

@router.put("/update", response_model=OrganisationUpdateResponse)
async def update_organisation(
name: str = Form(None),
logo: UploadFile = File(None),
session: Session = Depends(get_async_session),
current_user: User = Depends(get_current_user)
):
"""
Update the current user's organization details.

This endpoint:
1. Updates the organization name and/or logo
2. Returns the updated organization details
"""
# Check if the user is an admin
if not current_user.is_admin:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="You don't have permission to update this organization"
)

# Get the organization
org_id = current_user.organisation_id
org = await session.get(Organisation, org_id)
if not org:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Organization not found"
)

# Update the organization name if provided
if name:
org.name = name

# Update the organization logo if provided
if logo:
logo_path = await save_logo(logo)
org.logo_url = logo_path

# Save the changes
session.add(org)
await session.commit()
await session.refresh(org)

# Convert Organisation to OrganisationResponse
org_response = OrganisationResponse(
id=org.id,
name=org.name,
logo_url=org.logo_url,
created_at=org.created_at
)

return OrganisationUpdateResponse(
success=True,
message=f"Organisation '{org.name}' updated successfully",
organisation=org_response
)
23 changes: 15 additions & 8 deletions servers/fastapi/dependencies/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@ async def get_current_user_id(authorization: Optional[str] = Header(None)) -> st
This function validates the JWT token and extracts the user_id from the token payload.
"""
if not authorization:
# For development, we'll return a default user_id if no token is provided
# In production, this should raise an HTTPException for unauthorized access
return "default_user"
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Authorization header is missing",
headers={"WWW-Authenticate": "Bearer"},
)

try:
# Remove 'Bearer ' prefix if present
Expand All @@ -32,8 +34,13 @@ async def get_current_user_id(authorization: Optional[str] = Header(None)) -> st
)

return user_id
except JWTError:
# For backward compatibility during development, if token validation fails,
# fall back to using the token itself as the user_id
# In production, this should be removed and only valid JWT tokens should be accepted
return authorization.replace("Bearer ", "")
except JWTError as e:
# Log the error for debugging
print(f"JWT Error: {str(e)}")

# Raise an exception for invalid tokens
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid authentication token",
headers={"WWW-Authenticate": "Bearer"}
)
11 changes: 11 additions & 0 deletions servers/fastapi/schemas/organisations.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from pydantic import BaseModel, EmailStr
from datetime import datetime
from typing import Optional
from fastapi import UploadFile, File

# Organisation Schemas
class OrganisationCreate(BaseModel):
Expand Down Expand Up @@ -55,3 +56,13 @@ class OrganisationCreateResponse(BaseModel):
success: bool
message: str
organisation_id: str

# Organisation Update Schema
class OrganisationUpdate(BaseModel):
name: Optional[str] = None
logo_url: Optional[str] = None

class OrganisationUpdateResponse(BaseModel):
success: bool
message: str
organisation: OrganisationResponse
4 changes: 2 additions & 2 deletions servers/fastapi/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
import sys

# When running inside Docker, we need to use the service name instead of localhost
# BASE_URL = "http://localhost:3001/api/v1" # For running on host
BASE_URL = "http://localhost:80/api/v1" # For running inside Docker container
BASE_URL = "http://localhost:3001/api/v1" # For running on host
# BASE_URL = "http://localhost:80/api/v1" # For running inside Docker container
TEST_ORG_NAME = "Test Organization"
TEST_ADMIN_NAME = "Test Admin"
TEST_ADMIN_EMAIL = "[email protected]"
Expand Down
134 changes: 134 additions & 0 deletions servers/fastapi/test_update_org.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
"""
Test script to verify the organization update API endpoint.
This script tests the following:
1. Logging in to get a JWT token
2. Getting organization details
3. Updating organization details with a new name and logo
"""
import asyncio
import json
import requests
import sys
import os
from pathlib import Path

# When running inside Docker, we need to use the service name instead of localhost
# BASE_URL = "http://localhost:3001/api/v1" # For running on host
BASE_URL = "http://development:80/api/v1" # For running inside Docker container
TEST_ADMIN_EMAIL = "[email protected]"
TEST_ADMIN_PASSWORD = "securepassword123"
TEST_LOGO_PATH = Path(__file__).parent / "test_logo.png"

def print_header(message):
print("\n" + "=" * 50)
print(message)
print("=" * 50)

def create_test_logo():
"""Create a simple test logo file if it doesn't exist"""
if not TEST_LOGO_PATH.exists():
# Create a simple 100x100 black square as a test logo
try:
from PIL import Image
img = Image.new('RGB', (100, 100), color = (0, 0, 0))
img.save(TEST_LOGO_PATH)
print(f"Created test logo at {TEST_LOGO_PATH}")
except ImportError:
print("PIL not installed, cannot create test logo")
print("Please create a test logo manually at", TEST_LOGO_PATH)
sys.exit(1)

def test_login():
print_header("Testing: Login with Admin User")

url = f"{BASE_URL}/organisations/login"
headers = {"Content-Type": "application/json"}
data = {
"email": TEST_ADMIN_EMAIL,
"password": TEST_ADMIN_PASSWORD
}

response = requests.post(url, headers=headers, json=data)

if response.status_code == 200:
result = response.json()
print(f"✅ Successfully logged in as {result['user']['full_name']}")
print(f"✅ JWT Token: {result['access_token'][:20]}...")
return result['access_token'], result['user']['organisation_id']
else:
print(f"❌ Failed to login: {response.status_code}")
print(response.text)
return None, None

def test_get_organisation(token, org_id):
print_header("Testing: Get Organization Details")

url = f"{BASE_URL}/organisations/{org_id}"
headers = {"Authorization": f"Bearer {token}"}

response = requests.get(url, headers=headers)

if response.status_code == 200:
result = response.json()
print(f"✅ Successfully retrieved organization: {result['name']}")
print(f"✅ Organization ID: {result['id']}")
print(f"✅ Current Logo URL: {result['logo_url']}")
return True
else:
print(f"❌ Failed to get organization: {response.status_code}")
print(response.text)
return False

def test_update_organisation(token, org_id):
print_header("Testing: Update Organization Details")

url = f"{BASE_URL}/organisations/update"
headers = {"Authorization": f"Bearer {token}"}

# Prepare multipart form data
files = {
'logo': ('logo.png', open(TEST_LOGO_PATH, 'rb'), 'image/png')
}
data = {
'name': f"Updated Test Organization {os.urandom(4).hex()}"
}

response = requests.put(url, headers=headers, files=files, data=data)

if response.status_code == 200:
result = response.json()
print(f"✅ Successfully updated organization: {result['organisation']['name']}")
print(f"✅ New Logo URL: {result['organisation']['logo_url']}")
return True
else:
print(f"❌ Failed to update organization: {response.status_code}")
print(response.text)
return False

def main():
print("Starting organization update API tests...")

# Create test logo if it doesn't exist
create_test_logo()

# Test logging in to get a JWT token
token, org_id = test_login()
if not token or not org_id:
print("Cannot continue tests without a JWT token and organization ID")
return

# Test getting organization details
if not test_get_organisation(token, org_id):
print("Cannot continue tests without organization details")
return

# Test updating organization details
test_update_organisation(token, org_id)

# Test getting updated organization details
test_get_organisation(token, org_id)

print("\nTests completed!")

if __name__ == "__main__":
main()
2 changes: 1 addition & 1 deletion servers/fastapi/utils/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# JWT Configuration
SECRET_KEY = "your-secret-key-keep-it-secret" # Change this in production!
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 30
ACCESS_TOKEN_EXPIRE_MINUTES = 43200 # 30 days

pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
security = HTTPBearer()
Expand Down
50 changes: 50 additions & 0 deletions servers/fastapi/utils/file_upload.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import os
import shutil
import uuid
from fastapi import UploadFile
from pathlib import Path

# Define the base directory for file uploads
UPLOAD_DIR = Path("/app/servers/fastapi/static/uploads")
LOGO_DIR = UPLOAD_DIR / "logos"

# Ensure directories exist
os.makedirs(LOGO_DIR, exist_ok=True)

async def save_upload_file(upload_file: UploadFile, directory: Path = UPLOAD_DIR) -> str:
"""
Save an uploaded file to the specified directory and return the file path.

Args:
upload_file: The uploaded file
directory: The directory to save the file to (default: UPLOAD_DIR)

Returns:
The relative path to the saved file
"""
# Generate a unique filename to avoid collisions
file_extension = os.path.splitext(upload_file.filename)[1]
unique_filename = f"{uuid.uuid4()}{file_extension}"

# Create the full file path
file_path = directory / unique_filename

# Save the file
with open(file_path, "wb") as buffer:
shutil.copyfileobj(upload_file.file, buffer)

# Return the relative path from the static directory
relative_path = str(file_path).replace("/app/servers/fastapi/static", "")
return relative_path

async def save_logo(logo_file: UploadFile) -> str:
"""
Save an organization logo and return the file path.

Args:
logo_file: The uploaded logo file

Returns:
The relative path to the saved logo
"""
return await save_upload_file(logo_file, LOGO_DIR)