From 37023c49238d3750568198162acb07dc67ec6a5d Mon Sep 17 00:00:00 2001 From: dazziedez Date: Mon, 9 Sep 2024 16:51:24 +0200 Subject: [PATCH 1/4] modified: jishaku/features/sql.py --- jishaku/features/sql.py | 59 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 59 insertions(+) diff --git a/jishaku/features/sql.py b/jishaku/features/sql.py index 4470101b..79878eb8 100644 --- a/jishaku/features/sql.py +++ b/jishaku/features/sql.py @@ -320,6 +320,65 @@ def format_column_row(self, row: sqlite3.Row) -> str: return f"{row['type']}{not_null}{default_value}{primary_key}" +try: + import sqlalchemy + import sqlalchemy.ext.asyncio +except ImportError: + pass +else: + @adapter(sqlalchemy.ext.asyncio.AsyncEngine, sqlalchemy.ext.asyncio.AsyncConnection) + class SQLAlchemyAsyncAdapter(Adapter[typing.Union[sqlalchemy.ext.asyncio.AsyncEngine, sqlalchemy.ext.asyncio.AsyncConnection]]): + def __init__(self, connection: typing.Union[sqlalchemy.ext.asyncio.AsyncEngine, sqlalchemy.ext.asyncio.AsyncConnection]): + super().__init__(connection) + self.connection: sqlalchemy.ext.asyncio.AsyncConnection = None # type: ignore + + @contextlib.asynccontextmanager + async def use(self): + if isinstance(self.connector, sqlalchemy.ext.asyncio.AsyncEngine): + async with self.connector.connect() as connection: + self.connection = connection + yield + else: + self.connection = self.connector + yield + + def info(self) -> str: + return f"SQLAlchemy Async {sqlalchemy.__version__} {type(self.connector).__name__}" + + async def fetchrow(self, query: str) -> typing.Dict[str, typing.Any]: + result = await self.connection.execute(query) + return dict(result.first()) if result.first() else None + + async def fetch(self, query: str) -> typing.List[typing.Dict[str, typing.Any]]: + result = await self.connection.execute(query) + return [dict(row) for row in result] + + async def execute(self, query: str) -> str: + result = await self.connection.execute(query) + return str(result.rowcount) + " row(s) affected" + + async def table_summary(self, table_query: typing.Optional[str]) -> typing.Dict[str, typing.Dict[str, str]]: + tables: typing.Dict[str, typing.Dict[str, str]] = collections.defaultdict(dict) + + if table_query: + inspector = await sqlalchemy.inspect(self.connection) + table_info = await inspector.get_columns(table_query) + for column in table_info: + tables[table_query][column['name']] = self.format_column_info(column) + else: + inspector = await sqlalchemy.inspect(self.connection) + for table_name in await inspector.get_table_names(): + table_info = await inspector.get_columns(table_name) + for column in table_info: + tables[table_name][column['name']] = self.format_column_info(column) + + return tables + + def format_column_info(self, column: dict) -> str: + column_type = str(column['type']) + nullable = 'NOT NULL' if not column['nullable'] else '' + default = f" DEFAULT {column['default']}" if column['default'] is not None else '' + return f"{column_type} {nullable}{default}" # pylint: enable=missing-class-docstring,missing-function-docstring From e3ec73b0cde495662c80350632fe0fc67913d9d7 Mon Sep 17 00:00:00 2001 From: dazziedez Date: Mon, 9 Sep 2024 21:06:14 +0200 Subject: [PATCH 2/4] make it work --- jishaku/features/sql.py | 71 ++++++++++++++++++++--------------------- 1 file changed, 34 insertions(+), 37 deletions(-) diff --git a/jishaku/features/sql.py b/jishaku/features/sql.py index 79878eb8..9e6c79b2 100644 --- a/jishaku/features/sql.py +++ b/jishaku/features/sql.py @@ -321,64 +321,61 @@ def format_column_row(self, row: sqlite3.Row) -> str: return f"{row['type']}{not_null}{default_value}{primary_key}" try: - import sqlalchemy - import sqlalchemy.ext.asyncio + from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker + from sqlalchemy import text + from sqlalchemy.engine.reflection import Inspector except ImportError: pass else: - @adapter(sqlalchemy.ext.asyncio.AsyncEngine, sqlalchemy.ext.asyncio.AsyncConnection) - class SQLAlchemyAsyncAdapter(Adapter[typing.Union[sqlalchemy.ext.asyncio.AsyncEngine, sqlalchemy.ext.asyncio.AsyncConnection]]): - def __init__(self, connection: typing.Union[sqlalchemy.ext.asyncio.AsyncEngine, sqlalchemy.ext.asyncio.AsyncConnection]): - super().__init__(connection) - self.connection: sqlalchemy.ext.asyncio.AsyncConnection = None # type: ignore + @adapter(async_sessionmaker) + class SQLAlchemyAsyncSessionAdapter(Adapter[async_sessionmaker]): + def __init__(self, session_maker: async_sessionmaker): + super().__init__(session_maker) + self.session: AsyncSession = None # type: ignore @contextlib.asynccontextmanager async def use(self): - if isinstance(self.connector, sqlalchemy.ext.asyncio.AsyncEngine): - async with self.connector.connect() as connection: - self.connection = connection - yield - else: - self.connection = self.connector + async with self.connector() as session: + self.session = session yield def info(self) -> str: - return f"SQLAlchemy Async {sqlalchemy.__version__} {type(self.connector).__name__}" + return f"SQLAlchemy {AsyncSession.__module__.split('.')[1]} AsyncSession" async def fetchrow(self, query: str) -> typing.Dict[str, typing.Any]: - result = await self.connection.execute(query) - return dict(result.first()) if result.first() else None + result = await self.session.execute(text(query)) + row = result.fetchone() + return dict(row._mapping) if row else None async def fetch(self, query: str) -> typing.List[typing.Dict[str, typing.Any]]: - result = await self.connection.execute(query) - return [dict(row) for row in result] + result = await self.session.execute(text(query)) + return [dict(row._mapping) for row in result.fetchall()] async def execute(self, query: str) -> str: - result = await self.connection.execute(query) - return str(result.rowcount) + " row(s) affected" + result = await self.session.execute(text(query)) + await self.session.commit() + return f"{result.rowcount} row(s) affected" - async def table_summary(self, table_query: typing.Optional[str]) -> typing.Dict[str, typing.Dict[str, str]]: - tables: typing.Dict[str, typing.Dict[str, str]] = collections.defaultdict(dict) + async def table_summary( + self, table_query: typing.Optional[str] + ) -> typing.Dict[str, typing.Dict[str, str]]: + tables: typing.Dict[str, typing.Dict[str, str]] = collections.defaultdict( + dict + ) + + inspector = Inspector.from_engine(self.session.get_bind()) if table_query: - inspector = await sqlalchemy.inspect(self.connection) - table_info = await inspector.get_columns(table_query) - for column in table_info: - tables[table_query][column['name']] = self.format_column_info(column) + table_names = [table_query] else: - inspector = await sqlalchemy.inspect(self.connection) - for table_name in await inspector.get_table_names(): - table_info = await inspector.get_columns(table_name) - for column in table_info: - tables[table_name][column['name']] = self.format_column_info(column) + table_names = inspector.get_table_names() - return tables + for table_name in table_names: + columns = inspector.get_columns(table_name) + for column in columns: + tables[table_name][column["name"]] = str(column["type"]) - def format_column_info(self, column: dict) -> str: - column_type = str(column['type']) - nullable = 'NOT NULL' if not column['nullable'] else '' - default = f" DEFAULT {column['default']}" if column['default'] is not None else '' - return f"{column_type} {nullable}{default}" + return tables # pylint: enable=missing-class-docstring,missing-function-docstring From 91d08f5c802d7ca6341ada0036059bc89ee65009 Mon Sep 17 00:00:00 2001 From: freezer <79106393+dazziedez@users.noreply.github.com> Date: Mon, 9 Sep 2024 21:22:22 +0200 Subject: [PATCH 3/4] update for schema fix --- jishaku/features/sql.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/jishaku/features/sql.py b/jishaku/features/sql.py index 9e6c79b2..da60ec61 100644 --- a/jishaku/features/sql.py +++ b/jishaku/features/sql.py @@ -319,10 +319,9 @@ def format_column_row(self, row: sqlite3.Row) -> str: primary_key = " PRIMARY KEY" if row['pk'] else "" return f"{row['type']}{not_null}{default_value}{primary_key}" - try: from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker - from sqlalchemy import text + from sqlalchemy import text, inspect from sqlalchemy.engine.reflection import Inspector except ImportError: pass @@ -362,18 +361,19 @@ async def table_summary( tables: typing.Dict[str, typing.Dict[str, str]] = collections.defaultdict( dict ) - - inspector = Inspector.from_engine(self.session.get_bind()) + + engine = self.session.get_bind() + inspector = inspect(engine) if table_query: table_names = [table_query] else: - table_names = inspector.get_table_names() + table_names = await self.session.run_sync(inspector.get_table_names) for table_name in table_names: - columns = inspector.get_columns(table_name) + columns = await self.session.run_sync(lambda: inspector.get_columns(table_name)) for column in columns: - tables[table_name][column["name"]] = str(column["type"]) + tables[table_name][column['name']] = str(column['type']) return tables From 243a77f4bc2fe73e5488896d50f9d26e1b1a0849 Mon Sep 17 00:00:00 2001 From: freezer <79106393+dazziedez@users.noreply.github.com> Date: Mon, 9 Sep 2024 21:29:14 +0200 Subject: [PATCH 4/4] rely on raw sql queries i don't even know anymore --- jishaku/features/sql.py | 33 ++++++++++++++++++++++----------- 1 file changed, 22 insertions(+), 11 deletions(-) diff --git a/jishaku/features/sql.py b/jishaku/features/sql.py index da60ec61..15b5f6b7 100644 --- a/jishaku/features/sql.py +++ b/jishaku/features/sql.py @@ -355,25 +355,36 @@ async def execute(self, query: str) -> str: await self.session.commit() return f"{result.rowcount} row(s) affected" - async def table_summary( - self, table_query: typing.Optional[str] - ) -> typing.Dict[str, typing.Dict[str, str]]: - tables: typing.Dict[str, typing.Dict[str, str]] = collections.defaultdict( - dict - ) + async def table_summary(self, table_query: typing.Optional[str]) -> typing.Dict[str, typing.Dict[str, str]]: + tables: typing.Dict[str, typing.Dict[str, str]] = collections.defaultdict(dict) - engine = self.session.get_bind() - inspector = inspect(engine) + async def get_table_names(): + result = await self.session.execute(text( + "SELECT tablename FROM pg_catalog.pg_tables " + "WHERE schemaname != 'pg_catalog' AND schemaname != 'information_schema'" + )) + return [row[0] for row in result.fetchall()] + + async def get_column_info(table_name): + result = await self.session.execute(text( + "SELECT column_name, data_type, is_nullable " + "FROM information_schema.columns " + "WHERE table_name = :table_name" + ), {"table_name": table_name}) + return result.fetchall() if table_query: table_names = [table_query] else: - table_names = await self.session.run_sync(inspector.get_table_names) + table_names = await get_table_names() for table_name in table_names: - columns = await self.session.run_sync(lambda: inspector.get_columns(table_name)) + columns = await get_column_info(table_name) for column in columns: - tables[table_name][column['name']] = str(column['type']) + column_type = f"{column.data_type.upper()}" + if column.is_nullable == 'NO': + column_type += " NOT NULL" + tables[table_name][column.column_name] = column_type return tables