1 // Copyright (c) 2021 Google LLC.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "source/opt/dataflow.h"
16
17 #include <map>
18 #include <set>
19
20 #include "gmock/gmock.h"
21 #include "gtest/gtest.h"
22 #include "opt/function_utils.h"
23 #include "source/opt/build_module.h"
24
25 namespace spvtools {
26 namespace opt {
27 namespace {
28
29 using DataFlowTest = ::testing::Test;
30
31 // Simple analyses for testing:
32
33 // Stores the result IDs of visited instructions in visit order.
34 struct VisitOrder : public ForwardDataFlowAnalysis {
35 std::vector<uint32_t> visited_result_ids;
36
VisitOrderspvtools::opt::__anon7deebcc40111::VisitOrder37 VisitOrder(IRContext& context, LabelPosition label_position)
38 : ForwardDataFlowAnalysis(context, label_position) {}
39
Visitspvtools::opt::__anon7deebcc40111::VisitOrder40 VisitResult Visit(Instruction* inst) override {
41 if (inst->HasResultId()) {
42 visited_result_ids.push_back(inst->result_id());
43 }
44 return DataFlowAnalysis::VisitResult::kResultFixed;
45 }
46 };
47
48 // For each block, stores the set of blocks it can be preceded by.
49 // For example, with the following CFG:
50 // V-----------.
51 // -> 11 -> 12 -> 13 -> 15
52 // \-> 14 ---^
53 //
54 // The answer is:
55 // 11: 11, 12, 13
56 // 12: 11, 12, 13
57 // 13: 11, 12, 13
58 // 14: 11, 12, 13
59 // 15: 11, 12, 13, 14
60 struct BackwardReachability : public ForwardDataFlowAnalysis {
61 std::map<uint32_t, std::set<uint32_t>> reachable_from;
62
BackwardReachabilityspvtools::opt::__anon7deebcc40111::BackwardReachability63 BackwardReachability(IRContext& context)
64 : ForwardDataFlowAnalysis(
65 context, ForwardDataFlowAnalysis::LabelPosition::kLabelsOnly) {}
66
Visitspvtools::opt::__anon7deebcc40111::BackwardReachability67 VisitResult Visit(Instruction* inst) override {
68 // Conditional branches can be enqueued from labels, so skip them.
69 if (inst->opcode() != SpvOpLabel)
70 return DataFlowAnalysis::VisitResult::kResultFixed;
71 uint32_t id = inst->result_id();
72 VisitResult ret = DataFlowAnalysis::VisitResult::kResultFixed;
73 std::set<uint32_t>& precedents = reachable_from[id];
74 for (uint32_t pred : context().cfg()->preds(id)) {
75 bool pred_inserted = precedents.insert(pred).second;
76 if (pred_inserted) {
77 ret = DataFlowAnalysis::VisitResult::kResultChanged;
78 }
79 for (uint32_t block : reachable_from[pred]) {
80 bool inserted = precedents.insert(block).second;
81 if (inserted) {
82 ret = DataFlowAnalysis::VisitResult::kResultChanged;
83 }
84 }
85 }
86 return ret;
87 }
88
InitializeWorklistspvtools::opt::__anon7deebcc40111::BackwardReachability89 void InitializeWorklist(Function* function,
90 bool is_first_iteration) override {
91 // Since successor function is exact, only need one pass.
92 if (is_first_iteration) {
93 ForwardDataFlowAnalysis::InitializeWorklist(function, true);
94 }
95 }
96 };
97
TEST_F(DataFlowTest,ReversePostOrder)98 TEST_F(DataFlowTest, ReversePostOrder) {
99 // Note: labels and IDs are intentionally out of order.
100 //
101 // CFG: (order of branches is from bottom to top)
102 // V-----------.
103 // -> 50 -> 40 -> 20 -> 60 -> 70
104 // \-> 30 ---^
105
106 // DFS tree with RPO numbering:
107 // -> 50[0] -> 40[1] -> 20[2] 60[4] -> 70[5]
108 // \-> 30[3] ---^
109
110 const std::string text = R"(
111 OpCapability Shader
112 %1 = OpExtInstImport "GLSL.std.450"
113 OpMemoryModel Logical GLSL450
114 OpEntryPoint Fragment %2 "main"
115 OpExecutionMode %2 OriginUpperLeft
116 OpSource GLSL 430
117 %3 = OpTypeVoid
118 %4 = OpTypeFunction %3
119 %6 = OpTypeBool
120 %5 = OpConstantTrue %6
121 %2 = OpFunction %3 None %4
122 %50 = OpLabel
123 %51 = OpUndef %6
124 %52 = OpUndef %6
125 OpBranch %40
126 %70 = OpLabel
127 %69 = OpUndef %6
128 OpReturn
129 %60 = OpLabel
130 %61 = OpUndef %6
131 OpBranchConditional %5 %70 %40
132 %30 = OpLabel
133 %29 = OpUndef %6
134 OpBranch %60
135 %20 = OpLabel
136 %21 = OpUndef %6
137 OpBranch %60
138 %40 = OpLabel
139 %39 = OpUndef %6
140 OpBranchConditional %5 %30 %20
141 OpFunctionEnd
142 )";
143
144 std::unique_ptr<IRContext> context =
145 BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text,
146 SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
147 ASSERT_NE(context, nullptr);
148
149 Function* function = spvtest::GetFunction(context->module(), 2);
150
151 std::map<ForwardDataFlowAnalysis::LabelPosition, std::vector<uint32_t>>
152 expected_order;
153 expected_order[ForwardDataFlowAnalysis::LabelPosition::kLabelsOnly] = {
154 50, 40, 20, 30, 60, 70,
155 };
156 expected_order[ForwardDataFlowAnalysis::LabelPosition::kLabelsAtBeginning] = {
157 50, 51, 52, 40, 39, 20, 21, 30, 29, 60, 61, 70, 69,
158 };
159 expected_order[ForwardDataFlowAnalysis::LabelPosition::kLabelsAtEnd] = {
160 51, 52, 50, 39, 40, 21, 20, 29, 30, 61, 60, 69, 70,
161 };
162 expected_order[ForwardDataFlowAnalysis::LabelPosition::kNoLabels] = {
163 51, 52, 39, 21, 29, 61, 69,
164 };
165
166 for (const auto& test_case : expected_order) {
167 VisitOrder analysis(*context, test_case.first);
168 analysis.Run(function);
169 EXPECT_EQ(test_case.second, analysis.visited_result_ids);
170 }
171 }
172
TEST_F(DataFlowTest,BackwardReachability)173 TEST_F(DataFlowTest, BackwardReachability) {
174 // CFG:
175 // V-----------.
176 // -> 11 -> 12 -> 13 -> 15
177 // \-> 14 ---^
178
179 const std::string text = R"(
180 OpCapability Shader
181 %1 = OpExtInstImport "GLSL.std.450"
182 OpMemoryModel Logical GLSL450
183 OpEntryPoint Fragment %2 "main"
184 OpExecutionMode %2 OriginUpperLeft
185 OpSource GLSL 430
186 %3 = OpTypeVoid
187 %4 = OpTypeFunction %3
188 %6 = OpTypeBool
189 %5 = OpConstantTrue %6
190 %2 = OpFunction %3 None %4
191 %11 = OpLabel
192 OpBranch %12
193 %12 = OpLabel
194 OpBranchConditional %5 %14 %13
195 %13 = OpLabel
196 OpBranchConditional %5 %15 %11
197 %14 = OpLabel
198 OpBranch %15
199 %15 = OpLabel
200 OpReturn
201 OpFunctionEnd
202 )";
203
204 std::unique_ptr<IRContext> context =
205 BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text,
206 SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
207 ASSERT_NE(context, nullptr);
208
209 Function* function = spvtest::GetFunction(context->module(), 2);
210
211 BackwardReachability analysis(*context);
212 analysis.Run(function);
213
214 std::map<uint32_t, std::set<uint32_t>> expected_result;
215 expected_result[11] = {11, 12, 13};
216 expected_result[12] = {11, 12, 13};
217 expected_result[13] = {11, 12, 13};
218 expected_result[14] = {11, 12, 13};
219 expected_result[15] = {11, 12, 13, 14};
220 EXPECT_EQ(expected_result, analysis.reachable_from);
221 }
222
223 } // namespace
224 } // namespace opt
225 } // namespace spvtools
226