• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2021 Google LLC.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "source/opt/dataflow.h"
16 
17 #include <map>
18 #include <set>
19 
20 #include "gmock/gmock.h"
21 #include "gtest/gtest.h"
22 #include "opt/function_utils.h"
23 #include "source/opt/build_module.h"
24 
25 namespace spvtools {
26 namespace opt {
27 namespace {
28 
29 using DataFlowTest = ::testing::Test;
30 
31 // Simple analyses for testing:
32 
33 // Stores the result IDs of visited instructions in visit order.
34 struct VisitOrder : public ForwardDataFlowAnalysis {
35   std::vector<uint32_t> visited_result_ids;
36 
VisitOrderspvtools::opt::__anon7deebcc40111::VisitOrder37   VisitOrder(IRContext& context, LabelPosition label_position)
38       : ForwardDataFlowAnalysis(context, label_position) {}
39 
Visitspvtools::opt::__anon7deebcc40111::VisitOrder40   VisitResult Visit(Instruction* inst) override {
41     if (inst->HasResultId()) {
42       visited_result_ids.push_back(inst->result_id());
43     }
44     return DataFlowAnalysis::VisitResult::kResultFixed;
45   }
46 };
47 
48 // For each block, stores the set of blocks it can be preceded by.
49 // For example, with the following CFG:
50 //    V-----------.
51 // -> 11 -> 12 -> 13 -> 15
52 //            \-> 14 ---^
53 //
54 // The answer is:
55 // 11: 11, 12, 13
56 // 12: 11, 12, 13
57 // 13: 11, 12, 13
58 // 14: 11, 12, 13
59 // 15: 11, 12, 13, 14
60 struct BackwardReachability : public ForwardDataFlowAnalysis {
61   std::map<uint32_t, std::set<uint32_t>> reachable_from;
62 
BackwardReachabilityspvtools::opt::__anon7deebcc40111::BackwardReachability63   BackwardReachability(IRContext& context)
64       : ForwardDataFlowAnalysis(
65             context, ForwardDataFlowAnalysis::LabelPosition::kLabelsOnly) {}
66 
Visitspvtools::opt::__anon7deebcc40111::BackwardReachability67   VisitResult Visit(Instruction* inst) override {
68     // Conditional branches can be enqueued from labels, so skip them.
69     if (inst->opcode() != SpvOpLabel)
70       return DataFlowAnalysis::VisitResult::kResultFixed;
71     uint32_t id = inst->result_id();
72     VisitResult ret = DataFlowAnalysis::VisitResult::kResultFixed;
73     std::set<uint32_t>& precedents = reachable_from[id];
74     for (uint32_t pred : context().cfg()->preds(id)) {
75       bool pred_inserted = precedents.insert(pred).second;
76       if (pred_inserted) {
77         ret = DataFlowAnalysis::VisitResult::kResultChanged;
78       }
79       for (uint32_t block : reachable_from[pred]) {
80         bool inserted = precedents.insert(block).second;
81         if (inserted) {
82           ret = DataFlowAnalysis::VisitResult::kResultChanged;
83         }
84       }
85     }
86     return ret;
87   }
88 
InitializeWorklistspvtools::opt::__anon7deebcc40111::BackwardReachability89   void InitializeWorklist(Function* function,
90                           bool is_first_iteration) override {
91     // Since successor function is exact, only need one pass.
92     if (is_first_iteration) {
93       ForwardDataFlowAnalysis::InitializeWorklist(function, true);
94     }
95   }
96 };
97 
TEST_F(DataFlowTest,ReversePostOrder)98 TEST_F(DataFlowTest, ReversePostOrder) {
99   // Note: labels and IDs are intentionally out of order.
100   //
101   // CFG: (order of branches is from bottom to top)
102   //          V-----------.
103   // -> 50 -> 40 -> 20 -> 60 -> 70
104   //            \-> 30 ---^
105 
106   // DFS tree with RPO numbering:
107   // -> 50[0] -> 40[1] -> 20[2]    60[4] -> 70[5]
108   //                  \-> 30[3] ---^
109 
110   const std::string text = R"(
111                OpCapability Shader
112           %1 = OpExtInstImport "GLSL.std.450"
113                OpMemoryModel Logical GLSL450
114                OpEntryPoint Fragment %2 "main"
115                OpExecutionMode %2 OriginUpperLeft
116                OpSource GLSL 430
117           %3 = OpTypeVoid
118           %4 = OpTypeFunction %3
119           %6 = OpTypeBool
120           %5 = OpConstantTrue %6
121           %2 = OpFunction %3 None %4
122          %50 = OpLabel
123          %51 = OpUndef %6
124          %52 = OpUndef %6
125                OpBranch %40
126          %70 = OpLabel
127          %69 = OpUndef %6
128                OpReturn
129          %60 = OpLabel
130          %61 = OpUndef %6
131                OpBranchConditional %5 %70 %40
132          %30 = OpLabel
133          %29 = OpUndef %6
134                OpBranch %60
135          %20 = OpLabel
136          %21 = OpUndef %6
137                OpBranch %60
138          %40 = OpLabel
139          %39 = OpUndef %6
140                OpBranchConditional %5 %30 %20
141                OpFunctionEnd
142   )";
143 
144   std::unique_ptr<IRContext> context =
145       BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text,
146                   SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
147   ASSERT_NE(context, nullptr);
148 
149   Function* function = spvtest::GetFunction(context->module(), 2);
150 
151   std::map<ForwardDataFlowAnalysis::LabelPosition, std::vector<uint32_t>>
152       expected_order;
153   expected_order[ForwardDataFlowAnalysis::LabelPosition::kLabelsOnly] = {
154       50, 40, 20, 30, 60, 70,
155   };
156   expected_order[ForwardDataFlowAnalysis::LabelPosition::kLabelsAtBeginning] = {
157       50, 51, 52, 40, 39, 20, 21, 30, 29, 60, 61, 70, 69,
158   };
159   expected_order[ForwardDataFlowAnalysis::LabelPosition::kLabelsAtEnd] = {
160       51, 52, 50, 39, 40, 21, 20, 29, 30, 61, 60, 69, 70,
161   };
162   expected_order[ForwardDataFlowAnalysis::LabelPosition::kNoLabels] = {
163       51, 52, 39, 21, 29, 61, 69,
164   };
165 
166   for (const auto& test_case : expected_order) {
167     VisitOrder analysis(*context, test_case.first);
168     analysis.Run(function);
169     EXPECT_EQ(test_case.second, analysis.visited_result_ids);
170   }
171 }
172 
TEST_F(DataFlowTest,BackwardReachability)173 TEST_F(DataFlowTest, BackwardReachability) {
174   // CFG:
175   //    V-----------.
176   // -> 11 -> 12 -> 13 -> 15
177   //            \-> 14 ---^
178 
179   const std::string text = R"(
180                OpCapability Shader
181           %1 = OpExtInstImport "GLSL.std.450"
182                OpMemoryModel Logical GLSL450
183                OpEntryPoint Fragment %2 "main"
184                OpExecutionMode %2 OriginUpperLeft
185                OpSource GLSL 430
186           %3 = OpTypeVoid
187           %4 = OpTypeFunction %3
188           %6 = OpTypeBool
189           %5 = OpConstantTrue %6
190           %2 = OpFunction %3 None %4
191          %11 = OpLabel
192                OpBranch %12
193          %12 = OpLabel
194                OpBranchConditional %5 %14 %13
195          %13 = OpLabel
196                OpBranchConditional %5 %15 %11
197          %14 = OpLabel
198                OpBranch %15
199          %15 = OpLabel
200                OpReturn
201                OpFunctionEnd
202   )";
203 
204   std::unique_ptr<IRContext> context =
205       BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text,
206                   SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
207   ASSERT_NE(context, nullptr);
208 
209   Function* function = spvtest::GetFunction(context->module(), 2);
210 
211   BackwardReachability analysis(*context);
212   analysis.Run(function);
213 
214   std::map<uint32_t, std::set<uint32_t>> expected_result;
215   expected_result[11] = {11, 12, 13};
216   expected_result[12] = {11, 12, 13};
217   expected_result[13] = {11, 12, 13};
218   expected_result[14] = {11, 12, 13};
219   expected_result[15] = {11, 12, 13, 14};
220   EXPECT_EQ(expected_result, analysis.reachable_from);
221 }
222 
223 }  // namespace
224 }  // namespace opt
225 }  // namespace spvtools
226