Skip to content

Don't merge || Azure integration #1

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
249 changes: 249 additions & 0 deletions backend/apps/azure_openai/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,249 @@
from fastapi import FastAPI, Request, Response, HTTPException, Depends
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse, JSONResponse, FileResponse

import requests
import json
from pydantic import BaseModel


from apps.web.models.users import Users
from constants import ERROR_MESSAGES
from utils.utils import decode_token, get_current_user
from config import AZURE_OPENAI_API_BASE_URL, AZURE_OPENAI_API_KEY, AZURE_OPENAI_API_VERSION, CACHE_DIR

import hashlib
from pathlib import Path

app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)

app.state.AZURE_OPENAI_API_BASE_URL = AZURE_OPENAI_API_BASE_URL
app.state.AZURE_OPENAI_API_KEY = AZURE_OPENAI_API_KEY
app.state.AZURE_OPENAI_API_VERSION = AZURE_OPENAI_API_VERSION


class UrlUpdateForm(BaseModel):
url: str


class KeyUpdateForm(BaseModel):
key: str


class VersionUpdateForm(BaseModel):
version: str


@app.get("/url")
async def get_azure_openai_url(user=Depends(get_current_user)):
if user and user.role == "admin":
return {"AZURE_OPENAI_API_BASE_URL": app.state.AZURE_OPENAI_API_BASE_URL}
else:
raise HTTPException(
status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)


@app.post("/url/update")
async def update_azure_openai_url(form_data: UrlUpdateForm, user=Depends(get_current_user)):
if user and user.role == "admin":
app.state.AZURE_OPENAI_API_BASE_URL = form_data.url
return {"AZURE_OPENAI_API_BASE_URL": app.state.AZURE_OPENAI_API_BASE_URL}
else:
raise HTTPException(
status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)


@app.get("/key")
async def get_azure_openai_key(user=Depends(get_current_user)):
if user and user.role == "admin":
return {"AZURE_OPENAI_API_KEY": app.state.AZURE_OPENAI_API_KEY}
else:
raise HTTPException(
status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)


@app.post("/key/update")
async def update_azure_openai_key(form_data: KeyUpdateForm, user=Depends(get_current_user)):
print(form_data)
if user and user.role == "admin":
app.state.AZURE_OPENAI_API_KEY = form_data.key
return {"AZURE_OPENAI_API_KEY": app.state.AZURE_OPENAI_API_KEY}
else:
raise HTTPException(
status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)


@app.get("/version")
async def get_azure_openai_version(user=Depends(get_current_user)):
if user and user.role == "admin":
return {"AZURE_OPENAI_API_VERSION": app.state.AZURE_OPENAI_API_VERSION}
else:
raise HTTPException(
status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)


@app.post("/version/update")
async def update_azure_openai_version(form_data: VersionUpdateForm, user=Depends(get_current_user)):
print(form_data)
if user and user.role == "admin":
app.state.AZURE_OPENAI_API_VERSION = form_data.version
return {"AZURE_OPENAI_API_VERSION": app.state.AZURE_OPENAI_API_VERSION}
else:
raise HTTPException(
status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)


@app.post("/audio/speech")
async def speech(request: Request, user=Depends(get_current_user)):
target_url = f"{app.state.AZURE_OPENAI_API_BASE_URL}/audio/speech"

if user.role not in ["user", "admin"]:
raise HTTPException(
status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
# if app.state.AZURE_OPENAI_API_KEY == "":
# raise HTTPException(status_code=401, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)

body = await request.body()

name = hashlib.sha256(body).hexdigest()

SPEECH_CACHE_DIR = Path(CACHE_DIR).joinpath("./audio/speech/")
SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True)
file_path = SPEECH_CACHE_DIR.joinpath(f"{name}.mp3")
file_body_path = SPEECH_CACHE_DIR.joinpath(f"{name}.json")

# Check if the file already exists in the cache
if file_path.is_file():
return FileResponse(file_path)

headers = {}
headers["Authorization"] = f"Bearer {app.state.AZURE_OPENAI_API_KEY}"
headers["Content-Type"] = "application/json"

try:
print("openai")
r = requests.post(
url=target_url,
data=body,
headers=headers,
stream=True,
)

r.raise_for_status()

# Save the streaming content to a file
with open(file_path, "wb") as f:
for chunk in r.iter_content(chunk_size=8192):
f.write(chunk)

with open(file_body_path, "w") as f:
json.dump(json.loads(body.decode("utf-8")), f)

# Return the saved file
return FileResponse(file_path)

except Exception as e:
print(e)
error_detail = "Ollama WebUI: Server Connection Error"
if r is not None:
try:
res = r.json()
if "error" in res:
error_detail = f"External: {res['error']}"
except:
error_detail = f"External: {e}"

raise HTTPException(status_code=r.status_code, detail=error_detail)


@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
async def proxy(path: str, request: Request, user=Depends(get_current_user)):
target_url = f"{app.state.AZURE_OPENAI_API_BASE_URL}/{path}"
print(target_url, app.state.AZURE_OPENAI_API_KEY)

if user.role not in ["user", "admin"]:
raise HTTPException(
status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
# if app.state.AZURE_OPENAI_API_KEY == "":
# raise HTTPException(status_code=401, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)

body = await request.body()

# TODO: Remove below after gpt-4-vision fix from Open AI
# Try to decode the body of the request from bytes to a UTF-8 string (Require add max_token to fix gpt-4-vision)
try:
body = body.decode("utf-8")
body = json.loads(body)

# Check if the model is "gpt-4-vision-preview" and set "max_tokens" to 4000
# This is a workaround until OpenAI fixes the issue with this model
if body.get("model") == "gpt-4-vision-preview":
if "max_tokens" not in body:
body["max_tokens"] = 4000
print("Modified body_dict:", body)

# Convert the modified body back to JSON
body = json.dumps(body)
except json.JSONDecodeError as e:
print("Error loading request body into a dictionary:", e)

headers = {}
headers["Content-Type"] = "application/json"

target_url = f"{target_url}?api-version={app.state.AZURE_OPENAI_API_VERSION}&api-key={app.state.AZURE_OPENAI_API_KEY}"

try:
r = requests.request(
method=request.method,
url=target_url,
data=body,
headers=headers,
stream=True,
)

r.raise_for_status()

# Check if response is SSE
if "text/event-stream" in r.headers.get("Content-Type", ""):
return StreamingResponse(
r.iter_content(chunk_size=8192),
status_code=r.status_code,
headers=dict(r.headers),
)
else:
# For non-SSE, read the response and return it
# response_data = (
# r.json()
# if r.headers.get("Content-Type", "")
# == "application/json"
# else r.text
# )

response_data = r.json()

if "/azure-openai" in app.state.AZURE_OPENAI_API_BASE_URL and path == "models":
response_data["data"] = list(
filter(
lambda model: "gpt" in model["id"], response_data["data"])
)

return response_data
except Exception as e:
print(e)
error_detail = "Ollama WebUI: Server Connection Error"
if r is not None:
try:
res = r.json()
if "error" in res:
error_detail = f"External: {res['error']}"
except:
error_detail = f"External: {e}"

raise HTTPException(status_code=r.status_code, detail=error_detail)
23 changes: 15 additions & 8 deletions backend/apps/openai/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ async def get_openai_url(user=Depends(get_current_user)):
if user and user.role == "admin":
return {"OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL}
else:
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
raise HTTPException(
status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)


@app.post("/url/update")
Expand All @@ -50,15 +51,17 @@ async def update_openai_url(form_data: UrlUpdateForm, user=Depends(get_current_u
app.state.OPENAI_API_BASE_URL = form_data.url
return {"OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL}
else:
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
raise HTTPException(
status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)


@app.get("/key")
async def get_openai_key(user=Depends(get_current_user)):
if user and user.role == "admin":
return {"OPENAI_API_KEY": app.state.OPENAI_API_KEY}
else:
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
raise HTTPException(
status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)


@app.post("/key/update")
Expand All @@ -67,15 +70,17 @@ async def update_openai_key(form_data: KeyUpdateForm, user=Depends(get_current_u
app.state.OPENAI_API_KEY = form_data.key
return {"OPENAI_API_KEY": app.state.OPENAI_API_KEY}
else:
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
raise HTTPException(
status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)


@app.post("/audio/speech")
async def speech(request: Request, user=Depends(get_current_user)):
target_url = f"{app.state.OPENAI_API_BASE_URL}/audio/speech"

if user.role not in ["user", "admin"]:
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
raise HTTPException(
status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
# if app.state.OPENAI_API_KEY == "":
# raise HTTPException(status_code=401, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)

Expand Down Expand Up @@ -138,7 +143,8 @@ async def proxy(path: str, request: Request, user=Depends(get_current_user)):
print(target_url, app.state.OPENAI_API_KEY)

if user.role not in ["user", "admin"]:
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
raise HTTPException(
status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
# if app.state.OPENAI_API_KEY == "":
# raise HTTPException(status_code=401, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)

Expand Down Expand Up @@ -195,9 +201,10 @@ async def proxy(path: str, request: Request, user=Depends(get_current_user)):

response_data = r.json()

if "openai" in app.state.OPENAI_API_BASE_URL and path == "models":
if "/openai" in app.state.OPENAI_API_BASE_URL and path == "models":
response_data["data"] = list(
filter(lambda model: "gpt" in model["id"], response_data["data"])
filter(
lambda model: "gpt" in model["id"], response_data["data"])
)

return response_data
Expand Down
8 changes: 8 additions & 0 deletions backend/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,14 @@
OPENAI_API_BASE_URL = "https://api.openai.com/v1"


####################################
# AZURE_OPENAI_API
####################################

AZURE_OPENAI_API_KEY = os.environ.get("AZURE_OPENAI_API_KEY", "")
AZURE_OPENAI_API_BASE_URL = os.environ.get("AZURE_OPENAI_API_BASE_URL", "")
AZURE_OPENAI_API_VERSION = os.environ.get("AZURE_OPENAI_API_VERSION", "")

####################################
# WEBUI
####################################
Expand Down
2 changes: 2 additions & 0 deletions backend/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from apps.ollama.main import app as ollama_app
from apps.openai.main import app as openai_app
from apps.azure_openai.main import app as azure_openai_app

from apps.web.main import app as webui_app
from apps.rag.main import app as rag_app
Expand Down Expand Up @@ -55,6 +56,7 @@ async def check_url(request: Request, call_next):

app.mount("/ollama/api", ollama_app)
app.mount("/openai/api", openai_app)
app.mount("/azure-openai/api", azure_openai_app)
app.mount("/rag/api/v1", rag_app)


Expand Down
Loading