1 // Copyright (c) 2023 Google Inc.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "source/opt/trim_capabilities_pass.h"
16
17 #include <algorithm>
18 #include <array>
19 #include <cassert>
20 #include <functional>
21 #include <optional>
22 #include <queue>
23 #include <stack>
24 #include <unordered_map>
25 #include <unordered_set>
26 #include <vector>
27
28 #include "source/enum_set.h"
29 #include "source/enum_string_mapping.h"
30 #include "source/opt/ir_context.h"
31 #include "source/opt/reflect.h"
32 #include "source/spirv_target_env.h"
33 #include "source/util/string_utils.h"
34
35 namespace spvtools {
36 namespace opt {
37
38 namespace {
39 constexpr uint32_t kOpTypeFloatSizeIndex = 0;
40 constexpr uint32_t kOpTypePointerStorageClassIndex = 0;
41 constexpr uint32_t kTypeArrayTypeIndex = 0;
42 constexpr uint32_t kOpTypeScalarBitWidthIndex = 0;
43 constexpr uint32_t kTypePointerTypeIdInIndex = 1;
44 constexpr uint32_t kOpTypeIntSizeIndex = 0;
45 constexpr uint32_t kOpTypeImageDimIndex = 1;
46 constexpr uint32_t kOpTypeImageArrayedIndex = kOpTypeImageDimIndex + 2;
47 constexpr uint32_t kOpTypeImageMSIndex = kOpTypeImageArrayedIndex + 1;
48 constexpr uint32_t kOpTypeImageSampledIndex = kOpTypeImageMSIndex + 1;
49 constexpr uint32_t kOpTypeImageFormatIndex = kOpTypeImageSampledIndex + 1;
50 constexpr uint32_t kOpImageReadImageIndex = 0;
51 constexpr uint32_t kOpImageSparseReadImageIndex = 0;
52
53 // DFS visit of the type defined by `instruction`.
54 // If `condition` is true, children of the current node are visited.
55 // If `condition` is false, the children of the current node are ignored.
56 template <class UnaryPredicate>
DFSWhile(const Instruction * instruction,UnaryPredicate condition)57 static void DFSWhile(const Instruction* instruction, UnaryPredicate condition) {
58 std::stack<uint32_t> instructions_to_visit;
59 instructions_to_visit.push(instruction->result_id());
60 const auto* def_use_mgr = instruction->context()->get_def_use_mgr();
61
62 while (!instructions_to_visit.empty()) {
63 const Instruction* item = def_use_mgr->GetDef(instructions_to_visit.top());
64 instructions_to_visit.pop();
65
66 if (!condition(item)) {
67 continue;
68 }
69
70 if (item->opcode() == spv::Op::OpTypePointer) {
71 instructions_to_visit.push(
72 item->GetSingleWordInOperand(kTypePointerTypeIdInIndex));
73 continue;
74 }
75
76 if (item->opcode() == spv::Op::OpTypeMatrix ||
77 item->opcode() == spv::Op::OpTypeVector ||
78 item->opcode() == spv::Op::OpTypeArray ||
79 item->opcode() == spv::Op::OpTypeRuntimeArray) {
80 instructions_to_visit.push(
81 item->GetSingleWordInOperand(kTypeArrayTypeIndex));
82 continue;
83 }
84
85 if (item->opcode() == spv::Op::OpTypeStruct) {
86 item->ForEachInOperand([&instructions_to_visit](const uint32_t* op_id) {
87 instructions_to_visit.push(*op_id);
88 });
89 continue;
90 }
91 }
92 }
93
94 // Walks the type defined by `instruction` (OpType* only).
95 // Returns `true` if any call to `predicate` with the type/subtype returns true.
96 template <class UnaryPredicate>
AnyTypeOf(const Instruction * instruction,UnaryPredicate predicate)97 static bool AnyTypeOf(const Instruction* instruction,
98 UnaryPredicate predicate) {
99 assert(IsTypeInst(instruction->opcode()) &&
100 "AnyTypeOf called with a non-type instruction.");
101
102 bool found_one = false;
103 DFSWhile(instruction, [&found_one, predicate](const Instruction* node) {
104 if (found_one || predicate(node)) {
105 found_one = true;
106 return false;
107 }
108
109 return true;
110 });
111 return found_one;
112 }
113
is16bitType(const Instruction * instruction)114 static bool is16bitType(const Instruction* instruction) {
115 if (instruction->opcode() != spv::Op::OpTypeInt &&
116 instruction->opcode() != spv::Op::OpTypeFloat) {
117 return false;
118 }
119
120 return instruction->GetSingleWordInOperand(kOpTypeScalarBitWidthIndex) == 16;
121 }
122
Has16BitCapability(const FeatureManager * feature_manager)123 static bool Has16BitCapability(const FeatureManager* feature_manager) {
124 const CapabilitySet& capabilities = feature_manager->GetCapabilities();
125 return capabilities.contains(spv::Capability::Float16) ||
126 capabilities.contains(spv::Capability::Int16);
127 }
128
129 } // namespace
130
131 // ============== Begin opcode handler implementations. =======================
132 //
133 // Adding support for a new capability should only require adding a new handler,
134 // and updating the
135 // kSupportedCapabilities/kUntouchableCapabilities/kFordiddenCapabilities lists.
136 //
137 // Handler names follow the following convention:
138 // Handler_<Opcode>_<Capability>()
139
Handler_OpTypeFloat_Float16(const Instruction * instruction)140 static std::optional<spv::Capability> Handler_OpTypeFloat_Float16(
141 const Instruction* instruction) {
142 assert(instruction->opcode() == spv::Op::OpTypeFloat &&
143 "This handler only support OpTypeFloat opcodes.");
144
145 const uint32_t size =
146 instruction->GetSingleWordInOperand(kOpTypeFloatSizeIndex);
147 return size == 16 ? std::optional(spv::Capability::Float16) : std::nullopt;
148 }
149
Handler_OpTypeFloat_Float64(const Instruction * instruction)150 static std::optional<spv::Capability> Handler_OpTypeFloat_Float64(
151 const Instruction* instruction) {
152 assert(instruction->opcode() == spv::Op::OpTypeFloat &&
153 "This handler only support OpTypeFloat opcodes.");
154
155 const uint32_t size =
156 instruction->GetSingleWordInOperand(kOpTypeFloatSizeIndex);
157 return size == 64 ? std::optional(spv::Capability::Float64) : std::nullopt;
158 }
159
160 static std::optional<spv::Capability>
Handler_OpTypePointer_StorageInputOutput16(const Instruction * instruction)161 Handler_OpTypePointer_StorageInputOutput16(const Instruction* instruction) {
162 assert(instruction->opcode() == spv::Op::OpTypePointer &&
163 "This handler only support OpTypePointer opcodes.");
164
165 // This capability is only required if the variable has an Input/Output
166 // storage class.
167 spv::StorageClass storage_class = spv::StorageClass(
168 instruction->GetSingleWordInOperand(kOpTypePointerStorageClassIndex));
169 if (storage_class != spv::StorageClass::Input &&
170 storage_class != spv::StorageClass::Output) {
171 return std::nullopt;
172 }
173
174 if (!Has16BitCapability(instruction->context()->get_feature_mgr())) {
175 return std::nullopt;
176 }
177
178 return AnyTypeOf(instruction, is16bitType)
179 ? std::optional(spv::Capability::StorageInputOutput16)
180 : std::nullopt;
181 }
182
183 static std::optional<spv::Capability>
Handler_OpTypePointer_StoragePushConstant16(const Instruction * instruction)184 Handler_OpTypePointer_StoragePushConstant16(const Instruction* instruction) {
185 assert(instruction->opcode() == spv::Op::OpTypePointer &&
186 "This handler only support OpTypePointer opcodes.");
187
188 // This capability is only required if the variable has a PushConstant storage
189 // class.
190 spv::StorageClass storage_class = spv::StorageClass(
191 instruction->GetSingleWordInOperand(kOpTypePointerStorageClassIndex));
192 if (storage_class != spv::StorageClass::PushConstant) {
193 return std::nullopt;
194 }
195
196 if (!Has16BitCapability(instruction->context()->get_feature_mgr())) {
197 return std::nullopt;
198 }
199
200 return AnyTypeOf(instruction, is16bitType)
201 ? std::optional(spv::Capability::StoragePushConstant16)
202 : std::nullopt;
203 }
204
205 static std::optional<spv::Capability>
Handler_OpTypePointer_StorageUniformBufferBlock16(const Instruction * instruction)206 Handler_OpTypePointer_StorageUniformBufferBlock16(
207 const Instruction* instruction) {
208 assert(instruction->opcode() == spv::Op::OpTypePointer &&
209 "This handler only support OpTypePointer opcodes.");
210
211 // This capability is only required if the variable has a Uniform storage
212 // class.
213 spv::StorageClass storage_class = spv::StorageClass(
214 instruction->GetSingleWordInOperand(kOpTypePointerStorageClassIndex));
215 if (storage_class != spv::StorageClass::Uniform) {
216 return std::nullopt;
217 }
218
219 if (!Has16BitCapability(instruction->context()->get_feature_mgr())) {
220 return std::nullopt;
221 }
222
223 const auto* decoration_mgr = instruction->context()->get_decoration_mgr();
224 const bool matchesCondition =
225 AnyTypeOf(instruction, [decoration_mgr](const Instruction* item) {
226 if (!decoration_mgr->HasDecoration(item->result_id(),
227 spv::Decoration::BufferBlock)) {
228 return false;
229 }
230
231 return AnyTypeOf(item, is16bitType);
232 });
233
234 return matchesCondition
235 ? std::optional(spv::Capability::StorageUniformBufferBlock16)
236 : std::nullopt;
237 }
238
Handler_OpTypePointer_StorageUniform16(const Instruction * instruction)239 static std::optional<spv::Capability> Handler_OpTypePointer_StorageUniform16(
240 const Instruction* instruction) {
241 assert(instruction->opcode() == spv::Op::OpTypePointer &&
242 "This handler only support OpTypePointer opcodes.");
243
244 // This capability is only required if the variable has a Uniform storage
245 // class.
246 spv::StorageClass storage_class = spv::StorageClass(
247 instruction->GetSingleWordInOperand(kOpTypePointerStorageClassIndex));
248 if (storage_class != spv::StorageClass::Uniform) {
249 return std::nullopt;
250 }
251
252 const auto* feature_manager = instruction->context()->get_feature_mgr();
253 if (!Has16BitCapability(feature_manager)) {
254 return std::nullopt;
255 }
256
257 const bool hasBufferBlockCapability =
258 feature_manager->GetCapabilities().contains(
259 spv::Capability::StorageUniformBufferBlock16);
260 const auto* decoration_mgr = instruction->context()->get_decoration_mgr();
261 bool found16bitType = false;
262
263 DFSWhile(instruction, [decoration_mgr, hasBufferBlockCapability,
264 &found16bitType](const Instruction* item) {
265 if (found16bitType) {
266 return false;
267 }
268
269 if (hasBufferBlockCapability &&
270 decoration_mgr->HasDecoration(item->result_id(),
271 spv::Decoration::BufferBlock)) {
272 return false;
273 }
274
275 if (is16bitType(item)) {
276 found16bitType = true;
277 return false;
278 }
279
280 return true;
281 });
282
283 return found16bitType ? std::optional(spv::Capability::StorageUniform16)
284 : std::nullopt;
285 }
286
Handler_OpTypeInt_Int16(const Instruction * instruction)287 static std::optional<spv::Capability> Handler_OpTypeInt_Int16(
288 const Instruction* instruction) {
289 assert(instruction->opcode() == spv::Op::OpTypeInt &&
290 "This handler only support OpTypeInt opcodes.");
291
292 const uint32_t size =
293 instruction->GetSingleWordInOperand(kOpTypeIntSizeIndex);
294 return size == 16 ? std::optional(spv::Capability::Int16) : std::nullopt;
295 }
296
Handler_OpTypeInt_Int64(const Instruction * instruction)297 static std::optional<spv::Capability> Handler_OpTypeInt_Int64(
298 const Instruction* instruction) {
299 assert(instruction->opcode() == spv::Op::OpTypeInt &&
300 "This handler only support OpTypeInt opcodes.");
301
302 const uint32_t size =
303 instruction->GetSingleWordInOperand(kOpTypeIntSizeIndex);
304 return size == 64 ? std::optional(spv::Capability::Int64) : std::nullopt;
305 }
306
Handler_OpTypeImage_ImageMSArray(const Instruction * instruction)307 static std::optional<spv::Capability> Handler_OpTypeImage_ImageMSArray(
308 const Instruction* instruction) {
309 assert(instruction->opcode() == spv::Op::OpTypeImage &&
310 "This handler only support OpTypeImage opcodes.");
311
312 const uint32_t arrayed =
313 instruction->GetSingleWordInOperand(kOpTypeImageArrayedIndex);
314 const uint32_t ms = instruction->GetSingleWordInOperand(kOpTypeImageMSIndex);
315 const uint32_t sampled =
316 instruction->GetSingleWordInOperand(kOpTypeImageSampledIndex);
317
318 return arrayed == 1 && sampled == 2 && ms == 1
319 ? std::optional(spv::Capability::ImageMSArray)
320 : std::nullopt;
321 }
322
323 static std::optional<spv::Capability>
Handler_OpImageRead_StorageImageReadWithoutFormat(const Instruction * instruction)324 Handler_OpImageRead_StorageImageReadWithoutFormat(
325 const Instruction* instruction) {
326 assert(instruction->opcode() == spv::Op::OpImageRead &&
327 "This handler only support OpImageRead opcodes.");
328 const auto* def_use_mgr = instruction->context()->get_def_use_mgr();
329
330 const uint32_t image_index =
331 instruction->GetSingleWordInOperand(kOpImageReadImageIndex);
332 const uint32_t type_index = def_use_mgr->GetDef(image_index)->type_id();
333 const Instruction* type = def_use_mgr->GetDef(type_index);
334 const uint32_t dim = type->GetSingleWordInOperand(kOpTypeImageDimIndex);
335 const uint32_t format = type->GetSingleWordInOperand(kOpTypeImageFormatIndex);
336
337 const bool is_unknown = spv::ImageFormat(format) == spv::ImageFormat::Unknown;
338 const bool requires_capability_for_unknown =
339 spv::Dim(dim) != spv::Dim::SubpassData;
340 return is_unknown && requires_capability_for_unknown
341 ? std::optional(spv::Capability::StorageImageReadWithoutFormat)
342 : std::nullopt;
343 }
344
345 static std::optional<spv::Capability>
Handler_OpImageSparseRead_StorageImageReadWithoutFormat(const Instruction * instruction)346 Handler_OpImageSparseRead_StorageImageReadWithoutFormat(
347 const Instruction* instruction) {
348 assert(instruction->opcode() == spv::Op::OpImageSparseRead &&
349 "This handler only support OpImageSparseRead opcodes.");
350 const auto* def_use_mgr = instruction->context()->get_def_use_mgr();
351
352 const uint32_t image_index =
353 instruction->GetSingleWordInOperand(kOpImageSparseReadImageIndex);
354 const uint32_t type_index = def_use_mgr->GetDef(image_index)->type_id();
355 const Instruction* type = def_use_mgr->GetDef(type_index);
356 const uint32_t format = type->GetSingleWordInOperand(kOpTypeImageFormatIndex);
357
358 return spv::ImageFormat(format) == spv::ImageFormat::Unknown
359 ? std::optional(spv::Capability::StorageImageReadWithoutFormat)
360 : std::nullopt;
361 }
362
363 // Opcode of interest to determine capabilities requirements.
364 constexpr std::array<std::pair<spv::Op, OpcodeHandler>, 12> kOpcodeHandlers{{
365 // clang-format off
366 {spv::Op::OpImageRead, Handler_OpImageRead_StorageImageReadWithoutFormat},
367 {spv::Op::OpImageSparseRead, Handler_OpImageSparseRead_StorageImageReadWithoutFormat},
368 {spv::Op::OpTypeFloat, Handler_OpTypeFloat_Float16 },
369 {spv::Op::OpTypeFloat, Handler_OpTypeFloat_Float64 },
370 {spv::Op::OpTypeImage, Handler_OpTypeImage_ImageMSArray},
371 {spv::Op::OpTypeInt, Handler_OpTypeInt_Int16 },
372 {spv::Op::OpTypeInt, Handler_OpTypeInt_Int64 },
373 {spv::Op::OpTypePointer, Handler_OpTypePointer_StorageInputOutput16},
374 {spv::Op::OpTypePointer, Handler_OpTypePointer_StoragePushConstant16},
375 {spv::Op::OpTypePointer, Handler_OpTypePointer_StorageUniform16},
376 {spv::Op::OpTypePointer, Handler_OpTypePointer_StorageUniform16},
377 {spv::Op::OpTypePointer, Handler_OpTypePointer_StorageUniformBufferBlock16},
378 // clang-format on
379 }};
380
381 // ============== End opcode handler implementations. =======================
382
383 namespace {
getExtensionsRelatedTo(const CapabilitySet & capabilities,const AssemblyGrammar & grammar)384 ExtensionSet getExtensionsRelatedTo(const CapabilitySet& capabilities,
385 const AssemblyGrammar& grammar) {
386 ExtensionSet output;
387 const spv_operand_desc_t* desc = nullptr;
388 for (auto capability : capabilities) {
389 if (SPV_SUCCESS != grammar.lookupOperand(SPV_OPERAND_TYPE_CAPABILITY,
390 static_cast<uint32_t>(capability),
391 &desc)) {
392 continue;
393 }
394
395 for (uint32_t i = 0; i < desc->numExtensions; ++i) {
396 output.insert(desc->extensions[i]);
397 }
398 }
399
400 return output;
401 }
402 } // namespace
403
TrimCapabilitiesPass()404 TrimCapabilitiesPass::TrimCapabilitiesPass()
405 : supportedCapabilities_(
406 TrimCapabilitiesPass::kSupportedCapabilities.cbegin(),
407 TrimCapabilitiesPass::kSupportedCapabilities.cend()),
408 forbiddenCapabilities_(
409 TrimCapabilitiesPass::kForbiddenCapabilities.cbegin(),
410 TrimCapabilitiesPass::kForbiddenCapabilities.cend()),
411 untouchableCapabilities_(
412 TrimCapabilitiesPass::kUntouchableCapabilities.cbegin(),
413 TrimCapabilitiesPass::kUntouchableCapabilities.cend()),
414 opcodeHandlers_(kOpcodeHandlers.cbegin(), kOpcodeHandlers.cend()) {}
415
addInstructionRequirementsForOpcode(spv::Op opcode,CapabilitySet * capabilities,ExtensionSet * extensions) const416 void TrimCapabilitiesPass::addInstructionRequirementsForOpcode(
417 spv::Op opcode, CapabilitySet* capabilities,
418 ExtensionSet* extensions) const {
419 // Ignoring OpBeginInvocationInterlockEXT and OpEndInvocationInterlockEXT
420 // because they have three possible capabilities, only one of which is needed
421 if (opcode == spv::Op::OpBeginInvocationInterlockEXT ||
422 opcode == spv::Op::OpEndInvocationInterlockEXT) {
423 return;
424 }
425
426 const spv_opcode_desc_t* desc = {};
427 auto result = context()->grammar().lookupOpcode(opcode, &desc);
428 if (result != SPV_SUCCESS) {
429 return;
430 }
431
432 addSupportedCapabilitiesToSet(desc, capabilities);
433 addSupportedExtensionsToSet(desc, extensions);
434 }
435
addInstructionRequirementsForOperand(const Operand & operand,CapabilitySet * capabilities,ExtensionSet * extensions) const436 void TrimCapabilitiesPass::addInstructionRequirementsForOperand(
437 const Operand& operand, CapabilitySet* capabilities,
438 ExtensionSet* extensions) const {
439 // No supported capability relies on a 2+-word operand.
440 if (operand.words.size() != 1) {
441 return;
442 }
443
444 // No supported capability relies on a literal string operand or an ID.
445 if (operand.type == SPV_OPERAND_TYPE_LITERAL_STRING ||
446 operand.type == SPV_OPERAND_TYPE_ID ||
447 operand.type == SPV_OPERAND_TYPE_RESULT_ID) {
448 return;
449 }
450
451 // If the Vulkan memory model is declared and any instruction uses Device
452 // scope, the VulkanMemoryModelDeviceScope capability must be declared. This
453 // rule cannot be covered by the grammar, so must be checked explicitly.
454 if (operand.type == SPV_OPERAND_TYPE_SCOPE_ID) {
455 const Instruction* memory_model = context()->GetMemoryModel();
456 if (memory_model && memory_model->GetSingleWordInOperand(1u) ==
457 uint32_t(spv::MemoryModel::Vulkan)) {
458 capabilities->insert(spv::Capability::VulkanMemoryModelDeviceScope);
459 }
460 }
461
462 // case 1: Operand is a single value, can directly lookup.
463 if (!spvOperandIsConcreteMask(operand.type)) {
464 const spv_operand_desc_t* desc = {};
465 auto result = context()->grammar().lookupOperand(operand.type,
466 operand.words[0], &desc);
467 if (result != SPV_SUCCESS) {
468 return;
469 }
470 addSupportedCapabilitiesToSet(desc, capabilities);
471 addSupportedExtensionsToSet(desc, extensions);
472 return;
473 }
474
475 // case 2: operand can be a bitmask, we need to decompose the lookup.
476 for (uint32_t i = 0; i < 32; i++) {
477 const uint32_t mask = (1 << i) & operand.words[0];
478 if (!mask) {
479 continue;
480 }
481
482 const spv_operand_desc_t* desc = {};
483 auto result = context()->grammar().lookupOperand(operand.type, mask, &desc);
484 if (result != SPV_SUCCESS) {
485 continue;
486 }
487
488 addSupportedCapabilitiesToSet(desc, capabilities);
489 addSupportedExtensionsToSet(desc, extensions);
490 }
491 }
492
addInstructionRequirements(Instruction * instruction,CapabilitySet * capabilities,ExtensionSet * extensions) const493 void TrimCapabilitiesPass::addInstructionRequirements(
494 Instruction* instruction, CapabilitySet* capabilities,
495 ExtensionSet* extensions) const {
496 // Ignoring OpCapability and OpExtension instructions.
497 if (instruction->opcode() == spv::Op::OpCapability ||
498 instruction->opcode() == spv::Op::OpExtension) {
499 return;
500 }
501
502 addInstructionRequirementsForOpcode(instruction->opcode(), capabilities,
503 extensions);
504
505 // Second case: one of the opcode operand is gated by a capability.
506 const uint32_t operandCount = instruction->NumOperands();
507 for (uint32_t i = 0; i < operandCount; i++) {
508 addInstructionRequirementsForOperand(instruction->GetOperand(i),
509 capabilities, extensions);
510 }
511
512 // Last case: some complex logic needs to be run to determine capabilities.
513 auto[begin, end] = opcodeHandlers_.equal_range(instruction->opcode());
514 for (auto it = begin; it != end; it++) {
515 const OpcodeHandler handler = it->second;
516 auto result = handler(instruction);
517 if (!result.has_value()) {
518 continue;
519 }
520
521 capabilities->insert(*result);
522 }
523 }
524
AddExtensionsForOperand(const spv_operand_type_t type,const uint32_t value,ExtensionSet * extensions) const525 void TrimCapabilitiesPass::AddExtensionsForOperand(
526 const spv_operand_type_t type, const uint32_t value,
527 ExtensionSet* extensions) const {
528 const spv_operand_desc_t* desc = nullptr;
529 spv_result_t result = context()->grammar().lookupOperand(type, value, &desc);
530 if (result != SPV_SUCCESS) {
531 return;
532 }
533 addSupportedExtensionsToSet(desc, extensions);
534 }
535
536 std::pair<CapabilitySet, ExtensionSet>
DetermineRequiredCapabilitiesAndExtensions() const537 TrimCapabilitiesPass::DetermineRequiredCapabilitiesAndExtensions() const {
538 CapabilitySet required_capabilities;
539 ExtensionSet required_extensions;
540
541 get_module()->ForEachInst([&](Instruction* instruction) {
542 addInstructionRequirements(instruction, &required_capabilities,
543 &required_extensions);
544 });
545
546 for (auto capability : required_capabilities) {
547 AddExtensionsForOperand(SPV_OPERAND_TYPE_CAPABILITY,
548 static_cast<uint32_t>(capability),
549 &required_extensions);
550 }
551
552 #if !defined(NDEBUG)
553 // Debug only. We check the outputted required capabilities against the
554 // supported capabilities list. The supported capabilities list is useful for
555 // API users to quickly determine if they can use the pass or not. But this
556 // list has to remain up-to-date with the pass code. If we can detect a
557 // capability as required, but it's not listed, it means the list is
558 // out-of-sync. This method is not ideal, but should cover most cases.
559 {
560 for (auto capability : required_capabilities) {
561 assert(supportedCapabilities_.contains(capability) &&
562 "Module is using a capability that is not listed as supported.");
563 }
564 }
565 #endif
566
567 return std::make_pair(std::move(required_capabilities),
568 std::move(required_extensions));
569 }
570
TrimUnrequiredCapabilities(const CapabilitySet & required_capabilities) const571 Pass::Status TrimCapabilitiesPass::TrimUnrequiredCapabilities(
572 const CapabilitySet& required_capabilities) const {
573 const FeatureManager* feature_manager = context()->get_feature_mgr();
574 CapabilitySet capabilities_to_trim;
575 for (auto capability : feature_manager->GetCapabilities()) {
576 // Some capabilities cannot be safely removed. Leaving them untouched.
577 if (untouchableCapabilities_.contains(capability)) {
578 continue;
579 }
580
581 // If the capability is unsupported, don't trim it.
582 if (!supportedCapabilities_.contains(capability)) {
583 continue;
584 }
585
586 if (required_capabilities.contains(capability)) {
587 continue;
588 }
589
590 capabilities_to_trim.insert(capability);
591 }
592
593 for (auto capability : capabilities_to_trim) {
594 context()->RemoveCapability(capability);
595 }
596
597 return capabilities_to_trim.size() == 0 ? Pass::Status::SuccessWithoutChange
598 : Pass::Status::SuccessWithChange;
599 }
600
TrimUnrequiredExtensions(const ExtensionSet & required_extensions) const601 Pass::Status TrimCapabilitiesPass::TrimUnrequiredExtensions(
602 const ExtensionSet& required_extensions) const {
603 const auto supported_extensions =
604 getExtensionsRelatedTo(supportedCapabilities_, context()->grammar());
605
606 bool modified_module = false;
607 for (auto extension : supported_extensions) {
608 if (required_extensions.contains(extension)) {
609 continue;
610 }
611
612 if (context()->RemoveExtension(extension)) {
613 modified_module = true;
614 }
615 }
616
617 return modified_module ? Pass::Status::SuccessWithChange
618 : Pass::Status::SuccessWithoutChange;
619 }
620
HasForbiddenCapabilities() const621 bool TrimCapabilitiesPass::HasForbiddenCapabilities() const {
622 // EnumSet.HasAnyOf returns `true` if the given set is empty.
623 if (forbiddenCapabilities_.size() == 0) {
624 return false;
625 }
626
627 const auto& capabilities = context()->get_feature_mgr()->GetCapabilities();
628 return capabilities.HasAnyOf(forbiddenCapabilities_);
629 }
630
Process()631 Pass::Status TrimCapabilitiesPass::Process() {
632 if (HasForbiddenCapabilities()) {
633 return Status::SuccessWithoutChange;
634 }
635
636 auto[required_capabilities, required_extensions] =
637 DetermineRequiredCapabilitiesAndExtensions();
638
639 Pass::Status capStatus = TrimUnrequiredCapabilities(required_capabilities);
640 Pass::Status extStatus = TrimUnrequiredExtensions(required_extensions);
641
642 return capStatus == Pass::Status::SuccessWithChange ||
643 extStatus == Pass::Status::SuccessWithChange
644 ? Pass::Status::SuccessWithChange
645 : Pass::Status::SuccessWithoutChange;
646 }
647
648 } // namespace opt
649 } // namespace spvtools
650