• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2018 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "upgrade_memory_model.h"
16 
17 #include <utility>
18 
19 #include "source/opt/ir_builder.h"
20 #include "source/opt/ir_context.h"
21 #include "source/spirv_constant.h"
22 #include "source/util/make_unique.h"
23 #include "source/util/string_utils.h"
24 
25 namespace spvtools {
26 namespace opt {
27 
Process()28 Pass::Status UpgradeMemoryModel::Process() {
29   // TODO: This pass needs changes to support cooperative matrices.
30   if (context()->get_feature_mgr()->HasCapability(
31           spv::Capability::CooperativeMatrixNV)) {
32     return Pass::Status::SuccessWithoutChange;
33   }
34 
35   // Only update Logical GLSL450 to Logical VulkanKHR.
36   Instruction* memory_model = get_module()->GetMemoryModel();
37   if (memory_model->GetSingleWordInOperand(0u) !=
38           uint32_t(spv::AddressingModel::Logical) ||
39       memory_model->GetSingleWordInOperand(1u) !=
40           uint32_t(spv::MemoryModel::GLSL450)) {
41     return Pass::Status::SuccessWithoutChange;
42   }
43 
44   UpgradeMemoryModelInstruction();
45   UpgradeInstructions();
46   CleanupDecorations();
47   UpgradeBarriers();
48   UpgradeMemoryScope();
49 
50   return Pass::Status::SuccessWithChange;
51 }
52 
UpgradeMemoryModelInstruction()53 void UpgradeMemoryModel::UpgradeMemoryModelInstruction() {
54   // Overall changes necessary:
55   // 1. Add the OpExtension.
56   // 2. Add the OpCapability.
57   // 3. Modify the memory model.
58   Instruction* memory_model = get_module()->GetMemoryModel();
59   context()->AddCapability(MakeUnique<Instruction>(
60       context(), spv::Op::OpCapability, 0, 0,
61       std::initializer_list<Operand>{
62           {SPV_OPERAND_TYPE_CAPABILITY,
63            {uint32_t(spv::Capability::VulkanMemoryModelKHR)}}}));
64   const std::string extension = "SPV_KHR_vulkan_memory_model";
65   std::vector<uint32_t> words = spvtools::utils::MakeVector(extension);
66   context()->AddExtension(
67       MakeUnique<Instruction>(context(), spv::Op::OpExtension, 0, 0,
68                               std::initializer_list<Operand>{
69                                   {SPV_OPERAND_TYPE_LITERAL_STRING, words}}));
70   memory_model->SetInOperand(1u, {uint32_t(spv::MemoryModel::VulkanKHR)});
71 }
72 
UpgradeInstructions()73 void UpgradeMemoryModel::UpgradeInstructions() {
74   // Coherent and Volatile decorations are deprecated. Remove them and replace
75   // with flags on the memory/image operations. The decorations can occur on
76   // OpVariable, OpFunctionParameter (of pointer type) and OpStructType (member
77   // decoration). Trace from the decoration target(s) to the final memory/image
78   // instructions. Additionally, Workgroup storage class variables and function
79   // parameters are implicitly coherent in GLSL450.
80 
81   // Upgrade modf and frexp first since they generate new stores.
82   // In SPIR-V 1.4 or later, normalize OpCopyMemory* access operands.
83   for (auto& func : *get_module()) {
84     func.ForEachInst([this](Instruction* inst) {
85       if (inst->opcode() == spv::Op::OpExtInst) {
86         auto ext_inst = inst->GetSingleWordInOperand(1u);
87         if (ext_inst == GLSLstd450Modf || ext_inst == GLSLstd450Frexp) {
88           auto import =
89               get_def_use_mgr()->GetDef(inst->GetSingleWordInOperand(0u));
90           if (import->GetInOperand(0u).AsString() == "GLSL.std.450") {
91             UpgradeExtInst(inst);
92           }
93         }
94       } else if (get_module()->version() >= SPV_SPIRV_VERSION_WORD(1, 4)) {
95         if (inst->opcode() == spv::Op::OpCopyMemory ||
96             inst->opcode() == spv::Op::OpCopyMemorySized) {
97           uint32_t start_operand =
98               inst->opcode() == spv::Op::OpCopyMemory ? 2u : 3u;
99           if (inst->NumInOperands() > start_operand) {
100             auto num_access_words = MemoryAccessNumWords(
101                 inst->GetSingleWordInOperand(start_operand));
102             if ((num_access_words + start_operand) == inst->NumInOperands()) {
103               // There is a single memory access operand. Duplicate it to have a
104               // separate operand for both source and target.
105               for (uint32_t i = 0; i < num_access_words; ++i) {
106                 auto operand = inst->GetInOperand(start_operand + i);
107                 inst->AddOperand(std::move(operand));
108               }
109             }
110           } else {
111             // Add two memory access operands.
112             inst->AddOperand({SPV_OPERAND_TYPE_MEMORY_ACCESS,
113                               {uint32_t(spv::MemoryAccessMask::MaskNone)}});
114             inst->AddOperand({SPV_OPERAND_TYPE_MEMORY_ACCESS,
115                               {uint32_t(spv::MemoryAccessMask::MaskNone)}});
116           }
117         }
118       }
119     });
120   }
121 
122   UpgradeMemoryAndImages();
123   UpgradeAtomics();
124 }
125 
UpgradeMemoryAndImages()126 void UpgradeMemoryModel::UpgradeMemoryAndImages() {
127   for (auto& func : *get_module()) {
128     func.ForEachInst([this](Instruction* inst) {
129       bool is_coherent = false;
130       bool is_volatile = false;
131       bool src_coherent = false;
132       bool src_volatile = false;
133       bool dst_coherent = false;
134       bool dst_volatile = false;
135       uint32_t start_operand = 0u;
136       spv::Scope scope = spv::Scope::QueueFamilyKHR;
137       spv::Scope src_scope = spv::Scope::QueueFamilyKHR;
138       spv::Scope dst_scope = spv::Scope::QueueFamilyKHR;
139       switch (inst->opcode()) {
140         case spv::Op::OpLoad:
141         case spv::Op::OpStore:
142           std::tie(is_coherent, is_volatile, scope) =
143               GetInstructionAttributes(inst->GetSingleWordInOperand(0u));
144           break;
145         case spv::Op::OpImageRead:
146         case spv::Op::OpImageSparseRead:
147         case spv::Op::OpImageWrite:
148           std::tie(is_coherent, is_volatile, scope) =
149               GetInstructionAttributes(inst->GetSingleWordInOperand(0u));
150           break;
151         case spv::Op::OpCopyMemory:
152         case spv::Op::OpCopyMemorySized:
153           std::tie(dst_coherent, dst_volatile, dst_scope) =
154               GetInstructionAttributes(inst->GetSingleWordInOperand(0u));
155           std::tie(src_coherent, src_volatile, src_scope) =
156               GetInstructionAttributes(inst->GetSingleWordInOperand(1u));
157           break;
158         default:
159           break;
160       }
161 
162       switch (inst->opcode()) {
163         case spv::Op::OpLoad:
164           UpgradeFlags(inst, 1u, is_coherent, is_volatile, kVisibility,
165                        kMemory);
166           break;
167         case spv::Op::OpStore:
168           UpgradeFlags(inst, 2u, is_coherent, is_volatile, kAvailability,
169                        kMemory);
170           break;
171         case spv::Op::OpCopyMemory:
172         case spv::Op::OpCopyMemorySized:
173           start_operand = inst->opcode() == spv::Op::OpCopyMemory ? 2u : 3u;
174           if (get_module()->version() >= SPV_SPIRV_VERSION_WORD(1, 4)) {
175             // There are guaranteed to be two memory access operands at this
176             // point so treat source and target separately.
177             uint32_t num_access_words = MemoryAccessNumWords(
178                 inst->GetSingleWordInOperand(start_operand));
179             UpgradeFlags(inst, start_operand, dst_coherent, dst_volatile,
180                          kAvailability, kMemory);
181             UpgradeFlags(inst, start_operand + num_access_words, src_coherent,
182                          src_volatile, kVisibility, kMemory);
183           } else {
184             UpgradeFlags(inst, start_operand, dst_coherent, dst_volatile,
185                          kAvailability, kMemory);
186             UpgradeFlags(inst, start_operand, src_coherent, src_volatile,
187                          kVisibility, kMemory);
188           }
189           break;
190         case spv::Op::OpImageRead:
191         case spv::Op::OpImageSparseRead:
192           UpgradeFlags(inst, 2u, is_coherent, is_volatile, kVisibility, kImage);
193           break;
194         case spv::Op::OpImageWrite:
195           UpgradeFlags(inst, 3u, is_coherent, is_volatile, kAvailability,
196                        kImage);
197           break;
198         default:
199           break;
200       }
201 
202       // |is_coherent| is never used for the same instructions as
203       // |src_coherent| and |dst_coherent|.
204       if (is_coherent) {
205         inst->AddOperand(
206             {SPV_OPERAND_TYPE_SCOPE_ID, {GetScopeConstant(scope)}});
207       }
208       if (get_module()->version() >= SPV_SPIRV_VERSION_WORD(1, 4)) {
209         // There are two memory access operands. The first is for the target and
210         // the second is for the source.
211         if (dst_coherent || src_coherent) {
212           start_operand = inst->opcode() == spv::Op::OpCopyMemory ? 2u : 3u;
213           std::vector<Operand> new_operands;
214           uint32_t num_access_words =
215               MemoryAccessNumWords(inst->GetSingleWordInOperand(start_operand));
216           // The flags were already updated so subtract if we're adding a
217           // scope.
218           if (dst_coherent) --num_access_words;
219           for (uint32_t i = 0; i < start_operand + num_access_words; ++i) {
220             new_operands.push_back(inst->GetInOperand(i));
221           }
222           // Add the target scope if necessary.
223           if (dst_coherent) {
224             new_operands.push_back(
225                 {SPV_OPERAND_TYPE_SCOPE_ID, {GetScopeConstant(dst_scope)}});
226           }
227           // Copy the remaining current operands.
228           for (uint32_t i = start_operand + num_access_words;
229                i < inst->NumInOperands(); ++i) {
230             new_operands.push_back(inst->GetInOperand(i));
231           }
232           // Add the source scope if necessary.
233           if (src_coherent) {
234             new_operands.push_back(
235                 {SPV_OPERAND_TYPE_SCOPE_ID, {GetScopeConstant(src_scope)}});
236           }
237           inst->SetInOperands(std::move(new_operands));
238         }
239       } else {
240         // According to SPV_KHR_vulkan_memory_model, if both available and
241         // visible flags are used the first scope operand is for availability
242         // (writes) and the second is for visibility (reads).
243         if (dst_coherent) {
244           inst->AddOperand(
245               {SPV_OPERAND_TYPE_SCOPE_ID, {GetScopeConstant(dst_scope)}});
246         }
247         if (src_coherent) {
248           inst->AddOperand(
249               {SPV_OPERAND_TYPE_SCOPE_ID, {GetScopeConstant(src_scope)}});
250         }
251       }
252     });
253   }
254 }
255 
UpgradeAtomics()256 void UpgradeMemoryModel::UpgradeAtomics() {
257   for (auto& func : *get_module()) {
258     func.ForEachInst([this](Instruction* inst) {
259       if (spvOpcodeIsAtomicOp(inst->opcode())) {
260         bool unused_coherent = false;
261         bool is_volatile = false;
262         spv::Scope unused_scope = spv::Scope::QueueFamilyKHR;
263         std::tie(unused_coherent, is_volatile, unused_scope) =
264             GetInstructionAttributes(inst->GetSingleWordInOperand(0));
265 
266         UpgradeSemantics(inst, 2u, is_volatile);
267         if (inst->opcode() == spv::Op::OpAtomicCompareExchange ||
268             inst->opcode() == spv::Op::OpAtomicCompareExchangeWeak) {
269           UpgradeSemantics(inst, 3u, is_volatile);
270         }
271       }
272     });
273   }
274 }
275 
UpgradeSemantics(Instruction * inst,uint32_t in_operand,bool is_volatile)276 void UpgradeMemoryModel::UpgradeSemantics(Instruction* inst,
277                                           uint32_t in_operand,
278                                           bool is_volatile) {
279   if (!is_volatile) return;
280 
281   uint32_t semantics_id = inst->GetSingleWordInOperand(in_operand);
282   const analysis::Constant* constant =
283       context()->get_constant_mgr()->FindDeclaredConstant(semantics_id);
284   const analysis::Integer* type = constant->type()->AsInteger();
285   assert(type && type->width() == 32);
286   uint32_t value = 0;
287   if (type->IsSigned()) {
288     value = static_cast<uint32_t>(constant->GetS32());
289   } else {
290     value = constant->GetU32();
291   }
292 
293   value |= uint32_t(spv::MemorySemanticsMask::Volatile);
294   auto new_constant = context()->get_constant_mgr()->GetConstant(type, {value});
295   auto new_semantics =
296       context()->get_constant_mgr()->GetDefiningInstruction(new_constant);
297   inst->SetInOperand(in_operand, {new_semantics->result_id()});
298 }
299 
GetInstructionAttributes(uint32_t id)300 std::tuple<bool, bool, spv::Scope> UpgradeMemoryModel::GetInstructionAttributes(
301     uint32_t id) {
302   // |id| is a pointer used in a memory/image instruction. Need to determine if
303   // that pointer points to volatile or coherent memory. Workgroup storage
304   // class is implicitly coherent and cannot be decorated with volatile, so
305   // short circuit that case.
306   Instruction* inst = context()->get_def_use_mgr()->GetDef(id);
307   analysis::Type* type = context()->get_type_mgr()->GetType(inst->type_id());
308   if (type->AsPointer() &&
309       type->AsPointer()->storage_class() == spv::StorageClass::Workgroup) {
310     return std::make_tuple(true, false, spv::Scope::Workgroup);
311   }
312 
313   bool is_coherent = false;
314   bool is_volatile = false;
315   std::unordered_set<uint32_t> visited;
316   std::tie(is_coherent, is_volatile) =
317       TraceInstruction(context()->get_def_use_mgr()->GetDef(id),
318                        std::vector<uint32_t>(), &visited);
319 
320   return std::make_tuple(is_coherent, is_volatile, spv::Scope::QueueFamilyKHR);
321 }
322 
TraceInstruction(Instruction * inst,std::vector<uint32_t> indices,std::unordered_set<uint32_t> * visited)323 std::pair<bool, bool> UpgradeMemoryModel::TraceInstruction(
324     Instruction* inst, std::vector<uint32_t> indices,
325     std::unordered_set<uint32_t>* visited) {
326   auto iter = cache_.find(std::make_pair(inst->result_id(), indices));
327   if (iter != cache_.end()) {
328     return iter->second;
329   }
330 
331   if (!visited->insert(inst->result_id()).second) {
332     return std::make_pair(false, false);
333   }
334 
335   // Initialize the cache before |indices| is (potentially) modified.
336   auto& cached_result = cache_[std::make_pair(inst->result_id(), indices)];
337   cached_result.first = false;
338   cached_result.second = false;
339 
340   bool is_coherent = false;
341   bool is_volatile = false;
342   switch (inst->opcode()) {
343     case spv::Op::OpVariable:
344     case spv::Op::OpFunctionParameter:
345       is_coherent |= HasDecoration(inst, 0, spv::Decoration::Coherent);
346       is_volatile |= HasDecoration(inst, 0, spv::Decoration::Volatile);
347       if (!is_coherent || !is_volatile) {
348         bool type_coherent = false;
349         bool type_volatile = false;
350         std::tie(type_coherent, type_volatile) =
351             CheckType(inst->type_id(), indices);
352         is_coherent |= type_coherent;
353         is_volatile |= type_volatile;
354       }
355       break;
356     case spv::Op::OpAccessChain:
357     case spv::Op::OpInBoundsAccessChain:
358       // Store indices in reverse order.
359       for (uint32_t i = inst->NumInOperands() - 1; i > 0; --i) {
360         indices.push_back(inst->GetSingleWordInOperand(i));
361       }
362       break;
363     case spv::Op::OpPtrAccessChain:
364       // Store indices in reverse order. Skip the |Element| operand.
365       for (uint32_t i = inst->NumInOperands() - 1; i > 1; --i) {
366         indices.push_back(inst->GetSingleWordInOperand(i));
367       }
368       break;
369     default:
370       break;
371   }
372 
373   // No point searching further.
374   if (is_coherent && is_volatile) {
375     cached_result.first = true;
376     cached_result.second = true;
377     return std::make_pair(true, true);
378   }
379 
380   // Variables and function parameters are sources. Continue searching until we
381   // reach them.
382   if (inst->opcode() != spv::Op::OpVariable &&
383       inst->opcode() != spv::Op::OpFunctionParameter) {
384     inst->ForEachInId([this, &is_coherent, &is_volatile, &indices,
385                        &visited](const uint32_t* id_ptr) {
386       Instruction* op_inst = context()->get_def_use_mgr()->GetDef(*id_ptr);
387       const analysis::Type* type =
388           context()->get_type_mgr()->GetType(op_inst->type_id());
389       if (type &&
390           (type->AsPointer() || type->AsImage() || type->AsSampledImage())) {
391         bool operand_coherent = false;
392         bool operand_volatile = false;
393         std::tie(operand_coherent, operand_volatile) =
394             TraceInstruction(op_inst, indices, visited);
395         is_coherent |= operand_coherent;
396         is_volatile |= operand_volatile;
397       }
398     });
399   }
400 
401   cached_result.first = is_coherent;
402   cached_result.second = is_volatile;
403   return std::make_pair(is_coherent, is_volatile);
404 }
405 
CheckType(uint32_t type_id,const std::vector<uint32_t> & indices)406 std::pair<bool, bool> UpgradeMemoryModel::CheckType(
407     uint32_t type_id, const std::vector<uint32_t>& indices) {
408   bool is_coherent = false;
409   bool is_volatile = false;
410   Instruction* type_inst = context()->get_def_use_mgr()->GetDef(type_id);
411   assert(type_inst->opcode() == spv::Op::OpTypePointer);
412   Instruction* element_inst = context()->get_def_use_mgr()->GetDef(
413       type_inst->GetSingleWordInOperand(1u));
414   for (int i = (int)indices.size() - 1; i >= 0; --i) {
415     if (is_coherent && is_volatile) break;
416 
417     if (element_inst->opcode() == spv::Op::OpTypePointer) {
418       element_inst = context()->get_def_use_mgr()->GetDef(
419           element_inst->GetSingleWordInOperand(1u));
420     } else if (element_inst->opcode() == spv::Op::OpTypeStruct) {
421       uint32_t index = indices.at(i);
422       Instruction* index_inst = context()->get_def_use_mgr()->GetDef(index);
423       assert(index_inst->opcode() == spv::Op::OpConstant);
424       uint64_t value = GetIndexValue(index_inst);
425       is_coherent |= HasDecoration(element_inst, static_cast<uint32_t>(value),
426                                    spv::Decoration::Coherent);
427       is_volatile |= HasDecoration(element_inst, static_cast<uint32_t>(value),
428                                    spv::Decoration::Volatile);
429       element_inst = context()->get_def_use_mgr()->GetDef(
430           element_inst->GetSingleWordInOperand(static_cast<uint32_t>(value)));
431     } else {
432       assert(spvOpcodeIsComposite(element_inst->opcode()));
433       element_inst = context()->get_def_use_mgr()->GetDef(
434           element_inst->GetSingleWordInOperand(0u));
435     }
436   }
437 
438   if (!is_coherent || !is_volatile) {
439     bool remaining_coherent = false;
440     bool remaining_volatile = false;
441     std::tie(remaining_coherent, remaining_volatile) =
442         CheckAllTypes(element_inst);
443     is_coherent |= remaining_coherent;
444     is_volatile |= remaining_volatile;
445   }
446 
447   return std::make_pair(is_coherent, is_volatile);
448 }
449 
CheckAllTypes(const Instruction * inst)450 std::pair<bool, bool> UpgradeMemoryModel::CheckAllTypes(
451     const Instruction* inst) {
452   std::unordered_set<const Instruction*> visited;
453   std::vector<const Instruction*> stack;
454   stack.push_back(inst);
455 
456   bool is_coherent = false;
457   bool is_volatile = false;
458   while (!stack.empty()) {
459     const Instruction* def = stack.back();
460     stack.pop_back();
461 
462     if (!visited.insert(def).second) continue;
463 
464     if (def->opcode() == spv::Op::OpTypeStruct) {
465       // Any member decorated with coherent and/or volatile is enough to have
466       // the related operation be flagged as coherent and/or volatile.
467       is_coherent |= HasDecoration(def, std::numeric_limits<uint32_t>::max(),
468                                    spv::Decoration::Coherent);
469       is_volatile |= HasDecoration(def, std::numeric_limits<uint32_t>::max(),
470                                    spv::Decoration::Volatile);
471       if (is_coherent && is_volatile)
472         return std::make_pair(is_coherent, is_volatile);
473 
474       // Check the subtypes.
475       for (uint32_t i = 0; i < def->NumInOperands(); ++i) {
476         stack.push_back(context()->get_def_use_mgr()->GetDef(
477             def->GetSingleWordInOperand(i)));
478       }
479     } else if (spvOpcodeIsComposite(def->opcode())) {
480       stack.push_back(context()->get_def_use_mgr()->GetDef(
481           def->GetSingleWordInOperand(0u)));
482     } else if (def->opcode() == spv::Op::OpTypePointer) {
483       stack.push_back(context()->get_def_use_mgr()->GetDef(
484           def->GetSingleWordInOperand(1u)));
485     }
486   }
487 
488   return std::make_pair(is_coherent, is_volatile);
489 }
490 
GetIndexValue(Instruction * index_inst)491 uint64_t UpgradeMemoryModel::GetIndexValue(Instruction* index_inst) {
492   const analysis::Constant* index_constant =
493       context()->get_constant_mgr()->GetConstantFromInst(index_inst);
494   assert(index_constant->AsIntConstant());
495   if (index_constant->type()->AsInteger()->IsSigned()) {
496     if (index_constant->type()->AsInteger()->width() == 32) {
497       return index_constant->GetS32();
498     } else {
499       return index_constant->GetS64();
500     }
501   } else {
502     if (index_constant->type()->AsInteger()->width() == 32) {
503       return index_constant->GetU32();
504     } else {
505       return index_constant->GetU64();
506     }
507   }
508 }
509 
HasDecoration(const Instruction * inst,uint32_t value,spv::Decoration decoration)510 bool UpgradeMemoryModel::HasDecoration(const Instruction* inst, uint32_t value,
511                                        spv::Decoration decoration) {
512   // If the iteration was terminated early then an appropriate decoration was
513   // found.
514   return !context()->get_decoration_mgr()->WhileEachDecoration(
515       inst->result_id(), (uint32_t)decoration, [value](const Instruction& i) {
516         if (i.opcode() == spv::Op::OpDecorate ||
517             i.opcode() == spv::Op::OpDecorateId) {
518           return false;
519         } else if (i.opcode() == spv::Op::OpMemberDecorate) {
520           if (value == i.GetSingleWordInOperand(1u) ||
521               value == std::numeric_limits<uint32_t>::max())
522             return false;
523         }
524 
525         return true;
526       });
527 }
528 
UpgradeFlags(Instruction * inst,uint32_t in_operand,bool is_coherent,bool is_volatile,OperationType operation_type,InstructionType inst_type)529 void UpgradeMemoryModel::UpgradeFlags(Instruction* inst, uint32_t in_operand,
530                                       bool is_coherent, bool is_volatile,
531                                       OperationType operation_type,
532                                       InstructionType inst_type) {
533   if (!is_coherent && !is_volatile) return;
534 
535   uint32_t flags = 0;
536   if (inst->NumInOperands() > in_operand) {
537     flags |= inst->GetSingleWordInOperand(in_operand);
538   }
539   if (is_coherent) {
540     if (inst_type == kMemory) {
541       flags |= uint32_t(spv::MemoryAccessMask::NonPrivatePointerKHR);
542       if (operation_type == kVisibility) {
543         flags |= uint32_t(spv::MemoryAccessMask::MakePointerVisibleKHR);
544       } else {
545         flags |= uint32_t(spv::MemoryAccessMask::MakePointerAvailableKHR);
546       }
547     } else {
548       flags |= uint32_t(spv::ImageOperandsMask::NonPrivateTexelKHR);
549       if (operation_type == kVisibility) {
550         flags |= uint32_t(spv::ImageOperandsMask::MakeTexelVisibleKHR);
551       } else {
552         flags |= uint32_t(spv::ImageOperandsMask::MakeTexelAvailableKHR);
553       }
554     }
555   }
556 
557   if (is_volatile) {
558     if (inst_type == kMemory) {
559       flags |= uint32_t(spv::MemoryAccessMask::Volatile);
560     } else {
561       flags |= uint32_t(spv::ImageOperandsMask::VolatileTexelKHR);
562     }
563   }
564 
565   if (inst->NumInOperands() > in_operand) {
566     inst->SetInOperand(in_operand, {flags});
567   } else if (inst_type == kMemory) {
568     inst->AddOperand({SPV_OPERAND_TYPE_OPTIONAL_MEMORY_ACCESS, {flags}});
569   } else {
570     inst->AddOperand({SPV_OPERAND_TYPE_OPTIONAL_IMAGE, {flags}});
571   }
572 }
573 
GetScopeConstant(spv::Scope scope)574 uint32_t UpgradeMemoryModel::GetScopeConstant(spv::Scope scope) {
575   analysis::Integer int_ty(32, false);
576   uint32_t int_id = context()->get_type_mgr()->GetTypeInstruction(&int_ty);
577   const analysis::Constant* constant =
578       context()->get_constant_mgr()->GetConstant(
579           context()->get_type_mgr()->GetType(int_id),
580           {static_cast<uint32_t>(scope)});
581   return context()
582       ->get_constant_mgr()
583       ->GetDefiningInstruction(constant)
584       ->result_id();
585 }
586 
CleanupDecorations()587 void UpgradeMemoryModel::CleanupDecorations() {
588   // All of the volatile and coherent decorations have been dealt with, so now
589   // we can just remove them.
590   get_module()->ForEachInst([this](Instruction* inst) {
591     if (inst->result_id() != 0) {
592       context()->get_decoration_mgr()->RemoveDecorationsFrom(
593           inst->result_id(), [](const Instruction& dec) {
594             switch (dec.opcode()) {
595               case spv::Op::OpDecorate:
596               case spv::Op::OpDecorateId:
597                 if (spv::Decoration(dec.GetSingleWordInOperand(1u)) ==
598                         spv::Decoration::Coherent ||
599                     spv::Decoration(dec.GetSingleWordInOperand(1u)) ==
600                         spv::Decoration::Volatile)
601                   return true;
602                 break;
603               case spv::Op::OpMemberDecorate:
604                 if (spv::Decoration(dec.GetSingleWordInOperand(2u)) ==
605                         spv::Decoration::Coherent ||
606                     spv::Decoration(dec.GetSingleWordInOperand(2u)) ==
607                         spv::Decoration::Volatile)
608                   return true;
609                 break;
610               default:
611                 break;
612             }
613             return false;
614           });
615     }
616   });
617 }
618 
UpgradeBarriers()619 void UpgradeMemoryModel::UpgradeBarriers() {
620   std::vector<Instruction*> barriers;
621   // Collects all the control barriers in |function|. Returns true if the
622   // function operates on the Output storage class.
623   ProcessFunction CollectBarriers = [this, &barriers](Function* function) {
624     bool operates_on_output = false;
625     for (auto& block : *function) {
626       block.ForEachInst([this, &barriers,
627                          &operates_on_output](Instruction* inst) {
628         if (inst->opcode() == spv::Op::OpControlBarrier) {
629           barriers.push_back(inst);
630         } else if (!operates_on_output) {
631           // This instruction operates on output storage class if it is a
632           // pointer to output type or any input operand is a pointer to output
633           // type.
634           analysis::Type* type =
635               context()->get_type_mgr()->GetType(inst->type_id());
636           if (type && type->AsPointer() &&
637               type->AsPointer()->storage_class() == spv::StorageClass::Output) {
638             operates_on_output = true;
639             return;
640           }
641           inst->ForEachInId([this, &operates_on_output](uint32_t* id_ptr) {
642             Instruction* op_inst =
643                 context()->get_def_use_mgr()->GetDef(*id_ptr);
644             analysis::Type* op_type =
645                 context()->get_type_mgr()->GetType(op_inst->type_id());
646             if (op_type && op_type->AsPointer() &&
647                 op_type->AsPointer()->storage_class() ==
648                     spv::StorageClass::Output)
649               operates_on_output = true;
650           });
651         }
652       });
653     }
654     return operates_on_output;
655   };
656 
657   std::queue<uint32_t> roots;
658   for (auto& e : get_module()->entry_points())
659     if (spv::ExecutionModel(e.GetSingleWordInOperand(0u)) ==
660         spv::ExecutionModel::TessellationControl) {
661       roots.push(e.GetSingleWordInOperand(1u));
662       if (context()->ProcessCallTreeFromRoots(CollectBarriers, &roots)) {
663         for (auto barrier : barriers) {
664           // Add OutputMemoryKHR to the semantics of the barriers.
665           uint32_t semantics_id = barrier->GetSingleWordInOperand(2u);
666           Instruction* semantics_inst =
667               context()->get_def_use_mgr()->GetDef(semantics_id);
668           analysis::Type* semantics_type =
669               context()->get_type_mgr()->GetType(semantics_inst->type_id());
670           uint64_t semantics_value = GetIndexValue(semantics_inst);
671           const analysis::Constant* constant =
672               context()->get_constant_mgr()->GetConstant(
673                   semantics_type,
674                   {static_cast<uint32_t>(semantics_value) |
675                    uint32_t(spv::MemorySemanticsMask::OutputMemoryKHR)});
676           barrier->SetInOperand(2u, {context()
677                                          ->get_constant_mgr()
678                                          ->GetDefiningInstruction(constant)
679                                          ->result_id()});
680         }
681       }
682       barriers.clear();
683     }
684 }
685 
UpgradeMemoryScope()686 void UpgradeMemoryModel::UpgradeMemoryScope() {
687   get_module()->ForEachInst([this](Instruction* inst) {
688     // Don't need to handle all the operations that take a scope.
689     // * Group operations can only be subgroup
690     // * Non-uniform can only be workgroup or subgroup
691     // * Named barriers are not supported by Vulkan
692     // * Workgroup ops (e.g. async_copy) have at most workgroup scope.
693     if (spvOpcodeIsAtomicOp(inst->opcode())) {
694       if (IsDeviceScope(inst->GetSingleWordInOperand(1))) {
695         inst->SetInOperand(1, {GetScopeConstant(spv::Scope::QueueFamilyKHR)});
696       }
697     } else if (inst->opcode() == spv::Op::OpControlBarrier) {
698       if (IsDeviceScope(inst->GetSingleWordInOperand(1))) {
699         inst->SetInOperand(1, {GetScopeConstant(spv::Scope::QueueFamilyKHR)});
700       }
701     } else if (inst->opcode() == spv::Op::OpMemoryBarrier) {
702       if (IsDeviceScope(inst->GetSingleWordInOperand(0))) {
703         inst->SetInOperand(0, {GetScopeConstant(spv::Scope::QueueFamilyKHR)});
704       }
705     }
706   });
707 }
708 
IsDeviceScope(uint32_t scope_id)709 bool UpgradeMemoryModel::IsDeviceScope(uint32_t scope_id) {
710   const analysis::Constant* constant =
711       context()->get_constant_mgr()->FindDeclaredConstant(scope_id);
712   assert(constant && "Memory scope must be a constant");
713 
714   const analysis::Integer* type = constant->type()->AsInteger();
715   assert(type);
716   assert(type->width() == 32 || type->width() == 64);
717   if (type->width() == 32) {
718     if (type->IsSigned())
719       return static_cast<spv::Scope>(constant->GetS32()) == spv::Scope::Device;
720     else
721       return static_cast<spv::Scope>(constant->GetU32()) == spv::Scope::Device;
722   } else {
723     if (type->IsSigned())
724       return static_cast<spv::Scope>(constant->GetS64()) == spv::Scope::Device;
725     else
726       return static_cast<spv::Scope>(constant->GetU64()) == spv::Scope::Device;
727   }
728 
729   assert(false);
730   return false;
731 }
732 
UpgradeExtInst(Instruction * ext_inst)733 void UpgradeMemoryModel::UpgradeExtInst(Instruction* ext_inst) {
734   const bool is_modf = ext_inst->GetSingleWordInOperand(1u) == GLSLstd450Modf;
735   auto ptr_id = ext_inst->GetSingleWordInOperand(3u);
736   auto ptr_type_id = get_def_use_mgr()->GetDef(ptr_id)->type_id();
737   auto pointee_type_id =
738       get_def_use_mgr()->GetDef(ptr_type_id)->GetSingleWordInOperand(1u);
739   auto element_type_id = ext_inst->type_id();
740   std::vector<const analysis::Type*> element_types(2);
741   element_types[0] = context()->get_type_mgr()->GetType(element_type_id);
742   element_types[1] = context()->get_type_mgr()->GetType(pointee_type_id);
743   analysis::Struct struct_type(element_types);
744   uint32_t struct_id =
745       context()->get_type_mgr()->GetTypeInstruction(&struct_type);
746   // Change the operation
747   GLSLstd450 new_op = is_modf ? GLSLstd450ModfStruct : GLSLstd450FrexpStruct;
748   ext_inst->SetOperand(3u, {static_cast<uint32_t>(new_op)});
749   // Remove the pointer argument
750   ext_inst->RemoveOperand(5u);
751   // Set the type id to the new struct.
752   ext_inst->SetResultType(struct_id);
753 
754   // The result is now a struct of the original result. The zero'th element is
755   // old result and should replace the old result. The one'th element needs to
756   // be stored via a new instruction.
757   auto where = ext_inst->NextNode();
758   InstructionBuilder builder(
759       context(), where,
760       IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
761   auto extract_0 =
762       builder.AddCompositeExtract(element_type_id, ext_inst->result_id(), {0});
763   context()->ReplaceAllUsesWith(ext_inst->result_id(), extract_0->result_id());
764   // The extract's input was just changed to itself, so fix that.
765   extract_0->SetInOperand(0u, {ext_inst->result_id()});
766   auto extract_1 =
767       builder.AddCompositeExtract(pointee_type_id, ext_inst->result_id(), {1});
768   builder.AddStore(ptr_id, extract_1->result_id());
769 }
770 
MemoryAccessNumWords(uint32_t mask)771 uint32_t UpgradeMemoryModel::MemoryAccessNumWords(uint32_t mask) {
772   uint32_t result = 1;
773   if (mask & uint32_t(spv::MemoryAccessMask::Aligned)) ++result;
774   if (mask & uint32_t(spv::MemoryAccessMask::MakePointerAvailableKHR)) ++result;
775   if (mask & uint32_t(spv::MemoryAccessMask::MakePointerVisibleKHR)) ++result;
776   return result;
777 }
778 
779 }  // namespace opt
780 }  // namespace spvtools
781