Skip to content

Commit

Permalink
Merge pull request #213 from Yelp/u/krall/what_if_multi_cluster_is_easy
Browse files Browse the repository at this point in the history
A quick attempt at making taskproc handle migrations from one k8s cluster to another
  • Loading branch information
jfongatyelp authored Jul 11, 2024
2 parents 8c76fed + ba70ca1 commit e4753ff
Show file tree
Hide file tree
Showing 2 changed files with 128 additions and 26 deletions.
81 changes: 55 additions & 26 deletions task_processing/plugins/kubernetes/kubernetes_pod_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import Collection
from typing import Optional

from kubernetes import watch
from kubernetes import watch as kube_watch
from kubernetes.client import V1Affinity
from kubernetes.client import V1Container
from kubernetes.client import V1ContainerPort
Expand Down Expand Up @@ -72,13 +72,22 @@ def __init__(
kubeconfig_path: Optional[str] = None,
task_configs: Optional[Collection[KubernetesTaskConfig]] = [],
emit_events_without_state_transitions: bool = False,
# kubeconfigs used to continue to watch other clusters
# Used when transitioning to a new cluster in the primary kubeconfig_path to continue watching still-running pods on other clusters
watcher_kubeconfig_paths: Collection[str] = (),
) -> None:
if not version:
version = "unknown_task_processing"
user_agent = f"{namespace}/v{version}"
self.kube_client = KubeClient(
kubeconfig_path=kubeconfig_path, user_agent=user_agent
)

self.watcher_kube_clients = [
KubeClient(kubeconfig_path=watcher_kubeconfig_path, user_agent=user_agent)
for watcher_kubeconfig_path in watcher_kubeconfig_paths
]

self.namespace = namespace

# Pod modified events that did not result in a pod state transition are usually not
Expand Down Expand Up @@ -106,17 +115,23 @@ def __init__(

# TODO(TASKPROC-243): keep track of resourceVersion so that we can continue event processing
# from where we left off on restarts
self.watch = watch.Watch()
self.pod_event_watch_thread = threading.Thread(
target=self._pod_event_watch_loop,
# ideally this wouldn't be a daemon thread, but a watch.Watch() only checks
# if it should stop after receiving an event - and it's possible that we
# have periods with no events so instead we'll attempt to stop the watch
# and then join() with a small timeout to make sure that, if we shutdown
# with the thread alive, we did not drop any events
daemon=True,
)
self.pod_event_watch_thread.start()
self.pod_event_watch_threads = []
self.watches = []
for kube_client in [self.kube_client] + self.watcher_kube_clients:
watch = kube_watch.Watch()
pod_event_watch_thread = threading.Thread(
target=self._pod_event_watch_loop,
args=(kube_client, watch),
# ideally this wouldn't be a daemon thread, but a watch.Watch() only checks
# if it should stop after receiving an event - and it's possible that we
# have periods with no events so instead we'll attempt to stop the watch
# and then join() with a small timeout to make sure that, if we shutdown
# with the thread alive, we did not drop any events
daemon=True,
)
pod_event_watch_thread.start()
self.pod_event_watch_threads.append(pod_event_watch_thread)
self.watches.append(watch)

self.pending_event_processing_thread = threading.Thread(
target=self._pending_event_processing_loop,
Expand All @@ -143,7 +158,9 @@ def _initialize_existing_task(self, task_config: KubernetesTaskConfig) -> None:
),
)

def _pod_event_watch_loop(self) -> None:
def _pod_event_watch_loop(
self, kube_client: KubeClient, watch: kube_watch.Watch
) -> None:
logger.debug(f"Starting watching Pod events for namespace={self.namespace}.")
# TODO(TASKPROC-243): we'll need to correctly handle resourceVersion expiration for the case
# where the gap between task_proc shutting down and coming back up is long enough for data
Expand All @@ -155,8 +172,8 @@ def _pod_event_watch_loop(self) -> None:
# see: https://github.com/kubernetes/kubernetes/issues/74022
while not self.stopping:
try:
for pod_event in self.watch.stream(
self.kube_client.core.list_namespaced_pod, self.namespace
for pod_event in watch.stream(
kube_client.core.list_namespaced_pod, self.namespace
):
# it's possible that we've received an event after we've already set the stop
# flag since Watch streams block forever, so re-check if we've stopped before
Expand All @@ -168,7 +185,7 @@ def _pod_event_watch_loop(self) -> None:
break
except ApiException as e:
if not self.stopping:
if not self.kube_client.maybe_reload_on_exception(exception=e):
if not kube_client.maybe_reload_on_exception(exception=e):
logger.exception(
"Unhandled API exception while watching Pod events - restarting watch!"
)
Expand Down Expand Up @@ -589,11 +606,18 @@ def run(self, task_config: KubernetesTaskConfig) -> Optional[str]:

def reconcile(self, task_config: KubernetesTaskConfig) -> None:
pod_name = task_config.pod_name
try:
pod = self.kube_client.get_pod(namespace=self.namespace, pod_name=pod_name)
except Exception:
logger.exception(f"Hit an exception attempting to fetch pod {pod_name}")
pod = None
pod = None
for kube_client in [self.kube_client] + self.watcher_kube_clients:
try:
pod = kube_client.get_pod(namespace=self.namespace, pod_name=pod_name)
except Exception:
logger.exception(
f"Hit an exception attempting to fetch pod {pod_name} from {kube_client.kubeconfig_path}"
)
else:
# kube_client.get_pod will return None with no exception if it sees a 404 from API
if pod:
break

if pod_name not in self.task_metadata:
self._initialize_existing_task(task_config)
Expand Down Expand Up @@ -640,9 +664,12 @@ def kill(self, task_id: str) -> bool:
This function will request that Kubernetes delete the named Pod and will return
True if the Pod termination request was succesfully emitted or False otherwise.
"""
terminated = self.kube_client.terminate_pod(
namespace=self.namespace,
pod_name=task_id,
terminated = any(
kube_client.terminate_pod(
namespace=self.namespace,
pod_name=task_id,
)
for kube_client in [self.kube_client] + self.watcher_kube_clients
)
if terminated:
logger.info(
Expand Down Expand Up @@ -678,12 +705,14 @@ def stop(self) -> None:
logger.debug("Signaling Pod event Watch to stop streaming events...")
# make sure that we've stopped watching for events before calling join() - otherwise,
# join() will block until we hit the configured timeout (or forever with no timeout).
self.watch.stop()
for watch in self.watches:
watch.stop()
# timeout arbitrarily chosen - we mostly just want to make sure that we have a small
# grace period to flush the current event to the pending_events queue as well as
# any other clean-up - it's possible that after this join() the thread is still alive
# but in that case we can be reasonably sure that we're not dropping any data.
self.pod_event_watch_thread.join(timeout=POD_WATCH_THREAD_JOIN_TIMEOUT_S)
for pod_event_watch_thread in self.pod_event_watch_threads:
pod_event_watch_thread.join(timeout=POD_WATCH_THREAD_JOIN_TIMEOUT_S)

logger.debug("Waiting for all pending PodEvents to be processed...")
# once we've stopped updating the pending events queue, we then wait until we're done
Expand Down
73 changes: 73 additions & 0 deletions tests/unit/plugins/kubernetes/kubernetes_pod_executor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,24 @@ def k8s_executor(mock_Thread):
executor.stop()


@pytest.fixture
def k8s_executor_with_watcher_clusters(mock_Thread):
with mock.patch(
"task_processing.plugins.kubernetes.kube_client.kube_config.load_kube_config",
autospec=True,
), mock.patch(
"task_processing.plugins.kubernetes.kube_client.kube_client", autospec=True
), mock.patch.dict(
os.environ, {"KUBECONFIG": "/this/doesnt/exist.conf"}
):
executor = KubernetesPodExecutor(
namespace="task_processing_tests",
watcher_kubeconfig_paths=["/this/also/doesnt/exist.conf"],
)
yield executor
executor.stop()


@pytest.fixture
def mock_task_configs():
test_task_names = ["job1.action1", "job1.action2", "job2.action1", "job3.action2"]
Expand Down Expand Up @@ -86,6 +104,18 @@ def k8s_executor_with_tasks(mock_Thread, mock_task_configs):
executor.stop()


def test_init_watch_setup(k8s_executor):
assert len(k8s_executor.watches) == len(k8s_executor.pod_event_watch_threads) == 1


def test_init_watch_setup_multicluster(k8s_executor_with_watcher_clusters):
assert (
len(k8s_executor_with_watcher_clusters.watches)
== len(k8s_executor_with_watcher_clusters.pod_event_watch_threads)
== 2
)


def test_run_updates_task_metadata(k8s_executor):
task_config = KubernetesTaskConfig(
name="name", uuid="uuid", image="fake_image", command="fake_command"
Expand Down Expand Up @@ -866,6 +896,49 @@ def test_reconcile_missing_pod(
assert tm.task_state == KubernetesTaskState.TASK_LOST


def test_reconcile_multicluster(
k8s_executor_with_watcher_clusters,
):
task_config = mock.Mock(spec=KubernetesTaskConfig)
task_config.pod_name = "pod--name.uuid"
task_config.name = "job-name"

k8s_executor_with_watcher_clusters.task_metadata = pmap(
{
task_config.pod_name: KubernetesTaskMetadata(
task_config=mock.Mock(spec=KubernetesTaskConfig),
task_state=KubernetesTaskState.TASK_UNKNOWN,
task_state_history=v(),
)
}
)

mock_watcher_kube_client = mock.Mock(autospec=True)
mock_found_pod = mock.Mock(spec=V1Pod)
mock_found_pod.metadata.name = task_config.pod_name
mock_found_pod.status.phase = "Running"
mock_found_pod.status.host_ip = "1.2.3.4"
mock_found_pod.spec.node_name = "kubenode"
mock_watcher_kube_client.get_pod.return_value = mock_found_pod
mock_watcher_kube_clients = [mock_watcher_kube_client]

with mock.patch.object(
k8s_executor_with_watcher_clusters, "kube_client", autospec=True
) as mock_kube_client, mock.patch.object(
k8s_executor_with_watcher_clusters,
"watcher_kube_clients",
mock_watcher_kube_clients,
):
mock_kube_client.get_pod.return_value = None
k8s_executor_with_watcher_clusters.reconcile(task_config)

mock_watcher_kube_client.get_pod.assert_called()
assert k8s_executor_with_watcher_clusters.event_queue.qsize() == 1
assert len(k8s_executor_with_watcher_clusters.task_metadata) == 1
tm = k8s_executor_with_watcher_clusters.task_metadata["pod--name.uuid"]
assert tm.task_state == KubernetesTaskState.TASK_RUNNING


def test_reconcile_existing_pods(k8s_executor, mock_task_configs):
mock_pods = []
test_phases = ["Running", "Succeeded", "Failed", "Unknown"]
Expand Down

0 comments on commit e4753ff

Please sign in to comment.