diff --git a/src/spdl/pipeline/_builder.py b/src/spdl/pipeline/_builder.py index d36a6481..021b25eb 100644 --- a/src/spdl/pipeline/_builder.py +++ b/src/spdl/pipeline/_builder.py @@ -15,6 +15,7 @@ AsyncIterable, AsyncIterator, Awaitable, + Callable, Coroutine, Iterable, Sequence, @@ -37,7 +38,7 @@ from ._pipeline import Pipeline from ._utils import create_task -__all__ = ["PipelineFailure", "PipelineBuilder"] +__all__ = ["PipelineFailure", "PipelineBuilder", "_get_op_name"] _LG = logging.getLogger(__name__) @@ -115,6 +116,12 @@ async def _put_eof_when_done(queue): ################################################################################ +def _get_op_name(op: Callable) -> str: + if isinstance(op, partial): + return _get_op_name(op.func) + return getattr(op, "__name__", op.__class__.__name__) + + def _pipe( input_queue: AsyncQueue[T], op: Callables[T, U], @@ -619,7 +626,7 @@ def pipe( "when `output_order` is 'input'." ) - name = name or getattr(op, "__name__", op.__class__.__name__) + name = name or _get_op_name(op) if kwargs: # pyre-ignore