1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2 Copyright 2022 The StableHLO Authors. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 ==============================================================================*/ 16 17 #ifndef STABLEHLO_DIALECT_BASE_H 18 #define STABLEHLO_DIALECT_BASE_H 19 20 #include <algorithm> 21 22 #include "llvm/ADT/Sequence.h" 23 #include "llvm/ADT/SmallVector.h" 24 #include "mlir/IR/Attributes.h" 25 #include "mlir/IR/Builders.h" 26 #include "mlir/IR/BuiltinAttributes.h" 27 #include "mlir/IR/BuiltinTypes.h" 28 #include "mlir/IR/Diagnostics.h" 29 #include "mlir/IR/DialectInterface.h" 30 #include "mlir/IR/MLIRContext.h" 31 #include "mlir/IR/OpDefinition.h" 32 #include "mlir/IR/Operation.h" 33 #include "mlir/IR/Types.h" 34 #include "mlir/IR/Value.h" 35 #include "mlir/Interfaces/InferTypeOpInterface.h" 36 #include "mlir/Support/LogicalResult.h" 37 38 // Include order matters 39 #include "dialect/BaseAttrInterfaces.h.inc" 40 41 namespace mlir { 42 namespace hlo { 43 44 // Returns true if the given types are the same for the purposes of HLO type 45 // inference, accounting for special properties of quantization and sparsity. 46 bool isCompatibleForHloTypeInference(Type tp1, Type tp2); 47 48 // Shape derivation function that computes the shape of the result based on an 49 // operand. For a 2-dimensional input tensor, this produces IR of the form 50 // 51 // %0 = dim %arg0, 0 : memref<?x?xf32> 52 // %1 = index_cast %0 : index to i64 53 // %2 = dim %arg0, 1 : memref<?x?xf32> 54 // %3 = index_cast %2 : index to i64 55 // %4 = "shape.shape_of"(%1, %3) 56 // : (i64, i64) -> tensor<2xi64> 57 // 58 // and returns %4 as the shape value. 59 LogicalResult deriveShapeFromOperand( 60 OpBuilder *builder, Operation *op, Value operand, 61 SmallVectorImpl<Value> *reifiedReturnShapes); 62 63 // Type derivation function that returns a tensor type with a new element type. 64 TensorType getSameShapeTensorType(TensorType tensorType, Type elementType); 65 66 // Verify bounds expressed by HLO_BoundedInterface against the provided type. 67 // See documentation for HLO_BoundedInterface for the list of checks. 68 LogicalResult verifyBounds(ArrayRef<int64_t> bounds, ShapedType type, 69 function_ref<InFlightDiagnostic()> emitError); 70 71 // This interface is used for HLO dialects that have accompanying 72 // BoundedAttrInterface attributes which can carry bounds for dimension sizes 73 // of accompanying shaped types. 74 class BoundedDialectInterface 75 : public DialectInterface::Base<BoundedDialectInterface> { 76 public: BoundedDialectInterface(Dialect * dialect)77 explicit BoundedDialectInterface(Dialect *dialect) : Base(dialect) {} 78 virtual Attribute createBoundedAttr(ArrayRef<int64_t> bounds) const = 0; 79 }; 80 81 namespace OpTrait { 82 83 template <typename ConcreteType> 84 class BroadcastingElementwise 85 : public mlir::OpTrait::TraitBase<ConcreteType, BroadcastingElementwise> {}; 86 87 template <typename ConcreteType> 88 class PairwiseSameOperandAndResultType 89 : public mlir::OpTrait::TraitBase<ConcreteType, 90 PairwiseSameOperandAndResultType> { 91 public: verifyTrait(Operation * op)92 static LogicalResult verifyTrait(Operation *op) { 93 const int numOperands = op->getNumOperands(); 94 const int numResults = op->getNumResults(); 95 if (numOperands != numResults) { 96 return op->emitOpError() 97 << "requires the same number of operands and results"; 98 } 99 100 for (int idx : llvm::seq<int>(0, numOperands)) { 101 if (op->getOperand(idx).getType() != op->getResult(idx).getType()) { 102 return op->emitOpError() 103 << "requires the same type for operand and result at index " 104 << idx; 105 } 106 } 107 return success(); 108 } 109 }; 110 111 template <typename ConcreteType> 112 class CompatibleOperandsAndResultType 113 : public mlir::OpTrait::TraitBase<ConcreteType, 114 CompatibleOperandsAndResultType> { 115 public: verifyTrait(Operation * op)116 static LogicalResult verifyTrait(Operation *op) { 117 Type expected; 118 if (op->getNumResults() != 0) expected = op->getResult(0).getType(); 119 if (op->getNumOperands() != 0) expected = op->getOperand(0).getType(); 120 if (!expected) return failure(); 121 122 auto typeMatch = [&](Type actual) { 123 return isCompatibleForHloTypeInference(actual, expected); 124 }; 125 auto allMatch = llvm::all_of(op->getOperandTypes(), typeMatch) && 126 llvm::all_of(op->getResultTypes(), typeMatch); 127 if (!allMatch) { 128 return op->emitOpError( 129 "requires compatible types for all operands and results"); 130 } 131 132 return success(allMatch); 133 } 134 inferReturnTypes(MLIRContext *,Optional<Location> location,ValueRange operands,DictionaryAttr,RegionRange,SmallVectorImpl<Type> & inferredReturnTypes)135 static LogicalResult inferReturnTypes( 136 MLIRContext * /*context*/, Optional<Location> location, 137 ValueRange operands, DictionaryAttr /*attributes*/, 138 RegionRange /*regions*/, SmallVectorImpl<Type> &inferredReturnTypes) { 139 // TODO(b/231358795): Review the use of InferTypeOpInterface for ops that 140 // support quantization or sparsity. 141 if (operands.empty()) 142 return emitOptionalError( 143 location, 144 "Expected non-empty operands for [CompatibleOperandsAndResultType]"); 145 146 if (failed(inferMostSpecificType(location, operands.getTypes(), 147 inferredReturnTypes))) 148 return failure(); 149 return success(); 150 } 151 152 // This function is not going to be called automatically. 153 // It needs to be paired with INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS 154 // (see examples in StablehloOps.cc). inferReturnTypeComponentsFromOperands(MLIRContext * context,Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)155 static LogicalResult inferReturnTypeComponentsFromOperands( 156 MLIRContext *context, Optional<Location> location, 157 ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions, 158 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { 159 SmallVector<Type> inferredReturnTypes; 160 if (failed(inferReturnTypes(context, location, operands.getValues(), 161 attributes, regions, inferredReturnTypes))) 162 return failure(); 163 auto inferredReturnType = inferredReturnTypes[0].cast<ShapedType>(); 164 inferredReturnShapes.push_back(inferredReturnType); 165 return success(); 166 } 167 168 private: 169 // Cases of infer return shape with bounds (lhs and rhs are commutative): 170 // Dim of lhs Dim of rhs Infer 171 // c0: 3 3 3 172 // c1: 3 ? 3 173 // c2: 3 ?, bound=4 3 174 // c3: 3 ?, bound=2 Error out 175 // c4: ? ? ? 176 // c5: ? ?, bound=3 ?, bound=3 177 // c6: ?, bound=3 ?, bound=3 ?, bound=3 178 // c7: ?, bound=3 ?, bound=4 ?, bound=3 179 // This method generalizes it to multiple inputs: 1) get the static input dims 180 // (if any) as infer dim, and 2) get min of input bounds as infer bound inferMostSpecificType(Optional<Location> location,ValueTypeRange<ValueRange> inputTypes,SmallVectorImpl<Type> & inferredReturnTypes)181 static LogicalResult inferMostSpecificType( 182 Optional<Location> location, ValueTypeRange<ValueRange> inputTypes, 183 SmallVectorImpl<Type> &inferredReturnTypes) { 184 SmallVector<RankedTensorType> rankedTypes; 185 for (auto inputType : inputTypes) 186 if (auto rankedType = inputType.dyn_cast<RankedTensorType>()) 187 rankedTypes.push_back(rankedType); 188 if (rankedTypes.empty()) { 189 inferredReturnTypes.push_back(inputTypes[0]); 190 return success(); 191 } 192 193 auto rank = rankedTypes[0].getRank(); 194 BoundedDialectInterface *dialect = nullptr; 195 SmallVector<int64_t> inferredDimSizes(rank, ShapedType::kDynamicSize); 196 SmallVector<int64_t> inferredBounds(rank, ShapedType::kDynamicSize); 197 for (auto rankedType : rankedTypes) { 198 SmallVector<int64_t> bounds; 199 if (auto boundedAttr = rankedType.getEncoding() 200 .dyn_cast_or_null<BoundedAttrInterface>()) { 201 dialect = cast<BoundedDialectInterface>(&boundedAttr.getDialect()); 202 bounds = llvm::to_vector<4>(boundedAttr.getBounds()); 203 } else if (rankedType.getEncoding()) { 204 // TODO(zhouxin) infer sparsity encoding after b/238903065 is fixed. 205 inferredReturnTypes.push_back(inputTypes[0]); 206 return success(); 207 } 208 209 for (int dim = 0; dim < rank; ++dim) { 210 // Dimensions 211 auto dimSize = rankedType.getShape()[dim]; 212 if (inferredDimSizes[dim] != ShapedType::kDynamicSize && 213 dimSize != ShapedType::kDynamicSize && 214 inferredDimSizes[dim] != dimSize) 215 return emitOptionalError(location, "Mismatch dimension size ", 216 inferredDimSizes[dim], " and ", dimSize, 217 " in dimension ", dim); 218 if (inferredDimSizes[dim] == ShapedType::kDynamicSize) 219 inferredDimSizes[dim] = dimSize; 220 221 // Bounds 222 if (!bounds.empty() && bounds[dim] != ShapedType::kDynamicSize) { 223 if (inferredBounds[dim] == ShapedType::kDynamicSize) { 224 inferredBounds[dim] = bounds[dim]; 225 } else { 226 inferredBounds[dim] = std::min(inferredBounds[dim], bounds[dim]); 227 } 228 } 229 // Error out case that the inferred bound is smaller than inferred dim 230 if (inferredBounds[dim] != ShapedType::kDynamicSize && 231 inferredBounds[dim] < inferredDimSizes[dim]) 232 return emitOptionalError(location, 233 "bound must not be less than static " 234 "dimension size but has bound ", 235 inferredBounds[dim], " vs static size ", 236 inferredDimSizes[dim], " in dimension ", 237 dim); 238 if (inferredDimSizes[dim] != ShapedType::kDynamicSize) 239 inferredBounds[dim] = ShapedType::kDynamicSize; 240 } 241 } 242 243 Attribute encoding = nullptr; 244 if (llvm::any_of(inferredBounds, 245 [](auto el) { return el != ShapedType::kDynamicSize; })) { 246 encoding = dialect->createBoundedAttr(inferredBounds); 247 } 248 inferredReturnTypes.push_back(RankedTensorType::get( 249 inferredDimSizes, rankedTypes[0].getElementType(), encoding)); 250 251 return success(); 252 } 253 }; 254 255 } // namespace OpTrait 256 } // namespace hlo 257 } // namespace mlir 258 259 #endif 260