• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2022 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "source/opt/spread_volatile_semantics.h"
16 
17 #include "source/opt/decoration_manager.h"
18 #include "source/spirv_constant.h"
19 
20 namespace spvtools {
21 namespace opt {
22 namespace {
23 constexpr uint32_t kOpDecorateInOperandBuiltinDecoration = 2u;
24 constexpr uint32_t kOpLoadInOperandMemoryOperands = 1u;
25 constexpr uint32_t kOpEntryPointInOperandEntryPoint = 1u;
26 constexpr uint32_t kOpEntryPointInOperandInterface = 3u;
27 
HasBuiltinDecoration(analysis::DecorationManager * decoration_manager,uint32_t var_id,uint32_t built_in)28 bool HasBuiltinDecoration(analysis::DecorationManager* decoration_manager,
29                           uint32_t var_id, uint32_t built_in) {
30   return decoration_manager->FindDecoration(
31       var_id, uint32_t(spv::Decoration::BuiltIn),
32       [built_in](const Instruction& inst) {
33         return built_in == inst.GetSingleWordInOperand(
34                                kOpDecorateInOperandBuiltinDecoration);
35       });
36 }
37 
IsBuiltInForRayTracingVolatileSemantics(spv::BuiltIn built_in)38 bool IsBuiltInForRayTracingVolatileSemantics(spv::BuiltIn built_in) {
39   switch (built_in) {
40     case spv::BuiltIn::SMIDNV:
41     case spv::BuiltIn::WarpIDNV:
42     case spv::BuiltIn::SubgroupSize:
43     case spv::BuiltIn::SubgroupLocalInvocationId:
44     case spv::BuiltIn::SubgroupEqMask:
45     case spv::BuiltIn::SubgroupGeMask:
46     case spv::BuiltIn::SubgroupGtMask:
47     case spv::BuiltIn::SubgroupLeMask:
48     case spv::BuiltIn::SubgroupLtMask:
49       return true;
50     default:
51       return false;
52   }
53 }
54 
HasBuiltinForRayTracingVolatileSemantics(analysis::DecorationManager * decoration_manager,uint32_t var_id)55 bool HasBuiltinForRayTracingVolatileSemantics(
56     analysis::DecorationManager* decoration_manager, uint32_t var_id) {
57   return decoration_manager->FindDecoration(
58       var_id, uint32_t(spv::Decoration::BuiltIn), [](const Instruction& inst) {
59         spv::BuiltIn built_in = spv::BuiltIn(
60             inst.GetSingleWordInOperand(kOpDecorateInOperandBuiltinDecoration));
61         return IsBuiltInForRayTracingVolatileSemantics(built_in);
62       });
63 }
64 
HasVolatileDecoration(analysis::DecorationManager * decoration_manager,uint32_t var_id)65 bool HasVolatileDecoration(analysis::DecorationManager* decoration_manager,
66                            uint32_t var_id) {
67   return decoration_manager->HasDecoration(var_id,
68                                            uint32_t(spv::Decoration::Volatile));
69 }
70 
71 }  // namespace
72 
Process()73 Pass::Status SpreadVolatileSemantics::Process() {
74   if (HasNoExecutionModel()) {
75     return Status::SuccessWithoutChange;
76   }
77   const bool is_vk_memory_model_enabled =
78       context()->get_feature_mgr()->HasCapability(
79           spv::Capability::VulkanMemoryModel);
80   CollectTargetsForVolatileSemantics(is_vk_memory_model_enabled);
81 
82   // If VulkanMemoryModel capability is not enabled, we have to set Volatile
83   // decoration for interface variables instead of setting Volatile for load
84   // instructions. If an interface (or pointers to it) is used by two load
85   // instructions in two entry points and one must be volatile while another
86   // is not, we have to report an error for the conflict.
87   if (!is_vk_memory_model_enabled &&
88       HasInterfaceInConflictOfVolatileSemantics()) {
89     return Status::Failure;
90   }
91 
92   return SpreadVolatileSemanticsToVariables(is_vk_memory_model_enabled);
93 }
94 
SpreadVolatileSemanticsToVariables(const bool is_vk_memory_model_enabled)95 Pass::Status SpreadVolatileSemantics::SpreadVolatileSemanticsToVariables(
96     const bool is_vk_memory_model_enabled) {
97   Status status = Status::SuccessWithoutChange;
98   for (Instruction& var : context()->types_values()) {
99     auto entry_function_ids =
100         EntryFunctionsToSpreadVolatileSemanticsForVar(var.result_id());
101     if (entry_function_ids.empty()) {
102       continue;
103     }
104 
105     if (is_vk_memory_model_enabled) {
106       SetVolatileForLoadsInEntries(&var, entry_function_ids);
107     } else {
108       DecorateVarWithVolatile(&var);
109     }
110     status = Status::SuccessWithChange;
111   }
112   return status;
113 }
114 
IsTargetUsedByNonVolatileLoadInEntryPoint(uint32_t var_id,Instruction * entry_point)115 bool SpreadVolatileSemantics::IsTargetUsedByNonVolatileLoadInEntryPoint(
116     uint32_t var_id, Instruction* entry_point) {
117   uint32_t entry_function_id =
118       entry_point->GetSingleWordInOperand(kOpEntryPointInOperandEntryPoint);
119   std::unordered_set<uint32_t> funcs;
120   context()->CollectCallTreeFromRoots(entry_function_id, &funcs);
121   return !VisitLoadsOfPointersToVariableInEntries(
122       var_id,
123       [](Instruction* load) {
124         // If it has a load without volatile memory operand, finish traversal
125         // and return false.
126         if (load->NumInOperands() <= kOpLoadInOperandMemoryOperands) {
127           return false;
128         }
129         uint32_t memory_operands =
130             load->GetSingleWordInOperand(kOpLoadInOperandMemoryOperands);
131         return (memory_operands & uint32_t(spv::MemoryAccessMask::Volatile)) !=
132                0;
133       },
134       funcs);
135 }
136 
HasInterfaceInConflictOfVolatileSemantics()137 bool SpreadVolatileSemantics::HasInterfaceInConflictOfVolatileSemantics() {
138   for (Instruction& entry_point : get_module()->entry_points()) {
139     spv::ExecutionModel execution_model =
140         static_cast<spv::ExecutionModel>(entry_point.GetSingleWordInOperand(0));
141     for (uint32_t operand_index = kOpEntryPointInOperandInterface;
142          operand_index < entry_point.NumInOperands(); ++operand_index) {
143       uint32_t var_id = entry_point.GetSingleWordInOperand(operand_index);
144       if (!EntryFunctionsToSpreadVolatileSemanticsForVar(var_id).empty() &&
145           !IsTargetForVolatileSemantics(var_id, execution_model) &&
146           IsTargetUsedByNonVolatileLoadInEntryPoint(var_id, &entry_point)) {
147         Instruction* inst = context()->get_def_use_mgr()->GetDef(var_id);
148         context()->EmitErrorMessage(
149             "Variable is a target for Volatile semantics for an entry point, "
150             "but it is not for another entry point",
151             inst);
152         return true;
153       }
154     }
155   }
156   return false;
157 }
158 
MarkVolatileSemanticsForVariable(uint32_t var_id,Instruction * entry_point)159 void SpreadVolatileSemantics::MarkVolatileSemanticsForVariable(
160     uint32_t var_id, Instruction* entry_point) {
161   uint32_t entry_function_id =
162       entry_point->GetSingleWordInOperand(kOpEntryPointInOperandEntryPoint);
163   auto itr = var_ids_to_entry_fn_for_volatile_semantics_.find(var_id);
164   if (itr == var_ids_to_entry_fn_for_volatile_semantics_.end()) {
165     var_ids_to_entry_fn_for_volatile_semantics_[var_id] = {entry_function_id};
166     return;
167   }
168   itr->second.insert(entry_function_id);
169 }
170 
CollectTargetsForVolatileSemantics(const bool is_vk_memory_model_enabled)171 void SpreadVolatileSemantics::CollectTargetsForVolatileSemantics(
172     const bool is_vk_memory_model_enabled) {
173   for (Instruction& entry_point : get_module()->entry_points()) {
174     spv::ExecutionModel execution_model =
175         static_cast<spv::ExecutionModel>(entry_point.GetSingleWordInOperand(0));
176     for (uint32_t operand_index = kOpEntryPointInOperandInterface;
177          operand_index < entry_point.NumInOperands(); ++operand_index) {
178       uint32_t var_id = entry_point.GetSingleWordInOperand(operand_index);
179       if (!IsTargetForVolatileSemantics(var_id, execution_model)) {
180         continue;
181       }
182       if (is_vk_memory_model_enabled ||
183           IsTargetUsedByNonVolatileLoadInEntryPoint(var_id, &entry_point)) {
184         MarkVolatileSemanticsForVariable(var_id, &entry_point);
185       }
186     }
187   }
188 }
189 
DecorateVarWithVolatile(Instruction * var)190 void SpreadVolatileSemantics::DecorateVarWithVolatile(Instruction* var) {
191   analysis::DecorationManager* decoration_manager =
192       context()->get_decoration_mgr();
193   uint32_t var_id = var->result_id();
194   if (HasVolatileDecoration(decoration_manager, var_id)) {
195     return;
196   }
197   get_decoration_mgr()->AddDecoration(
198       spv::Op::OpDecorate,
199       {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {var_id}},
200        {spv_operand_type_t::SPV_OPERAND_TYPE_LITERAL_INTEGER,
201         {uint32_t(spv::Decoration::Volatile)}}});
202 }
203 
VisitLoadsOfPointersToVariableInEntries(uint32_t var_id,const std::function<bool (Instruction *)> & handle_load,const std::unordered_set<uint32_t> & function_ids)204 bool SpreadVolatileSemantics::VisitLoadsOfPointersToVariableInEntries(
205     uint32_t var_id, const std::function<bool(Instruction*)>& handle_load,
206     const std::unordered_set<uint32_t>& function_ids) {
207   std::vector<uint32_t> worklist({var_id});
208   auto* def_use_mgr = context()->get_def_use_mgr();
209   while (!worklist.empty()) {
210     uint32_t ptr_id = worklist.back();
211     worklist.pop_back();
212     bool finish_traversal = !def_use_mgr->WhileEachUser(
213         ptr_id, [this, &worklist, &ptr_id, handle_load,
214                  &function_ids](Instruction* user) {
215           BasicBlock* block = context()->get_instr_block(user);
216           if (block == nullptr ||
217               function_ids.find(block->GetParent()->result_id()) ==
218                   function_ids.end()) {
219             return true;
220           }
221 
222           if (user->opcode() == spv::Op::OpAccessChain ||
223               user->opcode() == spv::Op::OpInBoundsAccessChain ||
224               user->opcode() == spv::Op::OpPtrAccessChain ||
225               user->opcode() == spv::Op::OpInBoundsPtrAccessChain ||
226               user->opcode() == spv::Op::OpCopyObject) {
227             if (ptr_id == user->GetSingleWordInOperand(0))
228               worklist.push_back(user->result_id());
229             return true;
230           }
231 
232           if (user->opcode() != spv::Op::OpLoad) {
233             return true;
234           }
235 
236           return handle_load(user);
237         });
238     if (finish_traversal) return false;
239   }
240   return true;
241 }
242 
SetVolatileForLoadsInEntries(Instruction * var,const std::unordered_set<uint32_t> & entry_function_ids)243 void SpreadVolatileSemantics::SetVolatileForLoadsInEntries(
244     Instruction* var, const std::unordered_set<uint32_t>& entry_function_ids) {
245   // Set Volatile memory operand for all load instructions if they do not have
246   // it.
247   for (auto entry_id : entry_function_ids) {
248     std::unordered_set<uint32_t> funcs;
249     context()->CollectCallTreeFromRoots(entry_id, &funcs);
250     VisitLoadsOfPointersToVariableInEntries(
251         var->result_id(),
252         [](Instruction* load) {
253           if (load->NumInOperands() <= kOpLoadInOperandMemoryOperands) {
254             load->AddOperand({SPV_OPERAND_TYPE_MEMORY_ACCESS,
255                               {uint32_t(spv::MemoryAccessMask::Volatile)}});
256             return true;
257           }
258           uint32_t memory_operands =
259               load->GetSingleWordInOperand(kOpLoadInOperandMemoryOperands);
260           memory_operands |= uint32_t(spv::MemoryAccessMask::Volatile);
261           load->SetInOperand(kOpLoadInOperandMemoryOperands, {memory_operands});
262           return true;
263         },
264         funcs);
265   }
266 }
267 
IsTargetForVolatileSemantics(uint32_t var_id,spv::ExecutionModel execution_model)268 bool SpreadVolatileSemantics::IsTargetForVolatileSemantics(
269     uint32_t var_id, spv::ExecutionModel execution_model) {
270   analysis::DecorationManager* decoration_manager =
271       context()->get_decoration_mgr();
272   if (execution_model == spv::ExecutionModel::Fragment) {
273     return get_module()->version() >= SPV_SPIRV_VERSION_WORD(1, 6) &&
274            HasBuiltinDecoration(decoration_manager, var_id,
275                                 uint32_t(spv::BuiltIn::HelperInvocation));
276   }
277 
278   if (execution_model == spv::ExecutionModel::IntersectionKHR ||
279       execution_model == spv::ExecutionModel::IntersectionNV) {
280     if (HasBuiltinDecoration(decoration_manager, var_id,
281                              uint32_t(spv::BuiltIn::RayTmaxKHR))) {
282       return true;
283     }
284   }
285 
286   switch (execution_model) {
287     case spv::ExecutionModel::RayGenerationKHR:
288     case spv::ExecutionModel::ClosestHitKHR:
289     case spv::ExecutionModel::MissKHR:
290     case spv::ExecutionModel::CallableKHR:
291     case spv::ExecutionModel::IntersectionKHR:
292       return HasBuiltinForRayTracingVolatileSemantics(decoration_manager,
293                                                       var_id);
294     default:
295       return false;
296   }
297 }
298 
299 }  // namespace opt
300 }  // namespace spvtools
301