1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "mlir-hlo/utils/hlo_utils.h"
17
18 #include <numeric>
19
20 #include "mlir/IR/Attributes.h"
21
22 namespace mlir {
23 namespace hlo {
24
getBroadcastDimensionsAttr(Builder * b,Value x,Value y,bool allow_empty)25 DenseIntElementsAttr getBroadcastDimensionsAttr(Builder *b, Value x, Value y,
26 bool allow_empty) {
27 TensorType xType = x.getType().dyn_cast<RankedTensorType>();
28 TensorType yType = y.getType().dyn_cast<RankedTensorType>();
29 if (!xType || !yType) return {};
30 if (allow_empty && xType == yType) return {};
31
32 // If the shapes have the same rank, then there is nothing to do.
33 auto xRank = xType.getRank(), yRank = yType.getRank();
34 if (allow_empty && xRank == yRank) return {};
35
36 // Otherwise if the ranks of the inputs don't match, TensorFlow automatically
37 // reshapes the smaller by padding with dimensions of size 1 as a prefix. In
38 // other words to pad a 5-vector to a 3-dimensional tensor it is reshaped to
39 // have shape [1,1,5]. XLA's automatic broadcast code is able to broadcast
40 // from lower to higher rank, but doesn't assume you want to pad as a prefix
41 // of the dimensions, and instead needs to be told which dimensions of the
42 // higher rank tensor to match to the lower rank tensor.
43 auto maxRank = std::max(xRank, yRank);
44 auto minRank = std::min(xRank, yRank);
45
46 // Match the lower rank tensor along the larger-numbered dimensions of the
47 // higher rank tensor.
48 SmallVector<int64_t, 4> broadcastDimensions(minRank);
49 std::iota(broadcastDimensions.begin(), broadcastDimensions.end(),
50 maxRank - minRank);
51
52 RankedTensorType type =
53 RankedTensorType::get({minRank}, b->getIntegerType(64));
54 return DenseIntElementsAttr::get(type, broadcastDimensions);
55 }
56
GetScalarOfType(Type ty,int64_t raw_value)57 DenseElementsAttr GetScalarOfType(Type ty, int64_t raw_value) {
58 RankedTensorType scalar_ty = RankedTensorType::get({}, ty);
59
60 if (auto float_ty = ty.dyn_cast<FloatType>()) {
61 APFloat value(float_ty.getFloatSemantics(), raw_value);
62 return DenseElementsAttr::get(scalar_ty, value);
63 } else if (auto int_ty = ty.dyn_cast<IntegerType>()) {
64 APInt value(int_ty.getWidth(), static_cast<int64_t>(raw_value), true);
65 return DenseElementsAttr::get(scalar_ty, value);
66 } else if (auto complex_ty = ty.dyn_cast<ComplexType>()) {
67 Type complex_element_ty = complex_ty.getElementType();
68 if (complex_element_ty.isF32()) {
69 return DenseElementsAttr::get(
70 scalar_ty, static_cast<std::complex<float>>(raw_value));
71 } else if (complex_element_ty.isF64()) {
72 return DenseElementsAttr::get(
73 scalar_ty, static_cast<std::complex<double>>(raw_value));
74 }
75 }
76 llvm_unreachable("unsupported type");
77 }
78
GetScalarLimitOfFloatType(FloatType float_ty,ScalarLimit limit)79 static APFloat GetScalarLimitOfFloatType(FloatType float_ty,
80 ScalarLimit limit) {
81 auto &semantics = float_ty.getFloatSemantics();
82 switch (limit) {
83 case kLowest:
84 return APFloat::getLargest(semantics, /*negative=*/true);
85 case kInfinityLowest:
86 return APFloat::getInf(semantics, /*negative=*/true);
87 case kMax:
88 return APFloat::getLargest(semantics, /*negative=*/false);
89 case kInfinityMax:
90 return APFloat::getInf(semantics, /*negative=*/false);
91 }
92 llvm_unreachable("invalid limit");
93 }
94
95 // Returns a scalar value for the given integer type.
96 //
97 // The argument 'scalar' describes which scalar value to return. `integer_value`
98 // is used to specify the integer value for kInteger. For any other scalar,
99 // integer_value is ignored.
GetScalarLimitOfIntegerType(IntegerType integer_ty,ScalarLimit limit)100 static APInt GetScalarLimitOfIntegerType(IntegerType integer_ty,
101 ScalarLimit limit) {
102 unsigned width = integer_ty.getWidth();
103 switch (limit) {
104 case kLowest:
105 case kInfinityLowest:
106 if (integer_ty.isUnsigned()) {
107 return APInt::getMinValue(width);
108 } else {
109 return APInt::getSignedMinValue(width);
110 }
111
112 case kMax:
113 case kInfinityMax:
114 if (integer_ty.isUnsigned()) {
115 return APInt::getMaxValue(width);
116 } else {
117 return APInt::getSignedMaxValue(width);
118 }
119 }
120 llvm_unreachable("invalid limit");
121 }
122
GetScalarLimitOfType(Type ty,ScalarLimit limit)123 DenseElementsAttr GetScalarLimitOfType(Type ty, ScalarLimit limit) {
124 RankedTensorType scalar_ty = RankedTensorType::get({}, ty);
125 if (auto float_ty = ty.dyn_cast<FloatType>()) {
126 return DenseElementsAttr::get(scalar_ty,
127 GetScalarLimitOfFloatType(float_ty, limit));
128 } else if (auto integer_ty = ty.dyn_cast<IntegerType>()) {
129 return DenseElementsAttr::get(
130 scalar_ty, GetScalarLimitOfIntegerType(integer_ty, limit));
131 }
132 llvm_unreachable("unsupported type");
133 }
134
LmhloToMhloOpName(llvm::StringRef op_name,mlir::MLIRContext * context)135 std::string LmhloToMhloOpName(llvm::StringRef op_name,
136 mlir::MLIRContext *context) {
137 assert(op_name.startswith("lmhlo.") && "Expected an LMHLO op");
138
139 if (op_name == "lmhlo.dot") {
140 return "mhlo.dot_general";
141 }
142
143 if (op_name == "lmhlo.dynamic_slice") {
144 return "mhlo.dynamic-slice";
145 }
146
147 std::string mhlo_op_name(op_name.drop_front(1));
148 if (context->isOperationRegistered(mhlo_op_name)) return mhlo_op_name;
149 return "";
150 }
151
IsSequenceStartingWith0(DenseIntElementsAttr attr)152 bool IsSequenceStartingWith0(DenseIntElementsAttr attr) {
153 for (int64_t i = 0, e = attr.getNumElements(); i < e; ++i)
154 if (attr.getValue<IntegerAttr>(i).getInt() != i) return false;
155 return true;
156 }
157
158 } // namespace hlo
159 } // namespace mlir
160