• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2    Copyright 2022 The StableHLO Authors.
3 
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
7 
8     http://www.apache.org/licenses/LICENSE-2.0
9 
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
15 ==============================================================================*/
16 
17 #ifndef STABLEHLO_DIALECT_BASE_H
18 #define STABLEHLO_DIALECT_BASE_H
19 
20 #include <algorithm>
21 
22 #include "llvm/ADT/Sequence.h"
23 #include "llvm/ADT/SmallVector.h"
24 #include "mlir/IR/Attributes.h"
25 #include "mlir/IR/Builders.h"
26 #include "mlir/IR/BuiltinAttributes.h"
27 #include "mlir/IR/BuiltinTypes.h"
28 #include "mlir/IR/Diagnostics.h"
29 #include "mlir/IR/DialectInterface.h"
30 #include "mlir/IR/MLIRContext.h"
31 #include "mlir/IR/OpDefinition.h"
32 #include "mlir/IR/Operation.h"
33 #include "mlir/IR/Types.h"
34 #include "mlir/IR/Value.h"
35 #include "mlir/Interfaces/InferTypeOpInterface.h"
36 #include "mlir/Support/LogicalResult.h"
37 
38 // Include order matters
39 #include "dialect/BaseAttrInterfaces.h.inc"
40 
41 namespace mlir {
42 namespace hlo {
43 
44 // Returns true if the given types are the same for the purposes of HLO type
45 // inference, accounting for special properties of quantization and sparsity.
46 bool isCompatibleForHloTypeInference(Type tp1, Type tp2);
47 
48 // Shape derivation function that computes the shape of the result based on an
49 // operand. For a 2-dimensional input tensor, this produces IR of the form
50 //
51 //  %0 = dim %arg0, 0 : memref<?x?xf32>
52 //  %1 = index_cast %0 : index to i64
53 //  %2 = dim %arg0, 1 : memref<?x?xf32>
54 //  %3 = index_cast %2 : index to i64
55 //  %4 = "shape.shape_of"(%1, %3)
56 //    : (i64, i64) -> tensor<2xi64>
57 //
58 // and returns %4 as the shape value.
59 LogicalResult deriveShapeFromOperand(
60     OpBuilder *builder, Operation *op, Value operand,
61     SmallVectorImpl<Value> *reifiedReturnShapes);
62 
63 // Type derivation function that returns a tensor type with a new element type.
64 TensorType getSameShapeTensorType(TensorType tensorType, Type elementType);
65 
66 // Verify bounds expressed by HLO_BoundedInterface against the provided type.
67 // See documentation for HLO_BoundedInterface for the list of checks.
68 LogicalResult verifyBounds(ArrayRef<int64_t> bounds, ShapedType type,
69                            function_ref<InFlightDiagnostic()> emitError);
70 
71 // This interface is used for HLO dialects that have accompanying
72 // BoundedAttrInterface attributes which can carry bounds for dimension sizes
73 // of accompanying shaped types.
74 class BoundedDialectInterface
75     : public DialectInterface::Base<BoundedDialectInterface> {
76  public:
BoundedDialectInterface(Dialect * dialect)77   explicit BoundedDialectInterface(Dialect *dialect) : Base(dialect) {}
78   virtual Attribute createBoundedAttr(ArrayRef<int64_t> bounds) const = 0;
79 };
80 
81 namespace OpTrait {
82 
83 template <typename ConcreteType>
84 class BroadcastingElementwise
85     : public mlir::OpTrait::TraitBase<ConcreteType, BroadcastingElementwise> {};
86 
87 template <typename ConcreteType>
88 class PairwiseSameOperandAndResultType
89     : public mlir::OpTrait::TraitBase<ConcreteType,
90                                       PairwiseSameOperandAndResultType> {
91  public:
verifyTrait(Operation * op)92   static LogicalResult verifyTrait(Operation *op) {
93     const int numOperands = op->getNumOperands();
94     const int numResults = op->getNumResults();
95     if (numOperands != numResults) {
96       return op->emitOpError()
97              << "requires the same number of operands and results";
98     }
99 
100     for (int idx : llvm::seq<int>(0, numOperands)) {
101       if (op->getOperand(idx).getType() != op->getResult(idx).getType()) {
102         return op->emitOpError()
103                << "requires the same type for operand and result at index "
104                << idx;
105       }
106     }
107     return success();
108   }
109 };
110 
111 template <typename ConcreteType>
112 class CompatibleOperandsAndResultType
113     : public mlir::OpTrait::TraitBase<ConcreteType,
114                                       CompatibleOperandsAndResultType> {
115  public:
verifyTrait(Operation * op)116   static LogicalResult verifyTrait(Operation *op) {
117     Type expected;
118     if (op->getNumResults() != 0) expected = op->getResult(0).getType();
119     if (op->getNumOperands() != 0) expected = op->getOperand(0).getType();
120     if (!expected) return failure();
121 
122     auto typeMatch = [&](Type actual) {
123       return isCompatibleForHloTypeInference(actual, expected);
124     };
125     auto allMatch = llvm::all_of(op->getOperandTypes(), typeMatch) &&
126                     llvm::all_of(op->getResultTypes(), typeMatch);
127     if (!allMatch) {
128       return op->emitOpError(
129           "requires compatible types for all operands and results");
130     }
131 
132     return success(allMatch);
133   }
134 
inferReturnTypes(MLIRContext *,Optional<Location> location,ValueRange operands,DictionaryAttr,RegionRange,SmallVectorImpl<Type> & inferredReturnTypes)135   static LogicalResult inferReturnTypes(
136       MLIRContext * /*context*/, Optional<Location> location,
137       ValueRange operands, DictionaryAttr /*attributes*/,
138       RegionRange /*regions*/, SmallVectorImpl<Type> &inferredReturnTypes) {
139     // TODO(b/231358795): Review the use of InferTypeOpInterface for ops that
140     // support quantization or sparsity.
141     if (operands.empty())
142       return emitOptionalError(
143           location,
144           "Expected non-empty operands for [CompatibleOperandsAndResultType]");
145 
146     if (failed(inferMostSpecificType(location, operands.getTypes(),
147                                      inferredReturnTypes)))
148       return failure();
149     return success();
150   }
151 
152   // This function is not going to be called automatically.
153   // It needs to be paired with INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS
154   // (see examples in StablehloOps.cc).
inferReturnTypeComponentsFromOperands(MLIRContext * context,Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)155   static LogicalResult inferReturnTypeComponentsFromOperands(
156       MLIRContext *context, Optional<Location> location,
157       ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
158       SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
159     SmallVector<Type> inferredReturnTypes;
160     if (failed(inferReturnTypes(context, location, operands.getValues(),
161                                 attributes, regions, inferredReturnTypes)))
162       return failure();
163     auto inferredReturnType = inferredReturnTypes[0].cast<ShapedType>();
164     inferredReturnShapes.push_back(inferredReturnType);
165     return success();
166   }
167 
168  private:
169   // Cases of infer return shape with bounds (lhs and rhs are commutative):
170   //       Dim of lhs     Dim of rhs      Infer
171   //  c0:  3              3               3
172   //  c1:  3              ?               3
173   //  c2:  3              ?, bound=4      3
174   //  c3:  3              ?, bound=2      Error out
175   //  c4:  ?              ?               ?
176   //  c5:  ?              ?, bound=3      ?, bound=3
177   //  c6:  ?, bound=3     ?, bound=3      ?, bound=3
178   //  c7:  ?, bound=3     ?, bound=4      ?, bound=3
179   // This method generalizes it to multiple inputs: 1) get the static input dims
180   // (if any) as infer dim, and 2) get min of input bounds as infer bound
inferMostSpecificType(Optional<Location> location,ValueTypeRange<ValueRange> inputTypes,SmallVectorImpl<Type> & inferredReturnTypes)181   static LogicalResult inferMostSpecificType(
182       Optional<Location> location, ValueTypeRange<ValueRange> inputTypes,
183       SmallVectorImpl<Type> &inferredReturnTypes) {
184     SmallVector<RankedTensorType> rankedTypes;
185     for (auto inputType : inputTypes)
186       if (auto rankedType = inputType.dyn_cast<RankedTensorType>())
187         rankedTypes.push_back(rankedType);
188     if (rankedTypes.empty()) {
189       inferredReturnTypes.push_back(inputTypes[0]);
190       return success();
191     }
192 
193     auto rank = rankedTypes[0].getRank();
194     BoundedDialectInterface *dialect = nullptr;
195     SmallVector<int64_t> inferredDimSizes(rank, ShapedType::kDynamicSize);
196     SmallVector<int64_t> inferredBounds(rank, ShapedType::kDynamicSize);
197     for (auto rankedType : rankedTypes) {
198       SmallVector<int64_t> bounds;
199       if (auto boundedAttr = rankedType.getEncoding()
200                                  .dyn_cast_or_null<BoundedAttrInterface>()) {
201         dialect = cast<BoundedDialectInterface>(&boundedAttr.getDialect());
202         bounds = llvm::to_vector<4>(boundedAttr.getBounds());
203       } else if (rankedType.getEncoding()) {
204         // TODO(zhouxin) infer sparsity encoding after b/238903065 is fixed.
205         inferredReturnTypes.push_back(inputTypes[0]);
206         return success();
207       }
208 
209       for (int dim = 0; dim < rank; ++dim) {
210         // Dimensions
211         auto dimSize = rankedType.getShape()[dim];
212         if (inferredDimSizes[dim] != ShapedType::kDynamicSize &&
213             dimSize != ShapedType::kDynamicSize &&
214             inferredDimSizes[dim] != dimSize)
215           return emitOptionalError(location, "Mismatch dimension size ",
216                                    inferredDimSizes[dim], " and ", dimSize,
217                                    " in dimension ", dim);
218         if (inferredDimSizes[dim] == ShapedType::kDynamicSize)
219           inferredDimSizes[dim] = dimSize;
220 
221         // Bounds
222         if (!bounds.empty() && bounds[dim] != ShapedType::kDynamicSize) {
223           if (inferredBounds[dim] == ShapedType::kDynamicSize) {
224             inferredBounds[dim] = bounds[dim];
225           } else {
226             inferredBounds[dim] = std::min(inferredBounds[dim], bounds[dim]);
227           }
228         }
229         // Error out case that the inferred bound is smaller than inferred dim
230         if (inferredBounds[dim] != ShapedType::kDynamicSize &&
231             inferredBounds[dim] < inferredDimSizes[dim])
232           return emitOptionalError(location,
233                                    "bound must not be less than static "
234                                    "dimension size but has bound ",
235                                    inferredBounds[dim], " vs static size ",
236                                    inferredDimSizes[dim], " in dimension ",
237                                    dim);
238         if (inferredDimSizes[dim] != ShapedType::kDynamicSize)
239           inferredBounds[dim] = ShapedType::kDynamicSize;
240       }
241     }
242 
243     Attribute encoding = nullptr;
244     if (llvm::any_of(inferredBounds,
245                      [](auto el) { return el != ShapedType::kDynamicSize; })) {
246       encoding = dialect->createBoundedAttr(inferredBounds);
247     }
248     inferredReturnTypes.push_back(RankedTensorType::get(
249         inferredDimSizes, rankedTypes[0].getElementType(), encoding));
250 
251     return success();
252   }
253 };
254 
255 }  // namespace OpTrait
256 }  // namespace hlo
257 }  // namespace mlir
258 
259 #endif
260