diff --git a/.github/workflows/compiler-build.yml b/.github/workflows/compiler-build.yml index 32ab43f05..04a0a08ea 100644 --- a/.github/workflows/compiler-build.yml +++ b/.github/workflows/compiler-build.yml @@ -138,7 +138,7 @@ jobs: working-directory: ${{github.workspace}} run: | dotnet tool install --global dotnet-coverage - dotnet-coverage collect -s tools/dotnet_coverage.settings.xml -f cobertura -o coverage/unit.xml "dotnet test -c ${{matrix.config.buildType}} -s test.runsettings --no-build --verbosity normal --blame" + dotnet-coverage collect -s tools/dotnet_coverage.settings.xml -f cobertura -o coverage/unit.xml "dotnet test -c ${{matrix.config.buildType}} -s test.runsettings --no-build --verbosity normal --filter FullyQualifiedName!~Nncase.Tests.TargetTest.UnitTestCUDAKernels --blame" dotnet-coverage merge -o coverage.unit.xml -f cobertura -r coverage/*.xml - name: Upload Coverage diff --git a/CMakeLists.txt b/CMakeLists.txt index 0193cbf70..01d61edc0 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -47,6 +47,18 @@ option(BUILD_TESTING "Build test programs" OFF) option(ENABLE_OP_PROFILE "Profile ops cast time" OFF) option(ENABLE_DUMP_MANAGER "Enable dump manager" OFF) option(ENABLE_DUMP_MEM "Dump mem usage" OFF) +option(ENABLE_CUDA_RUNTIME "Enable CUDA runtime" OFF) + +if(DEFINED CMAKE_CUDA_COMPILER AND NOT "${CMAKE_CUDA_COMPILER}" STREQUAL "") + set(ENABLE_CUDA_RUNTIME ON CACHE BOOL "Enable CUDA runtime" FORCE) +endif() + +if(ENABLE_CUDA_RUNTIME) + if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES) + set(CMAKE_CUDA_ARCHITECTURES 120) + endif() + enable_language(CUDA) +endif() if (BUILDING_RUNTIME) # option(ENABLE_VULKAN_RUNTIME "Enable Vulkan runtime" OFF) diff --git a/cmake/compile_flags.cmake b/cmake/compile_flags.cmake index b5d7a36c4..fa9022513 100644 --- a/cmake/compile_flags.cmake +++ b/cmake/compile_flags.cmake @@ -4,7 +4,7 @@ if (MSVC) set(PYBIND11_CPP_STANDARD "/std:c++latest") else() add_compile_options(-fvisibility=hidden) - add_compile_options(-Wall -Wextra -pedantic -Werror -Wno-multichar -Wno-missing-field-initializers -Wno-unused-function -Wno-type-limits -Wno-unused-local-typedefs -Wno-sign-compare) + add_compile_options(-Wall -Wextra -Wno-missing-field-initializers -Wno-unused-function -Wno-type-limits -Wno-unused-local-typedefs -Wno-sign-compare) if (APPLE) add_compile_options(-Wno-four-char-constants -Wno-sometimes-uninitialized -Wno-deprecated -Wno-braced-scalar-init) elseif (CMAKE_CXX_COMPILER_ID STREQUAL "Clang") @@ -15,6 +15,11 @@ else() endif() endif() +if (CMAKE_CUDA_COMPILER) + message(STATUS "Configuring for CUDA") + #add_compile_options(-save-temps) +endif() + if(${CMAKE_SYSTEM_PROCESSOR} MATCHES "(x86)|(X86)|(amd64)|(AMD64)|(x86_64)|(X86_64)") if (MSVC) diff --git a/conanfile.py b/conanfile.py index b20364120..33d5c946b 100644 --- a/conanfile.py +++ b/conanfile.py @@ -29,6 +29,7 @@ class nncaseConan(ConanFile): "k230_runtime": [True, False], "k80_runtime": [True, False], "vulkan_runtime": [True, False], + "cuda_runtime": [True, False], "tests": [True, False], "python": [True, False], "python_root": ["ANY"] @@ -40,6 +41,7 @@ class nncaseConan(ConanFile): "k230_runtime": False, "k80_runtime": False, "vulkan_runtime": False, + "cuda_runtime": False, "tests": False, "python": True, "python_root": "" @@ -88,8 +90,11 @@ def generate(self): tc.variables['ENABLE_K230_RUNTIME'] = self.options.k230_runtime tc.variables['ENABLE_K80_RUNTIME'] = self.options.k80_runtime tc.variables['ENABLE_VULKAN_RUNTIME'] = self.options.vulkan_runtime + tc.variables['ENABLE_CUDA_RUNTIME'] = self.options.cuda_runtime tc.variables['BUILD_PYTHON_BINDING'] = self.options.python tc.variables['BUILD_TESTING'] = self.options.tests + if self.options.cuda_runtime: + tc.variables['CMAKE_CUDA_ARCHITECTURES'] = "120" if self.options.get_safe("python_root", default="") != "": tc.variables['Python3_ROOT_DIR'] = self.options.python_root if self.options.runtime: diff --git a/modules/Nncase.Modules.NTT/CodeGen/CPU/CSourceBuiltn.cs b/modules/Nncase.Modules.NTT/CodeGen/CPU/CSourceBuiltn.cs index a746e14ad..c9cd8d750 100644 --- a/modules/Nncase.Modules.NTT/CodeGen/CPU/CSourceBuiltn.cs +++ b/modules/Nncase.Modules.NTT/CodeGen/CPU/CSourceBuiltn.cs @@ -80,16 +80,16 @@ public static string TopoAwareRuntimeDef(NTTTargetOptions options, ulong dataAli return content; } - public static string ModuleTopologyDef(NTTTargetOptions options) + public static string ModuleTopologyDef(NTTTargetOptions options, bool isCUDA) { - var content = RazorTemplateEngine.RenderAsync("~/CodeGen/CPU/Templates/module_topology_def.h.cshtml", options).Result; + var content = RazorTemplateEngine.RenderAsync("~/CodeGen/CPU/Templates/module_topology_def.h.cshtml", new { Hierarchies = options.Hierarchies[0], IsCUDA = isCUDA }).Result; return content; } - public static string CMakeDef() + public static string CMakeDef(bool isCUDA) { var cmakePath = CMakePath(Path.Combine(Path.GetDirectoryName(typeof(CSourceBuiltn).Assembly.Location)!, "Runtime", "cmake", "ntt_module.cmake")); - var content = RazorTemplateEngine.RenderAsync("~/CodeGen/CPU/Templates/CMakeLists.txt.cshtml", new { CMakePath = cmakePath }).Result; + var content = RazorTemplateEngine.RenderAsync("~/CodeGen/CPU/Templates/CMakeLists.txt.cshtml", new { CMakePath = cmakePath, IsCUDA = isCUDA }).Result; return content; } diff --git a/modules/Nncase.Modules.NTT/CodeGen/CPU/CSourceCompiler.cs b/modules/Nncase.Modules.NTT/CodeGen/CPU/CSourceCompiler.cs index 5d38ea682..34e15ba45 100644 --- a/modules/Nncase.Modules.NTT/CodeGen/CPU/CSourceCompiler.cs +++ b/modules/Nncase.Modules.NTT/CodeGen/CPU/CSourceCompiler.cs @@ -22,6 +22,8 @@ public class CSourceCompiler { private static string? _vcVarPath; + private readonly bool _isCUDA; + /// /// compiler exe name. /// @@ -37,8 +39,9 @@ public class CSourceCompiler /// private string _ext = string.Empty; - public CSourceCompiler() + public CSourceCompiler(bool isCUDA) { + _isCUDA = isCUDA; PlatformSpecific(); ArchSpecific(); } @@ -186,8 +189,16 @@ private void ArchSpecific() private string ArgumentsSpecific(string sourcePath, string outPath) { - var archConfig = RuntimeInformation.IsOSPlatform(OSPlatform.Windows) ? - "-DCMAKE_C_COMPILER=clang-cl -DCMAKE_CXX_COMPILER=clang-cl" : string.Empty; + string archConfig = string.Empty; + if (_isCUDA) + { + archConfig = $"-DCMAKE_CUDA_ARCHITECTURES=120 -DCMAKE_CUDA_COMPILER=clang++"; + } + else + { + archConfig = RuntimeInformation.IsOSPlatform(OSPlatform.Windows) ? + "-DCMAKE_C_COMPILER=clang-cl -DCMAKE_CXX_COMPILER=clang-cl" : string.Empty; + } #if DEBUG var config = "Release"; diff --git a/modules/Nncase.Modules.NTT/CodeGen/CPU/DeviceCSourceConvertVisitor.cs b/modules/Nncase.Modules.NTT/CodeGen/CPU/DeviceCSourceConvertVisitor.cs index 26a66d3a8..b4b47410d 100644 --- a/modules/Nncase.Modules.NTT/CodeGen/CPU/DeviceCSourceConvertVisitor.cs +++ b/modules/Nncase.Modules.NTT/CodeGen/CPU/DeviceCSourceConvertVisitor.cs @@ -47,7 +47,7 @@ public static void WriteWithProfiler(string functionName, string tagName = "") IndentScope.Writer.IndWrite("{\n"); #if false // Disable device profiling for now. IndentScope.Writer.Write($"constexpr std::string_view function_name = \"{tagName}\";\n"); - IndentScope.Writer.Write($"auto_profiler profiler(function_name, runtime::profiling_level::device);\n"); + IndentScope.Writer.Write($"profile_scope profiler(function_name, profile_level::device);\n"); #endif IndentScope.Writer.Write($"{functionName};\n"); IndentScope.Writer.IndWrite("}\n"); @@ -69,7 +69,7 @@ public static void WriteIndWithProfiler(string functionName, string tagName = "" IndentScope.Writer.IndWrite("{\n"); #if false // Disable device profiling for now. IndentScope.Writer.IndWrite($"constexpr std::string_view function_name = \"{tagName}\";\n"); - IndentScope.Writer.IndWrite($"auto_profiler profiler(function_name, runtime::profiling_level::device);\n"); + IndentScope.Writer.IndWrite($"profile_scope profiler(function_name, profile_level::device);\n"); #endif IndentScope.Writer.IndWrite($"{functionName};\n"); IndentScope.Writer.IndWrite("}\n"); @@ -94,7 +94,7 @@ protected override CSymbol VisitPrimFunction(PrimFunction expr) } var ctype = $"template<{string.Join(", ", Enumerable.Range(0, expr.Parameters.Length).Select(x => $"class T{x}"))}>" + Environment.NewLine + - $"void {expr.Name}({string.Join(", ", expr.Parameters.AsValueEnumerable().Select(Visit).Select((s, i) => $"T{i} &&{s.Name}").ToArray())})"; + $"NTT_DEVICE void {expr.Name}({string.Join(", ", expr.Parameters.AsValueEnumerable().Select(Visit).Select((s, i) => $"T{i} &&{s.Name}").ToArray())})"; using (var scope = new IndentScope(_deviceBuilder)) { @@ -192,7 +192,7 @@ protected override CSymbol VisitPhysicalBuffer(PhysicalBuffer expr) _ => throw new NotSupportedException(expr.Location.ToString()), }; - var str = $"std::span({name} + {start.Name}, {size.Name})"; + var str = $"ntt::span({name} + {start.Name}, {size.Name})"; symbol = new(start.Type, str); _exprMemo.Add(expr, symbol); return symbol; diff --git a/modules/Nncase.Modules.NTT/CodeGen/CPU/FunctionBuilder.cs b/modules/Nncase.Modules.NTT/CodeGen/CPU/FunctionBuilder.cs index b44ee466d..8d64872f2 100644 --- a/modules/Nncase.Modules.NTT/CodeGen/CPU/FunctionBuilder.cs +++ b/modules/Nncase.Modules.NTT/CodeGen/CPU/FunctionBuilder.cs @@ -21,15 +21,17 @@ internal class FunctionBuilder private readonly BinaryWriter _textWriter; private readonly BinaryWriter _rdataWriter; private readonly IReadOnlyList _threadLocalRdataWriters; + private readonly IReadOnlyList _warpLocalRdataWriters; private readonly IReadOnlyList _blockLocalRdataWriters; - public FunctionBuilder(uint id, BinaryWriter rdataWriter, IReadOnlyList threadLocalRdataWriters, IReadOnlyList blockLocalRdataWriters, Targets.NTTTargetOptions targetOptions) + public FunctionBuilder(uint id, BinaryWriter rdataWriter, IReadOnlyList threadLocalRdataWriters, IReadOnlyList warpLocalRdataWriters, IReadOnlyList blockLocalRdataWriters, Targets.NTTTargetOptions targetOptions) { _id = id; _sectionManager = new(); _textWriter = _sectionManager.GetWriter(WellknownSectionNames.Text); _rdataWriter = rdataWriter; _threadLocalRdataWriters = threadLocalRdataWriters; + _warpLocalRdataWriters = warpLocalRdataWriters; _blockLocalRdataWriters = blockLocalRdataWriters; TargetOptions = targetOptions; } @@ -58,66 +60,17 @@ public unsafe ILinkableFunction Build(BaseFunction baseFunc) tensor.Serialize(_rdataWriter.BaseStream); } - // 2. write the thread local rdata - ulong threadLocalRdataPoolSize = ulong.MinValue; - foreach (var (@const, range) in primFunc.SchedResult.ThreadLocalRdatas) - { - var tensor = ((TensorConst)@const).Value; - var distributedType = (DistributedType)@const.CheckedType; - var size = range.Max - range.Min; - threadLocalRdataPoolSize = System.Math.Max(range.Max, threadLocalRdataPoolSize); - var dividedDims = DistributedUtility.GetDividedTensorType(distributedType).Shape.ToValueArray(); - var localStrides = TensorUtilities.GetDefaultStrides(dividedDims); - for (int i = 0; i < _threadLocalRdataWriters.Count; i++) - { - var threadLocalRdataWriter = _threadLocalRdataWriters[i]; - var shardIndex = DistributedUtility.GetUnraveledIndex(i, TargetOptions.Hierarchies[0]); - (var localOffset, var localShape) = DistributedUtility.GetLocalOffsetAndShape(distributedType, shardIndex); - var linearOffset = TensorUtilities.GetLinearOffset(tensor.Strides, localOffset); - - if ((ulong)TensorUtilities.GetProduct(localShape) * (ulong)tensor.ElementType.SizeInBytes > size) - { - throw new InvalidDataException("The Buffer Size Not Equal!"); - } - - threadLocalRdataWriter.Position(checked((long)range.Min)); - tensor.Serialize(threadLocalRdataWriter.BaseStream, linearOffset, localShape, localStrides); - } - } - - // 2. write the block local rdata - ulong blockLocalRdataPoolSize = ulong.MinValue; - foreach (var (@const, range) in primFunc.SchedResult.BlockLocalRdatas) - { - var tensor = ((TensorConst)@const).Value; - var distributedType = (DistributedType)@const.CheckedType; - var size = range.Max - range.Min; - blockLocalRdataPoolSize = System.Math.Max(range.Max, blockLocalRdataPoolSize); - var dividedDims = DistributedUtility.GetDividedTensorType(distributedType).Shape.ToValueArray(); - var localStrides = TensorUtilities.GetDefaultStrides(dividedDims); - for (int i = 0; i < _blockLocalRdataWriters.Count; i++) - { - var blockLocalRdataWriter = _blockLocalRdataWriters[i]; - var shardIndex = DistributedUtility.GetUnraveledIndex(i, TargetOptions.Hierarchies[0][..^1]).Concat([0]).ToArray(); - (var localOffset, var localShape) = DistributedUtility.GetLocalOffsetAndShape(distributedType, shardIndex); - var linearOffset = TensorUtilities.GetLinearOffset(tensor.Strides, localOffset); - - if ((ulong)TensorUtilities.GetProduct(localShape) * (ulong)tensor.ElementType.SizeInBytes > size) - { - throw new InvalidDataException("The Buffer Size Not Equal!"); - } - - blockLocalRdataWriter.Position(checked((long)range.Min)); - tensor.Serialize(blockLocalRdataWriter.BaseStream, linearOffset, localShape, localStrides); - } - } + // 2. write the local rdatas + var threadLocalRdataPoolSize = SerializeLocalRdata(primFunc.SchedResult.ThreadLocalRdatas, _threadLocalRdataWriters, "t"); + var warpLocalRdataPoolSize = SerializeLocalRdata(primFunc.SchedResult.WarpLocalRdatas, _warpLocalRdataWriters, "w"); + var blockLocalRdataPoolSize = SerializeLocalRdata(primFunc.SchedResult.BlockLocalRdatas, _blockLocalRdataWriters, "b"); - // 4. build function. + // 3. build function. var visitor = new KernelCSourceConvertVisitor(TargetOptions); visitor.Visit(primFunc); var functionCSource = visitor.GetCSource(); - // 5. write the kernel desc + // 4. write the kernel desc using (var writer = _sectionManager.GetWriter(LinkableKernelFunction.KernelHeaderSectionName)) { var header = default(KernelDescHeader); @@ -125,6 +78,7 @@ public unsafe ILinkableFunction Build(BaseFunction baseFunc) header.LocalDataAlign = (uint)primFunc.SchedResult.DataAlign; header.OutputPoolSize = primFunc.SchedResult.OutputUsage; header.LocalDataPoolSize = primFunc.SchedResult.DataUsage; + header.WarpLocalDataPoolSize = primFunc.SchedResult.WarpLocalDataPoolSize; header.BlockLocalDataPoolSize = primFunc.SchedResult.BlockLocalDataPoolSize; writer.Write(ref header); } @@ -132,6 +86,7 @@ public unsafe ILinkableFunction Build(BaseFunction baseFunc) var memoryPoolDesc = new KernelMemoryPoolDesc( rdataPoolSize, threadLocalRdataPoolSize, + warpLocalRdataPoolSize, blockLocalRdataPoolSize); var kernelDescSection = new LinkedSection(_sectionManager.GetContent(LinkableKernelFunction.KernelHeaderSectionName)!, ".desc", 0, 8, (uint)sizeof(KernelDescHeader)); return new LinkableKernelFunction(_id, primFunc, functionCSource, memoryPoolDesc, _sectionManager.GetContent(WellknownSectionNames.Text)!, kernelDescSection); @@ -154,4 +109,50 @@ public unsafe ILinkableFunction Build(BaseFunction baseFunc) throw new NotSupportedException($"the {baseFunc.GetType()} {baseFunc.Name} is notsupport for codegen!"); } + + private ulong SerializeLocalRdata(IReadOnlyDictionary> localRdatas, IReadOnlyList localRdataWriters, string scopeName) + { + ulong localRdataPoolSize = ulong.MinValue; + foreach (var (@const, range) in localRdatas) + { + var tensor = ((TensorConst)@const).Value; + var distributedType = (DistributedType)@const.CheckedType; + var size = range.Max - range.Min; + localRdataPoolSize = System.Math.Max(range.Max, localRdataPoolSize); + var dividedDims = DistributedUtility.GetDividedTensorType(distributedType).Shape.ToValueArray(); + var localStrides = TensorUtilities.GetDefaultStrides(dividedDims); + for (int i = 0; i < localRdataWriters.Count; i++) + { + var localRdataWriter = localRdataWriters[i]; + var shardIndex = GetScopedShardIndex(i, scopeName); + (var localOffset, var localShape) = DistributedUtility.GetLocalOffsetAndShape(distributedType, shardIndex); + var linearOffset = TensorUtilities.GetLinearOffset(tensor.Strides, localOffset); + + if ((ulong)TensorUtilities.GetProduct(localShape) * (ulong)tensor.ElementType.SizeInBytes > size) + { + throw new InvalidDataException("The Buffer Size Not Equal!"); + } + + localRdataWriter.Position(checked((long)range.Min)); + tensor.Serialize(localRdataWriter.BaseStream, linearOffset, localShape, localStrides); + } + } + + return localRdataPoolSize; + } + + private int[] GetScopedShardIndex(int writerIndex, string scopeName) + { + var hierarchies = TargetOptions.Hierarchies[0]; + var scopeIndex = TargetOptions.HierarchyNames.IndexOf(scopeName, StringComparison.Ordinal); + if (scopeIndex < 0) + { + return DistributedUtility.GetUnraveledIndex(writerIndex, hierarchies); + } + + var scopedHierarchies = hierarchies[..(scopeIndex + 1)]; + return DistributedUtility.GetUnraveledIndex(writerIndex, scopedHierarchies) + .Concat(Enumerable.Repeat(0, hierarchies.Length - scopedHierarchies.Length)) + .ToArray(); + } } diff --git a/modules/Nncase.Modules.NTT/CodeGen/CPU/FusionCSourceConvertVisitor.cs b/modules/Nncase.Modules.NTT/CodeGen/CPU/FusionCSourceConvertVisitor.cs index d655ebc19..31a804f95 100644 --- a/modules/Nncase.Modules.NTT/CodeGen/CPU/FusionCSourceConvertVisitor.cs +++ b/modules/Nncase.Modules.NTT/CodeGen/CPU/FusionCSourceConvertVisitor.cs @@ -60,7 +60,7 @@ protected override CSymbol VisitFusion(Fusion expr) IndentScope.Writer.IndWrite($"template<{string.Join(", ", Enumerable.Range(0, expr.Parameters.Length).Select(x => $"class T{x}"))}> struct {expr.Name} {{\n"); using (_ = new IndentScope()) { - IndentScope.Writer.IndWrite($"auto operator()({string.Join(", ", expr.Parameters.AsValueEnumerable().Select(Visit).Select((s, i) => $"const T{i} &{s.Name}").ToArray())}) const noexcept {{\n"); + IndentScope.Writer.IndWrite($"constexpr auto operator()({string.Join(", ", expr.Parameters.AsValueEnumerable().Select(Visit).Select((s, i) => $"const T{i} &{s.Name}").ToArray())}) const noexcept {{\n"); // 2. Function body using (_ = new IndentScope()) diff --git a/modules/Nncase.Modules.NTT/CodeGen/CPU/KernelCSourceConvertVisitor.cs b/modules/Nncase.Modules.NTT/CodeGen/CPU/KernelCSourceConvertVisitor.cs index f20039466..96c6d89a2 100644 --- a/modules/Nncase.Modules.NTT/CodeGen/CPU/KernelCSourceConvertVisitor.cs +++ b/modules/Nncase.Modules.NTT/CodeGen/CPU/KernelCSourceConvertVisitor.cs @@ -27,7 +27,7 @@ namespace Nncase.CodeGen.NTT; /// internal sealed class KernelCSourceConvertVisitor : CSourceConvertVisitor, IDisposable { - private readonly HashSet _excludedVars = new() { "data", "block_local_data" }; + private readonly HashSet _excludedVars = new() { "data", "warp_local_data", "block_local_data" }; private readonly StringBuilder _kernelBuilder; private readonly HashSet _refFuncs; private readonly HashSet _declaredBuffers = new(ReferenceEqualityComparer.Instance); @@ -47,7 +47,7 @@ public KernelCSourceConvertVisitor(NTTTargetOptions targetOptions) private Var[] TensorParams => _tensorParams ??= VisitEntry.Parameters.ToArray().OfType().Where(x => !_excludedVars.Contains(x.Name)).ToArray(); - public static void WriteWithProfiler(string functionName, string tagName = "") + public void WriteWithProfiler(string functionName, string tagName = "") { functionName = functionName.TrimEnd(new char[] { ';', '\n' }); if (tagName == string.Empty) @@ -62,12 +62,12 @@ public static void WriteWithProfiler(string functionName, string tagName = "") tagName = tagName == string.Empty ? functionName : tagName; IndentScope.Writer.IndWrite("{\n"); IndentScope.Writer.Write($"constexpr std::string_view function_name = \"{tagName}\";\n"); - IndentScope.Writer.Write($"auto_profiler profiler(function_name, runtime::profiling_level::kernel);\n"); + IndentScope.Writer.Write($"profile_scope profiler(0, profile_level::kernel);\n"); IndentScope.Writer.Write($"{functionName};\n"); IndentScope.Writer.IndWrite("}\n"); } - public static void WriteIndWithProfiler(string functionName, string tagName = "") + public void WriteIndWithProfiler(string functionName, string tagName = "") { functionName = functionName.TrimEnd(new char[] { ';', '\n' }); if (tagName == string.Empty) @@ -82,7 +82,7 @@ public static void WriteIndWithProfiler(string functionName, string tagName = "" tagName = tagName == string.Empty ? functionName : tagName; IndentScope.Writer.IndWrite("{\n"); IndentScope.Writer.IndWrite($"constexpr std::string_view function_name = \"{tagName}\";\n"); - IndentScope.Writer.IndWrite($"auto_profiler profiler(function_name, runtime::profiling_level::kernel);\n"); + IndentScope.Writer.IndWrite($"profile_scope profiler(0, profile_level::kernel);\n"); IndentScope.Writer.IndWrite($"{functionName};\n"); IndentScope.Writer.IndWrite("}\n"); } @@ -92,7 +92,7 @@ public KernelCSource GetCSource() var paramsExcluded = VisitEntry.Parameters.ToArray().OfType().Where(x => !_excludedVars.Contains(x.Name)).ToArray(); var templateHeader = TensorParams.Length == 0 ? string.Empty : $"template<{string.Join(", ", Enumerable.Range(0, TensorParams.Length).Select(x => $"class T{x}"))}>" + Environment.NewLine; var ctype = templateHeader + - $"void {VisitEntry.Name}({string.Concat(paramsExcluded.Select(Visit).Select(s => $"{s.Type} {s.Name}, ").ToArray())}const std::byte *rdata, const std::byte *thread_local_rdata, const std::byte *block_local_rdata, std::byte *thread_local_data, std::byte *block_local_data, std::byte *output, nncase::ntt::runtime::thread_inout_desc *const output_descs)"; + $"NTT_DEVICE void {VisitEntry.Name}({string.Concat(paramsExcluded.Select(Visit).Select(s => $"{s.Type} {s.Name}, ").ToArray())}const std::byte *rdata, const std::byte *thread_local_rdata, const std::byte *warp_local_rdata, const std::byte *block_local_rdata, std::byte *thread_local_data, std::byte *warp_local_data, std::byte *block_local_data, std::byte *output, nncase::ntt::runtime::thread_inout_desc *const output_descs)"; return new( Declare: ctype + ";\n", Kernel: CSourceBuiltn.MakeKernel(ctype, _kernelBuilder.ToString()), @@ -186,16 +186,18 @@ protected override CSymbol VisitPhysicalBuffer(PhysicalBuffer expr) { (MemoryLocation.Rdata, 0) => "rdata", (MemoryLocation.ThreadLocalRdata, 0) => "thread_local_rdata", + (MemoryLocation.WarpLocalRdata, 0) => "warp_local_rdata", (MemoryLocation.BlockLocalRdata, 0) => "block_local_rdata", (MemoryLocation.Data, 0) => "thread_local_data", (MemoryLocation.Data, 1) => "thread_local_data", + (MemoryLocation.WarpLocalData, 0) => "warp_local_data", (MemoryLocation.BlockLocalData, 0) => "block_local_data", (MemoryLocation.Output, 0) => "output", _ => throw new NotSupportedException($"{expr.Location}, {expr.Hierarchy}"), }; var ptypeName = "std::byte"; - if (expr.Location is MemoryLocation.Rdata or MemoryLocation.ThreadLocalRdata or MemoryLocation.BlockLocalRdata) + if (expr.Location is MemoryLocation.Rdata or MemoryLocation.ThreadLocalRdata or MemoryLocation.WarpLocalRdata or MemoryLocation.BlockLocalRdata) { // Rdata, ThreadLocalRdata and BlockLocalRdata are const ptypeName = $"const {ptypeName}"; @@ -205,12 +207,12 @@ protected override CSymbol VisitPhysicalBuffer(PhysicalBuffer expr) if (expr.Size is DimConst) { var spanSize = (ulong)expr.Size.FixedValue; - name = $"std::span<{ptypeName}, {spanSize}>({loc} + {start.Name}UL, {spanSize})"; + name = $"ntt::span<{ptypeName}, {spanSize}>({loc} + {start.Name}UL, {spanSize})"; } else { var spanSize = Visit(expr.Size).Name; - name = $"std::span<{ptypeName}>({loc} + {start.Name}UL, {spanSize})"; + name = $"ntt::span<{ptypeName}>({loc} + {start.Name}UL, {spanSize})"; } symbol = new(start.Type, name); @@ -471,7 +473,7 @@ protected override CSymbol VisitCall(Call expr) WriteWithProfiler($"slice({VisitBuffer(args[0], local: true).Name}, {VisitBuffer(args[3], local: true).Name}, {VisitDimOrShape(args[1]).Name}, {VisitDimOrShape(args[2]).Name}, fixed_dims_v<{string.Join(",", slice.Axes)}>, fixed_dims_v<{string.Join(",", slice.Strides)}>);\n"); break; case TIR.NTT.Concat concat: - WriteWithProfiler($"concat(std::make_tuple({string.Join(",", args.SkipLast(1).Select(x => VisitBuffer(x, local: true)).Select(s => s.Name))}), {VisitBuffer(args[^1], local: true).Name}, {concat.Axis}_dim);\n"); + WriteWithProfiler($"concat(ntt::make_tuple({string.Join(",", args.SkipLast(1).Select(x => VisitBuffer(x, local: true)).Select(s => s.Name))}), {VisitBuffer(args[^1], local: true).Name}, {concat.Axis}_dim);\n"); break; case TIR.NTT.Transpose transpose: WriteWithProfiler($"transpose({VisitBuffer(args[0], local: true).Name}, {VisitBuffer(args[1], local: true).Name}, fixed_dims_v<{string.Join(",", transpose.Perm)}>);\n"); @@ -555,7 +557,7 @@ protected override CSymbol VisitCall(Call expr) WriteIndWithProfiler($"get_position_ids({VisitBuffer(args[0], local: true).Name}, {VisitBuffer(args[1], local: true).Name}, {KernelUtility.ShardingToC(getPositionIds.DistributedType)}, {Visit(getPositionIds.DistributedType.TensorType.Shape).Name});\n"); break; case TIR.NTT.Stack stack: - IndentScope.Writer.Write($"stack<{stack.Axis}>(std::make_tuple({string.Join(",", args.SkipLast(1).Select(x => VisitBuffer(x, local: true)).Select(s => s.Name))}), {VisitBuffer(args[^1], local: true).Name});\n"); + IndentScope.Writer.Write($"stack<{stack.Axis}>(ntt::make_tuple({string.Join(",", args.SkipLast(1).Select(x => VisitBuffer(x, local: true)).Select(s => s.Name))}), {VisitBuffer(args[^1], local: true).Name});\n"); break; case TIR.NTT.Reshape reshape: IndentScope.Writer.Write($"reshape({VisitBuffer(args[0], local: true).Name}, {VisitBuffer(args[1], local: true).Name});\n"); diff --git a/modules/Nncase.Modules.NTT/CodeGen/CPU/LinkableFunction.cs b/modules/Nncase.Modules.NTT/CodeGen/CPU/LinkableFunction.cs index 3239ac741..1fb2eebb8 100644 --- a/modules/Nncase.Modules.NTT/CodeGen/CPU/LinkableFunction.cs +++ b/modules/Nncase.Modules.NTT/CodeGen/CPU/LinkableFunction.cs @@ -21,11 +21,14 @@ internal unsafe struct KernelDescHeader [MarshalAs(UnmanagedType.U8)] public ulong LocalDataPoolSize; + [MarshalAs(UnmanagedType.U8)] + public ulong WarpLocalDataPoolSize; + [MarshalAs(UnmanagedType.U8)] public ulong BlockLocalDataPoolSize; } -internal sealed record KernelMemoryPoolDesc(ulong RdataPoolSize, ulong ThreadLocalRdataPoolSize, ulong BlockLocalRdataPoolSize); +internal sealed record KernelMemoryPoolDesc(ulong RdataPoolSize, ulong ThreadLocalRdataPoolSize, ulong WarpLocalRdataPoolSize, ulong BlockLocalRdataPoolSize); internal sealed class LinkableKernelFunction : ILinkableFunction { diff --git a/modules/Nncase.Modules.NTT/CodeGen/CPU/LinkableModule.cs b/modules/Nncase.Modules.NTT/CodeGen/CPU/LinkableModule.cs index 5b25ad596..d6b881cb1 100644 --- a/modules/Nncase.Modules.NTT/CodeGen/CPU/LinkableModule.cs +++ b/modules/Nncase.Modules.NTT/CodeGen/CPU/LinkableModule.cs @@ -17,20 +17,24 @@ namespace Nncase.CodeGen.NTT; internal sealed class LinkableModule : ILinkableModule { + private readonly string _moduleKind; private readonly Stream _desc; private readonly Stream _rdata; private readonly IReadOnlyList _threadLocalRdatas; private readonly IReadOnlyList _threadLocalCaches; + private readonly IReadOnlyList _warpLocalRdatas; private readonly IReadOnlyList _blockLocalRdatas; private readonly IReadOnlyList _functions; private readonly NTTTargetOptions _targetOptions; - public LinkableModule(Stream desc, Stream rdata, IReadOnlyList threadLocalRdatas, IReadOnlyList threadLocalCaches, IReadOnlyList blockLocalRdatas, IReadOnlyList functions, CompileOptions options) + public LinkableModule(string moduleKind, Stream desc, Stream rdata, IReadOnlyList threadLocalRdatas, IReadOnlyList threadLocalCaches, IReadOnlyList warpLocalRdatas, IReadOnlyList blockLocalRdatas, IReadOnlyList functions, CompileOptions options) { + _moduleKind = moduleKind; _desc = desc; _rdata = rdata; _threadLocalRdatas = threadLocalRdatas; _threadLocalCaches = threadLocalCaches; + _warpLocalRdatas = warpLocalRdatas; _blockLocalRdatas = blockLocalRdatas; _functions = functions; PublicFunctions = _functions.OfType().ToArray(); @@ -134,14 +138,15 @@ private void WriteModuleTopologyDef(string codegenDir) { using (var writer = new StreamWriter(fs)) { - writer.Write(CSourceBuiltn.ModuleTopologyDef(_targetOptions)); + writer.Write(CSourceBuiltn.ModuleTopologyDef(_targetOptions, isCUDA: _moduleKind == CUDATarget.Kind)); } } } private void WriteThreadMain(string codegenDir, LinkableKernelFunction mainFunc, IReadOnlyList kernelFiles) { - using (var fs = File.Open(Path.Join(codegenDir, "thread_main.cpp"), FileMode.Create)) + var threadMainExt = _moduleKind == CUDATarget.Kind ? "cu" : "cpp"; + using (var fs = File.Open(Path.Join(codegenDir, $"thread_main.{threadMainExt}"), FileMode.Create)) { using (var writer = new StreamWriter(fs)) { @@ -173,7 +178,7 @@ private void WriteCMakeLists(string codegenDir) { using (var writer = new StreamWriter(fs)) { - writer.Write(CSourceBuiltn.CMakeDef()); + writer.Write(CSourceBuiltn.CMakeDef(isCUDA: _moduleKind == CUDATarget.Kind)); } } } @@ -199,12 +204,12 @@ private ILinkedModule GenerateLinkedModule(string codegenDir, LinkableKernelFunc var funcText = File.ReadAllBytes(elfPath); textWriter.Write(funcText); linkedFunctions.Add(new LinkedFunction(mainFunc.Id, mainFunc.SourceFunction, 0, (uint)funcText.Length, mainFunc.Sections)); - return new LinkedModule(linkedFunctions, _desc, manager.GetContent(WellknownSectionNames.Text)!, _rdata, _threadLocalRdatas, _threadLocalCaches, _blockLocalRdatas, rdataAlign); + return new LinkedModule(_moduleKind, linkedFunctions, _desc, manager.GetContent(WellknownSectionNames.Text)!, _rdata, _threadLocalRdatas, _threadLocalCaches, _warpLocalRdatas, _blockLocalRdatas, rdataAlign); } private string CompileCSource(string sourcePath) { - var compiler = new CSourceCompiler(); + var compiler = new CSourceCompiler(_moduleKind == CUDATarget.Kind); var binDir = Path.Join(sourcePath, "build", "nncase_ntt_module"); return compiler.Compile(sourcePath, binDir); } diff --git a/modules/Nncase.Modules.NTT/CodeGen/CPU/LinkedModule.cs b/modules/Nncase.Modules.NTT/CodeGen/CPU/LinkedModule.cs index 6165b08fc..827e4371a 100644 --- a/modules/Nncase.Modules.NTT/CodeGen/CPU/LinkedModule.cs +++ b/modules/Nncase.Modules.NTT/CodeGen/CPU/LinkedModule.cs @@ -17,21 +17,22 @@ internal unsafe struct ModuleDescHeader public uint ThreadDim; [MarshalAs(UnmanagedType.U4)] - public uint BlockDim; + public uint WarpDim; [MarshalAs(UnmanagedType.U4)] - public uint ChipDim; + public uint BlockDim; [MarshalAs(UnmanagedType.U4)] - public uint Reserved0; + public uint ChipDim; } internal sealed class LinkedModule : ILinkedModule { public const string ModuleHeaderSectionName = ".desc"; - public unsafe LinkedModule(IReadOnlyList functions, Stream desc, Stream text, Stream rdata, IReadOnlyList threadLocalRdatas, IReadOnlyList threadLocalCaches, IReadOnlyList blockLocalRdatas, ulong rdataAlign) + public unsafe LinkedModule(string moduleKind, IReadOnlyList functions, Stream desc, Stream text, Stream rdata, IReadOnlyList threadLocalRdatas, IReadOnlyList threadLocalCaches, IReadOnlyList warpLocalRdatas, IReadOnlyList blockLocalRdatas, ulong rdataAlign) { + ModuleKind = moduleKind; Functions = functions; Sections = [ @@ -40,11 +41,12 @@ public unsafe LinkedModule(IReadOnlyList functions, Stream desc new LinkedSection(rdata, WellknownSectionNames.Rdata, 0, (uint)rdataAlign, (ulong)rdata.Length), new LinkedMultipleContentsSection(threadLocalRdatas, WellknownSectionNames.ThreadLocalRdata, 0, (uint)rdataAlign), new LinkedMultipleContentsSection(threadLocalCaches, WellknownSectionNames.ThreadLocalCache, 0, (uint)rdataAlign), + new LinkedMultipleContentsSection(warpLocalRdatas, WellknownSectionNames.WarpLocalRdata, 0, (uint)rdataAlign), new LinkedMultipleContentsSection(blockLocalRdatas, WellknownSectionNames.BlockLocalRdata, 0, (uint)rdataAlign), ]; } - public string ModuleKind => "cpu"; + public string ModuleKind { get; } public uint Version => 0; diff --git a/modules/Nncase.Modules.NTT/CodeGen/CPU/ModuleBuilder.cs b/modules/Nncase.Modules.NTT/CodeGen/CPU/ModuleBuilder.cs index fe934826a..1d4fef416 100644 --- a/modules/Nncase.Modules.NTT/CodeGen/CPU/ModuleBuilder.cs +++ b/modules/Nncase.Modules.NTT/CodeGen/CPU/ModuleBuilder.cs @@ -17,22 +17,36 @@ public sealed class NTTModuleBuilder : IModuleBuilder private readonly BinaryWriter _rdataWriter; private readonly BinaryWriter[] _threadLocalRdataWriters; private readonly BinaryWriter[] _threadLocalCacheWriters; + private readonly BinaryWriter[] _warpLocalRdataWriters; private readonly BinaryWriter[] _blockLocalRdataWriters; - public NTTModuleBuilder(CompileOptions options) + public NTTModuleBuilder(string moduleKind, CompileOptions options) { + var targetOptions = (NTTTargetOptions)options.TargetOptions; + var hierarchies = targetOptions.Hierarchies[0]; + ModuleKind = moduleKind; _sectionManager = new(); _rdataWriter = _sectionManager.GetWriter(WellknownSectionNames.Rdata); - var shardCount = TensorUtilities.GetProduct(((Targets.NTTTargetOptions)options.TargetOptions).Hierarchies[0]); + + var shardCount = TensorUtilities.GetProduct(hierarchies); _threadLocalRdataWriters = new BinaryWriter[shardCount]; _threadLocalCacheWriters = new BinaryWriter[shardCount]; - _blockLocalRdataWriters = new BinaryWriter[shardCount / ((Targets.NTTTargetOptions)options.TargetOptions).Hierarchies[0][^1]]; for (int i = 0; i < shardCount; i++) { _threadLocalRdataWriters[i] = _sectionManager.GetWriter(WellknownSectionNames.ThreadLocalRdata, i); _threadLocalCacheWriters[i] = _sectionManager.GetWriter(WellknownSectionNames.ThreadLocalCache, i); } + var isCUDA = ModuleKind == CUDATarget.Kind; + var warpsCount = isCUDA ? shardCount / hierarchies[^1] : 0; + _warpLocalRdataWriters = new BinaryWriter[warpsCount]; + for (int i = 0; i < _warpLocalRdataWriters.Length; i++) + { + _warpLocalRdataWriters[i] = _sectionManager.GetWriter(WellknownSectionNames.WarpLocalRdata, i); + } + + var blocksCount = isCUDA ? (hierarchies.Length > 1 ? warpsCount / hierarchies[^2] : 1) : shardCount / hierarchies[^1]; + _blockLocalRdataWriters = new BinaryWriter[blocksCount]; for (int i = 0; i < _blockLocalRdataWriters.Length; i++) { _blockLocalRdataWriters[i] = _sectionManager.GetWriter(WellknownSectionNames.BlockLocalRdata, i); @@ -44,7 +58,7 @@ public NTTModuleBuilder(CompileOptions options) public CompileOptions CompileOptions { get; } /// - public string ModuleKind => "cpu"; + public string ModuleKind { get; } /// public ILinkableModule Build(IReadOnlyList functions) @@ -55,9 +69,21 @@ public ILinkableModule Build(IReadOnlyList functions) using (var writer = _sectionManager.GetWriter(LinkedModule.ModuleHeaderSectionName)) { var header = default(ModuleDescHeader); + var hasWarp = targetOptions.HierarchyNames.Contains('w', StringComparison.Ordinal); header.ThreadDim = (uint)targetOptions.Hierarchies[0][^1]; - header.BlockDim = targetOptions.Hierarchies[0].Length < 2 ? 1 : (uint)targetOptions.Hierarchies[0][^2]; - header.ChipDim = targetOptions.Hierarchies[0].Length < 3 ? 1 : (uint)targetOptions.Hierarchies[0][^3]; + if (hasWarp) + { + header.WarpDim = targetOptions.Hierarchies[0].Length < 2 ? 1 : (uint)targetOptions.Hierarchies[0][^2]; + header.BlockDim = targetOptions.Hierarchies[0].Length < 3 ? 1 : (uint)targetOptions.Hierarchies[0][^3]; + header.ChipDim = targetOptions.Hierarchies[0].Length < 4 ? 1 : (uint)targetOptions.Hierarchies[0][^4]; + } + else + { + header.WarpDim = 1; + header.BlockDim = targetOptions.Hierarchies[0].Length < 2 ? 1 : (uint)targetOptions.Hierarchies[0][^2]; + header.ChipDim = targetOptions.Hierarchies[0].Length < 3 ? 1 : (uint)targetOptions.Hierarchies[0][^3]; + } + writer.Write(ref header); // cache offsets. @@ -76,7 +102,7 @@ public ILinkableModule Build(IReadOnlyList functions) } } - var linkableFunctions = functions.OfType().Select((f, i) => new FunctionBuilder((uint)i, _rdataWriter, _threadLocalRdataWriters, _blockLocalRdataWriters, (Targets.NTTTargetOptions)CompileOptions.TargetOptions).Build(f)).ToArray(); + var linkableFunctions = functions.OfType().Select((f, i) => new FunctionBuilder((uint)i, _rdataWriter, _threadLocalRdataWriters, _warpLocalRdataWriters, _blockLocalRdataWriters, (Targets.NTTTargetOptions)CompileOptions.TargetOptions).Build(f)).ToArray(); _rdataWriter.Flush(); var threadLocalRdataContents = Enumerable.Range(0, _threadLocalRdataWriters.Length).Select(i => { @@ -96,12 +122,18 @@ public ILinkableModule Build(IReadOnlyList functions) return _sectionManager.GetContent(WellknownSectionNames.ThreadLocalCache, i)!; }).ToArray(); + var warpLocalRdataContents = Enumerable.Range(0, _warpLocalRdataWriters.Length).Select(i => + { + _warpLocalRdataWriters[i].Flush(); + return _sectionManager.GetContent(WellknownSectionNames.WarpLocalRdata, i)!; + }).ToArray(); + var blockLocalRdataContents = Enumerable.Range(0, _blockLocalRdataWriters.Length).Select(i => { _blockLocalRdataWriters[i].Flush(); return _sectionManager.GetContent(WellknownSectionNames.BlockLocalRdata, i)!; }).ToArray(); - return new LinkableModule(_sectionManager.GetContent(LinkedModule.ModuleHeaderSectionName)!, _sectionManager.GetContent(WellknownSectionNames.Rdata)!, threadLocalRdataContents, threadLocalCacheContents, blockLocalRdataContents, linkableFunctions, CompileOptions); + return new LinkableModule(ModuleKind, _sectionManager.GetContent(LinkedModule.ModuleHeaderSectionName)!, _sectionManager.GetContent(WellknownSectionNames.Rdata)!, threadLocalRdataContents, threadLocalCacheContents, warpLocalRdataContents, blockLocalRdataContents, linkableFunctions, CompileOptions); } } diff --git a/modules/Nncase.Modules.NTT/CodeGen/CPU/Templates/CMakeLists.txt.cshtml b/modules/Nncase.Modules.NTT/CodeGen/CPU/Templates/CMakeLists.txt.cshtml index 27199b31c..6b6b6a6fe 100644 --- a/modules/Nncase.Modules.NTT/CodeGen/CPU/Templates/CMakeLists.txt.cshtml +++ b/modules/Nncase.Modules.NTT/CodeGen/CPU/Templates/CMakeLists.txt.cshtml @@ -2,7 +2,11 @@ cmake_minimum_required(VERSION 3.15) -project(nncase_cpu_module) +@if (Model.IsCUDA) { +@:project(nncase_cpu_module CXX CUDA) +} else { +@:project(nncase_cpu_module CXX) +} option(BUILD_SHARED "Build shared library in linux" OFF) option(BUILD_STANDALONE "Build standalone executable" OFF) @@ -12,5 +16,9 @@ endif() include(@Html.Raw(Model.CMakePath)) -target_sources(nncase_ntt_module PRIVATE thread_main.cpp) -target_include_directories(nncase_ntt_module PRIVATE ${CMAKE_CURRENT_LIST_DIR}) +@if (Model.IsCUDA) { +@:target_sources(${NNCASE_NTT_MODULE_TARGET_NAME} PRIVATE thread_main.cu) +} else { +@:target_sources(${NNCASE_NTT_MODULE_TARGET_NAME} PRIVATE thread_main.cpp) +} +target_include_directories(${NNCASE_NTT_MODULE_TARGET_NAME} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) diff --git a/modules/Nncase.Modules.NTT/CodeGen/CPU/Templates/Kernels/Matmul.cshtml b/modules/Nncase.Modules.NTT/CodeGen/CPU/Templates/Kernels/Matmul.cshtml index 5c81a9b8c..57adcaaa0 100644 --- a/modules/Nncase.Modules.NTT/CodeGen/CPU/Templates/Kernels/Matmul.cshtml +++ b/modules/Nncase.Modules.NTT/CodeGen/CPU/Templates/Kernels/Matmul.cshtml @@ -8,18 +8,18 @@ { @:@(Model.Indent)if (@Html.Raw(Model.Arguments[3].Symbol.Name)) { // @:@(Model.Indent) constexpr std::string_view function_name = "matmul"; -// @:@(Model.Indent) auto_profiler profiler(function_name, runtime::profiling_level::device); +// @:@(Model.Indent) profile_scope profiler(function_name, profile_level::device); @:@(Model.Indent) matmul(@Html.Raw(Model.Arguments[0].Symbol.Name), @Html.Raw(Model.Arguments[1].Symbol.Name), @Html.Raw(Model.Arguments[2].Symbol.Name), @scale, fixed_shape_v<@string.Join(",", Model.Target.LhsVectorizedAxes)>, fixed_shape_v<>, fixed_shape_v<@string.Join(",", Model.Target.RhsVectorizedAxes)>, fixed_shape_v<>); @:@(Model.Indent)} else { // @:@(Model.Indent) constexpr std::string_view function_name = "matmul"; -// @:@(Model.Indent) auto_profiler profiler(function_name, runtime::profiling_level::device); +// @:@(Model.Indent) profile_scope profiler(function_name, profile_level::device); @:@(Model.Indent) matmul(@Html.Raw(Model.Arguments[0].Symbol.Name), @Html.Raw(Model.Arguments[1].Symbol.Name), @Html.Raw(Model.Arguments[2].Symbol.Name), @scale, fixed_shape_v<@string.Join(",", Model.Target.LhsVectorizedAxes)>, fixed_shape_v<>, fixed_shape_v<@string.Join(",", Model.Target.RhsVectorizedAxes)>, fixed_shape_v<>); @:@(Model.Indent)} } else { // @:@(Model.Indent) constexpr std::string_view function_name = "matmul"; -// @:@(Model.Indent) auto_profiler profiler(function_name, runtime::profiling_level::device); +// @:@(Model.Indent) profile_scope profiler(function_name, profile_level::device); @:@(Model.Indent) matmul(@Html.Raw(Model.Arguments[0].Symbol.Name), @Html.Raw(Model.Arguments[1].Symbol.Name), @Html.Raw(Model.Arguments[2].Symbol.Name), @scale, fixed_shape_v<@string.Join(",", Model.Target.LhsVectorizedAxes)>, fixed_shape_v<>, fixed_shape_v<@string.Join(",", Model.Target.RhsVectorizedAxes)>, fixed_shape_v<>); } @(Model.Indent)} diff --git a/modules/Nncase.Modules.NTT/CodeGen/CPU/Templates/Kernels/PackedMatMul.cshtml b/modules/Nncase.Modules.NTT/CodeGen/CPU/Templates/Kernels/PackedMatMul.cshtml index 51d6568ff..0b42768ae 100644 --- a/modules/Nncase.Modules.NTT/CodeGen/CPU/Templates/Kernels/PackedMatMul.cshtml +++ b/modules/Nncase.Modules.NTT/CodeGen/CPU/Templates/Kernels/PackedMatMul.cshtml @@ -8,18 +8,18 @@ { @:@(Model.Indent)if (@Html.Raw(Model.Arguments[3].Symbol.Name)) { // @:@(Model.Indent) constexpr std::string_view function_name = "matmul"; -// @:@(Model.Indent) auto_profiler profiler(function_name, runtime::profiling_level::device); +// @:@(Model.Indent) profile_scope profiler(function_name, profile_level::device); @:@(Model.Indent) packed_matmul(@Html.Raw(Model.Arguments[0].Symbol.Name), @Html.Raw(Model.Arguments[1].Symbol.Name), @Html.Raw(Model.Arguments[2].Symbol.Name), @(scale)); @:@(Model.Indent)} else { // @:@(Model.Indent) constexpr std::string_view function_name = "matmul"; -// @:@(Model.Indent) auto_profiler profiler(function_name, runtime::profiling_level::device); +// @:@(Model.Indent) profile_scope profiler(function_name, profile_level::device); @:@(Model.Indent) packed_matmul(@Html.Raw(Model.Arguments[0].Symbol.Name), @Html.Raw(Model.Arguments[1].Symbol.Name), @Html.Raw(Model.Arguments[2].Symbol.Name), @(scale)); @:@(Model.Indent)} } else { // @:@(Model.Indent) constexpr std::string_view function_name = "matmul"; -// @:@(Model.Indent) auto_profiler profiler(function_name, runtime::profiling_level::device); +// @:@(Model.Indent) profile_scope profiler(function_name, profile_level::device); @:@(Model.Indent) packed_matmul(@Html.Raw(Model.Arguments[0].Symbol.Name), @Html.Raw(Model.Arguments[1].Symbol.Name), @Html.Raw(Model.Arguments[2].Symbol.Name), @(scale)); } @(Model.Indent)} diff --git a/modules/Nncase.Modules.NTT/CodeGen/CPU/Templates/module_topology_def.h.cshtml b/modules/Nncase.Modules.NTT/CodeGen/CPU/Templates/module_topology_def.h.cshtml index 58b3a43e0..83806efdc 100644 --- a/modules/Nncase.Modules.NTT/CodeGen/CPU/Templates/module_topology_def.h.cshtml +++ b/modules/Nncase.Modules.NTT/CodeGen/CPU/Templates/module_topology_def.h.cshtml @@ -1,14 +1,14 @@ @using System.Linq @using NetFabric.Hyperlinq @using Nncase -@model Nncase.Targets.NTTTargetOptions @{ - var hierarchy = Model.Hierarchies[0]; + var hierarchy = (int[])Model.Hierarchies; + var topologyLevels = Model.IsCUDA ? 4 : 3; } #pragma once #include namespace nncase::ntt::distributed { - constexpr auto topology_shape = ntt::fixed_shape_v<@(string.Join(", ", Enumerable.Repeat(1, 3 - hierarchy.Length).Concat(hierarchy)))>; + constexpr auto topology_shape = ntt::fixed_shape_v<@(string.Join(", ", Enumerable.Repeat(1, topologyLevels - hierarchy.Length).Concat(hierarchy)))>; } diff --git a/modules/Nncase.Modules.NTT/CodeGen/CPU/Templates/thread_main.cpp.cshtml b/modules/Nncase.Modules.NTT/CodeGen/CPU/Templates/thread_main.cpp.cshtml index 51f00f036..5eb35ec89 100644 --- a/modules/Nncase.Modules.NTT/CodeGen/CPU/Templates/thread_main.cpp.cshtml +++ b/modules/Nncase.Modules.NTT/CodeGen/CPU/Templates/thread_main.cpp.cshtml @@ -7,12 +7,14 @@ var inputCount = Model.PrimFunction.Parameters.Length; } -extern "C" void thread_main(const nncase::ntt::runtime::thread_inout_desc *input_descs, +extern "C" NTT_DEVICE void thread_main(const nncase::ntt::runtime::thread_inout_desc *input_descs, nncase::ntt::runtime::thread_inout_desc *const output_descs, const std::byte *rdata, const std::byte *thread_local_rdata, + const std::byte *warp_local_rdata, const std::byte *block_local_rdata, std::byte *thread_local_data, + std::byte *warp_local_data, std::byte *block_local_data, std::byte *output) { /* prepare inputs */ @@ -72,7 +74,7 @@ extern "C" void thread_main(const nncase::ntt::runtime::thread_inout_desc *input throw new NotSupportedException($"not support multi form topology!"); } - @(Model.PrimFunction.Name)(@(string.Concat(names.Select(x => $"{x}, ")))rdata, thread_local_rdata, block_local_rdata, thread_local_data, block_local_data, output, output_descs); + @(Model.PrimFunction.Name)(@(string.Concat(names.Select(x => $"{x}, ")))rdata, thread_local_rdata, warp_local_rdata, block_local_rdata, thread_local_data, warp_local_data, block_local_data, output, output_descs); } #ifdef NNCASE_STANDALONE diff --git a/modules/Nncase.Modules.NTT/CodeGen/CPU/Templates/topo_aware_runtime.cshtml b/modules/Nncase.Modules.NTT/CodeGen/CPU/Templates/topo_aware_runtime.cshtml index 0ad5b8837..c401fd887 100644 --- a/modules/Nncase.Modules.NTT/CodeGen/CPU/Templates/topo_aware_runtime.cshtml +++ b/modules/Nncase.Modules.NTT/CodeGen/CPU/Templates/topo_aware_runtime.cshtml @@ -15,8 +15,7 @@ #pragma once #include -#include -#include +#include /** * @@brief topology aware runtime @@ -33,22 +32,15 @@ namespace tar { foreach (var i in comb) { shape[i] = 1; } - var groupRawName = groupName + "_raw"; -@:std::barrier<> @(groupRawName)[@(groups)] { - @for (int i = 0; i < groups; i++) - { - @:std::barrier(@(groupSize)), - } -@:}; -@:auto @(groupName) = nncase::ntt::make_tensor_view_from_address>(@(groupRawName), nncase::ntt::fixed_shape_v<@(string.Join(",", shape))>); +@:NTT_DEVICE decltype(nncase::ntt::make_tensor>(nncase::ntt::fixed_shape_v<@(string.Join(",", shape))>)) @(groupName); @: } @if (Model.CollectivePoolSize > 0) { -@:alignas(@Model.Alignment) uint8_t collective_pool_ptr[@Model.CollectivePoolSize]; +@:alignas(@Model.Alignment) NTT_DEVICE uint8_t collective_pool_ptr[@Model.CollectivePoolSize]; } else { -@:alignas(@Model.Alignment) uint8_t collective_pool_ptr[1]; +@:alignas(@Model.Alignment) NTT_DEVICE uint8_t collective_pool_ptr[1]; } enum reduce_kind { @@ -57,11 +49,11 @@ enum reduce_kind { } }; -constexpr std::array Hierarchy = {@(string.Join(", ", hierarchy))}; -auto src_ptr_tensor = nncase::ntt::make_tensor(nncase::ntt::fixed_shape_v<@(string.Join(",", hierarchy))>); -auto dest_ptr_tensor = nncase::ntt::make_tensor(nncase::ntt::fixed_shape_v<@(string.Join(",", hierarchy))>); +NTT_DEVICE constexpr ntt::array Hierarchy = {@(string.Join(", ", hierarchy))}; +NTT_DEVICE auto src_ptr_tensor = nncase::ntt::make_tensor(nncase::ntt::fixed_shape_v<@(string.Join(",", hierarchy))>); +NTT_DEVICE auto dest_ptr_tensor = nncase::ntt::make_tensor(nncase::ntt::fixed_shape_v<@(string.Join(",", hierarchy))>); -template static std::byte *get_cache_address() { +template NTT_DEVICE static std::byte *get_cache_address() { return reinterpret_cast( ntt::distributed::detail::global_thread_local_cache_ptr(program_ids())( Level)); @@ -77,7 +69,7 @@ namespace tac { using namespace nncase; template -void tensor_boxing_load_sync(const GlobalShape &global_shape, const Index &index, TDst &dest) +NTT_DEVICE void tensor_boxing_load_sync(const GlobalShape &global_shape, const Index &index, TDst &dest) { using TOutBase = std::decay_t; using TElem = typename TOutBase::element_type; @@ -87,7 +79,7 @@ void tensor_boxing_load_sync(const GlobalShape &global_shape, const Index &index } template -void tensor_boxing_store_sync(const GlobalShape &global_shape, const Index &index, TSrc &src) +NTT_DEVICE void tensor_boxing_store_sync(const GlobalShape &global_shape, const Index &index, TSrc &src) { using TSrcBase = std::decay_t; using TElem = typename TSrcBase::element_type; @@ -103,14 +95,14 @@ template class group_hierarchy_getter; @:template <> class group_hierarchy_getter { var shape = Enumerable.Range(0, hierarchy.Length).Select(i => comb.Contains(i) ? hierarchy[i] : 1).ToArray(); @:public: -@: static constexpr auto group_hierarchy = ntt::fixed_shape_v<@(string.Join(", ", shape))>; +@: NTT_DEVICE static constexpr auto group_hierarchy = ntt::fixed_shape_v<@(string.Join(", ", shape))>; @:}; } template class tensor_reduce_sync_impl { public: - void reduce_group_sync() const noexcept { + NTT_DEVICE void reduce_group_sync() const noexcept { @foreach(var comb in combinations) { var reduce_group_index = string.Join(", ", Enumerable.Range(0, hierarchy.Length).Select(i => comb.Contains(i) ? "0" : "ntt::distributed::" + hierarchyNames[i] + "id()")); @:if constexpr (Kind == tar::reduce_kind::@(GetName(comb, string.Empty))) { @@ -124,7 +116,7 @@ class tensor_reduce_sync_impl { } template - constexpr auto index_group2global(const TIndexInGroup &index_in_group, const TIndexInGlobal &index_in_global) const noexcept { + NTT_DEVICE constexpr auto index_group2global(const TIndexInGroup &index_in_group, const TIndexInGlobal &index_in_global) const noexcept { return ntt::generate_shape([&](auto axis) { if constexpr (Kind & (1 << (TIndexInGlobal::rank() - axis))) { return index_in_group[axis]; @@ -135,7 +127,7 @@ class tensor_reduce_sync_impl { } template - constexpr auto index_global2group(const TIndexInGlobal &index_in_global) const noexcept { + NTT_DEVICE constexpr auto index_global2group(const TIndexInGlobal &index_in_global) const noexcept { return ntt::generate_shape([&](auto axis) { if constexpr (Kind & (1 << (TIndexInGlobal::rank() - axis))) { return index_in_global[axis]; @@ -145,7 +137,7 @@ class tensor_reduce_sync_impl { }); } - static constexpr auto get_group_size() { + NTT_DEVICE static constexpr auto get_group_size() { size_t group_size = 1; for (size_t i = 1; i <= tar::Hierarchy.size(); i++) { if (Kind & (1 << i)) { @@ -160,7 +152,7 @@ class tensor_reduce_sync_impl { } template - void reduce_impl(TSliceIn &local, TSliceIn &remote, TSliceOut &dest) { + NTT_DEVICE void reduce_impl(TSliceIn &local, TSliceIn &remote, TSliceOut &dest) { if constexpr (Op == ntt::reduce_op::max) { ntt::binary(local, remote, dest); } else if constexpr (Op == ntt::reduce_op::sum || @@ -173,7 +165,7 @@ class tensor_reduce_sync_impl { } } - template void operator()(TIn &src, TOut &&dest) { + template NTT_DEVICE void operator()(TIn &src, TOut &&dest) { // collect all tensors pointer for access tensor from other nodes. using TElem = typename TIn::element_type; using TOutBase = std::decay_t; @@ -289,7 +281,7 @@ class tensor_reduce_sync_impl { } // namespace detail template -void tensor_reduce_sync(TIn &input, TOut &&output) { +NTT_DEVICE void tensor_reduce_sync(TIn &input, TOut &&output) { detail::tensor_reduce_sync_impl impl; impl(input, output); } diff --git a/modules/Nncase.Modules.NTT/NTTModule.cs b/modules/Nncase.Modules.NTT/NTTModule.cs index 9ba23e357..fe7497b94 100644 --- a/modules/Nncase.Modules.NTT/NTTModule.cs +++ b/modules/Nncase.Modules.NTT/NTTModule.cs @@ -15,5 +15,6 @@ internal class NTTModule : IApplicationPart public void ConfigureServices(IRegistrator registrator) { registrator.Register(reuse: Reuse.Singleton); + registrator.Register(reuse: Reuse.Singleton); } } diff --git a/modules/Nncase.Modules.NTT/Targets/NTTModuleCompiler.cs b/modules/Nncase.Modules.NTT/Targets/CPUModuleCompiler.cs similarity index 93% rename from modules/Nncase.Modules.NTT/Targets/NTTModuleCompiler.cs rename to modules/Nncase.Modules.NTT/Targets/CPUModuleCompiler.cs index 2585e965b..e0ab0705e 100644 --- a/modules/Nncase.Modules.NTT/Targets/NTTModuleCompiler.cs +++ b/modules/Nncase.Modules.NTT/Targets/CPUModuleCompiler.cs @@ -11,7 +11,7 @@ namespace Nncase.Targets; -public class NTTModuleCompiler : IModuleCompiler +public class CPUModuleCompiler : INTTModuleCompiler { public string ModuleKind => CPUTarget.Kind; @@ -35,7 +35,7 @@ public class NTTModuleCompiler : IModuleCompiler _ => throw new NotSupportedException($"Unsupported architecture: {RuntimeInformation.ProcessArchitecture}"), }; - public IModuleBuilder CreateModuleBuilder(CompileOptions options) => new NTTModuleBuilder(options); + public IModuleBuilder CreateModuleBuilder(CompileOptions options) => new NTTModuleBuilder(ModuleKind, options); public bool IsSupportedCall(Call call, CompileOptions options) { diff --git a/modules/Nncase.Modules.NTT/Targets/CPUTarget.cs b/modules/Nncase.Modules.NTT/Targets/CPUTarget.cs index 04312ec46..15e5e76ac 100644 --- a/modules/Nncase.Modules.NTT/Targets/CPUTarget.cs +++ b/modules/Nncase.Modules.NTT/Targets/CPUTarget.cs @@ -23,110 +23,15 @@ namespace Nncase.Targets; /// -/// Target for NTT. +/// Target for CPU. /// -public class CPUTarget : Target +public class CPUTarget : NTTTarget { public const string Kind = "cpu"; - private readonly NTTModuleCompiler _nttModuleCompiler = new(); - public CPUTarget() { - ModuleCompilers = [_nttModuleCompiler]; - } - - public override string Name => Kind; - - public override IReadOnlyList ModuleCompilers { get; } - - public override (System.CommandLine.Command Command, Func Parser) RegisterCommandAndParser() - { - var cmd = new NTTTargetOptionsCommand(Kind); - - ITargetOptions ParseTargetCompileOptions(InvocationContext context, Command command) - { - var binder = new NTTTargetOptionsBinder(cmd); - return binder.GetBoundValue(context); - } - - return (cmd, ParseTargetCompileOptions); - } - - public override void RegisterAffineSelectionPass(IPassManager passManager, CompileOptions options) - { - passManager.Add(); - } - - public override void RegisterAutoPackingRules(IRulesAddable pass, CompileOptions options) - { - var nr = _nttModuleCompiler.Nr; - - pass.Add(nr); } - public override void RegisterAutoVectorizeRules(IRulesAddable pass, CompileOptions options) - { - // todo config it in the target options. - var rank = 1; - var lane = _nttModuleCompiler.Lane; - var maskVectorStyle = _nttModuleCompiler.MaskVectorStyle; - - pass.Add(rank, lane); - pass.Add(rank, lane); - pass.Add(rank, lane); - - pass.Add(); - pass.Add(); - pass.Add(maskVectorStyle); - pass.Add(); - pass.Add(); - pass.Add(); - pass.Add(); - pass.Add(); - pass.Add(); - pass.Add(); - pass.Add(); - - // pass.Add(rank, lane); - pass.Add(); - - // pass.Add(rank, lane); - pass.Add(); - pass.Add(); - pass.Add(); - pass.Add(maskVectorStyle); - - pass.Add(); - pass.Add(); - pass.Add(); - pass.Add(); - pass.Add(); - pass.Add(); - pass.Add(); - pass.Add(); - pass.Add(); - pass.Add(); - - pass.Add(); - pass.Add(); - pass.Add(); - pass.Add(); - pass.Add(); - pass.Add(); - } - - public override void RegisterTIRSelectionPass(IPassManager passManager, CompileOptions optionsÍ) - { - passManager.Add(); - } - - public override void RegisterPostAutoVectorizePass(IPassManager passManager, CompileOptions options) - { - passManager.AddWithName("FoldPostOps").Configure(p => - { - p.Add(); - p.Add(); - }); - } + protected override INTTModuleCompiler NTTModuleCompiler { get; } = new CPUModuleCompiler(); } diff --git a/modules/Nncase.Modules.NTT/Targets/CUDAModuleCompiler.cs b/modules/Nncase.Modules.NTT/Targets/CUDAModuleCompiler.cs new file mode 100644 index 000000000..534e03a41 --- /dev/null +++ b/modules/Nncase.Modules.NTT/Targets/CUDAModuleCompiler.cs @@ -0,0 +1,34 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections.Generic; +using System.Runtime.InteropServices; +using Nncase.CodeGen; +using Nncase.CodeGen.NTT; +using Nncase.IR; +using Nncase.Passes; + +namespace Nncase.Targets; + +public class CUDAModuleCompiler : INTTModuleCompiler +{ + public string ModuleKind => CUDATarget.Kind; + + public MaskVectorStyle MaskVectorStyle => MaskVectorStyle.Fat; + + public int Lane => 16; + + public int Nr => 4; + + public IModuleBuilder CreateModuleBuilder(CompileOptions options) => new NTTModuleBuilder(ModuleKind, options); + + public bool IsSupportedCall(Call call, CompileOptions options) + { + return call.Target switch + { + Op op => PassUtility.IsCpuSupported(op, call, call.Arguments, ModuleKind), + _ => false, + }; + } +} diff --git a/modules/Nncase.Modules.NTT/Targets/CUDATarget.cs b/modules/Nncase.Modules.NTT/Targets/CUDATarget.cs new file mode 100644 index 000000000..76fd11b39 --- /dev/null +++ b/modules/Nncase.Modules.NTT/Targets/CUDATarget.cs @@ -0,0 +1,37 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections.Generic; +using System.CommandLine; +using System.CommandLine.Invocation; +using System.Linq; +using System.Runtime.InteropServices; +using System.Text; +using System.Threading.Tasks; +using Microsoft.Extensions.Configuration; +using Microsoft.Extensions.Options; +using Nncase.CodeGen; +using Nncase.CodeGen.NTT; +using Nncase.IR; +using Nncase.Passes; +using Nncase.Passes.Rules.Neutral; +using Nncase.Passes.Rules.ShapeBucket; +using Nncase.Passes.Transforms; +using Nncase.Quantization; + +namespace Nncase.Targets; + +/// +/// Target for CUDA. +/// +public class CUDATarget : NTTTarget +{ + public const string Kind = "cuda"; + + public CUDATarget() + { + } + + protected override INTTModuleCompiler NTTModuleCompiler { get; } = new CUDAModuleCompiler(); +} diff --git a/modules/Nncase.Modules.NTT/Targets/INTTModuleCompiler.cs b/modules/Nncase.Modules.NTT/Targets/INTTModuleCompiler.cs new file mode 100644 index 000000000..cd8d8cd7e --- /dev/null +++ b/modules/Nncase.Modules.NTT/Targets/INTTModuleCompiler.cs @@ -0,0 +1,19 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections.Generic; +using System.Runtime.InteropServices; +using Nncase.CodeGen; +using Nncase.CodeGen.NTT; +using Nncase.IR; +using Nncase.Passes; + +namespace Nncase.Targets; + +public interface INTTModuleCompiler : IModuleCompiler +{ + int Lane { get; } + + int Nr { get; } +} diff --git a/modules/Nncase.Modules.NTT/Targets/NTTTarget.cs b/modules/Nncase.Modules.NTT/Targets/NTTTarget.cs new file mode 100644 index 000000000..2dc8ca7ba --- /dev/null +++ b/modules/Nncase.Modules.NTT/Targets/NTTTarget.cs @@ -0,0 +1,130 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections.Generic; +using System.CommandLine; +using System.CommandLine.Invocation; +using System.Linq; +using System.Runtime.InteropServices; +using System.Text; +using System.Threading.Tasks; +using Microsoft.Extensions.Configuration; +using Microsoft.Extensions.Options; +using Nncase.CodeGen; +using Nncase.CodeGen.NTT; +using Nncase.IR; +using Nncase.Passes; +using Nncase.Passes.Rules.Neutral; +using Nncase.Passes.Rules.ShapeBucket; +using Nncase.Passes.Transforms; +using Nncase.Quantization; + +namespace Nncase.Targets; + +/// +/// Target for NTT. +/// +public abstract class NTTTarget : Target +{ + public NTTTarget() + { + ModuleCompilers = [NTTModuleCompiler]; + } + + public override string Name => NTTModuleCompiler.ModuleKind; + + public override IReadOnlyList ModuleCompilers { get; } + + protected abstract INTTModuleCompiler NTTModuleCompiler { get; } + + public override (System.CommandLine.Command Command, Func Parser) RegisterCommandAndParser() + { + var cmd = new NTTTargetOptionsCommand(Name); + + ITargetOptions ParseTargetCompileOptions(InvocationContext context, Command command) + { + var binder = new NTTTargetOptionsBinder(cmd); + return binder.GetBoundValue(context); + } + + return (cmd, ParseTargetCompileOptions); + } + + public override void RegisterAffineSelectionPass(IPassManager passManager, CompileOptions options) + { + passManager.Add(); + } + + public override void RegisterAutoPackingRules(IRulesAddable pass, CompileOptions options) + { + var nr = NTTModuleCompiler.Nr; + + pass.Add(nr); + } + + public override void RegisterAutoVectorizeRules(IRulesAddable pass, CompileOptions options) + { + // todo config it in the target options. + var rank = 1; + var lane = NTTModuleCompiler.Lane; + var maskVectorStyle = NTTModuleCompiler.MaskVectorStyle; + + pass.Add(rank, lane); + pass.Add(rank, lane); + pass.Add(rank, lane); + + pass.Add(); + pass.Add(); + pass.Add(maskVectorStyle); + pass.Add(); + pass.Add(); + pass.Add(); + pass.Add(); + pass.Add(); + pass.Add(); + pass.Add(); + pass.Add(); + + // pass.Add(rank, lane); + pass.Add(); + + // pass.Add(rank, lane); + pass.Add(); + pass.Add(); + pass.Add(); + pass.Add(maskVectorStyle); + + pass.Add(); + pass.Add(); + pass.Add(); + pass.Add(); + pass.Add(); + pass.Add(); + pass.Add(); + pass.Add(); + pass.Add(); + pass.Add(); + + pass.Add(); + pass.Add(); + pass.Add(); + pass.Add(); + pass.Add(); + pass.Add(); + } + + public override void RegisterTIRSelectionPass(IPassManager passManager, CompileOptions optionsÍ) + { + passManager.Add(NTTModuleCompiler.ModuleKind); + } + + public override void RegisterPostAutoVectorizePass(IPassManager passManager, CompileOptions options) + { + passManager.AddWithName("FoldPostOps").Configure(p => + { + p.Add(); + p.Add(); + }); + } +} diff --git a/ntt/cmake/compile_flags.cmake b/ntt/cmake/compile_flags.cmake index 0fac799dc..c9050db84 100644 --- a/ntt/cmake/compile_flags.cmake +++ b/ntt/cmake/compile_flags.cmake @@ -70,3 +70,6 @@ if(${CMAKE_SYSTEM_PROCESSOR} MATCHES message(FATAL_ERROR "Unsupported riscv64 target") endif() endif() + +if (CMAKE_CUDA_COMPILER) +endif() diff --git a/ntt/cmake/ntt_module.cmake b/ntt/cmake/ntt_module.cmake index 8b91c5cfb..0f9ab7db9 100644 --- a/ntt/cmake/ntt_module.cmake +++ b/ntt/cmake/ntt_module.cmake @@ -1,37 +1,105 @@ -cmake_minimum_required(VERSION 3.15) +cmake_minimum_required(VERSION 3.16) include(${CMAKE_CURRENT_LIST_DIR}/compile_flags.cmake) -if (BUILD_STANDALONE) - add_executable(nncase_ntt_module ${CMAKE_CURRENT_LIST_DIR}/../src/dummy.cpp) +if (CMAKE_CUDA_COMPILER) + set(NNCASE_NTT_MODULE_TARGET_NAME nncase_ntt_module_bundle) + + find_program(NVLINK nvlink REQUIRED) + find_program(FATBINARY fatbinary REQUIRED) + message(STATUS "Found nvlink: ${NVLINK}") + message(STATUS "Found fatbinary: ${FATBINARY}") else() - add_library(nncase_ntt_module SHARED ${CMAKE_CURRENT_LIST_DIR}/../src/dummy.cpp) + set(NNCASE_NTT_MODULE_TARGET_NAME nncase_ntt_module) endif() -target_compile_features(nncase_ntt_module PUBLIC cxx_std_20) -target_include_directories(nncase_ntt_module PUBLIC ${CMAKE_CURRENT_LIST_DIR}/../include) -set_target_properties(nncase_ntt_module PROPERTIES PREFIX "" SUFFIX "") -set_target_properties(nncase_ntt_module PROPERTIES POSITION_INDEPENDENT_CODE ON) -target_compile_definitions(nncase_ntt_module PUBLIC -DNNCASE_CPU_MODULE=1) -set_property(TARGET nncase_ntt_module PROPERTY INTERPROCEDURAL_OPTIMIZATION TRUE) +if (BUILD_STANDALONE) + add_executable(${NNCASE_NTT_MODULE_TARGET_NAME} ${CMAKE_CURRENT_LIST_DIR}/../src/dummy.cpp) +elseif (CMAKE_CUDA_COMPILER) + add_library(${NNCASE_NTT_MODULE_TARGET_NAME} OBJECT) +else() + add_library(${NNCASE_NTT_MODULE_TARGET_NAME} SHARED ${CMAKE_CURRENT_LIST_DIR}/../src/dummy.cpp) +endif() -target_sources(nncase_ntt_module PRIVATE ${CMAKE_CURRENT_LIST_DIR}/../src/cpu_runtime.cpp) +target_compile_features(${NNCASE_NTT_MODULE_TARGET_NAME} PUBLIC cxx_std_20) +target_include_directories(${NNCASE_NTT_MODULE_TARGET_NAME} PUBLIC ${CMAKE_CURRENT_LIST_DIR}/../include) +set_target_properties(${NNCASE_NTT_MODULE_TARGET_NAME} PROPERTIES PREFIX "" SUFFIX "") +set_target_properties(${NNCASE_NTT_MODULE_TARGET_NAME} PROPERTIES POSITION_INDEPENDENT_CODE ON) +set_property(TARGET ${NNCASE_NTT_MODULE_TARGET_NAME} PROPERTY INTERPROCEDURAL_OPTIMIZATION TRUE) if (BUILD_STANDALONE) - target_compile_definitions(nncase_ntt_module PUBLIC -DNNCASE_STANDALONE=1) + target_compile_definitions(${NNCASE_NTT_MODULE_TARGET_NAME} PUBLIC -DNNCASE_STANDALONE=1) endif() if (MSVC) - set_property(TARGET nncase_ntt_module PROPERTY + set_property(TARGET ${NNCASE_NTT_MODULE_TARGET_NAME} PROPERTY MSVC_RUNTIME_LIBRARY "MultiThreaded$<$:Debug>") - set_target_properties(nncase_ntt_module PROPERTIES LINK_FLAGS /SUBSYSTEM:CONSOLE) - target_link_options(nncase_ntt_module PRIVATE /NODEFAULTLIB) - target_link_libraries(nncase_ntt_module PRIVATE "libvcruntime$<$:d>" + set_target_properties(${NNCASE_NTT_MODULE_TARGET_NAME} PROPERTIES LINK_FLAGS /SUBSYSTEM:CONSOLE) + target_link_options(${NNCASE_NTT_MODULE_TARGET_NAME} PRIVATE /NODEFAULTLIB) + target_link_libraries(${NNCASE_NTT_MODULE_TARGET_NAME} PRIVATE "libvcruntime$<$:d>" "msvcrt$<$:d>" "ucrt$<$:d>" "libcpmt$<$:d>") elseif(APPLE) - target_link_options(nncase_ntt_module PRIVATE -ld_classic -lc) + target_link_options(${NNCASE_NTT_MODULE_TARGET_NAME} PRIVATE -ld_classic -lc) +else() + target_link_libraries(${NNCASE_NTT_MODULE_TARGET_NAME} PRIVATE pthread) +endif() + +if (CMAKE_CUDA_COMPILER) + target_sources(${NNCASE_NTT_MODULE_TARGET_NAME} PRIVATE ${CMAKE_CURRENT_LIST_DIR}/../src/cuda_runtime.cu) + target_compile_definitions(${NNCASE_NTT_MODULE_TARGET_NAME} PUBLIC -DNNCASE_CUDA_MODULE=1) + target_compile_options(${NNCASE_NTT_MODULE_TARGET_NAME} PRIVATE $<$: + -fgpu-rdc + --cuda-device-only + >) + + foreach(arch ${CMAKE_CUDA_ARCHITECTURES}) + # Link device code for this architecture + set(linked_obj "${CMAKE_CURRENT_BINARY_DIR}/linked_sm_${arch}.o") + add_custom_command( + OUTPUT ${linked_obj} + COMMAND ${NVLINK} + -arch=sm_${arch} + $ + -o ${linked_obj} + DEPENDS ${NNCASE_NTT_MODULE_TARGET_NAME} $ + COMMAND_EXPAND_LISTS + VERBATIM + COMMENT "Linking device code for sm_${arch}" + ) + + # Add to the list of all linked objects + list(APPEND ALL_LINKED_OBJECTS ${linked_obj}) + endforeach() + + add_custom_target(device_link ALL + DEPENDS ${ALL_LINKED_OBJECTS} + ) + + set(FATBIN_ARGS "") + foreach(arch ${CMAKE_CUDA_ARCHITECTURES}) + # Find the linked object for this architecture + set(arch_obj "${CMAKE_CURRENT_BINARY_DIR}/linked_sm_${arch}.o") + list(APPEND FATBIN_ARGS --image3=kind=elf,sm=${arch},file="${arch_obj}") + endforeach() + + # Create the fatbinary from all linked objects + add_custom_command( + OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/nncase_ntt_module + COMMAND ${FATBINARY} + -64 + --create ${CMAKE_CURRENT_BINARY_DIR}/nncase_ntt_module + ${FATBIN_ARGS} + DEPENDS ${ALL_LINKED_OBJECTS} + COMMENT "Creating fatbinary from linked objects" + VERBATIM + ) + + add_custom_target(fatbin ALL + DEPENDS ${CMAKE_CURRENT_BINARY_DIR}/nncase_ntt_module + ) else() - target_link_libraries(nncase_ntt_module PRIVATE pthread) + target_sources(${NNCASE_NTT_MODULE_TARGET_NAME} PRIVATE ${CMAKE_CURRENT_LIST_DIR}/../src/cpu_runtime.cpp) + target_compile_definitions(${NNCASE_NTT_MODULE_TARGET_NAME} PUBLIC -DNNCASE_CPU_MODULE=1) endif() diff --git a/ntt/include/nncase/bfloat16.h b/ntt/include/nncase/bfloat16.h index 1051c31e1..8c614eb63 100644 --- a/ntt/include/nncase/bfloat16.h +++ b/ntt/include/nncase/bfloat16.h @@ -43,11 +43,11 @@ struct bfloat16 { constexpr operator __bf16() const noexcept { return std::bit_cast<__bf16>(value_); } -// #else -// constexpr operator float() const noexcept { -// uint32_t value = raw() << 16; -// return std::bit_cast(value); -// } + // #else + // constexpr operator float() const noexcept { + // uint32_t value = raw() << 16; + // return std::bit_cast(value); + // } #endif @@ -133,7 +133,6 @@ struct bfloat16 { return uint64_t(double(*this)); } - constexpr explicit operator uint8_t() const noexcept { return uint8_t(float(*this)); } @@ -142,7 +141,6 @@ struct bfloat16 { return int8_t(float(*this)); } - constexpr explicit operator int16_t() const noexcept { return int16_t(float(*this)); } @@ -151,7 +149,6 @@ struct bfloat16 { return uint16_t(float(*this)); } - constexpr explicit operator bool() const noexcept { return bool(std::bit_cast(*this)); } @@ -162,8 +159,6 @@ struct bfloat16 { : nan(); } - - static constexpr bfloat16 epsilon() noexcept { // 0x1.0p-7 return from_raw(0x3c00); @@ -199,12 +194,14 @@ struct bfloat16 { }; #define DEFINE_BF16_BINARY_BF16RET(x) \ - inline bfloat16 operator x(bfloat16 a, bfloat16 b) noexcept { \ + NTT_ALWAYS_INLINE constexpr bfloat16 operator x(bfloat16 a, \ + bfloat16 b) noexcept { \ return bfloat16::round_to_bfloat16(float(a) x float(b)); \ } #define DEFINE_BF16_BINARY_BOOLRET(x) \ - inline bool operator x(bfloat16 a, bfloat16 b) noexcept { \ + NTT_ALWAYS_INLINE constexpr bool operator x(bfloat16 a, \ + bfloat16 b) noexcept { \ return float(a) x float(b); \ } @@ -218,7 +215,8 @@ DEFINE_BF16_BINARY_BOOLRET(>=) DEFINE_BF16_BINARY_BOOLRET(>) #define DEFINE_BF16_BINARY_SELF_MOD(x, op) \ - inline bfloat16 &operator x(bfloat16 & a, bfloat16 b) noexcept { \ + NTT_ALWAYS_INLINE constexpr bfloat16 &operator x(bfloat16 &a, \ + bfloat16 b) noexcept { \ a = a op b; \ return a; \ } @@ -228,15 +226,17 @@ DEFINE_BF16_BINARY_SELF_MOD(-=, -) DEFINE_BF16_BINARY_SELF_MOD(*=, *) DEFINE_BF16_BINARY_SELF_MOD(/=, /) -inline bfloat16 operator-(bfloat16 a) noexcept { +NTT_ALWAYS_INLINE constexpr bfloat16 operator-(bfloat16 a) noexcept { return bfloat16::round_to_bfloat16(-float(a)); } -inline bool operator==(const bfloat16 &lhs, const bfloat16 &rhs) noexcept { +NTT_ALWAYS_INLINE constexpr bool operator==(const bfloat16 &lhs, + const bfloat16 &rhs) noexcept { return lhs.raw() == rhs.raw(); } -inline bool operator!=(const bfloat16 &lhs, const bfloat16 &rhs) noexcept { +NTT_ALWAYS_INLINE constexpr bool operator!=(const bfloat16 &lhs, + const bfloat16 &rhs) noexcept { return lhs.raw() != rhs.raw(); } } // namespace nncase @@ -305,67 +305,76 @@ template <> struct numeric_limits { }; using nncase::bfloat16; -inline bool isinf(const bfloat16 &a) { return std::isinf(float(a)); } -inline bool isnan(const bfloat16 &a) { return std::isnan(float(a)); } -inline bool isfinite(const bfloat16 &a) { return std::isfinite(float(a)); } -inline bfloat16 abs(const bfloat16 &a) { +NTT_ALWAYS_INLINE NTT_HOST_DEVICE bool isinf(const bfloat16 &a) { + return std::isinf(float(a)); +} +NTT_ALWAYS_INLINE NTT_HOST_DEVICE bool isnan(const bfloat16 &a) { + return std::isnan(float(a)); +} +NTT_ALWAYS_INLINE NTT_HOST_DEVICE bool isfinite(const bfloat16 &a) { + return std::isfinite(float(a)); +} +NTT_ALWAYS_INLINE NTT_HOST_DEVICE bfloat16 abs(const bfloat16 &a) { return bfloat16::round_to_bfloat16(fabsf(float(a))); } -inline bfloat16 acos(const bfloat16 &a) { +NTT_ALWAYS_INLINE NTT_HOST_DEVICE bfloat16 acos(const bfloat16 &a) { return bfloat16::round_to_bfloat16(std::acos(float(a))); } -inline bfloat16 asin(const bfloat16 &a) { +NTT_ALWAYS_INLINE NTT_HOST_DEVICE bfloat16 asin(const bfloat16 &a) { return bfloat16::round_to_bfloat16(std::asin(float(a))); } -inline bfloat16 erf(const bfloat16 &a) { +NTT_ALWAYS_INLINE NTT_HOST_DEVICE bfloat16 erf(const bfloat16 &a) { return bfloat16::round_to_bfloat16(std::erff(float(a))); } -inline bfloat16 exp(const bfloat16 &a) { +NTT_ALWAYS_INLINE NTT_HOST_DEVICE bfloat16 exp(const bfloat16 &a) { return bfloat16::round_to_bfloat16(expf(float(a))); } -inline bfloat16 log(const bfloat16 &a) { +NTT_ALWAYS_INLINE NTT_HOST_DEVICE bfloat16 log(const bfloat16 &a) { return bfloat16::round_to_bfloat16(logf(float(a))); } -inline bfloat16 log10(const bfloat16 &a) { +NTT_ALWAYS_INLINE NTT_HOST_DEVICE bfloat16 log10(const bfloat16 &a) { return bfloat16::round_to_bfloat16(log10f(float(a))); } -inline bfloat16 sqrt(const bfloat16 &a) { +NTT_ALWAYS_INLINE NTT_HOST_DEVICE bfloat16 sqrt(const bfloat16 &a) { return bfloat16::round_to_bfloat16(sqrtf(float(a))); } -inline bfloat16 pow(const bfloat16 &a, const bfloat16 &b) { +NTT_ALWAYS_INLINE NTT_HOST_DEVICE bfloat16 pow(const bfloat16 &a, + const bfloat16 &b) { return bfloat16::round_to_bfloat16(powf(float(a), float(b))); } -inline bfloat16 sin(const bfloat16 &a) { +NTT_ALWAYS_INLINE NTT_HOST_DEVICE bfloat16 sin(const bfloat16 &a) { return bfloat16::round_to_bfloat16(sinf(float(a))); } -inline bfloat16 cos(const bfloat16 &a) { +NTT_ALWAYS_INLINE NTT_HOST_DEVICE bfloat16 cos(const bfloat16 &a) { return bfloat16::round_to_bfloat16(cosf(float(a))); } -inline bfloat16 tan(const bfloat16 &a) { +NTT_ALWAYS_INLINE NTT_HOST_DEVICE bfloat16 tan(const bfloat16 &a) { return bfloat16::round_to_bfloat16(tanf(float(a))); } -inline bfloat16 tanh(const bfloat16 &a) { +NTT_ALWAYS_INLINE NTT_HOST_DEVICE bfloat16 tanh(const bfloat16 &a) { return bfloat16::round_to_bfloat16(tanhf(float(a))); } -inline bfloat16 floor(const bfloat16 &a) { +NTT_ALWAYS_INLINE NTT_HOST_DEVICE bfloat16 floor(const bfloat16 &a) { return bfloat16::round_to_bfloat16(floorf(float(a))); } -inline bfloat16 ceil(const bfloat16 &a) { +NTT_ALWAYS_INLINE NTT_HOST_DEVICE bfloat16 ceil(const bfloat16 &a) { return bfloat16::round_to_bfloat16(ceilf(float(a))); } -inline bfloat16 round(const bfloat16 &a) { +NTT_ALWAYS_INLINE NTT_HOST_DEVICE bfloat16 round(const bfloat16 &a) { return bfloat16::round_to_bfloat16(roundf(float(a))); } -inline bfloat16 nearbyint(const bfloat16 &a) { +NTT_ALWAYS_INLINE NTT_HOST_DEVICE bfloat16 nearbyint(const bfloat16 &a) { return bfloat16::round_to_bfloat16(nearbyintf(float(a))); } -inline long lrint(const bfloat16 &a) { return lrintf(float(a)); } +NTT_ALWAYS_INLINE NTT_HOST_DEVICE long lrint(const bfloat16 &a) { + return lrintf(float(a)); +} template <> struct is_arithmetic : public true_type {}; } // namespace std -inline nncase::bfloat16 operator"" _bf16(long double x) { +NTT_ALWAYS_INLINE NTT_HOST_DEVICE nncase::bfloat16 +operator""_bf16(long double x) { return nncase::bfloat16(float(x)); } - diff --git a/ntt/include/nncase/float8.h b/ntt/include/nncase/float8.h index c6043c93f..7a36f16e8 100644 --- a/ntt/include/nncase/float8.h +++ b/ntt/include/nncase/float8.h @@ -36,7 +36,7 @@ */ #pragma once -#if defined(__GNUC__) && defined(__x86_64__) +#if defined(__GNUC__) && defined(__x86_64__) && !defined(__clang__) #pragma GCC optimize("no-strict-aliasing") #endif @@ -81,11 +81,12 @@ // #include // #include "nncase/nncase.h" +#include "ntt/compiler_defs.h" #include "bfloat16.h" #include "half.h" #ifndef CUTLASS_HOST_DEVICE -#define CUTLASS_HOST_DEVICE inline -#define CUTLASS_DEVICE inline +#define CUTLASS_HOST_DEVICE NTT_HOST_DEVICE inline +#define CUTLASS_DEVICE NTT_DEVICE inline #endif // !CUTLASS_HOST_DEVICE /////////////////////////////////////////////////////////////////////////////////////////////////// @@ -1384,22 +1385,22 @@ struct numeric_limits // CUTLASS_HOST_DEVICE -nncase::float_e4m3_t operator"" _fe4m3(long double x) { +nncase::float_e4m3_t operator""_fe4m3(long double x) { return nncase::float_e4m3_t(float(x)); } CUTLASS_HOST_DEVICE -nncase::float_e4m3_t operator"" _fe4m3(unsigned long long int x) { +nncase::float_e4m3_t operator""_fe4m3(unsigned long long int x) { return nncase::float_e4m3_t(int(x)); } CUTLASS_HOST_DEVICE -nncase::float_e5m2_t operator"" _fe5m2(long double x) { +nncase::float_e5m2_t operator""_fe5m2(long double x) { return nncase::float_e5m2_t(float(x)); } CUTLASS_HOST_DEVICE -nncase::float_e5m2_t operator"" _fe5m2(unsigned long long int x) { +nncase::float_e5m2_t operator""_fe5m2(unsigned long long int x) { return nncase::float_e5m2_t(int(x)); } diff --git a/ntt/include/nncase/half.h b/ntt/include/nncase/half.h index 7d1a3e0b2..f09826c43 100644 --- a/ntt/include/nncase/half.h +++ b/ntt/include/nncase/half.h @@ -24,11 +24,19 @@ #include #include -#ifdef __F16C__ +#ifdef __CUDA_ARCH__ +#include +#elif defined(__F16C__) #include #endif namespace nncase { +#ifdef __CUDA_ARCH__ +using native_half_t = __half; +#else +using native_half_t = _Float16; +#endif + struct fp16_from_raw_t { explicit fp16_from_raw_t() = default; }; @@ -44,7 +52,7 @@ struct half { public: constexpr half() noexcept = default; - constexpr half(_Float16 v) noexcept : value_(v) {} + constexpr half(native_half_t v) noexcept : value_(v) {} template ::value || @@ -54,9 +62,11 @@ struct half { static constexpr half round_to_half(float v) { if (std::is_constant_evaluated()) { - return (_Float16)v; + return (native_half_t)v; } else { -#ifdef __F16C__ +#ifdef __CUDA_ARCH__ + return __float2half_rn(v); +#elif defined(__F16C__) // To avoid truncsfhf2 return from_raw(_cvtss_sh(v, _MM_FROUND_NEARBYINT)); #else @@ -64,7 +74,7 @@ struct half { #endif } - return (_Float16)v; + return (native_half_t)v; } static constexpr half epsilon() noexcept { return from_raw(0x0800); } @@ -91,14 +101,16 @@ struct half { : value_(round_to_half(float(x)).value_) {} constexpr half(fp16_from_raw_t, uint16_t value) noexcept - : value_(std::bit_cast<_Float16>(value)) {} + : value_(std::bit_cast(value)) {} - constexpr operator _Float16() const noexcept { return value_; } + constexpr operator native_half_t() const noexcept { return value_; } constexpr operator float() const noexcept { if (std::is_constant_evaluated()) { return (float)value_; } else { -#ifdef __F16C__ +#ifdef __CUDA_ARCH__ + return __half2float(value_); +#elif defined(__F16C__) // To avoid extendhfdf2 return _cvtsh_ss(raw()); #else @@ -177,7 +189,7 @@ struct half { } private: - _Float16 value_; + native_half_t value_; }; #define DEFINE_FP16_BINARY_FP16RET(x) \ @@ -216,7 +228,7 @@ DEFINE_FP16_BINARY_BOOLRET(>=) DEFINE_FP16_BINARY_BOOLRET(>) #define DEFINE_FP16_BINARY_SELF_MOD(x, op) \ - NTT_ALWAYS_INLINE half &operator x(half & a, half b) noexcept { \ + NTT_ALWAYS_INLINE half &operator x(half &a, half b) noexcept { \ a = a op b; \ return a; \ } @@ -242,7 +254,8 @@ inline std::ostream &operator<<(std::ostream &os, const half &a) { os << std::to_string(float(a)); return os; } -inline half nextafter(const half &from, const half &to) { +NTT_ALWAYS_INLINE NTT_HOST_DEVICE half nextafter(const half &from, + const half &to) { if (from.raw() == to.raw()) { return to; } @@ -365,48 +378,76 @@ template <> struct numeric_limits { }; using nncase::half; -inline bool isinf(const half &a) { return std::isinf((float)(a)); } -inline bool isnan(const half &a) { return std::isnan(float(a)); } -inline bool isfinite(const half &a) { return std::isfinite(float(a)); } -inline half abs(const half &a) { return half::round_to_half(fabsf(float(a))); } -inline half fabs(const half &a) { return half::round_to_half(fabs(float(a))); } -inline half exp(const half &a) { return half::round_to_half(expf(float(a))); } -inline half log(const half &a) { return half::round_to_half(logf(float(a))); } -inline half log10(const half &a) { +NTT_ALWAYS_INLINE NTT_HOST_DEVICE bool isinf(const half &a) { + return std::isinf((float)(a)); +} +NTT_ALWAYS_INLINE NTT_HOST_DEVICE bool isnan(const half &a) { + return std::isnan(float(a)); +} +NTT_ALWAYS_INLINE NTT_HOST_DEVICE bool isfinite(const half &a) { + return std::isfinite(float(a)); +} +NTT_ALWAYS_INLINE NTT_HOST_DEVICE half abs(const half &a) { + return half::round_to_half(fabsf(float(a))); +} +NTT_ALWAYS_INLINE NTT_HOST_DEVICE half fabs(const half &a) { + return half::round_to_half(fabs(float(a))); +} +NTT_ALWAYS_INLINE NTT_HOST_DEVICE half exp(const half &a) { + return half::round_to_half(expf(float(a))); +} +NTT_ALWAYS_INLINE NTT_HOST_DEVICE half log(const half &a) { + return half::round_to_half(logf(float(a))); +} +NTT_ALWAYS_INLINE NTT_HOST_DEVICE half log10(const half &a) { return half::round_to_half(log10f(float(a))); } -inline half sqrt(const half &a) { return half::round_to_half(sqrtf(float(a))); } -inline half sin(const half &a) { return half::round_to_half(sinf(float(a))); } -inline half cos(const half &a) { return half::round_to_half(cosf(float(a))); } -inline half tan(const half &a) { return half::round_to_half(tanf(float(a))); } -inline half tanh(const half &a) { return half::round_to_half(tanh(float(a))); } -inline half floor(const half &a) { +NTT_ALWAYS_INLINE NTT_HOST_DEVICE half sqrt(const half &a) { + return half::round_to_half(sqrtf(float(a))); +} +NTT_ALWAYS_INLINE NTT_HOST_DEVICE half sin(const half &a) { + return half::round_to_half(sinf(float(a))); +} +NTT_ALWAYS_INLINE NTT_HOST_DEVICE half cos(const half &a) { + return half::round_to_half(cosf(float(a))); +} +NTT_ALWAYS_INLINE NTT_HOST_DEVICE half tan(const half &a) { + return half::round_to_half(tanf(float(a))); +} +NTT_ALWAYS_INLINE NTT_HOST_DEVICE half tanh(const half &a) { + return half::round_to_half(tanh(float(a))); +} +NTT_ALWAYS_INLINE NTT_HOST_DEVICE half floor(const half &a) { return half::round_to_half(floorf(float(a))); } -inline half ceil(const half &a) { return half::round_to_half(ceilf(float(a))); } -inline half round(const half &a) { +NTT_ALWAYS_INLINE NTT_HOST_DEVICE half ceil(const half &a) { + return half::round_to_half(ceilf(float(a))); +} +NTT_ALWAYS_INLINE NTT_HOST_DEVICE half round(const half &a) { return half::round_to_half(roundf(float(a))); } -inline half nearbyint(const half &a) { +NTT_ALWAYS_INLINE NTT_HOST_DEVICE half nearbyint(const half &a) { return half::round_to_half(nearbyintf(float(a))); } -inline half acos(const half &a) { +NTT_ALWAYS_INLINE NTT_HOST_DEVICE half acos(const half &a) { return half::round_to_half(std::acos(float(a))); } -inline half asin(const half &a) { +NTT_ALWAYS_INLINE NTT_HOST_DEVICE half asin(const half &a) { return half::round_to_half(std::asin(float(a))); } -inline half cosh(const half &a) { +NTT_ALWAYS_INLINE NTT_HOST_DEVICE half cosh(const half &a) { return half::round_to_half(std::cosh(float(a))); } -inline half sinh(const half &a) { +NTT_ALWAYS_INLINE NTT_HOST_DEVICE half sinh(const half &a) { return half::round_to_half(std::sinh(float(a))); } -inline half erf(const half &a) { +NTT_ALWAYS_INLINE NTT_HOST_DEVICE half erf(const half &a) { return half::round_to_half(std::erff(float(a))); } -inline long lrint(const half &a) { return lrintf(float(a)); } +NTT_ALWAYS_INLINE NTT_HOST_DEVICE long lrint(const half &a) { + return lrintf(float(a)); +} template <> struct is_floating_point : public std::true_type {}; template <> struct is_arithmetic : public true_type {}; -} // namespace std \ No newline at end of file +} // namespace std diff --git a/ntt/include/nncase/ntt/apply.h b/ntt/include/nncase/ntt/apply.h index 12677aae4..5192e0ddc 100644 --- a/ntt/include/nncase/ntt/apply.h +++ b/ntt/include/nncase/ntt/apply.h @@ -29,7 +29,7 @@ template &index, Offsets offsets, const Shape &shape, const TTile &tile, Callable &&callable, - const std::tuple &strides) { + const ntt::tuple &strides) { auto call = [&](std::index_sequence) { if constexpr (sizeof...(Strides)) { callable(index, offsets[fixed_dim_v]...); @@ -47,7 +47,7 @@ apply_impl(dynamic_shape_t &index, Offsets offsets, std::forward(callable), strides); } ntt::loop([&](auto i) { - offsets[i] += std::get(strides)[fixed_dim_v] * + offsets[i] += ntt::get(strides)[fixed_dim_v] * tile[fixed_dim_v]; }); } @@ -62,7 +62,7 @@ NTT_ALWAYS_INLINE constexpr void apply(const TShape &shape, Callable &&callable, detail::apply_impl<0>(index, make_repeat_shape(0), shape, make_ones_shape(), std::forward(callable), - std::forward_as_tuple(strides...)); + ntt::forward_as_tuple(strides...)); } else { if constexpr (sizeof...(TStrides)) { callable(fixed_shape_v<>, (strides, (dim_t)0)...); @@ -80,7 +80,7 @@ apply_tiled(const TShape &shape, const TTile &tile, Callable &&callable, dynamic_shape_t index{}; detail::apply_impl<0>(index, make_repeat_shape(0), shape, tile, std::forward(callable), - std::forward_as_tuple(strides...)); + ntt::forward_as_tuple(strides...)); } else { if constexpr (sizeof...(TStrides)) { callable(fixed_shape_v<>, (strides, (dim_t)0)...); diff --git a/ntt/include/nncase/ntt/arch/cpu/remote_tensor.h b/ntt/include/nncase/ntt/arch/cpu/remote_tensor.h index f339dabd1..b9c7a3538 100644 --- a/ntt/include/nncase/ntt/arch/cpu/remote_tensor.h +++ b/ntt/include/nncase/ntt/arch/cpu/remote_tensor.h @@ -34,12 +34,13 @@ extern decltype(nncase::ntt::make_tensor>( template TLocalProgramIds, ScopedProgramIds TRemoteProgramIds> -static auto get_remote_address(const TLocalProgramIds &local_program_ids, - const TRemoteProgramIds &remote_program_ids, - T *local_address) { +auto get_remote_address(const TLocalProgramIds &local_program_ids, + const TRemoteProgramIds &remote_program_ids, + T *local_address) { auto start = (size_t)global_local_data_ptr(local_program_ids)(0_dim); auto end = (size_t)global_local_data_ptr(local_program_ids)(1_dim); - auto remote_address = (size_t)global_local_data_ptr(remote_program_ids)(0_dim); + auto remote_address = + (size_t)global_local_data_ptr(remote_program_ids)(0_dim); if ((uintptr_t)local_address < start || (uintptr_t)local_address >= end) { start = (size_t)global_thread_local_rdata_ptr(local_program_ids)(0_dim); end = (size_t)global_thread_local_rdata_ptr(local_program_ids)(1_dim); @@ -47,7 +48,8 @@ static auto get_remote_address(const TLocalProgramIds &local_program_ids, (size_t)global_thread_local_rdata_ptr(remote_program_ids)(0_dim); if ((uintptr_t)local_address < start || (uintptr_t)local_address >= end) { - start = (size_t)global_block_local_rdata_ptr(local_program_ids)(0_dim); + start = + (size_t)global_block_local_rdata_ptr(local_program_ids)(0_dim); remote_address = (size_t)global_block_local_rdata_ptr(remote_program_ids)(0_dim); } diff --git a/ntt/include/nncase/ntt/arch/cpu/runtime.h b/ntt/include/nncase/ntt/arch/cpu/runtime.h index 902c0f3f1..7a2b524bb 100644 --- a/ntt/include/nncase/ntt/arch/cpu/runtime.h +++ b/ntt/include/nncase/ntt/arch/cpu/runtime.h @@ -15,234 +15,13 @@ #pragma once #include "../../profiling.h" #include "../../runtime.h" -#include #include -#include -#include -#include -#include -#include -#include #ifdef __APPLE__ #include #endif namespace nncase::ntt::runtime { - -struct record_id { - int cid = -1; - int bid = -1; - int tid = -1; -}; - -class timer_record : public nncase::ntt::runtime::timer_record_base { - public: - bool is_valid() const override { - return instance_id_.cid != -1 && instance_id_.bid != -1 && - instance_id_.tid != -1; - } - - void set_time(std::string_view function_name, uint64_t start_time, - uint64_t end_time) override { - auto &stats = function_stats_[function_name]; - stats.calls.push_back({start_time, end_time}); - stats.call_count++; - stats.total_time += end_time - start_time; - } - - void set_level(std::string_view filename, profiling_level level) override { - auto &stats = function_stats_[filename]; - stats.level = level; - } - - // print statistics - void console_print() const override { - - if (is_valid()) { - - std::cout << "\033[34m\n" - << "Core Id:" << instance_id_.cid - << ", Block Id:" << instance_id_.bid - << ", Thread Id:" << instance_id_.tid << "\033[0m\n"; - std::cout << "\033[34mStatistics for NTT kernels. \033[0m\n"; - for (const auto &[name, stats] : function_stats_) { - std::cout << "Function: " << name << "\n"; - std::cout << "Level: " << ntt::runtime::to_string(stats.level) - << "\n"; - std::cout << "\tCalls: " << stats.call_count << "\n"; - std::cout << "\tTotal time: " << stats.total_time - << " microseconds\n"; - uint64_t call_count = 0; - for (const auto &call : stats.calls) { - std::cout << "\t\t" - << "Call " << call_count++ << ": \n"; - std::cout << "\t\tStart time: " << call.start_time - << " microseconds\n"; - std::cout << "\t\tEnd time: " << call.end_time - << " microseconds\n"; - std::cout - << "\t\tDuration: " << call.end_time - call.start_time - << " microseconds\n"; - } - } - } - } - - void csv_print(std::string_view filename) const override { - if (is_valid()) { - std::ofstream csv_file(filename.data()); - if (!csv_file.is_open()) { - std::cerr << "Failed to open file: " << filename << std::endl; - return; - } - - csv_file - << "Core Id,Block Id,Thread Id,Function,Level,Calls,Total Time " - "(microseconds),Call Index,Start Time (microseconds),End " - "Time (microseconds),Duration (microseconds)\n"; - - for (const auto &[name, stats] : function_stats_) { - uint64_t call_count = 0; - for (const auto &call : stats.calls) { - csv_file << instance_id_.cid << "," << instance_id_.bid - << "," << instance_id_.tid << "," << name << "," - << ntt::runtime::to_string(stats.level) << "," - << stats.call_count << "," << stats.total_time - << "," << call_count++ << "," << call.start_time - << "," << call.end_time << "," - << (call.end_time - call.start_time) << "\n"; - } - } - - csv_file.close(); - } - } - - void markdown_print(std::string_view filename) const override { - - if (is_valid()) { - std::ofstream md_file(filename.data()); - if (!md_file.is_open()) { - std::cerr << "Failed to open file: " << filename << std::endl; - return; - } - - md_file << "### Core Information\n"; - md_file << "| Core Id | Block Id | Thread Id |\n"; - md_file << "|---------|----------|-----------|\n"; - md_file << "| " << instance_id_.cid << " | " << instance_id_.bid - << " | " << instance_id_.tid << " |\n"; - - md_file << "\n### NTT Kernels Statistics\n"; - - for (const auto &[name, stats] : function_stats_) { - md_file << "#### Function: " << name << "\n"; - md_file << "| Level | Calls | Total Time (microseconds) |\n"; - md_file << "|-------|-------|---------------------------|\n"; - md_file << "| " << ntt::runtime::to_string(stats.level) << " | " - << stats.call_count << " | " << stats.total_time - << " |\n"; - - md_file << "\n**Call Details:**\n"; - md_file << "| Call Index | Start Time (microseconds) | End " - "Time (microseconds) | Duration (microseconds) |\n"; - md_file << "|------------|---------------------------|---------" - "----------------|-------------------------|\n"; - - uint64_t call_count = 0; - for (const auto &call : stats.calls) { - md_file << "| " << call_count++ << " | " << call.start_time - << " | " << call.end_time << " | " - << (call.end_time - call.start_time) << " |\n"; - } - md_file << "\n"; - } - - md_file.close(); - } - } - - void json_print(std::string_view filename) const override { - - if (is_valid()) { - std::ofstream json_file(filename.data()); - if (!json_file.is_open()) { - std::cerr << "Failed to open file: " << filename << std::endl; - return; - } - - std::string pid = "\"cid: " + std::to_string(instance_id_.cid) + - ", bid: " + std::to_string(instance_id_.bid) + - "\""; - std::string tid = - "\"tid: " + std::to_string(instance_id_.tid) + "\""; - json_file << "[\n"; - - bool first = true; - for (const auto &[name, stats] : function_stats_) { - for (const auto &call : stats.calls) { - if (stats.level == profiling_level::kernel) { - if (!first) { - json_file << ",\n"; - } - first = false; - json_file << " {\n"; - json_file << " \"name\": \"" << name << "\",\n"; - json_file << " \"ph\": \"X\",\n"; - json_file << " \"ts\": " << call.start_time << ",\n"; - json_file << " \"dur\": " - << (call.end_time - call.start_time) << ",\n"; - json_file << " \"pid\": " << pid << ",\n"; - json_file << " \"tid\": " << tid << ",\n"; - json_file << " \"args\": { \"level:\":\"" - << ntt::runtime::to_string(stats.level) - << " \"}\n"; - json_file << " }"; - } - } - } - - for (const auto &[name, stats] : function_stats_) { - for (const auto &call : stats.calls) { - if (stats.level == profiling_level::device) { - if (!first) { - json_file << ",\n"; - } - first = false; - json_file << " {\n"; - json_file << " \"name\": \"" << name << "\",\n"; - json_file << " \"ph\": \"X\",\n"; - json_file << " \"ts\": " << call.start_time << ",\n"; - json_file << " \"dur\": " - << (call.end_time - call.start_time) << ",\n"; - json_file << " \"pid\": " << pid << ",\n"; - json_file << " \"tid\": " << tid << ",\n"; - json_file << " \"args\": { \"level:\":\"" - << ntt::runtime::to_string(stats.level) - << " \"}\n"; - json_file << " }"; - } - } - } - - json_file << "\n]\n"; - json_file.close(); - } - } - - timer_record() = default; - - ~timer_record() { - console_print(); - markdown_print("nncase_profiling.md"); - csv_print("nncase_profiling.csv"); - json_print("nncase_profiling.json"); - } - - void set_id(record_id id) override { instance_id_ = id; } -}; - struct cpu_block_entry_params_t { size_t tdim; size_t bdim; @@ -250,12 +29,11 @@ struct cpu_block_entry_params_t { size_t bid; size_t cid; size_t cpu_id_offset; + uint8_t enable_profiling; const thread_inout_desc *input_descs; thread_inout_desc *const output_descs; std::span rdata; std::byte *output; - uint8_t enable_profiling; - timer_record *timer_records; const uint64_t *thread_local_rdata_header; const uint64_t *thread_local_cache_header; std::span thread_local_rdata; @@ -264,6 +42,8 @@ struct cpu_block_entry_params_t { std::span block_local_rdata; std::span thread_local_data; std::span block_local_data; + std::span profile_records; + uint32_t *profile_record_counts; #ifdef __APPLE__ pthread_key_t cpu_thread_context_key; #endif @@ -273,8 +53,9 @@ struct cpu_thread_context_t { size_t tid; size_t bid; size_t cid; - timer_record *timer_records; uint8_t enable_profiling; + std::span profile_records; + uint32_t *profile_record_counts; static cpu_thread_context_t ¤t() noexcept; }; diff --git a/ntt/include/nncase/ntt/arch/cuda/distributed.h b/ntt/include/nncase/ntt/arch/cuda/distributed.h new file mode 100644 index 000000000..0aeaee9ca --- /dev/null +++ b/ntt/include/nncase/ntt/arch/cuda/distributed.h @@ -0,0 +1,17 @@ +/* Copyright 2019-2021 Canaan Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once +#include "remote_tensor.h" +#include "topology.h" diff --git a/ntt/include/nncase/ntt/arch/cpu/profiling.h b/ntt/include/nncase/ntt/arch/cuda/profiling.h similarity index 74% rename from ntt/include/nncase/ntt/arch/cpu/profiling.h rename to ntt/include/nncase/ntt/arch/cuda/profiling.h index a5022fa17..7105f0f2e 100644 --- a/ntt/include/nncase/ntt/arch/cpu/profiling.h +++ b/ntt/include/nncase/ntt/arch/cuda/profiling.h @@ -28,16 +28,16 @@ namespace nncase::ntt { // static nncase::ntt::runtime::timer_record // timer_records[CHIP_COUNTER][BLOCK_COUNTER][THREAD_COUNTER]; -// auto_profiler, start timing and end timing -class auto_profiler { +// profile_scope, start timing and end timing +class profile_scope { public: - inline uint64_t get_current_time() const { + __device__ inline uint64_t get_current_time() const { return std::chrono::duration_cast( std::chrono::high_resolution_clock::now().time_since_epoch()) .count(); } - auto_profiler(std::string_view function_name) + __device__ profile_scope(std::string_view function_name) : cid_(program_id()), bid_(program_id()), tid_(program_id()) { @@ -50,15 +50,15 @@ class auto_profiler { } } - auto_profiler(std::string_view function_name, - runtime::profiling_level level) - : auto_profiler(function_name) { // 调用另一个构造函数 + __device__ profile_scope(std::string_view function_name, + profile_level level) + : profile_scope(function_name) { // 调用另一个构造函数 if (enable_profiling_) { level_ = level; // 设置 level } } - ~auto_profiler() { + __device__ ~profile_scope() { if (enable_profiling_) { timer_storage_->set_id({cid_, bid_, tid_}); end_time_ = get_current_time(); @@ -74,16 +74,17 @@ class auto_profiler { int cid_; int bid_; int tid_; - nncase::ntt::runtime::profiling_level level_; + nncase::ntt::profile_level level_; nncase::ntt::runtime::timer_record *timer_storage_; bool enable_profiling_; - inline bool get_profiler_option() noexcept { - return runtime::cpu_thread_context_t::current().enable_profiling; + __device__ inline bool get_profiler_option() noexcept { + return runtime::cuda_thread_context_t::current().enable_profiling; } - inline nncase::ntt::runtime::timer_record *get_timer_record() noexcept { - return runtime::cpu_thread_context_t::current().timer_records; + __device__ inline nncase::ntt::runtime::timer_record * + get_timer_record() noexcept { + return runtime::cuda_thread_context_t::current().timer_records; } }; diff --git a/ntt/include/nncase/ntt/arch/cuda/remote_tensor.h b/ntt/include/nncase/ntt/arch/cuda/remote_tensor.h new file mode 100644 index 000000000..06e644b89 --- /dev/null +++ b/ntt/include/nncase/ntt/arch/cuda/remote_tensor.h @@ -0,0 +1,81 @@ +/* Copyright 2019-2021 Canaan Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once +#include "../../distributed/remote_tensor.h" +#include "../../tensor.h" +#include "../../vector.h" + +namespace nncase::ntt::distributed { +namespace detail { +extern __device__ decltype(nncase::ntt::make_tensor< + nncase::ntt::vector>( + nncase::ntt::distributed::topology_shape)) global_local_data_ptr; + +extern __device__ decltype(nncase::ntt::make_tensor< + nncase::ntt::vector>( + nncase::ntt::distributed::topology_shape)) global_thread_local_rdata_ptr; + +extern __device__ decltype(nncase::ntt::make_tensor< + nncase::ntt::vector>( + nncase::ntt::distributed::topology_shape)) global_thread_local_cache_ptr; + +extern __device__ decltype(nncase::ntt::make_tensor< + nncase::ntt::vector>( + nncase::ntt::distributed::topology_shape)) global_block_local_rdata_ptr; + +template TLocalProgramIds, + ScopedProgramIds TRemoteProgramIds> +__device__ auto get_remote_address(const TLocalProgramIds &local_program_ids, + const TRemoteProgramIds &remote_program_ids, + T *local_address) { + auto start = (size_t)global_local_data_ptr(local_program_ids)(0_dim); + auto end = (size_t)global_local_data_ptr(local_program_ids)(1_dim); + auto remote_address = + (size_t)global_local_data_ptr(remote_program_ids)(0_dim); + if ((uintptr_t)local_address < start || (uintptr_t)local_address >= end) { + start = (size_t)global_thread_local_rdata_ptr(local_program_ids)(0_dim); + end = (size_t)global_thread_local_rdata_ptr(local_program_ids)(1_dim); + remote_address = + (size_t)global_thread_local_rdata_ptr(remote_program_ids)(0_dim); + if ((uintptr_t)local_address < start || + (uintptr_t)local_address >= end) { + start = + (size_t)global_block_local_rdata_ptr(local_program_ids)(0_dim); + remote_address = + (size_t)global_block_local_rdata_ptr(remote_program_ids)(0_dim); + } + } + + return local_address - (T *)start + (T *)remote_address; +} +} // namespace detail + +template +struct remote_tensor_constructor { + template TLocalProgramIds, + ScopedProgramIds TRemoteProgramIds> + constexpr auto operator()(T *data, const TShape &shape, + const TStrides &strides, + const TLocalProgramIds &local_program_ids, + const TRemoteProgramIds &remote_program_ids) { + auto remote_address = + detail::get_remote_address( + local_program_ids, remote_program_ids, data); + return make_tensor_view_from_address(remote_address, shape, strides); + } +}; +} // namespace nncase::ntt::distributed diff --git a/ntt/include/nncase/ntt/arch/cuda/runtime.h b/ntt/include/nncase/ntt/arch/cuda/runtime.h new file mode 100644 index 000000000..4db64d06b --- /dev/null +++ b/ntt/include/nncase/ntt/arch/cuda/runtime.h @@ -0,0 +1,57 @@ +/* Copyright 2019-2021 Canaan Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once +#include "../../profiling.h" +#include "../../runtime.h" +#include "../../std_containers.h" +#include + +namespace nncase::ntt::runtime { +struct cuda_block_entry_params_t { + size_t tdim; + size_t bdim; + size_t cdim; + size_t cid; + uint8_t enable_profiling; + const thread_inout_desc *input_descs; + thread_inout_desc *const output_descs; + ntt::span rdata; + std::byte *output; + const uint64_t *thread_local_rdata_header; + ntt::span thread_local_rdata; + const uint64_t *warp_local_rdata_header; + ntt::span warp_local_rdata; + const uint64_t *block_local_rdata_header; + ntt::span block_local_rdata; + ntt::span thread_local_data; + ntt::span warp_local_data; + ntt::span block_local_data; + ntt::span profile_records; + uint32_t *profile_record_counts; +}; + +struct cuda_thread_context_t { + size_t cid; + uint8_t enable_profiling; + ntt::span profile_records; + uint32_t *profile_record_counts; + + NTT_DEVICE static cuda_thread_context_t ¤t() noexcept; +}; +} // namespace nncase::ntt::runtime + +extern "C" NTT_KERNEL NTT_RUNTIME_API void +block_entry(const nncase::ntt::runtime::cuda_block_entry_params_t ¶ms); +using block_entry_t = decltype(block_entry) *; diff --git a/ntt/include/nncase/ntt/arch/cuda/topology.h b/ntt/include/nncase/ntt/arch/cuda/topology.h new file mode 100644 index 000000000..d5a81dd7a --- /dev/null +++ b/ntt/include/nncase/ntt/arch/cuda/topology.h @@ -0,0 +1,90 @@ +/* Copyright 2019-2021 Canaan Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once +#include "../../distributed/topology.h" +#include "runtime.h" +#include +#include + +namespace nncase::ntt::distributed { +template <> struct program_id_getter { + __device__ static size_t id() noexcept { + if constexpr (program_dim() == warpSize) { + return cuda::ptx::get_sreg_laneid(); + + } else { + return threadIdx.x % program_dim(); + } + } +}; + +template <> struct program_id_getter { + __device__ static size_t id() noexcept { + return threadIdx.x / program_dim(); + } +}; + +template <> struct program_id_getter { + __device__ static size_t id() noexcept { return blockIdx.x; } +}; + +template <> struct program_id_getter { + __device__ static size_t id() noexcept { + return runtime::cuda_thread_context_t::current().cid; + } +}; + +inline __device__ size_t tid() noexcept { + return program_id(); +} + +inline __device__ size_t wid() noexcept { return program_id(); } + +inline __device__ size_t bid() noexcept { + return program_id(); +} + +inline __device__ size_t cid() noexcept { return program_id(); } + +inline constexpr auto tdim() noexcept { + return program_dim(); +} +inline constexpr auto wdim() noexcept { return program_dim(); } +inline constexpr auto bdim() noexcept { return program_dim(); } +inline constexpr auto cdim() noexcept { return program_dim(); } + +template <> class topology_synchronizer { + public: + __device__ static void synchronize() noexcept { __syncwarp(); } +}; + +template <> class topology_synchronizer { + public: + __device__ static void synchronize() noexcept { __syncthreads(); } +}; + +template <> class topology_synchronizer { + public: + __device__ static void synchronize() noexcept { + cooperative_groups::grid_group g = cooperative_groups::this_grid(); + g.sync(); + } +}; + +template <> class topology_synchronizer { + public: + __device__ static void synchronize() noexcept {} +}; +} // namespace nncase::ntt::distributed diff --git a/ntt/include/nncase/ntt/arch/cuda/topology_def.h b/ntt/include/nncase/ntt/arch/cuda/topology_def.h new file mode 100644 index 000000000..ae87a742e --- /dev/null +++ b/ntt/include/nncase/ntt/arch/cuda/topology_def.h @@ -0,0 +1,19 @@ +/* Copyright 2019-2021 Canaan Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +namespace nncase::ntt::distributed { +enum class topology { chip, block, warp, thread, count__ }; +} // namespace nncase::ntt::distributed diff --git a/ntt/include/nncase/ntt/arch/cuda/vector_ops.h b/ntt/include/nncase/ntt/arch/cuda/vector_ops.h new file mode 100644 index 000000000..e3d0a326d --- /dev/null +++ b/ntt/include/nncase/ntt/arch/cuda/vector_ops.h @@ -0,0 +1,19 @@ +/* Copyright 2019-2021 Canaan Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once +#include "../../vector_ops.h" +#include + +namespace nncase::ntt::ops {} // namespace nncase::ntt::ops diff --git a/ntt/include/nncase/ntt/arch/x86_64/ukernels.h b/ntt/include/nncase/ntt/arch/x86_64/ukernels.h index 78055ea97..095cf5f8c 100644 --- a/ntt/include/nncase/ntt/arch/x86_64/ukernels.h +++ b/ntt/include/nncase/ntt/arch/x86_64/ukernels.h @@ -337,9 +337,9 @@ class u_pack2d> { constexpr auto axes_temp = TAxes{}; constexpr auto conti_dims_input = - contiguous_dims(input.shape(), input.strides()); + contiguous_dims(TIn::shape(), TIn::strides()); constexpr auto conti_dims_output = - contiguous_dims(output.shape(), output.strides()); + contiguous_dims(TOut::shape(), TOut::strides()); if constexpr (TAxes::rank() == 2 && axes_temp[0_dim] + 1 == axes_temp[1_dim] && @@ -349,7 +349,6 @@ class u_pack2d> { if constexpr (TAxes::rank() > 0 && (TAxes{}[-1]) == (TIn::rank() - 1)) { using TVec = vector; - constexpr auto in_rank = TIn::rank(); constexpr auto out_rank = TOut::rank(); constexpr auto lanes = TVec::shape(); constexpr auto out_shape = TOut::shape(); @@ -371,7 +370,6 @@ class u_pack2d> { } else { using TVec = vector; constexpr auto in_rank = TIn::rank(); - constexpr auto out_rank = TOut::rank(); constexpr auto lanes = TVec::shape(); const auto out_shape = output.shape(); diff --git a/ntt/include/nncase/ntt/caching.h b/ntt/include/nncase/ntt/caching.h index 6108f5ed1..ace35c211 100644 --- a/ntt/include/nncase/ntt/caching.h +++ b/ntt/include/nncase/ntt/caching.h @@ -72,7 +72,7 @@ struct paged_attention_config using vectorized_axes_t = VectorizedAxes; using lanes_t = Lanes; using sharding_axes_t = ShardingAxes; - using axis_policies_t = std::tuple; + using axis_policies_t = ntt::tuple; static inline constexpr auto cache_layout = cache_layout_t{}; static inline constexpr auto block_layout = block_layout_t{}; @@ -88,7 +88,7 @@ struct paged_attention_config if constexpr (index == -1_dim) { return distributed::shard_policy::B; } else { - return std::get(axis_policies); + return ntt::get(axis_policies); } } }; @@ -125,14 +125,16 @@ template class attention_kv_cache { context_lens_(std::move(context_lens)), seq_lens_(std::move(seq_lens)) {} - size_t num_seqs() const noexcept { return num_seqs_; } - size_t num_tokens() const noexcept { return num_tokens_; } + constexpr size_t num_seqs() const noexcept { return num_seqs_; } + constexpr size_t num_tokens() const noexcept { return num_tokens_; } - int64_t context_len(int64_t request_id) const noexcept { + constexpr int64_t context_len(int64_t request_id) const noexcept { return context_lens_(request_id); } - int64_t seq_len(int64_t seq_id) const noexcept { return seq_lens_(seq_id); } + constexpr int64_t seq_len(int64_t seq_id) const noexcept { + return seq_lens_(seq_id); + } protected: size_t num_seqs_; @@ -150,7 +152,7 @@ kv_dim(const distributed::shard_policy::split &split) noexcept { } template -constexpr auto kv_addr_shape(std::tuple) noexcept { +constexpr auto kv_addr_shape(ntt::tuple) noexcept { return fixed_shape_v(AxisPolicies{})...>; } @@ -177,7 +179,7 @@ constexpr auto origin_kv_cache_one_block_shape() noexcept { auto shard_shape = TConfig::sharding_axes.aggregate( vectorized_shape, [&](auto last_shape, auto sharding_axis, auto i) { using axis_policy_t = - std::tuple_element_t; + ntt::tuple_element_t; if constexpr (sharding_axis == fixed_dim_v<( dim_t)paged_kvcache_dim_kind::num_blocks>) { @@ -246,10 +248,12 @@ class paged_attention_kv_cache : public attention_kv_cache { using kv_addrs_t = decltype(make_tensor_view_from_address( std::declval(), kv_addrs_shape)); - paged_attention_kv_cache(size_t num_seqs, size_t num_tokens, - context_lens_t context_lens, seq_lens_t seq_lens, - block_table_t block_table, - slot_mapping_t slot_mapping, kv_addrs_t kv_addrs) + constexpr paged_attention_kv_cache(size_t num_seqs, size_t num_tokens, + context_lens_t context_lens, + seq_lens_t seq_lens, + block_table_t block_table, + slot_mapping_t slot_mapping, + kv_addrs_t kv_addrs) : attention_kv_cache(num_seqs, num_tokens, context_lens, seq_lens), block_table_(block_table), @@ -311,7 +315,7 @@ class paged_attention_kv_cache : public attention_kv_cache { } } - auto block_table() const noexcept { return block_table_; } + constexpr auto block_table() const noexcept { return block_table_; } private: template @@ -326,7 +330,7 @@ class paged_attention_kv_cache : public attention_kv_cache { generate_shape([&](auto axis) { const auto index = block_shard_index(axis); const auto submesh_axes = - std::get(TConfig::axis_policies).axes; + ntt::get(TConfig::axis_policies).axes; const auto submesh_shape = Mesh::shape.select(submesh_axes); const auto local_program_id = linear_offset( local_index.select(submesh_axes), submesh_shape); diff --git a/ntt/include/nncase/ntt/compiler_defs.h b/ntt/include/nncase/ntt/compiler_defs.h index a101a294c..516d5c926 100644 --- a/ntt/include/nncase/ntt/compiler_defs.h +++ b/ntt/include/nncase/ntt/compiler_defs.h @@ -52,3 +52,13 @@ defined(__riscv_zvfbf) #define NTT_HAVE_NATIVE_BF16 1 #endif + +#ifdef __CUDACC__ +#define NTT_HOST_DEVICE __host__ __device__ +#define NTT_DEVICE __device__ +#define NTT_KERNEL __global__ +#else +#define NTT_HOST_DEVICE +#define NTT_DEVICE +#define NTT_KERNEL +#endif diff --git a/ntt/include/nncase/ntt/detail/shape_storage.h b/ntt/include/nncase/ntt/detail/shape_storage.h index 66ecbb029..aaf13e676 100644 --- a/ntt/include/nncase/ntt/detail/shape_storage.h +++ b/ntt/include/nncase/ntt/detail/shape_storage.h @@ -14,7 +14,7 @@ */ #pragma once #include "../compiler_defs.h" -#include "nncase/ntt/shape.h" +#include "../tensor_traits.h" #include #include @@ -68,7 +68,7 @@ struct NTT_EMPTY_BASES tensor_size_impl : public shape_storage, class = std::enable_if_t && FixedStrides>> constexpr tensor_size_impl() noexcept {} - tensor_size_impl(Shape shape, Strides strides) + constexpr tensor_size_impl(Shape shape, Strides strides) : shape_storage(shape), strides_storage(strides) {} }; } // namespace nncase::ntt::detail diff --git a/ntt/include/nncase/ntt/detail/tensor_storage.h b/ntt/include/nncase/ntt/detail/tensor_storage.h index 503150114..019c29512 100644 --- a/ntt/include/nncase/ntt/detail/tensor_storage.h +++ b/ntt/include/nncase/ntt/detail/tensor_storage.h @@ -14,6 +14,7 @@ */ #pragma once #include "../shape.h" +#include "../std_containers.h" #include "nncase/ntt/tensor_traits.h" #include @@ -23,7 +24,7 @@ template class tensor_storage; // fixed tensor template class tensor_storage { public: - using buffer_type = std::array; + using buffer_type = array; constexpr tensor_storage() = default; @@ -35,10 +36,10 @@ template class tensor_storage { constexpr const buffer_type &buffer() const noexcept { return buffer_; } constexpr buffer_type &buffer() noexcept { return buffer_; } - constexpr std::span elements() const noexcept { + constexpr span elements() const noexcept { return buffer_; } - constexpr std::span elements() noexcept { return buffer_; } + constexpr span elements() noexcept { return buffer_; } private: buffer_type buffer_; @@ -47,7 +48,7 @@ template class tensor_storage { // fixed view template class tensor_storage { public: - using buffer_type = std::span; + using buffer_type = span; constexpr tensor_storage(std::in_place_t, buffer_type value) : buffer_(value) {} @@ -55,10 +56,10 @@ template class tensor_storage { constexpr const buffer_type &buffer() const noexcept { return buffer_; } constexpr buffer_type &buffer() noexcept { return buffer_; } - constexpr std::span elements() const noexcept { + constexpr span elements() const noexcept { return buffer_; } - constexpr std::span elements() noexcept { return buffer_; } + constexpr span elements() noexcept { return buffer_; } private: buffer_type buffer_; @@ -76,10 +77,10 @@ template class tensor_storage { constexpr const buffer_type &buffer() const noexcept { return buffer_; } constexpr buffer_type &buffer() noexcept { return buffer_; } - constexpr std::span elements() const noexcept { + constexpr span elements() const noexcept { return {buffer_.data(), buffer_.size()}; } - constexpr std::span elements() noexcept { + constexpr span elements() noexcept { return {buffer_.data(), buffer_.size()}; } @@ -98,10 +99,10 @@ template <> class tensor_storage { constexpr const buffer_type &buffer() const noexcept { return buffer_; } constexpr buffer_type &buffer() noexcept { return buffer_; } - std::span elements() const noexcept { + span elements() const noexcept { return {reinterpret_cast(buffer_.data()), buffer_.size()}; } - std::span elements() noexcept { + span elements() noexcept { return {reinterpret_cast(buffer_.data()), buffer_.size()}; } @@ -112,8 +113,8 @@ template <> class tensor_storage { // dynamic view template class tensor_storage { public: - using const_buffer_type = std::span; - using buffer_type = std::span; + using const_buffer_type = span; + using buffer_type = span; constexpr tensor_storage(std::in_place_t, buffer_type value) : buffer_(value) {} diff --git a/ntt/include/nncase/ntt/dimension.h b/ntt/include/nncase/ntt/dimension.h index d77e7f5c8..aa9973d09 100644 --- a/ntt/include/nncase/ntt/dimension.h +++ b/ntt/include/nncase/ntt/dimension.h @@ -16,6 +16,7 @@ #include "primitive_ops.h" #include "tensor_traits.h" #include +#include #include #include #include @@ -60,7 +61,7 @@ template struct char_literal { }; } // namespace detail -template inline constexpr auto operator"" _dim() { +template inline constexpr auto operator""_dim() { constexpr auto value = detail::char_literal::to_int; return fixed_dim_v; } @@ -184,16 +185,15 @@ constexpr auto positive_index(const TIndex &index, } else { return index; } + } else if constexpr (std::unsigned_integral) { + return index; } else { return index < 0 ? index + dim : index; } } -namespace detail { -template struct dim_where_impl; - -template -struct dim_where_impl { +namespace ops { +template struct where { constexpr dim_t operator()(const Cond &cond, const T &true_dim, const F &false_dim) const noexcept { return cond ? dim_value(true_dim) : dim_value(false_dim); @@ -201,7 +201,7 @@ struct dim_where_impl { }; template -struct dim_where_impl, T, F> { +struct where, T, F> { constexpr auto operator()(const std::integral_constant &, [[maybe_unused]] const T &true_dim, @@ -213,12 +213,5 @@ struct dim_where_impl, T, F> { } } }; -} // namespace detail - -template -constexpr auto where(const Cond &cond, const T &true_dim, - const F &false_dim) noexcept { - detail::dim_where_impl impl; - return impl(cond, true_dim, false_dim); -} +} // namespace ops } // namespace nncase::ntt diff --git a/ntt/include/nncase/ntt/distributed.h b/ntt/include/nncase/ntt/distributed.h index 63dd0093f..482928657 100644 --- a/ntt/include/nncase/ntt/distributed.h +++ b/ntt/include/nncase/ntt/distributed.h @@ -17,4 +17,3 @@ #include "distributed/sharded_tensor.h" #include "distributed/sharding.h" #include "distributed/topology.h" -#include "kernels/reshard.h" diff --git a/ntt/include/nncase/ntt/distributed/sharding.h b/ntt/include/nncase/ntt/distributed/sharding.h index 13919a27b..40a415b41 100644 --- a/ntt/include/nncase/ntt/distributed/sharding.h +++ b/ntt/include/nncase/ntt/distributed/sharding.h @@ -116,7 +116,7 @@ concept SplitShardPolicy = is_split_shard_policy::value; template struct sharding { using mesh_type = Mesh; - using axis_policies_type = std::tuple; + using axis_policies_type = ntt::tuple; using dynamic_offset_t = dynamic_shape_t; static constexpr auto rank() { @@ -134,7 +134,7 @@ template struct sharding { global_offset(const GlobalShape &global_shape, const TShardIndex &shard_index) const noexcept { auto get_dim = [&, this] { - return std::get(axis_policies) + return ntt::get(axis_policies) .template global_offset(global_shape[fixed_dim_v], shard_index); }; @@ -148,7 +148,7 @@ template struct sharding { constexpr auto shard_shape(const GlobalShape &global_shape, const TShardIndex &shard_index) const noexcept { auto get_dim = [&, this] { - return std::get(axis_policies) + return ntt::get(axis_policies) .template shard_dim(global_shape[fixed_dim_v], shard_index); }; @@ -158,7 +158,7 @@ template struct sharding { return get_all_dims(std::make_index_sequence{}); } - std::tuple axis_policies; + ntt::tuple axis_policies; }; template @@ -170,7 +170,7 @@ namespace detail { template constexpr bool is_divisible(const Sharding &sharding, const GlobalShape &shape, std::index_sequence) noexcept { - return ((std::get(sharding.axis_policies) + return ((ntt::get(sharding.axis_policies) .template is_divisible( shape.at(Ids))) && ...); @@ -182,7 +182,7 @@ constexpr auto mesh_axes_mask_of_split_shard_policies() noexcept { return generate_shape([](auto mesh_axis) { return make_index_shape().aggregate( dim_zero, [&](auto last_mask, auto axis, auto) { - using policy_t = std::tuple_element_t< + using policy_t = ntt::tuple_element_t< axis, typename TSharding::axis_policies_type>; if constexpr (distributed::SplitShardPolicy) { if constexpr (policy_t::axes.contains(mesh_axis)) { @@ -212,7 +212,7 @@ template constexpr auto tensor_axes_mask_of_split_shard_policies() noexcept { return generate_shape([](auto axis) { using policy_t = - std::tuple_element_t; + ntt::tuple_element_t; if constexpr (distributed::SplitShardPolicy) { return dim_one; } else { @@ -239,7 +239,7 @@ constexpr auto local_shard_dim(const TSharding &sharding, using mesh_type = typename TSharding::mesh_type; const auto local_index = mesh_type::local_index(); - return std::get(sharding.axis_policies) + return ntt::get(sharding.axis_policies) .template shard_dim(global_shape[fixed_dim_v], local_index); } diff --git a/ntt/include/nncase/ntt/distributed/topology.h b/ntt/include/nncase/ntt/distributed/topology.h index 8491cf91e..c84971962 100644 --- a/ntt/include/nncase/ntt/distributed/topology.h +++ b/ntt/include/nncase/ntt/distributed/topology.h @@ -13,8 +13,11 @@ * limitations under the License. */ #pragma once +#include "nncase/ntt/compiler_defs.h" #if defined(NNCASE_XPU_MODULE) #include "../arch/xpu/topology_def.h" +#elif defined(__CUDA_ARCH__) +#include "../arch/cuda/topology_def.h" #else #include "../arch/cpu/topology_def.h" #endif @@ -70,17 +73,17 @@ constexpr auto topology_up_size() noexcept { } template struct program_id_getter { - static dim_t id() noexcept; + NTT_HOST_DEVICE static dim_t id() noexcept; }; -template dim_t program_id() noexcept { +template NTT_HOST_DEVICE dim_t program_id() noexcept { return program_id_getter::id(); } bool get_profiler_option() noexcept; template -auto program_ids() noexcept { +NTT_HOST_DEVICE auto program_ids() noexcept { auto f = [](std::index_sequence) { return make_shape(program_id(Is)>()...); }; @@ -89,7 +92,8 @@ auto program_ids() noexcept { template class topology_synchronizer; -template void topology_synchronize() noexcept { +template +NTT_HOST_DEVICE void topology_synchronize() noexcept { topology_synchronizer::synchronize(); } } // namespace nncase::ntt::distributed diff --git a/ntt/include/nncase/ntt/kernels/binary.h b/ntt/include/nncase/ntt/kernels/binary.h index 55f1f388b..8e10e54c3 100644 --- a/ntt/include/nncase/ntt/kernels/binary.h +++ b/ntt/include/nncase/ntt/kernels/binary.h @@ -26,8 +26,9 @@ class binary_impl TRhs, TOut> { public: template - void invoke_ukernel(const TBroadcastedLhs &lhs, const TBroadcastedRhs &rhs, - TOut &output, const TOp &op, bool is_broadcast) { + constexpr void invoke_ukernel(const TBroadcastedLhs &lhs, + const TBroadcastedRhs &rhs, TOut &output, + const TOp &op, bool is_broadcast) { auto lhs_conti_dims = contiguous_dims(lhs.shape(), lhs.strides()); auto rhs_conti_dims = contiguous_dims(rhs.shape(), rhs.strides()); @@ -106,7 +107,7 @@ class binary_impl template