Skip to content

Commit

Permalink
Improve dask to csv (#283)
Browse files Browse the repository at this point in the history
* use chunks to avoid loading all dataframe in memory

* use inplace for pivotted dataframe set index

* fix test
  • Loading branch information
danangmassandy authored Dec 5, 2024
1 parent 288cfff commit 797904b
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 20 deletions.
3 changes: 2 additions & 1 deletion django_project/gap/providers/observation.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,8 @@ def to_netcdf_stream(self):
df_pivot[date_coord] = pd.to_datetime(df_pivot[date_coord])

# Convert to xarray Dataset
ds = df_pivot.set_index(field_indices).to_xarray()
df_pivot.set_index(field_indices, inplace=True)
ds = df_pivot.to_xarray()

# write to netcdf
with (
Expand Down
7 changes: 6 additions & 1 deletion django_project/gap/tests/utils/test_netcdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,12 @@ def test_to_csv_stream(self):
csv_stream = self.dataset_reader_value_xr.to_csv_stream()
csv_data = list(csv_stream)
self.assertIsNotNone(csv_data)
data = csv_data[1].splitlines()
data = []
for idx, d in enumerate(csv_data):
if idx == 0:
continue
data.extend(d.splitlines())

# rows without header
self.assertEqual(len(data), 40)

Expand Down
20 changes: 15 additions & 5 deletions django_project/gap/utils/dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,20 @@
from gap.models import Preferences


def get_num_of_threads(is_api=False):
"""Get number of threads for dask computation.
:param is_api: whether for API usage, defaults to False
:type is_api: bool, optional
"""
preferences = Preferences.load()
return (
preferences.dask_threads_num_api if is_api else
preferences.dask_threads_num
)



def execute_dask_compute(x: Delayed, is_api=False):
"""Execute dask computation based on number of threads config.
Expand All @@ -20,11 +34,7 @@ def execute_dask_compute(x: Delayed, is_api=False):
:param is_api: Whether the computation is in GAP API, default to False
:type is_api: bool
"""
preferences = Preferences.load()
num_of_threads = (
preferences.dask_threads_num_api if is_api else
preferences.dask_threads_num
)
num_of_threads = get_num_of_threads(is_api)
if num_of_threads <= 0:
# use everything
x.compute()
Expand Down
68 changes: 55 additions & 13 deletions django_project/gap/utils/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

import json
import tempfile
import dask
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime, timedelta
from typing import Union, List, Tuple

Expand All @@ -27,7 +29,7 @@
DatasetTimeStep,
DatasetObservationType
)
from gap.utils.dask import execute_dask_compute
from gap.utils.dask import execute_dask_compute, get_num_of_threads


class DatasetVariable:
Expand Down Expand Up @@ -443,6 +445,15 @@ def to_netcdf_stream(self):
break
yield chunk

def _get_chunk_indices(self, chunks):
indices = []
start = 0
for size in chunks:
stop = start + size
indices.append((start, stop))
start = stop
return indices

def to_csv_stream(self, suffix='.csv', separator=','):
"""Generate csv bytes stream.
Expand All @@ -457,6 +468,12 @@ def to_csv_stream(self, suffix='.csv', separator=','):
reordered_cols = [
attribute.attribute.variable_name for attribute in self.attributes
]
# use date chunk = 1 to order by date
rechunk = {
self.date_variable: 1,
'lat': 300,
'lon': 300
}
if 'lat' in self.xr_dataset.dims:
dim_order.append('lat')
dim_order.append('lon')
Expand All @@ -465,18 +482,43 @@ def to_csv_stream(self, suffix='.csv', separator=','):
reordered_cols.insert(0, 'lat')
if 'ensemble' in self.xr_dataset.dims:
dim_order.append('ensemble')
# reordered_cols.insert(0, 'ensemble')
df = self.xr_dataset.to_dataframe(dim_order=dim_order)
df_reordered = df[reordered_cols]

# write headers
headers = dim_order + list(df_reordered.columns)
yield bytes(','.join(headers) + '\n', 'utf-8')

# Write the data in chunks
for start in range(0, len(df_reordered), self.csv_chunk_size):
chunk = df_reordered.iloc[start:start + self.csv_chunk_size]
yield chunk.to_csv(index=True, header=False, float_format='%g')
rechunk['ensemble'] = 50

# rechunk dataset
ds = self.xr_dataset.chunk(rechunk)
date_indices = self._get_chunk_indices(
ds.chunksizes[self.date_variable]
)
lat_indices = self._get_chunk_indices(ds.chunksizes['lat'])
lon_indices = self._get_chunk_indices(ds.chunksizes['lon'])
write_headers = True

# cannot use dask utils because to_dataframe is not returning
# delayed object
with dask.config.set(
pool=ThreadPoolExecutor(get_num_of_threads(is_api=True))
):
# iterate foreach chunk
for date_start, date_stop in date_indices:
for lat_start, lat_stop in lat_indices:
for lon_start, lon_stop in lon_indices:
slice_dict = {
self.date_variable: slice(date_start, date_stop),
'lat': slice(lat_start, lat_stop),
'lon': slice(lon_start, lon_stop)
}
chunk = ds.isel(**slice_dict)
chunk_df = chunk.to_dataframe(dim_order=dim_order)
chunk_df = chunk_df[reordered_cols]

if write_headers:
headers = dim_order + list(chunk_df.columns)
yield bytes(','.join(headers) + '\n', 'utf-8')
write_headers = False

yield chunk_df.to_csv(
index=True, header=False, float_format='%g'
)


class BaseDatasetReader:
Expand Down

0 comments on commit 797904b

Please sign in to comment.