From 9dc7c6a0fb1c707e8cc8a39c9edcb33e99b3ace9 Mon Sep 17 00:00:00 2001 From: Julia Sloan Date: Tue, 10 Mar 2026 17:14:17 -0700 Subject: [PATCH] Enable checkpointing for CMIP Clean up file-paths for restart tests Author: Julia Sloan , Akshay Sridhar --- .buildkite/pipeline.yml | 22 ++- docs/src/checkpointer.md | 28 ++- experiments/test/amip_test.yml | 2 + .../test/{compare.jl => compare_amip.jl} | 0 experiments/test/compare_cmip.jl | 175 ++++++++++++++++++ .../test/{restart.jl => restart_amip.jl} | 4 +- experiments/test/restart_cmip.jl | 127 +++++++++++++ experiments/test/restart_cmip.yml | 41 ++++ experiments/test/restart_state_only.jl | 6 +- ext/ClimaCouplerCMIPExt.jl | 6 +- ext/ClimaCouplerCMIPExt/clima_seaice.jl | 13 -- ext/ClimaCouplerCMIPExt/climaocean_helpers.jl | 73 ++++++++ ext/ClimaCouplerCMIPExt/oceananigans.jl | 18 +- src/Checkpointer.jl | 68 ++++--- 14 files changed, 512 insertions(+), 71 deletions(-) rename experiments/test/{compare.jl => compare_amip.jl} (100%) create mode 100644 experiments/test/compare_cmip.jl rename experiments/test/{restart.jl => restart_amip.jl} (98%) create mode 100644 experiments/test/restart_cmip.jl create mode 100644 experiments/test/restart_cmip.yml diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index 8dcf362354..348a989282 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -101,7 +101,7 @@ steps: - label: "MPI restarts" key: "mpi_restarts" - command: "srun julia --color=yes --project=experiments/AMIP/ experiments/test/restart.jl" + command: "srun julia --color=yes --project=experiments/AMIP/ experiments/test/restart_amip.jl" retry: *retry_policy env: CLIMACOMMS_CONTEXT: "MPI" @@ -113,24 +113,34 @@ steps: slurm_ntasks: 2 slurm_mem: 32GB - - label: "GPU restarts (state and cache)" - command: "julia -O0 --color=yes --project=experiments/AMIP/ experiments/test/restart.jl" + - label: "GPU AMIP restarts (state and cache)" + command: "julia -O0 --color=yes --project=experiments/AMIP/ experiments/test/restart_amip.jl" + retry: *retry_policy env: CLIMACOMMS_DEVICE: "CUDA" - retry: *retry_policy agents: slurm_gpus: 1 slurm_mem: 32GB - - label: "GPU restarts (state only)" + - label: "GPU AMIP restarts (state only)" command: "julia -O0 --color=yes --project=experiments/AMIP/ experiments/test/restart_state_only.jl" + retry: *retry_policy env: CLIMACOMMS_DEVICE: "CUDA" - retry: *retry_policy agents: slurm_gpus: 1 slurm_mem: 32GB + - label: "GPU CMIP restarts (state and cache)" + command: "julia -O0 --color=yes --project=experiments/CMIP/ experiments/test/restart_cmip.jl" + retry: *retry_policy + env: + CLIMACOMMS_DEVICE: "CUDA" + agents: + slurm_ntasks: 1 + slurm_gres: "gpu:1" + slurm_mem: 32GB + - group: "Integration Tests (single column)" steps: - label: "SCM: slabplanet aqua (atmos + slab ocean column)" diff --git a/docs/src/checkpointer.md b/docs/src/checkpointer.md index 5b654a44eb..f87227c292 100644 --- a/docs/src/checkpointer.md +++ b/docs/src/checkpointer.md @@ -108,13 +108,35 @@ forward, but there are still several challenges that need to be solved: Point 3. adds significant amount of code and requires component models to specify how their cache has to be restored. -If you are adding a component model, you have to extend the methods. +### Adding checkpointing to a new component model + +There are two ways to add checkpoint/restart support for a new component model: + +**Path A (ClimaCore-based models):** extend `get_model_prog_state` to return the +prognostic state as a `ClimaCore.FieldVector`. The default +`checkpoint_model_state` and `restart_model_state!` implementations will handle +HDF5 I/O via `ClimaCore.InputOutput` automatically. This path is intended for +models whose prognostic state is a `ClimaCore.FieldVector`; models that do not +use ClimaCore should use Path B instead. + ``` Checkpointer.get_model_prog_state Checkpointer.get_model_cache Checkpointer.restore_cache! ``` +**Path B (custom checkpoint format):** override `checkpoint_model_state` and +`restart_model_state!` directly for full control over the checkpoint format. This +is the approach used by `OceananigansSimulation`, which writes JLD2 checkpoints +via Oceananigans' native `checkpoint` and restores them with `Oceananigans.set!`. + +``` +Checkpointer.checkpoint_model_state +Checkpointer.checkpoint_model_cache # optional; no-op if cache checkpointing is not supported +Checkpointer.restart_model_state! +Checkpointer.restart_model_cache! # optional; warn or no-op if cache restore is not supported +``` + `ClimaCoupler` moves objects to the CPU with `Adapt(Array, x)`. `Adapt` traverses the object recursively, and proper `Adapt` methods have to be defined for every object involved in the chain. The easiest way to do this is using the @@ -159,7 +181,11 @@ This approach allows for a signficant reducation in the file size of the cache. Checkpointer.get_model_prog_state Checkpointer.get_model_cache Checkpointer.get_model_cache_to_checkpoint + Checkpointer.checkpoint_model_state + Checkpointer.checkpoint_model_cache Checkpointer.restart! + Checkpointer.restart_model_state! + Checkpointer.restart_model_cache! Checkpointer.checkpoint_sims Checkpointer.t_start_from_checkpoint Checkpointer.restore! diff --git a/experiments/test/amip_test.yml b/experiments/test/amip_test.yml index 689344149f..b669935ce9 100644 --- a/experiments/test/amip_test.yml +++ b/experiments/test/amip_test.yml @@ -15,7 +15,9 @@ mode_name: "amip" netcdf_interpolation_num_points: [90, 45, 31] netcdf_output_at_levels: true output_default_diagnostics: true +orographic_gravity_wave: ~ rayleigh_sponge: true +reproducible_restart: true start_date: "20100101" surface_setup: "PrescribedSurface" t_end: "540secs" diff --git a/experiments/test/compare.jl b/experiments/test/compare_amip.jl similarity index 100% rename from experiments/test/compare.jl rename to experiments/test/compare_amip.jl diff --git a/experiments/test/compare_cmip.jl b/experiments/test/compare_cmip.jl new file mode 100644 index 0000000000..f1337bd43a --- /dev/null +++ b/experiments/test/compare_cmip.jl @@ -0,0 +1,175 @@ +# compare.jl provides function to recursively compare complex objects while also +# allowing for some numerical tolerance. + +import ClimaComms +import ClimaAtmos as CA +import ClimaCore as CC +import Oceananigans as OC +import ClimaSeaIce as CSI +using ClimaSeaIce.SeaIceThermodynamics.HeatBoundaryConditions: IceWaterThermalEquilibrium +import NCDatasets + +""" + _error(arr1::AbstractArray, arr2::AbstractArray; ABS_TOL = 100eps(eltype(arr1))) + +We compute the error in this way: +- when the absolute value is larger than ABS_TOL, we use the absolute error +- in the other cases, we compare the relative errors +""" +function _error(arr1::AbstractArray, arr2::AbstractArray; ABS_TOL = 100eps(eltype(arr1))) + # There are some parameters, e.g. Obukhov length, for which Inf + # is a reasonable value (implying a stability parameter in the neutral boundary layer + # regime, for instance). We account for such instances with the `isfinite` function. + arr1 = Array(arr1) .* isfinite.(Array(arr1)) + arr2 = Array(arr2) .* isfinite.(Array(arr2)) + diff = abs.(arr1 .- arr2) + denominator = abs.(arr1) + error = ifelse.(denominator .> ABS_TOL, diff ./ denominator, diff) + return error +end + +""" + compare(v1, v2; name = "", ignore = Set([:rc])) + +Return whether `v1` and `v2` are the same (up to floating point errors). +`compare` walks through all the properties in `v1` and `v2` until it finds +that there are no more properties. At that point, `compare` tries to match the +resulting objects. When such objects are arrays with floating point, `compare` +defines a notion of `error` that is the following: when the absolute value is +less than `100eps(eltype)`, `error = absolute_error`, otherwise it is relative +error. The `error` is then compared against a tolerance. +Keyword arguments +================= +- `name` is used to collect the name of the property while we go recursively + over all the properties. You can pass a base name. +- `ignore` is a collection of `Symbol`s that identify properties that are + ignored when walking through the tree. This is useful for properties that + are known to be different (e.g., `output_dir`). +`:rc` is some CUDA/CuArray internal object that we don't care about +""" +function compare( + v1::T1, + v2::T2; + name = "", + ignore = Set([:rc]), +) where { + T1 <: Union{ + CC.Fields.FieldVector, + CC.Spaces.AbstractSpace, + NamedTuple, + CA.AtmosCache, + OC.Models.HydrostaticFreeSurfaceModels.HydrostaticFreeSurfaceModel, + CSI.SeaIceModel, + }, + T2 <: Union{ + CC.Fields.FieldVector, + CC.Spaces.AbstractSpace, + NamedTuple, + CA.AtmosCache, + OC.Models.HydrostaticFreeSurfaceModels.HydrostaticFreeSurfaceModel, + CSI.SeaIceModel, + }, +} + pass = true + return _compare(pass, v1, v2; name, ignore) +end + +function _compare(pass, v1::T, v2::T; name, ignore) where {T} + properties = filter(x -> !(x in ignore), propertynames(v1)) + if isempty(properties) + pass &= _compare(v1, v2; name, ignore) + else + # Recursive case + for p in properties + pass &= _compare( + pass, + getproperty(v1, p), + getproperty(v2, p); + name = "$(name).$(p)", + ignore, + ) + end + end + return pass +end + +# ClimaSeaIce `IceWaterThermalEquilibrium` and `FluxFunction` use reference `==` at compile time +# (see `@code_typed ==(iwte1, iwte2)`), and `FluxFunction.func` seem to hold non-comparable closures. +function _compare( + pass, + v1::IceWaterThermalEquilibrium, + v2::IceWaterThermalEquilibrium; + name, + ignore, +) + pass &= _compare(pass, v1.salinity, v2.salinity; name = "$(name).salinity", ignore) + return pass +end + +function _compare(pass, v1::CSI.FluxFunction, v2::CSI.FluxFunction; name, ignore) + pass &= + _compare(pass, v1.parameters, v2.parameters; name = "$(name).parameters", ignore) + return pass +end + +function _compare(v1::T, v2::T; name, ignore) where {T} + return print_maybe(v1 == v2, "$name differs") +end + +function _compare(v1::T, v2::T; name, ignore) where {T <: Union{AbstractString, Symbol}} + # What we can safely print without filling STDOUT + return print_maybe(v1 == v2, "$name differs: $v1 vs $v2") +end + +function _compare(v1::T, v2::T; name, ignore) where {T <: Number} + # We check with triple equal so that we also catch NaNs being equal + return print_maybe(v1 === v2, "$name differs: $v1 vs $v2") +end + +# We ignore NCDatasets. They contain a lot of state-ful information +function _compare(pass, v1::T, v2::T; name, ignore) where {T <: NCDatasets.NCDataset} + return pass +end + +function _compare( + v1::T, + v2::T; + name, + ignore, +) where {T <: CC.Fields.Field{<:CC.DataLayouts.AbstractData{<:Real}}} + return _compare(parent(v1), parent(v2); name, ignore) +end + +function _compare(pass, v1::T, v2::T; name, ignore) where {T <: CC.DataLayouts.AbstractData} + return pass && _compare(parent(v1), parent(v2); name, ignore) +end + +# Handle views +function _compare( + pass, + v1::SubArray{FT}, + v2::SubArray{FT}; + name, + ignore, +) where {FT <: AbstractFloat} + return pass && _compare(collect(v1), collect(v2); name, ignore) +end + +function _compare( + v1::AbstractArray{FT}, + v2::AbstractArray{FT}; + name, + ignore, +) where {FT <: AbstractFloat} + error = maximum(_error(v1, v2); init = zero(eltype(v1))) + return print_maybe(error <= 100eps(eltype(v1)), "$name error: $error") +end + +function _compare(pass, v1::T1, v2::T2; name, ignore) where {T1, T2} + error("v1 and v2 have different types") +end + +function print_maybe(exp, what) + exp || println(what) + return exp +end diff --git a/experiments/test/restart.jl b/experiments/test/restart_amip.jl similarity index 98% rename from experiments/test/restart.jl rename to experiments/test/restart_amip.jl index 139ad3cd05..a741eddf65 100644 --- a/experiments/test/restart.jl +++ b/experiments/test/restart_amip.jl @@ -21,7 +21,7 @@ using Test # Uncomment the following for cleaner output (but more difficult debugging) # Logging.disable_logging(Logging.Warn) -include("compare.jl") +include("compare_amip.jl") include("../AMIP/code_loading.jl") comms_ctx = ClimaComms.context() @@ -65,7 +65,7 @@ four_steps_reading["job_id"] = "four_steps_reading" Input.update_t_start_for_restarts!(four_steps_reading) cs_four_steps_reading = setup_and_run(four_steps_reading) -@testset "Restarts from command line arguments" begin +@testset "AMIP restarts (state and cache)" begin @test cs_four_steps_reading.tspan[1] == cs_four_steps.tspan[2] end diff --git a/experiments/test/restart_cmip.jl b/experiments/test/restart_cmip.jl new file mode 100644 index 0000000000..da8340cb89 --- /dev/null +++ b/experiments/test/restart_cmip.jl @@ -0,0 +1,127 @@ +# This test runs a small CMIP simulation four times. +# +# - The first time the simulation is run for four steps +# - The second time the simulation is run for two steps +# - The third time the simulation is run for two steps, but restarting from the +# second simulation +# +# After all these simulations are run, we compare the first and last runs. They +# should be bit-wise identical. +# +# The content of the simulation is not the most important, but it helps if it +# has all of the complexity possible. + +import ClimaComms +ClimaComms.@import_required_backends +import ClimaUtilities.OutputPathGenerator: maybe_wait_filesystem +import YAML +import Logging +using Test + +# Uncomment the following for cleaner output (but more difficult debugging) +# Logging.disable_logging(Logging.Warn) + +include("compare_cmip.jl") +include("../CMIP/code_loading.jl") + +comms_ctx = ClimaComms.context() +@info "Context: $(comms_ctx)" +ClimaComms.init(comms_ctx) + +# Make sure that all MPI processes agree on the output_loc +tmpdir = ClimaComms.iamroot(comms_ctx) ? mktempdir(pwd()) : "" +tmpdir = ClimaComms.bcast(comms_ctx, tmpdir) +# Sometimes the shared filesystem doesn't work properly and the folder is not +# synced across MPI processes. Let's add an additional check here. +maybe_wait_filesystem(ClimaComms.context(), tmpdir) + +# Parse the input config file as a dictionary +config_file = joinpath(@__DIR__, "restart_cmip.yml") +config_dict = Input.get_coupler_config_dict(config_file) + +# Four steps +four_steps = deepcopy(config_dict) + +four_steps["dt_cpl"] = "360secs" +four_steps["dt_ocean"] = "360secs" +four_steps["dt_seaice"] = "360secs" +four_steps["t_end"] = "1440secs" +four_steps["coupler_output_dir"] = tmpdir +four_steps["checkpoint_dt"] = "1440secs" +four_steps["job_id"] = "four_steps" + +println("Simulating four steps") +cs_four_steps = setup_and_run(four_steps) + +# Check that we can pick up a simulation by providing t_restart and restart_dir +println("Simulating four steps, options from command line") +four_steps_reading = deepcopy(four_steps) + +four_steps_reading["t_end"] = "1800secs" +four_steps_reading["detect_restart_files"] = true +four_steps_reading["restart_dir"] = cs_four_steps.dir_paths.checkpoints_dir +four_steps_reading["restart_t"] = 1440 +four_steps_reading["job_id"] = "four_steps_reading" +Input.update_t_start_for_restarts!(four_steps_reading) + +cs_four_steps_reading = setup_and_run(four_steps_reading) +@testset "CMIP restarts (state and cache)" begin + @test cs_four_steps_reading.tspan[1] == cs_four_steps.tspan[2] +end + +# Two steps + two steps (2 × 360 s = 720 s each half) +two_steps = deepcopy(config_dict) + +two_steps["dt_cpl"] = "360secs" +two_steps["dt_ocean"] = "360secs" +two_steps["dt_seaice"] = "360secs" +two_steps["t_end"] = "720secs" +two_steps["coupler_output_dir"] = tmpdir +# restart_cmip.yml sets dt_nogw/dt_ogw to 360 s so checkpoint_dt (360 s here, 1440 s in +# four_steps) is an integer multiple of the GW callback periods (ClimaAtmos checks +# checkpoint_dt / dt_nogw and checkpoint_dt / dt_ogw). +two_steps["checkpoint_dt"] = "360secs" +two_steps["job_id"] = "two_steps" + +# Copying since setup_and_run changes its content +println("Simulating two steps") +cs_two_steps1 = setup_and_run(two_steps) + +println("Restarting from checkpoint, initialization only") +# Construct a restarted CoupledSimulation at t = 720s, but do not advance it. +restart_init = deepcopy(two_steps) +restart_init["t_end"] = "720secs" # equal to t_start after update_t_start_for_restarts! +restart_init["detect_restart_files"] = true +restart_init["restart_dir"] = cs_two_steps1.dir_paths.checkpoints_dir +restart_init["restart_t"] = 720 +restart_init["restart_cache"] = true +restart_init["job_id"] = "two_steps_restart_init_only" +Input.update_t_start_for_restarts!(restart_init) +cs_two_steps_restart_init = Interfacer.CoupledSimulation(restart_init) + +@testset "Restart initialization matches checkpointed state" begin + @test cs_two_steps_restart_init.tspan[1] == cs_two_steps1.tspan[2] + + # Compare prognostic states after restart initialization (including coupler flux init) + # to the end-of-run state from the pre-restart segment (the checkpointed step). + @test compare( + cs_two_steps1.model_sims.atmos_sim.integrator.u, + cs_two_steps_restart_init.model_sims.atmos_sim.integrator.u, + ) + @test compare( + cs_two_steps1.model_sims.land_sim.integrator.u, + cs_two_steps_restart_init.model_sims.land_sim.integrator.u, + ) + # Sea-ice structural compare: see `compare_cmip.jl` (`IceWaterThermalEquilibrium`, `FluxFunction`). + @test compare( + cs_two_steps1.model_sims.ice_sim.ice.model, + cs_two_steps_restart_init.model_sims.ice_sim.ice.model, + ignore = [:clock, :parent, :ptr], + ) + @test compare( + cs_two_steps1.model_sims.ocean_sim.ocean.model, + cs_two_steps_restart_init.model_sims.ocean_sim.ocean.model, + # No cache restore from JLD2; timestepper working state (e.g. implicit_solver.t) can differ. + ignore = [:clock, :parent, :ptr, :timestepper], + ) +end diff --git a/experiments/test/restart_cmip.yml b/experiments/test/restart_cmip.yml new file mode 100644 index 0000000000..9c178b5036 --- /dev/null +++ b/experiments/test/restart_cmip.yml @@ -0,0 +1,41 @@ +FLOAT_TYPE: "Float32" +albedo_model: "CouplerAlbedo" +atmos_config_file: "config/atmos_configs/climaatmos_edonly.yml" +coupler_toml: ["toml/amip_edonly.toml"] +dt_nogw: "360secs" +dt_ogw: "360secs" +dt_atmos: "120secs" +dt_cpl: "240secs" +dt_land: "120secs" +dt_ocean: "240secs" +dt_seaice: "240secs" +dt_rad: "120secs" +dz_bottom: 100.0 +energy_check: false +h_elem: 6 +checkpoint_dt: "480secs" +ice_model: "clima_seaice" +land_model: "integrated" +land_spun_up_ic: false +mode_name: "cmip" +netcdf_output_at_levels: true +ocean_model: "oceananigans" +output_default_diagnostics: true +rayleigh_sponge: true +simple_ocean: true +save_cache: true +start_date: "20100101" +surface_setup: "PrescribedSurface" +t_end: "480secs" +topo_smoothing: true +topography: "Earth" +turbconv: ~ +# Disable gravity wave parameterizations for restart reproducibility +non_orographic_gravity_wave: false +orographic_gravity_wave: ~ +# Enable reproducible restart mode in ClimaAtmos +reproducible_restart: true +vert_diff: "DecayWithHeightDiffusion" +viscous_sponge: false +z_elem: 16 +z_max: 50000.0 diff --git a/experiments/test/restart_state_only.jl b/experiments/test/restart_state_only.jl index a7b07d9b23..b95b8b3591 100644 --- a/experiments/test/restart_state_only.jl +++ b/experiments/test/restart_state_only.jl @@ -1,4 +1,4 @@ -# This test runs a small AMIP simulation twice times. +# This test runs a small AMIP simulation twice. # # - The first time the simulation is run for two steps # - The second time the simulation is run for two steps, but restarting from the @@ -21,7 +21,7 @@ using Test # Uncomment the following for cleaner output (but more difficult debugging) # Logging.disable_logging(Logging.Warn) -include("compare.jl") +include("compare_amip.jl") include("../AMIP/code_loading.jl") comms_ctx = ClimaComms.context() @@ -77,6 +77,6 @@ two_steps_reading["save_cache"] = false Input.update_t_start_for_restarts!(two_steps_reading) cs_two_steps_reading = setup_and_run(two_steps_reading) -@testset "Restarts from command line arguments" begin +@testset "AMIP restarts (state only)" begin @test cs_two_steps_reading.tspan[1] == cs_two_steps.tspan[2] end diff --git a/ext/ClimaCouplerCMIPExt.jl b/ext/ClimaCouplerCMIPExt.jl index d55d6ad39a..3eb248a545 100644 --- a/ext/ClimaCouplerCMIPExt.jl +++ b/ext/ClimaCouplerCMIPExt.jl @@ -30,15 +30,13 @@ import ClimaCore as CC import ClimaParams as CP using KernelAbstractions: @kernel, @index, @inbounds -# Include helper functions first (used by both oceananigans.jl and clima_seaice.jl) -include("ClimaCouplerCMIPExt/climaocean_helpers.jl") - # Include skin temperature utilities include("ClimaCouplerCMIPExt/skin_temperature.jl") -# Include the model files +# Include the model files first so their types are available to climaocean_helpers.jl include("ClimaCouplerCMIPExt/oceananigans.jl") include("ClimaCouplerCMIPExt/clima_seaice.jl") +include("ClimaCouplerCMIPExt/climaocean_helpers.jl") include("ClimaCouplerCMIPExt/ocean_diagnostics.jl") include("ClimaCouplerCMIPExt/seaice_diagnostics.jl") diff --git a/ext/ClimaCouplerCMIPExt/clima_seaice.jl b/ext/ClimaCouplerCMIPExt/clima_seaice.jl index bae92b984d..4b1f4d3cc5 100644 --- a/ext/ClimaCouplerCMIPExt/clima_seaice.jl +++ b/ext/ClimaCouplerCMIPExt/clima_seaice.jl @@ -612,19 +612,6 @@ Arguments: ρτyio[i, j, 1] * ρₒ⁻¹ * OC.Operators.ℑyᵃᶠᵃ(i, j, 1, grid, ice_concentration) end -""" - get_model_prog_state(sim::ClimaSeaIceSimulation) - -Returns the model state of a simulation as a `ClimaCore.FieldVector`. -It's okay to leave this unimplemented for now, but we won't be able to use the -restart system. - -TODO extend this for non-ClimaCore states. -""" -function Checkpointer.get_model_prog_state(sim::ClimaSeaIceSimulation) - @warn "get_model_prog_state not implemented for ClimaSeaIceSimulation" -end - # Additional ClimaSeaIceSimulation getter methods for plotting debug fields Interfacer.get_field(sim::ClimaSeaIceSimulation, ::Val{:u}) = sim.ice.model.velocities.u Interfacer.get_field(sim::ClimaSeaIceSimulation, ::Val{:v}) = sim.ice.model.velocities.v diff --git a/ext/ClimaCouplerCMIPExt/climaocean_helpers.jl b/ext/ClimaCouplerCMIPExt/climaocean_helpers.jl index e947b02867..40e5609a0b 100644 --- a/ext/ClimaCouplerCMIPExt/climaocean_helpers.jl +++ b/ext/ClimaCouplerCMIPExt/climaocean_helpers.jl @@ -186,3 +186,76 @@ function unit_basis_vector_data(::Type{V}, local_geometry) where {V} FT = CC.Geometry.undertype(typeof(local_geometry)) return FT(1) / CC.Geometry._norm(V(FT(1)), local_geometry) end + + +""" + get_oc_sim(sim) + +Return the underlying `Oceananigans.Simulation` object for component models +that use Oceananigans under the hood. +""" +get_oc_sim(sim::OceananigansSimulation) = sim.ocean +get_oc_sim(sim::ClimaSeaIceSimulation) = sim.ice + + +""" + Checkpointer.checkpoint_model_state(sim, comms_ctx, t, prev_checkpoint_t; output_dir) + +Save the state of an Oceananigans-backed simulation to a JLD2 file at time `t` +(in seconds) using `Oceananigans.checkpoint`. + +If a previous checkpoint exists, it is removed to avoid accumulating files. +A value of -1 for `prev_checkpoint_t` indicates there is no previous checkpoint. +""" +function Checkpointer.checkpoint_model_state( + sim::Union{OceananigansSimulation, ClimaSeaIceSimulation}, + comms_ctx::ClimaComms.AbstractCommsContext, + t::Int, + prev_checkpoint_t::Int; + output_dir = "output", +) + day = floor(Int, t / (60 * 60 * 24)) + sec = floor(Int, t % (60 * 60 * 24)) + @info "Saving checkpoint $(nameof(sim)) model state to JLD2 on day $day second $sec" + output_file = joinpath(output_dir, "checkpoint_$(nameof(sim))_$t.jld2") + prev_checkpoint_file = + joinpath(output_dir, "checkpoint_$(nameof(sim))_$(prev_checkpoint_t).jld2") + Checkpointer.remove_checkpoint(prev_checkpoint_file, prev_checkpoint_t, comms_ctx) + OC.checkpoint(get_oc_sim(sim); filepath = output_file) + return nothing +end + +""" + Checkpointer.restart_model_state!(sim, input_file, comms_ctx) + +Restore the state of an Oceananigans-backed simulation from a JLD2 checkpoint +file using `Oceananigans.set!`. + +The coupler constructs `input_file` with a `.hdf5` extension; this method +replaces it with `.jld2` to match the format written by `checkpoint_model_state`. +""" +function Checkpointer.restart_model_state!( + sim::Union{OceananigansSimulation, ClimaSeaIceSimulation}, + input_file, + comms_ctx, +) + jld2_file = replace(input_file, ".hdf5" => ".jld2") + ispath(jld2_file) || error("Oceananigans checkpoint file not found: $jld2_file") + OC.set!(get_oc_sim(sim); checkpoint = jld2_file) + return nothing +end + +""" + Checkpointer.restart_model_cache!(sim, input_file) + +No-op for Oceananigans-backed simulations. All necessary state is restored via +`restart_model_state!`; there is no separate cache to restore. +""" +function Checkpointer.restart_model_cache!( + sim::Union{OceananigansSimulation, ClimaSeaIceSimulation}, + input_file, +) + @warn "$(nameof(sim)) does not support restoring the model cache from a checkpoint. " * + "The simulation cache will not be restored." + return nothing +end diff --git a/ext/ClimaCouplerCMIPExt/oceananigans.jl b/ext/ClimaCouplerCMIPExt/oceananigans.jl index 48a155d6f4..f58779d308 100644 --- a/ext/ClimaCouplerCMIPExt/oceananigans.jl +++ b/ext/ClimaCouplerCMIPExt/oceananigans.jl @@ -150,8 +150,6 @@ function OceananigansSimulation( # Simpler setup @info "Using simpler ocean setup; to be used for software testing only." free_surface = OC.SplitExplicitFreeSurface(grid; substeps = 70) - momentum_advection = OC.WENOVectorInvariant(order = 5) - horizontal_viscosity = OC.HorizontalScalarDiffusivity(ν = 1e4) tracer_advection = OC.WENO(order = 5) vertical_mixing = OC.ConvectiveAdjustmentVerticalDiffusivity( background_κz = 1e-5, @@ -159,7 +157,8 @@ function OceananigansSimulation( background_νz = 1e-4, convective_νz = 0.1, ) - + momentum_advection = OC.WENOVectorInvariant(order = 5) + horizontal_viscosity = OC.HorizontalScalarDiffusivity(ν = 1e4) closure = (horizontal_viscosity, vertical_mixing) end @@ -625,19 +624,6 @@ function FieldExchanger.update_sim!(sim::OceananigansSimulation, csf) return nothing end -""" - get_model_prog_state(sim::OceananigansSimulation) - -Returns the model state of a simulation as a `ClimaCore.FieldVector`. -It's okay to leave this unimplemented for now, but we won't be able to use the -restart system. - -TODO extend this for non-ClimaCore states. -""" -function Checkpointer.get_model_prog_state(sim::OceananigansSimulation) - @warn "get_model_prog_state not implemented for OceananigansSimulation" -end - # Additional OceananigansSimulation getter methods for plotting debug fields Interfacer.get_field(sim::OceananigansSimulation, ::Val{:salinity}) = sim.ocean.model.tracers.S diff --git a/src/Checkpointer.jl b/src/Checkpointer.jl index e2dbc54520..a092333b8e 100644 --- a/src/Checkpointer.jl +++ b/src/Checkpointer.jl @@ -42,7 +42,13 @@ get_model_cache(sim::Interfacer.AbstractComponentSimulation) = nothing prev_checkpoint_t::Int; output_dir = "output") -Checkpoint the model state of a simulation to a HDF5 file at a given time, t (in seconds). +Checkpoint the model state of a simulation at time `t` (in seconds). + +The default implementation uses `get_model_prog_state(sim)` to obtain a +`ClimaCore.FieldVector` and writes it to an HDF5 file via `ClimaCore.InputOutput`. +If `get_model_prog_state` returns `nothing`, this function does nothing. + +Component models that do not use ClimaCore can override this method to use their own checkpointing. If a previous checkpoint exists, it is removed. This is to avoid accumulating many checkpoint files in the output directory. A value of -1 for `prev_checkpoint_t` @@ -56,6 +62,7 @@ function checkpoint_model_state( output_dir = "output", ) Y = get_model_prog_state(sim) + isnothing(Y) && return nothing day = floor(Int, t / (60 * 60 * 24)) sec = floor(Int, t % (60 * 60 * 24)) @info "Saving checkpoint $(nameof(sim)) model state to HDF5 on day $day second $sec" @@ -83,6 +90,12 @@ end Checkpoint the model cache to N JLD2 files at a given time, t (in seconds), where N is the number of MPI ranks. +The default implementation uses `get_model_cache(sim)` to obtain the cache. +If `get_model_cache` returns `nothing`, this function does nothing. + +Component models that do not use ClimaCore can override this method +to use their own checkpointing. + Objects are saved to JLD2 files because caches are generally not ClimaCore objects (and ClimaCore.InputOutput can only save `Field`s or `FieldVector`s). @@ -97,7 +110,7 @@ function checkpoint_model_cache( prev_checkpoint_t::Int; output_dir = "output", ) - # Move p to CPU (because we cannot save CUArrays) + isnothing(get_model_cache(sim)) && return nothing p = get_model_cache_to_checkpoint(sim) day = floor(Int, t / (60 * 60 * 24)) sec = floor(Int, t % (60 * 60 * 24)) @@ -226,16 +239,14 @@ function checkpoint_sims(cs::Interfacer.CoupledSimulation) comms_ctx = ClimaComms.context(cs) for sim in cs.model_sims - if !isnothing(Checkpointer.get_model_prog_state(sim)) - Checkpointer.checkpoint_model_state( - sim, - comms_ctx, - time, - prev_checkpoint_t; - output_dir, - ) - end - if !isnothing(Checkpointer.get_model_cache(sim)) && cs.save_cache + Checkpointer.checkpoint_model_state( + sim, + comms_ctx, + time, + prev_checkpoint_t; + output_dir, + ) + if cs.save_cache Checkpointer.checkpoint_model_cache( sim, comms_ctx, @@ -276,15 +287,10 @@ function restart!(cs, checkpoint_dir, checkpoint_t, restart_cache) @info "Restarting from time $(checkpoint_t) and directory $(checkpoint_dir)" pid = ClimaComms.mypid(ClimaComms.context(cs)) for sim in cs.model_sims - if !isnothing(Checkpointer.get_model_prog_state(sim)) - input_file_state = - output_file = joinpath( - checkpoint_dir, - "checkpoint_$(nameof(sim))_$(checkpoint_t).hdf5", - ) - restart_model_state!(sim, input_file_state, ClimaComms.context(cs)) - end - if !isnothing(Checkpointer.get_model_cache(sim)) && restart_cache + input_file_state = + joinpath(checkpoint_dir, "checkpoint_$(nameof(sim))_$(checkpoint_t).hdf5") + restart_model_state!(sim, input_file_state, ClimaComms.context(cs)) + if restart_cache input_file_cache = joinpath( checkpoint_dir, "checkpoint_cache_$(pid)_$(nameof(sim))_$(checkpoint_t).jld2", @@ -301,14 +307,18 @@ end """ restart_model_cache!(sim, input_file) -Overwrite the content of `sim` with the cache from the `input_file`. +Restore the cache of `sim` from `input_file`. + +The default implementation uses `get_model_cache(sim)` to check whether the +simulation has a cache. If `get_model_cache` returns `nothing`, this function +does nothing. It relies on `restore_cache!(sim, old_cache)`, which has to be implemented by the component models that have a cache. """ function restart_model_cache!(sim, input_file) + isnothing(get_model_cache(sim)) && return nothing ispath(input_file) || error("File $(input_file) not found") - # Component models are responsible for defining a method for this JLD2.jldopen(input_file) do file restore_cache!(sim, file["cache"]) end @@ -317,12 +327,18 @@ end """ restart_model_state!(sim, input_file, comms_ctx) -Overwrite the content of `sim` with the state from the `input_file`. +Restore the prognostic state of `sim` from `input_file`. + +The default implementation reads a `ClimaCore.FieldVector` from an HDF5 file +written by the default `checkpoint_model_state`. If `get_model_prog_state` +returns `nothing`, this function does nothing. + +Component models that do not use ClimaCore can override this method to use their own checkpointing. """ function restart_model_state!(sim, input_file, comms_ctx) - ispath(input_file) || error("File $(input_file) not found") Y = get_model_prog_state(sim) - # open file and read + isnothing(Y) && return nothing + ispath(input_file) || error("File $(input_file) not found") CC.InputOutput.HDF5Reader(input_file, comms_ctx) do restart_reader Y_new = CC.InputOutput.read_field(restart_reader, "model_state") # set new state