From b663bc14bde35562ed2f8dfbc908c50965a0c950 Mon Sep 17 00:00:00 2001 From: AmanKashyap0807 Date: Sat, 7 Feb 2026 23:05:42 +0530 Subject: [PATCH] Add write_to_netcdf support for OutputVar --- docs/src/api.md | 1 + docs/src/var.md | 13 +++++ src/Var.jl | 84 +++++++++++++++++++++++++++++++ test/test_Var.jl | 125 +++++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 223 insertions(+) diff --git a/docs/src/api.md b/docs/src/api.md index c14b8199..3fc7fcdc 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -31,6 +31,7 @@ Catalog.available_vars(catalog::NCCatalog) ```@docs Var.OutputVar Var.read_var +Var.write_to_netcdf Var.is_z_1D Base.isempty(var::OutputVar) Var.short_name diff --git a/docs/src/var.md b/docs/src/var.md index e8080835..1b4304ef 100644 --- a/docs/src/var.md +++ b/docs/src/var.md @@ -19,6 +19,19 @@ import ClimaAnalysis: OutputVar myfile = OutputVar("my_netcdf_file.nc", "myvar") ``` +## Writing to NetCDF + +`OutputVar`s can be written to NetCDF files using [`write_to_netcdf`](@ref). + +```julia +ClimaAnalysis.write_to_netcdf("output.nc", var) +``` + +This creates a NetCDF file containing the data, dimensions, and attributes of the +`OutputVar`. The name of the variable in the NetCDF file is taken from the `short_name` +attribute of the `OutputVar`. If a file already exists at the target path, it will be +overwritten. + ## Physical units `OutputVar`s can contain information about their physical units. For diff --git a/src/Var.jl b/src/Var.jl index ba1e6576..09a2b8b5 100644 --- a/src/Var.jl +++ b/src/Var.jl @@ -25,6 +25,7 @@ import ..Utils: export OutputVar, read_var, + write_to_netcdf, average_lat, weighted_average_lat, average_lon, @@ -476,6 +477,89 @@ function read_var(paths::Vector{String}; short_name = nothing) end end +""" + write_to_netcdf(path::AbstractString, var::OutputVar) + +Write the `OutputVar` to a NetCDF file at `path`, overwriting any existing +file. The variable name is taken from `short_name(var)`. If not present, an +error is thrown. The resulting NetCDF file can be read back using +[`read_var`](@ref). + +Example +========= + +```julia +var = OutputVar(attribs, dims, dim_attribs, data) +write_to_netcdf("output.nc", var) + +# Read it back +var_read = read_var("output.nc") +``` +""" +function write_to_netcdf(path::AbstractString, var::OutputVar) + # check short_name exists + var_name = short_name(var) + isempty(var_name) && + error("OutputVar must have a short_name to be written to NetCDF") + + # check dimensions match data + var_dim_names = collect(keys(var.dims)) + ndims(var.data) == length(var.dims) || error( + "Number of dimensions in data ($(ndims(var.data))) does not match number of dims ($(length(var.dims)))", + ) + for (i, (dim_name, dim_val)) in enumerate(var.dims) + if ndims(dim_val) == 1 + expected_len = length(dim_val) + actual_len = size(var.data, i) + expected_len == actual_len || error( + "Dimension $dim_name has length $expected_len but data has size $actual_len along dimension $i", + ) + else + error( + "Dimension $dim_name must be 1D to be written to NetCDF (ndims=$(ndims(dim_val)))", + ) + end + end + + NCDatasets.NCDataset(path, "c") do nc + # Define dimensions and coordinate variables + for (i, (dim_name, dim_val)) in enumerate(var.dims) + dim_var = NCDatasets.defVar(nc, dim_name, dim_val, (dim_name,)) + if haskey(var.dim_attributes, dim_name) + for (k, v) in var.dim_attributes[dim_name] + dim_var.attrib[k] = _netcdf_safe_attrib(v) + end + end + end + + nc_var = + NCDatasets.defVar(nc, var_name, eltype(var.data), var_dim_names) + + nc_var[:] = var.data + + # Write attributes with type safety (overwrite existing keys) + for (k, v) in var.attributes + nc_var.attrib[k] = _netcdf_safe_attrib(v) + end + end + return nothing +end + +""" + _netcdf_safe_attrib(v) + +Convert attribute value `v` to a NetCDF-safe type. +NetCDF attributes must be basic scalar types or strings. +""" +function _netcdf_safe_attrib(v) + v isa AbstractString && return v + v isa Number && return v + v isa AbstractArray && eltype(v) <: Union{Number, AbstractString} && + return v + # For anything else (Unitful, custom types), convert to string + return string(v) +end + """ short_name(var::OutputVar) diff --git a/test/test_Var.jl b/test/test_Var.jl index 34af7861..057f9699 100644 --- a/test/test_Var.jl +++ b/test/test_Var.jl @@ -3613,3 +3613,128 @@ end @test sprint(show, var) == "Attributes:\n long_name => hi\nDimension attributes:\n lat:\n units => deg\nData defined over:\n lat with 0 element" end + +@testset "Write to NetCDF" begin + mktempdir() do tmpdir + nc_path = joinpath(tmpdir, "test_file.nc") + + # Test 1: Round-trip write and read with small arrays + lon = collect(range(-60.0, 60.0, 8)) + lat = collect(range(-30.0, 30.0, 6)) + data = reshape(collect(1:(length(lon) * length(lat))), (length(lon), length(lat))) + dims = OrderedDict("lon" => lon, "lat" => lat) + attribs = Dict( + "short_name" => "test_var", + "long_name" => "Test Variable", + "units" => "m", + ) + dim_attribs = OrderedDict( + "lon" => Dict("units" => "deg"), + "lat" => Dict("units" => "deg"), + ) + + var = ClimaAnalysis.OutputVar(attribs, dims, dim_attribs, data) + ClimaAnalysis.write_to_netcdf(nc_path, var) + + var_read = ClimaAnalysis.read_var(nc_path; short_name = "test_var") + + @test var_read.attributes["short_name"] == attribs["short_name"] + @test var_read.attributes["long_name"] == attribs["long_name"] + @test var_read.attributes["units"] == attribs["units"] + @test var_read.dims["lon"] == lon + @test var_read.dims["lat"] == lat + @test var_read.data ≈ data + @test var_read.dim_attributes["lon"]["units"] == "deg" + @test var_read.dim_attributes["lat"]["units"] == "deg" + + # Test 2: Overwrite behavior (writing to same path replaces file) + lon2 = collect(range(0.0, 10.0, 5)) + data2 = collect(range(0.0, step = 1.0, length = length(lon2))) + dims2 = OrderedDict("lon" => lon2) + attribs2 = Dict("short_name" => "new_var", "units" => "K") + dim_attribs2 = OrderedDict("lon" => Dict("units" => "m")) + var2 = ClimaAnalysis.OutputVar(attribs2, dims2, dim_attribs2, data2) + + ClimaAnalysis.write_to_netcdf(nc_path, var2) + + var2_read = ClimaAnalysis.read_var(nc_path; short_name = "new_var") + @test var2_read.attributes["short_name"] == "new_var" + @test var2_read.data ≈ data2 + @test var2_read.dims["lon"] == lon2 + + # Test 3: Error when short_name is missing + attribs_bad = Dict("long_name" => "No short name") + var_bad = ClimaAnalysis.OutputVar(attribs_bad, dims, dim_attribs, data) + @test_throws ErrorException ClimaAnalysis.write_to_netcdf( + nc_path, + var_bad, + ) + + # Test 4: Attribute type safety (non-string/number attributes) + attribs_mixed = Dict( + "short_name" => "mixed_attrs", + "units" => "kg", + "custom_array" => [1, 2, 3], + ) + var_mixed = + ClimaAnalysis.OutputVar(attribs_mixed, dims2, dim_attribs2, data2) + ClimaAnalysis.write_to_netcdf(nc_path, var_mixed) + var_mixed_read = ClimaAnalysis.read_var( + nc_path; + short_name = "mixed_attrs", + ) + @test var_mixed_read.attributes["short_name"] == "mixed_attrs" + @test var_mixed_read.attributes["custom_array"] == [1, 2, 3] + + # Test 5: Float32 round-trip + float_path = joinpath(tmpdir, "float_file.nc") + lonf = Float32[0, 1] + latf = Float32[10, 20] + dataf = reshape(Float32[1, 2, 3, 4], (length(lonf), length(latf))) + dimsf = OrderedDict("lon" => lonf, "lat" => latf) + attribsf = Dict("short_name" => "f32", "units" => "1") + dim_attribsf = OrderedDict("lon" => Dict("units" => "deg"), "lat" => Dict()) + + var_f32 = ClimaAnalysis.OutputVar(attribsf, dimsf, dim_attribsf, dataf) + ClimaAnalysis.write_to_netcdf(float_path, var_f32) + var_f32_read = ClimaAnalysis.read_var(float_path; short_name = "f32") + + @test eltype(var_f32_read.data) == Float32 + @test var_f32_read.data == dataf + @test var_f32_read.dims == dimsf + + # Test 6: NaN round-trip + nan_path = joinpath(tmpdir, "nan_file.nc") + xs = [0.0, 1.0, 2.0] + data_nan = [1.0, NaN, 3.0] + dims_nan = OrderedDict("x" => xs) + attribs_nan = Dict("short_name" => "nan_var") + dim_attribs_nan = OrderedDict("x" => Dict()) + + var_nan = + ClimaAnalysis.OutputVar(attribs_nan, dims_nan, dim_attribs_nan, data_nan) + ClimaAnalysis.write_to_netcdf(nan_path, var_nan) + var_nan_read = ClimaAnalysis.read_var(nan_path; short_name = "nan_var") + + @test var_nan_read.data[1] == 1.0 + @test isnan(var_nan_read.data[2]) + @test var_nan_read.data[3] == 3.0 + @test var_nan_read.dims == dims_nan + + # # Test 7: Scalar (0D) round-trip + # scalar_path = joinpath(tmpdir, "scalar_file.nc") + # data_scalar = reshape([42.0], ()) + # dims_scalar = OrderedDict{String, Vector{Float64}}() + # attribs_scalar = Dict("short_name" => "scalar") + # dim_attribs_scalar = OrderedDict{String, Dict{Any, Any}}() + + # var_scalar = + # ClimaAnalysis.OutputVar(attribs_scalar, dims_scalar, dim_attribs_scalar, data_scalar) + # ClimaAnalysis.write_to_netcdf(scalar_path, var_scalar) + # var_scalar_read = ClimaAnalysis.read_var(scalar_path; short_name = "scalar") + + # @test length(var_scalar_read.dims) == 0 + # @test ndims(var_scalar_read.data) == 0 + # @test var_scalar_read.data[] == 42.0 + end +end