• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
17 
18 #include "llvm/ADT/APFloat.h"
19 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
20 #include "mlir-hlo/utils/broadcast_utils.h"
21 #include "mlir/IR/Attributes.h"
22 #include "mlir/IR/Builders.h"
23 #include "mlir/IR/BuiltinTypes.h"
24 #include "mlir/IR/Diagnostics.h"
25 #include "mlir/IR/PatternMatch.h"
26 #include "mlir/IR/TypeUtilities.h"
27 
28 namespace mlir {
29 namespace chlo {
30 
31 template <typename T>
Verify(T op)32 static LogicalResult Verify(T op) {
33   return success();
34 }
35 
getConstantLikeMaxFiniteValue(OpBuilder & b,Location loc,Value val)36 Value getConstantLikeMaxFiniteValue(OpBuilder& b, Location loc, Value val) {
37   auto ty = getElementTypeOrSelf(val.getType()).cast<FloatType>();
38   return getConstantLike(
39       b, loc, llvm::APFloat::getLargest(ty.getFloatSemantics()), val);
40 }
41 
getConstantLikeInfValue(OpBuilder & b,Location loc,Value val,bool negative)42 Value getConstantLikeInfValue(OpBuilder& b, Location loc, Value val,
43                               bool negative) {
44   auto ty = getElementTypeOrSelf(val.getType()).cast<FloatType>();
45   return getConstantLike(
46       b, loc, llvm::APFloat::getInf(ty.getFloatSemantics(), negative), val);
47 }
48 
getConstantLikeSmallestFiniteValue(OpBuilder & b,Location loc,Value val)49 Value getConstantLikeSmallestFiniteValue(OpBuilder& b, Location loc,
50                                          Value val) {
51   auto ty = getElementTypeOrSelf(val.getType()).cast<FloatType>();
52   return getConstantLike(
53       b, loc, llvm::APFloat::getSmallest(ty.getFloatSemantics()), val);
54 }
55 
getConstantLike(OpBuilder & b,Location loc,const APFloat & constant,Value val)56 Value getConstantLike(OpBuilder& b, Location loc, const APFloat& constant,
57                       Value val) {
58   Type ty = getElementTypeOrSelf(val.getType());
59   return b.create<ConstantLikeOp>(loc, b.getFloatAttr(ty, constant), val);
60 }
61 
62 //===----------------------------------------------------------------------===//
63 // BinaryOps
64 //===----------------------------------------------------------------------===//
65 
66 namespace {
67 // Gets the resulting type from a broadcast between two types.
GetBroadcastType(Type x,Type y,Type element_type,DenseIntElementsAttr broadcast_dimensions_attr)68 static Type GetBroadcastType(Type x, Type y, Type element_type,
69                              DenseIntElementsAttr broadcast_dimensions_attr) {
70   auto x_ranked = x.dyn_cast<RankedTensorType>();
71   auto y_ranked = y.dyn_cast<RankedTensorType>();
72   if (!x_ranked || !y_ranked) {
73     return UnrankedTensorType::get(element_type);
74   }
75 
76   auto shape_x = x_ranked.getShape();
77   auto shape_y = y_ranked.getShape();
78 
79   if (shape_x.size() == shape_y.size()) {
80     llvm::SmallVector<int64_t, 4> out_shape(shape_x.size());
81     for (int i = 0, e = shape_x.size(); i < e; i++) {
82       auto x_val = shape_x[i];
83       auto y_val = shape_y[i];
84       if (x_val == -1 || y_val == -1) {
85         out_shape[i] = -1;
86       } else {
87         out_shape[i] = std::max(x_val, y_val);
88       }
89     }
90     return RankedTensorType::get(out_shape, element_type);
91   }
92 
93   auto shape_large = shape_x.size() > shape_y.size() ? shape_x : shape_y;
94   auto shape_small = shape_x.size() <= shape_y.size() ? shape_x : shape_y;
95 
96   llvm::SmallVector<int64_t, 4> broadcast_dimensions;
97   if (broadcast_dimensions_attr) {
98     // Explicit broadcast dimensions.
99     for (const APInt& int_value : broadcast_dimensions_attr.getIntValues()) {
100       broadcast_dimensions.push_back(int_value.getSExtValue());
101     }
102     if (broadcast_dimensions.size() != shape_small.size()) {
103       // Signal illegal broadcast_dimensions as unranked.
104       return UnrankedTensorType::get(element_type);
105     }
106   } else {
107     // If no broadcast dimensions, assume "numpy" broadcasting.
108     broadcast_dimensions = llvm::to_vector<4>(llvm::seq<int64_t>(
109         shape_large.size() - shape_small.size(), shape_large.size()));
110   }
111 
112   llvm::SmallVector<int64_t, 4> out_shape(shape_large.begin(),
113                                           shape_large.end());
114 
115   // Update according to the broadcast dimensions.
116   for (auto index_pair : llvm::enumerate(broadcast_dimensions)) {
117     auto old_value = out_shape[index_pair.value()];
118     auto new_value = shape_small[index_pair.index()];
119     if (old_value != -1 && (new_value == -1 || new_value > old_value)) {
120       out_shape[index_pair.value()] = new_value;
121     }
122   }
123 
124   return RankedTensorType::get(out_shape, element_type);
125 }
126 
InferBroadcastBinaryOpReturnTypeComponents(MLIRContext * context,Optional<Location> location,ValueRange operands,DictionaryAttr attributes,Type element_type,SmallVectorImpl<ShapedTypeComponents> & inferedReturnShapes)127 LogicalResult InferBroadcastBinaryOpReturnTypeComponents(
128     MLIRContext* context, Optional<Location> location, ValueRange operands,
129     DictionaryAttr attributes, Type element_type,
130     SmallVectorImpl<ShapedTypeComponents>& inferedReturnShapes) {
131   // Find broadcast_dimensions.
132   DenseIntElementsAttr broadcast_dimensions =
133       attributes.get("broadcast_dimensions")
134           .dyn_cast_or_null<DenseIntElementsAttr>();
135 
136   ShapedType lhs_type = operands[0].getType().dyn_cast<ShapedType>();
137   ShapedType rhs_type = operands[1].getType().dyn_cast<ShapedType>();
138   if (!lhs_type || !rhs_type ||
139       lhs_type.getElementType() != rhs_type.getElementType()) {
140     return emitOptionalError(location, "mismatched operand types");
141   }
142   if (!element_type) element_type = lhs_type.getElementType();
143   Type result_type =
144       GetBroadcastType(lhs_type, rhs_type, element_type, broadcast_dimensions);
145 
146   if (auto ranked_result_type = result_type.dyn_cast<RankedTensorType>()) {
147     inferedReturnShapes.emplace_back(ranked_result_type.getShape(),
148                                      element_type);
149     return success();
150   }
151 
152   // TODO(laurenzo): This should be constructing with `element_type` but that
153   // constructor variant needs to be added upstream.
154   inferedReturnShapes.emplace_back(/* element_type */);
155   return success();
156 }
157 
ReifyBroadcastBinaryOpReturnTypeShapes(OpBuilder & builder,Operation * op,SmallVectorImpl<Value> & reifiedReturnShapes)158 LogicalResult ReifyBroadcastBinaryOpReturnTypeShapes(
159     OpBuilder& builder, Operation* op,
160     SmallVectorImpl<Value>& reifiedReturnShapes) {
161   auto loc = op->getLoc();
162   auto lhs = op->getOperand(0);
163   auto rhs = op->getOperand(1);
164 
165   // Check for "numpy"-style rank broadcast.
166   auto broadcast_dimensions = op->getAttr("broadcast_dimensions")
167                                   .dyn_cast_or_null<DenseIntElementsAttr>();
168   if (broadcast_dimensions &&
169       !hlo::IsLegalNumpyRankedBroadcast(lhs, rhs, broadcast_dimensions)) {
170     // Note: It is unclear whether the general specification of explicit
171     // broadcast_dimensions on binary ops is a feature we want to carry
172     // forward. While it can technically be implemented for ranked-dynamic,
173     // it is incompatible with unranked inputs. If this warning is emitted
174     // in real programs, it is an indication that the feature should be
175     // implemented versus just falling back on the more standard definition
176     // of numpy-like prefix-padding.
177     return op->emitWarning()
178            << "unsupported non prefix-padded dynamic rank "
179            << "broadcast_dimensions = " << broadcast_dimensions;
180   }
181 
182   Value computed_shape = hlo::ComputeBinaryElementwiseBroadcastingResultExtents(
183       loc, lhs, rhs, builder, /*unsafe_as_extent_tensor=*/false);
184   if (!computed_shape) return failure();
185   reifiedReturnShapes.push_back(computed_shape);
186   return success();
187 }
188 }  // namespace
189 
190 //===----------------------------------------------------------------------===//
191 // BroadcastComplexOp (has custom type inference due to different result type).
192 //===----------------------------------------------------------------------===//
193 
inferReturnTypeComponents(MLIRContext * context,Optional<Location> location,ValueRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferedReturnShapes)194 LogicalResult BroadcastComplexOp::inferReturnTypeComponents(
195     MLIRContext* context, Optional<Location> location, ValueRange operands,
196     DictionaryAttr attributes, RegionRange regions,
197     SmallVectorImpl<ShapedTypeComponents>& inferedReturnShapes) {
198   ShapedType lhs_type = operands[0].getType().dyn_cast<ShapedType>();
199   if (!lhs_type) {
200     return emitOptionalError(location, "expected ShapedType");
201   }
202   Type element_type = ComplexType::get(lhs_type.getElementType());
203   return InferBroadcastBinaryOpReturnTypeComponents(context, location, operands,
204                                                     attributes, element_type,
205                                                     inferedReturnShapes);
206 }
reifyReturnTypeShapes(OpBuilder & builder,SmallVectorImpl<Value> & reifiedReturnShapes)207 LogicalResult BroadcastComplexOp::reifyReturnTypeShapes(
208     OpBuilder& builder, SmallVectorImpl<Value>& reifiedReturnShapes) {
209   return ReifyBroadcastBinaryOpReturnTypeShapes(builder, getOperation(),
210                                                 reifiedReturnShapes);
211 }
212 
213 //===----------------------------------------------------------------------===//
214 // BroadcastCompareOp (has custom type inference due to different result type).
215 //===----------------------------------------------------------------------===//
216 
build(OpBuilder & builder,OperationState & result,Value lhs,Value rhs,DenseIntElementsAttr broadcast_dimensions,StringAttr comparison_direction,StringAttr compare_type)217 void BroadcastCompareOp::build(OpBuilder& builder, OperationState& result,
218                                Value lhs, Value rhs,
219                                DenseIntElementsAttr broadcast_dimensions,
220                                StringAttr comparison_direction,
221                                StringAttr compare_type) {
222   auto new_type = GetBroadcastType(lhs.getType(), rhs.getType(),
223                                    builder.getI1Type(), broadcast_dimensions);
224   build(builder, result, new_type, lhs, rhs, broadcast_dimensions,
225         comparison_direction, compare_type);
226 }
227 
inferReturnTypeComponents(MLIRContext * context,Optional<Location> location,ValueRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferedReturnShapes)228 LogicalResult BroadcastCompareOp::inferReturnTypeComponents(
229     MLIRContext* context, Optional<Location> location, ValueRange operands,
230     DictionaryAttr attributes, RegionRange regions,
231     SmallVectorImpl<ShapedTypeComponents>& inferedReturnShapes) {
232   Type element_type = IntegerType::get(context, 1);
233   return InferBroadcastBinaryOpReturnTypeComponents(context, location, operands,
234                                                     attributes, element_type,
235                                                     inferedReturnShapes);
236 }
237 
reifyReturnTypeShapes(OpBuilder & builder,SmallVectorImpl<Value> & reifiedReturnShapes)238 LogicalResult BroadcastCompareOp::reifyReturnTypeShapes(
239     OpBuilder& builder, SmallVectorImpl<Value>& reifiedReturnShapes) {
240   return ReifyBroadcastBinaryOpReturnTypeShapes(builder, getOperation(),
241                                                 reifiedReturnShapes);
242 }
243 
244 //===----------------------------------------------------------------------===//
245 // IsInfOp
246 //===----------------------------------------------------------------------===//
247 
getIsInfLikeReturnType(Value operand)248 static Type getIsInfLikeReturnType(Value operand) {
249   Builder b(operand.getContext());
250   return mhlo::getSameShapeTensorType(operand.getType().cast<TensorType>(),
251                                       b.getI1Type());
252 }
253 
inferReturnTypes(MLIRContext * ctx,Optional<Location>,ValueRange operands,DictionaryAttr,RegionRange,SmallVectorImpl<Type> & inferredReturnTypes)254 LogicalResult IsInfOp::inferReturnTypes(
255     MLIRContext* ctx, Optional<Location>, ValueRange operands, DictionaryAttr,
256     RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {
257   inferredReturnTypes.push_back(getIsInfLikeReturnType(operands.front()));
258   return success();
259 }
260 
261 //===----------------------------------------------------------------------===//
262 // IsNegInfOp
263 //===----------------------------------------------------------------------===//
264 
inferReturnTypes(MLIRContext * ctx,Optional<Location>,ValueRange operands,DictionaryAttr,RegionRange,SmallVectorImpl<Type> & inferredReturnTypes)265 LogicalResult IsNegInfOp::inferReturnTypes(
266     MLIRContext* ctx, Optional<Location>, ValueRange operands, DictionaryAttr,
267     RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {
268   inferredReturnTypes.push_back(getIsInfLikeReturnType(operands.front()));
269   return success();
270 }
271 
272 //===----------------------------------------------------------------------===//
273 // IsPosInfOp
274 //===----------------------------------------------------------------------===//
275 
inferReturnTypes(MLIRContext * ctx,Optional<Location>,ValueRange operands,DictionaryAttr,RegionRange,SmallVectorImpl<Type> & inferredReturnTypes)276 LogicalResult IsPosInfOp::inferReturnTypes(
277     MLIRContext* ctx, Optional<Location>, ValueRange operands, DictionaryAttr,
278     RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {
279   inferredReturnTypes.push_back(getIsInfLikeReturnType(operands.front()));
280   return success();
281 }
282 
283 //===----------------------------------------------------------------------===//
284 // Macros for method definitions that are common to most broadcasting ops.
285 //===----------------------------------------------------------------------===//
286 
287 #define BROADCAST_INFER_SHAPE_TYPE_OP_DEFS(Op)                                \
288   LogicalResult Op::inferReturnTypeComponents(                                \
289       MLIRContext* context, Optional<Location> location, ValueRange operands, \
290       DictionaryAttr attributes, RegionRange regions,                         \
291       SmallVectorImpl<ShapedTypeComponents>& inferedReturnShapes) {           \
292     return InferBroadcastBinaryOpReturnTypeComponents(                        \
293         context, location, operands, attributes, /*element_type=*/nullptr,    \
294         inferedReturnShapes);                                                 \
295   }                                                                           \
296   LogicalResult Op::reifyReturnTypeShapes(                                    \
297       OpBuilder& builder, SmallVectorImpl<Value>& reifiedReturnShapes) {      \
298     return ReifyBroadcastBinaryOpReturnTypeShapes(builder, getOperation(),    \
299                                                   reifiedReturnShapes);       \
300   }
301 
302 #define BROADCAST_BINARY_OP_DEFS(Op)                                           \
303   void Op::build(OpBuilder& builder, OperationState& result, Value left,       \
304                  Value right, DenseIntElementsAttr broadcast_dimensions) {     \
305     auto type = GetBroadcastType(                                              \
306         left.getType().cast<ShapedType>(), right.getType().cast<ShapedType>(), \
307         getElementTypeOrSelf(right.getType()), broadcast_dimensions);          \
308     return Op::build(builder, result, type, left, right,                       \
309                      broadcast_dimensions);                                    \
310   }                                                                            \
311   BROADCAST_INFER_SHAPE_TYPE_OP_DEFS(Op)
312 
313 BROADCAST_BINARY_OP_DEFS(BroadcastAddOp);
314 BROADCAST_BINARY_OP_DEFS(BroadcastAndOp);
315 BROADCAST_BINARY_OP_DEFS(BroadcastAtan2Op);
316 BROADCAST_BINARY_OP_DEFS(BroadcastDivOp);
317 BROADCAST_BINARY_OP_DEFS(BroadcastMaxOp);
318 BROADCAST_BINARY_OP_DEFS(BroadcastMinOp);
319 BROADCAST_BINARY_OP_DEFS(BroadcastMulOp);
320 BROADCAST_BINARY_OP_DEFS(BroadcastOrOp);
321 BROADCAST_BINARY_OP_DEFS(BroadcastPolygammaOp);
322 BROADCAST_BINARY_OP_DEFS(BroadcastPowOp);
323 BROADCAST_BINARY_OP_DEFS(BroadcastRemOp);
324 BROADCAST_BINARY_OP_DEFS(BroadcastShiftLeftOp);
325 BROADCAST_BINARY_OP_DEFS(BroadcastShiftRightArithmeticOp);
326 BROADCAST_BINARY_OP_DEFS(BroadcastShiftRightLogicalOp);
327 BROADCAST_BINARY_OP_DEFS(BroadcastSubOp);
328 BROADCAST_BINARY_OP_DEFS(BroadcastXorOp);
329 BROADCAST_BINARY_OP_DEFS(BroadcastZetaOp);
330 
331 #undef BROADCAST_INFER_SHAPE_TYPE_OP_DEFS
332 #undef BROADCAST_BINARY_OP_DEFS
333 
Verify(ConstantLikeOp op)334 static LogicalResult Verify(ConstantLikeOp op) {
335   if (op.value().getType() != op.getType().cast<ShapedType>().getElementType())
336     return op.emitOpError() << "value's type doesn't match element return type";
337   return success();
338 }
339 
inferReturnTypeComponents(MLIRContext * context,Optional<Location> location,ValueRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferedReturnShapes)340 LogicalResult ConstantLikeOp::inferReturnTypeComponents(
341     MLIRContext* context, Optional<Location> location, ValueRange operands,
342     DictionaryAttr attributes, RegionRange regions,
343     SmallVectorImpl<ShapedTypeComponents>& inferedReturnShapes) {
344   ConstantLikeOp::Adaptor op(operands, attributes);
345   if (failed(op.verify(location.getValue()))) return failure();
346   Type element_type = op.value().getType();
347   Type operand_type = op.operand().getType();
348   if (operand_type.isa<UnrankedTensorType>()) {
349     inferedReturnShapes.emplace_back(element_type);
350   } else {
351     const auto& shape = operand_type.cast<RankedTensorType>().getShape();
352     inferedReturnShapes.emplace_back(shape, element_type);
353   }
354   return success();
355 }
356 
357 struct ConstantLikeToConstant : public OpRewritePattern<ConstantLikeOp> {
358   using OpRewritePattern<ConstantLikeOp>::OpRewritePattern;
359 
matchAndRewritemlir::chlo::ConstantLikeToConstant360   LogicalResult matchAndRewrite(ConstantLikeOp op,
361                                 PatternRewriter& rewriter) const override {
362     auto op_type = op.operand().getType().cast<ShapedType>();
363     if (!op_type.hasStaticShape()) return failure();
364     auto type = RankedTensorType::get(op_type.getShape(), op.value().getType());
365     ElementsAttr attr = DenseElementsAttr::get(type, op.value());
366     rewriter.replaceOpWithNewOp<mhlo::ConstOp>(op.getOperation(), attr);
367     return success();
368   }
369 };
370 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)371 void ConstantLikeOp::getCanonicalizationPatterns(
372     OwningRewritePatternList& results, MLIRContext* context) {
373   results.insert<ConstantLikeToConstant>(context);
374 }
375 
376 }  // namespace chlo
377 }  // namespace mlir
378 
379 #define GET_OP_CLASSES
380 #include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.cc.inc"
381 
382 namespace mlir {
383 namespace chlo {
384 
385 //===----------------------------------------------------------------------===//
386 // chlo Dialect Constructor
387 //===----------------------------------------------------------------------===//
388 
initialize()389 void HloClientDialect::initialize() {
390   addOperations<
391 #define GET_OP_LIST
392 #include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.cc.inc"
393       >();
394 }
395 
396 }  // namespace chlo
397 }  // namespace mlir
398