From 5d35656a3561931fc00f812a257774315872960f Mon Sep 17 00:00:00 2001 From: Keerthivasan D Date: Sat, 10 Feb 2024 13:36:33 +0530 Subject: [PATCH 1/2] Keerthi: Add ui options to integrate with Azure OpenAI --- backend/apps/azure_openai/main.py | 223 +++++++++++++++ backend/apps/openai/main.py | 23 +- backend/config.py | 10 + backend/main.py | 2 + src/lib/apis/azureopenai/index.ts | 262 ++++++++++++++++++ .../components/chat/Settings/External.svelte | 62 +++++ src/lib/constants.ts | 1 + 7 files changed, 575 insertions(+), 8 deletions(-) create mode 100644 backend/apps/azure_openai/main.py create mode 100644 src/lib/apis/azureopenai/index.ts diff --git a/backend/apps/azure_openai/main.py b/backend/apps/azure_openai/main.py new file mode 100644 index 00000000000..8b6f36f757d --- /dev/null +++ b/backend/apps/azure_openai/main.py @@ -0,0 +1,223 @@ +from fastapi import FastAPI, Request, Response, HTTPException, Depends +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import StreamingResponse, JSONResponse, FileResponse + +import requests +import json +from pydantic import BaseModel + + +from apps.web.models.users import Users +from constants import ERROR_MESSAGES +from utils.utils import decode_token, get_current_user +from config import AZURE_OPENAI_API_BASE_URL, AZURE_OPENAI_API_KEY, CACHE_DIR + +import hashlib +from pathlib import Path + +app = FastAPI() +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +app.state.AZURE_OPENAI_API_BASE_URL = AZURE_OPENAI_API_BASE_URL +app.state.AZURE_OPENAI_API_KEY = AZURE_OPENAI_API_KEY + + +class UrlUpdateForm(BaseModel): + url: str + + +class KeyUpdateForm(BaseModel): + key: str + + +@app.get("/url") +async def get_azure_openai_url(user=Depends(get_current_user)): + if user and user.role == "admin": + return {"AZURE_OPENAI_API_BASE_URL": app.state.AZURE_OPENAI_API_BASE_URL} + else: + raise HTTPException( + status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) + + +@app.post("/url/update") +async def update_azure_openai_url(form_data: UrlUpdateForm, user=Depends(get_current_user)): + if user and user.role == "admin": + app.state.AZURE_OPENAI_API_BASE_URL = form_data.url + return {"AZURE_OPENAI_API_BASE_URL": app.state.AZURE_OPENAI_API_BASE_URL} + else: + raise HTTPException( + status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) + + +@app.get("/key") +async def get_azure_openai_key(user=Depends(get_current_user)): + if user and user.role == "admin": + return {"AZURE_OPENAI_API_KEY": app.state.AZURE_OPENAI_API_KEY} + else: + raise HTTPException( + status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) + + +@app.post("/key/update") +async def update_azure_openai_key(form_data: KeyUpdateForm, user=Depends(get_current_user)): + print(form_data) + if user and user.role == "admin": + app.state.AZURE_OPENAI_API_KEY = form_data.key + return {"AZURE_OPENAI_API_KEY": app.state.AZURE_OPENAI_API_KEY} + else: + raise HTTPException( + status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) + + +@app.post("/audio/speech") +async def speech(request: Request, user=Depends(get_current_user)): + target_url = f"{app.state.AZURE_OPENAI_API_BASE_URL}/audio/speech" + + if user.role not in ["user", "admin"]: + raise HTTPException( + status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) + # if app.state.AZURE_OPENAI_API_KEY == "": + # raise HTTPException(status_code=401, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND) + + body = await request.body() + + name = hashlib.sha256(body).hexdigest() + + SPEECH_CACHE_DIR = Path(CACHE_DIR).joinpath("./audio/speech/") + SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True) + file_path = SPEECH_CACHE_DIR.joinpath(f"{name}.mp3") + file_body_path = SPEECH_CACHE_DIR.joinpath(f"{name}.json") + + # Check if the file already exists in the cache + if file_path.is_file(): + return FileResponse(file_path) + + headers = {} + headers["Authorization"] = f"Bearer {app.state.AZURE_OPENAI_API_KEY}" + headers["Content-Type"] = "application/json" + + try: + print("openai") + r = requests.post( + url=target_url, + data=body, + headers=headers, + stream=True, + ) + + r.raise_for_status() + + # Save the streaming content to a file + with open(file_path, "wb") as f: + for chunk in r.iter_content(chunk_size=8192): + f.write(chunk) + + with open(file_body_path, "w") as f: + json.dump(json.loads(body.decode("utf-8")), f) + + # Return the saved file + return FileResponse(file_path) + + except Exception as e: + print(e) + error_detail = "Ollama WebUI: Server Connection Error" + if r is not None: + try: + res = r.json() + if "error" in res: + error_detail = f"External: {res['error']}" + except: + error_detail = f"External: {e}" + + raise HTTPException(status_code=r.status_code, detail=error_detail) + + +@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"]) +async def proxy(path: str, request: Request, user=Depends(get_current_user)): + target_url = f"{app.state.AZURE_OPENAI_API_BASE_URL}/{path}" + print(target_url, app.state.AZURE_OPENAI_API_KEY) + + if user.role not in ["user", "admin"]: + raise HTTPException( + status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) + # if app.state.AZURE_OPENAI_API_KEY == "": + # raise HTTPException(status_code=401, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND) + + body = await request.body() + + # TODO: Remove below after gpt-4-vision fix from Open AI + # Try to decode the body of the request from bytes to a UTF-8 string (Require add max_token to fix gpt-4-vision) + try: + body = body.decode("utf-8") + body = json.loads(body) + + # Check if the model is "gpt-4-vision-preview" and set "max_tokens" to 4000 + # This is a workaround until OpenAI fixes the issue with this model + if body.get("model") == "gpt-4-vision-preview": + if "max_tokens" not in body: + body["max_tokens"] = 4000 + print("Modified body_dict:", body) + + # Convert the modified body back to JSON + body = json.dumps(body) + except json.JSONDecodeError as e: + print("Error loading request body into a dictionary:", e) + + headers = {} + headers["Authorization"] = f"Bearer {app.state.AZURE_OPENAI_API_KEY}" + headers["Content-Type"] = "application/json" + + try: + r = requests.request( + method=request.method, + url=target_url, + data=body, + headers=headers, + stream=True, + ) + + r.raise_for_status() + + # Check if response is SSE + if "text/event-stream" in r.headers.get("Content-Type", ""): + return StreamingResponse( + r.iter_content(chunk_size=8192), + status_code=r.status_code, + headers=dict(r.headers), + ) + else: + # For non-SSE, read the response and return it + # response_data = ( + # r.json() + # if r.headers.get("Content-Type", "") + # == "application/json" + # else r.text + # ) + + response_data = r.json() + + if "/azure-openai" in app.state.AZURE_OPENAI_API_BASE_URL and path == "models": + response_data["data"] = list( + filter( + lambda model: "gpt" in model["id"], response_data["data"]) + ) + + return response_data + except Exception as e: + print(e) + error_detail = "Ollama WebUI: Server Connection Error" + if r is not None: + try: + res = r.json() + if "error" in res: + error_detail = f"External: {res['error']}" + except: + error_detail = f"External: {e}" + + raise HTTPException(status_code=r.status_code, detail=error_detail) diff --git a/backend/apps/openai/main.py b/backend/apps/openai/main.py index 346d5de8944..31642089d80 100644 --- a/backend/apps/openai/main.py +++ b/backend/apps/openai/main.py @@ -41,7 +41,8 @@ async def get_openai_url(user=Depends(get_current_user)): if user and user.role == "admin": return {"OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL} else: - raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) + raise HTTPException( + status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) @app.post("/url/update") @@ -50,7 +51,8 @@ async def update_openai_url(form_data: UrlUpdateForm, user=Depends(get_current_u app.state.OPENAI_API_BASE_URL = form_data.url return {"OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL} else: - raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) + raise HTTPException( + status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) @app.get("/key") @@ -58,7 +60,8 @@ async def get_openai_key(user=Depends(get_current_user)): if user and user.role == "admin": return {"OPENAI_API_KEY": app.state.OPENAI_API_KEY} else: - raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) + raise HTTPException( + status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) @app.post("/key/update") @@ -67,7 +70,8 @@ async def update_openai_key(form_data: KeyUpdateForm, user=Depends(get_current_u app.state.OPENAI_API_KEY = form_data.key return {"OPENAI_API_KEY": app.state.OPENAI_API_KEY} else: - raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) + raise HTTPException( + status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) @app.post("/audio/speech") @@ -75,7 +79,8 @@ async def speech(request: Request, user=Depends(get_current_user)): target_url = f"{app.state.OPENAI_API_BASE_URL}/audio/speech" if user.role not in ["user", "admin"]: - raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) + raise HTTPException( + status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) # if app.state.OPENAI_API_KEY == "": # raise HTTPException(status_code=401, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND) @@ -138,7 +143,8 @@ async def proxy(path: str, request: Request, user=Depends(get_current_user)): print(target_url, app.state.OPENAI_API_KEY) if user.role not in ["user", "admin"]: - raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) + raise HTTPException( + status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) # if app.state.OPENAI_API_KEY == "": # raise HTTPException(status_code=401, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND) @@ -195,9 +201,10 @@ async def proxy(path: str, request: Request, user=Depends(get_current_user)): response_data = r.json() - if "openai" in app.state.OPENAI_API_BASE_URL and path == "models": + if "/openai" in app.state.OPENAI_API_BASE_URL and path == "models": response_data["data"] = list( - filter(lambda model: "gpt" in model["id"], response_data["data"]) + filter( + lambda model: "gpt" in model["id"], response_data["data"]) ) return response_data diff --git a/backend/config.py b/backend/config.py index 65ee2298710..2b2f4f1fc1c 100644 --- a/backend/config.py +++ b/backend/config.py @@ -66,6 +66,16 @@ OPENAI_API_BASE_URL = "https://api.openai.com/v1" +#################################### +# AZURE_OPENAI_API +#################################### + +AZURE_OPENAI_API_KEY = os.environ.get("AZURE_OPENAI_API_KEY", "") +AZURE_OPENAI_API_BASE_URL = os.environ.get("AZURE_OPENAI_API_BASE_URL", "") + +if AZURE_OPENAI_API_BASE_URL == "": + AZURE_OPENAI_API_BASE_URL = "https://api.openai.com/v1" + #################################### # WEBUI #################################### diff --git a/backend/main.py b/backend/main.py index f7a82b66394..867b38ba6c1 100644 --- a/backend/main.py +++ b/backend/main.py @@ -10,6 +10,7 @@ from apps.ollama.main import app as ollama_app from apps.openai.main import app as openai_app +from apps.azure_openai.main import app as azure_openai_app from apps.web.main import app as webui_app from apps.rag.main import app as rag_app @@ -55,6 +56,7 @@ async def check_url(request: Request, call_next): app.mount("/ollama/api", ollama_app) app.mount("/openai/api", openai_app) +app.mount("/azure-openai/api", azure_openai_app) app.mount("/rag/api/v1", rag_app) diff --git a/src/lib/apis/azureopenai/index.ts b/src/lib/apis/azureopenai/index.ts new file mode 100644 index 00000000000..cbb709c5490 --- /dev/null +++ b/src/lib/apis/azureopenai/index.ts @@ -0,0 +1,262 @@ +import { AZURE_OPENAI_API_BASE_URL } from '$lib/constants'; + +export const getAzureOpenAIUrl = async (token: string = '') => { + let error = null; + + const res = await fetch(`${AZURE_OPENAI_API_BASE_URL}/url`, { + method: 'GET', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + ...(token && { authorization: `Bearer ${token}` }) + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + if ('detail' in err) { + error = err.detail; + } else { + error = 'Server connection failed'; + } + return null; + }); + + if (error) { + throw error; + } + + return res.AZURE_OPENAI_API_BASE_URL; +}; + +export const updateAzureOpenAIUrl = async (token: string = '', url: string) => { + let error = null; + + const res = await fetch(`${AZURE_OPENAI_API_BASE_URL}/url/update`, { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + ...(token && { authorization: `Bearer ${token}` }) + }, + body: JSON.stringify({ + url: url + }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + if ('detail' in err) { + error = err.detail; + } else { + error = 'Server connection failed'; + } + return null; + }); + + if (error) { + throw error; + } + + return res.AZURE_OPENAI_API_BASE_URL; +}; + +export const getAzureOpenAIKey = async (token: string = '') => { + let error = null; + + const res = await fetch(`${AZURE_OPENAI_API_BASE_URL}/key`, { + method: 'GET', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + ...(token && { authorization: `Bearer ${token}` }) + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + if ('detail' in err) { + error = err.detail; + } else { + error = 'Server connection failed'; + } + return null; + }); + + if (error) { + throw error; + } + + return res.AZURE_OPENAI_API_KEY; +}; + +export const updateAzureOpenAIKey = async (token: string = '', key: string) => { + let error = null; + + const res = await fetch(`${AZURE_OPENAI_API_BASE_URL}/key/update`, { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + ...(token && { authorization: `Bearer ${token}` }) + }, + body: JSON.stringify({ + key: key + }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + if ('detail' in err) { + error = err.detail; + } else { + error = 'Server connection failed'; + } + return null; + }); + + if (error) { + throw error; + } + + return res.AZURE_OPENAI_API_KEY; +}; + +export const getAzureOpenAIModels = async (token: string = '') => { + let error = null; + + const res = await fetch(`${AZURE_OPENAI_API_BASE_URL}/models`, { + method: 'GET', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + ...(token && { authorization: `Bearer ${token}` }) + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + error = `OpenAI: ${err?.error?.message ?? 'Network Problem'}`; + return []; + }); + + if (error) { + throw error; + } + + const models = Array.isArray(res) ? res : res?.data ?? null; + + return models + ? models + .map((model) => ({ name: model.id, external: true })) + .sort((a, b) => { + return a.name.localeCompare(b.name); + }) + : models; +}; + +export const getAzureOpenAIModelsDirect = async ( + base_url: string = 'https://api.openai.com/v1', + api_key: string = '' +) => { + let error = null; + + const res = await fetch(`${base_url}/models`, { + method: 'GET', + headers: { + 'Content-Type': 'application/json', + Authorization: `Bearer ${api_key}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + error = `OpenAI: ${err?.error?.message ?? 'Network Problem'}`; + return null; + }); + + if (error) { + throw error; + } + + const models = Array.isArray(res) ? res : res?.data ?? null; + + return models + .map((model) => ({ name: model.id, external: true })) + .filter((model) => (base_url.includes('openai') ? model.name.includes('gpt') : true)) + .sort((a, b) => { + return a.name.localeCompare(b.name); + }); +}; + +export const generateAzureOpenAIChatCompletion = async (token: string = '', body: object) => { + let error = null; + + const res = await fetch(`${AZURE_OPENAI_API_BASE_URL}/chat/completions`, { + method: 'POST', + headers: { + Authorization: `Bearer ${token}`, + 'Content-Type': 'application/json' + }, + body: JSON.stringify(body) + }).catch((err) => { + console.log(err); + error = err; + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const synthesizeOpenAISpeech = async ( + token: string = '', + speaker: string = 'alloy', + text: string = '' +) => { + let error = null; + + const res = await fetch(`${AZURE_OPENAI_API_BASE_URL}/audio/speech`, { + method: 'POST', + headers: { + Authorization: `Bearer ${token}`, + 'Content-Type': 'application/json' + }, + body: JSON.stringify({ + model: 'tts-1', + input: text, + voice: speaker + }) + }).catch((err) => { + console.log(err); + error = err; + return null; + }); + + if (error) { + throw error; + } + + return res; +}; diff --git a/src/lib/components/chat/Settings/External.svelte b/src/lib/components/chat/Settings/External.svelte index 455370eb1e9..66c3cbf68ff 100644 --- a/src/lib/components/chat/Settings/External.svelte +++ b/src/lib/components/chat/Settings/External.svelte @@ -1,10 +1,17 @@ @@ -29,6 +51,7 @@ class="flex flex-col h-full justify-between space-y-3 text-sm" on:submit|preventDefault={() => { updateOpenAIHandler(); + updateAzureOpenAIHandler(); dispatch('save'); // saveSettings({ @@ -73,6 +96,45 @@ WebUI will make requests to '{OPENAI_API_BASE_URL}/chat' + +
+ +
+
Azure OpenAI API Key
+
+
+ +
+
+
+ Adds optional support for online models. +
+
+ +
+ +
+
Azure OpenAI API Base URL
+
+
+ +
+
+
+ WebUI will make requests to '{AZURE_OPENAI_API_BASE_URL}' +
+
diff --git a/src/lib/constants.ts b/src/lib/constants.ts index a85a6ccc2d1..9ba2253b1fa 100644 --- a/src/lib/constants.ts +++ b/src/lib/constants.ts @@ -5,6 +5,7 @@ export const WEBUI_BASE_URL = dev ? `http://${location.hostname}:8080` : ``; export const WEBUI_API_BASE_URL = `${WEBUI_BASE_URL}/api/v1`; export const OLLAMA_API_BASE_URL = `${WEBUI_BASE_URL}/ollama/api`; export const OPENAI_API_BASE_URL = `${WEBUI_BASE_URL}/openai/api`; +export const AZURE_OPENAI_API_BASE_URL = `${WEBUI_BASE_URL}/azure-openai/api`; export const RAG_API_BASE_URL = `${WEBUI_BASE_URL}/rag/api/v1`; export const WEB_UI_VERSION = 'v1.0.0-alpha-static'; From 5684f0b095577b021fab08bb743c6642b3a183df Mon Sep 17 00:00:00 2001 From: Keerthivasan D Date: Sun, 11 Feb 2024 23:03:01 +0530 Subject: [PATCH 2/2] Keerthi: [WIP] Azure Integration --- backend/apps/azure_openai/main.py | 30 ++++++++- backend/config.py | 4 +- src/lib/apis/azureopenai/index.ts | 67 +++++++++++++++++++ .../components/chat/Settings/External.svelte | 29 +++++++- src/lib/components/chat/SettingsModal.svelte | 9 ++- src/routes/(app)/+layout.svelte | 7 ++ 6 files changed, 139 insertions(+), 7 deletions(-) diff --git a/backend/apps/azure_openai/main.py b/backend/apps/azure_openai/main.py index 8b6f36f757d..ff3922c2bd2 100644 --- a/backend/apps/azure_openai/main.py +++ b/backend/apps/azure_openai/main.py @@ -10,7 +10,7 @@ from apps.web.models.users import Users from constants import ERROR_MESSAGES from utils.utils import decode_token, get_current_user -from config import AZURE_OPENAI_API_BASE_URL, AZURE_OPENAI_API_KEY, CACHE_DIR +from config import AZURE_OPENAI_API_BASE_URL, AZURE_OPENAI_API_KEY, AZURE_OPENAI_API_VERSION, CACHE_DIR import hashlib from pathlib import Path @@ -26,6 +26,7 @@ app.state.AZURE_OPENAI_API_BASE_URL = AZURE_OPENAI_API_BASE_URL app.state.AZURE_OPENAI_API_KEY = AZURE_OPENAI_API_KEY +app.state.AZURE_OPENAI_API_VERSION = AZURE_OPENAI_API_VERSION class UrlUpdateForm(BaseModel): @@ -36,6 +37,10 @@ class KeyUpdateForm(BaseModel): key: str +class VersionUpdateForm(BaseModel): + version: str + + @app.get("/url") async def get_azure_openai_url(user=Depends(get_current_user)): if user and user.role == "admin": @@ -75,6 +80,26 @@ async def update_azure_openai_key(form_data: KeyUpdateForm, user=Depends(get_cur status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) +@app.get("/version") +async def get_azure_openai_version(user=Depends(get_current_user)): + if user and user.role == "admin": + return {"AZURE_OPENAI_API_VERSION": app.state.AZURE_OPENAI_API_VERSION} + else: + raise HTTPException( + status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) + + +@app.post("/version/update") +async def update_azure_openai_version(form_data: VersionUpdateForm, user=Depends(get_current_user)): + print(form_data) + if user and user.role == "admin": + app.state.AZURE_OPENAI_API_VERSION = form_data.version + return {"AZURE_OPENAI_API_VERSION": app.state.AZURE_OPENAI_API_VERSION} + else: + raise HTTPException( + status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) + + @app.post("/audio/speech") async def speech(request: Request, user=Depends(get_current_user)): target_url = f"{app.state.AZURE_OPENAI_API_BASE_URL}/audio/speech" @@ -170,9 +195,10 @@ async def proxy(path: str, request: Request, user=Depends(get_current_user)): print("Error loading request body into a dictionary:", e) headers = {} - headers["Authorization"] = f"Bearer {app.state.AZURE_OPENAI_API_KEY}" headers["Content-Type"] = "application/json" + target_url = f"{target_url}?api-version={app.state.AZURE_OPENAI_API_VERSION}&api-key={app.state.AZURE_OPENAI_API_KEY}" + try: r = requests.request( method=request.method, diff --git a/backend/config.py b/backend/config.py index 2b2f4f1fc1c..f901881fd51 100644 --- a/backend/config.py +++ b/backend/config.py @@ -72,9 +72,7 @@ AZURE_OPENAI_API_KEY = os.environ.get("AZURE_OPENAI_API_KEY", "") AZURE_OPENAI_API_BASE_URL = os.environ.get("AZURE_OPENAI_API_BASE_URL", "") - -if AZURE_OPENAI_API_BASE_URL == "": - AZURE_OPENAI_API_BASE_URL = "https://api.openai.com/v1" +AZURE_OPENAI_API_VERSION = os.environ.get("AZURE_OPENAI_API_VERSION", "") #################################### # WEBUI diff --git a/src/lib/apis/azureopenai/index.ts b/src/lib/apis/azureopenai/index.ts index cbb709c5490..8f6cc6f41ec 100644 --- a/src/lib/apis/azureopenai/index.ts +++ b/src/lib/apis/azureopenai/index.ts @@ -134,6 +134,73 @@ export const updateAzureOpenAIKey = async (token: string = '', key: string) => { return res.AZURE_OPENAI_API_KEY; }; +export const getAzureOpenAIApiVersion = async (token: string = '') => { + let error = null; + + const res = await fetch(`${AZURE_OPENAI_API_BASE_URL}/version`, { + method: 'GET', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + ...(token && { authorization: `Bearer ${token}` }) + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + if ('detail' in err) { + error = err.detail; + } else { + error = 'Server connection failed'; + } + return null; + }); + + if (error) { + throw error; + } + + return res.AZURE_OPENAI_API_VERSION; +}; + +export const updateAzureOpenAIApiVersion = async (token: string = '', version: string) => { + let error = null; + + const res = await fetch(`${AZURE_OPENAI_API_BASE_URL}/version/update`, { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + ...(token && { authorization: `Bearer ${token}` }) + }, + body: JSON.stringify({ + version: version + }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + if ('detail' in err) { + error = err.detail; + } else { + error = 'Server connection failed'; + } + return null; + }); + + if (error) { + throw error; + } + + return res.AZURE_OPENAI_API_VERSION; +}; + export const getAzureOpenAIModels = async (token: string = '') => { let error = null; diff --git a/src/lib/components/chat/Settings/External.svelte b/src/lib/components/chat/Settings/External.svelte index 66c3cbf68ff..2db092ef08d 100644 --- a/src/lib/components/chat/Settings/External.svelte +++ b/src/lib/components/chat/Settings/External.svelte @@ -4,7 +4,9 @@ getAzureOpenAIKey, getAzureOpenAIUrl, updateAzureOpenAIKey, - updateAzureOpenAIUrl + updateAzureOpenAIUrl, + getAzureOpenAIApiVersion, + updateAzureOpenAIApiVersion } from '$lib/apis/azureopenai'; import { models, user } from '$lib/stores'; import { createEventDispatcher, onMount } from 'svelte'; @@ -26,6 +28,7 @@ let AZURE_OPENAI_API_KEY = ''; let AZURE_OPENAI_API_BASE_URL = ''; + let AZURE_OPENAI_API_VERSION = ''; const updateAzureOpenAIHandler = async () => { AZURE_OPENAI_API_BASE_URL = await updateAzureOpenAIUrl( @@ -33,6 +36,10 @@ AZURE_OPENAI_API_BASE_URL ); AZURE_OPENAI_API_KEY = await updateAzureOpenAIKey(localStorage.token, AZURE_OPENAI_API_KEY); + AZURE_OPENAI_API_VERSION = await updateAzureOpenAIApiVersion( + localStorage.token, + AZURE_OPENAI_API_VERSION + ); await models.set(await getAzureModels()); }; @@ -43,6 +50,7 @@ OPENAI_API_KEY = await getOpenAIKey(localStorage.token); AZURE_OPENAI_API_BASE_URL = await getAzureOpenAIUrl(localStorage.token); AZURE_OPENAI_API_KEY = await getAzureOpenAIKey(localStorage.token); + AZURE_OPENAI_API_VERSION = await getAzureOpenAIApiVersion(localStorage.token); } }); @@ -135,6 +143,25 @@ >
+ +
+ +
+
Azure OpenAI API Version
+
+
+ +
+
+
+ WebUI will make requests to '{AZURE_OPENAI_API_VERSION}' +
+
diff --git a/src/lib/components/chat/SettingsModal.svelte b/src/lib/components/chat/SettingsModal.svelte index 12874c12a34..70a8995ae18 100644 --- a/src/lib/components/chat/SettingsModal.svelte +++ b/src/lib/components/chat/SettingsModal.svelte @@ -4,6 +4,7 @@ import { getOllamaModels } from '$lib/apis/ollama'; import { getOpenAIModels } from '$lib/apis/openai'; + import { getAzureOpenAIModels } from '$lib/apis/azureopenai'; import Modal from '../common/Modal.svelte'; import Account from './Settings/Account.svelte'; @@ -42,6 +43,12 @@ return null; }); models.push(...(openAIModels ? [{ name: 'hr' }, ...openAIModels] : [])); + + const azureOpenAIModels = await getAzureOpenAIModels(localStorage.token).catch((error) => { + console.log(error); + return null; + }); + models.push(...(azureOpenAIModels ? [{ name: 'hr' }, ...azureOpenAIModels] : [])); } return models; @@ -384,4 +391,4 @@ input[type='number'] { -moz-appearance: textfield; /* Firefox */ } - \ No newline at end of file + diff --git a/src/routes/(app)/+layout.svelte b/src/routes/(app)/+layout.svelte index cc3672c4f82..e05cb059232 100644 --- a/src/routes/(app)/+layout.svelte +++ b/src/routes/(app)/+layout.svelte @@ -12,6 +12,7 @@ import { getPrompts } from '$lib/apis/prompts'; import { getOpenAIModels } from '$lib/apis/openai'; + import { getAzureOpenAIModels } from '$lib/apis/azureopenai'; import { user, @@ -61,6 +62,12 @@ models.push(...(openAIModels ? [{ name: 'hr' }, ...openAIModels] : [])); + const azureOpenAIModels = await getAzureOpenAIModels(localStorage.token).catch((error) => { + console.log(error); + return null; + }); + models.push(...(azureOpenAIModels ? [{ name: 'hr' }, ...azureOpenAIModels] : [])); + return models; };