• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "mlir-hlo/utils/hlo_utils.h"
17 
18 #include <numeric>
19 
20 #include "mlir/IR/Attributes.h"
21 
22 namespace mlir {
23 namespace hlo {
24 
getBroadcastDimensionsAttr(Builder * b,Value x,Value y,bool allow_empty)25 DenseIntElementsAttr getBroadcastDimensionsAttr(Builder *b, Value x, Value y,
26                                                 bool allow_empty) {
27   TensorType xType = x.getType().dyn_cast<RankedTensorType>();
28   TensorType yType = y.getType().dyn_cast<RankedTensorType>();
29   if (!xType || !yType) return {};
30   if (allow_empty && xType == yType) return {};
31 
32   // If the shapes have the same rank, then there is nothing to do.
33   auto xRank = xType.getRank(), yRank = yType.getRank();
34   if (allow_empty && xRank == yRank) return {};
35 
36   // Otherwise if the ranks of the inputs don't match, TensorFlow automatically
37   // reshapes the smaller by padding with dimensions of size 1 as a prefix. In
38   // other words to pad a 5-vector to a 3-dimensional tensor it is reshaped to
39   // have shape [1,1,5]. XLA's automatic broadcast code is able to broadcast
40   // from lower to higher rank, but doesn't assume you want to pad as a prefix
41   // of the dimensions, and instead needs to be told which dimensions of the
42   // higher rank tensor to match to the lower rank tensor.
43   auto maxRank = std::max(xRank, yRank);
44   auto minRank = std::min(xRank, yRank);
45 
46   // Match the lower rank tensor along the larger-numbered dimensions of the
47   // higher rank tensor.
48   SmallVector<int64_t, 4> broadcastDimensions(minRank);
49   std::iota(broadcastDimensions.begin(), broadcastDimensions.end(),
50             maxRank - minRank);
51 
52   RankedTensorType type =
53       RankedTensorType::get({minRank}, b->getIntegerType(64));
54   return DenseIntElementsAttr::get(type, broadcastDimensions);
55 }
56 
GetScalarOfType(Type ty,int64_t raw_value)57 DenseElementsAttr GetScalarOfType(Type ty, int64_t raw_value) {
58   RankedTensorType scalar_ty = RankedTensorType::get({}, ty);
59 
60   if (auto float_ty = ty.dyn_cast<FloatType>()) {
61     APFloat value(float_ty.getFloatSemantics(), raw_value);
62     return DenseElementsAttr::get(scalar_ty, value);
63   } else if (auto int_ty = ty.dyn_cast<IntegerType>()) {
64     APInt value(int_ty.getWidth(), static_cast<int64_t>(raw_value), true);
65     return DenseElementsAttr::get(scalar_ty, value);
66   } else if (auto complex_ty = ty.dyn_cast<ComplexType>()) {
67     Type complex_element_ty = complex_ty.getElementType();
68     if (complex_element_ty.isF32()) {
69       return DenseElementsAttr::get(
70           scalar_ty, static_cast<std::complex<float>>(raw_value));
71     } else if (complex_element_ty.isF64()) {
72       return DenseElementsAttr::get(
73           scalar_ty, static_cast<std::complex<double>>(raw_value));
74     }
75   }
76   llvm_unreachable("unsupported type");
77 }
78 
GetScalarLimitOfFloatType(FloatType float_ty,ScalarLimit limit)79 static APFloat GetScalarLimitOfFloatType(FloatType float_ty,
80                                          ScalarLimit limit) {
81   auto &semantics = float_ty.getFloatSemantics();
82   switch (limit) {
83     case kLowest:
84       return APFloat::getLargest(semantics, /*negative=*/true);
85     case kInfinityLowest:
86       return APFloat::getInf(semantics, /*negative=*/true);
87     case kMax:
88       return APFloat::getLargest(semantics, /*negative=*/false);
89     case kInfinityMax:
90       return APFloat::getInf(semantics, /*negative=*/false);
91   }
92   llvm_unreachable("invalid limit");
93 }
94 
95 // Returns a scalar value for the given integer type.
96 //
97 // The argument 'scalar' describes which scalar value to return. `integer_value`
98 // is used to specify the integer value for kInteger. For any other scalar,
99 // integer_value is ignored.
GetScalarLimitOfIntegerType(IntegerType integer_ty,ScalarLimit limit)100 static APInt GetScalarLimitOfIntegerType(IntegerType integer_ty,
101                                          ScalarLimit limit) {
102   unsigned width = integer_ty.getWidth();
103   switch (limit) {
104     case kLowest:
105     case kInfinityLowest:
106       if (integer_ty.isUnsigned()) {
107         return APInt::getMinValue(width);
108       } else {
109         return APInt::getSignedMinValue(width);
110       }
111 
112     case kMax:
113     case kInfinityMax:
114       if (integer_ty.isUnsigned()) {
115         return APInt::getMaxValue(width);
116       } else {
117         return APInt::getSignedMaxValue(width);
118       }
119   }
120   llvm_unreachable("invalid limit");
121 }
122 
GetScalarLimitOfType(Type ty,ScalarLimit limit)123 DenseElementsAttr GetScalarLimitOfType(Type ty, ScalarLimit limit) {
124   RankedTensorType scalar_ty = RankedTensorType::get({}, ty);
125   if (auto float_ty = ty.dyn_cast<FloatType>()) {
126     return DenseElementsAttr::get(scalar_ty,
127                                   GetScalarLimitOfFloatType(float_ty, limit));
128   } else if (auto integer_ty = ty.dyn_cast<IntegerType>()) {
129     return DenseElementsAttr::get(
130         scalar_ty, GetScalarLimitOfIntegerType(integer_ty, limit));
131   }
132   llvm_unreachable("unsupported type");
133 }
134 
LmhloToMhloOpName(llvm::StringRef op_name,mlir::MLIRContext * context)135 std::string LmhloToMhloOpName(llvm::StringRef op_name,
136                               mlir::MLIRContext *context) {
137   assert(op_name.startswith("lmhlo.") && "Expected an LMHLO op");
138 
139   if (op_name == "lmhlo.dot") {
140     return "mhlo.dot_general";
141   }
142 
143   if (op_name == "lmhlo.dynamic_slice") {
144     return "mhlo.dynamic-slice";
145   }
146 
147   std::string mhlo_op_name(op_name.drop_front(1));
148   if (context->isOperationRegistered(mhlo_op_name)) return mhlo_op_name;
149   return "";
150 }
151 
IsSequenceStartingWith0(DenseIntElementsAttr attr)152 bool IsSequenceStartingWith0(DenseIntElementsAttr attr) {
153   for (int64_t i = 0, e = attr.getNumElements(); i < e; ++i)
154     if (attr.getValue<IntegerAttr>(i).getInt() != i) return false;
155   return true;
156 }
157 
158 }  // namespace hlo
159 }  // namespace mlir
160