1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "mlir-hlo/utils/hlo_utils.h"
17
18 #include <numeric>
19
20 #include "mlir/IR/Attributes.h"
21
22 namespace mlir {
23 namespace hlo {
24
getBroadcastDimensionsAttr(Builder * b,Value x,Value y,bool allow_empty)25 DenseIntElementsAttr getBroadcastDimensionsAttr(Builder* b, Value x, Value y,
26 bool allow_empty) {
27 TensorType xType = x.getType().dyn_cast<RankedTensorType>();
28 TensorType yType = y.getType().dyn_cast<RankedTensorType>();
29 if (!xType || !yType) return {};
30 if (allow_empty && xType == yType) return {};
31
32 // If the shapes have the same rank, then there is nothing to do.
33 auto xRank = xType.getRank(), yRank = yType.getRank();
34 if (allow_empty && xRank == yRank) return {};
35
36 // Otherwise if the ranks of the inputs don't match, TensorFlow automatically
37 // reshapes the smaller by padding with dimensions of size 1 as a prefix. In
38 // other words to pad a 5-vector to a 3-dimensional tensor it is reshaped to
39 // have shape [1,1,5]. XLA's automatic broadcast code is able to broadcast
40 // from lower to higher rank, but doesn't assume you want to pad as a prefix
41 // of the dimensions, and instead needs to be told which dimensions of the
42 // higher rank tensor to match to the lower rank tensor.
43 auto maxRank = std::max(xRank, yRank);
44 auto minRank = std::min(xRank, yRank);
45
46 // Match the lower rank tensor along the larger-numbered dimensions of the
47 // higher rank tensor.
48 SmallVector<int64_t, 4> broadcastDimensions(minRank);
49 std::iota(broadcastDimensions.begin(), broadcastDimensions.end(),
50 maxRank - minRank);
51
52 RankedTensorType type =
53 RankedTensorType::get({minRank}, b->getIntegerType(64));
54 return DenseIntElementsAttr::get(type, broadcastDimensions);
55 }
56
GetScalarOfType(Type ty,int64_t raw_value)57 DenseElementsAttr GetScalarOfType(Type ty, int64_t raw_value) {
58 RankedTensorType scalar_ty = RankedTensorType::get({}, ty);
59
60 if (auto float_ty = ty.dyn_cast<FloatType>()) {
61 APFloat value(float_ty.getFloatSemantics(), raw_value);
62 return DenseElementsAttr::get(scalar_ty, value);
63 } else if (auto int_ty = ty.dyn_cast<IntegerType>()) {
64 APInt value(int_ty.getWidth(), static_cast<int64_t>(raw_value), true);
65 return DenseElementsAttr::get(scalar_ty, value);
66 } else if (auto complex_ty = ty.dyn_cast<ComplexType>()) {
67 Type complex_element_ty = complex_ty.getElementType();
68 if (complex_element_ty.isF32()) {
69 return DenseElementsAttr::get(
70 scalar_ty, static_cast<std::complex<float>>(raw_value));
71 } else if (complex_element_ty.isF64()) {
72 return DenseElementsAttr::get(
73 scalar_ty, static_cast<std::complex<double>>(raw_value));
74 }
75 }
76 llvm_unreachable("unsupported type");
77 }
78
GetScalarLimitOfFloatType(FloatType float_ty,ScalarLimit limit)79 static APFloat GetScalarLimitOfFloatType(FloatType float_ty,
80 ScalarLimit limit) {
81 auto& semantics = float_ty.getFloatSemantics();
82 switch (limit) {
83 case kLowest:
84 return APFloat::getLargest(semantics, /*negative=*/true);
85 case kInfinityLowest:
86 return APFloat::getInf(semantics, /*negative=*/true);
87 case kMax:
88 return APFloat::getLargest(semantics, /*negative=*/false);
89 case kInfinityMax:
90 return APFloat::getInf(semantics, /*negative=*/false);
91 }
92 llvm_unreachable("invalid limit");
93 }
94
95 // Returns a scalar value for the given integer type.
96 //
97 // The argument 'scalar' describes which scalar value to return. `integer_value`
98 // is used to specify the integer value for kInteger. For any other scalar,
99 // integer_value is ignored.
GetScalarLimitOfIntegerType(IntegerType integer_ty,ScalarLimit limit)100 static APInt GetScalarLimitOfIntegerType(IntegerType integer_ty,
101 ScalarLimit limit) {
102 unsigned width = integer_ty.getWidth();
103 bool is_bool = (width == 1);
104 switch (limit) {
105 case kLowest:
106 case kInfinityLowest:
107 if (integer_ty.isUnsigned() || is_bool) {
108 return APInt::getMinValue(width);
109 } else {
110 return APInt::getSignedMinValue(width);
111 }
112
113 case kMax:
114 case kInfinityMax:
115 if (integer_ty.isUnsigned() || is_bool) {
116 return APInt::getMaxValue(width);
117 } else {
118 return APInt::getSignedMaxValue(width);
119 }
120 }
121 llvm_unreachable("invalid limit");
122 }
123
GetScalarLimitOfType(Type ty,ScalarLimit limit)124 DenseElementsAttr GetScalarLimitOfType(Type ty, ScalarLimit limit) {
125 RankedTensorType scalar_ty = RankedTensorType::get({}, ty);
126 if (auto float_ty = ty.dyn_cast<FloatType>()) {
127 return DenseElementsAttr::get(scalar_ty,
128 GetScalarLimitOfFloatType(float_ty, limit));
129 } else if (auto integer_ty = ty.dyn_cast<IntegerType>()) {
130 return DenseElementsAttr::get(
131 scalar_ty, GetScalarLimitOfIntegerType(integer_ty, limit));
132 }
133 llvm_unreachable("unsupported type");
134 }
135
LmhloToMhloOpName(llvm::StringRef op_name,mlir::MLIRContext * context)136 std::string LmhloToMhloOpName(llvm::StringRef op_name,
137 mlir::MLIRContext* context) {
138 assert(op_name.startswith("lmhlo.") && "Expected an LMHLO op");
139
140 if (op_name == "lmhlo.dot") {
141 return "mhlo.dot_general";
142 }
143
144 if (op_name == "lmhlo.dynamic_slice") {
145 return "mhlo.dynamic-slice";
146 }
147
148 std::string mhlo_op_name(op_name.drop_front(1));
149 if (context->isOperationRegistered(mhlo_op_name)) return mhlo_op_name;
150 return "";
151 }
152
IsSequenceStartingWith0(DenseIntElementsAttr attr)153 bool IsSequenceStartingWith0(DenseIntElementsAttr attr) {
154 for (int64_t i = 0, e = attr.getNumElements(); i < e; ++i)
155 if (attr.getValue<IntegerAttr>(i).getInt() != i) return false;
156 return true;
157 }
158
getArgumentIndex(mlir::FuncOp op,Value value)159 int64_t getArgumentIndex(mlir::FuncOp op, Value value) {
160 BlockArgument arg = value.dyn_cast<BlockArgument>();
161 if (!arg || arg.getOwner() != &op.front()) return -1;
162 return arg.getArgNumber();
163 }
164
165 } // namespace hlo
166 } // namespace mlir
167