Skip to content

fix: apply proxy to (almost) all requests #336

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 1 addition & 6 deletions monty/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,19 +94,14 @@ async def main() -> None:
sync_commands_debug=True,
sync_on_cog_actions=True,
)

kwargs = {}
if constants.Client.proxy is not None:
kwargs["proxy"] = constants.Client.proxy

bot = Monty(
redis_session=redis_session,
command_prefix=constants.Client.default_command_prefix,
activity=disnake.Game(name=f"Commands: {constants.Client.default_command_prefix}help"),
allowed_mentions=disnake.AllowedMentions(everyone=False),
intents=_intents,
command_sync_flags=command_sync_flags,
**kwargs,
proxy=constants.Client.proxy,
)

try:
Expand Down
47 changes: 37 additions & 10 deletions monty/bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import redis
import redis.asyncio
import sqlalchemy as sa
import yarl
from disnake.ext import commands
from multidict import CIMultiDict, CIMultiDictProxy
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
Expand Down Expand Up @@ -64,6 +65,12 @@ def __init__(self, redis_session: redis.asyncio.Redis, proxy: str = None, **kwar
if TEST_GUILDS:
kwargs["test_guilds"] = TEST_GUILDS
log.warn("registering as test_guilds")

if proxy:
kwargs["proxy"] = proxy # pass proxy to disnake client
if "connector" not in kwargs:
kwargs["connector"] = self.create_connector(proxy=proxy)

super().__init__(**kwargs)

self.redis_session = redis_session
Expand Down Expand Up @@ -180,20 +187,40 @@ async def _request(
sys.version_info, constants.Client.version, constants.Source.github
)

# this is also used by gql
self.http_request_class = self._create_http_request_class(proxy=proxy)

self.http_session = aiohttp.ClientSession(
connector=aiohttp.TCPConnector(
resolver=aiohttp.AsyncResolver(),
family=socket.AF_INET,
verify_ssl=not bool(proxy and proxy.startswith("http://")),
),
connector=self.create_connector(proxy=proxy),
request_class=self.http_request_class,
trace_configs=trace_configs,
headers=multidict.CIMultiDict({"User-agent": user_agent}),
)
if proxy:
partial_request = functools.partial(_request, self.http_session, proxy=proxy)
else:
partial_request = functools.partial(_request, self.http_session)
self.http_session._request = partial_request
self.http_session._request = functools.partial(_request, self.http_session)

def create_connector(self, proxy: str = None) -> aiohttp.BaseConnector:
"""Create a TCPConnector, changing the ssl setting based on the proxy value."""
return aiohttp.TCPConnector(
resolver=aiohttp.AsyncResolver(),
family=socket.AF_INET,
ssl=not (proxy and proxy.startswith("http://")),
)

def _create_http_request_class(self, proxy: str = None) -> type[aiohttp.ClientRequest]:
"""Create a ClientRequest type, which inserts the proxy into every request's args (if set)."""
if not proxy:
return aiohttp.ClientRequest # default

proxy_url = yarl.URL(proxy)
verify_ssl = not proxy.startswith("http://")

class ProxyClientRequest(aiohttp.ClientRequest):
def __init__(self, *args: Any, **kwargs: Any):
kwargs["proxy"] = proxy_url
kwargs["ssl"] = verify_ssl
super().__init__(*args, **kwargs)

return ProxyClientRequest

async def get_self_invite_perms(self) -> disnake.Permissions:
"""Sets the internal invite_permissions and fetches them."""
Expand Down
9 changes: 8 additions & 1 deletion monty/exts/info/github_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,14 @@ def __init__(self, bot: Monty) -> None:
self.bot = bot

transport = AIOHTTPTransport(
url="https://api.github.com/graphql", timeout=20, headers=GITHUB_REQUEST_HEADERS, ssl=True
url="https://api.github.com/graphql",
timeout=20,
headers=GITHUB_REQUEST_HEADERS,
ssl=True,
client_session_args={
# used for applying proxy settings
"request_class": bot.http_request_class,
},
)

self.gql = gql.Client(transport=transport, fetch_schema_from_transport=True)
Expand Down