Skip to content

Commit 48f54b3

Browse files
support cooperative matrix
1 parent 53a4c89 commit 48f54b3

6 files changed

Lines changed: 49 additions & 3 deletions

File tree

docs/amber_script.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ with:
4141
* `Float16Int8Features.shaderInt8`
4242
* `VulkanMemoryModelFeatures.vulkanMemoryModel`
4343
* `VulkanMemoryModelFeatures.vulkanMemoryModelDeviceScope`
44+
* `CooperativeMatrixFeaturesKHR.cooperativeMatrix`
4445
* `ZeroInitializeWorkgroupMemoryFeatures.shaderZeroInitializeWorkgroupMemory`
4546
* `Storage8BitFeatures.storageBuffer8BitAccess`
4647
* `Storage8BitFeatures.uniformAndStorageBuffer8BitAccess`

samples/config_helper_vulkan.cc

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,8 @@ const char kDepthClampZeroOne[] = "DepthClampZeroOneFeatures.depthClampZeroOne";
8686

8787
const char kShaderSubgroupExtendedTypes[] =
8888
"ShaderSubgroupExtendedTypesFeatures.shaderSubgroupExtendedTypes";
89+
const char kCooperativeMatrix[] =
90+
"CooperativeMatrixFeaturesKHR.cooperativeMatrix";
8991

9092
const char kAccelerationStructure[] =
9193
"AccelerationStructureFeaturesKHR.accelerationStructure";
@@ -912,6 +914,8 @@ amber::Result ConfigHelperVulkan::CheckVulkanPhysicalDeviceRequirements(
912914
supports_.depth_clamp_zero_one = true;
913915
} else if (ext == VK_KHR_SHADER_SUBGROUP_EXTENDED_TYPES_EXTENSION_NAME) {
914916
supports_.shader_subgroup_extended_types = true;
917+
} else if (ext == VK_KHR_COOPERATIVE_MATRIX_EXTENSION_NAME) {
918+
supports_.cooperative_matrix = true;
915919
} else if (ext == VK_KHR_VARIABLE_POINTERS_EXTENSION_NAME) {
916920
supports_.variable_pointers = true;
917921
} else if (ext == VK_KHR_ACCELERATION_STRUCTURE_EXTENSION_NAME) {
@@ -954,6 +958,7 @@ amber::Result ConfigHelperVulkan::CheckVulkanPhysicalDeviceRequirements(
954958
VkPhysicalDevice8BitStorageFeaturesKHR storage_8bit_features = {};
955959
VkPhysicalDevice16BitStorageFeaturesKHR storage_16bit_features = {};
956960
VkPhysicalDeviceVulkanMemoryModelFeatures memory_model_structure_features{};
961+
VkPhysicalDeviceCooperativeMatrixFeaturesKHR cooperative_matrix_features{};
957962
VkPhysicalDeviceZeroInitializeWorkgroupMemoryFeaturesKHR
958963
zero_initialize_workgroup_memory_features{};
959964
#ifdef VK_EXT_SHADER_LONG_VECTOR_EXTENSION_NAME
@@ -993,6 +998,13 @@ amber::Result ConfigHelperVulkan::CheckVulkanPhysicalDeviceRequirements(
993998
memory_model_structure_features.pNext = next_ptr;
994999
next_ptr = &memory_model_structure_features;
9951000

1001+
if (supports_.cooperative_matrix) {
1002+
cooperative_matrix_features.sType =
1003+
VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_COOPERATIVE_MATRIX_FEATURES_KHR;
1004+
cooperative_matrix_features.pNext = next_ptr;
1005+
next_ptr = &cooperative_matrix_features;
1006+
}
1007+
9961008
zero_initialize_workgroup_memory_features.sType =
9971009
// NOLINTNEXTLINE(whitespace/line_length)
9981010
VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_ZERO_INITIALIZE_WORKGROUP_MEMORY_FEATURES_KHR;
@@ -1080,6 +1092,10 @@ amber::Result ConfigHelperVulkan::CheckVulkanPhysicalDeviceRequirements(
10801092
supports_.depth_clamp_zero_one =
10811093
depth_clamp_zero_one_features.depthClampZeroOne;
10821094
}
1095+
if (supports_.cooperative_matrix) {
1096+
supports_.cooperative_matrix =
1097+
cooperative_matrix_features.cooperativeMatrix;
1098+
}
10831099

10841100
std::vector<std::string> required_features1;
10851101
for (const auto& feature : required_features) {
@@ -1103,6 +1119,8 @@ amber::Result ConfigHelperVulkan::CheckVulkanPhysicalDeviceRequirements(
11031119
(feature == kVulkanMemoryModel_vulkanMemoryModelDeviceScope &&
11041120
memory_model_structure_features.vulkanMemoryModelDeviceScope ==
11051121
VK_FALSE) ||
1122+
(feature == kCooperativeMatrix &&
1123+
cooperative_matrix_features.cooperativeMatrix == VK_FALSE) ||
11061124
(feature ==
11071125
// NOLINTNEXTLINE(whitespace/line_length)
11081126
kZeroInitializeWorkgroupMemory_shaderZeroInitializeWorkgroupMemory &&
@@ -1424,6 +1442,12 @@ amber::Result ConfigHelperVulkan::CreateDeviceWithFeatures2(
14241442
VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_DEPTH_CLAMP_ZERO_ONE_FEATURES_EXT,
14251443
VK_EXT_DEPTH_CLAMP_ZERO_ONE_EXTENSION_NAME);
14261444
features_.depth_clamp_zero_one.depthClampZeroOne = VK_TRUE;
1445+
} else if (feature == kCooperativeMatrix) {
1446+
init_feature(
1447+
supports_.cooperative_matrix, features_.cooperative_matrix,
1448+
VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_COOPERATIVE_MATRIX_FEATURES_KHR,
1449+
VK_KHR_COOPERATIVE_MATRIX_EXTENSION_NAME);
1450+
features_.cooperative_matrix.cooperativeMatrix = VK_TRUE;
14271451
} else if (feature == kAccelerationStructure) {
14281452
init_feature(
14291453
supports_.acceleration_structure, features_.acceleration_structure,

samples/config_helper_vulkan.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ class ConfigHelperVulkan : public ConfigHelperImpl {
123123
bool subgroup_size_control = false;
124124
bool depth_clamp_zero_one = false;
125125
bool shader_subgroup_extended_types = false;
126+
bool cooperative_matrix = false;
126127
bool acceleration_structure = false;
127128
bool buffer_device_address = false;
128129
bool ray_tracing_pipeline = false;
@@ -146,6 +147,7 @@ class ConfigHelperVulkan : public ConfigHelperImpl {
146147
VkPhysicalDeviceDepthClampZeroOneFeaturesEXT depth_clamp_zero_one{};
147148
VkPhysicalDeviceShaderSubgroupExtendedTypesFeatures
148149
shader_subgroup_extended_types{};
150+
VkPhysicalDeviceCooperativeMatrixFeaturesKHR cooperative_matrix{};
149151
VkPhysicalDeviceAccelerationStructureFeaturesKHR acceleration_structure{};
150152
VkPhysicalDeviceBufferDeviceAddressFeatures buffer_device_address{};
151153
VkPhysicalDeviceRayTracingPipelineFeaturesKHR ray_tracing_pipeline{};

src/amberscript/parser_device_feature_test.cc

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ DEVICE_FEATURE SubgroupSizeControl.subgroupSizeControl
3737
DEVICE_FEATURE SubgroupSizeControl.computeFullSubgroups
3838
DEVICE_FEATURE VulkanMemoryModelFeatures.vulkanMemoryModel
3939
DEVICE_FEATURE VulkanMemoryModelFeatures.vulkanMemoryModelDeviceScope
40+
DEVICE_FEATURE CooperativeMatrixFeaturesKHR.cooperativeMatrix
4041
DEVICE_FEATURE ZeroInitializeWorkgroupMemoryFeatures.shaderZeroInitializeWorkgroupMemory
4142
DEVICE_FEATURE ShaderLongVectorFeaturesEXT.longVector)";
4243

@@ -46,7 +47,7 @@ DEVICE_FEATURE ShaderLongVectorFeaturesEXT.longVector)";
4647

4748
auto script = parser.GetScript();
4849
const auto& features = script->GetRequiredFeatures();
49-
ASSERT_EQ(17U, features.size());
50+
ASSERT_EQ(18U, features.size());
5051
EXPECT_EQ("vertexPipelineStoresAndAtomics", features[0]);
5152
EXPECT_EQ("VariablePointerFeatures.variablePointersStorageBuffer",
5253
features[1]);
@@ -66,10 +67,11 @@ DEVICE_FEATURE ShaderLongVectorFeaturesEXT.longVector)";
6667
EXPECT_EQ("VulkanMemoryModelFeatures.vulkanMemoryModel", features[13]);
6768
EXPECT_EQ("VulkanMemoryModelFeatures.vulkanMemoryModelDeviceScope",
6869
features[14]);
70+
EXPECT_EQ("CooperativeMatrixFeaturesKHR.cooperativeMatrix", features[15]);
6971
EXPECT_EQ("ZeroInitializeWorkgroupMemoryFeatures."
7072
"shaderZeroInitializeWorkgroupMemory",
71-
features[15]);
72-
EXPECT_EQ("ShaderLongVectorFeaturesEXT.longVector", features[16]);
73+
features[16]);
74+
EXPECT_EQ("ShaderLongVectorFeaturesEXT.longVector", features[17]);
7375
}
7476

7577
TEST_F(AmberScriptParserTest, DeviceFeatureMissingFeature) {

src/script.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ bool Script::IsKnownFeature(const std::string& name) const {
140140
name ==
141141
"ShaderSubgroupExtendedTypesFeatures"
142142
".shaderSubgroupExtendedTypes" ||
143+
name == "CooperativeMatrixFeaturesKHR.cooperativeMatrix" ||
143144
name == "RayTracingPipelineFeaturesKHR.rayTracingPipeline" ||
144145
name == "AccelerationStructureFeaturesKHR.accelerationStructure" ||
145146
name == "BufferDeviceAddressFeatures.bufferDeviceAddress" ||

src/vulkan/device.cc

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,8 @@ const char kShaderSubgroupExtendedTypes[] =
100100
"ShaderSubgroupExtendedTypesFeatures.shaderSubgroupExtendedTypes";
101101

102102
const char kIndexTypeUint8[] = "IndexTypeUint8Features.indexTypeUint8";
103+
const char kCooperativeMatrix[] =
104+
"CooperativeMatrixFeaturesKHR.cooperativeMatrix";
103105

104106
const char kAccelerationStructure[] =
105107
"AccelerationStructureFeaturesKHR.accelerationStructure";
@@ -574,6 +576,8 @@ Result Device::Initialize(
574576
VkPhysicalDeviceShaderSubgroupExtendedTypesFeatures*
575577
shader_subgroup_extended_types_ptrs = nullptr;
576578
VkPhysicalDeviceIndexTypeUint8FeaturesEXT* index_type_uint8_ptrs = nullptr;
579+
VkPhysicalDeviceCooperativeMatrixFeaturesKHR* cooperative_matrix_ptrs =
580+
nullptr;
577581
VkPhysicalDeviceAccelerationStructureFeaturesKHR*
578582
acceleration_structure_ptrs = nullptr;
579583
VkPhysicalDeviceBufferDeviceAddressFeatures* bda_ptrs = nullptr;
@@ -618,6 +622,10 @@ Result Device::Initialize(
618622
index_type_uint8_ptrs =
619623
static_cast<VkPhysicalDeviceIndexTypeUint8FeaturesEXT*>(ptr);
620624
break;
625+
case VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_COOPERATIVE_MATRIX_FEATURES_KHR:
626+
cooperative_matrix_ptrs =
627+
static_cast<VkPhysicalDeviceCooperativeMatrixFeaturesKHR*>(ptr);
628+
break;
621629
// NOLINTNEXTLINE(whitespace/line_length)
622630
case VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_ACCELERATION_STRUCTURE_FEATURES_KHR:
623631
acceleration_structure_ptrs =
@@ -703,6 +711,10 @@ Result Device::Initialize(
703711
depth_clamp_zero_one_features->depthClampZeroOne != VK_TRUE)) {
704712
return amber::Result("Depth clamp zero one requested but not returned");
705713
}
714+
if (feature == kCooperativeMatrix && cooperative_matrix_ptrs == nullptr) {
715+
return amber::Result(
716+
"Cooperative matrix requested but feature not returned");
717+
}
706718
if (feature == kAccelerationStructure) {
707719
if (acceleration_structure_ptrs == nullptr) {
708720
return amber::Result(
@@ -921,6 +933,10 @@ Result Device::Initialize(
921933
return amber::Result(
922934
"Index type uint8_t requested but feature not returned");
923935
}
936+
if (feature == kCooperativeMatrix &&
937+
cooperative_matrix_ptrs->cooperativeMatrix != VK_TRUE) {
938+
return amber::Result("Missing cooperative matrix feature");
939+
}
924940
}
925941

926942
if (!AreAllExtensionsSupported(available_extensions,

0 commit comments

Comments
 (0)