diff --git a/task_processing/plugins/kubernetes/kube_client.py b/task_processing/plugins/kubernetes/kube_client.py index bf8f797..3853001 100644 --- a/task_processing/plugins/kubernetes/kube_client.py +++ b/task_processing/plugins/kubernetes/kube_client.py @@ -1,13 +1,13 @@ import logging import os from http import HTTPStatus +from typing import List from typing import Optional from kubernetes import client as kube_client from kubernetes import config as kube_config from kubernetes.client.exceptions import ApiException from kubernetes.client.models.v1_pod import V1Pod - logger = logging.getLogger(__name__) DEFAULT_ATTEMPTS = 2 @@ -184,6 +184,12 @@ def get_pod( pod_name: str, attempts: int = DEFAULT_ATTEMPTS, ) -> Optional[V1Pod]: + """ + Wrapper around read_namespaced_pod() in the kubernetes clientlib that adds in + retrying on ApiExceptions. + + Returns V1Pod on success, None otherwise. + """ max_attempts = attempts while attempts: try: @@ -209,6 +215,41 @@ def get_pod( ) raise logger.info(f"Ran out of retries attempting to fetch pod {pod_name}.") + raise ExceededMaxAttempts(f'Retried fetching pod {pod_name} {max_attempts} times.') + + def get_pods( + self, namespace: str, attempts: int = DEFAULT_ATTEMPTS, + ) -> Optional[List[V1Pod]]: + """ + Wrapper around list_namespaced_pod() in the kubernetes clientlib that adds in + retrying on ApiExceptions. + + Returns a list of V1Pod on success, None otherwise. + """ + max_attempts = attempts + while attempts: + try: + pods = self.core.list_namespaced_pod( + namespace=namespace, + ).items + return pods + except ApiException as e: + # Unknown pods throws ApiException w/ 404 + if e.status == 404: + logger.info(f"Found no pods in the namespace {namespace}.") + return None + if not self.maybe_reload_on_exception(exception=e) and attempts: + logger.debug( + f"Failed to fetch pods in {namespace} due to unhandled API exception, " + "retrying.", + exc_info=True + ) + attempts -= 1 + except Exception: + logger.exception( + f"Failed to fetch pods in {namespace} due to unhandled exception." + ) + raise + logger.info(f"Ran out of retries attempting to fetch pods in namespace {namespace}.") raise ExceededMaxAttempts( - f"Retried fetching pod {pod_name} {max_attempts} times." - ) + f'Retried fetching pods in namespace {namespace} {max_attempts} times.') diff --git a/task_processing/plugins/kubernetes/kubernetes_pod_executor.py b/task_processing/plugins/kubernetes/kubernetes_pod_executor.py index 4eefbdc..5b3519c 100644 --- a/task_processing/plugins/kubernetes/kubernetes_pod_executor.py +++ b/task_processing/plugins/kubernetes/kubernetes_pod_executor.py @@ -1,9 +1,14 @@ import logging -import queue -import threading import time -from queue import Queue +from multiprocessing import cpu_count +from multiprocessing import JoinableQueue +from multiprocessing import Lock +from multiprocessing import Process +from multiprocessing.pool import Pool +from queue import Empty +from time import sleep from typing import Collection +from typing import List from typing import Optional from kubernetes import watch as kube_watch @@ -52,8 +57,8 @@ logger = logging.getLogger(__name__) -POD_WATCH_THREAD_JOIN_TIMEOUT_S = 1.0 -POD_EVENT_THREAD_JOIN_TIMEOUT_S = 1.0 +POD_WATCH_PROCESS_JOIN_TIMEOUT_S = 1.0 +POD_EVENT_PROCESS_JOIN_TIMEOUT_S = 1.0 QUEUE_GET_TIMEOUT_S = 0.5 SUPPORTED_POD_MODIFIED_EVENT_PHASES = { "Failed", @@ -68,6 +73,9 @@ # control plane some breathing room RETRY_BACKOFF_EXPONENT = 1.5 +REFRESH_EXECUTOR_STATE_PROCESS_GRACE = 300 +REFRESH_EXECUTOR_STATE_PROCESS_INTERVAL = 120 + class KubernetesPodExecutor(TaskExecutor): TASK_CONFIG_INTERFACE = KubernetesTaskConfig @@ -107,7 +115,7 @@ def __init__( self.stopping = False self.task_metadata: PMap[str, KubernetesTaskMetadata] = pmap() - self.task_metadata_lock = threading.RLock() + self.task_metadata_lock = Lock() if task_configs: for task_config in task_configs: self._initialize_existing_task(task_config) @@ -117,33 +125,38 @@ def __init__( # and we've opted to not do that processing in the Pod event watcher thread so as to keep # that logic for the threads that operate on them as simple as possible and to make it # possible to cleanly shutdown both of these. - self.pending_events: "Queue[PodEvent]" = Queue() - self.event_queue: "Queue[Event]" = Queue() - + self.pending_events: "JoinableQueue[PodEvent]" = JoinableQueue() + self.event_queue: "JoinableQueue[Event]" = JoinableQueue() # TODO(TASKPROC-243): keep track of resourceVersion so that we can continue event processing # from where we left off on restarts - self.pod_event_watch_threads = [] + self.pod_event_watch_processes: List[Process] = [] self.watches = [] for kube_client in [self.kube_client] + self.watcher_kube_clients: watch = kube_watch.Watch() - pod_event_watch_thread = threading.Thread( + pod_event_watch_process = Process( target=self._pod_event_watch_loop, args=(kube_client, watch), - # ideally this wouldn't be a daemon thread, but a watch.Watch() only checks + # ideally this wouldn't be a daemon process, but a watch.Watch() only checks # if it should stop after receiving an event - and it's possible that we # have periods with no events so instead we'll attempt to stop the watch # and then join() with a small timeout to make sure that, if we shutdown - # with the thread alive, we did not drop any events + # with the process alive, we did not drop any events daemon=True, ) - pod_event_watch_thread.start() - self.pod_event_watch_threads.append(pod_event_watch_thread) + pod_event_watch_process.start() + self.pod_event_watch_processes.append(pod_event_watch_process) self.watches.append(watch) - self.pending_event_processing_thread = threading.Thread( + self.pending_event_processing_process = Process( target=self._pending_event_processing_loop, ) - self.pending_event_processing_thread.start() + self.pending_event_processing_process.start() + + self.reconciliation_task_process = Process( + target=self._reconcile_task_loop, + daemon=True, + ) + self.reconciliation_task_process.start() def _initialize_existing_task(self, task_config: KubernetesTaskConfig) -> None: """Generates task_metadata in UNKNOWN state for an existing KubernetesTaskConfig. @@ -468,7 +481,7 @@ def _pending_event_processing_loop(self) -> None: try: event = self.pending_events.get(timeout=QUEUE_GET_TIMEOUT_S) self._process_pod_event(event) - except queue.Empty: + except Empty: logger.debug( f"Pending event queue remained empty after {QUEUE_GET_TIMEOUT_S} seconds.", ) @@ -493,6 +506,46 @@ def _pending_event_processing_loop(self) -> None: logger.debug("Exiting Pod event processing - stop requested.") + def _reconcile_task_loop(self) -> None: + """ + Run in a thread to reconcile task_metadata from k8s. + """ + logger.info( + f"Waiting {REFRESH_EXECUTOR_STATE_PROCESS_GRACE}s before doing work" + ) + sleep(REFRESH_EXECUTOR_STATE_PROCESS_GRACE) + logger.debug("Starting Pod task config reconciliation.") + # allocate half of total cpu count for multiprocessing + num_cpus = cpu_count() // 2 or 1 + while not self.stopping: + try: + pods = self.kube_client.get_pods(namespace=self.namespace) + except Exception: + logger.exception( + f"Hit an exception attempting to fetch pods in namespace {self.namespace}" + ) + pods = None + + if pods is not None: + # returns a list of tuples containing (list[tuple[KubernetesTaskConfig, V1Pod]]) + # if the pod is already in task_metadata + task_configs_pods = [ + (self.task_metadata[pod.metadata.name].task_config, pod) + for pod in pods + if pod.metadata.name in self.task_metadata + ] + + # create a process pool that uses half of total cpus + with Pool(num_cpus) as pool: + # call reconcile function for each task_config in parallel + result = pool.starmap_async(self.reconcile, task_configs_pods) + # wait for all tasks to finish + result.wait() + logger.info(f"Sleeping for {REFRESH_EXECUTOR_STATE_PROCESS_INTERVAL}s") + sleep(REFRESH_EXECUTOR_STATE_PROCESS_INTERVAL) + + logger.debug("Exiting Pod task config reconciliation - stop requested.") + def _create_container_definition( self, name: str, @@ -644,7 +697,9 @@ def run(self, task_config: KubernetesTaskConfig) -> Optional[str]: return None - def reconcile(self, task_config: KubernetesTaskConfig) -> None: + def reconcile( + self, task_config: KubernetesTaskConfig, pod: Optional[V1Pod] = None + ) -> None: pod_name = task_config.pod_name pod = None for kube_client in [self.kube_client] + self.watcher_kube_clients: @@ -751,8 +806,8 @@ def stop(self) -> None: # grace period to flush the current event to the pending_events queue as well as # any other clean-up - it's possible that after this join() the thread is still alive # but in that case we can be reasonably sure that we're not dropping any data. - for pod_event_watch_thread in self.pod_event_watch_threads: - pod_event_watch_thread.join(timeout=POD_WATCH_THREAD_JOIN_TIMEOUT_S) + for pod_event_watch_process in self.pod_event_watch_processes: + pod_event_watch_process.join(timeout=POD_WATCH_PROCESS_JOIN_TIMEOUT_S) logger.debug("Waiting for all pending PodEvents to be processed...") # once we've stopped updating the pending events queue, we then wait until we're done @@ -761,11 +816,11 @@ def stop(self) -> None: self.pending_events.join() logger.debug("All pending PodEvents have been processed.") # and then give ourselves time to do any post-stop cleanup - self.pending_event_processing_thread.join( - timeout=POD_EVENT_THREAD_JOIN_TIMEOUT_S + self.pending_event_processing_process.join( + timeout=POD_EVENT_PROCESS_JOIN_TIMEOUT_S ) logger.debug("Done stopping KubernetesPodExecutor!") - def get_event_queue(self) -> "Queue[Event]": + def get_event_queue(self) -> "JoinableQueue[Event]": return self.event_queue diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index a2e136b..1ad0e1c 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -1,3 +1,4 @@ +import multiprocessing import threading import mock @@ -14,3 +15,9 @@ def mock_sleep(): def mock_Thread(): with mock.patch.object(threading, "Thread") as mock_Thread: yield mock_Thread + + +@pytest.fixture +def mock_Process(): + with mock.patch.object(multiprocessing, 'Process') as mock_Process: + yield mock_Process diff --git a/tests/unit/plugins/kubernetes/kube_client_test.py b/tests/unit/plugins/kubernetes/kube_client_test.py index 389c6ec..0c67da8 100644 --- a/tests/unit/plugins/kubernetes/kube_client_test.py +++ b/tests/unit/plugins/kubernetes/kube_client_test.py @@ -119,3 +119,22 @@ def test_KubeClient_get_pod(): mock_kube_client.CoreV1Api().read_namespaced_pod.assert_called_once_with( namespace="ns", name="pod-name" ) + + +def test_KubeClient_get_pods(): + with mock.patch( + "task_processing.plugins.kubernetes.kube_client.kube_config.load_kube_config", + autospec=True + ), mock.patch( + "task_processing.plugins.kubernetes.kube_client.kube_client", + autospec=True + ) as mock_kube_client, mock.patch.dict( + os.environ, {"KUBECONFIG": "/another/kube/config.conf"} + ): + mock_config_path = "/OVERRIDE.conf" + mock_kube_client.CoreV1Api().list_namespaced_pod.return_value = mock.Mock() + client = KubeClient(kubeconfig_path=mock_config_path) + client.get_pods(namespace='ns', attempts=1) + mock_kube_client.CoreV1Api().list_namespaced_pod.assert_called_once_with( + namespace='ns' + ) diff --git a/tests/unit/plugins/kubernetes/kubernetes_pod_executor_test.py b/tests/unit/plugins/kubernetes/kubernetes_pod_executor_test.py index 1d62f80..2338d7d 100644 --- a/tests/unit/plugins/kubernetes/kubernetes_pod_executor_test.py +++ b/tests/unit/plugins/kubernetes/kubernetes_pod_executor_test.py @@ -41,7 +41,7 @@ @pytest.fixture -def k8s_executor(mock_Thread): +def k8s_executor(mock_Process): with mock.patch( "task_processing.plugins.kubernetes.kube_client.kube_config.load_kube_config", autospec=True, @@ -90,7 +90,7 @@ def mock_task_configs(): @pytest.fixture -def k8s_executor_with_tasks(mock_Thread, mock_task_configs): +def k8s_executor_with_tasks(mock_Process, mock_task_configs): with mock.patch( "task_processing.plugins.kubernetes.kube_client.kube_config.load_kube_config", autospec=True,