diff --git a/koala/db.py b/koala/db.py index a0f7919e..fc238d40 100644 --- a/koala/db.py +++ b/koala/db.py @@ -14,6 +14,13 @@ from functools import wraps from pathlib import Path +if os.name == 'nt': + print("Windows Detected: Database Encryption Disabled") + import sqlite3 +else: + print("Linux Detected: Database Encryption Enabled") + from pysqlcipher3 import dbapi2 as sqlite3 + from sqlalchemy import select, delete, and_, create_engine, func as sql_func from sqlalchemy.orm import sessionmaker @@ -53,7 +60,7 @@ def _get_sql_url(db_path, encrypted: bool, db_key=None): logger.debug("Database Path: "+DATABASE_PATH) engine = create_engine(_get_sql_url(db_path=DATABASE_PATH, encrypted=ENCRYPTED_DB, - db_key=DB_KEY), future=True) + db_key=DB_KEY), module=sqlite3) Session = sessionmaker(future=True) Session.configure(bind=engine) @@ -213,25 +220,15 @@ def get_all_available_guild_extensions(guild_id: int, session: Session): # [extension.extension_id for extension in session.execute(sql_select_all).all()] -def fetch_all_tables(): +def clear_all_tables(): """ - Fetches all table names within the database + Clears all the data from the given tables """ with session_manager() as session: - return [table.name for table in - session.execute("SELECT name FROM sqlite_master WHERE type='table' ORDER BY name;").all()] - - -def clear_all_tables(tables): - """ - Clears al the data from the given tables - - :param tables: a list of all tables to be cleared - """ - with session_manager() as session: - for table in tables: - session.execute('DELETE FROM ' + table + ';') - session.commit() + for table in reversed(mapper_registry.metadata.sorted_tables): + print('Clear table %s' % table) + session.execute(table.delete()) + session.commit() setup() diff --git a/requirements.txt b/requirements.txt index e1ba4dfd..bff288df 100644 --- a/requirements.txt +++ b/requirements.txt @@ -25,7 +25,7 @@ pytest-ordering==0.6 python-dotenv==1.0.0 requests==2.28.2 six==1.16.0 -sqlalchemy==1.4.37 +sqlalchemy==2.0.4 toml==0.10.2 twitchAPI==3.9.0 urllib3==1.26.14 diff --git a/tests/cogs/base/test_core.py b/tests/cogs/base/test_core.py index db103d1b..91979766 100644 --- a/tests/cogs/base/test_core.py +++ b/tests/cogs/base/test_core.py @@ -208,7 +208,7 @@ async def test_list_enabled_extensions(bot: commands.Bot): @mock.patch("koalabot.ENABLED_COGS", ["announce"]) @pytest.mark.asyncio async def test_get_extensions(bot: commands.Bot): - koalabot.load_all_cogs(bot) + await koalabot.load_all_cogs(bot) guild: discord.Guild = dpytest.get_config().guilds[0] resp = core.get_all_available_guild_extensions(guild.id) print(resp) diff --git a/tests/cogs/intro_cog/test_db.py b/tests/cogs/intro_cog/test_db.py index 25d9778f..eee9f540 100644 --- a/tests/cogs/intro_cog/test_db.py +++ b/tests/cogs/intro_cog/test_db.py @@ -137,4 +137,5 @@ async def test_on_member_join(): @pytest.fixture(scope='session', autouse=True) def setup_db(): - koala_db.clear_all_tables(koala_db.fetch_all_tables()) + + koala_db.clear_all_tables() diff --git a/tests/cogs/twitch_alert/test_db.py b/tests/cogs/twitch_alert/test_db.py index 6c6522d6..4d56f72a 100644 --- a/tests/cogs/twitch_alert/test_db.py +++ b/tests/cogs/twitch_alert/test_db.py @@ -26,6 +26,7 @@ from koala.cogs.twitch_alert import utils from koala.cogs.twitch_alert.models import TwitchAlerts, TeamInTwitchAlert, UserInTwitchTeam, UserInTwitchAlert from koala.db import session_manager, setup +from koala.models import mapper_registry # Constants DB_PATH = "Koala.db" @@ -66,13 +67,11 @@ def twitch_alert_db_manager_tables(twitch_alert_db_manager): def test_create_tables(): setup() tables = ['TwitchAlerts', 'UserInTwitchAlert', 'TeamInTwitchAlert', 'UserInTwitchTeam'] - sql_check_table_exists = "SELECT name FROM sqlite_master " \ - "WHERE type='table' AND " \ - "name IN ('TwitchAlerts', 'UserInTwitchAlert', 'TeamInTwitchAlert', 'UserInTwitchTeam');" - with session_manager() as session: - tables_found = session.execute(sql_check_table_exists).all() + tables_found = mapper_registry.metadata.tables for table in tables_found: - assert table.name in tables + if table in tables: + tables.remove(table) + assert tables == [] def test_new_ta(twitch_alert_db_manager_tables): diff --git a/tests/test_koalabot.py b/tests/test_koalabot.py index bb92a1cc..7194cfa2 100644 --- a/tests/test_koalabot.py +++ b/tests/test_koalabot.py @@ -20,7 +20,7 @@ # Own modules import koalabot -from koala.db import clear_all_tables, fetch_all_tables +from koala.db import clear_all_tables from tests.tests_utils.utils import FakeAuthor from tests.tests_utils.last_ctx_cog import LastCtxCog @@ -43,7 +43,7 @@ async def test_ctx(bot: commands.Bot): @pytest.fixture(scope='session', autouse=True) def setup_db(): - clear_all_tables(fetch_all_tables()) + clear_all_tables() @pytest_asyncio.fixture(scope='function', autouse=True)