Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ regions/Norway/data/

logs/
venv/
publish/

*.parquet
*.pt
Expand Down
2 changes: 1 addition & 1 deletion .readthedocs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ build:
- pip install poetry
- poetry config virtualenvs.create false
post_install:
- VIRTUAL_ENV=$READTHEDOCS_VIRTUALENV_PATH poetry install --with docs
- VIRTUAL_ENV=$READTHEDOCS_VIRTUALENV_PATH poetry install --with docs --without publish,viz,advanced


# Build documentation in the "docs/" directory with Sphinx
Expand Down
3 changes: 3 additions & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,9 @@

templates_path = ["_templates"]

# Print timing report
duration_show_summary = True

# -- Options for EPUB output
epub_show_urls = "footnote"

Expand Down
12 changes: 11 additions & 1 deletion massbalancemachine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

sys.path.append(os.path.dirname(os.path.realpath(__file__)))

import importlib

__all__ = [
"dataloader",
"models",
Expand All @@ -24,6 +26,14 @@
import training
import plots
import metrics
import sampling

# import sampling # Do not import by default since this is an advanced feature
import utils
from .config import * # Load config at the top level of the package


# Import only if the user asks for it
def __getattr__(name):
if name == "sampling":
return importlib.import_module(".sampling", __name__)
raise AttributeError(f"module {__name__} has no attribute {name}")
3 changes: 3 additions & 0 deletions massbalancemachine/data_preprocessing/wgms.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ def build_monthly_data(data, cfg, rgi_region=None):
~((data.RGIId == "RGI60-01.23646") & (data.MONTH_DIFF == 24))
] # West Yakutat

# Discard points before 1950 since ERA5 Land does not cover this period
data = data[data.YEAR > 1950]

region_name = get_region_name(rgi_region)
dataset = Dataset(cfg, data=data, region_name=region_name, region_id=rgi_region)

Expand Down
9 changes: 5 additions & 4 deletions massbalancemachine/dataloader/DataLoader.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,14 +69,15 @@ def set_train_test_split(
*,
test_size: float = None,
type_fold: str = "group-meas-id",
random_state: bool = False,
use_random_seed: bool = True,
) -> Tuple[Iterator[Any], Iterator[Any]]:
"""
Split the dataset into training and testing sets.

Args:
test_size (float): Proportion of the dataset to include in the test split.
type_fold (str): Type of splitting between train and test sets. Options are 'group-rgi','group-c_region' or 'group-meas-id'.
use_random_seed (bool): Whether the random seed should be used or not. This allows having reproducible splits.

Returns:
Tuple[Iterator[Any], Iterator[Any]]: Iterators for training and testing indices.
Expand All @@ -96,13 +97,13 @@ def set_train_test_split(
X, y, glacier_ids, stake_meas_id, regions = self._prepare_data_for_cv(
self.data, self.meta_data_columns
)
if random_state == False:
if use_random_seed:
gss = GroupShuffleSplit(
n_splits=1,
test_size=test_size,
random_state=self.random_seed, # commenting this improve randomness
random_state=self.random_seed,
)
elif random_state == True:
else:
gss = GroupShuffleSplit(
n_splits=1,
test_size=test_size,
Expand Down
13 changes: 10 additions & 3 deletions massbalancemachine/models/TorchNeuralNetworkRegressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,18 @@

def createModel(cfg, modelParams):
nInp = len(cfg.featureColumns)
dropout = modelParams.get("dropout", 0.0)
if modelParams["type"] == "sequential":
assert len(modelParams["layers"]) > 0
l = [nn.Linear(nInp, modelParams["layers"][0])]
for i in range(len(modelParams["layers"]) - 1):
l.append(nn.ReLU())
if dropout > 0:
l.append(nn.Dropout(dropout))
l.append(nn.Linear(modelParams["layers"][i], modelParams["layers"][i + 1]))
l.append(nn.ReLU())
if dropout > 0:
l.append(nn.Dropout(dropout))
l.append(nn.Linear(modelParams["layers"][-1], 1))
network = nn.Sequential(*l)
return network
Expand Down Expand Up @@ -178,12 +183,14 @@ def cumulative_pred(self):
# TODO: implement
pass

def evaluate_group_pred(self, geodataloader):
def evaluate_group_pred(self, geodataloader, val=False):
grouped_ids = pd.DataFrame()
with torch.no_grad():
for g in geodataloader.glaciers():
iterator = geodataloader.glaciersVal if val else geodataloader.glaciers
stakeMethod = geodataloader.stakesVal if val else geodataloader.stakes
for g in iterator():
# Get input features, metadata and ground truth
stakes, metadata, point_balance = geodataloader.stakes(g)
stakes, metadata, point_balance = stakeMethod(g)
idAggr = metadata["ID"].values

# Make prediction
Expand Down
1 change: 1 addition & 0 deletions massbalancemachine/plots/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,6 @@
from plots.profile_plots import profilePerGlacier
from plots.temporal_plots import cumulatedMassChange
from plots.input_plot import histogram_mb, scatterplot_mb
from plots.map_plots import mapGlacier
from plots.train_plots import plot_training_history
from plots.style import use_mbm_style, COLOR_ANNUAL, COLOR_WINTER
68 changes: 68 additions & 0 deletions massbalancemachine/plots/map_plots.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import xarray as xr
import pyproj
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors

import data_processing


def mapGlacier(df, rgi_id, year, cfg, ax=None, max_abs=None, title=None, gdir=None):

df_glacier_year = df[(df.RGIId == rgi_id) & (df.YEAR == year)]

if gdir is None:
# Initialize the OGGM Config
data_processing.oggm_utils._initialize_oggm_config("")
gdir = data_processing.oggm_utils._initialize_glacier_directories(
[rgi_id], cfg
)[0]

with xr.open_dataset(gdir.get_filepath("gridded_data")) as ds:
ds = ds.load()

# Coordinate transformation from WGS84 to the projection of OGGM data
transf = pyproj.Transformer.from_proj(
pyproj.CRS.from_user_input("EPSG:4326"),
pyproj.CRS.from_user_input(ds.pyproj_srs),
always_xy=True,
)
lon = df_glacier_year["POINT_LON"].to_numpy()
lat = df_glacier_year["POINT_LAT"].to_numpy()
x, y = transf.transform(lon, lat)

# Convert projected coordinates to nearest grid indices
col = np.round((x - ds.x.values[0]) / (ds.x.values[1] - ds.x.values[0])).astype(int)
row = np.round((y - ds.y.values[0]) / (ds.y.values[1] - ds.y.values[0])).astype(int)

# Make an empty grid matching OGGM's gridded_data
heat = xr.full_like(ds.topo, np.nan)
valid = (row >= 0) & (row < ds.sizes["y"]) & (col >= 0) & (col < ds.sizes["x"])
assert all(valid), "Some of the projected points fall outside of the OGGM grid"
heat.values[row[valid], col[valid]] = df_glacier_year.loc[valid, "pred"].to_numpy()

# Get background topography
smap = ds.salem.get_map(countries=False)
smap.set_shapefile(gdir.read_shapefile("outlines"))
smap.set_topography(ds.topo.data)

# Build color normalization (white is MB=0)
max_abs = max_abs or df_glacier_year["pred"].abs().max()
norm = mcolors.TwoSlopeNorm(vmin=-max_abs, vcenter=0, vmax=max_abs)

if ax is None:
fig, ax = plt.subplots(figsize=(9, 9))
else:
fig = None

# Plot annual MB
smap.set_cmap("RdBu")
smap.set_norm(norm)
smap.set_data(heat)
smap.plot(ax=ax)
smap.append_colorbar(ax=ax, label="Annual MB (m.w.e.)")
ax.set_title(title or f"{rgi_id} year {year}")

plt.tight_layout()

return fig
6 changes: 4 additions & 2 deletions massbalancemachine/plots/perf_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,9 +291,10 @@ def predVSTruthGlacierWide(
geoErr,
ax=None,
title="Glacier wide MB",
ax_xlim=(-1.5, 0.5),
ax_xlim=(-1.5, 1.0),
ax_ylim=(-1.5, 1.0),
color="orange",
legend=False,
):

if ax is None:
Expand All @@ -305,7 +306,8 @@ def predVSTruthGlacierWide(
ax.errorbar(
geoTarget[g], geoPred[g], xerr=2 * geoErr[g], label=g, fmt="o", color=color
)
plt.text(geoTarget[g] + 0.02, geoPred[g] + 0.02, g, fontsize=10)
if legend:
plt.text(geoTarget[g] + 0.02, geoPred[g] + 0.02, g, fontsize=10)

# Diagonal line
pt = (0, 0)
Expand Down
23 changes: 21 additions & 2 deletions massbalancemachine/plots/temporal_plots.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import FormatStrFormatter
from calendar import month_abbr
import time

Expand Down Expand Up @@ -68,7 +69,7 @@ def cumulatedMassChange(
first_year = np.sort(monthly_df.YEAR.unique())[0]
t = np.concatenate([[first_year], t])
y = np.concatenate([[0.0], y])
ax.plot(t, np.cumsum(y), color=color_pred)
(line,) = ax.plot(t, np.cumsum(y), color=color_pred)

nyear = monthly_df.YEAR.nunique()
if geo is not None and test_gl in geo:
Expand All @@ -95,6 +96,24 @@ def cumulatedMassChange(
glacier_title = titles.get(test_gl) if titles is not None else None
ax.set_title(glacier_title or test_gl.capitalize(), fontsize=20)

ax.tick_params(axis="x", labelsize=12)
ax.tick_params(axis="y", labelsize=12)
step_years_xticks = nyear // 10
ax.set_xticks(
np.arange(
first_year, first_year + nyear + step_years_xticks, step_years_xticks
)
)
ax.xaxis.set_major_formatter(FormatStrFormatter("%.0f"))

# Remove unused axes
for i in range(len(custom_order), len(axs)):
if isinstance(axs, list):
ax = axs[i]
else:
ax = axs.flatten()[i]
ax.set_visible(False)

# # Set axes limits
# if ax_xlim is not None:
# ax.set_xlim(ax_xlim)
Expand All @@ -103,4 +122,4 @@ def cumulatedMassChange(

plt.tight_layout()

return fig
return fig, line
1 change: 1 addition & 0 deletions massbalancemachine/training/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,6 @@
loadBestModel,
compute_stake_loss,
assessOnTest,
assessOnVal,
eval_geodetic,
)
Loading
Loading