Add IC ensemble ability to evaluator and update inference aggregators for ensembles#709
Add IC ensemble ability to evaluator and update inference aggregators for ensembles#709Arcomano1234 wants to merge 84 commits intomainfrom
Conversation
… for inference config
…nsemble" This reverts commit f998b9f.
| gen_data_norm=gen_data_norm, | ||
| i_time_start=self._n_timesteps_seen, | ||
| ) | ||
| if self.n_ensemble_per_ic > 1: |
There was a problem hiding this comment.
Check out fme/ace/aggregator/one_step/main.py and see if you like the structure there for how it separates ensemble and deterministic aggregators. I think what's here is fine too, maybe preferable.
There was a problem hiding this comment.
Either works but I did have a slight preference for the structure I proposed vs in fme/ace/aggregator/one_step/main.py to better allow more complex operations with ensemble aggs
| "a": torch.ones([batch_size, n_timesteps, nx, ny], device=get_device()), | ||
| "b": torch.ones([batch_size, n_timesteps, nx, ny], device=get_device()) * 3, | ||
| "c": torch.ones([batch_size, n_timesteps, nx, ny], device=get_device()) * 4, | ||
| "a": torch.ones( |
There was a problem hiding this comment.
Suggestion (optional): Like you do in test_inference_evaluator_aggregator_ensemble, make two BatchData, call .broadcast_ensemble on them, and then use PairedData.from_batch_data instead of constructing an ensemble PairedData at a low level. That would avoid coupling this test to the low-level implementation of the internals of BatchData/PairedData, making it easier to change the way ensembles are handled later if we need to.
| horizontal_dims=self.horizontal_dims, | ||
| epoch=self.epoch, | ||
| labels=self.labels.to(device) if self.labels is not None else None, | ||
| n_ensemble=self.n_ensemble, |
There was a problem hiding this comment.
Could separate these changes into their own PR, it looks like they're fixing an existing bug. Would make it easier to tell if you updated the tests for this (if you did great, if you didn't please do).
Maybe for some of these methods we should have a more general "test all attributes that don't start with underscore are identical, except for .device." type test that will always catch these issues, even if we add new attributes.
| data=repeat_interleave_batch_dim(self.data, n_ensemble), | ||
| time=xr.concat([self.time] * n_ensemble, dim="sample"), | ||
| labels=labels, | ||
| epoch=self.epoch, |
There was a problem hiding this comment.
It looks like epoch=self.epoch got deleted. Can you remove the changes to this method, which seem to be stylistic? I normally try to construct things before the final return-setting-many-attributes, in the case of larger inits like this one.
fme/ace/data_loading/batch_data.py
Outdated
| def target(self) -> TensorMapping: | ||
| return {k: v for k, v in self.reference.items() if k in self.prediction} | ||
|
|
||
| def broadcast_ensemble(self) -> tuple[EnsembleTensorDict, EnsembleTensorDict]: |
There was a problem hiding this comment.
Issue: I would expect .broadcast_ensemble to return the same type, like it does for BatchData. I'd also expect it to be a light wrapper that broadcasts the internal BatchData. Is it possible to do that here? If you do, the test that got refactored to low-level-constrict PairedData could have its changes limited to calling the .broadcast_ensemble method on the one it was already making.
To break up #709 into manageable PRs, this PR adds the ensemble mean aggregator to `one_step/ensemble.py`. Changes: - Add `EnsembleMeanRMSEMetric` class `one_step/ensemble.py` - [x] Tests added
Adds support for same initial condition (IC) ensembles for stochastic models during evaluation and inline inference. Also adds / extends ensemble-based aggregators for step 20 during inference.
Changes:
symbol (e.g.
fme.core.my_function) or script and concise description of changes or added featureCan group multiple related symbols on a single bullet
Tests added