• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //===- ConvertStandardToSPIRV.cpp - Standard to SPIR-V dialect conversion--===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to convert standard ops to SPIR-V ops.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/SPIRV/LayoutUtils.h"
14 #include "mlir/Dialect/SPIRV/SPIRVDialect.h"
15 #include "mlir/Dialect/SPIRV/SPIRVLowering.h"
16 #include "mlir/Dialect/SPIRV/SPIRVOps.h"
17 #include "mlir/Dialect/StandardOps/IR/Ops.h"
18 #include "mlir/IR/AffineMap.h"
19 #include "mlir/Support/LogicalResult.h"
20 #include "llvm/ADT/SetVector.h"
21 #include "llvm/Support/Debug.h"
22 
23 #define DEBUG_TYPE "std-to-spirv-pattern"
24 
25 using namespace mlir;
26 
27 //===----------------------------------------------------------------------===//
28 // Utility functions
29 //===----------------------------------------------------------------------===//
30 
31 /// Returns true if the given `type` is a boolean scalar or vector type.
isBoolScalarOrVector(Type type)32 static bool isBoolScalarOrVector(Type type) {
33   if (type.isInteger(1))
34     return true;
35   if (auto vecType = type.dyn_cast<VectorType>())
36     return vecType.getElementType().isInteger(1);
37   return false;
38 }
39 
40 /// Converts the given `srcAttr` into a boolean attribute if it holds an
41 /// integral value. Returns null attribute if conversion fails.
convertBoolAttr(Attribute srcAttr,Builder builder)42 static BoolAttr convertBoolAttr(Attribute srcAttr, Builder builder) {
43   if (auto boolAttr = srcAttr.dyn_cast<BoolAttr>())
44     return boolAttr;
45   if (auto intAttr = srcAttr.dyn_cast<IntegerAttr>())
46     return builder.getBoolAttr(intAttr.getValue().getBoolValue());
47   return BoolAttr();
48 }
49 
50 /// Converts the given `srcAttr` to a new attribute of the given `dstType`.
51 /// Returns null attribute if conversion fails.
convertIntegerAttr(IntegerAttr srcAttr,IntegerType dstType,Builder builder)52 static IntegerAttr convertIntegerAttr(IntegerAttr srcAttr, IntegerType dstType,
53                                       Builder builder) {
54   // If the source number uses less active bits than the target bitwidth, then
55   // it should be safe to convert.
56   if (srcAttr.getValue().isIntN(dstType.getWidth()))
57     return builder.getIntegerAttr(dstType, srcAttr.getInt());
58 
59   // XXX: Try again by interpreting the source number as a signed value.
60   // Although integers in the standard dialect are signless, they can represent
61   // a signed number. It's the operation decides how to interpret. This is
62   // dangerous, but it seems there is no good way of handling this if we still
63   // want to change the bitwidth. Emit a message at least.
64   if (srcAttr.getValue().isSignedIntN(dstType.getWidth())) {
65     auto dstAttr = builder.getIntegerAttr(dstType, srcAttr.getInt());
66     LLVM_DEBUG(llvm::dbgs() << "attribute '" << srcAttr << "' converted to '"
67                             << dstAttr << "' for type '" << dstType << "'\n");
68     return dstAttr;
69   }
70 
71   LLVM_DEBUG(llvm::dbgs() << "attribute '" << srcAttr
72                           << "' illegal: cannot fit into target type '"
73                           << dstType << "'\n");
74   return IntegerAttr();
75 }
76 
77 /// Converts the given `srcAttr` to a new attribute of the given `dstType`.
78 /// Returns null attribute if `dstType` is not 32-bit or conversion fails.
convertFloatAttr(FloatAttr srcAttr,FloatType dstType,Builder builder)79 static FloatAttr convertFloatAttr(FloatAttr srcAttr, FloatType dstType,
80                                   Builder builder) {
81   // Only support converting to float for now.
82   if (!dstType.isF32())
83     return FloatAttr();
84 
85   // Try to convert the source floating-point number to single precision.
86   APFloat dstVal = srcAttr.getValue();
87   bool losesInfo = false;
88   APFloat::opStatus status =
89       dstVal.convert(APFloat::IEEEsingle(), APFloat::rmTowardZero, &losesInfo);
90   if (status != APFloat::opOK || losesInfo) {
91     LLVM_DEBUG(llvm::dbgs()
92                << srcAttr << " illegal: cannot fit into converted type '"
93                << dstType << "'\n");
94     return FloatAttr();
95   }
96 
97   return builder.getF32FloatAttr(dstVal.convertToFloat());
98 }
99 
100 /// Returns signed remainder for `lhs` and `rhs` and lets the result follow
101 /// the sign of `signOperand`.
102 ///
103 /// Note that this is needed for Vulkan. Per the Vulkan's SPIR-V environment
104 /// spec, "for the OpSRem and OpSMod instructions, if either operand is negative
105 /// the result is undefined."  So we cannot directly use spv.SRem/spv.SMod
106 /// if either operand can be negative. Emulate it via spv.UMod.
emulateSignedRemainder(Location loc,Value lhs,Value rhs,Value signOperand,OpBuilder & builder)107 static Value emulateSignedRemainder(Location loc, Value lhs, Value rhs,
108                                     Value signOperand, OpBuilder &builder) {
109   assert(lhs.getType() == rhs.getType());
110   assert(lhs == signOperand || rhs == signOperand);
111 
112   Type type = lhs.getType();
113 
114   // Calculate the remainder with spv.UMod.
115   Value lhsAbs = builder.create<spirv::GLSLSAbsOp>(loc, type, lhs);
116   Value rhsAbs = builder.create<spirv::GLSLSAbsOp>(loc, type, rhs);
117   Value abs = builder.create<spirv::UModOp>(loc, lhsAbs, rhsAbs);
118 
119   // Fix the sign.
120   Value isPositive;
121   if (lhs == signOperand)
122     isPositive = builder.create<spirv::IEqualOp>(loc, lhs, lhsAbs);
123   else
124     isPositive = builder.create<spirv::IEqualOp>(loc, rhs, rhsAbs);
125   Value absNegate = builder.create<spirv::SNegateOp>(loc, type, abs);
126   return builder.create<spirv::SelectOp>(loc, type, isPositive, abs, absNegate);
127 }
128 
129 /// Returns the offset of the value in `targetBits` representation.
130 ///
131 /// `srcIdx` is an index into a 1-D array with each element having `sourceBits`.
132 /// It's assumed to be non-negative.
133 ///
134 /// When accessing an element in the array treating as having elements of
135 /// `targetBits`, multiple values are loaded in the same time. The method
136 /// returns the offset where the `srcIdx` locates in the value. For example, if
137 /// `sourceBits` equals to 8 and `targetBits` equals to 32, the x-th element is
138 /// located at (x % 4) * 8. Because there are four elements in one i32, and one
139 /// element has 8 bits.
getOffsetForBitwidth(Location loc,Value srcIdx,int sourceBits,int targetBits,OpBuilder & builder)140 static Value getOffsetForBitwidth(Location loc, Value srcIdx, int sourceBits,
141                                   int targetBits, OpBuilder &builder) {
142   assert(targetBits % sourceBits == 0);
143   IntegerType targetType = builder.getIntegerType(targetBits);
144   IntegerAttr idxAttr =
145       builder.getIntegerAttr(targetType, targetBits / sourceBits);
146   auto idx = builder.create<spirv::ConstantOp>(loc, targetType, idxAttr);
147   IntegerAttr srcBitsAttr = builder.getIntegerAttr(targetType, sourceBits);
148   auto srcBitsValue =
149       builder.create<spirv::ConstantOp>(loc, targetType, srcBitsAttr);
150   auto m = builder.create<spirv::UModOp>(loc, srcIdx, idx);
151   return builder.create<spirv::IMulOp>(loc, targetType, m, srcBitsValue);
152 }
153 
154 /// Returns an adjusted spirv::AccessChainOp. Based on the
155 /// extension/capabilities, certain integer bitwidths `sourceBits` might not be
156 /// supported. During conversion if a memref of an unsupported type is used,
157 /// load/stores to this memref need to be modified to use a supported higher
158 /// bitwidth `targetBits` and extracting the required bits. For an accessing a
159 /// 1D array (spv.array or spv.rt_array), the last index is modified to load the
160 /// bits needed. The extraction of the actual bits needed are handled
161 /// separately. Note that this only works for a 1-D tensor.
adjustAccessChainForBitwidth(SPIRVTypeConverter & typeConverter,spirv::AccessChainOp op,int sourceBits,int targetBits,OpBuilder & builder)162 static Value adjustAccessChainForBitwidth(SPIRVTypeConverter &typeConverter,
163                                           spirv::AccessChainOp op,
164                                           int sourceBits, int targetBits,
165                                           OpBuilder &builder) {
166   assert(targetBits % sourceBits == 0);
167   const auto loc = op.getLoc();
168   IntegerType targetType = builder.getIntegerType(targetBits);
169   IntegerAttr attr =
170       builder.getIntegerAttr(targetType, targetBits / sourceBits);
171   auto idx = builder.create<spirv::ConstantOp>(loc, targetType, attr);
172   auto lastDim = op->getOperand(op.getNumOperands() - 1);
173   auto indices = llvm::to_vector<4>(op.indices());
174   // There are two elements if this is a 1-D tensor.
175   assert(indices.size() == 2);
176   indices.back() = builder.create<spirv::SDivOp>(loc, lastDim, idx);
177   Type t = typeConverter.convertType(op.component_ptr().getType());
178   return builder.create<spirv::AccessChainOp>(loc, t, op.base_ptr(), indices);
179 }
180 
181 /// Returns the shifted `targetBits`-bit value with the given offset.
shiftValue(Location loc,Value value,Value offset,Value mask,int targetBits,OpBuilder & builder)182 static Value shiftValue(Location loc, Value value, Value offset, Value mask,
183                         int targetBits, OpBuilder &builder) {
184   Type targetType = builder.getIntegerType(targetBits);
185   Value result = builder.create<spirv::BitwiseAndOp>(loc, value, mask);
186   return builder.create<spirv::ShiftLeftLogicalOp>(loc, targetType, result,
187                                                    offset);
188 }
189 
190 /// Returns true if the operator is operating on unsigned integers.
191 /// TODO: Have a TreatOperandsAsUnsignedInteger trait and bake the information
192 /// to the ops themselves.
193 template <typename SPIRVOp>
isUnsignedOp()194 bool isUnsignedOp() {
195   return false;
196 }
197 
198 #define CHECK_UNSIGNED_OP(SPIRVOp)                                             \
199   template <>                                                                  \
200   bool isUnsignedOp<SPIRVOp>() {                                               \
201     return true;                                                               \
202   }
203 
204 CHECK_UNSIGNED_OP(spirv::AtomicUMaxOp)
CHECK_UNSIGNED_OP(spirv::AtomicUMinOp)205 CHECK_UNSIGNED_OP(spirv::AtomicUMinOp)
206 CHECK_UNSIGNED_OP(spirv::BitFieldUExtractOp)
207 CHECK_UNSIGNED_OP(spirv::ConvertUToFOp)
208 CHECK_UNSIGNED_OP(spirv::GroupNonUniformUMaxOp)
209 CHECK_UNSIGNED_OP(spirv::GroupNonUniformUMinOp)
210 CHECK_UNSIGNED_OP(spirv::UConvertOp)
211 CHECK_UNSIGNED_OP(spirv::UDivOp)
212 CHECK_UNSIGNED_OP(spirv::UGreaterThanEqualOp)
213 CHECK_UNSIGNED_OP(spirv::UGreaterThanOp)
214 CHECK_UNSIGNED_OP(spirv::ULessThanEqualOp)
215 CHECK_UNSIGNED_OP(spirv::ULessThanOp)
216 CHECK_UNSIGNED_OP(spirv::UModOp)
217 
218 #undef CHECK_UNSIGNED_OP
219 
220 /// Returns true if the allocations of type `t` can be lowered to SPIR-V.
221 static bool isAllocationSupported(MemRefType t) {
222   // Currently only support workgroup local memory allocations with static
223   // shape and int or float or vector of int or float element type.
224   if (!(t.hasStaticShape() &&
225         SPIRVTypeConverter::getMemorySpaceForStorageClass(
226             spirv::StorageClass::Workgroup) == t.getMemorySpace()))
227     return false;
228   Type elementType = t.getElementType();
229   if (auto vecType = elementType.dyn_cast<VectorType>())
230     elementType = vecType.getElementType();
231   return elementType.isIntOrFloat();
232 }
233 
234 /// Returns the scope to use for atomic operations use for emulating store
235 /// operations of unsupported integer bitwidths, based on the memref
236 /// type. Returns None on failure.
getAtomicOpScope(MemRefType t)237 static Optional<spirv::Scope> getAtomicOpScope(MemRefType t) {
238   Optional<spirv::StorageClass> storageClass =
239       SPIRVTypeConverter::getStorageClassForMemorySpace(t.getMemorySpace());
240   if (!storageClass)
241     return {};
242   switch (*storageClass) {
243   case spirv::StorageClass::StorageBuffer:
244     return spirv::Scope::Device;
245   case spirv::StorageClass::Workgroup:
246     return spirv::Scope::Workgroup;
247   default: {
248   }
249   }
250   return {};
251 }
252 
253 //===----------------------------------------------------------------------===//
254 // Operation conversion
255 //===----------------------------------------------------------------------===//
256 
257 // Note that DRR cannot be used for the patterns in this file: we may need to
258 // convert type along the way, which requires ConversionPattern. DRR generates
259 // normal RewritePattern.
260 
261 namespace {
262 
263 /// Converts an allocation operation to SPIR-V. Currently only supports lowering
264 /// to Workgroup memory when the size is constant.  Note that this pattern needs
265 /// to be applied in a pass that runs at least at spv.module scope since it wil
266 /// ladd global variables into the spv.module.
267 class AllocOpPattern final : public SPIRVOpLowering<AllocOp> {
268 public:
269   using SPIRVOpLowering<AllocOp>::SPIRVOpLowering;
270 
271   LogicalResult
matchAndRewrite(AllocOp operation,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const272   matchAndRewrite(AllocOp operation, ArrayRef<Value> operands,
273                   ConversionPatternRewriter &rewriter) const override {
274     MemRefType allocType = operation.getType();
275     if (!isAllocationSupported(allocType))
276       return operation.emitError("unhandled allocation type");
277 
278     // Get the SPIR-V type for the allocation.
279     Type spirvType = typeConverter.convertType(allocType);
280 
281     // Insert spv.globalVariable for this allocation.
282     Operation *parent =
283         SymbolTable::getNearestSymbolTable(operation->getParentOp());
284     if (!parent)
285       return failure();
286     Location loc = operation.getLoc();
287     spirv::GlobalVariableOp varOp;
288     {
289       OpBuilder::InsertionGuard guard(rewriter);
290       Block &entryBlock = *parent->getRegion(0).begin();
291       rewriter.setInsertionPointToStart(&entryBlock);
292       auto varOps = entryBlock.getOps<spirv::GlobalVariableOp>();
293       std::string varName =
294           std::string("__workgroup_mem__") +
295           std::to_string(std::distance(varOps.begin(), varOps.end()));
296       varOp = rewriter.create<spirv::GlobalVariableOp>(
297           loc, TypeAttr::get(spirvType), varName,
298           /*initializer = */ nullptr);
299     }
300 
301     // Get pointer to global variable at the current scope.
302     rewriter.replaceOpWithNewOp<spirv::AddressOfOp>(operation, varOp);
303     return success();
304   }
305 };
306 
307 /// Removed a deallocation if it is a supported allocation. Currently only
308 /// removes deallocation if the memory space is workgroup memory.
309 class DeallocOpPattern final : public SPIRVOpLowering<DeallocOp> {
310 public:
311   using SPIRVOpLowering<DeallocOp>::SPIRVOpLowering;
312 
313   LogicalResult
matchAndRewrite(DeallocOp operation,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const314   matchAndRewrite(DeallocOp operation, ArrayRef<Value> operands,
315                   ConversionPatternRewriter &rewriter) const override {
316     MemRefType deallocType = operation.memref().getType().cast<MemRefType>();
317     if (!isAllocationSupported(deallocType))
318       return operation.emitError("unhandled deallocation type");
319     rewriter.eraseOp(operation);
320     return success();
321   }
322 };
323 
324 /// Converts unary and binary standard operations to SPIR-V operations.
325 template <typename StdOp, typename SPIRVOp>
326 class UnaryAndBinaryOpPattern final : public SPIRVOpLowering<StdOp> {
327 public:
328   using SPIRVOpLowering<StdOp>::SPIRVOpLowering;
329 
330   LogicalResult
matchAndRewrite(StdOp operation,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const331   matchAndRewrite(StdOp operation, ArrayRef<Value> operands,
332                   ConversionPatternRewriter &rewriter) const override {
333     assert(operands.size() <= 2);
334     auto dstType = this->typeConverter.convertType(operation.getType());
335     if (!dstType)
336       return failure();
337     if (isUnsignedOp<SPIRVOp>() && dstType != operation.getType()) {
338       return operation.emitError(
339           "bitwidth emulation is not implemented yet on unsigned op");
340     }
341     rewriter.template replaceOpWithNewOp<SPIRVOp>(operation, dstType, operands);
342     return success();
343   }
344 };
345 
346 /// Converts std.remi_signed to SPIR-V ops.
347 ///
348 /// This cannot be merged into the template unary/binary pattern due to
349 /// Vulkan restrictions over spv.SRem and spv.SMod.
350 class SignedRemIOpPattern final : public SPIRVOpLowering<SignedRemIOp> {
351 public:
352   using SPIRVOpLowering<SignedRemIOp>::SPIRVOpLowering;
353 
354   LogicalResult
355   matchAndRewrite(SignedRemIOp remOp, ArrayRef<Value> operands,
356                   ConversionPatternRewriter &rewriter) const override;
357 };
358 
359 /// Converts bitwise standard operations to SPIR-V operations. This is a special
360 /// pattern other than the BinaryOpPatternPattern because if the operands are
361 /// boolean values, SPIR-V uses different operations (`SPIRVLogicalOp`). For
362 /// non-boolean operands, SPIR-V should use `SPIRVBitwiseOp`.
363 template <typename StdOp, typename SPIRVLogicalOp, typename SPIRVBitwiseOp>
364 class BitwiseOpPattern final : public SPIRVOpLowering<StdOp> {
365 public:
366   using SPIRVOpLowering<StdOp>::SPIRVOpLowering;
367 
368   LogicalResult
matchAndRewrite(StdOp operation,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const369   matchAndRewrite(StdOp operation, ArrayRef<Value> operands,
370                   ConversionPatternRewriter &rewriter) const override {
371     assert(operands.size() == 2);
372     auto dstType =
373         this->typeConverter.convertType(operation.getResult().getType());
374     if (!dstType)
375       return failure();
376     if (isBoolScalarOrVector(operands.front().getType())) {
377       rewriter.template replaceOpWithNewOp<SPIRVLogicalOp>(operation, dstType,
378                                                            operands);
379     } else {
380       rewriter.template replaceOpWithNewOp<SPIRVBitwiseOp>(operation, dstType,
381                                                            operands);
382     }
383     return success();
384   }
385 };
386 
387 /// Converts composite std.constant operation to spv.constant.
388 class ConstantCompositeOpPattern final : public SPIRVOpLowering<ConstantOp> {
389 public:
390   using SPIRVOpLowering<ConstantOp>::SPIRVOpLowering;
391 
392   LogicalResult
393   matchAndRewrite(ConstantOp constOp, ArrayRef<Value> operands,
394                   ConversionPatternRewriter &rewriter) const override;
395 };
396 
397 /// Converts scalar std.constant operation to spv.constant.
398 class ConstantScalarOpPattern final : public SPIRVOpLowering<ConstantOp> {
399 public:
400   using SPIRVOpLowering<ConstantOp>::SPIRVOpLowering;
401 
402   LogicalResult
403   matchAndRewrite(ConstantOp constOp, ArrayRef<Value> operands,
404                   ConversionPatternRewriter &rewriter) const override;
405 };
406 
407 /// Converts floating-point comparison operations to SPIR-V ops.
408 class CmpFOpPattern final : public SPIRVOpLowering<CmpFOp> {
409 public:
410   using SPIRVOpLowering<CmpFOp>::SPIRVOpLowering;
411 
412   LogicalResult
413   matchAndRewrite(CmpFOp cmpFOp, ArrayRef<Value> operands,
414                   ConversionPatternRewriter &rewriter) const override;
415 };
416 
417 /// Converts integer compare operation on i1 type operands to SPIR-V ops.
418 class BoolCmpIOpPattern final : public SPIRVOpLowering<CmpIOp> {
419 public:
420   using SPIRVOpLowering<CmpIOp>::SPIRVOpLowering;
421 
422   LogicalResult
423   matchAndRewrite(CmpIOp cmpIOp, ArrayRef<Value> operands,
424                   ConversionPatternRewriter &rewriter) const override;
425 };
426 
427 /// Converts integer compare operation to SPIR-V ops.
428 class CmpIOpPattern final : public SPIRVOpLowering<CmpIOp> {
429 public:
430   using SPIRVOpLowering<CmpIOp>::SPIRVOpLowering;
431 
432   LogicalResult
433   matchAndRewrite(CmpIOp cmpIOp, ArrayRef<Value> operands,
434                   ConversionPatternRewriter &rewriter) const override;
435 };
436 
437 /// Converts std.load to spv.Load.
438 class IntLoadOpPattern final : public SPIRVOpLowering<LoadOp> {
439 public:
440   using SPIRVOpLowering<LoadOp>::SPIRVOpLowering;
441 
442   LogicalResult
443   matchAndRewrite(LoadOp loadOp, ArrayRef<Value> operands,
444                   ConversionPatternRewriter &rewriter) const override;
445 };
446 
447 /// Converts std.load to spv.Load.
448 class LoadOpPattern final : public SPIRVOpLowering<LoadOp> {
449 public:
450   using SPIRVOpLowering<LoadOp>::SPIRVOpLowering;
451 
452   LogicalResult
453   matchAndRewrite(LoadOp loadOp, ArrayRef<Value> operands,
454                   ConversionPatternRewriter &rewriter) const override;
455 };
456 
457 /// Converts std.return to spv.Return.
458 class ReturnOpPattern final : public SPIRVOpLowering<ReturnOp> {
459 public:
460   using SPIRVOpLowering<ReturnOp>::SPIRVOpLowering;
461 
462   LogicalResult
463   matchAndRewrite(ReturnOp returnOp, ArrayRef<Value> operands,
464                   ConversionPatternRewriter &rewriter) const override;
465 };
466 
467 /// Converts std.select to spv.Select.
468 class SelectOpPattern final : public SPIRVOpLowering<SelectOp> {
469 public:
470   using SPIRVOpLowering<SelectOp>::SPIRVOpLowering;
471   LogicalResult
472   matchAndRewrite(SelectOp op, ArrayRef<Value> operands,
473                   ConversionPatternRewriter &rewriter) const override;
474 };
475 
476 /// Converts std.store to spv.Store on integers.
477 class IntStoreOpPattern final : public SPIRVOpLowering<StoreOp> {
478 public:
479   using SPIRVOpLowering<StoreOp>::SPIRVOpLowering;
480 
481   LogicalResult
482   matchAndRewrite(StoreOp storeOp, ArrayRef<Value> operands,
483                   ConversionPatternRewriter &rewriter) const override;
484 };
485 
486 /// Converts std.store to spv.Store.
487 class StoreOpPattern final : public SPIRVOpLowering<StoreOp> {
488 public:
489   using SPIRVOpLowering<StoreOp>::SPIRVOpLowering;
490 
491   LogicalResult
492   matchAndRewrite(StoreOp storeOp, ArrayRef<Value> operands,
493                   ConversionPatternRewriter &rewriter) const override;
494 };
495 
496 /// Converts std.zexti to spv.Select if the type of source is i1 or vector of
497 /// i1.
498 class ZeroExtendI1Pattern final : public SPIRVOpLowering<ZeroExtendIOp> {
499 public:
500   using SPIRVOpLowering<ZeroExtendIOp>::SPIRVOpLowering;
501 
502   LogicalResult
matchAndRewrite(ZeroExtendIOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const503   matchAndRewrite(ZeroExtendIOp op, ArrayRef<Value> operands,
504                   ConversionPatternRewriter &rewriter) const override {
505     auto srcType = operands.front().getType();
506     if (!isBoolScalarOrVector(srcType))
507       return failure();
508 
509     auto dstType = this->typeConverter.convertType(op.getResult().getType());
510     Location loc = op.getLoc();
511     Attribute zeroAttr, oneAttr;
512     if (auto vectorType = dstType.dyn_cast<VectorType>()) {
513       zeroAttr = DenseElementsAttr::get(vectorType, 0);
514       oneAttr = DenseElementsAttr::get(vectorType, 1);
515     } else {
516       zeroAttr = IntegerAttr::get(dstType, 0);
517       oneAttr = IntegerAttr::get(dstType, 1);
518     }
519     Value zero = rewriter.create<ConstantOp>(loc, zeroAttr);
520     Value one = rewriter.create<ConstantOp>(loc, oneAttr);
521     rewriter.template replaceOpWithNewOp<spirv::SelectOp>(
522         op, dstType, operands.front(), one, zero);
523     return success();
524   }
525 };
526 
527 /// Converts type-casting standard operations to SPIR-V operations.
528 template <typename StdOp, typename SPIRVOp>
529 class TypeCastingOpPattern final : public SPIRVOpLowering<StdOp> {
530 public:
531   using SPIRVOpLowering<StdOp>::SPIRVOpLowering;
532 
533   LogicalResult
matchAndRewrite(StdOp operation,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const534   matchAndRewrite(StdOp operation, ArrayRef<Value> operands,
535                   ConversionPatternRewriter &rewriter) const override {
536     assert(operands.size() == 1);
537     auto srcType = operands.front().getType();
538     if (isBoolScalarOrVector(srcType))
539       return failure();
540     auto dstType =
541         this->typeConverter.convertType(operation.getResult().getType());
542     if (dstType == srcType) {
543       // Due to type conversion, we are seeing the same source and target type.
544       // Then we can just erase this operation by forwarding its operand.
545       rewriter.replaceOp(operation, operands.front());
546     } else {
547       rewriter.template replaceOpWithNewOp<SPIRVOp>(operation, dstType,
548                                                     operands);
549     }
550     return success();
551   }
552 };
553 
554 /// Converts std.xor to SPIR-V operations.
555 class XOrOpPattern final : public SPIRVOpLowering<XOrOp> {
556 public:
557   using SPIRVOpLowering<XOrOp>::SPIRVOpLowering;
558 
559   LogicalResult
560   matchAndRewrite(XOrOp xorOp, ArrayRef<Value> operands,
561                   ConversionPatternRewriter &rewriter) const override;
562 };
563 
564 } // namespace
565 
566 //===----------------------------------------------------------------------===//
567 // SignedRemIOpPattern
568 //===----------------------------------------------------------------------===//
569 
matchAndRewrite(SignedRemIOp remOp,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const570 LogicalResult SignedRemIOpPattern::matchAndRewrite(
571     SignedRemIOp remOp, ArrayRef<Value> operands,
572     ConversionPatternRewriter &rewriter) const {
573   Value result = emulateSignedRemainder(remOp.getLoc(), operands[0],
574                                         operands[1], operands[0], rewriter);
575   rewriter.replaceOp(remOp, result);
576 
577   return success();
578 }
579 
580 //===----------------------------------------------------------------------===//
581 // ConstantOp with composite type.
582 //===----------------------------------------------------------------------===//
583 
matchAndRewrite(ConstantOp constOp,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const584 LogicalResult ConstantCompositeOpPattern::matchAndRewrite(
585     ConstantOp constOp, ArrayRef<Value> operands,
586     ConversionPatternRewriter &rewriter) const {
587   auto srcType = constOp.getType().dyn_cast<ShapedType>();
588   if (!srcType)
589     return failure();
590 
591   // std.constant should only have vector or tenor types.
592   assert((srcType.isa<VectorType, RankedTensorType>()));
593 
594   auto dstType = typeConverter.convertType(srcType);
595   if (!dstType)
596     return failure();
597 
598   auto dstElementsAttr = constOp.value().dyn_cast<DenseElementsAttr>();
599   ShapedType dstAttrType = dstElementsAttr.getType();
600   if (!dstElementsAttr)
601     return failure();
602 
603   // If the composite type has more than one dimensions, perform linearization.
604   if (srcType.getRank() > 1) {
605     if (srcType.isa<RankedTensorType>()) {
606       dstAttrType = RankedTensorType::get(srcType.getNumElements(),
607                                           srcType.getElementType());
608       dstElementsAttr = dstElementsAttr.reshape(dstAttrType);
609     } else {
610       // TODO: add support for large vectors.
611       return failure();
612     }
613   }
614 
615   Type srcElemType = srcType.getElementType();
616   Type dstElemType;
617   // Tensor types are converted to SPIR-V array types; vector types are
618   // converted to SPIR-V vector/array types.
619   if (auto arrayType = dstType.dyn_cast<spirv::ArrayType>())
620     dstElemType = arrayType.getElementType();
621   else
622     dstElemType = dstType.cast<VectorType>().getElementType();
623 
624   // If the source and destination element types are different, perform
625   // attribute conversion.
626   if (srcElemType != dstElemType) {
627     SmallVector<Attribute, 8> elements;
628     if (srcElemType.isa<FloatType>()) {
629       for (Attribute srcAttr : dstElementsAttr.getAttributeValues()) {
630         FloatAttr dstAttr = convertFloatAttr(
631             srcAttr.cast<FloatAttr>(), dstElemType.cast<FloatType>(), rewriter);
632         if (!dstAttr)
633           return failure();
634         elements.push_back(dstAttr);
635       }
636     } else if (srcElemType.isInteger(1)) {
637       return failure();
638     } else {
639       for (Attribute srcAttr : dstElementsAttr.getAttributeValues()) {
640         IntegerAttr dstAttr =
641             convertIntegerAttr(srcAttr.cast<IntegerAttr>(),
642                                dstElemType.cast<IntegerType>(), rewriter);
643         if (!dstAttr)
644           return failure();
645         elements.push_back(dstAttr);
646       }
647     }
648 
649     // Unfortunately, we cannot use dialect-specific types for element
650     // attributes; element attributes only works with builtin types. So we need
651     // to prepare another converted builtin types for the destination elements
652     // attribute.
653     if (dstAttrType.isa<RankedTensorType>())
654       dstAttrType = RankedTensorType::get(dstAttrType.getShape(), dstElemType);
655     else
656       dstAttrType = VectorType::get(dstAttrType.getShape(), dstElemType);
657 
658     dstElementsAttr = DenseElementsAttr::get(dstAttrType, elements);
659   }
660 
661   rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType,
662                                                  dstElementsAttr);
663   return success();
664 }
665 
666 //===----------------------------------------------------------------------===//
667 // ConstantOp with scalar type.
668 //===----------------------------------------------------------------------===//
669 
matchAndRewrite(ConstantOp constOp,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const670 LogicalResult ConstantScalarOpPattern::matchAndRewrite(
671     ConstantOp constOp, ArrayRef<Value> operands,
672     ConversionPatternRewriter &rewriter) const {
673   Type srcType = constOp.getType();
674   if (!srcType.isIntOrIndexOrFloat())
675     return failure();
676 
677   Type dstType = typeConverter.convertType(srcType);
678   if (!dstType)
679     return failure();
680 
681   // Floating-point types.
682   if (srcType.isa<FloatType>()) {
683     auto srcAttr = constOp.value().cast<FloatAttr>();
684     auto dstAttr = srcAttr;
685 
686     // Floating-point types not supported in the target environment are all
687     // converted to float type.
688     if (srcType != dstType) {
689       dstAttr = convertFloatAttr(srcAttr, dstType.cast<FloatType>(), rewriter);
690       if (!dstAttr)
691         return failure();
692     }
693 
694     rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
695     return success();
696   }
697 
698   // Bool type.
699   if (srcType.isInteger(1)) {
700     // std.constant can use 0/1 instead of true/false for i1 values. We need to
701     // handle that here.
702     auto dstAttr = convertBoolAttr(constOp.value(), rewriter);
703     if (!dstAttr)
704       return failure();
705     rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
706     return success();
707   }
708 
709   // IndexType or IntegerType. Index values are converted to 32-bit integer
710   // values when converting to SPIR-V.
711   auto srcAttr = constOp.value().cast<IntegerAttr>();
712   auto dstAttr =
713       convertIntegerAttr(srcAttr, dstType.cast<IntegerType>(), rewriter);
714   if (!dstAttr)
715     return failure();
716   rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
717   return success();
718 }
719 
720 //===----------------------------------------------------------------------===//
721 // CmpFOp
722 //===----------------------------------------------------------------------===//
723 
724 LogicalResult
matchAndRewrite(CmpFOp cmpFOp,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const725 CmpFOpPattern::matchAndRewrite(CmpFOp cmpFOp, ArrayRef<Value> operands,
726                                ConversionPatternRewriter &rewriter) const {
727   CmpFOpAdaptor cmpFOpOperands(operands);
728 
729   switch (cmpFOp.getPredicate()) {
730 #define DISPATCH(cmpPredicate, spirvOp)                                        \
731   case cmpPredicate:                                                           \
732     rewriter.replaceOpWithNewOp<spirvOp>(cmpFOp, cmpFOp.getResult().getType(), \
733                                          cmpFOpOperands.lhs(),                 \
734                                          cmpFOpOperands.rhs());                \
735     return success();
736 
737     // Ordered.
738     DISPATCH(CmpFPredicate::OEQ, spirv::FOrdEqualOp);
739     DISPATCH(CmpFPredicate::OGT, spirv::FOrdGreaterThanOp);
740     DISPATCH(CmpFPredicate::OGE, spirv::FOrdGreaterThanEqualOp);
741     DISPATCH(CmpFPredicate::OLT, spirv::FOrdLessThanOp);
742     DISPATCH(CmpFPredicate::OLE, spirv::FOrdLessThanEqualOp);
743     DISPATCH(CmpFPredicate::ONE, spirv::FOrdNotEqualOp);
744     // Unordered.
745     DISPATCH(CmpFPredicate::UEQ, spirv::FUnordEqualOp);
746     DISPATCH(CmpFPredicate::UGT, spirv::FUnordGreaterThanOp);
747     DISPATCH(CmpFPredicate::UGE, spirv::FUnordGreaterThanEqualOp);
748     DISPATCH(CmpFPredicate::ULT, spirv::FUnordLessThanOp);
749     DISPATCH(CmpFPredicate::ULE, spirv::FUnordLessThanEqualOp);
750     DISPATCH(CmpFPredicate::UNE, spirv::FUnordNotEqualOp);
751 
752 #undef DISPATCH
753 
754   default:
755     break;
756   }
757   return failure();
758 }
759 
760 //===----------------------------------------------------------------------===//
761 // CmpIOp
762 //===----------------------------------------------------------------------===//
763 
764 LogicalResult
matchAndRewrite(CmpIOp cmpIOp,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const765 BoolCmpIOpPattern::matchAndRewrite(CmpIOp cmpIOp, ArrayRef<Value> operands,
766                                    ConversionPatternRewriter &rewriter) const {
767   CmpIOpAdaptor cmpIOpOperands(operands);
768 
769   Type operandType = cmpIOp.lhs().getType();
770   if (!isBoolScalarOrVector(operandType))
771     return failure();
772 
773   switch (cmpIOp.getPredicate()) {
774 #define DISPATCH(cmpPredicate, spirvOp)                                        \
775   case cmpPredicate:                                                           \
776     rewriter.replaceOpWithNewOp<spirvOp>(cmpIOp, cmpIOp.getResult().getType(), \
777                                          cmpIOpOperands.lhs(),                 \
778                                          cmpIOpOperands.rhs());                \
779     return success();
780 
781     DISPATCH(CmpIPredicate::eq, spirv::LogicalEqualOp);
782     DISPATCH(CmpIPredicate::ne, spirv::LogicalNotEqualOp);
783 
784 #undef DISPATCH
785   default:;
786   }
787   return failure();
788 }
789 
790 LogicalResult
matchAndRewrite(CmpIOp cmpIOp,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const791 CmpIOpPattern::matchAndRewrite(CmpIOp cmpIOp, ArrayRef<Value> operands,
792                                ConversionPatternRewriter &rewriter) const {
793   CmpIOpAdaptor cmpIOpOperands(operands);
794 
795   Type operandType = cmpIOp.lhs().getType();
796   if (isBoolScalarOrVector(operandType))
797     return failure();
798 
799   switch (cmpIOp.getPredicate()) {
800 #define DISPATCH(cmpPredicate, spirvOp)                                        \
801   case cmpPredicate:                                                           \
802     if (isUnsignedOp<spirvOp>() &&                                             \
803         operandType != this->typeConverter.convertType(operandType)) {         \
804       return cmpIOp.emitError(                                                 \
805           "bitwidth emulation is not implemented yet on unsigned op");         \
806     }                                                                          \
807     rewriter.replaceOpWithNewOp<spirvOp>(cmpIOp, cmpIOp.getResult().getType(), \
808                                          cmpIOpOperands.lhs(),                 \
809                                          cmpIOpOperands.rhs());                \
810     return success();
811 
812     DISPATCH(CmpIPredicate::eq, spirv::IEqualOp);
813     DISPATCH(CmpIPredicate::ne, spirv::INotEqualOp);
814     DISPATCH(CmpIPredicate::slt, spirv::SLessThanOp);
815     DISPATCH(CmpIPredicate::sle, spirv::SLessThanEqualOp);
816     DISPATCH(CmpIPredicate::sgt, spirv::SGreaterThanOp);
817     DISPATCH(CmpIPredicate::sge, spirv::SGreaterThanEqualOp);
818     DISPATCH(CmpIPredicate::ult, spirv::ULessThanOp);
819     DISPATCH(CmpIPredicate::ule, spirv::ULessThanEqualOp);
820     DISPATCH(CmpIPredicate::ugt, spirv::UGreaterThanOp);
821     DISPATCH(CmpIPredicate::uge, spirv::UGreaterThanEqualOp);
822 
823 #undef DISPATCH
824   }
825   return failure();
826 }
827 
828 //===----------------------------------------------------------------------===//
829 // LoadOp
830 //===----------------------------------------------------------------------===//
831 
832 LogicalResult
matchAndRewrite(LoadOp loadOp,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const833 IntLoadOpPattern::matchAndRewrite(LoadOp loadOp, ArrayRef<Value> operands,
834                                   ConversionPatternRewriter &rewriter) const {
835   LoadOpAdaptor loadOperands(operands);
836   auto loc = loadOp.getLoc();
837   auto memrefType = loadOp.memref().getType().cast<MemRefType>();
838   if (!memrefType.getElementType().isSignlessInteger())
839     return failure();
840   spirv::AccessChainOp accessChainOp =
841       spirv::getElementPtr(typeConverter, memrefType, loadOperands.memref(),
842                            loadOperands.indices(), loc, rewriter);
843 
844   int srcBits = memrefType.getElementType().getIntOrFloatBitWidth();
845   auto dstType = typeConverter.convertType(memrefType)
846                      .cast<spirv::PointerType>()
847                      .getPointeeType()
848                      .cast<spirv::StructType>()
849                      .getElementType(0)
850                      .cast<spirv::ArrayType>()
851                      .getElementType();
852   int dstBits = dstType.getIntOrFloatBitWidth();
853   assert(dstBits % srcBits == 0);
854 
855   // If the rewrited load op has the same bit width, use the loading value
856   // directly.
857   if (srcBits == dstBits) {
858     rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp,
859                                                accessChainOp.getResult());
860     return success();
861   }
862 
863   // Assume that getElementPtr() works linearizely. If it's a scalar, the method
864   // still returns a linearized accessing. If the accessing is not linearized,
865   // there will be offset issues.
866   assert(accessChainOp.indices().size() == 2);
867   Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp,
868                                                    srcBits, dstBits, rewriter);
869   Value spvLoadOp = rewriter.create<spirv::LoadOp>(
870       loc, dstType, adjustedPtr,
871       loadOp->getAttrOfType<IntegerAttr>(
872           spirv::attributeName<spirv::MemoryAccess>()),
873       loadOp->getAttrOfType<IntegerAttr>("alignment"));
874 
875   // Shift the bits to the rightmost.
876   // ____XXXX________ -> ____________XXXX
877   Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1);
878   Value offset = getOffsetForBitwidth(loc, lastDim, srcBits, dstBits, rewriter);
879   Value result = rewriter.create<spirv::ShiftRightArithmeticOp>(
880       loc, spvLoadOp.getType(), spvLoadOp, offset);
881 
882   // Apply the mask to extract corresponding bits.
883   Value mask = rewriter.create<spirv::ConstantOp>(
884       loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1));
885   result = rewriter.create<spirv::BitwiseAndOp>(loc, dstType, result, mask);
886 
887   // Apply sign extension on the loading value unconditionally. The signedness
888   // semantic is carried in the operator itself, we relies other pattern to
889   // handle the casting.
890   IntegerAttr shiftValueAttr =
891       rewriter.getIntegerAttr(dstType, dstBits - srcBits);
892   Value shiftValue =
893       rewriter.create<spirv::ConstantOp>(loc, dstType, shiftValueAttr);
894   result = rewriter.create<spirv::ShiftLeftLogicalOp>(loc, dstType, result,
895                                                       shiftValue);
896   result = rewriter.create<spirv::ShiftRightArithmeticOp>(loc, dstType, result,
897                                                           shiftValue);
898   rewriter.replaceOp(loadOp, result);
899 
900   assert(accessChainOp.use_empty());
901   rewriter.eraseOp(accessChainOp);
902 
903   return success();
904 }
905 
906 LogicalResult
matchAndRewrite(LoadOp loadOp,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const907 LoadOpPattern::matchAndRewrite(LoadOp loadOp, ArrayRef<Value> operands,
908                                ConversionPatternRewriter &rewriter) const {
909   LoadOpAdaptor loadOperands(operands);
910   auto memrefType = loadOp.memref().getType().cast<MemRefType>();
911   if (memrefType.getElementType().isSignlessInteger())
912     return failure();
913   auto loadPtr =
914       spirv::getElementPtr(typeConverter, memrefType, loadOperands.memref(),
915                            loadOperands.indices(), loadOp.getLoc(), rewriter);
916   rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, loadPtr);
917   return success();
918 }
919 
920 //===----------------------------------------------------------------------===//
921 // ReturnOp
922 //===----------------------------------------------------------------------===//
923 
924 LogicalResult
matchAndRewrite(ReturnOp returnOp,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const925 ReturnOpPattern::matchAndRewrite(ReturnOp returnOp, ArrayRef<Value> operands,
926                                  ConversionPatternRewriter &rewriter) const {
927   if (returnOp.getNumOperands()) {
928     return failure();
929   }
930   rewriter.replaceOpWithNewOp<spirv::ReturnOp>(returnOp);
931   return success();
932 }
933 
934 //===----------------------------------------------------------------------===//
935 // SelectOp
936 //===----------------------------------------------------------------------===//
937 
938 LogicalResult
matchAndRewrite(SelectOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const939 SelectOpPattern::matchAndRewrite(SelectOp op, ArrayRef<Value> operands,
940                                  ConversionPatternRewriter &rewriter) const {
941   SelectOpAdaptor selectOperands(operands);
942   rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, selectOperands.condition(),
943                                                selectOperands.true_value(),
944                                                selectOperands.false_value());
945   return success();
946 }
947 
948 //===----------------------------------------------------------------------===//
949 // StoreOp
950 //===----------------------------------------------------------------------===//
951 
952 LogicalResult
matchAndRewrite(StoreOp storeOp,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const953 IntStoreOpPattern::matchAndRewrite(StoreOp storeOp, ArrayRef<Value> operands,
954                                    ConversionPatternRewriter &rewriter) const {
955   StoreOpAdaptor storeOperands(operands);
956   auto memrefType = storeOp.memref().getType().cast<MemRefType>();
957   if (!memrefType.getElementType().isSignlessInteger())
958     return failure();
959 
960   auto loc = storeOp.getLoc();
961   spirv::AccessChainOp accessChainOp =
962       spirv::getElementPtr(typeConverter, memrefType, storeOperands.memref(),
963                            storeOperands.indices(), loc, rewriter);
964   int srcBits = memrefType.getElementType().getIntOrFloatBitWidth();
965   auto dstType = typeConverter.convertType(memrefType)
966                      .cast<spirv::PointerType>()
967                      .getPointeeType()
968                      .cast<spirv::StructType>()
969                      .getElementType(0)
970                      .cast<spirv::ArrayType>()
971                      .getElementType();
972   int dstBits = dstType.getIntOrFloatBitWidth();
973   assert(dstBits % srcBits == 0);
974 
975   if (srcBits == dstBits) {
976     rewriter.replaceOpWithNewOp<spirv::StoreOp>(
977         storeOp, accessChainOp.getResult(), storeOperands.value());
978     return success();
979   }
980 
981   // Since there are multi threads in the processing, the emulation will be done
982   // with atomic operations. E.g., if the storing value is i8, rewrite the
983   // StoreOp to
984   // 1) load a 32-bit integer
985   // 2) clear 8 bits in the loading value
986   // 3) store 32-bit value back
987   // 4) load a 32-bit integer
988   // 5) modify 8 bits in the loading value
989   // 6) store 32-bit value back
990   // The step 1 to step 3 are done by AtomicAnd as one atomic step, and the step
991   // 4 to step 6 are done by AtomicOr as another atomic step.
992   assert(accessChainOp.indices().size() == 2);
993   Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1);
994   Value offset = getOffsetForBitwidth(loc, lastDim, srcBits, dstBits, rewriter);
995 
996   // Create a mask to clear the destination. E.g., if it is the second i8 in
997   // i32, 0xFFFF00FF is created.
998   Value mask = rewriter.create<spirv::ConstantOp>(
999       loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1));
1000   Value clearBitsMask =
1001       rewriter.create<spirv::ShiftLeftLogicalOp>(loc, dstType, mask, offset);
1002   clearBitsMask = rewriter.create<spirv::NotOp>(loc, dstType, clearBitsMask);
1003 
1004   Value storeVal =
1005       shiftValue(loc, storeOperands.value(), offset, mask, dstBits, rewriter);
1006   Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp,
1007                                                    srcBits, dstBits, rewriter);
1008   Optional<spirv::Scope> scope = getAtomicOpScope(memrefType);
1009   if (!scope)
1010     return failure();
1011   Value result = rewriter.create<spirv::AtomicAndOp>(
1012       loc, dstType, adjustedPtr, *scope, spirv::MemorySemantics::AcquireRelease,
1013       clearBitsMask);
1014   result = rewriter.create<spirv::AtomicOrOp>(
1015       loc, dstType, adjustedPtr, *scope, spirv::MemorySemantics::AcquireRelease,
1016       storeVal);
1017 
1018   // The AtomicOrOp has no side effect. Since it is already inserted, we can
1019   // just remove the original StoreOp. Note that rewriter.replaceOp()
1020   // doesn't work because it only accepts that the numbers of result are the
1021   // same.
1022   rewriter.eraseOp(storeOp);
1023 
1024   assert(accessChainOp.use_empty());
1025   rewriter.eraseOp(accessChainOp);
1026 
1027   return success();
1028 }
1029 
1030 LogicalResult
matchAndRewrite(StoreOp storeOp,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const1031 StoreOpPattern::matchAndRewrite(StoreOp storeOp, ArrayRef<Value> operands,
1032                                 ConversionPatternRewriter &rewriter) const {
1033   StoreOpAdaptor storeOperands(operands);
1034   auto memrefType = storeOp.memref().getType().cast<MemRefType>();
1035   if (memrefType.getElementType().isSignlessInteger())
1036     return failure();
1037   auto storePtr =
1038       spirv::getElementPtr(typeConverter, memrefType, storeOperands.memref(),
1039                            storeOperands.indices(), storeOp.getLoc(), rewriter);
1040   rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, storePtr,
1041                                               storeOperands.value());
1042   return success();
1043 }
1044 
1045 //===----------------------------------------------------------------------===//
1046 // XorOp
1047 //===----------------------------------------------------------------------===//
1048 
1049 LogicalResult
matchAndRewrite(XOrOp xorOp,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const1050 XOrOpPattern::matchAndRewrite(XOrOp xorOp, ArrayRef<Value> operands,
1051                               ConversionPatternRewriter &rewriter) const {
1052   assert(operands.size() == 2);
1053 
1054   if (isBoolScalarOrVector(operands.front().getType()))
1055     return failure();
1056 
1057   auto dstType = typeConverter.convertType(xorOp.getType());
1058   if (!dstType)
1059     return failure();
1060   rewriter.replaceOpWithNewOp<spirv::BitwiseXorOp>(xorOp, dstType, operands);
1061 
1062   return success();
1063 }
1064 
1065 //===----------------------------------------------------------------------===//
1066 // Pattern population
1067 //===----------------------------------------------------------------------===//
1068 
1069 namespace mlir {
populateStandardToSPIRVPatterns(MLIRContext * context,SPIRVTypeConverter & typeConverter,OwningRewritePatternList & patterns)1070 void populateStandardToSPIRVPatterns(MLIRContext *context,
1071                                      SPIRVTypeConverter &typeConverter,
1072                                      OwningRewritePatternList &patterns) {
1073   patterns.insert<
1074       // Unary and binary patterns
1075       BitwiseOpPattern<AndOp, spirv::LogicalAndOp, spirv::BitwiseAndOp>,
1076       BitwiseOpPattern<OrOp, spirv::LogicalOrOp, spirv::BitwiseOrOp>,
1077       UnaryAndBinaryOpPattern<AbsFOp, spirv::GLSLFAbsOp>,
1078       UnaryAndBinaryOpPattern<AddFOp, spirv::FAddOp>,
1079       UnaryAndBinaryOpPattern<AddIOp, spirv::IAddOp>,
1080       UnaryAndBinaryOpPattern<CeilFOp, spirv::GLSLCeilOp>,
1081       UnaryAndBinaryOpPattern<CosOp, spirv::GLSLCosOp>,
1082       UnaryAndBinaryOpPattern<DivFOp, spirv::FDivOp>,
1083       UnaryAndBinaryOpPattern<ExpOp, spirv::GLSLExpOp>,
1084       UnaryAndBinaryOpPattern<FloorFOp, spirv::GLSLFloorOp>,
1085       UnaryAndBinaryOpPattern<LogOp, spirv::GLSLLogOp>,
1086       UnaryAndBinaryOpPattern<MulFOp, spirv::FMulOp>,
1087       UnaryAndBinaryOpPattern<MulIOp, spirv::IMulOp>,
1088       UnaryAndBinaryOpPattern<NegFOp, spirv::FNegateOp>,
1089       UnaryAndBinaryOpPattern<RemFOp, spirv::FRemOp>,
1090       UnaryAndBinaryOpPattern<RsqrtOp, spirv::GLSLInverseSqrtOp>,
1091       UnaryAndBinaryOpPattern<ShiftLeftOp, spirv::ShiftLeftLogicalOp>,
1092       UnaryAndBinaryOpPattern<SignedDivIOp, spirv::SDivOp>,
1093       UnaryAndBinaryOpPattern<SignedShiftRightOp,
1094                               spirv::ShiftRightArithmeticOp>,
1095       UnaryAndBinaryOpPattern<SinOp, spirv::GLSLSinOp>,
1096       UnaryAndBinaryOpPattern<SqrtOp, spirv::GLSLSqrtOp>,
1097       UnaryAndBinaryOpPattern<SubFOp, spirv::FSubOp>,
1098       UnaryAndBinaryOpPattern<SubIOp, spirv::ISubOp>,
1099       UnaryAndBinaryOpPattern<TanhOp, spirv::GLSLTanhOp>,
1100       UnaryAndBinaryOpPattern<UnsignedDivIOp, spirv::UDivOp>,
1101       UnaryAndBinaryOpPattern<UnsignedRemIOp, spirv::UModOp>,
1102       UnaryAndBinaryOpPattern<UnsignedShiftRightOp, spirv::ShiftRightLogicalOp>,
1103       SignedRemIOpPattern, XOrOpPattern,
1104 
1105       // Comparison patterns
1106       BoolCmpIOpPattern, CmpFOpPattern, CmpIOpPattern,
1107 
1108       // Constant patterns
1109       ConstantCompositeOpPattern, ConstantScalarOpPattern,
1110 
1111       // Memory patterns
1112       AllocOpPattern, DeallocOpPattern, IntLoadOpPattern, IntStoreOpPattern,
1113       LoadOpPattern, StoreOpPattern,
1114 
1115       ReturnOpPattern, SelectOpPattern,
1116 
1117       // Type cast patterns
1118       ZeroExtendI1Pattern, TypeCastingOpPattern<IndexCastOp, spirv::SConvertOp>,
1119       TypeCastingOpPattern<SIToFPOp, spirv::ConvertSToFOp>,
1120       TypeCastingOpPattern<ZeroExtendIOp, spirv::UConvertOp>,
1121       TypeCastingOpPattern<TruncateIOp, spirv::SConvertOp>,
1122       TypeCastingOpPattern<FPToSIOp, spirv::ConvertFToSOp>,
1123       TypeCastingOpPattern<FPExtOp, spirv::FConvertOp>,
1124       TypeCastingOpPattern<FPTruncOp, spirv::FConvertOp>>(context,
1125                                                           typeConverter);
1126 }
1127 } // namespace mlir
1128