-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutilities.py
More file actions
340 lines (271 loc) · 13.1 KB
/
utilities.py
File metadata and controls
340 lines (271 loc) · 13.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
# import packages that are needed to run the modules
import xarray as xr
import numpy as np
import os
from haversine import haversine
from tqdm import tqdm
from scipy.stats import halfnorm
# print utility
def print_iteration(i,every_ith=100):
if (i % every_ith) == 0:
print(i)
# Function to calculate weighted global mean
def xr_global_mean_weighted(ds,lon='lon',lat='lat'):
# find the cosine (area) weighting over latitude
weights = np.cos(np.deg2rad(ds[lat]))
# calculate the global mean of the desired METRIC
weighted_ds=ds.weighted(weights).mean((lon, lat))
# return weighted data structure to above program level
return(weighted_ds)
# define functions to use xarray to convert longitudes from 0-360E to 180W-180E, and vice-versa
# From solution at https://stackoverflow.com/questions/53345442/about-changing-longitude-array-from-0-360-to-180-to-180-with-python-xarray
def xr_convert_lon_to180(ds,lonkey='lon'):
dsort=ds.copy(deep=True)
lon360=ds[lonkey]
lon180 = (lon360-1e-6 + 180) % 360 - 180
dsort.coords[lonkey] = lon180
dsort = dsort.sortby(dsort[lonkey])
return(dsort)
# Function to calculate (for xarray) the anomaly relative to a slice over time, but more general
def xr_anomaly_slice(ds,dim_key='time',slice_in=slice(1991,2020)):
ds_anom = ds - ds.sel({dim_key: slice_in}).mean(dim_key)
return(ds_anom)
# Function to calculate the weighted average along an array's index
def nanaverage(a, axis=None, weights=None, returned=False):
"""
Compute the weighted average along the specified axis, ignoring NaN values.
Parameters:
- a : array_like
Data to compute the average of.
- axis : int, optional
Axis along which to compute the average. Default is to average over the flattened array.
- weights : array_like, optional
An array of weights associated with the values in `a`.
- returned : bool, optional
If True, the sum of weights is returned in addition to the average.
Returns:
- weighted average : ndarray or scalar
The weighted average of the array elements.
- sum of weights : scalar, optional
Only returned if `returned` is True.
"""
# Ensure the input is a numpy array
a = np.asarray(a)
# If weights are provided, ensure it is a numpy array and broadcastable to 'a'
if weights is not None:
weights = np.asarray(weights)
if weights.shape != a.shape:
raise ValueError("weights should have the same shape as a.")
# Mask NaN values in the input array
mask = ~np.isnan(a)
# If no weights provided, perform a simple average, ignoring NaNs
if weights is None:
weights = np.ones_like(a)
# Apply the mask to both data and weights
masked_a = np.where(mask, a, 0)
masked_weights = np.where(mask, weights, 0)
# Compute the weighted average
weighted_avg = np.sum(masked_a * masked_weights, axis=axis) / np.sum(masked_weights, axis=axis)
if returned:
# Return the weighted average and the sum of weights
return weighted_avg, np.sum(masked_weights, axis=axis)
else:
# Return only the weighted average
return weighted_avg
# Function to take a slice between two years with dt64 type
def dt64_yrslice(lower_yr=1985,upper_yr=2015):
slice_range=[str(int(lower_yr))+'-01-01',str(int(upper_yr))+'-12-31']
sliceout=slice(slice_range[0],slice_range[1])
return(sliceout)
# Function to load all files in a directory by extension
def list_load_files_by_extension(path=".", extension=".nc"):
"""
Lists all files with a given extension in a specified directory.
Parameters:
path (str): The directory to search in. Defaults to the current directory (".").
extension (str): The file extension to filter by. Defaults to ".nc".
Returns:
list: A list of file names with the specified extension.
"""
if not extension.startswith("."):
extension = f".{extension}" # Ensure the extension starts with a dot.
try:
# List all files in the directory with the specified extension
files = [f for f in os.listdir(path) if f.endswith(extension) and os.path.isfile(os.path.join(path, f))]
return files
except FileNotFoundError:
print(f"Error: The directory '{path}' does not exist.")
return []
except PermissionError:
print(f"Error: Permission denied to access the directory '{path}'.")
return []
# function that checks to see whether a specific path exists locally
# returns True or False boolean
def path_exists_locally(path: str) -> bool:
"""
Check if a directory path exists locally.
Parameters:
path (str): Path to the directory.
Returns:
bool: True if the directory exists, False otherwise.
"""
return os.path.exists(path)
# Function to calculate the distance between two global locations
# using the haversine formula
def distance_between_locations(lon1,lat1,lon2,lat2):
# lon/lat inputs need to be pandas dataframes
# make sure that we have the right length input datasets
if (len(lon1)!=len(lat1)) or (len(lon2)!=len(lat2)):
error('Station lengths input to function are not internally consistent')
# if we have the right lengths, initialize
else:
nloc1,nloc2=len(lon1),len(lon2)
# initialize the output array
distance=np.zeros((nloc1,nloc2),dtype='float')*np.nan
# loop over each of the first location indices
for jj in range(nloc1):
# get the geocoordinates for the jjth location
jjloc=(lat1.iloc[jj],lon1.iloc[jj])
# loop over the second set of locations
for kk in range(nloc2):
# get the geocoordinates for the kkth location
kkloc=(lat2.iloc[kk],lon2.iloc[kk])
# calculate the distance between jj and kk, in kilometers
distance_km=haversine(jjloc, kkloc)
distance[jj,kk]=distance_km
# go back to the above program level
return(distance)
# Vectorized function to calculate the distance using xarray
# and the haversine formula
def distance_between_locations_xarray(lon1, lat1, lon2, lat2):
"""
Apply the haversine function over an xarray grid of lon1, lat1 and a single location (lon2, lat2).
Returns a grid of distances the same shape as lon1, lat1.
Parameters:
- lon1, lat1: xarray DataArrays (longitude and latitude grids)
- lon2, lat2: Single longitude and latitude point (float)
- haversine: Custom haversine function that calculates distance between two points
"""
# Wrapper to apply haversine function element-wise
def haversine_vectorized(lat1, lon1):
return haversine((lat1, lon1), (lat2, lon2))
# Apply the haversine function over the grid using apply_ufunc
distance = xr.apply_ufunc(
haversine_vectorized, # The function to apply
lat1, lon1, # Input variables (lat1 and lon1 grids)
vectorize=True, # Enable vectorization
dask="parallelized", # If you are working with Dask, enable parallelization
output_dtypes=[float] # Specify output data type
)
# Return the resulting distance grid as xarray DataArray
return distance
# Function to interpolate data from a grid to Tide Gauge stations
def calculate_location_from_grid(datin,stn_filename_list,tg_latlon,latd,lond,nearkm=78.,close_enough_km=0.1):
# define the arrays we will populate
stn_weight_avg_out = np.zeros((len(stn_filename_list),),dtype='float')*np.nan
stn_weighted_out = np.zeros((len(stn_filename_list),),dtype='float')*np.nan
stn_nearest_out = np.zeros((len(stn_filename_list),),dtype='float')*np.nan
# calculate the area-weighting
hwwm_sigma=(nearkm/2)/(2*np.sqrt(2*np.log(2))/2)
ydist=halfnorm(loc=0, scale=hwwm_sigma)
# loop over GESLA locations and find the weighted average of the nearby grid
for si,stnname in enumerate(stn_filename_list):
# find the subset slice of the grid that we need
stnlon = tg_latlon[stnname][1][0]
stnlat = tg_latlon[stnname][0][0]
lonslice = slice(stnlon - lond-0.001, stnlon + lond+0.001)
latslice = slice(stnlat - latd-0.001, stnlat + latd+0.001)
# slice out the station's nearby data
subset = datin.sel(lon=lonslice,lat=latslice)
# calculate the distances across the lat/lon slice
distances_subset = distance_between_locations_xarray(
subset['lon'], subset['lat'],
stnlon, stnlat)
# if any values are within 0.1km, take that as the actual value
superclose_subset = xr.where(distances_subset <= close_enough_km,subset,np.nan,keep_attrs=True)
superclose_distances = xr.where(distances_subset <= close_enough_km,distances_subset,np.nan,keep_attrs=True)
if np.sum(~np.isnan(superclose_subset))>0:
weighted_avg = nanaverage(superclose_subset.values.ravel(),
weights=ydist.sf(superclose_distances.values.ravel()))
stn_weighted_out[si,] = weighted_avg
stn_nearest_out[si,] = weighted_avg
# print('superclose ----------------------------')
continue
# find the locations within XX km
withinkm_subset = xr.where(distances_subset <= nearkm,subset,np.nan)
withinkm_distances = xr.where(distances_subset <= nearkm,distances_subset,np.nan)
# calculate the NaN ignoring weighted average within XX km
if np.sum(~np.isnan(withinkm_subset))>0:
weighted_avg = nanaverage(withinkm_subset.values.ravel(),
weights=ydist.sf(withinkm_distances.values.ravel()))
stn_weighted_out[si,] = weighted_avg
stn_nearest_out[si,] = subset.where(~np.isnan(withinkm_subset), drop=True).sel(lon = stnlon, lat = stnlat, method='nearest')
# print('withinXX ----------------------------')
# if we aren't close, set to missing
else:
stn_weighted_out[si,] = np.nan
stn_nearest_out[si,] = np.nan
# print('else ----------------------------')
# store the results and return to the top level
final_values_out = stn_nearest_out
final_values_out[np.isnan(final_values_out)] = stn_weighted_out[np.isnan(final_values_out)]
out = {}
out['final'] = final_values_out
out['nearest'] = stn_nearest_out
out['weighted'] = stn_weighted_out
return(out)
# Function to apply regridding over all years over the SLR gridded data
def apply_over_years(datin, stn_filename_list, tg_latlon, latd, lond, nearkm=78, close_enough_km=0.1,interp_style='final'):
"""
Apply the original function over all years for each station and return results in an xarray dataset.
"""
# Create lists to store the final output
station_indices = np.arange(len(stn_filename_list))
station_names = []
station_lats = []
station_lons = []
results_list = []
# Loop over each station in the list
for si, stnname in tqdm(enumerate(stn_filename_list)):
print_iteration(si, every_ith=1000)
# Get the station latitude and longitude
stn_lat = tg_latlon[stnname][0][0]
stn_lon = tg_latlon[stnname][1][0]
# Store station metadata
station_names.append(stnname)
station_lats.append(stn_lat)
station_lons.append(stn_lon)
# Initialize an empty list to store results for this station across all years
station_results = []
# Loop over each time step (year) in the dataset
for year in datin.time:
# print(f"Processing year {year.values}...")
# Select the data for the current year
datin_year = datin.sel(time=year)
# Apply the original function to calculate results for this station and year
result = calculate_location_from_grid(datin_year, [stnname], tg_latlon, latd, lond, nearkm, close_enough_km)
# Append the result for this year
station_results.append(result[interp_style][0])
# Append the station results to the list
results_list.append(station_results)
# Convert lists to numpy arrays
results_array = np.array(results_list)
# Create an xarray dataset to hold the results
ds = xr.Dataset(
{
'results': (('station', 'time'), results_array), # Main results array
'station_names': (('station'), np.array(station_names)), # Station names
'station_latitudes': (('station'), np.array(station_lats)), # Station latitudes
'station_longitudes': (('station'), np.array(station_lons)) # Station longitudes
},
coords={
'station': station_indices, # Station index
'time': datin.time.values # Time (years)
}
)
return(ds)
# Function to calculate increments in an array over time
def get_ds_time_increments(ds,timekey='time'):
ds_inc = ds.diff(dim=timekey, label="upper")
ds_inc = xr.concat([ds.isel(time=0) * 0, ds_inc ], dim=timekey)
return(ds_inc)