• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2018 Google LLC.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include <memory>
16 #include <string>
17 #include <unordered_set>
18 #include <vector>
19 
20 #include "gmock/gmock.h"
21 #include "source/opt/iterator.h"
22 #include "source/opt/loop_descriptor.h"
23 #include "source/opt/pass.h"
24 #include "source/opt/scalar_analysis.h"
25 #include "source/opt/tree_iterator.h"
26 #include "test/opt/assembly_builder.h"
27 #include "test/opt/function_utils.h"
28 #include "test/opt/pass_fixture.h"
29 #include "test/opt/pass_utils.h"
30 
31 namespace spvtools {
32 namespace opt {
33 namespace {
34 
35 using ::testing::UnorderedElementsAre;
36 using ScalarAnalysisTest = PassTest<::testing::Test>;
37 
38 /*
39 Generated from the following GLSL + --eliminate-local-multi-store
40 
41 #version 410 core
42 layout (location = 1) out float array[10];
43 void main() {
44   for (int i = 0; i < 10; ++i) {
45     array[i] = array[i+1];
46   }
47 }
48 */
TEST_F(ScalarAnalysisTest,BasicEvolutionTest)49 TEST_F(ScalarAnalysisTest, BasicEvolutionTest) {
50   const std::string text = R"(
51                OpCapability Shader
52           %1 = OpExtInstImport "GLSL.std.450"
53                OpMemoryModel Logical GLSL450
54                OpEntryPoint Fragment %4 "main" %24
55                OpExecutionMode %4 OriginUpperLeft
56                OpSource GLSL 410
57                OpName %4 "main"
58                OpName %24 "array"
59                OpDecorate %24 Location 1
60           %2 = OpTypeVoid
61           %3 = OpTypeFunction %2
62           %6 = OpTypeInt 32 1
63           %7 = OpTypePointer Function %6
64           %9 = OpConstant %6 0
65          %16 = OpConstant %6 10
66          %17 = OpTypeBool
67          %19 = OpTypeFloat 32
68          %20 = OpTypeInt 32 0
69          %21 = OpConstant %20 10
70          %22 = OpTypeArray %19 %21
71          %23 = OpTypePointer Output %22
72          %24 = OpVariable %23 Output
73          %27 = OpConstant %6 1
74          %29 = OpTypePointer Output %19
75           %4 = OpFunction %2 None %3
76           %5 = OpLabel
77                OpBranch %10
78          %10 = OpLabel
79          %35 = OpPhi %6 %9 %5 %34 %13
80                OpLoopMerge %12 %13 None
81                OpBranch %14
82          %14 = OpLabel
83          %18 = OpSLessThan %17 %35 %16
84                OpBranchConditional %18 %11 %12
85          %11 = OpLabel
86          %28 = OpIAdd %6 %35 %27
87          %30 = OpAccessChain %29 %24 %28
88          %31 = OpLoad %19 %30
89          %32 = OpAccessChain %29 %24 %35
90                OpStore %32 %31
91                OpBranch %13
92          %13 = OpLabel
93          %34 = OpIAdd %6 %35 %27
94                OpBranch %10
95          %12 = OpLabel
96                OpReturn
97                OpFunctionEnd
98   )";
99   // clang-format on
100   std::unique_ptr<IRContext> context =
101       BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
102                   SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
103   Module* module = context->module();
104   EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
105                              << text << std::endl;
106   const Function* f = spvtest::GetFunction(module, 4);
107   ScalarEvolutionAnalysis analysis{context.get()};
108 
109   const Instruction* store = nullptr;
110   const Instruction* load = nullptr;
111   for (const Instruction& inst : *spvtest::GetBasicBlock(f, 11)) {
112     if (inst.opcode() == SpvOp::SpvOpStore) {
113       store = &inst;
114     }
115     if (inst.opcode() == SpvOp::SpvOpLoad) {
116       load = &inst;
117     }
118   }
119 
120   EXPECT_NE(load, nullptr);
121   EXPECT_NE(store, nullptr);
122 
123   Instruction* access_chain =
124       context->get_def_use_mgr()->GetDef(load->GetSingleWordInOperand(0));
125 
126   Instruction* child = context->get_def_use_mgr()->GetDef(
127       access_chain->GetSingleWordInOperand(1));
128   const SENode* node = analysis.AnalyzeInstruction(child);
129 
130   EXPECT_NE(node, nullptr);
131 
132   // Unsimplified node should have the form of ADD(REC(0,1), 1)
133   EXPECT_EQ(node->GetType(), SENode::Add);
134 
135   const SENode* child_1 = node->GetChild(0);
136   EXPECT_TRUE(child_1->GetType() == SENode::Constant ||
137               child_1->GetType() == SENode::RecurrentAddExpr);
138 
139   const SENode* child_2 = node->GetChild(1);
140   EXPECT_TRUE(child_2->GetType() == SENode::Constant ||
141               child_2->GetType() == SENode::RecurrentAddExpr);
142 
143   SENode* simplified = analysis.SimplifyExpression(const_cast<SENode*>(node));
144   // Simplified should be in the form of REC(1,1)
145   EXPECT_EQ(simplified->GetType(), SENode::RecurrentAddExpr);
146 
147   EXPECT_EQ(simplified->GetChild(0)->GetType(), SENode::Constant);
148   EXPECT_EQ(simplified->GetChild(0)->AsSEConstantNode()->FoldToSingleValue(),
149             1);
150 
151   EXPECT_EQ(simplified->GetChild(1)->GetType(), SENode::Constant);
152   EXPECT_EQ(simplified->GetChild(1)->AsSEConstantNode()->FoldToSingleValue(),
153             1);
154 
155   EXPECT_EQ(simplified->GetChild(0), simplified->GetChild(1));
156 }
157 
158 /*
159 Generated from the following GLSL + --eliminate-local-multi-store
160 
161 #version 410 core
162 layout (location = 1) out float array[10];
163 layout (location = 2) flat in int loop_invariant;
164 void main() {
165   for (int i = 0; i < 10; ++i) {
166     array[i] = array[i+loop_invariant];
167   }
168 }
169 
170 */
TEST_F(ScalarAnalysisTest,LoadTest)171 TEST_F(ScalarAnalysisTest, LoadTest) {
172   const std::string text = R"(
173                OpCapability Shader
174           %1 = OpExtInstImport "GLSL.std.450"
175                OpMemoryModel Logical GLSL450
176                OpEntryPoint Fragment %2 "main" %3 %4
177                OpExecutionMode %2 OriginUpperLeft
178                OpSource GLSL 430
179                OpName %2 "main"
180                OpName %3 "array"
181                OpName %4 "loop_invariant"
182                OpDecorate %3 Location 1
183                OpDecorate %4 Flat
184                OpDecorate %4 Location 2
185           %5 = OpTypeVoid
186           %6 = OpTypeFunction %5
187           %7 = OpTypeInt 32 1
188           %8 = OpTypePointer Function %7
189           %9 = OpConstant %7 0
190          %10 = OpConstant %7 10
191          %11 = OpTypeBool
192          %12 = OpTypeFloat 32
193          %13 = OpTypeInt 32 0
194          %14 = OpConstant %13 10
195          %15 = OpTypeArray %12 %14
196          %16 = OpTypePointer Output %15
197           %3 = OpVariable %16 Output
198          %17 = OpTypePointer Input %7
199           %4 = OpVariable %17 Input
200          %18 = OpTypePointer Output %12
201          %19 = OpConstant %7 1
202           %2 = OpFunction %5 None %6
203          %20 = OpLabel
204                OpBranch %21
205          %21 = OpLabel
206          %22 = OpPhi %7 %9 %20 %23 %24
207                OpLoopMerge %25 %24 None
208                OpBranch %26
209          %26 = OpLabel
210          %27 = OpSLessThan %11 %22 %10
211                OpBranchConditional %27 %28 %25
212          %28 = OpLabel
213          %29 = OpLoad %7 %4
214          %30 = OpIAdd %7 %22 %29
215          %31 = OpAccessChain %18 %3 %30
216          %32 = OpLoad %12 %31
217          %33 = OpAccessChain %18 %3 %22
218                OpStore %33 %32
219                OpBranch %24
220          %24 = OpLabel
221          %23 = OpIAdd %7 %22 %19
222                OpBranch %21
223          %25 = OpLabel
224                OpReturn
225                OpFunctionEnd
226 )";
227   // clang-format on
228   std::unique_ptr<IRContext> context =
229       BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
230                   SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
231   Module* module = context->module();
232   EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
233                              << text << std::endl;
234   const Function* f = spvtest::GetFunction(module, 2);
235   ScalarEvolutionAnalysis analysis{context.get()};
236 
237   const Instruction* load = nullptr;
238   for (const Instruction& inst : *spvtest::GetBasicBlock(f, 28)) {
239     if (inst.opcode() == SpvOp::SpvOpLoad) {
240       load = &inst;
241     }
242   }
243 
244   EXPECT_NE(load, nullptr);
245 
246   Instruction* access_chain =
247       context->get_def_use_mgr()->GetDef(load->GetSingleWordInOperand(0));
248 
249   Instruction* child = context->get_def_use_mgr()->GetDef(
250       access_chain->GetSingleWordInOperand(1));
251   //  const SENode* node =
252   //  analysis.GetNodeFromInstruction(child->unique_id());
253 
254   const SENode* node = analysis.AnalyzeInstruction(child);
255 
256   EXPECT_NE(node, nullptr);
257 
258   // Unsimplified node should have the form of ADD(REC(0,1), X)
259   EXPECT_EQ(node->GetType(), SENode::Add);
260 
261   const SENode* child_1 = node->GetChild(0);
262   EXPECT_TRUE(child_1->GetType() == SENode::ValueUnknown ||
263               child_1->GetType() == SENode::RecurrentAddExpr);
264 
265   const SENode* child_2 = node->GetChild(1);
266   EXPECT_TRUE(child_2->GetType() == SENode::ValueUnknown ||
267               child_2->GetType() == SENode::RecurrentAddExpr);
268 
269   SENode* simplified = analysis.SimplifyExpression(const_cast<SENode*>(node));
270   EXPECT_EQ(simplified->GetType(), SENode::RecurrentAddExpr);
271 
272   const SERecurrentNode* rec = simplified->AsSERecurrentNode();
273 
274   EXPECT_NE(rec->GetChild(0), rec->GetChild(1));
275 
276   EXPECT_EQ(rec->GetOffset()->GetType(), SENode::ValueUnknown);
277 
278   EXPECT_EQ(rec->GetCoefficient()->GetType(), SENode::Constant);
279   EXPECT_EQ(rec->GetCoefficient()->AsSEConstantNode()->FoldToSingleValue(), 1u);
280 }
281 
282 /*
283 Generated from the following GLSL + --eliminate-local-multi-store
284 
285 #version 410 core
286 layout (location = 1) out float array[10];
287 layout (location = 2) flat in int loop_invariant;
288 void main() {
289   array[0] = array[loop_invariant * 2 + 4 + 5 - 24 - loop_invariant -
290 loop_invariant+ 16 * 3];
291 }
292 
293 */
TEST_F(ScalarAnalysisTest,SimplifySimple)294 TEST_F(ScalarAnalysisTest, SimplifySimple) {
295   const std::string text = R"(
296                OpCapability Shader
297           %1 = OpExtInstImport "GLSL.std.450"
298                OpMemoryModel Logical GLSL450
299                OpEntryPoint Fragment %2 "main" %3 %4
300                OpExecutionMode %2 OriginUpperLeft
301                OpSource GLSL 430
302                OpName %2 "main"
303                OpName %3 "array"
304                OpName %4 "loop_invariant"
305                OpDecorate %3 Location 1
306                OpDecorate %4 Flat
307                OpDecorate %4 Location 2
308           %5 = OpTypeVoid
309           %6 = OpTypeFunction %5
310           %7 = OpTypeFloat 32
311           %8 = OpTypeInt 32 0
312           %9 = OpConstant %8 10
313          %10 = OpTypeArray %7 %9
314          %11 = OpTypePointer Output %10
315           %3 = OpVariable %11 Output
316          %12 = OpTypeInt 32 1
317          %13 = OpConstant %12 0
318          %14 = OpTypePointer Input %12
319           %4 = OpVariable %14 Input
320          %15 = OpConstant %12 2
321          %16 = OpConstant %12 4
322          %17 = OpConstant %12 5
323          %18 = OpConstant %12 24
324          %19 = OpConstant %12 48
325          %20 = OpTypePointer Output %7
326           %2 = OpFunction %5 None %6
327          %21 = OpLabel
328          %22 = OpLoad %12 %4
329          %23 = OpIMul %12 %22 %15
330          %24 = OpIAdd %12 %23 %16
331          %25 = OpIAdd %12 %24 %17
332          %26 = OpISub %12 %25 %18
333          %28 = OpISub %12 %26 %22
334          %30 = OpISub %12 %28 %22
335          %31 = OpIAdd %12 %30 %19
336          %32 = OpAccessChain %20 %3 %31
337          %33 = OpLoad %7 %32
338          %34 = OpAccessChain %20 %3 %13
339                OpStore %34 %33
340                OpReturn
341                OpFunctionEnd
342     )";
343   // clang-format on
344   std::unique_ptr<IRContext> context =
345       BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
346                   SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
347   Module* module = context->module();
348   EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
349                              << text << std::endl;
350   const Function* f = spvtest::GetFunction(module, 2);
351   ScalarEvolutionAnalysis analysis{context.get()};
352 
353   const Instruction* load = nullptr;
354   for (const Instruction& inst : *spvtest::GetBasicBlock(f, 21)) {
355     if (inst.opcode() == SpvOp::SpvOpLoad && inst.result_id() == 33) {
356       load = &inst;
357     }
358   }
359 
360   EXPECT_NE(load, nullptr);
361 
362   Instruction* access_chain =
363       context->get_def_use_mgr()->GetDef(load->GetSingleWordInOperand(0));
364 
365   Instruction* child = context->get_def_use_mgr()->GetDef(
366       access_chain->GetSingleWordInOperand(1));
367 
368   const SENode* node = analysis.AnalyzeInstruction(child);
369 
370   // Unsimplified is a very large graph with an add at the top.
371   EXPECT_NE(node, nullptr);
372   EXPECT_EQ(node->GetType(), SENode::Add);
373 
374   // Simplified node should resolve down to a constant expression as the loads
375   // will eliminate themselves.
376   SENode* simplified = analysis.SimplifyExpression(const_cast<SENode*>(node));
377 
378   EXPECT_EQ(simplified->GetType(), SENode::Constant);
379   EXPECT_EQ(simplified->AsSEConstantNode()->FoldToSingleValue(), 33u);
380 }
381 
382 /*
383 Generated from the following GLSL + --eliminate-local-multi-store
384 
385 #version 410 core
386 layout(location = 0) in vec4 c;
387 layout (location = 1) out float array[10];
388 void main() {
389   int N = int(c.x);
390   for (int i = 0; i < 10; ++i) {
391     array[i] = array[i];
392     array[i] = array[i-1];
393     array[i] = array[i+1];
394     array[i+1] = array[i+1];
395     array[i+N] = array[i+N];
396     array[i] = array[i+N];
397   }
398 }
399 
400 */
TEST_F(ScalarAnalysisTest,Simplify)401 TEST_F(ScalarAnalysisTest, Simplify) {
402   const std::string text = R"(               OpCapability Shader
403           %1 = OpExtInstImport "GLSL.std.450"
404                OpMemoryModel Logical GLSL450
405                OpEntryPoint Fragment %4 "main" %12 %33
406                OpExecutionMode %4 OriginUpperLeft
407                OpSource GLSL 410
408                OpName %4 "main"
409                OpName %8 "N"
410                OpName %12 "c"
411                OpName %19 "i"
412                OpName %33 "array"
413                OpDecorate %12 Location 0
414                OpDecorate %33 Location 1
415           %2 = OpTypeVoid
416           %3 = OpTypeFunction %2
417           %6 = OpTypeInt 32 1
418           %7 = OpTypePointer Function %6
419           %9 = OpTypeFloat 32
420          %10 = OpTypeVector %9 4
421          %11 = OpTypePointer Input %10
422          %12 = OpVariable %11 Input
423          %13 = OpTypeInt 32 0
424          %14 = OpConstant %13 0
425          %15 = OpTypePointer Input %9
426          %20 = OpConstant %6 0
427          %27 = OpConstant %6 10
428          %28 = OpTypeBool
429          %30 = OpConstant %13 10
430          %31 = OpTypeArray %9 %30
431          %32 = OpTypePointer Output %31
432          %33 = OpVariable %32 Output
433          %36 = OpTypePointer Output %9
434          %42 = OpConstant %6 1
435           %4 = OpFunction %2 None %3
436           %5 = OpLabel
437           %8 = OpVariable %7 Function
438          %19 = OpVariable %7 Function
439          %16 = OpAccessChain %15 %12 %14
440          %17 = OpLoad %9 %16
441          %18 = OpConvertFToS %6 %17
442                OpStore %8 %18
443                OpStore %19 %20
444                OpBranch %21
445          %21 = OpLabel
446          %78 = OpPhi %6 %20 %5 %77 %24
447                OpLoopMerge %23 %24 None
448                OpBranch %25
449          %25 = OpLabel
450          %29 = OpSLessThan %28 %78 %27
451                OpBranchConditional %29 %22 %23
452          %22 = OpLabel
453          %37 = OpAccessChain %36 %33 %78
454          %38 = OpLoad %9 %37
455          %39 = OpAccessChain %36 %33 %78
456                OpStore %39 %38
457          %43 = OpISub %6 %78 %42
458          %44 = OpAccessChain %36 %33 %43
459          %45 = OpLoad %9 %44
460          %46 = OpAccessChain %36 %33 %78
461                OpStore %46 %45
462          %49 = OpIAdd %6 %78 %42
463          %50 = OpAccessChain %36 %33 %49
464          %51 = OpLoad %9 %50
465          %52 = OpAccessChain %36 %33 %78
466                OpStore %52 %51
467          %54 = OpIAdd %6 %78 %42
468          %56 = OpIAdd %6 %78 %42
469          %57 = OpAccessChain %36 %33 %56
470          %58 = OpLoad %9 %57
471          %59 = OpAccessChain %36 %33 %54
472                OpStore %59 %58
473          %62 = OpIAdd %6 %78 %18
474          %65 = OpIAdd %6 %78 %18
475          %66 = OpAccessChain %36 %33 %65
476          %67 = OpLoad %9 %66
477          %68 = OpAccessChain %36 %33 %62
478                OpStore %68 %67
479          %72 = OpIAdd %6 %78 %18
480          %73 = OpAccessChain %36 %33 %72
481          %74 = OpLoad %9 %73
482          %75 = OpAccessChain %36 %33 %78
483                OpStore %75 %74
484                OpBranch %24
485          %24 = OpLabel
486          %77 = OpIAdd %6 %78 %42
487                OpStore %19 %77
488                OpBranch %21
489          %23 = OpLabel
490                OpReturn
491                OpFunctionEnd
492 )";
493   // clang-format on
494   std::unique_ptr<IRContext> context =
495       BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
496                   SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
497   Module* module = context->module();
498   EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
499                              << text << std::endl;
500   const Function* f = spvtest::GetFunction(module, 4);
501   ScalarEvolutionAnalysis analysis{context.get()};
502 
503   const Instruction* loads[6];
504   const Instruction* stores[6];
505   int load_count = 0;
506   int store_count = 0;
507 
508   for (const Instruction& inst : *spvtest::GetBasicBlock(f, 22)) {
509     if (inst.opcode() == SpvOp::SpvOpLoad) {
510       loads[load_count] = &inst;
511       ++load_count;
512     }
513     if (inst.opcode() == SpvOp::SpvOpStore) {
514       stores[store_count] = &inst;
515       ++store_count;
516     }
517   }
518 
519   EXPECT_EQ(load_count, 6);
520   EXPECT_EQ(store_count, 6);
521 
522   Instruction* load_access_chain;
523   Instruction* store_access_chain;
524   Instruction* load_child;
525   Instruction* store_child;
526   SENode* load_node;
527   SENode* store_node;
528   SENode* subtract_node;
529   SENode* simplified_node;
530 
531   // Testing [i] - [i] == 0
532   load_access_chain =
533       context->get_def_use_mgr()->GetDef(loads[0]->GetSingleWordInOperand(0));
534   store_access_chain =
535       context->get_def_use_mgr()->GetDef(stores[0]->GetSingleWordInOperand(0));
536 
537   load_child = context->get_def_use_mgr()->GetDef(
538       load_access_chain->GetSingleWordInOperand(1));
539   store_child = context->get_def_use_mgr()->GetDef(
540       store_access_chain->GetSingleWordInOperand(1));
541 
542   load_node = analysis.AnalyzeInstruction(load_child);
543   store_node = analysis.AnalyzeInstruction(store_child);
544 
545   subtract_node = analysis.CreateSubtraction(store_node, load_node);
546   simplified_node = analysis.SimplifyExpression(subtract_node);
547   EXPECT_EQ(simplified_node->GetType(), SENode::Constant);
548   EXPECT_EQ(simplified_node->AsSEConstantNode()->FoldToSingleValue(), 0u);
549 
550   // Testing [i] - [i-1] == 1
551   load_access_chain =
552       context->get_def_use_mgr()->GetDef(loads[1]->GetSingleWordInOperand(0));
553   store_access_chain =
554       context->get_def_use_mgr()->GetDef(stores[1]->GetSingleWordInOperand(0));
555 
556   load_child = context->get_def_use_mgr()->GetDef(
557       load_access_chain->GetSingleWordInOperand(1));
558   store_child = context->get_def_use_mgr()->GetDef(
559       store_access_chain->GetSingleWordInOperand(1));
560 
561   load_node = analysis.AnalyzeInstruction(load_child);
562   store_node = analysis.AnalyzeInstruction(store_child);
563 
564   subtract_node = analysis.CreateSubtraction(store_node, load_node);
565   simplified_node = analysis.SimplifyExpression(subtract_node);
566 
567   EXPECT_EQ(simplified_node->GetType(), SENode::Constant);
568   EXPECT_EQ(simplified_node->AsSEConstantNode()->FoldToSingleValue(), 1u);
569 
570   // Testing [i] - [i+1] == -1
571   load_access_chain =
572       context->get_def_use_mgr()->GetDef(loads[2]->GetSingleWordInOperand(0));
573   store_access_chain =
574       context->get_def_use_mgr()->GetDef(stores[2]->GetSingleWordInOperand(0));
575 
576   load_child = context->get_def_use_mgr()->GetDef(
577       load_access_chain->GetSingleWordInOperand(1));
578   store_child = context->get_def_use_mgr()->GetDef(
579       store_access_chain->GetSingleWordInOperand(1));
580 
581   load_node = analysis.AnalyzeInstruction(load_child);
582   store_node = analysis.AnalyzeInstruction(store_child);
583 
584   subtract_node = analysis.CreateSubtraction(store_node, load_node);
585   simplified_node = analysis.SimplifyExpression(subtract_node);
586   EXPECT_EQ(simplified_node->GetType(), SENode::Constant);
587   EXPECT_EQ(simplified_node->AsSEConstantNode()->FoldToSingleValue(), -1);
588 
589   // Testing [i+1] - [i+1] == 0
590   load_access_chain =
591       context->get_def_use_mgr()->GetDef(loads[3]->GetSingleWordInOperand(0));
592   store_access_chain =
593       context->get_def_use_mgr()->GetDef(stores[3]->GetSingleWordInOperand(0));
594 
595   load_child = context->get_def_use_mgr()->GetDef(
596       load_access_chain->GetSingleWordInOperand(1));
597   store_child = context->get_def_use_mgr()->GetDef(
598       store_access_chain->GetSingleWordInOperand(1));
599 
600   load_node = analysis.AnalyzeInstruction(load_child);
601   store_node = analysis.AnalyzeInstruction(store_child);
602 
603   subtract_node = analysis.CreateSubtraction(store_node, load_node);
604   simplified_node = analysis.SimplifyExpression(subtract_node);
605   EXPECT_EQ(simplified_node->GetType(), SENode::Constant);
606   EXPECT_EQ(simplified_node->AsSEConstantNode()->FoldToSingleValue(), 0u);
607 
608   // Testing [i+N] - [i+N] == 0
609   load_access_chain =
610       context->get_def_use_mgr()->GetDef(loads[4]->GetSingleWordInOperand(0));
611   store_access_chain =
612       context->get_def_use_mgr()->GetDef(stores[4]->GetSingleWordInOperand(0));
613 
614   load_child = context->get_def_use_mgr()->GetDef(
615       load_access_chain->GetSingleWordInOperand(1));
616   store_child = context->get_def_use_mgr()->GetDef(
617       store_access_chain->GetSingleWordInOperand(1));
618 
619   load_node = analysis.AnalyzeInstruction(load_child);
620   store_node = analysis.AnalyzeInstruction(store_child);
621 
622   subtract_node = analysis.CreateSubtraction(store_node, load_node);
623 
624   simplified_node = analysis.SimplifyExpression(subtract_node);
625   EXPECT_EQ(simplified_node->GetType(), SENode::Constant);
626   EXPECT_EQ(simplified_node->AsSEConstantNode()->FoldToSingleValue(), 0u);
627 
628   // Testing [i] - [i+N] == -N
629   load_access_chain =
630       context->get_def_use_mgr()->GetDef(loads[5]->GetSingleWordInOperand(0));
631   store_access_chain =
632       context->get_def_use_mgr()->GetDef(stores[5]->GetSingleWordInOperand(0));
633 
634   load_child = context->get_def_use_mgr()->GetDef(
635       load_access_chain->GetSingleWordInOperand(1));
636   store_child = context->get_def_use_mgr()->GetDef(
637       store_access_chain->GetSingleWordInOperand(1));
638 
639   load_node = analysis.AnalyzeInstruction(load_child);
640   store_node = analysis.AnalyzeInstruction(store_child);
641 
642   subtract_node = analysis.CreateSubtraction(store_node, load_node);
643   simplified_node = analysis.SimplifyExpression(subtract_node);
644   EXPECT_EQ(simplified_node->GetType(), SENode::Negative);
645 }
646 
647 /*
648 Generated from the following GLSL + --eliminate-local-multi-store
649 
650 #version 430
651 layout(location = 1) out float array[10];
652 layout(location = 2) flat in int loop_invariant;
653 void main(void) {
654   for (int i = 0; i < 10; ++i) {
655     array[i * 2 + i * 5] = array[i * i * 2];
656     array[i * 2] = array[i * 5];
657   }
658 }
659 
660 */
661 
TEST_F(ScalarAnalysisTest,SimplifyMultiplyInductions)662 TEST_F(ScalarAnalysisTest, SimplifyMultiplyInductions) {
663   const std::string text = R"(
664                OpCapability Shader
665           %1 = OpExtInstImport "GLSL.std.450"
666                OpMemoryModel Logical GLSL450
667                OpEntryPoint Fragment %2 "main" %3 %4
668                OpExecutionMode %2 OriginUpperLeft
669                OpSource GLSL 430
670                OpName %2 "main"
671                OpName %5 "i"
672                OpName %3 "array"
673                OpName %4 "loop_invariant"
674                OpDecorate %3 Location 1
675                OpDecorate %4 Flat
676                OpDecorate %4 Location 2
677           %6 = OpTypeVoid
678           %7 = OpTypeFunction %6
679           %8 = OpTypeInt 32 1
680           %9 = OpTypePointer Function %8
681          %10 = OpConstant %8 0
682          %11 = OpConstant %8 10
683          %12 = OpTypeBool
684          %13 = OpTypeFloat 32
685          %14 = OpTypeInt 32 0
686          %15 = OpConstant %14 10
687          %16 = OpTypeArray %13 %15
688          %17 = OpTypePointer Output %16
689           %3 = OpVariable %17 Output
690          %18 = OpConstant %8 2
691          %19 = OpConstant %8 5
692          %20 = OpTypePointer Output %13
693          %21 = OpConstant %8 1
694          %22 = OpTypePointer Input %8
695           %4 = OpVariable %22 Input
696           %2 = OpFunction %6 None %7
697          %23 = OpLabel
698           %5 = OpVariable %9 Function
699                OpStore %5 %10
700                OpBranch %24
701          %24 = OpLabel
702          %25 = OpPhi %8 %10 %23 %26 %27
703                OpLoopMerge %28 %27 None
704                OpBranch %29
705          %29 = OpLabel
706          %30 = OpSLessThan %12 %25 %11
707                OpBranchConditional %30 %31 %28
708          %31 = OpLabel
709          %32 = OpIMul %8 %25 %18
710          %33 = OpIMul %8 %25 %19
711          %34 = OpIAdd %8 %32 %33
712          %35 = OpIMul %8 %25 %25
713          %36 = OpIMul %8 %35 %18
714          %37 = OpAccessChain %20 %3 %36
715          %38 = OpLoad %13 %37
716          %39 = OpAccessChain %20 %3 %34
717                OpStore %39 %38
718          %40 = OpIMul %8 %25 %18
719          %41 = OpIMul %8 %25 %19
720          %42 = OpAccessChain %20 %3 %41
721          %43 = OpLoad %13 %42
722          %44 = OpAccessChain %20 %3 %40
723                OpStore %44 %43
724                OpBranch %27
725          %27 = OpLabel
726          %26 = OpIAdd %8 %25 %21
727                OpStore %5 %26
728                OpBranch %24
729          %28 = OpLabel
730                OpReturn
731                OpFunctionEnd
732     )";
733   std::unique_ptr<IRContext> context =
734       BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
735                   SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
736   Module* module = context->module();
737   EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
738                              << text << std::endl;
739   const Function* f = spvtest::GetFunction(module, 2);
740   ScalarEvolutionAnalysis analysis{context.get()};
741 
742   const Instruction* loads[2] = {nullptr, nullptr};
743   const Instruction* stores[2] = {nullptr, nullptr};
744   int load_count = 0;
745   int store_count = 0;
746 
747   for (const Instruction& inst : *spvtest::GetBasicBlock(f, 31)) {
748     if (inst.opcode() == SpvOp::SpvOpLoad) {
749       loads[load_count] = &inst;
750       ++load_count;
751     }
752     if (inst.opcode() == SpvOp::SpvOpStore) {
753       stores[store_count] = &inst;
754       ++store_count;
755     }
756   }
757 
758   EXPECT_EQ(load_count, 2);
759   EXPECT_EQ(store_count, 2);
760 
761   Instruction* load_access_chain =
762       context->get_def_use_mgr()->GetDef(loads[0]->GetSingleWordInOperand(0));
763   Instruction* store_access_chain =
764       context->get_def_use_mgr()->GetDef(stores[0]->GetSingleWordInOperand(0));
765 
766   Instruction* load_child = context->get_def_use_mgr()->GetDef(
767       load_access_chain->GetSingleWordInOperand(1));
768   Instruction* store_child = context->get_def_use_mgr()->GetDef(
769       store_access_chain->GetSingleWordInOperand(1));
770 
771   SENode* store_node = analysis.AnalyzeInstruction(store_child);
772 
773   SENode* store_simplified = analysis.SimplifyExpression(store_node);
774 
775   load_access_chain =
776       context->get_def_use_mgr()->GetDef(loads[1]->GetSingleWordInOperand(0));
777   store_access_chain =
778       context->get_def_use_mgr()->GetDef(stores[1]->GetSingleWordInOperand(0));
779   load_child = context->get_def_use_mgr()->GetDef(
780       load_access_chain->GetSingleWordInOperand(1));
781   store_child = context->get_def_use_mgr()->GetDef(
782       store_access_chain->GetSingleWordInOperand(1));
783 
784   SENode* second_store =
785       analysis.SimplifyExpression(analysis.AnalyzeInstruction(store_child));
786   SENode* second_load =
787       analysis.SimplifyExpression(analysis.AnalyzeInstruction(load_child));
788   SENode* combined_add = analysis.SimplifyExpression(
789       analysis.CreateAddNode(second_load, second_store));
790 
791   // We're checking that the two recurrent expression have been correctly
792   // folded. In store_simplified they will have been folded as the entire
793   // expression was simplified as one. In combined_add the two expressions have
794   // been simplified one after the other which means the recurrent expressions
795   // aren't exactly the same but should still be folded as they are with respect
796   // to the same loop.
797   EXPECT_EQ(combined_add, store_simplified);
798 }
799 
800 /*
801 Generated from the following GLSL + --eliminate-local-multi-store
802 
803 #version 430
804 void main(void) {
805     for (int i = 0; i < 10; --i) {
806         array[i] = array[i];
807     }
808 }
809 
810 */
811 
TEST_F(ScalarAnalysisTest,SimplifyNegativeSteps)812 TEST_F(ScalarAnalysisTest, SimplifyNegativeSteps) {
813   const std::string text = R"(
814                OpCapability Shader
815           %1 = OpExtInstImport "GLSL.std.450"
816                OpMemoryModel Logical GLSL450
817                OpEntryPoint Fragment %2 "main" %3 %4
818                OpExecutionMode %2 OriginUpperLeft
819                OpSource GLSL 430
820                OpName %2 "main"
821                OpName %5 "i"
822                OpName %3 "array"
823                OpName %4 "loop_invariant"
824                OpDecorate %3 Location 1
825                OpDecorate %4 Flat
826                OpDecorate %4 Location 2
827           %6 = OpTypeVoid
828           %7 = OpTypeFunction %6
829           %8 = OpTypeInt 32 1
830           %9 = OpTypePointer Function %8
831          %10 = OpConstant %8 0
832          %11 = OpConstant %8 10
833          %12 = OpTypeBool
834          %13 = OpTypeFloat 32
835          %14 = OpTypeInt 32 0
836          %15 = OpConstant %14 10
837          %16 = OpTypeArray %13 %15
838          %17 = OpTypePointer Output %16
839           %3 = OpVariable %17 Output
840          %18 = OpTypePointer Output %13
841          %19 = OpConstant %8 1
842          %20 = OpTypePointer Input %8
843           %4 = OpVariable %20 Input
844           %2 = OpFunction %6 None %7
845          %21 = OpLabel
846           %5 = OpVariable %9 Function
847                OpStore %5 %10
848                OpBranch %22
849          %22 = OpLabel
850          %23 = OpPhi %8 %10 %21 %24 %25
851                OpLoopMerge %26 %25 None
852                OpBranch %27
853          %27 = OpLabel
854          %28 = OpSLessThan %12 %23 %11
855                OpBranchConditional %28 %29 %26
856          %29 = OpLabel
857          %30 = OpAccessChain %18 %3 %23
858          %31 = OpLoad %13 %30
859          %32 = OpAccessChain %18 %3 %23
860                OpStore %32 %31
861                OpBranch %25
862          %25 = OpLabel
863          %24 = OpISub %8 %23 %19
864                OpStore %5 %24
865                OpBranch %22
866          %26 = OpLabel
867                OpReturn
868                OpFunctionEnd
869     )";
870   std::unique_ptr<IRContext> context =
871       BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
872                   SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
873   Module* module = context->module();
874   EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
875                              << text << std::endl;
876   const Function* f = spvtest::GetFunction(module, 2);
877   ScalarEvolutionAnalysis analysis{context.get()};
878 
879   const Instruction* loads[1] = {nullptr};
880   int load_count = 0;
881 
882   for (const Instruction& inst : *spvtest::GetBasicBlock(f, 29)) {
883     if (inst.opcode() == SpvOp::SpvOpLoad) {
884       loads[load_count] = &inst;
885       ++load_count;
886     }
887   }
888 
889   EXPECT_EQ(load_count, 1);
890 
891   Instruction* load_access_chain =
892       context->get_def_use_mgr()->GetDef(loads[0]->GetSingleWordInOperand(0));
893   Instruction* load_child = context->get_def_use_mgr()->GetDef(
894       load_access_chain->GetSingleWordInOperand(1));
895 
896   SENode* load_node = analysis.AnalyzeInstruction(load_child);
897 
898   EXPECT_TRUE(load_node);
899   EXPECT_EQ(load_node->GetType(), SENode::RecurrentAddExpr);
900   EXPECT_TRUE(load_node->AsSERecurrentNode());
901 
902   SENode* child_1 = load_node->AsSERecurrentNode()->GetCoefficient();
903   SENode* child_2 = load_node->AsSERecurrentNode()->GetOffset();
904 
905   EXPECT_EQ(child_1->GetType(), SENode::Constant);
906   EXPECT_EQ(child_2->GetType(), SENode::Constant);
907 
908   EXPECT_EQ(child_1->AsSEConstantNode()->FoldToSingleValue(), -1);
909   EXPECT_EQ(child_2->AsSEConstantNode()->FoldToSingleValue(), 0u);
910 
911   SERecurrentNode* load_simplified =
912       analysis.SimplifyExpression(load_node)->AsSERecurrentNode();
913 
914   EXPECT_TRUE(load_simplified);
915   EXPECT_EQ(load_node, load_simplified);
916 
917   EXPECT_EQ(load_simplified->GetType(), SENode::RecurrentAddExpr);
918   EXPECT_TRUE(load_simplified->AsSERecurrentNode());
919 
920   SENode* simplified_child_1 =
921       load_simplified->AsSERecurrentNode()->GetCoefficient();
922   SENode* simplified_child_2 =
923       load_simplified->AsSERecurrentNode()->GetOffset();
924 
925   EXPECT_EQ(child_1, simplified_child_1);
926   EXPECT_EQ(child_2, simplified_child_2);
927 }
928 
929 /*
930 Generated from the following GLSL + --eliminate-local-multi-store
931 
932 #version 430
933 void main(void) {
934     for (int i = 0; i < 10; --i) {
935         array[i] = array[i];
936     }
937 }
938 
939 */
940 
TEST_F(ScalarAnalysisTest,SimplifyInductionsAndLoads)941 TEST_F(ScalarAnalysisTest, SimplifyInductionsAndLoads) {
942   const std::string text = R"(
943                OpCapability Shader
944           %1 = OpExtInstImport "GLSL.std.450"
945                OpMemoryModel Logical GLSL450
946                OpEntryPoint Fragment %2 "main" %3 %4
947                OpExecutionMode %2 OriginUpperLeft
948                OpSource GLSL 430
949                OpName %2 "main"
950                OpName %5 "i"
951                OpName %3 "array"
952                OpName %4 "N"
953                OpDecorate %3 Location 1
954                OpDecorate %4 Flat
955                OpDecorate %4 Location 2
956           %6 = OpTypeVoid
957           %7 = OpTypeFunction %6
958           %8 = OpTypeInt 32 1
959           %9 = OpTypePointer Function %8
960          %10 = OpConstant %8 0
961          %11 = OpConstant %8 10
962          %12 = OpTypeBool
963          %13 = OpTypeFloat 32
964          %14 = OpTypeInt 32 0
965          %15 = OpConstant %14 10
966          %16 = OpTypeArray %13 %15
967          %17 = OpTypePointer Output %16
968           %3 = OpVariable %17 Output
969          %18 = OpConstant %8 2
970          %19 = OpTypePointer Input %8
971           %4 = OpVariable %19 Input
972          %20 = OpTypePointer Output %13
973          %21 = OpConstant %8 1
974           %2 = OpFunction %6 None %7
975          %22 = OpLabel
976           %5 = OpVariable %9 Function
977                OpStore %5 %10
978                OpBranch %23
979          %23 = OpLabel
980          %24 = OpPhi %8 %10 %22 %25 %26
981                OpLoopMerge %27 %26 None
982                OpBranch %28
983          %28 = OpLabel
984          %29 = OpSLessThan %12 %24 %11
985                OpBranchConditional %29 %30 %27
986          %30 = OpLabel
987          %31 = OpLoad %8 %4
988          %32 = OpIMul %8 %18 %31
989          %33 = OpIAdd %8 %24 %32
990          %35 = OpIAdd %8 %24 %31
991          %36 = OpAccessChain %20 %3 %35
992          %37 = OpLoad %13 %36
993          %38 = OpAccessChain %20 %3 %33
994                OpStore %38 %37
995          %39 = OpIMul %8 %18 %24
996          %41 = OpIMul %8 %18 %31
997          %42 = OpIAdd %8 %39 %41
998          %43 = OpIAdd %8 %42 %21
999          %44 = OpIMul %8 %18 %24
1000          %46 = OpIAdd %8 %44 %31
1001          %47 = OpIAdd %8 %46 %21
1002          %48 = OpAccessChain %20 %3 %47
1003          %49 = OpLoad %13 %48
1004          %50 = OpAccessChain %20 %3 %43
1005                OpStore %50 %49
1006                OpBranch %26
1007          %26 = OpLabel
1008          %25 = OpISub %8 %24 %21
1009                OpStore %5 %25
1010                OpBranch %23
1011          %27 = OpLabel
1012                OpReturn
1013                OpFunctionEnd
1014     )";
1015   std::unique_ptr<IRContext> context =
1016       BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
1017                   SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
1018   Module* module = context->module();
1019   EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
1020                              << text << std::endl;
1021   const Function* f = spvtest::GetFunction(module, 2);
1022   ScalarEvolutionAnalysis analysis{context.get()};
1023 
1024   std::vector<const Instruction*> loads{};
1025   std::vector<const Instruction*> stores{};
1026 
1027   for (const Instruction& inst : *spvtest::GetBasicBlock(f, 30)) {
1028     if (inst.opcode() == SpvOp::SpvOpLoad) {
1029       loads.push_back(&inst);
1030     }
1031     if (inst.opcode() == SpvOp::SpvOpStore) {
1032       stores.push_back(&inst);
1033     }
1034   }
1035 
1036   EXPECT_EQ(loads.size(), 3u);
1037   EXPECT_EQ(stores.size(), 2u);
1038   {
1039     Instruction* store_access_chain = context->get_def_use_mgr()->GetDef(
1040         stores[0]->GetSingleWordInOperand(0));
1041 
1042     Instruction* store_child = context->get_def_use_mgr()->GetDef(
1043         store_access_chain->GetSingleWordInOperand(1));
1044 
1045     SENode* store_node = analysis.AnalyzeInstruction(store_child);
1046 
1047     SENode* store_simplified = analysis.SimplifyExpression(store_node);
1048 
1049     Instruction* load_access_chain =
1050         context->get_def_use_mgr()->GetDef(loads[1]->GetSingleWordInOperand(0));
1051 
1052     Instruction* load_child = context->get_def_use_mgr()->GetDef(
1053         load_access_chain->GetSingleWordInOperand(1));
1054 
1055     SENode* load_node = analysis.AnalyzeInstruction(load_child);
1056 
1057     SENode* load_simplified = analysis.SimplifyExpression(load_node);
1058 
1059     SENode* difference =
1060         analysis.CreateSubtraction(store_simplified, load_simplified);
1061 
1062     SENode* difference_simplified = analysis.SimplifyExpression(difference);
1063 
1064     // Check that i+2*N  -  i*N, turns into just N when both sides have already
1065     // been simplified into a single recurrent expression.
1066     EXPECT_EQ(difference_simplified->GetType(), SENode::ValueUnknown);
1067 
1068     // Check that the inverse, i*N - i+2*N turns into -N.
1069     SENode* difference_inverse = analysis.SimplifyExpression(
1070         analysis.CreateSubtraction(load_simplified, store_simplified));
1071 
1072     EXPECT_EQ(difference_inverse->GetType(), SENode::Negative);
1073     EXPECT_EQ(difference_inverse->GetChild(0)->GetType(), SENode::ValueUnknown);
1074     EXPECT_EQ(difference_inverse->GetChild(0), difference_simplified);
1075   }
1076 
1077   {
1078     Instruction* store_access_chain = context->get_def_use_mgr()->GetDef(
1079         stores[1]->GetSingleWordInOperand(0));
1080 
1081     Instruction* store_child = context->get_def_use_mgr()->GetDef(
1082         store_access_chain->GetSingleWordInOperand(1));
1083     SENode* store_node = analysis.AnalyzeInstruction(store_child);
1084     SENode* store_simplified = analysis.SimplifyExpression(store_node);
1085 
1086     Instruction* load_access_chain =
1087         context->get_def_use_mgr()->GetDef(loads[2]->GetSingleWordInOperand(0));
1088 
1089     Instruction* load_child = context->get_def_use_mgr()->GetDef(
1090         load_access_chain->GetSingleWordInOperand(1));
1091 
1092     SENode* load_node = analysis.AnalyzeInstruction(load_child);
1093 
1094     SENode* load_simplified = analysis.SimplifyExpression(load_node);
1095 
1096     SENode* difference =
1097         analysis.CreateSubtraction(store_simplified, load_simplified);
1098     SENode* difference_simplified = analysis.SimplifyExpression(difference);
1099 
1100     // Check that 2*i + 2*N + 1  -  2*i + N + 1, turns into just N when both
1101     // sides have already been simplified into a single recurrent expression.
1102     EXPECT_EQ(difference_simplified->GetType(), SENode::ValueUnknown);
1103 
1104     // Check that the inverse, (2*i + N + 1)  -  (2*i + 2*N + 1) turns into -N.
1105     SENode* difference_inverse = analysis.SimplifyExpression(
1106         analysis.CreateSubtraction(load_simplified, store_simplified));
1107 
1108     EXPECT_EQ(difference_inverse->GetType(), SENode::Negative);
1109     EXPECT_EQ(difference_inverse->GetChild(0)->GetType(), SENode::ValueUnknown);
1110     EXPECT_EQ(difference_inverse->GetChild(0), difference_simplified);
1111   }
1112 }
1113 
1114 /* Generated from the following GLSL + --eliminate-local-multi-store
1115 
1116   #version 430
1117   layout(location = 1) out float array[10];
1118   layout(location = 2) flat in int N;
1119   void main(void) {
1120     int step = 0;
1121     for (int i = 0; i < N; i += step) {
1122       step++;
1123     }
1124   }
1125 */
TEST_F(ScalarAnalysisTest,InductionWithVariantStep)1126 TEST_F(ScalarAnalysisTest, InductionWithVariantStep) {
1127   const std::string text = R"(
1128                OpCapability Shader
1129           %1 = OpExtInstImport "GLSL.std.450"
1130                OpMemoryModel Logical GLSL450
1131                OpEntryPoint Fragment %2 "main" %3 %4
1132                OpExecutionMode %2 OriginUpperLeft
1133                OpSource GLSL 430
1134                OpName %2 "main"
1135                OpName %5 "step"
1136                OpName %6 "i"
1137                OpName %3 "N"
1138                OpName %4 "array"
1139                OpDecorate %3 Flat
1140                OpDecorate %3 Location 2
1141                OpDecorate %4 Location 1
1142           %7 = OpTypeVoid
1143           %8 = OpTypeFunction %7
1144           %9 = OpTypeInt 32 1
1145          %10 = OpTypePointer Function %9
1146          %11 = OpConstant %9 0
1147          %12 = OpTypePointer Input %9
1148           %3 = OpVariable %12 Input
1149          %13 = OpTypeBool
1150          %14 = OpConstant %9 1
1151          %15 = OpTypeFloat 32
1152          %16 = OpTypeInt 32 0
1153          %17 = OpConstant %16 10
1154          %18 = OpTypeArray %15 %17
1155          %19 = OpTypePointer Output %18
1156           %4 = OpVariable %19 Output
1157           %2 = OpFunction %7 None %8
1158          %20 = OpLabel
1159           %5 = OpVariable %10 Function
1160           %6 = OpVariable %10 Function
1161                OpStore %5 %11
1162                OpStore %6 %11
1163                OpBranch %21
1164          %21 = OpLabel
1165          %22 = OpPhi %9 %11 %20 %23 %24
1166          %25 = OpPhi %9 %11 %20 %26 %24
1167                OpLoopMerge %27 %24 None
1168                OpBranch %28
1169          %28 = OpLabel
1170          %29 = OpLoad %9 %3
1171          %30 = OpSLessThan %13 %25 %29
1172                OpBranchConditional %30 %31 %27
1173          %31 = OpLabel
1174          %23 = OpIAdd %9 %22 %14
1175                OpStore %5 %23
1176                OpBranch %24
1177          %24 = OpLabel
1178          %26 = OpIAdd %9 %25 %23
1179                OpStore %6 %26
1180                OpBranch %21
1181          %27 = OpLabel
1182                OpReturn
1183                OpFunctionEnd
1184   )";
1185   std::unique_ptr<IRContext> context =
1186       BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
1187                   SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
1188   Module* module = context->module();
1189   EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
1190                              << text << std::endl;
1191   const Function* f = spvtest::GetFunction(module, 2);
1192   ScalarEvolutionAnalysis analysis{context.get()};
1193 
1194   std::vector<const Instruction*> phis{};
1195 
1196   for (const Instruction& inst : *spvtest::GetBasicBlock(f, 21)) {
1197     if (inst.opcode() == SpvOp::SpvOpPhi) {
1198       phis.push_back(&inst);
1199     }
1200   }
1201 
1202   EXPECT_EQ(phis.size(), 2u);
1203   SENode* phi_node_1 = analysis.AnalyzeInstruction(phis[0]);
1204   SENode* phi_node_2 = analysis.AnalyzeInstruction(phis[1]);
1205   EXPECT_NE(phi_node_1, nullptr);
1206   EXPECT_NE(phi_node_2, nullptr);
1207 
1208   EXPECT_EQ(phi_node_1->GetType(), SENode::RecurrentAddExpr);
1209   EXPECT_EQ(phi_node_2->GetType(), SENode::CanNotCompute);
1210 
1211   SENode* simplified_1 = analysis.SimplifyExpression(phi_node_1);
1212   SENode* simplified_2 = analysis.SimplifyExpression(phi_node_2);
1213 
1214   EXPECT_EQ(simplified_1->GetType(), SENode::RecurrentAddExpr);
1215   EXPECT_EQ(simplified_2->GetType(), SENode::CanNotCompute);
1216 }
1217 
1218 }  // namespace
1219 }  // namespace opt
1220 }  // namespace spvtools
1221