diff --git a/opensafely/__init__.py b/opensafely/__init__.py index b2a5bc03..037b6324 100644 --- a/opensafely/__init__.py +++ b/opensafely/__init__.py @@ -7,7 +7,6 @@ from pathlib import Path from opensafely import ( # noqa: E402 - check, clean, codelists, execute, @@ -129,7 +128,6 @@ def add_subcommand(cmd, module): add_subcommand("codelists", codelists) add_subcommand("pull", pull) add_subcommand("upgrade", upgrade) - add_subcommand("check", check) add_subcommand("unzip", unzip) add_subcommand("extract-stats", extract_stats) add_subcommand("info", info) @@ -160,10 +158,6 @@ def add_subcommand(cmd, module): print(str(exc), file=sys.stderr) success = False - # if `run`ning locally, run `check` in warn mode - if function == local_run.main and "format-output-for-github" not in kwargs: - check.main(continue_on_error=True) - # allow functions to return True/False, or an explicit exit code if success is False: exit_code = 1 diff --git a/opensafely/check.py b/opensafely/check.py deleted file mode 100644 index 75f5bb57..00000000 --- a/opensafely/check.py +++ /dev/null @@ -1,253 +0,0 @@ -import configparser -import glob -import os -import re -import sys -from dataclasses import dataclass -from pathlib import Path -from typing import List - -from opensafely._vendor import requests -from opensafely._vendor.ruyaml import YAML - - -DESCRIPTION = "Check the opensafely project for correctness" - - -@dataclass(frozen=True) -class RestrictedDataset: - name: str - cohort_extractor_function_names: List[str] - ehrql_names: List[str] - - -RESTRICTED_DATASETS = [ - RestrictedDataset( - name="icnarc", - cohort_extractor_function_names=[ - "admitted_to_icu", - ], - ehrql_names=[], - ), - RestrictedDataset( - name="isaric", - cohort_extractor_function_names=[ - "with_an_isaric_record", - ], - ehrql_names=["isaric_new"], - ), - RestrictedDataset( - name="ukrr", - cohort_extractor_function_names=[ - "with_record_in_ukrr", - ], - ehrql_names=["ukrr"], - ), - RestrictedDataset( - name="icnarc", - cohort_extractor_function_names=[ - "admitted_to_icu", - ], - ehrql_names=[], - ), - RestrictedDataset( - name="open_prompt", - cohort_extractor_function_names=[], - ehrql_names=["open_prompt"], - ), - RestrictedDataset( - name="wl_clockstops", - cohort_extractor_function_names=[], - ehrql_names=["wl_clockstops", "wl_clockstops_raw"], - ), - RestrictedDataset( - name="wl_openpathways", - cohort_extractor_function_names=[], - ehrql_names=["wl_openpathways", "wl_openpathways_raw"], - ), - RestrictedDataset( - name="appointments", - cohort_extractor_function_names=["with_gp_consultations"], - ehrql_names=["appointments"], - ), -] - -PERMISSIONS_URL = "https://raw.githubusercontent.com/opensafely-core/opensafely-cli/main/repository_permissions.yaml" - - -def add_arguments(parser): - pass - - -def main(continue_on_error=False): - permissions_url = os.environ.get("OPENSAFELY_PERMISSIONS_URL") or PERMISSIONS_URL - repo_name = get_repository_name(continue_on_error) - if not repo_name and not continue_on_error: - sys.exit("Unable to find repository name") - permissions = get_datasource_permissions(permissions_url) - allowed_datasets = get_allowed_datasets(repo_name, permissions) - - files_to_check = glob.glob("**/*.py", recursive=True) - - datasets_to_check = [ - dataset - for dataset in RESTRICTED_DATASETS - if dataset.name not in allowed_datasets - ] - - found_cohort_datasets = check_cohort_datasets(files_to_check, datasets_to_check) - found_ehrql_datasets = check_ehrql_datasets(files_to_check, datasets_to_check) - - violations = [] - - if found_ehrql_datasets: - violations.extend(list(format_ehrql_violations(found_ehrql_datasets))) - - if found_cohort_datasets: - violations.extend(list(format_cohort_violations(found_cohort_datasets))) - - if violations: - violations_text = "\n".join(violations) - if not continue_on_error: - sys.exit(violations_text) - print("*** WARNING ***\n") - print(violations_text) - else: - if not continue_on_error: - print("Success") - - -def check_cohort_datasets(files_to_check, datasets_to_check): - return { - dataset.name: dataset_check - for dataset in datasets_to_check - if ( - dataset_check := check_restricted_names( - restricted_names=dataset.cohort_extractor_function_names, - # Check for the use of `.function_name`. - regex_template=r"\.{name}\(", - files_to_check=files_to_check, - ) - ) - } - - -def check_ehrql_datasets(files_to_check, datasets_to_check): - return { - dataset.name: dataset_check - for dataset in datasets_to_check - if ( - dataset_check := check_restricted_names( - restricted_names=dataset.ehrql_names, - # Check for the use of `name.` or `name(` - regex_template=r"\b{name}(\.|\()", - files_to_check=files_to_check, - ) - ) - } - - -def format_cohort_violations(found_datasets): - yield "Usage of restricted datasets found:\n" - for d, functions in found_datasets.items(): - yield f"{d}: https://docs.opensafely.org/study-def-variables/#{d}" - for fn, files in functions.items(): - yield f"- {fn}" - for f, lines in files.items(): - yield f" - {f}:" - for ln, line in lines.items(): - yield f" line {ln}: {line}" - - -def format_ehrql_violations(found_datasets): - # Unlike for cohort-extractor, - # there is no specific reference we can currently link to for restricted tables. - # We may be able to add such a link in future, - # which might make it more reasonable to unify this function - # with the analogous function for cohort-extractor. - yield "Usage of restricted tables found:\n" - for d, tables in found_datasets.items(): - for table, files in tables.items(): - yield f"{table}" - for f, lines in files.items(): - yield f" - {f}:" - for ln, line in lines.items(): - yield f" line {ln}: {line}" - - -def check_restricted_names(restricted_names, regex_template, files_to_check): - found_names = {} - for name in restricted_names: - regex = re.compile(regex_template.format(name=name)) - found_files = {} - for f in files_to_check: - matches = check_file(f, regex) - if matches: - found_files[f] = matches - if found_files: - found_names[name] = found_files - return found_names - - -def check_file(filename, regex): - found_lines = {} - with open(filename, encoding="utf8", errors="ignore") as f: - for ln, line in enumerate(f, start=1): - if line.lstrip().startswith("#"): - continue - if regex.search(line): - found_lines[ln] = line - return found_lines - - -def get_datasource_permissions(permissions_url): - resp = requests.get(permissions_url) - if resp.status_code != 200: - raise requests.RequestException( - f"Error {resp.status_code} getting {permissions_url}" - ) - yaml = YAML() - permissions = yaml.load(resp.text) - return permissions - - -def get_local_permissions(): - path = Path(Path(PERMISSIONS_URL).name) - yaml = YAML() - permissions = yaml.load(path.read_text()) - return permissions - - -def get_repository_name(continue_on_error): - if "GITHUB_REPOSITORY" in os.environ: - return os.environ["GITHUB_REPOSITORY"] - else: - git_config_path = Path(".git", "config") - if not git_config_path.is_file(): - if not continue_on_error: - print("Git config file not found") - return - config = configparser.ConfigParser() - try: - config.read(git_config_path) - except Exception as e: - if not continue_on_error: - print(f"Unable to read git config.\n{str(e)}") - return - if 'remote "origin"' not in config.sections(): - if not continue_on_error: - print("Remote 'origin' not defined in git config.") - return - url = config['remote "origin"']["url"] - return ( - url.replace("https://github.com/", "") - .replace("git@github.com:", "") - .replace(".git", "") - .strip() - ) - - -def get_allowed_datasets(repository_name, permissions): - if not repository_name or repository_name not in permissions: - return [] - return permissions[repository_name]["allow"] diff --git a/tests/test_check.py b/tests/test_check.py deleted file mode 100644 index ef3ded23..00000000 --- a/tests/test_check.py +++ /dev/null @@ -1,444 +0,0 @@ -import itertools -import os -import subprocess -import textwrap -from collections import Counter -from enum import Enum -from pathlib import Path - -import pytest -from opensafely._vendor import requests -from opensafely._vendor.requests.exceptions import RequestException -from opensafely._vendor.ruyaml import YAML -from opensafely._vendor.ruyaml.comments import CommentedMap -from requests_mock import mocker - -from opensafely import check - - -# Because we're using a vendored version of requests we need to monkeypatch the -# requests_mock library so it references our vendored library instead -mocker.requests = requests -mocker._original_send = requests.Session.send - - -def flatten_list(nested_list): - return [x for sublist in nested_list for x in sublist] - - -class Protocol(Enum): - HTTPS = 1 - SSH = 2 - ENVIRON = 3 - - -UNRESTRICTED_FUNCTION = "with_these_medications" -UNRESTRICTED_TABLE = "clinical_events" - - -@pytest.fixture -def repo_path(tmp_path): - prev_dir = os.getcwd() - os.chdir(tmp_path) - yield tmp_path - os.chdir(prev_dir) - - -def get_permissions_fixture_data(): - permissions_file = ( - Path(__file__).parent - / "fixtures" - / "permissions" - / "repository_permissions.yaml" - ) - permissions_text = permissions_file.read_text() - permissions_dict = YAML().load(permissions_text) - return permissions_text, permissions_dict - - -def all_test_repos(): - _, permissions = get_permissions_fixture_data() - unknown_repo = "opensafely/dummy" - assert unknown_repo not in permissions - return [*permissions, unknown_repo, None] - - -def format_function_call(func): - return ( - f"patients.{func}(" - "between=['2021-01-01','2022-02-02'], " - "find_first_match_in_period=True, " - "returning='binary_flag')" - ) - - -def write_study_def(path, include_restricted): - filename_part = "restricted" if include_restricted else "unrestricted" - all_restricted_functions = flatten_list( - [ - dataset.cohort_extractor_function_names - for dataset in check.RESTRICTED_DATASETS - ] - ) - - for a in [1, 2]: - # generate the filename; we make 2 versions to test that all study defs are checked - filepath = path / f"study_definition_{filename_part}_{a}.py" - - # Build the function calls for the test's study definition. We name each variable with - # the function name itself, to make checking the outputs easier - - # if we're included restricted functions, create a function call for each one - # these will cause check fails depending on the test repo's permissions - if include_restricted: - restricted = [ - f"{name}_name={format_function_call(name)}," - for name in all_restricted_functions - ] - else: - restricted = [] - restricted_lines = "\n".join(restricted) - # create a commented-out function call for each restricted function - # include these in all test study defs; always allowed - restricted_commented_lines = "\n".join( - [ - f"#{name}_commented={format_function_call(name)}," - for name in all_restricted_functions - ] - ) - # create a function call an unrestricted function; - # include in all test study defs; this is always allowed - unrestricted = ( - f"{UNRESTRICTED_FUNCTION}={format_function_call(UNRESTRICTED_FUNCTION)}," - ) - - filepath.write_text( - textwrap.dedent( - f""" - from cohortextractor import StudyDefinition, patients - - study = StudyDefinition ( - {restricted_lines} - {restricted_commented_lines} - {unrestricted} - )""" - ) - ) - - -def write_dataset_def(path, include_restricted): - filename_part = "restricted" if include_restricted else "unrestricted" - all_restricted_tables = flatten_list( - [dataset.ehrql_names for dataset in check.RESTRICTED_DATASETS] - ) - - for a in [1, 2]: - # generate the filename; we make 2 versions to test that all dataset defs are checked - filepath = path / f"dataset_definition_{filename_part}_{a}.py" - - # Build the function calls for the test's dataset definition. We name each variable with - # the function name itself, to make checking the outputs easier - - # if we're included restricted functions, create an import for each one - # these will cause check fails depending on the test repo's permissions - if include_restricted: - restricted = [ - f"{name}_column = {name}.column" for name in all_restricted_tables - ] - else: - restricted = [] - restricted_lines = "\n".join(restricted) - # create a commented-out import for each restricted function - # include these in all test dataset defs; always allowed - restricted_commented_lines = "\n".join( - [f"#{name}_commented = {name}.column" for name in all_restricted_tables] - ) - # create an unrestricted table import; - # include in all test dataset defs; this is always allowed - unrestricted = f"{UNRESTRICTED_TABLE}_column = {UNRESTRICTED_TABLE}.column" - - filepath.write_text( - textwrap.dedent( - f""" - from ehrql import Dataset - - {restricted_lines} - {restricted_commented_lines} - {unrestricted} - )""" - ) - ) - - -def git_init(url): - subprocess.run(["git", "init"]) - subprocess.run(["git", "remote", "add", "origin", url]) - - -def validate_pass(capsys, continue_on_error): - check.main(continue_on_error) - stdout, stderr = capsys.readouterr() - if not continue_on_error: - assert stderr == "" - assert stdout == "Success\n" - else: - assert stdout == "" - - -def validate_fail(capsys, continue_on_error, permissions): - def validate_fail_output(stdout, stderr): - assert stdout != "Success\n" - assert "Usage of restricted datasets found:" in stderr - assert "Usage of restricted tables found:" in stderr - - for dataset in check.RESTRICTED_DATASETS: - for function_name in dataset.cohort_extractor_function_names: - # commented out functions are never in error output, even if restricted - assert f"#{function_name}_commented" not in stderr - if dataset.name in permissions: - assert dataset.name not in stderr - assert f"{function_name}_name" not in stderr - else: - assert dataset.name in stderr, permissions - assert f"{function_name}_name" in stderr - - for table_name in dataset.ehrql_names: - # commented out tables are never in error output, even if restricted - assert f"#{table_name}_commented" not in stderr - if dataset.name in permissions: - assert dataset.name not in stderr - assert f"{table_name}.column" not in stderr - else: - assert dataset.name in stderr, permissions - assert f"{table_name}.column" in stderr - - # unrestricted functions and tables are never in error output - assert UNRESTRICTED_FUNCTION not in stderr - assert UNRESTRICTED_TABLE not in stderr - # Both study definition files are reported - assert "study_definition_restricted_1.py" in stderr - assert "study_definition_restricted_2.py" in stderr - # Both dataset definition files are reported - assert "dataset_definition_restricted_1.py" in stderr - assert "dataset_definition_restricted_2.py" in stderr - - if not continue_on_error: - with pytest.raises(SystemExit): - check.main(continue_on_error) - stdout, stderr = capsys.readouterr() - validate_fail_output(stdout, stderr) - - else: - check.main(continue_on_error) - stdout, stderr = capsys.readouterr() - validate_fail_output(stdout, stdout) - - -def validate_norepo(capsys, continue_on_error): - if not continue_on_error: - with pytest.raises(SystemExit): - check.main(continue_on_error) - stdout, stderr = capsys.readouterr() - assert "git config" in stdout.lower() - assert "Unable to find repository name" in stderr - else: - check.main(continue_on_error) - stdout, stderr = capsys.readouterr() - assert stderr == "" - assert stdout == "" - - -def test_permissions_fixture_data_complete(): - """ - This test is just to test the permissions test fixture, to ensure: - 1) that we've included all the restricted datasets - 2) that at least one test repo has access to all restricted datasets - """ - _, permissions_dict = get_permissions_fixture_data() - - restricted_datasets = set(dataset.name for dataset in check.RESTRICTED_DATASETS) - - all_allowed_repo = None - # find repo with all restricted datasets - for repo, allowed_dict in permissions_dict.items(): - allowed = set(allowed_dict.get("allow", [])) - if not (restricted_datasets - allowed): - all_allowed_repo = repo - break - - assert ( - all_allowed_repo is not None - ), """ - No repo found with access to all restricted datasets. - If you added a new restricted dataset, make sure - tests/fixtures/permissions/repository-permissions.yaml has been updated. - """ - - flattened_permitted_datasets = flatten_list( - [ - dataset_permissions["allow"] - for dataset_permissions in permissions_dict.values() - ] - ) - permitted_dataset_counts = Counter(flattened_permitted_datasets) - - for dataset in restricted_datasets: - assert ( - permitted_dataset_counts[dataset] > 1 - ), f"No part-restricted repo found for restricted dataset {dataset}" - - -@pytest.mark.parametrize( - "repo, protocol, include_restricted, continue_on_error", - itertools.chain( - itertools.product( - all_test_repos(), list(Protocol), [True, False], [True, False] - ), - itertools.product([None], [None], [True, False], [True, False]), - ), -) -def test_check( - repo_path, - capsys, - monkeypatch, - requests_mock, - repo, - protocol, - include_restricted, - continue_on_error, -): - if "GITHUB_REPOSITORY" in os.environ: - monkeypatch.delenv("GITHUB_REPOSITORY") - - # Mock the call to the permissions URL to return the contents of our test permissions file - permissions_text, permissions_dict = get_permissions_fixture_data() - requests_mock.get(check.PERMISSIONS_URL, text=permissions_text) - - write_study_def(repo_path, include_restricted) - write_dataset_def(repo_path, include_restricted) - - if repo: - if protocol == Protocol.ENVIRON: - monkeypatch.setenv("GITHUB_REPOSITORY", repo) - else: - if protocol == Protocol.SSH: - url = f"git@github.com:{repo}.git" - elif protocol == Protocol.HTTPS: - url = f"https://github.com/{repo}" - else: - url = "" - git_init(url) - - repo_permissions = permissions_dict.get(repo, {}).get("allow", []) - # are the restricted datasets all in repo's permitted dataset? - # Some repos in the test fixtures list "ons", which is an allowed dataset; - # ignore any datasets listed in the repo's permissions that are not restricted - all_allowed = not ( - set(dataset.name for dataset in check.RESTRICTED_DATASETS) - - set(repo_permissions) - ) - - if not repo and not include_restricted: - validate_norepo(capsys, continue_on_error) - elif include_restricted and not all_allowed: - validate_fail( - capsys, continue_on_error, permissions_dict.get(repo, {}).get("allow", []) - ) - else: - validate_pass(capsys, continue_on_error) - - -def get_datasource_permissions(): - try: - permissions = check.get_datasource_permissions(check.PERMISSIONS_URL) - return permissions - except RequestException as e: - # This test should always pass on main, but if we've renamed the file - # on the branch, it will fail before it's merged - branch = subprocess.run(["git", "rev-parse", "--abbrev-ref", "HEAD"]) - if branch != "main" and "Error 404" in str(e): - pytest.xfail("Permissions file does not exist on main yet") - - -@pytest.mark.parametrize( - "get_permissions", [get_datasource_permissions, check.get_local_permissions] -) -def test_repository_permissions_yaml(get_permissions): - permissions = get_permissions() - assert permissions, "empty permissions file" - assert type(permissions) == CommentedMap, "invalid permissions file" - for k, v in permissions.items(): - assert len(v.keys()) == 1, f"multiple keys specified for {k}" - assert "allow" in v.keys(), f"allow key not present for {k}" - - -@pytest.mark.parametrize( - "target_string,restricted_name,expected_match", - [ - ("bad_name", "bad_name", False), - ("bad_name.", "bad_name", False), - ("bad_name()", "bad_name", False), - ("patients.bad_name()", "bad_name", True), - ("patients.not_a_bad_name()", "bad_name", False), - ("# patients.bad_name()", "bad_name", False), - ], -) -def test_check_cohort_datasets( - tmp_path, target_string, restricted_name, expected_match -): - contents = textwrap.dedent( - f"""\ - from cohortextractor import patients - {target_string} - patients.age_as_of("2020-01-01") - """ - ) - test_file = tmp_path / "test.py" - test_file.write_text(contents) - restricted_dataset = check.RestrictedDataset( - name="test", - cohort_extractor_function_names=[restricted_name], - ehrql_names=[], - ) - results = check.check_cohort_datasets([test_file], [restricted_dataset]) - if expected_match: - assert restricted_dataset.name in results - else: - assert results == {} - - -@pytest.mark.parametrize( - "target_string,restricted_name,expected_match", - [ - ("bad_name", "bad_name", False), - ("bad_name.", "bad_name", True), - ("bad_name()", "bad_name", True), - ("module.bad_name.", "bad_name", True), - ("module.bad_name()", "bad_name", True), - ("not_a_bad_name()", "bad_name", False), - ("# bad_name.", "bad_name", False), - ], -) -def test_check_ehrql_datasets(tmp_path, target_string, restricted_name, expected_match): - contents = "from erhql import create_dataset\ndataset = create_dataset()\n# foo" - contents = textwrap.dedent( - f"""\ - from ehrql import create_dataset - from ehrql.tables.core import patients - dataset = create_dataset() - {target_string} - dataset.age = patients.age_on("2020-01-01") - """ - ) - test_file = tmp_path / "test.py" - test_file.write_text(contents) - restricted_dataset = check.RestrictedDataset( - name="test", - cohort_extractor_function_names=[], - ehrql_names=[restricted_name], - ) - results = check.check_ehrql_datasets([test_file], [restricted_dataset]) - if expected_match: - assert restricted_dataset.name in results - else: - assert results == {}