1 // Copyright (c) 2021 Google LLC.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "source/lint/divergence_analysis.h"
16
17 #include "source/opt/basic_block.h"
18 #include "source/opt/control_dependence.h"
19 #include "source/opt/dataflow.h"
20 #include "source/opt/function.h"
21 #include "source/opt/instruction.h"
22 #include "spirv/unified1/spirv.h"
23
24 namespace spvtools {
25 namespace lint {
26
EnqueueSuccessors(opt::Instruction * inst)27 void DivergenceAnalysis::EnqueueSuccessors(opt::Instruction* inst) {
28 // Enqueue control dependents of block, if applicable.
29 // There are two ways for a dependence source to be updated:
30 // 1. control -> control: source block is marked divergent.
31 // 2. data -> control: branch condition is marked divergent.
32 uint32_t block_id;
33 if (inst->IsBlockTerminator()) {
34 block_id = context().get_instr_block(inst)->id();
35 } else if (inst->opcode() == SpvOpLabel) {
36 block_id = inst->result_id();
37 opt::BasicBlock* bb = context().cfg()->block(block_id);
38 // Only enqueue phi instructions, as other uses don't affect divergence.
39 bb->ForEachPhiInst([this](opt::Instruction* phi) { Enqueue(phi); });
40 } else {
41 opt::ForwardDataFlowAnalysis::EnqueueUsers(inst);
42 return;
43 }
44 if (!cd_.HasBlock(block_id)) {
45 return;
46 }
47 for (const spvtools::opt::ControlDependence& dep :
48 cd_.GetDependenceTargets(block_id)) {
49 opt::Instruction* target_inst =
50 context().cfg()->block(dep.target_bb_id())->GetLabelInst();
51 Enqueue(target_inst);
52 }
53 }
54
Visit(opt::Instruction * inst)55 opt::DataFlowAnalysis::VisitResult DivergenceAnalysis::Visit(
56 opt::Instruction* inst) {
57 if (inst->opcode() == SpvOpLabel) {
58 return VisitBlock(inst->result_id());
59 } else {
60 return VisitInstruction(inst);
61 }
62 }
63
VisitBlock(uint32_t id)64 opt::DataFlowAnalysis::VisitResult DivergenceAnalysis::VisitBlock(uint32_t id) {
65 if (!cd_.HasBlock(id)) {
66 return opt::DataFlowAnalysis::VisitResult::kResultFixed;
67 }
68 DivergenceLevel& cur_level = divergence_[id];
69 if (cur_level == DivergenceLevel::kDivergent) {
70 return opt::DataFlowAnalysis::VisitResult::kResultFixed;
71 }
72 DivergenceLevel orig = cur_level;
73 for (const spvtools::opt::ControlDependence& dep :
74 cd_.GetDependenceSources(id)) {
75 if (divergence_[dep.source_bb_id()] > cur_level) {
76 cur_level = divergence_[dep.source_bb_id()];
77 divergence_source_[id] = dep.source_bb_id();
78 } else if (dep.source_bb_id() != 0) {
79 uint32_t condition_id = dep.GetConditionID(*context().cfg());
80 DivergenceLevel dep_level = divergence_[condition_id];
81 // Check if we are along the chain of unconditional branches starting from
82 // the branch target.
83 if (follow_unconditional_branches_[dep.branch_target_bb_id()] !=
84 follow_unconditional_branches_[dep.target_bb_id()]) {
85 // We must have reconverged in order to reach this block.
86 // Promote partially uniform to divergent.
87 if (dep_level == DivergenceLevel::kPartiallyUniform) {
88 dep_level = DivergenceLevel::kDivergent;
89 }
90 }
91 if (dep_level > cur_level) {
92 cur_level = dep_level;
93 divergence_source_[id] = condition_id;
94 divergence_dependence_source_[id] = dep.source_bb_id();
95 }
96 }
97 }
98 return cur_level > orig ? VisitResult::kResultChanged
99 : VisitResult::kResultFixed;
100 }
101
VisitInstruction(opt::Instruction * inst)102 opt::DataFlowAnalysis::VisitResult DivergenceAnalysis::VisitInstruction(
103 opt::Instruction* inst) {
104 if (inst->IsBlockTerminator()) {
105 // This is called only when the condition has changed, so return changed.
106 return VisitResult::kResultChanged;
107 }
108 if (!inst->HasResultId()) {
109 return VisitResult::kResultFixed;
110 }
111 uint32_t id = inst->result_id();
112 DivergenceLevel& cur_level = divergence_[id];
113 if (cur_level == DivergenceLevel::kDivergent) {
114 return opt::DataFlowAnalysis::VisitResult::kResultFixed;
115 }
116 DivergenceLevel orig = cur_level;
117 cur_level = ComputeInstructionDivergence(inst);
118 return cur_level > orig ? VisitResult::kResultChanged
119 : VisitResult::kResultFixed;
120 }
121
122 DivergenceAnalysis::DivergenceLevel
ComputeInstructionDivergence(opt::Instruction * inst)123 DivergenceAnalysis::ComputeInstructionDivergence(opt::Instruction* inst) {
124 // TODO(kuhar): Check to see if inst is decorated with Uniform or UniformId
125 // and use that to short circuit other checks. Uniform is for subgroups which
126 // would satisfy derivative groups too. UniformId takes a scope, so if it is
127 // subgroup or greater it could satisfy derivative group and
128 // Device/QueueFamily could satisfy fully uniform.
129 uint32_t id = inst->result_id();
130 // Handle divergence roots.
131 if (inst->opcode() == SpvOpFunctionParameter) {
132 divergence_source_[id] = 0;
133 return divergence_[id] = DivergenceLevel::kDivergent;
134 } else if (inst->IsLoad()) {
135 spvtools::opt::Instruction* var = inst->GetBaseAddress();
136 if (var->opcode() != SpvOpVariable) {
137 // Assume divergent.
138 divergence_source_[id] = 0;
139 return DivergenceLevel::kDivergent;
140 }
141 DivergenceLevel ret = ComputeVariableDivergence(var);
142 if (ret > DivergenceLevel::kUniform) {
143 divergence_source_[inst->result_id()] = 0;
144 }
145 return divergence_[id] = ret;
146 }
147 // Get the maximum divergence of the operands.
148 DivergenceLevel ret = DivergenceLevel::kUniform;
149 inst->ForEachInId([this, inst, &ret](const uint32_t* op) {
150 if (!op) return;
151 if (divergence_[*op] > ret) {
152 divergence_source_[inst->result_id()] = *op;
153 ret = divergence_[*op];
154 }
155 });
156 divergence_[inst->result_id()] = ret;
157 return ret;
158 }
159
160 DivergenceAnalysis::DivergenceLevel
ComputeVariableDivergence(opt::Instruction * var)161 DivergenceAnalysis::ComputeVariableDivergence(opt::Instruction* var) {
162 uint32_t type_id = var->type_id();
163 spvtools::opt::analysis::Pointer* type =
164 context().get_type_mgr()->GetType(type_id)->AsPointer();
165 assert(type != nullptr);
166 uint32_t def_id = var->result_id();
167 DivergenceLevel ret;
168 switch (type->storage_class()) {
169 case SpvStorageClassFunction:
170 case SpvStorageClassGeneric:
171 case SpvStorageClassAtomicCounter:
172 case SpvStorageClassStorageBuffer:
173 case SpvStorageClassPhysicalStorageBuffer:
174 case SpvStorageClassOutput:
175 case SpvStorageClassWorkgroup:
176 case SpvStorageClassImage: // Image atomics probably aren't uniform.
177 case SpvStorageClassPrivate:
178 ret = DivergenceLevel::kDivergent;
179 break;
180 case SpvStorageClassInput:
181 ret = DivergenceLevel::kDivergent;
182 // If this variable has a Flat decoration, it is partially uniform.
183 // TODO(kuhar): Track access chain indices and also consider Flat members
184 // of a structure.
185 context().get_decoration_mgr()->WhileEachDecoration(
186 def_id, SpvDecorationFlat, [&ret](const opt::Instruction&) {
187 ret = DivergenceLevel::kPartiallyUniform;
188 return false;
189 });
190 break;
191 case SpvStorageClassUniformConstant:
192 // May be a storage image which is also written to; mark those as
193 // divergent.
194 if (!var->IsVulkanStorageImage() || var->IsReadOnlyPointer()) {
195 ret = DivergenceLevel::kUniform;
196 } else {
197 ret = DivergenceLevel::kDivergent;
198 }
199 break;
200 case SpvStorageClassUniform:
201 case SpvStorageClassPushConstant:
202 case SpvStorageClassCrossWorkgroup: // Not for shaders; default uniform.
203 default:
204 ret = DivergenceLevel::kUniform;
205 break;
206 }
207 return ret;
208 }
209
Setup(opt::Function * function)210 void DivergenceAnalysis::Setup(opt::Function* function) {
211 // TODO(kuhar): Run functions called by |function| so we can detect
212 // reconvergence caused by multiple returns.
213 cd_.ComputeControlDependenceGraph(
214 *context().cfg(), *context().GetPostDominatorAnalysis(function));
215 context().cfg()->ForEachBlockInPostOrder(
216 function->entry().get(), [this](const opt::BasicBlock* bb) {
217 uint32_t id = bb->id();
218 if (bb->terminator() == nullptr ||
219 bb->terminator()->opcode() != SpvOpBranch) {
220 follow_unconditional_branches_[id] = id;
221 } else {
222 uint32_t target_id = bb->terminator()->GetSingleWordInOperand(0);
223 // Target is guaranteed to have been visited before us in postorder.
224 follow_unconditional_branches_[id] =
225 follow_unconditional_branches_[target_id];
226 }
227 });
228 }
229
operator <<(std::ostream & os,DivergenceAnalysis::DivergenceLevel level)230 std::ostream& operator<<(std::ostream& os,
231 DivergenceAnalysis::DivergenceLevel level) {
232 switch (level) {
233 case DivergenceAnalysis::DivergenceLevel::kUniform:
234 return os << "uniform";
235 case DivergenceAnalysis::DivergenceLevel::kPartiallyUniform:
236 return os << "partially uniform";
237 case DivergenceAnalysis::DivergenceLevel::kDivergent:
238 return os << "divergent";
239 default:
240 return os << "<invalid divergence level>";
241 }
242 }
243
244 } // namespace lint
245 } // namespace spvtools
246