1 // Copyright (c) 2018 Google LLC.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include <memory>
16 #include <string>
17 #include <unordered_set>
18 #include <vector>
19
20 #include "gmock/gmock.h"
21 #include "source/opt/iterator.h"
22 #include "source/opt/loop_descriptor.h"
23 #include "source/opt/pass.h"
24 #include "source/opt/scalar_analysis.h"
25 #include "source/opt/tree_iterator.h"
26 #include "test/opt/assembly_builder.h"
27 #include "test/opt/function_utils.h"
28 #include "test/opt/pass_fixture.h"
29 #include "test/opt/pass_utils.h"
30
31 namespace spvtools {
32 namespace opt {
33 namespace {
34
35 using ::testing::UnorderedElementsAre;
36 using ScalarAnalysisTest = PassTest<::testing::Test>;
37
38 /*
39 Generated from the following GLSL + --eliminate-local-multi-store
40
41 #version 410 core
42 layout (location = 1) out float array[10];
43 void main() {
44 for (int i = 0; i < 10; ++i) {
45 array[i] = array[i+1];
46 }
47 }
48 */
TEST_F(ScalarAnalysisTest,BasicEvolutionTest)49 TEST_F(ScalarAnalysisTest, BasicEvolutionTest) {
50 const std::string text = R"(
51 OpCapability Shader
52 %1 = OpExtInstImport "GLSL.std.450"
53 OpMemoryModel Logical GLSL450
54 OpEntryPoint Fragment %4 "main" %24
55 OpExecutionMode %4 OriginUpperLeft
56 OpSource GLSL 410
57 OpName %4 "main"
58 OpName %24 "array"
59 OpDecorate %24 Location 1
60 %2 = OpTypeVoid
61 %3 = OpTypeFunction %2
62 %6 = OpTypeInt 32 1
63 %7 = OpTypePointer Function %6
64 %9 = OpConstant %6 0
65 %16 = OpConstant %6 10
66 %17 = OpTypeBool
67 %19 = OpTypeFloat 32
68 %20 = OpTypeInt 32 0
69 %21 = OpConstant %20 10
70 %22 = OpTypeArray %19 %21
71 %23 = OpTypePointer Output %22
72 %24 = OpVariable %23 Output
73 %27 = OpConstant %6 1
74 %29 = OpTypePointer Output %19
75 %4 = OpFunction %2 None %3
76 %5 = OpLabel
77 OpBranch %10
78 %10 = OpLabel
79 %35 = OpPhi %6 %9 %5 %34 %13
80 OpLoopMerge %12 %13 None
81 OpBranch %14
82 %14 = OpLabel
83 %18 = OpSLessThan %17 %35 %16
84 OpBranchConditional %18 %11 %12
85 %11 = OpLabel
86 %28 = OpIAdd %6 %35 %27
87 %30 = OpAccessChain %29 %24 %28
88 %31 = OpLoad %19 %30
89 %32 = OpAccessChain %29 %24 %35
90 OpStore %32 %31
91 OpBranch %13
92 %13 = OpLabel
93 %34 = OpIAdd %6 %35 %27
94 OpBranch %10
95 %12 = OpLabel
96 OpReturn
97 OpFunctionEnd
98 )";
99 // clang-format on
100 std::unique_ptr<IRContext> context =
101 BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
102 SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
103 Module* module = context->module();
104 EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
105 << text << std::endl;
106 const Function* f = spvtest::GetFunction(module, 4);
107 ScalarEvolutionAnalysis analysis{context.get()};
108
109 const Instruction* store = nullptr;
110 const Instruction* load = nullptr;
111 for (const Instruction& inst : *spvtest::GetBasicBlock(f, 11)) {
112 if (inst.opcode() == SpvOp::SpvOpStore) {
113 store = &inst;
114 }
115 if (inst.opcode() == SpvOp::SpvOpLoad) {
116 load = &inst;
117 }
118 }
119
120 EXPECT_NE(load, nullptr);
121 EXPECT_NE(store, nullptr);
122
123 Instruction* access_chain =
124 context->get_def_use_mgr()->GetDef(load->GetSingleWordInOperand(0));
125
126 Instruction* child = context->get_def_use_mgr()->GetDef(
127 access_chain->GetSingleWordInOperand(1));
128 const SENode* node = analysis.AnalyzeInstruction(child);
129
130 EXPECT_NE(node, nullptr);
131
132 // Unsimplified node should have the form of ADD(REC(0,1), 1)
133 EXPECT_EQ(node->GetType(), SENode::Add);
134
135 const SENode* child_1 = node->GetChild(0);
136 EXPECT_TRUE(child_1->GetType() == SENode::Constant ||
137 child_1->GetType() == SENode::RecurrentAddExpr);
138
139 const SENode* child_2 = node->GetChild(1);
140 EXPECT_TRUE(child_2->GetType() == SENode::Constant ||
141 child_2->GetType() == SENode::RecurrentAddExpr);
142
143 SENode* simplified = analysis.SimplifyExpression(const_cast<SENode*>(node));
144 // Simplified should be in the form of REC(1,1)
145 EXPECT_EQ(simplified->GetType(), SENode::RecurrentAddExpr);
146
147 EXPECT_EQ(simplified->GetChild(0)->GetType(), SENode::Constant);
148 EXPECT_EQ(simplified->GetChild(0)->AsSEConstantNode()->FoldToSingleValue(),
149 1);
150
151 EXPECT_EQ(simplified->GetChild(1)->GetType(), SENode::Constant);
152 EXPECT_EQ(simplified->GetChild(1)->AsSEConstantNode()->FoldToSingleValue(),
153 1);
154
155 EXPECT_EQ(simplified->GetChild(0), simplified->GetChild(1));
156 }
157
158 /*
159 Generated from the following GLSL + --eliminate-local-multi-store
160
161 #version 410 core
162 layout (location = 1) out float array[10];
163 layout (location = 2) flat in int loop_invariant;
164 void main() {
165 for (int i = 0; i < 10; ++i) {
166 array[i] = array[i+loop_invariant];
167 }
168 }
169
170 */
TEST_F(ScalarAnalysisTest,LoadTest)171 TEST_F(ScalarAnalysisTest, LoadTest) {
172 const std::string text = R"(
173 OpCapability Shader
174 %1 = OpExtInstImport "GLSL.std.450"
175 OpMemoryModel Logical GLSL450
176 OpEntryPoint Fragment %2 "main" %3 %4
177 OpExecutionMode %2 OriginUpperLeft
178 OpSource GLSL 430
179 OpName %2 "main"
180 OpName %3 "array"
181 OpName %4 "loop_invariant"
182 OpDecorate %3 Location 1
183 OpDecorate %4 Flat
184 OpDecorate %4 Location 2
185 %5 = OpTypeVoid
186 %6 = OpTypeFunction %5
187 %7 = OpTypeInt 32 1
188 %8 = OpTypePointer Function %7
189 %9 = OpConstant %7 0
190 %10 = OpConstant %7 10
191 %11 = OpTypeBool
192 %12 = OpTypeFloat 32
193 %13 = OpTypeInt 32 0
194 %14 = OpConstant %13 10
195 %15 = OpTypeArray %12 %14
196 %16 = OpTypePointer Output %15
197 %3 = OpVariable %16 Output
198 %17 = OpTypePointer Input %7
199 %4 = OpVariable %17 Input
200 %18 = OpTypePointer Output %12
201 %19 = OpConstant %7 1
202 %2 = OpFunction %5 None %6
203 %20 = OpLabel
204 OpBranch %21
205 %21 = OpLabel
206 %22 = OpPhi %7 %9 %20 %23 %24
207 OpLoopMerge %25 %24 None
208 OpBranch %26
209 %26 = OpLabel
210 %27 = OpSLessThan %11 %22 %10
211 OpBranchConditional %27 %28 %25
212 %28 = OpLabel
213 %29 = OpLoad %7 %4
214 %30 = OpIAdd %7 %22 %29
215 %31 = OpAccessChain %18 %3 %30
216 %32 = OpLoad %12 %31
217 %33 = OpAccessChain %18 %3 %22
218 OpStore %33 %32
219 OpBranch %24
220 %24 = OpLabel
221 %23 = OpIAdd %7 %22 %19
222 OpBranch %21
223 %25 = OpLabel
224 OpReturn
225 OpFunctionEnd
226 )";
227 // clang-format on
228 std::unique_ptr<IRContext> context =
229 BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
230 SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
231 Module* module = context->module();
232 EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
233 << text << std::endl;
234 const Function* f = spvtest::GetFunction(module, 2);
235 ScalarEvolutionAnalysis analysis{context.get()};
236
237 const Instruction* load = nullptr;
238 for (const Instruction& inst : *spvtest::GetBasicBlock(f, 28)) {
239 if (inst.opcode() == SpvOp::SpvOpLoad) {
240 load = &inst;
241 }
242 }
243
244 EXPECT_NE(load, nullptr);
245
246 Instruction* access_chain =
247 context->get_def_use_mgr()->GetDef(load->GetSingleWordInOperand(0));
248
249 Instruction* child = context->get_def_use_mgr()->GetDef(
250 access_chain->GetSingleWordInOperand(1));
251 // const SENode* node =
252 // analysis.GetNodeFromInstruction(child->unique_id());
253
254 const SENode* node = analysis.AnalyzeInstruction(child);
255
256 EXPECT_NE(node, nullptr);
257
258 // Unsimplified node should have the form of ADD(REC(0,1), X)
259 EXPECT_EQ(node->GetType(), SENode::Add);
260
261 const SENode* child_1 = node->GetChild(0);
262 EXPECT_TRUE(child_1->GetType() == SENode::ValueUnknown ||
263 child_1->GetType() == SENode::RecurrentAddExpr);
264
265 const SENode* child_2 = node->GetChild(1);
266 EXPECT_TRUE(child_2->GetType() == SENode::ValueUnknown ||
267 child_2->GetType() == SENode::RecurrentAddExpr);
268
269 SENode* simplified = analysis.SimplifyExpression(const_cast<SENode*>(node));
270 EXPECT_EQ(simplified->GetType(), SENode::RecurrentAddExpr);
271
272 const SERecurrentNode* rec = simplified->AsSERecurrentNode();
273
274 EXPECT_NE(rec->GetChild(0), rec->GetChild(1));
275
276 EXPECT_EQ(rec->GetOffset()->GetType(), SENode::ValueUnknown);
277
278 EXPECT_EQ(rec->GetCoefficient()->GetType(), SENode::Constant);
279 EXPECT_EQ(rec->GetCoefficient()->AsSEConstantNode()->FoldToSingleValue(), 1u);
280 }
281
282 /*
283 Generated from the following GLSL + --eliminate-local-multi-store
284
285 #version 410 core
286 layout (location = 1) out float array[10];
287 layout (location = 2) flat in int loop_invariant;
288 void main() {
289 array[0] = array[loop_invariant * 2 + 4 + 5 - 24 - loop_invariant -
290 loop_invariant+ 16 * 3];
291 }
292
293 */
TEST_F(ScalarAnalysisTest,SimplifySimple)294 TEST_F(ScalarAnalysisTest, SimplifySimple) {
295 const std::string text = R"(
296 OpCapability Shader
297 %1 = OpExtInstImport "GLSL.std.450"
298 OpMemoryModel Logical GLSL450
299 OpEntryPoint Fragment %2 "main" %3 %4
300 OpExecutionMode %2 OriginUpperLeft
301 OpSource GLSL 430
302 OpName %2 "main"
303 OpName %3 "array"
304 OpName %4 "loop_invariant"
305 OpDecorate %3 Location 1
306 OpDecorate %4 Flat
307 OpDecorate %4 Location 2
308 %5 = OpTypeVoid
309 %6 = OpTypeFunction %5
310 %7 = OpTypeFloat 32
311 %8 = OpTypeInt 32 0
312 %9 = OpConstant %8 10
313 %10 = OpTypeArray %7 %9
314 %11 = OpTypePointer Output %10
315 %3 = OpVariable %11 Output
316 %12 = OpTypeInt 32 1
317 %13 = OpConstant %12 0
318 %14 = OpTypePointer Input %12
319 %4 = OpVariable %14 Input
320 %15 = OpConstant %12 2
321 %16 = OpConstant %12 4
322 %17 = OpConstant %12 5
323 %18 = OpConstant %12 24
324 %19 = OpConstant %12 48
325 %20 = OpTypePointer Output %7
326 %2 = OpFunction %5 None %6
327 %21 = OpLabel
328 %22 = OpLoad %12 %4
329 %23 = OpIMul %12 %22 %15
330 %24 = OpIAdd %12 %23 %16
331 %25 = OpIAdd %12 %24 %17
332 %26 = OpISub %12 %25 %18
333 %28 = OpISub %12 %26 %22
334 %30 = OpISub %12 %28 %22
335 %31 = OpIAdd %12 %30 %19
336 %32 = OpAccessChain %20 %3 %31
337 %33 = OpLoad %7 %32
338 %34 = OpAccessChain %20 %3 %13
339 OpStore %34 %33
340 OpReturn
341 OpFunctionEnd
342 )";
343 // clang-format on
344 std::unique_ptr<IRContext> context =
345 BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
346 SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
347 Module* module = context->module();
348 EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
349 << text << std::endl;
350 const Function* f = spvtest::GetFunction(module, 2);
351 ScalarEvolutionAnalysis analysis{context.get()};
352
353 const Instruction* load = nullptr;
354 for (const Instruction& inst : *spvtest::GetBasicBlock(f, 21)) {
355 if (inst.opcode() == SpvOp::SpvOpLoad && inst.result_id() == 33) {
356 load = &inst;
357 }
358 }
359
360 EXPECT_NE(load, nullptr);
361
362 Instruction* access_chain =
363 context->get_def_use_mgr()->GetDef(load->GetSingleWordInOperand(0));
364
365 Instruction* child = context->get_def_use_mgr()->GetDef(
366 access_chain->GetSingleWordInOperand(1));
367
368 const SENode* node = analysis.AnalyzeInstruction(child);
369
370 // Unsimplified is a very large graph with an add at the top.
371 EXPECT_NE(node, nullptr);
372 EXPECT_EQ(node->GetType(), SENode::Add);
373
374 // Simplified node should resolve down to a constant expression as the loads
375 // will eliminate themselves.
376 SENode* simplified = analysis.SimplifyExpression(const_cast<SENode*>(node));
377
378 EXPECT_EQ(simplified->GetType(), SENode::Constant);
379 EXPECT_EQ(simplified->AsSEConstantNode()->FoldToSingleValue(), 33u);
380 }
381
382 /*
383 Generated from the following GLSL + --eliminate-local-multi-store
384
385 #version 410 core
386 layout(location = 0) in vec4 c;
387 layout (location = 1) out float array[10];
388 void main() {
389 int N = int(c.x);
390 for (int i = 0; i < 10; ++i) {
391 array[i] = array[i];
392 array[i] = array[i-1];
393 array[i] = array[i+1];
394 array[i+1] = array[i+1];
395 array[i+N] = array[i+N];
396 array[i] = array[i+N];
397 }
398 }
399
400 */
TEST_F(ScalarAnalysisTest,Simplify)401 TEST_F(ScalarAnalysisTest, Simplify) {
402 const std::string text = R"( OpCapability Shader
403 %1 = OpExtInstImport "GLSL.std.450"
404 OpMemoryModel Logical GLSL450
405 OpEntryPoint Fragment %4 "main" %12 %33
406 OpExecutionMode %4 OriginUpperLeft
407 OpSource GLSL 410
408 OpName %4 "main"
409 OpName %8 "N"
410 OpName %12 "c"
411 OpName %19 "i"
412 OpName %33 "array"
413 OpDecorate %12 Location 0
414 OpDecorate %33 Location 1
415 %2 = OpTypeVoid
416 %3 = OpTypeFunction %2
417 %6 = OpTypeInt 32 1
418 %7 = OpTypePointer Function %6
419 %9 = OpTypeFloat 32
420 %10 = OpTypeVector %9 4
421 %11 = OpTypePointer Input %10
422 %12 = OpVariable %11 Input
423 %13 = OpTypeInt 32 0
424 %14 = OpConstant %13 0
425 %15 = OpTypePointer Input %9
426 %20 = OpConstant %6 0
427 %27 = OpConstant %6 10
428 %28 = OpTypeBool
429 %30 = OpConstant %13 10
430 %31 = OpTypeArray %9 %30
431 %32 = OpTypePointer Output %31
432 %33 = OpVariable %32 Output
433 %36 = OpTypePointer Output %9
434 %42 = OpConstant %6 1
435 %4 = OpFunction %2 None %3
436 %5 = OpLabel
437 %8 = OpVariable %7 Function
438 %19 = OpVariable %7 Function
439 %16 = OpAccessChain %15 %12 %14
440 %17 = OpLoad %9 %16
441 %18 = OpConvertFToS %6 %17
442 OpStore %8 %18
443 OpStore %19 %20
444 OpBranch %21
445 %21 = OpLabel
446 %78 = OpPhi %6 %20 %5 %77 %24
447 OpLoopMerge %23 %24 None
448 OpBranch %25
449 %25 = OpLabel
450 %29 = OpSLessThan %28 %78 %27
451 OpBranchConditional %29 %22 %23
452 %22 = OpLabel
453 %37 = OpAccessChain %36 %33 %78
454 %38 = OpLoad %9 %37
455 %39 = OpAccessChain %36 %33 %78
456 OpStore %39 %38
457 %43 = OpISub %6 %78 %42
458 %44 = OpAccessChain %36 %33 %43
459 %45 = OpLoad %9 %44
460 %46 = OpAccessChain %36 %33 %78
461 OpStore %46 %45
462 %49 = OpIAdd %6 %78 %42
463 %50 = OpAccessChain %36 %33 %49
464 %51 = OpLoad %9 %50
465 %52 = OpAccessChain %36 %33 %78
466 OpStore %52 %51
467 %54 = OpIAdd %6 %78 %42
468 %56 = OpIAdd %6 %78 %42
469 %57 = OpAccessChain %36 %33 %56
470 %58 = OpLoad %9 %57
471 %59 = OpAccessChain %36 %33 %54
472 OpStore %59 %58
473 %62 = OpIAdd %6 %78 %18
474 %65 = OpIAdd %6 %78 %18
475 %66 = OpAccessChain %36 %33 %65
476 %67 = OpLoad %9 %66
477 %68 = OpAccessChain %36 %33 %62
478 OpStore %68 %67
479 %72 = OpIAdd %6 %78 %18
480 %73 = OpAccessChain %36 %33 %72
481 %74 = OpLoad %9 %73
482 %75 = OpAccessChain %36 %33 %78
483 OpStore %75 %74
484 OpBranch %24
485 %24 = OpLabel
486 %77 = OpIAdd %6 %78 %42
487 OpStore %19 %77
488 OpBranch %21
489 %23 = OpLabel
490 OpReturn
491 OpFunctionEnd
492 )";
493 // clang-format on
494 std::unique_ptr<IRContext> context =
495 BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
496 SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
497 Module* module = context->module();
498 EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
499 << text << std::endl;
500 const Function* f = spvtest::GetFunction(module, 4);
501 ScalarEvolutionAnalysis analysis{context.get()};
502
503 const Instruction* loads[6];
504 const Instruction* stores[6];
505 int load_count = 0;
506 int store_count = 0;
507
508 for (const Instruction& inst : *spvtest::GetBasicBlock(f, 22)) {
509 if (inst.opcode() == SpvOp::SpvOpLoad) {
510 loads[load_count] = &inst;
511 ++load_count;
512 }
513 if (inst.opcode() == SpvOp::SpvOpStore) {
514 stores[store_count] = &inst;
515 ++store_count;
516 }
517 }
518
519 EXPECT_EQ(load_count, 6);
520 EXPECT_EQ(store_count, 6);
521
522 Instruction* load_access_chain;
523 Instruction* store_access_chain;
524 Instruction* load_child;
525 Instruction* store_child;
526 SENode* load_node;
527 SENode* store_node;
528 SENode* subtract_node;
529 SENode* simplified_node;
530
531 // Testing [i] - [i] == 0
532 load_access_chain =
533 context->get_def_use_mgr()->GetDef(loads[0]->GetSingleWordInOperand(0));
534 store_access_chain =
535 context->get_def_use_mgr()->GetDef(stores[0]->GetSingleWordInOperand(0));
536
537 load_child = context->get_def_use_mgr()->GetDef(
538 load_access_chain->GetSingleWordInOperand(1));
539 store_child = context->get_def_use_mgr()->GetDef(
540 store_access_chain->GetSingleWordInOperand(1));
541
542 load_node = analysis.AnalyzeInstruction(load_child);
543 store_node = analysis.AnalyzeInstruction(store_child);
544
545 subtract_node = analysis.CreateSubtraction(store_node, load_node);
546 simplified_node = analysis.SimplifyExpression(subtract_node);
547 EXPECT_EQ(simplified_node->GetType(), SENode::Constant);
548 EXPECT_EQ(simplified_node->AsSEConstantNode()->FoldToSingleValue(), 0u);
549
550 // Testing [i] - [i-1] == 1
551 load_access_chain =
552 context->get_def_use_mgr()->GetDef(loads[1]->GetSingleWordInOperand(0));
553 store_access_chain =
554 context->get_def_use_mgr()->GetDef(stores[1]->GetSingleWordInOperand(0));
555
556 load_child = context->get_def_use_mgr()->GetDef(
557 load_access_chain->GetSingleWordInOperand(1));
558 store_child = context->get_def_use_mgr()->GetDef(
559 store_access_chain->GetSingleWordInOperand(1));
560
561 load_node = analysis.AnalyzeInstruction(load_child);
562 store_node = analysis.AnalyzeInstruction(store_child);
563
564 subtract_node = analysis.CreateSubtraction(store_node, load_node);
565 simplified_node = analysis.SimplifyExpression(subtract_node);
566
567 EXPECT_EQ(simplified_node->GetType(), SENode::Constant);
568 EXPECT_EQ(simplified_node->AsSEConstantNode()->FoldToSingleValue(), 1u);
569
570 // Testing [i] - [i+1] == -1
571 load_access_chain =
572 context->get_def_use_mgr()->GetDef(loads[2]->GetSingleWordInOperand(0));
573 store_access_chain =
574 context->get_def_use_mgr()->GetDef(stores[2]->GetSingleWordInOperand(0));
575
576 load_child = context->get_def_use_mgr()->GetDef(
577 load_access_chain->GetSingleWordInOperand(1));
578 store_child = context->get_def_use_mgr()->GetDef(
579 store_access_chain->GetSingleWordInOperand(1));
580
581 load_node = analysis.AnalyzeInstruction(load_child);
582 store_node = analysis.AnalyzeInstruction(store_child);
583
584 subtract_node = analysis.CreateSubtraction(store_node, load_node);
585 simplified_node = analysis.SimplifyExpression(subtract_node);
586 EXPECT_EQ(simplified_node->GetType(), SENode::Constant);
587 EXPECT_EQ(simplified_node->AsSEConstantNode()->FoldToSingleValue(), -1);
588
589 // Testing [i+1] - [i+1] == 0
590 load_access_chain =
591 context->get_def_use_mgr()->GetDef(loads[3]->GetSingleWordInOperand(0));
592 store_access_chain =
593 context->get_def_use_mgr()->GetDef(stores[3]->GetSingleWordInOperand(0));
594
595 load_child = context->get_def_use_mgr()->GetDef(
596 load_access_chain->GetSingleWordInOperand(1));
597 store_child = context->get_def_use_mgr()->GetDef(
598 store_access_chain->GetSingleWordInOperand(1));
599
600 load_node = analysis.AnalyzeInstruction(load_child);
601 store_node = analysis.AnalyzeInstruction(store_child);
602
603 subtract_node = analysis.CreateSubtraction(store_node, load_node);
604 simplified_node = analysis.SimplifyExpression(subtract_node);
605 EXPECT_EQ(simplified_node->GetType(), SENode::Constant);
606 EXPECT_EQ(simplified_node->AsSEConstantNode()->FoldToSingleValue(), 0u);
607
608 // Testing [i+N] - [i+N] == 0
609 load_access_chain =
610 context->get_def_use_mgr()->GetDef(loads[4]->GetSingleWordInOperand(0));
611 store_access_chain =
612 context->get_def_use_mgr()->GetDef(stores[4]->GetSingleWordInOperand(0));
613
614 load_child = context->get_def_use_mgr()->GetDef(
615 load_access_chain->GetSingleWordInOperand(1));
616 store_child = context->get_def_use_mgr()->GetDef(
617 store_access_chain->GetSingleWordInOperand(1));
618
619 load_node = analysis.AnalyzeInstruction(load_child);
620 store_node = analysis.AnalyzeInstruction(store_child);
621
622 subtract_node = analysis.CreateSubtraction(store_node, load_node);
623
624 simplified_node = analysis.SimplifyExpression(subtract_node);
625 EXPECT_EQ(simplified_node->GetType(), SENode::Constant);
626 EXPECT_EQ(simplified_node->AsSEConstantNode()->FoldToSingleValue(), 0u);
627
628 // Testing [i] - [i+N] == -N
629 load_access_chain =
630 context->get_def_use_mgr()->GetDef(loads[5]->GetSingleWordInOperand(0));
631 store_access_chain =
632 context->get_def_use_mgr()->GetDef(stores[5]->GetSingleWordInOperand(0));
633
634 load_child = context->get_def_use_mgr()->GetDef(
635 load_access_chain->GetSingleWordInOperand(1));
636 store_child = context->get_def_use_mgr()->GetDef(
637 store_access_chain->GetSingleWordInOperand(1));
638
639 load_node = analysis.AnalyzeInstruction(load_child);
640 store_node = analysis.AnalyzeInstruction(store_child);
641
642 subtract_node = analysis.CreateSubtraction(store_node, load_node);
643 simplified_node = analysis.SimplifyExpression(subtract_node);
644 EXPECT_EQ(simplified_node->GetType(), SENode::Negative);
645 }
646
647 /*
648 Generated from the following GLSL + --eliminate-local-multi-store
649
650 #version 430
651 layout(location = 1) out float array[10];
652 layout(location = 2) flat in int loop_invariant;
653 void main(void) {
654 for (int i = 0; i < 10; ++i) {
655 array[i * 2 + i * 5] = array[i * i * 2];
656 array[i * 2] = array[i * 5];
657 }
658 }
659
660 */
661
TEST_F(ScalarAnalysisTest,SimplifyMultiplyInductions)662 TEST_F(ScalarAnalysisTest, SimplifyMultiplyInductions) {
663 const std::string text = R"(
664 OpCapability Shader
665 %1 = OpExtInstImport "GLSL.std.450"
666 OpMemoryModel Logical GLSL450
667 OpEntryPoint Fragment %2 "main" %3 %4
668 OpExecutionMode %2 OriginUpperLeft
669 OpSource GLSL 430
670 OpName %2 "main"
671 OpName %5 "i"
672 OpName %3 "array"
673 OpName %4 "loop_invariant"
674 OpDecorate %3 Location 1
675 OpDecorate %4 Flat
676 OpDecorate %4 Location 2
677 %6 = OpTypeVoid
678 %7 = OpTypeFunction %6
679 %8 = OpTypeInt 32 1
680 %9 = OpTypePointer Function %8
681 %10 = OpConstant %8 0
682 %11 = OpConstant %8 10
683 %12 = OpTypeBool
684 %13 = OpTypeFloat 32
685 %14 = OpTypeInt 32 0
686 %15 = OpConstant %14 10
687 %16 = OpTypeArray %13 %15
688 %17 = OpTypePointer Output %16
689 %3 = OpVariable %17 Output
690 %18 = OpConstant %8 2
691 %19 = OpConstant %8 5
692 %20 = OpTypePointer Output %13
693 %21 = OpConstant %8 1
694 %22 = OpTypePointer Input %8
695 %4 = OpVariable %22 Input
696 %2 = OpFunction %6 None %7
697 %23 = OpLabel
698 %5 = OpVariable %9 Function
699 OpStore %5 %10
700 OpBranch %24
701 %24 = OpLabel
702 %25 = OpPhi %8 %10 %23 %26 %27
703 OpLoopMerge %28 %27 None
704 OpBranch %29
705 %29 = OpLabel
706 %30 = OpSLessThan %12 %25 %11
707 OpBranchConditional %30 %31 %28
708 %31 = OpLabel
709 %32 = OpIMul %8 %25 %18
710 %33 = OpIMul %8 %25 %19
711 %34 = OpIAdd %8 %32 %33
712 %35 = OpIMul %8 %25 %25
713 %36 = OpIMul %8 %35 %18
714 %37 = OpAccessChain %20 %3 %36
715 %38 = OpLoad %13 %37
716 %39 = OpAccessChain %20 %3 %34
717 OpStore %39 %38
718 %40 = OpIMul %8 %25 %18
719 %41 = OpIMul %8 %25 %19
720 %42 = OpAccessChain %20 %3 %41
721 %43 = OpLoad %13 %42
722 %44 = OpAccessChain %20 %3 %40
723 OpStore %44 %43
724 OpBranch %27
725 %27 = OpLabel
726 %26 = OpIAdd %8 %25 %21
727 OpStore %5 %26
728 OpBranch %24
729 %28 = OpLabel
730 OpReturn
731 OpFunctionEnd
732 )";
733 std::unique_ptr<IRContext> context =
734 BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
735 SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
736 Module* module = context->module();
737 EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
738 << text << std::endl;
739 const Function* f = spvtest::GetFunction(module, 2);
740 ScalarEvolutionAnalysis analysis{context.get()};
741
742 const Instruction* loads[2] = {nullptr, nullptr};
743 const Instruction* stores[2] = {nullptr, nullptr};
744 int load_count = 0;
745 int store_count = 0;
746
747 for (const Instruction& inst : *spvtest::GetBasicBlock(f, 31)) {
748 if (inst.opcode() == SpvOp::SpvOpLoad) {
749 loads[load_count] = &inst;
750 ++load_count;
751 }
752 if (inst.opcode() == SpvOp::SpvOpStore) {
753 stores[store_count] = &inst;
754 ++store_count;
755 }
756 }
757
758 EXPECT_EQ(load_count, 2);
759 EXPECT_EQ(store_count, 2);
760
761 Instruction* load_access_chain =
762 context->get_def_use_mgr()->GetDef(loads[0]->GetSingleWordInOperand(0));
763 Instruction* store_access_chain =
764 context->get_def_use_mgr()->GetDef(stores[0]->GetSingleWordInOperand(0));
765
766 Instruction* load_child = context->get_def_use_mgr()->GetDef(
767 load_access_chain->GetSingleWordInOperand(1));
768 Instruction* store_child = context->get_def_use_mgr()->GetDef(
769 store_access_chain->GetSingleWordInOperand(1));
770
771 SENode* store_node = analysis.AnalyzeInstruction(store_child);
772
773 SENode* store_simplified = analysis.SimplifyExpression(store_node);
774
775 load_access_chain =
776 context->get_def_use_mgr()->GetDef(loads[1]->GetSingleWordInOperand(0));
777 store_access_chain =
778 context->get_def_use_mgr()->GetDef(stores[1]->GetSingleWordInOperand(0));
779 load_child = context->get_def_use_mgr()->GetDef(
780 load_access_chain->GetSingleWordInOperand(1));
781 store_child = context->get_def_use_mgr()->GetDef(
782 store_access_chain->GetSingleWordInOperand(1));
783
784 SENode* second_store =
785 analysis.SimplifyExpression(analysis.AnalyzeInstruction(store_child));
786 SENode* second_load =
787 analysis.SimplifyExpression(analysis.AnalyzeInstruction(load_child));
788 SENode* combined_add = analysis.SimplifyExpression(
789 analysis.CreateAddNode(second_load, second_store));
790
791 // We're checking that the two recurrent expression have been correctly
792 // folded. In store_simplified they will have been folded as the entire
793 // expression was simplified as one. In combined_add the two expressions have
794 // been simplified one after the other which means the recurrent expressions
795 // aren't exactly the same but should still be folded as they are with respect
796 // to the same loop.
797 EXPECT_EQ(combined_add, store_simplified);
798 }
799
800 /*
801 Generated from the following GLSL + --eliminate-local-multi-store
802
803 #version 430
804 void main(void) {
805 for (int i = 0; i < 10; --i) {
806 array[i] = array[i];
807 }
808 }
809
810 */
811
TEST_F(ScalarAnalysisTest,SimplifyNegativeSteps)812 TEST_F(ScalarAnalysisTest, SimplifyNegativeSteps) {
813 const std::string text = R"(
814 OpCapability Shader
815 %1 = OpExtInstImport "GLSL.std.450"
816 OpMemoryModel Logical GLSL450
817 OpEntryPoint Fragment %2 "main" %3 %4
818 OpExecutionMode %2 OriginUpperLeft
819 OpSource GLSL 430
820 OpName %2 "main"
821 OpName %5 "i"
822 OpName %3 "array"
823 OpName %4 "loop_invariant"
824 OpDecorate %3 Location 1
825 OpDecorate %4 Flat
826 OpDecorate %4 Location 2
827 %6 = OpTypeVoid
828 %7 = OpTypeFunction %6
829 %8 = OpTypeInt 32 1
830 %9 = OpTypePointer Function %8
831 %10 = OpConstant %8 0
832 %11 = OpConstant %8 10
833 %12 = OpTypeBool
834 %13 = OpTypeFloat 32
835 %14 = OpTypeInt 32 0
836 %15 = OpConstant %14 10
837 %16 = OpTypeArray %13 %15
838 %17 = OpTypePointer Output %16
839 %3 = OpVariable %17 Output
840 %18 = OpTypePointer Output %13
841 %19 = OpConstant %8 1
842 %20 = OpTypePointer Input %8
843 %4 = OpVariable %20 Input
844 %2 = OpFunction %6 None %7
845 %21 = OpLabel
846 %5 = OpVariable %9 Function
847 OpStore %5 %10
848 OpBranch %22
849 %22 = OpLabel
850 %23 = OpPhi %8 %10 %21 %24 %25
851 OpLoopMerge %26 %25 None
852 OpBranch %27
853 %27 = OpLabel
854 %28 = OpSLessThan %12 %23 %11
855 OpBranchConditional %28 %29 %26
856 %29 = OpLabel
857 %30 = OpAccessChain %18 %3 %23
858 %31 = OpLoad %13 %30
859 %32 = OpAccessChain %18 %3 %23
860 OpStore %32 %31
861 OpBranch %25
862 %25 = OpLabel
863 %24 = OpISub %8 %23 %19
864 OpStore %5 %24
865 OpBranch %22
866 %26 = OpLabel
867 OpReturn
868 OpFunctionEnd
869 )";
870 std::unique_ptr<IRContext> context =
871 BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
872 SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
873 Module* module = context->module();
874 EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
875 << text << std::endl;
876 const Function* f = spvtest::GetFunction(module, 2);
877 ScalarEvolutionAnalysis analysis{context.get()};
878
879 const Instruction* loads[1] = {nullptr};
880 int load_count = 0;
881
882 for (const Instruction& inst : *spvtest::GetBasicBlock(f, 29)) {
883 if (inst.opcode() == SpvOp::SpvOpLoad) {
884 loads[load_count] = &inst;
885 ++load_count;
886 }
887 }
888
889 EXPECT_EQ(load_count, 1);
890
891 Instruction* load_access_chain =
892 context->get_def_use_mgr()->GetDef(loads[0]->GetSingleWordInOperand(0));
893 Instruction* load_child = context->get_def_use_mgr()->GetDef(
894 load_access_chain->GetSingleWordInOperand(1));
895
896 SENode* load_node = analysis.AnalyzeInstruction(load_child);
897
898 EXPECT_TRUE(load_node);
899 EXPECT_EQ(load_node->GetType(), SENode::RecurrentAddExpr);
900 EXPECT_TRUE(load_node->AsSERecurrentNode());
901
902 SENode* child_1 = load_node->AsSERecurrentNode()->GetCoefficient();
903 SENode* child_2 = load_node->AsSERecurrentNode()->GetOffset();
904
905 EXPECT_EQ(child_1->GetType(), SENode::Constant);
906 EXPECT_EQ(child_2->GetType(), SENode::Constant);
907
908 EXPECT_EQ(child_1->AsSEConstantNode()->FoldToSingleValue(), -1);
909 EXPECT_EQ(child_2->AsSEConstantNode()->FoldToSingleValue(), 0u);
910
911 SERecurrentNode* load_simplified =
912 analysis.SimplifyExpression(load_node)->AsSERecurrentNode();
913
914 EXPECT_TRUE(load_simplified);
915 EXPECT_EQ(load_node, load_simplified);
916
917 EXPECT_EQ(load_simplified->GetType(), SENode::RecurrentAddExpr);
918 EXPECT_TRUE(load_simplified->AsSERecurrentNode());
919
920 SENode* simplified_child_1 =
921 load_simplified->AsSERecurrentNode()->GetCoefficient();
922 SENode* simplified_child_2 =
923 load_simplified->AsSERecurrentNode()->GetOffset();
924
925 EXPECT_EQ(child_1, simplified_child_1);
926 EXPECT_EQ(child_2, simplified_child_2);
927 }
928
929 /*
930 Generated from the following GLSL + --eliminate-local-multi-store
931
932 #version 430
933 void main(void) {
934 for (int i = 0; i < 10; --i) {
935 array[i] = array[i];
936 }
937 }
938
939 */
940
TEST_F(ScalarAnalysisTest,SimplifyInductionsAndLoads)941 TEST_F(ScalarAnalysisTest, SimplifyInductionsAndLoads) {
942 const std::string text = R"(
943 OpCapability Shader
944 %1 = OpExtInstImport "GLSL.std.450"
945 OpMemoryModel Logical GLSL450
946 OpEntryPoint Fragment %2 "main" %3 %4
947 OpExecutionMode %2 OriginUpperLeft
948 OpSource GLSL 430
949 OpName %2 "main"
950 OpName %5 "i"
951 OpName %3 "array"
952 OpName %4 "N"
953 OpDecorate %3 Location 1
954 OpDecorate %4 Flat
955 OpDecorate %4 Location 2
956 %6 = OpTypeVoid
957 %7 = OpTypeFunction %6
958 %8 = OpTypeInt 32 1
959 %9 = OpTypePointer Function %8
960 %10 = OpConstant %8 0
961 %11 = OpConstant %8 10
962 %12 = OpTypeBool
963 %13 = OpTypeFloat 32
964 %14 = OpTypeInt 32 0
965 %15 = OpConstant %14 10
966 %16 = OpTypeArray %13 %15
967 %17 = OpTypePointer Output %16
968 %3 = OpVariable %17 Output
969 %18 = OpConstant %8 2
970 %19 = OpTypePointer Input %8
971 %4 = OpVariable %19 Input
972 %20 = OpTypePointer Output %13
973 %21 = OpConstant %8 1
974 %2 = OpFunction %6 None %7
975 %22 = OpLabel
976 %5 = OpVariable %9 Function
977 OpStore %5 %10
978 OpBranch %23
979 %23 = OpLabel
980 %24 = OpPhi %8 %10 %22 %25 %26
981 OpLoopMerge %27 %26 None
982 OpBranch %28
983 %28 = OpLabel
984 %29 = OpSLessThan %12 %24 %11
985 OpBranchConditional %29 %30 %27
986 %30 = OpLabel
987 %31 = OpLoad %8 %4
988 %32 = OpIMul %8 %18 %31
989 %33 = OpIAdd %8 %24 %32
990 %35 = OpIAdd %8 %24 %31
991 %36 = OpAccessChain %20 %3 %35
992 %37 = OpLoad %13 %36
993 %38 = OpAccessChain %20 %3 %33
994 OpStore %38 %37
995 %39 = OpIMul %8 %18 %24
996 %41 = OpIMul %8 %18 %31
997 %42 = OpIAdd %8 %39 %41
998 %43 = OpIAdd %8 %42 %21
999 %44 = OpIMul %8 %18 %24
1000 %46 = OpIAdd %8 %44 %31
1001 %47 = OpIAdd %8 %46 %21
1002 %48 = OpAccessChain %20 %3 %47
1003 %49 = OpLoad %13 %48
1004 %50 = OpAccessChain %20 %3 %43
1005 OpStore %50 %49
1006 OpBranch %26
1007 %26 = OpLabel
1008 %25 = OpISub %8 %24 %21
1009 OpStore %5 %25
1010 OpBranch %23
1011 %27 = OpLabel
1012 OpReturn
1013 OpFunctionEnd
1014 )";
1015 std::unique_ptr<IRContext> context =
1016 BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
1017 SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
1018 Module* module = context->module();
1019 EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
1020 << text << std::endl;
1021 const Function* f = spvtest::GetFunction(module, 2);
1022 ScalarEvolutionAnalysis analysis{context.get()};
1023
1024 std::vector<const Instruction*> loads{};
1025 std::vector<const Instruction*> stores{};
1026
1027 for (const Instruction& inst : *spvtest::GetBasicBlock(f, 30)) {
1028 if (inst.opcode() == SpvOp::SpvOpLoad) {
1029 loads.push_back(&inst);
1030 }
1031 if (inst.opcode() == SpvOp::SpvOpStore) {
1032 stores.push_back(&inst);
1033 }
1034 }
1035
1036 EXPECT_EQ(loads.size(), 3u);
1037 EXPECT_EQ(stores.size(), 2u);
1038 {
1039 Instruction* store_access_chain = context->get_def_use_mgr()->GetDef(
1040 stores[0]->GetSingleWordInOperand(0));
1041
1042 Instruction* store_child = context->get_def_use_mgr()->GetDef(
1043 store_access_chain->GetSingleWordInOperand(1));
1044
1045 SENode* store_node = analysis.AnalyzeInstruction(store_child);
1046
1047 SENode* store_simplified = analysis.SimplifyExpression(store_node);
1048
1049 Instruction* load_access_chain =
1050 context->get_def_use_mgr()->GetDef(loads[1]->GetSingleWordInOperand(0));
1051
1052 Instruction* load_child = context->get_def_use_mgr()->GetDef(
1053 load_access_chain->GetSingleWordInOperand(1));
1054
1055 SENode* load_node = analysis.AnalyzeInstruction(load_child);
1056
1057 SENode* load_simplified = analysis.SimplifyExpression(load_node);
1058
1059 SENode* difference =
1060 analysis.CreateSubtraction(store_simplified, load_simplified);
1061
1062 SENode* difference_simplified = analysis.SimplifyExpression(difference);
1063
1064 // Check that i+2*N - i*N, turns into just N when both sides have already
1065 // been simplified into a single recurrent expression.
1066 EXPECT_EQ(difference_simplified->GetType(), SENode::ValueUnknown);
1067
1068 // Check that the inverse, i*N - i+2*N turns into -N.
1069 SENode* difference_inverse = analysis.SimplifyExpression(
1070 analysis.CreateSubtraction(load_simplified, store_simplified));
1071
1072 EXPECT_EQ(difference_inverse->GetType(), SENode::Negative);
1073 EXPECT_EQ(difference_inverse->GetChild(0)->GetType(), SENode::ValueUnknown);
1074 EXPECT_EQ(difference_inverse->GetChild(0), difference_simplified);
1075 }
1076
1077 {
1078 Instruction* store_access_chain = context->get_def_use_mgr()->GetDef(
1079 stores[1]->GetSingleWordInOperand(0));
1080
1081 Instruction* store_child = context->get_def_use_mgr()->GetDef(
1082 store_access_chain->GetSingleWordInOperand(1));
1083 SENode* store_node = analysis.AnalyzeInstruction(store_child);
1084 SENode* store_simplified = analysis.SimplifyExpression(store_node);
1085
1086 Instruction* load_access_chain =
1087 context->get_def_use_mgr()->GetDef(loads[2]->GetSingleWordInOperand(0));
1088
1089 Instruction* load_child = context->get_def_use_mgr()->GetDef(
1090 load_access_chain->GetSingleWordInOperand(1));
1091
1092 SENode* load_node = analysis.AnalyzeInstruction(load_child);
1093
1094 SENode* load_simplified = analysis.SimplifyExpression(load_node);
1095
1096 SENode* difference =
1097 analysis.CreateSubtraction(store_simplified, load_simplified);
1098 SENode* difference_simplified = analysis.SimplifyExpression(difference);
1099
1100 // Check that 2*i + 2*N + 1 - 2*i + N + 1, turns into just N when both
1101 // sides have already been simplified into a single recurrent expression.
1102 EXPECT_EQ(difference_simplified->GetType(), SENode::ValueUnknown);
1103
1104 // Check that the inverse, (2*i + N + 1) - (2*i + 2*N + 1) turns into -N.
1105 SENode* difference_inverse = analysis.SimplifyExpression(
1106 analysis.CreateSubtraction(load_simplified, store_simplified));
1107
1108 EXPECT_EQ(difference_inverse->GetType(), SENode::Negative);
1109 EXPECT_EQ(difference_inverse->GetChild(0)->GetType(), SENode::ValueUnknown);
1110 EXPECT_EQ(difference_inverse->GetChild(0), difference_simplified);
1111 }
1112 }
1113
1114 /* Generated from the following GLSL + --eliminate-local-multi-store
1115
1116 #version 430
1117 layout(location = 1) out float array[10];
1118 layout(location = 2) flat in int N;
1119 void main(void) {
1120 int step = 0;
1121 for (int i = 0; i < N; i += step) {
1122 step++;
1123 }
1124 }
1125 */
TEST_F(ScalarAnalysisTest,InductionWithVariantStep)1126 TEST_F(ScalarAnalysisTest, InductionWithVariantStep) {
1127 const std::string text = R"(
1128 OpCapability Shader
1129 %1 = OpExtInstImport "GLSL.std.450"
1130 OpMemoryModel Logical GLSL450
1131 OpEntryPoint Fragment %2 "main" %3 %4
1132 OpExecutionMode %2 OriginUpperLeft
1133 OpSource GLSL 430
1134 OpName %2 "main"
1135 OpName %5 "step"
1136 OpName %6 "i"
1137 OpName %3 "N"
1138 OpName %4 "array"
1139 OpDecorate %3 Flat
1140 OpDecorate %3 Location 2
1141 OpDecorate %4 Location 1
1142 %7 = OpTypeVoid
1143 %8 = OpTypeFunction %7
1144 %9 = OpTypeInt 32 1
1145 %10 = OpTypePointer Function %9
1146 %11 = OpConstant %9 0
1147 %12 = OpTypePointer Input %9
1148 %3 = OpVariable %12 Input
1149 %13 = OpTypeBool
1150 %14 = OpConstant %9 1
1151 %15 = OpTypeFloat 32
1152 %16 = OpTypeInt 32 0
1153 %17 = OpConstant %16 10
1154 %18 = OpTypeArray %15 %17
1155 %19 = OpTypePointer Output %18
1156 %4 = OpVariable %19 Output
1157 %2 = OpFunction %7 None %8
1158 %20 = OpLabel
1159 %5 = OpVariable %10 Function
1160 %6 = OpVariable %10 Function
1161 OpStore %5 %11
1162 OpStore %6 %11
1163 OpBranch %21
1164 %21 = OpLabel
1165 %22 = OpPhi %9 %11 %20 %23 %24
1166 %25 = OpPhi %9 %11 %20 %26 %24
1167 OpLoopMerge %27 %24 None
1168 OpBranch %28
1169 %28 = OpLabel
1170 %29 = OpLoad %9 %3
1171 %30 = OpSLessThan %13 %25 %29
1172 OpBranchConditional %30 %31 %27
1173 %31 = OpLabel
1174 %23 = OpIAdd %9 %22 %14
1175 OpStore %5 %23
1176 OpBranch %24
1177 %24 = OpLabel
1178 %26 = OpIAdd %9 %25 %23
1179 OpStore %6 %26
1180 OpBranch %21
1181 %27 = OpLabel
1182 OpReturn
1183 OpFunctionEnd
1184 )";
1185 std::unique_ptr<IRContext> context =
1186 BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
1187 SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
1188 Module* module = context->module();
1189 EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
1190 << text << std::endl;
1191 const Function* f = spvtest::GetFunction(module, 2);
1192 ScalarEvolutionAnalysis analysis{context.get()};
1193
1194 std::vector<const Instruction*> phis{};
1195
1196 for (const Instruction& inst : *spvtest::GetBasicBlock(f, 21)) {
1197 if (inst.opcode() == SpvOp::SpvOpPhi) {
1198 phis.push_back(&inst);
1199 }
1200 }
1201
1202 EXPECT_EQ(phis.size(), 2u);
1203 SENode* phi_node_1 = analysis.AnalyzeInstruction(phis[0]);
1204 SENode* phi_node_2 = analysis.AnalyzeInstruction(phis[1]);
1205 EXPECT_NE(phi_node_1, nullptr);
1206 EXPECT_NE(phi_node_2, nullptr);
1207
1208 EXPECT_EQ(phi_node_1->GetType(), SENode::RecurrentAddExpr);
1209 EXPECT_EQ(phi_node_2->GetType(), SENode::CanNotCompute);
1210
1211 SENode* simplified_1 = analysis.SimplifyExpression(phi_node_1);
1212 SENode* simplified_2 = analysis.SimplifyExpression(phi_node_2);
1213
1214 EXPECT_EQ(simplified_1->GetType(), SENode::RecurrentAddExpr);
1215 EXPECT_EQ(simplified_2->GetType(), SENode::CanNotCompute);
1216 }
1217
1218 } // namespace
1219 } // namespace opt
1220 } // namespace spvtools
1221