diff --git a/backend/apps/azure_openai/main.py b/backend/apps/azure_openai/main.py new file mode 100644 index 00000000000..ff3922c2bd2 --- /dev/null +++ b/backend/apps/azure_openai/main.py @@ -0,0 +1,249 @@ +from fastapi import FastAPI, Request, Response, HTTPException, Depends +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import StreamingResponse, JSONResponse, FileResponse + +import requests +import json +from pydantic import BaseModel + + +from apps.web.models.users import Users +from constants import ERROR_MESSAGES +from utils.utils import decode_token, get_current_user +from config import AZURE_OPENAI_API_BASE_URL, AZURE_OPENAI_API_KEY, AZURE_OPENAI_API_VERSION, CACHE_DIR + +import hashlib +from pathlib import Path + +app = FastAPI() +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +app.state.AZURE_OPENAI_API_BASE_URL = AZURE_OPENAI_API_BASE_URL +app.state.AZURE_OPENAI_API_KEY = AZURE_OPENAI_API_KEY +app.state.AZURE_OPENAI_API_VERSION = AZURE_OPENAI_API_VERSION + + +class UrlUpdateForm(BaseModel): + url: str + + +class KeyUpdateForm(BaseModel): + key: str + + +class VersionUpdateForm(BaseModel): + version: str + + +@app.get("/url") +async def get_azure_openai_url(user=Depends(get_current_user)): + if user and user.role == "admin": + return {"AZURE_OPENAI_API_BASE_URL": app.state.AZURE_OPENAI_API_BASE_URL} + else: + raise HTTPException( + status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) + + +@app.post("/url/update") +async def update_azure_openai_url(form_data: UrlUpdateForm, user=Depends(get_current_user)): + if user and user.role == "admin": + app.state.AZURE_OPENAI_API_BASE_URL = form_data.url + return {"AZURE_OPENAI_API_BASE_URL": app.state.AZURE_OPENAI_API_BASE_URL} + else: + raise HTTPException( + status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) + + +@app.get("/key") +async def get_azure_openai_key(user=Depends(get_current_user)): + if user and user.role == "admin": + return {"AZURE_OPENAI_API_KEY": app.state.AZURE_OPENAI_API_KEY} + else: + raise HTTPException( + status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) + + +@app.post("/key/update") +async def update_azure_openai_key(form_data: KeyUpdateForm, user=Depends(get_current_user)): + print(form_data) + if user and user.role == "admin": + app.state.AZURE_OPENAI_API_KEY = form_data.key + return {"AZURE_OPENAI_API_KEY": app.state.AZURE_OPENAI_API_KEY} + else: + raise HTTPException( + status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) + + +@app.get("/version") +async def get_azure_openai_version(user=Depends(get_current_user)): + if user and user.role == "admin": + return {"AZURE_OPENAI_API_VERSION": app.state.AZURE_OPENAI_API_VERSION} + else: + raise HTTPException( + status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) + + +@app.post("/version/update") +async def update_azure_openai_version(form_data: VersionUpdateForm, user=Depends(get_current_user)): + print(form_data) + if user and user.role == "admin": + app.state.AZURE_OPENAI_API_VERSION = form_data.version + return {"AZURE_OPENAI_API_VERSION": app.state.AZURE_OPENAI_API_VERSION} + else: + raise HTTPException( + status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) + + +@app.post("/audio/speech") +async def speech(request: Request, user=Depends(get_current_user)): + target_url = f"{app.state.AZURE_OPENAI_API_BASE_URL}/audio/speech" + + if user.role not in ["user", "admin"]: + raise HTTPException( + status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) + # if app.state.AZURE_OPENAI_API_KEY == "": + # raise HTTPException(status_code=401, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND) + + body = await request.body() + + name = hashlib.sha256(body).hexdigest() + + SPEECH_CACHE_DIR = Path(CACHE_DIR).joinpath("./audio/speech/") + SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True) + file_path = SPEECH_CACHE_DIR.joinpath(f"{name}.mp3") + file_body_path = SPEECH_CACHE_DIR.joinpath(f"{name}.json") + + # Check if the file already exists in the cache + if file_path.is_file(): + return FileResponse(file_path) + + headers = {} + headers["Authorization"] = f"Bearer {app.state.AZURE_OPENAI_API_KEY}" + headers["Content-Type"] = "application/json" + + try: + print("openai") + r = requests.post( + url=target_url, + data=body, + headers=headers, + stream=True, + ) + + r.raise_for_status() + + # Save the streaming content to a file + with open(file_path, "wb") as f: + for chunk in r.iter_content(chunk_size=8192): + f.write(chunk) + + with open(file_body_path, "w") as f: + json.dump(json.loads(body.decode("utf-8")), f) + + # Return the saved file + return FileResponse(file_path) + + except Exception as e: + print(e) + error_detail = "Ollama WebUI: Server Connection Error" + if r is not None: + try: + res = r.json() + if "error" in res: + error_detail = f"External: {res['error']}" + except: + error_detail = f"External: {e}" + + raise HTTPException(status_code=r.status_code, detail=error_detail) + + +@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"]) +async def proxy(path: str, request: Request, user=Depends(get_current_user)): + target_url = f"{app.state.AZURE_OPENAI_API_BASE_URL}/{path}" + print(target_url, app.state.AZURE_OPENAI_API_KEY) + + if user.role not in ["user", "admin"]: + raise HTTPException( + status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) + # if app.state.AZURE_OPENAI_API_KEY == "": + # raise HTTPException(status_code=401, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND) + + body = await request.body() + + # TODO: Remove below after gpt-4-vision fix from Open AI + # Try to decode the body of the request from bytes to a UTF-8 string (Require add max_token to fix gpt-4-vision) + try: + body = body.decode("utf-8") + body = json.loads(body) + + # Check if the model is "gpt-4-vision-preview" and set "max_tokens" to 4000 + # This is a workaround until OpenAI fixes the issue with this model + if body.get("model") == "gpt-4-vision-preview": + if "max_tokens" not in body: + body["max_tokens"] = 4000 + print("Modified body_dict:", body) + + # Convert the modified body back to JSON + body = json.dumps(body) + except json.JSONDecodeError as e: + print("Error loading request body into a dictionary:", e) + + headers = {} + headers["Content-Type"] = "application/json" + + target_url = f"{target_url}?api-version={app.state.AZURE_OPENAI_API_VERSION}&api-key={app.state.AZURE_OPENAI_API_KEY}" + + try: + r = requests.request( + method=request.method, + url=target_url, + data=body, + headers=headers, + stream=True, + ) + + r.raise_for_status() + + # Check if response is SSE + if "text/event-stream" in r.headers.get("Content-Type", ""): + return StreamingResponse( + r.iter_content(chunk_size=8192), + status_code=r.status_code, + headers=dict(r.headers), + ) + else: + # For non-SSE, read the response and return it + # response_data = ( + # r.json() + # if r.headers.get("Content-Type", "") + # == "application/json" + # else r.text + # ) + + response_data = r.json() + + if "/azure-openai" in app.state.AZURE_OPENAI_API_BASE_URL and path == "models": + response_data["data"] = list( + filter( + lambda model: "gpt" in model["id"], response_data["data"]) + ) + + return response_data + except Exception as e: + print(e) + error_detail = "Ollama WebUI: Server Connection Error" + if r is not None: + try: + res = r.json() + if "error" in res: + error_detail = f"External: {res['error']}" + except: + error_detail = f"External: {e}" + + raise HTTPException(status_code=r.status_code, detail=error_detail) diff --git a/backend/apps/openai/main.py b/backend/apps/openai/main.py index 346d5de8944..31642089d80 100644 --- a/backend/apps/openai/main.py +++ b/backend/apps/openai/main.py @@ -41,7 +41,8 @@ async def get_openai_url(user=Depends(get_current_user)): if user and user.role == "admin": return {"OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL} else: - raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) + raise HTTPException( + status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) @app.post("/url/update") @@ -50,7 +51,8 @@ async def update_openai_url(form_data: UrlUpdateForm, user=Depends(get_current_u app.state.OPENAI_API_BASE_URL = form_data.url return {"OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL} else: - raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) + raise HTTPException( + status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) @app.get("/key") @@ -58,7 +60,8 @@ async def get_openai_key(user=Depends(get_current_user)): if user and user.role == "admin": return {"OPENAI_API_KEY": app.state.OPENAI_API_KEY} else: - raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) + raise HTTPException( + status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) @app.post("/key/update") @@ -67,7 +70,8 @@ async def update_openai_key(form_data: KeyUpdateForm, user=Depends(get_current_u app.state.OPENAI_API_KEY = form_data.key return {"OPENAI_API_KEY": app.state.OPENAI_API_KEY} else: - raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) + raise HTTPException( + status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) @app.post("/audio/speech") @@ -75,7 +79,8 @@ async def speech(request: Request, user=Depends(get_current_user)): target_url = f"{app.state.OPENAI_API_BASE_URL}/audio/speech" if user.role not in ["user", "admin"]: - raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) + raise HTTPException( + status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) # if app.state.OPENAI_API_KEY == "": # raise HTTPException(status_code=401, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND) @@ -138,7 +143,8 @@ async def proxy(path: str, request: Request, user=Depends(get_current_user)): print(target_url, app.state.OPENAI_API_KEY) if user.role not in ["user", "admin"]: - raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) + raise HTTPException( + status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) # if app.state.OPENAI_API_KEY == "": # raise HTTPException(status_code=401, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND) @@ -195,9 +201,10 @@ async def proxy(path: str, request: Request, user=Depends(get_current_user)): response_data = r.json() - if "openai" in app.state.OPENAI_API_BASE_URL and path == "models": + if "/openai" in app.state.OPENAI_API_BASE_URL and path == "models": response_data["data"] = list( - filter(lambda model: "gpt" in model["id"], response_data["data"]) + filter( + lambda model: "gpt" in model["id"], response_data["data"]) ) return response_data diff --git a/backend/config.py b/backend/config.py index 65ee2298710..f901881fd51 100644 --- a/backend/config.py +++ b/backend/config.py @@ -66,6 +66,14 @@ OPENAI_API_BASE_URL = "https://api.openai.com/v1" +#################################### +# AZURE_OPENAI_API +#################################### + +AZURE_OPENAI_API_KEY = os.environ.get("AZURE_OPENAI_API_KEY", "") +AZURE_OPENAI_API_BASE_URL = os.environ.get("AZURE_OPENAI_API_BASE_URL", "") +AZURE_OPENAI_API_VERSION = os.environ.get("AZURE_OPENAI_API_VERSION", "") + #################################### # WEBUI #################################### diff --git a/backend/main.py b/backend/main.py index f7a82b66394..867b38ba6c1 100644 --- a/backend/main.py +++ b/backend/main.py @@ -10,6 +10,7 @@ from apps.ollama.main import app as ollama_app from apps.openai.main import app as openai_app +from apps.azure_openai.main import app as azure_openai_app from apps.web.main import app as webui_app from apps.rag.main import app as rag_app @@ -55,6 +56,7 @@ async def check_url(request: Request, call_next): app.mount("/ollama/api", ollama_app) app.mount("/openai/api", openai_app) +app.mount("/azure-openai/api", azure_openai_app) app.mount("/rag/api/v1", rag_app) diff --git a/src/lib/apis/azureopenai/index.ts b/src/lib/apis/azureopenai/index.ts new file mode 100644 index 00000000000..8f6cc6f41ec --- /dev/null +++ b/src/lib/apis/azureopenai/index.ts @@ -0,0 +1,329 @@ +import { AZURE_OPENAI_API_BASE_URL } from '$lib/constants'; + +export const getAzureOpenAIUrl = async (token: string = '') => { + let error = null; + + const res = await fetch(`${AZURE_OPENAI_API_BASE_URL}/url`, { + method: 'GET', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + ...(token && { authorization: `Bearer ${token}` }) + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + if ('detail' in err) { + error = err.detail; + } else { + error = 'Server connection failed'; + } + return null; + }); + + if (error) { + throw error; + } + + return res.AZURE_OPENAI_API_BASE_URL; +}; + +export const updateAzureOpenAIUrl = async (token: string = '', url: string) => { + let error = null; + + const res = await fetch(`${AZURE_OPENAI_API_BASE_URL}/url/update`, { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + ...(token && { authorization: `Bearer ${token}` }) + }, + body: JSON.stringify({ + url: url + }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + if ('detail' in err) { + error = err.detail; + } else { + error = 'Server connection failed'; + } + return null; + }); + + if (error) { + throw error; + } + + return res.AZURE_OPENAI_API_BASE_URL; +}; + +export const getAzureOpenAIKey = async (token: string = '') => { + let error = null; + + const res = await fetch(`${AZURE_OPENAI_API_BASE_URL}/key`, { + method: 'GET', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + ...(token && { authorization: `Bearer ${token}` }) + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + if ('detail' in err) { + error = err.detail; + } else { + error = 'Server connection failed'; + } + return null; + }); + + if (error) { + throw error; + } + + return res.AZURE_OPENAI_API_KEY; +}; + +export const updateAzureOpenAIKey = async (token: string = '', key: string) => { + let error = null; + + const res = await fetch(`${AZURE_OPENAI_API_BASE_URL}/key/update`, { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + ...(token && { authorization: `Bearer ${token}` }) + }, + body: JSON.stringify({ + key: key + }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + if ('detail' in err) { + error = err.detail; + } else { + error = 'Server connection failed'; + } + return null; + }); + + if (error) { + throw error; + } + + return res.AZURE_OPENAI_API_KEY; +}; + +export const getAzureOpenAIApiVersion = async (token: string = '') => { + let error = null; + + const res = await fetch(`${AZURE_OPENAI_API_BASE_URL}/version`, { + method: 'GET', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + ...(token && { authorization: `Bearer ${token}` }) + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + if ('detail' in err) { + error = err.detail; + } else { + error = 'Server connection failed'; + } + return null; + }); + + if (error) { + throw error; + } + + return res.AZURE_OPENAI_API_VERSION; +}; + +export const updateAzureOpenAIApiVersion = async (token: string = '', version: string) => { + let error = null; + + const res = await fetch(`${AZURE_OPENAI_API_BASE_URL}/version/update`, { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + ...(token && { authorization: `Bearer ${token}` }) + }, + body: JSON.stringify({ + version: version + }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + if ('detail' in err) { + error = err.detail; + } else { + error = 'Server connection failed'; + } + return null; + }); + + if (error) { + throw error; + } + + return res.AZURE_OPENAI_API_VERSION; +}; + +export const getAzureOpenAIModels = async (token: string = '') => { + let error = null; + + const res = await fetch(`${AZURE_OPENAI_API_BASE_URL}/models`, { + method: 'GET', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + ...(token && { authorization: `Bearer ${token}` }) + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + error = `OpenAI: ${err?.error?.message ?? 'Network Problem'}`; + return []; + }); + + if (error) { + throw error; + } + + const models = Array.isArray(res) ? res : res?.data ?? null; + + return models + ? models + .map((model) => ({ name: model.id, external: true })) + .sort((a, b) => { + return a.name.localeCompare(b.name); + }) + : models; +}; + +export const getAzureOpenAIModelsDirect = async ( + base_url: string = 'https://api.openai.com/v1', + api_key: string = '' +) => { + let error = null; + + const res = await fetch(`${base_url}/models`, { + method: 'GET', + headers: { + 'Content-Type': 'application/json', + Authorization: `Bearer ${api_key}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + error = `OpenAI: ${err?.error?.message ?? 'Network Problem'}`; + return null; + }); + + if (error) { + throw error; + } + + const models = Array.isArray(res) ? res : res?.data ?? null; + + return models + .map((model) => ({ name: model.id, external: true })) + .filter((model) => (base_url.includes('openai') ? model.name.includes('gpt') : true)) + .sort((a, b) => { + return a.name.localeCompare(b.name); + }); +}; + +export const generateAzureOpenAIChatCompletion = async (token: string = '', body: object) => { + let error = null; + + const res = await fetch(`${AZURE_OPENAI_API_BASE_URL}/chat/completions`, { + method: 'POST', + headers: { + Authorization: `Bearer ${token}`, + 'Content-Type': 'application/json' + }, + body: JSON.stringify(body) + }).catch((err) => { + console.log(err); + error = err; + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const synthesizeOpenAISpeech = async ( + token: string = '', + speaker: string = 'alloy', + text: string = '' +) => { + let error = null; + + const res = await fetch(`${AZURE_OPENAI_API_BASE_URL}/audio/speech`, { + method: 'POST', + headers: { + Authorization: `Bearer ${token}`, + 'Content-Type': 'application/json' + }, + body: JSON.stringify({ + model: 'tts-1', + input: text, + voice: speaker + }) + }).catch((err) => { + console.log(err); + error = err; + return null; + }); + + if (error) { + throw error; + } + + return res; +}; diff --git a/src/lib/components/chat/Settings/External.svelte b/src/lib/components/chat/Settings/External.svelte index 455370eb1e9..2db092ef08d 100644 --- a/src/lib/components/chat/Settings/External.svelte +++ b/src/lib/components/chat/Settings/External.svelte @@ -1,10 +1,19 @@ @@ -29,6 +59,7 @@ class="flex flex-col h-full justify-between space-y-3 text-sm" on:submit|preventDefault={() => { updateOpenAIHandler(); + updateAzureOpenAIHandler(); dispatch('save'); // saveSettings({ @@ -73,6 +104,64 @@ WebUI will make requests to '{OPENAI_API_BASE_URL}/chat' + +