• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2021 Google LLC.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "source/opt/control_dependence.h"
16 
17 #include <algorithm>
18 #include <vector>
19 
20 #include "gmock/gmock-matchers.h"
21 #include "gtest/gtest.h"
22 #include "source/opt/build_module.h"
23 #include "source/opt/cfg.h"
24 #include "test/opt/function_utils.h"
25 
26 namespace spvtools {
27 namespace opt {
28 
29 namespace {
GatherEdges(const ControlDependenceAnalysis & cdg,std::vector<ControlDependence> & ret)30 void GatherEdges(const ControlDependenceAnalysis& cdg,
31                  std::vector<ControlDependence>& ret) {
32   cdg.ForEachBlockLabel([&](uint32_t label) {
33     ret.reserve(ret.size() + cdg.GetDependenceTargets(label).size());
34     ret.insert(ret.end(), cdg.GetDependenceTargets(label).begin(),
35                cdg.GetDependenceTargets(label).end());
36   });
37   std::sort(ret.begin(), ret.end());
38   // Verify that reverse graph is the same.
39   std::vector<ControlDependence> reverse_edges;
40   reverse_edges.reserve(ret.size());
41   cdg.ForEachBlockLabel([&](uint32_t label) {
42     reverse_edges.insert(reverse_edges.end(),
43                          cdg.GetDependenceSources(label).begin(),
44                          cdg.GetDependenceSources(label).end());
45   });
46   std::sort(reverse_edges.begin(), reverse_edges.end());
47   ASSERT_THAT(reverse_edges, testing::ElementsAreArray(ret));
48 }
49 
50 using ControlDependenceTest = ::testing::Test;
51 
TEST(ControlDependenceTest,DependenceSimpleCFG)52 TEST(ControlDependenceTest, DependenceSimpleCFG) {
53   const std::string text = R"(
54                OpCapability Addresses
55                OpCapability Kernel
56                OpMemoryModel Physical64 OpenCL
57                OpEntryPoint Kernel %1 "main"
58           %2 = OpTypeVoid
59           %3 = OpTypeFunction %2
60           %4 = OpTypeBool
61           %5 = OpTypeInt 32 0
62           %6 = OpConstant %5 0
63           %7 = OpConstantFalse %4
64           %8 = OpConstantTrue %4
65           %9 = OpConstant %5 1
66           %1 = OpFunction %2 None %3
67          %10 = OpLabel
68                OpBranch %11
69          %11 = OpLabel
70                OpSwitch %6 %12 1 %13
71          %12 = OpLabel
72                OpBranch %14
73          %13 = OpLabel
74                OpBranch %14
75          %14 = OpLabel
76                OpBranchConditional %8 %15 %16
77          %15 = OpLabel
78                OpBranch %19
79          %16 = OpLabel
80                OpBranchConditional %8 %17 %18
81          %17 = OpLabel
82                OpBranch %18
83          %18 = OpLabel
84                OpBranch %19
85          %19 = OpLabel
86                OpReturn
87                OpFunctionEnd
88 )";
89 
90   // CFG: (all edges pointing downward)
91   //   %10
92   //    |
93   //   %11
94   //  /   \ (R: %6 == 1, L: default)
95   // %12 %13
96   //  \   /
97   //   %14
98   // T/   \F
99   // %15  %16
100   //  | T/ |F
101   //  | %17|
102   //  |  \ |
103   //  |   %18
104   //  |  /
105   // %19
106 
107   std::unique_ptr<IRContext> context =
108       BuildModule(SPV_ENV_UNIVERSAL_1_0, nullptr, text,
109                   SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
110   Module* module = context->module();
111   EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
112                              << text << std::endl;
113   const Function* fn = spvtest::GetFunction(module, 1);
114   const BasicBlock* entry = spvtest::GetBasicBlock(fn, 10);
115   EXPECT_EQ(entry, fn->entry().get())
116       << "The entry node is not the expected one";
117 
118   {
119     PostDominatorAnalysis pdom;
120     const CFG& cfg = *context->cfg();
121     pdom.InitializeTree(cfg, fn);
122     ControlDependenceAnalysis cdg;
123     cdg.ComputeControlDependenceGraph(cfg, pdom);
124 
125     // Test HasBlock.
126     for (uint32_t id = 10; id <= 19; id++) {
127       EXPECT_TRUE(cdg.HasBlock(id));
128     }
129     EXPECT_TRUE(cdg.HasBlock(ControlDependenceAnalysis::kPseudoEntryBlock));
130     // Check blocks before/after valid range.
131     EXPECT_FALSE(cdg.HasBlock(5));
132     EXPECT_FALSE(cdg.HasBlock(25));
133     EXPECT_FALSE(cdg.HasBlock(UINT32_MAX));
134 
135     // Test ForEachBlockLabel.
136     std::set<uint32_t> block_labels;
137     cdg.ForEachBlockLabel([&block_labels](uint32_t id) {
138       bool inserted = block_labels.insert(id).second;
139       EXPECT_TRUE(inserted);  // Should have no duplicates.
140     });
141     EXPECT_THAT(block_labels, testing::ElementsAre(0, 10, 11, 12, 13, 14, 15,
142                                                    16, 17, 18, 19));
143 
144     {
145       // Test WhileEachBlockLabel.
146       uint32_t iters = 0;
147       EXPECT_TRUE(cdg.WhileEachBlockLabel([&iters](uint32_t) {
148         ++iters;
149         return true;
150       }));
151       EXPECT_EQ((uint32_t)block_labels.size(), iters);
152       iters = 0;
153       EXPECT_FALSE(cdg.WhileEachBlockLabel([&iters](uint32_t) {
154         ++iters;
155         return false;
156       }));
157       EXPECT_EQ(1, iters);
158     }
159 
160     // Test IsDependent.
161     EXPECT_TRUE(cdg.IsDependent(12, 11));
162     EXPECT_TRUE(cdg.IsDependent(13, 11));
163     EXPECT_TRUE(cdg.IsDependent(15, 14));
164     EXPECT_TRUE(cdg.IsDependent(16, 14));
165     EXPECT_TRUE(cdg.IsDependent(18, 14));
166     EXPECT_TRUE(cdg.IsDependent(17, 16));
167     EXPECT_TRUE(cdg.IsDependent(10, 0));
168     EXPECT_TRUE(cdg.IsDependent(11, 0));
169     EXPECT_TRUE(cdg.IsDependent(14, 0));
170     EXPECT_TRUE(cdg.IsDependent(19, 0));
171     EXPECT_FALSE(cdg.IsDependent(14, 11));
172     EXPECT_FALSE(cdg.IsDependent(17, 14));
173     EXPECT_FALSE(cdg.IsDependent(19, 14));
174     EXPECT_FALSE(cdg.IsDependent(12, 0));
175 
176     // Test GetDependenceSources/Targets.
177     std::vector<ControlDependence> edges;
178     GatherEdges(cdg, edges);
179     EXPECT_THAT(edges,
180                 testing::ElementsAre(
181                     ControlDependence(0, 10), ControlDependence(0, 11, 10),
182                     ControlDependence(0, 14, 10), ControlDependence(0, 19, 10),
183                     ControlDependence(11, 12), ControlDependence(11, 13),
184                     ControlDependence(14, 15), ControlDependence(14, 16),
185                     ControlDependence(14, 18, 16), ControlDependence(16, 17)));
186 
187     const uint32_t expected_condition_ids[] = {
188         0, 0, 0, 0, 6, 6, 8, 8, 8, 8,
189     };
190 
191     for (uint32_t i = 0; i < edges.size(); i++) {
192       EXPECT_EQ(expected_condition_ids[i], edges[i].GetConditionID(cfg));
193     }
194   }
195 }
196 
TEST(ControlDependenceTest,DependencePaperCFG)197 TEST(ControlDependenceTest, DependencePaperCFG) {
198   const std::string text = R"(
199                OpCapability Addresses
200                OpCapability Kernel
201                OpMemoryModel Physical64 OpenCL
202                OpEntryPoint Kernel %101 "main"
203         %102 = OpTypeVoid
204         %103 = OpTypeFunction %102
205         %104 = OpTypeBool
206         %108 = OpConstantTrue %104
207         %101 = OpFunction %102 None %103
208           %1 = OpLabel
209                OpBranch %2
210           %2 = OpLabel
211                OpBranchConditional %108 %3 %7
212           %3 = OpLabel
213                OpBranchConditional %108 %4 %5
214           %4 = OpLabel
215                OpBranch %6
216           %5 = OpLabel
217                OpBranch %6
218           %6 = OpLabel
219                OpBranch %8
220           %7 = OpLabel
221                OpBranch %8
222           %8 = OpLabel
223                OpBranch %9
224           %9 = OpLabel
225                OpBranchConditional %108 %10 %11
226          %10 = OpLabel
227                OpBranch %11
228          %11 = OpLabel
229                OpBranchConditional %108 %12 %9
230          %12 = OpLabel
231                OpBranchConditional %108 %13 %2
232          %13 = OpLabel
233                OpReturn
234                OpFunctionEnd
235 )";
236 
237   // CFG: (edges pointing downward if no arrow)
238   //         %1
239   //         |
240   //         %2 <----+
241   //       T/  \F    |
242   //      %3    \    |
243   //    T/  \F   \   |
244   //    %4  %5    %7 |
245   //     \  /    /   |
246   //      %6    /    |
247   //        \  /     |
248   //         %8      |
249   //         |       |
250   //         %9 <-+  |
251   //       T/  |  |  |
252   //       %10 |  |  |
253   //        \  |  |  |
254   //         %11-F+  |
255   //         T|      |
256   //         %12-F---+
257   //         T|
258   //         %13
259 
260   std::unique_ptr<IRContext> context =
261       BuildModule(SPV_ENV_UNIVERSAL_1_0, nullptr, text,
262                   SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
263   Module* module = context->module();
264   EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
265                              << text << std::endl;
266   const Function* fn = spvtest::GetFunction(module, 101);
267   const BasicBlock* entry = spvtest::GetBasicBlock(fn, 1);
268   EXPECT_EQ(entry, fn->entry().get())
269       << "The entry node is not the expected one";
270 
271   {
272     PostDominatorAnalysis pdom;
273     const CFG& cfg = *context->cfg();
274     pdom.InitializeTree(cfg, fn);
275     ControlDependenceAnalysis cdg;
276     cdg.ComputeControlDependenceGraph(cfg, pdom);
277 
278     std::vector<ControlDependence> edges;
279     GatherEdges(cdg, edges);
280     EXPECT_THAT(
281         edges, testing::ElementsAre(
282                    ControlDependence(0, 1), ControlDependence(0, 2, 1),
283                    ControlDependence(0, 8, 1), ControlDependence(0, 9, 1),
284                    ControlDependence(0, 11, 1), ControlDependence(0, 12, 1),
285                    ControlDependence(0, 13, 1), ControlDependence(2, 3),
286                    ControlDependence(2, 6, 3), ControlDependence(2, 7),
287                    ControlDependence(3, 4), ControlDependence(3, 5),
288                    ControlDependence(9, 10), ControlDependence(11, 9),
289                    ControlDependence(11, 11, 9), ControlDependence(12, 2),
290                    ControlDependence(12, 8, 2), ControlDependence(12, 9, 2),
291                    ControlDependence(12, 11, 2), ControlDependence(12, 12, 2)));
292 
293     const uint32_t expected_condition_ids[] = {
294         0,   0,   0,   0,   0,   0,   0,   108, 108, 108,
295         108, 108, 108, 108, 108, 108, 108, 108, 108, 108,
296     };
297 
298     for (uint32_t i = 0; i < edges.size(); i++) {
299       EXPECT_EQ(expected_condition_ids[i], edges[i].GetConditionID(cfg));
300     }
301   }
302 }
303 
304 }  // namespace
305 }  // namespace opt
306 }  // namespace spvtools
307