1 // Copyright (c) 2018 Google LLC.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include <sstream>
16 #include <string>
17 #include <tuple>
18
19 #include "gmock/gmock.h"
20 #include "test/unit_spirv.h"
21 #include "test/val/val_fixtures.h"
22
23 namespace spvtools {
24 namespace val {
25 namespace {
26
27 using ::testing::Combine;
28 using ::testing::HasSubstr;
29 using ::testing::Values;
30 using ::testing::ValuesIn;
31
GenerateShaderCode(const std::string & body,const std::string & capabilities_and_extensions="",const std::string & execution_model="GLCompute")32 std::string GenerateShaderCode(
33 const std::string& body,
34 const std::string& capabilities_and_extensions = "",
35 const std::string& execution_model = "GLCompute") {
36 std::ostringstream ss;
37 ss << R"(
38 OpCapability Shader
39 OpCapability GroupNonUniform
40 OpCapability GroupNonUniformVote
41 OpCapability GroupNonUniformBallot
42 OpCapability GroupNonUniformShuffle
43 OpCapability GroupNonUniformShuffleRelative
44 OpCapability GroupNonUniformArithmetic
45 OpCapability GroupNonUniformClustered
46 OpCapability GroupNonUniformQuad
47 )";
48
49 ss << capabilities_and_extensions;
50 ss << "OpMemoryModel Logical GLSL450\n";
51 ss << "OpEntryPoint " << execution_model << " %main \"main\"\n";
52 if (execution_model == "GLCompute") {
53 ss << "OpExecutionMode %main LocalSize 1 1 1\n";
54 }
55
56 ss << R"(
57 %void = OpTypeVoid
58 %func = OpTypeFunction %void
59 %bool = OpTypeBool
60 %u32 = OpTypeInt 32 0
61 %int = OpTypeInt 32 1
62 %float = OpTypeFloat 32
63 %u32vec4 = OpTypeVector %u32 4
64 %u32vec3 = OpTypeVector %u32 3
65
66 %true = OpConstantTrue %bool
67 %false = OpConstantFalse %bool
68
69 %u32_0 = OpConstant %u32 0
70
71 %float_0 = OpConstant %float 0
72
73 %u32vec4_null = OpConstantComposite %u32vec4 %u32_0 %u32_0 %u32_0 %u32_0
74 %u32vec3_null = OpConstantComposite %u32vec3 %u32_0 %u32_0 %u32_0
75
76 %cross_device = OpConstant %u32 0
77 %device = OpConstant %u32 1
78 %workgroup = OpConstant %u32 2
79 %subgroup = OpConstant %u32 3
80 %invocation = OpConstant %u32 4
81
82 %reduce = OpConstant %u32 0
83 %inclusive_scan = OpConstant %u32 1
84 %exclusive_scan = OpConstant %u32 2
85 %clustered_reduce = OpConstant %u32 3
86
87 %main = OpFunction %void None %func
88 %main_entry = OpLabel
89 )";
90
91 ss << body;
92
93 ss << R"(
94 OpReturn
95 OpFunctionEnd)";
96
97 return ss.str();
98 }
99
100 SpvScope scopes[] = {SpvScopeCrossDevice, SpvScopeDevice, SpvScopeWorkgroup,
101 SpvScopeSubgroup, SpvScopeInvocation};
102
103 using GroupNonUniform = spvtest::ValidateBase<
104 std::tuple<std::string, std::string, SpvScope, std::string, std::string>>;
105
ConvertScope(SpvScope scope)106 std::string ConvertScope(SpvScope scope) {
107 switch (scope) {
108 case SpvScopeCrossDevice:
109 return "%cross_device";
110 case SpvScopeDevice:
111 return "%device";
112 case SpvScopeWorkgroup:
113 return "%workgroup";
114 case SpvScopeSubgroup:
115 return "%subgroup";
116 case SpvScopeInvocation:
117 return "%invocation";
118 default:
119 return "";
120 }
121 }
122
TEST_P(GroupNonUniform,Vulkan1p1)123 TEST_P(GroupNonUniform, Vulkan1p1) {
124 std::string opcode = std::get<0>(GetParam());
125 std::string type = std::get<1>(GetParam());
126 SpvScope execution_scope = std::get<2>(GetParam());
127 std::string args = std::get<3>(GetParam());
128 std::string error = std::get<4>(GetParam());
129
130 std::ostringstream sstr;
131 sstr << "%result = " << opcode << " ";
132 sstr << type << " ";
133 sstr << ConvertScope(execution_scope) << " ";
134 sstr << args << "\n";
135
136 CompileSuccessfully(GenerateShaderCode(sstr.str()), SPV_ENV_VULKAN_1_1);
137 spv_result_t result = ValidateInstructions(SPV_ENV_VULKAN_1_1);
138 if (error == "") {
139 if (execution_scope == SpvScopeSubgroup) {
140 EXPECT_EQ(SPV_SUCCESS, result);
141 } else {
142 EXPECT_EQ(SPV_ERROR_INVALID_DATA, result);
143 EXPECT_THAT(
144 getDiagnosticString(),
145 HasSubstr(
146 "in Vulkan environment Execution scope is limited to Subgroup"));
147 }
148 } else {
149 EXPECT_EQ(SPV_ERROR_INVALID_DATA, result);
150 EXPECT_THAT(getDiagnosticString(), HasSubstr(error));
151 }
152 }
153
TEST_P(GroupNonUniform,Spirv1p3)154 TEST_P(GroupNonUniform, Spirv1p3) {
155 std::string opcode = std::get<0>(GetParam());
156 std::string type = std::get<1>(GetParam());
157 SpvScope execution_scope = std::get<2>(GetParam());
158 std::string args = std::get<3>(GetParam());
159 std::string error = std::get<4>(GetParam());
160
161 std::ostringstream sstr;
162 sstr << "%result = " << opcode << " ";
163 sstr << type << " ";
164 sstr << ConvertScope(execution_scope) << " ";
165 sstr << args << "\n";
166
167 CompileSuccessfully(GenerateShaderCode(sstr.str()), SPV_ENV_UNIVERSAL_1_3);
168 spv_result_t result = ValidateInstructions(SPV_ENV_UNIVERSAL_1_3);
169 if (error == "") {
170 if (execution_scope == SpvScopeSubgroup ||
171 execution_scope == SpvScopeWorkgroup) {
172 EXPECT_EQ(SPV_SUCCESS, result);
173 } else {
174 EXPECT_EQ(SPV_ERROR_INVALID_DATA, result);
175 EXPECT_THAT(
176 getDiagnosticString(),
177 HasSubstr("Execution scope is limited to Subgroup or Workgroup"));
178 }
179 } else {
180 EXPECT_EQ(SPV_ERROR_INVALID_DATA, result);
181 EXPECT_THAT(getDiagnosticString(), HasSubstr(error));
182 }
183 }
184
185 INSTANTIATE_TEST_SUITE_P(GroupNonUniformElect, GroupNonUniform,
186 Combine(Values("OpGroupNonUniformElect"),
187 Values("%bool"), ValuesIn(scopes), Values(""),
188 Values("")));
189
190 INSTANTIATE_TEST_SUITE_P(GroupNonUniformVote, GroupNonUniform,
191 Combine(Values("OpGroupNonUniformAll",
192 "OpGroupNonUniformAny",
193 "OpGroupNonUniformAllEqual"),
194 Values("%bool"), ValuesIn(scopes),
195 Values("%true"), Values("")));
196
197 INSTANTIATE_TEST_SUITE_P(GroupNonUniformBroadcast, GroupNonUniform,
198 Combine(Values("OpGroupNonUniformBroadcast"),
199 Values("%bool"), ValuesIn(scopes),
200 Values("%true %u32_0"), Values("")));
201
202 INSTANTIATE_TEST_SUITE_P(GroupNonUniformBroadcastFirst, GroupNonUniform,
203 Combine(Values("OpGroupNonUniformBroadcastFirst"),
204 Values("%bool"), ValuesIn(scopes),
205 Values("%true"), Values("")));
206
207 INSTANTIATE_TEST_SUITE_P(GroupNonUniformBallot, GroupNonUniform,
208 Combine(Values("OpGroupNonUniformBallot"),
209 Values("%u32vec4"), ValuesIn(scopes),
210 Values("%true"), Values("")));
211
212 INSTANTIATE_TEST_SUITE_P(GroupNonUniformInverseBallot, GroupNonUniform,
213 Combine(Values("OpGroupNonUniformInverseBallot"),
214 Values("%bool"), ValuesIn(scopes),
215 Values("%u32vec4_null"), Values("")));
216
217 INSTANTIATE_TEST_SUITE_P(GroupNonUniformBallotBitExtract, GroupNonUniform,
218 Combine(Values("OpGroupNonUniformBallotBitExtract"),
219 Values("%bool"), ValuesIn(scopes),
220 Values("%u32vec4_null %u32_0"), Values("")));
221
222 INSTANTIATE_TEST_SUITE_P(GroupNonUniformBallotBitCount, GroupNonUniform,
223 Combine(Values("OpGroupNonUniformBallotBitCount"),
224 Values("%u32"), ValuesIn(scopes),
225 Values("Reduce %u32vec4_null"), Values("")));
226
227 INSTANTIATE_TEST_SUITE_P(GroupNonUniformBallotFind, GroupNonUniform,
228 Combine(Values("OpGroupNonUniformBallotFindLSB",
229 "OpGroupNonUniformBallotFindMSB"),
230 Values("%u32"), ValuesIn(scopes),
231 Values("%u32vec4_null"), Values("")));
232
233 INSTANTIATE_TEST_SUITE_P(GroupNonUniformShuffle, GroupNonUniform,
234 Combine(Values("OpGroupNonUniformShuffle",
235 "OpGroupNonUniformShuffleXor",
236 "OpGroupNonUniformShuffleUp",
237 "OpGroupNonUniformShuffleDown"),
238 Values("%u32"), ValuesIn(scopes),
239 Values("%u32_0 %u32_0"), Values("")));
240
241 INSTANTIATE_TEST_SUITE_P(
242 GroupNonUniformIntegerArithmetic, GroupNonUniform,
243 Combine(Values("OpGroupNonUniformIAdd", "OpGroupNonUniformIMul",
244 "OpGroupNonUniformSMin", "OpGroupNonUniformUMin",
245 "OpGroupNonUniformSMax", "OpGroupNonUniformUMax",
246 "OpGroupNonUniformBitwiseAnd", "OpGroupNonUniformBitwiseOr",
247 "OpGroupNonUniformBitwiseXor"),
248 Values("%u32"), ValuesIn(scopes), Values("Reduce %u32_0"),
249 Values("")));
250
251 INSTANTIATE_TEST_SUITE_P(
252 GroupNonUniformFloatArithmetic, GroupNonUniform,
253 Combine(Values("OpGroupNonUniformFAdd", "OpGroupNonUniformFMul",
254 "OpGroupNonUniformFMin", "OpGroupNonUniformFMax"),
255 Values("%float"), ValuesIn(scopes), Values("Reduce %float_0"),
256 Values("")));
257
258 INSTANTIATE_TEST_SUITE_P(GroupNonUniformLogicalArithmetic, GroupNonUniform,
259 Combine(Values("OpGroupNonUniformLogicalAnd",
260 "OpGroupNonUniformLogicalOr",
261 "OpGroupNonUniformLogicalXor"),
262 Values("%bool"), ValuesIn(scopes),
263 Values("Reduce %true"), Values("")));
264
265 INSTANTIATE_TEST_SUITE_P(GroupNonUniformQuad, GroupNonUniform,
266 Combine(Values("OpGroupNonUniformQuadBroadcast",
267 "OpGroupNonUniformQuadSwap"),
268 Values("%u32"), ValuesIn(scopes),
269 Values("%u32_0 %u32_0"), Values("")));
270
271 INSTANTIATE_TEST_SUITE_P(GroupNonUniformBallotBitCountScope, GroupNonUniform,
272 Combine(Values("OpGroupNonUniformBallotBitCount"),
273 Values("%u32"), ValuesIn(scopes),
274 Values("Reduce %u32vec4_null"), Values("")));
275
276 INSTANTIATE_TEST_SUITE_P(
277 GroupNonUniformBallotBitCountBadResultType, GroupNonUniform,
278 Combine(
279 Values("OpGroupNonUniformBallotBitCount"), Values("%float", "%int"),
280 Values(SpvScopeSubgroup), Values("Reduce %u32vec4_null"),
281 Values("Expected Result Type to be an unsigned integer type scalar.")));
282
283 INSTANTIATE_TEST_SUITE_P(GroupNonUniformBallotBitCountBadValue, GroupNonUniform,
284 Combine(Values("OpGroupNonUniformBallotBitCount"),
285 Values("%u32"), Values(SpvScopeSubgroup),
286 Values("Reduce %u32vec3_null", "Reduce %u32_0",
287 "Reduce %float_0"),
288 Values("Expected Value to be a vector of four "
289 "components of integer type scalar")));
290
291 } // namespace
292 } // namespace val
293 } // namespace spvtools
294