-
Notifications
You must be signed in to change notification settings - Fork 12
/
interpolation.py
86 lines (55 loc) · 2.52 KB
/
interpolation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import numpy as np
from scipy import interpolate
def all_invalid(array, tol=5e-2):
""" Checks if 3d array contains all invalid values.
:param tol: tollerance ratio of invalid values
"""
masked_array = np.ma.masked_invalid(array)
c, cols, rows = array.shape
return np.sum(masked_array.mask) >= c * cols * rows * tol
def contain_invalid(masked_array):
"""Checks to see if the array contain any 1s, which would indicate NaNs in the swath."""
return np.sum(masked_array.mask) > 0
# ------------------------------------------------------------------------------ INTERPOLATION
def fill_channel(masked_array, xx, yy, method="nearest"):
"""
Inplace function: it fills all invalid valued by spatial interpolation
:param swath (numpy.array): array of size (nb_channels, height, width)
:param method (string): method for the interpolation. Check scipy.interpolate.griddata for possible methods
:return: list of channels that have been filled
"""
x1 = xx[~masked_array.mask]
y1 = yy[~masked_array.mask]
new_mask = masked_array[~masked_array.mask]
inter = interpolate.griddata((x1, y1), new_mask.ravel(), (xx, yy), method=method, fill_value=True)
return inter
def fill_all_channels(swath, method="nearest"):
"""
Inplace function: it fills all invalid valued by spatial interpolation a channel at a time
:param swath (numpy.array): array of size (nb_channels, height, width)
:param method (string): method for the interpolation. Check scipy.interpolate.griddata for possible methods
:return: list of channels that have been filled or were already full
"""
swath_shape = swath.shape
x, y = np.arange(0, swath_shape[2]), np.arange(0, swath_shape[1])
xx, yy = np.meshgrid(x, y)
full_channels = []
for i, ch_array in enumerate(swath):
masked_array = np.ma.masked_invalid(ch_array)
if contain_invalid(masked_array):
try:
inter = fill_channel(masked_array, xx, yy, method)
swath[i] = inter
full_channels.append(i)
except:
pass
else:
full_channels.append(i)
return full_channels
if __name__ == "__main__":
# kinda test all invalid
all_inv_array = np.array([[[np.NaN] * 10] * 4] * 3)
assert all_invalid(all_inv_array)
partial_inv_array = np.zeros((3, 7, 9))
partial_inv_array[0] = np.NaN
assert not all_invalid(partial_inv_array)