• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //===- DeduceVersionExtensionCapabilityPass.cpp ---------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to deduce minimal version/extension/capability
10 // requirements for a spirv::ModuleOp.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "PassDetail.h"
15 #include "mlir/Dialect/SPIRV/Passes.h"
16 #include "mlir/Dialect/SPIRV/SPIRVDialect.h"
17 #include "mlir/Dialect/SPIRV/SPIRVOps.h"
18 #include "mlir/Dialect/SPIRV/SPIRVTypes.h"
19 #include "mlir/Dialect/SPIRV/TargetAndABI.h"
20 #include "mlir/IR/Builders.h"
21 #include "mlir/IR/Visitors.h"
22 #include "llvm/ADT/SetVector.h"
23 #include "llvm/ADT/SmallSet.h"
24 
25 using namespace mlir;
26 
27 namespace {
28 /// Pass to deduce minimal version/extension/capability requirements for a
29 /// spirv::ModuleOp.
30 class UpdateVCEPass final : public SPIRVUpdateVCEBase<UpdateVCEPass> {
31   void runOnOperation() override;
32 };
33 } // namespace
34 
35 /// Checks that `candidates` extension requirements are possible to be satisfied
36 /// with the given `targetEnv` and updates `deducedExtensions` if so. Emits
37 /// errors attaching to the given `op` on failures.
38 ///
39 ///  `candidates` is a vector of vector for extension requirements following
40 /// ((Extension::A OR Extension::B) AND (Extension::C OR Extension::D))
41 /// convention.
checkAndUpdateExtensionRequirements(Operation * op,const spirv::TargetEnv & targetEnv,const spirv::SPIRVType::ExtensionArrayRefVector & candidates,llvm::SetVector<spirv::Extension> & deducedExtensions)42 static LogicalResult checkAndUpdateExtensionRequirements(
43     Operation *op, const spirv::TargetEnv &targetEnv,
44     const spirv::SPIRVType::ExtensionArrayRefVector &candidates,
45     llvm::SetVector<spirv::Extension> &deducedExtensions) {
46   for (const auto &ors : candidates) {
47     if (Optional<spirv::Extension> chosen = targetEnv.allows(ors)) {
48       deducedExtensions.insert(*chosen);
49     } else {
50       SmallVector<StringRef, 4> extStrings;
51       for (spirv::Extension ext : ors)
52         extStrings.push_back(spirv::stringifyExtension(ext));
53 
54       return op->emitError("'")
55              << op->getName() << "' requires at least one extension in ["
56              << llvm::join(extStrings, ", ")
57              << "] but none allowed in target environment";
58     }
59   }
60   return success();
61 }
62 
63 /// Checks that `candidates`capability requirements are possible to be satisfied
64 /// with the given `targetEnv` and updates `deducedCapabilities` if so. Emits
65 /// errors attaching to the given `op` on failures.
66 ///
67 ///  `candidates` is a vector of vector for capability requirements following
68 /// ((Capability::A OR Capability::B) AND (Capability::C OR Capability::D))
69 /// convention.
checkAndUpdateCapabilityRequirements(Operation * op,const spirv::TargetEnv & targetEnv,const spirv::SPIRVType::CapabilityArrayRefVector & candidates,llvm::SetVector<spirv::Capability> & deducedCapabilities)70 static LogicalResult checkAndUpdateCapabilityRequirements(
71     Operation *op, const spirv::TargetEnv &targetEnv,
72     const spirv::SPIRVType::CapabilityArrayRefVector &candidates,
73     llvm::SetVector<spirv::Capability> &deducedCapabilities) {
74   for (const auto &ors : candidates) {
75     if (Optional<spirv::Capability> chosen = targetEnv.allows(ors)) {
76       deducedCapabilities.insert(*chosen);
77     } else {
78       SmallVector<StringRef, 4> capStrings;
79       for (spirv::Capability cap : ors)
80         capStrings.push_back(spirv::stringifyCapability(cap));
81 
82       return op->emitError("'")
83              << op->getName() << "' requires at least one capability in ["
84              << llvm::join(capStrings, ", ")
85              << "] but none allowed in target environment";
86     }
87   }
88   return success();
89 }
90 
runOnOperation()91 void UpdateVCEPass::runOnOperation() {
92   spirv::ModuleOp module = getOperation();
93 
94   spirv::TargetEnvAttr targetAttr = spirv::lookupTargetEnv(module);
95   if (!targetAttr) {
96     module.emitError("missing 'spv.target_env' attribute");
97     return signalPassFailure();
98   }
99 
100   spirv::TargetEnv targetEnv(targetAttr);
101   spirv::Version allowedVersion = targetAttr.getVersion();
102 
103   spirv::Version deducedVersion = spirv::Version::V_1_0;
104   llvm::SetVector<spirv::Extension> deducedExtensions;
105   llvm::SetVector<spirv::Capability> deducedCapabilities;
106 
107   // Walk each SPIR-V op to deduce the minimal version/extension/capability
108   // requirements.
109   WalkResult walkResult = module.walk([&](Operation *op) -> WalkResult {
110     // Op min version requirements
111     if (auto minVersion = dyn_cast<spirv::QueryMinVersionInterface>(op)) {
112       deducedVersion = std::max(deducedVersion, minVersion.getMinVersion());
113       if (deducedVersion > allowedVersion) {
114         return op->emitError("'") << op->getName() << "' requires min version "
115                                   << spirv::stringifyVersion(deducedVersion)
116                                   << " but target environment allows up to "
117                                   << spirv::stringifyVersion(allowedVersion);
118       }
119     }
120 
121     // Op extension requirements
122     if (auto extensions = dyn_cast<spirv::QueryExtensionInterface>(op))
123       if (failed(checkAndUpdateExtensionRequirements(
124               op, targetEnv, extensions.getExtensions(), deducedExtensions)))
125         return WalkResult::interrupt();
126 
127     // Op capability requirements
128     if (auto capabilities = dyn_cast<spirv::QueryCapabilityInterface>(op))
129       if (failed(checkAndUpdateCapabilityRequirements(
130               op, targetEnv, capabilities.getCapabilities(),
131               deducedCapabilities)))
132         return WalkResult::interrupt();
133 
134     SmallVector<Type, 4> valueTypes;
135     valueTypes.append(op->operand_type_begin(), op->operand_type_end());
136     valueTypes.append(op->result_type_begin(), op->result_type_end());
137 
138     // Special treatment for global variables, whose type requirements are
139     // conveyed by type attributes.
140     if (auto globalVar = dyn_cast<spirv::GlobalVariableOp>(op))
141       valueTypes.push_back(globalVar.type());
142 
143     // Requirements from values' types
144     SmallVector<ArrayRef<spirv::Extension>, 4> typeExtensions;
145     SmallVector<ArrayRef<spirv::Capability>, 8> typeCapabilities;
146     for (Type valueType : valueTypes) {
147       typeExtensions.clear();
148       valueType.cast<spirv::SPIRVType>().getExtensions(typeExtensions);
149       if (failed(checkAndUpdateExtensionRequirements(
150               op, targetEnv, typeExtensions, deducedExtensions)))
151         return WalkResult::interrupt();
152 
153       typeCapabilities.clear();
154       valueType.cast<spirv::SPIRVType>().getCapabilities(typeCapabilities);
155       if (failed(checkAndUpdateCapabilityRequirements(
156               op, targetEnv, typeCapabilities, deducedCapabilities)))
157         return WalkResult::interrupt();
158     }
159 
160     return WalkResult::advance();
161   });
162 
163   if (walkResult.wasInterrupted())
164     return signalPassFailure();
165 
166   // TODO: verify that the deduced version is consistent with
167   // SPIR-V ops' maximal version requirements.
168 
169   auto triple = spirv::VerCapExtAttr::get(
170       deducedVersion, deducedCapabilities.getArrayRef(),
171       deducedExtensions.getArrayRef(), &getContext());
172   module.setAttr(spirv::ModuleOp::getVCETripleAttrName(), triple);
173 }
174 
175 std::unique_ptr<OperationPass<spirv::ModuleOp>>
createUpdateVersionCapabilityExtensionPass()176 mlir::spirv::createUpdateVersionCapabilityExtensionPass() {
177   return std::make_unique<UpdateVCEPass>();
178 }
179