1 // Copyright (c) 2019 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "source/reduce/merge_blocks_reduction_opportunity_finder.h"
16
17 #include "source/opt/build_module.h"
18 #include "source/reduce/reduction_opportunity.h"
19 #include "test/reduce/reduce_test_util.h"
20
21 namespace spvtools {
22 namespace reduce {
23 namespace {
24
TEST(MergeBlocksReductionPassTest,BasicCheck)25 TEST(MergeBlocksReductionPassTest, BasicCheck) {
26 std::string shader = R"(
27 OpCapability Shader
28 %1 = OpExtInstImport "GLSL.std.450"
29 OpMemoryModel Logical GLSL450
30 OpEntryPoint Fragment %4 "main"
31 OpExecutionMode %4 OriginUpperLeft
32 OpSource ESSL 310
33 OpName %4 "main"
34 OpName %8 "x"
35 %2 = OpTypeVoid
36 %3 = OpTypeFunction %2
37 %6 = OpTypeInt 32 1
38 %7 = OpTypePointer Function %6
39 %9 = OpConstant %6 1
40 %10 = OpConstant %6 2
41 %11 = OpConstant %6 3
42 %12 = OpConstant %6 4
43 %4 = OpFunction %2 None %3
44 %5 = OpLabel
45 %8 = OpVariable %7 Function
46 OpBranch %13
47 %13 = OpLabel
48 OpStore %8 %9
49 OpBranch %14
50 %14 = OpLabel
51 OpStore %8 %10
52 OpBranch %15
53 %15 = OpLabel
54 OpStore %8 %11
55 OpBranch %16
56 %16 = OpLabel
57 OpStore %8 %12
58 OpBranch %17
59 %17 = OpLabel
60 OpReturn
61 OpFunctionEnd
62 )";
63 const auto env = SPV_ENV_UNIVERSAL_1_3;
64 const auto consumer = nullptr;
65 const auto context =
66 BuildModule(env, consumer, shader, kReduceAssembleOption);
67 const auto ops =
68 MergeBlocksReductionOpportunityFinder().GetAvailableOpportunities(
69 context.get());
70 ASSERT_EQ(5, ops.size());
71
72 // Try order 3, 0, 2, 4, 1
73
74 ASSERT_TRUE(ops[3]->PreconditionHolds());
75 ops[3]->TryToApply();
76
77 std::string after_op_3 = R"(
78 OpCapability Shader
79 %1 = OpExtInstImport "GLSL.std.450"
80 OpMemoryModel Logical GLSL450
81 OpEntryPoint Fragment %4 "main"
82 OpExecutionMode %4 OriginUpperLeft
83 OpSource ESSL 310
84 OpName %4 "main"
85 OpName %8 "x"
86 %2 = OpTypeVoid
87 %3 = OpTypeFunction %2
88 %6 = OpTypeInt 32 1
89 %7 = OpTypePointer Function %6
90 %9 = OpConstant %6 1
91 %10 = OpConstant %6 2
92 %11 = OpConstant %6 3
93 %12 = OpConstant %6 4
94 %4 = OpFunction %2 None %3
95 %5 = OpLabel
96 %8 = OpVariable %7 Function
97 OpBranch %13
98 %13 = OpLabel
99 OpStore %8 %9
100 OpBranch %14
101 %14 = OpLabel
102 OpStore %8 %10
103 OpBranch %15
104 %15 = OpLabel
105 OpStore %8 %11
106 OpStore %8 %12
107 OpBranch %17
108 %17 = OpLabel
109 OpReturn
110 OpFunctionEnd
111 )";
112
113 CheckEqual(env, after_op_3, context.get());
114
115 ASSERT_TRUE(ops[0]->PreconditionHolds());
116 ops[0]->TryToApply();
117
118 std::string after_op_0 = R"(
119 OpCapability Shader
120 %1 = OpExtInstImport "GLSL.std.450"
121 OpMemoryModel Logical GLSL450
122 OpEntryPoint Fragment %4 "main"
123 OpExecutionMode %4 OriginUpperLeft
124 OpSource ESSL 310
125 OpName %4 "main"
126 OpName %8 "x"
127 %2 = OpTypeVoid
128 %3 = OpTypeFunction %2
129 %6 = OpTypeInt 32 1
130 %7 = OpTypePointer Function %6
131 %9 = OpConstant %6 1
132 %10 = OpConstant %6 2
133 %11 = OpConstant %6 3
134 %12 = OpConstant %6 4
135 %4 = OpFunction %2 None %3
136 %5 = OpLabel
137 %8 = OpVariable %7 Function
138 OpStore %8 %9
139 OpBranch %14
140 %14 = OpLabel
141 OpStore %8 %10
142 OpBranch %15
143 %15 = OpLabel
144 OpStore %8 %11
145 OpStore %8 %12
146 OpBranch %17
147 %17 = OpLabel
148 OpReturn
149 OpFunctionEnd
150 )";
151
152 CheckEqual(env, after_op_0, context.get());
153
154 ASSERT_TRUE(ops[2]->PreconditionHolds());
155 ops[2]->TryToApply();
156
157 std::string after_op_2 = R"(
158 OpCapability Shader
159 %1 = OpExtInstImport "GLSL.std.450"
160 OpMemoryModel Logical GLSL450
161 OpEntryPoint Fragment %4 "main"
162 OpExecutionMode %4 OriginUpperLeft
163 OpSource ESSL 310
164 OpName %4 "main"
165 OpName %8 "x"
166 %2 = OpTypeVoid
167 %3 = OpTypeFunction %2
168 %6 = OpTypeInt 32 1
169 %7 = OpTypePointer Function %6
170 %9 = OpConstant %6 1
171 %10 = OpConstant %6 2
172 %11 = OpConstant %6 3
173 %12 = OpConstant %6 4
174 %4 = OpFunction %2 None %3
175 %5 = OpLabel
176 %8 = OpVariable %7 Function
177 OpStore %8 %9
178 OpBranch %14
179 %14 = OpLabel
180 OpStore %8 %10
181 OpStore %8 %11
182 OpStore %8 %12
183 OpBranch %17
184 %17 = OpLabel
185 OpReturn
186 OpFunctionEnd
187 )";
188
189 CheckEqual(env, after_op_2, context.get());
190
191 ASSERT_TRUE(ops[4]->PreconditionHolds());
192 ops[4]->TryToApply();
193
194 std::string after_op_4 = R"(
195 OpCapability Shader
196 %1 = OpExtInstImport "GLSL.std.450"
197 OpMemoryModel Logical GLSL450
198 OpEntryPoint Fragment %4 "main"
199 OpExecutionMode %4 OriginUpperLeft
200 OpSource ESSL 310
201 OpName %4 "main"
202 OpName %8 "x"
203 %2 = OpTypeVoid
204 %3 = OpTypeFunction %2
205 %6 = OpTypeInt 32 1
206 %7 = OpTypePointer Function %6
207 %9 = OpConstant %6 1
208 %10 = OpConstant %6 2
209 %11 = OpConstant %6 3
210 %12 = OpConstant %6 4
211 %4 = OpFunction %2 None %3
212 %5 = OpLabel
213 %8 = OpVariable %7 Function
214 OpStore %8 %9
215 OpBranch %14
216 %14 = OpLabel
217 OpStore %8 %10
218 OpStore %8 %11
219 OpStore %8 %12
220 OpReturn
221 OpFunctionEnd
222 )";
223
224 CheckEqual(env, after_op_4, context.get());
225
226 ASSERT_TRUE(ops[1]->PreconditionHolds());
227 ops[1]->TryToApply();
228
229 std::string after_op_1 = R"(
230 OpCapability Shader
231 %1 = OpExtInstImport "GLSL.std.450"
232 OpMemoryModel Logical GLSL450
233 OpEntryPoint Fragment %4 "main"
234 OpExecutionMode %4 OriginUpperLeft
235 OpSource ESSL 310
236 OpName %4 "main"
237 OpName %8 "x"
238 %2 = OpTypeVoid
239 %3 = OpTypeFunction %2
240 %6 = OpTypeInt 32 1
241 %7 = OpTypePointer Function %6
242 %9 = OpConstant %6 1
243 %10 = OpConstant %6 2
244 %11 = OpConstant %6 3
245 %12 = OpConstant %6 4
246 %4 = OpFunction %2 None %3
247 %5 = OpLabel
248 %8 = OpVariable %7 Function
249 OpStore %8 %9
250 OpStore %8 %10
251 OpStore %8 %11
252 OpStore %8 %12
253 OpReturn
254 OpFunctionEnd
255 )";
256
257 CheckEqual(env, after_op_1, context.get());
258 }
259
TEST(MergeBlocksReductionPassTest,Loops)260 TEST(MergeBlocksReductionPassTest, Loops) {
261 std::string shader = R"(
262 OpCapability Shader
263 %1 = OpExtInstImport "GLSL.std.450"
264 OpMemoryModel Logical GLSL450
265 OpEntryPoint Fragment %4 "main"
266 OpExecutionMode %4 OriginUpperLeft
267 OpSource ESSL 310
268 OpName %4 "main"
269 OpName %8 "x"
270 OpName %10 "i"
271 OpName %29 "i"
272 %2 = OpTypeVoid
273 %3 = OpTypeFunction %2
274 %6 = OpTypeInt 32 1
275 %7 = OpTypePointer Function %6
276 %9 = OpConstant %6 1
277 %11 = OpConstant %6 0
278 %18 = OpConstant %6 10
279 %19 = OpTypeBool
280 %4 = OpFunction %2 None %3
281 %5 = OpLabel
282 %8 = OpVariable %7 Function
283 %10 = OpVariable %7 Function
284 %29 = OpVariable %7 Function
285 OpStore %8 %9
286 OpBranch %45
287 %45 = OpLabel
288 OpStore %10 %11
289 OpBranch %12
290 %12 = OpLabel
291 OpLoopMerge %14 %15 None
292 OpBranch %16
293 %16 = OpLabel
294 %17 = OpLoad %6 %10
295 OpBranch %46
296 %46 = OpLabel
297 %20 = OpSLessThan %19 %17 %18
298 OpBranchConditional %20 %13 %14
299 %13 = OpLabel
300 %21 = OpLoad %6 %10
301 OpBranch %47
302 %47 = OpLabel
303 %22 = OpLoad %6 %8
304 %23 = OpIAdd %6 %22 %21
305 OpStore %8 %23
306 %24 = OpLoad %6 %10
307 %25 = OpLoad %6 %8
308 %26 = OpIAdd %6 %25 %24
309 OpStore %8 %26
310 OpBranch %48
311 %48 = OpLabel
312 OpBranch %15
313 %15 = OpLabel
314 %27 = OpLoad %6 %10
315 %28 = OpIAdd %6 %27 %9
316 OpStore %10 %28
317 OpBranch %12
318 %14 = OpLabel
319 OpStore %29 %11
320 OpBranch %49
321 %49 = OpLabel
322 OpBranch %30
323 %30 = OpLabel
324 OpLoopMerge %32 %33 None
325 OpBranch %34
326 %34 = OpLabel
327 %35 = OpLoad %6 %29
328 %36 = OpSLessThan %19 %35 %18
329 OpBranch %50
330 %50 = OpLabel
331 OpBranchConditional %36 %31 %32
332 %31 = OpLabel
333 %37 = OpLoad %6 %29
334 %38 = OpLoad %6 %8
335 %39 = OpIAdd %6 %38 %37
336 OpStore %8 %39
337 %40 = OpLoad %6 %29
338 %41 = OpLoad %6 %8
339 %42 = OpIAdd %6 %41 %40
340 OpStore %8 %42
341 OpBranch %33
342 %33 = OpLabel
343 %43 = OpLoad %6 %29
344 %44 = OpIAdd %6 %43 %9
345 OpBranch %51
346 %51 = OpLabel
347 OpStore %29 %44
348 OpBranch %30
349 %32 = OpLabel
350 OpReturn
351 OpFunctionEnd
352 )";
353 const auto env = SPV_ENV_UNIVERSAL_1_3;
354 const auto consumer = nullptr;
355 const auto context =
356 BuildModule(env, consumer, shader, kReduceAssembleOption);
357 const auto ops =
358 MergeBlocksReductionOpportunityFinder().GetAvailableOpportunities(
359 context.get());
360 ASSERT_EQ(11, ops.size());
361
362 for (auto& ri : ops) {
363 ASSERT_TRUE(ri->PreconditionHolds());
364 ri->TryToApply();
365 }
366
367 std::string after = R"(
368 OpCapability Shader
369 %1 = OpExtInstImport "GLSL.std.450"
370 OpMemoryModel Logical GLSL450
371 OpEntryPoint Fragment %4 "main"
372 OpExecutionMode %4 OriginUpperLeft
373 OpSource ESSL 310
374 OpName %4 "main"
375 OpName %8 "x"
376 OpName %10 "i"
377 OpName %29 "i"
378 %2 = OpTypeVoid
379 %3 = OpTypeFunction %2
380 %6 = OpTypeInt 32 1
381 %7 = OpTypePointer Function %6
382 %9 = OpConstant %6 1
383 %11 = OpConstant %6 0
384 %18 = OpConstant %6 10
385 %19 = OpTypeBool
386 %4 = OpFunction %2 None %3
387 %5 = OpLabel
388 %8 = OpVariable %7 Function
389 %10 = OpVariable %7 Function
390 %29 = OpVariable %7 Function
391 OpStore %8 %9
392 OpStore %10 %11
393 OpBranch %12
394 %12 = OpLabel
395 %17 = OpLoad %6 %10
396 %20 = OpSLessThan %19 %17 %18
397 OpLoopMerge %14 %13 None
398 OpBranchConditional %20 %13 %14
399 %13 = OpLabel
400 %21 = OpLoad %6 %10
401 %22 = OpLoad %6 %8
402 %23 = OpIAdd %6 %22 %21
403 OpStore %8 %23
404 %24 = OpLoad %6 %10
405 %25 = OpLoad %6 %8
406 %26 = OpIAdd %6 %25 %24
407 OpStore %8 %26
408 %27 = OpLoad %6 %10
409 %28 = OpIAdd %6 %27 %9
410 OpStore %10 %28
411 OpBranch %12
412 %14 = OpLabel
413 OpStore %29 %11
414 OpBranch %30
415 %30 = OpLabel
416 %35 = OpLoad %6 %29
417 %36 = OpSLessThan %19 %35 %18
418 OpLoopMerge %32 %31 None
419 OpBranchConditional %36 %31 %32
420 %31 = OpLabel
421 %37 = OpLoad %6 %29
422 %38 = OpLoad %6 %8
423 %39 = OpIAdd %6 %38 %37
424 OpStore %8 %39
425 %40 = OpLoad %6 %29
426 %41 = OpLoad %6 %8
427 %42 = OpIAdd %6 %41 %40
428 OpStore %8 %42
429 %43 = OpLoad %6 %29
430 %44 = OpIAdd %6 %43 %9
431 OpStore %29 %44
432 OpBranch %30
433 %32 = OpLabel
434 OpReturn
435 OpFunctionEnd
436 )";
437
438 CheckEqual(env, after, context.get());
439 }
440
TEST(MergeBlocksReductionPassTest,MergeWithOpPhi)441 TEST(MergeBlocksReductionPassTest, MergeWithOpPhi) {
442 std::string shader = R"(
443 OpCapability Shader
444 %1 = OpExtInstImport "GLSL.std.450"
445 OpMemoryModel Logical GLSL450
446 OpEntryPoint Fragment %4 "main"
447 OpExecutionMode %4 OriginUpperLeft
448 OpSource ESSL 310
449 OpName %4 "main"
450 OpName %8 "x"
451 OpName %10 "y"
452 %2 = OpTypeVoid
453 %3 = OpTypeFunction %2
454 %6 = OpTypeInt 32 1
455 %7 = OpTypePointer Function %6
456 %9 = OpConstant %6 1
457 %4 = OpFunction %2 None %3
458 %5 = OpLabel
459 %8 = OpVariable %7 Function
460 %10 = OpVariable %7 Function
461 OpStore %8 %9
462 %11 = OpLoad %6 %8
463 OpBranch %12
464 %12 = OpLabel
465 %13 = OpPhi %6 %11 %5
466 OpStore %10 %13
467 OpReturn
468 OpFunctionEnd
469 )";
470
471 const auto env = SPV_ENV_UNIVERSAL_1_3;
472 const auto consumer = nullptr;
473 const auto context =
474 BuildModule(env, consumer, shader, kReduceAssembleOption);
475 const auto ops =
476 MergeBlocksReductionOpportunityFinder().GetAvailableOpportunities(
477 context.get());
478 ASSERT_EQ(1, ops.size());
479
480 ASSERT_TRUE(ops[0]->PreconditionHolds());
481 ops[0]->TryToApply();
482
483 std::string after = R"(
484 OpCapability Shader
485 %1 = OpExtInstImport "GLSL.std.450"
486 OpMemoryModel Logical GLSL450
487 OpEntryPoint Fragment %4 "main"
488 OpExecutionMode %4 OriginUpperLeft
489 OpSource ESSL 310
490 OpName %4 "main"
491 OpName %8 "x"
492 OpName %10 "y"
493 %2 = OpTypeVoid
494 %3 = OpTypeFunction %2
495 %6 = OpTypeInt 32 1
496 %7 = OpTypePointer Function %6
497 %9 = OpConstant %6 1
498 %4 = OpFunction %2 None %3
499 %5 = OpLabel
500 %8 = OpVariable %7 Function
501 %10 = OpVariable %7 Function
502 OpStore %8 %9
503 %11 = OpLoad %6 %8
504 OpStore %10 %11
505 OpReturn
506 OpFunctionEnd
507 )";
508
509 CheckEqual(env, after, context.get());
510 }
511
MergeBlocksReductionPassTest_LoopReturn_Helper(bool reverse)512 void MergeBlocksReductionPassTest_LoopReturn_Helper(bool reverse) {
513 // A merge block opportunity stores a block that can be merged with its
514 // predecessor.
515 // Given blocks A -> B -> C:
516 // This test demonstrates how merging B->C can invalidate
517 // the opportunity of merging A->B, and vice-versa. E.g.
518 // B->C are merged: B is now terminated with OpReturn.
519 // A->B can now no longer be merged because A is a loop header, which
520 // cannot be terminated with OpReturn.
521
522 std::string shader = R"(
523 OpCapability Shader
524 %1 = OpExtInstImport "GLSL.std.450"
525 OpMemoryModel Logical GLSL450
526 OpEntryPoint Fragment %2 "main"
527 OpExecutionMode %2 OriginUpperLeft
528 OpSource ESSL 310
529 OpName %2 "main"
530 %3 = OpTypeVoid
531 %4 = OpTypeFunction %3
532 %5 = OpTypeInt 32 1
533 %6 = OpTypePointer Function %5
534 %7 = OpTypeBool
535 %8 = OpConstantFalse %7
536 %2 = OpFunction %3 None %4
537 %9 = OpLabel
538 OpBranch %10
539 %10 = OpLabel ; A (loop header)
540 OpLoopMerge %13 %12 None
541 OpBranch %11
542 %12 = OpLabel ; (unreachable continue block)
543 OpBranch %10
544 %11 = OpLabel ; B
545 OpBranch %15
546 %15 = OpLabel ; C
547 OpReturn
548 %13 = OpLabel ; (unreachable merge block)
549 OpReturn
550 OpFunctionEnd
551 )";
552 const auto env = SPV_ENV_UNIVERSAL_1_3;
553 const auto consumer = nullptr;
554 const auto context =
555 BuildModule(env, consumer, shader, kReduceAssembleOption);
556 ASSERT_NE(context.get(), nullptr);
557 auto opportunities =
558 MergeBlocksReductionOpportunityFinder().GetAvailableOpportunities(
559 context.get());
560
561 // A->B and B->C
562 ASSERT_EQ(opportunities.size(), 2);
563
564 // Test applying opportunities in both orders.
565 if (reverse) {
566 std::reverse(opportunities.begin(), opportunities.end());
567 }
568
569 size_t num_applied = 0;
570 for (auto& ri : opportunities) {
571 if (ri->PreconditionHolds()) {
572 ri->TryToApply();
573 ++num_applied;
574 }
575 }
576
577 // Only 1 opportunity can be applied, as both disable each other.
578 ASSERT_EQ(num_applied, 1);
579
580 std::string after = R"(
581 OpCapability Shader
582 %1 = OpExtInstImport "GLSL.std.450"
583 OpMemoryModel Logical GLSL450
584 OpEntryPoint Fragment %2 "main"
585 OpExecutionMode %2 OriginUpperLeft
586 OpSource ESSL 310
587 OpName %2 "main"
588 %3 = OpTypeVoid
589 %4 = OpTypeFunction %3
590 %5 = OpTypeInt 32 1
591 %6 = OpTypePointer Function %5
592 %7 = OpTypeBool
593 %8 = OpConstantFalse %7
594 %2 = OpFunction %3 None %4
595 %9 = OpLabel
596 OpBranch %10
597 %10 = OpLabel ; A-B (loop header)
598 OpLoopMerge %13 %12 None
599 OpBranch %15
600 %12 = OpLabel ; (unreachable continue block)
601 OpBranch %10
602 %15 = OpLabel ; C
603 OpReturn
604 %13 = OpLabel ; (unreachable merge block)
605 OpReturn
606 OpFunctionEnd
607 )";
608
609 // The only difference is the labels.
610 std::string after_reversed = R"(
611 OpCapability Shader
612 %1 = OpExtInstImport "GLSL.std.450"
613 OpMemoryModel Logical GLSL450
614 OpEntryPoint Fragment %2 "main"
615 OpExecutionMode %2 OriginUpperLeft
616 OpSource ESSL 310
617 OpName %2 "main"
618 %3 = OpTypeVoid
619 %4 = OpTypeFunction %3
620 %5 = OpTypeInt 32 1
621 %6 = OpTypePointer Function %5
622 %7 = OpTypeBool
623 %8 = OpConstantFalse %7
624 %2 = OpFunction %3 None %4
625 %9 = OpLabel
626 OpBranch %10
627 %10 = OpLabel ; A (loop header)
628 OpLoopMerge %13 %12 None
629 OpBranch %11
630 %12 = OpLabel ; (unreachable continue block)
631 OpBranch %10
632 %11 = OpLabel ; B-C
633 OpReturn
634 %13 = OpLabel ; (unreachable merge block)
635 OpReturn
636 OpFunctionEnd
637 )";
638
639 CheckEqual(env, reverse ? after_reversed : after, context.get());
640 }
641
TEST(MergeBlocksReductionPassTest,LoopReturn)642 TEST(MergeBlocksReductionPassTest, LoopReturn) {
643 MergeBlocksReductionPassTest_LoopReturn_Helper(false);
644 }
645
TEST(MergeBlocksReductionPassTest,LoopReturnReverse)646 TEST(MergeBlocksReductionPassTest, LoopReturnReverse) {
647 MergeBlocksReductionPassTest_LoopReturn_Helper(true);
648 }
649
650 } // namespace
651 } // namespace reduce
652 } // namespace spvtools
653