From 48f54b32d0ae0ed80a4c823daf76f882c21220a3 Mon Sep 17 00:00:00 2001 From: Zheyuan Chen Date: Thu, 9 Apr 2026 21:13:49 -0700 Subject: [PATCH] support cooperative matrix --- docs/amber_script.md | 1 + samples/config_helper_vulkan.cc | 24 +++++++++++++++++++ samples/config_helper_vulkan.h | 2 ++ src/amberscript/parser_device_feature_test.cc | 8 ++++--- src/script.cc | 1 + src/vulkan/device.cc | 16 +++++++++++++ 6 files changed, 49 insertions(+), 3 deletions(-) diff --git a/docs/amber_script.md b/docs/amber_script.md index db93bd4b..59d20e4d 100644 --- a/docs/amber_script.md +++ b/docs/amber_script.md @@ -41,6 +41,7 @@ with: * `Float16Int8Features.shaderInt8` * `VulkanMemoryModelFeatures.vulkanMemoryModel` * `VulkanMemoryModelFeatures.vulkanMemoryModelDeviceScope` + * `CooperativeMatrixFeaturesKHR.cooperativeMatrix` * `ZeroInitializeWorkgroupMemoryFeatures.shaderZeroInitializeWorkgroupMemory` * `Storage8BitFeatures.storageBuffer8BitAccess` * `Storage8BitFeatures.uniformAndStorageBuffer8BitAccess` diff --git a/samples/config_helper_vulkan.cc b/samples/config_helper_vulkan.cc index c8b8b015..c21655c8 100644 --- a/samples/config_helper_vulkan.cc +++ b/samples/config_helper_vulkan.cc @@ -86,6 +86,8 @@ const char kDepthClampZeroOne[] = "DepthClampZeroOneFeatures.depthClampZeroOne"; const char kShaderSubgroupExtendedTypes[] = "ShaderSubgroupExtendedTypesFeatures.shaderSubgroupExtendedTypes"; +const char kCooperativeMatrix[] = + "CooperativeMatrixFeaturesKHR.cooperativeMatrix"; const char kAccelerationStructure[] = "AccelerationStructureFeaturesKHR.accelerationStructure"; @@ -912,6 +914,8 @@ amber::Result ConfigHelperVulkan::CheckVulkanPhysicalDeviceRequirements( supports_.depth_clamp_zero_one = true; } else if (ext == VK_KHR_SHADER_SUBGROUP_EXTENDED_TYPES_EXTENSION_NAME) { supports_.shader_subgroup_extended_types = true; + } else if (ext == VK_KHR_COOPERATIVE_MATRIX_EXTENSION_NAME) { + supports_.cooperative_matrix = true; } else if (ext == VK_KHR_VARIABLE_POINTERS_EXTENSION_NAME) { supports_.variable_pointers = true; } else if (ext == VK_KHR_ACCELERATION_STRUCTURE_EXTENSION_NAME) { @@ -954,6 +958,7 @@ amber::Result ConfigHelperVulkan::CheckVulkanPhysicalDeviceRequirements( VkPhysicalDevice8BitStorageFeaturesKHR storage_8bit_features = {}; VkPhysicalDevice16BitStorageFeaturesKHR storage_16bit_features = {}; VkPhysicalDeviceVulkanMemoryModelFeatures memory_model_structure_features{}; + VkPhysicalDeviceCooperativeMatrixFeaturesKHR cooperative_matrix_features{}; VkPhysicalDeviceZeroInitializeWorkgroupMemoryFeaturesKHR zero_initialize_workgroup_memory_features{}; #ifdef VK_EXT_SHADER_LONG_VECTOR_EXTENSION_NAME @@ -993,6 +998,13 @@ amber::Result ConfigHelperVulkan::CheckVulkanPhysicalDeviceRequirements( memory_model_structure_features.pNext = next_ptr; next_ptr = &memory_model_structure_features; + if (supports_.cooperative_matrix) { + cooperative_matrix_features.sType = + VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_COOPERATIVE_MATRIX_FEATURES_KHR; + cooperative_matrix_features.pNext = next_ptr; + next_ptr = &cooperative_matrix_features; + } + zero_initialize_workgroup_memory_features.sType = // NOLINTNEXTLINE(whitespace/line_length) VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_ZERO_INITIALIZE_WORKGROUP_MEMORY_FEATURES_KHR; @@ -1080,6 +1092,10 @@ amber::Result ConfigHelperVulkan::CheckVulkanPhysicalDeviceRequirements( supports_.depth_clamp_zero_one = depth_clamp_zero_one_features.depthClampZeroOne; } + if (supports_.cooperative_matrix) { + supports_.cooperative_matrix = + cooperative_matrix_features.cooperativeMatrix; + } std::vector required_features1; for (const auto& feature : required_features) { @@ -1103,6 +1119,8 @@ amber::Result ConfigHelperVulkan::CheckVulkanPhysicalDeviceRequirements( (feature == kVulkanMemoryModel_vulkanMemoryModelDeviceScope && memory_model_structure_features.vulkanMemoryModelDeviceScope == VK_FALSE) || + (feature == kCooperativeMatrix && + cooperative_matrix_features.cooperativeMatrix == VK_FALSE) || (feature == // NOLINTNEXTLINE(whitespace/line_length) kZeroInitializeWorkgroupMemory_shaderZeroInitializeWorkgroupMemory && @@ -1424,6 +1442,12 @@ amber::Result ConfigHelperVulkan::CreateDeviceWithFeatures2( VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_DEPTH_CLAMP_ZERO_ONE_FEATURES_EXT, VK_EXT_DEPTH_CLAMP_ZERO_ONE_EXTENSION_NAME); features_.depth_clamp_zero_one.depthClampZeroOne = VK_TRUE; + } else if (feature == kCooperativeMatrix) { + init_feature( + supports_.cooperative_matrix, features_.cooperative_matrix, + VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_COOPERATIVE_MATRIX_FEATURES_KHR, + VK_KHR_COOPERATIVE_MATRIX_EXTENSION_NAME); + features_.cooperative_matrix.cooperativeMatrix = VK_TRUE; } else if (feature == kAccelerationStructure) { init_feature( supports_.acceleration_structure, features_.acceleration_structure, diff --git a/samples/config_helper_vulkan.h b/samples/config_helper_vulkan.h index fd143238..27054bba 100644 --- a/samples/config_helper_vulkan.h +++ b/samples/config_helper_vulkan.h @@ -123,6 +123,7 @@ class ConfigHelperVulkan : public ConfigHelperImpl { bool subgroup_size_control = false; bool depth_clamp_zero_one = false; bool shader_subgroup_extended_types = false; + bool cooperative_matrix = false; bool acceleration_structure = false; bool buffer_device_address = false; bool ray_tracing_pipeline = false; @@ -146,6 +147,7 @@ class ConfigHelperVulkan : public ConfigHelperImpl { VkPhysicalDeviceDepthClampZeroOneFeaturesEXT depth_clamp_zero_one{}; VkPhysicalDeviceShaderSubgroupExtendedTypesFeatures shader_subgroup_extended_types{}; + VkPhysicalDeviceCooperativeMatrixFeaturesKHR cooperative_matrix{}; VkPhysicalDeviceAccelerationStructureFeaturesKHR acceleration_structure{}; VkPhysicalDeviceBufferDeviceAddressFeatures buffer_device_address{}; VkPhysicalDeviceRayTracingPipelineFeaturesKHR ray_tracing_pipeline{}; diff --git a/src/amberscript/parser_device_feature_test.cc b/src/amberscript/parser_device_feature_test.cc index f7c26254..e25b8636 100644 --- a/src/amberscript/parser_device_feature_test.cc +++ b/src/amberscript/parser_device_feature_test.cc @@ -37,6 +37,7 @@ DEVICE_FEATURE SubgroupSizeControl.subgroupSizeControl DEVICE_FEATURE SubgroupSizeControl.computeFullSubgroups DEVICE_FEATURE VulkanMemoryModelFeatures.vulkanMemoryModel DEVICE_FEATURE VulkanMemoryModelFeatures.vulkanMemoryModelDeviceScope +DEVICE_FEATURE CooperativeMatrixFeaturesKHR.cooperativeMatrix DEVICE_FEATURE ZeroInitializeWorkgroupMemoryFeatures.shaderZeroInitializeWorkgroupMemory DEVICE_FEATURE ShaderLongVectorFeaturesEXT.longVector)"; @@ -46,7 +47,7 @@ DEVICE_FEATURE ShaderLongVectorFeaturesEXT.longVector)"; auto script = parser.GetScript(); const auto& features = script->GetRequiredFeatures(); - ASSERT_EQ(17U, features.size()); + ASSERT_EQ(18U, features.size()); EXPECT_EQ("vertexPipelineStoresAndAtomics", features[0]); EXPECT_EQ("VariablePointerFeatures.variablePointersStorageBuffer", features[1]); @@ -66,10 +67,11 @@ DEVICE_FEATURE ShaderLongVectorFeaturesEXT.longVector)"; EXPECT_EQ("VulkanMemoryModelFeatures.vulkanMemoryModel", features[13]); EXPECT_EQ("VulkanMemoryModelFeatures.vulkanMemoryModelDeviceScope", features[14]); + EXPECT_EQ("CooperativeMatrixFeaturesKHR.cooperativeMatrix", features[15]); EXPECT_EQ("ZeroInitializeWorkgroupMemoryFeatures." "shaderZeroInitializeWorkgroupMemory", - features[15]); - EXPECT_EQ("ShaderLongVectorFeaturesEXT.longVector", features[16]); + features[16]); + EXPECT_EQ("ShaderLongVectorFeaturesEXT.longVector", features[17]); } TEST_F(AmberScriptParserTest, DeviceFeatureMissingFeature) { diff --git a/src/script.cc b/src/script.cc index e6e4b4cd..0611801b 100644 --- a/src/script.cc +++ b/src/script.cc @@ -140,6 +140,7 @@ bool Script::IsKnownFeature(const std::string& name) const { name == "ShaderSubgroupExtendedTypesFeatures" ".shaderSubgroupExtendedTypes" || + name == "CooperativeMatrixFeaturesKHR.cooperativeMatrix" || name == "RayTracingPipelineFeaturesKHR.rayTracingPipeline" || name == "AccelerationStructureFeaturesKHR.accelerationStructure" || name == "BufferDeviceAddressFeatures.bufferDeviceAddress" || diff --git a/src/vulkan/device.cc b/src/vulkan/device.cc index 73d5afca..fd419e9e 100644 --- a/src/vulkan/device.cc +++ b/src/vulkan/device.cc @@ -100,6 +100,8 @@ const char kShaderSubgroupExtendedTypes[] = "ShaderSubgroupExtendedTypesFeatures.shaderSubgroupExtendedTypes"; const char kIndexTypeUint8[] = "IndexTypeUint8Features.indexTypeUint8"; +const char kCooperativeMatrix[] = + "CooperativeMatrixFeaturesKHR.cooperativeMatrix"; const char kAccelerationStructure[] = "AccelerationStructureFeaturesKHR.accelerationStructure"; @@ -574,6 +576,8 @@ Result Device::Initialize( VkPhysicalDeviceShaderSubgroupExtendedTypesFeatures* shader_subgroup_extended_types_ptrs = nullptr; VkPhysicalDeviceIndexTypeUint8FeaturesEXT* index_type_uint8_ptrs = nullptr; + VkPhysicalDeviceCooperativeMatrixFeaturesKHR* cooperative_matrix_ptrs = + nullptr; VkPhysicalDeviceAccelerationStructureFeaturesKHR* acceleration_structure_ptrs = nullptr; VkPhysicalDeviceBufferDeviceAddressFeatures* bda_ptrs = nullptr; @@ -618,6 +622,10 @@ Result Device::Initialize( index_type_uint8_ptrs = static_cast(ptr); break; + case VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_COOPERATIVE_MATRIX_FEATURES_KHR: + cooperative_matrix_ptrs = + static_cast(ptr); + break; // NOLINTNEXTLINE(whitespace/line_length) case VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_ACCELERATION_STRUCTURE_FEATURES_KHR: acceleration_structure_ptrs = @@ -703,6 +711,10 @@ Result Device::Initialize( depth_clamp_zero_one_features->depthClampZeroOne != VK_TRUE)) { return amber::Result("Depth clamp zero one requested but not returned"); } + if (feature == kCooperativeMatrix && cooperative_matrix_ptrs == nullptr) { + return amber::Result( + "Cooperative matrix requested but feature not returned"); + } if (feature == kAccelerationStructure) { if (acceleration_structure_ptrs == nullptr) { return amber::Result( @@ -921,6 +933,10 @@ Result Device::Initialize( return amber::Result( "Index type uint8_t requested but feature not returned"); } + if (feature == kCooperativeMatrix && + cooperative_matrix_ptrs->cooperativeMatrix != VK_TRUE) { + return amber::Result("Missing cooperative matrix feature"); + } } if (!AreAllExtensionsSupported(available_extensions,