• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2021 Google LLC.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "source/lint/divergence_analysis.h"
16 
17 #include "source/opt/basic_block.h"
18 #include "source/opt/control_dependence.h"
19 #include "source/opt/dataflow.h"
20 #include "source/opt/function.h"
21 #include "source/opt/instruction.h"
22 
23 namespace spvtools {
24 namespace lint {
25 
EnqueueSuccessors(opt::Instruction * inst)26 void DivergenceAnalysis::EnqueueSuccessors(opt::Instruction* inst) {
27   // Enqueue control dependents of block, if applicable.
28   // There are two ways for a dependence source to be updated:
29   // 1. control -> control: source block is marked divergent.
30   // 2. data -> control: branch condition is marked divergent.
31   uint32_t block_id;
32   if (inst->IsBlockTerminator()) {
33     block_id = context().get_instr_block(inst)->id();
34   } else if (inst->opcode() == spv::Op::OpLabel) {
35     block_id = inst->result_id();
36     opt::BasicBlock* bb = context().cfg()->block(block_id);
37     // Only enqueue phi instructions, as other uses don't affect divergence.
38     bb->ForEachPhiInst([this](opt::Instruction* phi) { Enqueue(phi); });
39   } else {
40     opt::ForwardDataFlowAnalysis::EnqueueUsers(inst);
41     return;
42   }
43   if (!cd_.HasBlock(block_id)) {
44     return;
45   }
46   for (const spvtools::opt::ControlDependence& dep :
47        cd_.GetDependenceTargets(block_id)) {
48     opt::Instruction* target_inst =
49         context().cfg()->block(dep.target_bb_id())->GetLabelInst();
50     Enqueue(target_inst);
51   }
52 }
53 
Visit(opt::Instruction * inst)54 opt::DataFlowAnalysis::VisitResult DivergenceAnalysis::Visit(
55     opt::Instruction* inst) {
56   if (inst->opcode() == spv::Op::OpLabel) {
57     return VisitBlock(inst->result_id());
58   } else {
59     return VisitInstruction(inst);
60   }
61 }
62 
VisitBlock(uint32_t id)63 opt::DataFlowAnalysis::VisitResult DivergenceAnalysis::VisitBlock(uint32_t id) {
64   if (!cd_.HasBlock(id)) {
65     return opt::DataFlowAnalysis::VisitResult::kResultFixed;
66   }
67   DivergenceLevel& cur_level = divergence_[id];
68   if (cur_level == DivergenceLevel::kDivergent) {
69     return opt::DataFlowAnalysis::VisitResult::kResultFixed;
70   }
71   DivergenceLevel orig = cur_level;
72   for (const spvtools::opt::ControlDependence& dep :
73        cd_.GetDependenceSources(id)) {
74     if (divergence_[dep.source_bb_id()] > cur_level) {
75       cur_level = divergence_[dep.source_bb_id()];
76       divergence_source_[id] = dep.source_bb_id();
77     } else if (dep.source_bb_id() != 0) {
78       uint32_t condition_id = dep.GetConditionID(*context().cfg());
79       DivergenceLevel dep_level = divergence_[condition_id];
80       // Check if we are along the chain of unconditional branches starting from
81       // the branch target.
82       if (follow_unconditional_branches_[dep.branch_target_bb_id()] !=
83           follow_unconditional_branches_[dep.target_bb_id()]) {
84         // We must have reconverged in order to reach this block.
85         // Promote partially uniform to divergent.
86         if (dep_level == DivergenceLevel::kPartiallyUniform) {
87           dep_level = DivergenceLevel::kDivergent;
88         }
89       }
90       if (dep_level > cur_level) {
91         cur_level = dep_level;
92         divergence_source_[id] = condition_id;
93         divergence_dependence_source_[id] = dep.source_bb_id();
94       }
95     }
96   }
97   return cur_level > orig ? VisitResult::kResultChanged
98                           : VisitResult::kResultFixed;
99 }
100 
VisitInstruction(opt::Instruction * inst)101 opt::DataFlowAnalysis::VisitResult DivergenceAnalysis::VisitInstruction(
102     opt::Instruction* inst) {
103   if (inst->IsBlockTerminator()) {
104     // This is called only when the condition has changed, so return changed.
105     return VisitResult::kResultChanged;
106   }
107   if (!inst->HasResultId()) {
108     return VisitResult::kResultFixed;
109   }
110   uint32_t id = inst->result_id();
111   DivergenceLevel& cur_level = divergence_[id];
112   if (cur_level == DivergenceLevel::kDivergent) {
113     return opt::DataFlowAnalysis::VisitResult::kResultFixed;
114   }
115   DivergenceLevel orig = cur_level;
116   cur_level = ComputeInstructionDivergence(inst);
117   return cur_level > orig ? VisitResult::kResultChanged
118                           : VisitResult::kResultFixed;
119 }
120 
121 DivergenceAnalysis::DivergenceLevel
ComputeInstructionDivergence(opt::Instruction * inst)122 DivergenceAnalysis::ComputeInstructionDivergence(opt::Instruction* inst) {
123   // TODO(kuhar): Check to see if inst is decorated with Uniform or UniformId
124   // and use that to short circuit other checks. Uniform is for subgroups which
125   // would satisfy derivative groups too. UniformId takes a scope, so if it is
126   // subgroup or greater it could satisfy derivative group and
127   // Device/QueueFamily could satisfy fully uniform.
128   uint32_t id = inst->result_id();
129   // Handle divergence roots.
130   if (inst->opcode() == spv::Op::OpFunctionParameter) {
131     divergence_source_[id] = 0;
132     return divergence_[id] = DivergenceLevel::kDivergent;
133   } else if (inst->IsLoad()) {
134     spvtools::opt::Instruction* var = inst->GetBaseAddress();
135     if (var->opcode() != spv::Op::OpVariable) {
136       // Assume divergent.
137       divergence_source_[id] = 0;
138       return DivergenceLevel::kDivergent;
139     }
140     DivergenceLevel ret = ComputeVariableDivergence(var);
141     if (ret > DivergenceLevel::kUniform) {
142       divergence_source_[inst->result_id()] = 0;
143     }
144     return divergence_[id] = ret;
145   }
146   // Get the maximum divergence of the operands.
147   DivergenceLevel ret = DivergenceLevel::kUniform;
148   inst->ForEachInId([this, inst, &ret](const uint32_t* op) {
149     if (!op) return;
150     if (divergence_[*op] > ret) {
151       divergence_source_[inst->result_id()] = *op;
152       ret = divergence_[*op];
153     }
154   });
155   divergence_[inst->result_id()] = ret;
156   return ret;
157 }
158 
159 DivergenceAnalysis::DivergenceLevel
ComputeVariableDivergence(opt::Instruction * var)160 DivergenceAnalysis::ComputeVariableDivergence(opt::Instruction* var) {
161   uint32_t type_id = var->type_id();
162   spvtools::opt::analysis::Pointer* type =
163       context().get_type_mgr()->GetType(type_id)->AsPointer();
164   assert(type != nullptr);
165   uint32_t def_id = var->result_id();
166   DivergenceLevel ret;
167   switch (type->storage_class()) {
168     case spv::StorageClass::Function:
169     case spv::StorageClass::Generic:
170     case spv::StorageClass::AtomicCounter:
171     case spv::StorageClass::StorageBuffer:
172     case spv::StorageClass::PhysicalStorageBuffer:
173     case spv::StorageClass::Output:
174     case spv::StorageClass::Workgroup:
175     case spv::StorageClass::Image:  // Image atomics probably aren't uniform.
176     case spv::StorageClass::Private:
177       ret = DivergenceLevel::kDivergent;
178       break;
179     case spv::StorageClass::Input:
180       ret = DivergenceLevel::kDivergent;
181       // If this variable has a Flat decoration, it is partially uniform.
182       // TODO(kuhar): Track access chain indices and also consider Flat members
183       // of a structure.
184       context().get_decoration_mgr()->WhileEachDecoration(
185           def_id, static_cast<uint32_t>(spv::Decoration::Flat),
186           [&ret](const opt::Instruction&) {
187             ret = DivergenceLevel::kPartiallyUniform;
188             return false;
189           });
190       break;
191     case spv::StorageClass::UniformConstant:
192       // May be a storage image which is also written to; mark those as
193       // divergent.
194       if (!var->IsVulkanStorageImage() || var->IsReadOnlyPointer()) {
195         ret = DivergenceLevel::kUniform;
196       } else {
197         ret = DivergenceLevel::kDivergent;
198       }
199       break;
200     case spv::StorageClass::Uniform:
201     case spv::StorageClass::PushConstant:
202     case spv::StorageClass::CrossWorkgroup:  // Not for shaders; default
203                                              // uniform.
204     default:
205       ret = DivergenceLevel::kUniform;
206       break;
207   }
208   return ret;
209 }
210 
Setup(opt::Function * function)211 void DivergenceAnalysis::Setup(opt::Function* function) {
212   // TODO(kuhar): Run functions called by |function| so we can detect
213   // reconvergence caused by multiple returns.
214   cd_.ComputeControlDependenceGraph(
215       *context().cfg(), *context().GetPostDominatorAnalysis(function));
216   context().cfg()->ForEachBlockInPostOrder(
217       function->entry().get(), [this](const opt::BasicBlock* bb) {
218         uint32_t id = bb->id();
219         if (bb->terminator() == nullptr ||
220             bb->terminator()->opcode() != spv::Op::OpBranch) {
221           follow_unconditional_branches_[id] = id;
222         } else {
223           uint32_t target_id = bb->terminator()->GetSingleWordInOperand(0);
224           // Target is guaranteed to have been visited before us in postorder.
225           follow_unconditional_branches_[id] =
226               follow_unconditional_branches_[target_id];
227         }
228       });
229 }
230 
operator <<(std::ostream & os,DivergenceAnalysis::DivergenceLevel level)231 std::ostream& operator<<(std::ostream& os,
232                          DivergenceAnalysis::DivergenceLevel level) {
233   switch (level) {
234     case DivergenceAnalysis::DivergenceLevel::kUniform:
235       return os << "uniform";
236     case DivergenceAnalysis::DivergenceLevel::kPartiallyUniform:
237       return os << "partially uniform";
238     case DivergenceAnalysis::DivergenceLevel::kDivergent:
239       return os << "divergent";
240     default:
241       return os << "<invalid divergence level>";
242   }
243 }
244 
245 }  // namespace lint
246 }  // namespace spvtools
247