Skip to content
This repository has been archived by the owner on Oct 23, 2023. It is now read-only.

[flyteidl] Support attribute access on promises #439

Merged
merged 2 commits into from
Sep 21, 2023

Conversation

ByronHsu
Copy link
Contributor

@ByronHsu ByronHsu commented Sep 11, 2023

TL;DR

Support attribute access on output promises for List, Dict, Dataclass, and the mix of them. After the change, users are able to pass in an attribute of output in the workflow, for example:

@task
def t1() -> Dict[str, str]:
  return {"a": "b"}

@task
def t2(x: str):
  return

@workflow
def wf():
  o = t1()
  t2(x=o["a"])

Type

  • Bug Fix
  • Feature
  • Plugin

Are all requirements met?

  • Code completed
  • Smoke tested
  • Unit tests added
  • Code documentation added
  • Any pending items have an associated Issue

Complete description

1. Add attribute path on promise

  • flyteidl: Add attribute path on output reference (promise)
  • flytekit: Append attribute to attribute path in promises by overriding __getitem__ and __getattr__ in class promise
  • flytekit: Modify flyteidlIdentity to serialize attribute path to protobuf

2. Compile Validation at flytepropeller

Skip the type check if the promise contains attribute path because if the type is dataclass, we have no way to infer the resolved type.

3. Execution

Although in different places, the logic is the same. It iterates the attribute path

  1. If the current val is LiteralMap (dictionary), access the value by key
  2. If the current val is LiteralCollection (List), access the value by index
  3. If the current val is LiteralScalar (single value), break the loop

If the remaining val is a LiteralScalar and it contains Dataclass, then iterate the remaining path to extract the value from dataclass.

The resolved result of dataclass can be dataclass, list, or a scalar.

  1. Dataclass: construct it as a protobuf struct, and wrap literal scalar generic on top.
  2. List: construct it as a LiteralCollection
  3. Scalar: construct it as a Literal, and wrap literal scalar primitive on top.

Testing

In the following workflow:

  • basic_workflow contains trivial examples to access output attributes
  • failed_workflow contains examples that causes exception (e.g. out of range)
  • advanced_workflow contains examples with more complex attribute access
from dataclasses import dataclass
from typing import Dict, List, NamedTuple

from dataclasses_json import dataclass_json

from flytekit import WorkflowFailurePolicy, task, workflow


@dataclass_json
@dataclass
class foo:
    a: str


@task
def t1() -> (List[str], Dict[str, str], foo):
    return ["a", "b"], {"a": "b"}, foo(a="b")


@task
def t2(a: str) -> str:
    print("a", a)
    return a


@task
def t3() -> (Dict[str, List[str]], List[Dict[str, str]], Dict[str, foo]):
    return {"a": ["b"]}, [{"a": "b"}], {"a": foo(a="b")}


@task
def t4(a: List[str]):
    print("a", a)


@task
def t5(a: Dict[str, str]):
    print("a", a)


@workflow
def basic_workflow():
    l, d, f = t1()
    t2(a=l[0])
    t2(a=d["a"])
    t2(a=f.a)


@workflow(
    # The workflow doesn't fail when one of the nodes fails but other nodes are still executable
    failure_policy=WorkflowFailurePolicy.FAIL_AFTER_EXECUTABLE_NODES_COMPLETE
)
def failed_workflow():
    # This workflow is supposed to fail due to exceptions
    l, d, f = t1()
    t2(a=l[100])
    t2(a=d["b"])
    t2(a=f.b)


@workflow
def advanced_workflow():
    dl, ld, dd = t3()
    t2(a=dl["a"][0])
    t2(a=ld[0]["a"])
    t2(a=dd["a"].a)

    t4(a=dl["a"])
    t5(a=ld[0])

Local Execution

$ pyflyte run test_workflow.py basic_workflow   

a a
a b
a b

$ pyflyte run test_workflow.py failed_workflow

Failed with Exception Code: USER:PromiseAttributeResolveError
Underlying Exception: Failed to resolve attribute path [100] in promise Resolved(o0=<FlyteLiteral collection { literals { scalar { primitive { string_value: "a" } } } literals { scalar { primitive { string_value: "b" } } } }>), index 100 out of range 2
Encountered error while executing workflow 'test_workflow.failed_workflow':
  Error encountered while executing 'failed_workflow':
  Failed to resolve attribute path [100] in promise Resolved(o0=<FlyteLiteral collection { literals { scalar { primitive { string_value: "a" } } } literals { scalar { primitive { string_value: "b" } } } }>), index 100 out of range 2

$ pyflyte run test_workflow.py advanced_workflow
a b
a b
a b
a ['b']
a {'a': 'b'}

Remote Execution

  • Basic
image
  • Failed
image
  • Advanced
Screenshot 2023-09-11 at 10 35 25

Tracking Issue

flyteorg/flyte#3864

@codecov
Copy link

codecov bot commented Sep 11, 2023

Codecov Report

Patch coverage has no change and project coverage change: +2.55% 🎉

Comparison is base (ef37788) 75.92% compared to head (683f5a0) 78.48%.

❗ Current head 683f5a0 differs from pull request most recent head 4e0aa0b. Consider uploading reports for the commit 4e0aa0b to get more accurate results

Additional details and impacted files
@@            Coverage Diff             @@
##           master     #439      +/-   ##
==========================================
+ Coverage   75.92%   78.48%   +2.55%     
==========================================
  Files          18       18              
  Lines        1458     1250     -208     
==========================================
- Hits         1107      981     -126     
+ Misses        294      212      -82     
  Partials       57       57              
Flag Coverage Δ
unittests ?

Flags with carried forward coverage won't be shown. Click here to find out more.

see 18 files with indirect coverage changes

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Signed-off-by: byhsu <[email protected]>
Signed-off-by: byhsu <[email protected]>
@ByronHsu ByronHsu merged commit b130a81 into flyteorg:master Sep 21, 2023
15 checks passed
eapolinario pushed a commit that referenced this pull request Sep 28, 2023
* init

Signed-off-by: byhsu <[email protected]>

* fix typo

Signed-off-by: byhsu <[email protected]>

---------

Signed-off-by: byhsu <[email protected]>
Co-authored-by: byhsu <[email protected]>
@ByronHsu
Copy link
Contributor Author

ByronHsu commented Oct 15, 2023

Update 2023/10/14:

After a discussion with @wild-endeavor, we want to perform strong type check for fields in dataclasses. For example,

@dataclass
class Foo:
  a: int

@task
def t0() -> Foo:
  return Foo(a=1)

@task
def t1(x: str):
   ...

@workflow
def wf():
  f = to()
  t1(x=f.a)

In the original implementation, wf passed the compilation though f.a (int) and x (str) are different types. This is because we didn't store the type of dataclass fields in LiteralType, so we couldn't check them in the flytepropeller.

In order to do type check for dataclass field, I added dataclass_type in TypeStructure. It is a map<str, LiteralType>, which stores the mapping of dataclass field name to LiteralType. This also handles the case of nested dataclass because child dataclass can also be stored as a LiiteralType in the parent's dataclass_type.

On the flytekit side, I constructed dataclass_type recursively in DataclassTransformer, and use dataclass_type to compare the field type in flytepropeller at compile time.

Related Code Changes:

  1. flyteidl PR: [flyteidl] Support attribute access on promise - Add dataclass_type proto flyte#4233
  2. this commit in flytekit PR: flyteorg/flytekit@722a9a6
  3. this commit in flytepropeller PR: flyteorg/flyte@849e7c8

I also conducted e2e test to ensure the behavior. In the following example, the workflow failed at compile time because of mismatched type.


@task
def t1() -> (List[str], Dict[str, str], foo):
    return ["a", "b"], {"a": "b"}, foo(a="b")


@task
def t6(a: int):
    return a


@workflow
def dataclass_type_check_fails_at_compile():
    l, d, f = t1()
    # f.a is str, but t6 expects int. This should fail at compile time
    t6(a=f.a)

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants