Skip to content

Commit 442cde8

Browse files
authored
Removing explicit checks for async then/map/watch callbacks (#183)
1 parent d2e3e4c commit 442cde8

File tree

3 files changed

+13
-26
lines changed

3 files changed

+13
-26
lines changed

docs/api/chat.mdx

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1910,11 +1910,6 @@ def map(
19101910
~~~
19111911
"""
19121912
for callback in callbacks:
1913-
if not asyncio.iscoroutinefunction(callback):
1914-
raise TypeError(
1915-
f"Callback '{get_qualified_name(callback)}' must be an async function",
1916-
)
1917-
19181913
if allow_duplicates:
19191914
continue
19201915

@@ -2661,11 +2656,6 @@ def then(
26612656
~~~
26622657
"""
26632658
for callback in callbacks:
2664-
if not asyncio.iscoroutinefunction(callback):
2665-
raise TypeError(
2666-
f"Callback '{get_qualified_name(callback)}' must be an async function",
2667-
)
2668-
26692659
if allow_duplicates:
26702660
continue
26712661

rigging/chat.py

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -618,7 +618,7 @@ def __call__(
618618
self,
619619
chat: Chat,
620620
/,
621-
) -> t.Awaitable[Chat | None]: ...
621+
) -> t.Awaitable[Chat | None] | Chat | None: ...
622622

623623

624624
@runtime_checkable
@@ -642,7 +642,7 @@ def __call__(
642642
self,
643643
chats: list[Chat],
644644
/,
645-
) -> t.Awaitable[list[Chat]]: ...
645+
) -> t.Awaitable[list[Chat]] | list[Chat]: ...
646646

647647

648648
@runtime_checkable
@@ -773,7 +773,9 @@ async def traced_watch_callback(chats: list[Chat]) -> None:
773773
chat_count=len(chats),
774774
chat_ids=[str(c.uuid) for c in chats],
775775
):
776-
await callback(chats)
776+
result = callback(chats)
777+
if inspect.isawaitable(result):
778+
await result
777779

778780
return traced_watch_callback
779781

@@ -1100,11 +1102,6 @@ async def process(chat: Chat) -> Chat | None:
11001102
```
11011103
"""
11021104
for callback in callbacks:
1103-
if not asyncio.iscoroutinefunction(callback):
1104-
raise TypeError(
1105-
f"Callback '{get_qualified_name(callback)}' must be an async function",
1106-
)
1107-
11081105
if allow_duplicates:
11091106
continue
11101107

@@ -1147,11 +1144,6 @@ async def process(chats: list[Chat]) -> list[Chat]:
11471144
```
11481145
"""
11491146
for callback in callbacks:
1150-
if not asyncio.iscoroutinefunction(callback):
1151-
raise TypeError(
1152-
f"Callback '{get_qualified_name(callback)}' must be an async function",
1153-
)
1154-
11551147
if allow_duplicates:
11561148
continue
11571149

@@ -1565,9 +1557,8 @@ async def complete() -> None:
15651557
exit_stack.push_async_callback(complete)
15661558

15671559
result = callback(state.chat)
1568-
15691560
if inspect.isawaitable(result):
1570-
result = await result # type: ignore [assignment]
1561+
result = await result
15711562

15721563
if result is None or isinstance(result, Chat):
15731564
state.chat = result or state.chat

tests/test_message_slicing.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -953,7 +953,13 @@ def test_slice_with_empty_string_target() -> None:
953953
"""Test marking slice with empty string target."""
954954
message = Message("assistant", "Some content here")
955955

956-
slice_obj = message.mark_slice("")
956+
# Expect a "Empty string target provided" warning
957+
with warnings.catch_warnings(record=True) as w:
958+
warnings.simplefilter("always")
959+
slice_obj = message.mark_slice("")
960+
assert len(w) == 1
961+
assert issubclass(w[-1].category, MessageWarning)
962+
assert "Empty string target provided" in str(w[-1].message)
957963

958964
# Empty string should not create a valid slice
959965
assert slice_obj is None

0 commit comments

Comments
 (0)