• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "mlir-hlo/utils/hlo_utils.h"
17 
18 #include <numeric>
19 
20 #include "mlir/IR/Attributes.h"
21 
22 namespace mlir {
23 namespace hlo {
24 
getBroadcastDimensionsAttr(Builder * b,Value x,Value y,bool allow_empty)25 DenseIntElementsAttr getBroadcastDimensionsAttr(Builder* b, Value x, Value y,
26                                                 bool allow_empty) {
27   TensorType xType = x.getType().dyn_cast<RankedTensorType>();
28   TensorType yType = y.getType().dyn_cast<RankedTensorType>();
29   if (!xType || !yType) return {};
30   if (allow_empty && xType == yType) return {};
31 
32   // If the shapes have the same rank, then there is nothing to do.
33   auto xRank = xType.getRank(), yRank = yType.getRank();
34   if (allow_empty && xRank == yRank) return {};
35 
36   // Otherwise if the ranks of the inputs don't match, TensorFlow automatically
37   // reshapes the smaller by padding with dimensions of size 1 as a prefix. In
38   // other words to pad a 5-vector to a 3-dimensional tensor it is reshaped to
39   // have shape [1,1,5]. XLA's automatic broadcast code is able to broadcast
40   // from lower to higher rank, but doesn't assume you want to pad as a prefix
41   // of the dimensions, and instead needs to be told which dimensions of the
42   // higher rank tensor to match to the lower rank tensor.
43   auto maxRank = std::max(xRank, yRank);
44   auto minRank = std::min(xRank, yRank);
45 
46   // Match the lower rank tensor along the larger-numbered dimensions of the
47   // higher rank tensor.
48   SmallVector<int64_t, 4> broadcastDimensions(minRank);
49   std::iota(broadcastDimensions.begin(), broadcastDimensions.end(),
50             maxRank - minRank);
51 
52   RankedTensorType type =
53       RankedTensorType::get({minRank}, b->getIntegerType(64));
54   return DenseIntElementsAttr::get(type, broadcastDimensions);
55 }
56 
GetScalarOfType(Type ty,int64_t raw_value)57 DenseElementsAttr GetScalarOfType(Type ty, int64_t raw_value) {
58   RankedTensorType scalar_ty = RankedTensorType::get({}, ty);
59 
60   if (auto float_ty = ty.dyn_cast<FloatType>()) {
61     APFloat value(float_ty.getFloatSemantics(), raw_value);
62     return DenseElementsAttr::get(scalar_ty, value);
63   } else if (auto int_ty = ty.dyn_cast<IntegerType>()) {
64     APInt value(int_ty.getWidth(), static_cast<int64_t>(raw_value), true);
65     return DenseElementsAttr::get(scalar_ty, value);
66   } else if (auto complex_ty = ty.dyn_cast<ComplexType>()) {
67     Type complex_element_ty = complex_ty.getElementType();
68     if (complex_element_ty.isF32()) {
69       return DenseElementsAttr::get(
70           scalar_ty, static_cast<std::complex<float>>(raw_value));
71     } else if (complex_element_ty.isF64()) {
72       return DenseElementsAttr::get(
73           scalar_ty, static_cast<std::complex<double>>(raw_value));
74     }
75   }
76   llvm_unreachable("unsupported type");
77 }
78 
GetScalarLimitOfFloatType(FloatType float_ty,ScalarLimit limit)79 static APFloat GetScalarLimitOfFloatType(FloatType float_ty,
80                                          ScalarLimit limit) {
81   auto& semantics = float_ty.getFloatSemantics();
82   switch (limit) {
83     case kLowest:
84       return APFloat::getLargest(semantics, /*negative=*/true);
85     case kInfinityLowest:
86       return APFloat::getInf(semantics, /*negative=*/true);
87     case kMax:
88       return APFloat::getLargest(semantics, /*negative=*/false);
89     case kInfinityMax:
90       return APFloat::getInf(semantics, /*negative=*/false);
91   }
92   llvm_unreachable("invalid limit");
93 }
94 
95 // Returns a scalar value for the given integer type.
96 //
97 // The argument 'scalar' describes which scalar value to return. `integer_value`
98 // is used to specify the integer value for kInteger. For any other scalar,
99 // integer_value is ignored.
GetScalarLimitOfIntegerType(IntegerType integer_ty,ScalarLimit limit)100 static APInt GetScalarLimitOfIntegerType(IntegerType integer_ty,
101                                          ScalarLimit limit) {
102   unsigned width = integer_ty.getWidth();
103   bool is_bool = (width == 1);
104   switch (limit) {
105     case kLowest:
106     case kInfinityLowest:
107       if (integer_ty.isUnsigned() || is_bool) {
108         return APInt::getMinValue(width);
109       } else {
110         return APInt::getSignedMinValue(width);
111       }
112 
113     case kMax:
114     case kInfinityMax:
115       if (integer_ty.isUnsigned() || is_bool) {
116         return APInt::getMaxValue(width);
117       } else {
118         return APInt::getSignedMaxValue(width);
119       }
120   }
121   llvm_unreachable("invalid limit");
122 }
123 
GetScalarLimitOfType(Type ty,ScalarLimit limit)124 DenseElementsAttr GetScalarLimitOfType(Type ty, ScalarLimit limit) {
125   RankedTensorType scalar_ty = RankedTensorType::get({}, ty);
126   if (auto float_ty = ty.dyn_cast<FloatType>()) {
127     return DenseElementsAttr::get(scalar_ty,
128                                   GetScalarLimitOfFloatType(float_ty, limit));
129   } else if (auto integer_ty = ty.dyn_cast<IntegerType>()) {
130     return DenseElementsAttr::get(
131         scalar_ty, GetScalarLimitOfIntegerType(integer_ty, limit));
132   }
133   llvm_unreachable("unsupported type");
134 }
135 
LmhloToMhloOpName(llvm::StringRef op_name,mlir::MLIRContext * context)136 std::string LmhloToMhloOpName(llvm::StringRef op_name,
137                               mlir::MLIRContext* context) {
138   assert(op_name.startswith("lmhlo.") && "Expected an LMHLO op");
139 
140   if (op_name == "lmhlo.dot") {
141     return "mhlo.dot_general";
142   }
143 
144   if (op_name == "lmhlo.dynamic_slice") {
145     return "mhlo.dynamic-slice";
146   }
147 
148   std::string mhlo_op_name(op_name.drop_front(1));
149   if (context->isOperationRegistered(mhlo_op_name)) return mhlo_op_name;
150   return "";
151 }
152 
IsSequenceStartingWith0(DenseIntElementsAttr attr)153 bool IsSequenceStartingWith0(DenseIntElementsAttr attr) {
154   for (int64_t i = 0, e = attr.getNumElements(); i < e; ++i)
155     if (attr.getValue<IntegerAttr>(i).getInt() != i) return false;
156   return true;
157 }
158 
getArgumentIndex(mlir::FuncOp op,Value value)159 int64_t getArgumentIndex(mlir::FuncOp op, Value value) {
160   BlockArgument arg = value.dyn_cast<BlockArgument>();
161   if (!arg || arg.getOwner() != &op.front()) return -1;
162   return arg.getArgNumber();
163 }
164 
165 }  // namespace hlo
166 }  // namespace mlir
167