diff --git a/src/dodal/devices/electron_analyser/abstract/base_driver_io.py b/src/dodal/devices/electron_analyser/abstract/base_driver_io.py index 8b81c7fc255..f56d38cc3d8 100644 --- a/src/dodal/devices/electron_analyser/abstract/base_driver_io.py +++ b/src/dodal/devices/electron_analyser/abstract/base_driver_io.py @@ -153,11 +153,12 @@ async def set(self, region: TAbstractBaseRegion): self.energy_source.selected_source.set(region.excitation_energy_source) excitation_energy = await self.energy_source.energy.get_value() - # Copy region so doesn't alter the actual region and switch to kinetic energy - ke_region = region.model_copy() - ke_region.switch_energy_mode(EnergyMode.KINETIC, excitation_energy) - + # Switch to kinetic energy as epics doesn't support BINDING. + ke_region = region.switch_energy_mode(EnergyMode.KINETIC, excitation_energy) await self._set_region(ke_region) + # Set the true energy mode from original region so binding_energy_axis can be + # calculated correctly. + await self.energy_mode.set(region.energy_mode) @abstractmethod async def _set_region(self, ke_region: TAbstractBaseRegion): diff --git a/src/dodal/devices/electron_analyser/abstract/base_region.py b/src/dodal/devices/electron_analyser/abstract/base_region.py index f64cf19d946..5d024e2df55 100644 --- a/src/dodal/devices/electron_analyser/abstract/base_region.py +++ b/src/dodal/devices/electron_analyser/abstract/base_region.py @@ -1,7 +1,7 @@ import re from abc import ABC from collections.abc import Callable -from typing import Generic, TypeVar +from typing import Generic, Self, TypeVar from pydantic import BaseModel, Field, model_validator @@ -88,28 +88,43 @@ def is_kinetic_energy(self) -> bool: return self.energy_mode == EnergyMode.KINETIC def switch_energy_mode( - self, energy_mode: EnergyMode, excitation_energy: float - ) -> None: + self, energy_mode: EnergyMode, excitation_energy: float, copy: bool = True + ) -> Self: """ - Switch region to new energy mode: Kinetic or Binding. Updates the low_energy, - centre_energy, high_energy, and energy_mode, only if it switches to a new one. + Switch region with to a new energy mode with a new energy mode: Kinetic or Binding. + It caculates new values for low_energy, centre_energy, high_energy, via the + excitation enerrgy. It doesn't calculate anything if the region is already of + the same energy mode. Parameters: - energy_mode: mode you want to switch the region to. - excitation_energy: the energy to calculate the new values of low_energy, - centre_energy, and high_energy. + energy_mode: Mode you want to switch the region to. + excitation_energy: Energy conversion for low_energy, centre_energy, and + high_energy for new energy mode. + copy: Defaults to True. If true, create a copy of this region for the new + energy_mode and return it. If False, alter this region for the + energy_mode and return it self. + + Returns: + Region with selected energy mode and new calculated energy values. """ + switched_r = self.model_copy() if copy else self conv = ( to_binding_energy if energy_mode == EnergyMode.BINDING else to_kinetic_energy ) - self.low_energy = conv(self.low_energy, self.energy_mode, excitation_energy) - self.centre_energy = conv( - self.centre_energy, self.energy_mode, excitation_energy + switched_r.low_energy = conv( + switched_r.low_energy, switched_r.energy_mode, excitation_energy ) - self.high_energy = conv(self.high_energy, self.energy_mode, excitation_energy) - self.energy_mode = energy_mode + switched_r.centre_energy = conv( + switched_r.centre_energy, switched_r.energy_mode, excitation_energy + ) + switched_r.high_energy = conv( + switched_r.high_energy, switched_r.energy_mode, excitation_energy + ) + switched_r.energy_mode = energy_mode + + return switched_r @model_validator(mode="before") @classmethod diff --git a/src/dodal/devices/electron_analyser/detector.py b/src/dodal/devices/electron_analyser/detector.py index a0d78990a95..7a44fbde7e5 100644 --- a/src/dodal/devices/electron_analyser/detector.py +++ b/src/dodal/devices/electron_analyser/detector.py @@ -70,7 +70,7 @@ def __init__( driver: TAbstractAnalyserDriverIO, name: str = "", ): - # Save driver as direct child so particpates with connect() + # Save driver as direct child so participates with connect() self.driver = driver self._sequence_class = sequence_class controller = ConstantDeadTimeController[TAbstractAnalyserDriverIO](driver, 0) diff --git a/src/dodal/devices/electron_analyser/specs/driver_io.py b/src/dodal/devices/electron_analyser/specs/driver_io.py index 59603110ce0..fbf83ee1917 100644 --- a/src/dodal/devices/electron_analyser/specs/driver_io.py +++ b/src/dodal/devices/electron_analyser/specs/driver_io.py @@ -66,7 +66,6 @@ def __init__( async def _set_region(self, ke_region: SpecsRegion[TLensMode, TPsuMode]): await asyncio.gather( self.region_name.set(ke_region.name), - self.energy_mode.set(ke_region.energy_mode), self.low_energy.set(ke_region.low_energy), self.high_energy.set(ke_region.high_energy), self.slices.set(ke_region.slices), diff --git a/src/dodal/devices/electron_analyser/vgscienta/driver_io.py b/src/dodal/devices/electron_analyser/vgscienta/driver_io.py index 19c878f1785..454fa3d0954 100644 --- a/src/dodal/devices/electron_analyser/vgscienta/driver_io.py +++ b/src/dodal/devices/electron_analyser/vgscienta/driver_io.py @@ -74,7 +74,6 @@ def __init__( async def _set_region(self, ke_region: VGScientaRegion[TLensMode, TPassEnergyEnum]): await asyncio.gather( self.region_name.set(ke_region.name), - self.energy_mode.set(ke_region.energy_mode), self.low_energy.set(ke_region.low_energy), self.centre_energy.set(ke_region.centre_energy), self.high_energy.set(ke_region.high_energy), diff --git a/tests/devices/electron_analyser/abstract/test_base_driver_io.py b/tests/devices/electron_analyser/abstract/test_base_driver_io.py index afcae67a43b..283a19e03ce 100644 --- a/tests/devices/electron_analyser/abstract/test_base_driver_io.py +++ b/tests/devices/electron_analyser/abstract/test_base_driver_io.py @@ -1,5 +1,5 @@ from typing import get_origin -from unittest.mock import AsyncMock, MagicMock, patch +from unittest.mock import AsyncMock, patch import pytest from bluesky import plan_stubs as bps @@ -7,9 +7,10 @@ from bluesky.utils import FailedStatus from ophyd_async.core import StrictEnum, init_devices from ophyd_async.epics.adcore import ADImageMode +from ophyd_async.testing import get_mock_put from dodal.devices import b07, i09 -from dodal.devices.electron_analyser import DualEnergySource, EnergySource +from dodal.devices.electron_analyser import DualEnergySource, EnergyMode, EnergySource from dodal.devices.electron_analyser.abstract import ( AbstractAnalyserDriverIO, AbstractBaseRegion, @@ -49,23 +50,43 @@ async def sim_driver( @pytest.mark.parametrize("region", TEST_SEQUENCE_REGION_NAMES, indirect=True) -def test_driver_set( +async def test_driver_set( sim_driver: AbstractAnalyserDriverIO, region: AbstractBaseRegion, RE: RunEngine, ) -> None: sim_driver._set_region = AsyncMock() - if isinstance(sim_driver.energy_source, DualEnergySource): - sim_driver.energy_source.selected_source.set = MagicMock() + # Patch switch_energy_mode so we can check on calls, but still run the real function + with patch.object( + AbstractBaseRegion, + "switch_energy_mode", + side_effect=AbstractBaseRegion.switch_energy_mode, # run the real method + autospec=True, + ) as mock_switch_energy_mode: + RE(bps.mv(sim_driver, region)) + + mock_switch_energy_mode.assert_called_once_with( + region, + EnergyMode.KINETIC, + await sim_driver.energy_source.energy.get_value(), + ) - RE(bps.mv(sim_driver, region)) + if isinstance(sim_driver.energy_source, DualEnergySource): + get_mock_put( + sim_driver.energy_source.selected_source + ).assert_called_once_with(region.excitation_energy_source, wait=True) + + # Check interal _set_region was set with ke_region + ke_region = mock_switch_energy_mode.call_args[0][0].switch_energy_mode( + EnergyMode.KINETIC, + await sim_driver.energy_source.energy.get_value(), + ) + sim_driver._set_region.assert_called_once_with(ke_region) - if isinstance(sim_driver.energy_source, DualEnergySource): - sim_driver.energy_source.selected_source.set.assert_called_once_with( # type: ignore - region.excitation_energy_source + get_mock_put(sim_driver.energy_mode).assert_called_once_with( + region.energy_mode, wait=True ) - sim_driver._set_region.assert_called_once() def test_driver_throws_error_with_wrong_lens_mode( diff --git a/tests/devices/electron_analyser/abstract/test_base_region.py b/tests/devices/electron_analyser/abstract/test_base_region.py index e082880ed7c..35578f5ae0f 100644 --- a/tests/devices/electron_analyser/abstract/test_base_region.py +++ b/tests/devices/electron_analyser/abstract/test_base_region.py @@ -86,9 +86,13 @@ def test_region_kinetic_and_binding_energy( assert r.is_kinetic_energy() != is_binding_energy -def assert_region_field_energy_from_switching_energy_modes_is_correct( - region: AbstractBaseRegion, field: str, excitation_energy: float +@pytest.mark.parametrize("field", ["low_energy", "centre_energy", "high_energy"]) +@pytest.mark.parametrize("copy", [True, False]) +@pytest.mark.parametrize("region", TEST_SEQUENCE_REGION_NAMES, indirect=True) +def test_each_energy_field_for_region_is_correct_when_switching_energy_modes( + region: AbstractBaseRegion, field: str, copy: bool ) -> None: + excitation_energy = 100 conversion_func_map = { EnergyMode.KINETIC: to_binding_energy, EnergyMode.BINDING: to_kinetic_energy, @@ -111,23 +115,16 @@ def assert_region_field_energy_from_switching_energy_modes_is_correct( ] expected_e_values = [original_energy, converted_energy, original_energy] - # Do full cycle of switching energy modes - # First check shouldn't see change as region is the same energy mode + # Do full cycle of switching energy modes. + # First check shouldn't see change as region is the same energy mode. # Second check cycles to the opposite energy mode, check it is correct via opposite # energy mode. # Third check cycles back so should be original value. for e_mode, e_expected in zip(e_mode_sequence, expected_e_values, strict=False): - region.switch_energy_mode(e_mode, excitation_energy) - assert getattr(region, field) == e_expected - assert region.energy_mode == e_mode - - -@pytest.mark.parametrize("region", TEST_SEQUENCE_REGION_NAMES, indirect=True) -def test_each_energy_field_for_region_is_correct_when_switching_energy_modes( - region: AbstractBaseRegion, -) -> None: - excitation_energy = 100 - for field in ("low_energy", "centre_energy", "high_energy"): - assert_region_field_energy_from_switching_energy_modes_is_correct( - region, field, excitation_energy - ) + new_r = region.switch_energy_mode(e_mode, excitation_energy, copy) + assert getattr(new_r, field) == e_expected + assert new_r.energy_mode == e_mode + if copy: + assert new_r is not region + else: + assert new_r is region diff --git a/tests/devices/electron_analyser/specs/test_driver_io.py b/tests/devices/electron_analyser/specs/test_driver_io.py index 25475401076..47913b91005 100644 --- a/tests/devices/electron_analyser/specs/test_driver_io.py +++ b/tests/devices/electron_analyser/specs/test_driver_io.py @@ -52,9 +52,6 @@ async def test_analyser_sets_region_correctly( ) -> None: RE(bps.mv(sim_driver, region), wait=True) - excitation_energy = await sim_driver.energy_source.energy.get_value() - region.switch_energy_mode(EnergyMode.KINETIC, excitation_energy) - get_mock_put(sim_driver.region_name).assert_called_once_with(region.name, wait=True) get_mock_put(sim_driver.energy_mode).assert_called_once_with( region.energy_mode, wait=True @@ -66,18 +63,20 @@ async def test_analyser_sets_region_correctly( region.lens_mode, wait=True ) + excitation_energy = await sim_driver.energy_source.energy.get_value() + ke_region = region.switch_energy_mode(EnergyMode.KINETIC, excitation_energy) get_mock_put(sim_driver.low_energy).assert_called_once_with( - region.low_energy, wait=True + ke_region.low_energy, wait=True ) - if region.acquisition_mode == AcquisitionMode.FIXED_ENERGY: + if ke_region.acquisition_mode == AcquisitionMode.FIXED_ENERGY: get_mock_put(sim_driver.centre_energy).assert_called_once_with( - region.centre_energy, wait=True + ke_region.centre_energy, wait=True ) else: get_mock_put(sim_driver.centre_energy).assert_not_called() get_mock_put(sim_driver.high_energy).assert_called_once_with( - region.high_energy, wait=True + ke_region.high_energy, wait=True ) get_mock_put(sim_driver.pass_energy).assert_called_once_with( region.pass_energy, wait=True @@ -116,8 +115,7 @@ async def test_analyser_sets_region_and_read_configuration_is_correct( prefix = sim_driver.name + "-" excitation_energy = await sim_driver.energy_source.energy.get_value() - region.switch_energy_mode(EnergyMode.KINETIC, excitation_energy) - + ke_region = region.switch_energy_mode(EnergyMode.KINETIC, excitation_energy) await assert_configuration( sim_driver, { @@ -125,9 +123,9 @@ async def test_analyser_sets_region_and_read_configuration_is_correct( f"{prefix}energy_mode": partial_reading(region.energy_mode), f"{prefix}acquisition_mode": partial_reading(region.acquisition_mode), f"{prefix}lens_mode": partial_reading(region.lens_mode), - f"{prefix}low_energy": partial_reading(region.low_energy), + f"{prefix}low_energy": partial_reading(ke_region.low_energy), f"{prefix}centre_energy": partial_reading(ANY), - f"{prefix}high_energy": partial_reading(region.high_energy), + f"{prefix}high_energy": partial_reading(ke_region.high_energy), f"{prefix}energy_step": partial_reading(ANY), f"{prefix}pass_energy": partial_reading(region.pass_energy), f"{prefix}slices": partial_reading(region.slices), @@ -180,10 +178,13 @@ async def test_specs_analyser_binding_energy_axis( excitation_energy = await sim_driver.energy_source.energy.get_value() # Check binding energy is correct - is_binding = await sim_driver.energy_mode.get_value() == EnergyMode.BINDING + is_region_binding = region.is_binding_energy() + is_driver_binding = await sim_driver.energy_mode.get_value() == EnergyMode.BINDING + # Catch that driver correctly reflects what region energy mode is. + assert is_region_binding == is_driver_binding energy_axis = await sim_driver.energy_axis.get_value() expected_binding_energy_axis = np.array( - [excitation_energy - e if is_binding else e for e in energy_axis] + [excitation_energy - e if is_driver_binding else e for e in energy_axis] ) await assert_value(sim_driver.binding_energy_axis, expected_binding_energy_axis) diff --git a/tests/devices/electron_analyser/vgscienta/test_driver_io.py b/tests/devices/electron_analyser/vgscienta/test_driver_io.py index cb8bb90d6b8..a4c4f737189 100644 --- a/tests/devices/electron_analyser/vgscienta/test_driver_io.py +++ b/tests/devices/electron_analyser/vgscienta/test_driver_io.py @@ -49,9 +49,6 @@ async def test_analyser_sets_region_correctly( ) -> None: RE(bps.mv(sim_driver, region), wait=True) - excitation_energy = await sim_driver.energy_source.energy.get_value() - region.switch_energy_mode(EnergyMode.KINETIC, excitation_energy) - get_mock_put(sim_driver.region_name).assert_called_once_with(region.name, wait=True) get_mock_put(sim_driver.energy_mode).assert_called_once_with( region.energy_mode, wait=True @@ -62,14 +59,16 @@ async def test_analyser_sets_region_correctly( get_mock_put(sim_driver.lens_mode).assert_called_once_with( region.lens_mode, wait=True ) + excitation_energy = await sim_driver.energy_source.energy.get_value() + ke_region = region.switch_energy_mode(EnergyMode.KINETIC, excitation_energy) get_mock_put(sim_driver.low_energy).assert_called_once_with( - region.low_energy, wait=True + ke_region.low_energy, wait=True ) get_mock_put(sim_driver.centre_energy).assert_called_once_with( - region.centre_energy, wait=True + ke_region.centre_energy, wait=True ) get_mock_put(sim_driver.high_energy).assert_called_once_with( - region.high_energy, wait=True + ke_region.high_energy, wait=True ) get_mock_put(sim_driver.pass_energy).assert_called_once_with( region.pass_energy, wait=True @@ -111,8 +110,7 @@ async def test_analyser_sets_region_and_read_configuration_is_correct( prefix = sim_driver.name + "-" excitation_energy = await sim_driver.energy_source.energy.get_value() - region.switch_energy_mode(EnergyMode.KINETIC, excitation_energy) - + ke_region = region.switch_energy_mode(EnergyMode.KINETIC, excitation_energy) await assert_configuration( sim_driver, { @@ -120,9 +118,9 @@ async def test_analyser_sets_region_and_read_configuration_is_correct( f"{prefix}energy_mode": partial_reading(region.energy_mode), f"{prefix}acquisition_mode": partial_reading(region.acquisition_mode), f"{prefix}lens_mode": partial_reading(region.lens_mode), - f"{prefix}low_energy": partial_reading(region.low_energy), - f"{prefix}centre_energy": partial_reading(region.centre_energy), - f"{prefix}high_energy": partial_reading(region.high_energy), + f"{prefix}low_energy": partial_reading(ke_region.low_energy), + f"{prefix}centre_energy": partial_reading(ke_region.centre_energy), + f"{prefix}high_energy": partial_reading(ke_region.high_energy), f"{prefix}energy_step": partial_reading(region.energy_step), f"{prefix}pass_energy": partial_reading(region.pass_energy), f"{prefix}slices": partial_reading(region.slices), @@ -186,9 +184,14 @@ async def test_analayser_binding_energy_is_correct( # Check binding energy is correct energy_axis = [1, 2, 3, 4, 5] set_mock_value(sim_driver.energy_axis, np.array(energy_axis, dtype=float)) - is_binding = await sim_driver.energy_mode.get_value() == EnergyMode.BINDING + + # Check binding energy is correct + is_region_binding = region.is_binding_energy() + is_driver_binding = await sim_driver.energy_mode.get_value() == EnergyMode.BINDING + # Catch that driver correctly reflects what region energy mode is. + assert is_region_binding == is_driver_binding expected_binding_energy_axis = np.array( - [excitation_energy - e if is_binding else e for e in energy_axis] + [excitation_energy - e if is_driver_binding else e for e in energy_axis] ) await assert_value(sim_driver.binding_energy_axis, expected_binding_energy_axis)