From 0367866e3ca3785d9adb9222f0fa8e8ab0bc513b Mon Sep 17 00:00:00 2001 From: Kevin Phan <98072684+ph-kev@users.noreply.github.com> Date: Wed, 4 Dec 2024 13:49:17 -0800 Subject: [PATCH 1/2] Add interpolation routine This commit adds an interpolation routine for use in ClimaAnalysis, which seeks to replace Interpolations.jl in Var.jl. The interpolation routine supports N-dimensional linear interpolation on a grid with throw, flat, or periodic boundary conditions. Compared against Interpolations.jl, the interpolation routine does not make a struct and does not allocate anything on the heap when interpolating a point. --- src/ClimaAnalysis.jl | 1 + src/Interpolations.jl | 276 +++++++++++++++++++++++++++++++++ test/runtests.jl | 1 + test/test_Interpolations.jl | 298 ++++++++++++++++++++++++++++++++++++ 4 files changed, 576 insertions(+) create mode 100644 src/Interpolations.jl create mode 100644 test/test_Interpolations.jl diff --git a/src/ClimaAnalysis.jl b/src/ClimaAnalysis.jl index b7401892..2dd9614d 100644 --- a/src/ClimaAnalysis.jl +++ b/src/ClimaAnalysis.jl @@ -5,6 +5,7 @@ include("Utils.jl") import .Utils include("Numerics.jl") +include("Interpolations.jl") include("Var.jl") @reexport using .Var diff --git a/src/Interpolations.jl b/src/Interpolations.jl new file mode 100644 index 00000000..7a888678 --- /dev/null +++ b/src/Interpolations.jl @@ -0,0 +1,276 @@ +module Interpolations + +#= +Why not use Interpolations.jl? + +One of the most expensive operation in terms of time and memory is +`Var.resampled_as(src_var, dest_var)` which uses an interpolation to resample `data` in +`src_var` to match the `dims` in `dest_var`. However, interpolating using Interpolations.jl +is expensive. One must first intialize an interpolant, whose size in memory is at least as +big as the memory of `src_var.data`. Then, each evaluation using the interpolant +costs some amount of allocation on the heap. As a result, interpolating is extremely slow +for a large number of points. + +To solve this, we write our interpolation routine in the Numerics module that support the following: +1. No extra dependencies +2. No allocation when evaluating a point +3. Support boundary conditions for periodic, flat, and throw on an irregular grid +5. Comparable or better performance to Interpolations.jl +=# + +""" + linear_interpolate(point::NTuple{N, FT1}, + axes::NTuple{N, Vector}, + data::AbstractArray{FT2, N}, + extp_conds::NTuple{N, NTuple{4, Function}}) where {N, FT1, FT2} + +Linear interpolate `data` on `axes` and return the value at `point`. Extrapolation is +handled by `extp_conds`. +""" +function linear_interpolate( + point::NTuple{N}, + axes, + data::AbstractArray{FT, N}, + extp_conds, +) where {N, FT} + # Get a new point as determined by the extrapolation condition + point = extp_to_point(point, axes, extp_conds) + + # Find which cell contain the point + cell_indices_for_axes = find_cell_indices_for_axes(point, axes) + val = zero(FT) + + # Compute the denominator of the formula for linear interpolation + # (in 1D, this is x_1 - x_0) + bottom_term = compute_bottom_term(axes, cell_indices_for_axes) + + # Iterate through all 2^N points + @inbounds for bits in 0:(2^N - 1) + term = one(FT) + bound_indices = get_indices(cell_indices_for_axes, bits) + sign = get_sign(cell_indices_for_axes, bits) + # Weight is the value at each of the points of the cell + weight = data[get_complement_indices(cell_indices_for_axes, bits)...] + @inbounds for (dim_idx, bound_idx) in pairs(bound_indices) + val_minus_x2_or_x1 = point[dim_idx] - axes[dim_idx][bound_idx] + term *= val_minus_x2_or_x1 + end + term *= sign * weight + val += term + end + return val / bottom_term +end + +""" + linear_interpolate(point::Number, + axes, + data::AbstractArray{FT, N}, + extp_conds) where {N, FT} + +Convert a number to a tuple and linear interpolate. +""" +function linear_interpolate( + point::Number, + axes, + data::AbstractArray{FT, N}, + extp_conds, +) where {N, FT} + point = Tuple(point...) + return linear_interpolate(point, axes, data, extp_conds) +end + +""" + linear_interpolate(point::AbstractVector, + axes, + data::AbstractArray{FT, N}, + extp_conds) where {N, FT} + +Convert a vector to a tuple and linear interpolate. +""" +function linear_interpolate( + point::AbstractVector, + axes, + data::AbstractArray{FT, N}, + extp_conds, +) where {N, FT} + point = Tuple(coord for coord in point) + return linear_interpolate(point, axes, data, extp_conds) +end + +""" + linear_interpolate(point::Tuple, + axes, + data::AbstractArray{FT, N}, + extp_conds) where {N, FT} + +Promote a tuple and linear interpolate. +""" +function linear_interpolate( + point::Tuple, + axes, + data::AbstractArray{FT, N}, + extp_conds, +) where {N, FT} + point = promote(point...) + return linear_interpolate(point, axes, data, extp_conds) +end + + +""" + compute_bottom_term(axes, cell_indices_for_axes::NTuple{N}) + +Compute the bottom term when linearly interpolating. + +Consider the formula for 1D linear interpolation which is + y_0 * ((x_1 - x) / (x_1 - x_0)) + y_1 * ((x - x_0) / (x_1 - x_0)) +for interpolating the point (x, y) on line between (x_0, y_0) and (x_1, y_1). This function +computes x_1 - x_0. +""" +function compute_bottom_term(axes, cell_indices_for_axes::NTuple{N}) where {N} + return reduce( + *, + ntuple( + dim_idx -> + axes[dim_idx][cell_indices_for_axes[dim_idx][end]] - + axes[dim_idx][cell_indices_for_axes[dim_idx][begin]], + N, + ), + ) +end + +""" + get_complement_indices(indices::NTuple{N, Tuple{I, I}}, bits) where {N, I} + +Given a tuple consisting of 2-tuple, return a tuple of one element from each tuple according +to bits. The elements in the tuple are the complement of those found by `get_indices`. +""" +function get_complement_indices(indices::NTuple{N}, bits) where {N} + # Bit manipulation can be found here: + # https://github.com/parsiad/mlinterp/blob/master/mlinterp/mlinterp.hpp + # Adjusted for 1-indexing instead of 0-indexing + return ntuple(dim -> if (bits & (1 << (dim - 1)) != 0) + indices[dim][2] + else + indices[dim][1] + end, N) +end + +""" + get_indices(indices::NTuple{N, Tuple{I, I}}, bits) where {N, I} + +Given a tuple consisting of 2-tuple, return a tuple of one element from each tuple according to bits. +""" +function get_indices(indices::NTuple{N}, bits) where {N} + # See get_complement_indices for where the Bit manipulation comes from + return ntuple(dim -> if (bits & (1 << (dim - 1)) != 0) + indices[dim][1] + else + indices[dim][2] + end, N) +end + +""" + get_sign(indices::NTuple{N, Tuple{I, I}}, bits) where {N, I} + +Given a tuple consisting of 2-tuple, compute the appropriate sign when interpolating. +""" +function get_sign(_indices::NTuple{N}, bits) where {N} + # See get_complement_indices for where the bit manipulation comes from + return reduce(*, ntuple(dim -> if (bits & (1 << (dim - 1)) != 0) + 1 + else + -1 + end, N)) +end + +""" + find_cell_indices_for_axes(point::NTuple{N, FT}, + axes::NTuple{N, A}) where {N, FT, A <:AbstractVector} + +Given a point and axes, find the indices of the N-dimensional hyperrectangle, where the +point lives in. +""" +function find_cell_indices_for_axes(point::NTuple{N}, axes) where {N} + return ntuple( + dim_idx -> find_cell_indices_for_ax(point[dim_idx], axes[dim_idx]), + N, + ) +end + +""" + find_cell_indices_for_ax(val::FT1, ax::AbstractVector{FT2}) where {FT1, FT2} + +Given `val` and an `ax`, find the indices of the cell, where `val` lives in. +""" +function find_cell_indices_for_ax( + val::FT1, + ax::AbstractVector{FT2}, +) where {FT1, FT2} + len_of_ax = length(ax) + (val == ax[begin]) && return (1, 2) + (val == ax[end]) && return (len_of_ax - 1, len_of_ax) + idx = searchsortedfirst(ax, val) + return (idx - 1, idx) +end + +""" + extp_to_point(point::NTuple{N, FT1}, axes::NTuple{N, Vector}, extp_conds) where {N, FT1} + +Return a new point to evaluate at according to the extrapolation conditions. +""" +function extp_to_point(point::NTuple{N}, axes, extp_conds) where {N} + return ntuple( + idx -> extp_conds[idx].get_val_for_point(point[idx], axes[idx]), + N, + ) +end + +""" + extp_cond_throw() + +Return extrapolation condition for throwing an error when the point is out of bounds. + +The first and last nodes are not co-located. For example, if the axis is [1.0, 2.0, 3.0] +and the data is [4.0, 5.0, 6.0], then the value at 3.0 is 6.0 and not 4.0. +""" +function extp_cond_throw() + get_val_for_point(val, ax) = begin + val < ax[begin] && return error("Out of bounds error with $val in $ax") + val > ax[end] && return error("Out of bounds error with $val in $ax") + return val + end + return (; get_val_for_point = get_val_for_point) +end + +""" + extp_cond_flat() + +Return flat extrapolation condition. +""" +function extp_cond_flat() + get_val_for_point(val, ax) = begin + val < ax[begin] && return typeof(val)(ax[begin]) + val > ax[end] && return typeof(val)(ax[end]) + return val + end + return (; get_val_for_point) +end + +""" + extp_cond_periodic() + +Return periodic extrapolation condtion. +""" +function extp_cond_periodic() + get_val_for_point(val, ax) = begin + if (val < ax[begin]) || (val > ax[end]) + width = ax[end] - ax[begin] + new_val = mod(val - ax[begin], width) + ax[begin] + return typeof(val)(new_val) + end + return val + end + return (; get_val_for_point) +end + +end diff --git a/test/runtests.jl b/test/runtests.jl index 8e40c6b4..c4a8323c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -8,6 +8,7 @@ using Test @safetestset "Utils" begin @time include("test_Utils.jl") end @safetestset "Numerics" begin @time include("test_Numerics.jl") end +@safetestset "Interpolations" begin @time include("test_Interpolations.jl") end @safetestset "SimDir" begin @time include("test_Sim.jl") end @safetestset "Atmos" begin @time include("test_Atmos.jl") end @safetestset "Leaderboard" begin @time include("test_Leaderboard.jl") end diff --git a/test/test_Interpolations.jl b/test/test_Interpolations.jl new file mode 100644 index 00000000..60a3ec74 --- /dev/null +++ b/test/test_Interpolations.jl @@ -0,0 +1,298 @@ +using Test +import ClimaAnalysis + +@testset "Get indices and sign" begin + indices = ((1, 2),) + @test ClimaAnalysis.Interpolations.get_indices(indices, 0) == (2,) + @test ClimaAnalysis.Interpolations.get_indices(indices, 1) == (1,) + @test ClimaAnalysis.Interpolations.get_complement_indices(indices, 0) == + (1,) + @test ClimaAnalysis.Interpolations.get_complement_indices(indices, 1) == + (2,) + @test ClimaAnalysis.Interpolations.get_sign(indices, 0) == -1 + @test ClimaAnalysis.Interpolations.get_sign(indices, 1) == 1 + + indices = ((1, 2), (3, 4)) + @test ClimaAnalysis.Interpolations.get_indices(indices, 0) == (2, 4) + @test ClimaAnalysis.Interpolations.get_indices(indices, 1) == (1, 4) + @test ClimaAnalysis.Interpolations.get_indices(indices, 2) == (2, 3) + @test ClimaAnalysis.Interpolations.get_indices(indices, 3) == (1, 3) + @test ClimaAnalysis.Interpolations.get_complement_indices(indices, 0) == + (1, 3) + @test ClimaAnalysis.Interpolations.get_complement_indices(indices, 1) == + (2, 3) + @test ClimaAnalysis.Interpolations.get_complement_indices(indices, 2) == + (1, 4) + @test ClimaAnalysis.Interpolations.get_complement_indices(indices, 3) == + (2, 4) + @test ClimaAnalysis.Interpolations.get_sign(indices, 0) == 1 + @test ClimaAnalysis.Interpolations.get_sign(indices, 1) == -1 + @test ClimaAnalysis.Interpolations.get_sign(indices, 2) == -1 + @test ClimaAnalysis.Interpolations.get_sign(indices, 3) == 1 +end + +@testset "Find indices for cell" begin + val1 = 5 + ax1 = [0, 10] + @test ClimaAnalysis.Interpolations.find_cell_indices_for_ax(val1, ax1) == + (1, 2) + + val2 = 6 + ax2 = [0, 4, 10] + @test ClimaAnalysis.Interpolations.find_cell_indices_for_ax(val2, ax2) == + (2, 3) + + @test ClimaAnalysis.Interpolations.find_cell_indices_for_axes( + (val1, val2), + (ax1, ax2), + ) == ((1, 2), (2, 3)) +end + +@testset "Extrapolation conditions" begin + throw = ClimaAnalysis.Interpolations.extp_cond_throw() + flat = ClimaAnalysis.Interpolations.extp_cond_flat() + periodic = ClimaAnalysis.Interpolations.extp_cond_periodic() + + ax = [0, 1, 2, 3] + @test_throws ErrorException throw.get_val_for_point(10, ax) + @test_throws ErrorException throw.get_val_for_point(-1, ax) + @test throw.get_val_for_point(1.5, ax) == 1.5 + + @test flat.get_val_for_point(10, ax) == 3 + @test flat.get_val_for_point(-1, ax) == 0 + @test flat.get_val_for_point(1.5, ax) == 1.5 + + @test periodic.get_val_for_point(10, ax) == 1 + @test periodic.get_val_for_point(-1, ax) == 2 + @test periodic.get_val_for_point(1.5, ax) == 1.5 + @test periodic.get_val_for_point(3, ax) == 3 +end + +@testset "Extrapolate to new point" begin + throw = ClimaAnalysis.Interpolations.extp_cond_throw() + flat = ClimaAnalysis.Interpolations.extp_cond_flat() + periodic = ClimaAnalysis.Interpolations.extp_cond_periodic() + + ax1 = [0, 1, 2, 3] + @test ClimaAnalysis.Interpolations.extp_to_point((1,), (ax1,), (throw,)) == + (1,) + + ax2 = [4, 5, 6, 7] + @test ClimaAnalysis.Interpolations.extp_to_point( + (-1, 8), + (ax1, ax2), + (flat, periodic), + ) == (0, 5) +end + +@testset "Interpolation" begin + throw = ClimaAnalysis.Interpolations.extp_cond_throw() + flat = ClimaAnalysis.Interpolations.extp_cond_flat() + periodic = ClimaAnalysis.Interpolations.extp_cond_periodic() + + # 1D case + axes = ([1.0, 2.0, 3.0],) + data = [3.0, 1.0, 0.0] + + @test ClimaAnalysis.Interpolations.linear_interpolate( + (1.0,), + axes, + data, + (throw,), + ) == 3.0 + @test ClimaAnalysis.Interpolations.linear_interpolate( + (3.0,), + axes, + data, + (throw,), + ) == 0.0 + @test ClimaAnalysis.Interpolations.linear_interpolate( + (1.5,), + axes, + data, + (throw,), + ) == 2.0 + + # 1D case with extrapolation conditions + @test_throws ErrorException ClimaAnalysis.Interpolations.linear_interpolate( + (0.0,), + axes, + data, + (throw,), + ) + @test_throws ErrorException ClimaAnalysis.Interpolations.linear_interpolate( + (4.0,), + axes, + data, + (throw,), + ) + @test ClimaAnalysis.Interpolations.linear_interpolate( + (0.0,), + axes, + data, + (flat,), + ) == 3.0 + @test ClimaAnalysis.Interpolations.linear_interpolate( + (4.0,), + axes, + data, + (flat,), + ) == 0.0 + @test ClimaAnalysis.Interpolations.linear_interpolate( + (0.0,), + axes, + data, + (periodic,), + ) == 1.0 + @test ClimaAnalysis.Interpolations.linear_interpolate( + (4.0,), + axes, + data, + (periodic,), + ) == 1.0 + + # 2D case + axes = ([1.0, 2.0, 3.0], [4.0, 5.0, 6.0]) + data = reshape(1:9, (3, 3)) + + @test ClimaAnalysis.Interpolations.linear_interpolate( + (1.0, 4.0), + axes, + data, + (throw, throw), + ) == 1.0 + @test ClimaAnalysis.Interpolations.linear_interpolate( + (3.0, 6.0), + axes, + data, + (throw, throw), + ) == 9.0 + @test ClimaAnalysis.Interpolations.linear_interpolate( + (2.0, 5.0), + axes, + data, + (throw, throw), + ) == 5.0 + @test ClimaAnalysis.Interpolations.linear_interpolate( + (1.5, 4.5), + axes, + data, + (throw, throw), + ) == 3.0 + @test ClimaAnalysis.Interpolations.linear_interpolate( + (1.5, 5.5), + axes, + data, + (throw, throw), + ) == 6.0 + + # 2D cases with extrapolation conditions + @test_throws ErrorException ClimaAnalysis.Interpolations.linear_interpolate( + (4.0, 5.0), + axes, + data, + (throw, flat), + ) + @test_throws ErrorException ClimaAnalysis.Interpolations.linear_interpolate( + (2.0, 7.0), + axes, + data, + (flat, throw), + ) + @test_throws ErrorException ClimaAnalysis.Interpolations.linear_interpolate( + (0.0, 8.0), + axes, + data, + (throw, throw), + ) + @test ClimaAnalysis.Interpolations.linear_interpolate( + (0.0, 8.0), + axes, + data, + (flat, flat), + ) == 7.0 + @test ClimaAnalysis.Interpolations.linear_interpolate( + (4.0, 7.0), + axes, + data, + (periodic, periodic), + ) == 5.0 + @test ClimaAnalysis.Interpolations.linear_interpolate( + (3.0, 6.0), + axes, + data, + (periodic, periodic), + ) == 9.0 + @test ClimaAnalysis.Interpolations.linear_interpolate( + (4.0, 7.0), + axes, + data, + (flat, periodic), + ) == 6.0 + + # 3D cases with extrapolation conditions + axes = ([1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]) + data = reshape(1:27, (3, 3, 3)) + @test ClimaAnalysis.Interpolations.linear_interpolate( + (1.0, 5.0, 7.0), + axes, + data, + (throw, throw, throw), + ) == 4.0 + @test ClimaAnalysis.Interpolations.linear_interpolate( + (1.5, 5.2, 7.5), + axes, + data, + (throw, throw, throw), + ) ≈ 9.6 + + # Non equispaced + axes = ([1.0, 3.0, 7.0], [4.0, 5.0, 7.0]) + data = reshape(1:9, (3, 3)) + @test ClimaAnalysis.Interpolations.linear_interpolate( + (2.0, 4.5), + axes, + data, + (throw, throw, throw), + ) == 3.0 + @test ClimaAnalysis.Interpolations.linear_interpolate( + (5.0, 6.0), + axes, + data, + (throw, throw, throw), + ) == 7.0 + + # Different types + # Axes have different types and inputs have different types + axes = ([1.0f0, 2.0f0], [Float16(3.0), Float16(4.0)]) + data = [[1.0, 2.0] [3.0, 4.0]] + @test ClimaAnalysis.Interpolations.linear_interpolate( + (1.5f0, 3.5), + axes, + data, + (throw, throw), + ) == 2.5 + @test ClimaAnalysis.Interpolations.linear_interpolate( + (1.5, 4.5f0), + axes, + data, + (flat, flat), + ) == 3.5 + @test ClimaAnalysis.Interpolations.linear_interpolate( + [1.5, 4.5f0], + axes, + data, + (flat, flat), + ) == 3.5 + + # Single number + axes = ([1.0, 2.0, 3.0],) + data = [3.0, 1.0, 0.0] + + @test ClimaAnalysis.Interpolations.linear_interpolate( + 1.0, + axes, + data, + (throw,), + ) == 3.0 +end From 2d403be86192b2cb4ce6e173d0e89d70c5ac2bab Mon Sep 17 00:00:00 2001 From: Kevin Phan <98072684+ph-kev@users.noreply.github.com> Date: Wed, 4 Dec 2024 13:52:17 -0800 Subject: [PATCH 2/2] Use new interpolation routine in Var This commit removes Interpolations.jl from Var.jl. To do this, the function `_make_interpolant` was removed. Three new functions are added which are `_check_interpolant`, `interpolate_point`, and `interpolate_points`, where the latter two functions replace the functionality of `_make_interpolant`. Furthermore, the function `_find_extp_bound_cond` was refactored to `_find_extp_bound_conds` which find multiple extrapolation condtions using `_find_extp_bound_cond` which is refactored to find the extrapolation condition for a single point. All functions that use an interpolant are updated to use the new interpolation routine. The test for computing the bias in Atmos changes to check approximately close to 0.0, due to floating point errors. The tests that check for errors when interpolating out of bounds now check for ErrorException instead of BoundsError. --- NEWS.md | 5 ++ src/Var.jl | 140 +++++++++++++++++++++++++++++---------------- test/test_Atmos.jl | 2 +- test/test_Var.jl | 53 ++++++++++------- 4 files changed, 128 insertions(+), 72 deletions(-) diff --git a/NEWS.md b/NEWS.md index 216d0929..41dad2bc 100644 --- a/NEWS.md +++ b/NEWS.md @@ -22,6 +22,11 @@ With this release, you can remake a `OutputVar` using an already existing `Outpu is helpful if you need to construct a new `OutputVar` from an already existing one, but only need to modify one field while leaving the other fields the same. +## Add interpolation routine +With this release, any functions that rely on interpolation now uses the interpolation +routine written for ClimaAnalysis instead of Interpolations.jl. This substantially reduce +the number and size of allocations when using these functions. + v0.5.12 ------- diff --git a/src/Var.jl b/src/Var.jl index baa44baa..da0a6cef 100644 --- a/src/Var.jl +++ b/src/Var.jl @@ -9,6 +9,7 @@ import Statistics: mean import NaNStatistics: nanmean import ..Numerics +import ..Interpolations import ..Utils: nearest_index, seconds_to_prettystr, @@ -88,53 +89,96 @@ struct OutputVar{T <: AbstractArray, A <: AbstractArray, B, C} end """ - _make_interpolant(dims, data) + _check_interpolant(dims, data) -Make a linear interpolant from `dims`, a dictionary mapping dimension name to array and -`data`, an array containing data. Used in constructing a `OutputVar`. +Check if it is possible to create an interpolant. -If any element of the arrays in `dims` is a Dates.DateTime, then no interpolant is returned. -Interpolations.jl does not support interpolating on dates. If the longitudes span the entire -range and are equispaced, then a periodic boundary condition is added for the longitude -dimension. If the latitudes span the entire range and are equispaced, then a flat boundary -condition is added for the latitude dimension. In all other cases, an error is thrown when -extrapolating outside of `dim_array`. +If any element of the arrays in `dims` is a Dates.DateTime, then an error is returned. If + the longitudes span the entire range and are equispaced, then a periodic boundary condition +is added for the longitude dimension. If the latitudes span the entire range and are +equispaced, then a flat boundary condition is added for the latitude dimension. In all other +cases, an error is thrown when extrapolating outside of `dim_array`. """ -function _make_interpolant(dims, data) - # If any element is DateTime, then return nothing for the interpolant because - # Interpolations.jl do not support DateTimes +function _check_interpolant(dims) + # If any element is DateTime, then return an error + # ClimaAnalysis does not support interpolating on dates for dim_array in values(dims) - eltype(dim_array) <: Dates.DateTime && return nothing + eltype(dim_array) <: Dates.DateTime && return error( + "An interpolant cannot be made because interpolating on dates is not possible", + ) end # We can only create interpolants when we have 1D dimensions if isempty(dims) || any(d -> ndims(d) != 1 || length(d) == 1, values(dims)) - return nothing + return error( + "An interpolant cannot be made because the dimensions are not 1D", + ) end # Dimensions are all 1D, check that the knots are in increasing order (as required by - # Interpolations.jl) + # our interpolation routine) for (dim_name, dim_array) in dims if !issorted(dim_array) - @warn "Dimension $dim_name is not in increasing order. An interpolant will not be created. See Var.reverse_dim if the dimension is in decreasing order" - return nothing + return error( + "Dimension $dim_name is not in increasing order. An interpolant will not be created. See Var.reverse_dim if the dimension is in decreasing order", + ) end end + return nothing +end - # Find boundary conditions for extrapolation - extp_bound_conds = ( - _find_extp_bound_cond(dim_name, dim_array) for - (dim_name, dim_array) in dims - ) +""" + interpolate_point(point, dims, data) + +Linearly interpolate the point using `dims` and `data`. - dims_tuple = tuple(values(dims)...) - extp_bound_conds_tuple = tuple(extp_bound_conds...) - return Intp.extrapolate( - Intp.interpolate(dims_tuple, data, Intp.Gridded(Intp.Linear())), - extp_bound_conds_tuple, +Extrapolation conditions are determined by `_find_extp_bound_conds`. +""" +function interpolate_point(point, dims, data) + _check_interpolant(dims) + extp_bound_conds = _find_extp_bound_conds(dims) + return Interpolations.linear_interpolate( + point, + Tuple(values(dims)), + data, + extp_bound_conds, ) end +""" + interpolate_points(points, dims, data) + +Linearly interpolate the points using `dims` and `data`. + +Extrapolation conditions are determined by `_find_extp_bound_conds`. +""" +function interpolate_points(points, dims, data) + _check_interpolant(dims) + extp_bound_conds = _find_extp_bound_conds(dims) + dim_arrays_tuple = Tuple(values(dims)) + interpolated_arr = [ + Interpolations.linear_interpolate( + point, + dim_arrays_tuple, + data, + extp_bound_conds, + ) for point in points + ] + return interpolated_arr +end + +""" + _find_extp_bound_conds(dims) + +Find the appropriate boundary conditions given the `dims` of an `OutputVar`. +""" +function _find_extp_bound_conds(dims) + return ( + _find_extp_bound_cond(dim_name, dim_array) for + (dim_name, dim_array) in dims + ) |> Tuple +end + """ _find_extp_bound_cond(dim_name, dim_array) @@ -152,17 +196,17 @@ function _find_extp_bound_cond(dim_name, dim_array) conventional_dim_name(dim_name) == "longitude" && _isequispaced(dim_array) && isapprox(dim_size + dsize, 360.0) - ) && return Intp.Periodic() + ) && return Interpolations.extp_cond_periodic() ( conventional_dim_name(dim_name) == "longitude" && (dim_array[end] - dim_array[begin]) ≈ 360.0 - ) && return Intp.Periodic() + ) && return Interpolations.extp_cond_periodic() ( conventional_dim_name(dim_name) == "latitude" && _isequispaced(dim_array) && isapprox(dim_size + dsize, 180.0) - ) && return Intp.Flat() - return Intp.Throw() + ) && return Interpolations.extp_cond_flat() + return Interpolations.extp_cond_throw() end function OutputVar(attribs, dims, dim_attribs, data) @@ -998,11 +1042,11 @@ multilinear interpolation. Extrapolation is now allowed and will throw a `BoundsError` in most cases. If any element of the arrays of the dimensions is a Dates.DateTime, then interpolation is -not possible. Interpolations.jl do not support making interpolations for dates. If the -longitudes span the entire range and are equispaced, then a periodic boundary condition is -added for the longitude dimension. If the latitudes span the entire range and are -equispaced, then a flat boundary condition is added for the latitude dimension. In all other -cases, an error is thrown when extrapolating outside of the array of the dimension. +not possible. If the longitudes span the entire range and are equispaced, then a periodic +boundary condition is added for the longitude dimension. If the latitudes span the entire +range and are equispaced, then a flat boundary condition is added for the latitude +dimension. In all other cases, an error is thrown when extrapolating outside of the array of +the dimension. Example ======= @@ -1022,8 +1066,7 @@ julia> var2d = ClimaAnalysis.OutputVar(Dict("time" => time, "z" => z), data); va ``` """ function (x::OutputVar)(target_coord) - itp = _make_interpolant(x.dims, x.data) - return itp(target_coord...) + return interpolate_point(target_coord, x.dims, x.data) end """ @@ -1164,9 +1207,8 @@ function resampled_as(src_var::OutputVar, dest_var::OutputVar) src_var = reordered_as(src_var, dest_var) _check_dims_consistent(src_var, dest_var) - itp = _make_interpolant(src_var.dims, src_var.data) - src_resampled_data = - [itp(pt...) for pt in Base.product(values(dest_var.dims)...)] + coords = Base.product(values(dest_var.dims)...) + src_resampled_data = interpolate_points(coords, src_var.dims, src_var.data) # Construct new OutputVar to return src_var_ret_dims = empty(src_var.dims) @@ -1770,14 +1812,14 @@ function make_lonlat_mask( # Resample so that the mask match up with the grid of var # Round because linear resampling is done and we want the mask to be only ones and zeros - intp = _make_interpolant(mask_var.dims, mask_var.data) - mask_arr = - [ - intp(pt...) for pt in Base.product( - input_var.dims[longitude_name(input_var)], - input_var.dims[latitude_name(input_var)], - ) - ] .|> round + coords = [ + pt for pt in Base.product( + input_var.dims[longitude_name(input_var)], + input_var.dims[latitude_name(input_var)], + ) + ] + mask_arr = interpolate_points(coords, mask_var.dims, mask_var.data) + mask_arr .= mask_arr .|> round # Reshape data for broadcasting lon_idx = input_var.dim2index[longitude_name(input_var)] diff --git a/test/test_Atmos.jl b/test/test_Atmos.jl index 4204c231..e451a8f2 100644 --- a/test/test_Atmos.jl +++ b/test/test_Atmos.jl @@ -246,7 +246,7 @@ end sim_pressure = pressure3D, obs_pressure = pressure3D, ) - @test global_rmse_pfull == 0.0 + @test isapprox(global_rmse_pfull, 0.0, atol = 1e-11) # Test if the computation is the same as a manual computation zero_data = zeros(size(data)) diff --git a/test/test_Var.jl b/test/test_Var.jl index a131e9ee..0da7942e 100644 --- a/test/test_Var.jl +++ b/test/test_Var.jl @@ -124,35 +124,43 @@ end lon = 0.5:1.0:359.5 |> collect lat = -89.5:1.0:89.5 |> collect time = 1.0:100 |> collect - data = ones(length(lon), length(lat), length(time)) dims = OrderedDict(["lon" => lon, "lat" => lat, "time" => time]) - intp = ClimaAnalysis.Var._make_interpolant(dims, data) - @test intp.et == (Intp.Periodic(), Intp.Flat(), Intp.Throw()) + extp_conds = ClimaAnalysis.Var._find_extp_bound_conds(dims) + @test extp_conds == ( + ClimaAnalysis.Interpolations.extp_cond_periodic(), + ClimaAnalysis.Interpolations.extp_cond_flat(), + ClimaAnalysis.Interpolations.extp_cond_throw(), + ) # Not equispaced for lon and lat lon = 0.5:1.0:359.5 |> collect |> x -> push!(x, 42.0) |> sort lat = -89.5:1.0:89.5 |> collect |> x -> push!(x, 42.0) |> sort time = 1.0:100 |> collect - data = ones(length(lon), length(lat), length(time)) dims = OrderedDict(["lon" => lon, "lat" => lat, "time" => time]) - intp = ClimaAnalysis.Var._make_interpolant(dims, data) - @test intp.et == (Intp.Throw(), Intp.Throw(), Intp.Throw()) + extp_conds = ClimaAnalysis.Var._find_extp_bound_conds(dims) + @test extp_conds == ( + ClimaAnalysis.Interpolations.extp_cond_throw(), + ClimaAnalysis.Interpolations.extp_cond_throw(), + ClimaAnalysis.Interpolations.extp_cond_throw(), + ) # Does not span entire range for and lat lon = 0.5:1.0:350.5 |> collect lat = -89.5:1.0:80.5 |> collect time = 1.0:100 |> collect - data = ones(length(lon), length(lat), length(time)) dims = OrderedDict(["lon" => lon, "lat" => lat, "time" => time]) - intp = ClimaAnalysis.Var._make_interpolant(dims, data) - @test intp.et == (Intp.Throw(), Intp.Throw(), Intp.Throw()) + extp_conds = ClimaAnalysis.Var._find_extp_bound_conds(dims) + @test extp_conds == ( + ClimaAnalysis.Interpolations.extp_cond_throw(), + ClimaAnalysis.Interpolations.extp_cond_throw(), + ClimaAnalysis.Interpolations.extp_cond_throw(), + ) # Lon is exactly 360 degrees lon = 0.0:1.0:360.0 |> collect - data = ones(length(lon)) dims = OrderedDict(["lon" => lon]) - intp = ClimaAnalysis.Var._make_interpolant(dims, data) - @test intp.et == (Intp.Periodic(),) + extp_conds = ClimaAnalysis.Var._find_extp_bound_conds(dims) + @test extp_conds == (ClimaAnalysis.Interpolations.extp_cond_periodic(),) # Dates for the time dimension lon = 0.5:1.0:359.5 |> collect @@ -162,17 +170,18 @@ end Dates.DateTime(2020, 3, 1, 1, 2), Dates.DateTime(2020, 3, 1, 1, 3), ] - data = ones(length(lon), length(lat), length(time)) dims = OrderedDict(["lon" => lon, "lat" => lat, "time" => time]) - intp = ClimaAnalysis.Var._make_interpolant(dims, data) - @test isnothing(intp) + @test_throws ErrorException ClimaAnalysis.Var._check_interpolant(dims) # 2D dimensions arb_dim = reshape(collect(range(-89.5, 89.5, 16)), (4, 4)) - data = collect(1:16) dims = OrderedDict(["arb_dim" => arb_dim]) - intp = ClimaAnalysis.Var._make_interpolant(dims, data) - @test isnothing(intp) + @test_throws ErrorException ClimaAnalysis.Var._check_interpolant(dims) + + # Dimensions are not in increasing order + lon = [0.5, 42.0, 1.5, 110.0] + dims = OrderedDict(["lon" => lon]) + @test_throws ErrorException ClimaAnalysis.Var._check_interpolant(dims) end @testset "empty" begin @@ -529,6 +538,7 @@ end @test ClimaAnalysis.pressure_name(pressure_var) == "pfull" end +# FIX THIS @testset "Interpolation" begin # 1D interpolation with linear data, should yield correct results long = -175.0:175.0 |> collect @@ -539,7 +549,7 @@ end @test longvar.([10.5, 20.5]) == [10.5, 20.5] # Test error for data outside of range - @test_throws BoundsError longvar(200.0) + @test_throws ErrorException longvar(200.0) # 2D interpolation with linear data, should yield correct results time = 100.0:110.0 |> collect @@ -822,7 +832,7 @@ end @test src_var.data == ClimaAnalysis.resampled_as(src_var, src_var).data resampled_var = ClimaAnalysis.resampled_as(src_var, dest_var) @test resampled_var.data == reshape(1.0:(181 * 91), (181, 91))[1:91, 1:46] - @test_throws BoundsError ClimaAnalysis.resampled_as(dest_var, src_var) + @test_throws ErrorException ClimaAnalysis.resampled_as(dest_var, src_var) # BoundsError check src_long = 90.0:120.0 |> collect @@ -838,7 +848,7 @@ end dest_var = ClimaAnalysis.remake(dest_var, data = dest_data, dims = dest_dims) - @test_throws BoundsError ClimaAnalysis.resampled_as(src_var, dest_var) + @test_throws ErrorException ClimaAnalysis.resampled_as(src_var, dest_var) end @testset "Units" begin @@ -1869,7 +1879,6 @@ end attribs = Dict("long_name" => "hi") dim_attribs = OrderedDict(["lon" => Dict("units" => "deg")]) var = ClimaAnalysis.OutputVar(attribs, dims, dim_attribs, data) - @test isnothing(ClimaAnalysis.Var._make_interpolant(dims, data)) reverse_var = ClimaAnalysis.reverse_dim(var, "lat") @test reverse(lat) == reverse_var.dims["lat"]