diff --git a/.github/workflows/build-deps-ci-action.yml b/.github/workflows/build-deps-ci-action.yml index 767cbaf6f30d..4476e68b02fd 100644 --- a/.github/workflows/build-deps-ci-action.yml +++ b/.github/workflows/build-deps-ci-action.yml @@ -157,7 +157,7 @@ jobs: macos-dependencies: name: MacOS - runs-on: ${{ matrix.arch == 'x86_64' && 'macos-13' || 'macos-14' }} + runs-on: ${{ matrix.arch == 'x86_64' && 'macos-15-intel' || 'macos-14' }} if: ${{ toJSON(fromJSON(inputs.matrix)['macos']) != '[]' }} timeout-minutes: 90 strategy: diff --git a/.github/workflows/build-packages.yml b/.github/workflows/build-packages.yml index d9315d576589..032bd4b61a15 100644 --- a/.github/workflows/build-packages.yml +++ b/.github/workflows/build-packages.yml @@ -294,7 +294,7 @@ jobs: env: PIP_INDEX_URL: https://pypi.org/simple runs-on: - - ${{ matrix.arch == 'arm64' && 'macos-14' || 'macos-13' }} + - ${{ matrix.arch == 'arm64' && 'macos-14' || 'macos-15-intel' }} steps: - name: Check Package Signing Enabled diff --git a/.github/workflows/build-salt-onedir.yml b/.github/workflows/build-salt-onedir.yml index 149c791bb1ae..fe0fa9b2963c 100644 --- a/.github/workflows/build-salt-onedir.yml +++ b/.github/workflows/build-salt-onedir.yml @@ -109,7 +109,7 @@ jobs: matrix: include: ${{ fromJSON(inputs.matrix)['macos'] }} runs-on: - - ${{ matrix.arch == 'arm64' && 'macos-14' || 'macos-13' }} + - ${{ matrix.arch == 'arm64' && 'macos-14' || 'macos-15-intel' }} env: PIP_INDEX_URL: https://pypi.org/simple USE_S3_CACHE: 'false' diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 93cacb4e4dc4..e9ddd249e54b 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -439,7 +439,7 @@ jobs: with: cache-seed: ${{ needs.prepare-workflow.outputs.cache-seed }} salt-version: "${{ needs.prepare-workflow.outputs.salt-version }}" - relenv-version: "0.21.1" + relenv-version: "0.21.2" python-version: "3.11.14" ci-python-version: "3.11" matrix: ${{ toJSON(fromJSON(needs.prepare-workflow.outputs.config)['build-matrix']) }} @@ -456,7 +456,7 @@ jobs: with: salt-version: "${{ needs.prepare-workflow.outputs.salt-version }}" cache-prefix: ${{ needs.prepare-workflow.outputs.cache-seed }} - relenv-version: "0.21.1" + relenv-version: "0.21.2" python-version: "3.11.14" ci-python-version: "3.11" source: "onedir" @@ -473,7 +473,7 @@ jobs: with: salt-version: "${{ needs.prepare-workflow.outputs.salt-version }}" cache-prefix: ${{ needs.prepare-workflow.outputs.cache-seed }} - relenv-version: "0.21.1" + relenv-version: "0.21.2" python-version: "3.11.14" ci-python-version: "3.11" source: "src" diff --git a/.github/workflows/nightly.yml b/.github/workflows/nightly.yml index 8f3e52f6fbb0..0a7d88064466 100644 --- a/.github/workflows/nightly.yml +++ b/.github/workflows/nightly.yml @@ -434,7 +434,7 @@ jobs: with: cache-seed: ${{ needs.prepare-workflow.outputs.cache-seed }} salt-version: "${{ needs.prepare-workflow.outputs.salt-version }}" - relenv-version: "0.21.1" + relenv-version: "0.21.2" python-version: "3.11.14" ci-python-version: "3.11" matrix: ${{ toJSON(fromJSON(needs.prepare-workflow.outputs.config)['build-matrix']) }} @@ -451,7 +451,7 @@ jobs: with: salt-version: "${{ needs.prepare-workflow.outputs.salt-version }}" cache-prefix: ${{ needs.prepare-workflow.outputs.cache-seed }} - relenv-version: "0.21.1" + relenv-version: "0.21.2" python-version: "3.11.14" ci-python-version: "3.11" source: "onedir" @@ -472,7 +472,7 @@ jobs: with: salt-version: "${{ needs.prepare-workflow.outputs.salt-version }}" cache-prefix: ${{ needs.prepare-workflow.outputs.cache-seed }} - relenv-version: "0.21.1" + relenv-version: "0.21.2" python-version: "3.11.14" ci-python-version: "3.11" source: "src" diff --git a/.github/workflows/scheduled.yml b/.github/workflows/scheduled.yml index 9016e93280ea..4a51bf3eb1e7 100644 --- a/.github/workflows/scheduled.yml +++ b/.github/workflows/scheduled.yml @@ -482,7 +482,7 @@ jobs: with: cache-seed: ${{ needs.prepare-workflow.outputs.cache-seed }} salt-version: "${{ needs.prepare-workflow.outputs.salt-version }}" - relenv-version: "0.21.1" + relenv-version: "0.21.2" python-version: "3.11.14" ci-python-version: "3.11" matrix: ${{ toJSON(fromJSON(needs.prepare-workflow.outputs.config)['build-matrix']) }} @@ -499,7 +499,7 @@ jobs: with: salt-version: "${{ needs.prepare-workflow.outputs.salt-version }}" cache-prefix: ${{ needs.prepare-workflow.outputs.cache-seed }} - relenv-version: "0.21.1" + relenv-version: "0.21.2" python-version: "3.11.14" ci-python-version: "3.11" source: "onedir" @@ -516,7 +516,7 @@ jobs: with: salt-version: "${{ needs.prepare-workflow.outputs.salt-version }}" cache-prefix: ${{ needs.prepare-workflow.outputs.cache-seed }} - relenv-version: "0.21.1" + relenv-version: "0.21.2" python-version: "3.11.14" ci-python-version: "3.11" source: "src" diff --git a/.github/workflows/staging.yml b/.github/workflows/staging.yml index 4ae63a61dcb1..926f4db03a4b 100644 --- a/.github/workflows/staging.yml +++ b/.github/workflows/staging.yml @@ -466,7 +466,7 @@ jobs: with: cache-seed: ${{ needs.prepare-workflow.outputs.cache-seed }} salt-version: "${{ needs.prepare-workflow.outputs.salt-version }}" - relenv-version: "0.21.1" + relenv-version: "0.21.2" python-version: "3.11.14" ci-python-version: "3.11" matrix: ${{ toJSON(fromJSON(needs.prepare-workflow.outputs.config)['build-matrix']) }} @@ -484,7 +484,7 @@ jobs: with: salt-version: "${{ needs.prepare-workflow.outputs.salt-version }}" cache-prefix: ${{ needs.prepare-workflow.outputs.cache-seed }} - relenv-version: "0.21.1" + relenv-version: "0.21.2" python-version: "3.11.14" ci-python-version: "3.11" source: "onedir" @@ -506,7 +506,7 @@ jobs: with: salt-version: "${{ needs.prepare-workflow.outputs.salt-version }}" cache-prefix: ${{ needs.prepare-workflow.outputs.cache-seed }} - relenv-version: "0.21.1" + relenv-version: "0.21.2" python-version: "3.11.14" ci-python-version: "3.11" source: "src" diff --git a/cicd/shared-gh-workflows-context.yml b/cicd/shared-gh-workflows-context.yml index 2be09b86da77..35df20024321 100644 --- a/cicd/shared-gh-workflows-context.yml +++ b/cicd/shared-gh-workflows-context.yml @@ -1,6 +1,6 @@ nox_version: "2022.8.7" python_version: "3.11.14" -relenv_version: "0.21.1" +relenv_version: "0.21.2" release_branches: - "3006.x" - "3007.x" diff --git a/salt/channel/client.py b/salt/channel/client.py index 2eec2d3cb17b..830253621704 100644 --- a/salt/channel/client.py +++ b/salt/channel/client.py @@ -21,7 +21,7 @@ import salt.utils.stringutils import salt.utils.verify import salt.utils.versions -from salt.utils.asynchronous import SyncWrapper +from salt.utils.asynchronous import SyncWrapper, aioloop log = logging.getLogger(__name__) @@ -422,7 +422,7 @@ def factory(cls, opts, **kwargs): def __init__(self, opts, transport, auth, io_loop=None): self.opts = opts - self.io_loop = io_loop + self.io_loop = aioloop(io_loop) self.auth = auth try: # This loads or generates the minion's public key. @@ -636,7 +636,7 @@ def __enter__(self): return self def __exit__(self, *args): - self.io_loop.spawn_callback(self.close) + self.io_loop.call_soon(self.close) async def __aenter__(self): return self diff --git a/salt/channel/server.py b/salt/channel/server.py index 41d7605ffede..a87ffbf62294 100644 --- a/salt/channel/server.py +++ b/salt/channel/server.py @@ -1156,10 +1156,13 @@ def _publish_daemon(self, **kwargs): payload_handler=self.handle_pool_publish, ) self.pool_puller.start() - self.io_loop.add_callback( - self.transport.publisher, - self.publish_payload, - io_loop=self.io_loop, + # Extract asyncio loop for create_task + aio_loop = salt.utils.asynchronous.aioloop(self.io_loop) + aio_loop.create_task( + self.transport.publisher( + self.publish_payload, + io_loop=self.io_loop, + ) ) # run forever try: diff --git a/salt/crypt.py b/salt/crypt.py index d19d5b75f83b..71585de68bab 100644 --- a/salt/crypt.py +++ b/salt/crypt.py @@ -29,6 +29,7 @@ import salt.channel.client import salt.defaults.exitcodes import salt.payload +import salt.utils.asynchronous import salt.utils.crypt import salt.utils.decorators import salt.utils.event @@ -766,7 +767,12 @@ def __singleton_init__(self, opts, io_loop=None): self.mpub = "minion_master.pub" if not os.path.isfile(self.pub_path): self.get_keys() - self.io_loop = io_loop or tornado.ioloop.IOLoop.current() + if io_loop is None: + self.io_loop = salt.utils.asynchronous.aioloop( + tornado.ioloop.IOLoop.current() + ) + else: + self.io_loop = salt.utils.asynchronous.aioloop(io_loop) key = self.__key(self.opts) # TODO: if we already have creds for this key, lets just re-use if key in AsyncAuth.creds_map: @@ -833,13 +839,13 @@ def authenticate(self, callback=None): else: future = tornado.concurrent.Future() self._authenticate_future = future - self.io_loop.add_callback(self._authenticate) + self.io_loop.create_task(self._authenticate()) if callback is not None: def handle_future(future): response = future.result() - self.io_loop.add_callback(callback, response) + self.io_loop.call_soon(callback, response) future.add_done_callback(handle_future) diff --git a/salt/master.py b/salt/master.py index 09dd05e77aab..31571914b9b0 100644 --- a/salt/master.py +++ b/salt/master.py @@ -17,8 +17,6 @@ import threading import time -import tornado.gen - import salt.acl import salt.auth import salt.channel.server @@ -951,7 +949,7 @@ def start(self): if self.opts.get("cluster_id", None): # Notify the rest of the cluster we're starting. ipc_publisher.send_aes_key_event() - self.process_manager.run() + asyncio.run(self.process_manager.run()) def _handle_signals(self, signum, sigframe): # escalate the signals to the process manager @@ -1008,13 +1006,19 @@ async def handle_event(self, package): log.trace("Ignore tag %s", tag) def run(self): - io_loop = tornado.ioloop.IOLoop() + io_loop = asyncio.new_event_loop() + asyncio.set_event_loop(io_loop) with salt.utils.event.get_master_event( self.opts, self.opts["sock_dir"], io_loop=io_loop, listen=True ) as event_bus: event_bus.subscribe("") event_bus.set_event_handler(self.handle_event) - io_loop.start() + try: + io_loop.run_forever() + except (KeyboardInterrupt, SystemExit): + pass + finally: + io_loop.close() class ReqServer(salt.utils.process.SignalHandlingProcess): @@ -1175,16 +1179,18 @@ def __bind(self): """ Bind to the local port """ - self.io_loop = tornado.ioloop.IOLoop() + self.io_loop = asyncio.new_event_loop() + asyncio.set_event_loop(self.io_loop) for req_channel in self.req_channels: req_channel.post_fork( self._handle_payload, io_loop=self.io_loop ) # TODO: cleaner? Maybe lazily? try: - self.io_loop.start() + self.io_loop.run_forever() except (KeyboardInterrupt, SystemExit): - # Tornado knows what to do pass + finally: + self.io_loop.close() async def _handle_payload(self, payload): """ diff --git a/salt/metaproxy/deltaproxy.py b/salt/metaproxy/deltaproxy.py index 0c151ba10113..5d9cc8ac18d7 100644 --- a/salt/metaproxy/deltaproxy.py +++ b/salt/metaproxy/deltaproxy.py @@ -5,6 +5,7 @@ import asyncio import concurrent.futures import copy +import functools import logging import os import signal @@ -164,8 +165,13 @@ async def post_master_init(self, master): # Start engines here instead of in the Minion superclass __init__ # This is because we need to inject the __proxy__ variable but # it is not setup until now. - self.io_loop.spawn_callback( - salt.engines.start_engines, self.opts, self.process_manager, proxy=self.proxy + self.io_loop.call_soon( + functools.partial( + salt.engines.start_engines, + self.opts, + self.process_manager, + proxy=self.proxy, + ) ) proxy_init_func_name = f"{fq_proxyname}.init" diff --git a/salt/metaproxy/proxy.py b/salt/metaproxy/proxy.py index d2ca10fec5c2..4d8fc10c4e9c 100644 --- a/salt/metaproxy/proxy.py +++ b/salt/metaproxy/proxy.py @@ -4,6 +4,7 @@ import asyncio import copy +import functools import logging import os import signal @@ -160,8 +161,13 @@ async def post_master_init(self, master): # Start engines here instead of in the Minion superclass __init__ # This is because we need to inject the __proxy__ variable but # it is not setup until now. - self.io_loop.spawn_callback( - salt.engines.start_engines, self.opts, self.process_manager, proxy=self.proxy + self.io_loop.call_soon( + functools.partial( + salt.engines.start_engines, + self.opts, + self.process_manager, + proxy=self.proxy, + ) ) if ( diff --git a/salt/minion.py b/salt/minion.py index 4a4b17eef456..7b6a25a33e9b 100644 --- a/salt/minion.py +++ b/salt/minion.py @@ -40,6 +40,7 @@ import salt.syspaths import salt.transport import salt.utils.args +import salt.utils.asynchronous import salt.utils.context import salt.utils.ctx import salt.utils.data @@ -1045,11 +1046,13 @@ def __init__(self, opts): self.max_auth_wait = self.opts["acceptance_wait_time_max"] self.minions = [] self.jid_queue = [] - self.io_loop = tornado.ioloop.IOLoop.current() + try: + self.io_loop = asyncio.get_running_loop() + except RuntimeError: + self.io_loop = asyncio.new_event_loop() + asyncio.set_event_loop(self.io_loop) self.process_manager = ProcessManager(name="MultiMinionProcessManager") - self.io_loop.spawn_callback( - self.process_manager.run, **{"asynchronous": True} - ) # Tornado backward compat + self.io_loop.create_task(self.process_manager.run(asynchronous=True)) self.event_publisher = None self.event = None @@ -1062,10 +1065,11 @@ def __del__(self): def _bind(self): # start up the event publisher, so we can see events during startup self.event_publisher = salt.transport.ipc_publish_server("minion", self.opts) - self.io_loop.spawn_callback( - self.event_publisher.publisher, - self.event_publisher.publish_payload, - self.io_loop, + self.io_loop.create_task( + self.event_publisher.publisher( + self.event_publisher.publish_payload, + io_loop=self.io_loop, + ) ) self.event = salt.utils.event.get_event( "minion", opts=self.opts, io_loop=self.io_loop @@ -1135,7 +1139,7 @@ def _spawn_minions(self, timeout=60): loaded_base_name="salt.loader.{}".format(s_opts["master"]), jid_queue=self.jid_queue, ) - self.io_loop.spawn_callback(self._connect_minion, minion) + self.io_loop.create_task(self._connect_minion(minion)) self.io_loop.call_later(timeout, self._check_minions) async def _connect_minion(self, minion): @@ -1203,7 +1207,12 @@ def tune_in(self): self._spawn_minions() # serve forever! - self.io_loop.start() + try: + self.io_loop.run_forever() + except (KeyboardInterrupt, SystemExit): + pass + finally: + self.io_loop.close() @property def restart(self): @@ -1219,9 +1228,7 @@ def stop(self, signum, parent_sig_handler): Called from cli.daemons.Minion._handle_signals(). Adds stop_async as callback to the io_loop to prevent blocking. """ - self.io_loop.add_callback( # pylint: disable=not-callable - self.stop_async, signum, parent_sig_handler - ) + self.io_loop.create_task(self.stop_async(signum, parent_sig_handler)) async def stop_async(self, signum, parent_sig_handler): """ @@ -1297,9 +1304,18 @@ def __init__( self.req_channel = None if io_loop is None: - self.io_loop = tornado.ioloop.IOLoop.current() + try: + self.io_loop = asyncio.get_running_loop() + except RuntimeError: + self.io_loop = asyncio.new_event_loop() + asyncio.set_event_loop(self.io_loop) else: - self.io_loop = io_loop + # Accept either asyncio loop or Tornado IOLoop (extract asyncio loop) + if isinstance(io_loop, asyncio.AbstractEventLoop): + self.io_loop = io_loop + else: + # Assume it's a Tornado IOLoop, extract the asyncio loop + self.io_loop = salt.utils.asynchronous.aioloop(io_loop) # Warn if ZMQ < 3.2 if zmq: @@ -1344,11 +1360,11 @@ def __init__( time.sleep(sleep_time) self.process_manager = ProcessManager(name="MinionProcessManager") - self.io_loop.spawn_callback(self.process_manager.run, **{"asynchronous": True}) + self.io_loop.create_task(self.process_manager.run(asynchronous=True)) # We don't have the proxy setup yet, so we can't start engines # Engines need to be able to access __proxy__ if not salt.utils.platform.is_proxy(): - self.io_loop.spawn_callback( + self.io_loop.call_soon( salt.engines.start_engines, self.opts, self.process_manager ) @@ -1385,7 +1401,7 @@ def on_connect_master_future_done(future): if timeout: self.io_loop.call_later(timeout, self.io_loop.stop) try: - self.io_loop.start() + self.io_loop.run_forever() except KeyboardInterrupt: self.destroy() # I made the following 3 line oddity to preserve traceback. @@ -2534,7 +2550,7 @@ def _state_run(self): else: data["fun"] = "state.highstate" data["arg"] = [] - self.io_loop.add_callback(self._handle_decoded_payload, data) + self.io_loop.create_task(self._handle_decoded_payload(data)) def _refresh_grains_watcher(self, refresh_interval_in_minutes): """ @@ -3255,7 +3271,7 @@ def tune_in(self, start=True): self.setup_scheduler(before_connect=True) self.sync_connect_master() if self.connected: - self.io_loop.add_callback(self._fire_master_minion_start) + self.io_loop.create_task(self._fire_master_minion_start()) log.info("Minion is ready to receive requests!") # Make sure to gracefully handle SIGUSR1 @@ -3298,11 +3314,12 @@ def ping_timeout_handler(*_): "minion is running under an init system." ) - self.io_loop.add_callback( - self._fire_master_main, - "ping", - "minion_ping", - timeout_handler=ping_timeout_handler, + self.io_loop.create_task( + self._fire_master_main( + "ping", + "minion_ping", + timeout_handler=ping_timeout_handler, + ) ) except Exception: # pylint: disable=broad-except log.warning( @@ -3320,14 +3337,17 @@ def ping_timeout_handler(*_): if start: try: - self.io_loop.start() + self.io_loop.run_forever() if self.restart: self.destroy() except ( KeyboardInterrupt, RuntimeError, - ): # A RuntimeError can be re-raised by Tornado on shutdown + ): # A RuntimeError can be re-raised during shutdown self.destroy() + finally: + if not self.io_loop.is_closed(): + self.io_loop.close() async def _handle_payload(self, payload): if payload is not None and payload["enc"] == "aes": @@ -3611,9 +3631,18 @@ def __init__(self, opts, io_loop=None): self.jid_forward_cache = set() if io_loop is None: - self.io_loop = tornado.ioloop.IOLoop.current() + try: + self.io_loop = asyncio.get_running_loop() + except RuntimeError: + self.io_loop = asyncio.new_event_loop() + asyncio.set_event_loop(self.io_loop) else: - self.io_loop = io_loop + # Accept either asyncio loop or Tornado IOLoop (extract asyncio loop) + if isinstance(io_loop, asyncio.AbstractEventLoop): + self.io_loop = io_loop + else: + # Assume it's a Tornado IOLoop, extract the asyncio loop + self.io_loop = salt.utils.asynchronous.aioloop(io_loop) # List of events self.raw_events = [] @@ -3626,6 +3655,8 @@ def __init__(self, opts, io_loop=None): self.tries = collections.defaultdict(int) # Active pub futures: {master_id: (future, [job_ret, ...]), ...} self.pub_futures = {} + # Local client (set in tune_in()) + self.local = None def _spawn_syndics(self): """ @@ -3648,7 +3679,7 @@ async def connect(future, s_opts): except Exception as exc: # pylint: disable=broad-except future.set_exception(exc) - self.io_loop.spawn_callback(connect, future, s_opts) + self.io_loop.create_task(connect(future, s_opts)) async def _connect_syndic(self, opts): """ @@ -3845,7 +3876,13 @@ def tune_in(self): # Make sure to gracefully handle SIGUSR1 enable_sigusr1_handler() - self.io_loop.start() + try: + self.io_loop.run_forever() + except (KeyboardInterrupt, SystemExit): + pass + finally: + if not self.io_loop.is_closed(): + self.io_loop.close() async def _process_event(self, raw): # TODO: cleanup: Move down into event class diff --git a/salt/transport/tcp.py b/salt/transport/tcp.py index 1994a41adf03..380ce2f1b425 100644 --- a/salt/transport/tcp.py +++ b/salt/transport/tcp.py @@ -8,6 +8,7 @@ import asyncio import asyncio.exceptions import errno +import inspect import logging import multiprocessing import queue @@ -183,12 +184,24 @@ def run(self): # Wait for a free slot to be available to put # the connection into. # Sockets are picklable on Windows in Python 3. - self.socket_queue.put((connection, address), True, None) + try: + self.socket_queue.put((connection, address), True, None) + except Exception: # pylint: disable=broad-except + log.exception( + "Failed to enqueue connection from %s into load balancer queue", + address, + ) + connection.close() except OSError as e: # ECONNABORTED indicates that there was a connection # but it was closed while still in the accept queue. # (observed on FreeBSD). - name = self._socket.getsockname() + name = None + if self._socket is not None: + try: + name = self._socket.getsockname() + except OSError: + name = None if isinstance(name, tuple): name = name[0] if tornado.util.errno_from_exception(e) == errno.ECONNABORTED: @@ -229,7 +242,7 @@ class PublishClient(salt.transport.base.PublishClient): def __init__(self, opts, io_loop, **kwargs): # pylint: disable=W0231 super().__init__(opts, io_loop, **kwargs) self.opts = opts - self.io_loop = io_loop + self.io_loop = salt.utils.asynchronous.aioloop(io_loop) self.unpacker = salt.utils.msgpack.Unpacker() self.connected = False self._closing = False @@ -429,7 +442,7 @@ async def recv(self, timeout=None): self._stream = None stream.close() if self.disconnect_callback: - self.disconnect_callback() + await self.disconnect_callback() await self.connect() log.debug("Re-connected - continue") continue @@ -634,6 +647,7 @@ def __init__(self, message_handler, *args, **kwargs): self._closing = False super().__init__(*args, **kwargs) self.io_loop = io_loop + self.asyncio_loop = salt.utils.asynchronous.aioloop(io_loop) self.clients = [] self.message_handler = message_handler @@ -656,9 +670,18 @@ async def handle_stream( # pylint: disable=arguments-differ,invalid-overridden- for framed_msg in unpacker: framed_msg = salt.transport.frame.decode_embedded_strs(framed_msg) header = framed_msg["head"] - self.io_loop.spawn_callback( - self.message_handler, stream, framed_msg["body"], header - ) + try: + log.trace("Dispatching message handler for %s", address) + result = self.message_handler( + stream, framed_msg["body"], header + ) + except Exception: # pylint: disable=broad-except + log.exception( + "Unhandled exception while running message handler" + ) + else: + if inspect.isawaitable(result): + self.asyncio_loop.create_task(result) except _StreamClosedError: log.trace("req client disconnected %s", address) self.remove_client((stream, address)) @@ -703,7 +726,7 @@ def __init__(self, socket_queue, message_handler, *args, **kwargs): super().__init__(message_handler, *args, **kwargs) self.socket_queue = socket_queue self._stop = threading.Event() - self.thread = threading.Thread(target=self.socket_queue_thread) + self.thread = threading.Thread(target=self.socket_queue_thread, daemon=True) self.thread.start() def close(self): @@ -720,10 +743,13 @@ def socket_queue_thread(self): if self._stop.is_set(): break continue - # 'self.io_loop' initialized in super class - # 'salt.ext.tornado.tcpserver.TCPServer'. - # 'self._handle_connection' defined in same super class. - self.io_loop.spawn_callback( + log.trace( + "LoadBalancerWorker received queued connection from %s", address + ) + # Schedule handling the connection using the event loop. + # Must use call_soon_threadsafe since we're in a background thread + aio_loop = salt.utils.asynchronous.aioloop(self.io_loop) + aio_loop.call_soon_threadsafe( self._handle_connection, client_socket, address ) except (KeyboardInterrupt, SystemExit): @@ -785,9 +811,11 @@ def __init__( self.source_port = source_port self.connect_callback = connect_callback self.disconnect_callback = disconnect_callback - self.io_loop = io_loop or tornado.ioloop.IOLoop.current() - with salt.utils.asynchronous.current_ioloop(self.io_loop): - self._tcp_client = TCPClientKeepAlive(opts, resolver=resolver) + if io_loop is None: + io_loop = tornado.ioloop.IOLoop.current() + self.io_loop = io_loop + self.asyncio_loop = salt.utils.asynchronous.aioloop(io_loop) + self._tcp_client = TCPClientKeepAlive(opts, resolver=resolver) # TODO: max queue size self.send_future_map = {} # mapping of request_id -> Future @@ -795,7 +823,7 @@ def __init__( self._on_recv = None self._closing = False self._closed = False - self._connecting_future = tornado.concurrent.Future() + self._connecting_future = self.asyncio_loop.create_future() self._stream_return_running = False self._stream = None @@ -806,9 +834,12 @@ def close(self): if self._closing or self._closed: return self._closing = True - self.io_loop.add_timeout(1, self.check_close) + if not self.send_future_map: + self.io_loop.call_later(0, self.check_close) + else: + self.io_loop.call_later(1, self.check_close) - async def check_close(self): + def check_close(self): if not self.send_future_map: self._tcp_client.close() if self._stream: @@ -817,7 +848,7 @@ async def check_close(self): self._closed = True self._closing = False else: - self.io_loop.add_timeout(1, self.check_close) + self.io_loop.call_later(1, self.check_close) # pylint: disable=W1701 def __del__(self): @@ -860,7 +891,7 @@ async def connect(self): self._closing = False self._closed = False if not self._stream_return_running: - self.io_loop.spawn_callback(self._stream_return) + return_task = self.asyncio_loop.create_task(self._stream_return()) if self.connect_callback: self.connect_callback(True) @@ -882,7 +913,7 @@ async def _stream_return(self): # self.remove_message_timeout(message_id) else: if self._on_recv is not None: - self.io_loop.spawn_callback(self._on_recv, header, body) + self.io_loop.call_soon(self._on_recv, header, body) else: log.error( "Got response for message_id %s that we are not" @@ -970,7 +1001,7 @@ async def send(self, msg, timeout=None, callback=None, raw=False): message_id = self._message_id() header = {"mid": message_id} - future = tornado.concurrent.Future() + future = self.asyncio_loop.create_future() if callback is not None: @@ -998,7 +1029,7 @@ async def _do_send(): # Run send in a callback so we can wait on the future, in case we time # out before we are able to connect. - self.io_loop.add_callback(_do_send) + send_task = self.asyncio_loop.create_task(_do_send()) return await future @@ -1053,7 +1084,12 @@ def __init__( ssl=None, ): super().__init__(ssl_options=ssl) - self.io_loop = io_loop + if io_loop is None: + self.io_loop = salt.utils.asynchronous.aioloop( + tornado.ioloop.IOLoop.current() + ) + else: + self.io_loop = salt.utils.asynchronous.aioloop(io_loop) self.opts = opts self._closing = False self.clients = set() @@ -1119,7 +1155,7 @@ def handle_stream(self, stream, address): log.debug("Subscriber at %s connected", address) client = Subscriber(stream, address) self.clients.add(client) - self.io_loop.spawn_callback(self._stream_read, client) + self.io_loop.create_task(self._stream_read(client)) # TODO: ACK the publish through IPC async def publish_payload(self, package, topic_list=None): @@ -1198,7 +1234,12 @@ def __init__( # Placeholders for attributes to be populated by method calls self.sock = None - self.io_loop = io_loop or tornado.ioloop.IOLoop.current() + if io_loop is None: + self.io_loop = salt.utils.asynchronous.aioloop( + tornado.ioloop.IOLoop.current() + ) + else: + self.io_loop = salt.utils.asynchronous.aioloop(io_loop) self._closing = False def start(self): @@ -1242,10 +1283,7 @@ async def handle_stream(self, stream): unpacker.feed(wire_bytes) for framed_msg in unpacker: body = framed_msg["body"] - self.io_loop.spawn_callback( - self.payload_handler, - body, - ) + self.io_loop.create_task(self.payload_handler(body)) except tornado.iostream.StreamClosedError: if self.path: log.trace("Client disconnected from IPC %s", self.path) @@ -1277,7 +1315,7 @@ def handle_connection(self, connection, address): stream = tornado.iostream.IOStream( connection, ) - self.io_loop.spawn_callback(self.handle_stream, stream) + self.io_loop.create_task(self.handle_stream(stream)) except Exception as exc: # pylint: disable=broad-except log.error("IPC streaming error: %s", exc) @@ -1597,7 +1635,12 @@ def __init__(self, host, port, path, io_loop=None): to the server. """ - self.io_loop = io_loop or tornado.ioloop.IOLoop.current() + if io_loop is None: + self.io_loop = salt.utils.asynchronous.aioloop( + tornado.ioloop.IOLoop.current() + ) + else: + self.io_loop = salt.utils.asynchronous.aioloop(io_loop) self.host = host self.port = port self.path = path @@ -1622,7 +1665,7 @@ def connect(self, callback=None, timeout=None): future = tornado.concurrent.Future() self._connecting_future = future # self._connect(timeout) - self.io_loop.spawn_callback(self._connect, timeout) + self.io_loop.create_task(self._connect(timeout)) if callback is not None: @@ -1739,7 +1782,7 @@ class RequestClient(salt.transport.base.RequestClient): def __init__(self, opts, io_loop, **kwargs): # pylint: disable=W0231 super().__init__(opts, io_loop, **kwargs) self.opts = opts - self.io_loop = io_loop + self.io_loop = salt.utils.asynchronous.aioloop(io_loop) parse = urllib.parse.urlparse(self.opts["master_uri"]) master_host, master_port = parse.netloc.rsplit(":", 1) @@ -1800,7 +1843,7 @@ async def connect(self): # pylint: disable=invalid-overridden-method if not self._stream_return_running: self.task = asyncio.create_task(self._stream_return()) if self.connect_callback is not None: - self.connect_callback() + await self.connect_callback() async def _stream_return(self): self._stream_return_running = True @@ -1819,7 +1862,7 @@ async def _stream_return(self): self.send_future_map.pop(message_id).set_result(body) else: if self._on_recv is not None: - self.io_loop.spawn_callback(self._on_recv, header, body) + self.io_loop.call_soon(self._on_recv, header, body) else: log.error( "Got response for message_id %s that we are not" @@ -1910,7 +1953,7 @@ async def _do_send(): # Run send in a callback so we can wait on the future, in case we time # out before we are able to connect. - self.io_loop.add_callback(_do_send) + self.io_loop.create_task(_do_send()) recv = await future return recv diff --git a/salt/transport/ws.py b/salt/transport/ws.py index fe99d099af8b..4353fa463811 100644 --- a/salt/transport/ws.py +++ b/salt/transport/ws.py @@ -13,6 +13,7 @@ import salt.payload import salt.transport.base import salt.transport.frame +import salt.utils.asynchronous from salt.transport.tcp import ( USE_LOAD_BALANCER, LoadBalancerServer, @@ -43,7 +44,10 @@ class PublishClient(salt.transport.base.PublishClient): def __init__(self, opts, io_loop, **kwargs): # pylint: disable=W0231 self.opts = opts + if io_loop is None: + io_loop = tornado.ioloop.IOLoop.current() self.io_loop = io_loop + self.asyncio_loop = salt.utils.asynchronous.aioloop(io_loop) self.connected = False self._closing = False @@ -72,24 +76,30 @@ def __init__(self, opts, io_loop, **kwargs): # pylint: disable=W0231 self._closing = False self.on_recv_task = None - async def _close(self): - if self._session is not None: - await self._session.close() - self._session = None + def close(self): + if self._closing: + return + self._closing = True + # Cancel the receive task but don't await it (like TCP does) if self.on_recv_task: self.on_recv_task.cancel() - await self.on_recv_task self.on_recv_task = None - if self._ws is not None: - await self._ws.close() - self._ws = None + # Schedule async cleanup but don't wait for it + if self._session is not None or self._ws is not None: + self.asyncio_loop.create_task(self._async_cleanup()) self._closed = True - def close(self): - if self._closing: - return - self._closing = True - self.io_loop.spawn_callback(self._close) + async def _async_cleanup(self): + """Background cleanup of async resources""" + try: + if self._session is not None: + await self._session.close() + self._session = None + if self._ws is not None: + await self._ws.close() + self._ws = None + except Exception: # pylint: disable=broad-except + pass # Cleanup is best-effort # pylint: disable=W1701 def __del__(self): @@ -110,10 +120,10 @@ async def getstream(self, **kwargs): if self.source_ip or self.source_port: kwargs.update(source_ip=self.source_ip, source_port=self.source_port) ws = None + session = None start = time.monotonic() timeout = kwargs.get("timeout", None) while ws is None and (not self._closed and not self._closing): - session = None try: ctx = None if self.ssl is not None: @@ -263,6 +273,7 @@ def __init__( pull_path_perms=0o600, pub_path_perms=0o600, started=None, + _shutdown=None, ssl=None, ): self.opts = opts @@ -277,9 +288,16 @@ def __init__( self.ssl = ssl self.clients = set() self._run = None + if _shutdown is None: + self._shutdown = multiprocessing.Event() # Cross-process shutdown signal + else: + self._shutdown = _shutdown self.pub_writer = None self.pub_reader = None self._connecting = None + self.runner = None + self.site = None + self.puller = None if started is None: self.started = multiprocessing.Event() else: @@ -310,6 +328,7 @@ def __getstate__(self): "pub_path_perms": self.pub_path_perms, "ssl": self.ssl, "started": self.started, + "_shutdown": self._shutdown, } def publish_daemon( @@ -321,19 +340,19 @@ def publish_daemon( """ Bind to the interface specified in the configuration file """ - io_loop = tornado.ioloop.IOLoop() - io_loop.add_callback( - self.publisher, - publish_payload, - presence_callback, - remove_presence_callback, - io_loop, + # Use asyncio event loop directly like ZeroMQ does + io_loop = salt.utils.asynchronous.aioloop(tornado.ioloop.IOLoop()) + + # Set up asyncio signal handler to stop the loop on SIGTERM + import signal + + io_loop.add_signal_handler(signal.SIGTERM, io_loop.stop) + + publisher_task = io_loop.create_task( + self.publisher(publish_payload, io_loop=io_loop) ) - # run forever try: - io_loop.start() - except (KeyboardInterrupt, SystemExit): - pass + io_loop.run_forever() finally: self.close() @@ -348,19 +367,30 @@ async def publisher( io_loop = tornado.ioloop.IOLoop.current() if self._run is None: self._run = asyncio.Event() - self._run.set() + + # Monitor the multiprocessing shutdown event and stop the loop + async def monitor_shutdown(): + loop = asyncio.get_running_loop() + # Wait for shutdown signal in executor to avoid blocking + await loop.run_in_executor(None, self._shutdown.wait) + self._run.set() + loop.stop() + + asyncio.create_task(monitor_shutdown()) ctx = None if self.ssl is not None: ctx = salt.transport.base.ssl_context(self.ssl, server_side=True) if self.pub_path: server = aiohttp.web.Server(self.handle_request) - runner = aiohttp.web.ServerRunner(server) - await runner.setup() + self.runner = aiohttp.web.ServerRunner(server) + await self.runner.setup() with salt.utils.files.set_umask(0o177): log.info("Publisher binding to socket %s", self.pub_path) - site = aiohttp.web.UnixSite(runner, self.pub_path, ssl_context=ctx) - await site.start() + self.site = aiohttp.web.UnixSite( + self.runner, self.pub_path, ssl_context=ctx + ) + await self.site.start() os.chmod(self.pub_path, self.pub_path_perms) else: sock = _get_socket(self.opts) @@ -370,11 +400,11 @@ async def publisher( sock.bind((self.pub_host, self.pub_port)) sock.listen(self.backlog) server = aiohttp.web.Server(self.handle_request) - runner = aiohttp.web.ServerRunner(server) - await runner.setup() - site = aiohttp.web.SockSite(runner, sock, ssl_context=ctx) + self.runner = aiohttp.web.ServerRunner(server) + await self.runner.setup() + self.site = aiohttp.web.SockSite(self.runner, sock, ssl_context=ctx) log.info("Publisher binding to socket %s:%s", self.pub_host, self.pub_port) - await site.start() + await self.site.start() self._pub_payload = publish_payload if self.pull_path: @@ -388,9 +418,13 @@ async def publisher( self.pull_handler, self.pull_host, self.pull_port ) self.started.set() - while self._run.is_set(): - await asyncio.sleep(0.3) - await self.server.stop() + # Wait for shutdown signal instead of polling + await self._run.wait() + # Properly shut down aiohttp server + await self.site.stop() + await self.runner.cleanup() + # Close the puller server + self.puller.close() await self.puller.wait_closed() async def pull_handler(self, reader, writer): @@ -425,8 +459,14 @@ async def handle_request(self, request): ws = aiohttp.web.WebSocketResponse() await ws.prepare(request) self.clients.add(ws) - while True: - await asyncio.sleep(1) + try: + # Keep connection alive until client disconnects + async for msg in ws: + if msg.type == aiohttp.WSMsgType.ERROR: + log.error("ws connection closed with exception %s", ws.exception()) + break + finally: + self.clients.discard(ws) async def _connect(self): if self.pull_path: @@ -469,8 +509,8 @@ def close(self): self.pub_writer.close() self.pub_writer = None self.pub_reader = None - if self._run is not None: - self._run.clear() + # Signal shutdown across processes + self._shutdown.set() if self._connecting: self._connecting.cancel() @@ -509,7 +549,9 @@ def post_fork(self, message_handler, io_loop): self.message_handler = message_handler self._run = asyncio.Event() self._started = asyncio.Event() - self._run.set() + + # Convert to asyncio loop + io_loop = salt.utils.asynchronous.aioloop(io_loop) async def server(): server = aiohttp.web.Server(self.handle_message) @@ -524,12 +566,12 @@ async def server(): self._started.set() # pause here for very long time by serving HTTP requests and # waiting for keyboard interruption - while self._run.is_set(): - await asyncio.sleep(0.3) + # Wait for shutdown signal instead of polling + await self._run.wait() await self.site.stop() self._socket.close() - io_loop.spawn_callback(server) + io_loop.create_task(server()) async def handle_message(self, request): try: @@ -554,7 +596,7 @@ async def handle_message(self, request): log.error("ws connection closed with exception %s", ws.exception()) def close(self): - self._run.clear() + self._run.set() # Signal shutdown if self._socket is not None: self._socket.shutdown(socket.SHUT_RDWR) self._socket.close() @@ -570,7 +612,7 @@ def __init__(self, opts, io_loop): # pylint: disable=W0231 self.sending = False self.ws = None self.session = None - self.io_loop = io_loop + self.io_loop = salt.utils.asynchronous.aioloop(io_loop) self._closing = False self._closed = False self.ssl = self.opts.get("ssl", None) diff --git a/salt/transport/zeromq.py b/salt/transport/zeromq.py index 924a838b7c58..b96ded25e63c 100644 --- a/salt/transport/zeromq.py +++ b/salt/transport/zeromq.py @@ -27,6 +27,7 @@ import salt.payload import salt.transport.base +import salt.utils.asynchronous import salt.utils.files import salt.utils.process import salt.utils.stringutils @@ -206,7 +207,7 @@ def _legacy_setup( def __init__(self, opts, io_loop, **kwargs): super().__init__(opts, io_loop, **kwargs) self.opts = opts - self.io_loop = io_loop + self.io_loop = salt.utils.asynchronous.aioloop(io_loop) self._legacy_setup( _id=opts.get("id", ""), role=opts.get("__role", ""), @@ -402,7 +403,7 @@ async def consume(running): "Exception while consuming%s %s", self.uri, exc, exc_info=True ) - task = self.io_loop.spawn_callback(consume, running) + task = self.io_loop.create_task(consume(running)) self.callbacks[callback] = running, task @@ -560,7 +561,7 @@ async def callback(): task.add_done_callback(self.tasks.discard) self.tasks.add(task) - io_loop.add_callback(callback) + callback_task = salt.utils.asynchronous.aioloop(io_loop).create_task(callback()) async def request_handler(self): while not self._event.is_set(): @@ -804,7 +805,7 @@ def _send_recv(self, socket, _TimeoutError=tornado.gen.TimeoutError): future.set_exception(exc) if future.done(): - if isinstance(future.exception, SaltReqTimeoutError): + if isinstance(future.exception(), SaltReqTimeoutError): log.trace("Request timed out while sending. reconnecting.") else: log.trace( @@ -847,7 +848,7 @@ def _send_recv(self, socket, _TimeoutError=tornado.gen.TimeoutError): break if future.done(): - if isinstance(future.exception, SaltReqTimeoutError): + if isinstance(future.exception(), SaltReqTimeoutError): log.trace( "Request timed out while waiting for a response. reconnecting." ) @@ -880,7 +881,9 @@ def __init__(self, socket): def start_io_loop(self, io_loop): log.trace("Event monitor start!") self._running.set() - io_loop.spawn_callback(self.consume) + self._running_task = salt.utils.asynchronous.aioloop(io_loop).create_task( + self.consume() + ) async def consume(self): while self._running.is_set(): @@ -934,6 +937,11 @@ def stop(self): self._socket.disable_monitor() except zmq.Error: pass + if self._monitor_socket is not None: + try: + self._monitor_socket.close(0) + except zmq.Error: + pass self._socket = None self._running.clear() self._monitor_socket = None @@ -1032,10 +1040,13 @@ def publish_daemon( This method represents the Publish Daemon process. It is intended to be run in a thread or process as it creates and runs its own ioloop. """ - io_loop = tornado.ioloop.IOLoop() - io_loop.add_callback(self.publisher, publish_payload, io_loop=io_loop) + io_loop = salt.utils.asynchronous.aioloop(tornado.ioloop.IOLoop()) + + publisher_task = io_loop.create_task( + self.publisher(publish_payload, io_loop=io_loop) + ) try: - io_loop.start() + io_loop.run_forever() finally: self.close() @@ -1222,23 +1233,25 @@ def __init__(self, opts, io_loop, linger=0): # pylint: disable=W0231 self.master_uri = self.get_master_uri(opts) self.linger = linger if io_loop is None: - self.io_loop = tornado.ioloop.IOLoop.current() + self.io_loop = salt.utils.asynchronous.aioloop( + tornado.ioloop.IOLoop.current() + ) else: - self.io_loop = io_loop + self.io_loop = salt.utils.asynchronous.aioloop(io_loop) self.context = None self.send_queue = [] # mapping of message -> future self.send_future_map = {} self._closing = False self.socket = None - # self._queue = asyncio.Queue() - self._queue = tornado.queues.Queue() + self._queue = asyncio.Queue() async def connect(self): # pylint: disable=invalid-overridden-method if self.socket is None: self._connect_called = True self._closing = False # wire up sockets + self._queue = asyncio.Queue() self._init_socket() def _init_socket(self): @@ -1263,13 +1276,19 @@ def _init_socket(self): self.socket.setsockopt(zmq.IPV4ONLY, 0) self.socket.linger = self.linger self.socket.connect(self.master_uri) - self.io_loop.spawn_callback(self._send_recv, self.socket) + self.send_recv_task = self.io_loop.create_task( + self._send_recv(self.socket, self._queue) + ) + self.send_recv_task._log_destroy_pending = False # TODO: timeout all in-flight sessions, or error def close(self): if self._closing: return self._closing = True + # Save socket reference before clearing it for use in callback + self._queue.put_nowait((None, None)) + task_socket = self.socket if self.socket: self.socket.close() self.socket = None @@ -1277,6 +1296,38 @@ def close(self): # This hangs if closing the stream causes an import error self.context.term() self.context = None + # if getattr(self, "send_recv_task", None): + # task = self.send_recv_task + # if not task.done(): + # task.cancel() + + # # Suppress "Task was destroyed but it is pending!" warnings + # # by ensuring the task knows its exception will be handled + # task._log_destroy_pending = False + + # def _drain_cancelled(cancelled_task): + # try: + # cancelled_task.exception() + # except asyncio.CancelledError: # pragma: no cover + # # Task was cancelled - log the expected messages + # log.trace("Send socket closed while polling.") + # log.trace("Send and receive coroutine ending %s", task_socket) + # except ( + # Exception # pylint: disable=broad-exception-caught + # ): # pragma: no cover + # log.trace( + # "Exception while cancelling send/receive task.", + # exc_info=True, + # ) + # log.trace("Send and receive coroutine ending %s", task_socket) + + # task.add_done_callback(_drain_cancelled) + # else: + # try: + # task.result() + # except Exception as exc: # pylint: disable=broad-except + # log.trace("Exception while retrieving send/receive task: %r", exc) + # self.send_recv_task = None async def send(self, load, timeout=60): """ @@ -1319,7 +1370,7 @@ def get_master_uri(opts): # if we've reached here something is very abnormal raise SaltException("ReqChannel: missing master_uri/master_ip in self.opts") - async def _send_recv(self, socket, _TimeoutError=tornado.gen.TimeoutError): + async def _send_recv(self, socket, queue, _TimeoutError=tornado.gen.TimeoutError): """ Long running send/receive coroutine. This should be started once for each socket created. Once started, the coroutine will run until the @@ -1332,14 +1383,9 @@ async def _send_recv(self, socket, _TimeoutError=tornado.gen.TimeoutError): # close method is called. This allows us to fail gracefully once it's # been closed. while send_recv_running: - # try: - # future, message = await asyncio.wait_for(self._queue.get(), 0.3) - # except TimeoutError: try: - future, message = await self._queue.get( - timeout=datetime.timedelta(milliseconds=300) - ) - except _TimeoutError: + future, message = await asyncio.wait_for(queue.get(), 0.3) + except asyncio.TimeoutError as exc: try: # For some reason yielding here doesn't work becaues the # future always has a result? @@ -1348,7 +1394,10 @@ async def _send_recv(self, socket, _TimeoutError=tornado.gen.TimeoutError): except _TimeoutError: # This is what we expect if the socket is still alive pass - except zmq.eventloop.future.CancelledError: + except ( + zmq.eventloop.future.CancelledError, + asyncio.exceptions.CancelledError, + ): log.trace("Loop closed while polling send socket.") # The ioloop was closed before polling finished. send_recv_running = False @@ -1359,8 +1408,16 @@ async def _send_recv(self, socket, _TimeoutError=tornado.gen.TimeoutError): break continue + if future is None: + log.trace("Received send/recv shutdown sentinal") + send_recv_running = False + break try: await socket.send(message) + except asyncio.CancelledError as exc: + log.trace("Loop closed while sending.") + send_recv_running = False + future.set_exception(exc) except zmq.eventloop.future.CancelledError as exc: log.trace("Loop closed while sending.") # The ioloop was closed before polling finished. @@ -1384,7 +1441,10 @@ async def _send_recv(self, socket, _TimeoutError=tornado.gen.TimeoutError): future.set_exception(exc) if future.done(): - if isinstance(future.exception, SaltReqTimeoutError): + if isinstance(future.exception(), asyncio.CancelledError): + send_recv_running = False + break + elif isinstance(future.exception(), SaltReqTimeoutError): log.trace("Request timed out while sending. reconnecting.") else: log.trace( @@ -1401,12 +1461,16 @@ async def _send_recv(self, socket, _TimeoutError=tornado.gen.TimeoutError): try: # Time is in milliseconds. ready = await socket.poll(300, zmq.POLLIN) + except asyncio.CancelledError as exc: + log.trace("Loop closed while polling receive socket.") + send_recv_running = False + future.set_exception(exc) except zmq.eventloop.future.CancelledError as exc: log.trace("Loop closed while polling receive socket.") send_recv_running = False future.set_exception(exc) except zmq.ZMQError as exc: - log.trace("Recieve socket closed while polling.") + log.trace("Receive socket closed while polling.") send_recv_running = False future.set_exception(exc) @@ -1414,6 +1478,10 @@ async def _send_recv(self, socket, _TimeoutError=tornado.gen.TimeoutError): try: recv = await socket.recv() received = True + except asyncio.CancelledError as exc: + log.trace("Loop closed while receiving.") + send_recv_running = False + future.set_exception(exc) except zmq.eventloop.future.CancelledError as exc: log.trace("Loop closed while receiving.") send_recv_running = False @@ -1427,12 +1495,16 @@ async def _send_recv(self, socket, _TimeoutError=tornado.gen.TimeoutError): break if future.done(): - if isinstance(future.exception, SaltReqTimeoutError): - log.trace( + exc = future.exception() + if isinstance(exc, asyncio.CancelledError): + send_recv_running = False + break + elif isinstance(exc, SaltReqTimeoutError): + log.error( "Request timed out while waiting for a response. reconnecting." ) else: - log.trace("The request ended with an error. reconnecting.") + log.error("The request ended with an error. reconnecting. %r", exc) self.close() await self.connect() send_recv_running = False diff --git a/salt/utils/asynchronous.py b/salt/utils/asynchronous.py index d89e3ba002e4..06d1959e369b 100644 --- a/salt/utils/asynchronous.py +++ b/salt/utils/asynchronous.py @@ -14,6 +14,22 @@ log = logging.getLogger(__name__) +def aioloop(io_loop, warn=False): + """ + Ensure the ioloop is an asyncio loop not a tornado ioloop. + """ + if isinstance(io_loop, asyncio.AbstractEventLoop): + return io_loop + elif isinstance(io_loop, tornado.ioloop.IOLoop): + if warn: + import traceback + + log.warning("Passed tornado loop %s", "".join(traceback.format_stack())) + return io_loop.asyncio_loop + else: + raise RuntimeError("Loop must be AbstractEventLoop (prefered) or IOLoop") + + @contextlib.contextmanager def current_ioloop(io_loop): """ diff --git a/salt/utils/event.py b/salt/utils/event.py index 39675dd53766..60257d56f850 100644 --- a/salt/utils/event.py +++ b/salt/utils/event.py @@ -235,7 +235,7 @@ def __init__( self.node = node self.keep_loop = keep_loop if io_loop is not None: - self.io_loop = io_loop + self.io_loop = salt.utils.asynchronous.aioloop(io_loop) self._run_io_loop_sync = False else: self.io_loop = None @@ -270,6 +270,7 @@ def __init__( # and don't read out events from the buffer on an on-going basis, # the buffer will grow resulting in big memory usage. self.connect_pub() + self._publish_tasks = [] @classmethod def __load_cache_regex(cls): @@ -319,6 +320,30 @@ def unsubscribe(self, tag, match_type=None): ): self.pending_events.append(evt) + def _schedule(self, func, *args, **kwargs): + """ + Schedule ``func`` on the underlying asyncio event loop. + + ``func`` can be a coroutine function or a regular callable. If it + returns a coroutine, we ensure it's converted into a task so that any + exceptions are surfaced instead of being silently ignored. + """ + if self.io_loop is None: + raise RuntimeError("No asyncio event loop available for scheduling") + + loop = salt.utils.asynchronous.aioloop(self.io_loop) + + if asyncio.iscoroutinefunction(func): + loop.create_task(func(*args, **kwargs)) + return + + def runner(): + result = func(*args, **kwargs) + if asyncio.iscoroutine(result): + loop.create_task(result) + + loop.call_soon(runner) + def connect_pub(self, timeout=None): """ Establish the publish connection @@ -355,7 +380,7 @@ def connect_pub(self, timeout=None): self.subscriber = salt.transport.ipc_publish_client( self.node, self.opts, io_loop=self.io_loop ) - self.io_loop.spawn_callback(self.subscriber.connect) + self._connect_task = self.io_loop.create_task(self.subscriber.connect()) # For the asynchronous case, the connect will be defered to when # set_event_handler() is invoked. @@ -368,6 +393,11 @@ def close_pub(self): """ if not self.cpub: return + if hasattr(self, "_connect_task"): + task = self._connect_task + if task and not task.done(): + task.cancel() + self._connect_task = None self.subscriber.close() self.subscriber = None self.pending_events = [] @@ -420,6 +450,10 @@ def close_pull(self): self.pusher.close() self.pusher = None self.cpush = False + for task in self._publish_tasks: + if task and not task.done(): + task.cancel() + self._publish_tasks.clear() @classmethod def unpack(cls, raw): @@ -797,7 +831,8 @@ def fire_event(self, data, tag, timeout=1000): ) raise else: - asyncio.create_task(self.pusher.publish(msg)) + task = self.io_loop.create_task(self.pusher.publish(msg)) + self._publish_tasks.append(task) return True def fire_master(self, data, tag, timeout=1000): @@ -916,7 +951,7 @@ def set_event_handler(self, event_handler): if not self.cpub: self.connect_pub() # This will handle reconnects - self.io_loop.spawn_callback(self.subscriber.on_recv, event_handler) + self._schedule(self.subscriber.on_recv, event_handler) # pylint: disable=W1701 def __del__(self): diff --git a/salt/utils/process.py b/salt/utils/process.py index eeac663c03af..c074653c524c 100644 --- a/salt/utils/process.py +++ b/salt/utils/process.py @@ -2,6 +2,7 @@ Functions for daemonizing and otherwise modifying running processes """ +import asyncio import contextlib import copy import errno @@ -21,8 +22,6 @@ import threading import time -from tornado import gen - import salt._logging import salt.defaults.exitcodes import salt.utils.files @@ -608,8 +607,7 @@ def send_signal_to_processes(self, signal_): # Otherwise, it's a dead process, remove it from the process map del self._process_map[pid] - @gen.coroutine - def run(self, asynchronous=False): + async def run(self, asynchronous=False): """ Load and start all available api modules """ @@ -618,12 +616,14 @@ def run(self, asynchronous=False): appendproctitle(self.name) # make sure to kill the subprocesses if the parent is killed - if signal.getsignal(signal.SIGTERM) is signal.SIG_DFL: - # There are no SIGTERM handlers installed, install ours - signal.signal(signal.SIGTERM, self._handle_signals) - if signal.getsignal(signal.SIGINT) is signal.SIG_DFL: - # There are no SIGINT handlers installed, install ours - signal.signal(signal.SIGINT, self._handle_signals) + # Only set up signal handlers if we're in the main thread + if threading.current_thread() == threading.main_thread(): + if signal.getsignal(signal.SIGTERM) is signal.SIG_DFL: + # There are no SIGTERM handlers installed, install ours + signal.signal(signal.SIGTERM, self._handle_signals) + if signal.getsignal(signal.SIGINT) is signal.SIG_DFL: + # There are no SIGINT handlers installed, install ours + signal.signal(signal.SIGINT, self._handle_signals) while True: log.trace("Process manager iteration") @@ -633,10 +633,18 @@ def run(self, asynchronous=False): # The event-based subprocesses management code was removed from here # because os.wait() conflicts with the subprocesses management logic # implemented in `multiprocessing` package. See #35480 for details. + + # In synchronous mode with no processes, exit after checking children + # but before sleeping (to avoid unnecessary 10s delay in tests) + if not asynchronous and not self._process_map: + break + if asynchronous: - yield gen.sleep(10) + await asyncio.sleep(10) else: time.sleep(10) + + # Check again after sleep - in async mode, exit if no processes if not self._process_map: break # OSError is raised if a signal handler is called (SIGTERM) during os.wait diff --git a/tests/pytests/functional/transport/zeromq/test_request_client.py b/tests/pytests/functional/transport/zeromq/test_request_client.py index dd730f9586a0..339a212afd0b 100644 --- a/tests/pytests/functional/transport/zeromq/test_request_client.py +++ b/tests/pytests/functional/transport/zeromq/test_request_client.py @@ -131,7 +131,9 @@ async def test_request_client_send_recv_socket_closed( with caplog.at_level(logging.TRACE): request_client.close() await asyncio.sleep(0.5) - assert "Send socket closed while polling." in caplog.messages + # The tornado version would see this log. + # assert "Send socket closed while polling." in caplog.messages + assert "Received send/recv shutdown sentinal" in caplog.messages assert f"Send and receive coroutine ending {socket}" in caplog.messages @@ -391,3 +393,97 @@ async def recv(*args, **kwargs): request_client.close() serve_socket.close() ctx.term() + + +async def test_request_client_uses_asyncio_queue(io_loop, minion_opts, port): + """ + Test that RequestClient uses asyncio.Queue instead of tornado.queues.Queue. + This verifies the conversion from Tornado to pure asyncio. + """ + minion_opts["master_uri"] = f"tcp://127.0.0.1:{port}" + request_client = salt.transport.zeromq.RequestClient(minion_opts, io_loop) + try: + # Verify the queue is an asyncio.Queue + assert isinstance(request_client._queue, asyncio.Queue) + # Verify it has asyncio.Queue methods + assert hasattr(request_client._queue, "get") + assert hasattr(request_client._queue, "put") + # Verify it doesn't have Tornado-specific attributes + assert not hasattr(request_client._queue, "get_timeout") + finally: + request_client.close() + + +async def test_request_client_queue_timeout_uses_asyncio( + io_loop, minion_opts, port, caplog +): + """ + Test that RequestClient queue timeout uses asyncio.TimeoutError. + This verifies the conversion from tornado.gen.TimeoutError to asyncio.TimeoutError. + """ + minion_opts["master_uri"] = f"tcp://127.0.0.1:{port}" + ctx = zmq.Context() + serve_socket = ctx.socket(zmq.REP) + serve_socket.bind(minion_opts["master_uri"]) + + request_client = salt.transport.zeromq.RequestClient(minion_opts, io_loop) + + try: + await request_client.connect() + + # The queue should timeout without any messages + # This tests that asyncio.wait_for with asyncio.TimeoutError works + # The _send_recv loop should handle the timeout gracefully + + # Send a request - it should queue properly + future = asyncio.Future() + await request_client._queue.put((future, b"test_message")) + + # Wait briefly + await asyncio.sleep(0.1) + + # The _send_recv loop should have picked up the message + # and attempted to send it (though no handler is set up) + assert request_client._queue.qsize() == 0 + + finally: + request_client.close() + serve_socket.close() + ctx.term() + + +async def test_request_client_asyncio_cancelled_error_handling( + io_loop, request_client, minion_opts, port, caplog +): + """ + Test that RequestClient properly handles asyncio.CancelledError. + This verifies the new asyncio.CancelledError exception handlers. + """ + minion_opts["master_uri"] = f"tcp://127.0.0.1:{port}" + ctx = zmq.Context() + serve_socket = ctx.socket(zmq.REP) + serve_socket.bind(minion_opts["master_uri"]) + + await request_client.connect() + + socket = request_client.socket + + async def send(*args, **kwargs): + """ + Mock send to raise asyncio.CancelledError + """ + raise asyncio.CancelledError() + + socket.send = send + + with caplog.at_level(logging.TRACE): + with pytest.raises(asyncio.CancelledError): + try: + await request_client.send("meh") + await asyncio.sleep(0.3) + assert "Loop closed while sending." in caplog.messages + assert f"Send and receive coroutine ending {socket}" in caplog.messages + finally: + request_client.close() + serve_socket.close() + ctx.term() diff --git a/tests/pytests/functional/utils/test_process.py b/tests/pytests/functional/utils/test_process.py index 14525c426afc..8798c0bf66ee 100644 --- a/tests/pytests/functional/utils/test_process.py +++ b/tests/pytests/functional/utils/test_process.py @@ -5,6 +5,7 @@ Test salt's process utility module """ +import asyncio import os import pathlib import time @@ -71,3 +72,87 @@ def target(): break assert len(process_list.processes) == 0 assert _get_num_fds(pid) == num - 2 + + +async def test_process_manager_run_async(): + """ + Test that ProcessManager.run() is now an async coroutine. + This tests the conversion from Tornado @gen.coroutine to async/await. + """ + process_manager = salt.utils.process.ProcessManager(wait_for_kill=5) + try: + # Verify run() is an async coroutine + import inspect + + assert inspect.iscoroutinefunction(process_manager.run) + + # Create a task to run the process manager asynchronously + task = asyncio.create_task(process_manager.run(asynchronous=True)) + + # Let it run briefly + await asyncio.sleep(0.5) + + # Verify the task is running + assert not task.done() + + # Cancel the task + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + finally: + process_manager.terminate() + + +async def test_process_manager_run_uses_asyncio_sleep(): + """ + Test that ProcessManager.run() uses asyncio.sleep() instead of gen.sleep(). + """ + process_manager = salt.utils.process.ProcessManager(wait_for_kill=5) + try: + # Start the async run + task = asyncio.create_task(process_manager.run(asynchronous=True)) + + # Wait a bit to ensure it's looping with asyncio.sleep + await asyncio.sleep(0.1) + + # Verify it's still running (would hang if gen.sleep was used incorrectly) + assert not task.done() + + # Clean up + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + finally: + process_manager.terminate() + + +def test_process_manager_run_synchronous(): + """ + Test that ProcessManager.run() can still run synchronously. + """ + process_manager = salt.utils.process.ProcessManager(wait_for_kill=5) + try: + # When asynchronous=False, it should use time.sleep and exit quickly + # since there are no processes + import threading + + ran = [] + + def run_sync(): + # This should complete quickly since there are no processes + asyncio.run(process_manager.run(asynchronous=False)) + ran.append(True) + + thread = threading.Thread(target=run_sync) + thread.start() + thread.join(timeout=2) + + # Should have completed + assert not thread.is_alive() + assert ran == [True] + finally: + process_manager.terminate() diff --git a/tests/pytests/unit/test_minion.py b/tests/pytests/unit/test_minion.py index 9a376b07887a..f9143cf3a754 100644 --- a/tests/pytests/unit/test_minion.py +++ b/tests/pytests/unit/test_minion.py @@ -658,7 +658,13 @@ async def test_when_ping_interval_is_set_the_callback_should_be_added_to_periodi try: try: minion.connected = MagicMock(side_effect=(False, True)) - minion._fire_master_minion_start = MagicMock() + + # _fire_master_minion_start is now called as a coroutine via create_task + # so it must be an async function + async def async_mock(): + pass + + minion._fire_master_minion_start = async_mock minion.tune_in(start=False) except RuntimeError: pass @@ -834,10 +840,13 @@ def test_minion_manage_beacons(minion_opts): "salt.utils.process.SignalHandlingProcess.join", MagicMock(return_value=True), ): + minion = None try: minion_opts["beacons"] = {} - io_loop = MagicMock() + # io_loop must be a real Tornado IOLoop because our code calls + # salt.utils.asynchronous.aioloop() on it + io_loop = tornado.ioloop.IOLoop() mock_functions = {"test.ping": None} minion = salt.minion.Minion(minion_opts, io_loop=io_loop) @@ -853,7 +862,8 @@ def test_minion_manage_beacons(minion_opts): assert "ps" in minion.opts["beacons"] assert minion.opts["beacons"]["ps"] == bdata finally: - minion.destroy() + if minion is not None: + minion.destroy() def test_prep_ip_port(): @@ -1214,7 +1224,8 @@ async def test_minion_manager_async_stop(io_loop, minion_opts, tmp_path): assert mm.event is not None # Check io_loop is running - assert mm.io_loop.asyncio_loop.is_running() + # mm.io_loop is now an asyncio.AbstractEventLoop (not Tornado IOLoop) + assert mm.io_loop.is_running() # Wait for the ipc socket to be created, meaning the publish server is listening. while not list(pathlib.Path(minion_opts["sock_dir"]).glob("*")): @@ -1251,3 +1262,79 @@ async def test_minion_manager_async_stop(io_loop, minion_opts, tmp_path): parent_signal_handler.assert_called_once_with(signal.SIGTERM, None) assert mm.event_publisher is None assert mm.event is None + + +def test_minion_io_loop_is_asyncio_loop(minion_opts): + """ + Test that Minion io_loop is converted to asyncio.AbstractEventLoop. + This verifies the salt.utils.asynchronous.aioloop() conversion. + """ + minion = salt.minion.Minion(minion_opts, load_grains=False) + try: + # Verify io_loop is an asyncio loop, not a Tornado IOLoop + assert isinstance(minion.io_loop, asyncio.AbstractEventLoop) + # Ensure it has asyncio methods + assert hasattr(minion.io_loop, "create_task") + assert hasattr(minion.io_loop, "call_soon") + # Ensure it doesn't have Tornado-specific methods + assert not hasattr(minion.io_loop, "spawn_callback") + finally: + minion.destroy() + + +def test_minion_io_loop_with_provided_loop(minion_opts): + """ + Test that Minion io_loop conversion works when a loop is provided. + """ + # Create a Tornado IOLoop + tornado_loop = tornado.ioloop.IOLoop() + try: + minion = salt.minion.Minion( + minion_opts, io_loop=tornado_loop, load_grains=False + ) + try: + # Should still be converted to asyncio loop + assert isinstance(minion.io_loop, asyncio.AbstractEventLoop) + # Should be the same underlying loop + assert minion.io_loop is tornado_loop.asyncio_loop + finally: + minion.destroy() + finally: + tornado_loop.close() + + +def test_minion_manager_io_loop_is_asyncio_loop(minion_opts): + """ + Test that MinionManager io_loop is converted to asyncio.AbstractEventLoop. + """ + with patch("salt.utils.process.SignalHandlingProcess.start"): + with patch("salt.utils.verify.valid_id"): + mm = salt.minion.MinionManager(minion_opts) + try: + # Verify io_loop is an asyncio loop + assert isinstance(mm.io_loop, asyncio.AbstractEventLoop) + # Ensure it has asyncio methods + assert hasattr(mm.io_loop, "create_task") + assert hasattr(mm.io_loop, "call_soon") + # Ensure it doesn't have Tornado-specific methods + assert not hasattr(mm.io_loop, "spawn_callback") + finally: + mm.destroy() + + +def test_syndic_manager_io_loop_is_asyncio_loop(minion_opts): + """ + Test that SyndicManager io_loop is converted to asyncio.AbstractEventLoop. + """ + minion_opts["order_masters"] = True + sm = salt.minion.SyndicManager(minion_opts) + try: + # Verify io_loop is an asyncio loop + assert isinstance(sm.io_loop, asyncio.AbstractEventLoop) + # Ensure it has asyncio methods + assert hasattr(sm.io_loop, "create_task") + assert hasattr(sm.io_loop, "call_soon") + # Ensure it doesn't have Tornado-specific methods + assert not hasattr(sm.io_loop, "spawn_callback") + finally: + sm.destroy() diff --git a/tests/pytests/unit/transport/test_publish_client.py b/tests/pytests/unit/transport/test_publish_client.py index c1e7e5685508..e4b3d39e4e67 100644 --- a/tests/pytests/unit/transport/test_publish_client.py +++ b/tests/pytests/unit/transport/test_publish_client.py @@ -304,7 +304,7 @@ async def test_recv_timeout_zero(): """ host = "127.0.0.1" port = 11122 - ioloop = MagicMock() + ioloop = asyncio.get_running_loop() mock_stream = MagicMock() mock_unpacker = MagicMock() mock_unpacker.__iter__.return_value = [] diff --git a/tests/pytests/unit/transport/test_tcp.py b/tests/pytests/unit/transport/test_tcp.py index b5d1cb56bdf2..f6c64ae9ada9 100644 --- a/tests/pytests/unit/transport/test_tcp.py +++ b/tests/pytests/unit/transport/test_tcp.py @@ -7,7 +7,6 @@ import pytest import tornado import tornado.concurrent -import tornado.gen import tornado.ioloop import tornado.iostream from pytestshellutils.utils import ports @@ -37,16 +36,15 @@ def fake_crypto(): @pytest.fixture def _fake_authd(io_loop): - @tornado.gen.coroutine - def return_nothing(): - raise tornado.gen.Return() + async def return_nothing(*args, **kwargs): + return None with patch( "salt.crypt.AsyncAuth.authenticated", new_callable=PropertyMock ) as mock_authed, patch( "salt.crypt.AsyncAuth.authenticate", autospec=True, - return_value=return_nothing(), + side_effect=return_nothing, ), patch( "salt.crypt.AsyncAuth.gen_token", autospec=True, return_value=42 ): @@ -269,13 +267,9 @@ def test_tcp_pub_server_channel_publish_filtering_str_list(temp_salt_master): @pytest.fixture(scope="function") -def salt_message_client(): - io_loop_mock = MagicMock(spec=tornado.ioloop.IOLoop) - io_loop_mock.asyncio_loop = None - io_loop_mock.call_later.side_effect = lambda *args, **kwargs: (args, kwargs) - +def salt_message_client(io_loop): client = salt.transport.tcp.MessageClient( - {}, "127.0.0.1", ports.get_unused_localhost_port(), io_loop=io_loop_mock + {}, "127.0.0.1", ports.get_unused_localhost_port(), io_loop=io_loop ) try: @@ -395,20 +389,19 @@ def xtest_client_reconnect_backoff(client_socket): opts, client_socket.listen_on, client_socket.port ) - def _sleep(t): + async def _sleep(t): client.close() assert t == 5 return - # return tornado.gen.sleep() + # return asyncio.sleep() - @tornado.gen.coroutine - def connect(*args, **kwargs): + async def connect(*args, **kwargs): raise Exception("err") client._tcp_client.connect = connect try: - with patch("tornado.gen.sleep", side_effect=_sleep): + with patch("asyncio.sleep", side_effect=_sleep): client.io_loop.run_sync(client.connect) finally: client.close() @@ -448,7 +441,7 @@ async def test_when_async_req_channel_with_syndic_role_should_use_syndic_master_ @pytest.mark.usefixtures("_fake_authd", "_fake_crypticle", "_fake_keys") async def test_mixin_should_use_correct_path_when_syndic(): - mockloop = MagicMock() + mockloop = asyncio.get_running_loop() expected_pubkey_path = os.path.join("/etc/salt/pki/minion", "syndic_master.pub") opts = { "master_uri": "tcp://127.0.0.1:4506", @@ -497,6 +490,8 @@ async def test_presence_removed_on_stream_closed(): opts = {"presence_events": True} io_loop_mock = MagicMock(spec=tornado.ioloop.IOLoop) + # Add asyncio_loop attribute for aioloop() compatibility + io_loop_mock.asyncio_loop = MagicMock() with patch("salt.master.AESFuncs.__init__", return_value=None): server = salt.transport.tcp.PubServer(opts, io_loop=io_loop_mock) @@ -620,7 +615,7 @@ def read_bytes(self, *args, **kwargs): await server.handle_stream(stream, address) # Let loop iterate so callback gets called - await tornado.gen.sleep(0.01) + await asyncio.sleep(0.01) assert received assert [msg] == received @@ -664,9 +659,9 @@ async def test_message_client_stream_return_exception(minion_opts, io_loop): ] try: io_loop.add_callback(client._stream_return) - await tornado.gen.sleep(0.01) + await asyncio.sleep(0.01) client.close() - await tornado.gen.sleep(0.01) + await asyncio.sleep(0.01) assert client._stream is None finally: client.close() diff --git a/tests/pytests/unit/transport/test_zeromq.py b/tests/pytests/unit/transport/test_zeromq.py index 6844e4dbf8d6..bc22aabb242b 100644 --- a/tests/pytests/unit/transport/test_zeromq.py +++ b/tests/pytests/unit/transport/test_zeromq.py @@ -1,3 +1,4 @@ +import asyncio import ctypes import hashlib import logging @@ -9,7 +10,6 @@ import msgpack import pytest -import tornado.gen import tornado.ioloop import zmq.eventloop.future @@ -267,17 +267,16 @@ def run_loop_in_thread(loop, evt): """ loop.make_current() - @tornado.gen.coroutine - def stopper(): - yield tornado.gen.sleep(0.1) + async def stopper(): + await asyncio.sleep(0.1) while True: if not evt.is_set(): loop.stop() break - yield tornado.gen.sleep(0.3) + await asyncio.sleep(0.3) loop.add_callback(evt.set) - loop.add_callback(stopper) + loop.spawn_callback(stopper) try: loop.start() finally: @@ -349,13 +348,16 @@ def __exit__(self, *args, **kwargs): # pylint: enable=W1701 @classmethod - @tornado.gen.coroutine - def _handle_payload(cls, payload): + async def _handle_payload(cls, payload): """ TODO: something besides echo """ - cls.mock._handle_payload_hook() - raise tornado.gen.Return((payload, {"fun": "send_clear"})) + hook_result = cls.mock._handle_payload_hook() + if asyncio.iscoroutine(hook_result): + hook_result = await hook_result + if hook_result is not None: + return hook_result + return payload, {"fun": "send_clear"} def test_master_uri(): @@ -477,7 +479,7 @@ def test_serverside_exception(temp_salt_minion, temp_salt_master): """ with MockSaltMinionMaster(temp_salt_minion, temp_salt_master) as minion_master: with patch.object(minion_master.mock, "_handle_payload_hook") as _mock: - _mock.side_effect = tornado.gen.Return(({}, {"fun": "madeup-fun"})) + _mock.return_value = ({}, {"fun": "madeup-fun"}) ret = minion_master.channel.send({}, timeout=5, tries=1) assert ret == "Server-side exception handling payload" @@ -697,7 +699,7 @@ def test_req_server_chan_encrypt_v1(pki_dir, encryption_algorithm, master_opts): def test_req_chan_decode_data_dict_entry_v1( pki_dir, encryption_algorithm, minion_opts, master_opts ): - mockloop = MagicMock() + mockloop = asyncio.new_event_loop() minion_opts.update( { "master_uri": "tcp://127.0.0.1:4506", @@ -715,26 +717,29 @@ def test_req_chan_decode_data_dict_entry_v1( ) master_opts = dict(master_opts, pki_dir=str(pki_dir.joinpath("master"))) server = salt.channel.server.ReqServerChannel.factory(master_opts) - client = salt.channel.client.ReqChannel.factory(minion_opts, io_loop=mockloop) - dictkey = "pillar" - target = "minion" - pillar_data = {"pillar1": "meh"} - ret = server._encrypt_private( - pillar_data, - dictkey, - target, - sign_messages=False, - encryption_algorithm=encryption_algorithm, - ) - key = client.auth.get_keys() - aes = key.decrypt(ret["key"], encryption_algorithm) - pcrypt = salt.crypt.Crypticle(client.opts, aes) - ret_pillar_data = pcrypt.loads(ret[dictkey]) - assert ret_pillar_data == pillar_data + try: + client = salt.channel.client.ReqChannel.factory(minion_opts, io_loop=mockloop) + dictkey = "pillar" + target = "minion" + pillar_data = {"pillar1": "meh"} + ret = server._encrypt_private( + pillar_data, + dictkey, + target, + sign_messages=False, + encryption_algorithm=encryption_algorithm, + ) + key = client.auth.get_keys() + aes = key.decrypt(ret["key"], encryption_algorithm) + pcrypt = salt.crypt.Crypticle(client.opts, aes) + ret_pillar_data = pcrypt.loads(ret[dictkey]) + assert ret_pillar_data == pillar_data + finally: + mockloop.close() async def test_req_chan_decode_data_dict_entry_v2(minion_opts, master_opts, pki_dir): - mockloop = MagicMock() + mockloop = asyncio.get_running_loop() minion_opts.update( { "master_uri": "tcp://127.0.0.1:4506", @@ -776,8 +781,7 @@ async def test_req_chan_decode_data_dict_entry_v2(minion_opts, master_opts, pki_ transport = client.transport client.transport = MagicMock() - @tornado.gen.coroutine - def mocksend(msg, timeout=60, tries=3): + async def mocksend(msg, timeout=60, tries=3): client.transport.msg = msg load = client.auth.session_crypticle.loads(msg["load"]) ret = server._encrypt_private( @@ -789,7 +793,7 @@ def mocksend(msg, timeout=60, tries=3): encryption_algorithm=minion_opts["encryption_algorithm"], signing_algorithm=minion_opts["signing_algorithm"], ) - raise tornado.gen.Return(ret) + return ret client.transport.send = mocksend @@ -822,7 +826,7 @@ def mocksend(msg, timeout=60, tries=3): async def test_req_chan_decode_data_dict_entry_v2_bad_nonce( pki_dir, minion_opts, master_opts ): - mockloop = MagicMock() + mockloop = asyncio.get_running_loop() minion_opts.update( { "master_uri": "tcp://127.0.0.1:4506", @@ -868,10 +872,9 @@ async def test_req_chan_decode_data_dict_entry_v2_bad_nonce( signing_algorithm=minion_opts["signing_algorithm"], ) - @tornado.gen.coroutine - def mocksend(msg, timeout=60, tries=3): + async def mocksend(msg, timeout=60, tries=3): client.transport.msg = msg - raise tornado.gen.Return(ret) + return ret client.transport.send = mocksend @@ -904,7 +907,7 @@ def mocksend(msg, timeout=60, tries=3): async def test_req_chan_decode_data_dict_entry_v2_bad_signature( pki_dir, minion_opts, master_opts ): - mockloop = MagicMock() + mockloop = asyncio.get_running_loop() minion_opts.update( { "master_uri": "tcp://127.0.0.1:4506", @@ -947,8 +950,7 @@ async def test_req_chan_decode_data_dict_entry_v2_bad_signature( transport = client.transport client.transport = MagicMock() - @tornado.gen.coroutine - def mocksend(msg, timeout=60, tries=3): + async def mocksend(msg, timeout=60, tries=3): client.transport.msg = msg load = client.auth.session_crypticle.loads(msg["load"]) ret = server._encrypt_private( @@ -971,16 +973,15 @@ def mocksend(msg, timeout=60, tries=3): data["pillar"] = {"pillar1": "bar"} signed_msg["data"] = salt.payload.dumps(data) ret[dictkey] = pcrypt.dumps(signed_msg) - raise tornado.gen.Return(ret) + return ret client.transport.send = mocksend # Minion should try to authenticate on bad signature - @tornado.gen.coroutine - def mockauthenticate(): - pass + async def mockauthenticate(): + return None - client.auth.authenticate = MagicMock(wraps=mockauthenticate) + client.auth.authenticate = AsyncMock(side_effect=mockauthenticate) # Note the 'ver' value in 'load' does not represent the the 'version' sent # in the top level of the transport's message. @@ -1012,7 +1013,7 @@ def mockauthenticate(): async def test_req_chan_decode_data_dict_entry_v2_bad_key( pki_dir, minion_opts, master_opts ): - mockloop = MagicMock() + mockloop = asyncio.get_running_loop() minion_opts.update( { "master_uri": "tcp://127.0.0.1:4506", @@ -1055,8 +1056,7 @@ async def test_req_chan_decode_data_dict_entry_v2_bad_key( transport = client.transport client.transport = MagicMock() - @tornado.gen.coroutine - def mocksend(msg, timeout=60, tries=3): + async def mocksend(msg, timeout=60, tries=3): client.transport.msg = msg load = client.auth.session_crypticle.loads(msg["load"]) ret = server._encrypt_private( @@ -1082,7 +1082,7 @@ def mocksend(msg, timeout=60, tries=3): ret[dictkey] = pcrypt.dumps(signed_msg) key = salt.utils.stringutils.to_bytes(key) ret["key"] = pub.encrypt(key, minion_opts["encryption_algorithm"]) - raise tornado.gen.Return(ret) + return ret client.transport.send = mocksend @@ -1718,7 +1718,7 @@ async def test_client_send_recv_on_cancelled_error(minion_opts, io_loop): client.socket = AsyncMock() client.socket.poll.side_effect = zmq.eventloop.future.CancelledError client._queue.put_nowait((mock_future, {"meh": "bah"})) - await client._send_recv(client.socket) + await client._send_recv(client.socket, client._queue) mock_future.set_exception.assert_not_called() finally: client.close()