Skip to content

Commit

Permalink
ENH: Use odc.geo.xr.crop instead of rio.clip to make `rasters.cro…
Browse files Browse the repository at this point in the history
…p` dask-compatible #27
  • Loading branch information
remi-braun committed Jan 24, 2025
1 parent 3bf46dd commit 0fd7740
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 19 deletions.
3 changes: 2 additions & 1 deletion CHANGES.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# Release History

## 1.44.7 (2025-mm-dd)
## 1.45.0 (2025-mm-dd)

- **ENH: Use `odc.geo.xr.crop` instead of `rio.clip` to make `rasters.crop` dask-compatible** ([#27](https://github.com/sertit/sertit-utils/issues/27))
- FIX: Fixes when trying to write COGs with dask in `rasters.write`

## 1.44.6 (2025-01-13)
Expand Down
9 changes: 9 additions & 0 deletions ci/script_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,3 +104,12 @@ def xml_path():

def s3_env(*args, **kwargs):
return unistra.s3_env(use_s3_env_var=CI_SERTIT_S3, *args, **kwargs) # noqa: B026


def get_output(tmp, file, debug=False):
if debug:
out_path = AnyPath(__file__).resolve().parent / "ci_output"
out_path.mkdir(parents=True, exist_ok=True)
return out_path / file
else:
return AnyPath(tmp, file)
8 changes: 5 additions & 3 deletions ci/test_rasters.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import shapely
import xarray as xr

from ci.script_utils import KAPUT_KWARGS, dask_env, rasters_path, s3_env
from ci.script_utils import KAPUT_KWARGS, dask_env, get_output, rasters_path, s3_env
from sertit import ci, path, rasters, unistra, vectors
from sertit.rasters import (
FLOAT_NODATA,
Expand All @@ -40,6 +40,8 @@

ci.reduce_verbosity()

DEBUG = False


def test_indexes(caplog):
@s3_env
Expand Down Expand Up @@ -307,13 +309,13 @@ def test_paint(tmp_path, xda, xds, xda_dask, mask):
def test_crop(tmp_path, xda, xds, xda_dask, mask):
"""Test crop function"""
# DataArray
xda_cropped = os.path.join(tmp_path, "test_crop_xda.tif")
xda_cropped = get_output(tmp_path, "test_crop_xda.tif", DEBUG)
crop_xda = rasters.crop(xda, mask.geometry, **KAPUT_KWARGS)
rasters.write(crop_xda, xda_cropped, dtype=np.uint8)
ci.assert_xr_encoding_attrs(xda, crop_xda)

# Dataset
xds_cropped = os.path.join(tmp_path, "test_crop_xds.tif")
xds_cropped = get_output(tmp_path, "test_crop_xds.tif", DEBUG)
crop_xds = rasters.crop(xds, mask, nodata=get_nodata_value_from_xr(xds))
rasters.write(crop_xds, xds_cropped, dtype=np.uint8)
ci.assert_xr_encoding_attrs(xds, crop_xds)
Expand Down
54 changes: 39 additions & 15 deletions sertit/rasters.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
You can use this only if you have installed sertit[full] or sertit[rasters]
"""

import contextlib
import logging
from functools import wraps
from typing import Any, Callable, Optional, Union
Expand All @@ -41,7 +42,6 @@
"Please install 'rioxarray' to use the 'rasters' package."
) from ex

import contextlib

from sertit import dask, geometry, logs, misc, path, rasters_rio, vectors
from sertit.types import AnyPathStrType, AnyPathType, AnyRasterType, AnyXrDataStructure
Expand Down Expand Up @@ -763,20 +763,44 @@ def crop(
if nodata:
xds = set_nodata(xds, nodata)

if isinstance(shapes, (gpd.GeoDataFrame, gpd.GeoSeries)):
shapes = shapes.to_crs(xds.rio.crs).geometry
try:
from odc.geo import (
Geometry,
xr, # noqa
)

# Convert the shapes in the right format
if isinstance(shapes, (gpd.GeoDataFrame, gpd.GeoSeries)):
shapes = shapes.to_crs(xds.rio.crs)

if "from_disk" not in kwargs:
kwargs["from_disk"] = True # WAY FASTER
try:
shapes = shapes.union_all()
except AttributeError:
# Geopandas < 1.1.0
shapes = shapes.unary_union

# Clip keeps encoding and attrs
return xds.rio.clip(
shapes,
**misc.select_dict(
kwargs,
["crs", "all_touched", "drop", "invert", "from_disk"],
),
)
shapes_geom = Geometry(shapes, crs=xds.rio.crs)

# Crop
cropped = xds.odc.crop(shapes_geom, apply_mask=True)

return set_metadata(cropped, xds)

except ImportError:
if isinstance(shapes, (gpd.GeoDataFrame, gpd.GeoSeries)):
shapes = shapes.to_crs(xds.rio.crs).geometry

if "from_disk" not in kwargs:
kwargs["from_disk"] = True # WAY FASTER

# Clip keeps encoding and attrs
return xds.rio.clip(
shapes,
**misc.select_dict(
kwargs,
["crs", "all_touched", "drop", "invert", "from_disk"],
),
)


def __read__any_raster_to_rio_ds(function: Callable) -> Callable:
Expand Down Expand Up @@ -1651,8 +1675,8 @@ def read_uint8_array(


def set_metadata(
naked_xda: xr.DataArray, mtd_xda: xr.DataArray, new_name=None
) -> xr.DataArray:
naked_xda: AnyXrDataStructure, mtd_xda: AnyXrDataStructure, new_name=None
) -> AnyXrDataStructure:
"""
Set metadata from a :code:`xr.DataArray` to another (including :code:`rioxarray` metadata such as encoded_nodata and crs).
Expand Down

0 comments on commit 0fd7740

Please sign in to comment.