-
Notifications
You must be signed in to change notification settings - Fork 0
/
helpers.py
1752 lines (1481 loc) · 71.6 KB
/
helpers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
#Packages needed
#GENERAL
import glob #for getting filepaths
from types import SimpleNamespace
# Suppress all runtime warnings
import warnings
warnings.filterwarnings("ignore")
import pdb
def debug(): pdb.set_trace() #shortcut
import os
#GRIDDED DATA HANDLING
import xarray as xr
xr.set_options(keep_attrs=True) #VERY IMPORTANT HERE, BUT BEWARE OF CONSEQUENCES WHEN USING UNITS!
import numpy as np
#GENERAL PLOTTING
import matplotlib.pyplot as plt
import matplotlib as mpl
import matplotlib.ticker as mticker
import matplotlib.dates as md
import seaborn as sns #for the colorpalette
#CARTOPY
import cartopy.crs as ccrs
projection=ccrs.PlateCarree()
import cartopy.feature as cfeature
from cartopy.mpl.gridliner import LONGITUDE_FORMATTER, LATITUDE_FORMATTER
#MISCELLANEA
from IPython.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))
"""
#enable smooth scrolling (only works in jupyter notebook)
%%javascript
Jupyter.keyboard_manager.command_shortcuts.remove_shortcut("up");
Jupyter.keyboard_manager.command_shortcuts.remove_shortcut("down");
"""
#set reasonable matplotlib fontsize settings from the start
mpl.rcParams.update({
'font.size': 12, # Controls default text sizes
'axes.labelsize': 14, # Fontsize of the x and y labels
'axes.titlesize': 16, # Fontsize of the axes title
'xtick.labelsize': 12, # Fontsize of the tick labels
'ytick.labelsize': 12, # Fontsize of the tick labels
'xtick.minor.width': 1.0,# Thickness of the minor x-axis ticks
'ytick.minor.width': 1.0,# Thickness of the minor y-axis ticks
'xtick.major.size': 8, # Length of the major x-axis ticks in points
'ytick.major.size': 8, # Length of the major y-axis ticks in points
'xtick.minor.size': 4, # Length of the minor x-axis ticks in points
'ytick.minor.size': 4, # Length of the minor y-axis ticks in points
'axes.grid': True, # Show grid or not
'grid.color': 'gray', # Color of the grid lines
'grid.linestyle': '--', # Style of the grid lines
'grid.alpha': 0.5, # Transparency of the grid lines
'legend.fontsize': 12, # Fontsize of the legend
'legend.title_fontsize': 14,# Fontsize of the legend title
'legend.frameon': True, # Draw a frame around the legend
'legend.edgecolor': 'black',# Color of the legend edge
'legend.facecolor': 'white',# Background color of the legend
'figure.figsize': (8, 5), # Figure size in inches
'figure.dpi': 300 # Dots per inch
})
#for resetting values
#mpl.rcParams.update(matplotlib.rcParamsDefault)
"""
PREPROCESSING
"""
def convert_xy(ds,y='nav_lat',x='nav_lon',grid=None):
"""
convert latitudes and longitudes from nested to normal.
The basic output format is a bit unconvenient for me when working with xarray
Input:
ds: xarray dataset (any shape)
y: name of latitudes in coordinates
x: name of longitudes in coordinates
Output:
ds: Dataset with simple 1D coordinates for latitudes and longitudes.
"""
latitudes = ds[y].values
longitudes = ds[x].values
# Flatten latitudes and longitudes. get rid of -1 that are sometimes in the model data
flat_latitudes = np.setdiff1d(np.unique(latitudes.flatten()),[-1])
flat_longitudes = np.setdiff1d(np.unique(longitudes.flatten()),[-1])
ds[y]=('lat',flat_latitudes)
ds[x]=('lon',flat_longitudes)
ds=ds.rename({x:'lon',y:'lat'})
if grid is None:
x='x';y='y'
else:
x='x_'+grid; y='y_'+grid;
ds=ds.rename({x:'lon',y:'lat'})
#rename 'inner' grid variables
for a in [x for x in list(ds.dims) if 'inner' in x]:
if 'y' in a:
ds=ds.rename({a:'lat'})
elif 'x' in a:
ds=ds.rename({a:'lon'})
ds=ds.set_index({'lat':'lat','lon':'lon'})
return ds
def mask_data(ds,ls_mask):
"""
Applying a land sea mask is straight forward to do with the xarray where function.
Coordinates with ls_mask 0 recieve value nan.
"""
#ensure that latitude and longitude values are the same (depending on the grid (U,V) they might not fit the ls_mask)
#if ~np.array_equal(ls_mask['lat'].values,ds['lat'].values):
ds['lat']=ls_mask['lat'].values
#if ~np.array_equal(ls_mask['lat'].values,ds['lon'].values):
ds['lon']=ls_mask['lon'].values
return xr.where(ls_mask,ds,np.nan)
"""
FUNCTIONS TO LOAD NEMO-BAMHBI SPECIFIC DATA
"""
def load_domain(
domain_path=None, #domain either a filepath, or an xarray dataarray. latter needed when reading in data (used for masking)
y='nav_lat', #name of y-dimension
x='nav_lon', #name of x-dimension
depth_var='deptht', #name of depth variable. eventually rename that later to match your model data (depends if data on U_, V_, W_ grid)
top_level='top_level',
):
"""
Preprocessing wrapper for domain file
1. Load
2. Convert Lat/Lon
3. Land Sea Masking
"""
if domain_path==None:
print('Nothing loaded. Please indicate File path to domain file of model run')
else:
global domain #such that it can be accessed for sure
domain=xr.open_dataset(domain_path).squeeze().rename({'nav_lev':depth_var})
domain=convert_xy(domain)
global ls_mask
ls_mask=domain[top_level]
return domain,ls_mask
def load_data(data_paths,
load=True, #set false if data bigger than RAM! this will take some time to load, but its better to do it at some point before plotting.
domain_data=None, #either specify a domain data array, or I will just take the global one.
grid=None, #e.g. for T variables name this argument 'grid_T'
time_var='time_counter',
ben_lev='benlvl', #also keep benthic level variables
variabs=None, #either a list of variables
domain_kwargs={
'y':'nav_lat', #name of y-dimension
'x':'nav_lon', #name of x-dimension
'depth_var':'deptht', #name of depth variable. eventually rename that later to match your model data (depends if data on U_, V_, W_ grid)
'top_level':'top_level',
},
):
"""
Preprocessing wrapper for data file
1. Load
1B. Select relevant variables
2. Convert Lat/Lon
3. Select 2D,3d,4D variables
4. Land Sea Masking
"""
#set name of dimensions according to the grid. If grid is None, just deduce grid based on file path name.
#y='nav_lat' #name of y-dimension
#x='nav_lon' #name of x-dimension
#depth_var='deptht' #name of depth variable. eventually rename that later to match your model data (depends if data on U_, V_, W_ grid)
#just take some filepath
if isinstance(data_paths,list):
p=os.path.splitext(os.path.basename(data_paths[0]))[0]
else:
p=os.path.splitext(os.path.basename(data_paths))[0]
if ('btrc_T' in p) or ('ptrc_T' in p):
y='nav_lat'
x='nav_lon'
depth_var='deptht'
grid=None
elif 'grid_T' in p:
y='nav_lat_grid_T';
x='nav_lon_grid_T'
depth_var='deptht'
grid='grid_T'
elif 'grid_U' in p:
y='nav_lat_grid_U';
x='nav_lon_grid_U'
depth_var='depthu'
grid='grid_U'
elif 'grid_V' in p:
y='nav_lat_grid_V';
x='nav_lon_grid_V'
depth_var='depthv'
grid='grid_V'
elif 'grid_W' in p:
y='nav_lat';
x='nav_lon'
depth_var='depthw'
grid=None
#LOAD
if variabs is not None:
preprocess=lambda ds: ds[variabs]
else:
preprocess=None
#try:
data=xr.open_mfdataset(data_paths,preprocess=preprocess)
#except:
# data=xr.open_mfdataset(data_paths,concat_dim='time',combine='nested')
if load:
data=data.load()
#CONVERT LAT/LON
data=convert_xy(data,y=y,x=x,grid=grid)
#SELECT 2D/3D/4D Variables
lat_var='lat'
lon_var='lon'
d2=sorted([lat_var,lon_var])
d3=sorted([time_var,lat_var,lon_var])
d4=sorted([time_var,lat_var,lon_var,depth_var])
d4b=sorted([time_var,lat_var,lon_var,ben_lev])
to_keep = [var for var in data.data_vars if sorted(data[var].dims) in [d2, d3, d4,d4b]]
data=data[to_keep]
#MASK DATA
if type(domain_data)==xr.core.dataset.Dataset:
data=mask_data(data,domain_data[domain_kwargs['top_level']])
elif type(domain_data)==str:
domain2=load_domain(domain_data,y=domain_kwargs['y'],x=domain_kwargs['x'],depth_var=domain_kwargs['depth_var'],top_level=domain_kwargs['top_level'])
data=mask_data(data,domain2[domain_kwargs['top_level']])
return data
#river discharge data: read from the forcing data (lat,lon,time). I provide scripts to calculate river discharge for high and low resolution
#discharge. The indices come from analysing where the forcing data is not nan
def load_rivers(path_prefix,domain,mode='LR',lim=None,time=['1960','2022']):
"""
Input:
path_prefix:
e.g. '/gpfs/scratch/acad/bsmfc/mchoblet/BSFS_BIO/GEO_LR/runoff/runoff_y*.nc'
or '/gpfs/scratch/acad/bsmfc/mchoblet/BSFS_BIO/GEO/runoff/mast_runoff_y*.nc'
domain:
domain file used to get surface area (needed for unit conversion)
mode (refering to the spatial model resolution that you use)
LR: 'low resolution forcing'
HR: 'low resolution'
lim:
integer, restrict river data to 'lim' files
Output:
Diictionary with one entry for each river.
"""
rivers_hr_dict={
'danube_1':[183,97],
'danube_2':[178,95],
'danube_3':[172,95],
'danube_4':[169,92],
'danube_5':[159,90],
'dnestr':[207,124],
'dnepr_1':[227,163],
'dnepr_2':[228,163],
'kizil':[34,343],
'rioni':[52,569],
'sakarya':[8,143],
}
rivers_lr_dict={
'danube_1':[29, 13],
'danube_2':[31, 13],
'danube_3':[34, 13],
'dnestr':[39, 18],
'dnepr': [41, 22],
'kizil':[7, 46],
'rioni':[10, 77],
'sakarya': [3, 17],
}
if mode.lower()=='hr':
rivers_dict=rivers_hr_dict
danubes=['danube_'+str(i) for i in range(1,6)]
dneprs=['dnepr_1','dnepr_2']
else:
rivers_dict=rivers_lr_dict
danubes=['danube_'+str(i) for i in range(1,4)]
dneprs=['dnepr']
#get rivers from North western shelf in a list
nws=[n for nn in [danubes,dneprs,['dnestr']] for n in nn]
#Get paths and combine data
rivers=glob.glob(path_prefix)
rivers.sort()
data=xr.open_mfdataset(rivers,concat_dim='time',combine='nested').load()
start=int(''.join(filter(str.isdigit, rivers[0])))
end=int(''.join(filter(str.isdigit, rivers[-1])))+1
#debug()
times=xr.cftime_range(start=str(start),end=str(end),freq='MS')[:-1]
data['time']=times.to_datetimeindex()
#compute cell area
area=area_2d(domain)
#Put Data into dict
final_dic={}
sum_l=[]
for k,v in rivers_dict.items():
d=(data['sorunoff'].isel(lat=v[0],lon=v[1])*area.isel(lat=v[0],lon=v[1]))/1000
d.attrs['unit']='m³/s'
sum_l.append(d.rename(k).drop(('lat','lon','time_counter')))
final_dic[k]=d
#also do the sum for total river discharge
sum_rivers=xr.merge(sum_l).to_array(dim='new').sum('new')
final_dic['all']=sum_rivers
#do north western shelf sum (DDD)
final_dic['nws']=xr.merge([final_dic[k].drop(('lat','lon')).rename(k) for k in nws]).to_array(dim='new').sum('new')
#do danubes and dniepr sums
if mode.lower()=='hr':
final_dic['dniepr']=xr.merge([final_dic[k].drop(('lat','lon')).rename(k) for k in dneprs]).to_array(dim='new').sum('new')
final_dic['danube']=xr.merge([final_dic[k].drop(('lat','lon')).rename(k) for k in danubes]).to_array(dim='new').sum('new')
##To-Do: Add BGC variables
return final_dic
"""
FUNCTIONS TO MASK DATA (DEPTH, LATITUDE, LONGITUDE).
The original shape of data is kept, just masking out values.
"""
def mask_depth(data,domain,d0=0,d1=100,mode='data',depth_var='deptht'):
"""
mask all values of all positions where bottom_depth not in [d0,d1] with nan.
Can be used to focus on the North Western Shelf for instance.
Input:
2d/3d/4d data
Output:
Masked 2d/3d/4d data or mask according to option mode (data/mask)
"""
if depth_var in data.dims:
dep=get_bottom(data,domain)
elif depth_var not in data.coords:
#add depth from domain file to data (depth_var) and level
dep=domain[depth_var].values[domain['bottom_level'].values]
data[depth_var]=(('lat','lon'),dep)
dep=data[depth_var].to_dataset()
else:
dep=data
depth=mask_data(dep[depth_var],ls_mask)
mask=xr.where((depth>=d0) & (depth<=d1),1,0)
mask=mask_data(mask,ls_mask)
if mode=='data':
data=xr.where(mask,data,np.nan)
out=data
elif mode=='mask':
out=mask
return out
def mask_latlon(data,lats,lons,ls_mask):
"""
Mask data according to latitude and longitude bounds. The purpose are regional mean computations.
For plots focussing on a specific region just cut the data to a specific region (e.g. the Black Sea North Western Shelf
.sel(lat=slice(42.2,47),lon=slice(27.5,34.1))
Input:
data: either regular 2D,3D,4D data, or simple 2d mask
lat=[lat1,lat2] latitudinal bounds of region to keep
lon=[lon1,lon2] longitudinal bounds of region to keep
Output:
mask (or masked data)
"""
data=data.where(((data.lat>=lats[0]) & (data.lat<=lats[1]) & (data.lon>=lons[0]) & (data.lon<=lons[1])))
data=mask_data(data,ls_mask)
return data
def mask_margin(mask):
"""
TO-DO: Return Margin of a mask (relevant for transport)
"""
return
def mask_closest_coastline(data):
"""
TO-DO: The masking in mask_margin can create lines disconnected from the coast.
This function adds straight lines to the neares coast
"""
return
def get_bottom(data,domain,depth_var='deptht'):
"""
Straight forward index selection using bottom level in domain.
Index needs to be reduced by one
"""
bot=domain['bottom_level'].drop('time_counter') #bottom level index
bot=xr.where(bot==0,1,bot) #replace zeroes by one to get the right depth for land positions (else -1, -> last value)
return data.isel({depth_var:(bot-1)})
def cell_volume(domain,grid='T'):
"""
Cell volume computed from scaling factors.
grid='T'/'U'/V'/W'
output in m^3
"""
idx=['e1X','e2X','e3X_0']
idx=[i.replace('X',grid.lower()) for i in idx]
vol=domain[idx[0]]*domain[idx[1]]*domain[idx[2]]
return vol
def area_2d(domain,grid='T'):
"""
Compute horizontal surface of grid cells.
We tacitly assume that the x/y sizes are independent of depth (which is usually the case)
output in m^2
"""
idx=['e1X','e2X']
idx=[i.replace('X',grid.lower()) for i in idx]
area=domain[idx[0]]*domain[idx[1]]
return area
"""
HELPERS FOR PLOTTING FUNCTIONS
reduce repeating parts
"""
def spatial_stat(d,stats):
stat_l=[]
for m in stats:
if m=='mean': stat=d.mean(('lat','lon'),skipna=True)
elif m=='std': stat=d.std(('lat','lon'),skipna=True)
elif m=='median': stat=d.median(('lat','lon'),skipna=True)
elif m=='min': stat=d.min(('lat','lon'),skipna=True)
elif m=='max': stat=d.max(('lat','lon'),skipna=True)
stat_l.append(stat.values)
stat_l=np.concatenate([stat_l])
return stat_l
def alignYaxes(axes, align_values=None,nticks=None):
'''
Align the ticks of multiple y axes
By stackoverflow user Jason: https://stackoverflow.com/questions/26752464/how-do-i-align-gridlines-for-two-y-axis-scales
Modification:
- Slightly modified (log2/2**) instead of 10 to get an acceptable range for rivers. Also else I would get negative values.
- To-Do: Investigate better what this function is actually doing
Args:
axes (list): list of axes objects whose yaxis ticks are to be aligned.
Keyword Args:
align_values (None or list/tuple): if not None, should be a list/tuple
of floats with same length as <axes>. Values in <align_values>
define where the corresponding axes should be aligned up. E.g.
[0, 100, -22.5] means the 0 in axes[0], 100 in axes[1] and -22.5
in axes[2] would be aligned up. If None, align (approximately)
the lowest ticks in all axes.
Returns:
new_ticks (list): a list of new ticks for each axis in <axes>.
A new sets of ticks are computed for each axis in <axes> but with equal
length.
'''
from matplotlib.pyplot import MaxNLocator
nax=len(axes)
ticks=[aii.get_yticks() for aii in axes]
if align_values is None:
aligns=[ticks[ii][0] for ii in range(nax)]
#aligns=[np.median(ticks[ii]) for ii in range(nax)]
else:
if len(align_values) != nax:
raise Exception("Length of <axes> doesn't equal that of <align_values>.")
aligns=align_values
bounds=[aii.get_ylim() for aii in axes]
# align at some points
ticks_align=[ticks[ii]-aligns[ii] for ii in range(nax)]
# scale the range to 1-100
ranges=[tii[-1]-tii[0] for tii in ticks]
#lgs=[-np.log10(rii)+2. for rii in ranges]
lgs=[-np.log2(rii)+2. for rii in ranges]
igs=[np.floor(ii) for ii in lgs]
#log_ticks=[ticks_align[ii]*(10.**igs[ii]) for ii in range(nax)]
log_ticks=[ticks_align[ii]*(2.**igs[ii]) for ii in range(nax)]
# put all axes ticks into a single array, then compute new ticks for all
comb_ticks=np.concatenate(log_ticks)
comb_ticks.sort()
steps=[1, 2, 2.5, 3, 4, 5, 8, 10]
steps=None
if nticks==None:
nticks='auto'
locator=MaxNLocator(nbins=nticks, steps=steps,integer=True)
else:
min_ticks=nticks
locator=MaxNLocator(nbins=nticks, steps=steps,integer=True,min_n_ticks=min_ticks)
new_ticks=locator.tick_values(comb_ticks[0], comb_ticks[-1])
#new_ticks=[new_ticks/10.**igs[ii] for ii in range(nax)]
new_ticks=[new_ticks/2.**igs[ii] for ii in range(nax)]
new_ticks=[new_ticks[ii]+aligns[ii] for ii in range(nax)]
# find the lower bound
idx_l=0
for i in range(len(new_ticks[0])):
if any([new_ticks[jj][i] > bounds[jj][0] for jj in range(nax)]):
idx_l=i-1
break
# find the upper bound
idx_r=0
for i in range(len(new_ticks[0])):
if all([new_ticks[jj][i] > bounds[jj][1] for jj in range(nax)]):
idx_r=i
break
# trim tick lists by bounds
new_ticks=[tii[idx_l:idx_r+1] for tii in new_ticks]
# set ticks for each axis
for axii, tii in zip(axes, new_ticks):
axii.set_yticks(tii)
return new_ticks
"""
Statistic functions/wrapper.
"""
def temp_resampling_stats(data,time_name='time_counter',averaging='all',mode='mean',diff_modus=None,specific_time=None,mask=None):
"""
Apply temporal resampling procedure to data. I keep the data as grouped/sampled object and only apply the statistics in a separate function.
I als compute the statistics directly here. this is not super efficient, but at least, this keeps the code overseeable.
time_name:
- time_counter for nemo-bamhbi, time for wam
averaging:
- the time period for which to do the averaging
mode:
- the statistics to compute (mean,std,sum,min,max,quant_5,quant_95,'nan_count','zero_count','nan_count_percentage','zero_count_percentage')
modus:
- normal mode just takes the statistics. diff_all substracts the overall statistic (e.g.)mean, diff_all_rel does the same in relative terms.
- diff substracts statistic for specif time period (e.g. mean from all decembers), diff_rel does the same but computes relative difference (in %)
"""
special_stats_1=['nan_count','zero_count']
special_stats_2=['nan_count_percentage','zero_count_percentage']
special_stats=special_stats_1 + special_stats_2
def stats(data,mode):
"""
data either resampled or grouped, that way the dimension is automatically selected right
"""
if mode=='mean':
data=data.mean(time_name,skipna=True)
elif mode=='median':
data=data.median(time_name,skipna=True)
elif mode=='std':
data=data.std(time_name,skipna=True)
elif mode=='sum':
data=data.sum(time_name,skipna=True)
elif mode=='min':
data=data.min(time_name,skipna=True)
elif mode=='max':
data=data.max(time_name,skipna=True)
elif mode=='quant_5':
data=data.quantile(0.05,dim=time_name,skipna=True)
elif mode=='quant_95':
data=data.quantile(0.95,dim=time_name,skipna=True)
return data
def stats_spec(data,mode,averaging):
"""
special statistics (nan counting or zero counting. it has to be applied before taking the temporal resampling)
data either resampled or grouped, that way the dimension is automatically selected right
"""
data=data.transpose(time_name,...) #bring time axis to first dimension to have an easier life...
#first operation
if (mode=='nan_count') or (mode=='nan_count_percentage'):
data=np.isnan(data)
elif (mode=='zero_count') or (mode=='zero_count_percentage'):
data=(data==0)
if averaging=='all':
data=data
elif averaging=='seasonal':
data=data.groupby(data[time_name].dt.season)
elif averaging=='seasonal_t':
data=data.resample({time_name:'QS-DEC'})
elif averaging=='yearly':
data=data.resample({time_name:'YS'})
elif averaging=='monthly':
data=data.groupby(data[time_name].dt.month)
elif averaging=='monthly_t':
data=data.resample({time_name:'MS'})
length=data.count(time_name)
data=data.sum(time_name)
if mode=='nan_count_percentage' or mode=='zero_count_percentage':
data=data/length*100
return data
if averaging=='all':
if mode in special_stats:
mean=stats_spec(data,mode,averaging)
else:
mean=stats(data,mode)
mean=mean.assign_attrs({time_name:['Average']})
if (diff_modus=='diff') or (diff_modus=='diff_rel'):
mean=xr.full_like(mean,np.nan)
elif averaging=='seasonal':
if mode in special_stats:
mean=stats_spec(data,mode,averaging)
else:
mean=stats(data.groupby(data[time_name].dt.season),mode)#.mean()
mean=mean.sortby(xr.DataArray(['DJF','MAM','JJA', 'SON'],dims=['season'])) #put that into right order
mean=mean.rename({'season':time_name})
elif averaging=='seasonal_t':
if mode in special_stats:
mean=stats_spec(data,mode,averaging)
else:
mean=stats(data.resample({time_name:'QS-DEC'}),mode)#.mean()
if diff_modus=='diff':
if mode in special_stats:
mean=stats_spec(data,mode,averaging)
else:
seasons=stats(data.groupby(data[time_name].dt.season),mode)
for i,t in enumerate(mean[time_name]):
s=t.dt.season.values
mean.loc[{time_name:t}]=mean.isel({time_name:i})-seasons.sel(season=s).values
elif diff_modus=='diff_rel':
if mode in special_stats:
mean=stats_spec(data,mode,averaging)
else:
seasons=stats(data.groupby(data[time_name].dt.season),mode)
for i,t in enumerate(mean[time_name]):
s=t.dt.season.values
mean.loc[{time_name:t}]=(mean.isel({time_name:i})-seasons.sel(season=s).values)/seasons.sel(season=s).values*100
elif averaging=='yearly':
if mode in special_stats:
mean=stats_spec(data,mode,averaging)
else:
mean=stats(data.resample({time_name:'YS'}),mode)
label=mean[time_name].dt.year.values
if diff_modus=='diff':
alls=stats(data,mode)
for i,t in enumerate(mean[time_name]):
#import pdb
#pdb.set_trace()
mean.loc[{time_name:t}]=mean.isel({time_name:i})-alls.values
elif diff_modus=='diff_rel':
alls=stats(data,mode)
for i,t in enumerate(mean[time_name]):
#import pdb
#pdb.set_trace()
mean.loc[{time_name:t}]=(mean.isel({time_name:i})-alls.values)/alls.values*100
elif averaging=='monthly':
if mode in special_stats:
mean=stats_spec(data,mode,averaging)
else:
mean=stats(data.groupby(data[time_name].dt.month),mode)#.mean()
label=mean['month'].values
mean=mean.rename({'month':time_name})
elif averaging=='monthly_t':
if mode in special_stats:
mean=stats_spec(data,mode,averaging)
else:
mean=stats(data.resample({time_name:'MS'}),mode)#.mean()
months=mean['time_counter'].dt.month
years=mean['time_counter'].dt.year
label=[str(m.values)+'-'+str(months[i].values) for i,m in enumerate(years)]
if diff_modus=='diff':
if mode in special_stats:
mean=stats_spec(data,mode,averaging)
else:
monthly=stats(data.groupby(data[time_name].dt.month),mode)
for i,t in enumerate(mean[time_name]):
s=t.dt.month.values
mean.loc[{time_name:t}]=mean.isel({time_name:i})-monthly.sel(month=s).values
elif diff_modus=='diff_rel':
if mode in special_stats:
monthly=stats_spec(data,mode,averaging)
else:
monthly=stats(data.groupby(data[time_name].dt.month),mode)
for i,t in enumerate(mean[time_name]):
s=t.dt.month.values
mean.loc[{time_name:t}]=(mean.isel({time_name:i})-monthly.sel(month=s).values)/monthly.sel(month=s).values*100
else:
raise NameError("Temporal mode unknown. Select from 'all','seasonal','seasonal_t','yearly','monthly','monthly_t")
if diff_modus=='diff_all':
if averaging=='all':
mean=xr.full_like(mean,np.nan)
else:
if mode in special_stats:
vals=special_stats(data,mode,'all')#.values
else:
vals=stats(data,mode)
mean=mean-vals
elif diff_modus=='diff_all_rel':
if averaging=='all':
mean=xr.full_like(mean,np.nan)
else:
if mode in special_stats:
s=special_stats(data,mode,'all')
else:
s=stats(data,mode)
mean=(mean-s)/s*100
elif diff_modus=='diff_rel':
pass
elif diff_modus=='diff':
pass
elif diff_modus==None:
pass
else:
raise NameError("Difference modus unknown")
#select specific time after the difference computation !
if time_name in mean.dims:
if isinstance(specific_time,int):
mean=mean.isel({time_name:specific_time})
lab_t=True
elif isinstance(specific_time,str):
mean=mean.sel({time_name:specific_time},method='nearest')
lab_t=True
else:
lab_t=False
else:
lab_t=False
#add labels after time selection!
if averaging=='seasonal_t':
seasons=mean[time_name].dt.season
years=mean[time_name].dt.year
label=[str(m.values)+'-'+str(seasons[i].values) for i,m in enumerate(years)]
elif averaging=='yearly':
label=mean[time_name].dt.year.values
elif averaging=='monthly_t':
months=mean['time_counter'].dt.month
years=mean['time_counter'].dt.year
label=[str(m.values)+'-'+str(months[i].values) for i,m in enumerate(years)]
elif averaging=='seasonal' or averaging=='monthly':
label=mean[time_name].values
if averaging=='all':
label=['Average']
#trigger to avoid 0-dim label
if lab_t: label=[label.item()]
mean=mean.assign_attrs({'time_label':label})
if mode in special_stats:
if 'ls_mask' in globals():
mean=mask_data(mean,ls_mask)
return mean
def spatial_mean(data,area=None,domain=None,depth_var='deptht',grid='T'):
"""
Spatial mean both 2d and 3d data
"""
if depth_var in data.dims:
mean_ax=('lat','lon',depth_var)
if area is not None:
if depth_var not in area.dims:
raise TypeError('Area needs to contain depth')
else:
mean_ax=('lat','lon')
if domain is not None:
#compute area
if depth_var in data.dims:
area=cell_volume(domain,grid=grid)
else:
area=area_2d(domain,grid=grid)
if area is not None:
data=data.weighted(area).mean(mean_ax,skipna=True)
else:
data=data.mean(mean_ax,skipna=True)
return data
def periodic_corr(x, y):
"""Periodic correlation, implemented using the FFT. Usual implementations in scipy don't include the periodic options.
x and y must be real sequences with the same length.
By Stack overflow user Warren Wackesser: https://stackoverflow.com/questions/28284257/circular-cross-correlation-python
See also: https://www.ocean.washington.edu/courses/ess522/lectures/08_xcorr.pdf
"""
cov=np.fft.ifft(np.fft.fft(x).conj() * np.fft.fft(y)).real
denom=np.sqrt(np.fft.fft(x*x)[0]*np.fft.fft(y*y)[0])
corr=(cov/denom).real
lag=np.argmax(corr)
max_corr=corr[lag]
return lag,max_corr
"""
PLOTTING FUNCTIONS
long list of arguments due to some important options to improve look of plot. for first impressions you can throw away most of them and use default config.
PLOT LIST.
SPATIAL MAP PLOTS
0. Show mask to visualize how you masked your data.
1. Single variable with multiple statistics in time
2. Multiple variables in Time
3. Compare 2 Model runs for a variable in time
4. Compare 2 Model runs for multipl
TIME SERIES PLOTS
1. 2D/3D Spatial average im time (mask data before to get a specific region) of multiple variables.
Includes options to plot river outflow, compare different model runs and variables visually and also quantitatively via (cross) correlations.
"""
def show_mask(da,time_name='time_counter',ls_mask=None,domain=None,title='',depth_var='deptht',ret=False,cartopy=True,cbar_kwargs={'label': 'mask','shrink':0.5},cmap='GnBu',
lat_step=2,lon_step=4,bathymetry=True):
if depth_var in da.dims:
#just select first level ...
da=da.isel({depth_var:0})
if type(da)==xr.core.dataset.Dataset:
some_var=list(da.data_vars)[0]
da=da[some_var]
elif type(da)==xr.core.dataarray.DataArray:
pass
else:
raise TypeError('please provide xarray object for <da>')
if ls_mask is None:
raise TypeError('provide land sea mask "ls_mask"')
if time_name in da.dims:
is_nan_axis = ~np.all(np.isnan(da.transpose(time_name,...)), axis=0)
else:
is_nan_axis = ~np.isnan(da)
if cartopy:
#mask and show rivers
projection=ccrs.PlateCarree()
fig,ax=plt.subplots(subplot_kw={'projection':projection})
is_nan_axis=mask_data(is_nan_axis,ls_mask)
is_nan_axis.plot(ax=ax,cbar_kwargs=cbar_kwargs,cmap=cmap,vmin=0)
ax.add_feature(cfeature.LAND)
ax.add_feature(cfeature.RIVERS,edgecolor='blue',linewidth=1.5)
#ax.add_feature(cfeature.COASTLINE,linewidth=1.5 # coatline can look a bit weird because it doesn't match
ax.set_title(title)
gl = ax.gridlines(crs=projection, draw_labels=True,linewidth=1, color='gray', alpha=0.5, linestyle='--')
gl.xlabels_top = False; gl.ylabels_right = False
#FIX GRIDLINE SPACING
min_lat=np.round(da['lat'].min().values); max_lat=np.round(da['lat'].max().values)
min_lon=np.round(da['lon'].min().values); max_lon=np.round(da['lon'].max().values)
lats=np.arange(min_lat,max_lat+lat_step,lat_step); lons=np.arange(min_lon,max_lon+lon_step,lon_step)
gl.xlocator = mticker.FixedLocator(lons); gl.xformatter = LONGITUDE_FORMATTER;
gl.ylocator = mticker.FixedLocator(lats); gl.yformatter = LATITUDE_FORMATTER
gl.xlabel_style = {'size': 8}; gl.ylabel_style = {'size': 8}
if bathymetry:
cmap = mpl.cm.GnBu
bounds=[0,50,100,500,1000,2000]
norm = mpl.colors.BoundaryNorm(bounds, cmap.N)
CS = domain['bathy_metry'].plot.contour(kwargs=dict(inline=True),levels=bounds,
ax=ax,cmap=cmap,norm=norm,add_colorbar=False,alpha=0.9)
ax.clabel(CS,colors='dimgray')
#ax.clabel(CS)
ax.set_title(title)
else:
is_nan_axis.plot(cbar_kwargs=cbar_kwargs,cmap=cmap,vmin=0)
plt.title(title)
if ret==True:
ret=is_nan_axis
else:
ret=None
#if bathymetry
return ret
def plot1_singlevar(data,
avgs=['all','seasonal','seasonal_t','yearly','monthly','monthly_t'],
stats=['mean','sum','std','min','max','quant_5','quant_95','nan_count','zero_count','nan_count_percentage','zero_count_percentage'],
diff_modus=None, #can also be diff, diff_rel, diff_all, diff_all,rel #repeated for each stats entry.
#this way we can display normal mean, NOTE THAT THIS DOES NOT WORK when 'all' in avgs.
specific_time=None, #can be used to reduce temporal data to a specific time (e.g. a specific year and season for seasonal_t or for monthly_t'
#therefore provide a time string (e.g. "2012-02") and closest moment is selected. the selection is done after time resampling/subtracting time means
#else one could just restrict the data to the start. for selection after seasonal resampling remember to select the startmonth (e.g. december if you want a djf average, february would give you MAM because time has been resampled
time_name='time_counter',
cmap='Reds',
vmax=None,vmin=None,
title='Bottom Oxygen',
height=1.8, width=5.2, y=1, labelsize=None, titlesize=None, left_pos=-0.5,
grid_kwargs={'on':True,'top':False,'right':False,'bottom':True,'left':True,
'draw_labels':True, 'lon_step':4,'lat_step':2,'size':10},
cartopy_kwargs={'land':True, 'coastline': False,'rivers': True,'linewidth': 0.5,},
colorbar_kwargs={ 'levels':11, #levels in the plot every second level has a plot
'nbins':5, #only show every 'nth label.
'shrink':0.9, 'pad':0.02},
statsbox_kwargs={'use':False,'stats':['mean','std'],'pos':[0.75,0.95],'prec': '.1f',#precision when printing numbers
'fontsize':8},#'stats':['mean','std','median','min','max'
unit=None
):
"""
Plot panel:
x-axis: statistics for single variable
y-axis: time
To select: One input variable
"""
##### DATA PREPARATION ####
if diff_modus is None: modus=[None for _ in range(len(stats))]
if type(diff_modus)!=list: diff_modus=[diff_modus]
if len(diff_modus)!=len(stats): diff_modus=[diff_modus[0] for _ in range(len(stats))]
#make a list if not a list
if type(avgs)!=list: avgs=[avgs]
if type(stats)!=list: stats=[stats]
time_labels=[];all_data=[]
#store data in a double layered list/dictionary first layer is the statistics, second is the type of statistic
for i,s in enumerate(stats):
da={}
for ii,a in enumerate(avgs):
res=temp_resampling_stats(data,averaging=a,mode=s,diff_modus=diff_modus[i],time_name=time_name,specific_time=specific_time)
if i==0: time_labels.append(res.attrs['time_label'])
da[a]=res
all_data.append(da)
#unpack list of time labels. lentgth needed for setting up plot
#try:
# #unpack list of time labels. lentgth needed for setting up plot
time_labels=[tt for t in time_labels for tt in t]