diff --git a/prefect_qiskit/primitives/runner.py b/prefect_qiskit/primitives/runner.py index 7023afa..5e5c2b4 100644 --- a/prefect_qiskit/primitives/runner.py +++ b/prefect_qiskit/primitives/runner.py @@ -13,7 +13,6 @@ from typing import Literal -from prefect import task from prefect.artifacts import create_table_artifact from prefect.blocks.abstract import CredentialsBlock from prefect.context import TaskRunContext @@ -27,6 +26,9 @@ from prefect_qiskit.models import JobMetrics from prefect_qiskit.primitives.job import PrimitiveJob +# TODO integration of metrics database +# TODO automatic job split (workflow optimization) + async def retry_on_failure(_task, _task_run, state): try: @@ -41,14 +43,6 @@ async def retry_on_failure(_task, _task_run, state): return False -# TODO integration of metrics database -# TODO automatic job split (workflow optimization) - - -@task( - name="run_primitive", - retry_condition_fn=retry_on_failure, -) async def run_primitive( *, primitive_blocs: list[SamplerPub] | list[EstimatorPub], @@ -58,9 +52,9 @@ async def run_primitive( enable_analytics: bool = True, options: dict | None = None, ) -> PrimitiveResult: - """ - This function implements a Prefect task to manage the execution of - Qiskit Primitives on an abstract layer, + """A core logic to make a primitive job and returns a result. + + This function manages the execution of Qiskit Primitives on an abstract layer, providing built-in execution failure protection. It accepts PUBs and options, submitting this data to quantum computers via the vendor's API. diff --git a/prefect_qiskit/runtime.py b/prefect_qiskit/runtime.py index f90cac7..6dcddf7 100644 --- a/prefect_qiskit/runtime.py +++ b/prefect_qiskit/runtime.py @@ -21,11 +21,12 @@ """ import asyncio -from collections.abc import Callable +from functools import partial from prefect._internal.compatibility.async_dispatch import async_dispatch from prefect.blocks.core import Block from prefect.cache_policies import NO_CACHE, CacheKeyFnPolicy +from prefect.tasks import Task from prefect.utilities.asyncutils import run_coro_as_sync from pydantic import Field, model_validator from qiskit.primitives import PrimitiveResult @@ -35,7 +36,7 @@ from typing_extensions import Self from prefect_qiskit.models import AsyncRuntimeClientInterface -from prefect_qiskit.primitives.runner import run_primitive +from prefect_qiskit.primitives.runner import retry_on_failure, run_primitive from prefect_qiskit.utils.pub_hasher import pub_hasher from prefect_qiskit.vendors import QiskitAerCredentials, QuantumCredentialsT @@ -218,14 +219,11 @@ async def async_sampler( Returns: Qiskit PrimitiveResult object. """ - opted_run_primitive = self._setup_runner(tags=tags) + opted_run_primitive = self.build_runner_task(tags=tags) return await opted_run_primitive( primitive_blocs=list(map(SamplerPub.coerce, sampler_pubs)), program_type="sampler", - resource_name=self.resource_name, - credentials=self.credentials, - enable_analytics=self.enable_job_analytics, options=options, ) @@ -246,14 +244,11 @@ def sampler( Returns: Qiskit PrimitiveResult object. """ - opted_run_primitive = self._setup_runner(tags=tags) + opted_run_primitive = self.build_runner_task(tags=tags) coro = opted_run_primitive( primitive_blocs=list(map(SamplerPub.coerce, sampler_pubs)), program_type="sampler", - resource_name=self.resource_name, - credentials=self.credentials, - enable_analytics=self.enable_job_analytics, options=options, ) return asyncio.run(coro) @@ -274,14 +269,11 @@ async def async_estimator( Returns: Qiskit PrimitiveResult object. """ - opted_run_primitive = self._setup_runner(tags=tags) + opted_run_primitive = self.build_runner_task(tags=tags) return await opted_run_primitive( primitive_blocs=list(map(EstimatorPub.coerce, estimator_pubs)), program_type="estimator", - resource_name=self.resource_name, - credentials=self.credentials, - enable_analytics=self.enable_job_analytics, options=options, ) @@ -302,51 +294,59 @@ def estimator( Returns: Qiskit PrimitiveResult object. """ - opted_run_primitive = self._setup_runner(tags=tags) + opted_run_primitive = self.build_runner_task(tags=tags) coro = opted_run_primitive( primitive_blocs=list(map(EstimatorPub.coerce, estimator_pubs)), program_type="estimator", - resource_name=self.resource_name, - credentials=self.credentials, - enable_analytics=self.enable_job_analytics, options=options, ) return asyncio.run(coro) - def _setup_runner( + def build_runner_task( self, - tags: list[str] | None = None, - ) -> Callable: - """A helper function to setup a primitive runner. + **kwargs, + ) -> Task: + """Build Prefect Task object that runs primitive. Args: - tags: Arbitrary labels to add to the Primitive task tags of execution. + kwargs: Keyword arguments to instantiate Prefect Task. Returns: - A primitive runner callable with configured task options. + A configured prefect task. """ + configured_runner = partial( + run_primitive, + resource_name=self.resource_name, + credentials=self.credentials, + enable_analytics=self.enable_job_analytics, + ) + if self.execution_cache: - task_options = { - "cache_policy": CacheKeyFnPolicy(cache_key_fn=_primitive_cache), - "persist_result": True, - "result_serializer": "compressed/pickle", - } + kwargs.update( + { + "cache_policy": CacheKeyFnPolicy(cache_key_fn=_primitive_cache), + "persist_result": True, + "result_serializer": "compressed/pickle", + } + ) else: - task_options = { - "cache_policy": NO_CACHE, - "persist_result": False, - } - task_options.update( - { - "retries": self.max_retry, - "retry_delay_seconds": self.retry_delay, - "timeout_seconds": self.timeout, - "tags": tags or ["primitive-execute"], - } - ) + kwargs.update( + { + "cache_policy": NO_CACHE, + "persist_result": False, + } + ) - return run_primitive.with_options(**task_options) + return Task( + fn=configured_runner, + name="run_primitive", + retries=self.max_retry, + retry_delay_seconds=self.retry_delay, + timeout_seconds=self.timeout, + retry_condition_fn=retry_on_failure, + **kwargs, + ) def _primitive_cache(_, parameters) -> str: diff --git a/tests/unit/test_primitive_runner.py b/tests/unit/test_primitive_runner.py index 1ad8510..c286da3 100644 --- a/tests/unit/test_primitive_runner.py +++ b/tests/unit/test_primitive_runner.py @@ -24,7 +24,7 @@ from prefect_qiskit.exceptions import RuntimeJobFailure from prefect_qiskit.models import JobMetrics -from prefect_qiskit.primitives.runner import run_primitive +from prefect_qiskit.runtime import QuantumRuntime from prefect_qiskit.vendors.qiskit_aer import QiskitAerCredentials from prefect_qiskit.vendors.qiskit_aer.client import QiskitAerClient @@ -52,15 +52,17 @@ async def test_retry_until_success( ) # API will intentionally fail once - result = await run_primitive.with_options( - retries=2, - retry_delay_seconds=1, - )( - primitive_blocs=[bell_circuit_pub], - program_type="sampler", + runner_task = QuantumRuntime( resource_name="aer_simulator", credentials=aer_credentials_2q, - enable_analytics=False, + enable_job_analytics=False, + max_retry=2, + retry_delay=1, + ).build_runner_task() + + result = await runner_task( + primitive_blocs=[bell_circuit_pub], + program_type="sampler", ) assert result[0].data.meas.get_counts() == {"00": 512, "11": 512} @@ -89,16 +91,18 @@ async def test_unretryable( side_effect=RuntimeJobFailure("unretryable failure", job_id="test_job_123", retry=False), ) + runner_task = QuantumRuntime( + resource_name="aer_simulator", + credentials=aer_credentials_2q, + enable_job_analytics=False, + max_retry=2, + retry_delay=1, + ).build_runner_task() + with pytest.raises(RuntimeJobFailure): - await run_primitive.with_options( - retries=2, - retry_delay_seconds=1, - )( + await runner_task( primitive_blocs=[bell_circuit_pub], program_type="sampler", - resource_name="aer_simulator", - credentials=aer_credentials_2q, - enable_analytics=False, ) # Don't submit unretryable job more than once @@ -123,16 +127,18 @@ async def test_not_infinite_loop( side_effect=RuntimeJobFailure("retryable failure", job_id="test_job_123", retry=True), ) + runner_task = QuantumRuntime( + resource_name="aer_simulator", + credentials=aer_credentials_2q, + enable_job_analytics=False, + max_retry=2, + retry_delay=1, + ).build_runner_task() + with pytest.raises(RuntimeJobFailure): - await run_primitive.with_options( - retries=2, - retry_delay_seconds=1, - )( + await runner_task( primitive_blocs=[bell_circuit_pub], program_type="sampler", - resource_name="aer_simulator", - credentials=aer_credentials_2q, - enable_analytics=False, ) # Retry until retry limit @@ -160,16 +166,18 @@ async def test_retry_on_task_timeout( side_effect=lambda _: "COMPLETED" if time.time() - test_start > 10 else "QUEUED", ) - result = await run_primitive.with_options( - retries=2, - retry_delay_seconds=0, - timeout_seconds=6, - )( - primitive_blocs=[bell_circuit_pub], - program_type="sampler", + runner_task = QuantumRuntime( resource_name="aer_simulator", credentials=aer_credentials_2q, - enable_analytics=False, + enable_job_analytics=False, + max_retry=2, + retry_delay=0, + timeout=6, + ).build_runner_task() + + result = await runner_task( + primitive_blocs=[bell_circuit_pub], + program_type="sampler", ) assert result[0].data.meas.get_counts() == {"00": 512, "11": 512} @@ -213,15 +221,18 @@ async def test_job_metrics( ), ) - await run_primitive.with_options( + runner_task = QuantumRuntime( + resource_name="aer_simulator", + credentials=aer_credentials_2q, + enable_job_analytics=True, + ).build_runner_task( task_run_name="test_job_metrics", tags=["tag1", "tag2"], - )( + ) + + await runner_task( primitive_blocs=[bell_circuit_pub], program_type="sampler", - resource_name="aer_simulator", - credentials=aer_credentials_2q, - enable_analytics=True, options=test_options, ) diff --git a/tests/vendors/ibm/test_e2e.py b/tests/vendors/ibm/test_e2e.py index 73c2957..9c103f1 100644 --- a/tests/vendors/ibm/test_e2e.py +++ b/tests/vendors/ibm/test_e2e.py @@ -178,7 +178,7 @@ def test_sampler_e2e( "program_type": "sampler", "num_pubs": 1, "job_id": "test-job-id-123", - "tags": ["primitive-execute"], + "tags": [], "timestamp.created": ANY, "timestamp.started": ANY, "timestamp.completed": ANY,