diff --git a/run_raster_stats.py b/run_raster_stats.py index ee17e61..41b8da3 100644 --- a/run_raster_stats.py +++ b/run_raster_stats.py @@ -9,7 +9,12 @@ import pandas as pd from sqlalchemy import create_engine -from src.config.settings import LOG_LEVEL, load_pipeline_config, parse_pipeline_config +from src.config.settings import ( + LOG_LEVEL, + UPSAMPLED_RESOLUTION, + load_pipeline_config, + parse_pipeline_config, +) from src.utils.cog_utils import stack_cogs from src.utils.database_utils import ( create_dataset_table, @@ -21,6 +26,7 @@ from src.utils.general_utils import split_date_range from src.utils.inputs import cli_args from src.utils.iso3_utils import create_iso3_df, get_iso3_data, load_shp_from_azure +from src.utils.metadata_utils import process_polygon_metadata from src.utils.raster_utils import fast_zonal_stats_runner, prep_raster logger = logging.getLogger(__name__) @@ -110,19 +116,28 @@ def process_chunk(start, end, dataset, mode, df_iso3s, engine_url): if __name__ == "__main__": args = cli_args() - dataset = args.dataset - logger.info(f"Updating data for {dataset}...") engine_url = db_engine_url(args.mode) engine = create_engine(engine_url) + if args.update_metadata: + logger.info("Updating metadata in Postgres database...") + create_iso3_df(engine) + process_polygon_metadata( + engine, + args.mode, + upsampled_resolution=UPSAMPLED_RESOLUTION, + sel_iso3s=None, + ) + sys.exit(0) + + dataset = args.dataset + logger.info(f"Updating data for {dataset}...") + create_qa_table(engine) settings = load_pipeline_config(dataset) start, end, is_forecast = parse_pipeline_config(settings, args.test) create_dataset_table(dataset, engine, is_forecast) - if args.build_iso3: - logger.info("Creating ISO3 table in Postgres database...") - create_iso3_df(engine) sel_iso3s = settings["test"]["iso3s"] if args.test else None df_iso3s = get_iso3_data(sel_iso3s, engine) diff --git a/src/config/settings.py b/src/config/settings.py index 2538726..46a6e57 100644 --- a/src/config/settings.py +++ b/src/config/settings.py @@ -5,7 +5,7 @@ load_dotenv() - +UPSAMPLED_RESOLUTION = 0.05 LOG_LEVEL = "INFO" AZURE_DB_PW_DEV = os.getenv("AZURE_DB_PW_DEV") AZURE_DB_PW_PROD = os.getenv("AZURE_DB_PW_PROD") diff --git a/src/utils/database_utils.py b/src/utils/database_utils.py index b4d97a5..3442cbc 100644 --- a/src/utils/database_utils.py +++ b/src/utils/database_utils.py @@ -131,6 +131,56 @@ def create_iso3_table(engine): metadata.create_all(engine) +def create_polygon_table(engine, datasets): + """ + Create a table for storing polygon metadata in the database. + + Parameters + ---------- + engine : sqlalchemy.engine.Engine + The SQLAlchemy engine object used to connect to the database. + datasets : list of str + List of dataset names to include in the table structure. + + Returns + ------- + None + """ + columns = [ + Column("pcode", String), + Column("iso3", CHAR(3)), + Column("adm_level", Integer), + Column("name", String), + Column("name_language", String), + Column("area", REAL), + Column("standard", Boolean), + ] + + for dataset in datasets: + columns.extend( + [ + Column(f"{dataset}_n_intersect_raw_pixels", Integer), + Column(f"{dataset}_frac_raw_pixels", REAL), + Column(f"{dataset}_n_upsampled_pixels", Integer), + ] + ) + + metadata = MetaData() + Table( + "polygon", + metadata, + *columns, + UniqueConstraint( + *["pcode", "iso3", "adm_level"], + name="polygon_valid_date_leadtime_pcode_key", + postgresql_nulls_not_distinct=True, + ), + ) + + metadata.create_all(engine) + return + + def insert_qa_table(iso3, adm_level, dataset, error, stack_trace, engine): """ Insert an error record into the 'qa' table in the database. diff --git a/src/utils/inputs.py b/src/utils/inputs.py index 0ab8585..42d7c69 100644 --- a/src/utils/inputs.py +++ b/src/utils/inputs.py @@ -11,6 +11,7 @@ def cli_args(): help="Dataset for which to calculate raster stats", choices=["seas5", "era5", "imerg"], default=None, + nargs="?", ) parser.add_argument( "--mode", @@ -26,8 +27,8 @@ def cli_args(): action="store_true", ) parser.add_argument( - "--build-iso3", - help="""Builds the `iso3` table in Postgres""", + "--update-metadata", + help="Update the iso3 and polygon metadata tables.", action="store_true", ) return parser.parse_args() diff --git a/src/utils/iso3_utils.py b/src/utils/iso3_utils.py index 9aa9d82..a7989a0 100644 --- a/src/utils/iso3_utils.py +++ b/src/utils/iso3_utils.py @@ -187,9 +187,11 @@ def create_iso3_df(engine): df["max_adm_level"] = df.apply(determine_max_adm_level, axis=1) df["stats_last_updated"] = None + # TODO: This list seems to have some inconsistencies when compared against the + # contents of all polygons # Also need global p-codes list from https://fieldmaps.io/data/cod # We want to get the total number of pcodes per iso3, across each admin level - df_pcodes = pd.read_csv("data/global-pcodes.csv") + df_pcodes = pd.read_csv("data/global-pcodes.csv", low_memory=False) df_pcodes.drop(df_pcodes.index[0], inplace=True) df_counts = ( df_pcodes.groupby(["Location", "Admin Level"])["P-Code"] diff --git a/src/utils/metadata_utils.py b/src/utils/metadata_utils.py new file mode 100644 index 0000000..13c25b5 --- /dev/null +++ b/src/utils/metadata_utils.py @@ -0,0 +1,169 @@ +import logging +import tempfile +from pathlib import Path + +import coloredlogs +import geopandas as gpd +import numpy as np +import pandas as pd +import rioxarray as rxr + +from src.config.settings import LOG_LEVEL, load_pipeline_config +from src.utils.cloud_utils import get_container_client +from src.utils.cog_utils import get_cog_url +from src.utils.database_utils import create_polygon_table, postgres_upsert +from src.utils.iso3_utils import get_iso3_data, load_shp_from_azure +from src.utils.raster_utils import fast_zonal_stats, prep_raster, rasterize_admin + +logger = logging.getLogger(__name__) +coloredlogs.install(level=LOG_LEVEL, logger=logger) + + +def get_available_datasets(): + """ + Get list of available datasets from config directory. + + Returns + ------- + List[str] + List of dataset names (based on config file names) + """ + config_dir = Path("src") / "config" + return [f.stem for f in config_dir.glob("*.yml")] + + +def select_name_column(df, adm_level): + """ + Select the appropriate name column from the administrative boundary GeoDataFrame. + + Parameters + ---------- + df : geopandas.GeoDataFrame + The GeoDataFrame containing administrative boundary data. + adm_level : int + The administrative level to find the name column for. + + Returns + ------- + str + The name of the selected column. + """ + pattern = f"^ADM{adm_level}_[A-Z]{{2}}$" + adm_columns = df.filter(regex=pattern).columns + return adm_columns[0] + + +def get_single_cog(dataset, mode): + container_client = get_container_client(mode, "raster") + config = load_pipeline_config(dataset) + prefix = config["blob_prefix"] + cogs_list = [x.name for x in container_client.list_blobs(name_starts_with=prefix)] + cog_url = get_cog_url(mode, cogs_list[0]) + return rxr.open_rasterio(cog_url, chunks="auto") + + +def process_polygon_metadata(engine, mode, upsampled_resolution, sel_iso3s=None): + """ + Process and store polygon metadata for all administrative levels and datasets. + + Parameters + ---------- + engine : sqlalchemy.engine.Engine + The SQLAlchemy engine object used to connect to the database. + mode : str + The mode to run in ('dev', 'prod', etc.). + upsampled_resolution : float, optional + The desired output resolution for raster data. + sel_iso3s : list of str, optional + List of ISO3 codes to process. If None, processes all available. + + Returns + ------- + None + """ + datasets = get_available_datasets() + create_polygon_table(engine, datasets) + df_iso3s = get_iso3_data(None, engine) + + with tempfile.TemporaryDirectory() as td: + for _, row in df_iso3s.iterrows(): + iso3 = row["iso3"] + logger.info(f"Processing polygon metadata for {iso3}...") + max_adm = row["max_adm_level"] + load_shp_from_azure(iso3, td, mode) + try: + for i in range(0, max_adm + 1): + gdf = gpd.read_file(f"{td}/{iso3.lower()}_adm{i}.shp") + for dataset in datasets: + da = get_single_cog(dataset, mode) + input_resolution = da.rio.resolution() + gdf_adm0 = gpd.read_file(f"{td}/{iso3.lower()}_adm0.shp") + # We want all values to be unique, so that we can count the total + # number of unique cells from the raw source that contribute to the stats + da.values = np.arange(da.size).reshape(da.shape) + da = da.astype(np.float32) + + da_clipped = prep_raster(da, gdf_adm0, logger=logger) + output_resolution = da_clipped.rio.resolution() + upscale_factor = input_resolution[0] / output_resolution[0] + + src_transform = da_clipped.rio.transform() + src_width = da_clipped.rio.width + src_height = da_clipped.rio.height + + admin_raster = rasterize_admin( + gdf, src_width, src_height, src_transform, all_touched=False + ) + adm_ids = gdf[f"ADM{i}_PCODE"] + n_adms = len(adm_ids) + + results = fast_zonal_stats( + da_clipped.values[0], + admin_raster, + n_adms, + stats=["count", "unique"], + rast_fill=np.nan, + ) + df_results = pd.DataFrame.from_dict(results) + df_results[f"{dataset}_frac_raw_pixels"] = df_results[ + "count" + ] / (upscale_factor**2) + df_results = df_results.rename( + columns={ + "unique": f"{dataset}_n_intersect_raw_pixels", + "count": f"{dataset}_n_upsampled_pixels", + } + ) + gdf = gdf.join(df_results) + + gdf = gdf.to_crs("ESRI:54009") + gdf["area"] = gdf.geometry.area / 1_000_000 + + name_column = select_name_column(gdf, i) + extract_cols = [f"ADM{i}_PCODE", name_column, "area"] + dataset_cols = gdf.columns[ + gdf.columns.str.contains( + "_n_intersect_raw_pixels|" + "_frac_raw_pixels|" + "_n_upsampled_pixels" + ) + ] + + df = gdf[extract_cols + dataset_cols.tolist()] + df = df.rename( + columns={f"ADM{i}_PCODE": "pcode", name_column: "name"} + ) + df["adm_level"] = i + df["name_language"] = name_column[-2:] + df["iso3"] = iso3 + df["standard"] = True + + df.to_sql( + "polygon", + con=engine, + if_exists="append", + index=False, + method=postgres_upsert, + ) + except Exception as e: + logger.error(f"Error: {e}") diff --git a/src/utils/raster_utils.py b/src/utils/raster_utils.py index 53da3ab..781cece 100644 --- a/src/utils/raster_utils.py +++ b/src/utils/raster_utils.py @@ -8,7 +8,7 @@ from rasterio.enums import Resampling from rasterio.features import rasterize -from src.config.settings import LOG_LEVEL +from src.config.settings import LOG_LEVEL, UPSAMPLED_RESOLUTION from src.utils.database_utils import postgres_upsert from src.utils.general_utils import add_months_to_date @@ -193,6 +193,9 @@ def fast_zonal_stats( "sum": np.nansum, "std": np.nanstd, "count": lambda x, axis: np.sum(~np.isnan(x), axis=axis), + "unique": lambda x, axis: np.array( + [len(np.unique(row[~np.isnan(row)])) for row in x] + ), } for stat in stats: @@ -204,7 +207,7 @@ def fast_zonal_stats( return feature_stats -def upsample_raster(ds, resampled_resolution=0.05, logger=None): +def upsample_raster(ds, resampled_resolution=UPSAMPLED_RESOLUTION, logger=None): """ Upsample a raster to a higher resolution using nearest neighbor resampling, via the `Resampling.nearest` method from `rasterio`. @@ -214,7 +217,7 @@ def upsample_raster(ds, resampled_resolution=0.05, logger=None): ds : xarray.Dataset The raster data set to upsample. Must not have >4 dimensions. resampled_resolution : float, optional - The desired resolution for the upsampled raster. Default is 0.05. + The desired resolution for the upsampled raster. Returns -------