diff --git a/acto/__main__.py b/acto/__main__.py index 0ca73a0166..818d027409 100644 --- a/acto/__main__.py +++ b/acto/__main__.py @@ -87,6 +87,24 @@ help="Only generate test cases without executing them", ) parser.add_argument("--checkonly", action="store_true") +parser.add_argument( + "--num-alarms", + dest="num_alarms", + type=int, + help="Number of alarms to early stop running", +) +parser.add_argument( + "--time-duration", + dest="time_duration", + type=int, + help="Approximate running time (minutes) to early stop", +) +parser.add_argument( + "--hard-time-bound", + dest="hard_time_bound", + action="store_true", + help="Use hard time bound to early stop", +) args = parser.parse_args() @@ -149,6 +167,9 @@ apply_testcase_f=apply_testcase_f, delta_from=None, focus_fields=config.focus_fields, + num_alarms=args.num_alarms, + time_duration=args.time_duration, + hard_time_bound = args.hard_time_bound, ) generation_time = datetime.now() logger.info("Acto initialization finished in %s", generation_time - start_time) diff --git a/acto/engine.py b/acto/engine.py index 5ad5021346..bc6204d486 100644 --- a/acto/engine.py +++ b/acto/engine.py @@ -42,9 +42,12 @@ from acto.serialization import ActoEncoder, ContextEncoder from acto.snapshot import Snapshot from acto.utils import ( + AlarmCounter, delete_operator_pod, + get_early_stop_time, get_yaml_existing_namespace, process_crd, + terminate_threads, update_preload_images, ) from acto.utils.thread_logger import get_thread_logger, set_thread_logger_prefix @@ -236,6 +239,8 @@ def __init__( apply_testcase_f: Callable, acto_namespace: int, additional_exclude_paths: Optional[list[str]] = None, + alarm_counter: AlarmCounter = None, + early_stop_time: time.time = None, ) -> None: self.context = context self.workdir = workdir @@ -275,6 +280,9 @@ def __init__( self.apply_testcase_f = apply_testcase_f self.curr_trial = 0 + self.alarm_counter = alarm_counter + self.early_stop_time = early_stop_time + def run( self, errors: list[Optional[OracleResults]], @@ -296,7 +304,15 @@ def run( logger.info("Test finished") break + if self.alarm_counter is not None and self.alarm_counter.judge(self.worker_id): + logger.info("Test finshed for reaching the number of alarms") + break + trial_start_time = time.time() + if self.early_stop_time is not None and self.early_stop_time < trial_start_time: + logger.info("Test finshed for reaching early stop time") + break + self.cluster.restart_cluster(self.cluster_name, self.kubeconfig) apiclient = kubernetes_client(self.kubeconfig, self.context_name) self.cluster.load_images(self.images_archive, self.cluster_name) @@ -549,6 +565,13 @@ def run_trial( generation, ) if run_result.oracle_result.is_error(): + + # is alarm, alarm count plus one + if self.alarm_counter is not None and not run_result.is_invalid_input(): + self.alarm_counter.increment() + logger.info(f"Alarm count plus one, current count is {self.alarm_counter.get_count()}") + + # before return, run the recovery test case logger.info("Error result, running recovery") run_result.oracle_result.differential = self.run_recovery( @@ -765,6 +788,9 @@ def __init__( mount: Optional[list] = None, focus_fields: Optional[list] = None, acto_namespace: int = 0, + num_alarms: int = None, + time_duration: int = None, + hard_time_bound: bool = False, ) -> None: logger = get_thread_logger(with_prefix=False) @@ -813,6 +839,10 @@ def __init__( self.runner_type = Runner self.checker_type = CheckerSet + self.num_alarms = num_alarms + self.time_duration = time_duration + self.hard_time_bound = hard_time_bound + self.__learn( context_file=context_file, helper_crd=helper_crd, @@ -1060,7 +1090,9 @@ def run( check=True, ) + alarm_counter = None if self.num_alarms is None else AlarmCounter(self.num_alarms) start_time = time.time() + early_stop_time = get_early_stop_time(start_time, self.time_duration, self.hard_time_bound) errors: list[OracleResults] = [] runners: list[TrialRunner] = [] @@ -1083,6 +1115,8 @@ def run( self.apply_testcase_f, self.acto_namespace, self.operator_config.diff_ignore_fields, + alarm_counter, + early_stop_time, ) runners.append(runner) @@ -1094,6 +1128,10 @@ def run( ) t.start() threads.append(t) + + if self.time_duration is not None and self.hard_time_bound is True: + timer = threading.Timer((self.time_duration) * 60, terminate_threads, args=[threads]) + timer.start() for t in threads: t.join() diff --git a/acto/utils/__init__.py b/acto/utils/__init__.py index e67cec70c6..f54d4ce393 100644 --- a/acto/utils/__init__.py +++ b/acto/utils/__init__.py @@ -1,3 +1,4 @@ +from .early_stop import * from .error_handler import * from .k8s_helper import * from .preprocess import * diff --git a/acto/utils/early_stop.py b/acto/utils/early_stop.py new file mode 100644 index 0000000000..a2c8750732 --- /dev/null +++ b/acto/utils/early_stop.py @@ -0,0 +1,42 @@ +import ctypes +import threading +import time + + +class AlarmCounter: + def __init__(self, bound): + self.count = 0 + self.bound = bound + self.lock = threading.Lock() + + def increment(self, value=1): + with self.lock: + self.count += value + + def get_count(self): + return self.count + + def judge(self, work_id): + if self.count >= self.bound: + # print(f"Counter of thread {work_id} reached the number of alarms {self.bound}.") + return True + return False + + +def terminate_threads(threads: list[threading.Thread]): + for thread in threads: + if thread.is_alive(): + thread_id = thread.ident + # print(f"Timeout reached, terminating thread {thread_id}") + res = ctypes.pythonapi.PyThreadState_SetAsyncExc(ctypes.c_long(thread_id), + ctypes.py_object(SystemExit)) + if res > 1: + ctypes.pythonapi.PyThreadState_SetAsyncExc(thread_id, 0) + print("Exception raise failure in kill timeout threads") + + +def get_early_stop_time(start_time: time.time, time_duration: int, hard_time_bound: bool): + if time_duration is None or hard_time_bound is True: + return None + early_stop_time = start_time + time_duration * 60 + return early_stop_time