diff --git a/Project.toml b/Project.toml index ce49bf6d7..b0e678cfb 100644 --- a/Project.toml +++ b/Project.toml @@ -13,6 +13,8 @@ Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" +MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195" +MPIRPC = "a8caf107-0824-430d-bb41-9c1c9e5c5a9f" MemPool = "f9f48841-c794-520a-933b-121f7ba6ed94" NextLA = "d37ed344-79c4-486d-9307-6d11355a15a3" OnlineStats = "a15396b6-48d5-5d58-9928-6d29437db91e" @@ -77,6 +79,8 @@ Graphs = "1" JSON3 = "1" KernelAbstractions = "0.9" MacroTools = "0.5" +MPI = "0.20.22" +MPIRPC = "0.1" MemPool = "0.4.12" Metal = "1.1" NextLA = "0.2.2" diff --git a/benchmarks/check_comm_asymmetry.jl b/benchmarks/check_comm_asymmetry.jl new file mode 100644 index 000000000..684240ec5 --- /dev/null +++ b/benchmarks/check_comm_asymmetry.jl @@ -0,0 +1,111 @@ +#!/usr/bin/env julia +# Parse MPI+Dagger logs and report communication decision asymmetry per tag. +# Asymmetry: for the same tag, one rank decides to send (local+bcast, sender+communicated, etc.) +# and another rank decides to infer (inferred, uninvolved) and never recv → deadlock. +# +# Usage: julia check_comm_asymmetry.jl < logfile +# Or: mpiexec -n 10 julia ... run_matmul.jl 2>&1 | tee matmul.log; julia check_comm_asymmetry.jl < matmul.log + +const SEND_DECISIONS = Set([ + "local+bcast", "sender+communicated", "sender+inferred", "receiver+bcast", + "aliasing", # when followed by local+bcast we already capture local+bcast +]) +const RECV_DECISIONS = Set([ + "communicated", "receiver", "sender+communicated", # received data +]) +const INFER_DECISIONS = Set([ + "inferred", "uninvolved", # did not recv (uses inferred type) +]) + +function parse_line(line) + # Match [rank X][tag Y] then any [...] and capture the last bracket pair before space or end + rank = nothing + tag = nothing + decision = nothing + category = nothing # aliasing, execute!, remotecall_endpoint + for m in eachmatch(r"\[rank\s+(\d+)\]", line) + rank = parse(Int, m.captures[1]) + end + for m in eachmatch(r"\[tag\s+(\d+)\]", line) + tag = parse(Int, m.captures[1]) + end + for m in eachmatch(r"\[(execute!|aliasing|remotecall_endpoint)\]", line) + category = m.captures[1] + end + # Decision is usually in last [...] that looks like [word] or [word+word] + for m in eachmatch(r"\]\[([^\]]+)\]", line) + candidate = m.captures[1] + # Normalize: "communicated" "inferred" "local+bcast" "sender+inferred" "receiver" etc. + if occursin("inferred", candidate) && !occursin("communicated", candidate) + decision = "inferred" + break + elseif occursin("communicated", candidate) + decision = "communicated" + break + elseif occursin("local+bcast", candidate) + decision = "local+bcast" + break + elseif occursin("sender+", candidate) + decision = startswith(candidate, "sender+inferred") ? "sender+inferred" : "sender+communicated" + break + elseif candidate == "receiver" + decision = "receiver" + break + elseif candidate == "receiver+bcast" + decision = "receiver+bcast" + break + elseif candidate == "inplace_move" + decision = "inplace_move" + break + end + end + return rank, tag, category, decision +end + +function main() + # tag => Dict(rank => decision) + by_tag = Dict{Int, Dict{Int, String}}() + for line in eachline(stdin) + rank, tag, category, decision = parse_line(line) + isnothing(rank) && continue + isnothing(tag) && continue + isnothing(decision) && continue + if !haskey(by_tag, tag) + by_tag[tag] = Dict{Int, String}() + end + by_tag[tag][rank] = decision + end + + # For each tag, check: is there at least one sender and one inferrer (non-receiver)? + send_keys = Set(["local+bcast", "sender+communicated", "sender+inferred", "receiver+bcast"]) + infer_keys = Set(["inferred", "sender+inferred"]) # sender+inferred means sender didn't need to recv + recv_keys = Set(["communicated", "receiver", "sender+communicated"]) + + asymmetries = [] + for (tag, ranks) in sort(collect(by_tag), by = first) + senders = [r for (r, d) in ranks if d in send_keys] + inferrers = [r for (r, d) in ranks if d in infer_keys || d == "uninvolved"] + receivers = [r for (r, d) in ranks if d in recv_keys] + # Asymmetry: someone sends (bcast) so will send to ALL other ranks; someone chose infer and won't recv. + if !isempty(senders) && !isempty(inferrers) + push!(asymmetries, (tag, senders, inferrers, receivers, ranks)) + end + end + + if isempty(asymmetries) + println("No communication decision asymmetry found (no tag has both sender and inferrer).") + return + end + + println("=== Communication decision asymmetry (can cause deadlock) ===\n") + for (tag, senders, inferrers, receivers, ranks) in asymmetries + println("Tag $tag:") + println(" Senders (will bcast to all others): $senders") + println(" Inferrers (did not recv): $inferrers") + println(" Receivers: $receivers") + println(" All decisions: $ranks") + println() + end +end + +main() diff --git a/benchmarks/check_comm_asymmetry.py b/benchmarks/check_comm_asymmetry.py new file mode 100644 index 000000000..31a117442 --- /dev/null +++ b/benchmarks/check_comm_asymmetry.py @@ -0,0 +1,97 @@ +#!/usr/bin/env python3 +""" +Parse MPI+Dagger logs and report communication decision asymmetry per tag. +Asymmetry: for the same tag, one rank decides to send (local+bcast, etc.) +and another decides to infer (inferred) and never recv → deadlock. + +Usage: + # Capture full log (all ranks' Core.println from mpi.jl go to stdout): + mpiexec -n 10 julia --project=/path/to/Dagger.jl benchmarks/run_matmul.jl 2>&1 | tee matmul.log + # Then look for asymmetry (same tag: one rank sends, another infers → deadlock): + python3 check_comm_asymmetry.py < matmul.log +""" + +import re +import sys +from collections import defaultdict + +SEND_DECISIONS = {"local+bcast", "sender+communicated", "sender+inferred", "receiver+bcast"} +RECV_DECISIONS = {"communicated", "receiver", "sender+communicated"} +INFER_DECISIONS = {"inferred", "uninvolved", "sender+inferred"} + + +def parse_line(line: str): + rank = tag = category = decision = None + m = re.search(r"\[rank\s+(\d+)\]", line) + if m: + rank = int(m.group(1)) + m = re.search(r"\[tag\s+(\d+)\]", line) + if m: + tag = int(m.group(1)) + m = re.search(r"\[(execute!|aliasing|remotecall_endpoint)\]", line) + if m: + category = m.group(1) + # Capture decision from [...] blocks + for m in re.finditer(r"\]\[([^\]]+)\]", line): + candidate = m.group(1) + if "inferred" in candidate and "communicated" not in candidate: + decision = "inferred" + break + if "communicated" in candidate: + decision = "communicated" + break + if "local+bcast" in candidate: + decision = "local+bcast" + break + if candidate.startswith("sender+"): + decision = "sender+inferred" if "inferred" in candidate else "sender+communicated" + break + if candidate == "receiver": + decision = "receiver" + break + if candidate == "receiver+bcast": + decision = "receiver+bcast" + break + if candidate == "inplace_move": + decision = "inplace_move" + break + return rank, tag, category, decision + + +def main(): + by_tag = defaultdict(dict) # tag -> {rank: decision} + for line in sys.stdin: + rank, tag, category, decision = parse_line(line) + if rank is None or tag is None or decision is None: + continue + by_tag[tag][rank] = decision + + send_keys = {"local+bcast", "sender+communicated", "sender+inferred", "receiver+bcast"} + infer_keys = {"inferred", "sender+inferred", "uninvolved"} + recv_keys = {"communicated", "receiver", "sender+communicated"} + + asymmetries = [] + for tag in sorted(by_tag.keys()): + ranks = by_tag[tag] + senders = [r for r, d in ranks.items() if d in send_keys] + inferrers = [r for r, d in ranks.items() if d in infer_keys] + receivers = [r for r, d in ranks.items() if d in recv_keys] + if senders and inferrers: + asymmetries.append((tag, senders, inferrers, receivers, ranks)) + + if not asymmetries: + print("No communication decision asymmetry found (no tag has both sender and inferrer).") + return + + print("=== Communication decision asymmetry (can cause deadlock) ===\n") + for tag, senders, inferrers, receivers, ranks in asymmetries: + print(f"Tag {tag}:") + print(f" Senders (will bcast to all others): {senders}") + print(f" Inferrers (did not recv): {inferrers}") + print(f" Receivers: {receivers}") + print(f" All decisions: {dict(ranks)}") + print() + + +if __name__ == "__main__": + main() diff --git a/benchmarks/run_distribute_fetch.jl b/benchmarks/run_distribute_fetch.jl new file mode 100644 index 000000000..822e1ad2c --- /dev/null +++ b/benchmarks/run_distribute_fetch.jl @@ -0,0 +1,42 @@ +#!/usr/bin/env julia +# Create a matrix with a fixed reproducible pattern, distribute it with an +# MPI procgrid, then on each rank fetch and println the chunk(s) it owns. +# Usage (from repo root, use full path to Dagger.jl): +# mpiexec -n 4 julia --project=/path/to/Dagger.jl benchmarks/run_distribute_fetch.jl + +using MPI +using Dagger + +if !isdefined(Dagger, :accelerate!) + error("Dagger.accelerate! not found. Run with the local Dagger project: julia --project=/path/to/Dagger.jl ...") +end +Dagger.accelerate!(:mpi) + +const comm = MPI.COMM_WORLD +const rank = MPI.Comm_rank(comm) +const nranks = MPI.Comm_size(comm) + +# Fixed reproducible pattern: 6×6 matrix, M[i,j] = 10*i + j (same on all ranks) +const N = 6 +const BLOCK = 2 +A = [10 * i + j for i in 1:N, j in 1:N] + +# Procgrid: use Dagger's compatible processors so the procgrid passes validation +availprocs = collect(Dagger.compatible_processors()) +nblocks = (cld(N, BLOCK), cld(N, BLOCK)) +procgrid = reshape( + [availprocs[mod(i - 1, length(availprocs)) + 1] for i in 1:prod(nblocks)], + nblocks, +) + +# Distribute so chunk (i,j) is computed on procgrid[i,j] +D = distribute(A, Blocks(BLOCK, BLOCK), procgrid) +D_fetched = fetch(D) + +# On each rank: fetch and print only the chunk(s) this rank owns +for (idx, ch) in enumerate(D_fetched.chunks) + if ch isa Dagger.Chunk && ch.handle isa Dagger.MPIRef && ch.handle.rank == rank + data = fetch(ch) + println("rank $rank chunk $idx: ", data) + end +end diff --git a/benchmarks/run_matmul.jl b/benchmarks/run_matmul.jl new file mode 100644 index 000000000..0eb4ec0d7 --- /dev/null +++ b/benchmarks/run_matmul.jl @@ -0,0 +1,105 @@ +#!/usr/bin/env julia +# N×N matmul benchmark (Float32); block size scales with number of ranks. +# Usage (use the full path to Dagger.jl, not "..."): +# mpiexec -n 10 julia --project=/home/felipetome/dagger-dev/mpi/Dagger.jl benchmarks/run_matmul.jl +# Set CHECK_CORRECTNESS=true to collect and compare against GPU baseline: +# CHECK_CORRECTNESS=true mpiexec -n 10 julia --project=/home/felipetome/dagger-dev/mpi/Dagger.jl benchmarks/run_matmul.jl + +using MPI +using Dagger +using LinearAlgebra + +if !isdefined(Dagger, :accelerate!) + error("Dagger.accelerate! not found. Run with the local Dagger project: julia --project=/path/to/Dagger.jl ...") +end +Dagger.accelerate!(:mpi) + +const N = 2_000 +const comm = MPI.COMM_WORLD +const rank = MPI.Comm_rank(comm) +const nranks = MPI.Comm_size(comm) +# Block size proportional to ranks: ~nranks blocks in 2D => side blocks ≈ √nranks +const BLOCK = max(1, ceil(Int, N / ceil(Int, sqrt(nranks)))) + +const CHECK_CORRECTNESS = parse(Bool, get(ENV, "CHECK_CORRECTNESS", "false")) + +if rank == 0 + println("Benchmark: ", nranks, " ranks, N=", N, ", block size ", BLOCK, "×", BLOCK, " (matmul)") +end + +# Allocate and fill matrices in blocks (Float32) +A = rand(Blocks(BLOCK, BLOCK), Float32, N, N) +B = rand(Blocks(BLOCK, BLOCK), Float32, N, N) + +# Matrix multiply C = A * B +t_matmul = @elapsed begin + C = A * B +end + +if rank == 0 + println("Matmul time: ", round(t_matmul; digits=4), " s") +end + +# Optional: collect via datadeps (root=0). All ranks participate in the datadeps region. +if CHECK_CORRECTNESS + t_collect = @elapsed begin + A_full = Dagger.collect_datadeps(A; root=0) + B_full = Dagger.collect_datadeps(B; root=0) + C_dagger = Dagger.collect_datadeps(C; root=0) + end + if rank == 0 + println("Collecting result and computing baseline for correctness check (GPU)...") + using CUDA + CUDA.functional() || error("CUDA not functional; cannot compute GPU baseline. Check CUDA driver and device.") + t_upload = @elapsed begin + A_g = CUDA.cu(A_full) + B_g = CUDA.cu(B_full) + end + println("Collect + upload time: ", round(t_collect + t_upload; digits=4), " s") + + t_baseline = @elapsed begin + C_ref_g = A_g * B_g + end + println("Baseline (GPU/CUDA) time: ", round(t_baseline; digits=4), " s") + + # Require all elements within 100× machine epsilon relative error (componentwise) + C_dagger_cpu = C_dagger + C_ref_cpu = Array(C_ref_g) + eps_f = eps(Float32) + rtol = 50.0f0 * eps_f + diff = C_dagger_cpu .- C_ref_cpu + # rel_ij = |diff|/|C_ref|, denominator at least eps to avoid div by zero + denom = max.(abs.(C_ref_cpu), eps_f) + rel_err = abs.(diff) ./ denom + max_rel_err = Float32(maximum(rel_err)) + ok = max_rel_err <= rtol + if ok + println("Correctness: OK (max rel_err = ", max_rel_err, " <= 100×eps = ", rtol, ")") + else + println("Correctness: FAIL (max rel_err = ", max_rel_err, " > 100×eps = ", rtol, ")") + end + + # Per-block: which blocks have any element with rel_err > 100×eps + n_bi = ceil(Int, N / BLOCK) + n_bj = ceil(Int, N / BLOCK) + bad_blocks = Tuple{Int,Int,Float32}[] + for bi in 1:n_bi, bj in 1:n_bj + ri = (bi - 1) * BLOCK + 1 : min(bi * BLOCK, N) + rj = (bj - 1) * BLOCK + 1 : min(bj * BLOCK, N) + block_rel = Float32(maximum(@view(rel_err[ri, rj]))) + if block_rel > rtol + push!(bad_blocks, (bi, bj, block_rel)) + end + end + if isempty(bad_blocks) + println("Per-block: all ", n_bi * n_bj, " blocks within 100×eps rel_err.") + else + println("Per-block: ", length(bad_blocks), " block(s) exceed 100×eps rel_err (block size ", BLOCK, "×", BLOCK, "):") + sort!(bad_blocks; by = x -> -x[3]) + for (bi, bj, block_rel) in bad_blocks + println(" block [", bi, ",", bj, "] rows ", (bi - 1) * BLOCK + 1, ":", min(bi * BLOCK, N), + ", cols ", (bj - 1) * BLOCK + 1, ":", min(bj * BLOCK, N), " max rel_err = ", block_rel) + end + end + end +end diff --git a/benchmarks/run_qr.jl b/benchmarks/run_qr.jl new file mode 100644 index 000000000..c5915db2a --- /dev/null +++ b/benchmarks/run_qr.jl @@ -0,0 +1,46 @@ +#!/usr/bin/env julia +# 10k×10k QR + matmul benchmark; block size scales with number of ranks. +# Usage: mpiexec -n 100 julia --project=/path/to/Dagger.jl benchmarks/bench_100rank_qr_matmul.jl +# Or: bash benchmarks/run_100rank_qr_matmul.sh . + +using MPI +using Dagger +using LinearAlgebra + +Dagger.accelerate!(:mpi) + +const N = 10_000 +const comm = MPI.COMM_WORLD +const rank = MPI.Comm_rank(comm) +const nranks = MPI.Comm_size(comm) +# Block size proportional to ranks: ~nranks blocks in 2D => side blocks ≈ √nranks +const BLOCK = max(1, ceil(Int, N / ceil(Int, sqrt(nranks)))) + +if rank == 0 + println("Benchmark: ", nranks, " ranks, N=", N, ", block size ", BLOCK, "×", BLOCK, " (QR + matmul)") +end + +# Allocate and fill 10k×10k matrix in 1k×1k blocks +A = rand(Blocks(BLOCK, BLOCK), Float64, N, N) +MPI.Barrier(comm) + +# QR factorization (computing Q runs the full factorization) +t_qr = @elapsed begin + qr!(A) +end +MPI.Barrier(comm) + +if rank == 0 + println("QR time: ", round(t_qr; digits=4), " s") +end + +# Matrix multiply A * A +t_matmul = @elapsed begin + C = A * A +end +MPI.Barrier(comm) + +if rank == 0 + println("Matmul time: ", round(t_matmul; digits=4), " s") +end + diff --git a/lib/MPIRPC/.gitignore b/lib/MPIRPC/.gitignore new file mode 100644 index 000000000..ba39cc531 --- /dev/null +++ b/lib/MPIRPC/.gitignore @@ -0,0 +1 @@ +Manifest.toml diff --git a/lib/MPIRPC/Project.toml b/lib/MPIRPC/Project.toml new file mode 100644 index 000000000..4c5af20d2 --- /dev/null +++ b/lib/MPIRPC/Project.toml @@ -0,0 +1,19 @@ +name = "MPIRPC" +uuid = "a8caf107-0824-430d-bb41-9c1c9e5c5a9f" +version = "0.1.0" +authors = ["MPIRPC contributors"] + +[deps] +MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195" +Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b" + +[compat] +MPI = "0.20" +julia = "1.9" + +[extras] +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" + +[targets] +test = ["Test", "Random"] diff --git a/lib/MPIRPC/README.md b/lib/MPIRPC/README.md new file mode 100644 index 000000000..43793f61e --- /dev/null +++ b/lib/MPIRPC/README.md @@ -0,0 +1,72 @@ +# MPIRPC.jl + +Standalone MPI-backed RPC for Julia, modeled after the `Distributed` stdlib remote +invocation pipeline (header + serialized body + boundary, `invokelatest` on the +handler path, `RemoteException`-style error wrapping) and Dagger's +"accelerate-once" backend selection (a single backend value held in a task-local +slot determines uniform vs. non-uniform semantics; the public API is the same in +both modes). + +* Public API mirrors `Distributed`: `remotecall`, `remotecall_fetch`, + `remotecall_wait`, `remote_do`, `bcast_remotecall`, `fetch`, `wait`, with an `MPIFuture` handle. +* Two backends: + * `UniformMPIRPCBackend`: SPMD, every rank both issues and services RPC, + every rank calls `rpc_progress!`. + * `NonUniformMPIRPCBackend`: explicit listener / client roles, optional + `MPI.Comm_split` to isolate RPC from world collectives, only listeners must + poll `rpc_progress!`. +* Transport: `MPI.jl` only. No TCP, no `Distributed`, no `Dagger` dependency. +* Multi-threaded out of the box: handlers dispatch on `Threads.@spawn`, + so user closures that themselves spawn tasks and `fetch` them do not + deadlock under `julia --threads=N`. Trade-off: handler execution is + not FIFO across messages from the same source — use `remotecall_wait` + (or `remotecall` + `fetch`) when you need a side effect committed + before the next call. See `docs/ARCHITECTURE.md` §6. +* Optional progress daemon: pass `daemon = true` to either backend + constructor and `select_mpi_rpc_backend!` will spawn a yield-only + background task that drives `rpc_progress!` for you. The daemon is + placed on Julia's `:interactive` threadpool (start Julia with + `-t N,M`, `M >= 1`), so CPU-bound user code on the `:default` pool + cannot starve the wire pump. Removes the "every rank must call + `rpc_progress!`" requirement at the cost of ≈ 100% utilization of + one OS thread (the daemon never sleeps). `shutdown!` joins the daemon + cleanly. See `docs/ARCHITECTURE.md` §6. +* Idle waits cost zero CPU under `daemon = true`: `wait` / `fetch` on + an `MPIFuture` park on a `Threads.Condition` and are woken by + `deliver!` from the daemon's reply-dispatch path. Without the daemon + (`daemon = false`), `wait` falls back to a yield-rate spin because + the calling task is the only available progress driver. +* Minimum Julia: `1.9` (stable `task_local_storage`, `Base.invokelatest`). + +See [`docs/ARCHITECTURE.md`](docs/ARCHITECTURE.md) for the design rationale, the +deadlock and ABBA discussion, and the side-by-side mapping to `Distributed` and +to Dagger's `accelerate!`. + +## Quick start + +```julia +using MPI, MPIRPC + +MPI.Init(; threadlevel=:multiple) +MPIRPC.select_mpi_rpc_backend!(MPIRPC.UniformMPIRPCBackend(MPI.COMM_WORLD)) + +rank = MPI.Comm_rank(MPI.COMM_WORLD) +nprocs = MPI.Comm_size(MPI.COMM_WORLD) +peer = mod(rank + 1, nprocs) + +result = MPIRPC.@with_progress MPIRPC.remotecall_fetch(+, peer, rank, 100) +@assert result == rank + 100 # the args (rank, 100) travel to `peer` which sums them + +# Fan out the same call to every other rank; collect later: +futs = MPIRPC.bcast_remotecall(MPIRPC.current_mpi_rpc_backend(), *, rank, 7) +vals = map(MPIRPC.fetch, futs) # one entry per destination rank (ascending), excluding `rank` + +MPIRPC.shutdown!() +MPI.Finalize() +``` + +## Security + +Like `Distributed`, MPIRPC deserializes function objects and arguments sent by +peers. The same trust model applies: only run MPIRPC across mutually trusted +ranks. diff --git a/lib/MPIRPC/docs/ARCHITECTURE.md b/lib/MPIRPC/docs/ARCHITECTURE.md new file mode 100644 index 000000000..40de7068c --- /dev/null +++ b/lib/MPIRPC/docs/ARCHITECTURE.md @@ -0,0 +1,500 @@ +# MPIRPC.jl architecture + +MPIRPC is a standalone, MPI-only RPC for Julia, modeled on two existing +designs: + +* the **`Distributed` stdlib remote-call pipeline** — header + serialized body + + boundary, `invokelatest` on the handler path, `RemoteException`-style + error wrapping; and +* **Dagger.jl's "accelerate-once" backend selection** — one installed value + determines the runtime semantics; the public API is identical regardless + of mode. + +There is no dependency on `Distributed` or `Dagger`. The transport is +[`MPI.jl`](https://github.com/JuliaParallel/MPI.jl); `Serialization` is the +only other (stdlib) dependency. + +## 1. Module layout + +| File | Role | +|------|------| +| [`src/MPIRPC.jl`](../src/MPIRPC.jl) | Module entry, exports. | +| [`src/protocol.jl`](../src/protocol.jl) | `MPIRRID`, `MsgHeader`, `CallMsg`, `CallWaitMsg`, `RemoteDoMsg`, `ResultMsg`, `RPCProgressHaltMsg`, frame encode/decode. | +| [`src/exceptions.jl`](../src/exceptions.jl) | `MPIRemoteException` and `run_work_thunk`. | +| [`src/config.jl`](../src/config.jl) | `AbstractMPIRPCBackend`, process-global + task-local installation, `select_mpi_rpc_backend!` / `with_mpi_rpc_backend` / `current_mpi_rpc_backend`. | +| [`src/refs.jl`](../src/refs.jl) | `MPIFuture` and waiter delivery primitives. | +| [`src/uniform.jl`](../src/uniform.jl) | `UniformMPIRPCBackend` — SPMD: every rank issues and services. | +| [`src/nonuniform.jl`](../src/nonuniform.jl) | `NonUniformMPIRPCBackend` — listener/client split. | +| [`src/remotecall.jl`](../src/remotecall.jl) | Public surface: `remotecall`, `remotecall_fetch`, `remotecall_wait`, `remote_do`, `bcast_remotecall`, `wait`/`fetch`. | +| [`src/progress.jl`](../src/progress.jl) | `rpc_progress!`, `rpc_progress_halt!`, `rpc_barrier`, `serve_listener`. | + +## 2. Backend selection (Dagger parallel) + +The selection model is a near-exact port of Dagger's +[`acceleration.jl`](../../Dagger.jl/src/acceleration.jl): + +| Dagger | MPIRPC | +|--------|--------| +| `accelerate!(::Acceleration)` | `select_mpi_rpc_backend!(::AbstractMPIRPCBackend)` | +| `current_acceleration()` | `current_mpi_rpc_backend()` | +| `_with_default_acceleration(f)` | `with_mpi_rpc_backend(f, backend)` | +| `initialize_acceleration!(::Acceleration)` | `initialize_mpi_rpc!(::AbstractMPIRPCBackend)` | +| `Acceleration` | `AbstractMPIRPCBackend` | +| `DistributedAcceleration` (default no-op) | none — no backend is installed by default | +| `MPIAcceleration` | `UniformMPIRPCBackend`, `NonUniformMPIRPCBackend` | + +The installed backend lives in **two slots**: a process-global `Ref` (set +once by `select_mpi_rpc_backend!`) and a task-local override +(`task_local_storage(:mpi_rpc_backend)`, layered by `with_mpi_rpc_backend`). +`current_mpi_rpc_backend()` checks the task-local slot first, falling back +to the global. This matches what Dagger users see in practice — Dagger +declares the `TaskLocalValue` but in practice every `@spawn`-ed task sees +the value because the package exposes a global initializer for the slot — +while keeping our public surface dependency-free. + +**Re-init is not supported in v1.** Calling `select_mpi_rpc_backend!` +twice in the same process replaces the slot but does **not** finalize the +first backend's communicators or in-flight `Isend` buffers. To switch +modes mid-job, finalize MPI and start a fresh process. + +```mermaid +flowchart LR + Init["select_mpi_rpc_backend(b)"] --> InitHook["initialize_mpi_rpc(b)"] + InitHook --> SetGlobal["GLOBAL_BACKEND := b"] + SetGlobal --> Use["remotecall, remote_do, wait, ..."] + Use --> Lookup["current_mpi_rpc_backend()"] + Lookup --> CheckTLS{"task_local_storage hit?"} + CheckTLS -- yes --> ReturnTLS[return TLS value] + CheckTLS -- no --> ReturnGlobal[return GLOBAL_BACKEND] +``` + +## 3. Wire format + +Every RPC message — request or reply — is a single MPI byte payload of the +form: + +``` ++--------------------------+--------------------------+----------------+ +| serialized MsgHeader | serialized AbstractMsg | MSG_BOUNDARY | ++--------------------------+--------------------------+----------------+ +``` + +A fresh `Serialization.Serializer` is created per frame, so the +serializer's back-reference table is bounded to one frame. This is +deliberately simpler than `Distributed.ClusterSerializer`, which keeps +inter-message state for things like worker-ref caching. Full +`ClusterSerializer` parity (anonymous-function global-binding tracking, +shared object-number tables) is out of scope for v1 and is planned as an +opt-in serializer choice later. + +The `MSG_BOUNDARY` (10 bytes) is preserved from Distributed's design +purely as a fail-fast protocol-version / corruption check; MPI is +message-oriented so we do not need the boundary for stream +re-synchronization. + +### Message types + +| Type | Direction | Effect | +|------|-----------|--------| +| `CallMsg{:call}` / `CallMsg{:call_fetch}` | client → server | Server runs `f(args...; kwargs...)`, replies with `ResultMsg(value_or_exception)`. | +| `CallWaitMsg` | client → server | Server runs the call; replies with `:OK` or an `MPIRemoteException`. | +| `RemoteDoMsg` | client → server | Server runs the call, no reply sent (fire-and-forget). | +| `ResultMsg` | server → client | Carries the value; routed to a `MPIFuture` via `MsgHeader.response_oid`. | +| `RPCProgressHaltMsg` | any → rank on `request_tag` | Control: `rpc_progress!` stops draining further requests in the current pass (`return false`); empty `MsgHeader`. | + +### OID layout + +`MPIRRID(whence::Int32, id::UInt64)` is our analog of `Distributed.RRID`, +with `whence` storing the *MPI rank* (within the backend's communicator) +that allocated the id. Each `remotecall*` allocates a fresh `MPIRRID` on +the caller, registers a `MPIFuture` in the backend's waiter table, and +puts the id in `MsgHeader.notify_oid`. The server echoes that id back in +`MsgHeader.response_oid` of the `ResultMsg`; the client routes the reply +to the matching future. + +## 4. Tag layout (uniform & non-uniform) + +Two disjoint tags carry **all** RPC traffic: + +* `request_tag` (default `0xC0DE`): every `CallMsg` / `CallWaitMsg` / + `RemoteDoMsg`, and control [`RPCProgressHaltMsg`](../src/protocol.jl) + (see [`rpc_progress_halt!`](../src/progress.jl)) — same framing as RPC, + consumed inside `rpc_progress!` without spawning a handler. +* `reply_tag` (default `0xC0DF`): every `ResultMsg`. + +Tags **do not encode call identity**. Concurrent calls between the same +pair of ranks are distinguished by `MPIRRID` in the header, not by tag. +This is the most important deadlock-avoidance choice in MPIRPC; see +section 5. + +```mermaid +sequenceDiagram + participant Client + participant ClientMPI as Client MPI + participant ServerMPI as Server MPI + participant Server + Client->>Client: encode_frame MsgHeader plus CallMsg plus boundary + Client->>ClientMPI: Isend buf request_tag dest server + ClientMPI->>ServerMPI: matched on request_tag + Server->>Server: rpc_progress Improbe Mrecv decode_frame + Server->>Server: invokelatest f args + Server->>ServerMPI: Isend ResultMsg reply_tag dest client + ServerMPI->>ClientMPI: matched on reply_tag + Client->>Client: rpc_progress Improbe Mrecv decode_frame deliver MPIFuture +``` + +## 5. Deadlock avoidance and the ABBA argument + +MPIRPC deliberately decouples three concerns that often get conflated in +hand-rolled MPI-RPC code: + +1. **Direction** (request vs reply) — encoded in the *tag*. +2. **Call identity** (which call's reply is this?) — encoded in the + `MPIRRID` carried in the *header*. +3. **Order of completion** — never derived from tag matching; only from + waiter routing. + +This decoupling immediately rules out a class of ABBA bugs: + +* **Symmetric cross-calls.** If rank `i` issues a `remotecall_fetch` to + rank `j` while rank `j` simultaneously issues one to rank `i`, no + `(peer, tag)` pair is shared between request and reply traffic, so + request `Improbe`s never accidentally match a reply or vice versa. +* **Many concurrent calls between one pair.** Two `remotecall`s from `i` + to `j` get distinct `MPIRRID`s; the server's reply is routed by header + rather than position in the queue, so the client correctly resolves + futures even when fetched in a permuted order. +* **Wire-level FIFO.** MPI guarantees in-order delivery between a pair of + ranks on a single tag; we rely on this for *delivery* of nested calls, + e.g. a handler issuing further calls. **Handler execution is not + serialized** — see §6 (threading model) — so an application that needs + ordering of side effects across two messages from the same source + must use `remotecall_wait` (or `remotecall` + `fetch`), not bare + `remote_do` + `remotecall_fetch`. + +### Mesh-shutdown deadlock and `rpc_barrier` + +A subtler hazard appears specifically at *phase boundaries*: a rank `R` +whose own primary `remotecall_fetch` has just completed may still hold an +*unprocessed inbound request* sent by another rank from inside that +rank's handler. If `R` calls `MPI.Barrier` it stops pumping RPC +progress, and the peers still waiting on `R`'s replies hang. + +[`rpc_barrier`](../src/progress.jl) addresses this with a non-blocking +`MPI.Ibarrier` whose completion is awaited *while every rank pumps* +[`rpc_progress!`](../src/progress.jl). Use it between phases of an SPMD +program and before exiting an RPC session. The +[uniform mpiexec suite](../test/uniform_mpiexec.jl) regression-tests the +exact pattern that motivated this primitive (the "nested re-entrant +remotecall_fetch" testset). + +## 6. Threading model + +MPIRPC supports `julia --threads=N` end-to-end: MPI is initialized with +`threadlevel=:multiple`, every MPI call is wrapped in a per-backend +`mpi_lock`, the OID counter is `Threads.Atomic{UInt64}`, and the waiter +table is guarded by `waiters_lock`. The two key design choices: + +1. **Handlers run on `Threads.@spawn`.** When `rpc_progress!` matches a + request, it does not run the handler synchronously on the calling + task — it spawns a fresh task whose body wraps the dispatch in + `with_mpi_rpc_backend(backend) do ... end`. The progress pump + returns immediately. The user closure runs without holding any + MPIRPC lock. +2. **No global progress lock.** Multiple threads may call + `rpc_progress!` concurrently. The only serialization is the + per-backend `mpi_lock` around individual MPI calls (`Improbe`, + `Mrecv!`, `Isend`, `Test`, `Ibarrier`); state transitions on the + waiter table use the separate `waiters_lock`. + +Together these eliminate the deadlock pattern where a handler calls +`Threads.@spawn t; fetch(t)` and `t` itself does an MPIRPC call. With a +synchronous-handler model that held a progress lock, the spawned task +could not acquire the progress lock from another OS thread; the calling +task held it while blocked on `fetch(t)`. With handlers spawned, no lock +is held during user code, and `t` is free to drive progress on its own +task. The +[`uniform / handler that Threads.@spawn-then-fetches an RPC`](../test/uniform_mpiexec.jl) +regression test locks this in. + +### Trade-off: handler execution is not FIFO + +Distributed.jl serializes message dispatch through a single per-worker +reader task; this gives the property that two messages from the same +source on the same `(src, dest, tag)` are *executed* in arrival order. +MPIRPC under multi-threading does not provide that guarantee — it +preserves only **MPI delivery FIFO**. Two requests from rank `R` to rank +`P` may be matched and dispatched on different threads of `P`, in which +case the two handlers run concurrently and the second's side effects +may become visible before the first's. + +The right primitive for "the next call must observe the previous call's +side effect" is `remotecall_wait`: + +```julia +# Wrong under multi-threading: the fetch handler may dispatch on +# another thread *before* the remote_do handler's @eval commits. +remote_do(peer, ...) do; @eval Main.X = 42 end +val = remotecall_fetch(peer) do; Main.X end + +# Right: remotecall_wait blocks the caller until the remote handler +# acknowledges completion, so the side effect is committed before the +# next message goes on the wire. +remotecall_wait(peer, ...) do; @eval Main.X = 42 end +val = remotecall_fetch(peer) do; Main.X end +``` + +The `set-then-read with remotecall_wait` testset in both mpiexec suites +documents this idiom. + +### Hazards that remain (user-error class) + +* `rpc_barrier` (and any MPI collective) must be called from at most one + thread per rank — calling it concurrently posts multiple `Ibarrier` + requests, which is incorrect by MPI semantics. +* User-level non-reentrant locks held across `remotecall_fetch` can + self-deadlock if a remote handler running on the same task re-enters + them. Use `ReentrantLock`, or do not take user locks from inside + remote handlers. +* `wait(f::MPIFuture)` *under `daemon = false`* is a busy-pumping spin: + every waiter calls `rpc_progress!` in a loop with `yield()` between + passes. This is unavoidable in the no-daemon configuration because + `wait` itself is the only progress driver — parking the caller would + also park the wire. *Under `daemon = true`* the spin is replaced by a + proper `Threads.Condition`-park (see "Cond-park under `daemon = true`" + below), so this hazard only applies to applications that opted out of + the daemon. + +### Optional progress daemon + +By default, every rank that needs to receive RPC must call +`rpc_progress!` (or block in `wait`/`fetch`/`rpc_barrier`/`serve_listener`, +each of which pumps progress internally). If that requirement is +inconvenient — for example, because the application's main loop is +already complex, or because a listener rank has long stretches of pure +computation — pass `daemon = true` to the backend constructor: + +```julia +backend = select_mpi_rpc_backend!( + UniformMPIRPCBackend(MPI.COMM_WORLD; daemon = true)) +``` + +`select_mpi_rpc_backend!` then `Threads.@spawn`s a yield-only loop that +calls `rpc_progress!` in tight rotation until `shutdown!` flips +`backend.running[]` to `false`. `shutdown!` then `wait`s on the daemon +task before returning, so the caller can rely on no further MPI calls +being issued from the daemon after `shutdown!` returns. + +The daemon never sleeps. This is a deliberate choice: a sleeping daemon +trades CPU for tail latency on inbound requests, and the same trade-off +is already available without a daemon (just don't enable it, and pump +progress on your own schedule). Enabling the daemon is a "consume one +OS thread, never wait on the wire" decision; if that is the wrong +trade-off for your workload, leave `daemon = false`. + +#### Threadpool isolation + +The daemon is spawned on Julia's `:interactive` threadpool when at +least one interactive thread is configured (`julia -t N,M` with +`M >= 1`). The interactive pool exists exactly for latency-sensitive +tasks that must not be starved by `:default`-pool work. Concretely, +this means: a CPU-bound user computation on a default-pool thread — +e.g. a tight numerical loop with no yield points, or a long `ccall` +to a synchronous C library — **cannot** prevent the daemon from +servicing inbound RPC. The `daemon survives CPU-bound default-pool +work` testset in +[`test/uniform_daemon_mpiexec.jl`](../test/uniform_daemon_mpiexec.jl) +locks this in: it saturates every default-pool thread with a +no-yield busy loop and asserts that a peer's `remotecall_fetch` +*to this rank* still completes promptly. + +Handlers (`_run_handler_task`) deliberately stay on the `:default` +pool: handlers run user code, which is the workload the user wants +done; placing handlers on `:interactive` would steal latency budget +from whoever else uses that pool (the REPL, GC threads, etc.). + +If `Threads.nthreads(:interactive) == 0` at `select_mpi_rpc_backend!` +time, the daemon falls back to `:default` and emits a one-shot `@info` +message naming the consequence. Existing application behavior is +preserved; isolation is opt-in via the launch flag. + +Two regression suites exercise this path: +[`test/uniform_daemon_mpiexec.jl`](../test/uniform_daemon_mpiexec.jl) +and +[`test/nonuniform_daemon_mpiexec.jl`](../test/nonuniform_daemon_mpiexec.jl). +Neither calls `rpc_progress!` or `serve_listener` from user code, so +they are red unless the daemon is doing all of the inbound draining. + +### Cond-park under `daemon = true` + +`MPIFuture` carries a `Threads.Condition`. `deliver!` (called from +`_dispatch_reply!` after `take_waiter!` removes the future from the +waiter table) acquires the cond, stores the value atomically with +`@atomic :release`, and `notify(cond, all=true)` so any number of +waiters on the same future wake at once. + +`wait(::MPIFuture)` reads `f.backend.daemon` and chooses one of two +paths: + +```julia +function Base.wait(f::MPIFuture) + isready(f) && return f + backend = f.backend::AbstractMPIRPCBackend + if backend.daemon + @lock f.cond begin # check-park-recheck idiom + while !isready(f) + wait(f.cond) # task fully descheduled + end + end + else + while !isready(f) # v1 spin: this task is the + rpc_progress!(backend) # only progress driver, so + isready(f) && break # parking would deadlock + yield() + end + end + return f +end +``` + +The check is on `f.backend.daemon`, not the currently-installed +backend, so a future created under a daemon-backed backend is still +park-cheap to wait on even from a task that has scoped a different +backend via `with_mpi_rpc_backend`. + +**Lock geometry.** Three locks coexist on the reply path: + +1. `backend.mpi_lock` — wraps the raw MPI calls in `_try_recv_one!` + (`Improbe`, `Imrecv!`, `Test`). It is *released* between `Test` + polls so the daemon's wait for a slow rendezvous receive does not + block other threads' MPI calls. +2. `backend.waiters_lock` — wraps `take_waiter!`, the lookup that + removes the future from the waiter table. +3. `f.cond` — wraps the value store and `notify` in `deliver!`, and + the predicate check / `wait` in the consumer. + +These are acquired strictly sequentially in `_dispatch_reply!`: +`mpi_lock` is released before `waiters_lock` is taken (the lookup +happens after the buffer is fully out of MPI), and `waiters_lock` is +released before `f.cond` is taken (`take_waiter!` returns the future, +then `deliver!` runs). At no point are two of them held at the same +time, so there is no cross-future or cross-rank ordering hazard. + +**Wakeup atomicity.** `deliver!` and the consumer use the standard +condition-variable idiom: predicate (`isready`) is read inside the +lock, and `notify` happens *after* the predicate is set, *while +holding the same lock*. This prevents the lost-wakeup race where the +consumer's `isready` check observes `false` and parks just after the +producer's `notify` fired but before the producer's store became +visible. + +### Non-blocking receive (`Imrecv!` + `Test`/`yield`) + +`_try_recv_one!` does not call `MPI.Mrecv!`. The blocking +matched-receive primitive would have held `backend.mpi_lock` for the +entire duration of the message transfer, which for a large +rendezvous-protocol payload (multi-MB) can be milliseconds — during +which no other thread on the rank could acquire `mpi_lock` to do its +own MPI call (e.g. a handler trying to `Isend` a reply for a +different RPC). + +Instead, `_try_recv_one!` runs in two phases: + +1. **Match-and-post under `mpi_lock`.** `Improbe` matches the next + message on the requested tag, the buffer is allocated to the + matched count, and `Imrecv!` is posted. All three steps are + atomic — they have to be, because the matched message handle is + only valid until consumed by `Imrecv!`. + +2. **Poll-and-yield without `mpi_lock`.** The non-blocking request + from `Imrecv!` is then `Test`'d in a loop; `mpi_lock` is acquired + only for the `Test` call itself (microseconds) and released before + `yield()`. While the daemon's task is yielded, other threads can + freely acquire `mpi_lock` and progress their own MPI work. + +For typical RPC payloads under MPI's eager threshold (often 64 KiB on +OpenMPI, 256 KiB on MPICH), `Test` succeeds on the first poll because +the eager protocol delivered the entire payload during `Improbe`'s +match; the cost vs. blocking `Mrecv!` is one extra `Test` call and +one lock acquire/release pair, single-digit microseconds. For large +payloads on rendezvous, the daemon thread becomes truly cooperative +during the receive — yielding repeatedly until the transfer +completes — and other threads on the rank can issue their own MPI +calls in the gaps. + +Combined with the threadpool isolation (daemon on `:interactive`), +this means the only places where a thread spends real wall time +inside MPI without yielding are `MPI.Init`, `MPI.Comm_dup`, and +`MPI.Finalize`, all of which are one-time startup or shutdown +events. **At runtime, no rank-level MPI call holds a thread without +yield.** + +The `large concurrent payloads via Imrecv! + Test/yield` testset in +[`test/uniform_daemon_mpiexec.jl`](../test/uniform_daemon_mpiexec.jl) +exercises this path with multiple concurrent 1-MiB round-trips that +deliberately exceed common eager thresholds. + +## 7. Uniform vs non-uniform — operational rules + +| | Uniform | Non-uniform | +|---|---|---| +| Who initiates RPC? | Any rank | Any rank | +| Who services RPC? | All ranks | Only `listener_ranks` | +| Who must call `rpc_progress!`? | Every rank, regularly *(or none, if `daemon = true`)* | Listeners (for requests) and any rank with outstanding `MPIFuture`s (for replies). `wait`/`fetch` already pumps. *(Or no rank, if `daemon = true`.)* | +| `dest_rank` validity | Any peer in `comm` | Must be in `listener_ranks` | +| Subcomm by default | Yes (`MPI.Comm_dup`) | Yes (`MPI.Comm_dup`) | +| Phase barrier | `rpc_barrier(backend)` | `rpc_barrier(backend)` | +| World collectives | Run on `MPI.COMM_WORLD` independently | Same — the duped comm isolates RPC from world traffic | + +### Collective safety in non-uniform mode + +Because the backend duplicates the user's communicator, world-level +collectives (e.g. `MPI.Barrier(MPI.COMM_WORLD)`) cannot be matched against +RPC `Isend`s. Even so, **all ranks of `COMM_WORLD` must still participate +in any world collective**. The +[non-uniform mpiexec suite](../test/nonuniform_mpiexec.jl) explicitly +exercises a `MPI.Barrier(MPI.COMM_WORLD)` interleaved with subcomm RPC to +make sure the test does not encode a collective mismatch as "pass". + +## 8. World-age and `invokelatest` + +Where `Distributed/process_messages.jl` wraps the body deserializer in +`invokelatest(deserialize_msg, ...)` and the handler in +`invokelatest(msg.f, ...)`, MPIRPC does the same: see `decode_frame` in +[`protocol.jl`](../src/protocol.jl) and `_execute_request!` in +[`uniform.jl`](../src/uniform.jl). This lets remote ranks call functions +defined after their own world age has advanced (e.g. user code defined +between two RPC phases). + +## 9. Mapping to Distributed (cheat sheet) + +| Distributed | MPIRPC | +|---|---| +| `Future` | `MPIFuture` | +| `RRID(whence::Int, id::Int)` | `MPIRRID(whence::Int32, id::UInt64)` | +| `MsgHeader(response_oid, notify_oid)` | identical | +| `CallMsg{:call}`, `CallMsg{:call_fetch}`, `CallWaitMsg`, `RemoteDoMsg`, `ResultMsg` | identical names and roles | +| `MSG_BOUNDARY` (10 bytes) | `MSG_BOUNDARY` (10 bytes; different sentinel) | +| `ClusterSerializer` over a TCP stream | `Serializer` over a per-message `IOBuffer` | +| `invokelatest(deserialize_msg, ...)` | `invokelatest(deserialize, ...)` in `decode_frame` | +| `invokelatest(msg.f, msg.args...; ...)` in `handle_msg` | identical idiom in `_execute_request!` | +| `RemoteException` wrapping `CapturedException` | `MPIRemoteException` wrapping `CapturedException` | +| `remotecall` / `remotecall_fetch` / `remotecall_wait` / `remote_do` / `fetch` / `wait` | same names | + +## 10. Security + +Same model as `Distributed`: deserializing function objects and arguments +sent by peers is equivalent to letting them run code. Only run MPIRPC +across mutually trusted ranks. + +## 11. What is intentionally not in v1 + +* No `ClusterSerializer`-equivalent global-binding propagation. +* No `RemoteChannel` (server-resident channel) — `MPIFuture` is the only + reference type. +* No FIFO of *handler execution* across messages (only MPI delivery + FIFO). See §6. +* No mid-process backend re-init. +* No transport other than `MPI.jl`. +* No retry / fault tolerance — a process death is fatal to outstanding + futures targeting that rank, just as in `Distributed`. diff --git a/lib/MPIRPC/examples/nonuniform_driver.jl b/lib/MPIRPC/examples/nonuniform_driver.jl new file mode 100644 index 000000000..d7777ddd5 --- /dev/null +++ b/lib/MPIRPC/examples/nonuniform_driver.jl @@ -0,0 +1,63 @@ +# Non-uniform driver: half the ranks are dedicated listeners ("workers"), +# the rest are clients that submit work and collect results. +# +# Run with: +# mpiexec -n 4 julia --project=. examples/nonuniform_driver.jl + +using MPI +using MPIRPC + +MPI.Init(; threadlevel=:multiple) + +const WORLD_RANK = MPI.Comm_rank(MPI.COMM_WORLD) +const NPROC = MPI.Comm_size(MPI.COMM_WORLD) + +NPROC >= 2 || error("non-uniform example needs at least 2 ranks") + +# Half the ranks are listeners. Adjust to your topology. +const HALF = max(1, NPROC ÷ 2) +const LISTENER_RANKS = collect(0:(HALF - 1)) + +backend = MPIRPC.select_mpi_rpc_backend!( + NonUniformMPIRPCBackend(MPI.COMM_WORLD; listener_ranks = LISTENER_RANKS)) + +if MPIRPC.is_listener(backend) + println("[rank $WORLD_RANK] LISTENER; serving until shutdown") + flush(stdout) + + # Listener loop: pump progress, exit when a client tells us to stop. + # `rpc_progress!` itself does not block; this loop is the listener's + # main service loop. In production code, interleave `rpc_progress!` + # with whatever else the listener does. + while backend.running[] + MPIRPC.rpc_progress!(backend) + yield() + end + println("[rank $WORLD_RANK] LISTENER exit") +else + println("[rank $WORLD_RANK] CLIENT; submitting calls to $LISTENER_RANKS") + flush(stdout) + + # Submit calls to every listener. + K = 8 + futs = MPIFuture[] + for ℓ in LISTENER_RANKS, k in 1:K + push!(futs, MPIRPC.remotecall(+, ℓ, WORLD_RANK * 100, k)) + end + for f in futs + v = fetch(f) + println("[rank $WORLD_RANK] got $v from listener (expected WORLD_RANK*100 + k)") + end + + # Tell every listener to wind down. `remote_do` is fire-and-forget; the + # listeners flip their own `running` flag and exit their service loop. + for ℓ in LISTENER_RANKS + MPIRPC.remote_do(MPIRPC.shutdown!, ℓ) + end +end + +# Both roles must reach the same barrier. `rpc_barrier` is preferred over +# `MPI.Barrier` so any in-flight RPC drains before the program exits. +MPIRPC.rpc_barrier() +MPIRPC.shutdown!() +MPI.Finalize() diff --git a/lib/MPIRPC/examples/rpc_matmul_nonuniform.jl b/lib/MPIRPC/examples/rpc_matmul_nonuniform.jl new file mode 100644 index 000000000..2a1ed0370 --- /dev/null +++ b/lib/MPIRPC/examples/rpc_matmul_nonuniform.jl @@ -0,0 +1,131 @@ +# RPC-based matrix multiplication (non-uniform: listeners own B, clients own A) +# +# First half of ranks are **listeners** (service RPC); the rest are **clients** +# (issue RPC only). Each listener owns one contiguous column strip of `B`; +# each client owns one contiguous row block of `A`. Clients fetch every column +# strip from the listener ranks via `remotecall_fetch` — clients never call +# each other (matches `NonUniformMPIRPCBackend` constraints). +# +# Run from the MPIRPC package root under lib/MPIRPC (needs at least 4 ranks, same as package +# non-uniform tests): +# mpiexec -n 4 julia --threads=2,1 --project=. examples/rpc_matmul_nonuniform.jl +# +# Optional: `N=128` env sets matrix dimension (must satisfy divisibility below). +# +# We enable `daemon = true` so listeners continuously drain inbound RPC without +# a dedicated manual `serve_listener` loop while clients issue many fetches. + +using MPI +using MPIRPC + +const MAX_VERIFY_N = 256 + +a_elem(i::Int, j::Int) = sin(i) + cos(j) +b_elem(i::Int, j::Int) = tanh(i * 0.01) + tanh(j * 0.01) + +const _RPC_B_STRIP = Ref{Matrix{Float64}}() + +MPI.Init(; threadlevel = :multiple) + +const WORLD_RANK = MPI.Comm_rank(MPI.COMM_WORLD) +const NPROC = MPI.Comm_size(MPI.COMM_WORLD) + +NPROC >= 4 || error("non-uniform matmul example needs at least 4 ranks (got $NPROC)") + +const HALF = max(1, NPROC ÷ 2) +const LISTENER_RANKS = collect(0:(HALF - 1)) +const CLIENT_RANKS = collect(HALF:(NPROC - 1)) +const IS_LISTENER = WORLD_RANK in LISTENER_RANKS + +const N_LISTENERS = length(LISTENER_RANKS) +const N_CLIENTS = length(CLIENT_RANKS) + +backend = MPIRPC.select_mpi_rpc_backend!( + NonUniformMPIRPCBackend(MPI.COMM_WORLD; + listener_ranks = LISTENER_RANKS, + daemon = true)) +const COMM = backend.comm +const RANK = MPI.Comm_rank(COMM) +@assert RANK == WORLD_RANK + +const N = parse(Int, get(ENV, "N", "64")) +N % N_LISTENERS == 0 || error("N ($N) must be divisible by number of listeners ($N_LISTENERS)") +N % N_CLIENTS == 0 || error("N ($N) must be divisible by number of clients ($N_CLIENTS)") + +const N_LOC_COL = N ÷ N_LISTENERS # columns per listener strip +const N_LOC_ROW = N ÷ N_CLIENTS # rows per client block + +# --- Build owned pieces ------------------------------------------------------- + +if IS_LISTENER + ℓ = findfirst(==(WORLD_RANK), LISTENER_RANKS)::Int + COL_LO = (ℓ - 1) * N_LOC_COL + 1 + COL_HI = ℓ * N_LOC_COL + B_loc = Matrix{Float64}(undef, N, N_LOC_COL) + for lj in 1:N_LOC_COL + j_g = COL_LO + lj - 1 + for i in 1:N + B_loc[i, lj] = b_elem(i, j_g) + end + end + _RPC_B_STRIP[] = B_loc +else + c = findfirst(==(WORLD_RANK), CLIENT_RANKS)::Int + ROW_LO = (c - 1) * N_LOC_ROW + 1 + ROW_HI = c * N_LOC_ROW + A_loc = Matrix{Float64}(undef, N_LOC_ROW, N) + for li in 1:N_LOC_ROW + i_g = ROW_LO + li - 1 + for j in 1:N + A_loc[li, j] = a_elem(i_g, j) + end + end +end + +MPIRPC.rpc_barrier() + +# --- Clients multiply; listeners only service RPC (daemon drives progress) ----- + +if IS_LISTENER + C_loc = nothing +else + C_loc = zeros(N_LOC_ROW, N) + for ℓ in 1:N_LISTENERS + listener = LISTENER_RANKS[ℓ] + clo = (ℓ - 1) * N_LOC_COL + 1 + chi = ℓ * N_LOC_COL + B_panel = MPIRPC.remotecall_fetch(listener) do + return copy(Main._RPC_B_STRIP[]) + end + size(B_panel) == (N, N_LOC_COL) || error("bad B_panel size from listener $listener") + C_loc[:, clo:chi] = A_loc * B_panel + end +end + +MPIRPC.rpc_barrier() + +# --- Verification on first client only (small N) ------------------------------ + +VERIFY_RANK = first(CLIENT_RANKS) +if WORLD_RANK == VERIFY_RANK && N ≤ MAX_VERIFY_N && !IS_LISTENER + c = findfirst(==(WORLD_RANK), CLIENT_RANKS)::Int + ROW_LO = (c - 1) * N_LOC_ROW + 1 + ROW_HI = c * N_LOC_ROW + A_full = Matrix{Float64}(undef, N, N) + B_full = Matrix{Float64}(undef, N, N) + for j in 1:N, i in 1:N + A_full[i, j] = a_elem(i, j) + B_full[i, j] = b_elem(i, j) + end + C_ref = A_full * B_full + ref_rows = C_ref[ROW_LO:ROW_HI, :] + err = maximum(abs.(C_loc .- ref_rows)) + println("[client rank $WORLD_RANK] max |C_loc - C_ref| on owned rows = ", err) + @assert err < 1e-10 * max(1.0, maximum(abs.(C_ref))) +elseif WORLD_RANK == VERIFY_RANK + println("[client rank $WORLD_RANK] skipping dense verification (N=$N > $MAX_VERIFY_N)") +end + +MPIRPC.rpc_barrier() +MPIRPC.shutdown!() +MPI.Finalize() diff --git a/lib/MPIRPC/examples/rpc_matmul_uniform.jl b/lib/MPIRPC/examples/rpc_matmul_uniform.jl new file mode 100644 index 000000000..90d67a8f3 --- /dev/null +++ b/lib/MPIRPC/examples/rpc_matmul_uniform.jl @@ -0,0 +1,111 @@ +# RPC-based matrix multiplication (uniform SPMD) +# +# Each rank owns a row block of `A` and a column block of `B`. To form +# `C = A * B`, rank `r` needs every column strip of `B`; those strips live on +# distinct ranks, so we fetch each peer's strip via `remotecall_fetch`. No MPI +# collectives move `A` / `B` on the compute path — only point-to-point RPC. +# +# Run from the MPIRPC package root (e.g. Dagger.jl/lib/MPIRPC): +# mpiexec -n 4 julia --project=. examples/rpc_matmul_uniform.jl +# +# Optional threading (matches package tests; useful if you switch to +# `UniformMPIRPCBackend(...; daemon = true)` — daemon uses `:interactive`): +# mpiexec -n 4 julia --threads=2,1 --project=. examples/rpc_matmul_uniform.jl +# +# Limitations: +# * Matrix dimension `N` below must divide `NPROC`; keep `N` modest — each RPC +# returns an `n × (N/nproc)` Float64 strip (multi-MiB if `N` is huge). +# * Rank-0 verification builds dense `N×N` references and only runs when +# `N ≤ MAX_VERIFY_N` (memory + time). + +using MPI +using MPIRPC + +const MAX_VERIFY_N = 256 + +# Deterministic element generators (global indices i, j ∈ 1:N). +a_elem(i::Int, j::Int) = sin(i) + cos(j) +b_elem(i::Int, j::Int) = tanh(i * 0.01) + tanh(j * 0.01) + +# Filled before any RPC so the handler closure on the destination rank reads the +# correct strip via `Main._RPC_B_STRIP` (each MPI rank is its own process). +const _RPC_B_STRIP = Ref{Matrix{Float64}}() + +MPI.Init(; threadlevel = :multiple) +backend = MPIRPC.select_mpi_rpc_backend!(UniformMPIRPCBackend(MPI.COMM_WORLD)) +const COMM = backend.comm +const RANK = MPI.Comm_rank(COMM) +const NPROC = MPI.Comm_size(COMM) + +# Demo size: override with `N=128 julia ...` if desired (must divide world size). +const N = parse(Int, get(ENV, "N", "64")) +NPROC >= 1 || error("need at least 1 rank") +N % NPROC == 0 || error("N ($N) must be divisible by NPROC ($NPROC)") +const N_LOC = N ÷ NPROC + +# --- Local owned data --------------------------------------------------------- + +const ROW_LO = RANK * N_LOC + 1 +const ROW_HI = (RANK + 1) * N_LOC +const COL_LO = RANK * N_LOC + 1 +const COL_HI = (RANK + 1) * N_LOC + +A_loc = Matrix{Float64}(undef, N_LOC, N) +B_loc = Matrix{Float64}(undef, N, N_LOC) + +for li in 1:N_LOC + i_g = ROW_LO + li - 1 + for j in 1:N + A_loc[li, j] = a_elem(i_g, j) + end +end +for lj in 1:N_LOC + j_g = COL_LO + lj - 1 + for i in 1:N + B_loc[i, lj] = b_elem(i, j_g) + end +end + +_RPC_B_STRIP[] = B_loc + +MPIRPC.rpc_barrier() + +# --- Compute C_loc = A_loc * B via RPC-fetched column strips ----------------- + +C_loc = zeros(N_LOC, N) + +for q in 0:(NPROC - 1) + clo = q * N_LOC + 1 + chi = (q + 1) * N_LOC + # Closure runs on rank `q`; it captures nothing from rank `r` except what + # Julia serializes — here we only need rank q's Main._RPC_B_STRIP. + B_panel = MPIRPC.remotecall_fetch(q) do + return copy(Main._RPC_B_STRIP[]) + end + size(B_panel) == (N, N_LOC) || error("unexpected B_panel size on rank $RANK from $q") + C_loc[:, clo:chi] = A_loc * B_panel +end + +MPIRPC.rpc_barrier() + +# --- Verification on rank 0 (small N only) ------------------------------------ + +if RANK == 0 && N ≤ MAX_VERIFY_N + A_full = Matrix{Float64}(undef, N, N) + B_full = Matrix{Float64}(undef, N, N) + for j in 1:N, i in 1:N + A_full[i, j] = a_elem(i, j) + B_full[i, j] = b_elem(i, j) + end + C_ref = A_full * B_full + ref_rows = C_ref[ROW_LO:ROW_HI, :] + err = maximum(abs.(C_loc .- ref_rows)) + println("[rank 0] max |C_loc - C_ref| on owned rows = ", err) + @assert err < 1e-10 * max(1.0, maximum(abs.(C_ref))) +elseif RANK == 0 + println("[rank 0] skipping dense verification (N=$N > MAX_VERIFY_N=$MAX_VERIFY_N)") +end + +MPIRPC.rpc_barrier() +MPIRPC.shutdown!() +MPI.Finalize() diff --git a/lib/MPIRPC/examples/uniform_driver.jl b/lib/MPIRPC/examples/uniform_driver.jl new file mode 100644 index 000000000..5944db9ab --- /dev/null +++ b/lib/MPIRPC/examples/uniform_driver.jl @@ -0,0 +1,41 @@ +# Minimal uniform / SPMD driver for MPIRPC. +# +# Run with: +# mpiexec -n 4 julia --project=. examples/uniform_driver.jl +# +# Every rank issues calls and services calls; progress is pumped implicitly +# by `wait`/`fetch` and explicitly via `rpc_barrier` between phases. + +using MPI +using MPIRPC + +MPI.Init(; threadlevel=:multiple) +backend = MPIRPC.select_mpi_rpc_backend!(UniformMPIRPCBackend(MPI.COMM_WORLD)) +const RANK = MPI.Comm_rank(backend.comm) +const NPROC = MPI.Comm_size(backend.comm) + +println("[rank $RANK] up; world size = $NPROC") + +# Phase 1: every rank asks its right-hand neighbor what its rank is. +peer = mod(RANK + 1, NPROC) +neighbor_rank = MPIRPC.remotecall_fetch(peer) do + return MPI.Comm_rank(MPI.COMM_WORLD) +end +println("[rank $RANK] right neighbor reports rank=$neighbor_rank (expected $peer)") + +# Phase boundary: pump progress until every rank arrives. Without this, a +# rank that finishes early could exit while a peer's request to it is still +# unprocessed in its inbox. +MPIRPC.rpc_barrier() + +# Phase 2: scatter "compute" work all-to-all. +peers = [r for r in 0:(NPROC-1) if r != RANK] +futs = [MPIRPC.remotecall(*, p, RANK + 1, p + 1) for p in peers] +for (p, f) in zip(peers, futs) + @assert fetch(f) == (RANK + 1) * (p + 1) +end +println("[rank $RANK] all-to-all phase 2 done") + +MPIRPC.rpc_barrier() +MPIRPC.shutdown!() +MPI.Finalize() diff --git a/lib/MPIRPC/src/MPIRPC.jl b/lib/MPIRPC/src/MPIRPC.jl new file mode 100644 index 000000000..b7c42ece4 --- /dev/null +++ b/lib/MPIRPC/src/MPIRPC.jl @@ -0,0 +1,53 @@ +""" + MPIRPC + +Standalone MPI-backed RPC for Julia, modeled on the `Distributed` stdlib +(`MsgHeader` + serialized body + boundary, `invokelatest` on the handler +path, `RemoteException`-style error wrapping) and on Dagger's +"accelerate-once" backend selection (a single backend value held in a +task-local slot determines uniform vs. non-uniform semantics; the public +API is the same in both modes). + +See `docs/ARCHITECTURE.md` for the design narrative, the deadlock and +ABBA discussion, and the side-by-side mapping to `Distributed` and Dagger. +""" +module MPIRPC + +using Serialization +using MPI + +export AbstractMPIRPCBackend, + UniformMPIRPCBackend, + NonUniformMPIRPCBackend, + MPIFuture, + MPIRRID, + MPIRemoteException, + select_mpi_rpc_backend!, + current_mpi_rpc_backend, + initialize_mpi_rpc!, + with_mpi_rpc_backend, + shutdown!, + remotecall, + remotecall_fetch, + remotecall_wait, + remote_do, + bcast_remotecall, + rpc_progress!, + rpc_progress_halt!, + rpc_barrier, + serve_listener, + is_listener, + listener_ranks, + @with_progress + +include("protocol.jl") +include("exceptions.jl") +include("config.jl") +include("refs.jl") +include("dispatch.jl") +include("uniform.jl") +include("nonuniform.jl") +include("remotecall.jl") +include("progress.jl") + +end # module MPIRPC diff --git a/lib/MPIRPC/src/config.jl b/lib/MPIRPC/src/config.jl new file mode 100644 index 000000000..fbdb20bb9 --- /dev/null +++ b/lib/MPIRPC/src/config.jl @@ -0,0 +1,150 @@ +""" + AbstractMPIRPCBackend + +Marker supertype for an installed MPIRPC backend. Concrete backends +(`UniformMPIRPCBackend`, `NonUniformMPIRPCBackend`) hold the communicator, +tag layout, OID counter, waiter table, and MPI lock that drive the wire +protocol; the public RPC surface (`remotecall`, `remotecall_fetch`, +`rpc_progress!`, ...) dispatches dynamically on this type so call sites +never branch on mode. + +The selection model mirrors Dagger's acceleration: see +[`Dagger.jl/src/acceleration.jl`](../../Dagger.jl/src/acceleration.jl) +(`accelerate!`, `current_acceleration`, `_with_default_acceleration`). +[`select_mpi_rpc_backend!`](@ref) installs a backend once for the whole +process; [`with_mpi_rpc_backend`](@ref) layers a task-local override for +scoped tests. +""" +abstract type AbstractMPIRPCBackend end + +const _BACKEND_KEY = :mpi_rpc_backend + +# Process-global default backend, written once at startup by +# `select_mpi_rpc_backend!`. Reads are unsynchronized; the assumption is the +# same as Distributed's `init_parallel` / Dagger's `accelerate!`: install +# happens before any RPC traffic and there is at most one installed backend +# per process for the duration of its life. +const _GLOBAL_BACKEND = Ref{Union{AbstractMPIRPCBackend, Nothing}}(nothing) + +""" + initialize_mpi_rpc!(backend::AbstractMPIRPCBackend) + +Per-backend hook (analog of Dagger's `initialize_acceleration!`). Concrete +backends override this to do MPI initialization, communicator splits, tag +range allocation, etc. The default is a no-op. +""" +initialize_mpi_rpc!(::AbstractMPIRPCBackend) = nothing + +""" + select_mpi_rpc_backend!(backend::AbstractMPIRPCBackend) -> backend + +Install `backend` as the current backend for the whole process and run +[`initialize_mpi_rpc!`](@ref) once. After this call, every public RPC +function (`remotecall`, `remotecall_fetch`, `rpc_progress!`, ...) — from +*any* task — dispatches through this backend, unless overridden inside +[`with_mpi_rpc_backend`](@ref). + +Re-entering with a different backend is **not supported** in v1: calling +this twice in the same process replaces the slot but the on-the-wire state +(communicators, in-flight `Isend` buffers, OID counters) of the first +backend is not reset. To switch modes mid-job, call [`shutdown!`](@ref) on +the existing backend, finalize MPI if appropriate, and start a fresh process. +""" +function select_mpi_rpc_backend!(backend::AbstractMPIRPCBackend) + initialize_mpi_rpc!(backend) + _GLOBAL_BACKEND[] = backend + # Spawn the optional yield-only progress daemon. We do this *after* + # initialize_mpi_rpc! so the daemon's first `rpc_progress!` finds a + # fully constructed communicator, and *after* the global slot has been + # written so a re-entrant `current_mpi_rpc_backend()` from inside the + # daemon resolves to the right object. + # + # Threadpool placement: prefer `:interactive` so a CPU-bound user + # computation on the `:default` pool cannot starve the wire pump. + # The interactive pool exists exactly for latency-sensitive tasks + # that must keep getting CPU regardless of default-pool load. Fall + # back to `:default` only if the user did not configure any + # interactive threads (`julia -t N,0` or implicit zero); in that + # case we emit a one-time info message so the operator knows why + # tail latency might rise under heavy compute. + if backend.daemon && backend.daemon_task === nothing + if Threads.nthreads(:interactive) > 0 + backend.daemon_task = errormonitor(Threads.@spawn :interactive _daemon_loop(backend)) + else + @info """MPIRPC: no `:interactive` threads configured; the daemon will share \ + the `:default` pool with user computation. CPU-bound user code on the \ + default pool can starve the wire pump. Start Julia with `-t N,M` (M >= 1) \ + to give the daemon a dedicated interactive thread.""" + backend.daemon_task = errormonitor(Threads.@spawn _daemon_loop(backend)) + end + end + return backend +end + +""" + current_mpi_rpc_backend() -> AbstractMPIRPCBackend + +Return the backend installed for the current task. Prefers a task-local +override (set by [`with_mpi_rpc_backend`](@ref)); otherwise falls back to +the process-global slot installed by [`select_mpi_rpc_backend!`](@ref). +Throws `ArgumentError` if no backend has been installed yet. +""" +function current_mpi_rpc_backend() + local_backend = task_local_storage(_BACKEND_KEY, nothing) + if local_backend !== nothing + return local_backend::AbstractMPIRPCBackend + end + g = _GLOBAL_BACKEND[] + g === nothing && throw(ArgumentError( + "no MPIRPC backend installed; call `select_mpi_rpc_backend!` first")) + return g::AbstractMPIRPCBackend +end + +""" + with_mpi_rpc_backend(f, backend) -> f() + +Scoped variant: install `backend` for the duration of `f()` (in this task +only) and restore the previous task-local override on exit. The analog of +Dagger's `_with_default_acceleration`. This does **not** call +`initialize_mpi_rpc!`; pass an already-initialized backend, or call +`select_mpi_rpc_backend!` once at startup and use this helper only to layer +scoped overrides during tests. +""" +function with_mpi_rpc_backend(f, backend::AbstractMPIRPCBackend) + prev = task_local_storage(_BACKEND_KEY, nothing) + task_local_storage(_BACKEND_KEY, backend) + try + return f() + finally + if prev === nothing + delete!(task_local_storage(), _BACKEND_KEY) + else + task_local_storage(_BACKEND_KEY, prev) + end + end +end + +""" + shutdown!([backend]) + +Stop the listener loop on this rank by clearing the backend's `running` +flag. Outstanding `Isend` requests are reaped on a best-effort basis. +This is a *local* operation; in non-uniform mode the application should +broadcast a stop signal across listener ranks before calling this on each. + +If a progress daemon was spawned by [`select_mpi_rpc_backend!`](@ref) +(`daemon = true` on the backend), this function also `wait`s on the +daemon task so the caller can rely on no further `rpc_progress!` +running once `shutdown!` returns. The wait is bounded only by the +daemon's own loop body; because the loop is yield-only and checks the +flag every iteration, it terminates promptly. +""" +function shutdown!(backend::AbstractMPIRPCBackend = current_mpi_rpc_backend()) + backend.running[] = false + t = backend.daemon_task + if t !== nothing + wait(t) + backend.daemon_task = nothing + end + return nothing +end diff --git a/lib/MPIRPC/src/dispatch.jl b/lib/MPIRPC/src/dispatch.jl new file mode 100644 index 000000000..488056626 --- /dev/null +++ b/lib/MPIRPC/src/dispatch.jl @@ -0,0 +1,37 @@ +""" +Shared request/reply dispatch helpers for concrete backends. See +`uniform.jl` / `nonuniform.jl` for `take_waiter!` / `_send_reply!` methods. +""" + +function _reply_exception!(b::AbstractMPIRPCBackend, notify_oid::MPIRRID, rank::Int, ex) + ce = ex isa CapturedException ? ex : CapturedException(ex, catch_backtrace()) + _send_reply!(b, notify_oid, MPIRemoteException(rank, ce)) + return nothing +end + +""" +Deliver a reply-frame body deserialization failure into the client's +`MPIFuture` when `response_oid` is known. Returns `true` if a waiter +was found and delivered. +""" +function _deliver_reply_deserialize_failure!(b::AbstractMPIRPCBackend, + response_oid::MPIRRID, rank::Int, body_err) + is_null(response_oid) && return false + fut = take_waiter!(b, response_oid) + if fut === nothing + @warn "MPIRPC: reply for unknown waiter (response_oid=$(response_oid)) on rank $(rank); dropping" + return false + end + deliver!(fut, MPIRemoteException(rank, CapturedException(body_err, catch_backtrace()))) + return true +end + +function _notify_request_decode_failure!(b::AbstractMPIRPCBackend, header::MsgHeader, + rank::Int, body_err) + if !is_null(header.notify_oid) + _reply_exception!(b, header.notify_oid, rank, body_err) + else + showerror(stderr, CapturedException(body_err, catch_backtrace())) + end + return nothing +end diff --git a/lib/MPIRPC/src/exceptions.jl b/lib/MPIRPC/src/exceptions.jl new file mode 100644 index 000000000..e4def57dc --- /dev/null +++ b/lib/MPIRPC/src/exceptions.jl @@ -0,0 +1,39 @@ +""" + MPIRemoteException(rank::Int, captured::CapturedException) + +Local analogue of `Distributed.RemoteException`. Wraps the originating MPI +rank and a `Base.CapturedException` for the original error and its captured +backtrace, so end-users can inspect the remote failure without any +`Distributed` dependency. +""" +struct MPIRemoteException <: Exception + rank::Int + captured::CapturedException +end + +MPIRemoteException(captured::CapturedException) = MPIRemoteException(-1, captured) + +Base.capture_exception(ex::MPIRemoteException, _) = ex + +function Base.showerror(io::IO, re::MPIRemoteException) + print(io, "On MPI rank ", re.rank, ":\n") + showerror(io, re.captured) +end + +""" + run_work_thunk(thunk, rank; print_error::Bool=false) -> result_or_exception + +Execute `thunk()` and either return the value or wrap any thrown exception +in an `MPIRemoteException`. When `print_error` is true (as for +`Distributed.remote_do`), the captured exception is also printed to stderr, +mirroring `Distributed.run_work_thunk(thunk, print_error)`. +""" +function run_work_thunk(thunk::Function, rank::Int; print_error::Bool=false) + try + return thunk() + catch err + ce = CapturedException(err, catch_backtrace()) + print_error && showerror(stderr, ce) + return MPIRemoteException(rank, ce) + end +end diff --git a/lib/MPIRPC/src/nonuniform.jl b/lib/MPIRPC/src/nonuniform.jl new file mode 100644 index 000000000..d615a508e --- /dev/null +++ b/lib/MPIRPC/src/nonuniform.jl @@ -0,0 +1,262 @@ +""" + NonUniformMPIRPCBackend(comm; listener_ranks, ...) + +Heterogeneous backend with explicit roles. Ranks in `listener_ranks` service +inbound requests via [`rpc_progress!`](@ref); all other ranks are clients +that only initiate calls and only drain replies addressed to them. + +# Roles + +* **Listener**: must call `rpc_progress!` regularly (e.g. from a dedicated + service loop or interleaved with computation). Listeners may also issue + RPC themselves, in which case they additionally drain replies via the + same entrypoint. +* **Client-only**: never services requests from peers. Calling `remotecall*` + with a non-listener `dest_rank` raises `ArgumentError`. + +# Communicators + +If `dup_comm` is `true` (default), the backend duplicates `comm` with +`MPI.Comm_dup`. This isolates RPC traffic from collectives the user runs on +the original communicator, so for instance an `MPI.Barrier` on +`MPI.COMM_WORLD` cannot be matched against — or block — pending RPC sends. + +# Constructor + + NonUniformMPIRPCBackend(comm = MPI.COMM_WORLD; + listener_ranks::AbstractVector{<:Integer}, + request_tag = $(REQUEST_TAG_DEFAULT), + reply_tag = $(REPLY_TAG_DEFAULT), + dup_comm = true, + daemon = false) + +If `daemon` is `true`, [`select_mpi_rpc_backend!`](@ref) spawns a yield-only +progress loop on every rank (listener or client) so listener-side request +draining and client-side reply draining both happen automatically without +the user calling `rpc_progress!`. See `UniformMPIRPCBackend`'s docstring +for the threading and CPU caveats — the daemon never sleeps; it polls +continuously and uses ≈ 100% of one OS thread. +""" +mutable struct NonUniformMPIRPCBackend <: AbstractMPIRPCBackend + base_comm::MPI.Comm + comm::MPI.Comm + request_tag::Int32 + reply_tag::Int32 + dup_comm::Bool + rank::Int + size::Int + + listener_ranks::Set{Int} + is_listener::Bool + + rrid_counter::Threads.Atomic{UInt64} + waiters::Dict{MPIRRID, MPIFuture} + waiters_lock::ReentrantLock + + pending_sends::Vector{Tuple{MPI.Request, Vector{UInt8}}} + + mpi_lock::ReentrantLock + + running::Threads.Atomic{Bool} + initialized::Bool + + daemon::Bool + daemon_task::Union{Task, Nothing} + + function NonUniformMPIRPCBackend(base_comm::MPI.Comm = MPI.COMM_WORLD; + listener_ranks::AbstractVector{<:Integer}, + request_tag::Integer = REQUEST_TAG_DEFAULT, + reply_tag::Integer = REPLY_TAG_DEFAULT, + dup_comm::Bool = true, + daemon::Bool = false) + request_tag == reply_tag && throw(ArgumentError("request_tag and reply_tag must differ")) + isempty(listener_ranks) && throw(ArgumentError("listener_ranks must be non-empty")) + ls = Set{Int}(Int(r) for r in listener_ranks) + return new(base_comm, base_comm, Int32(request_tag), Int32(reply_tag), + dup_comm, -1, -1, + ls, false, + Threads.Atomic{UInt64}(1), + Dict{MPIRRID, MPIFuture}(), + ReentrantLock(), + Tuple{MPI.Request, Vector{UInt8}}[], + ReentrantLock(), + Threads.Atomic{Bool}(true), + false, + daemon, + nothing) + end +end + +function initialize_mpi_rpc!(b::NonUniformMPIRPCBackend) + b.initialized && return nothing + if !MPI.Initialized() + MPI.Init(; threadlevel=:multiple) + end + b.comm = b.dup_comm ? MPI.Comm_dup(b.base_comm) : b.base_comm + b.rank = MPI.Comm_rank(b.comm) + b.size = MPI.Comm_size(b.comm) + for r in b.listener_ranks + if r < 0 || r >= b.size + throw(ArgumentError("listener rank $r outside [0, $(b.size))")) + end + end + b.is_listener = b.rank in b.listener_ranks + b.initialized = true + return nothing +end + +""" + is_listener(backend) -> Bool + +True if the *current* rank services inbound RPC requests on this backend. +""" +is_listener(b::NonUniformMPIRPCBackend) = b.is_listener +is_listener(::UniformMPIRPCBackend) = true + +""" + listener_ranks(backend) -> Vector{Int} +""" +listener_ranks(b::NonUniformMPIRPCBackend) = sort!(collect(b.listener_ranks)) +listener_ranks(b::UniformMPIRPCBackend) = collect(0:(b.size-1)) + +next_rrid(b::NonUniformMPIRPCBackend) = + MPIRRID(Int32(b.rank), Threads.atomic_add!(b.rrid_counter, UInt64(1))) + +function register_waiter!(b::NonUniformMPIRPCBackend, fut::MPIFuture) + @lock b.waiters_lock begin + b.waiters[fut.rrid] = fut + end + return fut +end + +function take_waiter!(b::NonUniformMPIRPCBackend, rrid::MPIRRID) + @lock b.waiters_lock begin + fut = get(b.waiters, rrid, nothing) + fut === nothing && return nothing + delete!(b.waiters, rrid) + return fut + end +end + +function _post_isend!(b::NonUniformMPIRPCBackend, dest::Integer, tag::Integer, + buf::Vector{UInt8}) + @lock b.mpi_lock begin + req = MPI.Isend(buf, b.comm; dest=Int(dest), tag=Int(tag)) + push!(b.pending_sends, (req, buf)) + end + return nothing +end + +function _reap_pending_sends!(b::NonUniformMPIRPCBackend) + @lock b.mpi_lock begin + i = 1 + while i <= length(b.pending_sends) + req, _ = b.pending_sends[i] + if MPI.Test(req) + deleteat!(b.pending_sends, i) + else + i += 1 + end + end + end + return nothing +end + +# See `uniform.jl::_try_recv_one!` for the rationale; the body is identical +# because the receive path differs from the request/reply dispatch only at +# the `rpc_progress!` level (listener-vs-not gating), not here. +function _try_recv_one!(b::NonUniformMPIRPCBackend, tag::Integer) + local req::MPI.Request + local buf::Vector{UInt8} + local src::Int + @lock b.mpi_lock begin + got, m, status = MPI.Improbe(MPI.ANY_SOURCE, Int(tag), b.comm, MPI.Status) + got || return nothing + src = Int(status.MPI_SOURCE) + count = MPI.Get_count(status, UInt8) + buf = Vector{UInt8}(undef, count) + req = MPI.Imrecv!(buf, m) + end + while true + done = @lock b.mpi_lock MPI.Test(req) + done && return (src, buf) + yield() + end +end + +function rpc_progress!(b::NonUniformMPIRPCBackend) + return _rpc_progress_impl!(b) +end + +function _rpc_progress_impl!(b::NonUniformMPIRPCBackend) + _reap_pending_sends!(b) + + if b.is_listener + while true + r = _try_recv_one!(b, b.request_tag) + r === nothing && break + src, buf = r + header, msg, body_err = decode_frame(buf) + if body_err === nothing && msg isa RPCProgressHaltMsg + return false + end + # Spawn handler so the listener loop never holds any lock + # during user code. See uniform.jl for the rationale. + errormonitor(Threads.@spawn _run_handler_task(b, src, header, msg, body_err)) + end + end + + while true + r = _try_recv_one!(b, b.reply_tag) + r === nothing && break + _, buf = r + _dispatch_reply!(b, buf) + end + + return true +end + +function _execute_request!(b::NonUniformMPIRPCBackend, src::Int, header::MsgHeader, msg::AbstractMsg) + if msg isa CallMsg{:call} || msg isa CallMsg{:call_fetch} + v = run_work_thunk(() -> invokelatest(msg.f, msg.args...; pairs(msg.kwargs)...), b.rank) + if !is_null(header.notify_oid) + _send_reply!(b, header.notify_oid, v) + end + elseif msg isa CallWaitMsg + v = run_work_thunk(() -> invokelatest(msg.f, msg.args...; pairs(msg.kwargs)...), b.rank) + if !is_null(header.notify_oid) + ack = v isa MPIRemoteException ? v : :OK + _send_reply!(b, header.notify_oid, ack) + end + elseif msg isa RemoteDoMsg + run_work_thunk(() -> invokelatest(msg.f, msg.args...; pairs(msg.kwargs)...), b.rank; + print_error=true) + else + throw(ProtocolError("unhandled request message $(typeof(msg)) on rank $(b.rank))")) + end + return nothing +end + +function _dispatch_reply!(b::NonUniformMPIRPCBackend, buf::Vector{UInt8}) + header, msg, body_err = decode_frame(buf) + if body_err !== nothing + _deliver_reply_deserialize_failure!(b, header.response_oid, b.rank, body_err) + return + end + msg isa ResultMsg || throw(ProtocolError("expected ResultMsg on reply tag, got $(typeof(msg))")) + fut = take_waiter!(b, header.response_oid) + if fut === nothing + @warn "MPIRPC: ResultMsg for unknown waiter (response_oid=$(header.response_oid)) on rank $(b.rank); dropping" + return + end + deliver!(fut, msg.value) + return nothing +end + +function _send_reply!(b::NonUniformMPIRPCBackend, response_oid::MPIRRID, value) + header = MsgHeader(response_oid, NULL_RRID) + body = ResultMsg(value) + buf = encode_frame(header, body) + _post_isend!(b, response_oid.whence, b.reply_tag, buf) + return nothing +end diff --git a/lib/MPIRPC/src/progress.jl b/lib/MPIRPC/src/progress.jl new file mode 100644 index 000000000..4e9ccb4d6 --- /dev/null +++ b/lib/MPIRPC/src/progress.jl @@ -0,0 +1,162 @@ +""" + rpc_progress!() = rpc_progress!(current_mpi_rpc_backend()) + +Drive one non-blocking progress pass on the currently installed backend. +See the per-backend method docstrings for who must call this. +""" +rpc_progress!() = rpc_progress!(current_mpi_rpc_backend()) + +""" + _run_handler_task(backend, src, header, msg, body_err) + +Internal entry point for the spawned task that runs an inbound request. +Two important effects: + +* The current backend is re-installed as task-local so the user closure's + calls to `current_mpi_rpc_backend()` (e.g. nested `remotecall_fetch`) + see the same backend the request arrived on, even when the parent task + was inside a `with_mpi_rpc_backend` scope that does not propagate to + spawned children. +* Request frames are decoded in the parent [`rpc_progress!`](@ref) pass; + this task receives the triple from [`decode_frame`](@ref). Body + deserialization failures are turned into `MPIRemoteException` replies when + `header.notify_oid` is set, or printed to stderr when not (e.g. + fire-and-forget). User handler failures are already wrapped by + [`run_work_thunk`](@ref); unexpected exceptions after decode (e.g. from + `_send_reply!`) trigger a best-effort `MPIRemoteException` reply when + possible, then are rethrown so the task's `errormonitor` surfaces them. +""" +function _run_handler_task(b::AbstractMPIRPCBackend, src::Int, + header::MsgHeader, msg, body_err) + return with_mpi_rpc_backend(b) do + if body_err !== nothing + _notify_request_decode_failure!(b, header, b.rank, body_err) + return nothing + end + if msg isa RPCProgressHaltMsg + return nothing + end + try + _execute_request!(b, src, header, msg::AbstractMsg) + catch e + if !is_null(header.notify_oid) + _reply_exception!(b, header.notify_oid, b.rank, e) + end + rethrow() + end + return nothing + end +end + +""" + _daemon_loop(backend) + +Background progress driver, spawned by [`select_mpi_rpc_backend!`](@ref) +when the backend was constructed with `daemon = true`. The loop runs on +its own task (typically on its own OS thread under `julia -t >= 2`) and +calls [`rpc_progress!`](@ref) in a tight, **yield-only** loop until +`backend.running[]` flips to `false` (which happens inside +[`shutdown!`](@ref)). + +There is intentionally no `sleep` form: a sleeping daemon would delay +inbound requests by up to `poll_interval` seconds, which is the worst +kind of latency-vs-cpu trade-off for low-rate workloads. If you need +that trade-off, leave `daemon = false` and pump progress on your own +schedule. + +A `try / catch` wraps the body so an unexpected MPI or framing error +does not silently kill the daemon and leave the rank unresponsive. +Errors are logged via `@error` and then rethrown so `shutdown!`'s `wait` +on the daemon task surfaces the failure to the caller. +""" +function _daemon_loop(backend::AbstractMPIRPCBackend) + try + while backend.running[] + rpc_progress!(backend) + yield() + end + rpc_progress!(backend) # final drain after shutdown! flipped the flag + catch e + @error "MPIRPC: daemon loop crashed; this rank will stop making progress" exception = (e, catch_backtrace()) + rethrow() + end + return nothing +end + +""" + serve_listener(backend = current_mpi_rpc_backend(); poll_interval=0.0) + +Blocking helper for non-uniform listeners: loop calling +[`rpc_progress!`](@ref) until [`shutdown!`](@ref) flips +`backend.running[]` to `false`. `poll_interval` is the time, in seconds, +to `sleep` between progress passes (`0` yields without sleeping). + +This is a convenience for examples and tests; production code is free to +interleave `rpc_progress!` with its own work loop instead. +""" +function serve_listener(backend::AbstractMPIRPCBackend = current_mpi_rpc_backend(); + poll_interval::Real = 0.0) + while backend.running[] + rpc_progress!(backend) + if poll_interval > 0 + sleep(poll_interval) + else + yield() + end + end + rpc_progress!(backend) # final drain so in-flight messages are reaped + return nothing +end + +""" + rpc_barrier([backend]) + +Phase boundary for MPIRPC traffic: a non-blocking `MPI.Ibarrier` whose +completion is awaited *while every rank pumps* [`rpc_progress!`](@ref). +Use this between phases of an SPMD program, or before exiting an RPC +session, to drain in-flight work from peers. + +Why a plain `MPI.Barrier` is not enough: a rank `R` whose own +`remotecall_fetch` has just completed may still hold *unprocessed +requests* sent by other ranks (e.g. nested calls those ranks issued from +inside a handler running on `R`). If `R` enters `MPI.Barrier` it stops +pumping RPC progress, and the peers waiting on `R`'s replies hang. +`rpc_barrier` solves this by pumping progress until **all** ranks have +arrived, then doing a final drain pass. + +Calling pattern is identical to `MPI.Barrier`: every rank in the backend's +communicator must call it. +""" +function rpc_barrier(backend::AbstractMPIRPCBackend = current_mpi_rpc_backend()) + req = @lock backend.mpi_lock MPI.Ibarrier(backend.comm) + while true + rpc_progress!(backend) + done = @lock backend.mpi_lock MPI.Test(req) + done && break + yield() + end + rpc_progress!(backend) # one last drain so any reply Isend posted under + # the barrier is reaped on this rank + return nothing +end + +""" + rpc_progress_halt!(backend, dest_rank::Int) -> nothing + +Enqueue a framed [`RPCProgressHaltMsg`](@ref) to `dest_rank` on the backend's +`request_tag` via the same `_post_isend!` / pending-send path as ordinary RPC. + +The **destination** rank must call [`rpc_progress!`](@ref) (or run its daemon) +to match and decode the message. When consumed, that `rpc_progress!` pass +returns `false` and does **not** spawn a handler task for further matched +requests in that pass—useful to break out of a progress loop without executing +user RPC bodies for additional queued messages. + +The halt message carries an empty [`MsgHeader`](@ref); it is not a call/reply +pair and does not touch waiter tables. +""" +function rpc_progress_halt!(backend::AbstractMPIRPCBackend, dest_rank::Int) + buf = encode_frame(MsgHeader(), RPCProgressHaltMsg()) + _post_isend!(backend, dest_rank, backend.request_tag, buf) + return nothing +end \ No newline at end of file diff --git a/lib/MPIRPC/src/protocol.jl b/lib/MPIRPC/src/protocol.jl new file mode 100644 index 000000000..a805a8972 --- /dev/null +++ b/lib/MPIRPC/src/protocol.jl @@ -0,0 +1,181 @@ +using Serialization + +""" + MPIRRID(whence::Int32, id::UInt64) + +Routing identifier for an in-flight remote call. Mirrors `Distributed.RRID`, +with `whence` storing the *MPI rank* (within the backend's communicator) that +allocated the id rather than the Distributed worker pid. + +`whence == -1` and `id == 0` is reserved as the null id (`NULL_RRID`), +analogous to Distributed's `RRID(0, 0)`. +""" +struct MPIRRID + whence::Int32 + id::UInt64 +end + +const NULL_RRID = MPIRRID(Int32(-1), UInt64(0)) + +is_null(r::MPIRRID) = r.whence == NULL_RRID.whence && r.id == NULL_RRID.id + +Base.hash(r::MPIRRID, h::UInt) = hash(r.whence, hash(r.id, hash(MPIRRID, h))) +Base.:(==)(a::MPIRRID, b::MPIRRID) = a.whence == b.whence && a.id == b.id + +""" + MsgHeader(response_oid::MPIRRID, notify_oid::MPIRRID) + +Two-OID header preceding every MPIRPC body, mirroring `Distributed.MsgHeader`. + +* `response_oid` identifies a `MPIFuture` on the receiver of a `ResultMsg` + (i.e. on the *client* of the original call) so the value can be delivered + to the right waiter. +* `notify_oid` is the OID the *server* must echo back as `response_oid` when + it produces a `ResultMsg`. It is set on outgoing `CallMsg` / `CallWaitMsg` + by the client and is null on `RemoteDoMsg` (fire-and-forget). +""" +struct MsgHeader + response_oid::MPIRRID + notify_oid::MPIRRID +end +MsgHeader() = MsgHeader(NULL_RRID, NULL_RRID) +MsgHeader(response_oid::MPIRRID) = MsgHeader(response_oid, NULL_RRID) + +abstract type AbstractMsg end + +""" + CallMsg{Mode}(f, args::Tuple, kwargs) + +`Mode` is `:call` for `remotecall` (client expects a future) or `:call_fetch` +for `remotecall_fetch` (client expects the value to be returned in a +`ResultMsg`). MPIRPC v1 actually treats both modes identically on the wire: +the server always replies with a `ResultMsg` carrying the value or a +`MPIRemoteException`. `Mode` is preserved so the server can mirror Distributed's +behavior of distinguishing `:call` from `:call_fetch` in error reporting later. +""" +struct CallMsg{Mode} <: AbstractMsg + f::Any + args::Tuple + kwargs::Any +end + +""" + CallWaitMsg(f, args, kwargs) + +`remotecall_wait`: server runs the call and replies with `:OK` (or an +exception) so the client can confirm completion without paying for marshaling +the result. +""" +struct CallWaitMsg <: AbstractMsg + f::Any + args::Tuple + kwargs::Any +end + +""" + RemoteDoMsg(f, args, kwargs) + +Fire-and-forget: server runs the call and discards the result. No reply is +sent. Errors are printed to stderr on the server rank (like +`Distributed.remote_do`); nothing is delivered to the client. +""" +struct RemoteDoMsg <: AbstractMsg + f::Any + args::Tuple + kwargs::Any +end + +""" + ResultMsg(value) + +Carries either a successful return value or an `MPIRemoteException`. The +client routes it to a waiter via `MsgHeader.response_oid`. +""" +struct ResultMsg <: AbstractMsg + value::Any +end + +""" + RPCProgressHaltMsg + +Control message on the request tag: tells [`rpc_progress!`](@ref) to stop +draining further inbound requests in the **current** progress pass (returns +`false`). Framed like other MPIRPC bodies; see [`rpc_progress_halt!`](@ref). +""" +struct RPCProgressHaltMsg <: AbstractMsg end + +""" + MSG_BOUNDARY + +Ten-byte sentinel appended after every serialized frame. MPI is +message-oriented so we do not need it for stream resynchronization, but we +keep it as a fail-fast protocol-version / corruption check, matching +Distributed's convention. +""" +const MSG_BOUNDARY = UInt8[0x4d, 0x50, 0x49, 0x52, 0x50, 0x43, 0x46, 0x52, 0x4d, 0x31] + +""" + encode_frame(header, msg) -> Vector{UInt8} + +Serialize `header` then `msg` through a single `Serializer` over a fresh +`IOBuffer`, append `MSG_BOUNDARY`, and return the byte vector ready for +`MPI.Isend`. A fresh `Serializer` per message bounds the serializer's +back-reference table to one frame, so cross-frame state cannot leak. +""" +function encode_frame(header::MsgHeader, msg::AbstractMsg) + io = IOBuffer() + s = Serializer(io) + serialize(s, header) + serialize(s, msg) + write(io, MSG_BOUNDARY) + return take!(io) +end + +""" + decode_frame(buf) -> (header, msg, body_error) + +Decode a frame previously produced by `encode_frame`. The header is decoded +first so that, on body failures, the caller can still reply to the right +waiter (mirrors Distributed's two-stage parse in `process_messages.jl`). + +`body_error === nothing` on success; otherwise `msg === nothing` and +`body_error` is the captured exception. The `MSG_BOUNDARY` is verified on +success and a `ProtocolError` is raised if it is missing or wrong, which +forces a fail-fast rather than silently delivering garbage. + +`Base.invokelatest` wraps the body deserialization so that types defined +after world-age advances are reachable, matching Distributed's +`invokelatest(deserialize_msg, ...)` call site. +""" +function decode_frame(buf::Vector{UInt8}) + io = IOBuffer(buf) + s = Serializer(io) + header = deserialize(s)::MsgHeader + msg = nothing + body_error = nothing + try + msg = invokelatest(deserialize, s)::AbstractMsg + verify_boundary!(io) + catch e + body_error = e + end + return header, msg, body_error +end + +struct ProtocolError <: Exception + msg::String +end +Base.showerror(io::IO, e::ProtocolError) = print(io, "MPIRPC.ProtocolError: ", e.msg) + +function verify_boundary!(io::IOBuffer) + n = length(MSG_BOUNDARY) + bytes_left = io.size - io.ptr + 1 + if bytes_left < n + throw(ProtocolError("frame ended before MSG_BOUNDARY (have $(bytes_left) of $(n) bytes)")) + end + tail = read(io, n) + if tail != MSG_BOUNDARY + throw(ProtocolError("MSG_BOUNDARY mismatch (got $(tail))")) + end + return nothing +end diff --git a/lib/MPIRPC/src/refs.jl b/lib/MPIRPC/src/refs.jl new file mode 100644 index 000000000..ed5345a6a --- /dev/null +++ b/lib/MPIRPC/src/refs.jl @@ -0,0 +1,55 @@ +""" + MPIFuture + +Future-like handle returned by `remotecall` and (internally) used by +`remotecall_fetch` / `remotecall_wait`. Mirrors `Distributed.Future`: + +* `where` — MPI rank running the call (the eventual source of the reply). +* `rrid::MPIRRID` — id allocated by the *caller* (`whence == my rank`). + This same id is echoed back in `MsgHeader.response_oid` of the + corresponding `ResultMsg`. +* `v::Atomic{Union{Some{Any}, Nothing}}` — set once when the reply arrives. + +A `MPIFuture` is registered in the backend's waiter table from creation +until the value is delivered or it is finalized (whichever comes first). +""" +mutable struct MPIFuture + backend::Any + where::Int + rrid::MPIRRID + @atomic v::Union{Some{Any}, Nothing} + cond::Threads.Condition + + function MPIFuture(backend, where::Integer, rrid::MPIRRID) + return new(backend, Int(where), rrid, nothing, Threads.Condition()) + end +end + +Base.show(io::IO, f::MPIFuture) = + print(io, "MPIFuture(where=", f.where, ", rrid=", f.rrid, ", ready=", isready(f), ")") + +""" + isready(f::MPIFuture) -> Bool + +True if the reply has been delivered (success or remote exception) and a +subsequent `fetch(f)` will not block. +""" +Base.isready(f::MPIFuture) = (@atomic :acquire f.v) !== nothing + +""" + deliver!(f::MPIFuture, value) + +Internal: store `value` in `f` and wake any waiters. Idempotent — repeated +deliveries (which can occur if a stale duplicate reply ever lands) are +ignored after the first. +""" +function deliver!(f::MPIFuture, value) + @lock f.cond begin + prev = (@atomic :acquire f.v) + if prev === nothing + @atomic :release f.v = Some{Any}(value) + notify(f.cond, all=true) + end + end + return nothing +end diff --git a/lib/MPIRPC/src/remotecall.jl b/lib/MPIRPC/src/remotecall.jl new file mode 100644 index 000000000..e0d3dc0bf --- /dev/null +++ b/lib/MPIRPC/src/remotecall.jl @@ -0,0 +1,266 @@ +""" + remotecall(f, dest_rank, args...; kwargs...) -> MPIFuture + +Send a call to `dest_rank` and return a [`MPIFuture`](@ref) that completes +when the reply arrives. Mirrors `Distributed.remotecall`. + +The reply carries either the function's return value or a +`MPIRemoteException`; both are delivered into the future, and +[`fetch`](@ref) rethrows the exception case while returning the value +otherwise. +""" +function remotecall(f, dest_rank::Integer, args...; kwargs...) + backend = current_mpi_rpc_backend() + return _remotecall_internal(backend, Val(:call), f, Int(dest_rank), args, kwargs) +end + +""" + remotecall_fetch(f, dest_rank, args...; kwargs...) -> result + +Send a call, wait for the reply, and return the result. Throws +`MPIRemoteException` if the remote call raised. Mirrors +`Distributed.remotecall_fetch`. Uses [`current_mpi_rpc_backend`](@ref). +""" +function remotecall_fetch(f, dest_rank::Integer, args...; kwargs...) + return remotecall_fetch(f, current_mpi_rpc_backend(), dest_rank, args...; kwargs...) +end + +""" + remotecall_fetch(f, backend::AbstractMPIRPCBackend, dest_rank, args...; kwargs...) -> result + +Send a call, wait for the reply, and return the result. Throws +`MPIRemoteException` if the remote call raised. Mirrors +`Distributed.remotecall_fetch` with an explicit backend (e.g. for Dagger). +""" +function remotecall_fetch(f, backend::AbstractMPIRPCBackend, dest_rank::Integer, args...; kwargs...) + fut = _remotecall_internal(backend, Val(:call_fetch), f, Int(dest_rank), args, kwargs) + return fetch(fut) +end + +""" + remotecall_wait(f, dest_rank, args...; kwargs...) -> MPIFuture + +Send a call, wait for the server to acknowledge completion, and return the +future (containing `:OK` or an `MPIRemoteException`). Use this when you want +back-pressure on remote completion without paying for marshaling the result. +""" +function remotecall_wait(f, dest_rank::Integer, args...; kwargs...) + backend = current_mpi_rpc_backend() + fut = _remotecall_wait_internal(backend, f, Int(dest_rank), args, kwargs) + wait(fut) + val = something((@atomic :acquire fut.v)) + val isa MPIRemoteException && throw(val) + return fut +end + +""" + remote_do(f, dest_rank, args...; kwargs...) -> nothing + +Fire-and-forget: send a call to `dest_rank` and return immediately. The +remote function's return value (and any exception) is **not** observed by +the caller. Mirrors `Distributed.remote_do`. +""" +function remote_do(f, dest_rank::Integer, args...; kwargs...) + backend = current_mpi_rpc_backend() + _remote_do_internal(backend, f, Int(dest_rank), args, kwargs) + return nothing +end + +""" + bcast_remotecall(backend::AbstractMPIRPCBackend, f, args...; kwargs...) -> Vector{MPIFuture} + +Issue one [`remotecall`](@ref)-style `CallMsg{:call}` to **every other rank** +in `backend`'s communicator: all ranks in `0:backend.size-1` **except** +`backend.rank`, in increasing order. Returns a vector of length +`max(0, backend.size - 1)` (empty when `backend.size == 1`). + +`out[i]` is the future for the `i`-th destination in that filtered list (not +indexed by global rank). Progress and `fetch`/`wait` semantics match +[`remotecall`](@ref). + +For [`NonUniformMPIRPCBackend`](@ref), destinations that are not listeners +raise from [`_validate_dest!`](@ref); use the explicit-ranks form with only +listeners if needed. +""" +function bcast_remotecall(backend::AbstractMPIRPCBackend, f, args...; kwargs...) + dests = Int[d for d in 0:(backend.size - 1) if d != backend.rank] + return _bcast_remotecall_destinations!(backend, f, dests, Tuple(args), kwargs) +end + +""" + bcast_remotecall(backend::AbstractMPIRPCBackend, f, ranks::AbstractVector{<:Integer}, args...; kwargs...) -> Vector{MPIFuture} + bcast_remotecall(f, ranks::AbstractVector{<:Integer}, args...; kwargs...) -> Vector{MPIFuture} + +Issue [`remotecall`](@ref) to each rank in `ranks` **in list order** (no +automatic skip of `backend.rank`; filter `ranks` yourself if you want to omit +the caller). The `i`-th future corresponds to `ranks[i]`. + +The no-`backend` form uses [`current_mpi_rpc_backend`](@ref). +""" +function bcast_remotecall(backend::AbstractMPIRPCBackend, f, + ranks::AbstractVector{<:Integer}, args...; kwargs...) + dests = Int[Int(r) for r in ranks] + return _bcast_remotecall_destinations!(backend, f, dests, Tuple(args), kwargs) +end + +function bcast_remotecall(f, ranks::AbstractVector{<:Integer}, args...; kwargs...) + return bcast_remotecall(current_mpi_rpc_backend(), f, ranks, args...; kwargs...) +end + +function _bcast_remotecall_destinations!(backend::AbstractMPIRPCBackend, f, + dests::Vector{Int}, args::Tuple, kwargs) + n = length(dests) + out = Vector{MPIFuture}(undef, n) + for i in 1:n + out[i] = _remotecall_internal(backend, Val(:call), f, dests[i], args, kwargs) + end + return out +end + +# --------------------------------------------------------------------------- + +function _validate_dest!(backend::AbstractMPIRPCBackend, dest::Int) + backend.initialized || throw(ArgumentError( + "MPIRPC backend not initialized; call `select_mpi_rpc_backend!` first")) + if dest < 0 || dest >= backend.size + throw(ArgumentError("dest rank $dest outside [0, $(backend.size))")) + end + if backend isa NonUniformMPIRPCBackend && !(dest in backend.listener_ranks) + throw(ArgumentError( + "dest rank $dest is not a listener; only listener ranks can service RPC " * + "in a NonUniformMPIRPCBackend (listeners=$(sort!(collect(backend.listener_ranks))))")) + end + return nothing +end + +function _remotecall_internal(backend::AbstractMPIRPCBackend, ::Val{Mode}, f, + dest::Int, args::Tuple, kwargs) where {Mode} + _validate_dest!(backend, dest) + rrid = next_rrid(backend) + fut = MPIFuture(backend, dest, rrid) + register_waiter!(backend, fut) + header = MsgHeader(NULL_RRID, rrid) + msg = CallMsg{Mode}(f, args, _kwargs_to_pairs(kwargs)) + buf = encode_frame(header, msg) + _post_isend!(backend, dest, backend.request_tag, buf) + return fut +end + +function _remotecall_wait_internal(backend::AbstractMPIRPCBackend, f, + dest::Int, args::Tuple, kwargs) + _validate_dest!(backend, dest) + rrid = next_rrid(backend) + fut = MPIFuture(backend, dest, rrid) + register_waiter!(backend, fut) + header = MsgHeader(NULL_RRID, rrid) + msg = CallWaitMsg(f, args, _kwargs_to_pairs(kwargs)) + buf = encode_frame(header, msg) + _post_isend!(backend, dest, backend.request_tag, buf) + return fut +end + +function _remote_do_internal(backend::AbstractMPIRPCBackend, f, + dest::Int, args::Tuple, kwargs) + _validate_dest!(backend, dest) + header = MsgHeader(NULL_RRID, NULL_RRID) + msg = RemoteDoMsg(f, args, _kwargs_to_pairs(kwargs)) + buf = encode_frame(header, msg) + _post_isend!(backend, dest, backend.request_tag, buf) + return nothing +end + +# Materialize kwargs as a `Vector{Pair{Symbol,Any}}` so the on-wire form is +# stable regardless of whether the caller passed `; kwargs...`, a `NamedTuple`, +# or a `Base.Pairs`. This avoids serializing the iterator object itself, which +# can drag in unexpected closures. +function _kwargs_to_pairs(kwargs) + out = Vector{Pair{Symbol, Any}}() + for (k, v) in pairs(kwargs) + push!(out, k => v) + end + return out +end + +# --------------------------------------------------------------------------- +# wait / fetch — two strategies depending on whether the backend has a +# progress daemon driving the wire on this rank. + +""" + wait(f::MPIFuture) -> f + +Block until the reply for `f` has been delivered. Two implementations, +selected at call time by inspecting the future's backend: + +* **`backend.daemon == true`** — park on `f.cond` (a `Threads.Condition`). + The waiter consumes zero CPU until [`deliver!`](@ref) (called by the + daemon's `_dispatch_reply!`) holds the same lock, sets the value, and + `notify`s. This is the same shape as `Distributed`'s `take!(fut.v)` on + a `Channel`, with one important difference covered below. + +* **`backend.daemon == false`** — fall back to the v1 spin: call + `rpc_progress!` and `yield` in a loop. We *cannot* park here because + the backend has no other progress driver: the reply that would + eventually call `deliver!` only arrives when *some task* drains the + wire, and `wait` itself is the only candidate. Parking would deadlock. + +The `backend.daemon` check is read off `f.backend`, not the +*currently installed* backend, so a future created on a daemon-backed +backend is still cheap to wait on even from a task that has scoped a +non-daemon backend via [`with_mpi_rpc_backend`](@ref). +""" +function Base.wait(f::MPIFuture) + isready(f) && return f + backend = f.backend::AbstractMPIRPCBackend + if backend.daemon + # Cond-park path. The standard CV idiom: hold the lock, recheck + # the predicate, `wait` if false, recheck on wake. `deliver!` + # holds the same lock when it sets `f.v` and `notify`s, so the + # check / park / wake transitions are atomic with delivery. + @lock f.cond begin + while !isready(f) + wait(f.cond) + end + end + else + # No daemon ⇒ this task is the progress driver. A bare + # `wait(f.cond)` here would never wake because nobody is + # receiving the reply that triggers the notify. Yield-rate + # spin is the only correct choice; it costs CPU but it + # actually makes forward progress on the wire. + while !isready(f) + rpc_progress!(backend) + isready(f) && break + yield() + end + end + return f +end + +""" + fetch(f::MPIFuture) -> value + +Block until the reply has arrived, then return the value. If the remote +function threw, that exception is rethrown locally as a `MPIRemoteException`. +Park / spin behavior follows [`wait`](@ref) — see its docstring. +""" +function Base.fetch(f::MPIFuture) + wait(f) + val = something((@atomic :acquire f.v)) + val isa MPIRemoteException && throw(val) + return val +end + +# --------------------------------------------------------------------------- + +""" + @with_progress expr + +Convenience macro: evaluate `expr` in a scope where the current MPIRPC +backend's progress engine is driven by `wait`/`fetch`. This is a no-op +today (since `wait`/`fetch` already pump progress) and is provided as a +forward-compatible attachment point for future work that might run user +code in a context without an automatic progress drain. +""" +macro with_progress(expr) + return esc(expr) +end diff --git a/lib/MPIRPC/src/uniform.jl b/lib/MPIRPC/src/uniform.jl new file mode 100644 index 000000000..b7866911a --- /dev/null +++ b/lib/MPIRPC/src/uniform.jl @@ -0,0 +1,314 @@ +using MPI + +const REQUEST_TAG_DEFAULT = Int32(0xC0DE) +const REPLY_TAG_DEFAULT = Int32(0xC0DF) + +""" + UniformMPIRPCBackend(comm; ...) + +SPMD backend: every rank both issues and services RPC. Every rank must +call [`rpc_progress!`](@ref) regularly (e.g. once per main-loop iteration); +clients block in `fetch` / `wait` by looping `rpc_progress!` themselves so +no rank ever needs a dedicated background task to receive. + +# Wire layout + +* Two disjoint MPI tags carry traffic, regardless of how many calls are in + flight: `request_tag` for `CallMsg` / `CallWaitMsg` / `RemoteDoMsg`, and + `reply_tag` for `ResultMsg`. **Tags do not encode call identity.** +* Each call carries a unique `MPIRRID` in `MsgHeader.notify_oid`. The server + echoes it back as `MsgHeader.response_oid` on the reply, and the client + routes the reply to the matching `MPIFuture` through a waiter table keyed + by `MPIRRID`. + +This layout is deliberately ABBA-safe: two ranks calling each other +simultaneously cannot collide on `(peer, tag)` because requests and replies +travel on different tags, and concurrent calls to the same peer cannot be +confused because correlation lives in the header, not the tag. See +`docs/ARCHITECTURE.md` for the full deadlock argument. + +# Constructor + + UniformMPIRPCBackend(comm = MPI.COMM_WORLD; + request_tag = $(REQUEST_TAG_DEFAULT), + reply_tag = $(REPLY_TAG_DEFAULT), + dup_comm = true, + daemon = false) + +If `dup_comm` is `true` (default), the backend duplicates `comm` with +`MPI.Comm_dup` so RPC traffic cannot interleave with collectives the user +runs on the original communicator. Pass `dup_comm = false` if MPI is not +yet initialized at construction time and you intend to call +[`select_mpi_rpc_backend!`](@ref) (which initializes MPI if necessary). + +If `daemon` is `true`, [`select_mpi_rpc_backend!`](@ref) spawns a +`Threads.@spawn`-backed task that runs [`rpc_progress!`](@ref) in a tight +yield-only loop until [`shutdown!`](@ref) is called. This removes the +"every rank must call `rpc_progress!`" requirement: as long as Julia has +≥ 2 threads, the daemon makes forward progress on inbound requests +without any user pumping. Under `julia -t 1` the daemon shares a thread +with the main task and only runs at yield points, so the requirement is +effectively unchanged. The daemon does **not** sleep — it polls +continuously — so expect ≈ 100% utilization of one OS thread. If that is +unacceptable for your workload, leave `daemon = false` and drive +progress manually via `rpc_progress!`, `serve_listener`, `wait`/`fetch`, +or `rpc_barrier`. +""" +mutable struct UniformMPIRPCBackend <: AbstractMPIRPCBackend + base_comm::MPI.Comm + comm::MPI.Comm + request_tag::Int32 + reply_tag::Int32 + dup_comm::Bool + rank::Int + size::Int + + rrid_counter::Threads.Atomic{UInt64} + waiters::Dict{MPIRRID, MPIFuture} + waiters_lock::ReentrantLock + + pending_sends::Vector{Tuple{MPI.Request, Vector{UInt8}}} + + mpi_lock::ReentrantLock + + running::Threads.Atomic{Bool} + initialized::Bool + + daemon::Bool + daemon_task::Union{Task, Nothing} + + function UniformMPIRPCBackend(base_comm::MPI.Comm = MPI.COMM_WORLD; + request_tag::Integer = REQUEST_TAG_DEFAULT, + reply_tag::Integer = REPLY_TAG_DEFAULT, + dup_comm::Bool = true, + daemon::Bool = false) + request_tag == reply_tag && throw(ArgumentError("request_tag and reply_tag must differ")) + return new(base_comm, base_comm, Int32(request_tag), Int32(reply_tag), + dup_comm, -1, -1, + Threads.Atomic{UInt64}(1), + Dict{MPIRRID, MPIFuture}(), + ReentrantLock(), + Tuple{MPI.Request, Vector{UInt8}}[], + ReentrantLock(), + Threads.Atomic{Bool}(true), + false, + daemon, + nothing) + end +end + +function initialize_mpi_rpc!(b::UniformMPIRPCBackend) + b.initialized && return nothing + if !MPI.Initialized() + MPI.Init(; threadlevel=:multiple) + end + b.comm = b.dup_comm ? MPI.Comm_dup(b.base_comm) : b.base_comm + b.rank = MPI.Comm_rank(b.comm) + b.size = MPI.Comm_size(b.comm) + b.initialized = true + return nothing +end + +""" + next_rrid(backend) -> MPIRRID + +Allocate a new id local to this rank. The caller's rank is encoded in the +`whence` field, so the same `MPIRRID` is meaningful across the whole +communicator (only the originating rank ever looks it up in the waiter +table). +""" +next_rrid(b::UniformMPIRPCBackend) = + MPIRRID(Int32(b.rank), Threads.atomic_add!(b.rrid_counter, UInt64(1))) + +function register_waiter!(b::UniformMPIRPCBackend, fut::MPIFuture) + @lock b.waiters_lock begin + b.waiters[fut.rrid] = fut + end + return fut +end + +function take_waiter!(b::UniformMPIRPCBackend, rrid::MPIRRID) + @lock b.waiters_lock begin + fut = get(b.waiters, rrid, nothing) + fut === nothing && return nothing + delete!(b.waiters, rrid) + return fut + end +end + +# --------------------------------------------------------------------------- +# Send path + +function _post_isend!(b::UniformMPIRPCBackend, dest::Integer, tag::Integer, + buf::Vector{UInt8}) + @lock b.mpi_lock begin + req = MPI.Isend(buf, b.comm; dest=Int(dest), tag=Int(tag)) + push!(b.pending_sends, (req, buf)) + end + return nothing +end + +function _reap_pending_sends!(b::UniformMPIRPCBackend) + @lock b.mpi_lock begin + i = 1 + while i <= length(b.pending_sends) + req, _ = b.pending_sends[i] + if MPI.Test(req) + deleteat!(b.pending_sends, i) + else + i += 1 + end + end + end + return nothing +end + +# --------------------------------------------------------------------------- +# Receive / dispatch path + +""" + _try_recv_one!(backend, tag) -> (src, buf) | nothing + +Try to receive one message that's already in MPI's matching engine on +`tag`. Two phases: + +1. **Match-and-post under `mpi_lock`.** `MPI.Improbe` is the only way to + atomically peek at a matched message and remove it from the queue; + the buffer must be allocated and `MPI.Imrecv!` posted while still + holding the message handle. Both steps live in a single critical + section because they are semantically one operation. + +2. **Wait outside `mpi_lock`, with yield.** `MPI.Imrecv!` returns a + non-blocking request; we then loop on `MPI.Test` with `yield()` in + between, **releasing `mpi_lock` between polls**. This is what makes + the daemon thread non-blocking even on a slow rendezvous-protocol + transfer: while the receive is in flight, other tasks (handlers + doing their own `Isend`s, user code on this rank) can acquire + `mpi_lock` and progress their own MPI calls. + +Eager-protocol messages (the typical case for RPC payloads up to +MPI's eager threshold, often 64 KiB) almost always complete `Test` +on the first poll, so the only added cost vs. blocking `Mrecv!` is +one extra Test call and one lock acquire/release pair — single-digit +microseconds. The win is for large payloads on rendezvous protocol, +where the previous `Mrecv!` would have held `mpi_lock` for the entire +network round-trip. +""" +function _try_recv_one!(b::UniformMPIRPCBackend, tag::Integer) + local req::MPI.Request + local buf::Vector{UInt8} + local src::Int + @lock b.mpi_lock begin + got, m, status = MPI.Improbe(MPI.ANY_SOURCE, Int(tag), b.comm, MPI.Status) + got || return nothing + src = Int(status.MPI_SOURCE) + count = MPI.Get_count(status, UInt8) + buf = Vector{UInt8}(undef, count) + req = MPI.Imrecv!(buf, m) + end + while true + done = @lock b.mpi_lock MPI.Test(req) + done && return (src, buf) + yield() + end +end + +""" + rpc_progress!([backend]) + +Drive one non-blocking pass of the MPIRPC engine: reap completed sends, +service every inbound request currently matched on this rank, and deliver +every inbound reply currently matched on this rank. + +For [`UniformMPIRPCBackend`](@ref), **every** rank must call this regularly +(typically wrapped in [`@with_progress`](@ref) inside `wait` / `fetch`). +For [`NonUniformMPIRPCBackend`](@ref), only listener ranks need to call +the request-draining half; clients still need to call it to drain replies, +Returns `true` if the pass completed normally, or `false` if a +[`RPCProgressHaltMsg`](@ref) was consumed on the request tag (no further +requests are drained in that pass); see [`rpc_progress_halt!`](@ref). +""" +function rpc_progress!(b::UniformMPIRPCBackend) + return _rpc_progress_impl!(b) +end + +function _rpc_progress_impl!(b::UniformMPIRPCBackend) + _reap_pending_sends!(b) + + while true + r = _try_recv_one!(b, b.request_tag) + r === nothing && break + src, buf = r + header, msg, body_err = decode_frame(buf) + @debug "Received request" header msg body_err + if body_err === nothing && msg isa RPCProgressHaltMsg + return false + end + # Run the handler on a freshly spawned task so user code never + # executes while the progress pump holds any state. Without this, + # a handler that itself blocks on `Threads.@spawn`-then-`fetch` + # would deadlock: the spawned task would need another thread to + # drive RPC progress, but every thread that called `rpc_progress!` + # would be waiting on the handler that the previous progress pass + # had begun running synchronously. See `docs/ARCHITECTURE.md` §6. + errormonitor(Threads.@spawn _run_handler_task(b, src, header, msg, body_err)) + end + + while true + r = _try_recv_one!(b, b.reply_tag) + r === nothing && break + @debug "Received reply" r + _, buf = r + # Reply dispatch only touches MPIRPC's own state (waiter table and + # future condition); it cannot block on user code, so we keep it + # inline. + _dispatch_reply!(b, buf) + end + + return true +end + +function _execute_request!(b::UniformMPIRPCBackend, src::Int, header::MsgHeader, msg::AbstractMsg) + if msg isa CallMsg{:call} || msg isa CallMsg{:call_fetch} + v = run_work_thunk(() -> invokelatest(msg.f, msg.args...; pairs(msg.kwargs)...), b.rank) + if !is_null(header.notify_oid) + _send_reply!(b, header.notify_oid, v) + end + elseif msg isa CallWaitMsg + v = run_work_thunk(() -> invokelatest(msg.f, msg.args...; pairs(msg.kwargs)...), b.rank) + if !is_null(header.notify_oid) + ack = v isa MPIRemoteException ? v : :OK + _send_reply!(b, header.notify_oid, ack) + end + elseif msg isa RemoteDoMsg + run_work_thunk(() -> invokelatest(msg.f, msg.args...; pairs(msg.kwargs)...), b.rank; + print_error=true) + else + throw(ProtocolError("unhandled request message $(typeof(msg)) on rank $(b.rank))")) + end + return nothing +end + +function _dispatch_reply!(b::UniformMPIRPCBackend, buf::Vector{UInt8}) + header, msg, body_err = decode_frame(buf) + if body_err !== nothing + _deliver_reply_deserialize_failure!(b, header.response_oid, b.rank, body_err) + return + end + msg isa ResultMsg || throw(ProtocolError("expected ResultMsg on reply tag, got $(typeof(msg))")) + fut = take_waiter!(b, header.response_oid) + if fut === nothing + @warn "MPIRPC: ResultMsg for unknown waiter (response_oid=$(header.response_oid)) on rank $(b.rank); dropping" + return + end + deliver!(fut, msg.value) + return nothing +end + +function _send_reply!(b::UniformMPIRPCBackend, response_oid::MPIRRID, value) + header = MsgHeader(response_oid, NULL_RRID) + body = ResultMsg(value) + buf = encode_frame(header, body) + _post_isend!(b, response_oid.whence, b.reply_tag, buf) + return nothing +end diff --git a/lib/MPIRPC/test/bcast_remotecall_mpiexec.jl b/lib/MPIRPC/test/bcast_remotecall_mpiexec.jl new file mode 100644 index 000000000..42edbc24d --- /dev/null +++ b/lib/MPIRPC/test/bcast_remotecall_mpiexec.jl @@ -0,0 +1,57 @@ +using Test +using MPI +using MPIRPC + +MPI.Init(; threadlevel=:multiple) +backend = MPIRPC.select_mpi_rpc_backend!(UniformMPIRPCBackend(MPI.COMM_WORLD)) +const COMM = backend.comm +const RANK = MPI.Comm_rank(COMM) +const NPROC = MPI.Comm_size(COMM) + +NPROC >= 3 || error("bcast_remotecall test expects at least 3 ranks (got $NPROC)") + +function rpc_barrier_local() + MPIRPC.rpc_barrier(backend) +end + +function pump_until(pred; timeout::Real = 30.0) + t0 = time() + while !pred() + MPIRPC.rpc_progress!(backend) + time() - t0 > timeout && return false + yield() + end + return true +end + +rpc_barrier_local() +if RANK == 0 + println("--- bcast_remotecall / all other ranks ---") + flush(stdout) +end +rpc_barrier_local() + +@testset "bcast_remotecall skips caller" begin + futs = MPIRPC.bcast_remotecall(backend, +, RANK, 100) + @test length(futs) == NPROC - 1 + @test pump_until(() -> all(Base.isready, futs)) + @test all(f -> MPIRPC.fetch(f) == RANK + 100, futs) +end + +rpc_barrier_local() +if RANK == 0 + println("--- bcast_remotecall / explicit ranks vector ---") + flush(stdout) +end +rpc_barrier_local() + +@testset "bcast_remotecall explicit ranks vector" begin + peer = mod(RANK + 1, NPROC) + futs = MPIRPC.bcast_remotecall(backend, +, Int[peer], 10, 5) + @test length(futs) == 1 + @test pump_until(() -> Base.isready(futs[1])) + @test MPIRPC.fetch(futs[1]) == 15 +end + +rpc_barrier_local() +MPIRPC.shutdown!() diff --git a/lib/MPIRPC/test/mpi_tests.jl b/lib/MPIRPC/test/mpi_tests.jl new file mode 100644 index 000000000..8b2210b41 --- /dev/null +++ b/lib/MPIRPC/test/mpi_tests.jl @@ -0,0 +1,96 @@ +using Test +using MPI + +# Orchestrate the two mpiexec-driven suites from the host test process. +# `Pkg.test()` runs in a single process; we spawn `mpiexec` here so the +# usual `julia --project=. -e 'using Pkg; Pkg.test()'` flow exercises the +# real MPI transport. + +const MPI_BIN = mpiexec() + +const HERE = @__DIR__ +const PROJECT = abspath(joinpath(HERE, "..")) + +function _run_mpi(script::AbstractString; nproc::Integer, + default_threads::Integer = 2, + interactive_threads::Integer = 1) + # Pass `--threads=N,M` to each spawned Julia: N default-pool threads + # to exercise concurrent handlers and concurrent client tasks, plus M + # interactive-pool threads so the MPIRPC daemon (when enabled) can be + # scheduled on its own pool, isolated from any user CPU-bound work + # running on the default pool. + # + # `default_threads = 2` is the minimum that turns `Threads.@spawn` + # into actual parallelism. `interactive_threads = 1` is the minimum + # that gives the daemon a dedicated thread; the daemon-suite scripts + # additionally assert that they are running with the daemon on + # `:interactive`, so this needs to stay >= 1. + threadspec = "$(default_threads),$(interactive_threads)" + cmd = `$(MPI_BIN) -n $(nproc) $(Base.julia_cmd()) --threads=$(threadspec) --project=$(PROJECT) $(script)` + @info "MPIRPC: running" script nproc default_threads interactive_threads cmd + rc = success(pipeline(cmd, stdout = stdout, stderr = stderr)) + return rc +end + +@testset "mpiexec / uniform backend (4 ranks)" begin + if get(ENV, "MPIRPC_SKIP_MPI_TESTS", "0") == "1" + @info "MPIRPC: skipping mpiexec uniform suite (MPIRPC_SKIP_MPI_TESTS=1)" + @test true + else + @test _run_mpi(joinpath(HERE, "uniform_mpiexec.jl"); nproc = 4) + end +end + +@testset "mpiexec / non-uniform backend (4 ranks)" begin + if get(ENV, "MPIRPC_SKIP_MPI_TESTS", "0") == "1" + @info "MPIRPC: skipping mpiexec non-uniform suite (MPIRPC_SKIP_MPI_TESTS=1)" + @test true + else + @test _run_mpi(joinpath(HERE, "nonuniform_mpiexec.jl"); nproc = 4) + end +end + +# The two daemon suites verify the optional progress-daemon path: the +# scripts never call `rpc_progress!` or `serve_listener` themselves, so a +# successful run is *only* possible if the yield-only `_daemon_loop` +# spawned by `select_mpi_rpc_backend!` is doing its job. We give them a +# separate testset (rather than folding them into the existing two) so a +# regression in the daemon does not get masked by the manual-pump suites. +@testset "mpiexec / uniform backend with daemon (4 ranks)" begin + if get(ENV, "MPIRPC_SKIP_MPI_TESTS", "0") == "1" + @info "MPIRPC: skipping mpiexec uniform-daemon suite (MPIRPC_SKIP_MPI_TESTS=1)" + @test true + else + @test _run_mpi(joinpath(HERE, "uniform_daemon_mpiexec.jl"); nproc = 4) + end +end + +@testset "mpiexec / non-uniform backend with daemon (4 ranks)" begin + if get(ENV, "MPIRPC_SKIP_MPI_TESTS", "0") == "1" + @info "MPIRPC: skipping mpiexec non-uniform-daemon suite (MPIRPC_SKIP_MPI_TESTS=1)" + @test true + else + @test _run_mpi(joinpath(HERE, "nonuniform_daemon_mpiexec.jl"); nproc = 4) + end +end + +# Example: RPC matmul with explicit listener / client roles (no client-to-client RPC). +@testset "mpiexec / bcast_remotecall (3 ranks)" begin + if get(ENV, "MPIRPC_SKIP_MPI_TESTS", "0") == "1" + @info "MPIRPC: skipping mpiexec bcast_remotecall (MPIRPC_SKIP_MPI_TESTS=1)" + @test true + else + @test _run_mpi(joinpath(HERE, "bcast_remotecall_mpiexec.jl"); nproc = 3) + end +end + +# Example: RPC matmul with explicit listener / client roles (no client-to-client RPC). +@testset "mpiexec / example rpc_matmul non-uniform (4 ranks)" begin + if get(ENV, "MPIRPC_SKIP_MPI_TESTS", "0") == "1" + @info "MPIRPC: skipping mpiexec rpc_matmul_nonuniform (MPIRPC_SKIP_MPI_TESTS=1)" + @test true + else + ex = joinpath(dirname(HERE), "examples", "rpc_matmul_nonuniform.jl") + @test _run_mpi(ex; nproc = 4) + end +end diff --git a/lib/MPIRPC/test/nonuniform_daemon_mpiexec.jl b/lib/MPIRPC/test/nonuniform_daemon_mpiexec.jl new file mode 100644 index 000000000..baf816942 --- /dev/null +++ b/lib/MPIRPC/test/nonuniform_daemon_mpiexec.jl @@ -0,0 +1,138 @@ +using Test +using MPI +using MPIRPC + +# Non-uniform companion to `uniform_daemon_mpiexec.jl`. Listeners receive +# inbound requests via the daemon, clients receive their own replies via +# the daemon (or via `wait`/`fetch` self-pumping — both paths must work). +# No testset in this script calls `rpc_progress!` or `serve_listener`. + +MPI.Init(; threadlevel = :multiple) + +const WORLD_RANK = MPI.Comm_rank(MPI.COMM_WORLD) +const NPROC = MPI.Comm_size(MPI.COMM_WORLD) + +NPROC >= 4 || error("non-uniform daemon tests need at least 4 ranks (got $NPROC); " * + "rerun with `mpiexec -n 4 ...`") + +const HALF = max(1, NPROC ÷ 2) +const LISTENER_RANKS = collect(0:(HALF - 1)) +const CLIENT_RANKS = collect(HALF:(NPROC - 1)) +const IS_LISTENER = WORLD_RANK in LISTENER_RANKS + +backend = MPIRPC.select_mpi_rpc_backend!( + NonUniformMPIRPCBackend(MPI.COMM_WORLD; + listener_ranks = LISTENER_RANKS, + daemon = true), +) +const COMM = backend.comm + +@assert backend.daemon "daemon flag did not stick" +@assert backend.daemon_task isa Task "daemon task should have been spawned" +@assert !istaskdone(backend.daemon_task) + +# Daemon must be on the `:interactive` pool so it cannot be starved by +# CPU-bound user code on the `:default` pool. See uniform_daemon_mpiexec.jl +# for the rationale. +if Threads.nthreads(:interactive) > 0 + @assert Threads.threadpool(backend.daemon_task) === :interactive ( + "daemon task expected on :interactive, got $(Threads.threadpool(backend.daemon_task))") +end + +# `_phase!` uses `rpc_barrier`; see the uniform daemon script for why this +# is safe alongside the daemon (per-backend `mpi_lock` serializes raw MPI +# calls, the barrier request is task-local). +function _phase!(name::AbstractString) + rpc_barrier() + if WORLD_RANK == 0 + println("--- ", name, " ---") + flush(stdout) + end + rpc_barrier() +end + +# --------------------------------------------------------------------------- + +_phase!("non-uniform-daemon / smoke (no manual progress, no serve_listener)") +@testset "non-uniform-daemon / client-listener round trip" begin + if !IS_LISTENER + peer = first(LISTENER_RANKS) + # Critically, this script never has the listener call + # `serve_listener` or `rpc_progress!`. The daemon on the listener + # is doing all inbound draining. + got = remotecall_fetch(+, peer, WORLD_RANK, 1) + @test got == WORLD_RANK + 1 + else + @test true + end +end + +_phase!("non-uniform-daemon / each client fans out to all listeners") +@testset "non-uniform-daemon / each client fans out to all listeners" begin + if !IS_LISTENER + results = Vector{Int}(undef, length(LISTENER_RANKS)) + for (i, l) in enumerate(LISTENER_RANKS) + results[i] = remotecall_fetch(+, l, WORLD_RANK, l) + end + @test results == [WORLD_RANK + l for l in LISTENER_RANKS] + else + @test true + end +end + +_phase!("non-uniform-daemon / multiple client threads concurrent") +@testset "non-uniform-daemon / multiple client threads concurrent" begin + if Threads.nthreads() >= 2 && !IS_LISTENER + K = 16 + peer = first(LISTENER_RANKS) + results = Vector{Vector{Int}}(undef, Threads.nthreads()) + ts = Task[] + for tid in 1:Threads.nthreads() + t = Threads.@spawn begin + local got = Vector{Int}(undef, K) + for k in 1:K + got[k] = MPIRPC.remotecall_fetch( + +, peer, WORLD_RANK * 100 + tid, k) + end + results[tid] = got + end + push!(ts, t) + end + foreach(wait, ts) + for tid in 1:Threads.nthreads() + @test results[tid] == [WORLD_RANK * 100 + tid + k for k in 1:K] + end + else + @test true + end +end + +_phase!("non-uniform-daemon / handler that spawns and fetches") +@testset "non-uniform-daemon / nested fetch from spawned handler subtask" begin + if Threads.nthreads() >= 2 && length(LISTENER_RANKS) >= 2 && !IS_LISTENER + peer = LISTENER_RANKS[1] + helper = LISTENER_RANKS[2] + result = remotecall_fetch(peer, helper, WORLD_RANK) do helper_rank, origin + t = Threads.@spawn MPIRPC.remotecall_fetch(+, helper_rank, origin, 5) + return fetch(t) + 1000 + end + @test result == WORLD_RANK + 5 + 1000 + else + @test true + end +end + +_phase!("non-uniform-daemon / shutdown joins the daemon") +@testset "non-uniform-daemon / shutdown! cleanly joins the daemon" begin + @test backend.daemon_task isa Task + @test !istaskdone(backend.daemon_task) + shutdown!() + @test backend.running[] == false + @test backend.daemon_task === nothing +end + +# Final sync on the duplicate communicator. As in the uniform daemon +# script, every rank has called `shutdown!` so no rank is pumping +# progress; a plain `MPI.Barrier` on the RPC subcomm is the right tool. +MPI.Barrier(backend.comm) +MPI.Finalize() diff --git a/lib/MPIRPC/test/nonuniform_mpiexec.jl b/lib/MPIRPC/test/nonuniform_mpiexec.jl new file mode 100644 index 000000000..9a2914ca6 --- /dev/null +++ b/lib/MPIRPC/test/nonuniform_mpiexec.jl @@ -0,0 +1,242 @@ +using Test +using MPI +using MPIRPC +using Random + +Random.seed!(0xC0FFEE) + +MPI.Init(; threadlevel=:multiple) + +const WORLD_RANK = MPI.Comm_rank(MPI.COMM_WORLD) +const NPROC = MPI.Comm_size(MPI.COMM_WORLD) + +NPROC >= 4 || error("non-uniform tests need at least 4 ranks (got $NPROC); " * + "rerun with `mpiexec -n 4 ...`") + +# Roughly half the ranks are listeners, the rest are clients only. +const HALF = max(1, NPROC ÷ 2) +const LISTENER_RANKS = collect(0:(HALF - 1)) +const CLIENT_RANKS = collect(HALF:(NPROC - 1)) +const IS_LISTENER = WORLD_RANK in LISTENER_RANKS + +backend = MPIRPC.select_mpi_rpc_backend!( + NonUniformMPIRPCBackend(MPI.COMM_WORLD; listener_ranks = LISTENER_RANKS)) +const COMM = backend.comm + +@assert is_listener(backend) == IS_LISTENER +@assert listener_ranks(backend) == LISTENER_RANKS + +# A progress-pumping barrier and an announce on rank 0 — same idea as the +# uniform suite. Critical for non-uniform: clients pump replies and listeners +# pump requests, so a `rpc_barrier` after each phase prevents a listener +# from "going quiet" before clients have collected pending replies. +function _phase!(name::AbstractString) + rpc_barrier() + if WORLD_RANK == 0 + println("--- ", name, " ---") + flush(stdout) + end + rpc_barrier() +end + +function pump_until(pred; timeout::Real = 30.0, + backend = MPIRPC.current_mpi_rpc_backend()) + t0 = time() + while !pred() + rpc_progress!(backend) + time() - t0 > timeout && return false + yield() + end + return true +end + +# --------------------------------------------------------------------------- + +_phase!("non-uniform / role checks") +@testset "non-uniform / role checks" begin + @test is_listener(backend) == (WORLD_RANK in LISTENER_RANKS) + if !IS_LISTENER && length(CLIENT_RANKS) >= 2 + # Clients cannot remotecall to other clients; only listeners service RPC. + another_client = first(c for c in CLIENT_RANKS if c != WORLD_RANK) + @test_throws ArgumentError remotecall(+, another_client, 1, 2) + else + @test true + end +end + +_phase!("non-uniform / clients hammer listeners") +@testset "non-uniform / clients hammer listeners" begin + if IS_LISTENER + # Listeners do not initiate RPC in this test — they just service. + @test true + else + K = 16 + futs = MPIFuture[] + for ℓ in LISTENER_RANKS + for k in 1:K + push!(futs, remotecall(+, ℓ, WORLD_RANK, k)) + end + end + @test pump_until(() -> all(isready, futs); timeout = 60.0) + # Each call computed `WORLD_RANK + k` (in unspecified ℓ). + for ℓ in LISTENER_RANKS, k in 1:K + idx = findfirst(==(WORLD_RANK + k), [fetch(f) for f in futs]) + @test idx !== nothing + end + end +end + +_phase!("non-uniform / handler observes its own rank") +@testset "non-uniform / handler observes its own rank" begin + if !IS_LISTENER + peer = first(LISTENER_RANKS) + got = remotecall_fetch(peer) do + return MPI.Comm_rank(MPI.COMM_WORLD) + end + @test got == peer + else + @test true + end +end + +_phase!("non-uniform / many OIDs from one client to one listener") +@testset "non-uniform / many OIDs from one client to one listener" begin + if !IS_LISTENER + peer = first(LISTENER_RANKS) + K = 64 + futs = [remotecall(+, peer, WORLD_RANK * 1000, k) for k in 1:K] + @test pump_until(() -> all(isready, futs); timeout = 30.0) + perm = randperm(K) + for i in perm + @test fetch(futs[i]) == WORLD_RANK * 1000 + i + end + else + @test true + end +end + +_phase!("non-uniform / multi-threaded handler-spawn-fetch") +@testset "non-uniform / listener handler that Threads.@spawn-then-fetches" begin + # Regression test for the multi-threaded handler deadlock. The listener + # receiving the call dispatches the handler on `Threads.@spawn`; the + # handler in turn spawns a task that calls back into another listener + # via `remotecall_fetch`, then `fetch`es it. This worked under the + # synchronous-handler model only because of `ReentrantLock` re-entry + # on the same task; under multi-threading it required removing + # `progress_lock` and running handlers on their own tasks. + if Threads.nthreads() >= 2 && length(LISTENER_RANKS) >= 2 && !IS_LISTENER + peer = LISTENER_RANKS[1] + helper = LISTENER_RANKS[2] + result = remotecall_fetch(peer, helper, WORLD_RANK) do helper_rank, origin + t = Threads.@spawn MPIRPC.remotecall_fetch(+, helper_rank, origin, 5) + return fetch(t) + 1000 + end + @test result == WORLD_RANK + 5 + 1000 + else + @test true + end +end + +_phase!("non-uniform / concurrent client threads") +@testset "non-uniform / multiple client threads issuing concurrent RPC" begin + if Threads.nthreads() >= 2 && !IS_LISTENER + K = 16 + peer = first(LISTENER_RANKS) + results = Vector{Vector{Int}}(undef, Threads.nthreads()) + ts = Task[] + for tid in 1:Threads.nthreads() + t = Threads.@spawn begin + local got = Vector{Int}(undef, K) + for k in 1:K + got[k] = MPIRPC.remotecall_fetch(+, peer, WORLD_RANK * 100 + tid, k) + end + results[tid] = got + end + push!(ts, t) + end + for t in ts + wait(t) + end + for tid in 1:Threads.nthreads() + @test results[tid] == [WORLD_RANK * 100 + tid + k for k in 1:K] + end + else + @test true + end +end + +_phase!("non-uniform / world barrier coexists with subcomm RPC") +@testset "non-uniform / world barrier coexists with subcomm RPC" begin + if !IS_LISTENER + peer = first(LISTENER_RANKS) + f = remotecall(+, peer, WORLD_RANK, 1) + # Every rank in COMM_WORLD must reach this barrier so collectives + # stay aligned. Listeners reach it from the listener phase below. + MPI.Barrier(MPI.COMM_WORLD) + @test pump_until(() -> isready(f); timeout = 30.0) + @test fetch(f) == WORLD_RANK + 1 + else + # Listener: pump while clients post their RPC, then barrier with them. + t0 = time() + while time() - t0 < 0.5 + rpc_progress!(backend) + yield() + end + MPI.Barrier(MPI.COMM_WORLD) + # Pump a bit more so any pending replies clear out before _phase!. + t0 = time() + while time() - t0 < 0.5 + rpc_progress!(backend) + yield() + end + @test true + end +end + +_phase!("non-uniform / remote exception path") +@testset "non-uniform / remote exception path" begin + if !IS_LISTENER + peer = first(LISTENER_RANKS) + @test_throws MPIRemoteException remotecall_fetch(peer) do + error("planned failure") + end + else + @test true + end +end + +_phase!("non-uniform / remote_do is fire-and-forget") +@testset "non-uniform / remote_do returns nothing immediately" begin + if !IS_LISTENER + peer = first(LISTENER_RANKS) + @test remote_do(peer) do; nothing end === nothing + else + @test true + end +end + +_phase!("non-uniform / set-then-read via remotecall_wait") +@testset "non-uniform / set-then-read with remotecall_wait" begin + if !IS_LISTENER + peer = first(LISTENER_RANKS) + sentinel = Symbol("_NONUNI_SET_FROM_$(WORLD_RANK)") + # Under multi-threading the spawned-handler model does not preserve + # FIFO of handler *execution* on the same (src, dest, tag) — only + # MPI delivery. `remotecall_wait` provides the explicit ack we need + # before reading back the side effect. + remotecall_wait(peer, WORLD_RANK, sentinel) do origin, name + @eval Main const $name = $origin + end + v = remotecall_fetch(peer, sentinel) do name + return getfield(Main, name) + end + @test v == WORLD_RANK + else + @test true + end +end + +_phase!("non-uniform / shutdown") +shutdown!() +rpc_barrier() +MPI.Finalize() diff --git a/lib/MPIRPC/test/protocol_tests.jl b/lib/MPIRPC/test/protocol_tests.jl new file mode 100644 index 000000000..25dd3b381 --- /dev/null +++ b/lib/MPIRPC/test/protocol_tests.jl @@ -0,0 +1,120 @@ +using Test +using Serialization +using MPIRPC +using MPIRPC: MsgHeader, MPIRRID, NULL_RRID, CallMsg, CallWaitMsg, RemoteDoMsg, + ResultMsg, RPCProgressHaltMsg, encode_frame, decode_frame, MSG_BOUNDARY, ProtocolError, + is_null + +@testset "protocol / framing (no MPI)" begin + @testset "MPIRRID identity" begin + a = MPIRRID(Int32(2), UInt64(7)) + b = MPIRRID(Int32(2), UInt64(7)) + c = MPIRRID(Int32(3), UInt64(7)) + @test a == b + @test hash(a) == hash(b) + @test a != c + @test is_null(NULL_RRID) + @test !is_null(a) + end + + @testset "MsgHeader defaults" begin + @test is_null(MsgHeader().response_oid) + @test is_null(MsgHeader().notify_oid) + @test MsgHeader(MPIRRID(Int32(1), UInt64(2))).notify_oid == NULL_RRID + end + + @testset "CallMsg roundtrip preserves Mode and args" begin + h = MsgHeader(NULL_RRID, MPIRRID(Int32(0), UInt64(99))) + m = CallMsg{:call_fetch}(+, (1, 2, 3), Pair{Symbol,Any}[:by => 10]) + buf = encode_frame(h, m) + h2, m2, err = decode_frame(buf) + @test err === nothing + @test h2 == h + @test m2 isa CallMsg{:call_fetch} + @test m2.args == (1, 2, 3) + @test m2.kwargs == Pair{Symbol,Any}[:by => 10] + end + + @testset "ResultMsg roundtrip" begin + h = MsgHeader(MPIRRID(Int32(0), UInt64(1)), NULL_RRID) + m = ResultMsg([1.0, 2.0, 3.0]) + buf = encode_frame(h, m) + h2, m2, err = decode_frame(buf) + @test err === nothing + @test h2 == h + @test m2 isa ResultMsg + @test m2.value == [1.0, 2.0, 3.0] + end + + @testset "RemoteDoMsg has null OIDs" begin + h = MsgHeader(NULL_RRID, NULL_RRID) + m = RemoteDoMsg(println, ("hi",), Pair{Symbol,Any}[]) + buf = encode_frame(h, m) + _, m2, err = decode_frame(buf) + @test err === nothing + @test m2 isa RemoteDoMsg + end + + @testset "RPCProgressHaltMsg roundtrip" begin + buf = encode_frame(MsgHeader(), RPCProgressHaltMsg()) + h2, m2, err = decode_frame(buf) + @test err === nothing + @test h2 == MsgHeader() + @test m2 isa RPCProgressHaltMsg + end + + @testset "fresh serializer per frame: state does not leak" begin + # Two frames sharing a Vector{Int} should still round-trip identically; + # if state leaked across frames, the second decode would resolve a + # stale back-reference and crash. + v = collect(1:10) + b1 = encode_frame(MsgHeader(), ResultMsg(v)) + b2 = encode_frame(MsgHeader(), ResultMsg(v)) + _, m1, _ = decode_frame(b1) + _, m2, _ = decode_frame(b2) + @test m1.value == v + @test m2.value == v + end + + @testset "MSG_BOUNDARY mismatch is fail-fast" begin + buf = encode_frame(MsgHeader(), ResultMsg(42)) + buf[end] ⊻= 0xFF # corrupt last byte of the boundary + _, _, err = decode_frame(buf) + @test err isa ProtocolError + end + + @testset "boundary missing entirely is fail-fast" begin + buf = encode_frame(MsgHeader(), ResultMsg(42)) + truncated = buf[1:end - length(MSG_BOUNDARY) - 1] + _, _, err = decode_frame(truncated) + @test err !== nothing # may be EOFError or ProtocolError, both acceptable + end + + @testset "body deserialization error is captured, not thrown" begin + hbuf = IOBuffer() + s = Serializer(hbuf) + serialize(s, MsgHeader()) + hbytes = take!(hbuf) + garbage = UInt8[0xff for _ in 1:32] + full = vcat(hbytes, garbage, MSG_BOUNDARY) + h, m, err = decode_frame(full) + @test h isa MsgHeader # header parsed successfully + @test err !== nothing + @test m === nothing + end + + @testset "run_work_thunk print_error mirrors Distributed" begin + path, io = mktemp() + close(io) + open(path, "w") do f + redirect_stderr(f) do + v = MPIRPC.run_work_thunk(() -> error("MPIRPC_planned_thunk_failure"), 99; print_error=true) + @test v isa MPIRemoteException + @test v.rank == 99 + end + end + cap = read(path, String) + @test occursin("MPIRPC_planned_thunk_failure", cap) + rm(path) + end +end diff --git a/lib/MPIRPC/test/runtests.jl b/lib/MPIRPC/test/runtests.jl new file mode 100644 index 000000000..274ec58ed --- /dev/null +++ b/lib/MPIRPC/test/runtests.jl @@ -0,0 +1,7 @@ +using Test +using MPIRPC + +@testset "MPIRPC" begin + include("protocol_tests.jl") + include("mpi_tests.jl") +end diff --git a/lib/MPIRPC/test/uniform_daemon_mpiexec.jl b/lib/MPIRPC/test/uniform_daemon_mpiexec.jl new file mode 100644 index 000000000..d9652e823 --- /dev/null +++ b/lib/MPIRPC/test/uniform_daemon_mpiexec.jl @@ -0,0 +1,326 @@ +using Test +using MPI +using MPIRPC + +# This script verifies that the *yield-only progress daemon* is enough to +# drive the uniform backend end-to-end without a single explicit call to +# `rpc_progress!` from user code. The flag we never set in this script — +# anywhere — is `rpc_progress!`. If anything inside the testsets calls it, +# the test would not be measuring what it claims to. + +MPI.Init(; threadlevel = :multiple) +backend = MPIRPC.select_mpi_rpc_backend!( + UniformMPIRPCBackend(MPI.COMM_WORLD; daemon = true), +) +const COMM = backend.comm +const RANK = MPI.Comm_rank(COMM) +const NPROC = MPI.Comm_size(COMM) + +NPROC >= 2 || error("uniform daemon tests require at least 2 ranks (got $NPROC)") +@assert backend.daemon "daemon flag did not stick" +@assert backend.daemon_task isa Task "daemon task should have been spawned by select_mpi_rpc_backend!" +@assert !istaskdone(backend.daemon_task) "daemon task already exited before any RPC ran" + +# The daemon must run on the `:interactive` pool so user CPU-bound work +# on the `:default` pool cannot starve the wire pump. `mpi_tests.jl` +# launches us with `--threads=N,1` to make at least one interactive +# thread available; if for some reason it didn't, the assertion below +# fails loudly rather than letting the test pass under a weaker +# guarantee than what the design promises. +if Threads.nthreads(:interactive) > 0 + @assert Threads.threadpool(backend.daemon_task) === :interactive ( + "daemon task expected on :interactive, got $(Threads.threadpool(backend.daemon_task))") +end + +# Phase boundary: still uses `rpc_barrier` because that is a *collective* +# synchronization, not a progress driver — every rank must reach it. The +# daemon happens to also pump progress on each rank while inside the +# barrier, which is fine: `rpc_barrier` and `_daemon_loop` both go +# through the same per-backend `mpi_lock` so they cannot race on raw MPI +# calls. +function _phase!(name::AbstractString) + rpc_barrier() + if RANK == 0 + println("--- ", name, " ---") + flush(stdout) + end + rpc_barrier() +end + +# --------------------------------------------------------------------------- + +_phase!("uniform-daemon / smoke (no manual progress)") +@testset "uniform-daemon / remotecall_fetch with no manual progress" begin + peer = mod(RANK + 1, NPROC) + got = remotecall_fetch(+, peer, RANK, 100) + @test got == RANK + 100 +end + +_phase!("uniform-daemon / ABBA without manual progress") +@testset "uniform-daemon / ABBA without manual progress" begin + # Same shape as the non-daemon ABBA test, but here neither client side + # nor server side has any user-driven `rpc_progress!`. Forward progress + # on inbound *and* on reply draining for our own outstanding futures + # is entirely the daemon's responsibility. (Reply draining works + # because `wait`/`fetch` on `MPIFuture` calls `rpc_progress!` itself, + # but the more interesting half — receiving inbound requests from + # peers — has no user touchpoint.) + peer = mod(RANK + 1, NPROC) + futs = [remotecall(*, peer, RANK + 1, k) for k in 1:8] + for (k, f) in enumerate(futs) + @test fetch(f) == (RANK + 1) * k + end +end + +_phase!("uniform-daemon / fully passive rank") +@testset "uniform-daemon / a rank that never initiates RPC still serves" begin + # The cleanest demonstration that the daemon is doing useful work: + # one rank issues every RPC, others issue none. Without a daemon the + # passive ranks would never see their inbound `CallMsg` because nobody + # is calling `rpc_progress!` on them. The active rank uses `wait`, + # which itself pumps progress, but that only drives *its own* MPI + # state — the inbound matching on the passive ranks is independent. + if RANK == 0 + results = Vector{Int}(undef, NPROC - 1) + for p in 1:(NPROC - 1) + results[p] = remotecall_fetch(+, p, p, 1000) + end + @test results == [p + 1000 for p in 1:(NPROC - 1)] + else + # Passive ranks: don't touch the API at all in this testset. + @test true + end +end + +_phase!("uniform-daemon / many concurrent client tasks") +@testset "uniform-daemon / concurrent client tasks rely on daemon for inbound" begin + if Threads.nthreads() >= 2 && NPROC >= 2 + peers = [r for r in 0:(NPROC - 1) if r != RANK] + K = 16 + results = Vector{Vector{Int}}(undef, length(peers)) + ts = Task[] + for (i, p) in enumerate(peers) + t = Threads.@spawn begin + local got = Vector{Int}(undef, K) + for k in 1:K + got[k] = MPIRPC.remotecall_fetch(+, p, RANK * 1000, k) + end + results[i] = got + end + push!(ts, t) + end + foreach(wait, ts) + for (i, p) in enumerate(peers) + @test results[i] == [RANK * 1000 + k for k in 1:K] + end + else + @test true + end +end + +_phase!("uniform-daemon / handler that spawns and fetches an RPC") +@testset "uniform-daemon / nested fetch from spawned handler subtask" begin + # Same regression as in the non-daemon suite (the deadlock that + # motivated removing `progress_lock`), but with the daemon also + # racing to drain the request queue. We expect both to coexist. + if Threads.nthreads() >= 2 && NPROC >= 3 + peer = mod(RANK + 1, NPROC) + helper = mod(RANK + 2, NPROC) + result = remotecall_fetch(peer, helper, RANK) do helper_rank, origin + t = Threads.@spawn MPIRPC.remotecall_fetch(+, helper_rank, origin, 7) + return fetch(t) * 2 + end + @test result == (RANK + 7) * 2 + else + @test true + end +end + +_phase!("uniform-daemon / starvation: CPU-bound default-pool task must not block daemon") +@testset "uniform-daemon / daemon survives CPU-bound default-pool work" begin + # Pathological-case test: spawn a CPU-bound task on the `:default` + # pool that does *not* yield, and concurrently issue RPC traffic to + # a peer. Under a naive design where the daemon shares the default + # pool with user work, the busy task could starve the daemon and + # the peer's `remotecall_fetch` call to *us* would never be + # serviced. With the daemon on `:interactive`, this cannot happen: + # the interactive thread is reserved. + if Threads.nthreads(:interactive) >= 1 && NPROC >= 2 + peer = mod(RANK + 1, NPROC) + # Burn ~0.5 s of CPU on every default-pool thread, with no + # explicit yield points. If the daemon were on :default it + # would be locked out for the duration. + burners = Task[] + deadline = time() + 0.5 + for _ in 1:Threads.nthreads(:default) + t = Threads.@spawn :default begin + acc = 0 + while time() < deadline + # Inner loop deliberately has no `yield()` and no + # I/O so the cooperative scheduler does not + # preempt this task. + for i in 1:1_000_000 + acc = (acc + i * 31) % 9_973 + end + end + acc + end + push!(burners, t) + end + # While the burners hammer the default pool, the peer is going + # to call remotecall_fetch on *us* — see the symmetric arm + # below. Our daemon on :interactive must keep accepting that + # request. We measure success by completing our own + # remotecall_fetch *to* the peer within a tight bound. + t0 = time() + got = MPIRPC.remotecall_fetch(+, peer, RANK, 7) + elapsed = time() - t0 + @test got == RANK + 7 + # Generous bound — what matters is that this completes at all + # while default-pool threads are saturated. Without the + # interactive-pool placement the call would block until all + # burners finish (≈ 0.5 s); with it, the call completes in a + # few ms. We give 5 s of slack so a slow CI does not flake. + @test elapsed < 5.0 + foreach(wait, burners) + else + @test true + end +end + +_phase!("uniform-daemon / cond-park: many concurrent waiters wake exactly once") +@testset "uniform-daemon / cond-park: many concurrent waiters wake correctly" begin + # Regression test for the `Threads.Condition`-park path in + # `wait(::MPIFuture)`. Spawn a large number of client tasks, each + # blocking in `fetch` on its own future, and verify every one + # observes its expected reply. The properties we want to lock in: + # + # 1. Every parked waiter is woken — no lost-wakeup race between + # `deliver!`'s `notify` and the consumer's `wait(f.cond)`. + # 2. Each waiter wakes for *its own* future (no cross-future + # wakeup that returns a stale value to the wrong task). + # 3. The daemon's `_dispatch_reply!` path correctly transitions + # through `take_waiter!` → `deliver!` → cond `notify` while + # none of `mpi_lock`, `waiters_lock`, `f.cond` are held more + # than one at a time. + # + # The test cannot directly assert "no spinning happened" without + # peeking at task state, but a regression to the spin path would + # still pass *correctness* — so we additionally check that the + # waiters are *responsive*, by serializing many short calls and + # measuring that they all complete inside a generous bound. If + # `wait` was somehow not waking on `notify` (e.g. a bug where we + # accidentally took a stale `isready` snapshot under the lock and + # parked anyway), this would time out. + if Threads.nthreads() >= 2 + peers = [r for r in 0:(NPROC - 1) if r != RANK] + K = 64 # waiters per peer + results = Vector{Vector{Int}}(undef, length(peers)) + ts = Task[] + t_start = time() + for (i, p) in enumerate(peers) + t = Threads.@spawn begin + local got = Vector{Int}(undef, K) + # Each task spawns its own future and immediately blocks + # in `fetch`. Under daemon=true this means each task + # parks on its own `f.cond` while the daemon services + # all of the inbound replies on this rank. + for k in 1:K + got[k] = MPIRPC.remotecall_fetch(+, p, RANK * 10_000, k) + end + results[i] = got + end + push!(ts, t) + end + foreach(wait, ts) + elapsed = time() - t_start + for (i, p) in enumerate(peers) + @test results[i] == [RANK * 10_000 + k for k in 1:K] + end + # Loose upper bound. On a healthy machine this completes in + # well under a second; we give it 60 s to absorb CI variance. + # The point is to catch a hang, not to enforce performance. + @test elapsed < 60.0 + else + @test true + end +end + +_phase!("uniform-daemon / Imrecv!: large concurrent payloads round-trip") +@testset "uniform-daemon / large concurrent payloads via Imrecv! + Test/yield" begin + # Regression test for the non-blocking-receive path. Each task on + # this rank issues a `remotecall_fetch` with a 1 MiB payload. On + # most MPI implementations 1 MiB exceeds the eager threshold + # (typical defaults: 64 KiB for OpenMPI, 256 KiB for MPICH), so + # the receiving side actually goes through the rendezvous + # protocol — which is the one regime where the previous + # `MPI.Mrecv!` would have held `mpi_lock` for a non-trivial + # duration. With `Imrecv!` + `Test`/`yield` this is broken up. + # + # We verify only correctness (all replies arrive, with the right + # values) and that the whole batch completes in bounded time. + # Asserting "the daemon thread did not block" directly is not + # cleanly testable from user code; the existence and correctness + # of the round-trip under concurrent pressure is the practical + # proxy. + if Threads.nthreads() >= 2 && NPROC >= 2 + peer = mod(RANK + 1, NPROC) + N = 1_048_576 # 1 MiB + K = 4 # concurrent calls + # Use the same canonical payload across tasks so we can + # cross-check sums without reconstructing per-task arrays. + payload = Vector{UInt8}(undef, N) + for i in 1:N + payload[i] = UInt8((i - 1) % 256) + end + expected = sum(Int(b) for b in payload) + ts = Task[] + for _ in 1:K + t = Threads.@spawn MPIRPC.remotecall_fetch(peer, payload) do data + # Verify on the remote side that the buffer survived + # the rendezvous round-trip intact, then return the + # sum so the client can independently verify. + length(data) == N || error("size mismatch: got $(length(data))") + acc = 0 + for b in data + acc += Int(b) + end + return acc + end + push!(ts, t) + end + t0 = time() + results = fetch.(ts) + elapsed = time() - t0 + @test all(==(expected), results) + # Loose bound: with rendezvous and 4 concurrent 1-MiB payloads + # plus their replies, we expect well under 30 s on any sane + # interconnect (loopback / shared memory / TCP localhost). The + # point of the bound is to catch a hang from a leaked request + # or a missed `Test`, not to enforce performance. + @test elapsed < 30.0 + else + @test true + end +end + +_phase!("uniform-daemon / shutdown joins the daemon task") +@testset "uniform-daemon / shutdown! cleanly joins the daemon" begin + @test backend.daemon_task isa Task + @test !istaskdone(backend.daemon_task) + # `shutdown!` flips `running[]` to `false` and waits for the daemon + # task to terminate. After it returns, the daemon must be done and + # the `daemon_task` slot must have been cleared so a (hypothetical) + # subsequent re-init does not see a stale handle. + shutdown!() + @test backend.running[] == false + @test backend.daemon_task === nothing +end + +# Final coordination on the duplicate communicator. We cannot use +# `rpc_barrier` here: the daemon is gone, and every rank has called +# `shutdown!` so no rank is pumping progress. A plain `MPI.Barrier` +# on `backend.comm` is what we want — there is no in-flight RPC at +# this point because every testset above completed via `fetch`/`wait`. +MPI.Barrier(backend.comm) +MPI.Finalize() diff --git a/lib/MPIRPC/test/uniform_mpiexec.jl b/lib/MPIRPC/test/uniform_mpiexec.jl new file mode 100644 index 000000000..4f5187534 --- /dev/null +++ b/lib/MPIRPC/test/uniform_mpiexec.jl @@ -0,0 +1,336 @@ +using Test +using MPI +using MPIRPC +using Random + +# Each rank sets the *same* seed so any per-rank randomized peer-selection +# decision is reproducible across processes. +Random.seed!(0xC0FFEE) + +MPI.Init(; threadlevel=:multiple) +backend = MPIRPC.select_mpi_rpc_backend!(UniformMPIRPCBackend(MPI.COMM_WORLD)) +const COMM = backend.comm +const RANK = MPI.Comm_rank(COMM) +const NPROC = MPI.Comm_size(COMM) + +NPROC >= 2 || error("uniform tests require at least 2 ranks (got $NPROC)") + +# Synchronisation aid: a *progress-pumping* barrier on the RPC subcomm and +# announce the next testset on rank 0 so a hang is easy to localise. We +# specifically use `rpc_barrier` rather than `MPI.Barrier` so ranks continue +# servicing inbound RPC while waiting for stragglers — without this, a rank +# whose own primary `remotecall_fetch` completed could exit with a peer's +# nested-RPC request still queued for it. +function _phase!(name::AbstractString) + rpc_barrier() + if RANK == 0 + println("--- ", name, " ---") + flush(stdout) + end + rpc_barrier() +end + +# Pump progress until `pred()` is true or `timeout` seconds elapse. Returns +# `true` if the predicate became true. Liveness assertion: misuse of the API +# (forgetting to pump progress, holding a lock across fetch, etc.) fails in +# bounded time rather than hanging the whole CI. +function pump_until(pred; timeout::Real = 30.0, + backend = MPIRPC.current_mpi_rpc_backend()) + t0 = time() + while !pred() + rpc_progress!(backend) + time() - t0 > timeout && return false + yield() + end + return true +end + +# --------------------------------------------------------------------------- + +_phase!("uniform / smoke") +@testset "uniform / smoke" begin + peer = mod(RANK + 1, NPROC) + got = remotecall_fetch(+, peer, RANK, 100) + # `+` ran on `peer` over the args (RANK, 100), so the result is RANK + 100. + @test got == RANK + 100 +end + +_phase!("uniform / handler observes its own rank") +@testset "uniform / handler observes its own rank" begin + peer = mod(RANK + 1, NPROC) + got = remotecall_fetch(peer) do + return MPI.Comm_rank(MPI.COMM_WORLD) + end + @test got == peer +end + +_phase!("uniform / ABBA cross-calls") +@testset "uniform / ABBA: every rank simultaneously calls (i+1) % N" begin + # Classic ABBA pattern: rank i sends a request to (i+1) and must service + # the inbound request from (i-1) at the same time. With one tag per kind + # of message (request vs reply), correlation by RRID, and progress on + # every rank, this must complete deterministically. + peer = mod(RANK + 1, NPROC) + futs = [remotecall(*, peer, RANK + 1, k) for k in 1:8] + @test pump_until(() -> all(isready, futs); timeout = 30.0) + if all(isready, futs) + for (k, f) in enumerate(futs) + @test fetch(f) == (RANK + 1) * k + end + end +end + +_phase!("uniform / many OIDs, permuted fetch on one peer") +@testset "uniform / many OIDs, permuted fetch on one peer" begin + peer = mod(RANK + 1, NPROC) + K = 64 + expected = [RANK * 1000 + k for k in 1:K] + futs = [remotecall(+, peer, RANK * 1000, k) for k in 1:K] + + perm = randperm(K) + @test pump_until(() -> all(isready, futs); timeout = 30.0) + if all(isready, futs) + got = Vector{Int}(undef, K) + for i in perm + got[i] = fetch(futs[i]) + end + @test got == expected + end +end + +_phase!("uniform / all-to-all burst") +@testset "uniform / all-to-all burst" begin + peers = [r for r in 0:(NPROC-1) if r != RANK] + futs = [remotecall(+, p, RANK, p) for p in peers] + @test pump_until(() -> all(isready, futs); timeout = 30.0) + if all(isready, futs) + for (p, f) in zip(peers, futs) + @test fetch(f) == RANK + p + end + end +end + +_phase!("uniform / same-pair, different correlation ids") +@testset "uniform / same-pair, different correlation ids" begin + # Two messages from the same (src, dest) at the same time, distinguished + # only by RRID. MPI guarantees in-order delivery on (src, dest, tag), but + # the *application-level* matching must use RRID rather than message order + # to avoid coupling correctness to MPI ordering of unrelated calls. + peer = mod(RANK + 1, NPROC) + f1 = remotecall(identity, peer, :first) + f2 = remotecall(identity, peer, :second) + @test pump_until(() -> isready(f1) && isready(f2); timeout = 30.0) + @test fetch(f1) == :first + @test fetch(f2) == :second +end + +_phase!("uniform / nested re-entrant remotecall_fetch") +@testset "uniform / nested re-entrant remotecall_fetch" begin + if NPROC >= 3 + peer = mod(RANK + 1, NPROC) + helper = mod(RANK + 2, NPROC) + # Inside the handler on `peer`, call back to a third rank. Exercises + # nested progress: while `peer` is servicing rank `RANK`'s request, + # its handler must be able to issue and await a fresh + # remotecall_fetch — and the originating rank must keep pumping + # progress so its outer reply can come back. + result = remotecall_fetch(peer, helper, RANK) do helper_rank, origin + inner = MPIRPC.remotecall_fetch(+, helper_rank, origin, 1) + return inner * 10 + end + @test result == (RANK + 1) * 10 + else + @test true # not enough ranks + end +end + +_phase!("uniform / multi-threaded handler-spawn-fetch") +@testset "uniform / handler that Threads.@spawn-then-fetches an RPC" begin + # Regression test for the deadlock previously documented in + # `docs/ARCHITECTURE.md` §6: with synchronous handlers, the handler ran + # on the calling task while `progress_lock` was held. A handler that + # spawned its own task and called `fetch` on it would block forever + # because the spawned task could not acquire `progress_lock` from + # another OS thread (the lock was held by the handler's task on a + # different thread). After moving handlers to `Threads.@spawn`, this + # pattern works. + if Threads.nthreads() >= 2 && NPROC >= 3 + peer = mod(RANK + 1, NPROC) + helper = mod(RANK + 2, NPROC) + result = remotecall_fetch(peer, helper, RANK) do helper_rank, origin + t = Threads.@spawn MPIRPC.remotecall_fetch(+, helper_rank, origin, 7) + return fetch(t) * 2 + end + @test result == (RANK + 7) * 2 + else + @test true # need julia -t >=2 and at least 3 ranks to exercise this + end +end + +_phase!("uniform / many handlers, threads concurrent waiters") +@testset "uniform / concurrent client threads pumping the same backend" begin + # Several client tasks on the same rank issue concurrent + # `remotecall_fetch` calls. With `progress_lock` removed, every task + # can drive progress on its own; with handlers spawned on the peer, + # several of *its* handler tasks can run on different threads. The + # combined effect is a stress test for the lock geometry under + # `julia -t >=2`. + if Threads.nthreads() >= 2 + peers = [r for r in 0:(NPROC-1) if r != RANK] + K = 16 + results = Vector{Vector{Int}}(undef, length(peers)) + client_tasks = Task[] + for (i, p) in enumerate(peers) + t = Threads.@spawn begin + local got = Vector{Int}(undef, K) + for k in 1:K + got[k] = MPIRPC.remotecall_fetch(+, p, RANK * 1000, k) + end + results[i] = got + end + push!(client_tasks, t) + end + for t in client_tasks + wait(t) + end + for (i, p) in enumerate(peers) + @test results[i] == [RANK * 1000 + k for k in 1:K] + end + else + @test true + end +end + +_phase!("uniform / remote_do is fire-and-forget") +@testset "uniform / remote_do returns nothing immediately" begin + peer = mod(RANK + 1, NPROC) + # `remote_do` is fire-and-forget: it returns `nothing` as soon as the + # request has been posted, regardless of whether the remote handler + # has run. + @test remote_do(peer) do; nothing end === nothing +end + +_phase!("uniform / remote_do failure prints to stderr on handler rank") +@testset "uniform / remote_do failure prints to stderr on handler rank" begin + # Rank 0 issues `remote_do` to rank 1; rank 1 captures stderr while its + # progress loop (spawned task) services the inbound message so + # `showerror` from `run_work_thunk(...; print_error=true)` lands in the file. + cap_mark = "MPIRPC_planned_remote_do_failure" + cap_path = joinpath(mktempdir(), "mpirpc_remote_do_stderr.txt") + cap_task = nothing + if RANK == 1 + ready = Channel{Nothing}(1) + cap_task = Threads.@spawn begin + open(cap_path, "w") do f + redirect_stderr(f) do + put!(ready, nothing) + t0 = time() + while time() - t0 < 30.0 + rpc_progress!(backend) + flush(f) + yield() + if isfile(cap_path) && filesize(cap_path) > 0 + s = read(cap_path, String) + occursin(cap_mark, s) && break + end + end + end + end + return read(cap_path, String) + end + take!(ready) + end + rpc_barrier() + if RANK == 0 + @test remote_do(1) do + error(cap_mark) + end === nothing + end + rpc_barrier() + if RANK == 1 + txt = fetch(cap_task::Task) + @test occursin(cap_mark, txt) + rm(cap_path, force=true) + end + rpc_barrier() +end + +_phase!("uniform / set-then-read via remotecall_wait") +@testset "uniform / set-then-read with remotecall_wait" begin + peer = mod(RANK + 1, NPROC) + sentinel_name = Symbol("_MPIRPC_SET_FROM_$(RANK)") + # `remotecall_wait` blocks the caller until the remote handler + # acknowledges completion, which is the correct synchronization + # primitive when a follow-up call needs to observe the side effect. + # Under multi-threading the spawned-handler model does *not* preserve + # FIFO of *handler execution* on the same `(src, dest, tag)` — only + # FIFO of message delivery — so a `remote_do` followed by a sync + # `remotecall_fetch` would race. + remotecall_wait(peer, RANK, sentinel_name) do origin, name + @eval Main const $name = $origin + end + val = remotecall_fetch(peer, sentinel_name) do name + return getfield(Main, name) + end + @test val == RANK +end + +_phase!("uniform / remote exception is wrapped") +@testset "uniform / remote exception is wrapped" begin + peer = mod(RANK + 1, NPROC) + @test_throws MPIRemoteException remotecall_fetch(peer) do + error("planned failure on remote rank") + end +end + +_phase!("uniform / dest validation") +@testset "uniform / dest validation" begin + @test_throws ArgumentError remotecall(+, NPROC, 1) + @test_throws ArgumentError remotecall(+, -1, 1) +end + +_phase!("uniform / world barrier coexists with subcomm RPC") +@testset "uniform / collective on world comm coexists with RPC subcomm" begin + # The uniform backend dups MPI.COMM_WORLD by default, so a Barrier on the + # original world communicator does not collide with pending RPC traffic. + peer = mod(RANK + 1, NPROC) + fut = remotecall(+, peer, RANK, 1) + MPI.Barrier(MPI.COMM_WORLD) + @test pump_until(() -> isready(fut); timeout = 30.0) + if isready(fut) + @test fetch(fut) == RANK + 1 + end + MPI.Barrier(MPI.COMM_WORLD) +end + +_phase!("uniform / stress: rounds of permuted all-to-all") +@testset "uniform / stress: rounds of permuted all-to-all" begin + M = 4 + failures = 0 + for round in 1:M + # Same seed across ranks ⇒ same permutation; identical scheduling + # decisions on every rank so any failure is reproducible. + Random.seed!(0xCAFE00 + round) + peers = collect(0:(NPROC-1))[randperm(NPROC)] + peers_nonself = [p for p in peers if p != RANK] + futs = [remotecall(+, p, RANK, round) for p in peers_nonself] + completed = pump_until(() -> all(isready, futs); timeout = 30.0) + if completed + for f in futs + fetch(f) == RANK + round || (failures += 1) + end + else + failures += length(peers_nonself) + end + # Progress-pumping barrier between rounds so any round-N inner + # request still in flight on a peer is drained before round N+1 + # starts, even if a few ranks raced ahead of others. + rpc_barrier() + end + @test failures == 0 +end + +_phase!("uniform / shutdown") +shutdown!() +rpc_barrier() +MPI.Finalize() diff --git a/src/Dagger.jl b/src/Dagger.jl index 2e757ebc5..1a5720784 100644 --- a/src/Dagger.jl +++ b/src/Dagger.jl @@ -53,6 +53,13 @@ import Adapt include("lib/util.jl") include("utils/dagdebug.jl") +# Type definitions (for MPI/acceleration) +include("types/processor.jl") +include("types/scope.jl") +include("types/memory-space.jl") +include("types/chunk.jl") +include("types/acceleration.jl") + # Distributed data include("utils/locked-object.jl") include("utils/tasks.jl") @@ -77,12 +84,14 @@ include("queue.jl") include("thunk.jl") include("utils/fetch.jl") include("utils/chunks.jl") +include("weakchunk.jl") include("utils/logging.jl") include("submission.jl") abstract type MemorySpace end include("utils/memory-span.jl") include("utils/interval_tree.jl") include("memory-spaces.jl") +include("acceleration.jl") # Task scheduling include("compute.jl") @@ -90,6 +99,7 @@ include("utils/clock.jl") include("utils/system_uuid.jl") include("utils/caching.jl") include("sch/Sch.jl"); using .Sch +include("tochunk.jl") # Data dependency task queue include("datadeps/aliasing.jl") @@ -157,6 +167,10 @@ function set_distributed_package!(value) @info "Dagger.jl preference has been set, restart your Julia session for this change to take effect!" end +# MPI (mpi.jl loads MPI; mpi_mempool uses it) +include("mpi.jl") +include("mpi_mempool.jl") + # Precompilation import PrecompileTools: @compile_workload include("precompile.jl") diff --git a/src/acceleration.jl b/src/acceleration.jl new file mode 100644 index 000000000..f95236468 --- /dev/null +++ b/src/acceleration.jl @@ -0,0 +1,46 @@ +const ACCELERATION = TaskLocalValue{Acceleration}(() -> DistributedAcceleration()) + +current_acceleration() = ACCELERATION[] + +default_processor(::DistributedAcceleration) = OSProc(myid()) +default_processor(accel::DistributedAcceleration, x) = default_processor(accel) +default_processor() = default_processor(current_acceleration()) + +accelerate!(accel::Symbol) = accelerate!(Val{accel}()) +accelerate!(::Val{:distributed}) = accelerate!(DistributedAcceleration()) + +function _with_default_acceleration(f) + old_accel = ACCELERATION[] + ACCELERATION[] = DistributedAcceleration() + result = try + f() + finally + ACCELERATION[] = old_accel + end + return result +end + +initialize_acceleration!(a::DistributedAcceleration) = nothing +function accelerate!(accel::Acceleration) + initialize_acceleration!(accel) + ACCELERATION[] = accel +end +accelerate!(::Nothing) = nothing + +accel_matches_proc(accel::DistributedAcceleration, proc::OSProc) = true +accel_matches_proc(accel::DistributedAcceleration, proc) = true + +function compatible_processors(accel::Union{Acceleration,Nothing}, scope::AbstractScope, procs::Vector{<:Processor}) + comp = compatible_processors(scope, procs) + accel === nothing && return comp + return Set(p for p in comp if accel_matches_proc(accel, p)) +end + +uniform_execution(::DistributedAcceleration) = false +uniform_execution() = uniform_execution(current_acceleration()) + +default_processor(space::CPURAMMemorySpace) = OSProc(space.owner) +default_memory_space(accel::DistributedAcceleration) = CPURAMMemorySpace(myid()) +default_memory_space(accel::DistributedAcceleration, x) = default_memory_space(accel) +default_memory_space(x) = default_memory_space(current_acceleration(), x) +default_memory_space() = default_memory_space(current_acceleration()) diff --git a/src/affinity.jl b/src/affinity.jl new file mode 100644 index 000000000..aab663a51 --- /dev/null +++ b/src/affinity.jl @@ -0,0 +1,32 @@ +export domain, UnitDomain, project, alignfirst, ArrayDomain + +import Base: isempty, getindex, intersect, ==, size, length, ndims + +""" + domain(x::T) + +Returns metadata about `x`. This metadata will be in the `domain` +field of a Chunk object when an object of type `T` is created as +the result of evaluating a Thunk. +""" +function domain end + +""" + UnitDomain + +Default domain -- has no information about the value +""" +struct UnitDomain end + +""" +If no `domain` method is defined on an object, then +we use the `UnitDomain` on it. A `UnitDomain` is indivisible. +""" +domain(x::Any) = UnitDomain() + +### ChunkIO +affinity(r::DRef) = OSProc(r.owner)=>r.size +# this previously returned a vector with all machines that had the file cached +# but now only returns the owner and size, for consistency with affinity(::DRef), +# see #295 +affinity(r::FileRef) = OSProc(1)=>r.size diff --git a/src/array/alloc.jl b/src/array/alloc.jl index aa1050210..33de3506d 100644 --- a/src/array/alloc.jl +++ b/src/array/alloc.jl @@ -93,14 +93,31 @@ function stage(ctx, A::AllocateArray) scope = ExactScope(A.procgrid[CartesianIndex(mod1.(Tuple(I), size(A.procgrid))...)]) end + N = ndims(A.domainchunks) + ret_type = Array{A.eltype, N} if A.want_index - task = Dagger.@spawn compute_scope=scope allocate_array(A.f, A.eltype, i, size(x)) + task = Dagger.@spawn compute_scope=scope return_type=ret_type allocate_array(A.f, A.eltype, i, size(x)) else - task = Dagger.@spawn compute_scope=scope allocate_array(A.f, A.eltype, size(x)) + task = Dagger.@spawn compute_scope=scope return_type=ret_type allocate_array(A.f, A.eltype, size(x)) end tasks[i] = task end end + # MPI type propagation: ensure all ranks know the concrete chunk types + accel = Dagger.current_acceleration() + if accel isa Dagger.MPIAcceleration + N = ndims(A.domainchunks) + expected_type = Array{A.eltype, N} + Dagger.mpi_propagate_chunk_types!(tasks, accel, expected_type) + # Log chunk types per rank after array creation + rank = MPI.Comm_rank(accel.comm) + #=chunk_types = Type[chunktype(t) for t in tasks] + if allequal(chunk_types) + @info "[rank $rank] Array creation (alloc): all $(length(chunk_types)) chunk types are uniform: $(first(chunk_types))" + else + @warn "[rank $rank] Array creation (alloc): chunk types are NOT uniform: $chunk_types" + end=# + end return DArray(A.eltype, A.domain, A.domainchunks, tasks, A.partitioning) end diff --git a/src/array/darray.jl b/src/array/darray.jl index 32336f95d..fc99dc75d 100644 --- a/src/array/darray.jl +++ b/src/array/darray.jl @@ -1,7 +1,7 @@ -import Base: ==, fetch +import Base: ==, fetch, length, isempty, size export DArray, DVector, DMatrix, DVecOrMat, Blocks, AutoBlocks -export distribute +export distribute, collect_datadeps ###### Array Domains ###### @@ -36,7 +36,7 @@ Base.getindex(arr::AbstractArray{T,0} where T, d::ArrayDomain{0}) = arr Base.getindex(arr::GPUArraysCore.AbstractGPUArray, d::ArrayDomain) = arr[indexes(d)...] Base.getindex(arr::GPUArraysCore.AbstractGPUArray{T,0} where T, d::ArrayDomain{0}) = arr -function intersect(a::ArrayDomain, b::ArrayDomain) +function Base.intersect(a::ArrayDomain, b::ArrayDomain) if a === b return a end @@ -83,7 +83,8 @@ isempty(a::ArrayDomain) = length(a) == 0 The domain of an array is an ArrayDomain. """ domain(x::AbstractArray) = ArrayDomain([1:l for l in size(x)]) - +# Scalar / non-array values (e.g. for Chunk of immediate data) +domain(x::Any) = ArrayDomain(()) abstract type ArrayOp{T, N} <: AbstractArray{T, N} end Base.IndexStyle(::Type{<:ArrayOp}) = IndexCartesian() @@ -174,6 +175,7 @@ domain(d::DArray) = d.domain chunks(d::DArray) = d.chunks domainchunks(d::DArray) = d.subdomains size(x::DArray) = size(domain(x)) +Base.ndims(d::DArray{T,N}) where {T,N} = N stage(ctx, c::DArray) = c function Base.collect(d::DArray{T,N}; tree=false, copyto=false) where {T,N} @@ -200,6 +202,31 @@ function Base.collect(d::DArray{T,N}; tree=false, copyto=false) where {T,N} collect(treereduce_nd(dimcatfuncs, asyncmap(fetch, a.chunks))) end end + +""" + collect_datadeps(d::DArray; root=nothing) + +Collect a DArray to a single array by fetching every chunk on the current rank +and assembling into a full array. No datadeps scheduling or root-only assembly: +each rank that calls this gets the full matrix (useful when correctness matters +more than communication cost). +""" +function collect_datadeps(d::DArray{T,N}; root=nothing) where {T,N} + if isempty(d.chunks) + return Array{eltype(d)}(undef, size(d)...) + end + if N == 0 + return fetch(d.chunks[1]) + end + + chks = d.chunks + doms = domainchunks(d) + out = Array{T,N}(undef, size(d)) + for I in CartesianIndices(chks) + copyto!(view(out, indexes(doms[I])...), fetch(chks[I])) + end + return out +end Array{T,N}(A::DArray{S,N}) where {T,N,S} = convert(Array{T,N}, collect(A)) Base.wait(A::DArray) = foreach(wait, A.chunks) @@ -483,6 +510,21 @@ function stage(ctx::Context, d::Distribute) Dagger.@spawn compute_scope=scope identity(d.data[c]) end end + # MPI type propagation: ensure all ranks know the concrete chunk types + accel = Dagger.current_acceleration() + if accel isa Dagger.MPIAcceleration + N = Base.ndims(d.data) + expected_type = Array{eltype(d.data), N} + Dagger.mpi_propagate_chunk_types!(cs, accel, expected_type) + # Log chunk types per rank after array creation + rank = MPI.Comm_rank(accel.comm) + #=chunk_types = Type[chunktype(t) for t in cs] + if allequal(chunk_types) + @info "[rank $rank] Array creation (distribute): all $(length(chunk_types)) chunk types are uniform: $(first(chunk_types))" + else + @warn "[rank $rank] Array creation (distribute): chunk types are NOT uniform: $chunk_types" + end=# + end return DArray(eltype(d.data), domain(d.data), d.domainchunks, @@ -620,7 +662,7 @@ end mapchunk(f, chunk) = tochunk(f(poolget(chunk.handle))) function mapchunks(f, d::DArray{T,N,F}) where {T,N,F} chunks = map(d.chunks) do chunk - owner = get_parent(chunk.processor).pid + owner = root_worker_id(chunk.processor) remotecall_fetch(mapchunk, owner, f, chunk) end DArray{T,N,F}(d.domain, d.subdomains, chunks, d.concat) diff --git a/src/array/lu.jl b/src/array/lu.jl index b669100e2..7fbd87991 100644 --- a/src/array/lu.jl +++ b/src/array/lu.jl @@ -107,6 +107,7 @@ function swaprows_panel!(A::AbstractMatrix{T}, M::AbstractMatrix{T}, ipiv_chunk: A[p,:], M[r,:] = M[r,:], A[p,:] end end + return A end @kernel function _geru_kernel!(alpha, x, y, A) @@ -156,6 +157,7 @@ function swaprows_trail!(A::AbstractMatrix{T}, M::AbstractMatrix{T}, ipiv::Abstr end end end + return A end # Implementation of https://inria.hal.science/hal-04984070v1/file/ipdps_paper.pdf diff --git a/src/array/mul.jl b/src/array/mul.jl index 02b207641..5890473da 100644 --- a/src/array/mul.jl +++ b/src/array/mul.jl @@ -41,7 +41,7 @@ function LinearAlgebra.generic_matmatmul!( return gemm_dagger!(C, transA, transB, A, B, alpha, beta) end end -function _repartition_matmatmul(C, A, B, transA::Char, transB::Char) +function _repartition_matmatmul(C, A, B, transA::Char, transB::Char)::Tuple{Blocks{2}, Blocks{2}, Blocks{2}} partA = A.partitioning.blocksize partB = B.partitioning.blocksize istransA = transA == 'T' || transA == 'C' @@ -93,6 +93,24 @@ function _repartition_matmatmul(C, A, B, transA::Char, transB::Char) return Blocks(partC...), Blocks(partA...), Blocks(partB...) end +# Typed BLAS wrappers so that every @spawn kernel has an inferable return type +@inline function _gemm!(transA::Char, transB::Char, alpha::T, A, B, mzone, C)::Matrix{T} where {T} + BLAS.gemm!(transA, transB, alpha, A, B, mzone, C) + return C +end +@inline function _syrk!(uplo::AbstractChar, trans::AbstractChar, alpha::T, A, mzone, C)::Matrix{T} where {T} + BLAS.syrk!(uplo, trans, alpha, A, mzone, C) + return C +end +@inline function _herk!(uplo::AbstractChar, trans::AbstractChar, alpha::Real, A, mzone, C)::Matrix{<:Complex} + BLAS.herk!(uplo, trans, alpha, A, mzone, C) + return C +end +@inline function _gemv!(transA::Char, alpha::T, A, x, mzone, y)::Vector{T} where {T} + BLAS.gemv!(transA, alpha, A, x, mzone, y) + return y +end + """ Performs one of the matrix-matrix operations @@ -136,7 +154,7 @@ function gemm_dagger!( # A: NoTrans / B: NoTrans for k in range(1, Ant) mzone = k == 1 ? beta : T(1.0) - Dagger.@spawn BLAS.gemm!( + Dagger.@spawn _gemm!( transA, transB, alpha, @@ -150,7 +168,7 @@ function gemm_dagger!( # A: NoTrans / B: [Conj]Trans for k in range(1, Ant) mzone = k == 1 ? beta : T(1.0) - Dagger.@spawn BLAS.gemm!( + Dagger.@spawn _gemm!( transA, transB, alpha, @@ -166,7 +184,7 @@ function gemm_dagger!( # A: [Conj]Trans / B: NoTrans for k in range(1, Amt) mzone = k == 1 ? beta : T(1.0) - Dagger.@spawn BLAS.gemm!( + Dagger.@spawn _gemm!( transA, transB, alpha, @@ -180,7 +198,7 @@ function gemm_dagger!( # A: [Conj]Trans / B: [Conj]Trans for k in range(1, Amt) mzone = k == 1 ? beta : T(1.0) - Dagger.@spawn BLAS.gemm!( + Dagger.@spawn _gemm!( transA, transB, alpha, @@ -243,7 +261,7 @@ function syrk_dagger!( for k in range(1, Ant) mzone = k == 1 ? real(beta) : one(real(T)) if iscomplex - Dagger.@spawn BLAS.herk!( + Dagger.@spawn _herk!( uplo, trans, real(alpha), @@ -252,7 +270,7 @@ function syrk_dagger!( InOut(Cc[n, n]), ) else - Dagger.@spawn BLAS.syrk!( + Dagger.@spawn _syrk!( uplo, trans, alpha, @@ -267,7 +285,7 @@ function syrk_dagger!( for m in range(n + 1, Cmt) for k in range(1, Ant) mzone = k == 1 ? beta : one(T) - Dagger.@spawn BLAS.gemm!( + Dagger.@spawn _gemm!( trans, transs, alpha, @@ -283,7 +301,7 @@ function syrk_dagger!( for m in range(n + 1, Cmt) for k in range(1, Ant) mzone = k == 1 ? beta : one(T) - Dagger.@spawn BLAS.gemm!( + Dagger.@spawn _gemm!( trans, transs, alpha, @@ -300,7 +318,7 @@ function syrk_dagger!( for k in range(1, Amt) mzone = k == 1 ? real(beta) : one(real(T)) if iscomplex - Dagger.@spawn BLAS.herk!( + Dagger.@spawn _herk!( uplo, transs, real(alpha), @@ -309,7 +327,7 @@ function syrk_dagger!( InOut(Cc[n, n]), ) else - Dagger.@spawn BLAS.syrk!( + Dagger.@spawn _syrk!( uplo, trans, alpha, @@ -324,7 +342,7 @@ function syrk_dagger!( for m in range(n + 1, Cmt) for k in range(1, Amt) mzone = k == 1 ? beta : one(T) - Dagger.@spawn BLAS.gemm!( + Dagger.@spawn _gemm!( transs, 'N', alpha, @@ -340,7 +358,7 @@ function syrk_dagger!( for m in range(n + 1, Cmt) for k in range(1, Amt) mzone = k == 1 ? beta : one(T) - Dagger.@spawn BLAS.gemm!( + Dagger.@spawn _gemm!( transs, 'N', alpha, @@ -393,16 +411,17 @@ end return A end -@inline function copytile!(A, B) +@inline function copytile!(A::AbstractMatrix{T}, B::AbstractMatrix{T})::Nothing where {T} m, n = size(A) C = B' for i = 1:m, j = 1:n A[i, j] = C[i, j] end + return nothing end -@inline function copydiagtile!(A, uplo) +@inline function copydiagtile!(A::AbstractMatrix{T}, uplo::AbstractChar)::Nothing where {T} m, n = size(A) Acpy = copy(A) @@ -417,6 +436,7 @@ end for i = 1:m, j = 1:n A[i, j] = C[i, j] end + return nothing end function LinearAlgebra.generic_matvecmul!( C::DVector{T}, @@ -440,7 +460,7 @@ function LinearAlgebra.generic_matvecmul!( return gemv_dagger!(C, transA, A, B, _alpha, _beta) end end -function _repartition_matvecmul(C, A, B, transA::Char) +function _repartition_matvecmul(C, A, B, transA::Char)::Tuple{Blocks{1}, Blocks{2}, Blocks{1}} partA = A.partitioning.blocksize partB = B.partitioning.blocksize istransA = transA == 'T' || transA == 'C' @@ -495,7 +515,7 @@ function gemv_dagger!( # A: NoTrans for k in range(1, Ant) mzone = k == 1 ? beta : T(1.0) - Dagger.@spawn BLAS.gemv!( + Dagger.@spawn _gemv!( transA, alpha, In(Ac[m, k]), @@ -508,7 +528,7 @@ function gemv_dagger!( # A: [Conj]Trans for k in range(1, Amt) mzone = k == 1 ? beta : T(1.0) - Dagger.@spawn BLAS.gemv!( + Dagger.@spawn _gemv!( transA, alpha, In(Ac[k, m]), diff --git a/src/array/trsm.jl b/src/array/trsm.jl index 65e87c5d5..c0c025468 100644 --- a/src/array/trsm.jl +++ b/src/array/trsm.jl @@ -189,4 +189,4 @@ function trsm!(side::Char, uplo::Char, trans::Char, diag::Char, alpha::T, A::DMa end end -end \ No newline at end of file +end diff --git a/src/chunks.jl b/src/chunks.jl index 03bdfb65d..d5e7b6082 100644 --- a/src/chunks.jl +++ b/src/chunks.jl @@ -1,56 +1,4 @@ -export domain, UnitDomain, project, alignfirst, ArrayDomain - -import Base: isempty, getindex, intersect, ==, size, length, ndims - -""" - domain(x::T) - -Returns metadata about `x`. This metadata will be in the `domain` -field of a Chunk object when an object of type `T` is created as -the result of evaluating a Thunk. -""" -function domain end - -""" - UnitDomain - -Default domain -- has no information about the value -""" -struct UnitDomain end - -""" -If no `domain` method is defined on an object, then -we use the `UnitDomain` on it. A `UnitDomain` is indivisible. -""" -domain(x::Any) = UnitDomain() - -###### Chunk ###### - -""" - Chunk - -A reference to a piece of data located on a remote worker. `Chunk`s are -typically created with `Dagger.tochunk(data)`, and the data can then be -accessed from any worker with `collect(::Chunk)`. `Chunk`s are -serialization-safe, and use distributed refcounting (provided by -`MemPool.DRef`) to ensure that the data referenced by a `Chunk` won't be GC'd, -as long as a reference exists on some worker. - -Each `Chunk` is associated with a given `Dagger.Processor`, which is (in a -sense) the processor that "owns" or contains the data. Calling -`collect(::Chunk)` will perform data movement and conversions defined by that -processor to safely serialize the data to the calling worker. - -## Constructors -See [`tochunk`](@ref). -""" -mutable struct Chunk{T, H, P<:Processor, S<:AbstractScope} - chunktype::Type{T} - domain - handle::H - processor::P - scope::S -end +###### Chunk Methods ###### domain(c::Chunk) = c.domain chunktype(c::Chunk) = c.chunktype @@ -72,20 +20,27 @@ function collect(ctx::Context, chunk::Chunk; options=nothing) elseif chunk.handle isa FileRef return poolget(chunk.handle) else - return move(chunk.processor, OSProc(), chunk.handle) + return move(chunk.processor, default_processor(), chunk.handle) end end collect(ctx::Context, ref::DRef; options=nothing) = move(OSProc(ref.owner), OSProc(), ref) collect(ctx::Context, ref::FileRef; options=nothing) = poolget(ref) # FIXME: Do move call -function Base.fetch(chunk::Chunk; raw=false) - if raw - poolget(chunk.handle) - else - collect(chunk) +@warn "Fix semantics of collect" maxlog=1 +function Base.fetch(chunk::Chunk{T}; unwrap::Bool=false, uniform::Bool=uniform_execution(), kwargs...) where T + value = fetch_handle(chunk.handle; uniform)::T + if unwrap && unwrappable(value) + return fetch(value; unwrap, uniform, kwargs...) end + return value end +fetch_handle(ref::DRef; uniform::Bool) = poolget(ref) +fetch_handle(ref::FileRef; uniform::Bool) = poolget(ref) +unwrappable(x::Chunk) = true +unwrappable(x::DRef) = true +unwrappable(x::FileRef) = true +unwrappable(x) = false # Unwrap Chunk, DRef, and FileRef by default move(from_proc::Processor, to_proc::Processor, x::Chunk) = @@ -100,32 +55,3 @@ move(to_proc::Processor, d::DRef) = move(OSProc(d.owner), to_proc, d) move(to_proc::Processor, x) = move(OSProc(), to_proc, x) - -### ChunkIO -affinity(r::DRef) = OSProc(r.owner)=>r.size -# this previously returned a vector with all machines that had the file cached -# but now only returns the owner and size, for consistency with affinity(::DRef), -# see #295 -affinity(r::FileRef) = OSProc(1)=>r.size - -struct WeakChunk - wid::Int - id::Int - x::WeakRef - function WeakChunk(c::Chunk) - return new(c.handle.owner, c.handle.id, WeakRef(c)) - end -end -unwrap_weak(c::WeakChunk) = c.x.value -function unwrap_weak_checked(c::WeakChunk) - cw = unwrap_weak(c) - @assert cw !== nothing "WeakChunk expired: ($(c.wid), $(c.id))" - return cw -end -wrap_weak(c::Chunk) = WeakChunk(c) -isweak(c::WeakChunk) = true -isweak(c::Chunk) = false -is_task_or_chunk(c::WeakChunk) = true -Serialization.serialize(io::AbstractSerializer, wc::WeakChunk) = - error("Cannot serialize a WeakChunk") -chunktype(c::WeakChunk) = chunktype(unwrap_weak_checked(c)) diff --git a/src/datadeps/aliasing.jl b/src/datadeps/aliasing.jl index c3e0ed20b..963629464 100644 --- a/src/datadeps/aliasing.jl +++ b/src/datadeps/aliasing.jl @@ -226,6 +226,7 @@ struct ArgumentWrapper function ArgumentWrapper(arg, dep_mod) h = hash(dep_mod) h = _identity_hash(arg, h) + check_uniform(h, arg) return new(arg, dep_mod, h) end end @@ -242,14 +243,16 @@ struct HistoryEntry end struct AliasedObjectCacheStore + accel::Acceleration keys::Vector{AbstractAliasing} derived::Dict{AbstractAliasing,AbstractAliasing} stored::Dict{MemorySpace,Set{AbstractAliasing}} values::Dict{MemorySpace,Dict{AbstractAliasing,Chunk}} originals::Set{AbstractAliasing} end -AliasedObjectCacheStore() = - AliasedObjectCacheStore(Vector{AbstractAliasing}(), +AliasedObjectCacheStore(accel::Acceleration) = + AliasedObjectCacheStore(accel, + Vector{AbstractAliasing}(), Dict{AbstractAliasing,AbstractAliasing}(), Dict{MemorySpace,Set{AbstractAliasing}}(), Dict{MemorySpace,Dict{AbstractAliasing,Chunk}}(), @@ -278,8 +281,9 @@ function get_stored(cache::AliasedObjectCacheStore, space::MemorySpace, ainfo::A end function set_stored!(cache::AliasedObjectCacheStore, dest_space::MemorySpace, value::Chunk, ainfo::AbstractAliasing) @assert !is_stored(cache, dest_space, ainfo) "Cache already has derived ainfo $ainfo" + check_uniform(value) key = cache.derived[ainfo] - value_ainfo = aliasing(value, identity) + value_ainfo = aliasing(cache.accel, value, identity) cache.derived[value_ainfo] = key push!(get!(Set{AbstractAliasing}, cache.stored, dest_space), key) values_dict = get!(Dict{AbstractAliasing,Chunk}, cache.values, dest_space) @@ -287,6 +291,7 @@ function set_stored!(cache::AliasedObjectCacheStore, dest_space::MemorySpace, va return end function set_key_stored!(cache::AliasedObjectCacheStore, space::MemorySpace, ainfo::AbstractAliasing, value::Chunk) + check_uniform(value) push!(cache.keys, ainfo) cache.derived[ainfo] = ainfo push!(get!(Set{AbstractAliasing}, cache.stored, space), ainfo) @@ -296,6 +301,7 @@ function set_key_stored!(cache::AliasedObjectCacheStore, space::MemorySpace, ain end struct AliasedObjectCache + accel::Acceleration space::MemorySpace chunk::Chunk end @@ -323,7 +329,8 @@ function get_stored(cache::AliasedObjectCache, ainfo::AbstractAliasing) cache_raw = unwrap(cache.chunk)::AliasedObjectCacheStore return get_stored(cache_raw, cache.space, ainfo) end -function set_stored!(cache::AliasedObjectCache, value::Chunk, ainfo::AbstractAliasing) + +function set_stored!(accel::DistributedAcceleration, cache::AliasedObjectCache, value::Chunk, ainfo::AbstractAliasing) wid = root_worker_id(cache.chunk) if wid != myid() return remotecall_fetch(set_stored!, wid, cache, value, ainfo) @@ -332,7 +339,8 @@ function set_stored!(cache::AliasedObjectCache, value::Chunk, ainfo::AbstractAli set_stored!(cache_raw, cache.space, value, ainfo) return end -function set_key_stored!(cache::AliasedObjectCache, space::MemorySpace, ainfo::AbstractAliasing, value::Chunk) + +function set_key_stored!(accel::DistributedAcceleration, cache::AliasedObjectCache, space::MemorySpace, ainfo::AbstractAliasing, value::Chunk) wid = root_worker_id(cache.chunk) if wid != myid() return remotecall_fetch(set_key_stored!, wid, cache, space, ainfo, value) @@ -340,25 +348,29 @@ function set_key_stored!(cache::AliasedObjectCache, space::MemorySpace, ainfo::A cache_raw = unwrap(cache.chunk)::AliasedObjectCacheStore set_key_stored!(cache_raw, space, ainfo, value) end -function aliased_object!(f, cache::AliasedObjectCache, x; ainfo=aliasing(x, identity)) + +function aliased_object!(f, cache::AliasedObjectCache, x; ainfo=aliasing(cache.accel, x, identity)) x_space = memory_space(x) if !is_key_present(cache, x_space, ainfo) # Preserve the object's memory-space/processor pairing when inserting # the source key. Using bare `tochunk(x)` defaults to OSProc, which can # incorrectly wrap GPU-backed objects as CPU chunks. x_chunk = x isa Chunk ? x : tochunk(x, first(processors(x_space))) - set_key_stored!(cache, x_space, ainfo, x_chunk) + set_key_stored!(cache.accel, cache, x_space, ainfo, x_chunk) end if is_stored(cache, ainfo) return get_stored(cache, ainfo) else y = f(x) @assert y isa Chunk "Didn't get a Chunk from functor" + a = memory_space(y) + b = cache.space + @assert memory_space(y) == cache.space "Space mismatch! $(memory_space(y)) != $(cache.space)" if memory_space(x) != cache.space - @assert ainfo != aliasing(y, identity) "Aliasing mismatch! $ainfo == $(aliasing(y, identity))" + @assert ainfo != aliasing(cache.accel, y, identity) "Aliasing mismatch! $ainfo == $(aliasing(cache.accel, y, identity))" end - set_stored!(cache, y, ainfo) + set_stored!(cache.accel, cache, y, ainfo) return y end end @@ -436,7 +448,10 @@ struct DataDepsState arg_history = Dict{ArgumentWrapper,Vector{HistoryEntry}}() arg_owner = Dict{ArgumentWrapper,MemorySpace}() arg_overlaps = Dict{ArgumentWrapper,Set{ArgumentWrapper}}() - ainfo_backing_chunk = tochunk(AliasedObjectCacheStore()) + accel = current_acceleration() + ainfo_backing_chunk = _with_default_acceleration() do + tochunk(AliasedObjectCacheStore(accel)) + end supports_inplace_cache = IdDict{Any,Bool}() ainfo_cache = Dict{ArgumentWrapper,AliasingWrapper}() @@ -497,7 +512,9 @@ function populate_task_info!(state::DataDepsState, task_args, spec::DTaskSpec, t arg_chunk = state.raw_arg_to_chunk[arg] else if !(arg isa Chunk) - arg_chunk = tochunk(arg) + arg_chunk = with(MPI_TID=>task.uid) do + tochunk(arg) + end state.raw_arg_to_chunk[arg] = arg_chunk else state.raw_arg_to_chunk[arg] = arg @@ -507,6 +524,7 @@ function populate_task_info!(state::DataDepsState, task_args, spec::DTaskSpec, t # Track the origin space of the argument origin_space = memory_space(arg_chunk) + check_uniform(origin_space) state.arg_origin[arg_chunk] = origin_space state.remote_arg_to_original[arg_chunk] = arg_chunk @@ -568,7 +586,7 @@ function aliasing!(state::DataDepsState, target_space::MemorySpace, arg_w::Argum end # Calculate the ainfo - ainfo = AliasingWrapper(aliasing(remote_arg, arg_w.dep_mod)) + ainfo = AliasingWrapper(aliasing(current_acceleration(), remote_arg, arg_w.dep_mod)) # Cache the result state.ainfo_cache[remote_arg_w] = ainfo @@ -671,7 +689,10 @@ region returns. """ supports_inplace_move(x) = true supports_inplace_move(t::DTask) = supports_inplace_move(fetch(t; raw=true)) +@warn "Fix supports_inplace_move for MPI" maxlog=1 function supports_inplace_move(c::Chunk) + # FIXME + return true # FIXME: Use MemPool.access_ref pid = root_worker_id(c.processor) if pid == myid() @@ -748,24 +769,35 @@ end isremotehandle(x) = false isremotehandle(x::DTask) = true isremotehandle(x::Chunk) = true +@warn "Properly propagate MPI_TID and uniformity through any remotecalls" maxlog=1 function generate_slot!(state::DataDepsState, dest_space, data) # N.B. We do not perform any sync/copy with the current owner of the data, # because all we want here is to make a copy of some version of the data, # even if the data is not up to date. orig_space = memory_space(data) + check_uniform(orig_space) to_proc = first(processors(dest_space)) + check_uniform(to_proc) from_proc = first(processors(orig_space)) + check_uniform(from_proc) + check_uniform(typeof(data)) dest_space_args = get!(IdDict{Any,Any}, state.remote_args, dest_space) - aliased_object_cache = AliasedObjectCache(dest_space, state.ainfo_backing_chunk) + aliased_object_cache = AliasedObjectCache(current_acceleration(), dest_space, state.ainfo_backing_chunk) ctx = Sch.eager_context() id = rand(Int) @maybelog ctx timespan_start(ctx, :move, (;thunk_id=0, id, position=ArgPosition(), processor=to_proc), (;f=nothing, data)) - data_chunk = move_rewrap(aliased_object_cache, from_proc, to_proc, orig_space, dest_space, data) + data_chunk = with(MPI_TID=>DATADEPS_CURRENT_TASK[].uid) do + remotecall_endpoint_toplevel(move_rewrap, current_acceleration(), aliased_object_cache, from_proc, to_proc, orig_space, dest_space, data) + end @maybelog ctx timespan_finish(ctx, :move, (;thunk_id=0, id, position=ArgPosition(), processor=to_proc), (;f=nothing, data=data_chunk)) @assert memory_space(data_chunk) == dest_space "space mismatch! $dest_space (dest) != $(memory_space(data_chunk)) (actual) ($(typeof(data)) (data) vs. $(typeof(data_chunk)) (chunk)), spaces ($orig_space -> $dest_space)" dest_space_args[data] = data_chunk state.remote_arg_to_original[data_chunk] = data + check_uniform(memory_space(dest_space_args[data])) + check_uniform(processor(dest_space_args[data])) + check_uniform(dest_space_args[data].handle) + return dest_space_args[data] end function get_or_generate_slot!(state, dest_space, data) @@ -778,64 +810,63 @@ function get_or_generate_slot!(state, dest_space, data) end return state.remote_args[dest_space][data] end -function remotecall_endpoint(f, from_proc, to_proc, from_space, to_space, data) - to_w = root_worker_id(to_proc) - if to_w == myid() - data_converted = f(move(from_proc, to_proc, data)) - return tochunk(data_converted, to_proc) - end - return remotecall_fetch(to_w, from_proc, to_proc, to_space, data) do from_proc, to_proc, to_space, data - data_converted = f(move(from_proc, to_proc, data)) - return tochunk(data_converted, to_proc) + +function remotecall_fetch_fast(::DistributedAcceleration, f, target, args...; kwargs...) + wid = root_worker_id(target) + if wid == myid() + return f(args...; kwargs...) end + return remotecall_fetch(f, wid, args...; kwargs...) end -function rewrap_aliased_object!(cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, x) - return aliased_object!(cache, x) do x - return remotecall_endpoint(identity, from_proc, to_proc, from_space, to_space, x) - end + +function is_local(accel::DistributedAcceleration, target) + return root_worker_id(target) == myid() end -function move_rewrap(cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, data::Chunk) - # Unwrap so that we hit the right dispatch - wid = root_worker_id(data) - if wid != myid() - return remotecall_fetch(move_rewrap, wid, cache, from_proc, to_proc, from_space, to_space, data) + +function remotecall_endpoint_toplevel(f, accel::DistributedAcceleration, cache::AliasedObjectCache, from_proc, to_proc, from_space, to_space, data::Chunk) + return remotecall_fetch_fast(root_worker_id(from_proc)) do + data_raw = unwrap(data) + return f(accel, cache, from_proc, to_proc, from_space, to_space, data_raw)::Chunk + end +end +function remotecall_endpoint_transfer(f, accel::DistributedAcceleration, from_proc, to_proc, from_space, to_space, data) + return remotecall_fetch_fast(root_worker_id(to_proc)) do + return f(accel, from_proc, to_proc, from_space, to_space, data) end - data_raw = unwrap(data) - return move_rewrap(cache, from_proc, to_proc, from_space, to_space, data_raw) end -function move_rewrap(cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, data) - # For generic data +@warn "Replace all remotecall_fetch calls with remotecall_endpoint" maxlog=1 +move_rewrap(accel, cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, data::Chunk) = + remotecall_endpoint(move_rewrap, accel, cache, from_proc, to_proc, from_space, to_space, data) +function move_rewrap(accel, cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, data) + # Generic data, do the transfer return aliased_object!(cache, data) do data - return remotecall_endpoint(identity, from_proc, to_proc, from_space, to_space, data) + return remotecall_endpoint_transfer(accel, from_proc, to_proc, from_space, to_space, data) do accel, from_proc, to_proc, from_space, to_space, data + return tochunk(move(from_proc, to_proc, data), to_proc, to_space) + end end end -function move_rewrap(cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, v::SubArray) +function move_rewrap(accel, cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, v::SubArray) to_w = root_worker_id(to_proc) - p_chunk = rewrap_aliased_object!(cache, from_proc, to_proc, from_space, to_space, parent(v)) + p_chunk = move_rewrap(accel, cache, from_proc, to_proc, from_space, to_space, parent(v)) + check_uniform(p_chunk.handle) inds = parentindices(v) - return remotecall_fetch(to_w, from_proc, to_proc, from_space, to_space, p_chunk, inds) do from_proc, to_proc, from_space, to_space, p_chunk, inds + return remotecall_endpoint_transfer(accel, from_proc, to_proc, from_space, to_space, p_chunk) do accel, from_proc, to_proc, from_space, to_space, p_chunk p_new = move(from_proc, to_proc, p_chunk) v_new = view(p_new, inds...) - return tochunk(v_new, to_proc) + return tochunk(v_new, to_proc, to_space) end end # FIXME: Do this programmatically via recursive dispatch for wrapper in (UpperTriangular, LowerTriangular, UnitUpperTriangular, UnitLowerTriangular) - @eval function move_rewrap(cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, v::$(wrapper)) - to_w = root_worker_id(to_proc) - p_chunk = rewrap_aliased_object!(cache, from_proc, to_proc, from_space, to_space, parent(v)) - return remotecall_fetch(to_w, from_proc, to_proc, from_space, to_space, p_chunk) do from_proc, to_proc, from_space, to_space, p_chunk + @eval function move_rewrap(accel, cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, v::$(wrapper)) + p_chunk = move_rewrap(accel, cache, from_proc, to_proc, from_space, to_space, parent(v)) + return remotecall_endpoint_transfer(accel, from_proc, to_proc, from_space, to_space, p_chunk) do accel, from_proc, to_proc, from_space, to_space, p_chunk p_new = move(from_proc, to_proc, p_chunk) v_new = $(wrapper)(p_new) - return tochunk(v_new, to_proc) + return tochunk(v_new, to_proc, to_space) end end end -function move_rewrap(cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, v::Base.RefValue) - return aliased_object!(cache, v) do v - return remotecall_endpoint(identity, from_proc, to_proc, from_space, to_space, v) - end -end #= FIXME: Make this work so we can automatically move-rewrap recursive objects function move_rewrap_recursive(cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, x::T) where T if isstructtype(T) diff --git a/src/datadeps/chunkview.jl b/src/datadeps/chunkview.jl index 1c2aa600f..418987124 100644 --- a/src/datadeps/chunkview.jl +++ b/src/datadeps/chunkview.jl @@ -31,11 +31,12 @@ end Base.view(c::DTask, slices...) = view(fetch(c; raw=true), slices...) -function aliasing(x::ChunkView{N}) where N +function aliasing(accel::Acceleration, x::ChunkView{N}, dep_mod) where N + @assert dep_mod === identity "Dependency modifiers not yet supported for ChunkView: $dep_mod" return remotecall_fetch(root_worker_id(x.chunk.processor), x.chunk, x.slices) do x, slices x = unwrap(x) v = view(x, slices...) - return aliasing(v) + return aliasing(accel, v, dep_mod) end end memory_space(x::ChunkView) = memory_space(x.chunk) @@ -64,4 +65,4 @@ function move(from_proc::Processor, to_proc::Processor, slice::ChunkView) end end -Base.fetch(slice::ChunkView) = view(fetch(slice.chunk), slice.slices...) \ No newline at end of file +Base.fetch(slice::ChunkView) = view(fetch(slice.chunk), slice.slices...) diff --git a/src/datadeps/queue.jl b/src/datadeps/queue.jl index 96112cf15..3b3ed5185 100644 --- a/src/datadeps/queue.jl +++ b/src/datadeps/queue.jl @@ -25,6 +25,8 @@ function enqueue!(queue::DataDepsTaskQueue, pairs::Vector{DTaskPair}) append!(queue.seen_tasks, pairs) end +const DATADEPS_CURRENT_TASK = TaskLocalValue{Union{DTask,Nothing}}(Returns(nothing)) + """ spawn_datadeps(f::Base.Callable) @@ -88,6 +90,7 @@ end const DATADEPS_SCHEDULER = ScopedValue{Union{DataDepsScheduler,Nothing}}(nothing) const DATADEPS_LAUNCH_WAIT = ScopedValue{Union{Bool,Nothing}}(nothing) +@warn "Add reliable, uniform-safe Processor sorting" maxlog=1 function distribute_tasks!(queue::DataDepsTaskQueue) #= TODO: Improvements to be made: # - Support for copying non-AbstractArray arguments @@ -98,20 +101,25 @@ function distribute_tasks!(queue::DataDepsTaskQueue) =# # Get the set of all processors to be scheduled on - all_procs = Processor[] - scope = get_compute_scope() - for w in procs() - append!(all_procs, get_processors(OSProc(w))) + accel = current_acceleration() + accel_procs = filter(procs(Dagger.Sch.eager_context())) do proc + Dagger.accel_matches_proc(accel, proc) end + all_procs = unique(vcat([collect(Dagger.get_processors(gp)) for gp in accel_procs]...)) + # FIXME: This is an unreliable way to ensure processor uniformity + sort!(all_procs, by=short_name) + scope = get_compute_scope() filter!(proc->proc_in_scope(proc, scope), all_procs) if isempty(all_procs) throw(Sch.SchedulingException("No processors available, try widening scope")) end + if uniform_execution(accel) + for proc in all_procs + check_uniform(proc) + end + end all_scope = UnionScope(map(ExactScope, all_procs)) exec_spaces = unique(vcat(map(proc->collect(memory_spaces(proc)), all_procs)...)) - if !all(space->space isa CPURAMMemorySpace, exec_spaces) && !all(space->root_worker_id(space) == myid(), exec_spaces) - @warn "Datadeps support for multi-GPU, multi-worker is currently broken\nPlease be prepared for incorrect results or errors" maxlog=1 - end # Round-robin assign tasks to processors upper_queue = get_options(:task_queue) @@ -128,7 +136,9 @@ function distribute_tasks!(queue::DataDepsTaskQueue) # Copy args from remote to local # N.B. We sort the keys to ensure a deterministic order for uniformity + check_uniform(length(state.arg_owner)) for arg_w in sort(collect(keys(state.arg_owner)); by=arg_w->arg_w.hash) + check_uniform(arg_w) arg = arg_w.arg origin_space = state.arg_origin[arg] remainder, _ = compute_remainder_for_arg!(state, origin_space, arg_w, write_num) @@ -199,11 +209,15 @@ function distribute_task!(queue::DataDepsTaskQueue, state::DataDepsState, all_pr fargs::Vector{Argument} end + DATADEPS_CURRENT_TASK[] = task + task_scope = @something(spec.options.compute_scope, spec.options.scope, DefaultScope()) scheduler = queue.scheduler our_proc = datadeps_schedule_task(scheduler, state, all_procs, all_scope, task_scope, spec, task) @assert our_proc in all_procs our_space = only(memory_spaces(our_proc)) + check_uniform(our_proc) + check_uniform(our_space) # Find the scope for this task (and its copies) task_scope = @something(spec.options.compute_scope, spec.options.scope, DefaultScope()) @@ -308,6 +322,9 @@ function distribute_task!(queue::DataDepsTaskQueue, state::DataDepsState, all_pr if spec.options.syncdeps === nothing spec.options.syncdeps = Set{ThunkSyncdep}() end + if spec.options.tag === nothing + spec.options.tag = to_tag() + end syncdeps = spec.options.syncdeps map_or_ntuple(task_arg_ws) do idx arg_ws = task_arg_ws[idx] @@ -342,7 +359,9 @@ function distribute_task!(queue::DataDepsTaskQueue, state::DataDepsState, all_pr new_spec = DTaskSpec(new_fargs, spec.options) new_spec.options.scope = our_scope new_spec.options.exec_scope = our_scope - new_spec.options.occupancy = Dict(Any=>0) + if uniform_execution() + new_spec.options.occupancy = Dict(Any=>0) + end ctx = Sch.eager_context() @maybelog ctx timespan_start(ctx, :datadeps_execute, (;thunk_id=task.uid), (;)) enqueue!(queue.upper_queue, DTaskPair(new_spec, task)) @@ -370,5 +389,7 @@ function distribute_task!(queue::DataDepsTaskQueue, state::DataDepsState, all_pr write_num += 1 + DATADEPS_CURRENT_TASK[] = nothing + return write_num end diff --git a/src/datadeps/remainders.jl b/src/datadeps/remainders.jl index 2c2c49920..ee1b060db 100644 --- a/src/datadeps/remainders.jl +++ b/src/datadeps/remainders.jl @@ -98,7 +98,7 @@ function compute_remainder_for_arg!(state::DataDepsState, for entry in state.arg_history[arg_w] push!(spaces_set, entry.space) end - spaces = collect(spaces_set) + spaces = sort(collect(spaces_set), by=short_name) N = length(spaces) # Lookup all memory spans for arg_w in these spaces @@ -118,6 +118,8 @@ function compute_remainder_for_arg!(state::DataDepsState, @goto restart end end + check_uniform(spaces) + check_uniform(target_ainfos) # We may only need to schedule a full copy from the origin space to the # target space if this is the first time we've written to `arg_w` @@ -159,6 +161,8 @@ function compute_remainder_for_arg!(state::DataDepsState, other_ainfo = aliasing!(state, owner_space, arg_w) other_space = owner_space end + check_uniform(other_ainfo) + check_uniform(other_space) # Lookup all memory spans for arg_w in these spaces other_remote_arg_w = first(collect(state.ainfo_arg[other_ainfo])) @@ -174,6 +178,7 @@ function compute_remainder_for_arg!(state::DataDepsState, foreach(other_many_spans) do span verify_span(span) end + check_uniform(other_many_spans) if other_space == target_space # Only subtract, this data is already up-to-date in target_space @@ -250,7 +255,9 @@ Enqueues a copy operation to update the remainder regions of an object before a function enqueue_remainder_copy_to!(state::DataDepsState, dest_space::MemorySpace, arg_w::ArgumentWrapper, remainder_aliasing::MultiRemainderAliasing, f, idx, dest_scope, task, write_num::Int) for remainder in remainder_aliasing.remainders + check_uniform(remainder.space) @assert !isempty(remainder.spans) + check_uniform(remainder.spans) enqueue_remainder_copy_to!(state, dest_space, arg_w, remainder, f, idx, dest_scope, task, write_num) end end @@ -304,7 +311,9 @@ Enqueues a copy operation to update the remainder regions of an object back to t function enqueue_remainder_copy_from!(state::DataDepsState, dest_space::MemorySpace, arg_w::ArgumentWrapper, remainder_aliasing::MultiRemainderAliasing, dest_scope, write_num::Int) for remainder in remainder_aliasing.remainders + check_uniform(remainder.space) @assert !isempty(remainder.spans) + check_uniform(remainder.spans) enqueue_remainder_copy_from!(state, dest_space, arg_w, remainder, dest_scope, write_num) end end @@ -537,4 +546,4 @@ function find_object_holding_ptr(A::SparseMatrixCSC, ptr::UInt64) span = LocalMemorySpan(pointer(A.rowval), length(A.rowval)*sizeof(eltype(A.rowval))) @assert span_start(span) <= ptr <= span_end(span) "Pointer $ptr not found in SparseMatrixCSC" return A.rowval -end \ No newline at end of file +end diff --git a/src/dtask.jl b/src/dtask.jl index e94803502..c9e9e811f 100644 --- a/src/dtask.jl +++ b/src/dtask.jl @@ -11,14 +11,14 @@ Base.wait(t::ThunkFuture) = Dagger.Sch.thunk_yield() do wait(t.future) return end -function Base.fetch(t::ThunkFuture; proc=OSProc(), raw=false) +function Base.fetch(t::ThunkFuture; proc=OSProc(), raw=false, move_value=!raw, unwrap=!raw, uniform=uniform_execution()) error, value = Dagger.Sch.thunk_yield() do fetch(t.future) end if error throw(value) end - if raw + if !move_value return value else return move(proc, value) @@ -65,11 +65,11 @@ function Base.wait(t::DTask) wait(t.future) return end -function Base.fetch(t::DTask; raw=false) +function Base.fetch(t::DTask; raw=false, move_value=!raw, unwrap=!raw, uniform=false) if !istaskstarted(t) throw(ConcurrencyViolationError("Cannot `fetch` an unlaunched `DTask`")) end - return fetch(t.future; raw) + return fetch(t.future; move_value, unwrap, uniform) end function waitany(tasks::Vector{DTask}) if isempty(tasks) diff --git a/src/lib/domain-blocks.jl b/src/lib/domain-blocks.jl index 2a0854e3b..95e5c360f 100644 --- a/src/lib/domain-blocks.jl +++ b/src/lib/domain-blocks.jl @@ -6,6 +6,8 @@ struct DomainBlocks{N} <: AbstractArray{ArrayDomain{N, NTuple{N, UnitRange{Int}} end Base.@deprecate_binding BlockedDomains DomainBlocks +ndims(::DomainBlocks{N}) where N = N + size(x::DomainBlocks) = map(length, x.cumlength) function _getindex(x::DomainBlocks{N}, idx::Tuple) where N starts = map((vec, i) -> i == 0 ? 0 : getindex(vec,i), x.cumlength, map(x->x-1, idx)) diff --git a/src/memory-spaces.jl b/src/memory-spaces.jl index eb4f7ad5b..4aeebc058 100644 --- a/src/memory-spaces.jl +++ b/src/memory-spaces.jl @@ -4,18 +4,10 @@ end CPURAMMemorySpace() = CPURAMMemorySpace(myid()) root_worker_id(space::CPURAMMemorySpace) = space.owner -memory_space(x) = CPURAMMemorySpace(myid()) -function memory_space(x::Chunk) - proc = processor(x) - if proc isa OSProc - # TODO: This should probably be programmable - return CPURAMMemorySpace(proc.pid) - else - return only(memory_spaces(proc)) - end -end -memory_space(x::DTask) = - memory_space(fetch(x; raw=true)) +memory_space(x, proc::Processor=default_processor()) = first(memory_spaces(proc)) +memory_space(x::Processor) = first(memory_spaces(x)) +memory_space(x::Chunk) = x.space +memory_space(x::DTask) = memory_space(fetch(x; move_value=false, unwrap=false)) memory_spaces(::P) where {P<:Processor} = throw(ArgumentError("Must define `memory_spaces` for `$P`")) @@ -28,9 +20,10 @@ processors(space::CPURAMMemorySpace) = ### In-place Data Movement -function unwrap(x::Chunk) - @assert x.handle.owner == myid() - MemPool.poolget(x.handle) +unwrap(x::Chunk) = unwrap(x.handle) +function unwrap(handle::DRef) + @assert root_worker_id(handle) == myid() "DRef $handle is not owned by this process: $(root_worker_id(handle)) != $(myid())" + return MemPool.poolget(handle) end move!(dep_mod, to_space::MemorySpace, from_space::MemorySpace, to::T, from::F) where {T,F} = throw(ArgumentError("No `move!` implementation defined for $F -> $T")) @@ -69,6 +62,16 @@ function move!(::Type{<:Tridiagonal}, to_space::MemorySpace, from_space::MemoryS return end +# FIXME: Take MemorySpace instead +function move_type(from_proc::Processor, to_proc::Processor, ::Type{T}) where T + if from_proc == to_proc + return T + end + return Base._return_type(move, Tuple{typeof(from_proc), typeof(to_proc), T}) +end +move_type(from_proc::Processor, to_proc::Processor, ::Type{<:Chunk{T}}) where T = + move_type(from_proc, to_proc, T) + ### Aliasing and Memory Spans type_may_alias(::Type{String}) = false @@ -355,6 +358,7 @@ function memory_spans(oa::ObjectAliasing{S}) where S return [span] end +aliasing(accel::Acceleration, x, T) = aliasing(x, T) function aliasing(x, dep_mod) if dep_mod isa Symbol return aliasing(getfield(x, dep_mod)) @@ -391,19 +395,25 @@ aliasing(::String) = NoAliasing() # FIXME: Not necessarily true aliasing(::Symbol) = NoAliasing() aliasing(::Type) = NoAliasing() function aliasing(x::Chunk, T) - @assert x.handle isa DRef if root_worker_id(x.processor) == myid() return aliasing(unwrap(x), T) end + @assert x.handle isa DRef return remotecall_fetch(root_worker_id(x.processor), x, T) do x, T aliasing(unwrap(x), T) end end -aliasing(x::Chunk) = remotecall_fetch(root_worker_id(x.processor), x) do x - aliasing(unwrap(x)) +function aliasing(x::Chunk) + if root_worker_id(x.processor) == myid() + return aliasing(unwrap(x)) + end + @assert x.handle isa DRef + return remotecall_fetch(root_worker_id(x.processor), x) do x + aliasing(unwrap(x)) + end end -aliasing(x::DTask, T) = aliasing(fetch(x; raw=true), T) -aliasing(x::DTask) = aliasing(fetch(x; raw=true)) +aliasing(x::DTask, T) = aliasing(fetch(x; move_value=false, unwrap=false), T) +aliasing(x::DTask) = aliasing(fetch(x; move_value=false, unwrap=false)) function aliasing(x::Base.RefValue{T}) where T addr = UInt(Base.pointer_from_objref(x) + fieldoffset(typeof(x), 1)) @@ -611,5 +621,5 @@ unsafe_free!(x::Chunk) = remotecall_fetch(root_worker_id(x), x) do x unsafe_free!(unwrap(x)) return end -unsafe_free!(x::DTask) = unsafe_free!(fetch(x; raw=true)) +unsafe_free!(x::DTask) = unsafe_free!(fetch(x; move_value=false, unwrap=false)) unsafe_free!(x) = nothing # Do nothing by default diff --git a/src/mpi.jl b/src/mpi.jl new file mode 100644 index 000000000..8daaedd18 --- /dev/null +++ b/src/mpi.jl @@ -0,0 +1,1110 @@ +@warn "Move to MPIExt.jl" maxlog=1 + +using MPI +import MPIRPC + +const CHECK_UNIFORMITY = Ref{Bool}(false) +function check_uniformity!(check::Bool=true) + CHECK_UNIFORMITY[] = check +end +function check_uniform(value::Integer, original=value) + CHECK_UNIFORMITY[] && uniform_execution() || return true + comm = MPI.COMM_WORLD + rank = MPI.Comm_rank(comm) + matched = compare_all(value, comm) + if !matched + if rank == 0 + Core.print("[$rank] Found non-uniform value!\n") + end + Core.print("[$rank] value=$value, original=$original\n") + throw(ArgumentError("Non-uniform value")) + end + MPI.Barrier(comm) + return matched +end +function check_uniform(value, original=value) + CHECK_UNIFORMITY[] && uniform_execution() || return true + return check_uniform(hash(value), original) +end + +# MPI tag for `compare_all` / `check_uniform` only. Must not collide with other +# Dagger P2P on the same `comm` (e.g. `remotecall_endpoint_toplevel` uses tag `0` +# for broadcast metadata), or ranks can steal each other's messages and hang +# until `mpi_deadlock_detect` fires. +const COMPARE_ALL_MPI_TAG = UInt32(7243) + +function compare_all(value, comm) + rank = MPI.Comm_rank(comm) + size = MPI.Comm_size(comm) + for i in 0:(size-1) + if i != rank + send_yield(value, comm, i, COMPARE_ALL_MPI_TAG) + end + end + match = true + for i in 0:(size-1) + if i != rank + other_value = recv_yield(comm, i, COMPARE_ALL_MPI_TAG) + if value != other_value + match = false + end + end + end + return match +end + +struct MPIAcceleration <: Acceleration + comm::MPI.Comm +end +MPIAcceleration() = MPIAcceleration(MPI.COMM_WORLD) + +function aliasing(accel::MPIAcceleration, x::Chunk, T) + handle = x.handle::MPIRef + @assert accel.comm == handle.comm "MPIAcceleration comm mismatch" + tag = to_tag() + check_uniform(tag) + rank = MPI.Comm_rank(accel.comm) + if handle.rank == rank + ainfo = _with_default_acceleration() do + aliasing(x, T) + end + @opcounter :aliasing_bcast_send_yield + bcast_send_yield(ainfo, accel.comm, handle.rank, tag) + else + ainfo = recv_yield(accel.comm, handle.rank, tag) + end + check_uniform(ainfo) + return ainfo +end + +default_processor(accel::MPIAcceleration) = MPIOSProc(accel.comm, 0) +default_processor(accel::MPIAcceleration, x) = MPIOSProc(accel.comm, 0) +default_processor(accel::MPIAcceleration, x::Chunk) = MPIOSProc(x.handle.comm, x.handle.rank) +default_processor(accel::MPIAcceleration, x::Function) = MPIOSProc(accel.comm, MPI.Comm_rank(accel.comm)) +default_processor(accel::MPIAcceleration, T::Type) = MPIOSProc(accel.comm, MPI.Comm_rank(accel.comm)) +uniform_execution(accel::MPIAcceleration) = true + +@warn "Add a lock to MPIClusterProcChildren" maxlog=1 +const MPIClusterProcChildren = Dict{MPI.Comm, Set{Processor}}() + +struct MPIClusterProc <: Processor + comm::MPI.Comm + function MPIClusterProc(comm::MPI.Comm) + populate_children!(comm) + return new(comm) + end +end + +Sch.init_proc(state, proc::MPIClusterProc, log_sink) = Sch.init_proc(state, MPIOSProc(proc.comm), log_sink) + +MPIClusterProc() = MPIClusterProc(MPI.COMM_WORLD) + +function populate_children!(comm::MPI.Comm) + children = get_processors(OSProc()) + MPIClusterProcChildren[comm] = children +end + +struct MPIOSProc <: Processor + comm::MPI.Comm + rank::Int +end + +function MPIOSProc(comm::MPI.Comm) + rank = MPI.Comm_rank(comm) + return MPIOSProc(comm, rank) +end + +function MPIOSProc() + return MPIOSProc(MPI.COMM_WORLD) +end + +ProcessScope(p::MPIOSProc) = ProcessScope(myid()) + +function check_uniform(proc::MPIOSProc, original=proc) + return check_uniform(hash(MPIOSProc), original) && + check_uniform(proc.rank, original) +end + +function memory_spaces(proc::MPIOSProc) + children = get_processors(proc) + spaces = Set{MemorySpace}() + for proc in children + for space in memory_spaces(proc) + push!(spaces, space) + end + end + return spaces +end + +struct MPIProcessScope <: AbstractScope + comm::MPI.Comm + rank::Int +end + +Base.isless(::MPIProcessScope, ::MPIProcessScope) = false +Base.isless(::MPIProcessScope, ::NodeScope) = true +Base.isless(::MPIProcessScope, ::UnionScope) = true +Base.isless(::MPIProcessScope, ::TaintScope) = true +Base.isless(::MPIProcessScope, ::AnyScope) = true +constrain(x::MPIProcessScope, y::MPIProcessScope) = + x == y ? y : InvalidScope(x, y) +constrain(x::NodeScope, y::MPIProcessScope) = + x == y.parent ? y : InvalidScope(x, y) + +Base.isless(::ExactScope, ::MPIProcessScope) = true +constrain(x::MPIProcessScope, y::ExactScope) = + x == y.parent ? y : InvalidScope(x, y) + +function enclosing_scope(proc::MPIOSProc) + return MPIProcessScope(proc.comm, proc.rank) +end + +function Dagger.to_scope(::Val{:mpi_rank}, sc::NamedTuple) + if sc.mpi_rank == Colon() + return Dagger.to_scope(Val{:mpi_ranks}(), merge(sc, (;mpi_ranks=Colon()))) + else + @assert sc.mpi_rank isa Integer "Expected a single GPU device ID for :mpi_rank, got $(sc.mpi_rank)\nConsider using :mpi_ranks instead." + return Dagger.to_scope(Val{:mpi_ranks}(), merge(sc, (;mpi_ranks=[sc.mpi_rank]))) + end +end +Dagger.scope_key_precedence(::Val{:mpi_rank}) = 2 +function Dagger.to_scope(::Val{:mpi_ranks}, sc::NamedTuple) + comm = get(sc, :mpi_comm, MPI.COMM_WORLD) + if sc.ranks != Colon() + ranks = sc.ranks + else + ranks = MPI.Comm_size(comm) + end + inner_sc = NamedTuple(filter(kv->kv[1] != :mpi_ranks, Base.pairs(sc))...) + # FIXME: What to do here? + inner_scope = Dagger.to_scope(inner_sc) + scopes = Dagger.ExactScope[] + for rank in ranks + procs = Dagger.get_processors(Dagger.MPIOSProc(comm, rank)) + rank_scope = MPIProcessScope(comm, rank) + for proc in procs + proc_scope = Dagger.ExactScope(proc) + constrain(proc_scope, rank_scope) isa Dagger.InvalidScope && continue + push!(scopes, proc_scope) + end + end + return Dagger.UnionScope(scopes) +end +Dagger.scope_key_precedence(::Val{:mpi_ranks}) = 2 + +struct MPIProcessor{P<:Processor} <: Processor + innerProc::P + comm::MPI.Comm + rank::Int +end +proc_in_scope(proc::Processor, scope::MPIProcessScope) = false +proc_in_scope(proc::MPIProcessor, scope::MPIProcessScope) = + proc.comm == scope.comm && proc.rank == scope.rank + +function check_uniform(proc::MPIProcessor, original=proc) + return check_uniform(hash(MPIProcessor), original) && + check_uniform(proc.rank, original) && + # TODO: Not always valid (if pointer is embedded, say for GPUs) + check_uniform(hash(proc.innerProc), original) +end + +Dagger.iscompatible_func(::MPIProcessor, opts, ::Any) = true +Dagger.iscompatible_arg(::MPIProcessor, opts, ::Any) = true + +default_enabled(proc::MPIProcessor) = default_enabled(proc.innerProc) + +root_worker_id(proc::MPIProcessor) = myid() +root_worker_id(proc::MPIOSProc) = myid() +root_worker_id(proc::MPIClusterProc) = myid() + +get_parent(proc::MPIClusterProc) = proc +get_parent(proc::MPIOSProc) = MPIClusterProc(proc.comm) +get_parent(proc::MPIProcessor) = MPIOSProc(proc.comm, proc.rank) + +short_name(proc::MPIProcessor) = "(MPI: $(proc.rank), $(short_name(proc.innerProc)))" + +function get_processors(mosProc::MPIOSProc) + populate_children!(mosProc.comm) + children = MPIClusterProcChildren[mosProc.comm] + mpiProcs = Set{Processor}() + for proc in children + push!(mpiProcs, MPIProcessor(proc, mosProc.comm, mosProc.rank)) + end + return mpiProcs +end + +#TODO: non-uniform ranking through MPI groups +#TODO: use a lazy iterator +function get_processors(proc::MPIClusterProc) + children = Set{Processor}() + for i in 0:(MPI.Comm_size(proc.comm)-1) + for innerProc in MPIClusterProcChildren[proc.comm] + push!(children, MPIProcessor(innerProc, proc.comm, i)) + end + end + return children +end + +struct MPIMemorySpace{S<:MemorySpace} <: MemorySpace + innerSpace::S + comm::MPI.Comm + rank::Int +end + +function Base.:(==)(a::MPIMemorySpace, b::MPIMemorySpace) + return a.innerSpace == b.innerSpace && + a.comm == b.comm && + a.rank == b.rank +end +Base.hash(space::MPIMemorySpace, h::UInt=UInt(0)) = + hash(space.innerSpace, hash(space.comm.val, hash(space.rank, hash(MPIMemorySpace, h)))) + +function check_uniform(space::MPIMemorySpace, original=space) + return check_uniform(space.rank, original) && + # TODO: Not always valid (if pointer is embedded, say for GPUs) + check_uniform(hash(space.innerSpace), original) +end + +default_processor(space::MPIMemorySpace) = MPIOSProc(space.comm, space.rank) +default_memory_space(accel::MPIAcceleration) = MPIMemorySpace(CPURAMMemorySpace(myid()), accel.comm, 0) + +default_memory_space(accel::MPIAcceleration, x) = MPIMemorySpace(CPURAMMemorySpace(myid()), accel.comm, 0) +default_memory_space(accel::MPIAcceleration, x::Chunk) = MPIMemorySpace(CPURAMMemorySpace(myid()), x.handle.comm, x.handle.rank) +default_memory_space(accel::MPIAcceleration, x::Function) = MPIMemorySpace(CPURAMMemorySpace(myid()), accel.comm, MPI.Comm_rank(accel.comm)) +default_memory_space(accel::MPIAcceleration, T::Type) = MPIMemorySpace(CPURAMMemorySpace(myid()), accel.comm, MPI.Comm_rank(accel.comm)) + +function memory_spaces(proc::MPIClusterProc) + rawMemSpace = Set{MemorySpace}() + for rnk in 0:(MPI.Comm_size(proc.comm) - 1) + for innerSpace in memory_spaces(OSProc()) + push!(rawMemSpace, MPIMemorySpace(innerSpace, proc.comm, rnk)) + end + end + return rawMemSpace +end + +function memory_spaces(proc::MPIProcessor) + rawMemSpace = Set{MemorySpace}() + for innerSpace in memory_spaces(proc.innerProc) + push!(rawMemSpace, MPIMemorySpace(innerSpace, proc.comm, proc.rank)) + end + return rawMemSpace +end + +root_worker_id(mem_space::MPIMemorySpace) = myid() + +function processors(memSpace::MPIMemorySpace) + rawProc = Set{Processor}() + for innerProc in processors(memSpace.innerSpace) + push!(rawProc, MPIProcessor(innerProc, memSpace.comm, memSpace.rank)) + end + return rawProc +end + +struct MPIRefID + tid::UInt32 + generic::Bool + id::UInt32 + function MPIRefID(tid, generic, id) + @assert tid > 0 || generic "Invalid MPIRefID: tid=$tid, generic=$generic, id=$id" + return new(tid, generic, id) + end +end +Base.hash(id::MPIRefID, h::UInt=UInt(0)) = + hash(id.tid, hash(id.generic, hash(id.id, hash(MPIRefID, h)))) + +function check_uniform(ref::MPIRefID, original=ref) + return check_uniform(ref.tid, original) && + check_uniform(ref.generic, original) && + check_uniform(ref.id, original) +end + +function to_tag() + if Dagger.in_task() + # Tag is already assigned + opts = Dagger.get_tls().task_spec.options + tag = opts.tag + return tag + end + + # Generate a tag based on the TID + @assert !Sch.SCHED_MOVE[] "We should not create a tag during Sch move" + return to_tag(take_ref_id!()) +end +to_tag(id::MPIRefID) = id.generic ? id.id : id.tid + +# Semi-public internal value for passing TID to MPIRefID generation +const MPI_TID = ScopedValue{Int64}(0) +# Private internal value for tracking TID-based ID generations +#const _MPIREF_TID = Dict{Int, Threads.Atomic{Int}}() +# Private internal value for tracking non-TID (uniform) ID generations +#const _MPIREF_GENERIC = Threads.Atomic{Int}(1) + +mutable struct MPIRef + comm::MPI.Comm + rank::Int + size::Int + innerRef::Union{DRef, Nothing} + id::MPIRefID +end +Base.hash(ref::MPIRef, h::UInt=UInt(0)) = hash(ref.id, hash(MPIRef, h)) +root_worker_id(ref::MPIRef) = myid() + +function check_uniform(ref::MPIRef, original=ref) + return check_uniform(ref.rank, original) && + check_uniform(ref.id, original) +end + +function unwrap(handle::MPIRef) + @assert handle.rank == MPI.Comm_rank(handle.comm) "MPIRef $handle is not owned by this rank: $(handle.rank) != $(MPI.Comm_rank(handle.comm))" + return unwrap(handle.innerRef) +end + +to_tag(ref::MPIRef) = to_tag(ref.id) + +move(from_proc::Processor, to_proc::Processor, x::MPIRef) = + move(from_proc, to_proc, poolget(x; uniform=uniform_execution())) + +function affinity(x::MPIRef) + if x.innerRef === nothing + return MPIOSProc(x.comm, x.rank)=>0 + else + return MPIOSProc(x.comm, x.rank)=>x.innerRef.size + end +end + +function take_ref_id!() + tid = 0 + generic = 0 + id = 0 + if Dagger.in_task() + tid = sch_handle().thunk_id.id + #counter = get!(_MPIREF_TID, tid, Threads.Atomic{Int}(1)) + #id = Threads.atomic_add!(counter, 1) + id = tid + elseif MPI_TID[] != 0 + tid = MPI_TID[] + #counter = get!(_MPIREF_TID, tid, Threads.Atomic{Int}(1)) + #id = Threads.atomic_add!(counter, 1) + id = tid + else + if current_task() !== Base.roottask + throw(ConcurrencyViolationError("Attempted to generate generic MPIRefID in a multi-threaded context")) + end + generic = true + #id = Threads.atomic_add!(_MPIREF_GENERIC, 1) + id = next_id() # Abuse the TID counter for generic IDs + check_uniform(id) + end + @assert id < MPI.tag_ub() + return MPIRefID(tid, generic, id) +end + +#TODO: partitioned scheduling with comm bifurcation +function tochunk_pset(x, space::MPIMemorySpace; device=nothing, force_nonlocal=false, kwargs...) + @assert space.comm == MPI.COMM_WORLD "$(space.comm) != $(MPI.COMM_WORLD)" + local_rank = MPI.Comm_rank(space.comm) + Mid = take_ref_id!() + if local_rank != space.rank || force_nonlocal + return MPIRef(space.comm, space.rank, 0, nothing, Mid) + else + # type= is for Chunk metadata only; MemPool.poolset does not accept it + pset_kw = (; (k => v for (k, v) in pairs(kwargs) if k !== :type)...) + return MPIRef(space.comm, space.rank, sizeof(x), poolset(x; device, pset_kw...), Mid) + end +end + +const DEADLOCK_DETECT = TaskLocalValue{Bool}(()->true) +const DEADLOCK_WARN_PERIOD = TaskLocalValue{Float64}(()->10.0) +const DEADLOCK_TIMEOUT_PERIOD = TaskLocalValue{Float64}(()->120.0) +const RECV_WAITING = Base.Lockable(Dict{Tuple{MPI.Comm, Int, Int}, Base.Event}()) + +@warn "Rename and make generic these in-place structs" maxlog=1 +struct InplaceInfo + type::DataType + shape::Tuple +end +struct InplaceSparseInfo + type::DataType + m::Int + n::Int + colptr::Int + rowval::Int + nzval::Int +end + +function supports_inplace_mpi(value) + if value isa DenseArray && isbitstype(eltype(value)) + return true + else + return false + end +end +function recv_yield!(buffer, comm, src, tag) + rank = MPI.Comm_rank(comm) + #Core.println("buffer recv: $buffer, type of buffer: $(typeof(buffer)), is in place? $(supports_inplace_mpi(buffer))") + if !supports_inplace_mpi(buffer) + return recv_yield(comm, src, tag), false + end + #Core.println("[rank $(MPI.Comm_rank(comm))][tag $tag] Starting recv! from [$src]") + + # Ensure no other receiver is waiting + our_event = Base.Event() + @label retry + other_event = lock(RECV_WAITING) do waiting + if haskey(waiting, (comm, src, tag)) + waiting[(comm, src, tag)] + else + waiting[(comm, src, tag)] = our_event + nothing + end + end + if other_event !== nothing + #Core.println("[rank $(MPI.Comm_rank(comm))][tag $tag] Waiting for other receiver...") + wait(other_event) + @goto retry + end + + buffer = recv_yield_inplace!(buffer, comm, rank, src, tag) + + lock(RECV_WAITING) do waiting + delete!(waiting, (comm, src, tag)) + notify(our_event) + end + + return buffer, true + +end + +function recv_yield(comm, src, tag) + rank = MPI.Comm_rank(comm) + #Core.println("[rank $(MPI.Comm_rank(comm))][tag $tag] Starting recv from [$src]") + + # Ensure no other receiver is waiting + our_event = Base.Event() + @label retry + other_event = lock(RECV_WAITING) do waiting + if haskey(waiting, (comm, src, tag)) + waiting[(comm, src, tag)] + else + waiting[(comm, src, tag)] = our_event + nothing + end + end + if other_event !== nothing + #Core.println("[rank $(MPI.Comm_rank(comm))][tag $tag] Waiting for other receiver...") + wait(other_event) + @goto retry + end + #Core.println("[rank $(MPI.Comm_rank(comm))][tag $tag] Receiving...") + + type = nothing + @label receive + value = recv_yield_serialized(comm, rank, src, tag) + if value isa InplaceInfo || value isa InplaceSparseInfo + value = recv_yield_inplace(value, comm, rank, src, tag) + end + + lock(RECV_WAITING) do waiting + delete!(waiting, (comm, src, tag)) + notify(our_event) + end + return value +end + +function recv_yield_inplace!(array, comm, my_rank, their_rank, tag) + time_start = time_ns() + detect = DEADLOCK_DETECT[] + warn_period = round(UInt64, DEADLOCK_WARN_PERIOD[] * 1e9) + timeout_period = round(UInt64, DEADLOCK_TIMEOUT_PERIOD[] * 1e9) + + while true + (got, msg, stat) = MPI.Improbe(their_rank, tag, comm, MPI.Status) + if got + if MPI.Get_error(stat) != MPI.SUCCESS + error("recv_yield failed with error $(MPI.Get_error(stat))") + end + count = MPI.Get_count(stat, UInt8) + @assert count == sizeof(array) "recv_yield_inplace: expected $(sizeof(array)) bytes, got $count" + buf = MPI.Buffer(array) + req = MPI.Imrecv!(buf, msg) + __wait_for_request(req, comm, my_rank, their_rank, tag, "recv_yield", "recv") + return array + end + warn_period = mpi_deadlock_detect(detect, time_start, warn_period, timeout_period, my_rank, tag, "recv", their_rank) + yield() + end +end + +function recv_yield_inplace(_value::InplaceInfo, comm, my_rank, their_rank, tag) + T = _value.type + @assert T <: Array && isbitstype(eltype(T)) "recv_yield_inplace only supports inplace MPI transfers of bitstype dense arrays" + array = Array{eltype(T)}(undef, _value.shape) + return recv_yield_inplace!(array, comm, my_rank, their_rank, tag) +end + +function recv_yield_inplace(_value::InplaceSparseInfo, comm, my_rank, their_rank, tag) + T = _value.type + @assert T <: SparseMatrixCSC "recv_yield_inplace only supports inplace MPI transfers of SparseMatrixCSC" + + colptr = recv_yield_inplace!(Vector{Int64}(undef, _value.colptr), comm, my_rank, their_rank, tag) + rowval = recv_yield_inplace!(Vector{Int64}(undef, _value.rowval), comm, my_rank, their_rank, tag) + nzval = recv_yield_inplace!(Vector{eltype(T)}(undef, _value.nzval), comm, my_rank, their_rank, tag) + + return SparseMatrixCSC{eltype(T), Int64}(_value.m, _value.n, colptr, rowval, nzval) +end + +function recv_yield_serialized(comm, my_rank, their_rank, tag) + time_start = time_ns() + detect = DEADLOCK_DETECT[] + warn_period = round(UInt64, DEADLOCK_WARN_PERIOD[] * 1e9) + timeout_period = round(UInt64, DEADLOCK_TIMEOUT_PERIOD[] * 1e9) + + while true + (got, msg, stat) = MPI.Improbe(their_rank, tag, comm, MPI.Status) + if got + if MPI.Get_error(stat) != MPI.SUCCESS + error("recv_yield failed with error $(MPI.Get_error(stat))") + end + count = MPI.Get_count(stat, UInt8) + buf = Array{UInt8}(undef, count) + req = MPI.Imrecv!(MPI.Buffer(buf), msg) + __wait_for_request(req, comm, my_rank, their_rank, tag, "recv_yield", "recv") + return MPI.deserialize(buf) + end + warn_period = mpi_deadlock_detect(detect, time_start, warn_period, timeout_period, my_rank, tag, "recv", their_rank) + yield() + end +end + +const SEEN_TAGS = Dict{Int32, Type}() +send_yield!(value, comm, dest, tag) = + _send_yield(value, comm, dest, tag; inplace=true) +send_yield(value, comm, dest, tag) = + _send_yield(value, comm, dest, tag; inplace=false) +function _send_yield(value, comm, dest, tag; inplace::Bool) + rank = MPI.Comm_rank(comm) + + #= + if CHECK_UNIFORMITY[] && haskey(SEEN_TAGS, tag) && SEEN_TAGS[tag] !== typeof(value) + @error "[rank $(MPI.Comm_rank(comm))][tag $tag] Already seen tag (previous type: $(SEEN_TAGS[tag]), new type: $(typeof(value)))" exception=(InterruptException(),backtrace()) + end + if CHECK_UNIFORMITY[] + SEEN_TAGS[tag] = typeof(value) + end + =# + + #Core.println("[rank $(MPI.Comm_rank(comm))][tag $tag] Starting send to [$dest]: $(typeof(value)), is support inplace? $(supports_inplace_mpi(value))") + if inplace && supports_inplace_mpi(value) + send_yield_inplace(value, comm, rank, dest, tag) + else + send_yield_serialized(value, comm, rank, dest, tag) + end +end + +function send_yield_inplace(value, comm, my_rank, their_rank, tag) + @opcounter :send_yield_inplace + req = MPI.Isend(value, comm; dest=their_rank, tag) + __wait_for_request(req, comm, my_rank, their_rank, tag, "send_yield", "send") +end + +function send_yield_serialized(value, comm, my_rank, their_rank, tag) + @opcounter :send_yield_serialized + if value isa Array && isbitstype(eltype(value)) + send_yield_serialized(InplaceInfo(typeof(value), size(value)), comm, my_rank, their_rank, tag) + send_yield_inplace(value, comm, my_rank, their_rank, tag) + elseif value isa SparseMatrixCSC && isbitstype(eltype(value)) + send_yield_serialized(InplaceSparseInfo(typeof(value), value.m, value.n, length(value.colptr), length(value.rowval), length(value.nzval)), comm, my_rank, their_rank, tag) + send_yield_inplace(value.colptr, comm, my_rank, their_rank, tag) + send_yield_inplace(value.rowval, comm, my_rank, their_rank, tag) + send_yield_inplace(value.nzval, comm, my_rank, their_rank, tag) + else + req = MPI.isend(value, comm; dest=their_rank, tag) + __wait_for_request(req, comm, my_rank, their_rank, tag, "send_yield", "send") + end +end + +function __wait_for_request(req, comm, my_rank, their_rank, tag, fn::String, kind::String) + time_start = time_ns() + detect = DEADLOCK_DETECT[] + warn_period = round(UInt64, DEADLOCK_WARN_PERIOD[] * 1e9) + timeout_period = round(UInt64, DEADLOCK_TIMEOUT_PERIOD[] * 1e9) + while true + finish, status = MPI.Test(req, MPI.Status) + if finish + if MPI.Get_error(status) != MPI.SUCCESS + error("$fn failed with error $(MPI.Get_error(status))") + end + return + end + warn_period = mpi_deadlock_detect(detect, time_start, warn_period, timeout_period, my_rank, tag, kind, their_rank) + yield() + end +end + +function bcast_send_yield(value, comm, root, tag) + @opcounter :bcast_send_yield + sz = MPI.Comm_size(comm) + rank = MPI.Comm_rank(comm) + for other_rank in 0:(sz-1) + rank == other_rank && continue + send_yield(value, comm, other_rank, tag) + end +end + +#= Maybe can be worth it to implement this +function bcast_send_yield!(value, comm, root, tag) + sz = MPI.Comm_size(comm) + rank = MPI.Comm_rank(comm) + + for other_rank in 0:(sz-1) + rank == other_rank && continue + #println("[rank $rank] Sending to rank $other_rank") + send_yield!(value, comm, other_rank, tag) + end +end + +function bcast_recv_yield!(value, comm, root, tag) + sz = MPI.Comm_size(comm) + rank = MPI.Comm_rank(comm) + #println("[rank $rank] receive from rank $root") + recv_yield!(value, comm, root, tag) +end +=# +function mpi_deadlock_detect(detect, time_start, warn_period, timeout_period, rank, tag, kind, srcdest) + time_elapsed = (time_ns() - time_start) + if detect && time_elapsed > warn_period + @warn "[rank $rank][tag $tag] Hit probable hang on $kind (dest: $srcdest)" + return typemax(UInt64) + end + if detect && time_elapsed > timeout_period + error("[rank $rank][tag $tag] Hit hang on $kind (dest: $srcdest)") + end + return warn_period +end + +#discuss this with julian +@warn "Fix this WeakChunk method" maxlog=1 +WeakChunk(c::Chunk{T,H}) where {T,H<:MPIRef} = WeakChunk(c.handle.rank, c.handle.id.id, WeakRef(c)) + +function MemPool.poolget(ref::MPIRef; uniform::Bool=uniform_execution()) + @assert uniform || ref.rank == MPI.Comm_rank(ref.comm) "MPIRef rank mismatch: $(ref.rank) != $(MPI.Comm_rank(ref.comm))" + if uniform + tag = to_tag() + if ref.rank == MPI.Comm_rank(ref.comm) + value = poolget(ref.innerRef) + bcast_send_yield(value, ref.comm, ref.rank, tag) + return value + else + return recv_yield(ref.comm, ref.rank, tag) + end + else + return poolget(ref.innerRef) + end +end +fetch_handle(ref::MPIRef; uniform::Bool=uniform_execution()) = poolget(ref; uniform) + +function move!(dep_mod, to_space::MPIMemorySpace, from_space::MPIMemorySpace, to::Chunk, from::Chunk) + @assert to.handle isa MPIRef && from.handle isa MPIRef "MPIRef expected" + @assert to.handle.comm == from.handle.comm "MPIRef comm mismatch" + @assert to.handle.rank == to_space.rank && from.handle.rank == from_space.rank "MPIRef rank mismatch" + local_rank = MPI.Comm_rank(from.handle.comm) + if to_space.rank == from_space.rank == local_rank + move!(dep_mod, to_space.innerSpace, from_space.innerSpace, to, from) + else + @dagdebug nothing :mpi "[$local_rank][$tag] Moving from $(from_space.rank) to $(to_space.rank)\n" + tag = to_tag(from.handle) + if local_rank == from_space.rank + send_yield!(poolget(from.handle; uniform=false), to_space.comm, to_space.rank, tag) + elseif local_rank == to_space.rank + #@dagdebug nothing :mpi "[$local_rank][$tag] Receiving from rank $(from_space.rank) with tag $tag, type of buffer: $(typeof(poolget(to.handle; uniform=false)))" + to_val = poolget(to.handle; uniform=false) + val, inplace = recv_yield!(to_val, from_space.comm, from_space.rank, tag) + if !inplace + move!(dep_mod, to_space.innerSpace, from_space.innerSpace, to_val, val) + end + end + end + @dagdebug nothing :mpi "[$local_rank][$tag] Finished moving from $(from_space.rank) to $(to_space.rank) successfuly\n" +end +function move!(dep_mod::RemainderAliasing{<:MPIMemorySpace}, to_space::MPIMemorySpace, from_space::MPIMemorySpace, to::Chunk, from::Chunk) + @assert to.handle isa MPIRef && from.handle isa MPIRef "MPIRef expected" + @assert to.handle.comm == from.handle.comm "MPIRef comm mismatch" + @assert to.handle.rank == to_space.rank && from.handle.rank == from_space.rank "MPIRef rank mismatch" + local_rank = MPI.Comm_rank(from.handle.comm) + if to_space.rank == from_space.rank == local_rank + move!(dep_mod, to_space.innerSpace, from_space.innerSpace, to, from) + else + tag = to_tag(from.handle) + @dagdebug nothing :mpi "[$local_rank][$tag] Moving from $(from_space.rank) to $(to_space.rank)\n" + if local_rank == from_space.rank + # Get the source data for each span + len = sum(span_tuple->span_len(span_tuple[1]), dep_mod.spans) + copies = Vector{UInt8}(undef, len) + offset = 1 + for (from_span, _) in dep_mod.spans + #GC.@preserve copy begin + from_ptr = Ptr{UInt8}(from_span.ptr) + to_ptr = Ptr{UInt8}(pointer(copies, offset)) + unsafe_copyto!(to_ptr, from_ptr, from_span.len) + offset += from_span.len + #end + end + + # Send the spans + #send_yield(len, to_space.comm, to_space.rank, tag) + send_yield!(copies, to_space.comm, to_space.rank, tag) + #send_yield(copies, to_space.comm, to_space.rank, tag) + elseif local_rank == to_space.rank + # Receive the spans + len = sum(span_tuple->span_len(span_tuple[1]), dep_mod.spans) + copies = Vector{UInt8}(undef, len) + recv_yield!(copies, from_space.comm, from_space.rank, tag) + #copies = recv_yield(from_space.comm, from_space.rank, tag) + + # Copy the data into the destination object + #for (copy, (_, to_span)) in zip(copies, dep_mod.spans) + offset = 1 + for (_, to_span) in dep_mod.spans + #GC.@preserve copy begin + from_ptr = Ptr{UInt8}(pointer(copies, offset)) + to_ptr = Ptr{UInt8}(to_span.ptr) + unsafe_copyto!(to_ptr, from_ptr, to_span.len) + offset += to_span.len + #end + end + + # Ensure that the data is visible + Core.Intrinsics.atomic_fence(:release) + end + end + + return +end + +const ALIASED_OBJECT_CACHE = ScopedValue{Union{AliasedObjectCache, Nothing}}(nothing) + +move(::MPIOSProc, ::MPIProcessor, x::Union{Function,Type}) = x +move(::MPIOSProc, ::MPIProcessor, x::Chunk{<:Union{Function,Type}}) = poolget(x.handle) + +#TODO: out of place MPI move +function move(src::MPIOSProc, dst::MPIProcessor, x::Chunk) + @assert src.comm == dst.comm "Multi comm move not supported" + if Sch.SCHED_MOVE[] + if dst.rank == MPI.Comm_rank(dst.comm) + return poolget(x.handle) + end + else + @assert src.rank == MPI.Comm_rank(src.comm) "Unwrapping not permited" + @assert src.rank == x.handle.rank == dst.rank + return poolget(x.handle) + end +end + +function is_local(accel::MPIAcceleration, target) + return target.rank == MPI.Comm_rank(accel.comm) +end + +function remotecall_endpoint_toplevel(f, accel::MPIAcceleration, cache::AliasedObjectCache, from_proc, to_proc, from_space, to_space, data::Chunk) + backend = lock(MPIRC_BACKEND) do backends + backends[accel.comm] + end + if is_local(accel, from_proc) + data_raw = unwrap(data) + res = f(accel, cache, from_proc, to_proc, from_space, to_space, data_raw)::Chunk + csz = MPI.Comm_size(accel.comm) + for target_rank in 0:(csz-1) + if target_rank == from_proc.rank + continue + end + MPIRPC.rpc_progress_halt!(backend, target_rank) + end + @assert res.handle isa MPIRef "Expected MPIRef handle for MPI broadcast" + meta = (res.handle.id, res.handle.size, chunktype(res), res.space.innerSpace, res.space.rank) + bcast_send_yield(meta, accel.comm, from_proc.rank, 0) + return res + else + with(ALIASED_OBJECT_CACHE=>cache) do + while MPIRPC.rpc_progress!(backend) + yield() + end + end + id, size, T, inner_space, rank = recv_yield(accel.comm, from_proc.rank, 0) + space = MPIMemorySpace(inner_space, accel.comm, rank) + handle = MPIRef(accel.comm, rank, size, nothing, id) + return Chunk{T, typeof(handle), typeof(to_proc), AnyScope, typeof(space)}( + T, domain(nothing), handle, to_proc, AnyScope(), space) + end +end + +function remotecall_endpoint_transfer(f, accel::MPIAcceleration, from_proc, to_proc, from_space, to_space, data) + backend = lock(MPIRC_BACKEND) do backends + backends[accel.comm] + end + + return MPIRPC.remotecall_fetch(backend, to_proc.rank) do + ACCELERATION[] = accel + return f(accel, from_proc, to_proc, from_space, to_space, data) + end +end + +function set_stored_mpi!(space::MemorySpace, value::Chunk, ainfo::AbstractAliasing) + cache = ALIASED_OBJECT_CACHE[] + ACCELERATION[] = cache.accel + set_stored!(unwrap(cache.chunk)::AliasedObjectCacheStore, cache.space, value, ainfo) + return +end + +function mpi_nonlocal_chunk(value::Chunk) + return tochunk(nothing, value.processor, value.space, value.scope; type=chunktype(value), force_nonlocal=true) +end + +function set_stored!(accel::MPIAcceleration, cache::AliasedObjectCache, value::Chunk, ainfo::AbstractAliasing) + cache_raw = unwrap(cache.chunk)::AliasedObjectCacheStore + backend = lock(MPIRC_BACKEND) do backends + backends[accel.comm] + end + MPIRPC.bcast_remotecall(backend, set_stored_mpi!, cache.space, mpi_nonlocal_chunk(value), ainfo) + set_stored!(cache_raw, cache.space, value, ainfo) + return +end + +function set_key_stored_mpi!(space::MemorySpace, ainfo::AbstractAliasing, value::Chunk) + cache = unwrap(ALIASED_OBJECT_CACHE[].chunk)::AliasedObjectCacheStore + ACCELERATION[] = cache.accel + set_key_stored!(cache, space, ainfo, value) + return +end + +function set_key_stored!(accel::MPIAcceleration, cache::AliasedObjectCache, space::MemorySpace, ainfo::AbstractAliasing, value::Chunk) + cache_raw = unwrap(cache.chunk)::AliasedObjectCacheStore + backend = lock(MPIRC_BACKEND) do backends + backends[accel.comm] + end + MPIRPC.bcast_remotecall(backend, set_key_stored_mpi!, space, ainfo, mpi_nonlocal_chunk(value)) + set_key_stored!(cache_raw, space, ainfo, value) +end + + +#= +function remotecall_endpoint(f, accel::MPIAcceleration, from_proc, to_proc, from_space, to_space, data::Chunk) + loc_rank = MPI.Comm_rank(accel.comm) + if loc_rank == from_proc.rank + # FIXME: Descend via move_rewrap, and send data to to_proc + elseif loc_rank == to_proc.rank + # FIXME: Listen for data from from_proc to locally wrap as Chunk + while true + value = recv_yield(accel.comm, from_proc.rank, tag) + end + bcast_recv_yield(data_new, accel.comm, to_proc.rank, tag) + else + # Wait for final Chunk + return recv_yield(accel.comm, to_proc.rank, tag) + end +end +=# +#= +function remotecall_endpoint(f, accel::MPIAcceleration, from_proc, to_proc, from_space, to_space, data::Chunk) + loc_rank = MPI.Comm_rank(accel.comm) + task = DATADEPS_CURRENT_TASK[] + return with(MPI_UID=>task.uid) do + space = memory_space(data) + tag = to_tag() + T = move_type(from_proc.innerProc, to_proc.innerProc, chunktype(data)) + T_new = f !== identity ? Base._return_type(f, Tuple{T}) : T + need_bcast = !isconcretetype(T_new) || T_new === Union{} || T_new === Nothing || T_new === Any + + if space.rank != from_proc.rank + # Data is already at destination (to_proc.rank) + @assert space.rank == to_proc.rank + if space.rank == loc_rank + value = poolget(data.handle) + data_converted = f(move(from_proc.innerProc, to_proc.innerProc, value)) + T_actual = typeof(data_converted) + if need_bcast + bcast_send_yield(T_actual, accel.comm, to_proc.rank, tag) + end + return tochunk(data_converted, to_proc, to_space; type=T_actual) + else + T_actual = need_bcast ? recv_yield(accel.comm, to_proc.rank, tag) : T_new + return tochunk(nothing, to_proc, to_space; type=T_actual) + end + end + + # Data is on the source rank + @assert space.rank == from_proc.rank + if loc_rank == from_proc.rank == to_proc.rank + value = poolget(data.handle) + data_converted = f(move(from_proc.innerProc, to_proc.innerProc, value)) + return tochunk(data_converted, to_proc, to_space; type=typeof(data_converted)) + end + + if loc_rank == from_proc.rank + value = poolget(data.handle) + data_moved = move(from_proc.innerProc, to_proc.innerProc, value) + Dagger.send_yield(data_moved, accel.comm, to_proc.rank, tag) + T_actual = need_bcast ? recv_yield(accel.comm, to_proc.rank, type_tag) : T_new + return tochunk(nothing, to_proc, to_space; type=T_actual) + elseif loc_rank == to_proc.rank + data_moved = Dagger.recv_yield(accel.comm, from_space.rank, tag) + data_converted = f(move(from_proc.innerProc, to_proc.innerProc, data_moved)) + T_actual = typeof(data_converted) + if need_bcast + bcast_send_yield(T_actual, accel.comm, to_proc.rank, type_tag) + end + return tochunk(data_converted, to_proc, to_space; type=T_actual) + else + T_actual = need_bcast ? recv_yield(accel.comm, to_proc.rank, type_tag) : T_new + return tochunk(nothing, to_proc, to_space; type=T_actual) + end + end +end +=# + + + +# Chunk may be MPI-backed (MPIRef) but labeled with OSProc; treat source as the owning rank +function move(src::OSProc, dst::MPIProcessor, x::Chunk) + if x.handle isa MPIRef + return move(MPIOSProc(x.handle.comm, x.handle.rank), dst, x) + end + error("MPI move not supported") +end + +move(src::Processor, dst::MPIProcessor, x::Chunk) = error("MPI move not supported") +move(to_proc::MPIProcessor, chunk::Chunk) = + move(chunk.processor, to_proc, chunk) +move(to_proc::Processor, d::MPIRef) = + move(MPIOSProc(d.rank), to_proc, d) +move(to_proc::MPIProcessor, x) = + move(MPIOSProc(), to_proc, x) + +move(::MPIProcessor, ::MPIProcessor, x::Union{Function,Type}) = x +move(::MPIProcessor, ::MPIProcessor, x::Chunk{<:Union{Function,Type}}) = poolget(x.handle) + +@warn "Is this uniform logic valuable to have?" maxlog=1 +function move(src::MPIProcessor, dst::MPIProcessor, x::Chunk) + uniform = uniform_execution() + @assert uniform || src.rank == dst.rank "Unwrapping not permitted" + if Sch.SCHED_MOVE[] + # We can either unwrap locally, or return nothing + if dst.rank == MPI.Comm_rank(dst.comm) + return poolget(x.handle) + end + else + # Either we're uniform (so everyone cooperates), or we're unwrapping locally + if !uniform + @assert src.rank == MPI.Comm_rank(src.comm) "Unwrapping not permitted" + @assert src.rank == x.handle.rank == dst.rank + end + return poolget(x.handle; uniform) + end +end + +gpu_kernel_backend(::MPIProcessor) = KernelAbstractions.CPU() + +#FIXME:try to think of a better move! scheme +function execute!(proc::MPIProcessor, f, args...; kwargs...) + local_rank = MPI.Comm_rank(proc.comm) + islocal = local_rank == proc.rank + inplace_move = f === move! + result = nothing + tag = to_tag() + + if islocal || inplace_move + result = execute!(proc.innerProc, f, args...; kwargs...) + end + + if inplace_move + space = memory_space(nothing, proc)::MPIMemorySpace + dest_type = chunktype(args[4]) + return tochunk(nothing, proc, space; type=dest_type) + end + + # Infer return type; only bcast when inference is not concrete + fname = nameof(f) + arg_types = map(chunktype, args) + inferred_type = Base.promote_op(f, arg_types...) + + need_bcast = !isconcretetype(inferred_type) || inferred_type === Union{} || inferred_type === Nothing || inferred_type === Any + + if islocal + T = typeof(result) + if T === Nothing + @warn "[rank $local_rank] Gave $T result for $fname, args types: $arg_types; treating as Any for broadcast" + end + space = memory_space(result, proc)::MPIMemorySpace + if need_bcast + @opcounter :execute_bcast_send_yield + bcast_send_yield((T, space.innerSpace), proc.comm, proc.rank, tag) + end + return tochunk(result, proc, space; type=T) + else + if need_bcast + T, innerSpace = recv_yield(proc.comm, proc.rank, tag) + space = MPIMemorySpace(innerSpace, proc.comm, proc.rank) + return tochunk(nothing, proc, space; type=T) + else + space = memory_space(nothing, proc)::MPIMemorySpace + return tochunk(nothing, proc, space; type=inferred_type) + end + end +end + +accelerate!(::Val{:mpi}) = accelerate!(MPIAcceleration()) + +const MPIRC_BACKEND = LockedObject(Dict{MPI.Comm, MPIRPC.AbstractMPIRPCBackend}()) + +function initialize_acceleration!(a::MPIAcceleration) + if !MPI.Initialized() + MPI.Init(;threadlevel=:multiple) + end + backend = MPIRPC.select_mpi_rpc_backend!(MPIRPC.UniformMPIRPCBackend(a.comm)) + lock(MPIRC_BACKEND) do backends + backends[a.comm] = backend + end + ctx = Dagger.Sch.eager_context() + sz = MPI.Comm_size(a.comm) + for i in 0:(sz-1) + push!(ctx.procs, MPIOSProc(a.comm, i)) + end + unique!(ctx.procs) +end + +""" + mpi_propagate_chunk_types!(tasks, accel::MPIAcceleration, expected_type) + +Ensure all ranks use the same concrete type for the given tasks by setting +each task's options.return_type to expected_type when it is concrete. +This allows chunktype(task) to return the concrete type on every rank +without an MPI allgather of actual result types. +""" +function mpi_propagate_chunk_types!(tasks, accel::MPIAcceleration, expected_type) + isconcretetype(expected_type) || return + for t in tasks + if t isa Thunk + if t.options !== nothing + t.options.return_type = expected_type + else + t.options = Options(return_type=expected_type) + end + end + end + return +end + +accel_matches_proc(accel::MPIAcceleration, proc::MPIOSProc) = true +accel_matches_proc(accel::MPIAcceleration, proc::MPIClusterProc) = true +accel_matches_proc(accel::MPIAcceleration, proc::MPIProcessor) = true +accel_matches_proc(accel::MPIAcceleration, proc) = false + +function distribute(accel::MPIAcceleration, A::AbstractArray{T,N}, dist::Blocks{N}) where {T,N} + comm = accel.comm + rank = MPI.Comm_rank(comm) + + DA = view(A, dist) + DB = DArray{T,N}(undef, dist, size(A)) + copyto!(DB, DA) + + return DB +end diff --git a/src/mpi_mempool.jl b/src/mpi_mempool.jl new file mode 100644 index 000000000..149c7900a --- /dev/null +++ b/src/mpi_mempool.jl @@ -0,0 +1,36 @@ +# Mempool for received MPI message data only (no envelopes). +# Key: (comm, source, tag). Used when a message is received but not the one the caller was waiting for. +# Included from mpi.jl; runs in Dagger module scope. + +const MPI_RECV_MEMPOOL = Base.Lockable(Dict{Tuple{MPI.Comm, Int, Int}, Vector{Any}}()) + +function mpi_mempool_put!(comm::MPI.Comm, source::Integer, tag::Integer, data::Any) + key = (comm, Int(source), Int(tag)) + ref = poolset(data) + lock(MPI_RECV_MEMPOOL) do pool + if !haskey(pool, key) + pool[key] = Any[] + end + push!(pool[key], ref) + end + return nothing +end + +function mpi_mempool_take!(comm::MPI.Comm, source::Integer, tag::Integer) + key = (comm, Int(source), Int(tag)) + ref = lock(MPI_RECV_MEMPOOL) do pool + if !haskey(pool, key) || isempty(pool[key]) + return nothing + end + popfirst!(pool[key]) + end + ref === nothing && return nothing + return poolget(ref) +end + +function mpi_mempool_has(comm::MPI.Comm, source::Integer, tag::Integer) + key = (comm, Int(source), Int(tag)) + return lock(MPI_RECV_MEMPOOL) do pool + haskey(pool, key) && !isempty(pool[key]) + end +end diff --git a/src/mutable.jl b/src/mutable.jl new file mode 100644 index 000000000..1f48ead53 --- /dev/null +++ b/src/mutable.jl @@ -0,0 +1,41 @@ +function _mutable_inner(@nospecialize(f), proc, scope) + result = f() + return Ref(Dagger.tochunk(result, proc, scope)) +end + +""" + mutable(f::Base.Callable; worker, processor, scope) -> Chunk + +Calls `f()` on the specified worker or processor, returning a `Chunk` +referencing the result with the specified scope `scope`. +""" +function mutable(@nospecialize(f); worker=nothing, processor=nothing, scope=nothing) + if processor === nothing + if worker === nothing + processor = OSProc() + else + processor = OSProc(worker) + end + else + @assert worker === nothing "mutable: Can't mix worker and processor" + end + if scope === nothing + scope = processor isa OSProc ? ProcessScope(processor) : ExactScope(processor) + end + return fetch(Dagger.@spawn scope=scope _mutable_inner(f, processor, scope))[] +end + +""" + @mutable [worker=1] [processor=OSProc()] [scope=ProcessorScope()] f() + +Helper macro for [`mutable()`](@ref). +""" +macro mutable(exs...) + opts = esc.(exs[1:end-1]) + ex = exs[end] + quote + let f = @noinline ()->$(esc(ex)) + $mutable(f; $(opts...)) + end + end +end diff --git a/src/options.jl b/src/options.jl index eca59fbc9..09067da51 100644 --- a/src/options.jl +++ b/src/options.jl @@ -26,6 +26,7 @@ Stores per-task options to be passed to the scheduler. - `storage_leaf_tag::Union{MemPool.Tag,Nothing}=nothing`: If not `nothing`, specifies the MemPool storage leaf tag to associate with the task's result. This tag can be used by MemPool's storage devices to manipulate their behavior, such as the file name used to store data on disk." - `storage_retain::Union{Bool,Nothing}=nothing`: The value of `retain` to pass to `MemPool.poolset` when constructing the result `Chunk`. `nothing` defaults to `false`. - `name::Union{String,Nothing}=nothing`: If not `nothing`, annotates the task with a name for logging purposes. +- `tag::Union{UInt32,Nothing}=nothing`: (Data-deps/MPI) MPI message tag for this task; assigned automatically if `nothing`. - `stream_input_buffer_amount::Union{Int,Nothing}=nothing`: (Streaming only) Specifies the amount of slots to allocate for the input buffer of the task. Defaults to 1. - `stream_output_buffer_amount::Union{Int,Nothing}=nothing`: (Streaming only) Specifies the amount of slots to allocate for the output buffer of the task. Defaults to 1. - `stream_buffer_type::Union{Type,Nothing}=nothing`: (Streaming only) Specifies the type of buffer to use for the input and output buffers of the task. Defaults to `Dagger.ProcessRingBuffer`. @@ -61,10 +62,16 @@ Base.@kwdef mutable struct Options name::Union{String,Nothing} = nothing + tag::Union{UInt32,Nothing} = nothing + stream_input_buffer_amount::Union{Int,Nothing} = nothing stream_output_buffer_amount::Union{Int,Nothing} = nothing stream_buffer_type::Union{Type, Nothing} = nothing stream_max_evals::Union{Int,Nothing} = nothing + + acceleration::Union{Acceleration,Nothing} = nothing + + return_type::Union{Type,Nothing} = nothing end Options(::Nothing) = Options() function Options(old_options::NamedTuple) diff --git a/src/processor.jl b/src/processor.jl index ac2e74f14..4944dc083 100644 --- a/src/processor.jl +++ b/src/processor.jl @@ -2,16 +2,6 @@ export OSProc, Context, addprocs!, rmprocs! import Base: @invokelatest -""" - Processor - -An abstract type representing a processing device and associated memory, where -data can be stored and operated on. Subtypes should be immutable, and -instances should compare equal if they represent the same logical processing -device/memory. Subtype instances should be serializable between different -nodes. Subtype instances may contain a "parent" `Processor` to make it easy to -transfer data to/from other types of `Processor` at runtime. -""" abstract type Processor end const PROCESSOR_CALLBACKS = Dict{Symbol,Any}() @@ -150,3 +140,20 @@ iscompatible_arg(proc::OSProc, opts, args...) = "Returns a very brief `String` representation of `proc`." short_name(proc::Processor) = string(proc) short_name(p::OSProc) = "W: $(p.pid)" + +"Returns true if the processor is on the local worker (for MPI/ordering)." +is_local_processor(proc::Processor) = (root_worker_id(proc) == myid()) + +"Ordering key for task firing (used by MPI to avoid deadlock)." +fire_order_key(proc::Processor) = (root_worker_id(proc), 0) + +@doc """ + Processor + +An abstract type representing a processing device and associated memory, where +data can be stored and operated on. Subtypes should be immutable, and +instances should compare equal if they represent the same logical processing +device/memory. Subtype instances should be serializable between different +nodes. Subtype instances may contain a "parent" `Processor` to make it easy to +transfer data to/from other types of `Processor` at runtime. +""" Processor diff --git a/src/queue.jl b/src/queue.jl index 37947a0ac..c1e264c06 100644 --- a/src/queue.jl +++ b/src/queue.jl @@ -125,7 +125,7 @@ function wait_all(f; check_errors::Bool=false) result = with_options(f; task_queue=queue) for task in queue.tasks if check_errors - fetch(task; raw=true) + fetch(task; move_value=false, unwrap=false) else wait(task) end diff --git a/src/sch/Sch.jl b/src/sch/Sch.jl index 58aed6dc5..3c8353c59 100644 --- a/src/sch/Sch.jl +++ b/src/sch/Sch.jl @@ -15,7 +15,7 @@ import Base: @invokelatest import ..Dagger import ..Dagger: Context, Processor, SchedulerOptions, Options, Thunk, WeakThunk, ThunkFuture, ThunkID, DTaskFailedException, Chunk, WeakChunk, OSProc, AnyScope, DefaultScope, InvalidScope, LockedObject, Argument, Signature -import ..Dagger: order, dependents, noffspring, istask, inputs, unwrap_weak_checked, wrap_weak, affinity, tochunk, timespan_start, timespan_finish, procs, move, chunktype, default_enabled, processor, get_processors, get_parent, execute!, rmprocs!, task_processor, constrain, cputhreadtime, maybe_take_or_alloc! +import ..Dagger: order, dependents, noffspring, istask, inputs, unwrap_weak_checked, wrap_weak, affinity, tochunk, timespan_start, timespan_finish, procs, move, chunktype, default_enabled, processor, get_processors, get_parent, root_worker_id, execute!, rmprocs!, task_processor, constrain, cputhreadtime, maybe_take_or_alloc!, is_local_processor, fire_order_key, short_name import ..Dagger: @dagdebug, @safe_lock_spin1, @maybelog, @take_or_alloc! import DataStructures: PriorityQueue, enqueue!, dequeue_pair!, peek @@ -25,7 +25,7 @@ import ..Dagger: @reusable, @reusable_dict, @reusable_vector, @reusable_tasks, @ import TimespanLogging import TaskLocalValues: TaskLocalValue -import ScopedValues: @with +import ScopedValues: ScopedValue, @with, with const OneToMany = Dict{Thunk, Set{Thunk}} @@ -56,7 +56,7 @@ Fields: - `cache::WeakKeyDict{Thunk, Any}` - Maps from a finished `Thunk` to it's cached result, often a DRef - `valid::WeakKeyDict{Thunk, Nothing}` - Tracks all `Thunk`s that are in a valid scheduling state - `running::Set{Thunk}` - The set of currently-running `Thunk`s -- `running_on::Dict{Thunk,OSProc}` - Map from `Thunk` to the OS process executing it +- `running_on::Dict{Thunk,Processor}` - Map from `Thunk` to the OS process executing it - `thunk_dict::Dict{Int, WeakThunk}` - Maps from thunk IDs to a `Thunk` - `node_order::Any` - Function that returns the order of a thunk - `equiv_chunks::WeakKeyDict{DRef,Chunk}` - Cache mapping from `DRef` to a `Chunk` which contains it @@ -82,18 +82,18 @@ struct ComputeState ready::Vector{Thunk} valid::Dict{Thunk, Nothing} running::Set{Thunk} - running_on::Dict{Thunk,OSProc} + running_on::Dict{Thunk,Processor} thunk_dict::Dict{Int, WeakThunk} node_order::Any - equiv_chunks::WeakKeyDict{DRef,Chunk} - worker_time_pressure::Dict{Int,Dict{Processor,UInt64}} - worker_storage_pressure::Dict{Int,Dict{Union{StorageResource,Nothing},UInt64}} - worker_storage_capacity::Dict{Int,Dict{Union{StorageResource,Nothing},UInt64}} - worker_loadavg::Dict{Int,NTuple{3,Float64}} - worker_chans::Dict{Int, Tuple{RemoteChannel,RemoteChannel}} + equiv_chunks::WeakKeyDict{Any,Chunk} + worker_time_pressure::Dict{Processor,Dict{Processor,UInt64}} + worker_storage_pressure::Dict{Processor,Dict{Union{StorageResource,Nothing},UInt64}} + worker_storage_capacity::Dict{Processor,Dict{Union{StorageResource,Nothing},UInt64}} + worker_loadavg::Dict{Processor,NTuple{3,Float64}} + worker_chans::Dict{Int,Tuple{RemoteChannel,RemoteChannel}} signature_time_cost::Dict{Signature,UInt64} signature_alloc_cost::Dict{Signature,UInt64} - worker_transfer_rate::Dict{Int,Dict{Processor,UInt64}} + worker_transfer_rate::Dict{Processor,Dict{Processor,UInt64}} halt::Base.Event lock::ReentrantLock futures::Dict{Thunk, Vector{ThunkFuture}} @@ -111,18 +111,18 @@ function start_state(deps::Dict, node_order, chan) Vector{Thunk}(undef, 0), Dict{Thunk, Nothing}(), Set{Thunk}(), - Dict{Thunk,OSProc}(), + Dict{Thunk,Processor}(), Dict{Int, WeakThunk}(), node_order, - WeakKeyDict{DRef,Chunk}(), - Dict{Int,Dict{Processor,UInt64}}(), - Dict{Int,Dict{Union{StorageResource,Nothing},UInt64}}(), - Dict{Int,Dict{Union{StorageResource,Nothing},UInt64}}(), - Dict{Int,NTuple{3,Float64}}(), - Dict{Int, Tuple{RemoteChannel,RemoteChannel}}(), + WeakKeyDict{Any,Chunk}(), + Dict{Processor,Dict{Processor,UInt64}}(), + Dict{Processor,Dict{Union{StorageResource,Nothing},UInt64}}(), + Dict{Processor,Dict{Union{StorageResource,Nothing},UInt64}}(), + Dict{Processor,NTuple{3,Float64}}(), + Dict{Processor,Tuple{RemoteChannel,RemoteChannel}}(), Dict{Signature,UInt64}(), Dict{Signature,UInt64}(), - Dict{Int,Dict{Processor,UInt64}}(), + Dict{Processor,Dict{Processor,UInt64}}(), Base.Event(), ReentrantLock(), Dict{Thunk, Vector{ThunkFuture}}(), @@ -152,30 +152,29 @@ const WORKER_MONITOR_TASKS = Dict{Int,Task}() const WORKER_MONITOR_CHANS = Dict{Int,Dict{UInt64,RemoteChannel}}() function init_proc(state, p, log_sink) ctx = Context(Int[]; log_sink) - @maybelog ctx timespan_start(ctx, :init_proc, (;uid=state.uid, worker=p.pid), nothing) + pid = Dagger.root_worker_id(p) + @maybelog ctx timespan_start(ctx, :init_proc, (;uid=state.uid, worker=pid), nothing) # Initialize pressure and capacity - gproc = OSProc(p.pid) lock(state.lock) do - state.worker_time_pressure[p.pid] = Dict{Processor,UInt64}() - state.worker_transfer_rate[p.pid] = Dict{Processor,UInt64}() - - state.worker_storage_pressure[p.pid] = Dict{Union{StorageResource,Nothing},UInt64}() - state.worker_storage_capacity[p.pid] = Dict{Union{StorageResource,Nothing},UInt64}() + state.worker_transfer_rate[p] = Dict{Processor,UInt64}() + state.worker_time_pressure[p] = Dict{Processor,UInt64}() + state.worker_storage_pressure[p] = Dict{Union{StorageResource,Nothing},UInt64}() + state.worker_storage_capacity[p] = Dict{Union{StorageResource,Nothing},UInt64}() #= FIXME for storage in get_storage_resources(gproc) - pressure, capacity = remotecall_fetch(gproc.pid, storage) do storage + pressure, capacity = remotecall_fetch(root_worker_id(gproc), storage) do storage storage_pressure(storage), storage_capacity(storage) end - state.worker_storage_pressure[p.pid][storage] = pressure - state.worker_storage_capacity[p.pid][storage] = capacity + state.worker_storage_pressure[p][storage] = pressure + state.worker_storage_capacity[p][storage] = capacity end =# - state.worker_loadavg[p.pid] = (0.0, 0.0, 0.0) + state.worker_loadavg[p] = (0.0, 0.0, 0.0) end - if p.pid != 1 + if pid != 1 lock(WORKER_MONITOR_LOCK) do - wid = p.pid + wid = pid if !haskey(WORKER_MONITOR_TASKS, wid) t = Threads.@spawn begin try @@ -209,16 +208,16 @@ function init_proc(state, p, log_sink) end # Setup worker-to-scheduler channels - inp_chan = RemoteChannel(p.pid) - out_chan = RemoteChannel(p.pid) + inp_chan = RemoteChannel(pid) + out_chan = RemoteChannel(pid) lock(state.lock) do - state.worker_chans[p.pid] = (inp_chan, out_chan) + state.worker_chans[pid] = (inp_chan, out_chan) end # Setup dynamic listener - dynamic_listener!(ctx, state, p.pid) + dynamic_listener!(ctx, state, pid) - @maybelog ctx timespan_finish(ctx, :init_proc, (;uid=state.uid, worker=p.pid), nothing) + @maybelog ctx timespan_finish(ctx, :init_proc, (;uid=state.uid, worker=pid), nothing) end function _cleanup_proc(uid, log_sink) empty!(CHUNK_CACHE) # FIXME: Should be keyed on uid! @@ -236,7 +235,7 @@ function _cleanup_proc(uid, log_sink) end function cleanup_proc(state, p, log_sink) ctx = Context(Int[]; log_sink) - wid = p.pid + wid = root_worker_id(p) @maybelog ctx timespan_start(ctx, :cleanup_proc, (;uid=state.uid, worker=wid), nothing) lock(WORKER_MONITOR_LOCK) do if haskey(WORKER_MONITOR_CHANS, wid) @@ -299,7 +298,7 @@ function compute_dag(ctx::Context, d::Thunk, options=SchedulerOptions()) node_order = x -> -get(ord, x, 0) state = start_state(deps, node_order, chan) - master = OSProc(myid()) + master = Dagger.default_processor() @maybelog ctx timespan_start(ctx, :scheduler_init, (;uid=state.uid), master) try @@ -394,8 +393,8 @@ function scheduler_run(ctx, state::ComputeState, d::Thunk, options::SchedulerOpt res = tresult.result @dagdebug thunk_id :take "Got finished task" - gproc = OSProc(pid) safepoint(state) + gproc = proc != nothing ? get_parent(proc) : OSProc(pid) lock(state.lock) do thunk_failed = false if res isa Exception @@ -422,11 +421,11 @@ function scheduler_run(ctx, state::ComputeState, d::Thunk, options::SchedulerOpt node = unwrap_weak_checked(state.thunk_dict[thunk_id])::Thunk metadata = tresult.metadata if metadata !== nothing - state.worker_time_pressure[pid][proc] = metadata.time_pressure + state.worker_time_pressure[gproc][proc] = metadata.time_pressure #to_storage = fetch(node.options.storage) #state.worker_storage_pressure[pid][to_storage] = metadata.storage_pressure #state.worker_storage_capacity[pid][to_storage] = metadata.storage_capacity - #state.worker_loadavg[pid] = metadata.loadavg + #state.worker_loadavg[gproc] = metadata.loadavg sig = signature(state, node) state.signature_time_cost[sig] = (metadata.threadtime + get(state.signature_time_cost, sig, 0)) ÷ 2 state.signature_alloc_cost[sig] = (metadata.gc_allocd + get(state.signature_alloc_cost, sig, 0)) ÷ 2 @@ -440,8 +439,8 @@ function scheduler_run(ctx, state::ComputeState, d::Thunk, options::SchedulerOpt end end if res isa Chunk - if !haskey(state.equiv_chunks, res) - state.equiv_chunks[res.handle::DRef] = res + if !haskey(state.equiv_chunks, res.handle) + state.equiv_chunks[res.handle] = res end end store_result!(state, node, res; error=thunk_failed) @@ -528,7 +527,7 @@ end const CHUNK_CACHE = Dict{Chunk,Dict{Processor,Any}}() struct ScheduleTaskLocation - gproc::OSProc + gproc::Processor proc::Processor end struct ScheduleTaskSpec @@ -538,6 +537,25 @@ struct ScheduleTaskSpec est_alloc_util::UInt64 est_occupancy::UInt32 end + +"Ordering key for task locations when using MPI acceleration (deterministic across ranks)." +function _mpi_fire_order_key(loc::ScheduleTaskLocation) + g = loc.gproc + p = loc.proc + g_rank = g isa Union{Dagger.MPIOSProc, Dagger.MPIProcessor} ? g.rank : root_worker_id(g) + p_rank = p isa Union{Dagger.MPIOSProc, Dagger.MPIProcessor} ? p.rank : root_worker_id(p) + return (g_rank, p_rank) +end + +"Ordering key for a single Processor when using MPI acceleration (deterministic across ranks)." +function _mpi_proc_rank(proc::Processor) + g = get_parent(proc) + p = proc + g_rank = g isa Union{Dagger.MPIOSProc, Dagger.MPIProcessor} ? g.rank : root_worker_id(g) + p_rank = p isa Union{Dagger.MPIOSProc, Dagger.MPIProcessor} ? p.rank : root_worker_id(p) + return (g_rank, p_rank) +end + @reuse_scope function schedule!(ctx, state, sch_options, procs=procs_to_use(ctx, sch_options)) lock(state.lock) do safepoint(state) @@ -552,6 +570,7 @@ end to_fire_cleanup = @reuse_defer_cleanup empty!(to_fire) failed_scheduling = @reusable_vector :schedule!_failed_scheduling Union{Thunk,Nothing} nothing 32 failed_scheduling_cleanup = @reuse_defer_cleanup empty!(failed_scheduling) + # Select a new task and get its options task = nothing @label pop_task @@ -626,9 +645,9 @@ end end @label scope_computed - input_procs = @reusable_vector :schedule!_input_procs Processor OSProc() 32 + input_procs = @reusable_vector :schedule!_input_procs Union{Processor,Nothing} nothing 32 input_procs_cleanup = @reuse_defer_cleanup empty!(input_procs) - for proc in Dagger.compatible_processors(scope, procs) + for proc in Dagger.compatible_processors(options.acceleration, scope, procs) if !(proc in input_procs) push!(input_procs, proc) end @@ -660,7 +679,7 @@ end can_use, scope = can_use_proc(state, task, gproc, proc, options, scope) if can_use has_cap, est_time_util, est_alloc_util, est_occupancy = - has_capacity(state, proc, gproc.pid, options.time_util, options.alloc_util, options.occupancy, sig) + has_capacity(state, proc, gproc, options.time_util, options.alloc_util, options.occupancy, sig) if has_cap # Schedule task onto proc # FIXME: est_time_util = est_time_util isa MaxUtilization ? cap : est_time_util @@ -669,10 +688,10 @@ end Vector{ScheduleTaskSpec}() end push!(proc_tasks, ScheduleTaskSpec(task, scope, est_time_util, est_alloc_util, est_occupancy)) - state.worker_time_pressure[gproc.pid][proc] = - get(state.worker_time_pressure[gproc.pid], proc, 0) + + state.worker_time_pressure[gproc][proc] = + get(state.worker_time_pressure[gproc], proc, 0) + est_time_util - @dagdebug task :schedule "Scheduling to $gproc -> $proc (cost: $(costs[proc]), pressure: $(state.worker_time_pressure[gproc.pid][proc]))" + @dagdebug task :schedule "Scheduling to $gproc -> $proc (cost: $(costs[proc]), pressure: $(state.worker_time_pressure[gproc][proc]))" sorted_procs_cleanup() costs_cleanup() @goto pop_task @@ -739,14 +758,14 @@ function monitor_procs_changed!(ctx, state, options) end function remove_dead_proc!(ctx, state, proc, options) - @assert options.single !== proc.pid "Single worker failed, cannot continue." + @assert options.single !== root_worker_id(proc) "Single worker failed, cannot continue." rmprocs!(ctx, [proc]) - delete!(state.worker_time_pressure, proc.pid) - delete!(state.worker_transfer_rate, proc.pid) - delete!(state.worker_storage_pressure, proc.pid) - delete!(state.worker_storage_capacity, proc.pid) - delete!(state.worker_loadavg, proc.pid) - delete!(state.worker_chans, proc.pid) + delete!(state.worker_transfer_rate, proc) + delete!(state.worker_time_pressure, proc) + delete!(state.worker_storage_pressure, proc) + delete!(state.worker_storage_capacity, proc) + delete!(state.worker_loadavg, proc) + delete!(state.worker_chans, root_worker_id(proc)) end function finish_task!(ctx, state, node, thunk_failed) @@ -789,7 +808,7 @@ end function evict_all_chunks!(ctx, options, to_evict) if !isempty(to_evict) - @sync for w in map(p->p.pid, procs_to_use(ctx, options)) + @sync for w in map(p->root_worker_id(p), procs_to_use(ctx, options)) Threads.@spawn remote_do(evict_chunks!, w, ctx.log_sink, to_evict) end end @@ -860,9 +879,10 @@ Base.hash(task::TaskSpec, h::UInt) = hash(task.thunk_id, hash(TaskSpec, h)) end Tf = chunktype(first(args)) - @assert (options.single === nothing) || (gproc.pid == options.single) + pid = root_worker_id(gproc) + @assert (options.single === nothing) || (pid == options.single) # TODO: Set `sch_handle.tid.ref` to the right `DRef` - sch_handle = SchedulerHandle(ThunkID(thunk.id, nothing), state.worker_chans[gproc.pid]...) + sch_handle = SchedulerHandle(ThunkID(thunk.id, nothing), state.worker_chans[pid]...) # TODO: De-dup common fields (log_sink, uid, etc.) push!(to_send, TaskSpec( @@ -874,7 +894,7 @@ Base.hash(task::TaskSpec, h::UInt) = hash(task.thunk_id, hash(TaskSpec, h)) end if !isempty(to_send) - if Dagger.root_worker_id(gproc) == myid() + if root_worker_id(gproc) == myid() @reusable_tasks :fire_tasks!_task_cache 32 _->nothing "fire_tasks!" FireTaskSpec(proc, state.chan, to_send) else # N.B. We don't batch these because we might get a deserialization @@ -1080,7 +1100,7 @@ function start_processor_runner!(istate::ProcessorInternalState, uid::UInt64, re proc_occupancy = istate.proc_occupancy time_pressure = istate.time_pressure - wid = get_parent(to_proc).pid + wid = root_worker_id(to_proc) work_to_do = false while isopen(return_queue) # Wait for new tasks @@ -1138,7 +1158,6 @@ function start_processor_runner!(istate::ProcessorInternalState, uid::UInt64, re # Try to steal from local queues randomly # TODO: Prioritize stealing from busiest processors states = proc_states_values(uid) - # TODO: Try to pre-allocate this P = randperm(length(states)) for state in getindex.(Ref(states), P) other_istate = state.state @@ -1155,7 +1174,8 @@ function start_processor_runner!(istate::ProcessorInternalState, uid::UInt64, re end task, occupancy = peek(queue) scope = task.scope - if Dagger.proc_in_scope(to_proc, scope) + accel = something(task.options.acceleration, Dagger.DistributedAcceleration()) + if Dagger.proc_in_scope(to_proc, scope) && Dagger.accel_matches_proc(accel, to_proc) typemax(UInt32) - proc_occupancy_cached >= occupancy # Compatible, steal this task return dequeue_pair!(queue) @@ -1362,6 +1382,8 @@ function do_tasks(to_proc, return_queue, tasks) @dagdebug nothing :processor "Kicked processors" end +const SCHED_MOVE = ScopedValue{Bool}(false) + """ do_task(to_proc, task::TaskSpec) -> Any @@ -1373,13 +1395,15 @@ Executes a single task specified by `task` on `to_proc`. ctx_vars = task.ctx_vars ctx = Context(Processor[]; log_sink=ctx_vars.log_sink, profile=ctx_vars.profile) - from_proc = OSProc() + options = task.options + Dagger.accelerate!(options.acceleration) + + from_proc = Dagger.default_processor() data = task.data Tf = task.Tf f = isdefined(Tf, :instance) ? Tf.instance : nothing # Wait for required resources to become available - options = task.options propagated = get_propagated_options(options) to_storage = options.storage !== nothing ? fetch(options.storage) : MemPool.GLOBAL_DEVICE[] #to_storage_name = nameof(typeof(to_storage)) @@ -1447,7 +1471,7 @@ Executes a single task specified by `task` on `to_proc`. @maybelog ctx timespan_finish(ctx, :storage_wait, (;thunk_id, processor=to_proc), (;f, device=typeof(to_storage))) =# - @dagdebug thunk_id :execute "Moving data" + @dagdebug thunk_id :execute "Moving data for $Tf" # Initiate data transfers for function and arguments transfer_time = Threads.Atomic{UInt64}(0) @@ -1470,7 +1494,9 @@ Executes a single task specified by `task` on `to_proc`. Some{Any}(get!(CHUNK_CACHE[x], to_proc) do # Convert from cached value # TODO: Choose "closest" processor of same type first - some_proc = first(keys(CHUNK_CACHE[x])) + cache_procs = keys(CHUNK_CACHE[x]) + some_proc = Dagger.current_acceleration() isa Dagger.MPIAcceleration ? + minimum(cache_procs, by=_mpi_proc_rank) : first(cache_procs) some_x = CHUNK_CACHE[x][some_proc] @dagdebug thunk_id :move "Cache hit for argument $id at $some_proc: $some_x" @invokelatest move(some_proc, to_proc, some_x) @@ -1505,13 +1531,23 @@ Executes a single task specified by `task` on `to_proc`. end else =# - new_value = @invokelatest move(to_proc, value) + new_value = with(SCHED_MOVE=>true) do + @invokelatest move(to_proc, value) + end #end - if new_value !== value - @dagdebug thunk_id :move "Moved argument @ $position to $to_proc: $(typeof(value)) -> $(typeof(new_value))" + # Preserve Chunk reference when move returns nothing (placeholder on this rank). This keeps + # type information correct at all ranks: chunktype(Chunk) is concrete even when Chunk holds no data. + # So execute! sees correct arg_types. Materializing the value (for the kernel) must happen in + # execute! and may require lazy recv from the executor if this rank has a placeholder. + if new_value === nothing && (value isa Dagger.Chunk || value isa Dagger.WeakChunk) + arg.value = value + else + if new_value !== value + @dagdebug thunk_id :move "Moved argument @ $position to $to_proc: $(typeof(value)) -> $(typeof(new_value))" + end + arg.value = new_value end - @maybelog ctx timespan_finish(ctx, :move, (;thunk_id, position, processor=to_proc), (;f, data=new_value); tasks=[Base.current_task()]) - arg.value = new_value + @maybelog ctx timespan_finish(ctx, :move, (;thunk_id, position, processor=to_proc), (;f, data=Dagger.value(arg)); tasks=[Base.current_task()]) return end end @@ -1550,7 +1586,7 @@ Executes a single task specified by `task` on `to_proc`. # FIXME #gcnum_start = Base.gc_num() - @dagdebug thunk_id :execute "Executing $(typeof(f))" + @dagdebug thunk_id :execute "Executing $Tf" logging_enabled = !(ctx.log_sink isa TimespanLogging.NoOpLog) @@ -1613,7 +1649,7 @@ Executes a single task specified by `task` on `to_proc`. notify(TASK_SYNC) end - @dagdebug thunk_id :execute "Returning" + @dagdebug thunk_id :execute "Returning $Tf with $(typeof(result_meta))" # TODO: debug_storage("Releasing $to_storage_name") metadata = ( diff --git a/src/sch/util.jl b/src/sch/util.jl index 11706382a..164685195 100644 --- a/src/sch/util.jl +++ b/src/sch/util.jl @@ -383,7 +383,7 @@ function signature(f, args) value = Dagger.value(arg) if value isa Dagger.DTask # Only occurs via manual usage of signature - value = fetch(value; raw=true) + value = fetch(value; move_value=false, unwrap=false) end if istask(value) throw(ConcurrencyViolationError("Must call `collect_task_inputs!(state, task)` before calling `signature`")) @@ -440,8 +440,8 @@ function can_use_proc(state, task, gproc, proc, opts, scope) # Check against single if opts.single !== nothing @warn "The `single` option is deprecated, please use scopes instead\nSee https://juliaparallel.org/Dagger.jl/stable/scopes/ for details" maxlog=1 - if gproc.pid != opts.single - @dagdebug task :scope "Rejected $proc: gproc.pid ($(gproc.pid)) != single ($(opts.single))" + if root_worker_id(gproc) != opts.single + @dagdebug task :scope "Rejected $proc: gproc root_worker_id ($(root_worker_id(gproc))) != single ($(opts.single))" return false, scope end scope = constrain(scope, Dagger.ProcessScope(opts.single)) @@ -593,9 +593,10 @@ const DEFAULT_TRANSFER_RATE = UInt64(1_000_000) # Add fixed cost for cross-worker task transfer (esimated at 1ms) # TODO: Actually estimate/benchmark this - task_xfer_cost = gproc.pid != myid() ? 1_000_000 : 0 # 1ms + task_xfer_cost = root_worker_id(gproc) != myid() ? 1_000_000 : 0 # 1ms + pid = Dagger.root_worker_id(gproc) - tx_rate = get(get(state.worker_transfer_rate, gproc.pid, Dict{Processor,UInt64}()), proc, DEFAULT_TRANSFER_RATE) + tx_rate = get(get(state.worker_transfer_rate, pid, Dict{Processor,UInt64}()), proc, DEFAULT_TRANSFER_RATE) costs[proc] = est_time_util + (tx_cost/tx_rate) + task_xfer_cost end chunks_cleanup() diff --git a/src/scopes.jl b/src/scopes.jl index 79190c292..28aa8fa00 100644 --- a/src/scopes.jl +++ b/src/scopes.jl @@ -101,7 +101,7 @@ struct ExactScope <: AbstractScope parent::ProcessScope processor::Processor end -ExactScope(proc) = ExactScope(ProcessScope(get_parent(proc).pid), proc) +ExactScope(proc) = ExactScope(ProcessScope(root_worker_id(get_parent(proc))), proc) proc_in_scope(proc::Processor, scope::ExactScope) = proc == scope.processor "Indicates that the applied scopes `x` and `y` are incompatible." diff --git a/src/shard.jl b/src/shard.jl new file mode 100644 index 000000000..ecd0ee570 --- /dev/null +++ b/src/shard.jl @@ -0,0 +1,89 @@ +""" +Maps a value to one of multiple distributed "mirror" values automatically when +used as a thunk argument. Construct using `@shard` or `shard`. +""" +struct Shard + chunks::Dict{Processor,Chunk} +end + +""" + shard(f; kwargs...) -> Chunk{Shard} + +Executes `f` on all workers in `workers`, wrapping the result in a +process-scoped `Chunk`, and constructs a `Chunk{Shard}` containing all of these +`Chunk`s on the current worker. + +Keyword arguments: +- `procs` -- The list of processors to create pieces on. May be any iterable container of `Processor`s. +- `workers` -- The list of workers to create pieces on. May be any iterable container of `Integer`s. +- `per_thread::Bool=false` -- If `true`, creates a piece per each thread, rather than a piece per each worker. +""" +function shard(@nospecialize(f); procs=nothing, workers=nothing, per_thread=false) + if procs === nothing + if workers !== nothing + procs = [OSProc(w) for w in workers] + else + procs = lock(Sch.eager_context()) do + copy(Sch.eager_context().procs) + end + end + if per_thread + _procs = ThreadProc[] + for p in procs + append!(_procs, filter(p->p isa ThreadProc, get_processors(p))) + end + procs = _procs + end + else + if workers !== nothing + throw(ArgumentError("Cannot combine `procs` and `workers`")) + elseif per_thread + throw(ArgumentError("Cannot combine `procs` and `per_thread=true`")) + end + end + isempty(procs) && throw(ArgumentError("Cannot create empty Shard")) + shard_running_dict = Dict{Processor,DTask}() + for proc in procs + scope = proc isa OSProc ? ProcessScope(proc) : ExactScope(proc) + thunk = Dagger.@spawn scope=scope _mutable_inner(f, proc, scope) + shard_running_dict[proc] = thunk + end + shard_dict = Dict{Processor,Chunk}() + for proc in procs + shard_dict[proc] = fetch(shard_running_dict[proc])[] + end + return Shard(shard_dict) +end + +"Creates a `Shard`. See [`Dagger.shard`](@ref) for details." +macro shard(exs...) + opts = esc.(exs[1:end-1]) + ex = exs[end] + quote + let f = @noinline ()->$(esc(ex)) + $shard(f; $(opts...)) + end + end +end + +function move(from_proc::Processor, to_proc::Processor, shard::Shard) + # Match either this proc or some ancestor + # N.B. This behavior may bypass the piece's scope restriction + proc = to_proc + if haskey(shard.chunks, proc) + return move(from_proc, to_proc, shard.chunks[proc]) + end + parent = Dagger.get_parent(proc) + while parent != proc + proc = parent + parent = Dagger.get_parent(proc) + if haskey(shard.chunks, proc) + return move(from_proc, to_proc, shard.chunks[proc]) + end + end + + throw(KeyError(to_proc)) +end +Base.iterate(s::Shard) = iterate(values(s.chunks)) +Base.iterate(s::Shard, state) = iterate(values(s.chunks), state) +Base.length(s::Shard) = length(s.chunks) diff --git a/src/submission.jl b/src/submission.jl index 4ff4f2294..d3102eacf 100644 --- a/src/submission.jl +++ b/src/submission.jl @@ -285,7 +285,13 @@ function eager_process_args_submission_to_local(id_map, spec::DTaskSpec{true}) return ntuple(i->eager_process_elem_submission_to_local(id_map, spec.fargs[i]), length(spec.fargs)) end -DTaskMetadata(spec::DTaskSpec) = DTaskMetadata(eager_metadata(spec.fargs)) +function DTaskMetadata(spec::DTaskSpec) + rt = spec.options.return_type + if rt !== nothing && isconcretetype(rt) && rt !== Any + return DTaskMetadata(rt) + end + return DTaskMetadata(eager_metadata(spec.fargs)) +end function eager_metadata(fargs) f = value(fargs[1]) f = f isa StreamingFunction ? f.f : f @@ -298,6 +304,10 @@ function eager_spawn(spec::DTaskSpec) uid = eager_next_id() future = ThunkFuture() metadata = DTaskMetadata(spec) + # Propagate inferred return type to options + if isconcretetype(metadata.return_type) + spec.options.return_type = metadata.return_type + end return DTask(uid, future, metadata) end @@ -320,10 +330,16 @@ function eager_launch!(pair::DTaskPair) end end + # Propagate DTask return_type into options so the created Thunk has chunktype for downstream inference + options = spec.options + if isconcretetype(task.metadata.return_type) + options = copy(options) + options.return_type = task.metadata.return_type + end # Submit the task #=FIXME:REALLOC=# thunk_id = eager_submit!(PayloadOne(task.uid, task.future, - fargs, spec.options, true)) + fargs, options, true)) task.thunk_ref = thunk_id.ref end # FIXME: Don't convert Tuple to Vector{Argument} @@ -353,7 +369,13 @@ function eager_launch!(pairs::Vector{DTaskPair}) end end end - all_options = Options[pair.spec.options for pair in pairs] + # Propagate DTask return_type into options so created Thunks have chunktype for downstream inference + all_options = Options[ + let opts = pair.spec.options + isconcretetype(pair.task.metadata.return_type) ? (o = copy(opts); o.return_type = pair.task.metadata.return_type; o) : opts + end + for pair in pairs + ] # Submit the tasks #=FIXME:REALLOC=# diff --git a/src/thunk.jl b/src/thunk.jl index e13e299f0..c24e0c329 100644 --- a/src/thunk.jl +++ b/src/thunk.jl @@ -247,6 +247,14 @@ isweak(t) = false Base.show(io::IO, t::WeakThunk) = (print(io, "~"); Base.show(io, t.x.value)) Base.convert(::Type{WeakThunk}, t::Thunk) = WeakThunk(t) chunktype(t::WeakThunk) = chunktype(unwrap_weak_checked(t)) +# Use options.return_type when set (e.g. from mpi_propagate_chunk_types! or eager_metadata) +# so that Thunk arguments propagate type to downstream eager_metadata/execute! +function chunktype(t::Thunk) + if t.options !== nothing && t.options.return_type !== nothing && isconcretetype(t.options.return_type) + return t.options.return_type + end + return typeof(t) +end Base.convert(::Type{ThunkSyncdep}, t::WeakThunk) = ThunkSyncdep(nothing, t) ThunkSyncdep(t::WeakThunk) = ThunkSyncdep(nothing, t) @@ -462,7 +470,7 @@ function _par(mod, ex::Expr; lazy=true, recur=true, opts=()) end args = filter(arg->!Meta.isexpr(arg, :parameters), allargs) kwargs = filter(arg->Meta.isexpr(arg, :parameters), allargs) - if !isempty(kwargs) + if !Base.isempty(kwargs) kwargs = only(kwargs).args end if body !== nothing @@ -530,7 +538,7 @@ function spawn(f, args...; kwargs...) @nospecialize f args kwargs # Merge all passed options - if length(args) >= 1 && first(args) isa Options + if length(args) >= 1 && first(args) isa Options # N.B. Make a defensive copy in case user aliases Options struct task_options = copy(first(args)::Options) args = args[2:end] @@ -545,7 +553,7 @@ function spawn(f, args...; kwargs...) end function typed_spawn(f, args...; kwargs...) # Merge all passed options - if length(args) >= 1 && first(args) isa Options + if length(args) >= 1 && first(args) isa Options # N.B. Make a defensive copy in case user aliases Options struct task_options = copy(first(args)::Options) args = args[2:end] diff --git a/src/tochunk.jl b/src/tochunk.jl new file mode 100644 index 000000000..ff15e426e --- /dev/null +++ b/src/tochunk.jl @@ -0,0 +1,119 @@ +@warn "Update tochunk docstring" maxlog=1 +""" + tochunk(x, proc::Processor, scope::AbstractScope; device=nothing, rewrap=false, kwargs...) -> Chunk + +Create a chunk from data `x` which resides on `proc` and which has scope +`scope`. + +`device` specifies a `MemPool.StorageDevice` (which is itself wrapped in a +`Chunk`) which will be used to manage the reference contained in the `Chunk` +generated by this function. If `device` is `nothing` (the default), the data +will be inspected to determine if it's safe to serialize; if so, the default +MemPool storage device will be used; if not, then a `MemPool.CPURAMDevice` will +be used. + +`type` can be specified manually to force the type to be `Chunk{type}`. + +If `rewrap==true` and `x isa Chunk`, then the `Chunk` will be rewrapped in a +new `Chunk`. + +All other kwargs are passed directly to `MemPool.poolset`. +""" +tochunk(x::X, proc::P, space::M; kwargs...) where {X,P<:Processor,M<:MemorySpace} = + tochunk(x, proc, space, AnyScope(); kwargs...) +function tochunk(x::X, proc::P, space::M, scope::S; device=nothing, type=X, rewrap=false, kwargs...) where {X,P<:Processor,S,M<:MemorySpace} + if type === Nothing + throw(ArgumentError("Chunk type cannot be Nothing. Placeholder chunks must be created with an explicit type= (e.g. tochunk(nothing, proc, space; type=Matrix{Float64})). x=$(repr(x))")) + end + if x isa Chunk + check_proc_space(x, proc, space) + return maybe_rewrap(x, proc, space, scope; type, rewrap) + end + if device === nothing + device = if Sch.walk_storage_safe(x) + MemPool.GLOBAL_DEVICE[] + else + MemPool.CPURAMDevice() + end + end + ref = tochunk_pset(x, space; device, type, kwargs...) + return Chunk{type,typeof(ref),P,S,typeof(space)}(type, domain(x), ref, proc, scope, space) +end +# Disambiguate: Chunk-specific 3-arg so kwcall(tochunk, Chunk, Processor, Scope) is not ambiguous with utils/chunks.jl +function tochunk(x::Chunk, proc::P, scope::S; rewrap=false, kwargs...) where {P<:Processor,S} + if rewrap + return remotecall_fetch(x.handle.owner) do + tochunk(MemPool.poolget(x.handle), proc, scope; kwargs...) + end + else + return x + end +end +function tochunk(x::X, proc::P, scope::S; device=nothing, type=X, rewrap=false, kwargs...) where {X,P<:Processor,S} + if type === Nothing + throw(ArgumentError("Chunk type cannot be Nothing. Placeholder chunks must be created with an explicit type= (e.g. tochunk(nothing, proc, space; type=Matrix{Float64})). x=$(repr(x))")) + end + if device === nothing + device = if Sch.walk_storage_safe(x) + MemPool.GLOBAL_DEVICE[] + else + MemPool.CPURAMDevice() + end + end + if x isa Chunk + space = x.space + check_proc_space(x, proc, space) + return maybe_rewrap(x, proc, space, scope; type, rewrap) + end + space = default_memory_space(current_acceleration(), x) + ref = tochunk_pset(x, space; device, type, kwargs...) + return Chunk{type,typeof(ref),P,S,typeof(space)}(type, domain(x), ref, proc, scope, space) +end +function tochunk(x::X, space::M, scope::S; device=nothing, type=X, rewrap=false, kwargs...) where {X,M<:MemorySpace,S} + if type === Nothing + throw(ArgumentError("Chunk type cannot be Nothing. Placeholder chunks must be created with an explicit type= (e.g. tochunk(nothing, proc, space; type=Matrix{Float64})). x=$(repr(x))")) + end + if device === nothing + device = if Sch.walk_storage_safe(x) + MemPool.GLOBAL_DEVICE[] + else + MemPool.CPURAMDevice() + end + end + if x isa Chunk + proc = x.processor + check_proc_space(x, proc, space) + return maybe_rewrap(x, proc, space, scope; type, rewrap) + end + proc = default_processor(current_acceleration(), x) + ref = tochunk_pset(x, space; device, type, kwargs...) + return Chunk{type,typeof(ref),typeof(proc),S,M}(type, domain(x), ref, proc, scope, space) +end +# 2-arg: avoid overwriting utils/chunks.jl's tochunk(Any, Any) and tochunk(Any); only add Processor/MemorySpace variants +# Chunk + Processor: disambiguate vs utils/chunks.jl's tochunk(x::Chunk, proc; ...) +tochunk(x::Chunk, proc::Processor; kwargs...) = tochunk(x, proc, AnyScope(); kwargs...) +tochunk(x, proc::Processor; kwargs...) = tochunk(x, proc, AnyScope(); kwargs...) +tochunk(x, space::MemorySpace; kwargs...) = tochunk(x, space, AnyScope(); kwargs...) + +check_proc_space(x, proc, space) = nothing +function check_proc_space(x::Chunk, proc, space) + if x.space !== space + throw(ArgumentError("Memory space mismatch: Chunk=$(x.space) != Requested=$space")) + end +end +function check_proc_space(x::Thunk, proc, space) + # FIXME: Validate +end +function maybe_rewrap(x, proc, space, scope; type, rewrap) + if rewrap + return remotecall_fetch(x.handle.owner) do + tochunk(MemPool.poolget(x.handle), proc, scope; kwargs...) + end + else + return x + end +end + +tochunk_pset(x, space::MemorySpace; device=nothing, type=nothing, kwargs...) = poolset(x; device, kwargs...) + +# savechunk: defined in utils/chunks.jl (fork Chunk has space field; do not duplicate here) diff --git a/src/types/acceleration.jl b/src/types/acceleration.jl new file mode 100644 index 000000000..f9aa1d86f --- /dev/null +++ b/src/types/acceleration.jl @@ -0,0 +1,3 @@ +abstract type Acceleration end + +struct DistributedAcceleration <: Acceleration end diff --git a/src/types/chunk.jl b/src/types/chunk.jl new file mode 100644 index 000000000..9b8102a6d --- /dev/null +++ b/src/types/chunk.jl @@ -0,0 +1,27 @@ +""" + Chunk + +A reference to a piece of data located on a remote worker. `Chunk`s are +typically created with `Dagger.tochunk(data)`, and the data can then be +accessed from any worker with `collect(::Chunk)`. `Chunk`s are +serialization-safe, and use distributed refcounting (provided by +`MemPool.DRef`) to ensure that the data referenced by a `Chunk` won't be GC'd, +as long as a reference exists on some worker. + +Each `Chunk` is associated with a given `Dagger.Processor`, which is (in a +sense) the processor that "owns" or contains the data. Calling +`collect(::Chunk)` will perform data movement and conversions defined by that +processor to safely serialize the data to the calling worker. + +## Constructors +See [`tochunk`](@ref). +""" + +mutable struct Chunk{T, H, P<:Processor, S<:AbstractScope, M<:MemorySpace} + chunktype::Type{T} + domain + handle::H + processor::P + scope::S + space::M +end diff --git a/src/types/memory-space.jl b/src/types/memory-space.jl new file mode 100644 index 000000000..247ceccb0 --- /dev/null +++ b/src/types/memory-space.jl @@ -0,0 +1 @@ +abstract type MemorySpace end \ No newline at end of file diff --git a/src/types/processor.jl b/src/types/processor.jl new file mode 100644 index 000000000..1e333413f --- /dev/null +++ b/src/types/processor.jl @@ -0,0 +1,2 @@ +# Docstring for Processor is attached in src/processor.jl after OSProc is defined (avoids "Replacing docs" warning). +abstract type Processor end \ No newline at end of file diff --git a/src/types/scope.jl b/src/types/scope.jl new file mode 100644 index 000000000..0197fddf9 --- /dev/null +++ b/src/types/scope.jl @@ -0,0 +1 @@ +abstract type AbstractScope end \ No newline at end of file diff --git a/src/utils/chunks.jl b/src/utils/chunks.jl index 9f0c3b487..1300a5a1d 100644 --- a/src/utils/chunks.jl +++ b/src/utils/chunks.jl @@ -161,7 +161,8 @@ function tochunk(x::X, proc::P=OSProc(), scope::S=AnyScope(); device=nothing, re end end ref = poolset(x; device, kwargs...) - Chunk{X,typeof(ref),P,S}(X, domain(x), ref, proc, scope) + space = memory_space(proc) + Chunk{X,typeof(ref),P,S,typeof(space)}(X, domain(x), ref, proc, scope, space) end function tochunk(x::Chunk, proc=nothing, scope=nothing; rewrap=false, kwargs...) if rewrap @@ -185,5 +186,6 @@ function savechunk(data, dir, f) fr = FileRef(f, sz) proc = OSProc() scope = AnyScope() # FIXME: Scoped to this node - Chunk{typeof(data),typeof(fr),typeof(proc),typeof(scope)}(typeof(data), domain(data), fr, proc, scope, true) + space = memory_space(proc) + Chunk{typeof(data),typeof(fr),typeof(proc),typeof(scope),typeof(space)}(typeof(data), domain(data), fr, proc, scope, space) end diff --git a/src/utils/dagdebug.jl b/src/utils/dagdebug.jl index 873e47e79..678445051 100644 --- a/src/utils/dagdebug.jl +++ b/src/utils/dagdebug.jl @@ -59,4 +59,7 @@ macro opcounter(category, count=1) end end) end -opcounter(mod::Module, category::Symbol) = getfield(mod, Symbol(:OPCOUNTER_, category)).value[] \ No newline at end of file +opcounter(mod::Module, category::Symbol) = getfield(mod, Symbol(:OPCOUNTER_, category)).value[] + +# No-op debug helper for tracking largest values (used alongside @opcounter) +largest_value_update!(::Any) = nothing \ No newline at end of file diff --git a/src/weakchunk.jl b/src/weakchunk.jl new file mode 100644 index 000000000..e31070536 --- /dev/null +++ b/src/weakchunk.jl @@ -0,0 +1,23 @@ +struct WeakChunk + wid::Int + id::Int + x::WeakRef +end + +function WeakChunk(c::Chunk) + return WeakChunk(c.handle.owner, c.handle.id, WeakRef(c)) +end + +unwrap_weak(c::WeakChunk) = c.x.value +function unwrap_weak_checked(c::WeakChunk) + cw = unwrap_weak(c) + @assert cw !== nothing "WeakChunk expired: ($(c.wid), $(c.id))" + return cw +end +wrap_weak(c::Chunk) = WeakChunk(c) +isweak(c::WeakChunk) = true +isweak(c::Chunk) = false +is_task_or_chunk(c::WeakChunk) = true +Serialization.serialize(io::AbstractSerializer, wc::WeakChunk) = + error("Cannot serialize a WeakChunk") +chunktype(c::WeakChunk) = chunktype(unwrap_weak_checked(c)) diff --git a/test/mpi.jl b/test/mpi.jl new file mode 100644 index 000000000..7d71e801e --- /dev/null +++ b/test/mpi.jl @@ -0,0 +1,72 @@ +using Dagger, MPI, LinearAlgebra + +Dagger.accelerate!(:mpi) +Dagger.check_uniformity!(true) +comm = MPI.COMM_WORLD +rank = MPI.Comm_rank(comm) +sz = MPI.Comm_size(comm) + +mpidagger_all_results = [] + +# Define constants +# You need to define the MPI workers before running the benchmark +# Example: mpirun -n 4 julia --project benchmarks/DaggerMPI_Weak_scale.jl +datatype = [Float32, Float64] +datasize = 40 +try + for T in datatype + A = rand(T, datasize, datasize) + A = A * A' + A[diagind(A)] .+= size(A, 1) + B = copy(A) + @assert ishermitian(B) + DA = zeros(Blocks(20,20), T, datasize, datasize) + for chunk in DA.chunks + Dagger.check_uniform(fetch(chunk; move_value=false, unwrap=false).space) + end + copyto!(DA, A) + DB = zeros(Blocks(20,20), T, datasize, datasize) + for chunk in DB.chunks + Dagger.check_uniform(fetch(chunk; move_value=false, unwrap=false).space) + end + copyto!(DB, B) + LinearAlgebra._chol!(DA, UpperTriangular) + elapsed_time = @elapsed chol_DB = LinearAlgebra._chol!(DB, UpperTriangular) + + # Store results + result = ( + procs = sz, + dtype = T, + size = datasize, + time = elapsed_time, + gflops = (datasize^3 / 3) / (elapsed_time * 1e9) + ) + push!(mpidagger_all_results, result) + end +catch + if rank == 0 + Core.print("Rank 0:\n") + rethrow() + elseif rank == 1 + Core.print("Rank 1:\n") + sleep(1) + rethrow() + end +finally + MPI.Barrier(comm) +end +if rank == 0 + #= Write results to CSV + mkpath("benchmarks/results") + if !isempty(mpidagger_all_results) + df = DataFrame(mpidagger_all_results) + CSV.write("benchmarks/results/DaggerMPI_Weak_scale_results.csv", df) + + end + =# + # Summary statistics + for result in mpidagger_all_results + println(result.procs, ",", result.dtype, ",", result.size, ",", result.time, ",", result.gflops) + end + #println("\nAll Cholesky tests completed!") +end