diff --git a/NOTICE.txt b/NOTICE.txt index f9f07b531..c6ca760fd 100644 --- a/NOTICE.txt +++ b/NOTICE.txt @@ -1602,7 +1602,7 @@ Apache License anyio -4.10.0 +4.11.0 UNKNOWN The MIT License (MIT) diff --git a/connectors/__init__.py b/connectors/__init__.py index 859571c45..a076371b9 100644 --- a/connectors/__init__.py +++ b/connectors/__init__.py @@ -6,4 +6,4 @@ import os with open(os.path.join(os.path.dirname(__file__), "VERSION")) as f: - __version__ = f.read().strip() + __version__: str = f.read().strip() diff --git a/connectors/access_control.py b/connectors/access_control.py index 85beb67d5..ea13ebe3a 100644 --- a/connectors/access_control.py +++ b/connectors/access_control.py @@ -4,6 +4,8 @@ # you may not use this file except in compliance with the Elastic License 2.0. # +from typing import Dict, List, Optional, Union + ACCESS_CONTROL = "_allow_access_control" DLS_QUERY = """{ "bool": { @@ -27,14 +29,18 @@ }""" -def prefix_identity(prefix, identity): +def prefix_identity( + prefix: Optional[str], identity: Optional[Union[str, int]] +) -> Optional[str]: if prefix is None or identity is None: return None return f"{prefix}:{identity}" -def es_access_control_query(access_control): +def es_access_control_query( + access_control: List[Optional[str]], +) -> Dict[str, Dict[str, Dict[str, Union[Dict[str, List[str]], str]]]]: # filter out 'None' values filtered_access_control = list( filter( diff --git a/connectors/agent/cli.py b/connectors/agent/cli.py index 38ed1e77d..a9a8934fb 100644 --- a/connectors/agent/cli.py +++ b/connectors/agent/cli.py @@ -6,6 +6,7 @@ import asyncio import functools import signal +from logging import Logger from elastic_agent_client.util.async_tools import ( sleeps_for_retryable, @@ -14,10 +15,10 @@ from connectors.agent.component import ConnectorsAgentComponent from connectors.agent.logger import get_logger -logger = get_logger("cli") +logger: Logger = get_logger("cli") -def main(args=None): +def main(args=None) -> None: """Script entry point into running Connectors Service on Agent. It initialises an event loop, creates a component and runs the component. diff --git a/connectors/agent/component.py b/connectors/agent/component.py index 8c65fb0d2..84ea85227 100644 --- a/connectors/agent/component.py +++ b/connectors/agent/component.py @@ -4,6 +4,7 @@ # you may not use this file except in compliance with the Elastic License 2.0. # import sys +from logging import Logger from elastic_agent_client.client import V2Options, VersionInfo from elastic_agent_client.reader import new_v2_from_reader @@ -16,7 +17,7 @@ from connectors.agent.service_manager import ConnectorServiceManager from connectors.services.base import MultiService -logger = get_logger("component") +logger: Logger = get_logger("component") CONNECTOR_SERVICE = "connector-service" @@ -30,7 +31,7 @@ class ConnectorsAgentComponent: and provides applied interface to be able to run it in 2 simple methods: run and stop. """ - def __init__(self): + def __init__(self) -> None: """Inits the class. Init should be safe to call without expectations of side effects (connections to Agent, blocking or anything). @@ -42,7 +43,7 @@ def __init__(self): self.buffer = sys.stdin.buffer self.config_wrapper = ConnectorsAgentConfigurationWrapper() - async def run(self): + async def run(self) -> None: """Start reading from Agent protocol and run Connectors Service with settings reported by agent. This method can block if it's not running from Agent - it expects the client to be able to read messages @@ -72,7 +73,7 @@ async def run(self): await self.multi_service.run() - def stop(self, sig): + def stop(self, sig) -> None: """Shutdown everything running in the component. Attempts to gracefully shutdown the services that are running under the component. diff --git a/connectors/agent/config.py b/connectors/agent/config.py index 33d4eb7a7..fdcfaf215 100644 --- a/connectors/agent/config.py +++ b/connectors/agent/config.py @@ -4,12 +4,17 @@ # you may not use this file except in compliance with the Elastic License 2.0. # import base64 +from logging import Logger +from typing import Any, Dict, List, Union +from unittest.mock import Mock + +from elastic_agent_client.client import Unit from connectors.agent.logger import get_logger from connectors.config import add_defaults from connectors.utils import nested_get_from_dict -logger = get_logger("config") +logger: Logger = get_logger("config") class ConnectorsAgentConfigurationWrapper: @@ -21,7 +26,7 @@ class ConnectorsAgentConfigurationWrapper: - Indicating that configuration has changed so that the user of the class can trigger the restart """ - def __init__(self): + def __init__(self) -> None: """Inits the class. There's default config that allows us to run connectors service. When final @@ -37,7 +42,9 @@ def __init__(self): self.specific_config = {} - def try_update(self, connector_id, service_type, output_unit): + def try_update( + self, connector_id: str, service_type: str, output_unit: Union[Mock, Unit] + ) -> bool: """Try update the configuration and see if it changed. This method takes the check-in event data (connector_id, service_type and output) coming @@ -103,7 +110,7 @@ def try_update(self, connector_id, service_type, output_unit): logger.debug("No changes detected for connectors-relevant configurations") return False - def config_changed(self, new_config): + def config_changed(self, new_config: Dict[str, Any]) -> bool: """See if configuration passed in new_config will update currently stored configuration This method takes the new configuration received from the agent and see if there are any changes @@ -175,7 +182,7 @@ def _connectors_config_changes(): return False - def get(self): + def get(self) -> Dict[str, Any]: """Get current Connectors Service configuration. This method combines three configs with higher ones taking precedence: @@ -194,5 +201,7 @@ def get(self): return configuration - def get_specific_config(self): + def get_specific_config( + self, + ) -> Dict[str, Union[List[Dict[str, str]], Dict[str, int]]]: return self.specific_config diff --git a/connectors/agent/connector_record_manager.py b/connectors/agent/connector_record_manager.py index fe231dddd..d2f7a40a9 100644 --- a/connectors/agent/connector_record_manager.py +++ b/connectors/agent/connector_record_manager.py @@ -4,11 +4,14 @@ # you may not use this file except in compliance with the Elastic License 2.0. # +from logging import Logger +from typing import Optional, Tuple + from connectors.agent.logger import get_logger from connectors.protocol import ConnectorIndex from connectors.utils import generate_random_id -logger = get_logger("agent_connector_record_manager") +logger: Logger = get_logger("agent_connector_record_manager") class ConnectorRecordManager: @@ -17,10 +20,12 @@ class ConnectorRecordManager: exist in the connector index. It creates the connector record if necessary. """ - def __init__(self): + def __init__(self) -> None: self.connector_index = None - async def ensure_connector_records_exist(self, agent_config, connector_name=None): + async def ensure_connector_records_exist( + self, agent_config, connector_name: Optional[str] = None + ) -> None: """ Ensure that connector records exist for all connectors specified in the agent configuration. @@ -71,7 +76,7 @@ async def ensure_connector_records_exist(self, agent_config, connector_name=None f"Skipping connector creation. Connector record for {connector_id} already exists." ) - def _check_agent_config_ready(self, agent_config): + def _check_agent_config_ready(self, agent_config) -> Tuple[bool, Optional[str]]: """ Validates the agent configuration to check if all info is present to create a connector record. diff --git a/connectors/agent/logger.py b/connectors/agent/logger.py index 1c431ce69..1fdb6b80e 100644 --- a/connectors/agent/logger.py +++ b/connectors/agent/logger.py @@ -4,17 +4,18 @@ # you may not use this file except in compliance with the Elastic License 2.0. # import logging +from typing import TextIO, Union import ecs_logging -root_logger = logging.getLogger("agent_component") -handler = logging.StreamHandler() +root_logger: logging.Logger = logging.getLogger("agent_component") +handler: logging.StreamHandler[TextIO] = logging.StreamHandler() handler.setFormatter(ecs_logging.StdlibFormatter()) root_logger.addHandler(handler) root_logger.setLevel(logging.INFO) -def get_logger(module): +def get_logger(module: str) -> logging.Logger: logger = root_logger.getChild(module) if logger.hasHandlers(): @@ -25,5 +26,5 @@ def get_logger(module): return logger -def update_logger_level(log_level): +def update_logger_level(log_level: Union[int, str]) -> None: root_logger.setLevel(log_level) diff --git a/connectors/agent/protocol.py b/connectors/agent/protocol.py index b136691cc..5ccc6ba21 100644 --- a/connectors/agent/protocol.py +++ b/connectors/agent/protocol.py @@ -4,6 +4,9 @@ # you may not use this file except in compliance with the Elastic License 2.0. # +from logging import Logger + +from elastic_agent_client.client import V2 from elastic_agent_client.generated import elastic_agent_client_pb2 as proto from elastic_agent_client.handler.action import BaseActionHandler from elastic_agent_client.handler.checkin import BaseCheckinHandler @@ -11,7 +14,7 @@ from connectors.agent.connector_record_manager import ConnectorRecordManager from connectors.agent.logger import get_logger -logger = get_logger("protocol") +logger: Logger = get_logger("protocol") CONNECTORS_INPUT_TYPE = "connectors-py" @@ -49,10 +52,10 @@ class ConnectorCheckinHandler(BaseCheckinHandler): def __init__( self, - client, + client: V2, agent_connectors_config_wrapper, service_manager, - ): + ) -> None: """Inits the class. Initing this class should not produce side-effects. @@ -62,7 +65,7 @@ def __init__( self.service_manager = service_manager self.connector_record_manager = ConnectorRecordManager() - async def apply_from_client(self): + async def apply_from_client(self) -> None: """Implementation of BaseCheckinHandler.apply_from_client This method is called by the Agent Protocol handlers when there's a check-in event diff --git a/connectors/agent/service_manager.py b/connectors/agent/service_manager.py index 104d967c7..6b177e326 100644 --- a/connectors/agent/service_manager.py +++ b/connectors/agent/service_manager.py @@ -14,7 +14,7 @@ ) from connectors.utils import CancellableSleeps -logger = get_logger("service_manager") +logger: logging.Logger = get_logger("service_manager") class ConnectorServiceManager: @@ -28,7 +28,7 @@ class ConnectorServiceManager: """ - def __init__(self, configuration): + def __init__(self, configuration) -> None: """Inits ConnectorServiceManager with shared ConnectorsAgentConfigurationWrapper. This service is supposed to be ran once, and after it's stopped or finished running it's not @@ -41,7 +41,7 @@ def __init__(self, configuration): self._running = False self._sleeps = CancellableSleeps() - async def run(self): + async def run(self) -> None: """Starts the running loop of the service. Once started, the service attempts to run all needed connector subservices @@ -81,7 +81,7 @@ async def run(self): finally: logger.info("Finished running, exiting") - def stop(self): + def stop(self) -> None: """Stop the service manager and all running subservices. Running stop attempts to gracefully shutdown all subservices currently running. @@ -92,7 +92,7 @@ def stop(self): if self._multi_service: self._multi_service.shutdown(None) - def restart(self): + def restart(self) -> None: """Restart the service manager and all running subservices. Running restart attempts to gracefully shutdown all subservices currently running. diff --git a/connectors/build_info.py b/connectors/build_info.py index c2667b251..a716040d8 100644 --- a/connectors/build_info.py +++ b/connectors/build_info.py @@ -11,7 +11,7 @@ # This references a file that's built in .buildkite/publish/publish-common.sh # See https://github.com/elastic/connectors/pull/3154 for more info -yaml_path = os.path.join(os.path.dirname(__file__), "build.yaml") +yaml_path: str = os.path.join(os.path.dirname(__file__), "build.yaml") if os.path.exists(yaml_path): __build_info__ = "" with open(yaml_path) as f: diff --git a/connectors/cli/auth.py b/connectors/cli/auth.py index c24176419..162a0b090 100644 --- a/connectors/cli/auth.py +++ b/connectors/cli/auth.py @@ -5,6 +5,7 @@ # import asyncio import os +from typing import Optional import yaml from elasticsearch import ApiError @@ -15,7 +16,13 @@ class Auth: - def __init__(self, host, username=None, password=None, api_key=None): + def __init__( + self, + host: str, + username: Optional[str] = None, + password: Optional[str] = None, + api_key: Optional[str] = None, + ) -> None: elastic_config = { "host": host, "username": username, @@ -28,14 +35,14 @@ def __init__(self, host, username=None, password=None, api_key=None): self.cli_client = CLIClient(self.elastic_config) - def authenticate(self): + def authenticate(self) -> bool: if asyncio.run(self.__ping_es_client()): self.__save_config() return True else: return False - def is_config_present(self): + def is_config_present(self) -> bool: return os.path.isfile(CONFIG_FILE_PATH) async def __ping_es_client(self): diff --git a/connectors/cli/connector.py b/connectors/cli/connector.py index 6bd00d327..b634b8abc 100644 --- a/connectors/cli/connector.py +++ b/connectors/cli/connector.py @@ -5,6 +5,7 @@ # import asyncio from collections import OrderedDict +from typing import Dict, Union from connectors.es import DEFAULT_LANGUAGE from connectors.es.cli_client import CLIClient @@ -25,7 +26,7 @@ class IndexAlreadyExists(Exception): class Connector: - def __init__(self, config): + def __init__(self, config) -> None: self.config = config # initialize ES client @@ -62,8 +63,8 @@ def create( configuration, is_native, name, - language=DEFAULT_LANGUAGE, - from_index=False, + language: str = DEFAULT_LANGUAGE, + from_index: bool = False, ): return asyncio.run( self.__create( @@ -186,7 +187,7 @@ async def __create_connector( await self.cli_client.close() await self.connector_index.close() - def default_scheduling(self): + def default_scheduling(self) -> Dict[str, Dict[str, Union[bool, str]]]: return { "access_control": {"enabled": False, "interval": EVERYDAY_AT_MIDNIGHT}, "full": {"enabled": False, "interval": EVERYDAY_AT_MIDNIGHT}, diff --git a/connectors/cli/index.py b/connectors/cli/index.py index 4f0d91af3..87da922f4 100644 --- a/connectors/cli/index.py +++ b/connectors/cli/index.py @@ -16,7 +16,7 @@ class Index: - def __init__(self, config): + def __init__(self, config) -> None: self.elastic_config = config self.cli_client = CLIClient(self.elastic_config) self.connectors_index = ConnectorIndex(self.elastic_config) diff --git a/connectors/cli/job.py b/connectors/cli/job.py index 86e80fb1b..fbd6197d4 100644 --- a/connectors/cli/job.py +++ b/connectors/cli/job.py @@ -21,7 +21,7 @@ class Job: - def __init__(self, config): + def __init__(self, config) -> None: self.config = config self.cli_client = CLIClient(self.config) self.sync_job_index = SyncJobIndex(self.config) diff --git a/connectors/config.py b/connectors/config.py index 17990a120..7e4acd429 100644 --- a/connectors/config.py +++ b/connectors/config.py @@ -5,6 +5,7 @@ # import os +from typing import Dict, Optional, Union from envyaml import EnvYAML @@ -28,7 +29,19 @@ def load_config(config_file): return configuration -def add_defaults(config, default_config=None): +def add_defaults( + config, + default_config: Optional[ + Dict[ + str, + Union[ + Dict[str, str], + Dict[str, Union[Dict[str, Union[Dict[str, float], int]], int, str]], + Dict[str, Union[int, str]], + ], + ] + ] = None, +): if default_config is None: default_config = _default_config() configuration = dict(_merge_dicts(default_config, config)) @@ -57,7 +70,16 @@ def add_defaults(config, default_config=None): } -def _default_config(): +def _default_config() -> ( + Dict[ + str, + Union[ + Dict[str, str], + Dict[str, Union[Dict[str, Union[Dict[str, float], int]], int, str]], + Dict[str, Union[int, str]], + ], + ] +): return { "elasticsearch": { "host": "http://localhost:9200", @@ -146,7 +168,7 @@ def _default_config(): } -def _ent_search_config(configuration): +def _ent_search_config(configuration) -> None: if "ENT_SEARCH_CONFIG_PATH" not in os.environ: return logger.info("Found ENT_SEARCH_CONFIG_PATH, loading ent-search config") @@ -169,7 +191,7 @@ def _ent_search_config(configuration): logger.debug(f"Overridden {connector_field}") -def _nest_configs(configuration, field, value): +def _nest_configs(configuration, field, value) -> None: """ Update configuration field value taking into account the nesting. @@ -221,7 +243,7 @@ class DataSourceFrameworkConfig: preventing them from requiring substantial changes to access new configs that may be added. """ - def __init__(self, max_file_size): + def __init__(self, max_file_size) -> None: """ Should not be called directly. Use the Builder. """ diff --git a/connectors/connectors_cli.py b/connectors/connectors_cli.py index ffccc9aba..07d0f8803 100644 --- a/connectors/connectors_cli.py +++ b/connectors/connectors_cli.py @@ -14,6 +14,7 @@ import asyncio import json import os +from typing import BinaryIO, List, TextIO, Union import click import yaml @@ -32,7 +33,7 @@ __all__ = ["main"] -def load_config(ctx, config): +def load_config(ctx, config: Union[BinaryIO, TextIO, bytes, str]): if config: return yaml.safe_load(config) elif os.path.isfile(CONFIG_FILE_PATH): @@ -53,7 +54,7 @@ def load_config(ctx, config): @click.version_option(__version__, "-v", "--version", message="%(version)s") @click.option("-c", "--config", type=click.File("rb")) @click.pass_context -def cli(ctx, config): +def cli(ctx, config: Union[BinaryIO, TextIO, bytes, str]) -> None: # print help page if no subcommands provided if ctx.invoked_subcommand is None: click.echo(ctx.get_help()) @@ -77,7 +78,7 @@ def cli(ctx, config): default="basic", help="Authentication method", ) -def login(host, method): +def login(host, method) -> None: if method == "basic": username = click.prompt("Username") password = click.prompt("Password", hide_input=True) @@ -113,13 +114,13 @@ def login(host, method): # Connector group @click.group(invoke_without_command=False, help="Connectors management") @click.pass_context -def connector(ctx): +def connector(ctx) -> None: pass @click.command(name="list", help="List all existing connectors") @click.pass_obj -def list_connectors(obj): +def list_connectors(obj) -> None: connector = Connector(config=obj["config"]["elasticsearch"]) coro = connector.list_connectors() @@ -159,11 +160,11 @@ def list_connectors(obj): click.echo(e) -language_keys = [DEFAULT_LANGUAGE] +language_keys: List[str] = [DEFAULT_LANGUAGE] # Support blank values for languge -def validate_language(ctx, param, value): +def validate_language(ctx, param, value) -> None: if value not in language_keys: return None @@ -171,7 +172,7 @@ def validate_language(ctx, param, value): # override click's default 'choices' prompt with something a bit nicer -def interactive_service_type_prompt(): +def interactive_service_type_prompt() -> str: options = list(_default_config()["sources"].keys()) print(f"{Fore.GREEN}?{Style.RESET_ALL} Service type:") # noqa: T201 result = TerminalMenu( @@ -240,14 +241,14 @@ def create( obj, index_name, service_type, - index_language, + index_language: str, is_native, - from_index, + from_index: bool, from_file, update_config, connector_service_config, name, -): +) -> None: connector_configuration = {} if from_file: with open(from_file) as fd: @@ -386,13 +387,13 @@ def prompt(): # Index group @click.group(invoke_without_command=False, help="Search indices management") @click.pass_obj -def index(obj): +def index(obj) -> None: pass @click.command(name="list", help="Show all indices") @click.pass_obj -def list_indices(obj): +def list_indices(obj) -> None: index = Index(config=obj["config"]["elasticsearch"]) indices = index.list_indices() @@ -420,7 +421,7 @@ def list_indices(obj): @click.command(help="Remove all documents from the index") @click.pass_obj @click.argument("index", nargs=1) -def clean(obj, index): +def clean(obj, index) -> None: index_cli = Index(config=obj["config"]["elasticsearch"]) click.confirm( click.style("Are you sure you want to clean " + index + "?", fg="yellow"), @@ -445,7 +446,7 @@ def clean(obj, index): @click.command(help="Delete an index") @click.pass_obj @click.argument("index", nargs=1) -def delete(obj, index): +def delete(obj, index) -> None: index_cli = Index(config=obj["config"]["elasticsearch"]) click.confirm( click.style("Are you sure you want to delete " + index + "?", fg="yellow"), @@ -472,7 +473,7 @@ def delete(obj, index): # Job group @click.group(invoke_without_command=False, help="Sync jobs management") @click.pass_obj -def job(obj): +def job(obj) -> None: pass @@ -493,7 +494,7 @@ def job(obj): help="Output format", type=click.Choice(["json", "text"]), ) -def start(obj, i, t, output_format): +def start(obj, i, t, output_format) -> None: job_cli = Job(config=obj["config"]["elasticsearch"]) job_id = job_cli.start(connector_id=i, job_type=t) @@ -521,7 +522,7 @@ def start(obj, i, t, output_format): @click.command(name="list", help="List of jobs sorted by date.") @click.pass_obj @click.argument("connector_id", nargs=1) -def list_jobs(obj, connector_id): +def list_jobs(obj, connector_id) -> None: job_cli = Job(config=obj["config"]["elasticsearch"]) jobs = job_cli.list_jobs(connector_id=connector_id) @@ -566,7 +567,7 @@ def list_jobs(obj, connector_id): @click.command(help="Cancel a job") @click.pass_obj @click.argument("job_id") -def cancel(obj, job_id): +def cancel(obj, job_id) -> None: job_cli = Job(config=obj["config"]["elasticsearch"]) click.confirm( click.style("Are you sure you want to cancel jobs?", fg="yellow"), abort=True @@ -599,7 +600,7 @@ def cancel(obj, job_id): help="Output format", type=click.Choice(["json", "text"]), ) -def view_job(obj, job_id, output_format): +def view_job(obj, job_id, output_format) -> None: job_cli = Job(config=obj["config"]["elasticsearch"]) job = job_cli.job(job_id=job_id) result = { @@ -633,7 +634,7 @@ def view_job(obj, job_id, output_format): cli.add_command(job) -def main(args=None): +def main(args=None) -> None: cli() # pyright: ignore diff --git a/connectors/content_extraction.py b/connectors/content_extraction.py index b57d838c1..50d6b3c4f 100644 --- a/connectors/content_extraction.py +++ b/connectors/content_extraction.py @@ -5,13 +5,18 @@ # import os +from typing import Dict, Optional import aiofiles import aiohttp +from aiohttp.client import ClientSession from aiohttp.client_exceptions import ClientConnectionError, ServerTimeoutError +from aiohttp.client_reqrep import ClientResponse from connectors.logger import logger +__EXTRACTION_CONFIG = {} # setup by cli.py on startup + class ContentExtraction: """Content extraction manager @@ -21,18 +26,18 @@ class ContentExtraction: Requires the data extraction service to be running """ - __EXTRACTION_CONFIG = {} # setup by cli.py on startup - @classmethod - def get_extraction_config(cls): + def get_extraction_config(cls) -> Dict[str, Dict[str, str]]: return __EXTRACTION_CONFIG @classmethod - def set_extraction_config(cls, extraction_config): + def set_extraction_config( + cls, extraction_config: Optional[Dict[str, Dict[str, str]]] + ) -> None: global __EXTRACTION_CONFIG __EXTRACTION_CONFIG = extraction_config - def __init__(self): + def __init__(self) -> None: self.session = None self.extraction_config = ContentExtraction.get_extraction_config() @@ -60,13 +65,13 @@ def __init__(self): "Extraction service has been initialised but no extraction service configuration was found. No text will be extracted for this sync." ) - def _check_configured(self): + def _check_configured(self) -> bool: if self.host is not None: return True return False - def _begin_session(self): + def _begin_session(self) -> Optional[ClientSession]: if self.session is not None: return self.session @@ -76,19 +81,19 @@ def _begin_session(self): headers=self.headers, ) - async def _end_session(self): + async def _end_session(self) -> None: if not self.session: return await self.session.close() - def get_volume_dir(self): + def get_volume_dir(self) -> Optional[str]: if self.host is None: return None return self.volume_dir - async def extract_text(self, filepath, original_filename): + async def extract_text(self, filepath: str, original_filename: str) -> str: """Sends a text extraction request to tika-server using the supplied filename. Args: filepath: local path to the tempfile for extraction @@ -131,27 +136,29 @@ async def extract_text(self, filepath, original_filename): return content - async def send_filepointer(self, filepath, filename): + async def send_filepointer(self, filepath: str, filename: str) -> str: async with self._begin_session().put( f"{self.host}/extract_text/?local_file_path={filepath}", ) as response: return await self.parse_extraction_resp(filename, response) - async def send_file(self, filepath, filename): + async def send_file(self, filepath: str, filename: str) -> str: async with self._begin_session().put( f"{self.host}/extract_text/", data=self.file_sender(filepath), ) as response: return await self.parse_extraction_resp(filename, response) - async def file_sender(self, filepath): + async def file_sender(self, filepath: str): async with aiofiles.open(filepath, "rb") as f: chunk = await f.read(self.chunk_size) while chunk: yield chunk chunk = await f.read(self.chunk_size) - async def parse_extraction_resp(self, filename, response): + async def parse_extraction_resp( + self, filename: str, response: ClientResponse + ) -> str: """Parses the response from the tika-server and logs any extraction failures. Returns `extracted_text` from the response. diff --git a/connectors/es/client.py b/connectors/es/client.py index f68276044..ef4848f6b 100644 --- a/connectors/es/client.py +++ b/connectors/es/client.py @@ -7,6 +7,7 @@ import logging import time from enum import Enum +from typing import Tuple from elastic_transport import ConnectionTimeout from elastic_transport.client_utils import url_to_node_config @@ -45,7 +46,7 @@ class License(Enum): class ESClient: user_agent = f"{USER_AGENT_BASE}/service" - def __init__(self, config): + def __init__(self, config) -> None: self.serverless = config.get("serverless", False) self.config = config self.configured_host = config.get("host", "http://localhost:9200") @@ -114,11 +115,11 @@ def __init__(self, config): self.client = AsyncElasticsearch(**options) self._keep_waiting = True - def stop_waiting(self): + def stop_waiting(self) -> None: self._keep_waiting = False self._sleeps.cancel() - async def has_active_license_enabled(self, license_): + async def has_active_license_enabled(self, license_) -> Tuple[bool, License]: """This method checks, whether an active license or a more powerful active license is enabled. Returns: @@ -151,7 +152,7 @@ async def has_active_license_enabled(self, license_): actual_license, ) - async def close(self): + async def close(self) -> None: await self._retrier.close() await self.client.close() @@ -208,8 +209,8 @@ def __init__( logger_, max_retries, retry_interval, - retry_strategy=RetryStrategy.LINEAR_BACKOFF, - ): + retry_strategy: RetryStrategy = RetryStrategy.LINEAR_BACKOFF, + ) -> None: self._logger = logger_ self._sleeps = CancellableSleeps() self._keep_retrying = True @@ -218,11 +219,11 @@ def __init__( self._retry_interval = retry_interval self._retry_strategy = retry_strategy - async def close(self): + async def close(self) -> None: self._sleeps.cancel() self._keep_retrying = False - async def _sleep(self, retry): + async def _sleep(self, retry) -> None: time_to_sleep = time_to_sleep_between_retries( self._retry_strategy, self._retry_interval, retry ) @@ -260,7 +261,7 @@ async def execute_with_retry(self, func): raise RetryInterruptedError(msg) -def with_concurrency_control(retries=3): +def with_concurrency_control(retries: int = 3): def wrapper(func): @functools.wraps(func) async def wrapped(*args, **kwargs): diff --git a/connectors/es/document.py b/connectors/es/document.py index 3c3ba55d4..7e241a23e 100644 --- a/connectors/es/document.py +++ b/connectors/es/document.py @@ -42,41 +42,41 @@ def get(self, *keys, default=None): return default return value - async def reload(self): + async def reload(self) -> None: doc_source = await self.index.fetch_response_by_id(self.id) self._seq_no = doc_source.get("_seq_no") self._primary_term = doc_source.get("_primary_term") self._source = doc_source.get("_source", {}) - def log_debug(self, msg, *args, **kwargs): + def log_debug(self, msg, *args, **kwargs) -> None: self.logger.debug( msg, *args, **kwargs, ) - def log_info(self, msg, *args, **kwargs): + def log_info(self, msg, *args, **kwargs) -> None: self.logger.info( msg, *args, **kwargs, ) - def log_warning(self, msg, *args, **kwargs): + def log_warning(self, msg, *args, **kwargs) -> None: self.logger.warning( msg, *args, **kwargs, ) - def log_error(self, msg, *args, **kwargs): + def log_error(self, msg, *args, **kwargs) -> None: self.logger.error( msg, *args, **kwargs, ) - def log_exception(self, msg, *args, exc_info=True, **kwargs): + def log_exception(self, msg, *args, exc_info: bool = True, **kwargs) -> None: self.logger.exception( msg, *args, @@ -84,21 +84,21 @@ def log_exception(self, msg, *args, exc_info=True, **kwargs): **kwargs, ) - def log_critical(self, msg, *args, **kwargs): + def log_critical(self, msg, *args, **kwargs) -> None: self.logger.critical( msg, *args, **kwargs, ) - def log_fatal(self, msg, *args, **kwargs): + def log_fatal(self, msg, *args, **kwargs) -> None: self.logger.fatal( msg, *args, **kwargs, ) - def _prefix(self): + def _prefix(self) -> None: """Return a string which will be prefixed to the log message when filebeat is not turned on""" return None diff --git a/connectors/es/index.py b/connectors/es/index.py index c5309458e..2e2e7b771 100644 --- a/connectors/es/index.py +++ b/connectors/es/index.py @@ -24,7 +24,7 @@ class TemporaryConnectorApiWrapper(ESClient): this class will be removed. """ - def __init__(self, elastic_config): + def __init__(self, elastic_config) -> None: super().__init__(elastic_config) async def connector_get(self, connector_id, include_deleted): @@ -37,7 +37,7 @@ async def connector_get(self, connector_id, include_deleted): class ESApi(ESClient): - def __init__(self, elastic_config): + def __init__(self, elastic_config) -> None: super().__init__(elastic_config) self._api_wrapper = TemporaryConnectorApiWrapper(elastic_config) @@ -46,7 +46,7 @@ async def connector_check_in(self, connector_id): partial(self.client.connector.check_in, connector_id=connector_id) ) - async def connector_get(self, connector_id, include_deleted=False): + async def connector_get(self, connector_id, include_deleted: bool = False): return await self._retrier.execute_with_retry( partial(self._api_wrapper.connector_get, connector_id, include_deleted) ) @@ -147,7 +147,7 @@ class ESIndex(ESClient): elastic_config (dict): Elasticsearch configuration and credentials """ - def __init__(self, index_name, elastic_config): + def __init__(self, index_name, elastic_config) -> None: # initialize elasticsearch client super().__init__(elastic_config) self.api = ESApi(elastic_config) @@ -230,7 +230,9 @@ async def update_by_script(self, doc_id, script): ) ) - async def get_all_docs(self, query=None, sort=None, page_size=DEFAULT_PAGE_SIZE): + async def get_all_docs( + self, query=None, sort=None, page_size: int = DEFAULT_PAGE_SIZE + ): """ Lookup for elasticsearch documents using {query} diff --git a/connectors/es/management_client.py b/connectors/es/management_client.py index 0756d5c07..80728b320 100644 --- a/connectors/es/management_client.py +++ b/connectors/es/management_client.py @@ -4,8 +4,12 @@ # you may not use this file except in compliance with the Elastic License 2.0. # +from _asyncio import Task from functools import partial +from typing import Any, Dict, Generator, List, Optional, Union +from unittest.mock import AsyncMock +from elastic_transport import ApiResponse, HeadApiResponse, ObjectApiResponse from elasticsearch import ApiError from elasticsearch import ( NotFoundError as ElasticNotFoundError, @@ -31,12 +35,24 @@ class ESManagementClient(ESClient): This client, on the contrary, is used to manage a number of indices outside of connector protocol operations. """ - def __init__(self, config): + def __init__( + self, + config: Dict[ + str, + Union[ + str, + bool, + Dict[str, Union[int, bool, Dict[str, Union[bool, int, float]]]], + int, + float, + ], + ], + ) -> None: logger.debug(f"ESManagementClient connecting to {config['host']}") # initialize ESIndex instance super().__init__(config) - async def ensure_exists(self, indices=None): + async def ensure_exists(self, indices: Optional[List[str]] = None) -> None: if indices is None: indices = [] @@ -50,7 +66,9 @@ async def ensure_exists(self, indices=None): ) logger.debug(f"Created index {index}") - async def create_content_index(self, search_index_name, language_code): + async def create_content_index( + self, search_index_name: str, language_code: Optional[str] + ) -> Union[ObjectApiResponse, ApiResponse, AsyncMock]: return await self._retrier.execute_with_retry( partial( self.client.indices.create, @@ -59,8 +77,12 @@ async def create_content_index(self, search_index_name, language_code): ) async def ensure_ingest_pipeline_exists( - self, pipeline_id, version, description, processors - ): + self, + pipeline_id: Union[str, int], + version: int, + description: str, + processors: List[Union[str, Dict[str, Dict[str, Union[str, List[str], bool]]]]], + ) -> None: try: await self._retrier.execute_with_retry( partial(self.client.ingest.get_pipeline, id=pipeline_id) @@ -76,12 +98,12 @@ async def ensure_ingest_pipeline_exists( ) ) - async def delete_indices(self, indices): + async def delete_indices(self, indices: List[str]) -> None: await self._retrier.execute_with_retry( partial(self.client.indices.delete, index=indices, ignore_unavailable=True) ) - async def clean_index(self, index_name): + async def clean_index(self, index_name: str) -> AsyncMock: return await self._retrier.execute_with_retry( partial( self.client.delete_by_query, @@ -91,7 +113,7 @@ async def clean_index(self, index_name): ) ) - async def list_indices(self, index="*"): + async def list_indices(self, index: str = "*") -> Dict[Any, Any]: """ List indices using Elasticsearch.stats API. Includes the number of documents in each index. """ @@ -105,7 +127,7 @@ async def list_indices(self, index="*"): return indices - async def list_indices_serverless(self, index="*"): + async def list_indices_serverless(self, index: str = "*"): """ List indices in a serverless environment. This method is a workaround to the fact that the `indices.stats` API is not available in serverless environments. @@ -125,12 +147,16 @@ async def list_indices_serverless(self, index="*"): return indices - async def index_exists(self, index_name): + async def index_exists(self, index_name: str) -> Union[HeadApiResponse, AsyncMock]: return await self._retrier.execute_with_retry( partial(self.client.indices.exists, index=index_name) ) - async def get_index_or_alias(self, index_name, ignore_unavailable=False): + async def get_index_or_alias( + self, index_name: str, ignore_unavailable: bool = False + ) -> Optional[ + Union[Dict[str, Dict[Any, Any]], Dict[str, Dict[str, Dict[Any, Any]]]] + ]: """ Get index definition (mappings and settings) by its name or its alias. """ @@ -169,7 +195,7 @@ async def get_index_or_alias(self, index_name, ignore_unavailable=False): return None - async def upsert(self, _id, index_name, doc): + async def upsert(self, _id: str, index_name: str, doc: Dict[str, str]) -> AsyncMock: return await self._retrier.execute_with_retry( partial( self.client.index, @@ -179,7 +205,22 @@ async def upsert(self, _id, index_name, doc): ) ) - async def bulk_insert(self, operations, pipeline): + async def bulk_insert( + self, + operations: List[Union[Dict[str, Dict[str, str]], Dict[str, str], Any]], + pipeline: str, + ) -> Generator[ + Task, + None, + Union[ + ObjectApiResponse, + Dict[str, List[Dict[str, Dict[str, str]]]], + Dict[ + str, Union[bool, List[Dict[str, Dict[str, Union[Dict[str, str], str]]]]] + ], + Dict[str, List[Any]], + ], + ]: return await self._retrier.execute_with_retry( partial( self.client.bulk, @@ -188,7 +229,7 @@ async def bulk_insert(self, operations, pipeline): ) ) - async def yield_existing_documents_metadata(self, index): + async def yield_existing_documents_metadata(self, index: str) -> None: """Returns an iterator on the `id` and `_timestamp` fields of all documents in an index. WARNING @@ -211,7 +252,7 @@ async def yield_existing_documents_metadata(self, index): yield doc_id, timestamp - async def get_connector_secret(self, connector_secret_id): + async def get_connector_secret(self, connector_secret_id: str) -> str: secret = await self._retrier.execute_with_retry( partial( self.client.perform_request, @@ -221,7 +262,7 @@ async def get_connector_secret(self, connector_secret_id): ) return secret.get("value") - async def create_connector_secret(self, secret_value): + async def create_connector_secret(self, secret_value: str) -> str: secret = await self._retrier.execute_with_retry( partial( self.client.perform_request, diff --git a/connectors/es/sink.py b/connectors/es/sink.py index dae0a74dc..f6d1dd1fa 100644 --- a/connectors/es/sink.py +++ b/connectors/es/sink.py @@ -23,20 +23,24 @@ import functools import logging import time +from typing import Any, Dict, Optional, Tuple from connectors.config import ( DEFAULT_ELASTICSEARCH_MAX_RETRIES, DEFAULT_ELASTICSEARCH_RETRY_INTERVAL, ) from connectors.es import TIMESTAMP_FIELD +from connectors.es.client import License from connectors.es.management_client import ESManagementClient +from connectors.exceptions import DocumentIngestionError from connectors.filtering.basic_rule import BasicRuleEngine, parse from connectors.logger import logger, tracer -from connectors.protocol import Filter, JobType +from connectors.protocol import JobType from connectors.protocol.connectors import ( DELETED_DOCUMENT_COUNT, INDEXED_DOCUMENT_COUNT, INDEXED_DOCUMENT_VOLUME, + Filter, ) from connectors.utils import ( DEFAULT_CHUNK_MEM_SIZE, @@ -107,16 +111,12 @@ class ContentIndexDoesNotExistError(Exception): class ElasticsearchOverloadedError(Exception): - def __init__(self, cause=None): + def __init__(self, cause=None) -> None: msg = "Connector was unable to ingest data into overloaded Elasticsearch. Make sure Elasticsearch instance is healthy, has enough resources and content index is healthy." super().__init__(msg) self.__cause__ = cause -class DocumentIngestionError(Exception): - pass - - class Sink: """Send bulk operations in batches by consuming a queue. @@ -145,8 +145,8 @@ def __init__( retry_interval, error_monitor, logger_=None, - enable_bulk_operations_logging=False, - ): + enable_bulk_operations_logging: bool = False, + ) -> None: self.client = client self.queue = queue self.chunk_size = chunk_size @@ -162,7 +162,7 @@ def __init__( self._enable_bulk_operations_logging = enable_bulk_operations_logging self.counters = Counters() - def _bulk_op(self, doc, operation=OP_INDEX): + def _bulk_op(self, doc, operation: str = OP_INDEX): doc_id = doc["_id"] index = doc["_index"] @@ -231,7 +231,9 @@ def _map_id_to_op(self, operations): result[doc["_id"]] = op return result - async def _process_bulk_response(self, res, ids_to_ops, do_log=False): + async def _process_bulk_response( + self, res, ids_to_ops, do_log: bool = False + ) -> None: for item in res.get("items", []): if OP_INDEX in item: action_item = OP_INDEX @@ -302,7 +304,7 @@ async def _process_bulk_response(self, res, ids_to_ops, do_log=False): ) self.counters.increment(RESULT_SUCCESS) - def _populate_stats(self, stats, res): + def _populate_stats(self, stats, res) -> None: for item in res["items"]: for op, data in item.items(): # "result" is only present in successful operations @@ -327,7 +329,7 @@ def _populate_stats(self, stats, res): f"Sink stats - no. of docs indexed: {self.counters.get(INDEXED_DOCUMENT_COUNT)}, volume of docs indexed: {round(self.counters.get(INDEXED_DOCUMENT_VOLUME))} bytes, no. of docs deleted: {self.counters.get(DELETED_DOCUMENT_COUNT)}" ) - def force_cancel(self): + def force_cancel(self) -> None: self._canceled = True async def fetch_doc(self): @@ -336,7 +338,7 @@ async def fetch_doc(self): return await self.queue.get() - async def run(self): + async def run(self) -> None: try: await self._run() except asyncio.CancelledError: @@ -352,7 +354,7 @@ async def run(self): return raise - async def _run(self): + async def _run(self) -> None: """Creates batches of bulk calls given a queue of items. An item is a (size, object) tuple. Exits when the @@ -441,14 +443,14 @@ def __init__( client, queue, index, - filter_=None, - sync_rules_enabled=False, - content_extraction_enabled=True, - display_every=DEFAULT_DISPLAY_EVERY, - concurrent_downloads=DEFAULT_CONCURRENT_DOWNLOADS, + filter_: Optional[Filter] = None, + sync_rules_enabled: bool = False, + content_extraction_enabled: bool = True, + display_every: int = DEFAULT_DISPLAY_EVERY, + concurrent_downloads: int = DEFAULT_CONCURRENT_DOWNLOADS, logger_=None, - skip_unchanged_documents=False, - ): + skip_unchanged_documents: bool = False, + ) -> None: if filter_ is None: filter_ = Filter() self.client = client @@ -468,7 +470,7 @@ def __init__( self._canceled = False self.skip_unchanged_documents = skip_unchanged_documents - async def _deferred_index(self, lazy_download, doc_id, doc, operation): + async def _deferred_index(self, lazy_download, doc_id, doc, operation) -> None: try: data = await lazy_download(doit=True, timestamp=doc[TIMESTAMP_FIELD]) @@ -495,16 +497,16 @@ async def _deferred_index(self, lazy_download, doc_id, doc, operation): f"Failed to do deferred index operation for doc {doc_id}: {ex}" ) - def force_cancel(self): + def force_cancel(self) -> None: self._canceled = True - async def put_doc(self, doc): + async def put_doc(self, doc) -> None: if self._canceled: raise ForceCanceledError await self.queue.put(doc) - async def run(self, generator, job_type): + async def run(self, generator, job_type) -> None: sanitized_generator = ( (sanitize(doc), *other) async for doc, *other in generator ) @@ -550,7 +552,7 @@ async def _decorate_with_metrics_span(self, generator): async for doc in generator: yield doc - async def get_docs(self, generator, skip_unchanged_documents=False): + async def get_docs(self, generator, skip_unchanged_documents: bool = False) -> None: """Iterate on a generator of documents to fill a queue of bulk operations for the `Sink` to consume. Extraction happens in a separate task, when a document contains files. @@ -671,7 +673,7 @@ async def _load_existing_docs(self): return existing_ids - async def get_docs_incrementally(self, generator): + async def get_docs_incrementally(self, generator) -> None: """Iterate on a generator of documents to fill a queue with bulk operations for the `Sink` to consume. A document might be discarded if its timestamp has not changed. @@ -738,7 +740,7 @@ async def get_docs_incrementally(self, generator): await self.put_doc(END_DOCS) - async def get_access_control_docs(self, generator): + async def get_access_control_docs(self, generator) -> None: """Iterate on a generator of access control documents to fill a queue with bulk operations for the `Sink` to consume. A document might be discarded if its timestamp has not changed. @@ -798,7 +800,7 @@ async def get_access_control_docs(self, generator): await self.enqueue_docs_to_delete(existing_ids) await self.put_doc(END_DOCS) - async def enqueue_docs_to_delete(self, existing_ids): + async def enqueue_docs_to_delete(self, existing_ids) -> None: self._logger.debug(f"Delete {len(existing_ids)} docs from index '{self.index}'") for doc_id in existing_ids.keys(): await self.put_doc( @@ -812,7 +814,7 @@ async def enqueue_docs_to_delete(self, existing_ids): def _log_progress( self, - ): + ) -> None: self._logger.info( "Sync progress -- " f"created: {self.counters.get(CREATES_QUEUED)} | " @@ -840,7 +842,7 @@ class SyncOrchestrator: - once they are both over, returns totals """ - def __init__(self, elastic_config, logger_=None): + def __init__(self, elastic_config, logger_=None) -> None: self._logger = logger_ or logger self._logger.debug(f"SyncOrchestrator connecting to {elastic_config['host']}") self.es_management_client = ESManagementClient(elastic_config) @@ -854,18 +856,18 @@ def __init__(self, elastic_config, logger_=None): error_monitor_config = elastic_config.get("bulk", {}).get("error_monitor", {}) self.error_monitor = ErrorMonitor(error_monitor_config) - async def close(self): + async def close(self) -> None: await self.es_management_client.close() await self.cancel() - async def has_active_license_enabled(self, license_): + async def has_active_license_enabled(self, license_) -> Tuple[bool, License]: # TODO: think how to make it not a proxy method to the client return await self.es_management_client.has_active_license_enabled(license_) - def extract_index_or_alias(self, get_index_response, expected_index_name): + def extract_index_or_alias(self, get_index_response, expected_index_name) -> None: return None - async def prepare_content_index(self, index_name, language_code=None): + async def prepare_content_index(self, index_name, language_code=None) -> None: """Creates the index, given a mapping/settings if it does not exist.""" self._logger.debug(f"Checking index {index_name}") @@ -884,7 +886,7 @@ async def prepare_content_index(self, index_name, language_code=None): ) self._logger.info(f"Content index successfully created: {index_name}") - def done(self): + def done(self) -> bool: """ An async task (which this mimics) should be "done" if: - it was canceled @@ -903,13 +905,13 @@ def done(self): sink_done = True if self._sink_task is None or self._sink_task.done() else False return extractor_done and sink_done - def _sink_task_running(self): + def _sink_task_running(self) -> bool: return self._sink_task is not None and not self._sink_task.done() - def _extractor_task_running(self): + def _extractor_task_running(self) -> bool: return self._extractor_task is not None and not self._extractor_task.done() - async def cancel(self): + async def cancel(self) -> None: if self._sink_task_running(): self._logger.info( f"Canceling the Sink task: {self._sink_task.get_name()}" # pyright: ignore @@ -955,7 +957,7 @@ async def cancel(self): self._sink.force_cancel() self._extractor.force_cancel() - def ingestion_stats(self): + def ingestion_stats(self) -> Dict[str, Any]: stats = {} if self._extractor is not None: stats.update(self._extractor.counters.to_dict()) @@ -979,13 +981,13 @@ async def async_bulk( generator, pipeline, job_type, - filter_=None, - sync_rules_enabled=False, - content_extraction_enabled=True, + filter_: Optional[Filter] = None, + sync_rules_enabled: bool = False, + content_extraction_enabled: bool = True, options=None, - skip_unchanged_documents=False, - enable_bulk_operations_logging=False, - ): + skip_unchanged_documents: bool = False, + enable_bulk_operations_logging: bool = False, + ) -> None: """Performs a batch of `_bulk` calls, given a generator of documents Arguments: @@ -1068,7 +1070,7 @@ async def async_bulk( ) self._sink_task.add_done_callback(functools.partial(self.sink_task_callback)) - def sink_task_callback(self, task): + def sink_task_callback(self, task) -> None: if task.cancelled(): self._logger.warning( f"{type(self._sink).__name__}: {task.get_name()} was cancelled before completion" @@ -1080,7 +1082,7 @@ def sink_task_callback(self, task): ) self.error = task.exception() - def extractor_task_callback(self, task): + def extractor_task_callback(self, task) -> None: if task.cancelled(): self._logger.warning( f"{type(self._extractor).__name__}: {task.get_name()} was cancelled before completion" diff --git a/connectors/exceptions.py b/connectors/exceptions.py new file mode 100644 index 000000000..f67bfd02d --- /dev/null +++ b/connectors/exceptions.py @@ -0,0 +1,9 @@ +# +# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +# or more contributor license agreements. Licensed under the Elastic License 2.0; +# you may not use this file except in compliance with the Elastic License 2.0. +# + + +class DocumentIngestionError(Exception): + pass diff --git a/connectors/filtering/basic_rule.py b/connectors/filtering/basic_rule.py index 301339a5f..836ea69f7 100644 --- a/connectors/filtering/basic_rule.py +++ b/connectors/filtering/basic_rule.py @@ -7,17 +7,20 @@ import datetime import re from enum import Enum +from typing import Any, Dict, List, Optional, Sized, Union from dateutil.parser import ParserError, parser from connectors.logger import logger from connectors.utils import Format, shorten_str -IS_BOOL_FALSE = re.compile("^(false|f|no|n|off)$", re.I) -IS_BOOL_TRUE = re.compile("^(true|t|yes|y|on)$", re.I) +IS_BOOL_FALSE: re.Pattern[str] = re.compile("^(false|f|no|n|off)$", re.I) +IS_BOOL_TRUE: re.Pattern[str] = re.compile("^(true|t|yes|y|on)$", re.I) -def parse(basic_rules_json): +def parse( + basic_rules_json: Optional[List[Dict[str, Union[str, int]]]], +) -> List[Union["BasicRule", Any]]: """Parse a basic rules json array to BasicRule objects. Arguments: @@ -44,14 +47,14 @@ def parse(basic_rules_json): ) -def to_float(value): +def to_float(value: float) -> float: try: return float(value) except ValueError: return value -def to_datetime(value): +def to_datetime(value: str) -> datetime.datetime: try: date_parser = parser() parsed_date_or_datetime = date_parser.parse(timestr=value) @@ -68,7 +71,7 @@ def to_datetime(value): return value -def to_bool(value): +def to_bool(value: str) -> Union[bool, Sized]: if len(value) == 0 or IS_BOOL_FALSE.match(value): return False @@ -84,11 +87,11 @@ class RuleMatchStats: It's an internal class and is not expected to be used outside the module. """ - def __init__(self, policy, matches_count): + def __init__(self, policy: "Policy", matches_count: int) -> None: self.policy = policy self.matches_count = matches_count - def __add__(self, other): + def __add__(self, other: Optional[int]) -> "RuleMatchStats": if other is None: return self @@ -100,7 +103,7 @@ def __add__(self, other): msg = f"__add__ is not implemented for '{type(other)}'" raise NotImplementedError(msg) - def __eq__(self, other): + def __eq__(self, other: "RuleMatchStats") -> bool: return self.policy == other.policy and self.matches_count == other.matches_count @@ -114,13 +117,15 @@ class BasicRuleEngine: It also records stats, which basic rule matched how many documents with a certain policy. """ - def __init__(self, rules): + def __init__(self, rules: Optional[Union[List[None], List["BasicRule"]]]) -> None: self.rules = rules self.rules_match_stats = { BasicRule.DEFAULT_RULE_ID: RuleMatchStats(Policy.INCLUDE, 0) } - def should_ingest(self, document): + def should_ingest( + self, document: Dict[str, Union[str, float, int, datetime.datetime]] + ) -> bool: """Check, whether a document should be ingested or not. By default, the document will be ingested, if it doesn't match any rule. @@ -172,7 +177,7 @@ class Rule(Enum): RULES = [EQUALS, STARTS_WITH, ENDS_WITH, CONTAINS, REGEX, GREATER_THAN, LESS_THAN] @classmethod - def is_string_rule(cls, string): + def is_string_rule(cls, string: str) -> bool: try: cls.from_string(string) return True @@ -180,7 +185,7 @@ def is_string_rule(cls, string): return False @classmethod - def from_string(cls, string): + def from_string(cls, string: str) -> "Rule": match string.casefold(): case "equals": return Rule.EQUALS @@ -212,7 +217,7 @@ class Policy(Enum): POLICIES = [INCLUDE, EXCLUDE] @classmethod - def is_string_policy(cls, string): + def is_string_policy(cls, string: str) -> bool: try: cls.from_string(string) return True @@ -220,7 +225,7 @@ def is_string_policy(cls, string): return False @classmethod - def from_string(cls, string): + def from_string(cls, string: str) -> "Policy": match string.casefold(): case "include": return Policy.INCLUDE @@ -237,7 +242,15 @@ class BasicRule: DEFAULT_RULE_ID = "DEFAULT" SHORTEN_UUID_BY = 26 # UUID: 32 random chars + 4 hyphens; keep 10 characters - def __init__(self, id_, order, policy, field, rule, value): + def __init__( + self, + id_: Union[str, int], + order: int, + policy: Policy, + field: str, + rule: Rule, + value: Union[str, int, float], + ) -> None: self.id_ = id_ self.order = order self.policy = policy @@ -246,7 +259,7 @@ def __init__(self, id_, order, policy, field, rule, value): self.value = value @classmethod - def from_json(cls, basic_rule_json): + def from_json(cls, basic_rule_json: Dict[str, Union[str, int]]) -> "BasicRule": return cls( id_=basic_rule_json["id"], order=basic_rule_json["order"], @@ -256,7 +269,9 @@ def from_json(cls, basic_rule_json): value=basic_rule_json["value"], ) - def matches(self, document): + def matches( + self, document: Dict[str, Union[str, float, int, datetime.datetime]] + ) -> bool: """Check whether a document matches the basic rule. A basic rule matches or doesn't match a document based on the following comparisons: @@ -301,13 +316,15 @@ def matches(self, document): case Rule.EQUALS: return document_value == coerced_rule_value - def is_default_rule(self): + def is_default_rule(self) -> bool: return self.id_ == BasicRule.DEFAULT_RULE_ID - def is_include(self): + def is_include(self) -> bool: return self.policy == Policy.INCLUDE - def coerce_rule_value_based_on_document_value(self, doc_value): + def coerce_rule_value_based_on_document_value( + self, doc_value: Any + ) -> Union[float, str, datetime.datetime]: """Coerce the value inside the basic rule. This method tries to coerce the value inside the basic rule to the type used in the document. @@ -334,7 +351,7 @@ def coerce_rule_value_based_on_document_value(self, doc_value): ) return str(self.value) - def __str__(self): + def __str__(self) -> str: def _format_field(key, value): if isinstance(value, Enum): return f"{key}: {value.value}" @@ -345,7 +362,7 @@ def _format_field(key, value): ] return "Basic rule: " + ", ".join(formatted_fields) - def __format__(self, format_spec): + def __format__(self, format_spec: str) -> str: if format_spec == Format.SHORT.value: # order uses 0 based indexing return f"Basic rule {self.order + 1} (id: '{shorten_str(self.id_, BasicRule.SHORTEN_UUID_BY)}')" diff --git a/connectors/filtering/validation.py b/connectors/filtering/validation.py index d07a420bf..5b63e10dc 100644 --- a/connectors/filtering/validation.py +++ b/connectors/filtering/validation.py @@ -5,6 +5,7 @@ # from copy import deepcopy from enum import Enum +from typing import Any, Dict, List, Optional import fastjsonschema @@ -22,13 +23,13 @@ class SyncRuleValidationResult: ADVANCED_RULES = "advanced_snippet" - def __init__(self, rule_id, is_valid, validation_message): + def __init__(self, rule_id, is_valid, validation_message) -> None: self.rule_id = rule_id self.is_valid = is_valid self.validation_message = validation_message @classmethod - def valid_result(cls, rule_id): + def valid_result(cls, rule_id) -> "SyncRuleValidationResult": return SyncRuleValidationResult( rule_id=rule_id, is_valid=True, validation_message="Valid rule" ) @@ -52,7 +53,7 @@ class FilterValidationError: not in the context of each other) -> both rules belong to one error. """ - def __init__(self, ids=None, messages=None): + def __init__(self, ids=None, messages=None) -> None: if ids is None: ids = [] if messages is None: @@ -67,7 +68,7 @@ def __eq__(self, other): return self.ids == other.ids and self.messages == other.messages - def __str__(self): + def __str__(self) -> str: return f"(ids: {self.ids}, messages: {self.messages})" @@ -77,7 +78,7 @@ class FilteringValidationState(Enum): EDITED = "edited" @classmethod - def to_s(cls, value): + def to_s(cls, value) -> Optional[str]: match value: case FilteringValidationState.VALID: return "valid" @@ -94,14 +95,18 @@ class FilteringValidationResult: These errors will be derived from a single SyncRuleValidationResult which can be added to a FilteringValidationResult. """ - def __init__(self, state=FilteringValidationState.VALID, errors=None): + def __init__( + self, + state: FilteringValidationState = FilteringValidationState.VALID, + errors=None, + ) -> None: if errors is None: errors = [] self.state = state self.errors = errors - def __add__(self, other): + def __add__(self, other) -> "FilteringValidationResult": if other is None: return self @@ -128,7 +133,7 @@ def __eq__(self, other): return self.state == other.state and self.errors == other.errors - def to_dict(self): + def to_dict(self) -> Dict[str, Any]: return { "state": FilteringValidationState.to_s(self.state), "errors": [vars(error) for error in self.errors], @@ -144,7 +149,7 @@ class FilteringValidator: def __init__( self, basic_rules_validators=None, advanced_rules_validators=None, logger_=None - ): + ) -> None: self.basic_rules_validators = ( [] if basic_rules_validators is None else basic_rules_validators ) @@ -153,7 +158,7 @@ def __init__( ) self._logger = logger_ or logger - async def validate(self, filtering): + async def validate(self, filtering) -> FilteringValidationResult: def _is_valid_str(result): if result is None: return "Unknown (check validator implementation as it should never return 'None')" @@ -228,7 +233,7 @@ class BasicRulesSetSemanticValidator(BasicRulesSetValidator): """ @classmethod - def validate(cls, rules): + def validate(cls, rules) -> List[SyncRuleValidationResult]: rules_dict = {} for rule in rules: @@ -254,7 +259,9 @@ def validate(cls, rules): ] @classmethod - def semantic_duplicates_validation_results(cls, basic_rule, semantic_duplicate): + def semantic_duplicates_validation_results( + cls, basic_rule, semantic_duplicate + ) -> List[SyncRuleValidationResult]: def semantic_duplicate_msg(rule_one, rule_two): return f"{format(rule_one, Format.SHORT.value)} is semantically equal to {format(rule_two, Format.SHORT.value)}." @@ -291,7 +298,7 @@ class BasicRuleNoMatchAllRegexValidator(BasicRuleValidator): MATCH_ALL_REGEXPS = [".*", "(.*)"] @classmethod - def validate(cls, basic_rule_json): + def validate(cls, basic_rule_json) -> SyncRuleValidationResult: basic_rule = BasicRule.from_json(basic_rule_json) # default rule uses match all regex, which is intended if basic_rule.is_default_rule(): @@ -336,7 +343,7 @@ class BasicRuleAgainstSchemaValidator(BasicRuleValidator): ) @classmethod - def validate(cls, rule): + def validate(cls, rule) -> SyncRuleValidationResult: try: BasicRuleAgainstSchemaValidator.SCHEMA(rule) diff --git a/connectors/kibana.py b/connectors/kibana.py index 51a9258f5..db298bd7d 100644 --- a/connectors/kibana.py +++ b/connectors/kibana.py @@ -9,6 +9,7 @@ import os import sys from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser +from typing import Optional, Sequence from connectors.config import load_config from connectors.es import DEFAULT_LANGUAGE @@ -20,7 +21,7 @@ CONNECTORS_INDEX = ".elastic-connectors-v1" JOBS_INDEX = ".elastic-connectors-sync-jobs-v1" -DEFAULT_CONFIG = os.path.join(os.path.dirname(__file__), "..", "config.yml") +DEFAULT_CONFIG: str = os.path.join(os.path.dirname(__file__), "..", "config.yml") DEFAULT_PIPELINE = { "version": 1, "description": "For testing", @@ -43,11 +44,11 @@ ], } -logger = logging.getLogger("kibana-fake") +logger: logging.Logger = logging.getLogger("kibana-fake") set_extra_logger(logger, log_level=logging.DEBUG, prefix="KBN-FAKE") -async def prepare(service_type, index_name, config, connector_definition=None): +async def prepare(service_type, index_name, config, connector_definition=None) -> None: klass = get_source_klass(config["sources"][service_type]) es = ESManagementClient(config["elasticsearch"]) connector_index = ConnectorIndex(config["elasticsearch"]) @@ -100,7 +101,7 @@ async def prepare(service_type, index_name, config, connector_definition=None): await es.close() -async def upsert_index(es, index): +async def upsert_index(es, index) -> None: """Override the index with new mappings and settings. If the index with such name exists, it's deleted and then created again @@ -124,7 +125,7 @@ async def upsert_index(es, index): await es.create_content_index(index, DEFAULT_LANGUAGE) -def _parser(): +def _parser() -> ArgumentParser: parser = ArgumentParser( prog="fake-kibana", formatter_class=ArgumentDefaultsHelpFormatter ) @@ -156,9 +157,9 @@ def _parser(): return parser -def main(args=None): +def main(argv: Optional[Sequence[str]] = None) -> int: parser = _parser() - args = parser.parse_args(args=args) + args = parser.parse_args(args=argv) connector_definition_file = args.connector_definition config_file = args.config_file diff --git a/connectors/logger.py b/connectors/logger.py index b7b107428..3d17632e8 100644 --- a/connectors/logger.py +++ b/connectors/logger.py @@ -13,7 +13,8 @@ import time from datetime import datetime, timezone from functools import wraps -from typing import AsyncGenerator +from types import TracebackType +from typing import AsyncGenerator, Mapping, Optional, Tuple, Type, Union import ecs_logging from dateutil.tz import tzlocal @@ -40,17 +41,19 @@ class ColorFormatter(logging.Formatter): DATE_FMT = "%H:%M:%S" - def __init__(self, prefix): + def __init__(self, prefix) -> None: self.custom_format = "[" + prefix + "][%(asctime)s][%(levelname)s] %(message)s" super().__init__(datefmt=self.DATE_FMT) self.local_tz = tzlocal() - def converter(self, timestamp): + def converter(self, timestamp: float) -> datetime: dt = datetime.fromtimestamp(timestamp, self.local_tz) return dt.astimezone(timezone.utc) # override logging.Formatter to use an aware datetime object - def formatTime(self, record, datefmt=None): + def formatTime( + self, record: logging.LogRecord, datefmt: Optional[str] = None + ) -> str: dt = self.converter(record.created) if datefmt: s = dt.strftime(datefmt) @@ -61,20 +64,20 @@ def formatTime(self, record, datefmt=None): s = dt.isoformat() return s - def format(self, record): # noqa: A003 + def format(self, record: logging.LogRecord) -> str: # noqa: A003 self._style._fmt = self.COLORS[record.levelno] + self.custom_format + self.RESET return super().format(record) class DocumentLogger: - def __init__(self, prefix, extra): + def __init__(self, prefix, extra) -> None: self._prefix = prefix self._extra = extra def isEnabledFor(self, level): return logger.isEnabledFor(level) - def debug(self, msg, *args, **kwargs): + def debug(self, msg, *args, **kwargs) -> None: logger.debug( msg, *args, @@ -83,7 +86,7 @@ def debug(self, msg, *args, **kwargs): **kwargs, ) - def info(self, msg, *args, **kwargs): + def info(self, msg, *args, **kwargs) -> None: logger.info( msg, *args, @@ -92,7 +95,7 @@ def info(self, msg, *args, **kwargs): **kwargs, ) - def warning(self, msg, *args, **kwargs): + def warning(self, msg, *args, **kwargs) -> None: logger.warning( msg, *args, @@ -101,7 +104,7 @@ def warning(self, msg, *args, **kwargs): **kwargs, ) - def error(self, msg, *args, **kwargs): + def error(self, msg, *args, **kwargs) -> None: logger.error( msg, *args, @@ -110,7 +113,7 @@ def error(self, msg, *args, **kwargs): **kwargs, ) - def exception(self, msg, *args, exc_info=True, **kwargs): + def exception(self, msg, *args, exc_info: bool = True, **kwargs) -> None: logger.exception( msg, *args, @@ -120,7 +123,7 @@ def exception(self, msg, *args, exc_info=True, **kwargs): **kwargs, ) - def critical(self, msg, *args, **kwargs): + def critical(self, msg, *args, **kwargs) -> None: logger.critical( msg, *args, @@ -129,7 +132,7 @@ def critical(self, msg, *args, **kwargs): **kwargs, ) - def fatal(self, msg, *args, **kwargs): + def fatal(self, msg, *args, **kwargs) -> None: logger.fatal( msg, *args, @@ -140,7 +143,21 @@ def fatal(self, msg, *args, **kwargs): class ExtraLogger(logging.Logger): - def _log(self, level, msg, args, exc_info=None, prefix=None, extra=None): + def _log( + self, + level: int, + msg: str, + args: Union[Mapping[str, object], Tuple[object, ...]], + exc_info: Union[ + None, + BaseException, + bool, + Tuple[Type[BaseException], BaseException, Optional[TracebackType]], + Tuple[None, ...], + ] = None, + prefix=None, + extra: Optional[Mapping[str, object]] = None, + ) -> None: if ( not (hasattr(self, "filebeat") and self.filebeat) # pyright: ignore and prefix @@ -158,7 +175,7 @@ def _log(self, level, msg, args, exc_info=None, prefix=None, extra=None): super(ExtraLogger, self)._log(level, msg, args, exc_info, extra) -def set_logger(log_level=logging.INFO, filebeat=False): +def set_logger(log_level: int = logging.INFO, filebeat: bool = False): global logger if filebeat: formatter = ecs_logging.StdlibFormatter() @@ -180,7 +197,12 @@ def set_logger(log_level=logging.INFO, filebeat=False): return logger -def set_extra_logger(logger, log_level=logging.INFO, prefix="BYOC", filebeat=False): +def set_extra_logger( + logger: Optional[str], + log_level: Union[int, str] = logging.INFO, + prefix: str = "BYOC", + filebeat: bool = False, +) -> None: if isinstance(logger, str): logger = logging.getLogger(logger) handler = logging.StreamHandler() @@ -226,14 +248,14 @@ def timed_execution(name, func_name, slow_log=None, canceled=None): class _TracedAsyncGenerator: - def __init__(self, generator, name, func_name, slow_log=None): + def __init__(self, generator, name, func_name, slow_log=None) -> None: self.gen = generator self.name = name self.slow_log = slow_log self.func_name = func_name self.counter = 0 - def __aiter__(self): + def __aiter__(self) -> "_TracedAsyncGenerator": return self async def __anext__(self): diff --git a/connectors/preflight_check.py b/connectors/preflight_check.py index ee104c940..91b35f7bf 100644 --- a/connectors/preflight_check.py +++ b/connectors/preflight_check.py @@ -15,7 +15,7 @@ class PreflightCheck: - def __init__(self, config, version): + def __init__(self, config, version) -> None: self.version = version self.config = config self.elastic_config = config["elasticsearch"] @@ -29,13 +29,13 @@ def __init__(self, config, version): self._sleeps = CancellableSleeps() self.running = False - def stop(self): + def stop(self) -> None: self.running = False self._sleeps.cancel() if self.es_management_client is not None: self.es_management_client.stop_waiting() - def shutdown(self, sig): + def shutdown(self, sig) -> None: logger.info(f"Caught {sig.name}. Graceful shutdown.") self.stop() @@ -84,7 +84,7 @@ async def _check_es_server(self): versions_compatible = await self._versions_compatible(version.get("number")) return versions_compatible, is_serverless - async def _versions_compatible(self, es_version): + async def _versions_compatible(self, es_version) -> bool: """ Checks if the Connector and ES versions are compatible """ @@ -128,7 +128,7 @@ async def _versions_compatible(self, es_version): ) return True - async def _check_local_extraction_setup(self): + async def _check_local_extraction_setup(self) -> None: if self.extraction_config is None: logger.info( "Extraction service is not configured, skipping its preflight check." @@ -161,7 +161,7 @@ async def _check_local_extraction_setup(self): finally: await session.close() - async def _check_system_indices_with_retries(self): + async def _check_system_indices_with_retries(self) -> bool: attempts = 0 while self.running: try: @@ -189,7 +189,7 @@ async def _check_system_indices_with_retries(self): await self._sleeps.sleep(self.preflight_idle) return False - def _validate_configuration(self): + def _validate_configuration(self) -> bool: # "Native" mode configured_native_types = "native_service_types" in self.config force_allowed_native = self.config.get("_force_allow_native", False) diff --git a/connectors/protocol/connectors.py b/connectors/protocol/connectors.py index d8c2bf318..908b6f38e 100644 --- a/connectors/protocol/connectors.py +++ b/connectors/protocol/connectors.py @@ -18,6 +18,8 @@ from copy import deepcopy from datetime import datetime, timezone from enum import Enum +from typing import Any, Dict, List, Optional, Set, Sized, Tuple, Type, Union +from unittest.mock import MagicMock, Mock from elasticsearch import ( ApiError, @@ -35,6 +37,7 @@ from connectors.logger import logger from connectors.source import ( DEFAULT_CONFIGURATION, + BaseDataSource, DataSourceConfiguration, get_source_klass, ) @@ -78,8 +81,8 @@ CONNECTORS_INDEX = ".elastic-connectors" JOBS_INDEX = ".elastic-connectors-sync-jobs" -CONCRETE_CONNECTORS_INDEX = CONNECTORS_INDEX + "-v1" -CONCRETE_JOBS_INDEX = JOBS_INDEX + "-v1" +CONCRETE_CONNECTORS_INDEX: str = CONNECTORS_INDEX + "-v1" +CONCRETE_JOBS_INDEX: str = JOBS_INDEX + "-v1" CONNECTORS_ACCESS_CONTROL_INDEX_PREFIX = ".search-acl-filter-" JOB_NOT_FOUND_ERROR = "Couldn't find the job" @@ -89,7 +92,7 @@ INDEXED_DOCUMENT_VOLUME = "indexed_document_volume" DELETED_DOCUMENT_COUNT = "deleted_document_count" TOTAL_DOCUMENT_COUNT = "total_document_count" -ALLOWED_INGESTION_STATS_KEYS = ( +ALLOWED_INGESTION_STATS_KEYS: Tuple[str, ...] = ( INDEXED_DOCUMENT_COUNT, INDEXED_DOCUMENT_VOLUME, DELETED_DOCUMENT_COUNT, @@ -156,12 +159,23 @@ class ProtocolError(Exception): class ConnectorIndex(ESIndex): - def __init__(self, elastic_config): + def __init__( + self, + elastic_config: Dict[ + str, + Union[ + str, + bool, + Dict[str, Union[int, bool, Dict[str, Union[bool, int, float]]]], + int, + ], + ], + ) -> None: logger.debug(f"ConnectorIndex connecting to {elastic_config['host']}") # initialize ESIndex instance super().__init__(index_name=CONNECTORS_INDEX, elastic_config=elastic_config) - async def heartbeat(self, doc_id): + async def heartbeat(self, doc_id) -> None: if self.feature_use_connectors_api: await self.api.connector_check_in(doc_id) else: @@ -169,12 +183,12 @@ async def heartbeat(self, doc_id): async def connector_put( self, - connector_id, - service_type, - connector_name=None, - index_name=None, - is_native=False, - ): + connector_id: str, + service_type: str, + connector_name: Optional[str] = None, + index_name: Optional[str] = None, + is_native: bool = False, + ) -> None: await self.api.connector_put( connector_id=connector_id, service_type=service_type, @@ -183,7 +197,9 @@ async def connector_put( is_native=is_native, ) - async def connector_exists(self, connector_id, include_deleted=False): + async def connector_exists( + self, connector_id: str, include_deleted: bool = False + ) -> bool: try: doc = await self.api.connector_get( connector_id=connector_id, include_deleted=include_deleted @@ -198,8 +214,12 @@ async def connector_exists(self, connector_id, include_deleted=False): raise e async def connector_update_scheduling( - self, connector_id, full=None, incremental=None, access_control=None - ): + self, + connector_id: str, + full: Optional[Dict[str, Union[bool, str]]] = None, + incremental: None = None, + access_control: None = None, + ) -> None: scheduling = {} if full is not None: @@ -216,13 +236,20 @@ async def connector_update_scheduling( ) async def connector_update_configuration( - self, connector_id, schema=None, values=None - ): + self, + connector_id: str, + schema: Optional[Dict[str, Dict[str, Union[str, bool, int]]]] = None, + values: None = None, + ) -> None: await self.api.connector_update_configuration( connector_id=connector_id, configuration=schema, values=values ) - async def supported_connectors(self, native_service_types=None, connector_ids=None): + async def supported_connectors( + self, + native_service_types: Optional[Sized] = None, + connector_ids: Optional[Sized] = None, + ) -> None: if native_service_types is None: native_service_types = [] if connector_ids is None: @@ -261,13 +288,30 @@ async def supported_connectors(self, native_service_types=None, connector_ids=No async for connector in self.get_all_docs(query=query): yield connector - def _create_object(self, doc_source): + def _create_object( + self, + doc_source: Dict[ + str, + Union[ + Dict[str, str], + str, + Dict[ + str, + Optional[ + Union[ + Dict[str, Dict[str, str]], Dict[str, Union[bool, str]], str + ] + ], + ], + ], + ], + ) -> "Connector": return Connector( self, doc_source, ) - async def get_connector_by_index(self, index_name): + async def get_connector_by_index(self, index_name: str) -> "Connector": connectors = [ connector async for connector in self.get_all_docs( @@ -287,7 +331,7 @@ async def all_connectors(self): yield connector -def filter_ingestion_stats(ingestion_stats): +def filter_ingestion_stats(ingestion_stats: Optional[Dict[str, int]]) -> Dict[str, int]: if ingestion_stats is None: return {} @@ -298,80 +342,80 @@ def filter_ingestion_stats(ingestion_stats): class SyncJob(ESDocument): @property - def status(self): + def status(self) -> JobStatus: return JobStatus(self.get("status")) @property - def error(self): + def error(self) -> str: return self.get("error") @property - def connector_id(self): + def connector_id(self) -> Optional[str]: return self.get("connector", "id") @property - def index_name(self): + def index_name(self) -> Optional[str]: return self.get("connector", "index_name") @property - def language(self): + def language(self) -> str: return self.get("connector", "language") @property - def service_type(self): + def service_type(self) -> Optional[str]: return self.get("connector", "service_type") @property - def configuration(self): + def configuration(self) -> DataSourceConfiguration: return DataSourceConfiguration(self.get("connector", "configuration")) @property - def filtering(self): + def filtering(self) -> "Filter": return Filter(self.get("connector", "filtering", default={})) @property - def pipeline(self): + def pipeline(self) -> "Pipeline": return Pipeline(self.get("connector", "pipeline")) @property - def sync_cursor(self): + def sync_cursor(self) -> Dict[str, str]: return self.get("connector", "sync_cursor") @property - def terminated(self): + def terminated(self) -> bool: return self.status in (JobStatus.ERROR, JobStatus.COMPLETED, JobStatus.CANCELED) @property - def indexed_document_count(self): + def indexed_document_count(self) -> int: return self.get(INDEXED_DOCUMENT_COUNT, default=0) @property - def indexed_document_volume(self): + def indexed_document_volume(self) -> int: return self.get(INDEXED_DOCUMENT_VOLUME, default=0) @property - def deleted_document_count(self): + def deleted_document_count(self) -> int: return self.get(DELETED_DOCUMENT_COUNT, default=0) @property - def total_document_count(self): + def total_document_count(self) -> int: return self.get(TOTAL_DOCUMENT_COUNT, default=0) @property - def job_type(self): + def job_type(self) -> JobType: return JobType(self.get("job_type")) - def is_content_sync(self): + def is_content_sync(self) -> bool: return self.job_type in (JobType.FULL, JobType.INCREMENTAL) - async def validate_filtering(self, validator): + async def validate_filtering(self, validator: Mock) -> None: validation_result = await validator.validate_filtering(self.filtering) if validation_result.state != FilteringValidationState.VALID: msg = f"Filtering in state {validation_result.state}, errors: {validation_result.errors}." raise InvalidFilteringError(msg) - def _wrap_errors(self, operation_name, e): + def _wrap_errors(self, operation_name: str, e: TypeError): if isinstance(e, ApiError): reason = None if e.info is not None and "error" in e.info and "reason" in e.info["error"]: @@ -382,7 +426,7 @@ def _wrap_errors(self, operation_name, e): msg = f"Failed to {operation_name} for job {self.id} because of {e.__class__.__name__}: {str(e)}" raise ProtocolError(msg) from e - async def claim(self, sync_cursor=None): + async def claim(self, sync_cursor: Optional[Dict[str, str]] = None) -> None: try: if self.index.feature_use_connectors_api: await self.index.api.connector_sync_job_claim( @@ -402,7 +446,11 @@ async def claim(self, sync_cursor=None): except Exception as e: self._wrap_errors("claim job", e) - async def update_metadata(self, ingestion_stats=None, connector_metadata=None): + async def update_metadata( + self, + ingestion_stats: Optional[Dict[str, int]] = None, + connector_metadata: Optional[Sized] = None, + ) -> None: try: ingestion_stats = filter_ingestion_stats(ingestion_stats) if self.index.feature_use_connectors_api: @@ -425,7 +473,11 @@ async def update_metadata(self, ingestion_stats=None, connector_metadata=None): except Exception as e: self._wrap_errors("update metadata and stats", e) - async def done(self, ingestion_stats=None, connector_metadata=None): + async def done( + self, + ingestion_stats: None = None, + connector_metadata: Optional[Sized] = None, + ) -> None: try: await self._terminate( JobStatus.COMPLETED, None, ingestion_stats, connector_metadata @@ -433,7 +485,12 @@ async def done(self, ingestion_stats=None, connector_metadata=None): except Exception as e: self._wrap_errors("terminate as completed", e) - async def fail(self, error, ingestion_stats=None, connector_metadata=None): + async def fail( + self, + error: Union[str, int, Exception], + ingestion_stats: None = None, + connector_metadata: Optional[Sized] = None, + ) -> None: try: if isinstance(error, str): message = error @@ -449,7 +506,11 @@ async def fail(self, error, ingestion_stats=None, connector_metadata=None): except Exception as e: self._wrap_errors("terminate as failed", e) - async def cancel(self, ingestion_stats=None, connector_metadata=None): + async def cancel( + self, + ingestion_stats: None = None, + connector_metadata: Optional[Sized] = None, + ) -> None: try: await self._terminate( JobStatus.CANCELED, None, ingestion_stats, connector_metadata @@ -457,7 +518,11 @@ async def cancel(self, ingestion_stats=None, connector_metadata=None): except Exception as e: self._wrap_errors("terminate as canceled", e) - async def suspend(self, ingestion_stats=None, connector_metadata=None): + async def suspend( + self, + ingestion_stats: None = None, + connector_metadata: Optional[Sized] = None, + ) -> None: try: await self._terminate( JobStatus.SUSPENDED, None, ingestion_stats, connector_metadata @@ -466,8 +531,12 @@ async def suspend(self, ingestion_stats=None, connector_metadata=None): self._wrap_errors("terminate as suspended", e) async def _terminate( - self, status, error=None, ingestion_stats=None, connector_metadata=None - ): + self, + status: JobStatus, + error: Optional[str] = None, + ingestion_stats: None = None, + connector_metadata: Optional[Sized] = None, + ) -> None: ingestion_stats = filter_ingestion_stats(ingestion_stats) if connector_metadata is None: connector_metadata = {} @@ -486,10 +555,10 @@ async def _terminate( doc["metadata"] = connector_metadata await self.index.update(doc_id=self.id, doc=doc) - def _prefix(self): + def _prefix(self) -> str: return f"[Connector id: {self.connector_id}, index name: {self.index_name}, Sync job id: {self.id}]" - def _extra(self): + def _extra(self) -> Dict[str, Optional[str]]: return { "labels.sync_job_id": self.id, "labels.connector_id": self.connector_id, @@ -501,19 +570,71 @@ def _extra(self): class Filtering: DEFAULT_DOMAIN = "DEFAULT" - def __init__(self, filtering=None): + def __init__( + self, + filtering: Optional[ + Union[ + List[ + Union[ + Dict[ + str, + Union[ + str, + Dict[ + str, + Union[ + Dict[str, Dict[str, Dict[str, Dict[Any, Any]]]], + List[Dict[str, int]], + Dict[str, str], + ], + ], + Dict[str, str], + ], + ], + Dict[str, Union[Dict[str, str], str]], + ] + ], + List[ + Union[ + Dict[ + str, + Union[ + str, + Dict[ + str, + Union[ + Dict[str, Dict[str, Dict[str, Dict[Any, Any]]]], + List[Dict[str, int]], + Dict[str, str], + ], + ], + ], + ], + Dict[ + str, + Union[ + str, Dict[str, Union[str, List[Dict[str, List[str]]]]] + ], + ], + ] + ], + ] + ] = None, + ) -> None: if filtering is None: filtering = [] self.filtering = filtering - def get_active_filter(self, domain=DEFAULT_DOMAIN): + def get_active_filter(self, domain: str = DEFAULT_DOMAIN) -> "Filter": return self.get_filter(filter_state="active", domain=domain) - def get_draft_filter(self, domain=DEFAULT_DOMAIN): + def get_draft_filter(self, domain: str = DEFAULT_DOMAIN) -> "Filter": return self.get_filter(filter_state="draft", domain=domain) - def get_filter(self, filter_state="active", domain=DEFAULT_DOMAIN): + def get_filter( + self, filter_state: str = "active", domain: str = DEFAULT_DOMAIN + ) -> "Filter": return next( ( Filter(filter_[filter_state]) @@ -523,12 +644,30 @@ def get_filter(self, filter_state="active", domain=DEFAULT_DOMAIN): Filter(), ) - def to_list(self): + def to_list( + self, + ) -> List[ + Dict[ + str, + Union[ + str, + Dict[ + str, + Union[ + Dict[str, Dict[str, Dict[str, Dict[Any, Any]]]], + List[Dict[str, int]], + Dict[str, str], + ], + ], + Dict[str, Union[str, List[Dict[str, List[str]]]]], + ], + ] + ]: return list(self.filtering) class Filter(dict): - def __init__(self, filter_=None): + def __init__(self, filter_: Optional[Any] = None) -> None: if filter_ is None: filter_ = {} @@ -540,17 +679,19 @@ def __init__(self, filter_=None): "validation", {"state": FilteringValidationState.VALID.value, "errors": []} ) - def get_advanced_rules(self): + def get_advanced_rules(self) -> Any: return self.advanced_rules.get("value", {}) - def has_advanced_rules(self): + def has_advanced_rules(self) -> bool: advanced_rules = self.get_advanced_rules() return advanced_rules is not None and len(advanced_rules) > 0 - def has_validation_state(self, validation_state): + def has_validation_state(self, validation_state: FilteringValidationState) -> bool: return FilteringValidationState(self.validation["state"]) == validation_state - def transform_filtering(self): + def transform_filtering( + self, + ) -> Union[Dict[str, Union[Dict[Any, Any], List[Any]]], "Filter"]: """ Transform the filtering in .elastic-connectors to filtering ready-to-use in .elastic-connectors-sync-jobs """ @@ -571,7 +712,7 @@ def transform_filtering(self): class Pipeline(UserDict): - def __init__(self, data): + def __init__(self, data: Dict[str, Union[bool, str]]) -> None: if data is None: data = {} default = PIPELINE_DEFAULT.copy() @@ -591,18 +732,18 @@ class Features: NATIVE_CONNECTOR_API_KEYS = "native_connector_api_keys" - def __init__(self, features=None): + def __init__(self, features: Optional[Any] = None) -> None: if features is None: features = {} self.features = features - def incremental_sync_enabled(self): + def incremental_sync_enabled(self) -> bool: return nested_get_from_dict( self.features, ["incremental_sync", "enabled"], default=False ) - def document_level_security_enabled(self): + def document_level_security_enabled(self) -> bool: return nested_get_from_dict( self.features, ["document_level_security", "enabled"], default=False ) @@ -612,7 +753,7 @@ def native_connector_api_keys_enabled(self): self.features, ["native_connector_api_keys", "enabled"], default=True ) - def sync_rules_enabled(self): + def sync_rules_enabled(self) -> bool: return any( [ self.feature_enabled(Features.BASIC_RULES_NEW), @@ -622,7 +763,7 @@ def sync_rules_enabled(self): ] ) - def feature_enabled(self, feature): + def feature_enabled(self, feature: str) -> bool: match feature: case Features.BASIC_RULES_NEW: return nested_get_from_dict( @@ -642,66 +783,66 @@ def feature_enabled(self, feature): class Connector(ESDocument): @property - def status(self): + def status(self) -> Status: return Status(self.get("status")) @property - def service_type(self): + def service_type(self) -> Optional[str]: return self.get("service_type") @property - def last_seen(self): + def last_seen(self) -> Optional[datetime]: return self._property_as_datetime("last_seen") @property - def native(self): + def native(self) -> bool: return self.get("is_native", default=False) @property - def full_sync_scheduling(self): + def full_sync_scheduling(self) -> Dict[str, Union[bool, str]]: return self.get("scheduling", "full", default={}) @property - def incremental_sync_scheduling(self): + def incremental_sync_scheduling(self) -> Dict[str, Union[bool, str]]: return self.get("scheduling", "incremental", default={}) @property - def access_control_sync_scheduling(self): + def access_control_sync_scheduling(self) -> Dict[str, Union[bool, str]]: return self.get("scheduling", "access_control", default={}) @property - def configuration(self): + def configuration(self) -> DataSourceConfiguration: return DataSourceConfiguration(self.get("configuration")) @property - def index_name(self): + def index_name(self) -> Optional[str]: return self.get("index_name") @property - def language(self): + def language(self) -> str: return self.get("language") @property - def filtering(self): + def filtering(self) -> Filtering: return Filtering(self.get("filtering")) @property - def pipeline(self): + def pipeline(self) -> Pipeline: return Pipeline(self.get("pipeline")) @property - def features(self): + def features(self) -> Features: return Features(self.get("features")) @property - def last_sync_status(self): + def last_sync_status(self) -> JobStatus: return JobStatus(self.get("last_sync_status")) @property - def last_access_control_sync_status(self): + def last_access_control_sync_status(self) -> JobStatus: return JobStatus(self.get("last_access_control_sync_status")) - def _property_as_datetime(self, key): + def _property_as_datetime(self, key: str) -> Optional[datetime]: value = self.get(key) if value is not None: value = parse_datetime_string(value) # pyright: ignore @@ -711,15 +852,15 @@ def _property_as_datetime(self, key): return value @property - def last_sync_scheduled_at(self): + def last_sync_scheduled_at(self) -> datetime: return self._property_as_datetime("last_sync_scheduled_at") @property - def last_incremental_sync_scheduled_at(self): + def last_incremental_sync_scheduled_at(self) -> datetime: return self._property_as_datetime("last_incremental_sync_scheduled_at") @property - def last_access_control_sync_scheduled_at(self): + def last_access_control_sync_scheduled_at(self) -> datetime: return self._property_as_datetime("last_access_control_sync_scheduled_at") def last_sync_scheduled_at_by_job_type(self, job_type): @@ -735,14 +876,14 @@ def last_sync_scheduled_at_by_job_type(self, job_type): raise ValueError(msg) @property - def sync_cursor(self): + def sync_cursor(self) -> Dict[str, str]: return self.get("sync_cursor") @property - def api_key_secret_id(self): + def api_key_secret_id(self) -> str: return self.get("api_key_secret_id") - async def heartbeat(self, interval, force=False): + async def heartbeat(self, interval: int, force: bool = False) -> None: if ( force or self.last_seen is None @@ -751,7 +892,7 @@ async def heartbeat(self, interval, force=False): self.log_debug("Sending heartbeat") await self.index.heartbeat(doc_id=self.id) - def next_sync(self, job_type, now): + def next_sync(self, job_type: JobType, now: datetime) -> Optional[str]: """Returns the datetime in UTC timezone when the next sync for a given job type will run, return None if it's disabled.""" match job_type: @@ -769,7 +910,7 @@ def next_sync(self, job_type, now): return None return next_run(scheduling_property.get("interval"), now) - async def _update_datetime(self, field, new_ts): + async def _update_datetime(self, field: str, new_ts: Optional[datetime]) -> None: await self.index.update( doc_id=self.id, doc={field: iso_utc(new_ts)}, @@ -777,7 +918,9 @@ async def _update_datetime(self, field, new_ts): if_primary_term=self._primary_term, ) - async def update_last_sync_scheduled_at_by_job_type(self, job_type, new_ts): + async def update_last_sync_scheduled_at_by_job_type( + self, job_type: JobType, new_ts: datetime + ) -> None: match job_type: case JobType.ACCESS_CONTROL: await self._update_datetime( @@ -793,7 +936,7 @@ async def update_last_sync_scheduled_at_by_job_type(self, job_type, new_ts): msg = f"Unknown job type: {job_type}" raise ValueError(msg) - async def sync_starts(self, job_type): + async def sync_starts(self, job_type: JobType) -> None: if job_type == JobType.ACCESS_CONTROL: last_sync_information = { "last_access_control_sync_status": JobStatus.IN_PROGRESS.value, @@ -820,18 +963,20 @@ async def sync_starts(self, job_type): if_primary_term=self._primary_term, ) - async def error(self, error): + async def error(self, error: str) -> None: doc = { "status": Status.ERROR.value, "error": str(error), } await self.index.update(doc_id=self.id, doc=doc) - async def connected(self): + async def connected(self) -> None: doc = {"status": Status.CONNECTED.value, "error": None} await self.index.update(doc_id=self.id, doc=doc) - async def sync_done(self, job, cursor=None): + async def sync_done( + self, job: Optional[Mock], cursor: Optional[Dict[str, str]] = None + ) -> None: job_status = JobStatus.ERROR if job is None else job.status job_error = JOB_NOT_FOUND_ERROR if job is None else job.error job_type = job.job_type if job is not None else None @@ -883,7 +1028,7 @@ async def sync_done(self, job, cursor=None): await self.index.update(doc_id=self.id, doc=doc) @with_concurrency_control() - async def prepare(self, config, sources): + async def prepare(self, config: Dict[str, str], sources: Dict[str, str]) -> None: """Prepares the connector, given a configuration If the connector id and the service type is in the config, we want to populate the service type and then sets the default configuration. @@ -958,7 +1103,9 @@ async def prepare(self, config, sources): ) await self.reload() - def validated_doc(self, source_klass): + def validated_doc( + self, source_klass: Type[BaseDataSource] + ) -> Dict[str, Union[Dict[str, Dict[str, Optional[Union[str, int, bool]]]], str]]: simple_config = source_klass.get_simple_configuration() current_config = self.configuration.to_dict() @@ -993,8 +1140,33 @@ def validated_doc(self, source_klass): return doc def updated_configuration_fields( - self, missing_keys, current_config, simple_default_config - ): + self, + missing_keys: Union[List[str], Set[str]], + current_config: Dict[ + str, + Union[ + Dict[str, Optional[Union[str, int, bool]]], + Dict[str, Union[str, int]], + Dict[str, Union[str, bool, int]], + ], + ], + simple_default_config: Dict[ + str, + Union[ + Dict[str, Optional[Union[str, int, bool]]], + Dict[str, Union[str, int]], + Dict[str, Union[str, bool, int]], + ], + ], + ) -> Dict[ + str, + Union[ + Dict[str, Optional[Union[str, int, bool]]], + Dict[str, Union[str, bool, int]], + Dict[str, str], + Dict[str, int], + ], + ]: self.log_warning( f"Detected an existing connector: {self.id} ({self.service_type}) that was previously {Status.CONNECTED.value} but is now missing configuration: {missing_keys}. Values for the new fields will be automatically set. Please review these configuration values as part of your upgrade." ) @@ -1014,8 +1186,10 @@ def updated_configuration_fields( return draft_config def updated_configuration_field_properties( - self, fields_missing_properties, simple_config - ): + self, + fields_missing_properties: Dict[str, Dict[str, str]], + simple_config: Dict[str, Dict[str, Optional[Union[str, int, bool]]]], + ) -> Dict[str, Dict[str, Optional[Union[str, int, bool]]]]: """Checks the field properties for every field in a configuration. If a field is missing field properties, add those field properties with default values. @@ -1033,7 +1207,7 @@ def updated_configuration_field_properties( return deep_merge_dicts(filtered_simple_config, fields_missing_properties) @with_concurrency_control() - async def validate_filtering(self, validator): + async def validate_filtering(self, validator: Mock) -> None: await self.reload() draft_filter = self.filtering.get_draft_filter() if not draft_filter.has_validation_state(FilteringValidationState.EDITED): @@ -1074,7 +1248,7 @@ async def validate_filtering(self, validator): ) await self.reload() - async def document_count(self): + async def document_count(self) -> int: if not self.index.serverless: await self.index.client.indices.refresh( index=self.index_name, ignore_unavailable=True @@ -1084,10 +1258,10 @@ async def document_count(self): ) return result["count"] - def _prefix(self): + def _prefix(self) -> str: return f"[Connector id: {self.id}, index name: {self.index_name}]" - def _extra(self): + def _extra(self) -> Dict[str, Optional[str]]: return { "labels.connector_id": self.id, "labels.index_name": self.index_name, @@ -1095,7 +1269,7 @@ def _extra(self): } -IDLE_JOBS_THRESHOLD = 60 * 5 # 5 minutes +IDLE_JOBS_THRESHOLD: int = 60 * 5 # 5 minutes class SyncJobIndex(ESIndex): @@ -1106,10 +1280,21 @@ class SyncJobIndex(ESIndex): elastic_config (dict): Elasticsearch configuration and credentials """ - def __init__(self, elastic_config): + def __init__( + self, + elastic_config: Dict[ + str, + Union[ + str, + bool, + Dict[str, Union[int, bool, Dict[str, Union[bool, int, float]]]], + int, + ], + ], + ) -> None: super().__init__(index_name=JOBS_INDEX, elastic_config=elastic_config) - def _create_object(self, doc_source): + def _create_object(self, doc_source) -> SyncJob: """ Args: doc_source (dict): A raw Elasticsearch document @@ -1121,7 +1306,9 @@ def _create_object(self, doc_source): doc_source=doc_source, ) - async def create(self, connector, trigger_method, job_type): + async def create( + self, connector: Mock, trigger_method: JobTriggerMethod, job_type: JobType + ) -> MagicMock: if self.feature_use_connectors_api: response = await self.api.connector_sync_job_create( connector_id=connector.id, @@ -1159,7 +1346,9 @@ async def create(self, connector, trigger_method, job_type): return api_response["_id"] - async def pending_jobs(self, connector_ids, job_types): + async def pending_jobs( + self, connector_ids: List[int], job_types: Optional[str] + ) -> None: if not job_types: return if not isinstance(job_types, list): diff --git a/connectors/service_cli.py b/connectors/service_cli.py index 239ccff1b..c16311b0d 100755 --- a/connectors/service_cli.py +++ b/connectors/service_cli.py @@ -17,6 +17,8 @@ import logging import os import signal +from asyncio.events import AbstractEventLoop +from typing import Optional import click from click import ClickException, UsageError @@ -35,7 +37,7 @@ from connectors.utils import sleeps_for_retryable -async def _start_service(actions, config, loop): +async def _start_service(actions, config, loop) -> Optional[int]: """Starts the service. Steps: @@ -80,7 +82,7 @@ def _get_uvloop(): return uvloop -def get_event_loop(uvloop=False): +def get_event_loop(uvloop: bool = False) -> AbstractEventLoop: if uvloop: # activate uvloop if lib is present try: @@ -102,7 +104,9 @@ def get_event_loop(uvloop=False): return loop -def run(action, config_file, log_level, filebeat, service_type, uvloop): +def run( + action, config_file, log_level, filebeat: bool, service_type, uvloop: bool +) -> Optional[int]: """Loads the config file, sets the logger and executes an action. Actions: @@ -225,7 +229,7 @@ def run(action, config_file, log_level, filebeat, service_type, uvloop): help="Service type to get default configuration for if action is config.", ) @click.option("--uvloop", is_flag=True, default=False, help="Use uvloop if possible.") -def main(action, config_file, log_level, filebeat, service_type, uvloop): +def main(action, config_file, log_level, filebeat: bool, service_type, uvloop: bool): """Entry point to the service, responsible for all operations. Parses the arguments and calls `run` with them. diff --git a/connectors/services/access_control_sync_job_execution.py b/connectors/services/access_control_sync_job_execution.py index ee789366a..59a54c8fc 100644 --- a/connectors/services/access_control_sync_job_execution.py +++ b/connectors/services/access_control_sync_job_execution.py @@ -12,15 +12,15 @@ class AccessControlSyncJobExecutionService(JobExecutionService): name = "sync_access_control" - def __init__(self, config): + def __init__(self, config) -> None: super().__init__(config, "access_control_sync_job_execution_service") @cached_property - def display_name(self): + def display_name(self) -> str: return "access control sync job execution" @cached_property - def max_concurrency_config(self): + def max_concurrency_config(self) -> str: return "service.max_concurrent_access_control_syncs" @cached_property @@ -31,7 +31,7 @@ def job_types(self): def max_concurrency(self): return self.service_config.get("max_concurrent_access_control_syncs") - def should_execute(self, connector, sync_job): + def should_execute(self, connector, sync_job) -> bool: if not connector.features.document_level_security_enabled(): sync_job.log_debug("DLS is not enabled for the connector, skip the job...") return False diff --git a/connectors/services/base.py b/connectors/services/base.py index b5ff8c897..9c28f9ff3 100644 --- a/connectors/services/base.py +++ b/connectors/services/base.py @@ -14,6 +14,8 @@ import asyncio import time from copy import deepcopy +from typing import Any, Callable, DefaultDict, Dict, List, Optional, Tuple, Union +from unittest.mock import Mock from connectors.logger import DocumentLogger, logger from connectors.utils import CancellableSleeps @@ -34,7 +36,7 @@ class ServiceAlreadyRunningError(Exception): _SERVICES = {} -def get_services(names, config): +def get_services(names: List[str], config: DefaultDict[Any, Any]) -> "MultiService": """Instantiates a list of services given their names and a config. returns a `MultiService` instance. @@ -50,7 +52,12 @@ def get_service(name, config): class _Registry(type): """Metaclass used to register a service class in an internal registry.""" - def __new__(cls, name, bases, dct): + def __new__( + cls: type, + name: str, + bases: Tuple[()], + dct: Dict[str, Optional[Union[str, Callable]]], + ) -> type: service_name = dct.get("name") class_instance = super().__new__(cls, name, bases, dct) if service_name is not None: @@ -69,7 +76,7 @@ class BaseService(metaclass=_Registry): name = None # using None here avoids registering this class - def __init__(self, config, service_name): + def __init__(self, config: Dict[str, Any], service_name: str) -> None: self.config = config self.logger = DocumentLogger( f"[{service_name}]", {"service_name": service_name} @@ -83,14 +90,14 @@ def __init__(self, config, service_name): self._sleeps = CancellableSleeps() self.errors = [0, time.time()] - def stop(self): + def stop(self) -> None: self.running = False self._sleeps.cancel() async def _run(self): raise NotImplementedError() - async def run(self): + async def run(self) -> None: """Runs the service""" if self.running: msg = f"{self.__class__.__name__} is already running." @@ -102,7 +109,7 @@ async def run(self): finally: self.stop() - def raise_if_spurious(self, exception): + def raise_if_spurious(self, exception: TypeError) -> None: errors, first = self.errors errors += 1 @@ -118,7 +125,7 @@ def raise_if_spurious(self, exception): self.errors[0] = errors self.errors[1] = first - def _parse_connectors(self): + def _parse_connectors(self) -> Dict[str, Dict[str, str]]: connectors = {} configured_connectors = deepcopy(self.config.get("connectors")) if configured_connectors is not None: @@ -147,7 +154,17 @@ def _parse_connectors(self): return connectors - def _override_es_config(self, connector): + def _override_es_config( + self, connector: Mock + ) -> Dict[ + str, + Union[ + str, + bool, + Dict[str, Union[int, bool, Dict[str, Union[bool, int, float]]]], + int, + ], + ]: es_config = deepcopy(self.es_config) if connector.id not in self.connectors: return es_config @@ -166,10 +183,10 @@ def _override_es_config(self, connector): class MultiService: """Wrapper class to run multiple services against the same config.""" - def __init__(self, *services): + def __init__(self, *services) -> None: self._services = services - async def run(self): + async def run(self) -> None: """Runs every service in a task and wait for all tasks.""" tasks = [asyncio.create_task(service.run()) for service in self._services] @@ -182,7 +199,7 @@ async def run(self): except asyncio.CancelledError: logger.error("Service did not handle cancellation gracefully.") - def shutdown(self, sig): + def shutdown(self, sig: str) -> None: logger.info(f"Caught {sig}. Graceful shutdown.") for service in self._services: diff --git a/connectors/services/content_sync_job_execution.py b/connectors/services/content_sync_job_execution.py index e83e5c521..f6959ed55 100644 --- a/connectors/services/content_sync_job_execution.py +++ b/connectors/services/content_sync_job_execution.py @@ -5,6 +5,8 @@ # from functools import cached_property +from typing import Dict, List, Union +from unittest.mock import Mock from connectors.protocol import JobStatus, JobType from connectors.services.job_execution import JobExecutionService @@ -13,15 +15,35 @@ class ContentSyncJobExecutionService(JobExecutionService): name = "sync_content" - def __init__(self, config): + def __init__( + self, + config: Dict[ + str, + Union[ + str, + Dict[str, Union[float, int, str]], + List[Dict[str, str]], + Dict[str, str], + Dict[ + str, + Union[ + str, + bool, + Dict[str, Union[int, bool, Dict[str, Union[bool, int, float]]]], + int, + ], + ], + ], + ], + ) -> None: super().__init__(config, "content_sync_job_execution_service") @cached_property - def display_name(self): + def display_name(self) -> str: return "content sync job execution" @cached_property - def max_concurrency_config(self): + def max_concurrency_config(self) -> str: return "service.max_concurrent_content_syncs" @cached_property @@ -32,7 +54,7 @@ def job_types(self): def max_concurrency(self): return self.service_config.get("max_concurrent_content_syncs") - def should_execute(self, connector, sync_job): + def should_execute(self, connector: Mock, sync_job: Mock) -> bool: if connector.last_sync_status == JobStatus.IN_PROGRESS: sync_job.log_debug("Connector is still syncing content, skip the job...") return False diff --git a/connectors/services/job_cleanup.py b/connectors/services/job_cleanup.py index 52c3d8eea..761579c37 100644 --- a/connectors/services/job_cleanup.py +++ b/connectors/services/job_cleanup.py @@ -7,6 +7,8 @@ A task periodically clean up orphaned and idle jobs. """ +from typing import Dict, List, Union + from connectors.es.index import DocumentNotFoundError from connectors.es.management_client import ESManagementClient from connectors.protocol import ConnectorIndex, SyncJobIndex @@ -18,13 +20,15 @@ class JobCleanUpService(BaseService): name = "cleanup" - def __init__(self, config): + def __init__( + self, config: Dict[str, Union[Dict[str, str], Dict[str, int], List[str]]] + ) -> None: super().__init__(config, "job_cleanup_service") self.idling = int(self.service_config.get("job_cleanup_interval", 60 * 5)) self.native_service_types = self.config.get("native_service_types", []) or [] self.connector_ids = list(self.connectors.keys()) - async def _run(self): + async def _run(self) -> int: self.logger.debug("Successfully started Job cleanup task...") self.connector_index = ConnectorIndex(self.es_config) self.es_management_client = ESManagementClient(self.es_config) @@ -46,7 +50,7 @@ async def _run(self): await self.sync_job_index.close() return 0 - async def _process_orphaned_idle_jobs(self): + async def _process_orphaned_idle_jobs(self) -> None: try: self.logger.debug("Cleaning up orphaned idle jobs") connector_ids = [ @@ -78,7 +82,7 @@ async def _process_orphaned_idle_jobs(self): self.logger.error(e, exc_info=True) self.raise_if_spurious(e) - async def _process_idle_jobs(self): + async def _process_idle_jobs(self) -> None: try: self.logger.debug("Start cleaning up idle jobs...") connector_ids = [ diff --git a/connectors/services/job_execution.py b/connectors/services/job_execution.py index e7788afb5..74047613b 100644 --- a/connectors/services/job_execution.py +++ b/connectors/services/job_execution.py @@ -4,6 +4,8 @@ # you may not use this file except in compliance with the Elastic License 2.0. # from functools import cached_property +from typing import Dict, List, Union +from unittest.mock import Mock from connectors.es.client import License from connectors.es.index import DocumentNotFoundError @@ -22,13 +24,34 @@ class JobExecutionService(BaseService): name = "execute" - def __init__(self, config, service_name): + def __init__( + self, + config: Dict[ + str, + Union[ + str, + Dict[str, Union[float, int, str]], + List[Dict[str, str]], + Dict[str, str], + Dict[ + str, + Union[ + str, + bool, + Dict[str, Union[int, bool, Dict[str, Union[bool, int, float]]]], + int, + ], + ], + ], + ], + service_name: str, + ) -> None: super().__init__(config, service_name) self.idling = self.service_config["idling"] self.source_list = config["sources"] self.sync_job_pool = ConcurrentTasks(max_concurrency=self.max_concurrency) - def stop(self): + def stop(self) -> None: super().stop() self.sync_job_pool.cancel() @@ -51,7 +74,7 @@ def max_concurrency(self): def should_execute(self, connector, sync_job): raise NotImplementedError() - async def _sync(self, sync_job): + async def _sync(self, sync_job: Mock) -> None: if sync_job.service_type not in self.source_list: msg = f"Couldn't find data source class for {sync_job.service_type}" raise DataSourceError(msg) @@ -96,7 +119,7 @@ async def _sync(self, sync_job): f"{self.display_name.capitalize()} service is already running {self.max_concurrency} sync jobs and can't run more at this poinit. Increase '{self.max_concurrency_config}' in config if you want the service to run more sync jobs." # pyright: ignore ) - async def _run(self): + async def _run(self) -> int: self.connector_index = ConnectorIndex(self.es_config) self.sync_job_index = SyncJobIndex(self.es_config) diff --git a/connectors/services/job_scheduling.py b/connectors/services/job_scheduling.py index 079da3ad4..3a4efd96f 100644 --- a/connectors/services/job_scheduling.py +++ b/connectors/services/job_scheduling.py @@ -13,7 +13,10 @@ import functools from datetime import datetime, timezone +from typing import Dict, List, Union +from unittest.mock import Mock +import connectors.protocol.connectors from connectors.es.client import License, with_concurrency_control from connectors.es.index import DocumentNotFoundError from connectors.protocol import ( @@ -34,7 +37,27 @@ class JobSchedulingService(BaseService): name = "schedule" - def __init__(self, config): + def __init__( + self, + config: Dict[ + str, + Union[ + str, + Dict[str, Union[float, int, str]], + List[Dict[str, str]], + Dict[str, str], + Dict[ + str, + Union[ + str, + bool, + Dict[str, Union[int, bool, Dict[str, Union[bool, int, float]]]], + int, + ], + ], + ], + ], + ) -> None: super().__init__(config, "job_scheduling_service") self.idling = self.service_config["idling"] self.heartbeat_interval = self.service_config["heartbeat"] @@ -46,11 +69,11 @@ def __init__(self, config): ) self.schedule_tasks_pool = ConcurrentTasks(max_concurrency=self.max_concurrency) - def stop(self): + def stop(self) -> None: super().stop() self.schedule_tasks_pool.cancel() - async def _schedule(self, connector): + async def _schedule(self, connector: Mock) -> None: # To do some first-time stuff just_started = self.first_run self.first_run = False @@ -149,7 +172,7 @@ async def _schedule(self, connector): await self._try_schedule_sync(connector, JobType.FULL) - async def _run(self): + async def _run(self) -> int: """Main event loop.""" self.connector_index = ConnectorIndex(self.es_config) self.sync_job_index = SyncJobIndex(self.es_config) @@ -203,7 +226,9 @@ async def _run(self): await self.sync_job_index.close() return 0 - async def _try_schedule_sync(self, connector, job_type): + async def _try_schedule_sync( + self, connector: Mock, job_type: connectors.protocol.connectors.JobType + ) -> None: this_wake_up_time = datetime.now(timezone.utc) last_wake_up_time = self.last_wake_up_time diff --git a/connectors/source.py b/connectors/source.py index 9129af439..13fe1bdf0 100644 --- a/connectors/source.py +++ b/connectors/source.py @@ -13,7 +13,9 @@ from decimal import Decimal from enum import Enum from functools import cache +from os import PathLike from pydoc import locate +from typing import Dict, List, Type, Union import aiofiles from aiofiles.os import remove @@ -26,6 +28,7 @@ BasicRuleAgainstSchemaValidator, BasicRuleNoMatchAllRegexValidator, BasicRulesSetSemanticValidator, + FilteringValidationResult, FilteringValidator, ) from connectors.logger import logger @@ -37,7 +40,7 @@ hash_id, ) -CHUNK_SIZE = 1024 * 64 # 64KB default SSD page size +CHUNK_SIZE: int = 1024 * 64 # 64KB default SSD page size CURSOR_SYNC_TIMESTAMP = "cursor_timestamp" DEFAULT_CONFIGURATION = { @@ -81,11 +84,11 @@ def __init__( default_value=None, depends_on=None, label=None, - required=True, - field_type="str", + required: bool = True, + field_type: str = "str", validations=None, value=None, - ): + ) -> None: if depends_on is None: depends_on = [] if label is None: @@ -129,7 +132,7 @@ def value(self): def value(self, value): self._value = value - def _convert(self, value, field_type_): + def _convert(self, value, field_type_: str): cast_type = locate(field_type_) if cast_type not in TYPE_DEFAULTS: # unsupported type @@ -166,7 +169,7 @@ def _convert(self, value, field_type_): return cast_type(value) - def is_value_empty(self): + def is_value_empty(self) -> bool: """Checks if the `value` field is empty or not. This always checks `value` and never `default_value`. """ @@ -185,7 +188,7 @@ def is_value_empty(self): # int and bool return value is None - def validate(self): + def validate(self) -> List[str]: """Used to validate the `value` of a Field using its `validations`. If `value` is empty and the field is not required, the validation is run on the `default_value` instead. @@ -263,7 +266,7 @@ def validate(self): class DataSourceConfiguration: """Holds the configuration needed by the source class""" - def __init__(self, config): + def __init__(self, config) -> None: self._raw_config = config self._config = {} self._defaults = {} @@ -283,7 +286,7 @@ def __init__(self, config): else: self.set_field(key, label=key.capitalize(), value=str(value)) - def set_defaults(self, default_config): + def set_defaults(self, default_config) -> None: for name, item in default_config.items(): self._defaults[name] = item.get("value") if name in self._config: @@ -308,11 +311,11 @@ def set_field( default_value=None, depends_on=None, label=None, - required=True, - field_type="str", + required: bool = True, + field_type: str = "str", validations=None, value=None, - ): + ) -> None: self._config[name] = Field( name, default_value, @@ -330,13 +333,13 @@ def get_field(self, name): def get_fields(self): return self._config.values() - def is_empty(self): + def is_empty(self) -> bool: return len(self._config) == 0 def to_dict(self): return dict(self._raw_config) - def check_valid(self): + def check_valid(self) -> None: """Validates every Field against its `validations`. Raises ConfigurableFieldValueError if any validation errors are found. @@ -364,7 +367,7 @@ def check_valid(self): msg = f"Field validation errors: {'; '.join(validation_errors)}" raise ConfigurableFieldValueError(msg) - def dependencies_satisfied(self, field): + def dependencies_satisfied(self, field) -> bool: """Used to check if a Field has its dependencies satisfied. Returns True if all dependencies are satisfied, or no dependencies exist. @@ -419,20 +422,20 @@ def __init__(self, configuration): # this will be overwritten by set_framework_config() self.framework_config = DataSourceFrameworkConfig.Builder().build() - def __str__(self): + def __str__(self) -> str: return f"Datasource `{self.__class__.name}`" - def set_logger(self, logger_): + def set_logger(self, logger_) -> None: self._logger = logger_ self._set_internal_logger() - def _set_internal_logger(self): + def _set_internal_logger(self) -> None: # no op for BaseDataSource # if there are internal class (e.g. Client class) to which the logger need to be set, # this method needs to be implemented pass - def set_framework_config(self, framework_config): + def set_framework_config(self, framework_config) -> None: """Called by the framework, this exposes framework-wide configuration to be used by the DataSource""" self.framework_config = framework_config @@ -459,7 +462,17 @@ def get_default_configuration(cls): raise NotImplementedError @classmethod - def basic_rules_validators(cls): + def basic_rules_validators( + cls, + ) -> List[ + Type[ + Union[ + BasicRuleAgainstSchemaValidator, + BasicRuleNoMatchAllRegexValidator, + BasicRulesSetSemanticValidator, + ] + ] + ]: """Return default basic rule validators. Basic rule validators are executed in the order they appear in the list. @@ -472,7 +485,7 @@ def basic_rules_validators(cls): ] @classmethod - def hash_id(cls, _id): + def hash_id(cls, _id) -> str: """Called, when an `_id` is too long to be ingested into elasticsearch. This method can be overridden to execute a hash function on a document `_id`, @@ -483,7 +496,7 @@ def hash_id(cls, _id): return hash_id(_id) @classmethod - def features(cls): + def features(cls) -> Dict[str, Union[Dict[str, Dict[str, bool]], Dict[str, bool]]]: """Returns features available for the data source""" return { "sync_rules": { @@ -505,13 +518,13 @@ def features(cls): }, } - def set_features(self, features): + def set_features(self, features) -> None: if self._features is not None: self._logger.warning(f"'_features' already set in {self.__class__.name}") self._logger.debug(f"Setting '_features' for {self.__class__.name}") self._features = features - async def validate_filtering(self, filtering): + async def validate_filtering(self, filtering) -> FilteringValidationResult: """Execute all basic rule and advanced rule validators.""" return await FilteringValidator( @@ -528,7 +541,7 @@ def advanced_rules_validators(self): """ return [] - async def changed(self): + async def changed(self) -> bool: """When called, returns True if something has changed in the backend. Otherwise, returns False and the next sync is skipped. @@ -538,7 +551,7 @@ async def changed(self): """ return True - async def validate_config(self): + async def validate_config(self) -> None: """When called, validates configuration of the connector that is contained in self.configuration If connector configuration is invalid, this method will raise an exception @@ -546,7 +559,7 @@ async def validate_config(self): """ self.configuration.check_valid() - def validate_config_fields(self): + def validate_config_fields(self) -> None: """ "Checks if any fields in a configuration are missing. If a field is missing, raises an error. Ignores additional non-standard fields. @@ -574,7 +587,7 @@ async def ping(self): """ raise NotImplementedError - async def close(self): + async def close(self) -> None: """Called when the source is closed. Can be used to close connections @@ -646,7 +659,7 @@ async def get_file(doit=True, timestamp=None): """ raise NotImplementedError - def tweak_bulk_options(self, options): + def tweak_bulk_options(self, options) -> None: """Receives the bulk options every time a sync happens, so they can be tweaked if needed. @@ -699,7 +712,7 @@ def sync_cursor(self): return self._sync_cursor @staticmethod - def is_premium(): + def is_premium() -> bool: """Returns True if this DataSource is a Premium (paid license gated) connector. Otherwise, returns False. @@ -710,12 +723,12 @@ def is_premium(): def get_file_extension(self, filename): return get_file_extension(filename) - def can_file_be_downloaded(self, file_extension, filename, file_size): + def can_file_be_downloaded(self, file_extension, filename, file_size) -> bool: return self.is_valid_file_type( file_extension, filename ) and self.is_file_size_within_limit(file_size, filename) - def is_valid_file_type(self, file_extension, filename): + def is_valid_file_type(self, file_extension, filename) -> bool: if file_extension == "": self._logger.debug( f"Files without extension are not supported, skipping {filename}." @@ -730,7 +743,7 @@ def is_valid_file_type(self, file_extension, filename): return True - def is_file_size_within_limit(self, file_size, filename): + def is_file_size_within_limit(self, file_size, filename) -> bool: if ( file_size > self.framework_config.max_file_size and not self.configuration.get("use_text_extraction_service") @@ -748,7 +761,7 @@ async def download_and_extract_file( source_filename, file_extension, download_func, - return_doc_if_failed=False, + return_doc_if_failed: bool = False, ): """ Performs all the steps required for handling binary content: @@ -808,7 +821,7 @@ async def create_temp_file(self, file_extension): async def download_to_temp_file( self, temp_filename, source_filename, async_buffer, chunked_download_func - ): + ) -> None: self._logger.debug(f"Download beginning for file: {source_filename}.") async for data in chunked_download_func(): await async_buffer.write(data) @@ -850,7 +863,9 @@ async def handle_file_content_extraction(self, doc, source_filename, temp_filena return doc - async def remove_temp_file(self, temp_filename): + async def remove_temp_file( + self, temp_filename: Union[PathLike[bytes], PathLike[str], bytes, str] + ) -> None: try: await remove(temp_filename) except Exception as e: @@ -864,7 +879,7 @@ def last_sync_time(self): return default_time return self._sync_cursor.get(CURSOR_SYNC_TIMESTAMP, default_time) - def update_sync_timestamp_cursor(self, timestamp): + def update_sync_timestamp_cursor(self, timestamp) -> None: if self._sync_cursor is None: self._sync_cursor = {} self._sync_cursor[CURSOR_SYNC_TIMESTAMP] = timestamp diff --git a/connectors/sources/atlassian.py b/connectors/sources/atlassian.py index b44cb4311..c18f74f90 100644 --- a/connectors/sources/atlassian.py +++ b/connectors/sources/atlassian.py @@ -4,6 +4,14 @@ # you may not use this file except in compliance with the Elastic License 2.0. # +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type, Union + +if TYPE_CHECKING: + from connectors.sources.confluence import ConfluenceClient, ConfluenceDataSource + from connectors.sources.jira import JiraClient, JiraDataSource + import fastjsonschema from fastjsonschema import JsonSchemaValueException @@ -34,10 +42,12 @@ class AtlassianAdvancedRulesValidator(AdvancedRulesValidator): SCHEMA = fastjsonschema.compile(definition=SCHEMA_DEFINITION) - def __init__(self, source): + def __init__(self, source: Type[AdvancedRulesValidator]) -> None: self.source = source - async def validate(self, advanced_rules): + async def validate( + self, advanced_rules: Union[List[Dict[str, str]], Dict[str, List[str]]] + ) -> SyncRuleValidationResult: if len(advanced_rules) == 0: return SyncRuleValidationResult.valid_result( SyncRuleValidationResult.ADVANCED_RULES @@ -50,7 +60,9 @@ async def validate(self, advanced_rules): interval=RETRY_INTERVAL, strategy=RetryStrategy.EXPONENTIAL_BACKOFF, ) - async def _remote_validation(self, advanced_rules): + async def _remote_validation( + self, advanced_rules: Union[List[Dict[str, str]], Dict[str, List[str]]] + ) -> SyncRuleValidationResult: try: AtlassianAdvancedRulesValidator.SCHEMA(advanced_rules) except JsonSchemaValueException as e: @@ -65,46 +77,52 @@ async def _remote_validation(self, advanced_rules): ) -def prefix_account_id(account_id): +def prefix_account_id(account_id: str) -> Optional[str]: return prefix_identity("account_id", account_id) -def prefix_group_id(group_id): +def prefix_group_id(group_id: str) -> Optional[str]: return prefix_identity("group_id", group_id) -def prefix_role_key(role_key): +def prefix_role_key(role_key: str) -> Optional[str]: return prefix_identity("role_key", role_key) -def prefix_account_name(account_name): +def prefix_account_name(account_name: str) -> Optional[str]: return prefix_identity("name", account_name.replace(" ", "-")) -def prefix_account_email(email): +def prefix_account_email(email: str) -> Optional[str]: return prefix_identity("email_address", email) -def prefix_account_locale(locale): +def prefix_account_locale(locale: str) -> Optional[str]: return prefix_identity("locale", locale) -def prefix_user(user): +def prefix_user(user: str) -> Optional[str]: if not user: return return prefix_identity("user", user) -def prefix_group(group): +def prefix_group(group: str) -> Optional[str]: return prefix_identity("group", group) class AtlassianAccessControl: - def __init__(self, source, client): + def __init__( + self, + source: Union[JiraDataSource, ConfluenceDataSource], + client: Union[Type[JiraClient], ConfluenceClient, JiraClient], + ) -> None: self.source = source self.client = client - def access_control_query(self, access_control): + def access_control_query( + self, access_control: List[str] + ) -> Dict[str, Dict[str, Dict[str, Union[Dict[str, List[str]], str]]]]: return es_access_control_query(access_control) async def fetch_all_users(self, url): @@ -155,7 +173,10 @@ async def fetch_user_for_confluence(self, url): user = await self.client.api_call(url=url) yield await user.json() - async def user_access_control_doc(self, user): + async def user_access_control_doc( + self, + user: Dict[str, Union[str, bool, Dict[str, Union[List[Dict[str, str]], int]]]], + ) -> Dict[str, Any]: """Generate a user access control document. This method generates a user access control document based on the provided user information. @@ -216,7 +237,7 @@ async def user_access_control_doc(self, user): return user_document | self.access_control_query(access_control=access_control) - def is_active_atlassian_user(self, user_info): + def is_active_atlassian_user(self, user_info: Dict[str, Union[bool, str]]) -> bool: from connectors.sources.confluence import CONFLUENCE_CLOUD from connectors.sources.jira import JIRA_CLOUD diff --git a/connectors/sources/azure_blob_storage.py b/connectors/sources/azure_blob_storage.py index d975880cb..dd7b76244 100644 --- a/connectors/sources/azure_blob_storage.py +++ b/connectors/sources/azure_blob_storage.py @@ -5,11 +5,15 @@ # """Azure Blob Storage source module responsible to fetch documents from Azure Blob Storage""" +from _asyncio import Future +from datetime import datetime from functools import partial +from typing import Any, Dict, Generator, Iterator, List, Optional, Union +import azure.storage.blob.aio._container_client_async from azure.storage.blob.aio import BlobServiceClient, ContainerClient -from connectors.source import BaseDataSource +from connectors.source import BaseDataSource, DataSourceConfiguration BLOB_SCHEMA = { "title": "name", @@ -31,7 +35,7 @@ class AzureBlobStorageDataSource(BaseDataSource): service_type = "azure_blob_storage" incremental_sync_enabled = True - def __init__(self, configuration): + def __init__(self, configuration: DataSourceConfiguration) -> None: """Set up the connection to the azure base client Args: @@ -44,7 +48,7 @@ def __init__(self, configuration): self.containers = self.configuration["containers"] self.container_clients = {} - def tweak_bulk_options(self, options): + def tweak_bulk_options(self, options: Dict[str, int]) -> None: """Tweak bulk options as per concurrent downloads support by azure blob storage Args: @@ -56,7 +60,16 @@ def tweak_bulk_options(self, options): options["concurrent_downloads"] = self.concurrent_downloads @classmethod - def get_default_configuration(cls): + def get_default_configuration( + cls, + ) -> Dict[ + str, + Union[ + Dict[str, Union[List[Dict[str, Union[int, str]]], List[str], int, str]], + Dict[str, Union[List[str], int, str]], + Dict[str, Union[int, str]], + ], + ]: """Get the default configuration for Azure Blob Storage Returns: @@ -116,7 +129,7 @@ def get_default_configuration(cls): }, } - def _configure_connection_string(self): + def _configure_connection_string(self) -> str: """Generates connection string for ABS Returns: @@ -125,7 +138,7 @@ def _configure_connection_string(self): return f'AccountName={self.configuration["account_name"]};AccountKey={self.configuration["account_key"]};BlobEndpoint={self.configuration["blob_endpoint"]}' - async def ping(self): + async def ping(self) -> None: """Verify the connection with Azure Blob Storage""" self._logger.info("Generating connection string...") self.connection_string = self._configure_connection_string() @@ -139,14 +152,18 @@ async def ping(self): self._logger.exception("Error while connecting to the Azure Blob Storage.") raise - async def close(self): + async def close(self) -> None: if not self.container_clients: return for container_client in self.container_clients.values(): await container_client.close() self.container_clients = {} - def prepare_blob_doc(self, blob, container_metadata): + def prepare_blob_doc( + self, + blob: Dict[str, Union[str, Dict[str, str], datetime, int]], + container_metadata: Dict[str, str], + ) -> Dict[str, Any]: """Prepare key mappings to blob document Args: @@ -169,7 +186,12 @@ def prepare_blob_doc(self, blob, container_metadata): document[elasticsearch_field] = blob[azure_blob_storage_field] return document - async def get_content(self, blob, timestamp=None, doit=None): + async def get_content( + self, + blob: Dict[str, Union[int, str]], + timestamp: None = None, + doit: Optional[bool] = None, + ) -> Generator[Future, None, Optional[Dict[str, str]]]: """Get blob content via specific blob client Args: @@ -203,7 +225,9 @@ async def get_content(self, blob, timestamp=None, doit=None): partial(self.blob_download_func, filename, blob["container"], file_size), ) - def _get_container_client(self, container_name): + def _get_container_client( + self, container_name: str + ) -> azure.storage.blob.aio._container_client_async.ContainerClient: if self.container_clients.get(container_name) is None: try: self.container_clients[container_name] = ( @@ -221,7 +245,7 @@ def _get_container_client(self, container_name): else: return self.container_clients[container_name] - async def blob_download_func(self, blob_name, container_name, file_size): + async def blob_download_func(self, blob_name, container_name: str, file_size): container_client = self._get_container_client(container_name=container_name) offset = 0 length = INITIAL_DOWNLOAD_SIZE @@ -238,7 +262,7 @@ async def blob_download_func(self, blob_name, container_name, file_size): file_size = file_size - length yield content - async def get_container(self, container_list): + async def get_container(self, container_list: List[str]) -> Iterator[None]: """Get containers from Azure Blob Storage via azure base client Args: container_list (list): List of containers @@ -272,7 +296,7 @@ async def get_container(self, container_list): f"Something went wrong while fetching containers. Error: {exception}" ) - async def get_blob(self, container): + async def get_blob(self, container: Dict[str, Union[Dict[str, str], str]]) -> None: """Get blobs for a specific container from Azure Blob Storage via container client Args: diff --git a/connectors/sources/box.py b/connectors/sources/box.py index 7bf2eeb22..58c3402dc 100644 --- a/connectors/sources/box.py +++ b/connectors/sources/box.py @@ -8,17 +8,20 @@ import asyncio import logging import os +from _asyncio import Future, Task from datetime import datetime, timedelta from functools import cached_property, partial +from typing import Any, Dict, Generator, Iterator, List, Optional, Union import aiofiles import aiohttp from aiofiles.os import remove from aiofiles.tempfile import NamedTemporaryFile +from aiohttp.client import ClientSession from aiohttp.client_exceptions import ClientResponseError from connectors.logger import logger -from connectors.source import BaseDataSource +from connectors.source import BaseDataSource, DataSourceConfiguration from connectors.utils import ( TIKA_SUPPORTED_FILETYPES, CacheWithTimeout, @@ -43,7 +46,7 @@ RETRY_INTERVAL = 2 CHUNK_SIZE = 1024 FETCH_LIMIT = 1000 -QUEUE_MEM_SIZE = 5 * 1024 * 1024 # ~ 5 MB +QUEUE_MEM_SIZE: int = 5 * 1024 * 1024 # ~ 5 MB MAX_CONCURRENCY = 2000 MAX_CONCURRENT_DOWNLOADS = 15 FIELDS = "name,modified_at,size,type,sequence_id,etag,created_at,modified_at,content_created_at,content_modified_at,description,created_by,modified_by,owned_by,parent,item_status" @@ -54,9 +57,9 @@ refresh_token = None if "BOX_BASE_URL" in os.environ: - BASE_URL = os.environ.get("BOX_BASE_URL") + BASE_URL: Optional[str] = os.environ.get("BOX_BASE_URL") else: - BASE_URL = "https://api.box.com" + BASE_URL: Optional[str] = "https://api.box.com" class TokenError(Exception): @@ -68,7 +71,9 @@ class NotFound(Exception): class AccessToken: - def __init__(self, configuration, http_session): + def __init__( + self, configuration: DataSourceConfiguration, http_session: ClientSession + ) -> None: global refresh_token self.client_id = configuration["client_id"] self.client_secret = configuration["client_secret"] @@ -79,14 +84,14 @@ def __init__(self, configuration, http_session): self.is_enterprise = configuration["is_enterprise"] self.enterprise_id = configuration["enterprise_id"] - async def get(self): + async def get(self) -> str: if cached_value := self._token_cache.get_value(): return cached_value logger.debug("No token cache found; fetching new token") await self._set_access_token() return self.access_token - async def _set_access_token(self): + async def _set_access_token(self) -> None: logger.debug("Generating an access token") try: if self.is_enterprise == BOX_FREE: @@ -136,7 +141,7 @@ async def _set_access_token(self): class BoxClient: - def __init__(self, configuration): + def __init__(self, configuration: DataSourceConfiguration) -> None: self._sleeps = CancellableSleeps() self.configuration = configuration self._logger = logger @@ -148,10 +153,10 @@ def __init__(self, configuration): configuration=configuration, http_session=self._http_session ) - def set_logger(self, logger_): + def set_logger(self, logger_) -> None: self._logger = logger_ - async def _put_to_sleep(self, retry_after): + async def _put_to_sleep(self, retry_after: int) -> Iterator[Task]: self._logger.debug( f"Connector will attempt to retry after {retry_after} seconds." ) @@ -159,7 +164,9 @@ async def _put_to_sleep(self, retry_after): msg = "Rate limit exceeded." raise Exception(msg) - def debug_query_string(self, params): + def debug_query_string( + self, params: Optional[Dict[str, Union[int, str]]] + ) -> Optional[str]: if self._logger.isEnabledFor(logging.DEBUG): return ( "&".join(f"{key}={value}" for key, value in params.items()) @@ -167,7 +174,7 @@ def debug_query_string(self, params): else "" ) - async def _handle_client_errors(self, exception): + async def _handle_client_errors(self, exception: ClientResponseError) -> None: match exception.status: case 401: await self.token._set_access_token() @@ -187,7 +194,12 @@ async def _handle_client_errors(self, exception): strategy=RetryStrategy.EXPONENTIAL_BACKOFF, skipped_exceptions=NotFound, ) - async def get(self, url, headers, params=None): + async def get( + self, + url: str, + headers: Dict[str, str], + params: Optional[Dict[str, Union[int, str]]] = None, + ) -> Iterator[Task]: self._logger.debug( f"Calling GET {url}?{self.debug_query_string(params=params)}" ) @@ -200,7 +212,9 @@ async def get(self, url, headers, params=None): except Exception: raise - async def paginated_call(self, url, params, headers): + async def paginated_call( + self, url: str, params: Dict[str, str], headers: Dict[Any, Any] + ) -> Iterator[Task]: try: offset = 0 while True: @@ -216,10 +230,10 @@ async def paginated_call(self, url, params, headers): except Exception: raise - async def ping(self): + async def ping(self) -> None: await self.get(url=ENDPOINTS["PING"], headers={}) - async def close(self): + async def close(self) -> None: self._sleeps.cancel() await self._http_session.close() @@ -229,7 +243,7 @@ class BoxDataSource(BaseDataSource): service_type = "box" incremental_sync_enabled = True - def __init__(self, configuration): + def __init__(self, configuration: DataSourceConfiguration) -> None: super().__init__(configuration=configuration) self.configuration = configuration self.tasks = 0 @@ -238,14 +252,14 @@ def __init__(self, configuration): self.fetchers = ConcurrentTasks(max_concurrency=MAX_CONCURRENCY) self.concurrent_downloads = configuration["concurrent_downloads"] - def _set_internal_logger(self): + def _set_internal_logger(self) -> None: self.client.set_logger(logger_=self._logger) @cached_property - def client(self): + def client(self) -> BoxClient: return BoxClient(configuration=self.configuration) - def tweak_bulk_options(self, options): + def tweak_bulk_options(self, options) -> None: """Tweak bulk options as per concurrent downloads support by Box Args: @@ -254,7 +268,16 @@ def tweak_bulk_options(self, options): options["concurrent_downloads"] = self.concurrent_downloads @classmethod - def get_default_configuration(cls): + def get_default_configuration( + cls, + ) -> Dict[ + str, + Union[ + Dict[str, Union[List[Dict[str, str]], int, str]], + Dict[str, Union[List[Dict[str, Union[int, str]]], List[str], int, str]], + Dict[str, Union[int, str]], + ], + ]: """Get the default configuration for Box. Returns: @@ -310,12 +333,12 @@ def get_default_configuration(cls): }, } - async def close(self): + async def close(self) -> None: while not self.queue.empty(): await self.queue.get() await self.client.close() - async def ping(self): + async def ping(self) -> None: try: await self.client.ping() self._logger.debug("Successfully connected to Box") @@ -330,7 +353,7 @@ async def get_users_id(self): ): yield user.get("id") - async def _fetch(self, doc_id, user_id=None): + async def _fetch(self, doc_id: Union[str, int], user_id: None = None) -> None: self._logger.info( f"Fetching files and folders recursively for folder ID: {doc_id}" ) @@ -375,7 +398,9 @@ async def _fetch(self, doc_id, user_id=None): finally: await self.queue.put(FINISHED) - async def _get_document_with_content(self, url, attachment_name, document, user_id): + async def _get_document_with_content( + self, url: str, attachment_name: str, document: Dict[str, str], user_id: None + ) -> Generator[Future, None, Dict[str, str]]: file_data = await self.client.get( url=url, headers={"as-user": user_id} if user_id else {} ) @@ -400,8 +425,8 @@ async def _get_document_with_content(self, url, attachment_name, document, user_ return document def _pre_checks_for_get_content( - self, attachment_extension, attachment_name, attachment_size - ): + self, attachment_extension: str, attachment_name: str, attachment_size: int + ) -> bool: if attachment_extension == "": self._logger.debug( f"Files without extension are not supported, skipping {attachment_name}." @@ -421,7 +446,13 @@ def _pre_checks_for_get_content( return False return True - async def get_content(self, attachment, user_id=None, timestamp=None, doit=False): + async def get_content( + self, + attachment: Dict[str, Union[int, str]], + user_id: None = None, + timestamp: None = None, + doit: bool = False, + ) -> Generator[Future, None, Optional[Dict[str, str]]]: """Extracts the content for Apache TIKA supported file types. Args: diff --git a/connectors/sources/confluence.py b/connectors/sources/confluence.py index 1067e4948..f058c02bf 100644 --- a/connectors/sources/confluence.py +++ b/connectors/sources/confluence.py @@ -7,16 +7,24 @@ import asyncio import os +from _asyncio import Future, Task from copy import copy from functools import partial +from typing import Any, Dict, Generator, Iterator, List, Set, Union +from unittest.mock import AsyncMock from urllib.parse import urljoin import aiohttp +from aiohttp.client import ClientSession from aiohttp.client_exceptions import ClientResponseError, ServerDisconnectedError from connectors.access_control import ACCESS_CONTROL from connectors.logger import logger -from connectors.source import BaseDataSource, ConfigurableFieldValueError +from connectors.source import ( + BaseDataSource, + ConfigurableFieldValueError, + DataSourceConfiguration, +) from connectors.sources.atlassian import ( AtlassianAccessControl, AtlassianAdvancedRulesValidator, @@ -63,7 +71,7 @@ USER_QUERY = "expand=groups,applicationRoles" LABEL = "label" -URLS = { +URLS: Dict[str, str] = { SPACE: "rest/api/space?{api_query}", SPACE_PERMISSION: "rest/extender/1.0/permission/space/{space_key}/getSpacePermissionActors/VIEWSPACE", CONTENT: "rest/api/content/search?{api_query}", @@ -79,7 +87,7 @@ MAX_CONCURRENT_DOWNLOADS = 50 # Max concurrent download supported by confluence MAX_CONCURRENCY = 50 QUEUE_SIZE = 1024 -QUEUE_MEM_SIZE = 25 * 1024 * 1024 # Size in Megabytes +QUEUE_MEM_SIZE: int = 25 * 1024 * 1024 # Size in Megabytes SERVER_USER_BATCH = 1000 DATACENTER_USER_BATCH = 200 END_SIGNAL = "FINISHED_TASK" @@ -123,7 +131,7 @@ class Forbidden(Exception): class ConfluenceClient: """Confluence client to handle API calls made to Confluence""" - def __init__(self, configuration): + def __init__(self, configuration: DataSourceConfiguration) -> None: self._sleeps = CancellableSleeps() self.configuration = configuration self._logger = logger @@ -142,10 +150,10 @@ def __init__(self, configuration): self.ssl_ctx = False self.session = None - def set_logger(self, logger_): + def set_logger(self, logger_) -> None: self._logger = logger_ - def _get_session(self): + def _get_session(self) -> ClientSession: """Generate and return base client session with configuration fields Returns: @@ -189,7 +197,7 @@ def _get_session(self): ) return self.session - async def close_session(self): + async def close_session(self) -> None: """Closes unclosed client session""" self._sleeps.cancel() if self.session is None: @@ -197,7 +205,9 @@ async def close_session(self): await self.session.close() self.session = None - async def _handle_client_errors(self, url, exception): + async def _handle_client_errors( + self, url: str, exception: Union[TypeError, ValueError] + ) -> Iterator[Task]: if exception.status == 429: response_headers = exception.headers or {} retry_seconds = DEFAULT_RETRY_SECONDS @@ -245,7 +255,9 @@ async def _handle_client_errors(self, url, exception): strategy=RetryStrategy.EXPONENTIAL_BACKOFF, skipped_exceptions=NotFound, ) - async def api_call(self, url): + async def api_call( + self, url: str + ) -> Generator[Union[Task, Future], None, AsyncMock]: """Make a GET call for Atlassian API using the passed url with retry for the failed API calls. Args: @@ -273,7 +285,7 @@ async def api_call(self, url): except ClientResponseError as exception: await self._handle_client_errors(url=url, exception=exception) - async def paginated_api_call(self, url_name, **url_kwargs): + async def paginated_api_call(self, url_name: str, **url_kwargs) -> Iterator[Future]: """Make a paginated API call for Confluence objects using the passed url_name. Args: url_name (str): URL Name to identify the API endpoint to hit @@ -370,7 +382,9 @@ async def fetch_spaces(self): if (spaces == [WILDCARD]) or (space.get("key", "") in spaces): yield space - async def fetch_server_space_permission(self, url): + async def fetch_server_space_permission( + self, url: str + ) -> Dict[str, Dict[str, Dict[str, List[str]]]]: try: permissions = await self.api_call(url=os.path.join(self.host_url, url)) permission = await permissions.json() @@ -399,7 +413,7 @@ async def fetch_page_blog_documents(self, api_query): document["labels"] = labels yield document, attachment_count - async def fetch_attachments(self, content_id): + async def fetch_attachments(self, content_id: str) -> Iterator[Future]: async for response in self.paginated_api_call( url_name=ATTACHMENT, api_query=ATTACHMENT_QUERY, @@ -408,12 +422,12 @@ async def fetch_attachments(self, content_id): for attachment in response.get("results", []): yield attachment - async def ping(self): + async def ping(self) -> None: await self.api_call( url=os.path.join(self.host_url, PING_URL), ) - async def fetch_confluence_server_users(self): + async def fetch_confluence_server_users(self) -> None: start_at = 0 if self.data_source_type == CONFLUENCE_DATA_CENTER: limit = DATACENTER_USER_BATCH @@ -433,7 +447,7 @@ async def fetch_confluence_server_users(self): yield response.get(key) start_at += limit - async def fetch_label(self, label_id): + async def fetch_label(self, label_id: int) -> List[None]: url = os.path.join(self.host_url, URLS[LABEL].format(id=label_id)) label_data = await self.api_call(url=url) labels = await label_data.json() @@ -449,7 +463,7 @@ class ConfluenceDataSource(BaseDataSource): dls_enabled = True incremental_sync_enabled = True - def __init__(self, configuration): + def __init__(self, configuration: DataSourceConfiguration) -> None: """Setup the connection to Confluence Args: @@ -472,11 +486,22 @@ def __init__(self, configuration): else "publicName" ) - def _set_internal_logger(self): + def _set_internal_logger(self) -> None: self.confluence_client.set_logger(self._logger) @classmethod - def get_default_configuration(cls): + def get_default_configuration( + cls, + ) -> Dict[ + str, + Union[ + Dict[str, Union[List[Dict[str, str]], int, str]], + Dict[str, Union[List[Dict[str, Union[bool, str]]], int, str]], + Dict[str, Union[List[Dict[str, Union[int, str]]], List[str], int, str]], + Dict[str, Union[List[str], int, str]], + Dict[str, Union[int, str]], + ], + ]: """Get the default configuration for Confluence Returns: @@ -614,7 +639,7 @@ def get_default_configuration(cls): }, } - def _dls_enabled(self): + def _dls_enabled(self) -> bool: """Check if document level security is enabled. This method checks whether document level security (DLS) is enabled based on the provided configuration. Returns: @@ -628,7 +653,11 @@ def _dls_enabled(self): return self.configuration["use_document_level_security"] - def _decorate_with_access_control(self, document, access_control): + def _decorate_with_access_control( + self, + document: Dict[str, Union[str, int, List[Dict[str, str]], List[None]]], + access_control: List[Union[str, Any]], + ) -> Dict[str, Union[str, int, List[Dict[str, str]], List[None], List[str]]]: if self._dls_enabled(): document[ACCESS_CONTROL] = list( set(document.get(ACCESS_CONTROL, []) + access_control) @@ -636,7 +665,9 @@ def _decorate_with_access_control(self, document, access_control): return document - async def user_access_control_confluence_server(self, user): + async def user_access_control_confluence_server( + self, user: Dict[str, str] + ) -> Dict[str, Any]: """Generate a user access control document for confluence server. This method generates a user access control document based on the provided user information. @@ -685,7 +716,9 @@ async def user_access_control_confluence_server(self, user): access_control=access_control ) - async def user_access_control_data_center(self, user): + async def user_access_control_data_center( + self, user: Dict[str, str] + ) -> Dict[str, Any]: """Generate a user access control document for confluence data enter. This method generates a user access control document based on the provided user information. @@ -767,7 +800,7 @@ async def get_user(self): user=user ) - async def get_access_control(self): + async def get_access_control(self) -> None: """Get access control documents for active Atlassian users. This method fetches access control documents for active Atlassian users when document level security (DLS) @@ -792,7 +825,23 @@ async def get_access_control(self): async for user in users: yield user - def _get_access_control_from_permission(self, permissions, target_type): + def _get_access_control_from_permission( + self, + permissions: List[ + Union[ + Dict[ + str, + Union[ + int, + Dict[str, Dict[str, Union[List[Dict[str, str]], int]]], + Dict[str, str], + ], + ], + Any, + ] + ], + target_type: str, + ) -> Set[str]: if not self._dls_enabled(): return [] @@ -811,7 +860,12 @@ def _get_access_control_from_permission(self, permissions, target_type): return access_control - def _extract_identities(self, response): + def _extract_identities( + self, + response: Dict[ + str, Union[Dict[str, Union[List[Dict[str, str]], int]], Dict[str, int]] + ], + ) -> Set[str]: if not self._dls_enabled(): return set() @@ -828,7 +882,9 @@ def _extract_identities(self, response): return identities - def _extract_identities_for_datacenter(self, response): + def _extract_identities_for_datacenter( + self, response: Dict[str, Dict[str, List[Dict[str, str]]]] + ) -> Set[str]: if not self._dls_enabled(): return set() identities = set() @@ -844,14 +900,14 @@ def _extract_identities_for_datacenter(self, response): return identities - async def close(self): + async def close(self) -> None: """Closes unclosed client session""" await self.confluence_client.close_session() - def advanced_rules_validators(self): + def advanced_rules_validators(self) -> List[AtlassianAdvancedRulesValidator]: return [AtlassianAdvancedRulesValidator(self)] - def tweak_bulk_options(self, options): + def tweak_bulk_options(self, options: Dict[str, int]) -> None: """Tweak bulk options as per concurrent downloads support by Confluence Args: @@ -859,7 +915,7 @@ def tweak_bulk_options(self, options): """ options["concurrent_downloads"] = self.concurrent_downloads - async def validate_config(self): + async def validate_config(self) -> None: """Validates whether user input is empty or not for configuration fields. Also validate, if user configured spaces are available in Confluence. @@ -869,7 +925,7 @@ async def validate_config(self): await super().validate_config() await self._remote_validation() - async def _remote_validation(self): + async def _remote_validation(self) -> None: await self.confluence_client.ping() if self.spaces == [WILDCARD]: return @@ -883,7 +939,7 @@ async def _remote_validation(self): msg = f"Spaces '{', '.join(unavailable_spaces)}' are not available. Available spaces are: '{', '.join(space_keys)}'" raise ConfigurableFieldValueError(msg) - async def ping(self): + async def ping(self) -> None: """Verify the connection with Confluence""" try: await self.confluence_client.ping() @@ -891,7 +947,7 @@ async def ping(self): self._logger.warning(f"Error while connecting to Confluence: {e}") raise - def get_permission(self, permission): + def get_permission(self, permission: Dict[str, List[str]]) -> Set[str]: permissions = set() if permission.get("users"): for user in permission.get("users"): @@ -903,7 +959,9 @@ def get_permission(self, permission): return permissions - async def fetch_server_space_permission(self, space_key): + async def fetch_server_space_permission( + self, space_key: str + ) -> Dict[str, Dict[str, Dict[str, List[str]]]]: if not self._dls_enabled(): return {} @@ -970,8 +1028,8 @@ async def fetch_documents(self, api_query): ) async def fetch_attachments( - self, content_id, parent_name, parent_space, parent_type - ): + self, content_id: str, parent_name: str, parent_space: str, parent_type: str + ) -> Iterator[Future]: """Fetches all the attachments present in the given content (pages and blog posts) Args: @@ -1072,7 +1130,13 @@ async def search_by_query(self, query): yield document, download_url - async def download_attachment(self, url, attachment, timestamp=None, doit=False): + async def download_attachment( + self, + url: str, + attachment: Dict[str, Union[str, int]], + timestamp: None = None, + doit: bool = False, + ) -> Iterator[Future]: """Downloads the content of the given attachment in chunks using REST API call Args: @@ -1108,7 +1172,13 @@ async def download_attachment(self, url, attachment, timestamp=None, doit=False) ), ) - async def _attachment_coro(self, document, access_control): + async def _attachment_coro( + self, + document: Dict[ + str, Union[str, int, List[Dict[str, str]], List[None], List[str]] + ], + access_control: List[Union[str, Any]], + ) -> None: """Coroutine to add attachments to Queue and download content Args: @@ -1143,7 +1213,28 @@ async def _attachment_coro(self, document, access_control): finally: await self.queue.put(END_SIGNAL) # pyright: ignore - def format_space(self, space): + def format_space( + self, + space: Dict[ + str, + Union[ + int, + str, + Dict[str, str], + List[ + Dict[ + str, + Union[ + int, + Dict[str, Dict[str, Union[List[Dict[str, str]], int]]], + Dict[str, str], + ], + ] + ], + Dict[str, Union[Dict[str, str], str]], + ], + ], + ) -> Dict[str, Any]: space_url = os.path.join( self.confluence_client.host_url, space.get("_links", {}).get("webui", "")[1:], @@ -1197,7 +1288,7 @@ async def _space_coro(self): self._logger.exception(f"Error while fetching spaces: {exception}") raise - async def _page_blog_coro(self, api_query, target_type): + async def _page_blog_coro(self, api_query: str, target_type: str) -> None: """Coroutine to add pages/blogposts to Queue Args: diff --git a/connectors/sources/directory.py b/connectors/sources/directory.py index ab62c5014..0136b4879 100644 --- a/connectors/sources/directory.py +++ b/connectors/sources/directory.py @@ -10,14 +10,15 @@ import functools import os from datetime import datetime, timezone -from pathlib import Path +from pathlib import Path, PosixPath +from typing import Any, Dict, Optional, Union import aiofiles -from connectors.source import BaseDataSource +from connectors.source import BaseDataSource, DataSourceConfiguration from connectors.utils import TIKA_SUPPORTED_FILETYPES, get_base64_value, hash_id -DEFAULT_DIR = os.environ.get("SYSTEM_DIR", os.path.dirname(__file__)) +DEFAULT_DIR: str = os.environ.get("SYSTEM_DIR", os.path.dirname(__file__)) class DirectoryDataSource(BaseDataSource): @@ -26,13 +27,13 @@ class DirectoryDataSource(BaseDataSource): name = "System Directory" service_type = "dir" - def __init__(self, configuration): + def __init__(self, configuration: DataSourceConfiguration) -> None: super().__init__(configuration=configuration) self.directory = os.path.abspath(self.configuration["directory"]) self.pattern = self.configuration["pattern"] @classmethod - def get_default_configuration(cls): + def get_default_configuration(cls) -> Dict[str, Dict[str, Any]]: return { "directory": { "label": "Directory path", @@ -50,16 +51,18 @@ def get_default_configuration(cls): }, } - async def ping(self): + async def ping(self) -> bool: return True - async def changed(self): + async def changed(self) -> bool: return True - def get_id(self, path): + def get_id(self, path: Union[str, PosixPath]) -> str: return hash_id(str(path)) - async def _download(self, path, timestamp=None, doit=None): + async def _download( + self, path: str, timestamp: Optional[str] = None, doit: Optional[bool] = None + ) -> Optional[Dict[str, Any]]: if not (doit and os.path.splitext(path)[-1] in TIKA_SUPPORTED_FILETYPES): return diff --git a/connectors/sources/dropbox.py b/connectors/sources/dropbox.py index fc71eb67e..678206a83 100644 --- a/connectors/sources/dropbox.py +++ b/connectors/sources/dropbox.py @@ -7,14 +7,21 @@ import json import os +from _asyncio import Future, Task from datetime import datetime from enum import Enum from functools import cached_property, partial +from typing import Any, Dict, Generator, Iterator, List, Optional, Tuple, Union from urllib import parse import aiohttp import fastjsonschema -from aiohttp.client_exceptions import ClientResponseError, ServerDisconnectedError +from aiohttp.client import ClientSession +from aiohttp.client_exceptions import ( + ClientResponseError, + ServerDisconnectedError, + ServerTimeoutError, +) from connectors.access_control import ( ACCESS_CONTROL, @@ -26,7 +33,7 @@ SyncRuleValidationResult, ) from connectors.logger import logger -from connectors.source import BaseDataSource +from connectors.source import BaseDataSource, DataSourceConfiguration from connectors.utils import ( CancellableSleeps, RetryStrategy, @@ -61,13 +68,13 @@ ) logger.warning("IT'S SUPPOSED TO BE USED ONLY FOR TESTING") logger.warning("x" * 100) - BASE_URLS = { + BASE_URLS: Dict[str, str] = { "ACCESS_TOKEN_BASE_URL": os.environ["DROPBOX_API_URL"], "FILES_FOLDERS_BASE_URL": os.environ["DROPBOX_API_URL_V2"], "DOWNLOAD_BASE_URL": os.environ["DROPBOX_API_URL_V2"], } else: - BASE_URLS = { + BASE_URLS: Dict[str, str] = { "ACCESS_TOKEN_BASE_URL": "https://api.dropboxapi.com/", "FILES_FOLDERS_BASE_URL": f"https://api.dropboxapi.com/{API_VERSION}/", "DOWNLOAD_BASE_URL": f"https://content.dropboxapi.com/{API_VERSION}/", @@ -144,26 +151,26 @@ class BreakingField(Enum): HAS_MORE = "has_more" -def _prefix_user(user): +def _prefix_user(user: str) -> Optional[str]: if not user: return return prefix_identity("user", user) -def _prefix_user_id(user_id): +def _prefix_user_id(user_id: Union[str, int]) -> Optional[str]: return prefix_identity("user_id", user_id) -def _prefix_email(email): +def _prefix_email(email: str) -> Optional[str]: return prefix_identity("email", email) -def _prefix_group(group): +def _prefix_group(group: str) -> Optional[str]: return prefix_identity("group", group) class DropboxClient: - def __init__(self, configuration): + def __init__(self, configuration: DataSourceConfiguration) -> None: self._sleeps = CancellableSleeps() self.configuration = configuration self.path = ( @@ -181,7 +188,7 @@ def __init__(self, configuration): self.root_namespace_id = None self._logger = logger - def set_logger(self, logger_): + def set_logger(self, logger_) -> None: self._logger = logger_ @retryable( @@ -189,7 +196,7 @@ def set_logger(self, logger_): interval=RETRY_INTERVAL, strategy=RetryStrategy.EXPONENTIAL_BACKOFF, ) - async def _set_access_token(self): + async def _set_access_token(self) -> None: if self.token_expiration_time and ( not isinstance(self.token_expiration_time, datetime) ): @@ -218,7 +225,7 @@ async def _set_access_token(self): ) self._logger.debug("Access Token generated successfully") - def check_errors(self, response): + def check_errors(self, response: Dict[str, str]): error_response = response.get("error") if error_response == "invalid_grant": msg = "Configured Refresh Token is invalid." @@ -231,18 +238,20 @@ def check_errors(self, response): raise Exception(msg) @cached_property - def _get_session(self): + def _get_session(self) -> ClientSession: self._logger.debug("Generating aiohttp client session") timeout = aiohttp.ClientTimeout(total=None) return aiohttp.ClientSession(timeout=timeout, raise_for_status=True) - async def close(self): + async def close(self) -> None: self._sleeps.cancel() await self._get_session.close() del self._get_session - def _get_request_headers(self, file_type, url_name, **kwargs): + def _get_request_headers( + self, file_type: Optional[str], url_name: str, **kwargs + ) -> Dict[str, str]: kwargs = kwargs["kwargs"] request_headers = { "Authorization": f"Bearer {self.access_token}", @@ -283,7 +292,13 @@ def _get_request_headers(self, file_type, url_name, **kwargs): ) return request_headers - def _get_retry_after(self, retry, exception): + def _get_retry_after( + self, + retry: int, + exception: Union[ + ServerDisconnectedError, ClientResponseError, ServerTimeoutError, Exception + ], + ) -> Tuple[int, int]: self._logger.warning( f"Retry count: {retry} out of {self.retry_count}. Exception: {exception}" ) @@ -292,7 +307,9 @@ def _get_retry_after(self, retry, exception): retry += 1 return retry, RETRY_INTERVAL**retry - async def _handle_client_errors(self, retry, exception): + async def _handle_client_errors( + self, retry: int, exception: ClientResponseError + ) -> Generator[Task, None, int]: retry, retry_seconds = self._get_retry_after(retry=retry, exception=exception) match exception.status: case 401: @@ -319,7 +336,14 @@ async def _handle_client_errors(self, retry, exception): await self._sleeps.sleep(retry_seconds) return retry - async def api_call(self, base_url, url_name, data=None, file_type=None, **kwargs): + async def api_call( + self, + base_url: str, + url_name: str, + data: Optional[str] = None, + file_type: None = None, + **kwargs, + ) -> Iterator[Optional[Task]]: retry = 1 url = parse.urljoin(base_url, url_name) while True: @@ -352,8 +376,12 @@ async def api_call(self, base_url, url_name, data=None, file_type=None, **kwargs await self._sleeps.sleep(retry_seconds) async def _paginated_api_call( - self, base_url, breaking_field, continue_endpoint=None, **kwargs - ): + self, + base_url: str, + breaking_field: str, + continue_endpoint: Optional[str] = None, + **kwargs, + ) -> None: """Make a paginated API call for fetching Dropbox files/folders. Args: @@ -404,7 +432,7 @@ async def _paginated_api_call( ) return - async def ping(self, endpoint): + async def ping(self, endpoint: str) -> Iterator[Optional[Task]]: return await anext( self.api_call( base_url=BASE_URLS["FILES_FOLDERS_BASE_URL"], @@ -611,10 +639,15 @@ class DropBoxAdvancedRulesValidator(AdvancedRulesValidator): SCHEMA = fastjsonschema.compile(definition=SCHEMA_DEFINITION) - def __init__(self, source): + def __init__(self, source: "DropboxDataSource") -> None: self.source = source - async def validate(self, advanced_rules): + async def validate( + self, + advanced_rules: List[ + Dict[str, Union[str, Dict[str, Union[Dict[str, str], str]]]] + ], + ) -> SyncRuleValidationResult: if len(advanced_rules) == 0: return SyncRuleValidationResult.valid_result( SyncRuleValidationResult.ADVANCED_RULES @@ -627,7 +660,12 @@ async def validate(self, advanced_rules): interval=RETRY_INTERVAL, strategy=RetryStrategy.EXPONENTIAL_BACKOFF, ) - async def _remote_validation(self, advanced_rules): + async def _remote_validation( + self, + advanced_rules: List[ + Dict[str, Union[str, Dict[str, Union[Dict[str, str], str]]]] + ], + ) -> SyncRuleValidationResult: try: DropBoxAdvancedRulesValidator.SCHEMA(advanced_rules) except fastjsonschema.JsonSchemaValueException as e: @@ -666,7 +704,7 @@ class DropboxDataSource(BaseDataSource): dls_enabled = True incremental_sync_enabled = True - def __init__(self, configuration): + def __init__(self, configuration: DataSourceConfiguration) -> None: """Setup the connection to the Dropbox Args: @@ -679,11 +717,20 @@ def __init__(self, configuration): "include_inherited_users_and_groups" ] - def _set_internal_logger(self): + def _set_internal_logger(self) -> None: self.dropbox_client.set_logger(self._logger) @classmethod - def get_default_configuration(cls): + def get_default_configuration( + cls, + ) -> Dict[ + str, + Union[ + Dict[str, Union[List[Dict[str, Union[bool, str]]], int, str]], + Dict[str, Union[List[str], int, str]], + Dict[str, Union[int, str]], + ], + ]: """Get the default configuration for Dropbox Returns: @@ -761,7 +808,7 @@ def get_default_configuration(cls): }, } - def _dls_enabled(self): + def _dls_enabled(self) -> bool: """Check if document level security is enabled. This method checks whether document level security (DLS) is enabled based on the provided configuration. Returns: bool: True if document level security is enabled, False otherwise. @@ -774,14 +821,18 @@ def _dls_enabled(self): return self.configuration["use_document_level_security"] - def _decorate_with_access_control(self, document, access_control): + def _decorate_with_access_control( + self, document: Dict[str, Union[str, int, List[str]]], access_control: List[str] + ) -> Dict[str, Union[str, int, List[str]]]: if self._dls_enabled(): document[ACCESS_CONTROL] = list( set(document.get(ACCESS_CONTROL, []) + access_control) ) return document - async def _user_access_control_doc(self, user): + async def _user_access_control_doc( + self, user: Dict[str, Dict[str, Union[str, Dict[str, str], List[str]]]] + ) -> Dict[str, Any]: profile = user.get("profile", {}) email = profile.get("email") username = profile.get("name", {}).get("display_name") @@ -818,14 +869,14 @@ async def get_access_control(self): for user in users.get("members", []): yield await self._user_access_control_doc(user=user) - async def validate_config(self): + async def validate_config(self) -> None: """Validates whether user input is empty or not for configuration fields Also validate, if user configured path is available in Dropbox.""" await super().validate_config() await self._remote_validation() - async def _remote_validation(self): + async def _remote_validation(self) -> None: try: if self.dropbox_client.path not in ["", None]: await self.set_user_info() @@ -842,10 +893,10 @@ async def _remote_validation(self): msg = f"Error while validating path: {self.dropbox_client.path}. Error: {exception}" raise Exception(msg) from exception - def advanced_rules_validators(self): + def advanced_rules_validators(self) -> List[DropBoxAdvancedRulesValidator]: return [DropBoxAdvancedRulesValidator(self)] - def tweak_bulk_options(self, options): + def tweak_bulk_options(self, options: Dict[str, int]) -> None: """Tweak bulk options as per concurrent downloads support by dropbox Args: @@ -853,10 +904,10 @@ def tweak_bulk_options(self, options): """ options["concurrent_downloads"] = self.concurrent_downloads - async def close(self): + async def close(self) -> None: await self.dropbox_client.close() - async def ping(self): + async def ping(self) -> None: try: endpoint = EndpointName.AUTHENTICATED_ADMIN.value await self.dropbox_client.ping(endpoint=endpoint) @@ -869,7 +920,7 @@ async def ping(self): else: raise - async def set_user_info(self): + async def set_user_info(self) -> None: try: endpoint = EndpointName.AUTHENTICATED_ADMIN.value response = await self.dropbox_client.ping(endpoint=endpoint) @@ -890,8 +941,13 @@ async def set_user_info(self): ) or root_info.get("root_namespace_id") async def get_content( - self, attachment, is_shared=False, folder_id=None, timestamp=None, doit=False - ): + self, + attachment: Dict[str, Union[bool, str, int]], + is_shared: bool = False, + folder_id: None = None, + timestamp: None = None, + doit: bool = False, + ) -> Generator[Future, None, Optional[Dict[str, str]]]: """Extracts the content for allowed file types. Args: @@ -937,7 +993,13 @@ async def get_content( ), ) - def download_func(self, is_shared, attachment, filename, folder_id): + def download_func( + self, + is_shared: bool, + attachment: Dict[str, Union[bool, str, int]], + filename: str, + folder_id: None, + ) -> Optional[partial]: if is_shared: return partial( self.dropbox_client.download_shared_file, url=attachment["url"] @@ -957,7 +1019,9 @@ def download_func(self, is_shared, attachment, filename, folder_id): else: return - def _adapt_dropbox_doc_to_es_doc(self, response): + def _adapt_dropbox_doc_to_es_doc( + self, response: Dict[str, Union[bool, str, int]] + ) -> Dict[str, Any]: is_file = response.get(".tag") == "file" if is_file and response.get("name").split(".")[-1] == PAPER: timestamp = response.get("client_modified") @@ -972,7 +1036,9 @@ def _adapt_dropbox_doc_to_es_doc(self, response): "_timestamp": timestamp, } - def _adapt_dropbox_shared_file_doc_to_es_doc(self, response): + def _adapt_dropbox_shared_file_doc_to_es_doc( + self, response: Dict[str, Union[bool, str, int]] + ) -> Dict[str, Any]: return { "_id": response.get("id"), "type": FILE, @@ -1003,7 +1069,12 @@ async def _fetch_shared_files(self): json_metadata, ) - async def advanced_sync(self, rule): + async def advanced_sync( + self, + rule: Dict[ + str, Union[str, Dict[str, List[Dict[str, str]]], Dict[str, List[str]]] + ], + ) -> None: async for response in self.dropbox_client.search_files_folders(rule=rule): for entry in response.get("matches"): data = entry.get("metadata", {}).get("metadata") @@ -1023,15 +1094,31 @@ async def advanced_sync(self, rule): else: yield self._adapt_dropbox_doc_to_es_doc(response=data), data - def get_group_id(self, permission, identity): + def get_group_id( + self, + permission: Dict[str, Union[Dict[str, str], Dict[str, Union[str, int]]]], + identity: str, + ) -> str: if identity in permission: return permission.get(identity).get("group_id") - def get_email(self, permission, identity): + def get_email(self, permission: Dict[str, Dict[str, str]], identity: str) -> str: if identity in permission: return permission.get(identity).get("email") - async def get_permission(self, permission, account_id): + async def get_permission( + self, + permission: Dict[ + str, + Union[ + List[Dict[str, Dict[str, str]]], + str, + bool, + List[Union[Any, Dict[str, Dict[str, Union[str, int]]]]], + ], + ], + account_id: int, + ) -> List[Optional[str]]: permissions = [] if identities := permission.get("users"): for identity in identities: @@ -1056,7 +1143,9 @@ async def get_permission(self, permission, account_id): permissions.append(_prefix_user_id(account_id)) return permissions - async def get_folder_permission(self, shared_folder_id, account_id): + async def get_folder_permission( + self, shared_folder_id: int, account_id: int + ) -> List[str]: if ( not shared_folder_id or shared_folder_id == self.dropbox_client.root_namespace_id @@ -1070,7 +1159,9 @@ async def get_folder_permission(self, shared_folder_id, account_id): permission=permission, account_id=account_id ) - async def get_file_permission_without_batching(self, file_id, account_id): + async def get_file_permission_without_batching( + self, file_id: int, account_id: int + ) -> List[str]: async for ( permission ) in self.dropbox_client.list_file_permission_without_batching(file_id=file_id): @@ -1078,7 +1169,7 @@ async def get_file_permission_without_batching(self, file_id, account_id): permission=permission, account_id=account_id ) - async def get_account_details(self): + async def get_account_details(self) -> Tuple[str, str]: response = await anext( self.dropbox_client.api_call( base_url=BASE_URLS["FILES_FOLDERS_BASE_URL"], @@ -1091,7 +1182,9 @@ async def get_account_details(self): member_id = account_details.get("admin_profile", {}).get("team_member_id") return account_id, member_id - async def get_permission_list(self, item_type, item, account_id): + async def get_permission_list( + self, item_type: str, item: Dict[str, int], account_id: int + ) -> List[str]: if item_type == FOLDER: shared_folder_id = item.get("shared_folder_id") or item.get( "parent_shared_folder_id" @@ -1128,7 +1221,18 @@ async def map_permission_with_document(self, batched_document, account_id): file_id[1], ) - def document_tuple(self, document, attachment, folder_id=None): + def document_tuple( + self, + document: Dict[str, Union[str, int, List[str]]], + attachment: Union[ + Dict[str, Union[str, int]], str, Dict[str, Union[bool, str, int]] + ], + folder_id: None = None, + ) -> Union[ + Tuple[Dict[str, Union[List[str], int, str]], partial], + Tuple[Dict[str, Union[int, str]], None], + Tuple[Dict[str, Union[int, str]], partial], + ]: if document.get("type") == FILE: if document.get("url"): return document, partial( @@ -1144,7 +1248,7 @@ def document_tuple(self, document, attachment, folder_id=None): else: return document, None - async def add_document_to_list(self, func, account_id, is_shared=False): + async def add_document_to_list(self, func, account_id, is_shared: bool = False): batched_document = {} calling_func = func() if is_shared else func(path=self.dropbox_client.path) diff --git a/connectors/sources/generic_database.py b/connectors/sources/generic_database.py index 5339faccb..af13e1605 100644 --- a/connectors/sources/generic_database.py +++ b/connectors/sources/generic_database.py @@ -4,7 +4,10 @@ # you may not use this file except in compliance with the Elastic License 2.0. # import asyncio +from _asyncio import Future, Task from abc import ABC, abstractmethod +from functools import partial +from typing import Any, Dict, Iterator, List, Optional, Sized, Union from asyncpg.exceptions._base import InternalClientError from sqlalchemy.exc import ProgrammingError @@ -18,7 +21,7 @@ DEFAULT_WAIT_MULTIPLIER = 2 -def configured_tables(tables): +def configured_tables(tables: Union[str, List[str]]) -> List[Union[str, Any]]: """Split a string containing a comma-seperated list of tables by comma and strip the table names. Filter out `None` and zero-length values from the tables. @@ -43,11 +46,15 @@ def table_filter(table): ) -def is_wildcard(tables): +def is_wildcard(tables: Union[str, List[str]]) -> bool: return tables in (WILDCARD, [WILDCARD]) -def map_column_names(column_names, schema=None, tables=None): +def map_column_names( + column_names: List[Union[str, Any]], + schema: Optional[str] = None, + tables: Optional[Sized] = None, +) -> List[str]: prefix = "" if schema and len(schema.strip()) > 0: prefix += schema.strip() + "_" @@ -56,7 +63,9 @@ def map_column_names(column_names, schema=None, tables=None): return [f"{prefix}{column}".lower() for column in column_names] -def hash_id(tables, row, primary_key_columns): +def hash_id( + tables: List[str], row: Dict[str, Union[str, int]], primary_key_columns: List[str] +) -> str: """Generates an id using table names as prefix in sorted order and primary key values. Example: @@ -75,11 +84,11 @@ def hash_id(tables, row, primary_key_columns): async def fetch( - cursor_func, - fetch_columns=False, - fetch_size=DEFAULT_FETCH_SIZE, - retry_count=DEFAULT_RETRY_COUNT, -): + cursor_func: partial, + fetch_columns: bool = False, + fetch_size: int = DEFAULT_FETCH_SIZE, + retry_count: int = DEFAULT_RETRY_COUNT, +) -> Iterator[Union[Task, Future]]: @retryable( retries=retry_count, interval=DEFAULT_WAIT_MULTIPLIER, diff --git a/connectors/sources/github.py b/connectors/sources/github.py index 465f3eed5..b0bea34ba 100644 --- a/connectors/sources/github.py +++ b/connectors/sources/github.py @@ -7,12 +7,15 @@ import json import time +from _asyncio import Future, Task from enum import Enum from functools import cached_property, partial +from typing import Any, Dict, Generator, Iterator, List, Optional, Set, Tuple, Union import aiohttp import fastjsonschema import gidgethub +from aiohttp.client import ClientSession from gidgethub import QueryError, sansio from gidgethub.abc import ( BadGraphQLRequest, @@ -31,7 +34,11 @@ SyncRuleValidationResult, ) from connectors.logger import logger -from connectors.source import BaseDataSource, ConfigurableFieldValueError +from connectors.source import ( + BaseDataSource, + ConfigurableFieldValueError, + DataSourceConfiguration, +) from connectors.utils import ( CancellableSleeps, RetryStrategy, @@ -79,15 +86,15 @@ } -def _prefix_email(email): +def _prefix_email(email: str) -> Optional[str]: return prefix_identity("email", email) -def _prefix_username(user): +def _prefix_username(user: str) -> Optional[str]: return prefix_identity("username", user) -def _prefix_user_id(user_id): +def _prefix_user_id(user_id: str) -> Optional[str]: return prefix_identity("user_id", user_id) @@ -651,8 +658,15 @@ class ForbiddenException(Exception): class GitHubClient: def __init__( - self, auth_method, base_url, app_id, private_key, token, ssl_enabled, ssl_ca - ): + self, + auth_method: str, + base_url: str, + app_id: int, + private_key: str, + token: str, + ssl_enabled: bool, + ssl_ca: str, + ) -> None: self._sleeps = CancellableSleeps() self._logger = logger self.auth_method = auth_method @@ -684,10 +698,10 @@ def __init__( # a variable to hold the current installation id, used to refresh the access token self._installation_id = None - def set_logger(self, logger_): + def set_logger(self, logger_) -> None: self._logger = logger_ - async def _get_retry_after(self, resource_type): + async def _get_retry_after(self, resource_type: str) -> float: current_time = time.time() response = await self.get_github_item("/rate_limit") reset = nested_get_from_dict( @@ -696,7 +710,7 @@ async def _get_retry_after(self, resource_type): # Adding a 5 second delay to account for server delays return (reset - current_time) + 5 # pyright: ignore - async def _put_to_sleep(self, resource_type): + async def _put_to_sleep(self, resource_type: str) -> Iterator[Task]: retry_after = await self._get_retry_after(resource_type=resource_type) self._logger.debug( f"Rate limit exceeded. Retry after {retry_after} seconds. Resource type: {resource_type}" @@ -705,7 +719,7 @@ async def _put_to_sleep(self, resource_type): msg = "Rate limit exceeded." raise Exception(msg) - def _access_token(self): + def _access_token(self) -> str: if self.auth_method == PERSONAL_ACCESS_TOKEN: return self._personal_access_token if not self._installation_access_token: @@ -713,7 +727,7 @@ def _access_token(self): return self._installation_access_token # update the current installation id and re-generate access token - async def update_installation_id(self, installation_id): + async def update_installation_id(self, installation_id: int) -> None: self._logger.debug( f"Updating installation id - new ID: {installation_id}, original ID: {self._installation_id}" ) @@ -725,7 +739,7 @@ async def update_installation_id(self, installation_id): interval=RETRY_INTERVAL, strategy=RetryStrategy.EXPONENTIAL_BACKOFF, ) - async def _update_installation_access_token(self): + async def _update_installation_access_token(self) -> None: try: access_token_response = await get_installation_access_token( gh=self._get_client, @@ -744,7 +758,7 @@ async def _update_installation_access_token(self): raise @cached_property - def _get_session(self): + def _get_session(self) -> ClientSession: connector = aiohttp.TCPConnector(ssl=self.ssl_ctx) timeout = aiohttp.ClientTimeout(total=None) return aiohttp.ClientSession( @@ -754,7 +768,7 @@ def _get_session(self): ) @cached_property - def _get_client(self): + def _get_client(self) -> GitHubAPI: return GitHubAPI( session=self._get_session, requester="", @@ -767,7 +781,9 @@ def _get_client(self): strategy=RetryStrategy.EXPONENTIAL_BACKOFF, skipped_exceptions=UnauthorizedException, ) - async def graphql(self, query, variables=None): + async def graphql( + self, query: str, variables: Optional[Dict[str, str]] = None + ) -> Generator[Task, None, Dict[str, str]]: """Invoke GraphQL request to fetch repositories, pull requests, and issues. Args: @@ -826,7 +842,9 @@ async def graphql(self, query, variables=None): strategy=RetryStrategy.EXPONENTIAL_BACKOFF, skipped_exceptions=UnauthorizedException, ) - async def get_github_item(self, resource): + async def get_github_item( + self, resource: str + ) -> Dict[str, Union[str, int, Dict[str, Dict[str, int]]]]: """Execute request using getitem method of GitHubAPI which is using REST API. Using Rest API for fetching files and folder along with content. @@ -864,7 +882,9 @@ async def get_github_item(self, resource): ) raise - async def paginated_api_call(self, query, variables, keys): + async def paginated_api_call( + self, query: str, variables: Dict[str, Optional[str]], keys: List[str] + ): """Make a paginated API call for fetching GitHub objects. Args: @@ -884,7 +904,7 @@ async def paginated_api_call(self, query, variables, keys): break variables["cursor"] = page_info["endCursor"] # pyright: ignore - def get_repo_details(self, repo_name): + def get_repo_details(self, repo_name: str) -> List[str]: return repo_name.split("/") @retryable( @@ -893,7 +913,7 @@ def get_repo_details(self, repo_name): strategy=RetryStrategy.EXPONENTIAL_BACKOFF, skipped_exceptions=UnauthorizedException, ) - async def get_personal_access_token_scopes(self): + async def get_personal_access_token_scopes(self) -> Set[str]: try: request_headers = sansio.create_headers( self._get_client.requester, @@ -926,7 +946,7 @@ async def get_personal_access_token_scopes(self): interval=RETRY_INTERVAL, strategy=RetryStrategy.EXPONENTIAL_BACKOFF, ) - async def _github_app_get(self, url): + async def _github_app_get(self, url: str) -> Optional[Tuple[bytes, Optional[str]]]: self._logger.debug(f"Making a get request to GitHub: {url}") try: return await self._get_client._make_request( @@ -993,7 +1013,7 @@ async def get_user_repos(self, user): ): yield repo - async def get_foreign_repo(self, repo_name): + async def get_foreign_repo(self, repo_name: str) -> None: owner, repo = self.get_repo_details(repo_name=repo_name) repo_variables = {"owner": owner, "repositoryName": repo} data = await self.graphql( @@ -1018,22 +1038,24 @@ async def _fetch_all_members(self, org_name): ): yield repo.get("node") - async def get_logged_in_user(self): + async def get_logged_in_user(self) -> Optional[str]: data = await self.graphql(query=GithubQuery.USER_QUERY.value) return nested_get_from_dict(data, ["viewer", "login"]) - async def ping(self): + async def ping(self) -> None: if self.auth_method == GITHUB_APP: await self._github_app_get(url="/app") else: await self.get_logged_in_user() - async def close(self): + async def close(self) -> None: self._sleeps.cancel() await self._get_session.close() del self._get_session - def bifurcate_repos(self, repos, owner): + def bifurcate_repos( + self, repos: List[str], owner: str + ) -> Tuple[List[str], List[str]]: foreign_repos, configured_repos = [], [] for repo_name in repos: if repo_name not in ["", None]: @@ -1069,10 +1091,17 @@ class GitHubAdvancedRulesValidator(AdvancedRulesValidator): SCHEMA = fastjsonschema.compile(definition=SCHEMA_DEFINITION) - def __init__(self, source): + def __init__(self, source: "GitHubDataSource") -> None: self.source = source - async def validate(self, advanced_rules): + async def validate( + self, + advanced_rules: Union[ + List[Dict[str, Dict[str, str]]], + List[Dict[str, Union[Dict[str, str], str]]], + List[Dict[str, Union[str, List[Dict[str, str]]]]], + ], + ) -> SyncRuleValidationResult: if len(advanced_rules) == 0: return SyncRuleValidationResult.valid_result( SyncRuleValidationResult.ADVANCED_RULES @@ -1080,7 +1109,12 @@ async def validate(self, advanced_rules): return await self._remote_validation(advanced_rules) - async def _remote_validation(self, advanced_rules): + async def _remote_validation( + self, + advanced_rules: List[ + Dict[str, Union[Dict[str, str], str, List[Dict[str, str]]]] + ], + ) -> SyncRuleValidationResult: try: GitHubAdvancedRulesValidator.SCHEMA(advanced_rules) except fastjsonschema.JsonSchemaValueException as e: @@ -1114,7 +1148,7 @@ class GitHubDataSource(BaseDataSource): dls_enabled = True incremental_sync_enabled = True - def __init__(self, configuration): + def __init__(self, configuration: DataSourceConfiguration) -> None: """Setup the connection to the GitHub instance. Args: @@ -1143,14 +1177,24 @@ def __init__(self, configuration): # and the value is the installation id self._installations = {} - def _set_internal_logger(self): + def _set_internal_logger(self) -> None: self.github_client.set_logger(self._logger) - def advanced_rules_validators(self): + def advanced_rules_validators(self) -> List[GitHubAdvancedRulesValidator]: return [GitHubAdvancedRulesValidator(self)] @classmethod - def get_default_configuration(cls): + def get_default_configuration( + cls, + ) -> Dict[ + str, + Union[ + Dict[str, Union[List[Dict[str, str]], int, str]], + Dict[str, Union[List[Dict[str, Union[bool, str]]], int, str]], + Dict[str, Union[List[str], int, str]], + Dict[str, Union[int, str]], + ], + ]: """Get the default configuration for GitHub. Returns: @@ -1280,7 +1324,7 @@ def get_default_configuration(cls): }, } - def _dls_enabled(self): + def _dls_enabled(self) -> bool: """Check if document level security is enabled. This method checks whether document level security (DLS) is enabled based on the provided configuration. Returns: @@ -1297,7 +1341,7 @@ def _dls_enabled(self): and self.configuration["use_document_level_security"] ) - async def _logged_in_user(self): + async def _logged_in_user(self) -> Optional[str]: if self.configuration["auth_method"] != PERSONAL_ACCESS_TOKEN: return None if self._user: @@ -1305,7 +1349,7 @@ async def _logged_in_user(self): self._user = await self.github_client.get_logged_in_user() return self._user - async def get_invalid_repos(self): + async def get_invalid_repos(self) -> List[str]: self._logger.debug( "Checking if there are any inaccessible repositories configured" ) @@ -1314,7 +1358,7 @@ async def get_invalid_repos(self): else: return await self._get_invalid_repos_for_personal_access_token() - async def _get_invalid_repos_for_github_app(self): + async def _get_invalid_repos_for_github_app(self) -> List[str]: # A github app can be installed on multiple orgs/personal accounts, # so the repo must be configured in the format of 'OWNER/REPO', any other format will be rejected invalid_repos = set( @@ -1336,7 +1380,9 @@ async def _get_invalid_repos_for_github_app(self): return list(invalid_repos) - async def _get_repo_object_for_github_app(self, owner, repo_name): + async def _get_repo_object_for_github_app( + self, owner: str, repo_name: str + ) -> Optional[Dict[str, str]]: await self._fetch_installations() full_repo_name = f"{owner}/{repo_name}" if owner not in self._installations: @@ -1369,7 +1415,7 @@ async def _get_repo_object_for_github_app(self, owner, repo_name): return cached_repo[owner][full_repo_name] - async def _get_invalid_repos_for_personal_access_token(self): + async def _get_invalid_repos_for_personal_access_token(self) -> List[str]: try: if self.configuration["repo_type"] == "other": logged_in_user = await self._logged_in_user() @@ -1421,7 +1467,7 @@ async def _get_invalid_repos_for_personal_access_token(self): ) raise - async def _user_access_control_doc(self, user): + async def _user_access_control_doc(self, user: Dict[str, str]) -> Dict[str, Any]: user_id = user.get("id", "") user_name = user.get("login", "") user_email = user.get("email", "") @@ -1441,7 +1487,7 @@ async def _user_access_control_doc(self, user): access_control=[_prefixed_user_id, _prefixed_user_name, _prefixed_email] ) - async def get_access_control(self): + async def get_access_control(self) -> None: if not self._dls_enabled(): self._logger.warning("DLS is not enabled. Skipping") return @@ -1463,7 +1509,7 @@ async def _get_owners(self): else: yield await self._logged_in_user() - async def _remote_validation(self): + async def _remote_validation(self) -> None: """Validate scope of the configured personal access token and accessibility of repositories Raises: @@ -1473,7 +1519,7 @@ async def _remote_validation(self): await self._validate_personal_access_token_scopes() await self._validate_configured_repos() - async def _validate_personal_access_token_scopes(self): + async def _validate_personal_access_token_scopes(self) -> None: if self.configuration["auth_method"] != PERSONAL_ACCESS_TOKEN: return @@ -1494,7 +1540,7 @@ async def _validate_personal_access_token_scopes(self): msg = "Configured token does not have required rights to fetch the content. Required scopes are 'repo', 'user', and 'read:org'." raise ConfigurableFieldValueError(msg) - async def _validate_configured_repos(self): + async def _validate_configured_repos(self) -> None: if WILDCARD in self.configured_repos: return @@ -1503,17 +1549,17 @@ async def _validate_configured_repos(self): msg = f"Inaccessible repositories '{', '.join(invalid_repos)}'." raise ConfigurableFieldValueError(msg) - async def validate_config(self): + async def validate_config(self) -> None: """Validates whether user input is empty or not for configuration fields Also validate, if user configured repositories are accessible or not and scope of the token """ await super().validate_config() await self._remote_validation() - async def close(self): + async def close(self) -> None: await self.github_client.close() - async def ping(self): + async def ping(self) -> None: try: await self.github_client.ping() self._logger.debug("Successfully connected to GitHub.") @@ -1521,13 +1567,19 @@ async def ping(self): self._logger.exception("Error while connecting to GitHub.") raise - def adapt_gh_doc_to_es_doc(self, github_document, schema): + def adapt_gh_doc_to_es_doc( + self, github_document: Dict[str, Union[str, int]], schema: Dict[str, str] + ) -> Dict[str, Union[str, int]]: return { es_field: github_document[github_field] for es_field, github_field in schema.items() } - def _prepare_pull_request_doc(self, pull_request, reviews): + def _prepare_pull_request_doc( + self, + pull_request: Dict[str, Any], + reviews: List[Dict[str, Union[str, List[Dict[str, str]]]]], + ) -> Dict[str, Any]: return { "_id": pull_request.pop("id"), "_timestamp": pull_request.pop("updatedAt"), @@ -1539,7 +1591,7 @@ def _prepare_pull_request_doc(self, pull_request, reviews): "requested_reviewers": pull_request.get("reviewRequests", {}).get("nodes"), } - def _prepare_issue_doc(self, issue): + def _prepare_issue_doc(self, issue: Dict[str, Any]) -> Dict[str, Any]: return { "_id": issue.pop("id"), "type": ObjectType.ISSUE.value, @@ -1549,7 +1601,19 @@ def _prepare_issue_doc(self, issue): "assignees_list": issue.get("assignees", {}).get("nodes"), } - def _prepare_review_doc(self, review): + def _prepare_review_doc( + self, + review: Dict[ + str, + Optional[ + Union[ + Dict[str, str], + str, + Dict[str, Union[Dict[str, Union[bool, str]], List[Dict[str, str]]]], + ] + ], + ], + ) -> Dict[str, Any]: # review.author can be None if the user was deleted, so need to be extra null-safe author = review.get("author", {}) or {} @@ -1560,7 +1624,7 @@ def _prepare_review_doc(self, review): "comments": review.get("comments", {}).get("nodes"), } - async def _fetch_installations(self): + async def _fetch_installations(self) -> Dict[str, Union[Dict[str, int], int]]: """Fetches GitHub App installations, and populates instance variable self._installations Only populates Organization installations when repo_type is organization, and only populates User installations when repo_type is other """ @@ -1676,7 +1740,9 @@ async def _fetch_repos(self): exc_info=True, ) - def _convert_repo_object_to_doc(self, repo_object): + def _convert_repo_object_to_doc( + self, repo_object: Dict[str, str] + ) -> Dict[str, str]: repo_object = repo_object.copy() repo_object.update( { @@ -1698,8 +1764,13 @@ async def _fetch_remaining_data( ) async def _fetch_remaining_fields( - self, type_obj, object_type, owner, repo, field_type - ): + self, + type_obj: Dict[str, Any], + object_type: str, + owner: str, + repo: str, + field_type: str, + ) -> None: sample_dict = { "reviews": { "query": GithubQuery.REVIEW_QUERY.value, @@ -1751,7 +1822,9 @@ async def _fetch_remaining_fields( else: type_obj[sample_dict[field_type]["es_field"]].extend(response) - async def _extract_pull_request(self, pull_request, owner, repo): + async def _extract_pull_request( + self, pull_request: Dict[str, Any], owner: str, repo: str + ): reviews = [ self._prepare_review_doc(review=review) for review in pull_request.get("reviews", {}).get("nodes") @@ -1772,10 +1845,10 @@ async def _extract_pull_request(self, pull_request, owner, repo): async def _fetch_pull_requests( self, - repo_name, - response_key, - filter_query=None, - ): + repo_name: str, + response_key: List[str], + filter_query: None = None, + ) -> None: self._logger.info( f"Fetching pull requests from '{repo_name}' with response_key '{response_key}' and filter query: '{filter_query}'" ) @@ -1832,9 +1905,9 @@ async def _extract_issues(self, response, owner, repo, response_key): async def _fetch_issues( self, - repo_name, - response_key, - filter_query=None, + repo_name: str, + response_key: List[str], + filter_query: None = None, ): self._logger.info( f"Fetching issues from repo: {repo_name} with response_key: '{response_key}' and filter_query: '{filter_query}'" @@ -1871,7 +1944,7 @@ async def _fetch_issues( exc_info=True, ) - async def _fetch_last_commit_timestamp(self, repo_name, path): + async def _fetch_last_commit_timestamp(self, repo_name: str, path: str) -> str: commit, *_ = await self.github_client.get_github_item( # pyright: ignore resource=self.github_client.endpoints["COMMITS"].format( repo_name=repo_name, path=path @@ -1879,7 +1952,12 @@ async def _fetch_last_commit_timestamp(self, repo_name, path): ) return commit["commit"]["committer"]["date"] - async def _format_file_document(self, repo_object, repo_name, schema): + async def _format_file_document( + self, + repo_object: Dict[str, Union[str, int]], + repo_name: str, + schema: Dict[str, str], + ) -> Tuple[Dict[str, Union[int, str]], Dict[str, Union[int, str]]]: file_name = repo_object["path"].split("/")[-1] file_extension = ( file_name[file_name.rfind(".") :] if "." in file_name else "" # noqa @@ -1904,7 +1982,7 @@ async def _format_file_document(self, repo_object, repo_name, schema): document["_id"] = f"{repo_name}/{repo_object['path']}" return document, repo_object - async def _fetch_files(self, repo_name, default_branch): + async def _fetch_files(self, repo_name: str, default_branch: str): self._logger.info( f"Fetching files from repo: '{repo_name}' (branch: '{default_branch}')" ) @@ -1956,7 +2034,12 @@ async def _fetch_files_by_path(self, repo_name, path): exc_info=True, ) - async def get_content(self, attachment, timestamp=None, doit=False): + async def get_content( + self, + attachment: Dict[str, Union[str, int]], + timestamp: None = None, + doit: bool = False, + ) -> Generator[Future, None, Optional[Dict[str, str]]]: """Extracts the content for Apache TIKA supported file types. Args: @@ -1997,7 +2080,9 @@ async def download_func(self, url): else: yield - def _filter_rule_query(self, repo, query, query_type): + def _filter_rule_query( + self, repo: str, query: str, query_type: str + ) -> Tuple[bool, str]: """ Filters a query based on the query type. @@ -2024,20 +2109,22 @@ def _filter_rule_query(self, repo, query, query_type): else: return False, query - def is_previous_repo(self, repo_name): + def is_previous_repo(self, repo_name: str) -> bool: if repo_name in self.prev_repos: return True self.prev_repos.append(repo_name) return False - def _decorate_with_access_control(self, document, access_control): + def _decorate_with_access_control( + self, document: Dict[str, Any], access_control: List[str] + ) -> Dict[str, Any]: if self._dls_enabled(): document[ACCESS_CONTROL] = list( set(document.get(ACCESS_CONTROL, []) + access_control) ) return document - async def _fetch_access_control(self, repo_name): + async def _fetch_access_control(self, repo_name: str) -> List[Optional[str]]: owner, repo = self.github_client.get_repo_details(repo_name) collaborator_variables = { "orgName": owner, diff --git a/connectors/sources/gmail.py b/connectors/sources/gmail.py index 294e659cb..653057c37 100644 --- a/connectors/sources/gmail.py +++ b/connectors/sources/gmail.py @@ -4,6 +4,8 @@ # you may not use this file except in compliance with the Elastic License 2.0. # from functools import cached_property +from typing import Any, Dict, List, Optional, Union +from unittest.mock import AsyncMock import fastjsonschema from aiogoogle import AuthError @@ -14,7 +16,12 @@ AdvancedRulesValidator, SyncRuleValidationResult, ) -from connectors.source import BaseDataSource, ConfigurableFieldValueError +from connectors.protocol.connectors import Filter +from connectors.source import ( + BaseDataSource, + ConfigurableFieldValueError, + DataSourceConfiguration, +) from connectors.sources.google import ( GMailClient, GoogleDirectoryClient, @@ -54,7 +61,9 @@ class GMailAdvancedRulesValidator(AdvancedRulesValidator): definition=SCHEMA_DEFINITION, ) - async def validate(self, advanced_rules): + async def validate( + self, advanced_rules: Dict[str, Union[List[int], List[str]]] + ) -> SyncRuleValidationResult: if len(advanced_rules) == 0: return SyncRuleValidationResult.valid_result( SyncRuleValidationResult.ADVANCED_RULES @@ -74,7 +83,7 @@ async def validate(self, advanced_rules): ) -def _message_doc(message): +def _message_doc(message: Dict[str, Optional[Union[int, str]]]) -> Dict[str, Any]: timestamp_field = "_timestamp" # We're using the `_attachment` field here so the attachment processor on the ES side decodes the base64 value @@ -98,7 +107,7 @@ def _message_doc(message): return es_doc -def _filtering_enabled(filtering): +def _filtering_enabled(filtering: Optional[Filter]) -> bool: return filtering is not None and filtering.has_advanced_rules() @@ -111,11 +120,18 @@ class GMailDataSource(BaseDataSource): dls_enabled = True incremental_sync_enabled = True - def __init__(self, configuration): + def __init__(self, configuration: DataSourceConfiguration) -> None: super().__init__(configuration=configuration) @classmethod - def get_default_configuration(cls): + def get_default_configuration( + cls, + ) -> Dict[ + str, + Union[ + Dict[str, Union[List[Dict[str, str]], int, str]], Dict[str, Union[int, str]] + ], + ]: """Get the default configuration for the GMail connector. Returns: dict: Default configuration. @@ -165,7 +181,7 @@ def get_default_configuration(cls): }, } - async def validate_config(self): + async def validate_config(self) -> None: """Validates whether user inputs are valid or not for configuration fields. Raises: @@ -187,7 +203,7 @@ async def validate_config(self): await self._validate_google_directory_auth() await self._validate_gmail_auth() - async def _validate_gmail_auth(self): + async def _validate_gmail_auth(self) -> None: """ Validates, whether the provided configuration values allow the connector to authenticate against GMail API. Failed authentication indicates, that either the provided credentials are incorrect or mandatory GMail API @@ -203,7 +219,7 @@ async def _validate_gmail_auth(self): msg = f"GMail authentication was not successful. Check the values of the following fields: '{SERVICE_ACCOUNT_CREDENTIALS_LABEL}', '{SUBJECT_LABEL}' and '{CUSTOMER_ID_LABEL}'. Also make sure that the OAuth2 scopes for GMail are setup correctly." raise ConfigurableFieldValueError(msg) from e - async def _validate_google_directory_auth(self): + async def _validate_google_directory_auth(self) -> None: """ Validates, whether the provided configuration values allow the connector to authenticate against Google Directory API. Failed authentication indicates, that either the provided credentials are incorrect or mandatory @@ -220,10 +236,10 @@ async def _validate_google_directory_auth(self): msg = f"Google Directory authentication was not successful. Check the values of the following fields: '{SERVICE_ACCOUNT_CREDENTIALS_LABEL}', '{SUBJECT_LABEL}' and '{CUSTOMER_ID_LABEL}'. Also make sure that the OAuth2 scopes for Google Directory are setup correctly." raise ConfigurableFieldValueError(msg) from e - def advanced_rules_validators(self): + def advanced_rules_validators(self) -> List[GMailAdvancedRulesValidator]: return [GMailAdvancedRulesValidator()] - def _set_internal_logger(self): + def _set_internal_logger(self) -> None: self._google_directory_client.set_logger(self._logger) @cached_property @@ -234,14 +250,14 @@ def _service_account_credentials(self): return service_account_credentials @cached_property - def _google_directory_client(self): + def _google_directory_client(self) -> GoogleDirectoryClient: return GoogleDirectoryClient( json_credentials=self._service_account_credentials, customer_id=self.configuration["customer_id"], subject=self.configuration["subject"], ) - def _gmail_client(self, subject): + def _gmail_client(self, subject: str) -> GMailClient: """Instantiates a GMail client for a corresponding subject. Args: subject (str): Email address for the subject. @@ -257,7 +273,7 @@ def _gmail_client(self, subject): gmail_client.set_logger(self._logger) return gmail_client - async def ping(self): + async def ping(self) -> None: for service_name, client in [ ("GMail", self._gmail_client(self.configuration["subject"])), ("Google directory", self._google_directory_client), @@ -269,17 +285,21 @@ async def ping(self): self._logger.exception(f"Error while connecting to {service_name}.") raise - def _dls_enabled(self): + def _dls_enabled(self) -> bool: return ( self._features is not None and self._features.document_level_security_enabled() and self.configuration["use_document_level_security"] ) - def access_control_query(self, access_control): + def access_control_query( + self, access_control: List[str] + ) -> Dict[str, Dict[str, Dict[str, Union[Dict[str, List[str]], str]]]]: return es_access_control_query(access_control) - def _user_access_control_doc(self, user, access_control): + def _user_access_control_doc( + self, user: Dict[str, str], access_control: List[str] + ) -> Dict[str, Any]: email = user.get(UserFields.EMAIL.value) created_at = user.get(UserFields.CREATION_DATE.value) @@ -289,7 +309,9 @@ def _user_access_control_doc(self, user, access_control): "created_at": created_at or iso_utc(), } | self.access_control_query(access_control) - def _decorate_with_access_control(self, document, access_control): + def _decorate_with_access_control( + self, document: Dict[str, str], access_control: List[str] + ) -> Dict[str, Union[str, List[str]]]: if self._dls_enabled(): document[ACCESS_CONTROL] = list( set(document.get(ACCESS_CONTROL, []) + access_control) @@ -297,7 +319,7 @@ def _decorate_with_access_control(self, document, access_control): return document - async def get_access_control(self): + async def get_access_control(self) -> None: """Yields all users found in the Google Workspace associated with the configured service account. Yields: @@ -316,8 +338,11 @@ async def get_access_control(self): yield self._user_access_control_doc(user, access_control) async def _message_doc_with_access_control( - self, access_control, gmail_client, message - ): + self, + access_control: List[str], + gmail_client: AsyncMock, + message: Dict[str, str], + ) -> Dict[str, Union[str, List[str]]]: message_id = message.get("id") message_content = await gmail_client.message(message_id) message_content["id"] = message_id @@ -329,7 +354,7 @@ async def _message_doc_with_access_control( return message_doc_with_access_control - async def get_docs(self, filtering=None): + async def get_docs(self, filtering: None = None) -> None: """Yields messages for all users present in the Google Workspace. Includes spam and trash messages, if the corresponding configuration value is set to `True`. diff --git a/connectors/sources/google.py b/connectors/sources/google.py index 286d5fe5e..ac433345e 100644 --- a/connectors/sources/google.py +++ b/connectors/sources/google.py @@ -5,13 +5,16 @@ # import json import os +from _asyncio import Future, Task +from asyncio.tasks import _GatheringFuture from enum import Enum +from typing import Any, Callable, Dict, Generator, Iterator, List, Optional, Union from aiogoogle import Aiogoogle, AuthError, HTTPError from aiogoogle.auth.creds import ServiceAccountCreds from aiogoogle.sessions.aiohttp_session import AiohttpSession -from connectors.logger import logger +from connectors.logger import ExtraLogger, logger from connectors.source import ConfigurableFieldValueError from connectors.utils import RetryStrategy, retryable @@ -22,14 +25,14 @@ "universe_domain" } -GOOGLE_API_FTEST_HOST = os.environ.get("GOOGLE_API_FTEST_HOST") -RUNNING_FTEST = ( +GOOGLE_API_FTEST_HOST: Optional[str] = os.environ.get("GOOGLE_API_FTEST_HOST") +RUNNING_FTEST: bool = ( "RUNNING_FTEST" in os.environ ) # Flag to check if a connector is run for ftest or not. RETRIES = 3 RETRY_INTERVAL = 2 -DEFAULT_TIMEOUT = 1 * 60 # 1 min +DEFAULT_TIMEOUT: int = 1 * 60 # 1 min DEFAULT_PAGE_SIZE = 100 @@ -57,11 +60,15 @@ class RetryableAiohttpSession(AiohttpSession): interval=RETRY_INTERVAL, strategy=RetryStrategy.EXPONENTIAL_BACKOFF, ) - async def send(self, *args, **kwargs): + async def send( + self, *args, **kwargs + ) -> Generator[_GatheringFuture, None, Dict[str, Any]]: return await super().send(*args, **kwargs) -def load_service_account_json(service_account_credentials_json, google_service): +def load_service_account_json( + service_account_credentials_json: str, google_service: str +) -> Dict[str, str]: """ Load and parse a Google service account JSON configuration. @@ -101,7 +108,9 @@ def _load_json(json_string): raise ConfigurableFieldValueError(msg) -def validate_service_account_json(service_account_credentials, google_service): +def validate_service_account_json( + service_account_credentials: str, google_service: str +) -> None: """Validates whether service account JSON is a valid JSON string and checks for unexpected keys. @@ -122,7 +131,14 @@ def validate_service_account_json(service_account_credentials, google_service): class GoogleServiceAccountClient: """A Google client to handle api calls made to the Google Workspace APIs using a service account.""" - def __init__(self, json_credentials, api, api_version, scopes, api_timeout): + def __init__( + self, + json_credentials: Dict[str, str], + api: str, + api_version: str, + scopes: List[Union[str, Any]], + api_timeout: int, + ) -> None: """Initialize the ServiceAccountCreds class using which api calls will be made. Args: json_credentials (dict): Service account credentials json. @@ -136,15 +152,15 @@ def __init__(self, json_credentials, api, api_version, scopes, api_timeout): self.api_timeout = api_timeout self._logger = logger - def set_logger(self, logger_): + def set_logger(self, logger_: ExtraLogger) -> None: self._logger = logger_ async def api_call_paged( self, - resource, - method, + resource: str, + method: str, **kwargs, - ): + ) -> Iterator[_GatheringFuture]: """Make a paged GET call to a Google Workspace API. Args: resource (aiogoogle.resource.Resource): Resource name for which the API call will be made. @@ -171,10 +187,10 @@ async def _call_api(google_client, method_object, kwargs): async def api_call( self, - resource, - method, + resource: str, + method: str, **kwargs, - ): + ) -> Generator[Optional[Union[_GatheringFuture, Task]], None, Union[str, Future]]: """Make a non-paged GET call to Google Workspace API. Args: resource (aiogoogle.resource.Resource): Resource name for which the API call will be made. @@ -192,7 +208,13 @@ async def _call_api(google_client, method_object, kwargs): return await anext(self._execute_api_call(resource, method, _call_api, kwargs)) - async def _execute_api_call(self, resource, method, call_api_func, kwargs): + async def _execute_api_call( + self, + resource: str, + method: str, + call_api_func: Callable, + kwargs: Dict[str, Union[str, int]], + ) -> Iterator[Union[_GatheringFuture, Task, None]]: """Execute the API call with common try/except logic. Args: resource (aiogoogle.resource.Resource): Resource name for which the API call will be made. @@ -247,13 +269,19 @@ async def _execute_api_call(self, resource, method, call_api_func, kwargs): raise -def remove_universe_domain(json_credentials): +def remove_universe_domain(json_credentials: Dict[str, str]) -> None: if "universe_domain" in json_credentials: json_credentials.pop("universe_domain") class GoogleDirectoryClient: - def __init__(self, json_credentials, customer_id, subject, timeout=DEFAULT_TIMEOUT): + def __init__( + self, + json_credentials: Dict[str, str], + customer_id: str, + subject: str, + timeout: int = DEFAULT_TIMEOUT, + ) -> None: remove_universe_domain(json_credentials) json_credentials["subject"] = subject @@ -266,10 +294,10 @@ def __init__(self, json_credentials, customer_id, subject, timeout=DEFAULT_TIMEO api_timeout=timeout, ) - def set_logger(self, logger_): + def set_logger(self, logger_) -> None: self._logger = logger_ - async def ping(self): + async def ping(self) -> None: try: await self._client.api_call( resource="users", @@ -295,7 +323,13 @@ async def users(self): class GMailClient: - def __init__(self, json_credentials, customer_id, subject, timeout=DEFAULT_TIMEOUT): + def __init__( + self, + json_credentials: Dict[str, str], + customer_id: str, + subject: str, + timeout: int = DEFAULT_TIMEOUT, + ) -> None: remove_universe_domain(json_credentials) # This override is needed to be able to fetch the messages for the corresponding user, otherwise we get a 403 Forbidden (see: https://issuetracker.google.com/issues/290567932) @@ -310,10 +344,10 @@ def __init__(self, json_credentials, customer_id, subject, timeout=DEFAULT_TIMEO api_timeout=timeout, ) - def set_logger(self, logger_): + def set_logger(self, logger_) -> None: self._logger = logger_ - async def ping(self): + async def ping(self) -> None: try: await self._client.api_call( resource="users", method="getProfile", userId=self.user @@ -322,7 +356,10 @@ async def ping(self): raise async def messages( - self, query=None, includeSpamTrash=False, pageSize=DEFAULT_PAGE_SIZE + self, + query=None, + includeSpamTrash: bool = False, + pageSize: int = DEFAULT_PAGE_SIZE, ): fields = "id" @@ -338,7 +375,7 @@ async def messages( for message in page.get("messages", []): yield message - async def message(self, id_): + async def message(self, id_: str) -> Dict[str, str]: fields = "raw,internalDate" return await self._client.api_call( diff --git a/connectors/sources/google_cloud_storage.py b/connectors/sources/google_cloud_storage.py index 77e644b56..31c45083d 100644 --- a/connectors/sources/google_cloud_storage.py +++ b/connectors/sources/google_cloud_storage.py @@ -7,13 +7,16 @@ import os import urllib.parse +from _asyncio import Future, Task +from asyncio.tasks import _GatheringFuture from functools import cached_property, partial +from typing import Any, Dict, Generator, Iterator, List, Optional, Union from aiogoogle import Aiogoogle, HTTPError from aiogoogle.auth.creds import ServiceAccountCreds from connectors.logger import logger -from connectors.source import BaseDataSource +from connectors.source import BaseDataSource, DataSourceConfiguration from connectors.sources.google import ( load_service_account_json, validate_service_account_json, @@ -43,8 +46,8 @@ } RETRY_COUNT = 3 RETRY_INTERVAL = 2 -STORAGE_EMULATOR_HOST = os.environ.get("STORAGE_EMULATOR_HOST") -RUNNING_FTEST = ( +STORAGE_EMULATOR_HOST: Optional[str] = os.environ.get("STORAGE_EMULATOR_HOST") +RUNNING_FTEST: bool = ( "RUNNING_FTEST" in os.environ ) # Flag to check if a connector is run for ftest or not. REQUIRED_CREDENTIAL_KEYS = [ @@ -62,7 +65,7 @@ class GoogleCloudStorageClient: """A google client to handle api calls made to Google Cloud Storage.""" - def __init__(self, json_credentials): + def __init__(self, json_credentials: Dict[str, str]) -> None: """Initialize the ServiceAccountCreds class using which api calls will be made. Args: @@ -75,7 +78,7 @@ def __init__(self, json_credentials): self.user_project_id = self.service_account_credentials.project_id self._logger = logger - def set_logger(self, logger_): + def set_logger(self, logger_) -> None: self._logger = logger_ @retryable( @@ -85,12 +88,12 @@ def set_logger(self, logger_): ) async def api_call( self, - resource, - method, - sub_method=None, - full_response=False, + resource: str, + method: str, + sub_method: Optional[str] = None, + full_response: bool = False, **kwargs, - ): + ) -> Iterator[None]: """Make a GET call for Google Cloud Storage with retry for the failed API calls. Args: @@ -156,7 +159,7 @@ class GoogleCloudStorageDataSource(BaseDataSource): service_type = "google_cloud_storage" incremental_sync_enabled = True - def __init__(self, configuration): + def __init__(self, configuration: DataSourceConfiguration) -> None: """Set up the connection to the Google Cloud Storage Client. Args: @@ -164,11 +167,15 @@ def __init__(self, configuration): """ super().__init__(configuration=configuration) - def _set_internal_logger(self): + def _set_internal_logger(self) -> None: self._google_storage_client.set_logger(self._logger) @classmethod - def get_default_configuration(cls): + def get_default_configuration( + cls, + ) -> Dict[ + str, Union[Dict[str, Union[List[str], int, str]], Dict[str, Union[int, str]]] + ]: """Get the default configuration for Google Cloud Storage. Returns: @@ -200,7 +207,7 @@ def get_default_configuration(cls): }, } - async def validate_config(self): + async def validate_config(self) -> None: """Validates whether user inputs are valid or not for configuration field. Raises: @@ -213,7 +220,7 @@ async def validate_config(self): ) @cached_property - def _google_storage_client(self): + def _google_storage_client(self) -> GoogleCloudStorageClient: """Initialize and return the GoogleCloudStorageClient Returns: @@ -240,7 +247,7 @@ def _google_storage_client(self): return GoogleCloudStorageClient(json_credentials=required_credentials) - async def ping(self): + async def ping(self) -> None: """Verify the connection with Google Cloud Storage""" if RUNNING_FTEST: return @@ -260,7 +267,7 @@ async def ping(self): ) raise - async def fetch_buckets(self, buckets): + async def fetch_buckets(self, buckets: List[Any]) -> None: """Fetch the buckets from the Google Cloud Storage. Args: buckets (List): List of buckets. @@ -285,7 +292,9 @@ async def fetch_buckets(self, buckets): for bucket in buckets: yield {"items": [{"id": bucket, "name": bucket}]} - async def fetch_blobs(self, buckets): + async def fetch_blobs( + self, buckets: Dict[str, Union[str, List[Dict[str, str]]]] + ) -> Iterator[Optional[Task]]: """Fetches blobs stored in the bucket from Google Cloud Storage. Args: @@ -319,7 +328,7 @@ async def fetch_blobs(self, buckets): f"Something went wrong while fetching blobs from {bucket['name']}. Error: {exception}" ) - def prepare_blob_document(self, blob): + def prepare_blob_document(self, blob: Dict[str, str]) -> Dict[str, str]: """Apply key mappings to the blob document. Args: @@ -337,7 +346,9 @@ def prepare_blob_document(self, blob): ) return blob_document - def get_blob_document(self, blobs): + def get_blob_document( + self, blobs: Dict[str, Union[str, List[Dict[str, str]]]] + ) -> Iterator[Dict[str, Optional[str]]]: """Generate blob document. Args: @@ -349,7 +360,12 @@ def get_blob_document(self, blobs): for blob in blobs.get("items", []): yield self.prepare_blob_document(blob=blob) - async def get_content(self, blob, timestamp=None, doit=None): + async def get_content( + self, + blob: Dict[str, Optional[Union[str, int]]], + timestamp: None = None, + doit: Optional[bool] = None, + ) -> Generator[Union[Future, _GatheringFuture], None, Optional[Dict[str, str]]]: """Extracts the content for allowed file types. Args: @@ -405,7 +421,7 @@ async def get_content(self, blob, timestamp=None, doit=None): return document - async def get_docs(self, filtering=None): + async def get_docs(self, filtering: None = None) -> None: """Get buckets & blob documents from Google Cloud Storage. Yields: diff --git a/connectors/sources/google_drive.py b/connectors/sources/google_drive.py index 186521a11..16356f3d9 100644 --- a/connectors/sources/google_drive.py +++ b/connectors/sources/google_drive.py @@ -4,7 +4,21 @@ # you may not use this file except in compliance with the Elastic License 2.0. # import asyncio +from _asyncio import Future, Task +from asyncio.tasks import _GatheringFuture from functools import cached_property, partial +from typing import ( + Any, + Callable, + Dict, + Generator, + Iterator, + List, + Optional, + Tuple, + Union, +) +from unittest.mock import AsyncMock, MagicMock from aiogoogle import HTTPError @@ -18,6 +32,7 @@ CURSOR_SYNC_TIMESTAMP, BaseDataSource, ConfigurableFieldValueError, + DataSourceConfiguration, ) from connectors.sources.google import ( GoogleServiceAccountClient, @@ -41,7 +56,7 @@ GOOGLE_API_MAX_CONCURRENCY = 25 # Max open connections to Google API -DRIVE_API_TIMEOUT = 1 * 60 # 1 min +DRIVE_API_TIMEOUT: int = 1 * 60 # 1 min FOLDER_MIME_TYPE = "application/vnd.google-apps.folder" @@ -66,7 +81,7 @@ class SyncCursorEmpty(Exception): class GoogleDriveClient(GoogleServiceAccountClient): """A google drive client to handle api calls made to Google Drive API.""" - def __init__(self, json_credentials, subject=None): + def __init__(self, json_credentials: Dict[str, str], subject: None = None) -> None: """Initialize the GoogleApiClient superclass. Args: @@ -88,10 +103,12 @@ def __init__(self, json_credentials, subject=None): api_timeout=DRIVE_API_TIMEOUT, ) - async def ping(self): + async def ping( + self, + ) -> Generator[Optional[Union[_GatheringFuture, Task]], None, Future]: return await self.api_call(resource="about", method="get", fields="kind") - async def list_drives(self): + async def list_drives(self) -> Iterator[_GatheringFuture]: """Fetch all shared drive (id, name) from Google Drive Yields: @@ -106,7 +123,9 @@ async def list_drives(self): ): yield drive - async def get_all_drives(self): + async def get_all_drives( + self, + ) -> Generator[Union[_GatheringFuture, Task], None, Dict[Any, Any]]: """Retrieves all shared drives from Google Drive Returns: @@ -138,7 +157,11 @@ async def list_folders(self): ): yield folder - async def get_all_folders(self): + async def get_all_folders( + self, + ) -> Generator[ + Union[_GatheringFuture, Task], None, Dict[str, Dict[str, Union[str, List[str]]]] + ]: """Retrieves all folders from Google Drive Returns: @@ -155,7 +178,7 @@ async def get_all_folders(self): return folders - async def list_files(self, fetch_permissions=False, last_sync_time=None): + async def list_files(self, fetch_permissions: bool = False, last_sync_time=None): """Get files from Google Drive. Files can have any type. Args: @@ -189,7 +212,7 @@ async def list_files(self, fetch_permissions=False, last_sync_time=None): yield file async def list_files_from_my_drive( - self, fetch_permissions=False, last_sync_time=None + self, fetch_permissions: bool = False, last_sync_time=None ): """Retrieves files from Google Drive, with an option to fetch permissions (DLS). @@ -257,7 +280,7 @@ async def list_permissions(self, file_id): class GoogleAdminDirectoryClient(GoogleServiceAccountClient): """A google admin directory client to handle api calls made to Google Admin API.""" - def __init__(self, json_credentials, subject): + def __init__(self, json_credentials: Dict[str, str], subject: str) -> None: """Initialize the GoogleApiClient superclass. Args: @@ -316,35 +339,35 @@ async def list_groups_for_user(self, user_id): yield group -def _prefix_group(group): +def _prefix_group(group: str) -> Optional[str]: return prefix_identity("group", group) -def _prefix_user(user): +def _prefix_user(user: str) -> Optional[str]: return prefix_identity("user", user) -def _prefix_domain(domain): +def _prefix_domain(domain: str) -> Optional[str]: return prefix_identity("domain", domain) -def _is_user_permission(permission_type): +def _is_user_permission(permission_type: str) -> bool: return permission_type == "user" -def _is_group_permission(permission_type): +def _is_group_permission(permission_type: str) -> bool: return permission_type == "group" -def _is_domain_permission(permission_type): +def _is_domain_permission(permission_type: str) -> bool: return permission_type == "domain" -def _is_anyone_permission(permission_type): +def _is_anyone_permission(permission_type: str) -> bool: return permission_type == "anyone" -def _get_domain_from_email(email): +def _get_domain_from_email(email: str) -> str: return email.split("@")[-1] @@ -356,7 +379,7 @@ class GoogleDriveDataSource(BaseDataSource): dls_enabled = True incremental_sync_enabled = True - def __init__(self, configuration): + def __init__(self, configuration: DataSourceConfiguration) -> None: """Set up the data source. Args: @@ -364,12 +387,27 @@ def __init__(self, configuration): """ super().__init__(configuration=configuration) - def _set_internal_logger(self): + def _set_internal_logger(self) -> None: if self._domain_wide_delegation_sync_enabled() or self._dls_enabled(): self.google_admin_directory_client.set_logger(self._logger) @classmethod - def get_default_configuration(cls): + def get_default_configuration( + cls, + ) -> Dict[ + str, + Union[ + Dict[ + str, + Union[ + List[Dict[str, str]], List[Dict[str, Union[bool, str]]], int, str + ], + ], + Dict[str, Union[List[Dict[str, Union[int, str]]], List[str], int, str]], + Dict[str, Union[List[str], int, str]], + Dict[str, Union[int, str]], + ], + ]: """Get the default configuration for Google Drive. Returns: @@ -456,7 +494,7 @@ def get_default_configuration(cls): }, } - def google_drive_client(self, impersonate_email=None): + def google_drive_client(self, impersonate_email: None = None) -> GoogleDriveClient: """ Initialize and return an instance of the GoogleDriveClient. @@ -500,7 +538,7 @@ def google_drive_client(self, impersonate_email=None): return drive_client @cached_property - def google_admin_directory_client(self): + def google_admin_directory_client(self) -> GoogleAdminDirectoryClient: """Initialize and return the GoogleAdminDirectoryClient Returns: @@ -527,7 +565,7 @@ def google_admin_directory_client(self): return directory_client - async def validate_config(self): + async def validate_config(self) -> None: """Validates whether user inputs are valid or not for configuration field. Raises: @@ -541,7 +579,7 @@ async def validate_config(self): self._validate_google_workspace_admin_email() self._validate_google_workspace_email_for_shared_drives_sync() - def _validate_google_workspace_admin_email(self): + def _validate_google_workspace_admin_email(self) -> None: """ This method is used to validate the Google Workspace admin email address when Document Level Security (DLS) is enabled for the current configuration. The email address should not be empty, and it should have a valid email format (no @@ -569,7 +607,7 @@ def _validate_google_workspace_admin_email(self): msg = "Google Workspace admin email is malformed or contains whitespace characters." raise ConfigurableFieldValueError(msg) - def _validate_google_workspace_email_for_shared_drives_sync(self): + def _validate_google_workspace_email_for_shared_drives_sync(self) -> None: """ Validates the Google Workspace email address specified for shared drives synchronization. @@ -594,7 +632,7 @@ def _validate_google_workspace_email_for_shared_drives_sync(self): msg = "Google Workspace email for shared drives sync is malformed or contains whitespace characters." raise ConfigurableFieldValueError(msg) - async def ping(self): + async def ping(self) -> None: """Verify the connection with Google Drive""" try: if self._domain_wide_delegation_sync_enabled(): @@ -607,7 +645,7 @@ async def ping(self): self._logger.exception("Error while connecting to the Google Drive.") raise - def _get_google_workspace_admin_email(self): + def _get_google_workspace_admin_email(self) -> Optional[str]: """ Retrieves the Google Workspace admin email based on the current configuration. @@ -630,10 +668,10 @@ def _get_google_workspace_admin_email(self): else: return None - def _google_google_workspace_email_for_shared_drives_sync(self): + def _google_google_workspace_email_for_shared_drives_sync(self) -> str: return self.configuration.get("google_workspace_email_for_shared_drives_sync") - def _dls_enabled(self): + def _dls_enabled(self) -> bool: """Check if Document Level Security is enabled""" if self._features is None: return False @@ -643,21 +681,44 @@ def _dls_enabled(self): return bool(self.configuration.get("use_document_level_security", False)) - def _domain_wide_delegation_sync_enabled(self): + def _domain_wide_delegation_sync_enabled(self) -> bool: """Check if Domain Wide delegation sync is enabled""" return bool( self.configuration.get("use_domain_wide_delegation_for_sync", False) ) - def _max_concurrency(self): + def _max_concurrency(self) -> int: """Get maximum concurrent open connections from the user config""" return self.configuration.get("max_concurrency") or GOOGLE_API_MAX_CONCURRENCY - def access_control_query(self, access_control): + def access_control_query( + self, access_control: List[str] + ) -> Dict[str, Dict[str, Dict[str, Union[str, Dict[str, List[str]]]]]]: return es_access_control_query(access_control) - async def _process_items_concurrently(self, items, process_item_func): + async def _process_items_concurrently( + self, + items: List[Dict[str, Union[str, bool, List[str], Dict[str, str]]]], + process_item_func: Union[AsyncMock, Callable], + ) -> Generator[ + _GatheringFuture, + None, + List[ + Union[ + Dict[str, Union[Dict[str, str], str]], + Dict[ + str, + Union[ + str, + Dict[str, str], + Dict[str, Dict[str, Union[str, Dict[str, List[str]]]]], + ], + ], + Tuple[Dict[str, Optional[Union[str, bool]]], None], + ] + ], + ]: """Process a list of items concurrently using a semaphore for concurrency control. This function applies the `process_item_func` to each item in the `items` list @@ -690,7 +751,9 @@ async def process_item(item, semaphore): # Gather the results of all tasks concurrently return await asyncio.gather(*tasks) - async def prepare_single_access_control_document(self, user): + async def prepare_single_access_control_document( + self, user: Dict[str, Union[Dict[str, str], str]] + ) -> Dict[str, Any]: """Generate access control document for a single user. Fetch group memberships for a given user. Generate a user_access_control query that includes information about user email, groups and domain. @@ -740,7 +803,7 @@ async def prepare_access_control_documents(self, users_page): for ac_doc in prepared_ac_docs: yield ac_doc - async def get_access_control(self): + async def get_access_control(self) -> None: """Yields an access control document for every user of Google Workspace organization. Yields: @@ -757,7 +820,13 @@ async def get_access_control(self): ): yield access_control_doc - async def resolve_paths(self, google_drive_client=None): + async def resolve_paths( + self, google_drive_client: Optional[GoogleDriveClient] = None + ) -> Generator[ + Union[_GatheringFuture, Task], + None, + Dict[str, Union[Dict[str, Union[str, List[str]]], Dict[str, str]]], + ]: """Builds a lookup between a folder id and its absolute path in Google Drive structure Returns: @@ -798,7 +867,12 @@ async def resolve_paths(self, google_drive_client=None): return folders - async def _download_content(self, file, file_extension, download_func): + async def _download_content( + self, + file: Dict[str, Optional[Union[str, int]]], + file_extension: str, + download_func: partial, + ) -> Generator[Union[Future, _GatheringFuture], None, Tuple[str, None, int]]: """Downloads the file from Google Drive and returns the encoded file content. Args: @@ -826,7 +900,12 @@ async def _download_content(self, file, file_extension, download_func): return attachment, body, file_size - async def get_google_workspace_content(self, client, file, timestamp=None): + async def get_google_workspace_content( + self, + client: GoogleDriveClient, + file: Dict[str, Optional[Union[str, int]]], + timestamp: None = None, + ) -> Optional[Dict[str, Any]]: """Exports Google Workspace documents to an allowed file type and extracts its text content. Shared Google Workspace documents are different than regular files. When shared from @@ -879,7 +958,12 @@ async def get_google_workspace_content(self, client, file, timestamp=None): return document - async def get_generic_file_content(self, client, file, timestamp=None): + async def get_generic_file_content( + self, + client: GoogleDriveClient, + file: Dict[str, Optional[Union[str, int]]], + timestamp: None = None, + ) -> Optional[Dict[str, Any]]: """Extracts the content from allowed file types supported by Apache Tika. Args: @@ -928,7 +1012,13 @@ async def get_generic_file_content(self, client, file, timestamp=None): return document - async def get_content(self, client, file, timestamp=None, doit=None): + async def get_content( + self, + client: GoogleDriveClient, + file: Dict[str, Optional[Union[str, int]]], + timestamp: Optional[str] = None, + doit: Optional[bool] = None, + ) -> Optional[Dict[str, Any]]: """Extracts the content from a file file. Args: @@ -956,7 +1046,9 @@ async def get_content(self, client, file, timestamp=None, doit=None): client, file, timestamp=timestamp ) - async def _get_permissions_on_shared_drive(self, client, file_id): + async def _get_permissions_on_shared_drive( + self, client: GoogleDriveClient, file_id: str + ) -> Generator[Union[_GatheringFuture, Task], None, List[Dict[str, str]]]: """Retrieves the permissions on a shared drive for the given file ID. Args: @@ -973,7 +1065,9 @@ async def _get_permissions_on_shared_drive(self, client, file_id): return permissions - def _process_permissions(self, permissions): + def _process_permissions( + self, permissions: List[Dict[str, str]] + ) -> List[Optional[str]]: """Formats the access permission list for Google Drive object. Args: @@ -1007,7 +1101,22 @@ def _process_permissions(self, permissions): return processed_permissions - async def prepare_file(self, client, file, paths): + async def prepare_file( + self, + client: Union[GoogleDriveClient, MagicMock], + file: Dict[str, Any], + paths: Dict[str, Union[Dict[str, Union[str, List[str]]], Dict[str, str]]], + ) -> Generator[ + Union[_GatheringFuture, Task], + None, + Union[ + Tuple[Dict[str, Optional[Union[str, bool]]], None], + Tuple[Dict[str, Optional[Union[str, List[str], bool]]], None], + Tuple[Dict[str, Optional[Union[str, int, bool, List[str]]]], None], + Tuple[Dict[str, Optional[Union[str, bool]]], str], + Tuple[Dict[str, Optional[Union[str, int, bool]]], None], + ], + ]: """Apply key mappings to the file document. Args: @@ -1197,7 +1306,7 @@ async def get_docs(self, filtering=None): ): yield file, partial(self.get_content, google_drive_client, file) - async def get_docs_incrementally(self, sync_cursor, filtering=None): + async def get_docs_incrementally(self, sync_cursor: None, filtering: None = None): """Executes the logic to fetch Google Drive objects incrementally in an async manner. Args: @@ -1334,7 +1443,7 @@ async def get_docs_incrementally(self, sync_cursor, filtering=None): ) self.update_sync_timestamp_cursor(timestamp) - def init_sync_cursor(self): + def init_sync_cursor(self) -> Dict[str, str]: if not self._sync_cursor: self._sync_cursor = { CURSOR_GOOGLE_DRIVE_KEY: {}, diff --git a/connectors/sources/graphql.py b/connectors/sources/graphql.py index c45410296..d59c0dde9 100644 --- a/connectors/sources/graphql.py +++ b/connectors/sources/graphql.py @@ -9,15 +9,23 @@ import re from copy import deepcopy from functools import cached_property +from typing import Dict, Iterator, List, Optional, Tuple, Union +from unittest.mock import AsyncMock import aiohttp +from aiohttp.client import ClientSession from aiohttp.client_exceptions import ClientResponseError from graphql import parse, visit -from graphql.language.ast import VariableNode +from graphql.language.ast import DocumentNode, FieldNode, VariableNode +from graphql.language.source import Source from graphql.language.visitor import Visitor from connectors.logger import logger -from connectors.source import BaseDataSource, ConfigurableFieldValueError +from connectors.source import ( + BaseDataSource, + ConfigurableFieldValueError, + DataSourceConfiguration, +) from connectors.utils import ( CancellableSleeps, RetryStrategy, @@ -53,7 +61,7 @@ class FieldVisitor(Visitor): fields_dict = {} variables_dict = {} - def enter_field(self, node, *args): + def enter_field(self, node: FieldNode, *args) -> None: self.fields_dict[node.name.value] = [] self.variables_dict[node.name.value] = {} if node.arguments: @@ -70,7 +78,7 @@ class UnauthorizedException(Exception): class GraphQLClient: - def __init__(self, configuration): + def __init__(self, configuration: DataSourceConfiguration) -> None: self._sleeps = CancellableSleeps() self.configuration = configuration self._logger = logger @@ -85,11 +93,11 @@ def __init__(self, configuration): self.variables = {} self.headers = {} - def set_logger(self, logger_): + def set_logger(self, logger_) -> None: self._logger = logger_ @cached_property - def session(self): + def session(self) -> ClientSession: timeout = aiohttp.ClientTimeout(total=self.configuration["connection_timeout"]) if self.authentication_method == BEARER: self.headers.update( @@ -120,7 +128,19 @@ def session(self): raise_for_status=True, ) - def extract_graphql_data_items(self, data): + def extract_graphql_data_items( + self, + data: Dict[ + str, + Union[ + Dict[str, str], + Dict[str, Union[str, Dict[str, str], Dict[str, Union[bool, str]]]], + List[Dict[str, str]], + Dict[str, Dict[str, str]], + List[Dict[str, Union[Dict[str, str], str]]], + ], + ], + ) -> Iterator[Dict[str, Union[str, Dict[str, str], Dict[str, Union[bool, str]]]]]: """Returns sub objects from the response based on graphql_object_to_id_map Args: @@ -155,7 +175,16 @@ def extract_graphql_data_items(self, data): doc["_id"] = doc.get(field_id) yield doc - def extract_pagination_info(self, data): + def extract_pagination_info( + self, + data: Dict[ + str, + Union[ + Dict[str, Dict[str, str]], + Dict[str, Union[str, Dict[str, str], Dict[str, Union[bool, str]]]], + ], + ], + ) -> Tuple[bool, str, str]: pagination_key_path = self.pagination_key.split(".") for key in pagination_key_path: if isinstance(data, dict): @@ -181,7 +210,9 @@ def extract_pagination_info(self, data): msg = "Pagination is enabled but the query is missing 'pageInfo'. Please include 'pageInfo { hasNextPage endCursor }' in the query to support pagination." raise ConfigurableFieldValueError(msg) - def validate_paginated_query(self, graphql_query, visitor): + def validate_paginated_query( + self, graphql_query: str, visitor: FieldVisitor + ) -> None: graphql_object = self.pagination_key.split(".")[-1] self._logger.debug(f"Finding pageInfo field in {graphql_object}.") if not ( @@ -191,7 +222,7 @@ def validate_paginated_query(self, graphql_query, visitor): msg = f"Pagination is enabled but 'pageInfo' not found. Please include 'pageInfo' field inside '{graphql_object}' and 'after' argument in '{graphql_object}'." raise ConfigurableFieldValueError(msg) - async def paginated_call(self, graphql_query): + async def paginated_call(self, graphql_query: Union[Source, str]): if self.pagination_model == CURSOR_PAGINATION: ast = parse(graphql_query) # pyright: ignore visitor = FieldVisitor() @@ -228,7 +259,13 @@ async def paginated_call(self, graphql_query): interval=RETRY_INTERVAL, strategy=RetryStrategy.EXPONENTIAL_BACKOFF, ) - async def make_request(self, graphql_query): + async def make_request( + self, graphql_query: str + ) -> Union[ + AsyncMock, + Dict[str, Dict[str, Union[str, Dict[str, str], Dict[str, Union[bool, str]]]]], + Dict[str, List[Dict[str, Union[Dict[str, str], str]]]], + ]: try: if self.http_method == GET: return await self.get(graphql_query=graphql_query) @@ -243,7 +280,7 @@ async def make_request(self, graphql_query): except Exception: raise - async def get(self, graphql_query): + async def get(self, graphql_query: str) -> Dict[str, str]: params = {"query": graphql_query} async with self.session.get(url=self.url, params=params) as response: json_response = await response.json() @@ -253,7 +290,7 @@ async def get(self, graphql_query): msg = f"Error while executing query. Exception: {json_response['errors']}" raise Exception(msg) - async def post(self, graphql_query): + async def post(self, graphql_query: str) -> Dict[str, str]: """Invoke GraphQL request to fetch response. Args: @@ -279,12 +316,12 @@ async def post(self, graphql_query): msg = f"Error while executing query. Exception: {json_response['errors']}" raise Exception(msg) - async def close(self): + async def close(self) -> None: self._sleeps.cancel() await self.session.close() del self.session - async def ping(self): + async def ping(self) -> None: await self.make_request(graphql_query=PING_QUERY) @@ -294,7 +331,7 @@ class GraphQLDataSource(BaseDataSource): name = "GraphQL" service_type = "graphql" - def __init__(self, configuration): + def __init__(self, configuration: DataSourceConfiguration) -> None: """Setup the connection to the GraphQL instance. Args: @@ -303,11 +340,20 @@ def __init__(self, configuration): super().__init__(configuration=configuration) self.graphql_client = GraphQLClient(configuration=configuration) - def _set_internal_logger(self): + def _set_internal_logger(self) -> None: self.graphql_client.set_logger(self._logger) @classmethod - def get_default_configuration(cls): + def get_default_configuration( + cls, + ) -> Dict[ + str, + Union[ + Dict[str, Union[List[Dict[str, str]], int, str]], + Dict[str, Union[List[str], int, str]], + Dict[str, Union[int, str]], + ], + ]: """Get the default configuration for GraphQL. Returns: @@ -421,7 +467,7 @@ def get_default_configuration(cls): }, } - def is_query(self, ast): + def is_query(self, ast: DocumentNode) -> bool: for definition in ast.definitions: # pyright: ignore if ( hasattr(definition, "operation") @@ -430,14 +476,18 @@ def is_query(self, ast): return False return True - def validate_endpoints(self): + def validate_endpoints(self) -> bool: if re.match(URL_REGEX, self.graphql_client.url): return True return False def check_field_existence( - self, ast, field_path, graphql_field_id=None, check_id=False - ): + self, + ast: DocumentNode, + field_path: str, + graphql_field_id: Optional[str] = None, + check_id: bool = False, + ) -> Tuple[bool, bool]: def traverse(selections, path): for selection in selections: if selection.name.value == path[0]: @@ -460,7 +510,7 @@ def traverse(selections, path): return False, False - async def validate_config(self): + async def validate_config(self) -> None: """Validates whether user input is empty or not for configuration fields Also validate, if user configured repositories are accessible or not and scope of the token """ @@ -544,10 +594,10 @@ async def validate_config(self): ) raise ConfigurableFieldValueError(msg) - async def close(self): + async def close(self) -> None: await self.graphql_client.close() - async def ping(self): + async def ping(self) -> None: try: await self.graphql_client.ping() self._logger.debug("Successfully connected to GraphQL Instance.") @@ -555,7 +605,10 @@ async def ping(self): self._logger.exception("Error while connecting to GraphQL Instance.") raise - def yield_dict(self, documents): + def yield_dict( + self, + documents: Dict[str, Union[str, Dict[str, str], Dict[str, Union[bool, str]]]], + ) -> Iterator[Dict[str, Union[str, Dict[str, str], Dict[str, Union[bool, str]]]]]: if isinstance(documents, dict): yield documents elif isinstance(documents, list): @@ -563,7 +616,7 @@ def yield_dict(self, documents): if isinstance(document, dict): yield document - async def fetch_data(self, graphql_query): + async def fetch_data(self, graphql_query: str): if self.graphql_client.pagination_model == NO_PAGINATION: data = await self.graphql_client.make_request(graphql_query=graphql_query) for documents in self.graphql_client.extract_graphql_data_items(data=data): @@ -576,7 +629,7 @@ async def fetch_data(self, graphql_query): for document in self.yield_dict(data): yield document - async def get_docs(self, filtering=None): + async def get_docs(self, filtering: None = None): """Executes the logic to fetch GraphQL response in async manner. Args: diff --git a/connectors/sources/jira.py b/connectors/sources/jira.py index f256e01f9..9692306ef 100644 --- a/connectors/sources/jira.py +++ b/connectors/sources/jira.py @@ -6,18 +6,21 @@ """Jira source module responsible to fetch documents from Jira on-prem or cloud server.""" import asyncio +from _asyncio import Future, Task from copy import copy from datetime import datetime from functools import partial +from typing import Any, Dict, Generator, Iterator, List, Optional, Union from urllib import parse import aiohttp import pytz +from aiohttp.client import ClientSession from aiohttp.client_exceptions import ClientResponseError, ServerConnectionError from connectors.access_control import ACCESS_CONTROL from connectors.logger import logger -from connectors.source import BaseDataSource +from connectors.source import BaseDataSource, DataSourceConfiguration from connectors.sources.atlassian import ( AtlassianAccessControl, AtlassianAdvancedRulesValidator, @@ -44,7 +47,7 @@ FETCH_SIZE = 100 MAX_USER_FETCH_LIMIT = 1000 -QUEUE_MEM_SIZE = 5 * 1024 * 1024 # Size in Megabytes +QUEUE_MEM_SIZE: int = 5 * 1024 * 1024 # Size in Megabytes MAX_CONCURRENCY = 5 MAX_CONCURRENT_DOWNLOADS = 100 # Max concurrent download supported by jira @@ -62,7 +65,7 @@ SECURITY_LEVEL_MEMBERS = "issue_security_members" PROJECT_ROLE_MEMBERS_BY_ROLE_ID = "project_role_members_by_role_id" ALL_FIELDS = "all_fields" -URLS = { +URLS: Dict[str, str] = { PING: "rest/api/2/myself", PROJECT: "rest/api/2/project?expand=description,lead,url", PROJECT_BY_KEY: "rest/api/2/project/{key}", @@ -114,7 +117,7 @@ class EmptyResponseError(Exception): class JiraClient: """Jira client to handle API calls made to Jira""" - def __init__(self, configuration): + def __init__(self, configuration: DataSourceConfiguration) -> None: self._sleeps = CancellableSleeps() self.configuration = configuration self._logger = logger @@ -134,10 +137,10 @@ def __init__(self, configuration): self.ssl_ctx = False self.session = None - def set_logger(self, logger_): + def set_logger(self, logger_) -> None: self._logger = logger_ - def _get_session(self): + def _get_session(self) -> ClientSession: """Generate and return base client session with configuration fields Returns: @@ -183,7 +186,7 @@ def _get_session(self): ) return self.session - async def close_session(self): + async def close_session(self) -> None: """Closes unclosed client session""" self._sleeps.cancel() if self.session is None: @@ -191,7 +194,9 @@ async def close_session(self): await self.session.close() self.session = None - async def _handle_client_errors(self, url, exception): + async def _handle_client_errors( + self, url: str, exception: Union[TypeError, ValueError] + ) -> Iterator[Task]: if exception.status == 429: response_headers = exception.headers or {} retry_seconds = DEFAULT_RETRY_SECONDS @@ -225,7 +230,9 @@ async def _handle_client_errors(self, url, exception): strategy=RetryStrategy.EXPONENTIAL_BACKOFF, skipped_exceptions=NotFound, ) - async def api_call(self, url_name=None, **url_kwargs): + async def api_call( + self, url_name: Optional[str] = None, **url_kwargs + ) -> Iterator[Optional[Task]]: """Make a GET call for Atlassian API using the passed url_name with retry for the failed API calls. Args: @@ -400,12 +407,12 @@ async def issue_security_level_members(self, level_id): ): yield response - async def get_timezone(self): + async def get_timezone(self) -> str: async for response in self.api_call(url_name=PING): timezone = await response.json() return timezone.get("timeZone") - async def verify_projects(self): + async def verify_projects(self) -> None: if self.projects == ["*"]: return @@ -422,7 +429,7 @@ async def verify_projects(self): msg = f"Unable to verify projects: {self.projects}. Error: {exception}" raise Exception(msg) from exception - async def ping(self): + async def ping(self) -> None: await anext(self.api_call(url_name=PING)) async def get_jira_fields(self): @@ -444,7 +451,7 @@ class JiraDataSource(BaseDataSource): dls_enabled = True incremental_sync_enabled = True - def __init__(self, configuration): + def __init__(self, configuration: DataSourceConfiguration) -> None: """Setup the connection to the Jira Args: @@ -463,11 +470,22 @@ def __init__(self, configuration): self.project_permission_cache = {} self.custom_fields = {} - def _set_internal_logger(self): + def _set_internal_logger(self) -> None: self.jira_client.set_logger(self._logger) @classmethod - def get_default_configuration(cls): + def get_default_configuration( + cls, + ) -> Dict[ + str, + Union[ + Dict[str, Union[List[Dict[str, str]], int, str]], + Dict[str, Union[List[Dict[str, Union[bool, str]]], int, str]], + Dict[str, Union[List[Dict[str, Union[int, str]]], List[str], int, str]], + Dict[str, Union[List[str], int, str]], + Dict[str, Union[int, str]], + ], + ]: """Get the default configuration for Jira Returns: @@ -591,7 +609,7 @@ def get_default_configuration(cls): }, } - def _dls_enabled(self): + def _dls_enabled(self) -> bool: """Check if document level security is enabled. This method checks whether document level security (DLS) is enabled based on the provided configuration. Returns: @@ -605,14 +623,45 @@ def _dls_enabled(self): return self.configuration["use_document_level_security"] - def _decorate_with_access_control(self, document, access_control): + def _decorate_with_access_control( + self, + document: Dict[ + str, + Union[ + str, + Dict[ + str, + Union[ + Dict[str, str], + str, + List[Dict[str, Union[str, int]]], + Dict[str, Dict[str, List[Dict[str, str]]]], + ], + ], + Dict[ + str, + Union[ + Dict[str, str], + str, + List[Dict[str, Union[str, int]]], + Dict[str, Dict[Any, Any]], + ], + ], + Dict[str, str], + int, + ], + ], + access_control: Union[str, List[str]], + ) -> Dict[str, Any]: if self._dls_enabled(): document[ACCESS_CONTROL] = list( set(document.get(ACCESS_CONTROL, []) + access_control) ) return document - async def _project_access_control(self, project): + async def _project_access_control( + self, project: Dict[str, str] + ) -> List[Union[Any, str]]: if not self._dls_enabled(): return [] @@ -644,7 +693,9 @@ async def _project_access_control(self, project): ) return list(access_control) - async def _cache_project_access_control(self, project): + async def _cache_project_access_control( + self, project: Dict[str, str] + ) -> List[Union[Any, str]]: project_key = project.get("key") if project_key in self.project_permission_cache.keys(): project_access_controls = self.project_permission_cache.get(project_key) @@ -655,7 +706,9 @@ async def _cache_project_access_control(self, project): self.project_permission_cache[project_key] = project_access_controls return project_access_controls - async def _issue_access_control(self, issue_key, project): + async def _issue_access_control( + self, issue_key: str, project: Dict[str, str] + ) -> List[Union[Any, str]]: if not self._dls_enabled(): return [] @@ -729,7 +782,7 @@ async def _issue_access_control(self, issue_key, project): return project_access_controls return list(access_control) - async def get_access_control(self): + async def get_access_control(self) -> None: """Get access control documents for active Atlassian users. This method fetches access control documents for active Atlassian users when document level security (DLS) @@ -772,10 +825,10 @@ async def get_access_control(self): user=user ) - def advanced_rules_validators(self): + def advanced_rules_validators(self) -> List[AtlassianAdvancedRulesValidator]: return [AtlassianAdvancedRulesValidator(self)] - def tweak_bulk_options(self, options): + def tweak_bulk_options(self, options: Dict[str, int]) -> None: """Tweak bulk options as per concurrent downloads support by jira Args: @@ -783,11 +836,17 @@ def tweak_bulk_options(self, options): """ options["concurrent_downloads"] = self.concurrent_downloads - async def close(self): + async def close(self) -> None: """Closes unclosed client session""" await self.jira_client.close_session() - async def get_content(self, issue_key, attachment, timestamp=None, doit=False): + async def get_content( + self, + issue_key: str, + attachment: Dict[str, Union[str, int]], + timestamp: None = None, + doit: bool = False, + ) -> Generator[Future, None, Optional[Dict[str, str]]]: """Extracts the content for allowed file types. Args: @@ -838,7 +897,7 @@ async def get_content(self, issue_key, attachment, timestamp=None, doit=False): ), ) - async def ping(self): + async def ping(self) -> None: """Verify the connection with Jira""" try: await self.jira_client.ping() @@ -847,7 +906,7 @@ async def ping(self): self._logger.exception("Error while connecting to the Jira") raise - async def _put_projects(self, project, timestamp): + async def _put_projects(self, project: Dict[str, str], timestamp: str) -> None: """Store project documents to queue Args: @@ -869,7 +928,7 @@ async def _put_projects(self, project, timestamp): ) await self.queue.put((document_with_access_control, None)) # pyright: ignore - async def _get_projects(self): + async def _get_projects(self) -> None: """Get projects with the help of REST APIs Yields: @@ -889,7 +948,33 @@ async def _get_projects(self): f"Skipping data for type: {PROJECT}. Error: {exception}" ) - async def _put_issue(self, issue): + async def _put_issue( + self, + issue: Dict[ + str, + Union[ + str, + Dict[ + str, + Union[ + Dict[str, str], + str, + List[Dict[str, Union[str, int]]], + Dict[str, Dict[str, List[Dict[str, str]]]], + ], + ], + Dict[ + str, + Union[ + Dict[str, str], + str, + List[Dict[str, Union[str, int]]], + Dict[str, Dict[Any, Any]], + ], + ], + ], + ], + ) -> None: """Put specific issue as per the given issue_key in a queue Args: @@ -951,7 +1036,7 @@ async def _put_issue(self, issue): ) await self.queue.put("FINISHED") # pyright: ignore - async def _get_issues(self, custom_query=""): + async def _get_issues(self, custom_query: str = "") -> None: """Get issues with the help of REST APIs Yields: @@ -973,7 +1058,12 @@ async def _get_issues(self, custom_query=""): self.tasks += 1 await self.queue.put("FINISHED") # pyright: ignore - async def _put_attachment(self, attachments, issue_key, access_control): + async def _put_attachment( + self, + attachments: List[Dict[str, Union[str, int]]], + issue_key: str, + access_control: Union[str, List[str]], + ) -> None: """Put attachments of a specific issue in a queue Args: diff --git a/connectors/sources/microsoft_teams.py b/connectors/sources/microsoft_teams.py index 28a619f02..28ca9a25d 100644 --- a/connectors/sources/microsoft_teams.py +++ b/connectors/sources/microsoft_teams.py @@ -7,10 +7,12 @@ import asyncio import os +from _asyncio import Task from calendar import month_name from datetime import datetime, timedelta from enum import Enum from functools import cached_property, partial +from typing import Any, Callable, Dict, Iterator, List, Optional, Union import aiofiles import aiohttp @@ -19,8 +21,8 @@ from aiohttp.client_exceptions import ClientResponseError from msal import ConfidentialClientApplication -from connectors.logger import logger -from connectors.source import BaseDataSource +from connectors.logger import ExtraLogger, logger +from connectors.source import BaseDataSource, DataSourceConfiguration from connectors.utils import ( TIKA_SUPPORTED_FILETYPES, CacheWithTimeout, @@ -34,16 +36,16 @@ url_encode, ) -QUEUE_MEM_SIZE = 5 * 1024 * 1024 # Size in Megabytes +QUEUE_MEM_SIZE: int = 5 * 1024 * 1024 # Size in Megabytes MAX_CONCURRENCY = 80 -FILE_WRITE_CHUNK_SIZE = 1024 * 64 # 64KB default SSD page size +FILE_WRITE_CHUNK_SIZE: int = 1024 * 64 # 64KB default SSD page size MAX_FILE_SIZE = 10485760 TOKEN_EXPIRES = 3599 RETRY_COUNT = 3 RETRY_SECONDS = 30 RETRY_INTERVAL = 2 -RUNNING_FTEST = ( +RUNNING_FTEST: bool = ( "RUNNING_FTEST" in os.environ ) # Flag to check if a connector is run for ftest or not. @@ -54,14 +56,14 @@ ) logger.warning("IT'S SUPPOSED TO BE USED ONLY FOR TESTING") logger.warning("x" * 50) - override_url = os.environ["OVERRIDE_URL"] - BASE_URL = override_url - GRAPH_API_AUTH_URL = override_url - GRAPH_ACQUIRE_TOKEN_URL = override_url + override_url: str = os.environ["OVERRIDE_URL"] + BASE_URL: str = override_url + GRAPH_API_AUTH_URL: str = override_url + GRAPH_ACQUIRE_TOKEN_URL: str = override_url else: - GRAPH_API_AUTH_URL = "https://login.microsoftonline.com" - GRAPH_ACQUIRE_TOKEN_URL = "https://graph.microsoft.com/.default" # noqa S105 - BASE_URL = "https://graph.microsoft.com/v1.0" + GRAPH_API_AUTH_URL: str = "https://login.microsoftonline.com" + GRAPH_ACQUIRE_TOKEN_URL: str = "https://graph.microsoft.com/.default" # noqa S105 + BASE_URL: str = "https://graph.microsoft.com/v1.0" SCOPE = [ "User.Read.All", @@ -125,7 +127,7 @@ class EndSignal(Enum): class Schema: - def chat_messages(self): + def chat_messages(self) -> Dict[str, str]: return { "_id": "id", "_timestamp": "lastModifiedDateTime", @@ -137,13 +139,13 @@ def chat_messages(self): "message": "message", } - def chat_tabs(self): + def chat_tabs(self) -> Dict[str, str]: return { "_id": "id", "title": "displayName", } - def chat_attachments(self): + def chat_attachments(self) -> Dict[str, str]: return { "_id": "id", "name": "name", @@ -153,7 +155,7 @@ def chat_attachments(self): "creation_time": "createdDateTime", } - def meeting(self): + def meeting(self) -> Dict[str, str]: return { "_id": "id", "creation_time": "createdDateTime", @@ -168,14 +170,14 @@ def meeting(self): "original_end_timezone": "originalEndTimeZone", } - def teams(self): + def teams(self) -> Dict[str, str]: return { "_id": "id", "title": "displayName", "description": "description", } - def channel(self): + def channel(self) -> Dict[str, str]: return { "_id": "id", "url": "webUrl", @@ -184,10 +186,10 @@ def channel(self): "creation_time": "createdDateTime", } - def channel_tab(self): + def channel_tab(self) -> Dict[str, str]: return {"_id": "id", "title": "displayName", "url": "webUrl"} - def channel_message(self): + def channel_message(self) -> Dict[str, str]: return { "_id": "id", "url": "webUrl", @@ -195,7 +197,7 @@ def channel_message(self): "creation_time": "createdDateTime", } - def channel_attachment(self): + def channel_attachment(self) -> Dict[str, str]: return { "_id": "id", "name": "name", @@ -249,7 +251,14 @@ class TokenFetchFailed(Exception): class GraphAPIToken: """Class for handling access token for Microsoft Graph APIs""" - def __init__(self, tenant_id, client_id, client_secret, username, password): + def __init__( + self, + tenant_id: Optional[str], + client_id: Optional[str], + client_secret: Optional[str], + username: Optional[str], + password: Optional[str], + ) -> None: """Initializer. Args: @@ -268,7 +277,7 @@ def __init__(self, tenant_id, client_id, client_secret, username, password): self._token_cache_with_client = CacheWithTimeout() self._token_cache_with_username = CacheWithTimeout() - async def get_with_client(self): + async def get_with_client(self) -> str: """Get bearer token for provided credentials. If token has been retrieved, it'll be taken from the cache. @@ -295,7 +304,7 @@ async def get_with_client(self): return access_token - async def get_with_username_password(self): + async def get_with_username_password(self) -> str: """Get bearer token for provided credentials. If token has been retrieved, it'll be taken from the cache. @@ -326,7 +335,7 @@ async def get_with_username_password(self): interval=RETRY_INTERVAL, strategy=RetryStrategy.EXPONENTIAL_BACKOFF, ) - async def _fetch_token(self, is_acquire_for_client=False): + async def _fetch_token(self, is_acquire_for_client: bool = False): """Generate API token for usage with Graph API Args: is_acquire_for_client (boolean): True if token needs be generated using client. Default to false @@ -361,7 +370,14 @@ async def _fetch_token(self, is_acquire_for_client=False): class MicrosoftTeamsClient: """Client Class for API calls to Microsoft Teams""" - def __init__(self, tenant_id, client_id, client_secret, username, password): + def __init__( + self, + tenant_id: Optional[str], + client_id: Optional[str], + client_secret: Optional[str], + username: Optional[str], + password: Optional[str], + ) -> None: self._sleeps = CancellableSleeps() self._http_session = aiohttp.ClientSession( headers={ @@ -381,13 +397,13 @@ def __init__(self, tenant_id, client_id, client_secret, username, password): self._logger = logger - def set_logger(self, logger_): + def set_logger(self, logger_: ExtraLogger) -> None: self._logger = logger_ - async def fetch(self, url): + async def fetch(self, url: str) -> Dict[str, str]: return await self._get_json(absolute_url=url) - async def pipe(self, url, stream): + async def pipe(self, url, stream) -> None: try: async for response in self._get(absolute_url=url, use_token=False): async for data in response.content.iter_chunked(FILE_WRITE_CHUNK_SIZE): @@ -411,7 +427,7 @@ async def scroll(self, url): else: break - async def _get_json(self, absolute_url): + async def _get_json(self, absolute_url: str) -> Dict[str, Union[str, List[str]]]: try: async for response in self._get(absolute_url=absolute_url): return await response.json() @@ -426,7 +442,7 @@ async def _get_json(self, absolute_url): strategy=RetryStrategy.EXPONENTIAL_BACKOFF, skipped_exceptions=[NotFound, PermissionsMissing], ) - async def _get(self, absolute_url, use_token=True): + async def _get(self, absolute_url: str, use_token: bool = True) -> Iterator[Task]: try: if use_token: if any( @@ -460,7 +476,9 @@ async def _get(self, absolute_url, use_token=True): except ClientResponseError as e: await self._handle_client_response_error(absolute_url, e) - async def _handle_client_response_error(self, absolute_url, e): + async def _handle_client_response_error( + self, absolute_url: str, e: ClientResponseError + ) -> Iterator[Task]: if e.status == 429 or e.status == 503: response_headers = e.headers or {} updated_response_headers = { @@ -488,7 +506,7 @@ async def _handle_client_response_error(self, absolute_url, e): else: raise - async def ping(self): + async def ping(self) -> Dict[Any, Any]: return await self.fetch( url=URLS[UserEndpointName.PING.value].format(base_url=BASE_URL) ) @@ -528,7 +546,7 @@ async def get_user_drive(self, sender_id): ), ) - async def get_user_drive_root_children(self, drive_id): + async def get_user_drive_root_children(self, drive_id: int) -> Dict[str, str]: async for root_children_data in self.scroll( url=URLS[UserEndpointName.DRIVE_CHILDREN.value].format( base_url=BASE_URL, drive_id=drive_id @@ -560,7 +578,7 @@ async def get_calendars(self, user_id): for event in events: yield event - async def download_item(self, url, async_buffer): + async def download_item(self, url, async_buffer) -> None: await self.pipe(url=url, stream=async_buffer) async def get_teams(self): @@ -593,7 +611,9 @@ async def get_channel_messages(self, team_id, channel_id): ): yield channel_messages - async def get_channel_file(self, team_id, channel_id): + async def get_channel_file( + self, team_id: str, channel_id: str + ) -> Dict[str, Union[Dict[str, str], int, str]]: file = await self.fetch( url=URLS[TeamEndpointName.FILE.value].format( base_url=BASE_URL, team_id=team_id, channel_id=channel_id @@ -615,7 +635,7 @@ async def get_channel_drive_childrens(self, drive_id, item_id): yield documents yield child - async def close(self): + async def close(self) -> None: self._sleeps.cancel() await self._http_session.close() @@ -623,15 +643,15 @@ async def close(self): class MicrosoftTeamsFormatter: """Format documents""" - def __init__(self, schema): + def __init__(self, schema: Schema) -> None: self.schema = schema def map_document_with_schema( self, - document, - item, - document_type, - ): + document: Dict[str, Union[str, List[str], datetime]], + item: Dict[str, Any], + document_type: Callable, + ) -> None: """Prepare key mappings for documents Args: @@ -645,7 +665,9 @@ def map_document_with_schema( for elasticsearch_field, sharepoint_field in document_type().items(): document[elasticsearch_field] = item.get(sharepoint_field) - def format_doc(self, item, document_type, **kwargs): + def format_doc( + self, item: Dict[str, Any], document_type: Callable, **kwargs + ) -> Dict[str, Optional[Union[datetime, str, int]]]: document = {} for elasticsearch_field, sharepoint_field in kwargs["document"].items(): document[elasticsearch_field] = sharepoint_field @@ -654,7 +676,47 @@ def format_doc(self, item, document_type, **kwargs): ) return document - def format_user_chat_meeting_recording(self, item, url): + def format_user_chat_meeting_recording( + self, + item: Dict[ + str, + Optional[ + Union[ + str, + Dict[str, str], + Dict[ + str, + Optional[ + Union[str, Dict[str, Optional[Dict[str, Optional[str]]]]] + ], + ], + List[ + Dict[ + str, + Optional[ + Union[ + str, + Dict[str, str], + Dict[ + str, + Optional[ + Union[ + str, + List[ + Dict[str, Dict[str, Dict[str, str]]] + ], + ] + ], + ], + ] + ], + ] + ], + ] + ], + ], + url: str, + ) -> Dict[str, Any]: document = {"type": UserEndpointName.MEETING_RECORDING.value} document.update( { @@ -666,7 +728,7 @@ def format_user_chat_meeting_recording(self, item, url): ) return document - def get_calendar_detail(self, calendar): + def get_calendar_detail(self, calendar: Dict[str, Any]) -> str: body = "" organizer = calendar.get("organizer", {}).get("emailAddress").get("name") calendar_recurrence = calendar.get("recurrence") @@ -725,7 +787,7 @@ def get_calendar_detail(self, calendar): body = f"Schedule: {start_time} to {end_time} Organizer: {organizer}" return body - def format_user_calendars(self, item): + def format_user_calendars(self, item: Dict[str, Any]) -> Dict[str, Any]: document = {"type": UserEndpointName.MEETING.value} attendee_list = ( [ @@ -755,7 +817,29 @@ def format_user_calendars(self, item): ) return document - def format_channel_message(self, item, channel_name, message_content): + def format_channel_message( + self, + item: Dict[ + str, + Optional[ + Union[ + str, + Dict[str, Dict[str, str]], + Dict[str, str], + List[ + Dict[ + str, + Optional[ + Union[str, Dict[str, Dict[str, str]], Dict[str, str]] + ], + ] + ], + ] + ], + ], + channel_name: str, + message_content: str, + ) -> Dict[str, Any]: document = {"type": TeamEndpointName.MESSAGE.value} document.update( { # pyright: ignore @@ -774,7 +858,24 @@ def format_channel_message(self, item, channel_name, message_content): ) return document - def format_channel_meeting(self, reply): + def format_channel_meeting( + self, + reply: Dict[ + str, + Optional[ + Union[ + str, + Dict[str, str], + Dict[ + str, + Optional[ + Union[str, List[Dict[str, Dict[str, Dict[str, str]]]]] + ], + ], + ] + ], + ], + ) -> Dict[str, Any]: document = {"type": TeamEndpointName.MEETING.value} event = reply["eventDetail"] if event.get("@odata.type") == "#microsoft.graph.callEndedEventMessageDetail": @@ -806,7 +907,45 @@ def format_channel_meeting(self, reply): ) return document - async def format_user_chat_messages(self, chat, message, message_content, members): + async def format_user_chat_messages( + self, + chat: Dict[str, Optional[Union[List[Dict[str, str]], str]]], + message: Dict[ + str, + Optional[ + Union[ + str, + Dict[str, Dict[str, str]], + Dict[str, str], + List[ + Dict[ + str, + Optional[ + Union[str, Dict[str, Dict[str, str]], Dict[str, str]] + ], + ] + ], + ] + ], + ], + message_content: str, + members: Optional[str], + ) -> Dict[ + str, + Optional[ + Union[ + str, + Dict[str, Dict[str, str]], + Dict[str, str], + List[ + Dict[ + str, + Optional[Union[str, Dict[str, Dict[str, str]], Dict[str, str]]], + ] + ], + ] + ], + ]: if chat.get("topic"): message.update({"title": chat["topic"]}) else: @@ -823,7 +962,7 @@ async def format_user_chat_messages(self, chat, message, message_content, member ) return message - def format_attachment_names(self, attachments): + def format_attachment_names(self, attachments: List[Any]) -> str: if not attachments: return "" @@ -841,7 +980,7 @@ class MicrosoftTeamsDataSource(BaseDataSource): service_type = "microsoft_teams" incremental_sync_enabled = True - def __init__(self, configuration): + def __init__(self, configuration: DataSourceConfiguration) -> None: """Set up the connection to the Microsoft Teams. Args: @@ -854,11 +993,11 @@ def __init__(self, configuration): self.schema = Schema() self.formatter = MicrosoftTeamsFormatter(self.schema) - def _set_internal_logger(self): + def _set_internal_logger(self) -> None: self.client.set_logger(self._logger) @cached_property - def client(self): + def client(self) -> MicrosoftTeamsClient: tenant_id = self.configuration["tenant_id"] client_id = self.configuration["client_id"] client_secret = self.configuration["secret_value"] @@ -883,7 +1022,7 @@ async def _consumer(self): else: yield item - def verify_filename_for_extraction(self, filename): + def verify_filename_for_extraction(self, filename: str) -> Optional[bool]: attachment_extension = os.path.splitext(filename)[-1] if attachment_extension == "": self._logger.debug( @@ -897,7 +1036,9 @@ def verify_filename_for_extraction(self, filename): return return True - async def _download_content_for_attachment(self, download_func, original_filename): + async def _download_content_for_attachment( + self, download_func: partial, original_filename: str + ) -> str: attachment = None source_file_name = "" @@ -923,13 +1064,13 @@ async def _download_content_for_attachment(self, download_func, original_filenam return attachment - async def validate_config(self): + async def validate_config(self) -> None: await super().validate_config() # Check that we can log in into Graph API await self.client._api_token.get_with_username_password() - async def ping(self): + async def ping(self) -> None: """Verify the connection with Microsoft Teams""" try: await self.client.ping() @@ -938,7 +1079,7 @@ async def ping(self): self._logger.exception("Error while connecting to Microsoft Teams") raise - async def update_user_chat_attachments(self, **kwargs): + async def update_user_chat_attachments(self, **kwargs) -> None: async for attachments in self.client.get_user_chat_attachments( sender_id=kwargs["sender_id"], attachment_name=kwargs["attachment_name"], @@ -965,8 +1106,12 @@ async def update_user_chat_attachments(self, **kwargs): ) async def get_content( - self, user_attachment, download_url, timestamp=None, doit=False - ): + self, + user_attachment: Dict[str, Union[int, str]], + download_url: str, + timestamp: None = None, + doit: bool = False, + ) -> Optional[Dict[str, Any]]: """Extracts the content for allowed file types. Args: @@ -1004,8 +1149,13 @@ async def get_content( return document async def get_messages( - self, message, document_type=None, chat=None, channel_name=None, members=None - ): + self, + message: Dict[str, Any], + document_type: Optional[str] = None, + chat: Optional[Dict[str, Union[List[Dict[str, str]], str]]] = None, + channel_name: Optional[str] = None, + members: Optional[str] = None, + ) -> None: if not message.get("deletedDateTime") and ( "unknownFutureValue" not in message.get("messageType") ): @@ -1041,7 +1191,7 @@ async def get_messages( ) ) - async def user_chat_meeting_recording(self, message): + async def user_chat_meeting_recording(self, message: Dict[str, Any]) -> None: if ( message.get("eventDetail") and message["eventDetail"].get("@odata.type") @@ -1059,14 +1209,16 @@ async def user_chat_meeting_recording(self, message): ) ) - def get_chat_members(self, members): + def get_chat_members(self, members: List[Dict[str, str]]) -> str: return ",".join( member.get("displayName") for member in members if member.get("displayName", "") ) - async def user_chat_producer(self, chat): + async def user_chat_producer( + self, chat: Dict[str, Union[List[Dict[str, str]], str]] + ) -> None: members = self.get_chat_members(chat.get("members", [])) async for messages in self.client.get_user_chat_messages(chat_id=chat["id"]): for message in messages: @@ -1110,7 +1262,9 @@ async def user_chat_producer(self, chat): ) await self.queue.put(EndSignal.USER_CHAT_TASK_FINISHED) - async def get_channel_messages(self, message, channel_name): + async def get_channel_messages( + self, message: Dict[str, Any], channel_name: str + ) -> None: await self.get_messages(message=message, channel_name=channel_name) meeting_document = {} for reply in message.get("replies", []): @@ -1140,7 +1294,9 @@ async def get_channel_messages(self, message, channel_name): for document in meeting_document.values(): await self.queue.put((document, None)) - async def team_channel_producer(self, channel, team_id, team_name): + async def team_channel_producer( + self, channel: Dict[str, Optional[str]], team_id: str, team_name: str + ) -> None: channel_name = channel.get("displayName") await self.queue.put( ( @@ -1199,8 +1355,8 @@ async def team_channel_producer(self, channel, team_id, team_name): await self.queue.put(EndSignal.CHANNEL_TASK_FINISHED) async def get_channel_drive_producer( - self, drive_id, item_id, team_name, channel_name - ): + self, drive_id: str, item_id: str, team_name: str, channel_name: str + ) -> None: async for drive_child in self.client.get_channel_drive_childrens( drive_id=drive_id, item_id=item_id, @@ -1228,7 +1384,7 @@ async def get_channel_drive_producer( ) ) - async def teams_producer(self, team): + async def teams_producer(self, team: Dict[str, Optional[str]]) -> None: team_id = team.get("id") team_name = team.get("displayName") await self.queue.put( @@ -1254,7 +1410,7 @@ async def teams_producer(self, team): await self.queue.put(EndSignal.TEAM_TASK_FINISHED) - async def calendars_producer(self, user): + async def calendars_producer(self, user: Dict[str, str]) -> None: async for event in self.client.get_calendars(user_id=user["id"]): if event and not event.get("isCancelled"): await self.queue.put( @@ -1297,12 +1453,12 @@ async def get_docs(self, filtering=None): await self.fetchers.join() - async def close(self): + async def close(self) -> None: """Closes unclosed client session""" await self.client.close() @classmethod - def get_default_configuration(cls): + def get_default_configuration(cls) -> Dict[str, Dict[str, Union[int, str]]]: """Get the default configuration for Microsoft Teams. Returns: diff --git a/connectors/sources/mongo.py b/connectors/sources/mongo.py index 3293f07c3..5384b1ed7 100644 --- a/connectors/sources/mongo.py +++ b/connectors/sources/mongo.py @@ -9,6 +9,7 @@ from copy import deepcopy from datetime import datetime from tempfile import NamedTemporaryFile +from typing import Any, Dict, Iterator, List, Optional, Union import fastjsonschema from bson import OLD_UUID_SUBTYPE, Binary, DBRef, Decimal128, ObjectId @@ -21,7 +22,11 @@ AdvancedRulesValidator, SyncRuleValidationResult, ) -from connectors.source import BaseDataSource, ConfigurableFieldValueError +from connectors.source import ( + BaseDataSource, + ConfigurableFieldValueError, + DataSourceConfiguration, +) from connectors.utils import get_pem_format @@ -75,7 +80,9 @@ class MongoAdvancedRulesValidator(AdvancedRulesValidator): SCHEMA = fastjsonschema.compile(definition=SCHEMA_DEFINITION) - async def validate(self, advanced_rules): + async def validate( + self, advanced_rules: Dict[str, Any] + ) -> SyncRuleValidationResult: try: MongoAdvancedRulesValidator.SCHEMA(advanced_rules) @@ -97,7 +104,7 @@ class MongoDataSource(BaseDataSource): service_type = "mongodb" advanced_rules_enabled = True - def __init__(self, configuration): + def __init__(self, configuration: DataSourceConfiguration) -> None: super().__init__(configuration=configuration) self.client = None @@ -110,7 +117,16 @@ def __init__(self, configuration): self.collection = None @classmethod - def get_default_configuration(cls): + def get_default_configuration( + cls, + ) -> Dict[ + str, + Union[ + Dict[str, Union[List[Dict[str, Union[bool, str]]], List[str], int, str]], + Dict[str, Union[List[Dict[str, Union[bool, str]]], int, str]], + Dict[str, Union[int, str]], + ], + ]: return { "host": { "label": "Server hostname", @@ -172,7 +188,7 @@ def get_default_configuration(cls): } @contextmanager - def get_client(self): + def get_client(self) -> Iterator[AsyncIOMotorClient]: certfile = "" try: client_params = {} @@ -209,14 +225,24 @@ def get_client(self): exc_info=True, ) - def advanced_rules_validators(self): + def advanced_rules_validators(self) -> List[MongoAdvancedRulesValidator]: return [MongoAdvancedRulesValidator()] - async def ping(self): + async def ping(self) -> None: with self.get_client() as client: await client.admin.command("ping") - def remove_temp_file(self, temp_file): + def remove_temp_file( + self, + temp_file: Union[ + os.PathLike[bytes], + os.PathLike[str], + os.PathLike[Union[bytes, str]], + bytes, + int, + str, + ], + ) -> None: if os.path.exists(temp_file): try: os.remove(temp_file) @@ -227,7 +253,7 @@ def remove_temp_file(self, temp_file): ) # TODO: That's a lot of work. Find a better way - def serialize(self, doc): + def serialize(self, doc: Dict[str, Any]) -> Dict[str, Any]: def _serialize(value): if isinstance(value, ObjectId): value = str(value) @@ -285,7 +311,7 @@ async def get_docs(self, filtering=None): async for doc in collection.find(): yield self.serialize(doc), None - def check_conflicting_values(self, value): + def check_conflicting_values(self, value: Optional[bool]) -> None: if value == "true": value = True elif value == "false": @@ -297,7 +323,7 @@ def check_conflicting_values(self, value): msg = "The value of SSL/TLS must be the same in the hostname and configuration field." raise ConfigurableFieldValueError(msg) - async def validate_config(self): + async def validate_config(self) -> None: await super().validate_config() parsed_url = urllib.parse.urlparse(self.host) query_params = urllib.parse.parse_qs(parsed_url.query) diff --git a/connectors/sources/mssql.py b/connectors/sources/mssql.py index 377b9bdea..3472057c9 100644 --- a/connectors/sources/mssql.py +++ b/connectors/sources/mssql.py @@ -9,6 +9,7 @@ import os from functools import cached_property, partial from tempfile import NamedTemporaryFile +from typing import List, Optional, Sized import fastjsonschema from asyncpg.exceptions._base import InternalClientError @@ -47,31 +48,31 @@ class MSSQLQueries(Queries): """Class contains methods which return query""" - def ping(self): + def ping(self) -> str: """Query to ping source""" return "SELECT 1+1" - def all_tables(self, **kwargs): + def all_tables(self, **kwargs) -> str: """Query to get all tables""" return f"SELECT table_name FROM information_schema.tables WHERE TABLE_SCHEMA = '{ kwargs['schema'] }'" - def table_primary_key(self, **kwargs): + def table_primary_key(self, **kwargs) -> str: """Query to get the primary key""" return f"SELECT C.COLUMN_NAME FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS T JOIN INFORMATION_SCHEMA.CONSTRAINT_COLUMN_USAGE C ON C.CONSTRAINT_NAME=T.CONSTRAINT_NAME WHERE C.TABLE_NAME='{kwargs['table']}' and C.TABLE_SCHEMA='{kwargs['schema']}' and T.CONSTRAINT_TYPE='PRIMARY KEY'" - def table_data(self, **kwargs): + def table_data(self, **kwargs) -> str: """Query to get the table data""" return f'SELECT * FROM {kwargs["schema"]}."{kwargs["table"]}"' - def table_last_update_time(self, **kwargs): + def table_last_update_time(self, **kwargs) -> str: """Query to get the last update time of the table""" return f"SELECT last_user_update FROM sys.dm_db_index_usage_stats WHERE object_id=object_id('{kwargs['schema']}.{kwargs['table']}')" - def table_data_count(self, **kwargs): + def table_data_count(self, **kwargs) -> str: """Query to get the number of rows in the table""" return f'SELECT COUNT(*) FROM {kwargs["schema"]}."{kwargs["table"]}"' - def all_schemas(self): + def all_schemas(self) -> None: """Query to get all schemas of database""" pass @@ -92,7 +93,7 @@ class MSSQLAdvancedRulesValidator(AdvancedRulesValidator): SCHEMA = fastjsonschema.compile(definition=SCHEMA_DEFINITION) - def __init__(self, source): + def __init__(self, source) -> None: self.source = source async def validate(self, advanced_rules): @@ -108,7 +109,7 @@ async def validate(self, advanced_rules): interval=DEFAULT_WAIT_MULTIPLIER, strategy=RetryStrategy.EXPONENTIAL_BACKOFF, ) - async def _remote_validation(self, advanced_rules): + async def _remote_validation(self, advanced_rules) -> SyncRuleValidationResult: try: MSSQLAdvancedRulesValidator.SCHEMA(advanced_rules) except JsonSchemaValueException as e: @@ -157,9 +158,9 @@ def __init__( ssl_ca, validate_host, logger_, - retry_count=DEFAULT_RETRY_COUNT, - fetch_size=DEFAULT_FETCH_SIZE, - ): + retry_count: int = DEFAULT_RETRY_COUNT, + fetch_size: int = DEFAULT_FETCH_SIZE, + ) -> None: self.host = host self.port = port self.user = user @@ -178,10 +179,10 @@ def __init__( self.queries = MSSQLQueries() self._logger = logger_ - def set_logger(self, logger_): + def set_logger(self, logger_) -> None: self._logger = logger_ - def close(self): + def close(self) -> None: if os.path.exists(self.certfile): try: os.remove(self.certfile) @@ -212,14 +213,14 @@ def engine(self): } return create_engine(connection_string, connect_args=connect_args) - def create_pem_file(self): + def create_pem_file(self) -> None: """Create pem file for SSL Verification""" pem_certificates = get_pem_format(key=self.ssl_ca) with NamedTemporaryFile(mode="w", suffix=".pem", delete=False) as cert: cert.write(pem_certificates) self.certfile = cert.name - async def get_cursor(self, query): + async def get_cursor(self, query: str): """Executes the passed query on the Non-Async supported Database server and return cursor. Args: @@ -255,7 +256,7 @@ async def ping(self): ) ) - async def get_tables_to_fetch(self, is_filtering=False): + async def get_tables_to_fetch(self, is_filtering: bool = False): tables = configured_tables(self.tables) if is_wildcard(tables) or is_filtering: if is_filtering: @@ -329,7 +330,7 @@ async def get_table_last_update_time(self, table): ): return last_update_time - async def data_streamer(self, table=None, query=None): + async def data_streamer(self, table=None, query: Optional[str] = None): """Streaming data from a table Args: @@ -372,7 +373,7 @@ class MSSQLDataSource(BaseDataSource): service_type = "mssql" advanced_rules_enabled = True - def __init__(self, configuration): + def __init__(self, configuration) -> None: """Setup connection to the Microsoft SQL database-server configured by user Args: @@ -397,7 +398,7 @@ def __init__(self, configuration): logger_=self._logger, ) - def _set_internal_logger(self): + def _set_internal_logger(self) -> None: self.mssql_client.set_logger(self._logger) @classmethod @@ -484,13 +485,13 @@ def get_default_configuration(cls): }, } - def advanced_rules_validators(self): + def advanced_rules_validators(self) -> List[MSSQLAdvancedRulesValidator]: return [MSSQLAdvancedRulesValidator(self)] - async def close(self): + async def close(self) -> None: self.mssql_client.close() - async def ping(self): + async def ping(self) -> None: """Verify the connection with the database-server configured by user""" self._logger.debug("Validating the Connector Configuration...") try: @@ -511,7 +512,7 @@ def row2doc(self, row, doc_id, table, timestamp): ) return row - async def get_primary_key(self, tables): + async def get_primary_key(self, tables: Optional[Sized]) -> List[str]: self._logger.debug(f"Extracting primary keys for tables: {tables}") primary_key_columns = [] for table in tables: @@ -523,7 +524,12 @@ async def get_primary_key(self, tables): column_names=primary_key_columns, schema=self.schema, tables=tables ) - async def yield_rows_for_query(self, primary_key_columns, tables, query=None): + async def yield_rows_for_query( + self, + primary_key_columns, + tables: Optional[Sized], + query: Optional[str] = None, + ): if query is None: streamer = self.mssql_client.data_streamer(table=tables[0]) else: @@ -585,7 +591,9 @@ async def fetch_documents_from_query(self, tables, query, id_columns): f"Something went wrong while fetching document for query {query} and tables {', '.join(tables)}. Error: {exception}" ) - async def _yield_docs_custom_query(self, tables, query, id_columns=None): + async def _yield_docs_custom_query( + self, tables, query, id_columns: Optional[List[str]] = None + ): primary_key_columns = await self.get_primary_key(tables=tables) if id_columns: diff --git a/connectors/sources/mysql.py b/connectors/sources/mysql.py index ecb1c95c2..eaa2da03e 100644 --- a/connectors/sources/mysql.py +++ b/connectors/sources/mysql.py @@ -6,6 +6,9 @@ """MySQL source module responsible to fetch documents from MySQL""" import re +from datetime import datetime +from typing import Any, Dict, List, Optional, Tuple, Union +from unittest.mock import AsyncMock, MagicMock import aiomysql import fastjsonschema @@ -15,7 +18,12 @@ AdvancedRulesValidator, SyncRuleValidationResult, ) -from connectors.source import BaseDataSource, ConfigurableFieldValueError +from connectors.logger import ExtraLogger +from connectors.source import ( + BaseDataSource, + ConfigurableFieldValueError, + DataSourceConfiguration, +) from connectors.sources.generic_database import ( configured_tables, is_wildcard, @@ -28,7 +36,9 @@ ssl_context, ) -SPLIT_BY_COMMA_OUTSIDE_BACKTICKS_PATTERN = re.compile(r"`(?:[^`]|``)+`|\w+") +SPLIT_BY_COMMA_OUTSIDE_BACKTICKS_PATTERN: re.Pattern[str] = re.compile( + r"`(?:[^`]|``)+`|\w+" +) MAX_POOL_SIZE = 10 DEFAULT_FETCH_SIZE = 5000 @@ -36,7 +46,7 @@ RETRY_INTERVAL = 2 -def format_list(list_): +def format_list(list_: Union[List[str], str]) -> str: return ", ".join(list_) @@ -56,10 +66,10 @@ class MySQLAdvancedRulesValidator(AdvancedRulesValidator): SCHEMA = fastjsonschema.compile(definition=SCHEMA_DEFINITION) - def __init__(self, source): + def __init__(self, source: "MySqlDataSource") -> None: self.source = source - async def validate(self, advanced_rules): + async def validate(self, advanced_rules: Any) -> SyncRuleValidationResult: if len(advanced_rules) == 0: return SyncRuleValidationResult.valid_result( SyncRuleValidationResult.ADVANCED_RULES @@ -72,7 +82,18 @@ async def validate(self, advanced_rules): interval=RETRY_INTERVAL, strategy=RetryStrategy.EXPONENTIAL_BACKOFF, ) - async def _remote_validation(self, advanced_rules): + async def _remote_validation( + self, + advanced_rules: List[ + Union[ + Dict[str, str], + Dict[str, Union[str, List[str]]], + Dict[str, List[str]], + str, + Dict[str, Union[str, bool, List[str]]], + ] + ], + ) -> SyncRuleValidationResult: try: MySQLAdvancedRulesValidator.SCHEMA(advanced_rules) except JsonSchemaValueException as e: @@ -107,17 +128,17 @@ async def _remote_validation(self, advanced_rules): class MySQLClient: def __init__( self, - host, - port, - user, - password, - ssl_enabled, - ssl_certificate, - logger_, - database=None, - max_pool_size=MAX_POOL_SIZE, - fetch_size=DEFAULT_FETCH_SIZE, - ): + host: str, + port: int, + user: str, + password: str, + ssl_enabled: bool, + ssl_certificate: str, + logger_: ExtraLogger, + database: None = None, + max_pool_size: int = MAX_POOL_SIZE, + fetch_size: int = DEFAULT_FETCH_SIZE, + ) -> None: self.host = host self.port = port self.user = user @@ -131,7 +152,7 @@ def __init__( self.connection = None self._logger = logger_ - async def __aenter__(self): + async def __aenter__(self) -> "MySQLClient": connection_string = { "host": self.host, "port": int(self.port), @@ -150,7 +171,9 @@ async def __aenter__(self): return self - async def __aexit__(self, exception_type, exception_value, exception_traceback): + async def __aexit__( + self, exception_type: None, exception_value: None, exception_traceback: None + ) -> None: self._sleeps.cancel() self.connection_pool.release(self.connection) @@ -162,14 +185,14 @@ async def __aexit__(self, exception_type, exception_value, exception_traceback): interval=RETRY_INTERVAL, strategy=RetryStrategy.EXPONENTIAL_BACKOFF, ) - async def get_all_table_names(self): + async def get_all_table_names(self) -> List[str]: async with self.connection.cursor(aiomysql.cursors.SSCursor) as cursor: await cursor.execute( f"SELECT TABLE_NAME FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA = '{self.database}'" ) return [table[0] for table in await cursor.fetchall()] - async def ping(self): + async def ping(self) -> None: try: await self.connection.ping() self._logger.info("Successfully connected to the MySQL Server.") @@ -182,13 +205,13 @@ async def ping(self): interval=RETRY_INTERVAL, strategy=RetryStrategy.EXPONENTIAL_BACKOFF, ) - async def get_column_names_for_query(self, query): + async def get_column_names_for_query(self, query: str) -> List[str]: async with self.connection.cursor(aiomysql.cursors.SSCursor) as cursor: await cursor.execute(f"SELECT q.* FROM ({query}) as q LIMIT 0") return [f"{column[0]}" for column in cursor.description] - async def get_column_names_for_table(self, table): + async def get_column_names_for_table(self, table: str) -> List[Union[str, Any]]: return await self.get_column_names_for_query( f"SELECT * FROM `{self.database}`.`{table}`" ) @@ -198,7 +221,7 @@ async def get_column_names_for_table(self, table): interval=RETRY_INTERVAL, strategy=RetryStrategy.EXPONENTIAL_BACKOFF, ) - async def get_primary_key_column_names(self, table): + async def get_primary_key_column_names(self, table: str) -> List[str]: async with self.connection.cursor(aiomysql.cursors.SSCursor) as cursor: await cursor.execute( f"SELECT COLUMN_NAME FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_SCHEMA = '{self.database}' AND TABLE_NAME = '{table}' AND COLUMN_KEY = 'PRI'" @@ -211,7 +234,7 @@ async def get_primary_key_column_names(self, table): interval=RETRY_INTERVAL, strategy=RetryStrategy.EXPONENTIAL_BACKOFF, ) - async def get_last_update_time(self, table): + async def get_last_update_time(self, table: str) -> str: async with self.connection.cursor(aiomysql.cursors.SSCursor) as cursor: await cursor.execute( f"SELECT UPDATE_TIME FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA = '{self.database}' AND TABLE_NAME = '{table}'" @@ -241,7 +264,7 @@ async def yield_rows_for_table(self, table, primary_keys, table_row_count): break offset += self.fetch_size - async def _get_table_row_count_for_query(self, query): + async def _get_table_row_count_for_query(self, query: str) -> int: table_row_count_query = re.sub( r"SELECT\s.*?\sFROM", "SELECT COUNT(*) FROM", @@ -254,8 +277,8 @@ async def _get_table_row_count_for_query(self, query): return int(table_row_count[0]) def _update_query_with_pagination_attributes( - self, query, offset, primary_key_columns - ): + self, query: str, offset: int, primary_key_columns: List[str] + ) -> str: updated_query = "" has_orderby = bool(re.search(r"\bORDER\s+BY\b", query, flags=re.IGNORECASE)) # Checking if custom query has a semicolon at the end or not @@ -313,7 +336,13 @@ async def _fetchmany_in_batches(self, query): ) -def row2doc(row, column_names, primary_key_columns, table, timestamp): +def row2doc( + row: Union[List[str], Tuple[int, str, int]], + column_names: List[str], + primary_key_columns: Union[List[str], str], + table: Union[List[str], str], + timestamp: Optional[Union[str, datetime]], +) -> Dict[str, Union[int, str, List[str], datetime]]: row = dict(zip(column_names, row, strict=True)) row.update( { @@ -326,7 +355,11 @@ def row2doc(row, column_names, primary_key_columns, table, timestamp): return row -def generate_id(tables, row, primary_key_columns): +def generate_id( + tables: Union[List[str], str], + row: Dict[str, Union[str, int]], + primary_key_columns: Union[List[str], str], +) -> str: """Generates an id using table names as prefix in sorted order and primary key values. Example: @@ -351,7 +384,7 @@ class MySqlDataSource(BaseDataSource): service_type = "mysql" advanced_rules_enabled = True - def __init__(self, configuration): + def __init__(self, configuration: DataSourceConfiguration) -> None: super().__init__(configuration=configuration) self._sleeps = CancellableSleeps() self.retry_count = self.configuration["retry_count"] @@ -361,7 +394,16 @@ def __init__(self, configuration): self.tables = self.configuration["tables"] @classmethod - def get_default_configuration(cls): + def get_default_configuration( + cls, + ) -> Dict[ + str, + Union[ + Dict[str, Union[List[Dict[str, Union[bool, str]]], int, str]], + Dict[str, Union[List[str], int, str]], + Dict[str, Union[int, str]], + ], + ]: return { "host": { "label": "Host", @@ -430,7 +472,7 @@ def get_default_configuration(cls): }, } - def mysql_client(self): + def mysql_client(self) -> MySQLClient: return MySQLClient( host=self.configuration["host"], port=self.configuration["port"], @@ -443,13 +485,13 @@ def mysql_client(self): logger_=self._logger, ) - def advanced_rules_validators(self): + def advanced_rules_validators(self) -> List[MySQLAdvancedRulesValidator]: return [MySQLAdvancedRulesValidator(self)] - async def close(self): + async def close(self) -> None: self._sleeps.cancel() - async def validate_config(self): + async def validate_config(self) -> None: """Validates that user input is not empty and adheres to the specified constraints. Also validate, if the configured database and the configured tables are present and accessible using the configured user. @@ -465,20 +507,20 @@ async def validate_config(self): interval=RETRY_INTERVAL, strategy=RetryStrategy.EXPONENTIAL_BACKOFF, ) - async def _remote_validation(self): + async def _remote_validation(self) -> None: async with self.mysql_client() as client: async with client.connection.cursor() as cursor: await self._validate_database_accessible(cursor) await self._validate_tables_accessible(cursor) - async def _validate_database_accessible(self, cursor): + async def _validate_database_accessible(self, cursor: AsyncMock) -> None: try: await cursor.execute(f"USE `{self.database}`;") except aiomysql.Error as e: msg = f"The database '{self.database}' is either not present or not accessible for the user '{self.configuration['user']}'." raise ConfigurableFieldValueError(msg) from e - async def _validate_tables_accessible(self, cursor): + async def _validate_tables_accessible(self, cursor: AsyncMock) -> None: non_accessible_tables = [] tables_to_validate = await self.get_tables_to_fetch() @@ -492,7 +534,7 @@ async def _validate_tables_accessible(self, cursor): msg = f"The tables '{format_list(non_accessible_tables)}' are either not present or not accessible for user '{self.configuration['user']}'." raise ConfigurableFieldValueError(msg) - async def ping(self): + async def ping(self) -> None: async with self.mysql_client() as client: await client.ping() @@ -523,7 +565,9 @@ async def get_docs(self, filtering=None): async for row in self.fetch_documents(tables): yield row, None - async def fetch_documents(self, tables, query=None, id_columns=None): + async def fetch_documents( + self, tables: List[str], query: Optional[str] = None, id_columns: None = None + ) -> None: """If query is not present it fetches all rows from all tables. Otherwise, the custom query is executed. @@ -577,7 +621,13 @@ async def _yield_all_docs_from_tables(self, client, tables): timestamp=last_update_time, ) - async def _yield_docs_custom_query(self, client, tables, query, id_columns): + async def _yield_docs_custom_query( + self, + client: Union[MySQLClient, MagicMock], + tables: Union[List[str], str], + query: str, + id_columns: None, + ) -> None: primary_key_columns = [ await client.get_primary_key_column_names(table) for table in tables ] @@ -617,7 +667,7 @@ async def _yield_docs_custom_query(self, client, tables, query, id_columns): timestamp=max(last_update_times) if len(last_update_times) else None, ) - async def get_tables_to_fetch(self): + async def get_tables_to_fetch(self) -> Union[List[str], str]: tables = configured_tables(self.tables) async with self.mysql_client() as client: diff --git a/connectors/sources/network_drive.py b/connectors/sources/network_drive.py index 46d529f03..11013120b 100644 --- a/connectors/sources/network_drive.py +++ b/connectors/sources/network_drive.py @@ -7,8 +7,11 @@ import asyncio import csv +from _asyncio import Future from collections import deque from functools import cached_property, partial +from typing import Any, Dict, Generator, Iterator, List, Optional, Set, Tuple, Union +from unittest.mock import Mock import fastjsonschema import requests.exceptions @@ -19,6 +22,7 @@ SMBException, SMBOSError, SMBResponseException, + Unsuccessful, ) from smbprotocol.file_info import ( InfoType, @@ -44,7 +48,12 @@ SyncRuleValidationResult, ) from connectors.logger import logger -from connectors.source import BaseDataSource, ConfigurableFieldValueError +from connectors.protocol.connectors import Filter +from connectors.source import ( + BaseDataSource, + ConfigurableFieldValueError, + DataSourceConfiguration, +) from connectors.utils import ( RetryStrategy, iso_utc, @@ -93,11 +102,11 @@ class NoLogonServerException(Exception): pass -def _prefix_user(user): +def _prefix_user(user: str) -> Optional[str]: return prefix_identity("user", user) -def _prefix_rid(rid): +def _prefix_rid(rid: str) -> Optional[str]: return prefix_identity("rid", rid) @@ -118,10 +127,12 @@ class NetworkDriveAdvancedRulesValidator(AdvancedRulesValidator): SCHEMA_DEFINITION = {"type": "array", "items": RULES_OBJECT_SCHEMA_DEFINITION} SCHEMA = fastjsonschema.compile(definition=SCHEMA_DEFINITION) - def __init__(self, source): + def __init__(self, source: "NASDataSource") -> None: self.source = source - async def validate(self, advanced_rules): + async def validate( + self, advanced_rules: Union[Dict[str, List[str]], List[Dict[str, str]]] + ) -> SyncRuleValidationResult: if len(advanced_rules) == 0: return SyncRuleValidationResult.valid_result( SyncRuleValidationResult.ADVANCED_RULES @@ -129,7 +140,9 @@ async def validate(self, advanced_rules): return await self.validate_pattern(advanced_rules) - async def validate_pattern(self, advanced_rules): + async def validate_pattern( + self, advanced_rules: Union[Dict[str, List[str]], List[Dict[str, str]]] + ) -> SyncRuleValidationResult: try: NetworkDriveAdvancedRulesValidator.SCHEMA(advanced_rules) except fastjsonschema.JsonSchemaValueException as e: @@ -159,12 +172,12 @@ async def validate_pattern(self, advanced_rules): class SecurityInfo: - def __init__(self, user, password, server): + def __init__(self, user: str, password: str, server: str) -> None: self.username = user self.server_ip = server self.password = password - def get_descriptor(self, file_descriptor, info): + def get_descriptor(self, file_descriptor, info) -> SMB2CreateSDBuffer: """Get the Security Descriptor for the opened file.""" query_request = SMB2QueryInfoRequest() query_request["info_type"] = InfoType.SMB2_0_INFO_SECURITY @@ -187,7 +200,7 @@ def get_descriptor(self, file_descriptor, info): return security_descriptor @cached_property - def session(self): + def session(self) -> winrm.Session: return winrm.Session( self.server_ip, auth=(self.username, self.password), @@ -195,7 +208,7 @@ def session(self): server_cert_validation="ignore", ) - def parse_output(self, raw_output): + def parse_output(self, raw_output: Mock) -> Dict[str, str]: """ Formats and extracts key-value pairs from raw output data. @@ -243,16 +256,16 @@ def parse_output(self, raw_output): return formatted_result - def fetch_users(self): + def fetch_users(self) -> Dict[str, str]: users = self.session.run_ps(GET_USERS_COMMAND) return self.parse_output(users) - def fetch_groups(self): + def fetch_groups(self) -> Dict[str, str]: groups = self.session.run_ps(GET_GROUPS_COMMAND) return self.parse_output(groups) - def fetch_members(self, group_name): + def fetch_members(self, group_name: str) -> Dict[str, str]: members = self.session.run_ps(GET_GROUP_MEMBERS.format(name=group_name)) return self.parse_output(members) @@ -261,7 +274,9 @@ def fetch_members(self, group_name): class SMBSession: _connection = None - def __init__(self, server_ip, username, password, port): + def __init__( + self, server_ip: str, username: str, password: str, port: None + ) -> None: self.server_ip = server_ip self.username = username self.password = password @@ -269,7 +284,7 @@ def __init__(self, server_ip, username, password, port): self.session = None self._logger = logger - def create_connection(self): + def create_connection(self) -> None: """Creates an SMB session to the shared drive.""" try: self.session = smbclient.register_session( @@ -281,7 +296,7 @@ def create_connection(self): except SMBResponseException as exception: self.handle_smb_response_errors(exception=exception) - def handle_smb_response_errors(self, exception): + def handle_smb_response_errors(self, exception: Unsuccessful): msg = "" if exception.status == STATUS_INVALID_WORKSTATION: msg = f"Client does not have permission to access server: ({self.server_ip}:{self.port})." @@ -321,7 +336,7 @@ class NASDataSource(BaseDataSource): dls_enabled = True incremental_sync_enabled = True - def __init__(self, configuration): + def __init__(self, configuration: DataSourceConfiguration) -> None: """Set up the connection to the Network Drive Args: @@ -337,15 +352,42 @@ def __init__(self, configuration): self.identity_mappings = self.configuration["identity_mappings"] self.security_info = SecurityInfo(self.username, self.password, self.server_ip) - def advanced_rules_validators(self): + def advanced_rules_validators(self) -> List[NetworkDriveAdvancedRulesValidator]: return [NetworkDriveAdvancedRulesValidator(self)] @cached_property - def smb_connection(self): + def smb_connection(self) -> SMBSession: return SMBSession(self.server_ip, self.username, self.password, self.port) @classmethod - def get_default_configuration(cls): + def get_default_configuration( + cls, + ) -> Dict[ + str, + Union[ + Dict[ + str, + Union[ + List[Dict[str, str]], + List[Dict[str, Union[bool, str]]], + List[str], + int, + str, + ], + ], + Dict[ + str, + Union[ + List[str], + List[Union[Dict[str, str], Dict[str, Union[bool, str]]]], + int, + str, + ], + ], + Dict[str, Union[List[str], int, str]], + Dict[str, Union[int, str]], + ], + ]: """Get the default configuration for Network Drive. Returns: @@ -424,7 +466,7 @@ def get_default_configuration(cls): }, } - def format_document(self, file): + def format_document(self, file: Mock) -> Dict[str, str]: file_details = file._dir_info.fields document = { "path": file.path, @@ -443,7 +485,7 @@ def format_document(self, file): strategy=RetryStrategy.EXPONENTIAL_BACKOFF, skipped_exceptions=[SMBOSError, SMBException], ) - async def traverse_diretory(self, path): + async def traverse_diretory(self, path: str) -> Iterator[Future]: self._logger.debug( "Fetching the directory tree from remote server and content of directory on path" ) @@ -484,8 +526,8 @@ async def traverse_diretory(self, path): continue def is_match_with_previous_rules( - self, file_path, indexed_rules, match_with_previous_rules - ): + self, file_path: str, indexed_rules: Set[str], match_with_previous_rules: bool + ) -> bool: # Check if the file is matched with any of the previous indexed rules for indexed_rule in indexed_rules: if not match_with_previous_rules: @@ -502,7 +544,9 @@ def is_match_with_previous_rules( strategy=RetryStrategy.EXPONENTIAL_BACKOFF, skipped_exceptions=[SMBOSError, SMBException], ) - async def traverse_directory_for_syncrule(self, path, glob_pattern, indexed_rules): + async def traverse_directory_for_syncrule( + self, path: str, glob_pattern: str, indexed_rules: Set[str] + ) -> Iterator[Future]: self._logger.debug( "Fetching the directory tree from remote server and content of directory on path" ) @@ -546,14 +590,16 @@ async def traverse_directory_for_syncrule(self, path, glob_pattern, indexed_rule ) continue - def get_base_path(self, pattern): + def get_base_path(self, pattern: str) -> str: wildcards = ["*", "?", "[", "{", "!", "^"] for i, char in enumerate(pattern): if char in wildcards: return rf"\\{self.server_ip}/{pattern[:i].rsplit('/', 1)[0]}/" return rf"\\{self.server_ip}/{pattern}" - async def fetch_filtered_directory(self, advanced_rules): + async def fetch_filtered_directory( + self, advanced_rules: List[Dict[str, str]] + ) -> Iterator[Future]: """ Fetch file and folder based on advanced rules. @@ -589,21 +635,21 @@ async def fetch_filtered_directory(self, advanced_rules): f"Following advanced rules do not match with any path present in network drive or the rule is similar to another rule: {unmatched_rules}" ) - async def validate_config(self): + async def validate_config(self) -> None: await super().validate_config() path = self.configuration["drive_path"] if path.startswith("/") or path.startswith("\\"): message = f"SMB Path:{path} should not start with '/' in the beginning." raise ConfigurableFieldValueError(message) - async def ping(self): + async def ping(self) -> None: """Verify the connection with Network Drive""" await asyncio.to_thread(self.smb_connection.create_connection) await self.close() self._logger.info("Successfully connected to the Network Drive") - async def close(self): + async def close(self) -> None: """Close all the open smb sessions""" if self.smb_connection.session is None: return @@ -615,7 +661,7 @@ async def close(self): ), ) - async def fetch_file_content(self, path): + async def fetch_file_content(self, path: str) -> None: """Fetches the file content from the given drive path Args: @@ -641,7 +687,12 @@ async def fetch_file_content(self, path): f"Cannot read the contents of file on path:{path}. Error {error}" ) - async def get_content(self, file, timestamp=None, doit=None): + async def get_content( + self, + file: Dict[str, Union[str, int]], + timestamp: None = None, + doit: Optional[bool] = None, + ) -> Optional[Dict[str, str]]: """Get the content for a given file Args: @@ -676,7 +727,9 @@ async def get_content(self, file, timestamp=None, doit=None): partial(self.fetch_file_content, path=file["path"]), ) - def list_file_permission(self, file_path, file_type, mode, access): + def list_file_permission( + self, file_path: str, file_type: str, mode: str, access: str + ) -> Optional[List[str]]: try: with smbclient.open_file( file_path, @@ -695,7 +748,7 @@ def list_file_permission(self, file_path, file_type, mode, access): f"Cannot read the contents of file on path:{file_path}. Error {error}" ) - def _dls_enabled(self): + def _dls_enabled(self) -> bool: if ( self._features is None or not self._features.document_level_security_enabled() @@ -705,8 +758,12 @@ def _dls_enabled(self): return self.configuration["use_document_level_security"] async def _decorate_with_access_control( - self, document, file_path, file_type, groups_info - ): + self, + document: Dict[str, str], + file_path: str, + file_type: str, + groups_info: Dict[str, Dict[str, str]], + ) -> Generator[Future, None, Dict[str, Union[str, List[str]]]]: if self._dls_enabled(): allow_permissions, deny_permissions = await self.get_entity_permission( file_path=file_path, file_type=file_type, groups_info=groups_info @@ -717,7 +774,9 @@ async def _decorate_with_access_control( ) return document - async def _user_access_control_doc(self, user, sid, groups_info=None): + async def _user_access_control_doc( + self, user: str, sid: str, groups_info: Optional[List[str]] = None + ) -> Dict[str, Any]: rid = str(sid).split("-")[-1] prefixed_username = _prefix_user(user) rid_user = _prefix_rid(rid) @@ -737,7 +796,9 @@ async def _user_access_control_doc(self, user, sid, groups_info=None): "created_at": iso_utc(), } | es_access_control_query(access_control) - def read_user_info_csv(self): + def read_user_info_csv( + self, + ) -> List[Union[Dict[str, str], Any, Dict[str, Union[str, List[str]]]]]: with open(self.identity_mappings, encoding="utf-8") as file: user_info = [] try: @@ -757,7 +818,9 @@ def read_user_info_csv(self): ) return user_info - async def fetch_groups_info(self): + async def fetch_groups_info( + self, + ) -> Generator[Future, None, Dict[str, Dict[str, str]]]: self._logger.info( f"Fetching all groups and members for drive at path '{self.drive_path}'" ) @@ -772,7 +835,7 @@ async def fetch_groups_info(self): return groups_members - async def get_access_control(self): + async def get_access_control(self) -> None: if not self._dls_enabled(): self._logger.warning("DLS is not enabled. Skipping") return @@ -809,7 +872,11 @@ async def get_access_control(self): msg = "Something went wrong" raise requests.exceptions.ConnectionError(msg) from exception - async def get_entity_permission(self, file_path, file_type, groups_info): + async def get_entity_permission( + self, file_path: str, file_type: str, groups_info: Dict[str, Dict[str, str]] + ) -> Generator[ + Future, None, Union[Tuple[List[str], List[Any]], Tuple[List[str], List[str]]] + ]: """Processes permissions for a network drive, focusing on key terms: - SID (Security Identifier): The unique identifier for a user or group, it undergoes revision. @@ -871,7 +938,7 @@ async def get_entity_permission(self, file_path, file_type, groups_info): return allow_permissions, deny_permissions - async def get_docs(self, filtering=None): + async def get_docs(self, filtering: Optional[Filter] = None) -> Iterator[Future]: """Executes the logic to fetch files and folders in async manner. Yields: dictionary: Dictionary containing the Network Drive files and folders as documents diff --git a/connectors/sources/notion.py b/connectors/sources/notion.py index 3dc224f31..689296f9f 100644 --- a/connectors/sources/notion.py +++ b/connectors/sources/notion.py @@ -9,13 +9,26 @@ import json import os import re +from _asyncio import Task from copy import copy from functools import cached_property, partial -from typing import Any, Awaitable, Callable +from typing import ( + Any, + Awaitable, + Callable, + Dict, + Generator, + Iterator, + List, + Optional, + Union, +) +from unittest.mock import MagicMock from urllib.parse import unquote import aiohttp import fastjsonschema +from aiohttp.client import ClientSession from aiohttp.client_exceptions import ClientResponseError from notion_client import APIResponseError, AsyncClient @@ -24,7 +37,11 @@ SyncRuleValidationResult, ) from connectors.logger import logger -from connectors.source import BaseDataSource, ConfigurableFieldValueError +from connectors.source import ( + BaseDataSource, + ConfigurableFieldValueError, + DataSourceConfiguration, +) from connectors.utils import CancellableSleeps, RetryStrategy, retryable RETRIES = 3 @@ -45,24 +62,24 @@ class NotFound(Exception): class NotionClient: """Notion API client""" - def __init__(self, configuration): + def __init__(self, configuration: DataSourceConfiguration) -> None: self._sleeps = CancellableSleeps() self.configuration = configuration self._logger = logger self.notion_secret_key = self.configuration["notion_secret_key"] - def set_logger(self, logger_): + def set_logger(self, logger_) -> None: self._logger = logger_ @cached_property - def _get_client(self): + def _get_client(self) -> AsyncClient: return AsyncClient( auth=self.notion_secret_key, base_url=BASE_URL, ) @cached_property - def session(self): + def session(self) -> ClientSession: """Generate aiohttp client session. Returns: @@ -81,7 +98,7 @@ def session(self): strategy=RetryStrategy.EXPONENTIAL_BACKOFF, skipped_exceptions=NotFound, ) - async def get_via_session(self, url): + async def get_via_session(self, url: str) -> Iterator[Task]: self._logger.debug(f"Fetching data from url {url}") try: async with self.session.get(url=url) as response: @@ -106,8 +123,11 @@ async def get_via_session(self, url): skipped_exceptions=NotFound, ) async def fetch_results( - self, function: Callable[..., Awaitable[Any]], next_cursor=None, **kwargs: Any - ): + self, + function: Callable[..., Awaitable[Any]], + next_cursor: None = None, + **kwargs: Any, + ) -> Generator[Task, None, Dict[str, Optional[Union[List[Dict[str, str]], bool]]]]: try: return await function(start_cursor=next_cursor, **kwargs) except APIResponseError as exception: @@ -139,11 +159,11 @@ async def async_iterate_paginated_api( if not response["has_more"] or next_cursor is None: return - async def fetch_owner(self): + async def fetch_owner(self) -> None: """Fetch integration authorized owner""" await self._get_client.users.me() - async def close(self): + async def close(self) -> None: self._sleeps.cancel() await self._get_client.aclose() await self.session.close() @@ -160,7 +180,7 @@ async def fetch_users(self): if user_document.get("type") != "bot": yield user_document - async def fetch_child_blocks(self, block_id): + async def fetch_child_blocks(self, block_id: str) -> None: """Fetch child blocks recursively for a given block ID. Args: block_id (str): The ID of the parent block. @@ -220,7 +240,7 @@ async def fetch_comments(self, block_id): ): yield block_comment - async def query_database(self, database_id, body=None): + async def query_database(self, database_id: str, body: None = None) -> None: if body is None: body = {} async for result in self.async_iterate_paginated_api( @@ -278,11 +298,30 @@ class NotionAdvancedRulesValidator(AdvancedRulesValidator): SCHEMA = fastjsonschema.compile(definition=RULES_OBJECT_SCHEMA_DEFINITION) - def __init__(self, source): + def __init__(self, source: "NotionDataSource") -> None: self.source = source self._logger = logger - async def validate(self, advanced_rules): + async def validate( + self, + advanced_rules: Dict[ + str, + Union[ + List[ + Dict[ + str, + Union[ + Dict[str, List[Dict[str, Union[Dict[str, str], str]]]], str + ], + ] + ], + List[Dict[str, Union[Dict[str, str], str]]], + List[Dict[str, str]], + List[Dict[str, Dict[str, str]]], + List[Dict[str, Union[Dict[str, Union[Dict[str, str], str]], str]]], + ], + ], + ) -> SyncRuleValidationResult: if len(advanced_rules) == 0: return SyncRuleValidationResult.valid_result( SyncRuleValidationResult.ADVANCED_RULES @@ -291,7 +330,26 @@ async def validate(self, advanced_rules): self._logger.info("Remote validation started") return await self._remote_validation(advanced_rules) - async def _remote_validation(self, advanced_rules): + async def _remote_validation( + self, + advanced_rules: Dict[ + str, + Union[ + List[ + Dict[ + str, + Union[ + Dict[str, List[Dict[str, Union[Dict[str, str], str]]]], str + ], + ] + ], + List[Dict[str, Union[Dict[str, str], str]]], + List[Dict[str, str]], + List[Dict[str, Dict[str, str]]], + List[Dict[str, Union[Dict[str, Union[Dict[str, str], str]], str]]], + ], + ], + ) -> SyncRuleValidationResult: try: NotionAdvancedRulesValidator.SCHEMA(advanced_rules) except fastjsonschema.JsonSchemaValueException as e: @@ -358,7 +416,7 @@ class NotionDataSource(BaseDataSource): advanced_rules_enabled = True incremental_sync_enabled = True - def __init__(self, configuration): + def __init__(self, configuration: DataSourceConfiguration) -> None: """Setup the connection to the Notion instance. Args: @@ -375,10 +433,10 @@ def __init__(self, configuration): self._sleeps = CancellableSleeps() self.concurrent_downloads = self.configuration["concurrent_downloads"] - def _set_internal_logger(self): + def _set_internal_logger(self) -> None: self.notion_client.set_logger(self._logger) - async def ping(self): + async def ping(self) -> None: try: await self.notion_client.fetch_owner() self._logger.info("Successfully connected to Notion.") @@ -386,10 +444,20 @@ async def ping(self): self._logger.exception("Error while connecting to Notion.") raise - async def close(self): + async def close(self) -> None: await self.notion_client.close() - async def get_entities(self, entity_type, entity_titles): + async def get_entities( + self, entity_type: str, entity_titles: List[str] + ) -> List[ + Dict[ + str, + Union[ + Dict[str, Dict[str, List[Dict[str, Dict[str, str]]]]], + List[Dict[str, str]], + ], + ] + ]: """Search for a database or page with the given title.""" invalid_titles = [] found_titles = set() @@ -435,7 +503,7 @@ async def get_entities(self, entity_type, entity_titles): raise ConfigurableFieldValueError(msg) return exact_match_results - async def validate_config(self): + async def validate_config(self) -> None: """Validates if user configured databases and pages are available in notion.""" await super().validate_config() await asyncio.gather( @@ -444,7 +512,11 @@ async def validate_config(self): ) @classmethod - def get_default_configuration(cls): + def get_default_configuration( + cls, + ) -> Dict[ + str, Union[Dict[str, Union[List[str], int, str]], Dict[str, Union[int, str]]] + ]: """Get the default configuration for Notion. Returns: dict: Default configuration. @@ -491,10 +563,10 @@ def get_default_configuration(cls): }, } - def advanced_rules_validators(self): + def advanced_rules_validators(self) -> List[NotionAdvancedRulesValidator]: return [NotionAdvancedRulesValidator(self)] - def tweak_bulk_options(self, options): + def tweak_bulk_options(self, options) -> None: """Tweak bulk options as per concurrent downloads support by Notion Args: @@ -503,14 +575,24 @@ def tweak_bulk_options(self, options): options["concurrent_downloads"] = self.concurrent_downloads - async def get_file_metadata(self, attachment_metadata, file_url): + async def get_file_metadata( + self, attachment_metadata: Dict[Any, Any], file_url: str + ) -> Dict[str, Union[int, str]]: response = await anext(self.notion_client.get_via_session(url=file_url)) attachment_metadata["extension"] = "." + response.url.path.split(".")[-1] attachment_metadata["size"] = response.content_length attachment_metadata["name"] = unquote(response.url.path.split("/")[-1]) return attachment_metadata - async def get_content(self, attachment, file_url, timestamp=None, doit=False): + async def get_content( + self, + attachment: Dict[ + str, Union[str, Dict[str, str], Dict[str, Union[Dict[str, str], str]]] + ], + file_url: Optional[str], + timestamp: None = None, + doit: bool = False, + ) -> Optional[MagicMock]: """Extracts the content for Apache TIKA supported file types. Args: @@ -551,7 +633,7 @@ async def get_content(self, attachment, file_url, timestamp=None, doit=False): ), ) - def _format_doc(self, data): + def _format_doc(self, data: Dict[str, Any]) -> Dict[str, Any]: """Format document for handling empty values & type casting. Args: @@ -570,7 +652,7 @@ def _format_doc(self, data): del data["properties"] return data - def generate_query(self): + def generate_query(self) -> Iterator[Dict[str, Union[Dict[str, str], str]]]: if self.pages == ["*"] and self.databases == ["*"]: yield {} else: @@ -585,7 +667,46 @@ def generate_query(self): "filter": {"value": "database", "property": "object"}, } - def is_connected_property_block(self, page_database): + def is_connected_property_block( + self, + page_database: Dict[ + str, + Union[ + str, + Dict[str, str], + Dict[ + str, + Union[ + Dict[ + str, Union[str, List[Dict[str, Union[Dict[str, str], str]]]] + ], + Dict[str, Union[str, List[Dict[str, str]]]], + ], + ], + Dict[ + str, + Dict[ + str, + Union[ + str, + List[ + Dict[ + str, + Optional[ + Union[ + str, + Dict[str, Optional[str]], + Dict[str, Union[bool, str]], + ] + ], + ] + ], + ], + ], + ], + ], + ], + ) -> bool: properties = page_database.get("properties") if properties is None: return False diff --git a/connectors/sources/onedrive.py b/connectors/sources/onedrive.py index bbb433c38..e7a19d72b 100644 --- a/connectors/sources/onedrive.py +++ b/connectors/sources/onedrive.py @@ -8,12 +8,26 @@ import asyncio import json import os +from _asyncio import Future, Task +from asyncio.tasks import _GatheringFuture from datetime import datetime, timedelta from functools import cached_property, partial +from typing import ( + Any, + Callable, + Dict, + Generator, + Iterator, + List, + Optional, + Tuple, + Union, +) from urllib import parse import aiohttp import fastjsonschema +from aiohttp.client import ClientSession from aiohttp.client_exceptions import ( ClientPayloadError, ClientResponseError, @@ -31,7 +45,7 @@ SyncRuleValidationResult, ) from connectors.logger import logger -from connectors.source import BaseDataSource +from connectors.source import BaseDataSource, DataSourceConfiguration from connectors.utils import ( CacheWithTimeout, CancellableSleeps, @@ -57,7 +71,7 @@ BATCH = "batch" ITEM_FIELDS = "id,name,lastModifiedDateTime,content.downloadUrl,createdDateTime,size,webUrl,parentReference,file,folder" -ENDPOINTS = { +ENDPOINTS: Dict[str, str] = { PING: "drives", USERS: "users", GROUPS: "users/{user_id}/transitiveMemberOf", @@ -75,27 +89,27 @@ ) logger.warning("IT'S SUPPOSED TO BE USED ONLY FOR TESTING") logger.warning("x" * 50) - override_url = os.environ["OVERRIDE_URL"] - BASE_URL = override_url - GRAPH_API_AUTH_URL = override_url + override_url: str = os.environ["OVERRIDE_URL"] + BASE_URL: str = override_url + GRAPH_API_AUTH_URL: str = override_url else: - BASE_URL = "https://graph.microsoft.com/v1.0/" - GRAPH_API_AUTH_URL = "https://login.microsoftonline.com" + BASE_URL: str = "https://graph.microsoft.com/v1.0/" + GRAPH_API_AUTH_URL: str = "https://login.microsoftonline.com" -def _prefix_email(email): +def _prefix_email(email: str) -> Optional[str]: return prefix_identity("email", email) -def _prefix_user(user): +def _prefix_user(user: str) -> Optional[str]: return prefix_identity("user", user) -def _prefix_user_id(user_id): +def _prefix_user_id(user_id: str) -> Optional[str]: return prefix_identity("user_id", user_id) -def _prefix_group(group): +def _prefix_group(group: str) -> Optional[str]: return prefix_identity("group", group) @@ -142,10 +156,18 @@ class OneDriveAdvancedRulesValidator(AdvancedRulesValidator): SCHEMA_DEFINITION = {"type": "array", "items": RULES_OBJECT_SCHEMA_DEFINITION} SCHEMA = fastjsonschema.compile(definition=SCHEMA_DEFINITION) - def __init__(self, source): + def __init__(self, source: "OneDriveDataSource") -> None: self.source = source - async def validate(self, advanced_rules): + async def validate( + self, + advanced_rules: Union[ + List[Dict[str, List[Union[str, List[str]]]]], + List[Union[Dict[str, List[str]], Dict[str, Union[str, List[str]]]]], + List[Dict[str, Union[str, List[str]]]], + Dict[str, List[str]], + ], + ) -> SyncRuleValidationResult: if len(advanced_rules) == 0: return SyncRuleValidationResult.valid_result( SyncRuleValidationResult.ADVANCED_RULES @@ -168,7 +190,7 @@ async def validate(self, advanced_rules): class AccessToken: """Class for handling access token for Microsoft Graph APIs""" - def __init__(self, configuration): + def __init__(self, configuration: DataSourceConfiguration) -> None: self.tenant_id = configuration["tenant_id"] self.client_id = configuration["client_id"] self.client_secret = configuration["client_secret"] @@ -206,7 +228,7 @@ async def get(self): interval=RETRY_INTERVAL, strategy=RetryStrategy.EXPONENTIAL_BACKOFF, ) - async def _set_access_token(self): + async def _set_access_token(self) -> None: """Generate access token with configuration fields and stores it in the cache""" url = f"{GRAPH_API_AUTH_URL}/{self.tenant_id}/oauth2/v2.0/token" data = { @@ -230,18 +252,18 @@ async def _set_access_token(self): class OneDriveClient: """Client Class for API calls to OneDrive""" - def __init__(self, configuration): + def __init__(self, configuration: DataSourceConfiguration) -> None: self._sleeps = CancellableSleeps() self.configuration = configuration self.retry_count = self.configuration["retry_count"] self._logger = logger self.token = AccessToken(configuration=configuration) - def set_logger(self, logger_): + def set_logger(self, logger_) -> None: self._logger = logger_ @cached_property - def session(self): + def session(self) -> ClientSession: """Generate base client session with configuration fields Returns: ClientSession: Base client session @@ -258,7 +280,7 @@ def session(self): }, ) - async def close_session(self): + async def close_session(self) -> None: self._sleeps.cancel() await self.session.close() del self.session @@ -269,7 +291,7 @@ async def close_session(self): strategy=RetryStrategy.EXPONENTIAL_BACKOFF, skipped_exceptions=NotFound, ) - async def get(self, url, header=None): + async def get(self, url: str, header: None = None) -> Iterator[Task]: access_token = await self.token.get() headers = {"authorization": f"Bearer {access_token}"} if header: @@ -289,7 +311,7 @@ async def get(self, url, header=None): strategy=RetryStrategy.EXPONENTIAL_BACKOFF, skipped_exceptions=NotFound, ) - async def post(self, url, payload=None): + async def post(self, url: str, payload: None = None) -> Iterator[Task]: access_token = await self.token.get() headers = { "authorization": f"Bearer {access_token}", @@ -314,7 +336,9 @@ async def post(self, url, payload=None): await self._sleeps.sleep(retry_seconds) raise - async def _handle_client_side_errors(self, e): + async def _handle_client_side_errors( + self, e: ClientResponseError + ) -> Iterator[Task]: if e.status == 429 or e.status == 503: response_headers = e.headers or {} retry_seconds = DEFAULT_RETRY_SECONDS @@ -341,7 +365,11 @@ async def _handle_client_side_errors(self, e): raise async def paginated_api_call( - self, url, params=None, fetch_size=FETCH_SIZE, header=None + self, + url, + params: Optional[str] = None, + fetch_size: int = FETCH_SIZE, + header=None, ): if params is None: params = {} @@ -369,7 +397,7 @@ async def paginated_api_call( ) break - async def list_users(self, include_groups=False): + async def list_users(self, include_groups: bool = False): header = None params = { "$filter": "accountEnabled eq true", @@ -403,7 +431,9 @@ async def list_file_permission(self, user_id, file_id): for permission_detail in response: yield permission_detail - async def get_owned_files(self, user_id, skipped_extensions=None, pattern=""): + async def get_owned_files( + self, user_id, skipped_extensions=None, pattern: str = "" + ): params = {"$select": ITEM_FIELDS} delta_endpoint = ENDPOINTS[DELTA].format(user_id=user_id) @@ -431,7 +461,7 @@ class OneDriveDataSource(BaseDataSource): dls_enabled = True incremental_sync_enabled = True - def __init__(self, configuration): + def __init__(self, configuration: DataSourceConfiguration) -> None: """Setup the connection to OneDrive Args: @@ -442,14 +472,18 @@ def __init__(self, configuration): self.concurrent_downloads = self.configuration["concurrent_downloads"] @cached_property - def client(self): + def client(self) -> OneDriveClient: return OneDriveClient(self.configuration) - def _set_internal_logger(self): + def _set_internal_logger(self) -> None: self.client.set_logger(self._logger) @classmethod - def get_default_configuration(cls): + def get_default_configuration( + cls, + ) -> Dict[ + str, Union[Dict[str, Union[List[str], int, str]], Dict[str, Union[int, str]]] + ]: """Get the default configuration for OneDrive Returns: @@ -509,7 +543,7 @@ def get_default_configuration(cls): }, } - def tweak_bulk_options(self, options): + def tweak_bulk_options(self, options) -> None: """Tweak bulk options as per concurrent downloads support by ServiceNow Args: @@ -518,14 +552,14 @@ def tweak_bulk_options(self, options): options["concurrent_downloads"] = self.concurrent_downloads - def advanced_rules_validators(self): + def advanced_rules_validators(self) -> List[OneDriveAdvancedRulesValidator]: return [OneDriveAdvancedRulesValidator(self)] - async def close(self): + async def close(self) -> None: """Closes unclosed client session""" await self.client.close_session() - async def ping(self): + async def ping(self) -> None: """Verify the connection with OneDrive""" try: url = parse.urljoin(BASE_URL, ENDPOINTS[PING]) @@ -535,7 +569,13 @@ async def ping(self): self._logger.exception("Error while connecting to OneDrive") raise - async def get_content(self, file, download_url, timestamp=None, doit=False): + async def get_content( + self, + file: Dict[str, Union[str, int]], + download_url: str, + timestamp: None = None, + doit: bool = False, + ) -> Generator[Future, None, Optional[Dict[str, str]]]: """Extracts the content for allowed file types. Args: @@ -572,7 +612,9 @@ async def get_content(self, file, download_url, timestamp=None, doit=False): ), ) - def prepare_doc(self, file): + def prepare_doc( + self, file: Dict[str, Optional[Union[str, Dict[str, str], int, List[str]]]] + ) -> Dict[str, Any]: file_info = file.get("file", {}) or {} modified_document = { @@ -589,7 +631,7 @@ def prepare_doc(self, file): modified_document[ACCESS_CONTROL] = file[ACCESS_CONTROL] return modified_document - def _dls_enabled(self): + def _dls_enabled(self) -> bool: if self._features is None: return False @@ -598,7 +640,9 @@ def _dls_enabled(self): return self.configuration["use_document_level_security"] - async def _decorate_with_access_control(self, document, user_id): + async def _decorate_with_access_control( + self, document: Dict[str, Union[str, Dict[str, str], int]], user_id: None + ) -> Dict[str, Union[str, Dict[str, str], int, List[str]]]: if self._dls_enabled(): entity_permissions = await self.get_entity_permission( user_id=user_id, file_id=document.get("id") @@ -608,7 +652,9 @@ async def _decorate_with_access_control(self, document, user_id): ) return document - async def _user_access_control_doc(self, user): + async def _user_access_control_doc( + self, user: Dict[str, Union[str, List[Dict[str, str]]]] + ) -> Dict[str, Any]: email = user.get("mail") username = user.get("userPrincipalName") @@ -638,7 +684,7 @@ async def _user_access_control_doc(self, user): "created_at": user.get("createdDateTime", iso_utc()), } | es_access_control_query(access_control) - async def get_access_control(self): + async def get_access_control(self) -> None: if not self._dls_enabled(): self._logger.warning("DLS is not enabled. Skipping") return @@ -647,7 +693,9 @@ async def get_access_control(self): async for user in self.client.list_users(include_groups=True): yield await self._user_access_control_doc(user=user) - async def get_entity_permission(self, user_id, file_id): + async def get_entity_permission( + self, user_id: Optional[str], file_id: str + ) -> List[str]: if not self._dls_enabled(): return [] @@ -691,15 +739,19 @@ async def get_entity_permission(self, user_id, file_id): return permissions - def _prepare_batch(self, request_id, url): + def _prepare_batch(self, request_id: str, url: str) -> Dict[str, Any]: return {"id": str(request_id), "method": "GET", "url": url, "retry_count": "0"} - def pop_batch_requests(self, batched_apis): + def pop_batch_requests( + self, batched_apis: List[Dict[str, str]] + ) -> List[Dict[str, str]]: batch = batched_apis[: min(GRAPH_API_MAX_BATCH_SIZE, len(batched_apis))] batched_apis[:] = batched_apis[len(batch) :] return batch - def lookup_request_by_id(self, requests, response_id): + def lookup_request_by_id( + self, requests: List[Dict[str, str]], response_id: str + ) -> Dict[str, str]: for request in requests: if request.get("id") == response_id: return request @@ -738,7 +790,14 @@ async def json_batching(self, batched_apis): else: batched_apis.append(request) - def send_document_to_es(self, entity, download_url): + def send_document_to_es( + self, entity: Dict[str, Any], download_url: Optional[str] + ) -> Union[ + Tuple[Dict[str, Optional[Union[str, int, List[str]]]], None], + Tuple[Dict[str, Optional[Union[str, int]]], None], + Tuple[Dict[str, Union[str, int, List[str]]], partial], + Tuple[Dict[str, Union[str, int]], partial], + ]: entity = self.prepare_doc(entity) if entity["type"] == FILE and download_url: @@ -747,8 +806,16 @@ def send_document_to_es(self, entity, download_url): return entity, None async def _bounbed_concurrent_tasks( - self, items, max_concurrency, calling_func, **kwargs - ): + self, + items: List[Dict[str, Union[str, Dict[str, str], int]]], + max_concurrency: int, + calling_func: Callable, + **kwargs, + ) -> Generator[ + _GatheringFuture, + None, + List[Dict[str, Union[str, Dict[str, str], int, List[str]]]], + ]: async def process_item(item, semaphore): async with semaphore: return await calling_func(item, **kwargs) @@ -759,7 +826,9 @@ async def process_item(item, semaphore): return await asyncio.gather(*tasks) - def build_owned_files_url(self, user): + def build_owned_files_url( + self, user: Dict[str, Union[str, List[Dict[str, str]]]] + ) -> Dict[str, Any]: user_id = user.get("id") files_uri = f"{ENDPOINTS[DELTA].format(user_id=user_id)}?$select={ITEM_FIELDS}" diff --git a/connectors/sources/oracle.py b/connectors/sources/oracle.py index f6a3459f6..ec67c9ee1 100644 --- a/connectors/sources/oracle.py +++ b/connectors/sources/oracle.py @@ -7,14 +7,17 @@ import asyncio import os +from _asyncio import Future from functools import cached_property, partial +from typing import Any, Dict, Generator, List, Optional, Tuple, Union from urllib.parse import quote from asyncpg.exceptions._base import InternalClientError -from sqlalchemy import create_engine, text +from sqlalchemy import CursorResult, create_engine, text from sqlalchemy.exc import ProgrammingError -from connectors.source import BaseDataSource +from connectors.logger import ExtraLogger +from connectors.source import BaseDataSource, DataSourceConfiguration from connectors.sources.generic_database import ( DEFAULT_FETCH_SIZE, DEFAULT_RETRY_COUNT, @@ -35,33 +38,33 @@ class OracleQueries(Queries): """Class contains methods which return query""" - def ping(self): + def ping(self) -> str: """Query to ping source""" return "SELECT 1+1 FROM DUAL" - def all_tables(self, **kwargs): + def all_tables(self, **kwargs) -> str: """Query to get all tables""" return ( f"SELECT TABLE_NAME FROM all_tables where OWNER = UPPER('{kwargs['user']}')" ) - def table_primary_key(self, **kwargs): + def table_primary_key(self, **kwargs) -> str: """Query to get the primary key""" return f"SELECT cols.column_name FROM all_constraints cons, all_cons_columns cols WHERE cols.table_name = '{kwargs['table']}' AND cons.constraint_type = 'P' AND cons.constraint_name = cols.constraint_name AND cons.owner = UPPER('{kwargs['user']}') AND cons.owner = cols.owner ORDER BY cols.table_name, cols.position" - def table_data(self, **kwargs): + def table_data(self, **kwargs) -> str: """Query to get the table data""" return f"SELECT * FROM {kwargs['table']}" - def table_last_update_time(self, **kwargs): + def table_last_update_time(self, **kwargs) -> str: """Query to get the last update time of the table""" return f"SELECT SCN_TO_TIMESTAMP(MAX(ora_rowscn)) from {kwargs['table']}" - def table_data_count(self, **kwargs): + def table_data_count(self, **kwargs) -> str: """Query to get the number of rows in the table""" return f"SELECT COUNT(*) FROM {kwargs['table']}" - def all_schemas(self): + def all_schemas(self) -> None: """Query to get all schemas of database""" pass # Multiple schemas not supported in Oracle @@ -69,21 +72,21 @@ def all_schemas(self): class OracleClient: def __init__( self, - host, - port, - user, - password, - connection_source, - sid, - service_name, - tables, - protocol, - oracle_home, - wallet_config, - logger_, - retry_count=DEFAULT_RETRY_COUNT, - fetch_size=DEFAULT_FETCH_SIZE, - ): + host: str, + port: Optional[int], + user: str, + password: str, + connection_source: str, + sid: str, + service_name: str, + tables: Union[str, List[str]], + protocol: str, + oracle_home: str, + wallet_config: str, + logger_: Optional[ExtraLogger], + retry_count: int = DEFAULT_RETRY_COUNT, + fetch_size: int = DEFAULT_FETCH_SIZE, + ) -> None: self.host = host self.port = port self.user = user @@ -102,10 +105,10 @@ def __init__( self.queries = OracleQueries() self._logger = logger_ - def set_logger(self, logger_): + def set_logger(self, logger_) -> None: self._logger = logger_ - def close(self): + def close(self) -> None: if self.connection is not None: self.connection.close() @@ -131,7 +134,7 @@ def engine(self): else: return create_engine(connection_string) - async def get_cursor(self, query): + async def get_cursor(self, query: str) -> CursorResult[Any]: """Executes the passed query on the Non-Async supported Database server and return cursor. Args: @@ -159,7 +162,7 @@ async def get_cursor(self, query): ) raise - async def ping(self): + async def ping(self) -> Generator[Future, None, Tuple[int, str]]: return await anext( fetch( cursor_func=partial(self.get_cursor, self.queries.ping()), @@ -190,7 +193,7 @@ async def get_tables_to_fetch(self): for table in tables: yield table - async def get_table_row_count(self, table): + async def get_table_row_count(self, table: str) -> Generator[Future, None, int]: [row_count] = await anext( fetch( cursor_func=partial( @@ -205,7 +208,9 @@ async def get_table_row_count(self, table): ) return row_count - async def get_table_primary_key(self, table): + async def get_table_primary_key( + self, table: str + ) -> Generator[Future, None, List[str]]: self._logger.debug(f"Extracting primary keys for table '{table}'") primary_keys = [ key @@ -224,7 +229,9 @@ async def get_table_primary_key(self, table): self._logger.debug(f"Found primary keys for table '{table}'") return primary_keys - async def get_table_last_update_time(self, table): + async def get_table_last_update_time( + self, table: str + ) -> Generator[Future, None, str]: self._logger.debug(f"Fetching last updated time for table '{table}'") [last_update_time] = await anext( fetch( @@ -276,7 +283,7 @@ class OracleDataSource(BaseDataSource): name = "Oracle Database" service_type = "oracle" - def __init__(self, configuration): + def __init__(self, configuration: DataSourceConfiguration) -> None: """Setup connection to the Oracle database-server configured by user Args: @@ -305,11 +312,13 @@ def __init__(self, configuration): logger_=self._logger, ) - def _set_internal_logger(self): + def _set_internal_logger(self) -> None: self.oracle_client.set_logger(self._logger) @classmethod - def get_default_configuration(cls): + def get_default_configuration( + cls, + ) -> Dict[str, Dict[str, Union[str, int, bool, List[Dict[str, str]], List[str]]]]: return { "host": { "label": "Host", @@ -415,10 +424,10 @@ def get_default_configuration(cls): }, } - async def close(self): + async def close(self) -> None: self.oracle_client.close() - async def ping(self): + async def ping(self) -> None: """Verify the connection with the database-server configured by user""" self._logger.debug("Validating that the Connector can connect to Oracle...") try: diff --git a/connectors/sources/outlook.py b/connectors/sources/outlook.py index 1f3573fb0..7b310e113 100644 --- a/connectors/sources/outlook.py +++ b/connectors/sources/outlook.py @@ -7,15 +7,18 @@ import asyncio import os +from _asyncio import Task from copy import copy from datetime import date from functools import cached_property, partial +from typing import Any, Dict, Iterator, List, Optional, Union import aiofiles import aiohttp import exchangelib import requests.adapters from aiofiles.os import remove +from aiohttp.client import ClientSession from exchangelib import ( IMPERSONATION, OAUTH2, @@ -35,7 +38,7 @@ prefix_identity, ) from connectors.logger import logger -from connectors.source import BaseDataSource +from connectors.source import BaseDataSource, DataSourceConfiguration from connectors.utils import ( CancellableSleeps, RetryStrategy, @@ -50,7 +53,7 @@ RETRIES = 3 RETRY_INTERVAL = 2 -QUEUE_MEM_SIZE = 5 * 1024 * 1024 # Size in Megabytes +QUEUE_MEM_SIZE: int = 5 * 1024 * 1024 # Size in Megabytes OUTLOOK_SERVER = "outlook_server" OUTLOOK_CLOUD = "outlook_cloud" @@ -73,7 +76,7 @@ ) SEARCH_FILTER_FOR_ADMIN = "(&(objectClass=person)(|(cn=*admin*)(cn=*normal*)))" -MAIL_TYPES = [ +MAIL_TYPES: List[Dict[str, str]] = [ { "folder": "inbox", "constant": INBOX_MAIL_OBJECT, @@ -147,7 +150,7 @@ CERT_FILE = "outlook_cert.cer" -def ews_format_to_datetime(source_datetime, timezone): +def ews_format_to_datetime(source_datetime: str, timezone: str) -> str: """Change datetime format to user account timezone Args: datetime: Datetime in UTC format @@ -167,19 +170,19 @@ def ews_format_to_datetime(source_datetime, timezone): return source_datetime -def _prefix_email(email): +def _prefix_email(email: str) -> Optional[str]: return prefix_identity("email", email) -def _prefix_display_name(user): +def _prefix_display_name(user: str) -> Optional[str]: return prefix_identity("name", user) -def _prefix_user_id(user_id): +def _prefix_user_id(user_id: str) -> Optional[str]: return prefix_identity("user_id", user_id) -def _prefix_job(job_title): +def _prefix_job(job_title: str) -> Optional[str]: return prefix_identity("job_title", job_title) @@ -214,14 +217,14 @@ class SSLFailed(Exception): class ManageCertificate: - async def store_certificate(self, certificate): + async def store_certificate(self, certificate: str) -> None: async with aiofiles.open(CERT_FILE, "w") as file: await file.write(certificate) - def get_certificate_path(self): + def get_certificate_path(self) -> str: return os.path.join(os.getcwd(), CERT_FILE) - async def remove_certificate_file(self): + async def remove_certificate_file(self) -> None: if os.path.exists(CERT_FILE): await remove(CERT_FILE) @@ -229,7 +232,7 @@ async def remove_certificate_file(self): class RootCAAdapter(requests.adapters.HTTPAdapter): """Class to verify SSL Certificate for Exchange Servers""" - def cert_verify(self, conn, url, verify, cert): + def cert_verify(self, conn, url, verify, cert) -> None: try: super().cert_verify( conn=conn, @@ -246,8 +249,15 @@ class ExchangeUsers: """Fetch users from Exchange Active Directory""" def __init__( - self, ad_server, domain, exchange_server, user, password, ssl_enabled, ssl_ca - ): + self, + ad_server: str, + domain: str, + exchange_server: str, + user: str, + password: str, + ssl_enabled: bool, + ssl_ca: str, + ) -> None: self.ad_server = Server(host=ad_server) self.domain = domain self.exchange_server = exchange_server @@ -257,7 +267,7 @@ def __init__( self.ssl_ca = ssl_ca @cached_property - def _create_connection(self): + def _create_connection(self) -> Connection: return Connection( server=self.ad_server, user=self.user, @@ -266,10 +276,10 @@ def _create_connection(self): auto_bind=True, # pyright: ignore ) - async def close(self): + async def close(self) -> None: await ManageCertificate().remove_certificate_file() - def _fetch_normal_users(self, search_query): + def _fetch_normal_users(self, search_query: str) -> Iterator[Dict[str, str]]: try: has_value_for_normal_users, _, response, _ = self._create_connection.search( search_query, @@ -288,7 +298,7 @@ def _fetch_normal_users(self, search_query): msg = f"Something went wrong while fetching users. Error: {e}" raise UsersFetchFailed(msg) from e - def _fetch_admin_users(self, search_query): + def _fetch_admin_users(self, search_query: str) -> Iterator[str]: try: ( has_value_for_admin_users, @@ -352,16 +362,16 @@ async def get_user_accounts(self): class Office365Users: """Fetch users from Office365 Active Directory""" - def __init__(self, client_id, client_secret, tenant_id): + def __init__(self, client_id: str, client_secret: str, tenant_id: str) -> None: self.tenant_id = tenant_id self.client_id = client_id self.client_secret = client_secret @cached_property - def _get_session(self): + def _get_session(self) -> ClientSession: return aiohttp.ClientSession(raise_for_status=True) - async def close(self): + async def close(self) -> None: await self._get_session.close() del self._get_session @@ -388,7 +398,7 @@ def _check_errors(self, response): strategy=RetryStrategy.EXPONENTIAL_BACKOFF, skipped_exceptions=UnauthorizedException, ) - async def _fetch_token(self): + async def _fetch_token(self) -> str: try: async with self._get_session.post( url=f"https://login.microsoftonline.com/{self.tenant_id}/oauth2/v2.0/token", @@ -409,7 +419,7 @@ async def _fetch_token(self): interval=RETRY_INTERVAL, strategy=RetryStrategy.EXPONENTIAL_BACKOFF, ) - async def get_users(self): + async def get_users(self) -> Iterator[Task]: access_token = await self._fetch_token() filter_ = url_encode("accountEnabled eq true") url = f"https://graph.microsoft.com/v1.0/users?$top={TOP}&$filter={filter_}" @@ -588,7 +598,7 @@ def attachment_doc_formatter(self, attachment, attachment_type, timezone): class OutlookClient: """Outlook client to handle API calls made to Outlook""" - def __init__(self, configuration): + def __init__(self, configuration: DataSourceConfiguration) -> None: self._sleeps = CancellableSleeps() self.configuration = configuration self._logger = logger @@ -601,11 +611,11 @@ def __init__(self, configuration): else: self.ssl_ca = "" - def set_logger(self, logger_): + def set_logger(self, logger_) -> None: self._logger = logger_ @cached_property - def _get_user_instance(self): + def _get_user_instance(self) -> Union[ExchangeUsers, Office365Users]: if self.is_cloud: return Office365Users( client_id=self.configuration["client_id"], @@ -628,7 +638,7 @@ async def _fetch_all_users(self): async for user in self._get_user_instance.get_users(): yield user - async def ping(self): + async def ping(self) -> None: await anext(self._get_user_instance.get_users()) async def get_mails(self, account): @@ -681,7 +691,7 @@ class OutlookDataSource(BaseDataSource): incremental_sync_enabled = True dls_enabled = True - def __init__(self, configuration): + def __init__(self, configuration: DataSourceConfiguration) -> None: """Setup the connection to the Outlook Args: @@ -693,14 +703,29 @@ def __init__(self, configuration): self.doc_formatter = OutlookDocFormatter() @cached_property - def client(self): + def client(self) -> OutlookClient: return OutlookClient(configuration=self.configuration) - def _set_internal_logger(self): + def _set_internal_logger(self) -> None: self.client.set_logger(self._logger) @classmethod - def get_default_configuration(cls): + def get_default_configuration( + cls, + ) -> Dict[ + str, + Union[ + Dict[str, Union[List[Dict[str, str]], int, str]], + Dict[str, Union[List[str], int, str]], + Dict[ + str, + Union[ + List[Union[Dict[str, str], Dict[str, Union[bool, str]]]], int, str + ], + ], + Dict[str, Union[int, str]], + ], + ]: """Get the default configuration for Outlook Returns: @@ -807,7 +832,7 @@ def get_default_configuration(cls): }, } - def _dls_enabled(self): + def _dls_enabled(self) -> bool: """Check if document level security is enabled. This method checks whether document level security (DLS) is enabled based on the provided configuration. Returns: @@ -833,7 +858,7 @@ async def get_access_control(self): elif users.get("attributes", {}).get("mail"): yield await self._user_access_control_doc_for_server(users=users) - async def _user_access_control_doc(self, user): + async def _user_access_control_doc(self, user: Dict[str, str]) -> Dict[str, Any]: user_id = user.get("id", "") display_name = user.get("displayName", "") user_email = user.get("mail", "") @@ -856,7 +881,9 @@ async def _user_access_control_doc(self, user): access_control=[_prefixed_user_id, _prefixed_display_name, _prefixed_email] ) - async def _user_access_control_doc_for_server(self, users): + async def _user_access_control_doc_for_server( + self, users: Dict[str, Union[Dict[str, str], str]] + ) -> Dict[str, Any]: name_metadata = users.get("dn", "").split("=", 1)[1] display_name = name_metadata.split(",", 1)[0] user_email = users.get("attributes", {}).get("mail") @@ -877,17 +904,23 @@ async def _user_access_control_doc_for_server(self, users): access_control=[_prefixed_user_id, _prefixed_display_name, _prefixed_email] ) - def _decorate_with_access_control(self, document, access_control): + def _decorate_with_access_control( + self, + document: Dict[str, Union[str, List[str], int, List[int]]], + access_control: List[str], + ) -> Dict[str, Union[str, List[str], int, List[int]]]: if self._dls_enabled(): document[ACCESS_CONTROL] = list( set(document.get(ACCESS_CONTROL, []) + access_control) ) return document - async def close(self): + async def close(self) -> None: await self.client._get_user_instance.close() - async def get_content(self, attachment, timezone, timestamp=None, doit=False): + async def get_content( + self, attachment, timezone, timestamp=None, doit: bool = False + ): """Extracts the content for allowed file types. Args: @@ -1059,7 +1092,7 @@ async def _enqueue_calendars(self, calendar, child_calendar, timezone, account): ): yield doc - async def ping(self): + async def ping(self) -> None: """Verify the connection with Outlook""" await self.client.ping() self._logger.info("Successfully connected to Outlook") diff --git a/connectors/sources/postgresql.py b/connectors/sources/postgresql.py index 28be15d46..bb8f4474a 100644 --- a/connectors/sources/postgresql.py +++ b/connectors/sources/postgresql.py @@ -6,7 +6,9 @@ """Postgresql source module is responsible to fetch documents from PostgreSQL.""" import ssl +from _asyncio import Task from functools import cached_property, partial +from typing import Any, Dict, Generator, List, Optional, Sized, Tuple, Union from urllib.parse import quote import fastjsonschema @@ -15,12 +17,14 @@ from sqlalchemy import text from sqlalchemy.exc import ProgrammingError from sqlalchemy.ext.asyncio import create_async_engine +from sqlalchemy.ext.asyncio.engine import AsyncEngine from connectors.filtering.validation import ( AdvancedRulesValidator, SyncRuleValidationResult, ) -from connectors.source import BaseDataSource +from connectors.logger import ExtraLogger +from connectors.source import BaseDataSource, DataSourceConfiguration from connectors.sources.generic_database import ( DEFAULT_FETCH_SIZE, DEFAULT_RETRY_COUNT, @@ -45,15 +49,15 @@ class PostgreSQLQueries(Queries): """Class contains methods which return query""" - def ping(self): + def ping(self) -> str: """Query to ping source""" return "SELECT 1+1" - def all_tables(self, **kwargs): + def all_tables(self, **kwargs) -> str: """Query to get all tables""" return f"SELECT table_name FROM information_schema.tables WHERE table_catalog = '{kwargs['database']}' and table_schema = '{kwargs['schema']}'" - def table_primary_key(self, **kwargs): + def table_primary_key(self, **kwargs) -> str: """Query to get the primary key""" return ( f"SELECT a.attname AS c " @@ -67,19 +71,19 @@ def table_primary_key(self, **kwargs): f"ORDER BY array_position(i.indkey, a.attnum)" ) - def table_data(self, **kwargs): + def table_data(self, **kwargs) -> str: """Query to get the table data""" return f'SELECT * FROM "{kwargs["schema"]}"."{kwargs["table"]}" ORDER BY {kwargs["columns"]} LIMIT {kwargs["limit"]} OFFSET {kwargs["offset"]}' - def table_last_update_time(self, **kwargs): + def table_last_update_time(self, **kwargs) -> str: """Query to get the last update time of the table""" return f'SELECT MAX(pg_xact_commit_timestamp(xmin)) FROM "{kwargs["schema"]}"."{kwargs["table"]}"' - def table_data_count(self, **kwargs): + def table_data_count(self, **kwargs) -> str: """Query to get the number of rows in the table""" return f'SELECT COUNT(*) FROM "{kwargs["schema"]}"."{kwargs["table"]}"' - def all_schemas(self): + def all_schemas(self) -> None: """Query to get all schemas of database""" pass @@ -100,10 +104,15 @@ class PostgreSQLAdvancedRulesValidator(AdvancedRulesValidator): SCHEMA = fastjsonschema.compile(definition=SCHEMA_DEFINITION) - def __init__(self, source): + def __init__(self, source: "PostgreSQLDataSource") -> None: self.source = source - async def validate(self, advanced_rules): + async def validate( + self, + advanced_rules: Union[ + List[Dict[str, Union[str, List[str]]]], List[Dict[str, str]] + ], + ) -> SyncRuleValidationResult: if len(advanced_rules) == 0: return SyncRuleValidationResult.valid_result( SyncRuleValidationResult.ADVANCED_RULES @@ -116,7 +125,9 @@ async def validate(self, advanced_rules): interval=DEFAULT_WAIT_MULTIPLIER, strategy=RetryStrategy.EXPONENTIAL_BACKOFF, ) - async def _remote_validation(self, advanced_rules): + async def _remote_validation( + self, advanced_rules: List[Dict[str, Union[str, List[str]]]] + ) -> SyncRuleValidationResult: try: PostgreSQLAdvancedRulesValidator.SCHEMA(advanced_rules) except JsonSchemaValueException as e: @@ -155,19 +166,19 @@ async def _remote_validation(self, advanced_rules): class PostgreSQLClient: def __init__( self, - host, - port, - user, - password, - database, - schema, - tables, - ssl_enabled, - ssl_ca, - logger_, - retry_count=DEFAULT_RETRY_COUNT, - fetch_size=DEFAULT_FETCH_SIZE, - ): + host: str, + port: Union[str, int], + user: str, + password: str, + database: str, + schema: str, + tables: Union[List[str], str], + ssl_enabled: bool, + ssl_ca: str, + logger_: Optional[ExtraLogger], + retry_count: int = DEFAULT_RETRY_COUNT, + fetch_size: int = DEFAULT_FETCH_SIZE, + ) -> None: self.host = host self.port = port self.user = user @@ -184,18 +195,18 @@ def __init__( self.connection = None self._logger = logger_ - def set_logger(self, logger_): + def set_logger(self, logger_) -> None: self._logger = logger_ @cached_property - def engine(self): + def engine(self) -> AsyncEngine: connection_string = f"postgresql+asyncpg://{self.user}:{quote(self.password)}@{self.host}:{self.port}/{self.database}" return create_async_engine( connection_string, connect_args=self._get_connect_args(), ) - async def get_cursor(self, query): + async def get_cursor(self, query: str): """Execute the passed query on the Async supported Database server and return cursor. Args: @@ -215,7 +226,7 @@ async def get_cursor(self, query): ) raise - async def ping(self): + async def ping(self) -> Generator[Task, None, Tuple[int]]: return await anext( fetch( cursor_func=partial(self.get_cursor, self.queries.ping()), @@ -224,7 +235,7 @@ async def ping(self): ) ) - async def get_tables_to_fetch(self, is_filtering=False): + async def get_tables_to_fetch(self, is_filtering: bool = False): tables = configured_tables(self.tables) if is_wildcard(tables) or is_filtering: self._logger.info("Fetching all tables") @@ -245,7 +256,7 @@ async def get_tables_to_fetch(self, is_filtering=False): for table in tables: yield table - async def get_table_row_count(self, table): + async def get_table_row_count(self, table: str) -> int: [row_count] = await anext( fetch( cursor_func=partial( @@ -261,7 +272,7 @@ async def get_table_row_count(self, table): ) return row_count - async def get_table_primary_key(self, table): + async def get_table_primary_key(self, table: str) -> List[str]: primary_keys = [ key async for [key] in fetch( @@ -281,7 +292,7 @@ async def get_table_primary_key(self, table): return primary_keys - async def get_table_last_update_time(self, table): + async def get_table_last_update_time(self, table: str) -> str: self._logger.debug(f"Fetching last updated time for table '{table}'") [last_update_time] = await anext( fetch( @@ -371,7 +382,7 @@ async def data_streamer( self._logger.info(f"Found {record_count} records for '{query}' query") - def _get_connect_args(self): + def _get_connect_args(self) -> Dict[Any, Any]: """Convert string to pem format and create an SSL context Returns: @@ -394,7 +405,7 @@ class PostgreSQLDataSource(BaseDataSource): service_type = "postgresql" advanced_rules_enabled = True - def __init__(self, configuration): + def __init__(self, configuration: DataSourceConfiguration) -> None: """Setup connection to the PostgreSQL database-server configured by user Args: @@ -418,11 +429,16 @@ def __init__(self, configuration): logger_=self._logger, ) - def _set_internal_logger(self): + def _set_internal_logger(self) -> None: self.postgresql_client.set_logger(self._logger) @classmethod - def get_default_configuration(cls): + def get_default_configuration( + cls, + ) -> Dict[ + str, + Dict[str, Union[str, int, bool, List[str], List[Dict[str, Union[bool, str]]]]], + ]: return { "host": { "label": "Host", @@ -498,10 +514,10 @@ def get_default_configuration(cls): }, } - def advanced_rules_validators(self): + def advanced_rules_validators(self) -> List[PostgreSQLAdvancedRulesValidator]: return [PostgreSQLAdvancedRulesValidator(self)] - async def ping(self): + async def ping(self) -> None: """Verify the connection with the database-server configured by user""" self._logger.debug("Pinging the PostgreSQL instance") try: @@ -510,7 +526,13 @@ async def ping(self): msg = f"Can't connect to Postgresql on {self.postgresql_client.host}." raise Exception(msg) from e - def row2doc(self, row, doc_id, table, timestamp): + def row2doc( + self, + row: Dict[str, Union[str, int]], + doc_id: str, + table: Union[List[str], str], + timestamp: str, + ) -> Dict[str, Union[str, int, List[str]]]: row.update( { "_id": doc_id, @@ -522,7 +544,7 @@ def row2doc(self, row, doc_id, table, timestamp): ) return row - async def get_primary_key(self, tables): + async def get_primary_key(self, tables: List[str]) -> Tuple[List[str], List[str]]: self._logger.debug(f"Extracting primary keys for tables: {tables}") primary_key_columns = [] for table in tables: @@ -666,7 +688,7 @@ async def _yield_all_docs_from_tables(self, table): async def yield_rows_for_query( self, primary_key_columns, - tables, + tables: Optional[Sized], query=None, row_count=None, order_by_columns=None, diff --git a/connectors/sources/redis.py b/connectors/sources/redis.py index 06298aa87..b079c695c 100644 --- a/connectors/sources/redis.py +++ b/connectors/sources/redis.py @@ -8,6 +8,7 @@ from enum import Enum from functools import cached_property from tempfile import NamedTemporaryFile +from typing import Any, Dict, List, Set, Tuple, Union import fastjsonschema import redis.asyncio as redis @@ -17,7 +18,11 @@ SyncRuleValidationResult, ) from connectors.logger import logger -from connectors.source import BaseDataSource, ConfigurableFieldValueError +from connectors.source import ( + BaseDataSource, + ConfigurableFieldValueError, + DataSourceConfiguration, +) from connectors.utils import get_pem_format, hash_id, iso_utc PAGE_SIZE = 1000 @@ -35,7 +40,7 @@ class KeyType(Enum): class RedisClient: """Redis client to handle method calls made to Redis""" - def __init__(self, configuration): + def __init__(self, configuration: DataSourceConfiguration) -> None: self.configuration = configuration self._logger = logger self.host = self.configuration["host"] @@ -51,10 +56,10 @@ def __init__(self, configuration): self.cert_file = "" self.key_file = "" - def set_logger(self, logger_): + def set_logger(self, logger_) -> None: self._logger = logger_ - def store_ssl_key(self, key, suffix): + def store_ssl_key(self, key: str, suffix: str) -> str: if suffix == ".key": pem_certificates = get_pem_format( key=key, postfix="-----END RSA PRIVATE KEY-----" @@ -65,7 +70,7 @@ def store_ssl_key(self, key, suffix): cert.write(pem_certificates) return cert.name - def remove_temp_files(self): + def remove_temp_files(self) -> None: for file_path in [self.cert_file, self.key_file]: if os.path.exists(file_path): try: @@ -99,12 +104,12 @@ def _client(self): ) return self._redis_client - async def close(self): + async def close(self) -> None: if self._redis_client: await self._client.aclose() # pyright: ignore self.remove_temp_files() - async def validate_database(self, db): + async def validate_database(self, db: int) -> bool: try: await self._client.execute_command("SELECT", db) return True @@ -112,7 +117,7 @@ async def validate_database(self, db): self._logger.warning(f"Database {db} not found. Error: {exception}") return False - async def get_databases(self): + async def get_databases(self) -> None: """Returns number of databases from config_get response Returns: @@ -149,7 +154,9 @@ async def get_paginated_key(self, db, pattern, type_=None): ): yield key - async def get_key_value(self, key, key_type): + async def get_key_value( + self, key: str, key_type: str + ) -> Union[str, Set[int], Dict[str, str], List[int]]: """Fetch value of key for database. Args: @@ -188,13 +195,13 @@ async def get_key_value(self, key, key_type): ) return "" - async def get_key_metadata(self, key): + async def get_key_metadata(self, key: str) -> Tuple[str, str, int]: key_type = await self._client.type(key) key_value = await self.get_key_value(key=key, key_type=key_type) key_size = await self._client.memory_usage(key) return key_type, key_value, key_size - async def ping(self): + async def ping(self) -> None: await self._client.ping() @@ -221,17 +228,35 @@ class RedisAdvancedRulesValidator(AdvancedRulesValidator): SCHEMA = fastjsonschema.compile(definition=SCHEMA_DEFINITION) - def __init__(self, source): + def __init__(self, source: "RedisDataSource") -> None: self.source = source - async def validate(self, advanced_rules): + async def validate( + self, + advanced_rules: Union[ + List[Dict[str, int]], + Dict[str, int], + Dict[str, List[str]], + Dict[str, Union[str, int]], + List[Dict[str, Union[str, int]]], + ], + ) -> SyncRuleValidationResult: if len(advanced_rules) == 0: return SyncRuleValidationResult.valid_result( SyncRuleValidationResult.ADVANCED_RULES ) return await self._remote_validation(advanced_rules) - async def _remote_validation(self, advanced_rules): + async def _remote_validation( + self, + advanced_rules: Union[ + List[Dict[str, int]], + Dict[str, int], + Dict[str, List[str]], + Dict[str, Union[str, int]], + List[Dict[str, Union[str, int]]], + ], + ) -> SyncRuleValidationResult: try: RedisAdvancedRulesValidator.SCHEMA(advanced_rules) except fastjsonschema.JsonSchemaValueException as e: @@ -265,12 +290,20 @@ class RedisDataSource(BaseDataSource): service_type = "redis" advanced_rules_enabled = True - def __init__(self, configuration): + def __init__(self, configuration: DataSourceConfiguration) -> None: super().__init__(configuration=configuration) self.client = RedisClient(configuration=configuration) @classmethod - def get_default_configuration(cls): + def get_default_configuration( + cls, + ) -> Dict[ + str, + Union[ + Dict[str, Union[List[Dict[str, Union[bool, str]]], int, str]], + Dict[str, Union[int, str]], + ], + ]: return { "host": {"label": "Host", "order": 1, "type": "str"}, "port": {"label": "Port", "order": 2, "type": "int"}, @@ -330,16 +363,16 @@ def get_default_configuration(cls): }, } - def _set_internal_logger(self): + def _set_internal_logger(self) -> None: self.client.set_logger(self._logger) - def advanced_rules_validators(self): + def advanced_rules_validators(self) -> List[RedisAdvancedRulesValidator]: return [RedisAdvancedRulesValidator(self)] - async def close(self): + async def close(self) -> None: await self.client.close() - async def _remote_validation(self): + async def _remote_validation(self) -> None: """Validate configured databases Raises: ConfigurableFieldValueError: Unavailable services error. @@ -369,13 +402,13 @@ async def _remote_validation(self): if msg: raise ConfigurableFieldValueError(msg) - async def validate_config(self): + async def validate_config(self) -> None: """Validates whether user input is empty or not for configuration fields Also validate, if user configured databases are available in Redis.""" await super().validate_config() await self._remote_validation() - async def format_document(self, **kwargs): + async def format_document(self, **kwargs) -> Dict[str, Any]: """Prepare document for database records. Returns: @@ -393,7 +426,7 @@ async def format_document(self, **kwargs): } return document - async def ping(self): + async def ping(self) -> None: try: await self.client.ping() self._logger.info("Successfully connected to Redis.") @@ -401,7 +434,7 @@ async def ping(self): self._logger.exception("Error while connecting to Redis.") raise - async def get_db_records(self, db, pattern="*", type_=None): + async def get_db_records(self, db, pattern: str = "*", type_=None): async for key in self.client.get_paginated_key( db=db, pattern=pattern, type_=type_ ): diff --git a/connectors/sources/s3.py b/connectors/sources/s3.py index 3de3210d3..5f078ce46 100644 --- a/connectors/sources/s3.py +++ b/connectors/sources/s3.py @@ -5,8 +5,11 @@ # import asyncio import os +from _asyncio import Future from contextlib import AsyncExitStack from functools import partial +from typing import Dict, Generator, List, Optional, Union +from unittest.mock import MagicMock import aioboto3 import fastjsonschema @@ -19,7 +22,7 @@ SyncRuleValidationResult, ) from connectors.logger import logger -from connectors.source import BaseDataSource +from connectors.source import BaseDataSource, DataSourceConfiguration from connectors.utils import hash_id DEFAULT_PAGE_SIZE = 100 @@ -36,7 +39,7 @@ class S3Client: """Amazon S3 client to handle method calls made to S3""" - def __init__(self, configuration): + def __init__(self, configuration: DataSourceConfiguration) -> None: self.configuration = configuration self._logger = logger self.session = aioboto3.Session( @@ -51,10 +54,10 @@ def __init__(self, configuration): self.clients = {} self.client_context = [] - def set_logger(self, logger_): + def set_logger(self, logger_) -> None: self._logger = logger_ - async def client(self, region=None): + async def client(self, region: None = None): """This method creates context manager and client session object for s3. Args: region (str): Name of bucket region. Defaults to None @@ -83,17 +86,17 @@ async def client(self, region=None): self.clients[region_name] = s3_client return self.clients[region_name] - async def close_client(self): + async def close_client(self) -> None: """Closes unclosed client session""" for context in self.client_context: await context.aclose() - async def fetch_buckets(self): + async def fetch_buckets(self) -> None: """This method used to list all the buckets from Amazon S3""" s3 = await self.client() await s3.list_buckets() - async def get_bucket_list(self): + async def get_bucket_list(self) -> List[str]: """Returns bucket list from list_buckets response Returns: @@ -142,7 +145,7 @@ async def get_bucket_objects(self, bucket, **kwargs): f"Something went wrong while fetching documents from {bucket}. Error: {exception}" ) - async def get_bucket_region(self, bucket_name): + async def get_bucket_region(self, bucket_name: str) -> None: """This method return the name of region for a bucket. Args bucket_name (str): Name of bucket @@ -178,10 +181,12 @@ class S3AdvancedRulesValidator(AdvancedRulesValidator): SCHEMA = fastjsonschema.compile(definition=SCHEMA_DEFINITION) - def __init__(self, source): + def __init__(self, source: "S3DataSource") -> None: self.source = source - async def validate(self, advanced_rules): + async def validate( + self, advanced_rules: Union[Dict[str, List[str]], List[Dict[str, str]]] + ) -> SyncRuleValidationResult: if len(advanced_rules) == 0: return SyncRuleValidationResult.valid_result( SyncRuleValidationResult.ADVANCED_RULES @@ -206,7 +211,7 @@ class S3DataSource(BaseDataSource): service_type = "s3" advanced_rules_enabled = True - def __init__(self, configuration): + def __init__(self, configuration: DataSourceConfiguration) -> None: """Set up the connection to the Amazon S3. Args: @@ -215,13 +220,13 @@ def __init__(self, configuration): super().__init__(configuration=configuration) self.s3_client = S3Client(configuration=configuration) - def _set_internal_logger(self): + def _set_internal_logger(self) -> None: self.s3_client.set_logger(self._logger) - def advanced_rules_validators(self): + def advanced_rules_validators(self) -> List[S3AdvancedRulesValidator]: return [S3AdvancedRulesValidator(self)] - async def ping(self): + async def ping(self) -> None: """Verify the connection with AWS""" try: await self.s3_client.fetch_buckets() @@ -305,7 +310,15 @@ async def get_docs(self, filtering=None): ), ) - async def get_content(self, doc, s3_client, timestamp=None, doit=None): + async def get_content( + self, + doc: Dict[str, Union[str, int]], + s3_client: Union[MagicMock, str], + timestamp: None = None, + doit: Optional[Union[bool, int]] = None, + ) -> Generator[ + Future, None, Optional[Union[Dict[str, str], Dict[str, Union[str, bytes]]]] + ]: if not (doit): return @@ -335,12 +348,16 @@ async def get_content(self, doc, s3_client, timestamp=None, doit=None): return document - async def close(self): + async def close(self) -> None: """Closes unclosed client session""" await self.s3_client.close_client() @classmethod - def get_default_configuration(cls): + def get_default_configuration( + cls, + ) -> Dict[ + str, Union[Dict[str, Union[List[str], int, str]], Dict[str, Union[int, str]]] + ]: """Get the default configuration for Amazon S3. Returns: diff --git a/connectors/sources/salesforce.py b/connectors/sources/salesforce.py index 7bbb250df..6cc37795a 100644 --- a/connectors/sources/salesforce.py +++ b/connectors/sources/salesforce.py @@ -7,13 +7,17 @@ import os import re +from _asyncio import Future, Task from datetime import datetime from functools import cached_property, partial from itertools import groupby +from typing import Any, Dict, Generator, Iterator, List, Optional, Tuple, Union import aiohttp import fastjsonschema +from aiohttp.client import ClientSession from aiohttp.client_exceptions import ClientResponseError +from aiohttp.client_reqrep import ClientResponse from connectors.access_control import ( ACCESS_CONTROL, @@ -25,7 +29,12 @@ SyncRuleValidationResult, ) from connectors.logger import logger -from connectors.source import BaseDataSource, ConfigurableFieldValueError +from connectors.protocol.connectors import Filter +from connectors.source import ( + BaseDataSource, + ConfigurableFieldValueError, + DataSourceConfiguration, +) from connectors.utils import ( TIKA_SUPPORTED_FILETYPES, CancellableSleeps, @@ -34,8 +43,8 @@ retryable, ) -SALESFORCE_EMULATOR_HOST = os.environ.get("SALESFORCE_EMULATOR_HOST") -RUNNING_FTEST = ( +SALESFORCE_EMULATOR_HOST: Optional[str] = os.environ.get("SALESFORCE_EMULATOR_HOST") +RUNNING_FTEST: bool = ( "RUNNING_FTEST" in os.environ ) # Flag to check if a connector is run for ftest or not. @@ -138,16 +147,16 @@ ] -def _prefix_user(user): +def _prefix_user(user: str) -> Optional[str]: if user: return prefix_identity("user", user) -def _prefix_user_id(user_id): +def _prefix_user_id(user_id: str) -> Optional[str]: return prefix_identity("user_id", user_id) -def _prefix_email(email): +def _prefix_email(email: Optional[str]) -> Optional[str]: return prefix_identity("email", email) @@ -188,7 +197,7 @@ class SalesforceServerError(Exception): class SalesforceClient: - def __init__(self, configuration, base_url): + def __init__(self, configuration: DataSourceConfiguration, base_url: str) -> None: self._logger = logger self._sleeps = CancellableSleeps() @@ -211,24 +220,24 @@ def __init__(self, configuration, base_url): for obj in configuration["custom_objects_to_sync"] ] - def set_logger(self, logger_): + def set_logger(self, logger_) -> None: self._logger = logger_ @cached_property - def session(self): + def session(self) -> ClientSession: return aiohttp.ClientSession( timeout=aiohttp.ClientTimeout(total=None), ) - async def ping(self): + async def ping(self) -> None: await self.session.head(self.base_url) - async def close(self): + async def close(self) -> None: self.api_token.clear() await self.session.close() del self.session - def modify_soql_query(self, query): + def modify_soql_query(self, query: str) -> str: lowered_query = query.lower() match_limit = re.search(r"(?i)(.*)FROM\s+(.*?)(?:LIMIT)(.*)", lowered_query) match_offset = re.search(r"(?i)(.*)FROM\s+(.*?)(?:OFFSET)(.*)", lowered_query) @@ -244,7 +253,7 @@ def modify_soql_query(self, query): return query - def _add_last_modified_date(self, query): + def _add_last_modified_date(self, query: str) -> str: lowered_query = query.lower() if ( not ("fields(all)" in lowered_query or "fields(standard)" in lowered_query) @@ -256,7 +265,7 @@ def _add_last_modified_date(self, query): return query - def _add_id(self, query): + def _add_id(self, query: str) -> str: lowered_query = query.lower() if not ( "fields(all)" in lowered_query or "fields(standard)" in lowered_query @@ -265,7 +274,7 @@ def _add_id(self, query): return query - async def get_sync_rules_results(self, rule): + async def get_sync_rules_results(self, rule: Dict[str, str]) -> None: if rule["language"] == "SOQL": query_with_id = self._add_id(query=rule["query"]) query = self._add_last_modified_date(query=query_with_id) @@ -297,7 +306,7 @@ async def _custom_objects(self): custom_objects.append(sobject.get("name")) return custom_objects - async def get_custom_objects(self): + async def get_custom_objects(self) -> None: for custom_object in self.custom_objects_to_sync: query = await self._custom_object_query(custom_object=custom_object) async for records in self._yield_non_bulk_query_pages(query): @@ -322,7 +331,7 @@ async def get_users_with_read_access(self, sobject): for record in records: yield record - async def get_username_by_id(self, user_list): + async def get_username_by_id(self, user_list: Tuple[str]) -> None: query = USERNAME_FROM_IDS.format(user_list=user_list) async for records in self._yield_non_bulk_query_pages(query): for record in records: @@ -334,7 +343,7 @@ async def get_file_access(self, document_id): for record in records: yield record - async def get_accounts(self): + async def get_accounts(self) -> Iterator[Task]: if not await self._is_queryable("Account"): self._logger.warning( "Object Account is not queryable, so they won't be ingested." @@ -346,7 +355,7 @@ async def get_accounts(self): for record in records: yield record - async def get_opportunities(self): + async def get_opportunities(self) -> None: if not await self._is_queryable("Opportunity"): self._logger.warning( "Object Opportunity is not queryable, so they won't be ingested." @@ -358,7 +367,7 @@ async def get_opportunities(self): for record in records: yield record - async def get_contacts(self): + async def get_contacts(self) -> None: if not await self._is_queryable("Contact"): self._logger.warning( "Object Contact is not queryable, so they won't be ingested." @@ -375,7 +384,7 @@ async def get_contacts(self): record["Owner"] = sobjects_by_id["User"].get(record.get("OwnerId"), {}) yield record - async def get_leads(self): + async def get_leads(self) -> None: if not await self._is_queryable("Lead"): self._logger.warning( "Object Lead is not queryable, so they won't be ingested." @@ -399,7 +408,7 @@ async def get_leads(self): yield record - async def get_campaigns(self): + async def get_campaigns(self) -> None: if not await self._is_queryable("Campaign"): self._logger.warning( "Object Campaign is not queryable, so they won't be ingested." @@ -411,7 +420,7 @@ async def get_campaigns(self): for record in records: yield record - async def get_cases(self): + async def get_cases(self) -> None: if not await self._is_queryable("Case"): self._logger.warning( "Object Case is not queryable, so they won't be ingested." @@ -445,7 +454,7 @@ async def get_cases(self): record["Feeds"] = case_feeds_by_case_id.get(record.get("Id")) yield record - async def get_case_feeds(self, case_ids): + async def get_case_feeds(self, case_ids: List[str]) -> List[Dict[str, Any]]: query = await self._case_feeds_query(case_ids) all_case_feeds = [] async for case_feeds in self._yield_non_bulk_query_pages(query): @@ -453,7 +462,7 @@ async def get_case_feeds(self, case_ids): return all_case_feeds - async def queryable_sobjects(self): + async def queryable_sobjects(self) -> List[str]: """Cached async property""" if self._queryable_sobjects is not None: return self._queryable_sobjects @@ -469,9 +478,9 @@ async def queryable_sobjects(self): async def queryable_sobject_fields( self, - relevant_objects, - relevant_sobject_fields, - ): + relevant_objects: List[str], + relevant_sobject_fields: Optional[List[str]], + ) -> Dict[str, List[str]]: """Cached async property""" objects_to_query = [ obj for obj in relevant_objects if obj not in self._queryable_sobject_fields @@ -520,7 +529,7 @@ async def sobjects_cache_by_type(self): self._sobjects_cache_by_type["User"] = await self._prepare_sobject_cache("User") return self._sobjects_cache_by_type - async def _prepare_sobject_cache(self, sobject): + async def _prepare_sobject_cache(self, sobject: str) -> Dict[str, Dict[str, str]]: if not await self._is_queryable(sobject): self._logger.warning( f"{sobject} is not queryable, so they won't be cached." @@ -545,13 +554,15 @@ async def _prepare_sobject_cache(self, sobject): return sobjects - async def _is_queryable(self, sobject): + async def _is_queryable(self, sobject: str) -> bool: """User settings can cause sobjects to be non-queryable Querying these causes errors, so we try to filter those out in advance """ return sobject.lower() in await self.queryable_sobjects() - async def _select_queryable_fields(self, sobject, fields): + async def _select_queryable_fields( + self, sobject: str, fields: List[str] + ) -> List[str]: """User settings can cause fields to be non-queryable Querying these causes errors, so we try to filter those out in advance """ @@ -569,7 +580,9 @@ async def _select_queryable_fields(self, sobject, fields): return queryable_fields return [f for f in fields if f.lower() in queryable_fields] - async def _yield_non_bulk_query_pages(self, soql_query, endpoint=QUERY_ENDPOINT): + async def _yield_non_bulk_query_pages( + self, soql_query: str, endpoint: str = QUERY_ENDPOINT + ) -> Iterator[Task]: """loops through query response pages and yields lists of records""" url = f"{self.base_url}{endpoint}" params = {"q": soql_query} @@ -624,7 +637,22 @@ async def _yield_sosl_query_pages(self, sosl_query): ) yield response.get("searchRecords", []) - async def _execute_non_paginated_query(self, soql_query): + async def _execute_non_paginated_query( + self, soql_query: None + ) -> List[ + Dict[ + str, + Union[ + Dict[str, str], + str, + Dict[str, Union[str, Dict[str, str]]], + Dict[ + str, Union[int, bool, List[Dict[str, Union[str, Dict[str, str]]]]] + ], + Dict[str, Union[str, int]], + ], + ] + ]: """For quick queries, ignores pagination""" url = f"{self.base_url}{QUERY_ENDPOINT}" params = {"q": soql_query} @@ -634,7 +662,7 @@ async def _execute_non_paginated_query(self, soql_query): ) return response.get("records") - async def _auth_headers(self): + async def _auth_headers(self) -> Dict[str, str]: token = await self.api_token.token() return {"authorization": f"Bearer {token}"} @@ -643,7 +671,9 @@ async def _auth_headers(self): interval=RETRY_INTERVAL, skipped_exceptions=[RateLimitedException, InvalidQueryException], ) - async def _get_json(self, url, params=None): + async def _get_json( + self, url: str, params: Optional[Dict[str, str]] = None + ) -> Dict[str, Any]: response_body = None try: response = await self._get(url, params=params) @@ -656,7 +686,9 @@ async def _get_json(self, url, params=None): except Exception as e: raise e - async def _get(self, url, params=None): + async def _get( + self, url: str, params: Optional[Dict[str, str]] = None + ) -> ClientResponse: self._logger.debug(f"Sending request. Url: {url}, params: {params}") headers = await self._auth_headers() return await self.session.get( @@ -665,14 +697,16 @@ async def _get(self, url, params=None): params=params, ) - async def _download(self, content_version_id): + async def _download(self, content_version_id: str): endpoint = CONTENT_VERSION_DOWNLOAD_ENDPOINT.replace( "", content_version_id ) response = await self._get(f"{self.base_url}{endpoint}") yield response - async def _handle_client_response_error(self, response_body, e): + async def _handle_client_response_error( + self, response_body: Optional[List[Dict[str, str]]], e: ClientResponseError + ): exception_details = f"status: {e.status}, message: {e.message}" if e.status == 401: @@ -713,13 +747,15 @@ async def _handle_client_response_error(self, response_body, e): ) raise SalesforceServerError(msg) - def _handle_response_body_error(self, error_list): + def _handle_response_body_error( + self, error_list: Optional[List[Dict[str, str]]] + ) -> List[Dict[str, str]]: if error_list is None or len(error_list) < 1: return [{"errorCode": "unknown"}] return error_list - async def _custom_object_query(self, custom_object): + async def _custom_object_query(self, custom_object: str) -> str: queryable_fields = await self._select_queryable_fields( custom_object, [], @@ -732,7 +768,7 @@ async def _custom_object_query(self, custom_object): .build() ) - async def _user_query(self): + async def _user_query(self) -> str: queryable_fields = await self._select_queryable_fields( "User", ["Name", "Email", "UserType"], @@ -746,7 +782,7 @@ async def _user_query(self): .build() ) - async def _accounts_query(self): + async def _accounts_query(self) -> str: queryable_fields = await self._select_queryable_fields( "Account", [ @@ -791,7 +827,7 @@ async def _accounts_query(self): .build() ) - async def _opportunities_query(self): + async def _opportunities_query(self) -> str: queryable_fields = await self._select_queryable_fields( "Opportunity", [ @@ -812,7 +848,7 @@ async def _opportunities_query(self): .build() ) - async def _contacts_query(self): + async def _contacts_query(self) -> str: queryable_fields = await self._select_queryable_fields( "Contact", [ @@ -837,7 +873,7 @@ async def _contacts_query(self): .build() ) - async def _leads_query(self): + async def _leads_query(self) -> str: queryable_fields = await self._select_queryable_fields( "Lead", [ @@ -868,7 +904,7 @@ async def _leads_query(self): .build() ) - async def _campaigns_query(self): + async def _campaigns_query(self) -> str: queryable_fields = await self._select_queryable_fields( "Campaign", [ @@ -893,7 +929,7 @@ async def _campaigns_query(self): .build() ) - async def _cases_query(self): + async def _cases_query(self) -> str: queryable_fields = await self._select_queryable_fields( "Case", [ @@ -925,7 +961,7 @@ async def _cases_query(self): .build() ) - async def _email_messages_join_query(self): + async def _email_messages_join_query(self) -> str: """For join with Case""" queryable_fields = await self._select_queryable_fields( "EmailMessage", @@ -955,7 +991,7 @@ async def _email_messages_join_query(self): .build() ) - async def _case_comments_join_query(self): + async def _case_comments_join_query(self) -> str: """For join with Case""" queryable_fields = await self._select_queryable_fields( "CaseComment", @@ -976,7 +1012,7 @@ async def _case_comments_join_query(self): .build() ) - async def _case_feeds_query(self, case_ids): + async def _case_feeds_query(self, case_ids: List[str]) -> str: queryable_fields = await self._select_queryable_fields( "CaseFeed", [ @@ -1003,7 +1039,7 @@ async def _case_feeds_query(self, case_ids): .build() ) - async def _case_feed_comments_join(self): + async def _case_feed_comments_join(self) -> str: queryable_fields = await self._select_queryable_fields( "FeedComment", [ @@ -1025,7 +1061,7 @@ async def _case_feed_comments_join(self): .build() ) - async def content_document_links_join(self): + async def content_document_links_join(self) -> str: """Cached async property for getting downloadable attached files This join is identical for all SObject queries""" if self._content_document_links_join is not None: @@ -1099,7 +1135,9 @@ async def content_document_links_join(self): class SalesforceAPIToken: - def __init__(self, session, base_url, client_id, client_secret): + def __init__( + self, session: ClientSession, base_url: str, client_id: str, client_secret: str + ) -> None: self._token = None self.session = session self.url = f"{base_url}{TOKEN_ENDPOINT}" @@ -1114,7 +1152,7 @@ def __init__(self, session, base_url, client_id, client_secret): interval=RETRY_INTERVAL, skipped_exceptions=[InvalidCredentialsException], ) - async def token(self): + async def token(self) -> str: if self._token: return self._token @@ -1142,49 +1180,49 @@ async def token(self): msg = f"Unexpected error while fetching Salesforce token. Status: {e.status}, message: {e.message}" raise TokenFetchException(msg) from e - def clear(self): + def clear(self) -> None: self._token = None class SalesforceSoqlBuilder: - def __init__(self, table): + def __init__(self, table: str) -> None: self.table_name = table self.fields = [] self.where = "" self.order_by = "" self.limit = "" - def with_id(self): + def with_id(self) -> "SalesforceSoqlBuilder": self.fields.append("Id") return self - def with_default_metafields(self): + def with_default_metafields(self) -> "SalesforceSoqlBuilder": self.fields.extend(["CreatedDate", "LastModifiedDate"]) return self - def with_fields(self, fields): + def with_fields(self, fields: List[str]) -> "SalesforceSoqlBuilder": self.fields.extend(fields) return self - def with_where(self, where_string): + def with_where(self, where_string: str) -> "SalesforceSoqlBuilder": self.where = f"WHERE {where_string}" return self - def with_order_by(self, order_by_string): + def with_order_by(self, order_by_string: str) -> "SalesforceSoqlBuilder": self.order_by = f"ORDER BY {order_by_string}" return self - def with_limit(self, limit): + def with_limit(self, limit: int) -> "SalesforceSoqlBuilder": self.limit = f"LIMIT {limit}" return self - def with_join(self, join): + def with_join(self, join: str) -> "SalesforceSoqlBuilder": if join: self.fields.append(f"(\n{join})\n") return self - def build(self): + def build(self) -> str: select_columns = ",\n".join(set(self.fields)) query_lines = [] @@ -1198,10 +1236,22 @@ def build(self): class SalesforceDocMapper: - def __init__(self, base_url): + def __init__(self, base_url: str) -> None: self.base_url = base_url - def map_content_document(self, content_document): + def map_content_document( + self, + content_document: Dict[ + str, + Union[ + Dict[str, str], + str, + int, + Dict[str, Union[str, Dict[str, str]]], + List[str], + ], + ], + ) -> Dict[str, Any]: content_version = content_document.get("LatestPublishedVersion", {}) or {} owner = content_document.get("Owner", {}) or {} created_by = content_document.get("CreatedBy", {}) or {} @@ -1225,7 +1275,7 @@ def map_content_document(self, content_document): "version_url": f"{self.base_url}/{content_version.get('Id')}", } - def map_salesforce_objects(self, _object): + def map_salesforce_objects(self, _object: Dict[str, Any]) -> Dict[str, Any]: def _format_datetime(datetime_): datetime_ = datetime_ or iso_utc() return datetime.strptime(datetime_, "%Y-%m-%dT%H:%M:%S.%f%z").strftime( @@ -1258,10 +1308,12 @@ class SalesforceAdvancedRulesValidator(AdvancedRulesValidator): SCHEMA = fastjsonschema.compile(definition=SCHEMA_DEFINITION) - def __init__(self, source): + def __init__(self, source: "SalesforceDataSource") -> None: self.source = source - async def validate(self, advanced_rules): + async def validate( + self, advanced_rules: List[Dict[str, str]] + ) -> SyncRuleValidationResult: return await self._remote_validation(advanced_rules) @retryable( @@ -1269,7 +1321,9 @@ async def validate(self, advanced_rules): interval=RETRY_INTERVAL, strategy=RetryStrategy.EXPONENTIAL_BACKOFF, ) - async def _remote_validation(self, advanced_rules): + async def _remote_validation( + self, advanced_rules: List[Dict[str, str]] + ) -> SyncRuleValidationResult: try: SalesforceAdvancedRulesValidator.SCHEMA(advanced_rules) except fastjsonschema.JsonSchemaValueException as e: @@ -1321,7 +1375,7 @@ class SalesforceDataSource(BaseDataSource): dls_enabled = True incremental_sync_enabled = True - def __init__(self, configuration): + def __init__(self, configuration: DataSourceConfiguration) -> None: super().__init__(configuration=configuration) base_url = ( @@ -1336,11 +1390,20 @@ def __init__(self, configuration): self.doc_mapper = SalesforceDocMapper(base_url) self.permissions = {} - def _set_internal_logger(self): + def _set_internal_logger(self) -> None: self.salesforce_client.set_logger(self._logger) @classmethod - def get_default_configuration(cls): + def get_default_configuration( + cls, + ) -> Dict[ + str, + Union[ + Dict[str, Union[List[Dict[str, Union[bool, str]]], int, str]], + Dict[str, Union[List[str], int, str]], + Dict[str, Union[int, str]], + ], + ]: return { "domain": { "label": "Domain", @@ -1406,7 +1469,7 @@ def get_default_configuration(cls): }, } - def _dls_enabled(self): + def _dls_enabled(self) -> bool: """Check if document level security is enabled. This method checks whether document level security (DLS) is enabled based on the provided configuration. Returns: @@ -1420,14 +1483,18 @@ def _dls_enabled(self): return self.configuration["use_document_level_security"] - def _decorate_with_access_control(self, document, access_control): + def _decorate_with_access_control( + self, document: Dict[str, Any], access_control: List[Union[str, Any]] + ) -> Dict[str, Any]: if self._dls_enabled(): document[ACCESS_CONTROL] = list( set(document.get(ACCESS_CONTROL, []) + access_control) ) return document - async def _user_access_control_doc(self, user): + async def _user_access_control_doc( + self, user: Dict[str, Union[str, Dict[str, str]]] + ) -> Dict[str, Any]: email = user.get("Email") username = user.get("Name") @@ -1447,7 +1514,7 @@ async def _user_access_control_doc(self, user): "_timestamp": user.get("LastModifiedDate", iso_utc()), } | es_access_control_query(access_control) - async def get_access_control(self): + async def get_access_control(self) -> None: """Get access control documents for Salesforce users. This method fetches access control documents for Salesforce users when document level security (DLS) @@ -1470,11 +1537,11 @@ async def get_access_control(self): user_doc = await self._user_access_control_doc(user=user) yield user_doc - async def validate_config(self): + async def validate_config(self) -> None: await super().validate_config() await self._remote_validation() - async def _remote_validation(self): + async def _remote_validation(self) -> None: await self.salesforce_client.ping() if self.salesforce_client.sync_custom_objects: @@ -1492,10 +1559,10 @@ async def _remote_validation(self): msg = f"Custom objects {[obj[:-3] for obj in self.salesforce_client.custom_objects_to_sync if obj not in custom_object_response]} are not available." raise ConfigurableFieldValueError(msg) - async def close(self): + async def close(self) -> None: await self.salesforce_client.close() - async def ping(self): + async def ping(self) -> None: try: await self.salesforce_client.ping() self._logger.debug("Successfully connected to Salesforce.") @@ -1503,16 +1570,16 @@ async def ping(self): self._logger.exception(f"Error while connecting to Salesforce: {e}") raise - def advanced_rules_validators(self): + def advanced_rules_validators(self) -> List[SalesforceAdvancedRulesValidator]: return [SalesforceAdvancedRulesValidator(self)] - async def _get_advanced_sync_rules_result(self, rule): + async def _get_advanced_sync_rules_result(self, rule: Dict[str, str]) -> None: async for doc in self.salesforce_client.get_sync_rules_results(rule=rule): if sobject := doc.get("attributes", {}).get("type"): await self._fetch_users_with_read_access(sobject=sobject) yield doc - async def _fetch_users_with_read_access(self, sobject): + async def _fetch_users_with_read_access(self, sobject: str) -> None: if not self._dls_enabled(): self._logger.debug("DLS is not enabled. Skipping") return @@ -1547,7 +1614,7 @@ async def _fetch_users_with_read_access(self, sobject): self.permissions[sobject] = list(access_control) - async def get_docs(self, filtering=None): + async def get_docs(self, filtering: Optional[Filter] = None) -> None: # We collect all content documents and de-duplicate them before downloading and yielding content_docs = [] @@ -1695,7 +1762,9 @@ async def get_docs(self, filtering=None): yield self._decorate_with_access_control(doc, access_control), None - async def get_content(self, doc, content_version_id): + async def get_content( + self, doc: Dict[str, Union[int, str, List[str]]], content_version_id: str + ) -> Generator[Future, None, Dict[str, Union[int, str, List[str]]]]: file_size = doc["content_size"] filename = doc["title"] file_extension = self.get_file_extension(filename) @@ -1716,7 +1785,13 @@ async def get_content(self, doc, content_version_id): return_doc_if_failed=True, # we still ingest on download failure for Salesforce ) - def _parse_content_documents(self, record): + def _parse_content_documents( + self, record: Dict[str, Any] + ) -> List[ + Dict[ + str, Union[Dict[str, str], str, int, Dict[str, Union[str, Dict[str, str]]]] + ] + ]: content_docs = [] content_links = record.get("ContentDocumentLinks", {}) or {} content_links = content_links.get("records", []) or [] @@ -1730,7 +1805,36 @@ def _parse_content_documents(self, record): return content_docs - def _combine_duplicate_content_docs(self, content_docs): + def _combine_duplicate_content_docs( + self, + content_docs: List[ + Union[ + Dict[str, str], + Any, + Dict[ + str, + Union[ + Dict[str, str], str, int, Dict[str, Union[str, Dict[str, str]]] + ], + ], + ] + ], + ) -> List[ + Union[ + Dict[ + str, + Union[ + Dict[str, str], + str, + int, + Dict[str, Union[str, Dict[str, str]]], + List[str], + ], + ], + Any, + Dict[str, Union[str, List[str]]], + ] + ]: """Duplicate ContentDocuments may appear linked to multiple SObjects Here we ensure that we don't download any duplicates while retaining links""" grouped = {} diff --git a/connectors/sources/sandfly.py b/connectors/sources/sandfly.py index 49c3255bb..4b0a311d6 100644 --- a/connectors/sources/sandfly.py +++ b/connectors/sources/sandfly.py @@ -9,19 +9,26 @@ import json import socket +from _asyncio import Task from contextlib import asynccontextmanager from datetime import datetime, timedelta from functools import cached_property +from typing import Any, Dict, Generator, Iterator, List, Optional, Tuple, Union # import aiofiles import aiohttp +from aiohttp.client import ClientSession from aiohttp.client_exceptions import ( ClientResponseError, ) from connectors.es.sink import OP_INDEX -from connectors.logger import logger -from connectors.source import CURSOR_SYNC_TIMESTAMP, BaseDataSource +from connectors.logger import ExtraLogger, logger +from connectors.source import ( + CURSOR_SYNC_TIMESTAMP, + BaseDataSource, + DataSourceConfiguration, +) from connectors.utils import ( CacheWithTimeout, CancellableSleeps, @@ -38,11 +45,11 @@ CURSOR_SEQUENCE_ID_KEY = "sequence_id" -def extract_sandfly_date(datestr): +def extract_sandfly_date(datestr: str) -> datetime: return datetime.strptime(datestr, "%Y-%m-%dT%H:%M:%SZ") -def format_sandfly_date(date, flag): +def format_sandfly_date(date: datetime, flag: bool) -> str: if flag: return date.strftime("%Y-%m-%dT00:00:00Z") # date with time as midnight return date.strftime("%Y-%m-%dT%H:%M:%SZ") @@ -73,7 +80,12 @@ class SandflyNotLicensed(Exception): class SandflyAccessToken: - def __init__(self, http_session, configuration, logger_): + def __init__( + self, + http_session: ClientSession, + configuration: Union[Dict[str, Union[bool, str, int]], DataSourceConfiguration], + logger_: ExtraLogger, + ) -> None: self._token_cache = CacheWithTimeout() self._http_session = http_session self._logger = logger_ @@ -82,10 +94,10 @@ def __init__(self, http_session, configuration, logger_): self.username = configuration["username"] self.password = configuration["password"] - def set_logger(self, logger_): + def set_logger(self, logger_: ExtraLogger) -> None: self._logger = logger_ - async def get(self, is_cache=True): + async def get(self, is_cache: bool = True) -> Generator[Task, None, str]: cached_value = self._token_cache.get_value() if is_cache else None if cached_value: @@ -102,7 +114,7 @@ async def get(self, is_cache=True): interval=RETRY_INTERVAL, strategy=RetryStrategy.EXPONENTIAL_BACKOFF, ) - async def _fetch_token(self): + async def _fetch_token(self) -> Tuple[str, int]: url = f"{self.server_url}/auth/login" request_headers = { "Accept": "application/json", @@ -125,20 +137,25 @@ async def _fetch_token(self): class SandflySession: - def __init__(self, http_session, token, logger_): + def __init__( + self, + http_session: ClientSession, + token: SandflyAccessToken, + logger_: ExtraLogger, + ) -> None: self._sleeps = CancellableSleeps() self._logger = logger_ self._http_session = http_session self._token = token - def set_logger(self, logger_): + def set_logger(self, logger_: ExtraLogger) -> None: self._logger = logger_ - def close(self): + def close(self) -> None: self._sleeps.cancel() - async def ping(self, server_url): + async def ping(self, server_url: str) -> bool: try: await self._http_session.head(server_url) return True @@ -158,7 +175,7 @@ async def ping(self, server_url): strategy=RetryStrategy.EXPONENTIAL_BACKOFF, skipped_exceptions=[ResourceNotFound, FetchTokenError], ) - async def _get(self, absolute_url): + async def _get(self, absolute_url: str) -> Iterator[Task]: try: access_token = await self._token.get() headers = { @@ -192,7 +209,13 @@ async def _get(self, absolute_url): strategy=RetryStrategy.EXPONENTIAL_BACKOFF, skipped_exceptions=[ResourceNotFound, FetchTokenError], ) - async def _post(self, absolute_url, payload): + async def _post( + self, + absolute_url: str, + payload: Dict[ + str, Union[int, str, Dict[str, List[Dict[str, str]]], List[Dict[str, str]]] + ], + ) -> Iterator[Task]: try: access_token = await self._token.get() headers = { @@ -219,7 +242,7 @@ async def _post(self, absolute_url, payload): except Exception: raise - async def content_get(self, url): + async def content_get(self, url: str) -> Generator[Task, None, str]: try: async with self._get(absolute_url=url) as response: return await response.text() @@ -229,7 +252,20 @@ async def content_get(self, url): ) raise - async def content_post(self, url, payload): + async def content_post( + self, + url: str, + payload: Dict[ + str, + Union[ + int, + Dict[str, List[Dict[str, str]]], + List[Dict[str, str]], + Dict[str, Union[str, List[Dict[str, str]]]], + str, + ], + ], + ) -> Generator[Task, None, str]: try: async with self._post(absolute_url=url, payload=payload) as response: return await response.text() @@ -241,7 +277,10 @@ async def content_post(self, url, payload): class SandflyClient: - def __init__(self, configuration): + def __init__( + self, + configuration: Union[Dict[str, Union[bool, str, int]], DataSourceConfiguration], + ) -> None: self._sleeps = CancellableSleeps() self._logger = logger @@ -268,16 +307,16 @@ def __init__(self, configuration): logger_=self._logger, ) - def set_logger(self, logger_): + def set_logger(self, logger_: ExtraLogger) -> None: self._logger = logger_ self.token.set_logger(self._logger) self.client.set_logger(self._logger) - async def close(self): + async def close(self) -> None: await self.http_session.close() self.client.close() - async def ping(self): + async def ping(self) -> bool: try: await self.client.ping(self.server_url) self._logger.info( @@ -352,7 +391,9 @@ async def get_results_by_id(self, sequence_id, enable_pass): for result_item in data_list: yield result_item, more_results - async def get_results_by_time(self, time_since, enable_pass): + async def get_results_by_time( + self, time_since: str, enable_pass: bool + ) -> Iterator[Task]: results_url = f"{self.server_url}/results" if enable_pass: @@ -429,7 +470,7 @@ async def get_hosts(self): for host in data_list: yield host - async def get_license(self): + async def get_license(self) -> Iterator[Task]: license_url = f"{self.server_url}/license" content = await self.client.content_get(url=license_url) @@ -445,7 +486,7 @@ class SandflyDataSource(BaseDataSource): service_type = "sandfly" incremental_sync_enabled = True - def __init__(self, configuration): + def __init__(self, configuration: DataSourceConfiguration) -> None: super().__init__(configuration=configuration) self._logger = logger @@ -457,11 +498,11 @@ def __init__(self, configuration): self.fetch_days = self.configuration["fetch_days"] @cached_property - def client(self): + def client(self) -> SandflyClient: return SandflyClient(configuration=self.configuration) @classmethod - def get_default_configuration(cls): + def get_default_configuration(cls) -> Dict[str, Dict[str, Union[bool, str, int]]]: return { "server_url": { "label": "Sandfly Server URL", @@ -512,17 +553,17 @@ def get_default_configuration(cls): }, } - async def ping(self): + async def ping(self) -> bool: try: await self.client.ping() return True except Exception: raise - async def close(self): + async def close(self) -> None: await self.client.close() - def init_sync_cursor(self): + def init_sync_cursor(self) -> Dict[str, Union[str, int]]: if not self._sync_cursor: self._sync_cursor = { CURSOR_SEQUENCE_ID_KEY: 0, @@ -531,7 +572,17 @@ def init_sync_cursor(self): return self._sync_cursor - def _format_doc(self, doc_id, doc_time, doc_text, doc_field, doc_data): + def _format_doc( + self, + doc_id: str, + doc_time: str, + doc_text: str, + doc_field: str, + doc_data: Dict[ + str, + Optional[Union[str, Dict[str, Dict[str, Dict[str, str]]], Dict[str, str]]], + ], + ) -> Dict[str, Any]: document = { "_id": doc_id, "_timestamp": doc_time, @@ -541,7 +592,9 @@ def _format_doc(self, doc_id, doc_time, doc_text, doc_field, doc_data): } return document - def extract_results_data(self, result_item, get_more_results): + def extract_results_data( + self, result_item: Dict[str, Union[Dict[str, str], str]], get_more_results: bool + ) -> Tuple[str, str, str, str]: last_sequence_id = result_item["sequence_id"] external_id = result_item["external_id"] timestamp = result_item["header"]["end_time"] @@ -559,7 +612,7 @@ def extract_results_data(self, result_item, get_more_results): return timestamp, key_data, last_sequence_id, doc_id - def extract_sshkey_data(self, key_item): + def extract_sshkey_data(self, key_item: Dict[str, str]) -> Tuple[str, str]: friendly = key_item["friendly_name"] key_value = key_item["key_value"] @@ -569,7 +622,12 @@ def extract_sshkey_data(self, key_item): return friendly, doc_id - def extract_host_data(self, host_item): + def extract_host_data( + self, + host_item: Dict[ + str, Optional[Union[str, Dict[str, Dict[str, Dict[str, str]]]]] + ], + ) -> Tuple[str, str]: hostid = host_item["host_id"] hostname = host_item["hostname"] @@ -588,7 +646,9 @@ def extract_host_data(self, host_item): return key_data, doc_id - def validate_license(self, license_data): + def validate_license( + self, license_data: Dict[str, Union[int, Dict[str, str], Dict[str, List[str]]]] + ) -> None: customer = license_data["customer"]["name"] expiry = license_data["date"]["expiry"] @@ -612,7 +672,7 @@ def validate_license(self, license_data): msg = f"Sandfly Server [{self.server_url}] is not licensed for Elasticsearch Replication" raise SandflyNotLicensed(msg) - async def get_docs(self, filtering=None): + async def get_docs(self, filtering: None = None): self.init_sync_cursor() async for license_data in self.client.get_license(): @@ -708,7 +768,9 @@ async def get_docs(self, filtering=None): if last_sequence_id is not None: self._sync_cursor[CURSOR_SEQUENCE_ID_KEY] = last_sequence_id - async def get_docs_incrementally(self, sync_cursor, filtering=None): + async def get_docs_incrementally( + self, sync_cursor: Optional[Dict[str, str]], filtering: None = None + ): self._sync_cursor = sync_cursor timestamp = iso_utc() diff --git a/connectors/sources/servicenow.py b/connectors/sources/servicenow.py index 75e92df3d..381c82ae7 100644 --- a/connectors/sources/servicenow.py +++ b/connectors/sources/servicenow.py @@ -10,13 +10,17 @@ import math import os import uuid +from _asyncio import Future, Task +from asyncio.tasks import _GatheringFuture from enum import Enum from functools import cached_property, partial +from typing import Any, Dict, Generator, Iterator, List, Optional, Set, Tuple, Union from urllib.parse import urlencode import aiohttp import dateutil.parser as parser import fastjsonschema +from aiohttp.client import ClientSession from connectors.access_control import ( ACCESS_CONTROL, @@ -28,7 +32,11 @@ SyncRuleValidationResult, ) from connectors.logger import logger -from connectors.source import BaseDataSource, ConfigurableFieldValueError +from connectors.source import ( + BaseDataSource, + ConfigurableFieldValueError, + DataSourceConfiguration, +) from connectors.utils import ( CancellableSleeps, ConcurrentTasks, @@ -40,14 +48,14 @@ RETRIES = 3 RETRY_INTERVAL = 2 -QUEUE_MEM_SIZE = 25 * 1024 * 1024 # Size in Megabytes +QUEUE_MEM_SIZE: int = 25 * 1024 * 1024 # Size in Megabytes CONCURRENT_TASKS = 1000 # Depends on total number of services and size of each service MAX_CONCURRENT_CLIENT_SUPPORT = 10 TABLE_FETCH_SIZE = 50 TABLE_BATCH_SIZE = 5 ATTACHMENT_BATCH_SIZE = 10 -RUNNING_FTEST = ( +RUNNING_FTEST: bool = ( "RUNNING_FTEST" in os.environ ) # Flag to check if a connector is run for ftest or not. @@ -74,15 +82,15 @@ ACLS_QUERY = "sys_security_acl.operation=read^sys_security_acl.name={table_name}" -def _prefix_email(email): +def _prefix_email(email: str) -> Optional[str]: return prefix_identity("email", email) -def _prefix_username(user): +def _prefix_username(user: str) -> Optional[str]: return prefix_identity("username", user) -def _prefix_user_id(user_id): +def _prefix_user_id(user_id: str) -> Optional[str]: return prefix_identity("user_id", user_id) @@ -99,7 +107,7 @@ class InvalidResponse(Exception): class ServiceNowClient: """ServiceNow Client""" - def __init__(self, configuration): + def __init__(self, configuration: DataSourceConfiguration) -> None: """Setup the ServiceNow client. Args: @@ -112,11 +120,11 @@ def __init__(self, configuration): self.retry_count = self.configuration["retry_count"] self._logger = logger - def set_logger(self, logger_): + def set_logger(self, logger_) -> None: self._logger = logger_ @cached_property - def _get_session(self): + def _get_session(self) -> ClientSession: """Generate aiohttp client session with configuration fields. Returns: @@ -162,7 +170,7 @@ async def _read_response(self, response): interval=RETRY_INTERVAL, strategy=RetryStrategy.EXPONENTIAL_BACKOFF, ) - async def get_table_length(self, table_name): + async def get_table_length(self, table_name: str) -> int: try: url = ENDPOINTS["TABLE"].format(table=table_name) params = {"sysparm_limit": 1} @@ -177,7 +185,7 @@ async def get_table_length(self, table_name): ) raise - def _prepare_url(self, url, params, offset): + def _prepare_url(self, url: str, params: Dict[str, str], offset: int) -> str: if not url.endswith("/file"): query = ORDER_BY_CREATION_DATE_QUERY if "sysparm_query" in params.keys(): @@ -195,7 +203,9 @@ def _prepare_url(self, url, params, offset): full_url = f"{url}?{params_string}" return full_url - async def get_filter_apis(self, rules, mapping): + async def get_filter_apis( + self, rules: List[Dict[str, str]], mapping: Dict[str, str] + ) -> List[Dict[str, Union[str, List[Dict[str, str]]]]]: apis = [] for rule in rules: params = {"sysparm_query": rule["query"]} @@ -209,7 +219,9 @@ async def get_filter_apis(self, rules, mapping): apis.extend(paginated_apis) return apis - def get_record_apis(self, url, params, total_count): + def get_record_apis( + self, url: str, params: Dict[str, str], total_count: int + ) -> List[Dict[str, Any]]: headers = [ {"name": "Content-Type", "value": "application/json"}, {"name": "Accept", "value": "application/json"}, @@ -230,7 +242,7 @@ def get_record_apis(self, url, params, total_count): ) return apis - def get_attachment_apis(self, url, ids): + def get_attachment_apis(self, url: str, ids: List[str]) -> List[Dict[str, Any]]: headers = [ {"name": "Content-Type", "value": "application/json"}, {"name": "Accept", "value": "application/json"}, @@ -248,7 +260,7 @@ def get_attachment_apis(self, url, ids): ) return apis - async def get_data(self, batched_apis): + async def get_data(self, batched_apis: Set[str]) -> Iterator[Task]: try: batch_data = self._prepare_batch(requests=batched_apis) async for response in self._batch_api_call(batch_data=batch_data): @@ -259,7 +271,7 @@ async def get_data(self, batched_apis): ) raise - def _prepare_batch(self, requests): + def _prepare_batch(self, requests: Set[str]) -> Dict[str, str]: return {"batch_request_id": str(uuid.uuid4()), "rest_requests": requests} @retryable( @@ -267,7 +279,7 @@ def _prepare_batch(self, requests): interval=RETRY_INTERVAL, strategy=RetryStrategy.EXPONENTIAL_BACKOFF, ) - async def _batch_api_call(self, batch_data): + async def _batch_api_call(self, batch_data: Dict[str, Union[str, Set[str]]]): response = await self._api_call( url=ENDPOINTS["BATCH"], params={}, actions=batch_data, method="post" ) @@ -285,11 +297,17 @@ async def _api_call(self, url, params, actions, method): url=url, params=params, json=actions ) - async def download_func(self, url): + async def download_func(self, url: str): response = await self._api_call(url, {}, {}, "get") yield response - async def filter_services(self, configured_service): + async def filter_services( + self, configured_service: Union[List[str], Set[str]] + ) -> Union[ + Tuple[Dict[str, str], Set[Any]], + Tuple[Dict[str, str], List[str]], + Tuple[Dict[str, str], List[Any]], + ]: """Filter services based on service mappings. Args: @@ -342,14 +360,14 @@ async def filter_services(self, configured_service): ) raise - def _log_missing_sysparm_field(self, sys_id, field): + def _log_missing_sysparm_field(self, sys_id: str, field: str) -> None: msg = f"Entry in sys_db_object with sys_id '{sys_id}' is missing sysparm_field '{field}'. This is a non-issue if no invalid services are flagged." self._logger.debug(msg) - async def ping(self): + async def ping(self) -> None: await self.get_table_length(table_name="sys_db_object") - async def close_session(self): + async def close_session(self) -> None: """Closes unclosed client session""" self._sleeps.cancel() await self._get_session.close() @@ -371,10 +389,17 @@ class ServiceNowAdvancedRulesValidator(AdvancedRulesValidator): SCHEMA = fastjsonschema.compile(definition=SCHEMA_DEFINITION) - def __init__(self, source): + def __init__(self, source: "ServiceNowDataSource") -> None: self.source = source - async def validate(self, advanced_rules): + async def validate( + self, + advanced_rules: Union[ + List[Dict[str, str]], + List[Union[Dict[str, Union[str, List[str]]], Dict[str, str]]], + List[Dict[str, Union[str, List[str]]]], + ], + ) -> SyncRuleValidationResult: if len(advanced_rules) == 0: return SyncRuleValidationResult.valid_result( SyncRuleValidationResult.ADVANCED_RULES @@ -382,7 +407,9 @@ async def validate(self, advanced_rules): return await self._remote_validation(advanced_rules) - async def _remote_validation(self, advanced_rules): + async def _remote_validation( + self, advanced_rules: List[Dict[str, Union[str, List[str]]]] + ) -> SyncRuleValidationResult: try: ServiceNowAdvancedRulesValidator.SCHEMA(advanced_rules) except fastjsonschema.JsonSchemaValueException as e: @@ -424,7 +451,7 @@ class ServiceNowDataSource(BaseDataSource): dls_enabled = True incremental_sync_enabled = True - def __init__(self, configuration): + def __init__(self, configuration: DataSourceConfiguration) -> None: """Setup the connection to the ServiceNow instance. Args: @@ -442,10 +469,10 @@ def __init__(self, configuration): self.queue = MemQueue(maxmemsize=QUEUE_MEM_SIZE, refresh_timeout=120) self.fetchers = ConcurrentTasks(max_concurrency=CONCURRENT_TASKS) - def advanced_rules_validators(self): + def advanced_rules_validators(self) -> List[ServiceNowAdvancedRulesValidator]: return [ServiceNowAdvancedRulesValidator(self)] - def tweak_bulk_options(self, options): + def tweak_bulk_options(self, options: Dict[str, int]) -> None: """Tweak bulk options as per concurrent downloads support by ServiceNow Args: @@ -455,7 +482,11 @@ def tweak_bulk_options(self, options): options["concurrent_downloads"] = self.concurrent_downloads @classmethod - def get_default_configuration(cls): + def get_default_configuration( + cls, + ) -> Dict[ + str, Union[Dict[str, Union[List[str], int, str]], Dict[str, Union[int, str]]] + ]: return { "url": { "label": "Service URL", @@ -518,7 +549,7 @@ def get_default_configuration(cls): }, } - def _dls_enabled(self): + def _dls_enabled(self) -> bool: """Check if document level security is enabled. This method checks whether document level security (DLS) is enabled based on the provided configuration. Returns: @@ -532,7 +563,7 @@ def _dls_enabled(self): return self.configuration["use_document_level_security"] - async def _user_access_control_doc(self, user): + async def _user_access_control_doc(self, user: Dict[str, str]) -> Dict[str, Any]: user_id = user.get("_id", "") user_name = user.get("user_name", "") user_email = user.get("email", "") @@ -567,7 +598,7 @@ async def _fetch_users_by_roles(self, role): ): yield user - async def get_access_control(self): + async def get_access_control(self) -> None: if not self._dls_enabled(): self._logger.warning("DLS is not enabled. Skipping") return @@ -575,14 +606,16 @@ async def get_access_control(self): async for user in self._fetch_all_users(): yield await self._user_access_control_doc(user=user) - def _decorate_with_access_control(self, document, access_control): + def _decorate_with_access_control( + self, document: Dict[str, str], access_control: List[Union[Any, str]] + ) -> Dict[str, Union[str, List[str]]]: if self._dls_enabled(): document[ACCESS_CONTROL] = list( set(document.get(ACCESS_CONTROL, []) + access_control) ) return document - async def _remote_validation(self): + async def _remote_validation(self) -> None: """Validate configured services Raises: @@ -600,17 +633,17 @@ async def _remote_validation(self): msg = f"Services '{', '.join(self.invalid_services)}' are not available. Available services are: '{', '.join(set(self.servicenow_client.services) - set(self.invalid_services))}'" raise ConfigurableFieldValueError(msg) - async def validate_config(self): + async def validate_config(self) -> None: """Validates whether user input is empty or not for configuration fields Also validate, if user configured services are available in ServiceNow.""" await super().validate_config() await self._remote_validation() - async def close(self): + async def close(self) -> None: await self.servicenow_client.close_session() - async def ping(self): + async def ping(self) -> None: """Verify the connection with ServiceNow.""" try: @@ -621,7 +654,7 @@ async def ping(self): self._logger.exception("Error while connecting to the ServiceNow.") raise - def _format_doc(self, data): + def _format_doc(self, data: Dict[str, str]) -> Dict[str, str]: """Format document for handling empty values & type casting. Args: @@ -640,7 +673,11 @@ def _format_doc(self, data): ) return data - async def _fetch_attachment_metadata(self, batched_apis, table_access_control): + async def _fetch_attachment_metadata( + self, + batched_apis: List[Dict[str, Union[str, List[Dict[str, str]]]]], + table_access_control: List[Union[Any, str]], + ) -> None: try: async for attachments_metadata in self.servicenow_client.get_data( batched_apis=batched_apis @@ -670,7 +707,9 @@ async def _fetch_attachment_metadata(self, batched_apis, table_access_control): finally: await self.queue.put(EndSignal.ATTACHMENT) - async def _attachment_metadata_producer(self, record_ids, table_access_control): + async def _attachment_metadata_producer( + self, record_ids: List[str], table_access_control: List[Union[Any, str]] + ) -> None: attachment_apis = None try: attachment_apis = self.servicenow_client.get_attachment_apis( @@ -716,7 +755,11 @@ async def _yield_table_data(self, batched_apis): exc_info=True, ) - async def _fetch_table_data(self, batched_apis, table_access_control): + async def _fetch_table_data( + self, + batched_apis: List[Dict[str, Union[str, List[Dict[str, str]]]]], + table_access_control: List[Union[Any, str]], + ) -> None: try: async for table_data in self.servicenow_client.get_data( batched_apis=batched_apis @@ -751,7 +794,7 @@ async def _fetch_table_data(self, batched_apis, table_access_control): finally: await self.queue.put(EndSignal.RECORD) - async def _fetch_access_controls(self, table_name): + async def _fetch_access_controls(self, table_name: str) -> List[str]: access_control, user_roles, roles = [], [], {} if table_name in DEFAULT_SERVICE_NAMES.keys(): async for role in self._table_data_generator( @@ -828,7 +871,12 @@ async def _table_data_generator(self, service_name, params): exc_info=True, ) - async def _table_data_producer(self, service_name, params, table_access_control): + async def _table_data_producer( + self, + service_name: str, + params: Dict[Any, Any], + table_access_control: List[Union[Any, str]], + ) -> None: self._logger.debug(f"Fetching {service_name} data") try: async for batched_apis in self._get_batched_apis(service_name, params): @@ -843,7 +891,7 @@ async def _table_data_producer(self, service_name, params, table_access_control) finally: await self.queue.put(EndSignal.SERVICE) - async def _consumer(self): + async def _consumer(self) -> Iterator[Future]: """Consume the queue for the documents. Yields: @@ -858,7 +906,9 @@ async def _consumer(self): else: yield item - async def get_docs(self, filtering=None): + async def get_docs( + self, filtering: None = None + ) -> Iterator[Union[Future, _GatheringFuture]]: """Get documents from ServiceNow. Args: @@ -929,7 +979,9 @@ async def get_docs(self, filtering=None): await self.fetchers.join() - async def get_content(self, metadata, timestamp=None, doit=False): + async def get_content( + self, metadata: Dict[str, str], timestamp: None = None, doit: bool = False + ) -> Generator[Future, None, Optional[Dict[str, str]]]: file_size = int(metadata["size_bytes"]) if not (doit and file_size > 0): return diff --git a/connectors/sources/sharepoint_online.py b/connectors/sources/sharepoint_online.py index 7adbf3450..03acc6f9b 100644 --- a/connectors/sources/sharepoint_online.py +++ b/connectors/sources/sharepoint_online.py @@ -6,20 +6,35 @@ import asyncio import os import re +from _asyncio import Future from collections.abc import Iterable, Sized from contextlib import asynccontextmanager from datetime import datetime, timedelta from functools import partial, wraps +from typing import ( + Any, + Callable, + Dict, + Generator, + Iterator, + List, + Optional, + Tuple, + Union, +) +from unittest.mock import AsyncMock, MagicMock import aiofiles import aiohttp import fastjsonschema from aiofiles.os import remove from aiofiles.tempfile import NamedTemporaryFile +from aiohttp.client import ClientSession from aiohttp.client_exceptions import ClientPayloadError, ClientResponseError -from aiohttp.client_reqrep import RequestInfo +from aiohttp.client_reqrep import ClientResponse, RequestInfo from azure.identity.aio import CertificateCredential from fastjsonschema import JsonSchemaValueException +from freezegun.api import FakeDatetime from connectors.access_control import ( ACCESS_CONTROL, @@ -31,8 +46,12 @@ AdvancedRulesValidator, SyncRuleValidationResult, ) -from connectors.logger import logger -from connectors.source import CURSOR_SYNC_TIMESTAMP, BaseDataSource +from connectors.logger import ExtraLogger, logger +from connectors.source import ( + CURSOR_SYNC_TIMESTAMP, + BaseDataSource, + DataSourceConfiguration, +) from connectors.utils import ( TIKA_SUPPORTED_FILETYPES, CacheWithTimeout, @@ -57,20 +76,20 @@ ) logger.warning("IT'S SUPPOSED TO BE USED ONLY FOR TESTING") logger.warning("x" * 50) - override_url = os.environ["OVERRIDE_URL"] - GRAPH_API_URL = override_url - GRAPH_API_AUTH_URL = override_url - REST_API_AUTH_URL = override_url + override_url: str = os.environ["OVERRIDE_URL"] + GRAPH_API_URL: str = override_url + GRAPH_API_AUTH_URL: str = override_url + REST_API_AUTH_URL: str = override_url else: - GRAPH_API_URL = "https://graph.microsoft.com/v1.0" - GRAPH_API_AUTH_URL = "https://login.microsoftonline.com" - REST_API_AUTH_URL = "https://accounts.accesscontrol.windows.net" + GRAPH_API_URL: str = "https://graph.microsoft.com/v1.0" + GRAPH_API_AUTH_URL: str = "https://login.microsoftonline.com" + REST_API_AUTH_URL: str = "https://accounts.accesscontrol.windows.net" DEFAULT_RETRY_COUNT = 5 DEFAULT_RETRY_SECONDS = 30 DEFAULT_PARALLEL_CONNECTION_COUNT = 10 DEFAULT_BACKOFF_MULTIPLIER = 5 -FILE_WRITE_CHUNK_SIZE = 1024 * 64 # 64KB default SSD page size +FILE_WRITE_CHUNK_SIZE: int = 1024 * 64 # 64KB default SSD page size MAX_DOCUMENT_SIZE = 10485760 WILDCARD = "*" DRIVE_ITEMS_FIELDS = "id,content.downloadUrl,lastModifiedDateTime,lastModifiedBy,root,deleted,file,folder,package,name,webUrl,createdBy,createdDateTime,size,parentReference" @@ -98,7 +117,7 @@ REVIEWER = 7 SYSTEM = 0xFF # Note the exclusion of NONE(0), GUEST(1), RESTRICTED_READER(8), and RESTRICTED_GUEST(9) -VIEW_ROLE_TYPES = [ +VIEW_ROLE_TYPES: List[int] = [ READER, CONTRIBUTOR, WEB_DESIGNER, @@ -193,7 +212,13 @@ class MicrosoftSecurityToken: - https://learn.microsoft.com/en-us/azure/active-directory/develop/quickstart-register-app """ - def __init__(self, http_session, tenant_id, tenant_name, client_id): + def __init__( + self, + http_session: Optional[ClientSession], + tenant_id: Optional[str], + tenant_name: Optional[str], + client_id: Optional[str], + ) -> None: """Initializer. Args: @@ -258,7 +283,14 @@ async def _fetch_token(self): class SecretAPIToken(MicrosoftSecurityToken): - def __init__(self, http_session, tenant_id, tenant_name, client_id, client_secret): + def __init__( + self, + http_session: ClientSession, + tenant_id: Optional[str], + tenant_name: Optional[str], + client_id: Optional[str], + client_secret: Optional[str], + ) -> None: super().__init__(http_session, tenant_id, tenant_name, client_id) self._client_secret = client_secret @@ -270,7 +302,7 @@ class GraphAPIToken(SecretAPIToken): """Token to connect to Microsoft Graph API endpoints.""" @retryable(retries=3) - async def _fetch_token(self): + async def _fetch_token(self) -> Tuple[str, FakeDatetime]: """Fetch API token for usage with Graph API Returns: @@ -295,7 +327,7 @@ class SharepointRestAPIToken(SecretAPIToken): """Token to connect to Sharepoint REST API endpoints.""" @retryable(retries=DEFAULT_RETRY_COUNT) - async def _fetch_token(self): + async def _fetch_token(self) -> Tuple[str, FakeDatetime]: """Fetch API token for usage with Sharepoint REST API Returns: @@ -327,21 +359,21 @@ class EntraAPIToken(MicrosoftSecurityToken): def __init__( self, - http_session, - tenant_id, - tenant_name, - client_id, - certificate, - private_key, - scope, - ): + http_session: ClientSession, + tenant_id: str, + tenant_name: str, + client_id: str, + certificate: str, + private_key: str, + scope: str, + ) -> None: super().__init__(http_session, tenant_id, tenant_name, client_id) self._certificate = certificate self._private_key = private_key self._scope = scope @retryable(retries=3) - async def _fetch_token(self): + async def _fetch_token(self) -> Tuple[str, datetime]: """Fetch API token for usage with Graph API Returns: @@ -361,7 +393,7 @@ async def _fetch_token(self): return token.token, datetime.utcfromtimestamp(token.expires_on) -def retryable_aiohttp_call(retries): +def retryable_aiohttp_call(retries: int) -> Callable: # TODO: improve utils.retryable to allow custom logic # that can help choose what to retry def wrapper(func): @@ -386,7 +418,13 @@ async def wrapped(*args, **kwargs): class MicrosoftAPISession: - def __init__(self, http_session, api_token, scroll_field, logger_): + def __init__( + self, + http_session: ClientSession, + api_token: Union[GraphAPIToken, SharepointRestAPIToken], + scroll_field: str, + logger_: ExtraLogger, + ) -> None: self._http_session = http_session self._api_token = api_token @@ -399,21 +437,23 @@ def __init__(self, http_session, api_token, scroll_field, logger_): self._sleeps = CancellableSleeps() self._logger = logger_ - def set_logger(self, logger_): + def set_logger(self, logger_) -> None: self._logger = logger_ - def close(self): + def close(self) -> None: self._sleeps.cancel() - async def fetch(self, url): + async def fetch(self, url: str) -> Dict[str, str]: return await self._get_json(url) - async def post(self, url, payload): + async def post( + self, url: str, payload: Dict[str, str] + ) -> Dict[str, Union[List[Dict[str, str]], str]]: self._logger.debug(f"Post to url: '{url}' with body: {payload}") async with self._post(url, payload) as resp: return await resp.json() - async def pipe(self, url, stream): + async def pipe(self, url, stream) -> None: async with self._get(url) as resp: async for data in resp.content.iter_chunked(FILE_WRITE_CHUNK_SIZE): await stream.write(data) @@ -444,14 +484,23 @@ async def scroll_delta_url(self, url): else: break - async def _get_json(self, absolute_url): + async def _get_json( + self, absolute_url: str + ) -> Dict[ + str, Union[List[Dict[str, Union[str, int, Dict[str, str]]]], str, List[str]] + ]: self._logger.debug(f"Fetching url: {absolute_url}") async with self._get(absolute_url) as resp: return await resp.json() @asynccontextmanager @retryable_aiohttp_call(retries=DEFAULT_RETRY_COUNT) - async def _post(self, absolute_url, payload=None, retry_count=0): + async def _post( + self, + absolute_url: str, + payload: Optional[Dict[str, str]] = None, + retry_count: int = 0, + ): try: token = await self._api_token.get() headers = {"authorization": f"Bearer {token}"} @@ -473,7 +522,9 @@ async def _post(self, absolute_url, payload=None, retry_count=0): except ClientPayloadError as e: await self._handle_client_payload_error(e, retry_count) - async def _check_batch_items_for_errors(self, url, batch_resp): + async def _check_batch_items_for_errors( + self, url: str, batch_resp: ClientResponse + ) -> None: body = await batch_resp.json() responses = body.get("responses", []) for response in responses: @@ -494,7 +545,7 @@ async def _check_batch_items_for_errors(self, url, batch_resp): @asynccontextmanager @retryable_aiohttp_call(retries=DEFAULT_RETRY_COUNT) - async def _get(self, absolute_url, retry_count=0): + async def _get(self, absolute_url: str, retry_count: int = 0): try: token = await self._api_token.get() headers = {"authorization": f"Bearer {token}"} @@ -515,7 +566,9 @@ async def _get(self, absolute_url, retry_count=0): except ClientPayloadError as e: await self._handle_client_payload_error(e, retry_count) - async def _handle_client_payload_error(self, e, retry_count): + async def _handle_client_payload_error( + self, e: ClientPayloadError, retry_count: int + ): await self._sleeps.sleep( self._compute_retry_after( DEFAULT_RETRY_SECONDS, retry_count, DEFAULT_BACKOFF_MULTIPLIER @@ -524,7 +577,9 @@ async def _handle_client_payload_error(self, e, retry_count): raise e - async def _handle_client_response_error(self, absolute_url, e, retry_count): + async def _handle_client_response_error( + self, absolute_url: str, e: ClientResponseError, retry_count: int + ): if e.status == 429 or e.status == 503: response_headers = e.headers or {} @@ -561,7 +616,9 @@ async def _handle_client_response_error(self, absolute_url, e, retry_count): else: raise - def _compute_retry_after(self, retry_after, retry_count, backoff): + def _compute_retry_after( + self, retry_after: int, retry_count: int, backoff: int + ) -> int: # Wait for what Sharepoint API asks after the first failure. # Apply backoff if API is still not available. if retry_count <= 1: @@ -573,13 +630,13 @@ def _compute_retry_after(self, retry_after, retry_count, backoff): class SharepointOnlineClient: def __init__( self, - tenant_id, - tenant_name, - client_id, - client_secret=None, - certificate=None, - private_key=None, - ): + tenant_id: str, + tenant_name: str, + client_id: str, + client_secret: Optional[str] = None, + certificate: None = None, + private_key: None = None, + ) -> None: # Sharepoint / Graph API has quite strict throttling policies # If connector is overzealous, it can be banned for not respecting throttling policies # However if connector has a low setting for the tcp_connector limit, then it'll just be slow. @@ -640,7 +697,7 @@ def __init__( self._http_session, self.rest_api_token, "odata.nextLink", self._logger ) - def set_logger(self, logger_): + def set_logger(self, logger_) -> None: self._logger = logger_ self._graph_api_client.set_logger(self._logger) self._rest_api_client.set_logger(self._logger) @@ -654,7 +711,7 @@ async def groups(self): for group in page: yield group - async def group_sites(self, group_id): + async def group_sites(self, group_id: str) -> None: select = "" try: @@ -711,7 +768,7 @@ async def site_admins(self, site_web_url): self._logger.debug(f"No site admins found for site: '${site_web_url}'") return - async def site_groups_users(self, site_web_url, site_group_id): + async def site_groups_users(self, site_web_url: str, site_group_id: int) -> None: self._validate_sharepoint_rest_url(site_web_url) select_ = "Email,Id,UserPrincipalName,LoginName,Title" @@ -741,7 +798,7 @@ async def active_users_with_groups(self): except NotFound: return - async def group_members(self, group_id): + async def group_members(self, group_id: str) -> None: url = f"{GRAPH_API_URL}/groups/{group_id}/members" try: @@ -751,7 +808,7 @@ async def group_members(self, group_id): except NotFound: return - async def group_owners(self, group_id): + async def group_owners(self, group_id: str) -> None: select = "id,mail,userPrincipalName" url = f"{GRAPH_API_URL}/groups/{group_id}/owners?$select={select}" @@ -762,7 +819,7 @@ async def group_owners(self, group_id): except NotFound: return - async def site_users(self, site_web_url): + async def site_users(self, site_web_url: str) -> None: self._validate_sharepoint_rest_url(site_web_url) url = f"{site_web_url}/_api/web/siteusers" @@ -778,8 +835,8 @@ async def sites( self, sharepoint_host, allowed_root_sites, - enumerate_all_sites=True, - fetch_subsites=False, + enumerate_all_sites: bool = True, + fetch_subsites: bool = False, ): if allowed_root_sites == [WILDCARD] or enumerate_all_sites: self._logger.debug(f"Looking up all sites to fetch: {allowed_root_sites}") @@ -802,7 +859,7 @@ async def sites( f"Could not look up site '{allowed_site}' by relative path in parent site: {sharepoint_host}" ) - async def _all_sites(self, sharepoint_host, allowed_root_sites): + async def _all_sites(self, sharepoint_host: str, allowed_root_sites: List[Any]): select = "" try: async for page in self._graph_api_client.scroll( @@ -832,7 +889,9 @@ async def _fetch_site_and_subsites_by_path(self, sharepoint_host, allowed_site): async for site in self._recurse_sites(site_with_subsites): yield site - async def _fetch_site(self, sharepoint_host, allowed_site): + async def _fetch_site( + self, sharepoint_host: str, allowed_site: str + ) -> Dict[str, str]: self._logger.debug( f"Requesting site '{allowed_site}' by relative path in parent site: {sharepoint_host}" ) @@ -876,7 +935,7 @@ async def drive_items_delta(self, url): if "value" in response and len(response["value"]) > 0: yield DriveItemsPage(response["value"], delta_link) - async def drive_items(self, drive_id, url=None): + async def drive_items(self, drive_id, url: Optional[str] = None): url = ( ( f"{GRAPH_API_URL}/drives/{drive_id}/root/delta?$select={DRIVE_ITEMS_FIELDS}" @@ -888,7 +947,9 @@ async def drive_items(self, drive_id, url=None): async for page in self.drive_items_delta(url): yield page - async def drive_items_permissions_batch(self, drive_id, drive_item_ids): + async def drive_items_permissions_batch( + self, drive_id: int, drive_item_ids: List[Union[int, Any]] + ) -> None: requests = [] for item_id in drive_item_ids: @@ -910,7 +971,9 @@ async def drive_items_permissions_batch(self, drive_id, drive_item_ids): except NotFound: return - async def download_drive_item(self, drive_id, item_id, async_buffer): + async def download_drive_item( + self, drive_id: str, item_id: str, async_buffer: MagicMock + ) -> None: await self._graph_api_client.pipe( f"{GRAPH_API_URL}/drives/{drive_id}/items/{item_id}/content", async_buffer ) @@ -938,7 +1001,9 @@ async def site_list_has_unique_role_assignments(self, site_web_url, site_list_na except NotFound: return False - async def site_list_role_assignments(self, site_web_url, site_list_name): + async def site_list_role_assignments( + self, site_web_url: str, site_list_name: str + ) -> None: self._validate_sharepoint_rest_url(site_web_url) expand = "Member/users,RoleDefinitionBindings" @@ -953,8 +1018,8 @@ async def site_list_role_assignments(self, site_web_url, site_list_name): return async def site_list_item_has_unique_role_assignments( - self, site_web_url, site_list_name, list_item_id - ): + self, site_web_url: str, site_list_name: str, list_item_id: int + ) -> bool: self._validate_sharepoint_rest_url(site_web_url) url = f"{site_web_url}/_api/lists/GetByTitle('{site_list_name}')/items({list_item_id})/HasUniqueRoleAssignments" @@ -971,8 +1036,8 @@ async def site_list_item_has_unique_role_assignments( return False async def site_list_item_role_assignments( - self, site_web_url, site_list_name, list_item_id - ): + self, site_web_url: str, site_list_name: str, list_item_id: int + ) -> None: self._validate_sharepoint_rest_url(site_web_url) expand = "Member/users,RoleDefinitionBindings" @@ -996,7 +1061,9 @@ async def site_list_items(self, site_id, list_id): for site_list in page: yield site_list - async def site_list_item_attachments(self, site_web_url, list_title, list_item_id): + async def site_list_item_attachments( + self, site_web_url: str, list_title: str, list_item_id: str + ) -> None: self._validate_sharepoint_rest_url(site_web_url) url = f"{site_web_url}/_api/lists/GetByTitle('{list_title}')/items({list_item_id})?$expand=AttachmentFiles" @@ -1011,14 +1078,16 @@ async def site_list_item_attachments(self, site_web_url, list_title, list_item_i # Yes, makes no sense to me either. return - async def download_attachment(self, attachment_absolute_path, async_buffer): + async def download_attachment( + self, attachment_absolute_path: str, async_buffer: MagicMock + ) -> None: self._validate_sharepoint_rest_url(attachment_absolute_path) await self._rest_api_client.pipe( f"{attachment_absolute_path}/$value", async_buffer ) - async def site_pages(self, site_web_url): + async def site_pages(self, site_web_url: str) -> None: self._validate_sharepoint_rest_url(site_web_url) # select = "Id,Title,LayoutWebpartsContent,CanvasContent1,Description,Created,AuthorId,Modified,EditorId" @@ -1050,7 +1119,9 @@ async def site_pages(self, site_web_url): # Just to be on a safe side return - async def site_page_has_unique_role_assignments(self, site_web_url, site_page_id): + async def site_page_has_unique_role_assignments( + self, site_web_url: str, site_page_id: Union[str, int] + ) -> bool: self._validate_sharepoint_rest_url(site_web_url) url = f"{site_web_url}/_api/web/lists/GetByTitle('Site Pages')/items('{site_page_id}')/HasUniqueRoleAssignments" @@ -1061,7 +1132,9 @@ async def site_page_has_unique_role_assignments(self, site_web_url, site_page_id except NotFound: return False - async def site_page_role_assignments(self, site_web_url, site_page_id): + async def site_page_role_assignments( + self, site_web_url: str, site_page_id: int + ) -> None: self._validate_sharepoint_rest_url(site_web_url) expand = "Member/users,RoleDefinitionBindings" @@ -1075,7 +1148,9 @@ async def site_page_role_assignments(self, site_web_url, site_page_id): except NotFound: return - async def users_and_groups_for_role_assignment(self, site_web_url, role_assignment): + async def users_and_groups_for_role_assignment( + self, site_web_url: str, role_assignment: Dict[str, Union[str, int]] + ) -> List[Union[str, Any]]: self._validate_sharepoint_rest_url(site_web_url) if "PrincipalId" not in role_assignment: @@ -1093,7 +1168,7 @@ async def users_and_groups_for_role_assignment(self, site_web_url, role_assignme # This can also mean "not found" so handling it explicitly return [] - async def groups_user_transitive_member_of(self, user_id): + async def groups_user_transitive_member_of(self, user_id: str) -> None: url = f"{GRAPH_API_URL}/users/{user_id}/transitiveMemberOf" try: @@ -1103,12 +1178,12 @@ async def groups_user_transitive_member_of(self, user_id): except NotFound: return - async def tenant_details(self): + async def tenant_details(self) -> Dict[str, str]: url = f"{GRAPH_API_AUTH_URL}/common/userrealm/?user=cj@{self._tenant_name}.onmicrosoft.com&api-version=2.1&checkForMicrosoftAccount=false" return await self._rest_api_client.fetch(url) - def _validate_sharepoint_rest_url(self, url): + def _validate_sharepoint_rest_url(self, url: str) -> None: # TODO: make it better suitable for ftest if "OVERRIDE_URL" in os.environ: return @@ -1120,7 +1195,7 @@ def _validate_sharepoint_rest_url(self, url): msg = f"Unable to call Sharepoint REST API - tenant name is invalid. Authenticated for tenant name: {self._tenant_name}, actual tenant name for the service: {actual_tenant_name}. For url: {url}" raise InvalidSharepointTenant(msg) - async def close(self): + async def close(self) -> None: await self._http_session.close() self._graph_api_client.close() self._rest_api_client.close() @@ -1135,7 +1210,11 @@ class DriveItemsPage(Iterable, Sized): delta_link (str): Microsoft API deltaLink """ - def __init__(self, items, delta_link): + def __init__( + self, + items: List[Union[Dict[str, str], Dict[str, Union[Dict[str, str], str]], str]], + delta_link: Optional[str], + ) -> None: if items: self.items = items else: @@ -1146,14 +1225,16 @@ def __init__(self, items, delta_link): else: self._delta_link = None - def __len__(self): + def __len__(self) -> int: return len(self.items) - def __iter__(self): + def __iter__( + self, + ) -> Iterator[Union[Dict[str, str], Dict[str, Union[Dict[str, str], str]], str]]: for item in self.items: yield item - def delta_link(self): + def delta_link(self) -> str: return self._delta_link @@ -1168,7 +1249,9 @@ class SharepointOnlineAdvancedRulesValidator(AdvancedRulesValidator): SCHEMA = fastjsonschema.compile(definition=SCHEMA_DEFINITION) - async def validate(self, advanced_rules): + async def validate( + self, advanced_rules: Dict[str, Union[str, int]] + ) -> SyncRuleValidationResult: try: SharepointOnlineAdvancedRulesValidator.SCHEMA(advanced_rules) @@ -1183,23 +1266,23 @@ async def validate(self, advanced_rules): ) -def _prefix_group(group): +def _prefix_group(group: str) -> Optional[str]: return prefix_identity("group", group) -def _prefix_user(user): +def _prefix_user(user: str) -> Optional[str]: return prefix_identity("user", user) -def _prefix_user_id(user_id): +def _prefix_user_id(user_id: str) -> Optional[str]: return prefix_identity("user_id", user_id) -def _prefix_email(email): +def _prefix_email(email: str) -> Optional[str]: return prefix_identity("email", email) -def _get_login_name(raw_login_name): +def _get_login_name(raw_login_name: Optional[str]) -> Optional[str]: if raw_login_name and ( raw_login_name.startswith("i:0#.f|membership|") or raw_login_name.startswith("c:0o.c|federateddirectoryclaimprovider|") @@ -1213,7 +1296,7 @@ def _get_login_name(raw_login_name): return None -def _parse_created_date_time(created_date_time): +def _parse_created_date_time(created_date_time: str) -> Optional[datetime]: if created_date_time is None: return None return datetime.strptime(created_date_time, TIMESTAMP_FORMAT) @@ -1228,17 +1311,17 @@ class SharepointOnlineDataSource(BaseDataSource): dls_enabled = True incremental_sync_enabled = True - def __init__(self, configuration): + def __init__(self, configuration: DataSourceConfiguration) -> None: super().__init__(configuration=configuration) self._client = None self.site_group_cache = {} - def _set_internal_logger(self): + def _set_internal_logger(self) -> None: self.client.set_logger(self._logger) @property - def client(self): + def client(self) -> Union[AsyncMock, SharepointOnlineClient]: if not self._client: tenant_id = self.configuration["tenant_id"] tenant_name = self.configuration["tenant_name"] @@ -1268,7 +1351,17 @@ def client(self): return self._client @classmethod - def get_default_configuration(cls): + def get_default_configuration( + cls, + ) -> Dict[ + str, + Union[ + Dict[str, Union[List[Dict[str, str]], int, str]], + Dict[str, Union[List[Dict[str, Union[bool, str]]], int, str]], + Dict[str, Union[List[str], int, str]], + Dict[str, Union[int, str]], + ], + ]: return { "tenant_id": { "label": "Tenant ID", @@ -1399,7 +1492,7 @@ def get_default_configuration(cls): }, } - async def validate_config(self): + async def validate_config(self) -> None: await super().validate_config() # Check that we can log in into Graph API @@ -1444,14 +1537,18 @@ async def validate_config(self): msg = f"The specified SharePoint sites [{', '.join(missing)}] could not be retrieved during sync. Examples of sites available on the tenant:[{', '.join(retrieved_sites[:5])}]." raise Exception(msg) - def _site_path_from_web_url(self, web_url): + def _site_path_from_web_url(self, web_url: str) -> str: url_parts = web_url.split("/sites/") site_path_parts = url_parts[1:] return "/sites/".join( site_path_parts ) # just in case there was a /sites/ in the site path - def _decorate_with_access_control(self, document, access_control): + def _decorate_with_access_control( + self, + document: Dict[str, Union[str, Dict[str, str], int, List[str]]], + access_control: List[Union[str, Any]], + ) -> Dict[str, Union[str, Dict[str, str], int, List[str], List[Any]]]: if self._dls_enabled(): document[ACCESS_CONTROL] = list( set(document.get(ACCESS_CONTROL, []) + access_control) @@ -1459,7 +1556,9 @@ def _decorate_with_access_control(self, document, access_control): return document - async def _site_access_control(self, site): + async def _site_access_control( + self, site: Dict[str, Union[str, int, Dict[str, str]]] + ) -> Union[Tuple[List[Any], List[Any]], Tuple[List[str], List[str]]]: """Fetches all permissions for all owners, members and visitors of a given site. All groups and/or persons, which have permissions for a given site are returned with their given identity prefix ("user", "group" or "email"). For the given site all groups and its corresponding members and owners (username and/or email) are fetched. @@ -1511,7 +1610,7 @@ def _is_site_admin(user): return list(access_control), list(site_admins_access_control) - def _dls_enabled(self): + def _dls_enabled(self) -> bool: if self._features is None: return False @@ -1520,10 +1619,14 @@ def _dls_enabled(self): return self.configuration["use_document_level_security"] - def access_control_query(self, access_control): + def access_control_query( + self, access_control: List[str] + ) -> Dict[str, Dict[str, Dict[str, Union[Dict[str, List[str]], str]]]]: return es_access_control_query(access_control) - async def _user_access_control_doc(self, user): + async def _user_access_control_doc( + self, user: Dict[str, Optional[str]] + ) -> Optional[Dict[str, Any]]: """Constructs a user access control document, which will be synced to the corresponding access control index. The `_id` of the user access control document will either be the username (can also be the email sometimes) or the email itself. Note: the `_id` field won't be prefixed with the corresponding identity prefix ("user" or "email"). @@ -1599,7 +1702,7 @@ async def _user_access_control_doc(self, user): "created_at": created_at, } | self.access_control_query(access_control) - async def get_access_control(self): + async def get_access_control(self) -> None: """Yields an access control document for every user of a site. Note: this method will cache users and emails it has already and skip the ingestion for those. @@ -1832,7 +1935,7 @@ async def get_docs(self, filtering=None): ) yield site_page, None - async def get_docs_incrementally(self, sync_cursor, filtering=None): + async def get_docs_incrementally(self, sync_cursor: None, filtering: None = None): self._sync_cursor = sync_cursor timestamp = iso_zulu() @@ -1959,7 +2062,7 @@ async def site_collections(self): yield site_collection - async def sites(self, hostname, collections, check_timestamp=False): + async def sites(self, hostname, collections, check_timestamp: bool = False): async for site in self.client.sites( hostname, collections, @@ -1975,7 +2078,7 @@ async def sites(self, hostname, collections, check_timestamp=False): yield site - async def site_drives(self, site, check_timestamp=False): + async def site_drives(self, site, check_timestamp: bool = False): async for site_drive in self.client.site_drives(site["id"]): if not check_timestamp or ( check_timestamp @@ -1987,8 +2090,16 @@ async def site_drives(self, site, check_timestamp=False): yield site_drive async def _with_drive_item_permissions( - self, drive_item, drive_item_permissions, site_web_url - ): + self, + drive_item: Dict[str, Union[str, int]], + drive_item_permissions: List[ + Dict[ + str, + Union[Dict[str, Dict[str, str]], List[Dict[str, Dict[str, str]]], str], + ] + ], + site_web_url: str, + ) -> Dict[str, Union[List[str], int, str]]: """Decorates a drive item with its permissions. Args: @@ -2108,7 +2219,7 @@ async def site_list_items( site_list_id, site_list_name, site_access_control, - check_timestamp=False, + check_timestamp: bool = False, ): site_id = site.get("id") site_web_url = site.get("webUrl") @@ -2216,7 +2327,9 @@ async def site_list_items( yield list_item, None - async def site_lists(self, site, site_access_control, check_timestamp=False): + async def site_lists( + self, site, site_access_control, check_timestamp: bool = False + ): async for site_list in self.client.site_lists(site["id"]): if not check_timestamp or ( check_timestamp @@ -2268,7 +2381,9 @@ async def site_lists(self, site, site_access_control, check_timestamp=False): yield site_list - async def _get_access_control_from_role_assignment(self, role_assignment): + async def _get_access_control_from_role_assignment( + self, role_assignment: Dict[str, Any] + ) -> List[Union[str, Any]]: """Extracts access control from a role assignment. Args: @@ -2335,7 +2450,9 @@ def _has_limited_access(role_assignment): return access_control - async def site_pages(self, site, site_access_control, check_timestamp=False): + async def site_pages( + self, site, site_access_control, check_timestamp: bool = False + ): site_id = site["id"] url = site["webUrl"] async for site_page in self.client.site_pages(url): @@ -2403,7 +2520,7 @@ async def site_pages(self, site, site_access_control, check_timestamp=False): yield site_page - def init_sync_cursor(self): + def init_sync_cursor(self) -> Dict[str, str]: if not self._sync_cursor: self._sync_cursor = { CURSOR_SITE_DRIVE_KEY: {}, @@ -2412,24 +2529,30 @@ def init_sync_cursor(self): return self._sync_cursor - def update_drive_delta_link(self, drive_id, link): + def update_drive_delta_link(self, drive_id: str, link: str) -> None: if not link: return self._sync_cursor[CURSOR_SITE_DRIVE_KEY][drive_id] = link - def get_drive_delta_link(self, drive_id): + def get_drive_delta_link(self, drive_id: str) -> str: return nested_get_from_dict( self._sync_cursor, [CURSOR_SITE_DRIVE_KEY, drive_id] ) - def drive_item_operation(self, item): + def drive_item_operation( + self, item: Dict[str, Optional[Union[str, Dict[str, str]]]] + ) -> str: if "deleted" in item: return OP_DELETE else: return OP_INDEX - def download_function(self, drive_item, max_drive_item_age): + def download_function( + self, + drive_item: Dict[str, Optional[Union[str, Dict[str, str], int, List[str]]]], + max_drive_item_age: Optional[int], + ) -> None: if "deleted" in drive_item: # deleted drive items do not contain `name` property in the payload # so drive_item['id'] is used @@ -2448,7 +2571,7 @@ def download_function(self, drive_item, max_drive_item_age): f"Not downloading file {drive_item['name']}: field \"@microsoft.graph.downloadUrl\" is missing" ) return None - + # assert isinstance(drive_item["name"], str) if not self.is_supported_format(drive_item["name"]): self._logger.debug( f"Not downloading file {drive_item['name']}: file type is not supported" @@ -2461,10 +2584,12 @@ def download_function(self, drive_item, max_drive_item_age): ) return None + # assert isinstance(drive_item["lastModifiedDateTime"], str) modified_date = datetime.strptime( drive_item["lastModifiedDateTime"], TIMESTAMP_FORMAT ) + # assert isinstance(drive_item["size"], int) if max_drive_item_age and modified_date < datetime.utcnow() - timedelta( days=max_drive_item_age ): @@ -2486,7 +2611,9 @@ def download_function(self, drive_item, max_drive_item_age): drive_item["_original_filename"] = drive_item.get("name", "") return partial(self.get_drive_item_content, drive_item) - async def get_attachment_content(self, attachment, timestamp=None, doit=False): + async def get_attachment_content( + self, attachment: Dict[str, str], timestamp: None = None, doit: bool = False + ) -> Optional[Dict[str, Any]]: if not doit: return @@ -2528,7 +2655,12 @@ async def get_attachment_content(self, attachment, timestamp=None, doit=False): return doc - async def get_drive_item_content(self, drive_item, timestamp=None, doit=False): + async def get_drive_item_content( + self, + drive_item: Dict[str, Union[str, int, datetime, Dict[str, str]]], + timestamp: None = None, + doit: bool = False, + ) -> Optional[Dict[str, Any]]: document_size = int(drive_item["size"]) if not (doit and document_size): @@ -2562,7 +2694,9 @@ async def get_drive_item_content(self, drive_item, timestamp=None, doit=False): return doc - async def _download_content(self, download_func, original_filename): + async def _download_content( + self, download_func: partial, original_filename: str + ) -> Generator[Future, None, Union[Tuple[str, None], Tuple[None, str]]]: attachment = None body = None source_file_name = "" @@ -2601,18 +2735,18 @@ async def _download_content(self, download_func, original_filename): return attachment, body - async def ping(self): + async def ping(self) -> None: pass - async def close(self): + async def close(self) -> None: await self.client.close() if self.extraction_service is not None: await self.extraction_service._end_session() - def advanced_rules_validators(self): + def advanced_rules_validators(self) -> List[SharepointOnlineAdvancedRulesValidator]: return [SharepointOnlineAdvancedRulesValidator()] - def is_supported_format(self, filename): + def is_supported_format(self, filename: str) -> bool: if "." not in filename: return False @@ -2622,7 +2756,9 @@ def is_supported_format(self, filename): return False - async def _access_control_for_member(self, member): + async def _access_control_for_member( + self, member: Dict[str, Optional[str]] + ) -> List[str]: """ Helper function for converting a generic "member" into an access control list. "Member" here is loose, and intended to work with multiple SPO API responses. @@ -2660,7 +2796,9 @@ async def _access_control_for_member(self, member): else: return self._access_control_for_user(member) - def _access_control_for_user(self, user): + def _access_control_for_user( + self, user: Dict[str, Optional[str]] + ) -> List[Optional[str]]: user_access_control = [] user_principal_name = user.get( @@ -2686,7 +2824,7 @@ def _access_control_for_user(self, user): return user_access_control - async def _access_control_for_group_id(self, group_id): + async def _access_control_for_group_id(self, group_id: str) -> List[Optional[str]]: def is_group_owners_reference(potential_group_id): """ Some group ids aren't actually group IDs, but are references to the _owners_ of a group. diff --git a/connectors/sources/sharepoint_server.py b/connectors/sources/sharepoint_server.py index 10fd8cf0f..e5dfc2f20 100644 --- a/connectors/sources/sharepoint_server.py +++ b/connectors/sources/sharepoint_server.py @@ -7,7 +7,9 @@ import os import re +from _asyncio import Future, Task from functools import partial +from typing import Any, Dict, Generator, Iterator, List, Optional, Tuple, Union from urllib.parse import quote import httpx @@ -19,7 +21,12 @@ prefix_identity, ) from connectors.logger import logger -from connectors.source import CHUNK_SIZE, BaseDataSource, ConfigurableFieldValueError +from connectors.source import ( + CHUNK_SIZE, + BaseDataSource, + ConfigurableFieldValueError, + DataSourceConfiguration, +) from connectors.utils import ( TIKA_SUPPORTED_FILETYPES, CancellableSleeps, @@ -62,7 +69,7 @@ EDITOR = 6 REVIEWER = 7 SYSTEM = 0xFF -VIEW_ROLE_TYPES = [ +VIEW_ROLE_TYPES: List[int] = [ READER, CONTRIBUTOR, WEB_DESIGNER, @@ -73,7 +80,7 @@ ] -URLS = { +URLS: Dict[str, str] = { PING: "{site_url}/_api/web/webs", SITE: "{parent_site_url}/_api/web", SITES: "{parent_site_url}/_api/web/webs?$skip={skip}&$top={top}", @@ -90,7 +97,7 @@ UNIQUE_ROLES_FOR_ITEM: "{parent_site_url}/_api/lists/GetByTitle('{site_list_name}')/items({list_item_id})/HasUniqueRoleAssignments", ROLES_BY_TITLE_FOR_ITEM: "{parent_site_url}/_api/lists/GetByTitle('{site_list_name}')/items({list_item_id})/roleassignments?$expand=Member/users,RoleDefinitionBindings&$skip={skip}&$top={top}", } -SCHEMA = { +SCHEMA: Dict[str, Dict[str, str]] = { SITES: { "title": "Title", "url": "Url", @@ -128,7 +135,7 @@ class SharepointServerClient: """SharePoint client to handle API calls made to SharePoint""" - def __init__(self, configuration): + def __init__(self, configuration: DataSourceConfiguration) -> None: self._sleeps = CancellableSleeps() self.configuration = configuration self._logger = logger @@ -151,10 +158,10 @@ def __init__(self, configuration): else: self.ssl_ctx = False - def set_logger(self, logger_): + def set_logger(self, logger_) -> None: self._logger = logger_ - def _get_session(self): + def _get_session(self) -> httpx.AsyncClient: """Generate base client session using configuration fields Returns: @@ -190,12 +197,12 @@ def _get_session(self): def format_url( self, - site_url, - relative_url, - list_item_id=None, - content_type_id=None, - is_list_item_has_attachment=False, - ): + site_url: str, + relative_url: str, + list_item_id: Optional[str] = None, + content_type_id: Optional[str] = None, + is_list_item_has_attachment: bool = False, + ) -> str: if is_list_item_has_attachment: return ( site_url @@ -211,7 +218,7 @@ def format_url( else: return site_url + quote(relative_url) - async def close_session(self): + async def close_session(self) -> None: """Closes unclosed client session""" self._sleeps.cancel() if self.session is None: @@ -219,7 +226,9 @@ async def close_session(self): await self.session.aclose() # pyright: ignore self.session = None - async def _api_call(self, url_name, url="", **url_kwargs): + async def _api_call( + self, url_name: str, url: str = "", **url_kwargs + ) -> Iterator[Task]: """Make an API call to the SharePoint Server Args: @@ -269,7 +278,7 @@ async def _api_call(self, url_name, url="", **url_kwargs): await self._sleeps.sleep(RETRY_INTERVAL**retry) async def _fetch_data_with_next_url( - self, site_url, list_id, param_name, selected_field="" + self, site_url, list_id, param_name, selected_field: str = "" ): """Invokes a GET call to the SharePoint Server for calling list and drive item API. @@ -333,7 +342,7 @@ async def _fetch_data_with_query(self, site_url, param_name, **kwargs): if len(response_result) < TOP: break - async def get_sites(self, site_url): + async def get_sites(self, site_url: str) -> None: """Get sites from SharePoint Server Args: @@ -380,7 +389,9 @@ async def get_lists(self, site_url): for site_list in list_data: yield site_list - async def get_attachment(self, site_url, file_relative_url): + async def get_attachment( + self, site_url: str, file_relative_url: str + ) -> Dict[str, str]: """Execute the call for fetching attachment metadata Args: @@ -398,7 +409,7 @@ async def get_attachment(self, site_url, file_relative_url): ) ) - def verify_filename_for_extraction(self, filename, relative_url): + def verify_filename_for_extraction(self, filename: str, relative_url: str) -> None: attachment_extension = list(os.path.splitext(filename)) if "" in attachment_extension: attachment_extension.remove("") @@ -506,7 +517,7 @@ async def get_drive_items(self, list_id, site_url, **kwargs): yield result, file_relative_url - async def ping(self): + async def ping(self) -> None: """Executes the ping call in async manner""" site_url = "" if len(self.site_collections) > 0: @@ -552,7 +563,9 @@ async def site_role_assignments_using_title(self, site_url, site_list_name): for role in roles: yield role - async def site_list_has_unique_role_assignments(self, site_list_name, site_url): + async def site_list_has_unique_role_assignments( + self, site_list_name: str, site_url: str + ) -> bool: role = await anext( self._api_call( url_name=UNIQUE_ROLES, @@ -564,8 +577,8 @@ async def site_list_has_unique_role_assignments(self, site_list_name, site_url): return role.get("value", False) async def site_list_item_has_unique_role_assignments( - self, site_url, site_list_name, list_item_id - ): + self, site_url: str, site_list_name: str, list_item_id: int + ) -> bool: role = await anext( self._api_call( url_name=UNIQUE_ROLES_FOR_ITEM, @@ -589,7 +602,7 @@ async def site_list_item_role_assignments( for role in roles: yield role - def fix_relative_url(self, site_relative_url, item_relative_url): + def fix_relative_url(self, site_relative_url: str, item_relative_url: str) -> str: if item_relative_url is not None: item_relative_url = ( item_relative_url @@ -607,7 +620,7 @@ class SharepointServerDataSource(BaseDataSource): incremental_sync_enabled = True dls_enabled = True - def __init__(self, configuration): + def __init__(self, configuration: DataSourceConfiguration) -> None: """Setup the connection to the SharePoint Args: @@ -617,35 +630,45 @@ def __init__(self, configuration): self.sharepoint_client = SharepointServerClient(configuration=configuration) self.invalid_collections = [] - def _set_internal_logger(self): + def _set_internal_logger(self) -> None: self.sharepoint_client.set_logger(self._logger) - def _prefix_user(self, user): + def _prefix_user(self, user: str) -> Optional[str]: return prefix_identity("user", user) - def _prefix_user_id(self, user_id): + def _prefix_user_id(self, user_id: int) -> Optional[str]: return prefix_identity("user_id", user_id) - def _prefix_email(self, email): + def _prefix_email(self, email: str) -> Optional[str]: return prefix_identity("email", email) - def _prefix_login_name(self, name): + def _prefix_login_name(self, name: str) -> Optional[str]: return prefix_identity("login_name", name) - def _prefix_group(self, group): + def _prefix_group(self, group: str) -> Optional[str]: return prefix_identity("group", group) - def _prefix_group_name(self, group_name): + def _prefix_group_name(self, group_name: str) -> Optional[str]: return prefix_identity("group_name", group_name) - def _get_login_name(self, raw_login_name): + def _get_login_name(self, raw_login_name: str) -> str: if raw_login_name: parts = raw_login_name.split("|") return parts[-1] return None @classmethod - def get_default_configuration(cls): + def get_default_configuration( + cls, + ) -> Dict[ + str, + Union[ + Dict[str, Union[List[Dict[str, str]], int, str]], + Dict[str, Union[List[Dict[str, Union[bool, str]]], int, str]], + Dict[str, Union[List[str], int, str]], + Dict[str, Union[int, str]], + ], + ]: """Get the default configuration for SharePoint Returns: @@ -745,7 +768,7 @@ def get_default_configuration(cls): }, } - def _dls_enabled(self): + def _dls_enabled(self) -> bool: if ( self._features is None or not self._features.document_level_security_enabled() @@ -754,7 +777,11 @@ def _dls_enabled(self): return self.configuration["use_document_level_security"] - def _decorate_with_access_control(self, document, access_control): + def _decorate_with_access_control( + self, + document: Dict[str, Optional[Union[str, int]]], + access_control: List[Union[str, Any]], + ) -> Dict[str, Optional[Union[List[str], int, str]]]: if self._dls_enabled(): document[ACCESS_CONTROL] = list( set(document.get(ACCESS_CONTROL, []) + access_control) @@ -762,10 +789,14 @@ def _decorate_with_access_control(self, document, access_control): return document - def access_control_query(self, access_control): + def access_control_query( + self, access_control: List[str] + ) -> Dict[str, Dict[str, Dict[str, Union[Dict[str, List[str]], str]]]]: return es_access_control_query(access_control) - async def _user_access_control_doc(self, user): + async def _user_access_control_doc( + self, user: Dict[str, Union[str, int, bool, Dict[str, str]]] + ) -> Dict[str, Any]: login_name = user.get("LoginName") prefixed_mail = self._prefix_email(user.get("Email")) prefixed_title = self._prefix_user(user.get("Title")) @@ -790,7 +821,9 @@ async def _user_access_control_doc(self, user): "created_at": iso_utc(), } | self.access_control_query(access_control) - async def _access_control_for_member(self, member): + async def _access_control_for_member( + self, member: Dict[str, Union[str, int, bool, Dict[str, str]]] + ) -> List[Optional[str]]: principal_type = member.get("PrincipalType") is_group = principal_type != IS_USER if principal_type else False user_access_control = [] @@ -823,7 +856,7 @@ async def _access_control_for_member(self, member): user_access_control.append(self._prefix_email(email)) return user_access_control - async def get_access_control(self): + async def get_access_control(self) -> None: """Yields an access control document for every user of a site. Note: this method will cache users and emails it has already and skip the ingestion for those. @@ -872,7 +905,27 @@ async def process_user(user): if user_doc: yield user_doc - async def _get_access_control_from_role_assignment(self, role_assignment): + async def _get_access_control_from_role_assignment( + self, + role_assignment: Dict[ + str, + Union[ + Dict[ + str, + Union[ + str, + int, + List[Dict[str, Union[str, int, bool, Dict[str, str]]]], + bool, + Dict[str, str], + ], + ], + List[Dict[str, Union[Dict[str, str], int, str]]], + int, + Dict[str, Union[str, int, bool, Dict[str, str]]], + ], + ], + ) -> List[Union[str, Any]]: """Extracts access control from a role assignment. Args: @@ -933,7 +986,9 @@ def _has_limited_access(role_assignment): return access_control - async def _site_access_control(self, site_url): + async def _site_access_control( + self, site_url: str + ) -> Union[Tuple[List[Any], List[str]], Tuple[List[Any], List[Any]]]: self._logger.debug(f"Looking at site with url {site_url}") if not self._dls_enabled(): return [], [] @@ -967,11 +1022,11 @@ def _is_site_admin(user): return list(access_control), list(site_admins_access_control) - async def close(self): + async def close(self) -> None: """Closes unclosed client session""" await self.sharepoint_client.close_session() - async def _remote_validation(self): + async def _remote_validation(self) -> None: """Validate configured collections Raises: ConfigurableFieldValueError: Unavailable services error. @@ -994,13 +1049,13 @@ async def _remote_validation(self): ) raise ConfigurableFieldValueError(msg) - async def validate_config(self): + async def validate_config(self) -> None: """Validates whether user input is empty or not for configuration fields Also validate, if user configured collections are available in SharePoint.""" await super().validate_config() await self._remote_validation() - async def ping(self): + async def ping(self) -> None: """Verify the connection with SharePoint""" try: await self.sharepoint_client.ping() @@ -1015,10 +1070,10 @@ async def ping(self): def map_document_with_schema( self, - document, - item, - document_type, - ): + document: Dict[str, Union[int, str]], + item: Dict[str, Any], + document_type: str, + ) -> None: """Prepare key mappings for documents Args: @@ -1034,13 +1089,13 @@ def map_document_with_schema( def format_lists( self, - site_url, - list_relative_url, - item, - document_type, - admin_access_control, - site_list_access_control, - ): + site_url: str, + list_relative_url: str, + item: Dict[str, Union[Dict[str, str], int, str]], + document_type: str, + admin_access_control: List[Any], + site_list_access_control: List[Union[str, Any]], + ) -> Dict[str, Union[List[str], str]]: """Prepare key mappings for list Args: @@ -1068,7 +1123,9 @@ def format_lists( admin_access_control.extend(site_list_access_control) return self._decorate_with_access_control(document, admin_access_control) - def format_sites(self, item): + def format_sites( + self, item: Dict[str, Optional[Union[str, int, Dict[str, Union[int, str]]]]] + ) -> Dict[str, Any]: """Prepare key mappings for site Args: @@ -1089,10 +1146,19 @@ def format_sites(self, item): def format_drive_item( self, - site_url, - server_relative_url, - item, - ): + site_url: str, + server_relative_url: Optional[str], + item: Dict[ + str, + Union[ + Dict[str, str], + int, + str, + Dict[str, Union[int, str]], + Dict[str, Dict[Any, Any]], + ], + ], + ) -> Dict[str, str]: """Prepare key mappings for drive items Args: @@ -1132,10 +1198,10 @@ def format_drive_item( def format_list_item( self, - item, - site_url=None, - server_relative_url=None, - ): + item: Dict[str, Union[str, int, bool, Dict[str, str]]], + site_url: Optional[str] = None, + server_relative_url: Optional[str] = None, + ) -> Dict[str, Any]: """Prepare key mappings for list items Args: @@ -1168,7 +1234,7 @@ def format_list_item( return document - async def get_lists(self, site_url, site_access_control): + async def get_lists(self, site_url: str, site_access_control: List[Any]) -> None: async for site_list in self.sharepoint_client.get_lists(site_url=site_url): has_unique_role_assignments = False site_list_access_control = [] @@ -1202,8 +1268,12 @@ async def get_lists(self, site_url, site_access_control): yield site_list, site_list_access_control async def fetch_list_item_permission( - self, site_url, site_list_name, list_item_id, site_access_control - ): + self, + site_url: str, + site_list_name: str, + list_item_id: int, + site_access_control: List[Any], + ) -> List[Union[str, Any]]: list_item_access_control = [] has_unique_role_assignments = False @@ -1234,7 +1304,7 @@ async def fetch_list_item_permission( list_item_access_control.extend(site_access_control) return list_item_access_control - async def get_docs(self, filtering=None): + async def get_docs(self, filtering: None = None) -> None: """Executes the logic to fetch SharePoint objects in an async manner. Yields: @@ -1368,8 +1438,13 @@ async def get_docs(self, filtering=None): ) async def get_content( - self, document, file_relative_url, site_url, timestamp=None, doit=False - ): + self, + document: Dict[str, Union[int, str]], + file_relative_url: str, + site_url: str, + timestamp: None = None, + doit: bool = False, + ) -> Generator[Future, None, Optional[Dict[str, Union[int, str]]]]: """Get content of list items and drive items Args: @@ -1414,8 +1489,12 @@ async def get_content( ) async def get_site_pages_content( - self, document, list_response, timestamp=None, doit=False - ): + self, + document: Dict[str, Union[int, str]], + list_response: Dict[str, Optional[str]], + timestamp: None = None, + doit: bool = False, + ) -> Generator[Future, None, Optional[Dict[str, Union[int, str]]]]: """Get content of site pages for SharePoint Args: diff --git a/connectors/sources/slack.py b/connectors/sources/slack.py index c2a741ce8..0ac47478a 100644 --- a/connectors/sources/slack.py +++ b/connectors/sources/slack.py @@ -6,14 +6,16 @@ import re import time +from _asyncio import Task from contextlib import asynccontextmanager from datetime import datetime +from typing import Any, Dict, Generator, List, Optional, Union import aiohttp from aiohttp.client_exceptions import ClientResponseError -from connectors.logger import logger -from connectors.source import BaseDataSource +from connectors.logger import ExtraLogger, logger +from connectors.source import BaseDataSource, DataSourceConfiguration from connectors.utils import CancellableSleeps, dict_slice, retryable BASE_URL = "https://slack.com/api" @@ -23,7 +25,7 @@ NEXT_CURSOR = "next_cursor" DEFAULT_RETRY_SECONDS = 3 PAGE_SIZE = 200 -USER_ID_PATTERN = re.compile(r"<@([A-Z0-9]+)>") +USER_ID_PATTERN: re.Pattern[str] = re.compile(r"<@([A-Z0-9]+)>") # TODO list # Nice to haves: @@ -42,7 +44,7 @@ class ThrottledError(Exception): class SlackAPIError(Exception): """Internal exception class that wraps all non-ok responses from Slack""" - def __init__(self, reason): + def __init__(self, reason) -> None: self.__reason = reason @property @@ -51,7 +53,10 @@ def reason(self): class SlackClient: - def __init__(self, configuration): + def __init__( + self, + configuration: Union[Dict[str, Union[bool, str, int]], DataSourceConfiguration], + ) -> None: self.token = configuration["token"] self._http_session = aiohttp.ClientSession( headers=self._headers(), @@ -61,10 +66,10 @@ def __init__(self, configuration): self._logger = logger self._sleeps = CancellableSleeps() - def set_logger(self, logger_): + def set_logger(self, logger_: ExtraLogger) -> None: self._logger = logger_ - async def ping(self): + async def ping(self) -> bool: url = f"{BASE_URL}/auth.test" try: await self._get_json(url) @@ -72,7 +77,7 @@ async def ping(self): except SlackAPIError: return False - async def close(self): + async def close(self) -> None: await self._http_session.close() self._sleeps.cancel() @@ -148,13 +153,27 @@ async def list_users(self, cursor=None): if not cursor: break - def _add_cursor(self, url, cursor): + def _add_cursor(self, url: str, cursor: str) -> str: return f"{url}&{CURSOR}={cursor}" - def _get_next_cursor(self, response): + def _get_next_cursor( + self, + response: Dict[ + str, + Union[ + bool, + List[Dict[str, Union[bool, str]]], + Dict[str, str], + List[Dict[str, str]], + List[Union[Dict[str, Union[int, str]], Dict[str, str]]], + ], + ], + ) -> Optional[str]: return response.get(RESPONSE_METADATA, {}).get(NEXT_CURSOR) - async def _get_json(self, absolute_url): + async def _get_json( + self, absolute_url: str + ) -> Generator[Task, None, Dict[str, Any]]: self._logger.debug(f"Fetching url: {absolute_url}") async with self._call_api(absolute_url) as resp: json_content = await resp.json() @@ -166,7 +185,7 @@ async def _get_json(self, absolute_url): @asynccontextmanager @retryable(retries=3) - async def _call_api(self, absolute_url): + async def _call_api(self, absolute_url: str): try: async with self._http_session.get( absolute_url, @@ -176,7 +195,7 @@ async def _call_api(self, absolute_url): except ClientResponseError as e: await self._handle_client_response_error(e) - async def _handle_client_response_error(self, e): + async def _handle_client_response_error(self, e: ClientResponseError): if e.status == 429: response_headers = e.headers or {} if "Retry-After" in response_headers: @@ -197,7 +216,7 @@ async def _handle_client_response_error(self, e): else: raise - def _headers(self): + def _headers(self) -> Dict[str, str]: return { "Authorization": f"Bearer {self.token}", "accept": "application/json", @@ -208,7 +227,7 @@ class SlackDataSource(BaseDataSource): name = "Slack" service_type = "slack" - def __init__(self, configuration): + def __init__(self, configuration: DataSourceConfiguration) -> None: """Set up the connection to the Slack. Args: @@ -220,11 +239,11 @@ def __init__(self, configuration): self.n_days_to_fetch = configuration["fetch_last_n_days"] self.usernames = {} - def _set_internal_logger(self): + def _set_internal_logger(self) -> None: self.slack_client.set_logger(self._logger) @classmethod - def get_default_configuration(cls): + def get_default_configuration(cls) -> Dict[str, Dict[str, Union[int, str]]]: return { "token": { "label": "Authentication Token", @@ -258,12 +277,12 @@ def get_default_configuration(cls): }, } - async def ping(self): + async def ping(self) -> None: if not await self.slack_client.ping(): msg = "Could not connect to Slack" raise Exception(msg) - async def close(self): + async def close(self) -> None: await self.slack_client.close() async def get_docs(self, filtering=None): @@ -300,7 +319,7 @@ async def channels_and_messages(self): ): yield self.remap_message(message, channel) - def get_username(self, user): + def get_username(self, user: Dict[str, str]) -> str: """ Given a user record from slack, try to find a good username for it. This is hard, because no one property is reliably present and optimal. @@ -322,7 +341,7 @@ def get_username(self, user): "id" ] # Some Users do not have any names (like Bots). For these, we fall back on ID - def remap_user(self, user): + def remap_user(self, user: Dict[str, Any]) -> Dict[str, Any]: user_profile = user.get("profile", {}) return { "_id": user["id"], @@ -346,7 +365,7 @@ def remap_user(self, user): ), ) - def remap_channel(self, channel): + def remap_channel(self, channel: Dict[str, Union[bool, str]]) -> Dict[str, Any]: topic = channel.get("topic", {}) topic_creator_id = topic.get("creator") purpose = channel.get("purpose", {}) @@ -373,7 +392,11 @@ def remap_channel(self, channel): ), ) - def remap_message(self, message, channel): + def remap_message( + self, + message: Dict[str, Union[int, str]], + channel: Dict[str, Union[bool, str, int]], + ) -> Dict[str, Any]: user_id = message.get("user", message.get("bot_id")) def convert_usernames( diff --git a/connectors/sources/zoom.py b/connectors/sources/zoom.py index a63ff1ab2..bc7e77b5d 100644 --- a/connectors/sources/zoom.py +++ b/connectors/sources/zoom.py @@ -6,15 +6,20 @@ """Zoom source module responsible to fetch documents from Zoom.""" import os +from _asyncio import Future, Task from contextlib import asynccontextmanager from datetime import datetime, timedelta from functools import cached_property, partial +from logging import Logger +from typing import Any, Dict, Generator, List, Optional, Tuple, Union import aiohttp +from aiohttp.client import ClientSession from aiohttp.client_exceptions import ClientResponseError +from freezegun.api import FakeDatetime -from connectors.logger import logger -from connectors.source import BaseDataSource +from connectors.logger import ExtraLogger, logger +from connectors.source import BaseDataSource, DataSourceConfiguration from connectors.utils import ( CacheWithTimeout, CancellableSleeps, @@ -30,12 +35,12 @@ MEETING_PAGE_SIZE = 300 if "OVERRIDE_URL" in os.environ: - override_url = os.environ["OVERRIDE_URL"] - BASE_URL = override_url - BASE_AUTH_URL = override_url + override_url: str = os.environ["OVERRIDE_URL"] + BASE_URL: str = override_url + BASE_AUTH_URL: str = override_url else: - BASE_URL = "https://api.zoom.us/v2" - BASE_AUTH_URL = "https://zoom.us" + BASE_URL: str = "https://api.zoom.us/v2" + BASE_AUTH_URL: str = "https://zoom.us" AUTH = ( "{base_auth_url}/oauth/token?grant_type=account_credentials&account_id={account_id}" @@ -51,11 +56,11 @@ } -def format_recording_date(date): +def format_recording_date(date: FakeDatetime) -> str: return date.strftime("%Y-%m-%d") -def format_chat_date(date): +def format_chat_date(date: FakeDatetime) -> str: return date.strftime("%Y-%m-%dT%H:%M:%SZ") @@ -68,7 +73,12 @@ class ZoomResourceNotFound(Exception): class ZoomAPIToken: - def __init__(self, http_session, configuration, logger_): + def __init__( + self, + http_session: ClientSession, + configuration: DataSourceConfiguration, + logger_: Logger, + ) -> None: self._http_session = http_session self._token_cache = CacheWithTimeout() self._logger = logger_ @@ -76,10 +86,10 @@ def __init__(self, http_session, configuration, logger_): self.client_id = configuration["client_id"] self.client_secret = configuration["client_secret"] - def set_logger(self, logger_): + def set_logger(self, logger_) -> None: self._logger = logger_ - async def get(self, is_cache=True): + async def get(self, is_cache: bool = True) -> Generator[Task, None, str]: cached_value = self._token_cache.get_value() if is_cache else None if cached_value: @@ -96,7 +106,7 @@ async def get(self, is_cache=True): interval=RETRY_INTERVAL, strategy=RetryStrategy.EXPONENTIAL_BACKOFF, ) - async def _fetch_token(self): + async def _fetch_token(self) -> Tuple[str, int]: self._logger.debug("Generating access token.") url = AUTH.format(base_auth_url=BASE_AUTH_URL, account_id=self.account_id) content = f"{self.client_id}:{self.client_secret}" @@ -120,17 +130,19 @@ async def _fetch_token(self): class ZoomAPISession: - def __init__(self, http_session, api_token, logger_): + def __init__( + self, http_session: ClientSession, api_token: ZoomAPIToken, logger_: ExtraLogger + ) -> None: self._sleeps = CancellableSleeps() self._logger = logger_ self._http_session = http_session self._api_token = api_token - def set_logger(self, logger_): + def set_logger(self, logger_) -> None: self._logger = logger_ - def close(self): + def close(self) -> None: self._sleeps.cancel() @asynccontextmanager @@ -140,7 +152,7 @@ def close(self): strategy=RetryStrategy.EXPONENTIAL_BACKOFF, skipped_exceptions=ZoomResourceNotFound, ) - async def _get(self, absolute_url): + async def _get(self, absolute_url: str): try: token = await self._api_token.get() headers = { @@ -164,7 +176,7 @@ async def _get(self, absolute_url): except Exception: raise - async def fetch(self, url): + async def fetch(self, url: str) -> Generator[Task, None, Any]: try: async with self._get(absolute_url=url) as response: return await response.json() @@ -173,7 +185,7 @@ async def fetch(self, url): f"Data for {url} is being skipped. Error: {exception}." ) - async def content(self, url): + async def content(self, url: str) -> Generator[Task, None, Optional[str]]: try: async with self._get(absolute_url=url) as response: return await response.text() @@ -182,7 +194,7 @@ async def content(self, url): f"Content for {url} is being skipped. Error: {exception}." ) - async def scroll(self, url): + async def scroll(self, url: str) -> None: scroll_url = url while True: @@ -200,7 +212,7 @@ async def scroll(self, url): class ZoomClient: - def __init__(self, configuration): + def __init__(self, configuration: DataSourceConfiguration) -> None: self._sleeps = CancellableSleeps() self._logger = logger @@ -218,12 +230,12 @@ def __init__(self, configuration): logger_=self._logger, ) - def set_logger(self, logger_): + def set_logger(self, logger_) -> None: self._logger = logger_ self.api_token.set_logger(self._logger) self.api_client.set_logger(self._logger) - async def close(self): + async def close(self) -> None: await self.http_session.close() self.api_client.close() @@ -244,7 +256,7 @@ async def get_meetings(self, user_id, meeting_type): for meeting in meetings.get("meetings", []) or []: yield meeting - async def get_past_meeting(self, meeting_id): + async def get_past_meeting(self, meeting_id: str) -> Dict[str, str]: url = APIS["PAST_MEETING"].format(base_url=BASE_URL, meeting_id=meeting_id) return await self.api_client.fetch(url=url) @@ -307,19 +319,28 @@ class ZoomDataSource(BaseDataSource): service_type = "zoom" incremental_sync_enabled = True - def __init__(self, configuration): + def __init__(self, configuration: DataSourceConfiguration) -> None: super().__init__(configuration=configuration) self.configuration = configuration - def _set_internal_logger(self): + def _set_internal_logger(self) -> None: self.client.set_logger(self._logger) @cached_property - def client(self): + def client(self) -> ZoomClient: return ZoomClient(configuration=self.configuration) @classmethod - def get_default_configuration(cls): + def get_default_configuration( + cls, + ) -> Dict[ + str, + Union[ + Dict[str, Union[List[Dict[str, Union[int, str]]], int, str]], + Dict[str, Union[List[str], int, str]], + Dict[str, Union[int, str]], + ], + ]: return { "account_id": { "label": "Account ID", @@ -364,11 +385,11 @@ def get_default_configuration(cls): }, } - async def validate_config(self): + async def validate_config(self) -> None: await super().validate_config() await self.client.api_token.get() - async def ping(self): + async def ping(self) -> None: try: await self.client.api_token.get() self._logger.debug("Successfully connected to Zoom.") @@ -376,10 +397,14 @@ async def ping(self): self._logger.debug("Error while connecting to Zoom.") raise - async def close(self): + async def close(self) -> None: await self.client.close() - def _format_doc(self, doc, doc_time): + def _format_doc( + self, + doc: Dict[str, Union[str, int, List[Dict[str, str]]]], + doc_time: Optional[str], + ) -> Dict[str, Optional[Union[str, List[Dict[str, str]], int]]]: doc = self.serialize(doc=doc) doc.update( { @@ -389,7 +414,12 @@ def _format_doc(self, doc, doc_time): ) return doc - async def get_content(self, chat_file, timestamp=None, doit=False): + async def get_content( + self, + chat_file: Dict[str, Union[int, str]], + timestamp: None = None, + doit: bool = False, + ) -> Generator[Future, None, Optional[Dict[str, str]]]: file_size = chat_file["file_size"] if not (doit and file_size > 0): return @@ -417,7 +447,9 @@ async def get_content(self, chat_file, timestamp=None, doit=False): ), ) - async def fetch_previous_meeting_details(self, meeting_id): + async def fetch_previous_meeting_details( + self, meeting_id: str + ) -> Dict[str, Union[str, List[Dict[str, str]]]]: previous_meeting = await self.client.get_past_meeting(meeting_id=meeting_id) if not previous_meeting: diff --git a/connectors/sync_job_runner.py b/connectors/sync_job_runner.py index 9e7015e5d..84dc162ef 100644 --- a/connectors/sync_job_runner.py +++ b/connectors/sync_job_runner.py @@ -5,12 +5,16 @@ # import asyncio import time +from _asyncio import Future +from typing import Any, Dict, Iterator, Optional, Union +from unittest.mock import Mock import elasticsearch from elasticsearch import ( AuthorizationException as ElasticAuthorizationException, ) +import connectors.protocol.connectors from connectors.config import DataSourceFrameworkConfig from connectors.es.client import License, with_concurrency_control from connectors.es.index import DocumentNotFoundError @@ -45,7 +49,7 @@ class SyncJobRunningError(Exception): class InsufficientESLicenseError(Exception): - def __init__(self, required_license, actual_license): + def __init__(self, required_license: License, actual_license: License) -> None: super().__init__( f"Minimum required Elasticsearch license: '{required_license.value}'. Actual license: '{actual_license.value}'." ) @@ -56,12 +60,12 @@ class SyncJobStartError(Exception): class ConnectorNotFoundError(Exception): - def __init__(self, connector_id): + def __init__(self, connector_id: str) -> None: super().__init__(f"Connector is not found for connector ID {connector_id}.") class ConnectorJobNotFoundError(Exception): - def __init__(self, job_id): + def __init__(self, job_id: str) -> None: super().__init__(f"Connector job is not found for job ID {job_id}.") @@ -70,7 +74,9 @@ class ConnectorJobCanceledError(Exception): class ConnectorJobNotRunningError(Exception): - def __init__(self, job_id, status): + def __init__( + self, job_id: str, status: connectors.protocol.connectors.JobStatus + ) -> None: super().__init__( f"Connector job (ID: {job_id}) is not running but in status of {status}." ) @@ -96,12 +102,12 @@ class SyncJobRunner: def __init__( self, - source_klass, - sync_job, - connector, - es_config, - service_config, - ): + source_klass: Mock, + sync_job: Mock, + connector: Mock, + es_config: Dict[Any, Any], + service_config: Dict[Any, Any], + ) -> None: self.source_klass = source_klass self.data_provider = None self.sync_job = sync_job @@ -117,7 +123,7 @@ def __init__( "enable_operations_logging" ) - async def execute(self): + async def execute(self) -> None: if self.running: msg = f"Sync job {self.sync_job.id} is already running." raise SyncJobRunningError(msg) @@ -217,13 +223,17 @@ async def execute(self): if self.data_provider is not None: await self.data_provider.close() - def _data_source_framework_config(self): + def _data_source_framework_config(self) -> DataSourceFrameworkConfig: builder = DataSourceFrameworkConfig.Builder().with_max_file_size( self.service_config.get("max_file_download_size") ) return builder.build() - async def _execute_access_control_sync_job(self, job_type, bulk_options): + async def _execute_access_control_sync_job( + self, + job_type: connectors.protocol.connectors.JobType, + bulk_options: Dict[Any, Any], + ) -> None: if requires_platinum_license(self.sync_job, self.connector, self.source_klass): ( is_platinum_license_enabled, @@ -246,7 +256,9 @@ async def _execute_access_control_sync_job(self, job_type, bulk_options): enable_bulk_operations_logging=self._enable_bulk_operations_logging, ) - def _skip_unchanged_documents_enabled(self, job_type, data_provider): + def _skip_unchanged_documents_enabled( + self, job_type: connectors.protocol.connectors.JobType, data_provider: Mock + ) -> bool: """ Check if timestamp optimization is enabled for the current data source. Timestamp optimization can be enabled only for incremental jobs. @@ -265,7 +277,11 @@ def _skip_unchanged_documents_enabled(self, job_type, data_provider): is BaseDataSource.get_docs_incrementally ) - async def _execute_content_sync_job(self, job_type, bulk_options): + async def _execute_content_sync_job( + self, + job_type: connectors.protocol.connectors.JobType, + bulk_options: Dict[Any, Any], + ) -> None: if ( self.sync_job.job_type == JobType.INCREMENTAL and not self.connector.features.incremental_sync_enabled() @@ -300,7 +316,11 @@ async def _execute_content_sync_job(self, job_type, bulk_options): enable_bulk_operations_logging=self._enable_bulk_operations_logging, ) - async def _sync_done(self, sync_status, sync_error=None): + async def _sync_done( + self, + sync_status: connectors.protocol.connectors.JobStatus, + sync_error: Optional[Any] = None, + ) -> None: if self.job_reporting_task is not None and not self.job_reporting_task.done(): self.job_reporting_task.cancel() try: @@ -362,7 +382,7 @@ async def _sync_done(self, sync_status, sync_error=None): ) self.log_counters(ingestion_stats) - def log_counters(self, counters): + def log_counters(self, counters: Dict[str, int]) -> None: """ Logs out a dump of everything in "counters" @@ -381,7 +401,7 @@ def log_counters(self, counters): self.sync_job.log_info("----------------") @with_concurrency_control() - async def sync_starts(self): + async def sync_starts(self) -> None: if not await self.reload_connector(): msg = f"Couldn't reload connector {self.connector.id}" raise SyncJobStartError(msg) @@ -414,7 +434,7 @@ async def sync_starts(self): except Exception as e: raise SyncJobStartError from e - async def prepare_docs(self): + async def prepare_docs(self) -> None: self.sync_job.log_debug(f"Using pipeline {self.sync_job.pipeline}") async for doc, lazy_download, operation in self.generator(): @@ -446,7 +466,7 @@ async def prepare_docs(self): doc["_run_ml_inference"] = self.sync_job.pipeline["run_ml_inference"] yield doc, lazy_download, operation - async def generator(self): + async def generator(self) -> None: skip_unchanged_documents = self._skip_unchanged_documents_enabled( self.sync_job.job_type, self.data_provider ) @@ -478,7 +498,7 @@ async def generator(self): case _: raise UnsupportedJobType - async def update_ingestion_stats(self, interval): + async def update_ingestion_stats(self, interval: int) -> Iterator[Optional[Future]]: while True: await asyncio.sleep(interval) @@ -493,7 +513,7 @@ async def update_ingestion_stats(self, interval): } await self.sync_job.update_metadata(ingestion_stats=ingestion_stats) - async def check_job(self): + async def check_job(self) -> None: if not await self.reload_connector(): raise ConnectorNotFoundError(self.connector.id) @@ -506,7 +526,7 @@ async def check_job(self): if self.sync_job.status != JobStatus.IN_PROGRESS: raise ConnectorJobNotRunningError(self.sync_job.id, self.sync_job.status) - async def reload_sync_job(self): + async def reload_sync_job(self) -> bool: try: await self.sync_job.reload() return True @@ -514,7 +534,7 @@ async def reload_sync_job(self): self.sync_job.log_error("Couldn't reload sync job") return False - async def reload_connector(self): + async def reload_connector(self) -> bool: try: await self.connector.reload() return True @@ -522,7 +542,13 @@ async def reload_connector(self): self.connector.log_error("Couldn't reload connector") return False - def _content_extraction_enabled(self, sync_job_config, pipeline_config): + def _content_extraction_enabled( + self, + sync_job_config: Dict[str, Union[bool, str]], + pipeline_config: Union[ + connectors.protocol.connectors.Pipeline, Dict[str, bool], Dict[str, str] + ], + ) -> bool: if sync_job_config.get("use_text_extraction_service"): logger.debug( f"Binary content extraction via local extraction service is enabled for connector {self.connector.id} during sync job {self.sync_job.id}." diff --git a/connectors/utils.py b/connectors/utils.py index 59453e0a0..8aec0de93 100644 --- a/connectors/utils.py +++ b/connectors/utils.py @@ -18,19 +18,35 @@ import subprocess # noqa S404 import time import urllib.parse +from _asyncio import Future, Task from copy import deepcopy from datetime import datetime, timedelta, timezone from enum import Enum from time import strftime +from typing import ( + Any, + Callable, + Dict, + Generator, + Iterator, + List, + Optional, + Tuple, + Union, +) +from unittest.mock import AsyncMock, Mock import dateutil.parser as parser import pytz import tzcron from base64io import Base64IO from bs4 import BeautifulSoup +from freezegun.api import FakeDatetime from pympler import asizeof +from typing_extensions import Buffer -from connectors.logger import logger +from connectors.exceptions import DocumentIngestionError +from connectors.logger import _TracedAsyncGenerator, logger ACCESS_CONTROL_INDEX_PREFIX = ".search-acl-filter-" DEFAULT_CHUNK_SIZE = 500 @@ -86,17 +102,17 @@ class Format(Enum): SHORT = "short" -def parse_datetime_string(datetime): +def parse_datetime_string(datetime: str) -> datetime: return parser.parse(datetime) -def iso_utc(when=None): +def iso_utc(when: Optional[datetime] = None) -> str: if when is None: when = datetime.now(timezone.utc) return when.isoformat() -def with_utc_tz(ts): +def with_utc_tz(ts: datetime) -> datetime: """Ensure the timestmap has a timezone of UTC.""" if ts.tzinfo is None: return ts.replace(tzinfo=timezone.utc) @@ -104,17 +120,17 @@ def with_utc_tz(ts): return ts.astimezone(timezone.utc) -def iso_zulu(): +def iso_zulu() -> str: """Returns the current time in ISO Zulu format""" return datetime.now(timezone.utc).strftime(ISO_ZULU_TIMESTAMP_FORMAT) -def epoch_timestamp_zulu(): +def epoch_timestamp_zulu() -> str: """Returns the timestamp of the start of the epoch, in ISO Zulu format""" return strftime(ISO_ZULU_TIMESTAMP_FORMAT, time.gmtime(0)) -def next_run(quartz_definition, now): +def next_run(quartz_definition: str, now: datetime) -> datetime: """Returns the datetime in UTC timezone of the next run.""" # Year is optional and is never present. _, minutes, hours, day_of_month, month, day_of_week, year = ( @@ -153,7 +169,7 @@ class InvalidIndexNameError(ValueError): pass -def validate_index_name(name): +def validate_index_name(name: str) -> str: for char in INVALID_CHARS: if char in name: msg = f"Invalid character {char}" @@ -175,10 +191,12 @@ def validate_index_name(name): class CancellableSleeps: - def __init__(self): + def __init__(self) -> None: self._sleeps = set() - async def sleep(self, delay, result=None, *, loop=None): + async def sleep( + self, delay: Union[float, Mock, int], result: None = None, *, loop=None + ) -> None: async def _sleep(delay, result=None, *, loop=None): coro = asyncio.sleep(delay, result=result) task = asyncio.ensure_future(coro) @@ -193,7 +211,7 @@ async def _sleep(delay, result=None, *, loop=None): await _sleep(delay, result=result, loop=loop) - def cancel(self, sig=None): + def cancel(self, sig: None = None) -> None: if sig: logger.debug(f"Caught {sig}. Cancelling sleeps...") else: @@ -203,16 +221,16 @@ def cancel(self, sig=None): task.cancel() -def get_size(ob): +def get_size(ob: Any) -> int: """Returns size in Bytes""" return asizeof.asizeof(ob) -def get_file_extension(filename): +def get_file_extension(filename: str) -> str: return os.path.splitext(filename)[-1] -def get_base64_value(content): +def get_base64_value(content: Buffer) -> str: """ Returns the converted file passed into a base64 encoded value Args: @@ -221,7 +239,7 @@ def get_base64_value(content): return base64.b64encode(content).decode("utf-8") -def decode_base64_value(content): +def decode_base64_value(content: Union[str, Buffer]) -> bytes: """ Decodes the base64 encoded content Args: @@ -230,10 +248,22 @@ def decode_base64_value(content): return base64.b64decode(content) -_BASE64 = shutil.which("base64") +_BASE64: Optional[str] = shutil.which("base64") -def convert_to_b64(source, target=None, overwrite=False): +def convert_to_b64( + source: Union[os.PathLike[bytes], os.PathLike[str], bytes, str], + target: Union[None, os.PathLike[bytes], os.PathLike[str], bytes, int, str] = None, + overwrite: bool = False, +) -> Union[ + os.PathLike[bytes], + os.PathLike[str], + os.PathLike[Union[bytes, str]], + bytes, + int, + str, + None, +]: """Converts a `source` file to base64 using the system's `base64` When `target` is not provided, done in-place. @@ -289,27 +319,31 @@ def convert_to_b64(source, target=None, overwrite=False): class MemQueue(asyncio.Queue): def __init__( - self, maxsize=0, maxmemsize=0, refresh_interval=1.0, refresh_timeout=60 - ): + self, + maxsize: int = 0, + maxmemsize: int = 0, + refresh_interval: float = 1.0, + refresh_timeout: float = 60, + ) -> None: super().__init__(maxsize) self.maxmemsize = maxmemsize self.refresh_interval = refresh_interval self._current_memsize = 0 self.refresh_timeout = refresh_timeout - def qmemsize(self): + def qmemsize(self) -> int: return self._current_memsize - def _get(self): + def _get(self) -> Any: item_size, item = self._queue.popleft() # pyright: ignore self._current_memsize -= item_size return item_size, item - def _put(self, item): + def _put(self, item: Any) -> None: self._current_memsize += item[0] # pyright: ignore self._queue.append(item) # pyright: ignore - def full(self, next_item_size=0): + def full(self, next_item_size: int = 0) -> bool: full_by_numbers = super().full() if full_by_numbers: @@ -323,7 +357,7 @@ def full(self, next_item_size=0): return self._current_memsize + next_item_size >= self.maxmemsize - async def _putter_timeout(self, putter): + async def _putter_timeout(self, putter: Future) -> None: """This coroutine will set the result of the putter to QueueFull when a certain timeout it reached.""" start = time.time() while not putter.done(): @@ -338,7 +372,7 @@ async def _putter_timeout(self, putter): logger.debug("Queue Full") await asyncio.sleep(self.refresh_interval) - async def put(self, item): + async def put(self, item: Any) -> None: item_size = get_size(item) # This block is taken from the original put() method but with two @@ -387,14 +421,14 @@ async def put(self, item): super().put_nowait((item_size, item)) - def clear(self): + def clear(self) -> None: while not self.empty(): # Depending on your program, you may want to # catch QueueEmpty self.get_nowait() self.task_done() - def put_nowait(self, item): + def put_nowait(self, item: Union[str, int]) -> None: item_size = get_size(item) if self.full(item_size): msg = f"Queue is full: attempting to add item of size {item_size} bytes while {self.maxmemsize - self._current_memsize} free bytes left." @@ -408,7 +442,7 @@ class NonBlockingBoundedSemaphore(asyncio.BoundedSemaphore): This introduces a new try_acquire method, which will return if it can't acquire immediately. """ - def try_acquire(self): + def try_acquire(self) -> bool: if self.locked(): return False @@ -443,14 +477,14 @@ class ConcurrentTasks: await task_pool.join() """ - def __init__(self, max_concurrency=5): + def __init__(self, max_concurrency: int = 5) -> None: self.tasks = [] self._sem = NonBlockingBoundedSemaphore(max_concurrency) - def __len__(self): + def __len__(self) -> int: return len(self.tasks) - def _callback(self, task): + def _callback(self, task: Task) -> None: self.tasks.remove(task) self._sem.release() if task.cancelled(): @@ -462,7 +496,9 @@ def _callback(self, task): f"Exception found for task {task.get_name()}", exc_info=task.exception() ) - def _add_task(self, coroutine, name=None): + def _add_task( + self, coroutine: Union[functools.partial, Callable], name: Optional[str] = None + ) -> Task: task = asyncio.create_task(coroutine(), name=name) self.tasks.append(task) # _callback will be executed when the task is done, @@ -471,7 +507,9 @@ def _add_task(self, coroutine, name=None): task.add_done_callback(functools.partial(self._callback)) return task - async def put(self, coroutine, name=None): + async def put( + self, coroutine: Union[functools.partial, Callable], name: Optional[str] = None + ) -> Generator[Future, None, Task]: """Adds a coroutine for immediate execution. If the number of running tasks reach `max_concurrency`, this @@ -480,7 +518,9 @@ async def put(self, coroutine, name=None): await self._sem.acquire() return self._add_task(coroutine, name=name) - def try_put(self, coroutine, name=None): + def try_put( + self, coroutine: functools.partial, name: Optional[str] = None + ) -> Optional[Task]: """Tries to add a coroutine for immediate execution. If the number of running tasks reach `max_concurrency`, this @@ -491,7 +531,7 @@ def try_put(self, coroutine, name=None): return self._add_task(coroutine, name=name) return None - async def join(self, raise_on_error=False): + async def join(self, raise_on_error: bool = False) -> None: """Wait for all tasks to finish.""" try: await asyncio.gather(*self.tasks, return_exceptions=(not raise_on_error)) @@ -499,7 +539,7 @@ async def join(self, raise_on_error=False): self.cancel() raise - def raise_any_exception(self): + def raise_any_exception(self) -> None: for task in self.tasks: if task.done() and not task.cancelled(): if task.exception(): @@ -509,7 +549,7 @@ def raise_any_exception(self): self.cancel() # cancel all the pending tasks raise task.exception() - def cancel(self): + def cancel(self) -> None: """Cancels all tasks""" for task in self.tasks: task.cancel() @@ -529,11 +569,11 @@ class UnknownRetryStrategyError(Exception): def retryable( - retries=3, - interval=1.0, - strategy=RetryStrategy.LINEAR_BACKOFF, - skipped_exceptions=None, -): + retries: int = 3, + interval: float = 1.0, + strategy: RetryStrategy = RetryStrategy.LINEAR_BACKOFF, + skipped_exceptions: Optional[Any] = None, +) -> Callable: def wrapper(func): if skipped_exceptions is None: processed_skipped_exceptions = [] @@ -561,7 +601,13 @@ def wrapper(func): return wrapper -def retryable_async_function(func, retries, interval, strategy, skipped_exceptions): +def retryable_async_function( + func: Callable, + retries: int, + interval: Union[float, int], + strategy: RetryStrategy, + skipped_exceptions: List[Any], +) -> Callable: @functools.wraps(func) async def wrapped(*args, **kwargs): retry = 1 @@ -582,7 +628,13 @@ async def wrapped(*args, **kwargs): return wrapped -def retryable_async_generator(func, retries, interval, strategy, skipped_exceptions): +def retryable_async_generator( + func: Callable, + retries: int, + interval: Union[float, int], + strategy: RetryStrategy, + skipped_exceptions: List[Any], +) -> Callable: @functools.wraps(func) async def wrapped(*args, **kwargs): retry = 1 @@ -606,7 +658,13 @@ async def wrapped(*args, **kwargs): return wrapped -def retryable_sync_function(func, retries, interval, strategy, skipped_exceptions): +def retryable_sync_function( + func: Callable, + retries: int, + interval: int, + strategy: RetryStrategy, + skipped_exceptions: List[Any], +) -> Callable: @functools.wraps(func) def wrapped(*args, **kwargs): retry = 1 @@ -625,7 +683,9 @@ def wrapped(*args, **kwargs): return wrapped -def time_to_sleep_between_retries(strategy, interval, retry): +def time_to_sleep_between_retries( + strategy: Union[RetryStrategy, str], interval: Union[float, int], retry: int +) -> Union[float, int]: match strategy: case RetryStrategy.CONSTANT: return interval @@ -637,7 +697,7 @@ def time_to_sleep_between_retries(strategy, interval, retry): raise UnknownRetryStrategyError() -def ssl_context(certificate): +def ssl_context(certificate: str) -> ssl.SSLContext: """Convert string to pem format and create a SSL context Args: @@ -652,7 +712,7 @@ def ssl_context(certificate): return ctx -def url_encode(original_string): +def url_encode(original_string: str) -> str: """Performs encoding on the objects containing special characters in their url, and replaces single quote with two single quote since quote @@ -667,7 +727,7 @@ def url_encode(original_string): return urllib.parse.quote(original_string, safe="'") -def evaluate_timedelta(seconds, time_skew=0): +def evaluate_timedelta(seconds: int, time_skew: int = 0) -> str: """Adds seconds to the current utc time. Args: @@ -680,7 +740,7 @@ def evaluate_timedelta(seconds, time_skew=0): return iso_utc(when=modified_time) -def is_expired(expires_at): +def is_expired(expires_at: Optional[Union[FakeDatetime, datetime]]) -> bool: """Compares the given time with present time Args: @@ -692,7 +752,7 @@ def is_expired(expires_at): return datetime.utcnow() >= expires_at -def get_pem_format(key, postfix="-----END CERTIFICATE-----"): +def get_pem_format(key: str, postfix: str = "-----END CERTIFICATE-----") -> str: """Convert key into PEM format. Args: @@ -726,14 +786,14 @@ def get_pem_format(key, postfix="-----END CERTIFICATE-----"): return pem_format -def hash_id(_id): +def hash_id(_id: str) -> str: # Collision probability: 1.47*10^-29 # S105 rule considers this code unsafe, but we're not using it for security-related # things, only to generate pseudo-ids for documents return hashlib.md5(_id.encode("utf8")).hexdigest() # noqa S105 -def truncate_id(_id): +def truncate_id(_id: str) -> str: """Truncate ID of an object. We cannot guarantee that connector returns small IDs. @@ -754,7 +814,7 @@ def truncate_id(_id): return _id -def has_duplicates(strings_list): +def has_duplicates(strings_list: List[Union[str, Any]]) -> bool: seen = set() for s in strings_list: if s in seen: @@ -763,7 +823,10 @@ def has_duplicates(strings_list): return False -def filter_nested_dict_by_keys(key_list, source_dict): +def filter_nested_dict_by_keys( + key_list: List[Union[str, Any]], + source_dict: Dict[str, Union[Dict[str, int], Dict[Any, Any]]], +) -> Dict[str, Union[Dict[str, int], Dict[Any, Any]]]: """Filters a nested dict by the keys of the sub-level dict. This is used for checking if any configuration fields are missing properties. @@ -782,7 +845,9 @@ def filter_nested_dict_by_keys(key_list, source_dict): return filtered_dict -def deep_merge_dicts(base_dict, new_dict): +def deep_merge_dicts( + base_dict: Dict[str, Any], new_dict: Dict[str, Any] +) -> Dict[str, Any]: """Deep merges two nested dicts. Args: @@ -810,17 +875,17 @@ class CacheWithTimeout: Example of usage: cache = CacheWithTimeout() - cache.set_value(50, datetime.datetime.now() + datetime.timedelta(5) + cache.set_value(50, datetime.now() + timedelta(5) value = cache.get() # 50 sleep(5) value = cache.get() # None """ - def __init__(self): + def __init__(self) -> None: self._value = None self._expiration_date = None - def get_value(self): + def get_value(self) -> Optional[str]: """Get the value that's stored inside if it hasn't expired. If the expiration_date is past due, None is returned instead. @@ -833,7 +898,9 @@ def get_value(self): return None - def set_value(self, value, expiration_date): + def set_value( + self, value: str, expiration_date: Union[FakeDatetime, float, datetime] + ) -> None: """Set the value in the cache with expiration date. Once expiration_date is past due, the value will be lost. @@ -842,7 +909,7 @@ def set_value(self, value, expiration_date): self._expiration_date = expiration_date -def html_to_text(html): +def html_to_text(html: Optional[str]) -> Optional[Union[Mock, str]]: if not html: return html try: @@ -853,7 +920,7 @@ def html_to_text(html): return BeautifulSoup(html, features="html.parser").get_text(separator="\n") -async def aenumerate(asequence, start=0): +async def aenumerate(asequence: _TracedAsyncGenerator, start: int = 0) -> None: i = start async for elem in asequence: try: @@ -862,7 +929,12 @@ async def aenumerate(asequence, start=0): i += 1 -def iterable_batches_generator(iterable, batch_size): +def iterable_batches_generator( + iterable: List[ + Union[Dict[str, str], Dict[str, Union[Dict[str, str], str]], int, Any] + ], + batch_size: int, +) -> Iterator[List[Union[Dict[str, str], Dict[str, Union[Dict[str, str], str]], int]]]: """Iterate over an iterable in batches. If the batch size is bigger than the number of remaining elements then all remaining elements will be returned. @@ -880,7 +952,15 @@ def iterable_batches_generator(iterable, batch_size): yield iterable[idx : min(idx + batch_size, num_items)] -def dict_slice(hsh, keys, default=None): +def dict_slice( + hsh: Dict[str, Union[bool, str, int]], + keys: Union[ + Tuple[str, str, str, str], + Tuple[str, str, str, str, str, str], + Tuple[str, str, str, str, str, str, str], + ], + default: None = None, +) -> Dict[str, Optional[Union[str, int]]]: """ Slice a dict by a subset of its keys. :param hsh: The input dictionary to slice @@ -890,7 +970,7 @@ def dict_slice(hsh, keys, default=None): return {k: hsh.get(k, default) for k in keys} -def base64url_to_base64(string): +def base64url_to_base64(string: Optional[str]) -> Optional[str]: if string is None: return string @@ -901,7 +981,7 @@ def base64url_to_base64(string): return string.replace("_", "/") -def validate_email_address(email_address): +def validate_email_address(email_address: str) -> bool: """Validates an email address against a regular expression. This method does not include any remote check against an SMTP server for example.""" @@ -909,7 +989,7 @@ def validate_email_address(email_address): return re.fullmatch(EMAIL_REGEX_PATTERN, email_address) is not None -def shorten_str(string, shorten_by): +def shorten_str(string: Optional[str], shorten_by: int) -> str: """ Shorten a string by removing characters from the middle, replacing them with '...'. @@ -952,7 +1032,9 @@ def shorten_str(string, shorten_by): return f"{string[:keep + 1]}...{string[-keep:]}" -def func_human_readable_name(func): +def func_human_readable_name( + func: Union[AsyncMock, functools.partial, Callable], +) -> str: if isinstance(func, functools.partial): return func.func.__name__ @@ -962,7 +1044,11 @@ def func_human_readable_name(func): return str(func) -def nested_get_from_dict(dictionary, keys, default=None): +def nested_get_from_dict( + dictionary: Any, + keys: Union[List[str], Tuple[str, str]], + default: Optional[Union[bool, str, float]] = None, +) -> Any: def nested_get(dictionary_, keys_, default_=None): if dictionary_ is None: return default_ @@ -978,7 +1064,9 @@ def nested_get(dictionary_, keys_, default_=None): return nested_get(dictionary, keys, default) -def sanitize(doc): +def sanitize( + doc: Dict[str, Union[int, datetime, str]], +) -> Dict[str, Union[datetime, str]]: if doc["_id"]: # guarantee that IDs are strings, and not numeric doc["_id"] = str(doc["_id"]) @@ -992,18 +1080,20 @@ class Counters: A utility to provide code readability to managing a collection of counts """ - def __init__(self): + def __init__(self) -> None: self._storage = {} - def increment(self, key, value=1, namespace=None): + def increment( + self, key: str, value: int = 1, namespace: Optional[str] = None + ) -> None: if namespace: key = f"{namespace}.{key}" self._storage[key] = self._storage.get(key, 0) + value - def get(self, key) -> int: + def get(self, key: str) -> int: return self._storage.get(key, 0) - def to_dict(self): + def to_dict(self) -> Dict[str, int]: return deepcopy(self._storage) @@ -1014,13 +1104,13 @@ class TooManyErrors(Exception): class ErrorMonitor: def __init__( self, - enabled=True, - max_total_errors=1000, - max_consecutive_errors=10, - max_error_rate=0.15, - error_window_size=100, - error_queue_size=10, - ): + enabled: bool = True, + max_total_errors: int = 1000, + max_consecutive_errors: int = 10, + max_error_rate: float = 0.15, + error_window_size: int = 100, + error_queue_size: int = 10, + ) -> None: # When disabled, only track errors self.enabled = enabled @@ -1039,12 +1129,14 @@ def __init__( self.total_success_count = 0 self.total_error_count = 0 - def track_success(self): + def track_success(self) -> None: self.consecutive_error_count = 0 self.total_success_count += 1 self._update_error_window(False) - def track_error(self, error): + def track_error( + self, error: Union[InvalidIndexNameError, DocumentIngestionError, Exception] + ) -> None: self.total_error_count += 1 self.consecutive_error_count += 1 @@ -1058,7 +1150,7 @@ def track_error(self, error): self._raise_if_necessary() - def _update_error_window(self, value): + def _update_error_window(self, value: bool) -> None: # We keep the errors array of the size self.error_window_size this way, imagine self.error_window_size = 5 # Error array inits as falses: # [ false, false, false, false, false ] @@ -1083,7 +1175,7 @@ def _update_error_window(self, value): self.error_window[self.error_window_index] = value self.error_window_index = (self.error_window_index + 1) % self.error_window_size - def _error_window_error_rate(self): + def _error_window_error_rate(self) -> float: if self.error_window_size == 0: return 0 @@ -1093,7 +1185,7 @@ def _error_window_error_rate(self): return error_rate - def _raise_if_necessary(self): + def _raise_if_necessary(self) -> None: if not self.enabled: return @@ -1110,7 +1202,7 @@ def _raise_if_necessary(self): raise TooManyErrors(msg) from self.last_error -def generate_random_id(length=4): +def generate_random_id(length: int = 4) -> str: return "".join( secrets.choice(string.ascii_letters + string.digits) for _ in range(length) ) diff --git a/scripts/deps-csv.py b/scripts/deps-csv.py index ef35810d7..7dd6a15be 100755 --- a/scripts/deps-csv.py +++ b/scripts/deps-csv.py @@ -14,7 +14,7 @@ URL = 3 -def main(dependencies_csv): +def main(dependencies_csv) -> None: """ The input is what we get from `pip-licenses --format=csv --with-urls` See: https://pypi.org/project/pip-licenses/#csv @@ -44,7 +44,7 @@ def main(dependencies_csv): if __name__ == "__main__": - depenencies_csv = sys.argv[1] + depenencies_csv: str = sys.argv[1] print(f"post-processing {depenencies_csv}") # noqa main(depenencies_csv) print(f"wrote output to {depenencies_csv}") # noqa diff --git a/scripts/testing/cli.py b/scripts/testing/cli.py index e410a0ad6..5c1ab1eee 100644 --- a/scripts/testing/cli.py +++ b/scripts/testing/cli.py @@ -17,7 +17,7 @@ __all__ = ["main"] -BASE_DIR = os.path.dirname(os.path.abspath(__file__)) +BASE_DIR: str = os.path.dirname(os.path.abspath(__file__)) VM_STARTUP_SCRIPT_PATH = f"startup-script={BASE_DIR}/startup_scipt.sh" VM_INIT_ATTEMPTS = 30 @@ -42,7 +42,7 @@ @click.group() @click.pass_context -def cli(ctx): +def cli(ctx) -> None: pass @@ -92,7 +92,7 @@ def create_test_environment( es_password, test_case, delete, -): +) -> None: """ Creates a new VM and runs the tests """ @@ -115,7 +115,7 @@ def create_test_environment( @click.command(name="delete", help="Deletes a VM") @click.argument("name") @click.option("--vm-zone", default="europe-west1-b") -def delete_test_environment(name, vm_zone): +def delete_test_environment(name, vm_zone) -> None: """ Deletes the VM """ @@ -138,7 +138,7 @@ def delete_test_environment(name, vm_zone): cli.add_command(delete_test_environment) -def print_help(name, vm_zone): +def print_help(name, vm_zone) -> None: """ Prints a list of commands that can be used to interact with the setup """ @@ -185,7 +185,7 @@ def print_help(name, vm_zone): click.echo("List of sync jobs: " + click.style(sync_jobs_cmd, fg="green")) -def create_vm(name, vm_type, vm_zone): +def create_vm(name, vm_type, vm_zone) -> None: """ Creates a new VM and waits until the startup script finishes its work """ @@ -287,7 +287,7 @@ def render_connector_configuration(file_path): return configuration -def setup_stack(name, vm_zone, es_version, connectors_ref, es_host): +def setup_stack(name, vm_zone, es_version, connectors_ref, es_host) -> None: with click.progressbar(label="Setting up the stack...", length=100) as steps: # Upload pull-connectors file cmd = [ @@ -411,7 +411,7 @@ def connector_service_config(es_host, es_username, es_password): os.remove(file_name) -def run_scenarios(name, es_host, es_username, es_password, vm_zone, test_case): +def run_scenarios(name, es_host, es_username, es_password, vm_zone, test_case) -> None: """ Runs the scenarios from the test case file """ @@ -574,7 +574,7 @@ def run_scenarios(name, es_host, es_username, es_password, vm_zone, test_case): # TODO Read all the fields in one call -def read_from_vault(key): +def read_from_vault(key) -> str: """ Reads a secret from Vault and returns its value """ @@ -588,7 +588,7 @@ def read_from_vault(key): return result.stdout.decode("utf-8") -def run_gcloud_cmd(cmd): +def run_gcloud_cmd(cmd) -> subprocess.CompletedProcess[bytes]: """ Runs a gcloud command and raises an exception if the command failed """ @@ -600,7 +600,7 @@ def run_gcloud_cmd(cmd): return result -def main(args=None): +def main(args=None) -> None: cli() diff --git a/scripts/verify.py b/scripts/verify.py index d44002a60..a2e5629ba 100644 --- a/scripts/verify.py +++ b/scripts/verify.py @@ -6,16 +6,19 @@ # ruff: noqa: T201 import asyncio import os -from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser +from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser, Namespace +from typing import Optional, Sequence, Union from elasticsearch import AsyncElasticsearch from connectors.config import load_config -DEFAULT_CONFIG = os.path.join(os.path.dirname(__file__), "..", "config.yml") +DEFAULT_CONFIG: str = os.path.join(os.path.dirname(__file__), "..", "config.yml") -async def verify(service_type, index_name, size, config): +async def verify( + service_type, index_name: Union[None, Sequence[str], str], size, config +) -> None: config = config["elasticsearch"] host = config["host"] auth = config["username"], config["password"] @@ -56,7 +59,7 @@ async def verify(service_type, index_name, size, config): await client.close() -def _parser(): +def _parser() -> ArgumentParser: parser = ArgumentParser( prog="verify", formatter_class=ArgumentDefaultsHelpFormatter ) @@ -73,7 +76,7 @@ def _parser(): return parser -def main(args=None): +def main(args: Optional[Namespace] = None) -> None: parser = _parser() args = parser.parse_args(args=args) config_file = args.config_file diff --git a/setup.py b/setup.py index e70c03d79..6f726044d 100644 --- a/setup.py +++ b/setup.py @@ -10,9 +10,9 @@ from setuptools._vendor.packaging.markers import Marker try: - ARCH = os.uname().machine + ARCH: str = os.uname().machine except Exception as e: - ARCH = "x86_64" + ARCH: str = "x86_64" print( # noqa: T201 f"Defaulting to architecture '{ARCH}'. Unable to determine machine architecture due to error: {e}" ) @@ -75,7 +75,7 @@ def read_reqs(req_file): with open("README.md") as f: - long_description = f.read() + long_description: str = f.read() classifiers = [ diff --git a/tests/agent/test_agent_config.py b/tests/agent/test_agent_config.py index dc79bca25..0c28f428d 100644 --- a/tests/agent/test_agent_config.py +++ b/tests/agent/test_agent_config.py @@ -11,7 +11,7 @@ SERVICE_TYPE = "test-service-type" -def prepare_unit_mock(fields, log_level): +def prepare_unit_mock(fields, log_level) -> Mock: if not fields: fields = {} unit_mock = Mock() @@ -25,7 +25,7 @@ def prepare_unit_mock(fields, log_level): return unit_mock -def prepare_config_wrapper(): +def prepare_config_wrapper() -> ConnectorsAgentConfigurationWrapper: # populate with connectors list, so that we can test for changes in other config properties config_wrapper = ConnectorsAgentConfigurationWrapper() initial_config_unit = prepare_unit_mock({}, None) @@ -37,7 +37,7 @@ def prepare_config_wrapper(): return config_wrapper -def test_try_update_without_auth_data(): +def test_try_update_without_auth_data() -> None: config_wrapper = prepare_config_wrapper() unit_mock = prepare_unit_mock({}, None) @@ -52,7 +52,7 @@ def test_try_update_without_auth_data(): ) -def test_try_update_with_api_key_auth_data(): +def test_try_update_with_api_key_auth_data() -> None: hosts = ["https://localhost:9200"] api_key = "lemme_in" @@ -71,7 +71,7 @@ def test_try_update_with_api_key_auth_data(): assert config_wrapper.get()["elasticsearch"]["api_key"] == api_key -def test_try_update_with_non_encoded_api_key_auth_data(): +def test_try_update_with_non_encoded_api_key_auth_data() -> None: hosts = ["https://localhost:9200"] api_key = "something:else" encoded = "c29tZXRoaW5nOmVsc2U=" @@ -91,7 +91,7 @@ def test_try_update_with_non_encoded_api_key_auth_data(): assert config_wrapper.get()["elasticsearch"]["api_key"] == encoded -def test_try_update_with_basic_auth_auth_data(): +def test_try_update_with_basic_auth_auth_data() -> None: hosts = ["https://localhost:9200"] username = "elastic" password = "hold the door" @@ -114,7 +114,7 @@ def test_try_update_with_basic_auth_auth_data(): assert config_wrapper.get()["elasticsearch"]["password"] == password -def test_try_update_multiple_times_does_not_reset_config_values(): +def test_try_update_multiple_times_does_not_reset_config_values() -> None: hosts = ["https://localhost:9200"] api_key = "lemme_in" @@ -149,7 +149,7 @@ def test_try_update_multiple_times_does_not_reset_config_values(): assert config_wrapper.get()["service"]["log_level"] == log_level -def test_config_changed_when_new_variables_are_passed(): +def test_config_changed_when_new_variables_are_passed() -> None: hosts = ["https://localhost:9200"] api_key = "lemme_in_lalala" @@ -163,7 +163,7 @@ def test_config_changed_when_new_variables_are_passed(): assert config_wrapper.config_changed(new_config) is True -def test_config_changed_when_elasticsearch_config_changed(): +def test_config_changed_when_elasticsearch_config_changed() -> None: hosts = ["https://localhost:9200"] api_key = "lemme_in_lalala" @@ -189,7 +189,7 @@ def test_config_changed_when_elasticsearch_config_changed(): assert config_wrapper.config_changed(new_config) is True -def test_config_changed_when_elasticsearch_config_did_not_change(): +def test_config_changed_when_elasticsearch_config_did_not_change() -> None: hosts = ["https://localhost:9200"] api_key = "lemme_in_lalala" @@ -208,7 +208,7 @@ def test_config_changed_when_elasticsearch_config_did_not_change(): assert config_wrapper.config_changed(new_config) is True -def test_config_changed_when_log_level_config_changed(): +def test_config_changed_when_log_level_config_changed() -> None: config_wrapper = prepare_config_wrapper() config_wrapper.try_update( connector_id=CONNECTOR_ID, @@ -224,7 +224,7 @@ def test_config_changed_when_log_level_config_changed(): assert config_wrapper.config_changed(new_config) is True -def test_config_changed_when_log_level_config_did_not_change(): +def test_config_changed_when_log_level_config_did_not_change() -> None: config_wrapper = prepare_config_wrapper() config_wrapper.try_update( connector_id=CONNECTOR_ID, @@ -240,7 +240,7 @@ def test_config_changed_when_log_level_config_did_not_change(): assert config_wrapper.config_changed(new_config) is False -def test_config_changed_when_connectors_changed(): +def test_config_changed_when_connectors_changed() -> None: config_wrapper = ConnectorsAgentConfigurationWrapper() config_wrapper.try_update( @@ -258,7 +258,7 @@ def test_config_changed_when_connectors_changed(): assert config_wrapper.config_changed(new_config) is True -def test_config_changed_when_connectors_list_is_cleared(): +def test_config_changed_when_connectors_list_is_cleared() -> None: config_wrapper = ConnectorsAgentConfigurationWrapper() config_wrapper.try_update( @@ -274,7 +274,7 @@ def test_config_changed_when_connectors_list_is_cleared(): assert config_wrapper.config_changed(new_config) is True -def test_config_changed_when_connectors_list_is_extended(): +def test_config_changed_when_connectors_list_is_extended() -> None: config_wrapper = ConnectorsAgentConfigurationWrapper() config_wrapper.try_update( @@ -293,7 +293,7 @@ def test_config_changed_when_connectors_list_is_extended(): assert config_wrapper.config_changed(new_config) is True -def test_config_changed_when_connectors_did_not_change(): +def test_config_changed_when_connectors_did_not_change() -> None: config_wrapper = ConnectorsAgentConfigurationWrapper() config_wrapper.try_update( diff --git a/tests/agent/test_cli.py b/tests/agent/test_cli.py index ec1a6e41d..ed97b9745 100644 --- a/tests/agent/test_cli.py +++ b/tests/agent/test_cli.py @@ -12,7 +12,7 @@ @patch("connectors.agent.cli.ConnectorsAgentComponent", return_value=AsyncMock()) -def test_main_responds_to_sigterm(patch_component): +def test_main_responds_to_sigterm(patch_component) -> None: async def kill(): await asyncio.sleep(0.2) os.kill(os.getpid(), signal.SIGTERM) diff --git a/tests/agent/test_component.py b/tests/agent/test_component.py index 9bdfe9e17..67904adf2 100644 --- a/tests/agent/test_component.py +++ b/tests/agent/test_component.py @@ -12,17 +12,17 @@ class StubMultiService: - def __init__(self): + def __init__(self) -> None: self.running_stop = asyncio.Event() self.has_ran = False self.has_shutdown = False - async def run(self): + async def run(self) -> None: self.has_ran = True self.running_stop.clear() await self.running_stop.wait() - def shutdown(self, sig): + def shutdown(self, sig) -> None: self.has_shutdown = True self.running_stop.set() @@ -32,7 +32,7 @@ def shutdown(self, sig): @patch("connectors.agent.component.new_v2_from_reader", return_value=MagicMock()) async def test_try_update_without_auth_data( stub_multi_service, patch_new_v2_from_reader -): +) -> None: component = ConnectorsAgentComponent() async def stop_after_timeout(): diff --git a/tests/agent/test_connector_record_manager.py b/tests/agent/test_connector_record_manager.py index 9657f3943..46c5e0e5f 100644 --- a/tests/agent/test_connector_record_manager.py +++ b/tests/agent/test_connector_record_manager.py @@ -3,6 +3,7 @@ # or more contributor license agreements. Licensed under the Elastic License 2.0; # you may not use this file except in compliance with the Elastic License 2.0. # +from typing import Dict, List, Union from unittest.mock import AsyncMock, patch import pytest @@ -14,12 +15,12 @@ @pytest.fixture -def mock_connector_index(): +def mock_connector_index() -> AsyncMock: return AsyncMock(ConnectorIndex) @pytest.fixture -def mock_agent_config(): +def mock_agent_config() -> Dict[str, Union[Dict[str, str], List[Dict[str, str]]]]: return { "elasticsearch": {"host": "http://localhost:9200", "api_key": "dummy_key"}, "connectors": [{"connector_id": "1", "service_type": "service1"}], @@ -27,7 +28,7 @@ def mock_agent_config(): @pytest.fixture -def connector_record_manager(mock_connector_index): +def connector_record_manager(mock_connector_index) -> ConnectorRecordManager: manager = ConnectorRecordManager() manager.connector_index = mock_connector_index return manager @@ -36,7 +37,7 @@ def connector_record_manager(mock_connector_index): @pytest.mark.asyncio async def test_ensure_connector_records_exist_creates_connectors_if_not_exist( connector_record_manager, mock_agent_config -): +) -> None: random_connector_name_id = "1234" with patch( @@ -61,7 +62,7 @@ async def test_ensure_connector_records_exist_creates_connectors_if_not_exist( @pytest.mark.asyncio async def test_ensure_connector_records_exist_connector_already_exists( connector_record_manager, mock_agent_config -): +) -> None: connector_record_manager.connector_index.connector_exists = AsyncMock( return_value=True ) @@ -72,7 +73,7 @@ async def test_ensure_connector_records_exist_connector_already_exists( @pytest.mark.asyncio async def test_ensure_connector_records_raises_on_non_404_error( connector_record_manager, mock_agent_config -): +) -> None: connector_record_manager.connector_index.connector_exists = AsyncMock( side_effect=Exception("Unexpected error") ) @@ -87,7 +88,7 @@ async def test_ensure_connector_records_raises_on_non_404_error( @pytest.mark.asyncio async def test_ensure_connector_records_exist_agent_config_not_ready( connector_record_manager, -): +) -> None: invalid_config = {"connectors": []} await connector_record_manager.ensure_connector_records_exist(invalid_config) assert connector_record_manager.connector_index.connector_put.call_count == 0 @@ -96,7 +97,7 @@ async def test_ensure_connector_records_exist_agent_config_not_ready( @pytest.mark.asyncio async def test_ensure_connector_records_exist_exception_on_create( connector_record_manager, mock_agent_config -): +) -> None: connector_record_manager.connector_index.connector_exists = AsyncMock( return_value=False ) @@ -109,14 +110,14 @@ async def test_ensure_connector_records_exist_exception_on_create( def test_agent_config_ready_with_valid_config( connector_record_manager, mock_agent_config -): +) -> None: ready, _ = connector_record_manager._check_agent_config_ready(mock_agent_config) assert ready is True def test_agent_config_ready_with_invalid_config_missing_connectors( connector_record_manager, -): +) -> None: invalid_config = { "elasticsearch": {"host": "http://localhost:9200", "api_key": "dummy_key"} } @@ -126,7 +127,7 @@ def test_agent_config_ready_with_invalid_config_missing_connectors( def test_agent_config_ready_with_invalid_config_missing_elasticsearch( connector_record_manager, -): +) -> None: invalid_config = {"connectors": [{"connector_id": "1", "service_type": "service1"}]} ready, _ = connector_record_manager._check_agent_config_ready(invalid_config) assert ready is False diff --git a/tests/agent/test_protocol.py b/tests/agent/test_protocol.py index ac78eb833..b60602bc0 100644 --- a/tests/agent/test_protocol.py +++ b/tests/agent/test_protocol.py @@ -11,11 +11,12 @@ from google.protobuf.struct_pb2 import Struct from connectors.agent.config import ConnectorsAgentConfigurationWrapper +from connectors.agent.connector_record_manager import ConnectorRecordManager from connectors.agent.protocol import ConnectorActionHandler, ConnectorCheckinHandler @pytest.fixture(autouse=True) -def input_mock(): +def input_mock() -> Mock: unit_mock = Mock() unit_mock.unit_type = proto.UnitType.INPUT @@ -34,7 +35,7 @@ def _string_config_field_mock(value): @pytest.fixture(autouse=True) -def connector_record_manager_mock(): +def connector_record_manager_mock() -> Mock: connector_record_manager_mock = Mock() connector_record_manager_mock.ensure_connector_records_exist = AsyncMock( return_value=True @@ -44,7 +45,7 @@ def connector_record_manager_mock(): class TestConnectorActionHandler: @pytest.mark.asyncio - async def test_handle_action(self): + async def test_handle_action(self) -> None: action_handler = ConnectorActionHandler() with pytest.raises(NotImplementedError): @@ -54,8 +55,8 @@ async def test_handle_action(self): class TestConnectorCheckingHandler: @pytest.mark.asyncio async def test_apply_from_client_when_no_units_received( - self, connector_record_manager_mock, input_mock - ): + self, connector_record_manager_mock: ConnectorRecordManager, input_mock + ) -> None: client_mock = Mock() config_wrapper_mock = Mock() service_manager_mock = Mock() @@ -76,8 +77,8 @@ async def test_apply_from_client_when_no_units_received( @pytest.mark.asyncio async def test_apply_from_client_when_units_with_no_output( - self, connector_record_manager_mock, input_mock - ): + self, connector_record_manager_mock: ConnectorRecordManager, input_mock + ) -> None: client_mock = Mock() config_wrapper_mock = Mock() service_manager_mock = Mock() @@ -100,8 +101,8 @@ async def test_apply_from_client_when_units_with_no_output( @pytest.mark.asyncio async def test_apply_from_client_when_units_with_output_and_non_updating_config( - self, connector_record_manager_mock, input_mock - ): + self, connector_record_manager_mock: ConnectorRecordManager, input_mock + ) -> None: client_mock = Mock() config_wrapper_mock = Mock() @@ -128,8 +129,8 @@ async def test_apply_from_client_when_units_with_output_and_non_updating_config( @pytest.mark.asyncio async def test_apply_from_client_when_units_with_output_and_updating_config( - self, connector_record_manager_mock, input_mock - ): + self, connector_record_manager_mock: ConnectorRecordManager, input_mock + ) -> None: client_mock = Mock() config_wrapper_mock = Mock() @@ -157,8 +158,8 @@ async def test_apply_from_client_when_units_with_output_and_updating_config( @pytest.mark.asyncio async def test_apply_from_client_when_units_with_multiple_outputs_and_updating_config( - self, connector_record_manager_mock, input_mock - ): + self, connector_record_manager_mock: ConnectorRecordManager, input_mock + ) -> None: client_mock = Mock() config_wrapper_mock = Mock() @@ -197,8 +198,8 @@ async def test_apply_from_client_when_units_with_multiple_outputs_and_updating_c @pytest.mark.asyncio async def test_apply_from_client_when_units_with_multiple_mixed_outputs_and_updating_config( - self, connector_record_manager_mock, input_mock - ): + self, connector_record_manager_mock: ConnectorRecordManager, input_mock + ) -> None: client_mock = Mock() config_wrapper_mock = Mock() @@ -242,8 +243,8 @@ async def test_apply_from_client_when_units_with_multiple_mixed_outputs_and_upda @pytest.mark.asyncio async def test_apply_from_client_when_units_with_output_and_updating_log_level( - self, connector_record_manager_mock, input_mock - ): + self, connector_record_manager_mock: ConnectorRecordManager, input_mock + ) -> None: client_mock = Mock() config_wrapper = ConnectorsAgentConfigurationWrapper() diff --git a/tests/agent/test_service_manager.py b/tests/agent/test_service_manager.py index 6e4a528ba..06df5d021 100644 --- a/tests/agent/test_service_manager.py +++ b/tests/agent/test_service_manager.py @@ -13,7 +13,7 @@ @pytest.fixture(autouse=True) -def config_mock(): +def config_mock() -> Mock: config = Mock() config.get.return_value = { @@ -26,24 +26,24 @@ def config_mock(): class StubMultiService: - def __init__(self): + def __init__(self) -> None: self.running_stop = asyncio.Event() self.has_ran = False self.has_shutdown = False - async def run(self): + async def run(self) -> None: self.has_ran = True self.running_stop.clear() await self.running_stop.wait() - def shutdown(self, sig): + def shutdown(self, sig) -> None: self.has_shutdown = True self.running_stop.set() @pytest.mark.asyncio @patch("connectors.agent.service_manager.get_services", return_value=StubMultiService()) -async def test_run_and_stop_work_as_intended(patch_get_services, config_mock): +async def test_run_and_stop_work_as_intended(patch_get_services, config_mock) -> None: service_manager = ConnectorServiceManager(config_mock) async def stop_service_after_timeout(): @@ -58,7 +58,9 @@ async def stop_service_after_timeout(): @pytest.mark.asyncio @patch("connectors.agent.service_manager.get_services", return_value=StubMultiService()) -async def test_restart_starts_another_multiservice(patch_get_services, config_mock): +async def test_restart_starts_another_multiservice( + patch_get_services, config_mock +) -> None: service_manager = ConnectorServiceManager(config_mock) async def stop_service_after_timeout(): @@ -75,7 +77,9 @@ async def stop_service_after_timeout(): @pytest.mark.asyncio @patch("connectors.agent.service_manager.get_services", return_value=StubMultiService()) -async def test_cannot_run_same_service_manager_twice(patch_get_services, config_mock): +async def test_cannot_run_same_service_manager_twice( + patch_get_services, config_mock +) -> None: service_manager = ConnectorServiceManager(config_mock) with pytest.raises(ServiceAlreadyRunningError): diff --git a/tests/commons.py b/tests/commons.py index b3f17e085..b0d50647a 100644 --- a/tests/commons.py +++ b/tests/commons.py @@ -6,8 +6,9 @@ import math from functools import cached_property from random import choices +from typing import Any, List, Optional, Sized, Union -from faker import Faker +from faker.proxy import Faker class AsyncIterator: @@ -15,7 +16,7 @@ class AsyncIterator: Async documents generator fake class, which records the args and kwargs it was called with. """ - def __init__(self, items, reusable=False): + def __init__(self, items: List[Any], reusable: bool = False) -> None: """ AsyncIterator is a test-only abstraction to mock async iterables. By default it's usable only once: once iterated over, he iterator will not @@ -30,10 +31,10 @@ def __init__(self, items, reusable=False): self.call_count = 0 self.reusable = reusable - def __aiter__(self): + def __aiter__(self) -> "AsyncIterator": return self - async def __anext__(self): + async def __anext__(self) -> Any: if self.i >= len(self.items): if self.reusable: self.i = 0 @@ -43,7 +44,7 @@ async def __anext__(self): self.i += 1 return item - def __call__(self, *args, **kwargs): + def __call__(self, *args, **kwargs) -> "AsyncIterator": self.call_count += 1 if args: @@ -54,17 +55,17 @@ def __call__(self, *args, **kwargs): return self - def assert_not_called(self): + def assert_not_called(self) -> None: if self.call_count != 0: msg = f"Expected zero calls. Actual number of calls: {self.call_count}." raise AssertionError(msg) - def assert_called_once(self): + def assert_called_once(self) -> None: if self.call_count != 1: msg = f"Expected one call. Actual number of calls: {self.call_count}." raise AssertionError(msg) - def assert_called_once_with(self, *args, **kwargs): + def assert_called_once_with(self, *args, **kwargs) -> None: self.assert_called_once() if len(self.call_args) > 0 and self.call_args[0] != args: @@ -77,7 +78,7 @@ def assert_called_once_with(self, *args, **kwargs): class WeightedFakeProvider: - def __init__(self, seed=None, weights=None): + def __init__(self, seed=None, weights: Optional[Sized] = None) -> None: self.seed = seed if weights and len(weights) != 4: msg = f"Exactly 4 weights should be provided. Got {len(weights)}: {weights}" @@ -99,11 +100,11 @@ def _texts(self): ] @cached_property - def fake(self): + def fake(self) -> Faker: return self.fake_provider.fake @cached_property - def _htmls(self): + def _htmls(self) -> List[str]: return [ self.fake_provider.small_html(), self.fake_provider.medium_html(), @@ -114,12 +115,14 @@ def _htmls(self): def get_text(self): return choices(self._texts, self.weights)[0] - def get_html(self): + def get_html(self) -> str: return choices(self._htmls, self.weights)[0] class FakeProvider: - def __init__(self, seed=None): + def __init__( + self, seed: Union[None, bytearray, bytes, float, int, str] = None + ) -> None: self.seed = seed self.fake = Faker() if seed: @@ -144,26 +147,26 @@ def large_text(self): def extra_large_text(self): return self.generate_text(20 * 1024 * 1024) - def small_html(self): + def small_html(self) -> str: # Around 100KB return self.generate_html(1) - def medium_html(self): + def medium_html(self) -> str: # Around 1MB return self.generate_html(1 * 10) - def large_html(self): + def large_html(self) -> str: # Around 8MB return self.generate_html(8 * 10) - def extra_large_html(self): + def extra_large_html(self) -> str: # Around 25MB return self.generate_html(25 * 10) def generate_text(self, max_size): return self.fake.text(max_nb_chars=max_size) - def generate_html(self, images_of_100kb): + def generate_html(self, images_of_100kb) -> str: img = self._cached_random_str # 100kb text = self.small_text() diff --git a/tests/conftest.py b/tests/conftest.py index a4f9844ff..8f6e997cc 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -17,25 +17,25 @@ class Logger: - def __init__(self, silent=True): + def __init__(self, silent: bool = True) -> None: self.logs = [] self.silent = silent - def debug(self, msg, exc_info=False): + def debug(self, msg, exc_info: bool = False) -> None: if not self.silent: print(msg) # noqa: T201 self.logs.append(msg) if exc_info: self.logs.append(traceback.format_exc()) - def assert_instance(self, instance): + def assert_instance(self, instance) -> None: for log in self.logs: if isinstance(log, instance): return msg = f"Could not find an instance of {instance}" raise AssertionError(msg) - def assert_not_present(self, lines): + def assert_not_present(self, lines) -> None: if isinstance(lines, str): lines = [lines] for msg in lines: @@ -43,7 +43,7 @@ def assert_not_present(self, lines): if isinstance(log, str) and msg in log: raise AssertionError(f"'{msg}' found in {self.logs}") - def assert_present(self, lines): + def assert_present(self, lines) -> None: if isinstance(lines, str): lines = [lines] for msg in lines: @@ -80,7 +80,7 @@ def catch_stdout(): @pytest.fixture -def patch_logger(silent=True): +def patch_logger(silent: bool = True): class PatchedLogger(Logger): def info(self, msg, *args, prefix=None, extra=None, exc_info=None): super(PatchedLogger, self).info(msg, *args) @@ -153,7 +153,7 @@ def mock_aws(): os.environ["AWS_ACCESS_KEY_ID"] = old_key -def assert_re(expr, items): +def assert_re(expr: re.Pattern[str], items) -> None: expr = re.compile(expr) for item in reversed(items): diff --git a/tests/es/test_cli_client.py b/tests/es/test_cli_client.py index 65393165b..0bf9c9fb1 100644 --- a/tests/es/test_cli_client.py +++ b/tests/es/test_cli_client.py @@ -7,7 +7,7 @@ from connectors.es.cli_client import CLIClient -def test_overrides_user_agent_header(): +def test_overrides_user_agent_header() -> None: config = { "username": "elastic", "password": "changeme", diff --git a/tests/es/test_client.py b/tests/es/test_client.py index ea9b48e2e..f8b74398e 100644 --- a/tests/es/test_client.py +++ b/tests/es/test_client.py @@ -26,7 +26,7 @@ @pytest.mark.asyncio -async def test_with_concurrency_control(): +async def test_with_concurrency_control() -> None: mock_func = Mock() num_retries = 10 @@ -80,7 +80,7 @@ class TestESClient: ], ) @pytest.mark.asyncio - async def test_has_license_enabled(self, enabled_license, licenses_enabled): + async def test_has_license_enabled(self, enabled_license, licenses_enabled) -> None: es_client = ESClient(BASIC_CONFIG) es_client.client = AsyncMock() es_client.client.license.get = AsyncMock( @@ -105,7 +105,9 @@ async def test_has_license_enabled(self, enabled_license, licenses_enabled): ], ) @pytest.mark.asyncio - async def test_has_licenses_disabled(self, enabled_license, licenses_disabled): + async def test_has_licenses_disabled( + self, enabled_license, licenses_disabled + ) -> None: es_client = ESClient(BASIC_CONFIG) es_client.client = AsyncMock() es_client.client.license.get = AsyncMock( @@ -117,7 +119,7 @@ async def test_has_licenses_disabled(self, enabled_license, licenses_disabled): assert not is_enabled @pytest.mark.asyncio - async def test_has_license_disabled_with_expired_license(self): + async def test_has_license_disabled_with_expired_license(self) -> None: es_client = ESClient(BASIC_CONFIG) es_client.client = AsyncMock() es_client.client.license.get = AsyncMock( @@ -132,7 +134,7 @@ async def test_has_license_disabled_with_expired_license(self): assert license_ == License.EXPIRED @pytest.mark.asyncio - async def test_auth_conflict_logs_message(self, patch_logger): + async def test_auth_conflict_logs_message(self, patch_logger) -> None: ESClient(BASIC_API_CONFIG) patch_logger.assert_present( "configured API key will be used over configured basic auth" @@ -146,11 +148,11 @@ async def test_auth_conflict_logs_message(self, patch_logger): (BASIC_API_CONFIG, "ApiKey foo"), ], ) - def test_es_client_with_auth(self, config, expected_auth_header): + def test_es_client_with_auth(self, config, expected_auth_header) -> None: es_client = ESClient(config) assert es_client.client._headers["Authorization"] == expected_auth_header - def test_esclient(self): + def test_esclient(self) -> None: # creating a client with a minimal config should create one with sane # defaults @@ -165,7 +167,7 @@ def test_esclient(self): assert es_client.client._headers["Authorization"] == basic @pytest.mark.asyncio - async def test_es_client_auth_error(self, mock_responses, patch_logger): + async def test_es_client_auth_error(self, mock_responses, patch_logger) -> None: headers = {"X-Elastic-Product": "Elasticsearch"} # if we get auth issues, we want to know about them @@ -217,7 +219,7 @@ async def test_es_client_auth_error(self, mock_responses, patch_logger): patch_logger.assert_present("missing authentication credentials") @pytest.mark.asyncio - async def test_es_client_no_server(self): + async def test_es_client_no_server(self) -> None: # if we can't reach the server, we need to catch it cleanly config = { "username": "elastic", @@ -239,14 +241,14 @@ async def test_es_client_no_server(self): assert not await es_client.ping() await es_client.close() - def test_sets_product_origin_header(self): + def test_sets_product_origin_header(self) -> None: config = {"headers": {"some-header": "some-value"}} es_client = ESClient(config) assert es_client.client._headers["X-elastic-product-origin"] == "connectors" - def test_sets_user_agent(self): + def test_sets_user_agent(self) -> None: config = {"headers": {"some-header": "some-value"}} es_client = ESClient(config) @@ -259,19 +261,19 @@ def test_sets_user_agent(self): class TestTransientElasticsearchRetrier: @cached_property - def logger_mock(self): + def logger_mock(self) -> Mock: return Mock() @cached_property - def max_retries(self): + def max_retries(self) -> int: return 5 @cached_property - def retry_interval(self): + def retry_interval(self) -> int: return 50 @pytest.mark.asyncio - async def test_execute_with_retry(self, patch_sleep): + async def test_execute_with_retry(self, patch_sleep) -> None: retrier = TransientElasticsearchRetrier( self.logger_mock, self.max_retries, self.retry_interval ) @@ -284,7 +286,7 @@ async def _func(): assert patch_sleep.not_called() @pytest.mark.asyncio - async def test_execute_with_retry_429_with_recovery(self, patch_sleep): + async def test_execute_with_retry_429_with_recovery(self, patch_sleep) -> None: retrier = TransientElasticsearchRetrier( self.logger_mock, self.max_retries, self.retry_interval ) @@ -311,7 +313,7 @@ async def _func(): assert patch_sleep.awaited_exactly(2) @pytest.mark.asyncio - async def test_execute_with_retry_429_no_recovery(self, patch_sleep): + async def test_execute_with_retry_429_no_recovery(self, patch_sleep) -> None: retrier = TransientElasticsearchRetrier( self.logger_mock, self.max_retries, self.retry_interval ) @@ -330,7 +332,7 @@ async def _func(): assert patch_sleep.awaited_exactly(self.max_retries) @pytest.mark.asyncio - async def test_execute_with_retry_connection_timeout(self, patch_sleep): + async def test_execute_with_retry_connection_timeout(self, patch_sleep) -> None: retrier = TransientElasticsearchRetrier( self.logger_mock, self.max_retries, self.retry_interval ) @@ -348,7 +350,7 @@ async def _func(): assert patch_sleep.awaited_exactly(self.max_retries) @pytest.mark.asyncio - async def test_execute_with_retry_cancelled_midway(self, patch_sleep): + async def test_execute_with_retry_cancelled_midway(self, patch_sleep) -> None: retrier = TransientElasticsearchRetrier( self.logger_mock, self.max_retries, self.retry_interval ) diff --git a/tests/es/test_document.py b/tests/es/test_document.py index 65b7f94a6..a81c9e901 100644 --- a/tests/es/test_document.py +++ b/tests/es/test_document.py @@ -20,18 +20,18 @@ {"_id": "1", "_source": "hahaha"}, ], ) -def test_es_document_raise(doc_source): +def test_es_document_raise(doc_source) -> None: with pytest.raises(InvalidDocumentSourceError): ESDocument(elastic_index=None, doc_source=doc_source) -def test_es_document_ok(): +def test_es_document_ok() -> None: doc_source = {"_id": "1", "_source": {}} es_document = ESDocument(elastic_index=None, doc_source=doc_source) assert isinstance(es_document, ESDocument) -def test_es_document_get(): +def test_es_document_get() -> None: source = { "_id": "test", "_seq_no": 1, @@ -62,7 +62,7 @@ def test_es_document_get(): @pytest.mark.asyncio -async def test_reload(): +async def test_reload() -> None: source = { "_id": "test", "_seq_no": 1, diff --git a/tests/es/test_index.py b/tests/es/test_index.py index 66653903c..aeb1c6a50 100644 --- a/tests/es/test_index.py +++ b/tests/es/test_index.py @@ -20,7 +20,7 @@ @pytest.mark.asyncio -async def test_es_index_create_object_error(mock_responses): +async def test_es_index_create_object_error(mock_responses) -> None: index = ESIndex(index_name, config) mock_responses.post( f"http://nowhere.com:9200/{index_name}/_refresh", headers=headers, status=200 @@ -44,12 +44,12 @@ class FakeDocument: class FakeIndex(ESIndex): - def _create_object(self, doc): + def _create_object(self, doc) -> FakeDocument: return FakeDocument() @pytest.mark.asyncio -async def test_fetch_by_id(mock_responses): +async def test_fetch_by_id(mock_responses) -> None: doc_id = "1" index = FakeIndex(index_name, config) mock_responses.post( @@ -70,7 +70,7 @@ async def test_fetch_by_id(mock_responses): @pytest.mark.asyncio -async def test_fetch_response_by_id(mock_responses): +async def test_fetch_response_by_id(mock_responses) -> None: doc_id = "1" index = ESIndex(index_name, config) doc_source = { @@ -98,7 +98,7 @@ async def test_fetch_response_by_id(mock_responses): @pytest.mark.asyncio -async def test_fetch_response_by_id_not_found(mock_responses): +async def test_fetch_response_by_id_not_found(mock_responses) -> None: doc_id = "1" index = FakeIndex(index_name, config) mock_responses.post( @@ -118,7 +118,7 @@ async def test_fetch_response_by_id_not_found(mock_responses): @pytest.mark.asyncio -async def test_fetch_response_by_id_api_error(mock_responses, patch_sleep): +async def test_fetch_response_by_id_api_error(mock_responses, patch_sleep) -> None: doc_id = "1" index = FakeIndex(index_name, config) mock_responses.post( @@ -139,7 +139,7 @@ async def test_fetch_response_by_id_api_error(mock_responses, patch_sleep): @pytest.mark.asyncio -async def test_index(mock_responses): +async def test_index(mock_responses) -> None: doc_id = "1" index = ESIndex(index_name, config) mock_responses.post( @@ -156,7 +156,7 @@ async def test_index(mock_responses): @pytest.mark.asyncio -async def test_update(mock_responses): +async def test_update(mock_responses) -> None: doc_id = "1" index = ESIndex(index_name, config) mock_responses.post( @@ -172,7 +172,7 @@ async def test_update(mock_responses): @pytest.mark.asyncio -async def test_update_with_concurrency_control(mock_responses): +async def test_update_with_concurrency_control(mock_responses) -> None: doc_id = "1" index = ESIndex(index_name, config) mock_responses.post( @@ -188,7 +188,7 @@ async def test_update_with_concurrency_control(mock_responses): @pytest.mark.asyncio -async def test_update_by_script(): +async def test_update_by_script() -> None: doc_id = "1" script = {"source": ""} index = ESIndex(index_name, config) @@ -202,7 +202,9 @@ async def test_update_by_script(): @pytest.mark.asyncio -async def test_get_all_docs_with_error(mock_responses, patch_logger, patch_sleep): +async def test_get_all_docs_with_error( + mock_responses, patch_logger, patch_sleep +) -> None: index = FakeIndex(index_name, config) mock_responses.post( f"http://nowhere.com:9200/{index_name}/_refresh", headers=headers, status=200 @@ -226,7 +228,7 @@ async def test_get_all_docs_with_error(mock_responses, patch_logger, patch_sleep @pytest.mark.asyncio -async def test_get_all_docs(mock_responses): +async def test_get_all_docs(mock_responses) -> None: index = FakeIndex(index_name, config) total = 3 mock_responses.post( @@ -258,7 +260,7 @@ async def test_get_all_docs(mock_responses): @pytest.mark.asyncio -async def test_es_api_connector_check_in(): +async def test_es_api_connector_check_in() -> None: connector_id = "id" es_api = ESApi(elastic_config=config) @@ -270,7 +272,7 @@ async def test_es_api_connector_check_in(): @pytest.mark.asyncio -async def test_es_api_connector_put(): +async def test_es_api_connector_put() -> None: connector_id = "id" service_type = "service_type" connector_name = "connector_name" @@ -294,7 +296,7 @@ async def test_es_api_connector_put(): @pytest.mark.asyncio -async def test_es_api_connector_update_scheduling(): +async def test_es_api_connector_update_scheduling() -> None: connector_id = "id" scheduling = {"enabled": "true", "interval": "0 4 5 1 *"} @@ -309,7 +311,7 @@ async def test_es_api_connector_update_scheduling(): @pytest.mark.asyncio -async def test_es_api_connector_update_configuration(): +async def test_es_api_connector_update_configuration() -> None: connector_id = "id" configuration = {"config_key": "config_value"} values = {} @@ -325,7 +327,7 @@ async def test_es_api_connector_update_configuration(): @pytest.mark.asyncio -async def test_es_api_connector_update_filtering_draft_validation(): +async def test_es_api_connector_update_filtering_draft_validation() -> None: connector_id = "id" validation_result = {"validation": "result"} @@ -342,7 +344,7 @@ async def test_es_api_connector_update_filtering_draft_validation(): @pytest.mark.asyncio -async def test_es_api_connector_activate_filtering_draft(): +async def test_es_api_connector_activate_filtering_draft() -> None: connector_id = "id" es_api = ESApi(elastic_config=config) @@ -356,7 +358,7 @@ async def test_es_api_connector_activate_filtering_draft(): @pytest.mark.asyncio -async def test_es_api_connector_sync_job_create(): +async def test_es_api_connector_sync_job_create() -> None: connector_id = "id" job_type = "full" trigger_method = "on_demand" @@ -372,7 +374,7 @@ async def test_es_api_connector_sync_job_create(): @pytest.mark.asyncio -async def test_es_api_connector_get(): +async def test_es_api_connector_get() -> None: connector_id = "id" include_deleted = False @@ -387,7 +389,7 @@ async def test_es_api_connector_get(): @pytest.mark.asyncio -async def test_es_api_connector_sync_job_claim(): +async def test_es_api_connector_sync_job_claim() -> None: sync_job_id = "sync_job_id_test" worker_hostname = "workerhostname" sync_cursor = {"foo": "bar"} @@ -405,7 +407,7 @@ async def test_es_api_connector_sync_job_claim(): @pytest.mark.asyncio -async def test_es_api_connector_sync_job_update_stats_with_metadata(): +async def test_es_api_connector_sync_job_update_stats_with_metadata() -> None: sync_job_id = "sync_job_id_test" ingestion_stats = {"ingestion": "stat"} metadata = {"meta": "data"} @@ -421,7 +423,7 @@ async def test_es_api_connector_sync_job_update_stats_with_metadata(): @pytest.mark.asyncio -async def test_es_api_connector_sync_job_update_stats_metadata_as_none(): +async def test_es_api_connector_sync_job_update_stats_metadata_as_none() -> None: sync_job_id = "sync_job_id_test" ingestion_stats = {"ingestion": "stat"} metadata = None # make sure metadata gets passed as '{}' if undefined diff --git a/tests/es/test_license.py b/tests/es/test_license.py index bb7d0f97f..01731a4a1 100644 --- a/tests/es/test_license.py +++ b/tests/es/test_license.py @@ -11,14 +11,14 @@ from connectors.protocol import JobType -def mock_source_klass(is_premium): +def mock_source_klass(is_premium) -> Mock: source_klass = Mock() source_klass.is_premium = Mock(return_value=is_premium) return source_klass -def mock_connector(document_level_security_enabled): +def mock_connector(document_level_security_enabled) -> Mock: connector = Mock() connector.features = Mock() connector.features.document_level_security_enabled = Mock( @@ -28,7 +28,7 @@ def mock_connector(document_level_security_enabled): return connector -def mock_sync_job(job_type): +def mock_sync_job(job_type) -> Mock: sync_job = Mock() sync_job.job_type = job_type @@ -45,7 +45,7 @@ def mock_sync_job(job_type): ) def test_requires_platinum_license( job_type, document_level_security_enabled, is_premium -): +) -> None: sync_job = mock_sync_job(job_type) connector = mock_connector(document_level_security_enabled) source_klass = mock_source_klass(is_premium) @@ -63,7 +63,7 @@ def test_requires_platinum_license( ) def test_does_not_require_platinum_license( job_type, document_level_security_enabled, is_premium -): +) -> None: sync_job = mock_sync_job(job_type) connector = mock_connector(document_level_security_enabled) source_klass = mock_source_klass(is_premium) diff --git a/tests/es/test_management_client.py b/tests/es/test_management_client.py index 18ac77858..0da48809a 100644 --- a/tests/es/test_management_client.py +++ b/tests/es/test_management_client.py @@ -32,13 +32,17 @@ def es_management_client(self): yield es_management_client @pytest.mark.asyncio - async def test_ensure_exists_when_no_indices_passed(self, es_management_client): + async def test_ensure_exists_when_no_indices_passed( + self, es_management_client + ) -> None: await es_management_client.ensure_exists() es_management_client.client.indices.exists.assert_not_called() @pytest.mark.asyncio - async def test_ensure_exists_when_indices_passed(self, es_management_client): + async def test_ensure_exists_when_indices_passed( + self, es_management_client + ) -> None: index_name = "search-mongo" es_management_client.client.indices.exists.return_value = False @@ -48,7 +52,7 @@ async def test_ensure_exists_when_indices_passed(self, es_management_client): es_management_client.client.indices.create.assert_called_with(index=index_name) @pytest.mark.asyncio - async def test_create_content_index(self, es_management_client): + async def test_create_content_index(self, es_management_client) -> None: index_name = "search-mongo" lang_code = "en" await es_management_client.create_content_index(index_name, lang_code) @@ -58,7 +62,7 @@ async def test_create_content_index(self, es_management_client): @pytest.mark.asyncio async def test_ensure_ingest_pipeline_exists_when_pipeline_do_not_exist( self, es_management_client - ): + ) -> None: pipeline_id = 1 version = 2 description = "that's a pipeline" @@ -84,7 +88,7 @@ async def test_ensure_ingest_pipeline_exists_when_pipeline_do_not_exist( @pytest.mark.asyncio async def test_ensure_ingest_pipeline_exists_when_pipeline_exists( self, es_management_client - ): + ) -> None: pipeline_id = 1 version = 2 description = "that's a pipeline" @@ -99,7 +103,7 @@ async def test_ensure_ingest_pipeline_exists_when_pipeline_exists( es_management_client.client.ingest.put_pipeline.assert_not_called() @pytest.mark.asyncio - async def test_delete_indices(self, es_management_client): + async def test_delete_indices(self, es_management_client) -> None: indices = ["search-mongo"] es_management_client.client.indices.delete = AsyncMock() @@ -109,7 +113,7 @@ async def test_delete_indices(self, es_management_client): ) @pytest.mark.asyncio - async def test_index_exists(self, es_management_client): + async def test_index_exists(self, es_management_client) -> None: index_name = "search-mongo" es_management_client.client.indices.exists = AsyncMock() @@ -117,7 +121,7 @@ async def test_index_exists(self, es_management_client): es_management_client.client.indices.exists.assert_awaited_with(index=index_name) @pytest.mark.asyncio - async def test_clean_index(self, es_management_client): + async def test_clean_index(self, es_management_client) -> None: index_name = "search-mongo" es_management_client.client.indices.exists = AsyncMock() @@ -127,13 +131,13 @@ async def test_clean_index(self, es_management_client): ) @pytest.mark.asyncio - async def test_list_indices(self, es_management_client): + async def test_list_indices(self, es_management_client) -> None: await es_management_client.list_indices(index="search-*") es_management_client.client.indices.stats.assert_awaited_with(index="search-*") @pytest.mark.asyncio - async def test_upsert(self, es_management_client): + async def test_upsert(self, es_management_client) -> None: _id = "123" index_name = "search-mongo" document = {"something": "something"} @@ -147,7 +151,7 @@ async def test_upsert(self, es_management_client): @pytest.mark.asyncio async def test_yield_existing_documents_metadata_when_index_does_not_exist( self, es_management_client, mock_responses - ): + ) -> None: es_management_client.index_exists = AsyncMock(return_value=False) records = [ @@ -171,7 +175,7 @@ async def test_yield_existing_documents_metadata_when_index_does_not_exist( @pytest.mark.asyncio async def test_yield_existing_documents_metadata_when_index_exists( self, es_management_client, mock_responses - ): + ) -> None: es_management_client.index_exists = AsyncMock(return_value=True) records = [ @@ -193,7 +197,9 @@ async def test_yield_existing_documents_metadata_when_index_exists( assert ids == ["1", "2"] @pytest.mark.asyncio - async def test_get_connector_secret(self, es_management_client, mock_responses): + async def test_get_connector_secret( + self, es_management_client, mock_responses + ) -> None: secret_id = "secret-id" es_management_client.client.perform_request = AsyncMock( @@ -209,7 +215,7 @@ async def test_get_connector_secret(self, es_management_client, mock_responses): @pytest.mark.asyncio async def test_get_connector_secret_when_secret_does_not_exist( self, es_management_client, mock_responses - ): + ) -> None: secret_id = "secret-id" error_meta = Mock() @@ -227,7 +233,9 @@ async def test_get_connector_secret_when_secret_does_not_exist( assert secret is None @pytest.mark.asyncio - async def test_create_connector_secret(self, es_management_client, mock_responses): + async def test_create_connector_secret( + self, es_management_client, mock_responses + ) -> None: secret_id = "secret-id" secret_value = "my-secret" @@ -245,7 +253,9 @@ async def test_create_connector_secret(self, es_management_client, mock_response ) @pytest.mark.asyncio - async def test_extract_index_or_alias_with_index(self, es_management_client): + async def test_extract_index_or_alias_with_index( + self, es_management_client + ) -> None: response = { "shapo-online": { "aliases": {"search-shapo-online": {}}, @@ -261,7 +271,9 @@ async def test_extract_index_or_alias_with_index(self, es_management_client): assert index == response["shapo-online"] @pytest.mark.asyncio - async def test_extract_index_or_alias_with_alias(self, es_management_client): + async def test_extract_index_or_alias_with_alias( + self, es_management_client + ) -> None: response = { "shapo-online": { "aliases": {"search-shapo-online": {}}, @@ -277,7 +289,9 @@ async def test_extract_index_or_alias_with_alias(self, es_management_client): assert index == response["shapo-online"] @pytest.mark.asyncio - async def test_extract_index_or_alias_when_none_present(self, es_management_client): + async def test_extract_index_or_alias_when_none_present( + self, es_management_client + ) -> None: response = { "shapo-online": { "aliases": {"search-shapo-online": {}}, @@ -293,7 +307,9 @@ async def test_extract_index_or_alias_when_none_present(self, es_management_clie assert index is None @pytest.mark.asyncio - async def test_get_index_or_alias(self, es_management_client, mock_responses): + async def test_get_index_or_alias( + self, es_management_client, mock_responses + ) -> None: secret_id = "secret-id" secret_value = "my-secret" diff --git a/tests/fake_sources.py b/tests/fake_sources.py index cfd49e9ee..5d0c83a6e 100644 --- a/tests/fake_sources.py +++ b/tests/fake_sources.py @@ -8,6 +8,7 @@ """ from functools import partial +from typing import Dict, Optional from unittest.mock import Mock from connectors.filtering.validation import ( @@ -23,7 +24,7 @@ class FakeSource(BaseDataSource): name = "Fakey" service_type = "fake" - def __init__(self, configuration): + def __init__(self, configuration) -> None: self.configuration = configuration if configuration.has_field("raise"): msg = "I break on init" @@ -31,16 +32,16 @@ def __init__(self, configuration): self.fail = configuration.has_field("fail") self.configuration_invalid = configuration.has_field("configuration_invalid") - async def changed(self): + async def changed(self) -> bool: return True - async def ping(self): + async def ping(self) -> None: pass - async def close(self): + async def close(self) -> None: pass - async def _dl(self, doc_id, timestamp=None, doit=None): + async def _dl(self, doc_id, timestamp=None, doit=None) -> Optional[Dict[str, str]]: if not doit: return return {"_id": doc_id, "_timestamp": timestamp, "text": "xx"} @@ -56,18 +57,18 @@ def get_default_configuration(cls): return {} @classmethod - async def validate_filtering(cls, filtering): + async def validate_filtering(cls, filtering) -> FilteringValidationResult: # being explicit about that this result should always be valid return FilteringValidationResult( state=FilteringValidationState.VALID, errors=[] ) - async def validate_config(self): + async def validate_config(self) -> None: if self.configuration_invalid: msg = "I fail when validating configuration" raise ValueError(msg) - def tweak_bulk_options(self, options): + def tweak_bulk_options(self, options) -> None: pass @@ -82,7 +83,7 @@ class FakeSourceFilteringValid(FakeSource): service_type = "filtering_state_valid" @classmethod - async def validate_filtering(cls, filtering): + async def validate_filtering(cls, filtering) -> FilteringValidationResult: # use separate fake source to not rely on the behaviour in FakeSource which is used in many tests return FilteringValidationResult( state=FilteringValidationState.VALID, errors=[] @@ -96,7 +97,7 @@ class FakeSourceFilteringStateInvalid(FakeSource): service_type = "filtering_state_invalid" @classmethod - async def validate_filtering(cls, filtering): + async def validate_filtering(cls, filtering) -> FilteringValidationResult: return FilteringValidationResult(state=FilteringValidationState.INVALID) @@ -107,7 +108,7 @@ class FakeSourceFilteringStateEdited(FakeSource): service_type = "filtering_state_edited" @classmethod - async def validate_filtering(cls, filtering): + async def validate_filtering(cls, filtering) -> FilteringValidationResult: return FilteringValidationResult(state=FilteringValidationState.EDITED) @@ -118,7 +119,7 @@ class FakeSourceFilteringErrorsPresent(FakeSource): service_type = "filtering_errors_present" @classmethod - async def validate_filtering(cls, filtering): + async def validate_filtering(cls, filtering) -> FilteringValidationResult: return FilteringValidationResult(errors=[Mock()]) @@ -167,5 +168,5 @@ class PremiumFake(FakeSource): service_type = "premium_fake" @classmethod - def is_premium(): + def is_premium() -> bool: return True diff --git a/tests/filtering/test_basic_rule.py b/tests/filtering/test_basic_rule.py index e6e7bb5aa..6d4c93029 100644 --- a/tests/filtering/test_basic_rule.py +++ b/tests/filtering/test_basic_rule.py @@ -6,6 +6,7 @@ import datetime import uuid +from typing import Any, Dict, Union import pytest @@ -26,7 +27,7 @@ BASIC_RULE_ONE_RULE = "equals" BASIC_RULE_ONE_VALUE = "value one" -BASIC_RULE_ONE_JSON = { +BASIC_RULE_ONE_JSON: Dict[str, Union[int, str]] = { "id": BASIC_RULE_ONE_ID, "order": BASIC_RULE_ONE_ORDER, "policy": BASIC_RULE_ONE_POLICY, @@ -42,7 +43,7 @@ BASIC_RULE_TWO_RULE = "contains" BASIC_RULE_TWO_VALUE = "value two" -BASIC_RULE_TWO_JSON = { +BASIC_RULE_TWO_JSON: Dict[str, Union[int, str]] = { "id": BASIC_RULE_TWO_ID, "order": BASIC_RULE_TWO_ORDER, "policy": BASIC_RULE_TWO_POLICY, @@ -58,7 +59,7 @@ BASIC_RULE_THREE_RULE = "contains" BASIC_RULE_THREE_VALUE = "value three" -BASIC_RULE_THREE_JSON = { +BASIC_RULE_THREE_JSON: Dict[str, Union[int, str]] = { "id": BASIC_RULE_THREE_ID, "order": BASIC_RULE_THREE_ORDER, "policy": BASIC_RULE_THREE_POLICY, @@ -74,7 +75,7 @@ BASIC_RULE_DEFAULT_RULE = "equals" BASIC_RULE_DEFAULT_VALUE = ".*" -BASIC_RULE_DEFAULT_JSON = { +BASIC_RULE_DEFAULT_JSON: Dict[str, Union[int, str]] = { "id": BASIC_RULE_DEFAULT_ID, "order": BASIC_RULE_DEFAULT_ORDER, "policy": BASIC_RULE_DEFAULT_POLICY, @@ -99,7 +100,7 @@ year=2022, month=1, day=1, hour=5, minute=10, microsecond=5 ) -DOCUMENT_ONE = { +DOCUMENT_ONE: Dict[str, Union[datetime.datetime, float, str]] = { DESCRIPTION_KEY: DESCRIPTION_VALUE, AMOUNT_FLOAT_KEY: AMOUNT_FLOAT_VALUE, AMOUNT_INT_KEY: AMOUNT_INT_VALUE, @@ -107,7 +108,7 @@ CREATED_AT_DATETIME_KEY: CREATED_AT_DATETIME_VALUE, } -DOCUMENT_TWO = { +DOCUMENT_TWO: Dict[str, Union[datetime.datetime, float, str]] = { DESCRIPTION_KEY: DESCRIPTION_VALUE[1:], AMOUNT_FLOAT_KEY: AMOUNT_FLOAT_VALUE, AMOUNT_INT_KEY: AMOUNT_INT_VALUE, @@ -115,7 +116,7 @@ CREATED_AT_DATETIME_KEY: CREATED_AT_DATETIME_VALUE, } -DOCUMENT_THREE = { +DOCUMENT_THREE: Dict[str, Union[datetime.datetime, float, str]] = { DESCRIPTION_KEY: DESCRIPTION_VALUE[2:], AMOUNT_FLOAT_KEY: AMOUNT_FLOAT_VALUE, AMOUNT_INT_KEY: AMOUNT_INT_VALUE, @@ -168,7 +169,7 @@ "increments, expected_count", [([1, 2, 3], 6), ([None, None, None], 0), ([2, None], 2)], ) -def test_rule_match_stats_increment(increments, expected_count): +def test_rule_match_stats_increment(increments, expected_count) -> None: rule_match_stats = RuleMatchStats(Policy.INCLUDE, 0) for increment in increments: @@ -186,7 +187,7 @@ def test_rule_match_stats_increment(increments, expected_count): ([RuleMatchStats(Policy.INCLUDE, 1), RuleMatchStats(Policy.EXCLUDE, 2)], False), ], ) -def test_rule_match_stats_eq(rule_match_stats, should_equal): +def test_rule_match_stats_eq(rule_match_stats, should_equal) -> None: if should_equal: assert all(stats == rule_match_stats[0] for stats in rule_match_stats[1:]) else: @@ -294,7 +295,9 @@ def test_rule_match_stats_eq(rule_match_stats, should_equal): ), ], ) -def test_engine_should_ingest(documents_should_ingest_tuples, rules, expected_stats): +def test_engine_should_ingest( + documents_should_ingest_tuples, rules, expected_stats +) -> None: engine = BasicRuleEngine(rules) for document_should_ingest_tuple in documents_should_ingest_tuples: @@ -310,7 +313,7 @@ def test_engine_should_ingest(documents_should_ingest_tuples, rules, expected_st ) -def basic_rule_one_policy_and_rule_uppercase(): +def basic_rule_one_policy_and_rule_uppercase() -> Dict[str, Any]: basic_rule_uppercase = BASIC_RULE_ONE_JSON basic_rule_uppercase["rule"] = basic_rule_uppercase["rule"].upper() @@ -319,19 +322,19 @@ def basic_rule_one_policy_and_rule_uppercase(): return basic_rule_uppercase -def contains_rule_one(basic_rules): +def contains_rule_one(basic_rules) -> bool: return any(is_rule_one(basic_rule) for basic_rule in basic_rules) -def contains_rule_two(basic_rules): +def contains_rule_two(basic_rules) -> bool: return any(is_rule_two(basic_rule) for basic_rule in basic_rules) -def contains_rule_three(basic_rules): +def contains_rule_three(basic_rules) -> bool: return any(is_rule_three(basic_rule) for basic_rule in basic_rules) -def contains_default_rule(basic_rules): +def contains_default_rule(basic_rules) -> bool: return any(is_default_rule(basic_rule) for basic_rule in basic_rules) @@ -390,7 +393,9 @@ def is_default_rule(basic_rule): ("eXcLuDe", Policy.EXCLUDE), ], ) -def test_from_string_policy_factory_method(policy_string, expected_parsed_policy): +def test_from_string_policy_factory_method( + policy_string, expected_parsed_policy +) -> None: assert Policy.from_string(policy_string) == expected_parsed_policy @@ -408,7 +413,7 @@ def test_from_string_policy_factory_method(policy_string, expected_parsed_policy ("exclusion", False), ], ) -def test_string_is_policy(policy_string, is_policy): +def test_string_is_policy(policy_string, is_policy) -> None: if is_policy: assert Policy.is_string_policy(policy_string) else: @@ -437,7 +442,7 @@ def test_string_is_policy(policy_string, is_policy): ("sTaRtS_wItH", Rule.STARTS_WITH), ], ) -def test_from_string_rule_factory_method(rule_string, expected_parsed_rule): +def test_from_string_rule_factory_method(rule_string, expected_parsed_rule) -> None: assert Rule.from_string(rule_string) == expected_parsed_rule @@ -466,42 +471,42 @@ def test_from_string_rule_factory_method(rule_string, expected_parsed_rule): ("ends with", False), ], ) -def test_is_string_rule(rule_string, is_rule): +def test_is_string_rule(rule_string, is_rule) -> None: if is_rule: assert Rule.is_string_rule(rule_string) else: assert not Rule.is_string_rule(rule_string) -def test_raise_value_error_if_argument_cannot_be_parsed_to_policy(): +def test_raise_value_error_if_argument_cannot_be_parsed_to_policy() -> None: with pytest.raises(ValueError): Policy.from_string("unknown") -def test_raise_value_error_if_argument_cannot_be_parsed_to_rule(): +def test_raise_value_error_if_argument_cannot_be_parsed_to_rule() -> None: with pytest.raises(ValueError): Rule.from_string("unknown") -def test_from_json(): +def test_from_json() -> None: basic_rule = BasicRule.from_json(BASIC_RULE_ONE_JSON) assert is_rule_one(basic_rule) -def test_parse_none_to_empty_array(): +def test_parse_none_to_empty_array() -> None: raw_basic_rules = None assert len(parse(raw_basic_rules)) == 0 -def test_parse_empty_basic_rules_to_empty_array(): +def test_parse_empty_basic_rules_to_empty_array() -> None: raw_basic_rules = [] assert len(parse(raw_basic_rules)) == 0 -def test_parse_one_raw_basic_rule_with_policy_and_rule_lowercase(): +def test_parse_one_raw_basic_rule_with_policy_and_rule_lowercase() -> None: raw_basic_rules = [BASIC_RULE_ONE_JSON] parsed_basic_rules = parse(raw_basic_rules) @@ -510,7 +515,7 @@ def test_parse_one_raw_basic_rule_with_policy_and_rule_lowercase(): assert contains_rule_one(parsed_basic_rules) -def test_parse_one_raw_basic_rule_with_policy_and_rule_uppercase(): +def test_parse_one_raw_basic_rule_with_policy_and_rule_uppercase() -> None: raw_basic_rules = [basic_rule_one_policy_and_rule_uppercase()] parsed_basic_rules = parse(raw_basic_rules) @@ -519,7 +524,7 @@ def test_parse_one_raw_basic_rule_with_policy_and_rule_uppercase(): assert contains_rule_one(parsed_basic_rules) -def test_parses_multiple_rules_correctly(): +def test_parses_multiple_rules_correctly() -> None: raw_basic_rules = [BASIC_RULE_ONE_JSON, BASIC_RULE_TWO_JSON] parsed_basic_rules = parse(raw_basic_rules) @@ -529,7 +534,7 @@ def test_parses_multiple_rules_correctly(): assert contains_rule_two(parsed_basic_rules) -def test_parser_rejects_default_rule(): +def test_parser_rejects_default_rule() -> None: raw_basic_rules = [BASIC_RULE_DEFAULT_JSON, BASIC_RULE_ONE_JSON] parsed_basic_rules = parse(raw_basic_rules) @@ -539,7 +544,7 @@ def test_parser_rejects_default_rule(): assert not contains_default_rule(parsed_basic_rules) -def test_rules_are_ordered_ascending_with_respect_to_the_order_property(): +def test_rules_are_ordered_ascending_with_respect_to_the_order_property() -> None: raw_basic_rules = [BASIC_RULE_ONE_JSON, BASIC_RULE_THREE_JSON, BASIC_RULE_TWO_JSON] parsed_basic_rules = parse(raw_basic_rules) @@ -553,7 +558,7 @@ def test_rules_are_ordered_ascending_with_respect_to_the_order_property(): assert is_rule_three(third_rule) -def test_no_field_leads_to_no_match(): +def test_no_field_leads_to_no_match() -> None: basic_rule = BasicRule( id_=1, order=1, @@ -566,7 +571,7 @@ def test_no_field_leads_to_no_match(): assert not basic_rule.matches(DOCUMENT_ONE) -def test_starts_with_string_matches(): +def test_starts_with_string_matches() -> None: basic_rule = BasicRule( id_=1, order=1, @@ -579,7 +584,7 @@ def test_starts_with_string_matches(): assert basic_rule.matches(DOCUMENT_ONE) -def test_starts_with_string_no_match(): +def test_starts_with_string_no_match() -> None: basic_rule = BasicRule( id_=1, order=1, @@ -592,7 +597,7 @@ def test_starts_with_string_no_match(): assert not basic_rule.matches(DOCUMENT_ONE) -def test_ends_with_string_matches(): +def test_ends_with_string_matches() -> None: basic_rule = BasicRule( id_=1, order=1, @@ -605,7 +610,7 @@ def test_ends_with_string_matches(): assert basic_rule.matches(DOCUMENT_ONE) -def test_ends_with_string_no_match(): +def test_ends_with_string_no_match() -> None: basic_rule = BasicRule( id_=1, order=1, @@ -618,7 +623,7 @@ def test_ends_with_string_no_match(): assert not basic_rule.matches(DOCUMENT_ONE) -def test_contains_with_string_matches(): +def test_contains_with_string_matches() -> None: basic_rule = BasicRule( id_=1, order=1, @@ -631,7 +636,7 @@ def test_contains_with_string_matches(): assert basic_rule.matches(DOCUMENT_ONE) -def test_contains_with_string_no_match(): +def test_contains_with_string_no_match() -> None: basic_rule = BasicRule( id_=1, order=1, @@ -644,7 +649,7 @@ def test_contains_with_string_no_match(): assert not basic_rule.matches(DOCUMENT_ONE) -def test_regex_matches(): +def test_regex_matches() -> None: basic_rule = BasicRule( id_=1, order=1, @@ -657,7 +662,7 @@ def test_regex_matches(): assert basic_rule.matches(DOCUMENT_ONE) -def test_regex_no_match(): +def test_regex_no_match() -> None: basic_rule = BasicRule( id_=1, order=1, @@ -670,7 +675,7 @@ def test_regex_no_match(): assert not basic_rule.matches(DOCUMENT_ONE) -def test_less_than_string_match(): +def test_less_than_string_match() -> None: basic_rule = BasicRule( id_=1, order=1, @@ -683,7 +688,7 @@ def test_less_than_string_match(): assert basic_rule.matches(DOCUMENT_ONE) -def test_less_than_string_no_match_string_is_smaller_lexicographically(): +def test_less_than_string_no_match_string_is_smaller_lexicographically() -> None: basic_rule = BasicRule( id_=1, order=1, @@ -696,7 +701,7 @@ def test_less_than_string_no_match_string_is_smaller_lexicographically(): assert not basic_rule.matches(DOCUMENT_ONE) -def test_less_than_string_no_match_string_is_the_same(): +def test_less_than_string_no_match_string_is_the_same() -> None: basic_rule = BasicRule( id_=1, order=1, @@ -709,7 +714,7 @@ def test_less_than_string_no_match_string_is_the_same(): assert not basic_rule.matches(DOCUMENT_ONE) -def test_less_than_integer_match(): +def test_less_than_integer_match() -> None: basic_rule = BasicRule( id_=1, order=1, @@ -722,7 +727,7 @@ def test_less_than_integer_match(): assert basic_rule.matches(DOCUMENT_ONE) -def test_less_than_integer_no_match_numbers_are_the_same(): +def test_less_than_integer_no_match_numbers_are_the_same() -> None: basic_rule = BasicRule( id_=1, order=1, @@ -735,7 +740,7 @@ def test_less_than_integer_no_match_numbers_are_the_same(): assert not basic_rule.matches(DOCUMENT_ONE) -def test_less_than_integer_no_match_document_value_is_greater(): +def test_less_than_integer_no_match_document_value_is_greater() -> None: basic_rule = BasicRule( id_=1, order=1, @@ -748,7 +753,7 @@ def test_less_than_integer_no_match_document_value_is_greater(): assert not basic_rule.matches(DOCUMENT_ONE) -def test_less_than_float_match(): +def test_less_than_float_match() -> None: basic_rule = BasicRule( id_=1, order=1, @@ -761,7 +766,7 @@ def test_less_than_float_match(): assert basic_rule.matches(DOCUMENT_ONE) -def test_less_than_float_no_match_numbers_are_the_same(): +def test_less_than_float_no_match_numbers_are_the_same() -> None: basic_rule = BasicRule( id_=1, order=1, @@ -774,7 +779,7 @@ def test_less_than_float_no_match_numbers_are_the_same(): assert not basic_rule.matches(DOCUMENT_ONE) -def test_less_than_float_no_match_document_value_is_greater(): +def test_less_than_float_no_match_document_value_is_greater() -> None: basic_rule = BasicRule( id_=1, order=1, @@ -787,7 +792,7 @@ def test_less_than_float_no_match_document_value_is_greater(): assert not basic_rule.matches(DOCUMENT_ONE) -def test_less_than_datetime_match(): +def test_less_than_datetime_match() -> None: basic_rule = BasicRule( id_=1, order=1, @@ -800,7 +805,7 @@ def test_less_than_datetime_match(): assert basic_rule.matches(DOCUMENT_ONE) -def test_less_than_datetime_no_match_same_time(): +def test_less_than_datetime_no_match_same_time() -> None: basic_rule = BasicRule( id_=1, order=1, @@ -813,7 +818,7 @@ def test_less_than_datetime_no_match_same_time(): assert not basic_rule.matches(DOCUMENT_ONE) -def test_less_than_datetime_no_match_later_time(): +def test_less_than_datetime_no_match_later_time() -> None: basic_rule = BasicRule( id_=1, order=1, @@ -826,7 +831,7 @@ def test_less_than_datetime_no_match_later_time(): assert not basic_rule.matches(DOCUMENT_ONE) -def test_less_than_date_match(): +def test_less_than_date_match() -> None: basic_rule = BasicRule( id_=1, order=1, @@ -839,7 +844,7 @@ def test_less_than_date_match(): assert basic_rule.matches(DOCUMENT_ONE) -def test_less_than_date_no_match_same_time(): +def test_less_than_date_no_match_same_time() -> None: basic_rule = BasicRule( id_=1, order=1, @@ -852,7 +857,7 @@ def test_less_than_date_no_match_same_time(): assert not basic_rule.matches(DOCUMENT_ONE) -def test_less_than_date_no_match_later_time(): +def test_less_than_date_no_match_later_time() -> None: basic_rule = BasicRule( id_=1, order=1, @@ -865,7 +870,7 @@ def test_less_than_date_no_match_later_time(): assert not basic_rule.matches(DOCUMENT_ONE) -def test_greater_than_string_match(): +def test_greater_than_string_match() -> None: basic_rule = BasicRule( id_=1, order=1, @@ -878,7 +883,7 @@ def test_greater_than_string_match(): assert basic_rule.matches(DOCUMENT_ONE) -def test_greater_than_string_no_match_string_is_greater_lexicographically(): +def test_greater_than_string_no_match_string_is_greater_lexicographically() -> None: basic_rule = BasicRule( id_=1, order=1, @@ -891,7 +896,7 @@ def test_greater_than_string_no_match_string_is_greater_lexicographically(): assert not basic_rule.matches(DOCUMENT_ONE) -def test_greater_than_string_no_match_string_is_the_same(): +def test_greater_than_string_no_match_string_is_the_same() -> None: basic_rule = BasicRule( id_=1, order=1, @@ -904,7 +909,7 @@ def test_greater_than_string_no_match_string_is_the_same(): assert not basic_rule.matches(DOCUMENT_ONE) -def test_greater_than_integer_match(): +def test_greater_than_integer_match() -> None: basic_rule = BasicRule( id_=1, order=1, @@ -917,7 +922,7 @@ def test_greater_than_integer_match(): assert basic_rule.matches(DOCUMENT_ONE) -def test_greater_than_integer_no_match_numbers_are_the_same(): +def test_greater_than_integer_no_match_numbers_are_the_same() -> None: basic_rule = BasicRule( id_=1, order=1, @@ -930,7 +935,7 @@ def test_greater_than_integer_no_match_numbers_are_the_same(): assert not basic_rule.matches(DOCUMENT_ONE) -def test_greater_than_integer_no_match_document_value_is_less(): +def test_greater_than_integer_no_match_document_value_is_less() -> None: basic_rule = BasicRule( id_=1, order=1, @@ -943,7 +948,7 @@ def test_greater_than_integer_no_match_document_value_is_less(): assert not basic_rule.matches(DOCUMENT_ONE) -def test_greater_than_float_match(): +def test_greater_than_float_match() -> None: basic_rule = BasicRule( id_=1, order=1, @@ -956,7 +961,7 @@ def test_greater_than_float_match(): assert basic_rule.matches(DOCUMENT_ONE) -def test_greater_than_float_no_match_numbers_are_the_same(): +def test_greater_than_float_no_match_numbers_are_the_same() -> None: basic_rule = BasicRule( id_=1, order=1, @@ -969,7 +974,7 @@ def test_greater_than_float_no_match_numbers_are_the_same(): assert not basic_rule.matches(DOCUMENT_ONE) -def test_greater_than_float_no_match_document_value_is_greater(): +def test_greater_than_float_no_match_document_value_is_greater() -> None: basic_rule = BasicRule( id_=1, order=1, @@ -982,7 +987,7 @@ def test_greater_than_float_no_match_document_value_is_greater(): assert not basic_rule.matches(DOCUMENT_ONE) -def test_greater_than_datetime_match(): +def test_greater_than_datetime_match() -> None: basic_rule = BasicRule( id_=1, order=1, @@ -995,7 +1000,7 @@ def test_greater_than_datetime_match(): assert basic_rule.matches(DOCUMENT_ONE) -def test_greater_than_datetime_no_match_same_time(): +def test_greater_than_datetime_no_match_same_time() -> None: basic_rule = BasicRule( id_=1, order=1, @@ -1008,7 +1013,7 @@ def test_greater_than_datetime_no_match_same_time(): assert not basic_rule.matches(DOCUMENT_ONE) -def test_greater_than_datetime_no_match_earlier_time(): +def test_greater_than_datetime_no_match_earlier_time() -> None: basic_rule = BasicRule( id_=1, order=1, @@ -1021,7 +1026,7 @@ def test_greater_than_datetime_no_match_earlier_time(): assert not basic_rule.matches(DOCUMENT_ONE) -def test_greater_than_date_match(): +def test_greater_than_date_match() -> None: basic_rule = BasicRule( id_=1, order=1, @@ -1034,7 +1039,7 @@ def test_greater_than_date_match(): assert basic_rule.matches(DOCUMENT_ONE) -def test_greater_than_date_no_match_same_time(): +def test_greater_than_date_no_match_same_time() -> None: basic_rule = BasicRule( id_=1, order=1, @@ -1047,7 +1052,7 @@ def test_greater_than_date_no_match_same_time(): assert not basic_rule.matches(DOCUMENT_ONE) -def test_greater_than_date_no_match_earlier_time(): +def test_greater_than_date_no_match_earlier_time() -> None: basic_rule = BasicRule( id_=1, order=1, @@ -1060,7 +1065,7 @@ def test_greater_than_date_no_match_earlier_time(): assert not basic_rule.matches(DOCUMENT_ONE) -def test_equals_integer_match(): +def test_equals_integer_match() -> None: basic_rule = BasicRule( id_=1, order=1, @@ -1073,7 +1078,7 @@ def test_equals_integer_match(): assert basic_rule.matches(DOCUMENT_ONE) -def test_equals_integer_no_match(): +def test_equals_integer_no_match() -> None: basic_rule = BasicRule( id_=1, order=1, @@ -1086,7 +1091,7 @@ def test_equals_integer_no_match(): assert not basic_rule.matches(DOCUMENT_ONE) -def test_equals_float_match(): +def test_equals_float_match() -> None: basic_rule = BasicRule( id_=1, order=1, @@ -1099,7 +1104,7 @@ def test_equals_float_match(): assert basic_rule.matches(DOCUMENT_ONE) -def test_equals_float_no_match(): +def test_equals_float_no_match() -> None: basic_rule = BasicRule( id_=1, order=1, @@ -1112,7 +1117,7 @@ def test_equals_float_no_match(): assert not basic_rule.matches(DOCUMENT_ONE) -def test_equals_string_match(): +def test_equals_string_match() -> None: basic_rule = BasicRule( id_=1, order=1, @@ -1125,7 +1130,7 @@ def test_equals_string_match(): assert basic_rule.matches(DOCUMENT_ONE) -def test_equals_string_no_match(): +def test_equals_string_no_match() -> None: basic_rule = BasicRule( id_=1, order=1, @@ -1138,7 +1143,7 @@ def test_equals_string_no_match(): assert not basic_rule.matches(DOCUMENT_ONE) -def test_equals_datetime_match(): +def test_equals_datetime_match() -> None: basic_rule = BasicRule( id_=1, order=1, @@ -1151,7 +1156,7 @@ def test_equals_datetime_match(): assert basic_rule.matches(DOCUMENT_ONE) -def test_equals_datetime_no_match(): +def test_equals_datetime_no_match() -> None: basic_rule = BasicRule( id_=1, order=1, @@ -1164,7 +1169,7 @@ def test_equals_datetime_no_match(): assert not basic_rule.matches(DOCUMENT_ONE) -def test_equals_date_match(): +def test_equals_date_match() -> None: basic_rule = BasicRule( id_=1, order=1, @@ -1177,7 +1182,7 @@ def test_equals_date_match(): assert basic_rule.matches(DOCUMENT_ONE) -def test_equals_date_no_match(): +def test_equals_date_no_match() -> None: basic_rule = BasicRule( id_=1, order=1, @@ -1190,7 +1195,7 @@ def test_equals_date_no_match(): assert not basic_rule.matches(DOCUMENT_ONE) -def test_coerce_rule_value_to_str(): +def test_coerce_rule_value_to_str() -> None: basic_rule = BasicRule( id_=1, order=1, @@ -1208,7 +1213,7 @@ def test_coerce_rule_value_to_str(): assert coerced_rule_value == "string" -def test_coerce_rule_value_to_float(): +def test_coerce_rule_value_to_float() -> None: basic_rule = BasicRule( id_=1, order=1, @@ -1224,7 +1229,7 @@ def test_coerce_rule_value_to_float(): assert coerced_rule_value == 1.0 -def test_coerce_rule_value_to_float_if_it_is_an_int(): +def test_coerce_rule_value_to_float_if_it_is_an_int() -> None: basic_rule = BasicRule( id_=1, order=1, @@ -1240,7 +1245,7 @@ def test_coerce_rule_value_to_float_if_it_is_an_int(): assert coerced_rule_value == 1.0 -def test_coerce_rule_value_to_bool(): +def test_coerce_rule_value_to_bool() -> None: basic_rule = BasicRule( id_=1, order=1, @@ -1256,7 +1261,7 @@ def test_coerce_rule_value_to_bool(): assert bool(coerced_rule_value) -def test_coerce_rule_value_to_datetime_date_if_it_is_datetime(): +def test_coerce_rule_value_to_datetime_date_if_it_is_datetime() -> None: basic_rule = BasicRule( id_=1, order=1, @@ -1274,7 +1279,7 @@ def test_coerce_rule_value_to_datetime_date_if_it_is_datetime(): assert coerced_rule_value == datetime.datetime(year=2022, month=1, day=1) -def test_coerce_rule_value_to_datetime_date_if_it_is_date(): +def test_coerce_rule_value_to_datetime_date_if_it_is_date() -> None: basic_rule = BasicRule( id_=1, order=1, @@ -1292,7 +1297,7 @@ def test_coerce_rule_value_to_datetime_date_if_it_is_date(): assert coerced_rule_value == datetime.datetime(year=2022, month=1, day=1) -def test_coerce_rule_to_default_if_type_is_not_registered(): +def test_coerce_rule_to_default_if_type_is_not_registered() -> None: basic_rule = BasicRule( id_=1, order=1, @@ -1311,7 +1316,9 @@ def test_coerce_rule_to_default_if_type_is_not_registered(): assert coerced_rule_value == "something" -def test_coerce_rule_to_default_if_doc_value_type_not_matching_rule_value_type(): +def test_coerce_rule_to_default_if_doc_value_type_not_matching_rule_value_type() -> ( + None +): basic_rule = BasicRule( id_=1, order=1, @@ -1327,7 +1334,7 @@ def test_coerce_rule_to_default_if_doc_value_type_not_matching_rule_value_type() assert coerced_rule_value == "something" -def test_is_include_for_include_policy(): +def test_is_include_for_include_policy() -> None: basic_rule = BasicRule( id_=1, order=1, @@ -1340,7 +1347,7 @@ def test_is_include_for_include_policy(): assert basic_rule.is_include() -def test_is_not_include_for_exclude_policy(): +def test_is_not_include_for_exclude_policy() -> None: basic_rule = BasicRule( id_=1, order=1, @@ -1353,7 +1360,7 @@ def test_is_not_include_for_exclude_policy(): assert not basic_rule.is_include() -def test_basic_rule_str(): +def test_basic_rule_str() -> None: basic_rule = BasicRule( id_=1, order=1, @@ -1369,7 +1376,7 @@ def test_basic_rule_str(): ) -def test_basic_rule_format(): +def test_basic_rule_format() -> None: basic_rule = BasicRule( id_=str(uuid.UUID.bytes), order=1, diff --git a/tests/filtering/test_validation.py b/tests/filtering/test_validation.py index 2170556b1..7eea38dda 100644 --- a/tests/filtering/test_validation.py +++ b/tests/filtering/test_validation.py @@ -4,6 +4,7 @@ # you may not use this file except in compliance with the Elastic License 2.0. # +from typing import Dict, Union from unittest.mock import AsyncMock, Mock import pytest @@ -42,7 +43,7 @@ ADVANCED_RULE_TWO_ID = 6 ADVANCED_RULE_TWO_VALIDATION_MESSAGE = "rule 6 is invalid" -RULE_ONE = { +RULE_ONE: Dict[str, Union[int, str]] = { "id": RULE_ONE_ID, "order": 1, "policy": "include", @@ -51,7 +52,7 @@ "value": "value", } -RULE_TWO = { +RULE_TWO: Dict[str, Union[int, str]] = { "id": RULE_TWO_ID, "order": 2, "policy": "include", @@ -100,11 +101,13 @@ ), ], ) -def test_sync_rule_validation_result_eq(result_one, result_two, should_be_equal): +def test_sync_rule_validation_result_eq( + result_one, result_two, should_be_equal +) -> None: assert result_one == result_two if should_be_equal else result_one != result_two -def test_sync_rule_validation_result_eq_wrong_type(): +def test_sync_rule_validation_result_eq_wrong_type() -> None: with pytest.raises(TypeError): assert ( SyncRuleValidationResult(RULE_ONE_ID, True, RULE_ONE_VALIDATION_MESSAGE) @@ -219,7 +222,9 @@ def test_sync_rule_validation_result_eq_wrong_type(): ), ], ) -def test_filtering_validation_result_eq(result_one, result_two, should_be_equal): +def test_filtering_validation_result_eq( + result_one, result_two, should_be_equal +) -> None: assert result_one == result_two if should_be_equal else result_one != result_two @@ -280,11 +285,11 @@ def test_filtering_validation_result_eq(result_one, result_two, should_be_equal) ), ], ) -def test_filter_validation_error_eq(error_one, error_two, should_be_equal): +def test_filter_validation_error_eq(error_one, error_two, should_be_equal) -> None: assert error_one == error_two if should_be_equal else error_one != error_two -def test_valid_result_sync_rule(): +def test_valid_result_sync_rule() -> None: result = SyncRuleValidationResult.valid_result(RULE_ONE_ID) assert result.rule_id == RULE_ONE_ID @@ -391,7 +396,7 @@ def test_valid_result_sync_rule(): ) def test_filtering_validation_result( sync_rule_validation_results, expected_filtering_validation_result -): +) -> None: filtering_validation_result = FilteringValidationResult() for result in sync_rule_validation_results: @@ -676,7 +681,7 @@ def test_filtering_validation_result( @pytest.mark.asyncio async def test_filtering_validator( basic_rule_validation_results, advanced_rule_validation_results, expected_result -): +) -> None: basic_rule_validators = validator_fakes(basic_rule_validation_results) advanced_rule_validators = validator_fakes( advanced_rule_validation_results, is_basic_rule_validator=False @@ -702,7 +707,9 @@ async def test_filtering_validator( @pytest.mark.asyncio -async def test_filtering_validator_validate_when_advanced_rules_empty_then_skip_validation(): +async def test_filtering_validator_validate_when_advanced_rules_empty_then_skip_validation() -> ( + None +): invalid_validation_result = SyncRuleValidationResult( rule_id=RULE_TWO_ID, is_valid=False, @@ -723,7 +730,7 @@ async def test_filtering_validator_validate_when_advanced_rules_empty_then_skip_ assert validation_result.state == FilteringValidationState.VALID -def validator_fakes(results, is_basic_rule_validator=True): +def validator_fakes(results, is_basic_rule_validator: bool = True): validators = [] # validator 1 returns result 1, validator 2 returns result 2 ... @@ -755,7 +762,7 @@ def validator_fakes(results, is_basic_rule_validator=True): return validators -def assert_validators_called_with(validators, payload): +def assert_validators_called_with(validators, payload) -> None: for validator in validators: if issubclass(validator, BasicRulesSetValidator): validator.validate.assert_called_with(payload) @@ -834,7 +841,7 @@ def assert_validators_called_with(validators, payload): @pytest.mark.asyncio async def test_filtering_validator_multiple_basic_rules( basic_rules_validation_results, expected_result -): +) -> None: basic_rule_validator = BasicRulesSetValidator # side_effect: return result 1 on first call, result 2 on second call ... basic_rule_validator.validate = Mock(side_effect=[basic_rules_validation_results]) @@ -848,7 +855,7 @@ async def test_filtering_validator_multiple_basic_rules( assert validation_result == expected_result -def basic_rule_json(merge_with=None, delete_keys=None): +def basic_rule_json(merge_with=None, delete_keys=None) -> Dict[str, Union[int, str]]: # Default arguments are mutable if delete_keys is None: delete_keys = [] @@ -890,7 +897,7 @@ def basic_rule_json(merge_with=None, delete_keys=None): (basic_rule_json(merge_with={"rule": "regex", "value": "(.*)"}), False), ], ) -def test_basic_rule_validate_no_match_all_regex(basic_rule, should_be_valid): +def test_basic_rule_validate_no_match_all_regex(basic_rule, should_be_valid) -> None: if should_be_valid: assert BasicRuleNoMatchAllRegexValidator.validate(basic_rule).is_valid else: @@ -947,7 +954,7 @@ def test_basic_rule_validate_no_match_all_regex(basic_rule, should_be_valid): ), ], ) -def test_basic_rule_against_schema_validation(basic_rule, should_be_valid): +def test_basic_rule_against_schema_validation(basic_rule, should_be_valid) -> None: if should_be_valid: assert BasicRuleAgainstSchemaValidator.validate(basic_rule).is_valid else: @@ -1001,7 +1008,7 @@ def test_basic_rule_against_schema_validation(basic_rule, should_be_valid): ) def test_basic_rules_set_no_conflicting_policies_validation( basic_rules, should_be_valid -): +) -> None: validation_results = BasicRulesSetSemanticValidator.validate(basic_rules) if should_be_valid: @@ -1023,12 +1030,12 @@ def test_basic_rules_set_no_conflicting_policies_validation( ("edited", FilteringValidationState.EDITED), ], ) -def test_filtering_validation_state_from_string(string, expected_state): +def test_filtering_validation_state_from_string(string, expected_state) -> None: assert FilteringValidationState(string) == expected_state @pytest.mark.asyncio -async def test_filtering_validator_validate_single_advanced_rules_validator(): +async def test_filtering_validator_validate_single_advanced_rules_validator() -> None: invalid_validation_result = SyncRuleValidationResult( rule_id=RULE_TWO_ID, is_valid=False, diff --git a/tests/protocol/test_connectors.py b/tests/protocol/test_connectors.py index 45a101198..3076d1d40 100644 --- a/tests/protocol/test_connectors.py +++ b/tests/protocol/test_connectors.py @@ -7,6 +7,7 @@ import os from copy import deepcopy from datetime import datetime, timedelta, timezone +from typing import Dict, List, Union from unittest.mock import ANY, AsyncMock, Mock, patch import pytest @@ -44,10 +45,10 @@ from connectors.utils import ACCESS_CONTROL_INDEX_PREFIX, iso_utc from tests.commons import AsyncIterator -HERE = os.path.dirname(__file__) -FIXTURES_DIR = os.path.abspath(os.path.join(HERE, "..", "fixtures")) +HERE: str = os.path.dirname(__file__) +FIXTURES_DIR: str = os.path.abspath(os.path.join(HERE, "..", "fixtures")) -CONFIG = os.path.join(FIXTURES_DIR, "config.yml") +CONFIG: str = os.path.join(FIXTURES_DIR, "config.yml") DEFAULT_DOMAIN = "DEFAULT" @@ -175,15 +176,19 @@ ADVANCED_RULES = {"db": {"table": "SELECT * FROM db.table"}} -ADVANCED_RULES_NON_EMPTY = {"advanced_snippet": ADVANCED_RULES} +ADVANCED_RULES_NON_EMPTY: Dict[str, Dict[str, Dict[str, str]]] = { + "advanced_snippet": ADVANCED_RULES +} RULES = [ { "id": 1, } ] -BASIC_RULES_NON_EMPTY = {"rules": RULES} -ADVANCED_AND_BASIC_RULES_NON_EMPTY = { +BASIC_RULES_NON_EMPTY: Dict[str, List[Dict[str, int]]] = {"rules": RULES} +ADVANCED_AND_BASIC_RULES_NON_EMPTY: Dict[ + str, Union[Dict[str, Dict[str, str]], List[Dict[str, int]]] +] = { "advanced_snippet": {"db": {"table": "SELECT * FROM db.table"}}, "rules": RULES, } @@ -193,14 +198,16 @@ INDEX_NAME = "index_name" -def test_utc(): +def test_utc() -> None: # All dates are in ISO 8601 UTC so we can serialize them now = datetime.utcnow() then = json.loads(json.dumps({"date": iso_utc(when=now)}))["date"] assert now.isoformat() == then -mongo = { +mongo: Dict[ + str, Union[None, Dict[str, Dict[str, str]], Dict[str, Union[bool, str]], str] +] = { "api_key_id": "", "api_key_secret_id": "", "configuration": { @@ -237,7 +244,7 @@ def test_utc(): ) async def test_supported_connectors( native_service_types, connector_ids, expected_connector_count, mock_responses -): +) -> None: config = {"host": "http://nowhere.com:9200", "user": "tarek", "password": "blah"} native_connectors_query = { "bool": { @@ -296,7 +303,7 @@ async def test_supported_connectors( @pytest.mark.asyncio -async def test_all_connectors(mock_responses): +async def test_all_connectors(mock_responses) -> None: config = {"host": "http://nowhere.com:9200", "user": "tarek", "password": "blah"} headers = {"X-Elastic-Product": "Elasticsearch"} mock_responses.post( @@ -321,7 +328,7 @@ async def test_all_connectors(mock_responses): @pytest.mark.asyncio -async def test_connector_properties(): +async def test_connector_properties() -> None: connector_src = { "_id": "test", "_source": { @@ -387,7 +394,7 @@ async def test_connector_properties(): (60, iso_utc(datetime.now(timezone.utc) - timedelta(seconds=70)), True), ], ) -async def test_heartbeat(interval, last_seen, should_send_heartbeat): +async def test_heartbeat(interval, last_seen, should_send_heartbeat) -> None: source = { "_id": "1", "_source": { @@ -439,7 +446,7 @@ async def test_heartbeat(interval, last_seen, should_send_heartbeat): ], ) @pytest.mark.asyncio -async def test_sync_starts(job_type, expected_doc_source_update): +async def test_sync_starts(job_type, expected_doc_source_update) -> None: doc_id = "1" seq_no = 1 primary_term = 2 @@ -458,7 +465,7 @@ async def test_sync_starts(job_type, expected_doc_source_update): @pytest.mark.asyncio -async def test_connector_error(): +async def test_connector_error() -> None: connector_doc = {"_id": "1"} error = "something wrong" index = Mock() @@ -474,11 +481,11 @@ async def test_connector_error(): def mock_job( - status=JobStatus.COMPLETED, - job_type=JobType.FULL, + status: JobStatus = JobStatus.COMPLETED, + job_type: JobType = JobType.FULL, error=None, - terminated=True, -): + terminated: bool = True, +) -> Mock: job = Mock() job.status = status job.error = error @@ -625,7 +632,7 @@ def mock_job( ), ], ) -async def test_sync_done(job, expected_doc_source_update): +async def test_sync_done(job, expected_doc_source_update) -> None: connector_doc = {"_id": "1"} index = Mock() index.update = AsyncMock(return_value=1) @@ -635,7 +642,7 @@ async def test_sync_done(job, expected_doc_source_update): index.update.assert_called_with(doc_id=connector.id, doc=expected_doc_source_update) -mock_next_run = iso_utc() +mock_next_run: str = iso_utc() @pytest.mark.asyncio @@ -653,7 +660,7 @@ async def test_sync_done(job, expected_doc_source_update): @patch("connectors.protocol.connectors.next_run") async def test_connector_next_sync( next_run, scheduling_enabled, expected_next_sync, job_type -): +) -> None: connector_doc = { "_id": "1", "_source": { @@ -681,7 +688,7 @@ async def test_connector_next_sync( @pytest.mark.asyncio -async def test_sync_job_properties(): +async def test_sync_job_properties() -> None: sync_job_src = { "_id": "test", "_source": { @@ -737,7 +744,7 @@ async def test_sync_job_properties(): (JobType.ACCESS_CONTROL, False), ], ) -def test_is_content_sync(job_type, is_content_sync): +def test_is_content_sync(job_type, is_content_sync) -> None: source = {"_id": "1", "_source": {"job_type": job_type.value}} sync_job = SyncJob(elastic_index=None, doc_source=source) assert sync_job.is_content_sync() == is_content_sync @@ -761,7 +768,7 @@ def test_is_content_sync(job_type, is_content_sync): ) async def test_sync_job_validate_filtering( validation_result_state, validation_result_errors, should_raise_exception -): +) -> None: source = {"_id": "1"} index = Mock() validator = Mock() @@ -779,7 +786,7 @@ async def test_sync_job_validate_filtering( @pytest.mark.asyncio -async def test_sync_job_claim(): +async def test_sync_job_claim() -> None: source = {"_id": "1"} index = Mock() index.update = AsyncMock(return_value=1) @@ -799,7 +806,7 @@ async def test_sync_job_claim(): @pytest.mark.asyncio -async def test_sync_job_claim_with_connector_api(set_env): +async def test_sync_job_claim_with_connector_api(set_env) -> None: source = {"_id": "1"} index = Mock() index.api.connector_sync_job_claim = AsyncMock(return_value={"result": "updated"}) @@ -814,7 +821,7 @@ async def test_sync_job_claim_with_connector_api(set_env): @pytest.mark.asyncio -async def test_sync_job_claim_fails(): +async def test_sync_job_claim_fails() -> None: source = {"_id": "1"} index = Mock() api_meta = Mock() @@ -836,7 +843,7 @@ async def test_sync_job_claim_fails(): @pytest.mark.asyncio -async def test_sync_job_update_metadata(): +async def test_sync_job_update_metadata() -> None: source = {"_id": "1"} index = Mock() index.feature_use_connectors_api = False @@ -865,7 +872,7 @@ async def test_sync_job_update_metadata(): @pytest.mark.asyncio -async def test_sync_job_update_metadata_with_connector_api(): +async def test_sync_job_update_metadata_with_connector_api() -> None: source = {"_id": "1"} index = Mock() index.feature_use_connectors_api = True @@ -892,7 +899,7 @@ async def test_sync_job_update_metadata_with_connector_api(): @pytest.mark.asyncio -async def test_sync_job_done(): +async def test_sync_job_done() -> None: source = {"_id": "1"} index = Mock() index.update = AsyncMock(return_value=1) @@ -919,7 +926,7 @@ async def test_sync_job_done(): (123, "123"), ], ) -async def test_sync_job_fail(error, expected_message): +async def test_sync_job_fail(error, expected_message) -> None: source = {"_id": "1"} index = Mock() index.update = AsyncMock(return_value=1) @@ -937,7 +944,7 @@ async def test_sync_job_fail(error, expected_message): @pytest.mark.asyncio -async def test_sync_job_cancel(): +async def test_sync_job_cancel() -> None: source = {"_id": "1"} index = Mock() index.update = AsyncMock(return_value=1) @@ -956,7 +963,7 @@ async def test_sync_job_cancel(): @pytest.mark.asyncio -async def test_sync_job_suspend(): +async def test_sync_job_suspend() -> None: source = {"_id": "1"} index = Mock() index.update = AsyncMock(return_value=1) @@ -976,12 +983,12 @@ class Banana(BaseDataSource): """Banana""" @classmethod - def get_default_configuration(cls): + def get_default_configuration(cls) -> Dict[str, Dict[str, None]]: return {"one": {"value": None}, "two": {"value": None}} @pytest.mark.asyncio -async def test_connector_prepare_different_id_invalid_source(): +async def test_connector_prepare_different_id_invalid_source() -> None: doc_id = "1" seq_no = 1 primary_term = 2 @@ -1012,7 +1019,9 @@ async def test_connector_prepare_different_id_invalid_source(): ], ) @pytest.mark.asyncio -async def test_connector_prepare_with_prepared_connector(main_doc_id, this_doc_id): +async def test_connector_prepare_with_prepared_connector( + main_doc_id, this_doc_id +) -> None: seq_no = 1 primary_term = 2 connector_doc = { @@ -1079,7 +1088,7 @@ async def test_connector_prepare_with_prepared_connector(main_doc_id, this_doc_i @pytest.mark.asyncio async def test_connector_prepare_with_connector_empty_config_creates_default( main_doc_id, this_doc_id -): +) -> None: seq_no = 1 primary_term = 2 connector_doc = { @@ -1124,7 +1133,7 @@ async def test_connector_prepare_with_connector_empty_config_creates_default( @pytest.mark.asyncio async def test_connector_prepare_with_connector_missing_fields_creates_them( main_doc_id, this_doc_id -): +) -> None: seq_no = 1 primary_term = 2 connector_doc = { @@ -1186,7 +1195,7 @@ async def test_connector_prepare_with_connector_missing_fields_creates_them( @pytest.mark.asyncio async def test_connector_prepare_with_connector_missing_field_properties_creates_them( main_doc_id, this_doc_id -): +) -> None: seq_no = 1 primary_term = 2 connector_doc = { @@ -1243,7 +1252,7 @@ async def test_connector_prepare_with_connector_missing_field_properties_creates @pytest.mark.asyncio -async def test_connector_prepare_with_service_type_not_configured(): +async def test_connector_prepare_with_service_type_not_configured() -> None: doc_id = "1" seq_no = 1 primary_term = 2 @@ -1267,7 +1276,7 @@ async def test_connector_prepare_with_service_type_not_configured(): @pytest.mark.asyncio -async def test_connector_prepare_with_service_type_not_supported(): +async def test_connector_prepare_with_service_type_not_supported() -> None: doc_id = "1" seq_no = 1 primary_term = 2 @@ -1292,7 +1301,7 @@ async def test_connector_prepare_with_service_type_not_supported(): @pytest.mark.asyncio -async def test_connector_prepare_with_data_source_error(): +async def test_connector_prepare_with_data_source_error() -> None: doc_id = "1" seq_no = 1 primary_term = 2 @@ -1317,7 +1326,7 @@ async def test_connector_prepare_with_data_source_error(): @pytest.mark.asyncio -async def test_connector_prepare_with_different_features(): +async def test_connector_prepare_with_different_features() -> None: doc_id = "1" seq_no = 1 primary_term = 2 @@ -1350,7 +1359,7 @@ async def test_connector_prepare_with_different_features(): @pytest.mark.asyncio -async def test_connector_prepare(): +async def test_connector_prepare() -> None: doc_id = "1" seq_no = 1 primary_term = 2 @@ -1384,7 +1393,7 @@ async def test_connector_prepare(): @pytest.mark.asyncio -async def test_connector_prepare_with_race_condition(): +async def test_connector_prepare_with_race_condition() -> None: doc_id = "1" seq_no = 1 primary_term = 2 @@ -1443,7 +1452,7 @@ async def test_connector_prepare_with_race_condition(): @pytest.mark.asyncio async def test_connector_update_last_sync_scheduled_at_by_job_type( job_type, date_field_to_update -): +) -> None: doc_id = "2" seq_no = 2 primary_term = 1 @@ -1468,7 +1477,7 @@ async def test_connector_update_last_sync_scheduled_at_by_job_type( @pytest.mark.asyncio -async def test_connector_validate_filtering_not_edited(): +async def test_connector_validate_filtering_not_edited() -> None: index = Mock() index.update = AsyncMock() index.feature_use_connectors_api = False @@ -1484,7 +1493,7 @@ async def test_connector_validate_filtering_not_edited(): @pytest.mark.asyncio -async def test_connector_validate_filtering_invalid(): +async def test_connector_validate_filtering_invalid() -> None: doc_source = deepcopy(DOC_SOURCE_WITH_EDITED_FILTERING) index = Mock() index.update = AsyncMock() @@ -1521,7 +1530,7 @@ async def test_connector_validate_filtering_invalid(): @pytest.mark.asyncio -async def test_connector_validate_filtering_valid(): +async def test_connector_validate_filtering_valid() -> None: doc_source = deepcopy(DOC_SOURCE_WITH_EDITED_FILTERING) index = Mock() index.update = AsyncMock() @@ -1556,7 +1565,7 @@ async def test_connector_validate_filtering_valid(): @pytest.mark.asyncio -async def test_connector_validate_filtering_with_race_condition(): +async def test_connector_validate_filtering_with_race_condition() -> None: doc_source = deepcopy(DOC_SOURCE_WITH_EDITED_FILTERING) index = Mock() index.update = AsyncMock() @@ -1594,7 +1603,7 @@ async def test_connector_validate_filtering_with_race_condition(): @pytest.mark.asyncio -async def test_connector_validate_filtering_invalid_with_connector_api(set_env): +async def test_connector_validate_filtering_invalid_with_connector_api(set_env) -> None: doc_source = deepcopy(DOC_SOURCE_WITH_EDITED_FILTERING) index = Mock() index.api.connector_update_filtering_draft_validation = AsyncMock() @@ -1620,7 +1629,7 @@ async def test_connector_validate_filtering_invalid_with_connector_api(set_env): @pytest.mark.asyncio -async def test_connector_validate_filtering_valid_with_connector_api(set_env): +async def test_connector_validate_filtering_valid_with_connector_api(set_env) -> None: doc_source = deepcopy(DOC_SOURCE_WITH_EDITED_FILTERING) index = Mock() index.api.connector_update_filtering_draft_validation = AsyncMock() @@ -1649,7 +1658,7 @@ async def test_connector_validate_filtering_valid_with_connector_api(set_env): @pytest.mark.asyncio -async def test_connector_exists_returns_true_when_found(): +async def test_connector_exists_returns_true_when_found() -> None: config = { "username": "elastic", "password": "changeme", @@ -1664,7 +1673,7 @@ async def test_connector_exists_returns_true_when_found(): @pytest.mark.asyncio -async def test_connector_exists_returns_false_when_not_found(): +async def test_connector_exists_returns_false_when_not_found() -> None: config = { "username": "elastic", "password": "changeme", @@ -1686,7 +1695,7 @@ async def test_connector_exists_returns_false_when_not_found(): @pytest.mark.asyncio -async def test_connector_exists_raises_non_404_exception(): +async def test_connector_exists_raises_non_404_exception() -> None: config = { "username": "elastic", "password": "changeme", @@ -1701,7 +1710,7 @@ async def test_connector_exists_raises_non_404_exception(): @pytest.mark.asyncio -async def test_document_count(): +async def test_document_count() -> None: expected_count = 20 index = Mock() index.serverless = False @@ -1741,7 +1750,9 @@ async def test_document_count(): (None, ACTIVE_FILTER_STATE, NON_EXISTING_DOMAIN, EMPTY_FILTER), ], ) -def test_get_filter(filtering_json, filter_state, domain, expected_filter): +def test_get_filter( + filtering_json, filter_state: str, domain: str, expected_filter +) -> None: filtering = Filtering(filtering_json) assert filtering.get_filter(filter_state, domain) == expected_filter @@ -1754,7 +1765,7 @@ def test_get_filter(filtering_json, filter_state, domain, expected_filter): (None, ACTIVE_FILTERING_DEFAULT_DOMAIN), ], ) -def test_get_active_filter(domain, expected_filter): +def test_get_active_filter(domain: str, expected_filter) -> None: filtering = Filtering(FILTERING) if domain is not None: @@ -1770,7 +1781,7 @@ def test_get_active_filter(domain, expected_filter): (None, DRAFT_FILTERING_DEFAULT_DOMAIN), ], ) -def test_get_draft_filter(domain, expected_filter): +def test_get_draft_filter(domain: str, expected_filter) -> None: filtering = Filtering(FILTERING) if domain is not None: @@ -1795,7 +1806,7 @@ def test_get_draft_filter(domain, expected_filter): (None, {"advanced_snippet": {}, "rules": []}), ], ) -def test_transform_filtering(filtering, expected_transformed_filtering): +def test_transform_filtering(filtering, expected_transformed_filtering) -> None: assert ( Filter(filter_=filtering).transform_filtering() == expected_transformed_filtering @@ -1912,7 +1923,7 @@ def test_transform_filtering(filtering, expected_transformed_filtering): ), ], ) -def test_feature_enabled(features_json, feature_enabled): +def test_feature_enabled(features_json, feature_enabled) -> None: features = Features(features_json) assert all( @@ -1988,7 +1999,7 @@ def test_feature_enabled(features_json, feature_enabled): ({}, False), ], ) -def test_sync_rules_enabled(features_json, sync_rules_enabled): +def test_sync_rules_enabled(features_json, sync_rules_enabled) -> None: features = Features(features_json) assert features.sync_rules_enabled() == sync_rules_enabled @@ -2022,7 +2033,7 @@ def test_sync_rules_enabled(features_json, sync_rules_enabled): ({}, False), ], ) -def test_incremental_sync_enabled(features_json, incremental_sync_enabled): +def test_incremental_sync_enabled(features_json, incremental_sync_enabled) -> None: features = Features(features_json) assert features.incremental_sync_enabled() == incremental_sync_enabled @@ -2037,7 +2048,7 @@ def test_incremental_sync_enabled(features_json, incremental_sync_enabled): ], ) @patch("connectors.protocol.SyncJobIndex.index") -async def test_create_job(index_method, trigger_method, set_env): +async def test_create_job(index_method, trigger_method, set_env) -> None: connector = Mock() connector.id = "id" connector.index_name = "index_name" @@ -2079,7 +2090,7 @@ async def test_create_job(index_method, trigger_method, set_env): (JobTriggerMethod.SCHEDULED, JobType.ACCESS_CONTROL), ], ) -async def test_create_job_with_connector_api(trigger_method, job_type, set_env): +async def test_create_job_with_connector_api(trigger_method, job_type, set_env) -> None: connector = Mock() connector.id = "id" config = load_config(CONFIG) @@ -2113,7 +2124,7 @@ async def test_create_job_with_connector_api(trigger_method, job_type, set_env): ) async def test_create_jobs_with_correct_target_index( index_method, job_type, target_index_name, set_env -): +) -> None: connector = Mock() connector.index_name = INDEX_NAME config = load_config(CONFIG) @@ -2166,7 +2177,7 @@ async def test_create_jobs_with_correct_target_index( @patch("connectors.protocol.SyncJobIndex.get_all_docs") async def test_pending_jobs( get_all_docs, job_types, job_type_query, remote_call, set_env -): +) -> None: job = Mock() get_all_docs.return_value = AsyncIterator([job]) config = load_config(CONFIG) @@ -2208,7 +2219,7 @@ async def test_pending_jobs( @pytest.mark.asyncio @patch("connectors.protocol.SyncJobIndex.get_all_docs") -async def test_orphaned_idle_jobs(get_all_docs, set_env): +async def test_orphaned_idle_jobs(get_all_docs, set_env) -> None: job = Mock() get_all_docs.return_value = AsyncIterator([job]) config = load_config(CONFIG) @@ -2242,7 +2253,7 @@ async def test_orphaned_idle_jobs(get_all_docs, set_env): @pytest.mark.asyncio @patch("connectors.protocol.SyncJobIndex.get_all_docs") -async def test_idle_jobs(get_all_docs, set_env): +async def test_idle_jobs(get_all_docs, set_env) -> None: job = Mock() get_all_docs.return_value = AsyncIterator([job]) config = load_config(CONFIG) @@ -2283,7 +2294,7 @@ async def test_idle_jobs(get_all_docs, set_env): (None, False), ], ) -def test_advanced_rules_present(filtering, should_advanced_rules_be_present): +def test_advanced_rules_present(filtering, should_advanced_rules_be_present) -> None: assert Filter(filtering).has_advanced_rules() == should_advanced_rules_be_present @@ -2298,7 +2309,7 @@ def test_advanced_rules_present(filtering, should_advanced_rules_be_present): ) def test_has_validation_state( filtering, validation_state, has_expected_validation_state -): +) -> None: assert ( Filter(filtering).has_validation_state(validation_state) == has_expected_validation_state @@ -2314,7 +2325,7 @@ def test_has_validation_state( ("run_ml_inference", False, True), ], ) -def test_pipeline_properties(key, value, default_value): +def test_pipeline_properties(key, value, default_value) -> None: assert Pipeline({})[key] == default_value assert Pipeline({key: value})[key] == value @@ -2329,11 +2340,11 @@ def test_pipeline_properties(key, value, default_value): (None, {}), ], ) -def test_get_advanced_rules(filtering, expected_advanced_rules): +def test_get_advanced_rules(filtering, expected_advanced_rules) -> None: assert Filter(filtering).get_advanced_rules() == expected_advanced_rules -def test_updated_configuration_fields(): +def test_updated_configuration_fields() -> None: current = { "tenant_id": {"label": "Tenant ID", "order": 1, "type": "str", "value": "foo"}, "tenant_name": { @@ -2412,7 +2423,7 @@ def test_updated_configuration_fields(): @pytest.mark.asyncio -async def test_native_connector_missing_features(): +async def test_native_connector_missing_features() -> None: doc_id = "1" seq_no = 1 primary_term = 2 @@ -2452,7 +2463,7 @@ async def test_native_connector_missing_features(): @pytest.mark.asyncio -async def test_get_connector_by_index(): +async def test_get_connector_by_index() -> None: config = { "username": "elastic", "password": "changeme", @@ -2489,7 +2500,7 @@ async def test_get_connector_by_index(): ), ], ) -def test_property_as_datetime(indexed_timestamp, expected_datetime): +def test_property_as_datetime(indexed_timestamp, expected_datetime) -> None: connector = Connector( elastic_index=Mock(), doc_source={ diff --git a/tests/services/test_base.py b/tests/services/test_base.py index 0abb49e4b..3f1718fbe 100644 --- a/tests/services/test_base.py +++ b/tests/services/test_base.py @@ -8,19 +8,45 @@ import os from collections import defaultdict from copy import deepcopy +from typing import Dict, List, Optional, Type, Union from unittest.mock import Mock import pytest from connectors.config import load_config +from connectors.services.access_control_sync_job_execution import ( + AccessControlSyncJobExecutionService, +) from connectors.services.base import BaseService, MultiService, get_services - -HERE = os.path.dirname(__file__) -FIXTURES_DIR = os.path.abspath(os.path.join(HERE, "..", "fixtures")) -CONFIG_FILE = os.path.join(FIXTURES_DIR, "config.yml") - - -def create_service(service_klass, config=None, config_file=None, idling=None): +from connectors.services.content_sync_job_execution import ( + ContentSyncJobExecutionService, +) +from connectors.services.job_cleanup import JobCleanUpService +from connectors.services.job_scheduling import JobSchedulingService + +HERE: str = os.path.dirname(__file__) +FIXTURES_DIR: str = os.path.abspath(os.path.join(HERE, "..", "fixtures")) +CONFIG_FILE: str = os.path.join(FIXTURES_DIR, "config.yml") + + +def create_service( + service_klass: Union[ + Type[JobCleanUpService], + Type[JobSchedulingService], + Type[ContentSyncJobExecutionService], + Type[AccessControlSyncJobExecutionService], + ], + config: Optional[ + Dict[str, Union[Dict[str, str], Dict[str, int], List[str]]] + ] = None, + config_file: Optional[str] = None, + idling: None = None, +) -> Union[ + JobSchedulingService, + ContentSyncJobExecutionService, + AccessControlSyncJobExecutionService, + JobCleanUpService, +]: if config is None: config = load_config(config_file) if config_file else {} service = service_klass(config) @@ -29,7 +55,15 @@ def create_service(service_klass, config=None, config_file=None, idling=None): return service -async def run_service_with_stop_after(service, stop_after=0): +async def run_service_with_stop_after( + service: Union[ + JobSchedulingService, + ContentSyncJobExecutionService, + AccessControlSyncJobExecutionService, + JobCleanUpService, + ], + stop_after: int = 0, +) -> None: def _stop_running_service_without_cancelling(): service.running = False @@ -50,14 +84,24 @@ async def _terminate(): async def create_and_run_service( - service_klass, config=None, config_file=CONFIG_FILE, stop_after=0 -): + service_klass: Union[ + Type[ContentSyncJobExecutionService], + Type[JobSchedulingService], + Type[JobCleanUpService], + Type[AccessControlSyncJobExecutionService], + ], + config: Optional[ + Dict[str, Union[Dict[str, str], Dict[str, int], List[str]]] + ] = None, + config_file: str = CONFIG_FILE, + stop_after: int = 0, +) -> None: service = create_service(service_klass, config=config, config_file=config_file) await run_service_with_stop_after(service, stop_after) class StubService: - def __init__(self): + def __init__(self) -> None: self.running = False self.cancelled = False self.exploding = False @@ -65,7 +109,7 @@ def __init__(self): self.stopped = False self.handle_cancellation = True - async def run(self): + async def run(self) -> None: if self.handle_cancellation: try: await self._run() @@ -75,7 +119,7 @@ async def run(self): else: await self._run() - async def _run(self): + async def _run(self) -> None: self.running = True while self.running: if self.exploding: @@ -83,16 +127,16 @@ async def _run(self): raise Exception(msg) await asyncio.sleep(self.run_sleep_delay) - def stop(self): + def stop(self) -> None: self.running = False self.stopped = True - def explode(self): + def explode(self) -> None: self.exploding = True @pytest.mark.asyncio -async def test_multiservice_run_stops_all_services_when_one_raises_exception(): +async def test_multiservice_run_stops_all_services_when_one_raises_exception() -> None: service_1 = StubService() service_2 = StubService() service_3 = StubService() @@ -109,7 +153,7 @@ async def test_multiservice_run_stops_all_services_when_one_raises_exception(): @pytest.mark.asyncio -async def test_multiservice_run_stops_all_services_when_shutdown_happens(): +async def test_multiservice_run_stops_all_services_when_shutdown_happens() -> None: service_1 = StubService() service_2 = StubService() service_3 = StubService() @@ -128,7 +172,7 @@ async def test_multiservice_run_stops_all_services_when_shutdown_happens(): @pytest.mark.asyncio -async def test_registry(): +async def test_registry() -> None: ran = [] # creating a class using the BaseService as its base @@ -158,7 +202,9 @@ def __init__(self, config): @pytest.mark.asyncio -async def test_multiservice_stop_gracefully_stops_service_that_takes_too_long_to_run(): +async def test_multiservice_stop_gracefully_stops_service_that_takes_too_long_to_run() -> ( + None +): service_1 = StubService() service_2 = StubService() @@ -185,13 +231,13 @@ async def test_multiservice_stop_gracefully_stops_service_that_takes_too_long_to } -def test_parse_connectors_with_no_connectors(): +def test_parse_connectors_with_no_connectors() -> None: local_config = deepcopy(config) service = BaseService(local_config, "mock_service") assert not service.connectors -def test_parse_connectors(): +def test_parse_connectors() -> None: local_config = deepcopy(config) local_config["connectors"] = [ {"connector_id": "foo", "service_type": "bar"}, @@ -205,7 +251,7 @@ def test_parse_connectors(): assert service.connectors["baz"]["service_type"] == "qux" -def test_parse_connectors_with_duplicate_connectors(): +def test_parse_connectors_with_duplicate_connectors() -> None: local_config = deepcopy(config) local_config["connectors"] = [ {"connector_id": "foo", "service_type": "bar"}, @@ -218,7 +264,7 @@ def test_parse_connectors_with_duplicate_connectors(): assert service.connectors["foo"]["service_type"] == "baz" -def test_parse_connectors_with_incomplete_connector(): +def test_parse_connectors_with_incomplete_connector() -> None: local_config = deepcopy(config) local_config["connectors"] = [ {"connector_id": "foo", "service_type": "bar"}, @@ -231,7 +277,7 @@ def test_parse_connectors_with_incomplete_connector(): assert service.connectors["foo"]["service_type"] == "bar" -def test_parse_connectors_with_deprecated_config_and_new_config(): +def test_parse_connectors_with_deprecated_config_and_new_config() -> None: local_config = deepcopy(config) local_config["connectors"] = [{"connector_id": "foo", "service_type": "bar"}] local_config["connector_id"] = "deprecated" @@ -244,7 +290,7 @@ def test_parse_connectors_with_deprecated_config_and_new_config(): assert "deprecated" not in service.connectors -def test_parse_connectors_with_deprecated_config(): +def test_parse_connectors_with_deprecated_config() -> None: local_config = deepcopy(config) local_config["connector_id"] = "deprecated" local_config["service_type"] = "deprecated" @@ -255,7 +301,7 @@ def test_parse_connectors_with_deprecated_config(): assert service.connectors["deprecated"]["service_type"] == "deprecated" -def test_override_es_config(): +def test_override_es_config() -> None: connector_api_key = "connector_api_key" config = { "elasticsearch": { diff --git a/tests/services/test_job_cleanup.py b/tests/services/test_job_cleanup.py index 33048c369..f4ab4d661 100644 --- a/tests/services/test_job_cleanup.py +++ b/tests/services/test_job_cleanup.py @@ -27,7 +27,7 @@ } -def mock_connector(connector_id="1"): +def mock_connector(connector_id: str = "1") -> Mock: connector = Mock() connector.id = connector_id connector.sync_done = AsyncMock() @@ -35,9 +35,9 @@ def mock_connector(connector_id="1"): def mock_sync_job( - sync_job_id="1", - connector_id="1", -): + sync_job_id: str = "1", + connector_id: str = "1", +) -> Mock: job = Mock() job.job_id = sync_job_id job.connector_id = connector_id @@ -58,7 +58,7 @@ async def test_cleanup_jobs( connector_fetch_by_id, orphaned_idle_jobs, idle_jobs, -): +) -> None: connector = mock_connector() orphaned_idle_sync_job = mock_sync_job() idle_sync_job = mock_sync_job() diff --git a/tests/services/test_job_execution.py b/tests/services/test_job_execution.py index e78633aad..f91021ed8 100644 --- a/tests/services/test_job_execution.py +++ b/tests/services/test_job_execution.py @@ -50,7 +50,7 @@ def sync_job_index_mock(): yield sync_job_index_mock -def concurrent_task_mock(): +def concurrent_task_mock() -> Mock: task_mock = Mock() task_mock.try_put = Mock() task_mock.join = AsyncMock() @@ -83,10 +83,10 @@ def sync_job_runner_mock(): def mock_connector( - last_sync_status=JobStatus.COMPLETED, - last_access_control_sync_status=JobStatus.COMPLETED, - document_level_security_enabled=True, -): + last_sync_status: JobStatus = JobStatus.COMPLETED, + last_access_control_sync_status: JobStatus = JobStatus.COMPLETED, + document_level_security_enabled: bool = True, +) -> Mock: connector = Mock() connector.id = "1" connector.last_sync_status = last_sync_status @@ -99,7 +99,7 @@ def mock_connector( return connector -def mock_sync_job(service_type="fake", job_type=JobType.FULL): +def mock_sync_job(service_type: str = "fake", job_type: JobType = JobType.FULL) -> Mock: sync_job = Mock() sync_job.service_type = service_type sync_job.connector_id = "1" @@ -109,7 +109,9 @@ def mock_sync_job(service_type="fake", job_type=JobType.FULL): @pytest.mark.asyncio -async def test_no_connector(connector_index_mock, concurrent_tasks_mocks, set_env): +async def test_no_connector( + connector_index_mock, concurrent_tasks_mocks, set_env +) -> None: sync_job_pool_mock = concurrent_tasks_mocks connector_index_mock.supported_connectors.return_value = AsyncIterator([]) @@ -129,7 +131,7 @@ async def test_no_pending_jobs( concurrent_tasks_mocks, service_klass, set_env, -): +) -> None: sync_job_pool_mock = concurrent_tasks_mocks connector = mock_connector() @@ -153,9 +155,9 @@ async def test_job_execution_with_unsupported_source( sync_job_index_mock, concurrent_tasks_mocks, service_klass, - job_type, + job_type: JobType, set_env, -): +) -> None: sync_job_pool_mock = concurrent_tasks_mocks connector = mock_connector() @@ -181,9 +183,9 @@ async def test_job_execution_with_connector_not_found( concurrent_tasks_mocks, sync_job_runner_mock, service_klass, - job_type, + job_type: JobType, set_env, -): +) -> None: sync_job_pool_mock = concurrent_tasks_mocks connector = mock_connector() @@ -203,7 +205,7 @@ async def test_access_control_sync_job_execution_with_premium_connector( concurrent_tasks_mocks, sync_job_runner_mock, set_env, -): +) -> None: sync_job_pool_mock = concurrent_tasks_mocks connector = mock_connector() @@ -223,7 +225,7 @@ async def test_access_control_sync_job_execution_with_insufficient_license( concurrent_tasks_mocks, sync_job_runner_mock, set_env, -): +) -> None: sync_job_pool_mock = concurrent_tasks_mocks connector = mock_connector() @@ -247,7 +249,7 @@ async def test_access_control_sync_job_execution_with_dls_feature_flag_disabled( concurrent_tasks_mocks, sync_job_runner_mock, set_env, -): +) -> None: sync_job_pool_mock = concurrent_tasks_mocks connector = mock_connector( @@ -279,9 +281,9 @@ async def test_job_execution_with_connector_still_syncing( concurrent_tasks_mocks, sync_job_runner_mock, service_klass, - job_type, + job_type: JobType, set_env, -): +) -> None: sync_job_pool_mock = concurrent_tasks_mocks connector = mock_connector( @@ -311,9 +313,9 @@ async def test_job_execution( concurrent_tasks_mocks, sync_job_runner_mock, service_klass, - job_type, + job_type: JobType, set_env, -): +) -> None: sync_job_pool_mock = concurrent_tasks_mocks connector = mock_connector() @@ -342,9 +344,9 @@ async def test_job_execution_new_sync_job_not_blocked( concurrent_tasks_mocks, sync_job_runner_mock, service_klass, - job_type, + job_type: JobType, set_env, -): +) -> None: sync_job_pool_mock = concurrent_tasks_mocks connector = mock_connector() diff --git a/tests/services/test_job_scheduling.py b/tests/services/test_job_scheduling.py index de1fcfc4f..ea9cd531a 100644 --- a/tests/services/test_job_scheduling.py +++ b/tests/services/test_job_scheduling.py @@ -5,6 +5,7 @@ # import asyncio from datetime import datetime, timedelta, timezone +from typing import List from unittest.mock import AsyncMock, Mock, patch import pytest @@ -25,7 +26,7 @@ from tests.commons import AsyncIterator from tests.services.test_base import create_and_run_service -JOB_TYPES = [JobType.FULL, JobType.ACCESS_CONTROL] +JOB_TYPES: List[JobType] = [JobType.FULL, JobType.ACCESS_CONTROL] @pytest.fixture(autouse=True) @@ -58,18 +59,18 @@ def sync_job_index_mock(): yield sync_job_index_mock -default_next_sync = datetime.now(timezone.utc) + timedelta(hours=1) +default_next_sync: datetime = datetime.now(timezone.utc) + timedelta(hours=1) def mock_connector( - status=Status.CONNECTED, - service_type="fake", - next_sync=default_next_sync, + status: Status = Status.CONNECTED, + service_type: str = "fake", + next_sync: datetime = default_next_sync, prepare_exception=None, last_sync_scheduled_at_by_job_type=None, - document_level_security_enabled=True, - incremental_sync_enabled=False, -): + document_level_security_enabled: bool = True, + incremental_sync_enabled: bool = False, +) -> Mock: connector = Mock() connector.native = True connector.service_type = service_type @@ -101,7 +102,7 @@ def mock_connector( @pytest.mark.asyncio -async def test_no_connector(connector_index_mock, sync_job_index_mock, set_env): +async def test_no_connector(connector_index_mock, sync_job_index_mock, set_env) -> None: connector_index_mock.supported_connectors.return_value = AsyncIterator([]) await create_and_run_service(JobSchedulingService) @@ -113,7 +114,7 @@ async def test_connector_ready_to_sync( connector_index_mock, sync_job_index_mock, set_env, -): +) -> None: connector = mock_connector(next_sync=datetime.now(timezone.utc)) connector_index_mock.supported_connectors.return_value = AsyncIterator([connector]) await create_and_run_service(JobSchedulingService) @@ -137,7 +138,7 @@ async def test_connector_ready_to_sync_with_race_condition( connector_index_mock, sync_job_index_mock, set_env, -): +) -> None: connector = mock_connector(next_sync=datetime.now(timezone.utc)) # Do nothing in the first call(in _should_schedule_on_demand_sync) and second call(in _should_schedule_scheduled_sync), and the last_sync_scheduled_at is updated by another instance in the subsequent calls @@ -165,7 +166,7 @@ def _reset_last_sync_scheduled_at_by_job_type(): @pytest.mark.asyncio async def test_connector_sync_disabled( connector_index_mock, sync_job_index_mock, set_env -): +) -> None: connector = mock_connector(next_sync=None) connector_index_mock.supported_connectors.return_value = AsyncIterator([connector]) await create_and_run_service(JobSchedulingService) @@ -181,7 +182,7 @@ async def test_connector_scheduled_access_control_sync_with_dls_feature_disabled connector_index_mock, sync_job_index_mock, set_env, -): +) -> None: connector = mock_connector( next_sync=datetime.now(timezone.utc), document_level_security_enabled=False ) @@ -206,7 +207,7 @@ async def test_connector_scheduled_access_control_sync_with_insufficient_license connector_index_mock, sync_job_index_mock, set_env, -): +) -> None: connector = mock_connector(next_sync=datetime.now(timezone.utc)) connector_index_mock.supported_connectors.return_value = AsyncIterator([connector]) connector_index_mock.has_active_license_enabled = AsyncMock( @@ -239,13 +240,13 @@ async def test_connector_scheduled_access_control_sync_with_insufficient_license ], ) async def test_connector_scheduled_incremental_sync( - incremental_sync_enabled, - service_type, + incremental_sync_enabled: bool, + service_type: str, schedule_incremental_sync, connector_index_mock, sync_job_index_mock, set_env, -): +) -> None: connector = mock_connector( service_type=service_type, next_sync=datetime.now(timezone.utc), @@ -280,11 +281,11 @@ async def test_connector_scheduled_incremental_sync( [Status.CREATED, Status.NEEDS_CONFIGURATION], ) async def test_connector_not_configured( - connector_status, + connector_status: Status, connector_index_mock, sync_job_index_mock, set_env, -): +) -> None: connector = mock_connector(status=connector_status) connector_index_mock.supported_connectors.return_value = AsyncIterator([connector]) await create_and_run_service(JobSchedulingService) @@ -310,7 +311,7 @@ async def test_connector_prepare_failed( connector_index_mock, sync_job_index_mock, set_env, -): +) -> None: connector = mock_connector(prepare_exception=prepare_exception()) connector_index_mock.supported_connectors.return_value = AsyncIterator([connector]) await create_and_run_service(JobSchedulingService) @@ -324,7 +325,7 @@ async def test_connector_prepare_failed( @pytest.mark.asyncio async def test_run_when_sync_fails_then_continues_service_execution( connector_index_mock, set_env -): +) -> None: connector = mock_connector(next_sync=datetime.now(timezone.utc)) another_connector = mock_connector(next_sync=datetime.now(timezone.utc)) connector_index_mock.supported_connectors.return_value = AsyncIterator( @@ -351,7 +352,7 @@ async def test_run_when_sync_fails_then_continues_service_execution( @patch("connectors.services.job_scheduling.get_source_klass") async def test_run_when_connector_fields_are_invalid( get_source_klass_mock, connector_index_mock, set_env -): +) -> None: error_message = "Something invalid is in config!" actual_error = Exception(error_message) @@ -383,7 +384,7 @@ def _source_klass(config): @patch("connectors.services.job_scheduling.get_source_klass") async def test_run_when_connector_failed_validation_then_succeeded( get_source_klass_mock, connector_index_mock, set_env -): +) -> None: error_message = "Something invalid is in config!" actual_error = Exception(error_message) @@ -422,7 +423,7 @@ def _error_once(): @patch("connectors.services.job_scheduling.get_source_klass") async def test_run_when_connector_ping_fails( get_source_klass_mock, connector_index_mock, set_env -): +) -> None: error_message = "Something went wrong when trying to ping the data source!" actual_error = Exception(error_message) @@ -454,7 +455,7 @@ def _source_klass(config): @patch("connectors.services.job_scheduling.get_source_klass") async def test_run_when_connector_validate_config_fails( get_source_klass_mock, connector_index_mock, set_env -): +) -> None: data_source_mock = Mock() error = ConfigurableFieldValueError() @@ -484,7 +485,7 @@ def _source_klass(config): @pytest.mark.asyncio async def test_initial_loop_run_heartbeat_only_once( connector_index_mock, sync_job_index_mock, set_env -): +) -> None: connector = mock_connector(next_sync=None) connector_index_mock.supported_connectors.return_value = AsyncIterator( [connector, connector, connector, connector] @@ -507,7 +508,7 @@ async def test_initial_loop_run_heartbeat_only_once( @patch("connectors.services.job_scheduling.get_source_klass") async def test_run_when_validation_is_very_slow( get_source_klass_mock, connector_index_mock, set_env -): +) -> None: data_source_mock = Mock() def _source_klass(config): diff --git a/tests/sources/fixtures/azure_blob_storage/fixture.py b/tests/sources/fixtures/azure_blob_storage/fixture.py index aa527df0a..f9d704c0d 100644 --- a/tests/sources/fixtures/azure_blob_storage/fixture.py +++ b/tests/sources/fixtures/azure_blob_storage/fixture.py @@ -14,7 +14,7 @@ fake_provider = WeightedFakeProvider() -DATA_SIZE = os.environ.get("DATA_SIZE", "medium") +DATA_SIZE: str = os.environ.get("DATA_SIZE", "medium") CONTAINERS_TO_DELETE = 1 @@ -30,11 +30,11 @@ BLOB_COUNT = 1000 -def get_num_docs(): +def get_num_docs() -> None: print((CONTAINER_COUNT - CONTAINERS_TO_DELETE) * BLOB_COUNT) -async def load(): +async def load() -> None: """Method for generating document for azurite emulator""" try: blob_service_client = BlobServiceClient.from_connection_string( @@ -57,7 +57,7 @@ async def load(): print(f"Exception: {exception}") -async def remove(): +async def remove() -> None: """Method for removing 2k document for azurite emulator""" try: blob_service_client = BlobServiceClient.from_connection_string( diff --git a/tests/sources/fixtures/box/fixture.py b/tests/sources/fixtures/box/fixture.py index 0db6406c0..66589b33b 100644 --- a/tests/sources/fixtures/box/fixture.py +++ b/tests/sources/fixtures/box/fixture.py @@ -10,26 +10,31 @@ import os import random import string +from _io import BytesIO +from typing import Any, Dict from faker import Faker from flask import Flask, make_response, request +from flask.wrappers import Response app = Flask(__name__) -DATA_SIZE = os.environ.get("DATA_SIZE", "medium").lower() +DATA_SIZE: str = os.environ.get("DATA_SIZE", "medium").lower() _SIZES = {"small": 500000, "medium": 1000000, "large": 3000000} -FILE_SIZE = _SIZES[DATA_SIZE] -LARGE_DATA = "".join([random.choice(string.ascii_letters) for _ in range(FILE_SIZE)]) +FILE_SIZE: int = _SIZES[DATA_SIZE] +LARGE_DATA: str = "".join( + [random.choice(string.ascii_letters) for _ in range(FILE_SIZE)] +) fake = Faker() -def _create_data(size): +def _create_data(size) -> str: return "".join([random.choice(string.ascii_letters) for _ in range(size)]) class BoxAPI: - def __init__(self): + def __init__(self) -> None: self.app = Flask(__name__) self.first_sync = True match DATA_SIZE: @@ -50,7 +55,7 @@ def __init__(self): self.get_content ) - def get_token(self): + def get_token(self) -> Response: fake_res = { "access_token": "FAKE-ACCESS-TOKEN", "refresh_token": "FAKE-REFRESH-TOKEN", @@ -60,7 +65,7 @@ def get_token(self): response.headers["status_code"] = 200 return response - def get_user(self): + def get_user(self) -> Dict[str, str]: return {"username": "demo_user"} def get_folders_entries(self, offset, limit): @@ -94,7 +99,7 @@ def get_files_entries(self, offset, limit, folder_id): ] return files_entries - def get_folder_items(self, folder_id): + def get_folder_items(self, folder_id) -> Dict[str, Any]: offset = int(request.args.get("offset")) limit = int(request.args.get("limit")) if folder_id == "0": @@ -128,7 +133,7 @@ def get_folder_items(self, folder_id): } return response - def get_content(self, file_id): + def get_content(self, file_id) -> BytesIO: return io.BytesIO(bytes(LARGE_DATA, encoding="utf-8")) diff --git a/tests/sources/fixtures/confluence/fixture.py b/tests/sources/fixtures/confluence/fixture.py index 0c4ec30a2..727d0ba72 100644 --- a/tests/sources/fixtures/confluence/fixture.py +++ b/tests/sources/fixtures/confluence/fixture.py @@ -10,6 +10,8 @@ import os import re import time +from _io import BytesIO +from typing import Dict, List, Union from flask import Flask, request @@ -17,7 +19,7 @@ fake_provider = WeightedFakeProvider() -DATA_SIZE = os.environ.get("DATA_SIZE", "medium").lower() +DATA_SIZE: str = os.environ.get("DATA_SIZE", "medium").lower() match DATA_SIZE: case "small": @@ -37,7 +39,7 @@ raise Exception(msg) -def get_num_docs(): +def get_num_docs() -> None: # 2 is multiplier cause SPACE_OBJECTs will be delivered twice: # Test returns SPACE_OBJECT_COUNT objects for each type of content # There are 2 types of content: @@ -47,7 +49,7 @@ def get_num_docs(): class ConfluenceAPI: - def __init__(self): + def __init__(self) -> None: self.app = Flask(__name__) self.first_sync = True self.space_start_at = 0 @@ -109,7 +111,7 @@ def get_spaces(self): ) return spaces - def get_label(self, label_id): + def get_label(self, label_id) -> Dict[str, Union[List[Dict[str, str]], int]]: return { "results": [ { @@ -197,7 +199,7 @@ def get_attachments(self, content_id): attachments["results"].append(attachment) return attachments - def download(self, content_id, attachment_id): + def download(self, content_id, attachment_id) -> BytesIO: """Function to handle download calls for attachments Args: diff --git a/tests/sources/fixtures/dir/fixture.py b/tests/sources/fixtures/dir/fixture.py index 791e7f1f9..a1288c3ed 100644 --- a/tests/sources/fixtures/dir/fixture.py +++ b/tests/sources/fixtures/dir/fixture.py @@ -10,8 +10,8 @@ import urllib.request import zipfile -SYSTEM_DIR = os.path.join(os.path.dirname(__file__), "data") -DATA_SIZE = os.environ.get("DATA_SIZE", "small").lower() +SYSTEM_DIR: str = os.path.join(os.path.dirname(__file__), "data") +DATA_SIZE: str = os.environ.get("DATA_SIZE", "small").lower() if DATA_SIZE == "small": REPO = "connectors-python" @@ -21,7 +21,7 @@ REPO = "kibana" -def get_num_docs(): +def get_num_docs() -> None: match os.environ.get("DATA_SIZE", "medium"): case "small": print("100") @@ -31,7 +31,7 @@ def get_num_docs(): print("300") -async def load(): +async def load() -> None: if os.path.exists(SYSTEM_DIR): teardown() print(f"Working in {SYSTEM_DIR}") @@ -51,7 +51,7 @@ async def load(): os.unlink(repo_zip) -async def remove(): +async def remove() -> None: # removing 10 files files = [] for root, __, filenames in os.walk(SYSTEM_DIR): @@ -64,5 +64,5 @@ async def remove(): os.unlink(files[i]) -async def teardown(): +async def teardown() -> None: shutil.rmtree(SYSTEM_DIR) diff --git a/tests/sources/fixtures/dropbox/fixture.py b/tests/sources/fixtures/dropbox/fixture.py index 42cf33322..07bca7997 100755 --- a/tests/sources/fixtures/dropbox/fixture.py +++ b/tests/sources/fixtures/dropbox/fixture.py @@ -9,14 +9,17 @@ import io import json import os +from _io import BytesIO +from typing import Dict, Union from flask import Flask, jsonify, make_response, request +from flask.wrappers import Response from tests.commons import WeightedFakeProvider fake_provider = WeightedFakeProvider() -DATA_SIZE = os.environ.get("DATA_SIZE", "medium") +DATA_SIZE: str = os.environ.get("DATA_SIZE", "medium") match DATA_SIZE: case "small": @@ -33,7 +36,7 @@ class DropboxAPI: - def __init__(self): + def __init__(self) -> None: self.app = Flask(__name__) self.first_sync = True @@ -63,13 +66,13 @@ def __init__(self): self.app.route("/2/files/download", methods=["POST"])(self.download_file) self.app.route("/2/files/export", methods=["POST"])(self.download_paper_file) - def get_dropbox_token(self): + def get_dropbox_token(self) -> Response: res = {"access_token": "fake-access-token", "expires_in": 3699} response = make_response(res) response.headers["status_code"] = 200 return response - def get_current_account(self): + def get_current_account(self) -> Dict[str, Union[Dict[str, str], bool, str]]: return { "account_id": "dbid:1122aabb2AogwKjDAG8RkrnCee-8Zex-e94", "name": { @@ -118,7 +121,7 @@ def get_authenticated_admin(self): } } - def files_list_folder(self): + def files_list_folder(self) -> Response: response = {"entries": [], "cursor": "fake-cursor", "has_more": True} if self.first_sync: end_files_folders = FILE_FOLDERS @@ -158,7 +161,7 @@ def files_list_folder(self): response["entries"].append(file_entry) return jsonify(response) - def files_list_folder_continue(self): + def files_list_folder_continue(self) -> Response: response = {"entries": [], "cursor": "fake-cursor", "has_more": False} for entry in range(FILE_FOLDERS, FILE_FOLDERS * 2): folder_entry = { @@ -193,7 +196,7 @@ def files_list_folder_continue(self): response["entries"].append(file_entry) return jsonify(response) - def get_received_files(self): + def get_received_files(self) -> Response: response = {"entries": [], "cursor": "fake-cursor"} for entry in range(RECEIVED_FILES_PAGE): response["entries"].append( @@ -212,7 +215,7 @@ def get_received_files(self): ) return jsonify(response) - def get_received_files_continue(self): + def get_received_files_continue(self) -> Response: response = {"entries": [], "cursor": None} for entry in range(RECEIVED_FILES_PAGE, RECEIVED_FILES_PAGE * 2): response["entries"].append( @@ -231,7 +234,7 @@ def get_received_files_continue(self): ) return jsonify(response) - def get_shared_file_metadata(self): + def get_shared_file_metadata(self) -> Response: data = request.get_data().decode("utf-8") url = json.loads(data)["url"] res = { @@ -294,13 +297,13 @@ def get_shared_file_metadata(self): } return jsonify(res) - def download_file(self): + def download_file(self) -> BytesIO: return io.BytesIO(bytes(fake_provider.get_html(), encoding="utf-8")) - def download_paper_file(self): + def download_paper_file(self) -> BytesIO: return io.BytesIO(bytes(fake_provider.get_html(), encoding="utf-8")) - def download_shared_file(self): + def download_shared_file(self) -> BytesIO: return io.BytesIO(bytes(fake_provider.get_html(), encoding="utf-8")) diff --git a/tests/sources/fixtures/fixture.py b/tests/sources/fixtures/fixture.py index b38e2b873..ffc0168df 100644 --- a/tests/sources/fixtures/fixture.py +++ b/tests/sources/fixtures/fixture.py @@ -15,7 +15,9 @@ import signal import sys import time -from argparse import ArgumentParser +from argparse import ArgumentParser, Namespace +from asyncio.events import AbstractEventLoop +from typing import Optional from elastic_transport import ConnectionTimeout from elasticsearch import ApiError @@ -29,11 +31,11 @@ CONNECTORS_INDEX = ".elastic-connectors" -logger = logging.getLogger("ftest") +logger: logging.Logger = logging.getLogger("ftest") set_extra_logger(logger, log_level=logging.DEBUG, prefix="FTEST") -async def wait_for_es(): +async def wait_for_es() -> None: try: es_client = _es_client() await es_client.wait() @@ -45,7 +47,7 @@ async def wait_for_es(): await es_client.close() -def _parser(): +def _parser() -> ArgumentParser: parser = ArgumentParser(prog="fixture") parser.add_argument( @@ -73,7 +75,7 @@ def _parser(): return parser -def retrying_transient_errors(retries=5): +def retrying_transient_errors(retries: int = 5): def wrapper(func): @functools.wraps(func) async def wrapped(*args, **kwargs): @@ -104,7 +106,7 @@ async def wrapped(*args, **kwargs): return wrapper -def _es_client(): +def _es_client() -> ESManagementClient: options = { "host": "http://127.0.0.1:9200", "username": "elastic", @@ -133,7 +135,7 @@ async def _fetch_connector_metadata(es_client): return (connector_id, last_synced) -async def _monitor_service(pid): +async def _monitor_service(pid: int) -> None: es_client = _es_client() sync_job_timeout = 20 * 60 # 20 minutes timeout @@ -164,7 +166,7 @@ async def _monitor_service(pid): await es_client.close() -async def _exec_shell(cmd): +async def _exec_shell(cmd) -> None: # Create subprocess proc = await asyncio.create_subprocess_shell(cmd) @@ -173,7 +175,7 @@ async def _exec_shell(cmd): logger.debug(f"Successfully executed cmd: {cmd}") -async def main(args=None): +async def main(args: Optional[Namespace] = None): parser = _parser() args = parser.parse_args(args=args) action = args.action @@ -230,5 +232,5 @@ async def main(args=None): if __name__ == "__main__": - loop = asyncio.get_event_loop() + loop: AbstractEventLoop = asyncio.get_event_loop() loop.run_until_complete(main()) diff --git a/tests/sources/fixtures/github/fixture.py b/tests/sources/fixtures/github/fixture.py index 5d7b654ff..1c3d82d46 100644 --- a/tests/sources/fixtures/github/fixture.py +++ b/tests/sources/fixtures/github/fixture.py @@ -8,14 +8,16 @@ import base64 import os +from typing import Any, Dict, List, Union from flask import Flask, make_response, request +from flask.wrappers import Response from tests.commons import WeightedFakeProvider fake_provider = WeightedFakeProvider() -DATA_SIZE = os.environ.get("DATA_SIZE", "medium").lower() +DATA_SIZE: str = os.environ.get("DATA_SIZE", "medium").lower() # TODO: change number of files based on DATA_SIZE match DATA_SIZE: @@ -33,12 +35,12 @@ app = Flask(__name__) -def encode_data(content): +def encode_data(content) -> str: return base64.b64encode(bytes(content, "utf-8")).decode("utf-8") class GitHubAPI: - def __init__(self): + def __init__(self) -> None: self.app = Flask(__name__) self.file_count = FILE_COUNT self.issue_count = ISSUE_COUNT @@ -56,10 +58,10 @@ def __init__(self): self.app.route("/api/graphql", methods=["HEAD"])(self.get_scopes) self.files = {} - def encode_cursor(self, value): + def encode_cursor(self, value) -> str: return base64.b64encode(str(value).encode()).decode() - def decode_cursor(self, cursor): + def decode_cursor(self, cursor) -> int: return int(base64.b64decode(cursor.encode()).decode()) def get_index_metadata(self, variables, data): @@ -69,7 +71,13 @@ def get_index_metadata(self, variables, data): subset_nodes = data["nodes"][start_index:end_index] return start_index, end_index, subset_nodes - def mock_graphql_response(self): + def mock_graphql_response( + self, + ) -> Union[ + Dict[str, Dict[str, Dict[str, Dict[str, Any]]]], + Dict[str, Dict[str, Dict[str, str]]], + Dict[str, List[str]], + ]: data = request.get_json() query = data.get("query") variables = data.get("variables", {}) @@ -292,7 +300,7 @@ def mock_graphql_response(self): return mock_data - def get_tree(self): + def get_tree(self) -> Dict[str, List[Dict[str, Union[int, str]]]]: args = request.args tree_list = [] if args.get("recursive") == "1": @@ -312,7 +320,7 @@ def get_tree(self): self.file_count = 2000 return {"tree": tree_list} - def get_content(self, file_id): + def get_content(self, file_id) -> Dict[str, Any]: file = self.files[file_id] return { "name": f"dummy_file_{file_id}.md", @@ -324,7 +332,7 @@ def get_content(self, file_id): "encoding": "base64", } - def get_commits(self): + def get_commits(self) -> List[Dict[str, Dict[str, Dict[str, str]]]]: return [ { "commit": { @@ -337,7 +345,7 @@ def get_commits(self): } ] - def get_scopes(self): + def get_scopes(self) -> Response: response = make_response({}) response.headers["status_code"] = 200 response.headers["X-OAuth-Scopes"] = "repo, user, read:org" diff --git a/tests/sources/fixtures/google_cloud_storage/fixture.py b/tests/sources/fixtures/google_cloud_storage/fixture.py index 7ab9cf876..7b10df90e 100755 --- a/tests/sources/fixtures/google_cloud_storage/fixture.py +++ b/tests/sources/fixtures/google_cloud_storage/fixture.py @@ -14,12 +14,12 @@ from tests.commons import WeightedFakeProvider client_connection = None -HERE = os.path.dirname(__file__) +HERE: str = os.path.dirname(__file__) HOSTS = "/etc/hosts" fake_provider = WeightedFakeProvider() -DATA_SIZE = os.environ.get("DATA_SIZE", "medium").lower() +DATA_SIZE: str = os.environ.get("DATA_SIZE", "medium").lower() match DATA_SIZE: case "small": @@ -38,14 +38,14 @@ class PrerequisiteException(Exception): """This class is used to generate the custom error exception when prerequisites are not satisfied.""" - def __init__(self, errors): + def __init__(self, errors) -> None: super().__init__( f"Error while running e2e test for the Google Cloud Storage connector. \nReason: {errors}" ) self.errors = errors -def get_num_docs(): +def get_num_docs() -> None: print( FIRST_BUCKET_FILE_COUNT + SECOND_BUCKET_FILE_COUNT @@ -53,7 +53,7 @@ def get_num_docs(): ) -async def verify(): +async def verify() -> None: "Method to verify if prerequisites are satisfied for e2e or not" storage_emulator_host = os.getenv(key="STORAGE_EMULATOR_HOST", default=None) if storage_emulator_host != "http://localhost:4443": @@ -61,7 +61,7 @@ async def verify(): raise PrerequisiteException(msg) -def create_connection(): +def create_connection() -> None: """Method for creating connection to the fake Google Cloud Storage server""" try: global client_connection @@ -76,7 +76,7 @@ def create_connection(): raise -def generate_files(bucket_name, number_of_files): +def generate_files(bucket_name, number_of_files) -> None: """Method for generating files on the fake Google Cloud Storage server""" client_connection.create_bucket(bucket_name) bucket = client_connection.bucket(bucket_name) @@ -90,7 +90,7 @@ def generate_files(bucket_name, number_of_files): ) -async def load(): +async def load() -> None: create_connection() print("Started loading files on the fake Google Cloud Storage server....") if FIRST_BUCKET_FILE_COUNT: @@ -102,7 +102,7 @@ async def load(): ) -async def remove(): +async def remove() -> None: """Method for removing random blobs from the fake Google Cloud Storage server""" create_connection() print("Started removing random blobs from the fake Google Cloud Storage server....") @@ -117,5 +117,5 @@ async def remove(): ) -async def setup(): +async def setup() -> None: await verify() diff --git a/tests/sources/fixtures/google_cloud_storage/mocker.py b/tests/sources/fixtures/google_cloud_storage/mocker.py index 34d5d2ca9..2cb0467e9 100644 --- a/tests/sources/fixtures/google_cloud_storage/mocker.py +++ b/tests/sources/fixtures/google_cloud_storage/mocker.py @@ -6,13 +6,15 @@ # ruff: noqa: T201 """Module responsible for mocking POST call to Google Cloud Storage Data Source""" +from typing import Dict, Union + from flask import Flask app = Flask(__name__) @app.route("/token", methods=["POST"]) -def post_auth_token(): +def post_auth_token() -> Dict[str, Union[int, str]]: """Function to load""" return { "access_token": "XXXXXXStBkRnGyZ2mUYOLgls7QVBxOg82XhBCFo8UIT5gM", diff --git a/tests/sources/fixtures/google_drive/fixture.py b/tests/sources/fixtures/google_drive/fixture.py index a532c1609..75a411cc7 100644 --- a/tests/sources/fixtures/google_drive/fixture.py +++ b/tests/sources/fixtures/google_drive/fixture.py @@ -8,6 +8,7 @@ import os import time +from typing import Any, Dict, List, Union from flask import Flask, request @@ -15,7 +16,7 @@ fake_provider = WeightedFakeProvider() -DATA_SIZE = os.environ.get("DATA_SIZE", "medium").lower() +DATA_SIZE: str = os.environ.get("DATA_SIZE", "medium").lower() match DATA_SIZE: case "small": @@ -31,22 +32,22 @@ PRE_REQUEST_SLEEP = float(os.environ.get("PRE_REQUEST_SLEEP", "0.1")) -def get_num_docs(): +def get_num_docs() -> None: print(DOCS_COUNT) @app.before_request -def before_request(): +def before_request() -> None: time.sleep(PRE_REQUEST_SLEEP) @app.route("/drive/v3/about", methods=["GET"]) -def about_get(): +def about_get() -> Dict[str, str]: return {"kind": "drive#about"} @app.route("/drive/v3/drives", methods=["GET"]) -def drives_list(): +def drives_list() -> Dict[str, Union[List[Dict[str, str]], str]]: return { "nextPageToken": "dummyToken", "kind": "drive#driveList", @@ -58,7 +59,7 @@ def drives_list(): @app.route("/drive/v3/files", methods=["GET"]) -def files_list(): +def files_list() -> Dict[str, Union[List[Dict[str, Any]], str]]: files_list = [ { "kind": "drive#file", @@ -76,7 +77,7 @@ def files_list(): @app.route("/drive/v3/files/", methods=["GET"]) -def files_get(file_id): +def files_get(file_id) -> Union[Dict[str, str], str]: req_params = request.args.to_dict() # response includes the file contents in the response body @@ -93,7 +94,7 @@ def files_get(file_id): @app.route("/token", methods=["POST"]) -def post_auth_token(): +def post_auth_token() -> Dict[str, Union[int, str]]: """Function to load""" return { "access_token": "XXXXXXStBkRnGyZ2mUYOLgls7QVBxOg82XhBCFo8UIT5gM", diff --git a/tests/sources/fixtures/graphql/fixture.py b/tests/sources/fixtures/graphql/fixture.py index 917ee5bf7..d47e7e6ea 100644 --- a/tests/sources/fixtures/graphql/fixture.py +++ b/tests/sources/fixtures/graphql/fixture.py @@ -7,6 +7,7 @@ import base64 import os +from typing import Any, Dict from flask import Flask, request @@ -14,7 +15,7 @@ fake_provider = WeightedFakeProvider() -DATA_SIZE = os.environ.get("DATA_SIZE", "medium") +DATA_SIZE: str = os.environ.get("DATA_SIZE", "medium") match DATA_SIZE: case "small": @@ -26,14 +27,14 @@ class GraphQLAPI: - def __init__(self): + def __init__(self) -> None: self.app = Flask(__name__) self.app.route("/graphql", methods=["POST"])(self.mock_graphql_response) - def encode_cursor(self, value): + def encode_cursor(self, value) -> str: return base64.b64encode(str(value).encode()).decode() - def decode_cursor(self, cursor): + def decode_cursor(self, cursor) -> int: return int(base64.b64decode(cursor.encode()).decode()) def get_index_metadata(self, variables, data): @@ -43,7 +44,7 @@ def get_index_metadata(self, variables, data): subset_nodes = data["nodes"][start_index:end_index] return start_index, end_index, subset_nodes - def mock_graphql_response(self): + def mock_graphql_response(self) -> Dict[str, Dict[str, Dict[str, Dict[str, Any]]]]: issue_data = { "nodes": [ { diff --git a/tests/sources/fixtures/jira/fixture.py b/tests/sources/fixtures/jira/fixture.py index d32f8a372..6a11095f1 100644 --- a/tests/sources/fixtures/jira/fixture.py +++ b/tests/sources/fixtures/jira/fixture.py @@ -8,6 +8,8 @@ import io import os +from _io import BytesIO +from typing import Dict, List, Union from flask import Flask, request @@ -15,7 +17,7 @@ fake_provider = WeightedFakeProvider() -DATA_SIZE = os.environ.get("DATA_SIZE", "medium").lower() +DATA_SIZE: str = os.environ.get("DATA_SIZE", "medium").lower() PROJECT_TO_DELETE_COUNT = 100 @@ -32,7 +34,7 @@ @app.route("/rest/api/2/myself", methods=["GET"]) -def get_myself(): +def get_myself() -> Dict[str, str]: """Function to load an authenticated user's data""" myself = { "accountId": "5ff5815e34847e0069fedee3", @@ -44,7 +46,7 @@ def get_myself(): @app.route("/rest/api/2/project", methods=["GET"]) -def get_projects(): +def get_projects() -> List[Dict[str, str]]: """Function to load projects on the jira server Returns: projects (list): List of projects @@ -139,7 +141,7 @@ def get_issue(issue_id): @app.route("/rest/api/2/attachment/content/", methods=["GET"]) -def get_attachment_content(attachment_id): +def get_attachment_content(attachment_id) -> BytesIO: """Function to handle get attachment content calls Args: id (string): id of an attachment. @@ -150,7 +152,7 @@ def get_attachment_content(attachment_id): @app.route("/rest/api/2/field", methods=["GET"]) -def get_fields(): +def get_fields() -> List[Dict[str, Union[Dict[str, str], List[str], bool, str]]]: """Function to get all fields including default and custom fields from Jira""" return [ { diff --git a/tests/sources/fixtures/microsoft_teams/fixture.py b/tests/sources/fixtures/microsoft_teams/fixture.py index f5ba82750..f934702f4 100644 --- a/tests/sources/fixtures/microsoft_teams/fixture.py +++ b/tests/sources/fixtures/microsoft_teams/fixture.py @@ -6,6 +6,8 @@ # ruff: noqa: T201 import io import os +from _io import BytesIO +from typing import Any, Dict, List, Union from flask import Flask, request from flask_limiter import HEADERS, Limiter @@ -18,7 +20,7 @@ app = Flask(__name__) -THROTTLING = os.environ.get("THROTTLING", False) +THROTTLING: Union[bool, str] = os.environ.get("THROTTLING", False) PRE_REQUEST_SLEEP = float(os.environ.get("PRE_REQUEST_SLEEP", "0.05")) if THROTTLING: @@ -64,7 +66,7 @@ def adjust_document_id_size(doc_id): return f"{doc_id}-{addition}" -DATA_SIZE = os.environ.get("DATA_SIZE", "medium").lower() +DATA_SIZE: str = os.environ.get("DATA_SIZE", "medium").lower() match DATA_SIZE: case "small": @@ -92,10 +94,10 @@ def adjust_document_id_size(doc_id): MESSAGES_TO_DELETE = 10 EVENTS_TO_DELETE = 1 -ROOT = os.environ.get("OVERRIDE_URL", "http://127.0.0.1:10971") +ROOT: str = os.environ.get("OVERRIDE_URL", "http://127.0.0.1:10971") -def get_num_docs(): +def get_num_docs() -> None: # I tried to do the maths, but it's not possible without diving too deep into the connector # Therefore, doing naive way - just ran connector and took the number from the test expected_count = 0 @@ -111,7 +113,7 @@ def get_num_docs(): class MicrosoftTeamsAPI: - def __init__(self): + def __init__(self) -> None: self.app = Flask(__name__) self.app.route("/me", methods=["GET"])(self.get_myself) @@ -150,7 +152,7 @@ def __init__(self): self.app.route("/sites/list.txt", methods=["GET"])(self.download_file) - def get_myself(self): + def get_myself(self) -> Dict[str, Any]: return { "displayName": "Alex Wilber", "givenName": "Alex", @@ -159,7 +161,7 @@ def get_myself(self): "id": adjust_document_id_size("me-1"), } - def get_user_chats(self): + def get_user_chats(self) -> Dict[str, List[Dict[str, Any]]]: return { "value": [ { @@ -181,7 +183,9 @@ def get_user_chats(self): ] } - def get_user_chat_messages(self, chat_id): + def get_user_chat_messages( + self, chat_id + ) -> Union[Dict[str, str], Dict[str, Union[List[Dict[str, Any]], str]]]: global MESSAGES message_data = [] top = int(request.args.get("$top")) @@ -212,7 +216,7 @@ def get_user_chat_messages(self, chat_id): MESSAGES -= MESSAGES_TO_DELETE # performs deletion and pagination return response - def get_user_chat_tabs(self, chat_id): + def get_user_chat_tabs(self, chat_id) -> Dict[str, List[Dict[str, Any]]]: return { "value": [ { @@ -223,7 +227,7 @@ def get_user_chat_tabs(self, chat_id): ] } - def get_users(self): + def get_users(self) -> Dict[str, List[Dict[str, Any]]]: return { "value": [ { @@ -233,7 +237,9 @@ def get_users(self): ] } - def get_events(self, user_id): + def get_events( + self, user_id + ) -> Union[Dict[str, str], Dict[str, Union[List[Dict[str, Any]], str]]]: global EVENTS event_data = [] top = int(request.args.get("$top")) @@ -290,7 +296,7 @@ def get_events(self, user_id): EVENTS -= EVENTS_TO_DELETE return response - def get_teams(self): + def get_teams(self) -> Dict[str, Union[List[Dict[str, Any]], str]]: return { "@odata.context": "https://graph.microsoft.com/v1.0/$metadata#teams)", "value": [ @@ -321,7 +327,7 @@ def get_teams(self): ], } - def get_channels(self, team_id): + def get_channels(self, team_id) -> Dict[str, Union[List[Dict[str, Any]], str]]: channel_list = [] for channel in range(CHANNEL): channel_list.append( @@ -338,7 +344,9 @@ def get_channels(self, team_id): "value": channel_list, } - def get_channel_messages(self, team_id, channel_id): + def get_channel_messages( + self, team_id, channel_id + ) -> Dict[str, Union[List[Dict[str, Any]], str]]: message_list = [] for message in range(CHANNEL_MESSAGE): message_list.append( @@ -372,7 +380,9 @@ def get_channel_messages(self, team_id, channel_id): "value": message_list, } - def get_channel_tabs(self, team_id, channel_id): + def get_channel_tabs( + self, team_id, channel_id + ) -> Dict[str, Union[List[Dict[str, Any]], str]]: return { "@odata.context": "https://graph.microsoft.com/v1.0/$metadata#tabs)", "value": [ @@ -385,7 +395,9 @@ def get_channel_tabs(self, team_id, channel_id): ], } - def get_teams_filefolder(self, team_id, channel_id): + def get_teams_filefolder( + self, team_id, channel_id + ) -> Dict[str, Union[Dict[str, str], int, str]]: return { "id": "filfolder-1", "createdDateTime": "0001-01-01T00:00:00Z", @@ -398,7 +410,7 @@ def get_teams_filefolder(self, team_id, channel_id): }, } - def get_teams_file(self, drive_id, item_id): + def get_teams_file(self, drive_id, item_id) -> Dict[str, List[Dict[str, Any]]]: files_list = [] for file_data in range(FILES): files_list.append( @@ -419,7 +431,7 @@ def get_teams_file(self, drive_id, item_id): ) return {"value": files_list} - def download_file(self): + def download_file(self) -> BytesIO: return io.BytesIO(bytes(fake_provider.get_html(), encoding="utf-8")) diff --git a/tests/sources/fixtures/mongodb/fixture.py b/tests/sources/fixtures/mongodb/fixture.py index e3b23531c..c21d496e1 100644 --- a/tests/sources/fixtures/mongodb/fixture.py +++ b/tests/sources/fixtures/mongodb/fixture.py @@ -11,7 +11,7 @@ from faker import Faker from pymongo import MongoClient -DATA_SIZE = os.environ.get("DATA_SIZE", "small").lower() +DATA_SIZE: str = os.environ.get("DATA_SIZE", "small").lower() _SIZES = {"small": 750, "medium": 1500, "large": 3000} NUMBER_OF_RECORDS_TO_DELETE = 50 @@ -21,7 +21,7 @@ ) -async def load(): +async def load() -> None: def _random_record(): return { "id": bson.ObjectId(), @@ -45,7 +45,7 @@ def _random_record(): collection.insert_many(data) -async def remove(): +async def remove() -> None: db = client.sample_database collection = db.sample_collection diff --git a/tests/sources/fixtures/mongodb_serverless/fixture.py b/tests/sources/fixtures/mongodb_serverless/fixture.py index ae1eb9533..b15b389f0 100644 --- a/tests/sources/fixtures/mongodb_serverless/fixture.py +++ b/tests/sources/fixtures/mongodb_serverless/fixture.py @@ -11,7 +11,7 @@ from faker import Faker from pymongo import MongoClient -DATA_SIZE = os.environ.get("DATA_SIZE", "small").lower() +DATA_SIZE: str = os.environ.get("DATA_SIZE", "small").lower() _SIZES = {"small": 750, "medium": 1500, "large": 3000} NUMBER_OF_RECORDS_TO_DELETE = 50 @@ -20,7 +20,7 @@ OB_STORE = "/tmp/objectstore" -async def setup(): +async def setup() -> None: print(f"preparing {OB_STORE}") # creating the file storage for es if os.path.exists(OB_STORE): @@ -32,7 +32,7 @@ async def setup(): print(f"{OB_STORE} ready") -async def load(): +async def load() -> None: def _random_record(): return { "id": bson.ObjectId(), @@ -55,7 +55,7 @@ def _random_record(): collection.insert_many(data) -async def remove(): +async def remove() -> None: db = client.sample_database collection = db.sample_collection diff --git a/tests/sources/fixtures/mssql/fixture.py b/tests/sources/fixtures/mssql/fixture.py index 0504a4e02..38340ac35 100644 --- a/tests/sources/fixtures/mssql/fixture.py +++ b/tests/sources/fixtures/mssql/fixture.py @@ -29,7 +29,7 @@ faked = Faker() BATCH_SIZE = 1000 -DATA_SIZE = os.environ.get("DATA_SIZE", "medium").lower() +DATA_SIZE: str = os.environ.get("DATA_SIZE", "medium").lower() match DATA_SIZE: case "small": @@ -45,11 +45,11 @@ RECORDS_TO_DELETE = 10 -def get_num_docs(): +def get_num_docs() -> None: print(NUM_TABLES * (RECORD_COUNT - RECORDS_TO_DELETE)) -def inject_lines(table, cursor, lines): +def inject_lines(table, cursor, lines) -> None: """Ingest rows in table Args: @@ -144,7 +144,7 @@ def inject_lines(table, cursor, lines): print(f"Inserted batch #{batch} of {batch_size} documents.") -async def load(): +async def load() -> None: """N tables of 10001 rows each. each row is ~ 1024*20 bytes""" database_sa = pytds.connect( @@ -171,7 +171,7 @@ async def load(): database.commit() -async def remove(): +async def remove() -> None: """Removes 10 random items per table""" database = pytds.connect(server=HOST, port=PORT, user=USER, password=PASSWORD) diff --git a/tests/sources/fixtures/mysql/fixture.py b/tests/sources/fixtures/mysql/fixture.py index 0b14a8e06..381158b5a 100644 --- a/tests/sources/fixtures/mysql/fixture.py +++ b/tests/sources/fixtures/mysql/fixture.py @@ -13,7 +13,7 @@ fake_provider = WeightedFakeProvider(weights=[0.65, 0.3, 0.05, 0]) -DATA_SIZE = os.environ.get("DATA_SIZE", "medium").lower() +DATA_SIZE: str = os.environ.get("DATA_SIZE", "medium").lower() match DATA_SIZE: case "small": @@ -31,11 +31,11 @@ DATABASE_NAME = "customerinfo" -def get_num_docs(): +def get_num_docs() -> None: print(NUM_TABLES * (RECORD_COUNT - RECORDS_TO_DELETE)) -def inject_lines(table, cursor, lines): +def inject_lines(table, cursor, lines) -> None: batch_count = max(int(lines / BATCH_SIZE), 1) inserted = 0 @@ -51,7 +51,7 @@ def inject_lines(table, cursor, lines): print(f"Inserting batch #{batch} of {batch_size} documents.") -async def load(): +async def load() -> None: """N tables of 10001 rows each. each row is ~ 1024*20 bytes""" database = connect(host="127.0.0.1", port=3306, user="root", password="changeme") cursor = database.cursor() @@ -67,7 +67,7 @@ async def load(): database.commit() -async def remove(): +async def remove() -> None: """Removes 10 random items per table""" database = connect(host="127.0.0.1", port=3306, user="root", password="changeme") cursor = database.cursor() diff --git a/tests/sources/fixtures/network_drive/fixture.py b/tests/sources/fixtures/network_drive/fixture.py index c141dab0c..1f7dfa1ae 100644 --- a/tests/sources/fixtures/network_drive/fixture.py +++ b/tests/sources/fixtures/network_drive/fixture.py @@ -21,7 +21,7 @@ USERNAME = "admin" PASSWORD = "abc@123" -DATA_SIZE = os.environ.get("DATA_SIZE", "medium") +DATA_SIZE: str = os.environ.get("DATA_SIZE", "medium") match DATA_SIZE: case "small": @@ -35,7 +35,7 @@ FOLDER_COUNT = 250 -def generate_folder(): +def generate_folder() -> None: """Method for generating folder on Network Drive server""" try: print("Started loading folder on network drive server....") @@ -51,7 +51,7 @@ def generate_folder(): ) -def generate_files(): +def generate_files() -> None: """Method for generating files on Network Drive server""" try: smbclient.register_session(server=SERVER, username=USERNAME, password=PASSWORD) @@ -74,12 +74,12 @@ def generate_files(): raise -async def load(): +async def load() -> None: generate_folder() generate_files() -async def remove(): +async def remove() -> None: """Method for deleting 10 random files from Network Drive server""" try: smbclient.register_session(server=SERVER, username=USERNAME, password=PASSWORD) diff --git a/tests/sources/fixtures/notion/fixture.py b/tests/sources/fixtures/notion/fixture.py index 47a388df6..55dacc5d2 100644 --- a/tests/sources/fixtures/notion/fixture.py +++ b/tests/sources/fixtures/notion/fixture.py @@ -7,6 +7,7 @@ """Module to handle api calls received from connector.""" import os +from typing import Any, Dict, Union from flask import Flask, request @@ -15,15 +16,15 @@ app = Flask(__name__) -DATA_SIZE = os.environ.get("DATA_SIZE", "medium").lower() +DATA_SIZE: str = os.environ.get("DATA_SIZE", "medium").lower() _SIZES = {"small": 5, "medium": 10, "large": 15} -NUMBER_OF_DATABASES_PAGES = _SIZES[DATA_SIZE] +NUMBER_OF_DATABASES_PAGES: int = _SIZES[DATA_SIZE] fake_provider = WeightedFakeProvider() class NotionAPI: - def __init__(self): + def __init__(self) -> None: self.app = Flask(__name__) self.first_sync = True self.app.route("/v1/users/me", methods=["GET"])(self.get_owner) @@ -33,7 +34,9 @@ def __init__(self): self.get_block_children ) - def get_owner(self): + def get_owner( + self, + ) -> Dict[str, Union[Dict[str, Union[Dict[str, Union[bool, str]], str]], str]]: return { "object": "user", "id": "user_id", @@ -68,7 +71,9 @@ def get_users(self): } return users - def get_block_children(self, block_id): + def get_block_children( + self, block_id + ) -> Union[Dict[str, Union[None, bool, str]], Dict[str, Union[bool, str]]]: has_start_cursor = request.args.get("start_cursor") if has_start_cursor: response = { @@ -234,7 +239,7 @@ def get_page_database(self): databases_pages.extend(pages) return databases_pages - def search_by_query(self): + def search_by_query(self) -> Dict[str, Any]: return { "object": "list", "results": self.get_page_database(), diff --git a/tests/sources/fixtures/onedrive/fixture.py b/tests/sources/fixtures/onedrive/fixture.py index 8ab518c95..e95f1f41c 100644 --- a/tests/sources/fixtures/onedrive/fixture.py +++ b/tests/sources/fixtures/onedrive/fixture.py @@ -9,8 +9,12 @@ import io import os import time +from _io import BytesIO +from typing import Any, Dict, List, Union +from faker.proxy import Faker from flask import Flask, make_response, request +from flask.wrappers import Response from flask_limiter import HEADERS, Limiter from flask_limiter.util import get_remote_address @@ -21,7 +25,7 @@ app = Flask(__name__) -DATA_SIZE = os.environ.get("DATA_SIZE", "small").lower() +DATA_SIZE: str = os.environ.get("DATA_SIZE", "small").lower() match DATA_SIZE: case "small": @@ -35,7 +39,7 @@ FILE_COUNT_PER_USER = 150 -THROTTLING = os.environ.get("THROTTLING", False) +THROTTLING: Union[bool, str] = os.environ.get("THROTTLING", False) PRE_REQUEST_SLEEP = float(os.environ.get("PRE_REQUEST_SLEEP", "0.05")) if THROTTLING: @@ -60,9 +64,9 @@ TOKEN_EXPIRATION_TIMEOUT = 3699 # seconds -fake = fake_provider.fake +fake: Faker = fake_provider.fake DRIVE_ID = fake.uuid4() -ROOT = os.environ.get("ROOT_HOST_URL", "http://127.0.0.1:10972") +ROOT: str = os.environ.get("ROOT_HOST_URL", "http://127.0.0.1:10972") class DataGenerator: @@ -70,12 +74,12 @@ class DataGenerator: This class is used to generate fake data for OneDrive source. """ - def __init__(self): + def __init__(self) -> None: self.users = [] self.files_per_user = {} self.files_by_id = {} - def generate(self): + def generate(self) -> None: # Generate users in Azure AD for user_id in range(1, TOTAL_USERS + 1): user = { @@ -109,7 +113,7 @@ def generate(self): self.files_per_user[user["id"]].append(item) - def get_users(self, skip=0, take=100): + def get_users(self, skip: int = 0, take: int = 100): results = [] for user in self.users[skip:][:take]: @@ -117,7 +121,9 @@ def get_users(self, skip=0, take=100): return results - def get_drive_items(self, user_id, skip=0, take=100): + def get_drive_items( + self, user_id, skip: int = 0, take: int = 100 + ) -> List[Dict[str, Any]]: results = [] for file in self.files_per_user[user_id][skip:][:take]: @@ -159,7 +165,7 @@ def get_drive_item_content(self, user_id, item_id): class OneDriveAPI: - def __init__(self): + def __init__(self) -> None: self.app = Flask(__name__) self.first_sync = True self.data_generator = DataGenerator() @@ -182,10 +188,10 @@ def __init__(self): self.app.before_request(self.before_request) - def before_request(self): + def before_request(self) -> None: time.sleep(PRE_REQUEST_SLEEP) - def get_access_token(self, tenant_id): + def get_access_token(self, tenant_id) -> Response: res = { "access_token": f"fake-access-token-for-{tenant_id}", "expires_in": TOKEN_EXPIRATION_TIMEOUT, @@ -194,7 +200,7 @@ def get_access_token(self, tenant_id): response.headers["status_code"] = 200 return response - def batched_uris(self): + def batched_uris(self) -> Dict[str, List[Dict[str, Any]]]: payload = request.get_json() response = [] for rest_request in payload["requests"]: @@ -204,7 +210,7 @@ def batched_uris(self): ) return {"responses": response} - def get_users(self): + def get_users(self) -> Dict[str, Any]: skip = int(request.args.get("$skip", 0)) take = int(request.args.get("$take", 100)) users = self.data_generator.get_users(skip, take) @@ -220,7 +226,7 @@ def get_users(self): return response - def get_drive(self): + def get_drive(self) -> Dict[str, Union[Dict[str, Any], str]]: return { "@odata.context": "https://graph.microsoft.com/v1.0/$metadata#drives", "value": { @@ -235,7 +241,7 @@ def get_drive(self): }, } - def get_root_drive_delta(self, user_id): + def get_root_drive_delta(self, user_id) -> Dict[str, Any]: skip = int(request.args.get("$skip", 0)) take = int(request.args.get("$take", 100)) @@ -252,7 +258,7 @@ def get_root_drive_delta(self, user_id): return response - def download_content(self, user_id, item_id): + def download_content(self, user_id, item_id) -> BytesIO: content = self.data_generator.get_drive_item_content(user_id, item_id) return io.BytesIO(bytes(content, encoding="utf-8")) diff --git a/tests/sources/fixtures/oracle/fixture.py b/tests/sources/fixtures/oracle/fixture.py index aef266378..a4a0374d8 100644 --- a/tests/sources/fixtures/oracle/fixture.py +++ b/tests/sources/fixtures/oracle/fixture.py @@ -25,7 +25,7 @@ PASSWORD = "Password_123" DSN = "localhost:1521/FREE" -DATA_SIZE = os.environ.get("DATA_SIZE", "medium").lower() +DATA_SIZE: str = os.environ.get("DATA_SIZE", "medium").lower() match DATA_SIZE: case "small": @@ -41,11 +41,11 @@ RECORDS_TO_DELETE = 10 -def get_num_docs(): +def get_num_docs() -> None: print(NUM_TABLES * (RECORD_COUNT - RECORDS_TO_DELETE)) -def inject_lines(table, cursor, lines): +def inject_lines(table, cursor, lines) -> None: batch_count = max(int(lines / BATCH_SIZE), 1) inserted = 0 print(f"Inserting {lines} lines in {batch_count} batches") @@ -62,7 +62,7 @@ def inject_lines(table, cursor, lines): print(f"Inserted batch #{batch} of {batch_size} documents.") -async def load(): +async def load() -> None: """Generate tables and loads table data in the oracle server.""" """N tables of RECORD_COUNT rows each""" connection = oracledb.connect(user="system", password=PASSWORD, dsn=DSN) @@ -81,7 +81,7 @@ async def load(): connection.commit() -async def remove(): +async def remove() -> None: """Removes 10 random items per table""" connection = oracledb.connect(user=USER, password=PASSWORD, dsn=DSN) cursor = connection.cursor() diff --git a/tests/sources/fixtures/postgresql/fixture.py b/tests/sources/fixtures/postgresql/fixture.py index ced41e8e9..d92b3752f 100644 --- a/tests/sources/fixtures/postgresql/fixture.py +++ b/tests/sources/fixtures/postgresql/fixture.py @@ -7,6 +7,7 @@ import asyncio import os import random +from asyncio.events import AbstractEventLoop import asyncpg @@ -16,7 +17,7 @@ CONNECTION_STRING = "postgresql://admin:Password_123@127.0.0.1:9090/xe" BATCH_SIZE = 100 -DATA_SIZE = os.environ.get("DATA_SIZE", "medium").lower() +DATA_SIZE: str = os.environ.get("DATA_SIZE", "medium").lower() READONLY_USERNAME = "readonly" READONLY_PASSWORD = "foobar123" @@ -34,14 +35,14 @@ RECORDS_TO_DELETE = 10 -event_loop = asyncio.get_event_loop() +event_loop: AbstractEventLoop = asyncio.get_event_loop() -def get_num_docs(): +def get_num_docs() -> None: print(NUM_TABLES * (RECORD_COUNT - RECORDS_TO_DELETE)) -async def load(): +async def load() -> None: """Create a read-only user for use when configuring the connector, then create tables and load table data.""" @@ -95,7 +96,7 @@ async def load_rows(): await load_rows() -async def remove(): +async def remove() -> None: """Remove documents from tables""" connect = await asyncpg.connect(CONNECTION_STRING) for table in range(NUM_TABLES): diff --git a/tests/sources/fixtures/redis/fixture.py b/tests/sources/fixtures/redis/fixture.py index 6e9803106..d8174fcf2 100644 --- a/tests/sources/fixtures/redis/fixture.py +++ b/tests/sources/fixtures/redis/fixture.py @@ -11,9 +11,9 @@ from tests.commons import WeightedFakeProvider -DATA_SIZE = os.environ.get("DATA_SIZE", "small").lower() +DATA_SIZE: str = os.environ.get("DATA_SIZE", "small").lower() _NUM_DB = {"small": 2, "medium": 4, "large": 16} -NUM_DB = _NUM_DB[DATA_SIZE] +NUM_DB: int = _NUM_DB[DATA_SIZE] RECORDS_TO_DELETE = 10 EACH_ROW_ITEMS = 500 ENDPOINT = "redis://localhost:6379/" @@ -21,7 +21,7 @@ fake_provider = WeightedFakeProvider(weights=[0.65, 0.3, 0.05, 0]) -async def inject_lines(redis_client, lines): +async def inject_lines(redis_client, lines) -> None: text = fake_provider.get_text() rows = {} for row_id in range(lines): @@ -30,7 +30,7 @@ async def inject_lines(redis_client, lines): await redis_client.mset(rows) -async def load(): +async def load() -> None: """N databases of 500 rows each. each row is ~ 1024*20 bytes""" redis_client = await redis.from_url(f"{ENDPOINT}") for db in range(NUM_DB): @@ -39,7 +39,7 @@ async def load(): await inject_lines(redis_client, EACH_ROW_ITEMS) -async def remove(): +async def remove() -> None: """Removes 10 random items per db""" redis_client = await redis.from_url(f"{ENDPOINT}") for db in range(NUM_DB): diff --git a/tests/sources/fixtures/s3/fixture.py b/tests/sources/fixtures/s3/fixture.py index a9bd4b0fa..8807e0771 100644 --- a/tests/sources/fixtures/s3/fixture.py +++ b/tests/sources/fixtures/s3/fixture.py @@ -16,7 +16,7 @@ REGION_NAME = "us-west-2" AWS_ENDPOINT_URL = "http://127.0.0.1" AWS_PORT = int(os.environ.get("AWS_PORT", "5001")) -DATA_SIZE = os.environ.get("DATA_SIZE", "small").lower() +DATA_SIZE: str = os.environ.get("DATA_SIZE", "small").lower() AWS_SECRET_KEY = "dummy_secret_key" AWS_ACCESS_KEY_ID = "dummy_access_key" @@ -36,16 +36,16 @@ fake_provider = WeightedFakeProvider() -def get_num_docs(): +def get_num_docs() -> None: print(FOLDER_COUNT + FILE_COUNT - OBJECT_TO_DELETE_COUNT) -async def setup(): +async def setup() -> None: os.environ["AWS_ENDPOINT_URL"] = AWS_ENDPOINT_URL os.environ["AWS_PORT"] = str(AWS_PORT) -async def load(): +async def load() -> None: """Method for generating 10k document for aws s3 emulator""" s3_client = boto3.client( "s3", @@ -82,7 +82,7 @@ async def load(): ) -async def remove(): +async def remove() -> None: """Method for removing 15 random document from aws s3 emulator""" s3_client = boto3.client( "s3", diff --git a/tests/sources/fixtures/salesforce/fixture.py b/tests/sources/fixtures/salesforce/fixture.py index d13838cda..a68271fb4 100644 --- a/tests/sources/fixtures/salesforce/fixture.py +++ b/tests/sources/fixtures/salesforce/fixture.py @@ -10,15 +10,18 @@ import os import random import re +from _io import BytesIO +from typing import Any, Dict, List, Union +from faker.proxy import Faker from flask import Flask, request from tests.commons import WeightedFakeProvider fake_provider = WeightedFakeProvider(weights=[0.6, 0.2, 0.15, 0.05]) -fake = fake_provider.fake +fake: Faker = fake_provider.fake -DATA_SIZE = os.environ.get("DATA_SIZE", "medium").lower() +DATA_SIZE: str = os.environ.get("DATA_SIZE", "medium").lower() match DATA_SIZE: case "small": @@ -118,7 +121,7 @@ def generate_string(size): CONTENT_DOCUMENT_IDS = [generate_string(18) for _ in range(50)] -def generate_records(table_name): +def generate_records(table_name) -> List[Dict[str, Dict[str, Any]]]: records = [] for _ in range(RECORD_COUNT): record = {"Id": generate_string(18)} @@ -131,7 +134,7 @@ def generate_records(table_name): return records -def generate_content_document_records(): +def generate_content_document_records() -> List[Dict[str, Dict[str, Any]]]: return [ { "ContentDocument": { @@ -148,7 +151,7 @@ def generate_content_document_records(): @app.route("/services/oauth2/token", methods=["POST"]) -def token(): +def token() -> Dict[str, str]: return { "access_token": "XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX", "signature": "YYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYY", @@ -160,7 +163,7 @@ def token(): @app.route("/services/data//query", methods=["GET"]) -def query(version): +def query(version) -> Dict[str, Any]: query = request.args.get("q") table_name = re.findall(r"\bFROM\s+(\w+)", query)[-1] @@ -172,7 +175,7 @@ def query(version): @app.route("/services/data/<_version>/query/", methods=["GET"]) -def query_next(_version, table_name): +def query_next(_version, table_name) -> Dict[str, Any]: # table_name is supposed to be a followup query id. # We co-opt the value in ftests so we can track what table the queries are being run against, # as the table is not included in follow-up query params @@ -183,17 +186,17 @@ def query_next(_version, table_name): @app.route("/services/data/<_version>/sobjects", methods=["GET"]) -def describe(_version): +def describe(_version) -> Dict[str, List[Dict[str, Union[bool, str]]]]: return {"sobjects": [{"name": x, "queryable": True} for x in SOBJECTS]} @app.route("/services/data/<_version>/sobjects/<_sobject>/describe", methods=["GET"]) -def describe_sobject(_version, _sobject): +def describe_sobject(_version, _sobject) -> Dict[str, List[Dict[str, str]]]: return {"fields": [{"name": x} for x in SOBJECT_FIELDS]} @app.route("/sfc/servlet.shepherd/version/download/<_download_id>", methods=["GET"]) -def download(_download_id): +def download(_download_id) -> BytesIO: return io.BytesIO(bytes(fake_provider.get_html(), encoding="utf-8")) diff --git a/tests/sources/fixtures/sandfly/fixture.py b/tests/sources/fixtures/sandfly/fixture.py index 4a1bb6e86..4dbe59a94 100644 --- a/tests/sources/fixtures/sandfly/fixture.py +++ b/tests/sources/fixtures/sandfly/fixture.py @@ -8,16 +8,18 @@ import os from datetime import datetime, timedelta +from typing import Any, Dict, List, Union +from faker.proxy import Faker from flask import Flask from tests.commons import WeightedFakeProvider fake_provider = WeightedFakeProvider() -fake = fake_provider.fake +fake: Faker = fake_provider.fake -DATA_SIZE = os.environ.get("DATA_SIZE", "medium").lower() +DATA_SIZE: str = os.environ.get("DATA_SIZE", "medium").lower() match DATA_SIZE: case "small": @@ -40,13 +42,13 @@ raise Exception(msg) -def get_num_docs(): +def get_num_docs() -> None: total_docs = HOSTS_COUNT + KEYS_COUNT + (RESULTS_LOOP * RESULTS_COUNT) print(total_docs) class SandflyAPI: - def __init__(self): + def __init__(self) -> None: self.app = Flask(__name__) self.results_start = 0 self.results_stop = 0 @@ -62,7 +64,7 @@ def __init__(self): ) self.app.route("/v4/results", methods=["POST"])(self.get_results) - def do_ping(self): + def do_ping(self) -> Dict[str, Union[int, str]]: return { "data": "", "detail": "authentication failed", @@ -70,13 +72,13 @@ def do_ping(self): "title": "Unauthorized", } - def get_access_token(self): + def get_access_token(self) -> Dict[str, str]: return { "access_token": "Token#123", "refresh_token": "Refresh#123", } - def get_license(self): + def get_license(self) -> Dict[str, Union[Dict[str, Any], int]]: def _format_date(date): return date.strftime("%Y-%m-%dT%H:%M:%SZ") @@ -89,7 +91,7 @@ def _format_date(date): "limits": {"features": ["demo", "elasticsearch_replication"]}, } - def get_hosts(self): + def get_hosts(self) -> Dict[str, List[Dict[str, Any]]]: return { "data": [ { @@ -101,13 +103,13 @@ def get_hosts(self): ], } - def get_ssh_summary(self): + def get_ssh_summary(self) -> Dict[str, Union[List[Dict[str, str]], bool]]: return { "more_results": False, "data": [{"id": f"{key_id}"} for key_id in range(1, KEYS_COUNT + 1)], } - def get_ssh_key(self, key_id): + def get_ssh_key(self, key_id) -> Dict[str, str]: return { "id": key_id, "friendly_name": f"key{key_id} " + fake.word() + " " + fake.word(), diff --git a/tests/sources/fixtures/servicenow/fixture.py b/tests/sources/fixtures/servicenow/fixture.py index d21adc5fe..394950c2e 100644 --- a/tests/sources/fixtures/servicenow/fixture.py +++ b/tests/sources/fixtures/servicenow/fixture.py @@ -11,19 +11,25 @@ import os import random import string +from _io import BytesIO +from typing import Dict, List, Tuple from urllib.parse import parse_qs, urlparse from flask import Flask, make_response, request +from flask.wrappers import Response +from typing_extensions import Buffer -DATA_SIZE = os.environ.get("DATA_SIZE", "small").lower() +DATA_SIZE: str = os.environ.get("DATA_SIZE", "small").lower() _SIZES = {"small": 500000, "medium": 1000000, "large": 3000000} -FILE_SIZE = _SIZES[DATA_SIZE] -LARGE_DATA = "".join([random.choice(string.ascii_letters) for _ in range(FILE_SIZE)]) +FILE_SIZE: int = _SIZES[DATA_SIZE] +LARGE_DATA: str = "".join( + [random.choice(string.ascii_letters) for _ in range(FILE_SIZE)] +) TABLE_FETCH_SIZE = 50 class ServiceNowAPI: - def __init__(self): + def __init__(self) -> None: self.app = Flask(__name__) self.table_length = 500 self.get_table_length_call = 6 @@ -35,17 +41,17 @@ def __init__(self): ) self.app.route("/api/now/v1/batch", methods=["POST"])(self.get_batch_data) - def get_servicenow_formatted_data(self, response_key, response_data): + def get_servicenow_formatted_data(self, response_key, response_data) -> bytes: return bytes(str({response_key: response_data}).replace("'", '"'), "utf-8") - def get_url_data(self, url): + def get_url_data(self, url) -> Tuple[str, Dict[str, List[str]]]: parsed_url = urlparse(url) return parsed_url.path, parse_qs(parsed_url.query) - def decode_response(self, response): + def decode_response(self, response: Buffer) -> str: return base64.b64encode(response).decode() - def get_batch_data(self): + def get_batch_data(self) -> Response: batch_data = request.get_json() response = [] for rest_request in batch_data["rest_requests"]: @@ -80,7 +86,7 @@ def get_batch_data(self): batch_response.headers["Content-Type"] = "application/json" return batch_response - def get_table_length(self, table): + def get_table_length(self, table) -> Response: response = make_response(bytes(str({"Response": "Dummy"}), "utf-8")) response.headers["Content-Type"] = "application/json" if int(request.args["sysparm_limit"]) == 1: @@ -92,7 +98,7 @@ def get_table_length(self, table): self.table_length = 300 # to delete 2000 records per service in next sync return response - def get_table_data(self, table, offset): + def get_table_data(self, table, offset) -> bytes: records = [] for i in range(offset - TABLE_FETCH_SIZE, offset): records.append( @@ -106,7 +112,7 @@ def get_table_data(self, table, offset): response_key="result", response_data=records ) - def get_attachment_data(self, table_sys_id): + def get_attachment_data(self, table_sys_id) -> bytes: record = [ { "sys_id": f"attachment-{table_sys_id}", @@ -119,7 +125,7 @@ def get_attachment_data(self, table_sys_id): response_key="result", response_data=record ) - def get_attachment_content(self, sys_id): + def get_attachment_content(self, sys_id) -> BytesIO: return io.BytesIO(bytes(LARGE_DATA, encoding="utf-8")) diff --git a/tests/sources/fixtures/sharepoint_online/fixture.py b/tests/sources/fixtures/sharepoint_online/fixture.py index 539f5642a..54c005481 100644 --- a/tests/sources/fixtures/sharepoint_online/fixture.py +++ b/tests/sources/fixtures/sharepoint_online/fixture.py @@ -9,7 +9,9 @@ import random import string import time +from typing import Any, Dict, List, Optional, Union +from faker.proxy import Faker from flask import Flask, escape, request from flask_limiter import HEADERS, Limiter from flask_limiter.util import get_remote_address @@ -19,13 +21,13 @@ seed = 1597463007 fake_provider = WeightedFakeProvider() -fake = fake_provider.fake +fake: Faker = fake_provider.fake random.seed(seed) app = Flask(__name__) -THROTTLING = os.environ.get("THROTTLING", False) +THROTTLING: Union[bool, str] = os.environ.get("THROTTLING", False) PRE_REQUEST_SLEEP = float(os.environ.get("PRE_REQUEST_SLEEP", "0.05")) if THROTTLING: @@ -49,13 +51,13 @@ TOKEN_EXPIRATION_TIMEOUT = 3699 # seconds -ROOT = os.environ.get( +ROOT: str = os.environ.get( "ROOT_HOST_URL", "http://127.0.0.1:10337" ) # possible to override if hosting somewhere else TENANT = "functionaltest.sharepoint.fake" -DATA_SIZE = os.environ.get("DATA_SIZE", "medium") +DATA_SIZE: str = os.environ.get("DATA_SIZE", "medium") match DATA_SIZE: case "extra_small": @@ -88,22 +90,22 @@ NUMBER_OF_LIST_ITEM_ATTACHMENTS = 5 -TOTAL_RECORD_COUNT = NUMBER_OF_SITES * ( +TOTAL_RECORD_COUNT: int = NUMBER_OF_SITES * ( 1 * NUMBER_OF_DRIVE_ITEMS + NUMBER_OF_PAGES + NUMBER_OF_LISTS * NUMBER_OF_LIST_ITEMS * NUMBER_OF_LIST_ITEM_ATTACHMENTS ) -def get_num_docs(): +def get_num_docs() -> None: print(TOTAL_RECORD_COUNT) # noqa: T201 class AutoIncrement: - def __init__(self): + def __init__(self) -> None: self.val = 1 - def get(self): + def get(self) -> int: value = self.val self.val += 1 return value @@ -124,7 +126,7 @@ class RandomDataStorage: """ - def __init__(self): + def __init__(self) -> None: self.autoinc = AutoIncrement() self.tenants = [] @@ -142,7 +144,7 @@ def __init__(self): self.site_lists_by_list_name = {} self.list_item_attachment_content = {} - def generate(self): + def generate(self) -> None: self.tenants = [TENANT] for _ in range(NUMBER_OF_SITES): @@ -237,7 +239,7 @@ def generate(self): self.site_list_items[site_list["id"]].append(list_item) - def get_site_collections(self): + def get_site_collections(self) -> List[Dict[str, Union[Dict[str, Any], str]]]: results = [] for tenant in self.tenants: @@ -253,7 +255,7 @@ def get_site_collections(self): return results - def get_tenant_root_site_collections(self): + def get_tenant_root_site_collections(self) -> Dict[str, Any]: return { "@odata.context": "https://graph.microsoft.com/v1.0/$metadata#sites/$entity", "createdDateTime": "2023-12-12T12:00:00.000Z", @@ -267,7 +269,7 @@ def get_tenant_root_site_collections(self): "siteCollection": {"hostname": "example.sharepoint.com"}, } - def get_sites(self, skip=0, take=10): + def get_sites(self, skip: int = 0, take: int = 10) -> List[Dict[str, Any]]: results = [] for site in self.sites[skip:][:take]: @@ -286,7 +288,7 @@ def get_sites(self, skip=0, take=10): return results - def get_site_drives(self, site_id): + def get_site_drives(self, site_id) -> List[Dict[str, Any]]: results = [] site = self.sites_by_site_id[site_id] @@ -327,7 +329,9 @@ def get_site_drives(self, site_id): return results - def get_site_pages(self, site_name, skip=0, take=10): + def get_site_pages( + self, site_name, skip: int = 0, take: int = 10 + ) -> List[Dict[str, Any]]: results = [] site = self.sites_by_site_name[site_name] @@ -379,10 +383,12 @@ def get_site_pages(self, site_name, skip=0, take=10): return results - def generate_sharepoint_id(self): + def generate_sharepoint_id(self) -> str: return "".join(random.choices(string.ascii_uppercase + string.digits, k=32)) - def get_drive_items(self, drive_id, skip=0, take=100): + def get_drive_items( + self, drive_id, skip: int = 0, take: int = 100 + ) -> List[Union[Dict[str, int], Dict[str, str]]]: results = [] site = self.sites_by_drive_id[drive_id] @@ -420,7 +426,9 @@ def get_drive_items(self, drive_id, skip=0, take=100): def get_drive_item_content(self, drive_item_id): return self.drive_item_content[drive_item_id] - def get_site_lists(self, site_id, skip=0, take=100): + def get_site_lists( + self, site_id, skip: int = 0, take: int = 100 + ) -> List[Dict[str, Any]]: results = [] site = self.sites_by_site_id[site_id] @@ -455,7 +463,9 @@ def get_site_lists(self, site_id, skip=0, take=100): return results - def get_site_list_items(self, site_id, list_id, skip=0, take=100): + def get_site_list_items( + self, site_id, list_id, skip: int = 0, take: int = 100 + ) -> List[Dict[str, Any]]: results = [] list_ = self.site_lists_by_list_id[list_id] @@ -519,7 +529,9 @@ def get_site_list_items(self, site_id, list_id, skip=0, take=100): return results - def get_site_list_item_attachments(self, site_name, list_name, list_item_id): + def get_site_list_item_attachments( + self, site_name, list_name, list_item_id + ) -> Dict[str, Any]: list_ = self.site_lists_by_list_name[list_name] site = self.sites_by_site_name[site_name] list_items = self.site_list_items[list_["id"]] @@ -578,12 +590,12 @@ def get_list_item_attachment_content(self, list_id, list_item_id, file_name): @app.before_request -def before_request(): +def before_request() -> None: time.sleep(PRE_REQUEST_SLEEP) @app.route("//oauth2/v2.0/token", methods=["POST"]) -def get_graph_token(tenant_id): +def get_graph_token(tenant_id) -> Dict[str, Union[int, str]]: return { "access_token": f"fake-graph-api-token-{tenant_id}", "expires_in": TOKEN_EXPIRATION_TIMEOUT, @@ -591,7 +603,7 @@ def get_graph_token(tenant_id): @app.route("//tokens/OAuth/2", methods=["POST"]) -def get_rest_token(tenant_id): +def get_rest_token(tenant_id) -> Dict[str, Union[int, str]]: return { "access_token": f"fake-rest-api-token-{tenant_id}", "expires_in": TOKEN_EXPIRATION_TIMEOUT, @@ -599,7 +611,7 @@ def get_rest_token(tenant_id): @app.route("/common/userrealm/", methods=["GET"]) -def get_tenant(): +def get_tenant() -> Dict[str, Optional[str]]: return { "NameSpaceType": "Managed", "Login": "cj@something.onmicrosoft.com", @@ -611,7 +623,7 @@ def get_tenant(): @app.route("/sites/", methods=["GET"]) -def get_site_collections(): +def get_site_collections() -> Dict[str, Any]: # No paging as there's always one site collection return { "@odata.context": "https://graph.microsoft.com/v1.0/$metadata#sites(siteCollection,webUrl)", @@ -620,13 +632,13 @@ def get_site_collections(): @app.route("/sites/root", methods=["GET"]) -def get_tenant_root_site_collections(): +def get_tenant_root_site_collections() -> Dict[str, Any]: # No paging as there's always one site collection return data_storage.get_tenant_root_site_collections() @app.route("/sites//sites", methods=["GET"]) -def get_sites(site_id): +def get_sites(site_id) -> Dict[str, Any]: # Sharepoint Online does not use skip/take, but we do it here just for lazy implementation skip = int(request.args.get("$skip", 0)) take = int(request.args.get("$take", 10)) @@ -647,7 +659,7 @@ def get_sites(site_id): @app.route("/sites//drives", methods=["GET"]) -def get_site_drives(site_id): +def get_site_drives(site_id) -> Dict[str, Any]: # I don't bother to page cause it's mostly 1-2 drives return { "@odata.context": "https://graph.microsoft.com/v1.0/$metadata#drives", @@ -656,7 +668,7 @@ def get_site_drives(site_id): @app.route("/drives//root/delta", methods=["GET"]) -def get_drive_root_delta(drive_id): +def get_drive_root_delta(drive_id) -> Dict[str, Any]: skip = int(request.args.get("$skip", 0)) take = int(request.args.get("$take", 100)) @@ -682,7 +694,7 @@ def download_drive_item(drive_id, item_id): @app.route("/sites//lists", methods=["GET"]) -def get_site_lists(site_id): +def get_site_lists(site_id) -> Dict[str, Any]: skip = int(request.args.get("$skip", 0)) take = int(request.args.get("$take", 100)) @@ -701,7 +713,7 @@ def get_site_lists(site_id): @app.route("/sites//lists//items", methods=["GET"]) -def get_site_list_items(site_id, list_id): +def get_site_list_items(site_id, list_id) -> Dict[str, Any]: skip = int(request.args.get("$skip", 0)) take = int(request.args.get("$take", 100)) @@ -723,7 +735,7 @@ def get_site_list_items(site_id, list_id): @app.route( "/sites//_api/lists/GetByTitle('')/items()" ) -def get_list_item_attachments(site_name, list_title, list_item_id): +def get_list_item_attachments(site_name, list_title, list_item_id) -> Dict[str, Any]: expand = request.args.get(escape("$expand")) if expand and "AttachmentFiles" in expand: return data_storage.get_site_list_item_attachments( @@ -735,7 +747,7 @@ def get_list_item_attachments(site_name, list_title, list_item_id): @app.route("/sites//_api/web/lists/GetByTitle('Site Pages')/items") -def get_site_pages(site_name): +def get_site_pages(site_name) -> Dict[str, Any]: skip = int(request.args.get("$skip", 0)) take = int(request.args.get("$take", 100)) diff --git a/tests/sources/fixtures/sharepoint_server/fixture.py b/tests/sources/fixtures/sharepoint_server/fixture.py index a309c578e..098194378 100644 --- a/tests/sources/fixtures/sharepoint_server/fixture.py +++ b/tests/sources/fixtures/sharepoint_server/fixture.py @@ -8,6 +8,7 @@ import os import time +from typing import Dict, Union from flask import Flask, request from flask_limiter import HEADERS, Limiter @@ -19,7 +20,7 @@ app = Flask(__name__) -THROTTLING = os.environ.get("THROTTLING", False) +THROTTLING: Union[bool, str] = os.environ.get("THROTTLING", False) if THROTTLING: limiter = Limiter( @@ -42,7 +43,7 @@ DOC_ID_SIZE = 36 DOC_ID_FILLING_CHAR = "0" # used to fill in missing symbols for IDs -DATA_SIZE = os.environ.get("DATA_SIZE", "medium").lower() +DATA_SIZE: str = os.environ.get("DATA_SIZE", "medium").lower() match DATA_SIZE: case "small": @@ -73,7 +74,7 @@ def adjust_document_id_size(document_id): return f"{document_id}-{addition}" -def get_num_docs(): +def get_num_docs() -> None: total_attachments = total_subsites * lists_per_site * attachments_per_list total_lists = total_subsites * lists_per_site print(total_subsites + total_lists + 2 * total_attachments) @@ -83,12 +84,12 @@ def get_num_docs(): @app.before_request -def before_request(): +def before_request() -> None: time.sleep(PRE_REQUEST_SLEEP) @app.route("/sites//_api/web", methods=["get"]) -def get_site(site_collections): +def get_site(site_collections) -> Dict[str, Union[int, str]]: site = { "Created": "2023-10-24T00:24:36", "Description": site_collections, @@ -282,7 +283,9 @@ def get_list_and_items(parent_site_url, list_id): "//_api/web/getfilebyserverrelativeurl('')", methods=["GET"], ) -def get_attachment_data(parent_site_url, file_relative_url): +def get_attachment_data( + parent_site_url, file_relative_url +) -> Dict[str, Union[int, str]]: """Function to fetch attachment data on the sharepoint Args: parent_site_url (str): Path of parent site @@ -307,7 +310,7 @@ def get_attachment_data(parent_site_url, file_relative_url): "//_api/web/GetFileByServerRelativeUrl('')/$value", methods=["GET"], ) -def download(site, server_url): +def download(site, server_url) -> bytes: """Function to extract content of a attachment on the sharepoint Args: parent_url (str): Path of parent site diff --git a/tests/sources/fixtures/zoom/fixture.py b/tests/sources/fixtures/zoom/fixture.py index 96fd6b607..97db49020 100644 --- a/tests/sources/fixtures/zoom/fixture.py +++ b/tests/sources/fixtures/zoom/fixture.py @@ -8,17 +8,19 @@ import os from datetime import datetime, timedelta +from typing import Any, Dict, List, Optional, Union +from faker.proxy import Faker from flask import Flask, request from tests.commons import WeightedFakeProvider fake_provider = WeightedFakeProvider() -override_url = os.environ.get("OVERRIDE_URL", "http://127.0.0.1:10971") -fake = fake_provider.fake +override_url: str = os.environ.get("OVERRIDE_URL", "http://127.0.0.1:10971") +fake: Faker = fake_provider.fake -DATA_SIZE = os.environ.get("DATA_SIZE", "medium").lower() +DATA_SIZE: str = os.environ.get("DATA_SIZE", "medium").lower() match DATA_SIZE: case "small": @@ -45,7 +47,7 @@ class ZoomAPI: - def __init__(self): + def __init__(self) -> None: self.app = Flask(__name__) self.files = {} self.total_user = USER_COUNT @@ -66,7 +68,7 @@ def __init__(self): ) self.app.route("/download/", methods=["GET"])(self.get_content) - def get_access_token(self): + def get_access_token(self) -> Dict[str, Union[int, str]]: return {"access_token": "123456789", "expires_in": 3599} def get_users(self): @@ -89,7 +91,7 @@ def get_users(self): } return res - def get_meetings(self, user): + def get_meetings(self, user) -> Dict[str, Optional[List[Dict[str, Any]]]]: meeting_type = request.args.get("type") return { "next_page_token": None, @@ -105,7 +107,7 @@ def get_meetings(self, user): ], } - def get_recordings(self, user): + def get_recordings(self, user) -> Dict[str, Optional[List[Dict[str, Any]]]]: def _format_recording_date(date): return date.strftime("%Y-%m-%d") @@ -152,7 +154,7 @@ def _format_date(date): res = {} return res - def get_channels(self, user): + def get_channels(self, user) -> Dict[str, Optional[List[Dict[str, Any]]]]: return { "next_page_token": None, "channels": [ @@ -167,7 +169,7 @@ def get_channels(self, user): ], } - def get_messages(self, user): + def get_messages(self, user) -> Dict[str, Optional[List[Dict[str, Any]]]]: message_type = request.args.get("search_type") if message_type == "message": res = { diff --git a/tests/sources/support.py b/tests/sources/support.py index 61f49633e..ddb20741d 100644 --- a/tests/sources/support.py +++ b/tests/sources/support.py @@ -4,8 +4,13 @@ # you may not use this file except in compliance with the Elastic License 2.0. # from contextlib import asynccontextmanager +from typing import Type -from connectors.source import DEFAULT_CONFIGURATION, DataSourceConfiguration +from connectors.source import ( + DEFAULT_CONFIGURATION, + BaseDataSource, + DataSourceConfiguration, +) @asynccontextmanager @@ -24,7 +29,7 @@ async def create_source(klass, **extras): await source.close() -async def assert_basics(klass, field, value): +async def assert_basics(klass: Type[BaseDataSource], field: str, value: str) -> None: config = DataSourceConfiguration(klass.get_default_configuration()) assert config[field] == value async with create_source(klass) as source: diff --git a/tests/sources/test_atlassian.py b/tests/sources/test_atlassian.py index 730336694..386e063f8 100644 --- a/tests/sources/test_atlassian.py +++ b/tests/sources/test_atlassian.py @@ -91,7 +91,9 @@ ], ) @pytest.mark.asyncio -async def test_advanced_rules_validation(advanced_rules, expected_validation_result): +async def test_advanced_rules_validation( + advanced_rules, expected_validation_result +) -> None: validation_result = await AtlassianAdvancedRulesValidator( AdvancedRulesValidator ).validate(advanced_rules) @@ -139,7 +141,7 @@ async def test_advanced_rules_validation(advanced_rules, expected_validation_res ], ) @pytest.mark.asyncio -async def test_active_atlassian_user(user_info, result): +async def test_active_atlassian_user(user_info, result) -> None: async with create_source( JiraDataSource, jira_url="https://127.0.0.1:8080/test" ) as source: @@ -186,7 +188,7 @@ async def test_active_atlassian_user(user_info, result): @pytest.mark.asyncio -async def test_fetch_all_users(mock_responses): +async def test_fetch_all_users(mock_responses) -> None: jira_host = "https://127.0.0.1:8080" users_path = "rest/api/3/users/search" diff --git a/tests/sources/test_azure_blob_storage.py b/tests/sources/test_azure_blob_storage.py index 601c2ab7a..c580593bc 100644 --- a/tests/sources/test_azure_blob_storage.py +++ b/tests/sources/test_azure_blob_storage.py @@ -21,7 +21,7 @@ @asynccontextmanager async def create_abs_source( - use_text_extraction_service=False, + use_text_extraction_service: bool = False, ): async with create_source( AzureBlobStorageDataSource, @@ -34,7 +34,7 @@ async def create_abs_source( @pytest.mark.asyncio -async def test_ping_for_successful_connection(): +async def test_ping_for_successful_connection() -> None: """Test ping method of AzureBlobStorageDataSource class""" # Setup @@ -59,7 +59,7 @@ async def test_ping_for_successful_connection(): @pytest.mark.asyncio -async def test_ping_for_failed_connection(): +async def test_ping_for_failed_connection() -> None: """Test ping method of AzureBlobStorageDataSource class with negative case""" # Setup @@ -75,7 +75,7 @@ async def test_ping_for_failed_connection(): @pytest.mark.asyncio -async def test_prepare_blob_doc(): +async def test_prepare_blob_doc() -> None: """Test prepare_blob_doc method of AzureBlobStorageDataSource Class""" # Setup @@ -115,7 +115,7 @@ async def test_prepare_blob_doc(): @pytest.mark.asyncio -async def test_get_container(): +async def test_get_container() -> None: """Test get_container method of AzureBlobStorageDataSource Class""" # Setup @@ -153,7 +153,7 @@ async def test_get_container(): @pytest.mark.asyncio -async def test_get_blob(): +async def test_get_blob() -> None: """Test get_blob method of AzureBlobStorageDataSource Class""" # Setup @@ -202,7 +202,7 @@ async def test_get_blob(): @pytest.mark.asyncio -async def test_get_blob_negative(): +async def test_get_blob_negative() -> None: """Test get_blob negative method of AzureBlobStorageDataSource Class""" async with create_abs_source() as source: @@ -214,7 +214,7 @@ async def test_get_blob_negative(): @pytest.mark.asyncio -async def test_get_containr_negative(): +async def test_get_containr_negative() -> None: """Test get_container negative method of AzureBlobStorageDataSource Class""" async with create_abs_source() as source: @@ -226,7 +226,7 @@ async def test_get_containr_negative(): @pytest.mark.asyncio -async def test_get_doc(): +async def test_get_doc() -> None: """Test get_doc method of AzureBlobStorageDataSource Class""" # Setup @@ -302,7 +302,7 @@ async def create_fake_coroutine(item): @pytest.mark.asyncio -async def test_get_doc_for_specific_container(): +async def test_get_doc_for_specific_container() -> None: """Test get_doc for specific container method of AzureBlobStorageDataSource Class""" # Setup @@ -371,7 +371,7 @@ async def test_get_doc_for_specific_container(): @pytest.mark.asyncio -async def test_get_content(): +async def test_get_content() -> None: """Test get_content method of AzureBlobStorageDataSource Class""" # Setup @@ -417,7 +417,7 @@ async def read(self): @pytest.mark.asyncio -async def test_get_content_with_upper_extension(): +async def test_get_content_with_upper_extension() -> None: """Test get_content method of AzureBlobStorageDataSource Class""" # Setup @@ -465,7 +465,7 @@ async def read(self): @pytest.mark.asyncio -async def test_get_content_when_doit_false(): +async def test_get_content_when_doit_false() -> None: """Test get_content method when doit is false.""" # Setup @@ -493,7 +493,7 @@ async def test_get_content_when_doit_false(): @pytest.mark.asyncio -async def test_get_content_when_file_size_0b(): +async def test_get_content_when_file_size_0b() -> None: """Test get_content method when the file size is 0b""" # Setup @@ -521,7 +521,7 @@ async def test_get_content_when_file_size_0b(): @pytest.mark.asyncio -async def test_get_content_when_size_limit_exceeded(): +async def test_get_content_when_size_limit_exceeded() -> None: """Test get_content method when the file size is 10MB""" # Setup @@ -549,7 +549,7 @@ async def test_get_content_when_size_limit_exceeded(): @pytest.mark.asyncio -async def test_get_content_when_type_not_supported(): +async def test_get_content_when_type_not_supported() -> None: """Test get_content method when the file type is not supported""" # Setup @@ -577,7 +577,7 @@ async def test_get_content_when_type_not_supported(): @pytest.mark.asyncio -async def test_validate_config_no_account_name(): +async def test_validate_config_no_account_name() -> None: """Test configure connection string method of AzureBlobStorageDataSource class""" # Setup @@ -590,7 +590,7 @@ async def test_validate_config_no_account_name(): @pytest.mark.asyncio -async def test_tweak_bulk_options(): +async def test_tweak_bulk_options() -> None: """Test tweak_bulk_options method of BaseDataSource class""" # Setup @@ -603,7 +603,7 @@ async def test_tweak_bulk_options(): @pytest.mark.asyncio -async def test_validate_config_invalid_concurrent_downloads(): +async def test_validate_config_invalid_concurrent_downloads() -> None: """Test tweak_bulk_options method of BaseDataSource class with invalid concurrent downloads""" # Setup @@ -616,7 +616,7 @@ async def test_validate_config_invalid_concurrent_downloads(): @pytest.mark.asyncio -async def test_get_content_when_blob_tier_archive(): +async def test_get_content_when_blob_tier_archive() -> None: """Test get_content method when the blob tier is archive""" # Setup @@ -648,7 +648,7 @@ async def test_get_content_when_blob_tier_archive(): "connectors.content_extraction.ContentExtraction._check_configured", lambda *_: True, ) -async def test_get_content_with_text_extraction_enabled_adds_body(): +async def test_get_content_with_text_extraction_enabled_adds_body() -> None: mock_response = { "type": "blob", "id": "container1/blob1", @@ -704,7 +704,7 @@ async def read(self): @pytest.mark.asyncio -async def test_get_container_client_when_client_already_exists(): +async def test_get_container_client_when_client_already_exists() -> None: async with create_abs_source() as source: container_client = ContainerClient.from_connection_string( conn_str="AccountName=foo;AccountKey=bar;BlobEndpoint=https://foo.endpoint.com", @@ -715,7 +715,7 @@ async def test_get_container_client_when_client_already_exists(): @pytest.mark.asyncio -async def test_get_container_client_when_client_does_not_exist(): +async def test_get_container_client_when_client_does_not_exist() -> None: async with create_abs_source() as source: source.connection_string = ( "AccountName=foo;AccountKey=bar;BlobEndpoint=https://foo.endpoint.com" @@ -725,7 +725,7 @@ async def test_get_container_client_when_client_does_not_exist(): @pytest.mark.asyncio -async def test_close_with_connector_clients(): +async def test_close_with_connector_clients() -> None: async with create_abs_source() as source: class ContainerClientMock: diff --git a/tests/sources/test_box.py b/tests/sources/test_box.py index 261527332..440fb7bff 100644 --- a/tests/sources/test_box.py +++ b/tests/sources/test_box.py @@ -5,6 +5,7 @@ # """Tests the Box source class methods""" +from typing import Dict, List, Optional, Tuple, Union from unittest.mock import AsyncMock, Mock, patch import aiohttp @@ -96,7 +97,9 @@ "modified_at": "2023-08-04T03:17:55-07:00", "size": 1875887, } -MOCK_RESPONSE_FETCH = [ +MOCK_RESPONSE_FETCH: List[ + Tuple[Dict[str, Union[int, str]], Optional[Dict[str, Union[int, str]]]] +] = [ ( { "type": "file", @@ -135,7 +138,7 @@ class JSONAsyncMock: - def __init__(self, json, status, *args, **kwargs): + def __init__(self, json, status, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self._json = json self.status = status @@ -145,12 +148,12 @@ async def json(self): class StreamReaderAsyncMock(AsyncMock): - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.content = StreamReader -def get_json_mock(mock_response, status): +def get_json_mock(mock_response, status) -> AsyncMock: async_mock = AsyncMock() async_mock.__aenter__ = AsyncMock( return_value=JSONAsyncMock(json=mock_response, status=status) @@ -158,7 +161,9 @@ def get_json_mock(mock_response, status): return async_mock -def client_get_mock_func(url, headers, params): +def client_get_mock_func( + url, headers, params +) -> Union[JSONAsyncMock, StreamReaderAsyncMock]: if params is not None: if params.get("offset") == 0 and "/2.0/folders/220376481442/items" in url: return JSONAsyncMock(json=FOLDER_ITEMS, status=200) @@ -171,7 +176,7 @@ def client_get_mock_func(url, headers, params): @pytest.mark.asyncio @pytest.mark.parametrize("field", ["client_id", "client_secret", "refresh_token"]) -async def test_validate_config_raise_on_missing_fields(field): +async def test_validate_config_raise_on_missing_fields(field) -> None: async with create_source(BoxDataSource) as source: source.configuration.set_field(name=field, value="") @@ -180,7 +185,7 @@ async def test_validate_config_raise_on_missing_fields(field): @pytest.mark.asyncio -async def test_get(): +async def test_get() -> None: async with create_source(BoxDataSource) as source: source.client.token._set_access_token = AsyncMock() source.client.token.access_token = "abcd#123" @@ -190,7 +195,7 @@ async def test_get(): @pytest.mark.asyncio @pytest.mark.parametrize("box_account", ["box_free", "box_enterprise"]) -async def test_set_access_token(box_account): +async def test_set_access_token(box_account) -> None: async with create_source(BoxDataSource) as source: source.client.token.is_enterprise = box_account mock_token = { @@ -208,7 +213,7 @@ async def test_set_access_token(box_account): @pytest.mark.asyncio -async def test_set_access_token_raise_token_error_on_exception(): +async def test_set_access_token_raise_token_error_on_exception() -> None: async with create_source(BoxDataSource) as source: with patch("aiohttp.ClientSession.post", side_effect=Exception): with pytest.raises(TokenError): @@ -228,7 +233,7 @@ async def test_set_access_token_raise_token_error_on_exception(): ) async def test_client_get_raise_exception_on_response_error( mock_time_to_sleep_between_retries, status_code, exception -): +) -> None: async with create_source(BoxDataSource) as source: mock_time_to_sleep_between_retries.return_value = 0 source.client.token.get = AsyncMock() @@ -248,7 +253,7 @@ async def test_client_get_raise_exception_on_response_error( @pytest.mark.asyncio -async def test_ping_with_successful_connection(): +async def test_ping_with_successful_connection() -> None: async with create_source(BoxDataSource) as source: source.client.token.get = AsyncMock() source.client._http_session.get = AsyncMock( @@ -265,7 +270,7 @@ async def test_ping_with_successful_connection(): @patch("connectors.utils.time_to_sleep_between_retries") async def test_ping_raises_on_unsuccessful_connection( mock_time_to_sleep_between_retries, -): +) -> None: async with create_source(BoxDataSource) as source: mock_time_to_sleep_between_retries.return_value = 0 source.client.token.get = AsyncMock(side_effect=Exception()) @@ -284,7 +289,7 @@ async def test_ping_raises_on_unsuccessful_connection( (MOCK_ATTACHMENT, False, None), ], ) -async def test_get_content(attachment, doit, expected_content): +async def test_get_content(attachment, doit, expected_content) -> None: async with create_source(BoxDataSource) as source: source.client.token.get = AsyncMock() source.client._http_session.get = AsyncMock( @@ -302,7 +307,7 @@ async def test_get_content(attachment, doit, expected_content): @pytest.mark.asyncio -async def test_consumer_processes_queue_items(): +async def test_consumer_processes_queue_items() -> None: mock_folder = { "type": "folder", "etag": "0", @@ -335,7 +340,7 @@ async def test_consumer_processes_queue_items(): @pytest.mark.asyncio -async def test_fetch(): +async def test_fetch() -> None: actual_response = [] expected_response = [ { @@ -373,7 +378,7 @@ async def test_fetch(): @patch("connectors.utils.time_to_sleep_between_retries") async def test_fetch_returns_none_on_client_exception( mock_time_to_sleep_between_retries, -): +) -> None: async with create_source(BoxDataSource) as source: mock_time_to_sleep_between_retries.return_value = Mock() source.client.token.get = AsyncMock(side_effect=Exception()) @@ -382,7 +387,7 @@ async def test_fetch_returns_none_on_client_exception( @pytest.mark.asyncio -async def test_get_docs(): +async def test_get_docs() -> None: actual_response = [] expected_response = [ { @@ -415,7 +420,7 @@ async def test_get_docs(): @pytest.mark.asyncio -async def test_end_signal_is_added_to_queue_in_case_of_exception(): +async def test_end_signal_is_added_to_queue_in_case_of_exception() -> None: END_SIGNAL = "FINISHED" async with create_source(BoxDataSource) as source: with patch.object( diff --git a/tests/sources/test_confluence.py b/tests/sources/test_confluence.py index 9f1b66998..0d77330d9 100644 --- a/tests/sources/test_confluence.py +++ b/tests/sources/test_confluence.py @@ -8,6 +8,7 @@ import ssl from contextlib import asynccontextmanager from copy import copy +from typing import Dict, List, Optional, Union from unittest import mock from unittest.mock import AsyncMock, MagicMock, Mock, patch @@ -141,7 +142,7 @@ "_links": {}, } -EXPECTED_PAGE = { +EXPECTED_PAGE: Dict[str, Union[List[None], List[Dict[str, str]], str]] = { "_id": "4779", "type": "page", "_timestamp": "2023-01-24T04:07:19.672Z", @@ -224,7 +225,7 @@ "_attachment": "IyBUaGlzIGlzIHRoZSBkdW1teSBmaWxl", } -EXPECTED_CONTENT_EXTRACTED = { +EXPECTED_CONTENT_EXTRACTED: Dict[str, str] = { "_id": "att3637249", "_timestamp": "2023-01-03T09:24:50.633Z", "body": RESPONSE_CONTENT, @@ -337,7 +338,9 @@ } -EXPECTED_SEARCH_RESULT_FOR_FILTERING_CLOUD = [ +EXPECTED_SEARCH_RESULT_FOR_FILTERING_CLOUD: List[ + Union[Dict[str, str], Dict[str, Optional[str]], Dict[str, Union[int, str]]] +] = [ { "_id": "983046", "title": "Product Details", @@ -374,7 +377,9 @@ ] -EXPECTED_SEARCH_RESULT_FOR_FILTERING_DATA_CENTER = [ +EXPECTED_SEARCH_RESULT_FOR_FILTERING_DATA_CENTER: List[ + Union[Dict[str, str], Dict[str, Optional[str]], Dict[str, Union[int, str]]] +] = [ { "_id": "983046", "title": "Product Details", @@ -499,7 +504,7 @@ @asynccontextmanager async def create_confluence_source( - use_text_extraction_service=False, data_source=CONFLUENCE_SERVER + use_text_extraction_service: bool = False, data_source: str = CONFLUENCE_SERVER ): async with create_source( ConfluenceDataSource, @@ -516,7 +521,7 @@ async def create_confluence_source( class JSONAsyncMock(AsyncMock): - def __init__(self, json, *args, **kwargs): + def __init__(self, json, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self._json = json @@ -525,13 +530,13 @@ async def json(self): class StreamReaderAsyncMock(AsyncMock): - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.content = StreamReader @pytest.mark.asyncio -async def test_validate_configuration_with_invalid_concurrent_downloads(): +async def test_validate_configuration_with_invalid_concurrent_downloads() -> None: """Test validate configuration method of BaseDataSource class with invalid concurrent downloads""" # Setup @@ -580,7 +585,7 @@ async def test_validate_configuration_with_invalid_concurrent_downloads(): ) async def test_validate_configuration_with_invalid_dependency_fields_raises_error( configs, -): +) -> None: # Setup async with create_confluence_source() as source: for k, v in configs.items(): @@ -637,7 +642,7 @@ async def test_validate_configuration_with_invalid_dependency_fields_raises_erro ) async def test_validate_config_with_valid_dependency_fields_does_not_raise_error( configs, -): +) -> None: async with create_confluence_source() as source: source.confluence_client.ping = AsyncMock() for k, v in configs.items(): @@ -647,7 +652,9 @@ async def test_validate_config_with_valid_dependency_fields_does_not_raise_error @pytest.mark.asyncio -async def test_validate_config_when_ssl_enabled_and_ssl_ca_not_empty_does_not_raise_error(): +async def test_validate_config_when_ssl_enabled_and_ssl_ca_not_empty_does_not_raise_error() -> ( + None +): with patch.object(ssl, "create_default_context", return_value=MockSSL()): async with create_confluence_source() as source: source.confluence_client._get_session().get = AsyncMock() @@ -663,7 +670,7 @@ async def test_validate_config_when_ssl_enabled_and_ssl_ca_not_empty_does_not_ra @pytest.mark.asyncio -async def test_tweak_bulk_options(): +async def test_tweak_bulk_options() -> None: """Test tweak_bulk_options method of BaseDataSource class""" # Setup @@ -678,7 +685,7 @@ async def test_tweak_bulk_options(): @pytest.mark.asyncio -async def test_close_with_client_session(): +async def test_close_with_client_session() -> None: """Test close method for closing the existing session""" # Setup @@ -692,7 +699,7 @@ async def test_close_with_client_session(): @pytest.mark.asyncio -async def test_close_without_client_session(): +async def test_close_without_client_session() -> None: """Test close method when the session does not exist""" # Setup async with create_confluence_source() as source: @@ -703,7 +710,7 @@ async def test_close_without_client_session(): @pytest.mark.asyncio -async def test_remote_validation_when_space_keys_are_valid(): +async def test_remote_validation_when_space_keys_are_valid() -> None: async with create_confluence_source() as source: source.spaces = ["DM", "ES"] source.confluence_client._get_session().get = AsyncMock( @@ -714,7 +721,9 @@ async def test_remote_validation_when_space_keys_are_valid(): @pytest.mark.asyncio -async def test_remote_validation_when_space_keys_are_unavailable_then_raise_exception(): +async def test_remote_validation_when_space_keys_are_unavailable_then_raise_exception() -> ( + None +): async with create_confluence_source() as source: source.spaces = ["ES", "CS"] async_response = AsyncMock() @@ -734,7 +743,7 @@ async def test_remote_validation_when_space_keys_are_unavailable_then_raise_exce class MockSSL: """This class contains methods which returns dummy ssl context""" - def load_verify_locations(self, cadata): + def load_verify_locations(self, cadata) -> None: """This method verify locations""" pass @@ -752,7 +761,7 @@ def load_verify_locations(self, cadata): ], ) @pytest.mark.asyncio -async def test_validate_configuration_for_empty_fields(field, data_source): +async def test_validate_configuration_for_empty_fields(field, data_source) -> None: async with create_confluence_source() as source: source.confluence_client.configuration.get_field( "data_source" @@ -765,7 +774,7 @@ async def test_validate_configuration_for_empty_fields(field, data_source): @pytest.mark.asyncio -async def test_ping_with_ssl(): +async def test_ping_with_ssl() -> None: """Test ping method of ConfluenceDataSource class with SSL""" # Execute @@ -786,7 +795,7 @@ async def test_ping_with_ssl(): @pytest.mark.asyncio @patch("aiohttp.ClientSession.get") -async def test_ping_for_failed_connection_exception(mock_get): +async def test_ping_for_failed_connection_exception(mock_get) -> None: """Tests the ping functionality when connection can not be established to Confluence.""" # Setup @@ -800,7 +809,7 @@ async def test_ping_for_failed_connection_exception(mock_get): @pytest.mark.asyncio -async def test_validate_configuration_for_ssl_enabled(): +async def test_validate_configuration_for_ssl_enabled() -> None: """This function tests _validate_configuration when certification is empty and ssl is enabled""" # Setup async with create_confluence_source() as source: @@ -813,7 +822,7 @@ async def test_validate_configuration_for_ssl_enabled(): @pytest.mark.asyncio @patch("connectors.utils.time_to_sleep_between_retries", Mock(return_value=0)) -async def test_get_with_429_status(): +async def test_get_with_429_status() -> None: initial_response = ClientResponseError(None, None) initial_response.status = 429 initial_response.message = "rate-limited" @@ -836,7 +845,7 @@ async def test_get_with_429_status(): @pytest.mark.asyncio @patch("connectors.utils.time_to_sleep_between_retries", Mock(return_value=0)) -async def test_get_with_429_status_without_retry_after_header(): +async def test_get_with_429_status_without_retry_after_header() -> None: payload = {"value": "Test rate limit"} async_mock_response = AsyncMock() async_mock_response.json.return_value = payload @@ -855,7 +864,7 @@ async def test_get_with_429_status_without_retry_after_header(): @pytest.mark.asyncio @patch("connectors.utils.time_to_sleep_between_retries", Mock(return_value=0)) -async def test_get_with_400_status(): +async def test_get_with_400_status() -> None: error = ClientResponseError(None, None) error.status = 400 @@ -873,7 +882,7 @@ async def test_get_with_400_status(): @pytest.mark.asyncio @patch("connectors.utils.time_to_sleep_between_retries", Mock(return_value=0)) -async def test_get_with_401_status(): +async def test_get_with_401_status() -> None: error = ClientResponseError(None, None) error.status = 401 @@ -891,7 +900,7 @@ async def test_get_with_401_status(): @pytest.mark.asyncio @patch("connectors.utils.time_to_sleep_between_retries", Mock(return_value=0)) -async def test_get_with_403_status(): +async def test_get_with_403_status() -> None: error = ClientResponseError(None, None) error.status = 403 @@ -908,7 +917,7 @@ async def test_get_with_403_status(): @pytest.mark.asyncio -async def test_get_with_404_status(): +async def test_get_with_404_status() -> None: error = ClientResponseError(None, None) error.status = 404 @@ -926,7 +935,7 @@ async def test_get_with_404_status(): @pytest.mark.asyncio @patch("connectors.utils.time_to_sleep_between_retries", Mock(return_value=0)) -async def test_get_with_500_status(): +async def test_get_with_500_status() -> None: error = ClientResponseError(None, None) error.status = 500 @@ -944,7 +953,7 @@ async def test_get_with_500_status(): @freeze_time("2023-01-24T04:07:19") @pytest.mark.asyncio -async def test_fetch_spaces(): +async def test_fetch_spaces() -> None: # Setup async with create_confluence_source() as source: async_response = AsyncMock() @@ -958,7 +967,7 @@ async def test_fetch_spaces(): @pytest.mark.asyncio -async def test_fetch_documents(): +async def test_fetch_documents() -> None: # Setup async with create_confluence_source() as source: source.confluence_client._get_session().get = AsyncMock( @@ -972,7 +981,7 @@ async def test_fetch_documents(): @pytest.mark.asyncio -async def test_fetch_attachments(): +async def test_fetch_attachments() -> None: # Setup async with create_confluence_source() as source: source.confluence_client._get_session().get = AsyncMock( @@ -990,7 +999,7 @@ async def test_fetch_attachments(): @pytest.mark.asyncio -async def test_search_by_query(): +async def test_search_by_query() -> None: async with create_confluence_source() as source: source.confluence_client._get_session().get = AsyncMock( return_value=JSONAsyncMock(RESPONSE_SEARCH_RESULT) @@ -1007,7 +1016,7 @@ async def test_search_by_query(): @pytest.mark.asyncio -async def test_search_by_query_for_datacenter(): +async def test_search_by_query_for_datacenter() -> None: async with create_confluence_source() as source: source.confluence_client.data_source_type = "confluence_data_center" source.confluence_client._get_session().get = AsyncMock( @@ -1024,7 +1033,7 @@ async def test_search_by_query_for_datacenter(): @pytest.mark.asyncio -async def test_download_attachment(): +async def test_download_attachment() -> None: # Setup async with create_confluence_source() as source: source.confluence_client._get_session().get = AsyncMock( @@ -1046,7 +1055,7 @@ async def test_download_attachment(): @pytest.mark.asyncio -async def test_download_attachment_with_upper_extension(): +async def test_download_attachment_with_upper_extension() -> None: # Setup async with create_confluence_source() as source: source.confluence_client._get_session().get = AsyncMock( @@ -1071,7 +1080,7 @@ async def test_download_attachment_with_upper_extension(): @pytest.mark.asyncio -async def test_download_attachment_when_filesize_is_large_then_download_skips(): +async def test_download_attachment_when_filesize_is_large_then_download_skips() -> None: """Tests the download attachments method for file size greater than max limit.""" # Setup async with create_confluence_source() as source: @@ -1097,7 +1106,9 @@ async def test_download_attachment_when_filesize_is_large_then_download_skips(): @pytest.mark.asyncio -async def test_download_attachment_when_unsupported_filetype_used_then_fail_download_skips(): +async def test_download_attachment_when_unsupported_filetype_used_then_fail_download_skips() -> ( + None +): """Tests the download attachments method for file type is not supported""" # Setup async with create_confluence_source() as source: @@ -1126,7 +1137,7 @@ async def test_download_attachment_when_unsupported_filetype_used_then_fail_down "connectors.content_extraction.ContentExtraction._check_configured", lambda *_: True, ) -async def test_download_attachment_with_text_extraction_enabled_adds_body(): +async def test_download_attachment_with_text_extraction_enabled_adds_body() -> None: with ( patch( "connectors.content_extraction.ContentExtraction.extract_text", @@ -1184,7 +1195,9 @@ async def test_download_attachment_with_text_extraction_enabled_adds_body(): return_value=AsyncIterator([[copy(EXPECTED_CONTENT)]]), ) @freeze_time("2024-04-02T09:53:15.818621+00:00") -async def test_get_docs(spaces_patch, pages_patch, attachment_patch, content_patch): +async def test_get_docs( + spaces_patch, pages_patch, attachment_patch, content_patch +) -> None: """Tests the get_docs method""" # Setup @@ -1209,7 +1222,7 @@ async def test_get_docs(spaces_patch, pages_patch, attachment_patch, content_pat @pytest.mark.parametrize( "data_source_type", [CONFLUENCE_CLOUD, CONFLUENCE_DATA_CENTER, CONFLUENCE_SERVER] ) -async def test_get_session(data_source_type): +async def test_get_session(data_source_type) -> None: async with create_confluence_source(data_source=data_source_type) as source: try: source.confluence_client._get_session() @@ -1220,7 +1233,7 @@ async def test_get_session(data_source_type): @pytest.mark.asyncio -async def test_get_session_multiple_calls_return_same_instance(): +async def test_get_session_multiple_calls_return_same_instance() -> None: async with create_confluence_source() as source: first_instance = source.confluence_client._get_session() second_instance = source.confluence_client._get_session() @@ -1228,14 +1241,14 @@ async def test_get_session_multiple_calls_return_same_instance(): @pytest.mark.asyncio -async def test_get_session_raise_on_invalid_data_source_type(): +async def test_get_session_raise_on_invalid_data_source_type() -> None: async with create_confluence_source(data_source="invalid") as source: with pytest.raises(InvalidConfluenceDataSourceTypeError): source.confluence_client._get_session() @pytest.mark.asyncio -async def test_get_access_control_dls_disabled(): +async def test_get_access_control_dls_disabled() -> None: async with create_confluence_source() as source: source._dls_enabled = MagicMock(return_value=False) @@ -1248,7 +1261,7 @@ async def test_get_access_control_dls_disabled(): @pytest.mark.asyncio @freeze_time("2023-01-24T04:07:19") -async def test_get_access_control_dls_enabled(): +async def test_get_access_control_dls_enabled() -> None: mock_users = [ { # Indexable: The user is active and atlassian user. @@ -1357,7 +1370,7 @@ async def test_get_access_control_dls_enabled(): @pytest.mark.asyncio @freeze_time("2023-01-24T04:07:19") -async def test_get_access_control_dls_enabled_for_datacenter(): +async def test_get_access_control_dls_enabled_for_datacenter() -> None: mock_users = [ { "username": "user1", @@ -1405,7 +1418,7 @@ async def test_get_access_control_dls_enabled_for_datacenter(): @pytest.mark.asyncio @freeze_time("2023-01-24T04:07:19") -async def test_get_access_control_dls_enabled_for_server(): +async def test_get_access_control_dls_enabled_for_server() -> None: mock_users = [ { "fullName": "user1", @@ -1452,7 +1465,7 @@ async def test_get_access_control_dls_enabled_for_server(): @pytest.mark.asyncio -async def test_fetch_confluence_server_users(): +async def test_fetch_confluence_server_users() -> None: async with create_confluence_source() as source: async_response = AsyncMock() async_response.json.return_value = {"start": 0, "users": []} @@ -1536,7 +1549,7 @@ async def test_fetch_confluence_server_users(): @freeze_time("2024-04-02T09:53:15.818621+00:00") async def test_get_docs_dls_enabled( spaces_patch, pages_patch, attachment_patch, content_patch -): +) -> None: async with create_confluence_source() as source: source._dls_enabled = MagicMock(return_value=True) source.confluence_client.data_source_type = "confluence_cloud" @@ -1580,7 +1593,7 @@ async def test_get_docs_dls_enabled( ], ) @pytest.mark.asyncio -async def test_get_docs_with_advanced_rules(filtering, expected_docs): +async def test_get_docs_with_advanced_rules(filtering, expected_docs) -> None: async with create_confluence_source() as source: with patch.object( ConfluenceDataSource, @@ -1602,7 +1615,7 @@ async def test_get_docs_with_advanced_rules(filtering, expected_docs): @pytest.mark.asyncio -async def test_extract_identities_for_datacenter(): +async def test_extract_identities_for_datacenter() -> None: async with create_confluence_source() as source: source._dls_enabled = MagicMock(return_value=True) response = { @@ -1630,7 +1643,7 @@ async def test_extract_identities_for_datacenter(): @pytest.mark.asyncio -async def test_fetch_server_space_permission(): +async def test_fetch_server_space_permission() -> None: async with create_confluence_source() as source: source._dls_enabled = MagicMock(return_value=True) payload = { @@ -1664,7 +1677,7 @@ async def test_fetch_server_space_permission(): @pytest.mark.asyncio -async def test_api_call_for_exception(patch_sleep): +async def test_api_call_for_exception(patch_sleep) -> None: """This function test _api_call when credentials are incorrect""" async with create_confluence_source() as source: source.confluence_client.retry_count = 1 @@ -1676,7 +1689,7 @@ async def test_api_call_for_exception(patch_sleep): @pytest.mark.asyncio -async def test_get_permission(): +async def test_get_permission() -> None: async with create_confluence_source() as source: actual_permission = {"users": ["admin"], "groups": ["group"]} permisssions = source.get_permission(permission=actual_permission) @@ -1692,13 +1705,13 @@ async def test_get_permission(): ], ) @pytest.mark.asyncio -async def test_page_blog_coro(fetch_documents): +async def test_page_blog_coro(fetch_documents) -> None: async with create_confluence_source() as source: await source._page_blog_coro("api_query", "target") @pytest.mark.asyncio -async def test_end_signal_is_added_to_queue_in_case_of_exception(): +async def test_end_signal_is_added_to_queue_in_case_of_exception() -> None: END_SIGNAL = "FINISHED_TASK" async with create_confluence_source() as source: with patch.object( @@ -1712,7 +1725,7 @@ async def test_end_signal_is_added_to_queue_in_case_of_exception(): @pytest.mark.asyncio -async def test_fetch_page_blog_documents_with_labels(): +async def test_fetch_page_blog_documents_with_labels() -> None: async with create_confluence_source() as source: source.confluence_client._get_session().get = AsyncMock( return_value=JSONAsyncMock(RESPONSE_PAGE) @@ -1745,7 +1758,7 @@ async def test_fetch_page_blog_documents_with_labels(): @pytest.mark.asyncio -async def test_fetch_documents_with_html(): +async def test_fetch_documents_with_html() -> None: async with create_confluence_source() as source: source.confluence_client._get_session().get = AsyncMock( return_value=JSONAsyncMock(RESPONSE_PAGE_WITH_HTML) diff --git a/tests/sources/test_directory.py b/tests/sources/test_directory.py index e65f7d2b4..e1bfa0f1c 100644 --- a/tests/sources/test_directory.py +++ b/tests/sources/test_directory.py @@ -10,12 +10,12 @@ @pytest.mark.asyncio -async def test_basics(): +async def test_basics() -> None: await assert_basics(DirectoryDataSource, "directory", DEFAULT_DIR) @pytest.mark.asyncio -async def test_get_docs(catch_stdout): +async def test_get_docs(catch_stdout) -> None: async with create_source(DirectoryDataSource) as source: num = 0 async for doc, dl in source.get_docs(): diff --git a/tests/sources/test_dropbox.py b/tests/sources/test_dropbox.py index 9dc400613..07eb08028 100644 --- a/tests/sources/test_dropbox.py +++ b/tests/sources/test_dropbox.py @@ -7,6 +7,7 @@ import json from contextlib import asynccontextmanager +from typing import Dict, List, Optional, Union from unittest import mock from unittest.mock import ANY, AsyncMock, MagicMock, Mock, patch @@ -195,7 +196,9 @@ "has_more": True, } -MOCK_FILES_FOLDERS_CONTINUE = { +MOCK_FILES_FOLDERS_CONTINUE: Dict[ + str, Union[None, List[Dict[str, Union[int, str]]], bool] +] = { "entries": [ { ".tag": "file", @@ -213,7 +216,12 @@ "has_more": False, } -MOCK_MEMBERS_CONTINUE = { +MOCK_MEMBERS_CONTINUE: Dict[ + str, + Union[ + None, List[Dict[str, Dict[str, Union[Dict[str, str], List[str], str]]]], bool + ], +] = { "members": [ { "profile": { @@ -297,7 +305,9 @@ "cursor": "abcd#1234", } -MOCK_SHARED_FILES_CONTINUE = { +MOCK_SHARED_FILES_CONTINUE: Dict[ + str, Optional[List[Dict[str, Union[Dict[str, str], str]]]] +] = { "entries": [ { "access_type": {".tag": "viewer"}, @@ -430,7 +440,7 @@ "_timestamp": "2023-01-01T06:06:06Z", "_attachment": "IyBUaGlzIGlzIHRoZSBkdW1teSBmaWxl", } -EXPECTED_CONTENT_EXTRACTED = { +EXPECTED_CONTENT_EXTRACTED: Dict[str, str] = { "_id": "id:1", "_timestamp": "2023-01-01T06:06:06Z", "body": RESPONSE_CONTENT, @@ -556,7 +566,7 @@ class JSONAsyncMock(AsyncMock): - def __init__(self, json, status, *args, **kwargs): + def __init__(self, json, status, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self._json = json self.status = status @@ -566,12 +576,12 @@ async def json(self): class StreamReaderAsyncMock(AsyncMock): - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.content = StreamReader -def get_json_mock(mock_response, status): +def get_json_mock(mock_response, status) -> AsyncMock: async_mock = AsyncMock() async_mock.__aenter__ = AsyncMock( return_value=JSONAsyncMock(json=mock_response, status=status) @@ -579,13 +589,13 @@ def get_json_mock(mock_response, status): return async_mock -def get_stream_reader(): +def get_stream_reader() -> AsyncMock: async_mock = AsyncMock() async_mock.__aenter__ = AsyncMock(return_value=StreamReaderAsyncMock()) return async_mock -def setup_dropbox(source): +def setup_dropbox(source) -> None: # Set up default config with default values source.configuration.get_field("app_key").value = "abc#123" source.configuration.get_field("app_secret").value = "abc#123" @@ -594,8 +604,8 @@ def setup_dropbox(source): @asynccontextmanager async def create_dropbox_source( - use_text_extraction_service=False, - mock_access_token=True, + use_text_extraction_service: bool = False, + mock_access_token: bool = True, ): async with create_source( DropboxDataSource, @@ -614,7 +624,9 @@ async def create_dropbox_source( "field", ["app_key", "app_secret", "refresh_token"], ) -async def test_validate_configuration_with_empty_fields_then_raise_exception(field): +async def test_validate_configuration_with_empty_fields_then_raise_exception( + field, +) -> None: async with create_source(DropboxDataSource) as source: setup_dropbox(source) source.dropbox_client.configuration.get_field(field).value = "" @@ -624,7 +636,7 @@ async def test_validate_configuration_with_empty_fields_then_raise_exception(fie @pytest.mark.asyncio -async def test_validate_configuration_with_valid_path(): +async def test_validate_configuration_with_valid_path() -> None: async with create_source(DropboxDataSource) as source: setup_dropbox(source) source.dropbox_client.configuration.get_field("path").value = "/shared" @@ -639,7 +651,7 @@ async def test_validate_configuration_with_valid_path(): @pytest.mark.asyncio @patch("connectors.utils.time_to_sleep_between_retries", Mock(return_value=0)) -async def test_validate_configuration_with_invalid_path_then_raise_exception(): +async def test_validate_configuration_with_invalid_path_then_raise_exception() -> None: async with create_source(DropboxDataSource) as source: setup_dropbox(source) source.dropbox_client.path = "/abc" @@ -662,7 +674,7 @@ async def test_validate_configuration_with_invalid_path_then_raise_exception(): @pytest.mark.asyncio -async def test_set_access_token(): +async def test_set_access_token() -> None: async with create_source(DropboxDataSource) as source: setup_dropbox(source) with patch.object( @@ -676,7 +688,7 @@ async def test_set_access_token(): @pytest.mark.asyncio @patch("connectors.utils.time_to_sleep_between_retries", Mock(return_value=0)) -async def test_set_access_token_with_incorrect_app_key_then_raise_exception(): +async def test_set_access_token_with_incorrect_app_key_then_raise_exception() -> None: async with create_source(DropboxDataSource) as source: setup_dropbox(source) @@ -696,7 +708,9 @@ async def test_set_access_token_with_incorrect_app_key_then_raise_exception(): @pytest.mark.asyncio @patch("connectors.utils.time_to_sleep_between_retries", Mock(return_value=0)) -async def test_set_access_token_with_incorrect_refresh_token_then_raise_exception(): +async def test_set_access_token_with_incorrect_refresh_token_then_raise_exception() -> ( + None +): async with create_source(DropboxDataSource) as source: setup_dropbox(source) @@ -715,7 +729,7 @@ async def test_set_access_token_with_incorrect_refresh_token_then_raise_exceptio @pytest.mark.asyncio -async def test_tweak_bulk_options(): +async def test_tweak_bulk_options() -> None: async with create_source(DropboxDataSource) as source: setup_dropbox(source) source.concurrent_downloads = 10 @@ -726,7 +740,7 @@ async def test_tweak_bulk_options(): @pytest.mark.asyncio -async def test_close_with_client_session(): +async def test_close_with_client_session() -> None: async with create_source(DropboxDataSource) as source: setup_dropbox(source) _ = source.dropbox_client._get_session @@ -736,7 +750,7 @@ async def test_close_with_client_session(): @pytest.mark.asyncio -async def test_ping(): +async def test_ping() -> None: async with create_source(DropboxDataSource) as source: setup_dropbox(source) source.dropbox_client._set_access_token = AsyncMock() @@ -750,7 +764,7 @@ async def test_ping(): @pytest.mark.asyncio @patch("connectors.sources.dropbox.RETRY_INTERVAL", 0) -async def test_ping_when_server_timeout_error_raises(): +async def test_ping_when_server_timeout_error_raises() -> None: async with create_source(DropboxDataSource) as source: setup_dropbox(source) source.dropbox_client._set_access_token = AsyncMock() @@ -763,7 +777,7 @@ async def test_ping_when_server_timeout_error_raises(): @pytest.mark.asyncio @patch("connectors.sources.dropbox.RETRY_INTERVAL", 0) -async def test_ping_when_client_response_error_occurs(): +async def test_ping_when_client_response_error_occurs() -> None: async with create_source(DropboxDataSource) as source: setup_dropbox(source) source.dropbox_client._set_access_token = AsyncMock() @@ -785,7 +799,7 @@ async def test_ping_when_client_response_error_occurs(): @pytest.mark.asyncio @patch("connectors.sources.dropbox.RETRY_INTERVAL", 0) -async def test_ping_when_client_response_error_occur_with_unexpected_url(): +async def test_ping_when_client_response_error_occur_with_unexpected_url() -> None: async with create_source(DropboxDataSource) as source: setup_dropbox(source) source.dropbox_client._set_access_token = AsyncMock() @@ -807,7 +821,7 @@ async def test_ping_when_client_response_error_occur_with_unexpected_url(): @pytest.mark.asyncio @patch("connectors.sources.dropbox.RETRY_INTERVAL", 0) -async def test_api_call_negative(): +async def test_api_call_negative() -> None: async with create_source(DropboxDataSource) as source: setup_dropbox(source) source.dropbox_client.retry_count = 4 @@ -839,7 +853,7 @@ async def test_api_call_negative(): @pytest.mark.asyncio -async def test_api_call(): +async def test_api_call() -> None: async with create_source(DropboxDataSource) as source: setup_dropbox(source) source.dropbox_client._set_access_token = AsyncMock() @@ -872,7 +886,7 @@ async def test_api_call(): @pytest.mark.asyncio -async def test_paginated_api_call_when_skipping_api_call(): +async def test_paginated_api_call_when_skipping_api_call() -> None: async with create_source(DropboxDataSource) as source: setup_dropbox(source) source.dropbox_client.retry_count = 1 @@ -894,7 +908,7 @@ async def test_paginated_api_call_when_skipping_api_call(): @pytest.mark.asyncio -async def test_set_access_token_when_token_expires_at_is_str(): +async def test_set_access_token_when_token_expires_at_is_str() -> None: async with create_source(DropboxDataSource) as source: setup_dropbox(source) source.dropbox_client.token_expiration_time = "2023-02-10T09:02:23.629821" @@ -916,7 +930,7 @@ def patch_default_wait_multiplier(): @pytest.mark.asyncio @patch("connectors.utils.time_to_sleep_between_retries", Mock(return_value=0)) -async def test_api_call_when_token_is_expired(): +async def test_api_call_when_token_is_expired() -> None: async with create_source(DropboxDataSource) as source: setup_dropbox(source) @@ -948,7 +962,7 @@ async def test_api_call_when_token_is_expired(): @pytest.mark.asyncio -async def test_api_call_when_status_429_exception(): +async def test_api_call_when_status_429_exception() -> None: async with create_source(DropboxDataSource) as source: setup_dropbox(source) source.dropbox_client._set_access_token = AsyncMock() @@ -982,7 +996,7 @@ async def test_api_call_when_status_429_exception(): @pytest.mark.asyncio @patch("connectors.sources.dropbox.DEFAULT_RETRY_AFTER", 0) -async def test_api_call_when_status_429_exception_without_retry_after_header(): +async def test_api_call_when_status_429_exception_without_retry_after_header() -> None: async with create_source(DropboxDataSource) as source: setup_dropbox(source) source.dropbox_client.retry_count = 1 @@ -1027,7 +1041,7 @@ async def test_api_call_when_status_429_exception_without_retry_after_header(): ) async def test_get_content_when_is_downloadable_is_true( attachment, is_shared, expected_content -): +) -> None: async with create_source(DropboxDataSource) as source: setup_dropbox(source) source.dropbox_client._set_access_token = AsyncMock() @@ -1046,7 +1060,9 @@ async def test_get_content_when_is_downloadable_is_true( @pytest.mark.asyncio -async def test_get_content_when_is_downloadable_is_true_with_extraction_service(): +async def test_get_content_when_is_downloadable_is_true_with_extraction_service() -> ( + None +): with ( patch( "connectors.content_extraction.ContentExtraction.extract_text", @@ -1075,7 +1091,7 @@ async def test_get_content_when_is_downloadable_is_true_with_extraction_service( @pytest.mark.asyncio @freeze_time("2023-01-01T06:06:06") -async def test_fetch_files_folders(): +async def test_fetch_files_folders() -> None: async with create_source(DropboxDataSource) as source: setup_dropbox(source) source.dropbox_client.path = "/" @@ -1096,7 +1112,7 @@ async def test_fetch_files_folders(): @pytest.mark.asyncio @freeze_time("2023-01-01T06:06:06") -async def test_fetch_shared_files(): +async def test_fetch_shared_files() -> None: async with create_source(DropboxDataSource) as source: setup_dropbox(source) source.dropbox_client.path = "/" @@ -1124,7 +1140,7 @@ async def test_fetch_shared_files(): @pytest.mark.asyncio @freeze_time("2023-01-01T06:06:06") -async def test_search_files(): +async def test_search_files() -> None: async with create_source(DropboxDataSource) as source: setup_dropbox(source) rule = { @@ -1186,7 +1202,7 @@ async def test_search_files(): "ping", return_value=JSONAsyncMock(MOCK_AUTHENTICATED_ADMIN, 200), ) -async def test_get_docs(files_folders_patch, shared_files_patch, ping_patch): +async def test_get_docs(files_folders_patch, shared_files_patch, ping_patch) -> None: async with create_source(DropboxDataSource) as source: setup_dropbox(source) expected_responses = [*EXPECTED_FILES_FOLDERS, *EXPECTED_SHARED_FILES] @@ -1223,7 +1239,7 @@ async def test_get_docs(files_folders_patch, shared_files_patch, ping_patch): @pytest.mark.asyncio async def test_advanced_rules_validation_with_invalid_repos( advanced_rules, expected_validation_result -): +) -> None: async with create_source(DropboxDataSource) as source: setup_dropbox(source) source.dropbox_client.check_path = AsyncMock(side_effect=InvalidPathException()) @@ -1294,7 +1310,7 @@ async def test_advanced_rules_validation_with_invalid_repos( @pytest.mark.asyncio async def test_get_docs_with_advanced_rules( received_files_patch, files_folders_patch, ping_patch, filtering -): +) -> None: async with create_source(DropboxDataSource) as source: setup_dropbox(source) source.get_content = Mock(return_value=EXPECTED_CONTENT) @@ -1324,7 +1340,7 @@ async def test_get_docs_with_advanced_rules( @pytest.mark.asyncio -async def test_get_team_folder_id(): +async def test_get_team_folder_id() -> None: async with create_source(DropboxDataSource) as source: setup_dropbox(source) source.dropbox_client.path = "/" @@ -1346,7 +1362,7 @@ async def test_get_team_folder_id(): @pytest.mark.asyncio -async def test_get_access_control(): +async def test_get_access_control() -> None: async with create_source(DropboxDataSource) as source: setup_dropbox(source) source._dls_enabled = MagicMock(return_value=True) @@ -1364,7 +1380,7 @@ async def test_get_access_control(): @pytest.mark.asyncio -async def test_ping_dls_enabled(): +async def test_ping_dls_enabled() -> None: async with create_source(DropboxDataSource) as source: setup_dropbox(source) source._dls_enabled = MagicMock(return_value=True) @@ -1378,7 +1394,7 @@ async def test_ping_dls_enabled(): @pytest.mark.asyncio -async def test_get_permission_list_for_file(): +async def test_get_permission_list_for_file() -> None: async with create_source(DropboxDataSource) as source: setup_dropbox(source) source._dls_enabled = MagicMock(return_value=True) @@ -1407,7 +1423,7 @@ async def test_get_permission_list_for_file(): @pytest.mark.asyncio -async def test_get_permission_list_for_folder(): +async def test_get_permission_list_for_folder() -> None: async with create_source(DropboxDataSource) as source: setup_dropbox(source) source._dls_enabled = MagicMock(return_value=True) @@ -1436,7 +1452,7 @@ async def test_get_permission_list_for_folder(): @pytest.mark.asyncio -async def test_get_account_details(): +async def test_get_account_details() -> None: async with create_source(DropboxDataSource) as source: setup_dropbox(source) source.dropbox_client._set_access_token = AsyncMock() @@ -1455,7 +1471,7 @@ async def create_fake_coroutine(data): @pytest.mark.asyncio -async def test_get_docs_for_dls(): +async def test_get_docs_for_dls() -> None: async with create_source(DropboxDataSource) as source: setup_dropbox(source) source._dls_enabled = MagicMock(return_value=True) @@ -1474,7 +1490,7 @@ async def test_get_docs_for_dls(): @pytest.mark.asyncio -async def test_remote_validation_with_dls(): +async def test_remote_validation_with_dls() -> None: async with create_source(DropboxDataSource) as source: setup_dropbox(source) source._dls_enabled = MagicMock(return_value=True) @@ -1491,7 +1507,7 @@ async def test_remote_validation_with_dls(): @pytest.mark.asyncio -async def test_add_document_to_list(): +async def test_add_document_to_list() -> None: async with create_source(DropboxDataSource) as source: setup_dropbox(source) source.include_inherited_users_and_groups = True @@ -1517,7 +1533,7 @@ async def test_add_document_to_list(): @pytest.mark.asyncio -async def test_add_document_to_list_with_exclude_inherited_users_and_groups(): +async def test_add_document_to_list_with_exclude_inherited_users_and_groups() -> None: async with create_source(DropboxDataSource) as source: setup_dropbox(source) source._fetch_files_folders = Mock( diff --git a/tests/sources/test_generic_database.py b/tests/sources/test_generic_database.py index 6103f9bcf..6dee0729d 100644 --- a/tests/sources/test_generic_database.py +++ b/tests/sources/test_generic_database.py @@ -6,6 +6,7 @@ """Tests the Generic Database source class methods""" from functools import partial +from typing import List, Optional, Sized, Tuple, Union import pytest @@ -16,6 +17,7 @@ map_column_names, ) from connectors.sources.mssql import MSSQLQueries +from connectors.sources.oracle import OracleQueries SCHEMA = "dbo" TABLE = "emp_table" @@ -26,40 +28,40 @@ class ConnectionSync: """This Class create dummy connection with database and return dummy cursor""" - def __init__(self, query_object): + def __init__(self, query_object: Union[OracleQueries, MSSQLQueries]) -> None: """Setup dummy connection""" self.query_object = query_object - def __enter__(self): + def __enter__(self) -> "ConnectionSync": """Make a dummy database connection and return it""" return self - def __exit__(self, exception_type, exception_value, exception_traceback): + def __exit__(self, exception_type, exception_value, exception_traceback) -> None: """Make sure the dummy database connection gets closed""" pass - def execute(self, statement): + def execute(self, statement) -> "CursorSync": """This method returns dummy cursor""" return CursorSync(query_object=self.query_object, statement=statement) - def close(self): + def close(self) -> None: pass class CursorSync: """This class contains methods which returns dummy response""" - def __enter__(self): + def __enter__(self) -> "CursorSync": """Make a dummy database connection and return it""" return self - def __init__(self, query_object, *args, **kwargs): + def __init__(self, query_object, *args, **kwargs) -> None: """Setup dummy cursor""" self.first_call = True self.query = kwargs["statement"] self.query_object = query_object - def keys(self): + def keys(self) -> List[str]: """Return Columns of table Returns: @@ -67,7 +69,9 @@ def keys(self): """ return ["ids", "names"] - def fetchmany(self, size): + def fetchmany( + self, size: int + ) -> List[Union[Tuple[str], Tuple[int, str], Tuple[int]]]: """This method returns response of fetchmany Args: @@ -130,14 +134,14 @@ def fetchmany(self, size): (["table_1", "table_2", ""], ["table_1", "table_2"]), ], ) -def test_configured_tables(tables, expected_tables): +def test_configured_tables(tables, expected_tables) -> None: actual_tables = configured_tables(tables) assert actual_tables == expected_tables @pytest.mark.parametrize("tables", ["*", ["*"]]) -def test_is_wildcard(tables): +def test_is_wildcard(tables) -> None: assert is_wildcard(tables) @@ -156,7 +160,7 @@ def test_is_wildcard(tables): ("Schema", ["Table1", "Table2"], "schema_table1_table2_"), ], ) -def test_map_column_names(schema, tables, prefix): +def test_map_column_names(schema, tables: Optional[Sized], prefix) -> None: mapped_column_names = map_column_names(COLUMN_NAMES, schema, tables) for column_name, mapped_column_name in zip( @@ -165,12 +169,12 @@ def test_map_column_names(schema, tables, prefix): assert f"{prefix}{column_name}".lower() == mapped_column_name -async def get_cursor(query_object, query): +async def get_cursor(query_object, query) -> CursorSync: return CursorSync(query_object=query_object, statement=query) @pytest.mark.asyncio -async def test_fetch(): +async def test_fetch() -> None: query_object = MSSQLQueries() rows = [] diff --git a/tests/sources/test_github.py b/tests/sources/test_github.py index f241a53f9..bcd6ba52f 100644 --- a/tests/sources/test_github.py +++ b/tests/sources/test_github.py @@ -8,6 +8,7 @@ from contextlib import asynccontextmanager from copy import deepcopy from http import HTTPStatus +from typing import Dict, List, Tuple, Union from unittest.mock import ANY, AsyncMock, Mock, patch import gidgethub @@ -34,7 +35,7 @@ ADVANCED_SNIPPET = "advanced_snippet" -def public_repo(): +def public_repo() -> Dict[str, Union[Dict[str, int], Dict[str, str], int, str]]: return { "name": "demo_repo", "nameWithOwner": "demo_user/demo_repo", @@ -55,7 +56,73 @@ def public_repo(): } -def pull_request(): +def pull_request() -> ( + Dict[ + str, + Dict[ + str, + Dict[ + str, + Dict[ + str, + List[ + Dict[ + str, + Union[ + Dict[str, str], + Dict[ + str, + Union[ + Dict[str, Union[bool, str]], + List[Dict[str, Dict[str, str]]], + ], + ], + Dict[ + str, + Union[ + Dict[str, Union[bool, str]], + List[Dict[str, str]], + ], + ], + Dict[ + str, + Union[ + Dict[str, Union[bool, str]], + List[ + Dict[ + str, + Union[ + Dict[str, str], + Dict[ + str, + Union[ + Dict[str, Union[bool, str]], + List[Dict[str, str]], + ], + ], + str, + ], + ] + ], + ], + ], + Dict[ + str, + Union[ + Dict[str, Union[bool, str]], + List[Dict[str, Union[Dict[str, str], str]]], + ], + ], + int, + str, + ], + ] + ], + ], + ], + ], + ] +): return { "data": { "repository": { @@ -134,7 +201,44 @@ def pull_request(): } -def issue(): +def issue() -> ( + Dict[ + str, + Dict[ + str, + Dict[ + str, + Dict[ + str, + List[ + Dict[ + str, + Union[ + None, + Dict[ + str, + Union[ + Dict[str, Union[bool, str]], + List[Dict[str, str]], + ], + ], + Dict[ + str, + Union[ + Dict[str, Union[bool, str]], + List[Dict[str, Union[Dict[str, str], str]]], + ], + ], + int, + str, + ], + ] + ], + ], + ], + ], + ] +): return { "data": { "repository": { @@ -190,7 +294,7 @@ def issue(): } -def attachments(): +def attachments() -> Tuple[Dict[str, Union[int, str]], Dict[str, Union[int, str]]]: return ( { "_id": "demo_repo/source/source.md", @@ -280,7 +384,37 @@ def attachments(): }, ] -MOCK_RESPONSE_ISSUE = { +MOCK_RESPONSE_ISSUE: Dict[ + str, + Dict[ + str, + Dict[ + str, + List[ + Dict[ + str, + Union[ + None, + Dict[str, str], + Dict[ + str, + Union[Dict[str, Union[bool, str]], List[Dict[str, str]]], + ], + Dict[ + str, + Union[ + Dict[str, Union[bool, str]], + List[Dict[str, Union[Dict[str, str], str]]], + ], + ], + int, + str, + ], + ] + ], + ], + ], +] = { "repository": { "issues": { "nodes": [ @@ -325,7 +459,17 @@ def attachments(): } } } -EXPECTED_ISSUE = { +EXPECTED_ISSUE: Dict[ + str, + Union[ + None, + Dict[str, str], + List[Dict[str, str]], + List[Dict[str, Union[Dict[str, str], str]]], + int, + str, + ], +] = { "number": 1, "url": "https://github.com/demo_user/demo_repo/issues/1", "createdAt": "2023-04-18T10:12:21Z", @@ -636,7 +780,14 @@ def attachments(): } } } -EXPECTED_ACCESS_CONTROL = [ +EXPECTED_ACCESS_CONTROL: List[ + Dict[ + str, + Union[ + Dict[str, Dict[str, Union[Dict[str, List[str]], str]]], Dict[str, str], str + ], + ] +] = [ { "_id": "#123", "identity": { @@ -659,7 +810,14 @@ def attachments(): }, } ] -EXPECTED_ACCESS_CONTROL_GITHUB_APP = [ +EXPECTED_ACCESS_CONTROL_GITHUB_APP: List[ + Dict[ + str, + Union[ + Dict[str, Dict[str, Union[Dict[str, List[str]], str]]], Dict[str, str], str + ], + ] +] = [ { "_id": "#1", "identity": { @@ -811,12 +969,12 @@ def attachments(): @asynccontextmanager async def create_github_source( - auth_method=PERSONAL_ACCESS_TOKEN, - repo_type="other", - org_name="", - repos="*", - use_document_level_security=False, - use_text_extraction_service=False, + auth_method: str = PERSONAL_ACCESS_TOKEN, + repo_type: str = "other", + org_name: str = "", + repos: str = "*", + use_document_level_security: bool = False, + use_text_extraction_service: bool = False, ): async with create_source( GitHubDataSource, @@ -836,7 +994,7 @@ async def create_github_source( class JSONAsyncMock(AsyncMock): - def __init__(self, json, status, *args, **kwargs): + def __init__(self, json, status, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self._json = json self.status = status @@ -845,7 +1003,7 @@ async def json(self): return self._json -def get_json_mock(mock_response, status): +def get_json_mock(mock_response, status) -> AsyncMock: async_mock = AsyncMock() async_mock.__aenter__ = AsyncMock( return_value=JSONAsyncMock(json=mock_response, status=status) @@ -855,7 +1013,7 @@ def get_json_mock(mock_response, status): @pytest.mark.asyncio @pytest.mark.parametrize("field", ["repositories", "token"]) -async def test_validate_config_missing_fields_then_raise(field): +async def test_validate_config_missing_fields_then_raise(field) -> None: async with create_github_source() as source: source.configuration.get_field(field).value = "" @@ -864,7 +1022,7 @@ async def test_validate_config_missing_fields_then_raise(field): @pytest.mark.asyncio -async def test_ping_with_successful_connection(): +async def test_ping_with_successful_connection() -> None: async with create_github_source() as source: source.github_client._get_client.graphql = AsyncMock( return_value={"user": "username"} @@ -873,7 +1031,7 @@ async def test_ping_with_successful_connection(): @pytest.mark.asyncio -async def test_get_user_repos(): +async def test_get_user_repos() -> None: actual_response = [] async with create_github_source() as source: source.github_client.paginated_api_call = Mock( @@ -889,7 +1047,7 @@ async def test_get_user_repos(): @pytest.mark.asyncio @patch("connectors.utils.time_to_sleep_between_retries", Mock(return_value=0)) -async def test_ping_with_unsuccessful_connection(): +async def test_ping_with_unsuccessful_connection() -> None: async with create_github_source() as source: with patch.object( source.github_client, @@ -905,7 +1063,7 @@ async def test_ping_with_unsuccessful_connection(): "scopes", [{}, {"repo"}, {"manage_runner:org, delete:packages, admin:public_key"}], ) -async def test_validate_config_with_insufficient_scope(scopes): +async def test_validate_config_with_insufficient_scope(scopes) -> None: async with create_github_source() as source: source.github_client.get_personal_access_token_scopes = AsyncMock( return_value=scopes @@ -918,7 +1076,7 @@ async def test_validate_config_with_insufficient_scope(scopes): @pytest.mark.asyncio -async def test_validate_config_with_extra_scopes_token(patch_logger): +async def test_validate_config_with_extra_scopes_token(patch_logger) -> None: async with create_github_source() as source: source.github_client.get_personal_access_token_scopes = AsyncMock( return_value={"user", "repo", "admin:org"} @@ -931,7 +1089,7 @@ async def test_validate_config_with_extra_scopes_token(patch_logger): @pytest.mark.asyncio @patch("connectors.utils.time_to_sleep_between_retries", Mock(return_value=0)) -async def test_validate_config_with_inaccessible_repositories_then_raise(): +async def test_validate_config_with_inaccessible_repositories_then_raise() -> None: async with create_github_source( repos="repo1m owner1/repo1, repo2, owner2/repo2" ) as source: @@ -945,7 +1103,7 @@ async def test_validate_config_with_inaccessible_repositories_then_raise(): @pytest.mark.asyncio @patch("connectors.utils.time_to_sleep_between_retries", Mock(return_value=0)) -async def test_get_invalid_repos_with_max_retries(): +async def test_get_invalid_repos_with_max_retries() -> None: async with create_github_source() as source: with pytest.raises(Exception): source.github_client.graphql = AsyncMock(side_effect=Exception()) @@ -954,7 +1112,7 @@ async def test_get_invalid_repos_with_max_retries(): @pytest.mark.asyncio @patch("connectors.utils.time_to_sleep_between_retries", Mock(return_value=0)) -async def test_get_response_with_rate_limit_exceeded(): +async def test_get_response_with_rate_limit_exceeded() -> None: async with create_github_source() as source: with patch.object( source.github_client._get_client, @@ -967,7 +1125,7 @@ async def test_get_response_with_rate_limit_exceeded(): @pytest.mark.asyncio -async def test_put_to_sleep(): +async def test_put_to_sleep() -> None: async with create_github_source() as source: source.github_client._get_retry_after = AsyncMock(return_value=0) with pytest.raises(Exception, match="Rate limit exceeded."): @@ -975,7 +1133,7 @@ async def test_put_to_sleep(): @pytest.mark.asyncio -async def test_get_retry_after(): +async def test_get_retry_after() -> None: async with create_github_source() as source: source.github_client._get_client.getitem = AsyncMock( return_value={ @@ -1008,7 +1166,7 @@ async def test_get_retry_after(): ), ], ) -async def test_graphql_with_BadGraphQLRequest(exceptions, raises): +async def test_graphql_with_BadGraphQLRequest(exceptions, raises) -> None: async with create_github_source() as source: source.github_client._get_client.graphql = Mock(side_effect=exceptions) with pytest.raises(raises): @@ -1049,7 +1207,7 @@ async def test_graphql_with_BadGraphQLRequest(exceptions, raises): ), ], ) -async def test_graphql_with_QueryError(exceptions, raises, is_raised): +async def test_graphql_with_QueryError(exceptions, raises, is_raised) -> None: async with create_github_source() as source: source.github_client._get_client.graphql = Mock(side_effect=exceptions) if is_raised: @@ -1069,7 +1227,7 @@ async def test_graphql_with_QueryError(exceptions, raises, is_raised): @pytest.mark.asyncio @patch("connectors.utils.time_to_sleep_between_retries", Mock(return_value=0)) -async def test_graphql_with_unauthorized(): +async def test_graphql_with_unauthorized() -> None: async with create_github_source() as source: source.github_client._get_client.graphql = Mock( side_effect=GraphQLAuthorizationFailure( @@ -1083,7 +1241,7 @@ async def test_graphql_with_unauthorized(): @pytest.mark.asyncio -async def test_paginated_api_call(): +async def test_paginated_api_call() -> None: expected_response = MOCK_RESPONSE_REPO async with create_github_source() as source: actual_response = [] @@ -1098,7 +1256,7 @@ async def test_paginated_api_call(): @pytest.mark.asyncio -async def test_get_invalid_repos(): +async def test_get_invalid_repos() -> None: expected_response = ["owner1/repo2", "owner2/repo2"] async with create_github_source( repos="repo1, owner1/repo2, repo2, owner2/repo2" @@ -1133,7 +1291,7 @@ async def test_get_invalid_repos(): @pytest.mark.asyncio -async def test_get_invalid_repos_organization(): +async def test_get_invalid_repos_organization() -> None: expected_response = ["owner1/repo2", "org1/repo3"] async with create_github_source( repos="repo1, owner1/repo2, repo3", repo_type="organization", org_name="org1" @@ -1167,7 +1325,7 @@ async def test_get_invalid_repos_organization(): ) async def test_get_invalid_repos_organization_for_github_app( repo_type, configured_repos, expected_invalid_repos -): +) -> None: async with create_github_source( auth_method=GITHUB_APP, repos=configured_repos, repo_type=repo_type ) as source: @@ -1216,7 +1374,7 @@ async def test_get_invalid_repos_organization_for_github_app( @pytest.mark.asyncio -async def test_get_content_with_md_file(): +async def test_get_content_with_md_file() -> None: expected_response = { "_id": "demo_repo/source.md", "_timestamp": "2023-04-17T12:55:01Z", @@ -1235,7 +1393,7 @@ async def test_get_content_with_md_file(): @pytest.mark.asyncio -async def test_get_content_with_md_file_with_extraction_service(): +async def test_get_content_with_md_file_with_extraction_service() -> None: with ( patch( "connectors.content_extraction.ContentExtraction.extract_text", @@ -1271,7 +1429,7 @@ async def test_get_content_with_md_file_with_extraction_service(): (23000000, None), ], ) -async def test_get_content_with_differernt_size(size, expected_content): +async def test_get_content_with_differernt_size(size, expected_content) -> None: async with create_github_source() as source: attachment_with_size_zero = MOCK_ATTACHMENT.copy() attachment_with_size_zero["size"] = size @@ -1282,7 +1440,7 @@ async def test_get_content_with_differernt_size(size, expected_content): @pytest.mark.asyncio -async def test_fetch_repos(): +async def test_fetch_repos() -> None: async with create_github_source() as source: source.github_client.graphql = AsyncMock( return_value={"data": {"viewer": {"login": "owner1"}}} @@ -1308,7 +1466,7 @@ async def test_fetch_repos(): @pytest.mark.asyncio -async def test_fetch_repos_organization(): +async def test_fetch_repos_organization() -> None: async with create_github_source( repo_type="organization", org_name="org_1" ) as source: @@ -1323,7 +1481,7 @@ async def test_fetch_repos_organization(): @pytest.mark.asyncio -async def test_fetch_repos_when_user_repos_is_available(): +async def test_fetch_repos_when_user_repos_is_available() -> None: async with create_github_source(repos="demo_user/demo_repo, , demo_repo") as source: source.github_client.graphql = AsyncMock( side_effect=[ @@ -1351,7 +1509,7 @@ async def test_fetch_repos_when_user_repos_is_available(): "exception", [UnauthorizedException, ForbiddenException], ) -async def test_fetch_repos_with_client_exception(exception): +async def test_fetch_repos_with_client_exception(exception) -> None: async with create_github_source() as source: source.github_client.graphql = Mock(side_effect=exception()) with pytest.raises(exception): @@ -1375,7 +1533,7 @@ async def test_fetch_repos_with_client_exception(exception): ("other", "user_1/repo_3, user_2/repo_4", [MOCK_REPO_3_DOC, MOCK_REPO_4_DOC]), ], ) -async def test_fetch_repos_github_app(repo_type, repos, expected_repos): +async def test_fetch_repos_github_app(repo_type, repos, expected_repos) -> None: async with create_github_source( auth_method=GITHUB_APP, repo_type=repo_type, repos=repos ) as source: @@ -1402,7 +1560,7 @@ async def test_fetch_repos_github_app(repo_type, repos, expected_repos): @pytest.mark.asyncio -async def test_fetch_issues(): +async def test_fetch_issues() -> None: async with create_github_source() as source: source.fetch_extra_fields = AsyncMock() with patch.object( @@ -1424,7 +1582,7 @@ async def test_fetch_issues(): "exception", [UnauthorizedException, ForbiddenException], ) -async def test_fetch_issues_with_client_exception(exception): +async def test_fetch_issues_with_client_exception(exception) -> None: async with create_github_source() as source: source.github_client.graphql = Mock(side_effect=exception()) with pytest.raises(exception): @@ -1436,7 +1594,7 @@ async def test_fetch_issues_with_client_exception(exception): @pytest.mark.asyncio -async def test_fetch_pull_requests(): +async def test_fetch_pull_requests() -> None: async with create_github_source() as source: with patch.object( source.github_client, @@ -1462,7 +1620,7 @@ async def test_fetch_pull_requests(): "exception", [UnauthorizedException, ForbiddenException], ) -async def test_fetch_pull_requests_with_client_exception(exception): +async def test_fetch_pull_requests_with_client_exception(exception) -> None: async with create_github_source() as source: source.github_client.graphql = Mock(side_effect=exception()) with pytest.raises(exception): @@ -1474,7 +1632,7 @@ async def test_fetch_pull_requests_with_client_exception(exception): @pytest.mark.asyncio -async def test_fetch_pull_requests_with_deleted_users(): +async def test_fetch_pull_requests_with_deleted_users() -> None: async with create_github_source() as source: mock_review_deleted_user = { "repository": { @@ -1536,7 +1694,7 @@ async def test_fetch_pull_requests_with_deleted_users(): @pytest.mark.asyncio -async def test_fetch_path(): +async def test_fetch_path() -> None: async with create_github_source() as source: with patch.object( source.github_client, @@ -1550,7 +1708,7 @@ async def test_fetch_path(): @pytest.mark.asyncio -async def test_fetch_files(): +async def test_fetch_files() -> None: expected_response = ( { "name": "source.md", @@ -1590,7 +1748,7 @@ async def test_fetch_files(): "exception", [UnauthorizedException, ForbiddenException], ) -async def test_fetch_files_when_error_occurs(exception): +async def test_fetch_files_when_error_occurs(exception) -> None: async with create_github_source() as source: source.github_client.get_github_item = Mock(side_effect=exception()) with pytest.raises(exception): @@ -1599,7 +1757,7 @@ async def test_fetch_files_when_error_occurs(exception): @pytest.mark.asyncio -async def test_get_docs(): +async def test_get_docs() -> None: expected_response = [ PUBLIC_REPO, MOCK_RESPONSE_PULL, @@ -1624,7 +1782,9 @@ async def test_get_docs(): @pytest.mark.asyncio -async def test_get_docs_with_access_control_should_not_add_acl_for_public_repo(): +async def test_get_docs_with_access_control_should_not_add_acl_for_public_repo() -> ( + None +): public_repo_ = public_repo() pull_request_ = pull_request() issue_ = issue() @@ -1651,7 +1811,9 @@ async def test_get_docs_with_access_control_should_not_add_acl_for_public_repo() @pytest.mark.asyncio -async def test_get_docs_with_access_control_should_add_acl_for_non_public_repo(): +async def test_get_docs_with_access_control_should_add_acl_for_non_public_repo() -> ( + None +): expected_response = [ PRIVATE_REPO, MOCK_RESPONSE_PULL, @@ -1825,7 +1987,9 @@ async def test_get_docs_with_access_control_should_add_acl_for_non_public_repo() ], ) @pytest.mark.asyncio -async def test_advanced_rules_validation(advanced_rules, expected_validation_result): +async def test_advanced_rules_validation( + advanced_rules, expected_validation_result +) -> None: async with create_github_source() as source: source.get_invalid_repos = AsyncMock(return_value=[]) @@ -1858,7 +2022,7 @@ async def test_advanced_rules_validation(advanced_rules, expected_validation_res @pytest.mark.asyncio async def test_advanced_rules_validation_with_invalid_repos( advanced_rules, expected_validation_result -): +) -> None: async with create_github_source() as source: source.get_invalid_repos = AsyncMock(return_value=["repo_name"]) @@ -1938,7 +2102,7 @@ async def test_advanced_rules_validation_with_invalid_repos( ], ) @pytest.mark.asyncio -async def test_get_docs_with_advanced_rules(filtering, expected_response): +async def test_get_docs_with_advanced_rules(filtering, expected_response) -> None: actual_response = [] async with create_github_source() as source: source._get_configured_repos = Mock(return_value=AsyncIterator([PUBLIC_REPO])) @@ -1955,14 +2119,14 @@ async def test_get_docs_with_advanced_rules(filtering, expected_response): @pytest.mark.asyncio -async def test_is_previous_repo(): +async def test_is_previous_repo() -> None: async with create_github_source() as source: assert source.is_previous_repo("demo_user/demo_repo") is False assert source.is_previous_repo("demo_user/demo_repo") is True @pytest.mark.asyncio -async def test_get_access_control(): +async def test_get_access_control() -> None: async with create_github_source(repo_type="organization") as source: actual_response = [] source._dls_enabled = Mock(return_value=True) @@ -1983,7 +2147,7 @@ async def test_get_access_control(): @pytest.mark.asyncio -async def test_get_access_control_github_app(): +async def test_get_access_control_github_app() -> None: async with create_github_source( auth_method=GITHUB_APP, repo_type="organization" ) as source: @@ -2030,7 +2194,7 @@ async def test_get_access_control_github_app(): @pytest.mark.asyncio -async def test_fetch_access_control(): +async def test_fetch_access_control() -> None: async with create_github_source() as source: source.github_client.paginated_api_call = Mock( side_effect=[ @@ -2054,7 +2218,7 @@ async def test_fetch_access_control(): ("other", True, False), ], ) -async def test_dls_enabled(repo_type, use_document_level_security, dls_enabled): +async def test_dls_enabled(repo_type, use_document_level_security, dls_enabled) -> None: async with create_github_source( repo_type=repo_type, use_document_level_security=use_document_level_security ) as source: @@ -2072,7 +2236,7 @@ async def test_dls_enabled(repo_type, use_document_level_security, dls_enabled): ("repo, read:org", {"repo", "read:org"}), ], ) -async def test_get_personal_access_token_scopes(scopes, expected_scopes): +async def test_get_personal_access_token_scopes(scopes, expected_scopes) -> None: async with create_github_source() as source: source.github_client._get_client._request = AsyncMock( return_value=(200, {"X-OAuth-Scopes": scopes}, None) @@ -2104,7 +2268,9 @@ async def test_get_personal_access_token_scopes(scopes, expected_scopes): ), ], ) -async def test_get_personal_access_token_scopes_when_error_occurs(exception, raises): +async def test_get_personal_access_token_scopes_when_error_occurs( + exception, raises +) -> None: async with create_github_source() as source: source.github_client._get_client._request = AsyncMock(side_effect=exception) with pytest.raises(raises): @@ -2112,7 +2278,7 @@ async def test_get_personal_access_token_scopes_when_error_occurs(exception, rai @pytest.mark.asyncio -async def test_github_client_get_installations(): +async def test_github_client_get_installations() -> None: async with create_github_source(auth_method=GITHUB_APP) as source: mock_response = [ { @@ -2140,7 +2306,7 @@ async def test_github_client_get_installations(): @pytest.mark.asyncio -async def test_github_app_paginated_get(): +async def test_github_app_paginated_get() -> None: async with create_github_source(auth_method=GITHUB_APP) as source: item_1 = {"id": 1} item_2 = {"id": 2} @@ -2161,7 +2327,7 @@ async def test_github_app_paginated_get(): @pytest.mark.asyncio -async def test_update_installation_id(): +async def test_update_installation_id() -> None: async with create_github_source(auth_method=GITHUB_APP) as source: jwt_response = {"token": "changeme"} installation_id = 123 @@ -2187,7 +2353,7 @@ async def test_update_installation_id(): (PERSONAL_ACCESS_TOKEN, 1, "foo"), ], ) -async def test_logged_in_user(auth_method, expected_await_count, expected_user): +async def test_logged_in_user(auth_method, expected_await_count, expected_user) -> None: async with create_github_source(auth_method=auth_method) as source: source.github_client.get_logged_in_user = AsyncMock(return_value="foo") user = await source._logged_in_user() @@ -2200,7 +2366,7 @@ async def test_logged_in_user(auth_method, expected_await_count, expected_user): @pytest.mark.asyncio -async def test_fetch_installations_personal_access_token(): +async def test_fetch_installations_personal_access_token() -> None: async with create_github_source() as source: source.github_client.get_installations = AsyncMock() await source._fetch_installations() @@ -2209,7 +2375,7 @@ async def test_fetch_installations_personal_access_token(): @pytest.mark.asyncio -async def test_fetch_installations_withp_prepopulated_installations(): +async def test_fetch_installations_withp_prepopulated_installations() -> None: prepopulated_installations = {"fake_org": {"installation_id": 1}} async with create_github_source(auth_method=GITHUB_APP) as source: source.github_client.get_installations = AsyncMock() @@ -2230,7 +2396,7 @@ async def test_fetch_installations_withp_prepopulated_installations(): ("other", {"user_1": 3, "user_2": 4}), ], ) -async def test_fetch_installations(repo_type, expected_installations): +async def test_fetch_installations(repo_type, expected_installations) -> None: async with create_github_source( auth_method=GITHUB_APP, repo_type=repo_type ) as source: @@ -2252,7 +2418,7 @@ async def test_fetch_installations(repo_type, expected_installations): (GITHUB_APP, "other", ["user_1", "user_2"]), ], ) -async def test_get_owners(auth_method, repo_type, expected_owners): +async def test_get_owners(auth_method, repo_type, expected_owners) -> None: async with create_github_source( auth_method=auth_method, repo_type=repo_type, org_name="demo_org" ) as source: @@ -2271,7 +2437,7 @@ async def test_get_owners(auth_method, repo_type, expected_owners): @pytest.mark.asyncio @patch("connectors.utils.time_to_sleep_between_retries", Mock(return_value=0)) -async def test_update_installation_access_token_when_error_occurs(): +async def test_update_installation_access_token_when_error_occurs() -> None: async with create_github_source() as source: source.github_client.get_installation_access_token = AsyncMock( side_effect=Exception() @@ -2304,7 +2470,7 @@ async def test_update_installation_access_token_when_error_occurs(): (Exception(), Exception), ], ) -async def test_get_github_item_when_error_occurs(exceptions, raises): +async def test_get_github_item_when_error_occurs(exceptions, raises) -> None: async with create_github_source() as source: source.github_client._get_client.getitem = Mock(side_effect=exceptions) with pytest.raises(raises): diff --git a/tests/sources/test_gmail.py b/tests/sources/test_gmail.py index ffec4a7ce..05026de39 100644 --- a/tests/sources/test_gmail.py +++ b/tests/sources/test_gmail.py @@ -49,7 +49,9 @@ def dls_enabled(value): @asynccontextmanager -async def create_gmail_source(dls_enabled=False, include_spam_and_trash=False): +async def create_gmail_source( + dls_enabled: bool = False, include_spam_and_trash: bool = False +): async with create_source( GMailDataSource, service_account_credentials=json.dumps(JSON_CREDENTIALS), @@ -104,7 +106,7 @@ class TestGMailAdvancedRulesValidator: ), ], ) - async def test_advanced_rules_validator(self, advanced_rules, is_valid): + async def test_advanced_rules_validator(self, advanced_rules, is_valid) -> None: validation_result = await GMailAdvancedRulesValidator().validate(advanced_rules) assert validation_result.is_valid == is_valid @@ -141,13 +143,13 @@ async def test_advanced_rules_validator(self, advanced_rules, is_valid): ], ) @freeze_time(DATE) -def test_message_doc(message, expected_doc): +def test_message_doc(message, expected_doc) -> None: assert _message_doc(message) == expected_doc async def setup_messages_and_users_apis( patch_gmail_client, patch_google_directory_client, messages, users -): +) -> None: patch_google_directory_client.users = AsyncIterator(users) patch_gmail_client.messages = AsyncIterator(messages) patch_gmail_client.message = AsyncMock(side_effect=messages) @@ -173,7 +175,7 @@ async def patch_google_directory_client(self): @pytest.mark.asyncio async def test_ping_successful( self, patch_gmail_client, patch_google_directory_client - ): + ) -> None: async with create_gmail_source() as source: patch_gmail_client.ping = AsyncMock() patch_google_directory_client.ping = AsyncMock() @@ -187,7 +189,7 @@ async def test_ping_successful( @pytest.mark.asyncio async def test_ping_gmail_client_fails( self, patch_gmail_client, patch_google_directory_client - ): + ) -> None: async with create_gmail_source() as source: patch_gmail_client.ping = AsyncMock( side_effect=Exception("Something went wrong") @@ -200,7 +202,7 @@ async def test_ping_gmail_client_fails( @pytest.mark.asyncio async def test_ping_google_directory_client_fails( self, patch_gmail_client, patch_google_directory_client - ): + ) -> None: async with create_gmail_source() as source: patch_gmail_client.ping = AsyncMock() patch_google_directory_client.ping = AsyncMock(side_effect=Exception) @@ -211,7 +213,7 @@ async def test_ping_google_directory_client_fails( @pytest.mark.asyncio async def test_validate_config_valid( self, patch_gmail_client, patch_google_directory_client - ): + ) -> None: valid_json = '{"project_id": "dummy123"}' async with create_gmail_source() as source: @@ -231,7 +233,7 @@ async def test_validate_config_valid( raise AssertionError(msg) from None @pytest.mark.asyncio - async def test_validate_config_invalid_service_account_credentials(self): + async def test_validate_config_invalid_service_account_credentials(self) -> None: async with create_gmail_source() as source: source.configuration.get_field( "service_account_credentials" @@ -241,7 +243,7 @@ async def test_validate_config_invalid_service_account_credentials(self): await source.validate_config() @pytest.mark.asyncio - async def test_validate_config_invalid_subject(self): + async def test_validate_config_invalid_subject(self) -> None: async with create_gmail_source() as source: source.configuration.get_field("subject").value = "invalid address" @@ -251,7 +253,7 @@ async def test_validate_config_invalid_subject(self): @pytest.mark.asyncio async def test_validate_config_invalid_gmail_auth( self, patch_gmail_client, patch_google_directory_client - ): + ) -> None: async with create_gmail_source() as source: patch_gmail_client.ping = AsyncMock( side_effect=AuthError("some auth error") @@ -267,7 +269,7 @@ async def test_validate_config_invalid_gmail_auth( @pytest.mark.asyncio async def test_validate_config_invalid_google_directory_auth( self, patch_google_directory_client - ): + ) -> None: async with create_gmail_source() as source: patch_google_directory_client.ping = AsyncMock( side_effect=AuthError("some auth error") @@ -282,7 +284,7 @@ async def test_validate_config_invalid_google_directory_auth( @pytest.mark.asyncio async def test_get_access_control_with_dls_disabled( self, patch_google_directory_client - ): + ) -> None: users = [{UserFields.EMAIL.value: "user@google.com"}] patch_google_directory_client.users = AsyncIterator(users) @@ -299,7 +301,7 @@ async def test_get_access_control_with_dls_disabled( @pytest.mark.asyncio async def test_get_access_control_with_dls_enabled( self, patch_google_directory_client - ): + ) -> None: email = "user@google.com" creation_date = iso_utc() users = [ @@ -329,7 +331,7 @@ async def test_get_access_control_with_dls_enabled( @pytest.mark.asyncio async def test_get_docs_without_dls_without_filtering( self, patch_gmail_client, patch_google_directory_client - ): + ) -> None: users = [{UserFields.EMAIL.value: "user@google.com"}] message = { MessageFields.ID.value: "1", @@ -360,7 +362,7 @@ async def test_get_docs_without_dls_without_filtering( @pytest.mark.asyncio async def test_get_docs_without_dls_with_filtering( self, patch_gmail_client, patch_google_directory_client - ): + ) -> None: users = [{UserFields.EMAIL.value: "user@google.com"}] message = { MessageFields.ID.value: "1", @@ -399,7 +401,7 @@ async def test_get_docs_without_dls_with_filtering( @pytest.mark.asyncio async def test_get_docs_with_dls_without_filtering( self, patch_gmail_client, patch_google_directory_client - ): + ) -> None: email = "user@google.com" users = [{UserFields.EMAIL.value: email}] message = { @@ -434,7 +436,7 @@ async def test_get_docs_with_dls_without_filtering( @pytest.mark.asyncio async def test_get_docs_with_dls_with_filtering( self, patch_gmail_client, patch_google_directory_client - ): + ) -> None: email = "user@google.com" users = [{UserFields.EMAIL.value: email}] message = { @@ -478,7 +480,7 @@ async def test_get_docs_with_dls_with_filtering( @pytest.mark.asyncio async def test_get_docs_without_filtering_and_include_spam_and_trash( self, patch_gmail_client, patch_google_directory_client - ): + ) -> None: email = "user@google.com" users = [{UserFields.EMAIL.value: email}] message = { @@ -513,7 +515,7 @@ async def test_get_docs_without_filtering_and_include_spam_and_trash( @pytest.mark.asyncio async def test_get_docs_with_filtering_and_include_spam_and_trash( self, patch_gmail_client, patch_google_directory_client - ): + ) -> None: email = "user@google.com" users = [{UserFields.EMAIL.value: email}] message = { @@ -562,7 +564,9 @@ async def test_get_docs_with_filtering_and_include_spam_and_trash( ], ) @pytest.mark.asyncio - async def test_dls_enabled(self, feature_enabled_, rcf_enabled_, dls_enabled_): + async def test_dls_enabled( + self, feature_enabled_, rcf_enabled_, dls_enabled_ + ) -> None: async with create_gmail_source(dls_enabled=rcf_enabled_) as source: # `dls_enabled` sets both the feature flag and the config value in create_gmail_source # -> set dls feature flag after instantiation again diff --git a/tests/sources/test_google.py b/tests/sources/test_google.py index 1934cdb47..a13be08f7 100644 --- a/tests/sources/test_google.py +++ b/tests/sources/test_google.py @@ -3,6 +3,7 @@ # or more contributor license agreements. Licensed under the Elastic License 2.0; # you may not use this file except in compliance with the Elastic License 2.0. # +from typing import Dict, Optional from unittest.mock import AsyncMock, MagicMock, Mock, patch import pytest @@ -24,25 +25,29 @@ SUBJECT = "subject@domain.com" -def setup_gmail_client(json_credentials=None): +def setup_gmail_client( + json_credentials: Optional[Dict[str, str]] = None, +) -> GMailClient: if json_credentials is None: json_credentials = JSON_CREDENTIALS return GMailClient(json_credentials, CUSTOMER_ID, SUBJECT) -def setup_google_directory_client(json_credentials=None): +def setup_google_directory_client( + json_credentials: Optional[Dict[str, str]] = None, +) -> GoogleDirectoryClient: if json_credentials is None: json_credentials = JSON_CREDENTIALS return GoogleDirectoryClient(json_credentials, CUSTOMER_ID, SUBJECT) -def setup_google_service_account_client(): +def setup_google_service_account_client() -> GoogleServiceAccountClient: return GoogleServiceAccountClient(JSON_CREDENTIALS, "some api", "v1", [], 60) -def test_remove_universe_domain(): +def test_remove_universe_domain() -> None: universe_domain = "universe_domain" json_credentials = {universe_domain: "some_value", "key": "value"} remove_universe_domain(json_credentials) @@ -50,7 +55,7 @@ def test_remove_universe_domain(): assert universe_domain not in json_credentials -def test_validate_service_account_json_when_valid(): +def test_validate_service_account_json_when_valid() -> None: valid_service_account_credentials = '{"project_id": "dummy123"}' try: @@ -62,7 +67,7 @@ def test_validate_service_account_json_when_valid(): raise AssertionError(msg) from None -def test_validate_service_account_json_when_invalid(): +def test_validate_service_account_json_when_invalid() -> None: invalid_service_account_credentials = '{"invalid_key": "dummy123"}' with pytest.raises(ConfigurableFieldValueError): @@ -71,7 +76,7 @@ def test_validate_service_account_json_when_invalid(): ) -def test_load_service_account_json_valid_unescaped(): +def test_load_service_account_json_valid_unescaped() -> None: valid_unescaped_service_account_credentials = '{"project_id": "dummy123"}' json_credentials = load_service_account_json( @@ -81,7 +86,7 @@ def test_load_service_account_json_valid_unescaped(): assert isinstance(json_credentials, dict) -def test_load_service_account_json_valid_escaped(): +def test_load_service_account_json_valid_escaped() -> None: valid_unescaped_service_account_credentials = '"{\\"project_id\\": \\"dummy123\\"}"' json_credentials = load_service_account_json( @@ -91,7 +96,7 @@ def test_load_service_account_json_valid_escaped(): assert isinstance(json_credentials, dict) -def test_load_service_account_json_not_valid(): +def test_load_service_account_json_not_valid() -> None: valid_unescaped_service_account_credentials = "xd" with pytest.raises(ConfigurableFieldValueError): @@ -118,7 +123,9 @@ async def patch_aiogoogle(self): yield aiogoogle_client @pytest.mark.asyncio - async def test_api_call_paged(self, patch_service_account_creds, patch_aiogoogle): + async def test_api_call_paged( + self, patch_service_account_creds, patch_aiogoogle + ) -> None: items = ["a", "b", "c"] first_page_mock = AsyncIterator(items) first_page_mock.content = items @@ -175,7 +182,7 @@ async def _call_api_func(*args): assert actual_items == items @pytest.mark.asyncio - async def test_api_call(self, patch_service_account_creds, patch_aiogoogle): + async def test_api_call(self, patch_service_account_creds, patch_aiogoogle) -> None: item = "a" google_service_account_client = setup_google_service_account_client() @@ -206,7 +213,7 @@ async def patch_google_service_account_client(self): yield client @pytest.mark.asyncio - async def test_ping_successful(self, patch_google_service_account_client): + async def test_ping_successful(self, patch_google_service_account_client) -> None: google_directory_client = setup_google_directory_client() patch_google_service_account_client.api_call = AsyncMock() @@ -217,7 +224,7 @@ async def test_ping_successful(self, patch_google_service_account_client): raise AssertionError(msg) from None @pytest.mark.asyncio - async def test_ping_failed(self, patch_google_service_account_client): + async def test_ping_failed(self, patch_google_service_account_client) -> None: google_directory_client = setup_google_directory_client() patch_google_service_account_client.api_call = AsyncMock( side_effect=Exception() @@ -227,7 +234,7 @@ async def test_ping_failed(self, patch_google_service_account_client): await google_directory_client.ping() @pytest.mark.asyncio - async def test_users(self, patch_google_service_account_client): + async def test_users(self, patch_google_service_account_client) -> None: google_directory_client = setup_google_directory_client() users = [ @@ -251,7 +258,7 @@ async def test_users(self, patch_google_service_account_client): def test_subject_added_to_service_account_credentials( self, patch_google_service_account_client - ): + ) -> None: json_credentials = {} setup_google_directory_client(json_credentials=json_credentials) @@ -269,7 +276,7 @@ async def patch_google_service_account_client(self): yield client @pytest.mark.asyncio - async def test_ping_successful(self, patch_google_service_account_client): + async def test_ping_successful(self, patch_google_service_account_client) -> None: gmail_client = setup_gmail_client() patch_google_service_account_client.api_call = AsyncMock() @@ -280,7 +287,7 @@ async def test_ping_successful(self, patch_google_service_account_client): raise AssertionError(msg) from None @pytest.mark.asyncio - async def test_ping_failed(self, patch_google_service_account_client): + async def test_ping_failed(self, patch_google_service_account_client) -> None: gmail_client = setup_gmail_client() patch_google_service_account_client.api_call = AsyncMock( side_effect=Exception() @@ -290,7 +297,7 @@ async def test_ping_failed(self, patch_google_service_account_client): await gmail_client.ping() @pytest.mark.asyncio - async def test_messages(self, patch_google_service_account_client): + async def test_messages(self, patch_google_service_account_client) -> None: gmail_client = setup_gmail_client() messages = [ @@ -314,7 +321,7 @@ async def test_messages(self, patch_google_service_account_client): assert actual_messages == messages[0]["messages"] @pytest.mark.asyncio - async def test_message(self, patch_google_service_account_client): + async def test_message(self, patch_google_service_account_client) -> None: gmail_client = setup_gmail_client() message = {"raw": "some content", "internalDate": "some date"} @@ -326,7 +333,7 @@ async def test_message(self, patch_google_service_account_client): def test_subject_added_to_service_account_credentials( self, patch_google_service_account_client - ): + ) -> None: json_credentials = {} setup_gmail_client(json_credentials=json_credentials) diff --git a/tests/sources/test_google_cloud_storage.py b/tests/sources/test_google_cloud_storage.py index ad8e9eab8..8bcf09392 100644 --- a/tests/sources/test_google_cloud_storage.py +++ b/tests/sources/test_google_cloud_storage.py @@ -25,7 +25,7 @@ @asynccontextmanager -async def create_gcs_source(use_text_extraction_service=False): +async def create_gcs_source(use_text_extraction_service: bool = False): async with create_source( GoogleCloudStorageDataSource, service_account_credentials=SERVICE_ACCOUNT_CREDENTIALS, @@ -36,7 +36,7 @@ async def create_gcs_source(use_text_extraction_service=False): @pytest.mark.asyncio -async def test_empty_configuration(): +async def test_empty_configuration() -> None: """Tests the validity of the configurations passed to the Google Cloud source class.""" configuration = DataSourceConfiguration({"service_account_credentials": ""}) @@ -50,7 +50,7 @@ async def test_empty_configuration(): @pytest.mark.asyncio -async def test_raise_on_invalid_configuration(): +async def test_raise_on_invalid_configuration() -> None: configuration = DataSourceConfiguration( {"service_account_credentials": "{'abc':'bcd','cd'}"} ) @@ -64,7 +64,7 @@ async def test_raise_on_invalid_configuration(): @pytest.mark.asyncio -async def test_ping_for_successful_connection(catch_stdout): +async def test_ping_for_successful_connection(catch_stdout) -> None: """Tests the ping functionality for ensuring connection to Google Cloud Storage.""" expected_response = { @@ -83,7 +83,7 @@ async def test_ping_for_successful_connection(catch_stdout): @pytest.mark.asyncio @patch("connectors.utils.time_to_sleep_between_retries", Mock(return_value=0)) -async def test_ping_for_failed_connection(catch_stdout): +async def test_ping_for_failed_connection(catch_stdout) -> None: """Tests the ping functionality when connection can not be established to Google Cloud Storage.""" async with create_gcs_source() as source: @@ -135,7 +135,9 @@ async def test_ping_for_failed_connection(catch_stdout): ) ], ) -async def test_get_blob_document(previous_documents_list, updated_documents_list): +async def test_get_blob_document( + previous_documents_list, updated_documents_list +) -> None: """Tests the function which modifies the fetched blobs and maps the values to keys. Args: @@ -150,7 +152,7 @@ async def test_get_blob_document(previous_documents_list, updated_documents_list @pytest.mark.asyncio -async def test_fetch_buckets(): +async def test_fetch_buckets() -> None: """Tests the method which lists the storage buckets available in Google Cloud Storage.""" async with create_gcs_source() as source: @@ -199,7 +201,7 @@ async def test_fetch_buckets(): @pytest.mark.asyncio -async def test_fetch_blobs(): +async def test_fetch_blobs() -> None: """Tests the method responsible to yield blobs from Google Cloud Storage bucket.""" async with create_gcs_source() as source: @@ -246,7 +248,7 @@ async def test_fetch_blobs(): @pytest.mark.asyncio @patch("connectors.utils.time_to_sleep_between_retries", Mock(return_value=0)) -async def test_fetch_blobs_negative(): +async def test_fetch_blobs_negative() -> None: """Tests the method responsible to yield blobs(negative) from Google Cloud Storage bucket.""" bucket_response = { @@ -271,7 +273,7 @@ async def test_fetch_blobs_negative(): @pytest.mark.asyncio -async def test_get_docs(): +async def test_get_docs() -> None: """Tests the module responsible to fetch and yield blobs documents from Google Cloud Storage.""" async with create_gcs_source() as source: @@ -322,7 +324,7 @@ async def test_get_docs(): @pytest.mark.asyncio -async def test_get_docs_with_specific_bucket(): +async def test_get_docs_with_specific_bucket() -> None: """Tests the module responsible to fetch and yield blobs documents from Google Cloud Storage.""" async with create_gcs_source() as source: @@ -373,7 +375,7 @@ async def test_get_docs_with_specific_bucket(): @pytest.mark.asyncio -async def test_get_docs_when_no_buckets_present(): +async def test_get_docs_when_no_buckets_present() -> None: """Tests the module responsible to fetch and yield blobs documents from Google Cloud Storage. When Cloud storage does not have any buckets. """ @@ -453,7 +455,7 @@ async def test_get_docs_when_no_buckets_present(): ), ], ) -async def test_get_content(blob_document, expected_blob_document): +async def test_get_content(blob_document, expected_blob_document) -> None: """Test the module responsible for fetching the content of the file if it is extractable.""" async with create_gcs_source() as source: @@ -481,7 +483,7 @@ async def test_get_content(blob_document, expected_blob_document): "connectors.content_extraction.ContentExtraction._check_configured", lambda *_: True, ) -async def test_get_content_with_text_extraction_enabled_adds_body(): +async def test_get_content_with_text_extraction_enabled_adds_body() -> None: """Test the module responsible for fetching the content of the file if it is extractable.""" with ( @@ -537,7 +539,7 @@ async def test_get_content_with_text_extraction_enabled_adds_body(): @pytest.mark.asyncio -async def test_get_content_with_upper_extension(): +async def test_get_content_with_upper_extension() -> None: """Test the module responsible for fetching the content of the file if it is extractable.""" async with create_gcs_source() as source: @@ -583,7 +585,7 @@ async def test_get_content_with_upper_extension(): @pytest.mark.asyncio -async def test_get_content_when_type_not_supported(): +async def test_get_content_when_type_not_supported() -> None: """Test the module responsible for fetching the content of the file if it is not extractable or doit is not true.""" async with create_gcs_source() as source: @@ -625,7 +627,7 @@ async def test_get_content_when_type_not_supported(): @pytest.mark.asyncio -async def test_get_content_when_file_size_is_large(catch_stdout): +async def test_get_content_when_file_size_is_large(catch_stdout) -> None: """Test the module responsible for fetching the content of the file if it is not extractable or doit is not true.""" async with create_gcs_source() as source: @@ -668,7 +670,7 @@ async def test_get_content_when_file_size_is_large(catch_stdout): @pytest.mark.asyncio @patch("connectors.utils.time_to_sleep_between_retries", Mock(return_value=0)) -async def test_api_call_for_attribute_error(catch_stdout): +async def test_api_call_for_attribute_error(catch_stdout) -> None: """Tests the api_call method when resource attribute is not present in the getattr.""" async with create_gcs_source() as source: diff --git a/tests/sources/test_google_drive.py b/tests/sources/test_google_drive.py index f44b2fbcb..1c24d3dd5 100644 --- a/tests/sources/test_google_drive.py +++ b/tests/sources/test_google_drive.py @@ -33,7 +33,7 @@ SERVICE_ACCOUNT_CREDENTIALS = '{"project_id": "dummy123"}' -MORE_THAN_DEFAULT_FILE_SIZE_LIMIT = 10485760 + 1 +MORE_THAN_DEFAULT_FILE_SIZE_LIMIT: int = 10485760 + 1 @asynccontextmanager @@ -48,7 +48,7 @@ async def create_gdrive_source(**kwargs): @pytest.mark.asyncio -async def test_empty_configuration(): +async def test_empty_configuration() -> None: """Tests the validity of the configurations passed to the Google Drive source class.""" configuration = DataSourceConfiguration({"service_account_credentials": ""}) @@ -62,7 +62,7 @@ async def test_empty_configuration(): @pytest.mark.asyncio -async def test_raise_on_invalid_configuration(): +async def test_raise_on_invalid_configuration() -> None: """Test if invalid configuration raises an expected Exception""" configuration = DataSourceConfiguration( @@ -78,7 +78,7 @@ async def test_raise_on_invalid_configuration(): @pytest.mark.asyncio -async def test_raise_on_invalid_email_configuration_misformatted_email(): +async def test_raise_on_invalid_email_configuration_misformatted_email() -> None: """Test if invalid configuration raises an expected Exception""" configuration = DataSourceConfiguration( @@ -98,7 +98,7 @@ async def test_raise_on_invalid_email_configuration_misformatted_email(): @pytest.mark.asyncio -async def test_raise_on_invalid_email_configuration_empty_email(): +async def test_raise_on_invalid_email_configuration_empty_email() -> None: """Test if invalid configuration raises an expected Exception""" configuration = DataSourceConfiguration( @@ -118,7 +118,7 @@ async def test_raise_on_invalid_email_configuration_empty_email(): @pytest.mark.asyncio -async def test_ping_for_successful_connection(): +async def test_ping_for_successful_connection() -> None: """Tests the ping functionality for ensuring connection to Google Drive.""" expected_response = { @@ -136,7 +136,7 @@ async def test_ping_for_successful_connection(): @patch("connectors.utils.time_to_sleep_between_retries", mock.Mock(return_value=0)) @pytest.mark.asyncio -async def test_ping_for_failed_connection(): +async def test_ping_for_failed_connection() -> None: """Tests the ping functionality when connection can not be established to Google Drive.""" async with create_gdrive_source() as source: @@ -191,7 +191,7 @@ async def test_ping_for_failed_connection(): ], ) @pytest.mark.asyncio -async def test_prepare_files(files, expected_files): +async def test_prepare_files(files, expected_files) -> None: """Tests the function which modifies the fetched files and maps the values to keys.""" async with create_gdrive_source() as source: @@ -389,7 +389,7 @@ async def test_prepare_files(files, expected_files): ], ) @pytest.mark.asyncio -async def test_prepare_file(file, expected_file): +async def test_prepare_file(file, expected_file) -> None: """Test the method that formats the file metadata from Google Drive API""" async with create_gdrive_source() as source: @@ -413,7 +413,7 @@ async def test_prepare_file(file, expected_file): @pytest.mark.asyncio -async def test_list_drives(): +async def test_list_drives() -> None: """Tests the method which lists the shared drives from Google Drive.""" async with create_gdrive_source() as source: @@ -460,7 +460,7 @@ async def test_list_drives(): @pytest.mark.asyncio -async def test_list_folders(): +async def test_list_folders() -> None: """Tests the method which lists the folders from Google Drive.""" async with create_gdrive_source() as source: @@ -509,7 +509,7 @@ async def test_list_folders(): @pytest.mark.asyncio -async def test_resolve_paths(): +async def test_resolve_paths() -> None: """Test the method that builds a lookup between a folder id and its absolute path in Google Drive structure""" drives = { "driveId1": "Drive1", @@ -572,7 +572,7 @@ async def test_resolve_paths(): @pytest.mark.asyncio -async def test_fetch_files(): +async def test_fetch_files() -> None: """Tests the method responsible to yield files from Google Drive.""" async with create_gdrive_source() as source: @@ -621,7 +621,7 @@ async def test_fetch_files(): @pytest.mark.asyncio -async def test_get_docs_with_domain_wide_delegation(): +async def test_get_docs_with_domain_wide_delegation() -> None: """Tests the method responsible to yield files from Google Drive.""" async with create_gdrive_source( @@ -686,7 +686,7 @@ async def test_get_docs_with_domain_wide_delegation(): @pytest.mark.asyncio -async def test_get_docs(): +async def test_get_docs() -> None: """Tests the module responsible to fetch and yield files documents from Google Drive.""" async with create_gdrive_source() as source: @@ -736,7 +736,7 @@ async def test_get_docs(): @pytest.mark.asyncio -async def test_get_content(): +async def test_get_content() -> None: """Test the module responsible for fetching the content of the file if it is extractable.""" async with create_gdrive_source() as source: @@ -778,7 +778,7 @@ async def test_get_content(): @pytest.mark.asyncio -async def test_get_content_doit_false(): +async def test_get_content_doit_false() -> None: """Test the module responsible for fetching the content of the file with `doit` set to False""" async with create_gdrive_source() as source: @@ -804,7 +804,7 @@ async def test_get_content_doit_false(): @pytest.mark.asyncio -async def test_get_content_google_workspace_called(): +async def test_get_content_google_workspace_called() -> None: """Test the method responsible for selecting right extraction method depending on MIME type""" async with create_gdrive_source() as source: @@ -850,7 +850,7 @@ async def test_get_content_google_workspace_called(): @pytest.mark.asyncio -async def test_get_content_generic_files_called(): +async def test_get_content_generic_files_called() -> None: """Test the method responsible for selecting right extraction method depending on MIME type""" async with create_gdrive_source() as source: @@ -896,7 +896,7 @@ async def test_get_content_generic_files_called(): @pytest.mark.asyncio -async def test_get_google_workspace_content(): +async def test_get_google_workspace_content() -> None: """Test the module responsible for fetching the content of the Google Suite document.""" async with create_gdrive_source() as source: @@ -937,7 +937,9 @@ async def test_get_google_workspace_content(): "connectors.content_extraction.ContentExtraction._check_configured", lambda *_: True, ) -async def test_get_google_workspace_content_with_text_extraction_enabled_adds_body(): +async def test_get_google_workspace_content_with_text_extraction_enabled_adds_body() -> ( + None +): """Test the module responsible for fetching the content of the Google Suite document.""" with ( patch( @@ -983,7 +985,7 @@ async def test_get_google_workspace_content_with_text_extraction_enabled_adds_bo @pytest.mark.asyncio -async def test_get_google_workspace_content_size_limit(): +async def test_get_google_workspace_content_size_limit() -> None: """Test the module responsible for fetching the content of the Google Suite document if its size is above the limit.""" @@ -1021,7 +1023,7 @@ async def test_get_google_workspace_content_size_limit(): @pytest.mark.asyncio -async def test_get_generic_file_content(): +async def test_get_generic_file_content() -> None: """Test the module responsible for fetching the content of the file if it is extractable.""" async with create_gdrive_source() as source: @@ -1062,7 +1064,9 @@ async def test_get_generic_file_content(): "connectors.content_extraction.ContentExtraction._check_configured", lambda *_: True, ) -async def test_get_generic_file_content_with_text_extraction_enabled_adds_body(): +async def test_get_generic_file_content_with_text_extraction_enabled_adds_body() -> ( + None +): """Test the module responsible for fetching the content of the file if it is extractable.""" with ( patch( @@ -1108,7 +1112,7 @@ async def test_get_generic_file_content_with_text_extraction_enabled_adds_body() @pytest.mark.asyncio -async def test_get_generic_file_content_size_limit(): +async def test_get_generic_file_content_size_limit() -> None: """Test the module responsible for fetching the content of the file size is above the limit.""" async with create_gdrive_source() as source: @@ -1135,7 +1139,7 @@ async def test_get_generic_file_content_size_limit(): @pytest.mark.asyncio -async def test_get_generic_file_content_empty_file(): +async def test_get_generic_file_content_empty_file() -> None: """Test the module responsible for fetching the content of the file if the file size is 0.""" async with create_gdrive_source() as source: @@ -1162,7 +1166,7 @@ async def test_get_generic_file_content_empty_file(): @pytest.mark.asyncio -async def test_get_content_when_type_not_supported(): +async def test_get_content_when_type_not_supported() -> None: """Test the module responsible for fetching the content of the file if it is not extractable.""" async with create_gdrive_source() as source: @@ -1196,7 +1200,7 @@ async def test_get_content_when_type_not_supported(): @pytest.mark.asyncio @patch("connectors.utils.time_to_sleep_between_retries", 0) -async def test_api_call_for_attribute_error(): +async def test_api_call_for_attribute_error() -> None: """Tests the api_call method when resource attribute is not present in the getattr.""" async with create_gdrive_source() as source: @@ -1208,7 +1212,7 @@ async def test_api_call_for_attribute_error(): @pytest.mark.asyncio @patch("connectors.utils.time_to_sleep_between_retries", 0) -async def test_api_call_http_error(): +async def test_api_call_http_error() -> None: """Test handling retries for HTTPError exception in api_call() method.""" async with create_gdrive_source() as source: with mock.patch.object( @@ -1222,7 +1226,7 @@ async def test_api_call_http_error(): @pytest.mark.asyncio @patch("connectors.utils.time_to_sleep_between_retries", 0) -async def test_api_call_other_exception(): +async def test_api_call_other_exception() -> None: """Test handling retries for generic Exception in api_call() method.""" async with create_gdrive_source() as source: with mock.patch.object( @@ -1236,7 +1240,7 @@ async def test_api_call_other_exception(): @patch("connectors.utils.time_to_sleep_between_retries") async def test_api_call_ping_retries( mock_time_to_sleep_between_retries, mock_responses -): +) -> None: """Test handling retries for generic Exception in api_call() method.""" mock_time_to_sleep_between_retries.return_value = 0 @@ -1256,7 +1260,7 @@ async def test_api_call_ping_retries( @mock.patch("connectors.utils.time_to_sleep_between_retries") async def test_api_call_list_drives_retries( mock_time_to_sleep_between_retries, mock_responses -): +) -> None: """Test handling retries for generic Exception in api_call() method.""" mock_time_to_sleep_between_retries.return_value = 0 @@ -1365,7 +1369,7 @@ async def test_api_call_list_drives_retries( @pytest.mark.asyncio async def test_prepare_file_on_shared_drive_with_dls_enabled( file, permissions, expected_file -): +) -> None: """Test the method that formats the file metadata from Google Drive API""" async with create_gdrive_source() as source: @@ -1482,7 +1486,7 @@ async def test_prepare_file_on_shared_drive_with_dls_enabled( ], ) @pytest.mark.asyncio -async def test_prepare_file_on_my_drive_with_dls_enabled(file, expected_file): +async def test_prepare_file_on_my_drive_with_dls_enabled(file, expected_file) -> None: """Test the method that formats the file metadata from Google Drive API""" async with create_gdrive_source() as source: @@ -1537,7 +1541,7 @@ async def test_prepare_file_on_my_drive_with_dls_enabled(file, expected_file): ], ) @pytest.mark.asyncio -async def test_prepare_access_control_doc(user, groups, access_control_doc): +async def test_prepare_access_control_doc(user, groups, access_control_doc) -> None: """Test the method that formats the users data from Google Drive API""" async with create_gdrive_source( @@ -1604,7 +1608,7 @@ async def test_prepare_access_control_doc(user, groups, access_control_doc): @pytest.mark.asyncio async def test_prepare_access_control_documents( users_page, groups, access_control_docs -): +) -> None: """Test the method that formats the users data from Google Drive API""" async with create_gdrive_source( @@ -1630,7 +1634,7 @@ async def test_prepare_access_control_documents( @pytest.mark.asyncio -async def test_get_access_control_dls_disabled(): +async def test_get_access_control_dls_disabled() -> None: async with create_gdrive_source() as source: source._dls_enabled = mock.MagicMock(return_value=False) @@ -1642,7 +1646,7 @@ async def test_get_access_control_dls_disabled(): @pytest.mark.asyncio -async def test_get_access_control_dls_enabled(): +async def test_get_access_control_dls_enabled() -> None: """Tests the module responsible to fetch users data from Google Drive.""" async with create_gdrive_source( @@ -1684,14 +1688,14 @@ async def test_get_access_control_dls_enabled(): @pytest.mark.asyncio -async def test_get_google_workspace_admin_email_no_dls_no_delegation(): +async def test_get_google_workspace_admin_email_no_dls_no_delegation() -> None: async with create_gdrive_source() as source: email = source._get_google_workspace_admin_email() assert email is None @pytest.mark.asyncio -async def test_get_google_workspace_admin_email_with_delegation_no_dls(): +async def test_get_google_workspace_admin_email_with_delegation_no_dls() -> None: test_email = "email@test.com" async with create_gdrive_source( google_workspace_admin_email_for_data_sync=test_email @@ -1702,7 +1706,7 @@ async def test_get_google_workspace_admin_email_with_delegation_no_dls(): @pytest.mark.asyncio -async def test_get_google_workspace_admin_email_with_dls_no_delegation(): +async def test_get_google_workspace_admin_email_with_dls_no_delegation() -> None: test_email = "email@test.com" dls_admin_email = "email1@test.com" async with create_gdrive_source( @@ -1716,7 +1720,7 @@ async def test_get_google_workspace_admin_email_with_dls_no_delegation(): @pytest.mark.asyncio -async def test_get_google_workspace_admin_email_with_dls_delegation(): +async def test_get_google_workspace_admin_email_with_dls_delegation() -> None: test_email = "email@test.com" dls_admin_email = "email1@test.com" async with create_gdrive_source( @@ -1731,14 +1735,14 @@ async def test_get_google_workspace_admin_email_with_dls_delegation(): @pytest.mark.asyncio @pytest.mark.parametrize("sync_cursor", [None, {}]) -async def test_get_docs_incrementally_with_empty_sync_cursor(sync_cursor): +async def test_get_docs_incrementally_with_empty_sync_cursor(sync_cursor) -> None: async with create_gdrive_source() as source: with pytest.raises(SyncCursorEmpty): await anext(source.get_docs_incrementally(sync_cursor=sync_cursor)) @pytest.mark.asyncio -async def test_get_docs_incrementally(): +async def test_get_docs_incrementally() -> None: """Tests the module responsible to fetch and yield files documents from Google Drive.""" async with create_gdrive_source() as source: @@ -1832,7 +1836,7 @@ async def test_get_docs_incrementally(): @pytest.mark.asyncio -async def test_get_docs_incrementally_with_domain_wide_delegation(): +async def test_get_docs_incrementally_with_domain_wide_delegation() -> None: """Tests the method responsible to yield files from Google Drive.""" async with create_gdrive_source( @@ -1938,7 +1942,7 @@ async def test_get_docs_incrementally_with_domain_wide_delegation(): @pytest.mark.asyncio -async def test_users(): +async def test_users() -> None: async with create_gdrive_source( google_workspace_admin_email_for_data_sync="admin@email.com" ) as source: @@ -2215,7 +2219,7 @@ async def test_users(): ) async def test_list_files_from_my_drive( fetch_permissions, last_sync_time, api_response, expected_response -): +) -> None: async with create_gdrive_source() as source: with mock.patch.object( GoogleServiceAccountClient, diff --git a/tests/sources/test_graphql.py b/tests/sources/test_graphql.py index 75e61210f..7e9096f0b 100644 --- a/tests/sources/test_graphql.py +++ b/tests/sources/test_graphql.py @@ -21,7 +21,7 @@ class JSONAsyncMock(AsyncMock): - def __init__(self, json, status, *args, **kwargs): + def __init__(self, json, status, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self._json = json self.status = status @@ -30,7 +30,7 @@ async def json(self): return self._json -def get_json_mock(mock_response, status): +def get_json_mock(mock_response, status) -> AsyncMock: async_mock = AsyncMock() async_mock.__aenter__ = AsyncMock( return_value=JSONAsyncMock(json=mock_response, status=status) @@ -42,8 +42,8 @@ def get_json_mock(mock_response, status): async def create_graphql_source( headers=None, graphql_variables=None, - graphql_query="{users {name {firstName } } }", - graphql_object_to_id_map='{"users": "id"}', + graphql_query: str = "{users {name {firstName } } }", + graphql_object_to_id_map: str = '{"users": "id"}', ): async with create_source( GraphQLDataSource, @@ -86,7 +86,7 @@ async def create_graphql_source( ), ], ) -async def test_extract_graphql_data_items(object_list, data, expected_result): +async def test_extract_graphql_data_items(object_list, data, expected_result) -> None: actual_response = [] async with create_graphql_source() as source: source.graphql_client.graphql_object_to_id_map = object_list @@ -96,7 +96,7 @@ async def test_extract_graphql_data_items(object_list, data, expected_result): @pytest.mark.asyncio -async def test_get(): +async def test_get() -> None: async with create_graphql_source() as source: source.graphql_client.session.get = Mock( return_value=get_json_mock( @@ -109,7 +109,7 @@ async def test_get(): @pytest.mark.asyncio @patch("connectors.utils.time_to_sleep_between_retries", Mock(return_value=0)) -async def test_get_with_errors(): +async def test_get_with_errors() -> None: async with create_graphql_source() as source: source.graphql_client.session.get = Mock( return_value=get_json_mock( @@ -124,7 +124,7 @@ async def test_get_with_errors(): @pytest.mark.asyncio -async def test_post(): +async def test_post() -> None: async with create_graphql_source() as source: source.graphql_client.session.post = Mock( return_value=get_json_mock( @@ -137,7 +137,7 @@ async def test_post(): @pytest.mark.asyncio @patch("connectors.utils.time_to_sleep_between_retries", Mock(return_value=0)) -async def test_post_with_errors(): +async def test_post_with_errors() -> None: async with create_graphql_source() as source: source.graphql_client.session.post = Mock( return_value=get_json_mock( @@ -153,7 +153,7 @@ async def test_post_with_errors(): @pytest.mark.asyncio @patch("connectors.utils.time_to_sleep_between_retries", Mock(return_value=0)) -async def test_make_request_with_unauthorized(): +async def test_make_request_with_unauthorized() -> None: async with create_graphql_source() as source: source.graphql_client.session.post = Mock( side_effect=ClientResponseError( @@ -170,7 +170,7 @@ async def test_make_request_with_unauthorized(): @pytest.mark.asyncio @patch("connectors.utils.time_to_sleep_between_retries", Mock(return_value=0)) -async def test_make_request_with_429_exception(): +async def test_make_request_with_429_exception() -> None: async with create_graphql_source() as source: source.graphql_client.session.post = Mock( side_effect=ClientResponseError( @@ -186,7 +186,7 @@ async def test_make_request_with_429_exception(): @pytest.mark.asyncio -async def test_validate_config_with_invalid_url(): +async def test_validate_config_with_invalid_url() -> None: async with create_graphql_source() as source: source.graphql_client.url = "dummy_url" with pytest.raises(ConfigurableFieldValueError): @@ -194,7 +194,7 @@ async def test_validate_config_with_invalid_url(): @pytest.mark.asyncio -async def test_validate_config_with_mutation(): +async def test_validate_config_with_mutation() -> None: async with create_graphql_source( graphql_query="""mutation { addCategory(id: 6, name: "Green Fruits", products: [8, 2, 3]) { @@ -210,21 +210,21 @@ async def test_validate_config_with_mutation(): @pytest.mark.asyncio -async def test_validate_config_with_non_json_headers(): +async def test_validate_config_with_non_json_headers() -> None: async with create_graphql_source(headers="Invalid Headers") as source: with pytest.raises(ConfigurableFieldValueError): await source.validate_config() @pytest.mark.asyncio -async def test_validate_config_with_non_json_variables(): +async def test_validate_config_with_non_json_variables() -> None: async with create_graphql_source(graphql_variables="Invalid Variables") as source: with pytest.raises(ConfigurableFieldValueError): await source.validate_config() @pytest.mark.asyncio -async def test_ping(): +async def test_ping() -> None: async with create_graphql_source() as source: source.graphql_client.post = AsyncMock() await source.ping() @@ -232,7 +232,7 @@ async def test_ping(): @pytest.mark.asyncio @patch("connectors.utils.time_to_sleep_between_retries", Mock(return_value=0)) -async def test_ping_negative(): +async def test_ping_negative() -> None: async with create_graphql_source() as source: source.graphql_client.post = AsyncMock(side_effect=Exception()) with pytest.raises(Exception): @@ -240,7 +240,7 @@ async def test_ping_negative(): @pytest.mark.asyncio -async def test_fetch_data(): +async def test_fetch_data() -> None: expected_response = [{"id": "1", "name": {"firstName": "xyz"}, "_id": "1"}] actual_response = [] async with create_graphql_source() as source: @@ -257,7 +257,7 @@ async def test_fetch_data(): @pytest.mark.asyncio -async def test_fetch_data_with_pagination(): +async def test_fetch_data_with_pagination() -> None: expected_response = [ { "id": "1", @@ -305,7 +305,7 @@ async def test_fetch_data_with_pagination(): @pytest.mark.asyncio -async def test_fetch_data_without_pageinfo(): +async def test_fetch_data_without_pageinfo() -> None: async with create_graphql_source() as source: source.graphql_client.pagination_model = "cursor_pagination" source.graphql_client.graphql_object_to_id_map = {"users": id} @@ -320,7 +320,7 @@ async def test_fetch_data_without_pageinfo(): @pytest.mark.asyncio @freeze_time("2024-01-24T04:07:19") -async def test_get_docs(): +async def test_get_docs() -> None: expected_response = [ { "name": "xyz", @@ -358,7 +358,7 @@ async def test_get_docs(): @pytest.mark.asyncio @freeze_time("2024-01-24T04:07:19") -async def test_get_docs_with_dict_id(): +async def test_get_docs_with_dict_id() -> None: async with create_graphql_source() as source: source.fetch_data = AsyncIterator( [ @@ -371,7 +371,7 @@ async def test_get_docs_with_dict_id(): @pytest.mark.asyncio -async def test_extract_graphql_data_items_with_invalid_key(): +async def test_extract_graphql_data_items_with_invalid_key() -> None: async with create_graphql_source() as source: source.graphql_client.graphql_object_to_id_map = {"user": "id"} data = {"users": {"namexyzid": "123"}} @@ -381,7 +381,7 @@ async def test_extract_graphql_data_items_with_invalid_key(): @pytest.mark.asyncio -async def test_extract_pagination_info_with_invalid_key(): +async def test_extract_pagination_info_with_invalid_key() -> None: async with create_graphql_source() as source: source.pagination_key = ["users_data.user"] data = {"users_data": {"users": {"namexyzid": "123"}}} @@ -391,7 +391,7 @@ async def test_extract_pagination_info_with_invalid_key(): @pytest.mark.asyncio -async def test_is_query_with_mutation_query(): +async def test_is_query_with_mutation_query() -> None: async with create_graphql_source() as source: ast = parse("mutation Login($email: String!){login(email: $email) { token }}") response = source.is_query(ast) @@ -399,7 +399,7 @@ async def test_is_query_with_mutation_query(): @pytest.mark.asyncio -async def test_is_query_with_invalid_query(): +async def test_is_query_with_invalid_query() -> None: async with create_graphql_source() as source: source.graphql_client.graphql_query = "invalid_query {user {}}" with pytest.raises(ConfigurableFieldValueError): @@ -407,7 +407,7 @@ async def test_is_query_with_invalid_query(): @pytest.mark.asyncio -async def test_validate_config_with_invalid_objects(): +async def test_validate_config_with_invalid_objects() -> None: async with create_graphql_source() as source: source.graphql_client.graphql_query = ( "query {organization {repository { issues {name}}}}" @@ -420,7 +420,7 @@ async def test_validate_config_with_invalid_objects(): @pytest.mark.asyncio -async def test_validate_config_with_invalid_pagination_key(): +async def test_validate_config_with_invalid_pagination_key() -> None: async with create_graphql_source( graphql_object_to_id_map='{"organization.repository.issues": "id"}' ) as source: @@ -434,7 +434,7 @@ async def test_validate_config_with_invalid_pagination_key(): @pytest.mark.asyncio -async def test_validate_config_with_missing_config_field(): +async def test_validate_config_with_missing_config_field() -> None: async with create_graphql_source( graphql_object_to_id_map='{"organization.repository.issues": "id"}' ) as source: @@ -446,7 +446,7 @@ async def test_validate_config_with_missing_config_field(): @pytest.mark.asyncio -async def test_validate_config_with_invalid_json(): +async def test_validate_config_with_invalid_json() -> None: async with create_graphql_source( graphql_object_to_id_map='{"organization.repository.issues": "id"' ) as source: diff --git a/tests/sources/test_jira.py b/tests/sources/test_jira.py index 66121d119..18ec8d73d 100644 --- a/tests/sources/test_jira.py +++ b/tests/sources/test_jira.py @@ -8,6 +8,7 @@ import ssl from contextlib import asynccontextmanager from copy import copy +from typing import Dict, Optional from unittest import mock from unittest.mock import AsyncMock, MagicMock, Mock, patch @@ -201,7 +202,7 @@ "_timestamp": "2023-02-01T01:02:20", "_attachment": "IyBUaGlzIGlzIHRoZSBkdW1teSBmaWxl", } -EXPECTED_CONTENT_EXTRACTED = { +EXPECTED_CONTENT_EXTRACTED: Dict[str, str] = { "_id": "TP-1-10001", "_timestamp": "2023-02-01T01:02:20", "body": RESPONSE_CONTENT, @@ -331,9 +332,9 @@ @asynccontextmanager async def create_jira_source( - use_text_extraction_service=False, - data_source="jira_cloud", - jira_url="http://127.0.0.1:8080", + use_text_extraction_service: bool = False, + data_source: str = "jira_cloud", + jira_url: str = "http://127.0.0.1:8080", ): async with create_source( JiraDataSource, @@ -353,13 +354,13 @@ async def create_jira_source( class MockSsl: """This class contains methods which returns dummy ssl context""" - def load_verify_locations(self, cadata): + def load_verify_locations(self, cadata) -> None: """This method verify locations""" pass class JSONAsyncMock(AsyncMock): - def __init__(self, json, *args, **kwargs): + def __init__(self, json, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self._json = json @@ -368,24 +369,24 @@ async def json(self): class StreamReaderAsyncMock(AsyncMock): - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.content = StreamReader -def get_json_mock(mock_response): +def get_json_mock(mock_response) -> AsyncMock: async_mock = AsyncMock() async_mock.__aenter__ = AsyncMock(return_value=JSONAsyncMock(mock_response)) return async_mock -def get_stream_reader(): +def get_stream_reader() -> AsyncMock: async_mock = AsyncMock() async_mock.__aenter__ = AsyncMock(return_value=StreamReaderAsyncMock()) return async_mock -def side_effect_function(url, ssl): +def side_effect_function(url, ssl) -> Optional[AsyncMock]: """Dynamically changing return values for API calls Args: url, ssl: Params required for get call @@ -459,7 +460,7 @@ def side_effect_function(url, ssl): ], ) @pytest.mark.asyncio -async def test_validate_configuration_for_empty_fields(field, data_source): +async def test_validate_configuration_for_empty_fields(field, data_source) -> None: async with create_jira_source() as source: source.jira_client.configuration.get_field("data_source").value = data_source source.jira_client.configuration.get_field(field).value = "" @@ -470,7 +471,7 @@ async def test_validate_configuration_for_empty_fields(field, data_source): @pytest.mark.asyncio @patch("connectors.utils.time_to_sleep_between_retries", Mock(return_value=0)) -async def test_api_call_negative(): +async def test_api_call_negative() -> None: """Tests the api_call function while getting an exception.""" async with create_jira_source() as source: @@ -491,7 +492,7 @@ async def test_api_call_negative(): @pytest.mark.asyncio @patch("connectors.utils.time_to_sleep_between_retries", Mock(return_value=0)) -async def test_api_call_when_server_is_down(): +async def test_api_call_when_server_is_down() -> None: """Tests the api_call function while server gets disconnected.""" async with create_jira_source() as source: @@ -508,7 +509,7 @@ async def test_api_call_when_server_is_down(): @pytest.mark.asyncio @patch("connectors.utils.time_to_sleep_between_retries", Mock(return_value=0)) -async def test_api_call_with_empty_response(): +async def test_api_call_with_empty_response() -> None: """Tests the api_call function when response is empty.""" async with create_jira_source() as source: @@ -524,7 +525,7 @@ async def test_api_call_with_empty_response(): @pytest.mark.asyncio @patch("connectors.utils.time_to_sleep_between_retries", Mock(return_value=0)) -async def test_get_with_429_status(): +async def test_get_with_429_status() -> None: initial_response = ClientResponseError(None, None) initial_response.status = 429 initial_response.message = "rate-limited" @@ -549,7 +550,7 @@ async def test_get_with_429_status(): @pytest.mark.asyncio @patch("connectors.utils.time_to_sleep_between_retries", Mock(return_value=0)) -async def test_get_with_429_status_without_retry_after_header(): +async def test_get_with_429_status_without_retry_after_header() -> None: initial_response = ClientResponseError(None, None) initial_response.status = 429 initial_response.message = "rate-limited" @@ -573,7 +574,7 @@ async def test_get_with_429_status_without_retry_after_header(): @pytest.mark.asyncio -async def test_get_with_404_status(): +async def test_get_with_404_status() -> None: error = ClientResponseError(None, None) error.status = 404 @@ -591,7 +592,7 @@ async def test_get_with_404_status(): @pytest.mark.asyncio @patch("connectors.utils.time_to_sleep_between_retries", Mock(return_value=0)) -async def test_get_with_500_status(): +async def test_get_with_500_status() -> None: error = ClientResponseError(None, None) error.status = 500 @@ -608,7 +609,7 @@ async def test_get_with_500_status(): @pytest.mark.asyncio -async def test_ping_to_custom_path_server(): +async def test_ping_to_custom_path_server() -> None: expected_url = "http://127.0.0.1:8080/test/" async with create_jira_source(jira_url=expected_url) as source: @@ -631,7 +632,7 @@ async def test_ping_to_custom_path_server(): @patch("aiohttp.ClientSession.get") async def test_ping_with_ssl( mock_get, -): +) -> None: """Test ping method of JiraDataSource class with SSL""" mock_get.return_value.__aenter__.return_value.status = 200 @@ -650,7 +651,7 @@ async def test_ping_with_ssl( @pytest.mark.asyncio @patch("aiohttp.ClientSession.get") -async def test_ping_for_failed_connection_exception(client_session_get): +async def test_ping_for_failed_connection_exception(client_session_get) -> None: """Tests the ping functionality when connection can not be established to Jira.""" async with create_jira_source() as source: @@ -662,7 +663,7 @@ async def test_ping_for_failed_connection_exception(client_session_get): @pytest.mark.asyncio -async def test_validate_config_for_ssl_enabled_when_ssl_ca_empty_raises_error(): +async def test_validate_config_for_ssl_enabled_when_ssl_ca_empty_raises_error() -> None: """This function test _validate_configuration when certification is empty when ssl is enabled""" async with create_jira_source() as source: source.configuration.get_field("ssl_enabled").value = True @@ -671,7 +672,7 @@ async def test_validate_config_for_ssl_enabled_when_ssl_ca_empty_raises_error(): @pytest.mark.asyncio -async def test_validate_config_with_invalid_concurrent_downloads(): +async def test_validate_config_with_invalid_concurrent_downloads() -> None: """Test validate_config method of BaseDataSource class with invalid concurrent downloads""" async with create_jira_source() as source: @@ -684,7 +685,7 @@ async def test_validate_config_with_invalid_concurrent_downloads(): @pytest.mark.asyncio -async def test_tweak_bulk_options(): +async def test_tweak_bulk_options() -> None: """Test tweak_bulk_options method of BaseDataSource class""" async with create_jira_source() as source: @@ -700,7 +701,7 @@ async def test_tweak_bulk_options(): @pytest.mark.parametrize( "data_source_type", [JIRA_CLOUD, JIRA_DATA_CENTER, JIRA_SERVER] ) -async def test_get_session(data_source_type): +async def test_get_session(data_source_type) -> None: async with create_jira_source(data_source=data_source_type) as source: try: source.jira_client._get_session() @@ -711,7 +712,7 @@ async def test_get_session(data_source_type): @pytest.mark.asyncio -async def test_get_session_multiple_calls_return_same_instance(): +async def test_get_session_multiple_calls_return_same_instance() -> None: async with create_jira_source() as source: first_instance = source.jira_client._get_session() second_instance = source.jira_client._get_session() @@ -719,14 +720,14 @@ async def test_get_session_multiple_calls_return_same_instance(): @pytest.mark.asyncio -async def test_get_session_raise_on_invalid_data_source_type(): +async def test_get_session_raise_on_invalid_data_source_type() -> None: async with create_jira_source(data_source="invalid") as source: with pytest.raises(InvalidJiraDataSourceTypeError): source.jira_client._get_session() @pytest.mark.asyncio -async def test_close_with_client_session(): +async def test_close_with_client_session() -> None: async with create_jira_source() as source: source.jira_client._get_session() @@ -734,7 +735,7 @@ async def test_close_with_client_session(): @pytest.mark.asyncio -async def test_close_without_client_session(): +async def test_close_without_client_session() -> None: """Test close method when the session does not exist""" async with create_jira_source() as source: @@ -742,7 +743,7 @@ async def test_close_without_client_session(): @pytest.mark.asyncio -async def test_get_timezone(): +async def test_get_timezone() -> None: async with create_jira_source() as source: with patch.object( aiohttp.ClientSession, "get", side_effect=side_effect_function @@ -753,7 +754,7 @@ async def test_get_timezone(): @freeze_time("2023-01-24T04:07:19") @pytest.mark.asyncio -async def test_get_projects(): +async def test_get_projects() -> None: """Test _get_projects method""" async with create_jira_source() as source: @@ -772,7 +773,7 @@ async def test_get_projects(): @freeze_time("2023-01-24T04:07:19") @pytest.mark.asyncio -async def test_get_projects_for_specific_project(): +async def test_get_projects_for_specific_project() -> None: """Test _get_projects method for specific project key""" async with create_jira_source() as source: @@ -795,7 +796,7 @@ async def test_get_projects_for_specific_project(): @pytest.mark.asyncio -async def test_verify_projects(): +async def test_verify_projects() -> None: """Test verify_projects method""" async with create_jira_source() as source: source.jira_client.projects = ["TP", "DP"] @@ -805,7 +806,7 @@ async def test_verify_projects(): @pytest.mark.asyncio -async def test_verify_projects_with_unavailable_project_keys(): +async def test_verify_projects_with_unavailable_project_keys() -> None: async with create_jira_source() as source: source.jira_client.projects = ["TP", "AP"] @@ -815,7 +816,7 @@ async def test_verify_projects_with_unavailable_project_keys(): @pytest.mark.asyncio -async def test_put_issue(): +async def test_put_issue() -> None: """Test _put_issue method""" async with create_jira_source() as source: @@ -827,7 +828,7 @@ async def test_put_issue(): @pytest.mark.asyncio -async def test_get_custom_fields(): +async def test_get_custom_fields() -> None: async with create_jira_source() as source: with patch("aiohttp.ClientSession.get", side_effect=side_effect_function): custom_fields = await anext(source.jira_client.get_jira_fields()) @@ -835,7 +836,7 @@ async def test_get_custom_fields(): @pytest.mark.asyncio -async def test_put_attachment_positive(): +async def test_put_attachment_positive() -> None: """Test _put_attachment method""" async with create_jira_source() as source: @@ -855,7 +856,7 @@ async def test_put_attachment_positive(): @pytest.mark.asyncio -async def test_get_content(): +async def test_get_content() -> None: """Tests the get content method.""" async with create_jira_source() as source: @@ -873,7 +874,7 @@ async def test_get_content(): @pytest.mark.asyncio -async def test_get_content_with_text_extraction_enabled_adds_body(): +async def test_get_content_with_text_extraction_enabled_adds_body() -> None: """Tests the get content method.""" with ( patch( @@ -902,7 +903,7 @@ async def test_get_content_with_text_extraction_enabled_adds_body(): @pytest.mark.asyncio -async def test_get_content_with_upper_extension(): +async def test_get_content_with_upper_extension() -> None: """Tests the get content method.""" async with create_jira_source() as source: @@ -923,7 +924,7 @@ async def test_get_content_with_upper_extension(): @pytest.mark.asyncio -async def test_get_content_when_filesize_is_large(): +async def test_get_content_when_filesize_is_large() -> None: """Tests the get content method for file size greater than max limit.""" async with create_jira_source() as source: @@ -944,7 +945,7 @@ async def test_get_content_when_filesize_is_large(): @pytest.mark.asyncio -async def test_get_content_for_unsupported_filetype(): +async def test_get_content_for_unsupported_filetype() -> None: """Tests the get content method for file type is not supported.""" async with create_jira_source() as source: @@ -965,7 +966,7 @@ async def test_get_content_for_unsupported_filetype(): @pytest.mark.asyncio -async def test_get_consumer(): +async def test_get_consumer() -> None: """Test _get_consumer method""" async with create_jira_source() as source: source.tasks = 3 @@ -985,7 +986,7 @@ async def test_get_consumer(): @freeze_time("2023-01-24T04:07:19") @pytest.mark.asyncio -async def test_get_docs(): +async def test_get_docs() -> None: """Test _get_docs method""" async with create_jira_source() as source: source.jira_client.projects = ["*"] @@ -1026,7 +1027,7 @@ async def test_get_docs(): ], ) @pytest.mark.asyncio -async def test_get_docs_with_advanced_rules(filtering, expected_docs): +async def test_get_docs_with_advanced_rules(filtering, expected_docs) -> None: async with create_jira_source() as source: source.get_content = Mock(return_value=EXPECTED_ATTACHMENT_TYPE_BUG) @@ -1040,7 +1041,7 @@ async def test_get_docs_with_advanced_rules(filtering, expected_docs): @pytest.mark.asyncio -async def test_get_access_control_dls_disabled(): +async def test_get_access_control_dls_disabled() -> None: async with create_jira_source() as source: source._dls_enabled = MagicMock(return_value=False) @@ -1053,7 +1054,7 @@ async def test_get_access_control_dls_disabled(): @pytest.mark.asyncio @freeze_time("2023-01-24T04:07:19") -async def test_get_access_control_dls_enabled(): +async def test_get_access_control_dls_enabled() -> None: mock_users = [ { # Indexable: The user is active and atlassian user. @@ -1158,7 +1159,7 @@ async def test_get_access_control_dls_enabled(): @freeze_time("2023-01-24T04:07:19") @pytest.mark.asyncio -async def test_get_docs_with_dls_enabled(): +async def test_get_docs_with_dls_enabled() -> None: expected_projects_with_access_controls = [ project | { diff --git a/tests/sources/test_microsoft_teams.py b/tests/sources/test_microsoft_teams.py index b9f7a8625..6d37ee95f 100644 --- a/tests/sources/test_microsoft_teams.py +++ b/tests/sources/test_microsoft_teams.py @@ -5,6 +5,7 @@ # import base64 from io import BytesIO +from typing import Dict, List, Optional, Union from unittest.mock import AsyncMock, MagicMock, Mock, patch import pytest @@ -221,7 +222,43 @@ }, ] -EVENTS = [ +EVENTS: List[ + Union[ + Dict[ + str, + Union[ + None, + Dict[str, Dict[str, str]], + Dict[str, str], + Dict[ + str, + Union[ + Dict[str, Union[List[str], int, str]], + Dict[str, Union[int, str]], + ], + ], + List[Dict[str, Dict[str, str]]], + List[Dict[str, str]], + List[str], + int, + str, + ], + ], + Dict[ + str, + Union[ + None, + Dict[str, Dict[str, str]], + Dict[str, str], + List[Dict[str, Dict[str, str]]], + List[Dict[str, str]], + List[str], + int, + str, + ], + ], + ] +] = [ { "id": "AAMkADkyYTNmOTcw", "createdDateTime": "2023-08-10T05:47:41.5466652Z", @@ -378,7 +415,7 @@ }, ] -TEAMS = [ +TEAMS: List[Dict[str, Optional[str]]] = [ { "id": "25ab782d", "createdDateTime": None, @@ -390,7 +427,7 @@ } ] -CHANNELS = [ +CHANNELS: List[Dict[str, Optional[str]]] = [ { "id": "19:36b3f1125", "createdDateTime": "2023-08-08T08:23:13.984Z", @@ -508,7 +545,7 @@ class StubAPIToken: - async def get_with_username_password(self): + async def get_with_username_password(self) -> str: return "something" @@ -541,7 +578,7 @@ async def client(): class ClientSession: """Mock Client Session Class""" - async def close(self): + async def close(self) -> None: """Close method of Mock Client Session Class""" pass @@ -554,7 +591,7 @@ async def create_fake_coroutine(item): return item -def test_get_configuration(): +def test_get_configuration() -> None: config = DataSourceConfiguration( config=MicrosoftTeamsDataSource.get_default_configuration() ) @@ -577,14 +614,14 @@ def test_get_configuration(): ) async def test_validate_configuration_with_invalid_fields_raises_error( extras, -): +) -> None: async with create_source(MicrosoftTeamsDataSource, **extras) as source: with pytest.raises(ConfigurableFieldValueError): await source.validate_config() @pytest.mark.asyncio -async def test_ping_for_successful_connection(): +async def test_ping_for_successful_connection() -> None: async with create_source(MicrosoftTeamsDataSource) as source: DUMMY_RESPONSE = {} source.client.fetch = Mock( @@ -595,7 +632,7 @@ async def test_ping_for_successful_connection(): @pytest.mark.asyncio @patch("aiohttp.ClientSession.get") -async def test_ping_for_failed_connection_exception(mock_get): +async def test_ping_for_failed_connection_exception(mock_get) -> None: async with create_source(MicrosoftTeamsDataSource) as source: with patch.object( MicrosoftTeamsClient, @@ -607,7 +644,7 @@ async def test_ping_for_failed_connection_exception(mock_get): @pytest.mark.asyncio -async def test_get_docs_for_user_chats(): +async def test_get_docs_for_user_chats() -> None: async with create_source(MicrosoftTeamsDataSource) as source: source.client.get_user_chats = Mock(return_value=AsyncIterator([USER_CHATS])) source.client.get_user_chat_messages = Mock( @@ -629,7 +666,7 @@ async def test_get_docs_for_user_chats(): @pytest.mark.asyncio -async def test_get_docs_for_events(): +async def test_get_docs_for_events() -> None: async with create_source(MicrosoftTeamsDataSource) as source: source.client.get_user_chats = Mock(return_value=AsyncIterator([])) source.client.get_teams = Mock(return_value=AsyncIterator([])) @@ -646,7 +683,7 @@ async def test_get_docs_for_events(): @pytest.mark.asyncio -async def test_get_docs_for_teams(): +async def test_get_docs_for_teams() -> None: async with create_source(MicrosoftTeamsDataSource) as source: source.client.get_user_chats = Mock(return_value=AsyncIterator([])) source.client.users = Mock(return_value=AsyncIterator([])) @@ -682,7 +719,7 @@ async def test_get_docs_for_teams(): @pytest.mark.asyncio -async def test_get_content(): +async def test_get_content() -> None: message = b"This is content of attachment" async def download_func(url, async_buffer): @@ -706,7 +743,7 @@ async def download_func(url, async_buffer): ], ) @pytest.mark.asyncio -async def test_get_content_negative(attachment, download_url): +async def test_get_content_negative(attachment, download_url) -> None: message = b"This is content of attachment" async def download_func(url, async_buffer): @@ -719,7 +756,7 @@ async def download_func(url, async_buffer): @pytest.mark.asyncio -async def test_get_channel_file(): +async def test_get_channel_file() -> None: async with create_source( MicrosoftTeamsDataSource, ) as source: @@ -732,7 +769,7 @@ async def test_get_channel_file(): @pytest.mark.asyncio -async def test_get_channel_messages(patch_scroll, client): +async def test_get_channel_messages(patch_scroll, client) -> None: async with create_source( MicrosoftTeamsDataSource, ) as source: @@ -745,7 +782,7 @@ async def test_get_channel_messages(patch_scroll, client): @pytest.mark.asyncio -async def test_format_user_chat_messages(patch_scroll, client): +async def test_format_user_chat_messages(patch_scroll, client) -> None: async with create_source( MicrosoftTeamsDataSource, ) as source: @@ -756,7 +793,7 @@ async def test_format_user_chat_messages(patch_scroll, client): @pytest.mark.asyncio -async def test_format_user_chat_messages_for_members(patch_scroll, client): +async def test_format_user_chat_messages_for_members(patch_scroll, client) -> None: user_chat_request = { "id": "19:2ea91886", "topic": None, @@ -793,7 +830,7 @@ async def test_format_user_chat_messages_for_members(patch_scroll, client): @pytest.mark.asyncio -async def test_get_channel_messages_for_base_class(patch_scroll, client): +async def test_get_channel_messages_for_base_class(patch_scroll, client) -> None: async with create_source( MicrosoftTeamsDataSource, ) as source: @@ -803,7 +840,7 @@ async def test_get_channel_messages_for_base_class(patch_scroll, client): @pytest.mark.asyncio -async def test_get_messages(patch_scroll, client): +async def test_get_messages(patch_scroll, client) -> None: async with create_source( MicrosoftTeamsDataSource, ) as source: @@ -816,7 +853,7 @@ async def test_get_messages(patch_scroll, client): @pytest.mark.asyncio -async def test_get_channel_tabs(): +async def test_get_channel_tabs() -> None: async with create_source( MicrosoftTeamsDataSource, ) as source: @@ -827,7 +864,7 @@ async def test_get_channel_tabs(): @pytest.mark.asyncio -async def test_get_team_channels(): +async def test_get_team_channels() -> None: async with create_source( MicrosoftTeamsDataSource, ) as source: @@ -838,7 +875,7 @@ async def test_get_team_channels(): @pytest.mark.asyncio -async def test_get_teams(): +async def test_get_teams() -> None: async with create_source( MicrosoftTeamsDataSource, ) as source: @@ -849,7 +886,7 @@ async def test_get_teams(): @pytest.mark.asyncio -async def test_get_calendars(): +async def test_get_calendars() -> None: async with create_source( MicrosoftTeamsDataSource, ) as source: @@ -864,7 +901,7 @@ async def test_get_calendars(): @pytest.mark.asyncio -async def test_get_user_chat_tabs(): +async def test_get_user_chat_tabs() -> None: async with create_source( MicrosoftTeamsDataSource, ) as source: @@ -875,7 +912,7 @@ async def test_get_user_chat_tabs(): @pytest.mark.asyncio -async def test_get_user_chat_messages(): +async def test_get_user_chat_messages() -> None: async with create_source( MicrosoftTeamsDataSource, ) as source: @@ -885,7 +922,7 @@ async def test_get_user_chat_messages(): @pytest.mark.asyncio -async def test_get_user_chats(): +async def test_get_user_chats() -> None: async with create_source( MicrosoftTeamsDataSource, ) as source: @@ -895,7 +932,7 @@ async def test_get_user_chats(): @pytest.mark.asyncio -async def test_users(): +async def test_users() -> None: async with create_source( MicrosoftTeamsDataSource, ) as source: @@ -905,7 +942,7 @@ async def test_users(): @pytest.mark.asyncio -async def test_set_internal_logger(): +async def test_set_internal_logger() -> None: async with create_source( MicrosoftTeamsDataSource, ) as source: @@ -914,7 +951,7 @@ async def test_set_internal_logger(): @pytest.mark.asyncio -async def test_scroll(microsoft_client, mock_responses): +async def test_scroll(microsoft_client, mock_responses) -> None: url = "http://localhost:1234/url" first_page = ["1", "2", "3"] @@ -941,7 +978,7 @@ async def test_scroll(microsoft_client, mock_responses): @pytest.mark.asyncio -async def test_pipe(microsoft_client, mock_responses): +async def test_pipe(microsoft_client, mock_responses) -> None: class AsyncStream: def __init__(self): self.stream = BytesIO() @@ -970,7 +1007,7 @@ def read(self): async def test_call_api_with_403( microsoft_client, mock_responses, -): +) -> None: url = "http://localhost:1234/download-some-sample-file" unauthorized_error = ClientResponseError(None, None) @@ -996,7 +1033,7 @@ async def test_call_api_with_403( async def test_call_api_with_404( microsoft_client, mock_responses, -): +) -> None: url = "http://localhost:1234/download-some-sample-file" not_found_error = ClientResponseError(None, None) @@ -1023,7 +1060,7 @@ async def test_call_api_with_os_error( microsoft_client, mock_responses, patch_sleep, -): +) -> None: url = "http://localhost:1234/download-some-sample-file" not_found_error = ClientOSError() @@ -1049,7 +1086,7 @@ async def test_call_api_with_500( microsoft_client, mock_responses, patch_sleep, -): +) -> None: url = "http://localhost:1234/download-some-sample-file" not_found_error = ClientResponseError( @@ -1074,7 +1111,7 @@ async def test_call_api_with_500( class JSONAsyncMock(AsyncMock): - def __init__(self, json, *args, **kwargs): + def __init__(self, json, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self._json = json @@ -1087,7 +1124,7 @@ async def json(self): async def test_call_api_with_429( microsoft_client, mock_responses, -): +) -> None: initial_response = ClientResponseError(None, None) initial_response.status = 429 initial_response.message = "rate-limited" @@ -1115,7 +1152,7 @@ async def test_call_api_with_429( async def test_call_api_with_429_with_retry_after( microsoft_client, mock_responses, -): +) -> None: initial_response = ClientResponseError(None, None) initial_response.status = 429 initial_response.message = "rate-limited" @@ -1142,7 +1179,7 @@ async def test_call_api_with_429_with_retry_after( @pytest.mark.asyncio async def test_call_api_with_unhandled_status( microsoft_client, mock_responses, patch_sleep -): +) -> None: url = "http://localhost:1234/download-some-sample-file" error_message = "Something went wrong" @@ -1167,7 +1204,7 @@ async def test_call_api_with_unhandled_status( @pytest.mark.asyncio -async def test_get_for_client(): +async def test_get_for_client() -> None: bearer = "hello" source = GraphAPIToken( "tenant_id", "client_id", "client_secret", "username", "password" @@ -1179,7 +1216,7 @@ async def test_get_for_client(): @pytest.mark.asyncio -async def test_get_for_username_password(): +async def test_get_for_username_password() -> None: bearer = "hello" source = GraphAPIToken( "tenant_id", "client_id", "client_secret", "username", "password" @@ -1191,7 +1228,7 @@ async def test_get_for_username_password(): @pytest.mark.asyncio -async def test_fetch(microsoft_client, mock_responses): +async def test_fetch(microsoft_client, mock_responses) -> None: with patch.object( GraphAPIToken, "_fetch_token", @@ -1208,7 +1245,7 @@ async def test_fetch(microsoft_client, mock_responses): @pytest.mark.asyncio -async def test_get_user_drive_root_children(): +async def test_get_user_drive_root_children() -> None: async with create_source( MicrosoftTeamsDataSource, ) as source: diff --git a/tests/sources/test_mongo.py b/tests/sources/test_mongo.py index ffd179b56..99d1ce6bc 100644 --- a/tests/sources/test_mongo.py +++ b/tests/sources/test_mongo.py @@ -27,12 +27,12 @@ @asynccontextmanager async def create_mongo_source( - host="mongodb://127.0.0.1:27021", - database=DEFAULT_DATABASE, - collection=DEFAULT_COLLECTION, - ssl_enabled=False, - ssl_ca="", - tls_insecure=False, + host: str = "mongodb://127.0.0.1:27021", + database: str = DEFAULT_DATABASE, + collection: str = DEFAULT_COLLECTION, + ssl_enabled: bool = False, + ssl_ca: str = "", + tls_insecure: bool = False, ): async with create_source( MongoDataSource, @@ -137,12 +137,12 @@ async def create_mongo_source( ), ], ) -async def test_advanced_rules_validator(advanced_rules, is_valid): +async def test_advanced_rules_validator(advanced_rules, is_valid) -> None: validation_result = await MongoAdvancedRulesValidator().validate(advanced_rules) assert validation_result.is_valid == is_valid -def build_resp(): +def build_resp() -> MagicMock: doc1 = {"id": "one", "tuple": (1, 2, 3), "date": datetime.now()} doc2 = {"id": "two", "dict": {"a": "b"}, "decimal": Decimal128("0.0005")} @@ -165,7 +165,7 @@ def __init__(self): @mock.patch( "pymongo.mongo_client.MongoClient._run_operation", lambda *xi, **kw: build_resp() ) -async def test_get_docs(*args): +async def test_get_docs(*args) -> None: async with create_mongo_source() as source: num = 0 async for doc, _ in source.get_docs(): @@ -182,7 +182,7 @@ async def test_get_docs(*args): @mock.patch( "pymongo.mongo_client.MongoClient._run_operation", lambda *xi, **kw: build_resp() ) -async def test_ping_when_called_then_does_not_raise(*args): +async def test_ping_when_called_then_does_not_raise(*args) -> None: admin_mock = Mock() command_mock = AsyncMock() admin_mock.command = command_mock @@ -204,7 +204,7 @@ def admin(self): @pytest.mark.asyncio -async def test_mongo_data_source_get_docs_when_advanced_rules_find_present(): +async def test_mongo_data_source_get_docs_when_advanced_rules_find_present() -> None: async with create_mongo_source() as source: filtering = Filter( { @@ -245,7 +245,9 @@ async def test_mongo_data_source_get_docs_when_advanced_rules_find_present(): @pytest.mark.asyncio -async def test_mongo_data_source_get_docs_when_advanced_rules_aggregate_present(): +async def test_mongo_data_source_get_docs_when_advanced_rules_aggregate_present() -> ( + None +): async with create_mongo_source() as source: filtering = Filter( { @@ -298,7 +300,7 @@ def future_with_result(result): ) async def test_validate_config_when_database_name_invalid_then_raises_exception( patch_validate_collection, -): +) -> None: server_database_names = ["hello", "world"] configured_database_name = "something" @@ -326,7 +328,7 @@ async def test_validate_config_when_database_name_invalid_then_raises_exception( ) async def test_validate_config_when_collection_name_invalid_then_raises_exception( patch_validate_collection, -): +) -> None: server_database_names = ["hello"] server_collection_names = ["first", "second"] configured_database_name = "hello" @@ -370,7 +372,7 @@ async def test_validate_config_when_collection_name_invalid_then_raises_exceptio ) async def test_validate_config_when_collection_access_unauthorized( patch_validate_collection, patch_list_database_names, patch_list_collection_names -): +) -> None: configured_database_name = "hello" configured_collection_name = "second" @@ -395,7 +397,7 @@ async def test_validate_config_when_collection_access_unauthorized( ) async def test_validate_config_when_collection_access_unauthorized_and_no_admin_access( patch_validate_collection, patch_list_database_names -): +) -> None: configured_database_name = "hello" configured_collection_name = "second" @@ -416,7 +418,7 @@ async def test_validate_config_when_collection_access_unauthorized_and_no_admin_ ) async def test_validate_config_when_configuration_valid_then_does_not_raise( patch_validate_connection, -): +) -> None: configured_database_name = "hello" configured_collection_name = "second" @@ -459,7 +461,7 @@ async def test_validate_config_when_configuration_valid_then_does_not_raise( ), ], ) -async def test_serialize(raw, output): +async def test_serialize(raw, output) -> None: async with create_mongo_source() as source: assert source.serialize(raw) == output @@ -487,7 +489,7 @@ async def test_serialize(raw, output): ) async def test_ssl_successful_connection( patch_validate_collection, mock_ssl, certificate_value, tls_insecure -): +) -> None: mock_ssl.return_value = True async with create_mongo_source( ssl_enabled=True, ssl_ca=certificate_value, tls_insecure=tls_insecure @@ -513,7 +515,7 @@ async def test_ssl_successful_connection( ("mongodb://127.0.0.1:27021/?ssl=false", True), ], ) -async def test_get_client_when_pass_conflicting_values(host_value, ssl_value): +async def test_get_client_when_pass_conflicting_values(host_value, ssl_value) -> None: async with create_mongo_source(host=host_value, ssl_enabled=ssl_value) as source: with pytest.raises(ConfigurableFieldValueError): await source.validate_config() diff --git a/tests/sources/test_mssql.py b/tests/sources/test_mssql.py index ae5b2deb2..6436b9ca6 100644 --- a/tests/sources/test_mssql.py +++ b/tests/sources/test_mssql.py @@ -32,7 +32,7 @@ class MockEngine: """This Class create mock engine for mssql dialect""" - def connect(self): + def connect(self) -> ConnectionSync: """Make a connection Returns: @@ -42,7 +42,7 @@ def connect(self): @pytest.mark.asyncio -async def test_ping(): +async def test_ping() -> None: async with create_source(MSSQLDataSource) as source: source.engine = MockEngine() with patch.object( @@ -53,7 +53,7 @@ async def test_ping(): @pytest.mark.asyncio @patch("connectors.utils.time_to_sleep_between_retries", Mock(return_value=0)) -async def test_ping_negative(): +async def test_ping_negative() -> None: with pytest.raises(Exception): async with create_source(MSSQLDataSource) as source: with patch.object(Engine, "connect", side_effect=Exception()): @@ -62,7 +62,7 @@ async def test_ping_negative(): @pytest.mark.asyncio @patch("connectors.utils.time_to_sleep_between_retries", Mock(return_value=0)) -async def test_fetch_documents_from_table_negative(): +async def test_fetch_documents_from_table_negative() -> None: async with create_source(MSSQLDataSource) as source: with patch.object( source.mssql_client, @@ -75,7 +75,7 @@ async def test_fetch_documents_from_table_negative(): @pytest.mark.asyncio @patch("connectors.utils.time_to_sleep_between_retries", Mock(return_value=0)) -async def test_fetch_documents_from_query_negative(): +async def test_fetch_documents_from_query_negative() -> None: async with create_source(MSSQLDataSource) as source: with patch.object( source, @@ -89,7 +89,7 @@ async def test_fetch_documents_from_query_negative(): @pytest.mark.asyncio -async def test_get_docs(): +async def test_get_docs() -> None: # Setup async with create_source( MSSQLDataSource, database="xe", tables="*", schema="dbo" @@ -210,7 +210,9 @@ async def test_get_docs(): ], ) @pytest.mark.asyncio -async def test_advanced_rules_validation(advanced_rules, expected_validation_result): +async def test_advanced_rules_validation( + advanced_rules, expected_validation_result +) -> None: async with create_source( MSSQLDataSource, database="xe", tables="*", schema="dbo" ) as source: @@ -352,7 +354,7 @@ async def test_advanced_rules_validation(advanced_rules, expected_validation_res @pytest.mark.asyncio async def test_advanced_rules_validation_when_id_in_source_available( advanced_rules, id_in_source, expected_validation_result -): +) -> None: async with create_source( MSSQLDataSource, database="xe", tables="*", schema="dbo" ) as source: @@ -461,7 +463,7 @@ async def test_advanced_rules_validation_when_id_in_source_available( ], ) @pytest.mark.asyncio -async def test_get_docs_with_advanced_rules(filtering, expected_response): +async def test_get_docs_with_advanced_rules(filtering, expected_response) -> None: async with create_source( MSSQLDataSource, database="xe", tables="*", schema="dbo" ) as source: @@ -477,7 +479,7 @@ async def test_get_docs_with_advanced_rules(filtering, expected_response): @pytest.mark.asyncio -async def test_create_pem_file(): +async def test_create_pem_file() -> None: async with create_source(MSSQLDataSource) as source: source.mssql_client.create_pem_file() assert ".pem" in source.mssql_client.certfile @@ -488,7 +490,7 @@ async def test_create_pem_file(): @pytest.mark.asyncio -async def test_get_tables_to_fetch(): +async def test_get_tables_to_fetch() -> None: actual_response = [] expected_response = ["table1", "table2"] async with create_source(MSSQLDataSource) as source: @@ -499,7 +501,7 @@ async def test_get_tables_to_fetch(): @pytest.mark.asyncio -async def test_yield_docs_custom_query(): +async def test_yield_docs_custom_query() -> None: async with create_source(MSSQLDataSource) as source: source.mssql_client.get_table_primary_key = AsyncMock(return_value=[]) async for _ in source._yield_docs_custom_query( diff --git a/tests/sources/test_mysql.py b/tests/sources/test_mysql.py index 08668ae29..52df108be 100644 --- a/tests/sources/test_mysql.py +++ b/tests/sources/test_mysql.py @@ -5,6 +5,7 @@ # import asyncio import datetime +from typing import List, Optional, Type from unittest.mock import ANY, AsyncMock, MagicMock, Mock, patch import aiomysql @@ -88,7 +89,7 @@ def future_with_result(result): return future -def as_async_context_manager_mock(obj): +def as_async_context_manager_mock(obj) -> MagicMock: context_manager = MagicMock() context_manager.__aenter__.return_value = obj context_manager.__aexit__.return_value = None @@ -101,8 +102,8 @@ def mocked_mysql_client( table_cols=None, last_update_times=None, documents=None, - custom_query=False, -): + custom_query: bool = False, +) -> MagicMock: client = MagicMock() client.get_primary_key_column_names = AsyncMock(side_effect=pk_cols) @@ -153,7 +154,7 @@ async def patch_connection_pool(): class Result: """This class contains method which returns dummy response""" - def result(self): + def result(self) -> List[List[str]]: """Result method which returns dummy result""" return [["table1"], ["table2"]] @@ -165,7 +166,7 @@ async def __aenter__(self): """Make a dummy database connection and return it""" return self - def __init__(self, *args, **kw): + def __init__(self, *args, **kw) -> None: self.first_call = True self.description = [["Database"]] @@ -175,7 +176,7 @@ def fetchall(self): futures_object.set_result([["table1"], ["table2"]]) return futures_object - async def fetchmany(self, size=1): + async def fetchmany(self, size: int = 1): """This method returns response of fetchmany""" if self.first_call: self.first_call = False @@ -195,7 +196,9 @@ def execute(self, query): futures_object.set_result(MagicMock()) return futures_object - async def __aexit__(self, exception_type, exception_value, exception_traceback): + async def __aexit__( + self, exception_type, exception_value, exception_traceback + ) -> None: """Make sure the dummy database connection gets closed""" pass @@ -207,15 +210,17 @@ async def __aenter__(self): """Make a dummy database connection and return it""" return self - async def ping(self): + async def ping(self) -> bool: """This method returns object of Result class""" return True - async def cursor(self): + async def cursor(self) -> Type[Cursor]: """This method returns object of Result class""" return Cursor - async def __aexit__(self, exception_type, exception_value, exception_traceback): + async def __aexit__( + self, exception_type, exception_value, exception_traceback + ) -> None: """Make sure the dummy database connection gets closed""" pass @@ -223,7 +228,7 @@ async def __aexit__(self, exception_type, exception_value, exception_traceback): class MockSsl: """This class contains methods which returns dummy ssl context""" - def load_verify_locations(self, cadata): + def load_verify_locations(self, cadata) -> None: """This method verify locations""" pass @@ -240,7 +245,7 @@ async def mock_mysql_response(): return mock_response -async def mock_connection(mock_cursor): +async def mock_connection(mock_cursor) -> MagicMock: mock_conn = MagicMock(spec=aiomysql.Connection) mock_conn.cursor.return_value = mock_cursor mock_conn.__aenter__.return_value = mock_conn @@ -248,7 +253,7 @@ async def mock_connection(mock_cursor): return mock_conn -def mock_cursor_fetchmany(rows_per_batch=None): +def mock_cursor_fetchmany(rows_per_batch=None) -> MagicMock: if rows_per_batch is None: rows_per_batch = [] @@ -259,7 +264,7 @@ def mock_cursor_fetchmany(rows_per_batch=None): return mock_cursor -def mock_cursor_fetchall(): +def mock_cursor_fetchall() -> MagicMock: mock_cursor = MagicMock(spec=aiomysql.Cursor) mock_cursor.fetchall.side_effect = AsyncMock( return_value=[DOC_ONE, DOC_TWO, DOC_THREE] @@ -270,7 +275,9 @@ def mock_cursor_fetchall(): @pytest.mark.asyncio -async def test_client_when_aexit_called_then_cancel_sleeps(patch_connection_pool): +async def test_client_when_aexit_called_then_cancel_sleeps( + patch_connection_pool, +) -> None: client = await setup_mysql_client() async with client: @@ -281,7 +288,7 @@ async def test_client_when_aexit_called_then_cancel_sleeps(patch_connection_pool @pytest.mark.asyncio -async def test_client_get_tables(patch_connection_pool): +async def test_client_get_tables(patch_connection_pool) -> None: table_1 = "table_1" table_2 = "table_2" @@ -323,7 +330,7 @@ async def test_client_get_tables(patch_connection_pool): @pytest.mark.asyncio async def test_client_get_column_names_for_table( patch_connection_pool, column_tuples, expected_column_names -): +) -> None: mock_cursor = MagicMock(spec=aiomysql.Cursor) mock_cursor.description = column_tuples mock_cursor.__aenter__.return_value = mock_cursor @@ -338,7 +345,7 @@ async def test_client_get_column_names_for_table( @pytest.mark.asyncio -async def test_client_get_column_names_for_query(patch_connection_pool): +async def test_client_get_column_names_for_query(patch_connection_pool) -> None: columns = [("id",), ("class",)] mock_cursor = MagicMock(spec=aiomysql.Cursor) @@ -357,7 +364,7 @@ async def test_client_get_column_names_for_query(patch_connection_pool): @pytest.mark.asyncio -async def test_client_get_last_update_time(patch_connection_pool): +async def test_client_get_last_update_time(patch_connection_pool) -> None: last_update_time = iso_utc() mock_cursor = MagicMock(spec=aiomysql.Cursor) @@ -373,7 +380,7 @@ async def test_client_get_last_update_time(patch_connection_pool): @pytest.mark.asyncio -async def test_client_yield_rows_for_table(patch_connection_pool): +async def test_client_yield_rows_for_table(patch_connection_pool) -> None: rows = [DOC_ONE, DOC_TWO, DOC_THREE] mock_cursor = mock_cursor_fetchall() patch_connection_pool.acquire.return_value = await mock_connection(mock_cursor) @@ -393,7 +400,7 @@ async def test_client_yield_rows_for_table(patch_connection_pool): @pytest.mark.asyncio -async def test_client_yield_rows_for_query(patch_connection_pool): +async def test_client_yield_rows_for_query(patch_connection_pool) -> None: rows = [DOC_ONE, DOC_TWO, DOC_THREE] mock_cursor = mock_cursor_fetchall() mock_cursor.fetchone = AsyncMock(return_value=(3, None)) @@ -414,7 +421,7 @@ async def test_client_yield_rows_for_query(patch_connection_pool): @pytest.mark.asyncio -async def test_client_ping(patch_logger, patch_connection_pool): +async def test_client_ping(patch_logger, patch_connection_pool) -> None: client = await setup_mysql_client() async with client: @@ -422,7 +429,7 @@ async def test_client_ping(patch_logger, patch_connection_pool): @pytest.mark.asyncio -async def test_client_ping_negative(patch_logger): +async def test_client_ping_negative(patch_logger) -> None: client = await setup_mysql_client() mock_response = asyncio.Future() @@ -437,7 +444,7 @@ async def test_client_ping_negative(patch_logger): @freeze_time(TIME) @pytest.mark.asyncio -async def test_fetch_documents(patch_connection_pool): +async def test_fetch_documents(patch_connection_pool) -> None: primary_key_col = "pk" column = "column" document = ["table1"] @@ -464,7 +471,7 @@ async def test_fetch_documents(patch_connection_pool): @pytest.mark.asyncio async def test_fetch_documents_when_used_custom_query_then_sort_pk_cols( patch_connection_pool, patch_row2doc -): +) -> None: primary_key_col = ["cd", "ab"] column = "column" @@ -518,7 +525,7 @@ async def test_fetch_documents_when_used_custom_query_then_sort_pk_cols( @pytest.mark.asyncio async def test_fetch_documents_when_custom_query_used_and_update_time_none( patch_connection_pool, patch_row2doc -): +) -> None: primary_key_col = ["cd", "ab"] column = "column" @@ -569,7 +576,7 @@ async def test_fetch_documents_when_custom_query_used_and_update_time_none( @pytest.mark.asyncio -async def test_get_docs(patch_connection_pool): +async def test_get_docs(patch_connection_pool) -> None: async with create_source(MySqlDataSource) as source: await setup_mysql_source(source, DATABASE) source.mysql_client = MagicMock() @@ -581,7 +588,7 @@ async def test_get_docs(patch_connection_pool): assert doc == {"a": 1, "b": 2} -async def setup_mysql_client(): +async def setup_mysql_client() -> MySQLClient: client = MySQLClient( host="host", port=123, @@ -595,7 +602,9 @@ async def setup_mysql_client(): return client -async def setup_mysql_source(source, database="", client=None): +async def setup_mysql_source( + source, database: str = "", client: Optional[MagicMock] = None +): if client is None: client = MagicMock() @@ -663,7 +672,7 @@ def setup_available_docs(advanced_snippet): ], ) @pytest.mark.asyncio -async def test_get_docs_with_advanced_rules(filtering, expected_docs): +async def test_get_docs_with_advanced_rules(filtering, expected_docs) -> None: async with create_source(MySqlDataSource) as source: await setup_mysql_source(source) docs_in_db = setup_available_docs(filtering.get_advanced_rules()) @@ -677,7 +686,7 @@ async def test_get_docs_with_advanced_rules(filtering, expected_docs): @pytest.mark.asyncio -async def test_validate_config_when_host_empty_then_raise_error(): +async def test_validate_config_when_host_empty_then_raise_error() -> None: async with create_source(MySqlDataSource, host="") as source: with pytest.raises(ConfigurableFieldValueError): await source.validate_config() @@ -875,7 +884,7 @@ async def test_advanced_rules_validation_when_id_in_source_available( id_in_source, expected_validation_result, patch_ping, -): +) -> None: async with create_source(MySqlDataSource) as source: client = await setup_mysql_source(source, DATABASE) client.get_all_table_names = AsyncMock(return_value=tables_present_in_source) @@ -1011,7 +1020,7 @@ async def test_advanced_rules_validation( advanced_rules, expected_validation_result, patch_ping, -): +) -> None: async with create_source(MySqlDataSource) as source: client = await setup_mysql_source(source, DATABASE) client.get_all_table_names = AsyncMock(return_value=tables_present_in_source) @@ -1027,7 +1036,9 @@ async def test_advanced_rules_validation( @pytest.mark.parametrize("tables", ["*", ["*"]]) @pytest.mark.asyncio -async def test_get_tables_when_wildcard_configured_then_fetch_all_tables(tables): +async def test_get_tables_when_wildcard_configured_then_fetch_all_tables( + tables, +) -> None: async with create_source(MySqlDataSource) as source: source.tables = tables @@ -1042,7 +1053,9 @@ async def test_get_tables_when_wildcard_configured_then_fetch_all_tables(tables) @pytest.mark.asyncio -async def test_validate_database_accessible_when_accessible_then_no_error_raised(): +async def test_validate_database_accessible_when_accessible_then_no_error_raised() -> ( + None +): async with create_source(MySqlDataSource) as source: source.database = "test_database" @@ -1054,7 +1067,9 @@ async def test_validate_database_accessible_when_accessible_then_no_error_raised @pytest.mark.asyncio -async def test_validate_database_accessible_when_not_accessible_then_error_raised(): +async def test_validate_database_accessible_when_not_accessible_then_error_raised() -> ( + None +): async with create_source(MySqlDataSource) as source: cursor = AsyncMock() cursor.execute.side_effect = aiomysql.Error("Error") @@ -1064,7 +1079,9 @@ async def test_validate_database_accessible_when_not_accessible_then_error_raise @pytest.mark.asyncio -async def test_validate_tables_accessible_when_accessible_then_no_error_raised(): +async def test_validate_tables_accessible_when_accessible_then_no_error_raised() -> ( + None +): async with create_source(MySqlDataSource) as source: source.tables = ["table_1", "table_2", "table_3"] @@ -1089,7 +1106,7 @@ async def test_validate_tables_accessible_when_accessible_then_no_error_raised() @pytest.mark.asyncio async def test_validate_tables_accessible_when_accessible_and_wildcard_then_no_error_raised( tables, -): +) -> None: async with create_source(MySqlDataSource) as source: source.tables = tables source.get_tables_to_fetch = AsyncMock( @@ -1105,7 +1122,9 @@ async def test_validate_tables_accessible_when_accessible_and_wildcard_then_no_e @pytest.mark.asyncio -async def test_validate_tables_accessible_when_not_accessible_then_error_raised(): +async def test_validate_tables_accessible_when_not_accessible_then_error_raised() -> ( + None +): async with create_source(MySqlDataSource) as source: source.tables = ["table1"] source.get_tables_to_fetch = AsyncMock(return_value=["table1"]) @@ -1142,7 +1161,7 @@ async def test_validate_tables_accessible_when_not_accessible_then_error_raised( ), ], ) -def test_generate_id(tables, row, primary_key_columns, expected_id): +def test_generate_id(tables, row, primary_key_columns, expected_id) -> None: row_id = generate_id(tables, row, primary_key_columns) assert row_id == expected_id @@ -1187,7 +1206,7 @@ def test_generate_id(tables, row, primary_key_columns, expected_id): @freeze_time(TIME) def test_row2doc( row, column_names, primary_key_columns, tables, timestamp, expected_doc -): +) -> None: doc = row2doc( row=row, column_names=column_names, @@ -1225,7 +1244,7 @@ def test_row2doc( ), ], ) -async def test_update_query_with_pagination_attributes(query, updated_query): +async def test_update_query_with_pagination_attributes(query, updated_query) -> None: client = await setup_mysql_client() expected_updated_query = client._update_query_with_pagination_attributes( query=query, offset=0, primary_key_columns=["id"] @@ -1234,7 +1253,7 @@ async def test_update_query_with_pagination_attributes(query, updated_query): @pytest.mark.asyncio -async def test_get_table_row_count_for_query(patch_connection_pool): +async def test_get_table_row_count_for_query(patch_connection_pool) -> None: table_row_count_for_query = 100 custom_query = "SELECT id, name FROM my_table WHERE marks > 100;" mock_cursor = MagicMock(spec=aiomysql.Cursor) @@ -1253,7 +1272,9 @@ async def test_get_table_row_count_for_query(patch_connection_pool): @pytest.mark.asyncio -async def test_yield_docs_custom_query_with_no_primary_key(patch_connection_pool): +async def test_yield_docs_custom_query_with_no_primary_key( + patch_connection_pool, +) -> None: async with create_source(MySqlDataSource) as source: mock_cursor = mock_cursor_fetchall() mock_cursor.fetchall = AsyncMock(return_value=[]) @@ -1269,7 +1290,7 @@ async def test_yield_docs_custom_query_with_no_primary_key(patch_connection_pool @pytest.mark.asyncio -async def test_get_primary_key_column_names(patch_connection_pool): +async def test_get_primary_key_column_names(patch_connection_pool) -> None: mock_cursor = mock_cursor_fetchall() mock_cursor.fetchall = AsyncMock(return_value=[("id1", None), ("id2", None)]) patch_connection_pool.acquire.return_value = await mock_connection(mock_cursor) diff --git a/tests/sources/test_network_drive.py b/tests/sources/test_network_drive.py index 49f14a5f1..8ae04ad54 100644 --- a/tests/sources/test_network_drive.py +++ b/tests/sources/test_network_drive.py @@ -8,8 +8,9 @@ import asyncio import csv import datetime +from typing import Any, Dict, Optional from unittest import mock -from unittest.mock import ANY, MagicMock +from unittest.mock import ANY, MagicMock, Mock import pytest import smbclient @@ -66,7 +67,7 @@ class PasswordMustChange(SMBResponseException): _STATUS_CODE = 3221226020 -def mock_file(name): +def mock_file(name) -> Mock: """Generates the smbprotocol object for a file Args: @@ -95,7 +96,7 @@ def mock_file(name): return mock_response -def mock_folder(name): +def mock_folder(name) -> Mock: """Generates the smbprotocol object for a folder Args: @@ -122,7 +123,7 @@ def mock_folder(name): return mock_response -def side_effect_function(MAX_CHUNK_SIZE): +def side_effect_function(MAX_CHUNK_SIZE) -> Optional[bytes]: """Dynamically changing return values during reading a file in chunks Args: MAX_CHUNK_SIZE: Maximum bytes allowed to be read at a given time @@ -134,7 +135,7 @@ def side_effect_function(MAX_CHUNK_SIZE): return b"Mock...." -def mock_permission(sid, ace): +def mock_permission(sid, ace) -> Dict[str, Any]: mock_response = {} class MockSID: @@ -153,7 +154,7 @@ def __str__(self): @pytest.mark.asyncio -async def test_ping_for_successful_connection(): +async def test_ping_for_successful_connection() -> None: """Tests the ping functionality for ensuring connection to the Network Drive.""" # Setup expected_response = True @@ -168,7 +169,7 @@ async def test_ping_for_successful_connection(): @pytest.mark.asyncio @mock.patch("smbclient.register_session") -async def test_ping_for_failed_connection(session_mock): +async def test_ping_for_failed_connection(session_mock) -> None: """Tests the ping functionality when connection can not be established to Network Drive. Args: @@ -216,7 +217,7 @@ async def test_ping_for_failed_connection(session_mock): }, ], ) -async def test_create_connection_when_error_occurs(session_mock, side_effect): +async def test_create_connection_when_error_occurs(session_mock, side_effect) -> None: """Tests the create_connection function when an error occurs Args: @@ -235,7 +236,7 @@ async def test_create_connection_when_error_occurs(session_mock, side_effect): @mock.patch("smbclient.scandir") @pytest.mark.asyncio -async def test_traverse_diretory_with_invalid_path(dir_mock): +async def test_traverse_diretory_with_invalid_path(dir_mock) -> None: """Tests the scandir method of smbclient throws error on invalid path Args: @@ -254,7 +255,7 @@ async def test_traverse_diretory_with_invalid_path(dir_mock): @mock.patch("smbclient.scandir") @mock.patch("connectors.utils.time_to_sleep_between_retries", mock.Mock(return_value=0)) @pytest.mark.asyncio -async def test_traverse_diretory_retried_on_smb_timeout(dir_mock): +async def test_traverse_diretory_retried_on_smb_timeout(dir_mock) -> None: """Tests the scandir method of smbclient is retried on SMBConnectionClosed error Args: @@ -284,7 +285,7 @@ async def test_traverse_diretory_retried_on_smb_timeout(dir_mock): @pytest.mark.asyncio @mock.patch("smbclient.open_file") -async def test_fetch_file_when_file_is_inaccessible(file_mock, caplog): +async def test_fetch_file_when_file_is_inaccessible(file_mock, caplog) -> None: """Tests the open_file method of smbclient throws error when file cannot be accessed Args: @@ -311,7 +312,7 @@ async def create_fake_coroutine(data): @pytest.mark.asyncio -async def test_get_content(): +async def test_get_content() -> None: """Test get_content method of Network Drive""" # Setup async with create_source(NASDataSource) as source: @@ -340,7 +341,7 @@ async def test_get_content(): @pytest.mark.asyncio -async def test_get_content_with_upper_extension(): +async def test_get_content_with_upper_extension() -> None: """Test get_content method of Network Drive""" # Setup async with create_source(NASDataSource) as source: @@ -369,7 +370,7 @@ async def test_get_content_with_upper_extension(): @pytest.mark.asyncio -async def test_get_content_when_doit_false(): +async def test_get_content_when_doit_false() -> None: """Test get_content method when doit is false.""" # Setup async with create_source(NASDataSource) as source: @@ -387,7 +388,7 @@ async def test_get_content_when_doit_false(): @pytest.mark.asyncio -async def test_get_content_when_file_size_is_large(): +async def test_get_content_when_file_size_is_large() -> None: """Test the module responsible for fetching the content of the file if it is not extractable""" # Setup async with create_source(NASDataSource) as source: @@ -406,7 +407,7 @@ async def test_get_content_when_file_size_is_large(): @pytest.mark.asyncio -async def test_get_content_when_file_type_not_supported(): +async def test_get_content_when_file_type_not_supported() -> None: """Test get_content method when the file content type is not supported""" # Setup async with create_source(NASDataSource) as source: @@ -428,7 +429,7 @@ async def test_get_content_when_file_type_not_supported(): @mock.patch.object(NASDataSource, "traverse_diretory", return_value=mock.MagicMock()) @mock.patch.object(NASDataSource, "fetch_groups_info", return_value=mock.AsyncMock()) @mock.patch("smbclient.register_session") -async def test_get_doc(mock_traverse_diretory, mock_fetch_groups, session): +async def test_get_doc(mock_traverse_diretory, mock_fetch_groups, session) -> None: # Setup async with create_source(NASDataSource) as source: # Execute @@ -439,7 +440,7 @@ async def test_get_doc(mock_traverse_diretory, mock_fetch_groups, session): @pytest.mark.asyncio @mock.patch("smbclient.open_file") -async def test_fetch_file_when_file_is_accessible(file_mock): +async def test_fetch_file_when_file_is_accessible(file_mock) -> None: """Tests the open_file method of smbclient when file can be accessed Args: @@ -460,7 +461,7 @@ async def test_fetch_file_when_file_is_accessible(file_mock): @pytest.mark.asyncio -async def test_close_without_session(): +async def test_close_without_session() -> None: async with create_source(NASDataSource) as source: await source.close() @@ -588,7 +589,7 @@ async def test_close_without_session(): @pytest.mark.asyncio async def test_advanced_rules_validation( advanced_rules, expected_validation_result, mocked_document -): +) -> None: async with create_source(NASDataSource) as source: validation_result = await NetworkDriveAdvancedRulesValidator(source).validate( advanced_rules @@ -598,7 +599,7 @@ async def test_advanced_rules_validation( @pytest.mark.asyncio -async def test_pattern_start_with_slash(): +async def test_pattern_start_with_slash() -> None: async with create_source(NASDataSource) as source: expected_result = await NetworkDriveAdvancedRulesValidator(source).validate( advanced_rules=[{"pattern": "/[abc"}] @@ -611,10 +612,10 @@ async def test_pattern_start_with_slash(): class MockFile: - def __init__(self, path): + def __init__(self, path) -> None: self.path = path - def is_dir(self): + def is_dir(self) -> bool: return False def path(self): @@ -638,7 +639,7 @@ def path(self): ], ) @mock.patch("smbclient.register_session") -async def test_get_docs_for_syncrule(session, filtering): +async def test_get_docs_for_syncrule(session, filtering) -> None: async with create_source(NASDataSource) as source: with mock.patch( "smbclient.scandir", @@ -667,7 +668,7 @@ async def test_get_docs_for_syncrule(session, filtering): ], ) @mock.patch("smbclient.register_session") -async def test_get_docs_for_syncrule_negative(session, filtering): +async def test_get_docs_for_syncrule_negative(session, filtering) -> None: async with create_source(NASDataSource) as source: with mock.patch( "smbclient.scandir", @@ -677,7 +678,7 @@ async def test_get_docs_for_syncrule_negative(session, filtering): assert document is not None -def test_parse_output(): +def test_parse_output() -> None: security_object = SecurityInfo("user", "password", "0.0.0.0") raw_output = mock.Mock() @@ -696,7 +697,7 @@ def test_parse_output(): assert formatted_result == expected_result -def test_fetch_users(): +def test_fetch_users() -> None: security_object = SecurityInfo("user", "password", "0.0.0.0") sample_data = mock.Mock() @@ -719,7 +720,7 @@ def test_fetch_users(): assert users == expected_result -def test_fetch_groups(): +def test_fetch_groups() -> None: security_object = SecurityInfo("user", "password", "0.0.0.0") sample_data = mock.Mock() @@ -742,7 +743,7 @@ def test_fetch_groups(): assert users == expected_result -def test_fetch_members(): +def test_fetch_members() -> None: security_object = SecurityInfo("user", "password", "0.0.0.0") sample_data = mock.Mock() @@ -766,7 +767,7 @@ def test_fetch_members(): @pytest.mark.asyncio -async def test_get_access_control_dls_disabled(): +async def test_get_access_control_dls_disabled() -> None: async with create_source(NASDataSource) as source: source._features = mock.Mock() source._features.document_level_security_enabled = MagicMock(return_value=False) @@ -779,7 +780,7 @@ async def test_get_access_control_dls_disabled(): @pytest.mark.asyncio -async def test_get_access_control_linux_empty_csv_file_path(): +async def test_get_access_control_linux_empty_csv_file_path() -> None: async with create_source(NASDataSource) as source: source._dls_enabled = MagicMock(return_value=True) source.drive_type = LINUX @@ -789,7 +790,7 @@ async def test_get_access_control_linux_empty_csv_file_path(): @pytest.mark.asyncio -async def test_get_access_control_linux(): +async def test_get_access_control_linux() -> None: async with create_source(NASDataSource) as source: source._dls_enabled = MagicMock(return_value=True) source.drive_type = LINUX @@ -805,7 +806,7 @@ async def test_get_access_control_linux(): @pytest.mark.asyncio -async def test_fetch_groups_info(): +async def test_fetch_groups_info() -> None: mock_groups = {"Admins": "S-1-5-32-546"} mock_group_members = { "Administrator": "S-1-5-21-227823342-1368486282-703244805-500" @@ -823,7 +824,7 @@ async def test_fetch_groups_info(): @pytest.mark.asyncio -async def test_get_access_control_dls_enabled(): +async def test_get_access_control_dls_enabled() -> None: expected_user_access_control = [ [ "rid:500", @@ -856,7 +857,7 @@ async def test_get_access_control_dls_enabled(): @pytest.mark.asyncio -async def test_get_access_control_without_duplicate_ids(): +async def test_get_access_control_without_duplicate_ids() -> None: async with create_source(NASDataSource) as source: source._dls_enabled = MagicMock(return_value=True) source.drive_type = LINUX @@ -908,7 +909,9 @@ async def test_get_access_control_without_duplicate_ids(): @mock.patch.object(NASDataSource, "fetch_groups_info", return_value=mock.AsyncMock()) @mock.patch("smbclient.register_session") @pytest.mark.asyncio -async def test_get_docs_without_dls_enabled(mock_get_files, mock_fetch_groups, session): +async def test_get_docs_without_dls_enabled( + mock_get_files, mock_fetch_groups, session +) -> None: async with create_source(NASDataSource) as source: source._dls_enabled = MagicMock(return_value=False) @@ -989,7 +992,7 @@ async def test_get_docs_with_dls_enabled( mock_groups, mock_members, mock_users, -): +) -> None: async with create_source(NASDataSource) as source: source._dls_enabled = MagicMock(return_value=True) @@ -1007,7 +1010,7 @@ async def test_get_docs_with_dls_enabled( @pytest.mark.asyncio -async def test_read_csv_with_valid_data(): +async def test_read_csv_with_valid_data() -> None: async with create_source(NASDataSource) as source: with mock.patch( "builtins.open", @@ -1022,7 +1025,7 @@ async def test_read_csv_with_valid_data(): @pytest.mark.asyncio -async def test_read_csv_file_erroneous(): +async def test_read_csv_file_erroneous() -> None: async with create_source(NASDataSource) as source: with mock.patch("builtins.open", mock.mock_open(read_data="0I`00�^")): with mock.patch("csv.reader", side_effect=csv.Error): @@ -1031,7 +1034,7 @@ async def test_read_csv_file_erroneous(): @pytest.mark.asyncio -async def test_read_csv_with_empty_groups(): +async def test_read_csv_with_empty_groups() -> None: async with create_source(NASDataSource) as source: with mock.patch( "builtins.open", mock.mock_open(read_data="user1;1;\nuser2;2;") @@ -1046,7 +1049,7 @@ async def test_read_csv_with_empty_groups(): @pytest.mark.asyncio @mock.patch.object(SecurityInfo, "get_descriptor") -async def test_list_file_permissions(mock_get_descriptor): +async def test_list_file_permissions(mock_get_descriptor) -> None: with mock.patch("smbclient.open_file", return_value=MagicMock()) as mock_file: mock_file.fd.return_value = 2 mock_descriptor = mock.Mock() @@ -1066,7 +1069,7 @@ async def test_list_file_permissions(mock_get_descriptor): @pytest.mark.asyncio -async def test_list_file_permissions_with_inaccessible_file(): +async def test_list_file_permissions_with_inaccessible_file() -> None: with mock.patch("smbclient.open_file", return_value=MagicMock()) as mock_file: mock_file.side_effect = SMBOSError(ntstatus=0xC0000043, filename="file1.txt") @@ -1090,7 +1093,9 @@ async def test_list_file_permissions_with_inaccessible_file(): mock_permission(sid="S-1-11-10", ace=1), # Group with Deny permission ], ) -async def test_deny_permission_has_precedence_over_allow(mock_list_file_permission): +async def test_deny_permission_has_precedence_over_allow( + mock_list_file_permission, +) -> None: mock_groups_info = {"10": {"admin": "S-2-21-211-411"}} expected_result = [] async with create_source(NASDataSource) as source: @@ -1116,7 +1121,7 @@ async def test_deny_permission_has_precedence_over_allow(mock_list_file_permissi ) async def test_group_allow_ace_member1_allow_member2_deny_ace_then_member1_has_access( mock_list_file_permission, -): +) -> None: mock_groups_info = {"11": {"user-1": "S-2-21-211-411", "user-2": "S-3-23-222-221"}} expected_result = ["rid:411"] # Only User-1 should have access async with create_source(NASDataSource) as source: @@ -1130,7 +1135,7 @@ async def test_group_allow_ace_member1_allow_member2_deny_ace_then_member1_has_a assert document_permissions[ACCESS_CONTROL] == expected_result -async def test_validate_drive_path(): +async def test_validate_drive_path() -> None: async with create_source(NASDataSource) as source: source.configuration.get_field("drive_path").value = "/abc/bcd" source.configuration.get_field("username").value = "user" @@ -1144,7 +1149,7 @@ async def test_validate_drive_path(): @mock.patch("smbclient.scandir") @mock.patch("connectors.utils.time_to_sleep_between_retries", mock.Mock(return_value=0)) @pytest.mark.asyncio -async def test_traverse_diretory_smb_timeout_for_sync_rule(dir_mock): +async def test_traverse_diretory_smb_timeout_for_sync_rule(dir_mock) -> None: with mock.patch.object(SMBSession, "create_connection"): async with create_source(NASDataSource) as source: path = "some_path" @@ -1171,7 +1176,7 @@ async def test_traverse_diretory_smb_timeout_for_sync_rule(dir_mock): @mock.patch("smbclient.scandir") @pytest.mark.asyncio -async def test_traverse_diretory_with_invalid_path_for_syncrule(dir_mock): +async def test_traverse_diretory_with_invalid_path_for_syncrule(dir_mock) -> None: # Setup async with create_source(NASDataSource) as source: path = "unknown_path" diff --git a/tests/sources/test_notion.py b/tests/sources/test_notion.py index 021b2d4a9..328fe7659 100644 --- a/tests/sources/test_notion.py +++ b/tests/sources/test_notion.py @@ -3,6 +3,7 @@ # or more contributor license agreements. Licensed under the Elastic License 2.0; # you may not use this file except in compliance with the Elastic License 2.0. # +from typing import Dict, List, Optional, Union from unittest.mock import ANY, AsyncMock, MagicMock, Mock, patch import aiohttp @@ -23,7 +24,18 @@ from tests.sources.support import create_source ADVANCED_SNIPPET = "advanced_snippet" -DATABASE = { +DATABASE: Dict[ + str, + Union[ + List[ + Dict[ + str, + Union[None, Dict[str, Optional[str]], Dict[str, Union[bool, str]], str], + ] + ], + str, + ], +] = { "object": "database", "id": "database_id", "created_time": "2021-07-13T16:48:00.000Z", @@ -48,7 +60,32 @@ } ], } -BLOCK = { +BLOCK: Dict[ + str, + Union[ + Dict[ + str, + Dict[ + str, + Union[ + List[ + Dict[ + str, + Union[ + None, + Dict[str, Optional[str]], + Dict[str, Union[bool, str]], + str, + ], + ] + ], + str, + ], + ], + ], + str, + ], +] = { "object": "page", "id": "b3a9a3e8-5a8a-4ac0-9b52-9fb62772a9bd", "created_time": "2021-06-13T16:48:00.000Z", @@ -79,7 +116,19 @@ } }, } -COMMENT = { +COMMENT: Dict[ + str, + Union[ + Dict[str, str], + List[ + Dict[ + str, + Union[None, Dict[str, Optional[str]], Dict[str, Union[bool, str]], str], + ] + ], + str, + ], +] = { "object": "comment", "id": "c8f5a3e8-5a8a-4ac0-9b52-9fb62772a9bd", "created_time": "2021-06-13T16:48:00.000Z", @@ -106,7 +155,24 @@ ], } -CHILD_BLOCK = { +CHILD_BLOCK: Dict[ + str, + Union[ + Dict[ + str, + List[ + Dict[ + str, + Union[ + None, Dict[str, Optional[str]], Dict[str, Union[bool, str]], str + ], + ] + ], + ], + bool, + str, + ], +] = { "object": "block", "id": "b8f5a3e8-5a8a-4ac0-9b52-9fb62772a9bd", "created_time": "2021-05-13T16:48:00.000Z", @@ -135,7 +201,24 @@ ] }, } -CHILD_BLOCK_WITH_CHILDREN = { +CHILD_BLOCK_WITH_CHILDREN: Dict[ + str, + Union[ + Dict[ + str, + List[ + Dict[ + str, + Union[ + None, Dict[str, Optional[str]], Dict[str, Union[bool, str]], str + ], + ] + ], + ], + bool, + str, + ], +] = { "object": "block", "id": "b8f5a3e8-5a8a-4ac0-9b52-9fb62772a9bd", "created_time": "2021-05-13T16:48:00.000Z", @@ -208,7 +291,7 @@ @pytest.mark.asyncio @patch("connectors.sources.notion.NotionClient", autospec=True) -async def test_ping(mock_notion_client): +async def test_ping(mock_notion_client) -> None: mock_notion_client.return_value.fetch_owner.return_value = None async with create_source( NotionDataSource, @@ -220,7 +303,7 @@ async def test_ping(mock_notion_client): @pytest.mark.asyncio @patch("connectors.sources.notion.NotionClient", autospec=True) -async def test_ping_negative(mock_notion_client): +async def test_ping_negative(mock_notion_client) -> None: mock_notion_client.return_value.fetch_owner.side_effect = APIResponseError( message="Invalid API key", code=401, @@ -240,7 +323,7 @@ async def test_ping_negative(mock_notion_client): @pytest.mark.asyncio -async def test_close_with_client(): +async def test_close_with_client() -> None: async with create_source( NotionDataSource, notion_secret_key="5678", @@ -269,7 +352,7 @@ async def test_close_with_client(): @patch("connectors.sources.notion.NotionClient", autospec=True) async def test_get_entities( mock_notion_client, entity_type, entity_titles, mock_search_results -): +) -> None: mock_notion_client.return_value.fetch_by_query = AsyncIterator(mock_search_results) async with create_source( NotionDataSource, @@ -291,7 +374,7 @@ async def test_get_entities( @patch("connectors.sources.notion.NotionClient") async def test_get_entities_entity_not_found( mock_notion_client, entity_type, entity_titles, configuration_key -): +) -> None: mock_search_results = {"results": []} async def mock_make_request(*args, **kwargs): @@ -307,7 +390,7 @@ async def mock_make_request(*args, **kwargs): @pytest.mark.asyncio -async def test_get_entities_exception(): +async def test_get_entities_exception() -> None: async with create_source(NotionDataSource) as source: with patch.object(NotionClient, "fetch_by_query", side_effect=Exception()): with pytest.raises(Exception): @@ -315,7 +398,7 @@ async def test_get_entities_exception(): @pytest.mark.asyncio -async def test_get_content(): +async def test_get_content() -> None: mock_get_via_session = AsyncMock(return_value=MagicMock()) mock_download_extract = AsyncMock(return_value=MagicMock()) mock_file_metadata = AsyncMock(return_value=MagicMock()) @@ -332,7 +415,7 @@ async def test_get_content(): @pytest.mark.asyncio -async def test_get_content_when_url_is_empty(): +async def test_get_content_when_url_is_empty() -> None: async with create_source(NotionDataSource) as source: content = await source.get_content(FILE_BLOCK, None) assert content is None @@ -347,7 +430,7 @@ async def test_get_content_when_url_is_empty(): ("some/file/with/slashes"), ], ) -async def test_get_file_metadata(file_name): +async def test_get_file_metadata(file_name) -> None: async with create_source( NotionDataSource, notion_secret_key="test_get_file_metadata_key", @@ -369,7 +452,7 @@ async def test_get_file_metadata(file_name): @pytest.mark.asyncio -async def test_retrieve_and_process_blocks(): +async def test_retrieve_and_process_blocks() -> None: expected_responses_ids = [ USER.get("id"), BLOCK.get("id"), @@ -392,7 +475,7 @@ async def test_retrieve_and_process_blocks(): assert response_ids == expected_responses_ids -def test_generate_query(): +def test_generate_query() -> None: configuration = DataSourceConfiguration({}) source = NotionDataSource(configuration=configuration) source.pages = ["page1", "*"] @@ -408,7 +491,7 @@ def test_generate_query(): @pytest.mark.asyncio -async def test_fetch_users(): +async def test_fetch_users() -> None: async with create_source( NotionDataSource, notion_secret_key="test_fetch_users_key", @@ -423,7 +506,7 @@ async def test_fetch_users(): @pytest.mark.asyncio -async def test_fetch_child_blocks(): +async def test_fetch_child_blocks() -> None: async with create_source( NotionDataSource, notion_secret_key="test_fetch_child_blocks_key", @@ -443,7 +526,7 @@ async def test_fetch_child_blocks(): @pytest.mark.asyncio -async def test_fetch_comments(): +async def test_fetch_comments() -> None: async with create_source( NotionDataSource, notion_secret_key="test_fetch_comments_key", @@ -461,7 +544,7 @@ async def test_fetch_comments(): @pytest.mark.asyncio -async def test_fetch_by_query(): +async def test_fetch_by_query() -> None: async with create_source( NotionDataSource, notion_secret_key="test_fetch_by_query_key", @@ -482,7 +565,7 @@ async def test_fetch_by_query(): @pytest.mark.asyncio -async def test_query_database(): +async def test_query_database() -> None: async with create_source( NotionDataSource, notion_secret_key="test_query_database_key", @@ -501,7 +584,7 @@ async def test_query_database(): @patch("connectors.utils.time_to_sleep_between_retries", Mock(return_value=0)) @pytest.mark.asyncio -async def test_get_via_session_client_response_error(): +async def test_get_via_session_client_response_error() -> None: async with create_source( NotionDataSource, notion_secret_key="test_get_via_session_client_response_error_key", @@ -522,7 +605,7 @@ async def test_get_via_session_client_response_error(): @patch("connectors.utils.time_to_sleep_between_retries", Mock(return_value=0)) @pytest.mark.asyncio -async def test_get_via_session_with_429_status(): +async def test_get_via_session_with_429_status() -> None: retried_response = AsyncMock() async with create_source( @@ -547,7 +630,7 @@ async def test_get_via_session_with_429_status(): @pytest.mark.asyncio -async def test_fetch_children_recursively(): +async def test_fetch_children_recursively() -> None: async with create_source( NotionDataSource, notion_secret_key="test_fetch_children_recursively_key", @@ -598,7 +681,7 @@ async def test_fetch_children_recursively(): ], ) @pytest.mark.asyncio -async def test_get_docs_with_advanced_rules(filtering): +async def test_get_docs_with_advanced_rules(filtering) -> None: async with create_source( NotionDataSource, notion_secret_key="secert_key", @@ -763,7 +846,9 @@ async def test_get_docs_with_advanced_rules(filtering): ], ) @pytest.mark.asyncio -async def test_advanced_rules_validation(advanced_rules, expected_validation_result): +async def test_advanced_rules_validation( + advanced_rules, expected_validation_result +) -> None: async with create_source( NotionDataSource, notion_secret_key="secret_key" ) as source: @@ -784,7 +869,7 @@ async def test_advanced_rules_validation(advanced_rules, expected_validation_res @pytest.mark.asyncio -async def test_async_iterate_paginated_api(): +async def test_async_iterate_paginated_api() -> None: async def mock_function(**kwargs): return { "results": [{"name": "John"}, {"name": "Alice"}], @@ -808,7 +893,7 @@ async def mock_function(**kwargs): @patch("connectors.utils.time_to_sleep_between_retries", Mock(return_value=0)) @pytest.mark.asyncio -async def test_fetch_results_rate_limit_exceeded(): +async def test_fetch_results_rate_limit_exceeded() -> None: async def mock_function_with_429(**kwargs): mock_function_with_429.call_count += 1 @@ -838,7 +923,7 @@ async def mock_function_with_429(**kwargs): @patch("connectors.utils.time_to_sleep_between_retries", Mock(return_value=0)) @pytest.mark.asyncio -async def test_fetch_results_other_errors_not_retried(): +async def test_fetch_results_other_errors_not_retried() -> None: async def mock_function_with_other_error(**kwargs): raise APIResponseError( code="object_not_found", @@ -865,7 +950,7 @@ async def mock_function_with_other_error(**kwargs): @pytest.mark.asyncio -async def test_original_async_iterate_paginated_api_not_called(): +async def test_original_async_iterate_paginated_api_not_called() -> None: with patch.object(NotionClient, "async_iterate_paginated_api"): with patch( "notion_client.helpers.async_iterate_paginated_api" @@ -878,7 +963,7 @@ async def test_original_async_iterate_paginated_api_not_called(): @pytest.mark.asyncio -async def test_fetch_child_blocks_for_external_object_instance_page(caplog): +async def test_fetch_child_blocks_for_external_object_instance_page(caplog) -> None: block_id = "block_id" caplog.set_level("WARNING") with patch( @@ -903,7 +988,7 @@ async def test_fetch_child_blocks_for_external_object_instance_page(caplog): @pytest.mark.asyncio -async def test_is_connected_property_block(): +async def test_is_connected_property_block() -> None: mocked_connected_property_block = { "object": "page", "id": "12345678-1234-1234-1234-123456789012", @@ -943,7 +1028,7 @@ async def test_is_connected_property_block(): @pytest.mark.asyncio -async def test_fetch_child_blocks_with_not_found_object(caplog): +async def test_fetch_child_blocks_with_not_found_object(caplog) -> None: block_id = "block_id" caplog.set_level("WARNING") with patch( diff --git a/tests/sources/test_onedrive.py b/tests/sources/test_onedrive.py index 2d34f228a..6aed2be11 100644 --- a/tests/sources/test_onedrive.py +++ b/tests/sources/test_onedrive.py @@ -6,6 +6,7 @@ """Tests the OneDrive source class methods""" from contextlib import asynccontextmanager +from typing import Dict, List, Union from unittest.mock import ANY, AsyncMock, MagicMock, Mock, patch import pytest @@ -178,7 +179,9 @@ }, ] -EXPECTED_USER1_FILES = [ +EXPECTED_USER1_FILES: List[ + Union[Dict[str, Union[None, int, str]], Dict[str, Union[int, str]]] +] = [ { "created_at": "2023-05-01T09:09:19Z", "_id": "01DABHRNU2RE777OZMAZG24FV3XP24GXCO", @@ -228,7 +231,9 @@ }, ] -EXPECTED_USER2_FILES = [ +EXPECTED_USER2_FILES: List[ + Union[Dict[str, Union[None, int, str]], Dict[str, Union[int, str]]] +] = [ { "created_at": "2023-05-01T09:09:19Z", "_id": "01DABHRNU2RE777OZMAZG24FV3XP24GXCO", @@ -297,7 +302,7 @@ "_timestamp": "2023-05-01T09:10:21Z", "_attachment": "IyBUaGlzIGlzIHRoZSBkdW1teSBmaWxl", } -EXPECTED_CONTENT_EXTRACTED = { +EXPECTED_CONTENT_EXTRACTED: Dict[str, str] = { "_id": "01DABHRNUACUYC4OM3GJG2NVHDI2ABGP4E", "_timestamp": "2023-05-01T09:10:21Z", "body": RESPONSE_CONTENT, @@ -351,7 +356,12 @@ "grantedToIdentitiesV2": [{"group": {"bar": "foo"}}], } -EXPECTED_USER1_FILES_PERMISSION = [ +EXPECTED_USER1_FILES_PERMISSION: List[ + Union[ + Dict[str, Union[None, List[str], int, str]], + Dict[str, Union[List[str], int, str]], + ] +] = [ { "type": "folder", "title": "folder3", @@ -383,7 +393,12 @@ ], }, ] -EXPECTED_USER2_FILES_PERMISSION = [ +EXPECTED_USER2_FILES_PERMISSION: List[ + Union[ + Dict[str, Union[None, List[str], int, str]], + Dict[str, Union[List[str], int, str]], + ] +] = [ { "type": "folder", "title": "folder4", @@ -487,21 +502,21 @@ } -def token_retrieval_errors(message, error_code): +def token_retrieval_errors(message: str, error_code) -> ClientResponseError: error = ClientResponseError(None, None) error.status = error_code error.message = message return error -def get_stream_reader(): +def get_stream_reader() -> AsyncMock: async_mock = AsyncMock() async_mock.__aenter__ = AsyncMock(return_value=StreamReaderAsyncMock()) return async_mock class JSONAsyncMock(AsyncMock): - def __init__(self, json, *args, **kwargs): + def __init__(self, json, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self._json = json @@ -510,12 +525,12 @@ async def json(self): class StreamReaderAsyncMock(AsyncMock): - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.content = StreamReader -def test_get_configuration(): +def test_get_configuration() -> None: config = DataSourceConfiguration( config=OneDriveDataSource.get_default_configuration() ) @@ -524,7 +539,7 @@ def test_get_configuration(): @asynccontextmanager -async def create_onedrive_source(use_text_extraction_service=False): +async def create_onedrive_source(use_text_extraction_service: bool = False): async with create_source( OneDriveDataSource, client_id="foo", @@ -550,14 +565,14 @@ async def create_onedrive_source(use_text_extraction_service=False): ) async def test_validate_configuration_with_invalid_dependency_fields_raises_error( extras, -): +) -> None: async with create_source(OneDriveDataSource, **extras) as source: with pytest.raises(ConfigurableFieldValueError): await source.validate_config() @pytest.mark.asyncio -async def test_close_with_client_session(): +async def test_close_with_client_session() -> None: async with create_onedrive_source() as source: source.client.access_token = "dummy" @@ -567,7 +582,7 @@ async def test_close_with_client_session(): @pytest.mark.asyncio -async def test_set_access_token(): +async def test_set_access_token() -> None: async with create_onedrive_source() as source: mock_token = {"access_token": "msgraphtoken", "expires_in": "1234555"} async_response = AsyncMock() @@ -582,7 +597,7 @@ async def test_set_access_token(): @pytest.mark.asyncio -async def test_ping_for_successful_connection(): +async def test_ping_for_successful_connection() -> None: async with create_onedrive_source() as source: DUMMY_RESPONSE = {} source.client.get = AsyncIterator([[DUMMY_RESPONSE]]) @@ -592,7 +607,7 @@ async def test_ping_for_successful_connection(): @pytest.mark.asyncio @patch("aiohttp.ClientSession.get") -async def test_ping_for_failed_connection_exception(mock_get): +async def test_ping_for_failed_connection_exception(mock_get) -> None: async with create_onedrive_source() as source: with patch.object( OneDriveClient, "get", side_effect=Exception("Something went wrong") @@ -602,7 +617,7 @@ async def test_ping_for_failed_connection_exception(mock_get): @pytest.mark.asyncio -async def test_get_token_raises_correct_exception_when_400(): +async def test_get_token_raises_correct_exception_when_400() -> None: klass = OneDriveDataSource config = DataSourceConfiguration(config=klass.get_default_configuration()) @@ -622,7 +637,7 @@ async def test_get_token_raises_correct_exception_when_400(): @pytest.mark.asyncio -async def test_get_token_raises_correct_exception_when_401(): +async def test_get_token_raises_correct_exception_when_401() -> None: klass = OneDriveDataSource config = DataSourceConfiguration(config=klass.get_default_configuration()) @@ -641,7 +656,7 @@ async def test_get_token_raises_correct_exception_when_401(): @pytest.mark.asyncio -async def test_get_token_raises_correct_exception_when_any_other_status(): +async def test_get_token_raises_correct_exception_when_any_other_status() -> None: klass = OneDriveDataSource config = DataSourceConfiguration(config=klass.get_default_configuration()) @@ -661,7 +676,7 @@ async def test_get_token_raises_correct_exception_when_any_other_status(): @pytest.mark.asyncio @patch("connectors.utils.time_to_sleep_between_retries", Mock(return_value=0)) -async def test_get_with_429_status(): +async def test_get_with_429_status() -> None: initial_response = ClientResponseError(None, None) initial_response.status = 429 initial_response.message = "rate-limited" @@ -687,7 +702,7 @@ async def test_get_with_429_status(): @pytest.mark.asyncio @patch("connectors.utils.time_to_sleep_between_retries", Mock(return_value=0)) -async def test_get_with_429_status_without_retry_after_header(): +async def test_get_with_429_status_without_retry_after_header() -> None: initial_response = ClientResponseError(None, None) initial_response.status = 429 initial_response.message = "rate-limited" @@ -712,7 +727,7 @@ async def test_get_with_429_status_without_retry_after_header(): @pytest.mark.asyncio -async def test_get_with_404_status(): +async def test_get_with_404_status() -> None: error = ClientResponseError(None, None) error.status = 404 @@ -731,7 +746,7 @@ async def test_get_with_404_status(): @pytest.mark.asyncio @patch("connectors.utils.time_to_sleep_between_retries", Mock(return_value=0)) -async def test_get_with_500_status(): +async def test_get_with_500_status() -> None: error = ClientResponseError(None, None) error.status = 500 @@ -750,7 +765,7 @@ async def test_get_with_500_status(): @pytest.mark.asyncio @patch("connectors.utils.time_to_sleep_between_retries", Mock(return_value=0)) -async def test_post_with_429_status(): +async def test_post_with_429_status() -> None: initial_response = ClientPayloadError(None, None) initial_response.status = 429 initial_response.message = "rate-limited" @@ -776,7 +791,7 @@ async def test_post_with_429_status(): @pytest.mark.asyncio @patch("connectors.utils.time_to_sleep_between_retries", Mock(return_value=0)) -async def test_post_with_429_status_without_retry_after_header(): +async def test_post_with_429_status_without_retry_after_header() -> None: initial_response = ClientPayloadError(None, None) initial_response.status = 429 initial_response.message = "rate-limited" @@ -801,7 +816,7 @@ async def test_post_with_429_status_without_retry_after_header(): @pytest.mark.asyncio -async def test_list_groups(): +async def test_list_groups() -> None: async with create_onedrive_source() as source: with patch.object( OneDriveClient, @@ -816,7 +831,7 @@ async def test_list_groups(): @pytest.mark.asyncio -async def test_list_permissions(): +async def test_list_permissions() -> None: async with create_onedrive_source() as source: with patch.object( OneDriveClient, @@ -847,7 +862,7 @@ async def test_list_permissions(): @pytest.mark.asyncio async def test_get_entity_permission_when_response_missing_content( permissions, patch_logger -): +) -> None: async with create_onedrive_source() as source: source._dls_enabled = MagicMock(return_value=True) response = await source.get_entity_permission("user-1", "file-1") @@ -866,7 +881,7 @@ async def test_get_entity_permission_when_response_missing_content( @pytest.mark.asyncio -async def test_get_owned_files(): +async def test_get_owned_files() -> None: async with create_onedrive_source() as source: async_response = AsyncMock() async_response.__aenter__ = AsyncMock( @@ -887,7 +902,7 @@ async def test_get_owned_files(): @pytest.mark.asyncio -async def test_list_users(): +async def test_list_users() -> None: async with create_onedrive_source() as source: response = [] async_response = AsyncMock() @@ -914,7 +929,7 @@ async def test_list_users(): ) async def test_get_content_when_is_downloadable_is_true( file, download_url, expected_content -): +) -> None: async with create_onedrive_source() as source: with patch.object(AccessToken, "get", return_value="abc"): with patch("aiohttp.ClientSession.get", return_value=get_stream_reader()): @@ -931,7 +946,7 @@ async def test_get_content_when_is_downloadable_is_true( @pytest.mark.asyncio -async def test_get_content_with_extraction_service(): +async def test_get_content_with_extraction_service() -> None: with ( patch( "connectors.content_extraction.ContentExtraction.extract_text", @@ -960,7 +975,7 @@ async def test_get_content_with_extraction_service(): @pytest.mark.asyncio -async def test_prepare_doc_when_file_none(): +async def test_prepare_doc_when_file_none() -> None: async with create_onedrive_source() as source: mock_response = { "createdDateTime": "2023-05-01T09:09:19Z", @@ -988,7 +1003,7 @@ async def test_prepare_doc_when_file_none(): @pytest.mark.asyncio -async def test_lookup_request_by_id(): +async def test_lookup_request_by_id() -> None: async with create_onedrive_source() as source: requests = [ {"id": "1", "url": "user1/delta"}, @@ -1000,7 +1015,7 @@ async def test_lookup_request_by_id(): @pytest.mark.asyncio -async def test_json_batching(): +async def test_json_batching() -> None: async_response, next_page_response = AsyncMock(), AsyncMock() result = [] expected_result = [ @@ -1047,7 +1062,7 @@ async def test_json_batching(): ], ) @pytest.mark.asyncio -async def test_get_docs(users_patch, files_patch): +async def test_get_docs(users_patch, files_patch) -> None: async with create_onedrive_source() as source: expected_responses = [*EXPECTED_USER1_FILES, *EXPECTED_USER2_FILES] source.get_content = AsyncMock(return_value=EXPECTED_CONTENT) @@ -1156,7 +1171,9 @@ async def test_get_docs(users_patch, files_patch): ], ) @pytest.mark.asyncio -async def test_advanced_rules_validation(advanced_rules, expected_validation_result): +async def test_advanced_rules_validation( + advanced_rules, expected_validation_result +) -> None: async with create_onedrive_source() as source: validation_result = await OneDriveAdvancedRulesValidator(source).validate( advanced_rules @@ -1187,7 +1204,7 @@ async def test_advanced_rules_validation(advanced_rules, expected_validation_res ], ) @pytest.mark.asyncio -async def test_get_docs_with_advanced_rules(filtering): +async def test_get_docs_with_advanced_rules(filtering) -> None: async with create_onedrive_source() as source: with patch.object(AccessToken, "get", return_value="abc"): with patch.object( @@ -1221,7 +1238,7 @@ async def test_get_docs_with_advanced_rules(filtering): @pytest.mark.asyncio -async def test_get_access_control_dls_disabled(): +async def test_get_access_control_dls_disabled() -> None: async with create_onedrive_source() as source: source._dls_enabled = MagicMock(return_value=False) @@ -1233,7 +1250,7 @@ async def test_get_access_control_dls_disabled(): @pytest.mark.asyncio -async def test_get_access_control_dls_enabled(): +async def test_get_access_control_dls_enabled() -> None: expected_user_access_control = [ [ "email:AdeleV@w076v.onmicrosoft.com", @@ -1284,7 +1301,7 @@ async def test_get_access_control_dls_enabled(): ], ) @pytest.mark.asyncio -async def test_get_docs_without_dls_enabled(users_patch, files_patch): +async def test_get_docs_without_dls_enabled(users_patch, files_patch) -> None: async with create_onedrive_source() as source: source._dls_enabled = MagicMock(return_value=False) @@ -1329,7 +1346,9 @@ async def test_get_docs_without_dls_enabled(users_patch, files_patch): ], ) @pytest.mark.asyncio -async def test_get_docs_with_dls_enabled(users_patch, files_patch, permissions_patch): +async def test_get_docs_with_dls_enabled( + users_patch, files_patch, permissions_patch +) -> None: async with create_onedrive_source() as source: source._dls_enabled = MagicMock(return_value=True) diff --git a/tests/sources/test_oracle.py b/tests/sources/test_oracle.py index bc411583d..57a2e22a1 100644 --- a/tests/sources/test_oracle.py +++ b/tests/sources/test_oracle.py @@ -53,7 +53,7 @@ def oracle_client(**extras): (SERVICE_NAME, DSN_SERVICE_NAME), ], ) -def test_engine_in_thin_mode(mock_fun, connection_source, DSN): +def test_engine_in_thin_mode(mock_fun, connection_source, DSN) -> None: """Test engine method of OracleClient class in thin mode""" # Setup with oracle_client() as client: @@ -73,7 +73,7 @@ def test_engine_in_thin_mode(mock_fun, connection_source, DSN): (SERVICE_NAME, DSN_SERVICE_NAME), ], ) -def test_engine_in_thick_mode(mock_fun, connection_source, DSN): +def test_engine_in_thick_mode(mock_fun, connection_source, DSN) -> None: """Test engine method of OracleClient class in thick mode""" oracle_home = "/home/devuser" config_file_path = {"lib_dir": f"{oracle_home}/lib", "config_dir": ""} @@ -91,7 +91,7 @@ def test_engine_in_thick_mode(mock_fun, connection_source, DSN): @pytest.mark.asyncio -async def test_ping(): +async def test_ping() -> None: async with create_source(OracleDataSource) as source: with patch.object( Engine, "connect", return_value=ConnectionSync(OracleQueries()) @@ -100,7 +100,7 @@ async def test_ping(): @pytest.mark.asyncio -async def test_get_docs(): +async def test_get_docs() -> None: # Setup async with create_source( OracleDataSource, diff --git a/tests/sources/test_outlook.py b/tests/sources/test_outlook.py index 213717112..abe81c5ec 100644 --- a/tests/sources/test_outlook.py +++ b/tests/sources/test_outlook.py @@ -6,6 +6,7 @@ """Tests the Outlook source class methods""" from contextlib import asynccontextmanager +from typing import List, Optional, Union from unittest import mock from unittest.mock import AsyncMock, MagicMock, patch @@ -158,13 +159,13 @@ class MockException(Exception): - def __init__(self, status, message=None): + def __init__(self, status, message=None) -> None: super().__init__(message) self.status = status class CustomPath: - def __truediv__(self, path): + def __truediv__(self, path) -> Union["CustomPath", "MockOutlookObject"]: # Simulate hierarchy navigation and return a list of dictionaries if path == "Top of Information Store": return self @@ -178,12 +179,12 @@ def __truediv__(self, path): class MockAttachmentId: - def __init__(self, id): # noqa + def __init__(self, id) -> None: # noqa self.id = id class MailDocument: - def __init__(self): + def __init__(self) -> None: sender = MagicMock() sender.email_address = "dummy.user@gmail.com" @@ -202,7 +203,7 @@ def __init__(self): class TaskDocument: - def __init__(self): + def __init__(self) -> None: self.id = "task_1" self.last_modified_time = "2023-12-12T01:01:01Z" self.subject = "Create Test cases for Outlook" @@ -219,7 +220,7 @@ def __init__(self): class ContactDocument: - def __init__(self): + def __init__(self) -> None: contact = MagicMock() contact.email = "dummy.user@gmail.com" contact.phone_number = 99887776655 @@ -234,7 +235,7 @@ def __init__(self): class CalendarDocument: - def __init__(self): + def __init__(self) -> None: organizer = MagicMock() organizer.email_address = "dummy.user@gmail.com" organizer.mailbox.email_address = "dummy.user@gmail.com" @@ -254,10 +255,18 @@ def __init__(self): class AllObjects: - def __init__(self, object_type): + def __init__(self, object_type) -> None: self.object_type = object_type - def only(self, *args): + def only( + self, *args + ) -> Union[ + None, + List[CalendarDocument], + List[ContactDocument], + List[MailDocument], + List[TaskDocument], + ]: match self.object_type: case "mail": return [MailDocument()] @@ -270,16 +279,16 @@ def only(self, *args): class MockOutlookObject: - def __init__(self, object_type): + def __init__(self, object_type) -> None: self.object_type = object_type self.children = [self] - def all(self): # noqa + def all(self) -> AllObjects: # noqa return AllObjects(object_type=self.object_type) class MockAccount: - def __init__(self): + def __init__(self) -> None: self.default_timezone = "UTC" self.inbox = MockOutlookObject(object_type=MAIL) @@ -293,7 +302,7 @@ def __init__(self): class MockAttachment: - def __init__(self, attachment_id, name, size, last_modified_time, content): + def __init__(self, attachment_id, name, size, last_modified_time, content) -> None: self.attachment_id = attachment_id self.name = name self.size = size @@ -353,7 +362,7 @@ def __init__(self, attachment_id, name, size, last_modified_time, content): class JSONAsyncMock(AsyncMock): - def __init__(self, json, status, *args, **kwargs): + def __init__(self, json, status, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self._json = json self.status = status @@ -363,25 +372,25 @@ async def json(self): class StreamReaderAsyncMock(AsyncMock): - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.content = StreamReader @asynccontextmanager async def create_outlook_source( - data_source=OUTLOOK_CLOUD, - tenant_id="foo", - client_id="bar", - client_secret="faa", - exchange_server="127.0.0.1", - active_directory_server="127.0.0.1", - username="fee", - password="fuu", - domain="outlook.com", - ssl_enabled=False, - ssl_ca="", - use_text_extraction_service=False, + data_source: str = OUTLOOK_CLOUD, + tenant_id: str = "foo", + client_id: str = "bar", + client_secret: str = "faa", + exchange_server: str = "127.0.0.1", + active_directory_server: str = "127.0.0.1", + username: str = "fee", + password: str = "fuu", + domain: str = "outlook.com", + ssl_enabled: bool = False, + ssl_ca: str = "", + use_text_extraction_service: bool = False, ): async with create_source( OutlookDataSource, @@ -401,7 +410,7 @@ async def create_outlook_source( yield source -def get_json_mock(mock_response, status): +def get_json_mock(mock_response, status) -> AsyncMock: async_mock = AsyncMock() async_mock.__aenter__ = AsyncMock( return_value=JSONAsyncMock(json=mock_response, status=status) @@ -409,13 +418,13 @@ def get_json_mock(mock_response, status): return async_mock -def get_stream_reader(): +def get_stream_reader() -> AsyncMock: async_mock = AsyncMock() async_mock.__aenter__ = AsyncMock(return_value=StreamReaderAsyncMock()) return async_mock -def side_effect_function(url, headers): +def side_effect_function(url, headers) -> Optional[AsyncMock]: """Dynamically changing return values for API calls Args: url, ssl: Params required for get call @@ -468,7 +477,7 @@ def side_effect_function(url, headers): ) async def test_validate_configuration_with_invalid_dependency_fields_raises_error( extras, -): +) -> None: # Setup async with create_outlook_source(**extras) as source: # Execute @@ -506,14 +515,14 @@ async def test_validate_configuration_with_invalid_dependency_fields_raises_erro ) async def test_validate_config_with_valid_dependency_fields_does_not_raise_error( extras, -): +) -> None: async with create_outlook_source(**extras) as source: await source.validate_config() @pytest.mark.asyncio @patch("connectors.sources.outlook.Connection") -async def test_ping_for_server(mock_connection): +async def test_ping_for_server(mock_connection) -> None: mock_connection_instance = mock_connection.return_value mock_connection_instance.search.return_value = ( True, @@ -529,7 +538,7 @@ async def test_ping_for_server(mock_connection): @pytest.mark.asyncio @patch("connectors.sources.outlook.Connection") -async def test_ping_for_server_for_failed_connection(mock_connection): +async def test_ping_for_server_for_failed_connection(mock_connection) -> None: mock_connection_instance = mock_connection.return_value mock_connection_instance.search.return_value = ( False, @@ -545,7 +554,7 @@ async def test_ping_for_server_for_failed_connection(mock_connection): @pytest.mark.asyncio -async def test_ping_for_cloud(): +async def test_ping_for_cloud() -> None: async with create_outlook_source() as source: with mock.patch( "aiohttp.ClientSession.post", @@ -573,7 +582,7 @@ async def test_ping_for_cloud(): @mock.patch("connectors.utils.time_to_sleep_between_retries") async def test_ping_for_cloud_for_failed_connection( mock_time_to_sleep_between_retries, raised_exception, side_effect_exception -): +) -> None: mock_time_to_sleep_between_retries.return_value = 0 async with create_outlook_source() as source: with mock.patch( @@ -589,7 +598,7 @@ async def test_ping_for_cloud_for_failed_connection( @pytest.mark.asyncio -async def test_get_users_for_cloud(): +async def test_get_users_for_cloud() -> None: async with create_outlook_source() as source: users = [] with mock.patch( @@ -610,7 +619,7 @@ async def test_get_users_for_cloud(): @pytest.mark.asyncio @patch("connectors.sources.outlook.Connection") -async def test_fetch_admin_users_negative(mock_connection): +async def test_fetch_admin_users_negative(mock_connection) -> None: async with create_outlook_source() as source: mock_connection_instance = mock_connection.return_value mock_connection_instance.search.return_value = ( @@ -629,7 +638,7 @@ async def test_fetch_admin_users_negative(mock_connection): @pytest.mark.asyncio @patch("connectors.sources.outlook.Connection") -async def test_fetch_admin_users(mock_connection): +async def test_fetch_admin_users(mock_connection) -> None: async with create_outlook_source() as source: users = [] mock_connection_instance = mock_connection.return_value @@ -658,7 +667,7 @@ async def test_fetch_admin_users(mock_connection): (MOCK_ATTACHMENT_WITH_UNSUPPORTED_EXTENSION, None), ], ) -async def test_get_content(attachment, expected_content): +async def test_get_content(attachment, expected_content) -> None: async with create_outlook_source() as source: response = await source.get_content( attachment=attachment, @@ -669,7 +678,7 @@ async def test_get_content(attachment, expected_content): @pytest.mark.asyncio -async def test_get_content_with_extraction_service(): +async def test_get_content_with_extraction_service() -> None: with ( patch( "connectors.content_extraction.ContentExtraction.extract_text", @@ -698,7 +707,7 @@ async def test_get_content_with_extraction_service(): ], ) @patch("connectors.sources.outlook.Account", return_value="account") -async def test_get_user_accounts_for_cloud(account, is_cloud, user_response): +async def test_get_user_accounts_for_cloud(account, is_cloud, user_response) -> None: async with create_outlook_source() as source: source.client.is_cloud = is_cloud source.client._get_user_instance.get_users = AsyncIterator([user_response]) @@ -710,7 +719,7 @@ async def test_get_user_accounts_for_cloud(account, is_cloud, user_response): @pytest.mark.asyncio -async def test_get_docs(): +async def test_get_docs() -> None: async with create_outlook_source() as source: source.client._get_user_instance.get_user_accounts = AsyncIterator( [MockAccount()] @@ -724,7 +733,7 @@ async def test_get_docs(): "is_cloud, user_response", [(True, {"value": [{"mail": "dummy.user@gmail.com"}]})], ) -async def test_get_access_control(is_cloud, user_response): +async def test_get_access_control(is_cloud, user_response) -> None: async with create_outlook_source() as source: source.client.is_cloud = is_cloud source.client._get_user_instance.get_users = AsyncIterator([user_response]) @@ -749,7 +758,7 @@ async def test_get_access_control(is_cloud, user_response): }, ], ) -async def test_get_access_control_for_server(user_response): +async def test_get_access_control_for_server(user_response) -> None: async with create_outlook_source() as source: source.configuration.get_field("data_source").value = "outlook_server" source.client._get_user_instance.get_users = AsyncIterator([user_response]) diff --git a/tests/sources/test_postgresql.py b/tests/sources/test_postgresql.py index db157106d..a17df2112 100644 --- a/tests/sources/test_postgresql.py +++ b/tests/sources/test_postgresql.py @@ -7,6 +7,7 @@ import ssl from contextlib import asynccontextmanager +from typing import List from unittest.mock import ANY, Mock, patch import pytest @@ -33,7 +34,7 @@ TABLE = "emp_table" CUSTOMER_TABLE = "customer" -TIME = iso_utc() +TIME: str = iso_utc() ID_ONE = "id1" ID_TWO = "id2" @@ -56,7 +57,7 @@ async def create_postgresql_source(): class MockSsl: """This class contains methods which returns dummy ssl context""" - def load_verify_locations(self, cadata): + def load_verify_locations(self, cadata) -> None: """This method verify locations""" pass @@ -68,11 +69,13 @@ async def __aenter__(self): """Make a dummy database connection and return it""" return self - async def __aexit__(self, exception_type, exception_value, exception_traceback): + async def __aexit__( + self, exception_type, exception_value, exception_traceback + ) -> None: """Make sure the dummy database connection gets closed""" pass - async def execute(self, query): + async def execute(self, query) -> "CursorAsync": """This method returns dummy cursor""" return CursorAsync(query=query) @@ -84,12 +87,12 @@ async def __aenter__(self): """Make a dummy database connection and return it""" return self - def __init__(self, *args, **kw): + def __init__(self, *args, **kw) -> None: """Setup dummy cursor""" self.query = kw["query"] self.first_call = True - def keys(self): + def keys(self) -> List[str]: """Return Columns of table Returns: @@ -150,12 +153,14 @@ def fetchmany(self, size): ] return [] - async def __aexit__(self, exception_type, exception_value, exception_traceback): + async def __aexit__( + self, exception_type, exception_value, exception_traceback + ) -> None: """Make sure the dummy database connection gets closed""" pass -def test_get_connect_args(): +def test_get_connect_args() -> None: """This function test _get_connect_args with dummy certificate""" # Setup client = PostgreSQLClient( @@ -177,7 +182,7 @@ def test_get_connect_args(): @pytest.mark.asyncio -async def test_postgresql_ping(): +async def test_postgresql_ping() -> None: # Setup async with create_postgresql_source() as source: with patch.object(AsyncEngine, "connect", return_value=ConnectionAsync()): @@ -187,7 +192,7 @@ async def test_postgresql_ping(): @pytest.mark.asyncio -async def test_ping(): +async def test_ping() -> None: async with create_postgresql_source() as source: with patch.object(AsyncEngine, "connect", return_value=ConnectionAsync()): await source.ping() @@ -195,7 +200,7 @@ async def test_ping(): @pytest.mark.asyncio @patch("connectors.utils.time_to_sleep_between_retries", Mock(return_value=0)) -async def test_ping_negative(): +async def test_ping_negative() -> None: with pytest.raises(Exception): async with create_source(PostgreSQLDataSource, port=5432) as source: with patch.object(AsyncEngine, "connect", side_effect=Exception()): @@ -285,7 +290,9 @@ async def test_ping_negative(): ], ) @pytest.mark.asyncio -async def test_advanced_rules_validation(advanced_rules, expected_validation_result): +async def test_advanced_rules_validation( + advanced_rules, expected_validation_result +) -> None: async with create_source( PostgreSQLDataSource, database="xe", tables="*", schema="public", port=5432 ) as source: @@ -407,7 +414,7 @@ async def test_advanced_rules_validation(advanced_rules, expected_validation_res @pytest.mark.asyncio async def test_advanced_rules_validation_when_id_in_source_available( advanced_rules, id_in_source, expected_validation_result -): +) -> None: async with create_source( PostgreSQLDataSource, database="xe", tables="*", schema="public", port=5432 ) as source: @@ -421,7 +428,7 @@ async def test_advanced_rules_validation_when_id_in_source_available( @freeze_time(TIME) @pytest.mark.asyncio -async def test_get_docs(): +async def test_get_docs() -> None: # Setup async with create_postgresql_source() as source: with patch.object(AsyncEngine, "connect", return_value=ConnectionAsync()): @@ -552,7 +559,7 @@ async def test_get_docs(): ], ) @pytest.mark.asyncio -async def test_get_docs_with_advanced_rules(filtering, expected_response): +async def test_get_docs_with_advanced_rules(filtering, expected_response) -> None: async with create_source( PostgreSQLDataSource, database="xe", @@ -569,7 +576,7 @@ async def test_get_docs_with_advanced_rules(filtering, expected_response): assert actual_response == expected_response -def test_table_primary_key_query(): +def test_table_primary_key_query() -> None: """Test that the table_primary_key method generates the correct SQL query""" # Setup queries = PostgreSQLQueries() @@ -594,7 +601,7 @@ def test_table_primary_key_query(): assert query == expected_query -def test_table_primary_key_query_with_special_characters(): +def test_table_primary_key_query_with_special_characters() -> None: """Test that the table_primary_key method handles special characters in schema and table names""" # Setup queries = PostgreSQLQueries() @@ -620,7 +627,7 @@ def test_table_primary_key_query_with_special_characters(): @pytest.mark.asyncio -async def test_get_table_primary_key(): +async def test_get_table_primary_key() -> None: """Test that the get_table_primary_key method correctly processes query results""" # Setup async with create_postgresql_source() as source: diff --git a/tests/sources/test_redis.py b/tests/sources/test_redis.py index 97eb36c33..4554941e4 100644 --- a/tests/sources/test_redis.py +++ b/tests/sources/test_redis.py @@ -5,6 +5,7 @@ # import json from contextlib import asynccontextmanager +from typing import Dict, List, Set from unittest import mock from unittest.mock import ANY, AsyncMock, Mock @@ -38,46 +39,48 @@ class RedisClientMock: - async def execute_command(self, SELECT="JSON.GET", key="json_key"): + async def execute_command( + self, SELECT: str = "JSON.GET", key: str = "json_key" + ) -> str: return json.dumps({"1": "1", "2": "2"}) - async def zrange(self, key, start, skip, withscores=True): + async def zrange(self, key, start, skip, withscores: bool = True) -> Set[int]: return {1, 2, 3} - async def smembers(self, key): + async def smembers(self, key) -> Set[int]: return {1, 2, 3} - async def get(self, key): + async def get(self, key) -> str: return "this is value" - async def hgetall(self, key): + async def hgetall(self, key) -> str: return "hash" - async def xread(self, key): + async def xread(self, key) -> str: return "stream" - async def lrange(self, key, start, skip): + async def lrange(self, key, start, skip) -> List[int]: return [1, 2, 3] - async def config_get(self, databases): + async def config_get(self, databases) -> Dict[str, str]: return {"databases": "1"} - async def ping(self): + async def ping(self) -> bool: return False - async def aclose(self): + async def aclose(self) -> bool: return True - async def type(self, key): # NOQA + async def type(self, key) -> str: # NOQA return "string" - async def memory_usage(self, key): + async def memory_usage(self, key) -> int: return 10 async def scan_iter(self, match, count, _type): yield "0" - async def validate_database(self, db=0): + async def validate_database(self, db: int = 0) -> None: await self.execute_command() @@ -95,14 +98,14 @@ async def create_redis_source(): @pytest.mark.asyncio -async def test_ping_positive(): +async def test_ping_positive() -> None: async with create_redis_source() as source: source.client.ping = AsyncMock() await source.ping() @pytest.mark.asyncio -async def test_ping_negative(): +async def test_ping_negative() -> None: async with create_redis_source() as source: mocked_client = AsyncMock() with mock.patch("redis.asyncio.from_url", return_value=mocked_client): @@ -114,7 +117,7 @@ async def test_ping_negative(): @pytest.mark.asyncio -async def test_validate_config_when_database_is_not_integer(): +async def test_validate_config_when_database_is_not_integer() -> None: async with create_redis_source() as source: source.client.database = ["db123", "db456"] with mock.patch("redis.asyncio.from_url", return_value=AsyncMock()): @@ -123,7 +126,7 @@ async def test_validate_config_when_database_is_not_integer(): @pytest.mark.asyncio -async def test_validate_config_with_wrong_configuration(): +async def test_validate_config_with_wrong_configuration() -> None: async with create_redis_source() as source: mocked_client = AsyncMock() with mock.patch("redis.asyncio.from_url", return_value=mocked_client): @@ -135,7 +138,7 @@ async def test_validate_config_with_wrong_configuration(): @pytest.mark.asyncio -async def test_validate_config_when_database_is_invalid(): +async def test_validate_config_when_database_is_invalid() -> None: async with create_redis_source() as source: source.client.database = ["123"] source.client.validate_database = AsyncMock(return_value=False) @@ -146,7 +149,7 @@ async def test_validate_config_when_database_is_invalid(): @pytest.mark.asyncio @freeze_time("2023-01-24T04:07:19+00:00") -async def test_get_docs(): +async def test_get_docs() -> None: async with create_redis_source() as source: source.client.database = [1] @@ -160,7 +163,7 @@ async def test_get_docs(): @pytest.mark.asyncio -async def test_get_databases_for_multiple_db(): +async def test_get_databases_for_multiple_db() -> None: async with create_redis_source() as source: source.client.database = [1, 2] async for database in source.client.get_databases(): @@ -168,7 +171,7 @@ async def test_get_databases_for_multiple_db(): @pytest.mark.asyncio -async def test_get_databases_with_asterisk(): +async def test_get_databases_with_asterisk() -> None: async with create_redis_source() as source: source.client.database = ["*"] source.client._client = RedisClientMock() @@ -177,7 +180,7 @@ async def test_get_databases_with_asterisk(): @pytest.mark.asyncio -async def test_get_databases_expect_no_databases_on_auth_error(): +async def test_get_databases_expect_no_databases_on_auth_error() -> None: async with create_redis_source() as source: source.client.database = ["*"] mocked_client = AsyncMock() @@ -202,7 +205,7 @@ async def test_get_databases_expect_no_databases_on_auth_error(): ("stream_key", "stream", "stream"), ], ) -async def test_get_key_value(key, key_type, expected_response): +async def test_get_key_value(key, key_type, expected_response) -> None: async with create_redis_source() as source: source.client.database = ["*"] source.client._client = RedisClientMock() @@ -212,7 +215,7 @@ async def test_get_key_value(key, key_type, expected_response): @pytest.mark.asyncio @freeze_time("2023-01-24T04:07:19+00:00") -async def test_get_key_metadata(): +async def test_get_key_metadata() -> None: async with create_redis_source() as source: source.client._client = RedisClientMock() key_type, value, size = await source.client.get_key_metadata(key="0") @@ -223,7 +226,7 @@ async def test_get_key_metadata(): @pytest.mark.asyncio @freeze_time("2023-01-24T04:07:19+00:00") -async def test_get_db_records(): +async def test_get_db_records() -> None: async with create_redis_source() as source: source.client._client = RedisClientMock() source.client.get_paginated_key = AsyncIterator(["0"]) @@ -247,7 +250,7 @@ async def test_get_db_records(): ) @pytest.mark.asyncio @freeze_time("2023-01-24T04:07:19+00:00") -async def test_get_docs_with_sync_rules(filtering): +async def test_get_docs_with_sync_rules(filtering) -> None: async with create_redis_source() as source: source.client.database = ["*"] source.client._client = Mock() @@ -336,7 +339,9 @@ async def test_get_docs_with_sync_rules(filtering): ], ) @pytest.mark.asyncio -async def test_advanced_rules_validation(advanced_rules, expected_validation_result): +async def test_advanced_rules_validation( + advanced_rules, expected_validation_result +) -> None: async with create_redis_source() as source: source.client._client = RedisClientMock() validation_result = await RedisAdvancedRulesValidator(source).validate( @@ -346,7 +351,7 @@ async def test_advanced_rules_validation(advanced_rules, expected_validation_res @pytest.mark.asyncio -async def test_client_when_mutual_ssl_enabled(): +async def test_client_when_mutual_ssl_enabled() -> None: async with create_redis_source() as source: source.client.database = ["*"] source.client.ssl_enabled = True @@ -362,7 +367,7 @@ async def test_client_when_mutual_ssl_enabled(): @pytest.mark.asyncio -async def test_client_when_ssl_enabled(): +async def test_client_when_ssl_enabled() -> None: async with create_redis_source() as source: source.client.database = ["*"] source.client.ssl_enabled = True @@ -376,7 +381,7 @@ async def test_client_when_ssl_enabled(): @pytest.mark.asyncio -async def test_ping_when_mutual_ssl_enabled(): +async def test_ping_when_mutual_ssl_enabled() -> None: async with create_redis_source() as source: source.client.database = ["*"] source.client.ssl_enabled = True diff --git a/tests/sources/test_s3.py b/tests/sources/test_s3.py index 299a647e4..9f27782e3 100644 --- a/tests/sources/test_s3.py +++ b/tests/sources/test_s3.py @@ -5,6 +5,7 @@ # from contextlib import asynccontextmanager from datetime import datetime +from typing import Dict from unittest import mock from unittest.mock import ANY, AsyncMock, MagicMock, patch @@ -23,7 +24,7 @@ @asynccontextmanager -async def create_s3_source(use_text_extraction_service=False): +async def create_s3_source(use_text_extraction_service: bool = False): async with create_source( S3DataSource, buckets="ent-search-ingest-dev", @@ -37,7 +38,7 @@ async def create_s3_source(use_text_extraction_service=False): class Summary: """This class is used to initialize file summary""" - def __init__(self, key): + def __init__(self, key) -> None: """Setup key of file object Args: @@ -46,7 +47,7 @@ def __init__(self, key): self.key = key @property - async def size(self): + async def size(self) -> int: """Set size of file Returns: @@ -55,7 +56,7 @@ async def size(self): return 12 @property - async def last_modified(self): + async def last_modified(self) -> datetime: """Set last_modified time Returns: @@ -64,7 +65,7 @@ async def last_modified(self): return datetime.now() @property - async def storage_class(self): + async def storage_class(self) -> str: """Set storage_class of file Returns: @@ -73,7 +74,7 @@ async def storage_class(self): return "STANDARD" @property - async def owner(self): + async def owner(self) -> Dict[str, str]: """Set owner of file Returns: @@ -85,7 +86,7 @@ async def owner(self): class AIOResourceCollection: """Class for mock AIOResourceCollection""" - def __init__(self, *args, **kw): + def __init__(self, *args, **kw) -> None: """Setup AIOResourceCollection""" pass @@ -112,7 +113,7 @@ def __aiter__(self): class S3Object(dict): """Class for mock S3 object""" - def __init__(self, *args, **kw): + def __init__(self, *args, **kw) -> None: """Setup document of Mock object""" self.meta = mock.MagicMock() self["Body"] = self @@ -123,7 +124,7 @@ def __init__(self, *args, **kw): ] self.called = False - async def read(self, *args): + async def read(self, *args) -> bytes: """Method returns object content Returns: @@ -138,7 +139,7 @@ async def __aenter__(self): """Make a dummy connection and return it""" return self - async def __aexit__(self, exc_type, exc_val, exc_tb): + async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: """Make sure the dummy connection gets closed""" pass @@ -153,7 +154,7 @@ async def create_fake_coroutine(data): @pytest.mark.asyncio -async def test_ping(): +async def test_ping() -> None: """Test ping method of S3DataSource class""" async with create_s3_source() as source: with ( @@ -167,7 +168,7 @@ async def test_ping(): @pytest.mark.asyncio -async def test_ping_negative(): +async def test_ping_negative() -> None: """Test ping method of S3DataSource class with negative case""" async with create_s3_source() as source: with mock.patch.object( @@ -178,7 +179,7 @@ async def test_ping_negative(): @pytest.mark.asyncio -async def test_get_bucket_region(): +async def test_get_bucket_region() -> None: """Test get_bucket_region method of S3DataSource""" async with create_s3_source() as source: with mock.patch("aiobotocore.client.AioBaseClient", S3Object): @@ -188,7 +189,7 @@ async def test_get_bucket_region(): @pytest.mark.asyncio -async def test_get_bucket_region_negative(): +async def test_get_bucket_region_negative() -> None: """Test get_bucket_region method of S3DataSource for negative case""" async with create_s3_source() as source: with mock.patch.object( @@ -199,16 +200,16 @@ async def test_get_bucket_region_negative(): class ReadAsyncMock(AsyncMock): - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) - async def read(): + async def read() -> bytes: return b"test content" @mock.patch("aiobotocore.client.AioBaseClient") @pytest.mark.asyncio -async def test_get_content(s3_client): +async def test_get_content(s3_client: MagicMock) -> None: """Test get_content method of S3Client""" async with create_s3_source() as source: @@ -240,7 +241,9 @@ async def test_get_content(s3_client): @mock.patch("aiobotocore.client.AioBaseClient") @pytest.mark.asyncio -async def test_get_content_with_text_extraction_enabled_adds_body(s3_client): +async def test_get_content_with_text_extraction_enabled_adds_body( + s3_client: MagicMock, +) -> None: """Test get_content method of S3Client""" with ( patch( @@ -278,7 +281,7 @@ async def test_get_content_with_text_extraction_enabled_adds_body(s3_client): @mock.patch("aiobotocore.client.AioBaseClient") @pytest.mark.asyncio -async def test_get_content_with_upper_extension(s3_client): +async def test_get_content_with_upper_extension(s3_client: MagicMock) -> None: """Test get_content method of S3Client""" async with create_s3_source() as source: @@ -309,7 +312,7 @@ async def test_get_content_with_upper_extension(s3_client): @pytest.mark.asyncio -async def test_get_content_with_unsupported_file(mock_aws): +async def test_get_content_with_unsupported_file(mock_aws) -> None: """Test get_content method of S3Client for unsupported file""" async with create_s3_source() as source: with mock.patch("aiobotocore.client.AioBaseClient", S3Object): @@ -322,7 +325,7 @@ async def test_get_content_with_unsupported_file(mock_aws): @pytest.mark.asyncio -async def test_get_content_when_not_doit(mock_aws): +async def test_get_content_when_not_doit(mock_aws) -> None: """Test get_content method of S3Client when doit is none""" async with create_s3_source() as source: with mock.patch("aiobotocore.client.AioBaseClient", S3Object): @@ -333,7 +336,7 @@ async def test_get_content_when_not_doit(mock_aws): @pytest.mark.asyncio -async def test_get_content_when_size_is_large(mock_aws): +async def test_get_content_when_size_is_large(mock_aws) -> None: """Test get_content method of S3Client when size is greater than max size""" async with create_s3_source() as source: with mock.patch("aiobotocore.client.AioBaseClient", S3Object): @@ -355,7 +358,7 @@ async def get_roles(*args): @pytest.mark.asyncio -async def test_get_docs(mock_aws): +async def test_get_docs(mock_aws) -> None: """Test get_docs method of S3DataSource""" async with create_s3_source() as source: source.s3_client.get_bucket_location = mock.Mock( @@ -399,7 +402,7 @@ async def test_get_docs(mock_aws): ], ) @pytest.mark.asyncio -async def test_get_docs_with_advanced_rules(filtering): +async def test_get_docs_with_advanced_rules(filtering) -> None: async with create_s3_source() as source: source.s3_client.get_bucket_location = mock.Mock( return_value=await create_fake_coroutine("ap-south-1") @@ -427,7 +430,7 @@ async def test_get_docs_with_advanced_rules(filtering): @pytest.mark.asyncio -async def test_get_bucket_list(): +async def test_get_bucket_list() -> None: """Test get_bucket_list method of S3Client""" async with create_s3_source() as source: source.s3_client.bucket_list = [] @@ -444,7 +447,7 @@ async def test_get_bucket_list(): @pytest.mark.asyncio -async def test_get_bucket_list_for_wildcard(): +async def test_get_bucket_list_for_wildcard() -> None: async with create_s3_source() as source: source.configuration.get_field("buckets").value = ["*"] @@ -461,7 +464,7 @@ async def test_get_bucket_list_for_wildcard(): @pytest.mark.asyncio -async def test_validate_config_for_empty_bucket_string(): +async def test_validate_config_for_empty_bucket_string() -> None: """This function test validate_configwhen buckets string is empty""" async with create_s3_source() as source: source.configuration.get_field("buckets").value = [""] @@ -472,7 +475,7 @@ async def test_validate_config_for_empty_bucket_string(): @pytest.mark.asyncio -async def test_get_content_with_clienterror(): +async def test_get_content_with_clienterror() -> None: """Test get_content method of S3Client for client error""" async with create_s3_source() as source: document = { @@ -493,7 +496,7 @@ async def test_get_content_with_clienterror(): @pytest.mark.asyncio -async def test_close_with_client_session(): +async def test_close_with_client_session() -> None: """Test close method of S3DataSource with client session""" async with create_s3_source() as source: @@ -557,7 +560,9 @@ async def test_close_with_client_session(): ], ) @pytest.mark.asyncio -async def test_advanced_rules_validation(advanced_rules, expected_validation_result): +async def test_advanced_rules_validation( + advanced_rules, expected_validation_result +) -> None: async with create_source(S3DataSource) as source: validation_result = await S3AdvancedRulesValidator(source).validate( advanced_rules diff --git a/tests/sources/test_salesforce.py b/tests/sources/test_salesforce.py index aa9615b6e..45e3fdf9d 100644 --- a/tests/sources/test_salesforce.py +++ b/tests/sources/test_salesforce.py @@ -8,11 +8,12 @@ import re from contextlib import asynccontextmanager from copy import deepcopy +from typing import Dict, List, Union from unittest import TestCase, mock from unittest.mock import AsyncMock, MagicMock, patch import pytest -from aioresponses import CallbackResult +from aioresponses.core import CallbackResult from connectors.access_control import DLS_QUERY from connectors.protocol import Filter @@ -39,7 +40,7 @@ CONTENT_VERSION_ID = "content_version_id" TEST_BASE_URL = f"https://{TEST_DOMAIN}.my.salesforce.com" TEST_FILE_DOWNLOAD_URL = f"{TEST_BASE_URL}/services/data/{API_VERSION}/sobjects/ContentVersion/{CONTENT_VERSION_ID}/VersionData" -TEST_QUERY_MATCH_URL = re.compile( +TEST_QUERY_MATCH_URL: re.Pattern[str] = re.compile( f"{TEST_BASE_URL}/services/data/{API_VERSION}/(tooling|query)*" ) TEST_CLIENT_ID = "1234" @@ -154,7 +155,7 @@ ], } -LEAD_RESPONSE_PAYLOAD = { +LEAD_RESPONSE_PAYLOAD: Dict[str, List[Dict[str, Union[None, Dict[str, str], str]]]] = { "records": [ { "attributes": { @@ -317,7 +318,35 @@ ] } -CASE_FEED_RESPONSE_PAYLOAD = { +CASE_FEED_RESPONSE_PAYLOAD: Dict[ + str, + List[ + Dict[ + str, + Union[ + None, + Dict[ + str, + List[ + Dict[ + str, + Union[ + Dict[str, str], + Dict[str, Union[Dict[str, str], str]], + bool, + str, + ], + ] + ], + ], + Dict[str, str], + Dict[str, Union[Dict[str, str], str]], + int, + str, + ], + ] + ], +] = { "records": [ { "attributes": { @@ -371,7 +400,24 @@ ] } -CONTENT_DOCUMENT_LINKS_PAYLOAD = { +CONTENT_DOCUMENT_LINKS_PAYLOAD: Dict[ + str, + List[ + Dict[ + str, + Union[ + Dict[str, str], + Dict[ + str, + Union[ + Dict[str, str], Dict[str, Union[Dict[str, str], str]], int, str + ], + ], + str, + ], + ] + ], +] = { "records": [ { "attributes": { @@ -455,7 +501,40 @@ ], } -SOSL_RESPONSE_PAYLOAD = { +SOSL_RESPONSE_PAYLOAD: Dict[ + str, + List[ + Dict[ + str, + Union[ + Dict[ + str, + List[ + Dict[ + str, + Union[ + Dict[str, str], + Dict[ + str, + Union[ + Dict[str, str], + Dict[str, Union[Dict[str, str], str]], + int, + str, + ], + ], + str, + ], + ] + ], + ], + Dict[str, str], + Dict[str, Union[Dict[str, str], str]], + str, + ], + ] + ], +] = { "searchRecords": [ { "attributes": { @@ -652,7 +731,9 @@ @asynccontextmanager async def create_salesforce_source( - use_text_extraction_service=False, mock_token=True, mock_queryables=True + use_text_extraction_service: bool = False, + mock_token: bool = True, + mock_queryables: bool = True, ): async with create_source( SalesforceDataSource, @@ -680,7 +761,7 @@ async def create_salesforce_source( yield source -def salesforce_query_callback(url, **kwargs): +def salesforce_query_callback(url, **kwargs) -> CallbackResult: """Dynamically returns a payload based on query and adds ContentDocumentLinks to each payload """ @@ -749,7 +830,7 @@ def generate_account_doc(identifier): } -def test_get_default_configuration(): +def test_get_default_configuration() -> None: config = DataSourceConfiguration(SalesforceDataSource.get_default_configuration()) expected_fields = [ "client_id", @@ -799,7 +880,7 @@ def test_get_default_configuration(): ) async def test_validate_config_missing_fields_then_raise( domain, client_id, client_secret, standard_objects_to_sync, sync_custom_objects -): +) -> None: async with create_source( SalesforceDataSource, domain=domain, @@ -813,7 +894,7 @@ async def test_validate_config_missing_fields_then_raise( @pytest.mark.asyncio -async def test_ping_with_successful_connection(mock_responses): +async def test_ping_with_successful_connection(mock_responses) -> None: async with create_salesforce_source() as source: mock_responses.head(TEST_BASE_URL, status=200) @@ -821,7 +902,7 @@ async def test_ping_with_successful_connection(mock_responses): @pytest.mark.asyncio -async def test_generate_token_with_successful_connection(mock_responses): +async def test_generate_token_with_successful_connection(mock_responses) -> None: async with create_salesforce_source() as source: response_payload = { "access_token": "foo", @@ -842,7 +923,7 @@ async def test_generate_token_with_successful_connection(mock_responses): @pytest.mark.asyncio async def test_generate_token_with_bad_domain_raises_error( patch_sleep, mock_responses, patch_cancellable_sleeps -): +) -> None: async with create_salesforce_source(mock_token=False) as source: mock_responses.post( f"{TEST_BASE_URL}/services/oauth2/token", status=500, repeat=True @@ -854,7 +935,7 @@ async def test_generate_token_with_bad_domain_raises_error( @pytest.mark.asyncio async def test_generate_token_with_bad_credentials_raises_error( patch_sleep, mock_responses, patch_cancellable_sleeps -): +) -> None: async with create_salesforce_source(mock_token=False) as source: mock_responses.post( f"{TEST_BASE_URL}/services/oauth2/token", @@ -876,7 +957,7 @@ async def test_generate_token_with_bad_credentials_raises_error( @pytest.mark.asyncio async def test_generate_token_with_unexpected_error_retries( patch_sleep, mock_responses, patch_cancellable_sleeps -): +) -> None: async with create_salesforce_source() as source: response_payload = { "access_token": "foo", @@ -913,7 +994,7 @@ async def test_generate_token_with_unexpected_error_retries( "connectors.sources.salesforce.RELEVANT_SOBJECTS", ["FooField", "BarField", "ArghField"], ) -async def test_get_queryable_sobjects(mock_responses, sobject, expected_result): +async def test_get_queryable_sobjects(mock_responses, sobject, expected_result) -> None: async with create_salesforce_source(mock_queryables=False) as source: response_payload = { "sobjects": [ @@ -944,7 +1025,7 @@ async def test_get_queryable_sobjects(mock_responses, sobject, expected_result): "connectors.sources.salesforce.RELEVANT_SOBJECT_FIELDS", ["FooField", "BarField", "ArghField"], ) -async def test_get_queryable_fields(mock_responses): +async def test_get_queryable_fields(mock_responses) -> None: async with create_salesforce_source(mock_queryables=False) as source: expected_fields = [ { @@ -971,7 +1052,7 @@ async def test_get_queryable_fields(mock_responses): @pytest.mark.asyncio -async def test_execute_non_paginated_query(mock_responses): +async def test_execute_non_paginated_query(mock_responses) -> None: async with create_salesforce_source() as source: with mock.patch.object( source.salesforce_client, "_get_json", return_value=ACCOUNT_RESPONSE_PAYLOAD @@ -983,7 +1064,7 @@ async def test_execute_non_paginated_query(mock_responses): @pytest.mark.asyncio -async def test_get_accounts_when_success(mock_responses): +async def test_get_accounts_when_success(mock_responses) -> None: async with create_salesforce_source() as source: payload = deepcopy(ACCOUNT_RESPONSE_PAYLOAD) expected_record = payload["records"][0] @@ -1048,7 +1129,7 @@ async def test_get_accounts_when_success(mock_responses): @pytest.mark.asyncio -async def test_get_accounts_when_paginated_yields_all_pages(mock_responses): +async def test_get_accounts_when_paginated_yields_all_pages(mock_responses) -> None: async with create_salesforce_source() as source: response_page_1 = { "done": False, @@ -1087,7 +1168,7 @@ async def test_get_accounts_when_paginated_yields_all_pages(mock_responses): @pytest.mark.asyncio -async def test_get_accounts_when_invalid_request(patch_sleep, mock_responses): +async def test_get_accounts_when_invalid_request(patch_sleep, mock_responses) -> None: async with create_salesforce_source(mock_queryables=False) as source: response_payload = [ {"message": "Unable to process query.", "errorCode": "INVALID_FIELD"} @@ -1105,7 +1186,7 @@ async def test_get_accounts_when_invalid_request(patch_sleep, mock_responses): @pytest.mark.asyncio -async def test_get_accounts_when_not_queryable_yields_nothing(mock_responses): +async def test_get_accounts_when_not_queryable_yields_nothing(mock_responses) -> None: async with create_salesforce_source() as source: source.salesforce_client._is_queryable = mock.AsyncMock(return_value=False) async for record in source.salesforce_client.get_accounts(): @@ -1113,7 +1194,7 @@ async def test_get_accounts_when_not_queryable_yields_nothing(mock_responses): @pytest.mark.asyncio -async def test_get_contacts_when_not_queryable_yields_nothing(mock_responses): +async def test_get_contacts_when_not_queryable_yields_nothing(mock_responses) -> None: async with create_salesforce_source() as source: source.salesforce_client._is_queryable = mock.AsyncMock(return_value=False) async for record in source.salesforce_client.get_contacts(): @@ -1121,7 +1202,7 @@ async def test_get_contacts_when_not_queryable_yields_nothing(mock_responses): @pytest.mark.asyncio -async def test_get_leads_when_not_queryable_yields_nothing(mock_responses): +async def test_get_leads_when_not_queryable_yields_nothing(mock_responses) -> None: async with create_salesforce_source() as source: source.salesforce_client._is_queryable = mock.AsyncMock(return_value=False) async for record in source.salesforce_client.get_leads(): @@ -1129,7 +1210,9 @@ async def test_get_leads_when_not_queryable_yields_nothing(mock_responses): @pytest.mark.asyncio -async def test_get_opportunities_when_not_queryable_yields_nothing(mock_responses): +async def test_get_opportunities_when_not_queryable_yields_nothing( + mock_responses, +) -> None: async with create_salesforce_source() as source: source.salesforce_client._is_queryable = mock.AsyncMock(return_value=False) async for record in source.salesforce_client.get_opportunities(): @@ -1137,7 +1220,7 @@ async def test_get_opportunities_when_not_queryable_yields_nothing(mock_response @pytest.mark.asyncio -async def test_get_campaigns_when_not_queryable_yields_nothing(mock_responses): +async def test_get_campaigns_when_not_queryable_yields_nothing(mock_responses) -> None: async with create_salesforce_source() as source: source.salesforce_client._is_queryable = mock.AsyncMock(return_value=False) async for record in source.salesforce_client.get_campaigns(): @@ -1145,7 +1228,7 @@ async def test_get_campaigns_when_not_queryable_yields_nothing(mock_responses): @pytest.mark.asyncio -async def test_get_cases_when_not_queryable_yields_nothing(mock_responses): +async def test_get_cases_when_not_queryable_yields_nothing(mock_responses) -> None: async with create_salesforce_source() as source: source.salesforce_client._is_queryable = mock.AsyncMock(return_value=False) async for record in source.salesforce_client.get_cases(): @@ -1153,7 +1236,7 @@ async def test_get_cases_when_not_queryable_yields_nothing(mock_responses): @pytest.mark.asyncio -async def test_get_opportunities_when_success(mock_responses): +async def test_get_opportunities_when_success(mock_responses) -> None: async with create_salesforce_source() as source: expected_doc = { "_id": "opportunity_id", @@ -1191,7 +1274,7 @@ async def test_get_opportunities_when_success(mock_responses): @pytest.mark.asyncio -async def test_get_contacts_when_success(mock_responses): +async def test_get_contacts_when_success(mock_responses) -> None: async with create_salesforce_source() as source: payload = deepcopy(CONTACT_RESPONSE_PAYLOAD) expected_record = payload["records"][0] @@ -1240,7 +1323,7 @@ async def test_get_contacts_when_success(mock_responses): @pytest.mark.asyncio -async def test_get_leads_when_success(mock_responses): +async def test_get_leads_when_success(mock_responses) -> None: async with create_salesforce_source() as source: payload = deepcopy(LEAD_RESPONSE_PAYLOAD) expected_record = payload["records"][0] @@ -1295,7 +1378,7 @@ async def test_get_leads_when_success(mock_responses): @pytest.mark.asyncio -async def test_get_campaigns_when_success(mock_responses): +async def test_get_campaigns_when_success(mock_responses) -> None: async with create_salesforce_source() as source: expected_doc = { "_id": "campaign_id", @@ -1344,7 +1427,7 @@ async def test_get_campaigns_when_success(mock_responses): @pytest.mark.asyncio -async def test_get_cases_when_success(mock_responses): +async def test_get_cases_when_success(mock_responses) -> None: async with create_salesforce_source() as source: payload = deepcopy(CASE_RESPONSE_PAYLOAD) expected_record = payload["records"][0] @@ -1525,8 +1608,11 @@ async def test_get_cases_when_success(mock_responses): ], ) async def test_get_all_with_content_docs_when_success( - mock_responses, response_status, response_body, expected_attachment -): + mock_responses, + response_status, + response_body, + expected_attachment: Union[List[str], int, str], +) -> None: async with create_salesforce_source() as source: expected_doc = { "_id": "content_document_id", @@ -1580,7 +1666,7 @@ async def test_get_all_with_content_docs_when_success( @pytest.mark.asyncio -async def test_get_all_with_content_docs_and_extraction_service(mock_responses): +async def test_get_all_with_content_docs_and_extraction_service(mock_responses) -> None: with ( patch( "connectors.content_extraction.ContentExtraction.extract_text", @@ -1661,7 +1747,7 @@ async def test_get_all_with_content_docs_and_extraction_service(mock_responses): ), ], ) -async def test_modify_query(soql_query, modified_query): +async def test_modify_query(soql_query, modified_query) -> None: async with create_salesforce_source() as source: query = source.salesforce_client.modify_soql_query(soql_query) assert query == modified_query @@ -1681,7 +1767,7 @@ async def test_modify_query(soql_query, modified_query): ), ], ) -async def test_add_last_modified_date(soql_query, modified_query): +async def test_add_last_modified_date(soql_query, modified_query) -> None: async with create_salesforce_source() as source: query = source.salesforce_client._add_last_modified_date(soql_query) assert query == modified_query @@ -1698,7 +1784,7 @@ async def test_add_last_modified_date(soql_query, modified_query): ), ], ) -async def test_add_id(soql_query, modified_query): +async def test_add_id(soql_query, modified_query) -> None: async with create_salesforce_source() as source: query = source.salesforce_client._add_id(soql_query) assert query == modified_query @@ -1740,7 +1826,9 @@ async def test_add_id(soql_query, modified_query): ], ) @pytest.mark.asyncio -async def test_get_docs_for_soql_query(mock_responses, filtering, expected_docs): +async def test_get_docs_for_soql_query( + mock_responses, filtering, expected_docs +) -> None: async with create_salesforce_source() as source: mock_responses.get( TEST_FILE_DOWNLOAD_URL, @@ -1776,7 +1864,7 @@ async def test_get_docs_for_soql_query(mock_responses, filtering, expected_docs) ], ) @pytest.mark.asyncio -async def test_get_docs_for_sosl_query(mock_responses, filtering): +async def test_get_docs_for_sosl_query(mock_responses, filtering) -> None: async with create_salesforce_source() as source: mock_responses.get( TEST_FILE_DOWNLOAD_URL, @@ -1798,7 +1886,7 @@ async def test_get_docs_for_sosl_query(mock_responses, filtering): @pytest.mark.asyncio -async def test_remote_validation(mock_responses): +async def test_remote_validation(mock_responses) -> None: async with create_salesforce_source() as source: filtering = [{"query": "SELECT Id, Name FROM Account", "language": "SOQL"}] mock_responses.get( @@ -1816,7 +1904,7 @@ async def test_remote_validation(mock_responses): @pytest.mark.asyncio -async def test_remote_validation_negative(): +async def test_remote_validation_negative() -> None: async with create_salesforce_source() as source: filtering = [ { @@ -1831,7 +1919,7 @@ async def test_remote_validation_negative(): @pytest.mark.asyncio -async def test_prepare_sobject_cache(mock_responses): +async def test_prepare_sobject_cache(mock_responses) -> None: async with create_salesforce_source() as source: sobjects = { "records": [ @@ -1853,7 +1941,9 @@ async def test_prepare_sobject_cache(mock_responses): @pytest.mark.asyncio -async def test_request_when_token_invalid_refetches_token(patch_sleep, mock_responses): +async def test_request_when_token_invalid_refetches_token( + patch_sleep, mock_responses +) -> None: async with create_salesforce_source(mock_token=False) as source: payload = deepcopy(ACCOUNT_RESPONSE_PAYLOAD) expected_record = payload["records"][0] @@ -1894,7 +1984,9 @@ async def test_request_when_token_invalid_refetches_token(patch_sleep, mock_resp @pytest.mark.asyncio -async def test_request_when_rate_limited_raises_error_no_retries(mock_responses): +async def test_request_when_rate_limited_raises_error_no_retries( + mock_responses, +) -> None: async with create_salesforce_source() as source: response_payload = [ { @@ -1924,7 +2016,7 @@ async def test_request_when_rate_limited_raises_error_no_retries(mock_responses) ) async def test_request_when_invalid_query_raises_error_no_retries( mock_responses, error_code -): +) -> None: async with create_salesforce_source() as source: response_payload = [ { @@ -1946,7 +2038,7 @@ async def test_request_when_invalid_query_raises_error_no_retries( @pytest.mark.asyncio async def test_request_when_generic_400_raises_error_with_retries( patch_sleep, mock_responses -): +) -> None: async with create_salesforce_source() as source: mock_responses.get( TEST_QUERY_MATCH_URL, @@ -1962,7 +2054,7 @@ async def test_request_when_generic_400_raises_error_with_retries( @pytest.mark.asyncio async def test_request_when_generic_500_raises_error_with_retries( patch_sleep, mock_responses -): +) -> None: async with create_salesforce_source() as source: mock_responses.get( TEST_QUERY_MATCH_URL, @@ -1976,7 +2068,7 @@ async def test_request_when_generic_500_raises_error_with_retries( @pytest.mark.asyncio -async def test_build_soql_query_with_fields(): +async def test_build_soql_query_with_fields() -> None: expected_columns = [ "Id", "CreatedDate", @@ -2017,7 +2109,7 @@ async def test_build_soql_query_with_fields(): @pytest.mark.asyncio -async def test_combine_duplicate_content_docs_with_duplicates(): +async def test_combine_duplicate_content_docs_with_duplicates() -> None: async with create_salesforce_source(mock_queryables=False) as source: content_docs = [ { @@ -2044,25 +2136,25 @@ async def test_combine_duplicate_content_docs_with_duplicates(): "user, result", [("Alex Wilber", "user:Alex Wilber"), ("", None)], ) -async def test_prefix_user(user, result): +async def test_prefix_user(user, result) -> None: prefixed_user = _prefix_user(user=user) assert prefixed_user == result @pytest.mark.asyncio -async def test_prefix_user_id(): +async def test_prefix_user_id() -> None: prefixed_user_id = _prefix_user_id(user_id="ae34fad12") assert prefixed_user_id == "user_id:ae34fad12" @pytest.mark.asyncio -async def test_prefix_email(): +async def test_prefix_email() -> None: prefixed_email = _prefix_email(email="alex.wilber@gmail.com") assert prefixed_email == "email:alex.wilber@gmail.com" @pytest.mark.asyncio -async def test_get_access_control_dls_disabled(): +async def test_get_access_control_dls_disabled() -> None: async with create_salesforce_source() as source: source._dls_enabled = MagicMock(return_value=False) @@ -2074,7 +2166,7 @@ async def test_get_access_control_dls_disabled(): @pytest.mark.asyncio -async def test_get_access_control_dls_enabled(mock_responses): +async def test_get_access_control_dls_enabled(mock_responses) -> None: expected_user_doc = { "_id": "user_id", "identity": { @@ -2105,7 +2197,7 @@ async def test_get_access_control_dls_enabled(mock_responses): @pytest.mark.asyncio -async def test_get_docs_with_dls_enabled(mock_responses): +async def test_get_docs_with_dls_enabled(mock_responses) -> None: async with create_salesforce_source() as source: source._dls_enabled = MagicMock(return_value=True) source.salesforce_client._custom_objects = AsyncMock( @@ -2125,7 +2217,7 @@ async def test_get_docs_with_dls_enabled(mock_responses): @pytest.mark.asyncio -async def test_get_docs_with_configured_list_of_sobjects(mock_responses): +async def test_get_docs_with_configured_list_of_sobjects(mock_responses) -> None: async with create_salesforce_source() as source: source.salesforce_client.standard_objects_to_sync = ["Account", "Contact"] source.salesforce_client.sync_custom_objects = False @@ -2149,7 +2241,7 @@ async def test_get_docs_with_configured_list_of_sobjects(mock_responses): @pytest.mark.asyncio -async def test_get_docs_sync_custom_objects(mock_responses): +async def test_get_docs_sync_custom_objects(mock_responses) -> None: async with create_salesforce_source() as source: source.salesforce_client._custom_objects = AsyncMock( return_value=["CustomObject", "Connector__c"] @@ -2174,7 +2266,9 @@ async def test_get_docs_sync_custom_objects(mock_responses): @pytest.mark.asyncio -async def test_queryable_sobject_fields_performance_optimization(mock_responses): +async def test_queryable_sobject_fields_performance_optimization( + mock_responses, +) -> None: """ Test the performance optimization that reduces API calls from O(n*14) to O(14) diff --git a/tests/sources/test_sandfly.py b/tests/sources/test_sandfly.py index 0df4db4bf..dfe0d0acd 100644 --- a/tests/sources/test_sandfly.py +++ b/tests/sources/test_sandfly.py @@ -4,6 +4,7 @@ # you may not use this file except in compliance with the Elastic License 2.0. # from datetime import datetime +from typing import Dict, List, Optional, Union from unittest.mock import Mock, patch import pytest @@ -29,15 +30,15 @@ from tests.sources.support import create_source SANDFLY_SERVER_URL = "https://blackbird.sandflysecurity.com/v4" -URL_SANDFLY_LOGIN = SANDFLY_SERVER_URL + "/auth/login" -URL_SANDFLY_LICENSE = SANDFLY_SERVER_URL + "/license" -URL_SANDFLY_HOSTS = SANDFLY_SERVER_URL + "/hosts" -URL_SANDFLY_SSH_SUMMARY = SANDFLY_SERVER_URL + "/sshhunter/summary" -URL_SANDFLY_SSH_KEY1 = SANDFLY_SERVER_URL + "/sshhunter/key/1" -URL_SANDFLY_SSH_KEY2 = SANDFLY_SERVER_URL + "/sshhunter/key/2" -URL_SANDFLY_RESULTS = SANDFLY_SERVER_URL + "/results" - -configuration = { +URL_SANDFLY_LOGIN: str = SANDFLY_SERVER_URL + "/auth/login" +URL_SANDFLY_LICENSE: str = SANDFLY_SERVER_URL + "/license" +URL_SANDFLY_HOSTS: str = SANDFLY_SERVER_URL + "/hosts" +URL_SANDFLY_SSH_SUMMARY: str = SANDFLY_SERVER_URL + "/sshhunter/summary" +URL_SANDFLY_SSH_KEY1: str = SANDFLY_SERVER_URL + "/sshhunter/key/1" +URL_SANDFLY_SSH_KEY2: str = SANDFLY_SERVER_URL + "/sshhunter/key/2" +URL_SANDFLY_RESULTS: str = SANDFLY_SERVER_URL + "/results" + +configuration: Dict[str, Union[int, str]] = { "server_url": SANDFLY_SERVER_URL, "username": "elastic_api_user", "password": "elastic_api_password@@", @@ -76,7 +77,15 @@ } # Hosts Response Data -HOSTS_RESPONSE_DATA = { +HOSTS_RESPONSE_DATA: Dict[ + str, + List[ + Union[ + Dict[str, Optional[str]], + Dict[str, Union[Dict[str, Dict[str, Dict[str, str]]], str]], + ] + ], +] = { "data": [ { "host_id": "1001", @@ -171,7 +180,7 @@ async def sandfly_data_source(): @pytest.mark.asyncio -async def test_sandfly_date(sandfly_client, mock_responses): +async def test_sandfly_date(sandfly_client, mock_responses) -> None: expiry = "2025-06-23T17:35:23Z" expiry_date = extract_sandfly_date(expiry) assert type(expiry_date) is datetime @@ -187,7 +196,7 @@ async def test_sandfly_date(sandfly_client, mock_responses): @pytest.mark.asyncio -async def test_client_ping_success(sandfly_client, mock_responses): +async def test_client_ping_success(sandfly_client, mock_responses) -> None: mock_responses.head( SANDFLY_SERVER_URL, status=401, # Error code 401 Unauthorized means server is running @@ -197,7 +206,7 @@ async def test_client_ping_success(sandfly_client, mock_responses): @pytest.mark.asyncio @patch("connectors.utils.time_to_sleep_between_retries", Mock(return_value=0)) -async def test_client_ping_failure(sandfly_client, mock_responses): +async def test_client_ping_failure(sandfly_client, mock_responses) -> None: request_error = ClientResponseError(None, None) request_error.status = 403 request_error.message = "Forbidden" @@ -212,7 +221,7 @@ async def test_client_ping_failure(sandfly_client, mock_responses): @pytest.mark.asyncio @patch("connectors.utils.time_to_sleep_between_retries", Mock(return_value=0)) -async def test_client_login_failures(sandfly_client, mock_responses): +async def test_client_login_failures(sandfly_client, mock_responses) -> None: request_error = FetchTokenError(None, None) request_error.status = 403 request_error.message = "Forbidden" @@ -244,7 +253,7 @@ async def test_client_login_failures(sandfly_client, mock_responses): @pytest.mark.asyncio @patch("connectors.utils.time_to_sleep_between_retries", Mock(return_value=0)) -async def test_client_resource_not_found(sandfly_client, mock_responses): +async def test_client_resource_not_found(sandfly_client, mock_responses) -> None: mock_responses.post( URL_SANDFLY_LOGIN, status=200, @@ -276,7 +285,7 @@ async def test_client_resource_not_found(sandfly_client, mock_responses): @pytest.mark.asyncio -async def test_client_get_license(sandfly_client, mock_responses): +async def test_client_get_license(sandfly_client, mock_responses) -> None: mock_responses.post( URL_SANDFLY_LOGIN, status=200, @@ -297,7 +306,7 @@ async def test_client_get_license(sandfly_client, mock_responses): @pytest.mark.asyncio -async def test_client_get_hosts(sandfly_client, mock_responses): +async def test_client_get_hosts(sandfly_client, mock_responses) -> None: mock_responses.post( URL_SANDFLY_LOGIN, status=200, @@ -322,7 +331,7 @@ async def test_client_get_hosts(sandfly_client, mock_responses): @pytest.mark.asyncio -async def test_client_get_ssh_keys(sandfly_client, mock_responses): +async def test_client_get_ssh_keys(sandfly_client, mock_responses) -> None: mock_responses.post( URL_SANDFLY_LOGIN, status=200, @@ -356,7 +365,7 @@ async def test_client_get_ssh_keys(sandfly_client, mock_responses): @pytest.mark.asyncio -async def test_client_get_results_by_time(sandfly_client, mock_responses): +async def test_client_get_results_by_time(sandfly_client, mock_responses) -> None: mock_responses.post( URL_SANDFLY_LOGIN, status=200, @@ -396,7 +405,7 @@ async def test_client_get_results_by_time(sandfly_client, mock_responses): @pytest.mark.asyncio -async def test_client_get_results_by_id(sandfly_client, mock_responses): +async def test_client_get_results_by_id(sandfly_client, mock_responses) -> None: mock_responses.post( URL_SANDFLY_LOGIN, status=200, @@ -439,7 +448,7 @@ async def test_client_get_results_by_id(sandfly_client, mock_responses): @pytest.mark.asyncio -async def test_data_source_ping_success(sandfly_data_source, mock_responses): +async def test_data_source_ping_success(sandfly_data_source, mock_responses) -> None: mock_responses.head( SANDFLY_SERVER_URL, status=401, # Error code 401 Unauthorized means server is running @@ -449,7 +458,7 @@ async def test_data_source_ping_success(sandfly_data_source, mock_responses): @pytest.mark.asyncio @patch("connectors.utils.time_to_sleep_between_retries", Mock(return_value=0)) -async def test_data_source_ping_failure(sandfly_data_source, mock_responses): +async def test_data_source_ping_failure(sandfly_data_source, mock_responses) -> None: request_error = ClientResponseError(None, None) request_error.status = 403 request_error.message = "Forbidden" @@ -467,7 +476,7 @@ async def test_data_source_ping_failure(sandfly_data_source, mock_responses): @pytest.mark.asyncio async def test_data_source_get_docs_license_expired( sandfly_data_source, mock_responses -): +) -> None: mock_responses.post( URL_SANDFLY_LOGIN, status=200, @@ -487,7 +496,9 @@ async def test_data_source_get_docs_license_expired( @pytest.mark.asyncio -async def test_data_source_get_docs_not_licensed(sandfly_data_source, mock_responses): +async def test_data_source_get_docs_not_licensed( + sandfly_data_source, mock_responses +) -> None: mock_responses.post( URL_SANDFLY_LOGIN, status=200, @@ -507,7 +518,7 @@ async def test_data_source_get_docs_not_licensed(sandfly_data_source, mock_respo @pytest.mark.asyncio -async def test_data_source_get_docs(sandfly_data_source, mock_responses): +async def test_data_source_get_docs(sandfly_data_source, mock_responses) -> None: mock_responses.post( URL_SANDFLY_LOGIN, status=200, @@ -568,7 +579,7 @@ async def test_data_source_get_docs(sandfly_data_source, mock_responses): @pytest.mark.parametrize("sync_cursor", [None, {}]) async def test_data_source_get_docs_inc_empty_sync_cursor( sandfly_data_source, mock_responses, sync_cursor -): +) -> None: with pytest.raises(SyncCursorEmpty): docs = [] async for doc, _, _ in sandfly_data_source.get_docs_incrementally( @@ -580,7 +591,7 @@ async def test_data_source_get_docs_inc_empty_sync_cursor( @pytest.mark.asyncio async def test_data_source_get_docs_inc_license_expired( sandfly_data_source, mock_responses -): +) -> None: mock_responses.post( URL_SANDFLY_LOGIN, status=200, @@ -608,7 +619,7 @@ async def test_data_source_get_docs_inc_license_expired( @pytest.mark.asyncio async def test_data_source_get_docs_inc_not_licensed( sandfly_data_source, mock_responses -): +) -> None: mock_responses.post( URL_SANDFLY_LOGIN, status=200, @@ -634,7 +645,7 @@ async def test_data_source_get_docs_inc_not_licensed( @pytest.mark.asyncio -async def test_data_source_get_docs_inc(sandfly_data_source, mock_responses): +async def test_data_source_get_docs_inc(sandfly_data_source, mock_responses) -> None: mock_responses.post( URL_SANDFLY_LOGIN, status=200, diff --git a/tests/sources/test_servicenow.py b/tests/sources/test_servicenow.py index 4d6e3b71d..677161189 100644 --- a/tests/sources/test_servicenow.py +++ b/tests/sources/test_servicenow.py @@ -30,7 +30,7 @@ @asynccontextmanager -async def create_service_now_source(use_text_extraction_service=False): +async def create_service_now_source(use_text_extraction_service: bool = False): async with create_source( ServiceNowDataSource, url="http://127.0.0.1:1234", @@ -45,7 +45,7 @@ async def create_service_now_source(use_text_extraction_service=False): class MockResponse: """Mock response of aiohttp get method""" - def __init__(self, res, headers): + def __init__(self, res, headers) -> None: """Setup a response""" self._res = res self.headers = headers @@ -59,7 +59,7 @@ async def __aenter__(self): """Enters an async with block""" return self - async def __aexit__(self, exc_type, exc, tb): + async def __aexit__(self, exc_type, exc, tb) -> None: """Closes an async with block""" pass @@ -67,7 +67,7 @@ async def __aexit__(self, exc_type, exc, tb): class StreamerReader: """Mock Stream Reader""" - def __init__(self, res): + def __init__(self, res) -> None: """Setup a response""" self._res = res self._size = None @@ -79,7 +79,7 @@ async def iter_chunked(self, size): @pytest.mark.parametrize("field", ["username", "password", "services"]) @pytest.mark.asyncio -async def test_validate_config_missing_fields_then_raise(field): +async def test_validate_config_missing_fields_then_raise(field) -> None: async with create_service_now_source() as source: source.configuration.get_field(field).value = "" @@ -88,7 +88,7 @@ async def test_validate_config_missing_fields_then_raise(field): @pytest.mark.asyncio -async def test_validate_configuration_with_invalid_service_then_raise(): +async def test_validate_configuration_with_invalid_service_then_raise() -> None: async with create_service_now_source() as source: source.servicenow_client.services = ["label_1", "label_3"] @@ -114,7 +114,7 @@ async def test_validate_configuration_with_invalid_service_then_raise(): @pytest.mark.asyncio -async def test_ping_for_successful_connection(): +async def test_ping_for_successful_connection() -> None: async with create_service_now_source() as source: with mock.patch.object( ServiceNowClient, @@ -125,7 +125,7 @@ async def test_ping_for_successful_connection(): @pytest.mark.asyncio -async def test_ping_for_unsuccessful_connection_then_raise(): +async def test_ping_for_unsuccessful_connection_then_raise() -> None: async with create_service_now_source() as source: with mock.patch.object( ServiceNowClient, @@ -137,7 +137,7 @@ async def test_ping_for_unsuccessful_connection_then_raise(): @pytest.mark.asyncio -async def test_tweak_bulk_options(): +async def test_tweak_bulk_options() -> None: async with create_service_now_source() as source: source.concurrent_downloads = 10 options = {"concurrent_downloads": 5} @@ -147,7 +147,7 @@ async def test_tweak_bulk_options(): @pytest.mark.asyncio -async def test_get_data(): +async def test_get_data() -> None: async with create_service_now_source() as source: source.servicenow_client._api_call = mock.AsyncMock( return_value=MockResponse( @@ -166,7 +166,7 @@ async def test_get_data(): @pytest.mark.asyncio @patch("connectors.utils.time_to_sleep_between_retries", Mock(return_value=0)) -async def test_get_data_with_retry(): +async def test_get_data_with_retry() -> None: async with create_service_now_source() as source: source.servicenow_client._api_call = mock.AsyncMock( side_effect=ServerDisconnectedError @@ -178,7 +178,7 @@ async def test_get_data_with_retry(): @pytest.mark.asyncio -async def test_get_table_length(): +async def test_get_table_length() -> None: async with create_service_now_source() as source: source.servicenow_client._api_call = mock.AsyncMock( return_value=MockResponse( @@ -193,7 +193,7 @@ async def test_get_table_length(): @pytest.mark.asyncio @patch("connectors.utils.time_to_sleep_between_retries", Mock(return_value=0)) -async def test_get_table_length_with_retry(): +async def test_get_table_length_with_retry() -> None: async with create_service_now_source() as source: source.servicenow_client._api_call = mock.AsyncMock( side_effect=ServerDisconnectedError @@ -205,7 +205,7 @@ async def test_get_table_length_with_retry(): @pytest.mark.asyncio @patch("connectors.utils.time_to_sleep_between_retries", Mock(return_value=0)) -async def test_get_data_with_empty_response(): +async def test_get_data_with_empty_response() -> None: async with create_service_now_source() as source: source.servicenow_client._api_call = mock.AsyncMock( return_value=MockResponse( @@ -221,7 +221,7 @@ async def test_get_data_with_empty_response(): @pytest.mark.asyncio @patch("connectors.utils.time_to_sleep_between_retries", Mock(return_value=0)) -async def test_get_data_with_text_response(): +async def test_get_data_with_text_response() -> None: async with create_service_now_source() as source: source.servicenow_client._api_call = mock.AsyncMock( return_value=MockResponse( @@ -236,7 +236,7 @@ async def test_get_data_with_text_response(): @pytest.mark.asyncio -async def test_filter_services_with_exception(): +async def test_filter_services_with_exception() -> None: async with create_service_now_source() as source: source.servicenow_client.services = ["label_1", "label_3"] @@ -249,7 +249,7 @@ async def test_filter_services_with_exception(): @pytest.mark.asyncio -async def test_filter_services_when_sysparm_fields_missing(): +async def test_filter_services_when_sysparm_fields_missing() -> None: async with create_service_now_source() as source: source.servicenow_client.services = ["Incident", "Feature", "User"] @@ -275,7 +275,9 @@ async def test_filter_services_when_sysparm_fields_missing(): @pytest.mark.asyncio -async def test_filter_services_when_sysparm_fields_missing_for_unrelated_table(): +async def test_filter_services_when_sysparm_fields_missing_for_unrelated_table() -> ( + None +): async with create_service_now_source() as source: source.servicenow_client.services = ["Incident", "Feature"] @@ -303,7 +305,7 @@ async def test_filter_services_when_sysparm_fields_missing_for_unrelated_table() @pytest.mark.asyncio -async def test_get_docs_with_skipping_table_data(): +async def test_get_docs_with_skipping_table_data() -> None: async with create_service_now_source() as source: source.servicenow_client._api_call = mock.AsyncMock( return_value=MockResponse( @@ -360,7 +362,9 @@ async def test_get_docs_with_skipping_table_data(): ), ], ) -async def test_get_docs_with_skipping_attachment_data(dls_enabled, expected_response): +async def test_get_docs_with_skipping_attachment_data( + dls_enabled, expected_response +) -> None: async with create_service_now_source() as source: source._dls_enabled = Mock(return_value=dls_enabled) source._fetch_access_controls = mock.AsyncMock( @@ -408,7 +412,7 @@ async def test_get_docs_with_skipping_attachment_data(dls_enabled, expected_resp @pytest.mark.asyncio -async def test_get_docs_with_configured_services(): +async def test_get_docs_with_configured_services() -> None: async with create_service_now_source() as source: source.servicenow_client.services = ["custom"] source.servicenow_client._api_call = mock.AsyncMock( @@ -481,7 +485,7 @@ async def test_get_docs_with_configured_services(): @pytest.mark.asyncio -async def test_fetch_attachment_content_with_doit(): +async def test_fetch_attachment_content_with_doit() -> None: async with create_service_now_source() as source: source.servicenow_client._api_call = mock.AsyncMock( return_value=MockResponse(res=b"Attachment Content", headers={}) @@ -505,7 +509,7 @@ async def test_fetch_attachment_content_with_doit(): @pytest.mark.asyncio -async def test_fetch_attachment_content_with_extraction_service(): +async def test_fetch_attachment_content_with_extraction_service() -> None: with ( patch( "connectors.content_extraction.ContentExtraction.extract_text", @@ -541,7 +545,7 @@ async def test_fetch_attachment_content_with_extraction_service(): @pytest.mark.asyncio -async def test_fetch_attachment_content_with_upper_extension(): +async def test_fetch_attachment_content_with_upper_extension() -> None: async with create_service_now_source() as source: source.servicenow_client._api_call = mock.AsyncMock( return_value=MockResponse(res=b"Attachment Content", headers={}) @@ -565,7 +569,7 @@ async def test_fetch_attachment_content_with_upper_extension(): @pytest.mark.asyncio -async def test_fetch_attachment_content_without_doit(): +async def test_fetch_attachment_content_without_doit() -> None: async with create_service_now_source() as source: source.servicenow_client._api_call = mock.AsyncMock( return_value=MockResponse(res=b"Attachment Content", headers={}) @@ -584,7 +588,7 @@ async def test_fetch_attachment_content_without_doit(): @pytest.mark.asyncio -async def test_fetch_attachment_content_with_exception(): +async def test_fetch_attachment_content_with_exception() -> None: async with create_service_now_source() as source: source.servicenow_client._api_call = mock.AsyncMock( side_effect=Exception("Something went wrong") @@ -604,7 +608,7 @@ async def test_fetch_attachment_content_with_exception(): @pytest.mark.asyncio -async def test_fetch_attachment_content_with_unsupported_extension_then_skip(): +async def test_fetch_attachment_content_with_unsupported_extension_then_skip() -> None: async with create_service_now_source() as source: source.servicenow_client._api_call = mock.AsyncMock( return_value=MockResponse(res=b"Attachment Content", headers={}) @@ -624,7 +628,7 @@ async def test_fetch_attachment_content_with_unsupported_extension_then_skip(): @pytest.mark.asyncio -async def test_fetch_attachment_content_without_extension_then_skip(): +async def test_fetch_attachment_content_without_extension_then_skip() -> None: async with create_service_now_source() as source: source.servicenow_client._api_call = mock.AsyncMock( return_value=MockResponse(res=b"Attachment Content", headers={}) @@ -644,7 +648,7 @@ async def test_fetch_attachment_content_without_extension_then_skip(): @pytest.mark.asyncio -async def test_fetch_attachment_content_with_unsupported_file_size_then_skip(): +async def test_fetch_attachment_content_with_unsupported_file_size_then_skip() -> None: async with create_service_now_source() as source: source.servicenow_client._api_call = mock.AsyncMock( return_value=MockResponse(res=b"Attachment Content", headers={}) @@ -745,7 +749,9 @@ async def test_fetch_attachment_content_with_unsupported_file_size_then_skip(): ], ) @pytest.mark.asyncio -async def test_advanced_rules_validation(advanced_rules, expected_validation_result): +async def test_advanced_rules_validation( + advanced_rules, expected_validation_result +) -> None: async with create_service_now_source() as source: source.servicenow_client.get_table_length = mock.AsyncMock(return_value=2) @@ -784,7 +790,7 @@ async def test_advanced_rules_validation(advanced_rules, expected_validation_res ], ) @pytest.mark.asyncio -async def test_get_docs_with_advanced_rules(filtering): +async def test_get_docs_with_advanced_rules(filtering) -> None: async with create_service_now_source() as source: source.servicenow_client.services = ["custom"] source.servicenow_client._api_call = mock.AsyncMock( @@ -873,7 +879,7 @@ async def test_get_docs_with_advanced_rules(filtering): ], ) @pytest.mark.asyncio -async def test_get_docs_with_advanced_rules_pagination(filtering): +async def test_get_docs_with_advanced_rules_pagination(filtering) -> None: expected_filter_apis = [ { "headers": [ @@ -946,7 +952,7 @@ async def test_get_docs_with_advanced_rules_pagination(filtering): @pytest.mark.asyncio -async def test_get_access_control(): +async def test_get_access_control() -> None: expected_response = { "_id": "id_1", "identity": { @@ -992,7 +998,7 @@ async def test_get_access_control(): @pytest.mark.asyncio -async def test_get_access_control_dls_disabled(): +async def test_get_access_control_dls_disabled() -> None: async with create_service_now_source() as source: source._dls_enabled = Mock(return_value=False) @@ -1004,7 +1010,7 @@ async def test_get_access_control_dls_disabled(): @pytest.mark.asyncio -async def test_fetch_access_control(): +async def test_fetch_access_control() -> None: async with create_service_now_source() as source: with mock.patch.object( ServiceNowDataSource, @@ -1033,7 +1039,7 @@ async def test_fetch_access_control(): @pytest.mark.asyncio -async def test_fetch_access_control_for_public(): +async def test_fetch_access_control_for_public() -> None: async with create_service_now_source() as source: with mock.patch.object( ServiceNowDataSource, @@ -1080,7 +1086,7 @@ async def test_fetch_access_control_for_public(): @pytest.mark.asyncio -async def test_end_signal_is_added_to_queue_in_case_of_exception(): +async def test_end_signal_is_added_to_queue_in_case_of_exception() -> None: END_SIGNAL = "RECORD_TASK_FINISHED" async with create_service_now_source() as source: with patch.object( diff --git a/tests/sources/test_sharepoint_online.py b/tests/sources/test_sharepoint_online.py index 27a2ebcc4..6d6ccf00e 100644 --- a/tests/sources/test_sharepoint_online.py +++ b/tests/sources/test_sharepoint_online.py @@ -10,6 +10,7 @@ from datetime import datetime, timedelta, timezone from functools import partial from io import BytesIO +from typing import Any, Dict, List, Union from unittest.mock import ANY, AsyncMock, MagicMock, Mock, PropertyMock, patch import aiohttp @@ -65,7 +66,7 @@ WITHOUT_PREFIX = False IDENTITY_MAIL = "mail@spo.com" IDENTITY_USER_PRINCIPAL_NAME = "some identity" -IDENTITY_WITH_MAIL_AND_PRINCIPAL_NAME = { +IDENTITY_WITH_MAIL_AND_PRINCIPAL_NAME: Dict[str, str] = { "mail": IDENTITY_MAIL, "userPrincipalName": IDENTITY_USER_PRINCIPAL_NAME, } @@ -142,7 +143,7 @@ } ] -SAMPLE_DRIVE_PERMISSIONS = [ +SAMPLE_DRIVE_PERMISSIONS: List[Dict[str, Union[Dict[str, Dict[str, str]], str]]] = [ { "id": "3", "grantedToV2": { @@ -190,7 +191,7 @@ }, ] -SAMPLE_SITE_GROUP_USERS = [ +SAMPLE_SITE_GROUP_USERS: List[Dict[str, str]] = [ { "UserPrincipalName": SITEGROUP_USER_ONE_ID, "Email": SITEGROUP_USER_ONE_EMAIL, @@ -230,17 +231,17 @@ @asynccontextmanager async def create_spo_source( - tenant_id="1", - tenant_name="test", - client_id="2", - secret_value="3", - auth_method="secret", - site_collections=WILDCARD, - use_document_level_security=False, - use_text_extraction_service=False, - fetch_drive_item_permissions=True, - fetch_unique_list_permissions=True, - enumerate_all_sites=False, + tenant_id: str = "1", + tenant_name: str = "test", + client_id: str = "2", + secret_value: str = "3", + auth_method: str = "secret", + site_collections: str = WILDCARD, + use_document_level_security: bool = False, + use_text_extraction_service: bool = False, + fetch_drive_item_permissions: bool = True, + fetch_unique_list_permissions: bool = True, + enumerate_all_sites: bool = False, ): async with create_source( SharepointOnlineDataSource, @@ -276,11 +277,11 @@ def dls_enabled(value): return value -def access_control_matches(actual, expected): +def access_control_matches(actual, expected) -> bool: return all(access_control in expected for access_control in actual) -def access_control_is_equal(actual, expected): +def access_control_is_equal(actual, expected) -> bool: return set(actual) == set(expected) @@ -308,7 +309,7 @@ async def _fetch_token(self): raise error @pytest.mark.asyncio - async def test_fetch_token_raises_not_implemented_error(self): + async def test_fetch_token_raises_not_implemented_error(self) -> None: with pytest.raises(NotImplementedError) as e: mst = MicrosoftSecurityToken(None, None, None, None) @@ -317,7 +318,7 @@ async def test_fetch_token_raises_not_implemented_error(self): assert e is not None @pytest.mark.asyncio - async def test_get_returns_results_from_fetch_token(self): + async def test_get_returns_results_from_fetch_token(self) -> None: bearer = "something" expires_in = 0.1 @@ -331,7 +332,7 @@ async def test_get_returns_results_from_fetch_token(self): @pytest.mark.asyncio @freeze_time() - async def test_get_returns_cached_value_when_token_did_not_expire(self): + async def test_get_returns_cached_value_when_token_did_not_expire(self) -> None: original_bearer = "something" updated_bearer = "another" expires_at = datetime.utcnow() + timedelta(seconds=1) @@ -349,7 +350,7 @@ async def test_get_returns_cached_value_when_token_did_not_expire(self): assert second_bearer == original_bearer @pytest.mark.asyncio - async def test_get_returns_new_value_when_token_expired(self): + async def test_get_returns_new_value_when_token_expired(self) -> None: original_bearer = "something" updated_bearer = "another" expires_at = datetime.utcnow() + timedelta(seconds=-1) @@ -369,7 +370,7 @@ async def test_get_returns_new_value_when_token_expired(self): assert second_bearer == updated_bearer @pytest.mark.asyncio - async def test_get_raises_correct_exception_when_400(self): + async def test_get_raises_correct_exception_when_400(self) -> None: token = TestMicrosoftSecurityToken.StubMicrosoftSecurityTokenWrongConfig(400) with pytest.raises(TokenFetchFailed) as e: @@ -381,7 +382,7 @@ async def test_get_raises_correct_exception_when_400(self): assert e.match("Client ID") @pytest.mark.asyncio - async def test_get_raises_correct_exception_when_401(self): + async def test_get_raises_correct_exception_when_401(self) -> None: token = TestMicrosoftSecurityToken.StubMicrosoftSecurityTokenWrongConfig(401) with pytest.raises(TokenFetchFailed) as e: @@ -391,7 +392,7 @@ async def test_get_raises_correct_exception_when_401(self): assert e.match("Secret Value") @pytest.mark.asyncio - async def test_get_raises_correct_exception_when_any_other_status(self): + async def test_get_raises_correct_exception_when_any_other_status(self) -> None: message = "Internal server error" token = TestMicrosoftSecurityToken.StubMicrosoftSecurityTokenWrongConfig( 500, message @@ -415,7 +416,7 @@ async def token(self): @pytest.mark.asyncio @freeze_time() - async def test_fetch_token(self, token, mock_responses): + async def test_fetch_token(self, token, mock_responses) -> None: bearer = "hello" now = datetime.utcnow() expires_in = 15 @@ -432,7 +433,9 @@ async def test_fetch_token(self, token, mock_responses): @pytest.mark.asyncio @freeze_time() - async def test_fetch_token_retries(self, token, mock_responses, patch_sleep): + async def test_fetch_token_retries( + self, token, mock_responses, patch_sleep + ) -> None: bearer = "hello" expires_in = 15 now = datetime.utcnow() @@ -464,7 +467,7 @@ async def token(self): @pytest.mark.asyncio @freeze_time() - async def test_fetch_token(self, token, mock_responses): + async def test_fetch_token(self, token, mock_responses) -> None: bearer = "hello" expires_in = 15 now = datetime.utcnow() @@ -484,7 +487,9 @@ async def test_fetch_token(self, token, mock_responses): # Then this test can be removed @pytest.mark.asyncio @freeze_time() - async def test_fetch_token_retries(self, token, mock_responses, patch_sleep): + async def test_fetch_token_retries( + self, token, mock_responses, patch_sleep + ) -> None: bearer = "hello" expires_in = 15 now = datetime.utcnow() @@ -524,7 +529,7 @@ async def token(self): await session.close() @pytest.mark.asyncio - async def test_fetch_token(self, token, mock_responses): + async def test_fetch_token(self, token, mock_responses) -> None: bearer = "hello" expires_at = datetime.utcnow().timestamp() + 30 @@ -592,11 +597,11 @@ async def microsoft_api_session(self): await session.close() @property - def scroll_field(self): + def scroll_field(self) -> str: return "next_link" @pytest.mark.asyncio - async def test_fetch(self, microsoft_api_session, mock_responses): + async def test_fetch(self, microsoft_api_session, mock_responses) -> None: url = "http://localhost:1234/url" payload = {"test": "hello world"} @@ -607,7 +612,7 @@ async def test_fetch(self, microsoft_api_session, mock_responses): assert response == payload @pytest.mark.asyncio - async def test_post(self, microsoft_api_session, mock_responses): + async def test_post(self, microsoft_api_session, mock_responses) -> None: url = "http://localhost:1234/url" expected_response = {"test": "hello world"} payload = {"key": "value"} @@ -621,7 +626,7 @@ async def test_post(self, microsoft_api_session, mock_responses): @pytest.mark.asyncio async def test_post_with_batch_failures( self, microsoft_api_session, mock_responses, patch_cancellable_sleeps - ): + ) -> None: url = "https://graph.microsoft.com/v1.0/$batch" first_response = BATCH_THROTTLED_RESPONSE second_response = {"responses": [{"key": "value"}]} @@ -637,7 +642,7 @@ async def test_post_with_batch_failures( @pytest.mark.asyncio async def test_post_with_consecutive_batch_failures( self, microsoft_api_session, mock_responses, patch_cancellable_sleeps - ): + ) -> None: url = "https://graph.microsoft.com/v1.0/$batch" payload = {"key": "value"} @@ -653,7 +658,7 @@ async def test_post_with_consecutive_batch_failures( @pytest.mark.asyncio async def test_fetch_with_retry( self, microsoft_api_session, mock_responses, patch_sleep - ): + ) -> None: url = "http://localhost:1234/url" payload = {"test": "hello world"} @@ -670,7 +675,7 @@ async def test_fetch_with_retry( assert response == payload @pytest.mark.asyncio - async def test_scroll(self, microsoft_api_session, mock_responses): + async def test_scroll(self, microsoft_api_session, mock_responses) -> None: url = "http://localhost:1234/url" first_page = ["1", "2", "3"] @@ -694,7 +699,7 @@ async def test_scroll(self, microsoft_api_session, mock_responses): @pytest.mark.asyncio async def test_scroll_delta_url_with_data_link( self, microsoft_api_session, mock_responses - ): + ) -> None: drive_item = { "id": "1", "size": 15, @@ -733,7 +738,7 @@ async def test_scroll_delta_url_with_data_link( assert len(pages) == len(responses) @pytest.mark.asyncio - async def test_pipe(self, microsoft_api_session, mock_responses): + async def test_pipe(self, microsoft_api_session, mock_responses) -> None: class AsyncStream: def __init__(self): self.stream = BytesIO() @@ -760,7 +765,7 @@ async def test_call_api_with_429( mock_responses, patch_sleep, patch_cancellable_sleeps, - ): + ) -> None: url = "http://localhost:1234/download-some-sample-file" payload = {"hello": "world"} retry_after = 25 @@ -787,7 +792,7 @@ async def test_call_api_with_client_payload_error( mock_responses, patch_sleep, patch_cancellable_sleeps, - ): + ) -> None: url = "http://localhost:1234/download-some-sample-file" payload = {"hello": "world"} retry_after = DEFAULT_RETRY_SECONDS @@ -811,7 +816,7 @@ async def test_call_api_with_429_without_retry_after( mock_responses, patch_sleep, patch_cancellable_sleeps, - ): + ) -> None: url = "http://localhost:1234/download-some-sample-file" payload = {"hello": "world"} @@ -836,7 +841,7 @@ async def test_call_api_with_429_with_retry_after_and_backoff( mock_responses, patch_sleep, patch_cancellable_sleeps, - ): + ) -> None: url = "http://localhost:1234/download-some-sample-file" payload = {"hello": "world"} @@ -866,7 +871,7 @@ async def test_call_api_with_403( mock_responses, patch_sleep, patch_cancellable_sleeps, - ): + ) -> None: url = "http://localhost:1234/download-some-sample-file" # First throttle, then do not throttle @@ -893,7 +898,7 @@ async def test_call_api_with_404_with_retry_after_header( mock_responses, patch_sleep, patch_cancellable_sleeps, - ): + ) -> None: url = "http://localhost:1234/download-some-sample-file" payload = {"hello": "world"} retry_after = 25 @@ -920,7 +925,7 @@ async def test_call_api_with_404_without_retry_after_header( mock_responses, patch_sleep, patch_cancellable_sleeps, - ): + ) -> None: url = "http://localhost:1234/download-some-sample-file" # First throttle, then do not throttle @@ -947,7 +952,7 @@ async def test_call_api_with_400_without_retry_after_header( mock_responses, patch_sleep, patch_cancellable_sleeps, - ): + ) -> None: url = "http://localhost:1234/download-some-sample-file" # First throttle, then do not throttle @@ -970,7 +975,7 @@ async def test_call_api_with_unhandled_status( mock_responses, patch_sleep, patch_cancellable_sleeps, - ): + ) -> None: url = "http://localhost:1234/download-some-sample-file" error_message = "Something went wrong" @@ -995,19 +1000,19 @@ async def test_call_api_with_unhandled_status( class TestSharepointOnlineClient: @property - def tenant_id(self): + def tenant_id(self) -> str: return "tid" @property - def tenant_name(self): + def tenant_name(self) -> str: return "tname" @property - def client_id(self): + def client_id(self) -> str: return "cid" @property - def client_secret(self): + def client_secret(self) -> str: return "csecret" @pytest_asyncio.fixture @@ -1068,7 +1073,7 @@ async def _execute_scrolling_method(self, method, patch_scroll, setup_items, *ar return returned_items - async def _test_scrolling_method_not_found(self, method, patch_scroll): + async def _test_scrolling_method_not_found(self, method, patch_scroll) -> None: patch_scroll.side_effect = NotFound() returned_items = [] @@ -1078,7 +1083,7 @@ async def _test_scrolling_method_not_found(self, method, patch_scroll): assert len(returned_items) == 0 @pytest.mark.asyncio - async def test_groups(self, client, patch_scroll): + async def test_groups(self, client, patch_scroll) -> None: actual_items = ["1", "2", "3", "4"] returned_items = await self._execute_scrolling_method( @@ -1089,7 +1094,7 @@ async def test_groups(self, client, patch_scroll): assert returned_items == actual_items @pytest.mark.asyncio - async def test_group_sites(self, client, patch_scroll): + async def test_group_sites(self, client, patch_scroll) -> None: group_id = "12345" actual_items = ["1", "2", "3", "4"] @@ -1102,7 +1107,7 @@ async def test_group_sites(self, client, patch_scroll): assert returned_items == actual_items @pytest.mark.asyncio - async def test_group_sites_not_found(self, client, patch_scroll): + async def test_group_sites_not_found(self, client, patch_scroll) -> None: group_id = "12345" await self._test_scrolling_method_not_found( @@ -1110,7 +1115,7 @@ async def test_group_sites_not_found(self, client, patch_scroll): ) @pytest.mark.asyncio - async def test_site_collections(self, client, patch_scroll): + async def test_site_collections(self, client, patch_scroll) -> None: first_site_collection = { "@odata.context": "https://graph.microsoft.com/v1.0/$metadata#sites/$entity", "createdDateTime": "2023-12-12T12:00:00.000Z", @@ -1149,7 +1154,7 @@ async def test_site_collections(self, client, patch_scroll): @pytest.mark.asyncio async def test_site_collections_when_permission_missing( self, client, patch_fetch, patch_scroll - ): + ) -> None: actual_response = { "@odata.context": "https://graph.microsoft.com/v1.0/$metadata#sites/$entity", "createdDateTime": "2023-12-12T12:00:00.000Z", @@ -1169,7 +1174,7 @@ async def test_site_collections_when_permission_missing( assert site_collection == actual_response @pytest.mark.asyncio - async def test_sites_wildcard(self, client, patch_scroll): + async def test_sites_wildcard(self, client, patch_scroll) -> None: root_site = "root" actual_items = [ {"name": "First"}, @@ -1186,7 +1191,7 @@ async def test_sites_wildcard(self, client, patch_scroll): assert returned_items == actual_items @pytest.mark.asyncio - async def test_sites_filter(self, client, patch_scroll): + async def test_sites_filter(self, client, patch_scroll) -> None: root_site = "root" actual_items = [ {"name": "First"}, @@ -1205,7 +1210,7 @@ async def test_sites_filter(self, client, patch_scroll): assert actual_items[2] in returned_items @pytest.mark.asyncio - async def test_sites_filter_individually(self, client, patch_fetch): + async def test_sites_filter_individually(self, client, patch_fetch) -> None: root_site = "root" actual_items = [ {"name": "First"}, @@ -1226,7 +1231,7 @@ async def test_sites_filter_individually(self, client, patch_fetch): @pytest.mark.asyncio async def test_sites_filter_individually_plus_subsites( self, client, patch_fetch, patch_scroll - ): + ) -> None: root_site = "root" root_item = { "name": "First", @@ -1277,14 +1282,16 @@ async def test_sites_filter_individually_plus_subsites( ), ], ) - async def test_all_sites_with_error(self, client, patch_scroll, exception, raises): + async def test_all_sites_with_error( + self, client, patch_scroll, exception, raises + ) -> None: sharepoint_host = "example.sharepoint.com" patch_scroll.side_effect = exception with pytest.raises(raises): await anext(client._all_sites(sharepoint_host, [])) @pytest.mark.asyncio - async def test_site_drives(self, client, patch_scroll): + async def test_site_drives(self, client, patch_scroll) -> None: site_id = "12345" actual_items = ["1", "2", "3", "4"] @@ -1297,7 +1304,9 @@ async def test_site_drives(self, client, patch_scroll): assert returned_items == actual_items @pytest.mark.asyncio - async def test_drive_items_delta(self, client, patch_fetch, patch_scroll_delta_url): + async def test_drive_items_delta( + self, client, patch_fetch, patch_scroll_delta_url + ) -> None: delta_url_input = "https://sharepoint.com/delta-link-lalal" delta_url_next_page = "https://sharepoint.com/delta-link-lalal/page-2" delta_url_next_sync = "https://sharepoint.com/delta-link-lalal/next-sync" @@ -1326,7 +1335,7 @@ async def test_drive_items_delta(self, client, patch_fetch, patch_scroll_delta_u assert returned_drive_items == items_page_1 + items_page_2 @pytest.mark.asyncio - async def test_drive_items(self, client, patch_fetch): + async def test_drive_items(self, client, patch_fetch) -> None: drive_id = "12345" delta_url_next_sync = "https://sharepoint.com/delta-link-lalal/page-2" items_page_1 = ["1", "2"] @@ -1350,7 +1359,7 @@ async def test_drive_items(self, client, patch_fetch): assert returned_items == items_page_1 + items_page_2 @pytest.mark.asyncio - async def test_download_drive_item(self, client, patch_pipe): + async def test_download_drive_item(self, client, patch_pipe) -> None: """Basic setup for the test - no recursion through directories""" drive_id = "1" item_id = "2" @@ -1361,7 +1370,7 @@ async def test_download_drive_item(self, client, patch_pipe): patch_pipe.assert_awaited_once_with(ANY, async_buffer) @pytest.mark.asyncio - async def test_site_lists(self, client, patch_scroll): + async def test_site_lists(self, client, patch_scroll) -> None: site_id = "12345" actual_items = ["1", "2", "3", "4"] @@ -1374,7 +1383,7 @@ async def test_site_lists(self, client, patch_scroll): assert returned_items == actual_items @pytest.mark.asyncio - async def test_site_list_items(self, client, patch_scroll): + async def test_site_list_items(self, client, patch_scroll) -> None: site_id = "12345" list_id = "54321" @@ -1390,7 +1399,7 @@ async def test_site_list_items(self, client, patch_scroll): assert returned_items == actual_items @pytest.mark.asyncio - async def test_site_list_item_attachments(self, client, patch_fetch): + async def test_site_list_item_attachments(self, client, patch_fetch) -> None: site_web_url = f"https://{self.tenant_name}.sharepoint.com" list_title = "Summer Vacation Notes" list_item_id = "1" @@ -1409,7 +1418,9 @@ async def test_site_list_item_attachments(self, client, patch_fetch): assert returned_items == actual_attachments @pytest.mark.asyncio - async def test_site_list_item_attachments_not_found(self, client, patch_fetch): + async def test_site_list_item_attachments_not_found( + self, client, patch_fetch + ) -> None: site_web_url = f"https://{self.tenant_name}.sharepoint.com" list_title = "Summer Vacation Notes" list_item_id = "1" @@ -1425,7 +1436,7 @@ async def test_site_list_item_attachments_not_found(self, client, patch_fetch): assert len(returned_items) == 0 @pytest.mark.asyncio - async def test_site_list_item_attachments_wrong_tenant(self, client): + async def test_site_list_item_attachments_wrong_tenant(self, client) -> None: invalid_tenant_name = "something" site_web_url = f"https://{invalid_tenant_name}.sharepoint.com" list_title = "Summer Vacation Notes" @@ -1443,7 +1454,7 @@ async def test_site_list_item_attachments_wrong_tenant(self, client): assert e.match(self.tenant_name) @pytest.mark.asyncio - async def test_download_attachment(self, client, patch_pipe): + async def test_download_attachment(self, client, patch_pipe) -> None: attachment_path = f"https://{self.tenant_name}.sharepoint.com/thats/a/made/up/attachment/path.jpg" async_buffer = MagicMock() @@ -1452,7 +1463,7 @@ async def test_download_attachment(self, client, patch_pipe): patch_pipe.assert_awaited_once_with(ANY, async_buffer) @pytest.mark.asyncio - async def test_download_attachment_wrong_tenant(self, client, patch_pipe): + async def test_download_attachment_wrong_tenant(self, client, patch_pipe) -> None: invalid_tenant_name = "something" attachment_path = f"https://{invalid_tenant_name}.sharepoint.com/thats/a/made/up/attachment/path.jpg" async_buffer = MagicMock() @@ -1466,7 +1477,7 @@ async def test_download_attachment_wrong_tenant(self, client, patch_pipe): assert e.match(self.tenant_name) @pytest.mark.asyncio - async def test_site_pages(self, client, patch_scroll): + async def test_site_pages(self, client, patch_scroll) -> None: page_url_path = f"https://{self.tenant_name}.sharepoint.com/random/totally/made/up/page.aspx" actual_items = [{"Id": "1"}, {"Id": "2"}, {"Id": "3"}, {"Id": "4"}] @@ -1478,7 +1489,7 @@ async def test_site_pages(self, client, patch_scroll): assert [{"Id": i["Id"]} for i in returned_items] == actual_items @pytest.mark.asyncio - async def test_site_pages_not_found(self, client, patch_scroll): + async def test_site_pages_not_found(self, client, patch_scroll) -> None: page_url_path = f"https://{self.tenant_name}.sharepoint.com/random/totally/made/up/page.aspx" patch_scroll.side_effect = NotFound() @@ -1490,7 +1501,9 @@ async def test_site_pages_not_found(self, client, patch_scroll): assert len(returned_items) == 0 @pytest.mark.asyncio - async def test_site_page_has_unique_role_assignments(self, client, patch_fetch): + async def test_site_page_has_unique_role_assignments( + self, client, patch_fetch + ) -> None: url = f"https://{self.tenant_name}.sharepoint.com" site_page_id = 1 @@ -1501,7 +1514,7 @@ async def test_site_page_has_unique_role_assignments(self, client, patch_fetch): @pytest.mark.asyncio async def test_site_page_has_unique_role_assignments_not_found( self, client, patch_fetch - ): + ) -> None: url = f"https://{self.tenant_name}.sharepoint.com" site_page_id = 1 @@ -1510,7 +1523,7 @@ async def test_site_page_has_unique_role_assignments_not_found( assert not await client.site_page_has_unique_role_assignments(url, site_page_id) @pytest.mark.asyncio - async def test_site_pages_wrong_tenant(self, client, patch_scroll): + async def test_site_pages_wrong_tenant(self, client, patch_scroll) -> None: invalid_tenant_name = "something" page_url_path = f"https://{invalid_tenant_name}.sharepoint.com/random/totally/made/up/page.aspx" @@ -1524,7 +1537,7 @@ async def test_site_pages_wrong_tenant(self, client, patch_scroll): assert e.match(self.tenant_name) @pytest.mark.asyncio - async def test_tenant_details(self, client, patch_fetch): + async def test_tenant_details(self, client, patch_fetch) -> None: http_call_result = {"hello": "world"} patch_fetch.return_value = http_call_result @@ -1534,7 +1547,7 @@ async def test_tenant_details(self, client, patch_fetch): assert http_call_result == actual_result @pytest.mark.asyncio - async def test_site_group_users(self, client, patch_scroll): + async def test_site_group_users(self, client, patch_scroll) -> None: site_group_id = 42 site_groups_url = f"https://{self.tenant_name}.sharepoint.com/_api/web/sitegroups/getbyid({site_group_id})/users" users = [ @@ -1558,7 +1571,7 @@ async def test_site_group_users(self, client, patch_scroll): assert actual_group == users @pytest.mark.asyncio - async def test_site_group_not_found(self, client, patch_scroll): + async def test_site_group_not_found(self, client, patch_scroll) -> None: site_group_id = 42 site_groups_url = f"https://{self.tenant_name}.sharepoint.com/_api/web/sitegroups/getbyid({site_group_id})/users" @@ -1571,7 +1584,7 @@ async def test_site_group_not_found(self, client, patch_scroll): assert len(returned_items) == 0 @pytest.mark.asyncio - async def test_site_users(self, client, patch_scroll): + async def test_site_users(self, client, patch_scroll) -> None: site_users_url = f"https://{self.tenant_name}.sharepoint.com/random/totally/made/up/siteusers" users = ["user1", "user2"] @@ -1582,7 +1595,7 @@ async def test_site_users(self, client, patch_scroll): assert actual_users == users @pytest.mark.asyncio - async def test_site_users_not_found(self, client, patch_scroll): + async def test_site_users_not_found(self, client, patch_scroll) -> None: site_users_url = f"https://{self.tenant_name}.sharepoint.com/random/totally/made/up/siteusers" patch_scroll.side_effect = NotFound() @@ -1593,7 +1606,7 @@ async def test_site_users_not_found(self, client, patch_scroll): assert len(returned_items) == 0 @pytest.mark.asyncio - async def test_drive_items_permissions_batch(self, client, patch_post): + async def test_drive_items_permissions_batch(self, client, patch_post) -> None: drive_id = 1 drive_item_ids = [1, 2, 3] batch_response = { @@ -1618,7 +1631,9 @@ async def test_drive_items_permissions_batch(self, client, patch_post): client._graph_api_client.post.assert_awaited_with(ANY, expected_batch_request) @pytest.mark.asyncio - async def test_drive_items_permissions_batch_not_found(self, client, patch_post): + async def test_drive_items_permissions_batch_not_found( + self, client, patch_post + ) -> None: drive_id = 1 drive_item_ids = [1, 2, 3] @@ -1633,7 +1648,9 @@ async def test_drive_items_permissions_batch_not_found(self, client, patch_post) assert len(responses) == 0 @pytest.mark.asyncio - async def test_drive_items_permissions_batch_empty(self, client, patch_post): + async def test_drive_items_permissions_batch_empty( + self, client, patch_post + ) -> None: drive_id = 1 drive_item_ids = [] @@ -1644,7 +1661,7 @@ async def test_drive_items_permissions_batch_empty(self, client, patch_post): raise Exception(msg) @pytest.mark.asyncio - async def test_site_role_assignments(self, client, patch_scroll): + async def test_site_role_assignments(self, client, patch_scroll) -> None: site_role_assignments_url = ( f"https://{self.tenant_name}.sharepoint.com/sites/test" ) @@ -1663,7 +1680,9 @@ async def test_site_role_assignments(self, client, patch_scroll): assert actual_role_assignments == role_assignments @pytest.mark.asyncio - async def test_site_list_has_unique_role_assignments(self, client, patch_fetch): + async def test_site_list_has_unique_role_assignments( + self, client, patch_fetch + ) -> None: site_list_role_assignments_url = f"https://{self.tenant_name}.sharepoint.com/random/totally/made/up/roleassignments" site_list_name = "site_list" @@ -1676,7 +1695,7 @@ async def test_site_list_has_unique_role_assignments(self, client, patch_fetch): @pytest.mark.asyncio async def test_site_list_has_unique_role_assignments_not_found( self, client, patch_fetch - ): + ) -> None: site_list_role_assignments_url = f"https://{self.tenant_name}.sharepoint.com/random/totally/made/up/roleassignments" site_list_name = "site_list" @@ -1687,7 +1706,7 @@ async def test_site_list_has_unique_role_assignments_not_found( ) @pytest.mark.asyncio - async def test_site_list_role_assignments(self, client, patch_scroll): + async def test_site_list_role_assignments(self, client, patch_scroll) -> None: site_list_role_assignments_url = f"https://{self.tenant_name}.sharepoint.com/random/totally/made/up/roleassignments" site_list_name = "site_list" @@ -1706,7 +1725,9 @@ async def test_site_list_role_assignments(self, client, patch_scroll): assert actual_role_assignments == role_assignments @pytest.mark.asyncio - async def test_site_list_role_assignments_not_found(self, client, patch_scroll): + async def test_site_list_role_assignments_not_found( + self, client, patch_scroll + ) -> None: site_list_role_assignments_url = f"https://{self.tenant_name}.sharepoint.com/random/totally/made/up/roleassignments" site_list_name = "site_list" @@ -1723,7 +1744,7 @@ async def test_site_list_role_assignments_not_found(self, client, patch_scroll): @pytest.mark.asyncio async def test_site_list_item_has_unique_role_assignments( self, client, patch_fetch - ): + ) -> None: site_list_item_role_assignments_url = f"https://{self.tenant_name}.sharepoint.com/random/totally/made/up/roleassignments" list_title = "list_title" list_item_id = 1 @@ -1737,7 +1758,7 @@ async def test_site_list_item_has_unique_role_assignments( @pytest.mark.asyncio async def test_site_list_item_has_unique_role_assignments_not_found( self, client, patch_fetch - ): + ) -> None: site_list_item_role_assignments_url = f"https://{self.tenant_name}.sharepoint.com/random/totally/made/up/roleassignments" list_title = "list_title" list_item_id = 1 @@ -1751,7 +1772,7 @@ async def test_site_list_item_has_unique_role_assignments_not_found( @pytest.mark.asyncio async def test_site_list_item_has_unique_role_assignments_bad_request( self, client, patch_fetch - ): + ) -> None: site_list_item_role_assignments_url = f"https://{self.tenant_name}.sharepoint.com/random/totally/made/up/roleassignments" list_title = "list_title" list_item_id = 1 @@ -1763,7 +1784,7 @@ async def test_site_list_item_has_unique_role_assignments_bad_request( ) @pytest.mark.asyncio - async def test_site_list_item_role_assignments(self, client, patch_scroll): + async def test_site_list_item_role_assignments(self, client, patch_scroll) -> None: site_list_item_role_assignments_url = f"https://{self.tenant_name}.sharepoint.com/random/totally/made/up/roleassignments" list_title = "list_title" list_item_id = 1 @@ -1786,7 +1807,7 @@ async def test_site_list_item_role_assignments(self, client, patch_scroll): @pytest.mark.asyncio async def test_site_list_item_role_assignments_not_found( self, client, patch_scroll - ): + ) -> None: site_list_item_role_assignments_url = f"https://{self.tenant_name}.sharepoint.com/random/totally/made/up/roleassignments" list_title = "list_title" list_item_id = 1 @@ -1802,7 +1823,7 @@ async def test_site_list_item_role_assignments_not_found( assert len(role_assignments) == 0 @pytest.mark.asyncio - async def test_site_page_role_assignments(self, client, patch_scroll): + async def test_site_page_role_assignments(self, client, patch_scroll) -> None: site_web_url = f"https://{self.tenant_name}.sharepoint.com/random/totally/made/up/roleassignments" site_page_id = 1 role_assignments = [{"value": ["role"]}] @@ -1816,7 +1837,9 @@ async def test_site_page_role_assignments(self, client, patch_scroll): assert actual_role_assignments == role_assignments @pytest.mark.asyncio - async def test_site_page_role_assignments_not_found(self, client, patch_scroll): + async def test_site_page_role_assignments_not_found( + self, client, patch_scroll + ) -> None: site_page_role_assignments_url = f"https://{self.tenant_name}.sharepoint.com/random/totally/made/up/roleassignments" site_page_id = 1 @@ -1831,7 +1854,9 @@ async def test_site_page_role_assignments_not_found(self, client, patch_scroll): assert len(returned_items) == 0 @pytest.mark.asyncio - async def test_users_and_groups_for_role_assignment(self, client, patch_fetch): + async def test_users_and_groups_for_role_assignment( + self, client, patch_fetch + ) -> None: users_by_id_url = ( f"https://{self.tenant_name}.sharepoint.com/random/totally/made/up/users" ) @@ -1849,7 +1874,7 @@ async def test_users_and_groups_for_role_assignment(self, client, patch_fetch): @pytest.mark.asyncio async def test_users_and_groups_for_role_assignment_not_found( self, client, patch_fetch - ): + ) -> None: users_by_id_url = ( f"https://{self.tenant_name}.sharepoint.com/random/totally/made/up/users" ) @@ -1866,7 +1891,7 @@ async def test_users_and_groups_for_role_assignment_not_found( @pytest.mark.asyncio async def test_users_and_groups_for_role_assignment_internal_server_error( self, client, patch_fetch - ): + ) -> None: users_by_id_url = ( f"https://{self.tenant_name}.sharepoint.com/random/totally/made/up/users" ) @@ -1883,7 +1908,7 @@ async def test_users_and_groups_for_role_assignment_internal_server_error( @pytest.mark.asyncio async def test_users_and_groups_for_role_assignment_missing_principal_id( self, client, patch_fetch - ): + ) -> None: users_by_id_url = ( f"https://{self.tenant_name}.sharepoint.com/random/totally/made/up/users" ) @@ -1901,7 +1926,7 @@ async def test_users_and_groups_for_role_assignment_missing_principal_id( assert len(actual_users_and_groups) == 0 @pytest.mark.asyncio - async def test_groups_user_transitive_member_of(self, client, patch_scroll): + async def test_groups_user_transitive_member_of(self, client, patch_scroll) -> None: user_id = "12345" group_one = {"name": "some group"} group_two = {"name": "some other group"} @@ -1920,7 +1945,7 @@ async def test_groups_user_transitive_member_of(self, client, patch_scroll): @pytest.mark.asyncio async def test_groups_user_transitive_member_of_with_not_found_raised( self, client, patch_scroll - ): + ) -> None: user_id = "12345" patch_scroll.side_effect = NotFound() @@ -1931,7 +1956,7 @@ async def test_groups_user_transitive_member_of_with_not_found_raised( assert len(returned_groups) == 0 @pytest.mark.asyncio - async def test_group_members(self, client, patch_scroll): + async def test_group_members(self, client, patch_scroll) -> None: group_id = "12345" member_one = {"name": "some member"} member_two = {"name": "some other member"} @@ -1948,7 +1973,9 @@ async def test_group_members(self, client, patch_scroll): assert actual_members == expected_members @pytest.mark.asyncio - async def test_group_members_with_not_found_raised(self, client, patch_scroll): + async def test_group_members_with_not_found_raised( + self, client, patch_scroll + ) -> None: group_id = "12345" patch_scroll.side_effect = NotFound() @@ -1959,7 +1986,7 @@ async def test_group_members_with_not_found_raised(self, client, patch_scroll): assert len(returned_members) == 0 @pytest.mark.asyncio - async def test_group_owners(self, client, patch_scroll): + async def test_group_owners(self, client, patch_scroll) -> None: group_id = "12345" owner_one = {"name": "some owner"} owner_two = {"name": "some other owner"} @@ -1976,7 +2003,9 @@ async def test_group_owners(self, client, patch_scroll): assert actual_owners == expected_owners @pytest.mark.asyncio - async def test_group_owners_with_not_found_raised(self, client, patch_scroll): + async def test_group_owners_with_not_found_raised( + self, client, patch_scroll + ) -> None: group_id = "12345" patch_scroll.side_effect = NotFound() @@ -1989,11 +2018,11 @@ async def test_group_owners_with_not_found_raised(self, client, patch_scroll): class TestSharepointOnlineAdvancedRulesValidator: @pytest_asyncio.fixture - def validator(self): + def validator(self) -> SharepointOnlineAdvancedRulesValidator: return SharepointOnlineAdvancedRulesValidator() @pytest.mark.asyncio - async def test_validate(self, validator): + async def test_validate(self, validator) -> None: valid_rules = {"skipExtractingDriveItemsOlderThan": 15} result = await validator.validate(valid_rules) @@ -2001,7 +2030,7 @@ async def test_validate(self, validator): assert result.is_valid @pytest.mark.asyncio - async def test_validate_invalid_rule(self, validator): + async def test_validate_invalid_rule(self, validator) -> None: invalid_rules = {"skipExtractingDriveItemsOlderThan": "why is this a string"} result = await validator.validate(invalid_rules) @@ -2011,29 +2040,29 @@ async def test_validate_invalid_rule(self, validator): class TestSharepointOnlineDataSource: @property - def today(self): + def today(self) -> str: return datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ") @property - def day_ago(self): + def day_ago(self) -> str: return (datetime.now(timezone.utc) - timedelta(days=1)).strftime( "%Y-%m-%dT%H:%M:%SZ" ) @property - def month_ago(self): + def month_ago(self) -> str: return (datetime.now(timezone.utc) - timedelta(days=30)).strftime( "%Y-%m-%dT%H:%M:%SZ" ) @property - def two_months_ago(selfself): + def two_months_ago(selfself) -> str: return (datetime.now(timezone.utc) - timedelta(days=60)).strftime( "%Y-%m-%dT%H:%M:%SZ" ) @property - def site_collections(self): + def site_collections(self) -> List[Dict[str, Union[Dict[str, str], str]]]: return [ { "siteCollection": {"hostname": "test.sharepoint.com"}, @@ -2043,7 +2072,7 @@ def site_collections(self): ] @property - def sites(self): + def sites(self) -> List[Dict[str, Any]]: return [ { "id": "1", @@ -2055,11 +2084,11 @@ def sites(self): ] @property - def site_drives(self): + def site_drives(self) -> List[Dict[str, str]]: return [{"id": "2", "lastModifiedDateTime": self.day_ago}] @property - def drive_items(self): + def drive_items(self) -> List[DriveItemsPage]: return [ DriveItemsPage( items=[ @@ -2079,7 +2108,7 @@ def drive_items(self): ] @property - def site_lists(self): + def site_lists(self) -> List[Dict[str, str]]: return [ { "id": SITE_LIST_ONE_ID, @@ -2089,7 +2118,7 @@ def site_lists(self): ] @property - def site_list_has_unique_role_assignments(self): + def site_list_has_unique_role_assignments(self) -> bool: return True @property @@ -2116,14 +2145,14 @@ def site_list_items(self): ] @property - def site_list_item_attachments(self): + def site_list_item_attachments(self) -> List[Dict[str, str]]: return [ {"odata.id": "9", "name": "attachment 1.txt"}, {"odata.id": "10", "name": "attachment 2.txt"}, ] @property - def site_pages(self): + def site_pages(self) -> List[Dict[str, str]]: return [ { "Id": "4", @@ -2134,7 +2163,17 @@ def site_pages(self): ] @property - def site_role_assignments(self): + def site_role_assignments( + self, + ) -> List[ + Dict[ + str, + Union[ + Dict[str, Union[List[Dict[str, str]], str]], + List[Dict[str, Union[Dict[str, str], int, str]]], + ], + ] + ]: return [ { "Member": { @@ -2169,7 +2208,7 @@ def site_role_assignments(self): ] @property - def site_admins(self): + def site_admins(self) -> List[Dict[str, str]]: return [ { "LoginName": "c:0o.c|federateddirectoryclaimprovider|97d055cf-5cdf-4e5e-b383-f01ed3a8844d_o", @@ -2181,7 +2220,7 @@ def site_admins(self): ] @property - def group_members(self): + def group_members(self) -> List[Dict[str, str]]: return [ { "mail": MEMBER_ONE_EMAIL, @@ -2190,14 +2229,14 @@ def group_members(self): ] @property - def group_owners(self): + def group_owners(self) -> List[Dict[str, str]]: return [ {"mail": OWNER_ONE_EMAIL}, {"userPrincipalName": OWNER_TWO_USER_PRINCIPAL_NAME}, ] @property - def group(self): + def group(self) -> Dict[str, str]: return {"id": GROUP_ONE_ID} @property @@ -2210,47 +2249,47 @@ def site_users(self): ] @property - def site_list_role_assignments(self): + def site_list_role_assignments(self) -> Dict[str, List[str]]: return {"value": ["role"]} @property - def site_group_users(self): + def site_group_users(self) -> List[Dict[str, str]]: return SAMPLE_SITE_GROUP_USERS @property - def users_and_groups_for_role_assignments(self): + def users_and_groups_for_role_assignments(self) -> List[str]: return [USER_ONE_EMAIL, GROUP_ONE] @property - def site_list_item_role_assignments(self): + def site_list_item_role_assignments(self) -> Dict[str, List[str]]: return {"value": ["role"]} @property - def site_list_item_has_unique_role_assignments(self): + def site_list_item_has_unique_role_assignments(self) -> bool: return True @property - def site_page_has_unique_role_assignments(self): + def site_page_has_unique_role_assignments(self) -> bool: return True @property - def site_page_role_assignments(self): + def site_page_role_assignments(self) -> Dict[str, List[str]]: return {"value": ["role"]} @property - def graph_api_token(self): + def graph_api_token(self) -> str: return "graph bearer" @property - def rest_api_token(self): + def rest_api_token(self) -> str: return "rest bearer" @property - def valid_tenant(self): + def valid_tenant(self) -> Dict[str, str]: return {"NameSpaceType": "VALID"} @property - def drive_items_delta(self): + def drive_items_delta(self) -> List[DriveItemsPage]: return [ DriveItemsPage( items=[ @@ -2268,11 +2307,13 @@ def drive_items_delta(self): ] @property - def drive_item_permissions(self): + def drive_item_permissions( + self, + ) -> List[Dict[str, Union[Dict[str, Dict[str, str]], str]]]: return SAMPLE_DRIVE_PERMISSIONS @property - def drive_items_permissions_batch(self): + def drive_items_permissions_batch(self) -> List[Dict[str, Any]]: return [ { "id": drive_item_permission["id"], @@ -2342,14 +2383,16 @@ async def patch_sharepoint_client(self): yield client - def drive_items_func(self, drive_id, url=None): + def drive_items_func(self, drive_id, url=None) -> AsyncIterator: if not url: return AsyncIterator(self.drive_items) else: return AsyncIterator(self.drive_items_delta) @pytest.mark.asyncio - async def test_get_docs_without_access_control(self, patch_sharepoint_client): + async def test_get_docs_without_access_control( + self, patch_sharepoint_client + ) -> None: async with create_spo_source() as source: source._dls_enabled = Mock(return_value=False) @@ -2397,7 +2440,7 @@ async def test_get_docs_without_access_control(self, patch_sharepoint_client): ALLOW_ACCESS_CONTROL_PATCHED, ) @freeze_time(iso_utc()) - async def test_get_docs_with_access_control(self, patch_sharepoint_client): + async def test_get_docs_with_access_control(self, patch_sharepoint_client) -> None: group = "group" email = "email" user = "user" @@ -2501,7 +2544,7 @@ async def test_get_docs_with_access_control(self, patch_sharepoint_client): @pytest.mark.parametrize("sync_cursor", [None, {}]) async def test_get_docs_incrementally_with_empty_cursor( self, patch_sharepoint_client, sync_cursor - ): + ) -> None: async with create_spo_source() as source: with pytest.raises(SyncCursorEmpty): async for ( @@ -2513,7 +2556,7 @@ async def test_get_docs_incrementally_with_empty_cursor( @pytest.mark.asyncio @freeze_time(iso_utc()) - async def test_get_docs_incrementally(self, patch_sharepoint_client): + async def test_get_docs_incrementally(self, patch_sharepoint_client) -> None: async with create_spo_source() as source: source._site_access_control = AsyncMock(return_value=([], [])) # mock cache lookup @@ -2571,7 +2614,7 @@ async def test_get_docs_incrementally(self, patch_sharepoint_client): assert (operations["delete"]) == deleted @pytest.mark.asyncio - async def test_site_lists(self, patch_sharepoint_client): + async def test_site_lists(self, patch_sharepoint_client) -> None: async with create_spo_source( use_document_level_security=True, fetch_unique_list_permissions=False ) as source: @@ -2592,7 +2635,7 @@ async def test_site_lists(self, patch_sharepoint_client): @pytest.mark.asyncio async def test_site_lists_with_unique_role_assignments( self, patch_sharepoint_client - ): + ) -> None: async with create_spo_source(use_document_level_security=True) as source: site = {"id": "1", "webUrl": "some url"} site_access_control = ["some site specific access control"] @@ -2626,7 +2669,7 @@ async def test_site_lists_with_unique_role_assignments( patch_sharepoint_client.site_list_role_assignments.assert_called_once() @pytest.mark.asyncio - async def test_download_function_for_folder(self): + async def test_download_function_for_folder(self) -> None: async with create_spo_source() as source: drive_item = { "name": "folder", @@ -2740,7 +2783,7 @@ async def test_download_function_for_folder(self): @pytest.mark.asyncio async def test_with_drive_item_permissions( self, drive_item_permissions, site_group_users, expected_access_control - ): + ) -> None: async with create_spo_source(use_document_level_security=True) as source: drive_item = {"id": 1} @@ -2757,7 +2800,7 @@ async def test_with_drive_item_permissions( @pytest.mark.asyncio async def test_drive_items_batch_with_permissions_when_fetch_drive_item_permissions_enabled( self, patch_sharepoint_client - ): + ) -> None: async with create_spo_source(use_document_level_security=True) as source: drive_id = 1 drive_item_ids = ["1", "2"] @@ -2797,7 +2840,7 @@ async def test_drive_items_batch_with_permissions_when_fetch_drive_item_permissi @pytest.mark.asyncio async def test_drive_items_batch_with_permissions_when_fetch_drive_item_permissions_disabled( self, - ): + ) -> None: async with create_spo_source(fetch_drive_item_permissions=False) as source: drive_id = 1 drive_items_batch = [{"id": "1"}, {"id": "2"}] @@ -2820,7 +2863,7 @@ async def test_drive_items_batch_with_permissions_when_fetch_drive_item_permissi @pytest.mark.asyncio async def test_drive_items_batch_with_permissions_when_dls_disabled( self, - ): + ) -> None: async with create_spo_source(use_document_level_security=False) as source: drive_id = 1 drive_items_batch = [{"id": "1"}, {"id": "2"}] @@ -2843,7 +2886,7 @@ async def test_drive_items_batch_with_permissions_when_dls_disabled( @pytest.mark.asyncio async def test_drive_items_batch_with_permissions_for_delta_delete_operation( self, patch_sharepoint_client - ): + ) -> None: async with create_spo_source(use_document_level_security=True) as source: drive_id = 1 drive_item_ids = ["1", "2"] @@ -2887,7 +2930,7 @@ async def test_drive_items_batch_with_permissions_for_delta_delete_operation( ) async def test_drive_items_permissions_when_fetch_drive_item_permissions_enabled( self, patch_sharepoint_client - ): + ) -> None: group = _prefix_group("do-not-inherit-me") email = _prefix_email("should-not@be-inherited.com") user = _prefix_user("sorry-no-access-here") @@ -2937,7 +2980,7 @@ async def test_drive_items_permissions_when_fetch_drive_item_permissions_enabled ) async def test_site_page_permissions_when_fetch_drive_item_permissions_enabled( self, patch_sharepoint_client - ): + ) -> None: admin_email = _prefix_email("hello@iam-admin.com") admin_user = _prefix_user("admin-so-i-can-access-your-data") admin_site_access_controls = [admin_email, admin_user] @@ -2962,7 +3005,7 @@ async def test_site_page_permissions_when_fetch_drive_item_permissions_enabled( ) @pytest.mark.asyncio - async def test_download_function_for_deleted_item(self): + async def test_download_function_for_deleted_item(self) -> None: async with create_spo_source() as source: # deleted items don't have `name` property drive_item = {"id": "testid", "deleted": {"state": "deleted"}} @@ -2972,7 +3015,7 @@ async def test_download_function_for_deleted_item(self): assert download_result is None @pytest.mark.asyncio - async def test_download_function_for_unsupported_file(self): + async def test_download_function_for_unsupported_file(self) -> None: async with create_spo_source() as source: drive_item = { "id": "testid", @@ -2987,7 +3030,7 @@ async def test_download_function_for_unsupported_file(self): assert download_result is None @pytest.mark.asyncio - async def test_download_function_with_filtering_rule(self): + async def test_download_function_with_filtering_rule(self) -> None: async with create_spo_source() as source: max_drive_item_age = 15 drive_item = { @@ -3001,7 +3044,7 @@ async def test_download_function_with_filtering_rule(self): assert download_result is None - def test_get_default_configuration(self): + def test_get_default_configuration(self) -> None: config = SharepointOnlineDataSource.get_default_configuration() assert config is not None @@ -3009,7 +3052,7 @@ def test_get_default_configuration(self): @pytest.mark.asyncio async def test_validate_config_empty_config_with_secret_auth( self, patch_sharepoint_client - ): + ) -> None: async with create_spo_source( tenant_id="", tenant_name="", @@ -3028,7 +3071,7 @@ async def test_validate_config_empty_config_with_secret_auth( @pytest.mark.asyncio async def test_validate_config_empty_config_with_cert_auth( self, patch_sharepoint_client - ): + ) -> None: async with create_spo_source( tenant_id="", tenant_name="", client_id="", auth_method="certificate" ) as source: @@ -3042,7 +3085,7 @@ async def test_validate_config_empty_config_with_cert_auth( assert e.match("Content of private key file") @pytest.mark.asyncio - async def test_validate_config(self, patch_sharepoint_client): + async def test_validate_config(self, patch_sharepoint_client) -> None: async with create_spo_source() as source: await source.validate_config() @@ -3054,7 +3097,9 @@ async def test_validate_config(self, patch_sharepoint_client): patch_sharepoint_client.site_collections.assert_not_called() @pytest.mark.asyncio - async def test_validate_config_when_invalid_tenant(self, patch_sharepoint_client): + async def test_validate_config_when_invalid_tenant( + self, patch_sharepoint_client + ) -> None: invalid_tenant_name = "wat" async with create_spo_source( @@ -3072,7 +3117,7 @@ async def test_validate_config_when_invalid_tenant(self, patch_sharepoint_client @pytest.mark.asyncio async def test_validate_config_non_existing_collection( self, patch_sharepoint_client - ): + ) -> None: non_existing_site = "something" another_non_existing_site = "something-something" @@ -3089,7 +3134,7 @@ async def test_validate_config_non_existing_collection( @pytest.mark.asyncio async def test_validate_config_with_existing_collection_fetching_all_sites( self, patch_sharepoint_client - ): + ) -> None: existing_site = "site-1" async with create_spo_source( @@ -3100,7 +3145,7 @@ async def test_validate_config_with_existing_collection_fetching_all_sites( @pytest.mark.asyncio async def test_validate_config_with_existing_collection_fetching_individual_sites( self, patch_sharepoint_client - ): + ) -> None: existing_site = "site_1" async with create_spo_source( @@ -3109,7 +3154,7 @@ async def test_validate_config_with_existing_collection_fetching_individual_site await source.validate_config() @pytest.mark.asyncio - async def test_get_attachment_content(self, patch_sharepoint_client): + async def test_get_attachment_content(self, patch_sharepoint_client) -> None: attachment = {"odata.id": "1", "_original_filename": "file.ppt"} message = b"This is content of attachment" @@ -3126,7 +3171,7 @@ async def download_func(attachment_id, async_buffer): @pytest.mark.asyncio async def test_get_attachment_content_unsupported_file_type( self, patch_sharepoint_client - ): + ) -> None: filename = "file.unsupported_extention" attachment = { "odata.id": "1", @@ -3156,7 +3201,7 @@ async def download_func(attachment_id, async_buffer): ) async def test_get_attachment_with_text_extraction_enabled_adds_body( self, patch_sharepoint_client - ): + ) -> None: attachment = {"odata.id": "1", "_original_filename": "file.ppt"} message = "This is the text content of drive item" @@ -3191,7 +3236,7 @@ async def download_func(attachment_id, async_buffer): ) async def test_get_attachment_with_text_extraction_enabled_but_not_configured_adds_empty_string( self, patch_sharepoint_client - ): + ) -> None: attachment = {"odata.id": "1", "_original_filename": "file.ppt"} message = "This is the text content of drive item" @@ -3229,7 +3274,7 @@ async def download_func(attachment_id, async_buffer): ) async def test_get_drive_item_content( self, patch_sharepoint_client, filesize, expect_download - ): + ) -> None: drive_item = { "id": "1", "size": filesize, @@ -3262,7 +3307,7 @@ async def download_func(drive_id, drive_item_id, async_buffer): ) async def test_get_content_with_text_extraction_enabled_adds_body( self, patch_sharepoint_client, filesize - ): + ) -> None: drive_item = { "id": "1", "size": filesize, @@ -3304,7 +3349,7 @@ async def download_func(drive_id, drive_item_id, async_buffer): ) async def test_get_content_with_text_extraction_enabled_but_not_configured_adds_empty_string( self, patch_sharepoint_client, filesize - ): + ) -> None: drive_item = { "id": "1", "size": filesize, @@ -3339,7 +3384,7 @@ async def download_func(drive_id, drive_item_id, async_buffer): assert "_attachment" not in download_result @pytest.mark.asyncio - async def test_site_access_control(self, patch_sharepoint_client): + async def test_site_access_control(self, patch_sharepoint_client) -> None: async with create_spo_source(use_document_level_security=True) as source: patch_sharepoint_client._validate_sharepoint_rest_url = Mock() @@ -3411,7 +3456,7 @@ async def test_site_access_control(self, patch_sharepoint_client): ) async def test_decorate_with_access_control( self, _dls_enabled, document, access_control, expected_decorated_document - ): + ) -> None: async with create_spo_source(use_document_level_security=True) as source: decorated_document = source._decorate_with_access_control( document, access_control @@ -3452,7 +3497,7 @@ async def test_decorate_with_access_control( @pytest.mark.asyncio async def test_dls_enabled( self, dls_feature_flag, dls_config_value, expected_dls_enabled - ): + ) -> None: async with create_spo_source() as source: source._features = Mock() source._features.document_level_security_enabled = Mock( @@ -3465,7 +3510,7 @@ async def test_dls_enabled( assert source._dls_enabled() == expected_dls_enabled @pytest.mark.asyncio - async def test_dls_disabled_with_features_missing(self): + async def test_dls_disabled_with_features_missing(self) -> None: async with create_spo_source() as source: source._features = None @@ -3477,7 +3522,7 @@ async def test_dls_disabled_with_features_missing(self): TIMESTAMP_FORMAT_PATCHED, ) @pytest.mark.asyncio - async def test_user_access_control_doc(self, patch_sharepoint_client): + async def test_user_access_control_doc(self, patch_sharepoint_client) -> None: async with create_spo_source() as source: created_at = "2023-05-25T13:30:54Z" group_one = {"id": "group-one-id"} @@ -3519,7 +3564,7 @@ async def test_user_access_control_doc(self, patch_sharepoint_client): @pytest.mark.asyncio async def test_user_access_control_doc_with_null_created_date_time( self, patch_sharepoint_client - ): + ) -> None: async with create_spo_source() as source: patch_sharepoint_client.groups_user_transitive_member_of = AsyncIterator([]) @@ -3536,7 +3581,9 @@ async def test_user_access_control_doc_with_null_created_date_time( assert user_doc["created_at"] is None @pytest.mark.asyncio - async def test_get_access_control_with_dls_disabled(self, patch_sharepoint_client): + async def test_get_access_control_with_dls_disabled( + self, patch_sharepoint_client + ) -> None: async with create_spo_source() as source: patch_sharepoint_client.site_collections = AsyncIterator( [{"siteCollection": {"hostname": "localhost"}}] @@ -3557,7 +3604,7 @@ async def test_get_access_control_with_dls_disabled(self, patch_sharepoint_clien @pytest.mark.asyncio async def test_get_access_control_with_dls_enabled_and_fetch_all_users( self, patch_sharepoint_client - ): + ) -> None: async with create_spo_source(use_document_level_security=True) as source: group = {"@odata.type": "#microsoft.graph.group", "id": "doop"} member_email = "member@acme.co" @@ -3590,22 +3637,22 @@ async def test_get_access_control_with_dls_enabled_and_fetch_all_users( assert len(user_access_control_docs) == 2 - def test_prefix_group(self): + def test_prefix_group(self) -> None: group = "group" assert _prefix_group(group) == "group:group" - def test_prefix_user(self): + def test_prefix_user(self) -> None: user = "user" assert _prefix_user(user) == "user:user" - def test_prefix_email(self): + def test_prefix_email(self) -> None: email = "email" assert _prefix_email(email) == "email:email" - def test_prefix_user_id(self): + def test_prefix_user_id(self) -> None: user_id = "user id" assert _prefix_user_id(user_id) == "user_id:user id" @@ -3853,7 +3900,7 @@ def test_prefix_user_id(self): @pytest.mark.asyncio async def test_get_access_control_from_role_assignment( self, role_assignment, expected_access_control - ): + ) -> None: async with create_spo_source() as source: access_control = await source._get_access_control_from_role_assignment( role_assignment @@ -3876,5 +3923,5 @@ async def test_get_access_control_from_role_assignment( (None, None), ], ) - def test_get_login_name(self, raw_login_name, expected_login_name): + def test_get_login_name(self, raw_login_name, expected_login_name) -> None: assert _get_login_name(raw_login_name) == expected_login_name diff --git a/tests/sources/test_sharepoint_server.py b/tests/sources/test_sharepoint_server.py index 03faaefc3..92bda444b 100644 --- a/tests/sources/test_sharepoint_server.py +++ b/tests/sources/test_sharepoint_server.py @@ -26,7 +26,10 @@ @asynccontextmanager async def create_sps_source( - ssl_enabled=False, ssl_ca="", retry_count=3, use_text_extraction_service=False + ssl_enabled: bool = False, + ssl_ca: str = "", + retry_count: int = 3, + use_text_extraction_service: bool = False, ): async with create_source( SharepointServerDataSource, @@ -45,7 +48,7 @@ async def create_sps_source( class MockSSL: """This class contains methods which returns dummy ssl context""" - def load_verify_locations(self, cadata): + def load_verify_locations(self, cadata) -> None: """This method verify locations""" pass @@ -53,7 +56,7 @@ def load_verify_locations(self, cadata): class MockResponse: """Mock class for ClientResponse""" - def __init__(self, json, status): + def __init__(self, json, status) -> None: self._json = json self.status = status self.is_success = True @@ -62,7 +65,7 @@ async def json(self): """This Method is used to return a json response""" return self._json - async def __aexit__(self, exc_type, exc, tb): + async def __aexit__(self, exc_type, exc, tb) -> None: """Closes an async with block""" pass @@ -74,11 +77,11 @@ async def __aenter__(self): class MockObjectResponse: """Class to mock object response of httpx session.get method""" - def __init__(self): + def __init__(self) -> None: """Setup a streamReader object""" self.content = ByteStream - async def __aexit__(self, exc_type, exc, tb): + async def __aexit__(self, exc_type, exc, tb) -> None: """Closes an async with block""" pass @@ -96,7 +99,7 @@ async def async_native_coroutine_generator(item): @pytest.mark.asyncio -async def test_ping_for_successful_connection(): +async def test_ping_for_successful_connection() -> None: """Tests the ping functionality for ensuring connection to the Sharepoint.""" async with create_sps_source() as source: @@ -109,7 +112,7 @@ async def test_ping_for_successful_connection(): @pytest.mark.asyncio -async def test_ping_for_failed_connection_exception(): +async def test_ping_for_failed_connection_exception() -> None: """Tests the ping functionality when connection can not be established to Sharepoint.""" async with create_sps_source(retry_count=0) as source: @@ -124,7 +127,7 @@ async def test_ping_for_failed_connection_exception(): @pytest.mark.asyncio -async def test_validate_config_when_host_url_is_empty(): +async def test_validate_config_when_host_url_is_empty() -> None: """This function test validate_config when host_url is empty""" async with create_sps_source() as source: source.configuration.set_field(name="host_url", value="") @@ -134,7 +137,9 @@ async def test_validate_config_when_host_url_is_empty(): @pytest.mark.asyncio -async def test_validate_config_for_ssl_enabled_when_ssl_ca_not_empty_does_not_raise_error(): +async def test_validate_config_for_ssl_enabled_when_ssl_ca_not_empty_does_not_raise_error() -> ( + None +): """This function test validate_config when ssl is enabled and certificate is missing""" with patch.object(ssl, "create_default_context", return_value=MockSSL()): async with create_sps_source( @@ -159,14 +164,14 @@ async def test_validate_config_for_ssl_enabled_when_ssl_ca_not_empty_does_not_ra @pytest.mark.asyncio -async def test_validate_config_for_ssl_enabled_when_ssl_ca_empty_raises_error(): +async def test_validate_config_for_ssl_enabled_when_ssl_ca_empty_raises_error() -> None: async with create_sps_source(ssl_enabled=True) as source: with pytest.raises(ConfigurableFieldValueError): await source.validate_config() @pytest.mark.asyncio -async def test_api_call_for_exception(): +async def test_api_call_for_exception() -> None: """This function test _api_call when credentials are incorrect""" async with create_sps_source(retry_count=0) as source: with patch.object( @@ -177,7 +182,7 @@ async def test_api_call_for_exception(): @pytest.mark.asyncio -async def test_prepare_drive_items_doc(): +async def test_prepare_drive_items_doc() -> None: """Test the prepare drive items method""" async with create_sps_source() as source: list_items = { @@ -215,7 +220,7 @@ async def test_prepare_drive_items_doc(): @pytest.mark.asyncio -async def test_prepare_list_items_doc(): +async def test_prepare_list_items_doc() -> None: """Test the prepare list items method""" async with create_sps_source() as source: list_items = { @@ -255,7 +260,7 @@ async def test_prepare_list_items_doc(): @pytest.mark.asyncio -async def test_prepare_sites_doc(): +async def test_prepare_sites_doc() -> None: """Test the method for preparing sites document""" async with create_sps_source() as source: list_items = { @@ -284,7 +289,7 @@ async def test_prepare_sites_doc(): @pytest.mark.asyncio -async def test_get_sites_when_no_site_available(): +async def test_get_sites_when_no_site_available() -> None: """Test get sites method with valid details""" async with create_sps_source() as source: api_response = [] @@ -299,7 +304,7 @@ async def test_get_sites_when_no_site_available(): @pytest.mark.asyncio -async def test_get_list_items(): +async def test_get_list_items() -> None: """Test get list items method with valid details""" api_response = [ { @@ -395,7 +400,7 @@ async def test_get_list_items(): @pytest.mark.asyncio -async def test_get_drive_items(): +async def test_get_drive_items() -> None: """Test get drive items method with valid details""" api_response = [ { @@ -481,7 +486,7 @@ async def test_get_drive_items(): @pytest.mark.asyncio -async def test_get_docs_list_items(): +async def test_get_docs_list_items() -> None: """Test get docs method for list items""" site_content_response = { @@ -533,7 +538,7 @@ async def test_get_docs_list_items(): @pytest.mark.asyncio -async def test_get_docs_list_items_when_relativeurl_is_not_none(): +async def test_get_docs_list_items_when_relativeurl_is_not_none() -> None: """Test get docs method for list items""" site_content_response = { @@ -584,7 +589,7 @@ async def test_get_docs_list_items_when_relativeurl_is_not_none(): @pytest.mark.asyncio -async def test_get_docs_drive_items(): +async def test_get_docs_drive_items() -> None: """Test get docs method for drive items""" site_content_response = { @@ -637,7 +642,7 @@ async def test_get_docs_drive_items(): @pytest.mark.asyncio -async def test_get_docs_drive_items_for_web_pages(): +async def test_get_docs_drive_items_for_web_pages() -> None: site_content_response = { "Title": "ctest", "Id": "f764b597-ed44-49be-8867-f8e9ca5d0a6e", @@ -687,7 +692,7 @@ async def test_get_docs_drive_items_for_web_pages(): @pytest.mark.asyncio -async def test_get_docs_when_no_site_available(): +async def test_get_docs_when_no_site_available() -> None: """Test get docs when site is not available method""" async with create_sps_source() as source: @@ -705,7 +710,7 @@ async def test_get_docs_when_no_site_available(): @pytest.mark.asyncio -async def test_get_content(): +async def test_get_content() -> None: """Test the get content method""" response_content = "This is a dummy sharepoint body response" @@ -738,7 +743,7 @@ async def test_get_content(): @pytest.mark.asyncio -async def test_get_content_with_content_extraction(): +async def test_get_content_with_content_extraction() -> None: response_content = "This is a dummy sharepoint body response" with ( patch( @@ -782,7 +787,7 @@ class ContentResponse: @pytest.mark.asyncio -async def test_get_content_when_size_is_bigger(): +async def test_get_content_when_size_is_bigger() -> None: """Test the get content method when document size is greater than the allowed size limit.""" document = { "id": 1, @@ -802,7 +807,7 @@ async def test_get_content_when_size_is_bigger(): @pytest.mark.asyncio -async def test_get_content_when_doit_is_none(): +async def test_get_content_when_doit_is_none() -> None: """Test the get content method when doit is None""" document = { "id": 1, @@ -820,7 +825,7 @@ async def test_get_content_when_doit_is_none(): @pytest.mark.asyncio -async def test_fetch_data_with_query_sites(): +async def test_fetch_data_with_query_sites() -> None: """Test get invoke call for sites""" get_response = { "value": [ @@ -854,7 +859,7 @@ async def test_fetch_data_with_query_sites(): @pytest.mark.asyncio -async def test_fetch_data_with_query_list(): +async def test_fetch_data_with_query_list() -> None: """Test get invoke call for list""" get_response = { "value": [ @@ -888,7 +893,7 @@ async def test_fetch_data_with_query_list(): @pytest.mark.asyncio -async def test_fetch_data_with_next_url_items(): +async def test_fetch_data_with_next_url_items() -> None: """Test get invoke call for drive item""" get_response = { "value": [ @@ -923,7 +928,7 @@ async def test_fetch_data_with_next_url_items(): @pytest.mark.asyncio -async def test_fetch_data_with_next_url_list_items(): +async def test_fetch_data_with_next_url_list_items() -> None: """Test get invoke call when for list item""" get_response = { "value": [ @@ -960,13 +965,13 @@ async def test_fetch_data_with_next_url_list_items(): class ClientSession: """Mock Client Session Class""" - async def aclose(self): + async def aclose(self) -> None: """Close method of Mock Client Session Class""" pass @pytest.mark.asyncio -async def test_close_with_client_session(): +async def test_close_with_client_session() -> None: """Test close method of SharepointServerDataSource with client session""" async with create_sps_source() as source: @@ -976,7 +981,7 @@ async def test_close_with_client_session(): @pytest.mark.asyncio -async def test_close_without_client_session(): +async def test_close_without_client_session() -> None: """Test close method of SharepointServerDataSource without client session""" async with create_sps_source() as source: @@ -986,7 +991,7 @@ async def test_close_without_client_session(): @pytest.mark.asyncio -async def test_api_call_negative(patch_default_wait_multiplier): +async def test_api_call_negative(patch_default_wait_multiplier) -> None: """Tests the _api_call function while getting an exception.""" async with create_sps_source(retry_count=2) as source: @@ -999,7 +1004,7 @@ async def test_api_call_negative(patch_default_wait_multiplier): @pytest.mark.asyncio -async def test_api_call_successfully(): +async def test_api_call_successfully() -> None: """Tests the _api_call function.""" async with create_sps_source() as source: @@ -1032,7 +1037,7 @@ class ClientErrorException: @pytest.mark.asyncio async def test_api_call_when_status_429_exception( patch_default_wait_multiplier, caplog -): +) -> None: async with create_sps_source(retry_count=2) as source: mock_response = {"access_token": "test2344", "expires_in": "1234555"} async_response = MockResponse(mock_response, 429) @@ -1052,7 +1057,9 @@ async def test_api_call_when_status_429_exception( @pytest.mark.asyncio -async def test_api_call_when_server_is_down(patch_default_wait_multiplier, caplog): +async def test_api_call_when_server_is_down( + patch_default_wait_multiplier, caplog +) -> None: """Tests the _api_call function while server gets disconnected.""" async with create_sps_source(retry_count=2) as source: mock_response = {"access_token": "test2344", "expires_in": "1234555"} @@ -1072,7 +1079,7 @@ async def test_api_call_when_server_is_down(patch_default_wait_multiplier, caplo @pytest.mark.asyncio -async def test_get_session(): +async def test_get_session() -> None: """Test that the instance of session returned is always the same for the datasource class.""" async with create_sps_source() as source: first_instance = source.sharepoint_client._get_session() @@ -1081,7 +1088,7 @@ async def test_get_session(): @pytest.mark.asyncio -async def test_get_site_pages_content(): +async def test_get_site_pages_content() -> None: EXPECTED_ATTACHMENT = { "id": 1, "server_relative_url": "/url", @@ -1114,7 +1121,7 @@ async def coroutine_generator(item): @pytest.mark.asyncio -async def test_get_site_pages_content_when_doit_is_none(): +async def test_get_site_pages_content_when_doit_is_none() -> None: document = {"title": "Home.aspx", "type": "File", "size": 1000000} async with create_sps_source() as source: response_content = await source.get_site_pages_content( @@ -1127,7 +1134,7 @@ async def test_get_site_pages_content_when_doit_is_none(): @pytest.mark.asyncio -async def test_get_site_pages_content_for_canvascontent1_none(): +async def test_get_site_pages_content_for_canvascontent1_none() -> None: async with create_sps_source() as source: EXPECTED_ATTACHMENT = {"title": "Home.aspx", "type": "File", "size": "1000000"} response_content = await source.get_site_pages_content( @@ -1139,7 +1146,7 @@ async def test_get_site_pages_content_for_canvascontent1_none(): @pytest.mark.asyncio -async def test_get_list_items_with_no_extension(): +async def test_get_list_items_with_no_extension() -> None: api_response = [ { "AttachmentFiles": [ @@ -1207,7 +1214,7 @@ async def test_get_list_items_with_no_extension(): @pytest.mark.asyncio -async def test_get_list_items_with_extension_only(): +async def test_get_list_items_with_extension_only() -> None: api_response = [ { "AttachmentFiles": [ @@ -1280,7 +1287,7 @@ async def create_fake_coroutine(data): @pytest.mark.asyncio -async def test_get_access_control(): +async def test_get_access_control() -> None: async with create_sps_source() as source: source._dls_enabled = Mock(return_value=True) source.sharepoint_client.site_collections_path = ["collection1"] @@ -1334,7 +1341,7 @@ async def test_get_access_control(): @pytest.mark.asyncio -async def test_get_access_control_with_dls_disabled(): +async def test_get_access_control_with_dls_disabled() -> None: async with create_sps_source() as source: source._dls_enabled = Mock(return_value=False) access_control = [] @@ -1346,7 +1353,7 @@ async def test_get_access_control_with_dls_disabled(): @pytest.mark.asyncio -async def test_get_docs_with_dls_enabled(): +async def test_get_docs_with_dls_enabled() -> None: async with create_sps_source() as source: source.sharepoint_client.site_collections_path = ["collection1"] source.sharepoint_client.fix_relative_url = Mock(return_value="sites/site1") @@ -1559,7 +1566,7 @@ async def test_get_docs_with_dls_enabled(): @pytest.mark.asyncio -async def test_site_list_item_role_assignments(): +async def test_site_list_item_role_assignments() -> None: api_response = {"value": [{"verifying_calls_only": True}]} async with create_sps_source() as source: source.sharepoint_client._api_call = Mock( @@ -1572,7 +1579,7 @@ async def test_site_list_item_role_assignments(): @pytest.mark.asyncio -async def test_site_role_assignments_using_title(): +async def test_site_role_assignments_using_title() -> None: api_response = {"value": [{"verifying_calls_only": True}]} async with create_sps_source() as source: source.sharepoint_client._api_call = Mock( @@ -1585,7 +1592,7 @@ async def test_site_role_assignments_using_title(): @pytest.mark.asyncio -async def test_site_admins(): +async def test_site_admins() -> None: api_response = {"value": [{"verifying_calls_only": True}]} async with create_sps_source() as source: source.sharepoint_client._api_call = Mock( @@ -1596,7 +1603,7 @@ async def test_site_admins(): @pytest.mark.asyncio -async def test_site_role_assignments(): +async def test_site_role_assignments() -> None: api_response = {"value": [{"verifying_calls_only": True}]} async with create_sps_source() as source: source.sharepoint_client._api_call = Mock( @@ -1607,7 +1614,7 @@ async def test_site_role_assignments(): @pytest.mark.asyncio -async def test_site_list_has_unique_role_assignments(): +async def test_site_list_has_unique_role_assignments() -> None: api_response = {"value": True} async with create_sps_source() as source: source.sharepoint_client._api_call = Mock( @@ -1620,7 +1627,7 @@ async def test_site_list_has_unique_role_assignments(): @pytest.mark.asyncio -async def test_site_list_item_has_unique_role_assignments(): +async def test_site_list_item_has_unique_role_assignments() -> None: api_response = {"value": True} async with create_sps_source() as source: source.sharepoint_client._api_call = Mock( diff --git a/tests/sources/test_slack.py b/tests/sources/test_slack.py index e40f9c34d..48640495f 100644 --- a/tests/sources/test_slack.py +++ b/tests/sources/test_slack.py @@ -44,7 +44,7 @@ async def slack_data_source(): @pytest.mark.asyncio -async def test_slack_client_list_channels(slack_client, mock_responses): +async def test_slack_client_list_channels(slack_client, mock_responses) -> None: page1 = { "ok": True, "channels": [{"name": "channel1", "is_member": True}], @@ -75,7 +75,7 @@ async def test_slack_client_list_channels(slack_client, mock_responses): @pytest.mark.asyncio -async def test_slack_client_list_messages(slack_client, mock_responses): +async def test_slack_client_list_messages(slack_client, mock_responses) -> None: timestamp1 = 1690674765 timestamp2 = 1690665000 timestamp3 = 1690761165 @@ -119,7 +119,7 @@ async def test_slack_client_list_messages(slack_client, mock_responses): @pytest.mark.asyncio -async def test_slack_client_list_users(slack_client, mock_responses): +async def test_slack_client_list_users(slack_client, mock_responses) -> None: response_data = {"ok": True, "members": [{"id": "user1"}]} mock_responses.get( "https://slack.com/api/users.list?limit=200", @@ -136,7 +136,7 @@ async def test_slack_client_list_users(slack_client, mock_responses): @pytest.mark.asyncio @patch("connectors.utils.time_to_sleep_between_retries", Mock(return_value=0)) -async def test_handle_throttled_error(slack_client, mock_responses): +async def test_handle_throttled_error(slack_client, mock_responses) -> None: channel = {"id": 1, "name": "test"} error_response_data = {"ok": False, "error": "rate_limited"} response_data = {"ok": True, "messages": [{"text": "message", "type": "message"}]} @@ -161,7 +161,7 @@ async def test_handle_throttled_error(slack_client, mock_responses): @pytest.mark.asyncio -async def test_ping(slack_client, mock_responses): +async def test_ping(slack_client, mock_responses) -> None: response_data = {"ok": True} mock_responses.get( "https://slack.com/api/auth.test", @@ -173,7 +173,7 @@ async def test_ping(slack_client, mock_responses): @pytest.mark.asyncio @patch("connectors.utils.time_to_sleep_between_retries", Mock(return_value=0)) -async def test_bad_ping(slack_client, mock_responses): +async def test_bad_ping(slack_client, mock_responses) -> None: response_data = {"ok": False, "error": "not_authed"} mock_responses.get( "https://slack.com/api/auth.test", @@ -188,7 +188,7 @@ async def test_bad_ping(slack_client, mock_responses): @pytest.mark.asyncio -async def test_slack_data_source_get_docs(slack_data_source, mock_responses): +async def test_slack_data_source_get_docs(slack_data_source, mock_responses) -> None: users_response = [{"id": "user1"}] channels_response = [{"id": "1", "name": "channel1", "is_member": True}] messages_response = [{"text": "message1", "type": "message", "ts": 123456}] @@ -225,7 +225,7 @@ async def test_slack_data_source_get_docs(slack_data_source, mock_responses): @pytest.mark.asyncio -async def test_slack_data_source_convert_usernames(slack_data_source): +async def test_slack_data_source_convert_usernames(slack_data_source) -> None: usernames = {"USERID1": "user_one"} message = {"text": "<@USERID1> Hello, <@USERID2>", "ts": 1} channel = {"id": 12345, "name": "channel"} diff --git a/tests/sources/test_zoom.py b/tests/sources/test_zoom.py index a4cf20fad..3a08c6239 100644 --- a/tests/sources/test_zoom.py +++ b/tests/sources/test_zoom.py @@ -6,8 +6,9 @@ """Tests the Zoom source class methods""" from contextlib import asynccontextmanager +from typing import Dict, List, Optional, Union from unittest import mock -from unittest.mock import Mock, patch +from unittest.mock import AsyncMock, Mock, patch import aiohttp import pytest @@ -26,7 +27,7 @@ "next_page_token": "page2", "users": [{"id": "user1", "type": "user", "name": "admin"}], } -SAMPLE_USER_PAGE2 = { +SAMPLE_USER_PAGE2: Dict[str, None] = { "next_page_token": None, "users": None, } @@ -41,7 +42,7 @@ ] # Meeting document -SAMPLE_LIVE_MEETING = { +SAMPLE_LIVE_MEETING: Dict[str, Optional[List[Dict[str, str]]]] = { "next_page_token": None, "meetings": [ { @@ -56,7 +57,7 @@ }, ], } -SAMPLE_UPCOMING_MEETING = { +SAMPLE_UPCOMING_MEETING: Dict[str, Optional[List[Dict[str, str]]]] = { "next_page_token": None, "meetings": [ { @@ -71,7 +72,7 @@ }, ], } -SAMPLE_PREVIOUS_MEETING = { +SAMPLE_PREVIOUS_MEETING: Dict[str, Optional[List[Dict[str, str]]]] = { "next_page_token": None, "meetings": [ { @@ -97,7 +98,7 @@ "type": "previous_meeting_detail", "created_at": "2023-03-03T00:00:00Z", } -SAMPLE_PREVIOUS_MEETING_PARTICIPANTS = { +SAMPLE_PREVIOUS_MEETING_PARTICIPANTS: Dict[str, Optional[List[Dict[str, str]]]] = { "next_page_token": None, "participants": [ {"id": "participant1", "type": "participant"}, @@ -115,7 +116,9 @@ {"id": "participant3", "type": "participant"}, ], } -MEETING_EXPECTED_RESPONSE = [ +MEETING_EXPECTED_RESPONSE: List[ + Union[Dict[str, str], Dict[str, Union[None, List[Dict[str, str]], str]]] +] = [ { "id": "meeting1", "type": "live_meeting", @@ -171,25 +174,25 @@ ] # Recording document -SAMPLE_RECORDING_PAGE1 = { +SAMPLE_RECORDING_PAGE1: Dict[str, Optional[List[Dict[str, str]]]] = { "next_page_token": None, "meetings": [ {"id": "recording1", "type": "recording", "start_time": "2023-03-01T00:00:00Z"} ], } -SAMPLE_RECORDING_PAGE2 = { +SAMPLE_RECORDING_PAGE2: Dict[str, Optional[List[Dict[str, str]]]] = { "next_page_token": None, "meetings": [ {"id": "recording2", "type": "recording", "start_time": "2023-02-01T00:00:00Z"} ], } -SAMPLE_RECORDING_PAGE3 = { +SAMPLE_RECORDING_PAGE3: Dict[str, Optional[List[Dict[str, str]]]] = { "next_page_token": None, "meetings": [ {"id": "recording3", "type": "recording", "start_time": "2023-01-01T00:00:00Z"} ], } -SAMPLE_RECORDING_PAGE4 = { +SAMPLE_RECORDING_PAGE4: Dict[str, Optional[List[Dict[str, str]]]] = { "next_page_token": None, "meetings": [ {"id": "recording4", "type": "recording", "start_time": "2023-12-01T00:00:00Z"} @@ -227,7 +230,7 @@ ] # Channel document -SAMPLE_CHANNEL = { +SAMPLE_CHANNEL: Dict[str, Optional[List[Dict[str, str]]]] = { "next_page_token": None, "channels": [ {"id": "channel1", "type": "chat", "date_time": "2023-03-09T00:00:00Z"}, @@ -252,7 +255,7 @@ ] # Chat document -SAMPLE_CHAT = { +SAMPLE_CHAT: Dict[str, Optional[List[Dict[str, str]]]] = { "next_page_token": None, "messages": [ {"id": "chat1", "type": "chat", "date_time": "2023-03-09T00:00:00Z"}, @@ -292,7 +295,7 @@ "_timestamp": "2023-03-09T00:00:00Z", "_attachment": "Q29udGVudA==", } -FILE_EXPECTED_CONTENT_EXTRACTED = { +FILE_EXPECTED_CONTENT_EXTRACTED: Dict[str, str] = { "_id": "file1", "_timestamp": "2023-03-09T00:00:00Z", "body": SAMPLE_CONTENT, @@ -331,7 +334,7 @@ "file_name": "file3.png", "download_url": "https://api.zoom.us/v2/download_url", } -SAMPLE_FILE = { +SAMPLE_FILE: Dict[str, Optional[List[Dict[str, Union[int, str]]]]] = { "next_page_token": None, "messages": [ FILE, @@ -389,7 +392,7 @@ class ZoomAsyncMock(mock.AsyncMock): - def __init__(self, data, *args, **kwargs): + def __init__(self, data, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self._data = data @@ -400,7 +403,7 @@ async def text(self): return self._data -def get_mock(mock_response): +def get_mock(mock_response) -> AsyncMock: async_mock = mock.AsyncMock() async_mock.__aenter__ = mock.AsyncMock( return_value=ZoomAsyncMock(data=mock_response) @@ -410,7 +413,7 @@ def get_mock(mock_response): @asynccontextmanager async def create_zoom_source( - fetch_past_meeting_details=False, use_text_extraction_service=False + fetch_past_meeting_details: bool = False, use_text_extraction_service: bool = False ): async with create_source( ZoomDataSource, @@ -424,7 +427,7 @@ async def create_zoom_source( yield source -def mock_zoom_apis(url, headers): +def mock_zoom_apis(url, headers) -> AsyncMock: # Users APIS if url == "https://api.zoom.us/v2/users?page_size=300": return get_mock(mock_response=SAMPLE_USER_PAGE1) @@ -510,12 +513,12 @@ def mock_zoom_apis(url, headers): return get_mock(mock_response=None) -def mock_token_response(): +def mock_token_response() -> AsyncMock: return get_mock(mock_response=SAMPLE_ACCESS_TOKEN_RESPONSE) @pytest.mark.asyncio -async def test_fetch_for_successful_call(): +async def test_fetch_for_successful_call() -> None: async with create_zoom_source() as source: with mock.patch( "aiohttp.ClientSession.post", @@ -533,7 +536,7 @@ async def test_fetch_for_successful_call(): @pytest.mark.asyncio @patch("connectors.utils.time_to_sleep_between_retries", Mock(return_value=0)) -async def test_fetch_for_unsuccessful_call(): +async def test_fetch_for_unsuccessful_call() -> None: async with create_zoom_source() as source: with mock.patch( "aiohttp.ClientSession.post", @@ -551,7 +554,7 @@ async def test_fetch_for_unsuccessful_call(): @pytest.mark.asyncio @patch("connectors.utils.time_to_sleep_between_retries", Mock(return_value=0)) -async def test_fetch_for_unauthorized_error(): +async def test_fetch_for_unauthorized_error() -> None: async with create_zoom_source() as source: with mock.patch( "aiohttp.ClientSession.post", @@ -575,7 +578,7 @@ async def test_fetch_for_unauthorized_error(): @pytest.mark.asyncio @patch("connectors.utils.time_to_sleep_between_retries", Mock(return_value=0)) -async def test_fetch_for_notfound_error(): +async def test_fetch_for_notfound_error() -> None: async with create_zoom_source() as source: with mock.patch( "aiohttp.ClientSession.post", @@ -599,7 +602,7 @@ async def test_fetch_for_notfound_error(): @pytest.mark.asyncio @patch("connectors.utils.time_to_sleep_between_retries", Mock(return_value=0)) -async def test_fetch_for_other_client_error(): +async def test_fetch_for_other_client_error() -> None: async with create_zoom_source() as source: with mock.patch( "aiohttp.ClientSession.post", @@ -622,7 +625,7 @@ async def test_fetch_for_other_client_error(): @pytest.mark.asyncio -async def test_content_for_successful_call(): +async def test_content_for_successful_call() -> None: async with create_zoom_source() as source: with mock.patch( "aiohttp.ClientSession.post", @@ -640,7 +643,7 @@ async def test_content_for_successful_call(): @pytest.mark.asyncio @patch("connectors.utils.time_to_sleep_between_retries", Mock(return_value=0)) -async def test_content_for_unsuccessful_call(): +async def test_content_for_unsuccessful_call() -> None: async with create_zoom_source() as source: with mock.patch( "aiohttp.ClientSession.post", @@ -657,7 +660,7 @@ async def test_content_for_unsuccessful_call(): @pytest.mark.asyncio -async def test_scroll_for_successful_response(): +async def test_scroll_for_successful_response() -> None: async with create_zoom_source() as source: with mock.patch( "aiohttp.ClientSession.post", @@ -676,7 +679,7 @@ async def test_scroll_for_successful_response(): @pytest.mark.asyncio -async def test_scroll_for_empty_response(): +async def test_scroll_for_empty_response() -> None: async with create_zoom_source() as source: with mock.patch( "aiohttp.ClientSession.post", @@ -695,7 +698,7 @@ async def test_scroll_for_empty_response(): @pytest.mark.asyncio -async def test_validate_config(): +async def test_validate_config() -> None: async with create_zoom_source() as source: with mock.patch( "aiohttp.ClientSession.post", @@ -710,7 +713,7 @@ async def test_validate_config(): @pytest.mark.asyncio @pytest.mark.parametrize("field", ["account_id", "client_id", "client_secret"]) -async def test_validate_config_missing_fields_then_raise(field): +async def test_validate_config_missing_fields_then_raise(field) -> None: async with create_zoom_source() as source: source.configuration.get_field(field).value = "" @@ -719,7 +722,7 @@ async def test_validate_config_missing_fields_then_raise(field): @pytest.mark.asyncio -async def test_ping_for_successful_connection(): +async def test_ping_for_successful_connection() -> None: async with create_zoom_source() as source: with mock.patch( "aiohttp.ClientSession.post", @@ -734,7 +737,7 @@ async def test_ping_for_successful_connection(): @pytest.mark.asyncio @patch("connectors.utils.time_to_sleep_between_retries", Mock(return_value=0)) -async def test_ping_for_unsuccessful_connection(): +async def test_ping_for_unsuccessful_connection() -> None: async with create_zoom_source() as source: with mock.patch( "aiohttp.ClientSession.post", @@ -755,7 +758,7 @@ async def test_ping_for_unsuccessful_connection(): (FILE_WITH_UNSUPPORTED_EXTENSION, True, None), ], ) -async def test_get_content(attachment, doit, expected_content): +async def test_get_content(attachment, doit, expected_content) -> None: async with create_zoom_source() as source: with mock.patch( "aiohttp.ClientSession.post", @@ -773,7 +776,7 @@ async def test_get_content(attachment, doit, expected_content): @pytest.mark.asyncio -async def test_get_content_with_extraction_service(): +async def test_get_content_with_extraction_service() -> None: with ( patch( "connectors.content_extraction.ContentExtraction.extract_text", @@ -802,7 +805,7 @@ async def test_get_content_with_extraction_service(): @pytest.mark.asyncio @freeze_time("2023-03-09T00:00:00") -async def test_get_docs(): +async def test_get_docs() -> None: document_without_attachment = [] document_with_attachment = [] async with create_zoom_source(fetch_past_meeting_details=True) as source: diff --git a/tests/test_access_control.py b/tests/test_access_control.py index 6c1f34d7e..4561d4cdf 100644 --- a/tests/test_access_control.py +++ b/tests/test_access_control.py @@ -13,7 +13,7 @@ @pytest.mark.asyncio -async def test_access_control_query(): +async def test_access_control_query() -> None: access_control = ["user_1"] access_control_query = es_access_control_query(access_control) @@ -46,21 +46,21 @@ async def test_access_control_query(): } -def test_prefix_identity(): +def test_prefix_identity() -> None: prefix = "prefix" identity = "identity" assert prefix_identity(prefix, identity) == f"{prefix}:{identity}" -def test_prefix_identity_with_prefix_none(): +def test_prefix_identity_with_prefix_none() -> None: prefix = None identity = "identity" assert prefix_identity(prefix, identity) is None -def test_prefix_identity_with_identity_none(): +def test_prefix_identity_with_identity_none() -> None: prefix = "prefix" identity = None diff --git a/tests/test_commons.py b/tests/test_commons.py index f1a347beb..3896181c7 100644 --- a/tests/test_commons.py +++ b/tests/test_commons.py @@ -9,7 +9,7 @@ @pytest.mark.asyncio -async def test_async_generation(): +async def test_async_generation() -> None: items = [1, 2, 3] async_generator = AsyncIterator(items) @@ -22,7 +22,7 @@ async def test_async_generation(): @pytest.mark.asyncio -async def test_call_args(): +async def test_call_args() -> None: items = [1] async_generator = AsyncIterator(items) @@ -54,7 +54,7 @@ async def test_call_args(): @pytest.mark.asyncio -async def test_call_kwargs(): +async def test_call_kwargs() -> None: items = [1] async_generator = AsyncIterator(items) @@ -90,7 +90,7 @@ async def test_call_kwargs(): @pytest.mark.asyncio -async def test_assert_not_called(): +async def test_assert_not_called() -> None: items = [] async_generator = AsyncIterator(items) @@ -98,7 +98,7 @@ async def test_assert_not_called(): @pytest.mark.asyncio -async def test_assert_not_called_with_one_call(): +async def test_assert_not_called_with_one_call() -> None: items = [] async_generator = AsyncIterator(items) @@ -111,7 +111,7 @@ async def test_assert_not_called_with_one_call(): @pytest.mark.asyncio -async def test_assert_called_once(): +async def test_assert_called_once() -> None: items = [] async_generator = AsyncIterator(items) @@ -127,7 +127,7 @@ async def test_assert_called_once(): @pytest.mark.asyncio -async def test_assert_called_once_with_two_calls(): +async def test_assert_called_once_with_two_calls() -> None: items = [] async_generator = AsyncIterator(items) @@ -145,7 +145,7 @@ async def test_assert_called_once_with_two_calls(): @pytest.mark.asyncio -async def test_assert_called_once_with_one_arg(): +async def test_assert_called_once_with_one_arg() -> None: items = [1] argument = "some argument" @@ -159,7 +159,7 @@ async def test_assert_called_once_with_one_arg(): @pytest.mark.asyncio -async def test_assert_called_once_with_wrong_arg(): +async def test_assert_called_once_with_wrong_arg() -> None: items = [1] argument = "some argument" @@ -175,7 +175,7 @@ async def test_assert_called_once_with_wrong_arg(): @pytest.mark.asyncio -async def test_assert_called_once_with_one_arg_and_two_calls(): +async def test_assert_called_once_with_one_arg_and_two_calls() -> None: items = [1] argument = "some argument" @@ -195,7 +195,7 @@ async def test_assert_called_once_with_one_arg_and_two_calls(): @pytest.mark.asyncio -async def test_assert_called_once_with_three_args(): +async def test_assert_called_once_with_three_args() -> None: items = [1] argument_one = "some argument one" @@ -211,7 +211,7 @@ async def test_assert_called_once_with_three_args(): @pytest.mark.asyncio -async def test_assert_called_once_with_one_kwarg(): +async def test_assert_called_once_with_one_kwarg() -> None: items = [1] argument = "some argument" @@ -225,7 +225,7 @@ async def test_assert_called_once_with_one_kwarg(): @pytest.mark.asyncio -async def test_assert_called_once_with_one_kwarg_and_two_calls(): +async def test_assert_called_once_with_one_kwarg_and_two_calls() -> None: items = [1] argument = "some argument" @@ -245,7 +245,7 @@ async def test_assert_called_once_with_one_kwarg_and_two_calls(): @pytest.mark.asyncio -async def test_assert_called_once_with_wrong_kwarg(): +async def test_assert_called_once_with_wrong_kwarg() -> None: items = [1] argument = "some argument" @@ -261,7 +261,7 @@ async def test_assert_called_once_with_wrong_kwarg(): @pytest.mark.asyncio -async def test_assert_called_once_with_three_kwargs(): +async def test_assert_called_once_with_three_kwargs() -> None: items = [1] argument_one = "some argument one" @@ -285,7 +285,7 @@ async def test_assert_called_once_with_three_kwargs(): @pytest.mark.asyncio -async def test_assert_called_once_with_args_and_kwargs(): +async def test_assert_called_once_with_args_and_kwargs() -> None: items = [1] argument_one = "some argument one" diff --git a/tests/test_config.py b/tests/test_config.py index 3a3b29b9a..5918a8684 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -11,36 +11,36 @@ from connectors.config import _nest_configs, load_config -HERE = os.path.dirname(__file__) -FIXTURES_DIR = os.path.abspath(os.path.join(HERE, "fixtures")) +HERE: str = os.path.dirname(__file__) +FIXTURES_DIR: str = os.path.abspath(os.path.join(HERE, "fixtures")) -CONFIG_FILE = os.path.join(FIXTURES_DIR, "config.yml") -ES_CONFIG_FILE = os.path.join(FIXTURES_DIR, "entsearch.yml") -ES_CONFIG_INVALID_LOG_LEVEL_FILE = os.path.join( +CONFIG_FILE: str = os.path.join(FIXTURES_DIR, "config.yml") +ES_CONFIG_FILE: str = os.path.join(FIXTURES_DIR, "entsearch.yml") +ES_CONFIG_INVALID_LOG_LEVEL_FILE: str = os.path.join( FIXTURES_DIR, "entsearch_invalid_log_level.yml" ) -def test_bad_config_file(): +def test_bad_config_file() -> None: with pytest.raises(FileNotFoundError): load_config("BEEUUUAH") -def test_config(set_env): +def test_config(set_env) -> None: config = load_config(CONFIG_FILE) assert isinstance(config, dict) assert config["elasticsearch"]["host"] == "http://nowhere.com:9200" assert config["elasticsearch"]["user"] == "elastic" -def test_config_with_ent_search(set_env): +def test_config_with_ent_search(set_env) -> None: with mock.patch.dict(os.environ, {"ENT_SEARCH_CONFIG_PATH": ES_CONFIG_FILE}): config = load_config(CONFIG_FILE) assert config["elasticsearch"]["headers"]["X-Elastic-Auth"] == "SomeYeahValue" assert config["service"]["log_level"] == "DEBUG" -def test_config_with_invalid_log_level(set_env): +def test_config_with_invalid_log_level(set_env) -> None: with mock.patch.dict( os.environ, {"ENT_SEARCH_CONFIG_PATH": ES_CONFIG_INVALID_LOG_LEVEL_FILE} ): @@ -50,7 +50,7 @@ def test_config_with_invalid_log_level(set_env): assert e.match("Unexpected log level.*") -def test_nest_config_when_nested_field_does_not_exist(): +def test_nest_config_when_nested_field_does_not_exist() -> None: config = {} _nest_configs(config, "test.nested.property", 50) @@ -58,7 +58,7 @@ def test_nest_config_when_nested_field_does_not_exist(): assert config["test"]["nested"]["property"] == 50 -def test_nest_config_when_nested_field_exists(): +def test_nest_config_when_nested_field_exists() -> None: config = {"test": {"nested": {"property": 25}}} _nest_configs(config, "test.nested.property", 50) @@ -66,7 +66,7 @@ def test_nest_config_when_nested_field_exists(): assert config["test"]["nested"]["property"] == 50 -def test_nest_config_when_root_field_does_not_exist(): +def test_nest_config_when_root_field_does_not_exist() -> None: config = {} _nest_configs(config, "test", 50) @@ -74,7 +74,7 @@ def test_nest_config_when_root_field_does_not_exist(): assert config["test"] == 50 -def test_nest_config_when_root_field_does_exists(): +def test_nest_config_when_root_field_does_exists() -> None: config = {"test": 10} _nest_configs(config, "test", 50) diff --git a/tests/test_connectors_cli.py b/tests/test_connectors_cli.py index 7f7de65bc..23614958f 100644 --- a/tests/test_connectors_cli.py +++ b/tests/test_connectors_cli.py @@ -5,6 +5,7 @@ # import json import os +from typing import Sequence, Union from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -43,7 +44,7 @@ def mock_job_es_client(): @pytest.mark.parametrize("commands", [["-v"], ["--version"]]) -def test_version(commands): +def test_version(commands: Union[None, Sequence[str], str]) -> None: runner = CliRunner() result = runner.invoke(cli, commands) assert result.exit_code == 0 @@ -51,7 +52,7 @@ def test_version(commands): @pytest.mark.parametrize("commands", [["-h"], ["--help"], []]) -def test_help_page(commands): +def test_help_page(commands: Union[None, Sequence[str], str]) -> None: runner = CliRunner() result = runner.invoke(cli, commands) assert "Usage:" in result.output @@ -60,7 +61,7 @@ def test_help_page(commands): @patch("connectors.cli.auth.Auth._Auth__ping_es_client", AsyncMock(return_value=False)) -def test_login_unsuccessful(tmp_path): +def test_login_unsuccessful(tmp_path: Union[None, os.PathLike[str], str]) -> None: runner = CliRunner() with runner.isolated_filesystem(temp_dir=tmp_path) as temp_dir: result = runner.invoke( @@ -72,7 +73,7 @@ def test_login_unsuccessful(tmp_path): @patch("connectors.cli.auth.Auth._Auth__ping_es_client", AsyncMock(return_value=True)) -def test_login_successful(tmp_path): +def test_login_successful(tmp_path: Union[None, os.PathLike[str], str]) -> None: runner = CliRunner() with runner.isolated_filesystem(temp_dir=tmp_path) as temp_dir: result = runner.invoke( @@ -84,7 +85,9 @@ def test_login_successful(tmp_path): @patch("connectors.cli.auth.Auth._Auth__ping_es_client", AsyncMock(return_value=True)) -def test_login_successful_with_apikey_method(tmp_path): +def test_login_successful_with_apikey_method( + tmp_path: Union[None, os.PathLike[str], str], +) -> None: runner = CliRunner() api_key = "testapikey" with runner.isolated_filesystem(temp_dir=tmp_path) as temp_dir: @@ -102,7 +105,9 @@ def test_login_successful_with_apikey_method(tmp_path): @patch("click.confirm") -def test_login_when_credentials_file_exists(mocked_confirm, tmp_path): +def test_login_when_credentials_file_exists( + mocked_confirm, tmp_path: Union[None, os.PathLike[str], str] +) -> None: runner = CliRunner() with runner.isolated_filesystem(temp_dir=tmp_path) as temp_dir: mocked_confirm.return_value = True @@ -119,7 +124,7 @@ def test_login_when_credentials_file_exists(mocked_confirm, tmp_path): assert mocked_confirm.called_once() -def test_connector_help_page(): +def test_connector_help_page() -> None: runner = CliRunner() result = runner.invoke(cli, ["connector", "--help"]) assert result.exit_code == 0 @@ -129,14 +134,14 @@ def test_connector_help_page(): @patch("connectors.cli.connector.Connector.list_connectors", AsyncMock(return_value=[])) -def test_connector_list_no_connectors(): +def test_connector_list_no_connectors() -> None: runner = CliRunner() result = runner.invoke(cli, ["connector", "list"]) assert result.exit_code == 0 assert "No connectors found" in result.output -def test_connector_list_one_connector(): +def test_connector_list_one_connector() -> None: runner = CliRunner() connector_index = MagicMock() @@ -169,7 +174,7 @@ def test_connector_list_one_connector(): "connectors.cli.index.Index.index_or_connector_exists", MagicMock(return_value=[False, False]), ) -def test_connector_create(patch_click_confirm): +def test_connector_create(patch_click_confirm) -> None: runner = CliRunner() # configuration for the MongoDB connector @@ -217,8 +222,8 @@ def test_connector_create(patch_click_confirm): MagicMock(return_value=[False, False]), ) def test_connector_create_with_native_flags( - patch_click_confirm, native_flag, input_index_name, expected_index_name -): + patch_click_confirm, native_flag: str, input_index_name, expected_index_name +) -> None: runner = CliRunner() # configuration for the MongoDB connector @@ -261,7 +266,7 @@ def test_connector_create_with_native_flags( "connectors.cli.connector.Connector._Connector__create_api_key", AsyncMock(return_value={"id": "new_api_key_id", "encoded": "encoded_api_key"}), ) -def test_connector_create_from_index(patch_click_confirm): +def test_connector_create_from_index(patch_click_confirm) -> None: runner = CliRunner() # configuration for the MongoDB connector @@ -311,7 +316,7 @@ def test_connector_create_from_index(patch_click_confirm): @patch("click.confirm") def test_connector_create_fails_when_index_or_connector_exists( patch_click_confirm, index_exists, connector_exists, from_index_flag, expected_error -): +) -> None: runner = CliRunner() # configuration for the MongoDB connector @@ -356,7 +361,7 @@ def test_connector_create_fails_when_index_or_connector_exists( "connectors.cli.index.Index.index_or_connector_exists", MagicMock(return_value=[False, False]), ) -def test_connector_create_from_file(): +def test_connector_create_from_file() -> None: runner = CliRunner() # configuration for the MongoDB connector @@ -414,7 +419,7 @@ def test_connector_create_from_file(): "connectors.cli.index.Index.index_or_connector_exists", MagicMock(return_value=[False, False]), ) -def test_connector_create_and_update_the_service_config(): +def test_connector_create_and_update_the_service_config() -> None: runner = CliRunner() connector_id = "new_connector_id" service_type = "mongodb" @@ -472,7 +477,7 @@ def test_connector_create_and_update_the_service_config(): "connectors.cli.index.Index.index_or_connector_exists", MagicMock(return_value=[True, False]), ) -def test_connector_create_native_connector(patched_confirm): +def test_connector_create_native_connector(patched_confirm) -> None: runner = CliRunner() # configuration for the MongoDB connector @@ -528,7 +533,7 @@ def test_connector_create_native_connector(patched_confirm): assert "has been created" in result.output -def test_index_help_page(): +def test_index_help_page() -> None: runner = CliRunner() result = runner.invoke(cli, ["index", "--help"]) assert result.exit_code == 0 @@ -538,14 +543,14 @@ def test_index_help_page(): @patch("connectors.cli.index.Index.list_indices", MagicMock(return_value=[])) -def test_index_list_no_indexes(): +def test_index_list_no_indexes() -> None: runner = CliRunner() result = runner.invoke(cli, ["index", "list"]) assert result.exit_code == 0 assert "No indices found" in result.output -def test_index_list_one_index(): +def test_index_list_one_index() -> None: runner = CliRunner() indices = {"test_index": {"docs_count": 10}} @@ -559,7 +564,7 @@ def test_index_list_one_index(): assert "test_index" in result.output -def test_index_list_one_index_in_serverless(): +def test_index_list_one_index_in_serverless() -> None: runner = CliRunner() indices = {"test_index": {"docs_count": 10}} @@ -580,7 +585,7 @@ def test_index_list_one_index_in_serverless(): @patch("click.confirm", MagicMock(return_value=True)) -def test_index_clean(): +def test_index_clean() -> None: runner = CliRunner() index_name = "test_index" with patch( @@ -595,7 +600,7 @@ def test_index_clean(): @patch("click.confirm", MagicMock(return_value=True)) -def test_index_clean_error(): +def test_index_clean_error() -> None: runner = CliRunner() index_name = "test_index" with patch( @@ -609,7 +614,7 @@ def test_index_clean_error(): @patch("click.confirm", MagicMock(return_value=True)) -def test_index_delete(): +def test_index_delete() -> None: runner = CliRunner() index_name = "test_index" with patch( @@ -624,7 +629,7 @@ def test_index_delete(): @patch("click.confirm", MagicMock(return_value=True)) -def test_delete_index_error(): +def test_delete_index_error() -> None: runner = CliRunner() index_name = "test_index" with patch( @@ -637,7 +642,7 @@ def test_delete_index_error(): assert result.exit_code == 0 -def test_job_help_page(): +def test_job_help_page() -> None: runner = CliRunner() result = runner.invoke(cli, ["job", "--help"]) assert result.exit_code == 0 @@ -646,7 +651,7 @@ def test_job_help_page(): assert "Commands:" in result.output -def test_job_help_page_without_subcommands(): +def test_job_help_page_without_subcommands() -> None: runner = CliRunner() result = runner.invoke(cli, ["job"]) assert result.exit_code == 0 @@ -656,7 +661,7 @@ def test_job_help_page_without_subcommands(): @patch("click.confirm", MagicMock(return_value=True)) -def test_job_cancel(): +def test_job_cancel() -> None: runner = CliRunner() job_id = "test_job_id" @@ -688,7 +693,7 @@ def test_job_cancel(): @patch("click.confirm", MagicMock(return_value=True)) -def test_job_cancel_error(): +def test_job_cancel_error() -> None: runner = CliRunner() job_id = "test_job_id" with patch( @@ -701,7 +706,7 @@ def test_job_cancel_error(): assert result.exit_code == 0 -def test_job_list_no_jobs(): +def test_job_list_no_jobs() -> None: runner = CliRunner() connector_id = "test_connector_id" @@ -715,7 +720,7 @@ def test_job_list_no_jobs(): @patch("click.confirm", MagicMock(return_value=True)) -def test_job_list_one_job(): +def test_job_list_one_job() -> None: runner = CliRunner() job_id = "test_job_id" connector_id = "test_connector_id" @@ -766,7 +771,7 @@ def test_job_list_one_job(): "connectors.protocol.connectors.ConnectorIndex.fetch_by_id", AsyncMock(return_value=MagicMock()), ) -def test_job_start(): +def test_job_start() -> None: runner = CliRunner() connector_id = "test_connector_id" job_id = "test_job_id" @@ -782,7 +787,7 @@ def test_job_start(): assert result.exit_code == 0 -def test_job_view(): +def test_job_view() -> None: runner = CliRunner() job_id = "test_job_id" connector_id = "test_connector_id" diff --git a/tests/test_content_extraction.py b/tests/test_content_extraction.py index 9597141b1..e1f55f9e1 100644 --- a/tests/test_content_extraction.py +++ b/tests/test_content_extraction.py @@ -11,7 +11,7 @@ from connectors.content_extraction import ContentExtraction -def test_set_and_get_configuration(): +def test_set_and_get_configuration() -> None: config = { "extraction_service": { "host": "http://localhost:8090", @@ -36,7 +36,7 @@ def test_set_and_get_configuration(): ({"extraction_service": {"not_a_host": "!!!m"}}, False), ], ) -def test_check_configured(mock_config, expected_result): +def test_check_configured(mock_config, expected_result) -> None: with patch( "connectors.content_extraction.ContentExtraction.get_extraction_config", return_value=mock_config.get("extraction_service", None), @@ -46,7 +46,7 @@ def test_check_configured(mock_config, expected_result): @pytest.mark.asyncio -async def test_extract_text(mock_responses, patch_logger): +async def test_extract_text(mock_responses, patch_logger) -> None: filepath = "tmp/notreal.txt" url = "http://localhost:8090/extract_text/" payload = {"extracted_text": "I've been extracted!"} @@ -73,7 +73,7 @@ async def test_extract_text(mock_responses, patch_logger): @pytest.mark.asyncio -async def test_extract_text_with_file_pointer(mock_responses, patch_logger): +async def test_extract_text_with_file_pointer(mock_responses, patch_logger) -> None: filepath = "/tmp/notreal.txt" url = "http://localhost:8090/extract_text/?local_file_path=/tmp/notreal.txt" payload = {"extracted_text": "I've been extracted from a local file!"} @@ -104,7 +104,7 @@ async def test_extract_text_with_file_pointer(mock_responses, patch_logger): @pytest.mark.asyncio -async def test_extract_text_when_host_is_none(mock_responses, patch_logger): +async def test_extract_text_when_host_is_none(mock_responses, patch_logger) -> None: filepath = "/tmp/notreal.txt" with ( @@ -134,7 +134,7 @@ async def test_extract_text_when_host_is_none(mock_responses, patch_logger): @pytest.mark.asyncio async def test_extract_text_when_response_isnt_200_logs_warning( mock_responses, patch_logger -): +) -> None: filepath = "tmp/notreal.txt" url = "http://localhost:8090/extract_text/" @@ -167,7 +167,9 @@ async def test_extract_text_when_response_isnt_200_logs_warning( @pytest.mark.asyncio -async def test_extract_text_when_response_is_error(mock_responses, patch_logger): +async def test_extract_text_when_response_is_error( + mock_responses, patch_logger +) -> None: filepath = "tmp/notreal.txt" with ( @@ -194,7 +196,9 @@ async def test_extract_text_when_response_is_error(mock_responses, patch_logger) @pytest.mark.asyncio -async def test_extract_text_when_response_is_timeout(mock_responses, patch_logger): +async def test_extract_text_when_response_is_timeout( + mock_responses, patch_logger +) -> None: filepath = "tmp/notreal.txt" with ( @@ -223,7 +227,7 @@ async def test_extract_text_when_response_is_timeout(mock_responses, patch_logge @pytest.mark.asyncio async def test_extract_text_when_response_is_200_with_error_logs_warning( mock_responses, patch_logger -): +) -> None: filepath = "tmp/notreal.txt" url = "http://localhost:8090/extract_text/" diff --git a/tests/test_kibana.py b/tests/test_kibana.py index 236bcd675..efc771d61 100644 --- a/tests/test_kibana.py +++ b/tests/test_kibana.py @@ -11,11 +11,11 @@ from connectors.es.management_client import ESManagementClient from connectors.kibana import main, upsert_index -HERE = os.path.dirname(__file__) -FIXTURES_DIR = os.path.abspath(os.path.join(HERE, "fixtures")) +HERE: str = os.path.dirname(__file__) +FIXTURES_DIR: str = os.path.abspath(os.path.join(HERE, "fixtures")) -def mock_index_creation(index, mock_responses, hidden=True): +def mock_index_creation(index, mock_responses, hidden: bool = True) -> None: url = f"http://nowhere.com:9200/{index}" headers = {"X-Elastic-Product": "Elasticsearch"} mock_responses.head( @@ -33,7 +33,7 @@ def mock_index_creation(index, mock_responses, hidden=True): @mock.patch.dict(os.environ, {"elasticsearch.password": "changeme"}) -def test_main(patch_logger, mock_responses): +def test_main(patch_logger, mock_responses) -> None: headers = {"X-Elastic-Product": "Elasticsearch"} mock_responses.put( @@ -73,7 +73,7 @@ def test_main(patch_logger, mock_responses): @pytest.mark.asyncio -async def test_upsert_index(mock_responses): +async def test_upsert_index(mock_responses) -> None: config = {"host": "http://nowhere.com:9200", "user": "tarek", "password": "blah"} headers = {"X-Elastic-Product": "Elasticsearch"} mock_responses.post( diff --git a/tests/test_logger.py b/tests/test_logger.py index a507cc621..78c276f47 100644 --- a/tests/test_logger.py +++ b/tests/test_logger.py @@ -27,13 +27,13 @@ def unset_logger(): connectors.logger.logger = logger -def test_logger(): +def test_logger() -> None: with unset_logger(): logger = set_logger(logging.DEBUG) assert logger.level == logging.DEBUG -def test_logger_filebeat(): +def test_logger_filebeat() -> None: with unset_logger(): logger = set_logger(logging.DEBUG, filebeat=True) logs = [] @@ -50,7 +50,7 @@ def _w(msg): assert data["service"]["type"] == "connectors-python" -def test_tracer(): +def test_tracer() -> None: with unset_logger(): logger = set_logger(logging.DEBUG, filebeat=True) logs = [] @@ -74,7 +74,7 @@ def traceable(): @pytest.mark.asyncio -async def test_async_tracer(): +async def test_async_tracer() -> None: with unset_logger(): logger = set_logger(logging.DEBUG, filebeat=True) logs = [] @@ -97,7 +97,7 @@ async def traceable(): @pytest.mark.asyncio -async def test_async_tracer_slow(): +async def test_async_tracer_slow() -> None: with unset_logger(): logger = set_logger(logging.DEBUG, filebeat=True) logs = [] @@ -159,7 +159,7 @@ async def gen(): ("critical", ColorFormatter.BOLD_RED), ], ) -def test_colored_logging(log_level, color): +def test_colored_logging(log_level, color) -> None: with unset_logger(): logger = set_logger(logging.DEBUG, filebeat=False) logs = [] @@ -177,7 +177,7 @@ def _w(msg): # first param is UTC time, second param is offset we run with @freeze_time("2024-07-10 12:00:00", tz_offset=-7) -def test_timestamp_is_utc(): +def test_timestamp_is_utc() -> None: with unset_logger(): logger = set_logger(logging.DEBUG, filebeat=False) logs = [] @@ -193,7 +193,7 @@ def _w(msg): ) # if the local time was respected, this would be 05:00:00 -def test_colored_logging_with_filebeat(): +def test_colored_logging_with_filebeat() -> None: with unset_logger(): logger = set_logger(logging.DEBUG, filebeat=True) logs = [] diff --git a/tests/test_preflight_check.py b/tests/test_preflight_check.py index 54c5c5ba1..30375063b 100644 --- a/tests/test_preflight_check.py +++ b/tests/test_preflight_check.py @@ -4,6 +4,7 @@ # you may not use this file except in compliance with the Elastic License 2.0. # from copy import deepcopy +from typing import Dict, List, Union from unittest.mock import patch import pytest @@ -14,7 +15,9 @@ connectors_version = "1.2.3.4" headers = {"X-Elastic-Product": "Elasticsearch"} host = "http://localhost:9200" -config = { +config: Dict[ + str, Union[Dict[str, float], Dict[str, Union[float, str]], List[Dict[str, str]]] +] = { "elasticsearch": { "host": host, "username": "elastic", @@ -34,11 +37,11 @@ def mock_es_info( mock_responses, - healthy=True, - repeat=False, - es_version=connectors_version, - serverless=False, -): + healthy: bool = True, + repeat: bool = False, + es_version: str = connectors_version, + serverless: bool = False, +) -> None: status = 200 if healthy else 503 payload = { "version": { @@ -51,28 +54,30 @@ def mock_es_info( ) -def mock_index_exists(mock_responses, index, exist=True, repeat=False): +def mock_index_exists( + mock_responses, index, exist: bool = True, repeat: bool = False +) -> None: status = 200 if exist else 404 mock_responses.head(f"{host}/{index}", status=status, repeat=repeat) -def mock_index(mock_responses, index, doc_id, repeat=False): +def mock_index(mock_responses, index, doc_id, repeat: bool = False) -> None: status = 200 mock_responses.put(f"{host}/{index}/_doc/{doc_id}", status=status, repeat=repeat) -def mock_create_index(mock_responses, index, repeat=False): +def mock_create_index(mock_responses, index, repeat: bool = False) -> None: status = 200 mock_responses.put(f"{host}/{index}", status=status, repeat=repeat) -def mock_delete(mock_responses, index, doc_id, repeat=False): +def mock_delete(mock_responses, index, doc_id, repeat: bool = False) -> None: status = 200 mock_responses.delete(f"{host}/{index}/_doc/{doc_id}", status=status, repeat=repeat) @pytest.mark.asyncio -async def test_es_unavailable(mock_responses): +async def test_es_unavailable(mock_responses) -> None: mock_es_info(mock_responses, healthy=False, repeat=True) preflight = PreflightCheck(config, connectors_version) result = await preflight.run() @@ -80,7 +85,7 @@ async def test_es_unavailable(mock_responses): @pytest.mark.asyncio -async def test_connectors_index_missing(mocker, mock_responses): +async def test_connectors_index_missing(mocker, mock_responses) -> None: mock_es_info(mock_responses) mock_index_exists(mock_responses, CONCRETE_CONNECTORS_INDEX, exist=False) mock_index_exists(mock_responses, CONCRETE_JOBS_INDEX, exist=True) @@ -93,7 +98,7 @@ async def test_connectors_index_missing(mocker, mock_responses): @pytest.mark.asyncio -async def test_jobs_index_missing(mocker, mock_responses): +async def test_jobs_index_missing(mocker, mock_responses) -> None: mock_es_info(mock_responses) mock_index_exists(mock_responses, CONCRETE_CONNECTORS_INDEX, exist=True) mock_index_exists(mock_responses, CONCRETE_JOBS_INDEX, exist=False) @@ -105,7 +110,7 @@ async def test_jobs_index_missing(mocker, mock_responses): @pytest.mark.asyncio -async def test_both_indices_missing(mocker, mock_responses): +async def test_both_indices_missing(mocker, mock_responses) -> None: mock_es_info(mock_responses) mock_index_exists(mock_responses, CONCRETE_CONNECTORS_INDEX, exist=False) mock_index_exists(mock_responses, CONCRETE_JOBS_INDEX, exist=False) @@ -118,7 +123,7 @@ async def test_both_indices_missing(mocker, mock_responses): @pytest.mark.asyncio -async def test_pass(mock_responses): +async def test_pass(mock_responses) -> None: mock_es_info(mock_responses) mock_index_exists(mock_responses, CONCRETE_CONNECTORS_INDEX) mock_index_exists(mock_responses, CONCRETE_JOBS_INDEX) @@ -129,7 +134,7 @@ async def test_pass(mock_responses): @pytest.mark.asyncio @patch("connectors.preflight_check.logger") -async def test_pass_serverless(patched_logger, mock_responses): +async def test_pass_serverless(patched_logger, mock_responses) -> None: mock_es_info(mock_responses, serverless=True) mock_index_exists(mock_responses, CONCRETE_CONNECTORS_INDEX) mock_index_exists(mock_responses, CONCRETE_JOBS_INDEX) @@ -143,7 +148,9 @@ async def test_pass_serverless(patched_logger, mock_responses): @pytest.mark.asyncio @patch("connectors.preflight_check.logger") -async def test_pass_serverless_mismatched_versions(patched_logger, mock_responses): +async def test_pass_serverless_mismatched_versions( + patched_logger, mock_responses +) -> None: mock_es_info(mock_responses, es_version="2.0.0-SNAPSHOT", serverless=True) mock_index_exists(mock_responses, CONCRETE_CONNECTORS_INDEX) mock_index_exists(mock_responses, CONCRETE_JOBS_INDEX) @@ -171,8 +178,8 @@ async def test_pass_serverless_mismatched_versions(patched_logger, mock_response ], ) async def test_fail_mismatched_version( - patched_logger, mock_responses, es_version, expected_log -): + patched_logger, mock_responses, es_version: str, expected_log +) -> None: mock_es_info(mock_responses, es_version=es_version) mock_index_exists(mock_responses, CONCRETE_CONNECTORS_INDEX) mock_index_exists(mock_responses, CONCRETE_JOBS_INDEX) @@ -198,8 +205,8 @@ async def test_fail_mismatched_version( ], ) async def test_warn_mismatched_version( - patched_logger, mock_responses, es_version, expected_log -): + patched_logger, mock_responses, es_version: str, expected_log +) -> None: mock_es_info(mock_responses, es_version=es_version) mock_index_exists(mock_responses, CONCRETE_CONNECTORS_INDEX) mock_index_exists(mock_responses, CONCRETE_JOBS_INDEX) @@ -252,8 +259,8 @@ async def test_warn_mismatched_version( ], ) async def test_pass_mismatched_version( - patched_logger, mock_responses, es_version, connectors_version, expected_log -): + patched_logger, mock_responses, es_version: str, connectors_version, expected_log +) -> None: mock_es_info(mock_responses, es_version=es_version) mock_index_exists(mock_responses, CONCRETE_CONNECTORS_INDEX) mock_index_exists(mock_responses, CONCRETE_JOBS_INDEX) @@ -264,7 +271,7 @@ async def test_pass_mismatched_version( @pytest.mark.asyncio -async def test_es_transient_error(mock_responses): +async def test_es_transient_error(mock_responses) -> None: mock_es_info(mock_responses, healthy=False) mock_es_info(mock_responses) mock_index_exists(mock_responses, CONCRETE_CONNECTORS_INDEX) @@ -275,7 +282,7 @@ async def test_es_transient_error(mock_responses): @pytest.mark.asyncio -async def test_index_exist_transient_error(mock_responses): +async def test_index_exist_transient_error(mock_responses) -> None: mock_es_info(mock_responses) mock_index_exists(mock_responses, CONCRETE_CONNECTORS_INDEX, exist=False) mock_index_exists(mock_responses, CONCRETE_CONNECTORS_INDEX, repeat=True) @@ -288,7 +295,7 @@ async def test_index_exist_transient_error(mock_responses): @pytest.mark.asyncio @patch("connectors.preflight_check.logger") -async def test_native_config_is_warned(patched_logger, mock_responses): +async def test_native_config_is_warned(patched_logger, mock_responses) -> None: mock_es_info(mock_responses) mock_index_exists(mock_responses, CONCRETE_CONNECTORS_INDEX) mock_index_exists(mock_responses, CONCRETE_JOBS_INDEX) @@ -311,7 +318,7 @@ async def test_native_config_is_warned(patched_logger, mock_responses): @pytest.mark.asyncio @patch("connectors.preflight_check.logger") -async def test_native_config_is_forced(patched_logger, mock_responses): +async def test_native_config_is_forced(patched_logger, mock_responses) -> None: mock_es_info(mock_responses) mock_index_exists(mock_responses, CONCRETE_CONNECTORS_INDEX) mock_index_exists(mock_responses, CONCRETE_JOBS_INDEX) @@ -326,7 +333,7 @@ async def test_native_config_is_forced(patched_logger, mock_responses): @pytest.mark.asyncio @patch("connectors.preflight_check.logger") -async def test_client_config(patched_logger, mock_responses): +async def test_client_config(patched_logger, mock_responses) -> None: mock_es_info(mock_responses) mock_index_exists(mock_responses, CONCRETE_CONNECTORS_INDEX) mock_index_exists(mock_responses, CONCRETE_JOBS_INDEX) @@ -341,7 +348,7 @@ async def test_client_config(patched_logger, mock_responses): @pytest.mark.asyncio @patch("connectors.preflight_check.logger") -async def test_unmodified_default_config(patched_logger, mock_responses): +async def test_unmodified_default_config(patched_logger, mock_responses) -> None: mock_es_info(mock_responses) mock_index_exists(mock_responses, CONCRETE_CONNECTORS_INDEX) mock_index_exists(mock_responses, CONCRETE_JOBS_INDEX) @@ -358,7 +365,7 @@ async def test_unmodified_default_config(patched_logger, mock_responses): @pytest.mark.asyncio @patch("connectors.preflight_check.logger") -async def test_missing_mode_config(patched_logger, mock_responses): +async def test_missing_mode_config(patched_logger, mock_responses) -> None: mock_es_info(mock_responses) mock_index_exists(mock_responses, CONCRETE_CONNECTORS_INDEX) mock_index_exists(mock_responses, CONCRETE_JOBS_INDEX) @@ -374,7 +381,7 @@ async def test_missing_mode_config(patched_logger, mock_responses): @patch("connectors.preflight_check.logger") async def test_extraction_service_enabled_and_found_writes_info_log( patched_logger, mock_responses -): +) -> None: mock_es_info(mock_responses) mock_index_exists(mock_responses, CONCRETE_CONNECTORS_INDEX) mock_index_exists(mock_responses, CONCRETE_JOBS_INDEX) @@ -398,7 +405,7 @@ async def test_extraction_service_enabled_and_found_writes_info_log( @patch("connectors.preflight_check.logger") async def test_extraction_service_enabled_but_missing_logs_warning( patched_logger, mock_responses -): +) -> None: mock_es_info(mock_responses) mock_index_exists(mock_responses, CONCRETE_CONNECTORS_INDEX) mock_index_exists(mock_responses, CONCRETE_JOBS_INDEX) @@ -422,7 +429,7 @@ async def test_extraction_service_enabled_but_missing_logs_warning( @patch("connectors.preflight_check.logger") async def test_extraction_service_enabled_but_missing_logs_critical( patched_logger, mock_responses -): +) -> None: mock_es_info(mock_responses) mock_index_exists(mock_responses, CONCRETE_CONNECTORS_INDEX) mock_index_exists(mock_responses, CONCRETE_JOBS_INDEX) diff --git a/tests/test_service_cli.py b/tests/test_service_cli.py index aeba9091c..69b6f2b2a 100644 --- a/tests/test_service_cli.py +++ b/tests/test_service_cli.py @@ -17,15 +17,15 @@ from connectors.service_cli import _start_service, get_event_loop, main SUCCESS_EXIT_CODE = 0 -CLICK_EXCEPTION_EXIT_CODE = ClickException.exit_code -USAGE_ERROR_EXIT_CODE = UsageError.exit_code +CLICK_EXCEPTION_EXIT_CODE: int = ClickException.exit_code +USAGE_ERROR_EXIT_CODE: int = UsageError.exit_code -HERE = os.path.dirname(__file__) -FIXTURES_DIR = os.path.abspath(os.path.join(HERE, "fixtures")) -CONFIG = os.path.join(FIXTURES_DIR, "config.yml") +HERE: str = os.path.dirname(__file__) +FIXTURES_DIR: str = os.path.abspath(os.path.join(HERE, "fixtures")) +CONFIG: str = os.path.join(FIXTURES_DIR, "config.yml") -def test_main_exits_on_sigterm(mock_responses): +def test_main_exits_on_sigterm(mock_responses) -> None: headers = {"X-Elastic-Product": "Elasticsearch"} host = "http://localhost:9200" @@ -49,7 +49,7 @@ async def kill(): @pytest.mark.parametrize("option", ["-v", "--version"]) -def test_version_action(option): +def test_version_action(option) -> None: runner = CliRunner() result = runner.invoke(main, [option]) @@ -62,7 +62,7 @@ def test_version_action(option): @patch("connectors.service_cli.get_services") async def test_shutdown_signal_registered( patch_get_services, patch_preflight_check, set_env -): +) -> None: patch_multi_service = Mock() patch_get_services.return_value = patch_multi_service patch_multi_service.run = AsyncMock() @@ -74,7 +74,7 @@ async def test_shutdown_signal_registered( ) -def test_list_action(set_env): +def test_list_action(set_env) -> None: runner = CliRunner() config_file = CONFIG @@ -95,7 +95,7 @@ def test_list_action(set_env): assert "Bye" in output -def test_config_with_service_type_actions(set_env): +def test_config_with_service_type_actions(set_env) -> None: runner = CliRunner() config_file = CONFIG @@ -122,7 +122,7 @@ def test_config_with_service_type_actions(set_env): assert "Getting default configuration for service type fake" in output -def test_list_cannot_be_used_with_other_actions(set_env): +def test_list_cannot_be_used_with_other_actions(set_env) -> None: runner = CliRunner() config_file = CONFIG @@ -145,7 +145,7 @@ def test_list_cannot_be_used_with_other_actions(set_env): assert "Cannot use the `list` action with other actions" in result.output -def test_config_cannot_be_used_with_other_actions(set_env): +def test_config_cannot_be_used_with_other_actions(set_env) -> None: runner = CliRunner() config_file = CONFIG @@ -172,7 +172,7 @@ def test_config_cannot_be_used_with_other_actions(set_env): @patch( "connectors.service_cli.load_config", side_effect=Exception("something went wrong") ) -def test_main_with_invalid_configuration(load_config, set_logger): +def test_main_with_invalid_configuration(load_config, set_logger) -> None: runner = CliRunner() log_level = "DEBUG" # should be ignored! @@ -183,7 +183,7 @@ def test_main_with_invalid_configuration(load_config, set_logger): set_logger.assert_called_with(logging.INFO, filebeat=True) -def test_unknown_service_type(set_env): +def test_unknown_service_type(set_env) -> None: runner = CliRunner() config_file = CONFIG @@ -211,7 +211,7 @@ def test_unknown_service_type(set_env): @patch("connectors.service_cli._get_uvloop") @patch("connectors.service_cli.asyncio") -def test_uvloop_success(patched_asyncio, patched_uvloop): +def test_uvloop_success(patched_asyncio, patched_uvloop) -> None: get_event_loop(True) assert patched_asyncio.set_event_loop_policy.called_once_with( patched_uvloop.EventLoopPolicy() @@ -221,7 +221,7 @@ def test_uvloop_success(patched_asyncio, patched_uvloop): @patch("connectors.service_cli._get_uvloop", side_effect=Exception("import fails")) @patch("connectors.service_cli.asyncio") @patch("connectors.service_cli.logger") -def test_uvloop_error(patched_logger, patched_asyncio, patched_uvloop): +def test_uvloop_error(patched_logger, patched_asyncio, patched_uvloop) -> None: get_event_loop(True) patched_logger.warning.assert_any_call( "Unable to enable uvloop: import fails. Running with default event loop" diff --git a/tests/test_sink.py b/tests/test_sink.py index 2d9a3492e..27f5a625c 100644 --- a/tests/test_sink.py +++ b/tests/test_sink.py @@ -8,6 +8,7 @@ import itertools import json from copy import deepcopy +from typing import Any, Dict, Optional, Union from unittest import mock from unittest.mock import ANY, AsyncMock, Mock, call, patch @@ -50,16 +51,22 @@ DOC_TWO_ID = 2 DOC_THREE_ID = 3 -DOC_ONE = {"_id": DOC_ONE_ID, "_timestamp": TIMESTAMP} +DOC_ONE: Dict[str, Union[datetime.datetime, int]] = { + "_id": DOC_ONE_ID, + "_timestamp": TIMESTAMP, +} -DOC_ONE_DIFFERENT_TIMESTAMP = { +DOC_ONE_DIFFERENT_TIMESTAMP: Dict[str, Union[datetime.datetime, int]] = { "_id": DOC_ONE_ID, "_timestamp": TIMESTAMP + datetime.timedelta(days=1), } -DOC_TWO = {"_id": 2, "_timestamp": TIMESTAMP} -DOC_THREE = {"_id": 3, "_timestamp": TIMESTAMP} -DOC_FOUR = {"_id": 4, "_timestamp": TIMESTAMP} +DOC_TWO: Dict[str, Union[datetime.datetime, int]] = {"_id": 2, "_timestamp": TIMESTAMP} +DOC_THREE: Dict[str, Union[datetime.datetime, int]] = { + "_id": 3, + "_timestamp": TIMESTAMP, +} +DOC_FOUR: Dict[str, Union[datetime.datetime, int]] = {"_id": 4, "_timestamp": TIMESTAMP} BULK_ACTION_ERROR = "some error" @@ -69,7 +76,7 @@ CONTENT_EXTRACTION_DISABLED = False -def failed_bulk_action(doc_id, action, result, error=BULK_ACTION_ERROR): +def failed_bulk_action(doc_id, action, result, error: str = BULK_ACTION_ERROR): return {action: {"_id": doc_id, "result": result, "error": error}} @@ -77,15 +84,19 @@ def successful_bulk_action(doc_id, action, result): return {action: {"_id": doc_id, "result": result}} -def successful_action_log_message(doc_id, action, result): +def successful_action_log_message(doc_id, action, result) -> str: return f"Successfully executed '{action}' on document with id '{doc_id}'. Result: {result}" -def successful_operation_with_non_successful_result_log_message(doc_id, action, result): +def successful_operation_with_non_successful_result_log_message( + doc_id, action, result +) -> str: return f"Executed '{action}' on document with id '{doc_id}', but got non-successful result: {result}" -def failed_action_log_message(doc_id, action, result, error=BULK_ACTION_ERROR): +def failed_action_log_message( + doc_id, action, result, error: str = BULK_ACTION_ERROR +) -> str: return ( f"Failed to execute '{action}' on document with id '{doc_id}'. Error: {error}" ) @@ -95,7 +106,7 @@ def failed_action_log_message(doc_id, action, result, error=BULK_ACTION_ERROR): @pytest.mark.asyncio async def test_prepare_content_index_raise_error_when_index_creation_failed( mock_responses, -): +) -> None: index_name = "search-new-index" config = {"host": "http://nowhere.com:9200", "user": "tarek", "password": "blah"} headers = {"X-Elastic-Product": "Elasticsearch"} @@ -131,7 +142,7 @@ async def test_prepare_content_index_raise_error_when_index_creation_failed( @pytest.mark.asyncio async def test_prepare_content_index_create_index( mock_responses, -): +) -> None: index_name = "search-new-index" language_code = "jp" config = {"host": "http://nowhere.com:9200", "user": "tarek", "password": "blah"} @@ -181,7 +192,7 @@ async def test_prepare_content_index_create_index( @patch("connectors.es.sink.CANCELATION_TIMEOUT", -1) @pytest.mark.asyncio -async def test_prepare_content_index(mock_responses): +async def test_prepare_content_index(mock_responses) -> None: language_code = "en" config = {"host": "http://nowhere.com:9200", "user": "tarek", "password": "blah"} headers = {"X-Elastic-Product": "Elasticsearch"} @@ -214,7 +225,7 @@ async def test_prepare_content_index(mock_responses): create_index_mock.assert_not_called() -def set_responses(mock_responses, ts=None): +def set_responses(mock_responses, ts: Optional[str] = None) -> None: if ts is None: ts = datetime.datetime.now().isoformat() headers = {"X-Elastic-Product": "Elasticsearch"} @@ -359,7 +370,7 @@ async def _dl(doit=True, timestamp=None): await es.close() -def index_operation(doc): +def index_operation(doc) -> Dict[str, Any]: # deepcopy as get_docs mutates docs doc_copy = deepcopy(doc) doc_id = str(doc_copy.pop("_id")) @@ -368,7 +379,7 @@ def index_operation(doc): return {"_op_type": "index", "_index": INDEX, "_id": doc_id, "doc": doc_copy} -def update_operation(doc): +def update_operation(doc) -> Dict[str, Any]: # deepcopy as get_docs mutates docs doc_copy = deepcopy(doc) doc_id = str(doc_copy.pop("_id")) @@ -377,11 +388,11 @@ def update_operation(doc): return {"_op_type": "update", "_index": INDEX, "_id": doc_id, "doc": doc_copy} -def delete_operation(doc): +def delete_operation(doc) -> Dict[str, Any]: return {"_op_type": "delete", "_index": INDEX, "_id": str(doc["_id"])} -def end_docs_operation(): +def end_docs_operation() -> str: return "END_DOCS" @@ -405,7 +416,7 @@ def total_downloads(count): return count -async def queue_mock(): +async def queue_mock() -> Mock: queue = Mock() future = asyncio.Future() future.set_result(1) @@ -438,7 +449,7 @@ def queue_called_with_operations(queue, operations): ) -async def basic_rule_engine_mock(return_values): +async def basic_rule_engine_mock(return_values) -> Mock: basic_rule_engine = Mock() basic_rule_engine.should_ingest = ( Mock(side_effect=list(return_values)) @@ -448,7 +459,7 @@ async def basic_rule_engine_mock(return_values): return basic_rule_engine -async def lazy_downloads_mock(): +async def lazy_downloads_mock() -> Mock: lazy_downloads = Mock() future = asyncio.Future() future.set_result(1) @@ -459,9 +470,9 @@ async def lazy_downloads_mock(): async def setup_extractor( queue, basic_rule_engine=None, - sync_rules_enabled=False, - content_extraction_enabled=False, -): + sync_rules_enabled: bool = False, + content_extraction_enabled: bool = False, +) -> Extractor: config = { "username": "elastic", "password": "changeme", @@ -727,7 +738,7 @@ async def test_get_docs( expected_total_docs_created, expected_total_docs_deleted, expected_total_downloads, -): +) -> None: lazy_downloads = await lazy_downloads_mock() yield_existing_documents_metadata.return_value = AsyncIterator( @@ -908,7 +919,7 @@ async def test_get_docs_incrementally( expected_total_docs_created, expected_total_docs_deleted, expected_total_downloads, -): +) -> None: lazy_downloads = await lazy_downloads_mock() with mock.patch("connectors.utils.ConcurrentTasks", return_value=lazy_downloads): @@ -1019,7 +1030,7 @@ async def test_get_access_control_docs( expected_total_docs_updated, expected_total_docs_created, expected_total_docs_deleted, -): +) -> None: yield_existing_documents_metadata.return_value = AsyncIterator( [(str(doc["_id"]), doc["_timestamp"]) for doc in existing_docs] ) @@ -1123,7 +1134,7 @@ async def test_get_access_control_docs( ), ], ) -def test_bulk_populate_stats(res, expected_result): +def test_bulk_populate_stats(res, expected_result) -> None: sink = Sink( client=None, queue=None, @@ -1152,7 +1163,7 @@ def test_bulk_populate_stats(res, expected_result): @pytest.mark.asyncio -async def test_batch_bulk_with_retry(): +async def test_batch_bulk_with_retry() -> None: config = { "username": "elastic", "password": "changeme", @@ -1187,7 +1198,7 @@ async def test_batch_bulk_with_retry(): @pytest.mark.asyncio -async def test_batch_bulk_with_error_monitor(): +async def test_batch_bulk_with_error_monitor() -> None: config = { "username": "elastic", "password": "changeme", @@ -1249,7 +1260,7 @@ def _create_response(num_successful, num_failed): @pytest.mark.asyncio -async def test_batch_bulk_with_errors(patch_logger): +async def test_batch_bulk_with_errors(patch_logger) -> None: config = { "username": "elastic", "password": "changeme", @@ -1301,7 +1312,7 @@ async def test_batch_bulk_with_errors(patch_logger): @pytest.mark.asyncio async def test_sync_orchestrator_done_and_cleanup( extractor_task, extractor_task_done, sink_task, sink_task_done, expected_result -): +) -> None: if extractor_task is not None: extractor_task.cancel = Mock() extractor_task.done.return_value = extractor_task_done @@ -1329,7 +1340,7 @@ async def test_sync_orchestrator_done_and_cleanup( @pytest.mark.asyncio -async def test_extractor_put_doc(): +async def test_extractor_put_doc() -> None: doc = {"id": 123} queue = Mock() queue.put = AsyncMock() @@ -1346,7 +1357,7 @@ async def test_extractor_put_doc(): @mock.patch("connectors.utils.ConcurrentTasks.cancel") async def test_extractor_get_docs_when_downloads_fail( yield_existing_documents_metadata, concurrent_tasks_cancel -): +) -> None: queue = await queue_mock() yield_existing_documents_metadata.return_value = AsyncIterator([]) @@ -1368,7 +1379,7 @@ async def test_extractor_get_docs_when_downloads_fail( @pytest.mark.asyncio -async def test_force_canceled_extractor_put_doc(): +async def test_force_canceled_extractor_put_doc() -> None: doc = {"id": 123} queue = Mock() queue.put = AsyncMock() @@ -1381,7 +1392,7 @@ async def test_force_canceled_extractor_put_doc(): @pytest.mark.asyncio -async def test_force_canceled_extractor_with_other_errors(patch_logger): +async def test_force_canceled_extractor_with_other_errors(patch_logger) -> None: queue = Mock() queue.put = AsyncMock() extractor = Extractor( @@ -1400,7 +1411,7 @@ async def test_force_canceled_extractor_with_other_errors(patch_logger): @pytest.mark.asyncio -async def test_sink_fetch_doc(): +async def test_sink_fetch_doc() -> None: expected_doc = {"id": 123} queue = Mock() queue.get = AsyncMock(return_value=expected_doc) @@ -1422,7 +1433,7 @@ async def test_sink_fetch_doc(): @pytest.mark.asyncio -async def test_force_canceled_sink_fetch_doc(): +async def test_force_canceled_sink_fetch_doc() -> None: expected_doc = {"id": 123} queue = Mock() queue.get = AsyncMock(return_value=expected_doc) @@ -1445,7 +1456,7 @@ async def test_force_canceled_sink_fetch_doc(): @pytest.mark.asyncio -async def test_force_canceled_sink_with_other_errors(patch_logger): +async def test_force_canceled_sink_with_other_errors(patch_logger) -> None: queue = Mock() queue.get = AsyncMock(side_effect=Exception("a non-ForceCanceledError")) sink = Sink( @@ -1494,7 +1505,7 @@ async def test_force_canceled_sink_with_other_errors(patch_logger): ], ) @pytest.mark.asyncio -async def test_cancel_sync(extractor_task_done, sink_task_done, force_cancel): +async def test_cancel_sync(extractor_task_done, sink_task_done, force_cancel) -> None: config = {"host": "http://nowhere.com:9200", "user": "tarek", "password": "blah"} es = SyncOrchestrator(config) es._extractor = Mock() @@ -1525,7 +1536,7 @@ async def test_cancel_sync(extractor_task_done, sink_task_done, force_cancel): @pytest.mark.asyncio -async def test_extractor_run_when_mem_full_is_raised(): +async def test_extractor_run_when_mem_full_is_raised() -> None: docs_from_source = [ {"_id": 1}, {"_id": 2}, @@ -1559,7 +1570,7 @@ def _put_side_effect(value): @pytest.mark.asyncio async def test_should_not_log_bulk_operations_if_doc_id_tracing_is_disabled( patch_logger, -): +) -> None: action = "create" result = "created" operations = [] @@ -1647,7 +1658,7 @@ async def test_should_not_log_bulk_operations_if_doc_id_tracing_is_disabled( @pytest.mark.asyncio async def test_should_log_bulk_operations_if_doc_id_tracing_is_enabled( patch_logger, operation_results, expected_logs -): +) -> None: operations = [] client = Mock() client.bulk_insert = AsyncMock(return_value={"items": operation_results}) @@ -1671,7 +1682,7 @@ async def test_should_log_bulk_operations_if_doc_id_tracing_is_enabled( @pytest.mark.asyncio -async def test_should_log_error_when_id_is_missing(patch_logger): +async def test_should_log_error_when_id_is_missing(patch_logger) -> None: operations = [] client = Mock() # item missing id @@ -1706,7 +1717,7 @@ async def test_should_log_error_when_id_is_missing(patch_logger): @pytest.mark.asyncio -async def test_should_log_error_when_unknown_action_item_returned(patch_logger): +async def test_should_log_error_when_unknown_action_item_returned(patch_logger) -> None: operations = [] client = Mock() diff --git a/tests/test_source.py b/tests/test_source.py index 6e5e48b48..72a375070 100644 --- a/tests/test_source.py +++ b/tests/test_source.py @@ -5,6 +5,7 @@ # from datetime import datetime from decimal import Decimal +from typing import Dict, List, Union from unittest import TestCase, mock import pytest @@ -28,7 +29,14 @@ get_source_klasses, ) -CONFIG = { +CONFIG: Dict[ + str, + Union[ + Dict[str, str], + Dict[str, Union[None, int, str]], + Dict[str, Union[List[str], str]], + ], +] = { "host": { "label": "MongoDB Host", "type": "str", @@ -67,14 +75,14 @@ DATE_STRING_ISO_FORMAT = "2023-01-01T13:37:42+02:00" -def test_field(): +def test_field() -> None: # stupid holder f = Field("name") assert f.label == "name" assert f.field_type == "str" -def test_field_convert(): +def test_field_convert() -> None: assert Field("name", field_type="str").value == "" assert Field("name", value="", field_type="str").value == "" assert Field("name", value="1", field_type="str").value == "1" @@ -124,7 +132,7 @@ def test_field_convert(): assert Field("name", value="not a dict", field_type="dict").value == "not a dict" -def test_data_source_configuration(): +def test_data_source_configuration() -> None: c = DataSourceConfiguration(CONFIG) assert c["database"] == "sample_airbnb" assert c.get_field("database").label == "MongoDB Database" @@ -133,7 +141,7 @@ def test_data_source_configuration(): assert c["new"] == "one" -def test_default(): +def test_default() -> None: c = DataSourceConfiguration(CONFIG) assert c.get("database") == "sample_airbnb" assert c.get("dd", 1) == 1 @@ -171,7 +179,7 @@ def test_default(): ) def test_value_returns_correct_value( field_type, required, default_value, value, expected_value -): +) -> None: assert ( Field( "name", @@ -188,15 +196,15 @@ class MyConnector: id = "1" # noqa A003 service_type = "yea" - def __init__(self, *args): + def __init__(self, *args) -> None: pass -def test_get_source_klass(): +def test_get_source_klass() -> None: assert get_source_klass("test_source:MyConnector") is MyConnector -def test_get_source_klasses(): +def test_get_source_klasses() -> None: settings = { "sources": {"yea": "test_source:MyConnector", "yea2": "test_source:MyConnector"} } @@ -432,7 +440,7 @@ def test_get_source_klasses(): ), ], ) -async def test_check_valid_when_validations_succeed_no_errors_raised(config): +async def test_check_valid_when_validations_succeed_no_errors_raised(config) -> None: c = DataSourceConfiguration(config) c.check_valid() @@ -644,14 +652,14 @@ async def test_check_valid_when_validations_succeed_no_errors_raised(config): ), ], ) -async def test_check_valid_when_validations_fail_raises_error(config): +async def test_check_valid_when_validations_fail_raises_error(config) -> None: c = DataSourceConfiguration(config) with pytest.raises(ConfigurableFieldValueError): c.check_valid() @pytest.mark.asyncio -async def test_check_valid_when_dependencies_are_invalid_raises_error(): +async def test_check_valid_when_dependencies_are_invalid_raises_error() -> None: config = { "port": { "type": "int", @@ -671,7 +679,12 @@ async def test_check_valid_when_dependencies_are_invalid_raises_error(): # ABCs class DataSource(BaseDataSource): @classmethod - def get_default_configuration(cls): + def get_default_configuration( + cls, + ) -> Dict[ + str, + Union[Dict[str, str], Dict[str, Union[bool, str]], Dict[str, Union[int, str]]], + ]: return { "host": { "value": "127.0.0.1", @@ -700,7 +713,7 @@ def get_default_configuration(cls): @pytest.mark.asyncio @mock.patch("connectors.filtering.validation.FilteringValidator.validate") -async def test_validate_filter(validator_mock): +async def test_validate_filter(validator_mock) -> None: validator_mock.return_value = "valid" assert ( @@ -712,7 +725,7 @@ async def test_validate_filter(validator_mock): @pytest.mark.asyncio -async def test_invalid_configuration_raises_error(): +async def test_invalid_configuration_raises_error() -> None: configuration = {} with pytest.raises(TypeError) as e: @@ -723,7 +736,7 @@ async def test_invalid_configuration_raises_error(): @pytest.mark.asyncio -async def test_base_class(): +async def test_base_class() -> None: configuration = DataSourceConfiguration({}) with pytest.raises(NotImplementedError): @@ -872,7 +885,7 @@ async def test_base_class(): ], ) @pytest.mark.asyncio -async def test_serialize(raw_doc, expected_doc): +async def test_serialize(raw_doc, expected_doc) -> None: with mock.patch.object( BaseDataSource, "get_default_configuration", return_value={} ): @@ -888,7 +901,7 @@ async def test_serialize(raw_doc, expected_doc): @pytest.mark.asyncio -async def test_validate_config_fields_when_valid_no_errors_raised(): +async def test_validate_config_fields_when_valid_no_errors_raised() -> None: configuration = { "host": { "value": "127.0.0.1", @@ -912,7 +925,7 @@ async def test_validate_config_fields_when_valid_no_errors_raised(): @pytest.mark.asyncio -async def test_validate_config_fields_when_invalid_raises_error(): +async def test_validate_config_fields_when_invalid_raises_error() -> None: configuration = { "host": { "value": "127.0.0.1", diff --git a/tests/test_sync_job_runner.py b/tests/test_sync_job_runner.py index 5537966cd..3d0c6695e 100644 --- a/tests/test_sync_job_runner.py +++ b/tests/test_sync_job_runner.py @@ -4,6 +4,7 @@ # you may not use this file except in compliance with the Elastic License 2.0. # import asyncio +from typing import Dict, Optional from unittest.mock import ANY, AsyncMock, Mock, patch import pytest @@ -29,7 +30,7 @@ SYNC_CURSOR = {"foo": "bar"} -def mock_connector(): +def mock_connector() -> Mock: connector = Mock() connector.id = "1" connector.index_name = SEARCH_INDEX_NAME @@ -48,7 +49,7 @@ def mock_connector(): return connector -def mock_sync_job(job_type, index_name): +def mock_sync_job(job_type, index_name) -> Mock: sync_job = Mock() sync_job.id = "1" sync_job.configuration = {} @@ -74,16 +75,16 @@ def mock_sync_job(job_type, index_name): def create_runner( - source_changed=True, - source_available=True, + source_changed: bool = True, + source_available: bool = True, validate_config_exception=None, - job_type=JobType.FULL, - index_name=SEARCH_INDEX_NAME, - sync_cursor=SYNC_CURSOR, - connector=None, + job_type: JobType = JobType.FULL, + index_name: str = SEARCH_INDEX_NAME, + sync_cursor: Dict[str, str] = SYNC_CURSOR, + connector: Optional[Mock] = None, service_config=None, es_config=None, -): +) -> SyncJobRunner: source_klass = Mock() data_provider = Mock() data_provider.tweak_bulk_options = Mock() @@ -140,7 +141,7 @@ def sync_orchestrator_mock(): yield sync_orchestrator_mock -def create_runner_yielding_docs(docs=None): +def create_runner_yielding_docs(docs=None) -> SyncJobRunner: if docs is None: docs = [] @@ -155,7 +156,7 @@ def create_runner_yielding_docs(docs=None): @pytest.mark.asyncio -async def test_connector_content_sync_starts_fail(): +async def test_connector_content_sync_starts_fail() -> None: sync_job_runner = create_runner() # Do nothing in the first call, and the last_sync_status is set to `in_progress` by another instance in the subsequent calls @@ -184,7 +185,7 @@ def _reset_last_sync_status(): @pytest.mark.asyncio -async def test_connector_access_control_sync_starts_fail(): +async def test_connector_access_control_sync_starts_fail() -> None: sync_job_runner = create_runner(job_type=JobType.ACCESS_CONTROL) # Do nothing in the first call, and the last_access_control_sync_status is set to `in_progress` by another instance in the subsequent calls @@ -215,7 +216,7 @@ def _reset_last_sync_status(): @pytest.mark.asyncio -async def test_connector_incremental_sync_job_starts_fail(): +async def test_connector_incremental_sync_job_starts_fail() -> None: connector = mock_connector() # disable incremental sync connector.features.incremental_sync_enabled.return_value = False @@ -233,7 +234,7 @@ async def test_connector_incremental_sync_job_starts_fail(): @pytest.mark.asyncio -async def test_connector_content_claim_fails(): +async def test_connector_content_claim_fails() -> None: sync_job_runner = create_runner() # Do nothing in the first call, and the last_sync_status is set to `in_progress` by another instance in the subsequent calls @@ -266,8 +267,8 @@ def _reset_last_sync_status(): ) @pytest.mark.asyncio async def test_source_not_changed( - job_type, sync_cursor_to_claim, sync_cursor_to_update -): + job_type: JobType, sync_cursor_to_claim, sync_cursor_to_update +) -> None: sync_job_runner = create_runner(source_changed=False, job_type=job_type) await sync_job_runner.execute() @@ -299,7 +300,9 @@ async def test_source_not_changed( ], ) @pytest.mark.asyncio -async def test_source_invalid_config(job_type, sync_cursor): +async def test_source_invalid_config( + job_type: JobType, sync_cursor: Dict[str, str] +) -> None: sync_job_runner = create_runner( job_type=job_type, validate_config_exception=Exception(), @@ -337,7 +340,9 @@ async def test_source_invalid_config(job_type, sync_cursor): ], ) @pytest.mark.asyncio -async def test_source_not_available(job_type, sync_cursor): +async def test_source_not_available( + job_type: JobType, sync_cursor: Dict[str, str] +) -> None: sync_job_runner = create_runner( job_type=job_type, source_available=False, sync_cursor=sync_cursor ) @@ -366,7 +371,7 @@ async def test_source_not_available(job_type, sync_cursor): @pytest.mark.parametrize("job_type", [JobType.FULL, JobType.INCREMENTAL]) @pytest.mark.asyncio -async def test_invalid_filtering(job_type, sync_orchestrator_mock): +async def test_invalid_filtering(job_type: JobType, sync_orchestrator_mock) -> None: ingestion_stats = { "indexed_document_count": 0, "indexed_document_volume": 0, @@ -396,7 +401,7 @@ async def test_invalid_filtering(job_type, sync_orchestrator_mock): @pytest.mark.asyncio async def test_invalid_filtering_access_control_sync_still_executed( sync_orchestrator_mock, -): +) -> None: ingestion_stats = { "indexed_document_count": 0, "indexed_document_volume": 0, @@ -431,7 +436,9 @@ async def test_invalid_filtering_access_control_sync_still_executed( ], ) @pytest.mark.asyncio -async def test_async_bulk_error(job_type, sync_cursor, sync_orchestrator_mock): +async def test_async_bulk_error( + job_type: JobType, sync_cursor: Dict[str, str], sync_orchestrator_mock +) -> None: error = "something wrong" ingestion_stats = { "indexed_document_count": 0, @@ -463,7 +470,7 @@ async def test_async_bulk_error(job_type, sync_cursor, sync_orchestrator_mock): @pytest.mark.asyncio async def test_access_control_sync_fails_with_insufficient_license( sync_orchestrator_mock, -): +) -> None: ingestion_stats = { "indexed_document_count": 0, "indexed_document_volume": 0, @@ -504,7 +511,9 @@ async def test_access_control_sync_fails_with_insufficient_license( ], ) @pytest.mark.asyncio -async def test_sync_job_runner(job_type, sync_cursor, sync_orchestrator_mock): +async def test_sync_job_runner( + job_type: JobType, sync_cursor: Dict[str, str], sync_orchestrator_mock +) -> None: ingestion_stats = { "indexed_document_count": 25, "indexed_document_volume": 30, @@ -538,7 +547,9 @@ async def test_sync_job_runner(job_type, sync_cursor, sync_orchestrator_mock): ], ) @pytest.mark.asyncio -async def test_sync_job_runner_suspend(job_type, sync_cursor, sync_orchestrator_mock): +async def test_sync_job_runner_suspend( + job_type: JobType, sync_cursor: Dict[str, str], sync_orchestrator_mock +) -> None: ingestion_stats = { "indexed_document_count": 25, "indexed_document_volume": 30, @@ -570,7 +581,9 @@ async def test_sync_job_runner_suspend(job_type, sync_cursor, sync_orchestrator_ @patch("connectors.sync_job_runner.ES_ID_SIZE_LIMIT", 1) @pytest.mark.asyncio -async def test_prepare_docs_when_original_id_and_hashed_id_too_long_then_skip_doc(): +async def test_prepare_docs_when_original_id_and_hashed_id_too_long_then_skip_doc() -> ( + None +): _id_too_long = "ab" sync_job_runner = create_runner_yielding_docs(docs=[({"_id": _id_too_long}, None)]) @@ -588,7 +601,7 @@ async def test_prepare_docs_when_original_id_and_hashed_id_too_long_then_skip_do @pytest.mark.asyncio async def test_prepare_docs_when_original_id_below_limit_then_yield_doc_with_original_id( _id, -): +) -> None: sync_job_runner = create_runner_yielding_docs(docs=[({"_id": _id}, None)]) docs = [] @@ -601,7 +614,9 @@ async def test_prepare_docs_when_original_id_below_limit_then_yield_doc_with_ori @patch("connectors.sync_job_runner.ES_ID_SIZE_LIMIT", 3) @pytest.mark.asyncio -async def test_prepare_docs_when_original_id_above_limit_and_hashed_id_below_limit_then_yield_doc_with_hashed_id(): +async def test_prepare_docs_when_original_id_above_limit_and_hashed_id_below_limit_then_yield_doc_with_hashed_id() -> ( + None +): _id_too_long = "abcd" hashed_id = "a" @@ -628,8 +643,8 @@ async def test_prepare_docs_when_original_id_above_limit_and_hashed_id_below_lim @patch("connectors.sync_job_runner.JOB_REPORTING_INTERVAL", 0) @patch("connectors.sync_job_runner.JOB_CHECK_INTERVAL", 0) async def test_sync_job_runner_reporting_metadata( - job_type, sync_cursor, sync_orchestrator_mock -): + job_type: JobType, sync_cursor: Dict[str, str], sync_orchestrator_mock +) -> None: ingestion_stats = { "indexed_document_count": 15, "indexed_document_volume": 230, @@ -665,7 +680,9 @@ async def test_sync_job_runner_reporting_metadata( @pytest.mark.asyncio @patch("connectors.sync_job_runner.JOB_REPORTING_INTERVAL", 0) @patch("connectors.sync_job_runner.JOB_CHECK_INTERVAL", 0) -async def test_sync_job_runner_connector_not_found(job_type, sync_orchestrator_mock): +async def test_sync_job_runner_connector_not_found( + job_type: JobType, sync_orchestrator_mock +) -> None: ingestion_stats = { "indexed_document_count": 15, "indexed_document_volume": 230, @@ -707,8 +724,8 @@ def _raise_document_not_found_error(): @patch("connectors.sync_job_runner.JOB_REPORTING_INTERVAL", 0) @patch("connectors.sync_job_runner.JOB_CHECK_INTERVAL", 0) async def test_sync_job_runner_sync_job_not_found( - job_type, sync_cursor, sync_orchestrator_mock -): + job_type: JobType, sync_cursor: Dict[str, str], sync_orchestrator_mock +) -> None: ingestion_stats = { "indexed_document_count": 15, "indexed_document_volume": 230, @@ -741,7 +758,9 @@ async def test_sync_job_runner_sync_job_not_found( @pytest.mark.asyncio @patch("connectors.sync_job_runner.JOB_REPORTING_INTERVAL", 0) @patch("connectors.sync_job_runner.JOB_CHECK_INTERVAL", 0) -async def test_sync_job_runner_canceled(job_type, sync_cursor, sync_orchestrator_mock): +async def test_sync_job_runner_canceled( + job_type: JobType, sync_cursor: Dict[str, str], sync_orchestrator_mock +) -> None: ingestion_stats = { "indexed_document_count": 15, "indexed_document_volume": 230, @@ -783,8 +802,8 @@ def _update_job_status(): @patch("connectors.sync_job_runner.JOB_REPORTING_INTERVAL", 0) @patch("connectors.sync_job_runner.JOB_CHECK_INTERVAL", 0) async def test_sync_job_runner_not_running( - job_type, sync_cursor, sync_orchestrator_mock -): + job_type: JobType, sync_cursor: Dict[str, str], sync_orchestrator_mock +) -> None: ingestion_stats = { "indexed_document_count": 15, "indexed_document_volume": 230, @@ -817,7 +836,7 @@ def _update_job_status(): @pytest.mark.asyncio -async def test_sync_job_runner_sets_features_for_data_provider(): +async def test_sync_job_runner_sets_features_for_data_provider() -> None: sync_job_runner = create_runner() await sync_job_runner.execute() @@ -825,7 +844,7 @@ async def test_sync_job_runner_sets_features_for_data_provider(): assert sync_job_runner.data_provider.set_features.called -def test_skip_unchanged_documents_enabled(): +def test_skip_unchanged_documents_enabled() -> None: sync_job_runner = create_runner() class MockDataSource(BaseDataSource): @@ -841,7 +860,7 @@ def __init__(self, config): ) -def test_skip_unchanged_documents_enabled_disabled(): +def test_skip_unchanged_documents_enabled_disabled() -> None: sync_job_runner = create_runner() class MockDataSource(BaseDataSource): @@ -860,7 +879,7 @@ def get_docs_incrementally(self, sync_cursor, filtering=None): ) -def test_skip_unchanged_documents_enabled_disabled_by_full_sync(): +def test_skip_unchanged_documents_enabled_disabled_by_full_sync() -> None: sync_job_runner = create_runner() class MockDataSource(BaseDataSource): @@ -881,7 +900,7 @@ def __init__(self, config): Mock(return_value=True), ) @pytest.mark.asyncio -async def test_incremental_sync_with_skip_unchanged_documents_generator(): +async def test_incremental_sync_with_skip_unchanged_documents_generator() -> None: sync_job_runner = create_runner(job_type=JobType.INCREMENTAL) data_provider = Mock() @@ -902,7 +921,7 @@ async def test_incremental_sync_with_skip_unchanged_documents_generator(): Mock(return_value=False), ) @pytest.mark.asyncio -async def test_incremental_sync_without_skip_unchanged_documents_generator(): +async def test_incremental_sync_without_skip_unchanged_documents_generator() -> None: connector = mock_connector() connector.sync_cursor = {} @@ -922,7 +941,7 @@ async def test_incremental_sync_without_skip_unchanged_documents_generator(): ) -async def test_unsupported_job_type(): +async def test_unsupported_job_type() -> None: connector = mock_connector() connector.sync_cursor = {} @@ -969,7 +988,7 @@ async def test_unsupported_job_type(): ) def test_content_extraction_enabled( sync_job_config, pipeline_config, expected_enabled, expected_log, patch_logger -): +) -> None: sync_job_runner = create_runner() class MockDataSource(BaseDataSource): diff --git a/tests/test_utils.py b/tests/test_utils.py index d5131488d..8f0a1699f 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -163,14 +163,14 @@ ), ], ) -def test_next_run(cron_statement, now, expected_next_run): +def test_next_run(cron_statement, now, expected_next_run) -> None: # can run within two minutes assert next_run(cron_statement, now).isoformat( " ", "seconds" ) == expected_next_run.isoformat(" ", "seconds") -def test_with_utc_tz_naive_timestamp(): +def test_with_utc_tz_naive_timestamp() -> None: ts_naive = datetime(2024, 5, 27, 12, 0, 0) ts_utc = with_utc_tz(ts_naive) assert ts_utc.tzinfo == timezone.utc @@ -182,7 +182,7 @@ def test_with_utc_tz_naive_timestamp(): assert ts_utc.second == 0 -def test_with_utc_tz_aware_timestamp(): +def test_with_utc_tz_aware_timestamp() -> None: ts_aware = datetime( 2024, 5, 27, 12, 0, 0, tzinfo=timezone(timedelta(hours=5)) ) # Timezone aware timestamp with +5 offset @@ -196,7 +196,7 @@ def test_with_utc_tz_aware_timestamp(): assert ts_utc.second == 0 -def test_with_utc_tz_timestamp_in_utc(): +def test_with_utc_tz_timestamp_in_utc() -> None: ts_aware = datetime(2024, 5, 27, 12, 0, 0, tzinfo=timezone.utc) ts_utc = with_utc_tz(ts_aware) assert ts_utc.tzinfo == timezone.utc @@ -208,7 +208,7 @@ def test_with_utc_tz_timestamp_in_utc(): assert ts_utc.second == 0 -def test_invalid_names(): +def test_invalid_names() -> None: for name in ( "index?name", "index#name", @@ -223,7 +223,7 @@ def test_invalid_names(): validate_index_name(name) -def test_mem_queue_speed(): +def test_mem_queue_speed() -> None: def mem_queue(): import asyncio @@ -260,7 +260,7 @@ async def run(): @pytest.mark.asyncio -async def test_mem_queue_race(): +async def test_mem_queue_race() -> None: item = "small stuff" queue = MemQueue( maxmemsize=get_size(item) * 2 + 1, refresh_interval=0.01, refresh_timeout=1 @@ -291,7 +291,7 @@ async def remove_one(): @pytest.mark.asyncio -async def test_mem_queue(): +async def test_mem_queue() -> None: # Initial timeout is really small so that the test is fast. # The part of the test before timeout increase will take at least refresh_timeout # seconds to execute, so if timeout is 60 seconds, then it'll take 60+ seconds. @@ -331,7 +331,7 @@ async def remove_data(): @pytest.mark.asyncio -async def test_mem_queue_too_large_item(): +async def test_mem_queue_too_large_item() -> None: """ When an item is added to the queue that is larger than the queue capacity then the item is discarded @@ -353,7 +353,7 @@ async def test_mem_queue_too_large_item(): @pytest.mark.asyncio -async def test_mem_queue_put_nowait(): +async def test_mem_queue_put_nowait() -> None: queue = MemQueue( maxsize=5, maxmemsize=1000, refresh_interval=0.1, refresh_timeout=0.5 ) @@ -362,24 +362,24 @@ async def test_mem_queue_put_nowait(): queue.put_nowait(i) with pytest.raises(asyncio.QueueFull) as e: - await queue.put_nowait("x") + queue.put_nowait("x") assert e is not None -def test_get_base64_value(): +def test_get_base64_value() -> None: """This test verify get_base64_value method and convert encoded data into base64""" expected_result = get_base64_value("dummy".encode("utf-8")) assert expected_result == "ZHVtbXk=" -def test_decode_base64_value(): +def test_decode_base64_value() -> None: """This test verify decode_base64_value and decodes base64 encoded data""" expected_result = decode_base64_value("ZHVtbXk=".encode("utf-8")) assert expected_result == b"dummy" -def test_try_acquire(): +def test_try_acquire() -> None: bound_value = 5 sem = NonBlockingBoundedSemaphore(bound_value) @@ -393,7 +393,7 @@ def test_try_acquire(): @pytest.mark.asyncio -async def test_concurrent_runner(): +async def test_concurrent_runner() -> None: results = [] def _results_callback(task): @@ -413,7 +413,7 @@ async def coroutine(i): @pytest.mark.asyncio -async def test_concurrent_runner_canceled(): +async def test_concurrent_runner_canceled() -> None: results = [] tasks = [] @@ -439,7 +439,7 @@ async def coroutine(i): @pytest.mark.asyncio -async def test_concurrent_runner_canceled_with_waiting_task(): +async def test_concurrent_runner_canceled_with_waiting_task() -> None: results = [] def _results_callback(task): @@ -467,7 +467,7 @@ async def coroutine(i, sleep_time): @pytest.mark.asyncio -async def test_concurrent_runner_fails(): +async def test_concurrent_runner_fails() -> None: results = [] def _results_callback(task): @@ -490,7 +490,7 @@ async def coroutine(i): @pytest.mark.asyncio -async def test_concurrent_runner_high_concurrency(): +async def test_concurrent_runner_high_concurrency() -> None: results = [] def _results_callback(task): @@ -528,7 +528,7 @@ def _second_callback(task): ], ) @pytest.mark.asyncio -async def test_concurrent_runner_try_put(initial_capacity, expected_result): +async def test_concurrent_runner_try_put(initial_capacity, expected_result) -> None: results = [] def _results_callback(task): @@ -560,7 +560,7 @@ async def coroutine(i): @pytest.mark.asyncio -async def test_concurrent_runner_join(): +async def test_concurrent_runner_join() -> None: results = [] def _results_callback(task): @@ -597,7 +597,7 @@ async def delayed_coroutine(): @pytest.mark.asyncio -async def test_concurrent_tasks_raise_any_exception(): +async def test_concurrent_tasks_raise_any_exception() -> None: async def return_1(): return 1 @@ -623,7 +623,7 @@ async def long_sleep_then_return_2(): @pytest.mark.asyncio -async def test_concurrent_tasks_join_raise_on_error(): +async def test_concurrent_tasks_join_raise_on_error() -> None: results = [] def _results_callback(task): @@ -677,7 +677,7 @@ def temp_file(converter): @pytest.mark.parametrize("converter", ["system", "py"]) -def test_convert_to_b64_inplace(converter): +def test_convert_to_b64_inplace(converter) -> None: with temp_file(converter) as (source, content): # convert in-place result = convert_to_b64(source) @@ -688,7 +688,7 @@ def test_convert_to_b64_inplace(converter): @pytest.mark.parametrize("converter", ["system", "py"]) -def test_convert_to_b64_target(converter): +def test_convert_to_b64_target(converter) -> None: with temp_file(converter) as (source, content): # convert to a specific file try: @@ -702,7 +702,7 @@ def test_convert_to_b64_target(converter): @pytest.mark.parametrize("converter", ["system", "py"]) -def test_convert_to_b64_no_overwrite(converter): +def test_convert_to_b64_no_overwrite(converter) -> None: with temp_file(converter) as (source, content): # check overwrite try: @@ -739,7 +739,9 @@ def patch_file_ops(): ("Ubuntu", None, "/usr/bin/base64 -w 0 {source} > {target}"), ], ) -def test_convert_to_b64_newer_macos(system, mac_ver, cmd_template, patch_file_ops): +def test_convert_to_b64_newer_macos( + system, mac_ver, cmd_template, patch_file_ops +) -> None: with ( patch("platform.system", return_value=system), patch("platform.mac_ver", return_value=[mac_ver]), @@ -838,7 +840,7 @@ def does_not_raise_function(): @pytest.mark.fail_slow(1) @pytest.mark.asyncio -async def test_exponential_backoff_retry_async_function(): +async def test_exponential_backoff_retry_async_function() -> None: mock_func = Mock() num_retries = 10 @@ -896,7 +898,7 @@ async def raises_async_generator(): "skipped_exceptions", [CustomException, [CustomException, RuntimeError]] ) @pytest.mark.asyncio -async def test_skipped_exceptions_retry_async_function(skipped_exceptions): +async def test_skipped_exceptions_retry_async_function(skipped_exceptions) -> None: mock_func = Mock() num_retries = 10 @@ -918,7 +920,7 @@ async def raises(): "skipped_exceptions", [CustomException, [CustomException, RuntimeError]] ) @pytest.mark.asyncio -async def test_skipped_exceptions_retry_sync_function(skipped_exceptions): +async def test_skipped_exceptions_retry_sync_function(skipped_exceptions) -> None: mock_func = Mock() num_retries = 10 @@ -936,7 +938,7 @@ def raises(): assert mock_func.call_count == 1 -def test_retryable_not_implemented_error(): +def test_retryable_not_implemented_error() -> None: with pytest.raises(NotImplementedError): @retryable() @@ -947,12 +949,12 @@ class NotSupported: class MockSSL: """This class contains methods which returns dummy ssl context""" - def load_verify_locations(self, cadata): + def load_verify_locations(self, cadata) -> None: """This method verify locations""" pass -def test_ssl_context(): +def test_ssl_context() -> None: """This function test ssl_context with dummy certificate""" # Setup certificate = "-----BEGIN CERTIFICATE----- Certificate -----END CERTIFICATE-----" @@ -962,7 +964,7 @@ def test_ssl_context(): ssl_context(certificate=certificate) -def test_url_encode(): +def test_url_encode() -> None: """Test the url_encode method by passing a string""" # Execute encode_response = url_encode("http://ascii.cl?parameter='Click on URL Decode!'") @@ -973,7 +975,7 @@ def test_url_encode(): ) -def test_is_expired(): +def test_is_expired() -> None: """This method checks whether token expires or not""" # Execute expires_at = datetime.fromisoformat("2023-02-10T09:02:23.629821") @@ -983,7 +985,7 @@ def test_is_expired(): @freeze_time("2023-02-18 14:25:26.158843", tz_offset=-4) -def test_evaluate_timedelta(): +def test_evaluate_timedelta() -> None: """This method tests adding seconds to the current utc time""" # Execute expected_response = evaluate_timedelta(seconds=86399, time_skew=20) @@ -992,7 +994,7 @@ def test_evaluate_timedelta(): assert expected_response == "2023-02-19T14:25:05.158843" -def test_get_pem_format_with_postfix(): +def test_get_pem_format_with_postfix() -> None: expected_formatted_pem_key = """-----BEGIN PRIVATE KEY----- PrivateKey -----END PRIVATE KEY-----""" @@ -1004,7 +1006,7 @@ def test_get_pem_format_with_postfix(): assert formatted_private_key == expected_formatted_pem_key -def test_get_pem_format_multiline(): +def test_get_pem_format_multiline() -> None: expected_formatted_certificate = """-----BEGIN CERTIFICATE----- Certificate1 Certificate2 @@ -1015,7 +1017,7 @@ def test_get_pem_format_multiline(): assert formatted_certificate == expected_formatted_certificate -def test_get_pem_format_multiple_certificates(): +def test_get_pem_format_multiple_certificates() -> None: expected_formatted_multiple_certificates = """-----BEGIN CERTIFICATE----- Certificate1 -----END CERTIFICATE----- @@ -1029,7 +1031,7 @@ def test_get_pem_format_multiple_certificates(): assert formatted_multi_certificate == expected_formatted_multiple_certificates -def test_hash_id(): +def test_hash_id() -> None: limit = 512 random_id_too_long = "".join( random.choices(string.ascii_letters + string.digits, k=1000) @@ -1038,7 +1040,7 @@ def test_hash_id(): assert len(hash_id(random_id_too_long).encode("UTF-8")) < limit -def test_truncate_id(): +def test_truncate_id() -> None: long_id = "something-12341361361-21905128510263" truncated_id = truncate_id(long_id) @@ -1049,7 +1051,7 @@ def test_truncate_id(): "_list, should_have_duplicate", [([], False), (["abc"], False), (["abc", "def"], False), (["abc", "abc"], True)], ) -def test_has_duplicates(_list, should_have_duplicate): +def test_has_duplicates(_list, should_have_duplicate) -> None: assert has_duplicates(_list) == should_have_duplicate @@ -1074,7 +1076,7 @@ def test_has_duplicates(_list, should_have_duplicate): ), ], ) -def test_filter_nested_dict_by_keys(key_list, source_dict, expected_dict): +def test_filter_nested_dict_by_keys(key_list, source_dict, expected_dict) -> None: assert filter_nested_dict_by_keys(key_list, source_dict) == expected_dict @@ -1107,23 +1109,23 @@ def test_filter_nested_dict_by_keys(key_list, source_dict, expected_dict): ), ], ) -def test_deep_merge_dicts(base_dict, new_dict, expected_dict): +def test_deep_merge_dicts(base_dict, new_dict, expected_dict) -> None: assert deep_merge_dicts(base_dict, new_dict) == expected_dict -def test_html_to_text_with_html_with_unclosed_tag(): +def test_html_to_text_with_html_with_unclosed_tag() -> None: invalid_html = "
Hello, world!
Next Line" assert html_to_text(invalid_html) == "Hello, world!\nNext Line" -def test_html_to_text_without_html(): +def test_html_to_text_without_html() -> None: invalid_html = "just text" assert html_to_text(invalid_html) == "just text" -def test_html_to_text_with_weird_html(): +def test_html_to_text_with_weird_html() -> None: invalid_html = "
just
text" text = html_to_text(invalid_html) @@ -1132,11 +1134,11 @@ def test_html_to_text_with_weird_html(): assert "text" in text -def test_html_to_text_with_none(): +def test_html_to_text_with_none() -> None: assert html_to_text(None) is None -def test_html_to_text_with_lxml_exception(): +def test_html_to_text_with_lxml_exception() -> None: # Here we're just mocking it in such a way, that # using BeautifulSoup(html, "lxml") raises an error to emulate the fact # that lxml is not available. @@ -1187,7 +1189,7 @@ def batch_size(value): ([[]], batch_size(20), [[[]]]), ], ) -def test_iterable_batches_generator(iterable, batch_size_, expected_batches): +def test_iterable_batches_generator(iterable, batch_size_, expected_batches) -> None: actual_batches = [] for batch in iterable_batches_generator(iterable, batch_size_): @@ -1200,7 +1202,7 @@ def test_iterable_batches_generator(iterable, batch_size_, expected_batches): "base64url_encoded_value, base64_expected_value", [("YQ-_", "YQ+/"), ("", ""), (None, None)], ) -def test_base64url_to_base64(base64url_encoded_value, base64_expected_value): +def test_base64url_to_base64(base64url_encoded_value, base64_expected_value) -> None: assert base64url_to_base64(base64url_encoded_value) == base64_expected_value @@ -1215,7 +1217,7 @@ def test_base64url_to_base64(base64url_encoded_value, base64_expected_value): ("subject@email_address", False), ], ) -def test_validate_email_address(email_address, is_valid): +def test_validate_email_address(email_address, is_valid) -> None: assert validate_email_address(email_address) == is_valid @@ -1239,7 +1241,7 @@ def test_validate_email_address(email_address, is_valid): ("abcdefgh", 1000, "a...h"), ], ) -def test_shorten_str(original, shorten_by, shortened): +def test_shorten_str(original, shorten_by: int, shortened) -> None: assert shorten_str(original, shorten_by) == shortened @@ -1255,11 +1257,13 @@ def test_shorten_str(original, shorten_by, shortened): (RetryStrategy.EXPONENTIAL_BACKOFF, 10, 2, 100), # 10 ^ 2 = 100 ], ) -async def test_time_to_sleep_between_retries(strategy, interval, retry, expected_sleep): +async def test_time_to_sleep_between_retries( + strategy, interval, retry, expected_sleep +) -> None: assert time_to_sleep_between_retries(strategy, interval, retry) == expected_sleep -async def test_time_to_sleep_between_retries_invalid_strategy(): +async def test_time_to_sleep_between_retries_invalid_strategy() -> None: with pytest.raises(UnknownRetryStrategyError) as e: time_to_sleep_between_retries("lalala", 1, 1) @@ -1277,7 +1281,7 @@ async def test_time_to_sleep_between_retries_invalid_strategy(): ({"foo": {"bar": {"baz": "result"}}}, None, "result"), ], ) -def test_nested_get_from_dict(dictionary, default, expected): +def test_nested_get_from_dict(dictionary, default, expected) -> None: keys = ["foo", "bar", "baz"] assert nested_get_from_dict(dictionary, keys, default=default) == expected @@ -1296,11 +1300,11 @@ def test_nested_get_from_dict(dictionary, default, expected): ), ], ) -def test_parse_datetime_string_compatibility(string, parsed_datetime): +def test_parse_datetime_string_compatibility(string, parsed_datetime) -> None: assert parse_datetime_string(string) == parsed_datetime -def test_error_monitor_raises_after_too_many_errors_in_window(): +def test_error_monitor_raises_after_too_many_errors_in_window() -> None: error_monitor = ErrorMonitor(max_error_rate=0.15, error_window_size=100) for _ in range(10): @@ -1316,7 +1320,7 @@ def test_error_monitor_raises_after_too_many_errors_in_window(): error_monitor.track_error(InvalidIndexNameError("Can't use this name")) -def test_error_monitor_raises_when_errors_were_reported_before(): +def test_error_monitor_raises_when_errors_were_reported_before() -> None: # Regression test. # Problem fixed was that monitor incorrectly calculates max_error_ratio - it never # actually considered either ratio or window size - it was always raising an error if @@ -1369,7 +1373,7 @@ def test_error_monitor_raises_when_errors_were_reported_before(): error_monitor.track_error(InvalidIndexNameError("Can't use this name")) -def test_error_monitor_when_reports_too_many_consecutive_errors(): +def test_error_monitor_when_reports_too_many_consecutive_errors() -> None: error_monitor = ErrorMonitor(max_consecutive_errors=3) error_monitor.track_error(Exception("first")) @@ -1380,7 +1384,7 @@ def test_error_monitor_when_reports_too_many_consecutive_errors(): error_monitor.track_error(Exception("fourth")) -def test_error_monitor_when_reports_too_many_total_errors(): +def test_error_monitor_when_reports_too_many_total_errors() -> None: error_monitor = ErrorMonitor( max_total_errors=100, max_consecutive_errors=999, max_error_rate=1 ) @@ -1398,7 +1402,7 @@ def test_error_monitor_when_reports_too_many_total_errors(): error_monitor.track_error(Exception("third")) -def test_error_monitor_when_reports_too_many_errors_in_window(): +def test_error_monitor_when_reports_too_many_errors_in_window() -> None: error_monitor = ErrorMonitor(error_window_size=100, max_error_rate=0.05) # rate is 0.04 @@ -1420,7 +1424,7 @@ def test_error_monitor_when_reports_too_many_errors_in_window(): error_monitor.track_error(Exception("last")) -def test_error_monitor_when_errors_are_tracked_last_x_errors_are_stored(): +def test_error_monitor_when_errors_are_tracked_last_x_errors_are_stored() -> None: error_monitor = ErrorMonitor(error_queue_size=5) for _ in range(5): @@ -1441,7 +1445,7 @@ def test_error_monitor_when_errors_are_tracked_last_x_errors_are_stored(): assert str(errors[4]) == "second_part" -def test_error_monitor_when_disabled(): +def test_error_monitor_when_disabled() -> None: error_monitor = ErrorMonitor( enabled=False, max_total_errors=1, max_consecutive_errors=1, max_error_rate=0.01 )