From d196e7f84f00d95abfd4ef0b7ec4b796e1fbee03 Mon Sep 17 00:00:00 2001 From: jeffthibault Date: Fri, 17 Feb 2023 08:14:40 -0500 Subject: [PATCH] refactor relay threads --- nostr/relay.py | 81 +++++++++++++++++++++++------------------- nostr/relay_manager.py | 32 ++++++++++------- 2 files changed, 63 insertions(+), 50 deletions(-) diff --git a/nostr/relay.py b/nostr/relay.py index 8ab88f8..f780ca1 100644 --- a/nostr/relay.py +++ b/nostr/relay.py @@ -2,7 +2,7 @@ import time from dataclasses import dataclass from queue import Queue -from threading import Lock +from threading import Lock, Thread from typing import Optional from websocket import WebSocketApp from .event import Event @@ -21,7 +21,6 @@ def to_json_object(self) -> dict[str, bool]: - @dataclass class RelayProxyConnectionConfig: host: Optional[str] = None @@ -37,15 +36,13 @@ class Relay: policy: RelayPolicy = RelayPolicy() ssl_options: Optional[dict] = None proxy_config: RelayProxyConnectionConfig = None + error_threshold: int = 5 def __post_init__(self): - self.queue = Queue() + self.outgoing_messages = Queue() self.subscriptions: dict[str, Subscription] = {} self.num_sent_events: int = 0 - self.connected: bool = False - self.reconnect: bool = True self.error_counter: int = 0 - self.error_threshold: int = 0 self.lock: Lock = Lock() self.ws: WebSocketApp = WebSocketApp( self.url, @@ -54,42 +51,55 @@ def __post_init__(self): on_error=self._on_error, on_close=self._on_close ) + self._connection_thread: Thread = None - def connect(self): - self.ws.run_forever( - sslopt=self.ssl_options, - http_proxy_host=self.proxy_config.host if self.proxy_config is not None else None, - http_proxy_port=self.proxy_config.port if self.proxy_config is not None else None, - proxy_type=self.proxy_config.type if self.proxy_config is not None else None, - ) + def connect(self, is_reconnect=False): + if not self.is_connected(): + with self.lock: + self._connection_thread = Thread( + target=self.ws.run_forever, + kwargs={ + "sslopt": self.ssl_options, + "http_proxy_host": self.proxy_config.host if self.proxy_config is not None else None, + "http_proxy_port": self.proxy_config.port if self.proxy_config is not None else None, + "proxy_type": self.proxy_config.type if self.proxy_config is not None else None + }, + name=f"{self.url}-connection" + ) + self._connection_thread.start() + + if not is_reconnect: + Thread( + target=self.outgoing_messages_worker, + name=f"{self.url}-outgoing-messages-worker", + daemon=True + ).start() + time.sleep(1) + def close(self): - self.ws.close() + if self.is_connected(): + self.ws.close() - def check_reconnect(self): - try: - self.close() - except: - pass - self.connected = False - if self.reconnect: - time.sleep(1) - self.connect() + def is_connected(self) -> bool: + with self.lock: + if self._connection_thread is None or not self._connection_thread.is_alive(): + return False + else: + return True def publish(self, message: str): - self.queue.put(message) + self.outgoing_messages.put(message) - def queue_worker(self): + def outgoing_messages_worker(self): while True: - if self.connected: - message = self.queue.get() + if self.is_connected(): + message = self.outgoing_messages.get() try: self.ws.send(message) self.num_sent_events += 1 except: - self.queue.put(message) - else: - time.sleep(0.1) + self.outgoing_messages.put(message) def add_subscription(self, id, filters: Filters): with self.lock: @@ -115,21 +125,18 @@ def to_json_object(self) -> dict: } def _on_open(self, class_obj): - self.connected = True + pass def _on_close(self, class_obj, status_code, message): - self.connected = False + self.error_counter = 0 def _on_message(self, class_obj, message: str): self.message_pool.add_message(message, self.url) def _on_error(self, class_obj, error): - self.connected = False self.error_counter += 1 - if self.error_threshold and self.error_counter > self.error_threshold: - pass - else: - self.check_reconnect() + if self.error_counter > self.error_threshold: + self.close() def _is_valid_message(self, message: str) -> bool: message = message.strip("\n") diff --git a/nostr/relay_manager.py b/nostr/relay_manager.py index ffe3868..16bb21f 100644 --- a/nostr/relay_manager.py +++ b/nostr/relay_manager.py @@ -20,11 +20,20 @@ class RelayException(Exception): @dataclass class RelayManager: + connection_monitor_interval_secs: int = 5 + def __post_init__(self): self.relays: dict[str, Relay] = {} self.message_pool: MessagePool = MessagePool() self.lock: Lock = Lock() + threading.Thread( + target=self._relay_connection_monitor, + name="relay-connection-monitor", + daemon=True + ).start() + + def add_relay( self, url: str, @@ -36,19 +45,7 @@ def add_relay( with self.lock: self.relays[url] = relay - - threading.Thread( - target=relay.connect, - name=f"{relay.url}-thread" - ).start() - - threading.Thread( - target=relay.queue_worker, - name=f"{relay.url}-queue", - daemon=True - ).start() - - time.sleep(1) + relay.connect() def remove_relay(self, url: str): with self.lock: @@ -109,3 +106,12 @@ def publish_event(self, event: Event): for relay in self.relays.values(): if relay.policy.should_write: relay.publish(event.to_message()) + + def _relay_connection_monitor(self): + while True: + with self.lock: + for relay in self.relays.values(): + if not relay.is_connected(): + relay.connect(True) + + time.sleep(self.connection_monitor_interval_secs)