1 // Copyright (c) 2020 André Perez Maselco
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "source/fuzz/transformation_swap_commutable_operands.h"
16 #include "source/fuzz/instruction_descriptor.h"
17 #include "test/fuzz/fuzz_test_util.h"
18
19 namespace spvtools {
20 namespace fuzz {
21 namespace {
22
TEST(TransformationSwapCommutableOperandsTest,IsApplicableTest)23 TEST(TransformationSwapCommutableOperandsTest, IsApplicableTest) {
24 std::string shader = R"(
25 OpCapability Shader
26 %1 = OpExtInstImport "GLSL.std.450"
27 OpMemoryModel Logical GLSL450
28 OpEntryPoint Fragment %4 "main"
29 OpExecutionMode %4 OriginUpperLeft
30 OpSource ESSL 310
31 OpName %4 "main"
32 %2 = OpTypeVoid
33 %3 = OpTypeFunction %2
34 %6 = OpTypeInt 32 1
35 %7 = OpTypeInt 32 0
36 %8 = OpConstant %7 2
37 %9 = OpTypeArray %6 %8
38 %10 = OpTypePointer Function %9
39 %12 = OpConstant %6 1
40 %13 = OpConstant %6 2
41 %14 = OpConstantComposite %9 %12 %13
42 %15 = OpTypePointer Function %6
43 %17 = OpConstant %6 0
44 %29 = OpTypeFloat 32
45 %30 = OpTypeArray %29 %8
46 %31 = OpTypePointer Function %30
47 %33 = OpConstant %29 1
48 %34 = OpConstant %29 2
49 %35 = OpConstantComposite %30 %33 %34
50 %36 = OpTypePointer Function %29
51 %49 = OpTypeVector %29 3
52 %50 = OpTypeArray %49 %8
53 %51 = OpTypePointer Function %50
54 %53 = OpConstant %29 3
55 %54 = OpConstantComposite %49 %33 %34 %53
56 %55 = OpConstant %29 4
57 %56 = OpConstant %29 5
58 %57 = OpConstant %29 6
59 %58 = OpConstantComposite %49 %55 %56 %57
60 %59 = OpConstantComposite %50 %54 %58
61 %61 = OpTypePointer Function %49
62 %4 = OpFunction %2 None %3
63 %5 = OpLabel
64 %11 = OpVariable %10 Function
65 %16 = OpVariable %15 Function
66 %23 = OpVariable %15 Function
67 %32 = OpVariable %31 Function
68 %37 = OpVariable %36 Function
69 %43 = OpVariable %36 Function
70 %52 = OpVariable %51 Function
71 %60 = OpVariable %36 Function
72 OpStore %11 %14
73 %18 = OpAccessChain %15 %11 %17
74 %19 = OpLoad %6 %18
75 %20 = OpAccessChain %15 %11 %12
76 %21 = OpLoad %6 %20
77 %22 = OpIAdd %6 %19 %21
78 OpStore %16 %22
79 %24 = OpAccessChain %15 %11 %17
80 %25 = OpLoad %6 %24
81 %26 = OpAccessChain %15 %11 %12
82 %27 = OpLoad %6 %26
83 %28 = OpIMul %6 %25 %27
84 OpStore %23 %28
85 OpStore %32 %35
86 %38 = OpAccessChain %36 %32 %17
87 %39 = OpLoad %29 %38
88 %40 = OpAccessChain %36 %32 %12
89 %41 = OpLoad %29 %40
90 %42 = OpFAdd %29 %39 %41
91 OpStore %37 %42
92 %44 = OpAccessChain %36 %32 %17
93 %45 = OpLoad %29 %44
94 %46 = OpAccessChain %36 %32 %12
95 %47 = OpLoad %29 %46
96 %48 = OpFMul %29 %45 %47
97 OpStore %43 %48
98 OpStore %52 %59
99 %62 = OpAccessChain %61 %52 %17
100 %63 = OpLoad %49 %62
101 %64 = OpAccessChain %61 %52 %12
102 %65 = OpLoad %49 %64
103 %66 = OpDot %29 %63 %65
104 OpStore %60 %66
105 OpReturn
106 OpFunctionEnd
107 )";
108
109 const auto env = SPV_ENV_UNIVERSAL_1_5;
110 const auto consumer = nullptr;
111 const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption);
112 ASSERT_TRUE(IsValid(env, context.get()));
113
114 FactManager fact_manager;
115 spvtools::ValidatorOptions validator_options;
116 TransformationContext transformation_context(&fact_manager,
117 validator_options);
118
119 // Tests existing commutative instructions
120 auto instructionDescriptor = MakeInstructionDescriptor(22, SpvOpIAdd, 0);
121 auto transformation =
122 TransformationSwapCommutableOperands(instructionDescriptor);
123 ASSERT_TRUE(
124 transformation.IsApplicable(context.get(), transformation_context));
125
126 instructionDescriptor = MakeInstructionDescriptor(28, SpvOpIMul, 0);
127 transformation = TransformationSwapCommutableOperands(instructionDescriptor);
128 ASSERT_TRUE(
129 transformation.IsApplicable(context.get(), transformation_context));
130
131 instructionDescriptor = MakeInstructionDescriptor(42, SpvOpFAdd, 0);
132 transformation = TransformationSwapCommutableOperands(instructionDescriptor);
133 ASSERT_TRUE(
134 transformation.IsApplicable(context.get(), transformation_context));
135
136 instructionDescriptor = MakeInstructionDescriptor(48, SpvOpFMul, 0);
137 transformation = TransformationSwapCommutableOperands(instructionDescriptor);
138 ASSERT_TRUE(
139 transformation.IsApplicable(context.get(), transformation_context));
140
141 instructionDescriptor = MakeInstructionDescriptor(66, SpvOpDot, 0);
142 transformation = TransformationSwapCommutableOperands(instructionDescriptor);
143 ASSERT_TRUE(
144 transformation.IsApplicable(context.get(), transformation_context));
145
146 // Tests existing non-commutative instructions
147 instructionDescriptor = MakeInstructionDescriptor(1, SpvOpExtInstImport, 0);
148 transformation = TransformationSwapCommutableOperands(instructionDescriptor);
149 ASSERT_FALSE(
150 transformation.IsApplicable(context.get(), transformation_context));
151
152 instructionDescriptor = MakeInstructionDescriptor(5, SpvOpLabel, 0);
153 transformation = TransformationSwapCommutableOperands(instructionDescriptor);
154 ASSERT_FALSE(
155 transformation.IsApplicable(context.get(), transformation_context));
156
157 instructionDescriptor = MakeInstructionDescriptor(8, SpvOpConstant, 0);
158 transformation = TransformationSwapCommutableOperands(instructionDescriptor);
159 ASSERT_FALSE(
160 transformation.IsApplicable(context.get(), transformation_context));
161
162 instructionDescriptor = MakeInstructionDescriptor(11, SpvOpVariable, 0);
163 transformation = TransformationSwapCommutableOperands(instructionDescriptor);
164 ASSERT_FALSE(
165 transformation.IsApplicable(context.get(), transformation_context));
166
167 instructionDescriptor =
168 MakeInstructionDescriptor(14, SpvOpConstantComposite, 0);
169 transformation = TransformationSwapCommutableOperands(instructionDescriptor);
170 ASSERT_FALSE(
171 transformation.IsApplicable(context.get(), transformation_context));
172
173 // Tests the base instruction id not existing
174 instructionDescriptor = MakeInstructionDescriptor(67, SpvOpIAddCarry, 0);
175 transformation = TransformationSwapCommutableOperands(instructionDescriptor);
176 ASSERT_FALSE(
177 transformation.IsApplicable(context.get(), transformation_context));
178
179 instructionDescriptor = MakeInstructionDescriptor(68, SpvOpIEqual, 0);
180 transformation = TransformationSwapCommutableOperands(instructionDescriptor);
181 ASSERT_FALSE(
182 transformation.IsApplicable(context.get(), transformation_context));
183
184 instructionDescriptor = MakeInstructionDescriptor(69, SpvOpINotEqual, 0);
185 transformation = TransformationSwapCommutableOperands(instructionDescriptor);
186 ASSERT_FALSE(
187 transformation.IsApplicable(context.get(), transformation_context));
188
189 instructionDescriptor = MakeInstructionDescriptor(70, SpvOpFOrdEqual, 0);
190 transformation = TransformationSwapCommutableOperands(instructionDescriptor);
191 ASSERT_FALSE(
192 transformation.IsApplicable(context.get(), transformation_context));
193
194 instructionDescriptor = MakeInstructionDescriptor(71, SpvOpPtrEqual, 0);
195 transformation = TransformationSwapCommutableOperands(instructionDescriptor);
196 ASSERT_FALSE(
197 transformation.IsApplicable(context.get(), transformation_context));
198
199 // Tests there being no instruction with the desired opcode after the base
200 // instruction id
201 instructionDescriptor = MakeInstructionDescriptor(24, SpvOpIAdd, 0);
202 transformation = TransformationSwapCommutableOperands(instructionDescriptor);
203 ASSERT_FALSE(
204 transformation.IsApplicable(context.get(), transformation_context));
205
206 instructionDescriptor = MakeInstructionDescriptor(38, SpvOpIMul, 0);
207 transformation = TransformationSwapCommutableOperands(instructionDescriptor);
208 ASSERT_FALSE(
209 transformation.IsApplicable(context.get(), transformation_context));
210
211 instructionDescriptor = MakeInstructionDescriptor(45, SpvOpFAdd, 0);
212 transformation = TransformationSwapCommutableOperands(instructionDescriptor);
213 ASSERT_FALSE(
214 transformation.IsApplicable(context.get(), transformation_context));
215
216 instructionDescriptor = MakeInstructionDescriptor(66, SpvOpFMul, 0);
217 transformation = TransformationSwapCommutableOperands(instructionDescriptor);
218 ASSERT_FALSE(
219 transformation.IsApplicable(context.get(), transformation_context));
220
221 // Tests there being an instruction with the desired opcode after the base
222 // instruction id, but the skip count associated with the instruction
223 // descriptor being so high.
224 instructionDescriptor = MakeInstructionDescriptor(11, SpvOpIAdd, 100);
225 transformation = TransformationSwapCommutableOperands(instructionDescriptor);
226 ASSERT_FALSE(
227 transformation.IsApplicable(context.get(), transformation_context));
228
229 instructionDescriptor = MakeInstructionDescriptor(16, SpvOpIMul, 100);
230 transformation = TransformationSwapCommutableOperands(instructionDescriptor);
231 ASSERT_FALSE(
232 transformation.IsApplicable(context.get(), transformation_context));
233
234 instructionDescriptor = MakeInstructionDescriptor(23, SpvOpFAdd, 100);
235 transformation = TransformationSwapCommutableOperands(instructionDescriptor);
236 ASSERT_FALSE(
237 transformation.IsApplicable(context.get(), transformation_context));
238
239 instructionDescriptor = MakeInstructionDescriptor(32, SpvOpFMul, 100);
240 transformation = TransformationSwapCommutableOperands(instructionDescriptor);
241 ASSERT_FALSE(
242 transformation.IsApplicable(context.get(), transformation_context));
243
244 instructionDescriptor = MakeInstructionDescriptor(37, SpvOpDot, 100);
245 transformation = TransformationSwapCommutableOperands(instructionDescriptor);
246 ASSERT_FALSE(
247 transformation.IsApplicable(context.get(), transformation_context));
248 }
249
TEST(TransformationSwapCommutableOperandsTest,ApplyTest)250 TEST(TransformationSwapCommutableOperandsTest, ApplyTest) {
251 std::string shader = R"(
252 OpCapability Shader
253 %1 = OpExtInstImport "GLSL.std.450"
254 OpMemoryModel Logical GLSL450
255 OpEntryPoint Fragment %4 "main"
256 OpExecutionMode %4 OriginUpperLeft
257 OpSource ESSL 310
258 OpName %4 "main"
259 %2 = OpTypeVoid
260 %3 = OpTypeFunction %2
261 %6 = OpTypeInt 32 1
262 %7 = OpTypeInt 32 0
263 %8 = OpConstant %7 2
264 %9 = OpTypeArray %6 %8
265 %10 = OpTypePointer Function %9
266 %12 = OpConstant %6 1
267 %13 = OpConstant %6 2
268 %14 = OpConstantComposite %9 %12 %13
269 %15 = OpTypePointer Function %6
270 %17 = OpConstant %6 0
271 %29 = OpTypeFloat 32
272 %30 = OpTypeArray %29 %8
273 %31 = OpTypePointer Function %30
274 %33 = OpConstant %29 1
275 %34 = OpConstant %29 2
276 %35 = OpConstantComposite %30 %33 %34
277 %36 = OpTypePointer Function %29
278 %49 = OpTypeVector %29 3
279 %50 = OpTypeArray %49 %8
280 %51 = OpTypePointer Function %50
281 %53 = OpConstant %29 3
282 %54 = OpConstantComposite %49 %33 %34 %53
283 %55 = OpConstant %29 4
284 %56 = OpConstant %29 5
285 %57 = OpConstant %29 6
286 %58 = OpConstantComposite %49 %55 %56 %57
287 %59 = OpConstantComposite %50 %54 %58
288 %61 = OpTypePointer Function %49
289 %4 = OpFunction %2 None %3
290 %5 = OpLabel
291 %11 = OpVariable %10 Function
292 %16 = OpVariable %15 Function
293 %23 = OpVariable %15 Function
294 %32 = OpVariable %31 Function
295 %37 = OpVariable %36 Function
296 %43 = OpVariable %36 Function
297 %52 = OpVariable %51 Function
298 %60 = OpVariable %36 Function
299 OpStore %11 %14
300 %18 = OpAccessChain %15 %11 %17
301 %19 = OpLoad %6 %18
302 %20 = OpAccessChain %15 %11 %12
303 %21 = OpLoad %6 %20
304 %22 = OpIAdd %6 %19 %21
305 OpStore %16 %22
306 %24 = OpAccessChain %15 %11 %17
307 %25 = OpLoad %6 %24
308 %26 = OpAccessChain %15 %11 %12
309 %27 = OpLoad %6 %26
310 %28 = OpIMul %6 %25 %27
311 OpStore %23 %28
312 OpStore %32 %35
313 %38 = OpAccessChain %36 %32 %17
314 %39 = OpLoad %29 %38
315 %40 = OpAccessChain %36 %32 %12
316 %41 = OpLoad %29 %40
317 %42 = OpFAdd %29 %39 %41
318 OpStore %37 %42
319 %44 = OpAccessChain %36 %32 %17
320 %45 = OpLoad %29 %44
321 %46 = OpAccessChain %36 %32 %12
322 %47 = OpLoad %29 %46
323 %48 = OpFMul %29 %45 %47
324 OpStore %43 %48
325 OpStore %52 %59
326 %62 = OpAccessChain %61 %52 %17
327 %63 = OpLoad %49 %62
328 %64 = OpAccessChain %61 %52 %12
329 %65 = OpLoad %49 %64
330 %66 = OpDot %29 %63 %65
331 OpStore %60 %66
332 OpReturn
333 OpFunctionEnd
334 )";
335
336 const auto env = SPV_ENV_UNIVERSAL_1_5;
337 const auto consumer = nullptr;
338 const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption);
339 ASSERT_TRUE(IsValid(env, context.get()));
340
341 FactManager fact_manager;
342 spvtools::ValidatorOptions validator_options;
343 TransformationContext transformation_context(&fact_manager,
344 validator_options);
345
346 auto instructionDescriptor = MakeInstructionDescriptor(22, SpvOpIAdd, 0);
347 auto transformation =
348 TransformationSwapCommutableOperands(instructionDescriptor);
349 transformation.Apply(context.get(), &transformation_context);
350
351 instructionDescriptor = MakeInstructionDescriptor(28, SpvOpIMul, 0);
352 transformation = TransformationSwapCommutableOperands(instructionDescriptor);
353 transformation.Apply(context.get(), &transformation_context);
354
355 instructionDescriptor = MakeInstructionDescriptor(42, SpvOpFAdd, 0);
356 transformation = TransformationSwapCommutableOperands(instructionDescriptor);
357 transformation.Apply(context.get(), &transformation_context);
358
359 instructionDescriptor = MakeInstructionDescriptor(48, SpvOpFMul, 0);
360 transformation = TransformationSwapCommutableOperands(instructionDescriptor);
361 transformation.Apply(context.get(), &transformation_context);
362
363 instructionDescriptor = MakeInstructionDescriptor(66, SpvOpDot, 0);
364 transformation = TransformationSwapCommutableOperands(instructionDescriptor);
365 transformation.Apply(context.get(), &transformation_context);
366
367 std::string variantShader = R"(
368 OpCapability Shader
369 %1 = OpExtInstImport "GLSL.std.450"
370 OpMemoryModel Logical GLSL450
371 OpEntryPoint Fragment %4 "main"
372 OpExecutionMode %4 OriginUpperLeft
373 OpSource ESSL 310
374 OpName %4 "main"
375 %2 = OpTypeVoid
376 %3 = OpTypeFunction %2
377 %6 = OpTypeInt 32 1
378 %7 = OpTypeInt 32 0
379 %8 = OpConstant %7 2
380 %9 = OpTypeArray %6 %8
381 %10 = OpTypePointer Function %9
382 %12 = OpConstant %6 1
383 %13 = OpConstant %6 2
384 %14 = OpConstantComposite %9 %12 %13
385 %15 = OpTypePointer Function %6
386 %17 = OpConstant %6 0
387 %29 = OpTypeFloat 32
388 %30 = OpTypeArray %29 %8
389 %31 = OpTypePointer Function %30
390 %33 = OpConstant %29 1
391 %34 = OpConstant %29 2
392 %35 = OpConstantComposite %30 %33 %34
393 %36 = OpTypePointer Function %29
394 %49 = OpTypeVector %29 3
395 %50 = OpTypeArray %49 %8
396 %51 = OpTypePointer Function %50
397 %53 = OpConstant %29 3
398 %54 = OpConstantComposite %49 %33 %34 %53
399 %55 = OpConstant %29 4
400 %56 = OpConstant %29 5
401 %57 = OpConstant %29 6
402 %58 = OpConstantComposite %49 %55 %56 %57
403 %59 = OpConstantComposite %50 %54 %58
404 %61 = OpTypePointer Function %49
405 %4 = OpFunction %2 None %3
406 %5 = OpLabel
407 %11 = OpVariable %10 Function
408 %16 = OpVariable %15 Function
409 %23 = OpVariable %15 Function
410 %32 = OpVariable %31 Function
411 %37 = OpVariable %36 Function
412 %43 = OpVariable %36 Function
413 %52 = OpVariable %51 Function
414 %60 = OpVariable %36 Function
415 OpStore %11 %14
416 %18 = OpAccessChain %15 %11 %17
417 %19 = OpLoad %6 %18
418 %20 = OpAccessChain %15 %11 %12
419 %21 = OpLoad %6 %20
420 %22 = OpIAdd %6 %21 %19
421 OpStore %16 %22
422 %24 = OpAccessChain %15 %11 %17
423 %25 = OpLoad %6 %24
424 %26 = OpAccessChain %15 %11 %12
425 %27 = OpLoad %6 %26
426 %28 = OpIMul %6 %27 %25
427 OpStore %23 %28
428 OpStore %32 %35
429 %38 = OpAccessChain %36 %32 %17
430 %39 = OpLoad %29 %38
431 %40 = OpAccessChain %36 %32 %12
432 %41 = OpLoad %29 %40
433 %42 = OpFAdd %29 %41 %39
434 OpStore %37 %42
435 %44 = OpAccessChain %36 %32 %17
436 %45 = OpLoad %29 %44
437 %46 = OpAccessChain %36 %32 %12
438 %47 = OpLoad %29 %46
439 %48 = OpFMul %29 %47 %45
440 OpStore %43 %48
441 OpStore %52 %59
442 %62 = OpAccessChain %61 %52 %17
443 %63 = OpLoad %49 %62
444 %64 = OpAccessChain %61 %52 %12
445 %65 = OpLoad %49 %64
446 %66 = OpDot %29 %65 %63
447 OpStore %60 %66
448 OpReturn
449 OpFunctionEnd
450 )";
451
452 ASSERT_TRUE(IsEqual(env, variantShader, context.get()));
453 }
454
455 } // namespace
456 } // namespace fuzz
457 } // namespace spvtools
458