1 // Copyright (c) 2020 André Perez Maselco
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "source/fuzz/transformation_swap_commutable_operands.h"
16
17 #include "gtest/gtest.h"
18 #include "source/fuzz/fuzzer_util.h"
19 #include "source/fuzz/instruction_descriptor.h"
20 #include "test/fuzz/fuzz_test_util.h"
21
22 namespace spvtools {
23 namespace fuzz {
24 namespace {
25
TEST(TransformationSwapCommutableOperandsTest,IsApplicableTest)26 TEST(TransformationSwapCommutableOperandsTest, IsApplicableTest) {
27 std::string shader = R"(
28 OpCapability Shader
29 %1 = OpExtInstImport "GLSL.std.450"
30 OpMemoryModel Logical GLSL450
31 OpEntryPoint Fragment %4 "main"
32 OpExecutionMode %4 OriginUpperLeft
33 OpSource ESSL 310
34 OpName %4 "main"
35 %2 = OpTypeVoid
36 %3 = OpTypeFunction %2
37 %6 = OpTypeInt 32 1
38 %7 = OpTypeInt 32 0
39 %8 = OpConstant %7 2
40 %9 = OpTypeArray %6 %8
41 %10 = OpTypePointer Function %9
42 %12 = OpConstant %6 1
43 %13 = OpConstant %6 2
44 %14 = OpConstantComposite %9 %12 %13
45 %15 = OpTypePointer Function %6
46 %17 = OpConstant %6 0
47 %29 = OpTypeFloat 32
48 %30 = OpTypeArray %29 %8
49 %31 = OpTypePointer Function %30
50 %33 = OpConstant %29 1
51 %34 = OpConstant %29 2
52 %35 = OpConstantComposite %30 %33 %34
53 %36 = OpTypePointer Function %29
54 %49 = OpTypeVector %29 3
55 %50 = OpTypeArray %49 %8
56 %51 = OpTypePointer Function %50
57 %53 = OpConstant %29 3
58 %54 = OpConstantComposite %49 %33 %34 %53
59 %55 = OpConstant %29 4
60 %56 = OpConstant %29 5
61 %57 = OpConstant %29 6
62 %58 = OpConstantComposite %49 %55 %56 %57
63 %59 = OpConstantComposite %50 %54 %58
64 %61 = OpTypePointer Function %49
65 %4 = OpFunction %2 None %3
66 %5 = OpLabel
67 %11 = OpVariable %10 Function
68 %16 = OpVariable %15 Function
69 %23 = OpVariable %15 Function
70 %32 = OpVariable %31 Function
71 %37 = OpVariable %36 Function
72 %43 = OpVariable %36 Function
73 %52 = OpVariable %51 Function
74 %60 = OpVariable %36 Function
75 OpStore %11 %14
76 %18 = OpAccessChain %15 %11 %17
77 %19 = OpLoad %6 %18
78 %20 = OpAccessChain %15 %11 %12
79 %21 = OpLoad %6 %20
80 %22 = OpIAdd %6 %19 %21
81 OpStore %16 %22
82 %24 = OpAccessChain %15 %11 %17
83 %25 = OpLoad %6 %24
84 %26 = OpAccessChain %15 %11 %12
85 %27 = OpLoad %6 %26
86 %28 = OpIMul %6 %25 %27
87 OpStore %23 %28
88 OpStore %32 %35
89 %38 = OpAccessChain %36 %32 %17
90 %39 = OpLoad %29 %38
91 %40 = OpAccessChain %36 %32 %12
92 %41 = OpLoad %29 %40
93 %42 = OpFAdd %29 %39 %41
94 OpStore %37 %42
95 %44 = OpAccessChain %36 %32 %17
96 %45 = OpLoad %29 %44
97 %46 = OpAccessChain %36 %32 %12
98 %47 = OpLoad %29 %46
99 %48 = OpFMul %29 %45 %47
100 OpStore %43 %48
101 OpStore %52 %59
102 %62 = OpAccessChain %61 %52 %17
103 %63 = OpLoad %49 %62
104 %64 = OpAccessChain %61 %52 %12
105 %65 = OpLoad %49 %64
106 %66 = OpDot %29 %63 %65
107 OpStore %60 %66
108 OpReturn
109 OpFunctionEnd
110 )";
111
112 const auto env = SPV_ENV_UNIVERSAL_1_5;
113 const auto consumer = nullptr;
114 const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption);
115 spvtools::ValidatorOptions validator_options;
116 ASSERT_TRUE(fuzzerutil::IsValidAndWellFormed(context.get(), validator_options,
117 kConsoleMessageConsumer));
118 TransformationContext transformation_context(
119 MakeUnique<FactManager>(context.get()), validator_options);
120 // Tests existing commutative instructions
121 auto instructionDescriptor = MakeInstructionDescriptor(22, SpvOpIAdd, 0);
122 auto transformation =
123 TransformationSwapCommutableOperands(instructionDescriptor);
124 ASSERT_TRUE(
125 transformation.IsApplicable(context.get(), transformation_context));
126
127 instructionDescriptor = MakeInstructionDescriptor(28, SpvOpIMul, 0);
128 transformation = TransformationSwapCommutableOperands(instructionDescriptor);
129 ASSERT_TRUE(
130 transformation.IsApplicable(context.get(), transformation_context));
131
132 instructionDescriptor = MakeInstructionDescriptor(42, SpvOpFAdd, 0);
133 transformation = TransformationSwapCommutableOperands(instructionDescriptor);
134 ASSERT_TRUE(
135 transformation.IsApplicable(context.get(), transformation_context));
136
137 instructionDescriptor = MakeInstructionDescriptor(48, SpvOpFMul, 0);
138 transformation = TransformationSwapCommutableOperands(instructionDescriptor);
139 ASSERT_TRUE(
140 transformation.IsApplicable(context.get(), transformation_context));
141
142 instructionDescriptor = MakeInstructionDescriptor(66, SpvOpDot, 0);
143 transformation = TransformationSwapCommutableOperands(instructionDescriptor);
144 ASSERT_TRUE(
145 transformation.IsApplicable(context.get(), transformation_context));
146
147 // Tests existing non-commutative instructions
148 instructionDescriptor = MakeInstructionDescriptor(1, SpvOpExtInstImport, 0);
149 transformation = TransformationSwapCommutableOperands(instructionDescriptor);
150 ASSERT_FALSE(
151 transformation.IsApplicable(context.get(), transformation_context));
152
153 instructionDescriptor = MakeInstructionDescriptor(5, SpvOpLabel, 0);
154 transformation = TransformationSwapCommutableOperands(instructionDescriptor);
155 ASSERT_FALSE(
156 transformation.IsApplicable(context.get(), transformation_context));
157
158 instructionDescriptor = MakeInstructionDescriptor(8, SpvOpConstant, 0);
159 transformation = TransformationSwapCommutableOperands(instructionDescriptor);
160 ASSERT_FALSE(
161 transformation.IsApplicable(context.get(), transformation_context));
162
163 instructionDescriptor = MakeInstructionDescriptor(11, SpvOpVariable, 0);
164 transformation = TransformationSwapCommutableOperands(instructionDescriptor);
165 ASSERT_FALSE(
166 transformation.IsApplicable(context.get(), transformation_context));
167
168 instructionDescriptor =
169 MakeInstructionDescriptor(14, SpvOpConstantComposite, 0);
170 transformation = TransformationSwapCommutableOperands(instructionDescriptor);
171 ASSERT_FALSE(
172 transformation.IsApplicable(context.get(), transformation_context));
173
174 // Tests the base instruction id not existing
175 instructionDescriptor = MakeInstructionDescriptor(67, SpvOpIAddCarry, 0);
176 transformation = TransformationSwapCommutableOperands(instructionDescriptor);
177 ASSERT_FALSE(
178 transformation.IsApplicable(context.get(), transformation_context));
179
180 instructionDescriptor = MakeInstructionDescriptor(68, SpvOpIEqual, 0);
181 transformation = TransformationSwapCommutableOperands(instructionDescriptor);
182 ASSERT_FALSE(
183 transformation.IsApplicable(context.get(), transformation_context));
184
185 instructionDescriptor = MakeInstructionDescriptor(69, SpvOpINotEqual, 0);
186 transformation = TransformationSwapCommutableOperands(instructionDescriptor);
187 ASSERT_FALSE(
188 transformation.IsApplicable(context.get(), transformation_context));
189
190 instructionDescriptor = MakeInstructionDescriptor(70, SpvOpFOrdEqual, 0);
191 transformation = TransformationSwapCommutableOperands(instructionDescriptor);
192 ASSERT_FALSE(
193 transformation.IsApplicable(context.get(), transformation_context));
194
195 instructionDescriptor = MakeInstructionDescriptor(71, SpvOpPtrEqual, 0);
196 transformation = TransformationSwapCommutableOperands(instructionDescriptor);
197 ASSERT_FALSE(
198 transformation.IsApplicable(context.get(), transformation_context));
199
200 // Tests there being no instruction with the desired opcode after the base
201 // instruction id
202 instructionDescriptor = MakeInstructionDescriptor(24, SpvOpIAdd, 0);
203 transformation = TransformationSwapCommutableOperands(instructionDescriptor);
204 ASSERT_FALSE(
205 transformation.IsApplicable(context.get(), transformation_context));
206
207 instructionDescriptor = MakeInstructionDescriptor(38, SpvOpIMul, 0);
208 transformation = TransformationSwapCommutableOperands(instructionDescriptor);
209 ASSERT_FALSE(
210 transformation.IsApplicable(context.get(), transformation_context));
211
212 instructionDescriptor = MakeInstructionDescriptor(45, SpvOpFAdd, 0);
213 transformation = TransformationSwapCommutableOperands(instructionDescriptor);
214 ASSERT_FALSE(
215 transformation.IsApplicable(context.get(), transformation_context));
216
217 instructionDescriptor = MakeInstructionDescriptor(66, SpvOpFMul, 0);
218 transformation = TransformationSwapCommutableOperands(instructionDescriptor);
219 ASSERT_FALSE(
220 transformation.IsApplicable(context.get(), transformation_context));
221
222 // Tests there being an instruction with the desired opcode after the base
223 // instruction id, but the skip count associated with the instruction
224 // descriptor being so high.
225 instructionDescriptor = MakeInstructionDescriptor(11, SpvOpIAdd, 100);
226 transformation = TransformationSwapCommutableOperands(instructionDescriptor);
227 ASSERT_FALSE(
228 transformation.IsApplicable(context.get(), transformation_context));
229
230 instructionDescriptor = MakeInstructionDescriptor(16, SpvOpIMul, 100);
231 transformation = TransformationSwapCommutableOperands(instructionDescriptor);
232 ASSERT_FALSE(
233 transformation.IsApplicable(context.get(), transformation_context));
234
235 instructionDescriptor = MakeInstructionDescriptor(23, SpvOpFAdd, 100);
236 transformation = TransformationSwapCommutableOperands(instructionDescriptor);
237 ASSERT_FALSE(
238 transformation.IsApplicable(context.get(), transformation_context));
239
240 instructionDescriptor = MakeInstructionDescriptor(32, SpvOpFMul, 100);
241 transformation = TransformationSwapCommutableOperands(instructionDescriptor);
242 ASSERT_FALSE(
243 transformation.IsApplicable(context.get(), transformation_context));
244
245 instructionDescriptor = MakeInstructionDescriptor(37, SpvOpDot, 100);
246 transformation = TransformationSwapCommutableOperands(instructionDescriptor);
247 ASSERT_FALSE(
248 transformation.IsApplicable(context.get(), transformation_context));
249 }
250
TEST(TransformationSwapCommutableOperandsTest,ApplyTest)251 TEST(TransformationSwapCommutableOperandsTest, ApplyTest) {
252 std::string shader = R"(
253 OpCapability Shader
254 %1 = OpExtInstImport "GLSL.std.450"
255 OpMemoryModel Logical GLSL450
256 OpEntryPoint Fragment %4 "main"
257 OpExecutionMode %4 OriginUpperLeft
258 OpSource ESSL 310
259 OpName %4 "main"
260 %2 = OpTypeVoid
261 %3 = OpTypeFunction %2
262 %6 = OpTypeInt 32 1
263 %7 = OpTypeInt 32 0
264 %8 = OpConstant %7 2
265 %9 = OpTypeArray %6 %8
266 %10 = OpTypePointer Function %9
267 %12 = OpConstant %6 1
268 %13 = OpConstant %6 2
269 %14 = OpConstantComposite %9 %12 %13
270 %15 = OpTypePointer Function %6
271 %17 = OpConstant %6 0
272 %29 = OpTypeFloat 32
273 %30 = OpTypeArray %29 %8
274 %31 = OpTypePointer Function %30
275 %33 = OpConstant %29 1
276 %34 = OpConstant %29 2
277 %35 = OpConstantComposite %30 %33 %34
278 %36 = OpTypePointer Function %29
279 %49 = OpTypeVector %29 3
280 %50 = OpTypeArray %49 %8
281 %51 = OpTypePointer Function %50
282 %53 = OpConstant %29 3
283 %54 = OpConstantComposite %49 %33 %34 %53
284 %55 = OpConstant %29 4
285 %56 = OpConstant %29 5
286 %57 = OpConstant %29 6
287 %58 = OpConstantComposite %49 %55 %56 %57
288 %59 = OpConstantComposite %50 %54 %58
289 %61 = OpTypePointer Function %49
290 %4 = OpFunction %2 None %3
291 %5 = OpLabel
292 %11 = OpVariable %10 Function
293 %16 = OpVariable %15 Function
294 %23 = OpVariable %15 Function
295 %32 = OpVariable %31 Function
296 %37 = OpVariable %36 Function
297 %43 = OpVariable %36 Function
298 %52 = OpVariable %51 Function
299 %60 = OpVariable %36 Function
300 OpStore %11 %14
301 %18 = OpAccessChain %15 %11 %17
302 %19 = OpLoad %6 %18
303 %20 = OpAccessChain %15 %11 %12
304 %21 = OpLoad %6 %20
305 %22 = OpIAdd %6 %19 %21
306 OpStore %16 %22
307 %24 = OpAccessChain %15 %11 %17
308 %25 = OpLoad %6 %24
309 %26 = OpAccessChain %15 %11 %12
310 %27 = OpLoad %6 %26
311 %28 = OpIMul %6 %25 %27
312 OpStore %23 %28
313 OpStore %32 %35
314 %38 = OpAccessChain %36 %32 %17
315 %39 = OpLoad %29 %38
316 %40 = OpAccessChain %36 %32 %12
317 %41 = OpLoad %29 %40
318 %42 = OpFAdd %29 %39 %41
319 OpStore %37 %42
320 %44 = OpAccessChain %36 %32 %17
321 %45 = OpLoad %29 %44
322 %46 = OpAccessChain %36 %32 %12
323 %47 = OpLoad %29 %46
324 %48 = OpFMul %29 %45 %47
325 OpStore %43 %48
326 OpStore %52 %59
327 %62 = OpAccessChain %61 %52 %17
328 %63 = OpLoad %49 %62
329 %64 = OpAccessChain %61 %52 %12
330 %65 = OpLoad %49 %64
331 %66 = OpDot %29 %63 %65
332 OpStore %60 %66
333 OpReturn
334 OpFunctionEnd
335 )";
336
337 const auto env = SPV_ENV_UNIVERSAL_1_5;
338 const auto consumer = nullptr;
339 const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption);
340 spvtools::ValidatorOptions validator_options;
341 ASSERT_TRUE(fuzzerutil::IsValidAndWellFormed(context.get(), validator_options,
342 kConsoleMessageConsumer));
343 TransformationContext transformation_context(
344 MakeUnique<FactManager>(context.get()), validator_options);
345 auto instructionDescriptor = MakeInstructionDescriptor(22, SpvOpIAdd, 0);
346 auto transformation =
347 TransformationSwapCommutableOperands(instructionDescriptor);
348 ApplyAndCheckFreshIds(transformation, context.get(), &transformation_context);
349
350 instructionDescriptor = MakeInstructionDescriptor(28, SpvOpIMul, 0);
351 transformation = TransformationSwapCommutableOperands(instructionDescriptor);
352 ApplyAndCheckFreshIds(transformation, context.get(), &transformation_context);
353
354 instructionDescriptor = MakeInstructionDescriptor(42, SpvOpFAdd, 0);
355 transformation = TransformationSwapCommutableOperands(instructionDescriptor);
356 ApplyAndCheckFreshIds(transformation, context.get(), &transformation_context);
357
358 instructionDescriptor = MakeInstructionDescriptor(48, SpvOpFMul, 0);
359 transformation = TransformationSwapCommutableOperands(instructionDescriptor);
360 ApplyAndCheckFreshIds(transformation, context.get(), &transformation_context);
361
362 instructionDescriptor = MakeInstructionDescriptor(66, SpvOpDot, 0);
363 transformation = TransformationSwapCommutableOperands(instructionDescriptor);
364 ApplyAndCheckFreshIds(transformation, context.get(), &transformation_context);
365
366 std::string variantShader = R"(
367 OpCapability Shader
368 %1 = OpExtInstImport "GLSL.std.450"
369 OpMemoryModel Logical GLSL450
370 OpEntryPoint Fragment %4 "main"
371 OpExecutionMode %4 OriginUpperLeft
372 OpSource ESSL 310
373 OpName %4 "main"
374 %2 = OpTypeVoid
375 %3 = OpTypeFunction %2
376 %6 = OpTypeInt 32 1
377 %7 = OpTypeInt 32 0
378 %8 = OpConstant %7 2
379 %9 = OpTypeArray %6 %8
380 %10 = OpTypePointer Function %9
381 %12 = OpConstant %6 1
382 %13 = OpConstant %6 2
383 %14 = OpConstantComposite %9 %12 %13
384 %15 = OpTypePointer Function %6
385 %17 = OpConstant %6 0
386 %29 = OpTypeFloat 32
387 %30 = OpTypeArray %29 %8
388 %31 = OpTypePointer Function %30
389 %33 = OpConstant %29 1
390 %34 = OpConstant %29 2
391 %35 = OpConstantComposite %30 %33 %34
392 %36 = OpTypePointer Function %29
393 %49 = OpTypeVector %29 3
394 %50 = OpTypeArray %49 %8
395 %51 = OpTypePointer Function %50
396 %53 = OpConstant %29 3
397 %54 = OpConstantComposite %49 %33 %34 %53
398 %55 = OpConstant %29 4
399 %56 = OpConstant %29 5
400 %57 = OpConstant %29 6
401 %58 = OpConstantComposite %49 %55 %56 %57
402 %59 = OpConstantComposite %50 %54 %58
403 %61 = OpTypePointer Function %49
404 %4 = OpFunction %2 None %3
405 %5 = OpLabel
406 %11 = OpVariable %10 Function
407 %16 = OpVariable %15 Function
408 %23 = OpVariable %15 Function
409 %32 = OpVariable %31 Function
410 %37 = OpVariable %36 Function
411 %43 = OpVariable %36 Function
412 %52 = OpVariable %51 Function
413 %60 = OpVariable %36 Function
414 OpStore %11 %14
415 %18 = OpAccessChain %15 %11 %17
416 %19 = OpLoad %6 %18
417 %20 = OpAccessChain %15 %11 %12
418 %21 = OpLoad %6 %20
419 %22 = OpIAdd %6 %21 %19
420 OpStore %16 %22
421 %24 = OpAccessChain %15 %11 %17
422 %25 = OpLoad %6 %24
423 %26 = OpAccessChain %15 %11 %12
424 %27 = OpLoad %6 %26
425 %28 = OpIMul %6 %27 %25
426 OpStore %23 %28
427 OpStore %32 %35
428 %38 = OpAccessChain %36 %32 %17
429 %39 = OpLoad %29 %38
430 %40 = OpAccessChain %36 %32 %12
431 %41 = OpLoad %29 %40
432 %42 = OpFAdd %29 %41 %39
433 OpStore %37 %42
434 %44 = OpAccessChain %36 %32 %17
435 %45 = OpLoad %29 %44
436 %46 = OpAccessChain %36 %32 %12
437 %47 = OpLoad %29 %46
438 %48 = OpFMul %29 %47 %45
439 OpStore %43 %48
440 OpStore %52 %59
441 %62 = OpAccessChain %61 %52 %17
442 %63 = OpLoad %49 %62
443 %64 = OpAccessChain %61 %52 %12
444 %65 = OpLoad %49 %64
445 %66 = OpDot %29 %65 %63
446 OpStore %60 %66
447 OpReturn
448 OpFunctionEnd
449 )";
450
451 ASSERT_TRUE(IsEqual(env, variantShader, context.get()));
452 }
453
454 } // namespace
455 } // namespace fuzz
456 } // namespace spvtools
457