diff --git a/canopen/pdo/__init__.py b/canopen/pdo/__init__.py index 533309f8..9729934d 100644 --- a/canopen/pdo/__init__.py +++ b/canopen/pdo/__init__.py @@ -24,17 +24,19 @@ class PDO(PdoBase): :param tpdo: TPDO object holding the Transmit PDO mappings """ - def __init__(self, node, rpdo, tpdo): + def __init__(self, node, rpdo: PdoBase, tpdo: PdoBase): super(PDO, self).__init__(node) self.rx = rpdo.map self.tx = tpdo.map - self.map = {} - # the object 0x1A00 equals to key '1' so we remove 1 from the key + self.map = PdoMaps(0, 0, self) + # Combine RX and TX entries, but only via mapping parameter index. Relative index + # numbers would be ambiguous. + # The object 0x1A00 equals to key '1' so we remove 1 from the key for key, value in self.rx.items(): - self.map[0x1A00 + (key - 1)] = value + self.map.maps[self.rx.map_offset + (key - 1)] = value for key, value in self.tx.items(): - self.map[0x1600 + (key - 1)] = value + self.map.maps[self.tx.map_offset + (key - 1)] = value class RPDO(PdoBase): diff --git a/canopen/pdo/base.py b/canopen/pdo/base.py index 2e335e54..b050206b 100644 --- a/canopen/pdo/base.py +++ b/canopen/pdo/base.py @@ -1,6 +1,7 @@ from __future__ import annotations import binascii +import contextlib import logging import math import threading @@ -33,24 +34,28 @@ class PdoBase(Mapping): def __init__(self, node: Union[LocalNode, RemoteNode]): self.network: canopen.network.Network = canopen.network._UNINITIALIZED_NETWORK - self.map: Optional[PdoMaps] = None + self.map: PdoMaps # must initialize in derived classes self.node: Union[LocalNode, RemoteNode] = node def __iter__(self): return iter(self.map) - def __getitem__(self, key): - if isinstance(key, int) and (0x1A00 <= key <= 0x1BFF or # By TPDO ID (512) - 0x1600 <= key <= 0x17FF or # By RPDO ID (512) - 0 < key <= 512): # By PDO Index - return self.map[key] - else: - for pdo_map in self.map.values(): - try: - return pdo_map[key] - except KeyError: - # ignore if one specific PDO does not have the key and try the next one - continue + def __getitem__(self, key: Union[int, str]): + if isinstance(key, int): + if key == 0: + raise KeyError("PDO index zero requested for 1-based sequence") + if ( + 0 < key <= 512 # By PDO Index + or 0x1600 <= key <= 0x17FF # By RPDO ID (512) + or 0x1A00 <= key <= 0x1BFF # By TPDO ID (512) + ): + return self.map[key] + for pdo_map in self.map.values(): + try: + return pdo_map[key] + except KeyError: + # ignore if one specific PDO does not have the key and try the next one + continue raise KeyError(f"PDO: {key} was not found in any map") def __len__(self): @@ -140,10 +145,10 @@ def stop(self): pdo_map.stop() -class PdoMaps(Mapping): +class PdoMaps(Mapping[int, 'PdoMap']): """A collection of transmit or receive maps.""" - def __init__(self, com_offset, map_offset, pdo_node: PdoBase, cob_base=None): + def __init__(self, com_offset: int, map_offset: int, pdo_node: PdoBase, cob_base=None): """ :param com_offset: :param map_offset: @@ -151,6 +156,11 @@ def __init__(self, com_offset, map_offset, pdo_node: PdoBase, cob_base=None): :param cob_base: """ self.maps: dict[int, PdoMap] = {} + self.com_offset = com_offset + self.map_offset = map_offset + if not com_offset and not map_offset: + # Skip generating entries without parameter index offsets + return for map_no in range(512): if com_offset + map_no in pdo_node.node.object_dictionary: new_map = PdoMap( @@ -163,7 +173,9 @@ def __init__(self, com_offset, map_offset, pdo_node: PdoBase, cob_base=None): self.maps[map_no + 1] = new_map def __getitem__(self, key: int) -> PdoMap: - return self.maps[key] + with contextlib.suppress(KeyError): + return self.maps[key] + return self.maps[key + 1 - self.map_offset] def __iter__(self) -> Iterator[int]: return iter(self.maps) diff --git a/test/test_pdo.py b/test/test_pdo.py index b8bb0599..f07406a4 100644 --- a/test/test_pdo.py +++ b/test/test_pdo.py @@ -50,15 +50,39 @@ def test_pdo_getitem(self): self.assertEqual(node.tpdo[1]['BOOLEAN value 2'].raw, True) # Test different types of access - self.assertEqual(node.pdo[0x1600]['INTEGER16 value'].raw, -3) - self.assertEqual(node.pdo['INTEGER16 value'].raw, -3) - self.assertEqual(node.pdo.tx[1]['INTEGER16 value'].raw, -3) - self.assertEqual(node.pdo[0x2001].raw, -3) - self.assertEqual(node.tpdo[0x2001].raw, -3) - self.assertEqual(node.pdo[0x2002].raw, 0xf) - self.assertEqual(node.pdo['0x2002'].raw, 0xf) - self.assertEqual(node.tpdo[0x2002].raw, 0xf) - self.assertEqual(node.pdo[0x1600][0x2002].raw, 0xf) + by_mapping_record = node.pdo[0x1A00] + self.assertIsInstance(by_mapping_record, canopen.pdo.PdoMap) + self.assertEqual(by_mapping_record['INTEGER16 value'].raw, -3) + self.assertIs(node.tpdo[0x1A00], by_mapping_record) + by_object_name = node.pdo['INTEGER16 value'] + self.assertIsInstance(by_object_name, canopen.pdo.PdoVariable) + self.assertIs(by_object_name.od, node.object_dictionary['INTEGER16 value']) + self.assertEqual(by_object_name.raw, -3) + by_pdo_index = node.pdo.tx[1] + self.assertIs(by_pdo_index, by_mapping_record) + by_object_index = node.pdo[0x2001] + self.assertIsInstance(by_object_index, canopen.pdo.PdoVariable) + self.assertIs(by_object_index, by_object_name) + by_object_index_tpdo = node.tpdo[0x2001] + self.assertIs(by_object_index_tpdo, by_object_name) + by_object_index = node.pdo[0x2002] + self.assertEqual(by_object_index.raw, 0xf) + self.assertIs(node.pdo['0x2002'], by_object_index) + self.assertIs(node.tpdo[0x2002], by_object_index) + self.assertIs(node.pdo[0x1A00][0x2002], by_object_index) + + self.assertRaises(KeyError, lambda: node.pdo[0]) + self.assertRaises(KeyError, lambda: node.tpdo[0]) + self.assertRaises(KeyError, lambda: node.pdo['DOES NOT EXIST']) + + def test_pdo_maps_iterate(self): + node = self.node + self.assertEqual(len(node.pdo), sum(1 for _ in node.pdo)) + self.assertEqual(len(node.tpdo), sum(1 for _ in node.tpdo)) + self.assertEqual(len(node.rpdo), sum(1 for _ in node.rpdo)) + + pdo = node.tpdo[1] + self.assertEqual(len(pdo), sum(1 for _ in pdo)) def test_pdo_save(self): self.node.tpdo.save()