@@ -86,6 +86,8 @@ const char kDepthClampZeroOne[] = "DepthClampZeroOneFeatures.depthClampZeroOne";
8686
8787const char kShaderSubgroupExtendedTypes [] =
8888 " ShaderSubgroupExtendedTypesFeatures.shaderSubgroupExtendedTypes" ;
89+ const char kCooperativeMatrix [] =
90+ " CooperativeMatrixFeaturesKHR.cooperativeMatrix" ;
8991
9092const char kAccelerationStructure [] =
9193 " AccelerationStructureFeaturesKHR.accelerationStructure" ;
@@ -912,6 +914,8 @@ amber::Result ConfigHelperVulkan::CheckVulkanPhysicalDeviceRequirements(
912914 supports_.depth_clamp_zero_one = true ;
913915 } else if (ext == VK_KHR_SHADER_SUBGROUP_EXTENDED_TYPES_EXTENSION_NAME) {
914916 supports_.shader_subgroup_extended_types = true ;
917+ } else if (ext == VK_KHR_COOPERATIVE_MATRIX_EXTENSION_NAME) {
918+ supports_.cooperative_matrix = true ;
915919 } else if (ext == VK_KHR_VARIABLE_POINTERS_EXTENSION_NAME) {
916920 supports_.variable_pointers = true ;
917921 } else if (ext == VK_KHR_ACCELERATION_STRUCTURE_EXTENSION_NAME) {
@@ -954,6 +958,7 @@ amber::Result ConfigHelperVulkan::CheckVulkanPhysicalDeviceRequirements(
954958 VkPhysicalDevice8BitStorageFeaturesKHR storage_8bit_features = {};
955959 VkPhysicalDevice16BitStorageFeaturesKHR storage_16bit_features = {};
956960 VkPhysicalDeviceVulkanMemoryModelFeatures memory_model_structure_features{};
961+ VkPhysicalDeviceCooperativeMatrixFeaturesKHR cooperative_matrix_features{};
957962 VkPhysicalDeviceZeroInitializeWorkgroupMemoryFeaturesKHR
958963 zero_initialize_workgroup_memory_features{};
959964#ifdef VK_EXT_SHADER_LONG_VECTOR_EXTENSION_NAME
@@ -993,6 +998,13 @@ amber::Result ConfigHelperVulkan::CheckVulkanPhysicalDeviceRequirements(
993998 memory_model_structure_features.pNext = next_ptr;
994999 next_ptr = &memory_model_structure_features;
9951000
1001+ if (supports_.cooperative_matrix ) {
1002+ cooperative_matrix_features.sType =
1003+ VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_COOPERATIVE_MATRIX_FEATURES_KHR;
1004+ cooperative_matrix_features.pNext = next_ptr;
1005+ next_ptr = &cooperative_matrix_features;
1006+ }
1007+
9961008 zero_initialize_workgroup_memory_features.sType =
9971009 // NOLINTNEXTLINE(whitespace/line_length)
9981010 VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_ZERO_INITIALIZE_WORKGROUP_MEMORY_FEATURES_KHR;
@@ -1080,6 +1092,10 @@ amber::Result ConfigHelperVulkan::CheckVulkanPhysicalDeviceRequirements(
10801092 supports_.depth_clamp_zero_one =
10811093 depth_clamp_zero_one_features.depthClampZeroOne ;
10821094 }
1095+ if (supports_.cooperative_matrix ) {
1096+ supports_.cooperative_matrix =
1097+ cooperative_matrix_features.cooperativeMatrix ;
1098+ }
10831099
10841100 std::vector<std::string> required_features1;
10851101 for (const auto & feature : required_features) {
@@ -1103,6 +1119,8 @@ amber::Result ConfigHelperVulkan::CheckVulkanPhysicalDeviceRequirements(
11031119 (feature == kVulkanMemoryModel_vulkanMemoryModelDeviceScope &&
11041120 memory_model_structure_features.vulkanMemoryModelDeviceScope ==
11051121 VK_FALSE) ||
1122+ (feature == kCooperativeMatrix &&
1123+ cooperative_matrix_features.cooperativeMatrix == VK_FALSE) ||
11061124 (feature ==
11071125 // NOLINTNEXTLINE(whitespace/line_length)
11081126 kZeroInitializeWorkgroupMemory_shaderZeroInitializeWorkgroupMemory &&
@@ -1424,6 +1442,12 @@ amber::Result ConfigHelperVulkan::CreateDeviceWithFeatures2(
14241442 VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_DEPTH_CLAMP_ZERO_ONE_FEATURES_EXT,
14251443 VK_EXT_DEPTH_CLAMP_ZERO_ONE_EXTENSION_NAME);
14261444 features_.depth_clamp_zero_one .depthClampZeroOne = VK_TRUE;
1445+ } else if (feature == kCooperativeMatrix ) {
1446+ init_feature (
1447+ supports_.cooperative_matrix , features_.cooperative_matrix ,
1448+ VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_COOPERATIVE_MATRIX_FEATURES_KHR,
1449+ VK_KHR_COOPERATIVE_MATRIX_EXTENSION_NAME);
1450+ features_.cooperative_matrix .cooperativeMatrix = VK_TRUE;
14271451 } else if (feature == kAccelerationStructure ) {
14281452 init_feature (
14291453 supports_.acceleration_structure , features_.acceleration_structure ,
0 commit comments