• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2    Copyright 2022 The StableHLO Authors.
3 
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
7 
8     http://www.apache.org/licenses/LICENSE-2.0
9 
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
15 ==============================================================================*/
16 
17 #include "dialect/Base.h"
18 
19 #include "llvm/ADT/TypeSwitch.h"
20 #include "mlir/Dialect/Quant/QuantTypes.h"
21 #include "mlir/Dialect/Shape/IR/Shape.h"
22 #include "mlir/IR/DialectImplementation.h"
23 #include "mlir/IR/TypeUtilities.h"
24 #include "mlir/Support/LLVM.h"
25 
26 // Include order matters
27 #include "dialect/BaseAttrInterfaces.cpp.inc"
28 
29 namespace mlir {
30 namespace hlo {
31 
32 namespace {
getExpressedTypeOrSelf(Type type)33 Type getExpressedTypeOrSelf(Type type) {
34   auto quantType = type.dyn_cast<quant::QuantizedType>();
35   return quantType ? quantType.getExpressedType() : type;
36 }
37 
verifyCompatibleShapeWithBounds(Type type1,Type type2)38 LogicalResult verifyCompatibleShapeWithBounds(Type type1, Type type2) {
39   if (failed(verifyCompatibleShape(type1, type2))) return failure();
40 
41   // Verify shapes against bounds
42   auto isCompatible = [](ArrayRef<int64_t> shape,
43                          BoundedAttrInterface boundedAttr) {
44     if (shape.empty() || !boundedAttr) return true;
45     auto bounds = boundedAttr.getBounds();
46     for (auto [dim_size, bound] : llvm::zip(shape, bounds))  // NOLINT
47       if (bound != ShapedType::kDynamicSize && bound < dim_size) return false;
48     return true;
49   };
50 
51   RankedTensorType rankedType1 = type1.dyn_cast<RankedTensorType>();
52   RankedTensorType rankedType2 = type2.dyn_cast<RankedTensorType>();
53   if (rankedType1 && rankedType2) {
54     auto boundedAttr1 =
55         rankedType1.getEncoding().dyn_cast_or_null<BoundedAttrInterface>();
56     auto boundedAttr2 =
57         rankedType2.getEncoding().dyn_cast_or_null<BoundedAttrInterface>();
58     return LogicalResult::success(
59         isCompatible(rankedType1.getShape(), boundedAttr2) &&
60         isCompatible(rankedType2.getShape(), boundedAttr1));
61   }
62   return success();
63 }
64 }  // namespace
65 
isCompatibleForHloTypeInference(Type tp1,Type tp2)66 bool isCompatibleForHloTypeInference(Type tp1, Type tp2) {
67   // Dynamism: We don't require shapes to be the same, we only require them
68   // to be compatible, which means that:
69   //   1) At least one of the shapes is unranked.
70   //   2) Or both shapes have the same rank and their dimensions are compatible,
71   //     i.e. for each pair of corresponding dimensions:
72   //       2.1) At least one of the dimensions is dynamic,
73   //       2.2) Or both dimensions are equal.
74   // These relaxed rules simplify the implementation of type inference, allowing
75   // ops with partially inferred types to pass verification.
76   auto stp1 = tp1.dyn_cast<ShapedType>();
77   auto stp2 = tp2.dyn_cast<ShapedType>();
78   if (stp1 && stp2) {
79     return succeeded(verifyCompatibleShapeWithBounds(stp1, stp2)) &&
80            isCompatibleForHloTypeInference(stp1.getElementType(),
81                                            stp2.getElementType());
82   }
83 
84   // Quantization: In the most general case, we allow any combination of
85   // quantized/non-quantized across any combination of operands/results,
86   // and some differences in quantization parameters across operands/results.
87   // Individual ops may introduce additional constraints.
88   auto qtp1 = tp1.dyn_cast<quant::QuantizedType>();
89   auto qtp2 = tp2.dyn_cast<quant::QuantizedType>();
90   if (qtp1 && qtp2) {
91     if (qtp1.getStorageType() != qtp2.getStorageType() ||
92         qtp1.getStorageTypeMin() != qtp2.getStorageTypeMin() ||
93         qtp1.getStorageTypeMax() != qtp2.getStorageTypeMax())
94       return false;
95   }
96   auto etp1 = getExpressedTypeOrSelf(tp1);
97   auto etp2 = getExpressedTypeOrSelf(tp2);
98 
99   // Sparsity: In the most general case, we allow any combination of
100   // sparsity/denseness across any combination of operands/results, as well as
101   // differences in sparsity encodings for operands and results.
102   // Individual ops may introduce additional constraints.
103   // No additional code is needed to check this because of how sparsity is
104   // currently implemented.
105 
106   // Default case: Unless dynamism, quantization and/or sparsity are involved,
107   // the types are required to be exactly equal.
108   return etp1 == etp2;
109 }
110 
deriveShapeFromOperand(OpBuilder * builder,Operation * op,Value operand,SmallVectorImpl<Value> * reifiedReturnShapes)111 LogicalResult deriveShapeFromOperand(
112     OpBuilder* builder, Operation* op, Value operand,
113     SmallVectorImpl<Value>* reifiedReturnShapes) {
114   auto shapedTy = operand.getType().dyn_cast<ShapedType>();
115   if (!shapedTy) {
116     op->emitOpError() << "operand is not a shaped type";
117     return failure();
118   }
119   reifiedReturnShapes->assign(
120       {builder->create<shape::ShapeOfOp>(op->getLoc(), operand)});
121   return success();
122 }
123 
getSameShapeTensorType(TensorType tensorType,Type elementType)124 TensorType getSameShapeTensorType(TensorType tensorType, Type elementType) {
125   if (auto rankedTensorTy = tensorType.dyn_cast<RankedTensorType>()) {
126     return RankedTensorType::get(rankedTensorTy.getShape(), elementType,
127                                  rankedTensorTy.getEncoding());
128   }
129   if (auto unrankedTensorTy = tensorType.dyn_cast<UnrankedTensorType>()) {
130     return UnrankedTensorType::get(elementType);
131   }
132   llvm_unreachable("unhandled type");
133 }
134 
135 // TODO(hinsu): Add verification for bounds that it has the same size as rank
136 // of the tensor and static dimensions don't have bounds.
verifyBounds(ArrayRef<int64_t>,ShapedType,function_ref<InFlightDiagnostic ()>)137 LogicalResult verifyBounds(ArrayRef<int64_t> /*bounds*/, ShapedType /*type*/,
138                            function_ref<InFlightDiagnostic()> /*emitError*/) {
139   return success();
140 }
141 
142 }  // namespace hlo
143 }  // namespace mlir
144