• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2018 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "source/reduce/reducer.h"
16 
17 #include <unordered_map>
18 
19 #include "source/opt/build_module.h"
20 #include "source/reduce/operand_to_const_reduction_opportunity_finder.h"
21 #include "source/reduce/remove_unused_instruction_reduction_opportunity_finder.h"
22 #include "test/reduce/reduce_test_util.h"
23 
24 namespace spvtools {
25 namespace reduce {
26 namespace {
27 
28 const spv_target_env kEnv = SPV_ENV_UNIVERSAL_1_3;
29 const MessageConsumer kMessageConsumer = NopDiagnostic;
30 
31 // This changes its mind each time IsInteresting is invoked as to whether the
32 // binary is interesting, until some limit is reached after which the binary is
33 // always deemed interesting.  This is useful to test that reduction passes
34 // interleave in interesting ways for a while, and then always succeed after
35 // some point; the latter is important to end up with a predictable final
36 // reduced binary for tests.
37 class PingPongInteresting {
38  public:
PingPongInteresting(uint32_t always_interesting_after)39   explicit PingPongInteresting(uint32_t always_interesting_after)
40       : is_interesting_(true),
41         always_interesting_after_(always_interesting_after),
42         count_(0) {}
43 
IsInteresting()44   bool IsInteresting() {
45     bool result;
46     if (count_ > always_interesting_after_) {
47       result = true;
48     } else {
49       result = is_interesting_;
50       is_interesting_ = !is_interesting_;
51     }
52     count_++;
53     return result;
54   }
55 
56  private:
57   bool is_interesting_;
58   const uint32_t always_interesting_after_;
59   uint32_t count_;
60 };
61 
TEST(ReducerTest,ExprToConstantAndRemoveUnreferenced)62 TEST(ReducerTest, ExprToConstantAndRemoveUnreferenced) {
63   // Check that ExprToConstant and RemoveUnreferenced work together; once some
64   // ID uses have been changed to constants, those IDs can be removed.
65   std::string original = R"(
66                OpCapability Shader
67           %1 = OpExtInstImport "GLSL.std.450"
68                OpMemoryModel Logical GLSL450
69                OpEntryPoint Fragment %4 "main" %60
70                OpExecutionMode %4 OriginUpperLeft
71                OpSource ESSL 310
72                OpName %4 "main"
73                OpName %16 "buf2"
74                OpMemberName %16 0 "i"
75                OpName %18 ""
76                OpName %25 "buf1"
77                OpMemberName %25 0 "f"
78                OpName %27 ""
79                OpName %60 "_GLF_color"
80                OpMemberDecorate %16 0 Offset 0
81                OpDecorate %16 Block
82                OpDecorate %18 DescriptorSet 0
83                OpDecorate %18 Binding 2
84                OpMemberDecorate %25 0 Offset 0
85                OpDecorate %25 Block
86                OpDecorate %27 DescriptorSet 0
87                OpDecorate %27 Binding 1
88                OpDecorate %60 Location 0
89           %2 = OpTypeVoid
90           %3 = OpTypeFunction %2
91           %6 = OpTypeInt 32 1
92           %9 = OpConstant %6 0
93          %16 = OpTypeStruct %6
94          %17 = OpTypePointer Uniform %16
95          %18 = OpVariable %17 Uniform
96          %19 = OpTypePointer Uniform %6
97          %22 = OpTypeBool
98         %100 = OpConstantTrue %22
99          %24 = OpTypeFloat 32
100          %25 = OpTypeStruct %24
101          %26 = OpTypePointer Uniform %25
102          %27 = OpVariable %26 Uniform
103          %28 = OpTypePointer Uniform %24
104          %31 = OpConstant %24 2
105          %56 = OpConstant %6 1
106          %58 = OpTypeVector %24 4
107          %59 = OpTypePointer Output %58
108          %60 = OpVariable %59 Output
109          %72 = OpUndef %24
110          %74 = OpUndef %6
111           %4 = OpFunction %2 None %3
112           %5 = OpLabel
113                OpBranch %10
114          %10 = OpLabel
115          %73 = OpPhi %6 %74 %5 %77 %34
116          %71 = OpPhi %24 %72 %5 %76 %34
117          %70 = OpPhi %6 %9 %5 %57 %34
118          %20 = OpAccessChain %19 %18 %9
119          %21 = OpLoad %6 %20
120          %23 = OpSLessThan %22 %70 %21
121                OpLoopMerge %12 %34 None
122                OpBranchConditional %23 %11 %12
123          %11 = OpLabel
124          %29 = OpAccessChain %28 %27 %9
125          %30 = OpLoad %24 %29
126          %32 = OpFOrdGreaterThan %22 %30 %31
127                OpSelectionMerge %90 None
128                OpBranchConditional %32 %33 %46
129          %33 = OpLabel
130          %40 = OpFAdd %24 %71 %30
131          %45 = OpISub %6 %73 %21
132                OpBranch %90
133          %46 = OpLabel
134          %50 = OpFMul %24 %71 %30
135          %54 = OpSDiv %6 %73 %21
136                OpBranch %90
137          %90 = OpLabel
138          %77 = OpPhi %6 %45 %33 %54 %46
139          %76 = OpPhi %24 %40 %33 %50 %46
140                OpBranch %34
141          %34 = OpLabel
142          %57 = OpIAdd %6 %70 %56
143                OpBranch %10
144          %12 = OpLabel
145          %61 = OpAccessChain %28 %27 %9
146          %62 = OpLoad %24 %61
147          %66 = OpConvertSToF %24 %21
148          %68 = OpConvertSToF %24 %73
149          %69 = OpCompositeConstruct %58 %62 %71 %66 %68
150                OpStore %60 %69
151                OpReturn
152                OpFunctionEnd
153   )";
154 
155   std::string expected = R"(
156                OpCapability Shader
157           %1 = OpExtInstImport "GLSL.std.450"
158                OpMemoryModel Logical GLSL450
159                OpEntryPoint Fragment %4 "main"
160                OpExecutionMode %4 OriginUpperLeft
161           %2 = OpTypeVoid
162           %3 = OpTypeFunction %2
163           %6 = OpTypeInt 32 1
164           %9 = OpConstant %6 0
165          %22 = OpTypeBool
166         %100 = OpConstantTrue %22
167          %24 = OpTypeFloat 32
168          %31 = OpConstant %24 2
169          %56 = OpConstant %6 1
170          %72 = OpUndef %24
171          %74 = OpUndef %6
172           %4 = OpFunction %2 None %3
173           %5 = OpLabel
174                OpBranch %10
175          %10 = OpLabel
176                OpLoopMerge %12 %34 None
177                OpBranchConditional %100 %11 %12
178          %11 = OpLabel
179                OpSelectionMerge %90 None
180                OpBranchConditional %100 %33 %46
181          %33 = OpLabel
182                OpBranch %90
183          %46 = OpLabel
184                OpBranch %90
185          %90 = OpLabel
186                OpBranch %34
187          %34 = OpLabel
188                OpBranch %10
189          %12 = OpLabel
190                OpReturn
191                OpFunctionEnd
192   )";
193 
194   Reducer reducer(kEnv);
195   PingPongInteresting ping_pong_interesting(10);
196   reducer.SetMessageConsumer(kMessageConsumer);
197   reducer.SetInterestingnessFunction(
198       [&ping_pong_interesting](const std::vector<uint32_t>&, uint32_t) -> bool {
199         return ping_pong_interesting.IsInteresting();
200       });
201   reducer.AddReductionPass(
202       MakeUnique<RemoveUnusedInstructionReductionOpportunityFinder>(false));
203   reducer.AddReductionPass(
204       MakeUnique<OperandToConstReductionOpportunityFinder>());
205 
206   std::vector<uint32_t> binary_in;
207   SpirvTools t(kEnv);
208 
209   ASSERT_TRUE(t.Assemble(original, &binary_in, kReduceAssembleOption));
210   std::vector<uint32_t> binary_out;
211   spvtools::ReducerOptions reducer_options;
212   reducer_options.set_step_limit(500);
213   reducer_options.set_fail_on_validation_error(true);
214   spvtools::ValidatorOptions validator_options;
215 
216   Reducer::ReductionResultStatus status = reducer.Run(
217       std::move(binary_in), &binary_out, reducer_options, validator_options);
218 
219   ASSERT_EQ(status, Reducer::ReductionResultStatus::kComplete);
220 
221   CheckEqual(kEnv, expected, binary_out);
222 }
223 
InterestingWhileOpcodeExists(const std::vector<uint32_t> & binary,uint32_t opcode,uint32_t count,bool dump)224 bool InterestingWhileOpcodeExists(const std::vector<uint32_t>& binary,
225                                   uint32_t opcode, uint32_t count, bool dump) {
226   if (dump) {
227     std::stringstream ss;
228     ss << "temp_" << count << ".spv";
229     DumpShader(binary, ss.str().c_str());
230   }
231 
232   std::unique_ptr<opt::IRContext> context =
233       BuildModule(kEnv, kMessageConsumer, binary.data(), binary.size());
234   assert(context);
235   bool interesting = false;
236   for (auto& function : *context->module()) {
237     context->cfg()->ForEachBlockInPostOrder(
238         &*function.begin(),
239         [opcode, &interesting](opt::BasicBlock* block) -> void {
240           for (auto& inst : *block) {
241             if (inst.opcode() == opcode) {
242               interesting = true;
243               break;
244             }
245           }
246         });
247     if (interesting) {
248       break;
249     }
250   }
251   return interesting;
252 }
253 
InterestingWhileIMulReachable(const std::vector<uint32_t> & binary,uint32_t count)254 bool InterestingWhileIMulReachable(const std::vector<uint32_t>& binary,
255                                    uint32_t count) {
256   return InterestingWhileOpcodeExists(binary, SpvOpIMul, count, false);
257 }
258 
InterestingWhileSDivReachable(const std::vector<uint32_t> & binary,uint32_t count)259 bool InterestingWhileSDivReachable(const std::vector<uint32_t>& binary,
260                                    uint32_t count) {
261   return InterestingWhileOpcodeExists(binary, SpvOpSDiv, count, false);
262 }
263 
264 // The shader below was derived from the following GLSL, and optimized.
265 // #version 310 es
266 // precision highp float;
267 // layout(location = 0) out vec4 _GLF_color;
268 // int foo() {
269 //    int x = 1;
270 //    int y;
271 //    x = y / x;   // SDiv
272 //    return x;
273 // }
274 // void main() {
275 //    int c;
276 //    while (bool(c)) {
277 //        do {
278 //            if (bool(c)) {
279 //                if (bool(c)) {
280 //                    ++c;
281 //                } else {
282 //                    _GLF_color.x = float(c*c);  // IMul
283 //                }
284 //                return;
285 //            }
286 //        } while(bool(foo()));
287 //        return;
288 //    }
289 // }
290 const std::string kShaderWithLoopsDivAndMul = R"(
291                OpCapability Shader
292           %1 = OpExtInstImport "GLSL.std.450"
293                OpMemoryModel Logical GLSL450
294                OpEntryPoint Fragment %4 "main" %49
295                OpExecutionMode %4 OriginUpperLeft
296                OpSource ESSL 310
297                OpName %4 "main"
298                OpName %49 "_GLF_color"
299                OpDecorate %49 Location 0
300                OpDecorate %52 RelaxedPrecision
301                OpDecorate %77 RelaxedPrecision
302           %2 = OpTypeVoid
303           %3 = OpTypeFunction %2
304           %6 = OpTypeInt 32 1
305          %12 = OpConstant %6 1
306          %27 = OpTypeBool
307          %28 = OpTypeInt 32 0
308          %29 = OpConstant %28 0
309          %46 = OpTypeFloat 32
310          %47 = OpTypeVector %46 4
311          %48 = OpTypePointer Output %47
312          %49 = OpVariable %48 Output
313          %54 = OpTypePointer Output %46
314          %64 = OpConstantFalse %27
315          %67 = OpConstantTrue %27
316          %81 = OpUndef %6
317           %4 = OpFunction %2 None %3
318           %5 = OpLabel
319                OpBranch %61
320          %61 = OpLabel
321                OpLoopMerge %60 %63 None
322                OpBranch %20
323          %20 = OpLabel
324          %30 = OpINotEqual %27 %81 %29
325                OpLoopMerge %22 %23 None
326                OpBranchConditional %30 %21 %22
327          %21 = OpLabel
328                OpBranch %31
329          %31 = OpLabel
330                OpLoopMerge %33 %38 None
331                OpBranch %32
332          %32 = OpLabel
333                OpBranchConditional %30 %37 %38
334          %37 = OpLabel
335                OpSelectionMerge %42 None
336                OpBranchConditional %30 %41 %45
337          %41 = OpLabel
338                OpBranch %42
339          %45 = OpLabel
340          %52 = OpIMul %6 %81 %81
341          %53 = OpConvertSToF %46 %52
342          %55 = OpAccessChain %54 %49 %29
343                OpStore %55 %53
344                OpBranch %42
345          %42 = OpLabel
346                OpBranch %33
347          %38 = OpLabel
348          %77 = OpSDiv %6 %81 %12
349          %58 = OpINotEqual %27 %77 %29
350                OpBranchConditional %58 %31 %33
351          %33 = OpLabel
352          %86 = OpPhi %27 %67 %42 %64 %38
353                OpSelectionMerge %68 None
354                OpBranchConditional %86 %22 %68
355          %68 = OpLabel
356                OpBranch %22
357          %23 = OpLabel
358                OpBranch %20
359          %22 = OpLabel
360          %90 = OpPhi %27 %64 %20 %86 %33 %67 %68
361                OpSelectionMerge %70 None
362                OpBranchConditional %90 %60 %70
363          %70 = OpLabel
364                OpBranch %60
365          %63 = OpLabel
366                OpBranch %61
367          %60 = OpLabel
368                OpReturn
369                OpFunctionEnd
370   )";
371 
372 // The shader below comes from the following GLSL.
373 // #version 320 es
374 //
375 //  int baz(int x) {
376 //   int y = x + 1;
377 //   y = y + 2;
378 //   if (y > 0) {
379 //     return x;
380 //   }
381 //   return x + 1;
382 // }
383 //
384 //  int bar(int a) {
385 //   if (a == 3) {
386 //     return baz(2*a);
387 //   }
388 //   a = a + 1;
389 //   for (int i = 0; i < 10; i++) {
390 //     a += baz(a);
391 //   }
392 //   return a;
393 // }
394 //
395 //  void main() {
396 //   int x;
397 //   x = 3;
398 //   x += 1;
399 //   x += bar(x);
400 //   x += baz(x);
401 // }
402 const std::string kShaderWithMultipleFunctions = R"(
403                OpCapability Shader
404           %1 = OpExtInstImport "GLSL.std.450"
405                OpMemoryModel Logical GLSL450
406                OpEntryPoint Fragment %4 "main"
407                OpExecutionMode %4 OriginUpperLeft
408                OpSource ESSL 320
409           %2 = OpTypeVoid
410           %3 = OpTypeFunction %2
411           %6 = OpTypeInt 32 1
412           %7 = OpTypePointer Function %6
413           %8 = OpTypeFunction %6 %7
414          %17 = OpConstant %6 1
415          %20 = OpConstant %6 2
416          %23 = OpConstant %6 0
417          %24 = OpTypeBool
418          %35 = OpConstant %6 3
419          %53 = OpConstant %6 10
420           %4 = OpFunction %2 None %3
421           %5 = OpLabel
422          %65 = OpVariable %7 Function
423          %68 = OpVariable %7 Function
424          %73 = OpVariable %7 Function
425                OpStore %65 %35
426          %66 = OpLoad %6 %65
427          %67 = OpIAdd %6 %66 %17
428                OpStore %65 %67
429          %69 = OpLoad %6 %65
430                OpStore %68 %69
431          %70 = OpFunctionCall %6 %13 %68
432          %71 = OpLoad %6 %65
433          %72 = OpIAdd %6 %71 %70
434                OpStore %65 %72
435          %74 = OpLoad %6 %65
436                OpStore %73 %74
437          %75 = OpFunctionCall %6 %10 %73
438          %76 = OpLoad %6 %65
439          %77 = OpIAdd %6 %76 %75
440                OpStore %65 %77
441                OpReturn
442                OpFunctionEnd
443          %10 = OpFunction %6 None %8
444           %9 = OpFunctionParameter %7
445          %11 = OpLabel
446          %15 = OpVariable %7 Function
447          %16 = OpLoad %6 %9
448          %18 = OpIAdd %6 %16 %17
449                OpStore %15 %18
450          %19 = OpLoad %6 %15
451          %21 = OpIAdd %6 %19 %20
452                OpStore %15 %21
453          %22 = OpLoad %6 %15
454          %25 = OpSGreaterThan %24 %22 %23
455                OpSelectionMerge %27 None
456                OpBranchConditional %25 %26 %27
457          %26 = OpLabel
458          %28 = OpLoad %6 %9
459                OpReturnValue %28
460          %27 = OpLabel
461          %30 = OpLoad %6 %9
462          %31 = OpIAdd %6 %30 %17
463                OpReturnValue %31
464                OpFunctionEnd
465          %13 = OpFunction %6 None %8
466          %12 = OpFunctionParameter %7
467          %14 = OpLabel
468          %41 = OpVariable %7 Function
469          %46 = OpVariable %7 Function
470          %55 = OpVariable %7 Function
471          %34 = OpLoad %6 %12
472          %36 = OpIEqual %24 %34 %35
473                OpSelectionMerge %38 None
474                OpBranchConditional %36 %37 %38
475          %37 = OpLabel
476          %39 = OpLoad %6 %12
477          %40 = OpIMul %6 %20 %39
478                OpStore %41 %40
479          %42 = OpFunctionCall %6 %10 %41
480                OpReturnValue %42
481          %38 = OpLabel
482          %44 = OpLoad %6 %12
483          %45 = OpIAdd %6 %44 %17
484                OpStore %12 %45
485                OpStore %46 %23
486                OpBranch %47
487          %47 = OpLabel
488                OpLoopMerge %49 %50 None
489                OpBranch %51
490          %51 = OpLabel
491          %52 = OpLoad %6 %46
492          %54 = OpSLessThan %24 %52 %53
493                OpBranchConditional %54 %48 %49
494          %48 = OpLabel
495          %56 = OpLoad %6 %12
496                OpStore %55 %56
497          %57 = OpFunctionCall %6 %10 %55
498          %58 = OpLoad %6 %12
499          %59 = OpIAdd %6 %58 %57
500                OpStore %12 %59
501                OpBranch %50
502          %50 = OpLabel
503          %60 = OpLoad %6 %46
504          %61 = OpIAdd %6 %60 %17
505                OpStore %46 %61
506                OpBranch %47
507          %49 = OpLabel
508          %62 = OpLoad %6 %12
509                OpReturnValue %62
510                OpFunctionEnd
511   )";
512 
TEST(ReducerTest,ShaderReduceWhileMulReachable)513 TEST(ReducerTest, ShaderReduceWhileMulReachable) {
514   Reducer reducer(kEnv);
515 
516   reducer.SetInterestingnessFunction(InterestingWhileIMulReachable);
517   reducer.AddDefaultReductionPasses();
518   reducer.SetMessageConsumer(kMessageConsumer);
519 
520   std::vector<uint32_t> binary_in;
521   SpirvTools t(kEnv);
522 
523   ASSERT_TRUE(
524       t.Assemble(kShaderWithLoopsDivAndMul, &binary_in, kReduceAssembleOption));
525   std::vector<uint32_t> binary_out;
526   spvtools::ReducerOptions reducer_options;
527   reducer_options.set_step_limit(500);
528   reducer_options.set_fail_on_validation_error(true);
529   spvtools::ValidatorOptions validator_options;
530 
531   Reducer::ReductionResultStatus status = reducer.Run(
532       std::move(binary_in), &binary_out, reducer_options, validator_options);
533 
534   ASSERT_EQ(status, Reducer::ReductionResultStatus::kComplete);
535 }
536 
TEST(ReducerTest,ShaderReduceWhileDivReachable)537 TEST(ReducerTest, ShaderReduceWhileDivReachable) {
538   Reducer reducer(kEnv);
539 
540   reducer.SetInterestingnessFunction(InterestingWhileSDivReachable);
541   reducer.AddDefaultReductionPasses();
542   reducer.SetMessageConsumer(kMessageConsumer);
543 
544   std::vector<uint32_t> binary_in;
545   SpirvTools t(kEnv);
546 
547   ASSERT_TRUE(
548       t.Assemble(kShaderWithLoopsDivAndMul, &binary_in, kReduceAssembleOption));
549   std::vector<uint32_t> binary_out;
550   spvtools::ReducerOptions reducer_options;
551   reducer_options.set_step_limit(500);
552   reducer_options.set_fail_on_validation_error(true);
553   spvtools::ValidatorOptions validator_options;
554 
555   Reducer::ReductionResultStatus status = reducer.Run(
556       std::move(binary_in), &binary_out, reducer_options, validator_options);
557 
558   ASSERT_EQ(status, Reducer::ReductionResultStatus::kComplete);
559 }
560 
561 // Computes an instruction count for each function in the module represented by
562 // |binary|.
GetFunctionInstructionCount(const std::vector<uint32_t> & binary)563 std::unordered_map<uint32_t, uint32_t> GetFunctionInstructionCount(
564     const std::vector<uint32_t>& binary) {
565   std::unique_ptr<opt::IRContext> context =
566       BuildModule(kEnv, kMessageConsumer, binary.data(), binary.size());
567   assert(context != nullptr && "Failed to build module.");
568   std::unordered_map<uint32_t, uint32_t> result;
569   for (auto& function : *context->module()) {
570     uint32_t& count = result[function.result_id()] = 0;
571     function.ForEachInst([&count](opt::Instruction*) { count++; });
572   }
573   return result;
574 }
575 
TEST(ReducerTest,SingleFunctionReduction)576 TEST(ReducerTest, SingleFunctionReduction) {
577   Reducer reducer(kEnv);
578 
579   PingPongInteresting ping_pong_interesting(4);
580   reducer.SetInterestingnessFunction(
581       [&ping_pong_interesting](const std::vector<uint32_t>&, uint32_t) -> bool {
582         return ping_pong_interesting.IsInteresting();
583       });
584   reducer.AddDefaultReductionPasses();
585   reducer.SetMessageConsumer(kMessageConsumer);
586 
587   std::vector<uint32_t> binary_in;
588   SpirvTools t(kEnv);
589 
590   ASSERT_TRUE(t.Assemble(kShaderWithMultipleFunctions, &binary_in,
591                          kReduceAssembleOption));
592 
593   auto original_instruction_count = GetFunctionInstructionCount(binary_in);
594 
595   std::vector<uint32_t> binary_out;
596   spvtools::ReducerOptions reducer_options;
597   reducer_options.set_step_limit(500);
598   reducer_options.set_fail_on_validation_error(true);
599 
600   // Instruct the reducer to only target function 13.
601   reducer_options.set_target_function(13);
602 
603   spvtools::ValidatorOptions validator_options;
604 
605   Reducer::ReductionResultStatus status = reducer.Run(
606       std::move(binary_in), &binary_out, reducer_options, validator_options);
607 
608   ASSERT_EQ(status, Reducer::ReductionResultStatus::kComplete);
609 
610   auto final_instruction_count = GetFunctionInstructionCount(binary_out);
611 
612   // Nothing should have been removed from these functions.
613   ASSERT_EQ(original_instruction_count.at(4), final_instruction_count.at(4));
614   ASSERT_EQ(original_instruction_count.at(10), final_instruction_count.at(10));
615 
616   // Function 13 should have been reduced to these five instructions:
617   //   OpFunction
618   //   OpFunctionParameter
619   //   OpLabel
620   //   OpReturnValue
621   //   OpFunctionEnd
622   ASSERT_EQ(5, final_instruction_count.at(13));
623 }
624 
625 }  // namespace
626 }  // namespace reduce
627 }  // namespace spvtools
628