1 // Copyright (c) 2018 Google LLC.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include <sstream>
16 #include <string>
17 #include <tuple>
18
19 #include "gmock/gmock.h"
20 #include "test/unit_spirv.h"
21 #include "test/val/val_fixtures.h"
22
23 namespace spvtools {
24 namespace val {
25 namespace {
26
27 using ::testing::Combine;
28 using ::testing::HasSubstr;
29 using ::testing::Values;
30 using ::testing::ValuesIn;
31
GenerateShaderCode(const std::string & body,const std::string & capabilities_and_extensions="",const std::string & execution_model="GLCompute")32 std::string GenerateShaderCode(
33 const std::string& body,
34 const std::string& capabilities_and_extensions = "",
35 const std::string& execution_model = "GLCompute") {
36 std::ostringstream ss;
37 ss << R"(
38 OpCapability Shader
39 OpCapability GroupNonUniform
40 OpCapability GroupNonUniformVote
41 OpCapability GroupNonUniformBallot
42 OpCapability GroupNonUniformShuffle
43 OpCapability GroupNonUniformShuffleRelative
44 OpCapability GroupNonUniformArithmetic
45 OpCapability GroupNonUniformClustered
46 OpCapability GroupNonUniformQuad
47 )";
48
49 ss << capabilities_and_extensions;
50 ss << "OpMemoryModel Logical GLSL450\n";
51 ss << "OpEntryPoint " << execution_model << " %main \"main\"\n";
52 if (execution_model == "GLCompute") {
53 ss << "OpExecutionMode %main LocalSize 1 1 1\n";
54 }
55
56 ss << R"(
57 %void = OpTypeVoid
58 %func = OpTypeFunction %void
59 %bool = OpTypeBool
60 %u32 = OpTypeInt 32 0
61 %int = OpTypeInt 32 1
62 %float = OpTypeFloat 32
63 %u32vec4 = OpTypeVector %u32 4
64 %u32vec3 = OpTypeVector %u32 3
65
66 %true = OpConstantTrue %bool
67 %false = OpConstantFalse %bool
68
69 %u32_0 = OpConstant %u32 0
70
71 %float_0 = OpConstant %float 0
72
73 %u32vec4_null = OpConstantComposite %u32vec4 %u32_0 %u32_0 %u32_0 %u32_0
74 %u32vec3_null = OpConstantComposite %u32vec3 %u32_0 %u32_0 %u32_0
75
76 %cross_device = OpConstant %u32 0
77 %device = OpConstant %u32 1
78 %workgroup = OpConstant %u32 2
79 %subgroup = OpConstant %u32 3
80 %invocation = OpConstant %u32 4
81
82 %reduce = OpConstant %u32 0
83 %inclusive_scan = OpConstant %u32 1
84 %exclusive_scan = OpConstant %u32 2
85 %clustered_reduce = OpConstant %u32 3
86
87 %main = OpFunction %void None %func
88 %main_entry = OpLabel
89 )";
90
91 ss << body;
92
93 ss << R"(
94 OpReturn
95 OpFunctionEnd)";
96
97 return ss.str();
98 }
99
100 SpvScope scopes[] = {SpvScopeCrossDevice, SpvScopeDevice, SpvScopeWorkgroup,
101 SpvScopeSubgroup, SpvScopeInvocation};
102
103 using ValidateGroupNonUniform = spvtest::ValidateBase<bool>;
104 using GroupNonUniform = spvtest::ValidateBase<
105 std::tuple<std::string, std::string, SpvScope, std::string, std::string>>;
106
ConvertScope(SpvScope scope)107 std::string ConvertScope(SpvScope scope) {
108 switch (scope) {
109 case SpvScopeCrossDevice:
110 return "%cross_device";
111 case SpvScopeDevice:
112 return "%device";
113 case SpvScopeWorkgroup:
114 return "%workgroup";
115 case SpvScopeSubgroup:
116 return "%subgroup";
117 case SpvScopeInvocation:
118 return "%invocation";
119 default:
120 return "";
121 }
122 }
123
TEST_P(GroupNonUniform,Vulkan1p1)124 TEST_P(GroupNonUniform, Vulkan1p1) {
125 std::string opcode = std::get<0>(GetParam());
126 std::string type = std::get<1>(GetParam());
127 SpvScope execution_scope = std::get<2>(GetParam());
128 std::string args = std::get<3>(GetParam());
129 std::string error = std::get<4>(GetParam());
130
131 std::ostringstream sstr;
132 sstr << "%result = " << opcode << " ";
133 sstr << type << " ";
134 sstr << ConvertScope(execution_scope) << " ";
135 sstr << args << "\n";
136
137 CompileSuccessfully(GenerateShaderCode(sstr.str()), SPV_ENV_VULKAN_1_1);
138 spv_result_t result = ValidateInstructions(SPV_ENV_VULKAN_1_1);
139 if (error == "") {
140 if (execution_scope == SpvScopeSubgroup) {
141 EXPECT_EQ(SPV_SUCCESS, result);
142 } else {
143 EXPECT_EQ(SPV_ERROR_INVALID_DATA, result);
144 EXPECT_THAT(getDiagnosticString(),
145 AnyVUID("VUID-StandaloneSpirv-None-04642"));
146 EXPECT_THAT(
147 getDiagnosticString(),
148 HasSubstr(
149 "in Vulkan environment Execution scope is limited to Subgroup"));
150 }
151 } else {
152 EXPECT_EQ(SPV_ERROR_INVALID_DATA, result);
153 EXPECT_THAT(getDiagnosticString(), HasSubstr(error));
154 }
155 }
156
TEST_P(GroupNonUniform,Spirv1p3)157 TEST_P(GroupNonUniform, Spirv1p3) {
158 std::string opcode = std::get<0>(GetParam());
159 std::string type = std::get<1>(GetParam());
160 SpvScope execution_scope = std::get<2>(GetParam());
161 std::string args = std::get<3>(GetParam());
162 std::string error = std::get<4>(GetParam());
163
164 std::ostringstream sstr;
165 sstr << "%result = " << opcode << " ";
166 sstr << type << " ";
167 sstr << ConvertScope(execution_scope) << " ";
168 sstr << args << "\n";
169
170 CompileSuccessfully(GenerateShaderCode(sstr.str()), SPV_ENV_UNIVERSAL_1_3);
171 spv_result_t result = ValidateInstructions(SPV_ENV_UNIVERSAL_1_3);
172 if (error == "") {
173 if (execution_scope == SpvScopeSubgroup ||
174 execution_scope == SpvScopeWorkgroup) {
175 EXPECT_EQ(SPV_SUCCESS, result);
176 } else {
177 EXPECT_EQ(SPV_ERROR_INVALID_DATA, result);
178 EXPECT_THAT(
179 getDiagnosticString(),
180 HasSubstr("Execution scope is limited to Subgroup or Workgroup"));
181 }
182 } else {
183 EXPECT_EQ(SPV_ERROR_INVALID_DATA, result);
184 EXPECT_THAT(getDiagnosticString(), HasSubstr(error));
185 }
186 }
187
188 INSTANTIATE_TEST_SUITE_P(GroupNonUniformElect, GroupNonUniform,
189 Combine(Values("OpGroupNonUniformElect"),
190 Values("%bool"), ValuesIn(scopes), Values(""),
191 Values("")));
192
193 INSTANTIATE_TEST_SUITE_P(GroupNonUniformVote, GroupNonUniform,
194 Combine(Values("OpGroupNonUniformAll",
195 "OpGroupNonUniformAny",
196 "OpGroupNonUniformAllEqual"),
197 Values("%bool"), ValuesIn(scopes),
198 Values("%true"), Values("")));
199
200 INSTANTIATE_TEST_SUITE_P(GroupNonUniformBroadcast, GroupNonUniform,
201 Combine(Values("OpGroupNonUniformBroadcast"),
202 Values("%bool"), ValuesIn(scopes),
203 Values("%true %u32_0"), Values("")));
204
205 INSTANTIATE_TEST_SUITE_P(GroupNonUniformBroadcastFirst, GroupNonUniform,
206 Combine(Values("OpGroupNonUniformBroadcastFirst"),
207 Values("%bool"), ValuesIn(scopes),
208 Values("%true"), Values("")));
209
210 INSTANTIATE_TEST_SUITE_P(GroupNonUniformBallot, GroupNonUniform,
211 Combine(Values("OpGroupNonUniformBallot"),
212 Values("%u32vec4"), ValuesIn(scopes),
213 Values("%true"), Values("")));
214
215 INSTANTIATE_TEST_SUITE_P(GroupNonUniformInverseBallot, GroupNonUniform,
216 Combine(Values("OpGroupNonUniformInverseBallot"),
217 Values("%bool"), ValuesIn(scopes),
218 Values("%u32vec4_null"), Values("")));
219
220 INSTANTIATE_TEST_SUITE_P(GroupNonUniformBallotBitExtract, GroupNonUniform,
221 Combine(Values("OpGroupNonUniformBallotBitExtract"),
222 Values("%bool"), ValuesIn(scopes),
223 Values("%u32vec4_null %u32_0"), Values("")));
224
225 INSTANTIATE_TEST_SUITE_P(GroupNonUniformBallotBitCount, GroupNonUniform,
226 Combine(Values("OpGroupNonUniformBallotBitCount"),
227 Values("%u32"), ValuesIn(scopes),
228 Values("Reduce %u32vec4_null"), Values("")));
229
230 INSTANTIATE_TEST_SUITE_P(GroupNonUniformBallotFind, GroupNonUniform,
231 Combine(Values("OpGroupNonUniformBallotFindLSB",
232 "OpGroupNonUniformBallotFindMSB"),
233 Values("%u32"), ValuesIn(scopes),
234 Values("%u32vec4_null"), Values("")));
235
236 INSTANTIATE_TEST_SUITE_P(GroupNonUniformShuffle, GroupNonUniform,
237 Combine(Values("OpGroupNonUniformShuffle",
238 "OpGroupNonUniformShuffleXor",
239 "OpGroupNonUniformShuffleUp",
240 "OpGroupNonUniformShuffleDown"),
241 Values("%u32"), ValuesIn(scopes),
242 Values("%u32_0 %u32_0"), Values("")));
243
244 INSTANTIATE_TEST_SUITE_P(
245 GroupNonUniformIntegerArithmetic, GroupNonUniform,
246 Combine(Values("OpGroupNonUniformIAdd", "OpGroupNonUniformIMul",
247 "OpGroupNonUniformSMin", "OpGroupNonUniformUMin",
248 "OpGroupNonUniformSMax", "OpGroupNonUniformUMax",
249 "OpGroupNonUniformBitwiseAnd", "OpGroupNonUniformBitwiseOr",
250 "OpGroupNonUniformBitwiseXor"),
251 Values("%u32"), ValuesIn(scopes), Values("Reduce %u32_0"),
252 Values("")));
253
254 INSTANTIATE_TEST_SUITE_P(
255 GroupNonUniformFloatArithmetic, GroupNonUniform,
256 Combine(Values("OpGroupNonUniformFAdd", "OpGroupNonUniformFMul",
257 "OpGroupNonUniformFMin", "OpGroupNonUniformFMax"),
258 Values("%float"), ValuesIn(scopes), Values("Reduce %float_0"),
259 Values("")));
260
261 INSTANTIATE_TEST_SUITE_P(GroupNonUniformLogicalArithmetic, GroupNonUniform,
262 Combine(Values("OpGroupNonUniformLogicalAnd",
263 "OpGroupNonUniformLogicalOr",
264 "OpGroupNonUniformLogicalXor"),
265 Values("%bool"), ValuesIn(scopes),
266 Values("Reduce %true"), Values("")));
267
268 INSTANTIATE_TEST_SUITE_P(GroupNonUniformQuad, GroupNonUniform,
269 Combine(Values("OpGroupNonUniformQuadBroadcast",
270 "OpGroupNonUniformQuadSwap"),
271 Values("%u32"), ValuesIn(scopes),
272 Values("%u32_0 %u32_0"), Values("")));
273
274 INSTANTIATE_TEST_SUITE_P(GroupNonUniformBallotBitCountScope, GroupNonUniform,
275 Combine(Values("OpGroupNonUniformBallotBitCount"),
276 Values("%u32"), ValuesIn(scopes),
277 Values("Reduce %u32vec4_null"), Values("")));
278
279 INSTANTIATE_TEST_SUITE_P(
280 GroupNonUniformBallotBitCountBadResultType, GroupNonUniform,
281 Combine(
282 Values("OpGroupNonUniformBallotBitCount"), Values("%float", "%int"),
283 Values(SpvScopeSubgroup), Values("Reduce %u32vec4_null"),
284 Values("Expected Result Type to be an unsigned integer type scalar.")));
285
286 INSTANTIATE_TEST_SUITE_P(GroupNonUniformBallotBitCountBadValue, GroupNonUniform,
287 Combine(Values("OpGroupNonUniformBallotBitCount"),
288 Values("%u32"), Values(SpvScopeSubgroup),
289 Values("Reduce %u32vec3_null", "Reduce %u32_0",
290 "Reduce %float_0"),
291 Values("Expected Value to be a vector of four "
292 "components of integer type scalar")));
293
TEST_F(ValidateGroupNonUniform,VulkanGroupNonUniformBallotBitCountOperation)294 TEST_F(ValidateGroupNonUniform, VulkanGroupNonUniformBallotBitCountOperation) {
295 std::string test = R"(
296 OpCapability Shader
297 OpCapability GroupNonUniform
298 OpCapability GroupNonUniformBallot
299 OpCapability GroupNonUniformClustered
300 OpMemoryModel Logical GLSL450
301 OpEntryPoint GLCompute %main "main"
302 OpExecutionMode %main LocalSize 1 1 1
303 %void = OpTypeVoid
304 %func = OpTypeFunction %void
305 %u32 = OpTypeInt 32 0
306 %u32vec4 = OpTypeVector %u32 4
307 %u32_0 = OpConstant %u32 0
308 %u32vec4_null = OpConstantComposite %u32vec4 %u32_0 %u32_0 %u32_0 %u32_0
309 %subgroup = OpConstant %u32 3
310 %main = OpFunction %void None %func
311 %main_entry = OpLabel
312 %result = OpGroupNonUniformBallotBitCount %u32 %subgroup ClusteredReduce %u32vec4_null
313 OpReturn
314 OpFunctionEnd
315 )";
316
317 CompileSuccessfully(test, SPV_ENV_VULKAN_1_1);
318 ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions(SPV_ENV_VULKAN_1_1));
319 EXPECT_THAT(
320 getDiagnosticString(),
321 AnyVUID("VUID-StandaloneSpirv-OpGroupNonUniformBallotBitCount-04685"));
322 EXPECT_THAT(
323 getDiagnosticString(),
324 HasSubstr(
325 "In Vulkan: The OpGroupNonUniformBallotBitCount group operation must "
326 "be only: Reduce, InclusiveScan, or ExclusiveScan."));
327 }
328
329 } // namespace
330 } // namespace val
331 } // namespace spvtools
332