From 581740cc111c44b4069605c0930bb0763c4514f7 Mon Sep 17 00:00:00 2001 From: Eman Elsabban Date: Tue, 21 May 2024 10:20:04 -0700 Subject: [PATCH 1/2] Revert "Parallelizing execution of restoring in Tron - TRON-2161 (#950)" This reverts commit 419b353a65fc35d71bc7b00ff8afd80220bf3380. --- tests/mcp_test.py | 6 +- tests/serialize/runstate/statemanager_test.py | 22 +++--- tron/core/job_collection.py | 4 +- tron/core/job_scheduler.py | 4 +- tron/kubernetes.py | 1 - tron/mcp.py | 35 +++------- .../runstate/dynamodb_state_store.py | 67 ++++++------------- tron/serialize/runstate/statemanager.py | 38 ++++------- 8 files changed, 64 insertions(+), 113 deletions(-) diff --git a/tests/mcp_test.py b/tests/mcp_test.py index 13ae2c830..667c7f6d0 100644 --- a/tests/mcp_test.py +++ b/tests/mcp_test.py @@ -152,14 +152,18 @@ def teardown_mcp(self): shutil.rmtree(self.working_dir) shutil.rmtree(self.config_path) - def test_restore_state(self): + @mock.patch("tron.mcp.MesosClusterRepository", autospec=True) + def test_restore_state(self, mock_cluster_repo): job_state_data = {"1": "things", "2": "things"} + mesos_state_data = {"3": "things", "4": "things"} state_data = { + "mesos_state": mesos_state_data, "job_state": job_state_data, } self.mcp.state_watcher.restore.return_value = state_data action_runner = mock.Mock() self.mcp.restore_state(action_runner) + mock_cluster_repo.restore_state.assert_called_with(mesos_state_data) self.mcp.jobs.restore_state.assert_called_with(job_state_data, action_runner) diff --git a/tests/serialize/runstate/statemanager_test.py b/tests/serialize/runstate/statemanager_test.py index 8f1ad534f..07f4653aa 100644 --- a/tests/serialize/runstate/statemanager_test.py +++ b/tests/serialize/runstate/statemanager_test.py @@ -121,12 +121,15 @@ def test_restore(self): "one": {"key": "val1"}, "two": {"key": "val2"}, }, + # _restore_dicts for MESOS_STATE + {"frameworks": "clusters"}, ] restored_state = self.manager.restore(job_names) mock_restore_metadata.assert_called_once_with() assert mock_restore_dicts.call_args_list == [ mock.call(runstate.JOB_STATE, job_names), + mock.call(runstate.MESOS_STATE, ["frameworks"]), ] assert len(mock_restore_runs.call_args_list) == 2 assert restored_state == { @@ -134,6 +137,7 @@ def test_restore(self): "one": {"key": "val1", "runs": mock_restore_runs.return_value}, "two": {"key": "val2", "runs": mock_restore_runs.return_value}, }, + runstate.MESOS_STATE: {"frameworks": "clusters"}, } def test_restore_runs_for_job(self): @@ -143,13 +147,14 @@ def test_restore_runs_for_job(self): "_restore_dicts", autospec=True, ) as mock_restore_dicts: - mock_restore_dicts.side_effect = [ - {"job_a.2": {"job_name": "job_a", "run_num": 2}, "job_a.3": {"job_name": "job_a", "run_num": 3}} - ] + mock_restore_dicts.side_effect = [{"job_a.2": "two"}, {"job_a.3": "three"}] runs = self.manager._restore_runs_for_job("job_a", job_state) - assert mock_restore_dicts.call_args_list == [mock.call(runstate.JOB_RUN_STATE, ["job_a.2", "job_a.3"])] - assert runs == [{"job_name": "job_a", "run_num": 3}, {"job_name": "job_a", "run_num": 2}] + assert mock_restore_dicts.call_args_list == [ + mock.call(runstate.JOB_RUN_STATE, ["job_a.2"]), + mock.call(runstate.JOB_RUN_STATE, ["job_a.3"]), + ] + assert runs == ["two", "three"] def test_restore_runs_for_job_one_missing(self): job_state = {"run_nums": [2, 3], "enabled": True} @@ -158,13 +163,14 @@ def test_restore_runs_for_job_one_missing(self): "_restore_dicts", autospec=True, ) as mock_restore_dicts: - mock_restore_dicts.side_effect = [{"job_a.3": {"job_name": "job_a", "run_num": 3}, "job_b": {}}] + mock_restore_dicts.side_effect = [{}, {"job_a.3": "three"}] runs = self.manager._restore_runs_for_job("job_a", job_state) assert mock_restore_dicts.call_args_list == [ - mock.call(runstate.JOB_RUN_STATE, ["job_a.2", "job_a.3"]), + mock.call(runstate.JOB_RUN_STATE, ["job_a.2"]), + mock.call(runstate.JOB_RUN_STATE, ["job_a.3"]), ] - assert runs == [{"job_name": "job_a", "run_num": 3}] + assert runs == ["three"] def test_restore_dicts(self): names = ["namea", "nameb"] diff --git a/tron/core/job_collection.py b/tron/core/job_collection.py index 16b3b8404..7c066f450 100644 --- a/tron/core/job_collection.py +++ b/tron/core/job_collection.py @@ -69,8 +69,8 @@ def update(self, new_job_scheduler): def restore_state(self, job_state_data, config_action_runner): """ - Loops through the jobs and their runs in order to load their - state for each run. As we load the state, we will also schedule the next + Loops through the jobs and their runs in order to restore + state for each run. As we restore state, we will also schedule the next runs for each job """ for name, state in job_state_data.items(): diff --git a/tron/core/job_scheduler.py b/tron/core/job_scheduler.py index 6f37f6c49..dba0cde1f 100644 --- a/tron/core/job_scheduler.py +++ b/tron/core/job_scheduler.py @@ -25,15 +25,13 @@ def __init__(self, job): self.watch(job) def restore_state(self, job_state_data, config_action_runner): - """Load the job state and schedule any JobRuns.""" + """Restore the job state and schedule any JobRuns.""" job_runs = self.job.get_job_runs_from_state(job_state_data) for run in job_runs: self.job.watch(run) self.job.runs.runs.extend(job_runs) log.info(f"{self} restored") - # Tron will recover any action run that has UNKNOWN status - # and will start connecting to task_proc recovery.launch_recovery_actionruns_for_job_runs( job_runs=job_runs, ) diff --git a/tron/kubernetes.py b/tron/kubernetes.py index 31abf57fa..0d7cbbcc9 100644 --- a/tron/kubernetes.py +++ b/tron/kubernetes.py @@ -639,7 +639,6 @@ def get_cluster(cls, kubeconfig_path: Optional[str] = None) -> Optional[Kubernet kubeconfig_path = cls.kubeconfig_path if kubeconfig_path not in cls.clusters: - # will create the task_proc executor cluster = KubernetesCluster( kubeconfig_path=kubeconfig_path, enabled=cls.kubernetes_enabled, default_volumes=cls.default_volumes ) diff --git a/tron/mcp.py b/tron/mcp.py index 7b0e31ad1..3cf643590 100644 --- a/tron/mcp.py +++ b/tron/mcp.py @@ -1,6 +1,5 @@ import logging import time -from contextlib import contextmanager from tron import actioncommand from tron import command_context @@ -18,18 +17,6 @@ log = logging.getLogger(__name__) -@contextmanager -def timer(function_name: str): - start = time.time() - try: - yield - except Exception: - pass - finally: - end = time.time() - log.info(f"Execution time for function {function_name}: {end-start}") - - def apply_master_configuration(mapping, master_config): def get_config_value(seq): return [getattr(master_config, item) for item in seq] @@ -84,12 +71,11 @@ def initial_setup(self): # The job schedule factories will be created in the function below self._load_config() # Jobs will also get scheduled (internally) once the state for action runs are restored in restore_state - with timer("self.restore_state"): - self.restore_state( - actioncommand.create_action_runner_factory_from_config( - self.config.load().get_master().action_runner, - ), - ) + self.restore_state( + actioncommand.create_action_runner_factory_from_config( + self.config.load().get_master().action_runner, + ), + ) # Any job with existing state would have been scheduled already. Jobs # without any state will be scheduled here. self.jobs.run_queue_schedule() @@ -173,19 +159,16 @@ def get_config_manager(self): return self.config def restore_state(self, action_runner): - """Use the state manager to retrieve the persisted state from dynamodb and apply it + """Use the state manager to retrieve to persisted state and apply it to the configured Jobs. """ log.info("Restoring from DynamoDB") - with timer("restore"): - # restores the state of the jobs and their runs from DynamoDB - states = self.state_watcher.restore(self.jobs.get_names()) + states = self.state_watcher.restore(self.jobs.get_names()) + MesosClusterRepository.restore_state(states.get("mesos_state", {})) log.info( f"Tron will start restoring state for the jobs and will start scheduling them! Time elapsed since Tron started {time.time() - self.boot_time}" ) - # loads the runs' state and schedule the next run for each job - with timer("self.jobs.restore_state"): - self.jobs.restore_state(states.get("job_state", {}), action_runner) + self.jobs.restore_state(states.get("job_state", {}), action_runner) log.info( f"Tron completed restoring state for the jobs. Time elapsed since Tron started {time.time() - self.boot_time}" ) diff --git a/tron/serialize/runstate/dynamodb_state_store.py b/tron/serialize/runstate/dynamodb_state_store.py index a5af5479d..57c5363ea 100644 --- a/tron/serialize/runstate/dynamodb_state_store.py +++ b/tron/serialize/runstate/dynamodb_state_store.py @@ -1,5 +1,3 @@ -import concurrent.futures -import copy import logging import math import os @@ -9,9 +7,6 @@ from collections import defaultdict from collections import OrderedDict from typing import DefaultDict -from typing import List -from typing import Sequence -from typing import TypeVar import boto3 # type: ignore @@ -19,9 +14,7 @@ OBJECT_SIZE = 400000 MAX_SAVE_QUEUE = 500 -MAX_ATTEMPTS = 10 log = logging.getLogger(__name__) -T = TypeVar("T") class DynamoDBStateStore: @@ -54,48 +47,31 @@ def restore(self, keys) -> dict: vals = self._merge_items(first_items, remaining_items) return vals - def chunk_keys(self, keys: Sequence[T]) -> List[Sequence[T]]: - """Generates a list of chunks of keys to be used to read from DynamoDB""" - # have a for loop here for all the key chunks we want to go over - cand_keys_chunks = [] - for i in range(0, len(keys), 100): - # chunks of at most 100 keys will be in this list as there could be smaller chunks - cand_keys_chunks.append(keys[i : min(len(keys), i + 100)]) - return cand_keys_chunks - - def _get_items(self, table_keys: list) -> object: + def _get_items(self, keys: list) -> object: items = [] - # let's avoid potentially mutating our input :) - cand_keys_list = copy.copy(table_keys) - attempts_to_retrieve_keys = 0 - while len(cand_keys_list) != 0: - with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor: - responses = [ - executor.submit( - self.client.batch_get_item, - RequestItems={self.name: {"Keys": chunked_keys, "ConsistentRead": True}}, - ) - for chunked_keys in self.chunk_keys(cand_keys_list) - ] - # let's wipe the state so that we can loop back around - # if there are any un-processed keys - # NOTE: we'll re-chunk when submitting to the threadpool - # since it's possible that we've had several chunks fail - # enough keys that we'd otherwise send > 100 keys in a - # request otherwise - cand_keys_list = [] - for resp in concurrent.futures.as_completed(responses): - items.extend(resp.result()["Responses"][self.name]) - # add any potential unprocessed keys to the thread pool - if resp.result()["UnprocessedKeys"].get(self.name) and attempts_to_retrieve_keys < MAX_ATTEMPTS: - cand_keys_list.append(resp.result()["UnprocessedKeys"][self.name]["Keys"]) - elif attempts_to_retrieve_keys >= MAX_ATTEMPTS: - failed_keys = resp.result()["UnprocessedKeys"][self.name]["Keys"] + for i in range(0, len(keys), 100): + count = 0 + cand_keys = keys[i : min(len(keys), i + 100)] + while True: + resp = self.client.batch_get_item( + RequestItems={ + self.name: { + "Keys": cand_keys, + "ConsistentRead": True, + }, + }, + ) + items.extend(resp["Responses"][self.name]) + if resp["UnprocessedKeys"].get(self.name) and count < 10: + cand_keys = resp["UnprocessedKeys"][self.name]["Keys"] + count += 1 + elif count >= 10: error = Exception( - f"tron_dynamodb_restore_failure: failed to retrieve items with keys \n{failed_keys}\n from dynamodb\n{resp.result()}" + f"tron_dynamodb_restore_failure: failed to retrieve items with keys \n{cand_keys}\n from dynamodb\n{resp}" ) raise error - attempts_to_retrieve_keys += 1 + else: + break return items def _get_first_partitions(self, keys: list): @@ -103,7 +79,6 @@ def _get_first_partitions(self, keys: list): return self._get_items(new_keys) def _get_remaining_partitions(self, items: list): - """Get items in the remaining partitions: N = 1 and beyond""" keys_for_remaining_items = [] for item in items: remaining_items = [ diff --git a/tron/serialize/runstate/statemanager.py b/tron/serialize/runstate/statemanager.py index aa330a3be..982a18e15 100644 --- a/tron/serialize/runstate/statemanager.py +++ b/tron/serialize/runstate/statemanager.py @@ -1,5 +1,3 @@ -import concurrent.futures -import copy import itertools import logging import time @@ -147,41 +145,29 @@ def restore(self, job_names, skip_validation=False): if not skip_validation: self._restore_metadata() - # First, restore the jobs themselves jobs = self._restore_dicts(runstate.JOB_STATE, job_names) # jobs should be a dictionary that contains job name and number of runs # {'MASTER.k8s': {'run_nums':[0], 'enabled': True}, 'MASTER.cits_test_frequent_1': {'run_nums': [1,0], 'enabled': True}} - - # second, restore the runs for each of the jobs restored above - with concurrent.futures.ThreadPoolExecutor(max_workers=3) as executor: - # start the threads and mark each future with it's job name - # this is useful so that we can index the job name later to add the runs to the jobs dictionary - results = { - executor.submit(self._restore_runs_for_job, job_name, job_state): job_name - for job_name, job_state in jobs.items() - } - for result in concurrent.futures.as_completed(results): - jobs[results[result]]["runs"] = result.result() + for job_name, job_state in jobs.items(): + job_state["runs"] = self._restore_runs_for_job(job_name, job_state) + frameworks = self._restore_dicts(runstate.MESOS_STATE, ["frameworks"]) state = { runstate.JOB_STATE: jobs, + runstate.MESOS_STATE: frameworks, } return state def _restore_runs_for_job(self, job_name, job_state): - """Restore the state for the runs of each job""" run_nums = job_state["run_nums"] - keys = [jobrun.get_job_run_id(job_name, run_num) for run_num in run_nums] - job_runs_restored_states = self._restore_dicts(runstate.JOB_RUN_STATE, keys) - runs = copy.copy(job_runs_restored_states) - for run_id, state in runs.items(): - if state == {}: - log.error(f"Failed to restore {run_id}, no state found for it!") - job_runs_restored_states.pop(run_id) - - runs = list(job_runs_restored_states.values()) - # We need to sort below otherwise the runs will not be in order - runs.sort(key=lambda x: x["run_num"], reverse=True) + runs = [] + for run_num in run_nums: + key = jobrun.get_job_run_id(job_name, run_num) + run_state = list(self._restore_dicts(runstate.JOB_RUN_STATE, [key]).values()) + if not run_state: + log.error(f"Failed to restore {key}, no state found for it") + else: + runs.append(run_state[0]) return runs def _restore_metadata(self): From ccd7e30ee890efbf6d7eea35b7f22fe9b4f0ebbc Mon Sep 17 00:00:00 2001 From: Eman Elsabban Date: Tue, 21 May 2024 11:33:54 -0700 Subject: [PATCH 2/2] deletes the mesos code that the tests are complaining about --- tests/mcp_test.py | 6 +----- tests/serialize/runstate/statemanager_test.py | 4 ---- tron/mcp.py | 1 - tron/serialize/runstate/statemanager.py | 2 -- 4 files changed, 1 insertion(+), 12 deletions(-) diff --git a/tests/mcp_test.py b/tests/mcp_test.py index 667c7f6d0..13ae2c830 100644 --- a/tests/mcp_test.py +++ b/tests/mcp_test.py @@ -152,18 +152,14 @@ def teardown_mcp(self): shutil.rmtree(self.working_dir) shutil.rmtree(self.config_path) - @mock.patch("tron.mcp.MesosClusterRepository", autospec=True) - def test_restore_state(self, mock_cluster_repo): + def test_restore_state(self): job_state_data = {"1": "things", "2": "things"} - mesos_state_data = {"3": "things", "4": "things"} state_data = { - "mesos_state": mesos_state_data, "job_state": job_state_data, } self.mcp.state_watcher.restore.return_value = state_data action_runner = mock.Mock() self.mcp.restore_state(action_runner) - mock_cluster_repo.restore_state.assert_called_with(mesos_state_data) self.mcp.jobs.restore_state.assert_called_with(job_state_data, action_runner) diff --git a/tests/serialize/runstate/statemanager_test.py b/tests/serialize/runstate/statemanager_test.py index 07f4653aa..c159db62d 100644 --- a/tests/serialize/runstate/statemanager_test.py +++ b/tests/serialize/runstate/statemanager_test.py @@ -121,15 +121,12 @@ def test_restore(self): "one": {"key": "val1"}, "two": {"key": "val2"}, }, - # _restore_dicts for MESOS_STATE - {"frameworks": "clusters"}, ] restored_state = self.manager.restore(job_names) mock_restore_metadata.assert_called_once_with() assert mock_restore_dicts.call_args_list == [ mock.call(runstate.JOB_STATE, job_names), - mock.call(runstate.MESOS_STATE, ["frameworks"]), ] assert len(mock_restore_runs.call_args_list) == 2 assert restored_state == { @@ -137,7 +134,6 @@ def test_restore(self): "one": {"key": "val1", "runs": mock_restore_runs.return_value}, "two": {"key": "val2", "runs": mock_restore_runs.return_value}, }, - runstate.MESOS_STATE: {"frameworks": "clusters"}, } def test_restore_runs_for_job(self): diff --git a/tron/mcp.py b/tron/mcp.py index 3cf643590..f0335a8d8 100644 --- a/tron/mcp.py +++ b/tron/mcp.py @@ -164,7 +164,6 @@ def restore_state(self, action_runner): """ log.info("Restoring from DynamoDB") states = self.state_watcher.restore(self.jobs.get_names()) - MesosClusterRepository.restore_state(states.get("mesos_state", {})) log.info( f"Tron will start restoring state for the jobs and will start scheduling them! Time elapsed since Tron started {time.time() - self.boot_time}" ) diff --git a/tron/serialize/runstate/statemanager.py b/tron/serialize/runstate/statemanager.py index 982a18e15..c398aafd1 100644 --- a/tron/serialize/runstate/statemanager.py +++ b/tron/serialize/runstate/statemanager.py @@ -150,11 +150,9 @@ def restore(self, job_names, skip_validation=False): # {'MASTER.k8s': {'run_nums':[0], 'enabled': True}, 'MASTER.cits_test_frequent_1': {'run_nums': [1,0], 'enabled': True}} for job_name, job_state in jobs.items(): job_state["runs"] = self._restore_runs_for_job(job_name, job_state) - frameworks = self._restore_dicts(runstate.MESOS_STATE, ["frameworks"]) state = { runstate.JOB_STATE: jobs, - runstate.MESOS_STATE: frameworks, } return state