1 // Copyright (c) 2017 Google Inc.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include <map>
16 #include <memory>
17 #include <string>
18 #include <vector>
19
20 #include "gmock/gmock.h"
21 #include "gtest/gtest.h"
22 #include "source/opt/build_module.h"
23 #include "source/opt/cfg.h"
24 #include "source/opt/ir_context.h"
25 #include "source/opt/pass.h"
26 #include "source/opt/propagator.h"
27
28 namespace spvtools {
29 namespace opt {
30 namespace {
31
32 using ::testing::UnorderedElementsAre;
33
34 class PropagatorTest : public testing::Test {
35 protected:
TearDown()36 virtual void TearDown() {
37 ctx_.reset(nullptr);
38 values_.clear();
39 values_vec_.clear();
40 }
41
Assemble(const std::string & input)42 void Assemble(const std::string& input) {
43 ctx_ = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, input);
44 ASSERT_NE(nullptr, ctx_) << "Assembling failed for shader:\n"
45 << input << "\n";
46 }
47
Propagate(const SSAPropagator::VisitFunction & visit_fn)48 bool Propagate(const SSAPropagator::VisitFunction& visit_fn) {
49 SSAPropagator propagator(ctx_.get(), visit_fn);
50 bool retval = false;
51 for (auto& fn : *ctx_->module()) {
52 retval |= propagator.Run(&fn);
53 }
54 return retval;
55 }
56
GetValues()57 const std::vector<uint32_t>& GetValues() {
58 values_vec_.clear();
59 for (const auto& it : values_) {
60 values_vec_.push_back(it.second);
61 }
62 return values_vec_;
63 }
64
65 std::unique_ptr<IRContext> ctx_;
66 std::map<uint32_t, uint32_t> values_;
67 std::vector<uint32_t> values_vec_;
68 };
69
TEST_F(PropagatorTest,LocalPropagate)70 TEST_F(PropagatorTest, LocalPropagate) {
71 const std::string spv_asm = R"(
72 OpCapability Shader
73 %1 = OpExtInstImport "GLSL.std.450"
74 OpMemoryModel Logical GLSL450
75 OpEntryPoint Fragment %main "main" %outparm
76 OpExecutionMode %main OriginUpperLeft
77 OpSource GLSL 450
78 OpName %main "main"
79 OpName %x "x"
80 OpName %y "y"
81 OpName %z "z"
82 OpName %outparm "outparm"
83 OpDecorate %outparm Location 0
84 %void = OpTypeVoid
85 %3 = OpTypeFunction %void
86 %int = OpTypeInt 32 1
87 %_ptr_Function_int = OpTypePointer Function %int
88 %int_4 = OpConstant %int 4
89 %int_3 = OpConstant %int 3
90 %int_1 = OpConstant %int 1
91 %_ptr_Output_int = OpTypePointer Output %int
92 %outparm = OpVariable %_ptr_Output_int Output
93 %main = OpFunction %void None %3
94 %5 = OpLabel
95 %x = OpVariable %_ptr_Function_int Function
96 %y = OpVariable %_ptr_Function_int Function
97 %z = OpVariable %_ptr_Function_int Function
98 OpStore %x %int_4
99 OpStore %y %int_3
100 OpStore %z %int_1
101 %20 = OpLoad %int %z
102 OpStore %outparm %20
103 OpReturn
104 OpFunctionEnd
105 )";
106 Assemble(spv_asm);
107
108 const auto visit_fn = [this](Instruction* instr, BasicBlock** dest_bb) {
109 *dest_bb = nullptr;
110 if (instr->opcode() == SpvOpStore) {
111 uint32_t lhs_id = instr->GetSingleWordOperand(0);
112 uint32_t rhs_id = instr->GetSingleWordOperand(1);
113 Instruction* rhs_def = ctx_->get_def_use_mgr()->GetDef(rhs_id);
114 if (rhs_def->opcode() == SpvOpConstant) {
115 uint32_t val = rhs_def->GetSingleWordOperand(2);
116 values_[lhs_id] = val;
117 return SSAPropagator::kInteresting;
118 }
119 }
120 return SSAPropagator::kVarying;
121 };
122
123 EXPECT_TRUE(Propagate(visit_fn));
124 EXPECT_THAT(GetValues(), UnorderedElementsAre(4, 3, 1));
125 }
126
TEST_F(PropagatorTest,PropagateThroughPhis)127 TEST_F(PropagatorTest, PropagateThroughPhis) {
128 const std::string spv_asm = R"(
129 OpCapability Shader
130 %1 = OpExtInstImport "GLSL.std.450"
131 OpMemoryModel Logical GLSL450
132 OpEntryPoint Fragment %main "main" %x %outparm
133 OpExecutionMode %main OriginUpperLeft
134 OpSource GLSL 450
135 OpName %main "main"
136 OpName %x "x"
137 OpName %outparm "outparm"
138 OpDecorate %x Flat
139 OpDecorate %x Location 0
140 OpDecorate %outparm Location 0
141 %void = OpTypeVoid
142 %3 = OpTypeFunction %void
143 %int = OpTypeInt 32 1
144 %bool = OpTypeBool
145 %_ptr_Function_int = OpTypePointer Function %int
146 %int_4 = OpConstant %int 4
147 %int_3 = OpConstant %int 3
148 %int_1 = OpConstant %int 1
149 %_ptr_Input_int = OpTypePointer Input %int
150 %x = OpVariable %_ptr_Input_int Input
151 %_ptr_Output_int = OpTypePointer Output %int
152 %outparm = OpVariable %_ptr_Output_int Output
153 %main = OpFunction %void None %3
154 %4 = OpLabel
155 %5 = OpLoad %int %x
156 %6 = OpSGreaterThan %bool %5 %int_3
157 OpSelectionMerge %25 None
158 OpBranchConditional %6 %22 %23
159 %22 = OpLabel
160 %7 = OpLoad %int %int_4
161 OpBranch %25
162 %23 = OpLabel
163 %8 = OpLoad %int %int_4
164 OpBranch %25
165 %25 = OpLabel
166 %35 = OpPhi %int %7 %22 %8 %23
167 OpStore %outparm %35
168 OpReturn
169 OpFunctionEnd
170 )";
171
172 Assemble(spv_asm);
173
174 Instruction* phi_instr = nullptr;
175 const auto visit_fn = [this, &phi_instr](Instruction* instr,
176 BasicBlock** dest_bb) {
177 *dest_bb = nullptr;
178 if (instr->opcode() == SpvOpLoad) {
179 uint32_t rhs_id = instr->GetSingleWordOperand(2);
180 Instruction* rhs_def = ctx_->get_def_use_mgr()->GetDef(rhs_id);
181 if (rhs_def->opcode() == SpvOpConstant) {
182 uint32_t val = rhs_def->GetSingleWordOperand(2);
183 values_[instr->result_id()] = val;
184 return SSAPropagator::kInteresting;
185 }
186 } else if (instr->opcode() == SpvOpPhi) {
187 phi_instr = instr;
188 SSAPropagator::PropStatus retval;
189 for (uint32_t i = 2; i < instr->NumOperands(); i += 2) {
190 uint32_t phi_arg_id = instr->GetSingleWordOperand(i);
191 auto it = values_.find(phi_arg_id);
192 if (it != values_.end()) {
193 EXPECT_EQ(it->second, 4u);
194 retval = SSAPropagator::kInteresting;
195 values_[instr->result_id()] = it->second;
196 } else {
197 retval = SSAPropagator::kNotInteresting;
198 break;
199 }
200 }
201 return retval;
202 }
203
204 return SSAPropagator::kVarying;
205 };
206
207 EXPECT_TRUE(Propagate(visit_fn));
208
209 // The propagator should've concluded that the Phi instruction has a constant
210 // value of 4.
211 EXPECT_NE(phi_instr, nullptr);
212 EXPECT_EQ(values_[phi_instr->result_id()], 4u);
213
214 EXPECT_THAT(GetValues(), UnorderedElementsAre(4u, 4u, 4u));
215 }
216
217 } // namespace
218 } // namespace opt
219 } // namespace spvtools
220