From 3d4404e9d9a9b2a3327f8aee664a8e71ac1f18b8 Mon Sep 17 00:00:00 2001 From: Rehan Sohail Durrani Date: Wed, 1 Jun 2022 06:19:32 -0700 Subject: [PATCH] FEAT-#4412: Add Batch Pipeline API to Modin (#4452) Co-authored-by: Yaroslav Igoshev Co-authored-by: Mahesh Vashishtha Signed-off-by: Rehan Durrani --- .github/workflows/ci.yml | 2 + docs/development/architecture.rst | 3 +- docs/flow/modin/experimental/batch.rst | 12 + docs/flow/modin/experimental/experimental.rst | 2 +- docs/flow/modin/experimental/index.rst | 2 + docs/release_notes/release_notes-0.15.0.rst | 3 + docs/release_notes/release_notes-template.rst | 1 + docs/usage_guide/advanced_usage/batch.rst | 349 +++++++++++ docs/usage_guide/advanced_usage/index.rst | 7 + .../partitioning/virtual_partition.py | 15 +- modin/experimental/batch/__init__.py | 19 + modin/experimental/batch/pipeline.py | 395 ++++++++++++ .../experimental/batch/test/test_pipeline.py | 581 ++++++++++++++++++ 13 files changed, 1386 insertions(+), 5 deletions(-) create mode 100644 docs/flow/modin/experimental/batch.rst create mode 100644 docs/usage_guide/advanced_usage/batch.rst create mode 100644 modin/experimental/batch/__init__.py create mode 100644 modin/experimental/batch/pipeline.py create mode 100644 modin/experimental/batch/test/test_pipeline.py diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 56c8c245295..c708f34e874 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -119,6 +119,7 @@ jobs: modin/experimental/core/execution/native/implementations/omnisci_on_native/omnisci_worker.py \ - run: python scripts/doc_checker.py modin/experimental/core/storage_formats/omnisci - run: python scripts/doc_checker.py modin/experimental/core/execution/native/implementations/omnisci_on_native/exchange/dataframe_protocol + - run: python scripts/doc_checker.py modin/experimental/batch/pipeline.py - run: python scripts/doc_checker.py modin/logging lint-flake8: @@ -462,6 +463,7 @@ jobs: if: matrix.engine == 'ray' - run: pytest -n 2 modin/experimental/xgboost/test/test_dmatrix.py if: matrix.engine == 'ray' + - run: pytest -n 2 modin/experimental/batch/test/test_pipeline.py - run: pytest -n 2 modin/pandas/test/dataframe/test_binary.py - run: pytest -n 2 modin/pandas/test/dataframe/test_default.py - run: pytest -n 2 modin/pandas/test/dataframe/test_indexing.py diff --git a/docs/development/architecture.rst b/docs/development/architecture.rst index a5646f04573..ce3c8b3619f 100644 --- a/docs/development/architecture.rst +++ b/docs/development/architecture.rst @@ -338,7 +338,8 @@ details. The documentation covers most modules, with more docs being added every │ │ ├─── :doc:`sklearn ` │ │ ├───spreadsheet │ │ ├───sql - │ │ └─── :doc:`xgboost ` + │ │ ├─── :doc:`xgboost ` + │ │ └─── :doc:`batch ` │ └───pandas │ ├─── :doc:`dataframe ` │ └─── :doc:`series ` diff --git a/docs/flow/modin/experimental/batch.rst b/docs/flow/modin/experimental/batch.rst new file mode 100644 index 00000000000..8ed1e956297 --- /dev/null +++ b/docs/flow/modin/experimental/batch.rst @@ -0,0 +1,12 @@ +Batch Pipeline API +"""""""""""""""""" + +This API exposes the ability to pipeline row-parallel batch queries on a Modin DataFrame. Currently, +this feature is only supported for the ``PandasOnRay`` execution. + +API +''' + +.. automodule:: modin.experimental.batch.pipeline + :members: + diff --git a/docs/flow/modin/experimental/experimental.rst b/docs/flow/modin/experimental/experimental.rst index cb7d22bfb09..56a1f7cfbcd 100644 --- a/docs/flow/modin/experimental/experimental.rst +++ b/docs/flow/modin/experimental/experimental.rst @@ -7,5 +7,5 @@ In some cases Modin can give the user the opportunity to extend (not modify) typ API or to try new functionality in order to get more flexibility. Depending on the exact experimental feature user may need to install additional packages, change configurations or replace the standard Modin import statement ``import modin.pandas as pd`` with modified version -``import modin.experimental.pandas as pd``. For concreate experimental feature example, please +``import modin.experimental.pandas as pd``. For concrete examples of experimental features, please refer to the desired link from the :ref:`directory tree `. diff --git a/docs/flow/modin/experimental/index.rst b/docs/flow/modin/experimental/index.rst index 7d70131bd91..6b62a36736a 100644 --- a/docs/flow/modin/experimental/index.rst +++ b/docs/flow/modin/experimental/index.rst @@ -8,6 +8,7 @@ and provides a limited set of functionality: * :doc:`xgboost ` * :doc:`sklearn ` +* :doc:`batch ` .. toctree:: @@ -15,3 +16,4 @@ and provides a limited set of functionality: sklearn xgboost + batch diff --git a/docs/release_notes/release_notes-0.15.0.rst b/docs/release_notes/release_notes-0.15.0.rst index a93d3254013..b24efe1df39 100644 --- a/docs/release_notes/release_notes-0.15.0.rst +++ b/docs/release_notes/release_notes-0.15.0.rst @@ -59,6 +59,8 @@ Key Features and Updates * FIX-#4390: Add `redis` to Modin dependencies (#4396) * FIX-#3689: Add black and flake8 into development environment files (#4480) * TEST-#4516: Add numpydoc to developer requirements (#4517) +* New Features + * FEAT-#4412: Add Batch Pipeline API to Modin (#4452) Contributors ------------ @@ -76,3 +78,4 @@ Contributors @jrsacher @orcahmlee @naren-ponder +@RehanSD diff --git a/docs/release_notes/release_notes-template.rst b/docs/release_notes/release_notes-template.rst index 70aed81213c..3b248d6d20d 100644 --- a/docs/release_notes/release_notes-template.rst +++ b/docs/release_notes/release_notes-template.rst @@ -27,6 +27,7 @@ Key Features and Updates * * Dependencies * +* New Features Contributors ------------ diff --git a/docs/usage_guide/advanced_usage/batch.rst b/docs/usage_guide/advanced_usage/batch.rst new file mode 100644 index 00000000000..ce8f494368d --- /dev/null +++ b/docs/usage_guide/advanced_usage/batch.rst @@ -0,0 +1,349 @@ +Batch Pipline API Usage Guide +============================= + +Modin provides an experimental batching feature that pipelines row-parallel queries. This feature +is currently only supported for the ``PandasOnRay`` engine. Please note that this feature is experimental +and behavior or interfaces could be changed. + +Usage examples +-------------- + +In examples below we build and run some pipelines. It is important to note that the queries passed to +the pipeline operate on Modin DataFrame partitions, which are backed by ``pandas``. When using ``pandas``- +module level functions, please make sure to import and use ``pandas`` rather than ``modin.pandas``. + +Simple Batch Pipelining +^^^^^^^^^^^^^^^^^^^^^^^ + +This example walks through a simple batch pipeline in order to familiarize the user with the API. + +.. code-block:: python + + from modin.experimental.batch import PandasQueryPipeline + import modin.pandas as pd + import numpy as np + + df = pd.DataFrame( + np.random.randint(0, 100, (100, 100)), + columns=[f"col {i}" for i in range(1, 101)], + ) # Build the dataframe we will pipeline. + pipeline = PandasQueryPipeline(df) # Build the pipeline. + pipeline.add_query(lambda df: df + 1, is_output=True) # Add the first query and specify that + # it is an output query. + pipeline.add_query( + lambda df: df.rename(columns={f"col {i}":f"col {i-1}" for i in range(1, 101)}) + ) # Add a second query. + pipeline.add_query( + lambda df: df.drop(columns=['col 99']), + is_output=True, + ) # Add a third query and specify that it is an output query. + new_df = pd.DataFrame( + np.ones((100, 100)), + columns=[f"col {i}" for i in range(1, 101)], + ) # Build a second dataframe that we will pipeline now instead. + pipeline.update_df(new_df) # Update the dataframe that we will pipeline to be `new_df` + # instead of `df`. + result_dfs = pipeline.compute_batch() # Begin batch processing. + + # Print pipeline results + print(f"Result of Query 1:\n{result_dfs[0]}") + print(f"Result of Query 2:\n{result_dfs[1]}") + # Output IDs can also be specified + pipeline = PandasQueryPipeline(df) # Build the pipeline. + pipeline.add_query( + lambda df: df + 1, + is_output=True, + output_id=1, + ) # Add the first query, specify that it is an output query, as well as specify an output id. + pipeline.add_query( + lambda df: df.rename(columns={f"col {i}":f"col {i-1}" for i in range(1, 101)}) + ) # Add a second query. + pipeline.add_query( + lambda df: df.drop(columns=['col 99']), + is_output=True, + output_id=2, + ) # Add a third query, specify that it is an output query, and specify an output_id. + result_dfs = pipeline.compute_batch() # Begin batch processing. + + # Print pipeline results - should be a dictionary mapping Output IDs to resulting dataframes: + print(f"Mapping of Output ID to dataframe:\n{result_dfs}") + # Print results + for query_id, res_df in result_dfs.items(): + print(f"Query {query_id} resulted in\n{res_df}") + +Batch Pipelining with Postprocessing +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +A postprocessing function can also be provided when calling ``pipeline.compute_batch``. The example +below runs a similar pipeline as above, but the postprocessing function writes the output dfs to +a parquet file. + +.. code-block:: python + + from modin.experimental.batch import PandasQueryPipeline + import modin.pandas as pd + import numpy as np + import os + import shutil + + df = pd.DataFrame( + np.random.randint(0, 100, (100, 100)), + columns=[f"col {i}" for i in range(1, 101)], + ) # Build the dataframe we will pipeline. + pipeline = PandasQueryPipeline(df) # Build the pipeline. + pipeline.add_query( + lambda df: df + 1, + is_output=True, + output_id=1, + ) # Add the first query, specify that it is an output query, as well as specify an output id. + pipeline.add_query( + lambda df: df.rename(columns={f"col {i}":f"col {i-1}" for i in range(1, 101)}) + ) # Add a second query. + pipeline.add_query( + lambda df: df.drop(columns=['col 99']), + is_output=True, + output_id=2, + ) # Add a third query, specify that it is an output query, and specify an output_id. + def postprocessing_func(df, output_id, partition_id): + filepath = f"query_{output_id}/" + os.makedirs(filepath, exist_ok=True) + filepath += f"part-{partition_id:04d}.parquet" + df.to_parquet(filepath) + return df + result_dfs = pipeline.compute_batch( + postprocessor=postprocessing_func, + pass_partition_id=True, + pass_output_id=True, + ) # Begin computation, pass in a postprocessing function, and specify that partition ID and + # output ID should be passed to that postprocessing function. + + print(os.system("ls query_1/")) # Should show `NPartitions.get()` parquet files - which + # correspond to partitions of the output of query 1. + print(os.system("ls query_2/")) # Should show `NPartitions.get()` parquet files - which + # correspond to partitions of the output of query 2. + + for query_id, res_df in result_dfs.items(): + written_df = pd.read_parquet(f"query_{query_id}/") + shutil.rmtree(f"query_{query_id}/") # Clean up + print(f"Written and Computed DF are " + + f"{'equal' if res_df.equals(written_df) else 'not equal'} for query {query_id}") + +Batch Pipelining with Fan Out +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +If the input dataframe to a query is small (consisting of only one partition), it is possible to +induce additional parallelism using the ``fan_out`` argument. The ``fan_out`` argument replicates +the input partition, applies the query to each replica, and then coalesces all of the replicas back +to one partition using the ``reduce_fn`` that must also be specified when ``fan_out`` is ``True``. + +It is possible to control the parallelism via the ``num_partitions`` parameter passed to the +constructor of the ``PandasQueryPipeline``. This parameter designates the desired number of partitions, +and defaults to ``NPartitions.get()`` when not specified. During fan out, the input partition is replicated +``num_partitions`` times. In the previous examples, ``num_partitions`` was not specified, and so defaulted +to ``NPartitions.get()``. + +The example below demonstrates the usage of ``fan_out`` and ``num_partitions``. We first demonstrate +an example of a function that would benefit from this computation pattern: + +.. code-block:: python + + import glob + from PIL import Image + import torchvision.transforms as T + import torchvision + + transforms = T.Compose([T.ToTensor()]) + model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True) + model.eval() + COCO_INSTANCE_CATEGORY_NAMES = [ + '__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', + 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign', + 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', + 'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A', + 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', + 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', + 'bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', + 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', + 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table', + 'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', + 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book', + 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush' + ] + + def contains_cat(image, model): + image = transforms(image) + labels = [COCO_INSTANCE_CATEGORY_NAMES[i] for i in model([image])[0]['labels']] + return 'cat' in labels + + def serial_query(df): + """ + This function takes as input a dataframe with a single row corresponding to a folder + containing images to parse. Each image in the folder is passed through a neural network + that detects whether it contains a cat, in serial, and a new column is computed for the + dataframe that counts the number of images containing cats. + + Parameters + ---------- + df : a dataframe + The dataframe to process + + Returns + ------- + The same dataframe as before, with an additional column containing the count of images + containing cats. + """ + model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True) + model.eval() + img_folder = df['images'][0] + images = sorted(glob.glob(f"{img_folder}/*.jpg")) + cats = 0 + for img in images: + cats = cats + 1 if contains_cat(Image.open(img), model) else cats + df['cat_count'] = cats + return df + +To download the image files to test out this code, run the following bash script, which downloads +the images from the fast-ai-coco S3 bucket to a folder called ``images`` in your current working +directory: + +.. code-block:: shell + + aws s3 cp s3://fast-ai-coco/coco_tiny.tgz . --no-sign-request; tar -xf coco_tiny.tgz; mkdir \ + images; mv coco_tiny/train/* images/; rm -rf coco_tiny; rm -rf coco_tiny.tgz + +We can pipeline that code like so: + +.. code-block:: python + + import modin.pandas as pd + from modin.experimental.batch import PandasQueryPipeline + from time import time + df = pd.DataFrame([['images']], columns=['images']) + pipeline = PandasQueryPipeline(df) + pipeline.add_query(serial_query, is_output=True) + serial_start = time() + df_with_cat_count = pipeline.compute_batch()[0] + serial_end = time() + print(f"Result of pipeline:\n{df_with_cat_count}") + +We can induce `8x` parallelism into the pipeline above by combining the ``fan_out`` and ``num_partitions`` parameters like so: + +.. code-block:: python + + import modin.pandas as pd + from modin.experimental.batch import PandasQueryPipeline + import shutil + from time import time + df = pd.DataFrame([['images']], columns=['images']) + desired_num_partitions = 8 + def parallel_query(df, partition_id): + """ + This function takes as input a dataframe with a single row corresponding to a folder + containing images to parse. It parses `total_images/desired_num_partitions` images every + time it is called. A new column is computed for the dataframe that counts the number of + images containing cats. + + Parameters + ---------- + df : a dataframe + The dataframe to process + partition_id : int + The partition id of the dataframe that this function runs on. + + Returns + ------- + The same dataframe as before, with an additional column containing the count of images + containing cats. + """ + model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True) + model.eval() + img_folder = df['images'][0] + images = sorted(glob.glob(f"{img_folder}/*.jpg")) + total_images = len(images) + cats = 0 + start_index = partition_id * (total_images // desired_num_partitions) + if partition_id == desired_num_partitions - 1: # Last partition must parse to end of list + images = images[start_index:] + else: + end_index = (partition_id + 1) * (total_images // desired_num_partitions) + images = images[start_index:end_index] + for img in images: + cats = cats + 1 if contains_cat(Image.open(img), model) else cats + df['cat_count'] = cats + return df + + def reduce_fn(dfs): + """ + Coalesce the results of fanning out the `parallel_query` query. + + Parameters + ---------- + dfs : a list of dataframes + The resulting dataframes from fanning out `parallel_query` + + Returns + ------- + A new dataframe whose `cat_count` column is the sum of the `cat_count` column of all + dataframes in `dfs` + """ + df = dfs[0] + cat_count = df['cat_count'][0] + for dataframe in dfs[1:]: + cat_count += dataframe['cat_count'][0] + df['cat_count'] = cat_count + return df + pipeline = PandasQueryPipeline(df, desired_num_partitions) + pipeline.add_query( + parallel_query, + fan_out=True, + reduce_fn=reduce_fn, + is_output=True, + pass_partition_id=True + ) + parallel_start = time() + df_with_cat_count = pipeline.compute_batch()[0] + parallel_end = time() + print(f"Result of pipeline:\n{df_with_cat_count}") + print(f"Total Time in Serial: {serial_end - serial_start}") + print(f"Total time with induced parallelism: {parallel_end - parallel_start}") + shutil.rmtree("images/") # Clean up + +Batch Pipelining with Dynamic Repartitioning +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Similarly, it is also possible to hint to the Pipeline API to repartition after a node completes +computation. This is currently only supported if the input dataframe consists of only one partition. +The number of partitions after repartitioning is controlled by the ``num_partitions`` parameter +passed to the constructor of the ``PandasQueryPipeline``. + +The following example demonstrates how to use the ``repartition_after`` parameter. + +.. code-block:: python + + import modin.pandas as pd + from modin.experimental.batch import PandasQueryPipeline + import numpy as np + + small_df = pd.DataFrame([[1, 2, 3]]) # Create a small dataframe + + def increase_dataframe_size(df): + import pandas + new_df = pandas.concat([df] * 1000) + new_df = new_df.reset_index(drop=True) # Get a new range index that isn't duplicated + return new_df + + desired_num_partitions = 24 # We will repartition to 24 partitions + + def add_partition_id_to_df(df, partition_id): + import pandas + new_col = pandas.Series([partition_id]*len(df), name="partition_id", index=df.index) + return pandas.concat([df, new_col], axis=1) + + pipeline = PandasQueryPipeline(small_df, desired_num_partitions) + pipeline.add_query(increase_dataframe_size, repartition_after=True) + pipeline.add_query(add_partition_id_to_df, pass_partition_id=True, is_output=True) + result_df = pipeline.compute_batch()[0] + print(f"Number of partitions passed to second query: " + + f"{len(np.unique(result_df['partition_id'].values))}") + print(f"Result of pipeline:\n{result_df}") + diff --git a/docs/usage_guide/advanced_usage/index.rst b/docs/usage_guide/advanced_usage/index.rst index 7ee4bc48521..0287bb37f09 100644 --- a/docs/usage_guide/advanced_usage/index.rst +++ b/docs/usage_guide/advanced_usage/index.rst @@ -12,6 +12,7 @@ Advanced Usage modin_xgboost modin_in_the_cloud modin_logging + batch .. meta:: :description lang=en: @@ -93,6 +94,12 @@ and system memory. Logging is disabled by default, but when it is enabled, log f at the same directory level as the notebook/script used to run Modin. See our :doc:`Logging with Modin documentation ` for usage information. +Batch Pipeline API +------------------ +Modin provides an experimental batched API that pipelines row parallel queries. See our :doc:`Batch Pipline API Usage Guide ` +for a walkthrough on how to use this feature, as well as :doc:`Batch Pipeline API documentation ` +for more information about the API. + .. _`blog post`: https://medium.com/riselab/why-every-data-scientist-using-pandas-needs-modin-bringing-sql-to-dataframes-3b216b29a7c0 .. _`Modin SQL documentation`: modin_sql.html .. _`Modin Spreadsheet API documentation`: spreadsheets_api.html diff --git a/modin/core/execution/ray/implementations/pandas_on_ray/partitioning/virtual_partition.py b/modin/core/execution/ray/implementations/pandas_on_ray/partitioning/virtual_partition.py index db3fb62bdee..28a3dac0da4 100644 --- a/modin/core/execution/ray/implementations/pandas_on_ray/partitioning/virtual_partition.py +++ b/modin/core/execution/ray/implementations/pandas_on_ray/partitioning/virtual_partition.py @@ -381,15 +381,24 @@ def width(self): self._width_cache = self.list_of_partitions_to_combine[0].width() return self._width_cache - def drain_call_queue(self): - """Execute all operations stored in this partition's call queue.""" + def drain_call_queue(self, num_splits=None): + """ + Execute all operations stored in this partition's call queue. + + Parameters + ---------- + num_splits : int, default: None + The number of times to split the result object. + """ def drain(df): for func, args, kwargs in self.call_queue: df = func(df, *args, **kwargs) return df - drained = super(PandasOnRayDataframeVirtualPartition, self).apply(drain) + drained = super(PandasOnRayDataframeVirtualPartition, self).apply( + drain, num_splits=num_splits + ) self.list_of_partitions_to_combine = drained self.call_queue = [] diff --git a/modin/experimental/batch/__init__.py b/modin/experimental/batch/__init__.py new file mode 100644 index 00000000000..08eaab3867c --- /dev/null +++ b/modin/experimental/batch/__init__.py @@ -0,0 +1,19 @@ +# Licensed to Modin Development Team under one or more contributor license agreements. +# See the NOTICE file distributed with this work for additional information regarding +# copyright ownership. The Modin Development Team licenses this file to you under the +# Apache License, Version 2.0 (the "License"); you may not use this file except in +# compliance with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under +# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific language +# governing permissions and limitations under the License. + +from .pipeline import PandasQueryPipeline + + +__all__ = [ + "PandasQueryPipeline", +] diff --git a/modin/experimental/batch/pipeline.py b/modin/experimental/batch/pipeline.py new file mode 100644 index 00000000000..21d3ab17437 --- /dev/null +++ b/modin/experimental/batch/pipeline.py @@ -0,0 +1,395 @@ +# Licensed to Modin Development Team under one or more contributor license agreements. +# See the NOTICE file distributed with this work for additional information regarding +# copyright ownership. The Modin Development Team licenses this file to you under the +# Apache License, Version 2.0 (the "License"); you may not use this file except in +# compliance with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under +# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific language +# governing permissions and limitations under the License. + +"""Module houses ``PandasQueryPipeline`` and ``PandasQuery`` classes, that implement a batch pipeline protocol for Modin Dataframes.""" + +from typing import Callable, Optional +import numpy as np + +import modin.pandas as pd +from modin.core.storage_formats.pandas import PandasQueryCompiler +from modin.error_message import ErrorMessage +from modin.core.execution.ray.implementations.pandas_on_ray.dataframe.dataframe import ( + PandasOnRayDataframe, +) +from modin.config import NPartitions +from modin.utils import get_current_execution + + +class PandasQuery(object): + """ + Internal representation of a single query in a pipeline. + + This object represents a single function to be pipelined in a batch pipeline. + + Parameters + ---------- + func : Callable + The function to apply to the dataframe. + is_output : bool, default: False + Whether this query is an output query and should be passed both to the next query, and + directly to postprocessing. + repartition_after : bool, default: False + Whether to repartition after this query is computed. Currently, repartitioning is only + supported if there is 1 partition prior to repartitioning. + fan_out : bool, default: False + Whether to fan out this node. If True and only 1 partition is passed as input, the partition + is replicated `PandasQueryPipeline.num_partitions` (default: `NPartitions.get`) times, and + the function is called on each. The `reduce_fn` must also be specified. + pass_partition_id : bool, default: False + Whether to pass the numerical partition id to the query. + reduce_fn : Callable, default: None + The reduce function to apply if `fan_out` is set to True. This takes the + `PandasQueryPipeline.num_partitions` (default: `NPartitions.get`) partitions that result from + this query, and combines them into 1 partition. + output_id : int, default: None + An id to assign to this node if it is an output. + + Notes + ----- + `func` must be a function that is applied along an axis of the dataframe. + + Use `pandas` for any module level functions inside `func` since it operates directly on + partitions. + """ + + def __init__( + self, + func: Callable, + is_output: bool = False, + repartition_after: bool = False, + fan_out: bool = False, + pass_partition_id: bool = False, + reduce_fn: Optional[Callable] = None, + output_id: Optional[int] = None, + ): + self.func = func + self.is_output = is_output + self.repartition_after = repartition_after + self.fan_out = fan_out + self.pass_partition_id = pass_partition_id + self.reduce_fn = reduce_fn + self.output_id = output_id + # List of sub-queries to feed into this query, if this query is an output node. + self.operators = None + + +class PandasQueryPipeline(object): + """ + Internal representation of a query pipeline. + + This object keeps track of the functions that compose to form a query pipeline. + + Parameters + ---------- + df : modin.pandas.Dataframe + The dataframe to perform this pipeline on. + num_partitions : int, optional + The number of partitions to maintain for the batched dataframe. + If not specified, the value is assumed equal to ``NPartitions.get()``. + + Notes + ----- + Only row-parallel pipelines are supported. All queries will be applied along the row axis. + """ + + def __init__(self, df, num_partitions: Optional[int] = None): + if get_current_execution() != "PandasOnRay" or ( + not isinstance(df._query_compiler._modin_frame, PandasOnRayDataframe) + ): # pragma: no cover + ErrorMessage.not_implemented( + "Batch Pipeline API is only implemented for `PandasOnRay` execution." + ) + ErrorMessage.single_warning( + "The Batch Pipeline API is an experimental feature and still under development in Modin." + ) + self.df = df + self.num_partitions = num_partitions if num_partitions else NPartitions.get() + self.outputs = [] # List of output queries. + self.query_list = [] # List of all queries. + self.is_output_id_specified = ( + False # Flag to indicate that `output_id` has been specified for a node. + ) + + def update_df(self, df): + """ + Update the dataframe to perform this pipeline on. + + Parameters + ---------- + df : modin.pandas.DataFrame + The new dataframe to perform this pipeline on. + """ + if get_current_execution() != "PandasOnRay" or ( + not isinstance(df._query_compiler._modin_frame, PandasOnRayDataframe) + ): # pragma: no cover + ErrorMessage.not_implemented( + "Batch Pipeline API is only implemented for `PandasOnRay` execution." + ) + self.df = df + + def add_query( + self, + func: Callable, + is_output: bool = False, + repartition_after: bool = False, + fan_out: bool = False, + pass_partition_id: bool = False, + reduce_fn: Optional[Callable] = None, + output_id: Optional[int] = None, + ): + """ + Add a query to the current pipeline. + + Parameters + ---------- + func : Callable + DataFrame query to perform. + is_output : bool, default: False + Whether this query should be designated as an output query. If `True`, the output of + this query is passed both to the next query and directly to postprocessing. + repartition_after : bool, default: False + Whether the dataframe should be repartitioned after this query. Currently, + repartitioning is only supported if there is 1 partition prior. + fan_out : bool, default: False + Whether to fan out this node. If True and only 1 partition is passed as input, the + partition is replicated `self.num_partitions` (default: `NPartitions.get`) times, + and the function is called on each. The `reduce_fn` must also be specified. + pass_partition_id : bool, default: False + Whether to pass the numerical partition id to the query. + reduce_fn : Callable, default: None + The reduce function to apply if `fan_out` is set to True. This takes the + `self.num_partitions` (default: `NPartitions.get`) partitions that result from this + query, and combines them into 1 partition. + output_id : int, default: None + An id to assign to this node if it is an output. + + Notes + ----- + Use `pandas` for any module level functions inside `func` since it operates directly on + partitions. + """ + if not is_output and output_id is not None: + raise ValueError("Output ID cannot be specified for non-output node.") + if is_output: + if not self.is_output_id_specified and output_id is not None: + if len(self.outputs) != 0: + raise ValueError("Output ID must be specified for all nodes.") + if output_id is None and self.is_output_id_specified: + raise ValueError("Output ID must be specified for all nodes.") + self.query_list.append( + PandasQuery( + func, + is_output, + repartition_after, + fan_out, + pass_partition_id, + reduce_fn, + output_id, + ) + ) + if is_output: + self.outputs.append(self.query_list[-1]) + if output_id is not None: + self.is_output_id_specified = True + self.outputs[-1].operators = self.query_list[:-1] + self.query_list = [] + + def _complete_nodes(self, list_of_nodes, partitions): + """ + Run a sub-query end to end. + + Parameters + ---------- + list_of_nodes : list of PandasQuery + The functions that compose this query. + partitions : list of PandasOnRayDataframeVirtualPartition + The partitions that compose the dataframe that is input to this sub-query. + + Returns + ------- + list of PandasOnRayDataframeVirtualPartition + The partitions that result from computing the functions represented by `list_of_nodes`. + """ + for node in list_of_nodes: + if node.fan_out: + if len(partitions) > 1: + ErrorMessage.not_implemented( + "Fan out is only supported with DataFrames with 1 partition." + ) + partitions[0] = partitions[0].force_materialization() + partition_list = partitions[0].list_of_partitions_to_combine + partitions[0] = partitions[0].add_to_apply_calls(node.func, 0) + partitions[0].drain_call_queue(num_splits=1) + new_dfs = [] + for i in range(1, self.num_partitions): + new_dfs.append( + type(partitions[0])( + partition_list, + full_axis=partitions[0].full_axis, + ).add_to_apply_calls(node.func, i) + ) + new_dfs[-1].drain_call_queue(num_splits=1) + + def reducer(df): + df_inputs = [df] + for df in new_dfs: + df_inputs.append(df.to_pandas()) + return node.reduce_fn(df_inputs) + + partitions = [partitions[0].add_to_apply_calls(reducer)] + elif node.repartition_after: + if len(partitions) > 1: + ErrorMessage.not_implemented( + "Dynamic repartitioning is currently only supported for DataFrames with 1 partition." + ) + partitions[0] = ( + partitions[0].add_to_apply_calls(node.func).force_materialization() + ) + new_dfs = [] + + def mask_partition(df, i): + new_length = len(df.index) // self.num_partitions + if i == self.num_partitions - 1: + return df.iloc[i * new_length :] + return df.iloc[i * new_length : (i + 1) * new_length] + + for i in range(self.num_partitions): + new_dfs.append( + type(partitions[0])( + partitions[0].list_of_partitions_to_combine, + full_axis=partitions[0].full_axis, + ).add_to_apply_calls(mask_partition, i) + ) + partitions = new_dfs + else: + if node.pass_partition_id: + partitions = [ + part.add_to_apply_calls(node.func, i) + for i, part in enumerate(partitions) + ] + else: + partitions = [ + part.add_to_apply_calls(node.func) for part in partitions + ] + return partitions + + def compute_batch( + self, + postprocessor: Optional[Callable] = None, + pass_partition_id: Optional[bool] = False, + pass_output_id: Optional[bool] = False, + ): + """ + Run the completed pipeline + any postprocessing steps end to end. + + Parameters + ---------- + postprocessor : Callable, default: None + A postprocessing function to be applied to each output partition. + The order of arguments passed is `df` (the partition), `output_id` + (if `pass_output_id=True`), and `partition_id` (if `pass_partition_id=True`). + pass_partition_id : bool, default: False + Whether or not to pass the numerical partition id to the postprocessing function. + pass_output_id : bool, default: False + Whether or not to pass the output ID associated with output queries to the + postprocessing function. + + Returns + ------- + list or dict or DataFrame + If output ids are specified, a dictionary mapping output id to the resulting dataframe + is returned, otherwise, a list of the resulting dataframes is returned. + """ + if len(self.outputs) == 0: + ErrorMessage.single_warning( + "No outputs to compute. Returning an empty list. Please specify outputs by calling `add_query` with `is_output=True`." + ) + return [] + if not self.is_output_id_specified and pass_output_id: + raise ValueError( + "`pass_output_id` is set to True, but output ids have not been specified. " + + "To pass output ids, please specify them using the `output_id` kwarg with pipeline.add_query" + ) + if self.is_output_id_specified: + outs = {} + else: + outs = [] + modin_frame = self.df._query_compiler._modin_frame + partitions = modin_frame._partition_mgr_cls.row_partitions( + modin_frame._partitions + ) + for node in self.outputs: + partitions = self._complete_nodes(node.operators + [node], partitions) + for part in partitions: + part.drain_call_queue(num_splits=1) + if postprocessor: + output_partitions = [] + for partition_id, partition in enumerate(partitions): + args = [] + if pass_output_id: + args.append(node.output_id) + if pass_partition_id: + args.append(partition_id) + output_partitions.append( + partition.add_to_apply_calls(postprocessor, *args) + ) + else: + output_partitions = [ + part.add_to_apply_calls(lambda df: df) for part in partitions + ] + [ + part.drain_call_queue(num_splits=self.num_partitions) + for part in output_partitions + ] # Ensures our result df is block partitioned. + if not self.is_output_id_specified: + outs.append(output_partitions) + else: + outs[node.output_id] = output_partitions + if not self.is_output_id_specified: + final_results = [] + for df in outs: + partitions = [] + for row_partition in df: + partitions.append(row_partition.list_of_partitions_to_combine) + partitions = np.array(partitions) + partition_mgr_class = PandasOnRayDataframe._partition_mgr_cls + index = partition_mgr_class.get_indices( + 0, partitions, lambda df: df.axes[0] + ) + columns = partition_mgr_class.get_indices( + 1, partitions, lambda df: df.axes[1] + ) + result_modin_frame = PandasOnRayDataframe(partitions, index, columns) + query_compiler = PandasQueryCompiler(result_modin_frame) + result_df = pd.DataFrame(query_compiler=query_compiler) + final_results.append(result_df) + else: + final_results = {} + for id, df in outs.items(): + partitions = [] + for row_partition in df: + partitions.append(row_partition.list_of_partitions_to_combine) + partitions = np.array(partitions) + partition_mgr_class = PandasOnRayDataframe._partition_mgr_cls + index = partition_mgr_class.get_indices( + 0, partitions, lambda df: df.axes[0] + ) + columns = partition_mgr_class.get_indices( + 1, partitions, lambda df: df.axes[1] + ) + result_modin_frame = PandasOnRayDataframe(partitions, index, columns) + query_compiler = PandasQueryCompiler(result_modin_frame) + result_df = pd.DataFrame(query_compiler=query_compiler) + final_results[id] = result_df + return final_results diff --git a/modin/experimental/batch/test/test_pipeline.py b/modin/experimental/batch/test/test_pipeline.py new file mode 100644 index 00000000000..3248ca7b94f --- /dev/null +++ b/modin/experimental/batch/test/test_pipeline.py @@ -0,0 +1,581 @@ +# Licensed to Modin Development Team under one or more contributor license agreements. +# See the NOTICE file distributed with this work for additional information regarding +# copyright ownership. The Modin Development Team licenses this file to you under the +# Apache License, Version 2.0 (the "License"); you may not use this file except in +# compliance with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under +# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific language +# governing permissions and limitations under the License. + +import pytest +import numpy as np +import pandas +import ray + +import modin.pandas as pd +from modin.config import Engine, NPartitions +from modin.distributed.dataframe.pandas.partitions import from_partitions +from modin.experimental.batch.pipeline import PandasQueryPipeline +from modin.pandas.test.utils import df_equals + + +@pytest.mark.skipif( + Engine.get() != "Ray", + reason="Only Ray supports the Batch Pipeline API", +) +class TestPipelineRayEngine: + def test_warnings(self): + """Ensure that creating a Pipeline object raises the correct warnings.""" + arr = np.random.randint(0, 1000, (1000, 1000)) + df = pd.DataFrame(arr) + # Ensure that building a pipeline warns users that it is an experimental feature + with pytest.warns( + UserWarning, + match="The Batch Pipeline API is an experimental feature and still under development in Modin.", + ): + pipeline = PandasQueryPipeline(df) + with pytest.warns( + UserWarning, + match="No outputs to compute. Returning an empty list. Please specify outputs by calling `add_query` with `is_output=True`.", + ): + output = pipeline.compute_batch() + assert output == [], "Empty pipeline did not return an empty list." + + def test_pipeline_simple(self): + """Create a simple pipeline and ensure that it runs end to end correctly.""" + arr = np.random.randint(0, 1000, (1000, 1000)) + df = pd.DataFrame(arr) + + def add_col(df): + df["new_col"] = df.sum(axis=1) + return df + + # Build pipeline + pipeline = PandasQueryPipeline(df) + pipeline.add_query(add_col) + pipeline.add_query(lambda df: df * -30) + pipeline.add_query( + lambda df: df.rename(columns={i: f"col {i}" for i in range(1000)}) + ) + + def add_row_to_partition(df): + return pandas.concat([df, df.iloc[[-1]]]) + + pipeline.add_query(add_row_to_partition, is_output=True) + new_df = pipeline.compute_batch()[0] + # Build df without pipelining to ensure correctness + correct_df = add_col(pd.DataFrame(arr)) + correct_df *= -30 + correct_df = pd.DataFrame( + correct_df.rename(columns={i: f"col {i}" for i in range(1000)})._to_pandas() + ) + correct_modin_frame = correct_df._query_compiler._modin_frame + partitions = correct_modin_frame._partition_mgr_cls.row_partitions( + correct_modin_frame._partitions + ) + partitions = [ + partition.add_to_apply_calls(add_row_to_partition) + for partition in partitions + ] + [partition.drain_call_queue() for partition in partitions] + partitions = [partition.list_of_blocks for partition in partitions] + correct_df = from_partitions(partitions, axis=None) + # Compare pipelined and non-pipelined df + df_equals(correct_df, new_df) + # Ensure that setting `num_partitions` when creating a pipeline does not change `NPartitions` + num_partitions = NPartitions.get() + PandasQueryPipeline(df, num_partitions=(num_partitions - 1)) + assert ( + NPartitions.get() == num_partitions + ), "Pipeline did not change NPartitions.get()" + + def test_update_df(self): + """Ensure that `update_df` updates the df that the pipeline runs on.""" + df = pd.DataFrame([[1, 2, 3], [4, 5, 6]]) + pipeline = PandasQueryPipeline(df) + pipeline.add_query(lambda df: df + 3, is_output=True) + new_df = df * -1 + pipeline.update_df(new_df) + output_df = pipeline.compute_batch()[0] + df_equals((df * -1) + 3, output_df) + + def test_multiple_outputs(self): + """Create a pipeline with multiple outputs, and check that all are computed correctly.""" + arr = np.random.randint(0, 1000, (1000, 1000)) + df = pd.DataFrame(arr) + pipeline = PandasQueryPipeline(df) + pipeline.add_query(lambda df: df * -30, is_output=True) + pipeline.add_query( + lambda df: df.rename(columns={i: f"col {i}" for i in range(1000)}), + is_output=True, + ) + pipeline.add_query(lambda df: df + 30, is_output=True) + new_dfs = pipeline.compute_batch() + assert len(new_dfs) == 3, "Pipeline did not return all outputs" + correct_df = pd.DataFrame(arr) * -30 + df_equals(correct_df, new_dfs[0]) # First output computed correctly + correct_df = correct_df.rename(columns={i: f"col {i}" for i in range(1000)}) + df_equals(correct_df, new_dfs[1]) # Second output computed correctly + correct_df += 30 + df_equals(correct_df, new_dfs[2]) # Third output computed correctly + + def test_output_id(self): + """Ensure `output_id` is handled correctly when passed.""" + arr = np.random.randint(0, 1000, (1000, 1000)) + df = pd.DataFrame(arr) + pipeline = PandasQueryPipeline(df, 0) + pipeline.add_query(lambda df: df * -30, is_output=True, output_id=20) + with pytest.raises( + ValueError, match="Output ID must be specified for all nodes." + ): + pipeline.add_query( + lambda df: df.rename(columns={i: f"col {i}" for i in range(1000)}), + is_output=True, + ) + assert ( + len(pipeline.query_list) == 0 and len(pipeline.outputs) == 1 + ), "Invalid `add_query` incorrectly added a node to the pipeline." + pipeline = PandasQueryPipeline(df) + pipeline.add_query(lambda df: df * -30, is_output=True) + with pytest.raises( + ValueError, match="Output ID must be specified for all nodes." + ): + pipeline.add_query( + lambda df: df.rename(columns={i: f"col {i}" for i in range(1000)}), + is_output=True, + output_id=20, + ) + assert ( + len(pipeline.query_list) == 0 and len(pipeline.outputs) == 1 + ), "Invalid `add_query` incorrectly added a node to the pipeline." + pipeline = PandasQueryPipeline(df) + pipeline.add_query(lambda df: df, is_output=True) + with pytest.raises( + ValueError, + match=( + "`pass_output_id` is set to True, but output ids have not been specified. " + + "To pass output ids, please specify them using the `output_id` kwarg with pipeline.add_query" + ), + ): + pipeline.compute_batch(postprocessor=lambda df: df, pass_output_id=True) + with pytest.raises( + ValueError, + match="Output ID cannot be specified for non-output node.", + ): + pipeline.add_query(lambda df: df, output_id=22) + assert ( + len(pipeline.query_list) == 0 and len(pipeline.outputs) == 1 + ), "Invalid `add_query` incorrectly added a node to the pipeline." + + def test_output_id_multiple_outputs(self): + """Ensure `output_id` is handled correctly when multiple outputs are computed.""" + arr = np.random.randint(0, 1000, (1000, 1000)) + df = pd.DataFrame(arr) + pipeline = PandasQueryPipeline(df) + pipeline.add_query(lambda df: df * -30, is_output=True, output_id=20) + pipeline.add_query( + lambda df: df.rename(columns={i: f"col {i}" for i in range(1000)}), + is_output=True, + output_id=21, + ) + pipeline.add_query(lambda df: df + 30, is_output=True, output_id=22) + new_dfs = pipeline.compute_batch() + assert isinstance( + new_dfs, dict + ), "Pipeline did not return a dictionary mapping output_ids to dfs" + assert 20 in new_dfs, "Output ID 1 not cached correctly" + assert 21 in new_dfs, "Output ID 2 not cached correctly" + assert 22 in new_dfs, "Output ID 3 not cached correctly" + assert len(new_dfs) == 3, "Pipeline did not return all outputs" + correct_df = pd.DataFrame(arr) * -30 + df_equals(correct_df, new_dfs[20]) # First output computed correctly + correct_df = correct_df.rename(columns={i: f"col {i}" for i in range(1000)}) + df_equals(correct_df, new_dfs[21]) # Second output computed correctly + correct_df += 30 + df_equals(correct_df, new_dfs[22]) # Third output computed correctly + + def test_postprocessing(self): + """Check that the `postprocessor` argument to `_compute_batch` is handled correctly.""" + arr = np.random.randint(0, 1000, (1000, 1000)) + df = pd.DataFrame(arr) + pipeline = PandasQueryPipeline(df) + pipeline.add_query(lambda df: df * -30, is_output=True) + pipeline.add_query( + lambda df: df.rename(columns={i: f"col {i}" for i in range(1000)}), + is_output=True, + ) + pipeline.add_query(lambda df: df + 30, is_output=True) + + def new_col_adder(df): + df["new_col"] = df.iloc[:, -1] + return df + + new_dfs = pipeline.compute_batch(postprocessor=new_col_adder) + assert len(new_dfs) == 3, "Pipeline did not return all outputs" + correct_df = pd.DataFrame(arr) * -30 + correct_df["new_col"] = correct_df.iloc[:, -1] + df_equals(correct_df, new_dfs[0]) + correct_df = correct_df.drop(columns=["new_col"]) + correct_df = correct_df.rename(columns={i: f"col {i}" for i in range(1000)}) + correct_df["new_col"] = correct_df.iloc[:, -1] + df_equals(correct_df, new_dfs[1]) + correct_df = correct_df.drop(columns=["new_col"]) + correct_df += 30 + correct_df["new_col"] = correct_df.iloc[:, -1] + df_equals(correct_df, new_dfs[2]) + + def test_postprocessing_with_output_id(self): + """Check that the `postprocessor` argument is correctly handled when `output_id` is specified.""" + + def new_col_adder(df): + df["new_col"] = df.iloc[:, -1] + return df + + arr = np.random.randint(0, 1000, (1000, 1000)) + df = pd.DataFrame(arr) + pipeline = PandasQueryPipeline(df) + pipeline.add_query(lambda df: df * -30, is_output=True, output_id=20) + pipeline.add_query( + lambda df: df.rename(columns={i: f"col {i}" for i in range(1000)}), + is_output=True, + output_id=21, + ) + pipeline.add_query(lambda df: df + 30, is_output=True, output_id=22) + new_dfs = pipeline.compute_batch(postprocessor=new_col_adder) + assert len(new_dfs) == 3, "Pipeline did not return all outputs" + + def test_postprocessing_with_output_id_passed(self): + """Check that the `postprocessor` argument is correctly passed `output_id` when `pass_output_id` is `True`.""" + arr = np.random.randint(0, 1000, (1000, 1000)) + + def new_col_adder(df, o_id): + df["new_col"] = o_id + return df + + df = pd.DataFrame(arr) + pipeline = PandasQueryPipeline(df) + pipeline.add_query(lambda df: df * -30, is_output=True, output_id=20) + pipeline.add_query( + lambda df: df.rename(columns={i: f"col {i}" for i in range(1000)}), + is_output=True, + output_id=21, + ) + pipeline.add_query(lambda df: df + 30, is_output=True, output_id=22) + new_dfs = pipeline.compute_batch( + postprocessor=new_col_adder, pass_output_id=True + ) + correct_df = pd.DataFrame(arr) * -30 + correct_df["new_col"] = 20 + df_equals(correct_df, new_dfs[20]) + correct_df = correct_df.drop(columns=["new_col"]) + correct_df = correct_df.rename(columns={i: f"col {i}" for i in range(1000)}) + correct_df["new_col"] = 21 + df_equals(correct_df, new_dfs[21]) + correct_df = correct_df.drop(columns=["new_col"]) + correct_df += 30 + correct_df["new_col"] = 22 + df_equals(correct_df, new_dfs[22]) + + def test_postprocessing_with_partition_id(self): + """Check that the postprocessing is correctly handled when `partition_id` is passed.""" + arr = np.random.randint(0, 1000, (1000, 1000)) + + def new_col_adder(df, partition_id): + df["new_col"] = partition_id + return df + + df = pd.DataFrame(arr) + pipeline = PandasQueryPipeline(df) + pipeline.add_query(lambda df: df * -30, is_output=True, output_id=20) + pipeline.add_query( + lambda df: df.rename(columns={i: f"col {i}" for i in range(1000)}), + is_output=True, + output_id=21, + ) + new_dfs = pipeline.compute_batch( + postprocessor=new_col_adder, pass_partition_id=True + ) + correct_df = pd.DataFrame(arr) * -30 + correct_modin_frame = correct_df._query_compiler._modin_frame + partitions = correct_modin_frame._partition_mgr_cls.row_partitions( + correct_modin_frame._partitions + ) + partitions = [ + partition.add_to_apply_calls(new_col_adder, i) + for i, partition in enumerate(partitions) + ] + [partition.drain_call_queue() for partition in partitions] + partitions = [partition.list_of_blocks for partition in partitions] + correct_df = from_partitions(partitions, axis=None) + df_equals(correct_df, new_dfs[20]) + correct_df = correct_df.drop(columns=["new_col"]) + correct_df = pd.DataFrame( + correct_df.rename(columns={i: f"col {i}" for i in range(1000)})._to_pandas() + ) + correct_modin_frame = correct_df._query_compiler._modin_frame + partitions = correct_modin_frame._partition_mgr_cls.row_partitions( + correct_modin_frame._partitions + ) + partitions = [ + partition.add_to_apply_calls(new_col_adder, i) + for i, partition in enumerate(partitions) + ] + [partition.drain_call_queue() for partition in partitions] + partitions = [partition.list_of_blocks for partition in partitions] + correct_df = from_partitions(partitions, axis=None) + df_equals(correct_df, new_dfs[21]) + + def test_postprocessing_with_all_metadata(self): + """Check that postprocessing is correctly handled when `partition_id` and `output_id` are passed.""" + arr = np.random.randint(0, 1000, (1000, 1000)) + + def new_col_adder(df, o_id, partition_id): + df["new_col"] = f"{o_id} {partition_id}" + return df + + df = pd.DataFrame(arr) + pipeline = PandasQueryPipeline(df) + pipeline.add_query(lambda df: df * -30, is_output=True, output_id=20) + pipeline.add_query( + lambda df: df.rename(columns={i: f"col {i}" for i in range(1000)}), + is_output=True, + output_id=21, + ) + new_dfs = pipeline.compute_batch( + postprocessor=new_col_adder, pass_partition_id=True, pass_output_id=True + ) + correct_df = pd.DataFrame(arr) * -30 + correct_modin_frame = correct_df._query_compiler._modin_frame + partitions = correct_modin_frame._partition_mgr_cls.row_partitions( + correct_modin_frame._partitions + ) + partitions = [ + partition.add_to_apply_calls(new_col_adder, 20, i) + for i, partition in enumerate(partitions) + ] + [partition.drain_call_queue() for partition in partitions] + partitions = [partition.list_of_blocks for partition in partitions] + correct_df = from_partitions(partitions, axis=None) + df_equals(correct_df, new_dfs[20]) + correct_df = correct_df.drop(columns=["new_col"]) + correct_df = pd.DataFrame( + correct_df.rename(columns={i: f"col {i}" for i in range(1000)})._to_pandas() + ) + correct_modin_frame = correct_df._query_compiler._modin_frame + partitions = correct_modin_frame._partition_mgr_cls.row_partitions( + correct_modin_frame._partitions + ) + partitions = [ + partition.add_to_apply_calls(new_col_adder, 21, i) + for i, partition in enumerate(partitions) + ] + [partition.drain_call_queue() for partition in partitions] + partitions = [partition.list_of_blocks for partition in partitions] + correct_df = from_partitions(partitions, axis=None) + df_equals(correct_df, new_dfs[21]) + + def test_repartition_after(self): + """Check that the `repartition_after` argument is appropriately handled.""" + df = pd.DataFrame([list(range(1000))]) + pipeline = PandasQueryPipeline(df) + pipeline.add_query( + lambda df: pandas.concat([df] * 1000), repartition_after=True + ) + + def new_col_adder(df, partition_id): + df["new_col"] = partition_id + return df + + pipeline.add_query(new_col_adder, is_output=True, pass_partition_id=True) + new_dfs = pipeline.compute_batch() + # new_col_adder should set `new_col` to the partition ID + # throughout the dataframe. We expect there to be + # NPartitions.get() partitions by the time new_col_adder runs, + # because the previous step has repartitioned. + assert len(new_dfs[0]["new_col"].unique()) == NPartitions.get() + # Test that `repartition_after=True` raises an error when the result has more than + # one partition. + partition1 = ray.put(pandas.DataFrame([[0, 1, 2]])) + partition2 = ray.put(pandas.DataFrame([[3, 4, 5]])) + df = from_partitions([partition1, partition2], 0) + pipeline = PandasQueryPipeline(df, 0) + pipeline.add_query(lambda df: df, repartition_after=True, is_output=True) + + with pytest.raises( + NotImplementedError, + match="Dynamic repartitioning is currently only supported for DataFrames with 1 partition.", + ): + new_dfs = pipeline.compute_batch() + + def test_fan_out(self): + """Check that the fan_out argument is appropriately handled.""" + df = pd.DataFrame([[0, 1, 2]]) + + def new_col_adder(df, partition_id): + df["new_col"] = partition_id + return df + + def reducer(dfs): + new_cols = "".join([str(df["new_col"].values[0]) for df in dfs]) + dfs[0]["new_col1"] = new_cols + return dfs[0] + + pipeline = PandasQueryPipeline(df) + pipeline.add_query( + new_col_adder, + fan_out=True, + reduce_fn=reducer, + pass_partition_id=True, + is_output=True, + ) + new_df = pipeline.compute_batch()[0] + correct_df = pd.DataFrame([[0, 1, 2]]) + correct_df["new_col"] = 0 + correct_df["new_col1"] = "".join([str(i) for i in range(NPartitions.get())]) + df_equals(correct_df, new_df) + # Test that `fan_out=True` raises an error when the input has more than + # one partition. + partition1 = ray.put(pandas.DataFrame([[0, 1, 2]])) + partition2 = ray.put(pandas.DataFrame([[3, 4, 5]])) + df = from_partitions([partition1, partition2], 0) + pipeline = PandasQueryPipeline(df) + pipeline.add_query( + new_col_adder, + fan_out=True, + reduce_fn=reducer, + pass_partition_id=True, + is_output=True, + ) + with pytest.raises( + NotImplementedError, + match="Fan out is only supported with DataFrames with 1 partition.", + ): + new_df = pipeline.compute_batch()[0] + + def test_pipeline_complex(self): + """Create a complex pipeline with both `fan_out`, `repartition_after` and postprocessing and ensure that it runs end to end correctly.""" + from os.path import exists + from os import remove + from time import sleep + + df = pd.DataFrame([[0, 1, 2]]) + + def new_col_adder(df, partition_id): + sleep(60) + df["new_col"] = partition_id + return df + + def reducer(dfs): + new_cols = "".join([str(df["new_col"].values[0]) for df in dfs]) + dfs[0]["new_col1"] = new_cols + return dfs[0] + + desired_num_partitions = 24 + pipeline = PandasQueryPipeline(df, num_partitions=desired_num_partitions) + pipeline.add_query( + new_col_adder, + fan_out=True, + reduce_fn=reducer, + pass_partition_id=True, + is_output=True, + output_id=20, + ) + pipeline.add_query( + lambda df: pandas.concat([df] * 1000), + repartition_after=True, + ) + + def to_csv(df, partition_id): + df = df.drop(columns=["new_col"]) + df.to_csv(f"{partition_id}.csv") + return df + + pipeline.add_query(to_csv, is_output=True, output_id=21, pass_partition_id=True) + + def post_proc(df, o_id, partition_id): + df["new_col_proc"] = f"{o_id} {partition_id}" + return df + + new_dfs = pipeline.compute_batch( + postprocessor=post_proc, + pass_partition_id=True, + pass_output_id=True, + ) + correct_df = pd.DataFrame([[0, 1, 2]]) + correct_df["new_col"] = 0 + correct_df["new_col1"] = "".join( + [str(i) for i in range(desired_num_partitions)] + ) + correct_df["new_col_proc"] = "20 0" + df_equals(correct_df, new_dfs[20]) + correct_df = pd.concat([correct_df] * 1000) + correct_df = correct_df.drop(columns=["new_col"]) + correct_df["new_col_proc"] = "21 0" + new_length = len(correct_df.index) // desired_num_partitions + for i in range(desired_num_partitions): + if i == desired_num_partitions - 1: + correct_df.iloc[i * new_length :, -1] = f"21 {i}" + else: + correct_df.iloc[i * new_length : (i + 1) * new_length, -1] = f"21 {i}" + df_equals(correct_df, new_dfs[21]) + correct_df = correct_df.drop(columns=["new_col_proc"]) + for i in range(desired_num_partitions): + if i == desired_num_partitions - 1: + correct_partition = correct_df.iloc[i * new_length :] + else: + correct_partition = correct_df.iloc[ + i * new_length : (i + 1) * new_length + ] + assert exists( + f"{i}.csv" + ), "CSV File for Partition {i} does not exist, even though dataframe should have been repartitioned." + df_equals( + correct_partition, + pd.read_csv(f"{i}.csv", index_col="Unnamed: 0").rename( + columns={"0": 0, "1": 1, "2": 2} + ), + ) + remove(f"{i}.csv") + + +@pytest.mark.skipif( + Engine.get() == "Ray", + reason="Ray supports the Batch Pipeline API", +) +def test_pipeline_unsupported_engine(): + """Ensure that trying to use the Pipeline API with an unsupported Engine raises errors.""" + # Check that pipeline does not allow `Engine` to not be Ray. + df = pd.DataFrame([[1]]) + with pytest.raises( + NotImplementedError, + match="Batch Pipeline API is only implemented for `PandasOnRay` execution.", + ): + PandasQueryPipeline(df) + + eng = Engine.get() + Engine.put("Ray") + # Check that even if Engine is Ray, if the df is not backed by Ray, the Pipeline does not allow initialization. + with pytest.raises( + NotImplementedError, + match="Batch Pipeline API is only implemented for `PandasOnRay` execution.", + ): + PandasQueryPipeline(df, 0) + df_on_ray_engine = pd.DataFrame([[1]]) + pipeline = PandasQueryPipeline(df_on_ray_engine) + # Check that even if Engine is Ray, if the new df is not backed by Ray, the Pipeline does not allow an update. + with pytest.raises( + NotImplementedError, + match="Batch Pipeline API is only implemented for `PandasOnRay` execution.", + ): + pipeline.update_df(df) + Engine.put(eng) + # Check that pipeline does not allow an update when `Engine` is not Ray. + with pytest.raises( + NotImplementedError, + match="Batch Pipeline API is only implemented for `PandasOnRay` execution.", + ): + pipeline.update_df(df)