1 // Copyright (c) 2022 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "source/opt/spread_volatile_semantics.h"
16
17 #include "source/opt/decoration_manager.h"
18 #include "source/opt/ir_builder.h"
19 #include "source/spirv_constant.h"
20
21 namespace spvtools {
22 namespace opt {
23 namespace {
24
25 const uint32_t kOpDecorateInOperandBuiltinDecoration = 2u;
26 const uint32_t kOpLoadInOperandMemoryOperands = 1u;
27 const uint32_t kOpEntryPointInOperandEntryPoint = 1u;
28 const uint32_t kOpEntryPointInOperandInterface = 3u;
29
HasBuiltinDecoration(analysis::DecorationManager * decoration_manager,uint32_t var_id,uint32_t built_in)30 bool HasBuiltinDecoration(analysis::DecorationManager* decoration_manager,
31 uint32_t var_id, uint32_t built_in) {
32 return decoration_manager->FindDecoration(
33 var_id, SpvDecorationBuiltIn, [built_in](const Instruction& inst) {
34 return built_in == inst.GetSingleWordInOperand(
35 kOpDecorateInOperandBuiltinDecoration);
36 });
37 }
38
IsBuiltInForRayTracingVolatileSemantics(uint32_t built_in)39 bool IsBuiltInForRayTracingVolatileSemantics(uint32_t built_in) {
40 switch (built_in) {
41 case SpvBuiltInSMIDNV:
42 case SpvBuiltInWarpIDNV:
43 case SpvBuiltInSubgroupSize:
44 case SpvBuiltInSubgroupLocalInvocationId:
45 case SpvBuiltInSubgroupEqMask:
46 case SpvBuiltInSubgroupGeMask:
47 case SpvBuiltInSubgroupGtMask:
48 case SpvBuiltInSubgroupLeMask:
49 case SpvBuiltInSubgroupLtMask:
50 return true;
51 default:
52 return false;
53 }
54 }
55
HasBuiltinForRayTracingVolatileSemantics(analysis::DecorationManager * decoration_manager,uint32_t var_id)56 bool HasBuiltinForRayTracingVolatileSemantics(
57 analysis::DecorationManager* decoration_manager, uint32_t var_id) {
58 return decoration_manager->FindDecoration(
59 var_id, SpvDecorationBuiltIn, [](const Instruction& inst) {
60 uint32_t built_in =
61 inst.GetSingleWordInOperand(kOpDecorateInOperandBuiltinDecoration);
62 return IsBuiltInForRayTracingVolatileSemantics(built_in);
63 });
64 }
65
HasVolatileDecoration(analysis::DecorationManager * decoration_manager,uint32_t var_id)66 bool HasVolatileDecoration(analysis::DecorationManager* decoration_manager,
67 uint32_t var_id) {
68 return decoration_manager->HasDecoration(var_id, SpvDecorationVolatile);
69 }
70
HasOnlyEntryPointsAsFunctions(IRContext * context,Module * module)71 bool HasOnlyEntryPointsAsFunctions(IRContext* context, Module* module) {
72 std::unordered_set<uint32_t> entry_function_ids;
73 for (Instruction& entry_point : module->entry_points()) {
74 entry_function_ids.insert(
75 entry_point.GetSingleWordInOperand(kOpEntryPointInOperandEntryPoint));
76 }
77 for (auto& function : *module) {
78 if (entry_function_ids.find(function.result_id()) ==
79 entry_function_ids.end()) {
80 std::string message(
81 "Functions of SPIR-V for spread-volatile-semantics pass input must "
82 "be inlined except entry points");
83 message += "\n " + function.DefInst().PrettyPrint(
84 SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES);
85 context->consumer()(SPV_MSG_ERROR, "", {0, 0, 0}, message.c_str());
86 return false;
87 }
88 }
89 return true;
90 }
91
92 } // namespace
93
Process()94 Pass::Status SpreadVolatileSemantics::Process() {
95 if (!HasOnlyEntryPointsAsFunctions(context(), get_module())) {
96 return Status::Failure;
97 }
98
99 const bool is_vk_memory_model_enabled =
100 context()->get_feature_mgr()->HasCapability(
101 SpvCapabilityVulkanMemoryModel);
102 CollectTargetsForVolatileSemantics(is_vk_memory_model_enabled);
103
104 // If VulkanMemoryModel capability is not enabled, we have to set Volatile
105 // decoration for interface variables instead of setting Volatile for load
106 // instructions. If an interface (or pointers to it) is used by two load
107 // instructions in two entry points and one must be volatile while another
108 // is not, we have to report an error for the conflict.
109 if (!is_vk_memory_model_enabled &&
110 HasInterfaceInConflictOfVolatileSemantics()) {
111 return Status::Failure;
112 }
113
114 return SpreadVolatileSemanticsToVariables(is_vk_memory_model_enabled);
115 }
116
SpreadVolatileSemanticsToVariables(const bool is_vk_memory_model_enabled)117 Pass::Status SpreadVolatileSemantics::SpreadVolatileSemanticsToVariables(
118 const bool is_vk_memory_model_enabled) {
119 Status status = Status::SuccessWithoutChange;
120 for (Instruction& var : context()->types_values()) {
121 auto entry_function_ids =
122 EntryFunctionsToSpreadVolatileSemanticsForVar(var.result_id());
123 if (entry_function_ids.empty()) {
124 continue;
125 }
126
127 if (is_vk_memory_model_enabled) {
128 SetVolatileForLoadsInEntries(&var, entry_function_ids);
129 } else {
130 DecorateVarWithVolatile(&var);
131 }
132 status = Status::SuccessWithChange;
133 }
134 return status;
135 }
136
IsTargetUsedByNonVolatileLoadInEntryPoint(uint32_t var_id,Instruction * entry_point)137 bool SpreadVolatileSemantics::IsTargetUsedByNonVolatileLoadInEntryPoint(
138 uint32_t var_id, Instruction* entry_point) {
139 uint32_t entry_function_id =
140 entry_point->GetSingleWordInOperand(kOpEntryPointInOperandEntryPoint);
141 return !VisitLoadsOfPointersToVariableInEntries(
142 var_id,
143 [](Instruction* load) {
144 // If it has a load without volatile memory operand, finish traversal
145 // and return false.
146 if (load->NumInOperands() <= kOpLoadInOperandMemoryOperands) {
147 return false;
148 }
149 uint32_t memory_operands =
150 load->GetSingleWordInOperand(kOpLoadInOperandMemoryOperands);
151 return (memory_operands & SpvMemoryAccessVolatileMask) != 0;
152 },
153 {entry_function_id});
154 }
155
HasInterfaceInConflictOfVolatileSemantics()156 bool SpreadVolatileSemantics::HasInterfaceInConflictOfVolatileSemantics() {
157 for (Instruction& entry_point : get_module()->entry_points()) {
158 SpvExecutionModel execution_model =
159 static_cast<SpvExecutionModel>(entry_point.GetSingleWordInOperand(0));
160 for (uint32_t operand_index = kOpEntryPointInOperandInterface;
161 operand_index < entry_point.NumInOperands(); ++operand_index) {
162 uint32_t var_id = entry_point.GetSingleWordInOperand(operand_index);
163 if (!EntryFunctionsToSpreadVolatileSemanticsForVar(var_id).empty() &&
164 !IsTargetForVolatileSemantics(var_id, execution_model) &&
165 IsTargetUsedByNonVolatileLoadInEntryPoint(var_id, &entry_point)) {
166 Instruction* inst = context()->get_def_use_mgr()->GetDef(var_id);
167 context()->EmitErrorMessage(
168 "Variable is a target for Volatile semantics for an entry point, "
169 "but it is not for another entry point",
170 inst);
171 return true;
172 }
173 }
174 }
175 return false;
176 }
177
MarkVolatileSemanticsForVariable(uint32_t var_id,Instruction * entry_point)178 void SpreadVolatileSemantics::MarkVolatileSemanticsForVariable(
179 uint32_t var_id, Instruction* entry_point) {
180 uint32_t entry_function_id =
181 entry_point->GetSingleWordInOperand(kOpEntryPointInOperandEntryPoint);
182 auto itr = var_ids_to_entry_fn_for_volatile_semantics_.find(var_id);
183 if (itr == var_ids_to_entry_fn_for_volatile_semantics_.end()) {
184 var_ids_to_entry_fn_for_volatile_semantics_[var_id] = {entry_function_id};
185 return;
186 }
187 itr->second.insert(entry_function_id);
188 }
189
CollectTargetsForVolatileSemantics(const bool is_vk_memory_model_enabled)190 void SpreadVolatileSemantics::CollectTargetsForVolatileSemantics(
191 const bool is_vk_memory_model_enabled) {
192 for (Instruction& entry_point : get_module()->entry_points()) {
193 SpvExecutionModel execution_model =
194 static_cast<SpvExecutionModel>(entry_point.GetSingleWordInOperand(0));
195 for (uint32_t operand_index = kOpEntryPointInOperandInterface;
196 operand_index < entry_point.NumInOperands(); ++operand_index) {
197 uint32_t var_id = entry_point.GetSingleWordInOperand(operand_index);
198 if (!IsTargetForVolatileSemantics(var_id, execution_model)) {
199 continue;
200 }
201 if (is_vk_memory_model_enabled ||
202 IsTargetUsedByNonVolatileLoadInEntryPoint(var_id, &entry_point)) {
203 MarkVolatileSemanticsForVariable(var_id, &entry_point);
204 }
205 }
206 }
207 }
208
DecorateVarWithVolatile(Instruction * var)209 void SpreadVolatileSemantics::DecorateVarWithVolatile(Instruction* var) {
210 analysis::DecorationManager* decoration_manager =
211 context()->get_decoration_mgr();
212 uint32_t var_id = var->result_id();
213 if (HasVolatileDecoration(decoration_manager, var_id)) {
214 return;
215 }
216 get_decoration_mgr()->AddDecoration(
217 SpvOpDecorate, {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {var_id}},
218 {spv_operand_type_t::SPV_OPERAND_TYPE_LITERAL_INTEGER,
219 {SpvDecorationVolatile}}});
220 }
221
VisitLoadsOfPointersToVariableInEntries(uint32_t var_id,const std::function<bool (Instruction *)> & handle_load,const std::unordered_set<uint32_t> & entry_function_ids)222 bool SpreadVolatileSemantics::VisitLoadsOfPointersToVariableInEntries(
223 uint32_t var_id, const std::function<bool(Instruction*)>& handle_load,
224 const std::unordered_set<uint32_t>& entry_function_ids) {
225 std::vector<uint32_t> worklist({var_id});
226 auto* def_use_mgr = context()->get_def_use_mgr();
227 while (!worklist.empty()) {
228 uint32_t ptr_id = worklist.back();
229 worklist.pop_back();
230 bool finish_traversal = !def_use_mgr->WhileEachUser(
231 ptr_id, [this, &worklist, &ptr_id, handle_load,
232 &entry_function_ids](Instruction* user) {
233 BasicBlock* block = context()->get_instr_block(user);
234 if (block == nullptr ||
235 entry_function_ids.find(block->GetParent()->result_id()) ==
236 entry_function_ids.end()) {
237 return true;
238 }
239
240 if (user->opcode() == SpvOpAccessChain ||
241 user->opcode() == SpvOpInBoundsAccessChain ||
242 user->opcode() == SpvOpPtrAccessChain ||
243 user->opcode() == SpvOpInBoundsPtrAccessChain ||
244 user->opcode() == SpvOpCopyObject) {
245 if (ptr_id == user->GetSingleWordInOperand(0))
246 worklist.push_back(user->result_id());
247 return true;
248 }
249
250 if (user->opcode() != SpvOpLoad) {
251 return true;
252 }
253
254 return handle_load(user);
255 });
256 if (finish_traversal) return false;
257 }
258 return true;
259 }
260
SetVolatileForLoadsInEntries(Instruction * var,const std::unordered_set<uint32_t> & entry_function_ids)261 void SpreadVolatileSemantics::SetVolatileForLoadsInEntries(
262 Instruction* var, const std::unordered_set<uint32_t>& entry_function_ids) {
263 // Set Volatile memory operand for all load instructions if they do not have
264 // it.
265 VisitLoadsOfPointersToVariableInEntries(
266 var->result_id(),
267 [](Instruction* load) {
268 if (load->NumInOperands() <= kOpLoadInOperandMemoryOperands) {
269 load->AddOperand(
270 {SPV_OPERAND_TYPE_MEMORY_ACCESS, {SpvMemoryAccessVolatileMask}});
271 return true;
272 }
273 uint32_t memory_operands =
274 load->GetSingleWordInOperand(kOpLoadInOperandMemoryOperands);
275 memory_operands |= SpvMemoryAccessVolatileMask;
276 load->SetInOperand(kOpLoadInOperandMemoryOperands, {memory_operands});
277 return true;
278 },
279 entry_function_ids);
280 }
281
IsTargetForVolatileSemantics(uint32_t var_id,SpvExecutionModel execution_model)282 bool SpreadVolatileSemantics::IsTargetForVolatileSemantics(
283 uint32_t var_id, SpvExecutionModel execution_model) {
284 analysis::DecorationManager* decoration_manager =
285 context()->get_decoration_mgr();
286 if (execution_model == SpvExecutionModelFragment) {
287 return get_module()->version() >= SPV_SPIRV_VERSION_WORD(1, 6) &&
288 HasBuiltinDecoration(decoration_manager, var_id,
289 SpvBuiltInHelperInvocation);
290 }
291
292 if (execution_model == SpvExecutionModelIntersectionKHR ||
293 execution_model == SpvExecutionModelIntersectionNV) {
294 if (HasBuiltinDecoration(decoration_manager, var_id,
295 SpvBuiltInRayTmaxKHR)) {
296 return true;
297 }
298 }
299
300 switch (execution_model) {
301 case SpvExecutionModelRayGenerationKHR:
302 case SpvExecutionModelClosestHitKHR:
303 case SpvExecutionModelMissKHR:
304 case SpvExecutionModelCallableKHR:
305 case SpvExecutionModelIntersectionKHR:
306 return HasBuiltinForRayTracingVolatileSemantics(decoration_manager,
307 var_id);
308 default:
309 return false;
310 }
311 }
312
313 } // namespace opt
314 } // namespace spvtools
315