1 //===- LinalgTraits.h - Linalg Traits ---------------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #ifndef MLIR_DIALECT_LINALG_LINALGTRAITS_H_ 10 #define MLIR_DIALECT_LINALG_LINALGTRAITS_H_ 11 12 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" 13 #include "mlir/Dialect/Utils/StructuredOpsUtils.h" 14 #include "mlir/IR/AffineMap.h" 15 #include "mlir/IR/BuiltinOps.h" 16 #include "mlir/IR/BuiltinTypes.h" 17 #include "mlir/IR/OpDefinition.h" 18 #include "mlir/Support/LLVM.h" 19 20 namespace mlir { 21 namespace OpTrait { 22 namespace linalg { 23 24 /// This class provides the API for ops that are known to have a specified 25 /// number of inputs, all passed as operands. Use as a trait as follows: 26 /// 27 /// class DotOp : public Op<DotOp, OpTrait::NInputs<2>::Impl> { 28 /// 29 template <unsigned N> class NInputs { 30 public: 31 template <typename ConcreteType> 32 class Impl : public OpTrait::TraitBase<ConcreteType, NInputs<N>::Impl> { 33 public: getNumInputs()34 static unsigned getNumInputs() { return N; } 35 }; 36 }; 37 38 /// This class provides the API for ops that are known to not have init tensor 39 /// operands. Use as a trait as follows: 40 /// 41 /// class CopyOp : public Op<CopyOp, OpTrait::ZeroInitTensors> { 42 /// 43 template <typename ConcreteType> 44 class ZeroInitTensors : public TraitBase<ConcreteType, ZeroInitTensors> { 45 public: getNumInitTensors()46 static unsigned getNumInitTensors() { return 0; } 47 }; 48 49 /// This class provides the API for ops that are known to have a specified 50 /// number of outputs, all passed as operands. Use as a trait as follows: 51 /// 52 /// class DotOp : public Op<DotOp, OpTrait::NOutputs<2>::Impl> { 53 /// 54 template <unsigned N> class NOutputs { 55 public: 56 template <typename ConcreteType> 57 class Impl : public OpTrait::TraitBase<ConcreteType, NOutputs<N>::Impl> { 58 public: getNumOutputs()59 static unsigned getNumOutputs() { return N; } 60 }; 61 }; 62 63 /// This class provides a verifier for structured ops that are known to operate 64 /// on buffers or tensors. This trait must be used in conjunction with an op 65 /// definition or a trait that provides the methods `getNumInputs` and 66 /// `getNumOutputs`. Use as a trait as follows: 67 /// 68 /// class DotOp : public Op<DotOp, OpTrait::StructuredOpTraits> { 69 /// 70 template <typename ConcreteType> 71 class StructuredOpTraits 72 : public OpTrait::TraitBase<ConcreteType, StructuredOpTraits> { 73 public: verifyTrait(Operation * op)74 static LogicalResult verifyTrait(Operation *op) { 75 ConcreteType concreteOp = cast<ConcreteType>(op); 76 auto nOperands = concreteOp.getNumInputsAndOutputBuffers(); 77 if (failed(OpTrait::impl::verifyAtLeastNOperands(op, nOperands))) 78 return failure(); 79 if (op->getNumResults() > concreteOp.getNumOutputs()) 80 return op->emitError("unexpected #results > #outputs"); 81 return success(); 82 } 83 }; 84 85 /// This class provides a verifier for structured ops that are known to operate 86 /// on buffers or tensors and that support `ins`, `outs` and `init` arguments. 87 /// This trait must be used in conjunction with an op definition or a trait that 88 /// provides the methods `getNumInputs` and `getNumOutputs`. 89 /// 90 /// Use as a trait as follows: 91 /// 92 /// class MatmulOp : public Op<MatmulOp, OpTrait::NamedStructuredOpTrait> { 93 /// 94 template <typename ConcreteType> 95 class NamedStructuredOpTrait 96 : public OpTrait::TraitBase<ConcreteType, NamedStructuredOpTrait> { 97 public: getNumInputs()98 unsigned getNumInputs() { 99 return cast<ConcreteType>(this->getOperation()).inputs().size(); 100 } getNumInitTensors()101 unsigned getNumInitTensors() { 102 return cast<ConcreteType>(this->getOperation()).init_tensors().size(); 103 } getNumOutputs()104 unsigned getNumOutputs() { 105 ConcreteType concreteOp = cast<ConcreteType>(this->getOperation()); 106 return concreteOp.output_buffers().size() + 107 concreteOp.result_tensors().size(); 108 } verifyTrait(Operation * op)109 static LogicalResult verifyTrait(Operation *op) { 110 ConcreteType concreteOp = cast<ConcreteType>(op); 111 unsigned nInputAndBufferOperands = 112 concreteOp.getNumInputsAndOutputBuffers(); 113 if (failed( 114 OpTrait::impl::verifyAtLeastNOperands(op, nInputAndBufferOperands))) 115 return failure(); 116 117 SmallVector<AffineExpr, 4> redDims; 118 concreteOp.getReductionDims(redDims); 119 // If no result and no reduction, only check there is no init tensor and we 120 // are done. 121 if (redDims.empty() || op->getNumResults() == 0) { 122 if (!concreteOp.init_tensors().empty()) 123 return op->emitError("expected empty `init` when op has no " 124 "results or no reduction dims"); 125 return success(); 126 } 127 128 // Only a single tensor result supported atm. 129 if (op->getNumResults() != 1) 130 return op->emitError( 131 "expected single tensor result when reduction present"); 132 133 if (concreteOp.init_tensors().size() != op->getNumResults()) 134 return op->emitError( 135 "expected #init tensors to match #results when reduction present"); 136 137 for (unsigned idx = 0, e = op->getNumResults(); idx < e; ++idx) 138 if (concreteOp.init_tensors()[idx].getType() != op->getResultTypes()[idx]) 139 return op->emitError("expected init tensor #") 140 << idx << " of the same type as result #" << idx; 141 142 // Output tensor indexing map may not depend on reduction index. 143 // TODO: this is not yet tested. Add a test when linalg.generic switches to 144 // this representation. 145 for (unsigned idx = 0, e = concreteOp.getNumOutputs(); idx < e; ++idx) { 146 AffineMap outputMap = concreteOp.getOutputIndexingMap(idx); 147 for (auto expr : outputMap.getResults()) { 148 for (auto dim : redDims) { 149 unsigned pos = dim.cast<AffineDimExpr>().getPosition(); 150 if (expr.isFunctionOfDim(pos)) 151 return op->emitError( 152 "unexpected single tensor output indexing map ") 153 << "is function of reduction dim @" << pos; 154 } 155 } 156 } 157 158 return success(); 159 } 160 }; 161 162 } // namespace linalg 163 } // namespace OpTrait 164 } // namespace mlir 165 166 #endif // MLIR_DIALECT_LINALG_LINALGTRAITS_H_ 167