Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 91 additions & 22 deletions src/dodal/devices/fast_grid_scan.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import asyncio
from abc import ABC, abstractmethod
from typing import Generic, TypeVar

import numpy as np
from bluesky.plan_stubs import mv
from bluesky.protocols import Flyable
from bluesky.plan_stubs import prepare
from bluesky.protocols import Flyable, Preparable
from numpy import ndarray
from ophyd_async.core import (
AsyncStatus,
Expand All @@ -13,6 +14,7 @@
SignalRW,
StandardReadable,
derived_signal_r,
set_and_wait_for_value,
soft_signal_r_and_setter,
wait_for_value,
)
Expand All @@ -29,6 +31,10 @@
from dodal.parameters.experiment_parameter_base import AbstractExperimentWithBeamParams


class GridScanInvalidException(RuntimeError):
"""Raised when the gridscan parameters are not valid."""


@dataclass
class GridAxis:
start: float
Expand Down Expand Up @@ -144,7 +150,7 @@ def z_axis(self) -> GridAxis:
return GridAxis(self.z2_start_mm, self.z_step_size_mm, self.z_steps)


ParamType = TypeVar("ParamType", bound=GridScanParamsCommon, covariant=True)
ParamType = TypeVar("ParamType", bound=GridScanParamsCommon)


class WithDwellTime(BaseModel):
Expand Down Expand Up @@ -190,7 +196,9 @@ def __init__(self, prefix: str, name: str = "", has_prog_num=True) -> None:
self.program_number = soft_signal_r_and_setter(float, -1)[0]


class FastGridScanCommon(StandardReadable, Flyable, ABC, Generic[ParamType]):
class FastGridScanCommon(
StandardReadable, Flyable, ABC, Preparable, Generic[ParamType]
):
"""Device containing the minimal signals for a general fast grid scan.

When the motion program is started, the goniometer will move in a snake-like grid trajectory,
Expand Down Expand Up @@ -231,8 +239,9 @@ def __init__(
self.KICKOFF_TIMEOUT: float = 5.0

self.COMPLETE_STATUS: float = 60.0
self.VALIDITY_CHECK_TIMEOUT = 0.5

self.movable_params: dict[str, Signal] = {
self._movable_params: dict[str, Signal] = {
"x_steps": self.x_steps,
"y_steps": self.y_steps,
"x_step_size_mm": self.x_step_size,
Expand Down Expand Up @@ -284,6 +293,47 @@ def _create_motion_program(
self, motion_controller_prefix: str
) -> MotionProgram: ...

@AsyncStatus.wrap
async def prepare(self, value: ParamType):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, this is a lot nicer than doing it plan side!

"""
Submit the gridscan parameters to the device for validation prior to
gridscan kickoff
Args:
value: the gridscan parameters

Raises:
GridScanInvalidException: if the gridscan parameters were not valid
"""
set_statuses = []

LOGGER.info("Applying gridscan parameters...")
# Create arguments for bps.mv
for key, signal in self._movable_params.items():
param_value = value.__dict__[key]
set_statuses.append(await set_and_wait_for_value(signal, param_value)) # type: ignore

# Counter should always start at 0
set_statuses.append(await set_and_wait_for_value(self.position_counter, 0))

LOGGER.info("Gridscan parameters applied, waiting for sets to complete...")

# wait for parameter sets to complete
await asyncio.gather(*set_statuses)

LOGGER.info("Sets confirmed, waiting for validity checks to pass...")
# XXX Can we use x/y/z scan valid to distinguish between SampleException/pin invalid
# and other non-sample-related errors?
Comment on lines +324 to +325
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should: I'm not sure if there are cases where the scan is invalid but it's not a sample error but maybe I'm missing something?

try:
await wait_for_value(
self.scan_invalid, 0.0, timeout=self.VALIDITY_CHECK_TIMEOUT
)
except TimeoutError as e:
raise GridScanInvalidException(
f"Gridscan parameters not validated after {self.VALIDITY_CHECK_TIMEOUT}s"
) from e

LOGGER.info("Gridscan validity confirmed, gridscan is now prepared.")


class FastGridScanThreeD(FastGridScanCommon[ParamType]):
"""Device for standard 3D FGS.
Expand All @@ -309,10 +359,10 @@ def __init__(self, prefix: str, infix: str, name: str = "") -> None:

super().__init__(full_prefix, prefix, name)

self.movable_params["z_step_size_mm"] = self.z_step_size
self.movable_params["z2_start_mm"] = self.z2_start
self.movable_params["y2_start_mm"] = self.y2_start
self.movable_params["z_steps"] = self.z_steps
self._movable_params["z_step_size_mm"] = self.z_step_size
self._movable_params["z2_start_mm"] = self.z2_start
self._movable_params["y2_start_mm"] = self.y2_start
self._movable_params["z_steps"] = self.z_steps

def _create_expected_images_signal(self):
return derived_signal_r(
Expand All @@ -329,7 +379,35 @@ def _calculate_expected_images(self, x: int, y: int, z: int) -> int:
return first_grid + second_grid

def _create_scan_invalid_signal(self, prefix: str) -> SignalR[float]:
return epics_signal_r(float, f"{prefix}SCAN_INVALID")
self.x_scan_valid = epics_signal_r(float, f"{prefix}X_SCAN_VALID")
self.y_scan_valid = epics_signal_r(float, f"{prefix}Y_SCAN_VALID")
self.z_scan_valid = epics_signal_r(float, f"{prefix}Z_SCAN_VALID")
self.device_scan_invalid = epics_signal_r(float, f"{prefix}SCAN_INVALID")

def compute_derived_value(
x_scan_valid: float,
y_scan_valid: float,
z_scan_valid: float,
device_scan_invalid: float,
) -> float:
return (
1.0
if not (
x_scan_valid
and y_scan_valid
and z_scan_valid
and not device_scan_invalid
)
else 0.0
)

return derived_signal_r(
compute_derived_value,
x_scan_valid=self.x_scan_valid,
y_scan_valid=self.y_scan_valid,
z_scan_valid=self.z_scan_valid,
device_scan_invalid=self.device_scan_invalid,
)

def _create_motion_program(self, motion_controller_prefix: str):
return MotionProgram(motion_controller_prefix)
Expand All @@ -349,7 +427,7 @@ def __init__(self, prefix: str, name: str = "") -> None:
self.dwell_time_ms = epics_signal_rw_rbv(float, f"{full_prefix}DWELL_TIME")
self.x_counter = epics_signal_r(int, f"{full_prefix}X_COUNTER")
super().__init__(prefix, infix, name)
self.movable_params["dwell_time_ms"] = self.dwell_time_ms
self._movable_params["dwell_time_ms"] = self.dwell_time_ms

def _create_position_counter(self, prefix: str):
return epics_signal_rw(
Expand Down Expand Up @@ -380,20 +458,11 @@ def __init__(self, prefix: str, name: str = "") -> None:
)
super().__init__(prefix, infix, name)

self.movable_params["run_up_distance_mm"] = self.run_up_distance_mm
self._movable_params["run_up_distance_mm"] = self.run_up_distance_mm

def _create_position_counter(self, prefix: str):
return epics_signal_rw(int, f"{prefix}Y_COUNTER")


def set_fast_grid_scan_params(scan: FastGridScanCommon[ParamType], params: ParamType):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could: Maybe a docstring in here on what to expect if it's invalid

to_move = []

# Create arguments for bps.mv
for key in scan.movable_params.keys():
to_move.extend([scan.movable_params[key], params.__dict__[key]])

# Counter should always start at 0
to_move.extend([scan.position_counter, 0])

yield from mv(*to_move)
yield from prepare(scan, params, wait=True)
2 changes: 1 addition & 1 deletion src/dodal/devices/i02_1/fast_grid_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def __init__(
# See https://github.com/DiamondLightSource/mx-bluesky/issues/1203
self.dwell_time_ms = epics_signal_rw_rbv(float, f"{full_prefix}EXPOSURE_TIME")

self.movable_params["dwell_time_ms"] = self.dwell_time_ms
self._movable_params["dwell_time_ms"] = self.dwell_time_ms

def _create_expected_images_signal(self):
return derived_signal_r(
Expand Down
2 changes: 1 addition & 1 deletion tests/devices/i02_1/test_fast_grid_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ async def fast_grid_scan():
async def test_i02_1_gridscan_has_2d_behaviour(fast_grid_scan: ZebraFastGridScanTwoD):
three_d_movables = ["z_step_size_mm", "z2_start_mm", "y2_start_mm", "z_steps"]
for movable in three_d_movables:
assert movable not in fast_grid_scan.movable_params.keys()
assert movable not in fast_grid_scan._movable_params.keys()
set_mock_value(fast_grid_scan.x_steps, 5)
set_mock_value(fast_grid_scan.y_steps, 4)
assert await fast_grid_scan.expected_images.get_value() == 20
Loading