From 620ed2a065efcf9ba5a298eee946feccd1bbb5bd Mon Sep 17 00:00:00 2001 From: IlijaTrajkovic Date: Tue, 25 Feb 2025 18:33:01 +0100 Subject: [PATCH 01/12] created new branch and brought changes from old cleanup branch --- climatem/model/train_model.py | 60 +++++++++++++------------- climatem/plotting/plot_model_output.py | 30 ++++++------- configs/single_param_file.json | 20 ++++----- scripts/main_picabu.py | 0 scripts/run_single_jsonfile.sh | 4 +- 5 files changed, 58 insertions(+), 56 deletions(-) mode change 100644 => 100755 scripts/main_picabu.py mode change 100644 => 100755 scripts/run_single_jsonfile.sh diff --git a/climatem/model/train_model.py b/climatem/model/train_model.py index ad816a8..09f44d8 100644 --- a/climatem/model/train_model.py +++ b/climatem/model/train_model.py @@ -356,7 +356,7 @@ def trace_handler(p): # Todo propagate the path! if not self.plot_params.savar: self.plotter.save_coordinates_and_adjacency_matrices(self) - torch.save(self.model.state_dict(), self.save_path / "model.pth") + torch.save(self.model.module.state_dict(), self.save_path / "model.pth") # try to use the accelerator.save function here self.accelerator.save_state(output_dir=self.save_path) @@ -490,9 +490,9 @@ def train_step(self): # also make the proper prediction, not the reconstruction as we do above # we have to take care here to make sure that we have the right tensors with requires_grad - y_pred, y_spare, z_spare, pz_mu, pz_std = self.model.predict(x, y) + y_pred, y_spare, z_spare, pz_mu, pz_std = self.model.module.predict(x, y) # I was hoping to do this with no_grad, but I do actually need it for the crps loss. - px_mu, px_std = self.model.predict_pxmu_pxstd(x, y) + px_mu, px_std = self.model.module.predict_pxmu_pxstd(x, y) # compute regularisations constraints/penalties (sparsity and connectivity) if self.optim_params.use_sparsity_constraint: @@ -511,7 +511,7 @@ def train_step(self): h_acyclic = torch.tensor([0.0]) if self.instantaneous and not self.converged: h_acyclic = self.get_acyclicity_violation() - h_ortho = self.get_ortho_violation(self.model.autoencoder.get_w_decoder()) + h_ortho = self.get_ortho_violation(self.model.module.autoencoder.get_w_decoder()) # compute total loss - here we are removing the sparsity regularisation as we are using the constraint here. loss = nll + connect_reg + sparsity_reg @@ -547,10 +547,10 @@ def train_step(self): ), self.train_params.lr # projection of the gradient for w - if self.model.autoencoder.use_grad_project and not self.no_w_constraint: + if self.model.module.autoencoder.use_grad_project and not self.no_w_constraint: with torch.no_grad(): - self.model.autoencoder.get_w_decoder().clamp_(min=0.0) - assert torch.min(self.model.autoencoder.get_w_decoder()) >= 0.0 + self.model.module.autoencoder.get_w_decoder().clamp_(min=0.0) + assert torch.min(self.model.module.autoencoder.get_w_decoder()) >= 0.0 self.train_loss = loss.item() self.train_nll = nll.item() @@ -679,7 +679,7 @@ def train_step(self): # Validation step here. def valid_step(self): - self.model.eval() + self.model.module.eval() with torch.no_grad(): # sample data @@ -713,7 +713,7 @@ def valid_step(self): # h_ortho = torch.tensor([0.]) if self.instantaneous and not self.converged: h_acyclic = self.get_acyclicity_violation() - h_ortho = self.get_ortho_violation(self.model.autoencoder.get_w_decoder()) + h_ortho = self.get_ortho_violation(self.model.module.autoencoder.get_w_decoder()) h_sparsity = self.get_sparsity_violation( lower_threshold=0.05, upper_threshold=self.optim_params.sparsity_upper_threshold @@ -855,8 +855,8 @@ def threshold(self): Convert it to a binary graph and fix it. """ with torch.no_grad(): - thresholded_adj = (self.model.get_adj() > 0.5).type(torch.Tensor) - self.model.mask.fix(thresholded_adj) + thresholded_adj = (self.model.module.get_adj() > 0.5).type(torch.Tensor) + self.model.module.mask.fix(thresholded_adj) self.thresholded = True print("Thresholding ================") @@ -907,16 +907,16 @@ def log_losses(self): self.gamma_sparsity_list.append(self.ALM_sparsity.gamma) self.adj_tt[int(self.iteration / self.train_params.valid_freq)] = ( - self.model.get_adj() + self.model.module.get_adj() ) # .cpu().detach().numpy() - w = self.model.autoencoder.get_w_decoder() # .cpu().detach().numpy() + w = self.model.module.autoencoder.get_w_decoder() # .cpu().detach().numpy() if not self.no_gt: self.adj_w_tt[int(self.iteration / self.train_params.valid_freq)] = w # here we just plot the first element of the logvar_decoder and logvar_encoder - self.logvar_decoder_tt.append(self.model.autoencoder.logvar_decoder[0].item()) - self.logvar_encoder_tt.append(self.model.autoencoder.logvar_encoder[0].item()) - self.logvar_transition_tt.append(self.model.transition_model.logvar[0, 0].item()) + self.logvar_decoder_tt.append(self.model.module.autoencoder.logvar_decoder[0].item()) + self.logvar_encoder_tt.append(self.model.module.autoencoder.logvar_encoder[0].item()) + self.logvar_transition_tt.append(self.model.module.transition_model.logvar[0, 0].item()) def print_results(self): """Print values of many variable: losses, constraint violation, etc. @@ -955,7 +955,7 @@ def get_nll(self, x, y, z=None) -> torch.Tensor: def get_regularisation(self) -> float: if self.iteration > self.optim_params.schedule_reg: - adj = self.model.get_adj() + adj = self.model.module.get_adj() reg = self.optim_params.reg_coeff * torch.norm(adj, p=1) # reg /= adj.numel() else: @@ -965,7 +965,7 @@ def get_regularisation(self) -> float: def get_acyclicity_violation(self) -> torch.Tensor: if self.iteration > 0: - adj = self.model.get_adj()[-1].view(self.d * self.d_z, self.d * self.d_z) + adj = self.model.module.get_adj()[-1].view(self.d * self.d_z, self.d * self.d_z) h = compute_dag_constraint(adj) / self.acyclic_constraint_normalization else: h = torch.tensor([0.0]) @@ -1007,7 +1007,7 @@ def get_sparsity_violation(self, lower_threshold, upper_threshold) -> float: if self.iteration > self.optim_params.schedule_sparsity: # first get the adj - adj = self.model.get_adj() + adj = self.model.module.get_adj() sum_of_connections = torch.norm(adj, p=1) / self.sparsity_normalization # print('constraint value, before I subtract a threshold from it:', sum_of_connections) @@ -1204,7 +1204,7 @@ def connectivity_reg_complete(self): Not used yet - could be interesting :) """ c = torch.tensor([0.0]) - w = self.model.autoencoder.get_w_encoder() + w = self.model.module.autoencoder.get_w_encoder() d = self.data.distances for i in self.d: for k in self.d_z: @@ -1214,7 +1214,7 @@ def connectivity_reg_complete(self): def connectivity_reg(self, ratio: float = 0.0005): """Calculate a connectivity regularisation only on a subsample of the complete data.""" c = torch.tensor([0.0]) - w = self.model.autoencoder.get_w_encoder() + w = self.model.module.autoencoder.get_w_encoder() n = int(self.d_x * ratio) points = np.random.choice(np.arange(self.d_x), n) @@ -1248,7 +1248,7 @@ def autoregress_prediction_original(self, valid: bool = False, timesteps: int = timesteps: int, the number of timesteps to predict into the future autoregressively """ - self.model.eval() + self.model.module.eval() if not valid: @@ -1269,11 +1269,11 @@ def autoregress_prediction_original(self, valid: bool = False, timesteps: int = # ensure these are correct with torch.no_grad(): - y_pred, y, z, pz_mu, pz_std = self.model.predict(x, y) + y_pred, y, z, pz_mu, pz_std = self.model.module.predict(x, y) # Here we predict, but taking 100 samples from the latents # TODO: make this into an argument - samples_from_xs, samples_from_zs, y = self.model.predict_sample(x, y, 10) + samples_from_xs, samples_from_zs, y = self.model.module.predict_sample(x, y, 10) # append the first prediction predictions.append(y_pred) @@ -1314,7 +1314,7 @@ def autoregress_prediction_original(self, valid: bool = False, timesteps: int = # then predict the next timestep # y at this point is pointless!!! with torch.no_grad(): - y_pred, y, z, pz_mu, pz_std = self.model.predict(x, y) + y_pred, y, z, pz_mu, pz_std = self.model.module.predict(x, y) # append the prediction predictions.append(y_pred) @@ -1390,7 +1390,7 @@ def autoregress_prediction_original(self, valid: bool = False, timesteps: int = # save the model in its current state print("Saving the model, since the spatial spectra score is the best we have seen for all variables.") - torch.save(self.model.state_dict(), self.save_path / "best_model_for_average_spectra.pth") + torch.save(self.model.module.state_dict(), self.save_path / "best_model_for_average_spectra.pth") else: @@ -1412,10 +1412,10 @@ def autoregress_prediction_original(self, valid: bool = False, timesteps: int = # swap with torch.no_grad(): - y_pred, y, z, pz_mu, pz_std = self.model.predict(x, y) + y_pred, y, z, pz_mu, pz_std = self.model.module.predict(x, y) # predict and take 100 samples too - samples_from_xs, samples_from_zs, y = self.model.predict_sample(x, y, 100) + samples_from_xs, samples_from_zs, y = self.model.module.predict_sample(x, y, 100) # make a copy of y_pred, which is a tensor x_original = x.clone().detach() @@ -1445,7 +1445,7 @@ def autoregress_prediction_original(self, valid: bool = False, timesteps: int = with torch.no_grad(): # then predict the next timestep - y_pred, y, z, pz_mu, pz_std = self.model.predict(x, y) + y_pred, y, z, pz_mu, pz_std = self.model.module.predict(x, y) np.save(self.save_path / f"val_x_ar_{i}.npy", x.detach().cpu().numpy()) np.save(self.save_path / f"val_y_ar_{i}.npy", y.detach().cpu().numpy()) @@ -1528,7 +1528,7 @@ def particle_filter(self, x, y, num_particles, timesteps=120): for _ in range(timesteps): # Prediction # make all the new predictions, taking samples from the latents - _, samples_from_zs, y = self.model.predict_sample(x, y, 100) + _, samples_from_zs, y = self.model.module.predict_sample(x, y, 100) # then calculate the score of each of the samples # Update the weights, where we want the weights to increase as the score improves diff --git a/climatem/plotting/plot_model_output.py b/climatem/plotting/plot_model_output.py index 71327fe..ae37b8e 100644 --- a/climatem/plotting/plot_model_output.py +++ b/climatem/plotting/plot_model_output.py @@ -49,13 +49,13 @@ def save(self, learner): if learner.latent: # save matrix W of the decoder and encoder print("Saving the decoder, encoder and graphs.") - w_decoder = learner.model.autoencoder.get_w_decoder().cpu().detach().numpy() + w_decoder = learner.model.module.autoencoder.get_w_decoder().cpu().detach().numpy() np.save(learner.plots_path / "w_decoder.npy", w_decoder) - w_encoder = learner.model.autoencoder.get_w_encoder().cpu().detach().numpy() + w_encoder = learner.model.module.autoencoder.get_w_encoder().cpu().detach().numpy() np.save(learner.plots_path / "w_encoder.npy", w_encoder) # save the graphs G - adj = learner.model.get_adj().cpu().detach().numpy() + adj = learner.model.module.get_adj().cpu().detach().numpy() np.save(learner.plots_path / "graphs.npy", adj) def load(self, exp_path, data_loader): @@ -168,12 +168,12 @@ def plot(self, learner, save=False): ) # plot the adjacency matrix (learned vs ground-truth) - adj = learner.model.get_adj().cpu().detach().numpy() + adj = learner.model.module.get_adj().cpu().detach().numpy() if not learner.no_gt: if learner.latent: # for latent models, find the right permutation of the latent - adj_w = learner.model.autoencoder.get_w_decoder().cpu().detach().numpy() - adj_w2 = learner.model.autoencoder.get_w_encoder().cpu().detach().numpy() + adj_w = learner.model.module.autoencoder.get_w_decoder().cpu().detach().numpy() + adj_w2 = learner.model.module.autoencoder.get_w_encoder().cpu().detach().numpy() # variables using MCC if learner.debug_gt_z: gt_dag = learner.gt_dag @@ -210,8 +210,8 @@ def plot(self, learner, save=False): gt_w = None # for latent models, find the right permutation of the latent - adj_w = learner.model.autoencoder.get_w_decoder().cpu().detach().numpy() - adj_w2 = learner.model.autoencoder.get_w_encoder().cpu().detach().numpy() + adj_w = learner.model.module.autoencoder.get_w_decoder().cpu().detach().numpy() + adj_w2 = learner.model.module.autoencoder.get_w_encoder().cpu().detach().numpy() # this is where this was before, but I have now added the argument names for myself if learner.plot_params.savar: @@ -368,12 +368,12 @@ def plot_sparsity(self, learner, save=False): # plot the adjacency matrix (learned vs ground-truth) # Here if SAVAR, learner should have GT and gt_dag should be the SAVAR GT - adj = learner.model.get_adj().cpu().detach().numpy() + adj = learner.model.module.get_adj().cpu().detach().numpy() if not learner.no_gt: if learner.latent: # for latent models, find the right permutation of the latent - adj_w = learner.model.autoencoder.get_w_decoder().cpu().detach().numpy() - adj_w2 = learner.model.autoencoder.get_w_encoder().cpu().detach().numpy() + adj_w = learner.model.module.autoencoder.get_w_decoder().cpu().detach().numpy() + adj_w2 = learner.model.module.autoencoder.get_w_encoder().cpu().detach().numpy() # variables using MCC if learner.debug_gt_z: gt_dag = learner.gt_dag @@ -410,8 +410,8 @@ def plot_sparsity(self, learner, save=False): gt_w = None # for latent models, find the right permutation of the latent - adj_w = learner.model.autoencoder.get_w_decoder().cpu().detach().numpy() - adj_w2 = learner.model.autoencoder.get_w_encoder().cpu().detach().numpy() + adj_w = learner.model.module.autoencoder.get_w_decoder().cpu().detach().numpy() + adj_w2 = learner.model.module.autoencoder.get_w_encoder().cpu().detach().numpy() # this is where this was before, but I have now added the argument names for myself if learner.plot_params.savar: @@ -1076,8 +1076,8 @@ def save_coordinates_and_adjacency_matrices(self, learner): if not os.path.exists(learner.plots_path / "coordinates.npy"): np.save(learner.plots_path / "coordinates.npy", learner.coordinates) - adj_w = learner.model.autoencoder.get_w_decoder().cpu().detach().numpy() - adj_encoder_w = learner.model.autoencoder.get_w_encoder().cpu().detach().numpy() + adj_w = learner.model.module.autoencoder.get_w_decoder().cpu().detach().numpy() + adj_encoder_w = learner.model.module.autoencoder.get_w_encoder().cpu().detach().numpy() np.save(learner.plots_path / f"adj_encoder_w_{learner.iteration}.npy", adj_encoder_w) np.save(learner.plots_path / f"adj_w_{learner.iteration}.npy", adj_w) diff --git a/configs/single_param_file.json b/configs/single_param_file.json index 78757fd..c99a011 100644 --- a/configs/single_param_file.json +++ b/configs/single_param_file.json @@ -1,12 +1,12 @@ { "exp_params": { - "exp_path": "/network/scratch/j/julien.boussard/results/climatem_spectral/var_ts_picontrol", + "exp_path": "/home/kit/iti/qa4548/my_projects/climatem/results/climatem_spectral/var_ts_picontrol", "_target_": "emulator.src.datamodules.climate_datamodule.ClimateDataModule", "latent": true, - "d_z": 90, - "d_x": 6250, - "lon": 144, - "lat": 96, + "d_z": 4, + "d_x": 400, + "lon": 20, + "lat": 20, "tau": 5, "random_seed": 1, "gpu": true, @@ -15,11 +15,11 @@ "verbose": true }, "data_params": { - "data_dir": "/network/scratch/j/julien.boussard/data/Climateset_DATA_TEST", - "climateset_data": "/network/scratch/j/julien.boussard/data/icosahedral_data/structured/picontrol/24_ni", + "data_dir": "/home/kit/iti/qa4548/my_projects/climatem/data/Climateset_DATA_TEST", + "climateset_data": "/home/kit/iti/qa4548/my_projects/climatem/data/icosahedral_data/picontrol/24_ni", "reload_climate_set_data": false, - "icosahedral_coordinates_path": "/home/mila/j/julien.boussard/causal_model/climatem/mappings/vertex_lonlat_mapping.npy", - "in_var_ids": ["ts"], + "icosahedral_coordinates_path": "/home/kit/iti/qa4548/my_projects/climatem/mappings/vertex_lonlat_mapping.npy", + "in_var_ids": ["savar"], "out_var_ids": ["ts"], "num_levels": 1, "temp_res": "mon", @@ -34,7 +34,7 @@ "data_format": "numpy", "ishdf5": false, "global_normalization": true, - "seasonality_removal": true, + "seasonality_removal": false, "batch_size": 128, "eval_batch_size": 128, "channels_last": false, diff --git a/scripts/main_picabu.py b/scripts/main_picabu.py old mode 100644 new mode 100755 diff --git a/scripts/run_single_jsonfile.sh b/scripts/run_single_jsonfile.sh old mode 100644 new mode 100755 index e4d8504..5e17683 --- a/scripts/run_single_jsonfile.sh +++ b/scripts/run_single_jsonfile.sh @@ -15,9 +15,11 @@ module purge # 1. Load the required modules -module --quiet load python/3.10 +module load compiler/intel/2021.4.0 +module load devel/python/3.10.5_intel_2021.4.0 # 2. Load your environment +source $HOME/my_projects/climatem/env_emulator_climatem/bin/activate source $HOME/env_climatem/bin/activate From 01e0c19b1fb8876e61949434967fcfac9711f0e2 Mon Sep 17 00:00:00 2001 From: IlijaTrajkovic Date: Wed, 26 Feb 2025 22:13:12 +0100 Subject: [PATCH 02/12] added permutation for each plotting interval --- climatem/model/train_model.py | 4 +++ climatem/plotting/plot_model_output.py | 46 +++++++++++++++++++++++++- scripts/main_picabu.py | 2 ++ 3 files changed, 51 insertions(+), 1 deletion(-) diff --git a/climatem/model/train_model.py b/climatem/model/train_model.py index 09f44d8..006d405 100644 --- a/climatem/model/train_model.py +++ b/climatem/model/train_model.py @@ -20,12 +20,14 @@ def __init__( self, model, datamodule, + data_params, exp_params, gt_params, model_params, train_params, optim_params, plot_params, + savar_params, save_path, plots_path, best_metrics, @@ -43,10 +45,12 @@ def __init__( self.data_loader_train = iter(datamodule.train_dataloader(accelerator=accelerator)) self.data_loader_val = iter(datamodule.val_dataloader()) self.coordinates = datamodule.coordinates + self.data_params = data_params self.exp_params = exp_params self.train_params = train_params self.optim_params = optim_params self.plot_params = plot_params + self.savar_params = savar_params self.best_metrics = best_metrics self.save_path = save_path self.plots_path = plots_path diff --git a/climatem/plotting/plot_model_output.py b/climatem/plotting/plot_model_output.py index ae37b8e..798f01a 100644 --- a/climatem/plotting/plot_model_output.py +++ b/climatem/plotting/plot_model_output.py @@ -363,6 +363,12 @@ def plot_sparsity(self, learner, save=False): path=learner.plots_path, ) + # Load gt mode weights + savar_folder = learner.data_params.data_dir + savar_fname = f"modes_{learner.d_z}_tl_{learner.savar_params.time_len}_isforced_{learner.savar_params.is_forced}_difficulty_{learner.savar_params.difficulty}_noisestrength_{learner.savar_params.noise_val}_seasonality_{learner.savar_params.seasonality}_overlap_{learner.savar_params.overlap}" + # Get the gt mode weights + modes_gt = np.load(savar_folder + f"/{savar_fname}_mode_weights.npy") + # TODO: plot the prediction vs gt # plot_compare_prediction(x, x_hat) @@ -416,21 +422,29 @@ def plot_sparsity(self, learner, save=False): # this is where this was before, but I have now added the argument names for myself if learner.plot_params.savar: self.plot_adjacency_matrix( + learner, mat1=adj, # Below savar dag mat2=learner.datamodule.savar_gt_adj, + modes_gt=modes_gt, + modes_inferred=adj_w, path=learner.plots_path, name_suffix="transition", + savar=True, no_gt=False, iteration=learner.iteration, plot_through_time=learner.plot_params.plot_through_time, ) else: self.plot_adjacency_matrix( + learner, mat1=adj, mat2=gt_dag, + modes_gt=gt_w, + modes_inferred=adj_w, path=learner.plots_path, name_suffix="transition", + savar=False, no_gt=learner.no_gt, iteration=learner.iteration, plot_through_time=learner.plot_params.plot_through_time, @@ -892,7 +906,7 @@ def plot_savar_feature_maps( ): grid_shape = (learner.lat, learner.lon) - + print("SAVAR plot is called") w_adj = w_adj[0] # Now w_adj_mean should be (lat*lon, num_latents) d_z = w_adj.shape[1] @@ -1085,10 +1099,14 @@ def save_coordinates_and_adjacency_matrices(self, learner): # simply follow the lead of the function above, and try to plot through time. def plot_adjacency_matrix( self, + learner, mat1: np.ndarray, mat2: np.ndarray, + modes_gt, + modes_inferred, path, name_suffix: str, + savar, no_gt: bool = False, iteration: int = 0, plot_through_time: bool = True, @@ -1102,8 +1120,34 @@ def plot_adjacency_matrix( name_suffix: suffix for the name of the plot no_gt: if True, does not use the ground-truth graph """ + + lat = learner.lat + lon = learner.lon tau = mat1.shape[0] + if savar and modes_gt is not None and modes_inferred is not None: + # Find the permutation + modes_inferred = modes_inferred.reshape((lat, lon, modes_inferred.shape[-1])).transpose((2, 0, 1)) + + # Get the flat index of the maximum for each mode + idx_gt_flat = np.argmax(modes_gt.reshape(modes_gt.shape[0], -1), axis=1) # shape: (n_modes,) + idx_inferred_flat = np.argmax(modes_inferred.reshape(modes_inferred.shape[0], -1), axis=1) # shape: (n_modes,) + + # Convert flat indices to 2D coordinates (row, col) + idx_gt = np.array([np.unravel_index(i, (lat, lon)) for i in idx_gt_flat]) # shape: (n_modes, 2) + idx_inferred = np.array([np.unravel_index(i, (lat, lon)) for i in idx_inferred_flat]) # shape: (n_modes, 2) + + # Compute error matrix using squared Euclidean distance between indices which yields an (n_modes x n_modes) matrix + permutation_list = ((idx_gt[:, None, :] - idx_inferred[None, :, :]) ** 2).sum(axis=2).argmin(axis=1) + print("permutation_list:", permutation_list) + + # Permute + for k in range(tau): + mat1[k] = mat1[k][np.ix_(permutation_list, permutation_list)] + + print("PERMUTED THE MATRICES") + + subfig_names = [ f"Learned, latent dimensions = {mat1.shape[1], mat1.shape[2]}", "Ground Truth", diff --git a/scripts/main_picabu.py b/scripts/main_picabu.py index f30113d..71d48e9 100755 --- a/scripts/main_picabu.py +++ b/scripts/main_picabu.py @@ -204,12 +204,14 @@ def main( trainer = TrainingLatent( model, datamodule, + data_params, experiment_params, gt_params, model_params, train_params, optim_params, plot_params, + savar_params, save_path, plots_path, best_metrics, From 8aa42bee0bd7aebc301b08aba2595599dbbfd98b Mon Sep 17 00:00:00 2001 From: IlijaTrajkovic Date: Thu, 27 Feb 2025 23:05:05 +0100 Subject: [PATCH 03/12] fixed time-inverted gt matrix --- climatem/plotting/plot_model_output.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/climatem/plotting/plot_model_output.py b/climatem/plotting/plot_model_output.py index 798f01a..5923357 100644 --- a/climatem/plotting/plot_model_output.py +++ b/climatem/plotting/plot_model_output.py @@ -365,7 +365,8 @@ def plot_sparsity(self, learner, save=False): # Load gt mode weights savar_folder = learner.data_params.data_dir - savar_fname = f"modes_{learner.d_z}_tl_{learner.savar_params.time_len}_isforced_{learner.savar_params.is_forced}_difficulty_{learner.savar_params.difficulty}_noisestrength_{learner.savar_params.noise_val}_seasonality_{learner.savar_params.seasonality}_overlap_{learner.savar_params.overlap}" + n_modes = learner.savar_params.n_per_col ** 2 + savar_fname = f"modes_{n_modes}_tl_{learner.savar_params.time_len}_isforced_{learner.savar_params.is_forced}_difficulty_{learner.savar_params.difficulty}_noisestrength_{learner.savar_params.noise_val}_seasonality_{learner.savar_params.seasonality}_overlap_{learner.savar_params.overlap}" # Get the gt mode weights modes_gt = np.load(savar_folder + f"/{savar_fname}_mode_weights.npy") @@ -1222,7 +1223,7 @@ def plot_adjacency_matrix( elif row == 1: sns.heatmap( - mat2[tau - i - 1], + mat2[i], ax=axes[i], cbar=False, vmin=-1, @@ -1233,7 +1234,7 @@ def plot_adjacency_matrix( ) elif row == 2: sns.heatmap( - mat1[tau - i - 1] - mat2[tau - i - 1], + mat1[tau - i - 1] - mat2[i], ax=axes[i], cbar=False, vmin=-1, From 6b5be9b16d93ae17ef88f2cb64c7525553d4343e Mon Sep 17 00:00:00 2001 From: IlijaTrajkovic Date: Mon, 3 Mar 2025 17:56:15 +0100 Subject: [PATCH 04/12] added overlap of modes from 0 to 1 --- .../synthetic_data/generate_savar_datasets.py | 28 ++++++++++++++----- 1 file changed, 21 insertions(+), 7 deletions(-) diff --git a/climatem/synthetic_data/generate_savar_datasets.py b/climatem/synthetic_data/generate_savar_datasets.py index 4f76b29..de3efc5 100644 --- a/climatem/synthetic_data/generate_savar_datasets.py +++ b/climatem/synthetic_data/generate_savar_datasets.py @@ -102,7 +102,7 @@ def generate_save_savar_data( n_per_col=2, # Number of components N = n_per_col**2 difficulty="easy", seasonality=False, - overlap=False, + overlap=0, is_forced=False, plotting=True, ): @@ -110,26 +110,40 @@ def generate_save_savar_data( # Setup spatial weights of underlying processes ny = nx = n_per_col * comp_size N = n_per_col**2 # Number of components + + if not (0 <= overlap <= 1): raise ValueError("overlap must be between 0 and 1") + noise_weights = np.zeros((N, nx, ny)) modes_weights = np.zeros((N, nx, ny)) - if overlap: - raise ValueError("SAVAR data with overlapping modes not implemented yet") # Specify the path where you want to save the data npy_name = f"{name}.npy" save_path = save_dir_path / npy_name + # Center starting position (for fully overlapping modes) + center_x_start = (nx - comp_size) // 2 + center_y_start = (ny - comp_size) // 2 + # Create modes weights for k in range(n_per_col): for j in range(n_per_col): + idx = k * n_per_col + j + # Original starting position (no overlap) + orig_x_start = k * comp_size + orig_y_start = j * comp_size + # New starting positions (interpolated between original and central) + new_x_start = int((1 - overlap) * orig_x_start + overlap * center_x_start) + new_y_start = int((1 - overlap) * orig_y_start + overlap * center_y_start) + new_x_end = new_x_start + comp_size + new_y_end = new_y_start + comp_size modes_weights[ - k * n_per_col + j, k * comp_size : (k + 1) * comp_size, j * comp_size : (j + 1) * comp_size + idx, new_x_start : new_x_end, new_y_start : new_y_end ] = create_random_mode((comp_size, comp_size), random=True) - for k in range(n_per_col): - for j in range(n_per_col): + #for k in range(n_per_col): + # for j in range(n_per_col): noise_weights[ - k * n_per_col + j, k * comp_size : (k + 1) * comp_size, j * comp_size : (j + 1) * comp_size + idx, new_x_start : new_x_end, new_y_start : new_y_end ] = create_random_mode((comp_size, comp_size), random=True) # This is the probabiliity of having a link between latent k and j, with k different from j. latents always have one link with themselves at a previous time. From 216792393e82f34cf8f1f1e242b86f68ed72e4bf Mon Sep 17 00:00:00 2001 From: IlijaTrajkovic Date: Mon, 3 Mar 2025 19:18:49 +0100 Subject: [PATCH 05/12] added forcings, propagated them, added quadratic and exponential forcing, added torch funs at forcing method --- .../synthetic_data/generate_savar_datasets.py | 19 ++++--- climatem/synthetic_data/savar.py | 57 ++++++++++++------- scripts/main_picabu.py | 9 ++- 3 files changed, 56 insertions(+), 29 deletions(-) diff --git a/climatem/synthetic_data/generate_savar_datasets.py b/climatem/synthetic_data/generate_savar_datasets.py index de3efc5..567df42 100644 --- a/climatem/synthetic_data/generate_savar_datasets.py +++ b/climatem/synthetic_data/generate_savar_datasets.py @@ -104,6 +104,11 @@ def generate_save_savar_data( seasonality=False, overlap=0, is_forced=False, + f_1=1, + f_2=2, + f_time_1=4000, + f_time_2=8000, + ramp_type="linear", plotting=True, ): @@ -163,8 +168,7 @@ def generate_save_savar_data( check_stability(links_coeffs) if is_forced: - raise ValueError("SAVAR data with forcings not implemented yet") - f_1, f_2, f_time_1, f_time_2 = 1, 2, 4000, 8000 # turn off forcing by setting the time to the last time step + # turn off forcing by setting the time to the last time step w_f = modes_weights # A very simple method for adding a focring term (bias on the mean of the noise term) forcing_dict = { @@ -174,6 +178,7 @@ def generate_save_savar_data( "f_time_1": f_time_1, # The period one goes from t=0 to t=f_time_1 "f_time_2": f_time_2, # The period two goes from t= f_time_2 to the end. Between the two periods, the forcing is risen linearly "time_len": time_len, + "ramp_type": ramp_type, } if seasonality: raise ValueError("SAVAR data with seasonality not implemented yet") @@ -223,11 +228,11 @@ def generate_save_savar_data( "T": time_len, "N": N, "links_coeffs": links_coeffs, - # "f_1": f_1, - # "f_2": f_2, - # "f_time_1": f_time_1, - # "f_time_2": f_time_2, - # "time_len": time_len, + "f_1": f_1, + "f_2": f_2, + "f_time_1": f_time_1, + "f_time_2": f_time_2, + "ramp_type": ramp_type, # "season_dict": season_dict, # "seasonality" : True, } diff --git a/climatem/synthetic_data/savar.py b/climatem/synthetic_data/savar.py index d55a8a9..bf79c4d 100644 --- a/climatem/synthetic_data/savar.py +++ b/climatem/synthetic_data/savar.py @@ -239,41 +239,56 @@ def _add_seasonality_forcing(self): self.data_field += seasonal_data_field def _add_external_forcing(self): - - # TODO Make this a torch function - + """ + Adds external forcing to the data field using PyTorch tensors for GPU acceleration. + Allows for both linear and nonlinear ramps. + """ if self.forcing_dict is None: raise TypeError("Forcing dict is empty") w_f = deepcopy(self.forcing_dict.get("w_f")) - f_1 = deepcopy(self.forcing_dict.get("f_1")) - f_2 = deepcopy(self.forcing_dict.get("f_2")) - f_time_1 = deepcopy(self.forcing_dict.get("f_time_1")) - f_time_2 = deepcopy(self.forcing_dict.get("f_time_2")) + f_1 = self.forcing_dict.get("f_1", 0) + f_2 = self.forcing_dict.get("f_2", 0) + f_time_1 = self.forcing_dict.get("f_time_1", 0) + f_time_2 = self.forcing_dict.get("f_time_2", self.time_length) + ramp_type = self.forcing_dict.get("ramp_type", "linear") # Default to linear + # Default w_f to mode weights if not provided if w_f is None: w_f = deepcopy(self.mode_weights) - w_f = w_f.astype(bool).astype(int) # Converts non-zero elements of the weight into 1. + w_f = (w_f != 0).astype(int) # Convert non-zero elements to 1 - w_f_sum = w_f.sum(axis=0) + w_f_sum = torch.tensor(w_f.sum(axis=0), dtype=torch.float32, device="cuda") f_time_1 += self.transient f_time_2 += self.transient - - # Check time_length = self.time_length + self.transient - trend = np.concatenate( - ( - np.repeat([f_1], f_time_1), - np.linspace(f_1, f_2, f_time_2 - f_time_1), - np.repeat([f_2], time_length - f_time_2), - ) - ).reshape((1, time_length)) - forcing_field = (w_f_sum.reshape(1, -1) * trend.transpose()).transpose() - self.forcing_data_field = forcing_field + # Generate the forcing trend using torch tensors + if ramp_type == "linear": + ramp = torch.linspace(f_1, f_2, f_time_2 - f_time_1, dtype=torch.float32, device="cuda") + elif ramp_type == "quadratic": + t = torch.linspace(0, 1, f_time_2 - f_time_1, dtype=torch.float32, device="cuda") + ramp = f_1 + (f_2 - f_1) * t**2 + elif ramp_type == "exponential": + t = torch.linspace(0, 1, f_time_2 - f_time_1, dtype=torch.float32, device="cuda") + ramp = f_1 + (f_2 - f_1) * (torch.exp(t) - 1) / (torch.exp(torch.tensor(1.0)) - 1) + else: + raise ValueError("Unsupported ramp type. Choose from 'linear', 'quadratic', or 'exponential'.") + + # Generate the forcing trend using torch tensors + trend = torch.cat([ + torch.full((f_time_1,), f_1, dtype=torch.float32, device="cuda"), + ramp, + torch.full((time_length - f_time_2,), f_2, dtype=torch.float32, device="cuda") + ]).reshape(1, time_length) + + # Compute the forcing field on GPU + forcing_field = (w_f_sum.reshape(1, -1) * trend.T).T + self.forcing_data_field = forcing_field.cpu().numpy() # Add it to the data field. - self.data_field += forcing_field + self.data_field += self.forcing_data_field + def _create_linear(self): """Weights N \times L data_field L \times T.""" diff --git a/scripts/main_picabu.py b/scripts/main_picabu.py index 71d48e9..23a702a 100755 --- a/scripts/main_picabu.py +++ b/scripts/main_picabu.py @@ -66,6 +66,7 @@ def main( # generate data and split train/test if experiment_params.gpu and torch.cuda.is_available(): device = "cuda" + print("CUDACUDACUDA") else: device = "cpu" @@ -117,6 +118,11 @@ def main( seasonality=savar_params.seasonality, overlap=savar_params.overlap, is_forced=savar_params.is_forced, + f_1=savar_params.f_1, + f_2=savar_params.f_2, + f_time_1=savar_params.f_time_1, + f_time_2=savar_params.f_time_2, + ramp_type=savar_params.ramp_type, plot_original_data=savar_params.plot_original_data, ) datamodule.setup() @@ -175,7 +181,7 @@ def main( .translate({ord(","): None}) .translate({ord(" "): None}) ) - name = f"var_{data_var_ids_str}_scenarios_{data_params.train_scenarios[0]}_nonlinear_{model_params.nonlinear_mixing}_tau_{experiment_params.tau}_z_{experiment_params.d_z}_lr_{train_params.lr}_bs_{data_params.batch_size}_spreg_{optim_params.reg_coeff}_ormuinit_{optim_params.ortho_mu_init}_spmuinit_{optim_params.sparsity_mu_init}_spthres_{optim_params.sparsity_upper_threshold}_fixed_{model_params.fixed}_num_ensembles_{data_params.num_ensembles}_instantaneous_{model_params.instantaneous}_crpscoef_{optim_params.crps_coeff}_spcoef_{optim_params.spectral_coeff}_tempspcoef_{optim_params.temporal_spectral_coeff}" + name = f"var_{data_var_ids_str}_scenarios_{data_params.train_scenarios[0]}_nonlinear_{model_params.nonlinear_mixing}_tau_{experiment_params.tau}_z_{experiment_params.d_z}_lr_{train_params.lr}_bs_{data_params.batch_size}_spreg_{optim_params.reg_coeff}_ormuinit_{optim_params.ortho_mu_init}_spmuinit_{optim_params.sparsity_mu_init}_spthres_{optim_params.sparsity_upper_threshold}_fixed_{model_params.fixed}_num_ensembles_{data_params.num_ensembles}_instantaneous_{model_params.instantaneous}_crpscoef_{optim_params.crps_coeff}_spcoef_{optim_params.spectral_coeff}_tempspcoef_{optim_params.temporal_spectral_coeff}_overlap_{savar_params.overlap}_forcing_{savar_params.is_forced}" exp_path = exp_path / name os.makedirs(exp_path, exist_ok=True) @@ -192,6 +198,7 @@ def main( hp["train_params"] = train_params.__dict__ hp["model_params"] = model_params.__dict__ hp["optim_params"] = optim_params.__dict__ + hp["savar_params"] = savar_params.__dict__ with open(exp_path / "params.json", "w") as file: json.dump(hp, file, indent=4) From d10862e35b3a57848ec57982e93522aa99591898 Mon Sep 17 00:00:00 2001 From: IlijaTrajkovic Date: Thu, 3 Apr 2025 19:44:45 +0200 Subject: [PATCH 06/12] nonlinearity and forcings --- climatem/config.py | 14 + climatem/data_loader/causal_datamodule.py | 7 + climatem/data_loader/climate_datamodule.py | 2 + climatem/data_loader/savar_dataset.py | 23 +- climatem/plotting/plot_model_output.py | 55 +++- .../synthetic_data/generate_savar_datasets.py | 11 +- climatem/synthetic_data/savar.py | 242 ++++++++++++++++-- scripts/main_picabu.py | 2 + 8 files changed, 329 insertions(+), 27 deletions(-) diff --git a/climatem/config.py b/climatem/config.py index 47aad0a..aaddcf3 100644 --- a/climatem/config.py +++ b/climatem/config.py @@ -267,6 +267,13 @@ def __init__( seasonality: bool = False, # Seasonality in synthetic data overlap: bool = False, # Modes overlap is_forced: bool = False, # Forcings in synthetic data + f_1: int = 1, + f_2: int = 2, + f_time_1: int = 4000, + f_time_2: int = 8000, + ramp_type: str = "linear", + linearity: str = "linear", + poly_degrees: List[int] = [2], plot_original_data: bool = True, ): self.time_len = time_len @@ -277,4 +284,11 @@ def __init__( self.seasonality = seasonality self.overlap = overlap self.is_forced = is_forced + self.f_1 = f_1 + self.f_2 = f_2 + self.f_time_1 = f_time_1 + self.f_time_2 = f_time_2 + self.ramp_type = ramp_type + self.linearity = linearity + self.poly_degrees = poly_degrees self.plot_original_data = plot_original_data diff --git a/climatem/data_loader/causal_datamodule.py b/climatem/data_loader/causal_datamodule.py index 3c3e24f..3a4aaa5 100644 --- a/climatem/data_loader/causal_datamodule.py +++ b/climatem/data_loader/causal_datamodule.py @@ -96,6 +96,13 @@ def setup(self, stage: Optional[str] = None): seasonality=self.hparams.seasonality, overlap=self.hparams.overlap, is_forced=self.hparams.is_forced, + f_1=self.hparams.f_1, + f_2=self.hparams.f_2, + f_time_1=self.hparams.f_time_1, + f_time_2=self.hparams.f_time_2, + ramp_type=self.hparams.ramp_type, + linearity=self.hparams.linearity, + poly_degrees=self.hparams.poly_degrees, plot_original_data=self.hparams.plot_original_data, ) elif ( diff --git a/climatem/data_loader/climate_datamodule.py b/climatem/data_loader/climate_datamodule.py index 5dfd9e1..ed1c01d 100644 --- a/climatem/data_loader/climate_datamodule.py +++ b/climatem/data_loader/climate_datamodule.py @@ -71,6 +71,8 @@ def __init__( seasonality: bool = False, overlap: bool = False, is_forced: bool = False, + linearity: str = "linear", + poly_degrees: List[int] = [2], plot_original_data: bool = True, ): """ diff --git a/climatem/data_loader/savar_dataset.py b/climatem/data_loader/savar_dataset.py index 50aa5f5..886117c 100644 --- a/climatem/data_loader/savar_dataset.py +++ b/climatem/data_loader/savar_dataset.py @@ -1,6 +1,6 @@ import os from pathlib import Path -from typing import Optional +from typing import List, Optional import numpy as np import torch @@ -27,6 +27,13 @@ def __init__( seasonality: bool = False, overlap: bool = False, is_forced: bool = False, + f_1: int = 1, + f_2: int = 2, + f_time_1: int = 4000, + f_time_2: int = 8000, + ramp_type: str = "linear", + linearity: str = "linear", + poly_degrees: List[int] = [2,3], plot_original_data: bool = True, ): super().__init__() @@ -51,6 +58,13 @@ def __init__( self.seasonality = seasonality self.overlap = overlap self.is_forced = is_forced + self.f_1 = f_1 + self.f_2 = f_2 + self.f_time_1 = f_time_1 + self.f_time_2 = f_time_2 + self.ramp_type = ramp_type + self.linearity = linearity + self.poly_degrees = poly_degrees self.plot_original_data = plot_original_data if self.reload_climate_set_data: @@ -172,6 +186,13 @@ def get_causal_data( self.seasonality, self.overlap, self.is_forced, + self.f_1, + self.f_2, + self.f_time_1, + self.f_time_2, + self.ramp_type, + self.linearity, + self.poly_degrees, self.plot_original_data, ) time_steps = data.shape[1] diff --git a/climatem/plotting/plot_model_output.py b/climatem/plotting/plot_model_output.py index 5923357..e9d9bcc 100644 --- a/climatem/plotting/plot_model_output.py +++ b/climatem/plotting/plot_model_output.py @@ -10,6 +10,7 @@ import torch from mpl_toolkits.axes_grid1 import make_axes_locatable from mpl_toolkits.basemap import Basemap +import matplotlib.animation as animation from climatem.model.metrics import mcc_latent @@ -364,11 +365,16 @@ def plot_sparsity(self, learner, save=False): ) # Load gt mode weights - savar_folder = learner.data_params.data_dir - n_modes = learner.savar_params.n_per_col ** 2 - savar_fname = f"modes_{n_modes}_tl_{learner.savar_params.time_len}_isforced_{learner.savar_params.is_forced}_difficulty_{learner.savar_params.difficulty}_noisestrength_{learner.savar_params.noise_val}_seasonality_{learner.savar_params.seasonality}_overlap_{learner.savar_params.overlap}" - # Get the gt mode weights - modes_gt = np.load(savar_folder + f"/{savar_fname}_mode_weights.npy") + if learner.plot_params.savar: + savar_folder = learner.data_params.data_dir + n_modes = learner.savar_params.n_per_col ** 2 + savar_fname = f"modes_{n_modes}_tl_{learner.savar_params.time_len}_isforced_{learner.savar_params.is_forced}_difficulty_{learner.savar_params.difficulty}_noisestrength_{learner.savar_params.noise_val}_seasonality_{learner.savar_params.seasonality}_overlap_{learner.savar_params.overlap}" + # Get the gt mode weights + modes_gt = np.load(savar_folder + f"/{savar_fname}_mode_weights.npy") + if learner.iteration==500: + savar_data = np.load(savar_folder + f"/{savar_fname}.npy") + savar_anim_path = savar_folder + f"/{savar_fname}_original_savar_data.gif" + self.plot_original_savar(savar_data, learner.lat, learner.lon, savar_anim_path) # TODO: plot the prediction vs gt # plot_compare_prediction(x, x_hat) @@ -1429,14 +1435,45 @@ def save_mcc_and_assignement(self, exp_path): fig.savefig(exp_path / "mcc.png") fig.clf() + def plot_original_savar(self, data, lat, lon, path): + """Plotting the original savar data.""" + print(f"data shape {data.shape}") + # Get the dimensions + time_steps = data.shape[1] + data_reshaped = data.T.reshape((time_steps, lat, lon)) + + # Calculate the average over the time axis + avg_data = np.mean(data_reshaped, axis=0) + + # Determine the global min and max from the averaged data for consistent color scaling + vmin = np.min(avg_data) + vmax = np.max(avg_data) + + fig, ax = plt.subplots(figsize=(lon / 10, lat / 10)) + cax = ax.imshow(data_reshaped[0], aspect="auto", cmap="viridis", vmin=vmin, vmax=vmax) + cbar = fig.colorbar(cax, ax=ax) + + def animate(i): + cax.set_data(data_reshaped[i]) + ax.set_title(f"Time step: {i+1}") + return (cax,) + + # Create an animation + ani = animation.FuncAnimation(fig, animate, frames=100, blit=True) + + # Save the animation as a video file + ani.save(path, writer="pillow", fps=10) + + plt.close() + + # # Below are functions used for plotting savar results / metrics. Not used yet but could be useful / integrated into the savar pipeline - # def plot_original_savar(self, path, lon, lat, savar_path): + # def plot_original_savar(self, data, lon, lat, path): # """Plotting the original savar data.""" - # data = np.load(f"{savar_path}.npy") - + # print(f"data shape {data.shape}") # # Get the dimensions - # time_steps = data.shape[1] + # time_steps = data.shape[0] # data_reshaped = data.T.reshape((time_steps, lat, lon)) # # Calculate the average over the time axis diff --git a/climatem/synthetic_data/generate_savar_datasets.py b/climatem/synthetic_data/generate_savar_datasets.py index 567df42..966f1af 100644 --- a/climatem/synthetic_data/generate_savar_datasets.py +++ b/climatem/synthetic_data/generate_savar_datasets.py @@ -5,6 +5,7 @@ import numpy as np from mpl_toolkits.axes_grid1 import make_axes_locatable from scipy.stats import beta +import matplotlib.animation as animation from climatem.synthetic_data.savar import SAVAR from climatem.synthetic_data.utils import check_stability, create_random_mode @@ -109,6 +110,8 @@ def generate_save_savar_data( f_time_1=4000, f_time_2=8000, ramp_type="linear", + linearity="polynomial", + poly_degrees=[2,3], plotting=True, ): @@ -233,6 +236,8 @@ def generate_save_savar_data( "f_time_1": f_time_1, "f_time_2": f_time_2, "ramp_type": ramp_type, + "linearity": linearity, + "poly_degrees": poly_degrees, # "season_dict": season_dict, # "seasonality" : True, } @@ -274,6 +279,8 @@ def generate_save_savar_data( noise_strength=noise_val, # How to play with this parameter? # season_dict=season_dict, #turn off by commenting out # forcing_dict=forcing_dict #turn off by commenting out + linearity=linearity, + poly_degrees=poly_degrees, ) else: savar_model = SAVAR( @@ -281,11 +288,13 @@ def generate_save_savar_data( time_length=time_len, mode_weights=modes_weights, noise_strength=noise_val, - # season_dict=season_dict, #turn off by commenting out forcing_dict=forcing_dict, # turn off by commenting out + linearity=linearity, + poly_degrees=poly_degrees, ) savar_model.generate_data() # Remember to generate data, otherwise the data field will be empty np.save(save_path, savar_model.data_field) + print(f"{name} DONE!") return savar_model.data_field diff --git a/climatem/synthetic_data/savar.py b/climatem/synthetic_data/savar.py index bf79c4d..58ff24c 100644 --- a/climatem/synthetic_data/savar.py +++ b/climatem/synthetic_data/savar.py @@ -5,15 +5,17 @@ 2022 The main difference with the provided code is the torch/GPU implementation which considerably speeds up the data generation process """ - +from typing import List +import seaborn as sns import itertools as it from copy import deepcopy from math import pi, sin - +import matplotlib.pyplot as plt import numpy as np import torch from torch.distributions.multivariate_normal import MultivariateNormal from tqdm.auto import tqdm +import torch.nn as nn def dict_to_matrix(links_coeffs, default=0): @@ -37,9 +39,6 @@ def dict_to_matrix(links_coeffs, default=0): return graph -### - - class SAVAR: """Main class containing SAVAR model.""" @@ -64,8 +63,10 @@ class SAVAR: "seasonal_data_field", "forcing_data_field", "linearity", + "poly_degrees", "verbose", "model_seed", + "nnar_model", ] def __init__( @@ -81,12 +82,13 @@ def __init__( latent_noise_cov: np.ndarray = None, fast_cov: np.ndarray = None, forcing_dict: dict = None, + linearity: str = "linear", + poly_degrees: List[int] = [2], season_dict: dict = None, data_field: np.ndarray = None, noise_data_field: np.ndarray = None, seasonal_data_field: np.ndarray = None, forcing_data_field: np.ndarray = None, - linearity: str = "linear", verbose: bool = False, model_seed: int = None, ): @@ -106,10 +108,11 @@ def __init__( self.forcing_dict = forcing_dict self.season_dict = season_dict + self.linearity = linearity + self.poly_degrees = poly_degrees self.data_field = data_field - self.linearity = linearity self.verbose = verbose self.model_seed = model_seed @@ -119,7 +122,7 @@ def __init__( self.tau_max = max(abs(lag) for (_, lag), _ in it.chain.from_iterable(self.links_coeffs.values())) self.spatial_resolution = deepcopy(self.mode_weights.reshape(self.n_vars, -1).shape[1]) print("spatial-resolution done") - + if self.noise_weights is None: self.noise_weights = deepcopy(self.mode_weights) if self.latent_noise_cov is None: @@ -136,7 +139,7 @@ def __init__( if np.random is not None: np.random.seed(model_seed) - def generate_data(self) -> None: + def generate_data(self, train_nnar=True) -> None: """Generates the data of savar :return:""" # Prepare the datafield if self.data_field is None: @@ -165,7 +168,12 @@ def generate_data(self) -> None: if self.forcing_dict is not None: if self.verbose: print("Adding external forcing") + initial_data = self.data_field.copy() self._add_external_forcing() + diff = self.data_field - initial_data + print(f"Max change in data field: {diff.max()}") + print(f"Mean change in data field: {diff.mean()}") + print(f"Sample values after forcing applied:\n{diff[:, :5]}") else: print("No forcing") @@ -174,8 +182,17 @@ def generate_data(self) -> None: if self.verbose: print("Creating linear data") self._create_linear() + elif self.linearity == "polynomial": + if self.verbose: + print("Creating polynomial data") + self._create_polynomial() else: - raise NotImplementedError("Now, only linear methods are implemented") + if self.verbose: + print("Creating nonlinear data") + if train_nnar: + print("Training NNAR model before data generation...") + self.train_nnar(num_epochs=50, learning_rate=0.001, batch_size=32) + self._create_nonlinear() def generate_cov_noise_matrix(self) -> np.ndarray: """ @@ -247,22 +264,28 @@ def _add_external_forcing(self): raise TypeError("Forcing dict is empty") w_f = deepcopy(self.forcing_dict.get("w_f")) - f_1 = self.forcing_dict.get("f_1", 0) - f_2 = self.forcing_dict.get("f_2", 0) + f_1 = float(self.forcing_dict.get("f_1", 0)) + f_2 = float(self.forcing_dict.get("f_2", 0)) f_time_1 = self.forcing_dict.get("f_time_1", 0) f_time_2 = self.forcing_dict.get("f_time_2", self.time_length) ramp_type = self.forcing_dict.get("ramp_type", "linear") # Default to linear - # Default w_f to mode weights if not provided if w_f is None: w_f = deepcopy(self.mode_weights) w_f = (w_f != 0).astype(int) # Convert non-zero elements to 1 - w_f_sum = torch.tensor(w_f.sum(axis=0), dtype=torch.float32, device="cuda") + print(self.mode_weights.shape) + #w_f = w_f / (w_f.max() + 1e-8) # Normalize to range [0,1] + + # Merge last two dims first => shape (d_z, lat*lon) + temp = w_f.reshape(w_f.shape[0], w_f.shape[1]*w_f.shape[2]) + # sum over dim=0 => shape (lat*lon,) + w_f_sum = torch.tensor(temp.sum(axis=0), dtype=torch.float32, device="cuda") f_time_1 += self.transient f_time_2 += self.transient time_length = self.time_length + self.transient + # Generate the forcing trend using torch tensors if ramp_type == "linear": ramp = torch.linspace(f_1, f_2, f_time_2 - f_time_1, dtype=torch.float32, device="cuda") @@ -272,8 +295,14 @@ def _add_external_forcing(self): elif ramp_type == "exponential": t = torch.linspace(0, 1, f_time_2 - f_time_1, dtype=torch.float32, device="cuda") ramp = f_1 + (f_2 - f_1) * (torch.exp(t) - 1) / (torch.exp(torch.tensor(1.0)) - 1) + elif ramp_type == "sigmoid": + t = torch.linspace(-6, 6, f_time_2 - f_time_1, dtype=torch.float32, device="cuda") + ramp = f_1 + (f_2 - f_1) * (1 / (1 + torch.exp(-t))) + elif ramp_type == "sinusoidal": + t = torch.linspace(0, pi, f_time_2 - f_time_1, dtype=torch.float32, device="cuda") + ramp = f_1 + (f_2 - f_1) * (0.5 * (1 - torch.cos(t))) else: - raise ValueError("Unsupported ramp type. Choose from 'linear', 'quadratic', or 'exponential'.") + raise ValueError("Unsupported ramp type. Choose from 'linear', 'quadratic', 'exponential', 'sigmoid', or 'sinusoidal'.") # Generate the forcing trend using torch tensors trend = torch.cat([ @@ -282,13 +311,69 @@ def _add_external_forcing(self): torch.full((time_length - f_time_2,), f_2, dtype=torch.float32, device="cuda") ]).reshape(1, time_length) + if w_f_sum.dim() == 2: + w_f_sum = w_f_sum.sum(dim=0, keepdim=True) # Sum across the correct dimension + # Compute the forcing field on GPU forcing_field = (w_f_sum.reshape(1, -1) * trend.T).T self.forcing_data_field = forcing_field.cpu().numpy() - # Add it to the data field. + print(f"Using {ramp_type} ramp: f_1={f_1}, f_2={f_2}, f_time_1={f_time_1}, f_time_2={f_time_2}") + + print(f"Forcing data field mean: {self.forcing_data_field.mean()}") + + print(f"Before addition - Data field mean: {self.data_field.mean()}") + + data_field_before = self.data_field.copy() + self.data_field += self.forcing_data_field + data_field_after = self.data_field + + print(f"After addition - Data field mean: {self.data_field.mean()}") + + # # Convert tensors to numpy for plotting if necessary + # if isinstance(w_f_sum, torch.Tensor): + # w_f_sum = w_f_sum.cpu().numpy() + # if isinstance(forcing_field, torch.Tensor): + # forcing_field = forcing_field.cpu().numpy() + # if isinstance(data_field_before, torch.Tensor): + # data_field_before = data_field_before.cpu().numpy() + # if isinstance(data_field_after, torch.Tensor): + # data_field_after = data_field_after.cpu().numpy() + + # # Compute mean values over spatial dimensions + # mean_forcing = forcing_field.mean(axis=0) + # mean_data_before = data_field_before.mean(axis=0) + # mean_data_after = data_field_after.mean(axis=0) + + # # Plot 1: Mean Forcing over Time + # plt.figure(figsize=(10, 4)) + # plt.plot(range(time_length), mean_forcing, label="Mean Forcing", color="blue") + # plt.axvline(x=f_time_1, linestyle="--", color="gray", label="Start Forcing") + # plt.axvline(x=f_time_2, linestyle="--", color="gray", label="End Forcing") + # plt.xlabel("Time Steps") + # plt.ylabel("Forcing Intensity") + # plt.title("Evolution of External Forcing Over Time") + # plt.legend() + # plt.grid() + # plt.savefig(f"mean_forcing_over_time_{f_1}_{f_2}_{ramp_type}.png") # Save to a file + # plt.close() + + # # Plot 2: Mean Data Before and After Forcing + # plt.figure(figsize=(10, 4)) + # plt.plot(range(time_length), mean_data_before, label="Data Before Forcing", color="red", linestyle="dashed") + # plt.plot(range(time_length), mean_data_after, label="Data After Forcing", color="green") + # plt.axvline(x=f_time_1, linestyle="--", color="gray", label="Start Forcing") + # plt.axvline(x=f_time_2, linestyle="--", color="gray", label="End Forcing") + # plt.xlabel("Time Steps") + # plt.ylabel("Mean Data Value") + # plt.title("Effect of Forcing on Data Field") + # plt.legend() + # plt.grid() + # plt.savefig(f"mean_data_before_after_forcing_{f_1}_{f_2}_{ramp_type}.png") # Save to a file + # plt.close() + def _create_linear(self): """Weights N \times L data_field L \times T.""" @@ -312,3 +397,128 @@ def _create_linear(self): # data_field[..., t:t + 1] += torch.matmul(torch.matmul(torch.matmul(weights_inv, phi[..., i]), weights), data_field[..., t - 1 - i:t - i]) self.data_field = data_field[..., self.transient :].detach().cpu().numpy() + + + def train_nnar(self, num_epochs=50, learning_rate=0.001, batch_size=32): + """ + method for training a very simple single-layer neural network with + sigmoid activation (one neuron). We train it here on pairs (past_values, future_value), + but this can be adapted as needed. + """ + + # A trivial net: data_in -> [Linear] -> [Sigmoid] -> data_out + self.nnar_model = nn.Sequential( + nn.Linear(self.spatial_resolution, self.spatial_resolution), + nn.Sigmoid() + ).to("cuda") + + optimizer = torch.optim.Adam(self.nnar_model.parameters(), lr=learning_rate) + loss_fn = nn.MSELoss() + + # Create a training dataset from self.data_field: each sample is (X_t, X_{t+1}), + # (we might later incorporate more lags) + + # collect input-output pairs: + X = torch.from_numpy(self.data_field[:, :-1].T).float().to("cuda") + Y = torch.from_numpy(self.data_field[:, 1:].T).float().to("cuda") + dataset_size = X.shape[0] + + # Simple mini-batch loop + for epoch in range(num_epochs): + perm = torch.randperm(dataset_size, device="cuda") + batch_losses = [] + + for i in range(0, dataset_size, batch_size): + idx = perm[i : i + batch_size] + x_batch = X[idx] + y_batch = Y[idx] + + # forward pass + pred = self.nnar_model(x_batch) + loss = loss_fn(pred, y_batch) + batch_losses.append(loss.item()) + + # backward + update + optimizer.zero_grad() + loss.backward() + optimizer.step() + + if (epoch + 1) % 5 == 0: + print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {sum(batch_losses)/len(batch_losses):.6f}") + + print("Training of single-layer NNAR model completed.") + + + def _create_nonlinear(self): + """ + Generates nonlinear data by applying a (trained or simple) nonlinearity + at each time step. This method uses the same logic as _create_linear to step forward in time + and adds the nonlinearity (sigmoid) before adding to data_field. + + If train_nnar=True was set, we assume self.nnar_model was trained in generate_data(). + Otherwise, we can do a direct inline "torch.sigmoid(...)" approach. + Can be increased in complexity if needed + """ + + weights = torch.Tensor(np.linalg.pinv(self.mode_weights.reshape(self.n_vars, -1))).to("cuda") + phi = torch.Tensor(dict_to_matrix(self.links_coeffs)).to("cuda") + mode_weights_tensor = torch.Tensor(self.mode_weights.reshape(self.n_vars, -1)).to("cuda") + data_field = torch.Tensor(self.data_field).to("cuda") + + time_len = self.time_length + self.transient + tau_max = self.tau_max + + print("create_nonlinear (single-layer net + sigmoid)") + + for t in tqdm(range(tau_max, time_len)): + # Sum up influences from each lag + nonlinear_contrib = 0.0 + for i in range(tau_max): + # get linear combination as in _create_linear + lincombo = weights @ phi[..., i] @ mode_weights_tensor @ data_field[..., (t-1 - i):(t - i)] + # Apply a sigmoid (or feed it through the small neural net if you want more complexity) + lincombo_nl = torch.sigmoid(lincombo) + # accumulate + nonlinear_contrib += lincombo_nl.squeeze(-1) + + # Add the (nonlinear) effect to the data field at time t + data_field[:, t] += nonlinear_contrib + + self.data_field = data_field[:, self.transient :].detach().cpu().numpy() + + + def _create_polynomial(self): + """ + Example polynomial autoregression, e.g. x^2 for poly_degree=2. + """ + w_np = np.linalg.pinv(self.mode_weights.reshape(self.n_vars, -1)) + phi_np = dict_to_matrix(self.links_coeffs) + + w_torch = torch.Tensor(w_np).to("cuda") + phi_torch = torch.Tensor(phi_np).to("cuda") + mw_torch = torch.Tensor(self.mode_weights.reshape(self.n_vars, -1)).to("cuda") + data_field = torch.Tensor(self.data_field).to("cuda") + + time_len = self.time_length + self.transient + tau_max = self.tau_max + + print(f"create_polynomial with degrees={self.poly_degrees}") + + for t in tqdm(range(tau_max, time_len)): + # For each time step, sum over the contributions of all lags + for i in range(tau_max): + lincombo = ( + w_torch + @ phi_torch[..., i] + @ mw_torch + @ data_field[..., (t - 1 - i) : (t - i)] + ) + + # For each requested polynomial degree, add its effect + poly_sum = 0.0 + for deg in self.poly_degrees: + poly_sum += lincombo ** deg + + data_field[:, t] += poly_sum.squeeze(-1) + + self.data_field = data_field[:, self.transient :].detach().cpu().numpy() \ No newline at end of file diff --git a/scripts/main_picabu.py b/scripts/main_picabu.py index 23a702a..c30bb59 100755 --- a/scripts/main_picabu.py +++ b/scripts/main_picabu.py @@ -123,6 +123,8 @@ def main( f_time_1=savar_params.f_time_1, f_time_2=savar_params.f_time_2, ramp_type=savar_params.ramp_type, + linearity=savar_params.linearity, + poly_degrees=savar_params.poly_degrees, plot_original_data=savar_params.plot_original_data, ) datamodule.setup() From 2093b951c11c13e3b2a6143b6310a410ac7f08d0 Mon Sep 17 00:00:00 2001 From: IlijaTrajkovic Date: Mon, 21 Apr 2025 20:10:24 +0200 Subject: [PATCH 07/12] commented savar params and added new savar params --- configs/single_param_file_savar.json | 27 ++++++++++++-------- configs/single_param_file_with_comments.json | 18 +++++++++++++ 2 files changed, 35 insertions(+), 10 deletions(-) diff --git a/configs/single_param_file_savar.json b/configs/single_param_file_savar.json index d57b849..9e78f8a 100644 --- a/configs/single_param_file_savar.json +++ b/configs/single_param_file_savar.json @@ -1,12 +1,12 @@ { "exp_params": { - "exp_path": "/home/mila/j/julien.boussard/scratch/results/SAVAR_DATA_TEST", + "exp_path": "/pfs/work7/workspace/scratch/qa4548-results/SAVAR_DATA_TEST", "_target_": "emulator.src.datamodules.climate_datamodule.ClimateDataModule", "latent": true, - "d_z": 4, + "d_z": 9, "d_x": 400, - "lon": 20, - "lat": 20, + "lon": 30, + "lat": 30, "tau": 5, "random_seed": 1, "gpu": true, @@ -15,10 +15,10 @@ "verbose": true }, "data_params": { - "data_dir": "/network/scratch/j/julien.boussard/data/SAVAR_DATA_TEST", - "climateset_data": "/network/scratch/j/julien.boussard/data/icml_processed_data/picontrol/24_ni", + "data_dir": "/pfs/work7/workspace/scratch/qa4548-data/SAVAR_DATA_TEST", + "climateset_data": "/home/kit/iti/qa4548/my_projects/climatem/data/icosahedral_data/picontrol/24_ni", "reload_climate_set_data": false, - "icosahedral_coordinates_path": "/network/scratch/j/julien.boussard/data/vertex_lonlat_mapping.npy", + "icosahedral_coordinates_path": "/home/kit/iti/qa4548/my_projects/climatem/mappings/vertex_lonlat_mapping.npy", "in_var_ids": ["savar"], "out_var_ids": ["savar"], "num_levels": 1, @@ -124,11 +124,18 @@ "time_len": 10000, "comp_size": 10, "noise_val": 0.2, - "n_per_col": 2, + "n_per_col": 3, "difficulty": "easy", "seasonality": false, - "overlap": false, - "is_forced": false, + "overlap": 0.3, + "is_forced": true, + "f_1": 0, + "f_2": 3, + "f_time_1": 3000, + "f_time_2": 8000, + "ramp_type": "sinusoidal", + "linearity": "polynomial", + "poly_degrees": [2,3], "plot_original_data": true } diff --git a/configs/single_param_file_with_comments.json b/configs/single_param_file_with_comments.json index ee6daf5..ad1aa25 100644 --- a/configs/single_param_file_with_comments.json +++ b/configs/single_param_file_with_comments.json @@ -114,6 +114,24 @@ "plot_freq": 500, "plot_through_time": true, "print_freq": 500 + }, + "savar_params": { + "time_len": 10000, //number of steps for the data + "comp_size": 10, //size of components (modes) = comp_size x comp_size + "noise_val": 0.2, //noise value + "n_per_col": 3, //total number of modes on the grid = n_per_col x n_per_col + "difficulty": "easy", //complexity of causal links between latents = probabilty of there being a link between two modes + "seasonality": false, //seasonality - false per default, not implemented completely + "overlap": 0.3, //overlap of modes = 0 - modes have no overlap, 1 = all modes overlap at the center of the grid, one on another + "is_forced": true, //true if forcing included false if not + "f_1": 0, //initial forcing value + "f_2": 3, //final forcing value + "f_time_1": 3000, //start of forcing increase - forcing value at and before f_time_1 = f_1 + "f_time_2": 8000, //end of forcing increase - forcing value at and after f_time_2 = f_2 + "ramp_type": "sinusoidal", //form of forcing increase + "linearity": "polynomial", //linearity of data (linear, nonlinear, polynomial) + "poly_degrees": [2,3], //degrees of polynomial used for creating polynomial data + "plot_original_data": true //true if an animation of savar data is to be plotted, false otherwise } } From e8d0bf6ec1f2b3281338dbaa8ddadb20f99bb314 Mon Sep 17 00:00:00 2001 From: IlijaTrajkovic Date: Sun, 18 May 2025 22:54:23 +0200 Subject: [PATCH 08/12] refactor: update SAVAR parameters and enhance adjacency matrix evaluation --- climatem/plotting/plot_model_output.py | 3 + configs/single_param_file_savar.json | 3 +- scripts/varimax_pcmci_savar_evaluation.py | 156 ++++++++++++++-------- 3 files changed, 104 insertions(+), 58 deletions(-) diff --git a/climatem/plotting/plot_model_output.py b/climatem/plotting/plot_model_output.py index e9d9bcc..931acc0 100644 --- a/climatem/plotting/plot_model_output.py +++ b/climatem/plotting/plot_model_output.py @@ -96,6 +96,7 @@ def load(self, exp_path, data_loader): self.gt_w = data_loader.gt_w self.gt_graph = data_loader.gt_dag + def plot(self, learner, save=False): """ Main plotting function. @@ -1154,6 +1155,8 @@ def plot_adjacency_matrix( print("PERMUTED THE MATRICES") + #np.save(learner.plots_path / f"adjacency_{name_suffix}_permuted.npy", mat1) + subfig_names = [ f"Learned, latent dimensions = {mat1.shape[1], mat1.shape[2]}", diff --git a/configs/single_param_file_savar.json b/configs/single_param_file_savar.json index 9e78f8a..2d30e54 100644 --- a/configs/single_param_file_savar.json +++ b/configs/single_param_file_savar.json @@ -123,7 +123,7 @@ "savar_params": { "time_len": 10000, "comp_size": 10, - "noise_val": 0.2, + "noise_val": 0.25, "n_per_col": 3, "difficulty": "easy", "seasonality": false, @@ -138,6 +138,5 @@ "poly_degrees": [2,3], "plot_original_data": true } - } diff --git a/scripts/varimax_pcmci_savar_evaluation.py b/scripts/varimax_pcmci_savar_evaluation.py index 6c2ff52..81ba8c8 100644 --- a/scripts/varimax_pcmci_savar_evaluation.py +++ b/scripts/varimax_pcmci_savar_evaluation.py @@ -132,82 +132,126 @@ def varimax(Phi, gamma=1, q=20, tol=1e-6): if __name__ == "__main__": - ### Here set SAVAR paths and load data #### - difficulty = "easy" - tau = 5 - n_modes = 25 - comp_size = 25 - time_len = 10_000 + + # load your existing JSON config + config_path = Path("configs/single_param_file_savar.json") + with open(config_path, "r") as f: + cfg = json.load(f) + + exp = cfg["exp_params"] + data = cfg["data_params"] + savar = cfg["savar_params"] + + # pull out exactly the bits you used to hard-code + tau = exp["tau"] + n_modes = exp["d_z"] # latent dim = number of modes + comp_size = savar["comp_size"] + time_len = savar["time_len"] + is_forced = savar["is_forced"] + seasonality = savar["seasonality"] + overlap = savar["overlap"] + difficulty = savar["difficulty"] lat = lon = int(np.sqrt(n_modes)) * comp_size - noisestrength = 1 + noise_val = savar["noise_val"] var_names = [] for k in range(n_modes): var_names.append(rf"$X^{k}$") - savar_folder = Path("$HOME/savar_data") - savar_fname = f"m_{n_modes_gt}_{difficulty}_savar_name" + savar_folder = "/home/ka/ka_iti/ka_qa4548/my_projects/climatem/workspace/pfs7wor9/ka_qa4548-data/SAVAR_DATA_TEST" + # Load gt mode weights + savar_fname = f"modes_{n_modes}_tl_{time_len}_isforced_{is_forced}_difficulty_{difficulty}_noisestrength_{noise_val}_seasonality_{seasonality}_overlap_{overlap}" + # Get the gt mode weights + modes_gt = np.load(savar_folder + f"/{savar_fname}_mode_weights.npy") - savar_data = np.load(savar_folder / savar_fname) - params_file = savar_folder / f"{savar_fname[:-4]}_parameters.npy" + #savar_data = np.load(savar_folder / savar_fname) + params_file = savar_folder + f"/{savar_fname}_parameters.npy" params = np.load(params_file, allow_pickle=True).item() links_coeffs = params["links_coeffs"] - modes_gt = np.load(savar_folder / f"{savar_fname[:-4]}_mode_weights.npy") - modes_gt -= modes_gt.mean() - modes_gt /= modes_gt.std() + # modes_gt = np.load(savar_folder / f"{savar_fname[:-4]}_mode_weights.npy") + # modes_gt -= modes_gt.mean() + # modes_gt /= modes_gt.std() + + adj_gt = extract_adjacency_matrix(links_coeffs, n_modes, tau) + n_gt_connections = (np.array(adj_gt) > 0).sum() + + # load CDSD results (already permuted / aligned) + cdsd_adj_inferred_path = Path("/home/ka/ka_iti/ka_qa4548/my_projects/climatem/workspace/pfs7wor9/ka_qa4548-results/SAVAR_DATA_TEST/var_savar_scenarios_piControl_nonlinear_False_tau_5_z_9_lr_0.001_bs_256_spreg_0_ormuinit_100000.0_spmuinit_0.1_spthres_0.05_fixed_False_num_ensembles_1_instantaneous_False_crpscoef_1_spcoef_0_tempspcoef_0_overlap_0.3_forcing_True/plots/graphs.npy") + cdsd_modes_inferred_path = Path("/home/ka/ka_iti/ka_qa4548/my_projects/climatem/workspace/pfs7wor9/ka_qa4548-results/SAVAR_DATA_TEST/var_savar_scenarios_piControl_nonlinear_False_tau_5_z_9_lr_0.001_bs_256_spreg_0_ormuinit_100000.0_spmuinit_0.1_spthres_0.05_fixed_False_num_ensembles_1_instantaneous_False_crpscoef_1_spcoef_0_tempspcoef_0_overlap_0.3_forcing_True/plots/w_decoder.npy") + modes_inferred = np.load(cdsd_modes_inferred_path) + adj_w = np.load(cdsd_adj_inferred_path) - gt_adj_list = extract_adjacency_matrix(links_coeffs, n_modes, tau) - n_gt_connections = (np.array(gt_adj_list) > 0).sum() ############################ - # Fit PCA + varimax - pca_model = PCA(n_modes).fit(savar_data.T) - latent_data = pca_model.transform(savar_data.T) - varimaxpcs, varimax_rotation = varimax(latent_data) + # # Fit PCA + varimax + # pca_model = PCA(n_modes).fit(savar_data.T) + # latent_data = pca_model.transform(savar_data.T) + # varimaxpcs, varimax_rotation = varimax(latent_data) - # To recover which mode is which and permute accordingly when evaluating - inverse_varimax = dot(latent_data, np.linalg.pinv(varimax_rotation)) - reverted_data = pca_model.inverse_transform(inverse_varimax) + # # To recover which mode is which and permute accordingly when evaluating + # inverse_varimax = dot(latent_data, np.linalg.pinv(varimax_rotation)) + # reverted_data = pca_model.inverse_transform(inverse_varimax) - dataframe = pp.DataFrame(varimaxpcs, datatime={0: np.arange(len(varimaxpcs))}, var_names=var_names) - # Run PCMCI - pcmci = PCMCI(dataframe=dataframe, cond_ind_test=parcorr, verbosity=1) + # dataframe = pp.DataFrame(varimaxpcs, datatime={0: np.arange(len(varimaxpcs))}, var_names=var_names) + # # Run PCMCI + # pcmci = PCMCI(dataframe=dataframe, cond_ind_test=parcorr, verbosity=1) - results = pcmci.run_pcmci(tau_min=1, tau_max=5, pc_alpha=None, alpha_level=0.001) + # results = pcmci.run_pcmci(tau_min=1, tau_max=5, pc_alpha=None, alpha_level=0.001) # Permute accordingly before evaluating learned graph. - individual_modes = np.zeros((n_modes, time_len, lat, lon)) - for k in range(n_modes): - latent_data_bis = np.zeros(latent_data.shape) - latent_data_bis[:, k] = latent_data[:, k] - inverse_varimax = dot(latent_data_bis, np.linalg.pinv(varimax_rotation)) - reverted_data = pca_model.inverse_transform(inverse_varimax) - individual_modes[k] = reverted_data.reshape((-1, lat, lon)) - individual_modes = individual_modes.std(1) - individual_modes -= individual_modes.mean() - individual_modes /= individual_modes.std() - - permutation_list = ((modes_gt[:, None] - individual_modes[None]) ** 2).sum((2, 3)).argmin(1) - - # Get adjacency matrix from PCMCI graph - graph = results["graph"] - graph[ - results["val_matrix"] - < np.abs(results["val_matrix"].flatten()[results["val_matrix"].flatten().argsort()[::-1][n_gt_connections - 1]]) - ] = "" - - adj_matrix_inferred = np.zeros((tau, n_modes, n_modes)) - for k in range(n_modes): - graph_k = graph[k] - for j in range(n_modes): - adj_matrix_inferred[:, k, j] = graph_k[j][1:] == "-->" - + # individual_modes = np.zeros((n_modes, time_len, lat, lon)) + # for k in range(n_modes): + # latent_data_bis = np.zeros(latent_data.shape) + # latent_data_bis[:, k] = latent_data[:, k] + # inverse_varimax = dot(latent_data_bis, np.linalg.pinv(varimax_rotation)) + # reverted_data = pca_model.inverse_transform(inverse_varimax) + # individual_modes[k] = reverted_data.reshape((-1, lat, lon)) + # individual_modes = individual_modes.std(1) + # individual_modes -= individual_modes.mean() + # individual_modes /= individual_modes.std() + + # permutation_list = ((modes_gt[:, None] - individual_modes[None]) ** 2).sum((2, 3)).argmin(1) + + # # Get adjacency matrix from PCMCI graph + # graph = results["graph"] + # graph[ + # results["val_matrix"] + # < np.abs(results["val_matrix"].flatten()[results["val_matrix"].flatten().argsort()[::-1][n_gt_connections - 1]]) + # ] = "" + + # adj_matrix_inferred = np.zeros((tau, n_modes, n_modes)) + # for k in range(n_modes): + # graph_k = graph[k] + # for j in range(n_modes): + # adj_matrix_inferred[:, k, j] = graph_k[j][1:] == "-->" + + # for k in range(tau): + # adj_matrix_inferred[k] = adj_matrix_inferred[k][np.ix_(permutation_list, permutation_list)] + # adj_matrix_inferred = adj_matrix_inferred.transpose((0, 2, 1)) + + # Find the permutation + modes_inferred = modes_inferred.reshape((lat, lon, modes_inferred.shape[-1])).transpose((2, 0, 1)) + + # Get the flat index of the maximum for each mode + idx_gt_flat = np.argmax(modes_gt.reshape(modes_gt.shape[0], -1), axis=1) # shape: (n_modes,) + idx_inferred_flat = np.argmax(modes_inferred.reshape(modes_inferred.shape[0], -1), axis=1) # shape: (n_modes,) + + # Convert flat indices to 2D coordinates (row, col) + idx_gt = np.array([np.unravel_index(i, (lat, lon)) for i in idx_gt_flat]) # shape: (n_modes, 2) + idx_inferred = np.array([np.unravel_index(i, (lat, lon)) for i in idx_inferred_flat]) # shape: (n_modes, 2) + + # Compute error matrix using squared Euclidean distance between indices which yields an (n_modes x n_modes) matrix + permutation_list = ((idx_gt[:, None, :] - idx_inferred[None, :, :]) ** 2).sum(axis=2).argmin(axis=1) + print("permutation_list:", permutation_list) + + # Permute for k in range(tau): - adj_matrix_inferred[k] = adj_matrix_inferred[k][np.ix_(permutation_list, permutation_list)] - adj_matrix_inferred = adj_matrix_inferred.transpose((0, 2, 1)) + adj_w[k] = adj_w[k][np.ix_(permutation_list, permutation_list)] + + print("PERMUTED THE MATRICES") - precision, recall, f1, shd = evaluate_adjacency_matrix(adj_matrix_inferred, gt_adj_list, 0.9) + precision, recall, f1, shd = evaluate_adjacency_matrix(adj_w, adj_gt, 0.9) print(f"difficuly {difficulty} results:") print(f"Precision: {precision}, Recall: {recall}, F1 Score: {f1}, SHD: {shd}") From cc1706ed712e786e4c8ebc3d021a77e657eced99 Mon Sep 17 00:00:00 2001 From: IlijaTrajkovic Date: Fri, 4 Jul 2025 21:01:49 +0200 Subject: [PATCH 09/12] Refactor code for improved readability and consistency across multiple modules, used poetry - Updated parameter formatting in expParams and dataParams classes for better alignment and clarity. - Enhanced comments and documentation for various parameters in expParams, dataParams, and other classes. - Cleaned up unnecessary whitespace and improved code formatting in savar_dataset.py, plot_model_output.py, generate_savar_datasets.py, and graph_evaluation.py. - Modified load_and_permute_all_matrices function to accept explicit parameters instead of a list of CSV files, improving clarity on input requirements. - Adjusted the main execution block in varimax_pcmci_savar_evaluation.py to streamline loading and processing of inferred modes and adjacency matrices. - Ensured consistent use of spacing and line breaks throughout the codebase for better readability. --- climatem/data_loader/climate_dataset.py | 140 +++++++----------- climatem/data_loader/cmip6_dataset.py | 1 + climatem/data_loader/input4mip_dataset.py | 3 +- climatem/data_loader/savar_dataset.py | 2 +- climatem/plotting/plot_model_output.py | 6 +- .../synthetic_data/generate_savar_datasets.py | 59 ++++++-- climatem/synthetic_data/graph_evaluation.py | 122 +++++++++------ climatem/synthetic_data/savar.py | 133 ++++++++--------- scripts/main_picabu.py | 20 +-- 9 files changed, 254 insertions(+), 232 deletions(-) diff --git a/climatem/data_loader/climate_dataset.py b/climatem/data_loader/climate_dataset.py index 1ae587d..d7a96bd 100644 --- a/climatem/data_loader/climate_dataset.py +++ b/climatem/data_loader/climate_dataset.py @@ -1,26 +1,24 @@ -# NOTE: as of 14th Oct, I am also trying to get this to work for multiple variables. - -import glob -import os +# import glob +# import os +# from datetime import datetime, timedelta +import itertools import zipfile from pathlib import Path -from typing import Dict, List, Optional, Tuple, Union +from typing import Dict, List, Optional, Union # Tuple import numpy as np import torch import xarray as xr -from climatem.constants import ( # INPUT4MIPS_NOM_RES,; INPUT4MIPS_TEMP_RES, - AVAILABLE_MODELS_FIRETYPE, - CMIP6_NOM_RES, - CMIP6_TEMP_RES, - NO_OPENBURNING_VARS, - OPENBURNING_MODEL_MAPPING, -) - # from climatem.plotting.plot_data import plot_species, plot_species_anomaly from climatem.utils import get_logger +# from climatem.constants import ( # INPUT4MIPS_NOM_RES,; INPUT4MIPS_TEMP_RES,; CMIP6_NOM_RES,; CMIP6_TEMP_RES,; NO_OPENBURNING_VARS, +# AVAILABLE_MODELS_FIRETYPE, +# OPENBURNING_MODEL_MAPPING, +# ) + + log = get_logger() @@ -112,44 +110,24 @@ def __init__( self.global_normalization = global_normalization self.seasonality_removal = seasonality_removal - if climate_model in AVAILABLE_MODELS_FIRETYPE: - openburning_specs = OPENBURNING_MODEL_MAPPING[climate_model] - else: - openburning_specs = OPENBURNING_MODEL_MAPPING["other"] - - ds_kwargs = dict( - scenarios=scenarios, - years=self.years, - historical_years=self.historical_years, - channels_last=channels_last, - openburning_specs=openburning_specs, - mode=mode, - output_save_dir=self.output_save_dir, - reload_climate_set_data=self.reload_climate_set_data, - seq_to_seq=seq_to_seq, - global_normalization=self.global_normalization, - seasonality_removal=self.seasonality_removal, - ) + # if climate_model in AVAILABLE_MODELS_FIRETYPE: + # openburning_specs = OPENBURNING_MODEL_MAPPING[climate_model] + # else: + # openburning_specs = OPENBURNING_MODEL_MAPPING["other"] + self.seq_len = seq_len self.lat = lat self.lon = lon self.icosahedral_coordinates_path = icosahedral_coordinates_path - # TODO: is that needed? - # creates on cmip and on input4mip dataset - # print("creating input4mips") - self.input4mips_ds = Input4MipsDataset(variables=in_variables, **ds_kwargs) - # print("creating cmip6") - # self.cmip6_ds=self.input4mips_ds - self.cmip6_ds = CMIP6Dataset( - climate_model=climate_model, num_ensembles=num_ensembles, variables=out_variables, **ds_kwargs - ) - - # NOTE:() changing this so it can deal with with grib files and netcdf files - # this operates variable wise now.... #TODO: sizes for input4mips / adapt to mulitple vars def load_into_mem( - self, paths: List[List[str]], num_vars: int, channels_last=True, seq_to_seq=True - ): # -> np.ndarray(): + self, + paths: List[List[str]], + num_vars: int, + channels_last=True, + seq_to_seq=True, + get_years=None, + ): """ Take a file structure of netcdf or grib files and load them into memory. @@ -161,65 +139,58 @@ def load_into_mem( """ array_list = [] - # print("paths:", paths) - # print("length paths", len(paths)) - # I need to check here that it is doing the right thing for vlist in paths: - # print("length_paths_list", len(vlist)) - # print the last three characters of the first element of vlist - # NOTE:() assert that they are either .nc or .grib - and print an error! if vlist[0][-3:] == ".nc": - temp_data = xr.open_mfdataset( - vlist, concat_dim="time", combine="nested" - ).compute() # .compute is not necessary but eh, doesn't hurt - # ignore the bnds dimension - temp_data = temp_data.drop_dims("bnds") - # print("Temp data at the point of reading it in:", temp_data) - elif vlist[0][-5:] == ".grib": - # need to install cfgrib, eccodes and likely ecmwflibs to make sure this cfgrib engine works and is available + temp_data = xr.open_mfdataset(vlist, concat_dim="time", combine="nested").compute() + temp_data = temp_data.drop_dims("bnds", errors="ignore") + + elif vlist[0].endswith(".grib"): temp_data = xr.open_mfdataset(vlist, engine="cfgrib", concat_dim="time", combine="nested").compute() - # print("Temp data at the point of reading it in:", temp_data) - # then get rid of this with some assert ^ see above + # TODO : handle gribs together + elif vlist[0].endswith(".grib2"): + # TODO: not all data will have this name to remove leap days + we should remove feb 29? + filtered_vlist = list(itertools.chain(*vlist)) + filtered_vlist = [item for item in vlist if "000366.grib2" not in item] + temp_data = xr.open_mfdataset( + filtered_vlist, engine="cfgrib", concat_dim="time", combine="nested" + ).compute() + else: print("File extension not recognized, please use either .nc or .grib") + continue - temp_data = temp_data.to_array().to_numpy() # Should be of shape (vars, 1036*num_scenarios, 96, 144) - - # print("Temp data shape:", temp_data.shape) - # temp_data = temp_data.squeeze() # (1036*num_scanarios, 96, 144) + temp_data = temp_data.to_array().to_numpy() # Should be of shape (vars, time, lat, lon) array_list.append(temp_data) - # print("length of the array list:", len(array_list)) temp_data = np.concatenate(array_list, axis=0) - # print("Temp data shape after concatenation:", temp_data.shape) - - # this is not very neat, but it calc - if paths[0][0][-5:] == ".grib": + if paths[0][0].endswith(".grib"): years = len(paths[0]) temp_data = temp_data.reshape(num_vars, years, self.seq_len, -1) - # print("temp data shape", temp_data.shape) + + elif paths[0][0].endswith(".grib2"): + # Use self.seq_len = 365 (post-leap-day-removal) + filtered_vlist = [f for f in vlist if int(f[-10:-6]) <= 365] + vlist = filtered_vlist + years = len(vlist) // self.seq_len + temp_data = temp_data.reshape(num_vars, years, self.seq_len, -1) else: years = len(paths[0]) temp_data = temp_data.reshape(num_vars, years, self.seq_len, self.lon, self.lat) - # print("temp data shape", temp_data.shape) - - # create a new array with the first 3 columns, and then tuple(lon, lat) if seq_to_seq is False: temp_data = temp_data[:, :, -1, :, :] # only take last time step temp_data = np.expand_dims(temp_data, axis=2) - # print("seq to 1 temp data shape", temp_data.shape) + if channels_last: temp_data = temp_data.transpose((1, 2, 3, 4, 0)) - elif paths[0][0][-5:] == ".grib": - # print("In elif paths[0][0][-5:] == '.grib'") + elif paths[0][0][-5:] in [".grib", "grib2"]: temp_data = temp_data.transpose((1, 2, 0, 3)) else: temp_data = temp_data.transpose((1, 2, 0, 3, 4)) - # print("final temp data shape", temp_data.shape) + return temp_data # (86*num_scenarios!, 12, vars, 96, 144). Desired shape where 86*num_scenaiors can be the batch dimension. Can get items of shape (batch_size, 12, 96, 144) -> #TODO: confirm that one item should be one year of one scenario @@ -236,10 +207,13 @@ def load_coordinates_into_mem(self, paths: List[List[str]]) -> np.ndarray: Returns: np.ndarray: coordinates """ + print("self.icosahedral_coordinates_path", self.icosahedral_coordinates_path) print("length paths", len(paths)) if paths[0][0][-5:] == ".grib": # we have no lat and lon in grib files, so we need to fill it up from elsewhere, from the mapping.txt file: coordinates = np.load(self.icosahedral_coordinates_path) + elif paths[0][0][-5:] == "grib2": + coordinates = np.loadtxt(self.icosahedral_coordinates_path, skiprows=1, usecols=(1, 2)) else: for vlist in [paths[0]]: # print("I am in the else of load_coordinates_into_mem") @@ -375,7 +349,7 @@ def get_causal_data( # print("Trying to regrid to lon, lat if we have regular data...") # data = data.reshape(num_scenarios, num_years, num_vars, LON, LAT) - data = data.reshape(num_scenarios, num_years * 12, num_vars, self.lon, self.lat) + data = data.reshape(num_scenarios, num_years * self.seq_len, num_vars, self.lon, self.lat) except ValueError: print( @@ -389,7 +363,7 @@ def get_causal_data( # 26/08/24 # Now we don't split up the ensemble members - data = data.reshape(1, num_years * 12, num_vars, -1) + data = data.reshape(1, num_years * self.seq_len, num_vars, -1) # print("Data shape after reshaping:", data.shape) if isinstance(num_months_aggregated, (int, np.integer)) and num_months_aggregated > 1: @@ -730,15 +704,15 @@ def remove_seasonality(self, data): mean = np.nanmean(data, axis=0) std = np.nanstd(data, axis=0) - # make a numpy array containing the mean and std for each month: remove_season_stats = np.array([mean, std]) np.save(self.output_save_dir / "remove_season_stats", remove_season_stats, allow_pickle=True) print("Just about to return the data after removing seasonality.") - - return (data - mean[None]) / std[None] + std_safe = np.where(std == 0, 1, std) + deseasonalized = (data - mean[None]) / std_safe[None] + return deseasonalized def write_dataset_statistics(self, fname, stats): # fname = fname.replace('.npz.npy', '.npy') diff --git a/climatem/data_loader/cmip6_dataset.py b/climatem/data_loader/cmip6_dataset.py index 19eba20..b2f748c 100644 --- a/climatem/data_loader/cmip6_dataset.py +++ b/climatem/data_loader/cmip6_dataset.py @@ -12,6 +12,7 @@ # from climatem.plotting.plot_data import plot_species, plot_species_anomaly from climatem.utils import get_logger + from .climate_dataset import ClimateDataset log = get_logger() diff --git a/climatem/data_loader/input4mip_dataset.py b/climatem/data_loader/input4mip_dataset.py index a1bd11d..d083ba2 100644 --- a/climatem/data_loader/input4mip_dataset.py +++ b/climatem/data_loader/input4mip_dataset.py @@ -1,8 +1,6 @@ # NOTE: as of 14th Oct, I am also trying to get this to work for multiple variables. -import glob import os -import zipfile from pathlib import Path from typing import List, Optional, Tuple, Union @@ -24,6 +22,7 @@ # input4mips data set: same per model # from datamodule create one of these per train/test/val + class Input4MipsDataset(ClimateDataset): """ Loads all scenarios for a given var / for all vars. diff --git a/climatem/data_loader/savar_dataset.py b/climatem/data_loader/savar_dataset.py index 4612544..0da88f2 100644 --- a/climatem/data_loader/savar_dataset.py +++ b/climatem/data_loader/savar_dataset.py @@ -38,7 +38,7 @@ def __init__( ): super().__init__() self.output_save_dir = Path(output_save_dir) - self.savar_name = f"modes_{n_per_col**2}_tl_{time_len}_isforced_{is_forced}_difficulty_{difficulty}_noisestrength_{noise_val}_seasonality_{seasonality}_overlap_{overlap}" + self.savar_name = f"modes_{n_per_col**2}_tl_{time_len}_isforced_{is_forced}_difficulty_{difficulty}_noisestrength_{noise_val}_seasonality_{seasonality}_overlap_{overlap}_f1_{f_1}_f2_{f_2}_ft1_{f_time_1}_ft2_{f_time_2}_ramp_{ramp_type}_linearity_{linearity}_polydegs_{poly_degrees}" self.savar_path = self.output_save_dir / f"{self.savar_name}.npy" self.global_normalization = global_normalization diff --git a/climatem/plotting/plot_model_output.py b/climatem/plotting/plot_model_output.py index d583831..ebc7e01 100644 --- a/climatem/plotting/plot_model_output.py +++ b/climatem/plotting/plot_model_output.py @@ -217,7 +217,6 @@ def plot(self, learner, save=False): if learner.plot_params.savar: self.plot_adjacency_matrix( mat1=adj, - # Below savar dag mat2=learner.datamodule.savar_gt_adj, path=learner.plots_path, name_suffix="transition", @@ -368,13 +367,14 @@ def plot_sparsity(self, learner, save=False): if learner.plot_params.savar: savar_folder = learner.data_params.data_dir n_modes = learner.savar_params.n_per_col**2 - savar_fname = f"modes_{n_modes}_tl_{learner.savar_params.time_len}_isforced_{learner.savar_params.is_forced}_difficulty_{learner.savar_params.difficulty}_noisestrength_{learner.savar_params.noise_val}_seasonality_{learner.savar_params.seasonality}_overlap_{learner.savar_params.overlap}" + savar_fname = f"modes_{n_modes}_tl_{learner.savar_params.time_len}_isforced_{learner.savar_params.is_forced}_difficulty_{learner.savar_params.difficulty}_noisestrength_{learner.savar_params.noise_val}_seasonality_{learner.savar_params.seasonality}_overlap_{learner.savar_params.overlap}_f1_{learner.savar_params.f_1}_f2_{learner.savar_params.f_2}_ft1_{learner.savar_params.f_time_1}_ft2_{learner.savar_params.f_time_2}_ramp_{learner.savar_params.ramp_type}_linearity_{learner.savar_params.linearity}_polydegs_{learner.savar_params.poly_degrees}" # Get the gt mode weights modes_gt = np.load(f"{savar_folder}/{savar_fname}_mode_weights.npy") if learner.iteration == 500: savar_data = np.load(f"{savar_folder}/{savar_fname}.npy") savar_anim_path = f"{savar_folder}/{savar_fname}_original_savar_data.gif" self.plot_original_savar(savar_data, learner.lat, learner.lon, savar_anim_path) + print(modes_gt) # TODO: plot the prediction vs gt # plot_compare_prediction(x, x_hat) @@ -1471,7 +1471,7 @@ def animate(i): return (cax,) # Create an animation - ani = animation.FuncAnimation(fig, animate, frames=100, blit=True) + ani = animation.FuncAnimation(fig, animate, frames=100, blit=False) # Save the animation as a video file ani.save(path, writer="pillow", fps=10) diff --git a/climatem/synthetic_data/generate_savar_datasets.py b/climatem/synthetic_data/generate_savar_datasets.py index 966f1af..430728a 100644 --- a/climatem/synthetic_data/generate_savar_datasets.py +++ b/climatem/synthetic_data/generate_savar_datasets.py @@ -1,11 +1,11 @@ import csv import json +import matplotlib.animation as animation import matplotlib.pyplot as plt import numpy as np from mpl_toolkits.axes_grid1 import make_axes_locatable from scipy.stats import beta -import matplotlib.animation as animation from climatem.synthetic_data.savar import SAVAR from climatem.synthetic_data.utils import check_stability, create_random_mode @@ -111,7 +111,7 @@ def generate_save_savar_data( f_time_2=8000, ramp_type="linear", linearity="polynomial", - poly_degrees=[2,3], + poly_degrees=[2, 3], plotting=True, ): @@ -119,18 +119,18 @@ def generate_save_savar_data( ny = nx = n_per_col * comp_size N = n_per_col**2 # Number of components - if not (0 <= overlap <= 1): raise ValueError("overlap must be between 0 and 1") + if not (0 <= overlap <= 1): + raise ValueError("overlap must be between 0 and 1") noise_weights = np.zeros((N, nx, ny)) modes_weights = np.zeros((N, nx, ny)) - # Specify the path where you want to save the data npy_name = f"{name}.npy" save_path = save_dir_path / npy_name # Center starting position (for fully overlapping modes) - center_x_start = (nx - comp_size) // 2 + center_x_start = (nx - comp_size) // 2 center_y_start = (ny - comp_size) // 2 # Create modes weights @@ -145,14 +145,14 @@ def generate_save_savar_data( new_y_start = int((1 - overlap) * orig_y_start + overlap * center_y_start) new_x_end = new_x_start + comp_size new_y_end = new_y_start + comp_size - modes_weights[ - idx, new_x_start : new_x_end, new_y_start : new_y_end - ] = create_random_mode((comp_size, comp_size), random=True) - #for k in range(n_per_col): - # for j in range(n_per_col): - noise_weights[ - idx, new_x_start : new_x_end, new_y_start : new_y_end - ] = create_random_mode((comp_size, comp_size), random=True) + modes_weights[idx, new_x_start:new_x_end, new_y_start:new_y_end] = create_random_mode( + (comp_size, comp_size), random=True + ) + # for k in range(n_per_col): + # for j in range(n_per_col): + noise_weights[idx, new_x_start:new_x_end, new_y_start:new_y_end] = create_random_mode( + (comp_size, comp_size), random=True + ) # This is the probabiliity of having a link between latent k and j, with k different from j. latents always have one link with themselves at a previous time. if difficulty == "easy": @@ -297,4 +297,37 @@ def generate_save_savar_data( np.save(save_path, savar_model.data_field) print(f"{name} DONE!") + + plot_original_savar(savar_model.data_field, nx, ny, save_dir_path / f"{name}_original_savar_data2.gif") return savar_model.data_field + + +def plot_original_savar(data, lat, lon, path): + """Plotting the original savar data.""" + print(f"data shape {data.shape}") + # Get the dimensions + time_steps = data.shape[1] + data_reshaped = data.T.reshape((time_steps, lat, lon)) + + avg_data = np.mean(data_reshaped, axis=0) + vmin_avg, vmax_avg = avg_data.min(), avg_data.max() + vmin_all, vmax_all = data_reshaped.min(), data_reshaped.max() + print(f"avg_data range: {vmin_avg:.3e} to {vmax_avg:.3e}") + print(f"full data range: {vmin_all:.3e} to {vmax_all:.3e}") + + fig, ax = plt.subplots(figsize=(lon / 10, lat / 10)) + cax = ax.imshow(data_reshaped[0], aspect="auto", cmap="viridis", vmin=vmin_avg, vmax=vmax_avg) + # cbar = fig.colorbar(cax, ax=ax) + + def animate(i): + cax.set_data(data_reshaped[i]) + ax.set_title(f"Time step: {i+1}") + return (cax,) + + # Create an animation + ani = animation.FuncAnimation(fig, animate, frames=100, blit=False) + + # Save the animation as a video file + ani.save(path, writer="pillow", fps=10) + + plt.close() diff --git a/climatem/synthetic_data/graph_evaluation.py b/climatem/synthetic_data/graph_evaluation.py index 24c6088..ed870d5 100644 --- a/climatem/synthetic_data/graph_evaluation.py +++ b/climatem/synthetic_data/graph_evaluation.py @@ -68,7 +68,7 @@ def permute_matrix(matrix, permutation): return permuted_matrix -def load_and_permute_all_matrices(csv_files, permutation, remove_modes=[]): +def load_and_permute_all_matrices(modes_inferred, modes_gt, adj_w, adj_gt, lat, lon, tau): """ Loads and permutes multiple adjacency matrices, one for each time lag. @@ -80,24 +80,28 @@ def load_and_permute_all_matrices(csv_files, permutation, remove_modes=[]): np.ndarray: A 3D NumPy array containing all permuted adjacency matrices where the shape is (number_of_time_lags, n, n). """ - permuted_matrices = [] + # Find the permutation + modes_inferred = modes_inferred.reshape((lat, lon, modes_inferred.shape[-1])).transpose((2, 0, 1)) - for csv_file in csv_files: - # Load the adjacency matrix - adjacency_matrix = load_adjacency_matrix(csv_file) + # Get the flat index of the maximum for each mode + idx_gt_flat = np.argmax(modes_gt.reshape(modes_gt.shape[0], -1), axis=1) # shape: (n_modes,) + idx_inferred_flat = np.argmax(modes_inferred.reshape(modes_inferred.shape[0], -1), axis=1) # shape: (n_modes,) - if len(remove_modes): - adjacency_matrix = np.delete(adjacency_matrix, remove_modes, 0) - adjacency_matrix = np.delete(adjacency_matrix, remove_modes, 1) + # Convert flat indices to 2D coordinates (row, col) + idx_gt = np.array([np.unravel_index(i, (lat, lon)) for i in idx_gt_flat]) # shape: (n_modes, 2) + idx_inferred = np.array([np.unravel_index(i, (lat, lon)) for i in idx_inferred_flat]) # shape: (n_modes, 2) - # Permute the adjacency matrix - permuted_matrix = permute_matrix(adjacency_matrix, permutation) + # Compute error matrix using squared Euclidean distance between indices which yields an (n_modes x n_modes) matrix + permutation_list = ((idx_gt[:, None, :] - idx_inferred[None, :, :]) ** 2).sum(axis=2).argmin(axis=1) + print("permutation_list:", permutation_list) - # Append the permuted matrix to the list - permuted_matrices.append(permuted_matrix) + # Permute + for k in range(tau): + adj_w[k] = adj_w[k][np.ix_(permutation_list, permutation_list)] - # Convert the list of permuted matrices to a NumPy array - return np.array(permuted_matrices) + print("PERMUTED THE MATRICES") + + return adj_w def binarize_matrix(A, threshold=0.5): @@ -319,63 +323,85 @@ def save_equations_to_json(equations, filename): # Example usage: if __name__ == "__main__": - # Set parameters here - home_path = Path("$HOME") - tau = 5 - threshold = 0.75 - n_modes_gt = 25 - difficulty = "easy" - iteration = 2999 - comp_size = 25 - - n_modes = n_modes_gt + threshold = 0.5 + + # load your existing JSON config + config_path = Path("configs/single_param_file_savar.json") + with open(config_path, "r") as f: + cfg = json.load(f) + + exp = cfg["exp_params"] + data = cfg["data_params"] + savar = cfg["savar_params"] + + # pull out exactly the bits you used to hard-code + tau = exp["tau"] + n_modes = exp["d_z"] # latent dim = number of modes + comp_size = savar["comp_size"] + time_len = savar["time_len"] + is_forced = savar["is_forced"] + seasonality = savar["seasonality"] + overlap = savar["overlap"] + difficulty = savar["difficulty"] lat = lon = int(np.sqrt(n_modes)) * comp_size - folder_results = home_path / f"predictions_{n_modes_gt}_{difficulty}" - savar_folder = home_path / Path("savar_data") - savar_fname = f"m_{n_modes_gt}_{difficulty}_savar_name" - run_name = "model_results_folder" - results_path = folder_results / f"{savar_fname}_{run_name}" - csv_files = [results_path / f"adjacency_transition_time_{i}_iteration_{iteration}.csv" for i in np.arange(5, 0, -1)] - - # Get the permuted adjacency matrices for all time lags - modes_gt = np.load(savar_folder / f"{savar_fname}_mode_weights.npy") - mat_adj_w = np.load(results_path / f"adj_w_iteration_{iteration}.npy")[0] - - if n_modes == 100: - # With lots of modes some modes are equal and the other function breaks. This function works for the specifics params of the 100 modes dataset. - permutation_list = get_permutation_list_hardcoded_100(mat_adj_w, modes_gt, lat, lon) - else: - permutation_list = get_permutation_list(mat_adj_w, modes_gt, lat, lon) - permuted_matrices = np.array(load_and_permute_all_matrices(csv_files, permutation_list)) + noise_val = savar["noise_val"] + + home_path = str(Path.home()) + savar_path = "/my_projects/climatem/workspace/pfs7wor9/ka_qa4548-data/SAVAR_DATA_TEST" + results_path = Path( + "my_projects/climatem/workspace/pfs7wor9/ka_qa4548-results/SAVAR_DATA_TEST/var_savar_scenarios_piControl_nonlinear_False_tau_5_z_9_lr_0.001_bs_256_spreg_0_ormuinit_100000.0_spmuinit_0.1_spthres_0.05_fixed_False_num_ensembles_1_instantaneous_False_crpscoef_1_spcoef_0_tempspcoef_0_overlap_0.3_forcing_True" + ) + + # Load ground truthh modes + savar_folder = home_path + savar_path + savar_fname = f"modes_{n_modes}_tl_{time_len}_isforced_{is_forced}_difficulty_{difficulty}_noisestrength_{noise_val}_seasonality_{seasonality}_overlap_{overlap}" + # modes_gt_path = savar_folder / Path(f"/{savar_fname}_mode_weights.npy") + modes_gt = np.load(f"{savar_folder}/{savar_fname}_mode_weights.npy") + + result_folder = home_path / results_path + # load CDSD results + cdsd_adj_inferred_path = result_folder / Path("plots/graphs.npy") + cdsd_modes_inferred_path = result_folder / Path("plots/w_decoder.npy") + modes_inferred = np.load(cdsd_modes_inferred_path) + adjacency_inferred = np.load(cdsd_adj_inferred_path) + + # if n_modes == 100: + # # With lots of modes some modes are equal and the other function breaks. This function works for the specifics params of the 100 modes dataset. + # permutation_list = get_permutation_list(mat_adj_w, modes_gt, lat, lon) + # else: + # permutation_list = get_permutation_list(mat_adj_w, modes_gt, lat, lon) + permuted_matrices = np.array( + load_and_permute_all_matrices(modes_inferred, modes_gt, adjacency_inferred, adjacency_inferred, lat, lon, tau) + ) # Load parameters from npy file - params_file = savar_folder / f"{savar_fname}_parameters.npy" + params_file = f"{savar_folder}/{savar_fname}_parameters.npy" params = np.load(params_file, allow_pickle=True).item() links_coeffs = params["links_coeffs"] - gt_adj_list = extract_adjacency_matrix(links_coeffs, n_modes_gt, tau) + gt_adj_list = extract_adjacency_matrix(links_coeffs, n_modes, tau) plot_adjacency_matrix( mat1=binarize_matrix(permuted_matrices, threshold), mat2=gt_adj_list, mat3=gt_adj_list, - path=results_path, + path=result_folder, name=f"permuted_adjacency_thr_{threshold}", no_gt=False, - iteration=iteration, + iteration=20000, plot_through_time=True, ) - save_equations_to_json(extract_latent_equations(links_coeffs), results_path / "gt_eq") + save_equations_to_json(extract_latent_equations(links_coeffs), result_folder / "gt_eq") save_equations_to_json( extract_equations_from_adjacency(binarize_matrix(permuted_matrices, threshold)), - results_path / f"thr_{threshold}_results_eq", + result_folder / f"thr_{threshold}_results_eq", ) precision, recall, f1, shd = evaluate_adjacency_matrix(permuted_matrices, gt_adj_list, threshold) print(f"Precision: {precision}, Recall: {recall}, F1 Score: {f1}, SHD: {shd}") results = {"precision": precision, "recall": recall, "f1_score": f1, "shd": shd} # Save results as a JSON file - json_filename = results_path / f"thr_{threshold}_evaluation_results.json" + json_filename = result_folder / f"thr_{threshold}_evaluation_results.json" with open(json_filename, "w") as json_file: json.dump(results, json_file) diff --git a/climatem/synthetic_data/savar.py b/climatem/synthetic_data/savar.py index 58ff24c..8c8b2c8 100644 --- a/climatem/synthetic_data/savar.py +++ b/climatem/synthetic_data/savar.py @@ -5,17 +5,17 @@ 2022 The main difference with the provided code is the torch/GPU implementation which considerably speeds up the data generation process """ -from typing import List -import seaborn as sns + import itertools as it from copy import deepcopy from math import pi, sin -import matplotlib.pyplot as plt +from typing import List + import numpy as np import torch +import torch.nn as nn from torch.distributions.multivariate_normal import MultivariateNormal from tqdm.auto import tqdm -import torch.nn as nn def dict_to_matrix(links_coeffs, default=0): @@ -122,7 +122,7 @@ def __init__( self.tau_max = max(abs(lag) for (_, lag), _ in it.chain.from_iterable(self.links_coeffs.values())) self.spatial_resolution = deepcopy(self.mode_weights.reshape(self.n_vars, -1).shape[1]) print("spatial-resolution done") - + if self.noise_weights is None: self.noise_weights = deepcopy(self.mode_weights) if self.latent_noise_cov is None: @@ -220,8 +220,10 @@ def _add_noise_field(self): # Generate noise from cov print("Generate noise_data_field multivariate random") - mean_torch = torch.Tensor(np.zeros(self.spatial_resolution)).to(device="cuda") - cov = torch.Tensor(self.noise_cov).to(device="cuda") + dev = torch.device("cuda" if torch.cuda.is_available() else "cpu") + dtype = torch.float32 + mean_torch = torch.zeros(self.spatial_resolution, device=dev, dtype=dtype) + cov = torch.tensor(self.noise_cov, device=dev, dtype=dtype) distrib = MultivariateNormal(loc=mean_torch, covariance_matrix=cov) # . to(device="cuda") noise_data_field = distrib.sample(sample_shape=torch.Size([self.time_length + self.transient])) self.noise_data_field = noise_data_field.detach().cpu().numpy().transpose() @@ -258,6 +260,7 @@ def _add_seasonality_forcing(self): def _add_external_forcing(self): """ Adds external forcing to the data field using PyTorch tensors for GPU acceleration. + Allows for both linear and nonlinear ramps. """ if self.forcing_dict is None: @@ -275,41 +278,47 @@ def _add_external_forcing(self): w_f = (w_f != 0).astype(int) # Convert non-zero elements to 1 print(self.mode_weights.shape) - #w_f = w_f / (w_f.max() + 1e-8) # Normalize to range [0,1] + # w_f = w_f / (w_f.max() + 1e-8) # Normalize to range [0,1] # Merge last two dims first => shape (d_z, lat*lon) - temp = w_f.reshape(w_f.shape[0], w_f.shape[1]*w_f.shape[2]) + temp = w_f.reshape(w_f.shape[0], w_f.shape[1] * w_f.shape[2]) # sum over dim=0 => shape (lat*lon,) - w_f_sum = torch.tensor(temp.sum(axis=0), dtype=torch.float32, device="cuda") + + dev = torch.device("cuda" if torch.cuda.is_available() else "cpu") + dtype = torch.float32 + w_f_sum = torch.tensor(temp.sum(axis=0), dtype=dtype, device=dev) f_time_1 += self.transient f_time_2 += self.transient time_length = self.time_length + self.transient - # Generate the forcing trend using torch tensors if ramp_type == "linear": - ramp = torch.linspace(f_1, f_2, f_time_2 - f_time_1, dtype=torch.float32, device="cuda") + ramp = torch.linspace(f_1, f_2, f_time_2 - f_time_1, dtype=dtype, device=dev) elif ramp_type == "quadratic": - t = torch.linspace(0, 1, f_time_2 - f_time_1, dtype=torch.float32, device="cuda") + t = torch.linspace(0, 1, f_time_2 - f_time_1, dtype=dtype, device=dev) ramp = f_1 + (f_2 - f_1) * t**2 elif ramp_type == "exponential": - t = torch.linspace(0, 1, f_time_2 - f_time_1, dtype=torch.float32, device="cuda") + t = torch.linspace(0, 1, f_time_2 - f_time_1, dtype=torch.float32, device=dev) ramp = f_1 + (f_2 - f_1) * (torch.exp(t) - 1) / (torch.exp(torch.tensor(1.0)) - 1) elif ramp_type == "sigmoid": - t = torch.linspace(-6, 6, f_time_2 - f_time_1, dtype=torch.float32, device="cuda") + t = torch.linspace(-6, 6, f_time_2 - f_time_1, dtype=dtype, device=dev) ramp = f_1 + (f_2 - f_1) * (1 / (1 + torch.exp(-t))) elif ramp_type == "sinusoidal": - t = torch.linspace(0, pi, f_time_2 - f_time_1, dtype=torch.float32, device="cuda") + t = torch.linspace(0, pi, f_time_2 - f_time_1, dtype=dtype, device=dev) ramp = f_1 + (f_2 - f_1) * (0.5 * (1 - torch.cos(t))) else: - raise ValueError("Unsupported ramp type. Choose from 'linear', 'quadratic', 'exponential', 'sigmoid', or 'sinusoidal'.") + raise ValueError( + "Unsupported ramp type. Choose from 'linear', 'quadratic', 'exponential', 'sigmoid', or 'sinusoidal'." + ) # Generate the forcing trend using torch tensors - trend = torch.cat([ - torch.full((f_time_1,), f_1, dtype=torch.float32, device="cuda"), - ramp, - torch.full((time_length - f_time_2,), f_2, dtype=torch.float32, device="cuda") - ]).reshape(1, time_length) + trend = torch.cat( + [ + torch.full((f_time_1,), f_1, dtype=dtype, device=dev), + ramp, + torch.full((time_length - f_time_2,), f_2, dtype=dtype, device=dev), + ] + ).reshape(1, time_length) if w_f_sum.dim() == 2: w_f_sum = w_f_sum.sum(dim=0, keepdim=True) # Sum across the correct dimension @@ -324,12 +333,8 @@ def _add_external_forcing(self): print(f"Before addition - Data field mean: {self.data_field.mean()}") - data_field_before = self.data_field.copy() - self.data_field += self.forcing_data_field - data_field_after = self.data_field - print(f"After addition - Data field mean: {self.data_field.mean()}") # # Convert tensors to numpy for plotting if necessary @@ -346,7 +351,7 @@ def _add_external_forcing(self): # mean_forcing = forcing_field.mean(axis=0) # mean_data_before = data_field_before.mean(axis=0) # mean_data_after = data_field_after.mean(axis=0) - + # # Plot 1: Mean Forcing over Time # plt.figure(figsize=(10, 4)) # plt.plot(range(time_length), mean_forcing, label="Mean Forcing", color="blue") @@ -374,21 +379,21 @@ def _add_external_forcing(self): # plt.savefig(f"mean_data_before_after_forcing_{f_1}_{f_2}_{ramp_type}.png") # Save to a file # plt.close() - def _create_linear(self): """Weights N \times L data_field L \times T.""" weights = deepcopy(self.mode_weights.reshape(self.n_vars, -1)) # weights_inv = np.linalg.pinv(weights) - weights_inv = torch.Tensor(np.linalg.pinv(weights)).to(device="cuda") - weights = torch.Tensor(weights).to(device="cuda") + dev = torch.device("cuda" if torch.cuda.is_available() else "cpu") + weights_inv = torch.Tensor(np.linalg.pinv(weights)).to(device=dev) + weights = torch.Tensor(weights).to(device=dev) time_len = deepcopy(self.time_length) time_len += self.transient tau_max = self.tau_max # phi = dict_to_matrix(self.links_coeffs) - phi = torch.Tensor(dict_to_matrix(self.links_coeffs)).to(device="cuda") + phi = torch.Tensor(dict_to_matrix(self.links_coeffs)).to(device=dev) # data_field = deepcopy(self.data_field) - data_field = torch.Tensor(self.data_field).to(device="cuda") + data_field = torch.Tensor(self.data_field).to(device=dev) print("create_linear") for t in tqdm(range(tau_max, time_len)): @@ -398,19 +403,17 @@ def _create_linear(self): self.data_field = data_field[..., self.transient :].detach().cpu().numpy() - def train_nnar(self, num_epochs=50, learning_rate=0.001, batch_size=32): """ - method for training a very simple single-layer neural network with - sigmoid activation (one neuron). We train it here on pairs (past_values, future_value), - but this can be adapted as needed. + Method for training a very simple single-layer neural network with sigmoid activation (one neuron). + + We train it here on pairs (past_values, future_value), but this can be adapted as needed. """ # A trivial net: data_in -> [Linear] -> [Sigmoid] -> data_out - self.nnar_model = nn.Sequential( - nn.Linear(self.spatial_resolution, self.spatial_resolution), - nn.Sigmoid() - ).to("cuda") + self.nnar_model = nn.Sequential(nn.Linear(self.spatial_resolution, self.spatial_resolution), nn.Sigmoid()).to( + "cuda" + ) optimizer = torch.optim.Adam(self.nnar_model.parameters(), lr=learning_rate) loss_fn = nn.MSELoss() @@ -419,8 +422,8 @@ def train_nnar(self, num_epochs=50, learning_rate=0.001, batch_size=32): # (we might later incorporate more lags) # collect input-output pairs: - X = torch.from_numpy(self.data_field[:, :-1].T).float().to("cuda") - Y = torch.from_numpy(self.data_field[:, 1:].T).float().to("cuda") + X = torch.from_numpy(self.data_field[:, :-1].T).float().to("cuda") + Y = torch.from_numpy(self.data_field[:, 1:].T).float().to("cuda") dataset_size = X.shape[0] # Simple mini-batch loop @@ -430,11 +433,11 @@ def train_nnar(self, num_epochs=50, learning_rate=0.001, batch_size=32): for i in range(0, dataset_size, batch_size): idx = perm[i : i + batch_size] - x_batch = X[idx] + x_batch = X[idx] y_batch = Y[idx] # forward pass - pred = self.nnar_model(x_batch) + pred = self.nnar_model(x_batch) loss = loss_fn(pred, y_batch) batch_losses.append(loss.item()) @@ -448,12 +451,11 @@ def train_nnar(self, num_epochs=50, learning_rate=0.001, batch_size=32): print("Training of single-layer NNAR model completed.") - def _create_nonlinear(self): """ - Generates nonlinear data by applying a (trained or simple) nonlinearity - at each time step. This method uses the same logic as _create_linear to step forward in time - and adds the nonlinearity (sigmoid) before adding to data_field. + Generates nonlinear data by applying a (trained or simple) nonlinearity at each time step. This method uses the + same logic as _create_linear to step forward in time and adds the nonlinearity (sigmoid) before adding to + data_field. If train_nnar=True was set, we assume self.nnar_model was trained in generate_data(). Otherwise, we can do a direct inline "torch.sigmoid(...)" approach. @@ -475,7 +477,7 @@ def _create_nonlinear(self): nonlinear_contrib = 0.0 for i in range(tau_max): # get linear combination as in _create_linear - lincombo = weights @ phi[..., i] @ mode_weights_tensor @ data_field[..., (t-1 - i):(t - i)] + lincombo = weights @ phi[..., i] @ mode_weights_tensor @ data_field[..., (t - 1 - i) : (t - i)] # Apply a sigmoid (or feed it through the small neural net if you want more complexity) lincombo_nl = torch.sigmoid(lincombo) # accumulate @@ -486,18 +488,22 @@ def _create_nonlinear(self): self.data_field = data_field[:, self.transient :].detach().cpu().numpy() - def _create_polynomial(self): - """ - Example polynomial autoregression, e.g. x^2 for poly_degree=2. - """ + """Example polynomial autoregression, e.g. x^2 for poly_degree=2.""" w_np = np.linalg.pinv(self.mode_weights.reshape(self.n_vars, -1)) phi_np = dict_to_matrix(self.links_coeffs) - w_torch = torch.Tensor(w_np).to("cuda") - phi_torch = torch.Tensor(phi_np).to("cuda") - mw_torch = torch.Tensor(self.mode_weights.reshape(self.n_vars, -1)).to("cuda") - data_field = torch.Tensor(self.data_field).to("cuda") + # choose GPU if available, else CPU — and use float32 everywhere + dev = torch.device("cuda" if torch.cuda.is_available() else "cpu") + dtype = torch.float32 + w_torch = torch.tensor(w_np, device=dev, dtype=dtype) + phi_torch = torch.tensor(phi_np, device=dev, dtype=dtype) + mw_torch = torch.tensor( + self.mode_weights.reshape(self.n_vars, -1), + device=dev, + dtype=dtype, + ) + data_field = torch.tensor(self.data_field, device=dev, dtype=dtype) time_len = self.time_length + self.transient tau_max = self.tau_max @@ -507,18 +513,13 @@ def _create_polynomial(self): for t in tqdm(range(tau_max, time_len)): # For each time step, sum over the contributions of all lags for i in range(tau_max): - lincombo = ( - w_torch - @ phi_torch[..., i] - @ mw_torch - @ data_field[..., (t - 1 - i) : (t - i)] - ) + lincombo = w_torch @ phi_torch[..., i] @ mw_torch @ data_field[..., (t - 1 - i) : (t - i)] # For each requested polynomial degree, add its effect - poly_sum = 0.0 + poly_sum = torch.zeros_like(lincombo) for deg in self.poly_degrees: - poly_sum += lincombo ** deg + poly_sum += lincombo**deg data_field[:, t] += poly_sum.squeeze(-1) - self.data_field = data_field[:, self.transient :].detach().cpu().numpy() \ No newline at end of file + self.data_field = data_field[:, self.transient :].detach().cpu().numpy() diff --git a/scripts/main_picabu.py b/scripts/main_picabu.py index 3ab7c67..9178aaa 100755 --- a/scripts/main_picabu.py +++ b/scripts/main_picabu.py @@ -61,15 +61,7 @@ def main( # os.makedirs(args.exp_path) # generate data and split train/test -<<<<<<< HEAD - if experiment_params.gpu and torch.cuda.is_available(): - device = "cuda" - print("CUDACUDACUDA") - else: - device = "cpu" -======= device = torch.device("cuda" if (torch.cuda.is_available() and experiment_params.gpu) else "cpu") ->>>>>>> main if data_params.data_format == "hdf5": print("IS HDF5") @@ -156,10 +148,11 @@ def main( reduce_encoding_pos_dim=model_params.reduce_encoding_pos_dim, coeff_kl=optim_params.coeff_kl, d=d, + #Here, everything hardcoded to gaussian because GEV leads to Nan... TBD distr_z0="gaussian", distr_encoder="gaussian", distr_transition="gaussian", - distr_decoder="gev", + distr_decoder="gaussian", d_x=experiment_params.d_x, d_z=experiment_params.d_z, tau=experiment_params.tau, @@ -188,11 +181,7 @@ def main( .translate({ord(","): None}) .translate({ord(" "): None}) ) -<<<<<<< HEAD - name = f"var_{data_var_ids_str}_scenarios_{data_params.train_scenarios[0]}_nonlinear_{model_params.nonlinear_mixing}_tau_{experiment_params.tau}_z_{experiment_params.d_z}_lr_{train_params.lr}_bs_{data_params.batch_size}_spreg_{optim_params.reg_coeff}_ormuinit_{optim_params.ortho_mu_init}_spmuinit_{optim_params.sparsity_mu_init}_spthres_{optim_params.sparsity_upper_threshold}_fixed_{model_params.fixed}_num_ensembles_{data_params.num_ensembles}_instantaneous_{model_params.instantaneous}_crpscoef_{optim_params.crps_coeff}_spcoef_{optim_params.spectral_coeff}_tempspcoef_{optim_params.temporal_spectral_coeff}_overlap_{savar_params.overlap}_forcing_{savar_params.is_forced}" -======= name = f"var_{data_var_ids_str}_scen_{data_params.train_scenarios[0]}_nlinmix_{model_params.nonlinear_mixing}_nlindyn_{model_params.nonlinear_dynamics}_tau_{experiment_params.tau}_z_{experiment_params.d_z}_futt_{experiment_params.future_timesteps}_ldec_{optim_params.loss_decay_future_timesteps}_lr_{train_params.lr}_bs_{data_params.batch_size}_ormuin_{optim_params.ortho_mu_init}_spmuin_{optim_params.sparsity_mu_init}_spth_{optim_params.sparsity_upper_threshold}_nens_{data_params.num_ensembles}_inst_{model_params.instantaneous}_crpscoef_{optim_params.crps_coeff}_sspcoef_{optim_params.spectral_coeff}_tspcoef_{optim_params.temporal_spectral_coeff}_fracnhiwn_{optim_params.fraction_highest_wavenumbers}_nummix_{model_params.num_hidden_mixing}_numhid_{model_params.num_hidden}_embdim_{model_params.position_embedding_dim}" ->>>>>>> main exp_path = exp_path / name os.makedirs(exp_path, exist_ok=True) @@ -209,7 +198,6 @@ def main( hp["train_params"] = train_params.__dict__ hp["model_params"] = model_params.__dict__ hp["optim_params"] = optim_params.__dict__ - hp["savar_params"] = savar_params.__dict__ with open(exp_path / "params.json", "w") as file: json.dump(hp, file, indent=4) @@ -237,6 +225,7 @@ def main( accelerator, wandbname=name, profiler=False, + profiler_path="./log", ) # where is the model at this point? @@ -357,7 +346,7 @@ def assert_args( params = json.load(f) config_obj_list = update_config_withparse(params, args) - # get user's scratch directory on Mila cluster: + # get user's scratch directory: scratch_path = os.getenv("SCRATCH") params["data_params"]["data_dir"] = params["data_params"]["data_dir"].replace("$SCRATCH", scratch_path) print ("new data path:", params["data_params"]["data_dir"]) @@ -395,4 +384,3 @@ def assert_args( ) main(experiment_params, data_params, gt_params, train_params, model_params, optim_params, plot_params, savar_params) - From 29480716e74c2377babd070923daffbfd92e2ae8 Mon Sep 17 00:00:00 2001 From: IlijaTrajkovic Date: Wed, 9 Jul 2025 15:26:36 +0200 Subject: [PATCH 10/12] removed module --- climatem/model/train_model.py | 44 +++++++++++++++++------------------ 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/climatem/model/train_model.py b/climatem/model/train_model.py index 65bbcf1..4ec8790 100644 --- a/climatem/model/train_model.py +++ b/climatem/model/train_model.py @@ -395,7 +395,7 @@ def trace_handler(p): # Todo propagate the path! if not self.plot_params.savar: self.plotter.save_coordinates_and_adjacency_matrices(self) - torch.save(self.model.module.state_dict(), self.save_path / "model.pth") + torch.save(self.model.state_dict(), self.save_path / "model.pth") # try to use the accelerator.save function here self.accelerator.save_state(output_dir=self.save_path) @@ -567,7 +567,7 @@ def train_step(self): # noqa: C901 h_acyclic = torch.as_tensor([0.0]) if self.instantaneous and not self.converged: h_acyclic = self.get_acyclicity_violation() - h_ortho = self.get_ortho_violation(self.model.module.autoencoder.get_w_decoder()) + h_ortho = self.get_ortho_violation(self.model.autoencoder.get_w_decoder()) # compute total loss - here we are removing the sparsity regularisation as we are using the constraint here. loss = nll + connect_reg + sparsity_reg @@ -635,7 +635,7 @@ def train_step(self): # noqa: C901 self.optimizer.step() if self.optim_params.optimizer == "rmsprop" else self.optimizer.step() ), self.train_params.lr # projection of the gradient for w - if self.model.module.autoencoder.use_grad_project and not self.no_w_constraint: + if self.model.autoencoder.use_grad_project and not self.no_w_constraint: with torch.no_grad(): self.model.autoencoder.get_w_decoder().clamp_(min=0.0) @@ -773,7 +773,7 @@ def train_step(self): # noqa: C901 # Validation step here. def valid_step(self): - self.model.module.eval() + self.model.eval() with torch.no_grad(): # sample data @@ -825,7 +825,7 @@ def valid_step(self): # h_ortho = torch.tensor([0.]) if self.instantaneous and not self.converged: h_acyclic = self.get_acyclicity_violation() - h_ortho = self.get_ortho_violation(self.model.module.autoencoder.get_w_decoder()) + h_ortho = self.get_ortho_violation(self.model.autoencoder.get_w_decoder()) h_sparsity = self.get_sparsity_violation( lower_threshold=0.05, upper_threshold=self.optim_params.sparsity_upper_threshold @@ -969,8 +969,8 @@ def threshold(self): Convert it to a binary graph and fix it. """ with torch.no_grad(): - thresholded_adj = (self.model.module.get_adj() > 0.5).type(torch.Tensor) - self.model.module.mask.fix(thresholded_adj) + thresholded_adj = (self.model.get_adj() > 0.5).type(torch.Tensor) + self.model.mask.fix(thresholded_adj) self.thresholded = True print("Thresholding ================") @@ -1028,9 +1028,9 @@ def log_losses(self): self.adj_tt[int(self.iteration / self.train_params.valid_freq)] = self.model.get_adj().item() # here we just plot the first element of the logvar_decoder and logvar_encoder - self.logvar_decoder_tt.append(self.model.module.autoencoder.logvar_decoder[0].item()) - self.logvar_encoder_tt.append(self.model.module.autoencoder.logvar_encoder[0].item()) - self.logvar_transition_tt.append(self.model.module.transition_model.logvar[0, 0].item()) + self.logvar_decoder_tt.append(self.model.autoencoder.logvar_decoder[0].item()) + self.logvar_encoder_tt.append(self.model.autoencoder.logvar_encoder[0].item()) + self.logvar_transition_tt.append(self.model.transition_model.logvar[0, 0].item()) def print_results(self): """Print values of many variable: losses, constraint violation, etc. @@ -1068,7 +1068,7 @@ def get_nll(self, x, y, z=None) -> torch.Tensor: def get_regularisation(self) -> float: if self.iteration > self.optim_params.schedule_reg: - adj = self.model.module.get_adj() + adj = self.model.get_adj() reg = self.optim_params.reg_coeff * torch.norm(adj, p=1) # reg /= adj.numel() else: @@ -1078,7 +1078,7 @@ def get_regularisation(self) -> float: def get_acyclicity_violation(self) -> torch.Tensor: if self.iteration > 0: - adj = self.model.module.get_adj()[-1].view(self.d * self.d_z, self.d * self.d_z) + adj = self.model.get_adj()[-1].view(self.d * self.d_z, self.d * self.d_z) h = compute_dag_constraint(adj) / self.acyclic_constraint_normalization else: h = torch.as_tensor([0.0]) @@ -1126,7 +1126,7 @@ def get_sparsity_violation(self, lower_threshold, upper_threshold) -> float: if self.iteration > self.optim_params.schedule_sparsity: # first get the adj - adj = self.model.module.get_adj() + adj = self.model.get_adj() sum_of_connections = torch.norm(adj, p=1) / self.sparsity_normalization # print('constraint value, before I subtract a threshold from it:', sum_of_connections) @@ -1438,7 +1438,7 @@ def autoregress_prediction_original(self, valid: bool = False, timesteps: int = timesteps: int, the number of timesteps to predict into the future autoregressively """ - self.model.module.eval() + self.model.eval() if not valid: @@ -1459,11 +1459,11 @@ def autoregress_prediction_original(self, valid: bool = False, timesteps: int = # ensure these are correct with torch.no_grad(): - y_pred, y, z, pz_mu, pz_std = self.model.module.predict(x, y) + y_pred, y, z, pz_mu, pz_std = self.model.predict(x, y) # Here we predict, but taking 100 samples from the latents # TODO: make this into an argument - samples_from_xs, samples_from_zs, y = self.model.module.predict_sample(x, y, 10) + samples_from_xs, samples_from_zs, y = self.model.predict_sample(x, y, 10) # append the first prediction predictions.append(y_pred) @@ -1504,7 +1504,7 @@ def autoregress_prediction_original(self, valid: bool = False, timesteps: int = # then predict the next timestep # y at this point is pointless!!! with torch.no_grad(): - y_pred, y, z, pz_mu, pz_std = self.model.module.predict(x, y) + y_pred, y, z, pz_mu, pz_std = self.model.predict(x, y) # append the prediction predictions.append(y_pred) @@ -1580,7 +1580,7 @@ def autoregress_prediction_original(self, valid: bool = False, timesteps: int = # save the model in its current state print("Saving the model, since the spatial spectra score is the best we have seen for all variables.") - torch.save(self.model.module.state_dict(), self.save_path / "best_model_for_average_spectra.pth") + torch.save(self.model.state_dict(), self.save_path / "best_model_for_average_spectra.pth") else: @@ -1602,10 +1602,10 @@ def autoregress_prediction_original(self, valid: bool = False, timesteps: int = # swap with torch.no_grad(): - y_pred, y, z, pz_mu, pz_std = self.model.module.predict(x, y) + y_pred, y, z, pz_mu, pz_std = self.model.predict(x, y) # predict and take 100 samples too - samples_from_xs, samples_from_zs, y = self.model.module.predict_sample(x, y, 100) + samples_from_xs, samples_from_zs, y = self.model.predict_sample(x, y, 100) # make a copy of y_pred, which is a tensor x_original = x.clone().detach() @@ -1635,7 +1635,7 @@ def autoregress_prediction_original(self, valid: bool = False, timesteps: int = with torch.no_grad(): # then predict the next timestep - y_pred, y, z, pz_mu, pz_std = self.model.module.predict(x, y) + y_pred, y, z, pz_mu, pz_std = self.model.predict(x, y) np.save(self.save_path / f"val_x_ar_{i}.npy", x.detach().cpu().numpy()) np.save(self.save_path / f"val_y_ar_{i}.npy", y.detach().cpu().numpy()) @@ -1718,7 +1718,7 @@ def particle_filter(self, x, y, num_particles, timesteps=120): for _ in range(timesteps): # Prediction # make all the new predictions, taking samples from the latents - _, samples_from_zs, y = self.model.module.predict_sample(x, y, 100) + _, samples_from_zs, y = self.model.predict_sample(x, y, 100) # then calculate the score of each of the samples # Update the weights, where we want the weights to increase as the score improves From 792c8091f262a08164345a2af0119c8e7882f44e Mon Sep 17 00:00:00 2001 From: IlijaTrajkovic Date: Wed, 23 Jul 2025 00:26:22 +0200 Subject: [PATCH 11/12] added seasonality to savar feat: add seasonality parameters to savarParams and SavarDataset, update data generation to support seasonal effects --- climatem/config.py | 10 ++ climatem/data_loader/savar_dataset.py | 101 ++++++++---- .../synthetic_data/generate_savar_datasets.py | 153 ++++++++++++++++-- climatem/synthetic_data/savar.py | 70 +++++--- scripts/main_picabu.py | 5 + 5 files changed, 283 insertions(+), 56 deletions(-) diff --git a/climatem/config.py b/climatem/config.py index a27d9ca..2bd3d86 100644 --- a/climatem/config.py +++ b/climatem/config.py @@ -288,6 +288,11 @@ def __init__( n_per_col: int = 2, # square grid, equivalent of lat/lon difficulty: str = "easy", # easy, med_easy, med_hard, hard: difficulty of the graph seasonality: bool = False, # Seasonality in synthetic data + periods: List[float] = [365, 182.5, 60], # Periods of the seasonality in days + amplitudes: List[float] = [0.06, 0.02, 0.01], # Amplitudes of the seasonality + phases: List[float] = [0.0, 0.7853981634, 1.5707963268], # Phases of the seasonality in radians + yearly_jitter_amp: float = 0.05, # Amplitude of the yearly jitter + yearly_jitter_phase: float = 0.10, # Phase of the yearly overlap: bool = False, # Modes overlap is_forced: bool = False, # Forcings in synthetic data f_1: int = 1, @@ -305,6 +310,11 @@ def __init__( self.n_per_col = n_per_col self.difficulty = difficulty self.seasonality = seasonality + self.periods = periods + self.amplitudes = amplitudes + self.phases = phases + self.yearly_jitter_amp = yearly_jitter_amp + self.yearly_jitter_phase = yearly_jitter_phase self.overlap = overlap self.is_forced = is_forced self.f_1 = f_1 diff --git a/climatem/data_loader/savar_dataset.py b/climatem/data_loader/savar_dataset.py index 0da88f2..2b42886 100644 --- a/climatem/data_loader/savar_dataset.py +++ b/climatem/data_loader/savar_dataset.py @@ -1,6 +1,6 @@ import os from pathlib import Path -from typing import List, Optional +from typing import List, Optional, Sequence import numpy as np import torch @@ -25,6 +25,11 @@ def __init__( n_per_col: int = 2, difficulty: str = "easy", seasonality: bool = False, + periods: List[float] = [365, 182.5, 60], + amplitudes: List[float] = [0.06, 0.02, 0.01], + phases: List[float] = [0.0, 0.7853981634, 1.5707963268], + yearly_jitter_amp: float = 0.05, + yearly_jitter_phase: float = 0.10, overlap: bool = False, is_forced: bool = False, f_1: int = 1, @@ -56,6 +61,11 @@ def __init__( self.n_per_col = n_per_col self.difficulty = difficulty self.seasonality = seasonality + self.periods = periods + self.amplitudes = amplitudes + self.phases = phases + self.yearly_jitter_amp = yearly_jitter_amp + self.yearly_jitter_phase = yearly_jitter_phase self.overlap = overlap self.is_forced = is_forced self.f_1 = f_1 @@ -185,6 +195,11 @@ def get_causal_data( self.n_per_col, self.difficulty, self.seasonality, + self.periods, + self.amplitudes, + self.phases, + self.yearly_jitter_amp, + self.yearly_jitter_phase, self.overlap, self.is_forced, self.f_1, @@ -212,7 +227,14 @@ def get_causal_data( if self.global_normalization: data = (data - data.mean()) / data.std() if self.seasonality_removal: - self.norm_data = self.remove_seasonality(self.norm_data) + data = self.remove_seasonality( + data, + periods=self.periods, # already a constructor arg (e.g. 12) + demean=True, + normalise=False, + rolling=True, + w=10, # 10 years ≈ 120 steps @ monthly + ) print(f"data is {data.dtype}") @@ -392,35 +414,60 @@ def get_min_max(self, data): return vars_min, vars_max - # important? - # NOTE:(seb) I need to check the axis is correct here? - def remove_seasonality(self, data): + def remove_seasonality( + self, + data: np.ndarray, + periods: int | Sequence[int] | Sequence[float] = (12, 6, 3), + demean: bool = True, + normalise: bool = False, + rolling: bool = True, # ← default TRUE because of jitter + w: int = 10, # (10 years ≈ 120 steps @ monthly) + ): """ - Function to remove seasonality from the data There are various different options to do this These are just - different methods of removing seasonality. + Remove deterministic periodic seasonality from a [time, …] array. - e.g. - monthly - remove seasonality on a per month basis - rolling monthly - remove seasonality on a per month basis but using a rolling window, - removing only the average from the months that have preceded this month - linear - remove seasonality using a linear model to predict seasonality - - or trend removal - emissions - remove the trend using the emissions data, such as cumulative CO2 + Parameters + ---------- + period single cycle length **or** list/tuple of lengths + (e.g. [12, 6] for annual + semi-annual) + … """ - mean = np.nanmean(data, axis=0) - std = np.nanstd(data, axis=0) - - # return data - - # NOTE: SH - do we not do this above? - # standardise - I hope this is doing by month, to check - - return (data - mean[None]) / std[None] - - # now just divide by std... - # return data / std[None] + def _remove_one(x: np.ndarray, p: int) -> np.ndarray: + """Inner helper that handles a single period length.""" + t = x.shape[0] + rem = t % p + if rem: + x = x[:-rem] + t -= rem + folded = x.reshape((t // p, p) + x.shape[1:]) + if rolling: + k = min(w, folded.shape[0]) + mean = np.nanmean(folded[-k:], axis=0) + std = np.nanstd(folded[-k:], axis=0) + else: + mean = np.nanmean(folded, axis=0) + std = np.nanstd(folded, axis=0) + mean_full = np.tile(mean, (t // p, *[1] * (x.ndim - 1))) + std_full = np.tile(std, (t // p, *[1] * (x.ndim - 1))) + out = x.copy() + if demean: + out -= mean_full + if normalise: + out /= np.where(std_full == 0, 1, std_full) + return out.astype(np.float32) + + # handle one or many cycle lengths + if isinstance(periods, (list, tuple, np.ndarray)): + # remove the longest cycle first to avoid leakage + _periods = sorted([int(round(p)) for p in periods], reverse=True) + else: # single scalar + _periods = [int(round(periods))] + + out = data.astype(np.float32) + for p in _periods: + out = _remove_one(out, p) + return out def write_dataset_statistics(self, fname, stats): # fname = fname.replace('.npz.npy', '.npy') diff --git a/climatem/synthetic_data/generate_savar_datasets.py b/climatem/synthetic_data/generate_savar_datasets.py index 430728a..fa264b3 100644 --- a/climatem/synthetic_data/generate_savar_datasets.py +++ b/climatem/synthetic_data/generate_savar_datasets.py @@ -1,3 +1,4 @@ +import copy import csv import json @@ -102,7 +103,12 @@ def generate_save_savar_data( noise_val=0.2, n_per_col=2, # Number of components N = n_per_col**2 difficulty="easy", - seasonality=False, + seasonality=True, + periods=[12, 6, 3], + amplitudes=[0.1, 0.05, 0.02], + phases=[0.0, 0.7853981634, 1.5707963268], # [0, π/4, π/2] radians + yearly_jitter_amp: float = 0.05, + yearly_jitter_phase: float = 0.10, overlap=0, is_forced=False, f_1=1, @@ -183,10 +189,29 @@ def generate_save_savar_data( "time_len": time_len, "ramp_type": ramp_type, } + + season_dict = None if seasonality: - raise ValueError("SAVAR data with seasonality not implemented yet") - # We could introduce seasonality if we would wish - # season_dict = {"amplitude": 0.08, "period": 12} + lat = np.linspace(-90, 90, nx) # vary along rows + lat2d = np.repeat(lat[:, None], ny, axis=1) # shape (nx, ny) + season_weight = np.abs(np.sin(np.deg2rad(lat2d))).ravel() + + if phases is None: + phases = [0.0] * len(amplitudes) + + if not (len(amplitudes) == len(periods) == len(phases)): + raise ValueError("season_amplitudes, season_periods, season_phases must have identical lengths.") + + season_dict = { + "amplitudes": amplitudes, # e.g. [0.06, 0.02, 0.01] + "periods": periods, # e.g. [365, 182.5, 60] + "phases": phases, # radian offsets + "season_weight": season_weight, + "yearly_jitter": { + "amplitude": yearly_jitter_amp, # e.g. 0.05 + "phase": yearly_jitter_phase, # e.g. 0.10 + }, + } if plotting: # Plot the sum of mode weights @@ -238,10 +263,13 @@ def generate_save_savar_data( "ramp_type": ramp_type, "linearity": linearity, "poly_degrees": poly_degrees, - # "season_dict": season_dict, - # "seasonality" : True, + "season_dict": season_dict, + "seasonality": True, } + parameters_copy = copy.deepcopy(parameters) + convert_ndarray_to_list(parameters_copy) # safe to mutate + # Specify the path to save the parameters param_names = f"{name}_parameters.npy" params_path = save_dir_path / param_names @@ -250,7 +278,7 @@ def generate_save_savar_data( param_names = f"{name}_parameters.csv" params_path = save_dir_path / param_names - save_parameters_to_csv(params_path, parameters) + save_parameters_to_csv(params_path, parameters_copy) param_names = f"{name}_links_coeffs.csv" params_path = save_dir_path / param_names save_links_coeffs_to_csv(params_path, parameters["links_coeffs"]) @@ -259,7 +287,6 @@ def generate_save_savar_data( np.save(params_path, modes_weights) # Create a copy of the parameters to modify - parameters_copy = parameters.copy() convert_ndarray_to_list(parameters_copy) # Specify the path to save the parameters @@ -277,7 +304,7 @@ def generate_save_savar_data( time_length=time_len, mode_weights=modes_weights, noise_strength=noise_val, # How to play with this parameter? - # season_dict=season_dict, #turn off by commenting out + season_dict=season_dict, # turn off by commenting out # forcing_dict=forcing_dict #turn off by commenting out linearity=linearity, poly_degrees=poly_degrees, @@ -288,6 +315,7 @@ def generate_save_savar_data( time_length=time_len, mode_weights=modes_weights, noise_strength=noise_val, + season_dict=season_dict, # turn off by commenting out forcing_dict=forcing_dict, # turn off by commenting out linearity=linearity, poly_degrees=poly_degrees, @@ -299,6 +327,26 @@ def generate_save_savar_data( print(f"{name} DONE!") plot_original_savar(savar_model.data_field, nx, ny, save_dir_path / f"{name}_original_savar_data2.gif") + + plot_seasonality_only(savar_model.seasonal_data_field, nx, ny, save_dir_path / f"{name}_seasonality_only.gif") + + # Check the seasonality + season_std = savar_model.seasonal_data_field.std() + total_std = savar_model.data_field.std() + ratio = season_std / total_std + print(f"[diag] σ_season / σ_total = {ratio:.3f}") + + if ratio < 0.05: + print("[diag] Seasonality is tiny – raise amplitudes or lower noise_val.") + elif ratio < 0.15: + print("[diag] Seasonality is subtle – visible but easy to miss.") + else: + print("[diag] Seasonality should be visually obvious.") + + hovmoller_seasonality( + savar_model.seasonal_data_field, nx, ny, n_steps=100, path=save_dir_path / f"{name}_seasonality_hovmoller.png" + ) + return savar_model.data_field @@ -331,3 +379,90 @@ def animate(i): ani.save(path, writer="pillow", fps=10) plt.close() + + +# Save an animation of the pure seasonality component +def plot_seasonality_only(seasonal_field, nx, ny, path, n_frames=100): + """ + Save an animation (GIF) of the pure seasonal component. + + Parameters + ---------- + seasonal_field : (L, T) ndarray + nx, ny : grid dims (L must equal nx*ny) + path : output filename (.gif) + n_frames : how many timesteps to animate (default 100) + """ + import matplotlib.animation as animation + import matplotlib.pyplot as plt + import numpy as np + + T = seasonal_field.shape[1] + n_frames = min(n_frames, T) + + vmax = np.abs(seasonal_field).max() + vmin = -vmax + + field_2d = seasonal_field[:, :n_frames].T.reshape(n_frames, nx, ny) + + fig, ax = plt.subplots(figsize=(ny / 6, nx / 10)) + im = ax.imshow(field_2d[0], cmap="RdBu_r", vmin=vmin, vmax=vmax) + ax.set_xticks([]) + ax.set_yticks([]) + + divider = make_axes_locatable(ax) + cax = divider.append_axes("right", size="3%", pad=0.02) + cbar = fig.colorbar(im, cax=cax) + cbar.ax.tick_params(labelsize=8) + cbar.set_label("seasonal anomaly") + + def animate(i): + im.set_data(field_2d[i]) + ax.set_title(f"Seasonality – t={i+1}") + return (im,) + + ani = animation.FuncAnimation(fig, animate, frames=n_frames, blit=False) + + ani.save(path, writer="pillow", fps=10) + plt.close() + + +def hovmoller_seasonality(seasonal_field, nx, ny, n_steps=100, path="seasonality_hovmoller.png"): + """ + Hovmöller-style plot that compresses the whole seasonal signal into one figure. + + Parameters + ---------- + seasonal_field : ndarray shape (L, T) + nx, ny : grid dimensions (L must equal nx × ny) + n_steps : how many time-steps to include on the x-axis + path : PNG filename to save + """ + import matplotlib.pyplot as plt + import numpy as np + + L, T = seasonal_field.shape + n_steps = min(n_steps, T) + + # reshape to [time, lat, lon] + field_3d = seasonal_field.T.reshape(T, nx, ny) + + # zonal mean (average over lon axis) + hov = field_3d[:n_steps].mean(axis=2) # shape (n_steps, nx) + + # transpose so y-axis is latitude, x-axis is time + hov = hov.T # shape (nx, n_steps) + + vmax = np.abs(hov).max() # symmetric colour scale + vmin = -vmax + + plt.figure(figsize=(8, 4)) + im = plt.imshow(hov, aspect="auto", origin="lower", cmap="RdBu_r", vmin=vmin, vmax=vmax) # south pole at bottom + plt.colorbar(im, label="seasonal anomaly") + plt.xlabel("time-step") + plt.ylabel("latitude index") + plt.title("Seasonality • zonal mean (first {:d} steps)".format(n_steps)) + plt.tight_layout() + plt.savefig(path, dpi=150) + plt.close() + print("saved:", path) diff --git a/climatem/synthetic_data/savar.py b/climatem/synthetic_data/savar.py index 8c8b2c8..083d9e2 100644 --- a/climatem/synthetic_data/savar.py +++ b/climatem/synthetic_data/savar.py @@ -7,8 +7,9 @@ """ import itertools as it +import math from copy import deepcopy -from math import pi, sin +from math import pi from typing import List import numpy as np @@ -235,27 +236,56 @@ def _add_noise_field(self): def _add_seasonality_forcing(self): - # A*sin((2pi/lambda)*x) A = amplitude, lambda = period - amplitude = self.season_dict["amplitude"] - period = self.season_dict["period"] - season_weight = self.season_dict.get("season_weight", None) + periods = self.season_dict["periods"] # e.g. [12, 6, 3] for year, half-year, quarter-year + amplitudes = self.season_dict["amplitudes"] # same length as periods + phases = self.season_dict.get("phases", [0.0] * len(periods)) - seasonal_trend = np.asarray( - [amplitude * sin((2 * pi / period) * x) for x in range(self.time_length + self.transient)] - ) - - seasonal_data_field = np.ones_like(self.data_field) - seasonal_data_field *= seasonal_trend.reshape(1, -1) - - # Apply seasonal weights - if season_weight is not None: - season_weight = season_weight.sum(axis=0).reshape(self.spatial_resolution) # vector dim L - seasonal_data_field *= season_weight[:, None] # L times T - - self.seasonal_data_field = seasonal_data_field + # year-to-year amplitude / phase jitter + jitter_cfg = self.season_dict.get("yearly_jitter") # None or dict + base_P = periods[0] # assume first is annual (12 months) + dev = torch.device("cuda" if torch.cuda.is_available() else "cpu") + dtype = torch.float32 - # Add it to the data field. - self.data_field += seasonal_data_field + T = self.time_length + self.transient + ncy = math.ceil(T / base_P) # # of whole cycles + + L = self.data_field.shape[0] + T = self.time_length + self.transient + t = torch.arange(T, device=dev, dtype=dtype) + seasonal = torch.zeros((L, T), device=dev, dtype=dtype) + + σ_A = jitter_cfg["amplitude"] if jitter_cfg else 0.0 + σ_φ = jitter_cfg["phase"] if jitter_cfg else 0.0 + + # allow vector inputs, default to identical values otherwise + σ_Ak = torch.as_tensor(σ_A).expand(len(periods)).to(dtype=dtype, device=dev) + σ_φk = torch.as_tensor(σ_φ).expand(len(periods)).to(dtype=dtype, device=dev) + + for k, (A, P, φ) in enumerate(zip(amplitudes, periods, phases)): + # one jitter draw *per year* for this harmonic + amp_noise_k = 1 + σ_Ak[k] * torch.randn(ncy, device=dev, dtype=dtype) + phase_noise_k = σ_φk[k] * torch.randn(ncy, device=dev, dtype=dtype) + + amp_series_k = amp_noise_k.repeat_interleave(base_P)[:T] # (T,) + phase_series_k = phase_noise_k.repeat_interleave(base_P)[:T] # (T,) + + seasonal += amp_series_k * A * torch.sin(2 * math.pi / P * (t + phase_series_k) + φ) + + w = self.season_dict.get("season_weight") + if w is not None: + if not torch.is_tensor(w): + w = torch.as_tensor(w, dtype=dtype, device=dev) + else: + w = w.to(device=dev, dtype=dtype) + if w.ndim > 1: + w = w.reshape(-1) + if w.numel() != L: + raise ValueError(f"season_weight has length {w.numel()} but grid has {L} points") + seasonal *= w.reshape(L, 1) + + seasonal_np = seasonal.cpu().numpy() + self.seasonal_data_field = seasonal_np + self.data_field += seasonal_np def _add_external_forcing(self): """ diff --git a/scripts/main_picabu.py b/scripts/main_picabu.py index 9178aaa..aac82fc 100755 --- a/scripts/main_picabu.py +++ b/scripts/main_picabu.py @@ -110,6 +110,11 @@ def main( n_per_col=savar_params.n_per_col, difficulty=savar_params.difficulty, seasonality=savar_params.seasonality, + periods=savar_params.periods, + amplitudes=savar_params.amplitudes, + phases=savar_params.phases, + yearly_jitter_amp=savar_params.yearly_jitter_amp, + yearly_jitter_phase=savar_params.yearly_jitter_phase, overlap=savar_params.overlap, is_forced=savar_params.is_forced, f_1=savar_params.f_1, From 0d7e27252d6f2d05c54b782e48de4c9e130b7cc3 Mon Sep 17 00:00:00 2001 From: IlijaTrajkovic Date: Wed, 23 Jul 2025 15:15:01 +0200 Subject: [PATCH 12/12] changed max seasonality to mid latitudes --- .../synthetic_data/generate_savar_datasets.py | 43 +++++++++++-------- 1 file changed, 24 insertions(+), 19 deletions(-) diff --git a/climatem/synthetic_data/generate_savar_datasets.py b/climatem/synthetic_data/generate_savar_datasets.py index fa264b3..1997e15 100644 --- a/climatem/synthetic_data/generate_savar_datasets.py +++ b/climatem/synthetic_data/generate_savar_datasets.py @@ -194,7 +194,7 @@ def generate_save_savar_data( if seasonality: lat = np.linspace(-90, 90, nx) # vary along rows lat2d = np.repeat(lat[:, None], ny, axis=1) # shape (nx, ny) - season_weight = np.abs(np.sin(np.deg2rad(lat2d))).ravel() + season_weight = np.abs(np.sin(2 * np.deg2rad(lat2d))).ravel() if phases is None: phases = [0.0] * len(amplitudes) @@ -328,24 +328,29 @@ def generate_save_savar_data( plot_original_savar(savar_model.data_field, nx, ny, save_dir_path / f"{name}_original_savar_data2.gif") - plot_seasonality_only(savar_model.seasonal_data_field, nx, ny, save_dir_path / f"{name}_seasonality_only.gif") - - # Check the seasonality - season_std = savar_model.seasonal_data_field.std() - total_std = savar_model.data_field.std() - ratio = season_std / total_std - print(f"[diag] σ_season / σ_total = {ratio:.3f}") - - if ratio < 0.05: - print("[diag] Seasonality is tiny – raise amplitudes or lower noise_val.") - elif ratio < 0.15: - print("[diag] Seasonality is subtle – visible but easy to miss.") - else: - print("[diag] Seasonality should be visually obvious.") - - hovmoller_seasonality( - savar_model.seasonal_data_field, nx, ny, n_steps=100, path=save_dir_path / f"{name}_seasonality_hovmoller.png" - ) + if seasonality: + plot_seasonality_only(savar_model.seasonal_data_field, nx, ny, save_dir_path / f"{name}_seasonality_only.gif") + + # Check the seasonality + season_std = savar_model.seasonal_data_field.std() + total_std = savar_model.data_field.std() + ratio = season_std / total_std + print(f"[diag] σ_season / σ_total = {ratio:.3f}") + + if ratio < 0.05: + print("[diag] Seasonality is tiny – raise amplitudes or lower noise_val.") + elif ratio < 0.15: + print("[diag] Seasonality is subtle – visible but easy to miss.") + else: + print("[diag] Seasonality should be visually obvious.") + + hovmoller_seasonality( + savar_model.seasonal_data_field, + nx, + ny, + n_steps=100, + path=save_dir_path / f"{name}_seasonality_hovmoller.png", + ) return savar_model.data_field