Skip to content

Commit

Permalink
Merge pull request #5 from OCHA-DAP/final-tweaks
Browse files Browse the repository at this point in the history
Final tweaks
  • Loading branch information
hannahker authored Oct 9, 2024
2 parents 7873de9 + e8b8fa1 commit 97a006e
Show file tree
Hide file tree
Showing 14 changed files with 540 additions and 96 deletions.
27 changes: 27 additions & 0 deletions .github/workflows/run_tests.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
name: Run Tests

on:
push:
branches: [ main ]
pull_request:
branches: [ main ]

jobs:
test:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2

- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: '3.12'

- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -r requirements.txt
- name: Run tests
run: |
python -m pytest tests/
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ pip install -r requirements-dev.txt
# Connection to Azure blob storage
DSCI_AZ_SAS_DEV=<provided-on-request>
DSCI_AZ_SAS_PROD=<provided-on-request>
AZURE_DB_PW=<provided-on-request>
AZURE_DB_PW_DEV=<provided-on-request>
AZURE_DB_PW_PROD=<provided-on-request>
```

### Pre-Commit
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@ python-dotenv==1.0.1
psycopg2_binary==2.9.9
sqlalchemy==2.0.33
rasterio==1.3.10
pytest==8.3.3
156 changes: 110 additions & 46 deletions run_raster_stats.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,121 @@
import logging
import sys
import tempfile
import traceback
from multiprocessing import Pool, current_process

import coloredlogs
import geopandas as gpd
import pandas as pd
from sqlalchemy import create_engine

from src.config.settings import LOG_LEVEL, load_pipeline_config, parse_pipeline_config
from src.utils.cog_utils import stack_cogs
from src.utils.database_utils import (
create_dataset_table,
create_qa_table,
db_engine,
db_engine_url,
insert_qa_table,
postgres_upsert,
)
from src.utils.general_utils import split_date_range
from src.utils.inputs import cli_args
from src.utils.iso3_utils import create_iso3_df, get_iso3_data, load_shp
from src.utils.iso3_utils import create_iso3_df, get_iso3_data, load_shp_from_azure
from src.utils.raster_utils import fast_zonal_stats_runner, prep_raster

logger = logging.getLogger(__name__)
coloredlogs.install(level=LOG_LEVEL, logger=logger)


def setup_logger(name, level=logging.INFO):
"""Function to setup a logger that prints to console"""
logger = logging.getLogger(name)
logger.setLevel(level)
coloredlogs.install(level=level, logger=logger)

if not logger.handlers:
handler = logging.StreamHandler(sys.stdout)
formatter = logging.Formatter(
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
handler.setFormatter(formatter)
logger.addHandler(handler)

return logger


def process_chunk(start, end, dataset, mode, df_iso3s, engine_url):
process_name = current_process().name
logger = setup_logger(f"{process_name}: {dataset}_{start}")
logger.info(f"Starting processing for {dataset} from {start} to {end}")

engine = create_engine(engine_url)
ds = stack_cogs(start, end, dataset, mode)

try:
for _, row in df_iso3s.iterrows():
iso3 = row["iso_3"]
# shp_url = row["o_shp"]
max_adm = row["max_adm_level"]
logger.info(f"Processing data for {iso3}...")

with tempfile.TemporaryDirectory() as td:
load_shp_from_azure(iso3, td, mode)
gdf = gpd.read_file(f"{td}/{iso3.lower()}_adm0.shp")
try:
ds_clipped = prep_raster(ds, gdf, logger=logger)
except Exception as e:
logger.error(f"Error preparing raster for {iso3}: {e}")
stack_trace = traceback.format_exc()
insert_qa_table(iso3, None, dataset, e, stack_trace, engine)
continue

try:
all_results = []
for adm_level in range(max_adm + 1):
gdf = gpd.read_file(f"{td}/{iso3.lower()}_adm{adm_level}.shp")
logger.info(f"Computing stats for adm{adm_level}...")
df_results = fast_zonal_stats_runner(
ds_clipped,
gdf,
adm_level,
iso3,
save_to_database=False,
engine=None,
dataset=dataset,
logger=logger,
)
if df_results is not None:
all_results.append(df_results)
df_all_results = pd.concat(all_results, ignore_index=True)
logger.info(f"Writing {len(df_all_results)} rows to database...")
df_all_results.to_sql(
f"{dataset}",
con=engine,
if_exists="append",
index=False,
method=postgres_upsert,
)
except Exception as e:
logger.error(f"Error calculating stats for {iso3}: {e}")
stack_trace = traceback.format_exc()
insert_qa_table(iso3, adm_level, dataset, e, stack_trace, engine)
continue
# Clear memory
del ds_clipped

finally:
engine.dispose()


if __name__ == "__main__":
args = cli_args()
dataset = args.dataset
logger.info(f"Updating data for {dataset}...")

engine = db_engine(args.mode)
engine_url = db_engine_url(args.mode)
engine = create_engine(engine_url)

create_qa_table(engine)
settings = load_pipeline_config(dataset)
start, end, is_forecast = parse_pipeline_config(settings, args.test)
Expand All @@ -37,49 +126,24 @@

sel_iso3s = settings["test"]["iso3s"] if args.test else None
df_iso3s = get_iso3_data(sel_iso3s, engine)
date_ranges = split_date_range(start, end)

logger.info(f"Creating stack of COGs from {start} to {end}...")
ds = stack_cogs(start, end, dataset, args.mode)

# Loop through each country
logger.info(f"Calculating raster stats for {len(df_iso3s)} ISO3s...")
for idx, row in df_iso3s.iterrows():
iso3 = df_iso3s.loc[idx, "iso_3"]
shp_url = df_iso3s.loc[idx, "o_shp"]
max_adm = df_iso3s.loc[idx, "max_adm_level"]
logger.info(f"Processing data for {iso3}...")

with tempfile.TemporaryDirectory() as td:
load_shp(shp_url, td, iso3)
gdf = gpd.read_file(f"{td}/{iso3.lower()}_adm0.shp")
try:
ds_clipped = prep_raster(ds, gdf)
except Exception as e:
logger.error(f"Error preparing raster for {iso3}: {e}")
stack_trace = traceback.format_exc()
insert_qa_table(iso3, None, dataset, e, stack_trace, engine)
continue

# Loop through each adm
for adm_level in list(range(0, max_adm + 1)):
try:
gdf = gpd.read_file(f"{td}/{iso3.lower()}_adm{adm_level}.shp")
logger.info(f"Computing stats for adm{adm_level}...")
fast_zonal_stats_runner(
ds_clipped,
gdf,
adm_level,
iso3,
save_to_database=False,
engine=engine,
dataset=dataset,
)
except Exception as e:
logger.error(
f"Error calculating stats for {iso3} at {adm_level}: {e}"
)
stack_trace = traceback.format_exc()
insert_qa_table(iso3, adm_level, dataset, e, stack_trace, engine)
continue
if len(date_ranges) > 1:
num_processes = 5
logger.info(
f"Processing {len(date_ranges)} chunks with {num_processes} processes"
)

process_args = [
(start, end, dataset, args.mode, df_iso3s, engine_url)
for start, end in date_ranges
]

with Pool(num_processes) as pool:
pool.starmap(process_chunk, process_args)

else:
logger.info("Processing entire date range in a single chunk")
process_chunk(start, end, dataset, args.mode, df_iso3s, engine_url)

logger.info("Done calculating and saving stats.")
6 changes: 3 additions & 3 deletions src/config/era5.yml
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
blob_prefix: era5/monthly/processed/precip_reanalysis_v
start_date: 1981-01-01
end_date: 2024-07-30
end_date: 2024-10-30
forecast: False
test:
start_date: 2020-01-01
end_date: 2020-05-01
iso3s: ["AFG", "BRA"]
end_date: 2020-02-01
iso3s: ["AFG"]
8 changes: 4 additions & 4 deletions src/config/imerg.yml
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
blob_prefix: imerg/daily/late/v7/processed/imerg-daily-late-
start_date: 2000-06-01
end_date: 2024-07-30
end_date: 2024-10-30
forecast: False
test:
start_date: 2020-01-01
end_date: 2020-01-15
iso3s: ["AFG"]
start_date: 2000-06-01
end_date: 2002-01-15
iso3s: ["AFG", "ETH", "HTI", "LBN"]
8 changes: 4 additions & 4 deletions src/config/seas5.yml
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
blob_prefix: seas5/monthly/processed/precip_em_i
start_date: 1981-01-01
end_date: 2024-07-30
end_date: 2024-10-30
forecast: True
test:
start_date: 2020-01-01
end_date: 2020-05-01
iso3s: ["AFG"]
start_date: 1981-01-01
end_date: 2024-10-01
iso3s: ["ETH", "AFG", "CMR"]
8 changes: 4 additions & 4 deletions src/config/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@
load_dotenv()


MAX_ADM = 2
LOG_LEVEL = "INFO"
AZURE_DB_PW = os.getenv("AZURE_DB_PW")
AZURE_DB_PW_DEV = os.getenv("AZURE_DB_PW_DEV")
AZURE_DB_PW_PROD = os.getenv("AZURE_DB_PW_PROD")
DATABASES = {
"local": "sqlite:///chd-rasterstats-local.db",
"dev": f"postgresql+psycopg2://chdadmin:{AZURE_DB_PW}@chd-rasterstats-dev.postgres.database.azure.com/postgres", # noqa
"prod": f"postgresql+psycopg2://chdadmin:{AZURE_DB_PW}@chd-rasterstats-dev.postgres.database.azure.com/postgres", # noqa
"dev": f"postgresql+psycopg2://chdadmin:{AZURE_DB_PW_DEV}@chd-rasterstats-dev.postgres.database.azure.com/postgres", # noqa
"prod": f"postgresql+psycopg2://chdadmin:{AZURE_DB_PW_PROD}@chd-rasterstats-prod.postgres.database.azure.com/postgres", # noqa
}


Expand Down
6 changes: 5 additions & 1 deletion src/utils/cog_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,11 @@ def stack_cogs(start_date, end_date, dataset="era5", mode="dev"):
raise Exception("No COGs found to process")

das = []
for cog in tqdm.tqdm(cogs_list):

# Only show progress bar if running in interactive mode (ie. running locally)
cogs_list = tqdm.tqdm(cogs_list) if mode == "local" else cogs_list

for cog in cogs_list:
if dataset == "era5":
da_in = process_era5(cog, mode)
elif dataset == "seas5":
Expand Down
Loading

0 comments on commit 97a006e

Please sign in to comment.