Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 6 additions & 12 deletions prefect_qiskit/primitives/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@

from typing import Literal

from prefect import task
from prefect.artifacts import create_table_artifact
from prefect.blocks.abstract import CredentialsBlock
from prefect.context import TaskRunContext
Expand All @@ -27,6 +26,9 @@
from prefect_qiskit.models import JobMetrics
from prefect_qiskit.primitives.job import PrimitiveJob

# TODO integration of metrics database
# TODO automatic job split (workflow optimization)


async def retry_on_failure(_task, _task_run, state):
try:
Expand All @@ -41,14 +43,6 @@ async def retry_on_failure(_task, _task_run, state):
return False


# TODO integration of metrics database
# TODO automatic job split (workflow optimization)


@task(
name="run_primitive",
retry_condition_fn=retry_on_failure,
)
async def run_primitive(
*,
primitive_blocs: list[SamplerPub] | list[EstimatorPub],
Expand All @@ -58,9 +52,9 @@ async def run_primitive(
enable_analytics: bool = True,
options: dict | None = None,
) -> PrimitiveResult:
"""
This function implements a Prefect task to manage the execution of
Qiskit Primitives on an abstract layer,
"""A core logic to make a primitive job and returns a result.

This function manages the execution of Qiskit Primitives on an abstract layer,
providing built-in execution failure protection.

It accepts PUBs and options, submitting this data to quantum computers via the vendor's API.
Expand Down
84 changes: 42 additions & 42 deletions prefect_qiskit/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,12 @@
"""

import asyncio
from collections.abc import Callable
from functools import partial

from prefect._internal.compatibility.async_dispatch import async_dispatch
from prefect.blocks.core import Block
from prefect.cache_policies import NO_CACHE, CacheKeyFnPolicy
from prefect.tasks import Task
from prefect.utilities.asyncutils import run_coro_as_sync
from pydantic import Field, model_validator
from qiskit.primitives import PrimitiveResult
Expand All @@ -35,7 +36,7 @@
from typing_extensions import Self

from prefect_qiskit.models import AsyncRuntimeClientInterface
from prefect_qiskit.primitives.runner import run_primitive
from prefect_qiskit.primitives.runner import retry_on_failure, run_primitive
from prefect_qiskit.utils.pub_hasher import pub_hasher
from prefect_qiskit.vendors import QiskitAerCredentials, QuantumCredentialsT

Expand Down Expand Up @@ -218,14 +219,11 @@ async def async_sampler(
Returns:
Qiskit PrimitiveResult object.
"""
opted_run_primitive = self._setup_runner(tags=tags)
opted_run_primitive = self.build_runner_task(tags=tags)

return await opted_run_primitive(
primitive_blocs=list(map(SamplerPub.coerce, sampler_pubs)),
program_type="sampler",
resource_name=self.resource_name,
credentials=self.credentials,
enable_analytics=self.enable_job_analytics,
options=options,
)

Expand All @@ -246,14 +244,11 @@ def sampler(
Returns:
Qiskit PrimitiveResult object.
"""
opted_run_primitive = self._setup_runner(tags=tags)
opted_run_primitive = self.build_runner_task(tags=tags)

coro = opted_run_primitive(
primitive_blocs=list(map(SamplerPub.coerce, sampler_pubs)),
program_type="sampler",
resource_name=self.resource_name,
credentials=self.credentials,
enable_analytics=self.enable_job_analytics,
options=options,
)
return asyncio.run(coro)
Expand All @@ -274,14 +269,11 @@ async def async_estimator(
Returns:
Qiskit PrimitiveResult object.
"""
opted_run_primitive = self._setup_runner(tags=tags)
opted_run_primitive = self.build_runner_task(tags=tags)

return await opted_run_primitive(
primitive_blocs=list(map(EstimatorPub.coerce, estimator_pubs)),
program_type="estimator",
resource_name=self.resource_name,
credentials=self.credentials,
enable_analytics=self.enable_job_analytics,
options=options,
)

Expand All @@ -302,51 +294,59 @@ def estimator(
Returns:
Qiskit PrimitiveResult object.
"""
opted_run_primitive = self._setup_runner(tags=tags)
opted_run_primitive = self.build_runner_task(tags=tags)

coro = opted_run_primitive(
primitive_blocs=list(map(EstimatorPub.coerce, estimator_pubs)),
program_type="estimator",
resource_name=self.resource_name,
credentials=self.credentials,
enable_analytics=self.enable_job_analytics,
options=options,
)
return asyncio.run(coro)

def _setup_runner(
def build_runner_task(
self,
tags: list[str] | None = None,
) -> Callable:
"""A helper function to setup a primitive runner.
**kwargs,
) -> Task:
"""Build Prefect Task object that runs primitive.

Args:
tags: Arbitrary labels to add to the Primitive task tags of execution.
kwargs: Keyword arguments to instantiate Prefect Task.

Returns:
A primitive runner callable with configured task options.
A configured prefect task.
"""
configured_runner = partial(
run_primitive,
resource_name=self.resource_name,
credentials=self.credentials,
enable_analytics=self.enable_job_analytics,
)

if self.execution_cache:
task_options = {
"cache_policy": CacheKeyFnPolicy(cache_key_fn=_primitive_cache),
"persist_result": True,
"result_serializer": "compressed/pickle",
}
kwargs.update(
{
"cache_policy": CacheKeyFnPolicy(cache_key_fn=_primitive_cache),
"persist_result": True,
"result_serializer": "compressed/pickle",
}
)
else:
task_options = {
"cache_policy": NO_CACHE,
"persist_result": False,
}
task_options.update(
{
"retries": self.max_retry,
"retry_delay_seconds": self.retry_delay,
"timeout_seconds": self.timeout,
"tags": tags or ["primitive-execute"],
}
)
kwargs.update(
{
"cache_policy": NO_CACHE,
"persist_result": False,
}
)

return run_primitive.with_options(**task_options)
return Task(
fn=configured_runner,
name="run_primitive",
retries=self.max_retry,
retry_delay_seconds=self.retry_delay,
timeout_seconds=self.timeout,
retry_condition_fn=retry_on_failure,
**kwargs,
)


def _primitive_cache(_, parameters) -> str:
Expand Down
81 changes: 46 additions & 35 deletions tests/unit/test_primitive_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

from prefect_qiskit.exceptions import RuntimeJobFailure
from prefect_qiskit.models import JobMetrics
from prefect_qiskit.primitives.runner import run_primitive
from prefect_qiskit.runtime import QuantumRuntime
from prefect_qiskit.vendors.qiskit_aer import QiskitAerCredentials
from prefect_qiskit.vendors.qiskit_aer.client import QiskitAerClient

Expand Down Expand Up @@ -52,15 +52,17 @@ async def test_retry_until_success(
)

# API will intentionally fail once
result = await run_primitive.with_options(
retries=2,
retry_delay_seconds=1,
)(
primitive_blocs=[bell_circuit_pub],
program_type="sampler",
runner_task = QuantumRuntime(
resource_name="aer_simulator",
credentials=aer_credentials_2q,
enable_analytics=False,
enable_job_analytics=False,
max_retry=2,
retry_delay=1,
).build_runner_task()

result = await runner_task(
primitive_blocs=[bell_circuit_pub],
program_type="sampler",
)

assert result[0].data.meas.get_counts() == {"00": 512, "11": 512}
Expand Down Expand Up @@ -89,16 +91,18 @@ async def test_unretryable(
side_effect=RuntimeJobFailure("unretryable failure", job_id="test_job_123", retry=False),
)

runner_task = QuantumRuntime(
resource_name="aer_simulator",
credentials=aer_credentials_2q,
enable_job_analytics=False,
max_retry=2,
retry_delay=1,
).build_runner_task()

with pytest.raises(RuntimeJobFailure):
await run_primitive.with_options(
retries=2,
retry_delay_seconds=1,
)(
await runner_task(
primitive_blocs=[bell_circuit_pub],
program_type="sampler",
resource_name="aer_simulator",
credentials=aer_credentials_2q,
enable_analytics=False,
)

# Don't submit unretryable job more than once
Expand All @@ -123,16 +127,18 @@ async def test_not_infinite_loop(
side_effect=RuntimeJobFailure("retryable failure", job_id="test_job_123", retry=True),
)

runner_task = QuantumRuntime(
resource_name="aer_simulator",
credentials=aer_credentials_2q,
enable_job_analytics=False,
max_retry=2,
retry_delay=1,
).build_runner_task()

with pytest.raises(RuntimeJobFailure):
await run_primitive.with_options(
retries=2,
retry_delay_seconds=1,
)(
await runner_task(
primitive_blocs=[bell_circuit_pub],
program_type="sampler",
resource_name="aer_simulator",
credentials=aer_credentials_2q,
enable_analytics=False,
)

# Retry until retry limit
Expand Down Expand Up @@ -160,16 +166,18 @@ async def test_retry_on_task_timeout(
side_effect=lambda _: "COMPLETED" if time.time() - test_start > 10 else "QUEUED",
)

result = await run_primitive.with_options(
retries=2,
retry_delay_seconds=0,
timeout_seconds=6,
)(
primitive_blocs=[bell_circuit_pub],
program_type="sampler",
runner_task = QuantumRuntime(
resource_name="aer_simulator",
credentials=aer_credentials_2q,
enable_analytics=False,
enable_job_analytics=False,
max_retry=2,
retry_delay=0,
timeout=6,
).build_runner_task()

result = await runner_task(
primitive_blocs=[bell_circuit_pub],
program_type="sampler",
)

assert result[0].data.meas.get_counts() == {"00": 512, "11": 512}
Expand Down Expand Up @@ -213,15 +221,18 @@ async def test_job_metrics(
),
)

await run_primitive.with_options(
runner_task = QuantumRuntime(
resource_name="aer_simulator",
credentials=aer_credentials_2q,
enable_job_analytics=True,
).build_runner_task(
task_run_name="test_job_metrics",
tags=["tag1", "tag2"],
)(
)

await runner_task(
primitive_blocs=[bell_circuit_pub],
program_type="sampler",
resource_name="aer_simulator",
credentials=aer_credentials_2q,
enable_analytics=True,
options=test_options,
)

Expand Down
2 changes: 1 addition & 1 deletion tests/vendors/ibm/test_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def test_sampler_e2e(
"program_type": "sampler",
"num_pubs": 1,
"job_id": "test-job-id-123",
"tags": ["primitive-execute"],
"tags": [],
"timestamp.created": ANY,
"timestamp.started": ANY,
"timestamp.completed": ANY,
Expand Down
Loading