• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2023 Google Inc.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "source/opt/trim_capabilities_pass.h"
16 
17 #include <algorithm>
18 #include <array>
19 #include <cassert>
20 #include <functional>
21 #include <optional>
22 #include <queue>
23 #include <stack>
24 #include <unordered_map>
25 #include <unordered_set>
26 #include <vector>
27 
28 #include "source/enum_set.h"
29 #include "source/enum_string_mapping.h"
30 #include "source/opt/ir_context.h"
31 #include "source/opt/reflect.h"
32 #include "source/spirv_target_env.h"
33 #include "source/util/string_utils.h"
34 
35 namespace spvtools {
36 namespace opt {
37 
38 namespace {
39 constexpr uint32_t kOpTypeFloatSizeIndex = 0;
40 constexpr uint32_t kOpTypePointerStorageClassIndex = 0;
41 constexpr uint32_t kTypeArrayTypeIndex = 0;
42 constexpr uint32_t kOpTypeScalarBitWidthIndex = 0;
43 constexpr uint32_t kTypePointerTypeIdInIndex = 1;
44 constexpr uint32_t kOpTypeIntSizeIndex = 0;
45 constexpr uint32_t kOpTypeImageDimIndex = 1;
46 constexpr uint32_t kOpTypeImageArrayedIndex = kOpTypeImageDimIndex + 2;
47 constexpr uint32_t kOpTypeImageMSIndex = kOpTypeImageArrayedIndex + 1;
48 constexpr uint32_t kOpTypeImageSampledIndex = kOpTypeImageMSIndex + 1;
49 constexpr uint32_t kOpTypeImageFormatIndex = kOpTypeImageSampledIndex + 1;
50 constexpr uint32_t kOpImageReadImageIndex = 0;
51 constexpr uint32_t kOpImageSparseReadImageIndex = 0;
52 
53 // DFS visit of the type defined by `instruction`.
54 // If `condition` is true, children of the current node are visited.
55 // If `condition` is false, the children of the current node are ignored.
56 template <class UnaryPredicate>
DFSWhile(const Instruction * instruction,UnaryPredicate condition)57 static void DFSWhile(const Instruction* instruction, UnaryPredicate condition) {
58   std::stack<uint32_t> instructions_to_visit;
59   instructions_to_visit.push(instruction->result_id());
60   const auto* def_use_mgr = instruction->context()->get_def_use_mgr();
61 
62   while (!instructions_to_visit.empty()) {
63     const Instruction* item = def_use_mgr->GetDef(instructions_to_visit.top());
64     instructions_to_visit.pop();
65 
66     if (!condition(item)) {
67       continue;
68     }
69 
70     if (item->opcode() == spv::Op::OpTypePointer) {
71       instructions_to_visit.push(
72           item->GetSingleWordInOperand(kTypePointerTypeIdInIndex));
73       continue;
74     }
75 
76     if (item->opcode() == spv::Op::OpTypeMatrix ||
77         item->opcode() == spv::Op::OpTypeVector ||
78         item->opcode() == spv::Op::OpTypeArray ||
79         item->opcode() == spv::Op::OpTypeRuntimeArray) {
80       instructions_to_visit.push(
81           item->GetSingleWordInOperand(kTypeArrayTypeIndex));
82       continue;
83     }
84 
85     if (item->opcode() == spv::Op::OpTypeStruct) {
86       item->ForEachInOperand([&instructions_to_visit](const uint32_t* op_id) {
87         instructions_to_visit.push(*op_id);
88       });
89       continue;
90     }
91   }
92 }
93 
94 // Walks the type defined by `instruction` (OpType* only).
95 // Returns `true` if any call to `predicate` with the type/subtype returns true.
96 template <class UnaryPredicate>
AnyTypeOf(const Instruction * instruction,UnaryPredicate predicate)97 static bool AnyTypeOf(const Instruction* instruction,
98                       UnaryPredicate predicate) {
99   assert(IsTypeInst(instruction->opcode()) &&
100          "AnyTypeOf called with a non-type instruction.");
101 
102   bool found_one = false;
103   DFSWhile(instruction, [&found_one, predicate](const Instruction* node) {
104     if (found_one || predicate(node)) {
105       found_one = true;
106       return false;
107     }
108 
109     return true;
110   });
111   return found_one;
112 }
113 
is16bitType(const Instruction * instruction)114 static bool is16bitType(const Instruction* instruction) {
115   if (instruction->opcode() != spv::Op::OpTypeInt &&
116       instruction->opcode() != spv::Op::OpTypeFloat) {
117     return false;
118   }
119 
120   return instruction->GetSingleWordInOperand(kOpTypeScalarBitWidthIndex) == 16;
121 }
122 
Has16BitCapability(const FeatureManager * feature_manager)123 static bool Has16BitCapability(const FeatureManager* feature_manager) {
124   const CapabilitySet& capabilities = feature_manager->GetCapabilities();
125   return capabilities.contains(spv::Capability::Float16) ||
126          capabilities.contains(spv::Capability::Int16);
127 }
128 
129 }  // namespace
130 
131 // ============== Begin opcode handler implementations. =======================
132 //
133 // Adding support for a new capability should only require adding a new handler,
134 // and updating the
135 // kSupportedCapabilities/kUntouchableCapabilities/kFordiddenCapabilities lists.
136 //
137 // Handler names follow the following convention:
138 //  Handler_<Opcode>_<Capability>()
139 
Handler_OpTypeFloat_Float16(const Instruction * instruction)140 static std::optional<spv::Capability> Handler_OpTypeFloat_Float16(
141     const Instruction* instruction) {
142   assert(instruction->opcode() == spv::Op::OpTypeFloat &&
143          "This handler only support OpTypeFloat opcodes.");
144 
145   const uint32_t size =
146       instruction->GetSingleWordInOperand(kOpTypeFloatSizeIndex);
147   return size == 16 ? std::optional(spv::Capability::Float16) : std::nullopt;
148 }
149 
Handler_OpTypeFloat_Float64(const Instruction * instruction)150 static std::optional<spv::Capability> Handler_OpTypeFloat_Float64(
151     const Instruction* instruction) {
152   assert(instruction->opcode() == spv::Op::OpTypeFloat &&
153          "This handler only support OpTypeFloat opcodes.");
154 
155   const uint32_t size =
156       instruction->GetSingleWordInOperand(kOpTypeFloatSizeIndex);
157   return size == 64 ? std::optional(spv::Capability::Float64) : std::nullopt;
158 }
159 
160 static std::optional<spv::Capability>
Handler_OpTypePointer_StorageInputOutput16(const Instruction * instruction)161 Handler_OpTypePointer_StorageInputOutput16(const Instruction* instruction) {
162   assert(instruction->opcode() == spv::Op::OpTypePointer &&
163          "This handler only support OpTypePointer opcodes.");
164 
165   // This capability is only required if the variable has an Input/Output
166   // storage class.
167   spv::StorageClass storage_class = spv::StorageClass(
168       instruction->GetSingleWordInOperand(kOpTypePointerStorageClassIndex));
169   if (storage_class != spv::StorageClass::Input &&
170       storage_class != spv::StorageClass::Output) {
171     return std::nullopt;
172   }
173 
174   if (!Has16BitCapability(instruction->context()->get_feature_mgr())) {
175     return std::nullopt;
176   }
177 
178   return AnyTypeOf(instruction, is16bitType)
179              ? std::optional(spv::Capability::StorageInputOutput16)
180              : std::nullopt;
181 }
182 
183 static std::optional<spv::Capability>
Handler_OpTypePointer_StoragePushConstant16(const Instruction * instruction)184 Handler_OpTypePointer_StoragePushConstant16(const Instruction* instruction) {
185   assert(instruction->opcode() == spv::Op::OpTypePointer &&
186          "This handler only support OpTypePointer opcodes.");
187 
188   // This capability is only required if the variable has a PushConstant storage
189   // class.
190   spv::StorageClass storage_class = spv::StorageClass(
191       instruction->GetSingleWordInOperand(kOpTypePointerStorageClassIndex));
192   if (storage_class != spv::StorageClass::PushConstant) {
193     return std::nullopt;
194   }
195 
196   if (!Has16BitCapability(instruction->context()->get_feature_mgr())) {
197     return std::nullopt;
198   }
199 
200   return AnyTypeOf(instruction, is16bitType)
201              ? std::optional(spv::Capability::StoragePushConstant16)
202              : std::nullopt;
203 }
204 
205 static std::optional<spv::Capability>
Handler_OpTypePointer_StorageUniformBufferBlock16(const Instruction * instruction)206 Handler_OpTypePointer_StorageUniformBufferBlock16(
207     const Instruction* instruction) {
208   assert(instruction->opcode() == spv::Op::OpTypePointer &&
209          "This handler only support OpTypePointer opcodes.");
210 
211   // This capability is only required if the variable has a Uniform storage
212   // class.
213   spv::StorageClass storage_class = spv::StorageClass(
214       instruction->GetSingleWordInOperand(kOpTypePointerStorageClassIndex));
215   if (storage_class != spv::StorageClass::Uniform) {
216     return std::nullopt;
217   }
218 
219   if (!Has16BitCapability(instruction->context()->get_feature_mgr())) {
220     return std::nullopt;
221   }
222 
223   const auto* decoration_mgr = instruction->context()->get_decoration_mgr();
224   const bool matchesCondition =
225       AnyTypeOf(instruction, [decoration_mgr](const Instruction* item) {
226         if (!decoration_mgr->HasDecoration(item->result_id(),
227                                            spv::Decoration::BufferBlock)) {
228           return false;
229         }
230 
231         return AnyTypeOf(item, is16bitType);
232       });
233 
234   return matchesCondition
235              ? std::optional(spv::Capability::StorageUniformBufferBlock16)
236              : std::nullopt;
237 }
238 
Handler_OpTypePointer_StorageUniform16(const Instruction * instruction)239 static std::optional<spv::Capability> Handler_OpTypePointer_StorageUniform16(
240     const Instruction* instruction) {
241   assert(instruction->opcode() == spv::Op::OpTypePointer &&
242          "This handler only support OpTypePointer opcodes.");
243 
244   // This capability is only required if the variable has a Uniform storage
245   // class.
246   spv::StorageClass storage_class = spv::StorageClass(
247       instruction->GetSingleWordInOperand(kOpTypePointerStorageClassIndex));
248   if (storage_class != spv::StorageClass::Uniform) {
249     return std::nullopt;
250   }
251 
252   const auto* feature_manager = instruction->context()->get_feature_mgr();
253   if (!Has16BitCapability(feature_manager)) {
254     return std::nullopt;
255   }
256 
257   const bool hasBufferBlockCapability =
258       feature_manager->GetCapabilities().contains(
259           spv::Capability::StorageUniformBufferBlock16);
260   const auto* decoration_mgr = instruction->context()->get_decoration_mgr();
261   bool found16bitType = false;
262 
263   DFSWhile(instruction, [decoration_mgr, hasBufferBlockCapability,
264                          &found16bitType](const Instruction* item) {
265     if (found16bitType) {
266       return false;
267     }
268 
269     if (hasBufferBlockCapability &&
270         decoration_mgr->HasDecoration(item->result_id(),
271                                       spv::Decoration::BufferBlock)) {
272       return false;
273     }
274 
275     if (is16bitType(item)) {
276       found16bitType = true;
277       return false;
278     }
279 
280     return true;
281   });
282 
283   return found16bitType ? std::optional(spv::Capability::StorageUniform16)
284                         : std::nullopt;
285 }
286 
Handler_OpTypeInt_Int16(const Instruction * instruction)287 static std::optional<spv::Capability> Handler_OpTypeInt_Int16(
288     const Instruction* instruction) {
289   assert(instruction->opcode() == spv::Op::OpTypeInt &&
290          "This handler only support OpTypeInt opcodes.");
291 
292   const uint32_t size =
293       instruction->GetSingleWordInOperand(kOpTypeIntSizeIndex);
294   return size == 16 ? std::optional(spv::Capability::Int16) : std::nullopt;
295 }
296 
Handler_OpTypeInt_Int64(const Instruction * instruction)297 static std::optional<spv::Capability> Handler_OpTypeInt_Int64(
298     const Instruction* instruction) {
299   assert(instruction->opcode() == spv::Op::OpTypeInt &&
300          "This handler only support OpTypeInt opcodes.");
301 
302   const uint32_t size =
303       instruction->GetSingleWordInOperand(kOpTypeIntSizeIndex);
304   return size == 64 ? std::optional(spv::Capability::Int64) : std::nullopt;
305 }
306 
Handler_OpTypeImage_ImageMSArray(const Instruction * instruction)307 static std::optional<spv::Capability> Handler_OpTypeImage_ImageMSArray(
308     const Instruction* instruction) {
309   assert(instruction->opcode() == spv::Op::OpTypeImage &&
310          "This handler only support OpTypeImage opcodes.");
311 
312   const uint32_t arrayed =
313       instruction->GetSingleWordInOperand(kOpTypeImageArrayedIndex);
314   const uint32_t ms = instruction->GetSingleWordInOperand(kOpTypeImageMSIndex);
315   const uint32_t sampled =
316       instruction->GetSingleWordInOperand(kOpTypeImageSampledIndex);
317 
318   return arrayed == 1 && sampled == 2 && ms == 1
319              ? std::optional(spv::Capability::ImageMSArray)
320              : std::nullopt;
321 }
322 
323 static std::optional<spv::Capability>
Handler_OpImageRead_StorageImageReadWithoutFormat(const Instruction * instruction)324 Handler_OpImageRead_StorageImageReadWithoutFormat(
325     const Instruction* instruction) {
326   assert(instruction->opcode() == spv::Op::OpImageRead &&
327          "This handler only support OpImageRead opcodes.");
328   const auto* def_use_mgr = instruction->context()->get_def_use_mgr();
329 
330   const uint32_t image_index =
331       instruction->GetSingleWordInOperand(kOpImageReadImageIndex);
332   const uint32_t type_index = def_use_mgr->GetDef(image_index)->type_id();
333   const Instruction* type = def_use_mgr->GetDef(type_index);
334   const uint32_t dim = type->GetSingleWordInOperand(kOpTypeImageDimIndex);
335   const uint32_t format = type->GetSingleWordInOperand(kOpTypeImageFormatIndex);
336 
337   const bool is_unknown = spv::ImageFormat(format) == spv::ImageFormat::Unknown;
338   const bool requires_capability_for_unknown =
339       spv::Dim(dim) != spv::Dim::SubpassData;
340   return is_unknown && requires_capability_for_unknown
341              ? std::optional(spv::Capability::StorageImageReadWithoutFormat)
342              : std::nullopt;
343 }
344 
345 static std::optional<spv::Capability>
Handler_OpImageSparseRead_StorageImageReadWithoutFormat(const Instruction * instruction)346 Handler_OpImageSparseRead_StorageImageReadWithoutFormat(
347     const Instruction* instruction) {
348   assert(instruction->opcode() == spv::Op::OpImageSparseRead &&
349          "This handler only support OpImageSparseRead opcodes.");
350   const auto* def_use_mgr = instruction->context()->get_def_use_mgr();
351 
352   const uint32_t image_index =
353       instruction->GetSingleWordInOperand(kOpImageSparseReadImageIndex);
354   const uint32_t type_index = def_use_mgr->GetDef(image_index)->type_id();
355   const Instruction* type = def_use_mgr->GetDef(type_index);
356   const uint32_t format = type->GetSingleWordInOperand(kOpTypeImageFormatIndex);
357 
358   return spv::ImageFormat(format) == spv::ImageFormat::Unknown
359              ? std::optional(spv::Capability::StorageImageReadWithoutFormat)
360              : std::nullopt;
361 }
362 
363 // Opcode of interest to determine capabilities requirements.
364 constexpr std::array<std::pair<spv::Op, OpcodeHandler>, 12> kOpcodeHandlers{{
365     // clang-format off
366     {spv::Op::OpImageRead,         Handler_OpImageRead_StorageImageReadWithoutFormat},
367     {spv::Op::OpImageSparseRead,   Handler_OpImageSparseRead_StorageImageReadWithoutFormat},
368     {spv::Op::OpTypeFloat,         Handler_OpTypeFloat_Float16 },
369     {spv::Op::OpTypeFloat,         Handler_OpTypeFloat_Float64 },
370     {spv::Op::OpTypeImage,         Handler_OpTypeImage_ImageMSArray},
371     {spv::Op::OpTypeInt,           Handler_OpTypeInt_Int16 },
372     {spv::Op::OpTypeInt,           Handler_OpTypeInt_Int64 },
373     {spv::Op::OpTypePointer,       Handler_OpTypePointer_StorageInputOutput16},
374     {spv::Op::OpTypePointer,       Handler_OpTypePointer_StoragePushConstant16},
375     {spv::Op::OpTypePointer,       Handler_OpTypePointer_StorageUniform16},
376     {spv::Op::OpTypePointer,       Handler_OpTypePointer_StorageUniform16},
377     {spv::Op::OpTypePointer,       Handler_OpTypePointer_StorageUniformBufferBlock16},
378     // clang-format on
379 }};
380 
381 // ==============  End opcode handler implementations.  =======================
382 
383 namespace {
getExtensionsRelatedTo(const CapabilitySet & capabilities,const AssemblyGrammar & grammar)384 ExtensionSet getExtensionsRelatedTo(const CapabilitySet& capabilities,
385                                     const AssemblyGrammar& grammar) {
386   ExtensionSet output;
387   const spv_operand_desc_t* desc = nullptr;
388   for (auto capability : capabilities) {
389     if (SPV_SUCCESS != grammar.lookupOperand(SPV_OPERAND_TYPE_CAPABILITY,
390                                              static_cast<uint32_t>(capability),
391                                              &desc)) {
392       continue;
393     }
394 
395     for (uint32_t i = 0; i < desc->numExtensions; ++i) {
396       output.insert(desc->extensions[i]);
397     }
398   }
399 
400   return output;
401 }
402 }  // namespace
403 
TrimCapabilitiesPass()404 TrimCapabilitiesPass::TrimCapabilitiesPass()
405     : supportedCapabilities_(
406           TrimCapabilitiesPass::kSupportedCapabilities.cbegin(),
407           TrimCapabilitiesPass::kSupportedCapabilities.cend()),
408       forbiddenCapabilities_(
409           TrimCapabilitiesPass::kForbiddenCapabilities.cbegin(),
410           TrimCapabilitiesPass::kForbiddenCapabilities.cend()),
411       untouchableCapabilities_(
412           TrimCapabilitiesPass::kUntouchableCapabilities.cbegin(),
413           TrimCapabilitiesPass::kUntouchableCapabilities.cend()),
414       opcodeHandlers_(kOpcodeHandlers.cbegin(), kOpcodeHandlers.cend()) {}
415 
addInstructionRequirementsForOpcode(spv::Op opcode,CapabilitySet * capabilities,ExtensionSet * extensions) const416 void TrimCapabilitiesPass::addInstructionRequirementsForOpcode(
417     spv::Op opcode, CapabilitySet* capabilities,
418     ExtensionSet* extensions) const {
419   // Ignoring OpBeginInvocationInterlockEXT and OpEndInvocationInterlockEXT
420   // because they have three possible capabilities, only one of which is needed
421   if (opcode == spv::Op::OpBeginInvocationInterlockEXT ||
422       opcode == spv::Op::OpEndInvocationInterlockEXT) {
423     return;
424   }
425 
426   const spv_opcode_desc_t* desc = {};
427   auto result = context()->grammar().lookupOpcode(opcode, &desc);
428   if (result != SPV_SUCCESS) {
429     return;
430   }
431 
432   addSupportedCapabilitiesToSet(desc, capabilities);
433   addSupportedExtensionsToSet(desc, extensions);
434 }
435 
addInstructionRequirementsForOperand(const Operand & operand,CapabilitySet * capabilities,ExtensionSet * extensions) const436 void TrimCapabilitiesPass::addInstructionRequirementsForOperand(
437     const Operand& operand, CapabilitySet* capabilities,
438     ExtensionSet* extensions) const {
439   // No supported capability relies on a 2+-word operand.
440   if (operand.words.size() != 1) {
441     return;
442   }
443 
444   // No supported capability relies on a literal string operand or an ID.
445   if (operand.type == SPV_OPERAND_TYPE_LITERAL_STRING ||
446       operand.type == SPV_OPERAND_TYPE_ID ||
447       operand.type == SPV_OPERAND_TYPE_RESULT_ID) {
448     return;
449   }
450 
451   // If the Vulkan memory model is declared and any instruction uses Device
452   // scope, the VulkanMemoryModelDeviceScope capability must be declared. This
453   // rule cannot be covered by the grammar, so must be checked explicitly.
454   if (operand.type == SPV_OPERAND_TYPE_SCOPE_ID) {
455     const Instruction* memory_model = context()->GetMemoryModel();
456     if (memory_model && memory_model->GetSingleWordInOperand(1u) ==
457                             uint32_t(spv::MemoryModel::Vulkan)) {
458       capabilities->insert(spv::Capability::VulkanMemoryModelDeviceScope);
459     }
460   }
461 
462   // case 1: Operand is a single value, can directly lookup.
463   if (!spvOperandIsConcreteMask(operand.type)) {
464     const spv_operand_desc_t* desc = {};
465     auto result = context()->grammar().lookupOperand(operand.type,
466                                                      operand.words[0], &desc);
467     if (result != SPV_SUCCESS) {
468       return;
469     }
470     addSupportedCapabilitiesToSet(desc, capabilities);
471     addSupportedExtensionsToSet(desc, extensions);
472     return;
473   }
474 
475   // case 2: operand can be a bitmask, we need to decompose the lookup.
476   for (uint32_t i = 0; i < 32; i++) {
477     const uint32_t mask = (1 << i) & operand.words[0];
478     if (!mask) {
479       continue;
480     }
481 
482     const spv_operand_desc_t* desc = {};
483     auto result = context()->grammar().lookupOperand(operand.type, mask, &desc);
484     if (result != SPV_SUCCESS) {
485       continue;
486     }
487 
488     addSupportedCapabilitiesToSet(desc, capabilities);
489     addSupportedExtensionsToSet(desc, extensions);
490   }
491 }
492 
addInstructionRequirements(Instruction * instruction,CapabilitySet * capabilities,ExtensionSet * extensions) const493 void TrimCapabilitiesPass::addInstructionRequirements(
494     Instruction* instruction, CapabilitySet* capabilities,
495     ExtensionSet* extensions) const {
496   // Ignoring OpCapability and OpExtension instructions.
497   if (instruction->opcode() == spv::Op::OpCapability ||
498       instruction->opcode() == spv::Op::OpExtension) {
499     return;
500   }
501 
502   addInstructionRequirementsForOpcode(instruction->opcode(), capabilities,
503                                       extensions);
504 
505   // Second case: one of the opcode operand is gated by a capability.
506   const uint32_t operandCount = instruction->NumOperands();
507   for (uint32_t i = 0; i < operandCount; i++) {
508     addInstructionRequirementsForOperand(instruction->GetOperand(i),
509                                          capabilities, extensions);
510   }
511 
512   // Last case: some complex logic needs to be run to determine capabilities.
513   auto[begin, end] = opcodeHandlers_.equal_range(instruction->opcode());
514   for (auto it = begin; it != end; it++) {
515     const OpcodeHandler handler = it->second;
516     auto result = handler(instruction);
517     if (!result.has_value()) {
518       continue;
519     }
520 
521     capabilities->insert(*result);
522   }
523 }
524 
AddExtensionsForOperand(const spv_operand_type_t type,const uint32_t value,ExtensionSet * extensions) const525 void TrimCapabilitiesPass::AddExtensionsForOperand(
526     const spv_operand_type_t type, const uint32_t value,
527     ExtensionSet* extensions) const {
528   const spv_operand_desc_t* desc = nullptr;
529   spv_result_t result = context()->grammar().lookupOperand(type, value, &desc);
530   if (result != SPV_SUCCESS) {
531     return;
532   }
533   addSupportedExtensionsToSet(desc, extensions);
534 }
535 
536 std::pair<CapabilitySet, ExtensionSet>
DetermineRequiredCapabilitiesAndExtensions() const537 TrimCapabilitiesPass::DetermineRequiredCapabilitiesAndExtensions() const {
538   CapabilitySet required_capabilities;
539   ExtensionSet required_extensions;
540 
541   get_module()->ForEachInst([&](Instruction* instruction) {
542     addInstructionRequirements(instruction, &required_capabilities,
543                                &required_extensions);
544   });
545 
546   for (auto capability : required_capabilities) {
547     AddExtensionsForOperand(SPV_OPERAND_TYPE_CAPABILITY,
548                             static_cast<uint32_t>(capability),
549                             &required_extensions);
550   }
551 
552 #if !defined(NDEBUG)
553   // Debug only. We check the outputted required capabilities against the
554   // supported capabilities list. The supported capabilities list is useful for
555   // API users to quickly determine if they can use the pass or not. But this
556   // list has to remain up-to-date with the pass code. If we can detect a
557   // capability as required, but it's not listed, it means the list is
558   // out-of-sync. This method is not ideal, but should cover most cases.
559   {
560     for (auto capability : required_capabilities) {
561       assert(supportedCapabilities_.contains(capability) &&
562              "Module is using a capability that is not listed as supported.");
563     }
564   }
565 #endif
566 
567   return std::make_pair(std::move(required_capabilities),
568                         std::move(required_extensions));
569 }
570 
TrimUnrequiredCapabilities(const CapabilitySet & required_capabilities) const571 Pass::Status TrimCapabilitiesPass::TrimUnrequiredCapabilities(
572     const CapabilitySet& required_capabilities) const {
573   const FeatureManager* feature_manager = context()->get_feature_mgr();
574   CapabilitySet capabilities_to_trim;
575   for (auto capability : feature_manager->GetCapabilities()) {
576     // Some capabilities cannot be safely removed. Leaving them untouched.
577     if (untouchableCapabilities_.contains(capability)) {
578       continue;
579     }
580 
581     // If the capability is unsupported, don't trim it.
582     if (!supportedCapabilities_.contains(capability)) {
583       continue;
584     }
585 
586     if (required_capabilities.contains(capability)) {
587       continue;
588     }
589 
590     capabilities_to_trim.insert(capability);
591   }
592 
593   for (auto capability : capabilities_to_trim) {
594     context()->RemoveCapability(capability);
595   }
596 
597   return capabilities_to_trim.size() == 0 ? Pass::Status::SuccessWithoutChange
598                                           : Pass::Status::SuccessWithChange;
599 }
600 
TrimUnrequiredExtensions(const ExtensionSet & required_extensions) const601 Pass::Status TrimCapabilitiesPass::TrimUnrequiredExtensions(
602     const ExtensionSet& required_extensions) const {
603   const auto supported_extensions =
604       getExtensionsRelatedTo(supportedCapabilities_, context()->grammar());
605 
606   bool modified_module = false;
607   for (auto extension : supported_extensions) {
608     if (required_extensions.contains(extension)) {
609       continue;
610     }
611 
612     if (context()->RemoveExtension(extension)) {
613       modified_module = true;
614     }
615   }
616 
617   return modified_module ? Pass::Status::SuccessWithChange
618                          : Pass::Status::SuccessWithoutChange;
619 }
620 
HasForbiddenCapabilities() const621 bool TrimCapabilitiesPass::HasForbiddenCapabilities() const {
622   // EnumSet.HasAnyOf returns `true` if the given set is empty.
623   if (forbiddenCapabilities_.size() == 0) {
624     return false;
625   }
626 
627   const auto& capabilities = context()->get_feature_mgr()->GetCapabilities();
628   return capabilities.HasAnyOf(forbiddenCapabilities_);
629 }
630 
Process()631 Pass::Status TrimCapabilitiesPass::Process() {
632   if (HasForbiddenCapabilities()) {
633     return Status::SuccessWithoutChange;
634   }
635 
636   auto[required_capabilities, required_extensions] =
637       DetermineRequiredCapabilitiesAndExtensions();
638 
639   Pass::Status capStatus = TrimUnrequiredCapabilities(required_capabilities);
640   Pass::Status extStatus = TrimUnrequiredExtensions(required_extensions);
641 
642   return capStatus == Pass::Status::SuccessWithChange ||
643                  extStatus == Pass::Status::SuccessWithChange
644              ? Pass::Status::SuccessWithChange
645              : Pass::Status::SuccessWithoutChange;
646 }
647 
648 }  // namespace opt
649 }  // namespace spvtools
650