1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
17
18 #include "llvm/ADT/APFloat.h"
19 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
20 #include "mlir-hlo/utils/broadcast_utils.h"
21 #include "mlir/IR/Attributes.h"
22 #include "mlir/IR/Builders.h"
23 #include "mlir/IR/BuiltinTypes.h"
24 #include "mlir/IR/Diagnostics.h"
25 #include "mlir/IR/PatternMatch.h"
26 #include "mlir/IR/TypeUtilities.h"
27
28 namespace mlir {
29 namespace chlo {
30
31 template <typename T>
Verify(T op)32 static LogicalResult Verify(T op) {
33 return success();
34 }
35
getConstantLikeMaxFiniteValue(OpBuilder & b,Location loc,Value val)36 Value getConstantLikeMaxFiniteValue(OpBuilder& b, Location loc, Value val) {
37 auto ty = getElementTypeOrSelf(val.getType()).cast<FloatType>();
38 return getConstantLike(
39 b, loc, llvm::APFloat::getLargest(ty.getFloatSemantics()), val);
40 }
41
getConstantLikeInfValue(OpBuilder & b,Location loc,Value val,bool negative)42 Value getConstantLikeInfValue(OpBuilder& b, Location loc, Value val,
43 bool negative) {
44 auto ty = getElementTypeOrSelf(val.getType()).cast<FloatType>();
45 return getConstantLike(
46 b, loc, llvm::APFloat::getInf(ty.getFloatSemantics(), negative), val);
47 }
48
getConstantLikeSmallestFiniteValue(OpBuilder & b,Location loc,Value val)49 Value getConstantLikeSmallestFiniteValue(OpBuilder& b, Location loc,
50 Value val) {
51 auto ty = getElementTypeOrSelf(val.getType()).cast<FloatType>();
52 return getConstantLike(
53 b, loc, llvm::APFloat::getSmallest(ty.getFloatSemantics()), val);
54 }
55
getConstantLike(OpBuilder & b,Location loc,const APFloat & constant,Value val)56 Value getConstantLike(OpBuilder& b, Location loc, const APFloat& constant,
57 Value val) {
58 Type ty = getElementTypeOrSelf(val.getType());
59 return b.create<ConstantLikeOp>(loc, b.getFloatAttr(ty, constant), val);
60 }
61
62 //===----------------------------------------------------------------------===//
63 // BinaryOps
64 //===----------------------------------------------------------------------===//
65
66 namespace {
67 // Gets the resulting type from a broadcast between two types.
GetBroadcastType(Type x,Type y,Type element_type,DenseIntElementsAttr broadcast_dimensions_attr)68 static Type GetBroadcastType(Type x, Type y, Type element_type,
69 DenseIntElementsAttr broadcast_dimensions_attr) {
70 auto x_ranked = x.dyn_cast<RankedTensorType>();
71 auto y_ranked = y.dyn_cast<RankedTensorType>();
72 if (!x_ranked || !y_ranked) {
73 return UnrankedTensorType::get(element_type);
74 }
75
76 auto shape_x = x_ranked.getShape();
77 auto shape_y = y_ranked.getShape();
78
79 if (shape_x.size() == shape_y.size()) {
80 llvm::SmallVector<int64_t, 4> out_shape(shape_x.size());
81 for (int i = 0, e = shape_x.size(); i < e; i++) {
82 auto x_val = shape_x[i];
83 auto y_val = shape_y[i];
84 if (x_val == -1 || y_val == -1) {
85 out_shape[i] = -1;
86 } else {
87 out_shape[i] = std::max(x_val, y_val);
88 }
89 }
90 return RankedTensorType::get(out_shape, element_type);
91 }
92
93 auto shape_large = shape_x.size() > shape_y.size() ? shape_x : shape_y;
94 auto shape_small = shape_x.size() <= shape_y.size() ? shape_x : shape_y;
95
96 llvm::SmallVector<int64_t, 4> broadcast_dimensions;
97 if (broadcast_dimensions_attr) {
98 // Explicit broadcast dimensions.
99 for (const APInt& int_value : broadcast_dimensions_attr.getIntValues()) {
100 broadcast_dimensions.push_back(int_value.getSExtValue());
101 }
102 if (broadcast_dimensions.size() != shape_small.size()) {
103 // Signal illegal broadcast_dimensions as unranked.
104 return UnrankedTensorType::get(element_type);
105 }
106 } else {
107 // If no broadcast dimensions, assume "numpy" broadcasting.
108 broadcast_dimensions = llvm::to_vector<4>(llvm::seq<int64_t>(
109 shape_large.size() - shape_small.size(), shape_large.size()));
110 }
111
112 llvm::SmallVector<int64_t, 4> out_shape(shape_large.begin(),
113 shape_large.end());
114
115 // Update according to the broadcast dimensions.
116 for (auto index_pair : llvm::enumerate(broadcast_dimensions)) {
117 auto old_value = out_shape[index_pair.value()];
118 auto new_value = shape_small[index_pair.index()];
119 if (old_value != -1 && (new_value == -1 || new_value > old_value)) {
120 out_shape[index_pair.value()] = new_value;
121 }
122 }
123
124 return RankedTensorType::get(out_shape, element_type);
125 }
126
InferBroadcastBinaryOpReturnTypeComponents(MLIRContext * context,Optional<Location> location,ValueRange operands,DictionaryAttr attributes,Type element_type,SmallVectorImpl<ShapedTypeComponents> & inferedReturnShapes)127 LogicalResult InferBroadcastBinaryOpReturnTypeComponents(
128 MLIRContext* context, Optional<Location> location, ValueRange operands,
129 DictionaryAttr attributes, Type element_type,
130 SmallVectorImpl<ShapedTypeComponents>& inferedReturnShapes) {
131 // Find broadcast_dimensions.
132 DenseIntElementsAttr broadcast_dimensions =
133 attributes.get("broadcast_dimensions")
134 .dyn_cast_or_null<DenseIntElementsAttr>();
135
136 ShapedType lhs_type = operands[0].getType().dyn_cast<ShapedType>();
137 ShapedType rhs_type = operands[1].getType().dyn_cast<ShapedType>();
138 if (!lhs_type || !rhs_type ||
139 lhs_type.getElementType() != rhs_type.getElementType()) {
140 return emitOptionalError(location, "mismatched operand types");
141 }
142 if (!element_type) element_type = lhs_type.getElementType();
143 Type result_type =
144 GetBroadcastType(lhs_type, rhs_type, element_type, broadcast_dimensions);
145
146 if (auto ranked_result_type = result_type.dyn_cast<RankedTensorType>()) {
147 inferedReturnShapes.emplace_back(ranked_result_type.getShape(),
148 element_type);
149 return success();
150 }
151
152 // TODO(laurenzo): This should be constructing with `element_type` but that
153 // constructor variant needs to be added upstream.
154 inferedReturnShapes.emplace_back(/* element_type */);
155 return success();
156 }
157
ReifyBroadcastBinaryOpReturnTypeShapes(OpBuilder & builder,Operation * op,SmallVectorImpl<Value> & reifiedReturnShapes)158 LogicalResult ReifyBroadcastBinaryOpReturnTypeShapes(
159 OpBuilder& builder, Operation* op,
160 SmallVectorImpl<Value>& reifiedReturnShapes) {
161 auto loc = op->getLoc();
162 auto lhs = op->getOperand(0);
163 auto rhs = op->getOperand(1);
164
165 // Check for "numpy"-style rank broadcast.
166 auto broadcast_dimensions = op->getAttr("broadcast_dimensions")
167 .dyn_cast_or_null<DenseIntElementsAttr>();
168 if (broadcast_dimensions &&
169 !hlo::IsLegalNumpyRankedBroadcast(lhs, rhs, broadcast_dimensions)) {
170 // Note: It is unclear whether the general specification of explicit
171 // broadcast_dimensions on binary ops is a feature we want to carry
172 // forward. While it can technically be implemented for ranked-dynamic,
173 // it is incompatible with unranked inputs. If this warning is emitted
174 // in real programs, it is an indication that the feature should be
175 // implemented versus just falling back on the more standard definition
176 // of numpy-like prefix-padding.
177 return op->emitWarning()
178 << "unsupported non prefix-padded dynamic rank "
179 << "broadcast_dimensions = " << broadcast_dimensions;
180 }
181
182 Value computed_shape = hlo::ComputeBinaryElementwiseBroadcastingResultExtents(
183 loc, lhs, rhs, builder, /*unsafe_as_extent_tensor=*/false);
184 if (!computed_shape) return failure();
185 reifiedReturnShapes.push_back(computed_shape);
186 return success();
187 }
188 } // namespace
189
190 //===----------------------------------------------------------------------===//
191 // BroadcastComplexOp (has custom type inference due to different result type).
192 //===----------------------------------------------------------------------===//
193
inferReturnTypeComponents(MLIRContext * context,Optional<Location> location,ValueRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferedReturnShapes)194 LogicalResult BroadcastComplexOp::inferReturnTypeComponents(
195 MLIRContext* context, Optional<Location> location, ValueRange operands,
196 DictionaryAttr attributes, RegionRange regions,
197 SmallVectorImpl<ShapedTypeComponents>& inferedReturnShapes) {
198 ShapedType lhs_type = operands[0].getType().dyn_cast<ShapedType>();
199 if (!lhs_type) {
200 return emitOptionalError(location, "expected ShapedType");
201 }
202 Type element_type = ComplexType::get(lhs_type.getElementType());
203 return InferBroadcastBinaryOpReturnTypeComponents(context, location, operands,
204 attributes, element_type,
205 inferedReturnShapes);
206 }
reifyReturnTypeShapes(OpBuilder & builder,SmallVectorImpl<Value> & reifiedReturnShapes)207 LogicalResult BroadcastComplexOp::reifyReturnTypeShapes(
208 OpBuilder& builder, SmallVectorImpl<Value>& reifiedReturnShapes) {
209 return ReifyBroadcastBinaryOpReturnTypeShapes(builder, getOperation(),
210 reifiedReturnShapes);
211 }
212
213 //===----------------------------------------------------------------------===//
214 // BroadcastCompareOp (has custom type inference due to different result type).
215 //===----------------------------------------------------------------------===//
216
build(OpBuilder & builder,OperationState & result,Value lhs,Value rhs,DenseIntElementsAttr broadcast_dimensions,StringAttr comparison_direction,StringAttr compare_type)217 void BroadcastCompareOp::build(OpBuilder& builder, OperationState& result,
218 Value lhs, Value rhs,
219 DenseIntElementsAttr broadcast_dimensions,
220 StringAttr comparison_direction,
221 StringAttr compare_type) {
222 auto new_type = GetBroadcastType(lhs.getType(), rhs.getType(),
223 builder.getI1Type(), broadcast_dimensions);
224 build(builder, result, new_type, lhs, rhs, broadcast_dimensions,
225 comparison_direction, compare_type);
226 }
227
inferReturnTypeComponents(MLIRContext * context,Optional<Location> location,ValueRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferedReturnShapes)228 LogicalResult BroadcastCompareOp::inferReturnTypeComponents(
229 MLIRContext* context, Optional<Location> location, ValueRange operands,
230 DictionaryAttr attributes, RegionRange regions,
231 SmallVectorImpl<ShapedTypeComponents>& inferedReturnShapes) {
232 Type element_type = IntegerType::get(context, 1);
233 return InferBroadcastBinaryOpReturnTypeComponents(context, location, operands,
234 attributes, element_type,
235 inferedReturnShapes);
236 }
237
reifyReturnTypeShapes(OpBuilder & builder,SmallVectorImpl<Value> & reifiedReturnShapes)238 LogicalResult BroadcastCompareOp::reifyReturnTypeShapes(
239 OpBuilder& builder, SmallVectorImpl<Value>& reifiedReturnShapes) {
240 return ReifyBroadcastBinaryOpReturnTypeShapes(builder, getOperation(),
241 reifiedReturnShapes);
242 }
243
244 //===----------------------------------------------------------------------===//
245 // IsInfOp
246 //===----------------------------------------------------------------------===//
247
getIsInfLikeReturnType(Value operand)248 static Type getIsInfLikeReturnType(Value operand) {
249 Builder b(operand.getContext());
250 return mhlo::getSameShapeTensorType(operand.getType().cast<TensorType>(),
251 b.getI1Type());
252 }
253
inferReturnTypes(MLIRContext * ctx,Optional<Location>,ValueRange operands,DictionaryAttr,RegionRange,SmallVectorImpl<Type> & inferredReturnTypes)254 LogicalResult IsInfOp::inferReturnTypes(
255 MLIRContext* ctx, Optional<Location>, ValueRange operands, DictionaryAttr,
256 RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {
257 inferredReturnTypes.push_back(getIsInfLikeReturnType(operands.front()));
258 return success();
259 }
260
261 //===----------------------------------------------------------------------===//
262 // IsNegInfOp
263 //===----------------------------------------------------------------------===//
264
inferReturnTypes(MLIRContext * ctx,Optional<Location>,ValueRange operands,DictionaryAttr,RegionRange,SmallVectorImpl<Type> & inferredReturnTypes)265 LogicalResult IsNegInfOp::inferReturnTypes(
266 MLIRContext* ctx, Optional<Location>, ValueRange operands, DictionaryAttr,
267 RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {
268 inferredReturnTypes.push_back(getIsInfLikeReturnType(operands.front()));
269 return success();
270 }
271
272 //===----------------------------------------------------------------------===//
273 // IsPosInfOp
274 //===----------------------------------------------------------------------===//
275
inferReturnTypes(MLIRContext * ctx,Optional<Location>,ValueRange operands,DictionaryAttr,RegionRange,SmallVectorImpl<Type> & inferredReturnTypes)276 LogicalResult IsPosInfOp::inferReturnTypes(
277 MLIRContext* ctx, Optional<Location>, ValueRange operands, DictionaryAttr,
278 RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {
279 inferredReturnTypes.push_back(getIsInfLikeReturnType(operands.front()));
280 return success();
281 }
282
283 //===----------------------------------------------------------------------===//
284 // Macros for method definitions that are common to most broadcasting ops.
285 //===----------------------------------------------------------------------===//
286
287 #define BROADCAST_INFER_SHAPE_TYPE_OP_DEFS(Op) \
288 LogicalResult Op::inferReturnTypeComponents( \
289 MLIRContext* context, Optional<Location> location, ValueRange operands, \
290 DictionaryAttr attributes, RegionRange regions, \
291 SmallVectorImpl<ShapedTypeComponents>& inferedReturnShapes) { \
292 return InferBroadcastBinaryOpReturnTypeComponents( \
293 context, location, operands, attributes, /*element_type=*/nullptr, \
294 inferedReturnShapes); \
295 } \
296 LogicalResult Op::reifyReturnTypeShapes( \
297 OpBuilder& builder, SmallVectorImpl<Value>& reifiedReturnShapes) { \
298 return ReifyBroadcastBinaryOpReturnTypeShapes(builder, getOperation(), \
299 reifiedReturnShapes); \
300 }
301
302 #define BROADCAST_BINARY_OP_DEFS(Op) \
303 void Op::build(OpBuilder& builder, OperationState& result, Value left, \
304 Value right, DenseIntElementsAttr broadcast_dimensions) { \
305 auto type = GetBroadcastType( \
306 left.getType().cast<ShapedType>(), right.getType().cast<ShapedType>(), \
307 getElementTypeOrSelf(right.getType()), broadcast_dimensions); \
308 return Op::build(builder, result, type, left, right, \
309 broadcast_dimensions); \
310 } \
311 BROADCAST_INFER_SHAPE_TYPE_OP_DEFS(Op)
312
313 BROADCAST_BINARY_OP_DEFS(BroadcastAddOp);
314 BROADCAST_BINARY_OP_DEFS(BroadcastAndOp);
315 BROADCAST_BINARY_OP_DEFS(BroadcastAtan2Op);
316 BROADCAST_BINARY_OP_DEFS(BroadcastDivOp);
317 BROADCAST_BINARY_OP_DEFS(BroadcastMaxOp);
318 BROADCAST_BINARY_OP_DEFS(BroadcastMinOp);
319 BROADCAST_BINARY_OP_DEFS(BroadcastMulOp);
320 BROADCAST_BINARY_OP_DEFS(BroadcastOrOp);
321 BROADCAST_BINARY_OP_DEFS(BroadcastPolygammaOp);
322 BROADCAST_BINARY_OP_DEFS(BroadcastPowOp);
323 BROADCAST_BINARY_OP_DEFS(BroadcastRemOp);
324 BROADCAST_BINARY_OP_DEFS(BroadcastShiftLeftOp);
325 BROADCAST_BINARY_OP_DEFS(BroadcastShiftRightArithmeticOp);
326 BROADCAST_BINARY_OP_DEFS(BroadcastShiftRightLogicalOp);
327 BROADCAST_BINARY_OP_DEFS(BroadcastSubOp);
328 BROADCAST_BINARY_OP_DEFS(BroadcastXorOp);
329 BROADCAST_BINARY_OP_DEFS(BroadcastZetaOp);
330
331 #undef BROADCAST_INFER_SHAPE_TYPE_OP_DEFS
332 #undef BROADCAST_BINARY_OP_DEFS
333
Verify(ConstantLikeOp op)334 static LogicalResult Verify(ConstantLikeOp op) {
335 if (op.value().getType() != op.getType().cast<ShapedType>().getElementType())
336 return op.emitOpError() << "value's type doesn't match element return type";
337 return success();
338 }
339
inferReturnTypeComponents(MLIRContext * context,Optional<Location> location,ValueRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferedReturnShapes)340 LogicalResult ConstantLikeOp::inferReturnTypeComponents(
341 MLIRContext* context, Optional<Location> location, ValueRange operands,
342 DictionaryAttr attributes, RegionRange regions,
343 SmallVectorImpl<ShapedTypeComponents>& inferedReturnShapes) {
344 ConstantLikeOp::Adaptor op(operands, attributes);
345 if (failed(op.verify(location.getValue()))) return failure();
346 Type element_type = op.value().getType();
347 Type operand_type = op.operand().getType();
348 if (operand_type.isa<UnrankedTensorType>()) {
349 inferedReturnShapes.emplace_back(element_type);
350 } else {
351 const auto& shape = operand_type.cast<RankedTensorType>().getShape();
352 inferedReturnShapes.emplace_back(shape, element_type);
353 }
354 return success();
355 }
356
357 struct ConstantLikeToConstant : public OpRewritePattern<ConstantLikeOp> {
358 using OpRewritePattern<ConstantLikeOp>::OpRewritePattern;
359
matchAndRewritemlir::chlo::ConstantLikeToConstant360 LogicalResult matchAndRewrite(ConstantLikeOp op,
361 PatternRewriter& rewriter) const override {
362 auto op_type = op.operand().getType().cast<ShapedType>();
363 if (!op_type.hasStaticShape()) return failure();
364 auto type = RankedTensorType::get(op_type.getShape(), op.value().getType());
365 ElementsAttr attr = DenseElementsAttr::get(type, op.value());
366 rewriter.replaceOpWithNewOp<mhlo::ConstOp>(op.getOperation(), attr);
367 return success();
368 }
369 };
370
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)371 void ConstantLikeOp::getCanonicalizationPatterns(
372 OwningRewritePatternList& results, MLIRContext* context) {
373 results.insert<ConstantLikeToConstant>(context);
374 }
375
376 } // namespace chlo
377 } // namespace mlir
378
379 #define GET_OP_CLASSES
380 #include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.cc.inc"
381
382 namespace mlir {
383 namespace chlo {
384
385 //===----------------------------------------------------------------------===//
386 // chlo Dialect Constructor
387 //===----------------------------------------------------------------------===//
388
initialize()389 void HloClientDialect::initialize() {
390 addOperations<
391 #define GET_OP_LIST
392 #include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.cc.inc"
393 >();
394 }
395
396 } // namespace chlo
397 } // namespace mlir
398