• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2023 Google Inc.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #ifndef SOURCE_OPT_TRIM_CAPABILITIES_PASS_H_
16 #define SOURCE_OPT_TRIM_CAPABILITIES_PASS_H_
17 
18 #include <algorithm>
19 #include <array>
20 #include <functional>
21 #include <optional>
22 #include <unordered_map>
23 #include <unordered_set>
24 
25 #include "source/enum_set.h"
26 #include "source/extensions.h"
27 #include "source/opt/ir_context.h"
28 #include "source/opt/module.h"
29 #include "source/opt/pass.h"
30 #include "source/spirv_target_env.h"
31 
32 namespace spvtools {
33 namespace opt {
34 
35 // This is required for NDK build. The unordered_set/unordered_map
36 // implementation don't work with class enums.
37 struct ClassEnumHash {
operatorClassEnumHash38   std::size_t operator()(spv::Capability value) const {
39     using StoringType = typename std::underlying_type_t<spv::Capability>;
40     return std::hash<StoringType>{}(static_cast<StoringType>(value));
41   }
42 
operatorClassEnumHash43   std::size_t operator()(spv::Op value) const {
44     using StoringType = typename std::underlying_type_t<spv::Op>;
45     return std::hash<StoringType>{}(static_cast<StoringType>(value));
46   }
47 };
48 
49 // An opcode handler is a function which, given an instruction, returns either
50 // the required capability, or nothing.
51 // Each handler checks one case for a capability requirement.
52 //
53 // Example:
54 //  - `OpTypeImage` can have operand `A` operand which requires capability 1
55 //  - `OpTypeImage` can also have operand `B` which requires capability 2.
56 //    -> We have 2 handlers: `Handler_OpTypeImage_1` and
57 //    `Handler_OpTypeImage_2`.
58 using OpcodeHandler =
59     std::optional<spv::Capability> (*)(const Instruction* instruction);
60 
61 // This pass tried to remove superfluous capabilities declared in the module.
62 // - If all the capabilities listed by an extension are removed, the extension
63 //   is also trimmed.
64 // - If the module countains any capability listed in `kForbiddenCapabilities`,
65 //   the module is left untouched.
66 // - No capabilities listed in `kUntouchableCapabilities` are trimmed, even when
67 //   not used.
68 // - Only capabilitied listed in `kSupportedCapabilities` are supported.
69 // - If the module contains unsupported capabilities, results might be
70 //   incorrect.
71 class TrimCapabilitiesPass : public Pass {
72  private:
73   // All the capabilities supported by this optimization pass. If your module
74   // contains unsupported instruction, the pass could yield bad results.
75   static constexpr std::array kSupportedCapabilities{
76       // clang-format off
77       spv::Capability::ComputeDerivativeGroupLinearNV,
78       spv::Capability::ComputeDerivativeGroupQuadsNV,
79       spv::Capability::Float16,
80       spv::Capability::Float64,
81       spv::Capability::FragmentShaderPixelInterlockEXT,
82       spv::Capability::FragmentShaderSampleInterlockEXT,
83       spv::Capability::FragmentShaderShadingRateInterlockEXT,
84       spv::Capability::GroupNonUniform,
85       spv::Capability::GroupNonUniformArithmetic,
86       spv::Capability::GroupNonUniformClustered,
87       spv::Capability::GroupNonUniformPartitionedNV,
88       spv::Capability::GroupNonUniformVote,
89       spv::Capability::Groups,
90       spv::Capability::ImageMSArray,
91       spv::Capability::Int16,
92       spv::Capability::Int64,
93       spv::Capability::Linkage,
94       spv::Capability::MinLod,
95       spv::Capability::PhysicalStorageBufferAddresses,
96       spv::Capability::RayQueryKHR,
97       spv::Capability::RayTracingKHR,
98       spv::Capability::RayTraversalPrimitiveCullingKHR,
99       spv::Capability::Shader,
100       spv::Capability::ShaderClockKHR,
101       spv::Capability::StorageImageReadWithoutFormat,
102       spv::Capability::StorageInputOutput16,
103       spv::Capability::StoragePushConstant16,
104       spv::Capability::StorageUniform16,
105       spv::Capability::StorageUniformBufferBlock16,
106       spv::Capability::VulkanMemoryModelDeviceScope,
107       // clang-format on
108   };
109 
110   // Those capabilities disable all transformation of the module.
111   static constexpr std::array kForbiddenCapabilities{
112       spv::Capability::Linkage,
113   };
114 
115   // Those capabilities are never removed from a module because we cannot
116   // guess from the SPIR-V only if they are required or not.
117   static constexpr std::array kUntouchableCapabilities{
118       spv::Capability::Shader,
119   };
120 
121  public:
122   TrimCapabilitiesPass();
123   TrimCapabilitiesPass(const TrimCapabilitiesPass&) = delete;
124   TrimCapabilitiesPass(TrimCapabilitiesPass&&) = delete;
125 
126  private:
127   // Inserts every capability listed by `descriptor` this pass supports into
128   // `output`. Expects a Descriptor like `spv_opcode_desc_t` or
129   // `spv_operand_desc_t`.
130   template <class Descriptor>
addSupportedCapabilitiesToSet(const Descriptor * const descriptor,CapabilitySet * output)131   inline void addSupportedCapabilitiesToSet(const Descriptor* const descriptor,
132                                             CapabilitySet* output) const {
133     const uint32_t capabilityCount = descriptor->numCapabilities;
134     for (uint32_t i = 0; i < capabilityCount; ++i) {
135       const auto capability = descriptor->capabilities[i];
136       if (supportedCapabilities_.contains(capability)) {
137         output->insert(capability);
138       }
139     }
140   }
141 
142   // Inserts every extension listed by `descriptor` required by the module into
143   // `output`. Expects a Descriptor like `spv_opcode_desc_t` or
144   // `spv_operand_desc_t`.
145   template <class Descriptor>
addSupportedExtensionsToSet(const Descriptor * const descriptor,ExtensionSet * output)146   inline void addSupportedExtensionsToSet(const Descriptor* const descriptor,
147                                           ExtensionSet* output) const {
148     if (descriptor->minVersion <=
149         spvVersionForTargetEnv(context()->GetTargetEnv())) {
150       return;
151     }
152     output->insert(descriptor->extensions,
153                    descriptor->extensions + descriptor->numExtensions);
154   }
155 
156   void addInstructionRequirementsForOpcode(spv::Op opcode,
157                                            CapabilitySet* capabilities,
158                                            ExtensionSet* extensions) const;
159   void addInstructionRequirementsForOperand(const Operand& operand,
160                                             CapabilitySet* capabilities,
161                                             ExtensionSet* extensions) const;
162 
163   // Given an `instruction`, determines the capabilities it requires, and output
164   // them in `capabilities`. The returned capabilities form a subset of
165   // kSupportedCapabilities.
166   void addInstructionRequirements(Instruction* instruction,
167                                   CapabilitySet* capabilities,
168                                   ExtensionSet* extensions) const;
169 
170   // Given an operand `type` and `value`, adds the extensions it would require
171   // to `extensions`.
172   void AddExtensionsForOperand(const spv_operand_type_t type,
173                                const uint32_t value,
174                                ExtensionSet* extensions) const;
175 
176   // Returns the list of required capabilities and extensions for the module.
177   // The returned capabilities form a subset of kSupportedCapabilities.
178   std::pair<CapabilitySet, ExtensionSet>
179   DetermineRequiredCapabilitiesAndExtensions() const;
180 
181   // Trims capabilities not listed in `required_capabilities` if possible.
182   // Returns whether or not the module was modified.
183   Pass::Status TrimUnrequiredCapabilities(
184       const CapabilitySet& required_capabilities) const;
185 
186   // Trims extensions not listed in `required_extensions` if supported by this
187   // pass. An extensions is considered supported as soon as one capability this
188   // pass support requires it.
189   Pass::Status TrimUnrequiredExtensions(
190       const ExtensionSet& required_extensions) const;
191 
192   // Returns if the analyzed module contains any forbidden capability.
193   bool HasForbiddenCapabilities() const;
194 
195  public:
name()196   const char* name() const override { return "trim-capabilities"; }
197   Status Process() override;
198 
199  private:
200   const CapabilitySet supportedCapabilities_;
201   const CapabilitySet forbiddenCapabilities_;
202   const CapabilitySet untouchableCapabilities_;
203   const std::unordered_multimap<spv::Op, OpcodeHandler, ClassEnumHash>
204       opcodeHandlers_;
205 };
206 
207 }  // namespace opt
208 }  // namespace spvtools
209 #endif  // SOURCE_OPT_TRIM_CAPABILITIES_H_
210