diff --git a/Project.toml b/Project.toml index f1c2198f..10d9a169 100644 --- a/Project.toml +++ b/Project.toml @@ -4,6 +4,7 @@ version = "0.10.0" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" +ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9" NLPModels = "a4795742-8479-5a88-8948-cc11e1c8c1a6" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" SolverCore = "ff4d7338-4cf1-434d-91df-b86cb86fb843" diff --git a/ext/ExaModelsMetal.jl b/ext/ExaModelsMetal.jl index 03a1f773..58f60bf8 100644 --- a/ext/ExaModelsMetal.jl +++ b/ext/ExaModelsMetal.jl @@ -5,14 +5,10 @@ import ExaModels, Metal ExaModels.sort!(array::A; lt = isless) where {A<:Metal.MtlArray} = copyto!(array, sort!(Array(array); lt = lt)) -replace_float_64(a::NamedTuple{names,T}) where {names, T} = NamedTuple{names}(replace_float_64.(Tuple(a))) -replace_float_64(a::Tuple) = replace_float_64.(a) -replace_float_64(x::Float64) = Float32(x) -replace_float_64(x) = x - -ExaModels.convert_array(v, ::Metal.MetalBackend) = ExaModels.adapt(Metal.MetalBackend(), replace_float_64.(v)) +ExaModels.convert_array(v::Metal.MtlArray, ::Metal.MetalBackend) = v +ExaModels.convert_array(v, ::Metal.MetalBackend) = + Metal.MtlArray(ExaModels.replace_float_64.(v)) ExaModels.default_T(::Metal.MetalBackend) = Float32 end # module - diff --git a/ext/ExaModelsOneAPI.jl b/ext/ExaModelsOneAPI.jl index c8f8096b..8b72ff14 100644 --- a/ext/ExaModelsOneAPI.jl +++ b/ext/ExaModelsOneAPI.jl @@ -2,9 +2,13 @@ module ExaModelsOneAPI import ExaModels, oneAPI -if pkgversion(oneAPI) < v"2.6" +ExaModels.convert_array(v::oneAPI.oneArray, ::oneAPI.oneAPIBackend) = v +ExaModels.convert_array(v, ::oneAPI.oneAPIBackend) = + oneAPI.oneArray(ExaModels.replace_float_64.(v)) + +ExaModels.default_T(::oneAPI.oneAPIBackend) = Float32 - ExaModels.convert_array(v, backend::oneAPI.oneAPIBackend) = oneAPI.oneArray(v) +if pkgversion(oneAPI) < v"2.6" ExaModels.sort!(array::A; lt = isless) where {A<:oneAPI.oneArray} = copyto!(array, sort!(Array(array); lt = lt)) diff --git a/src/ExaModels.jl b/src/ExaModels.jl index dd65e3af..5c9e6ea0 100644 --- a/src/ExaModels.jl +++ b/src/ExaModels.jl @@ -27,6 +27,7 @@ For more information, please visit https://github.com/exanauts/ExaModels.jl module ExaModels import Adapt: adapt +import ConstructionBase import NLPModels: NLPModels, obj, diff --git a/src/templates.jl b/src/templates.jl index a3dd5ce2..b37f8149 100644 --- a/src/templates.jl +++ b/src/templates.jl @@ -2,6 +2,24 @@ convert_array(v, ::Nothing) = v convert_array(v, backend) = adapt(backend, v) +# Recursively replace Float64 with Float32 in scalars, containers, and the +# fields of arbitrary structs. Used by backend extensions (Metal, oneAPI) that +# either reject Float64 outright or perform poorly on it. Type-stable via +# multiple dispatch on leaf types and Val(fieldcount(T)) on the generic struct +# path. +replace_float_64(x::Float64) = Float32(x) +replace_float_64(x::Tuple) = map(replace_float_64, x) +replace_float_64(x::NamedTuple) = map(replace_float_64, x) +replace_float_64(x::AbstractArray{Float64}) = Float32.(x) +replace_float_64(x::AbstractArray) = replace_float_64.(x) +@inline replace_float_64(x::T) where {T} = _rebuild_float_32(x, Val(fieldcount(T))) +@inline _rebuild_float_32(x, ::Val{0}) = x +@inline function _rebuild_float_32(x::T, ::Val{N}) where {T, N} + ConstructionBase.constructorof(T)( + ntuple(i -> replace_float_64(getfield(x, i)), Val(N))... + ) +end + # to avoid type privacy sort!(array; kwargs...) = Base.sort!(array; kwargs...) diff --git a/test/UtilsTest/UtilsTest.jl b/test/UtilsTest/UtilsTest.jl index 4b3b10c5..f646e271 100644 --- a/test/UtilsTest/UtilsTest.jl +++ b/test/UtilsTest/UtilsTest.jl @@ -8,6 +8,20 @@ UTIL_MODELS = [ExaModels.TimedNLPModel, ExaModels.CompressedNLPModel] FIELDS = [:solution, :multipliers, :multipliers_L, :multipliers_U, :objective] +# Test struct for replace_float_64 coverage +struct _RFTestBus{T} + id::Int + vmax::T + vmin::T + name::Symbol +end + +struct _RFTestNetwork{T} + baseMVA::T + bus::Vector{_RFTestBus{T}} + coefs::NTuple{3, T} +end + function runtests() @testset "Utils tests" begin m, ~ = _exa_luksan_vlcek_model(nothing, 3) @@ -26,6 +40,49 @@ function runtests() end end end + + @testset "replace_float_64" begin + # Scalar leaf + @test ExaModels.replace_float_64(1.5) === 1.5f0 + @test ExaModels.replace_float_64(1.5f0) === 1.5f0 + @test ExaModels.replace_float_64(3) === 3 + @test ExaModels.replace_float_64(:foo) === :foo + + # Tuple / NamedTuple recursion + @test ExaModels.replace_float_64((1.0, 2, "x")) === (1.0f0, 2, "x") + @test ExaModels.replace_float_64((a = 1.0, b = 2)) === (a = 1.0f0, b = 2) + + # Array of Float64 + @test ExaModels.replace_float_64([1.0, 2.0]) == Float32[1.0, 2.0] + @test eltype(ExaModels.replace_float_64([1.0, 2.0])) === Float32 + + # Mixed-eltype array routed elementwise + arr = Any[1.0, 2] + out = ExaModels.replace_float_64(arr) + @test out[1] === 1.0f0 && out[2] === 2 + + # Parametric struct rebuild + b = _RFTestBus{Float64}(1, 1.1, 0.9, :b1) + b32 = ExaModels.replace_float_64(b) + @test b32 isa _RFTestBus{Float32} + @test b32.id === 1 + @test b32.vmax === 1.1f0 + @test b32.name === :b1 + + # Nested struct + array field + net = _RFTestNetwork(100.0, + [_RFTestBus{Float64}(1, 1.1, 0.9, :b1), + _RFTestBus{Float64}(2, 1.2, 0.8, :b2)], + (0.1, 0.2, 0.3)) + net32 = ExaModels.replace_float_64(net) + @test net32 isa _RFTestNetwork{Float32} + @test net32.baseMVA === 100.0f0 + @test eltype(net32.bus) === _RFTestBus{Float32} + @test net32.coefs === (0.1f0, 0.2f0, 0.3f0) + + # Identity on type with no Float64 anywhere + @test ExaModels.replace_float_64((1, 2, 3)) === (1, 2, 3) + end end end