Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/physics_checkpointer #351

Open
wants to merge 64 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
64 commits
Select commit Hold shift + click to select a range
beefed4
fix variable mismatch
elynnwu Aug 30, 2022
7ad0486
reinitialize dycore to reset temp storages
elynnwu Sep 1, 2022
e873534
Merge branch 'main' into fix/checkpointer-validation
elynnwu Sep 1, 2022
8119251
Merge branch 'main' into fix/checkpointer-validation
elynnwu Sep 8, 2022
6788d77
fix var name
elynnwu Sep 9, 2022
70c044e
fvdynamics-out tracers in compute domain only
elynnwu Sep 9, 2022
deef15c
update data version
elynnwu Sep 9, 2022
2c0fbdc
revert comment
elynnwu Sep 12, 2022
f91c22e
fix driver subset
elynnwu Sep 12, 2022
534a887
clean up translate test
elynnwu Sep 13, 2022
ab5992f
Merge branch 'main' into fix/checkpointer-validation
elynnwu Sep 13, 2022
9154218
add comment
elynnwu Sep 14, 2022
5842e3e
Merge branch 'main' into current
elynnwu Sep 14, 2022
3a33976
Bump to 8.1.3
elynnwu Sep 15, 2022
cd8907b
add physics and driver savepoint tests
elynnwu Sep 16, 2022
36573ac
specify experiement name
elynnwu Sep 16, 2022
e0beac8
add python version to caching
elynnwu Sep 16, 2022
7dbb1b9
separate test, increase threshold for circleci
elynnwu Sep 16, 2022
ef9a475
update test names
elynnwu Sep 16, 2022
889e19f
update config
elynnwu Sep 16, 2022
04cc70a
add first attempt at unified savepoints target
mcgibbon Sep 16, 2022
52d89ad
Merge branch 'feature/add-circleci-physics-driver-test' of github.com…
mcgibbon Sep 16, 2022
89df5ce
attempt to put most of the test logic in a command
mcgibbon Sep 16, 2022
e36a10a
fix config to validate
mcgibbon Sep 16, 2022
343a1c2
fix test_driver, lint, and dycore_savepoints_mpi
mcgibbon Sep 16, 2022
74d0e1d
add remaining circleci plans
mcgibbon Sep 16, 2022
1941476
activate venv correctly in driver test
mcgibbon Sep 16, 2022
0ead115
use independent keys for gt caching
mcgibbon Sep 16, 2022
c6011ca
downgrade serial savepoint tests to medium resource type
mcgibbon Sep 16, 2022
3e9d080
add h5netcdf to requirement
elynnwu Sep 19, 2022
42cfc9b
specify xarray engine
elynnwu Sep 19, 2022
d877737
fix bug in checkpointer test, add more savepoints, update thresholds
mcgibbon Sep 19, 2022
d4949e4
Merge branch 'feature/add-circleci-physics-driver-test' into feature/…
mcgibbon Sep 19, 2022
2d207f0
add circleci execution of checkpoint tests, label as savepoint tests,…
mcgibbon Sep 19, 2022
859d2a5
fix config validation
mcgibbon Sep 19, 2022
e743e62
define google application credentials env var for top level savepoint…
mcgibbon Sep 19, 2022
915e6e3
use machine worker for test_savepoints
mcgibbon Sep 19, 2022
516f60f
Merge branch 'main' into feature/checkpointer_circleci
mcgibbon Sep 19, 2022
c79eec4
run the 54rank test on an xlarge worker
mcgibbon Sep 19, 2022
f96cb35
delete duplicated test definitions
mcgibbon Sep 19, 2022
e0e6e79
fix path for savepoint tests
mcgibbon Sep 19, 2022
6f88f0b
move directory to match moved path in last commit
mcgibbon Sep 19, 2022
4becd03
remove driver 54rank compiled test from CircleCI
mcgibbon Sep 20, 2022
8589707
increase driver_savepoints_mpi resource type to xlarge
mcgibbon Sep 20, 2022
19fbcf1
Merge branch 'main' into feature/checkpointer_circleci
mcgibbon Sep 20, 2022
fcecc1a
checkpointer test works with remapping and tracer advection
mcgibbon Sep 21, 2022
92bb68e
Merge branch 'main' into feature/checkpointer_circleci
mcgibbon Sep 23, 2022
bdda2bf
Merge branch 'main' into feature/checkpointer_circleci
elynnwu Sep 26, 2022
5c59895
physics checkpointer init
elynnwu Sep 28, 2022
5ea06df
update phys and driver out
elynnwu Sep 28, 2022
538d943
physics checkpoint working for 1 step
elynnwu Sep 30, 2022
ddea611
physics checkpoint working
elynnwu Oct 4, 2022
19b9944
2 step driver threshold
elynnwu Oct 4, 2022
61a24c5
Merge branch 'main' into feature/physics_checkpointer
elynnwu Oct 4, 2022
480a16a
delete unused checkpoint
elynnwu Oct 4, 2022
90c26d9
lint
elynnwu Oct 5, 2022
b394f02
Merge branch 'feature/physics_checkpointer'
elynnwu Oct 6, 2022
4c85dea
resolve circular dependency
elynnwu Oct 7, 2022
62d48b5
Merge branch 'feature/physics_checkpointer' of github.com:ai2cm/pace …
elynnwu Oct 7, 2022
9fdd382
lint on the right python version
elynnwu Oct 7, 2022
3a379ba
Merge branch 'main' into feature/physics_checkpointer
elynnwu Oct 10, 2022
87f5115
remove unchecked variable
elynnwu Oct 10, 2022
7855319
Merge branch 'main'
elynnwu Nov 17, 2022
7832131
Merge branch 'main' into feature/physics_checkpointer
mcgibbon Dec 22, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 95 additions & 1 deletion dsl/pace/dsl/dace/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
import time
from dataclasses import dataclass, field
from typing import Dict, List
from typing import Dict, List, Tuple

import dace
from dace.transformation.helpers import get_parent_map
Expand Down Expand Up @@ -37,6 +37,100 @@ def __exit__(self, _type, _val, _traceback):
DaCeProgress.log(self.prefix, f"{self.label}...{elapsed}s.")


def sdfg_nan_checker(sdfg: dace.SDFG):
"""
Insert a check on array after each computational map to check for NaN
in the domain.

In current pipeline, it is to be inserter after sdfg.simplify(...).
"""
import copy

import sympy as sp
from dace import data as dt
from dace import symbolic
from dace.sdfg import graph as gr
from dace.sdfg import utils as sdutil

# Adds a NaN checker after every mapexit->access node
checks: List[
Tuple[dace.SDFGState, dace.nodes.AccessNode, gr.MultiConnectorEdge[dace.Memlet]]
] = []
allmaps = [
(me, state)
for me, state in sdfg.all_nodes_recursive()
if isinstance(me, dace.nodes.MapEntry)
]
topmaps = [
(me, state) for me, state in allmaps if get_parent_map(state, me) is None
]
for me, state in topmaps:
mx = state.exit_node(me)
for e in state.out_edges(mx):
if isinstance(e.dst, dace.nodes.AccessNode):
if isinstance(e.dst.desc(state.parent), dt.View): # Skip views for now
continue
node = sdutil.get_last_view_node(state, e.dst)
if "diss_estd" in node.data:
continue
if state.memlet_path(e)[
0
].data.dynamic: # Skip dynamic (region) outputs
continue

checks.append((state, node, e))
for state, node, e in checks:
# Append node that will go after the map
newnode: dace.nodes.AccessNode = copy.deepcopy(node)
# Move all outgoing edges to new node
for oe in list(state.out_edges(node)):
state.remove_edge(oe)
state.add_edge(newnode, oe.src_conn, oe.dst, oe.dst_conn, oe.data)

# Add map in between node and newnode
sdfg = state.parent
inparr = sdfg.arrays[newnode.data]
index_expr = ", ".join(["__i%d" % i for i in range(len(inparr.shape))])
index_printf = ", ".join(["%d"] * len(inparr.shape))

# Get range from memlet (which may not be the entire array size)
def evaluate(expr):
return expr.subs({sp.Function("int_floor"): symbolic.int_floor})

# Infer scheduly
schedule_type = dace.ScheduleType.Default
if (
inparr.storage == dace.StorageType.GPU_Global
or inparr.storage == dace.StorageType.GPU_Shared
):
schedule_type = dace.ScheduleType.GPU_Device

ranges = []
for i, (begin, end, step) in enumerate(e.data.subset):
ranges.append(
(f"__i{i}", (evaluate(begin), evaluate(end), evaluate(step)))
) # evaluate to resolve views & actively read/write domains
state.add_mapped_tasklet(
name="nancheck",
map_ranges=ranges,
inputs={"__inp": dace.Memlet.simple(newnode.data, index_expr)},
code=f"""
if (__inp != __inp) {{
printf("NaN value found at {newnode.data}, line %d, index {index_printf}\\n", __LINE__, {index_expr});
}}
""", # noqa: E501
schedule=schedule_type,
language=dace.Language.CPP,
outputs={
"__out": dace.Memlet.simple(newnode.data, index_expr, num_accesses=-1)
},
input_nodes={node.data: node},
output_nodes={newnode.data: newnode},
external_edges=True,
)
logger.info(f"Added {len(checks)} NaN checks")


def _is_ref(sd: dace.sdfg.SDFG, aname: str):
found = False
for node, state in sd.all_nodes_recursive():
Expand Down
11 changes: 9 additions & 2 deletions fv3core/pace/fv3core/stencils/dyn_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,6 +607,11 @@ def _get_da_min(self) -> float:

def _checkpoint_csw(self, state: DycoreState, tag: str):
if self.call_checkpointer:
if tag == "In":
# need this to compare with Fortran
self._ut[:] = 0.0
self._vt[:] = 0.0
self._divgd.data[:] = 0.0
self.checkpointer(
f"C_SW-{tag}",
delpd=state.delp,
Expand All @@ -625,6 +630,8 @@ def _checkpoint_csw(self, state: DycoreState, tag: str):

def _checkpoint_dsw_in(self, state: DycoreState):
if self.call_checkpointer:
self._xfx[:] = 0.0
self._yfx[:] = 0.0
self.checkpointer(
"D_SW-In",
ucd=state.uc,
Expand All @@ -638,8 +645,8 @@ def _checkpoint_dsw_in(self, state: DycoreState):
ptd=state.pt,
uad=state.ua,
vad=state.va,
zhd=self._zh,
divgdd=self._divgd,
# zhd=self._zh,
# divgdd=self._divgd,
xfxd=self._xfx,
yfxd=self._yfx,
mfxd=state.mfxd,
Expand Down
53 changes: 53 additions & 0 deletions fv3core/pace/fv3core/stencils/remapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,6 +550,36 @@ def __call__(
mdt (in) : Remap time step
bdt (in): Timestep
"""

if self._call_checkpointer:
self._checkpointer(
"Remapping-In",
pt=pt,
delp=delp,
delz=delz,
peln=peln.transpose(
[X_DIM, Z_INTERFACE_DIM, Y_DIM]
), # [x, z, y] fortran data
u=u,
v=v,
w=w,
ua=ua,
va=va,
cappa=cappa,
pkz=pkz,
pk=pk,
pe=pe.transpose(
[X_DIM, Z_INTERFACE_DIM, Y_DIM]
), # [x, z, y] fortran data
phis=hs,
te_2d=te0_2d,
ps=ps,
wsd=wsd,
omga=omga,
ak=ak,
bk=bk,
dp1=dp1,
)
# TODO: remove unused arguments (and commented code that references them)
# TODO: can we trim ps or make it a temporary
# TODO: pe is copied into pe1 and pe2 for vectorization reasons in the Fortran,
Expand Down Expand Up @@ -693,3 +723,26 @@ def __call__(
else:
# converts virtual temperature back to virtual potential temperature
self._basic_adjust_divide_stencil(pkz, pt)

if self._call_checkpointer:
self._checkpointer(
"Remapping-Out",
pt=pt,
delp=delp,
delz=delz,
peln=peln.transpose(
[X_DIM, Z_INTERFACE_DIM, Y_DIM]
), # [x, z, y] fortran data
u=u,
v=v,
w=w,
cappa=cappa,
pkz=pkz,
pk=pk,
pe=pe.transpose(
[X_DIM, Z_INTERFACE_DIM, Y_DIM]
), # [x, z, y] fortran data
te_2d=te0_2d,
omga=omga,
dp1=dp1,
)
8 changes: 7 additions & 1 deletion fv3core/pace/fv3core/stencils/tracer_2d_1l.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import math
from typing import Dict
from typing import Dict, Optional

import gt4py.cartesian.gtscript as gtscript
from gt4py.cartesian.gtscript import PARALLEL, computation, horizontal, interval, region
Expand Down Expand Up @@ -189,6 +189,10 @@ def __init__(
config=stencil_factory.config.dace_config,
dace_compiletime_args=["tracers"],
)
self._checkpointer = checkpointer
# this is only computed in init because Dace does not yet support
# this operation
self._call_checkpointer = checkpointer is not None
grid_indexing = stencil_factory.grid_indexing
self.grid_indexing = grid_indexing # needed for selective validation
self._tracer_count = len(tracers)
Expand Down Expand Up @@ -285,6 +289,7 @@ def __call__(
x_courant (inout): accumulated courant number in x-direction
y_courant (inout): accumulated courant number in y-direction
"""
self._checkpoint_input(tracers, dp1, mfxd, mfyd, cxd, cyd)
# DaCe parsing issue
# if len(tracers) != self._tracer_count:
# raise ValueError(
Expand Down Expand Up @@ -390,3 +395,4 @@ def __call__(
# we can't use variable assignment to avoid a data copy
# because of current dace limitations
self._swap_dp(dp1, dp2)
self._checkpoint_output(tracers, dp1, mfxd, mfyd, cxd, cyd)
43 changes: 41 additions & 2 deletions physics/pace/physics/stencils/physics.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import typing
from typing import List

import gt4py.cartesian.gtscript as gtscript
Expand All @@ -12,6 +13,7 @@
)
from typing_extensions import Literal

import pace.dsl.gt4py_utils as utils
import pace.util
import pace.util.constants as constants
from pace.dsl.dace.orchestration import orchestrate
Expand All @@ -21,6 +23,7 @@
from pace.physics.stencils.get_phi_fv3 import get_phi_fv3
from pace.physics.stencils.get_prs_fv3 import get_prs_fv3
from pace.physics.stencils.microphysics import Microphysics
from pace.stencils.testing.translate import reshape_pace_variable_to_fortran_format
from pace.util import X_DIM, Y_DIM, Z_DIM
from pace.util.grid import GridData

Expand Down Expand Up @@ -209,13 +212,18 @@ def __init__(
grid_data: GridData,
namelist: PhysicsConfig,
active_packages: List[Literal[PHYSICS_PACKAGES]],
checkpointer: typing.Optional[pace.util.Checkpointer] = None,
):
self._checkpointer = checkpointer
# this is only computed in init because Dace does not yet support
# this operation
self._call_checkpointer = checkpointer is not None
orchestrate(
obj=self,
config=stencil_factory.config.dace_config,
dace_compiletime_args=["physics_state"],
)

self.grid_indexing = stencil_factory.grid_indexing
grid_indexing = stencil_factory.grid_indexing
self._setup_statein()
self._ptop = grid_data.ptop
Expand Down Expand Up @@ -276,7 +284,31 @@ def _setup_statein(self):
self._p00 = 1.0e5

def __call__(self, physics_state: PhysicsState, timestep: float):

if self._call_checkpointer:
self._checkpointer(
"GFSPhysicsDriver-In",
qvapor=physics_state.qvapor,
qliquid=physics_state.qliquid,
qrain=physics_state.qrain,
qsnow=physics_state.qsnow,
qice=physics_state.qice,
qgraupel=physics_state.qgraupel,
qo3mr=physics_state.qo3mr,
qsgs_tke=physics_state.qsgs_tke,
qcld=physics_state.qcld,
pt=physics_state.pt,
delp=physics_state.delp,
delz=physics_state.delz,
ua=physics_state.ua,
va=physics_state.va,
w=physics_state.w,
omga=physics_state.omga,
)
if self._call_checkpointer:
self._checkpointer(
"AtmosPhysDriverStatein-In",
pt=physics_state.pt,
)
self._atmos_phys_driver_statein(
self._prsik,
physics_state.phii,
Expand All @@ -295,6 +327,13 @@ def __call__(self, physics_state: PhysicsState, timestep: float):
physics_state.pt,
self._dm3d,
)
if self._call_checkpointer:
self._checkpointer(
"AtmosPhysDriverStatein-Out",
IPD_tgrs=reshape_pace_variable_to_fortran_format(
physics_state.pt, self.grid_indexing
),
)
self._get_prs_fv3(
physics_state.phii,
physics_state.prsi,
Expand Down
34 changes: 33 additions & 1 deletion stencils/pace/stencils/fv_update_phys.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import typing

import gt4py.cartesian.gtscript as gtscript
from gt4py.cartesian.gtscript import FORWARD, PARALLEL, computation, exp, interval, log

Expand All @@ -10,7 +12,7 @@
from pace.dsl.typing import Float, FloatField, FloatFieldIJ
from pace.stencils.c2l_ord import CubedToLatLon
from pace.stencils.update_dwind_phys import AGrid2DGridPhysics
from pace.util import X_DIM, Y_DIM
from pace.util import X_DIM, Y_DIM, Z_INTERFACE_DIM
from pace.util.grid import DriverGridData, GridData


Expand Down Expand Up @@ -92,7 +94,12 @@ def __init__(
state: fv3core.DycoreState,
u_dt: pace.util.Quantity,
v_dt: pace.util.Quantity,
checkpointer: typing.Optional[pace.util.Checkpointer] = None,
):
self._checkpointer = checkpointer
# this is only computed in init because Dace does not yet support
# this operation
self._call_checkpointer = checkpointer is not None
orchestrate(
obj=self,
config=stencil_factory.config.dace_config,
Expand Down Expand Up @@ -154,6 +161,31 @@ def __call__(
t_dt,
dt: float,
):
if self._call_checkpointer:
self._checkpointer(
"FVUpdatePhys-In",
u_dt=u_dt,
v_dt=v_dt,
t_dt=t_dt,
ua=state.ua,
va=state.va,
u=state.u,
v=state.v,
qvapor=state.qvapor,
qliquid=state.qliquid,
qice=state.qice,
qrain=state.qrain,
qsnow=state.qsnow,
qgraupel=state.qgraupel,
peln=state.peln.transpose(
[X_DIM, Z_INTERFACE_DIM, Y_DIM]
), # [x, z, y] fortran data,
delp=state.delp,
pt=state.pt,
ps=state.ps,
pe=state.pe.transpose([X_DIM, Z_INTERFACE_DIM, Y_DIM]),
pk=state.pk,
)
self._moist_cv(
state.qvapor,
state.qliquid,
Expand Down
Loading