• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2022 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "source/opt/spread_volatile_semantics.h"
16 
17 #include "source/opt/decoration_manager.h"
18 #include "source/opt/ir_builder.h"
19 #include "source/spirv_constant.h"
20 
21 namespace spvtools {
22 namespace opt {
23 namespace {
24 
25 const uint32_t kOpDecorateInOperandBuiltinDecoration = 2u;
26 const uint32_t kOpLoadInOperandMemoryOperands = 1u;
27 const uint32_t kOpEntryPointInOperandEntryPoint = 1u;
28 const uint32_t kOpEntryPointInOperandInterface = 3u;
29 
HasBuiltinDecoration(analysis::DecorationManager * decoration_manager,uint32_t var_id,uint32_t built_in)30 bool HasBuiltinDecoration(analysis::DecorationManager* decoration_manager,
31                           uint32_t var_id, uint32_t built_in) {
32   return decoration_manager->FindDecoration(
33       var_id, SpvDecorationBuiltIn, [built_in](const Instruction& inst) {
34         return built_in == inst.GetSingleWordInOperand(
35                                kOpDecorateInOperandBuiltinDecoration);
36       });
37 }
38 
IsBuiltInForRayTracingVolatileSemantics(uint32_t built_in)39 bool IsBuiltInForRayTracingVolatileSemantics(uint32_t built_in) {
40   switch (built_in) {
41     case SpvBuiltInSMIDNV:
42     case SpvBuiltInWarpIDNV:
43     case SpvBuiltInSubgroupSize:
44     case SpvBuiltInSubgroupLocalInvocationId:
45     case SpvBuiltInSubgroupEqMask:
46     case SpvBuiltInSubgroupGeMask:
47     case SpvBuiltInSubgroupGtMask:
48     case SpvBuiltInSubgroupLeMask:
49     case SpvBuiltInSubgroupLtMask:
50       return true;
51     default:
52       return false;
53   }
54 }
55 
HasBuiltinForRayTracingVolatileSemantics(analysis::DecorationManager * decoration_manager,uint32_t var_id)56 bool HasBuiltinForRayTracingVolatileSemantics(
57     analysis::DecorationManager* decoration_manager, uint32_t var_id) {
58   return decoration_manager->FindDecoration(
59       var_id, SpvDecorationBuiltIn, [](const Instruction& inst) {
60         uint32_t built_in =
61             inst.GetSingleWordInOperand(kOpDecorateInOperandBuiltinDecoration);
62         return IsBuiltInForRayTracingVolatileSemantics(built_in);
63       });
64 }
65 
HasVolatileDecoration(analysis::DecorationManager * decoration_manager,uint32_t var_id)66 bool HasVolatileDecoration(analysis::DecorationManager* decoration_manager,
67                            uint32_t var_id) {
68   return decoration_manager->HasDecoration(var_id, SpvDecorationVolatile);
69 }
70 
HasOnlyEntryPointsAsFunctions(IRContext * context,Module * module)71 bool HasOnlyEntryPointsAsFunctions(IRContext* context, Module* module) {
72   std::unordered_set<uint32_t> entry_function_ids;
73   for (Instruction& entry_point : module->entry_points()) {
74     entry_function_ids.insert(
75         entry_point.GetSingleWordInOperand(kOpEntryPointInOperandEntryPoint));
76   }
77   for (auto& function : *module) {
78     if (entry_function_ids.find(function.result_id()) ==
79         entry_function_ids.end()) {
80       std::string message(
81           "Functions of SPIR-V for spread-volatile-semantics pass input must "
82           "be inlined except entry points");
83       message += "\n  " + function.DefInst().PrettyPrint(
84                               SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES);
85       context->consumer()(SPV_MSG_ERROR, "", {0, 0, 0}, message.c_str());
86       return false;
87     }
88   }
89   return true;
90 }
91 
92 }  // namespace
93 
Process()94 Pass::Status SpreadVolatileSemantics::Process() {
95   if (!HasOnlyEntryPointsAsFunctions(context(), get_module())) {
96     return Status::Failure;
97   }
98 
99   const bool is_vk_memory_model_enabled =
100       context()->get_feature_mgr()->HasCapability(
101           SpvCapabilityVulkanMemoryModel);
102   CollectTargetsForVolatileSemantics(is_vk_memory_model_enabled);
103 
104   // If VulkanMemoryModel capability is not enabled, we have to set Volatile
105   // decoration for interface variables instead of setting Volatile for load
106   // instructions. If an interface (or pointers to it) is used by two load
107   // instructions in two entry points and one must be volatile while another
108   // is not, we have to report an error for the conflict.
109   if (!is_vk_memory_model_enabled &&
110       HasInterfaceInConflictOfVolatileSemantics()) {
111     return Status::Failure;
112   }
113 
114   return SpreadVolatileSemanticsToVariables(is_vk_memory_model_enabled);
115 }
116 
SpreadVolatileSemanticsToVariables(const bool is_vk_memory_model_enabled)117 Pass::Status SpreadVolatileSemantics::SpreadVolatileSemanticsToVariables(
118     const bool is_vk_memory_model_enabled) {
119   Status status = Status::SuccessWithoutChange;
120   for (Instruction& var : context()->types_values()) {
121     auto entry_function_ids =
122         EntryFunctionsToSpreadVolatileSemanticsForVar(var.result_id());
123     if (entry_function_ids.empty()) {
124       continue;
125     }
126 
127     if (is_vk_memory_model_enabled) {
128       SetVolatileForLoadsInEntries(&var, entry_function_ids);
129     } else {
130       DecorateVarWithVolatile(&var);
131     }
132     status = Status::SuccessWithChange;
133   }
134   return status;
135 }
136 
IsTargetUsedByNonVolatileLoadInEntryPoint(uint32_t var_id,Instruction * entry_point)137 bool SpreadVolatileSemantics::IsTargetUsedByNonVolatileLoadInEntryPoint(
138     uint32_t var_id, Instruction* entry_point) {
139   uint32_t entry_function_id =
140       entry_point->GetSingleWordInOperand(kOpEntryPointInOperandEntryPoint);
141   return !VisitLoadsOfPointersToVariableInEntries(
142       var_id,
143       [](Instruction* load) {
144         // If it has a load without volatile memory operand, finish traversal
145         // and return false.
146         if (load->NumInOperands() <= kOpLoadInOperandMemoryOperands) {
147           return false;
148         }
149         uint32_t memory_operands =
150             load->GetSingleWordInOperand(kOpLoadInOperandMemoryOperands);
151         return (memory_operands & SpvMemoryAccessVolatileMask) != 0;
152       },
153       {entry_function_id});
154 }
155 
HasInterfaceInConflictOfVolatileSemantics()156 bool SpreadVolatileSemantics::HasInterfaceInConflictOfVolatileSemantics() {
157   for (Instruction& entry_point : get_module()->entry_points()) {
158     SpvExecutionModel execution_model =
159         static_cast<SpvExecutionModel>(entry_point.GetSingleWordInOperand(0));
160     for (uint32_t operand_index = kOpEntryPointInOperandInterface;
161          operand_index < entry_point.NumInOperands(); ++operand_index) {
162       uint32_t var_id = entry_point.GetSingleWordInOperand(operand_index);
163       if (!EntryFunctionsToSpreadVolatileSemanticsForVar(var_id).empty() &&
164           !IsTargetForVolatileSemantics(var_id, execution_model) &&
165           IsTargetUsedByNonVolatileLoadInEntryPoint(var_id, &entry_point)) {
166         Instruction* inst = context()->get_def_use_mgr()->GetDef(var_id);
167         context()->EmitErrorMessage(
168             "Variable is a target for Volatile semantics for an entry point, "
169             "but it is not for another entry point",
170             inst);
171         return true;
172       }
173     }
174   }
175   return false;
176 }
177 
MarkVolatileSemanticsForVariable(uint32_t var_id,Instruction * entry_point)178 void SpreadVolatileSemantics::MarkVolatileSemanticsForVariable(
179     uint32_t var_id, Instruction* entry_point) {
180   uint32_t entry_function_id =
181       entry_point->GetSingleWordInOperand(kOpEntryPointInOperandEntryPoint);
182   auto itr = var_ids_to_entry_fn_for_volatile_semantics_.find(var_id);
183   if (itr == var_ids_to_entry_fn_for_volatile_semantics_.end()) {
184     var_ids_to_entry_fn_for_volatile_semantics_[var_id] = {entry_function_id};
185     return;
186   }
187   itr->second.insert(entry_function_id);
188 }
189 
CollectTargetsForVolatileSemantics(const bool is_vk_memory_model_enabled)190 void SpreadVolatileSemantics::CollectTargetsForVolatileSemantics(
191     const bool is_vk_memory_model_enabled) {
192   for (Instruction& entry_point : get_module()->entry_points()) {
193     SpvExecutionModel execution_model =
194         static_cast<SpvExecutionModel>(entry_point.GetSingleWordInOperand(0));
195     for (uint32_t operand_index = kOpEntryPointInOperandInterface;
196          operand_index < entry_point.NumInOperands(); ++operand_index) {
197       uint32_t var_id = entry_point.GetSingleWordInOperand(operand_index);
198       if (!IsTargetForVolatileSemantics(var_id, execution_model)) {
199         continue;
200       }
201       if (is_vk_memory_model_enabled ||
202           IsTargetUsedByNonVolatileLoadInEntryPoint(var_id, &entry_point)) {
203         MarkVolatileSemanticsForVariable(var_id, &entry_point);
204       }
205     }
206   }
207 }
208 
DecorateVarWithVolatile(Instruction * var)209 void SpreadVolatileSemantics::DecorateVarWithVolatile(Instruction* var) {
210   analysis::DecorationManager* decoration_manager =
211       context()->get_decoration_mgr();
212   uint32_t var_id = var->result_id();
213   if (HasVolatileDecoration(decoration_manager, var_id)) {
214     return;
215   }
216   get_decoration_mgr()->AddDecoration(
217       SpvOpDecorate, {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {var_id}},
218                       {spv_operand_type_t::SPV_OPERAND_TYPE_LITERAL_INTEGER,
219                        {SpvDecorationVolatile}}});
220 }
221 
VisitLoadsOfPointersToVariableInEntries(uint32_t var_id,const std::function<bool (Instruction *)> & handle_load,const std::unordered_set<uint32_t> & entry_function_ids)222 bool SpreadVolatileSemantics::VisitLoadsOfPointersToVariableInEntries(
223     uint32_t var_id, const std::function<bool(Instruction*)>& handle_load,
224     const std::unordered_set<uint32_t>& entry_function_ids) {
225   std::vector<uint32_t> worklist({var_id});
226   auto* def_use_mgr = context()->get_def_use_mgr();
227   while (!worklist.empty()) {
228     uint32_t ptr_id = worklist.back();
229     worklist.pop_back();
230     bool finish_traversal = !def_use_mgr->WhileEachUser(
231         ptr_id, [this, &worklist, &ptr_id, handle_load,
232                  &entry_function_ids](Instruction* user) {
233           BasicBlock* block = context()->get_instr_block(user);
234           if (block == nullptr ||
235               entry_function_ids.find(block->GetParent()->result_id()) ==
236                   entry_function_ids.end()) {
237             return true;
238           }
239 
240           if (user->opcode() == SpvOpAccessChain ||
241               user->opcode() == SpvOpInBoundsAccessChain ||
242               user->opcode() == SpvOpPtrAccessChain ||
243               user->opcode() == SpvOpInBoundsPtrAccessChain ||
244               user->opcode() == SpvOpCopyObject) {
245             if (ptr_id == user->GetSingleWordInOperand(0))
246               worklist.push_back(user->result_id());
247             return true;
248           }
249 
250           if (user->opcode() != SpvOpLoad) {
251             return true;
252           }
253 
254           return handle_load(user);
255         });
256     if (finish_traversal) return false;
257   }
258   return true;
259 }
260 
SetVolatileForLoadsInEntries(Instruction * var,const std::unordered_set<uint32_t> & entry_function_ids)261 void SpreadVolatileSemantics::SetVolatileForLoadsInEntries(
262     Instruction* var, const std::unordered_set<uint32_t>& entry_function_ids) {
263   // Set Volatile memory operand for all load instructions if they do not have
264   // it.
265   VisitLoadsOfPointersToVariableInEntries(
266       var->result_id(),
267       [](Instruction* load) {
268         if (load->NumInOperands() <= kOpLoadInOperandMemoryOperands) {
269           load->AddOperand(
270               {SPV_OPERAND_TYPE_MEMORY_ACCESS, {SpvMemoryAccessVolatileMask}});
271           return true;
272         }
273         uint32_t memory_operands =
274             load->GetSingleWordInOperand(kOpLoadInOperandMemoryOperands);
275         memory_operands |= SpvMemoryAccessVolatileMask;
276         load->SetInOperand(kOpLoadInOperandMemoryOperands, {memory_operands});
277         return true;
278       },
279       entry_function_ids);
280 }
281 
IsTargetForVolatileSemantics(uint32_t var_id,SpvExecutionModel execution_model)282 bool SpreadVolatileSemantics::IsTargetForVolatileSemantics(
283     uint32_t var_id, SpvExecutionModel execution_model) {
284   analysis::DecorationManager* decoration_manager =
285       context()->get_decoration_mgr();
286   if (execution_model == SpvExecutionModelFragment) {
287     return get_module()->version() >= SPV_SPIRV_VERSION_WORD(1, 6) &&
288            HasBuiltinDecoration(decoration_manager, var_id,
289                                 SpvBuiltInHelperInvocation);
290   }
291 
292   if (execution_model == SpvExecutionModelIntersectionKHR ||
293       execution_model == SpvExecutionModelIntersectionNV) {
294     if (HasBuiltinDecoration(decoration_manager, var_id,
295                              SpvBuiltInRayTmaxKHR)) {
296       return true;
297     }
298   }
299 
300   switch (execution_model) {
301     case SpvExecutionModelRayGenerationKHR:
302     case SpvExecutionModelClosestHitKHR:
303     case SpvExecutionModelMissKHR:
304     case SpvExecutionModelCallableKHR:
305     case SpvExecutionModelIntersectionKHR:
306       return HasBuiltinForRayTracingVolatileSemantics(decoration_manager,
307                                                       var_id);
308     default:
309       return false;
310   }
311 }
312 
313 }  // namespace opt
314 }  // namespace spvtools
315