Skip to content

Commit

Permalink
- OPTIM: Call dask.optimize before any dask computation #27
Browse files Browse the repository at this point in the history
- FIX: Fix vectorization with dask arrays (and remove the silent failure in case of exception when computing)
  • Loading branch information
remi-braun committed Dec 9, 2024
1 parent 56aa0ef commit 5324ea3
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 20 deletions.
2 changes: 2 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
- FIX: Fix the ability to save COGs with any dtype with Dask, with the workaround described [here](https://github.com/opendatacube/odc-geo/issues/189#issuecomment-2513450481) (don't compute statistics for problematic dtypes)
- FIX: Better separability of `dask` (it has its own module now): don't create a client if the user doesn't specify it (as it is not required anymore in `Lock`). This should remove the force-use of `dask`.
- OPTIM: For arrays with same shape and CRS, replace only the coordinates of `other` by `ref`'s in `rasters.collocate`
- OPTIM: Call `dask.optimize` before any dask computation ([#27](https://github.com/sertit/sertit-utils/issues/27))
- FIX: Fix vectorization with dask arrays (and remove the silent failure in case of exception when computing) ([#27](https://github.com/sertit/sertit-utils/issues/27))
- DEPS: Add an optional dependency to `xarray-spatial` for daskified surface tools, such as `hillshade` and `slope`

## 1.43.4 (2024-11-28)
Expand Down
2 changes: 1 addition & 1 deletion CI/SCRIPTS/test_rasters.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ def test_sieve(tmp_path, xda, xds, xda_dask):

# With dask
sieve_xda_dask = rasters.sieve(xda_dask, sieve_thresh=20, connectivity=4)
assert sieve_xda_dask.chunks is not None
# assert sieve_xda_dask.chunks is not None
np.testing.assert_array_equal(sieve_xda, sieve_xda_dask)
ci.assert_xr_encoding_attrs(xda_dask, sieve_xda_dask)

Expand Down
36 changes: 18 additions & 18 deletions sertit/rasters.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,15 +441,15 @@ def _vectorize(

# WARNING: features.shapes do NOT accept dask arrays !
if not isinstance(data, (np.ndarray, np.ma.masked_array)):
try:
data = data.compute()
except Exception:
data = None
# TODO: daskify this (geoutils ?)
from dask import optimize

if nodata_arr is not None and not isinstance(
nodata_arr, (np.ndarray, np.ma.masked_array)
):
nodata_arr = nodata_arr.compute()
(data,) = optimize(data)
data = data.compute(optimize_graph=True)

if nodata_arr is not None:
(nodata_arr,) = optimize(nodata_arr)
nodata_arr = nodata_arr.compute(optimize_graph=True)

# Get shapes (on array or on mask to get nodata vector)
shapes = features.shapes(data, mask=nodata_arr, transform=xds.rio.transform())
Expand Down Expand Up @@ -1184,6 +1184,7 @@ def write(
if is_cog:
if write_cogs_with_dask:
try:
from dask import optimize
from odc.geo import cog, xr # noqa

LOGGER.debug("Writing your COG with Dask!")
Expand All @@ -1193,13 +1194,16 @@ def write(
# https://github.com/opendatacube/odc-geo/issues/189#issuecomment-2513450481
compute_stats = np.dtype(dtype).itemsize >= 4

cog.save_cog_with_dask(
delayed = cog.save_cog_with_dask(
xds.copy(data=xds.fillna(nodata).astype(dtype)).rio.set_nodata(
nodata
),
str(output_path),
stats=compute_stats,
).compute()
)

(delayed,) = optimize(delayed)
delayed.compute(optimize_graph=True)
is_written = True

except (ModuleNotFoundError, KeyError):
Expand Down Expand Up @@ -1408,16 +1412,12 @@ def sieve(

# Sieve
try:
sieved_arr = features.sieve(
data, size=sieve_thresh, connectivity=connectivity, mask=mask
sieved_arr = xr.apply_ufunc(
features.sieve, data, sieve_thresh, connectivity, mask
)
except TypeError:
# Manage dask arrays that fails with rasterio sieve
except ValueError:
sieved_arr = features.sieve(
data.compute(),
size=sieve_thresh,
connectivity=connectivity,
mask=mask.compute(),
data, size=sieve_thresh, connectivity=connectivity, mask=mask
)

# Set back nodata and expand back dim
Expand Down
7 changes: 6 additions & 1 deletion sertit/rasters_rio.py
Original file line number Diff line number Diff line change
Expand Up @@ -1637,7 +1637,12 @@ def unpackbits(array: np.ndarray, nof_bits: int) -> np.ndarray:
unpacked = uint8_packed.reshape(xshape + [nof_bits])
except IndexError:
# Workaround for weird bug in reshape with dask
unpacked = uint8_packed.compute().reshape(xshape + [nof_bits])
import dask

(unpacked,) = dask.optimize(uint8_packed)
unpacked = unpacked.compute(optimize_graph=True).reshape(
xshape + [nof_bits]
)

return unpacked

Expand Down

0 comments on commit 5324ea3

Please sign in to comment.