Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revert "Parallelizing execution of restoring in Tron - TRON-2161 (#950)" #968

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 10 additions & 8 deletions tests/serialize/runstate/statemanager_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,13 +143,14 @@ def test_restore_runs_for_job(self):
"_restore_dicts",
autospec=True,
) as mock_restore_dicts:
mock_restore_dicts.side_effect = [
{"job_a.2": {"job_name": "job_a", "run_num": 2}, "job_a.3": {"job_name": "job_a", "run_num": 3}}
]
mock_restore_dicts.side_effect = [{"job_a.2": "two"}, {"job_a.3": "three"}]
runs = self.manager._restore_runs_for_job("job_a", job_state)

assert mock_restore_dicts.call_args_list == [mock.call(runstate.JOB_RUN_STATE, ["job_a.2", "job_a.3"])]
assert runs == [{"job_name": "job_a", "run_num": 3}, {"job_name": "job_a", "run_num": 2}]
assert mock_restore_dicts.call_args_list == [
mock.call(runstate.JOB_RUN_STATE, ["job_a.2"]),
mock.call(runstate.JOB_RUN_STATE, ["job_a.3"]),
]
assert runs == ["two", "three"]

def test_restore_runs_for_job_one_missing(self):
job_state = {"run_nums": [2, 3], "enabled": True}
Expand All @@ -158,13 +159,14 @@ def test_restore_runs_for_job_one_missing(self):
"_restore_dicts",
autospec=True,
) as mock_restore_dicts:
mock_restore_dicts.side_effect = [{"job_a.3": {"job_name": "job_a", "run_num": 3}, "job_b": {}}]
mock_restore_dicts.side_effect = [{}, {"job_a.3": "three"}]
runs = self.manager._restore_runs_for_job("job_a", job_state)

assert mock_restore_dicts.call_args_list == [
mock.call(runstate.JOB_RUN_STATE, ["job_a.2", "job_a.3"]),
mock.call(runstate.JOB_RUN_STATE, ["job_a.2"]),
mock.call(runstate.JOB_RUN_STATE, ["job_a.3"]),
]
assert runs == [{"job_name": "job_a", "run_num": 3}]
assert runs == ["three"]

def test_restore_dicts(self):
names = ["namea", "nameb"]
Expand Down
4 changes: 2 additions & 2 deletions tron/core/job_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@ def update(self, new_job_scheduler):

def restore_state(self, job_state_data, config_action_runner):
"""
Loops through the jobs and their runs in order to load their
state for each run. As we load the state, we will also schedule the next
Loops through the jobs and their runs in order to restore
state for each run. As we restore state, we will also schedule the next
runs for each job
"""
for name, state in job_state_data.items():
Expand Down
4 changes: 1 addition & 3 deletions tron/core/job_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,13 @@ def __init__(self, job):
self.watch(job)

def restore_state(self, job_state_data, config_action_runner):
"""Load the job state and schedule any JobRuns."""
"""Restore the job state and schedule any JobRuns."""
job_runs = self.job.get_job_runs_from_state(job_state_data)
for run in job_runs:
self.job.watch(run)
self.job.runs.runs.extend(job_runs)
log.info(f"{self} restored")

# Tron will recover any action run that has UNKNOWN status
# and will start connecting to task_proc
recovery.launch_recovery_actionruns_for_job_runs(
job_runs=job_runs,
)
Expand Down
1 change: 0 additions & 1 deletion tron/kubernetes.py
Original file line number Diff line number Diff line change
Expand Up @@ -639,7 +639,6 @@ def get_cluster(cls, kubeconfig_path: Optional[str] = None) -> Optional[Kubernet
kubeconfig_path = cls.kubeconfig_path

if kubeconfig_path not in cls.clusters:
# will create the task_proc executor
cluster = KubernetesCluster(
kubeconfig_path=kubeconfig_path, enabled=cls.kubernetes_enabled, default_volumes=cls.default_volumes
)
Expand Down
34 changes: 8 additions & 26 deletions tron/mcp.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import logging
import time
from contextlib import contextmanager

from tron import actioncommand
from tron import command_context
Expand All @@ -18,18 +17,6 @@
log = logging.getLogger(__name__)


@contextmanager
def timer(function_name: str):
start = time.time()
try:
yield
except Exception:
pass
finally:
end = time.time()
log.info(f"Execution time for function {function_name}: {end-start}")


def apply_master_configuration(mapping, master_config):
def get_config_value(seq):
return [getattr(master_config, item) for item in seq]
Expand Down Expand Up @@ -84,12 +71,11 @@ def initial_setup(self):
# The job schedule factories will be created in the function below
self._load_config()
# Jobs will also get scheduled (internally) once the state for action runs are restored in restore_state
with timer("self.restore_state"):
self.restore_state(
actioncommand.create_action_runner_factory_from_config(
self.config.load().get_master().action_runner,
),
)
self.restore_state(
actioncommand.create_action_runner_factory_from_config(
self.config.load().get_master().action_runner,
),
)
# Any job with existing state would have been scheduled already. Jobs
# without any state will be scheduled here.
self.jobs.run_queue_schedule()
Expand Down Expand Up @@ -173,19 +159,15 @@ def get_config_manager(self):
return self.config

def restore_state(self, action_runner):
"""Use the state manager to retrieve the persisted state from dynamodb and apply it
"""Use the state manager to retrieve to persisted state and apply it
to the configured Jobs.
"""
log.info("Restoring from DynamoDB")
with timer("restore"):
# restores the state of the jobs and their runs from DynamoDB
states = self.state_watcher.restore(self.jobs.get_names())
states = self.state_watcher.restore(self.jobs.get_names())
log.info(
f"Tron will start restoring state for the jobs and will start scheduling them! Time elapsed since Tron started {time.time() - self.boot_time}"
)
# loads the runs' state and schedule the next run for each job
with timer("self.jobs.restore_state"):
self.jobs.restore_state(states.get("job_state", {}), action_runner)
self.jobs.restore_state(states.get("job_state", {}), action_runner)
log.info(
f"Tron completed restoring state for the jobs. Time elapsed since Tron started {time.time() - self.boot_time}"
)
Expand Down
67 changes: 21 additions & 46 deletions tron/serialize/runstate/dynamodb_state_store.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import concurrent.futures
import copy
import logging
import math
import os
Expand All @@ -9,19 +7,14 @@
from collections import defaultdict
from collections import OrderedDict
from typing import DefaultDict
from typing import List
from typing import Sequence
from typing import TypeVar

import boto3 # type: ignore

from tron.metrics import timer

OBJECT_SIZE = 400000
MAX_SAVE_QUEUE = 500
MAX_ATTEMPTS = 10
log = logging.getLogger(__name__)
T = TypeVar("T")


class DynamoDBStateStore:
Expand Down Expand Up @@ -54,56 +47,38 @@ def restore(self, keys) -> dict:
vals = self._merge_items(first_items, remaining_items)
return vals

def chunk_keys(self, keys: Sequence[T]) -> List[Sequence[T]]:
"""Generates a list of chunks of keys to be used to read from DynamoDB"""
# have a for loop here for all the key chunks we want to go over
cand_keys_chunks = []
for i in range(0, len(keys), 100):
# chunks of at most 100 keys will be in this list as there could be smaller chunks
cand_keys_chunks.append(keys[i : min(len(keys), i + 100)])
return cand_keys_chunks

def _get_items(self, table_keys: list) -> object:
def _get_items(self, keys: list) -> object:
items = []
# let's avoid potentially mutating our input :)
cand_keys_list = copy.copy(table_keys)
attempts_to_retrieve_keys = 0
while len(cand_keys_list) != 0:
with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor:
responses = [
executor.submit(
self.client.batch_get_item,
RequestItems={self.name: {"Keys": chunked_keys, "ConsistentRead": True}},
)
for chunked_keys in self.chunk_keys(cand_keys_list)
]
# let's wipe the state so that we can loop back around
# if there are any un-processed keys
# NOTE: we'll re-chunk when submitting to the threadpool
# since it's possible that we've had several chunks fail
# enough keys that we'd otherwise send > 100 keys in a
# request otherwise
cand_keys_list = []
for resp in concurrent.futures.as_completed(responses):
items.extend(resp.result()["Responses"][self.name])
# add any potential unprocessed keys to the thread pool
if resp.result()["UnprocessedKeys"].get(self.name) and attempts_to_retrieve_keys < MAX_ATTEMPTS:
cand_keys_list.append(resp.result()["UnprocessedKeys"][self.name]["Keys"])
elif attempts_to_retrieve_keys >= MAX_ATTEMPTS:
failed_keys = resp.result()["UnprocessedKeys"][self.name]["Keys"]
for i in range(0, len(keys), 100):
count = 0
cand_keys = keys[i : min(len(keys), i + 100)]
while True:
resp = self.client.batch_get_item(
RequestItems={
self.name: {
"Keys": cand_keys,
"ConsistentRead": True,
},
},
)
items.extend(resp["Responses"][self.name])
if resp["UnprocessedKeys"].get(self.name) and count < 10:
cand_keys = resp["UnprocessedKeys"][self.name]["Keys"]
count += 1
elif count >= 10:
error = Exception(
f"tron_dynamodb_restore_failure: failed to retrieve items with keys \n{failed_keys}\n from dynamodb\n{resp.result()}"
f"tron_dynamodb_restore_failure: failed to retrieve items with keys \n{cand_keys}\n from dynamodb\n{resp}"
)
raise error
attempts_to_retrieve_keys += 1
else:
break
return items

def _get_first_partitions(self, keys: list):
new_keys = [{"key": {"S": key}, "index": {"N": "0"}} for key in keys]
return self._get_items(new_keys)

def _get_remaining_partitions(self, items: list):
"""Get items in the remaining partitions: N = 1 and beyond"""
keys_for_remaining_items = []
for item in items:
remaining_items = [
Expand Down
36 changes: 10 additions & 26 deletions tron/serialize/runstate/statemanager.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import concurrent.futures
import copy
import itertools
import logging
import time
Expand Down Expand Up @@ -147,41 +145,27 @@ def restore(self, job_names, skip_validation=False):
if not skip_validation:
self._restore_metadata()

# First, restore the jobs themselves
jobs = self._restore_dicts(runstate.JOB_STATE, job_names)
# jobs should be a dictionary that contains job name and number of runs
# {'MASTER.k8s': {'run_nums':[0], 'enabled': True}, 'MASTER.cits_test_frequent_1': {'run_nums': [1,0], 'enabled': True}}

# second, restore the runs for each of the jobs restored above
with concurrent.futures.ThreadPoolExecutor(max_workers=3) as executor:
# start the threads and mark each future with it's job name
# this is useful so that we can index the job name later to add the runs to the jobs dictionary
results = {
executor.submit(self._restore_runs_for_job, job_name, job_state): job_name
for job_name, job_state in jobs.items()
}
for result in concurrent.futures.as_completed(results):
jobs[results[result]]["runs"] = result.result()
for job_name, job_state in jobs.items():
job_state["runs"] = self._restore_runs_for_job(job_name, job_state)

state = {
runstate.JOB_STATE: jobs,
}
return state

def _restore_runs_for_job(self, job_name, job_state):
"""Restore the state for the runs of each job"""
run_nums = job_state["run_nums"]
keys = [jobrun.get_job_run_id(job_name, run_num) for run_num in run_nums]
job_runs_restored_states = self._restore_dicts(runstate.JOB_RUN_STATE, keys)
runs = copy.copy(job_runs_restored_states)
for run_id, state in runs.items():
if state == {}:
log.error(f"Failed to restore {run_id}, no state found for it!")
job_runs_restored_states.pop(run_id)

runs = list(job_runs_restored_states.values())
# We need to sort below otherwise the runs will not be in order
runs.sort(key=lambda x: x["run_num"], reverse=True)
runs = []
for run_num in run_nums:
key = jobrun.get_job_run_id(job_name, run_num)
run_state = list(self._restore_dicts(runstate.JOB_RUN_STATE, [key]).values())
if not run_state:
log.error(f"Failed to restore {key}, no state found for it")
else:
runs.append(run_state[0])
return runs

def _restore_metadata(self):
Expand Down
Loading