diff --git a/NEWS.md b/NEWS.md index 216d0929..41dad2bc 100644 --- a/NEWS.md +++ b/NEWS.md @@ -22,6 +22,11 @@ With this release, you can remake a `OutputVar` using an already existing `Outpu is helpful if you need to construct a new `OutputVar` from an already existing one, but only need to modify one field while leaving the other fields the same. +## Add interpolation routine +With this release, any functions that rely on interpolation now uses the interpolation +routine written for ClimaAnalysis instead of Interpolations.jl. This substantially reduce +the number and size of allocations when using these functions. + v0.5.12 ------- diff --git a/src/ClimaAnalysis.jl b/src/ClimaAnalysis.jl index b7401892..2dd9614d 100644 --- a/src/ClimaAnalysis.jl +++ b/src/ClimaAnalysis.jl @@ -5,6 +5,7 @@ include("Utils.jl") import .Utils include("Numerics.jl") +include("Interpolations.jl") include("Var.jl") @reexport using .Var diff --git a/src/Interpolations.jl b/src/Interpolations.jl new file mode 100644 index 00000000..7a888678 --- /dev/null +++ b/src/Interpolations.jl @@ -0,0 +1,276 @@ +module Interpolations + +#= +Why not use Interpolations.jl? + +One of the most expensive operation in terms of time and memory is +`Var.resampled_as(src_var, dest_var)` which uses an interpolation to resample `data` in +`src_var` to match the `dims` in `dest_var`. However, interpolating using Interpolations.jl +is expensive. One must first intialize an interpolant, whose size in memory is at least as +big as the memory of `src_var.data`. Then, each evaluation using the interpolant +costs some amount of allocation on the heap. As a result, interpolating is extremely slow +for a large number of points. + +To solve this, we write our interpolation routine in the Numerics module that support the following: +1. No extra dependencies +2. No allocation when evaluating a point +3. Support boundary conditions for periodic, flat, and throw on an irregular grid +5. Comparable or better performance to Interpolations.jl +=# + +""" + linear_interpolate(point::NTuple{N, FT1}, + axes::NTuple{N, Vector}, + data::AbstractArray{FT2, N}, + extp_conds::NTuple{N, NTuple{4, Function}}) where {N, FT1, FT2} + +Linear interpolate `data` on `axes` and return the value at `point`. Extrapolation is +handled by `extp_conds`. +""" +function linear_interpolate( + point::NTuple{N}, + axes, + data::AbstractArray{FT, N}, + extp_conds, +) where {N, FT} + # Get a new point as determined by the extrapolation condition + point = extp_to_point(point, axes, extp_conds) + + # Find which cell contain the point + cell_indices_for_axes = find_cell_indices_for_axes(point, axes) + val = zero(FT) + + # Compute the denominator of the formula for linear interpolation + # (in 1D, this is x_1 - x_0) + bottom_term = compute_bottom_term(axes, cell_indices_for_axes) + + # Iterate through all 2^N points + @inbounds for bits in 0:(2^N - 1) + term = one(FT) + bound_indices = get_indices(cell_indices_for_axes, bits) + sign = get_sign(cell_indices_for_axes, bits) + # Weight is the value at each of the points of the cell + weight = data[get_complement_indices(cell_indices_for_axes, bits)...] + @inbounds for (dim_idx, bound_idx) in pairs(bound_indices) + val_minus_x2_or_x1 = point[dim_idx] - axes[dim_idx][bound_idx] + term *= val_minus_x2_or_x1 + end + term *= sign * weight + val += term + end + return val / bottom_term +end + +""" + linear_interpolate(point::Number, + axes, + data::AbstractArray{FT, N}, + extp_conds) where {N, FT} + +Convert a number to a tuple and linear interpolate. +""" +function linear_interpolate( + point::Number, + axes, + data::AbstractArray{FT, N}, + extp_conds, +) where {N, FT} + point = Tuple(point...) + return linear_interpolate(point, axes, data, extp_conds) +end + +""" + linear_interpolate(point::AbstractVector, + axes, + data::AbstractArray{FT, N}, + extp_conds) where {N, FT} + +Convert a vector to a tuple and linear interpolate. +""" +function linear_interpolate( + point::AbstractVector, + axes, + data::AbstractArray{FT, N}, + extp_conds, +) where {N, FT} + point = Tuple(coord for coord in point) + return linear_interpolate(point, axes, data, extp_conds) +end + +""" + linear_interpolate(point::Tuple, + axes, + data::AbstractArray{FT, N}, + extp_conds) where {N, FT} + +Promote a tuple and linear interpolate. +""" +function linear_interpolate( + point::Tuple, + axes, + data::AbstractArray{FT, N}, + extp_conds, +) where {N, FT} + point = promote(point...) + return linear_interpolate(point, axes, data, extp_conds) +end + + +""" + compute_bottom_term(axes, cell_indices_for_axes::NTuple{N}) + +Compute the bottom term when linearly interpolating. + +Consider the formula for 1D linear interpolation which is + y_0 * ((x_1 - x) / (x_1 - x_0)) + y_1 * ((x - x_0) / (x_1 - x_0)) +for interpolating the point (x, y) on line between (x_0, y_0) and (x_1, y_1). This function +computes x_1 - x_0. +""" +function compute_bottom_term(axes, cell_indices_for_axes::NTuple{N}) where {N} + return reduce( + *, + ntuple( + dim_idx -> + axes[dim_idx][cell_indices_for_axes[dim_idx][end]] - + axes[dim_idx][cell_indices_for_axes[dim_idx][begin]], + N, + ), + ) +end + +""" + get_complement_indices(indices::NTuple{N, Tuple{I, I}}, bits) where {N, I} + +Given a tuple consisting of 2-tuple, return a tuple of one element from each tuple according +to bits. The elements in the tuple are the complement of those found by `get_indices`. +""" +function get_complement_indices(indices::NTuple{N}, bits) where {N} + # Bit manipulation can be found here: + # https://github.com/parsiad/mlinterp/blob/master/mlinterp/mlinterp.hpp + # Adjusted for 1-indexing instead of 0-indexing + return ntuple(dim -> if (bits & (1 << (dim - 1)) != 0) + indices[dim][2] + else + indices[dim][1] + end, N) +end + +""" + get_indices(indices::NTuple{N, Tuple{I, I}}, bits) where {N, I} + +Given a tuple consisting of 2-tuple, return a tuple of one element from each tuple according to bits. +""" +function get_indices(indices::NTuple{N}, bits) where {N} + # See get_complement_indices for where the Bit manipulation comes from + return ntuple(dim -> if (bits & (1 << (dim - 1)) != 0) + indices[dim][1] + else + indices[dim][2] + end, N) +end + +""" + get_sign(indices::NTuple{N, Tuple{I, I}}, bits) where {N, I} + +Given a tuple consisting of 2-tuple, compute the appropriate sign when interpolating. +""" +function get_sign(_indices::NTuple{N}, bits) where {N} + # See get_complement_indices for where the bit manipulation comes from + return reduce(*, ntuple(dim -> if (bits & (1 << (dim - 1)) != 0) + 1 + else + -1 + end, N)) +end + +""" + find_cell_indices_for_axes(point::NTuple{N, FT}, + axes::NTuple{N, A}) where {N, FT, A <:AbstractVector} + +Given a point and axes, find the indices of the N-dimensional hyperrectangle, where the +point lives in. +""" +function find_cell_indices_for_axes(point::NTuple{N}, axes) where {N} + return ntuple( + dim_idx -> find_cell_indices_for_ax(point[dim_idx], axes[dim_idx]), + N, + ) +end + +""" + find_cell_indices_for_ax(val::FT1, ax::AbstractVector{FT2}) where {FT1, FT2} + +Given `val` and an `ax`, find the indices of the cell, where `val` lives in. +""" +function find_cell_indices_for_ax( + val::FT1, + ax::AbstractVector{FT2}, +) where {FT1, FT2} + len_of_ax = length(ax) + (val == ax[begin]) && return (1, 2) + (val == ax[end]) && return (len_of_ax - 1, len_of_ax) + idx = searchsortedfirst(ax, val) + return (idx - 1, idx) +end + +""" + extp_to_point(point::NTuple{N, FT1}, axes::NTuple{N, Vector}, extp_conds) where {N, FT1} + +Return a new point to evaluate at according to the extrapolation conditions. +""" +function extp_to_point(point::NTuple{N}, axes, extp_conds) where {N} + return ntuple( + idx -> extp_conds[idx].get_val_for_point(point[idx], axes[idx]), + N, + ) +end + +""" + extp_cond_throw() + +Return extrapolation condition for throwing an error when the point is out of bounds. + +The first and last nodes are not co-located. For example, if the axis is [1.0, 2.0, 3.0] +and the data is [4.0, 5.0, 6.0], then the value at 3.0 is 6.0 and not 4.0. +""" +function extp_cond_throw() + get_val_for_point(val, ax) = begin + val < ax[begin] && return error("Out of bounds error with $val in $ax") + val > ax[end] && return error("Out of bounds error with $val in $ax") + return val + end + return (; get_val_for_point = get_val_for_point) +end + +""" + extp_cond_flat() + +Return flat extrapolation condition. +""" +function extp_cond_flat() + get_val_for_point(val, ax) = begin + val < ax[begin] && return typeof(val)(ax[begin]) + val > ax[end] && return typeof(val)(ax[end]) + return val + end + return (; get_val_for_point) +end + +""" + extp_cond_periodic() + +Return periodic extrapolation condtion. +""" +function extp_cond_periodic() + get_val_for_point(val, ax) = begin + if (val < ax[begin]) || (val > ax[end]) + width = ax[end] - ax[begin] + new_val = mod(val - ax[begin], width) + ax[begin] + return typeof(val)(new_val) + end + return val + end + return (; get_val_for_point) +end + +end diff --git a/src/Var.jl b/src/Var.jl index baa44baa..da0a6cef 100644 --- a/src/Var.jl +++ b/src/Var.jl @@ -9,6 +9,7 @@ import Statistics: mean import NaNStatistics: nanmean import ..Numerics +import ..Interpolations import ..Utils: nearest_index, seconds_to_prettystr, @@ -88,53 +89,96 @@ struct OutputVar{T <: AbstractArray, A <: AbstractArray, B, C} end """ - _make_interpolant(dims, data) + _check_interpolant(dims, data) -Make a linear interpolant from `dims`, a dictionary mapping dimension name to array and -`data`, an array containing data. Used in constructing a `OutputVar`. +Check if it is possible to create an interpolant. -If any element of the arrays in `dims` is a Dates.DateTime, then no interpolant is returned. -Interpolations.jl does not support interpolating on dates. If the longitudes span the entire -range and are equispaced, then a periodic boundary condition is added for the longitude -dimension. If the latitudes span the entire range and are equispaced, then a flat boundary -condition is added for the latitude dimension. In all other cases, an error is thrown when -extrapolating outside of `dim_array`. +If any element of the arrays in `dims` is a Dates.DateTime, then an error is returned. If + the longitudes span the entire range and are equispaced, then a periodic boundary condition +is added for the longitude dimension. If the latitudes span the entire range and are +equispaced, then a flat boundary condition is added for the latitude dimension. In all other +cases, an error is thrown when extrapolating outside of `dim_array`. """ -function _make_interpolant(dims, data) - # If any element is DateTime, then return nothing for the interpolant because - # Interpolations.jl do not support DateTimes +function _check_interpolant(dims) + # If any element is DateTime, then return an error + # ClimaAnalysis does not support interpolating on dates for dim_array in values(dims) - eltype(dim_array) <: Dates.DateTime && return nothing + eltype(dim_array) <: Dates.DateTime && return error( + "An interpolant cannot be made because interpolating on dates is not possible", + ) end # We can only create interpolants when we have 1D dimensions if isempty(dims) || any(d -> ndims(d) != 1 || length(d) == 1, values(dims)) - return nothing + return error( + "An interpolant cannot be made because the dimensions are not 1D", + ) end # Dimensions are all 1D, check that the knots are in increasing order (as required by - # Interpolations.jl) + # our interpolation routine) for (dim_name, dim_array) in dims if !issorted(dim_array) - @warn "Dimension $dim_name is not in increasing order. An interpolant will not be created. See Var.reverse_dim if the dimension is in decreasing order" - return nothing + return error( + "Dimension $dim_name is not in increasing order. An interpolant will not be created. See Var.reverse_dim if the dimension is in decreasing order", + ) end end + return nothing +end - # Find boundary conditions for extrapolation - extp_bound_conds = ( - _find_extp_bound_cond(dim_name, dim_array) for - (dim_name, dim_array) in dims - ) +""" + interpolate_point(point, dims, data) + +Linearly interpolate the point using `dims` and `data`. - dims_tuple = tuple(values(dims)...) - extp_bound_conds_tuple = tuple(extp_bound_conds...) - return Intp.extrapolate( - Intp.interpolate(dims_tuple, data, Intp.Gridded(Intp.Linear())), - extp_bound_conds_tuple, +Extrapolation conditions are determined by `_find_extp_bound_conds`. +""" +function interpolate_point(point, dims, data) + _check_interpolant(dims) + extp_bound_conds = _find_extp_bound_conds(dims) + return Interpolations.linear_interpolate( + point, + Tuple(values(dims)), + data, + extp_bound_conds, ) end +""" + interpolate_points(points, dims, data) + +Linearly interpolate the points using `dims` and `data`. + +Extrapolation conditions are determined by `_find_extp_bound_conds`. +""" +function interpolate_points(points, dims, data) + _check_interpolant(dims) + extp_bound_conds = _find_extp_bound_conds(dims) + dim_arrays_tuple = Tuple(values(dims)) + interpolated_arr = [ + Interpolations.linear_interpolate( + point, + dim_arrays_tuple, + data, + extp_bound_conds, + ) for point in points + ] + return interpolated_arr +end + +""" + _find_extp_bound_conds(dims) + +Find the appropriate boundary conditions given the `dims` of an `OutputVar`. +""" +function _find_extp_bound_conds(dims) + return ( + _find_extp_bound_cond(dim_name, dim_array) for + (dim_name, dim_array) in dims + ) |> Tuple +end + """ _find_extp_bound_cond(dim_name, dim_array) @@ -152,17 +196,17 @@ function _find_extp_bound_cond(dim_name, dim_array) conventional_dim_name(dim_name) == "longitude" && _isequispaced(dim_array) && isapprox(dim_size + dsize, 360.0) - ) && return Intp.Periodic() + ) && return Interpolations.extp_cond_periodic() ( conventional_dim_name(dim_name) == "longitude" && (dim_array[end] - dim_array[begin]) ≈ 360.0 - ) && return Intp.Periodic() + ) && return Interpolations.extp_cond_periodic() ( conventional_dim_name(dim_name) == "latitude" && _isequispaced(dim_array) && isapprox(dim_size + dsize, 180.0) - ) && return Intp.Flat() - return Intp.Throw() + ) && return Interpolations.extp_cond_flat() + return Interpolations.extp_cond_throw() end function OutputVar(attribs, dims, dim_attribs, data) @@ -998,11 +1042,11 @@ multilinear interpolation. Extrapolation is now allowed and will throw a `BoundsError` in most cases. If any element of the arrays of the dimensions is a Dates.DateTime, then interpolation is -not possible. Interpolations.jl do not support making interpolations for dates. If the -longitudes span the entire range and are equispaced, then a periodic boundary condition is -added for the longitude dimension. If the latitudes span the entire range and are -equispaced, then a flat boundary condition is added for the latitude dimension. In all other -cases, an error is thrown when extrapolating outside of the array of the dimension. +not possible. If the longitudes span the entire range and are equispaced, then a periodic +boundary condition is added for the longitude dimension. If the latitudes span the entire +range and are equispaced, then a flat boundary condition is added for the latitude +dimension. In all other cases, an error is thrown when extrapolating outside of the array of +the dimension. Example ======= @@ -1022,8 +1066,7 @@ julia> var2d = ClimaAnalysis.OutputVar(Dict("time" => time, "z" => z), data); va ``` """ function (x::OutputVar)(target_coord) - itp = _make_interpolant(x.dims, x.data) - return itp(target_coord...) + return interpolate_point(target_coord, x.dims, x.data) end """ @@ -1164,9 +1207,8 @@ function resampled_as(src_var::OutputVar, dest_var::OutputVar) src_var = reordered_as(src_var, dest_var) _check_dims_consistent(src_var, dest_var) - itp = _make_interpolant(src_var.dims, src_var.data) - src_resampled_data = - [itp(pt...) for pt in Base.product(values(dest_var.dims)...)] + coords = Base.product(values(dest_var.dims)...) + src_resampled_data = interpolate_points(coords, src_var.dims, src_var.data) # Construct new OutputVar to return src_var_ret_dims = empty(src_var.dims) @@ -1770,14 +1812,14 @@ function make_lonlat_mask( # Resample so that the mask match up with the grid of var # Round because linear resampling is done and we want the mask to be only ones and zeros - intp = _make_interpolant(mask_var.dims, mask_var.data) - mask_arr = - [ - intp(pt...) for pt in Base.product( - input_var.dims[longitude_name(input_var)], - input_var.dims[latitude_name(input_var)], - ) - ] .|> round + coords = [ + pt for pt in Base.product( + input_var.dims[longitude_name(input_var)], + input_var.dims[latitude_name(input_var)], + ) + ] + mask_arr = interpolate_points(coords, mask_var.dims, mask_var.data) + mask_arr .= mask_arr .|> round # Reshape data for broadcasting lon_idx = input_var.dim2index[longitude_name(input_var)] diff --git a/test/runtests.jl b/test/runtests.jl index 8e40c6b4..c4a8323c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -8,6 +8,7 @@ using Test @safetestset "Utils" begin @time include("test_Utils.jl") end @safetestset "Numerics" begin @time include("test_Numerics.jl") end +@safetestset "Interpolations" begin @time include("test_Interpolations.jl") end @safetestset "SimDir" begin @time include("test_Sim.jl") end @safetestset "Atmos" begin @time include("test_Atmos.jl") end @safetestset "Leaderboard" begin @time include("test_Leaderboard.jl") end diff --git a/test/test_Atmos.jl b/test/test_Atmos.jl index 4204c231..e451a8f2 100644 --- a/test/test_Atmos.jl +++ b/test/test_Atmos.jl @@ -246,7 +246,7 @@ end sim_pressure = pressure3D, obs_pressure = pressure3D, ) - @test global_rmse_pfull == 0.0 + @test isapprox(global_rmse_pfull, 0.0, atol = 1e-11) # Test if the computation is the same as a manual computation zero_data = zeros(size(data)) diff --git a/test/test_Interpolations.jl b/test/test_Interpolations.jl new file mode 100644 index 00000000..60a3ec74 --- /dev/null +++ b/test/test_Interpolations.jl @@ -0,0 +1,298 @@ +using Test +import ClimaAnalysis + +@testset "Get indices and sign" begin + indices = ((1, 2),) + @test ClimaAnalysis.Interpolations.get_indices(indices, 0) == (2,) + @test ClimaAnalysis.Interpolations.get_indices(indices, 1) == (1,) + @test ClimaAnalysis.Interpolations.get_complement_indices(indices, 0) == + (1,) + @test ClimaAnalysis.Interpolations.get_complement_indices(indices, 1) == + (2,) + @test ClimaAnalysis.Interpolations.get_sign(indices, 0) == -1 + @test ClimaAnalysis.Interpolations.get_sign(indices, 1) == 1 + + indices = ((1, 2), (3, 4)) + @test ClimaAnalysis.Interpolations.get_indices(indices, 0) == (2, 4) + @test ClimaAnalysis.Interpolations.get_indices(indices, 1) == (1, 4) + @test ClimaAnalysis.Interpolations.get_indices(indices, 2) == (2, 3) + @test ClimaAnalysis.Interpolations.get_indices(indices, 3) == (1, 3) + @test ClimaAnalysis.Interpolations.get_complement_indices(indices, 0) == + (1, 3) + @test ClimaAnalysis.Interpolations.get_complement_indices(indices, 1) == + (2, 3) + @test ClimaAnalysis.Interpolations.get_complement_indices(indices, 2) == + (1, 4) + @test ClimaAnalysis.Interpolations.get_complement_indices(indices, 3) == + (2, 4) + @test ClimaAnalysis.Interpolations.get_sign(indices, 0) == 1 + @test ClimaAnalysis.Interpolations.get_sign(indices, 1) == -1 + @test ClimaAnalysis.Interpolations.get_sign(indices, 2) == -1 + @test ClimaAnalysis.Interpolations.get_sign(indices, 3) == 1 +end + +@testset "Find indices for cell" begin + val1 = 5 + ax1 = [0, 10] + @test ClimaAnalysis.Interpolations.find_cell_indices_for_ax(val1, ax1) == + (1, 2) + + val2 = 6 + ax2 = [0, 4, 10] + @test ClimaAnalysis.Interpolations.find_cell_indices_for_ax(val2, ax2) == + (2, 3) + + @test ClimaAnalysis.Interpolations.find_cell_indices_for_axes( + (val1, val2), + (ax1, ax2), + ) == ((1, 2), (2, 3)) +end + +@testset "Extrapolation conditions" begin + throw = ClimaAnalysis.Interpolations.extp_cond_throw() + flat = ClimaAnalysis.Interpolations.extp_cond_flat() + periodic = ClimaAnalysis.Interpolations.extp_cond_periodic() + + ax = [0, 1, 2, 3] + @test_throws ErrorException throw.get_val_for_point(10, ax) + @test_throws ErrorException throw.get_val_for_point(-1, ax) + @test throw.get_val_for_point(1.5, ax) == 1.5 + + @test flat.get_val_for_point(10, ax) == 3 + @test flat.get_val_for_point(-1, ax) == 0 + @test flat.get_val_for_point(1.5, ax) == 1.5 + + @test periodic.get_val_for_point(10, ax) == 1 + @test periodic.get_val_for_point(-1, ax) == 2 + @test periodic.get_val_for_point(1.5, ax) == 1.5 + @test periodic.get_val_for_point(3, ax) == 3 +end + +@testset "Extrapolate to new point" begin + throw = ClimaAnalysis.Interpolations.extp_cond_throw() + flat = ClimaAnalysis.Interpolations.extp_cond_flat() + periodic = ClimaAnalysis.Interpolations.extp_cond_periodic() + + ax1 = [0, 1, 2, 3] + @test ClimaAnalysis.Interpolations.extp_to_point((1,), (ax1,), (throw,)) == + (1,) + + ax2 = [4, 5, 6, 7] + @test ClimaAnalysis.Interpolations.extp_to_point( + (-1, 8), + (ax1, ax2), + (flat, periodic), + ) == (0, 5) +end + +@testset "Interpolation" begin + throw = ClimaAnalysis.Interpolations.extp_cond_throw() + flat = ClimaAnalysis.Interpolations.extp_cond_flat() + periodic = ClimaAnalysis.Interpolations.extp_cond_periodic() + + # 1D case + axes = ([1.0, 2.0, 3.0],) + data = [3.0, 1.0, 0.0] + + @test ClimaAnalysis.Interpolations.linear_interpolate( + (1.0,), + axes, + data, + (throw,), + ) == 3.0 + @test ClimaAnalysis.Interpolations.linear_interpolate( + (3.0,), + axes, + data, + (throw,), + ) == 0.0 + @test ClimaAnalysis.Interpolations.linear_interpolate( + (1.5,), + axes, + data, + (throw,), + ) == 2.0 + + # 1D case with extrapolation conditions + @test_throws ErrorException ClimaAnalysis.Interpolations.linear_interpolate( + (0.0,), + axes, + data, + (throw,), + ) + @test_throws ErrorException ClimaAnalysis.Interpolations.linear_interpolate( + (4.0,), + axes, + data, + (throw,), + ) + @test ClimaAnalysis.Interpolations.linear_interpolate( + (0.0,), + axes, + data, + (flat,), + ) == 3.0 + @test ClimaAnalysis.Interpolations.linear_interpolate( + (4.0,), + axes, + data, + (flat,), + ) == 0.0 + @test ClimaAnalysis.Interpolations.linear_interpolate( + (0.0,), + axes, + data, + (periodic,), + ) == 1.0 + @test ClimaAnalysis.Interpolations.linear_interpolate( + (4.0,), + axes, + data, + (periodic,), + ) == 1.0 + + # 2D case + axes = ([1.0, 2.0, 3.0], [4.0, 5.0, 6.0]) + data = reshape(1:9, (3, 3)) + + @test ClimaAnalysis.Interpolations.linear_interpolate( + (1.0, 4.0), + axes, + data, + (throw, throw), + ) == 1.0 + @test ClimaAnalysis.Interpolations.linear_interpolate( + (3.0, 6.0), + axes, + data, + (throw, throw), + ) == 9.0 + @test ClimaAnalysis.Interpolations.linear_interpolate( + (2.0, 5.0), + axes, + data, + (throw, throw), + ) == 5.0 + @test ClimaAnalysis.Interpolations.linear_interpolate( + (1.5, 4.5), + axes, + data, + (throw, throw), + ) == 3.0 + @test ClimaAnalysis.Interpolations.linear_interpolate( + (1.5, 5.5), + axes, + data, + (throw, throw), + ) == 6.0 + + # 2D cases with extrapolation conditions + @test_throws ErrorException ClimaAnalysis.Interpolations.linear_interpolate( + (4.0, 5.0), + axes, + data, + (throw, flat), + ) + @test_throws ErrorException ClimaAnalysis.Interpolations.linear_interpolate( + (2.0, 7.0), + axes, + data, + (flat, throw), + ) + @test_throws ErrorException ClimaAnalysis.Interpolations.linear_interpolate( + (0.0, 8.0), + axes, + data, + (throw, throw), + ) + @test ClimaAnalysis.Interpolations.linear_interpolate( + (0.0, 8.0), + axes, + data, + (flat, flat), + ) == 7.0 + @test ClimaAnalysis.Interpolations.linear_interpolate( + (4.0, 7.0), + axes, + data, + (periodic, periodic), + ) == 5.0 + @test ClimaAnalysis.Interpolations.linear_interpolate( + (3.0, 6.0), + axes, + data, + (periodic, periodic), + ) == 9.0 + @test ClimaAnalysis.Interpolations.linear_interpolate( + (4.0, 7.0), + axes, + data, + (flat, periodic), + ) == 6.0 + + # 3D cases with extrapolation conditions + axes = ([1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]) + data = reshape(1:27, (3, 3, 3)) + @test ClimaAnalysis.Interpolations.linear_interpolate( + (1.0, 5.0, 7.0), + axes, + data, + (throw, throw, throw), + ) == 4.0 + @test ClimaAnalysis.Interpolations.linear_interpolate( + (1.5, 5.2, 7.5), + axes, + data, + (throw, throw, throw), + ) ≈ 9.6 + + # Non equispaced + axes = ([1.0, 3.0, 7.0], [4.0, 5.0, 7.0]) + data = reshape(1:9, (3, 3)) + @test ClimaAnalysis.Interpolations.linear_interpolate( + (2.0, 4.5), + axes, + data, + (throw, throw, throw), + ) == 3.0 + @test ClimaAnalysis.Interpolations.linear_interpolate( + (5.0, 6.0), + axes, + data, + (throw, throw, throw), + ) == 7.0 + + # Different types + # Axes have different types and inputs have different types + axes = ([1.0f0, 2.0f0], [Float16(3.0), Float16(4.0)]) + data = [[1.0, 2.0] [3.0, 4.0]] + @test ClimaAnalysis.Interpolations.linear_interpolate( + (1.5f0, 3.5), + axes, + data, + (throw, throw), + ) == 2.5 + @test ClimaAnalysis.Interpolations.linear_interpolate( + (1.5, 4.5f0), + axes, + data, + (flat, flat), + ) == 3.5 + @test ClimaAnalysis.Interpolations.linear_interpolate( + [1.5, 4.5f0], + axes, + data, + (flat, flat), + ) == 3.5 + + # Single number + axes = ([1.0, 2.0, 3.0],) + data = [3.0, 1.0, 0.0] + + @test ClimaAnalysis.Interpolations.linear_interpolate( + 1.0, + axes, + data, + (throw,), + ) == 3.0 +end diff --git a/test/test_Var.jl b/test/test_Var.jl index a131e9ee..0da7942e 100644 --- a/test/test_Var.jl +++ b/test/test_Var.jl @@ -124,35 +124,43 @@ end lon = 0.5:1.0:359.5 |> collect lat = -89.5:1.0:89.5 |> collect time = 1.0:100 |> collect - data = ones(length(lon), length(lat), length(time)) dims = OrderedDict(["lon" => lon, "lat" => lat, "time" => time]) - intp = ClimaAnalysis.Var._make_interpolant(dims, data) - @test intp.et == (Intp.Periodic(), Intp.Flat(), Intp.Throw()) + extp_conds = ClimaAnalysis.Var._find_extp_bound_conds(dims) + @test extp_conds == ( + ClimaAnalysis.Interpolations.extp_cond_periodic(), + ClimaAnalysis.Interpolations.extp_cond_flat(), + ClimaAnalysis.Interpolations.extp_cond_throw(), + ) # Not equispaced for lon and lat lon = 0.5:1.0:359.5 |> collect |> x -> push!(x, 42.0) |> sort lat = -89.5:1.0:89.5 |> collect |> x -> push!(x, 42.0) |> sort time = 1.0:100 |> collect - data = ones(length(lon), length(lat), length(time)) dims = OrderedDict(["lon" => lon, "lat" => lat, "time" => time]) - intp = ClimaAnalysis.Var._make_interpolant(dims, data) - @test intp.et == (Intp.Throw(), Intp.Throw(), Intp.Throw()) + extp_conds = ClimaAnalysis.Var._find_extp_bound_conds(dims) + @test extp_conds == ( + ClimaAnalysis.Interpolations.extp_cond_throw(), + ClimaAnalysis.Interpolations.extp_cond_throw(), + ClimaAnalysis.Interpolations.extp_cond_throw(), + ) # Does not span entire range for and lat lon = 0.5:1.0:350.5 |> collect lat = -89.5:1.0:80.5 |> collect time = 1.0:100 |> collect - data = ones(length(lon), length(lat), length(time)) dims = OrderedDict(["lon" => lon, "lat" => lat, "time" => time]) - intp = ClimaAnalysis.Var._make_interpolant(dims, data) - @test intp.et == (Intp.Throw(), Intp.Throw(), Intp.Throw()) + extp_conds = ClimaAnalysis.Var._find_extp_bound_conds(dims) + @test extp_conds == ( + ClimaAnalysis.Interpolations.extp_cond_throw(), + ClimaAnalysis.Interpolations.extp_cond_throw(), + ClimaAnalysis.Interpolations.extp_cond_throw(), + ) # Lon is exactly 360 degrees lon = 0.0:1.0:360.0 |> collect - data = ones(length(lon)) dims = OrderedDict(["lon" => lon]) - intp = ClimaAnalysis.Var._make_interpolant(dims, data) - @test intp.et == (Intp.Periodic(),) + extp_conds = ClimaAnalysis.Var._find_extp_bound_conds(dims) + @test extp_conds == (ClimaAnalysis.Interpolations.extp_cond_periodic(),) # Dates for the time dimension lon = 0.5:1.0:359.5 |> collect @@ -162,17 +170,18 @@ end Dates.DateTime(2020, 3, 1, 1, 2), Dates.DateTime(2020, 3, 1, 1, 3), ] - data = ones(length(lon), length(lat), length(time)) dims = OrderedDict(["lon" => lon, "lat" => lat, "time" => time]) - intp = ClimaAnalysis.Var._make_interpolant(dims, data) - @test isnothing(intp) + @test_throws ErrorException ClimaAnalysis.Var._check_interpolant(dims) # 2D dimensions arb_dim = reshape(collect(range(-89.5, 89.5, 16)), (4, 4)) - data = collect(1:16) dims = OrderedDict(["arb_dim" => arb_dim]) - intp = ClimaAnalysis.Var._make_interpolant(dims, data) - @test isnothing(intp) + @test_throws ErrorException ClimaAnalysis.Var._check_interpolant(dims) + + # Dimensions are not in increasing order + lon = [0.5, 42.0, 1.5, 110.0] + dims = OrderedDict(["lon" => lon]) + @test_throws ErrorException ClimaAnalysis.Var._check_interpolant(dims) end @testset "empty" begin @@ -529,6 +538,7 @@ end @test ClimaAnalysis.pressure_name(pressure_var) == "pfull" end +# FIX THIS @testset "Interpolation" begin # 1D interpolation with linear data, should yield correct results long = -175.0:175.0 |> collect @@ -539,7 +549,7 @@ end @test longvar.([10.5, 20.5]) == [10.5, 20.5] # Test error for data outside of range - @test_throws BoundsError longvar(200.0) + @test_throws ErrorException longvar(200.0) # 2D interpolation with linear data, should yield correct results time = 100.0:110.0 |> collect @@ -822,7 +832,7 @@ end @test src_var.data == ClimaAnalysis.resampled_as(src_var, src_var).data resampled_var = ClimaAnalysis.resampled_as(src_var, dest_var) @test resampled_var.data == reshape(1.0:(181 * 91), (181, 91))[1:91, 1:46] - @test_throws BoundsError ClimaAnalysis.resampled_as(dest_var, src_var) + @test_throws ErrorException ClimaAnalysis.resampled_as(dest_var, src_var) # BoundsError check src_long = 90.0:120.0 |> collect @@ -838,7 +848,7 @@ end dest_var = ClimaAnalysis.remake(dest_var, data = dest_data, dims = dest_dims) - @test_throws BoundsError ClimaAnalysis.resampled_as(src_var, dest_var) + @test_throws ErrorException ClimaAnalysis.resampled_as(src_var, dest_var) end @testset "Units" begin @@ -1869,7 +1879,6 @@ end attribs = Dict("long_name" => "hi") dim_attribs = OrderedDict(["lon" => Dict("units" => "deg")]) var = ClimaAnalysis.OutputVar(attribs, dims, dim_attribs, data) - @test isnothing(ClimaAnalysis.Var._make_interpolant(dims, data)) reverse_var = ClimaAnalysis.reverse_dim(var, "lat") @test reverse(lat) == reverse_var.dims["lat"]