diff --git a/monty/__main__.py b/monty/__main__.py index 10b514e2..d816e899 100644 --- a/monty/__main__.py +++ b/monty/__main__.py @@ -94,11 +94,6 @@ async def main() -> None: sync_commands_debug=True, sync_on_cog_actions=True, ) - - kwargs = {} - if constants.Client.proxy is not None: - kwargs["proxy"] = constants.Client.proxy - bot = Monty( redis_session=redis_session, command_prefix=constants.Client.default_command_prefix, @@ -106,7 +101,7 @@ async def main() -> None: allowed_mentions=disnake.AllowedMentions(everyone=False), intents=_intents, command_sync_flags=command_sync_flags, - **kwargs, + proxy=constants.Client.proxy, ) try: diff --git a/monty/bot.py b/monty/bot.py index 82045f0f..84e1d101 100644 --- a/monty/bot.py +++ b/monty/bot.py @@ -18,6 +18,7 @@ import redis import redis.asyncio import sqlalchemy as sa +import yarl from disnake.ext import commands from multidict import CIMultiDict, CIMultiDictProxy from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine @@ -64,6 +65,12 @@ def __init__(self, redis_session: redis.asyncio.Redis, proxy: str = None, **kwar if TEST_GUILDS: kwargs["test_guilds"] = TEST_GUILDS log.warn("registering as test_guilds") + + if proxy: + kwargs["proxy"] = proxy # pass proxy to disnake client + if "connector" not in kwargs: + kwargs["connector"] = self.create_connector(proxy=proxy) + super().__init__(**kwargs) self.redis_session = redis_session @@ -180,20 +187,40 @@ async def _request( sys.version_info, constants.Client.version, constants.Source.github ) + # this is also used by gql + self.http_request_class = self._create_http_request_class(proxy=proxy) + self.http_session = aiohttp.ClientSession( - connector=aiohttp.TCPConnector( - resolver=aiohttp.AsyncResolver(), - family=socket.AF_INET, - verify_ssl=not bool(proxy and proxy.startswith("http://")), - ), + connector=self.create_connector(proxy=proxy), + request_class=self.http_request_class, trace_configs=trace_configs, headers=multidict.CIMultiDict({"User-agent": user_agent}), ) - if proxy: - partial_request = functools.partial(_request, self.http_session, proxy=proxy) - else: - partial_request = functools.partial(_request, self.http_session) - self.http_session._request = partial_request + self.http_session._request = functools.partial(_request, self.http_session) + + def create_connector(self, proxy: str = None) -> aiohttp.BaseConnector: + """Create a TCPConnector, changing the ssl setting based on the proxy value.""" + return aiohttp.TCPConnector( + resolver=aiohttp.AsyncResolver(), + family=socket.AF_INET, + ssl=not (proxy and proxy.startswith("http://")), + ) + + def _create_http_request_class(self, proxy: str = None) -> type[aiohttp.ClientRequest]: + """Create a ClientRequest type, which inserts the proxy into every request's args (if set).""" + if not proxy: + return aiohttp.ClientRequest # default + + proxy_url = yarl.URL(proxy) + verify_ssl = not proxy.startswith("http://") + + class ProxyClientRequest(aiohttp.ClientRequest): + def __init__(self, *args: Any, **kwargs: Any): + kwargs["proxy"] = proxy_url + kwargs["ssl"] = verify_ssl + super().__init__(*args, **kwargs) + + return ProxyClientRequest async def get_self_invite_perms(self) -> disnake.Permissions: """Sets the internal invite_permissions and fetches them.""" diff --git a/monty/exts/info/github_info.py b/monty/exts/info/github_info.py index c86b590b..5c37ed89 100644 --- a/monty/exts/info/github_info.py +++ b/monty/exts/info/github_info.py @@ -208,7 +208,14 @@ def __init__(self, bot: Monty) -> None: self.bot = bot transport = AIOHTTPTransport( - url="https://api.github.com/graphql", timeout=20, headers=GITHUB_REQUEST_HEADERS, ssl=True + url="https://api.github.com/graphql", + timeout=20, + headers=GITHUB_REQUEST_HEADERS, + ssl=True, + client_session_args={ + # used for applying proxy settings + "request_class": bot.http_request_class, + }, ) self.gql = gql.Client(transport=transport, fetch_schema_from_transport=True)