|
2 | 2 | from __future__ import division
|
3 | 3 |
|
4 | 4 | import asyncio
|
| 5 | +import concurrent |
5 | 6 | import socket
|
6 | 7 | import socks
|
7 | 8 |
|
@@ -110,19 +111,32 @@ def _connect(self, fam, typ):
|
110 | 111 | sock = None
|
111 | 112 | timeout = self.timeout
|
112 | 113 |
|
113 |
| - async def resolve_hostname(host, port, fam=0, typ=0, proto=0, flags=0): |
114 |
| - loop = asyncio.get_event_loop() |
| 114 | + async def async_getaddrinfo(host, port, fam=0, typ=0, proto=0, flags=0): |
| 115 | + loop = asyncio.get_running_loop() |
115 | 116 | try:
|
116 | 117 | result = await loop.getaddrinfo(host, port, family=fam, type=typ, proto=proto, flags=flags)
|
117 | 118 | except asyncio.exceptions.CancelledError:
|
118 | 119 | result = []
|
119 | 120 | return result
|
120 | 121 |
|
| 122 | + def run_async_in_thread(coro): |
| 123 | + loop = asyncio.new_event_loop() |
| 124 | + asyncio.set_event_loop(loop) |
| 125 | + future = loop.run_until_complete(coro) |
| 126 | + loop.close() |
| 127 | + return future |
| 128 | + |
| 129 | + # Using asyncio to avoid process blocking when DNS resolution fail. It's probably better |
| 130 | + # to use async all the ways to `sock.connect`. However, let's keep the changes small |
| 131 | + # until we have the needs. |
| 132 | + def sync_getaddrinfo(*args): |
| 133 | + # Run in a separate thread to avoid deadlocks when users nest eventloops. |
| 134 | + with concurrent.futures.ThreadPoolExecutor() as executor: |
| 135 | + future = executor.submit(run_async_in_thread, async_getaddrinfo(*args)) |
| 136 | + return future.result() |
| 137 | + |
121 | 138 | with self.waitfor('Opening connection to %s on port %s' % (self.rhost, self.rport)) as h:
|
122 |
| - # Using asyncio to avoid blocking when DNS resolution fail. It's probably better |
123 |
| - # to use async all the ways to `sock.connect`. However, let's keep the changes |
124 |
| - # small until we have the needs. |
125 |
| - hostnames = asyncio.run(resolve_hostname(self.rhost, self.rport, fam, typ, 0, socket.AI_PASSIVE)) |
| 139 | + hostnames = sync_getaddrinfo(self.rhost, self.rport, fam, typ, 0, socket.AI_PASSIVE) |
126 | 140 | for res in hostnames:
|
127 | 141 | self.family, self.type, self.proto, _canonname, sockaddr = res
|
128 | 142 |
|
|
0 commit comments