1 //===- ConvertStandardToSPIRV.cpp - Standard to SPIR-V dialect conversion--===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to convert standard ops to SPIR-V ops.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "mlir/Dialect/SPIRV/LayoutUtils.h"
14 #include "mlir/Dialect/SPIRV/SPIRVDialect.h"
15 #include "mlir/Dialect/SPIRV/SPIRVLowering.h"
16 #include "mlir/Dialect/SPIRV/SPIRVOps.h"
17 #include "mlir/Dialect/StandardOps/IR/Ops.h"
18 #include "mlir/IR/AffineMap.h"
19 #include "mlir/Support/LogicalResult.h"
20 #include "llvm/ADT/SetVector.h"
21 #include "llvm/Support/Debug.h"
22
23 #define DEBUG_TYPE "std-to-spirv-pattern"
24
25 using namespace mlir;
26
27 //===----------------------------------------------------------------------===//
28 // Utility functions
29 //===----------------------------------------------------------------------===//
30
31 /// Returns true if the given `type` is a boolean scalar or vector type.
isBoolScalarOrVector(Type type)32 static bool isBoolScalarOrVector(Type type) {
33 if (type.isInteger(1))
34 return true;
35 if (auto vecType = type.dyn_cast<VectorType>())
36 return vecType.getElementType().isInteger(1);
37 return false;
38 }
39
40 /// Converts the given `srcAttr` into a boolean attribute if it holds an
41 /// integral value. Returns null attribute if conversion fails.
convertBoolAttr(Attribute srcAttr,Builder builder)42 static BoolAttr convertBoolAttr(Attribute srcAttr, Builder builder) {
43 if (auto boolAttr = srcAttr.dyn_cast<BoolAttr>())
44 return boolAttr;
45 if (auto intAttr = srcAttr.dyn_cast<IntegerAttr>())
46 return builder.getBoolAttr(intAttr.getValue().getBoolValue());
47 return BoolAttr();
48 }
49
50 /// Converts the given `srcAttr` to a new attribute of the given `dstType`.
51 /// Returns null attribute if conversion fails.
convertIntegerAttr(IntegerAttr srcAttr,IntegerType dstType,Builder builder)52 static IntegerAttr convertIntegerAttr(IntegerAttr srcAttr, IntegerType dstType,
53 Builder builder) {
54 // If the source number uses less active bits than the target bitwidth, then
55 // it should be safe to convert.
56 if (srcAttr.getValue().isIntN(dstType.getWidth()))
57 return builder.getIntegerAttr(dstType, srcAttr.getInt());
58
59 // XXX: Try again by interpreting the source number as a signed value.
60 // Although integers in the standard dialect are signless, they can represent
61 // a signed number. It's the operation decides how to interpret. This is
62 // dangerous, but it seems there is no good way of handling this if we still
63 // want to change the bitwidth. Emit a message at least.
64 if (srcAttr.getValue().isSignedIntN(dstType.getWidth())) {
65 auto dstAttr = builder.getIntegerAttr(dstType, srcAttr.getInt());
66 LLVM_DEBUG(llvm::dbgs() << "attribute '" << srcAttr << "' converted to '"
67 << dstAttr << "' for type '" << dstType << "'\n");
68 return dstAttr;
69 }
70
71 LLVM_DEBUG(llvm::dbgs() << "attribute '" << srcAttr
72 << "' illegal: cannot fit into target type '"
73 << dstType << "'\n");
74 return IntegerAttr();
75 }
76
77 /// Converts the given `srcAttr` to a new attribute of the given `dstType`.
78 /// Returns null attribute if `dstType` is not 32-bit or conversion fails.
convertFloatAttr(FloatAttr srcAttr,FloatType dstType,Builder builder)79 static FloatAttr convertFloatAttr(FloatAttr srcAttr, FloatType dstType,
80 Builder builder) {
81 // Only support converting to float for now.
82 if (!dstType.isF32())
83 return FloatAttr();
84
85 // Try to convert the source floating-point number to single precision.
86 APFloat dstVal = srcAttr.getValue();
87 bool losesInfo = false;
88 APFloat::opStatus status =
89 dstVal.convert(APFloat::IEEEsingle(), APFloat::rmTowardZero, &losesInfo);
90 if (status != APFloat::opOK || losesInfo) {
91 LLVM_DEBUG(llvm::dbgs()
92 << srcAttr << " illegal: cannot fit into converted type '"
93 << dstType << "'\n");
94 return FloatAttr();
95 }
96
97 return builder.getF32FloatAttr(dstVal.convertToFloat());
98 }
99
100 /// Returns signed remainder for `lhs` and `rhs` and lets the result follow
101 /// the sign of `signOperand`.
102 ///
103 /// Note that this is needed for Vulkan. Per the Vulkan's SPIR-V environment
104 /// spec, "for the OpSRem and OpSMod instructions, if either operand is negative
105 /// the result is undefined." So we cannot directly use spv.SRem/spv.SMod
106 /// if either operand can be negative. Emulate it via spv.UMod.
emulateSignedRemainder(Location loc,Value lhs,Value rhs,Value signOperand,OpBuilder & builder)107 static Value emulateSignedRemainder(Location loc, Value lhs, Value rhs,
108 Value signOperand, OpBuilder &builder) {
109 assert(lhs.getType() == rhs.getType());
110 assert(lhs == signOperand || rhs == signOperand);
111
112 Type type = lhs.getType();
113
114 // Calculate the remainder with spv.UMod.
115 Value lhsAbs = builder.create<spirv::GLSLSAbsOp>(loc, type, lhs);
116 Value rhsAbs = builder.create<spirv::GLSLSAbsOp>(loc, type, rhs);
117 Value abs = builder.create<spirv::UModOp>(loc, lhsAbs, rhsAbs);
118
119 // Fix the sign.
120 Value isPositive;
121 if (lhs == signOperand)
122 isPositive = builder.create<spirv::IEqualOp>(loc, lhs, lhsAbs);
123 else
124 isPositive = builder.create<spirv::IEqualOp>(loc, rhs, rhsAbs);
125 Value absNegate = builder.create<spirv::SNegateOp>(loc, type, abs);
126 return builder.create<spirv::SelectOp>(loc, type, isPositive, abs, absNegate);
127 }
128
129 /// Returns the offset of the value in `targetBits` representation.
130 ///
131 /// `srcIdx` is an index into a 1-D array with each element having `sourceBits`.
132 /// It's assumed to be non-negative.
133 ///
134 /// When accessing an element in the array treating as having elements of
135 /// `targetBits`, multiple values are loaded in the same time. The method
136 /// returns the offset where the `srcIdx` locates in the value. For example, if
137 /// `sourceBits` equals to 8 and `targetBits` equals to 32, the x-th element is
138 /// located at (x % 4) * 8. Because there are four elements in one i32, and one
139 /// element has 8 bits.
getOffsetForBitwidth(Location loc,Value srcIdx,int sourceBits,int targetBits,OpBuilder & builder)140 static Value getOffsetForBitwidth(Location loc, Value srcIdx, int sourceBits,
141 int targetBits, OpBuilder &builder) {
142 assert(targetBits % sourceBits == 0);
143 IntegerType targetType = builder.getIntegerType(targetBits);
144 IntegerAttr idxAttr =
145 builder.getIntegerAttr(targetType, targetBits / sourceBits);
146 auto idx = builder.create<spirv::ConstantOp>(loc, targetType, idxAttr);
147 IntegerAttr srcBitsAttr = builder.getIntegerAttr(targetType, sourceBits);
148 auto srcBitsValue =
149 builder.create<spirv::ConstantOp>(loc, targetType, srcBitsAttr);
150 auto m = builder.create<spirv::UModOp>(loc, srcIdx, idx);
151 return builder.create<spirv::IMulOp>(loc, targetType, m, srcBitsValue);
152 }
153
154 /// Returns an adjusted spirv::AccessChainOp. Based on the
155 /// extension/capabilities, certain integer bitwidths `sourceBits` might not be
156 /// supported. During conversion if a memref of an unsupported type is used,
157 /// load/stores to this memref need to be modified to use a supported higher
158 /// bitwidth `targetBits` and extracting the required bits. For an accessing a
159 /// 1D array (spv.array or spv.rt_array), the last index is modified to load the
160 /// bits needed. The extraction of the actual bits needed are handled
161 /// separately. Note that this only works for a 1-D tensor.
adjustAccessChainForBitwidth(SPIRVTypeConverter & typeConverter,spirv::AccessChainOp op,int sourceBits,int targetBits,OpBuilder & builder)162 static Value adjustAccessChainForBitwidth(SPIRVTypeConverter &typeConverter,
163 spirv::AccessChainOp op,
164 int sourceBits, int targetBits,
165 OpBuilder &builder) {
166 assert(targetBits % sourceBits == 0);
167 const auto loc = op.getLoc();
168 IntegerType targetType = builder.getIntegerType(targetBits);
169 IntegerAttr attr =
170 builder.getIntegerAttr(targetType, targetBits / sourceBits);
171 auto idx = builder.create<spirv::ConstantOp>(loc, targetType, attr);
172 auto lastDim = op->getOperand(op.getNumOperands() - 1);
173 auto indices = llvm::to_vector<4>(op.indices());
174 // There are two elements if this is a 1-D tensor.
175 assert(indices.size() == 2);
176 indices.back() = builder.create<spirv::SDivOp>(loc, lastDim, idx);
177 Type t = typeConverter.convertType(op.component_ptr().getType());
178 return builder.create<spirv::AccessChainOp>(loc, t, op.base_ptr(), indices);
179 }
180
181 /// Returns the shifted `targetBits`-bit value with the given offset.
shiftValue(Location loc,Value value,Value offset,Value mask,int targetBits,OpBuilder & builder)182 static Value shiftValue(Location loc, Value value, Value offset, Value mask,
183 int targetBits, OpBuilder &builder) {
184 Type targetType = builder.getIntegerType(targetBits);
185 Value result = builder.create<spirv::BitwiseAndOp>(loc, value, mask);
186 return builder.create<spirv::ShiftLeftLogicalOp>(loc, targetType, result,
187 offset);
188 }
189
190 /// Returns true if the operator is operating on unsigned integers.
191 /// TODO: Have a TreatOperandsAsUnsignedInteger trait and bake the information
192 /// to the ops themselves.
193 template <typename SPIRVOp>
isUnsignedOp()194 bool isUnsignedOp() {
195 return false;
196 }
197
198 #define CHECK_UNSIGNED_OP(SPIRVOp) \
199 template <> \
200 bool isUnsignedOp<SPIRVOp>() { \
201 return true; \
202 }
203
204 CHECK_UNSIGNED_OP(spirv::AtomicUMaxOp)
CHECK_UNSIGNED_OP(spirv::AtomicUMinOp)205 CHECK_UNSIGNED_OP(spirv::AtomicUMinOp)
206 CHECK_UNSIGNED_OP(spirv::BitFieldUExtractOp)
207 CHECK_UNSIGNED_OP(spirv::ConvertUToFOp)
208 CHECK_UNSIGNED_OP(spirv::GroupNonUniformUMaxOp)
209 CHECK_UNSIGNED_OP(spirv::GroupNonUniformUMinOp)
210 CHECK_UNSIGNED_OP(spirv::UConvertOp)
211 CHECK_UNSIGNED_OP(spirv::UDivOp)
212 CHECK_UNSIGNED_OP(spirv::UGreaterThanEqualOp)
213 CHECK_UNSIGNED_OP(spirv::UGreaterThanOp)
214 CHECK_UNSIGNED_OP(spirv::ULessThanEqualOp)
215 CHECK_UNSIGNED_OP(spirv::ULessThanOp)
216 CHECK_UNSIGNED_OP(spirv::UModOp)
217
218 #undef CHECK_UNSIGNED_OP
219
220 /// Returns true if the allocations of type `t` can be lowered to SPIR-V.
221 static bool isAllocationSupported(MemRefType t) {
222 // Currently only support workgroup local memory allocations with static
223 // shape and int or float or vector of int or float element type.
224 if (!(t.hasStaticShape() &&
225 SPIRVTypeConverter::getMemorySpaceForStorageClass(
226 spirv::StorageClass::Workgroup) == t.getMemorySpace()))
227 return false;
228 Type elementType = t.getElementType();
229 if (auto vecType = elementType.dyn_cast<VectorType>())
230 elementType = vecType.getElementType();
231 return elementType.isIntOrFloat();
232 }
233
234 /// Returns the scope to use for atomic operations use for emulating store
235 /// operations of unsupported integer bitwidths, based on the memref
236 /// type. Returns None on failure.
getAtomicOpScope(MemRefType t)237 static Optional<spirv::Scope> getAtomicOpScope(MemRefType t) {
238 Optional<spirv::StorageClass> storageClass =
239 SPIRVTypeConverter::getStorageClassForMemorySpace(t.getMemorySpace());
240 if (!storageClass)
241 return {};
242 switch (*storageClass) {
243 case spirv::StorageClass::StorageBuffer:
244 return spirv::Scope::Device;
245 case spirv::StorageClass::Workgroup:
246 return spirv::Scope::Workgroup;
247 default: {
248 }
249 }
250 return {};
251 }
252
253 //===----------------------------------------------------------------------===//
254 // Operation conversion
255 //===----------------------------------------------------------------------===//
256
257 // Note that DRR cannot be used for the patterns in this file: we may need to
258 // convert type along the way, which requires ConversionPattern. DRR generates
259 // normal RewritePattern.
260
261 namespace {
262
263 /// Converts an allocation operation to SPIR-V. Currently only supports lowering
264 /// to Workgroup memory when the size is constant. Note that this pattern needs
265 /// to be applied in a pass that runs at least at spv.module scope since it wil
266 /// ladd global variables into the spv.module.
267 class AllocOpPattern final : public SPIRVOpLowering<AllocOp> {
268 public:
269 using SPIRVOpLowering<AllocOp>::SPIRVOpLowering;
270
271 LogicalResult
matchAndRewrite(AllocOp operation,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const272 matchAndRewrite(AllocOp operation, ArrayRef<Value> operands,
273 ConversionPatternRewriter &rewriter) const override {
274 MemRefType allocType = operation.getType();
275 if (!isAllocationSupported(allocType))
276 return operation.emitError("unhandled allocation type");
277
278 // Get the SPIR-V type for the allocation.
279 Type spirvType = typeConverter.convertType(allocType);
280
281 // Insert spv.globalVariable for this allocation.
282 Operation *parent =
283 SymbolTable::getNearestSymbolTable(operation->getParentOp());
284 if (!parent)
285 return failure();
286 Location loc = operation.getLoc();
287 spirv::GlobalVariableOp varOp;
288 {
289 OpBuilder::InsertionGuard guard(rewriter);
290 Block &entryBlock = *parent->getRegion(0).begin();
291 rewriter.setInsertionPointToStart(&entryBlock);
292 auto varOps = entryBlock.getOps<spirv::GlobalVariableOp>();
293 std::string varName =
294 std::string("__workgroup_mem__") +
295 std::to_string(std::distance(varOps.begin(), varOps.end()));
296 varOp = rewriter.create<spirv::GlobalVariableOp>(
297 loc, TypeAttr::get(spirvType), varName,
298 /*initializer = */ nullptr);
299 }
300
301 // Get pointer to global variable at the current scope.
302 rewriter.replaceOpWithNewOp<spirv::AddressOfOp>(operation, varOp);
303 return success();
304 }
305 };
306
307 /// Removed a deallocation if it is a supported allocation. Currently only
308 /// removes deallocation if the memory space is workgroup memory.
309 class DeallocOpPattern final : public SPIRVOpLowering<DeallocOp> {
310 public:
311 using SPIRVOpLowering<DeallocOp>::SPIRVOpLowering;
312
313 LogicalResult
matchAndRewrite(DeallocOp operation,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const314 matchAndRewrite(DeallocOp operation, ArrayRef<Value> operands,
315 ConversionPatternRewriter &rewriter) const override {
316 MemRefType deallocType = operation.memref().getType().cast<MemRefType>();
317 if (!isAllocationSupported(deallocType))
318 return operation.emitError("unhandled deallocation type");
319 rewriter.eraseOp(operation);
320 return success();
321 }
322 };
323
324 /// Converts unary and binary standard operations to SPIR-V operations.
325 template <typename StdOp, typename SPIRVOp>
326 class UnaryAndBinaryOpPattern final : public SPIRVOpLowering<StdOp> {
327 public:
328 using SPIRVOpLowering<StdOp>::SPIRVOpLowering;
329
330 LogicalResult
matchAndRewrite(StdOp operation,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const331 matchAndRewrite(StdOp operation, ArrayRef<Value> operands,
332 ConversionPatternRewriter &rewriter) const override {
333 assert(operands.size() <= 2);
334 auto dstType = this->typeConverter.convertType(operation.getType());
335 if (!dstType)
336 return failure();
337 if (isUnsignedOp<SPIRVOp>() && dstType != operation.getType()) {
338 return operation.emitError(
339 "bitwidth emulation is not implemented yet on unsigned op");
340 }
341 rewriter.template replaceOpWithNewOp<SPIRVOp>(operation, dstType, operands);
342 return success();
343 }
344 };
345
346 /// Converts std.remi_signed to SPIR-V ops.
347 ///
348 /// This cannot be merged into the template unary/binary pattern due to
349 /// Vulkan restrictions over spv.SRem and spv.SMod.
350 class SignedRemIOpPattern final : public SPIRVOpLowering<SignedRemIOp> {
351 public:
352 using SPIRVOpLowering<SignedRemIOp>::SPIRVOpLowering;
353
354 LogicalResult
355 matchAndRewrite(SignedRemIOp remOp, ArrayRef<Value> operands,
356 ConversionPatternRewriter &rewriter) const override;
357 };
358
359 /// Converts bitwise standard operations to SPIR-V operations. This is a special
360 /// pattern other than the BinaryOpPatternPattern because if the operands are
361 /// boolean values, SPIR-V uses different operations (`SPIRVLogicalOp`). For
362 /// non-boolean operands, SPIR-V should use `SPIRVBitwiseOp`.
363 template <typename StdOp, typename SPIRVLogicalOp, typename SPIRVBitwiseOp>
364 class BitwiseOpPattern final : public SPIRVOpLowering<StdOp> {
365 public:
366 using SPIRVOpLowering<StdOp>::SPIRVOpLowering;
367
368 LogicalResult
matchAndRewrite(StdOp operation,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const369 matchAndRewrite(StdOp operation, ArrayRef<Value> operands,
370 ConversionPatternRewriter &rewriter) const override {
371 assert(operands.size() == 2);
372 auto dstType =
373 this->typeConverter.convertType(operation.getResult().getType());
374 if (!dstType)
375 return failure();
376 if (isBoolScalarOrVector(operands.front().getType())) {
377 rewriter.template replaceOpWithNewOp<SPIRVLogicalOp>(operation, dstType,
378 operands);
379 } else {
380 rewriter.template replaceOpWithNewOp<SPIRVBitwiseOp>(operation, dstType,
381 operands);
382 }
383 return success();
384 }
385 };
386
387 /// Converts composite std.constant operation to spv.constant.
388 class ConstantCompositeOpPattern final : public SPIRVOpLowering<ConstantOp> {
389 public:
390 using SPIRVOpLowering<ConstantOp>::SPIRVOpLowering;
391
392 LogicalResult
393 matchAndRewrite(ConstantOp constOp, ArrayRef<Value> operands,
394 ConversionPatternRewriter &rewriter) const override;
395 };
396
397 /// Converts scalar std.constant operation to spv.constant.
398 class ConstantScalarOpPattern final : public SPIRVOpLowering<ConstantOp> {
399 public:
400 using SPIRVOpLowering<ConstantOp>::SPIRVOpLowering;
401
402 LogicalResult
403 matchAndRewrite(ConstantOp constOp, ArrayRef<Value> operands,
404 ConversionPatternRewriter &rewriter) const override;
405 };
406
407 /// Converts floating-point comparison operations to SPIR-V ops.
408 class CmpFOpPattern final : public SPIRVOpLowering<CmpFOp> {
409 public:
410 using SPIRVOpLowering<CmpFOp>::SPIRVOpLowering;
411
412 LogicalResult
413 matchAndRewrite(CmpFOp cmpFOp, ArrayRef<Value> operands,
414 ConversionPatternRewriter &rewriter) const override;
415 };
416
417 /// Converts integer compare operation on i1 type operands to SPIR-V ops.
418 class BoolCmpIOpPattern final : public SPIRVOpLowering<CmpIOp> {
419 public:
420 using SPIRVOpLowering<CmpIOp>::SPIRVOpLowering;
421
422 LogicalResult
423 matchAndRewrite(CmpIOp cmpIOp, ArrayRef<Value> operands,
424 ConversionPatternRewriter &rewriter) const override;
425 };
426
427 /// Converts integer compare operation to SPIR-V ops.
428 class CmpIOpPattern final : public SPIRVOpLowering<CmpIOp> {
429 public:
430 using SPIRVOpLowering<CmpIOp>::SPIRVOpLowering;
431
432 LogicalResult
433 matchAndRewrite(CmpIOp cmpIOp, ArrayRef<Value> operands,
434 ConversionPatternRewriter &rewriter) const override;
435 };
436
437 /// Converts std.load to spv.Load.
438 class IntLoadOpPattern final : public SPIRVOpLowering<LoadOp> {
439 public:
440 using SPIRVOpLowering<LoadOp>::SPIRVOpLowering;
441
442 LogicalResult
443 matchAndRewrite(LoadOp loadOp, ArrayRef<Value> operands,
444 ConversionPatternRewriter &rewriter) const override;
445 };
446
447 /// Converts std.load to spv.Load.
448 class LoadOpPattern final : public SPIRVOpLowering<LoadOp> {
449 public:
450 using SPIRVOpLowering<LoadOp>::SPIRVOpLowering;
451
452 LogicalResult
453 matchAndRewrite(LoadOp loadOp, ArrayRef<Value> operands,
454 ConversionPatternRewriter &rewriter) const override;
455 };
456
457 /// Converts std.return to spv.Return.
458 class ReturnOpPattern final : public SPIRVOpLowering<ReturnOp> {
459 public:
460 using SPIRVOpLowering<ReturnOp>::SPIRVOpLowering;
461
462 LogicalResult
463 matchAndRewrite(ReturnOp returnOp, ArrayRef<Value> operands,
464 ConversionPatternRewriter &rewriter) const override;
465 };
466
467 /// Converts std.select to spv.Select.
468 class SelectOpPattern final : public SPIRVOpLowering<SelectOp> {
469 public:
470 using SPIRVOpLowering<SelectOp>::SPIRVOpLowering;
471 LogicalResult
472 matchAndRewrite(SelectOp op, ArrayRef<Value> operands,
473 ConversionPatternRewriter &rewriter) const override;
474 };
475
476 /// Converts std.store to spv.Store on integers.
477 class IntStoreOpPattern final : public SPIRVOpLowering<StoreOp> {
478 public:
479 using SPIRVOpLowering<StoreOp>::SPIRVOpLowering;
480
481 LogicalResult
482 matchAndRewrite(StoreOp storeOp, ArrayRef<Value> operands,
483 ConversionPatternRewriter &rewriter) const override;
484 };
485
486 /// Converts std.store to spv.Store.
487 class StoreOpPattern final : public SPIRVOpLowering<StoreOp> {
488 public:
489 using SPIRVOpLowering<StoreOp>::SPIRVOpLowering;
490
491 LogicalResult
492 matchAndRewrite(StoreOp storeOp, ArrayRef<Value> operands,
493 ConversionPatternRewriter &rewriter) const override;
494 };
495
496 /// Converts std.zexti to spv.Select if the type of source is i1 or vector of
497 /// i1.
498 class ZeroExtendI1Pattern final : public SPIRVOpLowering<ZeroExtendIOp> {
499 public:
500 using SPIRVOpLowering<ZeroExtendIOp>::SPIRVOpLowering;
501
502 LogicalResult
matchAndRewrite(ZeroExtendIOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const503 matchAndRewrite(ZeroExtendIOp op, ArrayRef<Value> operands,
504 ConversionPatternRewriter &rewriter) const override {
505 auto srcType = operands.front().getType();
506 if (!isBoolScalarOrVector(srcType))
507 return failure();
508
509 auto dstType = this->typeConverter.convertType(op.getResult().getType());
510 Location loc = op.getLoc();
511 Attribute zeroAttr, oneAttr;
512 if (auto vectorType = dstType.dyn_cast<VectorType>()) {
513 zeroAttr = DenseElementsAttr::get(vectorType, 0);
514 oneAttr = DenseElementsAttr::get(vectorType, 1);
515 } else {
516 zeroAttr = IntegerAttr::get(dstType, 0);
517 oneAttr = IntegerAttr::get(dstType, 1);
518 }
519 Value zero = rewriter.create<ConstantOp>(loc, zeroAttr);
520 Value one = rewriter.create<ConstantOp>(loc, oneAttr);
521 rewriter.template replaceOpWithNewOp<spirv::SelectOp>(
522 op, dstType, operands.front(), one, zero);
523 return success();
524 }
525 };
526
527 /// Converts type-casting standard operations to SPIR-V operations.
528 template <typename StdOp, typename SPIRVOp>
529 class TypeCastingOpPattern final : public SPIRVOpLowering<StdOp> {
530 public:
531 using SPIRVOpLowering<StdOp>::SPIRVOpLowering;
532
533 LogicalResult
matchAndRewrite(StdOp operation,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const534 matchAndRewrite(StdOp operation, ArrayRef<Value> operands,
535 ConversionPatternRewriter &rewriter) const override {
536 assert(operands.size() == 1);
537 auto srcType = operands.front().getType();
538 if (isBoolScalarOrVector(srcType))
539 return failure();
540 auto dstType =
541 this->typeConverter.convertType(operation.getResult().getType());
542 if (dstType == srcType) {
543 // Due to type conversion, we are seeing the same source and target type.
544 // Then we can just erase this operation by forwarding its operand.
545 rewriter.replaceOp(operation, operands.front());
546 } else {
547 rewriter.template replaceOpWithNewOp<SPIRVOp>(operation, dstType,
548 operands);
549 }
550 return success();
551 }
552 };
553
554 /// Converts std.xor to SPIR-V operations.
555 class XOrOpPattern final : public SPIRVOpLowering<XOrOp> {
556 public:
557 using SPIRVOpLowering<XOrOp>::SPIRVOpLowering;
558
559 LogicalResult
560 matchAndRewrite(XOrOp xorOp, ArrayRef<Value> operands,
561 ConversionPatternRewriter &rewriter) const override;
562 };
563
564 } // namespace
565
566 //===----------------------------------------------------------------------===//
567 // SignedRemIOpPattern
568 //===----------------------------------------------------------------------===//
569
matchAndRewrite(SignedRemIOp remOp,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const570 LogicalResult SignedRemIOpPattern::matchAndRewrite(
571 SignedRemIOp remOp, ArrayRef<Value> operands,
572 ConversionPatternRewriter &rewriter) const {
573 Value result = emulateSignedRemainder(remOp.getLoc(), operands[0],
574 operands[1], operands[0], rewriter);
575 rewriter.replaceOp(remOp, result);
576
577 return success();
578 }
579
580 //===----------------------------------------------------------------------===//
581 // ConstantOp with composite type.
582 //===----------------------------------------------------------------------===//
583
matchAndRewrite(ConstantOp constOp,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const584 LogicalResult ConstantCompositeOpPattern::matchAndRewrite(
585 ConstantOp constOp, ArrayRef<Value> operands,
586 ConversionPatternRewriter &rewriter) const {
587 auto srcType = constOp.getType().dyn_cast<ShapedType>();
588 if (!srcType)
589 return failure();
590
591 // std.constant should only have vector or tenor types.
592 assert((srcType.isa<VectorType, RankedTensorType>()));
593
594 auto dstType = typeConverter.convertType(srcType);
595 if (!dstType)
596 return failure();
597
598 auto dstElementsAttr = constOp.value().dyn_cast<DenseElementsAttr>();
599 ShapedType dstAttrType = dstElementsAttr.getType();
600 if (!dstElementsAttr)
601 return failure();
602
603 // If the composite type has more than one dimensions, perform linearization.
604 if (srcType.getRank() > 1) {
605 if (srcType.isa<RankedTensorType>()) {
606 dstAttrType = RankedTensorType::get(srcType.getNumElements(),
607 srcType.getElementType());
608 dstElementsAttr = dstElementsAttr.reshape(dstAttrType);
609 } else {
610 // TODO: add support for large vectors.
611 return failure();
612 }
613 }
614
615 Type srcElemType = srcType.getElementType();
616 Type dstElemType;
617 // Tensor types are converted to SPIR-V array types; vector types are
618 // converted to SPIR-V vector/array types.
619 if (auto arrayType = dstType.dyn_cast<spirv::ArrayType>())
620 dstElemType = arrayType.getElementType();
621 else
622 dstElemType = dstType.cast<VectorType>().getElementType();
623
624 // If the source and destination element types are different, perform
625 // attribute conversion.
626 if (srcElemType != dstElemType) {
627 SmallVector<Attribute, 8> elements;
628 if (srcElemType.isa<FloatType>()) {
629 for (Attribute srcAttr : dstElementsAttr.getAttributeValues()) {
630 FloatAttr dstAttr = convertFloatAttr(
631 srcAttr.cast<FloatAttr>(), dstElemType.cast<FloatType>(), rewriter);
632 if (!dstAttr)
633 return failure();
634 elements.push_back(dstAttr);
635 }
636 } else if (srcElemType.isInteger(1)) {
637 return failure();
638 } else {
639 for (Attribute srcAttr : dstElementsAttr.getAttributeValues()) {
640 IntegerAttr dstAttr =
641 convertIntegerAttr(srcAttr.cast<IntegerAttr>(),
642 dstElemType.cast<IntegerType>(), rewriter);
643 if (!dstAttr)
644 return failure();
645 elements.push_back(dstAttr);
646 }
647 }
648
649 // Unfortunately, we cannot use dialect-specific types for element
650 // attributes; element attributes only works with builtin types. So we need
651 // to prepare another converted builtin types for the destination elements
652 // attribute.
653 if (dstAttrType.isa<RankedTensorType>())
654 dstAttrType = RankedTensorType::get(dstAttrType.getShape(), dstElemType);
655 else
656 dstAttrType = VectorType::get(dstAttrType.getShape(), dstElemType);
657
658 dstElementsAttr = DenseElementsAttr::get(dstAttrType, elements);
659 }
660
661 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType,
662 dstElementsAttr);
663 return success();
664 }
665
666 //===----------------------------------------------------------------------===//
667 // ConstantOp with scalar type.
668 //===----------------------------------------------------------------------===//
669
matchAndRewrite(ConstantOp constOp,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const670 LogicalResult ConstantScalarOpPattern::matchAndRewrite(
671 ConstantOp constOp, ArrayRef<Value> operands,
672 ConversionPatternRewriter &rewriter) const {
673 Type srcType = constOp.getType();
674 if (!srcType.isIntOrIndexOrFloat())
675 return failure();
676
677 Type dstType = typeConverter.convertType(srcType);
678 if (!dstType)
679 return failure();
680
681 // Floating-point types.
682 if (srcType.isa<FloatType>()) {
683 auto srcAttr = constOp.value().cast<FloatAttr>();
684 auto dstAttr = srcAttr;
685
686 // Floating-point types not supported in the target environment are all
687 // converted to float type.
688 if (srcType != dstType) {
689 dstAttr = convertFloatAttr(srcAttr, dstType.cast<FloatType>(), rewriter);
690 if (!dstAttr)
691 return failure();
692 }
693
694 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
695 return success();
696 }
697
698 // Bool type.
699 if (srcType.isInteger(1)) {
700 // std.constant can use 0/1 instead of true/false for i1 values. We need to
701 // handle that here.
702 auto dstAttr = convertBoolAttr(constOp.value(), rewriter);
703 if (!dstAttr)
704 return failure();
705 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
706 return success();
707 }
708
709 // IndexType or IntegerType. Index values are converted to 32-bit integer
710 // values when converting to SPIR-V.
711 auto srcAttr = constOp.value().cast<IntegerAttr>();
712 auto dstAttr =
713 convertIntegerAttr(srcAttr, dstType.cast<IntegerType>(), rewriter);
714 if (!dstAttr)
715 return failure();
716 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
717 return success();
718 }
719
720 //===----------------------------------------------------------------------===//
721 // CmpFOp
722 //===----------------------------------------------------------------------===//
723
724 LogicalResult
matchAndRewrite(CmpFOp cmpFOp,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const725 CmpFOpPattern::matchAndRewrite(CmpFOp cmpFOp, ArrayRef<Value> operands,
726 ConversionPatternRewriter &rewriter) const {
727 CmpFOpAdaptor cmpFOpOperands(operands);
728
729 switch (cmpFOp.getPredicate()) {
730 #define DISPATCH(cmpPredicate, spirvOp) \
731 case cmpPredicate: \
732 rewriter.replaceOpWithNewOp<spirvOp>(cmpFOp, cmpFOp.getResult().getType(), \
733 cmpFOpOperands.lhs(), \
734 cmpFOpOperands.rhs()); \
735 return success();
736
737 // Ordered.
738 DISPATCH(CmpFPredicate::OEQ, spirv::FOrdEqualOp);
739 DISPATCH(CmpFPredicate::OGT, spirv::FOrdGreaterThanOp);
740 DISPATCH(CmpFPredicate::OGE, spirv::FOrdGreaterThanEqualOp);
741 DISPATCH(CmpFPredicate::OLT, spirv::FOrdLessThanOp);
742 DISPATCH(CmpFPredicate::OLE, spirv::FOrdLessThanEqualOp);
743 DISPATCH(CmpFPredicate::ONE, spirv::FOrdNotEqualOp);
744 // Unordered.
745 DISPATCH(CmpFPredicate::UEQ, spirv::FUnordEqualOp);
746 DISPATCH(CmpFPredicate::UGT, spirv::FUnordGreaterThanOp);
747 DISPATCH(CmpFPredicate::UGE, spirv::FUnordGreaterThanEqualOp);
748 DISPATCH(CmpFPredicate::ULT, spirv::FUnordLessThanOp);
749 DISPATCH(CmpFPredicate::ULE, spirv::FUnordLessThanEqualOp);
750 DISPATCH(CmpFPredicate::UNE, spirv::FUnordNotEqualOp);
751
752 #undef DISPATCH
753
754 default:
755 break;
756 }
757 return failure();
758 }
759
760 //===----------------------------------------------------------------------===//
761 // CmpIOp
762 //===----------------------------------------------------------------------===//
763
764 LogicalResult
matchAndRewrite(CmpIOp cmpIOp,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const765 BoolCmpIOpPattern::matchAndRewrite(CmpIOp cmpIOp, ArrayRef<Value> operands,
766 ConversionPatternRewriter &rewriter) const {
767 CmpIOpAdaptor cmpIOpOperands(operands);
768
769 Type operandType = cmpIOp.lhs().getType();
770 if (!isBoolScalarOrVector(operandType))
771 return failure();
772
773 switch (cmpIOp.getPredicate()) {
774 #define DISPATCH(cmpPredicate, spirvOp) \
775 case cmpPredicate: \
776 rewriter.replaceOpWithNewOp<spirvOp>(cmpIOp, cmpIOp.getResult().getType(), \
777 cmpIOpOperands.lhs(), \
778 cmpIOpOperands.rhs()); \
779 return success();
780
781 DISPATCH(CmpIPredicate::eq, spirv::LogicalEqualOp);
782 DISPATCH(CmpIPredicate::ne, spirv::LogicalNotEqualOp);
783
784 #undef DISPATCH
785 default:;
786 }
787 return failure();
788 }
789
790 LogicalResult
matchAndRewrite(CmpIOp cmpIOp,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const791 CmpIOpPattern::matchAndRewrite(CmpIOp cmpIOp, ArrayRef<Value> operands,
792 ConversionPatternRewriter &rewriter) const {
793 CmpIOpAdaptor cmpIOpOperands(operands);
794
795 Type operandType = cmpIOp.lhs().getType();
796 if (isBoolScalarOrVector(operandType))
797 return failure();
798
799 switch (cmpIOp.getPredicate()) {
800 #define DISPATCH(cmpPredicate, spirvOp) \
801 case cmpPredicate: \
802 if (isUnsignedOp<spirvOp>() && \
803 operandType != this->typeConverter.convertType(operandType)) { \
804 return cmpIOp.emitError( \
805 "bitwidth emulation is not implemented yet on unsigned op"); \
806 } \
807 rewriter.replaceOpWithNewOp<spirvOp>(cmpIOp, cmpIOp.getResult().getType(), \
808 cmpIOpOperands.lhs(), \
809 cmpIOpOperands.rhs()); \
810 return success();
811
812 DISPATCH(CmpIPredicate::eq, spirv::IEqualOp);
813 DISPATCH(CmpIPredicate::ne, spirv::INotEqualOp);
814 DISPATCH(CmpIPredicate::slt, spirv::SLessThanOp);
815 DISPATCH(CmpIPredicate::sle, spirv::SLessThanEqualOp);
816 DISPATCH(CmpIPredicate::sgt, spirv::SGreaterThanOp);
817 DISPATCH(CmpIPredicate::sge, spirv::SGreaterThanEqualOp);
818 DISPATCH(CmpIPredicate::ult, spirv::ULessThanOp);
819 DISPATCH(CmpIPredicate::ule, spirv::ULessThanEqualOp);
820 DISPATCH(CmpIPredicate::ugt, spirv::UGreaterThanOp);
821 DISPATCH(CmpIPredicate::uge, spirv::UGreaterThanEqualOp);
822
823 #undef DISPATCH
824 }
825 return failure();
826 }
827
828 //===----------------------------------------------------------------------===//
829 // LoadOp
830 //===----------------------------------------------------------------------===//
831
832 LogicalResult
matchAndRewrite(LoadOp loadOp,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const833 IntLoadOpPattern::matchAndRewrite(LoadOp loadOp, ArrayRef<Value> operands,
834 ConversionPatternRewriter &rewriter) const {
835 LoadOpAdaptor loadOperands(operands);
836 auto loc = loadOp.getLoc();
837 auto memrefType = loadOp.memref().getType().cast<MemRefType>();
838 if (!memrefType.getElementType().isSignlessInteger())
839 return failure();
840 spirv::AccessChainOp accessChainOp =
841 spirv::getElementPtr(typeConverter, memrefType, loadOperands.memref(),
842 loadOperands.indices(), loc, rewriter);
843
844 int srcBits = memrefType.getElementType().getIntOrFloatBitWidth();
845 auto dstType = typeConverter.convertType(memrefType)
846 .cast<spirv::PointerType>()
847 .getPointeeType()
848 .cast<spirv::StructType>()
849 .getElementType(0)
850 .cast<spirv::ArrayType>()
851 .getElementType();
852 int dstBits = dstType.getIntOrFloatBitWidth();
853 assert(dstBits % srcBits == 0);
854
855 // If the rewrited load op has the same bit width, use the loading value
856 // directly.
857 if (srcBits == dstBits) {
858 rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp,
859 accessChainOp.getResult());
860 return success();
861 }
862
863 // Assume that getElementPtr() works linearizely. If it's a scalar, the method
864 // still returns a linearized accessing. If the accessing is not linearized,
865 // there will be offset issues.
866 assert(accessChainOp.indices().size() == 2);
867 Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp,
868 srcBits, dstBits, rewriter);
869 Value spvLoadOp = rewriter.create<spirv::LoadOp>(
870 loc, dstType, adjustedPtr,
871 loadOp->getAttrOfType<IntegerAttr>(
872 spirv::attributeName<spirv::MemoryAccess>()),
873 loadOp->getAttrOfType<IntegerAttr>("alignment"));
874
875 // Shift the bits to the rightmost.
876 // ____XXXX________ -> ____________XXXX
877 Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1);
878 Value offset = getOffsetForBitwidth(loc, lastDim, srcBits, dstBits, rewriter);
879 Value result = rewriter.create<spirv::ShiftRightArithmeticOp>(
880 loc, spvLoadOp.getType(), spvLoadOp, offset);
881
882 // Apply the mask to extract corresponding bits.
883 Value mask = rewriter.create<spirv::ConstantOp>(
884 loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1));
885 result = rewriter.create<spirv::BitwiseAndOp>(loc, dstType, result, mask);
886
887 // Apply sign extension on the loading value unconditionally. The signedness
888 // semantic is carried in the operator itself, we relies other pattern to
889 // handle the casting.
890 IntegerAttr shiftValueAttr =
891 rewriter.getIntegerAttr(dstType, dstBits - srcBits);
892 Value shiftValue =
893 rewriter.create<spirv::ConstantOp>(loc, dstType, shiftValueAttr);
894 result = rewriter.create<spirv::ShiftLeftLogicalOp>(loc, dstType, result,
895 shiftValue);
896 result = rewriter.create<spirv::ShiftRightArithmeticOp>(loc, dstType, result,
897 shiftValue);
898 rewriter.replaceOp(loadOp, result);
899
900 assert(accessChainOp.use_empty());
901 rewriter.eraseOp(accessChainOp);
902
903 return success();
904 }
905
906 LogicalResult
matchAndRewrite(LoadOp loadOp,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const907 LoadOpPattern::matchAndRewrite(LoadOp loadOp, ArrayRef<Value> operands,
908 ConversionPatternRewriter &rewriter) const {
909 LoadOpAdaptor loadOperands(operands);
910 auto memrefType = loadOp.memref().getType().cast<MemRefType>();
911 if (memrefType.getElementType().isSignlessInteger())
912 return failure();
913 auto loadPtr =
914 spirv::getElementPtr(typeConverter, memrefType, loadOperands.memref(),
915 loadOperands.indices(), loadOp.getLoc(), rewriter);
916 rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, loadPtr);
917 return success();
918 }
919
920 //===----------------------------------------------------------------------===//
921 // ReturnOp
922 //===----------------------------------------------------------------------===//
923
924 LogicalResult
matchAndRewrite(ReturnOp returnOp,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const925 ReturnOpPattern::matchAndRewrite(ReturnOp returnOp, ArrayRef<Value> operands,
926 ConversionPatternRewriter &rewriter) const {
927 if (returnOp.getNumOperands()) {
928 return failure();
929 }
930 rewriter.replaceOpWithNewOp<spirv::ReturnOp>(returnOp);
931 return success();
932 }
933
934 //===----------------------------------------------------------------------===//
935 // SelectOp
936 //===----------------------------------------------------------------------===//
937
938 LogicalResult
matchAndRewrite(SelectOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const939 SelectOpPattern::matchAndRewrite(SelectOp op, ArrayRef<Value> operands,
940 ConversionPatternRewriter &rewriter) const {
941 SelectOpAdaptor selectOperands(operands);
942 rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, selectOperands.condition(),
943 selectOperands.true_value(),
944 selectOperands.false_value());
945 return success();
946 }
947
948 //===----------------------------------------------------------------------===//
949 // StoreOp
950 //===----------------------------------------------------------------------===//
951
952 LogicalResult
matchAndRewrite(StoreOp storeOp,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const953 IntStoreOpPattern::matchAndRewrite(StoreOp storeOp, ArrayRef<Value> operands,
954 ConversionPatternRewriter &rewriter) const {
955 StoreOpAdaptor storeOperands(operands);
956 auto memrefType = storeOp.memref().getType().cast<MemRefType>();
957 if (!memrefType.getElementType().isSignlessInteger())
958 return failure();
959
960 auto loc = storeOp.getLoc();
961 spirv::AccessChainOp accessChainOp =
962 spirv::getElementPtr(typeConverter, memrefType, storeOperands.memref(),
963 storeOperands.indices(), loc, rewriter);
964 int srcBits = memrefType.getElementType().getIntOrFloatBitWidth();
965 auto dstType = typeConverter.convertType(memrefType)
966 .cast<spirv::PointerType>()
967 .getPointeeType()
968 .cast<spirv::StructType>()
969 .getElementType(0)
970 .cast<spirv::ArrayType>()
971 .getElementType();
972 int dstBits = dstType.getIntOrFloatBitWidth();
973 assert(dstBits % srcBits == 0);
974
975 if (srcBits == dstBits) {
976 rewriter.replaceOpWithNewOp<spirv::StoreOp>(
977 storeOp, accessChainOp.getResult(), storeOperands.value());
978 return success();
979 }
980
981 // Since there are multi threads in the processing, the emulation will be done
982 // with atomic operations. E.g., if the storing value is i8, rewrite the
983 // StoreOp to
984 // 1) load a 32-bit integer
985 // 2) clear 8 bits in the loading value
986 // 3) store 32-bit value back
987 // 4) load a 32-bit integer
988 // 5) modify 8 bits in the loading value
989 // 6) store 32-bit value back
990 // The step 1 to step 3 are done by AtomicAnd as one atomic step, and the step
991 // 4 to step 6 are done by AtomicOr as another atomic step.
992 assert(accessChainOp.indices().size() == 2);
993 Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1);
994 Value offset = getOffsetForBitwidth(loc, lastDim, srcBits, dstBits, rewriter);
995
996 // Create a mask to clear the destination. E.g., if it is the second i8 in
997 // i32, 0xFFFF00FF is created.
998 Value mask = rewriter.create<spirv::ConstantOp>(
999 loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1));
1000 Value clearBitsMask =
1001 rewriter.create<spirv::ShiftLeftLogicalOp>(loc, dstType, mask, offset);
1002 clearBitsMask = rewriter.create<spirv::NotOp>(loc, dstType, clearBitsMask);
1003
1004 Value storeVal =
1005 shiftValue(loc, storeOperands.value(), offset, mask, dstBits, rewriter);
1006 Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp,
1007 srcBits, dstBits, rewriter);
1008 Optional<spirv::Scope> scope = getAtomicOpScope(memrefType);
1009 if (!scope)
1010 return failure();
1011 Value result = rewriter.create<spirv::AtomicAndOp>(
1012 loc, dstType, adjustedPtr, *scope, spirv::MemorySemantics::AcquireRelease,
1013 clearBitsMask);
1014 result = rewriter.create<spirv::AtomicOrOp>(
1015 loc, dstType, adjustedPtr, *scope, spirv::MemorySemantics::AcquireRelease,
1016 storeVal);
1017
1018 // The AtomicOrOp has no side effect. Since it is already inserted, we can
1019 // just remove the original StoreOp. Note that rewriter.replaceOp()
1020 // doesn't work because it only accepts that the numbers of result are the
1021 // same.
1022 rewriter.eraseOp(storeOp);
1023
1024 assert(accessChainOp.use_empty());
1025 rewriter.eraseOp(accessChainOp);
1026
1027 return success();
1028 }
1029
1030 LogicalResult
matchAndRewrite(StoreOp storeOp,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const1031 StoreOpPattern::matchAndRewrite(StoreOp storeOp, ArrayRef<Value> operands,
1032 ConversionPatternRewriter &rewriter) const {
1033 StoreOpAdaptor storeOperands(operands);
1034 auto memrefType = storeOp.memref().getType().cast<MemRefType>();
1035 if (memrefType.getElementType().isSignlessInteger())
1036 return failure();
1037 auto storePtr =
1038 spirv::getElementPtr(typeConverter, memrefType, storeOperands.memref(),
1039 storeOperands.indices(), storeOp.getLoc(), rewriter);
1040 rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, storePtr,
1041 storeOperands.value());
1042 return success();
1043 }
1044
1045 //===----------------------------------------------------------------------===//
1046 // XorOp
1047 //===----------------------------------------------------------------------===//
1048
1049 LogicalResult
matchAndRewrite(XOrOp xorOp,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const1050 XOrOpPattern::matchAndRewrite(XOrOp xorOp, ArrayRef<Value> operands,
1051 ConversionPatternRewriter &rewriter) const {
1052 assert(operands.size() == 2);
1053
1054 if (isBoolScalarOrVector(operands.front().getType()))
1055 return failure();
1056
1057 auto dstType = typeConverter.convertType(xorOp.getType());
1058 if (!dstType)
1059 return failure();
1060 rewriter.replaceOpWithNewOp<spirv::BitwiseXorOp>(xorOp, dstType, operands);
1061
1062 return success();
1063 }
1064
1065 //===----------------------------------------------------------------------===//
1066 // Pattern population
1067 //===----------------------------------------------------------------------===//
1068
1069 namespace mlir {
populateStandardToSPIRVPatterns(MLIRContext * context,SPIRVTypeConverter & typeConverter,OwningRewritePatternList & patterns)1070 void populateStandardToSPIRVPatterns(MLIRContext *context,
1071 SPIRVTypeConverter &typeConverter,
1072 OwningRewritePatternList &patterns) {
1073 patterns.insert<
1074 // Unary and binary patterns
1075 BitwiseOpPattern<AndOp, spirv::LogicalAndOp, spirv::BitwiseAndOp>,
1076 BitwiseOpPattern<OrOp, spirv::LogicalOrOp, spirv::BitwiseOrOp>,
1077 UnaryAndBinaryOpPattern<AbsFOp, spirv::GLSLFAbsOp>,
1078 UnaryAndBinaryOpPattern<AddFOp, spirv::FAddOp>,
1079 UnaryAndBinaryOpPattern<AddIOp, spirv::IAddOp>,
1080 UnaryAndBinaryOpPattern<CeilFOp, spirv::GLSLCeilOp>,
1081 UnaryAndBinaryOpPattern<CosOp, spirv::GLSLCosOp>,
1082 UnaryAndBinaryOpPattern<DivFOp, spirv::FDivOp>,
1083 UnaryAndBinaryOpPattern<ExpOp, spirv::GLSLExpOp>,
1084 UnaryAndBinaryOpPattern<FloorFOp, spirv::GLSLFloorOp>,
1085 UnaryAndBinaryOpPattern<LogOp, spirv::GLSLLogOp>,
1086 UnaryAndBinaryOpPattern<MulFOp, spirv::FMulOp>,
1087 UnaryAndBinaryOpPattern<MulIOp, spirv::IMulOp>,
1088 UnaryAndBinaryOpPattern<NegFOp, spirv::FNegateOp>,
1089 UnaryAndBinaryOpPattern<RemFOp, spirv::FRemOp>,
1090 UnaryAndBinaryOpPattern<RsqrtOp, spirv::GLSLInverseSqrtOp>,
1091 UnaryAndBinaryOpPattern<ShiftLeftOp, spirv::ShiftLeftLogicalOp>,
1092 UnaryAndBinaryOpPattern<SignedDivIOp, spirv::SDivOp>,
1093 UnaryAndBinaryOpPattern<SignedShiftRightOp,
1094 spirv::ShiftRightArithmeticOp>,
1095 UnaryAndBinaryOpPattern<SinOp, spirv::GLSLSinOp>,
1096 UnaryAndBinaryOpPattern<SqrtOp, spirv::GLSLSqrtOp>,
1097 UnaryAndBinaryOpPattern<SubFOp, spirv::FSubOp>,
1098 UnaryAndBinaryOpPattern<SubIOp, spirv::ISubOp>,
1099 UnaryAndBinaryOpPattern<TanhOp, spirv::GLSLTanhOp>,
1100 UnaryAndBinaryOpPattern<UnsignedDivIOp, spirv::UDivOp>,
1101 UnaryAndBinaryOpPattern<UnsignedRemIOp, spirv::UModOp>,
1102 UnaryAndBinaryOpPattern<UnsignedShiftRightOp, spirv::ShiftRightLogicalOp>,
1103 SignedRemIOpPattern, XOrOpPattern,
1104
1105 // Comparison patterns
1106 BoolCmpIOpPattern, CmpFOpPattern, CmpIOpPattern,
1107
1108 // Constant patterns
1109 ConstantCompositeOpPattern, ConstantScalarOpPattern,
1110
1111 // Memory patterns
1112 AllocOpPattern, DeallocOpPattern, IntLoadOpPattern, IntStoreOpPattern,
1113 LoadOpPattern, StoreOpPattern,
1114
1115 ReturnOpPattern, SelectOpPattern,
1116
1117 // Type cast patterns
1118 ZeroExtendI1Pattern, TypeCastingOpPattern<IndexCastOp, spirv::SConvertOp>,
1119 TypeCastingOpPattern<SIToFPOp, spirv::ConvertSToFOp>,
1120 TypeCastingOpPattern<ZeroExtendIOp, spirv::UConvertOp>,
1121 TypeCastingOpPattern<TruncateIOp, spirv::SConvertOp>,
1122 TypeCastingOpPattern<FPToSIOp, spirv::ConvertFToSOp>,
1123 TypeCastingOpPattern<FPExtOp, spirv::FConvertOp>,
1124 TypeCastingOpPattern<FPTruncOp, spirv::FConvertOp>>(context,
1125 typeConverter);
1126 }
1127 } // namespace mlir
1128