1 // Copyright (c) 2023 Google Inc.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "source/opt/trim_capabilities_pass.h"
16
17 #include <algorithm>
18 #include <array>
19 #include <cassert>
20 #include <functional>
21 #include <optional>
22 #include <queue>
23 #include <unordered_map>
24 #include <unordered_set>
25 #include <vector>
26
27 #include "source/enum_set.h"
28 #include "source/enum_string_mapping.h"
29 #include "source/opt/ir_context.h"
30 #include "source/spirv_target_env.h"
31 #include "source/util/string_utils.h"
32
33 namespace spvtools {
34 namespace opt {
35
36 namespace {
37 constexpr uint32_t kVariableStorageClassIndex = 0;
38 constexpr uint32_t kTypeArrayTypeIndex = 0;
39 constexpr uint32_t kOpTypeScalarBitWidthIndex = 0;
40 constexpr uint32_t kTypePointerTypeIdInIdx = 1;
41 } // namespace
42
43 // ============== Begin opcode handler implementations. =======================
44 //
45 // Adding support for a new capability should only require adding a new handler,
46 // and updating the
47 // kSupportedCapabilities/kUntouchableCapabilities/kFordiddenCapabilities lists.
48 //
49 // Handler names follow the following convention:
50 // Handler_<Opcode>_<Capability>()
51
Handler_OpVariable_StorageInputOutput16(const Instruction * instruction)52 static std::optional<spv::Capability> Handler_OpVariable_StorageInputOutput16(
53 const Instruction* instruction) {
54 assert(instruction->opcode() == spv::Op::OpVariable &&
55 "This handler only support OpVariable opcodes.");
56
57 // This capability is only required if the variable as an Input/Output storage
58 // class.
59 spv::StorageClass storage_class = spv::StorageClass(
60 instruction->GetSingleWordInOperand(kVariableStorageClassIndex));
61 if (storage_class != spv::StorageClass::Input &&
62 storage_class != spv::StorageClass::Output) {
63 return std::nullopt;
64 }
65
66 // This capability is only required if the type involves a 16-bit component.
67 // Quick check: are 16-bit types allowed?
68 const CapabilitySet& capabilities =
69 instruction->context()->get_feature_mgr()->GetCapabilities();
70 if (!capabilities.contains(spv::Capability::Float16) &&
71 !capabilities.contains(spv::Capability::Int16)) {
72 return std::nullopt;
73 }
74
75 // We need to walk the type definition.
76 std::queue<uint32_t> instructions_to_visit;
77 instructions_to_visit.push(instruction->type_id());
78 const auto* def_use_mgr = instruction->context()->get_def_use_mgr();
79 while (!instructions_to_visit.empty()) {
80 const Instruction* item =
81 def_use_mgr->GetDef(instructions_to_visit.front());
82 instructions_to_visit.pop();
83
84 if (item->opcode() == spv::Op::OpTypePointer) {
85 instructions_to_visit.push(
86 item->GetSingleWordInOperand(kTypePointerTypeIdInIdx));
87 continue;
88 }
89
90 if (item->opcode() == spv::Op::OpTypeMatrix ||
91 item->opcode() == spv::Op::OpTypeVector ||
92 item->opcode() == spv::Op::OpTypeArray ||
93 item->opcode() == spv::Op::OpTypeRuntimeArray) {
94 instructions_to_visit.push(
95 item->GetSingleWordInOperand(kTypeArrayTypeIndex));
96 continue;
97 }
98
99 if (item->opcode() == spv::Op::OpTypeStruct) {
100 item->ForEachInOperand([&instructions_to_visit](const uint32_t* op_id) {
101 instructions_to_visit.push(*op_id);
102 });
103 continue;
104 }
105
106 if (item->opcode() != spv::Op::OpTypeInt &&
107 item->opcode() != spv::Op::OpTypeFloat) {
108 continue;
109 }
110
111 if (item->GetSingleWordInOperand(kOpTypeScalarBitWidthIndex) == 16) {
112 return spv::Capability::StorageInputOutput16;
113 }
114 }
115
116 return std::nullopt;
117 }
118
119 // Opcode of interest to determine capabilities requirements.
120 constexpr std::array<std::pair<spv::Op, OpcodeHandler>, 1> kOpcodeHandlers{{
121 {spv::Op::OpVariable, Handler_OpVariable_StorageInputOutput16},
122 }};
123
124 // ============== End opcode handler implementations. =======================
125
126 namespace {
getExtensionsRelatedTo(const CapabilitySet & capabilities,const AssemblyGrammar & grammar)127 ExtensionSet getExtensionsRelatedTo(const CapabilitySet& capabilities,
128 const AssemblyGrammar& grammar) {
129 ExtensionSet output;
130 const spv_operand_desc_t* desc = nullptr;
131 for (auto capability : capabilities) {
132 if (SPV_SUCCESS != grammar.lookupOperand(SPV_OPERAND_TYPE_CAPABILITY,
133 static_cast<uint32_t>(capability),
134 &desc)) {
135 continue;
136 }
137
138 for (uint32_t i = 0; i < desc->numExtensions; ++i) {
139 output.insert(desc->extensions[i]);
140 }
141 }
142
143 return output;
144 }
145 } // namespace
146
TrimCapabilitiesPass()147 TrimCapabilitiesPass::TrimCapabilitiesPass()
148 : supportedCapabilities_(
149 TrimCapabilitiesPass::kSupportedCapabilities.cbegin(),
150 TrimCapabilitiesPass::kSupportedCapabilities.cend()),
151 forbiddenCapabilities_(
152 TrimCapabilitiesPass::kForbiddenCapabilities.cbegin(),
153 TrimCapabilitiesPass::kForbiddenCapabilities.cend()),
154 untouchableCapabilities_(
155 TrimCapabilitiesPass::kUntouchableCapabilities.cbegin(),
156 TrimCapabilitiesPass::kUntouchableCapabilities.cend()),
157 opcodeHandlers_(kOpcodeHandlers.cbegin(), kOpcodeHandlers.cend()) {}
158
addInstructionRequirements(Instruction * instruction,CapabilitySet * capabilities,ExtensionSet * extensions) const159 void TrimCapabilitiesPass::addInstructionRequirements(
160 Instruction* instruction, CapabilitySet* capabilities,
161 ExtensionSet* extensions) const {
162 // Ignoring OpCapability instructions.
163 if (instruction->opcode() == spv::Op::OpCapability) {
164 return;
165 }
166
167 // First case: the opcode is itself gated by a capability.
168 {
169 const spv_opcode_desc_t* desc = {};
170 auto result =
171 context()->grammar().lookupOpcode(instruction->opcode(), &desc);
172 if (result == SPV_SUCCESS) {
173 addSupportedCapabilitiesToSet(desc->numCapabilities, desc->capabilities,
174 capabilities);
175 if (desc->minVersion <=
176 spvVersionForTargetEnv(context()->GetTargetEnv())) {
177 extensions->insert(desc->extensions,
178 desc->extensions + desc->numExtensions);
179 }
180 }
181 }
182
183 // Second case: one of the opcode operand is gated by a capability.
184 const uint32_t operandCount = instruction->NumOperands();
185 for (uint32_t i = 0; i < operandCount; i++) {
186 const auto& operand = instruction->GetOperand(i);
187 // No supported capability relies on a 2+-word operand.
188 if (operand.words.size() != 1) {
189 continue;
190 }
191
192 // No supported capability relies on a literal string operand.
193 if (operand.type == SPV_OPERAND_TYPE_LITERAL_STRING) {
194 continue;
195 }
196
197 const spv_operand_desc_t* desc = {};
198 auto result = context()->grammar().lookupOperand(operand.type,
199 operand.words[0], &desc);
200 if (result != SPV_SUCCESS) {
201 continue;
202 }
203
204 addSupportedCapabilitiesToSet(desc->numCapabilities, desc->capabilities,
205 capabilities);
206 if (desc->minVersion <= spvVersionForTargetEnv(context()->GetTargetEnv())) {
207 extensions->insert(desc->extensions,
208 desc->extensions + desc->numExtensions);
209 }
210 }
211
212 // Last case: some complex logic needs to be run to determine capabilities.
213 auto[begin, end] = opcodeHandlers_.equal_range(instruction->opcode());
214 for (auto it = begin; it != end; it++) {
215 const OpcodeHandler handler = it->second;
216 auto result = handler(instruction);
217 if (result.has_value()) {
218 capabilities->insert(*result);
219 }
220 }
221 }
222
223 std::pair<CapabilitySet, ExtensionSet>
DetermineRequiredCapabilitiesAndExtensions() const224 TrimCapabilitiesPass::DetermineRequiredCapabilitiesAndExtensions() const {
225 CapabilitySet required_capabilities;
226 ExtensionSet required_extensions;
227
228 get_module()->ForEachInst([&](Instruction* instruction) {
229 addInstructionRequirements(instruction, &required_capabilities,
230 &required_extensions);
231 });
232
233 #if !defined(NDEBUG)
234 // Debug only. We check the outputted required capabilities against the
235 // supported capabilities list. The supported capabilities list is useful for
236 // API users to quickly determine if they can use the pass or not. But this
237 // list has to remain up-to-date with the pass code. If we can detect a
238 // capability as required, but it's not listed, it means the list is
239 // out-of-sync. This method is not ideal, but should cover most cases.
240 {
241 for (auto capability : required_capabilities) {
242 assert(supportedCapabilities_.contains(capability) &&
243 "Module is using a capability that is not listed as supported.");
244 }
245 }
246 #endif
247
248 return std::make_pair(std::move(required_capabilities),
249 std::move(required_extensions));
250 }
251
TrimUnrequiredCapabilities(const CapabilitySet & required_capabilities) const252 Pass::Status TrimCapabilitiesPass::TrimUnrequiredCapabilities(
253 const CapabilitySet& required_capabilities) const {
254 const FeatureManager* feature_manager = context()->get_feature_mgr();
255 CapabilitySet capabilities_to_trim;
256 for (auto capability : feature_manager->GetCapabilities()) {
257 // Forbidden capability completely prevents trimming. Early exit.
258 if (forbiddenCapabilities_.contains(capability)) {
259 return Pass::Status::SuccessWithoutChange;
260 }
261
262 // Some capabilities cannot be safely removed. Leaving them untouched.
263 if (untouchableCapabilities_.contains(capability)) {
264 continue;
265 }
266
267 // If the capability is unsupported, don't trim it.
268 if (!supportedCapabilities_.contains(capability)) {
269 continue;
270 }
271
272 if (required_capabilities.contains(capability)) {
273 continue;
274 }
275
276 capabilities_to_trim.insert(capability);
277 }
278
279 for (auto capability : capabilities_to_trim) {
280 context()->RemoveCapability(capability);
281 }
282
283 return capabilities_to_trim.size() == 0 ? Pass::Status::SuccessWithoutChange
284 : Pass::Status::SuccessWithChange;
285 }
286
TrimUnrequiredExtensions(const ExtensionSet & required_extensions) const287 Pass::Status TrimCapabilitiesPass::TrimUnrequiredExtensions(
288 const ExtensionSet& required_extensions) const {
289 const auto supported_extensions =
290 getExtensionsRelatedTo(supportedCapabilities_, context()->grammar());
291
292 bool modified_module = false;
293 for (auto extension : supported_extensions) {
294 if (!required_extensions.contains(extension)) {
295 modified_module = true;
296 context()->RemoveExtension(extension);
297 }
298 }
299
300 return modified_module ? Pass::Status::SuccessWithChange
301 : Pass::Status::SuccessWithoutChange;
302 }
303
Process()304 Pass::Status TrimCapabilitiesPass::Process() {
305 auto[required_capabilities, required_extensions] =
306 DetermineRequiredCapabilitiesAndExtensions();
307
308 Pass::Status status = TrimUnrequiredCapabilities(required_capabilities);
309 // If no capabilities were removed, we have no extension to trim.
310 // Note: this is true because this pass only removes unused extensions caused
311 // by unused capabilities.
312 // This is not an extension trimming pass.
313 if (status == Pass::Status::SuccessWithoutChange) {
314 return status;
315 }
316 return TrimUnrequiredExtensions(required_extensions);
317 }
318
319 } // namespace opt
320 } // namespace spvtools
321