diff --git a/sbot/arduino.py b/sbot/arduino.py index faac7721..c6fd9a87 100644 --- a/sbot/arduino.py +++ b/sbot/arduino.py @@ -358,11 +358,19 @@ def analog_value(self) -> float: ADC_MIN = 0 self._check_if_disabled() - if self.mode not in ANALOG_READ_MODES: - raise IOError(f'Analog read is not supported in {self.mode}') if not self._supports_analog: - raise IOError('Pin does not support analog read') - response = self._serial.query(f'PIN:{self._index}:ANALOG:GET?') + raise IOError(f'Analog read is not supported on pin {self._index}') + + # Combine the mode and response queries into a single pipeline + mode, response = self._serial.query_multi([ + f'PIN:{self._index}:MODE:GET?', + f'PIN:{self._index}:ANALOG:GET?', + ]) + mode = GPIOPinMode(mode) + + if mode not in ANALOG_READ_MODES: + raise IOError(f'Analog read is not supported in {self.mode}') + # map the response from the ADC range to the voltage range return map_to_float(int(response), ADC_MIN, ADC_MAX, 0.0, 5.0) diff --git a/sbot/serial_wrapper.py b/sbot/serial_wrapper.py index 38333f9d..e5001317 100644 --- a/sbot/serial_wrapper.py +++ b/sbot/serial_wrapper.py @@ -6,6 +6,7 @@ """ from __future__ import annotations +import itertools import logging import sys import threading @@ -122,47 +123,102 @@ def stop(self) -> None: """ self._disconnect() + def _connect_if_needed(self) -> None: + if not self.serial.is_open: + if not self._connect(): + # If the serial port cannot be opened raise an error, + # this will be caught by the retry decorator + raise BoardDisconnectionError(( + f'Connection to board {self.identity.board_type}:' + f'{self.identity.asset_tag} could not be established', + )) + @retry(times=3, exceptions=(BoardDisconnectionError, UnicodeDecodeError)) - def query(self, data: str) -> str: + def query_multi(self, commands: list[str]) -> list[str]: """ Send a command to the board and return the response. - This method will automatically reconnect to the board and retry the command + This method will automatically reconnect to the board and retry the commands up to 3 times on serial errors. - :param data: The data to write to the board. + :param commands: The commands to write to the board. :raises BoardDisconnectionError: If the serial connection fails during the transaction, including failing to respond to the command. - :return: The response from the board with the trailing newline removed. + :return: The responses from the board with the trailing newlines removed. """ + # Verify no command has a newline in it, and build a command `bytes` from the + # list of commands + encoded_commands: list[bytes] = [] + invalid_commands: list[tuple[str, str]] = [] + + for command in commands: + if '\n' in command: + invalid_commands.append(("contains newline", command)) + else: + try: + byte_form = command.encode(encoding='utf-8') + except UnicodeEncodeError as e: + invalid_commands.append((str(e), command)) + else: + encoded_commands.append(byte_form) + encoded_commands.append(b'\n') + + if invalid_commands: + invalid_commands.sort() + + invalid_command_groups = dict(itertools.groupby( + invalid_commands, + key=lambda x: x[0], + )) + + error_message = "\n".join( + ["Invalid commands:"] + + [ + f" {reason}: " + ", ".join( + repr(command) + for _, command in grouped_commands + ) + for reason, grouped_commands in invalid_command_groups.items() + ], + ) + raise ValueError(error_message) + + full_commands = b''.join(encoded_commands) + with self._lock: - if not self.serial.is_open: - if not self._connect(): - # If the serial port cannot be opened raise an error, - # this will be caught by the retry decorator - raise BoardDisconnectionError(( - f'Connection to board {self.identity.board_type}:' - f'{self.identity.asset_tag} could not be established', - )) + # If the serial port is not open, try to connect + self._connect_if_needed() # TODO: Write me + # Contain all the serial IO in a try-catch; on error, disconnect and raise an error try: - logger.log(TRACE, f'Serial write - {data!r}') - cmd = data + '\n' - self.serial.write(cmd.encode()) - - response = self.serial.readline() - try: - response_str = response.decode().rstrip('\n') - except UnicodeDecodeError as e: - logger.warning( - f"Board {self.identity.board_type}:{self.identity.asset_tag} " - f"returned invalid characters: {response!r}") - raise e - logger.log( - TRACE, f'Serial read - {response_str!r}') - - if b'\n' not in response: - # If readline times out no error is raised, it returns an incomplete string + # Send the commands to the board + self.serial.write(full_commands) + + # Log the commands + for command in commands: + logger.log(TRACE, f"Serial write - {command!r}") + + # Read as many lines as there are commands + responses_binary = [ + self.serial.readline() + for _ in range(len(commands)) + ] + + # Log the responses. For backwards compatibility reasons, we decode + # these separately here before any error processing, so that the + # logs are correct even if an error occurs. + for response_binary in responses_binary: + response_decoded = response_binary.decode( + "utf-8", + errors="replace", + ).rstrip('\n') + logger.log(TRACE, f"Serial read - {response_decoded!r}") + + # Check all responses have a trailing newline (an incomplete + # response will not). + # This is within the lock and try-catch to ensure the serial port + # is closed on error. + if not all(response.endswith(b'\n') for response in responses_binary): logger.warning(( f'Connection to board {self.identity.board_type}:' f'{self.identity.asset_tag} timed out waiting for response' @@ -176,15 +232,51 @@ def query(self, data: str) -> str: 'disconnected during transaction' )) - if response_str.startswith('NACK'): - _, error_msg = response_str.split(':', maxsplit=1) - logger.error(( - f'Board {self.identity.board_type}:{self.identity.asset_tag} ' - f'returned NACK on write command: {error_msg}' - )) - raise RuntimeError(error_msg) + # Decode all the responses as UTF-8 + try: + responses_decoded = [ + response.decode("utf-8").rstrip('\n') + for response in responses_binary + ] + except UnicodeDecodeError as e: + logger.warning( + f"Board {self.identity.board_type}:{self.identity.asset_tag} " + f"returned invalid characters: {responses_binary!r}") + raise e + + # Collect any NACK responses; if any, raise an error + nack_prefix = 'NACK:' + nack_responses = [ + response + for response in responses_decoded + if response.startswith(nack_prefix) + ] + + if nack_responses: + errors = [response[len(nack_prefix):] for response in nack_responses] + # We can't use exception groups due to needing to support Python 3.8 + raise ( + RuntimeError(errors[0]) + if len(errors) == 1 + else RuntimeError("Multiple errors: " + ", ".join(errors)) + ) + + # Return the list of responses + return responses_decoded + + def query(self, data: str) -> str: + """ + Send a command to the board and return the response. - return response_str + This method will automatically reconnect to the board and retry the command + up to 3 times on serial errors. + + :param data: The data to write to the board. + :raises BoardDisconnectionError: If the serial connection fails during the transaction, + including failing to respond to the command. + :return: The response from the board with the trailing newline removed. + """ + return self.query_multi([data])[0] def write(self, data: str) -> None: """ diff --git a/tests/conftest.py b/tests/conftest.py index fe0aa435..8b09b256 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -65,7 +65,8 @@ def query(self, request: str) -> str: """ # Assert that we have not run out of responses # and that the request is the next one we expect - assert self.request_index < len(self.responses) + if self.request_index >= len(self.responses): + raise AssertionError(f"Unexpected request: {request}") assert request == self.responses[self.request_index][0] # Fetch the response and increment the request index @@ -73,6 +74,10 @@ def query(self, request: str) -> str: self.request_index += 1 return response + def query_multi(self, commands: list[str]) -> list[str]: + """Send multiple commands and return the responses.""" + return [self.query(command) for command in commands] + def write(self, request: str) -> None: """Send a command without waiting for a response.""" _ = self.query(request) diff --git a/tests/test_arduino.py b/tests/test_arduino.py index b348c187..c238e59d 100644 --- a/tests/test_arduino.py +++ b/tests/test_arduino.py @@ -6,7 +6,7 @@ from __future__ import annotations import re -from typing import NamedTuple +from typing import NamedTuple, Generator import pytest @@ -25,7 +25,7 @@ class MockArduino(NamedTuple): @pytest.fixture -def arduino_serial(monkeypatch) -> None: +def arduino_serial(monkeypatch) -> Generator[MockArduino, None, None]: serial_wrapper = MockSerialWrapper([ ("*IDN?", "Student Robotics:Arduino:X:2.0"), # Called by Arduino.__init__ ]) @@ -84,83 +84,136 @@ def test_arduino_ultrasound(arduino_serial: MockArduino) -> None: arduino.ultrasound_measure(25, 2) -def test_arduino_pins(arduino_serial: MockArduino) -> None: - """ - Test the arduino pins properties and methods. - - This test uses the mock serial wrapper to simulate the arduino. - """ +@pytest.mark.parametrize("pin,expected,command,response", [ + (2, GPIOPinMode.OUTPUT, "PIN:2:MODE:GET?", "OUTPUT"), + (10, GPIOPinMode.INPUT_PULLUP, "PIN:10:MODE:GET?", "INPUT_PULLUP"), + (AnalogPins.A0, GPIOPinMode.INPUT, "PIN:14:MODE:GET?", "INPUT"), +]) +def test_arduino_get_pin_mode( + arduino_serial: MockArduino, + pin: int, + expected: GPIOPinMode, + command: str, + response: str, +) -> None: arduino = arduino_serial.arduino_board arduino_serial.serial_wrapper._add_responses([ - ("PIN:2:MODE:GET?", "OUTPUT"), - ("PIN:10:MODE:GET?", "INPUT_PULLUP"), - ("PIN:14:MODE:GET?", "INPUT"), - ("PIN:2:MODE:SET:OUTPUT", "ACK"), - ("PIN:10:MODE:SET:INPUT_PULLUP", "ACK"), - ("PIN:14:MODE:SET:INPUT", "ACK"), - # PIN::ANALOG:GET? + (command, response), ]) - # Test that we can get the mode of a pin - assert arduino.pins[2].mode == GPIOPinMode.OUTPUT - assert arduino.pins[10].mode == GPIOPinMode.INPUT_PULLUP - assert arduino.pins[AnalogPins.A0].mode == GPIOPinMode.INPUT + assert arduino.pins[pin].mode == expected + + +def test_arduino_set_invalid_pin_mode(arduino_serial: MockArduino) -> None: + arduino = arduino_serial.arduino_board with pytest.raises(IOError): arduino.pins[2].mode = 1 - # Test that we can set the mode of a pin - arduino.pins[2].mode = GPIOPinMode.OUTPUT - arduino.pins[10].mode = GPIOPinMode.INPUT_PULLUP - arduino.pins[AnalogPins.A0].mode = GPIOPinMode.INPUT - # Test that we can get the digital value of a pin +@pytest.mark.parametrize("pin,mode,command,response", [ + (2, GPIOPinMode.OUTPUT, "PIN:2:MODE:SET:OUTPUT", "ACK"), + (10, GPIOPinMode.INPUT_PULLUP, "PIN:10:MODE:SET:INPUT_PULLUP", "ACK"), + (AnalogPins.A0, GPIOPinMode.INPUT, "PIN:14:MODE:SET:INPUT", "ACK"), +]) +def test_arduino_set_pin_mode( + arduino_serial: MockArduino, + pin: int, + mode: GPIOPinMode, + command: str, + response: str, +) -> None: + arduino = arduino_serial.arduino_board arduino_serial.serial_wrapper._add_responses([ - ("PIN:2:MODE:GET?", "OUTPUT"), # mode is read before digital value - ("PIN:2:DIGITAL:GET?", "1"), - ("PIN:10:MODE:GET?", "INPUT_PULLUP"), - ("PIN:10:DIGITAL:GET?", "0"), - ("PIN:14:MODE:GET?", "INPUT"), - ("PIN:14:DIGITAL:GET?", "1"), + (command, response), ]) - assert arduino.pins[2].digital_value is True - assert arduino.pins[10].digital_value is False - assert arduino.pins[AnalogPins.A0].digital_value is True - # Test that we can set the digital value of a pin + arduino.pins[pin].mode = mode + + +@pytest.mark.parametrize("pin,expected,command,response,mode_command,mode_response", [ + (2, True, "PIN:2:DIGITAL:GET?", "1", "PIN:2:MODE:GET?", "OUTPUT"), + (10, False, "PIN:10:DIGITAL:GET?", "0", "PIN:10:MODE:GET?", "INPUT_PULLUP"), + (AnalogPins.A0, True, "PIN:14:DIGITAL:GET?", "1", "PIN:14:MODE:GET?", "INPUT"), +]) +def test_arduino_get_digital_value( + arduino_serial: MockArduino, + pin: int, + expected: bool, + command: str, + response: str, + mode_command: str, + mode_response: str, +) -> None: + arduino = arduino_serial.arduino_board arduino_serial.serial_wrapper._add_responses([ - ("PIN:2:MODE:GET?", "OUTPUT"), # mode is read before digital value - ("PIN:2:DIGITAL:SET:1", "ACK"), - ("PIN:2:MODE:GET?", "OUTPUT"), - ("PIN:2:DIGITAL:SET:0", "ACK"), - ("PIN:10:MODE:GET?", "INPUT_PULLUP"), - ("PIN:10:MODE:GET?", "INPUT_PULLUP"), - ("PIN:14:MODE:GET?", "INPUT"), - ("PIN:14:MODE:GET?", "INPUT"), + (mode_command, mode_response), + (command, response), ]) - arduino.pins[2].digital_value = True - arduino.pins[2].digital_value = False - with pytest.raises(IOError, match=r"Digital write is not supported.*"): - arduino.pins[10].digital_value = False + + assert arduino.pins[pin].digital_value == expected + + +@pytest.mark.parametrize("pin,value,command,response,mode_command,mode_response", [ + (2, True, "PIN:2:DIGITAL:SET:1", "ACK", "PIN:2:MODE:GET?", "OUTPUT"), + (2, False, "PIN:2:DIGITAL:SET:0", "ACK", "PIN:2:MODE:GET?", "OUTPUT"), +]) +def test_arduino_set_digital_value( + arduino_serial: MockArduino, + pin: int, + value: bool, + command: str, + response: str, + mode_command: str, + mode_response: str, +) -> None: + arduino = arduino_serial.arduino_board + arduino_serial.serial_wrapper._add_responses([ + (mode_command, mode_response), + (command, response), + ]) + + arduino.pins[pin].digital_value = value + + +@pytest.mark.parametrize("pin,mode_command,mode_response", [ + (10, "PIN:10:MODE:GET?", "INPUT_PULLUP"), + (AnalogPins.A0, "PIN:14:MODE:GET?", "INPUT"), +]) +def test_arduino_set_invalid_digital_value( + arduino_serial: MockArduino, + pin: int, + mode_command: str, + mode_response: str, +) -> None: + arduino = arduino_serial.arduino_board + arduino_serial.serial_wrapper._add_responses([ + (mode_command, mode_response), + (mode_command, mode_response), + ]) + with pytest.raises(IOError, match=r"Digital write is not supported.*"): - arduino.pins[AnalogPins.A0].digital_value = True + arduino.pins[pin].digital_value = False + - # Test that we can get the analog value of a pin +def test_arduino_get_analog_value(arduino_serial: MockArduino) -> None: + arduino = arduino_serial.arduino_board arduino_serial.serial_wrapper._add_responses([ - ("PIN:2:MODE:GET?", "OUTPUT"), # mode is read before analog value - ("PIN:2:MODE:GET?", "OUTPUT"), - ("PIN:10:MODE:GET?", "INPUT"), ("PIN:14:MODE:GET?", "INPUT"), ("PIN:14:ANALOG:GET?", "1000"), ]) - with pytest.raises(IOError, match=r"Analog read is not supported.*"): - arduino.pins[2].analog_value - with pytest.raises(IOError, match=r"Pin does not support analog read"): - arduino.pins[10].analog_value + # 4.888 = round((5 / 1023) * 1000, 3) assert arduino.pins[AnalogPins.A0].analog_value == 4.888 +def test_arduino_get_invalid_analog_value_from_digital_only_pin(arduino_serial: MockArduino) -> None: + arduino = arduino_serial.arduino_board + + with pytest.raises(IOError, match=r".*not support.*"): + arduino.pins[2].analog_value + + def test_invalid_properties(arduino_serial: MockArduino) -> None: """ Test that settng invalid properties raise an AttributeError.