Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Floodscan stats hannah #14

Merged
merged 2 commits into from
Nov 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion run_raster_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ def process_chunk(start, end, dataset, mode, df_iso3s, engine_url):
try:
for _, row in df_iso3s.iterrows():
iso3 = row["iso3"]
# shp_url = row["o_shp"]
max_adm = row["max_adm_level"]
logger.info(f"Processing data for {iso3}...")

Expand Down
4 changes: 3 additions & 1 deletion src/config/seas5.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ blob_prefix: seas5/monthly/processed/precip_em_i
start_date: 1981-01-01
end_date: Null
forecast: True
extra_dims:
- leadtime
test:
start_date: 2024-01-01
end_date: 2024-10-01
end_date: 2024-02-01
iso3s: ["AFG"]
6 changes: 4 additions & 2 deletions src/utils/cog_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def stack_cogs(start_date, end_date, dataset="era5", mode="dev"):
end_date : str or datetime-like
The end date of the date range for stacking the COGs. This can be a string or a datetime object.
dataset : str, optional
The name of the dataset to retrieve COGs from. Options include "era5", "imerg", and "seas5".
The name of the dataset to retrieve COGs from. Options include "floodscan", "era5", "imerg", and "seas5".
Default is "era5".
mode : str, optional
The environment mode to use when accessing the cloud storage container. May be "dev", "prod", or "local".
Expand All @@ -166,7 +166,9 @@ def stack_cogs(start_date, end_date, dataset="era5", mode="dev"):
config = load_pipeline_config(dataset)
prefix = config["blob_prefix"]
except Exception:
logger.error("Input `dataset` must be one of `era5`, `seas5`, or `imerg`.")
logger.error(
"Input `dataset` must be one of `floodscan`, `era5`, `seas5`, or `imerg`."
)

cogs_list = [
x.name
Expand Down
8 changes: 5 additions & 3 deletions src/utils/database_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def db_engine_url(mode):
return DATABASES[mode]


def create_dataset_table(dataset, engine, is_forecast=False, extra_dims=None):
def create_dataset_table(dataset, engine, is_forecast=False, extra_dims=[]):
"""
Create a table for storing dataset statistics in the database.

Expand All @@ -48,8 +48,10 @@ def create_dataset_table(dataset, engine, is_forecast=False, extra_dims=None):
is_forecast : Bool
Whether or not the dataset is a forecast. Will include `leadtime` and
`issued_date` columns if so.
extra_dims: List
List containing the name of the extra dimensions
extra_dims : List
A list of the names of any extra dimensions that need to be added to the
dataset table.

Returns
-------
None
Expand Down
84 changes: 43 additions & 41 deletions src/utils/raster_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,17 @@
coloredlogs.install(level=LOG_LEVEL, logger=logger)


def validate_dimensions(ds):
required_dims = {"x", "y", "date"}
missing_dims = required_dims - set(ds.dims)
if missing_dims:
raise ValueError(f"Dataset missing required dimensions: {missing_dims}")
# Get the fourth dimension if it exists (not x, y, or date)
dims = list(ds.dims)
fourth_dim = next((dim for dim in dims if dim not in {"x", "y", "date"}), None)
return fourth_dim


def fast_zonal_stats_runner(
ds,
gdf,
Expand All @@ -35,7 +46,8 @@ def fast_zonal_stats_runner(
Parameters
----------
ds : xarray.Dataset
The input raster dataset. Should have the following dimensions: `x`, `y`, `date`, `leadtime` (optional).
The input raster dataset. Must have dimensions 'x', 'y', and 'date',
with an optional fourth dimension (e.g., 'band' or 'leadtime').
gdf : geopandas.GeoDataFrame
A GeoDataFrame containing the administrative boundaries.
adm_level : int
Expand Down Expand Up @@ -65,7 +77,8 @@ def fast_zonal_stats_runner(
logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())

# TODO: Pre-compute and save
fourth_dim = validate_dimensions(ds)

# Rasterize the adm bounds
src_transform = ds.rio.transform()
src_width = ds.rio.width
Expand All @@ -77,27 +90,29 @@ def fast_zonal_stats_runner(
n_adms = len(adm_ids)

outputs = []
# TODO: Can this be vectorized further?
for date in ds.date.values:
logger.debug(f"Calculating for {date}...")
ds_sel = ds.sel(date=date)
if "leadtime" in ds_sel.dims:
for lt in ds_sel.leadtime.values:
ds__ = ds_sel.sel(leadtime=lt)
# Some leadtime/date combos are invalid and so don't have any data

if fourth_dim: # 4D case
for val in ds_sel[fourth_dim].values:
ds__ = ds_sel.sel({fourth_dim: val})
# Skip if all values are NaN
if bool(np.all(np.isnan(ds__.values))):
continue
results = fast_zonal_stats(
ds__.values, admin_raster, n_adms, stats=stats, rast_fill=rast_fill
)
for i, result in enumerate(results):
result["valid_date"] = date
result["issued_date"] = add_months_to_date(date, -lt)
# Special handling for leadtime dimension
if fourth_dim == "leadtime":
result["issued_date"] = add_months_to_date(date, -val)
result["pcode"] = adm_ids[i]
result["adm_level"] = adm_level
result["leadtime"] = lt
result[fourth_dim] = val # Store the fourth dimension value
outputs.extend(results)
else:
else: # 3D case
results = fast_zonal_stats(
ds_sel.values, admin_raster, n_adms, stats=stats, rast_fill=rast_fill
)
Expand Down Expand Up @@ -215,7 +230,8 @@ def upsample_raster(ds, resampled_resolution=UPSAMPLED_RESOLUTION, logger=None):
Parameters
----------
ds : xarray.Dataset
The raster data set to upsample. Must not have >4 dimensions.
The raster data set to upsample. Must have dimensions 'x', 'y', and 'date',
with an optional fourth dimension (e.g., 'band' or 'leadtime').
resampled_resolution : float, optional
The desired resolution for the upsampled raster.

Expand All @@ -228,6 +244,8 @@ def upsample_raster(ds, resampled_resolution=UPSAMPLED_RESOLUTION, logger=None):
logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())

fourth_dim = validate_dimensions(ds)

# Assuming square resolution
input_resolution = ds.rio.resolution()[0]
upscale_factor = input_resolution / resampled_resolution
Expand All @@ -249,45 +267,29 @@ def upsample_raster(ds, resampled_resolution=UPSAMPLED_RESOLUTION, logger=None):
)
ds = ds.rio.write_crs("EPSG:4326")

# Forecast data will have 4 dims, since we have a leadtime
nd = len(list(ds.dims))
if nd == 4:
if fourth_dim: # 4D case
resampled_arrays = []

# TODO: fix the rioxarray.exceptions.NoDataInBounds: Unable to determine bounds from coordinates. here
if 'band' in ds.dims:
for band in ds.band.values:
ds_ = ds.sel(band=band)
ds_ = ds_.rio.reproject(
ds_.rio.crs,
shape=(new_height, new_width),
resampling=Resampling.nearest,
nodata=np.nan,
)
ds_ = ds_.expand_dims(["band"])
resampled_arrays.append(ds_)
else:
for lt in ds.leadtime.values:
ds_ = ds.sel(leadtime=lt)
ds_ = ds_.rio.reproject(
ds_.rio.crs,
shape=(new_height, new_width),
resampling=Resampling.nearest,
nodata=np.nan,
)
ds_ = ds_.expand_dims(["leadtime"])
resampled_arrays.append(ds_)

for val in ds[fourth_dim].values:
ds_ = ds.sel({fourth_dim: val})
ds_ = ds_.rio.reproject(
ds_.rio.crs,
shape=(new_height, new_width),
resampling=Resampling.nearest,
nodata=np.nan,
)
# Expand along the fourth dimension
ds_ = ds_.expand_dims([fourth_dim])
resampled_arrays.append(ds_)

ds_resampled = xr.combine_by_coords(resampled_arrays, combine_attrs="drop")
elif (nd == 2) or (nd == 3):
else: # 3D case (x, y, date)
ds_resampled = ds.rio.reproject(
ds.rio.crs,
shape=(new_height, new_width),
resampling=Resampling.nearest,
nodata=np.nan,
)
else:
raise Exception("Input Dataset must have 2, 3, or 4 dimensions.")

return ds_resampled

Expand Down