From 8562fd9249a868ead7b4264712325a53a176df7c Mon Sep 17 00:00:00 2001 From: Noah Jackson Date: Thu, 13 Jun 2024 23:32:12 -0700 Subject: [PATCH] Add identity to task execution metadata (#2315) Signed-off-by: noahjax Signed-off-by: ddl-rliu Co-authored-by: ddl-rliu --- flytekit/models/security.py | 3 +++ flytekit/models/task.py | 9 +++++++++ pyproject.toml | 2 +- tests/flytekit/unit/extend/test_agent.py | 2 ++ 4 files changed, 15 insertions(+), 1 deletion(-) diff --git a/flytekit/models/security.py b/flytekit/models/security.py index 748b4d09eb..e210c910b7 100644 --- a/flytekit/models/security.py +++ b/flytekit/models/security.py @@ -92,12 +92,14 @@ class Identity(_common.FlyteIdlEntity): iam_role: Optional[str] = None k8s_service_account: Optional[str] = None oauth2_client: Optional[OAuth2Client] = None + execution_identity: Optional[str] = None def to_flyte_idl(self) -> _sec.Identity: return _sec.Identity( iam_role=self.iam_role if self.iam_role else None, k8s_service_account=self.k8s_service_account if self.k8s_service_account else None, oauth2_client=self.oauth2_client.to_flyte_idl() if self.oauth2_client else None, + execution_identity=self.execution_identity if self.execution_identity else None, ) @classmethod @@ -108,6 +110,7 @@ def from_flyte_idl(cls, pb2_object: _sec.Identity) -> "Identity": oauth2_client=OAuth2Client.from_flyte_idl(pb2_object.oauth2_client) if pb2_object.oauth2_client and pb2_object.oauth2_client.ByteSize() else None, + execution_identity=pb2_object.execution_identity if pb2_object.execution_identity else None, ) diff --git a/flytekit/models/task.py b/flytekit/models/task.py index 5072f03757..0532b276e2 100644 --- a/flytekit/models/task.py +++ b/flytekit/models/task.py @@ -528,6 +528,7 @@ def __init__( annotations, k8s_service_account, environment_variables, + identity, ): """ Runtime task execution metadata. @@ -539,6 +540,7 @@ def __init__( :param dict[str, str] annotations: Annotations to use for the execution of this task. :param Text k8s_service_account: Service account to use for execution of this task. :param dict[str, str] environment_variables: Environment variables for this task. + :param flytekit.models.security.Identity identity: Identity of user executing this task """ self._task_execution_id = task_execution_id self._namespace = namespace @@ -546,6 +548,7 @@ def __init__( self._annotations = annotations self._k8s_service_account = k8s_service_account self._environment_variables = environment_variables + self._identity = identity @property def task_execution_id(self): @@ -571,6 +574,10 @@ def k8s_service_account(self): def environment_variables(self): return self._environment_variables + @property + def identity(self): + return self._identity + def to_flyte_idl(self): """ :rtype: flyteidl.admin.agent_pb2.TaskExecutionMetadata @@ -584,6 +591,7 @@ def to_flyte_idl(self): environment_variables={k: v for k, v in self.environment_variables.items()} if self.labels is not None else None, + identity=self.identity.to_flyte_idl() if self.identity else None, ) return task_execution_metadata @@ -604,6 +612,7 @@ def from_flyte_idl(cls, pb2_object): environment_variables={k: v for k, v in pb2_object.environment_variables.items()} if pb2_object.environment_variables is not None else None, + identity=_sec.Identity.from_flyte_idl(pb2_object.identity) if pb2_object.identity else None, ) diff --git a/pyproject.toml b/pyproject.toml index e0fd189bd6..5aece34595 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,7 +20,7 @@ dependencies = [ "diskcache>=5.2.1", "docker>=4.0.0,<7.0.0", "docstring-parser>=0.9.0", - "flyteidl>=1.11.0b1", + "flyteidl>=1.12.0", "fsspec>=2023.3.0", "gcsfs>=2023.3.0", "googleapis-common-protos>=1.57", diff --git a/tests/flytekit/unit/extend/test_agent.py b/tests/flytekit/unit/extend/test_agent.py index 17db5c2788..3226313079 100644 --- a/tests/flytekit/unit/extend/test_agent.py +++ b/tests/flytekit/unit/extend/test_agent.py @@ -47,6 +47,7 @@ WorkflowExecutionIdentifier, ) from flytekit.models.literals import LiteralMap +from flytekit.models.security import Identity from flytekit.models.task import TaskExecutionMetadata, TaskTemplate from flytekit.tools.translator import get_serializable @@ -159,6 +160,7 @@ def simple_task(i: int): annotations={"annotation_key": "annotation_val"}, k8s_service_account="k8s service account", environment_variables={"env_var_key": "env_var_val"}, + identity=Identity(execution_identity="task executor"), )