• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2023 Google Inc.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "source/opt/trim_capabilities_pass.h"
16 
17 #include <algorithm>
18 #include <array>
19 #include <cassert>
20 #include <functional>
21 #include <optional>
22 #include <queue>
23 #include <stack>
24 #include <unordered_map>
25 #include <unordered_set>
26 #include <vector>
27 
28 #include "source/enum_set.h"
29 #include "source/enum_string_mapping.h"
30 #include "source/opt/ir_context.h"
31 #include "source/opt/reflect.h"
32 #include "source/spirv_target_env.h"
33 #include "source/util/string_utils.h"
34 
35 namespace spvtools {
36 namespace opt {
37 
38 namespace {
39 constexpr uint32_t kOpTypeFloatSizeIndex = 0;
40 constexpr uint32_t kOpTypePointerStorageClassIndex = 0;
41 constexpr uint32_t kTypeArrayTypeIndex = 0;
42 constexpr uint32_t kOpTypeScalarBitWidthIndex = 0;
43 constexpr uint32_t kTypePointerTypeIdInIndex = 1;
44 constexpr uint32_t kOpTypeIntSizeIndex = 0;
45 constexpr uint32_t kOpTypeImageDimIndex = 1;
46 constexpr uint32_t kOpTypeImageArrayedIndex = kOpTypeImageDimIndex + 2;
47 constexpr uint32_t kOpTypeImageMSIndex = kOpTypeImageArrayedIndex + 1;
48 constexpr uint32_t kOpTypeImageSampledIndex = kOpTypeImageMSIndex + 1;
49 constexpr uint32_t kOpTypeImageFormatIndex = kOpTypeImageSampledIndex + 1;
50 constexpr uint32_t kOpImageReadImageIndex = 0;
51 constexpr uint32_t kOpImageSparseReadImageIndex = 0;
52 
53 // DFS visit of the type defined by `instruction`.
54 // If `condition` is true, children of the current node are visited.
55 // If `condition` is false, the children of the current node are ignored.
56 template <class UnaryPredicate>
DFSWhile(const Instruction * instruction,UnaryPredicate condition)57 static void DFSWhile(const Instruction* instruction, UnaryPredicate condition) {
58   std::stack<uint32_t> instructions_to_visit;
59   instructions_to_visit.push(instruction->result_id());
60   const auto* def_use_mgr = instruction->context()->get_def_use_mgr();
61 
62   while (!instructions_to_visit.empty()) {
63     const Instruction* item = def_use_mgr->GetDef(instructions_to_visit.top());
64     instructions_to_visit.pop();
65 
66     if (!condition(item)) {
67       continue;
68     }
69 
70     if (item->opcode() == spv::Op::OpTypePointer) {
71       instructions_to_visit.push(
72           item->GetSingleWordInOperand(kTypePointerTypeIdInIndex));
73       continue;
74     }
75 
76     if (item->opcode() == spv::Op::OpTypeMatrix ||
77         item->opcode() == spv::Op::OpTypeVector ||
78         item->opcode() == spv::Op::OpTypeArray ||
79         item->opcode() == spv::Op::OpTypeRuntimeArray) {
80       instructions_to_visit.push(
81           item->GetSingleWordInOperand(kTypeArrayTypeIndex));
82       continue;
83     }
84 
85     if (item->opcode() == spv::Op::OpTypeStruct) {
86       item->ForEachInOperand([&instructions_to_visit](const uint32_t* op_id) {
87         instructions_to_visit.push(*op_id);
88       });
89       continue;
90     }
91   }
92 }
93 
94 // Walks the type defined by `instruction` (OpType* only).
95 // Returns `true` if any call to `predicate` with the type/subtype returns true.
96 template <class UnaryPredicate>
AnyTypeOf(const Instruction * instruction,UnaryPredicate predicate)97 static bool AnyTypeOf(const Instruction* instruction,
98                       UnaryPredicate predicate) {
99   assert(IsTypeInst(instruction->opcode()) &&
100          "AnyTypeOf called with a non-type instruction.");
101 
102   bool found_one = false;
103   DFSWhile(instruction, [&found_one, predicate](const Instruction* node) {
104     if (found_one || predicate(node)) {
105       found_one = true;
106       return false;
107     }
108 
109     return true;
110   });
111   return found_one;
112 }
113 
is16bitType(const Instruction * instruction)114 static bool is16bitType(const Instruction* instruction) {
115   if (instruction->opcode() != spv::Op::OpTypeInt &&
116       instruction->opcode() != spv::Op::OpTypeFloat) {
117     return false;
118   }
119 
120   return instruction->GetSingleWordInOperand(kOpTypeScalarBitWidthIndex) == 16;
121 }
122 
Has16BitCapability(const FeatureManager * feature_manager)123 static bool Has16BitCapability(const FeatureManager* feature_manager) {
124   const CapabilitySet& capabilities = feature_manager->GetCapabilities();
125   return capabilities.contains(spv::Capability::Float16) ||
126          capabilities.contains(spv::Capability::Int16);
127 }
128 
129 }  // namespace
130 
131 // ============== Begin opcode handler implementations. =======================
132 //
133 // Adding support for a new capability should only require adding a new handler,
134 // and updating the
135 // kSupportedCapabilities/kUntouchableCapabilities/kFordiddenCapabilities lists.
136 //
137 // Handler names follow the following convention:
138 //  Handler_<Opcode>_<Capability>()
139 
Handler_OpTypeFloat_Float16(const Instruction * instruction)140 static std::optional<spv::Capability> Handler_OpTypeFloat_Float16(
141     const Instruction* instruction) {
142   assert(instruction->opcode() == spv::Op::OpTypeFloat &&
143          "This handler only support OpTypeFloat opcodes.");
144 
145   const uint32_t size =
146       instruction->GetSingleWordInOperand(kOpTypeFloatSizeIndex);
147   return size == 16 ? std::optional(spv::Capability::Float16) : std::nullopt;
148 }
149 
Handler_OpTypeFloat_Float64(const Instruction * instruction)150 static std::optional<spv::Capability> Handler_OpTypeFloat_Float64(
151     const Instruction* instruction) {
152   assert(instruction->opcode() == spv::Op::OpTypeFloat &&
153          "This handler only support OpTypeFloat opcodes.");
154 
155   const uint32_t size =
156       instruction->GetSingleWordInOperand(kOpTypeFloatSizeIndex);
157   return size == 64 ? std::optional(spv::Capability::Float64) : std::nullopt;
158 }
159 
160 static std::optional<spv::Capability>
Handler_OpTypePointer_StorageInputOutput16(const Instruction * instruction)161 Handler_OpTypePointer_StorageInputOutput16(const Instruction* instruction) {
162   assert(instruction->opcode() == spv::Op::OpTypePointer &&
163          "This handler only support OpTypePointer opcodes.");
164 
165   // This capability is only required if the variable has an Input/Output
166   // storage class.
167   spv::StorageClass storage_class = spv::StorageClass(
168       instruction->GetSingleWordInOperand(kOpTypePointerStorageClassIndex));
169   if (storage_class != spv::StorageClass::Input &&
170       storage_class != spv::StorageClass::Output) {
171     return std::nullopt;
172   }
173 
174   if (!Has16BitCapability(instruction->context()->get_feature_mgr())) {
175     return std::nullopt;
176   }
177 
178   return AnyTypeOf(instruction, is16bitType)
179              ? std::optional(spv::Capability::StorageInputOutput16)
180              : std::nullopt;
181 }
182 
183 static std::optional<spv::Capability>
Handler_OpTypePointer_StoragePushConstant16(const Instruction * instruction)184 Handler_OpTypePointer_StoragePushConstant16(const Instruction* instruction) {
185   assert(instruction->opcode() == spv::Op::OpTypePointer &&
186          "This handler only support OpTypePointer opcodes.");
187 
188   // This capability is only required if the variable has a PushConstant storage
189   // class.
190   spv::StorageClass storage_class = spv::StorageClass(
191       instruction->GetSingleWordInOperand(kOpTypePointerStorageClassIndex));
192   if (storage_class != spv::StorageClass::PushConstant) {
193     return std::nullopt;
194   }
195 
196   if (!Has16BitCapability(instruction->context()->get_feature_mgr())) {
197     return std::nullopt;
198   }
199 
200   return AnyTypeOf(instruction, is16bitType)
201              ? std::optional(spv::Capability::StoragePushConstant16)
202              : std::nullopt;
203 }
204 
205 static std::optional<spv::Capability>
Handler_OpTypePointer_StorageUniformBufferBlock16(const Instruction * instruction)206 Handler_OpTypePointer_StorageUniformBufferBlock16(
207     const Instruction* instruction) {
208   assert(instruction->opcode() == spv::Op::OpTypePointer &&
209          "This handler only support OpTypePointer opcodes.");
210 
211   // This capability is only required if the variable has a Uniform storage
212   // class.
213   spv::StorageClass storage_class = spv::StorageClass(
214       instruction->GetSingleWordInOperand(kOpTypePointerStorageClassIndex));
215   if (storage_class != spv::StorageClass::Uniform) {
216     return std::nullopt;
217   }
218 
219   if (!Has16BitCapability(instruction->context()->get_feature_mgr())) {
220     return std::nullopt;
221   }
222 
223   const auto* decoration_mgr = instruction->context()->get_decoration_mgr();
224   const bool matchesCondition =
225       AnyTypeOf(instruction, [decoration_mgr](const Instruction* item) {
226         if (!decoration_mgr->HasDecoration(item->result_id(),
227                                            spv::Decoration::BufferBlock)) {
228           return false;
229         }
230 
231         return AnyTypeOf(item, is16bitType);
232       });
233 
234   return matchesCondition
235              ? std::optional(spv::Capability::StorageUniformBufferBlock16)
236              : std::nullopt;
237 }
238 
Handler_OpTypePointer_StorageUniform16(const Instruction * instruction)239 static std::optional<spv::Capability> Handler_OpTypePointer_StorageUniform16(
240     const Instruction* instruction) {
241   assert(instruction->opcode() == spv::Op::OpTypePointer &&
242          "This handler only support OpTypePointer opcodes.");
243 
244   // This capability is only required if the variable has a Uniform storage
245   // class.
246   spv::StorageClass storage_class = spv::StorageClass(
247       instruction->GetSingleWordInOperand(kOpTypePointerStorageClassIndex));
248   if (storage_class != spv::StorageClass::Uniform) {
249     return std::nullopt;
250   }
251 
252   const auto* feature_manager = instruction->context()->get_feature_mgr();
253   if (!Has16BitCapability(feature_manager)) {
254     return std::nullopt;
255   }
256 
257   const bool hasBufferBlockCapability =
258       feature_manager->GetCapabilities().contains(
259           spv::Capability::StorageUniformBufferBlock16);
260   const auto* decoration_mgr = instruction->context()->get_decoration_mgr();
261   bool found16bitType = false;
262 
263   DFSWhile(instruction, [decoration_mgr, hasBufferBlockCapability,
264                          &found16bitType](const Instruction* item) {
265     if (found16bitType) {
266       return false;
267     }
268 
269     if (hasBufferBlockCapability &&
270         decoration_mgr->HasDecoration(item->result_id(),
271                                       spv::Decoration::BufferBlock)) {
272       return false;
273     }
274 
275     if (is16bitType(item)) {
276       found16bitType = true;
277       return false;
278     }
279 
280     return true;
281   });
282 
283   return found16bitType ? std::optional(spv::Capability::StorageUniform16)
284                         : std::nullopt;
285 }
286 
Handler_OpTypeInt_Int16(const Instruction * instruction)287 static std::optional<spv::Capability> Handler_OpTypeInt_Int16(
288     const Instruction* instruction) {
289   assert(instruction->opcode() == spv::Op::OpTypeInt &&
290          "This handler only support OpTypeInt opcodes.");
291 
292   const uint32_t size =
293       instruction->GetSingleWordInOperand(kOpTypeIntSizeIndex);
294   return size == 16 ? std::optional(spv::Capability::Int16) : std::nullopt;
295 }
296 
Handler_OpTypeInt_Int64(const Instruction * instruction)297 static std::optional<spv::Capability> Handler_OpTypeInt_Int64(
298     const Instruction* instruction) {
299   assert(instruction->opcode() == spv::Op::OpTypeInt &&
300          "This handler only support OpTypeInt opcodes.");
301 
302   const uint32_t size =
303       instruction->GetSingleWordInOperand(kOpTypeIntSizeIndex);
304   return size == 64 ? std::optional(spv::Capability::Int64) : std::nullopt;
305 }
306 
Handler_OpTypeImage_ImageMSArray(const Instruction * instruction)307 static std::optional<spv::Capability> Handler_OpTypeImage_ImageMSArray(
308     const Instruction* instruction) {
309   assert(instruction->opcode() == spv::Op::OpTypeImage &&
310          "This handler only support OpTypeImage opcodes.");
311 
312   const uint32_t arrayed =
313       instruction->GetSingleWordInOperand(kOpTypeImageArrayedIndex);
314   const uint32_t ms = instruction->GetSingleWordInOperand(kOpTypeImageMSIndex);
315   const uint32_t sampled =
316       instruction->GetSingleWordInOperand(kOpTypeImageSampledIndex);
317 
318   return arrayed == 1 && sampled == 2 && ms == 1
319              ? std::optional(spv::Capability::ImageMSArray)
320              : std::nullopt;
321 }
322 
323 static std::optional<spv::Capability>
Handler_OpImageRead_StorageImageReadWithoutFormat(const Instruction * instruction)324 Handler_OpImageRead_StorageImageReadWithoutFormat(
325     const Instruction* instruction) {
326   assert(instruction->opcode() == spv::Op::OpImageRead &&
327          "This handler only support OpImageRead opcodes.");
328   const auto* def_use_mgr = instruction->context()->get_def_use_mgr();
329 
330   const uint32_t image_index =
331       instruction->GetSingleWordInOperand(kOpImageReadImageIndex);
332   const uint32_t type_index = def_use_mgr->GetDef(image_index)->type_id();
333   const Instruction* type = def_use_mgr->GetDef(type_index);
334   const uint32_t dim = type->GetSingleWordInOperand(kOpTypeImageDimIndex);
335   const uint32_t format = type->GetSingleWordInOperand(kOpTypeImageFormatIndex);
336 
337   const bool is_unknown = spv::ImageFormat(format) == spv::ImageFormat::Unknown;
338   const bool requires_capability_for_unknown =
339       spv::Dim(dim) != spv::Dim::SubpassData;
340   return is_unknown && requires_capability_for_unknown
341              ? std::optional(spv::Capability::StorageImageReadWithoutFormat)
342              : std::nullopt;
343 }
344 
345 static std::optional<spv::Capability>
Handler_OpImageSparseRead_StorageImageReadWithoutFormat(const Instruction * instruction)346 Handler_OpImageSparseRead_StorageImageReadWithoutFormat(
347     const Instruction* instruction) {
348   assert(instruction->opcode() == spv::Op::OpImageSparseRead &&
349          "This handler only support OpImageSparseRead opcodes.");
350   const auto* def_use_mgr = instruction->context()->get_def_use_mgr();
351 
352   const uint32_t image_index =
353       instruction->GetSingleWordInOperand(kOpImageSparseReadImageIndex);
354   const uint32_t type_index = def_use_mgr->GetDef(image_index)->type_id();
355   const Instruction* type = def_use_mgr->GetDef(type_index);
356   const uint32_t format = type->GetSingleWordInOperand(kOpTypeImageFormatIndex);
357 
358   return spv::ImageFormat(format) == spv::ImageFormat::Unknown
359              ? std::optional(spv::Capability::StorageImageReadWithoutFormat)
360              : std::nullopt;
361 }
362 
363 // Opcode of interest to determine capabilities requirements.
364 constexpr std::array<std::pair<spv::Op, OpcodeHandler>, 12> kOpcodeHandlers{{
365     // clang-format off
366     {spv::Op::OpImageRead,         Handler_OpImageRead_StorageImageReadWithoutFormat},
367     {spv::Op::OpImageSparseRead,   Handler_OpImageSparseRead_StorageImageReadWithoutFormat},
368     {spv::Op::OpTypeFloat,         Handler_OpTypeFloat_Float16 },
369     {spv::Op::OpTypeFloat,         Handler_OpTypeFloat_Float64 },
370     {spv::Op::OpTypeImage,         Handler_OpTypeImage_ImageMSArray},
371     {spv::Op::OpTypeInt,           Handler_OpTypeInt_Int16 },
372     {spv::Op::OpTypeInt,           Handler_OpTypeInt_Int64 },
373     {spv::Op::OpTypePointer,       Handler_OpTypePointer_StorageInputOutput16},
374     {spv::Op::OpTypePointer,       Handler_OpTypePointer_StoragePushConstant16},
375     {spv::Op::OpTypePointer,       Handler_OpTypePointer_StorageUniform16},
376     {spv::Op::OpTypePointer,       Handler_OpTypePointer_StorageUniform16},
377     {spv::Op::OpTypePointer,       Handler_OpTypePointer_StorageUniformBufferBlock16},
378     // clang-format on
379 }};
380 
381 // ==============  End opcode handler implementations.  =======================
382 
383 namespace {
getExtensionsRelatedTo(const CapabilitySet & capabilities,const AssemblyGrammar & grammar)384 ExtensionSet getExtensionsRelatedTo(const CapabilitySet& capabilities,
385                                     const AssemblyGrammar& grammar) {
386   ExtensionSet output;
387   const spv_operand_desc_t* desc = nullptr;
388   for (auto capability : capabilities) {
389     if (SPV_SUCCESS != grammar.lookupOperand(SPV_OPERAND_TYPE_CAPABILITY,
390                                              static_cast<uint32_t>(capability),
391                                              &desc)) {
392       continue;
393     }
394 
395     for (uint32_t i = 0; i < desc->numExtensions; ++i) {
396       output.insert(desc->extensions[i]);
397     }
398   }
399 
400   return output;
401 }
402 
hasOpcodeConflictingCapabilities(spv::Op opcode)403 bool hasOpcodeConflictingCapabilities(spv::Op opcode) {
404   switch (opcode) {
405     case spv::Op::OpBeginInvocationInterlockEXT:
406     case spv::Op::OpEndInvocationInterlockEXT:
407     case spv::Op::OpGroupNonUniformIAdd:
408     case spv::Op::OpGroupNonUniformFAdd:
409     case spv::Op::OpGroupNonUniformIMul:
410     case spv::Op::OpGroupNonUniformFMul:
411     case spv::Op::OpGroupNonUniformSMin:
412     case spv::Op::OpGroupNonUniformUMin:
413     case spv::Op::OpGroupNonUniformFMin:
414     case spv::Op::OpGroupNonUniformSMax:
415     case spv::Op::OpGroupNonUniformUMax:
416     case spv::Op::OpGroupNonUniformFMax:
417     case spv::Op::OpGroupNonUniformBitwiseAnd:
418     case spv::Op::OpGroupNonUniformBitwiseOr:
419     case spv::Op::OpGroupNonUniformBitwiseXor:
420     case spv::Op::OpGroupNonUniformLogicalAnd:
421     case spv::Op::OpGroupNonUniformLogicalOr:
422     case spv::Op::OpGroupNonUniformLogicalXor:
423       return true;
424     default:
425       return false;
426   }
427 }
428 
429 }  // namespace
430 
TrimCapabilitiesPass()431 TrimCapabilitiesPass::TrimCapabilitiesPass()
432     : supportedCapabilities_(
433           TrimCapabilitiesPass::kSupportedCapabilities.cbegin(),
434           TrimCapabilitiesPass::kSupportedCapabilities.cend()),
435       forbiddenCapabilities_(
436           TrimCapabilitiesPass::kForbiddenCapabilities.cbegin(),
437           TrimCapabilitiesPass::kForbiddenCapabilities.cend()),
438       untouchableCapabilities_(
439           TrimCapabilitiesPass::kUntouchableCapabilities.cbegin(),
440           TrimCapabilitiesPass::kUntouchableCapabilities.cend()),
441       opcodeHandlers_(kOpcodeHandlers.cbegin(), kOpcodeHandlers.cend()) {}
442 
addInstructionRequirementsForOpcode(spv::Op opcode,CapabilitySet * capabilities,ExtensionSet * extensions) const443 void TrimCapabilitiesPass::addInstructionRequirementsForOpcode(
444     spv::Op opcode, CapabilitySet* capabilities,
445     ExtensionSet* extensions) const {
446   if (hasOpcodeConflictingCapabilities(opcode)) {
447     return;
448   }
449 
450   const spv_opcode_desc_t* desc = {};
451   auto result = context()->grammar().lookupOpcode(opcode, &desc);
452   if (result != SPV_SUCCESS) {
453     return;
454   }
455 
456   addSupportedCapabilitiesToSet(desc, capabilities);
457   addSupportedExtensionsToSet(desc, extensions);
458 }
459 
addInstructionRequirementsForOperand(const Operand & operand,CapabilitySet * capabilities,ExtensionSet * extensions) const460 void TrimCapabilitiesPass::addInstructionRequirementsForOperand(
461     const Operand& operand, CapabilitySet* capabilities,
462     ExtensionSet* extensions) const {
463   // No supported capability relies on a 2+-word operand.
464   if (operand.words.size() != 1) {
465     return;
466   }
467 
468   // No supported capability relies on a literal string operand or an ID.
469   if (operand.type == SPV_OPERAND_TYPE_LITERAL_STRING ||
470       operand.type == SPV_OPERAND_TYPE_ID ||
471       operand.type == SPV_OPERAND_TYPE_RESULT_ID) {
472     return;
473   }
474 
475   // If the Vulkan memory model is declared and any instruction uses Device
476   // scope, the VulkanMemoryModelDeviceScope capability must be declared. This
477   // rule cannot be covered by the grammar, so must be checked explicitly.
478   if (operand.type == SPV_OPERAND_TYPE_SCOPE_ID) {
479     const Instruction* memory_model = context()->GetMemoryModel();
480     if (memory_model && memory_model->GetSingleWordInOperand(1u) ==
481                             uint32_t(spv::MemoryModel::Vulkan)) {
482       capabilities->insert(spv::Capability::VulkanMemoryModelDeviceScope);
483     }
484   }
485 
486   // case 1: Operand is a single value, can directly lookup.
487   if (!spvOperandIsConcreteMask(operand.type)) {
488     const spv_operand_desc_t* desc = {};
489     auto result = context()->grammar().lookupOperand(operand.type,
490                                                      operand.words[0], &desc);
491     if (result != SPV_SUCCESS) {
492       return;
493     }
494     addSupportedCapabilitiesToSet(desc, capabilities);
495     addSupportedExtensionsToSet(desc, extensions);
496     return;
497   }
498 
499   // case 2: operand can be a bitmask, we need to decompose the lookup.
500   for (uint32_t i = 0; i < 32; i++) {
501     const uint32_t mask = (1 << i) & operand.words[0];
502     if (!mask) {
503       continue;
504     }
505 
506     const spv_operand_desc_t* desc = {};
507     auto result = context()->grammar().lookupOperand(operand.type, mask, &desc);
508     if (result != SPV_SUCCESS) {
509       continue;
510     }
511 
512     addSupportedCapabilitiesToSet(desc, capabilities);
513     addSupportedExtensionsToSet(desc, extensions);
514   }
515 }
516 
addInstructionRequirements(Instruction * instruction,CapabilitySet * capabilities,ExtensionSet * extensions) const517 void TrimCapabilitiesPass::addInstructionRequirements(
518     Instruction* instruction, CapabilitySet* capabilities,
519     ExtensionSet* extensions) const {
520   // Ignoring OpCapability and OpExtension instructions.
521   if (instruction->opcode() == spv::Op::OpCapability ||
522       instruction->opcode() == spv::Op::OpExtension) {
523     return;
524   }
525 
526   addInstructionRequirementsForOpcode(instruction->opcode(), capabilities,
527                                       extensions);
528 
529   // Second case: one of the opcode operand is gated by a capability.
530   const uint32_t operandCount = instruction->NumOperands();
531   for (uint32_t i = 0; i < operandCount; i++) {
532     addInstructionRequirementsForOperand(instruction->GetOperand(i),
533                                          capabilities, extensions);
534   }
535 
536   // Last case: some complex logic needs to be run to determine capabilities.
537   auto[begin, end] = opcodeHandlers_.equal_range(instruction->opcode());
538   for (auto it = begin; it != end; it++) {
539     const OpcodeHandler handler = it->second;
540     auto result = handler(instruction);
541     if (!result.has_value()) {
542       continue;
543     }
544 
545     capabilities->insert(*result);
546   }
547 }
548 
AddExtensionsForOperand(const spv_operand_type_t type,const uint32_t value,ExtensionSet * extensions) const549 void TrimCapabilitiesPass::AddExtensionsForOperand(
550     const spv_operand_type_t type, const uint32_t value,
551     ExtensionSet* extensions) const {
552   const spv_operand_desc_t* desc = nullptr;
553   spv_result_t result = context()->grammar().lookupOperand(type, value, &desc);
554   if (result != SPV_SUCCESS) {
555     return;
556   }
557   addSupportedExtensionsToSet(desc, extensions);
558 }
559 
560 std::pair<CapabilitySet, ExtensionSet>
DetermineRequiredCapabilitiesAndExtensions() const561 TrimCapabilitiesPass::DetermineRequiredCapabilitiesAndExtensions() const {
562   CapabilitySet required_capabilities;
563   ExtensionSet required_extensions;
564 
565   get_module()->ForEachInst([&](Instruction* instruction) {
566     addInstructionRequirements(instruction, &required_capabilities,
567                                &required_extensions);
568   });
569 
570   for (auto capability : required_capabilities) {
571     AddExtensionsForOperand(SPV_OPERAND_TYPE_CAPABILITY,
572                             static_cast<uint32_t>(capability),
573                             &required_extensions);
574   }
575 
576 #if !defined(NDEBUG)
577   // Debug only. We check the outputted required capabilities against the
578   // supported capabilities list. The supported capabilities list is useful for
579   // API users to quickly determine if they can use the pass or not. But this
580   // list has to remain up-to-date with the pass code. If we can detect a
581   // capability as required, but it's not listed, it means the list is
582   // out-of-sync. This method is not ideal, but should cover most cases.
583   {
584     for (auto capability : required_capabilities) {
585       assert(supportedCapabilities_.contains(capability) &&
586              "Module is using a capability that is not listed as supported.");
587     }
588   }
589 #endif
590 
591   return std::make_pair(std::move(required_capabilities),
592                         std::move(required_extensions));
593 }
594 
TrimUnrequiredCapabilities(const CapabilitySet & required_capabilities) const595 Pass::Status TrimCapabilitiesPass::TrimUnrequiredCapabilities(
596     const CapabilitySet& required_capabilities) const {
597   const FeatureManager* feature_manager = context()->get_feature_mgr();
598   CapabilitySet capabilities_to_trim;
599   for (auto capability : feature_manager->GetCapabilities()) {
600     // Some capabilities cannot be safely removed. Leaving them untouched.
601     if (untouchableCapabilities_.contains(capability)) {
602       continue;
603     }
604 
605     // If the capability is unsupported, don't trim it.
606     if (!supportedCapabilities_.contains(capability)) {
607       continue;
608     }
609 
610     if (required_capabilities.contains(capability)) {
611       continue;
612     }
613 
614     capabilities_to_trim.insert(capability);
615   }
616 
617   for (auto capability : capabilities_to_trim) {
618     context()->RemoveCapability(capability);
619   }
620 
621   return capabilities_to_trim.size() == 0 ? Pass::Status::SuccessWithoutChange
622                                           : Pass::Status::SuccessWithChange;
623 }
624 
TrimUnrequiredExtensions(const ExtensionSet & required_extensions) const625 Pass::Status TrimCapabilitiesPass::TrimUnrequiredExtensions(
626     const ExtensionSet& required_extensions) const {
627   const auto supported_extensions =
628       getExtensionsRelatedTo(supportedCapabilities_, context()->grammar());
629 
630   bool modified_module = false;
631   for (auto extension : supported_extensions) {
632     if (required_extensions.contains(extension)) {
633       continue;
634     }
635 
636     if (context()->RemoveExtension(extension)) {
637       modified_module = true;
638     }
639   }
640 
641   return modified_module ? Pass::Status::SuccessWithChange
642                          : Pass::Status::SuccessWithoutChange;
643 }
644 
HasForbiddenCapabilities() const645 bool TrimCapabilitiesPass::HasForbiddenCapabilities() const {
646   // EnumSet.HasAnyOf returns `true` if the given set is empty.
647   if (forbiddenCapabilities_.size() == 0) {
648     return false;
649   }
650 
651   const auto& capabilities = context()->get_feature_mgr()->GetCapabilities();
652   return capabilities.HasAnyOf(forbiddenCapabilities_);
653 }
654 
Process()655 Pass::Status TrimCapabilitiesPass::Process() {
656   if (HasForbiddenCapabilities()) {
657     return Status::SuccessWithoutChange;
658   }
659 
660   auto[required_capabilities, required_extensions] =
661       DetermineRequiredCapabilitiesAndExtensions();
662 
663   Pass::Status capStatus = TrimUnrequiredCapabilities(required_capabilities);
664   Pass::Status extStatus = TrimUnrequiredExtensions(required_extensions);
665 
666   return capStatus == Pass::Status::SuccessWithChange ||
667                  extStatus == Pass::Status::SuccessWithChange
668              ? Pass::Status::SuccessWithChange
669              : Pass::Status::SuccessWithoutChange;
670 }
671 
672 }  // namespace opt
673 }  // namespace spvtools
674