diff --git a/.github/workflows/compiler-build.yml b/.github/workflows/compiler-build.yml
index 32ab43f05..04a0a08ea 100644
--- a/.github/workflows/compiler-build.yml
+++ b/.github/workflows/compiler-build.yml
@@ -138,7 +138,7 @@ jobs:
working-directory: ${{github.workspace}}
run: |
dotnet tool install --global dotnet-coverage
- dotnet-coverage collect -s tools/dotnet_coverage.settings.xml -f cobertura -o coverage/unit.xml "dotnet test -c ${{matrix.config.buildType}} -s test.runsettings --no-build --verbosity normal --blame"
+ dotnet-coverage collect -s tools/dotnet_coverage.settings.xml -f cobertura -o coverage/unit.xml "dotnet test -c ${{matrix.config.buildType}} -s test.runsettings --no-build --verbosity normal --filter FullyQualifiedName!~Nncase.Tests.TargetTest.UnitTestCUDAKernels --blame"
dotnet-coverage merge -o coverage.unit.xml -f cobertura -r coverage/*.xml
- name: Upload Coverage
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 0193cbf70..01d61edc0 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -47,6 +47,18 @@ option(BUILD_TESTING "Build test programs" OFF)
option(ENABLE_OP_PROFILE "Profile ops cast time" OFF)
option(ENABLE_DUMP_MANAGER "Enable dump manager" OFF)
option(ENABLE_DUMP_MEM "Dump mem usage" OFF)
+option(ENABLE_CUDA_RUNTIME "Enable CUDA runtime" OFF)
+
+if(DEFINED CMAKE_CUDA_COMPILER AND NOT "${CMAKE_CUDA_COMPILER}" STREQUAL "")
+ set(ENABLE_CUDA_RUNTIME ON CACHE BOOL "Enable CUDA runtime" FORCE)
+endif()
+
+if(ENABLE_CUDA_RUNTIME)
+ if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES)
+ set(CMAKE_CUDA_ARCHITECTURES 120)
+ endif()
+ enable_language(CUDA)
+endif()
if (BUILDING_RUNTIME)
# option(ENABLE_VULKAN_RUNTIME "Enable Vulkan runtime" OFF)
diff --git a/cmake/compile_flags.cmake b/cmake/compile_flags.cmake
index b5d7a36c4..fa9022513 100644
--- a/cmake/compile_flags.cmake
+++ b/cmake/compile_flags.cmake
@@ -4,7 +4,7 @@ if (MSVC)
set(PYBIND11_CPP_STANDARD "/std:c++latest")
else()
add_compile_options(-fvisibility=hidden)
- add_compile_options(-Wall -Wextra -pedantic -Werror -Wno-multichar -Wno-missing-field-initializers -Wno-unused-function -Wno-type-limits -Wno-unused-local-typedefs -Wno-sign-compare)
+ add_compile_options(-Wall -Wextra -Wno-missing-field-initializers -Wno-unused-function -Wno-type-limits -Wno-unused-local-typedefs -Wno-sign-compare)
if (APPLE)
add_compile_options(-Wno-four-char-constants -Wno-sometimes-uninitialized -Wno-deprecated -Wno-braced-scalar-init)
elseif (CMAKE_CXX_COMPILER_ID STREQUAL "Clang")
@@ -15,6 +15,11 @@ else()
endif()
endif()
+if (CMAKE_CUDA_COMPILER)
+ message(STATUS "Configuring for CUDA")
+ #add_compile_options(-save-temps)
+endif()
+
if(${CMAKE_SYSTEM_PROCESSOR} MATCHES
"(x86)|(X86)|(amd64)|(AMD64)|(x86_64)|(X86_64)")
if (MSVC)
diff --git a/conanfile.py b/conanfile.py
index b20364120..33d5c946b 100644
--- a/conanfile.py
+++ b/conanfile.py
@@ -29,6 +29,7 @@ class nncaseConan(ConanFile):
"k230_runtime": [True, False],
"k80_runtime": [True, False],
"vulkan_runtime": [True, False],
+ "cuda_runtime": [True, False],
"tests": [True, False],
"python": [True, False],
"python_root": ["ANY"]
@@ -40,6 +41,7 @@ class nncaseConan(ConanFile):
"k230_runtime": False,
"k80_runtime": False,
"vulkan_runtime": False,
+ "cuda_runtime": False,
"tests": False,
"python": True,
"python_root": ""
@@ -88,8 +90,11 @@ def generate(self):
tc.variables['ENABLE_K230_RUNTIME'] = self.options.k230_runtime
tc.variables['ENABLE_K80_RUNTIME'] = self.options.k80_runtime
tc.variables['ENABLE_VULKAN_RUNTIME'] = self.options.vulkan_runtime
+ tc.variables['ENABLE_CUDA_RUNTIME'] = self.options.cuda_runtime
tc.variables['BUILD_PYTHON_BINDING'] = self.options.python
tc.variables['BUILD_TESTING'] = self.options.tests
+ if self.options.cuda_runtime:
+ tc.variables['CMAKE_CUDA_ARCHITECTURES'] = "120"
if self.options.get_safe("python_root", default="") != "":
tc.variables['Python3_ROOT_DIR'] = self.options.python_root
if self.options.runtime:
diff --git a/modules/Nncase.Modules.NTT/CodeGen/CPU/CSourceBuiltn.cs b/modules/Nncase.Modules.NTT/CodeGen/CPU/CSourceBuiltn.cs
index a746e14ad..c9cd8d750 100644
--- a/modules/Nncase.Modules.NTT/CodeGen/CPU/CSourceBuiltn.cs
+++ b/modules/Nncase.Modules.NTT/CodeGen/CPU/CSourceBuiltn.cs
@@ -80,16 +80,16 @@ public static string TopoAwareRuntimeDef(NTTTargetOptions options, ulong dataAli
return content;
}
- public static string ModuleTopologyDef(NTTTargetOptions options)
+ public static string ModuleTopologyDef(NTTTargetOptions options, bool isCUDA)
{
- var content = RazorTemplateEngine.RenderAsync("~/CodeGen/CPU/Templates/module_topology_def.h.cshtml", options).Result;
+ var content = RazorTemplateEngine.RenderAsync("~/CodeGen/CPU/Templates/module_topology_def.h.cshtml", new { Hierarchies = options.Hierarchies[0], IsCUDA = isCUDA }).Result;
return content;
}
- public static string CMakeDef()
+ public static string CMakeDef(bool isCUDA)
{
var cmakePath = CMakePath(Path.Combine(Path.GetDirectoryName(typeof(CSourceBuiltn).Assembly.Location)!, "Runtime", "cmake", "ntt_module.cmake"));
- var content = RazorTemplateEngine.RenderAsync("~/CodeGen/CPU/Templates/CMakeLists.txt.cshtml", new { CMakePath = cmakePath }).Result;
+ var content = RazorTemplateEngine.RenderAsync("~/CodeGen/CPU/Templates/CMakeLists.txt.cshtml", new { CMakePath = cmakePath, IsCUDA = isCUDA }).Result;
return content;
}
diff --git a/modules/Nncase.Modules.NTT/CodeGen/CPU/CSourceCompiler.cs b/modules/Nncase.Modules.NTT/CodeGen/CPU/CSourceCompiler.cs
index 5d38ea682..34e15ba45 100644
--- a/modules/Nncase.Modules.NTT/CodeGen/CPU/CSourceCompiler.cs
+++ b/modules/Nncase.Modules.NTT/CodeGen/CPU/CSourceCompiler.cs
@@ -22,6 +22,8 @@ public class CSourceCompiler
{
private static string? _vcVarPath;
+ private readonly bool _isCUDA;
+
///
/// compiler exe name.
///
@@ -37,8 +39,9 @@ public class CSourceCompiler
///
private string _ext = string.Empty;
- public CSourceCompiler()
+ public CSourceCompiler(bool isCUDA)
{
+ _isCUDA = isCUDA;
PlatformSpecific();
ArchSpecific();
}
@@ -186,8 +189,16 @@ private void ArchSpecific()
private string ArgumentsSpecific(string sourcePath, string outPath)
{
- var archConfig = RuntimeInformation.IsOSPlatform(OSPlatform.Windows) ?
- "-DCMAKE_C_COMPILER=clang-cl -DCMAKE_CXX_COMPILER=clang-cl" : string.Empty;
+ string archConfig = string.Empty;
+ if (_isCUDA)
+ {
+ archConfig = $"-DCMAKE_CUDA_ARCHITECTURES=120 -DCMAKE_CUDA_COMPILER=clang++";
+ }
+ else
+ {
+ archConfig = RuntimeInformation.IsOSPlatform(OSPlatform.Windows) ?
+ "-DCMAKE_C_COMPILER=clang-cl -DCMAKE_CXX_COMPILER=clang-cl" : string.Empty;
+ }
#if DEBUG
var config = "Release";
diff --git a/modules/Nncase.Modules.NTT/CodeGen/CPU/DeviceCSourceConvertVisitor.cs b/modules/Nncase.Modules.NTT/CodeGen/CPU/DeviceCSourceConvertVisitor.cs
index 26a66d3a8..b4b47410d 100644
--- a/modules/Nncase.Modules.NTT/CodeGen/CPU/DeviceCSourceConvertVisitor.cs
+++ b/modules/Nncase.Modules.NTT/CodeGen/CPU/DeviceCSourceConvertVisitor.cs
@@ -47,7 +47,7 @@ public static void WriteWithProfiler(string functionName, string tagName = "")
IndentScope.Writer.IndWrite("{\n");
#if false // Disable device profiling for now.
IndentScope.Writer.Write($"constexpr std::string_view function_name = \"{tagName}\";\n");
- IndentScope.Writer.Write($"auto_profiler profiler(function_name, runtime::profiling_level::device);\n");
+ IndentScope.Writer.Write($"profile_scope profiler(function_name, profile_level::device);\n");
#endif
IndentScope.Writer.Write($"{functionName};\n");
IndentScope.Writer.IndWrite("}\n");
@@ -69,7 +69,7 @@ public static void WriteIndWithProfiler(string functionName, string tagName = ""
IndentScope.Writer.IndWrite("{\n");
#if false // Disable device profiling for now.
IndentScope.Writer.IndWrite($"constexpr std::string_view function_name = \"{tagName}\";\n");
- IndentScope.Writer.IndWrite($"auto_profiler profiler(function_name, runtime::profiling_level::device);\n");
+ IndentScope.Writer.IndWrite($"profile_scope profiler(function_name, profile_level::device);\n");
#endif
IndentScope.Writer.IndWrite($"{functionName};\n");
IndentScope.Writer.IndWrite("}\n");
@@ -94,7 +94,7 @@ protected override CSymbol VisitPrimFunction(PrimFunction expr)
}
var ctype = $"template<{string.Join(", ", Enumerable.Range(0, expr.Parameters.Length).Select(x => $"class T{x}"))}>" + Environment.NewLine +
- $"void {expr.Name}({string.Join(", ", expr.Parameters.AsValueEnumerable().Select(Visit).Select((s, i) => $"T{i} &&{s.Name}").ToArray())})";
+ $"NTT_DEVICE void {expr.Name}({string.Join(", ", expr.Parameters.AsValueEnumerable().Select(Visit).Select((s, i) => $"T{i} &&{s.Name}").ToArray())})";
using (var scope = new IndentScope(_deviceBuilder))
{
@@ -192,7 +192,7 @@ protected override CSymbol VisitPhysicalBuffer(PhysicalBuffer expr)
_ => throw new NotSupportedException(expr.Location.ToString()),
};
- var str = $"std::span({name} + {start.Name}, {size.Name})";
+ var str = $"ntt::span({name} + {start.Name}, {size.Name})";
symbol = new(start.Type, str);
_exprMemo.Add(expr, symbol);
return symbol;
diff --git a/modules/Nncase.Modules.NTT/CodeGen/CPU/FunctionBuilder.cs b/modules/Nncase.Modules.NTT/CodeGen/CPU/FunctionBuilder.cs
index b44ee466d..8d64872f2 100644
--- a/modules/Nncase.Modules.NTT/CodeGen/CPU/FunctionBuilder.cs
+++ b/modules/Nncase.Modules.NTT/CodeGen/CPU/FunctionBuilder.cs
@@ -21,15 +21,17 @@ internal class FunctionBuilder
private readonly BinaryWriter _textWriter;
private readonly BinaryWriter _rdataWriter;
private readonly IReadOnlyList _threadLocalRdataWriters;
+ private readonly IReadOnlyList _warpLocalRdataWriters;
private readonly IReadOnlyList _blockLocalRdataWriters;
- public FunctionBuilder(uint id, BinaryWriter rdataWriter, IReadOnlyList threadLocalRdataWriters, IReadOnlyList blockLocalRdataWriters, Targets.NTTTargetOptions targetOptions)
+ public FunctionBuilder(uint id, BinaryWriter rdataWriter, IReadOnlyList threadLocalRdataWriters, IReadOnlyList warpLocalRdataWriters, IReadOnlyList blockLocalRdataWriters, Targets.NTTTargetOptions targetOptions)
{
_id = id;
_sectionManager = new();
_textWriter = _sectionManager.GetWriter(WellknownSectionNames.Text);
_rdataWriter = rdataWriter;
_threadLocalRdataWriters = threadLocalRdataWriters;
+ _warpLocalRdataWriters = warpLocalRdataWriters;
_blockLocalRdataWriters = blockLocalRdataWriters;
TargetOptions = targetOptions;
}
@@ -58,66 +60,17 @@ public unsafe ILinkableFunction Build(BaseFunction baseFunc)
tensor.Serialize(_rdataWriter.BaseStream);
}
- // 2. write the thread local rdata
- ulong threadLocalRdataPoolSize = ulong.MinValue;
- foreach (var (@const, range) in primFunc.SchedResult.ThreadLocalRdatas)
- {
- var tensor = ((TensorConst)@const).Value;
- var distributedType = (DistributedType)@const.CheckedType;
- var size = range.Max - range.Min;
- threadLocalRdataPoolSize = System.Math.Max(range.Max, threadLocalRdataPoolSize);
- var dividedDims = DistributedUtility.GetDividedTensorType(distributedType).Shape.ToValueArray();
- var localStrides = TensorUtilities.GetDefaultStrides(dividedDims);
- for (int i = 0; i < _threadLocalRdataWriters.Count; i++)
- {
- var threadLocalRdataWriter = _threadLocalRdataWriters[i];
- var shardIndex = DistributedUtility.GetUnraveledIndex(i, TargetOptions.Hierarchies[0]);
- (var localOffset, var localShape) = DistributedUtility.GetLocalOffsetAndShape(distributedType, shardIndex);
- var linearOffset = TensorUtilities.GetLinearOffset(tensor.Strides, localOffset);
-
- if ((ulong)TensorUtilities.GetProduct(localShape) * (ulong)tensor.ElementType.SizeInBytes > size)
- {
- throw new InvalidDataException("The Buffer Size Not Equal!");
- }
-
- threadLocalRdataWriter.Position(checked((long)range.Min));
- tensor.Serialize(threadLocalRdataWriter.BaseStream, linearOffset, localShape, localStrides);
- }
- }
-
- // 2. write the block local rdata
- ulong blockLocalRdataPoolSize = ulong.MinValue;
- foreach (var (@const, range) in primFunc.SchedResult.BlockLocalRdatas)
- {
- var tensor = ((TensorConst)@const).Value;
- var distributedType = (DistributedType)@const.CheckedType;
- var size = range.Max - range.Min;
- blockLocalRdataPoolSize = System.Math.Max(range.Max, blockLocalRdataPoolSize);
- var dividedDims = DistributedUtility.GetDividedTensorType(distributedType).Shape.ToValueArray();
- var localStrides = TensorUtilities.GetDefaultStrides(dividedDims);
- for (int i = 0; i < _blockLocalRdataWriters.Count; i++)
- {
- var blockLocalRdataWriter = _blockLocalRdataWriters[i];
- var shardIndex = DistributedUtility.GetUnraveledIndex(i, TargetOptions.Hierarchies[0][..^1]).Concat([0]).ToArray();
- (var localOffset, var localShape) = DistributedUtility.GetLocalOffsetAndShape(distributedType, shardIndex);
- var linearOffset = TensorUtilities.GetLinearOffset(tensor.Strides, localOffset);
-
- if ((ulong)TensorUtilities.GetProduct(localShape) * (ulong)tensor.ElementType.SizeInBytes > size)
- {
- throw new InvalidDataException("The Buffer Size Not Equal!");
- }
-
- blockLocalRdataWriter.Position(checked((long)range.Min));
- tensor.Serialize(blockLocalRdataWriter.BaseStream, linearOffset, localShape, localStrides);
- }
- }
+ // 2. write the local rdatas
+ var threadLocalRdataPoolSize = SerializeLocalRdata(primFunc.SchedResult.ThreadLocalRdatas, _threadLocalRdataWriters, "t");
+ var warpLocalRdataPoolSize = SerializeLocalRdata(primFunc.SchedResult.WarpLocalRdatas, _warpLocalRdataWriters, "w");
+ var blockLocalRdataPoolSize = SerializeLocalRdata(primFunc.SchedResult.BlockLocalRdatas, _blockLocalRdataWriters, "b");
- // 4. build function.
+ // 3. build function.
var visitor = new KernelCSourceConvertVisitor(TargetOptions);
visitor.Visit(primFunc);
var functionCSource = visitor.GetCSource();
- // 5. write the kernel desc
+ // 4. write the kernel desc
using (var writer = _sectionManager.GetWriter(LinkableKernelFunction.KernelHeaderSectionName))
{
var header = default(KernelDescHeader);
@@ -125,6 +78,7 @@ public unsafe ILinkableFunction Build(BaseFunction baseFunc)
header.LocalDataAlign = (uint)primFunc.SchedResult.DataAlign;
header.OutputPoolSize = primFunc.SchedResult.OutputUsage;
header.LocalDataPoolSize = primFunc.SchedResult.DataUsage;
+ header.WarpLocalDataPoolSize = primFunc.SchedResult.WarpLocalDataPoolSize;
header.BlockLocalDataPoolSize = primFunc.SchedResult.BlockLocalDataPoolSize;
writer.Write(ref header);
}
@@ -132,6 +86,7 @@ public unsafe ILinkableFunction Build(BaseFunction baseFunc)
var memoryPoolDesc = new KernelMemoryPoolDesc(
rdataPoolSize,
threadLocalRdataPoolSize,
+ warpLocalRdataPoolSize,
blockLocalRdataPoolSize);
var kernelDescSection = new LinkedSection(_sectionManager.GetContent(LinkableKernelFunction.KernelHeaderSectionName)!, ".desc", 0, 8, (uint)sizeof(KernelDescHeader));
return new LinkableKernelFunction(_id, primFunc, functionCSource, memoryPoolDesc, _sectionManager.GetContent(WellknownSectionNames.Text)!, kernelDescSection);
@@ -154,4 +109,50 @@ public unsafe ILinkableFunction Build(BaseFunction baseFunc)
throw new NotSupportedException($"the {baseFunc.GetType()} {baseFunc.Name} is notsupport for codegen!");
}
+
+ private ulong SerializeLocalRdata(IReadOnlyDictionary> localRdatas, IReadOnlyList localRdataWriters, string scopeName)
+ {
+ ulong localRdataPoolSize = ulong.MinValue;
+ foreach (var (@const, range) in localRdatas)
+ {
+ var tensor = ((TensorConst)@const).Value;
+ var distributedType = (DistributedType)@const.CheckedType;
+ var size = range.Max - range.Min;
+ localRdataPoolSize = System.Math.Max(range.Max, localRdataPoolSize);
+ var dividedDims = DistributedUtility.GetDividedTensorType(distributedType).Shape.ToValueArray();
+ var localStrides = TensorUtilities.GetDefaultStrides(dividedDims);
+ for (int i = 0; i < localRdataWriters.Count; i++)
+ {
+ var localRdataWriter = localRdataWriters[i];
+ var shardIndex = GetScopedShardIndex(i, scopeName);
+ (var localOffset, var localShape) = DistributedUtility.GetLocalOffsetAndShape(distributedType, shardIndex);
+ var linearOffset = TensorUtilities.GetLinearOffset(tensor.Strides, localOffset);
+
+ if ((ulong)TensorUtilities.GetProduct(localShape) * (ulong)tensor.ElementType.SizeInBytes > size)
+ {
+ throw new InvalidDataException("The Buffer Size Not Equal!");
+ }
+
+ localRdataWriter.Position(checked((long)range.Min));
+ tensor.Serialize(localRdataWriter.BaseStream, linearOffset, localShape, localStrides);
+ }
+ }
+
+ return localRdataPoolSize;
+ }
+
+ private int[] GetScopedShardIndex(int writerIndex, string scopeName)
+ {
+ var hierarchies = TargetOptions.Hierarchies[0];
+ var scopeIndex = TargetOptions.HierarchyNames.IndexOf(scopeName, StringComparison.Ordinal);
+ if (scopeIndex < 0)
+ {
+ return DistributedUtility.GetUnraveledIndex(writerIndex, hierarchies);
+ }
+
+ var scopedHierarchies = hierarchies[..(scopeIndex + 1)];
+ return DistributedUtility.GetUnraveledIndex(writerIndex, scopedHierarchies)
+ .Concat(Enumerable.Repeat(0, hierarchies.Length - scopedHierarchies.Length))
+ .ToArray();
+ }
}
diff --git a/modules/Nncase.Modules.NTT/CodeGen/CPU/FusionCSourceConvertVisitor.cs b/modules/Nncase.Modules.NTT/CodeGen/CPU/FusionCSourceConvertVisitor.cs
index d655ebc19..31a804f95 100644
--- a/modules/Nncase.Modules.NTT/CodeGen/CPU/FusionCSourceConvertVisitor.cs
+++ b/modules/Nncase.Modules.NTT/CodeGen/CPU/FusionCSourceConvertVisitor.cs
@@ -60,7 +60,7 @@ protected override CSymbol VisitFusion(Fusion expr)
IndentScope.Writer.IndWrite($"template<{string.Join(", ", Enumerable.Range(0, expr.Parameters.Length).Select(x => $"class T{x}"))}> struct {expr.Name} {{\n");
using (_ = new IndentScope())
{
- IndentScope.Writer.IndWrite($"auto operator()({string.Join(", ", expr.Parameters.AsValueEnumerable().Select(Visit).Select((s, i) => $"const T{i} &{s.Name}").ToArray())}) const noexcept {{\n");
+ IndentScope.Writer.IndWrite($"constexpr auto operator()({string.Join(", ", expr.Parameters.AsValueEnumerable().Select(Visit).Select((s, i) => $"const T{i} &{s.Name}").ToArray())}) const noexcept {{\n");
// 2. Function body
using (_ = new IndentScope())
diff --git a/modules/Nncase.Modules.NTT/CodeGen/CPU/KernelCSourceConvertVisitor.cs b/modules/Nncase.Modules.NTT/CodeGen/CPU/KernelCSourceConvertVisitor.cs
index f20039466..96c6d89a2 100644
--- a/modules/Nncase.Modules.NTT/CodeGen/CPU/KernelCSourceConvertVisitor.cs
+++ b/modules/Nncase.Modules.NTT/CodeGen/CPU/KernelCSourceConvertVisitor.cs
@@ -27,7 +27,7 @@ namespace Nncase.CodeGen.NTT;
///
internal sealed class KernelCSourceConvertVisitor : CSourceConvertVisitor, IDisposable
{
- private readonly HashSet _excludedVars = new() { "data", "block_local_data" };
+ private readonly HashSet _excludedVars = new() { "data", "warp_local_data", "block_local_data" };
private readonly StringBuilder _kernelBuilder;
private readonly HashSet _refFuncs;
private readonly HashSet _declaredBuffers = new(ReferenceEqualityComparer.Instance);
@@ -47,7 +47,7 @@ public KernelCSourceConvertVisitor(NTTTargetOptions targetOptions)
private Var[] TensorParams => _tensorParams ??= VisitEntry.Parameters.ToArray().OfType().Where(x => !_excludedVars.Contains(x.Name)).ToArray();
- public static void WriteWithProfiler(string functionName, string tagName = "")
+ public void WriteWithProfiler(string functionName, string tagName = "")
{
functionName = functionName.TrimEnd(new char[] { ';', '\n' });
if (tagName == string.Empty)
@@ -62,12 +62,12 @@ public static void WriteWithProfiler(string functionName, string tagName = "")
tagName = tagName == string.Empty ? functionName : tagName;
IndentScope.Writer.IndWrite("{\n");
IndentScope.Writer.Write($"constexpr std::string_view function_name = \"{tagName}\";\n");
- IndentScope.Writer.Write($"auto_profiler profiler(function_name, runtime::profiling_level::kernel);\n");
+ IndentScope.Writer.Write($"profile_scope profiler(0, profile_level::kernel);\n");
IndentScope.Writer.Write($"{functionName};\n");
IndentScope.Writer.IndWrite("}\n");
}
- public static void WriteIndWithProfiler(string functionName, string tagName = "")
+ public void WriteIndWithProfiler(string functionName, string tagName = "")
{
functionName = functionName.TrimEnd(new char[] { ';', '\n' });
if (tagName == string.Empty)
@@ -82,7 +82,7 @@ public static void WriteIndWithProfiler(string functionName, string tagName = ""
tagName = tagName == string.Empty ? functionName : tagName;
IndentScope.Writer.IndWrite("{\n");
IndentScope.Writer.IndWrite($"constexpr std::string_view function_name = \"{tagName}\";\n");
- IndentScope.Writer.IndWrite($"auto_profiler profiler(function_name, runtime::profiling_level::kernel);\n");
+ IndentScope.Writer.IndWrite($"profile_scope profiler(0, profile_level::kernel);\n");
IndentScope.Writer.IndWrite($"{functionName};\n");
IndentScope.Writer.IndWrite("}\n");
}
@@ -92,7 +92,7 @@ public KernelCSource GetCSource()
var paramsExcluded = VisitEntry.Parameters.ToArray().OfType().Where(x => !_excludedVars.Contains(x.Name)).ToArray();
var templateHeader = TensorParams.Length == 0 ? string.Empty : $"template<{string.Join(", ", Enumerable.Range(0, TensorParams.Length).Select(x => $"class T{x}"))}>" + Environment.NewLine;
var ctype = templateHeader +
- $"void {VisitEntry.Name}({string.Concat(paramsExcluded.Select(Visit).Select(s => $"{s.Type} {s.Name}, ").ToArray())}const std::byte *rdata, const std::byte *thread_local_rdata, const std::byte *block_local_rdata, std::byte *thread_local_data, std::byte *block_local_data, std::byte *output, nncase::ntt::runtime::thread_inout_desc *const output_descs)";
+ $"NTT_DEVICE void {VisitEntry.Name}({string.Concat(paramsExcluded.Select(Visit).Select(s => $"{s.Type} {s.Name}, ").ToArray())}const std::byte *rdata, const std::byte *thread_local_rdata, const std::byte *warp_local_rdata, const std::byte *block_local_rdata, std::byte *thread_local_data, std::byte *warp_local_data, std::byte *block_local_data, std::byte *output, nncase::ntt::runtime::thread_inout_desc *const output_descs)";
return new(
Declare: ctype + ";\n",
Kernel: CSourceBuiltn.MakeKernel(ctype, _kernelBuilder.ToString()),
@@ -186,16 +186,18 @@ protected override CSymbol VisitPhysicalBuffer(PhysicalBuffer expr)
{
(MemoryLocation.Rdata, 0) => "rdata",
(MemoryLocation.ThreadLocalRdata, 0) => "thread_local_rdata",
+ (MemoryLocation.WarpLocalRdata, 0) => "warp_local_rdata",
(MemoryLocation.BlockLocalRdata, 0) => "block_local_rdata",
(MemoryLocation.Data, 0) => "thread_local_data",
(MemoryLocation.Data, 1) => "thread_local_data",
+ (MemoryLocation.WarpLocalData, 0) => "warp_local_data",
(MemoryLocation.BlockLocalData, 0) => "block_local_data",
(MemoryLocation.Output, 0) => "output",
_ => throw new NotSupportedException($"{expr.Location}, {expr.Hierarchy}"),
};
var ptypeName = "std::byte";
- if (expr.Location is MemoryLocation.Rdata or MemoryLocation.ThreadLocalRdata or MemoryLocation.BlockLocalRdata)
+ if (expr.Location is MemoryLocation.Rdata or MemoryLocation.ThreadLocalRdata or MemoryLocation.WarpLocalRdata or MemoryLocation.BlockLocalRdata)
{
// Rdata, ThreadLocalRdata and BlockLocalRdata are const
ptypeName = $"const {ptypeName}";
@@ -205,12 +207,12 @@ protected override CSymbol VisitPhysicalBuffer(PhysicalBuffer expr)
if (expr.Size is DimConst)
{
var spanSize = (ulong)expr.Size.FixedValue;
- name = $"std::span<{ptypeName}, {spanSize}>({loc} + {start.Name}UL, {spanSize})";
+ name = $"ntt::span<{ptypeName}, {spanSize}>({loc} + {start.Name}UL, {spanSize})";
}
else
{
var spanSize = Visit(expr.Size).Name;
- name = $"std::span<{ptypeName}>({loc} + {start.Name}UL, {spanSize})";
+ name = $"ntt::span<{ptypeName}>({loc} + {start.Name}UL, {spanSize})";
}
symbol = new(start.Type, name);
@@ -471,7 +473,7 @@ protected override CSymbol VisitCall(Call expr)
WriteWithProfiler($"slice({VisitBuffer(args[0], local: true).Name}, {VisitBuffer(args[3], local: true).Name}, {VisitDimOrShape(args[1]).Name}, {VisitDimOrShape(args[2]).Name}, fixed_dims_v<{string.Join(",", slice.Axes)}>, fixed_dims_v<{string.Join(",", slice.Strides)}>);\n");
break;
case TIR.NTT.Concat concat:
- WriteWithProfiler($"concat(std::make_tuple({string.Join(",", args.SkipLast(1).Select(x => VisitBuffer(x, local: true)).Select(s => s.Name))}), {VisitBuffer(args[^1], local: true).Name}, {concat.Axis}_dim);\n");
+ WriteWithProfiler($"concat(ntt::make_tuple({string.Join(",", args.SkipLast(1).Select(x => VisitBuffer(x, local: true)).Select(s => s.Name))}), {VisitBuffer(args[^1], local: true).Name}, {concat.Axis}_dim);\n");
break;
case TIR.NTT.Transpose transpose:
WriteWithProfiler($"transpose({VisitBuffer(args[0], local: true).Name}, {VisitBuffer(args[1], local: true).Name}, fixed_dims_v<{string.Join(",", transpose.Perm)}>);\n");
@@ -555,7 +557,7 @@ protected override CSymbol VisitCall(Call expr)
WriteIndWithProfiler($"get_position_ids({VisitBuffer(args[0], local: true).Name}, {VisitBuffer(args[1], local: true).Name}, {KernelUtility.ShardingToC(getPositionIds.DistributedType)}, {Visit(getPositionIds.DistributedType.TensorType.Shape).Name});\n");
break;
case TIR.NTT.Stack stack:
- IndentScope.Writer.Write($"stack<{stack.Axis}>(std::make_tuple({string.Join(",", args.SkipLast(1).Select(x => VisitBuffer(x, local: true)).Select(s => s.Name))}), {VisitBuffer(args[^1], local: true).Name});\n");
+ IndentScope.Writer.Write($"stack<{stack.Axis}>(ntt::make_tuple({string.Join(",", args.SkipLast(1).Select(x => VisitBuffer(x, local: true)).Select(s => s.Name))}), {VisitBuffer(args[^1], local: true).Name});\n");
break;
case TIR.NTT.Reshape reshape:
IndentScope.Writer.Write($"reshape({VisitBuffer(args[0], local: true).Name}, {VisitBuffer(args[1], local: true).Name});\n");
diff --git a/modules/Nncase.Modules.NTT/CodeGen/CPU/LinkableFunction.cs b/modules/Nncase.Modules.NTT/CodeGen/CPU/LinkableFunction.cs
index 3239ac741..1fb2eebb8 100644
--- a/modules/Nncase.Modules.NTT/CodeGen/CPU/LinkableFunction.cs
+++ b/modules/Nncase.Modules.NTT/CodeGen/CPU/LinkableFunction.cs
@@ -21,11 +21,14 @@ internal unsafe struct KernelDescHeader
[MarshalAs(UnmanagedType.U8)]
public ulong LocalDataPoolSize;
+ [MarshalAs(UnmanagedType.U8)]
+ public ulong WarpLocalDataPoolSize;
+
[MarshalAs(UnmanagedType.U8)]
public ulong BlockLocalDataPoolSize;
}
-internal sealed record KernelMemoryPoolDesc(ulong RdataPoolSize, ulong ThreadLocalRdataPoolSize, ulong BlockLocalRdataPoolSize);
+internal sealed record KernelMemoryPoolDesc(ulong RdataPoolSize, ulong ThreadLocalRdataPoolSize, ulong WarpLocalRdataPoolSize, ulong BlockLocalRdataPoolSize);
internal sealed class LinkableKernelFunction : ILinkableFunction
{
diff --git a/modules/Nncase.Modules.NTT/CodeGen/CPU/LinkableModule.cs b/modules/Nncase.Modules.NTT/CodeGen/CPU/LinkableModule.cs
index 5b25ad596..d6b881cb1 100644
--- a/modules/Nncase.Modules.NTT/CodeGen/CPU/LinkableModule.cs
+++ b/modules/Nncase.Modules.NTT/CodeGen/CPU/LinkableModule.cs
@@ -17,20 +17,24 @@ namespace Nncase.CodeGen.NTT;
internal sealed class LinkableModule : ILinkableModule
{
+ private readonly string _moduleKind;
private readonly Stream _desc;
private readonly Stream _rdata;
private readonly IReadOnlyList _threadLocalRdatas;
private readonly IReadOnlyList _threadLocalCaches;
+ private readonly IReadOnlyList _warpLocalRdatas;
private readonly IReadOnlyList _blockLocalRdatas;
private readonly IReadOnlyList _functions;
private readonly NTTTargetOptions _targetOptions;
- public LinkableModule(Stream desc, Stream rdata, IReadOnlyList threadLocalRdatas, IReadOnlyList threadLocalCaches, IReadOnlyList blockLocalRdatas, IReadOnlyList functions, CompileOptions options)
+ public LinkableModule(string moduleKind, Stream desc, Stream rdata, IReadOnlyList threadLocalRdatas, IReadOnlyList threadLocalCaches, IReadOnlyList warpLocalRdatas, IReadOnlyList blockLocalRdatas, IReadOnlyList functions, CompileOptions options)
{
+ _moduleKind = moduleKind;
_desc = desc;
_rdata = rdata;
_threadLocalRdatas = threadLocalRdatas;
_threadLocalCaches = threadLocalCaches;
+ _warpLocalRdatas = warpLocalRdatas;
_blockLocalRdatas = blockLocalRdatas;
_functions = functions;
PublicFunctions = _functions.OfType().ToArray();
@@ -134,14 +138,15 @@ private void WriteModuleTopologyDef(string codegenDir)
{
using (var writer = new StreamWriter(fs))
{
- writer.Write(CSourceBuiltn.ModuleTopologyDef(_targetOptions));
+ writer.Write(CSourceBuiltn.ModuleTopologyDef(_targetOptions, isCUDA: _moduleKind == CUDATarget.Kind));
}
}
}
private void WriteThreadMain(string codegenDir, LinkableKernelFunction mainFunc, IReadOnlyList kernelFiles)
{
- using (var fs = File.Open(Path.Join(codegenDir, "thread_main.cpp"), FileMode.Create))
+ var threadMainExt = _moduleKind == CUDATarget.Kind ? "cu" : "cpp";
+ using (var fs = File.Open(Path.Join(codegenDir, $"thread_main.{threadMainExt}"), FileMode.Create))
{
using (var writer = new StreamWriter(fs))
{
@@ -173,7 +178,7 @@ private void WriteCMakeLists(string codegenDir)
{
using (var writer = new StreamWriter(fs))
{
- writer.Write(CSourceBuiltn.CMakeDef());
+ writer.Write(CSourceBuiltn.CMakeDef(isCUDA: _moduleKind == CUDATarget.Kind));
}
}
}
@@ -199,12 +204,12 @@ private ILinkedModule GenerateLinkedModule(string codegenDir, LinkableKernelFunc
var funcText = File.ReadAllBytes(elfPath);
textWriter.Write(funcText);
linkedFunctions.Add(new LinkedFunction(mainFunc.Id, mainFunc.SourceFunction, 0, (uint)funcText.Length, mainFunc.Sections));
- return new LinkedModule(linkedFunctions, _desc, manager.GetContent(WellknownSectionNames.Text)!, _rdata, _threadLocalRdatas, _threadLocalCaches, _blockLocalRdatas, rdataAlign);
+ return new LinkedModule(_moduleKind, linkedFunctions, _desc, manager.GetContent(WellknownSectionNames.Text)!, _rdata, _threadLocalRdatas, _threadLocalCaches, _warpLocalRdatas, _blockLocalRdatas, rdataAlign);
}
private string CompileCSource(string sourcePath)
{
- var compiler = new CSourceCompiler();
+ var compiler = new CSourceCompiler(_moduleKind == CUDATarget.Kind);
var binDir = Path.Join(sourcePath, "build", "nncase_ntt_module");
return compiler.Compile(sourcePath, binDir);
}
diff --git a/modules/Nncase.Modules.NTT/CodeGen/CPU/LinkedModule.cs b/modules/Nncase.Modules.NTT/CodeGen/CPU/LinkedModule.cs
index 6165b08fc..827e4371a 100644
--- a/modules/Nncase.Modules.NTT/CodeGen/CPU/LinkedModule.cs
+++ b/modules/Nncase.Modules.NTT/CodeGen/CPU/LinkedModule.cs
@@ -17,21 +17,22 @@ internal unsafe struct ModuleDescHeader
public uint ThreadDim;
[MarshalAs(UnmanagedType.U4)]
- public uint BlockDim;
+ public uint WarpDim;
[MarshalAs(UnmanagedType.U4)]
- public uint ChipDim;
+ public uint BlockDim;
[MarshalAs(UnmanagedType.U4)]
- public uint Reserved0;
+ public uint ChipDim;
}
internal sealed class LinkedModule : ILinkedModule
{
public const string ModuleHeaderSectionName = ".desc";
- public unsafe LinkedModule(IReadOnlyList functions, Stream desc, Stream text, Stream rdata, IReadOnlyList threadLocalRdatas, IReadOnlyList threadLocalCaches, IReadOnlyList blockLocalRdatas, ulong rdataAlign)
+ public unsafe LinkedModule(string moduleKind, IReadOnlyList functions, Stream desc, Stream text, Stream rdata, IReadOnlyList threadLocalRdatas, IReadOnlyList threadLocalCaches, IReadOnlyList warpLocalRdatas, IReadOnlyList blockLocalRdatas, ulong rdataAlign)
{
+ ModuleKind = moduleKind;
Functions = functions;
Sections =
[
@@ -40,11 +41,12 @@ public unsafe LinkedModule(IReadOnlyList functions, Stream desc
new LinkedSection(rdata, WellknownSectionNames.Rdata, 0, (uint)rdataAlign, (ulong)rdata.Length),
new LinkedMultipleContentsSection(threadLocalRdatas, WellknownSectionNames.ThreadLocalRdata, 0, (uint)rdataAlign),
new LinkedMultipleContentsSection(threadLocalCaches, WellknownSectionNames.ThreadLocalCache, 0, (uint)rdataAlign),
+ new LinkedMultipleContentsSection(warpLocalRdatas, WellknownSectionNames.WarpLocalRdata, 0, (uint)rdataAlign),
new LinkedMultipleContentsSection(blockLocalRdatas, WellknownSectionNames.BlockLocalRdata, 0, (uint)rdataAlign),
];
}
- public string ModuleKind => "cpu";
+ public string ModuleKind { get; }
public uint Version => 0;
diff --git a/modules/Nncase.Modules.NTT/CodeGen/CPU/ModuleBuilder.cs b/modules/Nncase.Modules.NTT/CodeGen/CPU/ModuleBuilder.cs
index fe934826a..1d4fef416 100644
--- a/modules/Nncase.Modules.NTT/CodeGen/CPU/ModuleBuilder.cs
+++ b/modules/Nncase.Modules.NTT/CodeGen/CPU/ModuleBuilder.cs
@@ -17,22 +17,36 @@ public sealed class NTTModuleBuilder : IModuleBuilder
private readonly BinaryWriter _rdataWriter;
private readonly BinaryWriter[] _threadLocalRdataWriters;
private readonly BinaryWriter[] _threadLocalCacheWriters;
+ private readonly BinaryWriter[] _warpLocalRdataWriters;
private readonly BinaryWriter[] _blockLocalRdataWriters;
- public NTTModuleBuilder(CompileOptions options)
+ public NTTModuleBuilder(string moduleKind, CompileOptions options)
{
+ var targetOptions = (NTTTargetOptions)options.TargetOptions;
+ var hierarchies = targetOptions.Hierarchies[0];
+ ModuleKind = moduleKind;
_sectionManager = new();
_rdataWriter = _sectionManager.GetWriter(WellknownSectionNames.Rdata);
- var shardCount = TensorUtilities.GetProduct(((Targets.NTTTargetOptions)options.TargetOptions).Hierarchies[0]);
+
+ var shardCount = TensorUtilities.GetProduct(hierarchies);
_threadLocalRdataWriters = new BinaryWriter[shardCount];
_threadLocalCacheWriters = new BinaryWriter[shardCount];
- _blockLocalRdataWriters = new BinaryWriter[shardCount / ((Targets.NTTTargetOptions)options.TargetOptions).Hierarchies[0][^1]];
for (int i = 0; i < shardCount; i++)
{
_threadLocalRdataWriters[i] = _sectionManager.GetWriter(WellknownSectionNames.ThreadLocalRdata, i);
_threadLocalCacheWriters[i] = _sectionManager.GetWriter(WellknownSectionNames.ThreadLocalCache, i);
}
+ var isCUDA = ModuleKind == CUDATarget.Kind;
+ var warpsCount = isCUDA ? shardCount / hierarchies[^1] : 0;
+ _warpLocalRdataWriters = new BinaryWriter[warpsCount];
+ for (int i = 0; i < _warpLocalRdataWriters.Length; i++)
+ {
+ _warpLocalRdataWriters[i] = _sectionManager.GetWriter(WellknownSectionNames.WarpLocalRdata, i);
+ }
+
+ var blocksCount = isCUDA ? (hierarchies.Length > 1 ? warpsCount / hierarchies[^2] : 1) : shardCount / hierarchies[^1];
+ _blockLocalRdataWriters = new BinaryWriter[blocksCount];
for (int i = 0; i < _blockLocalRdataWriters.Length; i++)
{
_blockLocalRdataWriters[i] = _sectionManager.GetWriter(WellknownSectionNames.BlockLocalRdata, i);
@@ -44,7 +58,7 @@ public NTTModuleBuilder(CompileOptions options)
public CompileOptions CompileOptions { get; }
///
- public string ModuleKind => "cpu";
+ public string ModuleKind { get; }
///
public ILinkableModule Build(IReadOnlyList functions)
@@ -55,9 +69,21 @@ public ILinkableModule Build(IReadOnlyList functions)
using (var writer = _sectionManager.GetWriter(LinkedModule.ModuleHeaderSectionName))
{
var header = default(ModuleDescHeader);
+ var hasWarp = targetOptions.HierarchyNames.Contains('w', StringComparison.Ordinal);
header.ThreadDim = (uint)targetOptions.Hierarchies[0][^1];
- header.BlockDim = targetOptions.Hierarchies[0].Length < 2 ? 1 : (uint)targetOptions.Hierarchies[0][^2];
- header.ChipDim = targetOptions.Hierarchies[0].Length < 3 ? 1 : (uint)targetOptions.Hierarchies[0][^3];
+ if (hasWarp)
+ {
+ header.WarpDim = targetOptions.Hierarchies[0].Length < 2 ? 1 : (uint)targetOptions.Hierarchies[0][^2];
+ header.BlockDim = targetOptions.Hierarchies[0].Length < 3 ? 1 : (uint)targetOptions.Hierarchies[0][^3];
+ header.ChipDim = targetOptions.Hierarchies[0].Length < 4 ? 1 : (uint)targetOptions.Hierarchies[0][^4];
+ }
+ else
+ {
+ header.WarpDim = 1;
+ header.BlockDim = targetOptions.Hierarchies[0].Length < 2 ? 1 : (uint)targetOptions.Hierarchies[0][^2];
+ header.ChipDim = targetOptions.Hierarchies[0].Length < 3 ? 1 : (uint)targetOptions.Hierarchies[0][^3];
+ }
+
writer.Write(ref header);
// cache offsets.
@@ -76,7 +102,7 @@ public ILinkableModule Build(IReadOnlyList functions)
}
}
- var linkableFunctions = functions.OfType().Select((f, i) => new FunctionBuilder((uint)i, _rdataWriter, _threadLocalRdataWriters, _blockLocalRdataWriters, (Targets.NTTTargetOptions)CompileOptions.TargetOptions).Build(f)).ToArray();
+ var linkableFunctions = functions.OfType().Select((f, i) => new FunctionBuilder((uint)i, _rdataWriter, _threadLocalRdataWriters, _warpLocalRdataWriters, _blockLocalRdataWriters, (Targets.NTTTargetOptions)CompileOptions.TargetOptions).Build(f)).ToArray();
_rdataWriter.Flush();
var threadLocalRdataContents = Enumerable.Range(0, _threadLocalRdataWriters.Length).Select(i =>
{
@@ -96,12 +122,18 @@ public ILinkableModule Build(IReadOnlyList functions)
return _sectionManager.GetContent(WellknownSectionNames.ThreadLocalCache, i)!;
}).ToArray();
+ var warpLocalRdataContents = Enumerable.Range(0, _warpLocalRdataWriters.Length).Select(i =>
+ {
+ _warpLocalRdataWriters[i].Flush();
+ return _sectionManager.GetContent(WellknownSectionNames.WarpLocalRdata, i)!;
+ }).ToArray();
+
var blockLocalRdataContents = Enumerable.Range(0, _blockLocalRdataWriters.Length).Select(i =>
{
_blockLocalRdataWriters[i].Flush();
return _sectionManager.GetContent(WellknownSectionNames.BlockLocalRdata, i)!;
}).ToArray();
- return new LinkableModule(_sectionManager.GetContent(LinkedModule.ModuleHeaderSectionName)!, _sectionManager.GetContent(WellknownSectionNames.Rdata)!, threadLocalRdataContents, threadLocalCacheContents, blockLocalRdataContents, linkableFunctions, CompileOptions);
+ return new LinkableModule(ModuleKind, _sectionManager.GetContent(LinkedModule.ModuleHeaderSectionName)!, _sectionManager.GetContent(WellknownSectionNames.Rdata)!, threadLocalRdataContents, threadLocalCacheContents, warpLocalRdataContents, blockLocalRdataContents, linkableFunctions, CompileOptions);
}
}
diff --git a/modules/Nncase.Modules.NTT/CodeGen/CPU/Templates/CMakeLists.txt.cshtml b/modules/Nncase.Modules.NTT/CodeGen/CPU/Templates/CMakeLists.txt.cshtml
index 27199b31c..6b6b6a6fe 100644
--- a/modules/Nncase.Modules.NTT/CodeGen/CPU/Templates/CMakeLists.txt.cshtml
+++ b/modules/Nncase.Modules.NTT/CodeGen/CPU/Templates/CMakeLists.txt.cshtml
@@ -2,7 +2,11 @@
cmake_minimum_required(VERSION 3.15)
-project(nncase_cpu_module)
+@if (Model.IsCUDA) {
+@:project(nncase_cpu_module CXX CUDA)
+} else {
+@:project(nncase_cpu_module CXX)
+}
option(BUILD_SHARED "Build shared library in linux" OFF)
option(BUILD_STANDALONE "Build standalone executable" OFF)
@@ -12,5 +16,9 @@ endif()
include(@Html.Raw(Model.CMakePath))
-target_sources(nncase_ntt_module PRIVATE thread_main.cpp)
-target_include_directories(nncase_ntt_module PRIVATE ${CMAKE_CURRENT_LIST_DIR})
+@if (Model.IsCUDA) {
+@:target_sources(${NNCASE_NTT_MODULE_TARGET_NAME} PRIVATE thread_main.cu)
+} else {
+@:target_sources(${NNCASE_NTT_MODULE_TARGET_NAME} PRIVATE thread_main.cpp)
+}
+target_include_directories(${NNCASE_NTT_MODULE_TARGET_NAME} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
diff --git a/modules/Nncase.Modules.NTT/CodeGen/CPU/Templates/Kernels/Matmul.cshtml b/modules/Nncase.Modules.NTT/CodeGen/CPU/Templates/Kernels/Matmul.cshtml
index 5c81a9b8c..57adcaaa0 100644
--- a/modules/Nncase.Modules.NTT/CodeGen/CPU/Templates/Kernels/Matmul.cshtml
+++ b/modules/Nncase.Modules.NTT/CodeGen/CPU/Templates/Kernels/Matmul.cshtml
@@ -8,18 +8,18 @@
{
@:@(Model.Indent)if (@Html.Raw(Model.Arguments[3].Symbol.Name)) {
// @:@(Model.Indent) constexpr std::string_view function_name = "matmul";
-// @:@(Model.Indent) auto_profiler profiler(function_name, runtime::profiling_level::device);
+// @:@(Model.Indent) profile_scope profiler(function_name, profile_level::device);
@:@(Model.Indent) matmul(@Html.Raw(Model.Arguments[0].Symbol.Name), @Html.Raw(Model.Arguments[1].Symbol.Name), @Html.Raw(Model.Arguments[2].Symbol.Name), @scale, fixed_shape_v<@string.Join(",", Model.Target.LhsVectorizedAxes)>, fixed_shape_v<>, fixed_shape_v<@string.Join(",", Model.Target.RhsVectorizedAxes)>, fixed_shape_v<>);
@:@(Model.Indent)} else {
// @:@(Model.Indent) constexpr std::string_view function_name = "matmul";
-// @:@(Model.Indent) auto_profiler profiler(function_name, runtime::profiling_level::device);
+// @:@(Model.Indent) profile_scope profiler(function_name, profile_level::device);
@:@(Model.Indent) matmul(@Html.Raw(Model.Arguments[0].Symbol.Name), @Html.Raw(Model.Arguments[1].Symbol.Name), @Html.Raw(Model.Arguments[2].Symbol.Name), @scale, fixed_shape_v<@string.Join(",", Model.Target.LhsVectorizedAxes)>, fixed_shape_v<>, fixed_shape_v<@string.Join(",", Model.Target.RhsVectorizedAxes)>, fixed_shape_v<>);
@:@(Model.Indent)}
}
else
{
// @:@(Model.Indent) constexpr std::string_view function_name = "matmul";
-// @:@(Model.Indent) auto_profiler profiler(function_name, runtime::profiling_level::device);
+// @:@(Model.Indent) profile_scope profiler(function_name, profile_level::device);
@:@(Model.Indent) matmul(@Html.Raw(Model.Arguments[0].Symbol.Name), @Html.Raw(Model.Arguments[1].Symbol.Name), @Html.Raw(Model.Arguments[2].Symbol.Name), @scale, fixed_shape_v<@string.Join(",", Model.Target.LhsVectorizedAxes)>, fixed_shape_v<>, fixed_shape_v<@string.Join(",", Model.Target.RhsVectorizedAxes)>, fixed_shape_v<>);
}
@(Model.Indent)}
diff --git a/modules/Nncase.Modules.NTT/CodeGen/CPU/Templates/Kernels/PackedMatMul.cshtml b/modules/Nncase.Modules.NTT/CodeGen/CPU/Templates/Kernels/PackedMatMul.cshtml
index 51d6568ff..0b42768ae 100644
--- a/modules/Nncase.Modules.NTT/CodeGen/CPU/Templates/Kernels/PackedMatMul.cshtml
+++ b/modules/Nncase.Modules.NTT/CodeGen/CPU/Templates/Kernels/PackedMatMul.cshtml
@@ -8,18 +8,18 @@
{
@:@(Model.Indent)if (@Html.Raw(Model.Arguments[3].Symbol.Name)) {
// @:@(Model.Indent) constexpr std::string_view function_name = "matmul";
-// @:@(Model.Indent) auto_profiler profiler(function_name, runtime::profiling_level::device);
+// @:@(Model.Indent) profile_scope profiler(function_name, profile_level::device);
@:@(Model.Indent) packed_matmul(@Html.Raw(Model.Arguments[0].Symbol.Name), @Html.Raw(Model.Arguments[1].Symbol.Name), @Html.Raw(Model.Arguments[2].Symbol.Name), @(scale));
@:@(Model.Indent)} else {
// @:@(Model.Indent) constexpr std::string_view function_name = "matmul";
-// @:@(Model.Indent) auto_profiler profiler(function_name, runtime::profiling_level::device);
+// @:@(Model.Indent) profile_scope profiler(function_name, profile_level::device);
@:@(Model.Indent) packed_matmul(@Html.Raw(Model.Arguments[0].Symbol.Name), @Html.Raw(Model.Arguments[1].Symbol.Name), @Html.Raw(Model.Arguments[2].Symbol.Name), @(scale));
@:@(Model.Indent)}
}
else
{
// @:@(Model.Indent) constexpr std::string_view function_name = "matmul";
-// @:@(Model.Indent) auto_profiler profiler(function_name, runtime::profiling_level::device);
+// @:@(Model.Indent) profile_scope profiler(function_name, profile_level::device);
@:@(Model.Indent) packed_matmul(@Html.Raw(Model.Arguments[0].Symbol.Name), @Html.Raw(Model.Arguments[1].Symbol.Name), @Html.Raw(Model.Arguments[2].Symbol.Name), @(scale));
}
@(Model.Indent)}
diff --git a/modules/Nncase.Modules.NTT/CodeGen/CPU/Templates/module_topology_def.h.cshtml b/modules/Nncase.Modules.NTT/CodeGen/CPU/Templates/module_topology_def.h.cshtml
index 58b3a43e0..83806efdc 100644
--- a/modules/Nncase.Modules.NTT/CodeGen/CPU/Templates/module_topology_def.h.cshtml
+++ b/modules/Nncase.Modules.NTT/CodeGen/CPU/Templates/module_topology_def.h.cshtml
@@ -1,14 +1,14 @@
@using System.Linq
@using NetFabric.Hyperlinq
@using Nncase
-@model Nncase.Targets.NTTTargetOptions
@{
- var hierarchy = Model.Hierarchies[0];
+ var hierarchy = (int[])Model.Hierarchies;
+ var topologyLevels = Model.IsCUDA ? 4 : 3;
}
#pragma once
#include
namespace nncase::ntt::distributed {
- constexpr auto topology_shape = ntt::fixed_shape_v<@(string.Join(", ", Enumerable.Repeat(1, 3 - hierarchy.Length).Concat(hierarchy)))>;
+ constexpr auto topology_shape = ntt::fixed_shape_v<@(string.Join(", ", Enumerable.Repeat(1, topologyLevels - hierarchy.Length).Concat(hierarchy)))>;
}
diff --git a/modules/Nncase.Modules.NTT/CodeGen/CPU/Templates/thread_main.cpp.cshtml b/modules/Nncase.Modules.NTT/CodeGen/CPU/Templates/thread_main.cpp.cshtml
index 51f00f036..5eb35ec89 100644
--- a/modules/Nncase.Modules.NTT/CodeGen/CPU/Templates/thread_main.cpp.cshtml
+++ b/modules/Nncase.Modules.NTT/CodeGen/CPU/Templates/thread_main.cpp.cshtml
@@ -7,12 +7,14 @@
var inputCount = Model.PrimFunction.Parameters.Length;
}
-extern "C" void thread_main(const nncase::ntt::runtime::thread_inout_desc *input_descs,
+extern "C" NTT_DEVICE void thread_main(const nncase::ntt::runtime::thread_inout_desc *input_descs,
nncase::ntt::runtime::thread_inout_desc *const output_descs,
const std::byte *rdata,
const std::byte *thread_local_rdata,
+ const std::byte *warp_local_rdata,
const std::byte *block_local_rdata,
std::byte *thread_local_data,
+ std::byte *warp_local_data,
std::byte *block_local_data,
std::byte *output) {
/* prepare inputs */
@@ -72,7 +74,7 @@ extern "C" void thread_main(const nncase::ntt::runtime::thread_inout_desc *input
throw new NotSupportedException($"not support multi form topology!");
}
- @(Model.PrimFunction.Name)(@(string.Concat(names.Select(x => $"{x}, ")))rdata, thread_local_rdata, block_local_rdata, thread_local_data, block_local_data, output, output_descs);
+ @(Model.PrimFunction.Name)(@(string.Concat(names.Select(x => $"{x}, ")))rdata, thread_local_rdata, warp_local_rdata, block_local_rdata, thread_local_data, warp_local_data, block_local_data, output, output_descs);
}
#ifdef NNCASE_STANDALONE
diff --git a/modules/Nncase.Modules.NTT/CodeGen/CPU/Templates/topo_aware_runtime.cshtml b/modules/Nncase.Modules.NTT/CodeGen/CPU/Templates/topo_aware_runtime.cshtml
index 0ad5b8837..c401fd887 100644
--- a/modules/Nncase.Modules.NTT/CodeGen/CPU/Templates/topo_aware_runtime.cshtml
+++ b/modules/Nncase.Modules.NTT/CodeGen/CPU/Templates/topo_aware_runtime.cshtml
@@ -15,8 +15,7 @@
#pragma once
#include
-#include
-#include
+#include
/**
* @@brief topology aware runtime
@@ -33,22 +32,15 @@ namespace tar {
foreach (var i in comb) {
shape[i] = 1;
}
- var groupRawName = groupName + "_raw";
-@:std::barrier<> @(groupRawName)[@(groups)] {
- @for (int i = 0; i < groups; i++)
- {
- @:std::barrier(@(groupSize)),
- }
-@:};
-@:auto @(groupName) = nncase::ntt::make_tensor_view_from_address>(@(groupRawName), nncase::ntt::fixed_shape_v<@(string.Join(",", shape))>);
+@:NTT_DEVICE decltype(nncase::ntt::make_tensor>(nncase::ntt::fixed_shape_v<@(string.Join(",", shape))>)) @(groupName);
@:
}
@if (Model.CollectivePoolSize > 0) {
-@:alignas(@Model.Alignment) uint8_t collective_pool_ptr[@Model.CollectivePoolSize];
+@:alignas(@Model.Alignment) NTT_DEVICE uint8_t collective_pool_ptr[@Model.CollectivePoolSize];
} else {
-@:alignas(@Model.Alignment) uint8_t collective_pool_ptr[1];
+@:alignas(@Model.Alignment) NTT_DEVICE uint8_t collective_pool_ptr[1];
}
enum reduce_kind {
@@ -57,11 +49,11 @@ enum reduce_kind {
}
};
-constexpr std::array Hierarchy = {@(string.Join(", ", hierarchy))};
-auto src_ptr_tensor = nncase::ntt::make_tensor(nncase::ntt::fixed_shape_v<@(string.Join(",", hierarchy))>);
-auto dest_ptr_tensor = nncase::ntt::make_tensor(nncase::ntt::fixed_shape_v<@(string.Join(",", hierarchy))>);
+NTT_DEVICE constexpr ntt::array Hierarchy = {@(string.Join(", ", hierarchy))};
+NTT_DEVICE auto src_ptr_tensor = nncase::ntt::make_tensor(nncase::ntt::fixed_shape_v<@(string.Join(",", hierarchy))>);
+NTT_DEVICE auto dest_ptr_tensor = nncase::ntt::make_tensor(nncase::ntt::fixed_shape_v<@(string.Join(",", hierarchy))>);
-template static std::byte *get_cache_address() {
+template NTT_DEVICE static std::byte *get_cache_address() {
return reinterpret_cast(
ntt::distributed::detail::global_thread_local_cache_ptr(program_ids())(
Level));
@@ -77,7 +69,7 @@ namespace tac {
using namespace nncase;
template
-void tensor_boxing_load_sync(const GlobalShape &global_shape, const Index &index, TDst &dest)
+NTT_DEVICE void tensor_boxing_load_sync(const GlobalShape &global_shape, const Index &index, TDst &dest)
{
using TOutBase = std::decay_t;
using TElem = typename TOutBase::element_type;
@@ -87,7 +79,7 @@ void tensor_boxing_load_sync(const GlobalShape &global_shape, const Index &index
}
template
-void tensor_boxing_store_sync(const GlobalShape &global_shape, const Index &index, TSrc &src)
+NTT_DEVICE void tensor_boxing_store_sync(const GlobalShape &global_shape, const Index &index, TSrc &src)
{
using TSrcBase = std::decay_t;
using TElem = typename TSrcBase::element_type;
@@ -103,14 +95,14 @@ template class group_hierarchy_getter;
@:template <> class group_hierarchy_getter {
var shape = Enumerable.Range(0, hierarchy.Length).Select(i => comb.Contains(i) ? hierarchy[i] : 1).ToArray();
@:public:
-@: static constexpr auto group_hierarchy = ntt::fixed_shape_v<@(string.Join(", ", shape))>;
+@: NTT_DEVICE static constexpr auto group_hierarchy = ntt::fixed_shape_v<@(string.Join(", ", shape))>;
@:};
}
template
class tensor_reduce_sync_impl {
public:
- void reduce_group_sync() const noexcept {
+ NTT_DEVICE void reduce_group_sync() const noexcept {
@foreach(var comb in combinations) {
var reduce_group_index = string.Join(", ", Enumerable.Range(0, hierarchy.Length).Select(i => comb.Contains(i) ? "0" : "ntt::distributed::" + hierarchyNames[i] + "id()"));
@:if constexpr (Kind == tar::reduce_kind::@(GetName(comb, string.Empty))) {
@@ -124,7 +116,7 @@ class tensor_reduce_sync_impl {
}
template
- constexpr auto index_group2global(const TIndexInGroup &index_in_group, const TIndexInGlobal &index_in_global) const noexcept {
+ NTT_DEVICE constexpr auto index_group2global(const TIndexInGroup &index_in_group, const TIndexInGlobal &index_in_global) const noexcept {
return ntt::generate_shape([&](auto axis) {
if constexpr (Kind & (1 << (TIndexInGlobal::rank() - axis))) {
return index_in_group[axis];
@@ -135,7 +127,7 @@ class tensor_reduce_sync_impl {
}
template
- constexpr auto index_global2group(const TIndexInGlobal &index_in_global) const noexcept {
+ NTT_DEVICE constexpr auto index_global2group(const TIndexInGlobal &index_in_global) const noexcept {
return ntt::generate_shape([&](auto axis) {
if constexpr (Kind & (1 << (TIndexInGlobal::rank() - axis))) {
return index_in_global[axis];
@@ -145,7 +137,7 @@ class tensor_reduce_sync_impl {
});
}
- static constexpr auto get_group_size() {
+ NTT_DEVICE static constexpr auto get_group_size() {
size_t group_size = 1;
for (size_t i = 1; i <= tar::Hierarchy.size(); i++) {
if (Kind & (1 << i)) {
@@ -160,7 +152,7 @@ class tensor_reduce_sync_impl {
}
template
- void reduce_impl(TSliceIn &local, TSliceIn &remote, TSliceOut &dest) {
+ NTT_DEVICE void reduce_impl(TSliceIn &local, TSliceIn &remote, TSliceOut &dest) {
if constexpr (Op == ntt::reduce_op::max) {
ntt::binary(local, remote, dest);
} else if constexpr (Op == ntt::reduce_op::sum ||
@@ -173,7 +165,7 @@ class tensor_reduce_sync_impl {
}
}
- template void operator()(TIn &src, TOut &&dest) {
+ template NTT_DEVICE void operator()(TIn &src, TOut &&dest) {
// collect all tensors pointer for access tensor from other nodes.
using TElem = typename TIn::element_type;
using TOutBase = std::decay_t;
@@ -289,7 +281,7 @@ class tensor_reduce_sync_impl {
} // namespace detail
template
-void tensor_reduce_sync(TIn &input, TOut &&output) {
+NTT_DEVICE void tensor_reduce_sync(TIn &input, TOut &&output) {
detail::tensor_reduce_sync_impl impl;
impl(input, output);
}
diff --git a/modules/Nncase.Modules.NTT/NTTModule.cs b/modules/Nncase.Modules.NTT/NTTModule.cs
index 9ba23e357..fe7497b94 100644
--- a/modules/Nncase.Modules.NTT/NTTModule.cs
+++ b/modules/Nncase.Modules.NTT/NTTModule.cs
@@ -15,5 +15,6 @@ internal class NTTModule : IApplicationPart
public void ConfigureServices(IRegistrator registrator)
{
registrator.Register(reuse: Reuse.Singleton);
+ registrator.Register(reuse: Reuse.Singleton);
}
}
diff --git a/modules/Nncase.Modules.NTT/Targets/NTTModuleCompiler.cs b/modules/Nncase.Modules.NTT/Targets/CPUModuleCompiler.cs
similarity index 93%
rename from modules/Nncase.Modules.NTT/Targets/NTTModuleCompiler.cs
rename to modules/Nncase.Modules.NTT/Targets/CPUModuleCompiler.cs
index 2585e965b..e0ab0705e 100644
--- a/modules/Nncase.Modules.NTT/Targets/NTTModuleCompiler.cs
+++ b/modules/Nncase.Modules.NTT/Targets/CPUModuleCompiler.cs
@@ -11,7 +11,7 @@
namespace Nncase.Targets;
-public class NTTModuleCompiler : IModuleCompiler
+public class CPUModuleCompiler : INTTModuleCompiler
{
public string ModuleKind => CPUTarget.Kind;
@@ -35,7 +35,7 @@ public class NTTModuleCompiler : IModuleCompiler
_ => throw new NotSupportedException($"Unsupported architecture: {RuntimeInformation.ProcessArchitecture}"),
};
- public IModuleBuilder CreateModuleBuilder(CompileOptions options) => new NTTModuleBuilder(options);
+ public IModuleBuilder CreateModuleBuilder(CompileOptions options) => new NTTModuleBuilder(ModuleKind, options);
public bool IsSupportedCall(Call call, CompileOptions options)
{
diff --git a/modules/Nncase.Modules.NTT/Targets/CPUTarget.cs b/modules/Nncase.Modules.NTT/Targets/CPUTarget.cs
index 04312ec46..15e5e76ac 100644
--- a/modules/Nncase.Modules.NTT/Targets/CPUTarget.cs
+++ b/modules/Nncase.Modules.NTT/Targets/CPUTarget.cs
@@ -23,110 +23,15 @@
namespace Nncase.Targets;
///
-/// Target for NTT.
+/// Target for CPU.
///
-public class CPUTarget : Target
+public class CPUTarget : NTTTarget
{
public const string Kind = "cpu";
- private readonly NTTModuleCompiler _nttModuleCompiler = new();
-
public CPUTarget()
{
- ModuleCompilers = [_nttModuleCompiler];
- }
-
- public override string Name => Kind;
-
- public override IReadOnlyList ModuleCompilers { get; }
-
- public override (System.CommandLine.Command Command, Func Parser) RegisterCommandAndParser()
- {
- var cmd = new NTTTargetOptionsCommand(Kind);
-
- ITargetOptions ParseTargetCompileOptions(InvocationContext context, Command command)
- {
- var binder = new NTTTargetOptionsBinder(cmd);
- return binder.GetBoundValue(context);
- }
-
- return (cmd, ParseTargetCompileOptions);
- }
-
- public override void RegisterAffineSelectionPass(IPassManager passManager, CompileOptions options)
- {
- passManager.Add();
- }
-
- public override void RegisterAutoPackingRules(IRulesAddable pass, CompileOptions options)
- {
- var nr = _nttModuleCompiler.Nr;
-
- pass.Add(nr);
}
- public override void RegisterAutoVectorizeRules(IRulesAddable pass, CompileOptions options)
- {
- // todo config it in the target options.
- var rank = 1;
- var lane = _nttModuleCompiler.Lane;
- var maskVectorStyle = _nttModuleCompiler.MaskVectorStyle;
-
- pass.Add(rank, lane);
- pass.Add(rank, lane);
- pass.Add(rank, lane);
-
- pass.Add();
- pass.Add();
- pass.Add(maskVectorStyle);
- pass.Add();
- pass.Add();
- pass.Add();
- pass.Add();
- pass.Add();
- pass.Add();
- pass.Add();
- pass.Add();
-
- // pass.Add(rank, lane);
- pass.Add();
-
- // pass.Add(rank, lane);
- pass.Add();
- pass.Add();
- pass.Add();
- pass.Add(maskVectorStyle);
-
- pass.Add();
- pass.Add();
- pass.Add();
- pass.Add();
- pass.Add();
- pass.Add();
- pass.Add();
- pass.Add();
- pass.Add();
- pass.Add();
-
- pass.Add();
- pass.Add();
- pass.Add();
- pass.Add();
- pass.Add();
- pass.Add();
- }
-
- public override void RegisterTIRSelectionPass(IPassManager passManager, CompileOptions optionsÍ)
- {
- passManager.Add();
- }
-
- public override void RegisterPostAutoVectorizePass(IPassManager passManager, CompileOptions options)
- {
- passManager.AddWithName("FoldPostOps").Configure(p =>
- {
- p.Add();
- p.Add();
- });
- }
+ protected override INTTModuleCompiler NTTModuleCompiler { get; } = new CPUModuleCompiler();
}
diff --git a/modules/Nncase.Modules.NTT/Targets/CUDAModuleCompiler.cs b/modules/Nncase.Modules.NTT/Targets/CUDAModuleCompiler.cs
new file mode 100644
index 000000000..534e03a41
--- /dev/null
+++ b/modules/Nncase.Modules.NTT/Targets/CUDAModuleCompiler.cs
@@ -0,0 +1,34 @@
+// Copyright (c) Canaan Inc. All rights reserved.
+// Licensed under the Apache license. See LICENSE file in the project root for full license information.
+
+using System;
+using System.Collections.Generic;
+using System.Runtime.InteropServices;
+using Nncase.CodeGen;
+using Nncase.CodeGen.NTT;
+using Nncase.IR;
+using Nncase.Passes;
+
+namespace Nncase.Targets;
+
+public class CUDAModuleCompiler : INTTModuleCompiler
+{
+ public string ModuleKind => CUDATarget.Kind;
+
+ public MaskVectorStyle MaskVectorStyle => MaskVectorStyle.Fat;
+
+ public int Lane => 16;
+
+ public int Nr => 4;
+
+ public IModuleBuilder CreateModuleBuilder(CompileOptions options) => new NTTModuleBuilder(ModuleKind, options);
+
+ public bool IsSupportedCall(Call call, CompileOptions options)
+ {
+ return call.Target switch
+ {
+ Op op => PassUtility.IsCpuSupported(op, call, call.Arguments, ModuleKind),
+ _ => false,
+ };
+ }
+}
diff --git a/modules/Nncase.Modules.NTT/Targets/CUDATarget.cs b/modules/Nncase.Modules.NTT/Targets/CUDATarget.cs
new file mode 100644
index 000000000..76fd11b39
--- /dev/null
+++ b/modules/Nncase.Modules.NTT/Targets/CUDATarget.cs
@@ -0,0 +1,37 @@
+// Copyright (c) Canaan Inc. All rights reserved.
+// Licensed under the Apache license. See LICENSE file in the project root for full license information.
+
+using System;
+using System.Collections.Generic;
+using System.CommandLine;
+using System.CommandLine.Invocation;
+using System.Linq;
+using System.Runtime.InteropServices;
+using System.Text;
+using System.Threading.Tasks;
+using Microsoft.Extensions.Configuration;
+using Microsoft.Extensions.Options;
+using Nncase.CodeGen;
+using Nncase.CodeGen.NTT;
+using Nncase.IR;
+using Nncase.Passes;
+using Nncase.Passes.Rules.Neutral;
+using Nncase.Passes.Rules.ShapeBucket;
+using Nncase.Passes.Transforms;
+using Nncase.Quantization;
+
+namespace Nncase.Targets;
+
+///
+/// Target for CUDA.
+///
+public class CUDATarget : NTTTarget
+{
+ public const string Kind = "cuda";
+
+ public CUDATarget()
+ {
+ }
+
+ protected override INTTModuleCompiler NTTModuleCompiler { get; } = new CUDAModuleCompiler();
+}
diff --git a/modules/Nncase.Modules.NTT/Targets/INTTModuleCompiler.cs b/modules/Nncase.Modules.NTT/Targets/INTTModuleCompiler.cs
new file mode 100644
index 000000000..cd8d8cd7e
--- /dev/null
+++ b/modules/Nncase.Modules.NTT/Targets/INTTModuleCompiler.cs
@@ -0,0 +1,19 @@
+// Copyright (c) Canaan Inc. All rights reserved.
+// Licensed under the Apache license. See LICENSE file in the project root for full license information.
+
+using System;
+using System.Collections.Generic;
+using System.Runtime.InteropServices;
+using Nncase.CodeGen;
+using Nncase.CodeGen.NTT;
+using Nncase.IR;
+using Nncase.Passes;
+
+namespace Nncase.Targets;
+
+public interface INTTModuleCompiler : IModuleCompiler
+{
+ int Lane { get; }
+
+ int Nr { get; }
+}
diff --git a/modules/Nncase.Modules.NTT/Targets/NTTTarget.cs b/modules/Nncase.Modules.NTT/Targets/NTTTarget.cs
new file mode 100644
index 000000000..2dc8ca7ba
--- /dev/null
+++ b/modules/Nncase.Modules.NTT/Targets/NTTTarget.cs
@@ -0,0 +1,130 @@
+// Copyright (c) Canaan Inc. All rights reserved.
+// Licensed under the Apache license. See LICENSE file in the project root for full license information.
+
+using System;
+using System.Collections.Generic;
+using System.CommandLine;
+using System.CommandLine.Invocation;
+using System.Linq;
+using System.Runtime.InteropServices;
+using System.Text;
+using System.Threading.Tasks;
+using Microsoft.Extensions.Configuration;
+using Microsoft.Extensions.Options;
+using Nncase.CodeGen;
+using Nncase.CodeGen.NTT;
+using Nncase.IR;
+using Nncase.Passes;
+using Nncase.Passes.Rules.Neutral;
+using Nncase.Passes.Rules.ShapeBucket;
+using Nncase.Passes.Transforms;
+using Nncase.Quantization;
+
+namespace Nncase.Targets;
+
+///
+/// Target for NTT.
+///
+public abstract class NTTTarget : Target
+{
+ public NTTTarget()
+ {
+ ModuleCompilers = [NTTModuleCompiler];
+ }
+
+ public override string Name => NTTModuleCompiler.ModuleKind;
+
+ public override IReadOnlyList ModuleCompilers { get; }
+
+ protected abstract INTTModuleCompiler NTTModuleCompiler { get; }
+
+ public override (System.CommandLine.Command Command, Func Parser) RegisterCommandAndParser()
+ {
+ var cmd = new NTTTargetOptionsCommand(Name);
+
+ ITargetOptions ParseTargetCompileOptions(InvocationContext context, Command command)
+ {
+ var binder = new NTTTargetOptionsBinder(cmd);
+ return binder.GetBoundValue(context);
+ }
+
+ return (cmd, ParseTargetCompileOptions);
+ }
+
+ public override void RegisterAffineSelectionPass(IPassManager passManager, CompileOptions options)
+ {
+ passManager.Add();
+ }
+
+ public override void RegisterAutoPackingRules(IRulesAddable pass, CompileOptions options)
+ {
+ var nr = NTTModuleCompiler.Nr;
+
+ pass.Add(nr);
+ }
+
+ public override void RegisterAutoVectorizeRules(IRulesAddable pass, CompileOptions options)
+ {
+ // todo config it in the target options.
+ var rank = 1;
+ var lane = NTTModuleCompiler.Lane;
+ var maskVectorStyle = NTTModuleCompiler.MaskVectorStyle;
+
+ pass.Add(rank, lane);
+ pass.Add(rank, lane);
+ pass.Add(rank, lane);
+
+ pass.Add();
+ pass.Add();
+ pass.Add(maskVectorStyle);
+ pass.Add();
+ pass.Add();
+ pass.Add();
+ pass.Add();
+ pass.Add();
+ pass.Add();
+ pass.Add();
+ pass.Add();
+
+ // pass.Add(rank, lane);
+ pass.Add();
+
+ // pass.Add(rank, lane);
+ pass.Add();
+ pass.Add();
+ pass.Add();
+ pass.Add(maskVectorStyle);
+
+ pass.Add();
+ pass.Add();
+ pass.Add();
+ pass.Add();
+ pass.Add();
+ pass.Add();
+ pass.Add();
+ pass.Add();
+ pass.Add();
+ pass.Add();
+
+ pass.Add();
+ pass.Add();
+ pass.Add();
+ pass.Add();
+ pass.Add();
+ pass.Add();
+ }
+
+ public override void RegisterTIRSelectionPass(IPassManager passManager, CompileOptions optionsÍ)
+ {
+ passManager.Add(NTTModuleCompiler.ModuleKind);
+ }
+
+ public override void RegisterPostAutoVectorizePass(IPassManager passManager, CompileOptions options)
+ {
+ passManager.AddWithName("FoldPostOps").Configure(p =>
+ {
+ p.Add();
+ p.Add();
+ });
+ }
+}
diff --git a/ntt/cmake/compile_flags.cmake b/ntt/cmake/compile_flags.cmake
index 0fac799dc..c9050db84 100644
--- a/ntt/cmake/compile_flags.cmake
+++ b/ntt/cmake/compile_flags.cmake
@@ -70,3 +70,6 @@ if(${CMAKE_SYSTEM_PROCESSOR} MATCHES
message(FATAL_ERROR "Unsupported riscv64 target")
endif()
endif()
+
+if (CMAKE_CUDA_COMPILER)
+endif()
diff --git a/ntt/cmake/ntt_module.cmake b/ntt/cmake/ntt_module.cmake
index 8b91c5cfb..0f9ab7db9 100644
--- a/ntt/cmake/ntt_module.cmake
+++ b/ntt/cmake/ntt_module.cmake
@@ -1,37 +1,105 @@
-cmake_minimum_required(VERSION 3.15)
+cmake_minimum_required(VERSION 3.16)
include(${CMAKE_CURRENT_LIST_DIR}/compile_flags.cmake)
-if (BUILD_STANDALONE)
- add_executable(nncase_ntt_module ${CMAKE_CURRENT_LIST_DIR}/../src/dummy.cpp)
+if (CMAKE_CUDA_COMPILER)
+ set(NNCASE_NTT_MODULE_TARGET_NAME nncase_ntt_module_bundle)
+
+ find_program(NVLINK nvlink REQUIRED)
+ find_program(FATBINARY fatbinary REQUIRED)
+ message(STATUS "Found nvlink: ${NVLINK}")
+ message(STATUS "Found fatbinary: ${FATBINARY}")
else()
- add_library(nncase_ntt_module SHARED ${CMAKE_CURRENT_LIST_DIR}/../src/dummy.cpp)
+ set(NNCASE_NTT_MODULE_TARGET_NAME nncase_ntt_module)
endif()
-target_compile_features(nncase_ntt_module PUBLIC cxx_std_20)
-target_include_directories(nncase_ntt_module PUBLIC ${CMAKE_CURRENT_LIST_DIR}/../include)
-set_target_properties(nncase_ntt_module PROPERTIES PREFIX "" SUFFIX "")
-set_target_properties(nncase_ntt_module PROPERTIES POSITION_INDEPENDENT_CODE ON)
-target_compile_definitions(nncase_ntt_module PUBLIC -DNNCASE_CPU_MODULE=1)
-set_property(TARGET nncase_ntt_module PROPERTY INTERPROCEDURAL_OPTIMIZATION TRUE)
+if (BUILD_STANDALONE)
+ add_executable(${NNCASE_NTT_MODULE_TARGET_NAME} ${CMAKE_CURRENT_LIST_DIR}/../src/dummy.cpp)
+elseif (CMAKE_CUDA_COMPILER)
+ add_library(${NNCASE_NTT_MODULE_TARGET_NAME} OBJECT)
+else()
+ add_library(${NNCASE_NTT_MODULE_TARGET_NAME} SHARED ${CMAKE_CURRENT_LIST_DIR}/../src/dummy.cpp)
+endif()
-target_sources(nncase_ntt_module PRIVATE ${CMAKE_CURRENT_LIST_DIR}/../src/cpu_runtime.cpp)
+target_compile_features(${NNCASE_NTT_MODULE_TARGET_NAME} PUBLIC cxx_std_20)
+target_include_directories(${NNCASE_NTT_MODULE_TARGET_NAME} PUBLIC ${CMAKE_CURRENT_LIST_DIR}/../include)
+set_target_properties(${NNCASE_NTT_MODULE_TARGET_NAME} PROPERTIES PREFIX "" SUFFIX "")
+set_target_properties(${NNCASE_NTT_MODULE_TARGET_NAME} PROPERTIES POSITION_INDEPENDENT_CODE ON)
+set_property(TARGET ${NNCASE_NTT_MODULE_TARGET_NAME} PROPERTY INTERPROCEDURAL_OPTIMIZATION TRUE)
if (BUILD_STANDALONE)
- target_compile_definitions(nncase_ntt_module PUBLIC -DNNCASE_STANDALONE=1)
+ target_compile_definitions(${NNCASE_NTT_MODULE_TARGET_NAME} PUBLIC -DNNCASE_STANDALONE=1)
endif()
if (MSVC)
- set_property(TARGET nncase_ntt_module PROPERTY
+ set_property(TARGET ${NNCASE_NTT_MODULE_TARGET_NAME} PROPERTY
MSVC_RUNTIME_LIBRARY "MultiThreaded$<$:Debug>")
- set_target_properties(nncase_ntt_module PROPERTIES LINK_FLAGS /SUBSYSTEM:CONSOLE)
- target_link_options(nncase_ntt_module PRIVATE /NODEFAULTLIB)
- target_link_libraries(nncase_ntt_module PRIVATE "libvcruntime$<$:d>"
+ set_target_properties(${NNCASE_NTT_MODULE_TARGET_NAME} PROPERTIES LINK_FLAGS /SUBSYSTEM:CONSOLE)
+ target_link_options(${NNCASE_NTT_MODULE_TARGET_NAME} PRIVATE /NODEFAULTLIB)
+ target_link_libraries(${NNCASE_NTT_MODULE_TARGET_NAME} PRIVATE "libvcruntime$<$:d>"
"msvcrt$<$:d>"
"ucrt$<$:d>"
"libcpmt$<$:d>")
elseif(APPLE)
- target_link_options(nncase_ntt_module PRIVATE -ld_classic -lc)
+ target_link_options(${NNCASE_NTT_MODULE_TARGET_NAME} PRIVATE -ld_classic -lc)
+else()
+ target_link_libraries(${NNCASE_NTT_MODULE_TARGET_NAME} PRIVATE pthread)
+endif()
+
+if (CMAKE_CUDA_COMPILER)
+ target_sources(${NNCASE_NTT_MODULE_TARGET_NAME} PRIVATE ${CMAKE_CURRENT_LIST_DIR}/../src/cuda_runtime.cu)
+ target_compile_definitions(${NNCASE_NTT_MODULE_TARGET_NAME} PUBLIC -DNNCASE_CUDA_MODULE=1)
+ target_compile_options(${NNCASE_NTT_MODULE_TARGET_NAME} PRIVATE $<$:
+ -fgpu-rdc
+ --cuda-device-only
+ >)
+
+ foreach(arch ${CMAKE_CUDA_ARCHITECTURES})
+ # Link device code for this architecture
+ set(linked_obj "${CMAKE_CURRENT_BINARY_DIR}/linked_sm_${arch}.o")
+ add_custom_command(
+ OUTPUT ${linked_obj}
+ COMMAND ${NVLINK}
+ -arch=sm_${arch}
+ $
+ -o ${linked_obj}
+ DEPENDS ${NNCASE_NTT_MODULE_TARGET_NAME} $
+ COMMAND_EXPAND_LISTS
+ VERBATIM
+ COMMENT "Linking device code for sm_${arch}"
+ )
+
+ # Add to the list of all linked objects
+ list(APPEND ALL_LINKED_OBJECTS ${linked_obj})
+ endforeach()
+
+ add_custom_target(device_link ALL
+ DEPENDS ${ALL_LINKED_OBJECTS}
+ )
+
+ set(FATBIN_ARGS "")
+ foreach(arch ${CMAKE_CUDA_ARCHITECTURES})
+ # Find the linked object for this architecture
+ set(arch_obj "${CMAKE_CURRENT_BINARY_DIR}/linked_sm_${arch}.o")
+ list(APPEND FATBIN_ARGS --image3=kind=elf,sm=${arch},file="${arch_obj}")
+ endforeach()
+
+ # Create the fatbinary from all linked objects
+ add_custom_command(
+ OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/nncase_ntt_module
+ COMMAND ${FATBINARY}
+ -64
+ --create ${CMAKE_CURRENT_BINARY_DIR}/nncase_ntt_module
+ ${FATBIN_ARGS}
+ DEPENDS ${ALL_LINKED_OBJECTS}
+ COMMENT "Creating fatbinary from linked objects"
+ VERBATIM
+ )
+
+ add_custom_target(fatbin ALL
+ DEPENDS ${CMAKE_CURRENT_BINARY_DIR}/nncase_ntt_module
+ )
else()
- target_link_libraries(nncase_ntt_module PRIVATE pthread)
+ target_sources(${NNCASE_NTT_MODULE_TARGET_NAME} PRIVATE ${CMAKE_CURRENT_LIST_DIR}/../src/cpu_runtime.cpp)
+ target_compile_definitions(${NNCASE_NTT_MODULE_TARGET_NAME} PUBLIC -DNNCASE_CPU_MODULE=1)
endif()
diff --git a/ntt/include/nncase/bfloat16.h b/ntt/include/nncase/bfloat16.h
index 1051c31e1..8c614eb63 100644
--- a/ntt/include/nncase/bfloat16.h
+++ b/ntt/include/nncase/bfloat16.h
@@ -43,11 +43,11 @@ struct bfloat16 {
constexpr operator __bf16() const noexcept {
return std::bit_cast<__bf16>(value_);
}
-// #else
-// constexpr operator float() const noexcept {
-// uint32_t value = raw() << 16;
-// return std::bit_cast(value);
-// }
+ // #else
+ // constexpr operator float() const noexcept {
+ // uint32_t value = raw() << 16;
+ // return std::bit_cast(value);
+ // }
#endif
@@ -133,7 +133,6 @@ struct bfloat16 {
return uint64_t(double(*this));
}
-
constexpr explicit operator uint8_t() const noexcept {
return uint8_t(float(*this));
}
@@ -142,7 +141,6 @@ struct bfloat16 {
return int8_t(float(*this));
}
-
constexpr explicit operator int16_t() const noexcept {
return int16_t(float(*this));
}
@@ -151,7 +149,6 @@ struct bfloat16 {
return uint16_t(float(*this));
}
-
constexpr explicit operator bool() const noexcept {
return bool(std::bit_cast(*this));
}
@@ -162,8 +159,6 @@ struct bfloat16 {
: nan();
}
-
-
static constexpr bfloat16 epsilon() noexcept {
// 0x1.0p-7
return from_raw(0x3c00);
@@ -199,12 +194,14 @@ struct bfloat16 {
};
#define DEFINE_BF16_BINARY_BF16RET(x) \
- inline bfloat16 operator x(bfloat16 a, bfloat16 b) noexcept { \
+ NTT_ALWAYS_INLINE constexpr bfloat16 operator x(bfloat16 a, \
+ bfloat16 b) noexcept { \
return bfloat16::round_to_bfloat16(float(a) x float(b)); \
}
#define DEFINE_BF16_BINARY_BOOLRET(x) \
- inline bool operator x(bfloat16 a, bfloat16 b) noexcept { \
+ NTT_ALWAYS_INLINE constexpr bool operator x(bfloat16 a, \
+ bfloat16 b) noexcept { \
return float(a) x float(b); \
}
@@ -218,7 +215,8 @@ DEFINE_BF16_BINARY_BOOLRET(>=)
DEFINE_BF16_BINARY_BOOLRET(>)
#define DEFINE_BF16_BINARY_SELF_MOD(x, op) \
- inline bfloat16 &operator x(bfloat16 & a, bfloat16 b) noexcept { \
+ NTT_ALWAYS_INLINE constexpr bfloat16 &operator x(bfloat16 &a, \
+ bfloat16 b) noexcept { \
a = a op b; \
return a; \
}
@@ -228,15 +226,17 @@ DEFINE_BF16_BINARY_SELF_MOD(-=, -)
DEFINE_BF16_BINARY_SELF_MOD(*=, *)
DEFINE_BF16_BINARY_SELF_MOD(/=, /)
-inline bfloat16 operator-(bfloat16 a) noexcept {
+NTT_ALWAYS_INLINE constexpr bfloat16 operator-(bfloat16 a) noexcept {
return bfloat16::round_to_bfloat16(-float(a));
}
-inline bool operator==(const bfloat16 &lhs, const bfloat16 &rhs) noexcept {
+NTT_ALWAYS_INLINE constexpr bool operator==(const bfloat16 &lhs,
+ const bfloat16 &rhs) noexcept {
return lhs.raw() == rhs.raw();
}
-inline bool operator!=(const bfloat16 &lhs, const bfloat16 &rhs) noexcept {
+NTT_ALWAYS_INLINE constexpr bool operator!=(const bfloat16 &lhs,
+ const bfloat16 &rhs) noexcept {
return lhs.raw() != rhs.raw();
}
} // namespace nncase
@@ -305,67 +305,76 @@ template <> struct numeric_limits {
};
using nncase::bfloat16;
-inline bool isinf(const bfloat16 &a) { return std::isinf(float(a)); }
-inline bool isnan(const bfloat16 &a) { return std::isnan(float(a)); }
-inline bool isfinite(const bfloat16 &a) { return std::isfinite(float(a)); }
-inline bfloat16 abs(const bfloat16 &a) {
+NTT_ALWAYS_INLINE NTT_HOST_DEVICE bool isinf(const bfloat16 &a) {
+ return std::isinf(float(a));
+}
+NTT_ALWAYS_INLINE NTT_HOST_DEVICE bool isnan(const bfloat16 &a) {
+ return std::isnan(float(a));
+}
+NTT_ALWAYS_INLINE NTT_HOST_DEVICE bool isfinite(const bfloat16 &a) {
+ return std::isfinite(float(a));
+}
+NTT_ALWAYS_INLINE NTT_HOST_DEVICE bfloat16 abs(const bfloat16 &a) {
return bfloat16::round_to_bfloat16(fabsf(float(a)));
}
-inline bfloat16 acos(const bfloat16 &a) {
+NTT_ALWAYS_INLINE NTT_HOST_DEVICE bfloat16 acos(const bfloat16 &a) {
return bfloat16::round_to_bfloat16(std::acos(float(a)));
}
-inline bfloat16 asin(const bfloat16 &a) {
+NTT_ALWAYS_INLINE NTT_HOST_DEVICE bfloat16 asin(const bfloat16 &a) {
return bfloat16::round_to_bfloat16(std::asin(float(a)));
}
-inline bfloat16 erf(const bfloat16 &a) {
+NTT_ALWAYS_INLINE NTT_HOST_DEVICE bfloat16 erf(const bfloat16 &a) {
return bfloat16::round_to_bfloat16(std::erff(float(a)));
}
-inline bfloat16 exp(const bfloat16 &a) {
+NTT_ALWAYS_INLINE NTT_HOST_DEVICE bfloat16 exp(const bfloat16 &a) {
return bfloat16::round_to_bfloat16(expf(float(a)));
}
-inline bfloat16 log(const bfloat16 &a) {
+NTT_ALWAYS_INLINE NTT_HOST_DEVICE bfloat16 log(const bfloat16 &a) {
return bfloat16::round_to_bfloat16(logf(float(a)));
}
-inline bfloat16 log10(const bfloat16 &a) {
+NTT_ALWAYS_INLINE NTT_HOST_DEVICE bfloat16 log10(const bfloat16 &a) {
return bfloat16::round_to_bfloat16(log10f(float(a)));
}
-inline bfloat16 sqrt(const bfloat16 &a) {
+NTT_ALWAYS_INLINE NTT_HOST_DEVICE bfloat16 sqrt(const bfloat16 &a) {
return bfloat16::round_to_bfloat16(sqrtf(float(a)));
}
-inline bfloat16 pow(const bfloat16 &a, const bfloat16 &b) {
+NTT_ALWAYS_INLINE NTT_HOST_DEVICE bfloat16 pow(const bfloat16 &a,
+ const bfloat16 &b) {
return bfloat16::round_to_bfloat16(powf(float(a), float(b)));
}
-inline bfloat16 sin(const bfloat16 &a) {
+NTT_ALWAYS_INLINE NTT_HOST_DEVICE bfloat16 sin(const bfloat16 &a) {
return bfloat16::round_to_bfloat16(sinf(float(a)));
}
-inline bfloat16 cos(const bfloat16 &a) {
+NTT_ALWAYS_INLINE NTT_HOST_DEVICE bfloat16 cos(const bfloat16 &a) {
return bfloat16::round_to_bfloat16(cosf(float(a)));
}
-inline bfloat16 tan(const bfloat16 &a) {
+NTT_ALWAYS_INLINE NTT_HOST_DEVICE bfloat16 tan(const bfloat16 &a) {
return bfloat16::round_to_bfloat16(tanf(float(a)));
}
-inline bfloat16 tanh(const bfloat16 &a) {
+NTT_ALWAYS_INLINE NTT_HOST_DEVICE bfloat16 tanh(const bfloat16 &a) {
return bfloat16::round_to_bfloat16(tanhf(float(a)));
}
-inline bfloat16 floor(const bfloat16 &a) {
+NTT_ALWAYS_INLINE NTT_HOST_DEVICE bfloat16 floor(const bfloat16 &a) {
return bfloat16::round_to_bfloat16(floorf(float(a)));
}
-inline bfloat16 ceil(const bfloat16 &a) {
+NTT_ALWAYS_INLINE NTT_HOST_DEVICE bfloat16 ceil(const bfloat16 &a) {
return bfloat16::round_to_bfloat16(ceilf(float(a)));
}
-inline bfloat16 round(const bfloat16 &a) {
+NTT_ALWAYS_INLINE NTT_HOST_DEVICE bfloat16 round(const bfloat16 &a) {
return bfloat16::round_to_bfloat16(roundf(float(a)));
}
-inline bfloat16 nearbyint(const bfloat16 &a) {
+NTT_ALWAYS_INLINE NTT_HOST_DEVICE bfloat16 nearbyint(const bfloat16 &a) {
return bfloat16::round_to_bfloat16(nearbyintf(float(a)));
}
-inline long lrint(const bfloat16 &a) { return lrintf(float(a)); }
+NTT_ALWAYS_INLINE NTT_HOST_DEVICE long lrint(const bfloat16 &a) {
+ return lrintf(float(a));
+}
template <> struct is_arithmetic : public true_type {};
} // namespace std
-inline nncase::bfloat16 operator"" _bf16(long double x) {
+NTT_ALWAYS_INLINE NTT_HOST_DEVICE nncase::bfloat16
+operator""_bf16(long double x) {
return nncase::bfloat16(float(x));
}
-
diff --git a/ntt/include/nncase/float8.h b/ntt/include/nncase/float8.h
index c6043c93f..7a36f16e8 100644
--- a/ntt/include/nncase/float8.h
+++ b/ntt/include/nncase/float8.h
@@ -36,7 +36,7 @@
*/
#pragma once
-#if defined(__GNUC__) && defined(__x86_64__)
+#if defined(__GNUC__) && defined(__x86_64__) && !defined(__clang__)
#pragma GCC optimize("no-strict-aliasing")
#endif
@@ -81,11 +81,12 @@
// #include
// #include "nncase/nncase.h"
+#include "ntt/compiler_defs.h"
#include "bfloat16.h"
#include "half.h"
#ifndef CUTLASS_HOST_DEVICE
-#define CUTLASS_HOST_DEVICE inline
-#define CUTLASS_DEVICE inline
+#define CUTLASS_HOST_DEVICE NTT_HOST_DEVICE inline
+#define CUTLASS_DEVICE NTT_DEVICE inline
#endif // !CUTLASS_HOST_DEVICE
///////////////////////////////////////////////////////////////////////////////////////////////////
@@ -1384,22 +1385,22 @@ struct numeric_limits
//
CUTLASS_HOST_DEVICE
-nncase::float_e4m3_t operator"" _fe4m3(long double x) {
+nncase::float_e4m3_t operator""_fe4m3(long double x) {
return nncase::float_e4m3_t(float(x));
}
CUTLASS_HOST_DEVICE
-nncase::float_e4m3_t operator"" _fe4m3(unsigned long long int x) {
+nncase::float_e4m3_t operator""_fe4m3(unsigned long long int x) {
return nncase::float_e4m3_t(int(x));
}
CUTLASS_HOST_DEVICE
-nncase::float_e5m2_t operator"" _fe5m2(long double x) {
+nncase::float_e5m2_t operator""_fe5m2(long double x) {
return nncase::float_e5m2_t(float(x));
}
CUTLASS_HOST_DEVICE
-nncase::float_e5m2_t operator"" _fe5m2(unsigned long long int x) {
+nncase::float_e5m2_t operator""_fe5m2(unsigned long long int x) {
return nncase::float_e5m2_t(int(x));
}
diff --git a/ntt/include/nncase/half.h b/ntt/include/nncase/half.h
index 7d1a3e0b2..f09826c43 100644
--- a/ntt/include/nncase/half.h
+++ b/ntt/include/nncase/half.h
@@ -24,11 +24,19 @@
#include
#include
-#ifdef __F16C__
+#ifdef __CUDA_ARCH__
+#include
+#elif defined(__F16C__)
#include
#endif
namespace nncase {
+#ifdef __CUDA_ARCH__
+using native_half_t = __half;
+#else
+using native_half_t = _Float16;
+#endif
+
struct fp16_from_raw_t {
explicit fp16_from_raw_t() = default;
};
@@ -44,7 +52,7 @@ struct half {
public:
constexpr half() noexcept = default;
- constexpr half(_Float16 v) noexcept : value_(v) {}
+ constexpr half(native_half_t v) noexcept : value_(v) {}
template ::value ||
@@ -54,9 +62,11 @@ struct half {
static constexpr half round_to_half(float v) {
if (std::is_constant_evaluated()) {
- return (_Float16)v;
+ return (native_half_t)v;
} else {
-#ifdef __F16C__
+#ifdef __CUDA_ARCH__
+ return __float2half_rn(v);
+#elif defined(__F16C__)
// To avoid truncsfhf2
return from_raw(_cvtss_sh(v, _MM_FROUND_NEARBYINT));
#else
@@ -64,7 +74,7 @@ struct half {
#endif
}
- return (_Float16)v;
+ return (native_half_t)v;
}
static constexpr half epsilon() noexcept { return from_raw(0x0800); }
@@ -91,14 +101,16 @@ struct half {
: value_(round_to_half(float(x)).value_) {}
constexpr half(fp16_from_raw_t, uint16_t value) noexcept
- : value_(std::bit_cast<_Float16>(value)) {}
+ : value_(std::bit_cast(value)) {}
- constexpr operator _Float16() const noexcept { return value_; }
+ constexpr operator native_half_t() const noexcept { return value_; }
constexpr operator float() const noexcept {
if (std::is_constant_evaluated()) {
return (float)value_;
} else {
-#ifdef __F16C__
+#ifdef __CUDA_ARCH__
+ return __half2float(value_);
+#elif defined(__F16C__)
// To avoid extendhfdf2
return _cvtsh_ss(raw());
#else
@@ -177,7 +189,7 @@ struct half {
}
private:
- _Float16 value_;
+ native_half_t value_;
};
#define DEFINE_FP16_BINARY_FP16RET(x) \
@@ -216,7 +228,7 @@ DEFINE_FP16_BINARY_BOOLRET(>=)
DEFINE_FP16_BINARY_BOOLRET(>)
#define DEFINE_FP16_BINARY_SELF_MOD(x, op) \
- NTT_ALWAYS_INLINE half &operator x(half & a, half b) noexcept { \
+ NTT_ALWAYS_INLINE half &operator x(half &a, half b) noexcept { \
a = a op b; \
return a; \
}
@@ -242,7 +254,8 @@ inline std::ostream &operator<<(std::ostream &os, const half &a) {
os << std::to_string(float(a));
return os;
}
-inline half nextafter(const half &from, const half &to) {
+NTT_ALWAYS_INLINE NTT_HOST_DEVICE half nextafter(const half &from,
+ const half &to) {
if (from.raw() == to.raw()) {
return to;
}
@@ -365,48 +378,76 @@ template <> struct numeric_limits {
};
using nncase::half;
-inline bool isinf(const half &a) { return std::isinf((float)(a)); }
-inline bool isnan(const half &a) { return std::isnan(float(a)); }
-inline bool isfinite(const half &a) { return std::isfinite(float(a)); }
-inline half abs(const half &a) { return half::round_to_half(fabsf(float(a))); }
-inline half fabs(const half &a) { return half::round_to_half(fabs(float(a))); }
-inline half exp(const half &a) { return half::round_to_half(expf(float(a))); }
-inline half log(const half &a) { return half::round_to_half(logf(float(a))); }
-inline half log10(const half &a) {
+NTT_ALWAYS_INLINE NTT_HOST_DEVICE bool isinf(const half &a) {
+ return std::isinf((float)(a));
+}
+NTT_ALWAYS_INLINE NTT_HOST_DEVICE bool isnan(const half &a) {
+ return std::isnan(float(a));
+}
+NTT_ALWAYS_INLINE NTT_HOST_DEVICE bool isfinite(const half &a) {
+ return std::isfinite(float(a));
+}
+NTT_ALWAYS_INLINE NTT_HOST_DEVICE half abs(const half &a) {
+ return half::round_to_half(fabsf(float(a)));
+}
+NTT_ALWAYS_INLINE NTT_HOST_DEVICE half fabs(const half &a) {
+ return half::round_to_half(fabs(float(a)));
+}
+NTT_ALWAYS_INLINE NTT_HOST_DEVICE half exp(const half &a) {
+ return half::round_to_half(expf(float(a)));
+}
+NTT_ALWAYS_INLINE NTT_HOST_DEVICE half log(const half &a) {
+ return half::round_to_half(logf(float(a)));
+}
+NTT_ALWAYS_INLINE NTT_HOST_DEVICE half log10(const half &a) {
return half::round_to_half(log10f(float(a)));
}
-inline half sqrt(const half &a) { return half::round_to_half(sqrtf(float(a))); }
-inline half sin(const half &a) { return half::round_to_half(sinf(float(a))); }
-inline half cos(const half &a) { return half::round_to_half(cosf(float(a))); }
-inline half tan(const half &a) { return half::round_to_half(tanf(float(a))); }
-inline half tanh(const half &a) { return half::round_to_half(tanh(float(a))); }
-inline half floor(const half &a) {
+NTT_ALWAYS_INLINE NTT_HOST_DEVICE half sqrt(const half &a) {
+ return half::round_to_half(sqrtf(float(a)));
+}
+NTT_ALWAYS_INLINE NTT_HOST_DEVICE half sin(const half &a) {
+ return half::round_to_half(sinf(float(a)));
+}
+NTT_ALWAYS_INLINE NTT_HOST_DEVICE half cos(const half &a) {
+ return half::round_to_half(cosf(float(a)));
+}
+NTT_ALWAYS_INLINE NTT_HOST_DEVICE half tan(const half &a) {
+ return half::round_to_half(tanf(float(a)));
+}
+NTT_ALWAYS_INLINE NTT_HOST_DEVICE half tanh(const half &a) {
+ return half::round_to_half(tanh(float(a)));
+}
+NTT_ALWAYS_INLINE NTT_HOST_DEVICE half floor(const half &a) {
return half::round_to_half(floorf(float(a)));
}
-inline half ceil(const half &a) { return half::round_to_half(ceilf(float(a))); }
-inline half round(const half &a) {
+NTT_ALWAYS_INLINE NTT_HOST_DEVICE half ceil(const half &a) {
+ return half::round_to_half(ceilf(float(a)));
+}
+NTT_ALWAYS_INLINE NTT_HOST_DEVICE half round(const half &a) {
return half::round_to_half(roundf(float(a)));
}
-inline half nearbyint(const half &a) {
+NTT_ALWAYS_INLINE NTT_HOST_DEVICE half nearbyint(const half &a) {
return half::round_to_half(nearbyintf(float(a)));
}
-inline half acos(const half &a) {
+NTT_ALWAYS_INLINE NTT_HOST_DEVICE half acos(const half &a) {
return half::round_to_half(std::acos(float(a)));
}
-inline half asin(const half &a) {
+NTT_ALWAYS_INLINE NTT_HOST_DEVICE half asin(const half &a) {
return half::round_to_half(std::asin(float(a)));
}
-inline half cosh(const half &a) {
+NTT_ALWAYS_INLINE NTT_HOST_DEVICE half cosh(const half &a) {
return half::round_to_half(std::cosh(float(a)));
}
-inline half sinh(const half &a) {
+NTT_ALWAYS_INLINE NTT_HOST_DEVICE half sinh(const half &a) {
return half::round_to_half(std::sinh(float(a)));
}
-inline half erf(const half &a) {
+NTT_ALWAYS_INLINE NTT_HOST_DEVICE half erf(const half &a) {
return half::round_to_half(std::erff(float(a)));
}
-inline long lrint(const half &a) { return lrintf(float(a)); }
+NTT_ALWAYS_INLINE NTT_HOST_DEVICE long lrint(const half &a) {
+ return lrintf(float(a));
+}
template <> struct is_floating_point : public std::true_type {};
template <> struct is_arithmetic : public true_type {};
-} // namespace std
\ No newline at end of file
+} // namespace std
diff --git a/ntt/include/nncase/ntt/apply.h b/ntt/include/nncase/ntt/apply.h
index 12677aae4..5192e0ddc 100644
--- a/ntt/include/nncase/ntt/apply.h
+++ b/ntt/include/nncase/ntt/apply.h
@@ -29,7 +29,7 @@ template &index, Offsets offsets,
const Shape &shape, const TTile &tile, Callable &&callable,
- const std::tuple &strides) {
+ const ntt::tuple &strides) {
auto call = [&](std::index_sequence) {
if constexpr (sizeof...(Strides)) {
callable(index, offsets[fixed_dim_v]...);
@@ -47,7 +47,7 @@ apply_impl(dynamic_shape_t &index, Offsets offsets,
std::forward(callable), strides);
}
ntt::loop([&](auto i) {
- offsets[i] += std::get(strides)[fixed_dim_v] *
+ offsets[i] += ntt::get(strides)[fixed_dim_v] *
tile[fixed_dim_v];
});
}
@@ -62,7 +62,7 @@ NTT_ALWAYS_INLINE constexpr void apply(const TShape &shape, Callable &&callable,
detail::apply_impl<0>(index, make_repeat_shape(0),
shape, make_ones_shape(),
std::forward(callable),
- std::forward_as_tuple(strides...));
+ ntt::forward_as_tuple(strides...));
} else {
if constexpr (sizeof...(TStrides)) {
callable(fixed_shape_v<>, (strides, (dim_t)0)...);
@@ -80,7 +80,7 @@ apply_tiled(const TShape &shape, const TTile &tile, Callable &&callable,
dynamic_shape_t index{};
detail::apply_impl<0>(index, make_repeat_shape(0),
shape, tile, std::forward(callable),
- std::forward_as_tuple(strides...));
+ ntt::forward_as_tuple(strides...));
} else {
if constexpr (sizeof...(TStrides)) {
callable(fixed_shape_v<>, (strides, (dim_t)0)...);
diff --git a/ntt/include/nncase/ntt/arch/cpu/remote_tensor.h b/ntt/include/nncase/ntt/arch/cpu/remote_tensor.h
index f339dabd1..b9c7a3538 100644
--- a/ntt/include/nncase/ntt/arch/cpu/remote_tensor.h
+++ b/ntt/include/nncase/ntt/arch/cpu/remote_tensor.h
@@ -34,12 +34,13 @@ extern decltype(nncase::ntt::make_tensor>(
template TLocalProgramIds,
ScopedProgramIds TRemoteProgramIds>
-static auto get_remote_address(const TLocalProgramIds &local_program_ids,
- const TRemoteProgramIds &remote_program_ids,
- T *local_address) {
+auto get_remote_address(const TLocalProgramIds &local_program_ids,
+ const TRemoteProgramIds &remote_program_ids,
+ T *local_address) {
auto start = (size_t)global_local_data_ptr(local_program_ids)(0_dim);
auto end = (size_t)global_local_data_ptr(local_program_ids)(1_dim);
- auto remote_address = (size_t)global_local_data_ptr(remote_program_ids)(0_dim);
+ auto remote_address =
+ (size_t)global_local_data_ptr(remote_program_ids)(0_dim);
if ((uintptr_t)local_address < start || (uintptr_t)local_address >= end) {
start = (size_t)global_thread_local_rdata_ptr(local_program_ids)(0_dim);
end = (size_t)global_thread_local_rdata_ptr(local_program_ids)(1_dim);
@@ -47,7 +48,8 @@ static auto get_remote_address(const TLocalProgramIds &local_program_ids,
(size_t)global_thread_local_rdata_ptr(remote_program_ids)(0_dim);
if ((uintptr_t)local_address < start ||
(uintptr_t)local_address >= end) {
- start = (size_t)global_block_local_rdata_ptr(local_program_ids)(0_dim);
+ start =
+ (size_t)global_block_local_rdata_ptr(local_program_ids)(0_dim);
remote_address =
(size_t)global_block_local_rdata_ptr(remote_program_ids)(0_dim);
}
diff --git a/ntt/include/nncase/ntt/arch/cpu/runtime.h b/ntt/include/nncase/ntt/arch/cpu/runtime.h
index 902c0f3f1..7a2b524bb 100644
--- a/ntt/include/nncase/ntt/arch/cpu/runtime.h
+++ b/ntt/include/nncase/ntt/arch/cpu/runtime.h
@@ -15,234 +15,13 @@
#pragma once
#include "../../profiling.h"
#include "../../runtime.h"
-#include
#include
-#include
-#include
-#include
-#include
-#include
-#include
#ifdef __APPLE__
#include
#endif
namespace nncase::ntt::runtime {
-
-struct record_id {
- int cid = -1;
- int bid = -1;
- int tid = -1;
-};
-
-class timer_record : public nncase::ntt::runtime::timer_record_base {
- public:
- bool is_valid() const override {
- return instance_id_.cid != -1 && instance_id_.bid != -1 &&
- instance_id_.tid != -1;
- }
-
- void set_time(std::string_view function_name, uint64_t start_time,
- uint64_t end_time) override {
- auto &stats = function_stats_[function_name];
- stats.calls.push_back({start_time, end_time});
- stats.call_count++;
- stats.total_time += end_time - start_time;
- }
-
- void set_level(std::string_view filename, profiling_level level) override {
- auto &stats = function_stats_[filename];
- stats.level = level;
- }
-
- // print statistics
- void console_print() const override {
-
- if (is_valid()) {
-
- std::cout << "\033[34m\n"
- << "Core Id:" << instance_id_.cid
- << ", Block Id:" << instance_id_.bid
- << ", Thread Id:" << instance_id_.tid << "\033[0m\n";
- std::cout << "\033[34mStatistics for NTT kernels. \033[0m\n";
- for (const auto &[name, stats] : function_stats_) {
- std::cout << "Function: " << name << "\n";
- std::cout << "Level: " << ntt::runtime::to_string(stats.level)
- << "\n";
- std::cout << "\tCalls: " << stats.call_count << "\n";
- std::cout << "\tTotal time: " << stats.total_time
- << " microseconds\n";
- uint64_t call_count = 0;
- for (const auto &call : stats.calls) {
- std::cout << "\t\t"
- << "Call " << call_count++ << ": \n";
- std::cout << "\t\tStart time: " << call.start_time
- << " microseconds\n";
- std::cout << "\t\tEnd time: " << call.end_time
- << " microseconds\n";
- std::cout
- << "\t\tDuration: " << call.end_time - call.start_time
- << " microseconds\n";
- }
- }
- }
- }
-
- void csv_print(std::string_view filename) const override {
- if (is_valid()) {
- std::ofstream csv_file(filename.data());
- if (!csv_file.is_open()) {
- std::cerr << "Failed to open file: " << filename << std::endl;
- return;
- }
-
- csv_file
- << "Core Id,Block Id,Thread Id,Function,Level,Calls,Total Time "
- "(microseconds),Call Index,Start Time (microseconds),End "
- "Time (microseconds),Duration (microseconds)\n";
-
- for (const auto &[name, stats] : function_stats_) {
- uint64_t call_count = 0;
- for (const auto &call : stats.calls) {
- csv_file << instance_id_.cid << "," << instance_id_.bid
- << "," << instance_id_.tid << "," << name << ","
- << ntt::runtime::to_string(stats.level) << ","
- << stats.call_count << "," << stats.total_time
- << "," << call_count++ << "," << call.start_time
- << "," << call.end_time << ","
- << (call.end_time - call.start_time) << "\n";
- }
- }
-
- csv_file.close();
- }
- }
-
- void markdown_print(std::string_view filename) const override {
-
- if (is_valid()) {
- std::ofstream md_file(filename.data());
- if (!md_file.is_open()) {
- std::cerr << "Failed to open file: " << filename << std::endl;
- return;
- }
-
- md_file << "### Core Information\n";
- md_file << "| Core Id | Block Id | Thread Id |\n";
- md_file << "|---------|----------|-----------|\n";
- md_file << "| " << instance_id_.cid << " | " << instance_id_.bid
- << " | " << instance_id_.tid << " |\n";
-
- md_file << "\n### NTT Kernels Statistics\n";
-
- for (const auto &[name, stats] : function_stats_) {
- md_file << "#### Function: " << name << "\n";
- md_file << "| Level | Calls | Total Time (microseconds) |\n";
- md_file << "|-------|-------|---------------------------|\n";
- md_file << "| " << ntt::runtime::to_string(stats.level) << " | "
- << stats.call_count << " | " << stats.total_time
- << " |\n";
-
- md_file << "\n**Call Details:**\n";
- md_file << "| Call Index | Start Time (microseconds) | End "
- "Time (microseconds) | Duration (microseconds) |\n";
- md_file << "|------------|---------------------------|---------"
- "----------------|-------------------------|\n";
-
- uint64_t call_count = 0;
- for (const auto &call : stats.calls) {
- md_file << "| " << call_count++ << " | " << call.start_time
- << " | " << call.end_time << " | "
- << (call.end_time - call.start_time) << " |\n";
- }
- md_file << "\n";
- }
-
- md_file.close();
- }
- }
-
- void json_print(std::string_view filename) const override {
-
- if (is_valid()) {
- std::ofstream json_file(filename.data());
- if (!json_file.is_open()) {
- std::cerr << "Failed to open file: " << filename << std::endl;
- return;
- }
-
- std::string pid = "\"cid: " + std::to_string(instance_id_.cid) +
- ", bid: " + std::to_string(instance_id_.bid) +
- "\"";
- std::string tid =
- "\"tid: " + std::to_string(instance_id_.tid) + "\"";
- json_file << "[\n";
-
- bool first = true;
- for (const auto &[name, stats] : function_stats_) {
- for (const auto &call : stats.calls) {
- if (stats.level == profiling_level::kernel) {
- if (!first) {
- json_file << ",\n";
- }
- first = false;
- json_file << " {\n";
- json_file << " \"name\": \"" << name << "\",\n";
- json_file << " \"ph\": \"X\",\n";
- json_file << " \"ts\": " << call.start_time << ",\n";
- json_file << " \"dur\": "
- << (call.end_time - call.start_time) << ",\n";
- json_file << " \"pid\": " << pid << ",\n";
- json_file << " \"tid\": " << tid << ",\n";
- json_file << " \"args\": { \"level:\":\""
- << ntt::runtime::to_string(stats.level)
- << " \"}\n";
- json_file << " }";
- }
- }
- }
-
- for (const auto &[name, stats] : function_stats_) {
- for (const auto &call : stats.calls) {
- if (stats.level == profiling_level::device) {
- if (!first) {
- json_file << ",\n";
- }
- first = false;
- json_file << " {\n";
- json_file << " \"name\": \"" << name << "\",\n";
- json_file << " \"ph\": \"X\",\n";
- json_file << " \"ts\": " << call.start_time << ",\n";
- json_file << " \"dur\": "
- << (call.end_time - call.start_time) << ",\n";
- json_file << " \"pid\": " << pid << ",\n";
- json_file << " \"tid\": " << tid << ",\n";
- json_file << " \"args\": { \"level:\":\""
- << ntt::runtime::to_string(stats.level)
- << " \"}\n";
- json_file << " }";
- }
- }
- }
-
- json_file << "\n]\n";
- json_file.close();
- }
- }
-
- timer_record() = default;
-
- ~timer_record() {
- console_print();
- markdown_print("nncase_profiling.md");
- csv_print("nncase_profiling.csv");
- json_print("nncase_profiling.json");
- }
-
- void set_id(record_id id) override { instance_id_ = id; }
-};
-
struct cpu_block_entry_params_t {
size_t tdim;
size_t bdim;
@@ -250,12 +29,11 @@ struct cpu_block_entry_params_t {
size_t bid;
size_t cid;
size_t cpu_id_offset;
+ uint8_t enable_profiling;
const thread_inout_desc *input_descs;
thread_inout_desc *const output_descs;
std::span rdata;
std::byte *output;
- uint8_t enable_profiling;
- timer_record *timer_records;
const uint64_t *thread_local_rdata_header;
const uint64_t *thread_local_cache_header;
std::span thread_local_rdata;
@@ -264,6 +42,8 @@ struct cpu_block_entry_params_t {
std::span block_local_rdata;
std::span thread_local_data;
std::span block_local_data;
+ std::span profile_records;
+ uint32_t *profile_record_counts;
#ifdef __APPLE__
pthread_key_t cpu_thread_context_key;
#endif
@@ -273,8 +53,9 @@ struct cpu_thread_context_t {
size_t tid;
size_t bid;
size_t cid;
- timer_record *timer_records;
uint8_t enable_profiling;
+ std::span profile_records;
+ uint32_t *profile_record_counts;
static cpu_thread_context_t ¤t() noexcept;
};
diff --git a/ntt/include/nncase/ntt/arch/cuda/distributed.h b/ntt/include/nncase/ntt/arch/cuda/distributed.h
new file mode 100644
index 000000000..0aeaee9ca
--- /dev/null
+++ b/ntt/include/nncase/ntt/arch/cuda/distributed.h
@@ -0,0 +1,17 @@
+/* Copyright 2019-2021 Canaan Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#pragma once
+#include "remote_tensor.h"
+#include "topology.h"
diff --git a/ntt/include/nncase/ntt/arch/cpu/profiling.h b/ntt/include/nncase/ntt/arch/cuda/profiling.h
similarity index 74%
rename from ntt/include/nncase/ntt/arch/cpu/profiling.h
rename to ntt/include/nncase/ntt/arch/cuda/profiling.h
index a5022fa17..7105f0f2e 100644
--- a/ntt/include/nncase/ntt/arch/cpu/profiling.h
+++ b/ntt/include/nncase/ntt/arch/cuda/profiling.h
@@ -28,16 +28,16 @@ namespace nncase::ntt {
// static nncase::ntt::runtime::timer_record
// timer_records[CHIP_COUNTER][BLOCK_COUNTER][THREAD_COUNTER];
-// auto_profiler, start timing and end timing
-class auto_profiler {
+// profile_scope, start timing and end timing
+class profile_scope {
public:
- inline uint64_t get_current_time() const {
+ __device__ inline uint64_t get_current_time() const {
return std::chrono::duration_cast(
std::chrono::high_resolution_clock::now().time_since_epoch())
.count();
}
- auto_profiler(std::string_view function_name)
+ __device__ profile_scope(std::string_view function_name)
: cid_(program_id()),
bid_(program_id()),
tid_(program_id()) {
@@ -50,15 +50,15 @@ class auto_profiler {
}
}
- auto_profiler(std::string_view function_name,
- runtime::profiling_level level)
- : auto_profiler(function_name) { // 调用另一个构造函数
+ __device__ profile_scope(std::string_view function_name,
+ profile_level level)
+ : profile_scope(function_name) { // 调用另一个构造函数
if (enable_profiling_) {
level_ = level; // 设置 level
}
}
- ~auto_profiler() {
+ __device__ ~profile_scope() {
if (enable_profiling_) {
timer_storage_->set_id({cid_, bid_, tid_});
end_time_ = get_current_time();
@@ -74,16 +74,17 @@ class auto_profiler {
int cid_;
int bid_;
int tid_;
- nncase::ntt::runtime::profiling_level level_;
+ nncase::ntt::profile_level level_;
nncase::ntt::runtime::timer_record *timer_storage_;
bool enable_profiling_;
- inline bool get_profiler_option() noexcept {
- return runtime::cpu_thread_context_t::current().enable_profiling;
+ __device__ inline bool get_profiler_option() noexcept {
+ return runtime::cuda_thread_context_t::current().enable_profiling;
}
- inline nncase::ntt::runtime::timer_record *get_timer_record() noexcept {
- return runtime::cpu_thread_context_t::current().timer_records;
+ __device__ inline nncase::ntt::runtime::timer_record *
+ get_timer_record() noexcept {
+ return runtime::cuda_thread_context_t::current().timer_records;
}
};
diff --git a/ntt/include/nncase/ntt/arch/cuda/remote_tensor.h b/ntt/include/nncase/ntt/arch/cuda/remote_tensor.h
new file mode 100644
index 000000000..06e644b89
--- /dev/null
+++ b/ntt/include/nncase/ntt/arch/cuda/remote_tensor.h
@@ -0,0 +1,81 @@
+/* Copyright 2019-2021 Canaan Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#pragma once
+#include "../../distributed/remote_tensor.h"
+#include "../../tensor.h"
+#include "../../vector.h"
+
+namespace nncase::ntt::distributed {
+namespace detail {
+extern __device__ decltype(nncase::ntt::make_tensor<
+ nncase::ntt::vector>(
+ nncase::ntt::distributed::topology_shape)) global_local_data_ptr;
+
+extern __device__ decltype(nncase::ntt::make_tensor<
+ nncase::ntt::vector>(
+ nncase::ntt::distributed::topology_shape)) global_thread_local_rdata_ptr;
+
+extern __device__ decltype(nncase::ntt::make_tensor<
+ nncase::ntt::vector>(
+ nncase::ntt::distributed::topology_shape)) global_thread_local_cache_ptr;
+
+extern __device__ decltype(nncase::ntt::make_tensor<
+ nncase::ntt::vector