diff --git a/livekit-plugins/livekit-plugins-asyncai/README.md b/livekit-plugins/livekit-plugins-asyncai/README.md new file mode 100644 index 0000000000..4adc99afc8 --- /dev/null +++ b/livekit-plugins/livekit-plugins-asyncai/README.md @@ -0,0 +1,15 @@ +# Async plugin for LiveKit Agents + +Support for text to speech with [Async](https://async.ai/). + +See https://docs.livekit.io/agents/integrations/tts/asyncai/ for more information. + +## Installation + +```bash +pip install livekit-plugins-asyncai +``` + +## Pre-requisites + +You'll need an API key from Async. It can be set as an environment variable: `ASYNCAI_API_KEY` \ No newline at end of file diff --git a/livekit-plugins/livekit-plugins-asyncai/livekit/plugins/asyncai/__init__.py b/livekit-plugins/livekit-plugins-asyncai/livekit/plugins/asyncai/__init__.py new file mode 100644 index 0000000000..b7e4d100b9 --- /dev/null +++ b/livekit-plugins/livekit-plugins-asyncai/livekit/plugins/asyncai/__init__.py @@ -0,0 +1,44 @@ +# Copyright 2023 LiveKit, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""AsyncAI plugin for LiveKit Agents + +See https://docs.livekit.io/agents/integrations/tts/asyncai/ for more information. +""" + +from .tts import TTS +from .version import __version__ + +__all__ = ["TTS", "__version__"] + +from livekit.agents import Plugin + +from .log import logger + + +class AsyncAIPlugin(Plugin): + def __init__(self) -> None: + super().__init__(__name__, __version__, __package__, logger) + + +Plugin.register_plugin(AsyncAIPlugin()) + +# Cleanup docs of unexported modules +_module = dir() +NOT_IN_ALL = [m for m in _module if m not in __all__] + +__pdoc__ = {} + +for n in NOT_IN_ALL: + __pdoc__[n] = False diff --git a/livekit-plugins/livekit-plugins-asyncai/livekit/plugins/asyncai/log.py b/livekit-plugins/livekit-plugins-asyncai/livekit/plugins/asyncai/log.py new file mode 100644 index 0000000000..c629ba440a --- /dev/null +++ b/livekit-plugins/livekit-plugins-asyncai/livekit/plugins/asyncai/log.py @@ -0,0 +1,3 @@ +import logging + +logger = logging.getLogger("livekit.plugins.asyncai") diff --git a/livekit-plugins/livekit-plugins-asyncai/livekit/plugins/asyncai/models.py b/livekit-plugins/livekit-plugins-asyncai/livekit/plugins/asyncai/models.py new file mode 100644 index 0000000000..e253285090 --- /dev/null +++ b/livekit-plugins/livekit-plugins-asyncai/livekit/plugins/asyncai/models.py @@ -0,0 +1,7 @@ +from typing import Literal + +TTSEncoding = Literal["pcm_s16le", "pcm_f32le", "pcm_mulaw"] + +TTSModels = Literal["asyncflow_multilingual_v1.0", "asyncflow_v2.0"] +TTSLanguages = Literal["en", "de", "es", "fr", "it"] +TTSDefaultVoiceId = "e0f39dc4-f691-4e78-bba5-5c636692cc04" diff --git a/livekit-plugins/livekit-plugins-asyncai/livekit/plugins/asyncai/tts.py b/livekit-plugins/livekit-plugins-asyncai/livekit/plugins/asyncai/tts.py new file mode 100644 index 0000000000..ee839af025 --- /dev/null +++ b/livekit-plugins/livekit-plugins-asyncai/livekit/plugins/asyncai/tts.py @@ -0,0 +1,293 @@ +# Copyright 2023 LiveKit, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import asyncio +import base64 +import json +import os +import weakref +from dataclasses import dataclass, replace +from typing import Union, cast + +import aiohttp + +from livekit.agents import ( + APIConnectionError, + APIConnectOptions, + APIStatusError, + APITimeoutError, + tokenize, + tts, + utils, +) +from livekit.agents.types import DEFAULT_API_CONNECT_OPTIONS, NOT_GIVEN, NotGivenOr +from livekit.agents.utils import is_given + +from .log import logger +from .models import ( + TTSDefaultVoiceId, + TTSEncoding, + TTSModels, +) + +API_AUTH_HEADER = "api_key" +API_VERSION_HEADER = "version" +API_VERSION = "v1" + +BUFFERED_WORDS_COUNT = 10 + + +@dataclass +class _TTSOptions: + model: TTSModels | str + encoding: TTSEncoding + sample_rate: int + voice: str | list[float] + api_key: str + language: str + base_url: str + + def get_http_url(self, path: str) -> str: + return f"{self.base_url}{path}" + + def get_ws_url(self, path: str) -> str: + return f"{self.base_url.replace('http', 'ws', 1)}{path}" + + +class TTS(tts.TTS): + def __init__( + self, + *, + api_key: str | None = None, + model: TTSModels | str = "asyncflow_multilingual_v1.0", + language: str = "en", + encoding: TTSEncoding = "pcm_s16le", + voice: str = TTSDefaultVoiceId, + sample_rate: int = 32000, + http_session: aiohttp.ClientSession | None = None, + base_url: str = "https://api.async.ai", + ) -> None: + """ + Create a new instance of Async TTS. + + See https://docs.async.ai/text-to-speech-websocket-3477526w0 for more details + on the the Async API. + + Args: + model (TTSModels, optional): The Async TTS model to use. Defaults to "asyncflow_multilingual_v1.0". + language (str, optional): The language code for synthesis. Defaults to "en". + encoding (TTSEncoding, optional): The audio encoding format. Defaults to "pcm_s16le". + voice (str, optional): The voice ID. + sample_rate (int, optional): The audio sample rate in Hz. Defaults to 32000. + api_key (str, optional): The Async API key. If not provided, it will be + read from the ASYNCAI_API_KEY environment variable. + http_session (aiohttp.ClientSession | None, optional): An existing aiohttp + ClientSession to use. If not provided, a new session will be created. + base_url (str, optional): The base URL for the Async API. Defaults to "https://api.async.ai". + """ + + super().__init__( + capabilities=tts.TTSCapabilities(streaming=True), + sample_rate=sample_rate, + num_channels=1, + ) + async_api_key = api_key or os.environ.get("ASYNCAI_API_KEY") + if not async_api_key: + raise ValueError("ASYNCAI_API_KEY must be set") + + self._opts = _TTSOptions( + model=model, + language=language, + encoding=encoding, + sample_rate=sample_rate, + voice=voice, + api_key=async_api_key, + base_url=base_url, + ) + self._session = http_session + self._pool = utils.ConnectionPool[aiohttp.ClientWebSocketResponse]( + connect_cb=self._connect_ws, + close_cb=self._close_ws, + max_session_duration=300, + mark_refreshed_on_get=True, + ) + self._streams = weakref.WeakSet[SynthesizeStream]() + + async def _connect_ws(self, timeout: float) -> aiohttp.ClientWebSocketResponse: + session = self._ensure_session() + url = self._opts.get_ws_url( + f"/text_to_speech/websocket/ws?api_key={self._opts.api_key}&version={API_VERSION}" + ) + + init_payload = { + "model_id": self._opts.model, + "voice": {"mode": "id", "id": self._opts.voice}, + "output_format": { + "container": "raw", + "encoding": self._opts.encoding, + "sample_rate": self._opts.sample_rate, + }, + } + ws = await asyncio.wait_for(session.ws_connect(url), timeout) + await ws.send_str(json.dumps(init_payload)) + return ws + + async def _close_ws(self, ws: aiohttp.ClientWebSocketResponse) -> None: + await ws.close() + + def _ensure_session(self) -> aiohttp.ClientSession: + if not self._session: + self._session = utils.http_context.http_session() + + return self._session + + def prewarm(self) -> None: + self._pool.prewarm() + + def update_options( + self, + *, + model: NotGivenOr[TTSModels | str] = NOT_GIVEN, + language: NotGivenOr[str] = NOT_GIVEN, + voice: NotGivenOr[str] = NOT_GIVEN, + ) -> None: + """ + Update the Text-to-Speech (TTS) configuration options. + + This method allows updating the TTS settings, including model type, language and voice. + If any parameter is not provided, the existing value will be retained. + + Args: + model (TTSModels, optional): The Async TTS model to use. Defaults to "asyncflow_multilingual_v1.0". + language (str, optional): The language code for synthesis. Defaults to "en". + voice (str, optional): The voice ID. + """ + if is_given(model): + self._opts.model = model + if is_given(language): + self._opts.language = language + if is_given(voice): + self._opts.voice = cast(Union[str, list[float]], voice) + + def stream( + self, *, conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS + ) -> SynthesizeStream: + return SynthesizeStream(tts=self, conn_options=conn_options) + + def synthesize( + self, text: str, *, conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS + ): + pass + + async def aclose(self) -> None: + for stream in list(self._streams): + await stream.aclose() + + self._streams.clear() + await self._pool.aclose() + + +class SynthesizeStream(tts.SynthesizeStream): + def __init__(self, *, tts: TTS, conn_options: APIConnectOptions): + super().__init__(tts=tts, conn_options=conn_options) + self._tts: TTS = tts + self._sent_tokenizer_stream = tokenize.basic.SentenceTokenizer( + min_sentence_len=BUFFERED_WORDS_COUNT + ).stream() + self._opts = replace(tts._opts) + + async def _run(self, output_emitter: tts.AudioEmitter) -> None: + request_id = utils.shortuuid() + output_emitter.initialize( + request_id=request_id, + sample_rate=self._opts.sample_rate, + num_channels=1, + mime_type="audio/pcm", + stream=True, + ) + + async def _sentence_stream_task(ws: aiohttp.ClientWebSocketResponse) -> None: + async for ev in self._sent_tokenizer_stream: + token_pkt = {} + token_pkt["transcript"] = ev.token + " " + token_pkt["force"] = True + self._mark_started() + await ws.send_str(json.dumps(token_pkt)) + + # end_pkt = {} + # end_pkt["transcript"] = "" + # await ws.send_str(json.dumps(end_pkt)) + + async def _input_task() -> None: + async for data in self._input_ch: + if isinstance(data, self._FlushSentinel): + self._sent_tokenizer_stream.flush() + continue + + self._sent_tokenizer_stream.push_text(data) + + self._sent_tokenizer_stream.end_input() + + async def _recv_task(ws: aiohttp.ClientWebSocketResponse) -> None: + current_segment_id: str | None = None + while True: + msg = await ws.receive() + if msg.type in ( + aiohttp.WSMsgType.CLOSED, + aiohttp.WSMsgType.CLOSE, + aiohttp.WSMsgType.CLOSING, + ): + raise APIStatusError( + "Async connection closed unexpectedly", request_id=request_id + ) + + if msg.type != aiohttp.WSMsgType.TEXT: + logger.warning("unexpected Async message type %s", msg.type) + continue + + data = json.loads(msg.data) + if current_segment_id is None: + current_segment_id = "new_segment" + output_emitter.start_segment(segment_id="new_segment") + if data.get("audio"): + b64data = base64.b64decode(data["audio"]) + output_emitter.push(b64data) + if data.get("final") and data["final"] is True: + output_emitter.end_input() + break + else: + logger.warning("unexpected message %s", data) + + try: + async with self._tts._pool.connection(timeout=self._conn_options.timeout) as ws: + tasks = [ + asyncio.create_task(_input_task()), + asyncio.create_task(_sentence_stream_task(ws)), + asyncio.create_task(_recv_task(ws)), + ] + + try: + await asyncio.gather(*tasks) + finally: + await utils.aio.gracefully_cancel(*tasks) + except asyncio.TimeoutError: + raise APITimeoutError() from None + except aiohttp.ClientResponseError as e: + raise APIStatusError( + message=e.message, status_code=e.status, request_id=None, body=None + ) from None + except Exception as e: + raise APIConnectionError() from e diff --git a/livekit-plugins/livekit-plugins-asyncai/livekit/plugins/asyncai/version.py b/livekit-plugins/livekit-plugins-asyncai/livekit/plugins/asyncai/version.py new file mode 100644 index 0000000000..d7b9a51fce --- /dev/null +++ b/livekit-plugins/livekit-plugins-asyncai/livekit/plugins/asyncai/version.py @@ -0,0 +1,15 @@ +# Copyright 2023 LiveKit, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +__version__ = "1.2.14" diff --git a/livekit-plugins/livekit-plugins-asyncai/pyproject.toml b/livekit-plugins/livekit-plugins-asyncai/pyproject.toml new file mode 100644 index 0000000000..3df19b1377 --- /dev/null +++ b/livekit-plugins/livekit-plugins-asyncai/pyproject.toml @@ -0,0 +1,39 @@ +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[project] +name = "livekit-plugins-asyncai" +dynamic = ["version"] +description = "LiveKit Agents Plugin for AsyncAI" +readme = "README.md" +license = "Apache-2.0" +requires-python = ">=3.9.0" +authors = [{ name = "LiveKit", email = "hello@livekit.io" }] +keywords = ["webrtc", "realtime", "audio", "video", "livekit"] +classifiers = [ + "Intended Audience :: Developers", + "License :: OSI Approved :: Apache Software License", + "Topic :: Multimedia :: Sound/Audio", + "Topic :: Multimedia :: Video", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3 :: Only", +] +dependencies = ["livekit-agents>=1.2.14"] + +[project.urls] +Documentation = "https://docs.livekit.io" +Website = "https://livekit.io/" +Source = "https://github.com/livekit/agents" + +[tool.hatch.version] +path = "livekit/plugins/asyncai/version.py" + +[tool.hatch.build.targets.wheel] +packages = ["livekit"] + +[tool.hatch.build.targets.sdist] +include = ["/livekit"]