Skip to content

Commit

Permalink
Merge pull request #544 from matthewhoffman/landice/esmf_interp_fix
Browse files Browse the repository at this point in the history
Fix bugs in ESMF interpolation method

This merge fixes a number of issues with the 'esmf' method for interpolation in the MALI interpolation script:
* indexing was off by one, which made interpolation with unstructured source meshes garbage. For structured source meshes, it made interpolation shifted by one grid cell.
* support MPAS source fields with a vertical dimension when using the 'esmf' method
* refactor to use sparse matrix multiply, which speeds up interpolation a few hundred times
* add destination mesh area normalization support. This is necessary when using the ESMF 'conserve' method for destination cells that are only partly overlapped by the source cells.

Two issues remain for interpolating between two MPAS meshes with the 'esmf' method:

1. If the destination mesh is larger than the source mesh, those locations are filled with zeros. The ESMF 'conserve' method does not support extrapolation, and there is no obvious solution to this issue and it would need to be handled manually on a case by case basis.
2. Some fields are only defined on subdomains (e.g. temperature is only defined where ice thickness is nonzero) and the script currently has no mechanism for masking them. This results in garbage values getting interpolated in, e.g. temperature values near the margin will have values around 100 K because of interpolating realistic values around 250K with garbage values of 0K. This issue applies to the barycentric method as well. However, this situation can be worked around by performing extrapolation of the temperature field on the source mesh before doing interpolation between meshes.

This PR also includes a refactoring of create_SCRIP_file_from_planar_rectangular_grid.py that speeds it up by orders of magnitude for large meshes, as well as a minor update to define_cullMask.py
  • Loading branch information
matthewhoffman authored Jan 19, 2024
2 parents bcedec2 + b0e6147 commit 6c89e44
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 23 deletions.
4 changes: 2 additions & 2 deletions landice/mesh_tools_li/define_cullMask.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@
if keepCellMask[n] == 1:
keepCellMaskNew[iCell] = 1
keepCellMask = np.copy(keepCellMaskNew) # after we've looped over all cells assign the new mask to the variable we need (either for another loop around the domain or to write out)
print(' Num of cells to keep: {}'.format(sum(keepCellMask)))
print(f'Num of cells to keep: {keepCellMask.sum()}')

# Now convert the keepCellMask to the cullMask
cullCell[:] = np.absolute(keepCellMask[:]-1) # Flip the mask for which ones to cull
Expand Down Expand Up @@ -148,7 +148,7 @@
ind = np.nonzero(((xCell-xCell[iCell])**2 + (yCell-yCell[iCell])**2)**0.5 < dist)[0]
keepCellMask[ind] = 1

print(' Num of cells to keep:'.format(sum(keepCellMask)))
print(f'Num of cells to keep: {keepCellMask.sum()}')

# Now convert the keepCellMask to the cullMask
cullCell[:] = np.absolute(keepCellMask[:]-1) # Flip the mask for which ones to cull
Expand Down
33 changes: 24 additions & 9 deletions landice/mesh_tools_li/interpolate_to_mpasli_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import math
from collections import OrderedDict
import scipy.spatial
import scipy.sparse
import time
from datetime import datetime

Expand Down Expand Up @@ -64,9 +65,15 @@
S = wfile.variables['S'][:]
col = wfile.variables['col'][:]
row = wfile.variables['row'][:]
n_a = len(wfile.dimensions['n_a'])
n_b = len(wfile.dimensions['n_b'])
dst_frac = wfile.variables['frac_b'][:]
wfile.close()
#----------------------------

# convert to SciPy Compressed Sparse Row (CSR) matrix format
weights_csr = scipy.sparse.coo_array((S, (row - 1, col - 1)), shape=(n_b, n_a)).tocsr()

print('') # make a space in stdout before further output


Expand All @@ -78,15 +85,20 @@
#----------------------------

def ESMF_interp(sourceField):
# Interpolates from the sourceField to the destinationField using ESMF weights
# Interpolates from the sourceField to the destinationField using ESMF weights
destinationField = np.zeros(xCell.shape) # fields on cells only
try:
# Initialize new field to 0 - required
destinationField = np.zeros(xCell.shape) # fields on cells only
sourceFieldFlat = sourceField.flatten() # Flatten source field
for i in range(len(row)):
destinationField[row[i]-1] = destinationField[row[i]-1] + S[i] * sourceFieldFlat[col[i]]
# Convert the source field into the SciPy Compressed Sparse Row matrix format
# This needs some reshaping to get the matching dimensions
source_csr = scipy.sparse.csr_matrix(sourceField.flatten()[:, np.newaxis])
# Use SciPy CSR dot product - much faster than iterating over elements of the full matrix
destinationField = weights_csr.dot(source_csr).toarray().squeeze()
# For conserve remapping, need to normalize by destination area fraction
# It should be safe to do this for other methods
ind = np.where(dst_frac > 0.0)[0]
destinationField[ind] /= dst_frac[ind]
except:
'error in ESMF_interp'
print('error in ESMF_interp')
return destinationField

#----------------------------
Expand Down Expand Up @@ -328,7 +340,7 @@ def interpolate_field_with_layers(MPASfieldName):
if filetype=='cism':
print(' Input layer {}, layer {} min/max: {} {}'.format(z, InputFieldName, InputField[z,:,:].min(), InputField[z,:,:].max()))
elif filetype=='mpas':
print(' Input layer {}, layer {} min/max: {} {}'.format(z, InputFieldName, InputField[:,z].min(), InputField[z,:].max()))
print(' Input layer {}, layer {} min/max: {} {}'.format(z, InputFieldName, InputField[:,z].min(), InputField[:,z].max()))
# Call the appropriate routine for actually doing the interpolation
if args.interpType == 'b':
print(" ...Layer {}, Interpolating this layer to MPAS grid using built-in bilinear method...".format(z))
Expand All @@ -349,7 +361,10 @@ def interpolate_field_with_layers(MPASfieldName):
mpas_grid_input_layers[z,:] = InputField[:,z].flatten()[nn_idx_cell] # 2d cism fields need to be flattened. (Note the indices were flattened during init, so this just matches that operation for the field data itself.) 1d mpas fields do not, but the operation won't do anything because they are already flat.
elif args.interpType == 'e':
print(" ...Layer{}, Interpolating this layer to MPAS grid using ESMF-weights method...".format(z))
mpas_grid_input_layers[z,:] = ESMF_interp(InputField[z,:,:])
if filetype=='cism':
mpas_grid_input_layers[z,:] = ESMF_interp(InputField[z,:,:])
elif filetype=='mpas':
mpas_grid_input_layers[z,:] = ESMF_interp(InputField[:,z])
else:
sys.exit('ERROR: Unknown interpolation method specified')
print(' interpolated MPAS {}, layer {} min/max {} {}: '.format(MPASfieldName, z, mpas_grid_input_layers[z,:].min(), mpas_grid_input_layers[z,:].max()))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,18 +125,19 @@
print ('Filling in corners of each cell.')
grid_corner_lon_local = np.zeros( (nx * ny, 4) ) # It is WAYYY faster to fill in the array entry-by-entry in memory than to disk.
grid_corner_lat_local = np.zeros( (nx * ny, 4) )
for j in range(ny):
for i in range(nx):
iCell = j*nx + i

grid_corner_lon_local[iCell, 0] = stag_lon[j, i]
grid_corner_lon_local[iCell, 1] = stag_lon[j, i+1]
grid_corner_lon_local[iCell, 2] = stag_lon[j+1, i+1]
grid_corner_lon_local[iCell, 3] = stag_lon[j+1, i]
grid_corner_lat_local[iCell, 0] = stag_lat[j, i]
grid_corner_lat_local[iCell, 1] = stag_lat[j, i+1]
grid_corner_lat_local[iCell, 2] = stag_lat[j+1, i+1]
grid_corner_lat_local[iCell, 3] = stag_lat[j+1, i]

jj = np.arange(ny)
ii = np.arange(nx)
i_ind, j_ind = np.meshgrid(ii, jj)
cell_ind = j_ind * nx + i_ind
grid_corner_lon_local[cell_ind, 0] = stag_lon[j_ind, i_ind]
grid_corner_lon_local[cell_ind, 1] = stag_lon[j_ind, i_ind + 1]
grid_corner_lon_local[cell_ind, 2] = stag_lon[j_ind + 1, i_ind + 1]
grid_corner_lon_local[cell_ind, 3] = stag_lon[j_ind + 1, i_ind]
grid_corner_lat_local[cell_ind, 0] = stag_lat[j_ind, i_ind]
grid_corner_lat_local[cell_ind, 1] = stag_lat[j_ind, i_ind + 1]
grid_corner_lat_local[cell_ind, 2] = stag_lat[j_ind + 1, i_ind + 1]
grid_corner_lat_local[cell_ind, 3] = stag_lat[j_ind + 1, i_ind]

grid_corner_lon[:] = grid_corner_lon_local[:]
grid_corner_lat[:] = grid_corner_lat_local[:]
Expand Down Expand Up @@ -171,3 +172,4 @@

fin.close()
fout.close()
print('scrip file generation complete')

0 comments on commit 6c89e44

Please sign in to comment.