• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2021 Google LLC.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "source/lint/divergence_analysis.h"
16 
17 #include "source/opt/basic_block.h"
18 #include "source/opt/control_dependence.h"
19 #include "source/opt/dataflow.h"
20 #include "source/opt/function.h"
21 #include "source/opt/instruction.h"
22 #include "spirv/unified1/spirv.h"
23 
24 namespace spvtools {
25 namespace lint {
26 
EnqueueSuccessors(opt::Instruction * inst)27 void DivergenceAnalysis::EnqueueSuccessors(opt::Instruction* inst) {
28   // Enqueue control dependents of block, if applicable.
29   // There are two ways for a dependence source to be updated:
30   // 1. control -> control: source block is marked divergent.
31   // 2. data -> control: branch condition is marked divergent.
32   uint32_t block_id;
33   if (inst->IsBlockTerminator()) {
34     block_id = context().get_instr_block(inst)->id();
35   } else if (inst->opcode() == SpvOpLabel) {
36     block_id = inst->result_id();
37     opt::BasicBlock* bb = context().cfg()->block(block_id);
38     // Only enqueue phi instructions, as other uses don't affect divergence.
39     bb->ForEachPhiInst([this](opt::Instruction* phi) { Enqueue(phi); });
40   } else {
41     opt::ForwardDataFlowAnalysis::EnqueueUsers(inst);
42     return;
43   }
44   if (!cd_.HasBlock(block_id)) {
45     return;
46   }
47   for (const spvtools::opt::ControlDependence& dep :
48        cd_.GetDependenceTargets(block_id)) {
49     opt::Instruction* target_inst =
50         context().cfg()->block(dep.target_bb_id())->GetLabelInst();
51     Enqueue(target_inst);
52   }
53 }
54 
Visit(opt::Instruction * inst)55 opt::DataFlowAnalysis::VisitResult DivergenceAnalysis::Visit(
56     opt::Instruction* inst) {
57   if (inst->opcode() == SpvOpLabel) {
58     return VisitBlock(inst->result_id());
59   } else {
60     return VisitInstruction(inst);
61   }
62 }
63 
VisitBlock(uint32_t id)64 opt::DataFlowAnalysis::VisitResult DivergenceAnalysis::VisitBlock(uint32_t id) {
65   if (!cd_.HasBlock(id)) {
66     return opt::DataFlowAnalysis::VisitResult::kResultFixed;
67   }
68   DivergenceLevel& cur_level = divergence_[id];
69   if (cur_level == DivergenceLevel::kDivergent) {
70     return opt::DataFlowAnalysis::VisitResult::kResultFixed;
71   }
72   DivergenceLevel orig = cur_level;
73   for (const spvtools::opt::ControlDependence& dep :
74        cd_.GetDependenceSources(id)) {
75     if (divergence_[dep.source_bb_id()] > cur_level) {
76       cur_level = divergence_[dep.source_bb_id()];
77       divergence_source_[id] = dep.source_bb_id();
78     } else if (dep.source_bb_id() != 0) {
79       uint32_t condition_id = dep.GetConditionID(*context().cfg());
80       DivergenceLevel dep_level = divergence_[condition_id];
81       // Check if we are along the chain of unconditional branches starting from
82       // the branch target.
83       if (follow_unconditional_branches_[dep.branch_target_bb_id()] !=
84           follow_unconditional_branches_[dep.target_bb_id()]) {
85         // We must have reconverged in order to reach this block.
86         // Promote partially uniform to divergent.
87         if (dep_level == DivergenceLevel::kPartiallyUniform) {
88           dep_level = DivergenceLevel::kDivergent;
89         }
90       }
91       if (dep_level > cur_level) {
92         cur_level = dep_level;
93         divergence_source_[id] = condition_id;
94         divergence_dependence_source_[id] = dep.source_bb_id();
95       }
96     }
97   }
98   return cur_level > orig ? VisitResult::kResultChanged
99                           : VisitResult::kResultFixed;
100 }
101 
VisitInstruction(opt::Instruction * inst)102 opt::DataFlowAnalysis::VisitResult DivergenceAnalysis::VisitInstruction(
103     opt::Instruction* inst) {
104   if (inst->IsBlockTerminator()) {
105     // This is called only when the condition has changed, so return changed.
106     return VisitResult::kResultChanged;
107   }
108   if (!inst->HasResultId()) {
109     return VisitResult::kResultFixed;
110   }
111   uint32_t id = inst->result_id();
112   DivergenceLevel& cur_level = divergence_[id];
113   if (cur_level == DivergenceLevel::kDivergent) {
114     return opt::DataFlowAnalysis::VisitResult::kResultFixed;
115   }
116   DivergenceLevel orig = cur_level;
117   cur_level = ComputeInstructionDivergence(inst);
118   return cur_level > orig ? VisitResult::kResultChanged
119                           : VisitResult::kResultFixed;
120 }
121 
122 DivergenceAnalysis::DivergenceLevel
ComputeInstructionDivergence(opt::Instruction * inst)123 DivergenceAnalysis::ComputeInstructionDivergence(opt::Instruction* inst) {
124   // TODO(kuhar): Check to see if inst is decorated with Uniform or UniformId
125   // and use that to short circuit other checks. Uniform is for subgroups which
126   // would satisfy derivative groups too. UniformId takes a scope, so if it is
127   // subgroup or greater it could satisfy derivative group and
128   // Device/QueueFamily could satisfy fully uniform.
129   uint32_t id = inst->result_id();
130   // Handle divergence roots.
131   if (inst->opcode() == SpvOpFunctionParameter) {
132     divergence_source_[id] = 0;
133     return divergence_[id] = DivergenceLevel::kDivergent;
134   } else if (inst->IsLoad()) {
135     spvtools::opt::Instruction* var = inst->GetBaseAddress();
136     if (var->opcode() != SpvOpVariable) {
137       // Assume divergent.
138       divergence_source_[id] = 0;
139       return DivergenceLevel::kDivergent;
140     }
141     DivergenceLevel ret = ComputeVariableDivergence(var);
142     if (ret > DivergenceLevel::kUniform) {
143       divergence_source_[inst->result_id()] = 0;
144     }
145     return divergence_[id] = ret;
146   }
147   // Get the maximum divergence of the operands.
148   DivergenceLevel ret = DivergenceLevel::kUniform;
149   inst->ForEachInId([this, inst, &ret](const uint32_t* op) {
150     if (!op) return;
151     if (divergence_[*op] > ret) {
152       divergence_source_[inst->result_id()] = *op;
153       ret = divergence_[*op];
154     }
155   });
156   divergence_[inst->result_id()] = ret;
157   return ret;
158 }
159 
160 DivergenceAnalysis::DivergenceLevel
ComputeVariableDivergence(opt::Instruction * var)161 DivergenceAnalysis::ComputeVariableDivergence(opt::Instruction* var) {
162   uint32_t type_id = var->type_id();
163   spvtools::opt::analysis::Pointer* type =
164       context().get_type_mgr()->GetType(type_id)->AsPointer();
165   assert(type != nullptr);
166   uint32_t def_id = var->result_id();
167   DivergenceLevel ret;
168   switch (type->storage_class()) {
169     case SpvStorageClassFunction:
170     case SpvStorageClassGeneric:
171     case SpvStorageClassAtomicCounter:
172     case SpvStorageClassStorageBuffer:
173     case SpvStorageClassPhysicalStorageBuffer:
174     case SpvStorageClassOutput:
175     case SpvStorageClassWorkgroup:
176     case SpvStorageClassImage:  // Image atomics probably aren't uniform.
177     case SpvStorageClassPrivate:
178       ret = DivergenceLevel::kDivergent;
179       break;
180     case SpvStorageClassInput:
181       ret = DivergenceLevel::kDivergent;
182       // If this variable has a Flat decoration, it is partially uniform.
183       // TODO(kuhar): Track access chain indices and also consider Flat members
184       // of a structure.
185       context().get_decoration_mgr()->WhileEachDecoration(
186           def_id, SpvDecorationFlat, [&ret](const opt::Instruction&) {
187             ret = DivergenceLevel::kPartiallyUniform;
188             return false;
189           });
190       break;
191     case SpvStorageClassUniformConstant:
192       // May be a storage image which is also written to; mark those as
193       // divergent.
194       if (!var->IsVulkanStorageImage() || var->IsReadOnlyPointer()) {
195         ret = DivergenceLevel::kUniform;
196       } else {
197         ret = DivergenceLevel::kDivergent;
198       }
199       break;
200     case SpvStorageClassUniform:
201     case SpvStorageClassPushConstant:
202     case SpvStorageClassCrossWorkgroup:  // Not for shaders; default uniform.
203     default:
204       ret = DivergenceLevel::kUniform;
205       break;
206   }
207   return ret;
208 }
209 
Setup(opt::Function * function)210 void DivergenceAnalysis::Setup(opt::Function* function) {
211   // TODO(kuhar): Run functions called by |function| so we can detect
212   // reconvergence caused by multiple returns.
213   cd_.ComputeControlDependenceGraph(
214       *context().cfg(), *context().GetPostDominatorAnalysis(function));
215   context().cfg()->ForEachBlockInPostOrder(
216       function->entry().get(), [this](const opt::BasicBlock* bb) {
217         uint32_t id = bb->id();
218         if (bb->terminator() == nullptr ||
219             bb->terminator()->opcode() != SpvOpBranch) {
220           follow_unconditional_branches_[id] = id;
221         } else {
222           uint32_t target_id = bb->terminator()->GetSingleWordInOperand(0);
223           // Target is guaranteed to have been visited before us in postorder.
224           follow_unconditional_branches_[id] =
225               follow_unconditional_branches_[target_id];
226         }
227       });
228 }
229 
operator <<(std::ostream & os,DivergenceAnalysis::DivergenceLevel level)230 std::ostream& operator<<(std::ostream& os,
231                          DivergenceAnalysis::DivergenceLevel level) {
232   switch (level) {
233     case DivergenceAnalysis::DivergenceLevel::kUniform:
234       return os << "uniform";
235     case DivergenceAnalysis::DivergenceLevel::kPartiallyUniform:
236       return os << "partially uniform";
237     case DivergenceAnalysis::DivergenceLevel::kDivergent:
238       return os << "divergent";
239     default:
240       return os << "<invalid divergence level>";
241   }
242 }
243 
244 }  // namespace lint
245 }  // namespace spvtools
246