• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2022 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "source/opt/spread_volatile_semantics.h"
16 
17 #include "source/opt/decoration_manager.h"
18 #include "source/opt/ir_builder.h"
19 #include "source/spirv_constant.h"
20 
21 namespace spvtools {
22 namespace opt {
23 namespace {
24 
25 const uint32_t kOpDecorateInOperandBuiltinDecoration = 2u;
26 const uint32_t kOpLoadInOperandMemoryOperands = 1u;
27 const uint32_t kOpEntryPointInOperandEntryPoint = 1u;
28 const uint32_t kOpEntryPointInOperandInterface = 3u;
29 
HasBuiltinDecoration(analysis::DecorationManager * decoration_manager,uint32_t var_id,uint32_t built_in)30 bool HasBuiltinDecoration(analysis::DecorationManager* decoration_manager,
31                           uint32_t var_id, uint32_t built_in) {
32   return decoration_manager->FindDecoration(
33       var_id, SpvDecorationBuiltIn, [built_in](const Instruction& inst) {
34         return built_in == inst.GetSingleWordInOperand(
35                                kOpDecorateInOperandBuiltinDecoration);
36       });
37 }
38 
IsBuiltInForRayTracingVolatileSemantics(uint32_t built_in)39 bool IsBuiltInForRayTracingVolatileSemantics(uint32_t built_in) {
40   switch (built_in) {
41     case SpvBuiltInSMIDNV:
42     case SpvBuiltInWarpIDNV:
43     case SpvBuiltInSubgroupSize:
44     case SpvBuiltInSubgroupLocalInvocationId:
45     case SpvBuiltInSubgroupEqMask:
46     case SpvBuiltInSubgroupGeMask:
47     case SpvBuiltInSubgroupGtMask:
48     case SpvBuiltInSubgroupLeMask:
49     case SpvBuiltInSubgroupLtMask:
50       return true;
51     default:
52       return false;
53   }
54 }
55 
HasBuiltinForRayTracingVolatileSemantics(analysis::DecorationManager * decoration_manager,uint32_t var_id)56 bool HasBuiltinForRayTracingVolatileSemantics(
57     analysis::DecorationManager* decoration_manager, uint32_t var_id) {
58   return decoration_manager->FindDecoration(
59       var_id, SpvDecorationBuiltIn, [](const Instruction& inst) {
60         uint32_t built_in =
61             inst.GetSingleWordInOperand(kOpDecorateInOperandBuiltinDecoration);
62         return IsBuiltInForRayTracingVolatileSemantics(built_in);
63       });
64 }
65 
HasVolatileDecoration(analysis::DecorationManager * decoration_manager,uint32_t var_id)66 bool HasVolatileDecoration(analysis::DecorationManager* decoration_manager,
67                            uint32_t var_id) {
68   return decoration_manager->HasDecoration(var_id, SpvDecorationVolatile);
69 }
70 
71 }  // namespace
72 
Process()73 Pass::Status SpreadVolatileSemantics::Process() {
74   if (HasNoExecutionModel()) {
75     return Status::SuccessWithoutChange;
76   }
77   const bool is_vk_memory_model_enabled =
78       context()->get_feature_mgr()->HasCapability(
79           SpvCapabilityVulkanMemoryModel);
80   CollectTargetsForVolatileSemantics(is_vk_memory_model_enabled);
81 
82   // If VulkanMemoryModel capability is not enabled, we have to set Volatile
83   // decoration for interface variables instead of setting Volatile for load
84   // instructions. If an interface (or pointers to it) is used by two load
85   // instructions in two entry points and one must be volatile while another
86   // is not, we have to report an error for the conflict.
87   if (!is_vk_memory_model_enabled &&
88       HasInterfaceInConflictOfVolatileSemantics()) {
89     return Status::Failure;
90   }
91 
92   return SpreadVolatileSemanticsToVariables(is_vk_memory_model_enabled);
93 }
94 
SpreadVolatileSemanticsToVariables(const bool is_vk_memory_model_enabled)95 Pass::Status SpreadVolatileSemantics::SpreadVolatileSemanticsToVariables(
96     const bool is_vk_memory_model_enabled) {
97   Status status = Status::SuccessWithoutChange;
98   for (Instruction& var : context()->types_values()) {
99     auto entry_function_ids =
100         EntryFunctionsToSpreadVolatileSemanticsForVar(var.result_id());
101     if (entry_function_ids.empty()) {
102       continue;
103     }
104 
105     if (is_vk_memory_model_enabled) {
106       SetVolatileForLoadsInEntries(&var, entry_function_ids);
107     } else {
108       DecorateVarWithVolatile(&var);
109     }
110     status = Status::SuccessWithChange;
111   }
112   return status;
113 }
114 
IsTargetUsedByNonVolatileLoadInEntryPoint(uint32_t var_id,Instruction * entry_point)115 bool SpreadVolatileSemantics::IsTargetUsedByNonVolatileLoadInEntryPoint(
116     uint32_t var_id, Instruction* entry_point) {
117   uint32_t entry_function_id =
118       entry_point->GetSingleWordInOperand(kOpEntryPointInOperandEntryPoint);
119   std::unordered_set<uint32_t> funcs;
120   context()->CollectCallTreeFromRoots(entry_function_id, &funcs);
121   return !VisitLoadsOfPointersToVariableInEntries(
122       var_id,
123       [](Instruction* load) {
124         // If it has a load without volatile memory operand, finish traversal
125         // and return false.
126         if (load->NumInOperands() <= kOpLoadInOperandMemoryOperands) {
127           return false;
128         }
129         uint32_t memory_operands =
130             load->GetSingleWordInOperand(kOpLoadInOperandMemoryOperands);
131         return (memory_operands & SpvMemoryAccessVolatileMask) != 0;
132       },
133       funcs);
134 }
135 
HasInterfaceInConflictOfVolatileSemantics()136 bool SpreadVolatileSemantics::HasInterfaceInConflictOfVolatileSemantics() {
137   for (Instruction& entry_point : get_module()->entry_points()) {
138     SpvExecutionModel execution_model =
139         static_cast<SpvExecutionModel>(entry_point.GetSingleWordInOperand(0));
140     for (uint32_t operand_index = kOpEntryPointInOperandInterface;
141          operand_index < entry_point.NumInOperands(); ++operand_index) {
142       uint32_t var_id = entry_point.GetSingleWordInOperand(operand_index);
143       if (!EntryFunctionsToSpreadVolatileSemanticsForVar(var_id).empty() &&
144           !IsTargetForVolatileSemantics(var_id, execution_model) &&
145           IsTargetUsedByNonVolatileLoadInEntryPoint(var_id, &entry_point)) {
146         Instruction* inst = context()->get_def_use_mgr()->GetDef(var_id);
147         context()->EmitErrorMessage(
148             "Variable is a target for Volatile semantics for an entry point, "
149             "but it is not for another entry point",
150             inst);
151         return true;
152       }
153     }
154   }
155   return false;
156 }
157 
MarkVolatileSemanticsForVariable(uint32_t var_id,Instruction * entry_point)158 void SpreadVolatileSemantics::MarkVolatileSemanticsForVariable(
159     uint32_t var_id, Instruction* entry_point) {
160   uint32_t entry_function_id =
161       entry_point->GetSingleWordInOperand(kOpEntryPointInOperandEntryPoint);
162   auto itr = var_ids_to_entry_fn_for_volatile_semantics_.find(var_id);
163   if (itr == var_ids_to_entry_fn_for_volatile_semantics_.end()) {
164     var_ids_to_entry_fn_for_volatile_semantics_[var_id] = {entry_function_id};
165     return;
166   }
167   itr->second.insert(entry_function_id);
168 }
169 
CollectTargetsForVolatileSemantics(const bool is_vk_memory_model_enabled)170 void SpreadVolatileSemantics::CollectTargetsForVolatileSemantics(
171     const bool is_vk_memory_model_enabled) {
172   for (Instruction& entry_point : get_module()->entry_points()) {
173     SpvExecutionModel execution_model =
174         static_cast<SpvExecutionModel>(entry_point.GetSingleWordInOperand(0));
175     for (uint32_t operand_index = kOpEntryPointInOperandInterface;
176          operand_index < entry_point.NumInOperands(); ++operand_index) {
177       uint32_t var_id = entry_point.GetSingleWordInOperand(operand_index);
178       if (!IsTargetForVolatileSemantics(var_id, execution_model)) {
179         continue;
180       }
181       if (is_vk_memory_model_enabled ||
182           IsTargetUsedByNonVolatileLoadInEntryPoint(var_id, &entry_point)) {
183         MarkVolatileSemanticsForVariable(var_id, &entry_point);
184       }
185     }
186   }
187 }
188 
DecorateVarWithVolatile(Instruction * var)189 void SpreadVolatileSemantics::DecorateVarWithVolatile(Instruction* var) {
190   analysis::DecorationManager* decoration_manager =
191       context()->get_decoration_mgr();
192   uint32_t var_id = var->result_id();
193   if (HasVolatileDecoration(decoration_manager, var_id)) {
194     return;
195   }
196   get_decoration_mgr()->AddDecoration(
197       SpvOpDecorate, {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {var_id}},
198                       {spv_operand_type_t::SPV_OPERAND_TYPE_LITERAL_INTEGER,
199                        {SpvDecorationVolatile}}});
200 }
201 
VisitLoadsOfPointersToVariableInEntries(uint32_t var_id,const std::function<bool (Instruction *)> & handle_load,const std::unordered_set<uint32_t> & function_ids)202 bool SpreadVolatileSemantics::VisitLoadsOfPointersToVariableInEntries(
203     uint32_t var_id, const std::function<bool(Instruction*)>& handle_load,
204     const std::unordered_set<uint32_t>& function_ids) {
205   std::vector<uint32_t> worklist({var_id});
206   auto* def_use_mgr = context()->get_def_use_mgr();
207   while (!worklist.empty()) {
208     uint32_t ptr_id = worklist.back();
209     worklist.pop_back();
210     bool finish_traversal = !def_use_mgr->WhileEachUser(
211         ptr_id, [this, &worklist, &ptr_id, handle_load,
212                  &function_ids](Instruction* user) {
213           BasicBlock* block = context()->get_instr_block(user);
214           if (block == nullptr ||
215               function_ids.find(block->GetParent()->result_id()) ==
216                   function_ids.end()) {
217             return true;
218           }
219 
220           if (user->opcode() == SpvOpAccessChain ||
221               user->opcode() == SpvOpInBoundsAccessChain ||
222               user->opcode() == SpvOpPtrAccessChain ||
223               user->opcode() == SpvOpInBoundsPtrAccessChain ||
224               user->opcode() == SpvOpCopyObject) {
225             if (ptr_id == user->GetSingleWordInOperand(0))
226               worklist.push_back(user->result_id());
227             return true;
228           }
229 
230           if (user->opcode() != SpvOpLoad) {
231             return true;
232           }
233 
234           return handle_load(user);
235         });
236     if (finish_traversal) return false;
237   }
238   return true;
239 }
240 
SetVolatileForLoadsInEntries(Instruction * var,const std::unordered_set<uint32_t> & entry_function_ids)241 void SpreadVolatileSemantics::SetVolatileForLoadsInEntries(
242     Instruction* var, const std::unordered_set<uint32_t>& entry_function_ids) {
243   // Set Volatile memory operand for all load instructions if they do not have
244   // it.
245   for (auto entry_id : entry_function_ids) {
246     std::unordered_set<uint32_t> funcs;
247     context()->CollectCallTreeFromRoots(entry_id, &funcs);
248     VisitLoadsOfPointersToVariableInEntries(
249         var->result_id(),
250         [](Instruction* load) {
251           if (load->NumInOperands() <= kOpLoadInOperandMemoryOperands) {
252             load->AddOperand({SPV_OPERAND_TYPE_MEMORY_ACCESS,
253                               {SpvMemoryAccessVolatileMask}});
254             return true;
255           }
256           uint32_t memory_operands =
257               load->GetSingleWordInOperand(kOpLoadInOperandMemoryOperands);
258           memory_operands |= SpvMemoryAccessVolatileMask;
259           load->SetInOperand(kOpLoadInOperandMemoryOperands, {memory_operands});
260           return true;
261         },
262         funcs);
263   }
264 }
265 
IsTargetForVolatileSemantics(uint32_t var_id,SpvExecutionModel execution_model)266 bool SpreadVolatileSemantics::IsTargetForVolatileSemantics(
267     uint32_t var_id, SpvExecutionModel execution_model) {
268   analysis::DecorationManager* decoration_manager =
269       context()->get_decoration_mgr();
270   if (execution_model == SpvExecutionModelFragment) {
271     return get_module()->version() >= SPV_SPIRV_VERSION_WORD(1, 6) &&
272            HasBuiltinDecoration(decoration_manager, var_id,
273                                 SpvBuiltInHelperInvocation);
274   }
275 
276   if (execution_model == SpvExecutionModelIntersectionKHR ||
277       execution_model == SpvExecutionModelIntersectionNV) {
278     if (HasBuiltinDecoration(decoration_manager, var_id,
279                              SpvBuiltInRayTmaxKHR)) {
280       return true;
281     }
282   }
283 
284   switch (execution_model) {
285     case SpvExecutionModelRayGenerationKHR:
286     case SpvExecutionModelClosestHitKHR:
287     case SpvExecutionModelMissKHR:
288     case SpvExecutionModelCallableKHR:
289     case SpvExecutionModelIntersectionKHR:
290       return HasBuiltinForRayTracingVolatileSemantics(decoration_manager,
291                                                       var_id);
292     default:
293       return false;
294   }
295 }
296 
297 }  // namespace opt
298 }  // namespace spvtools
299