22from asyncio import wait_for
33from contextlib import nullcontext
44from dataclasses import dataclass
5- from unittest .mock import AsyncMock , MagicMock , call , patch
5+ from unittest .mock import AsyncMock , MagicMock , patch
66
77import numpy as np
88import pytest
@@ -268,6 +268,20 @@ def common_grid_scan_params(request: pytest.FixtureRequest):
268268 return request .getfixturevalue (request .param )
269269
270270
271+ @pytest .fixture (
272+ params = [
273+ ["zebra_fast_grid_scan" , "zebra_grid_scan_params" ],
274+ ["panda_fast_grid_scan" , "panda_grid_scan_params" ],
275+ ],
276+ ids = ["zebra" , "panda" ],
277+ )
278+ def grid_scan_devices_with_params (request : pytest .FixtureRequest ):
279+ return (
280+ request .getfixturevalue (request .param [0 ]),
281+ request .getfixturevalue (request .param [1 ]),
282+ )
283+
284+
271285@pytest .mark .parametrize (
272286 "grid_position, expected" ,
273287 [
@@ -403,33 +417,37 @@ async def test_i02_1_gridscan_has_2d_behaviour(
403417
404418# TODO check all signals, parametrize for Panda
405419async def test_gridscan_prepare_writes_values_and_checks_readback (
406- zebra_fast_grid_scan : ZebraFastGridScan ,
407- zebra_grid_scan_params : ZebraGridScanParams ,
420+ grid_scan_devices_with_params ,
408421):
409- params = zebra_grid_scan_params
422+ grid_scan_device , grid_scan_params = grid_scan_devices_with_params
423+ params = grid_scan_params
410424 for signal , value in {
411- zebra_fast_grid_scan .x_scan_valid : 1 ,
412- zebra_fast_grid_scan .y_scan_valid : 1 ,
413- zebra_fast_grid_scan .z_scan_valid : 1 ,
414- zebra_fast_grid_scan .scan_invalid : 0 ,
425+ grid_scan_device .x_scan_valid : 1 ,
426+ grid_scan_device .y_scan_valid : 1 ,
427+ grid_scan_device .z_scan_valid : 1 ,
428+ grid_scan_device .scan_invalid : 0 ,
415429 }.items ():
416430 set_mock_value (signal , value )
417431
418- signals_to_check = [zebra_fast_grid_scan .x_steps , zebra_fast_grid_scan .y_steps ]
419- for signal in signals_to_check :
432+ signal_names_to_param_names = {
433+ signal .name : p_name
434+ for p_name , signal in grid_scan_device .movable_params .items ()
435+ }
436+ signals = [
437+ grid_scan_device .movable_params [k ] for k in signal_names_to_param_names .values ()
438+ ]
439+ for signal in signals :
420440 set_mock_put_proceeds (signal , False )
421441
422- status = zebra_fast_grid_scan .prepare (params )
442+ status = grid_scan_device .prepare (params )
423443
424- for signal in signals_to_check :
444+ for signal in signals :
425445 put = get_mock_put (signal )
426446 assert not status .done
427447 while True :
428448 if len (put .mock_calls ) > 0 :
429449 put .assert_called_once_with (
430- zebra_grid_scan_params .__dict__ [
431- signal .name [signal .name .rfind ("-" ) + 1 :]
432- ],
450+ grid_scan_params .__dict__ [signal_names_to_param_names [signal .name ]],
433451 wait = True ,
434452 )
435453 break
@@ -441,34 +459,35 @@ async def test_gridscan_prepare_writes_values_and_checks_readback(
441459
442460
443461async def test_gridscan_prepare_checks_validity_after_writes (
444- zebra_fast_grid_scan : ZebraFastGridScan , zebra_grid_scan_params : ZebraGridScanParams
462+ grid_scan_devices_with_params ,
445463):
464+ grid_scan_device , grid_scan_params = grid_scan_devices_with_params
446465 parent = MagicMock ()
447466 expected_signals_to_set = {}
448467
449- for key in zebra_grid_scan_params .__dict__ .keys ():
450- if signal := getattr (zebra_fast_grid_scan , key , None ):
468+ for key in grid_scan_params .__dict__ .keys ():
469+ if signal := getattr (grid_scan_device , key , None ):
451470 expected_signals_to_set [key ] = signal
452471
453472 for key , signal in expected_signals_to_set .items ():
454473 parent .attach_mock (get_mock_put (signal ), key )
455474
456475 checked_signals = {
457- zebra_fast_grid_scan .x_scan_valid : 1 ,
458- zebra_fast_grid_scan .y_scan_valid : 1 ,
459- zebra_fast_grid_scan .z_scan_valid : 1 ,
460- zebra_fast_grid_scan .scan_invalid : 0 ,
476+ grid_scan_device .x_scan_valid : 1 ,
477+ grid_scan_device .y_scan_valid : 1 ,
478+ grid_scan_device .z_scan_valid : 1 ,
479+ grid_scan_device .scan_invalid : 0 ,
461480 }
462481 for signal , expected_value in checked_signals .items ():
463482 set_mock_value (signal , 0 if expected_value else 1 )
464483
465- status = zebra_fast_grid_scan .prepare (zebra_grid_scan_params )
466- await asyncio .sleep (0.1 )
467- assert not status .done
484+ status = grid_scan_device .prepare (grid_scan_params )
468485 for key in expected_signals_to_set :
469- parent .assert_has_calls (
470- [getattr (call , key )(zebra_grid_scan_params .__dict__ [key ], wait = True )]
471- )
486+ mock_put = getattr (parent , key )
487+ while len (mock_put .mock_calls ) == 0 :
488+ await asyncio .sleep (0.1 )
489+ mock_put .assert_called_with (grid_scan_params .__dict__ [key ], wait = True )
490+ assert not status .done
472491
473492 for signal , expected_value in checked_signals .items ():
474493 set_mock_value (signal , expected_value )
@@ -477,22 +496,24 @@ async def test_gridscan_prepare_checks_validity_after_writes(
477496
478497
479498async def test_gridscan_prepare_times_out_for_validity_check (
480- zebra_fast_grid_scan : ZebraFastGridScan , zebra_grid_scan_params : ZebraGridScanParams
499+ grid_scan_devices_with_params ,
481500):
501+ grid_scan_device , grid_scan_params = grid_scan_devices_with_params
482502 checked_signals = {
483- zebra_fast_grid_scan .x_scan_valid : 1 ,
484- zebra_fast_grid_scan .y_scan_valid : 1 ,
485- zebra_fast_grid_scan .z_scan_valid : 1 ,
486- zebra_fast_grid_scan .scan_invalid : 0 ,
503+ grid_scan_device .x_scan_valid : 1 ,
504+ grid_scan_device .y_scan_valid : 1 ,
505+ grid_scan_device .z_scan_valid : 1 ,
506+ grid_scan_device .scan_invalid : 0 ,
487507 }
508+ device_name = grid_scan_device .name
488509 for signal , expected_value in checked_signals .items ():
489- if signal .name != "fake_FGS -scan_invalid" :
510+ if signal .name != f" { device_name } -scan_invalid" :
490511 set_mock_value (signal , 0 if expected_value else 1 )
491512
492- status = zebra_fast_grid_scan .prepare (zebra_grid_scan_params )
513+ status = grid_scan_device .prepare (grid_scan_params )
493514
494515 with pytest .raises (
495516 TimeoutError ,
496- match = "fake_FGS -x_scan_valid didn't match 1 in 0.5s, last value 0.0" ,
517+ match = f" { device_name } -x_scan_valid didn't match 1 in 0.5s, last value 0.0" ,
497518 ):
498519 await status
0 commit comments