From 797904b7b05513a18e49bae7d715dfcfd2fcb414 Mon Sep 17 00:00:00 2001 From: Danang Date: Thu, 5 Dec 2024 07:12:57 +0000 Subject: [PATCH] Improve dask to csv (#283) * use chunks to avoid loading all dataframe in memory * use inplace for pivotted dataframe set index * fix test --- django_project/gap/providers/observation.py | 3 +- django_project/gap/tests/utils/test_netcdf.py | 7 +- django_project/gap/utils/dask.py | 20 ++++-- django_project/gap/utils/reader.py | 68 +++++++++++++++---- 4 files changed, 78 insertions(+), 20 deletions(-) diff --git a/django_project/gap/providers/observation.py b/django_project/gap/providers/observation.py index d9ff0fdb..57b8bffe 100644 --- a/django_project/gap/providers/observation.py +++ b/django_project/gap/providers/observation.py @@ -233,7 +233,8 @@ def to_netcdf_stream(self): df_pivot[date_coord] = pd.to_datetime(df_pivot[date_coord]) # Convert to xarray Dataset - ds = df_pivot.set_index(field_indices).to_xarray() + df_pivot.set_index(field_indices, inplace=True) + ds = df_pivot.to_xarray() # write to netcdf with ( diff --git a/django_project/gap/tests/utils/test_netcdf.py b/django_project/gap/tests/utils/test_netcdf.py index d912a6e8..e8e6bf73 100644 --- a/django_project/gap/tests/utils/test_netcdf.py +++ b/django_project/gap/tests/utils/test_netcdf.py @@ -261,7 +261,12 @@ def test_to_csv_stream(self): csv_stream = self.dataset_reader_value_xr.to_csv_stream() csv_data = list(csv_stream) self.assertIsNotNone(csv_data) - data = csv_data[1].splitlines() + data = [] + for idx, d in enumerate(csv_data): + if idx == 0: + continue + data.extend(d.splitlines()) + # rows without header self.assertEqual(len(data), 40) diff --git a/django_project/gap/utils/dask.py b/django_project/gap/utils/dask.py index c3ff6a20..b57c970a 100644 --- a/django_project/gap/utils/dask.py +++ b/django_project/gap/utils/dask.py @@ -12,6 +12,20 @@ from gap.models import Preferences +def get_num_of_threads(is_api=False): + """Get number of threads for dask computation. + + :param is_api: whether for API usage, defaults to False + :type is_api: bool, optional + """ + preferences = Preferences.load() + return ( + preferences.dask_threads_num_api if is_api else + preferences.dask_threads_num + ) + + + def execute_dask_compute(x: Delayed, is_api=False): """Execute dask computation based on number of threads config. @@ -20,11 +34,7 @@ def execute_dask_compute(x: Delayed, is_api=False): :param is_api: Whether the computation is in GAP API, default to False :type is_api: bool """ - preferences = Preferences.load() - num_of_threads = ( - preferences.dask_threads_num_api if is_api else - preferences.dask_threads_num - ) + num_of_threads = get_num_of_threads(is_api) if num_of_threads <= 0: # use everything x.compute() diff --git a/django_project/gap/utils/reader.py b/django_project/gap/utils/reader.py index 0dde5d1a..05f1e17b 100644 --- a/django_project/gap/utils/reader.py +++ b/django_project/gap/utils/reader.py @@ -7,6 +7,8 @@ import json import tempfile +import dask +from concurrent.futures import ThreadPoolExecutor from datetime import datetime, timedelta from typing import Union, List, Tuple @@ -27,7 +29,7 @@ DatasetTimeStep, DatasetObservationType ) -from gap.utils.dask import execute_dask_compute +from gap.utils.dask import execute_dask_compute, get_num_of_threads class DatasetVariable: @@ -443,6 +445,15 @@ def to_netcdf_stream(self): break yield chunk + def _get_chunk_indices(self, chunks): + indices = [] + start = 0 + for size in chunks: + stop = start + size + indices.append((start, stop)) + start = stop + return indices + def to_csv_stream(self, suffix='.csv', separator=','): """Generate csv bytes stream. @@ -457,6 +468,12 @@ def to_csv_stream(self, suffix='.csv', separator=','): reordered_cols = [ attribute.attribute.variable_name for attribute in self.attributes ] + # use date chunk = 1 to order by date + rechunk = { + self.date_variable: 1, + 'lat': 300, + 'lon': 300 + } if 'lat' in self.xr_dataset.dims: dim_order.append('lat') dim_order.append('lon') @@ -465,18 +482,43 @@ def to_csv_stream(self, suffix='.csv', separator=','): reordered_cols.insert(0, 'lat') if 'ensemble' in self.xr_dataset.dims: dim_order.append('ensemble') - # reordered_cols.insert(0, 'ensemble') - df = self.xr_dataset.to_dataframe(dim_order=dim_order) - df_reordered = df[reordered_cols] - - # write headers - headers = dim_order + list(df_reordered.columns) - yield bytes(','.join(headers) + '\n', 'utf-8') - - # Write the data in chunks - for start in range(0, len(df_reordered), self.csv_chunk_size): - chunk = df_reordered.iloc[start:start + self.csv_chunk_size] - yield chunk.to_csv(index=True, header=False, float_format='%g') + rechunk['ensemble'] = 50 + + # rechunk dataset + ds = self.xr_dataset.chunk(rechunk) + date_indices = self._get_chunk_indices( + ds.chunksizes[self.date_variable] + ) + lat_indices = self._get_chunk_indices(ds.chunksizes['lat']) + lon_indices = self._get_chunk_indices(ds.chunksizes['lon']) + write_headers = True + + # cannot use dask utils because to_dataframe is not returning + # delayed object + with dask.config.set( + pool=ThreadPoolExecutor(get_num_of_threads(is_api=True)) + ): + # iterate foreach chunk + for date_start, date_stop in date_indices: + for lat_start, lat_stop in lat_indices: + for lon_start, lon_stop in lon_indices: + slice_dict = { + self.date_variable: slice(date_start, date_stop), + 'lat': slice(lat_start, lat_stop), + 'lon': slice(lon_start, lon_stop) + } + chunk = ds.isel(**slice_dict) + chunk_df = chunk.to_dataframe(dim_order=dim_order) + chunk_df = chunk_df[reordered_cols] + + if write_headers: + headers = dim_order + list(chunk_df.columns) + yield bytes(','.join(headers) + '\n', 'utf-8') + write_headers = False + + yield chunk_df.to_csv( + index=True, header=False, float_format='%g' + ) class BaseDatasetReader: