Skip to content

Commit 39c13e8

Browse files
authored
separate inference from web server (#1342)
1 parent d441e0f commit 39c13e8

File tree

1 file changed

+85
-76
lines changed

1 file changed

+85
-76
lines changed

06_gpu_and_ml/speech-to-text/streaming_parakeet.py

Lines changed: 85 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@
122122

123123
@app.cls(volumes={"/cache": model_cache}, gpu="a10g", image=image)
124124
@modal.concurrent(max_inputs=14, target_inputs=10)
125-
class Parakeet:
125+
class ParakeetModel:
126126
@modal.enter()
127127
def load(self):
128128
import logging
@@ -146,81 +146,7 @@ def transcribe(self, audio_bytes: bytes) -> str:
146146

147147
return output[0].text
148148

149-
@modal.asgi_app()
150-
def web(self):
151-
from fastapi import FastAPI, Response, WebSocket
152-
from fastapi.responses import HTMLResponse
153-
from fastapi.staticfiles import StaticFiles
154-
155-
web_app = FastAPI()
156-
web_app.mount("/static", StaticFiles(directory="/frontend"))
157-
158-
@web_app.get("/status")
159-
async def status():
160-
return Response(status_code=200)
161-
162-
# serve frontend
163-
@web_app.get("/")
164-
async def index():
165-
return HTMLResponse(content=open("/frontend/index.html").read())
166-
167-
@web_app.websocket("/ws")
168-
async def run_with_websocket(ws: WebSocket):
169-
from fastapi import WebSocketDisconnect
170-
from pydub import AudioSegment
171-
172-
await ws.accept()
173-
174-
# initialize an empty audio segment
175-
audio_segment = AudioSegment.empty()
176-
177-
try:
178-
while True:
179-
# receive a chunk of audio data and convert it to an audio segment
180-
chunk = await ws.receive_bytes()
181-
if chunk == END_OF_STREAM:
182-
await ws.send_bytes(END_OF_STREAM)
183-
break
184-
audio_segment, text = await self.handle_audio_chunk(
185-
chunk, audio_segment
186-
)
187-
if text:
188-
await ws.send_text(text)
189-
except Exception as e:
190-
if not isinstance(e, WebSocketDisconnect):
191-
print(f"Error handling websocket: {type(e)}: {e}")
192-
try:
193-
await ws.close(code=1011, reason="Internal server error")
194-
except Exception as e:
195-
print(f"Error closing websocket: {type(e)}: {e}")
196-
197-
return web_app
198-
199149
@modal.method()
200-
async def run_with_queue(self, q: modal.Queue):
201-
from pydub import AudioSegment
202-
203-
# initialize an empty audio segment
204-
audio_segment = AudioSegment.empty()
205-
206-
try:
207-
while True:
208-
# receive a chunk of audio data and convert it to an audio segment
209-
chunk = await q.get.aio(partition="audio")
210-
211-
if chunk == END_OF_STREAM:
212-
await q.put.aio(END_OF_STREAM, partition="transcription")
213-
break
214-
215-
audio_segment, text = await self.handle_audio_chunk(
216-
chunk, audio_segment
217-
)
218-
if text:
219-
await q.put.aio(text, partition="transcription")
220-
except Exception as e:
221-
print(f"Error handling queue: {type(e)}: {e}")
222-
return
223-
224150
async def handle_audio_chunk(
225151
self,
226152
chunk: bytes,
@@ -272,6 +198,89 @@ async def handle_audio_chunk(
272198
print("❌ Transcription error:", e)
273199
raise e
274200

201+
@modal.method()
202+
async def run_with_queue(self, q: modal.Queue):
203+
from pydub import AudioSegment
204+
205+
# initialize an empty audio segment
206+
audio_segment = AudioSegment.empty()
207+
208+
try:
209+
while True:
210+
# receive a chunk of audio data and convert it to an audio segment
211+
chunk = await q.get.aio(partition="audio")
212+
213+
if chunk == END_OF_STREAM:
214+
await q.put.aio(END_OF_STREAM, partition="transcription")
215+
break
216+
217+
audio_segment, text = await self.handle_audio_chunk.remote.aio(
218+
chunk, audio_segment
219+
)
220+
if text:
221+
await q.put.aio(text, partition="transcription")
222+
except Exception as e:
223+
print(f"Error handling queue: {type(e)}: {e}")
224+
return
225+
226+
227+
@app.cls(image=image)
228+
@modal.concurrent(max_inputs=1000)
229+
class WebServer:
230+
@modal.asgi_app()
231+
def web(self):
232+
from fastapi import FastAPI, Response, WebSocket
233+
from fastapi.responses import HTMLResponse
234+
from fastapi.staticfiles import StaticFiles
235+
236+
web_app = FastAPI()
237+
web_app.mount("/static", StaticFiles(directory="/frontend"))
238+
239+
@web_app.get("/status")
240+
async def status():
241+
return Response(status_code=200)
242+
243+
# serve frontend
244+
@web_app.get("/")
245+
async def index():
246+
return HTMLResponse(content=open("/frontend/index.html").read())
247+
248+
@web_app.websocket("/ws")
249+
async def run_with_websocket(ws: WebSocket):
250+
from fastapi import WebSocketDisconnect
251+
252+
await ws.accept()
253+
254+
from pydub import AudioSegment
255+
256+
model = ParakeetModel()
257+
258+
# initialize an empty audio segment
259+
audio_segment = AudioSegment.empty()
260+
261+
try:
262+
while True:
263+
# receive a chunk of audio data and convert it to an audio segment
264+
chunk = await ws.receive_bytes()
265+
if chunk == END_OF_STREAM:
266+
await ws.send_bytes(END_OF_STREAM)
267+
break
268+
(
269+
audio_segment,
270+
text,
271+
) = await model.handle_audio_chunk.remote.aio(chunk, audio_segment)
272+
if text:
273+
await ws.send_text(text)
274+
except Exception as e:
275+
if not isinstance(e, WebSocketDisconnect):
276+
print(f"Error handling websocket: {type(e)}: {e}")
277+
try:
278+
await ws.close(code=1011, reason="Internal server error")
279+
except Exception as e:
280+
print(f"Error closing websocket: {type(e)}: {e}")
281+
282+
return web_app
283+
275284

276285
# ## Running transcription from a local Python client
277286

@@ -299,7 +308,7 @@ async def main(audio_url: str = AUDIO_URL):
299308

300309
print("🎤 Starting Transcription")
301310
with modal.Queue.ephemeral() as q:
302-
Parakeet().run_with_queue.spawn(q)
311+
ParakeetModel().run_with_queue.spawn(q)
303312
send = asyncio.create_task(send_audio(q, audio_data))
304313
recv = asyncio.create_task(receive_text(q))
305314
await asyncio.gather(send, recv)

0 commit comments

Comments
 (0)