1 // Copyright (c) 2021 Google LLC.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "source/opt/control_dependence.h"
16
17 #include <algorithm>
18 #include <vector>
19
20 #include "gmock/gmock-matchers.h"
21 #include "gtest/gtest.h"
22 #include "source/opt/build_module.h"
23 #include "source/opt/cfg.h"
24 #include "test/opt/function_utils.h"
25
26 namespace spvtools {
27 namespace opt {
28
29 namespace {
GatherEdges(const ControlDependenceAnalysis & cdg,std::vector<ControlDependence> & ret)30 void GatherEdges(const ControlDependenceAnalysis& cdg,
31 std::vector<ControlDependence>& ret) {
32 cdg.ForEachBlockLabel([&](uint32_t label) {
33 ret.reserve(ret.size() + cdg.GetDependenceTargets(label).size());
34 ret.insert(ret.end(), cdg.GetDependenceTargets(label).begin(),
35 cdg.GetDependenceTargets(label).end());
36 });
37 std::sort(ret.begin(), ret.end());
38 // Verify that reverse graph is the same.
39 std::vector<ControlDependence> reverse_edges;
40 reverse_edges.reserve(ret.size());
41 cdg.ForEachBlockLabel([&](uint32_t label) {
42 reverse_edges.insert(reverse_edges.end(),
43 cdg.GetDependenceSources(label).begin(),
44 cdg.GetDependenceSources(label).end());
45 });
46 std::sort(reverse_edges.begin(), reverse_edges.end());
47 ASSERT_THAT(reverse_edges, testing::ElementsAreArray(ret));
48 }
49
50 using ControlDependenceTest = ::testing::Test;
51
TEST(ControlDependenceTest,DependenceSimpleCFG)52 TEST(ControlDependenceTest, DependenceSimpleCFG) {
53 const std::string text = R"(
54 OpCapability Addresses
55 OpCapability Kernel
56 OpMemoryModel Physical64 OpenCL
57 OpEntryPoint Kernel %1 "main"
58 %2 = OpTypeVoid
59 %3 = OpTypeFunction %2
60 %4 = OpTypeBool
61 %5 = OpTypeInt 32 0
62 %6 = OpConstant %5 0
63 %7 = OpConstantFalse %4
64 %8 = OpConstantTrue %4
65 %9 = OpConstant %5 1
66 %1 = OpFunction %2 None %3
67 %10 = OpLabel
68 OpBranch %11
69 %11 = OpLabel
70 OpSwitch %6 %12 1 %13
71 %12 = OpLabel
72 OpBranch %14
73 %13 = OpLabel
74 OpBranch %14
75 %14 = OpLabel
76 OpBranchConditional %8 %15 %16
77 %15 = OpLabel
78 OpBranch %19
79 %16 = OpLabel
80 OpBranchConditional %8 %17 %18
81 %17 = OpLabel
82 OpBranch %18
83 %18 = OpLabel
84 OpBranch %19
85 %19 = OpLabel
86 OpReturn
87 OpFunctionEnd
88 )";
89
90 // CFG: (all edges pointing downward)
91 // %10
92 // |
93 // %11
94 // / \ (R: %6 == 1, L: default)
95 // %12 %13
96 // \ /
97 // %14
98 // T/ \F
99 // %15 %16
100 // | T/ |F
101 // | %17|
102 // | \ |
103 // | %18
104 // | /
105 // %19
106
107 std::unique_ptr<IRContext> context =
108 BuildModule(SPV_ENV_UNIVERSAL_1_0, nullptr, text,
109 SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
110 Module* module = context->module();
111 EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
112 << text << std::endl;
113 const Function* fn = spvtest::GetFunction(module, 1);
114 const BasicBlock* entry = spvtest::GetBasicBlock(fn, 10);
115 EXPECT_EQ(entry, fn->entry().get())
116 << "The entry node is not the expected one";
117
118 {
119 PostDominatorAnalysis pdom;
120 const CFG& cfg = *context->cfg();
121 pdom.InitializeTree(cfg, fn);
122 ControlDependenceAnalysis cdg;
123 cdg.ComputeControlDependenceGraph(cfg, pdom);
124
125 // Test HasBlock.
126 for (uint32_t id = 10; id <= 19; id++) {
127 EXPECT_TRUE(cdg.HasBlock(id));
128 }
129 EXPECT_TRUE(cdg.HasBlock(ControlDependenceAnalysis::kPseudoEntryBlock));
130 // Check blocks before/after valid range.
131 EXPECT_FALSE(cdg.HasBlock(5));
132 EXPECT_FALSE(cdg.HasBlock(25));
133 EXPECT_FALSE(cdg.HasBlock(UINT32_MAX));
134
135 // Test ForEachBlockLabel.
136 std::set<uint32_t> block_labels;
137 cdg.ForEachBlockLabel([&block_labels](uint32_t id) {
138 bool inserted = block_labels.insert(id).second;
139 EXPECT_TRUE(inserted); // Should have no duplicates.
140 });
141 EXPECT_THAT(block_labels, testing::ElementsAre(0, 10, 11, 12, 13, 14, 15,
142 16, 17, 18, 19));
143
144 {
145 // Test WhileEachBlockLabel.
146 uint32_t iters = 0;
147 EXPECT_TRUE(cdg.WhileEachBlockLabel([&iters](uint32_t) {
148 ++iters;
149 return true;
150 }));
151 EXPECT_EQ((uint32_t)block_labels.size(), iters);
152 iters = 0;
153 EXPECT_FALSE(cdg.WhileEachBlockLabel([&iters](uint32_t) {
154 ++iters;
155 return false;
156 }));
157 EXPECT_EQ(1, iters);
158 }
159
160 // Test IsDependent.
161 EXPECT_TRUE(cdg.IsDependent(12, 11));
162 EXPECT_TRUE(cdg.IsDependent(13, 11));
163 EXPECT_TRUE(cdg.IsDependent(15, 14));
164 EXPECT_TRUE(cdg.IsDependent(16, 14));
165 EXPECT_TRUE(cdg.IsDependent(18, 14));
166 EXPECT_TRUE(cdg.IsDependent(17, 16));
167 EXPECT_TRUE(cdg.IsDependent(10, 0));
168 EXPECT_TRUE(cdg.IsDependent(11, 0));
169 EXPECT_TRUE(cdg.IsDependent(14, 0));
170 EXPECT_TRUE(cdg.IsDependent(19, 0));
171 EXPECT_FALSE(cdg.IsDependent(14, 11));
172 EXPECT_FALSE(cdg.IsDependent(17, 14));
173 EXPECT_FALSE(cdg.IsDependent(19, 14));
174 EXPECT_FALSE(cdg.IsDependent(12, 0));
175
176 // Test GetDependenceSources/Targets.
177 std::vector<ControlDependence> edges;
178 GatherEdges(cdg, edges);
179 EXPECT_THAT(edges,
180 testing::ElementsAre(
181 ControlDependence(0, 10), ControlDependence(0, 11, 10),
182 ControlDependence(0, 14, 10), ControlDependence(0, 19, 10),
183 ControlDependence(11, 12), ControlDependence(11, 13),
184 ControlDependence(14, 15), ControlDependence(14, 16),
185 ControlDependence(14, 18, 16), ControlDependence(16, 17)));
186
187 const uint32_t expected_condition_ids[] = {
188 0, 0, 0, 0, 6, 6, 8, 8, 8, 8,
189 };
190
191 for (uint32_t i = 0; i < edges.size(); i++) {
192 EXPECT_EQ(expected_condition_ids[i], edges[i].GetConditionID(cfg));
193 }
194 }
195 }
196
TEST(ControlDependenceTest,DependencePaperCFG)197 TEST(ControlDependenceTest, DependencePaperCFG) {
198 const std::string text = R"(
199 OpCapability Addresses
200 OpCapability Kernel
201 OpMemoryModel Physical64 OpenCL
202 OpEntryPoint Kernel %101 "main"
203 %102 = OpTypeVoid
204 %103 = OpTypeFunction %102
205 %104 = OpTypeBool
206 %108 = OpConstantTrue %104
207 %101 = OpFunction %102 None %103
208 %1 = OpLabel
209 OpBranch %2
210 %2 = OpLabel
211 OpBranchConditional %108 %3 %7
212 %3 = OpLabel
213 OpBranchConditional %108 %4 %5
214 %4 = OpLabel
215 OpBranch %6
216 %5 = OpLabel
217 OpBranch %6
218 %6 = OpLabel
219 OpBranch %8
220 %7 = OpLabel
221 OpBranch %8
222 %8 = OpLabel
223 OpBranch %9
224 %9 = OpLabel
225 OpBranchConditional %108 %10 %11
226 %10 = OpLabel
227 OpBranch %11
228 %11 = OpLabel
229 OpBranchConditional %108 %12 %9
230 %12 = OpLabel
231 OpBranchConditional %108 %13 %2
232 %13 = OpLabel
233 OpReturn
234 OpFunctionEnd
235 )";
236
237 // CFG: (edges pointing downward if no arrow)
238 // %1
239 // |
240 // %2 <----+
241 // T/ \F |
242 // %3 \ |
243 // T/ \F \ |
244 // %4 %5 %7 |
245 // \ / / |
246 // %6 / |
247 // \ / |
248 // %8 |
249 // | |
250 // %9 <-+ |
251 // T/ | | |
252 // %10 | | |
253 // \ | | |
254 // %11-F+ |
255 // T| |
256 // %12-F---+
257 // T|
258 // %13
259
260 std::unique_ptr<IRContext> context =
261 BuildModule(SPV_ENV_UNIVERSAL_1_0, nullptr, text,
262 SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
263 Module* module = context->module();
264 EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
265 << text << std::endl;
266 const Function* fn = spvtest::GetFunction(module, 101);
267 const BasicBlock* entry = spvtest::GetBasicBlock(fn, 1);
268 EXPECT_EQ(entry, fn->entry().get())
269 << "The entry node is not the expected one";
270
271 {
272 PostDominatorAnalysis pdom;
273 const CFG& cfg = *context->cfg();
274 pdom.InitializeTree(cfg, fn);
275 ControlDependenceAnalysis cdg;
276 cdg.ComputeControlDependenceGraph(cfg, pdom);
277
278 std::vector<ControlDependence> edges;
279 GatherEdges(cdg, edges);
280 EXPECT_THAT(
281 edges, testing::ElementsAre(
282 ControlDependence(0, 1), ControlDependence(0, 2, 1),
283 ControlDependence(0, 8, 1), ControlDependence(0, 9, 1),
284 ControlDependence(0, 11, 1), ControlDependence(0, 12, 1),
285 ControlDependence(0, 13, 1), ControlDependence(2, 3),
286 ControlDependence(2, 6, 3), ControlDependence(2, 7),
287 ControlDependence(3, 4), ControlDependence(3, 5),
288 ControlDependence(9, 10), ControlDependence(11, 9),
289 ControlDependence(11, 11, 9), ControlDependence(12, 2),
290 ControlDependence(12, 8, 2), ControlDependence(12, 9, 2),
291 ControlDependence(12, 11, 2), ControlDependence(12, 12, 2)));
292
293 const uint32_t expected_condition_ids[] = {
294 0, 0, 0, 0, 0, 0, 0, 108, 108, 108,
295 108, 108, 108, 108, 108, 108, 108, 108, 108, 108,
296 };
297
298 for (uint32_t i = 0; i < edges.size(); i++) {
299 EXPECT_EQ(expected_condition_ids[i], edges[i].GetConditionID(cfg));
300 }
301 }
302 }
303
304 } // namespace
305 } // namespace opt
306 } // namespace spvtools
307