From 35397b1799c18130f8936167758224705560751b Mon Sep 17 00:00:00 2001 From: Sungho Shin Date: Tue, 19 May 2026 11:44:12 -0400 Subject: [PATCH 1/2] =?UTF-8?q?GPU=20extensions=20(Metal,=20oneAPI):=20uni?= =?UTF-8?q?fy=20Float64=20=E2=86=92=20Float32=20conversion?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The Metal and oneAPI backend extensions both need to recursively convert Float64 host data to Float32 before upload, because - Metal cannot compile Float64 IR at all, - oneAPI Arc-class devices reject Float64 allocations outright (Iris Xe / Data Center GPU Max accept them but the user usually wants fp32 anyway). Two things were wrong before this change: 1. The Metal extension's `replace_float_64` only handled scalars, Tuples, and NamedTuples — its catch-all `replace_float_64(x) = x` silently let plain structs through. So Float64 fields inside structs like ExaPowerIO's `BusData`/`GenData`/`BranchData` reached the device, and the JIT either emitted invalid Float64 IR (best case) or crashed Apple's Metal compiler with an XPC interrupt (typical case for second-order AD kernels on OPF). 2. The oneAPI extension only defined `convert_array` behind `if pkgversion(oneAPI) < v\"2.6\"`. On current oneAPI (≥ v2.6) the branch was skipped, falling back to `adapt(backend, v)` which preserved Float64 element type. Every fp32 OPF/LV/COPS run on Arc A770 bombed at host-to-device transfer with "Float64 is not supported on this device". This commit: - Hoists `replace_float_64` into `src/templates.jl` (single implementation, both extensions reuse it). - Implements the struct recursion with pure multiple dispatch (no try/catch, no `@generated`). Specialized methods cover Float64, Tuple, NamedTuple, AbstractArray{Float64}, and AbstractArray; the generic struct path uses `Val(fieldcount(T))` to keep `ntuple` type-stable and `ConstructionBase.constructorof(T)` to rebuild the struct with Float32-typed fields. Adds ConstructionBase as a dep. - In the Metal extension: removes the bespoke (incomplete) `replace_float_64` and adds the `MtlArray` identity overload to `convert_array`. Both extensions now read `convert_array(v::DeviceArray, ::Backend) = v` and `convert_array(v, ::Backend) = Backend.Array(ExaModels.replace_float_64.(v))`. - In the oneAPI extension: drops the `< v2.6` guard from the `convert_array` definition so the dispatch is unconditional. Keeps the older `sort!`/`findall` shims inside the version guard since those remain version-specific. - Adds a `replace_float_64` test block to UtilsTest covering: scalar leaf, mixed Tuple/NamedTuple, Vector{Float64} eltype check, mixed-Any array, parametric struct rebuild, and nested struct with array fields. Validated manually: - Apple M1 (Metal, local): OPF case118 polar with T=Float32 — all five NLPModels callbacks succeed (obj, cons!, grad!, jac_coord!, hess_coord!). LV.rosenrock with T=Float32 — same. - Apple M2 Pro (Metal, pro): 231/240-case OPF+LV+COPS benchmark with T=Float32; remaining 9 failures are rocket/glider Metal-compiler XPC crashes on hess_coord! and dirichlet/henon/lane_emden which use a separate PDEProblem struct in ExaModelsPower that needs its own parameterization on T (out of scope here). - Intel Arc A770 (oneAPI, shin-compute-002): was 0/240 before this change; smoke test on case118 polar now PASS. Full 240-case run is in progress as of this commit. Co-Authored-By: Claude Opus 4.7 --- Project.toml | 1 + ext/ExaModelsMetal.jl | 10 ++----- ext/ExaModelsOneAPI.jl | 6 ++-- src/ExaModels.jl | 1 + src/templates.jl | 18 ++++++++++++ test/UtilsTest/UtilsTest.jl | 57 +++++++++++++++++++++++++++++++++++++ 6 files changed, 84 insertions(+), 9 deletions(-) diff --git a/Project.toml b/Project.toml index f1c2198fb..10d9a169a 100644 --- a/Project.toml +++ b/Project.toml @@ -4,6 +4,7 @@ version = "0.10.0" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" +ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9" NLPModels = "a4795742-8479-5a88-8948-cc11e1c8c1a6" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" SolverCore = "ff4d7338-4cf1-434d-91df-b86cb86fb843" diff --git a/ext/ExaModelsMetal.jl b/ext/ExaModelsMetal.jl index 03a1f7735..58f60bf8e 100644 --- a/ext/ExaModelsMetal.jl +++ b/ext/ExaModelsMetal.jl @@ -5,14 +5,10 @@ import ExaModels, Metal ExaModels.sort!(array::A; lt = isless) where {A<:Metal.MtlArray} = copyto!(array, sort!(Array(array); lt = lt)) -replace_float_64(a::NamedTuple{names,T}) where {names, T} = NamedTuple{names}(replace_float_64.(Tuple(a))) -replace_float_64(a::Tuple) = replace_float_64.(a) -replace_float_64(x::Float64) = Float32(x) -replace_float_64(x) = x - -ExaModels.convert_array(v, ::Metal.MetalBackend) = ExaModels.adapt(Metal.MetalBackend(), replace_float_64.(v)) +ExaModels.convert_array(v::Metal.MtlArray, ::Metal.MetalBackend) = v +ExaModels.convert_array(v, ::Metal.MetalBackend) = + Metal.MtlArray(ExaModels.replace_float_64.(v)) ExaModels.default_T(::Metal.MetalBackend) = Float32 end # module - diff --git a/ext/ExaModelsOneAPI.jl b/ext/ExaModelsOneAPI.jl index c8f8096b0..407811128 100644 --- a/ext/ExaModelsOneAPI.jl +++ b/ext/ExaModelsOneAPI.jl @@ -2,9 +2,11 @@ module ExaModelsOneAPI import ExaModels, oneAPI -if pkgversion(oneAPI) < v"2.6" +ExaModels.convert_array(v::oneAPI.oneArray, ::oneAPI.oneAPIBackend) = v +ExaModels.convert_array(v, ::oneAPI.oneAPIBackend) = + oneAPI.oneArray(ExaModels.replace_float_64.(v)) - ExaModels.convert_array(v, backend::oneAPI.oneAPIBackend) = oneAPI.oneArray(v) +if pkgversion(oneAPI) < v"2.6" ExaModels.sort!(array::A; lt = isless) where {A<:oneAPI.oneArray} = copyto!(array, sort!(Array(array); lt = lt)) diff --git a/src/ExaModels.jl b/src/ExaModels.jl index dd65e3af1..5c9e6ea0c 100644 --- a/src/ExaModels.jl +++ b/src/ExaModels.jl @@ -27,6 +27,7 @@ For more information, please visit https://github.com/exanauts/ExaModels.jl module ExaModels import Adapt: adapt +import ConstructionBase import NLPModels: NLPModels, obj, diff --git a/src/templates.jl b/src/templates.jl index a3dd5ce2d..b37f81495 100644 --- a/src/templates.jl +++ b/src/templates.jl @@ -2,6 +2,24 @@ convert_array(v, ::Nothing) = v convert_array(v, backend) = adapt(backend, v) +# Recursively replace Float64 with Float32 in scalars, containers, and the +# fields of arbitrary structs. Used by backend extensions (Metal, oneAPI) that +# either reject Float64 outright or perform poorly on it. Type-stable via +# multiple dispatch on leaf types and Val(fieldcount(T)) on the generic struct +# path. +replace_float_64(x::Float64) = Float32(x) +replace_float_64(x::Tuple) = map(replace_float_64, x) +replace_float_64(x::NamedTuple) = map(replace_float_64, x) +replace_float_64(x::AbstractArray{Float64}) = Float32.(x) +replace_float_64(x::AbstractArray) = replace_float_64.(x) +@inline replace_float_64(x::T) where {T} = _rebuild_float_32(x, Val(fieldcount(T))) +@inline _rebuild_float_32(x, ::Val{0}) = x +@inline function _rebuild_float_32(x::T, ::Val{N}) where {T, N} + ConstructionBase.constructorof(T)( + ntuple(i -> replace_float_64(getfield(x, i)), Val(N))... + ) +end + # to avoid type privacy sort!(array; kwargs...) = Base.sort!(array; kwargs...) diff --git a/test/UtilsTest/UtilsTest.jl b/test/UtilsTest/UtilsTest.jl index 4b3b10c59..f646e2713 100644 --- a/test/UtilsTest/UtilsTest.jl +++ b/test/UtilsTest/UtilsTest.jl @@ -8,6 +8,20 @@ UTIL_MODELS = [ExaModels.TimedNLPModel, ExaModels.CompressedNLPModel] FIELDS = [:solution, :multipliers, :multipliers_L, :multipliers_U, :objective] +# Test struct for replace_float_64 coverage +struct _RFTestBus{T} + id::Int + vmax::T + vmin::T + name::Symbol +end + +struct _RFTestNetwork{T} + baseMVA::T + bus::Vector{_RFTestBus{T}} + coefs::NTuple{3, T} +end + function runtests() @testset "Utils tests" begin m, ~ = _exa_luksan_vlcek_model(nothing, 3) @@ -26,6 +40,49 @@ function runtests() end end end + + @testset "replace_float_64" begin + # Scalar leaf + @test ExaModels.replace_float_64(1.5) === 1.5f0 + @test ExaModels.replace_float_64(1.5f0) === 1.5f0 + @test ExaModels.replace_float_64(3) === 3 + @test ExaModels.replace_float_64(:foo) === :foo + + # Tuple / NamedTuple recursion + @test ExaModels.replace_float_64((1.0, 2, "x")) === (1.0f0, 2, "x") + @test ExaModels.replace_float_64((a = 1.0, b = 2)) === (a = 1.0f0, b = 2) + + # Array of Float64 + @test ExaModels.replace_float_64([1.0, 2.0]) == Float32[1.0, 2.0] + @test eltype(ExaModels.replace_float_64([1.0, 2.0])) === Float32 + + # Mixed-eltype array routed elementwise + arr = Any[1.0, 2] + out = ExaModels.replace_float_64(arr) + @test out[1] === 1.0f0 && out[2] === 2 + + # Parametric struct rebuild + b = _RFTestBus{Float64}(1, 1.1, 0.9, :b1) + b32 = ExaModels.replace_float_64(b) + @test b32 isa _RFTestBus{Float32} + @test b32.id === 1 + @test b32.vmax === 1.1f0 + @test b32.name === :b1 + + # Nested struct + array field + net = _RFTestNetwork(100.0, + [_RFTestBus{Float64}(1, 1.1, 0.9, :b1), + _RFTestBus{Float64}(2, 1.2, 0.8, :b2)], + (0.1, 0.2, 0.3)) + net32 = ExaModels.replace_float_64(net) + @test net32 isa _RFTestNetwork{Float32} + @test net32.baseMVA === 100.0f0 + @test eltype(net32.bus) === _RFTestBus{Float32} + @test net32.coefs === (0.1f0, 0.2f0, 0.3f0) + + # Identity on type with no Float64 anywhere + @test ExaModels.replace_float_64((1, 2, 3)) === (1, 2, 3) + end end end From 9de159c5d2cf21436a425a178daa1ef03d3e8852 Mon Sep 17 00:00:00 2001 From: Sungho Shin Date: Tue, 19 May 2026 13:16:47 -0400 Subject: [PATCH 2/2] oneAPI ext: set default_T = Float32 The convert_array dispatch normalizes host data to Float32 before upload, so the default float type for an oneAPI backend should match. Without this, OracleTest's _atol(backend) helper still picks the Float64 tolerance (1e-10), which is tighter than Float32 machine epsilon (~1e-7), causing 15 false test failures on the oneapi CI runner. This mirrors the existing Metal extension which already declares default_T(::MetalBackend) = Float32. --- ext/ExaModelsOneAPI.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ext/ExaModelsOneAPI.jl b/ext/ExaModelsOneAPI.jl index 407811128..8b72ff14f 100644 --- a/ext/ExaModelsOneAPI.jl +++ b/ext/ExaModelsOneAPI.jl @@ -6,6 +6,8 @@ ExaModels.convert_array(v::oneAPI.oneArray, ::oneAPI.oneAPIBackend) = v ExaModels.convert_array(v, ::oneAPI.oneAPIBackend) = oneAPI.oneArray(ExaModels.replace_float_64.(v)) +ExaModels.default_T(::oneAPI.oneAPIBackend) = Float32 + if pkgversion(oneAPI) < v"2.6" ExaModels.sort!(array::A; lt = isless) where {A<:oneAPI.oneArray} =