1 // Copyright (c) 2023 Google Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 #ifndef SOURCE_OPT_TRIM_CAPABILITIES_PASS_H_ 16 #define SOURCE_OPT_TRIM_CAPABILITIES_PASS_H_ 17 18 #include <algorithm> 19 #include <array> 20 #include <functional> 21 #include <optional> 22 #include <unordered_map> 23 #include <unordered_set> 24 25 #include "source/enum_set.h" 26 #include "source/extensions.h" 27 #include "source/opt/ir_context.h" 28 #include "source/opt/module.h" 29 #include "source/opt/pass.h" 30 #include "source/spirv_target_env.h" 31 32 namespace spvtools { 33 namespace opt { 34 35 // This is required for NDK build. The unordered_set/unordered_map 36 // implementation don't work with class enums. 37 struct ClassEnumHash { operatorClassEnumHash38 std::size_t operator()(spv::Capability value) const { 39 using StoringType = typename std::underlying_type_t<spv::Capability>; 40 return std::hash<StoringType>{}(static_cast<StoringType>(value)); 41 } 42 operatorClassEnumHash43 std::size_t operator()(spv::Op value) const { 44 using StoringType = typename std::underlying_type_t<spv::Op>; 45 return std::hash<StoringType>{}(static_cast<StoringType>(value)); 46 } 47 }; 48 49 // An opcode handler is a function which, given an instruction, returns either 50 // the required capability, or nothing. 51 // Each handler checks one case for a capability requirement. 52 // 53 // Example: 54 // - `OpTypeImage` can have operand `A` operand which requires capability 1 55 // - `OpTypeImage` can also have operand `B` which requires capability 2. 56 // -> We have 2 handlers: `Handler_OpTypeImage_1` and 57 // `Handler_OpTypeImage_2`. 58 using OpcodeHandler = 59 std::optional<spv::Capability> (*)(const Instruction* instruction); 60 61 // This pass tried to remove superfluous capabilities declared in the module. 62 // - If all the capabilities listed by an extension are removed, the extension 63 // is also trimmed. 64 // - If the module countains any capability listed in `kForbiddenCapabilities`, 65 // the module is left untouched. 66 // - No capabilities listed in `kUntouchableCapabilities` are trimmed, even when 67 // not used. 68 // - Only capabilitied listed in `kSupportedCapabilities` are supported. 69 // - If the module contains unsupported capabilities, results might be 70 // incorrect. 71 class TrimCapabilitiesPass : public Pass { 72 private: 73 // All the capabilities supported by this optimization pass. If your module 74 // contains unsupported instruction, the pass could yield bad results. 75 static constexpr std::array kSupportedCapabilities{ 76 // clang-format off 77 spv::Capability::ComputeDerivativeGroupLinearNV, 78 spv::Capability::ComputeDerivativeGroupQuadsNV, 79 spv::Capability::Float16, 80 spv::Capability::Float64, 81 spv::Capability::FragmentShaderPixelInterlockEXT, 82 spv::Capability::FragmentShaderSampleInterlockEXT, 83 spv::Capability::FragmentShaderShadingRateInterlockEXT, 84 spv::Capability::GroupNonUniform, 85 spv::Capability::GroupNonUniformArithmetic, 86 spv::Capability::GroupNonUniformClustered, 87 spv::Capability::GroupNonUniformPartitionedNV, 88 spv::Capability::GroupNonUniformVote, 89 spv::Capability::Groups, 90 spv::Capability::ImageMSArray, 91 spv::Capability::Int16, 92 spv::Capability::Int64, 93 spv::Capability::Linkage, 94 spv::Capability::MinLod, 95 spv::Capability::PhysicalStorageBufferAddresses, 96 spv::Capability::RayQueryKHR, 97 spv::Capability::RayTracingKHR, 98 spv::Capability::RayTraversalPrimitiveCullingKHR, 99 spv::Capability::Shader, 100 spv::Capability::ShaderClockKHR, 101 spv::Capability::StorageImageReadWithoutFormat, 102 spv::Capability::StorageInputOutput16, 103 spv::Capability::StoragePushConstant16, 104 spv::Capability::StorageUniform16, 105 spv::Capability::StorageUniformBufferBlock16, 106 spv::Capability::VulkanMemoryModelDeviceScope, 107 // clang-format on 108 }; 109 110 // Those capabilities disable all transformation of the module. 111 static constexpr std::array kForbiddenCapabilities{ 112 spv::Capability::Linkage, 113 }; 114 115 // Those capabilities are never removed from a module because we cannot 116 // guess from the SPIR-V only if they are required or not. 117 static constexpr std::array kUntouchableCapabilities{ 118 spv::Capability::Shader, 119 }; 120 121 public: 122 TrimCapabilitiesPass(); 123 TrimCapabilitiesPass(const TrimCapabilitiesPass&) = delete; 124 TrimCapabilitiesPass(TrimCapabilitiesPass&&) = delete; 125 126 private: 127 // Inserts every capability listed by `descriptor` this pass supports into 128 // `output`. Expects a Descriptor like `spv_opcode_desc_t` or 129 // `spv_operand_desc_t`. 130 template <class Descriptor> addSupportedCapabilitiesToSet(const Descriptor * const descriptor,CapabilitySet * output)131 inline void addSupportedCapabilitiesToSet(const Descriptor* const descriptor, 132 CapabilitySet* output) const { 133 const uint32_t capabilityCount = descriptor->numCapabilities; 134 for (uint32_t i = 0; i < capabilityCount; ++i) { 135 const auto capability = descriptor->capabilities[i]; 136 if (supportedCapabilities_.contains(capability)) { 137 output->insert(capability); 138 } 139 } 140 } 141 142 // Inserts every extension listed by `descriptor` required by the module into 143 // `output`. Expects a Descriptor like `spv_opcode_desc_t` or 144 // `spv_operand_desc_t`. 145 template <class Descriptor> addSupportedExtensionsToSet(const Descriptor * const descriptor,ExtensionSet * output)146 inline void addSupportedExtensionsToSet(const Descriptor* const descriptor, 147 ExtensionSet* output) const { 148 if (descriptor->minVersion <= 149 spvVersionForTargetEnv(context()->GetTargetEnv())) { 150 return; 151 } 152 output->insert(descriptor->extensions, 153 descriptor->extensions + descriptor->numExtensions); 154 } 155 156 void addInstructionRequirementsForOpcode(spv::Op opcode, 157 CapabilitySet* capabilities, 158 ExtensionSet* extensions) const; 159 void addInstructionRequirementsForOperand(const Operand& operand, 160 CapabilitySet* capabilities, 161 ExtensionSet* extensions) const; 162 163 // Given an `instruction`, determines the capabilities it requires, and output 164 // them in `capabilities`. The returned capabilities form a subset of 165 // kSupportedCapabilities. 166 void addInstructionRequirements(Instruction* instruction, 167 CapabilitySet* capabilities, 168 ExtensionSet* extensions) const; 169 170 // Given an operand `type` and `value`, adds the extensions it would require 171 // to `extensions`. 172 void AddExtensionsForOperand(const spv_operand_type_t type, 173 const uint32_t value, 174 ExtensionSet* extensions) const; 175 176 // Returns the list of required capabilities and extensions for the module. 177 // The returned capabilities form a subset of kSupportedCapabilities. 178 std::pair<CapabilitySet, ExtensionSet> 179 DetermineRequiredCapabilitiesAndExtensions() const; 180 181 // Trims capabilities not listed in `required_capabilities` if possible. 182 // Returns whether or not the module was modified. 183 Pass::Status TrimUnrequiredCapabilities( 184 const CapabilitySet& required_capabilities) const; 185 186 // Trims extensions not listed in `required_extensions` if supported by this 187 // pass. An extensions is considered supported as soon as one capability this 188 // pass support requires it. 189 Pass::Status TrimUnrequiredExtensions( 190 const ExtensionSet& required_extensions) const; 191 192 // Returns if the analyzed module contains any forbidden capability. 193 bool HasForbiddenCapabilities() const; 194 195 public: name()196 const char* name() const override { return "trim-capabilities"; } 197 Status Process() override; 198 199 private: 200 const CapabilitySet supportedCapabilities_; 201 const CapabilitySet forbiddenCapabilities_; 202 const CapabilitySet untouchableCapabilities_; 203 const std::unordered_multimap<spv::Op, OpcodeHandler, ClassEnumHash> 204 opcodeHandlers_; 205 }; 206 207 } // namespace opt 208 } // namespace spvtools 209 #endif // SOURCE_OPT_TRIM_CAPABILITIES_H_ 210