Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/master' into remote_jupyter
Browse files Browse the repository at this point in the history
  • Loading branch information
Mecoli1219 committed Aug 13, 2024
2 parents 7c8232f + bc2e000 commit 0a61ec4
Show file tree
Hide file tree
Showing 44 changed files with 672 additions and 219 deletions.
1 change: 1 addition & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ RUN apt-get update && apt-get install build-essential -y \
&& apt-get clean autoclean \
&& apt-get autoremove --yes \
&& rm -rf /var/lib/{apt,dpkg,cache,log}/ \
&& rm -rf /root/.cache/pip \
&& useradd -u 1000 flytekit \
&& chown flytekit: /root \
&& chown flytekit: /home \
Expand Down
146 changes: 114 additions & 32 deletions flytekit/bin/entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@
from flytekit.models.core import errors as _error_models
from flytekit.models.core import execution as _execution_models
from flytekit.models.core import identifier as _identifier
from flytekit.tools.fast_registration import download_distribution as _download_distribution
from flytekit.tools.fast_registration import (
download_distribution as _download_distribution,
)
from flytekit.tools.module_loader import load_object_from_module


Expand All @@ -66,7 +68,9 @@ def _compute_array_job_index():
if os.environ.get("BATCH_JOB_ARRAY_INDEX_OFFSET"):
offset = int(os.environ.get("BATCH_JOB_ARRAY_INDEX_OFFSET"))
if os.environ.get("BATCH_JOB_ARRAY_INDEX_VAR_NAME"):
return offset + int(os.environ.get(os.environ.get("BATCH_JOB_ARRAY_INDEX_VAR_NAME")))
return offset + int(
os.environ.get(os.environ.get("BATCH_JOB_ARRAY_INDEX_VAR_NAME"))
)
return offset


Expand All @@ -91,13 +95,17 @@ def _dispatch_execute(
# Step1
local_inputs_file = os.path.join(ctx.execution_state.working_dir, "inputs.pb")
ctx.file_access.get_data(inputs_path, local_inputs_file)
input_proto = utils.load_proto_from_file(_literals_pb2.LiteralMap, local_inputs_file)
input_proto = utils.load_proto_from_file(
_literals_pb2.LiteralMap, local_inputs_file
)
idl_input_literals = _literal_models.LiteralMap.from_flyte_idl(input_proto)

# Step2
# Decorate the dispatch execute function before calling it, this wraps all exceptions into one
# of the FlyteScopedExceptions
outputs = _scoped_exceptions.system_entry_point(task_def.dispatch_execute)(ctx, idl_input_literals)
outputs = _scoped_exceptions.system_entry_point(task_def.dispatch_execute)(
ctx, idl_input_literals
)
if inspect.iscoroutine(outputs):
# Handle eager-mode (async) tasks
logger.info("Output is a coroutine")
Expand All @@ -106,7 +114,9 @@ def _dispatch_execute(
# Step3a
if isinstance(outputs, VoidPromise):
logger.warning("Task produces no outputs")
output_file_dict = {_constants.OUTPUT_FILE_NAME: _literal_models.LiteralMap(literals={})}
output_file_dict = {
_constants.OUTPUT_FILE_NAME: _literal_models.LiteralMap(literals={})
}
elif isinstance(outputs, _literal_models.LiteralMap):
output_file_dict = {_constants.OUTPUT_FILE_NAME: outputs}
elif isinstance(outputs, _dynamic_job.DynamicJobSpec):
Expand All @@ -125,11 +135,16 @@ def _dispatch_execute(
# Handle user-scoped errors
except _scoped_exceptions.FlyteScopedUserException as e:
if isinstance(e.value, IgnoreOutputs):
logger.warning(f"User-scoped IgnoreOutputs received! Outputs.pb will not be uploaded. reason {e}!!")
logger.warning(
f"User-scoped IgnoreOutputs received! Outputs.pb will not be uploaded. reason {e}!!"
)
return
output_file_dict[_constants.ERROR_FILE_NAME] = _error_models.ErrorDocument(
_error_models.ContainerError(
e.error_code, e.verbose_message, e.kind, _execution_models.ExecutionError.ErrorKind.USER
e.error_code,
e.verbose_message,
e.kind,
_execution_models.ExecutionError.ErrorKind.USER,
)
)
logger.error("!! Begin User Error Captured by Flyte !!")
Expand All @@ -139,11 +154,16 @@ def _dispatch_execute(
# Handle system-scoped errors
except _scoped_exceptions.FlyteScopedSystemException as e:
if isinstance(e.value, IgnoreOutputs):
logger.warning(f"System-scoped IgnoreOutputs received! Outputs.pb will not be uploaded. reason {e}!!")
logger.warning(
f"System-scoped IgnoreOutputs received! Outputs.pb will not be uploaded. reason {e}!!"
)
return
output_file_dict[_constants.ERROR_FILE_NAME] = _error_models.ErrorDocument(
_error_models.ContainerError(
e.error_code, e.verbose_message, e.kind, _execution_models.ExecutionError.ErrorKind.SYSTEM
e.error_code,
e.verbose_message,
e.kind,
_execution_models.ExecutionError.ErrorKind.SYSTEM,
)
)
logger.error("!! Begin System Error Captured by Flyte !!")
Expand All @@ -163,23 +183,34 @@ def _dispatch_execute(
_execution_models.ExecutionError.ErrorKind.SYSTEM,
)
)
logger.error(f"Exception when executing task {task_def.name or task_def.id.name}, reason {str(e)}")
logger.error(
f"Exception when executing task {task_def.name or task_def.id.name}, reason {str(e)}"
)
logger.error("!! Begin Unknown System Error Captured by Flyte !!")
logger.error(exc_str)
logger.error("!! End Error Captured by Flyte !!")

for k, v in output_file_dict.items():
utils.write_proto_to_file(v.to_flyte_idl(), os.path.join(ctx.execution_state.engine_dir, k))
utils.write_proto_to_file(
v.to_flyte_idl(), os.path.join(ctx.execution_state.engine_dir, k)
)

ctx.file_access.put_data(ctx.execution_state.engine_dir, output_prefix, is_multipart=True)
logger.info(f"Engine folder written successfully to the output prefix {output_prefix}")
ctx.file_access.put_data(
ctx.execution_state.engine_dir, output_prefix, is_multipart=True
)
logger.info(
f"Engine folder written successfully to the output prefix {output_prefix}"
)

if not getattr(task_def, "disable_deck", True):
_output_deck(task_def.name.split(".")[-1], ctx.user_space_params)

logger.debug("Finished _dispatch_execute")

if os.environ.get("FLYTE_FAIL_ON_ERROR", "").lower() == "true" and _constants.ERROR_FILE_NAME in output_file_dict:
if (
os.environ.get("FLYTE_FAIL_ON_ERROR", "").lower() == "true"
and _constants.ERROR_FILE_NAME in output_file_dict
):
# This env is set by the flytepropeller
# AWS batch job get the status from the exit code, so once we catch the error,
# we should return the error code here
Expand Down Expand Up @@ -254,8 +285,12 @@ def setup_execution(

checkpointer = None
if checkpoint_path is not None:
checkpointer = SyncCheckpoint(checkpoint_dest=checkpoint_path, checkpoint_src=prev_checkpoint)
logger.debug(f"Checkpointer created with source {prev_checkpoint} and dest {checkpoint_path}")
checkpointer = SyncCheckpoint(
checkpoint_dest=checkpoint_path, checkpoint_src=prev_checkpoint
)
logger.debug(
f"Checkpointer created with source {prev_checkpoint} and dest {checkpoint_path}"
)

execution_parameters = ExecutionParameters(
execution_id=_identifier.WorkflowExecutionIdentifier(
Expand Down Expand Up @@ -283,7 +318,9 @@ def setup_execution(
raw_output_prefix=raw_output_data_prefix,
output_metadata_prefix=output_metadata_prefix,
checkpoint=checkpointer,
task_id=_identifier.Identifier(_identifier.ResourceType.TASK, tk_project, tk_domain, tk_name, tk_version),
task_id=_identifier.Identifier(
_identifier.ResourceType.TASK, tk_project, tk_domain, tk_name, tk_version
),
)

metadata = {
Expand All @@ -300,7 +337,9 @@ def setup_execution(
execution_metadata=metadata,
)
except TypeError: # would be thrown from DataPersistencePlugins.find_plugin
logger.error(f"No data plugin found for raw output prefix {raw_output_data_prefix}")
logger.error(
f"No data plugin found for raw output prefix {raw_output_data_prefix}"
)
raise

ctx = ctx.new_builder().with_file_access(file_access).build()
Expand Down Expand Up @@ -387,7 +426,7 @@ def _execute_task(
:return:
"""
if not pickled and len(resolver_args) < 1:
raise Exception("cannot be <1")
raise ValueError("cannot be <1")

with setup_execution(
raw_output_data_prefix,
Expand Down Expand Up @@ -452,10 +491,14 @@ def _execute_map_task(
:return:
"""
if len(resolver_args) < 1:
raise Exception(f"Resolver args cannot be <1, got {resolver_args}")
raise ValueError(f"Resolver args cannot be <1, got {resolver_args}")

with setup_execution(
raw_output_data_prefix, checkpoint_path, prev_checkpoint, dynamic_addl_distro, dynamic_dest_dir
raw_output_data_prefix,
checkpoint_path,
prev_checkpoint,
dynamic_addl_distro,
dynamic_dest_dir,
) as ctx:
task_index = _compute_array_job_index()

Expand All @@ -468,7 +511,9 @@ def _execute_map_task(
map_task = cloudpickle.load(f)
else:
mtr = load_object_from_module(resolver)()
map_task = mtr.load_task(loader_args=resolver_args, max_concurrency=max_concurrency)
map_task = mtr.load_task(
loader_args=resolver_args, max_concurrency=max_concurrency
)
# Special case for the map task resolver, we need to append the task index to the output prefix.
# TODO: (https://github.com/flyteorg/flyte/issues/5011) Remove legacy map task
if mtr.name() == "flytekit.core.legacy_map_task.MapTaskResolver":
Expand All @@ -486,7 +531,9 @@ def _execute_map_task(


def normalize_inputs(
raw_output_data_prefix: Optional[str], checkpoint_path: Optional[str], prev_checkpoint: Optional[str]
raw_output_data_prefix: Optional[str],
checkpoint_path: Optional[str],
prev_checkpoint: Optional[str],
):
# Backwards compatibility - if Propeller hasn't filled this in, then it'll come through here as the original
# template string, so let's explicitly set it to None so that the downstream functions will know to fall back
Expand All @@ -495,7 +542,11 @@ def normalize_inputs(
raw_output_data_prefix = None
if checkpoint_path == "{{.checkpointOutputPrefix}}":
checkpoint_path = None
if prev_checkpoint == "{{.prevCheckpointPrefix}}" or prev_checkpoint == "" or prev_checkpoint == '""':
if (
prev_checkpoint == "{{.prevCheckpointPrefix}}"
or prev_checkpoint == ""
or prev_checkpoint == '""'
):
prev_checkpoint = None

return raw_output_data_prefix, checkpoint_path, prev_checkpoint
Expand All @@ -516,8 +567,15 @@ def _pass_through():
@click.option("--dynamic-addl-distro", required=False)
@click.option("--dynamic-dest-dir", required=False)
@click.option("--resolver", required=False)
@click.option("--pickled", is_flag=True, default=False, help="Use this to mark if the distribution is pickled.")
@click.option("--pkl-file", required=False, help="Location where pickled file can be found.")
@click.option(
"--pickled",
is_flag=True,
default=False,
help="Use this to mark if the distribution is pickled.",
)
@click.option(
"--pkl-file", required=False, help="Location where pickled file can be found."
)
@click.argument(
"resolver-args",
type=click.UNPROCESSED,
Expand Down Expand Up @@ -569,9 +627,19 @@ def execute_task_cmd(
@_pass_through.command("pyflyte-fast-execute")
@click.option("--additional-distribution", required=False)
@click.option("--dest-dir", required=False)
@click.option("--pickled", is_flag=True, default=False, help="Use this to mark if the distribution is pickled.")
@click.option(
"--pickled",
is_flag=True,
default=False,
help="Use this to mark if the distribution is pickled.",
)
@click.argument("task-execute-cmd", nargs=-1, type=click.UNPROCESSED)
def fast_execute_task_cmd(additional_distribution: str, dest_dir: str, pickled: bool, task_execute_cmd: List[str]):
def fast_execute_task_cmd(
additional_distribution: str,
dest_dir: str,
pickled: bool,
task_execute_cmd: List[str],
):
"""
Downloads a compressed code distribution specified by additional-distribution and then calls the underlying
task execute command for the updated code.
Expand All @@ -580,13 +648,20 @@ def fast_execute_task_cmd(additional_distribution: str, dest_dir: str, pickled:
if pickled:
click.secho("Received pickled object")
dest_file = os.path.join(os.getcwd(), "pickled.tar.gz")
FlyteContextManager.current_context().file_access.get_data(additional_distribution, dest_file)
FlyteContextManager.current_context().file_access.get_data(
additional_distribution, dest_file
)
cmd_extend = ["--pickled", "--pkl-file", dest_file]
else:
if not dest_dir:
dest_dir = os.getcwd()
_download_distribution(additional_distribution, dest_dir)
cmd_extend = ["--dynamic-addl-distro", additional_distribution, "--dynamic-dest-dir", dest_dir]
cmd_extend = [
"--dynamic-addl-distro",
additional_distribution,
"--dynamic-dest-dir",
dest_dir,
]

# Insert the call to fast before the unbounded resolver args
cmd = []
Expand Down Expand Up @@ -619,8 +694,15 @@ def handle_sigterm(signum, frame):
@click.option("--resolver", required=True)
@click.option("--checkpoint-path", required=False)
@click.option("--prev-checkpoint", required=False)
@click.option("--pickled", is_flag=True, default=False, help="Use this to mark if the distribution is pickled.")
@click.option("--pkl-file", required=False, help="Location where pickled file can be found.")
@click.option(
"--pickled",
is_flag=True,
default=False,
help="Use this to mark if the distribution is pickled.",
)
@click.option(
"--pkl-file", required=False, help="Location where pickled file can be found."
)
@click.argument(
"resolver-args",
type=click.UNPROCESSED,
Expand Down
2 changes: 1 addition & 1 deletion flytekit/clients/auth/auth_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ def _request_access_token(self, auth_code) -> Credentials:
if resp.status_code != _StatusCodes.OK:
# TODO: handle expected (?) error cases:
# https://auth0.com/docs/flows/guides/device-auth/call-api-device-auth#token-responses
raise Exception(
raise RuntimeError(
"Failed to request access token with response: [{}] {}".format(resp.status_code, resp.content)
)
return self._credentials_from_response(resp)
Expand Down
9 changes: 7 additions & 2 deletions flytekit/clis/sdk_in_container/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
from flytekit.remote import FlyteLaunchPlan, FlyteRemote, FlyteTask, FlyteWorkflow, remote_fs
from flytekit.remote.executions import FlyteWorkflowExecution
from flytekit.tools import module_loader
from flytekit.tools.script_mode import _find_project_root, compress_scripts
from flytekit.tools.script_mode import _find_project_root, compress_scripts, get_all_modules
from flytekit.tools.translator import Options


Expand Down Expand Up @@ -379,6 +379,10 @@ def to_click_option(
This handles converting workflow input types to supported click parameters with callbacks to initialize
the input values to their expected types.
"""
if input_name != input_name.lower():
# Click does not support uppercase option names: https://github.com/pallets/click/issues/837
raise ValueError(f"Workflow input name must be lowercase: {input_name!r}")

run_level_params: RunLevelParams = ctx.obj

literal_converter = FlyteLiteralConverter(
Expand Down Expand Up @@ -493,7 +497,8 @@ def _update_flyte_context(params: RunLevelParams) -> FlyteContext.Builder:
if output_prefix and ctx.file_access.is_remote(output_prefix):
with tempfile.TemporaryDirectory() as tmp_dir:
archive_fname = pathlib.Path(os.path.join(tmp_dir, "script_mode.tar.gz"))
compress_scripts(params.computed_params.project_root, str(archive_fname), params.computed_params.module)
modules = get_all_modules(params.computed_params.project_root, params.computed_params.module)
compress_scripts(params.computed_params.project_root, str(archive_fname), modules)
remote_dir = file_access.get_random_remote_directory()
remote_archive_fname = f"{remote_dir}/script_mode.tar.gz"
file_access.put_data(str(archive_fname), remote_archive_fname)
Expand Down
4 changes: 2 additions & 2 deletions flytekit/core/array_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,9 @@ def __init__(
if isinstance(metadata, _workflow_model.NodeMetadata):
self.metadata = metadata
else:
raise Exception("Invalid metadata for LaunchPlan. Should be NodeMetadata.")
raise TypeError("Invalid metadata for LaunchPlan. Should be NodeMetadata.")
else:
raise Exception("Only LaunchPlans are supported for now.")
raise ValueError(f"Only LaunchPlans are supported for now, but got {type(target)}")

def construct_node_metadata(self) -> _workflow_model.NodeMetadata:
# Part of SupportsNodeCreation interface
Expand Down
2 changes: 1 addition & 1 deletion flytekit/core/base_sql_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def query_template(self) -> str:
return self._query_template

def execute(self, **kwargs) -> Any:
raise Exception("Cannot run a SQL Task natively, please mock.")
raise NotImplementedError("Cannot run a SQL Task natively, please mock.")

def get_query(self, **kwargs) -> str:
return self.interpolate_query(self.query_template, **kwargs)
Expand Down
2 changes: 1 addition & 1 deletion flytekit/core/base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ def __call__(self, *args: object, **kwargs: object) -> Union[Tuple[Promise], Pro
return flyte_entity_call_handler(self, *args, **kwargs) # type: ignore

def compile(self, ctx: FlyteContext, *args, **kwargs):
raise Exception("not implemented")
raise NotImplementedError

def get_container(self, settings: SerializationSettings) -> Optional[_task_model.Container]:
"""
Expand Down
2 changes: 1 addition & 1 deletion flytekit/core/class_based_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,6 @@ def loader_args(self, settings: SerializationSettings, t: PythonAutoContainerTas
This is responsible for turning an instance of a task into args that the load_task function can reconstitute.
"""
if t not in self.mapping:
raise Exception("no such task")
raise ValueError("no such task")

return [f"{self.mapping.index(t)}"]
Loading

0 comments on commit 0a61ec4

Please sign in to comment.