Skip to content

Commit

Permalink
Improve function name retrieval (#329)
Browse files Browse the repository at this point in the history
  • Loading branch information
mthrok authored Jan 17, 2025
1 parent 84c2f62 commit 4fc0bc2
Showing 1 changed file with 9 additions and 2 deletions.
11 changes: 9 additions & 2 deletions src/spdl/pipeline/_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
AsyncIterable,
AsyncIterator,
Awaitable,
Callable,
Coroutine,
Iterable,
Sequence,
Expand All @@ -37,7 +38,7 @@
from ._pipeline import Pipeline
from ._utils import create_task

__all__ = ["PipelineFailure", "PipelineBuilder"]
__all__ = ["PipelineFailure", "PipelineBuilder", "_get_op_name"]

_LG = logging.getLogger(__name__)

Expand Down Expand Up @@ -115,6 +116,12 @@ async def _put_eof_when_done(queue):
################################################################################


def _get_op_name(op: Callable) -> str:
if isinstance(op, partial):
return _get_op_name(op.func)
return getattr(op, "__name__", op.__class__.__name__)


def _pipe(
input_queue: AsyncQueue[T],
op: Callables[T, U],
Expand Down Expand Up @@ -619,7 +626,7 @@ def pipe(
"when `output_order` is 'input'."
)

name = name or getattr(op, "__name__", op.__class__.__name__)
name = name or _get_op_name(op)

if kwargs:
# pyre-ignore
Expand Down

0 comments on commit 4fc0bc2

Please sign in to comment.