1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 Copyright 2022 The StableHLO Authors.
3
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
7
8 http://www.apache.org/licenses/LICENSE-2.0
9
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
15 ==============================================================================*/
16
17 #include "dialect/Base.h"
18
19 #include "llvm/ADT/TypeSwitch.h"
20 #include "mlir/Dialect/Quant/QuantTypes.h"
21 #include "mlir/Dialect/Shape/IR/Shape.h"
22 #include "mlir/IR/DialectImplementation.h"
23 #include "mlir/IR/TypeUtilities.h"
24 #include "mlir/Support/LLVM.h"
25
26 // Include order matters
27 #include "dialect/BaseAttrInterfaces.cpp.inc"
28
29 namespace mlir {
30 namespace hlo {
31
32 namespace {
getExpressedTypeOrSelf(Type type)33 Type getExpressedTypeOrSelf(Type type) {
34 auto quantType = type.dyn_cast<quant::QuantizedType>();
35 return quantType ? quantType.getExpressedType() : type;
36 }
37
verifyCompatibleShapeWithBounds(Type type1,Type type2)38 LogicalResult verifyCompatibleShapeWithBounds(Type type1, Type type2) {
39 if (failed(verifyCompatibleShape(type1, type2))) return failure();
40
41 // Verify shapes against bounds
42 auto isCompatible = [](ArrayRef<int64_t> shape,
43 BoundedAttrInterface boundedAttr) {
44 if (shape.empty() || !boundedAttr) return true;
45 auto bounds = boundedAttr.getBounds();
46 for (auto [dim_size, bound] : llvm::zip(shape, bounds)) // NOLINT
47 if (bound != ShapedType::kDynamicSize && bound < dim_size) return false;
48 return true;
49 };
50
51 RankedTensorType rankedType1 = type1.dyn_cast<RankedTensorType>();
52 RankedTensorType rankedType2 = type2.dyn_cast<RankedTensorType>();
53 if (rankedType1 && rankedType2) {
54 auto boundedAttr1 =
55 rankedType1.getEncoding().dyn_cast_or_null<BoundedAttrInterface>();
56 auto boundedAttr2 =
57 rankedType2.getEncoding().dyn_cast_or_null<BoundedAttrInterface>();
58 return LogicalResult::success(
59 isCompatible(rankedType1.getShape(), boundedAttr2) &&
60 isCompatible(rankedType2.getShape(), boundedAttr1));
61 }
62 return success();
63 }
64 } // namespace
65
isCompatibleForHloTypeInference(Type tp1,Type tp2)66 bool isCompatibleForHloTypeInference(Type tp1, Type tp2) {
67 // Dynamism: We don't require shapes to be the same, we only require them
68 // to be compatible, which means that:
69 // 1) At least one of the shapes is unranked.
70 // 2) Or both shapes have the same rank and their dimensions are compatible,
71 // i.e. for each pair of corresponding dimensions:
72 // 2.1) At least one of the dimensions is dynamic,
73 // 2.2) Or both dimensions are equal.
74 // These relaxed rules simplify the implementation of type inference, allowing
75 // ops with partially inferred types to pass verification.
76 auto stp1 = tp1.dyn_cast<ShapedType>();
77 auto stp2 = tp2.dyn_cast<ShapedType>();
78 if (stp1 && stp2) {
79 return succeeded(verifyCompatibleShapeWithBounds(stp1, stp2)) &&
80 isCompatibleForHloTypeInference(stp1.getElementType(),
81 stp2.getElementType());
82 }
83
84 // Quantization: In the most general case, we allow any combination of
85 // quantized/non-quantized across any combination of operands/results,
86 // and some differences in quantization parameters across operands/results.
87 // Individual ops may introduce additional constraints.
88 auto qtp1 = tp1.dyn_cast<quant::QuantizedType>();
89 auto qtp2 = tp2.dyn_cast<quant::QuantizedType>();
90 if (qtp1 && qtp2) {
91 if (qtp1.getStorageType() != qtp2.getStorageType() ||
92 qtp1.getStorageTypeMin() != qtp2.getStorageTypeMin() ||
93 qtp1.getStorageTypeMax() != qtp2.getStorageTypeMax())
94 return false;
95 }
96 auto etp1 = getExpressedTypeOrSelf(tp1);
97 auto etp2 = getExpressedTypeOrSelf(tp2);
98
99 // Sparsity: In the most general case, we allow any combination of
100 // sparsity/denseness across any combination of operands/results, as well as
101 // differences in sparsity encodings for operands and results.
102 // Individual ops may introduce additional constraints.
103 // No additional code is needed to check this because of how sparsity is
104 // currently implemented.
105
106 // Default case: Unless dynamism, quantization and/or sparsity are involved,
107 // the types are required to be exactly equal.
108 return etp1 == etp2;
109 }
110
deriveShapeFromOperand(OpBuilder * builder,Operation * op,Value operand,SmallVectorImpl<Value> * reifiedReturnShapes)111 LogicalResult deriveShapeFromOperand(
112 OpBuilder* builder, Operation* op, Value operand,
113 SmallVectorImpl<Value>* reifiedReturnShapes) {
114 auto shapedTy = operand.getType().dyn_cast<ShapedType>();
115 if (!shapedTy) {
116 op->emitOpError() << "operand is not a shaped type";
117 return failure();
118 }
119 reifiedReturnShapes->assign(
120 {builder->create<shape::ShapeOfOp>(op->getLoc(), operand)});
121 return success();
122 }
123
getSameShapeTensorType(TensorType tensorType,Type elementType)124 TensorType getSameShapeTensorType(TensorType tensorType, Type elementType) {
125 if (auto rankedTensorTy = tensorType.dyn_cast<RankedTensorType>()) {
126 return RankedTensorType::get(rankedTensorTy.getShape(), elementType,
127 rankedTensorTy.getEncoding());
128 }
129 if (auto unrankedTensorTy = tensorType.dyn_cast<UnrankedTensorType>()) {
130 return UnrankedTensorType::get(elementType);
131 }
132 llvm_unreachable("unhandled type");
133 }
134
135 // TODO(hinsu): Add verification for bounds that it has the same size as rank
136 // of the tensor and static dimensions don't have bounds.
verifyBounds(ArrayRef<int64_t>,ShapedType,function_ref<InFlightDiagnostic ()>)137 LogicalResult verifyBounds(ArrayRef<int64_t> /*bounds*/, ShapedType /*type*/,
138 function_ref<InFlightDiagnostic()> /*emitError*/) {
139 return success();
140 }
141
142 } // namespace hlo
143 } // namespace mlir
144