1 //===- TestAvailability.cpp - Pass to test SPIR-V op availability ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Dialect/SPIRV/SPIRVLowering.h"
10 #include "mlir/Dialect/SPIRV/SPIRVOps.h"
11 #include "mlir/Dialect/SPIRV/SPIRVTypes.h"
12 #include "mlir/IR/BuiltinOps.h"
13 #include "mlir/Pass/Pass.h"
14
15 using namespace mlir;
16
17 //===----------------------------------------------------------------------===//
18 // Printing op availability pass
19 //===----------------------------------------------------------------------===//
20
21 namespace {
22 /// A pass for testing SPIR-V op availability.
23 struct PrintOpAvailability
24 : public PassWrapper<PrintOpAvailability, FunctionPass> {
25 void runOnFunction() override;
26 };
27 } // end anonymous namespace
28
runOnFunction()29 void PrintOpAvailability::runOnFunction() {
30 auto f = getFunction();
31 llvm::outs() << f.getName() << "\n";
32
33 Dialect *spvDialect = getContext().getLoadedDialect("spv");
34
35 f->walk([&](Operation *op) {
36 if (op->getDialect() != spvDialect)
37 return WalkResult::advance();
38
39 auto opName = op->getName();
40 auto &os = llvm::outs();
41
42 if (auto minVersion = dyn_cast<spirv::QueryMinVersionInterface>(op))
43 os << opName << " min version: "
44 << spirv::stringifyVersion(minVersion.getMinVersion()) << "\n";
45
46 if (auto maxVersion = dyn_cast<spirv::QueryMaxVersionInterface>(op))
47 os << opName << " max version: "
48 << spirv::stringifyVersion(maxVersion.getMaxVersion()) << "\n";
49
50 if (auto extension = dyn_cast<spirv::QueryExtensionInterface>(op)) {
51 os << opName << " extensions: [";
52 for (const auto &exts : extension.getExtensions()) {
53 os << " [";
54 llvm::interleaveComma(exts, os, [&](spirv::Extension ext) {
55 os << spirv::stringifyExtension(ext);
56 });
57 os << "]";
58 }
59 os << " ]\n";
60 }
61
62 if (auto capability = dyn_cast<spirv::QueryCapabilityInterface>(op)) {
63 os << opName << " capabilities: [";
64 for (const auto &caps : capability.getCapabilities()) {
65 os << " [";
66 llvm::interleaveComma(caps, os, [&](spirv::Capability cap) {
67 os << spirv::stringifyCapability(cap);
68 });
69 os << "]";
70 }
71 os << " ]\n";
72 }
73 os.flush();
74
75 return WalkResult::advance();
76 });
77 }
78
79 namespace mlir {
registerPrintOpAvailabilityPass()80 void registerPrintOpAvailabilityPass() {
81 PassRegistration<PrintOpAvailability> printOpAvailabilityPass(
82 "test-spirv-op-availability", "Test SPIR-V op availability");
83 }
84 } // namespace mlir
85
86 //===----------------------------------------------------------------------===//
87 // Converting target environment pass
88 //===----------------------------------------------------------------------===//
89
90 namespace {
91 /// A pass for testing SPIR-V op availability.
92 struct ConvertToTargetEnv
93 : public PassWrapper<ConvertToTargetEnv, FunctionPass> {
94 void runOnFunction() override;
95 };
96
97 struct ConvertToAtomCmpExchangeWeak : public RewritePattern {
98 ConvertToAtomCmpExchangeWeak(MLIRContext *context);
99 LogicalResult matchAndRewrite(Operation *op,
100 PatternRewriter &rewriter) const override;
101 };
102
103 struct ConvertToBitReverse : public RewritePattern {
104 ConvertToBitReverse(MLIRContext *context);
105 LogicalResult matchAndRewrite(Operation *op,
106 PatternRewriter &rewriter) const override;
107 };
108
109 struct ConvertToGroupNonUniformBallot : public RewritePattern {
110 ConvertToGroupNonUniformBallot(MLIRContext *context);
111 LogicalResult matchAndRewrite(Operation *op,
112 PatternRewriter &rewriter) const override;
113 };
114
115 struct ConvertToModule : public RewritePattern {
116 ConvertToModule(MLIRContext *context);
117 LogicalResult matchAndRewrite(Operation *op,
118 PatternRewriter &rewriter) const override;
119 };
120
121 struct ConvertToSubgroupBallot : public RewritePattern {
122 ConvertToSubgroupBallot(MLIRContext *context);
123 LogicalResult matchAndRewrite(Operation *op,
124 PatternRewriter &rewriter) const override;
125 };
126 } // end anonymous namespace
127
runOnFunction()128 void ConvertToTargetEnv::runOnFunction() {
129 MLIRContext *context = &getContext();
130 FuncOp fn = getFunction();
131
132 auto targetEnv = fn.getOperation()
133 ->getAttr(spirv::getTargetEnvAttrName())
134 .cast<spirv::TargetEnvAttr>();
135 if (!targetEnv) {
136 fn.emitError("missing 'spv.target_env' attribute");
137 return signalPassFailure();
138 }
139
140 auto target = spirv::SPIRVConversionTarget::get(targetEnv);
141
142 OwningRewritePatternList patterns;
143 patterns.insert<ConvertToAtomCmpExchangeWeak, ConvertToBitReverse,
144 ConvertToGroupNonUniformBallot, ConvertToModule,
145 ConvertToSubgroupBallot>(context);
146
147 if (failed(applyPartialConversion(fn, *target, std::move(patterns))))
148 return signalPassFailure();
149 }
150
ConvertToAtomCmpExchangeWeak(MLIRContext * context)151 ConvertToAtomCmpExchangeWeak::ConvertToAtomCmpExchangeWeak(MLIRContext *context)
152 : RewritePattern("test.convert_to_atomic_compare_exchange_weak_op",
153 {"spv.AtomicCompareExchangeWeak"}, 1, context) {}
154
155 LogicalResult
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const156 ConvertToAtomCmpExchangeWeak::matchAndRewrite(Operation *op,
157 PatternRewriter &rewriter) const {
158 Value ptr = op->getOperand(0);
159 Value value = op->getOperand(1);
160 Value comparator = op->getOperand(2);
161
162 // Create a spv.AtomicCompareExchangeWeak op with AtomicCounterMemory bits in
163 // memory semantics to additionally require AtomicStorage capability.
164 rewriter.replaceOpWithNewOp<spirv::AtomicCompareExchangeWeakOp>(
165 op, value.getType(), ptr, spirv::Scope::Workgroup,
166 spirv::MemorySemantics::AcquireRelease |
167 spirv::MemorySemantics::AtomicCounterMemory,
168 spirv::MemorySemantics::Acquire, value, comparator);
169 return success();
170 }
171
ConvertToBitReverse(MLIRContext * context)172 ConvertToBitReverse::ConvertToBitReverse(MLIRContext *context)
173 : RewritePattern("test.convert_to_bit_reverse_op", {"spv.BitReverse"}, 1,
174 context) {}
175
176 LogicalResult
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const177 ConvertToBitReverse::matchAndRewrite(Operation *op,
178 PatternRewriter &rewriter) const {
179 Value predicate = op->getOperand(0);
180
181 rewriter.replaceOpWithNewOp<spirv::BitReverseOp>(
182 op, op->getResult(0).getType(), predicate);
183 return success();
184 }
185
ConvertToGroupNonUniformBallot(MLIRContext * context)186 ConvertToGroupNonUniformBallot::ConvertToGroupNonUniformBallot(
187 MLIRContext *context)
188 : RewritePattern("test.convert_to_group_non_uniform_ballot_op",
189 {"spv.GroupNonUniformBallot"}, 1, context) {}
190
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const191 LogicalResult ConvertToGroupNonUniformBallot::matchAndRewrite(
192 Operation *op, PatternRewriter &rewriter) const {
193 Value predicate = op->getOperand(0);
194
195 rewriter.replaceOpWithNewOp<spirv::GroupNonUniformBallotOp>(
196 op, op->getResult(0).getType(), spirv::Scope::Workgroup, predicate);
197 return success();
198 }
199
ConvertToModule(MLIRContext * context)200 ConvertToModule::ConvertToModule(MLIRContext *context)
201 : RewritePattern("test.convert_to_module_op", {"spv.module"}, 1, context) {}
202
203 LogicalResult
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const204 ConvertToModule::matchAndRewrite(Operation *op,
205 PatternRewriter &rewriter) const {
206 rewriter.replaceOpWithNewOp<spirv::ModuleOp>(
207 op, spirv::AddressingModel::PhysicalStorageBuffer64,
208 spirv::MemoryModel::Vulkan);
209 return success();
210 }
211
ConvertToSubgroupBallot(MLIRContext * context)212 ConvertToSubgroupBallot::ConvertToSubgroupBallot(MLIRContext *context)
213 : RewritePattern("test.convert_to_subgroup_ballot_op",
214 {"spv.SubgroupBallotKHR"}, 1, context) {}
215
216 LogicalResult
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const217 ConvertToSubgroupBallot::matchAndRewrite(Operation *op,
218 PatternRewriter &rewriter) const {
219 Value predicate = op->getOperand(0);
220
221 rewriter.replaceOpWithNewOp<spirv::SubgroupBallotKHROp>(
222 op, op->getResult(0).getType(), predicate);
223 return success();
224 }
225
226 namespace mlir {
registerConvertToTargetEnvPass()227 void registerConvertToTargetEnvPass() {
228 PassRegistration<ConvertToTargetEnv> convertToTargetEnvPass(
229 "test-spirv-target-env", "Test SPIR-V target environment");
230 }
231 } // namespace mlir
232