Skip to content

Commit 9155103

Browse files
committed
Clean up some of the magic helpers and add typing
1 parent 9592ee9 commit 9155103

File tree

4 files changed

+36
-14
lines changed

4 files changed

+36
-14
lines changed

examples/python/fastapi/app.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,11 +84,14 @@ async def updates2():
8484
# This is also identical, but yielding a string of fragments automatically calls merge_fragments
8585
# and dicts automatically calls merge_signals
8686
@app.get("/updates3")
87+
# Wraps the resulting async generator in a DatastarStreamingResponse
8788
@sse_generator
8889
async def updates3():
8990
while True:
91+
# Implicit merge_fragments
9092
yield f"""<span id="currentTime">{datetime.now().isoformat()}"""
9193
await asyncio.sleep(1)
94+
# Implicit merge_signals
9295
yield {"currentTime": f"{datetime.now().isoformat()}"}
9396
await asyncio.sleep(1)
9497

sdk/python/src/datastar_py/quart.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from quart import make_response as _make_response
44

5-
from .sse import ServerSentEventGenerator, SSE_HEADERS, _async_map, _wrap_event
5+
from .sse import ServerSentEventGenerator, SSE_HEADERS, _sse_iterable_wrapper
66

77

88
async def make_datastar_response(async_generator):
@@ -14,7 +14,6 @@ async def make_datastar_response(async_generator):
1414
def sse_generator(generator_func):
1515
@wraps(generator_func)
1616
async def _wrapper(*args, **kwargs):
17-
content = _async_map(_wrap_event, generator_func(*args, **kwargs))
18-
17+
content = _sse_iterable_wrapper(generator_func(*args, **kwargs))
1918
return await make_datastar_response(content)
2019
return _wrapper

sdk/python/src/datastar_py/sse.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,17 @@
1+
from __future__ import annotations
2+
13
import json
24
from itertools import chain
3-
from typing import Optional, Protocol, Union, runtime_checkable
5+
from typing import (
6+
Optional,
7+
Protocol,
8+
Union,
9+
runtime_checkable,
10+
Callable,
11+
AsyncIterable,
12+
TypeVar,
13+
Iterable,
14+
)
415

516
import datastar_py.consts as consts
617

@@ -178,15 +189,29 @@ def redirect(cls, location: str):
178189
return cls.execute_script(f"setTimeout(() => window.location = '{location}')")
179190

180191

181-
def _wrap_event(event):
182-
if isinstance(event, _HtmlProvider) or (isinstance(event, str) and event.startswith("<")):
192+
def _wrap_event(event: str | _HtmlProvider | dict) -> str:
193+
if isinstance(event, _HtmlProvider) or (
194+
isinstance(event, str) and event.startswith("<")
195+
):
183196
return ServerSentEventGenerator.merge_fragments(event)
184197
elif isinstance(event, dict):
185198
return ServerSentEventGenerator.merge_signals(event)
186199
else:
187200
return event
188201

189202

190-
async def _async_map(func, async_iter):
203+
async def _async_map(func: Callable, async_iter: AsyncIterable) -> AsyncIterable:
191204
async for item in async_iter:
192205
yield func(item)
206+
207+
208+
SyncOrAsyncIterable = TypeVar("SyncOrAsyncIterable", AsyncIterable, Iterable)
209+
210+
211+
def _sse_iterable_wrapper(iterable: SyncOrAsyncIterable) -> SyncOrAsyncIterable:
212+
"""Wraps an iterable to allow implicitly turning fragments and dictionaries
213+
into merge-fragments and merge-signals events."""
214+
if isinstance(iterable, AsyncIterable):
215+
return _async_map(_wrap_event, iterable)
216+
else:
217+
return map(_wrap_event, iterable)

sdk/python/src/datastar_py/starlette.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from starlette.responses import StreamingResponse as _StreamingResponse
55

6-
from .sse import SSE_HEADERS, ServerSentEventGenerator, _wrap_event, _async_map
6+
from .sse import SSE_HEADERS, ServerSentEventGenerator, _wrap_event, _async_map, _sse_iterable_wrapper
77

88

99
class DatastarStreamingResponse(_StreamingResponse):
@@ -17,10 +17,5 @@ def sse_generator(generator_func):
1717
@wraps(generator_func)
1818
def _wrapper(*args, **kwargs):
1919
content = generator_func(*args, **kwargs)
20-
if isinstance(content, typing.AsyncIterable):
21-
content = _async_map(_wrap_event, content)
22-
else:
23-
content = map(_wrap_event, content)
24-
25-
return DatastarStreamingResponse(content)
20+
return DatastarStreamingResponse(_sse_iterable_wrapper(content))
2621
return _wrapper

0 commit comments

Comments
 (0)