• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "mlir-hlo/utils/hlo_utils.h"
17 
18 #include <numeric>
19 #include <string>
20 
21 #include "mlir/Dialect/Func/IR/FuncOps.h"
22 #include "mlir/IR/Attributes.h"
23 
24 namespace mlir {
25 namespace hlo {
26 
27 static constexpr size_t kPaddingSize = 64;
28 
getBroadcastDimensionsAttr(Builder * b,Value x,Value y,bool allowEmpty)29 DenseIntElementsAttr getBroadcastDimensionsAttr(Builder* b, Value x, Value y,
30                                                 bool allowEmpty) {
31   TensorType xType = x.getType().dyn_cast<RankedTensorType>();
32   TensorType yType = y.getType().dyn_cast<RankedTensorType>();
33   if (!xType || !yType) return {};
34   if (allowEmpty && xType == yType) return {};
35 
36   // If the shapes have the same rank, then there is nothing to do.
37   auto xRank = xType.getRank(), yRank = yType.getRank();
38   if (allowEmpty && xRank == yRank) return {};
39 
40   // Otherwise if the ranks of the inputs don't match, TensorFlow automatically
41   // reshapes the smaller by padding with dimensions of size 1 as a prefix. In
42   // other words to pad a 5-vector to a 3-dimensional tensor it is reshaped to
43   // have shape [1,1,5]. XLA's automatic broadcast code is able to broadcast
44   // from lower to higher rank, but doesn't assume you want to pad as a prefix
45   // of the dimensions, and instead needs to be told which dimensions of the
46   // higher rank tensor to match to the lower rank tensor.
47   auto maxRank = std::max(xRank, yRank);
48   auto minRank = std::min(xRank, yRank);
49 
50   // Match the lower rank tensor along the larger-numbered dimensions of the
51   // higher rank tensor.
52   SmallVector<int64_t, 4> broadcastDimensions(minRank);
53   std::iota(broadcastDimensions.begin(), broadcastDimensions.end(),
54             maxRank - minRank);
55 
56   RankedTensorType type =
57       RankedTensorType::get({minRank}, b->getIntegerType(64));
58   return DenseIntElementsAttr::get(type, broadcastDimensions);
59 }
60 
getScalarOfType(Type ty,int64_t rawValue)61 DenseElementsAttr getScalarOfType(Type ty, int64_t rawValue) {
62   RankedTensorType scalarTy = RankedTensorType::get({}, ty);
63 
64   if (auto floatTy = ty.dyn_cast<FloatType>()) {
65     APFloat value(floatTy.getFloatSemantics(), rawValue);
66     return DenseElementsAttr::get(scalarTy, value);
67   }
68   if (auto intTy = ty.dyn_cast<IntegerType>()) {
69     APInt value(intTy.getWidth(), static_cast<int64_t>(rawValue),
70                 /*isSigned=*/true);
71     return DenseElementsAttr::get(scalarTy, value);
72   }
73   if (auto complexTy = ty.dyn_cast<ComplexType>()) {
74     if (auto floatTy = complexTy.getElementType().cast<FloatType>()) {
75       APFloat real(floatTy.getFloatSemantics(), rawValue);
76       APFloat imag = APFloat::getZero(floatTy.getFloatSemantics());
77       return DenseElementsAttr::get(scalarTy,
78                                     std::complex<APFloat>(real, imag));
79     }
80   }
81   llvm_unreachable("unsupported type");
82 }
83 
getScalarNegZeroOfType(Type ty)84 DenseElementsAttr getScalarNegZeroOfType(Type ty) {
85   RankedTensorType scalarTy = RankedTensorType::get({}, ty);
86 
87   if (auto floatTy = ty.dyn_cast<FloatType>()) {
88     APFloat negZero =
89         APFloat::getZero(floatTy.getFloatSemantics(), /*Negative=*/true);
90     return DenseElementsAttr::get(scalarTy, negZero);
91   }
92   if (auto intTy = ty.dyn_cast<IntegerType>()) {
93     return DenseElementsAttr::get(scalarTy, APInt::getZero(intTy.getWidth()));
94   }
95   if (auto complexTy = ty.dyn_cast<ComplexType>()) {
96     if (auto floatTy = complexTy.getElementType().cast<FloatType>()) {
97       APFloat negZero =
98           APFloat::getZero(floatTy.getFloatSemantics(), /*Negative=*/true);
99       return DenseElementsAttr::get(scalarTy,
100                                     std::complex<APFloat>(negZero, negZero));
101     }
102   }
103   llvm_unreachable("unsupported type");
104 }
105 
getScalarLimitOfFloatType(FloatType floatTy,ScalarLimit limit)106 static APFloat getScalarLimitOfFloatType(FloatType floatTy, ScalarLimit limit) {
107   auto& semantics = floatTy.getFloatSemantics();
108   switch (limit) {
109     case kLowest:
110       return APFloat::getLargest(semantics, /*negative=*/true);
111     case kInfinityLowest:
112       return APFloat::getInf(semantics, /*negative=*/true);
113     case kMax:
114       return APFloat::getLargest(semantics, /*negative=*/false);
115     case kInfinityMax:
116       return APFloat::getInf(semantics, /*negative=*/false);
117   }
118   llvm_unreachable("invalid limit");
119 }
120 
121 // Returns a scalar value for the given integer type.
122 //
123 // The argument 'scalar' describes which scalar value to return. `integer_value`
124 // is used to specify the integer value for kInteger. For any other scalar,
125 // integer_value is ignored.
getScalarLimitOfIntegerType(IntegerType integerTy,ScalarLimit limit)126 static APInt getScalarLimitOfIntegerType(IntegerType integerTy,
127                                          ScalarLimit limit) {
128   unsigned width = integerTy.getWidth();
129   bool isBool = (width == 1);
130   switch (limit) {
131     case kLowest:
132     case kInfinityLowest:
133       if (integerTy.isUnsigned() || isBool) {
134         return APInt::getMinValue(width);
135       } else {
136         return APInt::getSignedMinValue(width);
137       }
138 
139     case kMax:
140     case kInfinityMax:
141       if (integerTy.isUnsigned() || isBool) {
142         return APInt::getMaxValue(width);
143       } else {
144         return APInt::getSignedMaxValue(width);
145       }
146   }
147   llvm_unreachable("invalid limit");
148 }
149 
getScalarLimitOfType(Type ty,ScalarLimit limit)150 DenseElementsAttr getScalarLimitOfType(Type ty, ScalarLimit limit) {
151   RankedTensorType scalarTy = RankedTensorType::get({}, ty);
152   if (auto floatTy = ty.dyn_cast<FloatType>()) {
153     return DenseElementsAttr::get(scalarTy,
154                                   getScalarLimitOfFloatType(floatTy, limit));
155   }
156   if (auto integerTy = ty.dyn_cast<IntegerType>()) {
157     return DenseElementsAttr::get(
158         scalarTy, getScalarLimitOfIntegerType(integerTy, limit));
159   }
160   llvm_unreachable("unsupported type");
161 }
162 
lmhloToMhloOpName(llvm::StringRef opName,mlir::MLIRContext * context)163 std::string lmhloToMhloOpName(llvm::StringRef opName,
164                               mlir::MLIRContext* context) {
165   assert(opName.startswith("lmhlo.") && "Expected an LMHLO op");
166 
167   if (opName == "lmhlo.dot") {
168     return "mhlo.dot_general";
169   }
170 
171   if (opName == "lmhlo.dynamic_slice") {
172     return "mhlo.dynamic_slice";
173   }
174 
175   std::string mhloOpName(opName.drop_front(1));
176   if (context->isOperationRegistered(mhloOpName)) return mhloOpName;
177   return "";
178 }
179 
isSequenceStartingWith0(Attribute attr)180 bool isSequenceStartingWith0(Attribute attr) {
181   DenseIntElementsAttr denseAttr = attr.dyn_cast<DenseIntElementsAttr>();
182   for (int64_t i = 0, e = denseAttr.getNumElements(); i < e; ++i)
183     if (denseAttr.getValues<APInt>()[i].getSExtValue() != i) return false;
184   return true;
185 }
186 
getArgumentIndex(mlir::func::FuncOp op,Value value)187 int64_t getArgumentIndex(mlir::func::FuncOp op, Value value) {
188   BlockArgument arg = value.dyn_cast<BlockArgument>();
189   if (!arg || arg.getOwner() != &op.front()) return -1;
190   return arg.getArgNumber();
191 }
192 
193 /// Computes the memory usage of the given allocations.
computeMemory(const std::vector<Value> & allocs)194 std::pair<size_t, size_t> computeMemory(const std::vector<Value>& allocs) {
195   size_t totalSize = 0;
196   size_t allocCounter = 0;
197   for (const Value alloc : allocs) {
198     auto shape = alloc.getType().cast<ShapedType>();
199     size_t shapeBytes = llvm::divideCeil(shape.getSizeInBits(), 8);
200     size_t alignFactor = llvm::divideCeil(shapeBytes, kPaddingSize);
201     size_t size = alignFactor * kPaddingSize;
202     totalSize += size;
203     allocCounter++;
204   }
205   return std::make_pair(totalSize, allocCounter);
206 }
207 
208 }  // namespace hlo
209 }  // namespace mlir
210