Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 16 additions & 6 deletions .buildkite/pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ steps:

- label: "MPI restarts"
key: "mpi_restarts"
command: "srun julia --color=yes --project=experiments/AMIP/ experiments/test/restart.jl"
command: "srun julia --color=yes --project=experiments/AMIP/ experiments/test/restart_amip.jl"
retry: *retry_policy
env:
CLIMACOMMS_CONTEXT: "MPI"
Expand All @@ -113,24 +113,34 @@ steps:
slurm_ntasks: 2
slurm_mem: 32GB

- label: "GPU restarts (state and cache)"
command: "julia -O0 --color=yes --project=experiments/AMIP/ experiments/test/restart.jl"
- label: "GPU AMIP restarts (state and cache)"
command: "julia -O0 --color=yes --project=experiments/AMIP/ experiments/test/restart_amip.jl"
retry: *retry_policy
env:
CLIMACOMMS_DEVICE: "CUDA"
retry: *retry_policy
agents:
slurm_gpus: 1
slurm_mem: 32GB

- label: "GPU restarts (state only)"
- label: "GPU AMIP restarts (state only)"
command: "julia -O0 --color=yes --project=experiments/AMIP/ experiments/test/restart_state_only.jl"
retry: *retry_policy
env:
CLIMACOMMS_DEVICE: "CUDA"
retry: *retry_policy
agents:
slurm_gpus: 1
slurm_mem: 32GB

- label: "GPU CMIP restarts (state and cache)"
command: "julia -O0 --color=yes --project=experiments/CMIP/ experiments/test/restart_cmip.jl"
retry: *retry_policy
env:
CLIMACOMMS_DEVICE: "CUDA"
agents:
slurm_ntasks: 1
slurm_gres: "gpu:1"
slurm_mem: 32GB

- group: "Integration Tests (single column)"
steps:
- label: "SCM: slabplanet aqua (atmos + slab ocean column)"
Expand Down
28 changes: 27 additions & 1 deletion docs/src/checkpointer.md
Original file line number Diff line number Diff line change
Expand Up @@ -108,13 +108,35 @@ forward, but there are still several challenges that need to be solved:
Point 3. adds significant amount of code and requires component models to
specify how their cache has to be restored.

If you are adding a component model, you have to extend the methods.
### Adding checkpointing to a new component model

There are two ways to add checkpoint/restart support for a new component model:

**Path A (ClimaCore-based models):** extend `get_model_prog_state` to return the
prognostic state as a `ClimaCore.FieldVector`. The default
`checkpoint_model_state` and `restart_model_state!` implementations will handle
HDF5 I/O via `ClimaCore.InputOutput` automatically. This path is intended for
models whose prognostic state is a `ClimaCore.FieldVector`; models that do not
use ClimaCore should use Path B instead.

```
Checkpointer.get_model_prog_state
Checkpointer.get_model_cache
Checkpointer.restore_cache!
```

**Path B (custom checkpoint format):** override `checkpoint_model_state` and
`restart_model_state!` directly for full control over the checkpoint format. This
is the approach used by `OceananigansSimulation`, which writes JLD2 checkpoints
via Oceananigans' native `checkpoint` and restores them with `Oceananigans.set!`.

```
Checkpointer.checkpoint_model_state
Checkpointer.checkpoint_model_cache # optional; no-op if cache checkpointing is not supported
Checkpointer.restart_model_state!
Checkpointer.restart_model_cache! # optional; warn or no-op if cache restore is not supported
```

`ClimaCoupler` moves objects to the CPU with `Adapt(Array, x)`. `Adapt`
traverses the object recursively, and proper `Adapt` methods have to be defined
for every object involved in the chain. The easiest way to do this is using the
Expand Down Expand Up @@ -159,7 +181,11 @@ This approach allows for a signficant reducation in the file size of the cache.
Checkpointer.get_model_prog_state
Checkpointer.get_model_cache
Checkpointer.get_model_cache_to_checkpoint
Checkpointer.checkpoint_model_state
Checkpointer.checkpoint_model_cache
Checkpointer.restart!
Checkpointer.restart_model_state!
Checkpointer.restart_model_cache!
Checkpointer.checkpoint_sims
Checkpointer.t_start_from_checkpoint
Checkpointer.restore!
Expand Down
2 changes: 2 additions & 0 deletions experiments/test/amip_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@ mode_name: "amip"
netcdf_interpolation_num_points: [90, 45, 31]
netcdf_output_at_levels: true
output_default_diagnostics: true
orographic_gravity_wave: ~
rayleigh_sponge: true
reproducible_restart: true
Comment thread
akshaysridhar marked this conversation as resolved.
start_date: "20100101"
surface_setup: "PrescribedSurface"
t_end: "540secs"
Expand Down
File renamed without changes.
175 changes: 175 additions & 0 deletions experiments/test/compare_cmip.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
# compare.jl provides function to recursively compare complex objects while also
# allowing for some numerical tolerance.

import ClimaComms
import ClimaAtmos as CA
import ClimaCore as CC
import Oceananigans as OC
import ClimaSeaIce as CSI
using ClimaSeaIce.SeaIceThermodynamics.HeatBoundaryConditions: IceWaterThermalEquilibrium
import NCDatasets

"""
_error(arr1::AbstractArray, arr2::AbstractArray; ABS_TOL = 100eps(eltype(arr1)))

We compute the error in this way:
- when the absolute value is larger than ABS_TOL, we use the absolute error
- in the other cases, we compare the relative errors
"""
function _error(arr1::AbstractArray, arr2::AbstractArray; ABS_TOL = 100eps(eltype(arr1)))
# There are some parameters, e.g. Obukhov length, for which Inf
# is a reasonable value (implying a stability parameter in the neutral boundary layer
# regime, for instance). We account for such instances with the `isfinite` function.
arr1 = Array(arr1) .* isfinite.(Array(arr1))
arr2 = Array(arr2) .* isfinite.(Array(arr2))
diff = abs.(arr1 .- arr2)
denominator = abs.(arr1)
error = ifelse.(denominator .> ABS_TOL, diff ./ denominator, diff)
return error
end

"""
compare(v1, v2; name = "", ignore = Set([:rc]))

Return whether `v1` and `v2` are the same (up to floating point errors).
`compare` walks through all the properties in `v1` and `v2` until it finds
that there are no more properties. At that point, `compare` tries to match the
resulting objects. When such objects are arrays with floating point, `compare`
defines a notion of `error` that is the following: when the absolute value is
less than `100eps(eltype)`, `error = absolute_error`, otherwise it is relative
error. The `error` is then compared against a tolerance.
Keyword arguments
=================
- `name` is used to collect the name of the property while we go recursively
over all the properties. You can pass a base name.
- `ignore` is a collection of `Symbol`s that identify properties that are
ignored when walking through the tree. This is useful for properties that
are known to be different (e.g., `output_dir`).
`:rc` is some CUDA/CuArray internal object that we don't care about
"""
function compare(
v1::T1,
v2::T2;
name = "",
ignore = Set([:rc]),
) where {
T1 <: Union{
CC.Fields.FieldVector,
CC.Spaces.AbstractSpace,
NamedTuple,
CA.AtmosCache,
OC.Models.HydrostaticFreeSurfaceModels.HydrostaticFreeSurfaceModel,
CSI.SeaIceModel,
},
T2 <: Union{
CC.Fields.FieldVector,
CC.Spaces.AbstractSpace,
NamedTuple,
CA.AtmosCache,
OC.Models.HydrostaticFreeSurfaceModels.HydrostaticFreeSurfaceModel,
CSI.SeaIceModel,
},
}
pass = true
return _compare(pass, v1, v2; name, ignore)
end

function _compare(pass, v1::T, v2::T; name, ignore) where {T}
properties = filter(x -> !(x in ignore), propertynames(v1))
if isempty(properties)
pass &= _compare(v1, v2; name, ignore)
else
# Recursive case
for p in properties
pass &= _compare(
pass,
getproperty(v1, p),
getproperty(v2, p);
name = "$(name).$(p)",
ignore,
)
end
end
return pass
end

# ClimaSeaIce `IceWaterThermalEquilibrium` and `FluxFunction` use reference `==` at compile time
# (see `@code_typed ==(iwte1, iwte2)`), and `FluxFunction.func` seem to hold non-comparable closures.
function _compare(
pass,
v1::IceWaterThermalEquilibrium,
v2::IceWaterThermalEquilibrium;
name,
ignore,
)
pass &= _compare(pass, v1.salinity, v2.salinity; name = "$(name).salinity", ignore)
return pass
end

function _compare(pass, v1::CSI.FluxFunction, v2::CSI.FluxFunction; name, ignore)
pass &=
_compare(pass, v1.parameters, v2.parameters; name = "$(name).parameters", ignore)
return pass
end

function _compare(v1::T, v2::T; name, ignore) where {T}
return print_maybe(v1 == v2, "$name differs")
end

function _compare(v1::T, v2::T; name, ignore) where {T <: Union{AbstractString, Symbol}}
# What we can safely print without filling STDOUT
return print_maybe(v1 == v2, "$name differs: $v1 vs $v2")
end

function _compare(v1::T, v2::T; name, ignore) where {T <: Number}
# We check with triple equal so that we also catch NaNs being equal
return print_maybe(v1 === v2, "$name differs: $v1 vs $v2")
end

# We ignore NCDatasets. They contain a lot of state-ful information
function _compare(pass, v1::T, v2::T; name, ignore) where {T <: NCDatasets.NCDataset}
return pass
end

function _compare(
v1::T,
v2::T;
name,
ignore,
) where {T <: CC.Fields.Field{<:CC.DataLayouts.AbstractData{<:Real}}}
return _compare(parent(v1), parent(v2); name, ignore)
end

function _compare(pass, v1::T, v2::T; name, ignore) where {T <: CC.DataLayouts.AbstractData}
return pass && _compare(parent(v1), parent(v2); name, ignore)
end

# Handle views
function _compare(
pass,
v1::SubArray{FT},
v2::SubArray{FT};
name,
ignore,
) where {FT <: AbstractFloat}
return pass && _compare(collect(v1), collect(v2); name, ignore)
end

function _compare(
v1::AbstractArray{FT},
v2::AbstractArray{FT};
name,
ignore,
) where {FT <: AbstractFloat}
error = maximum(_error(v1, v2); init = zero(eltype(v1)))
return print_maybe(error <= 100eps(eltype(v1)), "$name error: $error")
end

function _compare(pass, v1::T1, v2::T2; name, ignore) where {T1, T2}
error("v1 and v2 have different types")
end

function print_maybe(exp, what)
exp || println(what)
return exp
end
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ using Test
# Uncomment the following for cleaner output (but more difficult debugging)
# Logging.disable_logging(Logging.Warn)

include("compare.jl")
include("compare_amip.jl")
include("../AMIP/code_loading.jl")

comms_ctx = ClimaComms.context()
Expand Down Expand Up @@ -65,7 +65,7 @@ four_steps_reading["job_id"] = "four_steps_reading"
Input.update_t_start_for_restarts!(four_steps_reading)

cs_four_steps_reading = setup_and_run(four_steps_reading)
@testset "Restarts from command line arguments" begin
@testset "AMIP restarts (state and cache)" begin
@test cs_four_steps_reading.tspan[1] == cs_four_steps.tspan[2]
end

Expand Down
Loading
Loading