1 // Copyright (c) 2019 Google LLC.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include <vector>
16
17 #include "test/opt/pass_fixture.h"
18 #include "test/opt/pass_utils.h"
19
20 namespace spvtools {
21 namespace opt {
22 namespace {
23
24 typedef std::tuple<std::string, std::string> StripAtomicCounterMemoryParam;
25
26 using MemorySemanticsModified =
27 PassTest<::testing::TestWithParam<StripAtomicCounterMemoryParam>>;
28 using NonMemorySemanticsUnmodifiedTest = PassTest<::testing::Test>;
29
operator +=(std::vector<const char * > & lhs,const char * rhs)30 void operator+=(std::vector<const char*>& lhs, const char* rhs) {
31 lhs.push_back(rhs);
32 }
33
GetConstDecl(std::string val)34 std::string GetConstDecl(std::string val) {
35 std::string decl;
36 decl += "%uint_" + val + " = OpConstant %uint " + val;
37 return decl;
38 }
39
GetUnchangedString(std::string (generate_inst)(std::string),std::string val)40 std::string GetUnchangedString(std::string(generate_inst)(std::string),
41 std::string val) {
42 std::string decl = GetConstDecl(val);
43 std::string inst = generate_inst(val);
44
45 std::vector<const char*> result = {
46 // clang-format off
47 "OpCapability Shader",
48 "OpCapability VulkanMemoryModel",
49 "OpExtension \"SPV_KHR_vulkan_memory_model\"",
50 "OpMemoryModel Logical Vulkan",
51 "OpEntryPoint Vertex %1 \"shader\"",
52 "%uint = OpTypeInt 32 0",
53 "%_ptr_Workgroup_uint = OpTypePointer Workgroup %uint",
54 "%4 = OpVariable %_ptr_Workgroup_uint Workgroup",
55 "%uint_0 = OpConstant %uint 0",
56 "%uint_1 = OpConstant %uint 1",
57 "%void = OpTypeVoid",
58 "%8 = OpTypeFunction %void",
59 decl.c_str(),
60 "%1 = OpFunction %void None %8",
61 "%10 = OpLabel",
62 inst.c_str(),
63 "OpReturn",
64 "OpFunctionEnd"
65 // clang-format on
66 };
67 return JoinAllInsts(result);
68 }
69
GetChangedString(std::string (generate_inst)(std::string),std::string orig,std::string changed)70 std::string GetChangedString(std::string(generate_inst)(std::string),
71 std::string orig, std::string changed) {
72 std::string orig_decl = GetConstDecl(orig);
73 std::string changed_decl = GetConstDecl(changed);
74 std::string inst = generate_inst(changed);
75
76 std::vector<const char*> result = {
77 // clang-format off
78 "OpCapability Shader",
79 "OpCapability VulkanMemoryModel",
80 "OpExtension \"SPV_KHR_vulkan_memory_model\"",
81 "OpMemoryModel Logical Vulkan",
82 "OpEntryPoint Vertex %1 \"shader\"",
83 "%uint = OpTypeInt 32 0",
84 "%_ptr_Workgroup_uint = OpTypePointer Workgroup %uint",
85 "%4 = OpVariable %_ptr_Workgroup_uint Workgroup",
86 "%uint_0 = OpConstant %uint 0",
87 "%uint_1 = OpConstant %uint 1",
88 "%void = OpTypeVoid",
89 "%8 = OpTypeFunction %void",
90 orig_decl.c_str() };
91 // clang-format on
92 if (changed != "0") result += changed_decl.c_str();
93 result += "%1 = OpFunction %void None %8";
94 result += "%10 = OpLabel";
95 result += inst.c_str();
96 result += "OpReturn";
97 result += "OpFunctionEnd";
98 return JoinAllInsts(result);
99 }
100
GetInputAndExpected(std::string (generate_inst)(std::string),StripAtomicCounterMemoryParam param)101 std::tuple<std::string, std::string> GetInputAndExpected(
102 std::string(generate_inst)(std::string),
103 StripAtomicCounterMemoryParam param) {
104 std::string orig = std::get<0>(param);
105 std::string changed = std::get<1>(param);
106 std::string input = GetUnchangedString(generate_inst, orig);
107 std::string expected = orig == changed
108 ? GetUnchangedString(generate_inst, changed)
109 : GetChangedString(generate_inst, orig, changed);
110 return std::make_tuple(input, expected);
111 }
112
GetOpControlBarrierInst(std::string val)113 std::string GetOpControlBarrierInst(std::string val) {
114 return "OpControlBarrier %uint_1 %uint_1 %uint_" + val;
115 }
116
TEST_P(MemorySemanticsModified,OpControlBarrier)117 TEST_P(MemorySemanticsModified, OpControlBarrier) {
118 std::string input, expected;
119 std::tie(input, expected) =
120 GetInputAndExpected(GetOpControlBarrierInst, GetParam());
121 SinglePassRunAndCheck<StripAtomicCounterMemoryPass>(input, expected,
122 /* skip_nop = */ false);
123 }
124
GetOpMemoryBarrierInst(std::string val)125 std::string GetOpMemoryBarrierInst(std::string val) {
126 return "OpMemoryBarrier %uint_1 %uint_" + val;
127 }
128
TEST_P(MemorySemanticsModified,OpMemoryBarrier)129 TEST_P(MemorySemanticsModified, OpMemoryBarrier) {
130 std::string input, expected;
131 std::tie(input, expected) =
132 GetInputAndExpected(GetOpMemoryBarrierInst, GetParam());
133 SinglePassRunAndCheck<StripAtomicCounterMemoryPass>(input, expected,
134 /* skip_nop = */ false);
135 }
136
GetOpAtomicLoadInst(std::string val)137 std::string GetOpAtomicLoadInst(std::string val) {
138 return "%11 = OpAtomicLoad %uint %4 %uint_1 %uint_" + val;
139 }
140
TEST_P(MemorySemanticsModified,OpAtomicLoad)141 TEST_P(MemorySemanticsModified, OpAtomicLoad) {
142 std::string input, expected;
143 std::tie(input, expected) =
144 GetInputAndExpected(GetOpAtomicLoadInst, GetParam());
145 SinglePassRunAndCheck<StripAtomicCounterMemoryPass>(input, expected,
146 /* skip_nop = */ false);
147 }
148
GetOpAtomicStoreInst(std::string val)149 std::string GetOpAtomicStoreInst(std::string val) {
150 return "OpAtomicStore %4 %uint_1 %uint_" + val + " %uint_1";
151 }
152
TEST_P(MemorySemanticsModified,OpAtomicStore)153 TEST_P(MemorySemanticsModified, OpAtomicStore) {
154 std::string input, expected;
155 std::tie(input, expected) =
156 GetInputAndExpected(GetOpAtomicStoreInst, GetParam());
157 SinglePassRunAndCheck<StripAtomicCounterMemoryPass>(input, expected,
158 /* skip_nop = */ false);
159 }
160
GetOpAtomicExchangeInst(std::string val)161 std::string GetOpAtomicExchangeInst(std::string val) {
162 return "%11 = OpAtomicExchange %uint %4 %uint_1 %uint_" + val + " %uint_0";
163 }
164
TEST_P(MemorySemanticsModified,OpAtomicExchange)165 TEST_P(MemorySemanticsModified, OpAtomicExchange) {
166 std::string input, expected;
167 std::tie(input, expected) =
168 GetInputAndExpected(GetOpAtomicExchangeInst, GetParam());
169 SinglePassRunAndCheck<StripAtomicCounterMemoryPass>(input, expected,
170 /* skip_nop = */ false);
171 }
172
GetOpAtomicCompareExchangeInst(std::string val)173 std::string GetOpAtomicCompareExchangeInst(std::string val) {
174 return "%11 = OpAtomicCompareExchange %uint %4 %uint_1 %uint_" + val +
175 " %uint_" + val + " %uint_0 %uint_0";
176 }
177
TEST_P(MemorySemanticsModified,OpAtomicCompareExchange)178 TEST_P(MemorySemanticsModified, OpAtomicCompareExchange) {
179 std::string input, expected;
180 std::tie(input, expected) =
181 GetInputAndExpected(GetOpAtomicCompareExchangeInst, GetParam());
182 SinglePassRunAndCheck<StripAtomicCounterMemoryPass>(input, expected,
183 /* skip_nop = */ false);
184 }
185
GetOpAtomicCompareExchangeWeakInst(std::string val)186 std::string GetOpAtomicCompareExchangeWeakInst(std::string val) {
187 return "%11 = OpAtomicCompareExchangeWeak %uint %4 %uint_1 %uint_" + val +
188 " %uint_" + val + " %uint_0 %uint_0";
189 }
190
TEST_P(MemorySemanticsModified,OpAtomicCompareExchangeWeak)191 TEST_P(MemorySemanticsModified, OpAtomicCompareExchangeWeak) {
192 std::string input, expected;
193 std::tie(input, expected) =
194 GetInputAndExpected(GetOpAtomicCompareExchangeWeakInst, GetParam());
195 SinglePassRunAndCheck<StripAtomicCounterMemoryPass>(input, expected,
196 /* skip_nop = */ false);
197 }
198
GetOpAtomicIIncrementInst(std::string val)199 std::string GetOpAtomicIIncrementInst(std::string val) {
200 return "%11 = OpAtomicIIncrement %uint %4 %uint_1 %uint_" + val;
201 }
202
TEST_P(MemorySemanticsModified,OpAtomicIIncrement)203 TEST_P(MemorySemanticsModified, OpAtomicIIncrement) {
204 std::string input, expected;
205 std::tie(input, expected) =
206 GetInputAndExpected(GetOpAtomicIIncrementInst, GetParam());
207 SinglePassRunAndCheck<StripAtomicCounterMemoryPass>(input, expected,
208 /* skip_nop = */ false);
209 }
210
GetOpAtomicIDecrementInst(std::string val)211 std::string GetOpAtomicIDecrementInst(std::string val) {
212 return "%11 = OpAtomicIDecrement %uint %4 %uint_1 %uint_" + val;
213 }
214
TEST_P(MemorySemanticsModified,OpAtomicIDecrement)215 TEST_P(MemorySemanticsModified, OpAtomicIDecrement) {
216 std::string input, expected;
217 std::tie(input, expected) =
218 GetInputAndExpected(GetOpAtomicIDecrementInst, GetParam());
219 SinglePassRunAndCheck<StripAtomicCounterMemoryPass>(input, expected,
220 /* skip_nop = */ false);
221 }
222
GetOpAtomicIAddInst(std::string val)223 std::string GetOpAtomicIAddInst(std::string val) {
224 return "%11 = OpAtomicIAdd %uint %4 %uint_1 %uint_" + val + " %uint_1";
225 }
226
TEST_P(MemorySemanticsModified,OpAtomicIAdd)227 TEST_P(MemorySemanticsModified, OpAtomicIAdd) {
228 std::string input, expected;
229 std::tie(input, expected) =
230 GetInputAndExpected(GetOpAtomicIAddInst, GetParam());
231 SinglePassRunAndCheck<StripAtomicCounterMemoryPass>(input, expected,
232 /* skip_nop = */ false);
233 }
234
GetOpAtomicISubInst(std::string val)235 std::string GetOpAtomicISubInst(std::string val) {
236 return "%11 = OpAtomicISub %uint %4 %uint_1 %uint_" + val + " %uint_1";
237 }
238
TEST_P(MemorySemanticsModified,OpAtomicISub)239 TEST_P(MemorySemanticsModified, OpAtomicISub) {
240 std::string input, expected;
241 std::tie(input, expected) =
242 GetInputAndExpected(GetOpAtomicISubInst, GetParam());
243 SinglePassRunAndCheck<StripAtomicCounterMemoryPass>(input, expected,
244 /* skip_nop = */ false);
245 }
246
GetOpAtomicSMinInst(std::string val)247 std::string GetOpAtomicSMinInst(std::string val) {
248 return "%11 = OpAtomicSMin %uint %4 %uint_1 %uint_" + val + " %uint_1";
249 }
250
TEST_P(MemorySemanticsModified,OpAtomicSMin)251 TEST_P(MemorySemanticsModified, OpAtomicSMin) {
252 std::string input, expected;
253 std::tie(input, expected) =
254 GetInputAndExpected(GetOpAtomicSMinInst, GetParam());
255 SinglePassRunAndCheck<StripAtomicCounterMemoryPass>(input, expected,
256 /* skip_nop = */ false);
257 }
258
GetOpAtomicUMinInst(std::string val)259 std::string GetOpAtomicUMinInst(std::string val) {
260 return "%11 = OpAtomicUMin %uint %4 %uint_1 %uint_" + val + " %uint_1";
261 }
262
TEST_P(MemorySemanticsModified,OpAtomicUMin)263 TEST_P(MemorySemanticsModified, OpAtomicUMin) {
264 std::string input, expected;
265 std::tie(input, expected) =
266 GetInputAndExpected(GetOpAtomicUMinInst, GetParam());
267 SinglePassRunAndCheck<StripAtomicCounterMemoryPass>(input, expected,
268 /* skip_nop = */ false);
269 }
270
GetOpAtomicSMaxInst(std::string val)271 std::string GetOpAtomicSMaxInst(std::string val) {
272 return "%11 = OpAtomicSMax %uint %4 %uint_1 %uint_" + val + " %uint_1";
273 }
274
TEST_P(MemorySemanticsModified,OpAtomicSMax)275 TEST_P(MemorySemanticsModified, OpAtomicSMax) {
276 std::string input, expected;
277 std::tie(input, expected) =
278 GetInputAndExpected(GetOpAtomicSMaxInst, GetParam());
279 SinglePassRunAndCheck<StripAtomicCounterMemoryPass>(input, expected,
280 /* skip_nop = */ false);
281 }
282
GetOpAtomicUMaxInst(std::string val)283 std::string GetOpAtomicUMaxInst(std::string val) {
284 return "%11 = OpAtomicUMax %uint %4 %uint_1 %uint_" + val + " %uint_1";
285 }
286
TEST_P(MemorySemanticsModified,OpAtomicUMax)287 TEST_P(MemorySemanticsModified, OpAtomicUMax) {
288 std::string input, expected;
289 std::tie(input, expected) =
290 GetInputAndExpected(GetOpAtomicUMaxInst, GetParam());
291 SinglePassRunAndCheck<StripAtomicCounterMemoryPass>(input, expected,
292 /* skip_nop = */ false);
293 }
294
GetOpAtomicAndInst(std::string val)295 std::string GetOpAtomicAndInst(std::string val) {
296 return "%11 = OpAtomicAnd %uint %4 %uint_1 %uint_" + val + " %uint_1";
297 }
298
TEST_P(MemorySemanticsModified,OpAtomicAnd)299 TEST_P(MemorySemanticsModified, OpAtomicAnd) {
300 std::string input, expected;
301 std::tie(input, expected) =
302 GetInputAndExpected(GetOpAtomicAndInst, GetParam());
303 SinglePassRunAndCheck<StripAtomicCounterMemoryPass>(input, expected,
304 /* skip_nop = */ false);
305 }
306
GetOpAtomicOrInst(std::string val)307 std::string GetOpAtomicOrInst(std::string val) {
308 return "%11 = OpAtomicOr %uint %4 %uint_1 %uint_" + val + " %uint_1";
309 }
310
TEST_P(MemorySemanticsModified,OpAtomicOr)311 TEST_P(MemorySemanticsModified, OpAtomicOr) {
312 std::string input, expected;
313 std::tie(input, expected) =
314 GetInputAndExpected(GetOpAtomicOrInst, GetParam());
315 SinglePassRunAndCheck<StripAtomicCounterMemoryPass>(input, expected,
316 /* skip_nop = */ false);
317 }
318
GetOpAtomicXorInst(std::string val)319 std::string GetOpAtomicXorInst(std::string val) {
320 return "%11 = OpAtomicXor %uint %4 %uint_1 %uint_" + val + " %uint_1";
321 }
322
TEST_P(MemorySemanticsModified,OpAtomicXor)323 TEST_P(MemorySemanticsModified, OpAtomicXor) {
324 std::string input, expected;
325 std::tie(input, expected) =
326 GetInputAndExpected(GetOpAtomicXorInst, GetParam());
327 SinglePassRunAndCheck<StripAtomicCounterMemoryPass>(input, expected,
328 /* skip_nop = */ false);
329 }
330
GetOpAtomicFlagTestAndSetInst(std::string val)331 std::string GetOpAtomicFlagTestAndSetInst(std::string val) {
332 return "%11 = OpAtomicFlagTestAndSet %uint %4 %uint_1 %uint_" + val;
333 }
334
TEST_P(MemorySemanticsModified,OpAtomicFlagTestAndSet)335 TEST_P(MemorySemanticsModified, OpAtomicFlagTestAndSet) {
336 std::string input, expected;
337 std::tie(input, expected) =
338 GetInputAndExpected(GetOpAtomicFlagTestAndSetInst, GetParam());
339 SinglePassRunAndCheck<StripAtomicCounterMemoryPass>(input, expected,
340 /* skip_nop = */ false);
341 }
342
GetOpAtomicFlagClearInst(std::string val)343 std::string GetOpAtomicFlagClearInst(std::string val) {
344 return "OpAtomicFlagClear %4 %uint_1 %uint_" + val;
345 }
346
TEST_P(MemorySemanticsModified,OpAtomicFlagClear)347 TEST_P(MemorySemanticsModified, OpAtomicFlagClear) {
348 std::string input, expected;
349 std::tie(input, expected) =
350 GetInputAndExpected(GetOpAtomicFlagClearInst, GetParam());
351 SinglePassRunAndCheck<StripAtomicCounterMemoryPass>(input, expected,
352 /* skip_nop = */ false);
353 }
354
GetOpMemoryNamedBarrierInst(std::string val)355 std::string GetOpMemoryNamedBarrierInst(std::string val) {
356 return "OpMemoryNamedBarrier %4 %uint_1 %uint_" + val;
357 }
358
TEST_P(MemorySemanticsModified,OpMemoryNamedBarrier)359 TEST_P(MemorySemanticsModified, OpMemoryNamedBarrier) {
360 std::string input, expected;
361 std::tie(input, expected) =
362 GetInputAndExpected(GetOpMemoryNamedBarrierInst, GetParam());
363 SinglePassRunAndCheck<StripAtomicCounterMemoryPass>(input, expected,
364 /* skip_nop = */ false);
365 }
366
367 // clang-format off
368 INSTANTIATE_TEST_SUITE_P(
369 StripAtomicCounterMemoryTest, MemorySemanticsModified,
370 ::testing::ValuesIn(std::vector<StripAtomicCounterMemoryParam>({
371 std::make_tuple("1024", "0"),
372 std::make_tuple("5", "5"),
373 std::make_tuple("1288", "264"),
374 std::make_tuple("264", "264")
375 })));
376 // clang-format on
377
GetNoMemorySemanticsPresentInst(std::string val)378 std::string GetNoMemorySemanticsPresentInst(std::string val) {
379 return "%11 = OpVariable %_ptr_Workgroup_uint Workgroup %uint_" + val;
380 }
381
TEST_F(NonMemorySemanticsUnmodifiedTest,NoMemorySemanticsPresent)382 TEST_F(NonMemorySemanticsUnmodifiedTest, NoMemorySemanticsPresent) {
383 std::string input, expected;
384 StripAtomicCounterMemoryParam param = std::make_tuple("1288", "1288");
385 std::tie(input, expected) =
386 GetInputAndExpected(GetNoMemorySemanticsPresentInst, param);
387 SinglePassRunAndCheck<StripAtomicCounterMemoryPass>(input, expected,
388 /* skip_nop = */ false);
389 }
390
GetMemorySemanticsPresentInst(std::string val)391 std::string GetMemorySemanticsPresentInst(std::string val) {
392 return "%11 = OpAtomicIAdd %uint %4 %uint_1 %uint_" + val + " %uint_1288";
393 }
394
TEST_F(NonMemorySemanticsUnmodifiedTest,MemorySemanticsPresent)395 TEST_F(NonMemorySemanticsUnmodifiedTest, MemorySemanticsPresent) {
396 std::string input, expected;
397 StripAtomicCounterMemoryParam param = std::make_tuple("1288", "264");
398 std::tie(input, expected) =
399 GetInputAndExpected(GetMemorySemanticsPresentInst, param);
400 SinglePassRunAndCheck<StripAtomicCounterMemoryPass>(input, expected,
401 /* skip_nop = */ false);
402 }
403
404 } // namespace
405 } // namespace opt
406 } // namespace spvtools
407