diff --git a/mockafka/aiokafka/aiokafka_consumer.py b/mockafka/aiokafka/aiokafka_consumer.py index 9bddc6f..09a3b4f 100644 --- a/mockafka/aiokafka/aiokafka_consumer.py +++ b/mockafka/aiokafka/aiokafka_consumer.py @@ -244,6 +244,24 @@ async def getmany( return dict(result) + def __aiter__(self): + if self._is_closed: + raise ConsumerStoppedError() + return self + + async def __anext__(self) -> ConsumerRecord[bytes, bytes]: + while True: + try: + result = await self.getone() + if result is None: + # Follow the lead of `getone`, though note that we should + # address this as part of any fix to + # https://github.com/alm0ra/mockafka-py/issues/117 + raise StopAsyncIteration + return result + except ConsumerStoppedError: + raise StopAsyncIteration from None + async def __aenter__(self) -> Self: await self.start() return self diff --git a/tests/test_aiokafka/test_aiokafka_consumer.py b/tests/test_aiokafka/test_aiokafka_consumer.py index 9112e3b..badcadf 100644 --- a/tests/test_aiokafka/test_aiokafka_consumer.py +++ b/tests/test_aiokafka/test_aiokafka_consumer.py @@ -1,5 +1,6 @@ from __future__ import annotations +import sys import itertools from unittest import IsolatedAsyncioTestCase @@ -17,6 +18,13 @@ ) from mockafka.kafka_store import KafkaStore +if sys.version_info < (3, 10): + def aiter(async_iterable): # noqa: A001 + return async_iterable.__aiter__() + + async def anext(async_iterable): # noqa: A001 + return await async_iterable.__anext__() + @pytest.mark.asyncio class TestAIOKAFKAFakeConsumer(IsolatedAsyncioTestCase): @@ -40,7 +48,7 @@ def topic(self): def create_topic(self): self.kafka.create_partition(topic=self.test_topic, partitions=16) - async def produce_message(self): + async def produce_two_messages(self): await self.producer.send( topic=self.test_topic, partition=0, key=b"test", value=b"test" ) @@ -51,6 +59,40 @@ async def produce_message(self): async def test_consume(self): await self.test_poll_with_commit() + async def test_async_iterator(self): + self.create_topic() + await self.produce_two_messages() + self.consumer.subscribe(topics=[self.test_topic]) + await self.consumer.start() + + iterator = aiter(self.consumer) + message = await anext(iterator) + self.assertEqual(message.value, b"test") + + message = await anext(iterator) + self.assertEqual(message.value, b"test1") + + # Technically at this point aiokafka's consumer would block + # indefinitely, however since that's not useful in tests we instead stop + # iterating. + with pytest.raises(StopAsyncIteration): + await anext(iterator) + + async def test_async_iterator_closed_early(self): + self.create_topic() + await self.produce_two_messages() + self.consumer.subscribe(topics=[self.test_topic]) + await self.consumer.start() + + iterator = aiter(self.consumer) + message = await anext(iterator) + self.assertEqual(message.value, b"test") + + await self.consumer.stop() + + with pytest.raises(StopAsyncIteration): + await anext(iterator) + async def test_start(self): # check consumer store is empty await self.consumer.start() @@ -69,7 +111,7 @@ async def test_start(self): async def test_poll_without_commit(self): self.create_topic() - await self.produce_message() + await self.produce_two_messages() self.consumer.subscribe(topics=[self.test_topic]) await self.consumer.start() @@ -83,7 +125,7 @@ async def test_poll_without_commit(self): async def test_partition_specific_poll_without_commit(self): self.create_topic() - await self.produce_message() + await self.produce_two_messages() self.consumer.subscribe(topics=[self.test_topic]) await self.consumer.start() @@ -99,7 +141,7 @@ async def test_partition_specific_poll_without_commit(self): async def test_poll_with_commit(self): self.create_topic() - await self.produce_message() + await self.produce_two_messages() self.consumer.subscribe(topics=[self.test_topic]) await self.consumer.start() @@ -116,7 +158,7 @@ async def test_poll_with_commit(self): async def test_getmany_without_commit(self): self.create_topic() - await self.produce_message() + await self.produce_two_messages() await self.producer.send( topic=self.test_topic, partition=2, key=b"test2", value=b"test2" ) @@ -145,7 +187,7 @@ async def test_getmany_without_commit(self): async def test_getmany_with_limit_without_commit(self): self.create_topic() - await self.produce_message() + await self.produce_two_messages() await self.producer.send( topic=self.test_topic, partition=0, key=b"test2", value=b"test2" ) @@ -182,7 +224,7 @@ async def test_getmany_with_limit_without_commit(self): async def test_getmany_specific_poll_without_commit(self): self.create_topic() - await self.produce_message() + await self.produce_two_messages() await self.producer.send( topic=self.test_topic, partition=1, key=b"test2", value=b"test2" ) @@ -210,7 +252,7 @@ async def test_getmany_specific_poll_without_commit(self): async def test_getmany_with_commit(self): self.create_topic() - await self.produce_message() + await self.produce_two_messages() await self.producer.send( topic=self.test_topic, partition=2, key=b"test2", value=b"test2" ) @@ -287,7 +329,7 @@ async def test_lifecycle(self): self.assertEqual(self.consumer.subscribed_topic, topics) - await self.produce_message() + await self.produce_two_messages() messages = { tp: self.summarise(msgs) @@ -336,7 +378,7 @@ async def test_context_manager(self): async with self.consumer as consumer: self.assertEqual(self.consumer, consumer) - await self.produce_message() + await self.produce_two_messages() messages = { tp: self.summarise(msgs) @@ -373,3 +415,6 @@ async def test_consumer_is_stopped(self): self.consumer.subscribe(topics=topics) with self.assertRaises(ConsumerStoppedError): await self.consumer.getone() + + with self.assertRaises(ConsumerStoppedError): + aiter(self.consumer)