diff --git a/homewizard_energy/v2/__init__.py b/homewizard_energy/v2/__init__.py index 8aa4698..2868ccf 100644 --- a/homewizard_energy/v2/__init__.py +++ b/homewizard_energy/v2/__init__.py @@ -5,20 +5,13 @@ import asyncio import logging import ssl -from collections.abc import Callable, Coroutine +from collections.abc import Coroutine from http import HTTPStatus -from typing import Any, TypeVar +from typing import Any, Callable, TypeVar +import aiohttp import async_timeout import backoff -from aiohttp.client import ( - ClientError, - ClientResponseError, - ClientSession, - ClientTimeout, - TCPConnector, -) -from aiohttp.hdrs import METH_DELETE, METH_GET, METH_POST, METH_PUT from homewizard_energy.errors import ( DisabledError, @@ -30,6 +23,7 @@ from .cacert import CACERT from .models import Device, Measurement, System, SystemUpdate +from .websocket import Websocket _LOGGER = logging.getLogger(__name__) @@ -54,9 +48,10 @@ async def wrapper(self, *args, **kwargs) -> T: class HomeWizardEnergyV2: """Communicate with a HomeWizard Energy device.""" - _clientsession: ClientSession | None = None + _clientsession: aiohttp.ClientSession | None = None _close_clientsession: bool = False _request_timeout: int = 10 + _websocket: Websocket | None = None def __init__( self, @@ -89,6 +84,16 @@ def host(self) -> str: """ return self._host + @property + def websocket(self) -> Websocket: + """Return the websocket object. + + Create a new websocket object if it does not exist. + """ + if self._websocket is None: + self._websocket = Websocket(self) + return self._websocket + @authorized_method async def device(self) -> Device: """Return the device object.""" @@ -115,7 +120,7 @@ async def system( if update is not None: data = update.as_dict() status, response = await self._request( - "/api/system", method=METH_PUT, data=data + "/api/system", method=aiohttp.hdrs.METH_PUT, data=data ) else: @@ -133,7 +138,7 @@ async def identify( self, ) -> bool: """Send identify request.""" - await self._request("/api/system/identify", method=METH_PUT) + await self._request("/api/system/identify", method=aiohttp.hdrs.METH_PUT) return True async def get_token( @@ -142,7 +147,7 @@ async def get_token( ) -> str: """Get authorization token from device.""" status, response = await self._request( - "/api/user", method=METH_POST, data={"name": f"local/{name}"} + "/api/user", method=aiohttp.hdrs.METH_POST, data={"name": f"local/{name}"} ) if status == HTTPStatus.FORBIDDEN: @@ -168,7 +173,7 @@ async def delete_token( """Delete authorization token from device.""" status, response = await self._request( "/api/user", - method=METH_DELETE, + method=aiohttp.hdrs.METH_DELETE, data={"name": name} if name is not None else None, ) @@ -180,11 +185,34 @@ async def delete_token( if name is None: self._token = None - async def _get_clientsession(self) -> ClientSession: + @property + def token(self) -> str | None: + """Return the token of the device. + + Returns: + token: The used token + + """ + return self._token + + @property + def request_timeout(self) -> int: + """Return the request timeout of the device. + + Returns: + request_timeout: The used request timeout + + """ + return self._request_timeout + + async def get_clientsession(self) -> aiohttp.ClientSession: """ Get a clientsession that is tuned for communication with the HomeWizard Energy Device """ + if self._clientsession is not None: + return self._clientsession + def _build_ssl_context() -> ssl.SSLContext: context = ssl.create_default_context(cadata=CACERT) if self._identifier is not None: @@ -199,26 +227,28 @@ def _build_ssl_context() -> ssl.SSLContext: loop = asyncio.get_running_loop() context = await loop.run_in_executor(None, _build_ssl_context) - connector = TCPConnector( + connector = aiohttp.TCPConnector( enable_cleanup_closed=True, ssl=context, limit_per_host=1, ) - return ClientSession( - connector=connector, timeout=ClientTimeout(total=self._request_timeout) + self._clientsession = aiohttp.ClientSession( + connector=connector, + timeout=aiohttp.ClientTimeout(total=self._request_timeout), ) + return self._clientsession + @backoff.on_exception(backoff.expo, RequestError, max_tries=5, logger=None) async def _request( - self, path: str, method: str = METH_GET, data: object = None + self, path: str, method: str = aiohttp.hdrs.METH_GET, data: object = None ) -> Any: """Make a request to the API.""" - if self._clientsession is None: - self._clientsession = await self._get_clientsession() + _clientsession = await self.get_clientsession() - if self._clientsession.closed: + if _clientsession.closed: # Avoid runtime errors when connection is closed. # This solves an issue when updates were scheduled and clientsession was closed. return None @@ -235,7 +265,7 @@ async def _request( try: async with async_timeout.timeout(self._request_timeout): - resp = await self._clientsession.request( + resp = await _clientsession.request( method, url, json=data, @@ -249,7 +279,7 @@ async def _request( raise RequestError( f"Timeout occurred while connecting to the HomeWizard Energy device at {self.host}" ) from exception - except (ClientError, ClientResponseError) as exception: + except (aiohttp.ClientError, aiohttp.ClientResponseError) as exception: raise RequestError( f"Error occurred while communicating with the HomeWizard Energy device at {self.host}" ) from exception @@ -276,6 +306,7 @@ async def close(self) -> None: _LOGGER.debug("Closing clientsession") if self._clientsession is not None: await self._clientsession.close() + self._clientsession = None async def __aenter__(self) -> HomeWizardEnergyV2: """Async enter. diff --git a/homewizard_energy/v2/const.py b/homewizard_energy/v2/const.py index 520289d..d1884f4 100644 --- a/homewizard_energy/v2/const.py +++ b/homewizard_energy/v2/const.py @@ -1,6 +1,16 @@ """Constants for HomeWizard Energy.""" -SUPPORTED_API_VERSION = "v1" +from enum import StrEnum + +SUPPORTED_API_VERSION = "2.0.0" SUPPORTS_STATE = ["HWE-SKT"] SUPPORTS_IDENTIFY = ["HWE-SKT", "HWE-P1", "HWE-WTR"] + + +class WebsocketTopic(StrEnum): + """Websocket topics.""" + + DEVICE = "device" + MEASUREMENT = "measurement" + SYSTEM = "system" diff --git a/homewizard_energy/v2/websocket.py b/homewizard_energy/v2/websocket.py new file mode 100644 index 0000000..19688eb --- /dev/null +++ b/homewizard_energy/v2/websocket.py @@ -0,0 +1,188 @@ +"""Websocket client for HomeWizard Energy API.""" + +import asyncio +import logging +from typing import TYPE_CHECKING, Callable + +import aiohttp + +from homewizard_energy.errors import UnauthorizedError + +from .const import WebsocketTopic +from .models import Device, Measurement, System + +if TYPE_CHECKING: + from . import HomeWizardEnergyV2 + +OnMessageCallbackType = Callable[[str, Device | Measurement | System], None] + +_LOGGER = logging.getLogger(__name__) + + +class Websocket: + """Websocket client for HomeWizard Energy API.""" + + _connect_lock: asyncio.Lock = asyncio.Lock() + _ws_connection: aiohttp.ClientWebSocketResponse | None = None + _ws_subscriptions: list[tuple[str, OnMessageCallbackType]] = [] + + _ws_authenticated: bool = False + + def __init__(self, parent: "HomeWizardEnergyV2"): + self._parent = parent + + async def connect(self) -> bool: + """Connect the websocket.""" + + if self._connect_lock.locked(): + _LOGGER.debug("Another connect is already happening") + return False + try: + await asyncio.wait_for(self._connect_lock.acquire(), timeout=0.1) + except asyncio.TimeoutError: + _LOGGER.debug("Failed to get connection lock") + + start_event = asyncio.Event() + _LOGGER.debug("Scheduling WS connect...") + asyncio.create_task(self._websocket_loop(start_event)) + + try: + await asyncio.wait_for( + start_event.wait(), timeout=self._parent.request_timeout + ) + except asyncio.TimeoutError: + _LOGGER.warning("Timed out while waiting for Websocket to connect") + await self.disconnect() + + self._connect_lock.release() + if self._ws_connection is None: + _LOGGER.debug("Failed to connect to Websocket") + return False + _LOGGER.debug("Connected to Websocket successfully") + return True + + async def disconnect(self) -> None: + """Disconnect the websocket.""" + if self._ws_connection is not None and not self._ws_connection.closed: + await self._ws_connection.close() + self._ws_connection = None + + def subscribe( + self, topic: WebsocketTopic, ws_callback: OnMessageCallbackType + ) -> Callable[[], None]: + """ + Subscribe to raw websocket messages. + + Returns a callback that will unsubscribe. + """ + + def _unsub_ws_callback() -> None: + self._ws_subscriptions.remove({topic, ws_callback}) + + _LOGGER.debug("Adding subscription: %s, %s", topic, ws_callback) + self._ws_subscriptions.append((topic, ws_callback)) + + if self._ws_connection is not None and self._ws_authenticated: + asyncio.create_task( + self._ws_connection.send_json({"type": "subscribe", "data": topic}) + ) + + return _unsub_ws_callback + + async def _websocket_loop(self, start_event: asyncio.Event) -> None: + _LOGGER.debug("Connecting WS...") + + _clientsession = await self._parent.get_clientsession() + + # catch any and all errors for Websocket so we can clean up correctly + try: + self._ws_connection = await _clientsession.ws_connect( + f"wss://{self._parent.host}/api/ws", ssl=False + ) + start_event.set() + + async for msg in self._ws_connection: + _LOGGER.info("Received message: %s", msg) + if not await self._process_message(msg): + break + except aiohttp.ClientError as e: + _LOGGER.exception("Websocket disconnect error: %s", e) + finally: + _LOGGER.debug("Websocket disconnected") + if self._ws_connection is not None and not self._ws_connection.closed: + await self._ws_connection.close() + self._ws_connection = None + # make sure event does not timeout + start_event.set() + + async def _on_authorization_requested(self, msg_type: str, msg_data: str) -> None: + del msg_type, msg_data + + _LOGGER.info("Authorization requested") + if self._ws_authenticated: + raise UnauthorizedError("Already authenticated") + + await self._ws_connection.send_json( + {"type": "authorization", "data": self._parent.token} + ) + + async def _on_authorized(self, msg_type: str, msg_data: str) -> None: + del msg_type, msg_data + + _LOGGER.info("Authorized") + self._ws_authenticated = True + + # Send subscription requests + print(self._ws_subscriptions) + for topic, _ in self._ws_subscriptions: + _LOGGER.info("Sending subscription request for %s", topic) + await self._ws_connection.send_json({"type": "subscribe", "data": topic}) + + async def _process_message(self, msg: aiohttp.WSMessage) -> bool: + if msg.type == aiohttp.WSMsgType.ERROR: + raise ValueError(f"Error from Websocket: {msg.data}") + + _LOGGER.debug("Received message: %s", msg.data) + + if msg.type == aiohttp.WSMsgType.TEXT: + try: + msg = msg.json() + except ValueError as ex: + raise ValueError(f"Invalid JSON received: {msg.data}") from ex + + if "type" not in msg: + raise ValueError(f"Missing 'type' in message: {msg}") + + msg_type = msg.get("type") + msg_data = msg.get("data") + parsed_data = None + + match msg_type: + case "authorization_requested": + await self._on_authorization_requested(msg_type, msg_data) + return True + + case "authorized": + await self._on_authorized(msg_type, msg_data) + return True + + case WebsocketTopic.MEASUREMENT: + parsed_data = Measurement.from_dict(msg_data) + + case WebsocketTopic.SYSTEM: + parsed_data = System.from_dict(msg_data) + + case WebsocketTopic.DEVICE: + parsed_data = Device.from_dict(msg_data) + + if parsed_data is None: + raise ValueError(f"Unknown message type: {msg_type}") + + for topic, callback in self._ws_subscriptions: + if topic == msg_type: + try: + await callback(topic, parsed_data) + except Exception: # pylint: disable=broad-except + _LOGGER.exception("Error processing websocket message") + + return True