• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2020 André Perez Maselco
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "source/fuzz/transformation_swap_commutable_operands.h"
16 #include "source/fuzz/instruction_descriptor.h"
17 #include "test/fuzz/fuzz_test_util.h"
18 
19 namespace spvtools {
20 namespace fuzz {
21 namespace {
22 
TEST(TransformationSwapCommutableOperandsTest,IsApplicableTest)23 TEST(TransformationSwapCommutableOperandsTest, IsApplicableTest) {
24   std::string shader = R"(
25                OpCapability Shader
26           %1 = OpExtInstImport "GLSL.std.450"
27                OpMemoryModel Logical GLSL450
28                OpEntryPoint Fragment %4 "main"
29                OpExecutionMode %4 OriginUpperLeft
30                OpSource ESSL 310
31                OpName %4 "main"
32           %2 = OpTypeVoid
33           %3 = OpTypeFunction %2
34           %6 = OpTypeInt 32 1
35           %7 = OpTypeInt 32 0
36           %8 = OpConstant %7 2
37           %9 = OpTypeArray %6 %8
38          %10 = OpTypePointer Function %9
39          %12 = OpConstant %6 1
40          %13 = OpConstant %6 2
41          %14 = OpConstantComposite %9 %12 %13
42          %15 = OpTypePointer Function %6
43          %17 = OpConstant %6 0
44          %29 = OpTypeFloat 32
45          %30 = OpTypeArray %29 %8
46          %31 = OpTypePointer Function %30
47          %33 = OpConstant %29 1
48          %34 = OpConstant %29 2
49          %35 = OpConstantComposite %30 %33 %34
50          %36 = OpTypePointer Function %29
51          %49 = OpTypeVector %29 3
52          %50 = OpTypeArray %49 %8
53          %51 = OpTypePointer Function %50
54          %53 = OpConstant %29 3
55          %54 = OpConstantComposite %49 %33 %34 %53
56          %55 = OpConstant %29 4
57          %56 = OpConstant %29 5
58          %57 = OpConstant %29 6
59          %58 = OpConstantComposite %49 %55 %56 %57
60          %59 = OpConstantComposite %50 %54 %58
61          %61 = OpTypePointer Function %49
62           %4 = OpFunction %2 None %3
63           %5 = OpLabel
64          %11 = OpVariable %10 Function
65          %16 = OpVariable %15 Function
66          %23 = OpVariable %15 Function
67          %32 = OpVariable %31 Function
68          %37 = OpVariable %36 Function
69          %43 = OpVariable %36 Function
70          %52 = OpVariable %51 Function
71          %60 = OpVariable %36 Function
72                OpStore %11 %14
73          %18 = OpAccessChain %15 %11 %17
74          %19 = OpLoad %6 %18
75          %20 = OpAccessChain %15 %11 %12
76          %21 = OpLoad %6 %20
77          %22 = OpIAdd %6 %19 %21
78                OpStore %16 %22
79          %24 = OpAccessChain %15 %11 %17
80          %25 = OpLoad %6 %24
81          %26 = OpAccessChain %15 %11 %12
82          %27 = OpLoad %6 %26
83          %28 = OpIMul %6 %25 %27
84                OpStore %23 %28
85                OpStore %32 %35
86          %38 = OpAccessChain %36 %32 %17
87          %39 = OpLoad %29 %38
88          %40 = OpAccessChain %36 %32 %12
89          %41 = OpLoad %29 %40
90          %42 = OpFAdd %29 %39 %41
91                OpStore %37 %42
92          %44 = OpAccessChain %36 %32 %17
93          %45 = OpLoad %29 %44
94          %46 = OpAccessChain %36 %32 %12
95          %47 = OpLoad %29 %46
96          %48 = OpFMul %29 %45 %47
97                OpStore %43 %48
98                OpStore %52 %59
99          %62 = OpAccessChain %61 %52 %17
100          %63 = OpLoad %49 %62
101          %64 = OpAccessChain %61 %52 %12
102          %65 = OpLoad %49 %64
103          %66 = OpDot %29 %63 %65
104                OpStore %60 %66
105                OpReturn
106                OpFunctionEnd
107   )";
108 
109   const auto env = SPV_ENV_UNIVERSAL_1_5;
110   const auto consumer = nullptr;
111   const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption);
112   ASSERT_TRUE(IsValid(env, context.get()));
113 
114   FactManager fact_manager;
115   spvtools::ValidatorOptions validator_options;
116   TransformationContext transformation_context(&fact_manager,
117                                                validator_options);
118 
119   // Tests existing commutative instructions
120   auto instructionDescriptor = MakeInstructionDescriptor(22, SpvOpIAdd, 0);
121   auto transformation =
122       TransformationSwapCommutableOperands(instructionDescriptor);
123   ASSERT_TRUE(
124       transformation.IsApplicable(context.get(), transformation_context));
125 
126   instructionDescriptor = MakeInstructionDescriptor(28, SpvOpIMul, 0);
127   transformation = TransformationSwapCommutableOperands(instructionDescriptor);
128   ASSERT_TRUE(
129       transformation.IsApplicable(context.get(), transformation_context));
130 
131   instructionDescriptor = MakeInstructionDescriptor(42, SpvOpFAdd, 0);
132   transformation = TransformationSwapCommutableOperands(instructionDescriptor);
133   ASSERT_TRUE(
134       transformation.IsApplicable(context.get(), transformation_context));
135 
136   instructionDescriptor = MakeInstructionDescriptor(48, SpvOpFMul, 0);
137   transformation = TransformationSwapCommutableOperands(instructionDescriptor);
138   ASSERT_TRUE(
139       transformation.IsApplicable(context.get(), transformation_context));
140 
141   instructionDescriptor = MakeInstructionDescriptor(66, SpvOpDot, 0);
142   transformation = TransformationSwapCommutableOperands(instructionDescriptor);
143   ASSERT_TRUE(
144       transformation.IsApplicable(context.get(), transformation_context));
145 
146   // Tests existing non-commutative instructions
147   instructionDescriptor = MakeInstructionDescriptor(1, SpvOpExtInstImport, 0);
148   transformation = TransformationSwapCommutableOperands(instructionDescriptor);
149   ASSERT_FALSE(
150       transformation.IsApplicable(context.get(), transformation_context));
151 
152   instructionDescriptor = MakeInstructionDescriptor(5, SpvOpLabel, 0);
153   transformation = TransformationSwapCommutableOperands(instructionDescriptor);
154   ASSERT_FALSE(
155       transformation.IsApplicable(context.get(), transformation_context));
156 
157   instructionDescriptor = MakeInstructionDescriptor(8, SpvOpConstant, 0);
158   transformation = TransformationSwapCommutableOperands(instructionDescriptor);
159   ASSERT_FALSE(
160       transformation.IsApplicable(context.get(), transformation_context));
161 
162   instructionDescriptor = MakeInstructionDescriptor(11, SpvOpVariable, 0);
163   transformation = TransformationSwapCommutableOperands(instructionDescriptor);
164   ASSERT_FALSE(
165       transformation.IsApplicable(context.get(), transformation_context));
166 
167   instructionDescriptor =
168       MakeInstructionDescriptor(14, SpvOpConstantComposite, 0);
169   transformation = TransformationSwapCommutableOperands(instructionDescriptor);
170   ASSERT_FALSE(
171       transformation.IsApplicable(context.get(), transformation_context));
172 
173   // Tests the base instruction id not existing
174   instructionDescriptor = MakeInstructionDescriptor(67, SpvOpIAddCarry, 0);
175   transformation = TransformationSwapCommutableOperands(instructionDescriptor);
176   ASSERT_FALSE(
177       transformation.IsApplicable(context.get(), transformation_context));
178 
179   instructionDescriptor = MakeInstructionDescriptor(68, SpvOpIEqual, 0);
180   transformation = TransformationSwapCommutableOperands(instructionDescriptor);
181   ASSERT_FALSE(
182       transformation.IsApplicable(context.get(), transformation_context));
183 
184   instructionDescriptor = MakeInstructionDescriptor(69, SpvOpINotEqual, 0);
185   transformation = TransformationSwapCommutableOperands(instructionDescriptor);
186   ASSERT_FALSE(
187       transformation.IsApplicable(context.get(), transformation_context));
188 
189   instructionDescriptor = MakeInstructionDescriptor(70, SpvOpFOrdEqual, 0);
190   transformation = TransformationSwapCommutableOperands(instructionDescriptor);
191   ASSERT_FALSE(
192       transformation.IsApplicable(context.get(), transformation_context));
193 
194   instructionDescriptor = MakeInstructionDescriptor(71, SpvOpPtrEqual, 0);
195   transformation = TransformationSwapCommutableOperands(instructionDescriptor);
196   ASSERT_FALSE(
197       transformation.IsApplicable(context.get(), transformation_context));
198 
199   // Tests there being no instruction with the desired opcode after the base
200   // instruction id
201   instructionDescriptor = MakeInstructionDescriptor(24, SpvOpIAdd, 0);
202   transformation = TransformationSwapCommutableOperands(instructionDescriptor);
203   ASSERT_FALSE(
204       transformation.IsApplicable(context.get(), transformation_context));
205 
206   instructionDescriptor = MakeInstructionDescriptor(38, SpvOpIMul, 0);
207   transformation = TransformationSwapCommutableOperands(instructionDescriptor);
208   ASSERT_FALSE(
209       transformation.IsApplicable(context.get(), transformation_context));
210 
211   instructionDescriptor = MakeInstructionDescriptor(45, SpvOpFAdd, 0);
212   transformation = TransformationSwapCommutableOperands(instructionDescriptor);
213   ASSERT_FALSE(
214       transformation.IsApplicable(context.get(), transformation_context));
215 
216   instructionDescriptor = MakeInstructionDescriptor(66, SpvOpFMul, 0);
217   transformation = TransformationSwapCommutableOperands(instructionDescriptor);
218   ASSERT_FALSE(
219       transformation.IsApplicable(context.get(), transformation_context));
220 
221   // Tests there being an instruction with the desired opcode after the base
222   // instruction id, but the skip count associated with the instruction
223   // descriptor being so high.
224   instructionDescriptor = MakeInstructionDescriptor(11, SpvOpIAdd, 100);
225   transformation = TransformationSwapCommutableOperands(instructionDescriptor);
226   ASSERT_FALSE(
227       transformation.IsApplicable(context.get(), transformation_context));
228 
229   instructionDescriptor = MakeInstructionDescriptor(16, SpvOpIMul, 100);
230   transformation = TransformationSwapCommutableOperands(instructionDescriptor);
231   ASSERT_FALSE(
232       transformation.IsApplicable(context.get(), transformation_context));
233 
234   instructionDescriptor = MakeInstructionDescriptor(23, SpvOpFAdd, 100);
235   transformation = TransformationSwapCommutableOperands(instructionDescriptor);
236   ASSERT_FALSE(
237       transformation.IsApplicable(context.get(), transformation_context));
238 
239   instructionDescriptor = MakeInstructionDescriptor(32, SpvOpFMul, 100);
240   transformation = TransformationSwapCommutableOperands(instructionDescriptor);
241   ASSERT_FALSE(
242       transformation.IsApplicable(context.get(), transformation_context));
243 
244   instructionDescriptor = MakeInstructionDescriptor(37, SpvOpDot, 100);
245   transformation = TransformationSwapCommutableOperands(instructionDescriptor);
246   ASSERT_FALSE(
247       transformation.IsApplicable(context.get(), transformation_context));
248 }
249 
TEST(TransformationSwapCommutableOperandsTest,ApplyTest)250 TEST(TransformationSwapCommutableOperandsTest, ApplyTest) {
251   std::string shader = R"(
252                OpCapability Shader
253           %1 = OpExtInstImport "GLSL.std.450"
254                OpMemoryModel Logical GLSL450
255                OpEntryPoint Fragment %4 "main"
256                OpExecutionMode %4 OriginUpperLeft
257                OpSource ESSL 310
258                OpName %4 "main"
259           %2 = OpTypeVoid
260           %3 = OpTypeFunction %2
261           %6 = OpTypeInt 32 1
262           %7 = OpTypeInt 32 0
263           %8 = OpConstant %7 2
264           %9 = OpTypeArray %6 %8
265          %10 = OpTypePointer Function %9
266          %12 = OpConstant %6 1
267          %13 = OpConstant %6 2
268          %14 = OpConstantComposite %9 %12 %13
269          %15 = OpTypePointer Function %6
270          %17 = OpConstant %6 0
271          %29 = OpTypeFloat 32
272          %30 = OpTypeArray %29 %8
273          %31 = OpTypePointer Function %30
274          %33 = OpConstant %29 1
275          %34 = OpConstant %29 2
276          %35 = OpConstantComposite %30 %33 %34
277          %36 = OpTypePointer Function %29
278          %49 = OpTypeVector %29 3
279          %50 = OpTypeArray %49 %8
280          %51 = OpTypePointer Function %50
281          %53 = OpConstant %29 3
282          %54 = OpConstantComposite %49 %33 %34 %53
283          %55 = OpConstant %29 4
284          %56 = OpConstant %29 5
285          %57 = OpConstant %29 6
286          %58 = OpConstantComposite %49 %55 %56 %57
287          %59 = OpConstantComposite %50 %54 %58
288          %61 = OpTypePointer Function %49
289           %4 = OpFunction %2 None %3
290           %5 = OpLabel
291          %11 = OpVariable %10 Function
292          %16 = OpVariable %15 Function
293          %23 = OpVariable %15 Function
294          %32 = OpVariable %31 Function
295          %37 = OpVariable %36 Function
296          %43 = OpVariable %36 Function
297          %52 = OpVariable %51 Function
298          %60 = OpVariable %36 Function
299                OpStore %11 %14
300          %18 = OpAccessChain %15 %11 %17
301          %19 = OpLoad %6 %18
302          %20 = OpAccessChain %15 %11 %12
303          %21 = OpLoad %6 %20
304          %22 = OpIAdd %6 %19 %21
305                OpStore %16 %22
306          %24 = OpAccessChain %15 %11 %17
307          %25 = OpLoad %6 %24
308          %26 = OpAccessChain %15 %11 %12
309          %27 = OpLoad %6 %26
310          %28 = OpIMul %6 %25 %27
311                OpStore %23 %28
312                OpStore %32 %35
313          %38 = OpAccessChain %36 %32 %17
314          %39 = OpLoad %29 %38
315          %40 = OpAccessChain %36 %32 %12
316          %41 = OpLoad %29 %40
317          %42 = OpFAdd %29 %39 %41
318                OpStore %37 %42
319          %44 = OpAccessChain %36 %32 %17
320          %45 = OpLoad %29 %44
321          %46 = OpAccessChain %36 %32 %12
322          %47 = OpLoad %29 %46
323          %48 = OpFMul %29 %45 %47
324                OpStore %43 %48
325                OpStore %52 %59
326          %62 = OpAccessChain %61 %52 %17
327          %63 = OpLoad %49 %62
328          %64 = OpAccessChain %61 %52 %12
329          %65 = OpLoad %49 %64
330          %66 = OpDot %29 %63 %65
331                OpStore %60 %66
332                OpReturn
333                OpFunctionEnd
334   )";
335 
336   const auto env = SPV_ENV_UNIVERSAL_1_5;
337   const auto consumer = nullptr;
338   const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption);
339   ASSERT_TRUE(IsValid(env, context.get()));
340 
341   FactManager fact_manager;
342   spvtools::ValidatorOptions validator_options;
343   TransformationContext transformation_context(&fact_manager,
344                                                validator_options);
345 
346   auto instructionDescriptor = MakeInstructionDescriptor(22, SpvOpIAdd, 0);
347   auto transformation =
348       TransformationSwapCommutableOperands(instructionDescriptor);
349   transformation.Apply(context.get(), &transformation_context);
350 
351   instructionDescriptor = MakeInstructionDescriptor(28, SpvOpIMul, 0);
352   transformation = TransformationSwapCommutableOperands(instructionDescriptor);
353   transformation.Apply(context.get(), &transformation_context);
354 
355   instructionDescriptor = MakeInstructionDescriptor(42, SpvOpFAdd, 0);
356   transformation = TransformationSwapCommutableOperands(instructionDescriptor);
357   transformation.Apply(context.get(), &transformation_context);
358 
359   instructionDescriptor = MakeInstructionDescriptor(48, SpvOpFMul, 0);
360   transformation = TransformationSwapCommutableOperands(instructionDescriptor);
361   transformation.Apply(context.get(), &transformation_context);
362 
363   instructionDescriptor = MakeInstructionDescriptor(66, SpvOpDot, 0);
364   transformation = TransformationSwapCommutableOperands(instructionDescriptor);
365   transformation.Apply(context.get(), &transformation_context);
366 
367   std::string variantShader = R"(
368                OpCapability Shader
369           %1 = OpExtInstImport "GLSL.std.450"
370                OpMemoryModel Logical GLSL450
371                OpEntryPoint Fragment %4 "main"
372                OpExecutionMode %4 OriginUpperLeft
373                OpSource ESSL 310
374                OpName %4 "main"
375           %2 = OpTypeVoid
376           %3 = OpTypeFunction %2
377           %6 = OpTypeInt 32 1
378           %7 = OpTypeInt 32 0
379           %8 = OpConstant %7 2
380           %9 = OpTypeArray %6 %8
381          %10 = OpTypePointer Function %9
382          %12 = OpConstant %6 1
383          %13 = OpConstant %6 2
384          %14 = OpConstantComposite %9 %12 %13
385          %15 = OpTypePointer Function %6
386          %17 = OpConstant %6 0
387          %29 = OpTypeFloat 32
388          %30 = OpTypeArray %29 %8
389          %31 = OpTypePointer Function %30
390          %33 = OpConstant %29 1
391          %34 = OpConstant %29 2
392          %35 = OpConstantComposite %30 %33 %34
393          %36 = OpTypePointer Function %29
394          %49 = OpTypeVector %29 3
395          %50 = OpTypeArray %49 %8
396          %51 = OpTypePointer Function %50
397          %53 = OpConstant %29 3
398          %54 = OpConstantComposite %49 %33 %34 %53
399          %55 = OpConstant %29 4
400          %56 = OpConstant %29 5
401          %57 = OpConstant %29 6
402          %58 = OpConstantComposite %49 %55 %56 %57
403          %59 = OpConstantComposite %50 %54 %58
404          %61 = OpTypePointer Function %49
405           %4 = OpFunction %2 None %3
406           %5 = OpLabel
407          %11 = OpVariable %10 Function
408          %16 = OpVariable %15 Function
409          %23 = OpVariable %15 Function
410          %32 = OpVariable %31 Function
411          %37 = OpVariable %36 Function
412          %43 = OpVariable %36 Function
413          %52 = OpVariable %51 Function
414          %60 = OpVariable %36 Function
415                OpStore %11 %14
416          %18 = OpAccessChain %15 %11 %17
417          %19 = OpLoad %6 %18
418          %20 = OpAccessChain %15 %11 %12
419          %21 = OpLoad %6 %20
420          %22 = OpIAdd %6 %21 %19
421                OpStore %16 %22
422          %24 = OpAccessChain %15 %11 %17
423          %25 = OpLoad %6 %24
424          %26 = OpAccessChain %15 %11 %12
425          %27 = OpLoad %6 %26
426          %28 = OpIMul %6 %27 %25
427                OpStore %23 %28
428                OpStore %32 %35
429          %38 = OpAccessChain %36 %32 %17
430          %39 = OpLoad %29 %38
431          %40 = OpAccessChain %36 %32 %12
432          %41 = OpLoad %29 %40
433          %42 = OpFAdd %29 %41 %39
434                OpStore %37 %42
435          %44 = OpAccessChain %36 %32 %17
436          %45 = OpLoad %29 %44
437          %46 = OpAccessChain %36 %32 %12
438          %47 = OpLoad %29 %46
439          %48 = OpFMul %29 %47 %45
440                OpStore %43 %48
441                OpStore %52 %59
442          %62 = OpAccessChain %61 %52 %17
443          %63 = OpLoad %49 %62
444          %64 = OpAccessChain %61 %52 %12
445          %65 = OpLoad %49 %64
446          %66 = OpDot %29 %65 %63
447                OpStore %60 %66
448                OpReturn
449                OpFunctionEnd
450   )";
451 
452   ASSERT_TRUE(IsEqual(env, variantShader, context.get()));
453 }
454 
455 }  // namespace
456 }  // namespace fuzz
457 }  // namespace spvtools
458