• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2023 Google Inc.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "source/opt/trim_capabilities_pass.h"
16 
17 #include <algorithm>
18 #include <array>
19 #include <cassert>
20 #include <functional>
21 #include <optional>
22 #include <queue>
23 #include <unordered_map>
24 #include <unordered_set>
25 #include <vector>
26 
27 #include "source/enum_set.h"
28 #include "source/enum_string_mapping.h"
29 #include "source/opt/ir_context.h"
30 #include "source/spirv_target_env.h"
31 #include "source/util/string_utils.h"
32 
33 namespace spvtools {
34 namespace opt {
35 
36 namespace {
37 constexpr uint32_t kVariableStorageClassIndex = 0;
38 constexpr uint32_t kTypeArrayTypeIndex = 0;
39 constexpr uint32_t kOpTypeScalarBitWidthIndex = 0;
40 constexpr uint32_t kTypePointerTypeIdInIdx = 1;
41 }  // namespace
42 
43 // ============== Begin opcode handler implementations. =======================
44 //
45 // Adding support for a new capability should only require adding a new handler,
46 // and updating the
47 // kSupportedCapabilities/kUntouchableCapabilities/kFordiddenCapabilities lists.
48 //
49 // Handler names follow the following convention:
50 //  Handler_<Opcode>_<Capability>()
51 
Handler_OpVariable_StorageInputOutput16(const Instruction * instruction)52 static std::optional<spv::Capability> Handler_OpVariable_StorageInputOutput16(
53     const Instruction* instruction) {
54   assert(instruction->opcode() == spv::Op::OpVariable &&
55          "This handler only support OpVariable opcodes.");
56 
57   // This capability is only required if the variable as an Input/Output storage
58   // class.
59   spv::StorageClass storage_class = spv::StorageClass(
60       instruction->GetSingleWordInOperand(kVariableStorageClassIndex));
61   if (storage_class != spv::StorageClass::Input &&
62       storage_class != spv::StorageClass::Output) {
63     return std::nullopt;
64   }
65 
66   // This capability is only required if the type involves a 16-bit component.
67   // Quick check: are 16-bit types allowed?
68   const CapabilitySet& capabilities =
69       instruction->context()->get_feature_mgr()->GetCapabilities();
70   if (!capabilities.contains(spv::Capability::Float16) &&
71       !capabilities.contains(spv::Capability::Int16)) {
72     return std::nullopt;
73   }
74 
75   // We need to walk the type definition.
76   std::queue<uint32_t> instructions_to_visit;
77   instructions_to_visit.push(instruction->type_id());
78   const auto* def_use_mgr = instruction->context()->get_def_use_mgr();
79   while (!instructions_to_visit.empty()) {
80     const Instruction* item =
81         def_use_mgr->GetDef(instructions_to_visit.front());
82     instructions_to_visit.pop();
83 
84     if (item->opcode() == spv::Op::OpTypePointer) {
85       instructions_to_visit.push(
86           item->GetSingleWordInOperand(kTypePointerTypeIdInIdx));
87       continue;
88     }
89 
90     if (item->opcode() == spv::Op::OpTypeMatrix ||
91         item->opcode() == spv::Op::OpTypeVector ||
92         item->opcode() == spv::Op::OpTypeArray ||
93         item->opcode() == spv::Op::OpTypeRuntimeArray) {
94       instructions_to_visit.push(
95           item->GetSingleWordInOperand(kTypeArrayTypeIndex));
96       continue;
97     }
98 
99     if (item->opcode() == spv::Op::OpTypeStruct) {
100       item->ForEachInOperand([&instructions_to_visit](const uint32_t* op_id) {
101         instructions_to_visit.push(*op_id);
102       });
103       continue;
104     }
105 
106     if (item->opcode() != spv::Op::OpTypeInt &&
107         item->opcode() != spv::Op::OpTypeFloat) {
108       continue;
109     }
110 
111     if (item->GetSingleWordInOperand(kOpTypeScalarBitWidthIndex) == 16) {
112       return spv::Capability::StorageInputOutput16;
113     }
114   }
115 
116   return std::nullopt;
117 }
118 
119 // Opcode of interest to determine capabilities requirements.
120 constexpr std::array<std::pair<spv::Op, OpcodeHandler>, 1> kOpcodeHandlers{{
121     {spv::Op::OpVariable, Handler_OpVariable_StorageInputOutput16},
122 }};
123 
124 // ==============  End opcode handler implementations.  =======================
125 
126 namespace {
getExtensionsRelatedTo(const CapabilitySet & capabilities,const AssemblyGrammar & grammar)127 ExtensionSet getExtensionsRelatedTo(const CapabilitySet& capabilities,
128                                     const AssemblyGrammar& grammar) {
129   ExtensionSet output;
130   const spv_operand_desc_t* desc = nullptr;
131   for (auto capability : capabilities) {
132     if (SPV_SUCCESS != grammar.lookupOperand(SPV_OPERAND_TYPE_CAPABILITY,
133                                              static_cast<uint32_t>(capability),
134                                              &desc)) {
135       continue;
136     }
137 
138     for (uint32_t i = 0; i < desc->numExtensions; ++i) {
139       output.insert(desc->extensions[i]);
140     }
141   }
142 
143   return output;
144 }
145 }  // namespace
146 
TrimCapabilitiesPass()147 TrimCapabilitiesPass::TrimCapabilitiesPass()
148     : supportedCapabilities_(
149           TrimCapabilitiesPass::kSupportedCapabilities.cbegin(),
150           TrimCapabilitiesPass::kSupportedCapabilities.cend()),
151       forbiddenCapabilities_(
152           TrimCapabilitiesPass::kForbiddenCapabilities.cbegin(),
153           TrimCapabilitiesPass::kForbiddenCapabilities.cend()),
154       untouchableCapabilities_(
155           TrimCapabilitiesPass::kUntouchableCapabilities.cbegin(),
156           TrimCapabilitiesPass::kUntouchableCapabilities.cend()),
157       opcodeHandlers_(kOpcodeHandlers.cbegin(), kOpcodeHandlers.cend()) {}
158 
addInstructionRequirements(Instruction * instruction,CapabilitySet * capabilities,ExtensionSet * extensions) const159 void TrimCapabilitiesPass::addInstructionRequirements(
160     Instruction* instruction, CapabilitySet* capabilities,
161     ExtensionSet* extensions) const {
162   // Ignoring OpCapability instructions.
163   if (instruction->opcode() == spv::Op::OpCapability) {
164     return;
165   }
166 
167   // First case: the opcode is itself gated by a capability.
168   {
169     const spv_opcode_desc_t* desc = {};
170     auto result =
171         context()->grammar().lookupOpcode(instruction->opcode(), &desc);
172     if (result == SPV_SUCCESS) {
173       addSupportedCapabilitiesToSet(desc->numCapabilities, desc->capabilities,
174                                     capabilities);
175       if (desc->minVersion <=
176           spvVersionForTargetEnv(context()->GetTargetEnv())) {
177         extensions->insert(desc->extensions,
178                            desc->extensions + desc->numExtensions);
179       }
180     }
181   }
182 
183   // Second case: one of the opcode operand is gated by a capability.
184   const uint32_t operandCount = instruction->NumOperands();
185   for (uint32_t i = 0; i < operandCount; i++) {
186     const auto& operand = instruction->GetOperand(i);
187     // No supported capability relies on a 2+-word operand.
188     if (operand.words.size() != 1) {
189       continue;
190     }
191 
192     // No supported capability relies on a literal string operand.
193     if (operand.type == SPV_OPERAND_TYPE_LITERAL_STRING) {
194       continue;
195     }
196 
197     const spv_operand_desc_t* desc = {};
198     auto result = context()->grammar().lookupOperand(operand.type,
199                                                      operand.words[0], &desc);
200     if (result != SPV_SUCCESS) {
201       continue;
202     }
203 
204     addSupportedCapabilitiesToSet(desc->numCapabilities, desc->capabilities,
205                                   capabilities);
206     if (desc->minVersion <= spvVersionForTargetEnv(context()->GetTargetEnv())) {
207       extensions->insert(desc->extensions,
208                          desc->extensions + desc->numExtensions);
209     }
210   }
211 
212   // Last case: some complex logic needs to be run to determine capabilities.
213   auto[begin, end] = opcodeHandlers_.equal_range(instruction->opcode());
214   for (auto it = begin; it != end; it++) {
215     const OpcodeHandler handler = it->second;
216     auto result = handler(instruction);
217     if (result.has_value()) {
218       capabilities->insert(*result);
219     }
220   }
221 }
222 
223 std::pair<CapabilitySet, ExtensionSet>
DetermineRequiredCapabilitiesAndExtensions() const224 TrimCapabilitiesPass::DetermineRequiredCapabilitiesAndExtensions() const {
225   CapabilitySet required_capabilities;
226   ExtensionSet required_extensions;
227 
228   get_module()->ForEachInst([&](Instruction* instruction) {
229     addInstructionRequirements(instruction, &required_capabilities,
230                                &required_extensions);
231   });
232 
233 #if !defined(NDEBUG)
234   // Debug only. We check the outputted required capabilities against the
235   // supported capabilities list. The supported capabilities list is useful for
236   // API users to quickly determine if they can use the pass or not. But this
237   // list has to remain up-to-date with the pass code. If we can detect a
238   // capability as required, but it's not listed, it means the list is
239   // out-of-sync. This method is not ideal, but should cover most cases.
240   {
241     for (auto capability : required_capabilities) {
242       assert(supportedCapabilities_.contains(capability) &&
243              "Module is using a capability that is not listed as supported.");
244     }
245   }
246 #endif
247 
248   return std::make_pair(std::move(required_capabilities),
249                         std::move(required_extensions));
250 }
251 
TrimUnrequiredCapabilities(const CapabilitySet & required_capabilities) const252 Pass::Status TrimCapabilitiesPass::TrimUnrequiredCapabilities(
253     const CapabilitySet& required_capabilities) const {
254   const FeatureManager* feature_manager = context()->get_feature_mgr();
255   CapabilitySet capabilities_to_trim;
256   for (auto capability : feature_manager->GetCapabilities()) {
257     // Forbidden capability completely prevents trimming. Early exit.
258     if (forbiddenCapabilities_.contains(capability)) {
259       return Pass::Status::SuccessWithoutChange;
260     }
261 
262     // Some capabilities cannot be safely removed. Leaving them untouched.
263     if (untouchableCapabilities_.contains(capability)) {
264       continue;
265     }
266 
267     // If the capability is unsupported, don't trim it.
268     if (!supportedCapabilities_.contains(capability)) {
269       continue;
270     }
271 
272     if (required_capabilities.contains(capability)) {
273       continue;
274     }
275 
276     capabilities_to_trim.insert(capability);
277   }
278 
279   for (auto capability : capabilities_to_trim) {
280     context()->RemoveCapability(capability);
281   }
282 
283   return capabilities_to_trim.size() == 0 ? Pass::Status::SuccessWithoutChange
284                                           : Pass::Status::SuccessWithChange;
285 }
286 
TrimUnrequiredExtensions(const ExtensionSet & required_extensions) const287 Pass::Status TrimCapabilitiesPass::TrimUnrequiredExtensions(
288     const ExtensionSet& required_extensions) const {
289   const auto supported_extensions =
290       getExtensionsRelatedTo(supportedCapabilities_, context()->grammar());
291 
292   bool modified_module = false;
293   for (auto extension : supported_extensions) {
294     if (!required_extensions.contains(extension)) {
295       modified_module = true;
296       context()->RemoveExtension(extension);
297     }
298   }
299 
300   return modified_module ? Pass::Status::SuccessWithChange
301                          : Pass::Status::SuccessWithoutChange;
302 }
303 
Process()304 Pass::Status TrimCapabilitiesPass::Process() {
305   auto[required_capabilities, required_extensions] =
306       DetermineRequiredCapabilitiesAndExtensions();
307 
308   Pass::Status status = TrimUnrequiredCapabilities(required_capabilities);
309   // If no capabilities were removed, we have no extension to trim.
310   // Note: this is true because this pass only removes unused extensions caused
311   // by unused capabilities.
312   //       This is not an extension trimming pass.
313   if (status == Pass::Status::SuccessWithoutChange) {
314     return status;
315   }
316   return TrimUnrequiredExtensions(required_extensions);
317 }
318 
319 }  // namespace opt
320 }  // namespace spvtools
321