From 3b435a46afb25ba62697dcbcee9687d912e0db45 Mon Sep 17 00:00:00 2001 From: yanzin00 Date: Mon, 29 Sep 2025 14:59:44 +0000 Subject: [PATCH 1/6] DArray: Restrict copyto! scope to destination --- src/array/copy.jl | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/array/copy.jl b/src/array/copy.jl index 7d92566e6..cb81b4c2d 100644 --- a/src/array/copy.jl +++ b/src/array/copy.jl @@ -131,7 +131,14 @@ function darray_copyto!(B::DArray{TB,NB}, A::DArray{TA,NA}, Binds=parentindices( Arange_local = Arange_global_clamped .- CartesianIndex(Arange_start) .+ CartesianIndex{Nmax}(1) # Perform local view copy - Dagger.@spawn copyto_view!(Out(Bpart), Brange_local, In(Apart), Arange_local) + space = (Bpart isa DTask ? fetch(Bpart; move_value=false, unwrap=false) : Bpart).space + procs = processors(space) + scope = UnionScope([ExactScope(proc) for proc in procs]) + check_uniform(space) + for proc in procs + check_uniform(proc) + end + Dagger.@spawn scope = scope copyto_view!(Out(Bpart), Brange_local, In(Apart), Arange_local) end end end From 0baaafd6f6512ab8ae8263854ba0a27f94f345d2 Mon Sep 17 00:00:00 2001 From: Felipe Tome Date: Fri, 10 Jan 2025 11:33:22 -0300 Subject: [PATCH 2/6] DaggerMPI: Initial implementation Co-authored-by: Julian P Samaroo Co-authored-by: yanzin00 --- LocalPreferences.toml | 10 + Project.toml | 2 + benchmarks/check_comm_asymmetry.jl | 111 ++++ benchmarks/check_comm_asymmetry.py | 97 +++ benchmarks/run_distribute_fetch.jl | 42 ++ benchmarks/run_matmul.jl | 105 ++++ benchmarks/run_qr.jl | 46 ++ src/Dagger.jl | 13 + src/affinity.jl | 32 + src/array/alloc.jl | 21 +- src/array/copy.jl | 9 +- src/array/darray.jl | 50 +- src/array/mul.jl | 56 +- src/array/trsm.jl | 2 +- src/chunks.jl | 102 +--- src/datadeps/aliasing.jl | 515 ++++++++-------- src/datadeps/chunkview.jl | 50 +- src/datadeps/queue.jl | 366 +++++++++-- src/datadeps/remainders.jl | 283 ++++----- src/datadeps/scheduling.jl | 6 +- src/dtask.jl | 6 +- src/lib/domain-blocks.jl | 2 + src/memory-spaces.jl | 363 +++-------- src/mpi.jl | 948 +++++++++++++++++++++++++++++ src/mpi_mempool.jl | 36 ++ src/mutable.jl | 41 ++ src/options.jl | 7 + src/processor.jl | 27 +- src/sch/Sch.jl | 219 ++++--- src/sch/util.jl | 12 +- src/scopes.jl | 2 +- src/shard.jl | 89 +++ src/submission.jl | 28 +- src/thunk.jl | 14 +- src/tochunk.jl | 119 ++++ src/types/acceleration.jl | 1 + src/types/chunk.jl | 27 + src/types/memory-space.jl | 1 + src/types/processor.jl | 2 + src/types/scope.jl | 1 + src/utils/chunks.jl | 6 +- src/utils/dagdebug.jl | 5 +- src/weakchunk.jl | 23 + test/mpi.jl | 70 +++ 44 files changed, 2952 insertions(+), 1015 deletions(-) create mode 100644 LocalPreferences.toml create mode 100644 benchmarks/check_comm_asymmetry.jl create mode 100644 benchmarks/check_comm_asymmetry.py create mode 100644 benchmarks/run_distribute_fetch.jl create mode 100644 benchmarks/run_matmul.jl create mode 100644 benchmarks/run_qr.jl create mode 100644 src/affinity.jl create mode 100644 src/mpi.jl create mode 100644 src/mpi_mempool.jl create mode 100644 src/mutable.jl create mode 100644 src/shard.jl create mode 100644 src/tochunk.jl create mode 100644 src/types/acceleration.jl create mode 100644 src/types/chunk.jl create mode 100644 src/types/memory-space.jl create mode 100644 src/types/processor.jl create mode 100644 src/types/scope.jl create mode 100644 src/weakchunk.jl create mode 100644 test/mpi.jl diff --git a/LocalPreferences.toml b/LocalPreferences.toml new file mode 100644 index 000000000..3a11c113f --- /dev/null +++ b/LocalPreferences.toml @@ -0,0 +1,10 @@ +# When using system MPI, run once in the environment where you run MPI jobs (with MPI module loaded): +# julia --project=Dagger.jl -e 'using MPIPreferences; MPIPreferences.use_system_binary()' +# That populates abi, libmpi, mpiexec and avoids "Unknown MPI ABI nothing". +[MPIPreferences] +_format = "1.0" +abi = "MPICH" +binary = "system" +libmpi = "libmpi" +mpiexec = "mpiexec" +preloads = [] diff --git a/Project.toml b/Project.toml index ce49bf6d7..69163e027 100644 --- a/Project.toml +++ b/Project.toml @@ -13,6 +13,7 @@ Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" +MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195" MemPool = "f9f48841-c794-520a-933b-121f7ba6ed94" NextLA = "d37ed344-79c4-486d-9307-6d11355a15a3" OnlineStats = "a15396b6-48d5-5d58-9928-6d29437db91e" @@ -77,6 +78,7 @@ Graphs = "1" JSON3 = "1" KernelAbstractions = "0.9" MacroTools = "0.5" +MPI = "0.20.22" MemPool = "0.4.12" Metal = "1.1" NextLA = "0.2.2" diff --git a/benchmarks/check_comm_asymmetry.jl b/benchmarks/check_comm_asymmetry.jl new file mode 100644 index 000000000..684240ec5 --- /dev/null +++ b/benchmarks/check_comm_asymmetry.jl @@ -0,0 +1,111 @@ +#!/usr/bin/env julia +# Parse MPI+Dagger logs and report communication decision asymmetry per tag. +# Asymmetry: for the same tag, one rank decides to send (local+bcast, sender+communicated, etc.) +# and another rank decides to infer (inferred, uninvolved) and never recv → deadlock. +# +# Usage: julia check_comm_asymmetry.jl < logfile +# Or: mpiexec -n 10 julia ... run_matmul.jl 2>&1 | tee matmul.log; julia check_comm_asymmetry.jl < matmul.log + +const SEND_DECISIONS = Set([ + "local+bcast", "sender+communicated", "sender+inferred", "receiver+bcast", + "aliasing", # when followed by local+bcast we already capture local+bcast +]) +const RECV_DECISIONS = Set([ + "communicated", "receiver", "sender+communicated", # received data +]) +const INFER_DECISIONS = Set([ + "inferred", "uninvolved", # did not recv (uses inferred type) +]) + +function parse_line(line) + # Match [rank X][tag Y] then any [...] and capture the last bracket pair before space or end + rank = nothing + tag = nothing + decision = nothing + category = nothing # aliasing, execute!, remotecall_endpoint + for m in eachmatch(r"\[rank\s+(\d+)\]", line) + rank = parse(Int, m.captures[1]) + end + for m in eachmatch(r"\[tag\s+(\d+)\]", line) + tag = parse(Int, m.captures[1]) + end + for m in eachmatch(r"\[(execute!|aliasing|remotecall_endpoint)\]", line) + category = m.captures[1] + end + # Decision is usually in last [...] that looks like [word] or [word+word] + for m in eachmatch(r"\]\[([^\]]+)\]", line) + candidate = m.captures[1] + # Normalize: "communicated" "inferred" "local+bcast" "sender+inferred" "receiver" etc. + if occursin("inferred", candidate) && !occursin("communicated", candidate) + decision = "inferred" + break + elseif occursin("communicated", candidate) + decision = "communicated" + break + elseif occursin("local+bcast", candidate) + decision = "local+bcast" + break + elseif occursin("sender+", candidate) + decision = startswith(candidate, "sender+inferred") ? "sender+inferred" : "sender+communicated" + break + elseif candidate == "receiver" + decision = "receiver" + break + elseif candidate == "receiver+bcast" + decision = "receiver+bcast" + break + elseif candidate == "inplace_move" + decision = "inplace_move" + break + end + end + return rank, tag, category, decision +end + +function main() + # tag => Dict(rank => decision) + by_tag = Dict{Int, Dict{Int, String}}() + for line in eachline(stdin) + rank, tag, category, decision = parse_line(line) + isnothing(rank) && continue + isnothing(tag) && continue + isnothing(decision) && continue + if !haskey(by_tag, tag) + by_tag[tag] = Dict{Int, String}() + end + by_tag[tag][rank] = decision + end + + # For each tag, check: is there at least one sender and one inferrer (non-receiver)? + send_keys = Set(["local+bcast", "sender+communicated", "sender+inferred", "receiver+bcast"]) + infer_keys = Set(["inferred", "sender+inferred"]) # sender+inferred means sender didn't need to recv + recv_keys = Set(["communicated", "receiver", "sender+communicated"]) + + asymmetries = [] + for (tag, ranks) in sort(collect(by_tag), by = first) + senders = [r for (r, d) in ranks if d in send_keys] + inferrers = [r for (r, d) in ranks if d in infer_keys || d == "uninvolved"] + receivers = [r for (r, d) in ranks if d in recv_keys] + # Asymmetry: someone sends (bcast) so will send to ALL other ranks; someone chose infer and won't recv. + if !isempty(senders) && !isempty(inferrers) + push!(asymmetries, (tag, senders, inferrers, receivers, ranks)) + end + end + + if isempty(asymmetries) + println("No communication decision asymmetry found (no tag has both sender and inferrer).") + return + end + + println("=== Communication decision asymmetry (can cause deadlock) ===\n") + for (tag, senders, inferrers, receivers, ranks) in asymmetries + println("Tag $tag:") + println(" Senders (will bcast to all others): $senders") + println(" Inferrers (did not recv): $inferrers") + println(" Receivers: $receivers") + println(" All decisions: $ranks") + println() + end +end + +main() diff --git a/benchmarks/check_comm_asymmetry.py b/benchmarks/check_comm_asymmetry.py new file mode 100644 index 000000000..31a117442 --- /dev/null +++ b/benchmarks/check_comm_asymmetry.py @@ -0,0 +1,97 @@ +#!/usr/bin/env python3 +""" +Parse MPI+Dagger logs and report communication decision asymmetry per tag. +Asymmetry: for the same tag, one rank decides to send (local+bcast, etc.) +and another decides to infer (inferred) and never recv → deadlock. + +Usage: + # Capture full log (all ranks' Core.println from mpi.jl go to stdout): + mpiexec -n 10 julia --project=/path/to/Dagger.jl benchmarks/run_matmul.jl 2>&1 | tee matmul.log + # Then look for asymmetry (same tag: one rank sends, another infers → deadlock): + python3 check_comm_asymmetry.py < matmul.log +""" + +import re +import sys +from collections import defaultdict + +SEND_DECISIONS = {"local+bcast", "sender+communicated", "sender+inferred", "receiver+bcast"} +RECV_DECISIONS = {"communicated", "receiver", "sender+communicated"} +INFER_DECISIONS = {"inferred", "uninvolved", "sender+inferred"} + + +def parse_line(line: str): + rank = tag = category = decision = None + m = re.search(r"\[rank\s+(\d+)\]", line) + if m: + rank = int(m.group(1)) + m = re.search(r"\[tag\s+(\d+)\]", line) + if m: + tag = int(m.group(1)) + m = re.search(r"\[(execute!|aliasing|remotecall_endpoint)\]", line) + if m: + category = m.group(1) + # Capture decision from [...] blocks + for m in re.finditer(r"\]\[([^\]]+)\]", line): + candidate = m.group(1) + if "inferred" in candidate and "communicated" not in candidate: + decision = "inferred" + break + if "communicated" in candidate: + decision = "communicated" + break + if "local+bcast" in candidate: + decision = "local+bcast" + break + if candidate.startswith("sender+"): + decision = "sender+inferred" if "inferred" in candidate else "sender+communicated" + break + if candidate == "receiver": + decision = "receiver" + break + if candidate == "receiver+bcast": + decision = "receiver+bcast" + break + if candidate == "inplace_move": + decision = "inplace_move" + break + return rank, tag, category, decision + + +def main(): + by_tag = defaultdict(dict) # tag -> {rank: decision} + for line in sys.stdin: + rank, tag, category, decision = parse_line(line) + if rank is None or tag is None or decision is None: + continue + by_tag[tag][rank] = decision + + send_keys = {"local+bcast", "sender+communicated", "sender+inferred", "receiver+bcast"} + infer_keys = {"inferred", "sender+inferred", "uninvolved"} + recv_keys = {"communicated", "receiver", "sender+communicated"} + + asymmetries = [] + for tag in sorted(by_tag.keys()): + ranks = by_tag[tag] + senders = [r for r, d in ranks.items() if d in send_keys] + inferrers = [r for r, d in ranks.items() if d in infer_keys] + receivers = [r for r, d in ranks.items() if d in recv_keys] + if senders and inferrers: + asymmetries.append((tag, senders, inferrers, receivers, ranks)) + + if not asymmetries: + print("No communication decision asymmetry found (no tag has both sender and inferrer).") + return + + print("=== Communication decision asymmetry (can cause deadlock) ===\n") + for tag, senders, inferrers, receivers, ranks in asymmetries: + print(f"Tag {tag}:") + print(f" Senders (will bcast to all others): {senders}") + print(f" Inferrers (did not recv): {inferrers}") + print(f" Receivers: {receivers}") + print(f" All decisions: {dict(ranks)}") + print() + + +if __name__ == "__main__": + main() diff --git a/benchmarks/run_distribute_fetch.jl b/benchmarks/run_distribute_fetch.jl new file mode 100644 index 000000000..822e1ad2c --- /dev/null +++ b/benchmarks/run_distribute_fetch.jl @@ -0,0 +1,42 @@ +#!/usr/bin/env julia +# Create a matrix with a fixed reproducible pattern, distribute it with an +# MPI procgrid, then on each rank fetch and println the chunk(s) it owns. +# Usage (from repo root, use full path to Dagger.jl): +# mpiexec -n 4 julia --project=/path/to/Dagger.jl benchmarks/run_distribute_fetch.jl + +using MPI +using Dagger + +if !isdefined(Dagger, :accelerate!) + error("Dagger.accelerate! not found. Run with the local Dagger project: julia --project=/path/to/Dagger.jl ...") +end +Dagger.accelerate!(:mpi) + +const comm = MPI.COMM_WORLD +const rank = MPI.Comm_rank(comm) +const nranks = MPI.Comm_size(comm) + +# Fixed reproducible pattern: 6×6 matrix, M[i,j] = 10*i + j (same on all ranks) +const N = 6 +const BLOCK = 2 +A = [10 * i + j for i in 1:N, j in 1:N] + +# Procgrid: use Dagger's compatible processors so the procgrid passes validation +availprocs = collect(Dagger.compatible_processors()) +nblocks = (cld(N, BLOCK), cld(N, BLOCK)) +procgrid = reshape( + [availprocs[mod(i - 1, length(availprocs)) + 1] for i in 1:prod(nblocks)], + nblocks, +) + +# Distribute so chunk (i,j) is computed on procgrid[i,j] +D = distribute(A, Blocks(BLOCK, BLOCK), procgrid) +D_fetched = fetch(D) + +# On each rank: fetch and print only the chunk(s) this rank owns +for (idx, ch) in enumerate(D_fetched.chunks) + if ch isa Dagger.Chunk && ch.handle isa Dagger.MPIRef && ch.handle.rank == rank + data = fetch(ch) + println("rank $rank chunk $idx: ", data) + end +end diff --git a/benchmarks/run_matmul.jl b/benchmarks/run_matmul.jl new file mode 100644 index 000000000..0eb4ec0d7 --- /dev/null +++ b/benchmarks/run_matmul.jl @@ -0,0 +1,105 @@ +#!/usr/bin/env julia +# N×N matmul benchmark (Float32); block size scales with number of ranks. +# Usage (use the full path to Dagger.jl, not "..."): +# mpiexec -n 10 julia --project=/home/felipetome/dagger-dev/mpi/Dagger.jl benchmarks/run_matmul.jl +# Set CHECK_CORRECTNESS=true to collect and compare against GPU baseline: +# CHECK_CORRECTNESS=true mpiexec -n 10 julia --project=/home/felipetome/dagger-dev/mpi/Dagger.jl benchmarks/run_matmul.jl + +using MPI +using Dagger +using LinearAlgebra + +if !isdefined(Dagger, :accelerate!) + error("Dagger.accelerate! not found. Run with the local Dagger project: julia --project=/path/to/Dagger.jl ...") +end +Dagger.accelerate!(:mpi) + +const N = 2_000 +const comm = MPI.COMM_WORLD +const rank = MPI.Comm_rank(comm) +const nranks = MPI.Comm_size(comm) +# Block size proportional to ranks: ~nranks blocks in 2D => side blocks ≈ √nranks +const BLOCK = max(1, ceil(Int, N / ceil(Int, sqrt(nranks)))) + +const CHECK_CORRECTNESS = parse(Bool, get(ENV, "CHECK_CORRECTNESS", "false")) + +if rank == 0 + println("Benchmark: ", nranks, " ranks, N=", N, ", block size ", BLOCK, "×", BLOCK, " (matmul)") +end + +# Allocate and fill matrices in blocks (Float32) +A = rand(Blocks(BLOCK, BLOCK), Float32, N, N) +B = rand(Blocks(BLOCK, BLOCK), Float32, N, N) + +# Matrix multiply C = A * B +t_matmul = @elapsed begin + C = A * B +end + +if rank == 0 + println("Matmul time: ", round(t_matmul; digits=4), " s") +end + +# Optional: collect via datadeps (root=0). All ranks participate in the datadeps region. +if CHECK_CORRECTNESS + t_collect = @elapsed begin + A_full = Dagger.collect_datadeps(A; root=0) + B_full = Dagger.collect_datadeps(B; root=0) + C_dagger = Dagger.collect_datadeps(C; root=0) + end + if rank == 0 + println("Collecting result and computing baseline for correctness check (GPU)...") + using CUDA + CUDA.functional() || error("CUDA not functional; cannot compute GPU baseline. Check CUDA driver and device.") + t_upload = @elapsed begin + A_g = CUDA.cu(A_full) + B_g = CUDA.cu(B_full) + end + println("Collect + upload time: ", round(t_collect + t_upload; digits=4), " s") + + t_baseline = @elapsed begin + C_ref_g = A_g * B_g + end + println("Baseline (GPU/CUDA) time: ", round(t_baseline; digits=4), " s") + + # Require all elements within 100× machine epsilon relative error (componentwise) + C_dagger_cpu = C_dagger + C_ref_cpu = Array(C_ref_g) + eps_f = eps(Float32) + rtol = 50.0f0 * eps_f + diff = C_dagger_cpu .- C_ref_cpu + # rel_ij = |diff|/|C_ref|, denominator at least eps to avoid div by zero + denom = max.(abs.(C_ref_cpu), eps_f) + rel_err = abs.(diff) ./ denom + max_rel_err = Float32(maximum(rel_err)) + ok = max_rel_err <= rtol + if ok + println("Correctness: OK (max rel_err = ", max_rel_err, " <= 100×eps = ", rtol, ")") + else + println("Correctness: FAIL (max rel_err = ", max_rel_err, " > 100×eps = ", rtol, ")") + end + + # Per-block: which blocks have any element with rel_err > 100×eps + n_bi = ceil(Int, N / BLOCK) + n_bj = ceil(Int, N / BLOCK) + bad_blocks = Tuple{Int,Int,Float32}[] + for bi in 1:n_bi, bj in 1:n_bj + ri = (bi - 1) * BLOCK + 1 : min(bi * BLOCK, N) + rj = (bj - 1) * BLOCK + 1 : min(bj * BLOCK, N) + block_rel = Float32(maximum(@view(rel_err[ri, rj]))) + if block_rel > rtol + push!(bad_blocks, (bi, bj, block_rel)) + end + end + if isempty(bad_blocks) + println("Per-block: all ", n_bi * n_bj, " blocks within 100×eps rel_err.") + else + println("Per-block: ", length(bad_blocks), " block(s) exceed 100×eps rel_err (block size ", BLOCK, "×", BLOCK, "):") + sort!(bad_blocks; by = x -> -x[3]) + for (bi, bj, block_rel) in bad_blocks + println(" block [", bi, ",", bj, "] rows ", (bi - 1) * BLOCK + 1, ":", min(bi * BLOCK, N), + ", cols ", (bj - 1) * BLOCK + 1, ":", min(bj * BLOCK, N), " max rel_err = ", block_rel) + end + end + end +end diff --git a/benchmarks/run_qr.jl b/benchmarks/run_qr.jl new file mode 100644 index 000000000..c5915db2a --- /dev/null +++ b/benchmarks/run_qr.jl @@ -0,0 +1,46 @@ +#!/usr/bin/env julia +# 10k×10k QR + matmul benchmark; block size scales with number of ranks. +# Usage: mpiexec -n 100 julia --project=/path/to/Dagger.jl benchmarks/bench_100rank_qr_matmul.jl +# Or: bash benchmarks/run_100rank_qr_matmul.sh . + +using MPI +using Dagger +using LinearAlgebra + +Dagger.accelerate!(:mpi) + +const N = 10_000 +const comm = MPI.COMM_WORLD +const rank = MPI.Comm_rank(comm) +const nranks = MPI.Comm_size(comm) +# Block size proportional to ranks: ~nranks blocks in 2D => side blocks ≈ √nranks +const BLOCK = max(1, ceil(Int, N / ceil(Int, sqrt(nranks)))) + +if rank == 0 + println("Benchmark: ", nranks, " ranks, N=", N, ", block size ", BLOCK, "×", BLOCK, " (QR + matmul)") +end + +# Allocate and fill 10k×10k matrix in 1k×1k blocks +A = rand(Blocks(BLOCK, BLOCK), Float64, N, N) +MPI.Barrier(comm) + +# QR factorization (computing Q runs the full factorization) +t_qr = @elapsed begin + qr!(A) +end +MPI.Barrier(comm) + +if rank == 0 + println("QR time: ", round(t_qr; digits=4), " s") +end + +# Matrix multiply A * A +t_matmul = @elapsed begin + C = A * A +end +MPI.Barrier(comm) + +if rank == 0 + println("Matmul time: ", round(t_matmul; digits=4), " s") +end + diff --git a/src/Dagger.jl b/src/Dagger.jl index 2e757ebc5..1b3791274 100644 --- a/src/Dagger.jl +++ b/src/Dagger.jl @@ -53,6 +53,13 @@ import Adapt include("lib/util.jl") include("utils/dagdebug.jl") +# Type definitions (for MPI/acceleration) +include("types/processor.jl") +include("types/scope.jl") +include("types/memory-space.jl") +include("types/chunk.jl") +include("types/acceleration.jl") + # Distributed data include("utils/locked-object.jl") include("utils/tasks.jl") @@ -77,6 +84,7 @@ include("queue.jl") include("thunk.jl") include("utils/fetch.jl") include("utils/chunks.jl") +include("weakchunk.jl") include("utils/logging.jl") include("submission.jl") abstract type MemorySpace end @@ -90,6 +98,7 @@ include("utils/clock.jl") include("utils/system_uuid.jl") include("utils/caching.jl") include("sch/Sch.jl"); using .Sch +include("tochunk.jl") # Data dependency task queue include("datadeps/aliasing.jl") @@ -157,6 +166,10 @@ function set_distributed_package!(value) @info "Dagger.jl preference has been set, restart your Julia session for this change to take effect!" end +# MPI (mpi.jl loads MPI; mpi_mempool uses it) +include("mpi.jl") +include("mpi_mempool.jl") + # Precompilation import PrecompileTools: @compile_workload include("precompile.jl") diff --git a/src/affinity.jl b/src/affinity.jl new file mode 100644 index 000000000..aab663a51 --- /dev/null +++ b/src/affinity.jl @@ -0,0 +1,32 @@ +export domain, UnitDomain, project, alignfirst, ArrayDomain + +import Base: isempty, getindex, intersect, ==, size, length, ndims + +""" + domain(x::T) + +Returns metadata about `x`. This metadata will be in the `domain` +field of a Chunk object when an object of type `T` is created as +the result of evaluating a Thunk. +""" +function domain end + +""" + UnitDomain + +Default domain -- has no information about the value +""" +struct UnitDomain end + +""" +If no `domain` method is defined on an object, then +we use the `UnitDomain` on it. A `UnitDomain` is indivisible. +""" +domain(x::Any) = UnitDomain() + +### ChunkIO +affinity(r::DRef) = OSProc(r.owner)=>r.size +# this previously returned a vector with all machines that had the file cached +# but now only returns the owner and size, for consistency with affinity(::DRef), +# see #295 +affinity(r::FileRef) = OSProc(1)=>r.size diff --git a/src/array/alloc.jl b/src/array/alloc.jl index aa1050210..33de3506d 100644 --- a/src/array/alloc.jl +++ b/src/array/alloc.jl @@ -93,14 +93,31 @@ function stage(ctx, A::AllocateArray) scope = ExactScope(A.procgrid[CartesianIndex(mod1.(Tuple(I), size(A.procgrid))...)]) end + N = ndims(A.domainchunks) + ret_type = Array{A.eltype, N} if A.want_index - task = Dagger.@spawn compute_scope=scope allocate_array(A.f, A.eltype, i, size(x)) + task = Dagger.@spawn compute_scope=scope return_type=ret_type allocate_array(A.f, A.eltype, i, size(x)) else - task = Dagger.@spawn compute_scope=scope allocate_array(A.f, A.eltype, size(x)) + task = Dagger.@spawn compute_scope=scope return_type=ret_type allocate_array(A.f, A.eltype, size(x)) end tasks[i] = task end end + # MPI type propagation: ensure all ranks know the concrete chunk types + accel = Dagger.current_acceleration() + if accel isa Dagger.MPIAcceleration + N = ndims(A.domainchunks) + expected_type = Array{A.eltype, N} + Dagger.mpi_propagate_chunk_types!(tasks, accel, expected_type) + # Log chunk types per rank after array creation + rank = MPI.Comm_rank(accel.comm) + #=chunk_types = Type[chunktype(t) for t in tasks] + if allequal(chunk_types) + @info "[rank $rank] Array creation (alloc): all $(length(chunk_types)) chunk types are uniform: $(first(chunk_types))" + else + @warn "[rank $rank] Array creation (alloc): chunk types are NOT uniform: $chunk_types" + end=# + end return DArray(A.eltype, A.domain, A.domainchunks, tasks, A.partitioning) end diff --git a/src/array/copy.jl b/src/array/copy.jl index cb81b4c2d..7d92566e6 100644 --- a/src/array/copy.jl +++ b/src/array/copy.jl @@ -131,14 +131,7 @@ function darray_copyto!(B::DArray{TB,NB}, A::DArray{TA,NA}, Binds=parentindices( Arange_local = Arange_global_clamped .- CartesianIndex(Arange_start) .+ CartesianIndex{Nmax}(1) # Perform local view copy - space = (Bpart isa DTask ? fetch(Bpart; move_value=false, unwrap=false) : Bpart).space - procs = processors(space) - scope = UnionScope([ExactScope(proc) for proc in procs]) - check_uniform(space) - for proc in procs - check_uniform(proc) - end - Dagger.@spawn scope = scope copyto_view!(Out(Bpart), Brange_local, In(Apart), Arange_local) + Dagger.@spawn copyto_view!(Out(Bpart), Brange_local, In(Apart), Arange_local) end end end diff --git a/src/array/darray.jl b/src/array/darray.jl index 32336f95d..7e723acd0 100644 --- a/src/array/darray.jl +++ b/src/array/darray.jl @@ -1,7 +1,7 @@ -import Base: ==, fetch +import Base: ==, fetch, length, isempty, size export DArray, DVector, DMatrix, DVecOrMat, Blocks, AutoBlocks -export distribute +export distribute, collect_datadeps ###### Array Domains ###### @@ -83,7 +83,8 @@ isempty(a::ArrayDomain) = length(a) == 0 The domain of an array is an ArrayDomain. """ domain(x::AbstractArray) = ArrayDomain([1:l for l in size(x)]) - +# Scalar / non-array values (e.g. for Chunk of immediate data) +domain(x::Any) = ArrayDomain(()) abstract type ArrayOp{T, N} <: AbstractArray{T, N} end Base.IndexStyle(::Type{<:ArrayOp}) = IndexCartesian() @@ -174,6 +175,7 @@ domain(d::DArray) = d.domain chunks(d::DArray) = d.chunks domainchunks(d::DArray) = d.subdomains size(x::DArray) = size(domain(x)) +Base.ndims(d::DArray{T,N}) where {T,N} = N stage(ctx, c::DArray) = c function Base.collect(d::DArray{T,N}; tree=false, copyto=false) where {T,N} @@ -200,6 +202,31 @@ function Base.collect(d::DArray{T,N}; tree=false, copyto=false) where {T,N} collect(treereduce_nd(dimcatfuncs, asyncmap(fetch, a.chunks))) end end + +""" + collect_datadeps(d::DArray; root=nothing) + +Collect a DArray to a single array by fetching every chunk on the current rank +and assembling into a full array. No datadeps scheduling or root-only assembly: +each rank that calls this gets the full matrix (useful when correctness matters +more than communication cost). +""" +function collect_datadeps(d::DArray{T,N}; root=nothing) where {T,N} + if isempty(d.chunks) + return Array{eltype(d)}(undef, size(d)...) + end + if N == 0 + return fetch(d.chunks[1]) + end + + chks = d.chunks + doms = domainchunks(d) + out = Array{T,N}(undef, size(d)) + for I in CartesianIndices(chks) + copyto!(view(out, indexes(doms[I])...), fetch(chks[I])) + end + return out +end Array{T,N}(A::DArray{S,N}) where {T,N,S} = convert(Array{T,N}, collect(A)) Base.wait(A::DArray) = foreach(wait, A.chunks) @@ -483,6 +510,21 @@ function stage(ctx::Context, d::Distribute) Dagger.@spawn compute_scope=scope identity(d.data[c]) end end + # MPI type propagation: ensure all ranks know the concrete chunk types + accel = Dagger.current_acceleration() + if accel isa Dagger.MPIAcceleration + N = Base.ndims(d.data) + expected_type = Array{eltype(d.data), N} + Dagger.mpi_propagate_chunk_types!(cs, accel, expected_type) + # Log chunk types per rank after array creation + rank = MPI.Comm_rank(accel.comm) + #=chunk_types = Type[chunktype(t) for t in cs] + if allequal(chunk_types) + @info "[rank $rank] Array creation (distribute): all $(length(chunk_types)) chunk types are uniform: $(first(chunk_types))" + else + @warn "[rank $rank] Array creation (distribute): chunk types are NOT uniform: $chunk_types" + end=# + end return DArray(eltype(d.data), domain(d.data), d.domainchunks, @@ -620,7 +662,7 @@ end mapchunk(f, chunk) = tochunk(f(poolget(chunk.handle))) function mapchunks(f, d::DArray{T,N,F}) where {T,N,F} chunks = map(d.chunks) do chunk - owner = get_parent(chunk.processor).pid + owner = root_worker_id(chunk.processor) remotecall_fetch(mapchunk, owner, f, chunk) end DArray{T,N,F}(d.domain, d.subdomains, chunks, d.concat) diff --git a/src/array/mul.jl b/src/array/mul.jl index 02b207641..5890473da 100644 --- a/src/array/mul.jl +++ b/src/array/mul.jl @@ -41,7 +41,7 @@ function LinearAlgebra.generic_matmatmul!( return gemm_dagger!(C, transA, transB, A, B, alpha, beta) end end -function _repartition_matmatmul(C, A, B, transA::Char, transB::Char) +function _repartition_matmatmul(C, A, B, transA::Char, transB::Char)::Tuple{Blocks{2}, Blocks{2}, Blocks{2}} partA = A.partitioning.blocksize partB = B.partitioning.blocksize istransA = transA == 'T' || transA == 'C' @@ -93,6 +93,24 @@ function _repartition_matmatmul(C, A, B, transA::Char, transB::Char) return Blocks(partC...), Blocks(partA...), Blocks(partB...) end +# Typed BLAS wrappers so that every @spawn kernel has an inferable return type +@inline function _gemm!(transA::Char, transB::Char, alpha::T, A, B, mzone, C)::Matrix{T} where {T} + BLAS.gemm!(transA, transB, alpha, A, B, mzone, C) + return C +end +@inline function _syrk!(uplo::AbstractChar, trans::AbstractChar, alpha::T, A, mzone, C)::Matrix{T} where {T} + BLAS.syrk!(uplo, trans, alpha, A, mzone, C) + return C +end +@inline function _herk!(uplo::AbstractChar, trans::AbstractChar, alpha::Real, A, mzone, C)::Matrix{<:Complex} + BLAS.herk!(uplo, trans, alpha, A, mzone, C) + return C +end +@inline function _gemv!(transA::Char, alpha::T, A, x, mzone, y)::Vector{T} where {T} + BLAS.gemv!(transA, alpha, A, x, mzone, y) + return y +end + """ Performs one of the matrix-matrix operations @@ -136,7 +154,7 @@ function gemm_dagger!( # A: NoTrans / B: NoTrans for k in range(1, Ant) mzone = k == 1 ? beta : T(1.0) - Dagger.@spawn BLAS.gemm!( + Dagger.@spawn _gemm!( transA, transB, alpha, @@ -150,7 +168,7 @@ function gemm_dagger!( # A: NoTrans / B: [Conj]Trans for k in range(1, Ant) mzone = k == 1 ? beta : T(1.0) - Dagger.@spawn BLAS.gemm!( + Dagger.@spawn _gemm!( transA, transB, alpha, @@ -166,7 +184,7 @@ function gemm_dagger!( # A: [Conj]Trans / B: NoTrans for k in range(1, Amt) mzone = k == 1 ? beta : T(1.0) - Dagger.@spawn BLAS.gemm!( + Dagger.@spawn _gemm!( transA, transB, alpha, @@ -180,7 +198,7 @@ function gemm_dagger!( # A: [Conj]Trans / B: [Conj]Trans for k in range(1, Amt) mzone = k == 1 ? beta : T(1.0) - Dagger.@spawn BLAS.gemm!( + Dagger.@spawn _gemm!( transA, transB, alpha, @@ -243,7 +261,7 @@ function syrk_dagger!( for k in range(1, Ant) mzone = k == 1 ? real(beta) : one(real(T)) if iscomplex - Dagger.@spawn BLAS.herk!( + Dagger.@spawn _herk!( uplo, trans, real(alpha), @@ -252,7 +270,7 @@ function syrk_dagger!( InOut(Cc[n, n]), ) else - Dagger.@spawn BLAS.syrk!( + Dagger.@spawn _syrk!( uplo, trans, alpha, @@ -267,7 +285,7 @@ function syrk_dagger!( for m in range(n + 1, Cmt) for k in range(1, Ant) mzone = k == 1 ? beta : one(T) - Dagger.@spawn BLAS.gemm!( + Dagger.@spawn _gemm!( trans, transs, alpha, @@ -283,7 +301,7 @@ function syrk_dagger!( for m in range(n + 1, Cmt) for k in range(1, Ant) mzone = k == 1 ? beta : one(T) - Dagger.@spawn BLAS.gemm!( + Dagger.@spawn _gemm!( trans, transs, alpha, @@ -300,7 +318,7 @@ function syrk_dagger!( for k in range(1, Amt) mzone = k == 1 ? real(beta) : one(real(T)) if iscomplex - Dagger.@spawn BLAS.herk!( + Dagger.@spawn _herk!( uplo, transs, real(alpha), @@ -309,7 +327,7 @@ function syrk_dagger!( InOut(Cc[n, n]), ) else - Dagger.@spawn BLAS.syrk!( + Dagger.@spawn _syrk!( uplo, trans, alpha, @@ -324,7 +342,7 @@ function syrk_dagger!( for m in range(n + 1, Cmt) for k in range(1, Amt) mzone = k == 1 ? beta : one(T) - Dagger.@spawn BLAS.gemm!( + Dagger.@spawn _gemm!( transs, 'N', alpha, @@ -340,7 +358,7 @@ function syrk_dagger!( for m in range(n + 1, Cmt) for k in range(1, Amt) mzone = k == 1 ? beta : one(T) - Dagger.@spawn BLAS.gemm!( + Dagger.@spawn _gemm!( transs, 'N', alpha, @@ -393,16 +411,17 @@ end return A end -@inline function copytile!(A, B) +@inline function copytile!(A::AbstractMatrix{T}, B::AbstractMatrix{T})::Nothing where {T} m, n = size(A) C = B' for i = 1:m, j = 1:n A[i, j] = C[i, j] end + return nothing end -@inline function copydiagtile!(A, uplo) +@inline function copydiagtile!(A::AbstractMatrix{T}, uplo::AbstractChar)::Nothing where {T} m, n = size(A) Acpy = copy(A) @@ -417,6 +436,7 @@ end for i = 1:m, j = 1:n A[i, j] = C[i, j] end + return nothing end function LinearAlgebra.generic_matvecmul!( C::DVector{T}, @@ -440,7 +460,7 @@ function LinearAlgebra.generic_matvecmul!( return gemv_dagger!(C, transA, A, B, _alpha, _beta) end end -function _repartition_matvecmul(C, A, B, transA::Char) +function _repartition_matvecmul(C, A, B, transA::Char)::Tuple{Blocks{1}, Blocks{2}, Blocks{1}} partA = A.partitioning.blocksize partB = B.partitioning.blocksize istransA = transA == 'T' || transA == 'C' @@ -495,7 +515,7 @@ function gemv_dagger!( # A: NoTrans for k in range(1, Ant) mzone = k == 1 ? beta : T(1.0) - Dagger.@spawn BLAS.gemv!( + Dagger.@spawn _gemv!( transA, alpha, In(Ac[m, k]), @@ -508,7 +528,7 @@ function gemv_dagger!( # A: [Conj]Trans for k in range(1, Amt) mzone = k == 1 ? beta : T(1.0) - Dagger.@spawn BLAS.gemv!( + Dagger.@spawn _gemv!( transA, alpha, In(Ac[k, m]), diff --git a/src/array/trsm.jl b/src/array/trsm.jl index 65e87c5d5..c0c025468 100644 --- a/src/array/trsm.jl +++ b/src/array/trsm.jl @@ -189,4 +189,4 @@ function trsm!(side::Char, uplo::Char, trans::Char, diag::Char, alpha::T, A::DMa end end -end \ No newline at end of file +end diff --git a/src/chunks.jl b/src/chunks.jl index 03bdfb65d..0defc1ff6 100644 --- a/src/chunks.jl +++ b/src/chunks.jl @@ -1,56 +1,4 @@ -export domain, UnitDomain, project, alignfirst, ArrayDomain - -import Base: isempty, getindex, intersect, ==, size, length, ndims - -""" - domain(x::T) - -Returns metadata about `x`. This metadata will be in the `domain` -field of a Chunk object when an object of type `T` is created as -the result of evaluating a Thunk. -""" -function domain end - -""" - UnitDomain - -Default domain -- has no information about the value -""" -struct UnitDomain end - -""" -If no `domain` method is defined on an object, then -we use the `UnitDomain` on it. A `UnitDomain` is indivisible. -""" -domain(x::Any) = UnitDomain() - -###### Chunk ###### - -""" - Chunk - -A reference to a piece of data located on a remote worker. `Chunk`s are -typically created with `Dagger.tochunk(data)`, and the data can then be -accessed from any worker with `collect(::Chunk)`. `Chunk`s are -serialization-safe, and use distributed refcounting (provided by -`MemPool.DRef`) to ensure that the data referenced by a `Chunk` won't be GC'd, -as long as a reference exists on some worker. - -Each `Chunk` is associated with a given `Dagger.Processor`, which is (in a -sense) the processor that "owns" or contains the data. Calling -`collect(::Chunk)` will perform data movement and conversions defined by that -processor to safely serialize the data to the calling worker. - -## Constructors -See [`tochunk`](@ref). -""" -mutable struct Chunk{T, H, P<:Processor, S<:AbstractScope} - chunktype::Type{T} - domain - handle::H - processor::P - scope::S -end +###### Chunk Methods ###### domain(c::Chunk) = c.domain chunktype(c::Chunk) = c.chunktype @@ -72,20 +20,27 @@ function collect(ctx::Context, chunk::Chunk; options=nothing) elseif chunk.handle isa FileRef return poolget(chunk.handle) else - return move(chunk.processor, OSProc(), chunk.handle) + return move(chunk.processor, default_processor(), chunk.handle) end end collect(ctx::Context, ref::DRef; options=nothing) = move(OSProc(ref.owner), OSProc(), ref) collect(ctx::Context, ref::FileRef; options=nothing) = poolget(ref) # FIXME: Do move call -function Base.fetch(chunk::Chunk; raw=false) - if raw - poolget(chunk.handle) - else - collect(chunk) +@warn "Fix semantics of collect" maxlog=1 +function Base.fetch(chunk::Chunk{T}; unwrap::Bool=false, uniform::Bool=false, kwargs...) where T + value = fetch_handle(chunk.handle; uniform)::T + if unwrap && unwrappable(value) + return fetch(value; unwrap, uniform, kwargs...) end + return value end +fetch_handle(ref::DRef; uniform::Bool=false) = poolget(ref) +fetch_handle(ref::FileRef; uniform::Bool=false) = poolget(ref) +unwrappable(x::Chunk) = true +unwrappable(x::DRef) = true +unwrappable(x::FileRef) = true +unwrappable(x) = false # Unwrap Chunk, DRef, and FileRef by default move(from_proc::Processor, to_proc::Processor, x::Chunk) = @@ -100,32 +55,3 @@ move(to_proc::Processor, d::DRef) = move(OSProc(d.owner), to_proc, d) move(to_proc::Processor, x) = move(OSProc(), to_proc, x) - -### ChunkIO -affinity(r::DRef) = OSProc(r.owner)=>r.size -# this previously returned a vector with all machines that had the file cached -# but now only returns the owner and size, for consistency with affinity(::DRef), -# see #295 -affinity(r::FileRef) = OSProc(1)=>r.size - -struct WeakChunk - wid::Int - id::Int - x::WeakRef - function WeakChunk(c::Chunk) - return new(c.handle.owner, c.handle.id, WeakRef(c)) - end -end -unwrap_weak(c::WeakChunk) = c.x.value -function unwrap_weak_checked(c::WeakChunk) - cw = unwrap_weak(c) - @assert cw !== nothing "WeakChunk expired: ($(c.wid), $(c.id))" - return cw -end -wrap_weak(c::Chunk) = WeakChunk(c) -isweak(c::WeakChunk) = true -isweak(c::Chunk) = false -is_task_or_chunk(c::WeakChunk) = true -Serialization.serialize(io::AbstractSerializer, wc::WeakChunk) = - error("Cannot serialize a WeakChunk") -chunktype(c::WeakChunk) = chunktype(unwrap_weak_checked(c)) diff --git a/src/datadeps/aliasing.jl b/src/datadeps/aliasing.jl index c3e0ed20b..e9ff24a79 100644 --- a/src/datadeps/aliasing.jl +++ b/src/datadeps/aliasing.jl @@ -8,7 +8,7 @@ export In, Out, InOut, Deps, spawn_datadeps ============================================================================== This file implements the data dependencies system for Dagger tasks, which allows -tasks to access their arguments in a controlled manner. The system maintains +tasks to write to their arguments in a controlled manner. The system maintains data coherency across distributed workers by tracking aliasing relationships and orchestrating data movement operations. @@ -25,59 +25,26 @@ KEY CONCEPTS: 1. ALIASING ANALYSIS: - Every mutable argument is analyzed for its memory access pattern - Memory spans are computed to determine which bytes in memory are accessed - - Arguments that access overlapping memory spans are considered "aliasing" + - Objects that access overlapping memory spans are considered "aliasing" - Examples: An array A and view(A, 2:3, 2:3) alias each other 2. DATA LOCALITY TRACKING: - The system tracks where the "source of truth" for each piece of data lives - As tasks execute and modify data, the source of truth may move between workers - - Each argument can have its own independent source of truth location + - Each aliasing region can have its own independent source of truth location 3. ALIASED OBJECT MANAGEMENT: - When copying arguments between workers, the system tracks "aliased objects" - This ensures that if both an array and its view need to be copied to a worker, only one copy of the underlying array is made, with the view pointing to it - - The aliased_object!() and move_rewrap() functions manage this sharing - -ALIASING INFO: --------------- - -The system uses different types of aliasing info to represent different types of -aliasing relationships: - -- ContiguousAliasing: Single contiguous memory region (e.g., full array) -- StridedAliasing: Multiple non-contiguous regions (e.g., SubArray) -- DiagonalAliasing: Diagonal elements only (e.g., Diagonal(A)) -- TriangularAliasing: Triangular regions (e.g., UpperTriangular(A)) - -Any two aliasing objects can be compared using the will_alias function to -determine if they overlap. Additionally, any aliasing object can be converted to -a vector of memory spans, which represents the contiguous regions of memory that -the aliasing object covers. - -DATA MOVEMENT FUNCTIONS: ------------------------- - -move!(dep_mod, to_space, from_space, to, from): -- The core in-place data movement function -- dep_mod specifies which part of the data to copy (identity, UpperTriangular, etc.) -- Supports partial copies via RemainderAliasing dependency modifiers - -move_rewrap(...): -- Handles copying of wrapped objects (SubArrays, ChunkViews) -- Ensures aliased objects are reused on destination worker - -read/write_remainder!(...): -- Read/write a span of memory from an object to/from a buffer -- Used by move! to copy the remainder of an aliased object + - The aliased_object!() functions manage this sharing THE DISTRIBUTED ALIASING PROBLEM: --------------------------------- In a multithreaded environment, aliasing "just works" because all tasks operate -on the user-provided memory. However, in a distributed environment, arguments -must be copied between workers, which breaks aliasing relationships if care is -not taken. +on the same memory. However, in a distributed environment, arguments must be +copied between workers, which breaks aliasing relationships. Consider this scenario: ```julia @@ -96,9 +63,11 @@ MULTITHREADED BEHAVIOR (WORKS): - Task dependencies ensure correct ordering (e.g., Task 1 then Task 2) DISTRIBUTED BEHAVIOR (THE PROBLEM): +- Tasks may be scheduled on different workers - Each argument must be copied to the destination worker -- Without special handling, we would copy A and vA independently to another worker -- This creates two separate arrays, breaking the aliasing relationship between A and vA +- Without special handling, we would copy A to worker1 and vA to worker2 +- This creates two separate arrays, breaking the aliasing relationship +- Updates to the view on worker2 don't affect the array on worker1 THE SOLUTION - PARTIAL DATA MOVEMENT: ------------------------------------- @@ -112,13 +81,12 @@ The datadeps system solves this by: 2. PARTIAL DATA TRANSFER: - Instead of copying entire objects, only transfer the "dirty" regions - - This prevents overwrites of data that has already been updated by another task - - This also minimizes network traffic and overall copy time - - Uses the move!(dep_mod, ...) function with RemainderAliasing dependency modifiers + - This minimizes network traffic and maximizes parallelism + - Uses the move!(dep_mod, ...) function with dependency modifiers 3. REMAINDER TRACKING: - - When a task needs the full object, copy partial regions as needed - When a partial region is updated, track what parts still need updating + - Before a task needs the full object, copy the remaining "clean" regions - This preserves all updates while avoiding overwrites EXAMPLE EXECUTION FLOW: @@ -140,24 +108,69 @@ Tasks: T1 modifies InOut(A), T2 modifies InOut(vA) - T2 needs vA, but vA aliases with A (which was modified by T1) - Copy vA-region of A from worker1 to worker2 - This is a PARTIAL copy - only the 2:3, 2:3 region - - Create vA on worker2 pointing to the appropriate region of A + - Create vA on worker2 pointing to the appropriate region - T2 executes, modifying vA region on worker2 - Update: vA's data_locality = worker2 4. FINAL SYNCHRONIZATION: - - Need to copy-back A and vA to worker0 - - A needs to be assembled from: worker1 (non-vA regions of A) + worker2 (vA region of A) - - REMAINDER COPY: Copy non-vA regions from worker1 to worker0 - - REMAINDER COPY: Copy vA region from worker2 to worker0 + - Some future task needs the complete A + - A needs to be assembled from: worker1 (non-vA regions) + worker2 (vA region) + - REMAINDER COPY: Copy non-vA regions from worker1 to worker2 + - OR INVERSE: Copy vA-region from worker2 to worker1, then copy full A + +MEMORY SPAN COMPUTATION: +------------------------ -REMAINDER COMPUTATION: ----------------------- +The system uses memory spans to determine aliasing and compute remainders: + +- ContiguousAliasing: Single contiguous memory region (e.g., full array) +- StridedAliasing: Multiple non-contiguous regions (e.g., SubArray) +- DiagonalAliasing: Diagonal elements only (e.g., Diagonal(A)) +- TriangularAliasing: Triangular regions (e.g., UpperTriangular(A)) Remainder computation involves: 1. Computing memory spans for all overlapping aliasing objects 2. Finding the set difference: full_object_spans - updated_spans -3. Creating a RemainderAliasing object representing the difference between spans -4. Performing one or more move! calls with this RemainderAliasing object to copy only needed data +3. Creating a "remainder aliasing" object representing the not-yet-updated regions +4. Performing move! with this remainder object to copy only needed data + +DATA MOVEMENT FUNCTIONS: +------------------------ + +move!(dep_mod, to_space, from_space, to, from): +- The core in-place data movement function +- dep_mod specifies which part of the data to copy (identity, UpperTriangular, etc.) +- Supports partial copies via dependency modifiers + +move_rewrap(): +- Handles copying of wrapped objects (SubArrays, ChunkViews) +- Ensures aliased objects are reused on destination worker + +enqueue_copy_to!(): +- Schedules data movement tasks before user tasks +- Ensures data is up-to-date on the worker where a task will run + +CURRENT LIMITATIONS AND TODOS: +------------------------------- + +1. REMAINDER COMPUTATION: + - The system currently handles simple overlaps but needs sophisticated + remainder calculation for complex aliasing patterns + - Need functions to compute span set differences + +2. ORDERING DEPENDENCIES: + - Need to ensure remainder copies happen in correct order + - Must not overwrite more recent updates with stale data + +3. COMPLEX ALIASING PATTERNS: + - Multiple overlapping views of the same array + - Nested aliasing structures (views of views) + - Mixed aliasing types (diagonal + triangular regions) + +4. PERFORMANCE OPTIMIZATION: + - Minimize number of copy operations + - Batch compatible transfers + - Optimize for common access patterns =# "Specifies a read-only dependency." @@ -179,11 +192,6 @@ struct Deps{T,DT<:Tuple} end Deps(x, deps...) = Deps(x, deps) -chunktype(::In{T}) where T = T -chunktype(::Out{T}) where T = T -chunktype(::InOut{T}) where T = T -chunktype(::Deps{T,DT}) where {T,DT} = T - function unwrap_inout(arg) readdep = false writedep = false @@ -218,6 +226,7 @@ _identity_hash(arg::Chunk, h::UInt=UInt(0)) = hash(arg.handle, hash(Chunk, h)) _identity_hash(arg::SubArray, h::UInt=UInt(0)) = hash(arg.indices, hash(arg.offset1, hash(arg.stride1, _identity_hash(arg.parent, h)))) _identity_hash(arg::CartesianIndices, h::UInt=UInt(0)) = hash(arg.indices, hash(typeof(arg), h)) +@warn "Dispatch bcast behavior on acceleration" maxlog=1 struct ArgumentWrapper arg dep_mod @@ -226,6 +235,7 @@ struct ArgumentWrapper function ArgumentWrapper(arg, dep_mod) h = hash(dep_mod) h = _identity_hash(arg, h) + check_uniform(h, arg) return new(arg, dep_mod, h) end end @@ -340,7 +350,7 @@ function set_key_stored!(cache::AliasedObjectCache, space::MemorySpace, ainfo::A cache_raw = unwrap(cache.chunk)::AliasedObjectCacheStore set_key_stored!(cache_raw, space, ainfo, value) end -function aliased_object!(f, cache::AliasedObjectCache, x; ainfo=aliasing(x, identity)) +function aliased_object!(f, cache::AliasedObjectCache, x; ainfo=aliasing(current_acceleration(), x, identity)) x_space = memory_space(x) if !is_key_present(cache, x_space, ainfo) # Preserve the object's memory-space/processor pairing when inserting @@ -356,13 +366,14 @@ function aliased_object!(f, cache::AliasedObjectCache, x; ainfo=aliasing(x, iden @assert y isa Chunk "Didn't get a Chunk from functor" @assert memory_space(y) == cache.space "Space mismatch! $(memory_space(y)) != $(cache.space)" if memory_space(x) != cache.space - @assert ainfo != aliasing(y, identity) "Aliasing mismatch! $ainfo == $(aliasing(y, identity))" + @assert ainfo != aliasing(current_acceleration(), y, identity) "Aliasing mismatch! $ainfo == $(aliasing(current_acceleration(), y, identity))" end set_stored!(cache, y, ainfo) return y end end +@warn "Switch ArgumentWrapper to contain just the argument, and add DependencyWrapper" maxlog=1 struct DataDepsState # The mapping of original raw argument to its Chunk raw_arg_to_chunk::IdDict{Any,Chunk} @@ -378,13 +389,10 @@ struct DataDepsState # The mapping of remote argument to original argument remote_arg_to_original::IdDict{Any,Any} - # The mapping of original argument wrapper to remote argument wrapper - remote_arg_w::Dict{ArgumentWrapper,Dict{MemorySpace,ArgumentWrapper}} - # The mapping of ainfo to argument and dep_mod # Used to lookup which argument and dep_mod a given ainfo is generated from # N.B. This is a mapping for remote argument copies - ainfo_arg::Dict{AliasingWrapper,Set{ArgumentWrapper}} + ainfo_arg::Dict{AliasingWrapper,ArgumentWrapper} # The history of writes (direct or indirect) to each argument and dep_mod, in terms of ainfos directly written to, and the memory space they were written to # Updated when a new write happens on an overlapping ainfo @@ -402,7 +410,7 @@ struct DataDepsState # The mapping of, for a given memory space, the backing Chunks that an ainfo references # Used by slot generation to replace the backing Chunks during move - ainfo_backing_chunk::Chunk{AliasedObjectCacheStore} + ainfo_backing_chunk::Dict{MemorySpace,Dict{AbstractAliasing,Chunk}} # Cache of argument's supports_inplace_move query result supports_inplace_cache::IdDict{Any,Bool} @@ -411,10 +419,6 @@ struct DataDepsState # N.B. This is a mapping for remote argument copies ainfo_cache::Dict{ArgumentWrapper,AliasingWrapper} - # The oracle for aliasing lookups - # Used to populate ainfos_overlaps efficiently - ainfos_lookup::AliasingLookup - # The overlapping ainfos for each ainfo # Incrementally updated as new ainfos are created # Used for fast will_alias lookups @@ -426,32 +430,60 @@ struct DataDepsState ainfos_owner::Dict{AliasingWrapper,Union{Pair{DTask,Int},Nothing}} ainfos_readers::Dict{AliasingWrapper,Vector{Pair{DTask,Int}}} - function DataDepsState() + function DataDepsState(aliasing::Bool) + if !aliasing + @warn "aliasing=false is no longer supported, aliasing is now always enabled" maxlog=1 + end + arg_to_chunk = IdDict{Any,Chunk}() arg_origin = IdDict{Any,MemorySpace}() remote_args = Dict{MemorySpace,IdDict{Any,Any}}() remote_arg_to_original = IdDict{Any,Any}() - remote_arg_w = Dict{ArgumentWrapper,Dict{MemorySpace,ArgumentWrapper}}() - ainfo_arg = Dict{AliasingWrapper,Set{ArgumentWrapper}}() - arg_history = Dict{ArgumentWrapper,Vector{HistoryEntry}}() + ainfo_arg = Dict{AliasingWrapper,ArgumentWrapper}() arg_owner = Dict{ArgumentWrapper,MemorySpace}() arg_overlaps = Dict{ArgumentWrapper,Set{ArgumentWrapper}}() - ainfo_backing_chunk = tochunk(AliasedObjectCacheStore()) + ainfo_backing_chunk = Dict{MemorySpace,Dict{AbstractAliasing,Chunk}}() + arg_history = Dict{ArgumentWrapper,Vector{HistoryEntry}}() supports_inplace_cache = IdDict{Any,Bool}() ainfo_cache = Dict{ArgumentWrapper,AliasingWrapper}() - ainfos_lookup = AliasingLookup() ainfos_overlaps = Dict{AliasingWrapper,Set{AliasingWrapper}}() ainfos_owner = Dict{AliasingWrapper,Union{Pair{DTask,Int},Nothing}}() ainfos_readers = Dict{AliasingWrapper,Vector{Pair{DTask,Int}}}() - return new(arg_to_chunk, arg_origin, remote_args, remote_arg_to_original, remote_arg_w, ainfo_arg, arg_history, arg_owner, arg_overlaps, ainfo_backing_chunk, - supports_inplace_cache, ainfo_cache, ainfos_lookup, ainfos_overlaps, ainfos_owner, ainfos_readers) + return new(arg_to_chunk, arg_origin, remote_args, remote_arg_to_original, ainfo_arg, arg_owner, arg_overlaps, ainfo_backing_chunk, arg_history, + supports_inplace_cache, ainfo_cache, ainfos_overlaps, ainfos_owner, ainfos_readers) end end +# N.B. arg_w must be the original argument wrapper, not a remote copy +function aliasing!(state::DataDepsState, target_space::MemorySpace, arg_w::ArgumentWrapper) + # Grab the remote copy of the argument, and calculate the ainfo + remote_arg = get_or_generate_slot!(state, target_space, arg_w.arg) + remote_arg_w = ArgumentWrapper(remote_arg, arg_w.dep_mod) + + # Check if we already have the result cached + if haskey(state.ainfo_cache, remote_arg_w) + return state.ainfo_cache[remote_arg_w] + end + + # Calculate the ainfo + ainfo = AliasingWrapper(aliasing(current_acceleration(), remote_arg, arg_w.dep_mod)) + + # Cache the result + state.ainfo_cache[remote_arg_w] = ainfo + + # Update the mapping of ainfo to argument and dep_mod + state.ainfo_arg[ainfo] = remote_arg_w + + # Populate info for the new ainfo + populate_ainfo!(state, arg_w, ainfo, target_space) + + return ainfo +end + function supports_inplace_move(state::DataDepsState, arg) return get!(state.supports_inplace_cache, arg) do return supports_inplace_move(arg) @@ -465,72 +497,70 @@ function is_writedep(arg, deps, task::DTask) end # Aliasing state setup -function populate_task_info!(state::DataDepsState, task_args, spec::DTaskSpec, task::DTask) - # Track the task's arguments and access patterns - return map_or_ntuple(task_args) do idx - _arg = task_args[idx] - - # Unwrap the argument - _arg_with_deps = value(_arg) - pos = _arg.pos +# Internal: iterate over task args and call callback(arg, pos, may_alias, inplace_move, deps) for each tracked arg. +function _populate_task_info!(state::DataDepsState, spec::DTaskSpec, task::DTask, callback) + for (idx, _arg) in enumerate(spec.fargs) + arg_pos = _arg.pos # ArgPosition for this argument (Argument/TypedArgument have .pos) + arg = value(_arg) # Unwrap In/InOut/Out wrappers and record dependencies - arg_pre_unwrap, deps = unwrap_inout(_arg_with_deps) - - # Unwrap the Chunk underlying any DTask arguments - arg = arg_pre_unwrap isa DTask ? fetch(arg_pre_unwrap; raw=true) : arg_pre_unwrap - - # Skip non-aliasing arguments or arguments that don't support in-place move - may_alias = type_may_alias(typeof(arg)) - inplace_move = may_alias && supports_inplace_move(state, arg) - if !may_alias || !inplace_move - arg_w = ArgumentWrapper(arg, identity) - if is_typed(spec) - return TypedDataDepsTaskArgument(arg, pos, may_alias, inplace_move, (DataDepsTaskDependency(arg_w, false, false),)) - else - return DataDepsTaskArgument(arg, pos, may_alias, inplace_move, [DataDepsTaskDependency(arg_w, false, false)]) - end + arg, deps = unwrap_inout(arg) + + # Unwrap the Chunk underlying any DTask arguments only when already ready. + # Fetching an unready DTask here would deadlock: distribute_tasks! runs before + # the scheduler, so dependent tasks have not run yet. Skip aliasing for unready + # DTasks so we pass them through; the worker will fetch at execution time (may block on MPI). + if arg isa DTask + isready(arg) || continue + arg = fetch(arg; move_value=false, unwrap=false) end + # Skip non-aliasing arguments + type_may_alias(typeof(arg)) || continue + + # Skip arguments not supporting in-place move + supports_inplace_move(state, arg) || continue + # Generate a Chunk for the argument if necessary if haskey(state.raw_arg_to_chunk, arg) - arg_chunk = state.raw_arg_to_chunk[arg] + arg = state.raw_arg_to_chunk[arg] else if !(arg isa Chunk) - arg_chunk = tochunk(arg) - state.raw_arg_to_chunk[arg] = arg_chunk + new_arg = with(MPI_UID=>task.uid) do + tochunk(arg) + end + state.raw_arg_to_chunk[arg] = new_arg + arg = new_arg else state.raw_arg_to_chunk[arg] = arg - arg_chunk = arg end end # Track the origin space of the argument - origin_space = memory_space(arg_chunk) - state.arg_origin[arg_chunk] = origin_space - state.remote_arg_to_original[arg_chunk] = arg_chunk + origin_space = memory_space(arg) + check_uniform(origin_space) + state.arg_origin[arg] = origin_space + state.remote_arg_to_original[arg] = arg + + may_alias = true + inplace_move = true + callback(arg, arg_pos, may_alias, inplace_move, deps) # Populate argument info for all aliasing dependencies - # And return the argument, dependencies, and ArgumentWrappers - if is_typed(spec) - deps = Tuple(DataDepsTaskDependency(arg_chunk, dep) for dep in deps) - map_or_ntuple(deps) do dep_idx - dep = deps[dep_idx] - # Populate argument info - populate_argument_info!(state, dep.arg_w, origin_space) - end - return TypedDataDepsTaskArgument(arg_chunk, pos, may_alias, inplace_move, deps) - else - deps = [DataDepsTaskDependency(arg_chunk, dep) for dep in deps] - map_or_ntuple(deps) do dep_idx - dep = deps[dep_idx] - # Populate argument info - populate_argument_info!(state, dep.arg_w, origin_space) - end - return DataDepsTaskArgument(arg_chunk, pos, may_alias, inplace_move, deps) + for (dep_mod, _, _) in deps + # Generate an ArgumentWrapper for the argument + aw = ArgumentWrapper(arg, dep_mod) + + # Populate argument info + populate_argument_info!(state, aw, origin_space) end end end + +function populate_task_info!(state::DataDepsState, spec::DTaskSpec, task::DTask) + # Track the task's arguments and access patterns (callback only for state updates) + _populate_task_info!(state, spec, task, (arg, pos, may_alias, inplace_move, deps) -> nothing) +end function populate_argument_info!(state::DataDepsState, arg_w::ArgumentWrapper, origin_space::MemorySpace) # Initialize ownership and history if !haskey(state.arg_owner, arg_w) @@ -550,56 +580,23 @@ function populate_argument_info!(state::DataDepsState, arg_w::ArgumentWrapper, o # Calculate the ainfo (which will populate ainfo structures and merge history) aliasing!(state, origin_space, arg_w) end -# N.B. arg_w must be the original argument wrapper, not a remote copy -function aliasing!(state::DataDepsState, target_space::MemorySpace, arg_w::ArgumentWrapper) - if haskey(state.remote_arg_w, arg_w) && haskey(state.remote_arg_w[arg_w], target_space) - remote_arg_w = @inbounds state.remote_arg_w[arg_w][target_space] - remote_arg = remote_arg_w.arg - else - # Grab the remote copy of the argument, and calculate the ainfo - remote_arg = get_or_generate_slot!(state, target_space, arg_w.arg) - remote_arg_w = ArgumentWrapper(remote_arg, arg_w.dep_mod) - get!(Dict{MemorySpace,ArgumentWrapper}, state.remote_arg_w, arg_w)[target_space] = remote_arg_w - end - - # Check if we already have the result cached - if haskey(state.ainfo_cache, remote_arg_w) - return state.ainfo_cache[remote_arg_w] - end - - # Calculate the ainfo - ainfo = AliasingWrapper(aliasing(remote_arg, arg_w.dep_mod)) - - # Cache the result - state.ainfo_cache[remote_arg_w] = ainfo - - # Update the mapping of ainfo to argument and dep_mod - if !haskey(state.ainfo_arg, ainfo) - state.ainfo_arg[ainfo] = Set{ArgumentWrapper}([remote_arg_w]) - end - push!(state.ainfo_arg[ainfo], remote_arg_w) - - # Populate info for the new ainfo - populate_ainfo!(state, arg_w, ainfo, target_space) - - return ainfo -end function populate_ainfo!(state::DataDepsState, original_arg_w::ArgumentWrapper, target_ainfo::AliasingWrapper, target_space::MemorySpace) + # Initialize owner and readers if !haskey(state.ainfos_owner, target_ainfo) - # Add ourselves to the lookup oracle - ainfo_idx = push!(state.ainfos_lookup, target_ainfo) - - # Find overlapping ainfos overlaps = Set{AliasingWrapper}() push!(overlaps, target_ainfo) - for other_ainfo in intersect(state.ainfos_lookup, target_ainfo; ainfo_idx) + other_ainfos = (Dagger.current_acceleration() isa Dagger.MPIAcceleration + ? sort(collect(keys(state.ainfos_owner)), by=hash) + : keys(state.ainfos_owner)) + for other_ainfo in other_ainfos target_ainfo == other_ainfo && continue - # Mark us and them as overlapping - push!(overlaps, other_ainfo) - push!(state.ainfos_overlaps[other_ainfo], target_ainfo) + if will_alias(target_ainfo, other_ainfo) + # Mark us and them as overlapping + push!(overlaps, other_ainfo) + push!(state.ainfos_overlaps[other_ainfo], target_ainfo) - # Add overlapping history to our own - for other_remote_arg_w in state.ainfo_arg[other_ainfo] + # Add overlapping history to our own + other_remote_arg_w = state.ainfo_arg[other_ainfo] other_arg = state.remote_arg_to_original[other_remote_arg_w.arg] other_arg_w = ArgumentWrapper(other_arg, other_remote_arg_w.dep_mod) push!(state.arg_overlaps[original_arg_w], other_arg_w) @@ -608,16 +605,13 @@ function populate_ainfo!(state::DataDepsState, original_arg_w::ArgumentWrapper, end end state.ainfos_overlaps[target_ainfo] = overlaps - - # Initialize owner and readers state.ainfos_owner[target_ainfo] = nothing state.ainfos_readers[target_ainfo] = Pair{DTask,Int}[] end end function merge_history!(state::DataDepsState, arg_w::ArgumentWrapper, other_arg_w::ArgumentWrapper) history = state.arg_history[arg_w] - @opcounter :merge_history - @opcounter :merge_history_complexity length(history) + largest_value_update!(length(history)) origin_space = state.arg_origin[other_arg_w.arg] for other_entry in state.arg_history[other_arg_w] write_num_tuple = HistoryEntry(AliasingWrapper(NoAliasing()), origin_space, other_entry.write_num) @@ -646,13 +640,10 @@ function merge_history!(state::DataDepsState, arg_w::ArgumentWrapper, other_arg_ end end function truncate_history!(state::DataDepsState, arg_w::ArgumentWrapper) - # FIXME: Do this continuously if possible if haskey(state.arg_history, arg_w) && length(state.arg_history[arg_w]) > 100000 origin_space = state.arg_origin[arg_w.arg] - @opcounter :truncate_history _, last_idx = compute_remainder_for_arg!(state, origin_space, arg_w, 0; compute_syncdeps=false) if last_idx > 0 - @opcounter :truncate_history_removed last_idx deleteat!(state.arg_history[arg_w], 1:last_idx) end end @@ -670,8 +661,11 @@ use of `x`, and the data in `x` will not be updated when the `spawn_datadeps` region returns. """ supports_inplace_move(x) = true -supports_inplace_move(t::DTask) = supports_inplace_move(fetch(t; raw=true)) +supports_inplace_move(t::DTask) = supports_inplace_move(fetch(t; move_value=false, unwrap=false)) +@warn "Fix this to work with MPI (can't call poolget on the wrong rank)" maxlog=1 function supports_inplace_move(c::Chunk) + # FIXME + return true # FIXME: Use MemPool.access_ref pid = root_worker_id(c.processor) if pid == myid() @@ -743,12 +737,19 @@ function add_reader!(state::DataDepsState, arg_w::ArgumentWrapper, dest_space::M push!(state.ainfos_readers[ainfo], task=>write_num) end +# FIXME: These should go in MPIExt.jl +const MPI_TID = ScopedValue{Int64}(0) +const MPI_UID = ScopedValue{Int64}(0) + # Make a copy of each piece of data on each worker # memory_space => {arg => copy_of_arg} isremotehandle(x) = false isremotehandle(x::DTask) = true isremotehandle(x::Chunk) = true function generate_slot!(state::DataDepsState, dest_space, data) + if data isa DTask + data = fetch(data; move_value=false, unwrap=false) + end # N.B. We do not perform any sync/copy with the current owner of the data, # because all we want here is to make a copy of some version of the data, # even if the data is not up to date. @@ -756,16 +757,30 @@ function generate_slot!(state::DataDepsState, dest_space, data) to_proc = first(processors(dest_space)) from_proc = first(processors(orig_space)) dest_space_args = get!(IdDict{Any,Any}, state.remote_args, dest_space) - aliased_object_cache = AliasedObjectCache(dest_space, state.ainfo_backing_chunk) - ctx = Sch.eager_context() - id = rand(Int) - @maybelog ctx timespan_start(ctx, :move, (;thunk_id=0, id, position=ArgPosition(), processor=to_proc), (;f=nothing, data)) - data_chunk = move_rewrap(aliased_object_cache, from_proc, to_proc, orig_space, dest_space, data) - @maybelog ctx timespan_finish(ctx, :move, (;thunk_id=0, id, position=ArgPosition(), processor=to_proc), (;f=nothing, data=data_chunk)) + ALIASED_OBJECT_CACHE[] = get!(Dict{AbstractAliasing,Chunk}, state.ainfo_backing_chunk, dest_space) + if orig_space == dest_space && (data isa Chunk || !isremotehandle(data)) + # Fast path for local data that's already in a Chunk or not a remote handle needing rewrapping + task = DATADEPS_CURRENT_TASK[] + data_chunk = with(MPI_UID=>task.uid) do + tochunk(data, from_proc) + end + else + ctx = Sch.eager_context() + id = rand(Int) + @maybelog ctx timespan_start(ctx, :move, (;thunk_id=0, id, position=ArgPosition(), processor=to_proc), (;f=nothing, data)) + data_chunk = move_rewrap(from_proc, to_proc, orig_space, dest_space, data) + @maybelog ctx timespan_finish(ctx, :move, (;thunk_id=0, id, position=ArgPosition(), processor=to_proc), (;f=nothing, data=data_chunk)) + end @assert memory_space(data_chunk) == dest_space "space mismatch! $dest_space (dest) != $(memory_space(data_chunk)) (actual) ($(typeof(data)) (data) vs. $(typeof(data_chunk)) (chunk)), spaces ($orig_space -> $dest_space)" dest_space_args[data] = data_chunk state.remote_arg_to_original[data_chunk] = data + ALIASED_OBJECT_CACHE[] = nothing + + check_uniform(memory_space(dest_space_args[data])) + check_uniform(processor(dest_space_args[data])) + check_uniform(dest_space_args[data].handle) + return dest_space_args[data] end function get_or_generate_slot!(state, dest_space, data) @@ -778,82 +793,86 @@ function get_or_generate_slot!(state, dest_space, data) end return state.remote_args[dest_space][data] end -function remotecall_endpoint(f, from_proc, to_proc, from_space, to_space, data) - to_w = root_worker_id(to_proc) - if to_w == myid() - data_converted = f(move(from_proc, to_proc, data)) - return tochunk(data_converted, to_proc) - end - return remotecall_fetch(to_w, from_proc, to_proc, to_space, data) do from_proc, to_proc, to_space, data - data_converted = f(move(from_proc, to_proc, data)) - return tochunk(data_converted, to_proc) +function move_rewrap(from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, data) + return aliased_object!(data) do data + return remotecall_endpoint(identity, current_acceleration(), from_proc, to_proc, from_space, to_space, data) end end -function rewrap_aliased_object!(cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, x) - return aliased_object!(cache, x) do x - return remotecall_endpoint(identity, from_proc, to_proc, from_space, to_space, x) +function remotecall_endpoint(f, ::Dagger.DistributedAcceleration, from_proc, to_proc, orig_space, dest_space, data) + to_w = root_worker_id(to_proc) + return remotecall_fetch(to_w, from_proc, to_proc, dest_space, data) do from_proc, to_proc, dest_space, data + data_converted = f(move(from_proc, to_proc, data)) + return tochunk(data_converted, to_proc, dest_space) end end -function move_rewrap(cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, data::Chunk) - # Unwrap so that we hit the right dispatch - wid = root_worker_id(data) - if wid != myid() - return remotecall_fetch(move_rewrap, wid, cache, from_proc, to_proc, from_space, to_space, data) - end - data_raw = unwrap(data) - return move_rewrap(cache, from_proc, to_proc, from_space, to_space, data_raw) +const ALIASED_OBJECT_CACHE = TaskLocalValue{Union{Dict{AbstractAliasing,Chunk}, Nothing}}(()->nothing) + +# Explicit cache for move_rewrap (used by haloarray, tests) +struct AliasedObjectCacheStore end +struct AliasedObjectCache + dest_space::MemorySpace + backing::Chunk + cache::Dict{AbstractAliasing,Chunk} + AliasedObjectCache(dest_space::MemorySpace, backing::Chunk) = new(dest_space, backing, Dict{AbstractAliasing,Chunk}()) end function move_rewrap(cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, data) - # For generic data - return aliased_object!(cache, data) do data - return remotecall_endpoint(identity, from_proc, to_proc, from_space, to_space, data) + old = ALIASED_OBJECT_CACHE[] + ALIASED_OBJECT_CACHE[] = cache.cache + try + return move_rewrap(from_proc, to_proc, from_space, to_space, data) + finally + ALIASED_OBJECT_CACHE[] = old end end -function move_rewrap(cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, v::SubArray) - to_w = root_worker_id(to_proc) - p_chunk = rewrap_aliased_object!(cache, from_proc, to_proc, from_space, to_space, parent(v)) - inds = parentindices(v) - return remotecall_fetch(to_w, from_proc, to_proc, from_space, to_space, p_chunk, inds) do from_proc, to_proc, from_space, to_space, p_chunk, inds - p_new = move(from_proc, to_proc, p_chunk) - v_new = view(p_new, inds...) - return tochunk(v_new, to_proc) - end -end -# FIXME: Do this programmatically via recursive dispatch -for wrapper in (UpperTriangular, LowerTriangular, UnitUpperTriangular, UnitLowerTriangular) - @eval function move_rewrap(cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, v::$(wrapper)) - to_w = root_worker_id(to_proc) - p_chunk = rewrap_aliased_object!(cache, from_proc, to_proc, from_space, to_space, parent(v)) - return remotecall_fetch(to_w, from_proc, to_proc, from_space, to_space, p_chunk) do from_proc, to_proc, from_space, to_space, p_chunk - p_new = move(from_proc, to_proc, p_chunk) - v_new = $(wrapper)(p_new) - return tochunk(v_new, to_proc) - end - end + +@warn "Document these public methods" maxlog=1 +# TODO: Use state to cache aliasing() results +function declare_aliased_object!(x; ainfo=aliasing(current_acceleration(), x, identity)) + cache = ALIASED_OBJECT_CACHE[] + cache[ainfo] = x end -function move_rewrap(cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, v::Base.RefValue) - return aliased_object!(cache, v) do v - return remotecall_endpoint(identity, from_proc, to_proc, from_space, to_space, v) +function aliased_object!(x; ainfo=aliasing(current_acceleration(), x, identity)) + cache = ALIASED_OBJECT_CACHE[] + if haskey(cache, ainfo) + y = cache[ainfo] + else + @assert x isa Chunk "x must be a Chunk\nUse functor form of aliased_object!" + cache[ainfo] = x + y = x end + return y end -#= FIXME: Make this work so we can automatically move-rewrap recursive objects -function move_rewrap_recursive(cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, x::T) where T - if isstructtype(T) - # Check all object fields (recursive) - for field in fieldnames(T) - value = getfield(x, field) - new_value = aliased_object!(cache, value) do value - return move_rewrap_recursive(cache, from_proc, to_proc, from_space, to_space, value) - end - setfield!(x, field, new_value) - end - return x +function aliased_object!(f, x; ainfo=aliasing(current_acceleration(), x, identity)) + cache = ALIASED_OBJECT_CACHE[] + if haskey(cache, ainfo) + y = cache[ainfo] else - @warn "Cannot move-rewrap object of type $T" - return x + y = f(x) + @assert y isa Chunk "Didn't get a Chunk from functor" + cache[ainfo] = y + end + return y +end +function aliased_object_unwrap!(x::Chunk) + y = unwrap(x) + ainfo = aliasing(current_acceleration(), y, identity) + return unwrap(aliased_object!(x; ainfo)) +end + +struct DataDepsSchedulerState + task_to_spec::Dict{DTask,DTaskSpec} + assignments::Dict{DTask,MemorySpace} + dependencies::Dict{DTask,Set{DTask}} + task_completions::Dict{DTask,UInt64} + space_completions::Dict{MemorySpace,UInt64} + capacities::Dict{MemorySpace,Int} + + function DataDepsSchedulerState() + return new(Dict{DTask,DTaskSpec}(), + Dict{DTask,MemorySpace}(), + Dict{DTask,Set{DTask}}(), + Dict{DTask,UInt64}(), + Dict{MemorySpace,UInt64}(), + Dict{MemorySpace,Int}()) end end -move_rewrap(cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, x::String) = x # FIXME: Not necessarily true -move_rewrap(cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, x::Symbol) = x -move_rewrap(cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, x::Type) = x -=# diff --git a/src/datadeps/chunkview.jl b/src/datadeps/chunkview.jl index 1c2aa600f..6e2a21dfd 100644 --- a/src/datadeps/chunkview.jl +++ b/src/datadeps/chunkview.jl @@ -3,10 +3,6 @@ struct ChunkView{N} slices::NTuple{N, Union{Int, AbstractRange{Int}, Colon}} end -function _identity_hash(arg::ChunkView, h::UInt=UInt(0)) - return hash(arg.slices, _identity_hash(arg.chunk, h)) -end - function Base.view(c::Chunk, slices...) if c.domain isa ArrayDomain nd, sz = ndims(c.domain), size(c.domain) @@ -29,39 +25,31 @@ function Base.view(c::Chunk, slices...) return ChunkView(c, slices) end -Base.view(c::DTask, slices...) = view(fetch(c; raw=true), slices...) +Base.view(c::DTask, slices...) = view(fetch(c; move_value=false, unwrap=false), slices...) -function aliasing(x::ChunkView{N}) where N - return remotecall_fetch(root_worker_id(x.chunk.processor), x.chunk, x.slices) do x, slices - x = unwrap(x) - v = view(x, slices...) - return aliasing(v) - end -end +aliasing(x::ChunkView) = + throw(ConcurrencyViolationError("Cannot query aliasing of a ChunkView directly")) memory_space(x::ChunkView) = memory_space(x.chunk) isremotehandle(x::ChunkView) = true -function move_rewrap(cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, slice::ChunkView) - to_w = root_worker_id(to_proc) - # N.B. We use move_rewrap (not rewrap_aliased_object!) so that if the inner - # chunk is a SubArray, it goes through the SubArray-aware path which shares - # the parent array via the aliased object cache. Using rewrap_aliased_object! - # would simply serialize the entire SubArray, creating a new parent copy on - # the destination, breaking aliasing with other views of the same parent. - p_chunk = move_rewrap(cache, from_proc, to_proc, from_space, to_space, slice.chunk) - return remotecall_fetch(to_w, from_proc, to_proc, from_space, to_space, p_chunk, slice.slices) do from_proc, to_proc, from_space, to_space, p_chunk, inds - p_new = move(from_proc, to_proc, p_chunk) - v_new = view(p_new, inds...) - return tochunk(v_new, to_proc) +# This definition is here because it's so similar to ChunkView +function move_rewrap(from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, v::SubArray) + p_chunk = aliased_object!(parent(v)) do p_chunk + return remotecall_endpoint(identity, current_acceleration(), from_proc, to_proc, from_space, to_space, p_chunk) + end + inds = parentindices(v) + return remotecall_endpoint(current_acceleration(), from_proc, to_proc, from_space, to_space, p_chunk) do p_new + return view(p_new, inds...) end end -function move(from_proc::Processor, to_proc::Processor, slice::ChunkView) - to_w = root_worker_id(to_proc) - return remotecall_fetch(to_w, from_proc, to_proc, slice.chunk, slice.slices) do from_proc, to_proc, chunk, slices - chunk_new = move(from_proc, to_proc, chunk) - v_new = view(chunk_new, slices...) - return tochunk(v_new, to_proc) +function move_rewrap(from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, slice::ChunkView) + p_chunk = aliased_object!(slice.chunk) do p_chunk + return remotecall_endpoint(identity, current_acceleration(), from_proc, to_proc, from_space, to_space, p_chunk) + end + inds = slice.slices + return remotecall_endpoint(current_acceleration(), from_proc, to_proc, from_space, to_space, p_chunk) do p_new + return view(p_new, inds...) end end -Base.fetch(slice::ChunkView) = view(fetch(slice.chunk), slice.slices...) \ No newline at end of file +Base.fetch(slice::ChunkView) = view(fetch(slice.chunk), slice.slices...) diff --git a/src/datadeps/queue.jl b/src/datadeps/queue.jl index 96112cf15..70e7543eb 100644 --- a/src/datadeps/queue.jl +++ b/src/datadeps/queue.jl @@ -1,4 +1,21 @@ -struct DataDepsTaskQueue{Scheduler<:DataDepsScheduler} <: AbstractTaskQueue + +const TAG_WAITING = Base.Lockable(Ref{UInt32}(1)) +function to_tag() + intask = Dagger.in_task() + if intask + opts = Dagger.get_tls().task_spec.options + tag = opts.tag + return tag + end + lock(TAG_WAITING) do counter_ref + @assert Sch.SCHED_MOVE[] == false "We should not create a tag on the scheduler unwrap move" + tag = counter_ref[] + counter_ref[] = tag + 1 > MPI.tag_ub() ? 1 : tag + 1 + return tag + end +end + +struct DataDepsTaskQueue <: AbstractTaskQueue # The queue above us upper_queue::AbstractTaskQueue # The set of tasks that have already been seen @@ -7,14 +24,24 @@ struct DataDepsTaskQueue{Scheduler<:DataDepsScheduler} <: AbstractTaskQueue g::Union{SimpleDiGraph{Int},Nothing} # The mapping from task to graph ID task_to_id::Union{Dict{DTask,Int},Nothing} + # How to traverse the dependency graph when launching tasks + traversal::Symbol # Which scheduler to use to assign tasks to processors - scheduler::Scheduler + scheduler::Symbol + + # Whether aliasing across arguments is possible + # The fields following only apply when aliasing==true + aliasing::Bool - function DataDepsTaskQueue(upper_queue; scheduler::DataDepsScheduler) + function DataDepsTaskQueue(upper_queue; + traversal::Symbol=:inorder, + scheduler::Symbol=:naive, + aliasing::Bool=true) seen_tasks = DTaskPair[] g = SimpleDiGraph() task_to_id = Dict{DTask,Int}() - return new{typeof(scheduler)}(upper_queue, seen_tasks, g, task_to_id, scheduler) + return new(upper_queue, seen_tasks, g, task_to_id, traversal, scheduler, + aliasing) end end @@ -25,8 +52,10 @@ function enqueue!(queue::DataDepsTaskQueue, pairs::Vector{DTaskPair}) append!(queue.seen_tasks, pairs) end +const DATADEPS_CURRENT_TASK = TaskLocalValue{Union{DTask,Nothing}}(Returns(nothing)) + """ - spawn_datadeps(f::Base.Callable) + spawn_datadeps(f::Base.Callable; traversal::Symbol=:inorder) Constructs a "datadeps" (data dependencies) region and calls `f` within it. Dagger tasks launched within `f` may wrap their arguments with `In`, `Out`, or @@ -53,41 +82,46 @@ appropriately. At the end of executing `f`, `spawn_datadeps` will wait for all launched tasks to complete, rethrowing the first error, if any. The result of `f` will be returned from `spawn_datadeps`. + +The keyword argument `traversal` controls the order that tasks are launched by +the scheduler, and may be set to `:bfs` or `:dfs` for Breadth-First Scheduling +or Depth-First Scheduling, respectively. All traversal orders respect the +dependencies and ordering of the launched tasks, but may provide better or +worse performance for a given set of datadeps tasks. This argument is +experimental and subject to change. """ function spawn_datadeps(f::Base.Callable; static::Bool=true, traversal::Symbol=:inorder, - scheduler::Union{DataDepsScheduler,Nothing}=nothing, + scheduler::Union{Symbol,Nothing}=nothing, aliasing::Bool=true, launch_wait::Union{Bool,Nothing}=nothing) if !static throw(ArgumentError("Dynamic scheduling is no longer available")) end - if traversal != :inorder - throw(ArgumentError("Traversal order is no longer configurable, and always :inorder")) - end - if !aliasing - throw(ArgumentError("Aliasing analysis is no longer optional")) - end wait_all(; check_errors=true) do - scheduler = something(scheduler, DATADEPS_SCHEDULER[], RoundRobinScheduler()) + scheduler = something(scheduler, DATADEPS_SCHEDULER[], :roundrobin)::Symbol launch_wait = something(launch_wait, DATADEPS_LAUNCH_WAIT[], false)::Bool if launch_wait result = spawn_bulk() do - queue = DataDepsTaskQueue(get_options(:task_queue); scheduler) + queue = DataDepsTaskQueue(get_options(:task_queue); + traversal, scheduler, aliasing) with_options(f; task_queue=queue) distribute_tasks!(queue) end else - queue = DataDepsTaskQueue(get_options(:task_queue); scheduler) + queue = DataDepsTaskQueue(get_options(:task_queue); + traversal, scheduler, aliasing) result = with_options(f; task_queue=queue) distribute_tasks!(queue) end + DATADEPS_CURRENT_TASK[] = nothing return result end end -const DATADEPS_SCHEDULER = ScopedValue{Union{DataDepsScheduler,Nothing}}(nothing) +const DATADEPS_SCHEDULER = ScopedValue{Union{Symbol,Nothing}}(nothing) const DATADEPS_LAUNCH_WAIT = ScopedValue{Union{Bool,Nothing}}(nothing) +@warn "Don't blindly set occupancy=0, only do for MPI" maxlog=1 function distribute_tasks!(queue::DataDepsTaskQueue) #= TODO: Improvements to be made: # - Support for copying non-AbstractArray arguments @@ -98,37 +132,96 @@ function distribute_tasks!(queue::DataDepsTaskQueue) =# # Get the set of all processors to be scheduled on - all_procs = Processor[] scope = get_compute_scope() - for w in procs() - append!(all_procs, get_processors(OSProc(w))) + accel = current_acceleration() + accel_procs = filter(procs(Dagger.Sch.eager_context())) do proc + Dagger.accel_matches_proc(accel, proc) end + all_procs = unique(vcat([collect(Dagger.get_processors(gp)) for gp in accel_procs]...)) + # FIXME: This is an unreliable way to ensure processor uniformity + sort!(all_procs, by=short_name) filter!(proc->proc_in_scope(proc, scope), all_procs) if isempty(all_procs) throw(Sch.SchedulingException("No processors available, try widening scope")) end - all_scope = UnionScope(map(ExactScope, all_procs)) exec_spaces = unique(vcat(map(proc->collect(memory_spaces(proc)), all_procs)...)) - if !all(space->space isa CPURAMMemorySpace, exec_spaces) && !all(space->root_worker_id(space) == myid(), exec_spaces) + #=if !all(space->space isa CPURAMMemorySpace, exec_spaces) && !all(space->root_worker_id(space) == myid(), exec_spaces) @warn "Datadeps support for multi-GPU, multi-worker is currently broken\nPlease be prepared for incorrect results or errors" maxlog=1 + end=# + for proc in all_procs + check_uniform(proc) end # Round-robin assign tasks to processors upper_queue = get_options(:task_queue) + traversal = queue.traversal + if traversal == :inorder + # As-is + task_order = Colon() + elseif traversal == :bfs + # BFS + task_order = Int[1] + to_walk = Int[1] + seen = Set{Int}([1]) + while !isempty(to_walk) + # N.B. next_root has already been seen + next_root = popfirst!(to_walk) + for v in outneighbors(queue.g, next_root) + if !(v in seen) + push!(task_order, v) + push!(seen, v) + push!(to_walk, v) + end + end + end + elseif traversal == :dfs + # DFS (modified with backtracking) + task_order = Int[] + to_walk = Int[1] + seen = Set{Int}() + while length(task_order) < length(queue.seen_tasks) && !isempty(to_walk) + next_root = popfirst!(to_walk) + if !(next_root in seen) + iv = inneighbors(queue.g, next_root) + if all(v->v in seen, iv) + push!(task_order, next_root) + push!(seen, next_root) + ov = outneighbors(queue.g, next_root) + prepend!(to_walk, ov) + else + push!(to_walk, next_root) + end + end + end + else + throw(ArgumentError("Invalid traversal mode: $traversal")) + end + + state = DataDepsState(queue.aliasing) + sstate = DataDepsSchedulerState() + for proc in all_procs + space = only(memory_spaces(proc)) + get!(()->0, sstate.capacities, space) + sstate.capacities[space] += 1 + end + # Start launching tasks and necessary copies - state = DataDepsState() write_num = 1 + proc_idx = 1 + #pressures = Dict{Processor,Int}() proc_to_scope_lfu = BasicLFUCache{Processor,AbstractScope}(1024) - for pair in queue.seen_tasks + for pair in queue.seen_tasks[task_order] spec = pair.spec task = pair.task - write_num = distribute_task!(queue, state, all_procs, all_scope, spec, task, spec.fargs, proc_to_scope_lfu, write_num) + write_num, proc_idx = distribute_task!(queue, state, all_procs, spec, task, spec.fargs, proc_to_scope_lfu, write_num, proc_idx) end # Copy args from remote to local # N.B. We sort the keys to ensure a deterministic order for uniformity + check_uniform(length(state.arg_owner)) for arg_w in sort(collect(keys(state.arg_owner)); by=arg_w->arg_w.hash) + check_uniform(arg_w) arg = arg_w.arg origin_space = state.arg_origin[arg] remainder, _ = compute_remainder_for_arg!(state, origin_space, arg_w, write_num) @@ -141,10 +234,6 @@ function distribute_tasks!(queue::DataDepsTaskQueue) else @assert remainder isa NoAliasing "Expected NoAliasing, got $(typeof(remainder))" @dagdebug nothing :spawn_datadeps "Skipped copy-from (up-to-date): $origin_space" - ctx = Sch.eager_context() - id = rand(UInt) - @maybelog ctx timespan_start(ctx, :datadeps_copy_skip, (;id), (;)) - @maybelog ctx timespan_finish(ctx, :datadeps_copy_skip, (;id), (;thunk_id=0, from_space=origin_space, to_space=origin_space, arg_w, from_arg=arg, to_arg=arg)) end end write_num += 1 @@ -189,29 +278,174 @@ struct TypedDataDepsTaskArgument{T,N} deps::NTuple{N,DataDepsTaskDependency} end map_or_ntuple(f, xs::Vector) = map(f, 1:length(xs)) -@inline map_or_ntuple(@specialize(f), xs::NTuple{N,T}) where {N,T} = ntuple(f, Val(N)) -function distribute_task!(queue::DataDepsTaskQueue, state::DataDepsState, all_procs, all_scope, spec::DTaskSpec{typed}, task::DTask, fargs, proc_to_scope_lfu, write_num::Int) where typed +map_or_ntuple(f, xs::Tuple) = ntuple(f, length(xs)) + +# 4-arg version: side effects + returns Vector/Tuple of DataDepsTaskArgument for distribute_task! +function populate_task_info!(state::DataDepsState, task_args, spec::DTaskSpec, task::DTask) + result = DataDepsTaskArgument[] + _populate_task_info!(state, spec, task, (arg, pos, may_alias, inplace_move, deps) -> begin + dep_infos = DataDepsTaskDependency[DataDepsTaskDependency(arg, d) for d in deps] + push!(result, DataDepsTaskArgument(arg, pos, may_alias, inplace_move, dep_infos)) + end) + return spec.fargs isa Tuple ? (result...,) : result +end + +function distribute_task!(queue::DataDepsTaskQueue, state::DataDepsState, all_procs, spec::DTaskSpec{typed}, task::DTask, fargs, proc_to_scope_lfu, write_num::Int, proc_idx::Int) where typed @specialize spec fargs + DATADEPS_CURRENT_TASK[] = task + if typed fargs::Tuple else fargs::Vector{Argument} end - task_scope = @something(spec.options.compute_scope, spec.options.scope, DefaultScope()) scheduler = queue.scheduler - our_proc = datadeps_schedule_task(scheduler, state, all_procs, all_scope, task_scope, spec, task) + if scheduler == :naive + raw_args = map(arg->tochunk(value(arg)), spec.fargs) + our_proc = remotecall_fetch(1, all_procs, raw_args) do all_procs, raw_args + Sch.init_eager() + sch_state = Sch.EAGER_STATE[] + + @lock sch_state.lock begin + # Calculate costs per processor and select the most optimal + # FIXME: This should consider any already-allocated slots, + # whether they are up-to-date, and if not, the cost of moving + # data to them + procs, costs = Sch.estimate_task_costs(sch_state, all_procs, nothing, raw_args) + return first(procs) + end + end + elseif scheduler == :smart + raw_args = map(filter(arg->haskey(state.data_locality, value(arg)), spec.fargs)) do arg + arg_chunk = tochunk(value(arg)) + # Only the owned slot is valid + # FIXME: Track up-to-date copies and pass all of those + return arg_chunk => data_locality[arg] + end + f_chunk = tochunk(value(spec.fargs[1])) + our_proc, task_pressure = remotecall_fetch(1, all_procs, pressures, f_chunk, raw_args) do all_procs, pressures, f, chunks_locality + Sch.init_eager() + sch_state = Sch.EAGER_STATE[] + + @lock sch_state.lock begin + tx_rate = sch_state.transfer_rate[] + + costs = Dict{Processor,Float64}() + for proc in all_procs + # Filter out chunks that are already local + chunks_filt = Iterators.filter(((chunk, space)=chunk_locality)->!(proc in processors(space)), chunks_locality) + + # Estimate network transfer costs based on data size + # N.B. `affinity(x)` really means "data size of `x`" + # N.B. We treat same-worker transfers as having zero transfer cost + tx_cost = Sch.impute_sum(affinity(chunk)[2] for chunk in chunks_filt) + + # Estimate total cost to move data and get task running after currently-scheduled tasks + est_time_util = get(pressures, proc, UInt64(0)) + costs[proc] = est_time_util + (tx_cost/tx_rate) + end + + # Look up estimated task cost + sig = Sch.signature(sch_state, f, map(first, chunks_locality)) + task_pressure = get(sch_state.signature_time_cost, sig, 1000^3) + + # Shuffle procs around, so equally-costly procs are equally considered (skip when MPI for deterministic tie-breaking) + procs = if current_acceleration() isa Dagger.MPIAcceleration + collect(all_procs) + else + P = randperm(length(all_procs)) + getindex.(Ref(all_procs), P) + end + + # Sort by lowest cost first + sort!(procs, by=p->costs[p]) + + best_proc = first(procs) + return best_proc, task_pressure + end + end + # FIXME: Pressure should be decreased by pressure of syncdeps on same processor + pressures[our_proc] = get(pressures, our_proc, UInt64(0)) + task_pressure + elseif scheduler == :ultra + args = Base.mapany(spec.fargs) do arg + pos, data = arg + data, _ = unwrap_inout(data) + if data isa DTask + data = fetch(data; move_value=false, unwrap=false) + end + return pos => tochunk(data) + end + f_chunk = tochunk(value(spec.fargs[1])) + task_time = remotecall_fetch(1, f_chunk, args) do f, args + Sch.init_eager() + sch_state = Sch.EAGER_STATE[] + return @lock sch_state.lock begin + sig = Sch.signature(sch_state, f, args) + return get(sch_state.signature_time_cost, sig, 1000^3) + end + end + + # FIXME: Copy deps are computed eagerly + deps = @something(spec.options.syncdeps, Set{Any}()) + + # Find latest time-to-completion of all syncdeps + deps_completed = UInt64(0) + for dep in deps + haskey(sstate.task_completions, dep) || continue # copy deps aren't recorded + deps_completed = max(deps_completed, sstate.task_completions[dep]) + end + + # Find latest time-to-completion of each memory space + # FIXME: Figure out space completions based on optimal packing + spaces_completed = Dict{MemorySpace,UInt64}() + for space in exec_spaces + completed = UInt64(0) + for (task, other_space) in sstate.assignments + space == other_space || continue + completed = max(completed, sstate.task_completions[task]) + end + spaces_completed[space] = completed + end + + # Choose the earliest-available memory space and processor + # FIXME: Consider move time + move_time = UInt64(0) + local our_space_completed + while true + our_space_completed, our_space = findmin(spaces_completed) + our_space_procs = filter(proc->proc in all_procs, processors(our_space)) + if isempty(our_space_procs) + delete!(spaces_completed, our_space) + continue + end + our_proc = if current_acceleration() isa Dagger.MPIAcceleration + first(sort(collect(our_space_procs), by=short_name)) + else + rand(our_space_procs) + end + break + end + + sstate.task_to_spec[task] = spec + sstate.assignments[task] = our_space + sstate.task_completions[task] = our_space_completed + move_time + task_time + elseif scheduler == :roundrobin + our_proc = all_procs[proc_idx] + else + error("Invalid scheduler: $sched") + end @assert our_proc in all_procs our_space = only(memory_spaces(our_proc)) # Find the scope for this task (and its copies) task_scope = @something(spec.options.compute_scope, spec.options.scope, DefaultScope()) - if task_scope == all_scope + if task_scope == scope # Optimize for the common case, cache the proc=>scope mapping our_scope = get!(proc_to_scope_lfu, our_proc) do our_procs = filter(proc->proc in all_procs, collect(processors(our_space))) - return constrain(UnionScope(map(ExactScope, our_procs)...), all_scope) + return constrain(UnionScope(map(ExactScope, our_procs)...), scope) end else # Use the provided scope and constrain it to the available processors @@ -221,12 +455,13 @@ function distribute_task!(queue::DataDepsTaskQueue, state::DataDepsState, all_pr if our_scope isa InvalidScope throw(Sch.SchedulingException("Scopes are not compatible: $(our_scope.x), $(our_scope.y)")) end + check_uniform(our_proc) + check_uniform(our_space) f = spec.fargs[1] - tid = task.uid # FIXME: May not be correct to move this under uniformity #f.value = move(default_processor(), our_proc, value(f)) - @dagdebug tid :spawn_datadeps "($(repr(value(f)))) Scheduling: $our_proc ($our_space)" + @dagdebug nothing :spawn_datadeps "($(repr(value(f)))) Scheduling: $our_proc ($our_space)" # Copy raw task arguments for analysis # N.B. Used later for checking dependencies @@ -253,13 +488,13 @@ function distribute_task!(queue::DataDepsTaskQueue, state::DataDepsState, all_pr # Is the data written previously or now? if !arg_ws.may_alias - @dagdebug tid :spawn_datadeps "($(repr(value(f))))[$(idx-1)] Skipped copy-to (immutable)" + @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$(idx-1)] Skipped copy-to (immutable)" return arg end # Is the data writeable? if !arg_ws.inplace_move - @dagdebug tid :spawn_datadeps "($(repr(value(f))))[$(idx-1)] Skipped copy-to (non-writeable)" + @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$(idx-1)] Skipped copy-to (non-writeable)" return arg end @@ -276,7 +511,7 @@ function distribute_task!(queue::DataDepsTaskQueue, state::DataDepsState, all_pr enqueue_copy_to!(state, our_space, arg_w, value(f), idx, our_scope, task, write_num) else @assert remainder isa NoAliasing "Expected NoAliasing, got $(typeof(remainder))" - @dagdebug tid :spawn_datadeps "($(repr(value(f))))[$(idx-1)][$dep_mod] Skipped copy-to (up-to-date): $our_space" + @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$(idx-1)][$dep_mod] Skipped copy-to (up-to-date): $our_space" end end return arg_remote @@ -295,9 +530,6 @@ function distribute_task!(queue::DataDepsTaskQueue, state::DataDepsState, all_pr end # Check that any mutable and written arguments are already in the correct space - # N.B. We only do this check when the argument supports in-place - # moves, because for the moment, we are not guaranteeing updates or - # write-back of results if is_writedep(arg, deps, task) && arg_ws.may_alias && arg_ws.inplace_move arg_space = memory_space(arg) @assert arg_space == our_space "($(repr(value(f))))[$(idx-1)] Tried to pass $(typeof(arg)) from $arg_space to $our_space" @@ -306,8 +538,12 @@ function distribute_task!(queue::DataDepsTaskQueue, state::DataDepsState, all_pr # Calculate this task's syncdeps if spec.options.syncdeps === nothing - spec.options.syncdeps = Set{ThunkSyncdep}() + spec.options.syncdeps = Set{Any}() end + if spec.options.tag === nothing + spec.options.tag = to_tag() + end + syncdeps = spec.options.syncdeps map_or_ntuple(task_arg_ws) do idx arg_ws = task_arg_ws[idx] @@ -320,33 +556,46 @@ function distribute_task!(queue::DataDepsTaskQueue, state::DataDepsState, all_pr ainfo = aliasing!(state, our_space, arg_w) dep_mod = arg_w.dep_mod if dep.writedep - @dagdebug tid :spawn_datadeps "($(repr(value(f))))[$(idx-1)][$dep_mod] Syncing as writer" + @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$(idx-1)][$dep_mod] Syncing as writer" get_write_deps!(state, our_space, ainfo, write_num, syncdeps) else - @dagdebug tid :spawn_datadeps "($(repr(value(f))))[$(idx-1)][$dep_mod] Syncing as reader" + @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$(idx-1)][$dep_mod] Syncing as reader" get_read_deps!(state, our_space, ainfo, write_num, syncdeps) end end return end - @dagdebug tid :spawn_datadeps "($(repr(value(f)))) Task has $(length(syncdeps)) syncdeps" - - # Launch user's task - new_fargs = map_or_ntuple(task_arg_ws) do idx - if is_typed(spec) - return TypedArgument(task_arg_ws[idx].pos, remote_args[idx]) - else - return Argument(task_arg_ws[idx].pos, remote_args[idx]) + @dagdebug nothing :spawn_datadeps "($(repr(value(f)))) Task has $(length(syncdeps)) syncdeps" + + # Launch user's task: preserve full argument list (spec.fargs); use remote values only for tracked args + new_fargs = if spec.fargs isa Tuple + ntuple(length(spec.fargs)) do i + arg = spec.fargs[i] + pos = arg.pos + j = findfirst(w -> w.pos == pos, task_arg_ws) + if j !== nothing + val = remote_args[j] + is_typed(spec) ? TypedArgument(pos, val) : Argument(pos, val) + else + copy(arg) + end end + else + [let arg = spec.fargs[i], pos = arg.pos + j = findfirst(w -> w.pos == pos, task_arg_ws) + if j !== nothing + val = remote_args[j] + is_typed(spec) ? TypedArgument(pos, val) : Argument(pos, val) + else + copy(arg) + end + end for i in 1:length(spec.fargs)] end new_spec = DTaskSpec(new_fargs, spec.options) new_spec.options.scope = our_scope new_spec.options.exec_scope = our_scope new_spec.options.occupancy = Dict(Any=>0) - ctx = Sch.eager_context() - @maybelog ctx timespan_start(ctx, :datadeps_execute, (;thunk_id=task.uid), (;)) enqueue!(queue.upper_queue, DTaskPair(new_spec, task)) - @maybelog ctx timespan_finish(ctx, :datadeps_execute, (;thunk_id=task.uid), (;space=our_space, deps=task_arg_ws, args=remote_args)) # Update read/write tracking for arguments map_or_ntuple(task_arg_ws) do idx @@ -359,7 +608,7 @@ function distribute_task!(queue::DataDepsTaskQueue, state::DataDepsState, all_pr ainfo = aliasing!(state, our_space, arg_w) dep_mod = arg_w.dep_mod if dep.writedep - @dagdebug tid :spawn_datadeps "($(repr(value(f))))[$(idx-1)][$dep_mod] Task set as writer" + @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$(idx-1)][$dep_mod] Task set as writer" add_writer!(state, arg_w, our_space, ainfo, task, write_num) else add_reader!(state, arg_w, our_space, ainfo, task, write_num) @@ -369,6 +618,7 @@ function distribute_task!(queue::DataDepsTaskQueue, state::DataDepsState, all_pr end write_num += 1 + proc_idx = mod1(proc_idx + 1, length(all_procs)) - return write_num + return write_num, proc_idx end diff --git a/src/datadeps/remainders.jl b/src/datadeps/remainders.jl index 2c2c49920..af4b8a13c 100644 --- a/src/datadeps/remainders.jl +++ b/src/datadeps/remainders.jl @@ -9,11 +9,10 @@ This is used to perform partial data copies that only update the "remainder" reg struct RemainderAliasing{S<:MemorySpace} <: AbstractAliasing space::S spans::Vector{Tuple{LocalMemorySpan,LocalMemorySpan}} - ainfos::Vector{AliasingWrapper} syncdeps::Set{ThunkSyncdep} end -RemainderAliasing(space::S, spans::Vector{Tuple{LocalMemorySpan,LocalMemorySpan}}, ainfos::Vector{AliasingWrapper}, syncdeps::Set{ThunkSyncdep}) where S = - RemainderAliasing{S}(space, spans, ainfos, syncdeps) +RemainderAliasing(space::S, spans::Vector{Tuple{LocalMemorySpan,LocalMemorySpan}}, syncdeps::Set{ThunkSyncdep}) where S = + RemainderAliasing{S}(space, spans, syncdeps) memory_spans(ra::RemainderAliasing) = ra.spans @@ -43,6 +42,42 @@ memory_spans(mra::MultiRemainderAliasing) = vcat(memory_spans.(mra.remainders).. Base.hash(mra::MultiRemainderAliasing, h::UInt) = hash(mra.remainders, hash(MultiRemainderAliasing, h)) Base.:(==)(mra1::MultiRemainderAliasing, mra2::MultiRemainderAliasing) = mra1.remainders == mra2.remainders +#= FIXME: Integrate with main documentation +Problem statement: + +Remainder copy calculation needs to ensure that, for a given argument and +dependency modifier, and for a given target memory space, any data not yet +updated (whether through this arg or through another that aliases) is added to +the remainder, while any data that has been updated is not in the remainder. +Remainder copies may be multi-part, as data may be spread across multiple other +memory spaces. + +Ainfo is not alone sufficient to identify the combination of argument and +dependency modifier, as ainfo is specific to an allocation in a given memory +space. Thus, this combination needs to be tracked together, and separately from +memory space. However, information may span multiple memory spaces (and thus +multiple ainfos), so we should try to make queries of cross-memory space +information fast, as they will need to be performed for every task, for every +combination. + +Game Plan: + +- Use ArgumentWrapper to track this combination throughout the codebase, ideally generated just once +- Maintain the keying of remote_args only on argument, as the dependency modifier doesn’t affect the argument being passed into the task, so it should not factor into generating and tracking remote argument copies +- Add a structure to track the mapping from ArgumentWrapper to memory space to ainfo, as a quick way to lookup all ainfos needing to be considered +- When considering a remainder copy, only look at a single memory space’s ainfos at a time, as the ainfos should overlap exactly the same way on any memory space, and this allows us to use ainfo_overlaps to track overlaps +- Remainder copies will need to separately consider the source memory space, and the destination memory space when acquiring spans to copy to/from +- Memory spans for ainfos generated from the same ArgumentWrapper should be assumed to be paired in the same order, regardless of memory space, to ensure we can perform the translation from source to destination span address + - Alternatively, we might provide an API to take source and destination ainfos, and desired remainder memory spans, which then performs the copy for us +- When a task or copy writes to arguments, we should record this happening for all overlapping ainfos, in a manner that will be efficient to query from another memory space. We can probably walk backwards and attach this to a structure keyed on ArgumentWrapper, as that will be very efficient for later queries (because the history will now be linearized in one vector). +- Remainder copies will need to know, for all overlapping ainfos of the ArgumentWrapper ainfo at the target memory space, how recently that ainfo was updated relative to other ainfos, and relative to how recently the target ainfo was written. + - The last time the target ainfo was written is the furthest back we need to consider, as the target data must have been fully up-to-date when that write completed. + - Consideration of updates should start at most recent first, walking backwards in time, as the most recent updates contain the up-to-date data. + - For each span under consideration, we should subtract from it the current remainder set, to ensure we only copy up-to-date data. + - We must add that span portion to the remainder set no matter what, but if it was updated on the target memory space, we don’t need to schedule a copy for it, since it’s already where it needs to be. + - Even before the last target write is seen, we are allowed to stop searching if we find that our target ainfo is fully covered (because this implies that the target ainfo is fully out-of-date). +=# + struct FullCopy end """ @@ -87,14 +122,13 @@ function compute_remainder_for_arg!(state::DataDepsState, target_space::MemorySpace, arg_w::ArgumentWrapper, write_num::Int; compute_syncdeps::Bool=true) + @label restart + + # Determine all memory spaces of the history spaces_set = Set{MemorySpace}() push!(spaces_set, target_space) owner_space = state.arg_owner[arg_w] push!(spaces_set, owner_space) - - @label restart - - # Determine all memory spaces of the history for entry in state.arg_history[arg_w] push!(spaces_set, entry.space) end @@ -109,15 +143,21 @@ function compute_remainder_for_arg!(state::DataDepsState, push!(target_ainfos, LocalMemorySpan.(spans)) end nspans = length(first(target_ainfos)) - @assert all(==(nspans), length.(target_ainfos)) "Aliasing info for $(typeof(arg_w.arg))[$(arg_w.dep_mod)] has different number of spans in different memory spaces" # FIXME: This is a hack to ensure that we don't miss any history generated by aliasing(...) +<<<<<<< HEAD + for (_, space, _) in state.arg_history[arg_w] + if !in(space, spaces) +======= for entry in state.arg_history[arg_w] if !in(entry.space, spaces) @opcounter :compute_remainder_for_arg_restart +>>>>>>> 85e0b801 (MPI: Optimizations and fix some uniformity issues) @goto restart end end + check_uniform(spaces) + check_uniform(target_ainfos) # We may only need to schedule a full copy from the origin space to the # target space if this is the first time we've written to `arg_w` @@ -130,14 +170,10 @@ function compute_remainder_for_arg!(state::DataDepsState, end # Create our remainder as an interval tree over all target ainfos - VERIFY_SPAN_CURRENT_OBJECT[] = arg_w.arg remainder = IntervalTree{ManyMemorySpan{N}}(ManyMemorySpan{N}(ntuple(i -> target_ainfos[i][j], N)) for j in 1:nspans) - for span in remainder - verify_span(span) - end # Create our tracker - tracker = Dict{MemorySpace,Tuple{Vector{Tuple{LocalMemorySpan,LocalMemorySpan}},Vector{AliasingWrapper},Set{ThunkSyncdep}}}() + tracker = Dict{MemorySpace,Tuple{Vector{Tuple{LocalMemorySpan,LocalMemorySpan}},Set{ThunkSyncdep}}}() # Walk backwards through the history of writes to this target # other_ainfo is the overlapping ainfo that was written to @@ -159,9 +195,11 @@ function compute_remainder_for_arg!(state::DataDepsState, other_ainfo = aliasing!(state, owner_space, arg_w) other_space = owner_space end + check_uniform(other_ainfo) + check_uniform(other_space) # Lookup all memory spans for arg_w in these spaces - other_remote_arg_w = first(collect(state.ainfo_arg[other_ainfo])) + other_remote_arg_w = state.ainfo_arg[other_ainfo] other_arg_w = ArgumentWrapper(state.remote_arg_to_original[other_remote_arg_w.arg], other_remote_arg_w.dep_mod) other_ainfos = Vector{Vector{LocalMemorySpan}}() for space in spaces @@ -171,15 +209,14 @@ function compute_remainder_for_arg!(state::DataDepsState, end nspans = length(first(other_ainfos)) other_many_spans = [ManyMemorySpan{N}(ntuple(i -> other_ainfos[i][j], N)) for j in 1:nspans] - foreach(other_many_spans) do span - verify_span(span) - end + + check_uniform(other_many_spans) + check_uniform(spaces) if other_space == target_space # Only subtract, this data is already up-to-date in target_space # N.B. We don't add to syncdeps here, because we'll see this ainfo # in get_write_deps! - @opcounter :compute_remainder_for_arg_subtract subtract_spans!(remainder, other_many_spans) continue end @@ -188,19 +225,16 @@ function compute_remainder_for_arg!(state::DataDepsState, other_space_idx = something(findfirst(==(other_space), spaces)) target_space_idx = something(findfirst(==(target_space), spaces)) tracker_other_space = get!(tracker, other_space) do - (Vector{Tuple{LocalMemorySpan,LocalMemorySpan}}(), Vector{AliasingWrapper}(), Set{ThunkSyncdep}()) + (Vector{Tuple{LocalMemorySpan,LocalMemorySpan}}(), Set{ThunkSyncdep}()) end - @opcounter :compute_remainder_for_arg_schedule - has_overlap = schedule_remainder!(tracker_other_space[1], other_space_idx, target_space_idx, remainder, other_many_spans) - if compute_syncdeps && has_overlap + schedule_remainder!(tracker_other_space[1], other_space_idx, target_space_idx, remainder, other_many_spans) + if compute_syncdeps @assert haskey(state.ainfos_owner, other_ainfo) "[idx $idx] ainfo $(typeof(other_ainfo)) has no owner" - get_read_deps!(state, other_space, other_ainfo, write_num, tracker_other_space[3]) - push!(tracker_other_space[2], other_ainfo) + get_read_deps!(state, other_space, other_ainfo, write_num, tracker_other_space[2]) end end - VERIFY_SPAN_CURRENT_OBJECT[] = nothing - if isempty(tracker) || all(tracked->isempty(tracked[1]), values(tracker)) + if isempty(tracker) return NoAliasing(), 0 end @@ -208,13 +242,12 @@ function compute_remainder_for_arg!(state::DataDepsState, mra = MultiRemainderAliasing() for space in spaces if haskey(tracker, space) - spans, ainfos, syncdeps = tracker[space] + spans, syncdeps = tracker[space] if !isempty(spans) - push!(mra.remainders, RemainderAliasing(space, spans, ainfos, syncdeps)) + push!(mra.remainders, RemainderAliasing(space, spans, syncdeps)) end end end - @assert !isempty(mra.remainders) "Expected at least one remainder (spaces: $spaces, tracker spaces: $(collect(keys(tracker))))" return mra, last_idx end @@ -230,13 +263,12 @@ copy from `other_many_spans` to the subtraced portion of `remainder`. function schedule_remainder!(tracker::Vector, source_space_idx::Int, dest_space_idx::Int, remainder::IntervalTree, other_many_spans::Vector{ManyMemorySpan{N}}) where N diff = Vector{ManyMemorySpan{N}}() subtract_spans!(remainder, other_many_spans, diff) + for span in diff source_span = span.spans[source_space_idx] dest_span = span.spans[dest_space_idx] - @assert span_len(source_span) == span_len(dest_span) "Source and dest spans are not the same size: $(span_len(source_span)) != $(span_len(dest_span))" push!(tracker, (source_span, dest_span)) end - return !isempty(diff) end ### Remainder copy functions @@ -250,7 +282,9 @@ Enqueues a copy operation to update the remainder regions of an object before a function enqueue_remainder_copy_to!(state::DataDepsState, dest_space::MemorySpace, arg_w::ArgumentWrapper, remainder_aliasing::MultiRemainderAliasing, f, idx, dest_scope, task, write_num::Int) for remainder in remainder_aliasing.remainders + check_uniform(remainder.space) @assert !isempty(remainder.spans) + check_uniform(remainder.spans) enqueue_remainder_copy_to!(state, dest_space, arg_w, remainder, f, idx, dest_scope, task, write_num) end end @@ -263,7 +297,7 @@ function enqueue_remainder_copy_to!(state::DataDepsState, dest_space::MemorySpac # overwritten by more recent partial updates source_space = remainder_aliasing.space - @dagdebug task.uid :spawn_datadeps "($(repr(f)))[$(idx-1)][$dep_mod] Enqueueing remainder copy-to for $(typeof(arg_w.arg))[$(arg_w.dep_mod)]: $source_space => $dest_space" + @dagdebug nothing :spawn_datadeps "($(repr(f)))[$(idx-1)][$dep_mod] Enqueueing remainder copy-to for $(typeof(arg_w.arg))[$(arg_w.dep_mod)]: $source_space => $dest_space" # Get the source and destination arguments arg_dest = state.remote_args[dest_space][arg_w.arg] @@ -276,23 +310,16 @@ function enqueue_remainder_copy_to!(state::DataDepsState, dest_space::MemorySpac push!(remainder_syncdeps, syncdep) end empty!(remainder_aliasing.syncdeps) # We can't bring these to move! - source_ainfos = copy(remainder_aliasing.ainfos) - empty!(remainder_aliasing.ainfos) get_write_deps!(state, dest_space, target_ainfo, write_num, remainder_syncdeps) - @dagdebug task.uid :spawn_datadeps "($(repr(f)))[$(idx-1)][$dep_mod] Remainder copy-to has $(length(remainder_syncdeps)) syncdeps" + @dagdebug nothing :spawn_datadeps "($(repr(f)))[$(idx-1)][$dep_mod] Remainder copy-to has $(length(remainder_syncdeps)) syncdeps" # Launch the remainder copy task - ctx = Sch.eager_context() - id = rand(UInt) - @maybelog ctx timespan_start(ctx, :datadeps_copy, (;id), (;)) - copy_task = Dagger.@spawn scope=dest_scope exec_scope=dest_scope syncdeps=remainder_syncdeps meta=true Dagger.move!(remainder_aliasing, dest_space, source_space, arg_dest, arg_source) - @maybelog ctx timespan_finish(ctx, :datadeps_copy, (;id), (;thunk_id=copy_task.uid, from_space=source_space, to_space=dest_space, arg_w, from_arg=arg_source, to_arg=arg_dest)) - - # This copy task reads the sources and writes to the target - for ainfo in source_ainfos - add_reader!(state, arg_w, source_space, ainfo, copy_task, write_num) + copy_task = Dagger.with_options(; tag=to_tag()) do + Dagger.@spawn scope=dest_scope exec_scope=dest_scope syncdeps=remainder_syncdeps occupancy=Dict(Any=>0) meta=true Dagger.move!(remainder_aliasing, dest_space, source_space, arg_dest, arg_source) end + + # This copy task becomes a new writer for the target region add_writer!(state, arg_w, dest_space, target_ainfo, copy_task, write_num) end """ @@ -304,7 +331,9 @@ Enqueues a copy operation to update the remainder regions of an object back to t function enqueue_remainder_copy_from!(state::DataDepsState, dest_space::MemorySpace, arg_w::ArgumentWrapper, remainder_aliasing::MultiRemainderAliasing, dest_scope, write_num::Int) for remainder in remainder_aliasing.remainders + check_uniform(remainder.space) @assert !isempty(remainder.spans) + check_uniform(remainder.spans) enqueue_remainder_copy_from!(state, dest_space, arg_w, remainder, dest_scope, write_num) end end @@ -330,23 +359,16 @@ function enqueue_remainder_copy_from!(state::DataDepsState, dest_space::MemorySp push!(remainder_syncdeps, syncdep) end empty!(remainder_aliasing.syncdeps) # We can't bring these to move! - source_ainfos = copy(remainder_aliasing.ainfos) - empty!(remainder_aliasing.ainfos) get_write_deps!(state, dest_space, target_ainfo, write_num, remainder_syncdeps) @dagdebug nothing :spawn_datadeps "($(typeof(arg_w.arg)))[$dep_mod] Remainder copy-from has $(length(remainder_syncdeps)) syncdeps" # Launch the remainder copy task - ctx = Sch.eager_context() - id = rand(UInt) - @maybelog ctx timespan_start(ctx, :datadeps_copy, (;id), (;)) - copy_task = Dagger.@spawn scope=dest_scope exec_scope=dest_scope syncdeps=remainder_syncdeps meta=true Dagger.move!(remainder_aliasing, dest_space, source_space, arg_dest, arg_source) - @maybelog ctx timespan_finish(ctx, :datadeps_copy, (;id), (;thunk_id=copy_task.uid, from_space=source_space, to_space=dest_space, arg_w, from_arg=arg_source, to_arg=arg_dest)) - - # This copy task reads the sources and writes to the target - for ainfo in source_ainfos - add_reader!(state, arg_w, source_space, ainfo, copy_task, write_num) + copy_task = Dagger.with_options(; tag=to_tag()) do + Dagger.@spawn scope=dest_scope exec_scope=dest_scope syncdeps=remainder_syncdeps occupancy=Dict(Any=>0) meta=true Dagger.move!(remainder_aliasing, dest_space, source_space, arg_dest, arg_source) end + + # This copy task becomes a new writer for the target region add_writer!(state, arg_w, dest_space, target_ainfo, copy_task, write_num) end @@ -357,7 +379,7 @@ function enqueue_copy_to!(state::DataDepsState, dest_space::MemorySpace, arg_w:: source_space = state.arg_owner[arg_w] target_ainfo = aliasing!(state, dest_space, arg_w) - @dagdebug task.uid :spawn_datadeps "($(repr(f)))[$(idx-1)][$dep_mod] Enqueueing full copy-to for $(typeof(arg_w.arg))[$(arg_w.dep_mod)]: $source_space => $dest_space" + @dagdebug nothing :spawn_datadeps "($(repr(f)))[$(idx-1)][$dep_mod] Enqueueing full copy-to for $(typeof(arg_w.arg))[$(arg_w.dep_mod)]: $source_space => $dest_space" # Get the source and destination arguments arg_dest = state.remote_args[dest_space][arg_w.arg] @@ -370,17 +392,12 @@ function enqueue_copy_to!(state::DataDepsState, dest_space::MemorySpace, arg_w:: get_read_deps!(state, source_space, source_ainfo, write_num, copy_syncdeps) get_write_deps!(state, dest_space, target_ainfo, write_num, copy_syncdeps) - @dagdebug task.uid :spawn_datadeps "($(repr(f)))[$(idx-1)][$dep_mod] Full copy-to has $(length(copy_syncdeps)) syncdeps" + @dagdebug nothing :spawn_datadeps "($(repr(f)))[$(idx-1)][$dep_mod] Full copy-to has $(length(copy_syncdeps)) syncdeps" # Launch the remainder copy task - ctx = Sch.eager_context() - id = rand(UInt) - @maybelog ctx timespan_start(ctx, :datadeps_copy, (;id), (;)) - copy_task = Dagger.@spawn scope=dest_scope exec_scope=dest_scope syncdeps=copy_syncdeps meta=true Dagger.move!(dep_mod, dest_space, source_space, arg_dest, arg_source) - @maybelog ctx timespan_finish(ctx, :datadeps_copy, (;id), (;thunk_id=copy_task.uid, from_space=source_space, to_space=dest_space, arg_w, from_arg=arg_source, to_arg=arg_dest)) - - # This copy task reads the source and writes to the target - add_reader!(state, arg_w, source_space, source_ainfo, copy_task, write_num) + copy_task = Dagger.with_options(; tag=to_tag()) do + Dagger.@spawn scope=dest_scope exec_scope=dest_scope syncdeps=copy_syncdeps occupancy=Dict(Any=>0) meta=true Dagger.move!(dep_mod, dest_space, source_space, arg_dest, arg_source) + end add_writer!(state, arg_w, dest_space, target_ainfo, copy_task, write_num) end function enqueue_copy_from!(state::DataDepsState, dest_space::MemorySpace, arg_w::ArgumentWrapper, @@ -405,47 +422,38 @@ function enqueue_copy_from!(state::DataDepsState, dest_space::MemorySpace, arg_w @dagdebug nothing :spawn_datadeps "($(typeof(arg_w.arg)))[$dep_mod] Full copy-from has $(length(copy_syncdeps)) syncdeps" # Launch the remainder copy task - ctx = Sch.eager_context() - id = rand(UInt) - @maybelog ctx timespan_start(ctx, :datadeps_copy, (;id), (;)) - copy_task = Dagger.@spawn scope=dest_scope exec_scope=dest_scope syncdeps=copy_syncdeps meta=true Dagger.move!(dep_mod, dest_space, source_space, arg_dest, arg_source) - @maybelog ctx timespan_finish(ctx, :datadeps_copy, (;id), (;thunk_id=copy_task.uid, from_space=source_space, to_space=dest_space, arg_w, from_arg=arg_source, to_arg=arg_dest)) - - # This copy task reads the source and writes to the target - add_reader!(state, arg_w, source_space, source_ainfo, copy_task, write_num) + copy_task = Dagger.with_options(; tag=to_tag()) do + Dagger.@spawn scope=dest_scope exec_scope=dest_scope syncdeps=copy_syncdeps occupancy=Dict(Any=>0) meta=true Dagger.move!(dep_mod, dest_space, source_space, arg_dest, arg_source) + end + + # This copy task becomes a new writer for the target region add_writer!(state, arg_w, dest_space, target_ainfo, copy_task, write_num) end # Main copy function for RemainderAliasing -function move!(dep_mod::RemainderAliasing{S}, to_space::MemorySpace, from_space::MemorySpace, to::Chunk, from::Chunk) where S - # TODO: Support direct copy between GPU memory spaces - - # Copy the data from the source object - copies = remotecall_fetch(root_worker_id(from_space), from_space, dep_mod, from) do from_space, dep_mod, from - len = sum(span_tuple->span_len(span_tuple[1]), dep_mod.spans) - copies = Vector{UInt8}(undef, len) - from_raw = unwrap(from) - offset = UInt64(1) - with_context!(from_space) - GC.@preserve copies begin - for (from_span, _) in dep_mod.spans - read_remainder!(copies, offset, from_raw, from_span.ptr, from_span.len) - offset += from_span.len +function move!(dep_mod::RemainderAliasing, to_space::MemorySpace, from_space::MemorySpace, to::Chunk, from::Chunk) + # Get the source data for each span + copies = remotecall_fetch(root_worker_id(from_space), dep_mod) do dep_mod + copies = Vector{UInt8}[] + for (from_span, _) in dep_mod.spans + copy = Vector{UInt8}(undef, from_span.len) + GC.@preserve copy begin + from_ptr = Ptr{UInt8}(from_span.ptr) + to_ptr = Ptr{UInt8}(pointer(copy)) + unsafe_copyto!(to_ptr, from_ptr, from_span.len) end + push!(copies, copy) end - @assert offset == len+UInt64(1) return copies end # Copy the data into the destination object - offset = UInt64(1) - to_raw = unwrap(to) - GC.@preserve copies begin - for (_, to_span) in dep_mod.spans - write_remainder!(copies, offset, to_raw, to_span.ptr, to_span.len) - offset += to_span.len + for (copy, (_, to_span)) in zip(copies, dep_mod.spans) + GC.@preserve copy begin + from_ptr = Ptr{UInt8}(pointer(copy)) + to_ptr = Ptr{UInt8}(to_span.ptr) + unsafe_copyto!(to_ptr, from_ptr, to_span.len) end - @assert offset == length(copies)+UInt64(1) end # Ensure that the data is visible @@ -453,88 +461,3 @@ function move!(dep_mod::RemainderAliasing{S}, to_space::MemorySpace, from_space: return end - -function read_remainder!(copies::Vector{UInt8}, copies_offset::UInt64, from::Array, from_ptr::UInt64, len::UInt64) - elsize = sizeof(eltype(from)) - @assert len / elsize == round(UInt64, len / elsize) "Span length is not an integer multiple of the element size: $(len) / $(elsize) = $(len / elsize) (elsize: $elsize)" - n = UInt64(len / elsize) - from_offset_n = UInt64((from_ptr - UInt64(pointer(from))) / elsize) + UInt64(1) - from_vec = reshape(from, prod(size(from)))::DenseVector{eltype(from)} - # unsafe_wrap(Array, ...) doesn't like unaligned memory - unsafe_copyto!(Ptr{eltype(from)}(pointer(copies, copies_offset)), pointer(from_vec, from_offset_n), n) -end -function read_remainder!(copies::Vector{UInt8}, copies_offset::UInt64, from::DenseArray, from_ptr::UInt64, len::UInt64) - elsize = sizeof(eltype(from)) - @assert len / elsize == round(UInt64, len / elsize) "Span length is not an integer multiple of the element size: $(len) / $(elsize) = $(len / elsize) (elsize: $elsize)" - n = UInt64(len / elsize) - from_offset_n = UInt64((from_ptr - UInt64(pointer(from))) / elsize) + UInt64(1) - from_vec = reshape(from, prod(size(from)))::DenseVector{eltype(from)} - copies_typed = unsafe_wrap(Vector{eltype(from)}, Ptr{eltype(from)}(pointer(copies, copies_offset)), n) - copyto!(copies_typed, 1, from_vec, Int(from_offset_n), Int(n)) -end -function read_remainder!(copies::Vector{UInt8}, copies_offset::UInt64, from, from_ptr::UInt64, n::UInt64) - real_from = find_object_holding_ptr(from, from_ptr) - return read_remainder!(copies, copies_offset, real_from, from_ptr, n) -end - -function write_remainder!(copies::Vector{UInt8}, copies_offset::UInt64, to::Array, to_ptr::UInt64, len::UInt64) - elsize = sizeof(eltype(to)) - @assert len / elsize == round(UInt64, len / elsize) "Span length is not an integer multiple of the element size: $(len) / $(elsize) = $(len / elsize) (elsize: $elsize)" - n = UInt64(len / elsize) - to_offset_n = UInt64((to_ptr - UInt64(pointer(to))) / elsize) + UInt64(1) - to_vec = reshape(to, prod(size(to)))::DenseVector{eltype(to)} - # unsafe_wrap(Array, ...) doesn't like unaligned memory - unsafe_copyto!(pointer(to_vec, to_offset_n), Ptr{eltype(to)}(pointer(copies, copies_offset)), n) -end -function write_remainder!(copies::Vector{UInt8}, copies_offset::UInt64, to::DenseArray, to_ptr::UInt64, len::UInt64) - elsize = sizeof(eltype(to)) - @assert len / elsize == round(UInt64, len / elsize) "Span length is not an integer multiple of the element size: $(len) / $(elsize) = $(len / elsize) (elsize: $elsize)" - n = UInt64(len / elsize) - to_offset_n = UInt64((to_ptr - UInt64(pointer(to))) / elsize) + UInt64(1) - to_vec = reshape(to, prod(size(to)))::DenseVector{eltype(to)} - copies_typed = unsafe_wrap(Vector{eltype(to)}, Ptr{eltype(to)}(pointer(copies, copies_offset)), n) - copyto!(to_vec, Int(to_offset_n), copies_typed, 1, Int(n)) -end -function write_remainder!(copies::Vector{UInt8}, copies_offset::UInt64, to, to_ptr::UInt64, n::UInt64) - real_to = find_object_holding_ptr(to, to_ptr) - return write_remainder!(copies, copies_offset, real_to, to_ptr, n) -end - -# Remainder copies for common objects -for wrapper in (UpperTriangular, LowerTriangular, UnitUpperTriangular, UnitLowerTriangular, SubArray) - @eval function read_remainder!(copies::Vector{UInt8}, copies_offset::UInt64, from::$wrapper, from_ptr::UInt64, n::UInt64) - read_remainder!(copies, copies_offset, parent(from), from_ptr, n) - end - @eval function write_remainder!(copies::Vector{UInt8}, copies_offset::UInt64, to::$wrapper, to_ptr::UInt64, n::UInt64) - write_remainder!(copies, copies_offset, parent(to), to_ptr, n) - end -end - -function read_remainder!(copies::Vector{UInt8}, copies_offset::UInt64, from::Base.RefValue, from_ptr::UInt64, n::UInt64) - if from_ptr == UInt64(Base.pointer_from_objref(from) + fieldoffset(typeof(from), 1)) - unsafe_copyto!(pointer(copies, copies_offset), Ptr{UInt8}(from_ptr), n) - else - read_remainder!(copies, copies_offset, from[], from_ptr, n) - end -end -function write_remainder!(copies::Vector{UInt8}, copies_offset::UInt64, to::Base.RefValue, to_ptr::UInt64, n::UInt64) - if to_ptr == UInt64(Base.pointer_from_objref(to) + fieldoffset(typeof(to), 1)) - unsafe_copyto!(Ptr{UInt8}(to_ptr), pointer(copies, copies_offset), n) - else - write_remainder!(copies, copies_offset, to[], to_ptr, n) - end -end - -function find_object_holding_ptr(A::SparseMatrixCSC, ptr::UInt64) - span = LocalMemorySpan(pointer(A.nzval), length(A.nzval)*sizeof(eltype(A.nzval))) - if span_start(span) <= ptr <= span_end(span) - return A.nzval - end - span = LocalMemorySpan(pointer(A.colptr), length(A.colptr)*sizeof(eltype(A.colptr))) - if span_start(span) <= ptr <= span_end(span) - return A.colptr - end - span = LocalMemorySpan(pointer(A.rowval), length(A.rowval)*sizeof(eltype(A.rowval))) - @assert span_start(span) <= ptr <= span_end(span) "Pointer $ptr not found in SparseMatrixCSC" - return A.rowval -end \ No newline at end of file diff --git a/src/datadeps/scheduling.jl b/src/datadeps/scheduling.jl index 0bf9818f6..b2bcaca7b 100644 --- a/src/datadeps/scheduling.jl +++ b/src/datadeps/scheduling.jl @@ -111,7 +111,11 @@ function datadeps_schedule_task(sched::UltraScheduler, state::DataDepsState, all delete!(spaces_completed, our_space) continue end - our_proc = rand(our_space_procs) + our_proc = if Dagger.current_acceleration() isa Dagger.MPIAcceleration + first(sort(collect(our_space_procs), by=Dagger.short_name)) + else + rand(our_space_procs) + end break end diff --git a/src/dtask.jl b/src/dtask.jl index e94803502..13e66cafe 100644 --- a/src/dtask.jl +++ b/src/dtask.jl @@ -65,11 +65,13 @@ function Base.wait(t::DTask) wait(t.future) return end -function Base.fetch(t::DTask; raw=false) +function Base.fetch(t::DTask; raw=false, move_value=nothing, unwrap=nothing) if !istaskstarted(t) throw(ConcurrencyViolationError("Cannot `fetch` an unlaunched `DTask`")) end - return fetch(t.future; raw) + # Datadeps/aliasing API: move_value=false => don't move => raw=true + raw_eff = move_value !== nothing ? !move_value : raw + return fetch(t.future; raw=raw_eff) end function waitany(tasks::Vector{DTask}) if isempty(tasks) diff --git a/src/lib/domain-blocks.jl b/src/lib/domain-blocks.jl index 2a0854e3b..95e5c360f 100644 --- a/src/lib/domain-blocks.jl +++ b/src/lib/domain-blocks.jl @@ -6,6 +6,8 @@ struct DomainBlocks{N} <: AbstractArray{ArrayDomain{N, NTuple{N, UnitRange{Int}} end Base.@deprecate_binding BlockedDomains DomainBlocks +ndims(::DomainBlocks{N}) where N = N + size(x::DomainBlocks) = map(length, x.cumlength) function _getindex(x::DomainBlocks{N}, idx::Tuple) where N starts = map((vec, i) -> i == 0 ? 0 : getindex(vec,i), x.cumlength, map(x->x-1, idx)) diff --git a/src/memory-spaces.jl b/src/memory-spaces.jl index eb4f7ad5b..39bfa7ccc 100644 --- a/src/memory-spaces.jl +++ b/src/memory-spaces.jl @@ -1,24 +1,63 @@ +struct DistributedAcceleration <: Acceleration end + +const ACCELERATION = TaskLocalValue{Acceleration}(() -> DistributedAcceleration()) + +current_acceleration() = ACCELERATION[] + +default_processor(::DistributedAcceleration) = OSProc(myid()) +default_processor(accel::DistributedAcceleration, x) = default_processor(accel) +default_processor() = default_processor(current_acceleration()) + +accelerate!(accel::Symbol) = accelerate!(Val{accel}()) +accelerate!(::Val{:distributed}) = accelerate!(DistributedAcceleration()) + +initialize_acceleration!(a::DistributedAcceleration) = nothing +function accelerate!(accel::Acceleration) + initialize_acceleration!(accel) + ACCELERATION[] = accel +end +accelerate!(::Nothing) = nothing + +accel_matches_proc(accel::DistributedAcceleration, proc::OSProc) = true +accel_matches_proc(accel::DistributedAcceleration, proc) = true + +function compatible_processors(accel::Union{Acceleration,Nothing}, scope::AbstractScope, procs::Vector{<:Processor}) + comp = compatible_processors(scope, procs) + accel === nothing && return comp + return Set(p for p in comp if accel_matches_proc(accel, p)) +end + struct CPURAMMemorySpace <: MemorySpace owner::Int end -CPURAMMemorySpace() = CPURAMMemorySpace(myid()) root_worker_id(space::CPURAMMemorySpace) = space.owner -memory_space(x) = CPURAMMemorySpace(myid()) -function memory_space(x::Chunk) - proc = processor(x) - if proc isa OSProc - # TODO: This should probably be programmable - return CPURAMMemorySpace(proc.pid) - else - return only(memory_spaces(proc)) - end -end -memory_space(x::DTask) = - memory_space(fetch(x; raw=true)) +CPURAMMemorySpace() = CPURAMMemorySpace(myid()) + +default_processor(space::CPURAMMemorySpace) = OSProc(space.owner) +default_memory_space(accel::DistributedAcceleration) = CPURAMMemorySpace(myid()) +default_memory_space(accel::DistributedAcceleration, x) = default_memory_space(accel) +default_memory_space(x) = default_memory_space(current_acceleration(), x) +default_memory_space() = default_memory_space(current_acceleration()) + +memory_space(x, proc::Processor=default_processor()) = first(memory_spaces(proc)) +memory_space(x::Processor) = first(memory_spaces(x)) +memory_space(x::Chunk) = x.space +memory_space(x::DTask) = memory_space(fetch(x; move_value=false, unwrap=false)) memory_spaces(::P) where {P<:Processor} = throw(ArgumentError("Must define `memory_spaces` for `$P`")) + +function memory_spaces(proc::OSProc) + children = get_processors(proc) + spaces = Set{MemorySpace}() + for proc in children + for space in memory_spaces(proc) + push!(spaces, space) + end + end + return spaces +end memory_spaces(proc::ThreadProc) = Set([CPURAMMemorySpace(proc.owner)]) processors(::S) where {S<:MemorySpace} = @@ -28,9 +67,12 @@ processors(space::CPURAMMemorySpace) = ### In-place Data Movement -function unwrap(x::Chunk) - @assert x.handle.owner == myid() - MemPool.poolget(x.handle) +function unwrap(x::Chunk; uniform::Bool=false) + @assert root_worker_id(x.handle) == myid() "Chunk $x is not owned by this process: $(root_worker_id(x.handle)) != $(myid())" + if x.handle isa DRef + return MemPool.poolget(x.handle) + end + return MemPool.poolget(x.handle; uniform) end move!(dep_mod, to_space::MemorySpace, from_space::MemorySpace, to::T, from::F) where {T,F} = throw(ArgumentError("No `move!` implementation defined for $F -> $T")) @@ -69,6 +111,16 @@ function move!(::Type{<:Tridiagonal}, to_space::MemorySpace, from_space::MemoryS return end +# FIXME: Take MemorySpace instead +function move_type(from_proc::Processor, to_proc::Processor, ::Type{T}) where T + if from_proc == to_proc + return T + end + return Base._return_type(move, Tuple{typeof(from_proc), typeof(to_proc), T}) +end +move_type(from_proc::Processor, to_proc::Processor, ::Type{<:Chunk{T}}) where T = + move_type(from_proc, to_proc, T) + ### Aliasing and Memory Spans type_may_alias(::Type{String}) = false @@ -88,20 +140,20 @@ function type_may_alias(::Type{T}) where T return false end -may_alias(::MemorySpace, ::MemorySpace) = false -may_alias(space1::M, space2::M) where M<:MemorySpace = space1 == space2 +may_alias(::MemorySpace, ::MemorySpace) = true may_alias(space1::CPURAMMemorySpace, space2::CPURAMMemorySpace) = space1.owner == space2.owner +# RemotePtr and MemorySpan are defined in utils/memory-span.jl (included earlier). + abstract type AbstractAliasing end memory_spans(::T) where T<:AbstractAliasing = throw(ArgumentError("Must define `memory_spans` for `$T`")) memory_spans(x) = memory_spans(aliasing(x)) memory_spans(x, T) = memory_spans(aliasing(x, T)) -### Type-generic aliasing info wrapper - -mutable struct AliasingWrapper <: AbstractAliasing +struct AliasingWrapper <: AbstractAliasing inner::AbstractAliasing hash::UInt64 + AliasingWrapper(inner::AbstractAliasing) = new(inner, hash(inner)) end memory_spans(x::AliasingWrapper) = memory_spans(x.inner) @@ -110,204 +162,8 @@ equivalent_structure(x::AliasingWrapper, y::AliasingWrapper) = Base.hash(x::AliasingWrapper, h::UInt64) = hash(x.hash, h) Base.isequal(x::AliasingWrapper, y::AliasingWrapper) = x.hash == y.hash Base.:(==)(x::AliasingWrapper, y::AliasingWrapper) = x.hash == y.hash -will_alias(x::AliasingWrapper, y::AliasingWrapper) = will_alias(x.inner, y.inner) - -### Small dictionary type - -struct SmallDict{K,V} <: AbstractDict{K,V} - keys::Vector{K} - vals::Vector{V} -end -SmallDict{K,V}() where {K,V} = SmallDict{K,V}(Vector{K}(), Vector{V}()) -function Base.getindex(d::SmallDict{K,V}, key) where {K,V} - key_idx = findfirst(==(convert(K, key)), d.keys) - if key_idx === nothing - throw(KeyError(key)) - end - return @inbounds d.vals[key_idx] -end -function Base.setindex!(d::SmallDict{K,V}, val, key) where {K,V} - key_conv = convert(K, key) - key_idx = findfirst(==(key_conv), d.keys) - if key_idx === nothing - push!(d.keys, key_conv) - push!(d.vals, convert(V, val)) - else - d.vals[key_idx] = convert(V, val) - end - return val -end -Base.haskey(d::SmallDict{K,V}, key) where {K,V} = in(convert(K, key), d.keys) -Base.keys(d::SmallDict) = d.keys -Base.length(d::SmallDict) = length(d.keys) -Base.iterate(d::SmallDict) = iterate(d, 1) -Base.iterate(d::SmallDict, state) = state > length(d.keys) ? nothing : (d.keys[state] => d.vals[state], state+1) - -### Type-stable lookup structure for AliasingWrappers - -struct AliasingLookup - # The set of memory spaces that are being tracked - spaces::Vector{MemorySpace} - # The set of AliasingWrappers that are being tracked - # One entry for each AliasingWrapper - ainfos::Vector{AliasingWrapper} - # The memory spaces for each AliasingWrapper - # One entry for each AliasingWrapper - ainfos_spaces::Vector{Vector{Int}} - # The spans for each AliasingWrapper in each memory space - # One entry for each AliasingWrapper - spans::Vector{SmallDict{Int,Vector{LocalMemorySpan}}} - # The set of AliasingWrappers that only exist in a single memory space - # One entry for each AliasingWrapper - ainfos_only_space::Vector{Int} - # The bounding span for each AliasingWrapper in each memory space - # One entry for each AliasingWrapper - bounding_spans::Vector{SmallDict{Int,LocalMemorySpan}} - # The interval tree of the bounding spans for each AliasingWrapper - # One entry for each MemorySpace - bounding_spans_tree::Vector{IntervalTree{LocatorMemorySpan{Int},UInt64}} - - AliasingLookup() = new(MemorySpace[], - AliasingWrapper[], - Vector{Int}[], - SmallDict{Int,Vector{LocalMemorySpan}}[], - Int[], - SmallDict{Int,LocalMemorySpan}[], - IntervalTree{LocatorMemorySpan{Int},UInt64}[]) -end -function Base.push!(lookup::AliasingLookup, ainfo::AliasingWrapper) - # Update the set of memory spaces and spans, - # and find the bounding spans for this AliasingWrapper - spaces_set = Set{MemorySpace}(lookup.spaces) - self_spaces_set = Set{Int}() - spans = SmallDict{Int,Vector{LocalMemorySpan}}() - for span in memory_spans(ainfo) - space = span.ptr.space - if !in(space, spaces_set) - push!(spaces_set, space) - push!(lookup.spaces, space) - push!(lookup.bounding_spans_tree, IntervalTree{LocatorMemorySpan{Int}}()) - end - space_idx = findfirst(==(space), lookup.spaces) - push!(self_spaces_set, space_idx) - spans_in_space = get!(Vector{LocalMemorySpan}, spans, space_idx) - push!(spans_in_space, LocalMemorySpan(span)) - end - push!(lookup.ainfos_spaces, collect(self_spaces_set)) - push!(lookup.spans, spans) - - # Update the set of AliasingWrappers - push!(lookup.ainfos, ainfo) - ainfo_idx = length(lookup.ainfos) - - # Check if the AliasingWrapper only exists in a single memory space - if length(self_spaces_set) == 1 - space_idx = only(self_spaces_set) - push!(lookup.ainfos_only_space, space_idx) - else - push!(lookup.ainfos_only_space, 0) - end - - # Add the bounding spans for this AliasingWrapper - bounding_spans = SmallDict{Int,LocalMemorySpan}() - for space_idx in keys(spans) - space_spans = spans[space_idx] - bound_start = minimum(span_start, space_spans) - bound_end = maximum(span_end, space_spans) - bounding_span = LocalMemorySpan(bound_start, bound_end - bound_start) - bounding_spans[space_idx] = bounding_span - insert!(lookup.bounding_spans_tree[space_idx], LocatorMemorySpan(bounding_span, ainfo_idx)) - end - push!(lookup.bounding_spans, bounding_spans) - - return ainfo_idx -end -struct AliasingLookupFinder - lookup::AliasingLookup - ainfo::AliasingWrapper - ainfo_idx::Int - spaces_idx::Vector{Int} - to_consider::Vector{Int} -end -Base.eltype(::AliasingLookupFinder) = AliasingWrapper -Base.IteratorSize(::AliasingLookupFinder) = Base.SizeUnknown() -# FIXME: We should use a Dict{UInt,Int} to find the ainfo_idx instead of linear search -function Base.intersect(lookup::AliasingLookup, ainfo::AliasingWrapper; ainfo_idx=nothing) - if ainfo_idx === nothing - ainfo_idx = something(findfirst(==(ainfo), lookup.ainfos)) - end - spaces_idx = lookup.ainfos_spaces[ainfo_idx] - to_consider_spans = LocatorMemorySpan{Int}[] - for space_idx in spaces_idx - bounding_spans_tree = lookup.bounding_spans_tree[space_idx] - self_bounding_span = LocatorMemorySpan(lookup.bounding_spans[ainfo_idx][space_idx], 0) - find_overlapping!(bounding_spans_tree, self_bounding_span, to_consider_spans; exact=false) - end - to_consider = Int[locator.owner for locator in to_consider_spans] - @assert all(to_consider .> 0) - return AliasingLookupFinder(lookup, ainfo, ainfo_idx, spaces_idx, to_consider) -end -Base.iterate(finder::AliasingLookupFinder) = iterate(finder, 1) -function Base.iterate(finder::AliasingLookupFinder, cursor_ainfo_idx) - ainfo_spaces = nothing - cursor_space_idx = 1 - - # New ainfos enter here - @label ainfo_restart - - # Check if we've exhausted all ainfos - if cursor_ainfo_idx > length(finder.to_consider) - return nothing - end - ainfo_idx = finder.to_consider[cursor_ainfo_idx] - - # Find the appropriate memory spaces for this ainfo - if ainfo_spaces === nothing - ainfo_spaces = finder.lookup.ainfos_spaces[ainfo_idx] - end - - # New memory spaces (for the same ainfo) enter here - @label space_restart - - # Check if we've exhausted all memory spaces for this ainfo, and need to move to the next ainfo - if cursor_space_idx > length(ainfo_spaces) - cursor_ainfo_idx += 1 - ainfo_spaces = nothing - cursor_space_idx = 1 - @goto ainfo_restart - end - - # Find the currently considered memory space for this ainfo - space_idx = ainfo_spaces[cursor_space_idx] - - # Check if this memory space is part of our target ainfo's spaces - if !(space_idx in finder.spaces_idx) - cursor_space_idx += 1 - @goto space_restart - end - - # Check if this ainfo's bounding span is part of our target ainfo's bounding span in this memory space - other_ainfo_bounding_span = finder.lookup.bounding_spans[ainfo_idx][space_idx] - self_bounding_span = finder.lookup.bounding_spans[finder.ainfo_idx][space_idx] - if !spans_overlap(other_ainfo_bounding_span, self_bounding_span) - cursor_space_idx += 1 - @goto space_restart - end - - # We have a overlapping bounds in the same memory space, so check if the ainfos are aliasing - # This is the slow path! - other_ainfo = finder.lookup.ainfos[ainfo_idx] - aliasing = will_alias(finder.ainfo, other_ainfo) - if !aliasing - cursor_ainfo_idx += 1 - ainfo_spaces = nothing - cursor_space_idx = 1 - @goto ainfo_restart - end - - # We overlap, so return the ainfo and the next ainfo index - return other_ainfo, cursor_ainfo_idx+1 -end +will_alias(x::AliasingWrapper, y::AliasingWrapper) = + will_alias(x.inner, y.inner) struct NoAliasing <: AbstractAliasing end memory_spans(::NoAliasing) = MemorySpan{CPURAMMemorySpace}[] @@ -324,11 +180,8 @@ struct CombinedAliasing <: AbstractAliasing end function memory_spans(ca::CombinedAliasing) # FIXME: Don't hardcode CPURAMMemorySpace - if length(ca.sub_ainfos) == 0 - return MemorySpan{CPURAMMemorySpace}[] - end - all_spans = memory_spans(ca.sub_ainfos[1]) - for sub_a in ca.sub_ainfos[2:end] + all_spans = MemorySpan{CPURAMMemorySpace}[] + for sub_a in ca.sub_ainfos append!(all_spans, memory_spans(sub_a)) end return all_spans @@ -338,23 +191,23 @@ Base.:(==)(ca1::CombinedAliasing, ca2::CombinedAliasing) = Base.hash(ca1::CombinedAliasing, h::UInt) = hash(ca1.sub_ainfos, hash(CombinedAliasing, h)) -struct ObjectAliasing{S<:MemorySpace} <: AbstractAliasing - ptr::RemotePtr{Cvoid,S} +struct ObjectAliasing <: AbstractAliasing + ptr::Ptr{Cvoid} sz::UInt end -ObjectAliasing(ptr::RemotePtr{Cvoid,S}, sz::Integer) where {S<:MemorySpace} = - ObjectAliasing{S}(ptr, UInt(sz)) function ObjectAliasing(x::T) where T @nospecialize x - ptr = RemotePtr{Cvoid}(pointer_from_objref(x)) + ptr = pointer_from_objref(x) sz = sizeof(T) return ObjectAliasing(ptr, sz) end -function memory_spans(oa::ObjectAliasing{S}) where S - span = MemorySpan{S}(oa.ptr, oa.sz) +function memory_spans(oa::ObjectAliasing) + rptr = RemotePtr{Cvoid}(oa.ptr) + span = MemorySpan{CPURAMMemorySpace}(rptr, oa.sz) return [span] end +aliasing(accel::Acceleration, x, T) = aliasing(x, T) function aliasing(x, dep_mod) if dep_mod isa Symbol return aliasing(getfield(x, dep_mod)) @@ -390,31 +243,16 @@ end aliasing(::String) = NoAliasing() # FIXME: Not necessarily true aliasing(::Symbol) = NoAliasing() aliasing(::Type) = NoAliasing() -function aliasing(x::Chunk, T) +aliasing(x::DTask, T) = aliasing(fetch(x; move_value=false, unwrap=false), T) +aliasing(x::DTask) = aliasing(fetch(x; move_value=false, unwrap=false)) +function aliasing(accel::DistributedAcceleration, x::Chunk, T) @assert x.handle isa DRef - if root_worker_id(x.processor) == myid() - return aliasing(unwrap(x), T) - end return remotecall_fetch(root_worker_id(x.processor), x, T) do x, T aliasing(unwrap(x), T) end end -aliasing(x::Chunk) = remotecall_fetch(root_worker_id(x.processor), x) do x - aliasing(unwrap(x)) -end -aliasing(x::DTask, T) = aliasing(fetch(x; raw=true), T) -aliasing(x::DTask) = aliasing(fetch(x; raw=true)) - -function aliasing(x::Base.RefValue{T}) where T - addr = UInt(Base.pointer_from_objref(x) + fieldoffset(typeof(x), 1)) - ptr = RemotePtr{Cvoid}(addr, CPURAMMemorySpace(myid())) - ainfo = ObjectAliasing(ptr, sizeof(x)) - if isassigned(x) && type_may_alias(T) && type_may_alias(typeof(x[])) - return CombinedAliasing([ainfo, aliasing(x[])]) - else - return CombinedAliasing([ainfo]) - end -end +aliasing(x::Chunk, T) = aliasing(unwrap(x), T) +aliasing(x::Chunk) = aliasing(unwrap(x)) struct ContiguousAliasing{S} <: AbstractAliasing span::MemorySpan{S} @@ -467,22 +305,13 @@ function _memory_spans(a::StridedAliasing{T,N,S}, spans, ptr, dim) where {T,N,S} return spans end -function aliasing(x::SubArray{T,N}) where {T,N} +function aliasing(x::SubArray{T,N,A}) where {T,N,A<:Array} if isbitstype(T) - p = parent(x) - space = memory_space(p) - S = typeof(space) - parent_ptr = RemotePtr{Cvoid}(UInt64(pointer(p)), space) - ptr = RemotePtr{Cvoid}(UInt64(pointer(x)), space) - NA = ndims(p) - raw_inds = parentindices(x) - inds = ntuple(i->raw_inds[i] isa Integer ? (raw_inds[i]:raw_inds[i]) : UnitRange(raw_inds[i]), NA) - sz = ntuple(i->length(inds[i]), NA) - return StridedAliasing{T,NA,S}(parent_ptr, - ptr, - inds, - sz, - strides(p)) + S = CPURAMMemorySpace + return StridedAliasing{T,ndims(x),S}(RemotePtr{Cvoid}(pointer(parent(x))), + RemotePtr{Cvoid}(pointer(x)), + parentindices(x), + size(x), strides(x)) else # FIXME: Also ContiguousAliasing of container #return IteratedAliasing(x) @@ -599,7 +428,7 @@ end function will_alias(x_span::MemorySpan, y_span::MemorySpan) may_alias(x_span.ptr.space, y_span.ptr.space) || return false # FIXME: Allow pointer conversion instead of just failing - @assert x_span.ptr.space == y_span.ptr.space "Memory spans are in different spaces: $(x_span.ptr.space) vs. $(y_span.ptr.space)" + @assert x_span.ptr.space == y_span.ptr.space x_end = x_span.ptr + x_span.len - 1 y_end = y_span.ptr + y_span.len - 1 return x_span.ptr <= y_end && y_span.ptr <= x_end diff --git a/src/mpi.jl b/src/mpi.jl new file mode 100644 index 000000000..1b84a7b9d --- /dev/null +++ b/src/mpi.jl @@ -0,0 +1,948 @@ +using MPI + +const CHECK_UNIFORMITY = Ref{Bool}(false) +function check_uniformity!(check::Bool=true) + CHECK_UNIFORMITY[] = check +end +function check_uniform(value::Integer, original=value) + CHECK_UNIFORMITY[] || return true + comm = MPI.COMM_WORLD + rank = MPI.Comm_rank(comm) + matched = compare_all(value, comm) + if !matched + if rank == 0 + Core.print("[$rank] Found non-uniform value!\n") + end + Core.print("[$rank] value=$value, original=$original") + throw(ArgumentError("Non-uniform value")) + end + MPI.Barrier(comm) + return matched +end +function check_uniform(value, original=value) + CHECK_UNIFORMITY[] || return true + return check_uniform(hash(value), original) +end + +function compare_all(value, comm) + rank = MPI.Comm_rank(comm) + size = MPI.Comm_size(comm) + for i in 0:(size-1) + if i != rank + send_yield(value, comm, i, UInt32(0); check_seen=false) + end + end + match = true + for i in 0:(size-1) + if i != rank + other_value = recv_yield(comm, i, UInt32(0)) + if value != other_value + match = false + end + end + end + return match +end + +struct MPIAcceleration <: Acceleration + comm::MPI.Comm +end +MPIAcceleration() = MPIAcceleration(MPI.COMM_WORLD) + +function aliasing(accel::MPIAcceleration, x::Chunk, T) + handle = x.handle::MPIRef + @assert accel.comm == handle.comm "MPIAcceleration comm mismatch" + tag = to_tag() + check_uniform(tag) + rank = MPI.Comm_rank(accel.comm) + if handle.rank == rank + ainfo = aliasing(x, T) + #Core.print("[$rank] aliasing: $ainfo, sending\n") + @opcounter :aliasing_bcast_send_yield + bcast_send_yield(ainfo, accel.comm, handle.rank, tag) + else + #Core.print("[$rank] aliasing: receiving from $(handle.rank)\n") + ainfo = recv_yield(accel.comm, handle.rank, tag) + #Core.print("[$rank] aliasing: received $ainfo\n") + end + check_uniform(ainfo) + return ainfo +end +default_processor(accel::MPIAcceleration) = MPIOSProc(accel.comm, 0) +default_processor(accel::MPIAcceleration, x) = MPIOSProc(accel.comm, 0) +default_processor(accel::MPIAcceleration, x::Chunk) = MPIOSProc(x.handle.comm, x.handle.rank) +default_processor(accel::MPIAcceleration, x::Function) = MPIOSProc(accel.comm, MPI.Comm_rank(accel.comm)) +default_processor(accel::MPIAcceleration, T::Type) = MPIOSProc(accel.comm, MPI.Comm_rank(accel.comm)) + +#TODO: Add a lock +const MPIClusterProcChildren = Dict{MPI.Comm, Set{Processor}}() + +struct MPIClusterProc <: Processor + comm::MPI.Comm + function MPIClusterProc(comm::MPI.Comm) + populate_children(comm) + return new(comm) + end +end + +Sch.init_proc(state, proc::MPIClusterProc, log_sink) = Sch.init_proc(state, MPIOSProc(proc.comm), log_sink) + +MPIClusterProc() = MPIClusterProc(MPI.COMM_WORLD) + +function populate_children(comm::MPI.Comm) + children = get_processors(OSProc()) + MPIClusterProcChildren[comm] = children +end + +struct MPIOSProc <: Processor + comm::MPI.Comm + rank::Int +end + +function MPIOSProc(comm::MPI.Comm) + rank = MPI.Comm_rank(comm) + return MPIOSProc(comm, rank) +end + +function MPIOSProc() + return MPIOSProc(MPI.COMM_WORLD) +end + +ProcessScope(p::MPIOSProc) = ProcessScope(myid()) + +function check_uniform(proc::MPIOSProc, original=proc) + return check_uniform(hash(MPIOSProc), original) && + check_uniform(proc.rank, original) +end + +function memory_spaces(proc::MPIOSProc) + children = get_processors(proc) + spaces = Set{MemorySpace}() + for proc in children + for space in memory_spaces(proc) + push!(spaces, space) + end + end + return spaces +end + +struct MPIProcessScope <: AbstractScope + comm::MPI.Comm + rank::Int +end + +Base.isless(::MPIProcessScope, ::MPIProcessScope) = false +Base.isless(::MPIProcessScope, ::NodeScope) = true +Base.isless(::MPIProcessScope, ::UnionScope) = true +Base.isless(::MPIProcessScope, ::TaintScope) = true +Base.isless(::MPIProcessScope, ::AnyScope) = true +constrain(x::MPIProcessScope, y::MPIProcessScope) = + x == y ? y : InvalidScope(x, y) +constrain(x::NodeScope, y::MPIProcessScope) = + x == y.parent ? y : InvalidScope(x, y) + +Base.isless(::ExactScope, ::MPIProcessScope) = true +constrain(x::MPIProcessScope, y::ExactScope) = + x == y.parent ? y : InvalidScope(x, y) + +function enclosing_scope(proc::MPIOSProc) + return MPIProcessScope(proc.comm, proc.rank) +end + +function Dagger.to_scope(::Val{:mpi_rank}, sc::NamedTuple) + if sc.mpi_rank == Colon() + return Dagger.to_scope(Val{:mpi_ranks}(), merge(sc, (;mpi_ranks=Colon()))) + else + @assert sc.mpi_rank isa Integer "Expected a single GPU device ID for :mpi_rank, got $(sc.mpi_rank)\nConsider using :mpi_ranks instead." + return Dagger.to_scope(Val{:mpi_ranks}(), merge(sc, (;mpi_ranks=[sc.mpi_rank]))) + end +end +Dagger.scope_key_precedence(::Val{:mpi_rank}) = 2 +function Dagger.to_scope(::Val{:mpi_ranks}, sc::NamedTuple) + comm = get(sc, :mpi_comm, MPI.COMM_WORLD) + if sc.ranks != Colon() + ranks = sc.ranks + else + ranks = MPI.Comm_size(comm) + end + inner_sc = NamedTuple(filter(kv->kv[1] != :mpi_ranks, Base.pairs(sc))...) + # FIXME: What to do here? + inner_scope = Dagger.to_scope(inner_sc) + scopes = Dagger.ExactScope[] + for rank in ranks + procs = Dagger.get_processors(Dagger.MPIOSProc(comm, rank)) + rank_scope = MPIProcessScope(comm, rank) + for proc in procs + proc_scope = Dagger.ExactScope(proc) + constrain(proc_scope, rank_scope) isa Dagger.InvalidScope && continue + push!(scopes, proc_scope) + end + end + return Dagger.UnionScope(scopes) +end +Dagger.scope_key_precedence(::Val{:mpi_ranks}) = 2 + +struct MPIProcessor{P<:Processor} <: Processor + innerProc::P + comm::MPI.Comm + rank::Int +end +proc_in_scope(proc::Processor, scope::MPIProcessScope) = false +proc_in_scope(proc::MPIProcessor, scope::MPIProcessScope) = + proc.comm == scope.comm && proc.rank == scope.rank + +function check_uniform(proc::MPIProcessor, original=proc) + return check_uniform(hash(MPIProcessor), original) && + check_uniform(proc.rank, original) && + # TODO: Not always valid (if pointer is embedded, say for GPUs) + check_uniform(hash(proc.innerProc), original) +end + +Dagger.iscompatible_func(::MPIProcessor, opts, ::Any) = true +Dagger.iscompatible_arg(::MPIProcessor, opts, ::Any) = true + +default_enabled(proc::MPIProcessor) = default_enabled(proc.innerProc) + +root_worker_id(proc::MPIProcessor) = myid() +root_worker_id(proc::MPIOSProc) = myid() +root_worker_id(proc::MPIClusterProc) = myid() + +get_parent(proc::MPIClusterProc) = proc +get_parent(proc::MPIOSProc) = MPIClusterProc(proc.comm) +get_parent(proc::MPIProcessor) = MPIOSProc(proc.comm, proc.rank) + +short_name(proc::MPIProcessor) = "(MPI: $(proc.rank), $(short_name(proc.innerProc)))" + +function get_processors(mosProc::MPIOSProc) + populate_children(mosProc.comm) + children = MPIClusterProcChildren[mosProc.comm] + mpiProcs = Set{Processor}() + for proc in children + push!(mpiProcs, MPIProcessor(proc, mosProc.comm, mosProc.rank)) + end + return mpiProcs +end + +#TODO: non-uniform ranking through MPI groups +#TODO: use a lazy iterator +function get_processors(proc::MPIClusterProc) + children = Set{Processor}() + for i in 0:(MPI.Comm_size(proc.comm)-1) + for innerProc in MPIClusterProcChildren[proc.comm] + push!(children, MPIProcessor(innerProc, proc.comm, i)) + end + end + return children +end + +struct MPIMemorySpace{S<:MemorySpace} <: MemorySpace + innerSpace::S + comm::MPI.Comm + rank::Int +end + +function check_uniform(space::MPIMemorySpace, original=space) + return check_uniform(space.rank, original) && + # TODO: Not always valid (if pointer is embedded, say for GPUs) + check_uniform(hash(space.innerSpace), original) +end + +default_processor(space::MPIMemorySpace) = MPIOSProc(space.comm, space.rank) +default_memory_space(accel::MPIAcceleration) = MPIMemorySpace(CPURAMMemorySpace(myid()), accel.comm, 0) + +default_memory_space(accel::MPIAcceleration, x) = MPIMemorySpace(CPURAMMemorySpace(myid()), accel.comm, 0) +default_memory_space(accel::MPIAcceleration, x::Chunk) = MPIMemorySpace(CPURAMMemorySpace(myid()), x.handle.comm, x.handle.rank) +default_memory_space(accel::MPIAcceleration, x::Function) = MPIMemorySpace(CPURAMMemorySpace(myid()), accel.comm, MPI.Comm_rank(accel.comm)) +default_memory_space(accel::MPIAcceleration, T::Type) = MPIMemorySpace(CPURAMMemorySpace(myid()), accel.comm, MPI.Comm_rank(accel.comm)) + +function memory_spaces(proc::MPIClusterProc) + rawMemSpace = Set{MemorySpace}() + for rnk in 0:(MPI.Comm_size(proc.comm) - 1) + for innerSpace in memory_spaces(OSProc()) + push!(rawMemSpace, MPIMemorySpace(innerSpace, proc.comm, rnk)) + end + end + return rawMemSpace +end + +function memory_spaces(proc::MPIProcessor) + rawMemSpace = Set{MemorySpace}() + for innerSpace in memory_spaces(proc.innerProc) + push!(rawMemSpace, MPIMemorySpace(innerSpace, proc.comm, proc.rank)) + end + return rawMemSpace +end + +root_worker_id(mem_space::MPIMemorySpace) = myid() + +function processors(memSpace::MPIMemorySpace) + rawProc = Set{Processor}() + for innerProc in processors(memSpace.innerSpace) + push!(rawProc, MPIProcessor(innerProc, memSpace.comm, memSpace.rank)) + end + return rawProc +end + +struct MPIRefID + tid::Int + uid::UInt + id::Int + function MPIRefID(tid, uid, id) + @assert tid > 0 || uid > 0 "Invalid MPIRefID: tid=$tid, uid=$uid, id=$id" + return new(tid, uid, id) + end +end +Base.hash(id::MPIRefID, h::UInt=UInt(0)) = + hash(id.tid, hash(id.uid, hash(id.id, hash(MPIRefID, h)))) + +function check_uniform(ref::MPIRefID, original=ref) + return check_uniform(ref.tid, original) && + check_uniform(ref.uid, original) && + check_uniform(ref.id, original) +end + +const MPIREF_TID = Dict{Int, Threads.Atomic{Int}}() +const MPIREF_UID = Dict{Int, Threads.Atomic{Int}}() + +mutable struct MPIRef + comm::MPI.Comm + rank::Int + size::Int + innerRef::Union{DRef, Nothing} + id::MPIRefID +end +Base.hash(ref::MPIRef, h::UInt=UInt(0)) = hash(ref.id, hash(MPIRef, h)) +root_worker_id(ref::MPIRef) = myid() + +function check_uniform(ref::MPIRef, original=ref) + return check_uniform(ref.rank, original) && + check_uniform(ref.id, original) +end + +move(from_proc::Processor, to_proc::Processor, x::MPIRef) = + move(from_proc, to_proc, poolget(x; uniform=FETCH_UNIFORM[])) + +function affinity(x::MPIRef) + if x.innerRef === nothing + return MPIOSProc(x.comm, x.rank)=>0 + else + return MPIOSProc(x.comm, x.rank)=>x.innerRef.size + end +end + +function take_ref_id!() + tid = 0 + uid = 0 + id = 0 + if Dagger.in_task() + tid = sch_handle().thunk_id.id + uid = 0 + counter = get!(MPIREF_TID, tid, Threads.Atomic{Int}(1)) + id = Threads.atomic_add!(counter, 1) + elseif MPI_TID[] != 0 + tid = MPI_TID[] + uid = 0 + counter = get!(MPIREF_TID, tid, Threads.Atomic{Int}(1)) + id = Threads.atomic_add!(counter, 1) + elseif MPI_UID[] != 0 + tid = 0 + uid = MPI_UID[] + counter = get!(MPIREF_UID, uid, Threads.Atomic{Int}(1)) + id = Threads.atomic_add!(counter, 1) + end + return MPIRefID(tid, uid, id) +end + +#TODO: partitioned scheduling with comm bifurcation +function tochunk_pset(x, space::MPIMemorySpace; device=nothing, kwargs...) + @assert space.comm == MPI.COMM_WORLD "$(space.comm) != $(MPI.COMM_WORLD)" + local_rank = MPI.Comm_rank(space.comm) + Mid = take_ref_id!() + if local_rank != space.rank + return MPIRef(space.comm, space.rank, 0, nothing, Mid) + else + # type= is for Chunk metadata only; MemPool.poolset does not accept it + pset_kw = (; (k => v for (k, v) in pairs(kwargs) if k !== :type)...) + return MPIRef(space.comm, space.rank, sizeof(x), poolset(x; device, pset_kw...), Mid) + end +end + +const DEADLOCK_DETECT = TaskLocalValue{Bool}(()->true) +const DEADLOCK_WARN_PERIOD = TaskLocalValue{Float64}(()->10.0) +const DEADLOCK_TIMEOUT_PERIOD = TaskLocalValue{Float64}(()->120.0) +const RECV_WAITING = Base.Lockable(Dict{Tuple{MPI.Comm, Int, Int}, Base.Event}()) + +struct InplaceInfo + type::DataType + shape::Tuple +end +struct InplaceSparseInfo + type::DataType + m::Int + n::Int + colptr::Int + rowval::Int + nzval::Int +end + +function supports_inplace_mpi(value) + if value isa DenseArray && isbitstype(eltype(value)) + return true + else + return false + end +end +function recv_yield!(buffer, comm, src, tag) + rank = MPI.Comm_rank(comm) + #Core.println("buffer recv: $buffer, type of buffer: $(typeof(buffer)), is in place? $(supports_inplace_mpi(buffer))") + if !supports_inplace_mpi(buffer) + return recv_yield(comm, src, tag), false + end + #Core.println("[rank $(MPI.Comm_rank(comm))][tag $tag] Starting recv! from [$src]") + + # Ensure no other receiver is waiting + our_event = Base.Event() + @label retry + other_event = lock(RECV_WAITING) do waiting + if haskey(waiting, (comm, src, tag)) + waiting[(comm, src, tag)] + else + waiting[(comm, src, tag)] = our_event + nothing + end + end + if other_event !== nothing + #Core.println("[rank $(MPI.Comm_rank(comm))][tag $tag] Waiting for other receiver...") + wait(other_event) + @goto retry + end + + buffer = recv_yield_inplace!(buffer, comm, rank, src, tag) + + lock(RECV_WAITING) do waiting + delete!(waiting, (comm, src, tag)) + notify(our_event) + end + + return buffer, true + +end + +function recv_yield(comm, src, tag) + rank = MPI.Comm_rank(comm) + #Core.println("[rank $(MPI.Comm_rank(comm))][tag $tag] Starting recv from [$src]") + + # Ensure no other receiver is waiting + our_event = Base.Event() + @label retry + other_event = lock(RECV_WAITING) do waiting + if haskey(waiting, (comm, src, tag)) + waiting[(comm, src, tag)] + else + waiting[(comm, src, tag)] = our_event + nothing + end + end + if other_event !== nothing + #Core.println("[rank $(MPI.Comm_rank(comm))][tag $tag] Waiting for other receiver...") + wait(other_event) + @goto retry + end + #Core.println("[rank $(MPI.Comm_rank(comm))][tag $tag] Receiving...") + + type = nothing + @label receive + value = recv_yield_serialized(comm, rank, src, tag) + if value isa InplaceInfo || value isa InplaceSparseInfo + value = recv_yield_inplace(value, comm, rank, src, tag) + end + + lock(RECV_WAITING) do waiting + delete!(waiting, (comm, src, tag)) + notify(our_event) + end + return value +end + +function recv_yield_inplace!(array, comm, my_rank, their_rank, tag) + time_start = time_ns() + detect = DEADLOCK_DETECT[] + warn_period = round(UInt64, DEADLOCK_WARN_PERIOD[] * 1e9) + timeout_period = round(UInt64, DEADLOCK_TIMEOUT_PERIOD[] * 1e9) + + while true + (got, msg, stat) = MPI.Improbe(their_rank, tag, comm, MPI.Status) + if got + if MPI.Get_error(stat) != MPI.SUCCESS + error("recv_yield failed with error $(MPI.Get_error(stat))") + end + count = MPI.Get_count(stat, UInt8) + @assert count == sizeof(array) "recv_yield_inplace: expected $(sizeof(array)) bytes, got $count" + buf = MPI.Buffer(array) + req = MPI.Imrecv!(buf, msg) + __wait_for_request(req, comm, my_rank, their_rank, tag, "recv_yield", "recv") + return array + end + warn_period = mpi_deadlock_detect(detect, time_start, warn_period, timeout_period, my_rank, tag, "recv", their_rank) + yield() + end +end + +function recv_yield_inplace(_value::InplaceInfo, comm, my_rank, their_rank, tag) + T = _value.type + @assert T <: Array && isbitstype(eltype(T)) "recv_yield_inplace only supports inplace MPI transfers of bitstype dense arrays" + array = Array{eltype(T)}(undef, _value.shape) + return recv_yield_inplace!(array, comm, my_rank, their_rank, tag) +end + +function recv_yield_inplace(_value::InplaceSparseInfo, comm, my_rank, their_rank, tag) + T = _value.type + @assert T <: SparseMatrixCSC "recv_yield_inplace only supports inplace MPI transfers of SparseMatrixCSC" + + colptr = recv_yield_inplace!(Vector{Int64}(undef, _value.colptr), comm, my_rank, their_rank, tag) + rowval = recv_yield_inplace!(Vector{Int64}(undef, _value.rowval), comm, my_rank, their_rank, tag) + nzval = recv_yield_inplace!(Vector{eltype(T)}(undef, _value.nzval), comm, my_rank, their_rank, tag) + + return SparseMatrixCSC{eltype(T), Int64}(_value.m, _value.n, colptr, rowval, nzval) +end + +function recv_yield_serialized(comm, my_rank, their_rank, tag) + time_start = time_ns() + detect = DEADLOCK_DETECT[] + warn_period = round(UInt64, DEADLOCK_WARN_PERIOD[] * 1e9) + timeout_period = round(UInt64, DEADLOCK_TIMEOUT_PERIOD[] * 1e9) + + while true + (got, msg, stat) = MPI.Improbe(their_rank, tag, comm, MPI.Status) + if got + if MPI.Get_error(stat) != MPI.SUCCESS + error("recv_yield failed with error $(MPI.Get_error(stat))") + end + count = MPI.Get_count(stat, UInt8) + buf = Array{UInt8}(undef, count) + req = MPI.Imrecv!(MPI.Buffer(buf), msg) + __wait_for_request(req, comm, my_rank, their_rank, tag, "recv_yield", "recv") + return MPI.deserialize(buf) + end + warn_period = mpi_deadlock_detect(detect, time_start, warn_period, timeout_period, my_rank, tag, "recv", their_rank) + yield() + end +end + +const SEEN_TAGS = Dict{Int32, Type}() +send_yield!(value, comm, dest, tag; check_seen::Bool=true) = + _send_yield(value, comm, dest, tag; check_seen, inplace=true) +send_yield(value, comm, dest, tag; check_seen::Bool=true) = + _send_yield(value, comm, dest, tag; check_seen, inplace=false) +function _send_yield(value, comm, dest, tag; check_seen::Bool=true, inplace::Bool) + rank = MPI.Comm_rank(comm) + + if check_seen && haskey(SEEN_TAGS, tag) && SEEN_TAGS[tag] !== typeof(value) + @error "[rank $(MPI.Comm_rank(comm))][tag $tag] Already seen tag (previous type: $(SEEN_TAGS[tag]), new type: $(typeof(value)))" exception=(InterruptException(),backtrace()) + end + if check_seen + SEEN_TAGS[tag] = typeof(value) + end + #Core.println("[rank $(MPI.Comm_rank(comm))][tag $tag] Starting send to [$dest]: $(typeof(value)), is support inplace? $(supports_inplace_mpi(value))") + if inplace && supports_inplace_mpi(value) + send_yield_inplace(value, comm, rank, dest, tag) + else + send_yield_serialized(value, comm, rank, dest, tag) + end +end + +function send_yield_inplace(value, comm, my_rank, their_rank, tag) + @opcounter :send_yield_inplace + req = MPI.Isend(value, comm; dest=their_rank, tag) + __wait_for_request(req, comm, my_rank, their_rank, tag, "send_yield", "send") +end + +function send_yield_serialized(value, comm, my_rank, their_rank, tag) + @opcounter :send_yield_serialized + if value isa Array && isbitstype(eltype(value)) + send_yield_serialized(InplaceInfo(typeof(value), size(value)), comm, my_rank, their_rank, tag) + send_yield_inplace(value, comm, my_rank, their_rank, tag) + elseif value isa SparseMatrixCSC && isbitstype(eltype(value)) + send_yield_serialized(InplaceSparseInfo(typeof(value), value.m, value.n, length(value.colptr), length(value.rowval), length(value.nzval)), comm, my_rank, their_rank, tag) + send_yield_inplace(value.colptr, comm, my_rank, their_rank, tag) + send_yield_inplace(value.rowval, comm, my_rank, their_rank, tag) + send_yield_inplace(value.nzval, comm, my_rank, their_rank, tag) + else + req = MPI.isend(value, comm; dest=their_rank, tag) + __wait_for_request(req, comm, my_rank, their_rank, tag, "send_yield", "send") + end +end + +function __wait_for_request(req, comm, my_rank, their_rank, tag, fn::String, kind::String) + time_start = time_ns() + detect = DEADLOCK_DETECT[] + warn_period = round(UInt64, DEADLOCK_WARN_PERIOD[] * 1e9) + timeout_period = round(UInt64, DEADLOCK_TIMEOUT_PERIOD[] * 1e9) + while true + finish, status = MPI.Test(req, MPI.Status) + if finish + if MPI.Get_error(status) != MPI.SUCCESS + error("$fn failed with error $(MPI.Get_error(status))") + end + return + end + warn_period = mpi_deadlock_detect(detect, time_start, warn_period, timeout_period, my_rank, tag, kind, their_rank) + yield() + end +end + +function bcast_send_yield(value, comm, root, tag) + @opcounter :bcast_send_yield + sz = MPI.Comm_size(comm) + rank = MPI.Comm_rank(comm) + for other_rank in 0:(sz-1) + rank == other_rank && continue + send_yield(value, comm, other_rank, tag) + end +end + +#= Maybe can be worth it to implement this +function bcast_send_yield!(value, comm, root, tag) + sz = MPI.Comm_size(comm) + rank = MPI.Comm_rank(comm) + + for other_rank in 0:(sz-1) + rank == other_rank && continue + #println("[rank $rank] Sending to rank $other_rank") + send_yield!(value, comm, other_rank, tag) + end +end + +function bcast_recv_yield!(value, comm, root, tag) + sz = MPI.Comm_size(comm) + rank = MPI.Comm_rank(comm) + #println("[rank $rank] receive from rank $root") + recv_yield!(value, comm, root, tag) +end +=# +function mpi_deadlock_detect(detect, time_start, warn_period, timeout_period, rank, tag, kind, srcdest) + time_elapsed = (time_ns() - time_start) + if detect && time_elapsed > warn_period + @warn "[rank $rank][tag $tag] Hit probable hang on $kind (dest: $srcdest)" + return typemax(UInt64) + end + if detect && time_elapsed > timeout_period + error("[rank $rank][tag $tag] Hit hang on $kind (dest: $srcdest)") + end + return warn_period +end + +#discuss this with julian +WeakChunk(c::Chunk{T,H}) where {T,H<:MPIRef} = WeakChunk(c.handle.rank, c.handle.id.id, WeakRef(c)) + +function MemPool.poolget(ref::MPIRef; uniform::Bool=false) + @assert uniform || ref.rank == MPI.Comm_rank(ref.comm) "MPIRef rank mismatch: $(ref.rank) != $(MPI.Comm_rank(ref.comm))" + if uniform + tag = to_tag() + if ref.rank == MPI.Comm_rank(ref.comm) + value = poolget(ref.innerRef) + @opcounter :poolget_bcast_send_yield + bcast_send_yield(value, ref.comm, ref.rank, tag) + return value + else + return recv_yield(ref.comm, ref.rank, tag) + end + else + return poolget(ref.innerRef) + end +end +fetch_handle(ref::MPIRef; uniform::Bool=false) = poolget(ref; uniform) + +function move!(dep_mod, to_space::MPIMemorySpace, from_space::MPIMemorySpace, to::Chunk, from::Chunk) + @assert to.handle isa MPIRef && from.handle isa MPIRef "MPIRef expected" + @assert to.handle.comm == from.handle.comm "MPIRef comm mismatch" + @assert to.handle.rank == to_space.rank && from.handle.rank == from_space.rank "MPIRef rank mismatch" + local_rank = MPI.Comm_rank(from.handle.comm) + if to_space.rank == from_space.rank == local_rank + move!(dep_mod, to_space.innerSpace, from_space.innerSpace, to, from) + else + @dagdebug nothing :mpi "[$local_rank][$tag] Moving from $(from_space.rank) to $(to_space.rank)\n" + tag = to_tag() + if local_rank == from_space.rank + send_yield!(poolget(from.handle; uniform=false), to_space.comm, to_space.rank, tag) + elseif local_rank == to_space.rank + #@dagdebug nothing :mpi "[$local_rank][$tag] Receiving from rank $(from_space.rank) with tag $tag, type of buffer: $(typeof(poolget(to.handle; uniform=false)))" + to_val = poolget(to.handle; uniform=false) + val, inplace = recv_yield!(to_val, from_space.comm, from_space.rank, tag) + if !inplace + move!(dep_mod, to_space.innerSpace, from_space.innerSpace, to_val, val) + end + end + end + @dagdebug nothing :mpi "[$local_rank][$tag] Finished moving from $(from_space.rank) to $(to_space.rank) successfuly\n" +end +function move!(dep_mod::RemainderAliasing{<:MPIMemorySpace}, to_space::MPIMemorySpace, from_space::MPIMemorySpace, to::Chunk, from::Chunk) + @assert to.handle isa MPIRef && from.handle isa MPIRef "MPIRef expected" + @assert to.handle.comm == from.handle.comm "MPIRef comm mismatch" + @assert to.handle.rank == to_space.rank && from.handle.rank == from_space.rank "MPIRef rank mismatch" + local_rank = MPI.Comm_rank(from.handle.comm) + if to_space.rank == from_space.rank == local_rank + move!(dep_mod, to_space.innerSpace, from_space.innerSpace, to, from) + else + tag = to_tag() + @dagdebug nothing :mpi "[$local_rank][$tag] Moving from $(from_space.rank) to $(to_space.rank)\n" + if local_rank == from_space.rank + # Get the source data for each span + len = sum(span_tuple->span_len(span_tuple[1]), dep_mod.spans) + copies = Vector{UInt8}(undef, len) + offset = 1 + for (from_span, _) in dep_mod.spans + #GC.@preserve copy begin + from_ptr = Ptr{UInt8}(from_span.ptr) + to_ptr = Ptr{UInt8}(pointer(copies, offset)) + unsafe_copyto!(to_ptr, from_ptr, from_span.len) + offset += from_span.len + #end + end + + # Send the spans + #send_yield(len, to_space.comm, to_space.rank, tag) + send_yield!(copies, to_space.comm, to_space.rank, tag; check_seen=false) + #send_yield(copies, to_space.comm, to_space.rank, tag) + elseif local_rank == to_space.rank + # Receive the spans + len = sum(span_tuple->span_len(span_tuple[1]), dep_mod.spans) + copies = Vector{UInt8}(undef, len) + recv_yield!(copies, from_space.comm, from_space.rank, tag) + #copies = recv_yield(from_space.comm, from_space.rank, tag) + + # Copy the data into the destination object + #for (copy, (_, to_span)) in zip(copies, dep_mod.spans) + offset = 1 + for (_, to_span) in dep_mod.spans + #GC.@preserve copy begin + from_ptr = Ptr{UInt8}(pointer(copies, offset)) + to_ptr = Ptr{UInt8}(to_span.ptr) + unsafe_copyto!(to_ptr, from_ptr, to_span.len) + offset += to_span.len + #end + end + + # Ensure that the data is visible + Core.Intrinsics.atomic_fence(:release) + end + end + + return +end + + +move(::MPIOSProc, ::MPIProcessor, x::Union{Function,Type}) = x +move(::MPIOSProc, ::MPIProcessor, x::Chunk{<:Union{Function,Type}}) = poolget(x.handle) + +#TODO: out of place MPI move +function move(src::MPIOSProc, dst::MPIProcessor, x::Chunk) + @assert src.comm == dst.comm "Multi comm move not supported" + if Sch.SCHED_MOVE[] + if dst.rank == MPI.Comm_rank(dst.comm) + return poolget(x.handle) + end + else + @assert src.rank == MPI.Comm_rank(src.comm) "Unwrapping not permited" + @assert src.rank == x.handle.rank == dst.rank + return poolget(x.handle) + end +end + +const MPI_UNIFORM = ScopedValue{Bool}(false) +# When true, move(_, _, MPIRef) uses poolget(; uniform=true) so the owner bcasts and the fetcher recv (e.g. rank 0 collecting). +const FETCH_UNIFORM = ScopedValue{Bool}(true) + +function remotecall_endpoint(f, accel::Dagger.MPIAcceleration, from_proc, to_proc, from_space, to_space, data) + loc_rank = MPI.Comm_rank(accel.comm) + task = DATADEPS_CURRENT_TASK[] + return with(MPI_UID=>task.uid, MPI_UNIFORM=>true) do + @assert data isa Chunk "Expected Chunk, got $(typeof(data))" + space = memory_space(data) + tag = to_tag() + type_tag = to_tag() + T = move_type(from_proc.innerProc, to_proc.innerProc, chunktype(data)) + T_new = f !== identity ? Base._return_type(f, Tuple{T}) : T + need_bcast = !isconcretetype(T_new) || T_new === Union{} || T_new === Nothing || T_new === Any + + if space.rank != from_proc.rank + # Data is already at destination (to_proc.rank) + @assert space.rank == to_proc.rank + if space.rank == loc_rank + value = poolget(data.handle) + data_converted = f(move(from_proc.innerProc, to_proc.innerProc, value)) + T_actual = typeof(data_converted) + if need_bcast + bcast_send_yield(T_actual, accel.comm, to_proc.rank, type_tag) + end + return tochunk(data_converted, to_proc, to_space; type=T_actual) + else + T_actual = need_bcast ? recv_yield(accel.comm, to_proc.rank, type_tag) : T_new + return tochunk(nothing, to_proc, to_space; type=T_actual) + end + end + + # Data is on the source rank + @assert space.rank == from_proc.rank + if loc_rank == from_proc.rank == to_proc.rank + value = poolget(data.handle) + data_converted = f(move(from_proc.innerProc, to_proc.innerProc, value)) + return tochunk(data_converted, to_proc, to_space; type=typeof(data_converted)) + end + + if loc_rank == from_proc.rank + value = poolget(data.handle) + data_moved = move(from_proc.innerProc, to_proc.innerProc, value) + Dagger.send_yield(data_moved, accel.comm, to_proc.rank, tag) + T_actual = need_bcast ? recv_yield(accel.comm, to_proc.rank, type_tag) : T_new + return tochunk(nothing, to_proc, to_space; type=T_actual) + elseif loc_rank == to_proc.rank + data_moved = Dagger.recv_yield(accel.comm, from_space.rank, tag) + data_converted = f(move(from_proc.innerProc, to_proc.innerProc, data_moved)) + T_actual = typeof(data_converted) + if need_bcast + bcast_send_yield(T_actual, accel.comm, to_proc.rank, type_tag) + end + return tochunk(data_converted, to_proc, to_space; type=T_actual) + else + T_actual = need_bcast ? recv_yield(accel.comm, to_proc.rank, type_tag) : T_new + return tochunk(nothing, to_proc, to_space; type=T_actual) + end + end +end + +# Chunk may be MPI-backed (MPIRef) but labeled with OSProc; treat source as the owning rank +function move(src::OSProc, dst::MPIProcessor, x::Chunk) + if x.handle isa MPIRef + return move(MPIOSProc(x.handle.comm, x.handle.rank), dst, x) + end + error("MPI move not supported") +end + +move(src::Processor, dst::MPIProcessor, x::Chunk) = error("MPI move not supported") +move(to_proc::MPIProcessor, chunk::Chunk) = + move(chunk.processor, to_proc, chunk) +move(to_proc::Processor, d::MPIRef) = + move(MPIOSProc(d.rank), to_proc, d) +move(to_proc::MPIProcessor, x) = + move(MPIOSProc(), to_proc, x) + +move(::MPIProcessor, ::MPIProcessor, x::Union{Function,Type}) = x +move(::MPIProcessor, ::MPIProcessor, x::Chunk{<:Union{Function,Type}}) = poolget(x.handle) + +@warn "Is this uniform logic valuable to have?" maxlog=1 +function move(src::MPIProcessor, dst::MPIProcessor, x::Chunk) + uniform = false #uniform = MPI_UNIFORM[] + @assert uniform || src.rank == dst.rank "Unwrapping not permitted" + if Sch.SCHED_MOVE[] + # We can either unwrap locally, or return nothing + if dst.rank == MPI.Comm_rank(dst.comm) + return poolget(x.handle) + end + else + # Either we're uniform (so everyone cooperates), or we're unwrapping locally + if !uniform + @assert src.rank == MPI.Comm_rank(src.comm) "Unwrapping not permitted" + @assert src.rank == x.handle.rank == dst.rank + end + return poolget(x.handle; uniform) + end +end + + +#FIXME:try to think of a better move! scheme +function execute!(proc::MPIProcessor, f, args...; kwargs...) + local_rank = MPI.Comm_rank(proc.comm) + islocal = local_rank == proc.rank + inplace_move = f === move! + result = nothing + tag = to_tag() + + if islocal || inplace_move + result = execute!(proc.innerProc, f, args...; kwargs...) + end + + if inplace_move + space = memory_space(nothing, proc)::MPIMemorySpace + dest_type = chunktype(args[4]) + return tochunk(nothing, proc, space; type=dest_type) + end + + # Infer return type; only bcast when inference is not concrete + fname = nameof(f) + arg_types = map(chunktype, args) + inferred_type = Base.promote_op(f, arg_types...) + + need_bcast = !isconcretetype(inferred_type) || inferred_type === Union{} || inferred_type === Nothing || inferred_type === Any + + if islocal + T = typeof(result) + space = memory_space(result, proc)::MPIMemorySpace + if need_bcast + @opcounter :execute_bcast_send_yield + bcast_send_yield((T, space.innerSpace), proc.comm, proc.rank, tag) + end + return tochunk(result, proc, space; type=T) + else + if need_bcast + T, innerSpace = recv_yield(proc.comm, proc.rank, tag) + space = MPIMemorySpace(innerSpace, proc.comm, proc.rank) + return tochunk(nothing, proc, space; type=T) + else + space = memory_space(nothing, proc)::MPIMemorySpace + return tochunk(nothing, proc, space; type=inferred_type) + end + end +end + +accelerate!(::Val{:mpi}) = accelerate!(MPIAcceleration()) + +function initialize_acceleration!(a::MPIAcceleration) + if !MPI.Initialized() + MPI.Init(;threadlevel=:multiple) + end + ctx = Dagger.Sch.eager_context() + sz = MPI.Comm_size(a.comm) + for i in 0:(sz-1) + push!(ctx.procs, MPIOSProc(a.comm, i)) + end + unique!(ctx.procs) +end + +""" + mpi_propagate_chunk_types!(tasks, accel::MPIAcceleration, expected_type) + +Ensure all ranks use the same concrete type for the given tasks by setting +each task's options.return_type to expected_type when it is concrete. +This allows chunktype(task) to return the concrete type on every rank +without an MPI allgather of actual result types. +""" +function mpi_propagate_chunk_types!(tasks, accel::MPIAcceleration, expected_type) + isconcretetype(expected_type) || return + for t in tasks + if t isa Thunk + if t.options !== nothing + t.options.return_type = expected_type + else + t.options = Options(return_type=expected_type) + end + end + end + return +end + +accel_matches_proc(accel::MPIAcceleration, proc::MPIOSProc) = true +accel_matches_proc(accel::MPIAcceleration, proc::MPIClusterProc) = true +accel_matches_proc(accel::MPIAcceleration, proc::MPIProcessor) = true +accel_matches_proc(accel::MPIAcceleration, proc) = false + +function distribute(accel::MPIAcceleration, A::AbstractArray{T,N}, dist::Blocks{N}) where {T,N} + comm = accel.comm + rank = MPI.Comm_rank(comm) + + DA = view(A, dist) + DB = DArray{T,N}(undef, dist, size(A)) + copyto!(DB, DA) + + return DB +end diff --git a/src/mpi_mempool.jl b/src/mpi_mempool.jl new file mode 100644 index 000000000..149c7900a --- /dev/null +++ b/src/mpi_mempool.jl @@ -0,0 +1,36 @@ +# Mempool for received MPI message data only (no envelopes). +# Key: (comm, source, tag). Used when a message is received but not the one the caller was waiting for. +# Included from mpi.jl; runs in Dagger module scope. + +const MPI_RECV_MEMPOOL = Base.Lockable(Dict{Tuple{MPI.Comm, Int, Int}, Vector{Any}}()) + +function mpi_mempool_put!(comm::MPI.Comm, source::Integer, tag::Integer, data::Any) + key = (comm, Int(source), Int(tag)) + ref = poolset(data) + lock(MPI_RECV_MEMPOOL) do pool + if !haskey(pool, key) + pool[key] = Any[] + end + push!(pool[key], ref) + end + return nothing +end + +function mpi_mempool_take!(comm::MPI.Comm, source::Integer, tag::Integer) + key = (comm, Int(source), Int(tag)) + ref = lock(MPI_RECV_MEMPOOL) do pool + if !haskey(pool, key) || isempty(pool[key]) + return nothing + end + popfirst!(pool[key]) + end + ref === nothing && return nothing + return poolget(ref) +end + +function mpi_mempool_has(comm::MPI.Comm, source::Integer, tag::Integer) + key = (comm, Int(source), Int(tag)) + return lock(MPI_RECV_MEMPOOL) do pool + haskey(pool, key) && !isempty(pool[key]) + end +end diff --git a/src/mutable.jl b/src/mutable.jl new file mode 100644 index 000000000..1f48ead53 --- /dev/null +++ b/src/mutable.jl @@ -0,0 +1,41 @@ +function _mutable_inner(@nospecialize(f), proc, scope) + result = f() + return Ref(Dagger.tochunk(result, proc, scope)) +end + +""" + mutable(f::Base.Callable; worker, processor, scope) -> Chunk + +Calls `f()` on the specified worker or processor, returning a `Chunk` +referencing the result with the specified scope `scope`. +""" +function mutable(@nospecialize(f); worker=nothing, processor=nothing, scope=nothing) + if processor === nothing + if worker === nothing + processor = OSProc() + else + processor = OSProc(worker) + end + else + @assert worker === nothing "mutable: Can't mix worker and processor" + end + if scope === nothing + scope = processor isa OSProc ? ProcessScope(processor) : ExactScope(processor) + end + return fetch(Dagger.@spawn scope=scope _mutable_inner(f, processor, scope))[] +end + +""" + @mutable [worker=1] [processor=OSProc()] [scope=ProcessorScope()] f() + +Helper macro for [`mutable()`](@ref). +""" +macro mutable(exs...) + opts = esc.(exs[1:end-1]) + ex = exs[end] + quote + let f = @noinline ()->$(esc(ex)) + $mutable(f; $(opts...)) + end + end +end diff --git a/src/options.jl b/src/options.jl index eca59fbc9..09067da51 100644 --- a/src/options.jl +++ b/src/options.jl @@ -26,6 +26,7 @@ Stores per-task options to be passed to the scheduler. - `storage_leaf_tag::Union{MemPool.Tag,Nothing}=nothing`: If not `nothing`, specifies the MemPool storage leaf tag to associate with the task's result. This tag can be used by MemPool's storage devices to manipulate their behavior, such as the file name used to store data on disk." - `storage_retain::Union{Bool,Nothing}=nothing`: The value of `retain` to pass to `MemPool.poolset` when constructing the result `Chunk`. `nothing` defaults to `false`. - `name::Union{String,Nothing}=nothing`: If not `nothing`, annotates the task with a name for logging purposes. +- `tag::Union{UInt32,Nothing}=nothing`: (Data-deps/MPI) MPI message tag for this task; assigned automatically if `nothing`. - `stream_input_buffer_amount::Union{Int,Nothing}=nothing`: (Streaming only) Specifies the amount of slots to allocate for the input buffer of the task. Defaults to 1. - `stream_output_buffer_amount::Union{Int,Nothing}=nothing`: (Streaming only) Specifies the amount of slots to allocate for the output buffer of the task. Defaults to 1. - `stream_buffer_type::Union{Type,Nothing}=nothing`: (Streaming only) Specifies the type of buffer to use for the input and output buffers of the task. Defaults to `Dagger.ProcessRingBuffer`. @@ -61,10 +62,16 @@ Base.@kwdef mutable struct Options name::Union{String,Nothing} = nothing + tag::Union{UInt32,Nothing} = nothing + stream_input_buffer_amount::Union{Int,Nothing} = nothing stream_output_buffer_amount::Union{Int,Nothing} = nothing stream_buffer_type::Union{Type, Nothing} = nothing stream_max_evals::Union{Int,Nothing} = nothing + + acceleration::Union{Acceleration,Nothing} = nothing + + return_type::Union{Type,Nothing} = nothing end Options(::Nothing) = Options() function Options(old_options::NamedTuple) diff --git a/src/processor.jl b/src/processor.jl index ac2e74f14..4944dc083 100644 --- a/src/processor.jl +++ b/src/processor.jl @@ -2,16 +2,6 @@ export OSProc, Context, addprocs!, rmprocs! import Base: @invokelatest -""" - Processor - -An abstract type representing a processing device and associated memory, where -data can be stored and operated on. Subtypes should be immutable, and -instances should compare equal if they represent the same logical processing -device/memory. Subtype instances should be serializable between different -nodes. Subtype instances may contain a "parent" `Processor` to make it easy to -transfer data to/from other types of `Processor` at runtime. -""" abstract type Processor end const PROCESSOR_CALLBACKS = Dict{Symbol,Any}() @@ -150,3 +140,20 @@ iscompatible_arg(proc::OSProc, opts, args...) = "Returns a very brief `String` representation of `proc`." short_name(proc::Processor) = string(proc) short_name(p::OSProc) = "W: $(p.pid)" + +"Returns true if the processor is on the local worker (for MPI/ordering)." +is_local_processor(proc::Processor) = (root_worker_id(proc) == myid()) + +"Ordering key for task firing (used by MPI to avoid deadlock)." +fire_order_key(proc::Processor) = (root_worker_id(proc), 0) + +@doc """ + Processor + +An abstract type representing a processing device and associated memory, where +data can be stored and operated on. Subtypes should be immutable, and +instances should compare equal if they represent the same logical processing +device/memory. Subtype instances should be serializable between different +nodes. Subtype instances may contain a "parent" `Processor` to make it easy to +transfer data to/from other types of `Processor` at runtime. +""" Processor diff --git a/src/sch/Sch.jl b/src/sch/Sch.jl index 58aed6dc5..3fcd87070 100644 --- a/src/sch/Sch.jl +++ b/src/sch/Sch.jl @@ -15,7 +15,7 @@ import Base: @invokelatest import ..Dagger import ..Dagger: Context, Processor, SchedulerOptions, Options, Thunk, WeakThunk, ThunkFuture, ThunkID, DTaskFailedException, Chunk, WeakChunk, OSProc, AnyScope, DefaultScope, InvalidScope, LockedObject, Argument, Signature -import ..Dagger: order, dependents, noffspring, istask, inputs, unwrap_weak_checked, wrap_weak, affinity, tochunk, timespan_start, timespan_finish, procs, move, chunktype, default_enabled, processor, get_processors, get_parent, execute!, rmprocs!, task_processor, constrain, cputhreadtime, maybe_take_or_alloc! +import ..Dagger: order, dependents, noffspring, istask, inputs, unwrap_weak_checked, wrap_weak, affinity, tochunk, timespan_start, timespan_finish, procs, move, chunktype, default_enabled, processor, get_processors, get_parent, root_worker_id, execute!, rmprocs!, task_processor, constrain, cputhreadtime, maybe_take_or_alloc!, is_local_processor, fire_order_key, short_name import ..Dagger: @dagdebug, @safe_lock_spin1, @maybelog, @take_or_alloc! import DataStructures: PriorityQueue, enqueue!, dequeue_pair!, peek @@ -25,7 +25,7 @@ import ..Dagger: @reusable, @reusable_dict, @reusable_vector, @reusable_tasks, @ import TimespanLogging import TaskLocalValues: TaskLocalValue -import ScopedValues: @with +import ScopedValues: ScopedValue, @with, with const OneToMany = Dict{Thunk, Set{Thunk}} @@ -56,7 +56,7 @@ Fields: - `cache::WeakKeyDict{Thunk, Any}` - Maps from a finished `Thunk` to it's cached result, often a DRef - `valid::WeakKeyDict{Thunk, Nothing}` - Tracks all `Thunk`s that are in a valid scheduling state - `running::Set{Thunk}` - The set of currently-running `Thunk`s -- `running_on::Dict{Thunk,OSProc}` - Map from `Thunk` to the OS process executing it +- `running_on::Dict{Thunk,Processor}` - Map from `Thunk` to the OS process executing it - `thunk_dict::Dict{Int, WeakThunk}` - Maps from thunk IDs to a `Thunk` - `node_order::Any` - Function that returns the order of a thunk - `equiv_chunks::WeakKeyDict{DRef,Chunk}` - Cache mapping from `DRef` to a `Chunk` which contains it @@ -82,15 +82,15 @@ struct ComputeState ready::Vector{Thunk} valid::Dict{Thunk, Nothing} running::Set{Thunk} - running_on::Dict{Thunk,OSProc} + running_on::Dict{Thunk,Processor} thunk_dict::Dict{Int, WeakThunk} node_order::Any - equiv_chunks::WeakKeyDict{DRef,Chunk} - worker_time_pressure::Dict{Int,Dict{Processor,UInt64}} - worker_storage_pressure::Dict{Int,Dict{Union{StorageResource,Nothing},UInt64}} - worker_storage_capacity::Dict{Int,Dict{Union{StorageResource,Nothing},UInt64}} - worker_loadavg::Dict{Int,NTuple{3,Float64}} - worker_chans::Dict{Int, Tuple{RemoteChannel,RemoteChannel}} + equiv_chunks::WeakKeyDict{Any,Chunk} + worker_time_pressure::Dict{Processor,Dict{Processor,UInt64}} + worker_storage_pressure::Dict{Processor,Dict{Union{StorageResource,Nothing},UInt64}} + worker_storage_capacity::Dict{Processor,Dict{Union{StorageResource,Nothing},UInt64}} + worker_loadavg::Dict{Processor,NTuple{3,Float64}} + worker_chans::Dict{Int,Tuple{RemoteChannel,RemoteChannel}} signature_time_cost::Dict{Signature,UInt64} signature_alloc_cost::Dict{Signature,UInt64} worker_transfer_rate::Dict{Int,Dict{Processor,UInt64}} @@ -111,10 +111,10 @@ function start_state(deps::Dict, node_order, chan) Vector{Thunk}(undef, 0), Dict{Thunk, Nothing}(), Set{Thunk}(), - Dict{Thunk,OSProc}(), + Dict{Thunk,Processor}(), Dict{Int, WeakThunk}(), node_order, - WeakKeyDict{DRef,Chunk}(), + WeakKeyDict{Any,Chunk}(), Dict{Int,Dict{Processor,UInt64}}(), Dict{Int,Dict{Union{StorageResource,Nothing},UInt64}}(), Dict{Int,Dict{Union{StorageResource,Nothing},UInt64}}(), @@ -152,30 +152,29 @@ const WORKER_MONITOR_TASKS = Dict{Int,Task}() const WORKER_MONITOR_CHANS = Dict{Int,Dict{UInt64,RemoteChannel}}() function init_proc(state, p, log_sink) ctx = Context(Int[]; log_sink) - @maybelog ctx timespan_start(ctx, :init_proc, (;uid=state.uid, worker=p.pid), nothing) + pid = Dagger.root_worker_id(p) + @maybelog ctx timespan_start(ctx, :init_proc, (;uid=state.uid, worker=pid), nothing) # Initialize pressure and capacity - gproc = OSProc(p.pid) lock(state.lock) do - state.worker_time_pressure[p.pid] = Dict{Processor,UInt64}() - state.worker_transfer_rate[p.pid] = Dict{Processor,UInt64}() - - state.worker_storage_pressure[p.pid] = Dict{Union{StorageResource,Nothing},UInt64}() - state.worker_storage_capacity[p.pid] = Dict{Union{StorageResource,Nothing},UInt64}() + state.worker_transfer_rate[p] = Dict{Processor,UInt64}() + state.worker_time_pressure[p] = Dict{Processor,UInt64}() + state.worker_storage_pressure[p] = Dict{Union{StorageResource,Nothing},UInt64}() + state.worker_storage_capacity[p] = Dict{Union{StorageResource,Nothing},UInt64}() #= FIXME for storage in get_storage_resources(gproc) - pressure, capacity = remotecall_fetch(gproc.pid, storage) do storage + pressure, capacity = remotecall_fetch(root_worker_id(gproc), storage) do storage storage_pressure(storage), storage_capacity(storage) end - state.worker_storage_pressure[p.pid][storage] = pressure - state.worker_storage_capacity[p.pid][storage] = capacity + state.worker_storage_pressure[p][storage] = pressure + state.worker_storage_capacity[p][storage] = capacity end =# - state.worker_loadavg[p.pid] = (0.0, 0.0, 0.0) + state.worker_loadavg[p] = (0.0, 0.0, 0.0) end - if p.pid != 1 + if pid != 1 lock(WORKER_MONITOR_LOCK) do - wid = p.pid + wid = pid if !haskey(WORKER_MONITOR_TASKS, wid) t = Threads.@spawn begin try @@ -209,16 +208,16 @@ function init_proc(state, p, log_sink) end # Setup worker-to-scheduler channels - inp_chan = RemoteChannel(p.pid) - out_chan = RemoteChannel(p.pid) + inp_chan = RemoteChannel(pid) + out_chan = RemoteChannel(pid) lock(state.lock) do - state.worker_chans[p.pid] = (inp_chan, out_chan) + state.worker_chans[pid] = (inp_chan, out_chan) end # Setup dynamic listener - dynamic_listener!(ctx, state, p.pid) + dynamic_listener!(ctx, state, pid) - @maybelog ctx timespan_finish(ctx, :init_proc, (;uid=state.uid, worker=p.pid), nothing) + @maybelog ctx timespan_finish(ctx, :init_proc, (;uid=state.uid, worker=pid), nothing) end function _cleanup_proc(uid, log_sink) empty!(CHUNK_CACHE) # FIXME: Should be keyed on uid! @@ -236,7 +235,7 @@ function _cleanup_proc(uid, log_sink) end function cleanup_proc(state, p, log_sink) ctx = Context(Int[]; log_sink) - wid = p.pid + wid = root_worker_id(p) @maybelog ctx timespan_start(ctx, :cleanup_proc, (;uid=state.uid, worker=wid), nothing) lock(WORKER_MONITOR_LOCK) do if haskey(WORKER_MONITOR_CHANS, wid) @@ -299,7 +298,7 @@ function compute_dag(ctx::Context, d::Thunk, options=SchedulerOptions()) node_order = x -> -get(ord, x, 0) state = start_state(deps, node_order, chan) - master = OSProc(myid()) + master = Dagger.default_processor() @maybelog ctx timespan_start(ctx, :scheduler_init, (;uid=state.uid), master) try @@ -394,8 +393,8 @@ function scheduler_run(ctx, state::ComputeState, d::Thunk, options::SchedulerOpt res = tresult.result @dagdebug thunk_id :take "Got finished task" - gproc = OSProc(pid) safepoint(state) + gproc = proc != nothing ? get_parent(proc) : OSProc(pid) lock(state.lock) do thunk_failed = false if res isa Exception @@ -422,11 +421,11 @@ function scheduler_run(ctx, state::ComputeState, d::Thunk, options::SchedulerOpt node = unwrap_weak_checked(state.thunk_dict[thunk_id])::Thunk metadata = tresult.metadata if metadata !== nothing - state.worker_time_pressure[pid][proc] = metadata.time_pressure + state.worker_time_pressure[gproc][proc] = metadata.time_pressure #to_storage = fetch(node.options.storage) #state.worker_storage_pressure[pid][to_storage] = metadata.storage_pressure #state.worker_storage_capacity[pid][to_storage] = metadata.storage_capacity - #state.worker_loadavg[pid] = metadata.loadavg + #state.worker_loadavg[gproc] = metadata.loadavg sig = signature(state, node) state.signature_time_cost[sig] = (metadata.threadtime + get(state.signature_time_cost, sig, 0)) ÷ 2 state.signature_alloc_cost[sig] = (metadata.gc_allocd + get(state.signature_alloc_cost, sig, 0)) ÷ 2 @@ -440,8 +439,8 @@ function scheduler_run(ctx, state::ComputeState, d::Thunk, options::SchedulerOpt end end if res isa Chunk - if !haskey(state.equiv_chunks, res) - state.equiv_chunks[res.handle::DRef] = res + if !haskey(state.equiv_chunks, res.handle) + state.equiv_chunks[res.handle] = res end end store_result!(state, node, res; error=thunk_failed) @@ -528,7 +527,7 @@ end const CHUNK_CACHE = Dict{Chunk,Dict{Processor,Any}}() struct ScheduleTaskLocation - gproc::OSProc + gproc::Processor proc::Processor end struct ScheduleTaskSpec @@ -538,6 +537,25 @@ struct ScheduleTaskSpec est_alloc_util::UInt64 est_occupancy::UInt32 end + +"Ordering key for task locations when using MPI acceleration (deterministic across ranks)." +function _mpi_fire_order_key(loc::ScheduleTaskLocation) + g = loc.gproc + p = loc.proc + g_rank = g isa Union{Dagger.MPIOSProc, Dagger.MPIProcessor} ? g.rank : root_worker_id(g) + p_rank = p isa Union{Dagger.MPIOSProc, Dagger.MPIProcessor} ? p.rank : root_worker_id(p) + return (g_rank, p_rank) +end + +"Ordering key for a single Processor when using MPI acceleration (deterministic across ranks)." +function _mpi_proc_rank(proc::Processor) + g = get_parent(proc) + p = proc + g_rank = g isa Union{Dagger.MPIOSProc, Dagger.MPIProcessor} ? g.rank : root_worker_id(g) + p_rank = p isa Union{Dagger.MPIOSProc, Dagger.MPIProcessor} ? p.rank : root_worker_id(p) + return (g_rank, p_rank) +end + @reuse_scope function schedule!(ctx, state, sch_options, procs=procs_to_use(ctx, sch_options)) lock(state.lock) do safepoint(state) @@ -552,6 +570,7 @@ end to_fire_cleanup = @reuse_defer_cleanup empty!(to_fire) failed_scheduling = @reusable_vector :schedule!_failed_scheduling Union{Thunk,Nothing} nothing 32 failed_scheduling_cleanup = @reuse_defer_cleanup empty!(failed_scheduling) + # Select a new task and get its options task = nothing @label pop_task @@ -626,9 +645,9 @@ end end @label scope_computed - input_procs = @reusable_vector :schedule!_input_procs Processor OSProc() 32 + input_procs = @reusable_vector :schedule!_input_procs Union{Processor,Nothing} nothing 32 input_procs_cleanup = @reuse_defer_cleanup empty!(input_procs) - for proc in Dagger.compatible_processors(scope, procs) + for proc in Dagger.compatible_processors(options.acceleration, scope, procs) if !(proc in input_procs) push!(input_procs, proc) end @@ -660,7 +679,7 @@ end can_use, scope = can_use_proc(state, task, gproc, proc, options, scope) if can_use has_cap, est_time_util, est_alloc_util, est_occupancy = - has_capacity(state, proc, gproc.pid, options.time_util, options.alloc_util, options.occupancy, sig) + has_capacity(state, proc, gproc, options.time_util, options.alloc_util, options.occupancy, sig) if has_cap # Schedule task onto proc # FIXME: est_time_util = est_time_util isa MaxUtilization ? cap : est_time_util @@ -669,10 +688,10 @@ end Vector{ScheduleTaskSpec}() end push!(proc_tasks, ScheduleTaskSpec(task, scope, est_time_util, est_alloc_util, est_occupancy)) - state.worker_time_pressure[gproc.pid][proc] = - get(state.worker_time_pressure[gproc.pid], proc, 0) + + state.worker_time_pressure[gproc][proc] = + get(state.worker_time_pressure[gproc], proc, 0) + est_time_util - @dagdebug task :schedule "Scheduling to $gproc -> $proc (cost: $(costs[proc]), pressure: $(state.worker_time_pressure[gproc.pid][proc]))" + @dagdebug task :schedule "Scheduling to $gproc -> $proc (cost: $(costs[proc]), pressure: $(state.worker_time_pressure[gproc][proc]))" sorted_procs_cleanup() costs_cleanup() @goto pop_task @@ -687,10 +706,21 @@ end costs_cleanup() @goto pop_task - # Fire all newly-scheduled tasks + # Fire all newly-scheduled tasks (owner/local first, then by fire_order_key to avoid MPI execute! deadlock) @label fire_tasks - for (task_loc, task_spec) in to_fire - fire_tasks!(ctx, task_loc, task_spec, state) + task_locs = collect(keys(to_fire)) + if Dagger.current_acceleration() isa Dagger.MPIAcceleration + sort!(task_locs, by=_mpi_fire_order_key) + end + rank = try + M = parentmodule(@__MODULE__) + (isdefined(M, :MPI) && M.MPI.Initialized()) ? Int(M.MPI.Comm_rank(M.MPI.COMM_WORLD)) : nothing + catch + nothing + end + for (i, task_loc) in enumerate(task_locs) + #Core.println("fire_order rank=", rank, " [", i, "/", length(task_locs), "] task_loc=", task_loc) + fire_tasks!(ctx, task_loc, to_fire[task_loc], state) end to_fire_cleanup() @@ -739,14 +769,14 @@ function monitor_procs_changed!(ctx, state, options) end function remove_dead_proc!(ctx, state, proc, options) - @assert options.single !== proc.pid "Single worker failed, cannot continue." + @assert options.single !== root_worker_id(proc) "Single worker failed, cannot continue." rmprocs!(ctx, [proc]) - delete!(state.worker_time_pressure, proc.pid) - delete!(state.worker_transfer_rate, proc.pid) - delete!(state.worker_storage_pressure, proc.pid) - delete!(state.worker_storage_capacity, proc.pid) - delete!(state.worker_loadavg, proc.pid) - delete!(state.worker_chans, proc.pid) + delete!(state.worker_transfer_rate, proc) + delete!(state.worker_time_pressure, proc) + delete!(state.worker_storage_pressure, proc) + delete!(state.worker_storage_capacity, proc) + delete!(state.worker_loadavg, proc) + delete!(state.worker_chans, root_worker_id(proc)) end function finish_task!(ctx, state, node, thunk_failed) @@ -789,7 +819,7 @@ end function evict_all_chunks!(ctx, options, to_evict) if !isempty(to_evict) - @sync for w in map(p->p.pid, procs_to_use(ctx, options)) + @sync for w in map(p->root_worker_id(p), procs_to_use(ctx, options)) Threads.@spawn remote_do(evict_chunks!, w, ctx.log_sink, to_evict) end end @@ -860,9 +890,10 @@ Base.hash(task::TaskSpec, h::UInt) = hash(task.thunk_id, hash(TaskSpec, h)) end Tf = chunktype(first(args)) - @assert (options.single === nothing) || (gproc.pid == options.single) + pid = root_worker_id(gproc) + @assert (options.single === nothing) || (pid == options.single) # TODO: Set `sch_handle.tid.ref` to the right `DRef` - sch_handle = SchedulerHandle(ThunkID(thunk.id, nothing), state.worker_chans[gproc.pid]...) + sch_handle = SchedulerHandle(ThunkID(thunk.id, nothing), state.worker_chans[pid]...) # TODO: De-dup common fields (log_sink, uid, etc.) push!(to_send, TaskSpec( @@ -874,7 +905,7 @@ Base.hash(task::TaskSpec, h::UInt) = hash(task.thunk_id, hash(TaskSpec, h)) end if !isempty(to_send) - if Dagger.root_worker_id(gproc) == myid() + if root_worker_id(gproc) == myid() @reusable_tasks :fire_tasks!_task_cache 32 _->nothing "fire_tasks!" FireTaskSpec(proc, state.chan, to_send) else # N.B. We don't batch these because we might get a deserialization @@ -1080,7 +1111,7 @@ function start_processor_runner!(istate::ProcessorInternalState, uid::UInt64, re proc_occupancy = istate.proc_occupancy time_pressure = istate.time_pressure - wid = get_parent(to_proc).pid + wid = root_worker_id(to_proc) work_to_do = false while isopen(return_queue) # Wait for new tasks @@ -1135,12 +1166,15 @@ function start_processor_runner!(istate::ProcessorInternalState, uid::UInt64, re # Try to steal a task @maybelog ctx timespan_start(ctx, :proc_steal_local, (;uid, worker=wid, processor=to_proc), nothing) - # Try to steal from local queues randomly + # Try to steal from local queues randomly (deterministic order when MPI to avoid deadlocks) # TODO: Prioritize stealing from busiest processors states = proc_states_values(uid) - # TODO: Try to pre-allocate this - P = randperm(length(states)) - for state in getindex.(Ref(states), P) + order = if Dagger.current_acceleration() isa Dagger.MPIAcceleration + sort(1:length(states), by=i->_mpi_proc_rank(states[i].state.proc)) + else + randperm(length(states)) + end + for state in getindex.(Ref(states), order) other_istate = state.state if other_istate.proc === to_proc continue @@ -1155,7 +1189,8 @@ function start_processor_runner!(istate::ProcessorInternalState, uid::UInt64, re end task, occupancy = peek(queue) scope = task.scope - if Dagger.proc_in_scope(to_proc, scope) + accel = something(task.options.acceleration, Dagger.DistributedAcceleration()) + if Dagger.proc_in_scope(to_proc, scope) && Dagger.accel_matches_proc(accel, to_proc) typemax(UInt32) - proc_occupancy_cached >= occupancy # Compatible, steal this task return dequeue_pair!(queue) @@ -1348,11 +1383,15 @@ function do_tasks(to_proc, return_queue, tasks) end notify(istate.reschedule) - # Kick other processors to make them steal + # Kick other processors to make them steal (deterministic order when MPI to avoid deadlocks) # TODO: Alternatively, automatically balance work instead of blindly enqueueing states = proc_states_values(uid) - P = randperm(length(states)) - for other_state in getindex.(Ref(states), P) + order = if Dagger.current_acceleration() isa Dagger.MPIAcceleration + sort(1:length(states), by=i->_mpi_proc_rank(states[i].state.proc)) + else + randperm(length(states)) + end + for other_state in getindex.(Ref(states), order) other_istate = other_state.state if other_istate.proc === to_proc continue @@ -1361,6 +1400,8 @@ function do_tasks(to_proc, return_queue, tasks) end @dagdebug nothing :processor "Kicked processors" end + +const SCHED_MOVE = ScopedValue{Bool}(false) """ do_task(to_proc, task::TaskSpec) -> Any @@ -1373,13 +1414,15 @@ Executes a single task specified by `task` on `to_proc`. ctx_vars = task.ctx_vars ctx = Context(Processor[]; log_sink=ctx_vars.log_sink, profile=ctx_vars.profile) - from_proc = OSProc() + options = task.options + Dagger.accelerate!(options.acceleration) + + from_proc = Dagger.default_processor() data = task.data Tf = task.Tf f = isdefined(Tf, :instance) ? Tf.instance : nothing # Wait for required resources to become available - options = task.options propagated = get_propagated_options(options) to_storage = options.storage !== nothing ? fetch(options.storage) : MemPool.GLOBAL_DEVICE[] #to_storage_name = nameof(typeof(to_storage)) @@ -1447,7 +1490,7 @@ Executes a single task specified by `task` on `to_proc`. @maybelog ctx timespan_finish(ctx, :storage_wait, (;thunk_id, processor=to_proc), (;f, device=typeof(to_storage))) =# - @dagdebug thunk_id :execute "Moving data" + @dagdebug thunk_id :execute "Moving data for $Tf" # Initiate data transfers for function and arguments transfer_time = Threads.Atomic{UInt64}(0) @@ -1466,11 +1509,13 @@ Executes a single task specified by `task` on `to_proc`. #= FIXME: This isn't valid if x is written to x = if x isa Chunk value = lock(TASK_SYNC) do - if haskey(CHUNK_CACHE, x) - Some{Any}(get!(CHUNK_CACHE[x], to_proc) do - # Convert from cached value - # TODO: Choose "closest" processor of same type first - some_proc = first(keys(CHUNK_CACHE[x])) + if haskey(CHUNK_CACHE, x) + Some{Any}(get!(CHUNK_CACHE[x], to_proc) do + # Convert from cached value + # TODO: Choose "closest" processor of same type first + cache_procs = keys(CHUNK_CACHE[x]) + some_proc = Dagger.current_acceleration() isa Dagger.MPIAcceleration ? + minimum(cache_procs, by=_mpi_proc_rank) : first(cache_procs) some_x = CHUNK_CACHE[x][some_proc] @dagdebug thunk_id :move "Cache hit for argument $id at $some_proc: $some_x" @invokelatest move(some_proc, to_proc, some_x) @@ -1505,13 +1550,23 @@ Executes a single task specified by `task` on `to_proc`. end else =# - new_value = @invokelatest move(to_proc, value) + new_value = with(SCHED_MOVE=>true) do + @invokelatest move(to_proc, value) + end #end - if new_value !== value - @dagdebug thunk_id :move "Moved argument @ $position to $to_proc: $(typeof(value)) -> $(typeof(new_value))" + # Preserve Chunk reference when move returns nothing (placeholder on this rank). This keeps + # type information correct at all ranks: chunktype(Chunk) is concrete even when Chunk holds no data. + # So execute! sees correct arg_types. Materializing the value (for the kernel) must happen in + # execute! and may require lazy recv from the executor if this rank has a placeholder. + if new_value === nothing && (value isa Dagger.Chunk || value isa Dagger.WeakChunk) + arg.value = value + else + if new_value !== value + @dagdebug thunk_id :move "Moved argument @ $position to $to_proc: $(typeof(value)) -> $(typeof(new_value))" + end + arg.value = new_value end - @maybelog ctx timespan_finish(ctx, :move, (;thunk_id, position, processor=to_proc), (;f, data=new_value); tasks=[Base.current_task()]) - arg.value = new_value + @maybelog ctx timespan_finish(ctx, :move, (;thunk_id, position, processor=to_proc), (;f, data=Dagger.value(arg)); tasks=[Base.current_task()]) return end end @@ -1550,7 +1605,7 @@ Executes a single task specified by `task` on `to_proc`. # FIXME #gcnum_start = Base.gc_num() - @dagdebug thunk_id :execute "Executing $(typeof(f))" + @dagdebug thunk_id :execute "Executing $Tf" logging_enabled = !(ctx.log_sink isa TimespanLogging.NoOpLog) @@ -1613,7 +1668,7 @@ Executes a single task specified by `task` on `to_proc`. notify(TASK_SYNC) end - @dagdebug thunk_id :execute "Returning" + @dagdebug thunk_id :execute "Returning $Tf with $(typeof(result_meta))" # TODO: debug_storage("Releasing $to_storage_name") metadata = ( diff --git a/src/sch/util.jl b/src/sch/util.jl index 11706382a..8e36c3576 100644 --- a/src/sch/util.jl +++ b/src/sch/util.jl @@ -440,8 +440,8 @@ function can_use_proc(state, task, gproc, proc, opts, scope) # Check against single if opts.single !== nothing @warn "The `single` option is deprecated, please use scopes instead\nSee https://juliaparallel.org/Dagger.jl/stable/scopes/ for details" maxlog=1 - if gproc.pid != opts.single - @dagdebug task :scope "Rejected $proc: gproc.pid ($(gproc.pid)) != single ($(opts.single))" + if root_worker_id(gproc) != opts.single + @dagdebug task :scope "Rejected $proc: gproc root_worker_id ($(root_worker_id(gproc))) != single ($(opts.single))" return false, scope end scope = constrain(scope, Dagger.ProcessScope(opts.single)) @@ -593,19 +593,21 @@ const DEFAULT_TRANSFER_RATE = UInt64(1_000_000) # Add fixed cost for cross-worker task transfer (esimated at 1ms) # TODO: Actually estimate/benchmark this - task_xfer_cost = gproc.pid != myid() ? 1_000_000 : 0 # 1ms + task_xfer_cost = root_worker_id(gproc) != myid() ? 1_000_000 : 0 # 1ms tx_rate = get(get(state.worker_transfer_rate, gproc.pid, Dict{Processor,UInt64}()), proc, DEFAULT_TRANSFER_RATE) costs[proc] = est_time_util + (tx_cost/tx_rate) + task_xfer_cost end chunks_cleanup() - # Shuffle procs around, so equally-costly procs are equally considered + # Shuffle procs around, so equally-costly procs are equally considered (skip shuffle when MPI for deterministic tie-breaking) np = length(procs) @reusable :estimate_task_costs_P Vector{Int} 0 4 np P begin resize!(P, np) copyto!(P, 1:np) - randperm!(P) + if !(Dagger.current_acceleration() isa Dagger.MPIAcceleration) + randperm!(P) + end for idx in 1:np sorted_procs[idx] = procs[P[idx]] end diff --git a/src/scopes.jl b/src/scopes.jl index 79190c292..28aa8fa00 100644 --- a/src/scopes.jl +++ b/src/scopes.jl @@ -101,7 +101,7 @@ struct ExactScope <: AbstractScope parent::ProcessScope processor::Processor end -ExactScope(proc) = ExactScope(ProcessScope(get_parent(proc).pid), proc) +ExactScope(proc) = ExactScope(ProcessScope(root_worker_id(get_parent(proc))), proc) proc_in_scope(proc::Processor, scope::ExactScope) = proc == scope.processor "Indicates that the applied scopes `x` and `y` are incompatible." diff --git a/src/shard.jl b/src/shard.jl new file mode 100644 index 000000000..ecd0ee570 --- /dev/null +++ b/src/shard.jl @@ -0,0 +1,89 @@ +""" +Maps a value to one of multiple distributed "mirror" values automatically when +used as a thunk argument. Construct using `@shard` or `shard`. +""" +struct Shard + chunks::Dict{Processor,Chunk} +end + +""" + shard(f; kwargs...) -> Chunk{Shard} + +Executes `f` on all workers in `workers`, wrapping the result in a +process-scoped `Chunk`, and constructs a `Chunk{Shard}` containing all of these +`Chunk`s on the current worker. + +Keyword arguments: +- `procs` -- The list of processors to create pieces on. May be any iterable container of `Processor`s. +- `workers` -- The list of workers to create pieces on. May be any iterable container of `Integer`s. +- `per_thread::Bool=false` -- If `true`, creates a piece per each thread, rather than a piece per each worker. +""" +function shard(@nospecialize(f); procs=nothing, workers=nothing, per_thread=false) + if procs === nothing + if workers !== nothing + procs = [OSProc(w) for w in workers] + else + procs = lock(Sch.eager_context()) do + copy(Sch.eager_context().procs) + end + end + if per_thread + _procs = ThreadProc[] + for p in procs + append!(_procs, filter(p->p isa ThreadProc, get_processors(p))) + end + procs = _procs + end + else + if workers !== nothing + throw(ArgumentError("Cannot combine `procs` and `workers`")) + elseif per_thread + throw(ArgumentError("Cannot combine `procs` and `per_thread=true`")) + end + end + isempty(procs) && throw(ArgumentError("Cannot create empty Shard")) + shard_running_dict = Dict{Processor,DTask}() + for proc in procs + scope = proc isa OSProc ? ProcessScope(proc) : ExactScope(proc) + thunk = Dagger.@spawn scope=scope _mutable_inner(f, proc, scope) + shard_running_dict[proc] = thunk + end + shard_dict = Dict{Processor,Chunk}() + for proc in procs + shard_dict[proc] = fetch(shard_running_dict[proc])[] + end + return Shard(shard_dict) +end + +"Creates a `Shard`. See [`Dagger.shard`](@ref) for details." +macro shard(exs...) + opts = esc.(exs[1:end-1]) + ex = exs[end] + quote + let f = @noinline ()->$(esc(ex)) + $shard(f; $(opts...)) + end + end +end + +function move(from_proc::Processor, to_proc::Processor, shard::Shard) + # Match either this proc or some ancestor + # N.B. This behavior may bypass the piece's scope restriction + proc = to_proc + if haskey(shard.chunks, proc) + return move(from_proc, to_proc, shard.chunks[proc]) + end + parent = Dagger.get_parent(proc) + while parent != proc + proc = parent + parent = Dagger.get_parent(proc) + if haskey(shard.chunks, proc) + return move(from_proc, to_proc, shard.chunks[proc]) + end + end + + throw(KeyError(to_proc)) +end +Base.iterate(s::Shard) = iterate(values(s.chunks)) +Base.iterate(s::Shard, state) = iterate(values(s.chunks), state) +Base.length(s::Shard) = length(s.chunks) diff --git a/src/submission.jl b/src/submission.jl index 4ff4f2294..fffcc577d 100644 --- a/src/submission.jl +++ b/src/submission.jl @@ -285,7 +285,13 @@ function eager_process_args_submission_to_local(id_map, spec::DTaskSpec{true}) return ntuple(i->eager_process_elem_submission_to_local(id_map, spec.fargs[i]), length(spec.fargs)) end -DTaskMetadata(spec::DTaskSpec) = DTaskMetadata(eager_metadata(spec.fargs)) +function DTaskMetadata(spec::DTaskSpec) + rt = spec.options.return_type + if rt !== nothing && isconcretetype(rt) && rt !== Any + return DTaskMetadata(rt) + end + return DTaskMetadata(eager_metadata(spec.fargs)) +end function eager_metadata(fargs) f = value(fargs[1]) f = f isa StreamingFunction ? f.f : f @@ -298,6 +304,10 @@ function eager_spawn(spec::DTaskSpec) uid = eager_next_id() future = ThunkFuture() metadata = DTaskMetadata(spec) + # Propagate inferred return type to options so execute! can skip MPI bcast + if isconcretetype(metadata.return_type) + spec.options.return_type = metadata.return_type + end return DTask(uid, future, metadata) end @@ -320,10 +330,16 @@ function eager_launch!(pair::DTaskPair) end end + # Propagate DTask return_type into options so the created Thunk has chunktype for downstream inference + options = spec.options + if isconcretetype(task.metadata.return_type) + options = copy(options) + options.return_type = task.metadata.return_type + end # Submit the task #=FIXME:REALLOC=# thunk_id = eager_submit!(PayloadOne(task.uid, task.future, - fargs, spec.options, true)) + fargs, options, true)) task.thunk_ref = thunk_id.ref end # FIXME: Don't convert Tuple to Vector{Argument} @@ -353,7 +369,13 @@ function eager_launch!(pairs::Vector{DTaskPair}) end end end - all_options = Options[pair.spec.options for pair in pairs] + # Propagate DTask return_type into options so created Thunks have chunktype for downstream inference + all_options = Options[ + let opts = pair.spec.options + isconcretetype(pair.task.metadata.return_type) ? (o = copy(opts); o.return_type = pair.task.metadata.return_type; o) : opts + end + for pair in pairs + ] # Submit the tasks #=FIXME:REALLOC=# diff --git a/src/thunk.jl b/src/thunk.jl index e13e299f0..c24e0c329 100644 --- a/src/thunk.jl +++ b/src/thunk.jl @@ -247,6 +247,14 @@ isweak(t) = false Base.show(io::IO, t::WeakThunk) = (print(io, "~"); Base.show(io, t.x.value)) Base.convert(::Type{WeakThunk}, t::Thunk) = WeakThunk(t) chunktype(t::WeakThunk) = chunktype(unwrap_weak_checked(t)) +# Use options.return_type when set (e.g. from mpi_propagate_chunk_types! or eager_metadata) +# so that Thunk arguments propagate type to downstream eager_metadata/execute! +function chunktype(t::Thunk) + if t.options !== nothing && t.options.return_type !== nothing && isconcretetype(t.options.return_type) + return t.options.return_type + end + return typeof(t) +end Base.convert(::Type{ThunkSyncdep}, t::WeakThunk) = ThunkSyncdep(nothing, t) ThunkSyncdep(t::WeakThunk) = ThunkSyncdep(nothing, t) @@ -462,7 +470,7 @@ function _par(mod, ex::Expr; lazy=true, recur=true, opts=()) end args = filter(arg->!Meta.isexpr(arg, :parameters), allargs) kwargs = filter(arg->Meta.isexpr(arg, :parameters), allargs) - if !isempty(kwargs) + if !Base.isempty(kwargs) kwargs = only(kwargs).args end if body !== nothing @@ -530,7 +538,7 @@ function spawn(f, args...; kwargs...) @nospecialize f args kwargs # Merge all passed options - if length(args) >= 1 && first(args) isa Options + if length(args) >= 1 && first(args) isa Options # N.B. Make a defensive copy in case user aliases Options struct task_options = copy(first(args)::Options) args = args[2:end] @@ -545,7 +553,7 @@ function spawn(f, args...; kwargs...) end function typed_spawn(f, args...; kwargs...) # Merge all passed options - if length(args) >= 1 && first(args) isa Options + if length(args) >= 1 && first(args) isa Options # N.B. Make a defensive copy in case user aliases Options struct task_options = copy(first(args)::Options) args = args[2:end] diff --git a/src/tochunk.jl b/src/tochunk.jl new file mode 100644 index 000000000..ff15e426e --- /dev/null +++ b/src/tochunk.jl @@ -0,0 +1,119 @@ +@warn "Update tochunk docstring" maxlog=1 +""" + tochunk(x, proc::Processor, scope::AbstractScope; device=nothing, rewrap=false, kwargs...) -> Chunk + +Create a chunk from data `x` which resides on `proc` and which has scope +`scope`. + +`device` specifies a `MemPool.StorageDevice` (which is itself wrapped in a +`Chunk`) which will be used to manage the reference contained in the `Chunk` +generated by this function. If `device` is `nothing` (the default), the data +will be inspected to determine if it's safe to serialize; if so, the default +MemPool storage device will be used; if not, then a `MemPool.CPURAMDevice` will +be used. + +`type` can be specified manually to force the type to be `Chunk{type}`. + +If `rewrap==true` and `x isa Chunk`, then the `Chunk` will be rewrapped in a +new `Chunk`. + +All other kwargs are passed directly to `MemPool.poolset`. +""" +tochunk(x::X, proc::P, space::M; kwargs...) where {X,P<:Processor,M<:MemorySpace} = + tochunk(x, proc, space, AnyScope(); kwargs...) +function tochunk(x::X, proc::P, space::M, scope::S; device=nothing, type=X, rewrap=false, kwargs...) where {X,P<:Processor,S,M<:MemorySpace} + if type === Nothing + throw(ArgumentError("Chunk type cannot be Nothing. Placeholder chunks must be created with an explicit type= (e.g. tochunk(nothing, proc, space; type=Matrix{Float64})). x=$(repr(x))")) + end + if x isa Chunk + check_proc_space(x, proc, space) + return maybe_rewrap(x, proc, space, scope; type, rewrap) + end + if device === nothing + device = if Sch.walk_storage_safe(x) + MemPool.GLOBAL_DEVICE[] + else + MemPool.CPURAMDevice() + end + end + ref = tochunk_pset(x, space; device, type, kwargs...) + return Chunk{type,typeof(ref),P,S,typeof(space)}(type, domain(x), ref, proc, scope, space) +end +# Disambiguate: Chunk-specific 3-arg so kwcall(tochunk, Chunk, Processor, Scope) is not ambiguous with utils/chunks.jl +function tochunk(x::Chunk, proc::P, scope::S; rewrap=false, kwargs...) where {P<:Processor,S} + if rewrap + return remotecall_fetch(x.handle.owner) do + tochunk(MemPool.poolget(x.handle), proc, scope; kwargs...) + end + else + return x + end +end +function tochunk(x::X, proc::P, scope::S; device=nothing, type=X, rewrap=false, kwargs...) where {X,P<:Processor,S} + if type === Nothing + throw(ArgumentError("Chunk type cannot be Nothing. Placeholder chunks must be created with an explicit type= (e.g. tochunk(nothing, proc, space; type=Matrix{Float64})). x=$(repr(x))")) + end + if device === nothing + device = if Sch.walk_storage_safe(x) + MemPool.GLOBAL_DEVICE[] + else + MemPool.CPURAMDevice() + end + end + if x isa Chunk + space = x.space + check_proc_space(x, proc, space) + return maybe_rewrap(x, proc, space, scope; type, rewrap) + end + space = default_memory_space(current_acceleration(), x) + ref = tochunk_pset(x, space; device, type, kwargs...) + return Chunk{type,typeof(ref),P,S,typeof(space)}(type, domain(x), ref, proc, scope, space) +end +function tochunk(x::X, space::M, scope::S; device=nothing, type=X, rewrap=false, kwargs...) where {X,M<:MemorySpace,S} + if type === Nothing + throw(ArgumentError("Chunk type cannot be Nothing. Placeholder chunks must be created with an explicit type= (e.g. tochunk(nothing, proc, space; type=Matrix{Float64})). x=$(repr(x))")) + end + if device === nothing + device = if Sch.walk_storage_safe(x) + MemPool.GLOBAL_DEVICE[] + else + MemPool.CPURAMDevice() + end + end + if x isa Chunk + proc = x.processor + check_proc_space(x, proc, space) + return maybe_rewrap(x, proc, space, scope; type, rewrap) + end + proc = default_processor(current_acceleration(), x) + ref = tochunk_pset(x, space; device, type, kwargs...) + return Chunk{type,typeof(ref),typeof(proc),S,M}(type, domain(x), ref, proc, scope, space) +end +# 2-arg: avoid overwriting utils/chunks.jl's tochunk(Any, Any) and tochunk(Any); only add Processor/MemorySpace variants +# Chunk + Processor: disambiguate vs utils/chunks.jl's tochunk(x::Chunk, proc; ...) +tochunk(x::Chunk, proc::Processor; kwargs...) = tochunk(x, proc, AnyScope(); kwargs...) +tochunk(x, proc::Processor; kwargs...) = tochunk(x, proc, AnyScope(); kwargs...) +tochunk(x, space::MemorySpace; kwargs...) = tochunk(x, space, AnyScope(); kwargs...) + +check_proc_space(x, proc, space) = nothing +function check_proc_space(x::Chunk, proc, space) + if x.space !== space + throw(ArgumentError("Memory space mismatch: Chunk=$(x.space) != Requested=$space")) + end +end +function check_proc_space(x::Thunk, proc, space) + # FIXME: Validate +end +function maybe_rewrap(x, proc, space, scope; type, rewrap) + if rewrap + return remotecall_fetch(x.handle.owner) do + tochunk(MemPool.poolget(x.handle), proc, scope; kwargs...) + end + else + return x + end +end + +tochunk_pset(x, space::MemorySpace; device=nothing, type=nothing, kwargs...) = poolset(x; device, kwargs...) + +# savechunk: defined in utils/chunks.jl (fork Chunk has space field; do not duplicate here) diff --git a/src/types/acceleration.jl b/src/types/acceleration.jl new file mode 100644 index 000000000..b647dd303 --- /dev/null +++ b/src/types/acceleration.jl @@ -0,0 +1 @@ +abstract type Acceleration end \ No newline at end of file diff --git a/src/types/chunk.jl b/src/types/chunk.jl new file mode 100644 index 000000000..9b8102a6d --- /dev/null +++ b/src/types/chunk.jl @@ -0,0 +1,27 @@ +""" + Chunk + +A reference to a piece of data located on a remote worker. `Chunk`s are +typically created with `Dagger.tochunk(data)`, and the data can then be +accessed from any worker with `collect(::Chunk)`. `Chunk`s are +serialization-safe, and use distributed refcounting (provided by +`MemPool.DRef`) to ensure that the data referenced by a `Chunk` won't be GC'd, +as long as a reference exists on some worker. + +Each `Chunk` is associated with a given `Dagger.Processor`, which is (in a +sense) the processor that "owns" or contains the data. Calling +`collect(::Chunk)` will perform data movement and conversions defined by that +processor to safely serialize the data to the calling worker. + +## Constructors +See [`tochunk`](@ref). +""" + +mutable struct Chunk{T, H, P<:Processor, S<:AbstractScope, M<:MemorySpace} + chunktype::Type{T} + domain + handle::H + processor::P + scope::S + space::M +end diff --git a/src/types/memory-space.jl b/src/types/memory-space.jl new file mode 100644 index 000000000..247ceccb0 --- /dev/null +++ b/src/types/memory-space.jl @@ -0,0 +1 @@ +abstract type MemorySpace end \ No newline at end of file diff --git a/src/types/processor.jl b/src/types/processor.jl new file mode 100644 index 000000000..1e333413f --- /dev/null +++ b/src/types/processor.jl @@ -0,0 +1,2 @@ +# Docstring for Processor is attached in src/processor.jl after OSProc is defined (avoids "Replacing docs" warning). +abstract type Processor end \ No newline at end of file diff --git a/src/types/scope.jl b/src/types/scope.jl new file mode 100644 index 000000000..0197fddf9 --- /dev/null +++ b/src/types/scope.jl @@ -0,0 +1 @@ +abstract type AbstractScope end \ No newline at end of file diff --git a/src/utils/chunks.jl b/src/utils/chunks.jl index 9f0c3b487..1300a5a1d 100644 --- a/src/utils/chunks.jl +++ b/src/utils/chunks.jl @@ -161,7 +161,8 @@ function tochunk(x::X, proc::P=OSProc(), scope::S=AnyScope(); device=nothing, re end end ref = poolset(x; device, kwargs...) - Chunk{X,typeof(ref),P,S}(X, domain(x), ref, proc, scope) + space = memory_space(proc) + Chunk{X,typeof(ref),P,S,typeof(space)}(X, domain(x), ref, proc, scope, space) end function tochunk(x::Chunk, proc=nothing, scope=nothing; rewrap=false, kwargs...) if rewrap @@ -185,5 +186,6 @@ function savechunk(data, dir, f) fr = FileRef(f, sz) proc = OSProc() scope = AnyScope() # FIXME: Scoped to this node - Chunk{typeof(data),typeof(fr),typeof(proc),typeof(scope)}(typeof(data), domain(data), fr, proc, scope, true) + space = memory_space(proc) + Chunk{typeof(data),typeof(fr),typeof(proc),typeof(scope),typeof(space)}(typeof(data), domain(data), fr, proc, scope, space) end diff --git a/src/utils/dagdebug.jl b/src/utils/dagdebug.jl index 873e47e79..678445051 100644 --- a/src/utils/dagdebug.jl +++ b/src/utils/dagdebug.jl @@ -59,4 +59,7 @@ macro opcounter(category, count=1) end end) end -opcounter(mod::Module, category::Symbol) = getfield(mod, Symbol(:OPCOUNTER_, category)).value[] \ No newline at end of file +opcounter(mod::Module, category::Symbol) = getfield(mod, Symbol(:OPCOUNTER_, category)).value[] + +# No-op debug helper for tracking largest values (used alongside @opcounter) +largest_value_update!(::Any) = nothing \ No newline at end of file diff --git a/src/weakchunk.jl b/src/weakchunk.jl new file mode 100644 index 000000000..e31070536 --- /dev/null +++ b/src/weakchunk.jl @@ -0,0 +1,23 @@ +struct WeakChunk + wid::Int + id::Int + x::WeakRef +end + +function WeakChunk(c::Chunk) + return WeakChunk(c.handle.owner, c.handle.id, WeakRef(c)) +end + +unwrap_weak(c::WeakChunk) = c.x.value +function unwrap_weak_checked(c::WeakChunk) + cw = unwrap_weak(c) + @assert cw !== nothing "WeakChunk expired: ($(c.wid), $(c.id))" + return cw +end +wrap_weak(c::Chunk) = WeakChunk(c) +isweak(c::WeakChunk) = true +isweak(c::Chunk) = false +is_task_or_chunk(c::WeakChunk) = true +Serialization.serialize(io::AbstractSerializer, wc::WeakChunk) = + error("Cannot serialize a WeakChunk") +chunktype(c::WeakChunk) = chunktype(unwrap_weak_checked(c)) diff --git a/test/mpi.jl b/test/mpi.jl new file mode 100644 index 000000000..a84ffdce1 --- /dev/null +++ b/test/mpi.jl @@ -0,0 +1,70 @@ +using Dagger +using MPI +using LinearAlgebra +using SparseArrays + +Dagger.accelerate!(:mpi) + +comm = MPI.COMM_WORLD +rank = MPI.Comm_rank(comm) +size = MPI.Comm_size(comm) + +# Use a large array (adjust size as needed for your RAM) +N = 100 +tag = 123 + +if rank == 0 + arr = sprand(N, N, 0.6) +else + arr = spzeros(N, N) +end + +# --- Out-of-place broadcast --- +function bcast_outofplace() + MPI.Barrier(comm) + if rank == 0 + Dagger.bcast_send_yield(arr, comm, 0, tag+1) + else + Dagger.bcast_recv_yield(comm, 0, tag+1) + end + MPI.Barrier(comm) +end +# --- In-place broadcast --- + +function bcast_inplace() + MPI.Barrier(comm) + if rank == 0 + Dagger.bcast_send_yield!(arr, comm, 0, tag) + else + Dagger.bcast_recv_yield!(arr, comm, 0, tag) + end + MPI.Barrier(comm) +end + +function bcast_inplace_metadata() + MPI.Barrier(comm) + if rank == 0 + Dagger.bcast_send_yield_metadata(arr, comm, 0) + end + MPI.Barrier(comm) +end + + +inplace = @time bcast_inplace() + + +MPI.Barrier(comm) +MPI.Finalize() + + + + +#= +A = rand(Blocks(2,2), 4, 4) +Ac = collect(A) +println(Ac) + + +move!(identity, Ac[1].space , Ac[2].space, Ac[1], Ac[2]) +=# + From f01909cd8ec93b2f74dc46cc53d3e6fcedecfcdf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Yan=20Guimar=C3=A3es?= Date: Thu, 23 Apr 2026 19:30:09 -0300 Subject: [PATCH 3/6] add MPI test --- src/datadeps/remainders.jl | 6 --- src/sch/Sch.jl | 18 +++---- src/sch/util.jl | 3 +- test/mpi.jl | 104 ++++++++++++++++--------------------- 4 files changed, 57 insertions(+), 74 deletions(-) diff --git a/src/datadeps/remainders.jl b/src/datadeps/remainders.jl index af4b8a13c..88201c621 100644 --- a/src/datadeps/remainders.jl +++ b/src/datadeps/remainders.jl @@ -145,14 +145,8 @@ function compute_remainder_for_arg!(state::DataDepsState, nspans = length(first(target_ainfos)) # FIXME: This is a hack to ensure that we don't miss any history generated by aliasing(...) -<<<<<<< HEAD - for (_, space, _) in state.arg_history[arg_w] - if !in(space, spaces) -======= for entry in state.arg_history[arg_w] if !in(entry.space, spaces) - @opcounter :compute_remainder_for_arg_restart ->>>>>>> 85e0b801 (MPI: Optimizations and fix some uniformity issues) @goto restart end end diff --git a/src/sch/Sch.jl b/src/sch/Sch.jl index 3fcd87070..3b8688a16 100644 --- a/src/sch/Sch.jl +++ b/src/sch/Sch.jl @@ -93,7 +93,7 @@ struct ComputeState worker_chans::Dict{Int,Tuple{RemoteChannel,RemoteChannel}} signature_time_cost::Dict{Signature,UInt64} signature_alloc_cost::Dict{Signature,UInt64} - worker_transfer_rate::Dict{Int,Dict{Processor,UInt64}} + worker_transfer_rate::Dict{Processor,Dict{Processor,UInt64}} halt::Base.Event lock::ReentrantLock futures::Dict{Thunk, Vector{ThunkFuture}} @@ -115,14 +115,14 @@ function start_state(deps::Dict, node_order, chan) Dict{Int, WeakThunk}(), node_order, WeakKeyDict{Any,Chunk}(), - Dict{Int,Dict{Processor,UInt64}}(), - Dict{Int,Dict{Union{StorageResource,Nothing},UInt64}}(), - Dict{Int,Dict{Union{StorageResource,Nothing},UInt64}}(), - Dict{Int,NTuple{3,Float64}}(), - Dict{Int, Tuple{RemoteChannel,RemoteChannel}}(), + Dict{Processor,Dict{Processor,UInt64}}(), + Dict{Processor,Dict{Union{StorageResource,Nothing},UInt64}}(), + Dict{Processor,Dict{Union{StorageResource,Nothing},UInt64}}(), + Dict{Processor,NTuple{3,Float64}}(), + Dict{Processor,Tuple{RemoteChannel,RemoteChannel}}(), Dict{Signature,UInt64}(), Dict{Signature,UInt64}(), - Dict{Int,Dict{Processor,UInt64}}(), + Dict{Processor,Dict{Processor,UInt64}}(), Base.Event(), ReentrantLock(), Dict{Thunk, Vector{ThunkFuture}}(), @@ -1189,7 +1189,7 @@ function start_processor_runner!(istate::ProcessorInternalState, uid::UInt64, re end task, occupancy = peek(queue) scope = task.scope - accel = something(task.options.acceleration, Dagger.DistributedAcceleration()) + accel = something(task.options.acceleration, Dagger.DistributedAcceleration()) if Dagger.proc_in_scope(to_proc, scope) && Dagger.accel_matches_proc(accel, to_proc) typemax(UInt32) - proc_occupancy_cached >= occupancy # Compatible, steal this task @@ -1400,7 +1400,7 @@ function do_tasks(to_proc, return_queue, tasks) end @dagdebug nothing :processor "Kicked processors" end - + const SCHED_MOVE = ScopedValue{Bool}(false) """ diff --git a/src/sch/util.jl b/src/sch/util.jl index 8e36c3576..38b767588 100644 --- a/src/sch/util.jl +++ b/src/sch/util.jl @@ -594,8 +594,9 @@ const DEFAULT_TRANSFER_RATE = UInt64(1_000_000) # Add fixed cost for cross-worker task transfer (esimated at 1ms) # TODO: Actually estimate/benchmark this task_xfer_cost = root_worker_id(gproc) != myid() ? 1_000_000 : 0 # 1ms + pid = Dagger.root_worker_id(gproc) - tx_rate = get(get(state.worker_transfer_rate, gproc.pid, Dict{Processor,UInt64}()), proc, DEFAULT_TRANSFER_RATE) + tx_rate = get(get(state.worker_transfer_rate, pid, Dict{Processor,UInt64}()), proc, DEFAULT_TRANSFER_RATE) costs[proc] = est_time_util + (tx_cost/tx_rate) + task_xfer_cost end chunks_cleanup() diff --git a/test/mpi.jl b/test/mpi.jl index a84ffdce1..c6d2cbae3 100644 --- a/test/mpi.jl +++ b/test/mpi.jl @@ -1,70 +1,58 @@ -using Dagger -using MPI -using LinearAlgebra -using SparseArrays - +using Dagger, MPI, LinearAlgebra Dagger.accelerate!(:mpi) - comm = MPI.COMM_WORLD rank = MPI.Comm_rank(comm) -size = MPI.Comm_size(comm) - -# Use a large array (adjust size as needed for your RAM) -N = 100 -tag = 123 +sz = MPI.Comm_size(comm) + +mpidagger_all_results = [] + +# Define constants +# You need to define the MPI workers before running the benchmark +# Example: mpirun -n 4 julia --project benchmarks/DaggerMPI_Weak_scale.jl +datatype = [Float32, Float64] +datasize = 40 +try + for T in datatype + A = rand(T, datasize, datasize) + A = A * A' + A[diagind(A)] .+= size(A, 1) + B = copy(A) + @assert ishermitian(B) + DA = distribute(A, Blocks(20,20)) + DB = distribute(B, Blocks(20,20)) + + LinearAlgebra._chol!(DA, UpperTriangular) + elapsed_time = @elapsed chol_DB = LinearAlgebra._chol!(DB, UpperTriangular) + + # Store results + result = ( + procs = sz, + dtype = T, + size = datasize, + time = elapsed_time, + gflops = (datasize^3 / 3) / (elapsed_time * 1e9) + ) + push!(mpidagger_all_results, result) -if rank == 0 - arr = sprand(N, N, 0.6) -else - arr = spzeros(N, N) -end -# --- Out-of-place broadcast --- -function bcast_outofplace() - MPI.Barrier(comm) - if rank == 0 - Dagger.bcast_send_yield(arr, comm, 0, tag+1) - else - Dagger.bcast_recv_yield(comm, 0, tag+1) end - MPI.Barrier(comm) -end -# --- In-place broadcast --- - -function bcast_inplace() - MPI.Barrier(comm) +catch e if rank == 0 - Dagger.bcast_send_yield!(arr, comm, 0, tag) - else - Dagger.bcast_recv_yield!(arr, comm, 0, tag) + showerror(stdout, e) end - MPI.Barrier(comm) end +if rank == 0 + #= Write results to CSV + mkpath("benchmarks/results") + if !isempty(mpidagger_all_results) + df = DataFrame(mpidagger_all_results) + CSV.write("benchmarks/results/DaggerMPI_Weak_scale_results.csv", df) -function bcast_inplace_metadata() - MPI.Barrier(comm) - if rank == 0 - Dagger.bcast_send_yield_metadata(arr, comm, 0) end - MPI.Barrier(comm) + =# + # Summary statistics + for result in mpidagger_all_results + println(result.procs, ",", result.dtype, ",", result.size, ",", result.time, ",", result.gflops) + end + #println("\nAll Cholesky tests completed!") end - - -inplace = @time bcast_inplace() - - -MPI.Barrier(comm) -MPI.Finalize() - - - - -#= -A = rand(Blocks(2,2), 4, 4) -Ac = collect(A) -println(Ac) - - -move!(identity, Ac[1].space , Ac[2].space, Ac[1], Ac[2]) -=# - From 63fa20122b03e93f98d7f127d8e7549a85c95b17 Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Wed, 29 Apr 2026 10:18:00 -0700 Subject: [PATCH 4/6] MPI fixups --- LocalPreferences.toml | 10 - src/Dagger.jl | 1 + src/acceleration.jl | 46 ++++ src/array/darray.jl | 2 +- src/chunks.jl | 6 +- src/datadeps/aliasing.jl | 529 +++++++++++++++++++------------------ src/datadeps/chunkview.jl | 49 ++-- src/datadeps/queue.jl | 371 +++++--------------------- src/datadeps/remainders.jl | 272 ++++++++++++------- src/datadeps/scheduling.jl | 6 +- src/dtask.jl | 10 +- src/memory-spaces.jl | 345 ++++++++++++++++++------ src/mpi.jl | 170 ++++++++---- src/queue.jl | 2 +- src/sch/Sch.jl | 51 ++-- src/sch/util.jl | 8 +- src/submission.jl | 2 +- src/types/acceleration.jl | 4 +- test/mpi.jl | 28 +- 19 files changed, 1027 insertions(+), 885 deletions(-) delete mode 100644 LocalPreferences.toml create mode 100644 src/acceleration.jl diff --git a/LocalPreferences.toml b/LocalPreferences.toml deleted file mode 100644 index 3a11c113f..000000000 --- a/LocalPreferences.toml +++ /dev/null @@ -1,10 +0,0 @@ -# When using system MPI, run once in the environment where you run MPI jobs (with MPI module loaded): -# julia --project=Dagger.jl -e 'using MPIPreferences; MPIPreferences.use_system_binary()' -# That populates abi, libmpi, mpiexec and avoids "Unknown MPI ABI nothing". -[MPIPreferences] -_format = "1.0" -abi = "MPICH" -binary = "system" -libmpi = "libmpi" -mpiexec = "mpiexec" -preloads = [] diff --git a/src/Dagger.jl b/src/Dagger.jl index 1b3791274..1a5720784 100644 --- a/src/Dagger.jl +++ b/src/Dagger.jl @@ -91,6 +91,7 @@ abstract type MemorySpace end include("utils/memory-span.jl") include("utils/interval_tree.jl") include("memory-spaces.jl") +include("acceleration.jl") # Task scheduling include("compute.jl") diff --git a/src/acceleration.jl b/src/acceleration.jl new file mode 100644 index 000000000..f95236468 --- /dev/null +++ b/src/acceleration.jl @@ -0,0 +1,46 @@ +const ACCELERATION = TaskLocalValue{Acceleration}(() -> DistributedAcceleration()) + +current_acceleration() = ACCELERATION[] + +default_processor(::DistributedAcceleration) = OSProc(myid()) +default_processor(accel::DistributedAcceleration, x) = default_processor(accel) +default_processor() = default_processor(current_acceleration()) + +accelerate!(accel::Symbol) = accelerate!(Val{accel}()) +accelerate!(::Val{:distributed}) = accelerate!(DistributedAcceleration()) + +function _with_default_acceleration(f) + old_accel = ACCELERATION[] + ACCELERATION[] = DistributedAcceleration() + result = try + f() + finally + ACCELERATION[] = old_accel + end + return result +end + +initialize_acceleration!(a::DistributedAcceleration) = nothing +function accelerate!(accel::Acceleration) + initialize_acceleration!(accel) + ACCELERATION[] = accel +end +accelerate!(::Nothing) = nothing + +accel_matches_proc(accel::DistributedAcceleration, proc::OSProc) = true +accel_matches_proc(accel::DistributedAcceleration, proc) = true + +function compatible_processors(accel::Union{Acceleration,Nothing}, scope::AbstractScope, procs::Vector{<:Processor}) + comp = compatible_processors(scope, procs) + accel === nothing && return comp + return Set(p for p in comp if accel_matches_proc(accel, p)) +end + +uniform_execution(::DistributedAcceleration) = false +uniform_execution() = uniform_execution(current_acceleration()) + +default_processor(space::CPURAMMemorySpace) = OSProc(space.owner) +default_memory_space(accel::DistributedAcceleration) = CPURAMMemorySpace(myid()) +default_memory_space(accel::DistributedAcceleration, x) = default_memory_space(accel) +default_memory_space(x) = default_memory_space(current_acceleration(), x) +default_memory_space() = default_memory_space(current_acceleration()) diff --git a/src/array/darray.jl b/src/array/darray.jl index 7e723acd0..fc99dc75d 100644 --- a/src/array/darray.jl +++ b/src/array/darray.jl @@ -36,7 +36,7 @@ Base.getindex(arr::AbstractArray{T,0} where T, d::ArrayDomain{0}) = arr Base.getindex(arr::GPUArraysCore.AbstractGPUArray, d::ArrayDomain) = arr[indexes(d)...] Base.getindex(arr::GPUArraysCore.AbstractGPUArray{T,0} where T, d::ArrayDomain{0}) = arr -function intersect(a::ArrayDomain, b::ArrayDomain) +function Base.intersect(a::ArrayDomain, b::ArrayDomain) if a === b return a end diff --git a/src/chunks.jl b/src/chunks.jl index 0defc1ff6..d5e7b6082 100644 --- a/src/chunks.jl +++ b/src/chunks.jl @@ -28,15 +28,15 @@ collect(ctx::Context, ref::DRef; options=nothing) = collect(ctx::Context, ref::FileRef; options=nothing) = poolget(ref) # FIXME: Do move call @warn "Fix semantics of collect" maxlog=1 -function Base.fetch(chunk::Chunk{T}; unwrap::Bool=false, uniform::Bool=false, kwargs...) where T +function Base.fetch(chunk::Chunk{T}; unwrap::Bool=false, uniform::Bool=uniform_execution(), kwargs...) where T value = fetch_handle(chunk.handle; uniform)::T if unwrap && unwrappable(value) return fetch(value; unwrap, uniform, kwargs...) end return value end -fetch_handle(ref::DRef; uniform::Bool=false) = poolget(ref) -fetch_handle(ref::FileRef; uniform::Bool=false) = poolget(ref) +fetch_handle(ref::DRef; uniform::Bool) = poolget(ref) +fetch_handle(ref::FileRef; uniform::Bool) = poolget(ref) unwrappable(x::Chunk) = true unwrappable(x::DRef) = true unwrappable(x::FileRef) = true diff --git a/src/datadeps/aliasing.jl b/src/datadeps/aliasing.jl index e9ff24a79..848443e8e 100644 --- a/src/datadeps/aliasing.jl +++ b/src/datadeps/aliasing.jl @@ -8,7 +8,7 @@ export In, Out, InOut, Deps, spawn_datadeps ============================================================================== This file implements the data dependencies system for Dagger tasks, which allows -tasks to write to their arguments in a controlled manner. The system maintains +tasks to access their arguments in a controlled manner. The system maintains data coherency across distributed workers by tracking aliasing relationships and orchestrating data movement operations. @@ -25,26 +25,59 @@ KEY CONCEPTS: 1. ALIASING ANALYSIS: - Every mutable argument is analyzed for its memory access pattern - Memory spans are computed to determine which bytes in memory are accessed - - Objects that access overlapping memory spans are considered "aliasing" + - Arguments that access overlapping memory spans are considered "aliasing" - Examples: An array A and view(A, 2:3, 2:3) alias each other 2. DATA LOCALITY TRACKING: - The system tracks where the "source of truth" for each piece of data lives - As tasks execute and modify data, the source of truth may move between workers - - Each aliasing region can have its own independent source of truth location + - Each argument can have its own independent source of truth location 3. ALIASED OBJECT MANAGEMENT: - When copying arguments between workers, the system tracks "aliased objects" - This ensures that if both an array and its view need to be copied to a worker, only one copy of the underlying array is made, with the view pointing to it - - The aliased_object!() functions manage this sharing + - The aliased_object!() and move_rewrap() functions manage this sharing + +ALIASING INFO: +-------------- + +The system uses different types of aliasing info to represent different types of +aliasing relationships: + +- ContiguousAliasing: Single contiguous memory region (e.g., full array) +- StridedAliasing: Multiple non-contiguous regions (e.g., SubArray) +- DiagonalAliasing: Diagonal elements only (e.g., Diagonal(A)) +- TriangularAliasing: Triangular regions (e.g., UpperTriangular(A)) + +Any two aliasing objects can be compared using the will_alias function to +determine if they overlap. Additionally, any aliasing object can be converted to +a vector of memory spans, which represents the contiguous regions of memory that +the aliasing object covers. + +DATA MOVEMENT FUNCTIONS: +------------------------ + +move!(dep_mod, to_space, from_space, to, from): +- The core in-place data movement function +- dep_mod specifies which part of the data to copy (identity, UpperTriangular, etc.) +- Supports partial copies via RemainderAliasing dependency modifiers + +move_rewrap(...): +- Handles copying of wrapped objects (SubArrays, ChunkViews) +- Ensures aliased objects are reused on destination worker + +read/write_remainder!(...): +- Read/write a span of memory from an object to/from a buffer +- Used by move! to copy the remainder of an aliased object THE DISTRIBUTED ALIASING PROBLEM: --------------------------------- In a multithreaded environment, aliasing "just works" because all tasks operate -on the same memory. However, in a distributed environment, arguments must be -copied between workers, which breaks aliasing relationships. +on the user-provided memory. However, in a distributed environment, arguments +must be copied between workers, which breaks aliasing relationships if care is +not taken. Consider this scenario: ```julia @@ -63,11 +96,9 @@ MULTITHREADED BEHAVIOR (WORKS): - Task dependencies ensure correct ordering (e.g., Task 1 then Task 2) DISTRIBUTED BEHAVIOR (THE PROBLEM): -- Tasks may be scheduled on different workers - Each argument must be copied to the destination worker -- Without special handling, we would copy A to worker1 and vA to worker2 -- This creates two separate arrays, breaking the aliasing relationship -- Updates to the view on worker2 don't affect the array on worker1 +- Without special handling, we would copy A and vA independently to another worker +- This creates two separate arrays, breaking the aliasing relationship between A and vA THE SOLUTION - PARTIAL DATA MOVEMENT: ------------------------------------- @@ -81,12 +112,13 @@ The datadeps system solves this by: 2. PARTIAL DATA TRANSFER: - Instead of copying entire objects, only transfer the "dirty" regions - - This minimizes network traffic and maximizes parallelism - - Uses the move!(dep_mod, ...) function with dependency modifiers + - This prevents overwrites of data that has already been updated by another task + - This also minimizes network traffic and overall copy time + - Uses the move!(dep_mod, ...) function with RemainderAliasing dependency modifiers 3. REMAINDER TRACKING: + - When a task needs the full object, copy partial regions as needed - When a partial region is updated, track what parts still need updating - - Before a task needs the full object, copy the remaining "clean" regions - This preserves all updates while avoiding overwrites EXAMPLE EXECUTION FLOW: @@ -108,69 +140,24 @@ Tasks: T1 modifies InOut(A), T2 modifies InOut(vA) - T2 needs vA, but vA aliases with A (which was modified by T1) - Copy vA-region of A from worker1 to worker2 - This is a PARTIAL copy - only the 2:3, 2:3 region - - Create vA on worker2 pointing to the appropriate region + - Create vA on worker2 pointing to the appropriate region of A - T2 executes, modifying vA region on worker2 - Update: vA's data_locality = worker2 4. FINAL SYNCHRONIZATION: - - Some future task needs the complete A - - A needs to be assembled from: worker1 (non-vA regions) + worker2 (vA region) - - REMAINDER COPY: Copy non-vA regions from worker1 to worker2 - - OR INVERSE: Copy vA-region from worker2 to worker1, then copy full A - -MEMORY SPAN COMPUTATION: ------------------------- - -The system uses memory spans to determine aliasing and compute remainders: + - Need to copy-back A and vA to worker0 + - A needs to be assembled from: worker1 (non-vA regions of A) + worker2 (vA region of A) + - REMAINDER COPY: Copy non-vA regions from worker1 to worker0 + - REMAINDER COPY: Copy vA region from worker2 to worker0 -- ContiguousAliasing: Single contiguous memory region (e.g., full array) -- StridedAliasing: Multiple non-contiguous regions (e.g., SubArray) -- DiagonalAliasing: Diagonal elements only (e.g., Diagonal(A)) -- TriangularAliasing: Triangular regions (e.g., UpperTriangular(A)) +REMAINDER COMPUTATION: +---------------------- Remainder computation involves: 1. Computing memory spans for all overlapping aliasing objects 2. Finding the set difference: full_object_spans - updated_spans -3. Creating a "remainder aliasing" object representing the not-yet-updated regions -4. Performing move! with this remainder object to copy only needed data - -DATA MOVEMENT FUNCTIONS: ------------------------- - -move!(dep_mod, to_space, from_space, to, from): -- The core in-place data movement function -- dep_mod specifies which part of the data to copy (identity, UpperTriangular, etc.) -- Supports partial copies via dependency modifiers - -move_rewrap(): -- Handles copying of wrapped objects (SubArrays, ChunkViews) -- Ensures aliased objects are reused on destination worker - -enqueue_copy_to!(): -- Schedules data movement tasks before user tasks -- Ensures data is up-to-date on the worker where a task will run - -CURRENT LIMITATIONS AND TODOS: -------------------------------- - -1. REMAINDER COMPUTATION: - - The system currently handles simple overlaps but needs sophisticated - remainder calculation for complex aliasing patterns - - Need functions to compute span set differences - -2. ORDERING DEPENDENCIES: - - Need to ensure remainder copies happen in correct order - - Must not overwrite more recent updates with stale data - -3. COMPLEX ALIASING PATTERNS: - - Multiple overlapping views of the same array - - Nested aliasing structures (views of views) - - Mixed aliasing types (diagonal + triangular regions) - -4. PERFORMANCE OPTIMIZATION: - - Minimize number of copy operations - - Batch compatible transfers - - Optimize for common access patterns +3. Creating a RemainderAliasing object representing the difference between spans +4. Performing one or more move! calls with this RemainderAliasing object to copy only needed data =# "Specifies a read-only dependency." @@ -192,6 +179,11 @@ struct Deps{T,DT<:Tuple} end Deps(x, deps...) = Deps(x, deps) +chunktype(::In{T}) where T = T +chunktype(::Out{T}) where T = T +chunktype(::InOut{T}) where T = T +chunktype(::Deps{T,DT}) where {T,DT} = T + function unwrap_inout(arg) readdep = false writedep = false @@ -226,7 +218,6 @@ _identity_hash(arg::Chunk, h::UInt=UInt(0)) = hash(arg.handle, hash(Chunk, h)) _identity_hash(arg::SubArray, h::UInt=UInt(0)) = hash(arg.indices, hash(arg.offset1, hash(arg.stride1, _identity_hash(arg.parent, h)))) _identity_hash(arg::CartesianIndices, h::UInt=UInt(0)) = hash(arg.indices, hash(typeof(arg), h)) -@warn "Dispatch bcast behavior on acceleration" maxlog=1 struct ArgumentWrapper arg dep_mod @@ -252,6 +243,7 @@ struct HistoryEntry end struct AliasedObjectCacheStore + accel::Acceleration keys::Vector{AbstractAliasing} derived::Dict{AbstractAliasing,AbstractAliasing} stored::Dict{MemorySpace,Set{AbstractAliasing}} @@ -259,7 +251,8 @@ struct AliasedObjectCacheStore originals::Set{AbstractAliasing} end AliasedObjectCacheStore() = - AliasedObjectCacheStore(Vector{AbstractAliasing}(), + AliasedObjectCacheStore(current_acceleration(), + Vector{AbstractAliasing}(), Dict{AbstractAliasing,AbstractAliasing}(), Dict{MemorySpace,Set{AbstractAliasing}}(), Dict{MemorySpace,Dict{AbstractAliasing,Chunk}}(), @@ -289,7 +282,7 @@ end function set_stored!(cache::AliasedObjectCacheStore, dest_space::MemorySpace, value::Chunk, ainfo::AbstractAliasing) @assert !is_stored(cache, dest_space, ainfo) "Cache already has derived ainfo $ainfo" key = cache.derived[ainfo] - value_ainfo = aliasing(value, identity) + value_ainfo = aliasing(cache.accel, value, identity) cache.derived[value_ainfo] = key push!(get!(Set{AbstractAliasing}, cache.stored, dest_space), key) values_dict = get!(Dict{AbstractAliasing,Chunk}, cache.values, dest_space) @@ -306,6 +299,7 @@ function set_key_stored!(cache::AliasedObjectCacheStore, space::MemorySpace, ain end struct AliasedObjectCache + accel::Acceleration space::MemorySpace chunk::Chunk end @@ -350,7 +344,7 @@ function set_key_stored!(cache::AliasedObjectCache, space::MemorySpace, ainfo::A cache_raw = unwrap(cache.chunk)::AliasedObjectCacheStore set_key_stored!(cache_raw, space, ainfo, value) end -function aliased_object!(f, cache::AliasedObjectCache, x; ainfo=aliasing(current_acceleration(), x, identity)) +function aliased_object!(f, cache::AliasedObjectCache, x; ainfo=aliasing(cache.accel, x, identity)) x_space = memory_space(x) if !is_key_present(cache, x_space, ainfo) # Preserve the object's memory-space/processor pairing when inserting @@ -366,14 +360,13 @@ function aliased_object!(f, cache::AliasedObjectCache, x; ainfo=aliasing(current @assert y isa Chunk "Didn't get a Chunk from functor" @assert memory_space(y) == cache.space "Space mismatch! $(memory_space(y)) != $(cache.space)" if memory_space(x) != cache.space - @assert ainfo != aliasing(current_acceleration(), y, identity) "Aliasing mismatch! $ainfo == $(aliasing(current_acceleration(), y, identity))" + @assert ainfo != aliasing(caache.accel, y, identity) "Aliasing mismatch! $ainfo == $(aliasing(cache.accel, y, identity))" end set_stored!(cache, y, ainfo) return y end end -@warn "Switch ArgumentWrapper to contain just the argument, and add DependencyWrapper" maxlog=1 struct DataDepsState # The mapping of original raw argument to its Chunk raw_arg_to_chunk::IdDict{Any,Chunk} @@ -389,10 +382,13 @@ struct DataDepsState # The mapping of remote argument to original argument remote_arg_to_original::IdDict{Any,Any} + # The mapping of original argument wrapper to remote argument wrapper + remote_arg_w::Dict{ArgumentWrapper,Dict{MemorySpace,ArgumentWrapper}} + # The mapping of ainfo to argument and dep_mod # Used to lookup which argument and dep_mod a given ainfo is generated from # N.B. This is a mapping for remote argument copies - ainfo_arg::Dict{AliasingWrapper,ArgumentWrapper} + ainfo_arg::Dict{AliasingWrapper,Set{ArgumentWrapper}} # The history of writes (direct or indirect) to each argument and dep_mod, in terms of ainfos directly written to, and the memory space they were written to # Updated when a new write happens on an overlapping ainfo @@ -410,7 +406,7 @@ struct DataDepsState # The mapping of, for a given memory space, the backing Chunks that an ainfo references # Used by slot generation to replace the backing Chunks during move - ainfo_backing_chunk::Dict{MemorySpace,Dict{AbstractAliasing,Chunk}} + ainfo_backing_chunk::Chunk{AliasedObjectCacheStore} # Cache of argument's supports_inplace_move query result supports_inplace_cache::IdDict{Any,Bool} @@ -419,6 +415,10 @@ struct DataDepsState # N.B. This is a mapping for remote argument copies ainfo_cache::Dict{ArgumentWrapper,AliasingWrapper} + # The oracle for aliasing lookups + # Used to populate ainfos_overlaps efficiently + ainfos_lookup::AliasingLookup + # The overlapping ainfos for each ainfo # Incrementally updated as new ainfos are created # Used for fast will_alias lookups @@ -430,58 +430,32 @@ struct DataDepsState ainfos_owner::Dict{AliasingWrapper,Union{Pair{DTask,Int},Nothing}} ainfos_readers::Dict{AliasingWrapper,Vector{Pair{DTask,Int}}} - function DataDepsState(aliasing::Bool) - if !aliasing - @warn "aliasing=false is no longer supported, aliasing is now always enabled" maxlog=1 - end - + function DataDepsState() arg_to_chunk = IdDict{Any,Chunk}() arg_origin = IdDict{Any,MemorySpace}() remote_args = Dict{MemorySpace,IdDict{Any,Any}}() remote_arg_to_original = IdDict{Any,Any}() - ainfo_arg = Dict{AliasingWrapper,ArgumentWrapper}() + remote_arg_w = Dict{ArgumentWrapper,Dict{MemorySpace,ArgumentWrapper}}() + ainfo_arg = Dict{AliasingWrapper,Set{ArgumentWrapper}}() + arg_history = Dict{ArgumentWrapper,Vector{HistoryEntry}}() arg_owner = Dict{ArgumentWrapper,MemorySpace}() arg_overlaps = Dict{ArgumentWrapper,Set{ArgumentWrapper}}() - ainfo_backing_chunk = Dict{MemorySpace,Dict{AbstractAliasing,Chunk}}() - arg_history = Dict{ArgumentWrapper,Vector{HistoryEntry}}() + ainfo_backing_chunk = _with_default_acceleration() do + tochunk(AliasedObjectCacheStore()) + end supports_inplace_cache = IdDict{Any,Bool}() ainfo_cache = Dict{ArgumentWrapper,AliasingWrapper}() + ainfos_lookup = AliasingLookup() ainfos_overlaps = Dict{AliasingWrapper,Set{AliasingWrapper}}() ainfos_owner = Dict{AliasingWrapper,Union{Pair{DTask,Int},Nothing}}() ainfos_readers = Dict{AliasingWrapper,Vector{Pair{DTask,Int}}}() - return new(arg_to_chunk, arg_origin, remote_args, remote_arg_to_original, ainfo_arg, arg_owner, arg_overlaps, ainfo_backing_chunk, arg_history, - supports_inplace_cache, ainfo_cache, ainfos_overlaps, ainfos_owner, ainfos_readers) - end -end - -# N.B. arg_w must be the original argument wrapper, not a remote copy -function aliasing!(state::DataDepsState, target_space::MemorySpace, arg_w::ArgumentWrapper) - # Grab the remote copy of the argument, and calculate the ainfo - remote_arg = get_or_generate_slot!(state, target_space, arg_w.arg) - remote_arg_w = ArgumentWrapper(remote_arg, arg_w.dep_mod) - - # Check if we already have the result cached - if haskey(state.ainfo_cache, remote_arg_w) - return state.ainfo_cache[remote_arg_w] + return new(arg_to_chunk, arg_origin, remote_args, remote_arg_to_original, remote_arg_w, ainfo_arg, arg_history, arg_owner, arg_overlaps, ainfo_backing_chunk, + supports_inplace_cache, ainfo_cache, ainfos_lookup, ainfos_overlaps, ainfos_owner, ainfos_readers) end - - # Calculate the ainfo - ainfo = AliasingWrapper(aliasing(current_acceleration(), remote_arg, arg_w.dep_mod)) - - # Cache the result - state.ainfo_cache[remote_arg_w] = ainfo - - # Update the mapping of ainfo to argument and dep_mod - state.ainfo_arg[ainfo] = remote_arg_w - - # Populate info for the new ainfo - populate_ainfo!(state, arg_w, ainfo, target_space) - - return ainfo end function supports_inplace_move(state::DataDepsState, arg) @@ -497,70 +471,75 @@ function is_writedep(arg, deps, task::DTask) end # Aliasing state setup -# Internal: iterate over task args and call callback(arg, pos, may_alias, inplace_move, deps) for each tracked arg. -function _populate_task_info!(state::DataDepsState, spec::DTaskSpec, task::DTask, callback) - for (idx, _arg) in enumerate(spec.fargs) - arg_pos = _arg.pos # ArgPosition for this argument (Argument/TypedArgument have .pos) - arg = value(_arg) +function populate_task_info!(state::DataDepsState, task_args, spec::DTaskSpec, task::DTask) + # Track the task's arguments and access patterns + return map_or_ntuple(task_args) do idx + _arg = task_args[idx] + + # Unwrap the argument + _arg_with_deps = value(_arg) + pos = _arg.pos # Unwrap In/InOut/Out wrappers and record dependencies - arg, deps = unwrap_inout(arg) - - # Unwrap the Chunk underlying any DTask arguments only when already ready. - # Fetching an unready DTask here would deadlock: distribute_tasks! runs before - # the scheduler, so dependent tasks have not run yet. Skip aliasing for unready - # DTasks so we pass them through; the worker will fetch at execution time (may block on MPI). - if arg isa DTask - isready(arg) || continue - arg = fetch(arg; move_value=false, unwrap=false) + arg_pre_unwrap, deps = unwrap_inout(_arg_with_deps) + + # Unwrap the Chunk underlying any DTask arguments + arg = arg_pre_unwrap isa DTask ? fetch(arg_pre_unwrap; raw=true) : arg_pre_unwrap + + # Skip non-aliasing arguments or arguments that don't support in-place move + may_alias = type_may_alias(typeof(arg)) + inplace_move = may_alias && supports_inplace_move(state, arg) + if !may_alias || !inplace_move + arg_w = ArgumentWrapper(arg, identity) + if is_typed(spec) + return TypedDataDepsTaskArgument(arg, pos, may_alias, inplace_move, (DataDepsTaskDependency(arg_w, false, false),)) + else + return DataDepsTaskArgument(arg, pos, may_alias, inplace_move, [DataDepsTaskDependency(arg_w, false, false)]) + end end - # Skip non-aliasing arguments - type_may_alias(typeof(arg)) || continue - - # Skip arguments not supporting in-place move - supports_inplace_move(state, arg) || continue - # Generate a Chunk for the argument if necessary if haskey(state.raw_arg_to_chunk, arg) - arg = state.raw_arg_to_chunk[arg] + arg_chunk = state.raw_arg_to_chunk[arg] else if !(arg isa Chunk) - new_arg = with(MPI_UID=>task.uid) do + arg_chunk = with(MPI_TID=>task.uid) do tochunk(arg) end - state.raw_arg_to_chunk[arg] = new_arg - arg = new_arg + state.raw_arg_to_chunk[arg] = arg_chunk else state.raw_arg_to_chunk[arg] = arg + arg_chunk = arg end end # Track the origin space of the argument - origin_space = memory_space(arg) + origin_space = memory_space(arg_chunk) check_uniform(origin_space) - state.arg_origin[arg] = origin_space - state.remote_arg_to_original[arg] = arg - - may_alias = true - inplace_move = true - callback(arg, arg_pos, may_alias, inplace_move, deps) + state.arg_origin[arg_chunk] = origin_space + state.remote_arg_to_original[arg_chunk] = arg_chunk # Populate argument info for all aliasing dependencies - for (dep_mod, _, _) in deps - # Generate an ArgumentWrapper for the argument - aw = ArgumentWrapper(arg, dep_mod) - - # Populate argument info - populate_argument_info!(state, aw, origin_space) + # And return the argument, dependencies, and ArgumentWrappers + if is_typed(spec) + deps = Tuple(DataDepsTaskDependency(arg_chunk, dep) for dep in deps) + map_or_ntuple(deps) do dep_idx + dep = deps[dep_idx] + # Populate argument info + populate_argument_info!(state, dep.arg_w, origin_space) + end + return TypedDataDepsTaskArgument(arg_chunk, pos, may_alias, inplace_move, deps) + else + deps = [DataDepsTaskDependency(arg_chunk, dep) for dep in deps] + map_or_ntuple(deps) do dep_idx + dep = deps[dep_idx] + # Populate argument info + populate_argument_info!(state, dep.arg_w, origin_space) + end + return DataDepsTaskArgument(arg_chunk, pos, may_alias, inplace_move, deps) end end end - -function populate_task_info!(state::DataDepsState, spec::DTaskSpec, task::DTask) - # Track the task's arguments and access patterns (callback only for state updates) - _populate_task_info!(state, spec, task, (arg, pos, may_alias, inplace_move, deps) -> nothing) -end function populate_argument_info!(state::DataDepsState, arg_w::ArgumentWrapper, origin_space::MemorySpace) # Initialize ownership and history if !haskey(state.arg_owner, arg_w) @@ -580,23 +559,56 @@ function populate_argument_info!(state::DataDepsState, arg_w::ArgumentWrapper, o # Calculate the ainfo (which will populate ainfo structures and merge history) aliasing!(state, origin_space, arg_w) end +# N.B. arg_w must be the original argument wrapper, not a remote copy +function aliasing!(state::DataDepsState, target_space::MemorySpace, arg_w::ArgumentWrapper) + if haskey(state.remote_arg_w, arg_w) && haskey(state.remote_arg_w[arg_w], target_space) + remote_arg_w = @inbounds state.remote_arg_w[arg_w][target_space] + remote_arg = remote_arg_w.arg + else + # Grab the remote copy of the argument, and calculate the ainfo + remote_arg = get_or_generate_slot!(state, target_space, arg_w.arg) + remote_arg_w = ArgumentWrapper(remote_arg, arg_w.dep_mod) + get!(Dict{MemorySpace,ArgumentWrapper}, state.remote_arg_w, arg_w)[target_space] = remote_arg_w + end + + # Check if we already have the result cached + if haskey(state.ainfo_cache, remote_arg_w) + return state.ainfo_cache[remote_arg_w] + end + + # Calculate the ainfo + ainfo = AliasingWrapper(aliasing(current_acceleration(), remote_arg, arg_w.dep_mod)) + + # Cache the result + state.ainfo_cache[remote_arg_w] = ainfo + + # Update the mapping of ainfo to argument and dep_mod + if !haskey(state.ainfo_arg, ainfo) + state.ainfo_arg[ainfo] = Set{ArgumentWrapper}([remote_arg_w]) + end + push!(state.ainfo_arg[ainfo], remote_arg_w) + + # Populate info for the new ainfo + populate_ainfo!(state, arg_w, ainfo, target_space) + + return ainfo +end function populate_ainfo!(state::DataDepsState, original_arg_w::ArgumentWrapper, target_ainfo::AliasingWrapper, target_space::MemorySpace) - # Initialize owner and readers if !haskey(state.ainfos_owner, target_ainfo) + # Add ourselves to the lookup oracle + ainfo_idx = push!(state.ainfos_lookup, target_ainfo) + + # Find overlapping ainfos overlaps = Set{AliasingWrapper}() push!(overlaps, target_ainfo) - other_ainfos = (Dagger.current_acceleration() isa Dagger.MPIAcceleration - ? sort(collect(keys(state.ainfos_owner)), by=hash) - : keys(state.ainfos_owner)) - for other_ainfo in other_ainfos + for other_ainfo in intersect(state.ainfos_lookup, target_ainfo; ainfo_idx) target_ainfo == other_ainfo && continue - if will_alias(target_ainfo, other_ainfo) - # Mark us and them as overlapping - push!(overlaps, other_ainfo) - push!(state.ainfos_overlaps[other_ainfo], target_ainfo) + # Mark us and them as overlapping + push!(overlaps, other_ainfo) + push!(state.ainfos_overlaps[other_ainfo], target_ainfo) - # Add overlapping history to our own - other_remote_arg_w = state.ainfo_arg[other_ainfo] + # Add overlapping history to our own + for other_remote_arg_w in state.ainfo_arg[other_ainfo] other_arg = state.remote_arg_to_original[other_remote_arg_w.arg] other_arg_w = ArgumentWrapper(other_arg, other_remote_arg_w.dep_mod) push!(state.arg_overlaps[original_arg_w], other_arg_w) @@ -605,13 +617,16 @@ function populate_ainfo!(state::DataDepsState, original_arg_w::ArgumentWrapper, end end state.ainfos_overlaps[target_ainfo] = overlaps + + # Initialize owner and readers state.ainfos_owner[target_ainfo] = nothing state.ainfos_readers[target_ainfo] = Pair{DTask,Int}[] end end function merge_history!(state::DataDepsState, arg_w::ArgumentWrapper, other_arg_w::ArgumentWrapper) history = state.arg_history[arg_w] - largest_value_update!(length(history)) + @opcounter :merge_history + @opcounter :merge_history_complexity length(history) origin_space = state.arg_origin[other_arg_w.arg] for other_entry in state.arg_history[other_arg_w] write_num_tuple = HistoryEntry(AliasingWrapper(NoAliasing()), origin_space, other_entry.write_num) @@ -640,10 +655,13 @@ function merge_history!(state::DataDepsState, arg_w::ArgumentWrapper, other_arg_ end end function truncate_history!(state::DataDepsState, arg_w::ArgumentWrapper) + # FIXME: Do this continuously if possible if haskey(state.arg_history, arg_w) && length(state.arg_history[arg_w]) > 100000 origin_space = state.arg_origin[arg_w.arg] + @opcounter :truncate_history _, last_idx = compute_remainder_for_arg!(state, origin_space, arg_w, 0; compute_syncdeps=false) if last_idx > 0 + @opcounter :truncate_history_removed last_idx deleteat!(state.arg_history[arg_w], 1:last_idx) end end @@ -661,8 +679,8 @@ use of `x`, and the data in `x` will not be updated when the `spawn_datadeps` region returns. """ supports_inplace_move(x) = true -supports_inplace_move(t::DTask) = supports_inplace_move(fetch(t; move_value=false, unwrap=false)) -@warn "Fix this to work with MPI (can't call poolget on the wrong rank)" maxlog=1 +supports_inplace_move(t::DTask) = supports_inplace_move(fetch(t; raw=true)) +@warn "Fix supports_inplace_move for MPI" maxlog=1 function supports_inplace_move(c::Chunk) # FIXME return true @@ -737,46 +755,39 @@ function add_reader!(state::DataDepsState, arg_w::ArgumentWrapper, dest_space::M push!(state.ainfos_readers[ainfo], task=>write_num) end -# FIXME: These should go in MPIExt.jl -const MPI_TID = ScopedValue{Int64}(0) -const MPI_UID = ScopedValue{Int64}(0) - # Make a copy of each piece of data on each worker # memory_space => {arg => copy_of_arg} isremotehandle(x) = false isremotehandle(x::DTask) = true isremotehandle(x::Chunk) = true +@warn "Properly propagate MPI_TID and uniformity through any remotecalls" maxlog=1 function generate_slot!(state::DataDepsState, dest_space, data) - if data isa DTask - data = fetch(data; move_value=false, unwrap=false) - end # N.B. We do not perform any sync/copy with the current owner of the data, # because all we want here is to make a copy of some version of the data, # even if the data is not up to date. orig_space = memory_space(data) + check_uniform(orig_space) to_proc = first(processors(dest_space)) + check_uniform(to_proc) from_proc = first(processors(orig_space)) - dest_space_args = get!(IdDict{Any,Any}, state.remote_args, dest_space) - ALIASED_OBJECT_CACHE[] = get!(Dict{AbstractAliasing,Chunk}, state.ainfo_backing_chunk, dest_space) - if orig_space == dest_space && (data isa Chunk || !isremotehandle(data)) - # Fast path for local data that's already in a Chunk or not a remote handle needing rewrapping - task = DATADEPS_CURRENT_TASK[] - data_chunk = with(MPI_UID=>task.uid) do - tochunk(data, from_proc) - end - else - ctx = Sch.eager_context() - id = rand(Int) - @maybelog ctx timespan_start(ctx, :move, (;thunk_id=0, id, position=ArgPosition(), processor=to_proc), (;f=nothing, data)) - data_chunk = move_rewrap(from_proc, to_proc, orig_space, dest_space, data) - @maybelog ctx timespan_finish(ctx, :move, (;thunk_id=0, id, position=ArgPosition(), processor=to_proc), (;f=nothing, data=data_chunk)) + check_uniform(from_proc) + if MPI.Comm_rank(MPI.COMM_WORLD) == 0 + display(typeof(data)) end + check_uniform(typeof(data)) + dest_space_args = get!(IdDict{Any,Any}, state.remote_args, dest_space) + aliased_object_cache = AliasedObjectCache(current_acceleration(), dest_space, state.ainfo_backing_chunk) + ctx = Sch.eager_context() + id = rand(Int) + @maybelog ctx timespan_start(ctx, :move, (;thunk_id=0, id, position=ArgPosition(), processor=to_proc), (;f=nothing, data)) + data_chunk = with(MPI_TID=>DATADEPS_CURRENT_TASK[].uid) do + remotecall_endpoint(move_rewrap, current_acceleration(), aliased_object_cache, from_proc, to_proc, orig_space, dest_space, data) + end + @maybelog ctx timespan_finish(ctx, :move, (;thunk_id=0, id, position=ArgPosition(), processor=to_proc), (;f=nothing, data=data_chunk)) @assert memory_space(data_chunk) == dest_space "space mismatch! $dest_space (dest) != $(memory_space(data_chunk)) (actual) ($(typeof(data)) (data) vs. $(typeof(data_chunk)) (chunk)), spaces ($orig_space -> $dest_space)" dest_space_args[data] = data_chunk state.remote_arg_to_original[data_chunk] = data - ALIASED_OBJECT_CACHE[] = nothing - check_uniform(memory_space(dest_space_args[data])) check_uniform(processor(dest_space_args[data])) check_uniform(dest_space_args[data].handle) @@ -793,86 +804,78 @@ function get_or_generate_slot!(state, dest_space, data) end return state.remote_args[dest_space][data] end -function move_rewrap(from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, data) - return aliased_object!(data) do data - return remotecall_endpoint(identity, current_acceleration(), from_proc, to_proc, from_space, to_space, data) + +function remotecall_fetch_fast(f, wid::Integer, args...; kwargs...) + if wid == myid() + return f(args...; kwargs...) end + return remotecall_fetch(f, wid, args...; kwargs...) end -function remotecall_endpoint(f, ::Dagger.DistributedAcceleration, from_proc, to_proc, orig_space, dest_space, data) - to_w = root_worker_id(to_proc) - return remotecall_fetch(to_w, from_proc, to_proc, dest_space, data) do from_proc, to_proc, dest_space, data - data_converted = f(move(from_proc, to_proc, data)) - return tochunk(data_converted, to_proc, dest_space) +function remotecall_endpoint(f, accel::DistributedAcceleration, cache::AliasedObjectCache, from_proc, to_proc, from_space, to_space, data::Chunk) + from_w = root_worker_id(from_proc) + return remotecall_fetch_fast(from_w) do + data_raw = unwrap(data) + return f(accel, cache, from_proc, to_proc, from_space, to_space, data_raw)::Chunk end end -const ALIASED_OBJECT_CACHE = TaskLocalValue{Union{Dict{AbstractAliasing,Chunk}, Nothing}}(()->nothing) - -# Explicit cache for move_rewrap (used by haloarray, tests) -struct AliasedObjectCacheStore end -struct AliasedObjectCache - dest_space::MemorySpace - backing::Chunk - cache::Dict{AbstractAliasing,Chunk} - AliasedObjectCache(dest_space::MemorySpace, backing::Chunk) = new(dest_space, backing, Dict{AbstractAliasing,Chunk}()) -end -function move_rewrap(cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, data) - old = ALIASED_OBJECT_CACHE[] - ALIASED_OBJECT_CACHE[] = cache.cache - try - return move_rewrap(from_proc, to_proc, from_space, to_space, data) - finally - ALIASED_OBJECT_CACHE[] = old - end -end - -@warn "Document these public methods" maxlog=1 -# TODO: Use state to cache aliasing() results -function declare_aliased_object!(x; ainfo=aliasing(current_acceleration(), x, identity)) - cache = ALIASED_OBJECT_CACHE[] - cache[ainfo] = x -end -function aliased_object!(x; ainfo=aliasing(current_acceleration(), x, identity)) - cache = ALIASED_OBJECT_CACHE[] - if haskey(cache, ainfo) - y = cache[ainfo] - else - @assert x isa Chunk "x must be a Chunk\nUse functor form of aliased_object!" - cache[ainfo] = x - y = x +function remotecall_endpoint_transfer(f, accel::DistributedAcceleration, from_proc, to_proc, from_space, to_space, data) + to_w = root_worker_id(to_proc) + return remotecall_fetch_fast(to_w) do + return f(accel, from_proc, to_proc, from_space, to_space, data) + end +end +@warn "Replace all remotecall_fetch calls with remotecall_endpoint" maxlog=1 +move_rewrap(accel, cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, data::Chunk) = + remotecall_endpoint(move_rewrap, accel, cache, from_proc, to_proc, from_space, to_space, data) +function move_rewrap(accel, cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, data) + # Generic data, do the transfer + return aliased_object!(cache, data) do data + return remotecall_endpoint_transfer(accel, from_proc, to_proc, from_space, to_space, data) do accel, from_proc, to_proc, from_space, to_space, data + return tochunk(move(from_proc, to_proc, data), to_proc) + end end - return y end -function aliased_object!(f, x; ainfo=aliasing(current_acceleration(), x, identity)) - cache = ALIASED_OBJECT_CACHE[] - if haskey(cache, ainfo) - y = cache[ainfo] - else - y = f(x) - @assert y isa Chunk "Didn't get a Chunk from functor" - cache[ainfo] = y +function move_rewrap(accel, cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, v::SubArray) + to_w = root_worker_id(to_proc) + p_chunk = move_rewrap(accel, cache, from_proc, to_proc, from_space, to_space, parent(v)) + check_uniform(p_chunk.handle) + inds = parentindices(v) + return remotecall_endpoint_transfer(accel, from_proc, to_proc, from_space, to_space, p_chunk) do accel, from_proc, to_proc, from_space, to_space, p_chunk + p_new = move(from_proc, to_proc, p_chunk) + v_new = view(p_new, inds...) + return tochunk(v_new, to_proc) + end +end +# FIXME: Do this programmatically via recursive dispatch +for wrapper in (UpperTriangular, LowerTriangular, UnitUpperTriangular, UnitLowerTriangular) + @eval function move_rewrap(accel, cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, v::$(wrapper)) + to_w = root_worker_id(to_proc) + p_chunk = move_rewrap(accel, cache, from_proc, to_proc, from_space, to_space, parent(v)) + return remotecall_fetch_fast(to_w, from_proc, to_proc, from_space, to_space, p_chunk) do from_proc, to_proc, from_space, to_space, p_chunk + p_new = move(from_proc, to_proc, p_chunk) + v_new = $(wrapper)(p_new) + return tochunk(v_new, to_proc) + end end - return y -end -function aliased_object_unwrap!(x::Chunk) - y = unwrap(x) - ainfo = aliasing(current_acceleration(), y, identity) - return unwrap(aliased_object!(x; ainfo)) end - -struct DataDepsSchedulerState - task_to_spec::Dict{DTask,DTaskSpec} - assignments::Dict{DTask,MemorySpace} - dependencies::Dict{DTask,Set{DTask}} - task_completions::Dict{DTask,UInt64} - space_completions::Dict{MemorySpace,UInt64} - capacities::Dict{MemorySpace,Int} - - function DataDepsSchedulerState() - return new(Dict{DTask,DTaskSpec}(), - Dict{DTask,MemorySpace}(), - Dict{DTask,Set{DTask}}(), - Dict{DTask,UInt64}(), - Dict{MemorySpace,UInt64}(), - Dict{MemorySpace,Int}()) +#= FIXME: Make this work so we can automatically move-rewrap recursive objects +function move_rewrap_recursive(cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, x::T) where T + if isstructtype(T) + # Check all object fields (recursive) + for field in fieldnames(T) + value = getfield(x, field) + new_value = aliased_object!(cache, value) do value + return move_rewrap_recursive(cache, from_proc, to_proc, from_space, to_space, value) + end + setfield!(x, field, new_value) + end + return x + else + @warn "Cannot move-rewrap object of type $T" + return x end end +move_rewrap(cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, x::String) = x # FIXME: Not necessarily true +move_rewrap(cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, x::Symbol) = x +move_rewrap(cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, x::Type) = x +=# diff --git a/src/datadeps/chunkview.jl b/src/datadeps/chunkview.jl index 6e2a21dfd..418987124 100644 --- a/src/datadeps/chunkview.jl +++ b/src/datadeps/chunkview.jl @@ -3,6 +3,10 @@ struct ChunkView{N} slices::NTuple{N, Union{Int, AbstractRange{Int}, Colon}} end +function _identity_hash(arg::ChunkView, h::UInt=UInt(0)) + return hash(arg.slices, _identity_hash(arg.chunk, h)) +end + function Base.view(c::Chunk, slices...) if c.domain isa ArrayDomain nd, sz = ndims(c.domain), size(c.domain) @@ -25,30 +29,39 @@ function Base.view(c::Chunk, slices...) return ChunkView(c, slices) end -Base.view(c::DTask, slices...) = view(fetch(c; move_value=false, unwrap=false), slices...) +Base.view(c::DTask, slices...) = view(fetch(c; raw=true), slices...) -aliasing(x::ChunkView) = - throw(ConcurrencyViolationError("Cannot query aliasing of a ChunkView directly")) +function aliasing(accel::Acceleration, x::ChunkView{N}, dep_mod) where N + @assert dep_mod === identity "Dependency modifiers not yet supported for ChunkView: $dep_mod" + return remotecall_fetch(root_worker_id(x.chunk.processor), x.chunk, x.slices) do x, slices + x = unwrap(x) + v = view(x, slices...) + return aliasing(accel, v, dep_mod) + end +end memory_space(x::ChunkView) = memory_space(x.chunk) isremotehandle(x::ChunkView) = true -# This definition is here because it's so similar to ChunkView -function move_rewrap(from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, v::SubArray) - p_chunk = aliased_object!(parent(v)) do p_chunk - return remotecall_endpoint(identity, current_acceleration(), from_proc, to_proc, from_space, to_space, p_chunk) - end - inds = parentindices(v) - return remotecall_endpoint(current_acceleration(), from_proc, to_proc, from_space, to_space, p_chunk) do p_new - return view(p_new, inds...) +function move_rewrap(cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, slice::ChunkView) + to_w = root_worker_id(to_proc) + # N.B. We use move_rewrap (not rewrap_aliased_object!) so that if the inner + # chunk is a SubArray, it goes through the SubArray-aware path which shares + # the parent array via the aliased object cache. Using rewrap_aliased_object! + # would simply serialize the entire SubArray, creating a new parent copy on + # the destination, breaking aliasing with other views of the same parent. + p_chunk = move_rewrap(cache, from_proc, to_proc, from_space, to_space, slice.chunk) + return remotecall_fetch(to_w, from_proc, to_proc, from_space, to_space, p_chunk, slice.slices) do from_proc, to_proc, from_space, to_space, p_chunk, inds + p_new = move(from_proc, to_proc, p_chunk) + v_new = view(p_new, inds...) + return tochunk(v_new, to_proc) end end -function move_rewrap(from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, slice::ChunkView) - p_chunk = aliased_object!(slice.chunk) do p_chunk - return remotecall_endpoint(identity, current_acceleration(), from_proc, to_proc, from_space, to_space, p_chunk) - end - inds = slice.slices - return remotecall_endpoint(current_acceleration(), from_proc, to_proc, from_space, to_space, p_chunk) do p_new - return view(p_new, inds...) +function move(from_proc::Processor, to_proc::Processor, slice::ChunkView) + to_w = root_worker_id(to_proc) + return remotecall_fetch(to_w, from_proc, to_proc, slice.chunk, slice.slices) do from_proc, to_proc, chunk, slices + chunk_new = move(from_proc, to_proc, chunk) + v_new = view(chunk_new, slices...) + return tochunk(v_new, to_proc) end end diff --git a/src/datadeps/queue.jl b/src/datadeps/queue.jl index 70e7543eb..3b3ed5185 100644 --- a/src/datadeps/queue.jl +++ b/src/datadeps/queue.jl @@ -1,21 +1,4 @@ - -const TAG_WAITING = Base.Lockable(Ref{UInt32}(1)) -function to_tag() - intask = Dagger.in_task() - if intask - opts = Dagger.get_tls().task_spec.options - tag = opts.tag - return tag - end - lock(TAG_WAITING) do counter_ref - @assert Sch.SCHED_MOVE[] == false "We should not create a tag on the scheduler unwrap move" - tag = counter_ref[] - counter_ref[] = tag + 1 > MPI.tag_ub() ? 1 : tag + 1 - return tag - end -end - -struct DataDepsTaskQueue <: AbstractTaskQueue +struct DataDepsTaskQueue{Scheduler<:DataDepsScheduler} <: AbstractTaskQueue # The queue above us upper_queue::AbstractTaskQueue # The set of tasks that have already been seen @@ -24,24 +7,14 @@ struct DataDepsTaskQueue <: AbstractTaskQueue g::Union{SimpleDiGraph{Int},Nothing} # The mapping from task to graph ID task_to_id::Union{Dict{DTask,Int},Nothing} - # How to traverse the dependency graph when launching tasks - traversal::Symbol # Which scheduler to use to assign tasks to processors - scheduler::Symbol + scheduler::Scheduler - # Whether aliasing across arguments is possible - # The fields following only apply when aliasing==true - aliasing::Bool - - function DataDepsTaskQueue(upper_queue; - traversal::Symbol=:inorder, - scheduler::Symbol=:naive, - aliasing::Bool=true) + function DataDepsTaskQueue(upper_queue; scheduler::DataDepsScheduler) seen_tasks = DTaskPair[] g = SimpleDiGraph() task_to_id = Dict{DTask,Int}() - return new(upper_queue, seen_tasks, g, task_to_id, traversal, scheduler, - aliasing) + return new{typeof(scheduler)}(upper_queue, seen_tasks, g, task_to_id, scheduler) end end @@ -55,7 +28,7 @@ end const DATADEPS_CURRENT_TASK = TaskLocalValue{Union{DTask,Nothing}}(Returns(nothing)) """ - spawn_datadeps(f::Base.Callable; traversal::Symbol=:inorder) + spawn_datadeps(f::Base.Callable) Constructs a "datadeps" (data dependencies) region and calls `f` within it. Dagger tasks launched within `f` may wrap their arguments with `In`, `Out`, or @@ -82,46 +55,42 @@ appropriately. At the end of executing `f`, `spawn_datadeps` will wait for all launched tasks to complete, rethrowing the first error, if any. The result of `f` will be returned from `spawn_datadeps`. - -The keyword argument `traversal` controls the order that tasks are launched by -the scheduler, and may be set to `:bfs` or `:dfs` for Breadth-First Scheduling -or Depth-First Scheduling, respectively. All traversal orders respect the -dependencies and ordering of the launched tasks, but may provide better or -worse performance for a given set of datadeps tasks. This argument is -experimental and subject to change. """ function spawn_datadeps(f::Base.Callable; static::Bool=true, traversal::Symbol=:inorder, - scheduler::Union{Symbol,Nothing}=nothing, + scheduler::Union{DataDepsScheduler,Nothing}=nothing, aliasing::Bool=true, launch_wait::Union{Bool,Nothing}=nothing) if !static throw(ArgumentError("Dynamic scheduling is no longer available")) end + if traversal != :inorder + throw(ArgumentError("Traversal order is no longer configurable, and always :inorder")) + end + if !aliasing + throw(ArgumentError("Aliasing analysis is no longer optional")) + end wait_all(; check_errors=true) do - scheduler = something(scheduler, DATADEPS_SCHEDULER[], :roundrobin)::Symbol + scheduler = something(scheduler, DATADEPS_SCHEDULER[], RoundRobinScheduler()) launch_wait = something(launch_wait, DATADEPS_LAUNCH_WAIT[], false)::Bool if launch_wait result = spawn_bulk() do - queue = DataDepsTaskQueue(get_options(:task_queue); - traversal, scheduler, aliasing) + queue = DataDepsTaskQueue(get_options(:task_queue); scheduler) with_options(f; task_queue=queue) distribute_tasks!(queue) end else - queue = DataDepsTaskQueue(get_options(:task_queue); - traversal, scheduler, aliasing) + queue = DataDepsTaskQueue(get_options(:task_queue); scheduler) result = with_options(f; task_queue=queue) distribute_tasks!(queue) end - DATADEPS_CURRENT_TASK[] = nothing return result end end -const DATADEPS_SCHEDULER = ScopedValue{Union{Symbol,Nothing}}(nothing) +const DATADEPS_SCHEDULER = ScopedValue{Union{DataDepsScheduler,Nothing}}(nothing) const DATADEPS_LAUNCH_WAIT = ScopedValue{Union{Bool,Nothing}}(nothing) -@warn "Don't blindly set occupancy=0, only do for MPI" maxlog=1 +@warn "Add reliable, uniform-safe Processor sorting" maxlog=1 function distribute_tasks!(queue::DataDepsTaskQueue) #= TODO: Improvements to be made: # - Support for copying non-AbstractArray arguments @@ -132,7 +101,6 @@ function distribute_tasks!(queue::DataDepsTaskQueue) =# # Get the set of all processors to be scheduled on - scope = get_compute_scope() accel = current_acceleration() accel_procs = filter(procs(Dagger.Sch.eager_context())) do proc Dagger.accel_matches_proc(accel, proc) @@ -140,81 +108,30 @@ function distribute_tasks!(queue::DataDepsTaskQueue) all_procs = unique(vcat([collect(Dagger.get_processors(gp)) for gp in accel_procs]...)) # FIXME: This is an unreliable way to ensure processor uniformity sort!(all_procs, by=short_name) + scope = get_compute_scope() filter!(proc->proc_in_scope(proc, scope), all_procs) if isempty(all_procs) throw(Sch.SchedulingException("No processors available, try widening scope")) end - exec_spaces = unique(vcat(map(proc->collect(memory_spaces(proc)), all_procs)...)) - #=if !all(space->space isa CPURAMMemorySpace, exec_spaces) && !all(space->root_worker_id(space) == myid(), exec_spaces) - @warn "Datadeps support for multi-GPU, multi-worker is currently broken\nPlease be prepared for incorrect results or errors" maxlog=1 - end=# - for proc in all_procs - check_uniform(proc) + if uniform_execution(accel) + for proc in all_procs + check_uniform(proc) + end end + all_scope = UnionScope(map(ExactScope, all_procs)) + exec_spaces = unique(vcat(map(proc->collect(memory_spaces(proc)), all_procs)...)) # Round-robin assign tasks to processors upper_queue = get_options(:task_queue) - traversal = queue.traversal - if traversal == :inorder - # As-is - task_order = Colon() - elseif traversal == :bfs - # BFS - task_order = Int[1] - to_walk = Int[1] - seen = Set{Int}([1]) - while !isempty(to_walk) - # N.B. next_root has already been seen - next_root = popfirst!(to_walk) - for v in outneighbors(queue.g, next_root) - if !(v in seen) - push!(task_order, v) - push!(seen, v) - push!(to_walk, v) - end - end - end - elseif traversal == :dfs - # DFS (modified with backtracking) - task_order = Int[] - to_walk = Int[1] - seen = Set{Int}() - while length(task_order) < length(queue.seen_tasks) && !isempty(to_walk) - next_root = popfirst!(to_walk) - if !(next_root in seen) - iv = inneighbors(queue.g, next_root) - if all(v->v in seen, iv) - push!(task_order, next_root) - push!(seen, next_root) - ov = outneighbors(queue.g, next_root) - prepend!(to_walk, ov) - else - push!(to_walk, next_root) - end - end - end - else - throw(ArgumentError("Invalid traversal mode: $traversal")) - end - - state = DataDepsState(queue.aliasing) - sstate = DataDepsSchedulerState() - for proc in all_procs - space = only(memory_spaces(proc)) - get!(()->0, sstate.capacities, space) - sstate.capacities[space] += 1 - end - # Start launching tasks and necessary copies + state = DataDepsState() write_num = 1 - proc_idx = 1 - #pressures = Dict{Processor,Int}() proc_to_scope_lfu = BasicLFUCache{Processor,AbstractScope}(1024) - for pair in queue.seen_tasks[task_order] + for pair in queue.seen_tasks spec = pair.spec task = pair.task - write_num, proc_idx = distribute_task!(queue, state, all_procs, spec, task, spec.fargs, proc_to_scope_lfu, write_num, proc_idx) + write_num = distribute_task!(queue, state, all_procs, all_scope, spec, task, spec.fargs, proc_to_scope_lfu, write_num) end # Copy args from remote to local @@ -234,6 +151,10 @@ function distribute_tasks!(queue::DataDepsTaskQueue) else @assert remainder isa NoAliasing "Expected NoAliasing, got $(typeof(remainder))" @dagdebug nothing :spawn_datadeps "Skipped copy-from (up-to-date): $origin_space" + ctx = Sch.eager_context() + id = rand(UInt) + @maybelog ctx timespan_start(ctx, :datadeps_copy_skip, (;id), (;)) + @maybelog ctx timespan_finish(ctx, :datadeps_copy_skip, (;id), (;thunk_id=0, from_space=origin_space, to_space=origin_space, arg_w, from_arg=arg, to_arg=arg)) end end write_num += 1 @@ -278,174 +199,33 @@ struct TypedDataDepsTaskArgument{T,N} deps::NTuple{N,DataDepsTaskDependency} end map_or_ntuple(f, xs::Vector) = map(f, 1:length(xs)) -map_or_ntuple(f, xs::Tuple) = ntuple(f, length(xs)) - -# 4-arg version: side effects + returns Vector/Tuple of DataDepsTaskArgument for distribute_task! -function populate_task_info!(state::DataDepsState, task_args, spec::DTaskSpec, task::DTask) - result = DataDepsTaskArgument[] - _populate_task_info!(state, spec, task, (arg, pos, may_alias, inplace_move, deps) -> begin - dep_infos = DataDepsTaskDependency[DataDepsTaskDependency(arg, d) for d in deps] - push!(result, DataDepsTaskArgument(arg, pos, may_alias, inplace_move, dep_infos)) - end) - return spec.fargs isa Tuple ? (result...,) : result -end - -function distribute_task!(queue::DataDepsTaskQueue, state::DataDepsState, all_procs, spec::DTaskSpec{typed}, task::DTask, fargs, proc_to_scope_lfu, write_num::Int, proc_idx::Int) where typed +@inline map_or_ntuple(@specialize(f), xs::NTuple{N,T}) where {N,T} = ntuple(f, Val(N)) +function distribute_task!(queue::DataDepsTaskQueue, state::DataDepsState, all_procs, all_scope, spec::DTaskSpec{typed}, task::DTask, fargs, proc_to_scope_lfu, write_num::Int) where typed @specialize spec fargs - DATADEPS_CURRENT_TASK[] = task - if typed fargs::Tuple else fargs::Vector{Argument} end - scheduler = queue.scheduler - if scheduler == :naive - raw_args = map(arg->tochunk(value(arg)), spec.fargs) - our_proc = remotecall_fetch(1, all_procs, raw_args) do all_procs, raw_args - Sch.init_eager() - sch_state = Sch.EAGER_STATE[] - - @lock sch_state.lock begin - # Calculate costs per processor and select the most optimal - # FIXME: This should consider any already-allocated slots, - # whether they are up-to-date, and if not, the cost of moving - # data to them - procs, costs = Sch.estimate_task_costs(sch_state, all_procs, nothing, raw_args) - return first(procs) - end - end - elseif scheduler == :smart - raw_args = map(filter(arg->haskey(state.data_locality, value(arg)), spec.fargs)) do arg - arg_chunk = tochunk(value(arg)) - # Only the owned slot is valid - # FIXME: Track up-to-date copies and pass all of those - return arg_chunk => data_locality[arg] - end - f_chunk = tochunk(value(spec.fargs[1])) - our_proc, task_pressure = remotecall_fetch(1, all_procs, pressures, f_chunk, raw_args) do all_procs, pressures, f, chunks_locality - Sch.init_eager() - sch_state = Sch.EAGER_STATE[] - - @lock sch_state.lock begin - tx_rate = sch_state.transfer_rate[] - - costs = Dict{Processor,Float64}() - for proc in all_procs - # Filter out chunks that are already local - chunks_filt = Iterators.filter(((chunk, space)=chunk_locality)->!(proc in processors(space)), chunks_locality) - - # Estimate network transfer costs based on data size - # N.B. `affinity(x)` really means "data size of `x`" - # N.B. We treat same-worker transfers as having zero transfer cost - tx_cost = Sch.impute_sum(affinity(chunk)[2] for chunk in chunks_filt) - - # Estimate total cost to move data and get task running after currently-scheduled tasks - est_time_util = get(pressures, proc, UInt64(0)) - costs[proc] = est_time_util + (tx_cost/tx_rate) - end - - # Look up estimated task cost - sig = Sch.signature(sch_state, f, map(first, chunks_locality)) - task_pressure = get(sch_state.signature_time_cost, sig, 1000^3) - - # Shuffle procs around, so equally-costly procs are equally considered (skip when MPI for deterministic tie-breaking) - procs = if current_acceleration() isa Dagger.MPIAcceleration - collect(all_procs) - else - P = randperm(length(all_procs)) - getindex.(Ref(all_procs), P) - end - - # Sort by lowest cost first - sort!(procs, by=p->costs[p]) - - best_proc = first(procs) - return best_proc, task_pressure - end - end - # FIXME: Pressure should be decreased by pressure of syncdeps on same processor - pressures[our_proc] = get(pressures, our_proc, UInt64(0)) + task_pressure - elseif scheduler == :ultra - args = Base.mapany(spec.fargs) do arg - pos, data = arg - data, _ = unwrap_inout(data) - if data isa DTask - data = fetch(data; move_value=false, unwrap=false) - end - return pos => tochunk(data) - end - f_chunk = tochunk(value(spec.fargs[1])) - task_time = remotecall_fetch(1, f_chunk, args) do f, args - Sch.init_eager() - sch_state = Sch.EAGER_STATE[] - return @lock sch_state.lock begin - sig = Sch.signature(sch_state, f, args) - return get(sch_state.signature_time_cost, sig, 1000^3) - end - end - - # FIXME: Copy deps are computed eagerly - deps = @something(spec.options.syncdeps, Set{Any}()) - - # Find latest time-to-completion of all syncdeps - deps_completed = UInt64(0) - for dep in deps - haskey(sstate.task_completions, dep) || continue # copy deps aren't recorded - deps_completed = max(deps_completed, sstate.task_completions[dep]) - end - - # Find latest time-to-completion of each memory space - # FIXME: Figure out space completions based on optimal packing - spaces_completed = Dict{MemorySpace,UInt64}() - for space in exec_spaces - completed = UInt64(0) - for (task, other_space) in sstate.assignments - space == other_space || continue - completed = max(completed, sstate.task_completions[task]) - end - spaces_completed[space] = completed - end - - # Choose the earliest-available memory space and processor - # FIXME: Consider move time - move_time = UInt64(0) - local our_space_completed - while true - our_space_completed, our_space = findmin(spaces_completed) - our_space_procs = filter(proc->proc in all_procs, processors(our_space)) - if isempty(our_space_procs) - delete!(spaces_completed, our_space) - continue - end - our_proc = if current_acceleration() isa Dagger.MPIAcceleration - first(sort(collect(our_space_procs), by=short_name)) - else - rand(our_space_procs) - end - break - end + DATADEPS_CURRENT_TASK[] = task - sstate.task_to_spec[task] = spec - sstate.assignments[task] = our_space - sstate.task_completions[task] = our_space_completed + move_time + task_time - elseif scheduler == :roundrobin - our_proc = all_procs[proc_idx] - else - error("Invalid scheduler: $sched") - end + task_scope = @something(spec.options.compute_scope, spec.options.scope, DefaultScope()) + scheduler = queue.scheduler + our_proc = datadeps_schedule_task(scheduler, state, all_procs, all_scope, task_scope, spec, task) @assert our_proc in all_procs our_space = only(memory_spaces(our_proc)) + check_uniform(our_proc) + check_uniform(our_space) # Find the scope for this task (and its copies) task_scope = @something(spec.options.compute_scope, spec.options.scope, DefaultScope()) - if task_scope == scope + if task_scope == all_scope # Optimize for the common case, cache the proc=>scope mapping our_scope = get!(proc_to_scope_lfu, our_proc) do our_procs = filter(proc->proc in all_procs, collect(processors(our_space))) - return constrain(UnionScope(map(ExactScope, our_procs)...), scope) + return constrain(UnionScope(map(ExactScope, our_procs)...), all_scope) end else # Use the provided scope and constrain it to the available processors @@ -455,13 +235,12 @@ function distribute_task!(queue::DataDepsTaskQueue, state::DataDepsState, all_pr if our_scope isa InvalidScope throw(Sch.SchedulingException("Scopes are not compatible: $(our_scope.x), $(our_scope.y)")) end - check_uniform(our_proc) - check_uniform(our_space) f = spec.fargs[1] + tid = task.uid # FIXME: May not be correct to move this under uniformity #f.value = move(default_processor(), our_proc, value(f)) - @dagdebug nothing :spawn_datadeps "($(repr(value(f)))) Scheduling: $our_proc ($our_space)" + @dagdebug tid :spawn_datadeps "($(repr(value(f)))) Scheduling: $our_proc ($our_space)" # Copy raw task arguments for analysis # N.B. Used later for checking dependencies @@ -488,13 +267,13 @@ function distribute_task!(queue::DataDepsTaskQueue, state::DataDepsState, all_pr # Is the data written previously or now? if !arg_ws.may_alias - @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$(idx-1)] Skipped copy-to (immutable)" + @dagdebug tid :spawn_datadeps "($(repr(value(f))))[$(idx-1)] Skipped copy-to (immutable)" return arg end # Is the data writeable? if !arg_ws.inplace_move - @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$(idx-1)] Skipped copy-to (non-writeable)" + @dagdebug tid :spawn_datadeps "($(repr(value(f))))[$(idx-1)] Skipped copy-to (non-writeable)" return arg end @@ -511,7 +290,7 @@ function distribute_task!(queue::DataDepsTaskQueue, state::DataDepsState, all_pr enqueue_copy_to!(state, our_space, arg_w, value(f), idx, our_scope, task, write_num) else @assert remainder isa NoAliasing "Expected NoAliasing, got $(typeof(remainder))" - @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$(idx-1)][$dep_mod] Skipped copy-to (up-to-date): $our_space" + @dagdebug tid :spawn_datadeps "($(repr(value(f))))[$(idx-1)][$dep_mod] Skipped copy-to (up-to-date): $our_space" end end return arg_remote @@ -530,6 +309,9 @@ function distribute_task!(queue::DataDepsTaskQueue, state::DataDepsState, all_pr end # Check that any mutable and written arguments are already in the correct space + # N.B. We only do this check when the argument supports in-place + # moves, because for the moment, we are not guaranteeing updates or + # write-back of results if is_writedep(arg, deps, task) && arg_ws.may_alias && arg_ws.inplace_move arg_space = memory_space(arg) @assert arg_space == our_space "($(repr(value(f))))[$(idx-1)] Tried to pass $(typeof(arg)) from $arg_space to $our_space" @@ -538,12 +320,11 @@ function distribute_task!(queue::DataDepsTaskQueue, state::DataDepsState, all_pr # Calculate this task's syncdeps if spec.options.syncdeps === nothing - spec.options.syncdeps = Set{Any}() + spec.options.syncdeps = Set{ThunkSyncdep}() end if spec.options.tag === nothing - spec.options.tag = to_tag() + spec.options.tag = to_tag() end - syncdeps = spec.options.syncdeps map_or_ntuple(task_arg_ws) do idx arg_ws = task_arg_ws[idx] @@ -556,46 +337,35 @@ function distribute_task!(queue::DataDepsTaskQueue, state::DataDepsState, all_pr ainfo = aliasing!(state, our_space, arg_w) dep_mod = arg_w.dep_mod if dep.writedep - @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$(idx-1)][$dep_mod] Syncing as writer" + @dagdebug tid :spawn_datadeps "($(repr(value(f))))[$(idx-1)][$dep_mod] Syncing as writer" get_write_deps!(state, our_space, ainfo, write_num, syncdeps) else - @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$(idx-1)][$dep_mod] Syncing as reader" + @dagdebug tid :spawn_datadeps "($(repr(value(f))))[$(idx-1)][$dep_mod] Syncing as reader" get_read_deps!(state, our_space, ainfo, write_num, syncdeps) end end return end - @dagdebug nothing :spawn_datadeps "($(repr(value(f)))) Task has $(length(syncdeps)) syncdeps" - - # Launch user's task: preserve full argument list (spec.fargs); use remote values only for tracked args - new_fargs = if spec.fargs isa Tuple - ntuple(length(spec.fargs)) do i - arg = spec.fargs[i] - pos = arg.pos - j = findfirst(w -> w.pos == pos, task_arg_ws) - if j !== nothing - val = remote_args[j] - is_typed(spec) ? TypedArgument(pos, val) : Argument(pos, val) - else - copy(arg) - end + @dagdebug tid :spawn_datadeps "($(repr(value(f)))) Task has $(length(syncdeps)) syncdeps" + + # Launch user's task + new_fargs = map_or_ntuple(task_arg_ws) do idx + if is_typed(spec) + return TypedArgument(task_arg_ws[idx].pos, remote_args[idx]) + else + return Argument(task_arg_ws[idx].pos, remote_args[idx]) end - else - [let arg = spec.fargs[i], pos = arg.pos - j = findfirst(w -> w.pos == pos, task_arg_ws) - if j !== nothing - val = remote_args[j] - is_typed(spec) ? TypedArgument(pos, val) : Argument(pos, val) - else - copy(arg) - end - end for i in 1:length(spec.fargs)] end new_spec = DTaskSpec(new_fargs, spec.options) new_spec.options.scope = our_scope new_spec.options.exec_scope = our_scope - new_spec.options.occupancy = Dict(Any=>0) + if uniform_execution() + new_spec.options.occupancy = Dict(Any=>0) + end + ctx = Sch.eager_context() + @maybelog ctx timespan_start(ctx, :datadeps_execute, (;thunk_id=task.uid), (;)) enqueue!(queue.upper_queue, DTaskPair(new_spec, task)) + @maybelog ctx timespan_finish(ctx, :datadeps_execute, (;thunk_id=task.uid), (;space=our_space, deps=task_arg_ws, args=remote_args)) # Update read/write tracking for arguments map_or_ntuple(task_arg_ws) do idx @@ -608,7 +378,7 @@ function distribute_task!(queue::DataDepsTaskQueue, state::DataDepsState, all_pr ainfo = aliasing!(state, our_space, arg_w) dep_mod = arg_w.dep_mod if dep.writedep - @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$(idx-1)][$dep_mod] Task set as writer" + @dagdebug tid :spawn_datadeps "($(repr(value(f))))[$(idx-1)][$dep_mod] Task set as writer" add_writer!(state, arg_w, our_space, ainfo, task, write_num) else add_reader!(state, arg_w, our_space, ainfo, task, write_num) @@ -618,7 +388,8 @@ function distribute_task!(queue::DataDepsTaskQueue, state::DataDepsState, all_pr end write_num += 1 - proc_idx = mod1(proc_idx + 1, length(all_procs)) - return write_num, proc_idx + DATADEPS_CURRENT_TASK[] = nothing + + return write_num end diff --git a/src/datadeps/remainders.jl b/src/datadeps/remainders.jl index 88201c621..ee1b060db 100644 --- a/src/datadeps/remainders.jl +++ b/src/datadeps/remainders.jl @@ -9,10 +9,11 @@ This is used to perform partial data copies that only update the "remainder" reg struct RemainderAliasing{S<:MemorySpace} <: AbstractAliasing space::S spans::Vector{Tuple{LocalMemorySpan,LocalMemorySpan}} + ainfos::Vector{AliasingWrapper} syncdeps::Set{ThunkSyncdep} end -RemainderAliasing(space::S, spans::Vector{Tuple{LocalMemorySpan,LocalMemorySpan}}, syncdeps::Set{ThunkSyncdep}) where S = - RemainderAliasing{S}(space, spans, syncdeps) +RemainderAliasing(space::S, spans::Vector{Tuple{LocalMemorySpan,LocalMemorySpan}}, ainfos::Vector{AliasingWrapper}, syncdeps::Set{ThunkSyncdep}) where S = + RemainderAliasing{S}(space, spans, ainfos, syncdeps) memory_spans(ra::RemainderAliasing) = ra.spans @@ -42,42 +43,6 @@ memory_spans(mra::MultiRemainderAliasing) = vcat(memory_spans.(mra.remainders).. Base.hash(mra::MultiRemainderAliasing, h::UInt) = hash(mra.remainders, hash(MultiRemainderAliasing, h)) Base.:(==)(mra1::MultiRemainderAliasing, mra2::MultiRemainderAliasing) = mra1.remainders == mra2.remainders -#= FIXME: Integrate with main documentation -Problem statement: - -Remainder copy calculation needs to ensure that, for a given argument and -dependency modifier, and for a given target memory space, any data not yet -updated (whether through this arg or through another that aliases) is added to -the remainder, while any data that has been updated is not in the remainder. -Remainder copies may be multi-part, as data may be spread across multiple other -memory spaces. - -Ainfo is not alone sufficient to identify the combination of argument and -dependency modifier, as ainfo is specific to an allocation in a given memory -space. Thus, this combination needs to be tracked together, and separately from -memory space. However, information may span multiple memory spaces (and thus -multiple ainfos), so we should try to make queries of cross-memory space -information fast, as they will need to be performed for every task, for every -combination. - -Game Plan: - -- Use ArgumentWrapper to track this combination throughout the codebase, ideally generated just once -- Maintain the keying of remote_args only on argument, as the dependency modifier doesn’t affect the argument being passed into the task, so it should not factor into generating and tracking remote argument copies -- Add a structure to track the mapping from ArgumentWrapper to memory space to ainfo, as a quick way to lookup all ainfos needing to be considered -- When considering a remainder copy, only look at a single memory space’s ainfos at a time, as the ainfos should overlap exactly the same way on any memory space, and this allows us to use ainfo_overlaps to track overlaps -- Remainder copies will need to separately consider the source memory space, and the destination memory space when acquiring spans to copy to/from -- Memory spans for ainfos generated from the same ArgumentWrapper should be assumed to be paired in the same order, regardless of memory space, to ensure we can perform the translation from source to destination span address - - Alternatively, we might provide an API to take source and destination ainfos, and desired remainder memory spans, which then performs the copy for us -- When a task or copy writes to arguments, we should record this happening for all overlapping ainfos, in a manner that will be efficient to query from another memory space. We can probably walk backwards and attach this to a structure keyed on ArgumentWrapper, as that will be very efficient for later queries (because the history will now be linearized in one vector). -- Remainder copies will need to know, for all overlapping ainfos of the ArgumentWrapper ainfo at the target memory space, how recently that ainfo was updated relative to other ainfos, and relative to how recently the target ainfo was written. - - The last time the target ainfo was written is the furthest back we need to consider, as the target data must have been fully up-to-date when that write completed. - - Consideration of updates should start at most recent first, walking backwards in time, as the most recent updates contain the up-to-date data. - - For each span under consideration, we should subtract from it the current remainder set, to ensure we only copy up-to-date data. - - We must add that span portion to the remainder set no matter what, but if it was updated on the target memory space, we don’t need to schedule a copy for it, since it’s already where it needs to be. - - Even before the last target write is seen, we are allowed to stop searching if we find that our target ainfo is fully covered (because this implies that the target ainfo is fully out-of-date). -=# - struct FullCopy end """ @@ -122,17 +87,18 @@ function compute_remainder_for_arg!(state::DataDepsState, target_space::MemorySpace, arg_w::ArgumentWrapper, write_num::Int; compute_syncdeps::Bool=true) - @label restart - - # Determine all memory spaces of the history spaces_set = Set{MemorySpace}() push!(spaces_set, target_space) owner_space = state.arg_owner[arg_w] push!(spaces_set, owner_space) + + @label restart + + # Determine all memory spaces of the history for entry in state.arg_history[arg_w] push!(spaces_set, entry.space) end - spaces = collect(spaces_set) + spaces = sort(collect(spaces_set), by=short_name) N = length(spaces) # Lookup all memory spans for arg_w in these spaces @@ -143,10 +109,12 @@ function compute_remainder_for_arg!(state::DataDepsState, push!(target_ainfos, LocalMemorySpan.(spans)) end nspans = length(first(target_ainfos)) + @assert all(==(nspans), length.(target_ainfos)) "Aliasing info for $(typeof(arg_w.arg))[$(arg_w.dep_mod)] has different number of spans in different memory spaces" # FIXME: This is a hack to ensure that we don't miss any history generated by aliasing(...) for entry in state.arg_history[arg_w] if !in(entry.space, spaces) + @opcounter :compute_remainder_for_arg_restart @goto restart end end @@ -164,10 +132,14 @@ function compute_remainder_for_arg!(state::DataDepsState, end # Create our remainder as an interval tree over all target ainfos + VERIFY_SPAN_CURRENT_OBJECT[] = arg_w.arg remainder = IntervalTree{ManyMemorySpan{N}}(ManyMemorySpan{N}(ntuple(i -> target_ainfos[i][j], N)) for j in 1:nspans) + for span in remainder + verify_span(span) + end # Create our tracker - tracker = Dict{MemorySpace,Tuple{Vector{Tuple{LocalMemorySpan,LocalMemorySpan}},Set{ThunkSyncdep}}}() + tracker = Dict{MemorySpace,Tuple{Vector{Tuple{LocalMemorySpan,LocalMemorySpan}},Vector{AliasingWrapper},Set{ThunkSyncdep}}}() # Walk backwards through the history of writes to this target # other_ainfo is the overlapping ainfo that was written to @@ -193,7 +165,7 @@ function compute_remainder_for_arg!(state::DataDepsState, check_uniform(other_space) # Lookup all memory spans for arg_w in these spaces - other_remote_arg_w = state.ainfo_arg[other_ainfo] + other_remote_arg_w = first(collect(state.ainfo_arg[other_ainfo])) other_arg_w = ArgumentWrapper(state.remote_arg_to_original[other_remote_arg_w.arg], other_remote_arg_w.dep_mod) other_ainfos = Vector{Vector{LocalMemorySpan}}() for space in spaces @@ -203,14 +175,16 @@ function compute_remainder_for_arg!(state::DataDepsState, end nspans = length(first(other_ainfos)) other_many_spans = [ManyMemorySpan{N}(ntuple(i -> other_ainfos[i][j], N)) for j in 1:nspans] - + foreach(other_many_spans) do span + verify_span(span) + end check_uniform(other_many_spans) - check_uniform(spaces) if other_space == target_space # Only subtract, this data is already up-to-date in target_space # N.B. We don't add to syncdeps here, because we'll see this ainfo # in get_write_deps! + @opcounter :compute_remainder_for_arg_subtract subtract_spans!(remainder, other_many_spans) continue end @@ -219,16 +193,19 @@ function compute_remainder_for_arg!(state::DataDepsState, other_space_idx = something(findfirst(==(other_space), spaces)) target_space_idx = something(findfirst(==(target_space), spaces)) tracker_other_space = get!(tracker, other_space) do - (Vector{Tuple{LocalMemorySpan,LocalMemorySpan}}(), Set{ThunkSyncdep}()) + (Vector{Tuple{LocalMemorySpan,LocalMemorySpan}}(), Vector{AliasingWrapper}(), Set{ThunkSyncdep}()) end - schedule_remainder!(tracker_other_space[1], other_space_idx, target_space_idx, remainder, other_many_spans) - if compute_syncdeps + @opcounter :compute_remainder_for_arg_schedule + has_overlap = schedule_remainder!(tracker_other_space[1], other_space_idx, target_space_idx, remainder, other_many_spans) + if compute_syncdeps && has_overlap @assert haskey(state.ainfos_owner, other_ainfo) "[idx $idx] ainfo $(typeof(other_ainfo)) has no owner" - get_read_deps!(state, other_space, other_ainfo, write_num, tracker_other_space[2]) + get_read_deps!(state, other_space, other_ainfo, write_num, tracker_other_space[3]) + push!(tracker_other_space[2], other_ainfo) end end + VERIFY_SPAN_CURRENT_OBJECT[] = nothing - if isempty(tracker) + if isempty(tracker) || all(tracked->isempty(tracked[1]), values(tracker)) return NoAliasing(), 0 end @@ -236,12 +213,13 @@ function compute_remainder_for_arg!(state::DataDepsState, mra = MultiRemainderAliasing() for space in spaces if haskey(tracker, space) - spans, syncdeps = tracker[space] + spans, ainfos, syncdeps = tracker[space] if !isempty(spans) - push!(mra.remainders, RemainderAliasing(space, spans, syncdeps)) + push!(mra.remainders, RemainderAliasing(space, spans, ainfos, syncdeps)) end end end + @assert !isempty(mra.remainders) "Expected at least one remainder (spaces: $spaces, tracker spaces: $(collect(keys(tracker))))" return mra, last_idx end @@ -257,12 +235,13 @@ copy from `other_many_spans` to the subtraced portion of `remainder`. function schedule_remainder!(tracker::Vector, source_space_idx::Int, dest_space_idx::Int, remainder::IntervalTree, other_many_spans::Vector{ManyMemorySpan{N}}) where N diff = Vector{ManyMemorySpan{N}}() subtract_spans!(remainder, other_many_spans, diff) - for span in diff source_span = span.spans[source_space_idx] dest_span = span.spans[dest_space_idx] + @assert span_len(source_span) == span_len(dest_span) "Source and dest spans are not the same size: $(span_len(source_span)) != $(span_len(dest_span))" push!(tracker, (source_span, dest_span)) end + return !isempty(diff) end ### Remainder copy functions @@ -291,7 +270,7 @@ function enqueue_remainder_copy_to!(state::DataDepsState, dest_space::MemorySpac # overwritten by more recent partial updates source_space = remainder_aliasing.space - @dagdebug nothing :spawn_datadeps "($(repr(f)))[$(idx-1)][$dep_mod] Enqueueing remainder copy-to for $(typeof(arg_w.arg))[$(arg_w.dep_mod)]: $source_space => $dest_space" + @dagdebug task.uid :spawn_datadeps "($(repr(f)))[$(idx-1)][$dep_mod] Enqueueing remainder copy-to for $(typeof(arg_w.arg))[$(arg_w.dep_mod)]: $source_space => $dest_space" # Get the source and destination arguments arg_dest = state.remote_args[dest_space][arg_w.arg] @@ -304,16 +283,23 @@ function enqueue_remainder_copy_to!(state::DataDepsState, dest_space::MemorySpac push!(remainder_syncdeps, syncdep) end empty!(remainder_aliasing.syncdeps) # We can't bring these to move! + source_ainfos = copy(remainder_aliasing.ainfos) + empty!(remainder_aliasing.ainfos) get_write_deps!(state, dest_space, target_ainfo, write_num, remainder_syncdeps) - @dagdebug nothing :spawn_datadeps "($(repr(f)))[$(idx-1)][$dep_mod] Remainder copy-to has $(length(remainder_syncdeps)) syncdeps" + @dagdebug task.uid :spawn_datadeps "($(repr(f)))[$(idx-1)][$dep_mod] Remainder copy-to has $(length(remainder_syncdeps)) syncdeps" # Launch the remainder copy task - copy_task = Dagger.with_options(; tag=to_tag()) do - Dagger.@spawn scope=dest_scope exec_scope=dest_scope syncdeps=remainder_syncdeps occupancy=Dict(Any=>0) meta=true Dagger.move!(remainder_aliasing, dest_space, source_space, arg_dest, arg_source) + ctx = Sch.eager_context() + id = rand(UInt) + @maybelog ctx timespan_start(ctx, :datadeps_copy, (;id), (;)) + copy_task = Dagger.@spawn scope=dest_scope exec_scope=dest_scope syncdeps=remainder_syncdeps meta=true Dagger.move!(remainder_aliasing, dest_space, source_space, arg_dest, arg_source) + @maybelog ctx timespan_finish(ctx, :datadeps_copy, (;id), (;thunk_id=copy_task.uid, from_space=source_space, to_space=dest_space, arg_w, from_arg=arg_source, to_arg=arg_dest)) + + # This copy task reads the sources and writes to the target + for ainfo in source_ainfos + add_reader!(state, arg_w, source_space, ainfo, copy_task, write_num) end - - # This copy task becomes a new writer for the target region add_writer!(state, arg_w, dest_space, target_ainfo, copy_task, write_num) end """ @@ -353,16 +339,23 @@ function enqueue_remainder_copy_from!(state::DataDepsState, dest_space::MemorySp push!(remainder_syncdeps, syncdep) end empty!(remainder_aliasing.syncdeps) # We can't bring these to move! + source_ainfos = copy(remainder_aliasing.ainfos) + empty!(remainder_aliasing.ainfos) get_write_deps!(state, dest_space, target_ainfo, write_num, remainder_syncdeps) @dagdebug nothing :spawn_datadeps "($(typeof(arg_w.arg)))[$dep_mod] Remainder copy-from has $(length(remainder_syncdeps)) syncdeps" # Launch the remainder copy task - copy_task = Dagger.with_options(; tag=to_tag()) do - Dagger.@spawn scope=dest_scope exec_scope=dest_scope syncdeps=remainder_syncdeps occupancy=Dict(Any=>0) meta=true Dagger.move!(remainder_aliasing, dest_space, source_space, arg_dest, arg_source) + ctx = Sch.eager_context() + id = rand(UInt) + @maybelog ctx timespan_start(ctx, :datadeps_copy, (;id), (;)) + copy_task = Dagger.@spawn scope=dest_scope exec_scope=dest_scope syncdeps=remainder_syncdeps meta=true Dagger.move!(remainder_aliasing, dest_space, source_space, arg_dest, arg_source) + @maybelog ctx timespan_finish(ctx, :datadeps_copy, (;id), (;thunk_id=copy_task.uid, from_space=source_space, to_space=dest_space, arg_w, from_arg=arg_source, to_arg=arg_dest)) + + # This copy task reads the sources and writes to the target + for ainfo in source_ainfos + add_reader!(state, arg_w, source_space, ainfo, copy_task, write_num) end - - # This copy task becomes a new writer for the target region add_writer!(state, arg_w, dest_space, target_ainfo, copy_task, write_num) end @@ -373,7 +366,7 @@ function enqueue_copy_to!(state::DataDepsState, dest_space::MemorySpace, arg_w:: source_space = state.arg_owner[arg_w] target_ainfo = aliasing!(state, dest_space, arg_w) - @dagdebug nothing :spawn_datadeps "($(repr(f)))[$(idx-1)][$dep_mod] Enqueueing full copy-to for $(typeof(arg_w.arg))[$(arg_w.dep_mod)]: $source_space => $dest_space" + @dagdebug task.uid :spawn_datadeps "($(repr(f)))[$(idx-1)][$dep_mod] Enqueueing full copy-to for $(typeof(arg_w.arg))[$(arg_w.dep_mod)]: $source_space => $dest_space" # Get the source and destination arguments arg_dest = state.remote_args[dest_space][arg_w.arg] @@ -386,12 +379,17 @@ function enqueue_copy_to!(state::DataDepsState, dest_space::MemorySpace, arg_w:: get_read_deps!(state, source_space, source_ainfo, write_num, copy_syncdeps) get_write_deps!(state, dest_space, target_ainfo, write_num, copy_syncdeps) - @dagdebug nothing :spawn_datadeps "($(repr(f)))[$(idx-1)][$dep_mod] Full copy-to has $(length(copy_syncdeps)) syncdeps" + @dagdebug task.uid :spawn_datadeps "($(repr(f)))[$(idx-1)][$dep_mod] Full copy-to has $(length(copy_syncdeps)) syncdeps" # Launch the remainder copy task - copy_task = Dagger.with_options(; tag=to_tag()) do - Dagger.@spawn scope=dest_scope exec_scope=dest_scope syncdeps=copy_syncdeps occupancy=Dict(Any=>0) meta=true Dagger.move!(dep_mod, dest_space, source_space, arg_dest, arg_source) - end + ctx = Sch.eager_context() + id = rand(UInt) + @maybelog ctx timespan_start(ctx, :datadeps_copy, (;id), (;)) + copy_task = Dagger.@spawn scope=dest_scope exec_scope=dest_scope syncdeps=copy_syncdeps meta=true Dagger.move!(dep_mod, dest_space, source_space, arg_dest, arg_source) + @maybelog ctx timespan_finish(ctx, :datadeps_copy, (;id), (;thunk_id=copy_task.uid, from_space=source_space, to_space=dest_space, arg_w, from_arg=arg_source, to_arg=arg_dest)) + + # This copy task reads the source and writes to the target + add_reader!(state, arg_w, source_space, source_ainfo, copy_task, write_num) add_writer!(state, arg_w, dest_space, target_ainfo, copy_task, write_num) end function enqueue_copy_from!(state::DataDepsState, dest_space::MemorySpace, arg_w::ArgumentWrapper, @@ -416,38 +414,47 @@ function enqueue_copy_from!(state::DataDepsState, dest_space::MemorySpace, arg_w @dagdebug nothing :spawn_datadeps "($(typeof(arg_w.arg)))[$dep_mod] Full copy-from has $(length(copy_syncdeps)) syncdeps" # Launch the remainder copy task - copy_task = Dagger.with_options(; tag=to_tag()) do - Dagger.@spawn scope=dest_scope exec_scope=dest_scope syncdeps=copy_syncdeps occupancy=Dict(Any=>0) meta=true Dagger.move!(dep_mod, dest_space, source_space, arg_dest, arg_source) - end - - # This copy task becomes a new writer for the target region + ctx = Sch.eager_context() + id = rand(UInt) + @maybelog ctx timespan_start(ctx, :datadeps_copy, (;id), (;)) + copy_task = Dagger.@spawn scope=dest_scope exec_scope=dest_scope syncdeps=copy_syncdeps meta=true Dagger.move!(dep_mod, dest_space, source_space, arg_dest, arg_source) + @maybelog ctx timespan_finish(ctx, :datadeps_copy, (;id), (;thunk_id=copy_task.uid, from_space=source_space, to_space=dest_space, arg_w, from_arg=arg_source, to_arg=arg_dest)) + + # This copy task reads the source and writes to the target + add_reader!(state, arg_w, source_space, source_ainfo, copy_task, write_num) add_writer!(state, arg_w, dest_space, target_ainfo, copy_task, write_num) end # Main copy function for RemainderAliasing -function move!(dep_mod::RemainderAliasing, to_space::MemorySpace, from_space::MemorySpace, to::Chunk, from::Chunk) - # Get the source data for each span - copies = remotecall_fetch(root_worker_id(from_space), dep_mod) do dep_mod - copies = Vector{UInt8}[] - for (from_span, _) in dep_mod.spans - copy = Vector{UInt8}(undef, from_span.len) - GC.@preserve copy begin - from_ptr = Ptr{UInt8}(from_span.ptr) - to_ptr = Ptr{UInt8}(pointer(copy)) - unsafe_copyto!(to_ptr, from_ptr, from_span.len) +function move!(dep_mod::RemainderAliasing{S}, to_space::MemorySpace, from_space::MemorySpace, to::Chunk, from::Chunk) where S + # TODO: Support direct copy between GPU memory spaces + + # Copy the data from the source object + copies = remotecall_fetch(root_worker_id(from_space), from_space, dep_mod, from) do from_space, dep_mod, from + len = sum(span_tuple->span_len(span_tuple[1]), dep_mod.spans) + copies = Vector{UInt8}(undef, len) + from_raw = unwrap(from) + offset = UInt64(1) + with_context!(from_space) + GC.@preserve copies begin + for (from_span, _) in dep_mod.spans + read_remainder!(copies, offset, from_raw, from_span.ptr, from_span.len) + offset += from_span.len end - push!(copies, copy) end + @assert offset == len+UInt64(1) return copies end # Copy the data into the destination object - for (copy, (_, to_span)) in zip(copies, dep_mod.spans) - GC.@preserve copy begin - from_ptr = Ptr{UInt8}(pointer(copy)) - to_ptr = Ptr{UInt8}(to_span.ptr) - unsafe_copyto!(to_ptr, from_ptr, to_span.len) + offset = UInt64(1) + to_raw = unwrap(to) + GC.@preserve copies begin + for (_, to_span) in dep_mod.spans + write_remainder!(copies, offset, to_raw, to_span.ptr, to_span.len) + offset += to_span.len end + @assert offset == length(copies)+UInt64(1) end # Ensure that the data is visible @@ -455,3 +462,88 @@ function move!(dep_mod::RemainderAliasing, to_space::MemorySpace, from_space::Me return end + +function read_remainder!(copies::Vector{UInt8}, copies_offset::UInt64, from::Array, from_ptr::UInt64, len::UInt64) + elsize = sizeof(eltype(from)) + @assert len / elsize == round(UInt64, len / elsize) "Span length is not an integer multiple of the element size: $(len) / $(elsize) = $(len / elsize) (elsize: $elsize)" + n = UInt64(len / elsize) + from_offset_n = UInt64((from_ptr - UInt64(pointer(from))) / elsize) + UInt64(1) + from_vec = reshape(from, prod(size(from)))::DenseVector{eltype(from)} + # unsafe_wrap(Array, ...) doesn't like unaligned memory + unsafe_copyto!(Ptr{eltype(from)}(pointer(copies, copies_offset)), pointer(from_vec, from_offset_n), n) +end +function read_remainder!(copies::Vector{UInt8}, copies_offset::UInt64, from::DenseArray, from_ptr::UInt64, len::UInt64) + elsize = sizeof(eltype(from)) + @assert len / elsize == round(UInt64, len / elsize) "Span length is not an integer multiple of the element size: $(len) / $(elsize) = $(len / elsize) (elsize: $elsize)" + n = UInt64(len / elsize) + from_offset_n = UInt64((from_ptr - UInt64(pointer(from))) / elsize) + UInt64(1) + from_vec = reshape(from, prod(size(from)))::DenseVector{eltype(from)} + copies_typed = unsafe_wrap(Vector{eltype(from)}, Ptr{eltype(from)}(pointer(copies, copies_offset)), n) + copyto!(copies_typed, 1, from_vec, Int(from_offset_n), Int(n)) +end +function read_remainder!(copies::Vector{UInt8}, copies_offset::UInt64, from, from_ptr::UInt64, n::UInt64) + real_from = find_object_holding_ptr(from, from_ptr) + return read_remainder!(copies, copies_offset, real_from, from_ptr, n) +end + +function write_remainder!(copies::Vector{UInt8}, copies_offset::UInt64, to::Array, to_ptr::UInt64, len::UInt64) + elsize = sizeof(eltype(to)) + @assert len / elsize == round(UInt64, len / elsize) "Span length is not an integer multiple of the element size: $(len) / $(elsize) = $(len / elsize) (elsize: $elsize)" + n = UInt64(len / elsize) + to_offset_n = UInt64((to_ptr - UInt64(pointer(to))) / elsize) + UInt64(1) + to_vec = reshape(to, prod(size(to)))::DenseVector{eltype(to)} + # unsafe_wrap(Array, ...) doesn't like unaligned memory + unsafe_copyto!(pointer(to_vec, to_offset_n), Ptr{eltype(to)}(pointer(copies, copies_offset)), n) +end +function write_remainder!(copies::Vector{UInt8}, copies_offset::UInt64, to::DenseArray, to_ptr::UInt64, len::UInt64) + elsize = sizeof(eltype(to)) + @assert len / elsize == round(UInt64, len / elsize) "Span length is not an integer multiple of the element size: $(len) / $(elsize) = $(len / elsize) (elsize: $elsize)" + n = UInt64(len / elsize) + to_offset_n = UInt64((to_ptr - UInt64(pointer(to))) / elsize) + UInt64(1) + to_vec = reshape(to, prod(size(to)))::DenseVector{eltype(to)} + copies_typed = unsafe_wrap(Vector{eltype(to)}, Ptr{eltype(to)}(pointer(copies, copies_offset)), n) + copyto!(to_vec, Int(to_offset_n), copies_typed, 1, Int(n)) +end +function write_remainder!(copies::Vector{UInt8}, copies_offset::UInt64, to, to_ptr::UInt64, n::UInt64) + real_to = find_object_holding_ptr(to, to_ptr) + return write_remainder!(copies, copies_offset, real_to, to_ptr, n) +end + +# Remainder copies for common objects +for wrapper in (UpperTriangular, LowerTriangular, UnitUpperTriangular, UnitLowerTriangular, SubArray) + @eval function read_remainder!(copies::Vector{UInt8}, copies_offset::UInt64, from::$wrapper, from_ptr::UInt64, n::UInt64) + read_remainder!(copies, copies_offset, parent(from), from_ptr, n) + end + @eval function write_remainder!(copies::Vector{UInt8}, copies_offset::UInt64, to::$wrapper, to_ptr::UInt64, n::UInt64) + write_remainder!(copies, copies_offset, parent(to), to_ptr, n) + end +end + +function read_remainder!(copies::Vector{UInt8}, copies_offset::UInt64, from::Base.RefValue, from_ptr::UInt64, n::UInt64) + if from_ptr == UInt64(Base.pointer_from_objref(from) + fieldoffset(typeof(from), 1)) + unsafe_copyto!(pointer(copies, copies_offset), Ptr{UInt8}(from_ptr), n) + else + read_remainder!(copies, copies_offset, from[], from_ptr, n) + end +end +function write_remainder!(copies::Vector{UInt8}, copies_offset::UInt64, to::Base.RefValue, to_ptr::UInt64, n::UInt64) + if to_ptr == UInt64(Base.pointer_from_objref(to) + fieldoffset(typeof(to), 1)) + unsafe_copyto!(Ptr{UInt8}(to_ptr), pointer(copies, copies_offset), n) + else + write_remainder!(copies, copies_offset, to[], to_ptr, n) + end +end + +function find_object_holding_ptr(A::SparseMatrixCSC, ptr::UInt64) + span = LocalMemorySpan(pointer(A.nzval), length(A.nzval)*sizeof(eltype(A.nzval))) + if span_start(span) <= ptr <= span_end(span) + return A.nzval + end + span = LocalMemorySpan(pointer(A.colptr), length(A.colptr)*sizeof(eltype(A.colptr))) + if span_start(span) <= ptr <= span_end(span) + return A.colptr + end + span = LocalMemorySpan(pointer(A.rowval), length(A.rowval)*sizeof(eltype(A.rowval))) + @assert span_start(span) <= ptr <= span_end(span) "Pointer $ptr not found in SparseMatrixCSC" + return A.rowval +end diff --git a/src/datadeps/scheduling.jl b/src/datadeps/scheduling.jl index b2bcaca7b..0bf9818f6 100644 --- a/src/datadeps/scheduling.jl +++ b/src/datadeps/scheduling.jl @@ -111,11 +111,7 @@ function datadeps_schedule_task(sched::UltraScheduler, state::DataDepsState, all delete!(spaces_completed, our_space) continue end - our_proc = if Dagger.current_acceleration() isa Dagger.MPIAcceleration - first(sort(collect(our_space_procs), by=Dagger.short_name)) - else - rand(our_space_procs) - end + our_proc = rand(our_space_procs) break end diff --git a/src/dtask.jl b/src/dtask.jl index 13e66cafe..c9e9e811f 100644 --- a/src/dtask.jl +++ b/src/dtask.jl @@ -11,14 +11,14 @@ Base.wait(t::ThunkFuture) = Dagger.Sch.thunk_yield() do wait(t.future) return end -function Base.fetch(t::ThunkFuture; proc=OSProc(), raw=false) +function Base.fetch(t::ThunkFuture; proc=OSProc(), raw=false, move_value=!raw, unwrap=!raw, uniform=uniform_execution()) error, value = Dagger.Sch.thunk_yield() do fetch(t.future) end if error throw(value) end - if raw + if !move_value return value else return move(proc, value) @@ -65,13 +65,11 @@ function Base.wait(t::DTask) wait(t.future) return end -function Base.fetch(t::DTask; raw=false, move_value=nothing, unwrap=nothing) +function Base.fetch(t::DTask; raw=false, move_value=!raw, unwrap=!raw, uniform=false) if !istaskstarted(t) throw(ConcurrencyViolationError("Cannot `fetch` an unlaunched `DTask`")) end - # Datadeps/aliasing API: move_value=false => don't move => raw=true - raw_eff = move_value !== nothing ? !move_value : raw - return fetch(t.future; raw=raw_eff) + return fetch(t.future; move_value, unwrap, uniform) end function waitany(tasks::Vector{DTask}) if isempty(tasks) diff --git a/src/memory-spaces.jl b/src/memory-spaces.jl index 39bfa7ccc..a531509cf 100644 --- a/src/memory-spaces.jl +++ b/src/memory-spaces.jl @@ -1,44 +1,8 @@ -struct DistributedAcceleration <: Acceleration end - -const ACCELERATION = TaskLocalValue{Acceleration}(() -> DistributedAcceleration()) - -current_acceleration() = ACCELERATION[] - -default_processor(::DistributedAcceleration) = OSProc(myid()) -default_processor(accel::DistributedAcceleration, x) = default_processor(accel) -default_processor() = default_processor(current_acceleration()) - -accelerate!(accel::Symbol) = accelerate!(Val{accel}()) -accelerate!(::Val{:distributed}) = accelerate!(DistributedAcceleration()) - -initialize_acceleration!(a::DistributedAcceleration) = nothing -function accelerate!(accel::Acceleration) - initialize_acceleration!(accel) - ACCELERATION[] = accel -end -accelerate!(::Nothing) = nothing - -accel_matches_proc(accel::DistributedAcceleration, proc::OSProc) = true -accel_matches_proc(accel::DistributedAcceleration, proc) = true - -function compatible_processors(accel::Union{Acceleration,Nothing}, scope::AbstractScope, procs::Vector{<:Processor}) - comp = compatible_processors(scope, procs) - accel === nothing && return comp - return Set(p for p in comp if accel_matches_proc(accel, p)) -end - struct CPURAMMemorySpace <: MemorySpace owner::Int end -root_worker_id(space::CPURAMMemorySpace) = space.owner - CPURAMMemorySpace() = CPURAMMemorySpace(myid()) - -default_processor(space::CPURAMMemorySpace) = OSProc(space.owner) -default_memory_space(accel::DistributedAcceleration) = CPURAMMemorySpace(myid()) -default_memory_space(accel::DistributedAcceleration, x) = default_memory_space(accel) -default_memory_space(x) = default_memory_space(current_acceleration(), x) -default_memory_space() = default_memory_space(current_acceleration()) +root_worker_id(space::CPURAMMemorySpace) = space.owner memory_space(x, proc::Processor=default_processor()) = first(memory_spaces(proc)) memory_space(x::Processor) = first(memory_spaces(x)) @@ -47,17 +11,6 @@ memory_space(x::DTask) = memory_space(fetch(x; move_value=false, unwrap=false)) memory_spaces(::P) where {P<:Processor} = throw(ArgumentError("Must define `memory_spaces` for `$P`")) - -function memory_spaces(proc::OSProc) - children = get_processors(proc) - spaces = Set{MemorySpace}() - for proc in children - for space in memory_spaces(proc) - push!(spaces, space) - end - end - return spaces -end memory_spaces(proc::ThreadProc) = Set([CPURAMMemorySpace(proc.owner)]) processors(::S) where {S<:MemorySpace} = @@ -67,12 +20,10 @@ processors(space::CPURAMMemorySpace) = ### In-place Data Movement -function unwrap(x::Chunk; uniform::Bool=false) - @assert root_worker_id(x.handle) == myid() "Chunk $x is not owned by this process: $(root_worker_id(x.handle)) != $(myid())" - if x.handle isa DRef - return MemPool.poolget(x.handle) - end - return MemPool.poolget(x.handle; uniform) +unwrap(x::Chunk) = unwrap(x.handle) +function unwrap(handle::DRef) + @assert root_worker_id(handle) == myid() "DRef $handle is not owned by this process: $(root_worker_id(handle)) != $(myid())" + return MemPool.poolget(x.handle) end move!(dep_mod, to_space::MemorySpace, from_space::MemorySpace, to::T, from::F) where {T,F} = throw(ArgumentError("No `move!` implementation defined for $F -> $T")) @@ -140,20 +91,20 @@ function type_may_alias(::Type{T}) where T return false end -may_alias(::MemorySpace, ::MemorySpace) = true +may_alias(::MemorySpace, ::MemorySpace) = false +may_alias(space1::M, space2::M) where M<:MemorySpace = space1 == space2 may_alias(space1::CPURAMMemorySpace, space2::CPURAMMemorySpace) = space1.owner == space2.owner -# RemotePtr and MemorySpan are defined in utils/memory-span.jl (included earlier). - abstract type AbstractAliasing end memory_spans(::T) where T<:AbstractAliasing = throw(ArgumentError("Must define `memory_spans` for `$T`")) memory_spans(x) = memory_spans(aliasing(x)) memory_spans(x, T) = memory_spans(aliasing(x, T)) -struct AliasingWrapper <: AbstractAliasing +### Type-generic aliasing info wrapper + +mutable struct AliasingWrapper <: AbstractAliasing inner::AbstractAliasing hash::UInt64 - AliasingWrapper(inner::AbstractAliasing) = new(inner, hash(inner)) end memory_spans(x::AliasingWrapper) = memory_spans(x.inner) @@ -162,8 +113,204 @@ equivalent_structure(x::AliasingWrapper, y::AliasingWrapper) = Base.hash(x::AliasingWrapper, h::UInt64) = hash(x.hash, h) Base.isequal(x::AliasingWrapper, y::AliasingWrapper) = x.hash == y.hash Base.:(==)(x::AliasingWrapper, y::AliasingWrapper) = x.hash == y.hash -will_alias(x::AliasingWrapper, y::AliasingWrapper) = - will_alias(x.inner, y.inner) +will_alias(x::AliasingWrapper, y::AliasingWrapper) = will_alias(x.inner, y.inner) + +### Small dictionary type + +struct SmallDict{K,V} <: AbstractDict{K,V} + keys::Vector{K} + vals::Vector{V} +end +SmallDict{K,V}() where {K,V} = SmallDict{K,V}(Vector{K}(), Vector{V}()) +function Base.getindex(d::SmallDict{K,V}, key) where {K,V} + key_idx = findfirst(==(convert(K, key)), d.keys) + if key_idx === nothing + throw(KeyError(key)) + end + return @inbounds d.vals[key_idx] +end +function Base.setindex!(d::SmallDict{K,V}, val, key) where {K,V} + key_conv = convert(K, key) + key_idx = findfirst(==(key_conv), d.keys) + if key_idx === nothing + push!(d.keys, key_conv) + push!(d.vals, convert(V, val)) + else + d.vals[key_idx] = convert(V, val) + end + return val +end +Base.haskey(d::SmallDict{K,V}, key) where {K,V} = in(convert(K, key), d.keys) +Base.keys(d::SmallDict) = d.keys +Base.length(d::SmallDict) = length(d.keys) +Base.iterate(d::SmallDict) = iterate(d, 1) +Base.iterate(d::SmallDict, state) = state > length(d.keys) ? nothing : (d.keys[state] => d.vals[state], state+1) + +### Type-stable lookup structure for AliasingWrappers + +struct AliasingLookup + # The set of memory spaces that are being tracked + spaces::Vector{MemorySpace} + # The set of AliasingWrappers that are being tracked + # One entry for each AliasingWrapper + ainfos::Vector{AliasingWrapper} + # The memory spaces for each AliasingWrapper + # One entry for each AliasingWrapper + ainfos_spaces::Vector{Vector{Int}} + # The spans for each AliasingWrapper in each memory space + # One entry for each AliasingWrapper + spans::Vector{SmallDict{Int,Vector{LocalMemorySpan}}} + # The set of AliasingWrappers that only exist in a single memory space + # One entry for each AliasingWrapper + ainfos_only_space::Vector{Int} + # The bounding span for each AliasingWrapper in each memory space + # One entry for each AliasingWrapper + bounding_spans::Vector{SmallDict{Int,LocalMemorySpan}} + # The interval tree of the bounding spans for each AliasingWrapper + # One entry for each MemorySpace + bounding_spans_tree::Vector{IntervalTree{LocatorMemorySpan{Int},UInt64}} + + AliasingLookup() = new(MemorySpace[], + AliasingWrapper[], + Vector{Int}[], + SmallDict{Int,Vector{LocalMemorySpan}}[], + Int[], + SmallDict{Int,LocalMemorySpan}[], + IntervalTree{LocatorMemorySpan{Int},UInt64}[]) +end +function Base.push!(lookup::AliasingLookup, ainfo::AliasingWrapper) + # Update the set of memory spaces and spans, + # and find the bounding spans for this AliasingWrapper + spaces_set = Set{MemorySpace}(lookup.spaces) + self_spaces_set = Set{Int}() + spans = SmallDict{Int,Vector{LocalMemorySpan}}() + for span in memory_spans(ainfo) + space = span.ptr.space + if !in(space, spaces_set) + push!(spaces_set, space) + push!(lookup.spaces, space) + push!(lookup.bounding_spans_tree, IntervalTree{LocatorMemorySpan{Int}}()) + end + space_idx = findfirst(==(space), lookup.spaces) + push!(self_spaces_set, space_idx) + spans_in_space = get!(Vector{LocalMemorySpan}, spans, space_idx) + push!(spans_in_space, LocalMemorySpan(span)) + end + push!(lookup.ainfos_spaces, collect(self_spaces_set)) + push!(lookup.spans, spans) + + # Update the set of AliasingWrappers + push!(lookup.ainfos, ainfo) + ainfo_idx = length(lookup.ainfos) + + # Check if the AliasingWrapper only exists in a single memory space + if length(self_spaces_set) == 1 + space_idx = only(self_spaces_set) + push!(lookup.ainfos_only_space, space_idx) + else + push!(lookup.ainfos_only_space, 0) + end + + # Add the bounding spans for this AliasingWrapper + bounding_spans = SmallDict{Int,LocalMemorySpan}() + for space_idx in keys(spans) + space_spans = spans[space_idx] + bound_start = minimum(span_start, space_spans) + bound_end = maximum(span_end, space_spans) + bounding_span = LocalMemorySpan(bound_start, bound_end - bound_start) + bounding_spans[space_idx] = bounding_span + insert!(lookup.bounding_spans_tree[space_idx], LocatorMemorySpan(bounding_span, ainfo_idx)) + end + push!(lookup.bounding_spans, bounding_spans) + + return ainfo_idx +end +struct AliasingLookupFinder + lookup::AliasingLookup + ainfo::AliasingWrapper + ainfo_idx::Int + spaces_idx::Vector{Int} + to_consider::Vector{Int} +end +Base.eltype(::AliasingLookupFinder) = AliasingWrapper +Base.IteratorSize(::AliasingLookupFinder) = Base.SizeUnknown() +# FIXME: We should use a Dict{UInt,Int} to find the ainfo_idx instead of linear search +function Base.intersect(lookup::AliasingLookup, ainfo::AliasingWrapper; ainfo_idx=nothing) + if ainfo_idx === nothing + ainfo_idx = something(findfirst(==(ainfo), lookup.ainfos)) + end + spaces_idx = lookup.ainfos_spaces[ainfo_idx] + to_consider_spans = LocatorMemorySpan{Int}[] + for space_idx in spaces_idx + bounding_spans_tree = lookup.bounding_spans_tree[space_idx] + self_bounding_span = LocatorMemorySpan(lookup.bounding_spans[ainfo_idx][space_idx], 0) + find_overlapping!(bounding_spans_tree, self_bounding_span, to_consider_spans; exact=false) + end + to_consider = Int[locator.owner for locator in to_consider_spans] + @assert all(to_consider .> 0) + return AliasingLookupFinder(lookup, ainfo, ainfo_idx, spaces_idx, to_consider) +end +Base.iterate(finder::AliasingLookupFinder) = iterate(finder, 1) +function Base.iterate(finder::AliasingLookupFinder, cursor_ainfo_idx) + ainfo_spaces = nothing + cursor_space_idx = 1 + + # New ainfos enter here + @label ainfo_restart + + # Check if we've exhausted all ainfos + if cursor_ainfo_idx > length(finder.to_consider) + return nothing + end + ainfo_idx = finder.to_consider[cursor_ainfo_idx] + + # Find the appropriate memory spaces for this ainfo + if ainfo_spaces === nothing + ainfo_spaces = finder.lookup.ainfos_spaces[ainfo_idx] + end + + # New memory spaces (for the same ainfo) enter here + @label space_restart + + # Check if we've exhausted all memory spaces for this ainfo, and need to move to the next ainfo + if cursor_space_idx > length(ainfo_spaces) + cursor_ainfo_idx += 1 + ainfo_spaces = nothing + cursor_space_idx = 1 + @goto ainfo_restart + end + + # Find the currently considered memory space for this ainfo + space_idx = ainfo_spaces[cursor_space_idx] + + # Check if this memory space is part of our target ainfo's spaces + if !(space_idx in finder.spaces_idx) + cursor_space_idx += 1 + @goto space_restart + end + + # Check if this ainfo's bounding span is part of our target ainfo's bounding span in this memory space + other_ainfo_bounding_span = finder.lookup.bounding_spans[ainfo_idx][space_idx] + self_bounding_span = finder.lookup.bounding_spans[finder.ainfo_idx][space_idx] + if !spans_overlap(other_ainfo_bounding_span, self_bounding_span) + cursor_space_idx += 1 + @goto space_restart + end + + # We have a overlapping bounds in the same memory space, so check if the ainfos are aliasing + # This is the slow path! + other_ainfo = finder.lookup.ainfos[ainfo_idx] + aliasing = will_alias(finder.ainfo, other_ainfo) + if !aliasing + cursor_ainfo_idx += 1 + ainfo_spaces = nothing + cursor_space_idx = 1 + @goto ainfo_restart + end + + # We overlap, so return the ainfo and the next ainfo index + return other_ainfo, cursor_ainfo_idx+1 +end struct NoAliasing <: AbstractAliasing end memory_spans(::NoAliasing) = MemorySpan{CPURAMMemorySpace}[] @@ -180,8 +327,11 @@ struct CombinedAliasing <: AbstractAliasing end function memory_spans(ca::CombinedAliasing) # FIXME: Don't hardcode CPURAMMemorySpace - all_spans = MemorySpan{CPURAMMemorySpace}[] - for sub_a in ca.sub_ainfos + if length(ca.sub_ainfos) == 0 + return MemorySpan{CPURAMMemorySpace}[] + end + all_spans = memory_spans(ca.sub_ainfos[1]) + for sub_a in ca.sub_ainfos[2:end] append!(all_spans, memory_spans(sub_a)) end return all_spans @@ -191,19 +341,20 @@ Base.:(==)(ca1::CombinedAliasing, ca2::CombinedAliasing) = Base.hash(ca1::CombinedAliasing, h::UInt) = hash(ca1.sub_ainfos, hash(CombinedAliasing, h)) -struct ObjectAliasing <: AbstractAliasing - ptr::Ptr{Cvoid} +struct ObjectAliasing{S<:MemorySpace} <: AbstractAliasing + ptr::RemotePtr{Cvoid,S} sz::UInt end +ObjectAliasing(ptr::RemotePtr{Cvoid,S}, sz::Integer) where {S<:MemorySpace} = + ObjectAliasing{S}(ptr, UInt(sz)) function ObjectAliasing(x::T) where T @nospecialize x - ptr = pointer_from_objref(x) + ptr = RemotePtr{Cvoid}(pointer_from_objref(x)) sz = sizeof(T) return ObjectAliasing(ptr, sz) end -function memory_spans(oa::ObjectAliasing) - rptr = RemotePtr{Cvoid}(oa.ptr) - span = MemorySpan{CPURAMMemorySpace}(rptr, oa.sz) +function memory_spans(oa::ObjectAliasing{S}) where S + span = MemorySpan{S}(oa.ptr, oa.sz) return [span] end @@ -243,16 +394,37 @@ end aliasing(::String) = NoAliasing() # FIXME: Not necessarily true aliasing(::Symbol) = NoAliasing() aliasing(::Type) = NoAliasing() -aliasing(x::DTask, T) = aliasing(fetch(x; move_value=false, unwrap=false), T) -aliasing(x::DTask) = aliasing(fetch(x; move_value=false, unwrap=false)) -function aliasing(accel::DistributedAcceleration, x::Chunk, T) +function aliasing(x::Chunk, T) + if root_worker_id(x.processor) == myid() + return aliasing(unwrap(x), T) + end @assert x.handle isa DRef return remotecall_fetch(root_worker_id(x.processor), x, T) do x, T aliasing(unwrap(x), T) end end -aliasing(x::Chunk, T) = aliasing(unwrap(x), T) -aliasing(x::Chunk) = aliasing(unwrap(x)) +function aliasing(x::Chunk) + if root_worker_id(x.processor) == myid() + return aliasing(unwrap(x)) + end + @assert x.handle isa DRef + return remotecall_fetch(root_worker_id(x.processor), x) do x + aliasing(unwrap(x)) + end +end +aliasing(x::DTask, T) = aliasing(fetch(x; move_value=false, unwrap=false), T) +aliasing(x::DTask) = aliasing(fetch(x; move_value=false, unwrap=false)) + +function aliasing(x::Base.RefValue{T}) where T + addr = UInt(Base.pointer_from_objref(x) + fieldoffset(typeof(x), 1)) + ptr = RemotePtr{Cvoid}(addr, CPURAMMemorySpace(myid())) + ainfo = ObjectAliasing(ptr, sizeof(x)) + if isassigned(x) && type_may_alias(T) && type_may_alias(typeof(x[])) + return CombinedAliasing([ainfo, aliasing(x[])]) + else + return CombinedAliasing([ainfo]) + end +end struct ContiguousAliasing{S} <: AbstractAliasing span::MemorySpan{S} @@ -305,13 +477,22 @@ function _memory_spans(a::StridedAliasing{T,N,S}, spans, ptr, dim) where {T,N,S} return spans end -function aliasing(x::SubArray{T,N,A}) where {T,N,A<:Array} +function aliasing(x::SubArray{T,N}) where {T,N} if isbitstype(T) - S = CPURAMMemorySpace - return StridedAliasing{T,ndims(x),S}(RemotePtr{Cvoid}(pointer(parent(x))), - RemotePtr{Cvoid}(pointer(x)), - parentindices(x), - size(x), strides(x)) + p = parent(x) + space = memory_space(p) + S = typeof(space) + parent_ptr = RemotePtr{Cvoid}(UInt64(pointer(p)), space) + ptr = RemotePtr{Cvoid}(UInt64(pointer(x)), space) + NA = ndims(p) + raw_inds = parentindices(x) + inds = ntuple(i->raw_inds[i] isa Integer ? (raw_inds[i]:raw_inds[i]) : UnitRange(raw_inds[i]), NA) + sz = ntuple(i->length(inds[i]), NA) + return StridedAliasing{T,NA,S}(parent_ptr, + ptr, + inds, + sz, + strides(p)) else # FIXME: Also ContiguousAliasing of container #return IteratedAliasing(x) @@ -428,7 +609,7 @@ end function will_alias(x_span::MemorySpan, y_span::MemorySpan) may_alias(x_span.ptr.space, y_span.ptr.space) || return false # FIXME: Allow pointer conversion instead of just failing - @assert x_span.ptr.space == y_span.ptr.space + @assert x_span.ptr.space == y_span.ptr.space "Memory spans are in different spaces: $(x_span.ptr.space) vs. $(y_span.ptr.space)" x_end = x_span.ptr + x_span.len - 1 y_end = y_span.ptr + y_span.len - 1 return x_span.ptr <= y_end && y_span.ptr <= x_end @@ -440,5 +621,5 @@ unsafe_free!(x::Chunk) = remotecall_fetch(root_worker_id(x), x) do x unsafe_free!(unwrap(x)) return end -unsafe_free!(x::DTask) = unsafe_free!(fetch(x; raw=true)) +unsafe_free!(x::DTask) = unsafe_free!(fetch(x; move_value=false, unwrap=false)) unsafe_free!(x) = nothing # Do nothing by default diff --git a/src/mpi.jl b/src/mpi.jl index 1b84a7b9d..b5723e6a2 100644 --- a/src/mpi.jl +++ b/src/mpi.jl @@ -1,3 +1,5 @@ +@warn "Move to MPIExt.jl" maxlog=1 + using MPI const CHECK_UNIFORMITY = Ref{Bool}(false) @@ -5,7 +7,7 @@ function check_uniformity!(check::Bool=true) CHECK_UNIFORMITY[] = check end function check_uniform(value::Integer, original=value) - CHECK_UNIFORMITY[] || return true + CHECK_UNIFORMITY[] && uniform_execution() || return true comm = MPI.COMM_WORLD rank = MPI.Comm_rank(comm) matched = compare_all(value, comm) @@ -13,14 +15,14 @@ function check_uniform(value::Integer, original=value) if rank == 0 Core.print("[$rank] Found non-uniform value!\n") end - Core.print("[$rank] value=$value, original=$original") + Core.print("[$rank] value=$value, original=$original\n") throw(ArgumentError("Non-uniform value")) end MPI.Barrier(comm) return matched end function check_uniform(value, original=value) - CHECK_UNIFORMITY[] || return true + CHECK_UNIFORMITY[] && uniform_execution() || return true return check_uniform(hash(value), original) end @@ -29,7 +31,7 @@ function compare_all(value, comm) size = MPI.Comm_size(comm) for i in 0:(size-1) if i != rank - send_yield(value, comm, i, UInt32(0); check_seen=false) + send_yield(value, comm, i, UInt32(0)) end end match = true @@ -56,7 +58,9 @@ function aliasing(accel::MPIAcceleration, x::Chunk, T) check_uniform(tag) rank = MPI.Comm_rank(accel.comm) if handle.rank == rank - ainfo = aliasing(x, T) + ainfo = _with_default_acceleration() do + aliasing(x, T) + end #Core.print("[$rank] aliasing: $ainfo, sending\n") @opcounter :aliasing_bcast_send_yield bcast_send_yield(ainfo, accel.comm, handle.rank, tag) @@ -68,19 +72,21 @@ function aliasing(accel::MPIAcceleration, x::Chunk, T) check_uniform(ainfo) return ainfo end + default_processor(accel::MPIAcceleration) = MPIOSProc(accel.comm, 0) default_processor(accel::MPIAcceleration, x) = MPIOSProc(accel.comm, 0) default_processor(accel::MPIAcceleration, x::Chunk) = MPIOSProc(x.handle.comm, x.handle.rank) default_processor(accel::MPIAcceleration, x::Function) = MPIOSProc(accel.comm, MPI.Comm_rank(accel.comm)) default_processor(accel::MPIAcceleration, T::Type) = MPIOSProc(accel.comm, MPI.Comm_rank(accel.comm)) +uniform_execution(accel::MPIAcceleration) = true -#TODO: Add a lock +@warn "Add a lock to MPIClusterProcChildren" maxlog=1 const MPIClusterProcChildren = Dict{MPI.Comm, Set{Processor}}() struct MPIClusterProc <: Processor comm::MPI.Comm function MPIClusterProc(comm::MPI.Comm) - populate_children(comm) + populate_children!(comm) return new(comm) end end @@ -89,7 +95,7 @@ Sch.init_proc(state, proc::MPIClusterProc, log_sink) = Sch.init_proc(state, MPIO MPIClusterProc() = MPIClusterProc(MPI.COMM_WORLD) -function populate_children(comm::MPI.Comm) +function populate_children!(comm::MPI.Comm) children = get_processors(OSProc()) MPIClusterProcChildren[comm] = children end @@ -214,7 +220,7 @@ get_parent(proc::MPIProcessor) = MPIOSProc(proc.comm, proc.rank) short_name(proc::MPIProcessor) = "(MPI: $(proc.rank), $(short_name(proc.innerProc)))" function get_processors(mosProc::MPIOSProc) - populate_children(mosProc.comm) + populate_children!(mosProc.comm) children = MPIClusterProcChildren[mosProc.comm] mpiProcs = Set{Processor}() for proc in children @@ -284,25 +290,43 @@ function processors(memSpace::MPIMemorySpace) end struct MPIRefID - tid::Int - uid::UInt - id::Int - function MPIRefID(tid, uid, id) - @assert tid > 0 || uid > 0 "Invalid MPIRefID: tid=$tid, uid=$uid, id=$id" - return new(tid, uid, id) + tid::UInt32 + generic::Bool + id::UInt32 + function MPIRefID(tid, generic, id) + @assert tid > 0 || generic "Invalid MPIRefID: tid=$tid, generic=$generic, id=$id" + return new(tid, generic, id) end end Base.hash(id::MPIRefID, h::UInt=UInt(0)) = - hash(id.tid, hash(id.uid, hash(id.id, hash(MPIRefID, h)))) + hash(id.tid, hash(id.generic, hash(id.id, hash(MPIRefID, h)))) function check_uniform(ref::MPIRefID, original=ref) return check_uniform(ref.tid, original) && - check_uniform(ref.uid, original) && + check_uniform(ref.generic, original) && check_uniform(ref.id, original) end -const MPIREF_TID = Dict{Int, Threads.Atomic{Int}}() -const MPIREF_UID = Dict{Int, Threads.Atomic{Int}}() +function to_tag() + if Dagger.in_task() + # Tag is already assigned + opts = Dagger.get_tls().task_spec.options + tag = opts.tag + return tag + end + + # Generate a tag based on the TID + @assert !Sch.SCHED_MOVE[] "We should not create a tag during Sch move" + return to_tag(take_ref_id!()) +end +to_tag(id::MPIRefID) = id.generic ? id.id : id.tid + +# Semi-public internal value for passing TID to MPIRefID generation +const MPI_TID = ScopedValue{Int64}(0) +# Private internal value for tracking TID-based ID generations +#const _MPIREF_TID = Dict{Int, Threads.Atomic{Int}}() +# Private internal value for tracking non-TID (uniform) ID generations +#const _MPIREF_GENERIC = Threads.Atomic{Int}(1) mutable struct MPIRef comm::MPI.Comm @@ -319,8 +343,15 @@ function check_uniform(ref::MPIRef, original=ref) check_uniform(ref.id, original) end +function unwrap(handle::MPIRef) + @assert handle.rank == MPI.Comm_rank(handle.comm) "MPIRef $handle is not owned by this rank: $(handle.rank) != $(MPI.Comm_rank(handle.comm))" + return unwrap(handle.innerRef) +end + +to_tag(ref::MPIRef) = to_tag(ref.id) + move(from_proc::Processor, to_proc::Processor, x::MPIRef) = - move(from_proc, to_proc, poolget(x; uniform=FETCH_UNIFORM[])) + move(from_proc, to_proc, poolget(x; uniform=uniform_execution())) function affinity(x::MPIRef) if x.innerRef === nothing @@ -332,25 +363,29 @@ end function take_ref_id!() tid = 0 - uid = 0 + generic = 0 id = 0 if Dagger.in_task() tid = sch_handle().thunk_id.id - uid = 0 - counter = get!(MPIREF_TID, tid, Threads.Atomic{Int}(1)) - id = Threads.atomic_add!(counter, 1) + #counter = get!(_MPIREF_TID, tid, Threads.Atomic{Int}(1)) + #id = Threads.atomic_add!(counter, 1) + id = tid elseif MPI_TID[] != 0 tid = MPI_TID[] - uid = 0 - counter = get!(MPIREF_TID, tid, Threads.Atomic{Int}(1)) - id = Threads.atomic_add!(counter, 1) - elseif MPI_UID[] != 0 - tid = 0 - uid = MPI_UID[] - counter = get!(MPIREF_UID, uid, Threads.Atomic{Int}(1)) - id = Threads.atomic_add!(counter, 1) + #counter = get!(_MPIREF_TID, tid, Threads.Atomic{Int}(1)) + #id = Threads.atomic_add!(counter, 1) + id = tid + else + if current_task() !== Base.roottask + throw(ConcurrencyViolationError("Attempted to generate generic MPIRefID in a multi-threaded context")) + end + generic = true + #id = Threads.atomic_add!(_MPIREF_GENERIC, 1) + id = next_id() # Abuse the TID counter for generic IDs + check_uniform(id) end - return MPIRefID(tid, uid, id) + @assert id < MPI.tag_ub() + return MPIRefID(tid, generic, id) end #TODO: partitioned scheduling with comm bifurcation @@ -372,6 +407,7 @@ const DEADLOCK_WARN_PERIOD = TaskLocalValue{Float64}(()->10.0) const DEADLOCK_TIMEOUT_PERIOD = TaskLocalValue{Float64}(()->120.0) const RECV_WAITING = Base.Lockable(Dict{Tuple{MPI.Comm, Int, Int}, Base.Event}()) +@warn "Rename and make generic these in-place structs" maxlog=1 struct InplaceInfo type::DataType shape::Tuple @@ -530,19 +566,22 @@ function recv_yield_serialized(comm, my_rank, their_rank, tag) end const SEEN_TAGS = Dict{Int32, Type}() -send_yield!(value, comm, dest, tag; check_seen::Bool=true) = - _send_yield(value, comm, dest, tag; check_seen, inplace=true) -send_yield(value, comm, dest, tag; check_seen::Bool=true) = - _send_yield(value, comm, dest, tag; check_seen, inplace=false) -function _send_yield(value, comm, dest, tag; check_seen::Bool=true, inplace::Bool) +send_yield!(value, comm, dest, tag) = + _send_yield(value, comm, dest, tag; inplace=true) +send_yield(value, comm, dest, tag) = + _send_yield(value, comm, dest, tag; inplace=false) +function _send_yield(value, comm, dest, tag; inplace::Bool) rank = MPI.Comm_rank(comm) - if check_seen && haskey(SEEN_TAGS, tag) && SEEN_TAGS[tag] !== typeof(value) + #= + if CHECK_UNIFORMITY[] && haskey(SEEN_TAGS, tag) && SEEN_TAGS[tag] !== typeof(value) @error "[rank $(MPI.Comm_rank(comm))][tag $tag] Already seen tag (previous type: $(SEEN_TAGS[tag]), new type: $(typeof(value)))" exception=(InterruptException(),backtrace()) end - if check_seen + if CHECK_UNIFORMITY[] SEEN_TAGS[tag] = typeof(value) end + =# + #Core.println("[rank $(MPI.Comm_rank(comm))][tag $tag] Starting send to [$dest]: $(typeof(value)), is support inplace? $(supports_inplace_mpi(value))") if inplace && supports_inplace_mpi(value) send_yield_inplace(value, comm, rank, dest, tag) @@ -633,15 +672,15 @@ function mpi_deadlock_detect(detect, time_start, warn_period, timeout_period, ra end #discuss this with julian +@warn "Fix this WeakChunk method" maxlog=1 WeakChunk(c::Chunk{T,H}) where {T,H<:MPIRef} = WeakChunk(c.handle.rank, c.handle.id.id, WeakRef(c)) -function MemPool.poolget(ref::MPIRef; uniform::Bool=false) +function MemPool.poolget(ref::MPIRef; uniform::Bool=uniform_execution()) @assert uniform || ref.rank == MPI.Comm_rank(ref.comm) "MPIRef rank mismatch: $(ref.rank) != $(MPI.Comm_rank(ref.comm))" if uniform tag = to_tag() if ref.rank == MPI.Comm_rank(ref.comm) value = poolget(ref.innerRef) - @opcounter :poolget_bcast_send_yield bcast_send_yield(value, ref.comm, ref.rank, tag) return value else @@ -651,7 +690,7 @@ function MemPool.poolget(ref::MPIRef; uniform::Bool=false) return poolget(ref.innerRef) end end -fetch_handle(ref::MPIRef; uniform::Bool=false) = poolget(ref; uniform) +fetch_handle(ref::MPIRef; uniform::Bool=uniform_execution()) = poolget(ref; uniform) function move!(dep_mod, to_space::MPIMemorySpace, from_space::MPIMemorySpace, to::Chunk, from::Chunk) @assert to.handle isa MPIRef && from.handle isa MPIRef "MPIRef expected" @@ -662,7 +701,7 @@ function move!(dep_mod, to_space::MPIMemorySpace, from_space::MPIMemorySpace, to move!(dep_mod, to_space.innerSpace, from_space.innerSpace, to, from) else @dagdebug nothing :mpi "[$local_rank][$tag] Moving from $(from_space.rank) to $(to_space.rank)\n" - tag = to_tag() + tag = to_tag(from.handle) if local_rank == from_space.rank send_yield!(poolget(from.handle; uniform=false), to_space.comm, to_space.rank, tag) elseif local_rank == to_space.rank @@ -684,7 +723,7 @@ function move!(dep_mod::RemainderAliasing{<:MPIMemorySpace}, to_space::MPIMemory if to_space.rank == from_space.rank == local_rank move!(dep_mod, to_space.innerSpace, from_space.innerSpace, to, from) else - tag = to_tag() + tag = to_tag(from.handle) @dagdebug nothing :mpi "[$local_rank][$tag] Moving from $(from_space.rank) to $(to_space.rank)\n" if local_rank == from_space.rank # Get the source data for each span @@ -702,7 +741,7 @@ function move!(dep_mod::RemainderAliasing{<:MPIMemorySpace}, to_space::MPIMemory # Send the spans #send_yield(len, to_space.comm, to_space.rank, tag) - send_yield!(copies, to_space.comm, to_space.rank, tag; check_seen=false) + send_yield!(copies, to_space.comm, to_space.rank, tag) #send_yield(copies, to_space.comm, to_space.rank, tag) elseif local_rank == to_space.rank # Receive the spans @@ -749,18 +788,35 @@ function move(src::MPIOSProc, dst::MPIProcessor, x::Chunk) end end -const MPI_UNIFORM = ScopedValue{Bool}(false) -# When true, move(_, _, MPIRef) uses poolget(; uniform=true) so the owner bcasts and the fetcher recv (e.g. rank 0 collecting). -const FETCH_UNIFORM = ScopedValue{Bool}(true) - -function remotecall_endpoint(f, accel::Dagger.MPIAcceleration, from_proc, to_proc, from_space, to_space, data) +#= +function remotecall_endpoint(f, accel::MPIAcceleration, from_proc, to_proc, from_space, to_space, data::Chunk) + loc_rank = MPI.Comm_rank(accel.comm) + if loc_rank == from_proc.rank + # FIXME: Descend via move_rewrap, and send data to to_proc + elseif loc_rank == to_proc.rank + # FIXME: Listen for data from from_proc to locally wrap as Chunk + while true + value = recv_yield(accel.comm, from_proc.rank, tag) + end + bcast_recv_yield(data_new, accel.comm, to_proc.rank, tag) + else + # Wait for final Chunk + return recv_yield(accel.comm, to_proc.rank, tag) + end +end +function remotecall_endpoint_transfer(f, accel::MPIAcceleration, from_proc, to_proc, from_space, to_space, data) + loc_rank = MPI.Comm_rank(accel.comm) + if loc_rank == from_proc.rank + elseif loc_rank == to_proc.rank + end +end +=# +function remotecall_endpoint(f, accel::MPIAcceleration, from_proc, to_proc, from_space, to_space, data::Chunk) loc_rank = MPI.Comm_rank(accel.comm) task = DATADEPS_CURRENT_TASK[] - return with(MPI_UID=>task.uid, MPI_UNIFORM=>true) do - @assert data isa Chunk "Expected Chunk, got $(typeof(data))" + return with(MPI_UID=>task.uid) do space = memory_space(data) tag = to_tag() - type_tag = to_tag() T = move_type(from_proc.innerProc, to_proc.innerProc, chunktype(data)) T_new = f !== identity ? Base._return_type(f, Tuple{T}) : T need_bcast = !isconcretetype(T_new) || T_new === Union{} || T_new === Nothing || T_new === Any @@ -773,11 +829,11 @@ function remotecall_endpoint(f, accel::Dagger.MPIAcceleration, from_proc, to_pro data_converted = f(move(from_proc.innerProc, to_proc.innerProc, value)) T_actual = typeof(data_converted) if need_bcast - bcast_send_yield(T_actual, accel.comm, to_proc.rank, type_tag) + bcast_send_yield(T_actual, accel.comm, to_proc.rank, tag) end return tochunk(data_converted, to_proc, to_space; type=T_actual) else - T_actual = need_bcast ? recv_yield(accel.comm, to_proc.rank, type_tag) : T_new + T_actual = need_bcast ? recv_yield(accel.comm, to_proc.rank, tag) : T_new return tochunk(nothing, to_proc, to_space; type=T_actual) end end @@ -832,7 +888,7 @@ move(::MPIProcessor, ::MPIProcessor, x::Chunk{<:Union{Function,Type}}) = poolget @warn "Is this uniform logic valuable to have?" maxlog=1 function move(src::MPIProcessor, dst::MPIProcessor, x::Chunk) - uniform = false #uniform = MPI_UNIFORM[] + uniform = uniform_execution() @assert uniform || src.rank == dst.rank "Unwrapping not permitted" if Sch.SCHED_MOVE[] # We can either unwrap locally, or return nothing diff --git a/src/queue.jl b/src/queue.jl index 37947a0ac..c1e264c06 100644 --- a/src/queue.jl +++ b/src/queue.jl @@ -125,7 +125,7 @@ function wait_all(f; check_errors::Bool=false) result = with_options(f; task_queue=queue) for task in queue.tasks if check_errors - fetch(task; raw=true) + fetch(task; move_value=false, unwrap=false) else wait(task) end diff --git a/src/sch/Sch.jl b/src/sch/Sch.jl index 3b8688a16..3c8353c59 100644 --- a/src/sch/Sch.jl +++ b/src/sch/Sch.jl @@ -706,21 +706,10 @@ end costs_cleanup() @goto pop_task - # Fire all newly-scheduled tasks (owner/local first, then by fire_order_key to avoid MPI execute! deadlock) + # Fire all newly-scheduled tasks @label fire_tasks - task_locs = collect(keys(to_fire)) - if Dagger.current_acceleration() isa Dagger.MPIAcceleration - sort!(task_locs, by=_mpi_fire_order_key) - end - rank = try - M = parentmodule(@__MODULE__) - (isdefined(M, :MPI) && M.MPI.Initialized()) ? Int(M.MPI.Comm_rank(M.MPI.COMM_WORLD)) : nothing - catch - nothing - end - for (i, task_loc) in enumerate(task_locs) - #Core.println("fire_order rank=", rank, " [", i, "/", length(task_locs), "] task_loc=", task_loc) - fire_tasks!(ctx, task_loc, to_fire[task_loc], state) + for (task_loc, task_spec) in to_fire + fire_tasks!(ctx, task_loc, task_spec, state) end to_fire_cleanup() @@ -1166,15 +1155,11 @@ function start_processor_runner!(istate::ProcessorInternalState, uid::UInt64, re # Try to steal a task @maybelog ctx timespan_start(ctx, :proc_steal_local, (;uid, worker=wid, processor=to_proc), nothing) - # Try to steal from local queues randomly (deterministic order when MPI to avoid deadlocks) + # Try to steal from local queues randomly # TODO: Prioritize stealing from busiest processors states = proc_states_values(uid) - order = if Dagger.current_acceleration() isa Dagger.MPIAcceleration - sort(1:length(states), by=i->_mpi_proc_rank(states[i].state.proc)) - else - randperm(length(states)) - end - for state in getindex.(Ref(states), order) + P = randperm(length(states)) + for state in getindex.(Ref(states), P) other_istate = state.state if other_istate.proc === to_proc continue @@ -1383,15 +1368,11 @@ function do_tasks(to_proc, return_queue, tasks) end notify(istate.reschedule) - # Kick other processors to make them steal (deterministic order when MPI to avoid deadlocks) + # Kick other processors to make them steal # TODO: Alternatively, automatically balance work instead of blindly enqueueing states = proc_states_values(uid) - order = if Dagger.current_acceleration() isa Dagger.MPIAcceleration - sort(1:length(states), by=i->_mpi_proc_rank(states[i].state.proc)) - else - randperm(length(states)) - end - for other_state in getindex.(Ref(states), order) + P = randperm(length(states)) + for other_state in getindex.(Ref(states), P) other_istate = other_state.state if other_istate.proc === to_proc continue @@ -1509,13 +1490,13 @@ Executes a single task specified by `task` on `to_proc`. #= FIXME: This isn't valid if x is written to x = if x isa Chunk value = lock(TASK_SYNC) do - if haskey(CHUNK_CACHE, x) - Some{Any}(get!(CHUNK_CACHE[x], to_proc) do - # Convert from cached value - # TODO: Choose "closest" processor of same type first - cache_procs = keys(CHUNK_CACHE[x]) - some_proc = Dagger.current_acceleration() isa Dagger.MPIAcceleration ? - minimum(cache_procs, by=_mpi_proc_rank) : first(cache_procs) + if haskey(CHUNK_CACHE, x) + Some{Any}(get!(CHUNK_CACHE[x], to_proc) do + # Convert from cached value + # TODO: Choose "closest" processor of same type first + cache_procs = keys(CHUNK_CACHE[x]) + some_proc = Dagger.current_acceleration() isa Dagger.MPIAcceleration ? + minimum(cache_procs, by=_mpi_proc_rank) : first(cache_procs) some_x = CHUNK_CACHE[x][some_proc] @dagdebug thunk_id :move "Cache hit for argument $id at $some_proc: $some_x" @invokelatest move(some_proc, to_proc, some_x) diff --git a/src/sch/util.jl b/src/sch/util.jl index 38b767588..164685195 100644 --- a/src/sch/util.jl +++ b/src/sch/util.jl @@ -383,7 +383,7 @@ function signature(f, args) value = Dagger.value(arg) if value isa Dagger.DTask # Only occurs via manual usage of signature - value = fetch(value; raw=true) + value = fetch(value; move_value=false, unwrap=false) end if istask(value) throw(ConcurrencyViolationError("Must call `collect_task_inputs!(state, task)` before calling `signature`")) @@ -601,14 +601,12 @@ const DEFAULT_TRANSFER_RATE = UInt64(1_000_000) end chunks_cleanup() - # Shuffle procs around, so equally-costly procs are equally considered (skip shuffle when MPI for deterministic tie-breaking) + # Shuffle procs around, so equally-costly procs are equally considered np = length(procs) @reusable :estimate_task_costs_P Vector{Int} 0 4 np P begin resize!(P, np) copyto!(P, 1:np) - if !(Dagger.current_acceleration() isa Dagger.MPIAcceleration) - randperm!(P) - end + randperm!(P) for idx in 1:np sorted_procs[idx] = procs[P[idx]] end diff --git a/src/submission.jl b/src/submission.jl index fffcc577d..d3102eacf 100644 --- a/src/submission.jl +++ b/src/submission.jl @@ -304,7 +304,7 @@ function eager_spawn(spec::DTaskSpec) uid = eager_next_id() future = ThunkFuture() metadata = DTaskMetadata(spec) - # Propagate inferred return type to options so execute! can skip MPI bcast + # Propagate inferred return type to options if isconcretetype(metadata.return_type) spec.options.return_type = metadata.return_type end diff --git a/src/types/acceleration.jl b/src/types/acceleration.jl index b647dd303..f9aa1d86f 100644 --- a/src/types/acceleration.jl +++ b/src/types/acceleration.jl @@ -1 +1,3 @@ -abstract type Acceleration end \ No newline at end of file +abstract type Acceleration end + +struct DistributedAcceleration <: Acceleration end diff --git a/test/mpi.jl b/test/mpi.jl index c6d2cbae3..7d71e801e 100644 --- a/test/mpi.jl +++ b/test/mpi.jl @@ -1,5 +1,7 @@ using Dagger, MPI, LinearAlgebra + Dagger.accelerate!(:mpi) +Dagger.check_uniformity!(true) comm = MPI.COMM_WORLD rank = MPI.Comm_rank(comm) sz = MPI.Comm_size(comm) @@ -18,9 +20,16 @@ try A[diagind(A)] .+= size(A, 1) B = copy(A) @assert ishermitian(B) - DA = distribute(A, Blocks(20,20)) - DB = distribute(B, Blocks(20,20)) - + DA = zeros(Blocks(20,20), T, datasize, datasize) + for chunk in DA.chunks + Dagger.check_uniform(fetch(chunk; move_value=false, unwrap=false).space) + end + copyto!(DA, A) + DB = zeros(Blocks(20,20), T, datasize, datasize) + for chunk in DB.chunks + Dagger.check_uniform(fetch(chunk; move_value=false, unwrap=false).space) + end + copyto!(DB, B) LinearAlgebra._chol!(DA, UpperTriangular) elapsed_time = @elapsed chol_DB = LinearAlgebra._chol!(DB, UpperTriangular) @@ -33,13 +42,18 @@ try gflops = (datasize^3 / 3) / (elapsed_time * 1e9) ) push!(mpidagger_all_results, result) - - end -catch e +catch if rank == 0 - showerror(stdout, e) + Core.print("Rank 0:\n") + rethrow() + elseif rank == 1 + Core.print("Rank 1:\n") + sleep(1) + rethrow() end +finally + MPI.Barrier(comm) end if rank == 0 #= Write results to CSV From 2f50523fc084cf5a9d9a5b213ee06e6ea39792db Mon Sep 17 00:00:00 2001 From: Felipe Tome Date: Mon, 18 May 2026 19:49:48 -0300 Subject: [PATCH 5/6] MPI: inclusion of MPIRPC and the interfaces with datadeps --- lib/MPIRPC/.gitignore | 1 + lib/MPIRPC/Project.toml | 19 + lib/MPIRPC/README.md | 72 +++ lib/MPIRPC/docs/ARCHITECTURE.md | 500 +++++++++++++++++++ lib/MPIRPC/examples/nonuniform_driver.jl | 63 +++ lib/MPIRPC/examples/rpc_matmul_nonuniform.jl | 131 +++++ lib/MPIRPC/examples/rpc_matmul_uniform.jl | 111 ++++ lib/MPIRPC/examples/uniform_driver.jl | 41 ++ lib/MPIRPC/src/MPIRPC.jl | 53 ++ lib/MPIRPC/src/config.jl | 150 ++++++ lib/MPIRPC/src/dispatch.jl | 37 ++ lib/MPIRPC/src/exceptions.jl | 39 ++ lib/MPIRPC/src/nonuniform.jl | 262 ++++++++++ lib/MPIRPC/src/progress.jl | 162 ++++++ lib/MPIRPC/src/protocol.jl | 181 +++++++ lib/MPIRPC/src/refs.jl | 55 ++ lib/MPIRPC/src/remotecall.jl | 266 ++++++++++ lib/MPIRPC/src/uniform.jl | 314 ++++++++++++ lib/MPIRPC/test/bcast_remotecall_mpiexec.jl | 57 +++ lib/MPIRPC/test/mpi_tests.jl | 96 ++++ lib/MPIRPC/test/nonuniform_daemon_mpiexec.jl | 138 +++++ lib/MPIRPC/test/nonuniform_mpiexec.jl | 242 +++++++++ lib/MPIRPC/test/protocol_tests.jl | 120 +++++ lib/MPIRPC/test/runtests.jl | 7 + lib/MPIRPC/test/uniform_daemon_mpiexec.jl | 326 ++++++++++++ lib/MPIRPC/test/uniform_mpiexec.jl | 336 +++++++++++++ src/array/lu.jl | 2 + src/datadeps/aliasing.jl | 57 ++- src/memory-spaces.jl | 2 +- src/mpi.jl | 132 ++++- 30 files changed, 3934 insertions(+), 38 deletions(-) create mode 100644 lib/MPIRPC/.gitignore create mode 100644 lib/MPIRPC/Project.toml create mode 100644 lib/MPIRPC/README.md create mode 100644 lib/MPIRPC/docs/ARCHITECTURE.md create mode 100644 lib/MPIRPC/examples/nonuniform_driver.jl create mode 100644 lib/MPIRPC/examples/rpc_matmul_nonuniform.jl create mode 100644 lib/MPIRPC/examples/rpc_matmul_uniform.jl create mode 100644 lib/MPIRPC/examples/uniform_driver.jl create mode 100644 lib/MPIRPC/src/MPIRPC.jl create mode 100644 lib/MPIRPC/src/config.jl create mode 100644 lib/MPIRPC/src/dispatch.jl create mode 100644 lib/MPIRPC/src/exceptions.jl create mode 100644 lib/MPIRPC/src/nonuniform.jl create mode 100644 lib/MPIRPC/src/progress.jl create mode 100644 lib/MPIRPC/src/protocol.jl create mode 100644 lib/MPIRPC/src/refs.jl create mode 100644 lib/MPIRPC/src/remotecall.jl create mode 100644 lib/MPIRPC/src/uniform.jl create mode 100644 lib/MPIRPC/test/bcast_remotecall_mpiexec.jl create mode 100644 lib/MPIRPC/test/mpi_tests.jl create mode 100644 lib/MPIRPC/test/nonuniform_daemon_mpiexec.jl create mode 100644 lib/MPIRPC/test/nonuniform_mpiexec.jl create mode 100644 lib/MPIRPC/test/protocol_tests.jl create mode 100644 lib/MPIRPC/test/runtests.jl create mode 100644 lib/MPIRPC/test/uniform_daemon_mpiexec.jl create mode 100644 lib/MPIRPC/test/uniform_mpiexec.jl diff --git a/lib/MPIRPC/.gitignore b/lib/MPIRPC/.gitignore new file mode 100644 index 000000000..ba39cc531 --- /dev/null +++ b/lib/MPIRPC/.gitignore @@ -0,0 +1 @@ +Manifest.toml diff --git a/lib/MPIRPC/Project.toml b/lib/MPIRPC/Project.toml new file mode 100644 index 000000000..4c5af20d2 --- /dev/null +++ b/lib/MPIRPC/Project.toml @@ -0,0 +1,19 @@ +name = "MPIRPC" +uuid = "a8caf107-0824-430d-bb41-9c1c9e5c5a9f" +version = "0.1.0" +authors = ["MPIRPC contributors"] + +[deps] +MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195" +Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b" + +[compat] +MPI = "0.20" +julia = "1.9" + +[extras] +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" + +[targets] +test = ["Test", "Random"] diff --git a/lib/MPIRPC/README.md b/lib/MPIRPC/README.md new file mode 100644 index 000000000..43793f61e --- /dev/null +++ b/lib/MPIRPC/README.md @@ -0,0 +1,72 @@ +# MPIRPC.jl + +Standalone MPI-backed RPC for Julia, modeled after the `Distributed` stdlib remote +invocation pipeline (header + serialized body + boundary, `invokelatest` on the +handler path, `RemoteException`-style error wrapping) and Dagger's +"accelerate-once" backend selection (a single backend value held in a task-local +slot determines uniform vs. non-uniform semantics; the public API is the same in +both modes). + +* Public API mirrors `Distributed`: `remotecall`, `remotecall_fetch`, + `remotecall_wait`, `remote_do`, `bcast_remotecall`, `fetch`, `wait`, with an `MPIFuture` handle. +* Two backends: + * `UniformMPIRPCBackend`: SPMD, every rank both issues and services RPC, + every rank calls `rpc_progress!`. + * `NonUniformMPIRPCBackend`: explicit listener / client roles, optional + `MPI.Comm_split` to isolate RPC from world collectives, only listeners must + poll `rpc_progress!`. +* Transport: `MPI.jl` only. No TCP, no `Distributed`, no `Dagger` dependency. +* Multi-threaded out of the box: handlers dispatch on `Threads.@spawn`, + so user closures that themselves spawn tasks and `fetch` them do not + deadlock under `julia --threads=N`. Trade-off: handler execution is + not FIFO across messages from the same source — use `remotecall_wait` + (or `remotecall` + `fetch`) when you need a side effect committed + before the next call. See `docs/ARCHITECTURE.md` §6. +* Optional progress daemon: pass `daemon = true` to either backend + constructor and `select_mpi_rpc_backend!` will spawn a yield-only + background task that drives `rpc_progress!` for you. The daemon is + placed on Julia's `:interactive` threadpool (start Julia with + `-t N,M`, `M >= 1`), so CPU-bound user code on the `:default` pool + cannot starve the wire pump. Removes the "every rank must call + `rpc_progress!`" requirement at the cost of ≈ 100% utilization of + one OS thread (the daemon never sleeps). `shutdown!` joins the daemon + cleanly. See `docs/ARCHITECTURE.md` §6. +* Idle waits cost zero CPU under `daemon = true`: `wait` / `fetch` on + an `MPIFuture` park on a `Threads.Condition` and are woken by + `deliver!` from the daemon's reply-dispatch path. Without the daemon + (`daemon = false`), `wait` falls back to a yield-rate spin because + the calling task is the only available progress driver. +* Minimum Julia: `1.9` (stable `task_local_storage`, `Base.invokelatest`). + +See [`docs/ARCHITECTURE.md`](docs/ARCHITECTURE.md) for the design rationale, the +deadlock and ABBA discussion, and the side-by-side mapping to `Distributed` and +to Dagger's `accelerate!`. + +## Quick start + +```julia +using MPI, MPIRPC + +MPI.Init(; threadlevel=:multiple) +MPIRPC.select_mpi_rpc_backend!(MPIRPC.UniformMPIRPCBackend(MPI.COMM_WORLD)) + +rank = MPI.Comm_rank(MPI.COMM_WORLD) +nprocs = MPI.Comm_size(MPI.COMM_WORLD) +peer = mod(rank + 1, nprocs) + +result = MPIRPC.@with_progress MPIRPC.remotecall_fetch(+, peer, rank, 100) +@assert result == rank + 100 # the args (rank, 100) travel to `peer` which sums them + +# Fan out the same call to every other rank; collect later: +futs = MPIRPC.bcast_remotecall(MPIRPC.current_mpi_rpc_backend(), *, rank, 7) +vals = map(MPIRPC.fetch, futs) # one entry per destination rank (ascending), excluding `rank` + +MPIRPC.shutdown!() +MPI.Finalize() +``` + +## Security + +Like `Distributed`, MPIRPC deserializes function objects and arguments sent by +peers. The same trust model applies: only run MPIRPC across mutually trusted +ranks. diff --git a/lib/MPIRPC/docs/ARCHITECTURE.md b/lib/MPIRPC/docs/ARCHITECTURE.md new file mode 100644 index 000000000..40de7068c --- /dev/null +++ b/lib/MPIRPC/docs/ARCHITECTURE.md @@ -0,0 +1,500 @@ +# MPIRPC.jl architecture + +MPIRPC is a standalone, MPI-only RPC for Julia, modeled on two existing +designs: + +* the **`Distributed` stdlib remote-call pipeline** — header + serialized body + + boundary, `invokelatest` on the handler path, `RemoteException`-style + error wrapping; and +* **Dagger.jl's "accelerate-once" backend selection** — one installed value + determines the runtime semantics; the public API is identical regardless + of mode. + +There is no dependency on `Distributed` or `Dagger`. The transport is +[`MPI.jl`](https://github.com/JuliaParallel/MPI.jl); `Serialization` is the +only other (stdlib) dependency. + +## 1. Module layout + +| File | Role | +|------|------| +| [`src/MPIRPC.jl`](../src/MPIRPC.jl) | Module entry, exports. | +| [`src/protocol.jl`](../src/protocol.jl) | `MPIRRID`, `MsgHeader`, `CallMsg`, `CallWaitMsg`, `RemoteDoMsg`, `ResultMsg`, `RPCProgressHaltMsg`, frame encode/decode. | +| [`src/exceptions.jl`](../src/exceptions.jl) | `MPIRemoteException` and `run_work_thunk`. | +| [`src/config.jl`](../src/config.jl) | `AbstractMPIRPCBackend`, process-global + task-local installation, `select_mpi_rpc_backend!` / `with_mpi_rpc_backend` / `current_mpi_rpc_backend`. | +| [`src/refs.jl`](../src/refs.jl) | `MPIFuture` and waiter delivery primitives. | +| [`src/uniform.jl`](../src/uniform.jl) | `UniformMPIRPCBackend` — SPMD: every rank issues and services. | +| [`src/nonuniform.jl`](../src/nonuniform.jl) | `NonUniformMPIRPCBackend` — listener/client split. | +| [`src/remotecall.jl`](../src/remotecall.jl) | Public surface: `remotecall`, `remotecall_fetch`, `remotecall_wait`, `remote_do`, `bcast_remotecall`, `wait`/`fetch`. | +| [`src/progress.jl`](../src/progress.jl) | `rpc_progress!`, `rpc_progress_halt!`, `rpc_barrier`, `serve_listener`. | + +## 2. Backend selection (Dagger parallel) + +The selection model is a near-exact port of Dagger's +[`acceleration.jl`](../../Dagger.jl/src/acceleration.jl): + +| Dagger | MPIRPC | +|--------|--------| +| `accelerate!(::Acceleration)` | `select_mpi_rpc_backend!(::AbstractMPIRPCBackend)` | +| `current_acceleration()` | `current_mpi_rpc_backend()` | +| `_with_default_acceleration(f)` | `with_mpi_rpc_backend(f, backend)` | +| `initialize_acceleration!(::Acceleration)` | `initialize_mpi_rpc!(::AbstractMPIRPCBackend)` | +| `Acceleration` | `AbstractMPIRPCBackend` | +| `DistributedAcceleration` (default no-op) | none — no backend is installed by default | +| `MPIAcceleration` | `UniformMPIRPCBackend`, `NonUniformMPIRPCBackend` | + +The installed backend lives in **two slots**: a process-global `Ref` (set +once by `select_mpi_rpc_backend!`) and a task-local override +(`task_local_storage(:mpi_rpc_backend)`, layered by `with_mpi_rpc_backend`). +`current_mpi_rpc_backend()` checks the task-local slot first, falling back +to the global. This matches what Dagger users see in practice — Dagger +declares the `TaskLocalValue` but in practice every `@spawn`-ed task sees +the value because the package exposes a global initializer for the slot — +while keeping our public surface dependency-free. + +**Re-init is not supported in v1.** Calling `select_mpi_rpc_backend!` +twice in the same process replaces the slot but does **not** finalize the +first backend's communicators or in-flight `Isend` buffers. To switch +modes mid-job, finalize MPI and start a fresh process. + +```mermaid +flowchart LR + Init["select_mpi_rpc_backend(b)"] --> InitHook["initialize_mpi_rpc(b)"] + InitHook --> SetGlobal["GLOBAL_BACKEND := b"] + SetGlobal --> Use["remotecall, remote_do, wait, ..."] + Use --> Lookup["current_mpi_rpc_backend()"] + Lookup --> CheckTLS{"task_local_storage hit?"} + CheckTLS -- yes --> ReturnTLS[return TLS value] + CheckTLS -- no --> ReturnGlobal[return GLOBAL_BACKEND] +``` + +## 3. Wire format + +Every RPC message — request or reply — is a single MPI byte payload of the +form: + +``` ++--------------------------+--------------------------+----------------+ +| serialized MsgHeader | serialized AbstractMsg | MSG_BOUNDARY | ++--------------------------+--------------------------+----------------+ +``` + +A fresh `Serialization.Serializer` is created per frame, so the +serializer's back-reference table is bounded to one frame. This is +deliberately simpler than `Distributed.ClusterSerializer`, which keeps +inter-message state for things like worker-ref caching. Full +`ClusterSerializer` parity (anonymous-function global-binding tracking, +shared object-number tables) is out of scope for v1 and is planned as an +opt-in serializer choice later. + +The `MSG_BOUNDARY` (10 bytes) is preserved from Distributed's design +purely as a fail-fast protocol-version / corruption check; MPI is +message-oriented so we do not need the boundary for stream +re-synchronization. + +### Message types + +| Type | Direction | Effect | +|------|-----------|--------| +| `CallMsg{:call}` / `CallMsg{:call_fetch}` | client → server | Server runs `f(args...; kwargs...)`, replies with `ResultMsg(value_or_exception)`. | +| `CallWaitMsg` | client → server | Server runs the call; replies with `:OK` or an `MPIRemoteException`. | +| `RemoteDoMsg` | client → server | Server runs the call, no reply sent (fire-and-forget). | +| `ResultMsg` | server → client | Carries the value; routed to a `MPIFuture` via `MsgHeader.response_oid`. | +| `RPCProgressHaltMsg` | any → rank on `request_tag` | Control: `rpc_progress!` stops draining further requests in the current pass (`return false`); empty `MsgHeader`. | + +### OID layout + +`MPIRRID(whence::Int32, id::UInt64)` is our analog of `Distributed.RRID`, +with `whence` storing the *MPI rank* (within the backend's communicator) +that allocated the id. Each `remotecall*` allocates a fresh `MPIRRID` on +the caller, registers a `MPIFuture` in the backend's waiter table, and +puts the id in `MsgHeader.notify_oid`. The server echoes that id back in +`MsgHeader.response_oid` of the `ResultMsg`; the client routes the reply +to the matching future. + +## 4. Tag layout (uniform & non-uniform) + +Two disjoint tags carry **all** RPC traffic: + +* `request_tag` (default `0xC0DE`): every `CallMsg` / `CallWaitMsg` / + `RemoteDoMsg`, and control [`RPCProgressHaltMsg`](../src/protocol.jl) + (see [`rpc_progress_halt!`](../src/progress.jl)) — same framing as RPC, + consumed inside `rpc_progress!` without spawning a handler. +* `reply_tag` (default `0xC0DF`): every `ResultMsg`. + +Tags **do not encode call identity**. Concurrent calls between the same +pair of ranks are distinguished by `MPIRRID` in the header, not by tag. +This is the most important deadlock-avoidance choice in MPIRPC; see +section 5. + +```mermaid +sequenceDiagram + participant Client + participant ClientMPI as Client MPI + participant ServerMPI as Server MPI + participant Server + Client->>Client: encode_frame MsgHeader plus CallMsg plus boundary + Client->>ClientMPI: Isend buf request_tag dest server + ClientMPI->>ServerMPI: matched on request_tag + Server->>Server: rpc_progress Improbe Mrecv decode_frame + Server->>Server: invokelatest f args + Server->>ServerMPI: Isend ResultMsg reply_tag dest client + ServerMPI->>ClientMPI: matched on reply_tag + Client->>Client: rpc_progress Improbe Mrecv decode_frame deliver MPIFuture +``` + +## 5. Deadlock avoidance and the ABBA argument + +MPIRPC deliberately decouples three concerns that often get conflated in +hand-rolled MPI-RPC code: + +1. **Direction** (request vs reply) — encoded in the *tag*. +2. **Call identity** (which call's reply is this?) — encoded in the + `MPIRRID` carried in the *header*. +3. **Order of completion** — never derived from tag matching; only from + waiter routing. + +This decoupling immediately rules out a class of ABBA bugs: + +* **Symmetric cross-calls.** If rank `i` issues a `remotecall_fetch` to + rank `j` while rank `j` simultaneously issues one to rank `i`, no + `(peer, tag)` pair is shared between request and reply traffic, so + request `Improbe`s never accidentally match a reply or vice versa. +* **Many concurrent calls between one pair.** Two `remotecall`s from `i` + to `j` get distinct `MPIRRID`s; the server's reply is routed by header + rather than position in the queue, so the client correctly resolves + futures even when fetched in a permuted order. +* **Wire-level FIFO.** MPI guarantees in-order delivery between a pair of + ranks on a single tag; we rely on this for *delivery* of nested calls, + e.g. a handler issuing further calls. **Handler execution is not + serialized** — see §6 (threading model) — so an application that needs + ordering of side effects across two messages from the same source + must use `remotecall_wait` (or `remotecall` + `fetch`), not bare + `remote_do` + `remotecall_fetch`. + +### Mesh-shutdown deadlock and `rpc_barrier` + +A subtler hazard appears specifically at *phase boundaries*: a rank `R` +whose own primary `remotecall_fetch` has just completed may still hold an +*unprocessed inbound request* sent by another rank from inside that +rank's handler. If `R` calls `MPI.Barrier` it stops pumping RPC +progress, and the peers still waiting on `R`'s replies hang. + +[`rpc_barrier`](../src/progress.jl) addresses this with a non-blocking +`MPI.Ibarrier` whose completion is awaited *while every rank pumps* +[`rpc_progress!`](../src/progress.jl). Use it between phases of an SPMD +program and before exiting an RPC session. The +[uniform mpiexec suite](../test/uniform_mpiexec.jl) regression-tests the +exact pattern that motivated this primitive (the "nested re-entrant +remotecall_fetch" testset). + +## 6. Threading model + +MPIRPC supports `julia --threads=N` end-to-end: MPI is initialized with +`threadlevel=:multiple`, every MPI call is wrapped in a per-backend +`mpi_lock`, the OID counter is `Threads.Atomic{UInt64}`, and the waiter +table is guarded by `waiters_lock`. The two key design choices: + +1. **Handlers run on `Threads.@spawn`.** When `rpc_progress!` matches a + request, it does not run the handler synchronously on the calling + task — it spawns a fresh task whose body wraps the dispatch in + `with_mpi_rpc_backend(backend) do ... end`. The progress pump + returns immediately. The user closure runs without holding any + MPIRPC lock. +2. **No global progress lock.** Multiple threads may call + `rpc_progress!` concurrently. The only serialization is the + per-backend `mpi_lock` around individual MPI calls (`Improbe`, + `Mrecv!`, `Isend`, `Test`, `Ibarrier`); state transitions on the + waiter table use the separate `waiters_lock`. + +Together these eliminate the deadlock pattern where a handler calls +`Threads.@spawn t; fetch(t)` and `t` itself does an MPIRPC call. With a +synchronous-handler model that held a progress lock, the spawned task +could not acquire the progress lock from another OS thread; the calling +task held it while blocked on `fetch(t)`. With handlers spawned, no lock +is held during user code, and `t` is free to drive progress on its own +task. The +[`uniform / handler that Threads.@spawn-then-fetches an RPC`](../test/uniform_mpiexec.jl) +regression test locks this in. + +### Trade-off: handler execution is not FIFO + +Distributed.jl serializes message dispatch through a single per-worker +reader task; this gives the property that two messages from the same +source on the same `(src, dest, tag)` are *executed* in arrival order. +MPIRPC under multi-threading does not provide that guarantee — it +preserves only **MPI delivery FIFO**. Two requests from rank `R` to rank +`P` may be matched and dispatched on different threads of `P`, in which +case the two handlers run concurrently and the second's side effects +may become visible before the first's. + +The right primitive for "the next call must observe the previous call's +side effect" is `remotecall_wait`: + +```julia +# Wrong under multi-threading: the fetch handler may dispatch on +# another thread *before* the remote_do handler's @eval commits. +remote_do(peer, ...) do; @eval Main.X = 42 end +val = remotecall_fetch(peer) do; Main.X end + +# Right: remotecall_wait blocks the caller until the remote handler +# acknowledges completion, so the side effect is committed before the +# next message goes on the wire. +remotecall_wait(peer, ...) do; @eval Main.X = 42 end +val = remotecall_fetch(peer) do; Main.X end +``` + +The `set-then-read with remotecall_wait` testset in both mpiexec suites +documents this idiom. + +### Hazards that remain (user-error class) + +* `rpc_barrier` (and any MPI collective) must be called from at most one + thread per rank — calling it concurrently posts multiple `Ibarrier` + requests, which is incorrect by MPI semantics. +* User-level non-reentrant locks held across `remotecall_fetch` can + self-deadlock if a remote handler running on the same task re-enters + them. Use `ReentrantLock`, or do not take user locks from inside + remote handlers. +* `wait(f::MPIFuture)` *under `daemon = false`* is a busy-pumping spin: + every waiter calls `rpc_progress!` in a loop with `yield()` between + passes. This is unavoidable in the no-daemon configuration because + `wait` itself is the only progress driver — parking the caller would + also park the wire. *Under `daemon = true`* the spin is replaced by a + proper `Threads.Condition`-park (see "Cond-park under `daemon = true`" + below), so this hazard only applies to applications that opted out of + the daemon. + +### Optional progress daemon + +By default, every rank that needs to receive RPC must call +`rpc_progress!` (or block in `wait`/`fetch`/`rpc_barrier`/`serve_listener`, +each of which pumps progress internally). If that requirement is +inconvenient — for example, because the application's main loop is +already complex, or because a listener rank has long stretches of pure +computation — pass `daemon = true` to the backend constructor: + +```julia +backend = select_mpi_rpc_backend!( + UniformMPIRPCBackend(MPI.COMM_WORLD; daemon = true)) +``` + +`select_mpi_rpc_backend!` then `Threads.@spawn`s a yield-only loop that +calls `rpc_progress!` in tight rotation until `shutdown!` flips +`backend.running[]` to `false`. `shutdown!` then `wait`s on the daemon +task before returning, so the caller can rely on no further MPI calls +being issued from the daemon after `shutdown!` returns. + +The daemon never sleeps. This is a deliberate choice: a sleeping daemon +trades CPU for tail latency on inbound requests, and the same trade-off +is already available without a daemon (just don't enable it, and pump +progress on your own schedule). Enabling the daemon is a "consume one +OS thread, never wait on the wire" decision; if that is the wrong +trade-off for your workload, leave `daemon = false`. + +#### Threadpool isolation + +The daemon is spawned on Julia's `:interactive` threadpool when at +least one interactive thread is configured (`julia -t N,M` with +`M >= 1`). The interactive pool exists exactly for latency-sensitive +tasks that must not be starved by `:default`-pool work. Concretely, +this means: a CPU-bound user computation on a default-pool thread — +e.g. a tight numerical loop with no yield points, or a long `ccall` +to a synchronous C library — **cannot** prevent the daemon from +servicing inbound RPC. The `daemon survives CPU-bound default-pool +work` testset in +[`test/uniform_daemon_mpiexec.jl`](../test/uniform_daemon_mpiexec.jl) +locks this in: it saturates every default-pool thread with a +no-yield busy loop and asserts that a peer's `remotecall_fetch` +*to this rank* still completes promptly. + +Handlers (`_run_handler_task`) deliberately stay on the `:default` +pool: handlers run user code, which is the workload the user wants +done; placing handlers on `:interactive` would steal latency budget +from whoever else uses that pool (the REPL, GC threads, etc.). + +If `Threads.nthreads(:interactive) == 0` at `select_mpi_rpc_backend!` +time, the daemon falls back to `:default` and emits a one-shot `@info` +message naming the consequence. Existing application behavior is +preserved; isolation is opt-in via the launch flag. + +Two regression suites exercise this path: +[`test/uniform_daemon_mpiexec.jl`](../test/uniform_daemon_mpiexec.jl) +and +[`test/nonuniform_daemon_mpiexec.jl`](../test/nonuniform_daemon_mpiexec.jl). +Neither calls `rpc_progress!` or `serve_listener` from user code, so +they are red unless the daemon is doing all of the inbound draining. + +### Cond-park under `daemon = true` + +`MPIFuture` carries a `Threads.Condition`. `deliver!` (called from +`_dispatch_reply!` after `take_waiter!` removes the future from the +waiter table) acquires the cond, stores the value atomically with +`@atomic :release`, and `notify(cond, all=true)` so any number of +waiters on the same future wake at once. + +`wait(::MPIFuture)` reads `f.backend.daemon` and chooses one of two +paths: + +```julia +function Base.wait(f::MPIFuture) + isready(f) && return f + backend = f.backend::AbstractMPIRPCBackend + if backend.daemon + @lock f.cond begin # check-park-recheck idiom + while !isready(f) + wait(f.cond) # task fully descheduled + end + end + else + while !isready(f) # v1 spin: this task is the + rpc_progress!(backend) # only progress driver, so + isready(f) && break # parking would deadlock + yield() + end + end + return f +end +``` + +The check is on `f.backend.daemon`, not the currently-installed +backend, so a future created under a daemon-backed backend is still +park-cheap to wait on even from a task that has scoped a different +backend via `with_mpi_rpc_backend`. + +**Lock geometry.** Three locks coexist on the reply path: + +1. `backend.mpi_lock` — wraps the raw MPI calls in `_try_recv_one!` + (`Improbe`, `Imrecv!`, `Test`). It is *released* between `Test` + polls so the daemon's wait for a slow rendezvous receive does not + block other threads' MPI calls. +2. `backend.waiters_lock` — wraps `take_waiter!`, the lookup that + removes the future from the waiter table. +3. `f.cond` — wraps the value store and `notify` in `deliver!`, and + the predicate check / `wait` in the consumer. + +These are acquired strictly sequentially in `_dispatch_reply!`: +`mpi_lock` is released before `waiters_lock` is taken (the lookup +happens after the buffer is fully out of MPI), and `waiters_lock` is +released before `f.cond` is taken (`take_waiter!` returns the future, +then `deliver!` runs). At no point are two of them held at the same +time, so there is no cross-future or cross-rank ordering hazard. + +**Wakeup atomicity.** `deliver!` and the consumer use the standard +condition-variable idiom: predicate (`isready`) is read inside the +lock, and `notify` happens *after* the predicate is set, *while +holding the same lock*. This prevents the lost-wakeup race where the +consumer's `isready` check observes `false` and parks just after the +producer's `notify` fired but before the producer's store became +visible. + +### Non-blocking receive (`Imrecv!` + `Test`/`yield`) + +`_try_recv_one!` does not call `MPI.Mrecv!`. The blocking +matched-receive primitive would have held `backend.mpi_lock` for the +entire duration of the message transfer, which for a large +rendezvous-protocol payload (multi-MB) can be milliseconds — during +which no other thread on the rank could acquire `mpi_lock` to do its +own MPI call (e.g. a handler trying to `Isend` a reply for a +different RPC). + +Instead, `_try_recv_one!` runs in two phases: + +1. **Match-and-post under `mpi_lock`.** `Improbe` matches the next + message on the requested tag, the buffer is allocated to the + matched count, and `Imrecv!` is posted. All three steps are + atomic — they have to be, because the matched message handle is + only valid until consumed by `Imrecv!`. + +2. **Poll-and-yield without `mpi_lock`.** The non-blocking request + from `Imrecv!` is then `Test`'d in a loop; `mpi_lock` is acquired + only for the `Test` call itself (microseconds) and released before + `yield()`. While the daemon's task is yielded, other threads can + freely acquire `mpi_lock` and progress their own MPI work. + +For typical RPC payloads under MPI's eager threshold (often 64 KiB on +OpenMPI, 256 KiB on MPICH), `Test` succeeds on the first poll because +the eager protocol delivered the entire payload during `Improbe`'s +match; the cost vs. blocking `Mrecv!` is one extra `Test` call and +one lock acquire/release pair, single-digit microseconds. For large +payloads on rendezvous, the daemon thread becomes truly cooperative +during the receive — yielding repeatedly until the transfer +completes — and other threads on the rank can issue their own MPI +calls in the gaps. + +Combined with the threadpool isolation (daemon on `:interactive`), +this means the only places where a thread spends real wall time +inside MPI without yielding are `MPI.Init`, `MPI.Comm_dup`, and +`MPI.Finalize`, all of which are one-time startup or shutdown +events. **At runtime, no rank-level MPI call holds a thread without +yield.** + +The `large concurrent payloads via Imrecv! + Test/yield` testset in +[`test/uniform_daemon_mpiexec.jl`](../test/uniform_daemon_mpiexec.jl) +exercises this path with multiple concurrent 1-MiB round-trips that +deliberately exceed common eager thresholds. + +## 7. Uniform vs non-uniform — operational rules + +| | Uniform | Non-uniform | +|---|---|---| +| Who initiates RPC? | Any rank | Any rank | +| Who services RPC? | All ranks | Only `listener_ranks` | +| Who must call `rpc_progress!`? | Every rank, regularly *(or none, if `daemon = true`)* | Listeners (for requests) and any rank with outstanding `MPIFuture`s (for replies). `wait`/`fetch` already pumps. *(Or no rank, if `daemon = true`.)* | +| `dest_rank` validity | Any peer in `comm` | Must be in `listener_ranks` | +| Subcomm by default | Yes (`MPI.Comm_dup`) | Yes (`MPI.Comm_dup`) | +| Phase barrier | `rpc_barrier(backend)` | `rpc_barrier(backend)` | +| World collectives | Run on `MPI.COMM_WORLD` independently | Same — the duped comm isolates RPC from world traffic | + +### Collective safety in non-uniform mode + +Because the backend duplicates the user's communicator, world-level +collectives (e.g. `MPI.Barrier(MPI.COMM_WORLD)`) cannot be matched against +RPC `Isend`s. Even so, **all ranks of `COMM_WORLD` must still participate +in any world collective**. The +[non-uniform mpiexec suite](../test/nonuniform_mpiexec.jl) explicitly +exercises a `MPI.Barrier(MPI.COMM_WORLD)` interleaved with subcomm RPC to +make sure the test does not encode a collective mismatch as "pass". + +## 8. World-age and `invokelatest` + +Where `Distributed/process_messages.jl` wraps the body deserializer in +`invokelatest(deserialize_msg, ...)` and the handler in +`invokelatest(msg.f, ...)`, MPIRPC does the same: see `decode_frame` in +[`protocol.jl`](../src/protocol.jl) and `_execute_request!` in +[`uniform.jl`](../src/uniform.jl). This lets remote ranks call functions +defined after their own world age has advanced (e.g. user code defined +between two RPC phases). + +## 9. Mapping to Distributed (cheat sheet) + +| Distributed | MPIRPC | +|---|---| +| `Future` | `MPIFuture` | +| `RRID(whence::Int, id::Int)` | `MPIRRID(whence::Int32, id::UInt64)` | +| `MsgHeader(response_oid, notify_oid)` | identical | +| `CallMsg{:call}`, `CallMsg{:call_fetch}`, `CallWaitMsg`, `RemoteDoMsg`, `ResultMsg` | identical names and roles | +| `MSG_BOUNDARY` (10 bytes) | `MSG_BOUNDARY` (10 bytes; different sentinel) | +| `ClusterSerializer` over a TCP stream | `Serializer` over a per-message `IOBuffer` | +| `invokelatest(deserialize_msg, ...)` | `invokelatest(deserialize, ...)` in `decode_frame` | +| `invokelatest(msg.f, msg.args...; ...)` in `handle_msg` | identical idiom in `_execute_request!` | +| `RemoteException` wrapping `CapturedException` | `MPIRemoteException` wrapping `CapturedException` | +| `remotecall` / `remotecall_fetch` / `remotecall_wait` / `remote_do` / `fetch` / `wait` | same names | + +## 10. Security + +Same model as `Distributed`: deserializing function objects and arguments +sent by peers is equivalent to letting them run code. Only run MPIRPC +across mutually trusted ranks. + +## 11. What is intentionally not in v1 + +* No `ClusterSerializer`-equivalent global-binding propagation. +* No `RemoteChannel` (server-resident channel) — `MPIFuture` is the only + reference type. +* No FIFO of *handler execution* across messages (only MPI delivery + FIFO). See §6. +* No mid-process backend re-init. +* No transport other than `MPI.jl`. +* No retry / fault tolerance — a process death is fatal to outstanding + futures targeting that rank, just as in `Distributed`. diff --git a/lib/MPIRPC/examples/nonuniform_driver.jl b/lib/MPIRPC/examples/nonuniform_driver.jl new file mode 100644 index 000000000..d7777ddd5 --- /dev/null +++ b/lib/MPIRPC/examples/nonuniform_driver.jl @@ -0,0 +1,63 @@ +# Non-uniform driver: half the ranks are dedicated listeners ("workers"), +# the rest are clients that submit work and collect results. +# +# Run with: +# mpiexec -n 4 julia --project=. examples/nonuniform_driver.jl + +using MPI +using MPIRPC + +MPI.Init(; threadlevel=:multiple) + +const WORLD_RANK = MPI.Comm_rank(MPI.COMM_WORLD) +const NPROC = MPI.Comm_size(MPI.COMM_WORLD) + +NPROC >= 2 || error("non-uniform example needs at least 2 ranks") + +# Half the ranks are listeners. Adjust to your topology. +const HALF = max(1, NPROC ÷ 2) +const LISTENER_RANKS = collect(0:(HALF - 1)) + +backend = MPIRPC.select_mpi_rpc_backend!( + NonUniformMPIRPCBackend(MPI.COMM_WORLD; listener_ranks = LISTENER_RANKS)) + +if MPIRPC.is_listener(backend) + println("[rank $WORLD_RANK] LISTENER; serving until shutdown") + flush(stdout) + + # Listener loop: pump progress, exit when a client tells us to stop. + # `rpc_progress!` itself does not block; this loop is the listener's + # main service loop. In production code, interleave `rpc_progress!` + # with whatever else the listener does. + while backend.running[] + MPIRPC.rpc_progress!(backend) + yield() + end + println("[rank $WORLD_RANK] LISTENER exit") +else + println("[rank $WORLD_RANK] CLIENT; submitting calls to $LISTENER_RANKS") + flush(stdout) + + # Submit calls to every listener. + K = 8 + futs = MPIFuture[] + for ℓ in LISTENER_RANKS, k in 1:K + push!(futs, MPIRPC.remotecall(+, ℓ, WORLD_RANK * 100, k)) + end + for f in futs + v = fetch(f) + println("[rank $WORLD_RANK] got $v from listener (expected WORLD_RANK*100 + k)") + end + + # Tell every listener to wind down. `remote_do` is fire-and-forget; the + # listeners flip their own `running` flag and exit their service loop. + for ℓ in LISTENER_RANKS + MPIRPC.remote_do(MPIRPC.shutdown!, ℓ) + end +end + +# Both roles must reach the same barrier. `rpc_barrier` is preferred over +# `MPI.Barrier` so any in-flight RPC drains before the program exits. +MPIRPC.rpc_barrier() +MPIRPC.shutdown!() +MPI.Finalize() diff --git a/lib/MPIRPC/examples/rpc_matmul_nonuniform.jl b/lib/MPIRPC/examples/rpc_matmul_nonuniform.jl new file mode 100644 index 000000000..2a1ed0370 --- /dev/null +++ b/lib/MPIRPC/examples/rpc_matmul_nonuniform.jl @@ -0,0 +1,131 @@ +# RPC-based matrix multiplication (non-uniform: listeners own B, clients own A) +# +# First half of ranks are **listeners** (service RPC); the rest are **clients** +# (issue RPC only). Each listener owns one contiguous column strip of `B`; +# each client owns one contiguous row block of `A`. Clients fetch every column +# strip from the listener ranks via `remotecall_fetch` — clients never call +# each other (matches `NonUniformMPIRPCBackend` constraints). +# +# Run from the MPIRPC package root under lib/MPIRPC (needs at least 4 ranks, same as package +# non-uniform tests): +# mpiexec -n 4 julia --threads=2,1 --project=. examples/rpc_matmul_nonuniform.jl +# +# Optional: `N=128` env sets matrix dimension (must satisfy divisibility below). +# +# We enable `daemon = true` so listeners continuously drain inbound RPC without +# a dedicated manual `serve_listener` loop while clients issue many fetches. + +using MPI +using MPIRPC + +const MAX_VERIFY_N = 256 + +a_elem(i::Int, j::Int) = sin(i) + cos(j) +b_elem(i::Int, j::Int) = tanh(i * 0.01) + tanh(j * 0.01) + +const _RPC_B_STRIP = Ref{Matrix{Float64}}() + +MPI.Init(; threadlevel = :multiple) + +const WORLD_RANK = MPI.Comm_rank(MPI.COMM_WORLD) +const NPROC = MPI.Comm_size(MPI.COMM_WORLD) + +NPROC >= 4 || error("non-uniform matmul example needs at least 4 ranks (got $NPROC)") + +const HALF = max(1, NPROC ÷ 2) +const LISTENER_RANKS = collect(0:(HALF - 1)) +const CLIENT_RANKS = collect(HALF:(NPROC - 1)) +const IS_LISTENER = WORLD_RANK in LISTENER_RANKS + +const N_LISTENERS = length(LISTENER_RANKS) +const N_CLIENTS = length(CLIENT_RANKS) + +backend = MPIRPC.select_mpi_rpc_backend!( + NonUniformMPIRPCBackend(MPI.COMM_WORLD; + listener_ranks = LISTENER_RANKS, + daemon = true)) +const COMM = backend.comm +const RANK = MPI.Comm_rank(COMM) +@assert RANK == WORLD_RANK + +const N = parse(Int, get(ENV, "N", "64")) +N % N_LISTENERS == 0 || error("N ($N) must be divisible by number of listeners ($N_LISTENERS)") +N % N_CLIENTS == 0 || error("N ($N) must be divisible by number of clients ($N_CLIENTS)") + +const N_LOC_COL = N ÷ N_LISTENERS # columns per listener strip +const N_LOC_ROW = N ÷ N_CLIENTS # rows per client block + +# --- Build owned pieces ------------------------------------------------------- + +if IS_LISTENER + ℓ = findfirst(==(WORLD_RANK), LISTENER_RANKS)::Int + COL_LO = (ℓ - 1) * N_LOC_COL + 1 + COL_HI = ℓ * N_LOC_COL + B_loc = Matrix{Float64}(undef, N, N_LOC_COL) + for lj in 1:N_LOC_COL + j_g = COL_LO + lj - 1 + for i in 1:N + B_loc[i, lj] = b_elem(i, j_g) + end + end + _RPC_B_STRIP[] = B_loc +else + c = findfirst(==(WORLD_RANK), CLIENT_RANKS)::Int + ROW_LO = (c - 1) * N_LOC_ROW + 1 + ROW_HI = c * N_LOC_ROW + A_loc = Matrix{Float64}(undef, N_LOC_ROW, N) + for li in 1:N_LOC_ROW + i_g = ROW_LO + li - 1 + for j in 1:N + A_loc[li, j] = a_elem(i_g, j) + end + end +end + +MPIRPC.rpc_barrier() + +# --- Clients multiply; listeners only service RPC (daemon drives progress) ----- + +if IS_LISTENER + C_loc = nothing +else + C_loc = zeros(N_LOC_ROW, N) + for ℓ in 1:N_LISTENERS + listener = LISTENER_RANKS[ℓ] + clo = (ℓ - 1) * N_LOC_COL + 1 + chi = ℓ * N_LOC_COL + B_panel = MPIRPC.remotecall_fetch(listener) do + return copy(Main._RPC_B_STRIP[]) + end + size(B_panel) == (N, N_LOC_COL) || error("bad B_panel size from listener $listener") + C_loc[:, clo:chi] = A_loc * B_panel + end +end + +MPIRPC.rpc_barrier() + +# --- Verification on first client only (small N) ------------------------------ + +VERIFY_RANK = first(CLIENT_RANKS) +if WORLD_RANK == VERIFY_RANK && N ≤ MAX_VERIFY_N && !IS_LISTENER + c = findfirst(==(WORLD_RANK), CLIENT_RANKS)::Int + ROW_LO = (c - 1) * N_LOC_ROW + 1 + ROW_HI = c * N_LOC_ROW + A_full = Matrix{Float64}(undef, N, N) + B_full = Matrix{Float64}(undef, N, N) + for j in 1:N, i in 1:N + A_full[i, j] = a_elem(i, j) + B_full[i, j] = b_elem(i, j) + end + C_ref = A_full * B_full + ref_rows = C_ref[ROW_LO:ROW_HI, :] + err = maximum(abs.(C_loc .- ref_rows)) + println("[client rank $WORLD_RANK] max |C_loc - C_ref| on owned rows = ", err) + @assert err < 1e-10 * max(1.0, maximum(abs.(C_ref))) +elseif WORLD_RANK == VERIFY_RANK + println("[client rank $WORLD_RANK] skipping dense verification (N=$N > $MAX_VERIFY_N)") +end + +MPIRPC.rpc_barrier() +MPIRPC.shutdown!() +MPI.Finalize() diff --git a/lib/MPIRPC/examples/rpc_matmul_uniform.jl b/lib/MPIRPC/examples/rpc_matmul_uniform.jl new file mode 100644 index 000000000..90d67a8f3 --- /dev/null +++ b/lib/MPIRPC/examples/rpc_matmul_uniform.jl @@ -0,0 +1,111 @@ +# RPC-based matrix multiplication (uniform SPMD) +# +# Each rank owns a row block of `A` and a column block of `B`. To form +# `C = A * B`, rank `r` needs every column strip of `B`; those strips live on +# distinct ranks, so we fetch each peer's strip via `remotecall_fetch`. No MPI +# collectives move `A` / `B` on the compute path — only point-to-point RPC. +# +# Run from the MPIRPC package root (e.g. Dagger.jl/lib/MPIRPC): +# mpiexec -n 4 julia --project=. examples/rpc_matmul_uniform.jl +# +# Optional threading (matches package tests; useful if you switch to +# `UniformMPIRPCBackend(...; daemon = true)` — daemon uses `:interactive`): +# mpiexec -n 4 julia --threads=2,1 --project=. examples/rpc_matmul_uniform.jl +# +# Limitations: +# * Matrix dimension `N` below must divide `NPROC`; keep `N` modest — each RPC +# returns an `n × (N/nproc)` Float64 strip (multi-MiB if `N` is huge). +# * Rank-0 verification builds dense `N×N` references and only runs when +# `N ≤ MAX_VERIFY_N` (memory + time). + +using MPI +using MPIRPC + +const MAX_VERIFY_N = 256 + +# Deterministic element generators (global indices i, j ∈ 1:N). +a_elem(i::Int, j::Int) = sin(i) + cos(j) +b_elem(i::Int, j::Int) = tanh(i * 0.01) + tanh(j * 0.01) + +# Filled before any RPC so the handler closure on the destination rank reads the +# correct strip via `Main._RPC_B_STRIP` (each MPI rank is its own process). +const _RPC_B_STRIP = Ref{Matrix{Float64}}() + +MPI.Init(; threadlevel = :multiple) +backend = MPIRPC.select_mpi_rpc_backend!(UniformMPIRPCBackend(MPI.COMM_WORLD)) +const COMM = backend.comm +const RANK = MPI.Comm_rank(COMM) +const NPROC = MPI.Comm_size(COMM) + +# Demo size: override with `N=128 julia ...` if desired (must divide world size). +const N = parse(Int, get(ENV, "N", "64")) +NPROC >= 1 || error("need at least 1 rank") +N % NPROC == 0 || error("N ($N) must be divisible by NPROC ($NPROC)") +const N_LOC = N ÷ NPROC + +# --- Local owned data --------------------------------------------------------- + +const ROW_LO = RANK * N_LOC + 1 +const ROW_HI = (RANK + 1) * N_LOC +const COL_LO = RANK * N_LOC + 1 +const COL_HI = (RANK + 1) * N_LOC + +A_loc = Matrix{Float64}(undef, N_LOC, N) +B_loc = Matrix{Float64}(undef, N, N_LOC) + +for li in 1:N_LOC + i_g = ROW_LO + li - 1 + for j in 1:N + A_loc[li, j] = a_elem(i_g, j) + end +end +for lj in 1:N_LOC + j_g = COL_LO + lj - 1 + for i in 1:N + B_loc[i, lj] = b_elem(i, j_g) + end +end + +_RPC_B_STRIP[] = B_loc + +MPIRPC.rpc_barrier() + +# --- Compute C_loc = A_loc * B via RPC-fetched column strips ----------------- + +C_loc = zeros(N_LOC, N) + +for q in 0:(NPROC - 1) + clo = q * N_LOC + 1 + chi = (q + 1) * N_LOC + # Closure runs on rank `q`; it captures nothing from rank `r` except what + # Julia serializes — here we only need rank q's Main._RPC_B_STRIP. + B_panel = MPIRPC.remotecall_fetch(q) do + return copy(Main._RPC_B_STRIP[]) + end + size(B_panel) == (N, N_LOC) || error("unexpected B_panel size on rank $RANK from $q") + C_loc[:, clo:chi] = A_loc * B_panel +end + +MPIRPC.rpc_barrier() + +# --- Verification on rank 0 (small N only) ------------------------------------ + +if RANK == 0 && N ≤ MAX_VERIFY_N + A_full = Matrix{Float64}(undef, N, N) + B_full = Matrix{Float64}(undef, N, N) + for j in 1:N, i in 1:N + A_full[i, j] = a_elem(i, j) + B_full[i, j] = b_elem(i, j) + end + C_ref = A_full * B_full + ref_rows = C_ref[ROW_LO:ROW_HI, :] + err = maximum(abs.(C_loc .- ref_rows)) + println("[rank 0] max |C_loc - C_ref| on owned rows = ", err) + @assert err < 1e-10 * max(1.0, maximum(abs.(C_ref))) +elseif RANK == 0 + println("[rank 0] skipping dense verification (N=$N > MAX_VERIFY_N=$MAX_VERIFY_N)") +end + +MPIRPC.rpc_barrier() +MPIRPC.shutdown!() +MPI.Finalize() diff --git a/lib/MPIRPC/examples/uniform_driver.jl b/lib/MPIRPC/examples/uniform_driver.jl new file mode 100644 index 000000000..5944db9ab --- /dev/null +++ b/lib/MPIRPC/examples/uniform_driver.jl @@ -0,0 +1,41 @@ +# Minimal uniform / SPMD driver for MPIRPC. +# +# Run with: +# mpiexec -n 4 julia --project=. examples/uniform_driver.jl +# +# Every rank issues calls and services calls; progress is pumped implicitly +# by `wait`/`fetch` and explicitly via `rpc_barrier` between phases. + +using MPI +using MPIRPC + +MPI.Init(; threadlevel=:multiple) +backend = MPIRPC.select_mpi_rpc_backend!(UniformMPIRPCBackend(MPI.COMM_WORLD)) +const RANK = MPI.Comm_rank(backend.comm) +const NPROC = MPI.Comm_size(backend.comm) + +println("[rank $RANK] up; world size = $NPROC") + +# Phase 1: every rank asks its right-hand neighbor what its rank is. +peer = mod(RANK + 1, NPROC) +neighbor_rank = MPIRPC.remotecall_fetch(peer) do + return MPI.Comm_rank(MPI.COMM_WORLD) +end +println("[rank $RANK] right neighbor reports rank=$neighbor_rank (expected $peer)") + +# Phase boundary: pump progress until every rank arrives. Without this, a +# rank that finishes early could exit while a peer's request to it is still +# unprocessed in its inbox. +MPIRPC.rpc_barrier() + +# Phase 2: scatter "compute" work all-to-all. +peers = [r for r in 0:(NPROC-1) if r != RANK] +futs = [MPIRPC.remotecall(*, p, RANK + 1, p + 1) for p in peers] +for (p, f) in zip(peers, futs) + @assert fetch(f) == (RANK + 1) * (p + 1) +end +println("[rank $RANK] all-to-all phase 2 done") + +MPIRPC.rpc_barrier() +MPIRPC.shutdown!() +MPI.Finalize() diff --git a/lib/MPIRPC/src/MPIRPC.jl b/lib/MPIRPC/src/MPIRPC.jl new file mode 100644 index 000000000..b7c42ece4 --- /dev/null +++ b/lib/MPIRPC/src/MPIRPC.jl @@ -0,0 +1,53 @@ +""" + MPIRPC + +Standalone MPI-backed RPC for Julia, modeled on the `Distributed` stdlib +(`MsgHeader` + serialized body + boundary, `invokelatest` on the handler +path, `RemoteException`-style error wrapping) and on Dagger's +"accelerate-once" backend selection (a single backend value held in a +task-local slot determines uniform vs. non-uniform semantics; the public +API is the same in both modes). + +See `docs/ARCHITECTURE.md` for the design narrative, the deadlock and +ABBA discussion, and the side-by-side mapping to `Distributed` and Dagger. +""" +module MPIRPC + +using Serialization +using MPI + +export AbstractMPIRPCBackend, + UniformMPIRPCBackend, + NonUniformMPIRPCBackend, + MPIFuture, + MPIRRID, + MPIRemoteException, + select_mpi_rpc_backend!, + current_mpi_rpc_backend, + initialize_mpi_rpc!, + with_mpi_rpc_backend, + shutdown!, + remotecall, + remotecall_fetch, + remotecall_wait, + remote_do, + bcast_remotecall, + rpc_progress!, + rpc_progress_halt!, + rpc_barrier, + serve_listener, + is_listener, + listener_ranks, + @with_progress + +include("protocol.jl") +include("exceptions.jl") +include("config.jl") +include("refs.jl") +include("dispatch.jl") +include("uniform.jl") +include("nonuniform.jl") +include("remotecall.jl") +include("progress.jl") + +end # module MPIRPC diff --git a/lib/MPIRPC/src/config.jl b/lib/MPIRPC/src/config.jl new file mode 100644 index 000000000..fbdb20bb9 --- /dev/null +++ b/lib/MPIRPC/src/config.jl @@ -0,0 +1,150 @@ +""" + AbstractMPIRPCBackend + +Marker supertype for an installed MPIRPC backend. Concrete backends +(`UniformMPIRPCBackend`, `NonUniformMPIRPCBackend`) hold the communicator, +tag layout, OID counter, waiter table, and MPI lock that drive the wire +protocol; the public RPC surface (`remotecall`, `remotecall_fetch`, +`rpc_progress!`, ...) dispatches dynamically on this type so call sites +never branch on mode. + +The selection model mirrors Dagger's acceleration: see +[`Dagger.jl/src/acceleration.jl`](../../Dagger.jl/src/acceleration.jl) +(`accelerate!`, `current_acceleration`, `_with_default_acceleration`). +[`select_mpi_rpc_backend!`](@ref) installs a backend once for the whole +process; [`with_mpi_rpc_backend`](@ref) layers a task-local override for +scoped tests. +""" +abstract type AbstractMPIRPCBackend end + +const _BACKEND_KEY = :mpi_rpc_backend + +# Process-global default backend, written once at startup by +# `select_mpi_rpc_backend!`. Reads are unsynchronized; the assumption is the +# same as Distributed's `init_parallel` / Dagger's `accelerate!`: install +# happens before any RPC traffic and there is at most one installed backend +# per process for the duration of its life. +const _GLOBAL_BACKEND = Ref{Union{AbstractMPIRPCBackend, Nothing}}(nothing) + +""" + initialize_mpi_rpc!(backend::AbstractMPIRPCBackend) + +Per-backend hook (analog of Dagger's `initialize_acceleration!`). Concrete +backends override this to do MPI initialization, communicator splits, tag +range allocation, etc. The default is a no-op. +""" +initialize_mpi_rpc!(::AbstractMPIRPCBackend) = nothing + +""" + select_mpi_rpc_backend!(backend::AbstractMPIRPCBackend) -> backend + +Install `backend` as the current backend for the whole process and run +[`initialize_mpi_rpc!`](@ref) once. After this call, every public RPC +function (`remotecall`, `remotecall_fetch`, `rpc_progress!`, ...) — from +*any* task — dispatches through this backend, unless overridden inside +[`with_mpi_rpc_backend`](@ref). + +Re-entering with a different backend is **not supported** in v1: calling +this twice in the same process replaces the slot but the on-the-wire state +(communicators, in-flight `Isend` buffers, OID counters) of the first +backend is not reset. To switch modes mid-job, call [`shutdown!`](@ref) on +the existing backend, finalize MPI if appropriate, and start a fresh process. +""" +function select_mpi_rpc_backend!(backend::AbstractMPIRPCBackend) + initialize_mpi_rpc!(backend) + _GLOBAL_BACKEND[] = backend + # Spawn the optional yield-only progress daemon. We do this *after* + # initialize_mpi_rpc! so the daemon's first `rpc_progress!` finds a + # fully constructed communicator, and *after* the global slot has been + # written so a re-entrant `current_mpi_rpc_backend()` from inside the + # daemon resolves to the right object. + # + # Threadpool placement: prefer `:interactive` so a CPU-bound user + # computation on the `:default` pool cannot starve the wire pump. + # The interactive pool exists exactly for latency-sensitive tasks + # that must keep getting CPU regardless of default-pool load. Fall + # back to `:default` only if the user did not configure any + # interactive threads (`julia -t N,0` or implicit zero); in that + # case we emit a one-time info message so the operator knows why + # tail latency might rise under heavy compute. + if backend.daemon && backend.daemon_task === nothing + if Threads.nthreads(:interactive) > 0 + backend.daemon_task = errormonitor(Threads.@spawn :interactive _daemon_loop(backend)) + else + @info """MPIRPC: no `:interactive` threads configured; the daemon will share \ + the `:default` pool with user computation. CPU-bound user code on the \ + default pool can starve the wire pump. Start Julia with `-t N,M` (M >= 1) \ + to give the daemon a dedicated interactive thread.""" + backend.daemon_task = errormonitor(Threads.@spawn _daemon_loop(backend)) + end + end + return backend +end + +""" + current_mpi_rpc_backend() -> AbstractMPIRPCBackend + +Return the backend installed for the current task. Prefers a task-local +override (set by [`with_mpi_rpc_backend`](@ref)); otherwise falls back to +the process-global slot installed by [`select_mpi_rpc_backend!`](@ref). +Throws `ArgumentError` if no backend has been installed yet. +""" +function current_mpi_rpc_backend() + local_backend = task_local_storage(_BACKEND_KEY, nothing) + if local_backend !== nothing + return local_backend::AbstractMPIRPCBackend + end + g = _GLOBAL_BACKEND[] + g === nothing && throw(ArgumentError( + "no MPIRPC backend installed; call `select_mpi_rpc_backend!` first")) + return g::AbstractMPIRPCBackend +end + +""" + with_mpi_rpc_backend(f, backend) -> f() + +Scoped variant: install `backend` for the duration of `f()` (in this task +only) and restore the previous task-local override on exit. The analog of +Dagger's `_with_default_acceleration`. This does **not** call +`initialize_mpi_rpc!`; pass an already-initialized backend, or call +`select_mpi_rpc_backend!` once at startup and use this helper only to layer +scoped overrides during tests. +""" +function with_mpi_rpc_backend(f, backend::AbstractMPIRPCBackend) + prev = task_local_storage(_BACKEND_KEY, nothing) + task_local_storage(_BACKEND_KEY, backend) + try + return f() + finally + if prev === nothing + delete!(task_local_storage(), _BACKEND_KEY) + else + task_local_storage(_BACKEND_KEY, prev) + end + end +end + +""" + shutdown!([backend]) + +Stop the listener loop on this rank by clearing the backend's `running` +flag. Outstanding `Isend` requests are reaped on a best-effort basis. +This is a *local* operation; in non-uniform mode the application should +broadcast a stop signal across listener ranks before calling this on each. + +If a progress daemon was spawned by [`select_mpi_rpc_backend!`](@ref) +(`daemon = true` on the backend), this function also `wait`s on the +daemon task so the caller can rely on no further `rpc_progress!` +running once `shutdown!` returns. The wait is bounded only by the +daemon's own loop body; because the loop is yield-only and checks the +flag every iteration, it terminates promptly. +""" +function shutdown!(backend::AbstractMPIRPCBackend = current_mpi_rpc_backend()) + backend.running[] = false + t = backend.daemon_task + if t !== nothing + wait(t) + backend.daemon_task = nothing + end + return nothing +end diff --git a/lib/MPIRPC/src/dispatch.jl b/lib/MPIRPC/src/dispatch.jl new file mode 100644 index 000000000..488056626 --- /dev/null +++ b/lib/MPIRPC/src/dispatch.jl @@ -0,0 +1,37 @@ +""" +Shared request/reply dispatch helpers for concrete backends. See +`uniform.jl` / `nonuniform.jl` for `take_waiter!` / `_send_reply!` methods. +""" + +function _reply_exception!(b::AbstractMPIRPCBackend, notify_oid::MPIRRID, rank::Int, ex) + ce = ex isa CapturedException ? ex : CapturedException(ex, catch_backtrace()) + _send_reply!(b, notify_oid, MPIRemoteException(rank, ce)) + return nothing +end + +""" +Deliver a reply-frame body deserialization failure into the client's +`MPIFuture` when `response_oid` is known. Returns `true` if a waiter +was found and delivered. +""" +function _deliver_reply_deserialize_failure!(b::AbstractMPIRPCBackend, + response_oid::MPIRRID, rank::Int, body_err) + is_null(response_oid) && return false + fut = take_waiter!(b, response_oid) + if fut === nothing + @warn "MPIRPC: reply for unknown waiter (response_oid=$(response_oid)) on rank $(rank); dropping" + return false + end + deliver!(fut, MPIRemoteException(rank, CapturedException(body_err, catch_backtrace()))) + return true +end + +function _notify_request_decode_failure!(b::AbstractMPIRPCBackend, header::MsgHeader, + rank::Int, body_err) + if !is_null(header.notify_oid) + _reply_exception!(b, header.notify_oid, rank, body_err) + else + showerror(stderr, CapturedException(body_err, catch_backtrace())) + end + return nothing +end diff --git a/lib/MPIRPC/src/exceptions.jl b/lib/MPIRPC/src/exceptions.jl new file mode 100644 index 000000000..e4def57dc --- /dev/null +++ b/lib/MPIRPC/src/exceptions.jl @@ -0,0 +1,39 @@ +""" + MPIRemoteException(rank::Int, captured::CapturedException) + +Local analogue of `Distributed.RemoteException`. Wraps the originating MPI +rank and a `Base.CapturedException` for the original error and its captured +backtrace, so end-users can inspect the remote failure without any +`Distributed` dependency. +""" +struct MPIRemoteException <: Exception + rank::Int + captured::CapturedException +end + +MPIRemoteException(captured::CapturedException) = MPIRemoteException(-1, captured) + +Base.capture_exception(ex::MPIRemoteException, _) = ex + +function Base.showerror(io::IO, re::MPIRemoteException) + print(io, "On MPI rank ", re.rank, ":\n") + showerror(io, re.captured) +end + +""" + run_work_thunk(thunk, rank; print_error::Bool=false) -> result_or_exception + +Execute `thunk()` and either return the value or wrap any thrown exception +in an `MPIRemoteException`. When `print_error` is true (as for +`Distributed.remote_do`), the captured exception is also printed to stderr, +mirroring `Distributed.run_work_thunk(thunk, print_error)`. +""" +function run_work_thunk(thunk::Function, rank::Int; print_error::Bool=false) + try + return thunk() + catch err + ce = CapturedException(err, catch_backtrace()) + print_error && showerror(stderr, ce) + return MPIRemoteException(rank, ce) + end +end diff --git a/lib/MPIRPC/src/nonuniform.jl b/lib/MPIRPC/src/nonuniform.jl new file mode 100644 index 000000000..d615a508e --- /dev/null +++ b/lib/MPIRPC/src/nonuniform.jl @@ -0,0 +1,262 @@ +""" + NonUniformMPIRPCBackend(comm; listener_ranks, ...) + +Heterogeneous backend with explicit roles. Ranks in `listener_ranks` service +inbound requests via [`rpc_progress!`](@ref); all other ranks are clients +that only initiate calls and only drain replies addressed to them. + +# Roles + +* **Listener**: must call `rpc_progress!` regularly (e.g. from a dedicated + service loop or interleaved with computation). Listeners may also issue + RPC themselves, in which case they additionally drain replies via the + same entrypoint. +* **Client-only**: never services requests from peers. Calling `remotecall*` + with a non-listener `dest_rank` raises `ArgumentError`. + +# Communicators + +If `dup_comm` is `true` (default), the backend duplicates `comm` with +`MPI.Comm_dup`. This isolates RPC traffic from collectives the user runs on +the original communicator, so for instance an `MPI.Barrier` on +`MPI.COMM_WORLD` cannot be matched against — or block — pending RPC sends. + +# Constructor + + NonUniformMPIRPCBackend(comm = MPI.COMM_WORLD; + listener_ranks::AbstractVector{<:Integer}, + request_tag = $(REQUEST_TAG_DEFAULT), + reply_tag = $(REPLY_TAG_DEFAULT), + dup_comm = true, + daemon = false) + +If `daemon` is `true`, [`select_mpi_rpc_backend!`](@ref) spawns a yield-only +progress loop on every rank (listener or client) so listener-side request +draining and client-side reply draining both happen automatically without +the user calling `rpc_progress!`. See `UniformMPIRPCBackend`'s docstring +for the threading and CPU caveats — the daemon never sleeps; it polls +continuously and uses ≈ 100% of one OS thread. +""" +mutable struct NonUniformMPIRPCBackend <: AbstractMPIRPCBackend + base_comm::MPI.Comm + comm::MPI.Comm + request_tag::Int32 + reply_tag::Int32 + dup_comm::Bool + rank::Int + size::Int + + listener_ranks::Set{Int} + is_listener::Bool + + rrid_counter::Threads.Atomic{UInt64} + waiters::Dict{MPIRRID, MPIFuture} + waiters_lock::ReentrantLock + + pending_sends::Vector{Tuple{MPI.Request, Vector{UInt8}}} + + mpi_lock::ReentrantLock + + running::Threads.Atomic{Bool} + initialized::Bool + + daemon::Bool + daemon_task::Union{Task, Nothing} + + function NonUniformMPIRPCBackend(base_comm::MPI.Comm = MPI.COMM_WORLD; + listener_ranks::AbstractVector{<:Integer}, + request_tag::Integer = REQUEST_TAG_DEFAULT, + reply_tag::Integer = REPLY_TAG_DEFAULT, + dup_comm::Bool = true, + daemon::Bool = false) + request_tag == reply_tag && throw(ArgumentError("request_tag and reply_tag must differ")) + isempty(listener_ranks) && throw(ArgumentError("listener_ranks must be non-empty")) + ls = Set{Int}(Int(r) for r in listener_ranks) + return new(base_comm, base_comm, Int32(request_tag), Int32(reply_tag), + dup_comm, -1, -1, + ls, false, + Threads.Atomic{UInt64}(1), + Dict{MPIRRID, MPIFuture}(), + ReentrantLock(), + Tuple{MPI.Request, Vector{UInt8}}[], + ReentrantLock(), + Threads.Atomic{Bool}(true), + false, + daemon, + nothing) + end +end + +function initialize_mpi_rpc!(b::NonUniformMPIRPCBackend) + b.initialized && return nothing + if !MPI.Initialized() + MPI.Init(; threadlevel=:multiple) + end + b.comm = b.dup_comm ? MPI.Comm_dup(b.base_comm) : b.base_comm + b.rank = MPI.Comm_rank(b.comm) + b.size = MPI.Comm_size(b.comm) + for r in b.listener_ranks + if r < 0 || r >= b.size + throw(ArgumentError("listener rank $r outside [0, $(b.size))")) + end + end + b.is_listener = b.rank in b.listener_ranks + b.initialized = true + return nothing +end + +""" + is_listener(backend) -> Bool + +True if the *current* rank services inbound RPC requests on this backend. +""" +is_listener(b::NonUniformMPIRPCBackend) = b.is_listener +is_listener(::UniformMPIRPCBackend) = true + +""" + listener_ranks(backend) -> Vector{Int} +""" +listener_ranks(b::NonUniformMPIRPCBackend) = sort!(collect(b.listener_ranks)) +listener_ranks(b::UniformMPIRPCBackend) = collect(0:(b.size-1)) + +next_rrid(b::NonUniformMPIRPCBackend) = + MPIRRID(Int32(b.rank), Threads.atomic_add!(b.rrid_counter, UInt64(1))) + +function register_waiter!(b::NonUniformMPIRPCBackend, fut::MPIFuture) + @lock b.waiters_lock begin + b.waiters[fut.rrid] = fut + end + return fut +end + +function take_waiter!(b::NonUniformMPIRPCBackend, rrid::MPIRRID) + @lock b.waiters_lock begin + fut = get(b.waiters, rrid, nothing) + fut === nothing && return nothing + delete!(b.waiters, rrid) + return fut + end +end + +function _post_isend!(b::NonUniformMPIRPCBackend, dest::Integer, tag::Integer, + buf::Vector{UInt8}) + @lock b.mpi_lock begin + req = MPI.Isend(buf, b.comm; dest=Int(dest), tag=Int(tag)) + push!(b.pending_sends, (req, buf)) + end + return nothing +end + +function _reap_pending_sends!(b::NonUniformMPIRPCBackend) + @lock b.mpi_lock begin + i = 1 + while i <= length(b.pending_sends) + req, _ = b.pending_sends[i] + if MPI.Test(req) + deleteat!(b.pending_sends, i) + else + i += 1 + end + end + end + return nothing +end + +# See `uniform.jl::_try_recv_one!` for the rationale; the body is identical +# because the receive path differs from the request/reply dispatch only at +# the `rpc_progress!` level (listener-vs-not gating), not here. +function _try_recv_one!(b::NonUniformMPIRPCBackend, tag::Integer) + local req::MPI.Request + local buf::Vector{UInt8} + local src::Int + @lock b.mpi_lock begin + got, m, status = MPI.Improbe(MPI.ANY_SOURCE, Int(tag), b.comm, MPI.Status) + got || return nothing + src = Int(status.MPI_SOURCE) + count = MPI.Get_count(status, UInt8) + buf = Vector{UInt8}(undef, count) + req = MPI.Imrecv!(buf, m) + end + while true + done = @lock b.mpi_lock MPI.Test(req) + done && return (src, buf) + yield() + end +end + +function rpc_progress!(b::NonUniformMPIRPCBackend) + return _rpc_progress_impl!(b) +end + +function _rpc_progress_impl!(b::NonUniformMPIRPCBackend) + _reap_pending_sends!(b) + + if b.is_listener + while true + r = _try_recv_one!(b, b.request_tag) + r === nothing && break + src, buf = r + header, msg, body_err = decode_frame(buf) + if body_err === nothing && msg isa RPCProgressHaltMsg + return false + end + # Spawn handler so the listener loop never holds any lock + # during user code. See uniform.jl for the rationale. + errormonitor(Threads.@spawn _run_handler_task(b, src, header, msg, body_err)) + end + end + + while true + r = _try_recv_one!(b, b.reply_tag) + r === nothing && break + _, buf = r + _dispatch_reply!(b, buf) + end + + return true +end + +function _execute_request!(b::NonUniformMPIRPCBackend, src::Int, header::MsgHeader, msg::AbstractMsg) + if msg isa CallMsg{:call} || msg isa CallMsg{:call_fetch} + v = run_work_thunk(() -> invokelatest(msg.f, msg.args...; pairs(msg.kwargs)...), b.rank) + if !is_null(header.notify_oid) + _send_reply!(b, header.notify_oid, v) + end + elseif msg isa CallWaitMsg + v = run_work_thunk(() -> invokelatest(msg.f, msg.args...; pairs(msg.kwargs)...), b.rank) + if !is_null(header.notify_oid) + ack = v isa MPIRemoteException ? v : :OK + _send_reply!(b, header.notify_oid, ack) + end + elseif msg isa RemoteDoMsg + run_work_thunk(() -> invokelatest(msg.f, msg.args...; pairs(msg.kwargs)...), b.rank; + print_error=true) + else + throw(ProtocolError("unhandled request message $(typeof(msg)) on rank $(b.rank))")) + end + return nothing +end + +function _dispatch_reply!(b::NonUniformMPIRPCBackend, buf::Vector{UInt8}) + header, msg, body_err = decode_frame(buf) + if body_err !== nothing + _deliver_reply_deserialize_failure!(b, header.response_oid, b.rank, body_err) + return + end + msg isa ResultMsg || throw(ProtocolError("expected ResultMsg on reply tag, got $(typeof(msg))")) + fut = take_waiter!(b, header.response_oid) + if fut === nothing + @warn "MPIRPC: ResultMsg for unknown waiter (response_oid=$(header.response_oid)) on rank $(b.rank); dropping" + return + end + deliver!(fut, msg.value) + return nothing +end + +function _send_reply!(b::NonUniformMPIRPCBackend, response_oid::MPIRRID, value) + header = MsgHeader(response_oid, NULL_RRID) + body = ResultMsg(value) + buf = encode_frame(header, body) + _post_isend!(b, response_oid.whence, b.reply_tag, buf) + return nothing +end diff --git a/lib/MPIRPC/src/progress.jl b/lib/MPIRPC/src/progress.jl new file mode 100644 index 000000000..4e9ccb4d6 --- /dev/null +++ b/lib/MPIRPC/src/progress.jl @@ -0,0 +1,162 @@ +""" + rpc_progress!() = rpc_progress!(current_mpi_rpc_backend()) + +Drive one non-blocking progress pass on the currently installed backend. +See the per-backend method docstrings for who must call this. +""" +rpc_progress!() = rpc_progress!(current_mpi_rpc_backend()) + +""" + _run_handler_task(backend, src, header, msg, body_err) + +Internal entry point for the spawned task that runs an inbound request. +Two important effects: + +* The current backend is re-installed as task-local so the user closure's + calls to `current_mpi_rpc_backend()` (e.g. nested `remotecall_fetch`) + see the same backend the request arrived on, even when the parent task + was inside a `with_mpi_rpc_backend` scope that does not propagate to + spawned children. +* Request frames are decoded in the parent [`rpc_progress!`](@ref) pass; + this task receives the triple from [`decode_frame`](@ref). Body + deserialization failures are turned into `MPIRemoteException` replies when + `header.notify_oid` is set, or printed to stderr when not (e.g. + fire-and-forget). User handler failures are already wrapped by + [`run_work_thunk`](@ref); unexpected exceptions after decode (e.g. from + `_send_reply!`) trigger a best-effort `MPIRemoteException` reply when + possible, then are rethrown so the task's `errormonitor` surfaces them. +""" +function _run_handler_task(b::AbstractMPIRPCBackend, src::Int, + header::MsgHeader, msg, body_err) + return with_mpi_rpc_backend(b) do + if body_err !== nothing + _notify_request_decode_failure!(b, header, b.rank, body_err) + return nothing + end + if msg isa RPCProgressHaltMsg + return nothing + end + try + _execute_request!(b, src, header, msg::AbstractMsg) + catch e + if !is_null(header.notify_oid) + _reply_exception!(b, header.notify_oid, b.rank, e) + end + rethrow() + end + return nothing + end +end + +""" + _daemon_loop(backend) + +Background progress driver, spawned by [`select_mpi_rpc_backend!`](@ref) +when the backend was constructed with `daemon = true`. The loop runs on +its own task (typically on its own OS thread under `julia -t >= 2`) and +calls [`rpc_progress!`](@ref) in a tight, **yield-only** loop until +`backend.running[]` flips to `false` (which happens inside +[`shutdown!`](@ref)). + +There is intentionally no `sleep` form: a sleeping daemon would delay +inbound requests by up to `poll_interval` seconds, which is the worst +kind of latency-vs-cpu trade-off for low-rate workloads. If you need +that trade-off, leave `daemon = false` and pump progress on your own +schedule. + +A `try / catch` wraps the body so an unexpected MPI or framing error +does not silently kill the daemon and leave the rank unresponsive. +Errors are logged via `@error` and then rethrown so `shutdown!`'s `wait` +on the daemon task surfaces the failure to the caller. +""" +function _daemon_loop(backend::AbstractMPIRPCBackend) + try + while backend.running[] + rpc_progress!(backend) + yield() + end + rpc_progress!(backend) # final drain after shutdown! flipped the flag + catch e + @error "MPIRPC: daemon loop crashed; this rank will stop making progress" exception = (e, catch_backtrace()) + rethrow() + end + return nothing +end + +""" + serve_listener(backend = current_mpi_rpc_backend(); poll_interval=0.0) + +Blocking helper for non-uniform listeners: loop calling +[`rpc_progress!`](@ref) until [`shutdown!`](@ref) flips +`backend.running[]` to `false`. `poll_interval` is the time, in seconds, +to `sleep` between progress passes (`0` yields without sleeping). + +This is a convenience for examples and tests; production code is free to +interleave `rpc_progress!` with its own work loop instead. +""" +function serve_listener(backend::AbstractMPIRPCBackend = current_mpi_rpc_backend(); + poll_interval::Real = 0.0) + while backend.running[] + rpc_progress!(backend) + if poll_interval > 0 + sleep(poll_interval) + else + yield() + end + end + rpc_progress!(backend) # final drain so in-flight messages are reaped + return nothing +end + +""" + rpc_barrier([backend]) + +Phase boundary for MPIRPC traffic: a non-blocking `MPI.Ibarrier` whose +completion is awaited *while every rank pumps* [`rpc_progress!`](@ref). +Use this between phases of an SPMD program, or before exiting an RPC +session, to drain in-flight work from peers. + +Why a plain `MPI.Barrier` is not enough: a rank `R` whose own +`remotecall_fetch` has just completed may still hold *unprocessed +requests* sent by other ranks (e.g. nested calls those ranks issued from +inside a handler running on `R`). If `R` enters `MPI.Barrier` it stops +pumping RPC progress, and the peers waiting on `R`'s replies hang. +`rpc_barrier` solves this by pumping progress until **all** ranks have +arrived, then doing a final drain pass. + +Calling pattern is identical to `MPI.Barrier`: every rank in the backend's +communicator must call it. +""" +function rpc_barrier(backend::AbstractMPIRPCBackend = current_mpi_rpc_backend()) + req = @lock backend.mpi_lock MPI.Ibarrier(backend.comm) + while true + rpc_progress!(backend) + done = @lock backend.mpi_lock MPI.Test(req) + done && break + yield() + end + rpc_progress!(backend) # one last drain so any reply Isend posted under + # the barrier is reaped on this rank + return nothing +end + +""" + rpc_progress_halt!(backend, dest_rank::Int) -> nothing + +Enqueue a framed [`RPCProgressHaltMsg`](@ref) to `dest_rank` on the backend's +`request_tag` via the same `_post_isend!` / pending-send path as ordinary RPC. + +The **destination** rank must call [`rpc_progress!`](@ref) (or run its daemon) +to match and decode the message. When consumed, that `rpc_progress!` pass +returns `false` and does **not** spawn a handler task for further matched +requests in that pass—useful to break out of a progress loop without executing +user RPC bodies for additional queued messages. + +The halt message carries an empty [`MsgHeader`](@ref); it is not a call/reply +pair and does not touch waiter tables. +""" +function rpc_progress_halt!(backend::AbstractMPIRPCBackend, dest_rank::Int) + buf = encode_frame(MsgHeader(), RPCProgressHaltMsg()) + _post_isend!(backend, dest_rank, backend.request_tag, buf) + return nothing +end \ No newline at end of file diff --git a/lib/MPIRPC/src/protocol.jl b/lib/MPIRPC/src/protocol.jl new file mode 100644 index 000000000..a805a8972 --- /dev/null +++ b/lib/MPIRPC/src/protocol.jl @@ -0,0 +1,181 @@ +using Serialization + +""" + MPIRRID(whence::Int32, id::UInt64) + +Routing identifier for an in-flight remote call. Mirrors `Distributed.RRID`, +with `whence` storing the *MPI rank* (within the backend's communicator) that +allocated the id rather than the Distributed worker pid. + +`whence == -1` and `id == 0` is reserved as the null id (`NULL_RRID`), +analogous to Distributed's `RRID(0, 0)`. +""" +struct MPIRRID + whence::Int32 + id::UInt64 +end + +const NULL_RRID = MPIRRID(Int32(-1), UInt64(0)) + +is_null(r::MPIRRID) = r.whence == NULL_RRID.whence && r.id == NULL_RRID.id + +Base.hash(r::MPIRRID, h::UInt) = hash(r.whence, hash(r.id, hash(MPIRRID, h))) +Base.:(==)(a::MPIRRID, b::MPIRRID) = a.whence == b.whence && a.id == b.id + +""" + MsgHeader(response_oid::MPIRRID, notify_oid::MPIRRID) + +Two-OID header preceding every MPIRPC body, mirroring `Distributed.MsgHeader`. + +* `response_oid` identifies a `MPIFuture` on the receiver of a `ResultMsg` + (i.e. on the *client* of the original call) so the value can be delivered + to the right waiter. +* `notify_oid` is the OID the *server* must echo back as `response_oid` when + it produces a `ResultMsg`. It is set on outgoing `CallMsg` / `CallWaitMsg` + by the client and is null on `RemoteDoMsg` (fire-and-forget). +""" +struct MsgHeader + response_oid::MPIRRID + notify_oid::MPIRRID +end +MsgHeader() = MsgHeader(NULL_RRID, NULL_RRID) +MsgHeader(response_oid::MPIRRID) = MsgHeader(response_oid, NULL_RRID) + +abstract type AbstractMsg end + +""" + CallMsg{Mode}(f, args::Tuple, kwargs) + +`Mode` is `:call` for `remotecall` (client expects a future) or `:call_fetch` +for `remotecall_fetch` (client expects the value to be returned in a +`ResultMsg`). MPIRPC v1 actually treats both modes identically on the wire: +the server always replies with a `ResultMsg` carrying the value or a +`MPIRemoteException`. `Mode` is preserved so the server can mirror Distributed's +behavior of distinguishing `:call` from `:call_fetch` in error reporting later. +""" +struct CallMsg{Mode} <: AbstractMsg + f::Any + args::Tuple + kwargs::Any +end + +""" + CallWaitMsg(f, args, kwargs) + +`remotecall_wait`: server runs the call and replies with `:OK` (or an +exception) so the client can confirm completion without paying for marshaling +the result. +""" +struct CallWaitMsg <: AbstractMsg + f::Any + args::Tuple + kwargs::Any +end + +""" + RemoteDoMsg(f, args, kwargs) + +Fire-and-forget: server runs the call and discards the result. No reply is +sent. Errors are printed to stderr on the server rank (like +`Distributed.remote_do`); nothing is delivered to the client. +""" +struct RemoteDoMsg <: AbstractMsg + f::Any + args::Tuple + kwargs::Any +end + +""" + ResultMsg(value) + +Carries either a successful return value or an `MPIRemoteException`. The +client routes it to a waiter via `MsgHeader.response_oid`. +""" +struct ResultMsg <: AbstractMsg + value::Any +end + +""" + RPCProgressHaltMsg + +Control message on the request tag: tells [`rpc_progress!`](@ref) to stop +draining further inbound requests in the **current** progress pass (returns +`false`). Framed like other MPIRPC bodies; see [`rpc_progress_halt!`](@ref). +""" +struct RPCProgressHaltMsg <: AbstractMsg end + +""" + MSG_BOUNDARY + +Ten-byte sentinel appended after every serialized frame. MPI is +message-oriented so we do not need it for stream resynchronization, but we +keep it as a fail-fast protocol-version / corruption check, matching +Distributed's convention. +""" +const MSG_BOUNDARY = UInt8[0x4d, 0x50, 0x49, 0x52, 0x50, 0x43, 0x46, 0x52, 0x4d, 0x31] + +""" + encode_frame(header, msg) -> Vector{UInt8} + +Serialize `header` then `msg` through a single `Serializer` over a fresh +`IOBuffer`, append `MSG_BOUNDARY`, and return the byte vector ready for +`MPI.Isend`. A fresh `Serializer` per message bounds the serializer's +back-reference table to one frame, so cross-frame state cannot leak. +""" +function encode_frame(header::MsgHeader, msg::AbstractMsg) + io = IOBuffer() + s = Serializer(io) + serialize(s, header) + serialize(s, msg) + write(io, MSG_BOUNDARY) + return take!(io) +end + +""" + decode_frame(buf) -> (header, msg, body_error) + +Decode a frame previously produced by `encode_frame`. The header is decoded +first so that, on body failures, the caller can still reply to the right +waiter (mirrors Distributed's two-stage parse in `process_messages.jl`). + +`body_error === nothing` on success; otherwise `msg === nothing` and +`body_error` is the captured exception. The `MSG_BOUNDARY` is verified on +success and a `ProtocolError` is raised if it is missing or wrong, which +forces a fail-fast rather than silently delivering garbage. + +`Base.invokelatest` wraps the body deserialization so that types defined +after world-age advances are reachable, matching Distributed's +`invokelatest(deserialize_msg, ...)` call site. +""" +function decode_frame(buf::Vector{UInt8}) + io = IOBuffer(buf) + s = Serializer(io) + header = deserialize(s)::MsgHeader + msg = nothing + body_error = nothing + try + msg = invokelatest(deserialize, s)::AbstractMsg + verify_boundary!(io) + catch e + body_error = e + end + return header, msg, body_error +end + +struct ProtocolError <: Exception + msg::String +end +Base.showerror(io::IO, e::ProtocolError) = print(io, "MPIRPC.ProtocolError: ", e.msg) + +function verify_boundary!(io::IOBuffer) + n = length(MSG_BOUNDARY) + bytes_left = io.size - io.ptr + 1 + if bytes_left < n + throw(ProtocolError("frame ended before MSG_BOUNDARY (have $(bytes_left) of $(n) bytes)")) + end + tail = read(io, n) + if tail != MSG_BOUNDARY + throw(ProtocolError("MSG_BOUNDARY mismatch (got $(tail))")) + end + return nothing +end diff --git a/lib/MPIRPC/src/refs.jl b/lib/MPIRPC/src/refs.jl new file mode 100644 index 000000000..ed5345a6a --- /dev/null +++ b/lib/MPIRPC/src/refs.jl @@ -0,0 +1,55 @@ +""" + MPIFuture + +Future-like handle returned by `remotecall` and (internally) used by +`remotecall_fetch` / `remotecall_wait`. Mirrors `Distributed.Future`: + +* `where` — MPI rank running the call (the eventual source of the reply). +* `rrid::MPIRRID` — id allocated by the *caller* (`whence == my rank`). + This same id is echoed back in `MsgHeader.response_oid` of the + corresponding `ResultMsg`. +* `v::Atomic{Union{Some{Any}, Nothing}}` — set once when the reply arrives. + +A `MPIFuture` is registered in the backend's waiter table from creation +until the value is delivered or it is finalized (whichever comes first). +""" +mutable struct MPIFuture + backend::Any + where::Int + rrid::MPIRRID + @atomic v::Union{Some{Any}, Nothing} + cond::Threads.Condition + + function MPIFuture(backend, where::Integer, rrid::MPIRRID) + return new(backend, Int(where), rrid, nothing, Threads.Condition()) + end +end + +Base.show(io::IO, f::MPIFuture) = + print(io, "MPIFuture(where=", f.where, ", rrid=", f.rrid, ", ready=", isready(f), ")") + +""" + isready(f::MPIFuture) -> Bool + +True if the reply has been delivered (success or remote exception) and a +subsequent `fetch(f)` will not block. +""" +Base.isready(f::MPIFuture) = (@atomic :acquire f.v) !== nothing + +""" + deliver!(f::MPIFuture, value) + +Internal: store `value` in `f` and wake any waiters. Idempotent — repeated +deliveries (which can occur if a stale duplicate reply ever lands) are +ignored after the first. +""" +function deliver!(f::MPIFuture, value) + @lock f.cond begin + prev = (@atomic :acquire f.v) + if prev === nothing + @atomic :release f.v = Some{Any}(value) + notify(f.cond, all=true) + end + end + return nothing +end diff --git a/lib/MPIRPC/src/remotecall.jl b/lib/MPIRPC/src/remotecall.jl new file mode 100644 index 000000000..e0d3dc0bf --- /dev/null +++ b/lib/MPIRPC/src/remotecall.jl @@ -0,0 +1,266 @@ +""" + remotecall(f, dest_rank, args...; kwargs...) -> MPIFuture + +Send a call to `dest_rank` and return a [`MPIFuture`](@ref) that completes +when the reply arrives. Mirrors `Distributed.remotecall`. + +The reply carries either the function's return value or a +`MPIRemoteException`; both are delivered into the future, and +[`fetch`](@ref) rethrows the exception case while returning the value +otherwise. +""" +function remotecall(f, dest_rank::Integer, args...; kwargs...) + backend = current_mpi_rpc_backend() + return _remotecall_internal(backend, Val(:call), f, Int(dest_rank), args, kwargs) +end + +""" + remotecall_fetch(f, dest_rank, args...; kwargs...) -> result + +Send a call, wait for the reply, and return the result. Throws +`MPIRemoteException` if the remote call raised. Mirrors +`Distributed.remotecall_fetch`. Uses [`current_mpi_rpc_backend`](@ref). +""" +function remotecall_fetch(f, dest_rank::Integer, args...; kwargs...) + return remotecall_fetch(f, current_mpi_rpc_backend(), dest_rank, args...; kwargs...) +end + +""" + remotecall_fetch(f, backend::AbstractMPIRPCBackend, dest_rank, args...; kwargs...) -> result + +Send a call, wait for the reply, and return the result. Throws +`MPIRemoteException` if the remote call raised. Mirrors +`Distributed.remotecall_fetch` with an explicit backend (e.g. for Dagger). +""" +function remotecall_fetch(f, backend::AbstractMPIRPCBackend, dest_rank::Integer, args...; kwargs...) + fut = _remotecall_internal(backend, Val(:call_fetch), f, Int(dest_rank), args, kwargs) + return fetch(fut) +end + +""" + remotecall_wait(f, dest_rank, args...; kwargs...) -> MPIFuture + +Send a call, wait for the server to acknowledge completion, and return the +future (containing `:OK` or an `MPIRemoteException`). Use this when you want +back-pressure on remote completion without paying for marshaling the result. +""" +function remotecall_wait(f, dest_rank::Integer, args...; kwargs...) + backend = current_mpi_rpc_backend() + fut = _remotecall_wait_internal(backend, f, Int(dest_rank), args, kwargs) + wait(fut) + val = something((@atomic :acquire fut.v)) + val isa MPIRemoteException && throw(val) + return fut +end + +""" + remote_do(f, dest_rank, args...; kwargs...) -> nothing + +Fire-and-forget: send a call to `dest_rank` and return immediately. The +remote function's return value (and any exception) is **not** observed by +the caller. Mirrors `Distributed.remote_do`. +""" +function remote_do(f, dest_rank::Integer, args...; kwargs...) + backend = current_mpi_rpc_backend() + _remote_do_internal(backend, f, Int(dest_rank), args, kwargs) + return nothing +end + +""" + bcast_remotecall(backend::AbstractMPIRPCBackend, f, args...; kwargs...) -> Vector{MPIFuture} + +Issue one [`remotecall`](@ref)-style `CallMsg{:call}` to **every other rank** +in `backend`'s communicator: all ranks in `0:backend.size-1` **except** +`backend.rank`, in increasing order. Returns a vector of length +`max(0, backend.size - 1)` (empty when `backend.size == 1`). + +`out[i]` is the future for the `i`-th destination in that filtered list (not +indexed by global rank). Progress and `fetch`/`wait` semantics match +[`remotecall`](@ref). + +For [`NonUniformMPIRPCBackend`](@ref), destinations that are not listeners +raise from [`_validate_dest!`](@ref); use the explicit-ranks form with only +listeners if needed. +""" +function bcast_remotecall(backend::AbstractMPIRPCBackend, f, args...; kwargs...) + dests = Int[d for d in 0:(backend.size - 1) if d != backend.rank] + return _bcast_remotecall_destinations!(backend, f, dests, Tuple(args), kwargs) +end + +""" + bcast_remotecall(backend::AbstractMPIRPCBackend, f, ranks::AbstractVector{<:Integer}, args...; kwargs...) -> Vector{MPIFuture} + bcast_remotecall(f, ranks::AbstractVector{<:Integer}, args...; kwargs...) -> Vector{MPIFuture} + +Issue [`remotecall`](@ref) to each rank in `ranks` **in list order** (no +automatic skip of `backend.rank`; filter `ranks` yourself if you want to omit +the caller). The `i`-th future corresponds to `ranks[i]`. + +The no-`backend` form uses [`current_mpi_rpc_backend`](@ref). +""" +function bcast_remotecall(backend::AbstractMPIRPCBackend, f, + ranks::AbstractVector{<:Integer}, args...; kwargs...) + dests = Int[Int(r) for r in ranks] + return _bcast_remotecall_destinations!(backend, f, dests, Tuple(args), kwargs) +end + +function bcast_remotecall(f, ranks::AbstractVector{<:Integer}, args...; kwargs...) + return bcast_remotecall(current_mpi_rpc_backend(), f, ranks, args...; kwargs...) +end + +function _bcast_remotecall_destinations!(backend::AbstractMPIRPCBackend, f, + dests::Vector{Int}, args::Tuple, kwargs) + n = length(dests) + out = Vector{MPIFuture}(undef, n) + for i in 1:n + out[i] = _remotecall_internal(backend, Val(:call), f, dests[i], args, kwargs) + end + return out +end + +# --------------------------------------------------------------------------- + +function _validate_dest!(backend::AbstractMPIRPCBackend, dest::Int) + backend.initialized || throw(ArgumentError( + "MPIRPC backend not initialized; call `select_mpi_rpc_backend!` first")) + if dest < 0 || dest >= backend.size + throw(ArgumentError("dest rank $dest outside [0, $(backend.size))")) + end + if backend isa NonUniformMPIRPCBackend && !(dest in backend.listener_ranks) + throw(ArgumentError( + "dest rank $dest is not a listener; only listener ranks can service RPC " * + "in a NonUniformMPIRPCBackend (listeners=$(sort!(collect(backend.listener_ranks))))")) + end + return nothing +end + +function _remotecall_internal(backend::AbstractMPIRPCBackend, ::Val{Mode}, f, + dest::Int, args::Tuple, kwargs) where {Mode} + _validate_dest!(backend, dest) + rrid = next_rrid(backend) + fut = MPIFuture(backend, dest, rrid) + register_waiter!(backend, fut) + header = MsgHeader(NULL_RRID, rrid) + msg = CallMsg{Mode}(f, args, _kwargs_to_pairs(kwargs)) + buf = encode_frame(header, msg) + _post_isend!(backend, dest, backend.request_tag, buf) + return fut +end + +function _remotecall_wait_internal(backend::AbstractMPIRPCBackend, f, + dest::Int, args::Tuple, kwargs) + _validate_dest!(backend, dest) + rrid = next_rrid(backend) + fut = MPIFuture(backend, dest, rrid) + register_waiter!(backend, fut) + header = MsgHeader(NULL_RRID, rrid) + msg = CallWaitMsg(f, args, _kwargs_to_pairs(kwargs)) + buf = encode_frame(header, msg) + _post_isend!(backend, dest, backend.request_tag, buf) + return fut +end + +function _remote_do_internal(backend::AbstractMPIRPCBackend, f, + dest::Int, args::Tuple, kwargs) + _validate_dest!(backend, dest) + header = MsgHeader(NULL_RRID, NULL_RRID) + msg = RemoteDoMsg(f, args, _kwargs_to_pairs(kwargs)) + buf = encode_frame(header, msg) + _post_isend!(backend, dest, backend.request_tag, buf) + return nothing +end + +# Materialize kwargs as a `Vector{Pair{Symbol,Any}}` so the on-wire form is +# stable regardless of whether the caller passed `; kwargs...`, a `NamedTuple`, +# or a `Base.Pairs`. This avoids serializing the iterator object itself, which +# can drag in unexpected closures. +function _kwargs_to_pairs(kwargs) + out = Vector{Pair{Symbol, Any}}() + for (k, v) in pairs(kwargs) + push!(out, k => v) + end + return out +end + +# --------------------------------------------------------------------------- +# wait / fetch — two strategies depending on whether the backend has a +# progress daemon driving the wire on this rank. + +""" + wait(f::MPIFuture) -> f + +Block until the reply for `f` has been delivered. Two implementations, +selected at call time by inspecting the future's backend: + +* **`backend.daemon == true`** — park on `f.cond` (a `Threads.Condition`). + The waiter consumes zero CPU until [`deliver!`](@ref) (called by the + daemon's `_dispatch_reply!`) holds the same lock, sets the value, and + `notify`s. This is the same shape as `Distributed`'s `take!(fut.v)` on + a `Channel`, with one important difference covered below. + +* **`backend.daemon == false`** — fall back to the v1 spin: call + `rpc_progress!` and `yield` in a loop. We *cannot* park here because + the backend has no other progress driver: the reply that would + eventually call `deliver!` only arrives when *some task* drains the + wire, and `wait` itself is the only candidate. Parking would deadlock. + +The `backend.daemon` check is read off `f.backend`, not the +*currently installed* backend, so a future created on a daemon-backed +backend is still cheap to wait on even from a task that has scoped a +non-daemon backend via [`with_mpi_rpc_backend`](@ref). +""" +function Base.wait(f::MPIFuture) + isready(f) && return f + backend = f.backend::AbstractMPIRPCBackend + if backend.daemon + # Cond-park path. The standard CV idiom: hold the lock, recheck + # the predicate, `wait` if false, recheck on wake. `deliver!` + # holds the same lock when it sets `f.v` and `notify`s, so the + # check / park / wake transitions are atomic with delivery. + @lock f.cond begin + while !isready(f) + wait(f.cond) + end + end + else + # No daemon ⇒ this task is the progress driver. A bare + # `wait(f.cond)` here would never wake because nobody is + # receiving the reply that triggers the notify. Yield-rate + # spin is the only correct choice; it costs CPU but it + # actually makes forward progress on the wire. + while !isready(f) + rpc_progress!(backend) + isready(f) && break + yield() + end + end + return f +end + +""" + fetch(f::MPIFuture) -> value + +Block until the reply has arrived, then return the value. If the remote +function threw, that exception is rethrown locally as a `MPIRemoteException`. +Park / spin behavior follows [`wait`](@ref) — see its docstring. +""" +function Base.fetch(f::MPIFuture) + wait(f) + val = something((@atomic :acquire f.v)) + val isa MPIRemoteException && throw(val) + return val +end + +# --------------------------------------------------------------------------- + +""" + @with_progress expr + +Convenience macro: evaluate `expr` in a scope where the current MPIRPC +backend's progress engine is driven by `wait`/`fetch`. This is a no-op +today (since `wait`/`fetch` already pump progress) and is provided as a +forward-compatible attachment point for future work that might run user +code in a context without an automatic progress drain. +""" +macro with_progress(expr) + return esc(expr) +end diff --git a/lib/MPIRPC/src/uniform.jl b/lib/MPIRPC/src/uniform.jl new file mode 100644 index 000000000..b7866911a --- /dev/null +++ b/lib/MPIRPC/src/uniform.jl @@ -0,0 +1,314 @@ +using MPI + +const REQUEST_TAG_DEFAULT = Int32(0xC0DE) +const REPLY_TAG_DEFAULT = Int32(0xC0DF) + +""" + UniformMPIRPCBackend(comm; ...) + +SPMD backend: every rank both issues and services RPC. Every rank must +call [`rpc_progress!`](@ref) regularly (e.g. once per main-loop iteration); +clients block in `fetch` / `wait` by looping `rpc_progress!` themselves so +no rank ever needs a dedicated background task to receive. + +# Wire layout + +* Two disjoint MPI tags carry traffic, regardless of how many calls are in + flight: `request_tag` for `CallMsg` / `CallWaitMsg` / `RemoteDoMsg`, and + `reply_tag` for `ResultMsg`. **Tags do not encode call identity.** +* Each call carries a unique `MPIRRID` in `MsgHeader.notify_oid`. The server + echoes it back as `MsgHeader.response_oid` on the reply, and the client + routes the reply to the matching `MPIFuture` through a waiter table keyed + by `MPIRRID`. + +This layout is deliberately ABBA-safe: two ranks calling each other +simultaneously cannot collide on `(peer, tag)` because requests and replies +travel on different tags, and concurrent calls to the same peer cannot be +confused because correlation lives in the header, not the tag. See +`docs/ARCHITECTURE.md` for the full deadlock argument. + +# Constructor + + UniformMPIRPCBackend(comm = MPI.COMM_WORLD; + request_tag = $(REQUEST_TAG_DEFAULT), + reply_tag = $(REPLY_TAG_DEFAULT), + dup_comm = true, + daemon = false) + +If `dup_comm` is `true` (default), the backend duplicates `comm` with +`MPI.Comm_dup` so RPC traffic cannot interleave with collectives the user +runs on the original communicator. Pass `dup_comm = false` if MPI is not +yet initialized at construction time and you intend to call +[`select_mpi_rpc_backend!`](@ref) (which initializes MPI if necessary). + +If `daemon` is `true`, [`select_mpi_rpc_backend!`](@ref) spawns a +`Threads.@spawn`-backed task that runs [`rpc_progress!`](@ref) in a tight +yield-only loop until [`shutdown!`](@ref) is called. This removes the +"every rank must call `rpc_progress!`" requirement: as long as Julia has +≥ 2 threads, the daemon makes forward progress on inbound requests +without any user pumping. Under `julia -t 1` the daemon shares a thread +with the main task and only runs at yield points, so the requirement is +effectively unchanged. The daemon does **not** sleep — it polls +continuously — so expect ≈ 100% utilization of one OS thread. If that is +unacceptable for your workload, leave `daemon = false` and drive +progress manually via `rpc_progress!`, `serve_listener`, `wait`/`fetch`, +or `rpc_barrier`. +""" +mutable struct UniformMPIRPCBackend <: AbstractMPIRPCBackend + base_comm::MPI.Comm + comm::MPI.Comm + request_tag::Int32 + reply_tag::Int32 + dup_comm::Bool + rank::Int + size::Int + + rrid_counter::Threads.Atomic{UInt64} + waiters::Dict{MPIRRID, MPIFuture} + waiters_lock::ReentrantLock + + pending_sends::Vector{Tuple{MPI.Request, Vector{UInt8}}} + + mpi_lock::ReentrantLock + + running::Threads.Atomic{Bool} + initialized::Bool + + daemon::Bool + daemon_task::Union{Task, Nothing} + + function UniformMPIRPCBackend(base_comm::MPI.Comm = MPI.COMM_WORLD; + request_tag::Integer = REQUEST_TAG_DEFAULT, + reply_tag::Integer = REPLY_TAG_DEFAULT, + dup_comm::Bool = true, + daemon::Bool = false) + request_tag == reply_tag && throw(ArgumentError("request_tag and reply_tag must differ")) + return new(base_comm, base_comm, Int32(request_tag), Int32(reply_tag), + dup_comm, -1, -1, + Threads.Atomic{UInt64}(1), + Dict{MPIRRID, MPIFuture}(), + ReentrantLock(), + Tuple{MPI.Request, Vector{UInt8}}[], + ReentrantLock(), + Threads.Atomic{Bool}(true), + false, + daemon, + nothing) + end +end + +function initialize_mpi_rpc!(b::UniformMPIRPCBackend) + b.initialized && return nothing + if !MPI.Initialized() + MPI.Init(; threadlevel=:multiple) + end + b.comm = b.dup_comm ? MPI.Comm_dup(b.base_comm) : b.base_comm + b.rank = MPI.Comm_rank(b.comm) + b.size = MPI.Comm_size(b.comm) + b.initialized = true + return nothing +end + +""" + next_rrid(backend) -> MPIRRID + +Allocate a new id local to this rank. The caller's rank is encoded in the +`whence` field, so the same `MPIRRID` is meaningful across the whole +communicator (only the originating rank ever looks it up in the waiter +table). +""" +next_rrid(b::UniformMPIRPCBackend) = + MPIRRID(Int32(b.rank), Threads.atomic_add!(b.rrid_counter, UInt64(1))) + +function register_waiter!(b::UniformMPIRPCBackend, fut::MPIFuture) + @lock b.waiters_lock begin + b.waiters[fut.rrid] = fut + end + return fut +end + +function take_waiter!(b::UniformMPIRPCBackend, rrid::MPIRRID) + @lock b.waiters_lock begin + fut = get(b.waiters, rrid, nothing) + fut === nothing && return nothing + delete!(b.waiters, rrid) + return fut + end +end + +# --------------------------------------------------------------------------- +# Send path + +function _post_isend!(b::UniformMPIRPCBackend, dest::Integer, tag::Integer, + buf::Vector{UInt8}) + @lock b.mpi_lock begin + req = MPI.Isend(buf, b.comm; dest=Int(dest), tag=Int(tag)) + push!(b.pending_sends, (req, buf)) + end + return nothing +end + +function _reap_pending_sends!(b::UniformMPIRPCBackend) + @lock b.mpi_lock begin + i = 1 + while i <= length(b.pending_sends) + req, _ = b.pending_sends[i] + if MPI.Test(req) + deleteat!(b.pending_sends, i) + else + i += 1 + end + end + end + return nothing +end + +# --------------------------------------------------------------------------- +# Receive / dispatch path + +""" + _try_recv_one!(backend, tag) -> (src, buf) | nothing + +Try to receive one message that's already in MPI's matching engine on +`tag`. Two phases: + +1. **Match-and-post under `mpi_lock`.** `MPI.Improbe` is the only way to + atomically peek at a matched message and remove it from the queue; + the buffer must be allocated and `MPI.Imrecv!` posted while still + holding the message handle. Both steps live in a single critical + section because they are semantically one operation. + +2. **Wait outside `mpi_lock`, with yield.** `MPI.Imrecv!` returns a + non-blocking request; we then loop on `MPI.Test` with `yield()` in + between, **releasing `mpi_lock` between polls**. This is what makes + the daemon thread non-blocking even on a slow rendezvous-protocol + transfer: while the receive is in flight, other tasks (handlers + doing their own `Isend`s, user code on this rank) can acquire + `mpi_lock` and progress their own MPI calls. + +Eager-protocol messages (the typical case for RPC payloads up to +MPI's eager threshold, often 64 KiB) almost always complete `Test` +on the first poll, so the only added cost vs. blocking `Mrecv!` is +one extra Test call and one lock acquire/release pair — single-digit +microseconds. The win is for large payloads on rendezvous protocol, +where the previous `Mrecv!` would have held `mpi_lock` for the entire +network round-trip. +""" +function _try_recv_one!(b::UniformMPIRPCBackend, tag::Integer) + local req::MPI.Request + local buf::Vector{UInt8} + local src::Int + @lock b.mpi_lock begin + got, m, status = MPI.Improbe(MPI.ANY_SOURCE, Int(tag), b.comm, MPI.Status) + got || return nothing + src = Int(status.MPI_SOURCE) + count = MPI.Get_count(status, UInt8) + buf = Vector{UInt8}(undef, count) + req = MPI.Imrecv!(buf, m) + end + while true + done = @lock b.mpi_lock MPI.Test(req) + done && return (src, buf) + yield() + end +end + +""" + rpc_progress!([backend]) + +Drive one non-blocking pass of the MPIRPC engine: reap completed sends, +service every inbound request currently matched on this rank, and deliver +every inbound reply currently matched on this rank. + +For [`UniformMPIRPCBackend`](@ref), **every** rank must call this regularly +(typically wrapped in [`@with_progress`](@ref) inside `wait` / `fetch`). +For [`NonUniformMPIRPCBackend`](@ref), only listener ranks need to call +the request-draining half; clients still need to call it to drain replies, +Returns `true` if the pass completed normally, or `false` if a +[`RPCProgressHaltMsg`](@ref) was consumed on the request tag (no further +requests are drained in that pass); see [`rpc_progress_halt!`](@ref). +""" +function rpc_progress!(b::UniformMPIRPCBackend) + return _rpc_progress_impl!(b) +end + +function _rpc_progress_impl!(b::UniformMPIRPCBackend) + _reap_pending_sends!(b) + + while true + r = _try_recv_one!(b, b.request_tag) + r === nothing && break + src, buf = r + header, msg, body_err = decode_frame(buf) + @debug "Received request" header msg body_err + if body_err === nothing && msg isa RPCProgressHaltMsg + return false + end + # Run the handler on a freshly spawned task so user code never + # executes while the progress pump holds any state. Without this, + # a handler that itself blocks on `Threads.@spawn`-then-`fetch` + # would deadlock: the spawned task would need another thread to + # drive RPC progress, but every thread that called `rpc_progress!` + # would be waiting on the handler that the previous progress pass + # had begun running synchronously. See `docs/ARCHITECTURE.md` §6. + errormonitor(Threads.@spawn _run_handler_task(b, src, header, msg, body_err)) + end + + while true + r = _try_recv_one!(b, b.reply_tag) + r === nothing && break + @debug "Received reply" r + _, buf = r + # Reply dispatch only touches MPIRPC's own state (waiter table and + # future condition); it cannot block on user code, so we keep it + # inline. + _dispatch_reply!(b, buf) + end + + return true +end + +function _execute_request!(b::UniformMPIRPCBackend, src::Int, header::MsgHeader, msg::AbstractMsg) + if msg isa CallMsg{:call} || msg isa CallMsg{:call_fetch} + v = run_work_thunk(() -> invokelatest(msg.f, msg.args...; pairs(msg.kwargs)...), b.rank) + if !is_null(header.notify_oid) + _send_reply!(b, header.notify_oid, v) + end + elseif msg isa CallWaitMsg + v = run_work_thunk(() -> invokelatest(msg.f, msg.args...; pairs(msg.kwargs)...), b.rank) + if !is_null(header.notify_oid) + ack = v isa MPIRemoteException ? v : :OK + _send_reply!(b, header.notify_oid, ack) + end + elseif msg isa RemoteDoMsg + run_work_thunk(() -> invokelatest(msg.f, msg.args...; pairs(msg.kwargs)...), b.rank; + print_error=true) + else + throw(ProtocolError("unhandled request message $(typeof(msg)) on rank $(b.rank))")) + end + return nothing +end + +function _dispatch_reply!(b::UniformMPIRPCBackend, buf::Vector{UInt8}) + header, msg, body_err = decode_frame(buf) + if body_err !== nothing + _deliver_reply_deserialize_failure!(b, header.response_oid, b.rank, body_err) + return + end + msg isa ResultMsg || throw(ProtocolError("expected ResultMsg on reply tag, got $(typeof(msg))")) + fut = take_waiter!(b, header.response_oid) + if fut === nothing + @warn "MPIRPC: ResultMsg for unknown waiter (response_oid=$(header.response_oid)) on rank $(b.rank); dropping" + return + end + deliver!(fut, msg.value) + return nothing +end + +function _send_reply!(b::UniformMPIRPCBackend, response_oid::MPIRRID, value) + header = MsgHeader(response_oid, NULL_RRID) + body = ResultMsg(value) + buf = encode_frame(header, body) + _post_isend!(b, response_oid.whence, b.reply_tag, buf) + return nothing +end diff --git a/lib/MPIRPC/test/bcast_remotecall_mpiexec.jl b/lib/MPIRPC/test/bcast_remotecall_mpiexec.jl new file mode 100644 index 000000000..42edbc24d --- /dev/null +++ b/lib/MPIRPC/test/bcast_remotecall_mpiexec.jl @@ -0,0 +1,57 @@ +using Test +using MPI +using MPIRPC + +MPI.Init(; threadlevel=:multiple) +backend = MPIRPC.select_mpi_rpc_backend!(UniformMPIRPCBackend(MPI.COMM_WORLD)) +const COMM = backend.comm +const RANK = MPI.Comm_rank(COMM) +const NPROC = MPI.Comm_size(COMM) + +NPROC >= 3 || error("bcast_remotecall test expects at least 3 ranks (got $NPROC)") + +function rpc_barrier_local() + MPIRPC.rpc_barrier(backend) +end + +function pump_until(pred; timeout::Real = 30.0) + t0 = time() + while !pred() + MPIRPC.rpc_progress!(backend) + time() - t0 > timeout && return false + yield() + end + return true +end + +rpc_barrier_local() +if RANK == 0 + println("--- bcast_remotecall / all other ranks ---") + flush(stdout) +end +rpc_barrier_local() + +@testset "bcast_remotecall skips caller" begin + futs = MPIRPC.bcast_remotecall(backend, +, RANK, 100) + @test length(futs) == NPROC - 1 + @test pump_until(() -> all(Base.isready, futs)) + @test all(f -> MPIRPC.fetch(f) == RANK + 100, futs) +end + +rpc_barrier_local() +if RANK == 0 + println("--- bcast_remotecall / explicit ranks vector ---") + flush(stdout) +end +rpc_barrier_local() + +@testset "bcast_remotecall explicit ranks vector" begin + peer = mod(RANK + 1, NPROC) + futs = MPIRPC.bcast_remotecall(backend, +, Int[peer], 10, 5) + @test length(futs) == 1 + @test pump_until(() -> Base.isready(futs[1])) + @test MPIRPC.fetch(futs[1]) == 15 +end + +rpc_barrier_local() +MPIRPC.shutdown!() diff --git a/lib/MPIRPC/test/mpi_tests.jl b/lib/MPIRPC/test/mpi_tests.jl new file mode 100644 index 000000000..8b2210b41 --- /dev/null +++ b/lib/MPIRPC/test/mpi_tests.jl @@ -0,0 +1,96 @@ +using Test +using MPI + +# Orchestrate the two mpiexec-driven suites from the host test process. +# `Pkg.test()` runs in a single process; we spawn `mpiexec` here so the +# usual `julia --project=. -e 'using Pkg; Pkg.test()'` flow exercises the +# real MPI transport. + +const MPI_BIN = mpiexec() + +const HERE = @__DIR__ +const PROJECT = abspath(joinpath(HERE, "..")) + +function _run_mpi(script::AbstractString; nproc::Integer, + default_threads::Integer = 2, + interactive_threads::Integer = 1) + # Pass `--threads=N,M` to each spawned Julia: N default-pool threads + # to exercise concurrent handlers and concurrent client tasks, plus M + # interactive-pool threads so the MPIRPC daemon (when enabled) can be + # scheduled on its own pool, isolated from any user CPU-bound work + # running on the default pool. + # + # `default_threads = 2` is the minimum that turns `Threads.@spawn` + # into actual parallelism. `interactive_threads = 1` is the minimum + # that gives the daemon a dedicated thread; the daemon-suite scripts + # additionally assert that they are running with the daemon on + # `:interactive`, so this needs to stay >= 1. + threadspec = "$(default_threads),$(interactive_threads)" + cmd = `$(MPI_BIN) -n $(nproc) $(Base.julia_cmd()) --threads=$(threadspec) --project=$(PROJECT) $(script)` + @info "MPIRPC: running" script nproc default_threads interactive_threads cmd + rc = success(pipeline(cmd, stdout = stdout, stderr = stderr)) + return rc +end + +@testset "mpiexec / uniform backend (4 ranks)" begin + if get(ENV, "MPIRPC_SKIP_MPI_TESTS", "0") == "1" + @info "MPIRPC: skipping mpiexec uniform suite (MPIRPC_SKIP_MPI_TESTS=1)" + @test true + else + @test _run_mpi(joinpath(HERE, "uniform_mpiexec.jl"); nproc = 4) + end +end + +@testset "mpiexec / non-uniform backend (4 ranks)" begin + if get(ENV, "MPIRPC_SKIP_MPI_TESTS", "0") == "1" + @info "MPIRPC: skipping mpiexec non-uniform suite (MPIRPC_SKIP_MPI_TESTS=1)" + @test true + else + @test _run_mpi(joinpath(HERE, "nonuniform_mpiexec.jl"); nproc = 4) + end +end + +# The two daemon suites verify the optional progress-daemon path: the +# scripts never call `rpc_progress!` or `serve_listener` themselves, so a +# successful run is *only* possible if the yield-only `_daemon_loop` +# spawned by `select_mpi_rpc_backend!` is doing its job. We give them a +# separate testset (rather than folding them into the existing two) so a +# regression in the daemon does not get masked by the manual-pump suites. +@testset "mpiexec / uniform backend with daemon (4 ranks)" begin + if get(ENV, "MPIRPC_SKIP_MPI_TESTS", "0") == "1" + @info "MPIRPC: skipping mpiexec uniform-daemon suite (MPIRPC_SKIP_MPI_TESTS=1)" + @test true + else + @test _run_mpi(joinpath(HERE, "uniform_daemon_mpiexec.jl"); nproc = 4) + end +end + +@testset "mpiexec / non-uniform backend with daemon (4 ranks)" begin + if get(ENV, "MPIRPC_SKIP_MPI_TESTS", "0") == "1" + @info "MPIRPC: skipping mpiexec non-uniform-daemon suite (MPIRPC_SKIP_MPI_TESTS=1)" + @test true + else + @test _run_mpi(joinpath(HERE, "nonuniform_daemon_mpiexec.jl"); nproc = 4) + end +end + +# Example: RPC matmul with explicit listener / client roles (no client-to-client RPC). +@testset "mpiexec / bcast_remotecall (3 ranks)" begin + if get(ENV, "MPIRPC_SKIP_MPI_TESTS", "0") == "1" + @info "MPIRPC: skipping mpiexec bcast_remotecall (MPIRPC_SKIP_MPI_TESTS=1)" + @test true + else + @test _run_mpi(joinpath(HERE, "bcast_remotecall_mpiexec.jl"); nproc = 3) + end +end + +# Example: RPC matmul with explicit listener / client roles (no client-to-client RPC). +@testset "mpiexec / example rpc_matmul non-uniform (4 ranks)" begin + if get(ENV, "MPIRPC_SKIP_MPI_TESTS", "0") == "1" + @info "MPIRPC: skipping mpiexec rpc_matmul_nonuniform (MPIRPC_SKIP_MPI_TESTS=1)" + @test true + else + ex = joinpath(dirname(HERE), "examples", "rpc_matmul_nonuniform.jl") + @test _run_mpi(ex; nproc = 4) + end +end diff --git a/lib/MPIRPC/test/nonuniform_daemon_mpiexec.jl b/lib/MPIRPC/test/nonuniform_daemon_mpiexec.jl new file mode 100644 index 000000000..baf816942 --- /dev/null +++ b/lib/MPIRPC/test/nonuniform_daemon_mpiexec.jl @@ -0,0 +1,138 @@ +using Test +using MPI +using MPIRPC + +# Non-uniform companion to `uniform_daemon_mpiexec.jl`. Listeners receive +# inbound requests via the daemon, clients receive their own replies via +# the daemon (or via `wait`/`fetch` self-pumping — both paths must work). +# No testset in this script calls `rpc_progress!` or `serve_listener`. + +MPI.Init(; threadlevel = :multiple) + +const WORLD_RANK = MPI.Comm_rank(MPI.COMM_WORLD) +const NPROC = MPI.Comm_size(MPI.COMM_WORLD) + +NPROC >= 4 || error("non-uniform daemon tests need at least 4 ranks (got $NPROC); " * + "rerun with `mpiexec -n 4 ...`") + +const HALF = max(1, NPROC ÷ 2) +const LISTENER_RANKS = collect(0:(HALF - 1)) +const CLIENT_RANKS = collect(HALF:(NPROC - 1)) +const IS_LISTENER = WORLD_RANK in LISTENER_RANKS + +backend = MPIRPC.select_mpi_rpc_backend!( + NonUniformMPIRPCBackend(MPI.COMM_WORLD; + listener_ranks = LISTENER_RANKS, + daemon = true), +) +const COMM = backend.comm + +@assert backend.daemon "daemon flag did not stick" +@assert backend.daemon_task isa Task "daemon task should have been spawned" +@assert !istaskdone(backend.daemon_task) + +# Daemon must be on the `:interactive` pool so it cannot be starved by +# CPU-bound user code on the `:default` pool. See uniform_daemon_mpiexec.jl +# for the rationale. +if Threads.nthreads(:interactive) > 0 + @assert Threads.threadpool(backend.daemon_task) === :interactive ( + "daemon task expected on :interactive, got $(Threads.threadpool(backend.daemon_task))") +end + +# `_phase!` uses `rpc_barrier`; see the uniform daemon script for why this +# is safe alongside the daemon (per-backend `mpi_lock` serializes raw MPI +# calls, the barrier request is task-local). +function _phase!(name::AbstractString) + rpc_barrier() + if WORLD_RANK == 0 + println("--- ", name, " ---") + flush(stdout) + end + rpc_barrier() +end + +# --------------------------------------------------------------------------- + +_phase!("non-uniform-daemon / smoke (no manual progress, no serve_listener)") +@testset "non-uniform-daemon / client-listener round trip" begin + if !IS_LISTENER + peer = first(LISTENER_RANKS) + # Critically, this script never has the listener call + # `serve_listener` or `rpc_progress!`. The daemon on the listener + # is doing all inbound draining. + got = remotecall_fetch(+, peer, WORLD_RANK, 1) + @test got == WORLD_RANK + 1 + else + @test true + end +end + +_phase!("non-uniform-daemon / each client fans out to all listeners") +@testset "non-uniform-daemon / each client fans out to all listeners" begin + if !IS_LISTENER + results = Vector{Int}(undef, length(LISTENER_RANKS)) + for (i, l) in enumerate(LISTENER_RANKS) + results[i] = remotecall_fetch(+, l, WORLD_RANK, l) + end + @test results == [WORLD_RANK + l for l in LISTENER_RANKS] + else + @test true + end +end + +_phase!("non-uniform-daemon / multiple client threads concurrent") +@testset "non-uniform-daemon / multiple client threads concurrent" begin + if Threads.nthreads() >= 2 && !IS_LISTENER + K = 16 + peer = first(LISTENER_RANKS) + results = Vector{Vector{Int}}(undef, Threads.nthreads()) + ts = Task[] + for tid in 1:Threads.nthreads() + t = Threads.@spawn begin + local got = Vector{Int}(undef, K) + for k in 1:K + got[k] = MPIRPC.remotecall_fetch( + +, peer, WORLD_RANK * 100 + tid, k) + end + results[tid] = got + end + push!(ts, t) + end + foreach(wait, ts) + for tid in 1:Threads.nthreads() + @test results[tid] == [WORLD_RANK * 100 + tid + k for k in 1:K] + end + else + @test true + end +end + +_phase!("non-uniform-daemon / handler that spawns and fetches") +@testset "non-uniform-daemon / nested fetch from spawned handler subtask" begin + if Threads.nthreads() >= 2 && length(LISTENER_RANKS) >= 2 && !IS_LISTENER + peer = LISTENER_RANKS[1] + helper = LISTENER_RANKS[2] + result = remotecall_fetch(peer, helper, WORLD_RANK) do helper_rank, origin + t = Threads.@spawn MPIRPC.remotecall_fetch(+, helper_rank, origin, 5) + return fetch(t) + 1000 + end + @test result == WORLD_RANK + 5 + 1000 + else + @test true + end +end + +_phase!("non-uniform-daemon / shutdown joins the daemon") +@testset "non-uniform-daemon / shutdown! cleanly joins the daemon" begin + @test backend.daemon_task isa Task + @test !istaskdone(backend.daemon_task) + shutdown!() + @test backend.running[] == false + @test backend.daemon_task === nothing +end + +# Final sync on the duplicate communicator. As in the uniform daemon +# script, every rank has called `shutdown!` so no rank is pumping +# progress; a plain `MPI.Barrier` on the RPC subcomm is the right tool. +MPI.Barrier(backend.comm) +MPI.Finalize() diff --git a/lib/MPIRPC/test/nonuniform_mpiexec.jl b/lib/MPIRPC/test/nonuniform_mpiexec.jl new file mode 100644 index 000000000..9a2914ca6 --- /dev/null +++ b/lib/MPIRPC/test/nonuniform_mpiexec.jl @@ -0,0 +1,242 @@ +using Test +using MPI +using MPIRPC +using Random + +Random.seed!(0xC0FFEE) + +MPI.Init(; threadlevel=:multiple) + +const WORLD_RANK = MPI.Comm_rank(MPI.COMM_WORLD) +const NPROC = MPI.Comm_size(MPI.COMM_WORLD) + +NPROC >= 4 || error("non-uniform tests need at least 4 ranks (got $NPROC); " * + "rerun with `mpiexec -n 4 ...`") + +# Roughly half the ranks are listeners, the rest are clients only. +const HALF = max(1, NPROC ÷ 2) +const LISTENER_RANKS = collect(0:(HALF - 1)) +const CLIENT_RANKS = collect(HALF:(NPROC - 1)) +const IS_LISTENER = WORLD_RANK in LISTENER_RANKS + +backend = MPIRPC.select_mpi_rpc_backend!( + NonUniformMPIRPCBackend(MPI.COMM_WORLD; listener_ranks = LISTENER_RANKS)) +const COMM = backend.comm + +@assert is_listener(backend) == IS_LISTENER +@assert listener_ranks(backend) == LISTENER_RANKS + +# A progress-pumping barrier and an announce on rank 0 — same idea as the +# uniform suite. Critical for non-uniform: clients pump replies and listeners +# pump requests, so a `rpc_barrier` after each phase prevents a listener +# from "going quiet" before clients have collected pending replies. +function _phase!(name::AbstractString) + rpc_barrier() + if WORLD_RANK == 0 + println("--- ", name, " ---") + flush(stdout) + end + rpc_barrier() +end + +function pump_until(pred; timeout::Real = 30.0, + backend = MPIRPC.current_mpi_rpc_backend()) + t0 = time() + while !pred() + rpc_progress!(backend) + time() - t0 > timeout && return false + yield() + end + return true +end + +# --------------------------------------------------------------------------- + +_phase!("non-uniform / role checks") +@testset "non-uniform / role checks" begin + @test is_listener(backend) == (WORLD_RANK in LISTENER_RANKS) + if !IS_LISTENER && length(CLIENT_RANKS) >= 2 + # Clients cannot remotecall to other clients; only listeners service RPC. + another_client = first(c for c in CLIENT_RANKS if c != WORLD_RANK) + @test_throws ArgumentError remotecall(+, another_client, 1, 2) + else + @test true + end +end + +_phase!("non-uniform / clients hammer listeners") +@testset "non-uniform / clients hammer listeners" begin + if IS_LISTENER + # Listeners do not initiate RPC in this test — they just service. + @test true + else + K = 16 + futs = MPIFuture[] + for ℓ in LISTENER_RANKS + for k in 1:K + push!(futs, remotecall(+, ℓ, WORLD_RANK, k)) + end + end + @test pump_until(() -> all(isready, futs); timeout = 60.0) + # Each call computed `WORLD_RANK + k` (in unspecified ℓ). + for ℓ in LISTENER_RANKS, k in 1:K + idx = findfirst(==(WORLD_RANK + k), [fetch(f) for f in futs]) + @test idx !== nothing + end + end +end + +_phase!("non-uniform / handler observes its own rank") +@testset "non-uniform / handler observes its own rank" begin + if !IS_LISTENER + peer = first(LISTENER_RANKS) + got = remotecall_fetch(peer) do + return MPI.Comm_rank(MPI.COMM_WORLD) + end + @test got == peer + else + @test true + end +end + +_phase!("non-uniform / many OIDs from one client to one listener") +@testset "non-uniform / many OIDs from one client to one listener" begin + if !IS_LISTENER + peer = first(LISTENER_RANKS) + K = 64 + futs = [remotecall(+, peer, WORLD_RANK * 1000, k) for k in 1:K] + @test pump_until(() -> all(isready, futs); timeout = 30.0) + perm = randperm(K) + for i in perm + @test fetch(futs[i]) == WORLD_RANK * 1000 + i + end + else + @test true + end +end + +_phase!("non-uniform / multi-threaded handler-spawn-fetch") +@testset "non-uniform / listener handler that Threads.@spawn-then-fetches" begin + # Regression test for the multi-threaded handler deadlock. The listener + # receiving the call dispatches the handler on `Threads.@spawn`; the + # handler in turn spawns a task that calls back into another listener + # via `remotecall_fetch`, then `fetch`es it. This worked under the + # synchronous-handler model only because of `ReentrantLock` re-entry + # on the same task; under multi-threading it required removing + # `progress_lock` and running handlers on their own tasks. + if Threads.nthreads() >= 2 && length(LISTENER_RANKS) >= 2 && !IS_LISTENER + peer = LISTENER_RANKS[1] + helper = LISTENER_RANKS[2] + result = remotecall_fetch(peer, helper, WORLD_RANK) do helper_rank, origin + t = Threads.@spawn MPIRPC.remotecall_fetch(+, helper_rank, origin, 5) + return fetch(t) + 1000 + end + @test result == WORLD_RANK + 5 + 1000 + else + @test true + end +end + +_phase!("non-uniform / concurrent client threads") +@testset "non-uniform / multiple client threads issuing concurrent RPC" begin + if Threads.nthreads() >= 2 && !IS_LISTENER + K = 16 + peer = first(LISTENER_RANKS) + results = Vector{Vector{Int}}(undef, Threads.nthreads()) + ts = Task[] + for tid in 1:Threads.nthreads() + t = Threads.@spawn begin + local got = Vector{Int}(undef, K) + for k in 1:K + got[k] = MPIRPC.remotecall_fetch(+, peer, WORLD_RANK * 100 + tid, k) + end + results[tid] = got + end + push!(ts, t) + end + for t in ts + wait(t) + end + for tid in 1:Threads.nthreads() + @test results[tid] == [WORLD_RANK * 100 + tid + k for k in 1:K] + end + else + @test true + end +end + +_phase!("non-uniform / world barrier coexists with subcomm RPC") +@testset "non-uniform / world barrier coexists with subcomm RPC" begin + if !IS_LISTENER + peer = first(LISTENER_RANKS) + f = remotecall(+, peer, WORLD_RANK, 1) + # Every rank in COMM_WORLD must reach this barrier so collectives + # stay aligned. Listeners reach it from the listener phase below. + MPI.Barrier(MPI.COMM_WORLD) + @test pump_until(() -> isready(f); timeout = 30.0) + @test fetch(f) == WORLD_RANK + 1 + else + # Listener: pump while clients post their RPC, then barrier with them. + t0 = time() + while time() - t0 < 0.5 + rpc_progress!(backend) + yield() + end + MPI.Barrier(MPI.COMM_WORLD) + # Pump a bit more so any pending replies clear out before _phase!. + t0 = time() + while time() - t0 < 0.5 + rpc_progress!(backend) + yield() + end + @test true + end +end + +_phase!("non-uniform / remote exception path") +@testset "non-uniform / remote exception path" begin + if !IS_LISTENER + peer = first(LISTENER_RANKS) + @test_throws MPIRemoteException remotecall_fetch(peer) do + error("planned failure") + end + else + @test true + end +end + +_phase!("non-uniform / remote_do is fire-and-forget") +@testset "non-uniform / remote_do returns nothing immediately" begin + if !IS_LISTENER + peer = first(LISTENER_RANKS) + @test remote_do(peer) do; nothing end === nothing + else + @test true + end +end + +_phase!("non-uniform / set-then-read via remotecall_wait") +@testset "non-uniform / set-then-read with remotecall_wait" begin + if !IS_LISTENER + peer = first(LISTENER_RANKS) + sentinel = Symbol("_NONUNI_SET_FROM_$(WORLD_RANK)") + # Under multi-threading the spawned-handler model does not preserve + # FIFO of handler *execution* on the same (src, dest, tag) — only + # MPI delivery. `remotecall_wait` provides the explicit ack we need + # before reading back the side effect. + remotecall_wait(peer, WORLD_RANK, sentinel) do origin, name + @eval Main const $name = $origin + end + v = remotecall_fetch(peer, sentinel) do name + return getfield(Main, name) + end + @test v == WORLD_RANK + else + @test true + end +end + +_phase!("non-uniform / shutdown") +shutdown!() +rpc_barrier() +MPI.Finalize() diff --git a/lib/MPIRPC/test/protocol_tests.jl b/lib/MPIRPC/test/protocol_tests.jl new file mode 100644 index 000000000..25dd3b381 --- /dev/null +++ b/lib/MPIRPC/test/protocol_tests.jl @@ -0,0 +1,120 @@ +using Test +using Serialization +using MPIRPC +using MPIRPC: MsgHeader, MPIRRID, NULL_RRID, CallMsg, CallWaitMsg, RemoteDoMsg, + ResultMsg, RPCProgressHaltMsg, encode_frame, decode_frame, MSG_BOUNDARY, ProtocolError, + is_null + +@testset "protocol / framing (no MPI)" begin + @testset "MPIRRID identity" begin + a = MPIRRID(Int32(2), UInt64(7)) + b = MPIRRID(Int32(2), UInt64(7)) + c = MPIRRID(Int32(3), UInt64(7)) + @test a == b + @test hash(a) == hash(b) + @test a != c + @test is_null(NULL_RRID) + @test !is_null(a) + end + + @testset "MsgHeader defaults" begin + @test is_null(MsgHeader().response_oid) + @test is_null(MsgHeader().notify_oid) + @test MsgHeader(MPIRRID(Int32(1), UInt64(2))).notify_oid == NULL_RRID + end + + @testset "CallMsg roundtrip preserves Mode and args" begin + h = MsgHeader(NULL_RRID, MPIRRID(Int32(0), UInt64(99))) + m = CallMsg{:call_fetch}(+, (1, 2, 3), Pair{Symbol,Any}[:by => 10]) + buf = encode_frame(h, m) + h2, m2, err = decode_frame(buf) + @test err === nothing + @test h2 == h + @test m2 isa CallMsg{:call_fetch} + @test m2.args == (1, 2, 3) + @test m2.kwargs == Pair{Symbol,Any}[:by => 10] + end + + @testset "ResultMsg roundtrip" begin + h = MsgHeader(MPIRRID(Int32(0), UInt64(1)), NULL_RRID) + m = ResultMsg([1.0, 2.0, 3.0]) + buf = encode_frame(h, m) + h2, m2, err = decode_frame(buf) + @test err === nothing + @test h2 == h + @test m2 isa ResultMsg + @test m2.value == [1.0, 2.0, 3.0] + end + + @testset "RemoteDoMsg has null OIDs" begin + h = MsgHeader(NULL_RRID, NULL_RRID) + m = RemoteDoMsg(println, ("hi",), Pair{Symbol,Any}[]) + buf = encode_frame(h, m) + _, m2, err = decode_frame(buf) + @test err === nothing + @test m2 isa RemoteDoMsg + end + + @testset "RPCProgressHaltMsg roundtrip" begin + buf = encode_frame(MsgHeader(), RPCProgressHaltMsg()) + h2, m2, err = decode_frame(buf) + @test err === nothing + @test h2 == MsgHeader() + @test m2 isa RPCProgressHaltMsg + end + + @testset "fresh serializer per frame: state does not leak" begin + # Two frames sharing a Vector{Int} should still round-trip identically; + # if state leaked across frames, the second decode would resolve a + # stale back-reference and crash. + v = collect(1:10) + b1 = encode_frame(MsgHeader(), ResultMsg(v)) + b2 = encode_frame(MsgHeader(), ResultMsg(v)) + _, m1, _ = decode_frame(b1) + _, m2, _ = decode_frame(b2) + @test m1.value == v + @test m2.value == v + end + + @testset "MSG_BOUNDARY mismatch is fail-fast" begin + buf = encode_frame(MsgHeader(), ResultMsg(42)) + buf[end] ⊻= 0xFF # corrupt last byte of the boundary + _, _, err = decode_frame(buf) + @test err isa ProtocolError + end + + @testset "boundary missing entirely is fail-fast" begin + buf = encode_frame(MsgHeader(), ResultMsg(42)) + truncated = buf[1:end - length(MSG_BOUNDARY) - 1] + _, _, err = decode_frame(truncated) + @test err !== nothing # may be EOFError or ProtocolError, both acceptable + end + + @testset "body deserialization error is captured, not thrown" begin + hbuf = IOBuffer() + s = Serializer(hbuf) + serialize(s, MsgHeader()) + hbytes = take!(hbuf) + garbage = UInt8[0xff for _ in 1:32] + full = vcat(hbytes, garbage, MSG_BOUNDARY) + h, m, err = decode_frame(full) + @test h isa MsgHeader # header parsed successfully + @test err !== nothing + @test m === nothing + end + + @testset "run_work_thunk print_error mirrors Distributed" begin + path, io = mktemp() + close(io) + open(path, "w") do f + redirect_stderr(f) do + v = MPIRPC.run_work_thunk(() -> error("MPIRPC_planned_thunk_failure"), 99; print_error=true) + @test v isa MPIRemoteException + @test v.rank == 99 + end + end + cap = read(path, String) + @test occursin("MPIRPC_planned_thunk_failure", cap) + rm(path) + end +end diff --git a/lib/MPIRPC/test/runtests.jl b/lib/MPIRPC/test/runtests.jl new file mode 100644 index 000000000..274ec58ed --- /dev/null +++ b/lib/MPIRPC/test/runtests.jl @@ -0,0 +1,7 @@ +using Test +using MPIRPC + +@testset "MPIRPC" begin + include("protocol_tests.jl") + include("mpi_tests.jl") +end diff --git a/lib/MPIRPC/test/uniform_daemon_mpiexec.jl b/lib/MPIRPC/test/uniform_daemon_mpiexec.jl new file mode 100644 index 000000000..d9652e823 --- /dev/null +++ b/lib/MPIRPC/test/uniform_daemon_mpiexec.jl @@ -0,0 +1,326 @@ +using Test +using MPI +using MPIRPC + +# This script verifies that the *yield-only progress daemon* is enough to +# drive the uniform backend end-to-end without a single explicit call to +# `rpc_progress!` from user code. The flag we never set in this script — +# anywhere — is `rpc_progress!`. If anything inside the testsets calls it, +# the test would not be measuring what it claims to. + +MPI.Init(; threadlevel = :multiple) +backend = MPIRPC.select_mpi_rpc_backend!( + UniformMPIRPCBackend(MPI.COMM_WORLD; daemon = true), +) +const COMM = backend.comm +const RANK = MPI.Comm_rank(COMM) +const NPROC = MPI.Comm_size(COMM) + +NPROC >= 2 || error("uniform daemon tests require at least 2 ranks (got $NPROC)") +@assert backend.daemon "daemon flag did not stick" +@assert backend.daemon_task isa Task "daemon task should have been spawned by select_mpi_rpc_backend!" +@assert !istaskdone(backend.daemon_task) "daemon task already exited before any RPC ran" + +# The daemon must run on the `:interactive` pool so user CPU-bound work +# on the `:default` pool cannot starve the wire pump. `mpi_tests.jl` +# launches us with `--threads=N,1` to make at least one interactive +# thread available; if for some reason it didn't, the assertion below +# fails loudly rather than letting the test pass under a weaker +# guarantee than what the design promises. +if Threads.nthreads(:interactive) > 0 + @assert Threads.threadpool(backend.daemon_task) === :interactive ( + "daemon task expected on :interactive, got $(Threads.threadpool(backend.daemon_task))") +end + +# Phase boundary: still uses `rpc_barrier` because that is a *collective* +# synchronization, not a progress driver — every rank must reach it. The +# daemon happens to also pump progress on each rank while inside the +# barrier, which is fine: `rpc_barrier` and `_daemon_loop` both go +# through the same per-backend `mpi_lock` so they cannot race on raw MPI +# calls. +function _phase!(name::AbstractString) + rpc_barrier() + if RANK == 0 + println("--- ", name, " ---") + flush(stdout) + end + rpc_barrier() +end + +# --------------------------------------------------------------------------- + +_phase!("uniform-daemon / smoke (no manual progress)") +@testset "uniform-daemon / remotecall_fetch with no manual progress" begin + peer = mod(RANK + 1, NPROC) + got = remotecall_fetch(+, peer, RANK, 100) + @test got == RANK + 100 +end + +_phase!("uniform-daemon / ABBA without manual progress") +@testset "uniform-daemon / ABBA without manual progress" begin + # Same shape as the non-daemon ABBA test, but here neither client side + # nor server side has any user-driven `rpc_progress!`. Forward progress + # on inbound *and* on reply draining for our own outstanding futures + # is entirely the daemon's responsibility. (Reply draining works + # because `wait`/`fetch` on `MPIFuture` calls `rpc_progress!` itself, + # but the more interesting half — receiving inbound requests from + # peers — has no user touchpoint.) + peer = mod(RANK + 1, NPROC) + futs = [remotecall(*, peer, RANK + 1, k) for k in 1:8] + for (k, f) in enumerate(futs) + @test fetch(f) == (RANK + 1) * k + end +end + +_phase!("uniform-daemon / fully passive rank") +@testset "uniform-daemon / a rank that never initiates RPC still serves" begin + # The cleanest demonstration that the daemon is doing useful work: + # one rank issues every RPC, others issue none. Without a daemon the + # passive ranks would never see their inbound `CallMsg` because nobody + # is calling `rpc_progress!` on them. The active rank uses `wait`, + # which itself pumps progress, but that only drives *its own* MPI + # state — the inbound matching on the passive ranks is independent. + if RANK == 0 + results = Vector{Int}(undef, NPROC - 1) + for p in 1:(NPROC - 1) + results[p] = remotecall_fetch(+, p, p, 1000) + end + @test results == [p + 1000 for p in 1:(NPROC - 1)] + else + # Passive ranks: don't touch the API at all in this testset. + @test true + end +end + +_phase!("uniform-daemon / many concurrent client tasks") +@testset "uniform-daemon / concurrent client tasks rely on daemon for inbound" begin + if Threads.nthreads() >= 2 && NPROC >= 2 + peers = [r for r in 0:(NPROC - 1) if r != RANK] + K = 16 + results = Vector{Vector{Int}}(undef, length(peers)) + ts = Task[] + for (i, p) in enumerate(peers) + t = Threads.@spawn begin + local got = Vector{Int}(undef, K) + for k in 1:K + got[k] = MPIRPC.remotecall_fetch(+, p, RANK * 1000, k) + end + results[i] = got + end + push!(ts, t) + end + foreach(wait, ts) + for (i, p) in enumerate(peers) + @test results[i] == [RANK * 1000 + k for k in 1:K] + end + else + @test true + end +end + +_phase!("uniform-daemon / handler that spawns and fetches an RPC") +@testset "uniform-daemon / nested fetch from spawned handler subtask" begin + # Same regression as in the non-daemon suite (the deadlock that + # motivated removing `progress_lock`), but with the daemon also + # racing to drain the request queue. We expect both to coexist. + if Threads.nthreads() >= 2 && NPROC >= 3 + peer = mod(RANK + 1, NPROC) + helper = mod(RANK + 2, NPROC) + result = remotecall_fetch(peer, helper, RANK) do helper_rank, origin + t = Threads.@spawn MPIRPC.remotecall_fetch(+, helper_rank, origin, 7) + return fetch(t) * 2 + end + @test result == (RANK + 7) * 2 + else + @test true + end +end + +_phase!("uniform-daemon / starvation: CPU-bound default-pool task must not block daemon") +@testset "uniform-daemon / daemon survives CPU-bound default-pool work" begin + # Pathological-case test: spawn a CPU-bound task on the `:default` + # pool that does *not* yield, and concurrently issue RPC traffic to + # a peer. Under a naive design where the daemon shares the default + # pool with user work, the busy task could starve the daemon and + # the peer's `remotecall_fetch` call to *us* would never be + # serviced. With the daemon on `:interactive`, this cannot happen: + # the interactive thread is reserved. + if Threads.nthreads(:interactive) >= 1 && NPROC >= 2 + peer = mod(RANK + 1, NPROC) + # Burn ~0.5 s of CPU on every default-pool thread, with no + # explicit yield points. If the daemon were on :default it + # would be locked out for the duration. + burners = Task[] + deadline = time() + 0.5 + for _ in 1:Threads.nthreads(:default) + t = Threads.@spawn :default begin + acc = 0 + while time() < deadline + # Inner loop deliberately has no `yield()` and no + # I/O so the cooperative scheduler does not + # preempt this task. + for i in 1:1_000_000 + acc = (acc + i * 31) % 9_973 + end + end + acc + end + push!(burners, t) + end + # While the burners hammer the default pool, the peer is going + # to call remotecall_fetch on *us* — see the symmetric arm + # below. Our daemon on :interactive must keep accepting that + # request. We measure success by completing our own + # remotecall_fetch *to* the peer within a tight bound. + t0 = time() + got = MPIRPC.remotecall_fetch(+, peer, RANK, 7) + elapsed = time() - t0 + @test got == RANK + 7 + # Generous bound — what matters is that this completes at all + # while default-pool threads are saturated. Without the + # interactive-pool placement the call would block until all + # burners finish (≈ 0.5 s); with it, the call completes in a + # few ms. We give 5 s of slack so a slow CI does not flake. + @test elapsed < 5.0 + foreach(wait, burners) + else + @test true + end +end + +_phase!("uniform-daemon / cond-park: many concurrent waiters wake exactly once") +@testset "uniform-daemon / cond-park: many concurrent waiters wake correctly" begin + # Regression test for the `Threads.Condition`-park path in + # `wait(::MPIFuture)`. Spawn a large number of client tasks, each + # blocking in `fetch` on its own future, and verify every one + # observes its expected reply. The properties we want to lock in: + # + # 1. Every parked waiter is woken — no lost-wakeup race between + # `deliver!`'s `notify` and the consumer's `wait(f.cond)`. + # 2. Each waiter wakes for *its own* future (no cross-future + # wakeup that returns a stale value to the wrong task). + # 3. The daemon's `_dispatch_reply!` path correctly transitions + # through `take_waiter!` → `deliver!` → cond `notify` while + # none of `mpi_lock`, `waiters_lock`, `f.cond` are held more + # than one at a time. + # + # The test cannot directly assert "no spinning happened" without + # peeking at task state, but a regression to the spin path would + # still pass *correctness* — so we additionally check that the + # waiters are *responsive*, by serializing many short calls and + # measuring that they all complete inside a generous bound. If + # `wait` was somehow not waking on `notify` (e.g. a bug where we + # accidentally took a stale `isready` snapshot under the lock and + # parked anyway), this would time out. + if Threads.nthreads() >= 2 + peers = [r for r in 0:(NPROC - 1) if r != RANK] + K = 64 # waiters per peer + results = Vector{Vector{Int}}(undef, length(peers)) + ts = Task[] + t_start = time() + for (i, p) in enumerate(peers) + t = Threads.@spawn begin + local got = Vector{Int}(undef, K) + # Each task spawns its own future and immediately blocks + # in `fetch`. Under daemon=true this means each task + # parks on its own `f.cond` while the daemon services + # all of the inbound replies on this rank. + for k in 1:K + got[k] = MPIRPC.remotecall_fetch(+, p, RANK * 10_000, k) + end + results[i] = got + end + push!(ts, t) + end + foreach(wait, ts) + elapsed = time() - t_start + for (i, p) in enumerate(peers) + @test results[i] == [RANK * 10_000 + k for k in 1:K] + end + # Loose upper bound. On a healthy machine this completes in + # well under a second; we give it 60 s to absorb CI variance. + # The point is to catch a hang, not to enforce performance. + @test elapsed < 60.0 + else + @test true + end +end + +_phase!("uniform-daemon / Imrecv!: large concurrent payloads round-trip") +@testset "uniform-daemon / large concurrent payloads via Imrecv! + Test/yield" begin + # Regression test for the non-blocking-receive path. Each task on + # this rank issues a `remotecall_fetch` with a 1 MiB payload. On + # most MPI implementations 1 MiB exceeds the eager threshold + # (typical defaults: 64 KiB for OpenMPI, 256 KiB for MPICH), so + # the receiving side actually goes through the rendezvous + # protocol — which is the one regime where the previous + # `MPI.Mrecv!` would have held `mpi_lock` for a non-trivial + # duration. With `Imrecv!` + `Test`/`yield` this is broken up. + # + # We verify only correctness (all replies arrive, with the right + # values) and that the whole batch completes in bounded time. + # Asserting "the daemon thread did not block" directly is not + # cleanly testable from user code; the existence and correctness + # of the round-trip under concurrent pressure is the practical + # proxy. + if Threads.nthreads() >= 2 && NPROC >= 2 + peer = mod(RANK + 1, NPROC) + N = 1_048_576 # 1 MiB + K = 4 # concurrent calls + # Use the same canonical payload across tasks so we can + # cross-check sums without reconstructing per-task arrays. + payload = Vector{UInt8}(undef, N) + for i in 1:N + payload[i] = UInt8((i - 1) % 256) + end + expected = sum(Int(b) for b in payload) + ts = Task[] + for _ in 1:K + t = Threads.@spawn MPIRPC.remotecall_fetch(peer, payload) do data + # Verify on the remote side that the buffer survived + # the rendezvous round-trip intact, then return the + # sum so the client can independently verify. + length(data) == N || error("size mismatch: got $(length(data))") + acc = 0 + for b in data + acc += Int(b) + end + return acc + end + push!(ts, t) + end + t0 = time() + results = fetch.(ts) + elapsed = time() - t0 + @test all(==(expected), results) + # Loose bound: with rendezvous and 4 concurrent 1-MiB payloads + # plus their replies, we expect well under 30 s on any sane + # interconnect (loopback / shared memory / TCP localhost). The + # point of the bound is to catch a hang from a leaked request + # or a missed `Test`, not to enforce performance. + @test elapsed < 30.0 + else + @test true + end +end + +_phase!("uniform-daemon / shutdown joins the daemon task") +@testset "uniform-daemon / shutdown! cleanly joins the daemon" begin + @test backend.daemon_task isa Task + @test !istaskdone(backend.daemon_task) + # `shutdown!` flips `running[]` to `false` and waits for the daemon + # task to terminate. After it returns, the daemon must be done and + # the `daemon_task` slot must have been cleared so a (hypothetical) + # subsequent re-init does not see a stale handle. + shutdown!() + @test backend.running[] == false + @test backend.daemon_task === nothing +end + +# Final coordination on the duplicate communicator. We cannot use +# `rpc_barrier` here: the daemon is gone, and every rank has called +# `shutdown!` so no rank is pumping progress. A plain `MPI.Barrier` +# on `backend.comm` is what we want — there is no in-flight RPC at +# this point because every testset above completed via `fetch`/`wait`. +MPI.Barrier(backend.comm) +MPI.Finalize() diff --git a/lib/MPIRPC/test/uniform_mpiexec.jl b/lib/MPIRPC/test/uniform_mpiexec.jl new file mode 100644 index 000000000..4f5187534 --- /dev/null +++ b/lib/MPIRPC/test/uniform_mpiexec.jl @@ -0,0 +1,336 @@ +using Test +using MPI +using MPIRPC +using Random + +# Each rank sets the *same* seed so any per-rank randomized peer-selection +# decision is reproducible across processes. +Random.seed!(0xC0FFEE) + +MPI.Init(; threadlevel=:multiple) +backend = MPIRPC.select_mpi_rpc_backend!(UniformMPIRPCBackend(MPI.COMM_WORLD)) +const COMM = backend.comm +const RANK = MPI.Comm_rank(COMM) +const NPROC = MPI.Comm_size(COMM) + +NPROC >= 2 || error("uniform tests require at least 2 ranks (got $NPROC)") + +# Synchronisation aid: a *progress-pumping* barrier on the RPC subcomm and +# announce the next testset on rank 0 so a hang is easy to localise. We +# specifically use `rpc_barrier` rather than `MPI.Barrier` so ranks continue +# servicing inbound RPC while waiting for stragglers — without this, a rank +# whose own primary `remotecall_fetch` completed could exit with a peer's +# nested-RPC request still queued for it. +function _phase!(name::AbstractString) + rpc_barrier() + if RANK == 0 + println("--- ", name, " ---") + flush(stdout) + end + rpc_barrier() +end + +# Pump progress until `pred()` is true or `timeout` seconds elapse. Returns +# `true` if the predicate became true. Liveness assertion: misuse of the API +# (forgetting to pump progress, holding a lock across fetch, etc.) fails in +# bounded time rather than hanging the whole CI. +function pump_until(pred; timeout::Real = 30.0, + backend = MPIRPC.current_mpi_rpc_backend()) + t0 = time() + while !pred() + rpc_progress!(backend) + time() - t0 > timeout && return false + yield() + end + return true +end + +# --------------------------------------------------------------------------- + +_phase!("uniform / smoke") +@testset "uniform / smoke" begin + peer = mod(RANK + 1, NPROC) + got = remotecall_fetch(+, peer, RANK, 100) + # `+` ran on `peer` over the args (RANK, 100), so the result is RANK + 100. + @test got == RANK + 100 +end + +_phase!("uniform / handler observes its own rank") +@testset "uniform / handler observes its own rank" begin + peer = mod(RANK + 1, NPROC) + got = remotecall_fetch(peer) do + return MPI.Comm_rank(MPI.COMM_WORLD) + end + @test got == peer +end + +_phase!("uniform / ABBA cross-calls") +@testset "uniform / ABBA: every rank simultaneously calls (i+1) % N" begin + # Classic ABBA pattern: rank i sends a request to (i+1) and must service + # the inbound request from (i-1) at the same time. With one tag per kind + # of message (request vs reply), correlation by RRID, and progress on + # every rank, this must complete deterministically. + peer = mod(RANK + 1, NPROC) + futs = [remotecall(*, peer, RANK + 1, k) for k in 1:8] + @test pump_until(() -> all(isready, futs); timeout = 30.0) + if all(isready, futs) + for (k, f) in enumerate(futs) + @test fetch(f) == (RANK + 1) * k + end + end +end + +_phase!("uniform / many OIDs, permuted fetch on one peer") +@testset "uniform / many OIDs, permuted fetch on one peer" begin + peer = mod(RANK + 1, NPROC) + K = 64 + expected = [RANK * 1000 + k for k in 1:K] + futs = [remotecall(+, peer, RANK * 1000, k) for k in 1:K] + + perm = randperm(K) + @test pump_until(() -> all(isready, futs); timeout = 30.0) + if all(isready, futs) + got = Vector{Int}(undef, K) + for i in perm + got[i] = fetch(futs[i]) + end + @test got == expected + end +end + +_phase!("uniform / all-to-all burst") +@testset "uniform / all-to-all burst" begin + peers = [r for r in 0:(NPROC-1) if r != RANK] + futs = [remotecall(+, p, RANK, p) for p in peers] + @test pump_until(() -> all(isready, futs); timeout = 30.0) + if all(isready, futs) + for (p, f) in zip(peers, futs) + @test fetch(f) == RANK + p + end + end +end + +_phase!("uniform / same-pair, different correlation ids") +@testset "uniform / same-pair, different correlation ids" begin + # Two messages from the same (src, dest) at the same time, distinguished + # only by RRID. MPI guarantees in-order delivery on (src, dest, tag), but + # the *application-level* matching must use RRID rather than message order + # to avoid coupling correctness to MPI ordering of unrelated calls. + peer = mod(RANK + 1, NPROC) + f1 = remotecall(identity, peer, :first) + f2 = remotecall(identity, peer, :second) + @test pump_until(() -> isready(f1) && isready(f2); timeout = 30.0) + @test fetch(f1) == :first + @test fetch(f2) == :second +end + +_phase!("uniform / nested re-entrant remotecall_fetch") +@testset "uniform / nested re-entrant remotecall_fetch" begin + if NPROC >= 3 + peer = mod(RANK + 1, NPROC) + helper = mod(RANK + 2, NPROC) + # Inside the handler on `peer`, call back to a third rank. Exercises + # nested progress: while `peer` is servicing rank `RANK`'s request, + # its handler must be able to issue and await a fresh + # remotecall_fetch — and the originating rank must keep pumping + # progress so its outer reply can come back. + result = remotecall_fetch(peer, helper, RANK) do helper_rank, origin + inner = MPIRPC.remotecall_fetch(+, helper_rank, origin, 1) + return inner * 10 + end + @test result == (RANK + 1) * 10 + else + @test true # not enough ranks + end +end + +_phase!("uniform / multi-threaded handler-spawn-fetch") +@testset "uniform / handler that Threads.@spawn-then-fetches an RPC" begin + # Regression test for the deadlock previously documented in + # `docs/ARCHITECTURE.md` §6: with synchronous handlers, the handler ran + # on the calling task while `progress_lock` was held. A handler that + # spawned its own task and called `fetch` on it would block forever + # because the spawned task could not acquire `progress_lock` from + # another OS thread (the lock was held by the handler's task on a + # different thread). After moving handlers to `Threads.@spawn`, this + # pattern works. + if Threads.nthreads() >= 2 && NPROC >= 3 + peer = mod(RANK + 1, NPROC) + helper = mod(RANK + 2, NPROC) + result = remotecall_fetch(peer, helper, RANK) do helper_rank, origin + t = Threads.@spawn MPIRPC.remotecall_fetch(+, helper_rank, origin, 7) + return fetch(t) * 2 + end + @test result == (RANK + 7) * 2 + else + @test true # need julia -t >=2 and at least 3 ranks to exercise this + end +end + +_phase!("uniform / many handlers, threads concurrent waiters") +@testset "uniform / concurrent client threads pumping the same backend" begin + # Several client tasks on the same rank issue concurrent + # `remotecall_fetch` calls. With `progress_lock` removed, every task + # can drive progress on its own; with handlers spawned on the peer, + # several of *its* handler tasks can run on different threads. The + # combined effect is a stress test for the lock geometry under + # `julia -t >=2`. + if Threads.nthreads() >= 2 + peers = [r for r in 0:(NPROC-1) if r != RANK] + K = 16 + results = Vector{Vector{Int}}(undef, length(peers)) + client_tasks = Task[] + for (i, p) in enumerate(peers) + t = Threads.@spawn begin + local got = Vector{Int}(undef, K) + for k in 1:K + got[k] = MPIRPC.remotecall_fetch(+, p, RANK * 1000, k) + end + results[i] = got + end + push!(client_tasks, t) + end + for t in client_tasks + wait(t) + end + for (i, p) in enumerate(peers) + @test results[i] == [RANK * 1000 + k for k in 1:K] + end + else + @test true + end +end + +_phase!("uniform / remote_do is fire-and-forget") +@testset "uniform / remote_do returns nothing immediately" begin + peer = mod(RANK + 1, NPROC) + # `remote_do` is fire-and-forget: it returns `nothing` as soon as the + # request has been posted, regardless of whether the remote handler + # has run. + @test remote_do(peer) do; nothing end === nothing +end + +_phase!("uniform / remote_do failure prints to stderr on handler rank") +@testset "uniform / remote_do failure prints to stderr on handler rank" begin + # Rank 0 issues `remote_do` to rank 1; rank 1 captures stderr while its + # progress loop (spawned task) services the inbound message so + # `showerror` from `run_work_thunk(...; print_error=true)` lands in the file. + cap_mark = "MPIRPC_planned_remote_do_failure" + cap_path = joinpath(mktempdir(), "mpirpc_remote_do_stderr.txt") + cap_task = nothing + if RANK == 1 + ready = Channel{Nothing}(1) + cap_task = Threads.@spawn begin + open(cap_path, "w") do f + redirect_stderr(f) do + put!(ready, nothing) + t0 = time() + while time() - t0 < 30.0 + rpc_progress!(backend) + flush(f) + yield() + if isfile(cap_path) && filesize(cap_path) > 0 + s = read(cap_path, String) + occursin(cap_mark, s) && break + end + end + end + end + return read(cap_path, String) + end + take!(ready) + end + rpc_barrier() + if RANK == 0 + @test remote_do(1) do + error(cap_mark) + end === nothing + end + rpc_barrier() + if RANK == 1 + txt = fetch(cap_task::Task) + @test occursin(cap_mark, txt) + rm(cap_path, force=true) + end + rpc_barrier() +end + +_phase!("uniform / set-then-read via remotecall_wait") +@testset "uniform / set-then-read with remotecall_wait" begin + peer = mod(RANK + 1, NPROC) + sentinel_name = Symbol("_MPIRPC_SET_FROM_$(RANK)") + # `remotecall_wait` blocks the caller until the remote handler + # acknowledges completion, which is the correct synchronization + # primitive when a follow-up call needs to observe the side effect. + # Under multi-threading the spawned-handler model does *not* preserve + # FIFO of *handler execution* on the same `(src, dest, tag)` — only + # FIFO of message delivery — so a `remote_do` followed by a sync + # `remotecall_fetch` would race. + remotecall_wait(peer, RANK, sentinel_name) do origin, name + @eval Main const $name = $origin + end + val = remotecall_fetch(peer, sentinel_name) do name + return getfield(Main, name) + end + @test val == RANK +end + +_phase!("uniform / remote exception is wrapped") +@testset "uniform / remote exception is wrapped" begin + peer = mod(RANK + 1, NPROC) + @test_throws MPIRemoteException remotecall_fetch(peer) do + error("planned failure on remote rank") + end +end + +_phase!("uniform / dest validation") +@testset "uniform / dest validation" begin + @test_throws ArgumentError remotecall(+, NPROC, 1) + @test_throws ArgumentError remotecall(+, -1, 1) +end + +_phase!("uniform / world barrier coexists with subcomm RPC") +@testset "uniform / collective on world comm coexists with RPC subcomm" begin + # The uniform backend dups MPI.COMM_WORLD by default, so a Barrier on the + # original world communicator does not collide with pending RPC traffic. + peer = mod(RANK + 1, NPROC) + fut = remotecall(+, peer, RANK, 1) + MPI.Barrier(MPI.COMM_WORLD) + @test pump_until(() -> isready(fut); timeout = 30.0) + if isready(fut) + @test fetch(fut) == RANK + 1 + end + MPI.Barrier(MPI.COMM_WORLD) +end + +_phase!("uniform / stress: rounds of permuted all-to-all") +@testset "uniform / stress: rounds of permuted all-to-all" begin + M = 4 + failures = 0 + for round in 1:M + # Same seed across ranks ⇒ same permutation; identical scheduling + # decisions on every rank so any failure is reproducible. + Random.seed!(0xCAFE00 + round) + peers = collect(0:(NPROC-1))[randperm(NPROC)] + peers_nonself = [p for p in peers if p != RANK] + futs = [remotecall(+, p, RANK, round) for p in peers_nonself] + completed = pump_until(() -> all(isready, futs); timeout = 30.0) + if completed + for f in futs + fetch(f) == RANK + round || (failures += 1) + end + else + failures += length(peers_nonself) + end + # Progress-pumping barrier between rounds so any round-N inner + # request still in flight on a peer is drained before round N+1 + # starts, even if a few ranks raced ahead of others. + rpc_barrier() + end + @test failures == 0 +end + +_phase!("uniform / shutdown") +shutdown!() +rpc_barrier() +MPI.Finalize() diff --git a/src/array/lu.jl b/src/array/lu.jl index b669100e2..7fbd87991 100644 --- a/src/array/lu.jl +++ b/src/array/lu.jl @@ -107,6 +107,7 @@ function swaprows_panel!(A::AbstractMatrix{T}, M::AbstractMatrix{T}, ipiv_chunk: A[p,:], M[r,:] = M[r,:], A[p,:] end end + return A end @kernel function _geru_kernel!(alpha, x, y, A) @@ -156,6 +157,7 @@ function swaprows_trail!(A::AbstractMatrix{T}, M::AbstractMatrix{T}, ipiv::Abstr end end end + return A end # Implementation of https://inria.hal.science/hal-04984070v1/file/ipdps_paper.pdf diff --git a/src/datadeps/aliasing.jl b/src/datadeps/aliasing.jl index 848443e8e..963629464 100644 --- a/src/datadeps/aliasing.jl +++ b/src/datadeps/aliasing.jl @@ -250,8 +250,8 @@ struct AliasedObjectCacheStore values::Dict{MemorySpace,Dict{AbstractAliasing,Chunk}} originals::Set{AbstractAliasing} end -AliasedObjectCacheStore() = - AliasedObjectCacheStore(current_acceleration(), +AliasedObjectCacheStore(accel::Acceleration) = + AliasedObjectCacheStore(accel, Vector{AbstractAliasing}(), Dict{AbstractAliasing,AbstractAliasing}(), Dict{MemorySpace,Set{AbstractAliasing}}(), @@ -281,6 +281,7 @@ function get_stored(cache::AliasedObjectCacheStore, space::MemorySpace, ainfo::A end function set_stored!(cache::AliasedObjectCacheStore, dest_space::MemorySpace, value::Chunk, ainfo::AbstractAliasing) @assert !is_stored(cache, dest_space, ainfo) "Cache already has derived ainfo $ainfo" + check_uniform(value) key = cache.derived[ainfo] value_ainfo = aliasing(cache.accel, value, identity) cache.derived[value_ainfo] = key @@ -290,6 +291,7 @@ function set_stored!(cache::AliasedObjectCacheStore, dest_space::MemorySpace, va return end function set_key_stored!(cache::AliasedObjectCacheStore, space::MemorySpace, ainfo::AbstractAliasing, value::Chunk) + check_uniform(value) push!(cache.keys, ainfo) cache.derived[ainfo] = ainfo push!(get!(Set{AbstractAliasing}, cache.stored, space), ainfo) @@ -327,7 +329,8 @@ function get_stored(cache::AliasedObjectCache, ainfo::AbstractAliasing) cache_raw = unwrap(cache.chunk)::AliasedObjectCacheStore return get_stored(cache_raw, cache.space, ainfo) end -function set_stored!(cache::AliasedObjectCache, value::Chunk, ainfo::AbstractAliasing) + +function set_stored!(accel::DistributedAcceleration, cache::AliasedObjectCache, value::Chunk, ainfo::AbstractAliasing) wid = root_worker_id(cache.chunk) if wid != myid() return remotecall_fetch(set_stored!, wid, cache, value, ainfo) @@ -336,7 +339,8 @@ function set_stored!(cache::AliasedObjectCache, value::Chunk, ainfo::AbstractAli set_stored!(cache_raw, cache.space, value, ainfo) return end -function set_key_stored!(cache::AliasedObjectCache, space::MemorySpace, ainfo::AbstractAliasing, value::Chunk) + +function set_key_stored!(accel::DistributedAcceleration, cache::AliasedObjectCache, space::MemorySpace, ainfo::AbstractAliasing, value::Chunk) wid = root_worker_id(cache.chunk) if wid != myid() return remotecall_fetch(set_key_stored!, wid, cache, space, ainfo, value) @@ -344,6 +348,7 @@ function set_key_stored!(cache::AliasedObjectCache, space::MemorySpace, ainfo::A cache_raw = unwrap(cache.chunk)::AliasedObjectCacheStore set_key_stored!(cache_raw, space, ainfo, value) end + function aliased_object!(f, cache::AliasedObjectCache, x; ainfo=aliasing(cache.accel, x, identity)) x_space = memory_space(x) if !is_key_present(cache, x_space, ainfo) @@ -351,18 +356,21 @@ function aliased_object!(f, cache::AliasedObjectCache, x; ainfo=aliasing(cache.a # the source key. Using bare `tochunk(x)` defaults to OSProc, which can # incorrectly wrap GPU-backed objects as CPU chunks. x_chunk = x isa Chunk ? x : tochunk(x, first(processors(x_space))) - set_key_stored!(cache, x_space, ainfo, x_chunk) + set_key_stored!(cache.accel, cache, x_space, ainfo, x_chunk) end if is_stored(cache, ainfo) return get_stored(cache, ainfo) else y = f(x) @assert y isa Chunk "Didn't get a Chunk from functor" + a = memory_space(y) + b = cache.space + @assert memory_space(y) == cache.space "Space mismatch! $(memory_space(y)) != $(cache.space)" if memory_space(x) != cache.space - @assert ainfo != aliasing(caache.accel, y, identity) "Aliasing mismatch! $ainfo == $(aliasing(cache.accel, y, identity))" + @assert ainfo != aliasing(cache.accel, y, identity) "Aliasing mismatch! $ainfo == $(aliasing(cache.accel, y, identity))" end - set_stored!(cache, y, ainfo) + set_stored!(cache.accel, cache, y, ainfo) return y end end @@ -440,8 +448,9 @@ struct DataDepsState arg_history = Dict{ArgumentWrapper,Vector{HistoryEntry}}() arg_owner = Dict{ArgumentWrapper,MemorySpace}() arg_overlaps = Dict{ArgumentWrapper,Set{ArgumentWrapper}}() + accel = current_acceleration() ainfo_backing_chunk = _with_default_acceleration() do - tochunk(AliasedObjectCacheStore()) + tochunk(AliasedObjectCacheStore(accel)) end supports_inplace_cache = IdDict{Any,Bool}() @@ -771,9 +780,6 @@ function generate_slot!(state::DataDepsState, dest_space, data) check_uniform(to_proc) from_proc = first(processors(orig_space)) check_uniform(from_proc) - if MPI.Comm_rank(MPI.COMM_WORLD) == 0 - display(typeof(data)) - end check_uniform(typeof(data)) dest_space_args = get!(IdDict{Any,Any}, state.remote_args, dest_space) aliased_object_cache = AliasedObjectCache(current_acceleration(), dest_space, state.ainfo_backing_chunk) @@ -781,7 +787,7 @@ function generate_slot!(state::DataDepsState, dest_space, data) id = rand(Int) @maybelog ctx timespan_start(ctx, :move, (;thunk_id=0, id, position=ArgPosition(), processor=to_proc), (;f=nothing, data)) data_chunk = with(MPI_TID=>DATADEPS_CURRENT_TASK[].uid) do - remotecall_endpoint(move_rewrap, current_acceleration(), aliased_object_cache, from_proc, to_proc, orig_space, dest_space, data) + remotecall_endpoint_toplevel(move_rewrap, current_acceleration(), aliased_object_cache, from_proc, to_proc, orig_space, dest_space, data) end @maybelog ctx timespan_finish(ctx, :move, (;thunk_id=0, id, position=ArgPosition(), processor=to_proc), (;f=nothing, data=data_chunk)) @assert memory_space(data_chunk) == dest_space "space mismatch! $dest_space (dest) != $(memory_space(data_chunk)) (actual) ($(typeof(data)) (data) vs. $(typeof(data_chunk)) (chunk)), spaces ($orig_space -> $dest_space)" @@ -805,22 +811,26 @@ function get_or_generate_slot!(state, dest_space, data) return state.remote_args[dest_space][data] end -function remotecall_fetch_fast(f, wid::Integer, args...; kwargs...) +function remotecall_fetch_fast(::DistributedAcceleration, f, target, args...; kwargs...) + wid = root_worker_id(target) if wid == myid() return f(args...; kwargs...) end return remotecall_fetch(f, wid, args...; kwargs...) end -function remotecall_endpoint(f, accel::DistributedAcceleration, cache::AliasedObjectCache, from_proc, to_proc, from_space, to_space, data::Chunk) - from_w = root_worker_id(from_proc) - return remotecall_fetch_fast(from_w) do + +function is_local(accel::DistributedAcceleration, target) + return root_worker_id(target) == myid() +end + +function remotecall_endpoint_toplevel(f, accel::DistributedAcceleration, cache::AliasedObjectCache, from_proc, to_proc, from_space, to_space, data::Chunk) + return remotecall_fetch_fast(root_worker_id(from_proc)) do data_raw = unwrap(data) return f(accel, cache, from_proc, to_proc, from_space, to_space, data_raw)::Chunk - end + end end function remotecall_endpoint_transfer(f, accel::DistributedAcceleration, from_proc, to_proc, from_space, to_space, data) - to_w = root_worker_id(to_proc) - return remotecall_fetch_fast(to_w) do + return remotecall_fetch_fast(root_worker_id(to_proc)) do return f(accel, from_proc, to_proc, from_space, to_space, data) end end @@ -831,7 +841,7 @@ function move_rewrap(accel, cache::AliasedObjectCache, from_proc::Processor, to_ # Generic data, do the transfer return aliased_object!(cache, data) do data return remotecall_endpoint_transfer(accel, from_proc, to_proc, from_space, to_space, data) do accel, from_proc, to_proc, from_space, to_space, data - return tochunk(move(from_proc, to_proc, data), to_proc) + return tochunk(move(from_proc, to_proc, data), to_proc, to_space) end end end @@ -843,18 +853,17 @@ function move_rewrap(accel, cache::AliasedObjectCache, from_proc::Processor, to_ return remotecall_endpoint_transfer(accel, from_proc, to_proc, from_space, to_space, p_chunk) do accel, from_proc, to_proc, from_space, to_space, p_chunk p_new = move(from_proc, to_proc, p_chunk) v_new = view(p_new, inds...) - return tochunk(v_new, to_proc) + return tochunk(v_new, to_proc, to_space) end end # FIXME: Do this programmatically via recursive dispatch for wrapper in (UpperTriangular, LowerTriangular, UnitUpperTriangular, UnitLowerTriangular) @eval function move_rewrap(accel, cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, v::$(wrapper)) - to_w = root_worker_id(to_proc) p_chunk = move_rewrap(accel, cache, from_proc, to_proc, from_space, to_space, parent(v)) - return remotecall_fetch_fast(to_w, from_proc, to_proc, from_space, to_space, p_chunk) do from_proc, to_proc, from_space, to_space, p_chunk + return remotecall_endpoint_transfer(accel, from_proc, to_proc, from_space, to_space, p_chunk) do accel, from_proc, to_proc, from_space, to_space, p_chunk p_new = move(from_proc, to_proc, p_chunk) v_new = $(wrapper)(p_new) - return tochunk(v_new, to_proc) + return tochunk(v_new, to_proc, to_space) end end end diff --git a/src/memory-spaces.jl b/src/memory-spaces.jl index a531509cf..4aeebc058 100644 --- a/src/memory-spaces.jl +++ b/src/memory-spaces.jl @@ -23,7 +23,7 @@ processors(space::CPURAMMemorySpace) = unwrap(x::Chunk) = unwrap(x.handle) function unwrap(handle::DRef) @assert root_worker_id(handle) == myid() "DRef $handle is not owned by this process: $(root_worker_id(handle)) != $(myid())" - return MemPool.poolget(x.handle) + return MemPool.poolget(handle) end move!(dep_mod, to_space::MemorySpace, from_space::MemorySpace, to::T, from::F) where {T,F} = throw(ArgumentError("No `move!` implementation defined for $F -> $T")) diff --git a/src/mpi.jl b/src/mpi.jl index b5723e6a2..8daaedd18 100644 --- a/src/mpi.jl +++ b/src/mpi.jl @@ -1,6 +1,7 @@ @warn "Move to MPIExt.jl" maxlog=1 using MPI +import MPIRPC const CHECK_UNIFORMITY = Ref{Bool}(false) function check_uniformity!(check::Bool=true) @@ -26,18 +27,24 @@ function check_uniform(value, original=value) return check_uniform(hash(value), original) end +# MPI tag for `compare_all` / `check_uniform` only. Must not collide with other +# Dagger P2P on the same `comm` (e.g. `remotecall_endpoint_toplevel` uses tag `0` +# for broadcast metadata), or ranks can steal each other's messages and hang +# until `mpi_deadlock_detect` fires. +const COMPARE_ALL_MPI_TAG = UInt32(7243) + function compare_all(value, comm) rank = MPI.Comm_rank(comm) size = MPI.Comm_size(comm) for i in 0:(size-1) if i != rank - send_yield(value, comm, i, UInt32(0)) + send_yield(value, comm, i, COMPARE_ALL_MPI_TAG) end end match = true for i in 0:(size-1) if i != rank - other_value = recv_yield(comm, i, UInt32(0)) + other_value = recv_yield(comm, i, COMPARE_ALL_MPI_TAG) if value != other_value match = false end @@ -61,13 +68,10 @@ function aliasing(accel::MPIAcceleration, x::Chunk, T) ainfo = _with_default_acceleration() do aliasing(x, T) end - #Core.print("[$rank] aliasing: $ainfo, sending\n") @opcounter :aliasing_bcast_send_yield bcast_send_yield(ainfo, accel.comm, handle.rank, tag) else - #Core.print("[$rank] aliasing: receiving from $(handle.rank)\n") ainfo = recv_yield(accel.comm, handle.rank, tag) - #Core.print("[$rank] aliasing: received $ainfo\n") end check_uniform(ainfo) return ainfo @@ -247,6 +251,14 @@ struct MPIMemorySpace{S<:MemorySpace} <: MemorySpace rank::Int end +function Base.:(==)(a::MPIMemorySpace, b::MPIMemorySpace) + return a.innerSpace == b.innerSpace && + a.comm == b.comm && + a.rank == b.rank +end +Base.hash(space::MPIMemorySpace, h::UInt=UInt(0)) = + hash(space.innerSpace, hash(space.comm.val, hash(space.rank, hash(MPIMemorySpace, h)))) + function check_uniform(space::MPIMemorySpace, original=space) return check_uniform(space.rank, original) && # TODO: Not always valid (if pointer is embedded, say for GPUs) @@ -389,11 +401,11 @@ function take_ref_id!() end #TODO: partitioned scheduling with comm bifurcation -function tochunk_pset(x, space::MPIMemorySpace; device=nothing, kwargs...) +function tochunk_pset(x, space::MPIMemorySpace; device=nothing, force_nonlocal=false, kwargs...) @assert space.comm == MPI.COMM_WORLD "$(space.comm) != $(MPI.COMM_WORLD)" local_rank = MPI.Comm_rank(space.comm) Mid = take_ref_id!() - if local_rank != space.rank + if local_rank != space.rank || force_nonlocal return MPIRef(space.comm, space.rank, 0, nothing, Mid) else # type= is for Chunk metadata only; MemPool.poolset does not accept it @@ -770,6 +782,7 @@ function move!(dep_mod::RemainderAliasing{<:MPIMemorySpace}, to_space::MPIMemory return end +const ALIASED_OBJECT_CACHE = ScopedValue{Union{AliasedObjectCache, Nothing}}(nothing) move(::MPIOSProc, ::MPIProcessor, x::Union{Function,Type}) = x move(::MPIOSProc, ::MPIProcessor, x::Chunk{<:Union{Function,Type}}) = poolget(x.handle) @@ -788,6 +801,91 @@ function move(src::MPIOSProc, dst::MPIProcessor, x::Chunk) end end +function is_local(accel::MPIAcceleration, target) + return target.rank == MPI.Comm_rank(accel.comm) +end + +function remotecall_endpoint_toplevel(f, accel::MPIAcceleration, cache::AliasedObjectCache, from_proc, to_proc, from_space, to_space, data::Chunk) + backend = lock(MPIRC_BACKEND) do backends + backends[accel.comm] + end + if is_local(accel, from_proc) + data_raw = unwrap(data) + res = f(accel, cache, from_proc, to_proc, from_space, to_space, data_raw)::Chunk + csz = MPI.Comm_size(accel.comm) + for target_rank in 0:(csz-1) + if target_rank == from_proc.rank + continue + end + MPIRPC.rpc_progress_halt!(backend, target_rank) + end + @assert res.handle isa MPIRef "Expected MPIRef handle for MPI broadcast" + meta = (res.handle.id, res.handle.size, chunktype(res), res.space.innerSpace, res.space.rank) + bcast_send_yield(meta, accel.comm, from_proc.rank, 0) + return res + else + with(ALIASED_OBJECT_CACHE=>cache) do + while MPIRPC.rpc_progress!(backend) + yield() + end + end + id, size, T, inner_space, rank = recv_yield(accel.comm, from_proc.rank, 0) + space = MPIMemorySpace(inner_space, accel.comm, rank) + handle = MPIRef(accel.comm, rank, size, nothing, id) + return Chunk{T, typeof(handle), typeof(to_proc), AnyScope, typeof(space)}( + T, domain(nothing), handle, to_proc, AnyScope(), space) + end +end + +function remotecall_endpoint_transfer(f, accel::MPIAcceleration, from_proc, to_proc, from_space, to_space, data) + backend = lock(MPIRC_BACKEND) do backends + backends[accel.comm] + end + + return MPIRPC.remotecall_fetch(backend, to_proc.rank) do + ACCELERATION[] = accel + return f(accel, from_proc, to_proc, from_space, to_space, data) + end +end + +function set_stored_mpi!(space::MemorySpace, value::Chunk, ainfo::AbstractAliasing) + cache = ALIASED_OBJECT_CACHE[] + ACCELERATION[] = cache.accel + set_stored!(unwrap(cache.chunk)::AliasedObjectCacheStore, cache.space, value, ainfo) + return +end + +function mpi_nonlocal_chunk(value::Chunk) + return tochunk(nothing, value.processor, value.space, value.scope; type=chunktype(value), force_nonlocal=true) +end + +function set_stored!(accel::MPIAcceleration, cache::AliasedObjectCache, value::Chunk, ainfo::AbstractAliasing) + cache_raw = unwrap(cache.chunk)::AliasedObjectCacheStore + backend = lock(MPIRC_BACKEND) do backends + backends[accel.comm] + end + MPIRPC.bcast_remotecall(backend, set_stored_mpi!, cache.space, mpi_nonlocal_chunk(value), ainfo) + set_stored!(cache_raw, cache.space, value, ainfo) + return +end + +function set_key_stored_mpi!(space::MemorySpace, ainfo::AbstractAliasing, value::Chunk) + cache = unwrap(ALIASED_OBJECT_CACHE[].chunk)::AliasedObjectCacheStore + ACCELERATION[] = cache.accel + set_key_stored!(cache, space, ainfo, value) + return +end + +function set_key_stored!(accel::MPIAcceleration, cache::AliasedObjectCache, space::MemorySpace, ainfo::AbstractAliasing, value::Chunk) + cache_raw = unwrap(cache.chunk)::AliasedObjectCacheStore + backend = lock(MPIRC_BACKEND) do backends + backends[accel.comm] + end + MPIRPC.bcast_remotecall(backend, set_key_stored_mpi!, space, ainfo, mpi_nonlocal_chunk(value)) + set_key_stored!(cache_raw, space, ainfo, value) +end + + #= function remotecall_endpoint(f, accel::MPIAcceleration, from_proc, to_proc, from_space, to_space, data::Chunk) loc_rank = MPI.Comm_rank(accel.comm) @@ -804,13 +902,8 @@ function remotecall_endpoint(f, accel::MPIAcceleration, from_proc, to_proc, from return recv_yield(accel.comm, to_proc.rank, tag) end end -function remotecall_endpoint_transfer(f, accel::MPIAcceleration, from_proc, to_proc, from_space, to_space, data) - loc_rank = MPI.Comm_rank(accel.comm) - if loc_rank == from_proc.rank - elseif loc_rank == to_proc.rank - end -end =# +#= function remotecall_endpoint(f, accel::MPIAcceleration, from_proc, to_proc, from_space, to_space, data::Chunk) loc_rank = MPI.Comm_rank(accel.comm) task = DATADEPS_CURRENT_TASK[] @@ -866,6 +959,9 @@ function remotecall_endpoint(f, accel::MPIAcceleration, from_proc, to_proc, from end end end +=# + + # Chunk may be MPI-backed (MPIRef) but labeled with OSProc; treat source as the owning rank function move(src::OSProc, dst::MPIProcessor, x::Chunk) @@ -905,6 +1001,7 @@ function move(src::MPIProcessor, dst::MPIProcessor, x::Chunk) end end +gpu_kernel_backend(::MPIProcessor) = KernelAbstractions.CPU() #FIXME:try to think of a better move! scheme function execute!(proc::MPIProcessor, f, args...; kwargs...) @@ -933,6 +1030,9 @@ function execute!(proc::MPIProcessor, f, args...; kwargs...) if islocal T = typeof(result) + if T === Nothing + @warn "[rank $local_rank] Gave $T result for $fname, args types: $arg_types; treating as Any for broadcast" + end space = memory_space(result, proc)::MPIMemorySpace if need_bcast @opcounter :execute_bcast_send_yield @@ -953,10 +1053,16 @@ end accelerate!(::Val{:mpi}) = accelerate!(MPIAcceleration()) +const MPIRC_BACKEND = LockedObject(Dict{MPI.Comm, MPIRPC.AbstractMPIRPCBackend}()) + function initialize_acceleration!(a::MPIAcceleration) if !MPI.Initialized() MPI.Init(;threadlevel=:multiple) end + backend = MPIRPC.select_mpi_rpc_backend!(MPIRPC.UniformMPIRPCBackend(a.comm)) + lock(MPIRC_BACKEND) do backends + backends[a.comm] = backend + end ctx = Dagger.Sch.eager_context() sz = MPI.Comm_size(a.comm) for i in 0:(sz-1) From 3f9b4d9f576a602309af322f8c1d4d78731ea184 Mon Sep 17 00:00:00 2001 From: Felipe Tome Date: Mon, 18 May 2026 19:50:34 -0300 Subject: [PATCH 6/6] MPI: MPIRPC included as a dependency --- Project.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/Project.toml b/Project.toml index 69163e027..b0e678cfb 100644 --- a/Project.toml +++ b/Project.toml @@ -14,6 +14,7 @@ KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195" +MPIRPC = "a8caf107-0824-430d-bb41-9c1c9e5c5a9f" MemPool = "f9f48841-c794-520a-933b-121f7ba6ed94" NextLA = "d37ed344-79c4-486d-9307-6d11355a15a3" OnlineStats = "a15396b6-48d5-5d58-9928-6d29437db91e" @@ -79,6 +80,7 @@ JSON3 = "1" KernelAbstractions = "0.9" MacroTools = "0.5" MPI = "0.20.22" +MPIRPC = "0.1" MemPool = "0.4.12" Metal = "1.1" NextLA = "0.2.2"