|
| 1 | +const std = @import("std"); |
| 2 | +const return_value = @import("../return_value/return_value.zig").return_value; |
| 3 | +const thread_info = @import("../thread_info/thread_info.zig").thread_info; |
| 4 | +const error_sets = @import("../error_sets/error_sets.zig"); |
| 5 | + |
| 6 | +pub const operation = struct { |
| 7 | + thread: std.Thread, |
| 8 | + result: *return_value, |
| 9 | + allocator: std.mem.Allocator, |
| 10 | + |
| 11 | + fn primary_thread(allocator: std.mem.Allocator, result: *return_value,nthreads: usize, nkernels: u128, comptime kernel: anytype, args: anytype) void { |
| 12 | + result.* = return_value.success; |
| 13 | + if (nthreads == 0) { |
| 14 | + result.* = return_value.internal_error; |
| 15 | + return; |
| 16 | + } |
| 17 | + const threads = |
| 18 | + allocator.alloc(std.Thread, nthreads) |
| 19 | + catch { |
| 20 | + result.* = return_value.internal_error; |
| 21 | + return; |
| 22 | + }; |
| 23 | + defer allocator.free(threads); |
| 24 | + const thread_returns = |
| 25 | + allocator.alloc(return_value, nthreads) |
| 26 | + catch { |
| 27 | + result.* = return_value.internal_error; |
| 28 | + return; |
| 29 | + }; |
| 30 | + defer allocator.free(thread_returns); |
| 31 | + |
| 32 | + const thread_info_gen = thread_info.generator.init(nthreads, nkernels, thread_returns); |
| 33 | + for (0.., thread_returns, threads) |nthreads_launched, *this_thread_return, *this_thread| { |
| 34 | + this_thread_return.* = return_value.success; |
| 35 | + this_thread.* = |
| 36 | + std.Thread.spawn(.{}, kernel, .{thread_info_gen.gen(nthreads_launched)} ++ args) |
| 37 | + catch { |
| 38 | + for (threads[0..nthreads_launched]) |thread| { |
| 39 | + thread.join(); |
| 40 | + } |
| 41 | + result.* = return_value.internal_error; |
| 42 | + return; |
| 43 | + }; |
| 44 | + } |
| 45 | + |
| 46 | + for (0.., thread_returns, threads) |i, thread_returned, thread| { |
| 47 | + thread.join(); |
| 48 | + if (thread_returned != return_value.success) { |
| 49 | + result.* = return_value.thread_error; |
| 50 | + for (threads[i..]) |thread_quickjoin| { |
| 51 | + thread_quickjoin.join(); |
| 52 | + } |
| 53 | + return; |
| 54 | + } |
| 55 | + } |
| 56 | + return; |
| 57 | + } |
| 58 | + pub fn launch(allocator: std.mem.Allocator, nthreads: usize, nkernels: u128, comptime kernel: anytype, args: anytype) error_sets.kernelmaster_error!operation { |
| 59 | + const result: *return_value = allocator.create(return_value) catch return error.kernelmaster_internal_error; |
| 60 | + errdefer allocator.destroy(result); |
| 61 | + return .{ |
| 62 | + .thread = std.Thread.spawn( |
| 63 | + .{}, |
| 64 | + primary_thread, |
| 65 | + .{ |
| 66 | + allocator, |
| 67 | + result, |
| 68 | + nthreads, |
| 69 | + nkernels, |
| 70 | + kernel, |
| 71 | + args, |
| 72 | + } |
| 73 | + ) catch return error.kernelmaster_internal_error, |
| 74 | + .result = result, |
| 75 | + .allocator = allocator, |
| 76 | + }; |
| 77 | + } |
| 78 | + pub fn sync(op: operation) !void { |
| 79 | + std.Thread.join(op.thread); |
| 80 | + const r: return_value = op.result.*; |
| 81 | + op.allocator.destroy(op.result); |
| 82 | + switch (r) { |
| 83 | + return_value.success => return, |
| 84 | + return_value.thread_error => return error.kernelmaster_thread_error, |
| 85 | + else => return error.kernelmaster_internal_error, |
| 86 | + } |
| 87 | + } |
| 88 | +}; |
0 commit comments