1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "mlir-hlo/utils/hlo_utils.h"
17
18 #include <numeric>
19 #include <string>
20
21 #include "mlir/Dialect/Func/IR/FuncOps.h"
22 #include "mlir/IR/Attributes.h"
23
24 namespace mlir {
25 namespace hlo {
26
27 static constexpr size_t kPaddingSize = 64;
28
getBroadcastDimensionsAttr(Builder * b,Value x,Value y,bool allowEmpty)29 DenseIntElementsAttr getBroadcastDimensionsAttr(Builder* b, Value x, Value y,
30 bool allowEmpty) {
31 TensorType xType = x.getType().dyn_cast<RankedTensorType>();
32 TensorType yType = y.getType().dyn_cast<RankedTensorType>();
33 if (!xType || !yType) return {};
34 if (allowEmpty && xType == yType) return {};
35
36 // If the shapes have the same rank, then there is nothing to do.
37 auto xRank = xType.getRank(), yRank = yType.getRank();
38 if (allowEmpty && xRank == yRank) return {};
39
40 // Otherwise if the ranks of the inputs don't match, TensorFlow automatically
41 // reshapes the smaller by padding with dimensions of size 1 as a prefix. In
42 // other words to pad a 5-vector to a 3-dimensional tensor it is reshaped to
43 // have shape [1,1,5]. XLA's automatic broadcast code is able to broadcast
44 // from lower to higher rank, but doesn't assume you want to pad as a prefix
45 // of the dimensions, and instead needs to be told which dimensions of the
46 // higher rank tensor to match to the lower rank tensor.
47 auto maxRank = std::max(xRank, yRank);
48 auto minRank = std::min(xRank, yRank);
49
50 // Match the lower rank tensor along the larger-numbered dimensions of the
51 // higher rank tensor.
52 SmallVector<int64_t, 4> broadcastDimensions(minRank);
53 std::iota(broadcastDimensions.begin(), broadcastDimensions.end(),
54 maxRank - minRank);
55
56 RankedTensorType type =
57 RankedTensorType::get({minRank}, b->getIntegerType(64));
58 return DenseIntElementsAttr::get(type, broadcastDimensions);
59 }
60
getScalarOfType(Type ty,int64_t rawValue)61 DenseElementsAttr getScalarOfType(Type ty, int64_t rawValue) {
62 RankedTensorType scalarTy = RankedTensorType::get({}, ty);
63
64 if (auto floatTy = ty.dyn_cast<FloatType>()) {
65 APFloat value(floatTy.getFloatSemantics(), rawValue);
66 return DenseElementsAttr::get(scalarTy, value);
67 }
68 if (auto intTy = ty.dyn_cast<IntegerType>()) {
69 APInt value(intTy.getWidth(), static_cast<int64_t>(rawValue),
70 /*isSigned=*/true);
71 return DenseElementsAttr::get(scalarTy, value);
72 }
73 if (auto complexTy = ty.dyn_cast<ComplexType>()) {
74 if (auto floatTy = complexTy.getElementType().cast<FloatType>()) {
75 APFloat real(floatTy.getFloatSemantics(), rawValue);
76 APFloat imag = APFloat::getZero(floatTy.getFloatSemantics());
77 return DenseElementsAttr::get(scalarTy,
78 std::complex<APFloat>(real, imag));
79 }
80 }
81 llvm_unreachable("unsupported type");
82 }
83
getScalarNegZeroOfType(Type ty)84 DenseElementsAttr getScalarNegZeroOfType(Type ty) {
85 RankedTensorType scalarTy = RankedTensorType::get({}, ty);
86
87 if (auto floatTy = ty.dyn_cast<FloatType>()) {
88 APFloat negZero =
89 APFloat::getZero(floatTy.getFloatSemantics(), /*Negative=*/true);
90 return DenseElementsAttr::get(scalarTy, negZero);
91 }
92 if (auto intTy = ty.dyn_cast<IntegerType>()) {
93 return DenseElementsAttr::get(scalarTy, APInt::getZero(intTy.getWidth()));
94 }
95 if (auto complexTy = ty.dyn_cast<ComplexType>()) {
96 if (auto floatTy = complexTy.getElementType().cast<FloatType>()) {
97 APFloat negZero =
98 APFloat::getZero(floatTy.getFloatSemantics(), /*Negative=*/true);
99 return DenseElementsAttr::get(scalarTy,
100 std::complex<APFloat>(negZero, negZero));
101 }
102 }
103 llvm_unreachable("unsupported type");
104 }
105
getScalarLimitOfFloatType(FloatType floatTy,ScalarLimit limit)106 static APFloat getScalarLimitOfFloatType(FloatType floatTy, ScalarLimit limit) {
107 auto& semantics = floatTy.getFloatSemantics();
108 switch (limit) {
109 case kLowest:
110 return APFloat::getLargest(semantics, /*negative=*/true);
111 case kInfinityLowest:
112 return APFloat::getInf(semantics, /*negative=*/true);
113 case kMax:
114 return APFloat::getLargest(semantics, /*negative=*/false);
115 case kInfinityMax:
116 return APFloat::getInf(semantics, /*negative=*/false);
117 }
118 llvm_unreachable("invalid limit");
119 }
120
121 // Returns a scalar value for the given integer type.
122 //
123 // The argument 'scalar' describes which scalar value to return. `integer_value`
124 // is used to specify the integer value for kInteger. For any other scalar,
125 // integer_value is ignored.
getScalarLimitOfIntegerType(IntegerType integerTy,ScalarLimit limit)126 static APInt getScalarLimitOfIntegerType(IntegerType integerTy,
127 ScalarLimit limit) {
128 unsigned width = integerTy.getWidth();
129 bool isBool = (width == 1);
130 switch (limit) {
131 case kLowest:
132 case kInfinityLowest:
133 if (integerTy.isUnsigned() || isBool) {
134 return APInt::getMinValue(width);
135 } else {
136 return APInt::getSignedMinValue(width);
137 }
138
139 case kMax:
140 case kInfinityMax:
141 if (integerTy.isUnsigned() || isBool) {
142 return APInt::getMaxValue(width);
143 } else {
144 return APInt::getSignedMaxValue(width);
145 }
146 }
147 llvm_unreachable("invalid limit");
148 }
149
getScalarLimitOfType(Type ty,ScalarLimit limit)150 DenseElementsAttr getScalarLimitOfType(Type ty, ScalarLimit limit) {
151 RankedTensorType scalarTy = RankedTensorType::get({}, ty);
152 if (auto floatTy = ty.dyn_cast<FloatType>()) {
153 return DenseElementsAttr::get(scalarTy,
154 getScalarLimitOfFloatType(floatTy, limit));
155 }
156 if (auto integerTy = ty.dyn_cast<IntegerType>()) {
157 return DenseElementsAttr::get(
158 scalarTy, getScalarLimitOfIntegerType(integerTy, limit));
159 }
160 llvm_unreachable("unsupported type");
161 }
162
lmhloToMhloOpName(llvm::StringRef opName,mlir::MLIRContext * context)163 std::string lmhloToMhloOpName(llvm::StringRef opName,
164 mlir::MLIRContext* context) {
165 assert(opName.startswith("lmhlo.") && "Expected an LMHLO op");
166
167 if (opName == "lmhlo.dot") {
168 return "mhlo.dot_general";
169 }
170
171 if (opName == "lmhlo.dynamic_slice") {
172 return "mhlo.dynamic_slice";
173 }
174
175 std::string mhloOpName(opName.drop_front(1));
176 if (context->isOperationRegistered(mhloOpName)) return mhloOpName;
177 return "";
178 }
179
isSequenceStartingWith0(Attribute attr)180 bool isSequenceStartingWith0(Attribute attr) {
181 DenseIntElementsAttr denseAttr = attr.dyn_cast<DenseIntElementsAttr>();
182 for (int64_t i = 0, e = denseAttr.getNumElements(); i < e; ++i)
183 if (denseAttr.getValues<APInt>()[i].getSExtValue() != i) return false;
184 return true;
185 }
186
getArgumentIndex(mlir::func::FuncOp op,Value value)187 int64_t getArgumentIndex(mlir::func::FuncOp op, Value value) {
188 BlockArgument arg = value.dyn_cast<BlockArgument>();
189 if (!arg || arg.getOwner() != &op.front()) return -1;
190 return arg.getArgNumber();
191 }
192
193 /// Computes the memory usage of the given allocations.
computeMemory(const std::vector<Value> & allocs)194 std::pair<size_t, size_t> computeMemory(const std::vector<Value>& allocs) {
195 size_t totalSize = 0;
196 size_t allocCounter = 0;
197 for (const Value alloc : allocs) {
198 auto shape = alloc.getType().cast<ShapedType>();
199 size_t shapeBytes = llvm::divideCeil(shape.getSizeInBits(), 8);
200 size_t alignFactor = llvm::divideCeil(shapeBytes, kPaddingSize);
201 size_t size = alignFactor * kPaddingSize;
202 totalSize += size;
203 allocCounter++;
204 }
205 return std::make_pair(totalSize, allocCounter);
206 }
207
208 } // namespace hlo
209 } // namespace mlir
210