-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
109 lines (93 loc) · 4.35 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
from fastapi import FastAPI, Query
from fastapi.responses import PlainTextResponse
import uvicorn
from config import Config, parse_args
from openai_client import LLMClient
from db import DB
proxy_server = FastAPI()
args = parse_args()
if args.config:
config = Config.from_toml(args.config)
else:
config = Config.from_args(args)
client = LLMClient.from_config(config)
db = DB.from_config(config.database_config)
@proxy_server.get("/translate", response_class=PlainTextResponse)
async def translation_handler(
text: str,
tgt_lang: str = Query(..., alias="to"),
src_lang: str = Query(..., alias="from")
):
"""
Handle translation requests between specified languages.
This function translates a piece of text from a source language to a target language.
It checks if the current chat history's source and target languages differ from the provided
ones and updates them if necessary. Then, it sends a completion request to the language model
client and returns the translated text.
Args:
to (str): The target language code.
text (str): The text to be translated.
from (str): The source language code.
Returns:
str: The translated text.
"""
if "" in [client.chat_history.src_lang, client.chat_history.tgt_lang] or client.chat_history.src_lang != src_lang or client.chat_history.tgt_lang != tgt_lang:
client.reset_history()
client.set_language_targets(src_lang, tgt_lang)
if client.config.database_config.use_latest_records:
translation_records = db.get_latest_translations(src_lang, tgt_lang, client.config.database_config.init_latest_records)
if translation_records:
client.apply_latest_translations(translation_records)
print("Latest translation Applied.")
client.chat_history.add_user_content(client.prompt.template.get_src_filled_prompt(text))
if client.config.database_config.use_cached_translation:
translated_text = db.fetch_translation(src_lang, tgt_lang, text)
if translated_text:
print("Got cached translation!")
completion_res = f"{client.prompt.template.tag.tgt_start}{translated_text}{client.prompt.template.tag.tgt_end}"
else:
completion_res = client.request_completion()
else:
completion_res = client.request_completion()
if client.config.history_config.use_history:
client.chat_history.add_assistant_content(completion_res)
if client.config.history_config.max_history > -1 :
if len(client.chat_history.chat_history)-(1+(1 if client.prompt.system_prompt.use_system_prompt else 0)) >= client.config.history_config.max_history:
if client.config.history_config.use_latest_history:
chat_history = client.chat_history.chat_history[2+(1 if client.prompt.system_prompt.use_system_prompt else 0):]
client.reset_history()
client.set_language_targets(src_lang, tgt_lang)
for turn in chat_history:
client.chat_history.add_message(turn['role'], turn['content'])
else:
client.chat_history.delete_latest_turns(2)
else:
client.set_language_targets("","")
translated_text = client.prompt.template.get_translated_text(completion_res)
if client.config.database_config.cache_translation:
if not db.fetch_translation(src_lang, tgt_lang, text):
db.save_translation(src_lang, tgt_lang, text, translated_text)
print(f"Original: {text}\nTranslated: {translated_text}")
return translated_text
@proxy_server.get("/reset")
async def reset_handler():
"""
Reset the chat history.
This function resets the chat history to its initial state, optionally using the system prompt
based on the configuration.
Returns:
str: A success message indicating that the reset was successful.
"""
client.reset_history()
return "Reset successful"
@proxy_server.get("/status")
async def status_handler():
"""
Get the status of the server.
This function returns a message indicating that the server is running.
Returns:
str: A message indicating the server's status.
"""
return "Server is running"
if __name__ == "__main__":
uvicorn.run(proxy_server, host=config.server_config.host, port=config.server_config.port)