Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create polygon metadata table #8

Merged
merged 9 commits into from
Nov 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 21 additions & 6 deletions run_raster_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,12 @@
import pandas as pd
from sqlalchemy import create_engine

from src.config.settings import LOG_LEVEL, load_pipeline_config, parse_pipeline_config
from src.config.settings import (
LOG_LEVEL,
UPSAMPLED_RESOLUTION,
load_pipeline_config,
parse_pipeline_config,
)
from src.utils.cog_utils import stack_cogs
from src.utils.database_utils import (
create_dataset_table,
Expand All @@ -21,6 +26,7 @@
from src.utils.general_utils import split_date_range
from src.utils.inputs import cli_args
from src.utils.iso3_utils import create_iso3_df, get_iso3_data, load_shp_from_azure
from src.utils.metadata_utils import process_polygon_metadata
from src.utils.raster_utils import fast_zonal_stats_runner, prep_raster

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -110,19 +116,28 @@ def process_chunk(start, end, dataset, mode, df_iso3s, engine_url):

if __name__ == "__main__":
args = cli_args()
dataset = args.dataset
logger.info(f"Updating data for {dataset}...")

engine_url = db_engine_url(args.mode)
engine = create_engine(engine_url)

if args.update_metadata:
logger.info("Updating metadata in Postgres database...")
create_iso3_df(engine)
process_polygon_metadata(
engine,
args.mode,
upsampled_resolution=UPSAMPLED_RESOLUTION,
sel_iso3s=None,
)
sys.exit(0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this sys.exit here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just because the program can terminate here after updating the metadata!


dataset = args.dataset
logger.info(f"Updating data for {dataset}...")

create_qa_table(engine)
settings = load_pipeline_config(dataset)
start, end, is_forecast = parse_pipeline_config(settings, args.test)
create_dataset_table(dataset, engine, is_forecast)
if args.build_iso3:
logger.info("Creating ISO3 table in Postgres database...")
create_iso3_df(engine)

sel_iso3s = settings["test"]["iso3s"] if args.test else None
df_iso3s = get_iso3_data(sel_iso3s, engine)
Expand Down
2 changes: 1 addition & 1 deletion src/config/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

load_dotenv()


UPSAMPLED_RESOLUTION = 0.05
LOG_LEVEL = "INFO"
AZURE_DB_PW_DEV = os.getenv("AZURE_DB_PW_DEV")
AZURE_DB_PW_PROD = os.getenv("AZURE_DB_PW_PROD")
Expand Down
50 changes: 50 additions & 0 deletions src/utils/database_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,56 @@ def create_iso3_table(engine):
metadata.create_all(engine)


def create_polygon_table(engine, datasets):
"""
Create a table for storing polygon metadata in the database.

Parameters
----------
engine : sqlalchemy.engine.Engine
The SQLAlchemy engine object used to connect to the database.
datasets : list of str
List of dataset names to include in the table structure.

Returns
-------
None
"""
columns = [
Column("pcode", String),
Column("iso3", CHAR(3)),
Column("adm_level", Integer),
Column("name", String),
Column("name_language", String),
Column("area", REAL),
Column("standard", Boolean),
]

for dataset in datasets:
columns.extend(
[
Column(f"{dataset}_n_intersect_raw_pixels", Integer),
Column(f"{dataset}_frac_raw_pixels", REAL),
Column(f"{dataset}_n_upsampled_pixels", Integer),
]
)

metadata = MetaData()
Table(
"polygon",
metadata,
*columns,
UniqueConstraint(
*["pcode", "iso3", "adm_level"],
name="polygon_valid_date_leadtime_pcode_key",
postgresql_nulls_not_distinct=True,
),
)

metadata.create_all(engine)
return


def insert_qa_table(iso3, adm_level, dataset, error, stack_trace, engine):
"""
Insert an error record into the 'qa' table in the database.
Expand Down
5 changes: 3 additions & 2 deletions src/utils/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ def cli_args():
help="Dataset for which to calculate raster stats",
choices=["seas5", "era5", "imerg"],
default=None,
nargs="?",
)
parser.add_argument(
"--mode",
Expand All @@ -26,8 +27,8 @@ def cli_args():
action="store_true",
)
parser.add_argument(
"--build-iso3",
help="""Builds the `iso3` table in Postgres""",
"--update-metadata",
help="Update the iso3 and polygon metadata tables.",
action="store_true",
)
return parser.parse_args()
4 changes: 3 additions & 1 deletion src/utils/iso3_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,9 +187,11 @@ def create_iso3_df(engine):
df["max_adm_level"] = df.apply(determine_max_adm_level, axis=1)
df["stats_last_updated"] = None

# TODO: This list seems to have some inconsistencies when compared against the
# contents of all polygons
# Also need global p-codes list from https://fieldmaps.io/data/cod
# We want to get the total number of pcodes per iso3, across each admin level
df_pcodes = pd.read_csv("data/global-pcodes.csv")
df_pcodes = pd.read_csv("data/global-pcodes.csv", low_memory=False)
df_pcodes.drop(df_pcodes.index[0], inplace=True)
df_counts = (
df_pcodes.groupby(["Location", "Admin Level"])["P-Code"]
Expand Down
169 changes: 169 additions & 0 deletions src/utils/metadata_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
import logging
import tempfile
from pathlib import Path

import coloredlogs
import geopandas as gpd
import numpy as np
import pandas as pd
import rioxarray as rxr

from src.config.settings import LOG_LEVEL, load_pipeline_config
from src.utils.cloud_utils import get_container_client
from src.utils.cog_utils import get_cog_url
from src.utils.database_utils import create_polygon_table, postgres_upsert
from src.utils.iso3_utils import get_iso3_data, load_shp_from_azure
from src.utils.raster_utils import fast_zonal_stats, prep_raster, rasterize_admin

logger = logging.getLogger(__name__)
coloredlogs.install(level=LOG_LEVEL, logger=logger)


def get_available_datasets():
"""
Get list of available datasets from config directory.

Returns
-------
List[str]
List of dataset names (based on config file names)
"""
config_dir = Path("src") / "config"
return [f.stem for f in config_dir.glob("*.yml")]


def select_name_column(df, adm_level):
"""
Select the appropriate name column from the administrative boundary GeoDataFrame.

Parameters
----------
df : geopandas.GeoDataFrame
The GeoDataFrame containing administrative boundary data.
adm_level : int
The administrative level to find the name column for.

Returns
-------
str
The name of the selected column.
"""
pattern = f"^ADM{adm_level}_[A-Z]{{2}}$"
adm_columns = df.filter(regex=pattern).columns
return adm_columns[0]


def get_single_cog(dataset, mode):
container_client = get_container_client(mode, "raster")
config = load_pipeline_config(dataset)
prefix = config["blob_prefix"]
cogs_list = [x.name for x in container_client.list_blobs(name_starts_with=prefix)]
cog_url = get_cog_url(mode, cogs_list[0])
return rxr.open_rasterio(cog_url, chunks="auto")


def process_polygon_metadata(engine, mode, upsampled_resolution, sel_iso3s=None):
"""
Process and store polygon metadata for all administrative levels and datasets.

Parameters
----------
engine : sqlalchemy.engine.Engine
The SQLAlchemy engine object used to connect to the database.
mode : str
The mode to run in ('dev', 'prod', etc.).
upsampled_resolution : float, optional
The desired output resolution for raster data.
sel_iso3s : list of str, optional
List of ISO3 codes to process. If None, processes all available.

Returns
-------
None
"""
datasets = get_available_datasets()
create_polygon_table(engine, datasets)
df_iso3s = get_iso3_data(None, engine)

with tempfile.TemporaryDirectory() as td:
for _, row in df_iso3s.iterrows():
iso3 = row["iso3"]
logger.info(f"Processing polygon metadata for {iso3}...")
max_adm = row["max_adm_level"]
load_shp_from_azure(iso3, td, mode)
try:
for i in range(0, max_adm + 1):
gdf = gpd.read_file(f"{td}/{iso3.lower()}_adm{i}.shp")
for dataset in datasets:
da = get_single_cog(dataset, mode)
input_resolution = da.rio.resolution()
gdf_adm0 = gpd.read_file(f"{td}/{iso3.lower()}_adm0.shp")
# We want all values to be unique, so that we can count the total
# number of unique cells from the raw source that contribute to the stats
da.values = np.arange(da.size).reshape(da.shape)
da = da.astype(np.float32)

da_clipped = prep_raster(da, gdf_adm0, logger=logger)
output_resolution = da_clipped.rio.resolution()
upscale_factor = input_resolution[0] / output_resolution[0]

src_transform = da_clipped.rio.transform()
src_width = da_clipped.rio.width
src_height = da_clipped.rio.height

admin_raster = rasterize_admin(
gdf, src_width, src_height, src_transform, all_touched=False
)
adm_ids = gdf[f"ADM{i}_PCODE"]
n_adms = len(adm_ids)

results = fast_zonal_stats(
da_clipped.values[0],
admin_raster,
n_adms,
stats=["count", "unique"],
rast_fill=np.nan,
)
df_results = pd.DataFrame.from_dict(results)
df_results[f"{dataset}_frac_raw_pixels"] = df_results[
"count"
] / (upscale_factor**2)
df_results = df_results.rename(
columns={
"unique": f"{dataset}_n_intersect_raw_pixels",
"count": f"{dataset}_n_upsampled_pixels",
}
)
gdf = gdf.join(df_results)

gdf = gdf.to_crs("ESRI:54009")
gdf["area"] = gdf.geometry.area / 1_000_000

name_column = select_name_column(gdf, i)
extract_cols = [f"ADM{i}_PCODE", name_column, "area"]
dataset_cols = gdf.columns[
gdf.columns.str.contains(
"_n_intersect_raw_pixels|"
"_frac_raw_pixels|"
"_n_upsampled_pixels"
)
]

df = gdf[extract_cols + dataset_cols.tolist()]
df = df.rename(
columns={f"ADM{i}_PCODE": "pcode", name_column: "name"}
)
df["adm_level"] = i
df["name_language"] = name_column[-2:]
df["iso3"] = iso3
df["standard"] = True

df.to_sql(
"polygon",
con=engine,
if_exists="append",
index=False,
method=postgres_upsert,
)
except Exception as e:
logger.error(f"Error: {e}")
9 changes: 6 additions & 3 deletions src/utils/raster_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from rasterio.enums import Resampling
from rasterio.features import rasterize

from src.config.settings import LOG_LEVEL
from src.config.settings import LOG_LEVEL, UPSAMPLED_RESOLUTION
from src.utils.database_utils import postgres_upsert
from src.utils.general_utils import add_months_to_date

Expand Down Expand Up @@ -193,6 +193,9 @@ def fast_zonal_stats(
"sum": np.nansum,
"std": np.nanstd,
"count": lambda x, axis: np.sum(~np.isnan(x), axis=axis),
"unique": lambda x, axis: np.array(
[len(np.unique(row[~np.isnan(row)])) for row in x]
),
}

for stat in stats:
Expand All @@ -204,7 +207,7 @@ def fast_zonal_stats(
return feature_stats


def upsample_raster(ds, resampled_resolution=0.05, logger=None):
def upsample_raster(ds, resampled_resolution=UPSAMPLED_RESOLUTION, logger=None):
"""
Upsample a raster to a higher resolution using nearest neighbor resampling,
via the `Resampling.nearest` method from `rasterio`.
Expand All @@ -214,7 +217,7 @@ def upsample_raster(ds, resampled_resolution=0.05, logger=None):
ds : xarray.Dataset
The raster data set to upsample. Must not have >4 dimensions.
resampled_resolution : float, optional
The desired resolution for the upsampled raster. Default is 0.05.
The desired resolution for the upsampled raster.

Returns
-------
Expand Down
Loading