• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
17 
18 #include "llvm/ADT/APFloat.h"
19 #include "mlir-hlo/utils/broadcast_utils.h"
20 #include "mlir/IR/Attributes.h"
21 #include "mlir/IR/Builders.h"
22 #include "mlir/IR/BuiltinTypes.h"
23 #include "mlir/IR/Diagnostics.h"
24 #include "mlir/IR/PatternMatch.h"
25 #include "mlir/IR/TypeUtilities.h"
26 
27 namespace mlir {
28 namespace chlo {
29 
30 template <typename T>
Verify(T op)31 static LogicalResult Verify(T op) {
32   return success();
33 }
34 
getConstantLikeMaxFiniteValue(OpBuilder & b,Location loc,Value val)35 Value getConstantLikeMaxFiniteValue(OpBuilder& b, Location loc, Value val) {
36   auto ty = getElementTypeOrSelf(val.getType()).cast<FloatType>();
37   return getConstantLike(
38       b, loc, llvm::APFloat::getLargest(ty.getFloatSemantics()), val);
39 }
40 
getConstantLikeInfValue(OpBuilder & b,Location loc,Value val,bool negative)41 Value getConstantLikeInfValue(OpBuilder& b, Location loc, Value val,
42                               bool negative) {
43   auto ty = getElementTypeOrSelf(val.getType()).cast<FloatType>();
44   return getConstantLike(
45       b, loc, llvm::APFloat::getInf(ty.getFloatSemantics(), negative), val);
46 }
47 
getConstantLikeSmallestFiniteValue(OpBuilder & b,Location loc,Value val)48 Value getConstantLikeSmallestFiniteValue(OpBuilder& b, Location loc,
49                                          Value val) {
50   auto ty = getElementTypeOrSelf(val.getType()).cast<FloatType>();
51   return getConstantLike(
52       b, loc, llvm::APFloat::getSmallest(ty.getFloatSemantics()), val);
53 }
54 
getConstantLike(OpBuilder & b,Location loc,const APFloat & constant,Value val)55 Value getConstantLike(OpBuilder& b, Location loc, const APFloat& constant,
56                       Value val) {
57   Type ty = getElementTypeOrSelf(val.getType());
58   return b.create<ConstantLikeOp>(loc, b.getFloatAttr(ty, constant), val);
59 }
60 
61 //===----------------------------------------------------------------------===//
62 // BinaryOps
63 //===----------------------------------------------------------------------===//
64 
65 namespace {
66 // Gets the resulting type from a broadcast between two types.
GetBroadcastType(Type x,Type y,Type element_type,DenseIntElementsAttr broadcast_dimensions_attr)67 static Type GetBroadcastType(Type x, Type y, Type element_type,
68                              DenseIntElementsAttr broadcast_dimensions_attr) {
69   auto x_ranked = x.dyn_cast<RankedTensorType>();
70   auto y_ranked = y.dyn_cast<RankedTensorType>();
71   if (!x_ranked || !y_ranked) {
72     return UnrankedTensorType::get(element_type);
73   }
74 
75   auto shape_x = x_ranked.getShape();
76   auto shape_y = y_ranked.getShape();
77 
78   if (shape_x.size() == shape_y.size()) {
79     llvm::SmallVector<int64_t, 4> out_shape(shape_x.size());
80     for (int i = 0, e = shape_x.size(); i < e; i++) {
81       auto x_val = shape_x[i];
82       auto y_val = shape_y[i];
83       if (x_val == -1 || y_val == -1) {
84         out_shape[i] = -1;
85       } else {
86         out_shape[i] = std::max(x_val, y_val);
87       }
88     }
89     return RankedTensorType::get(out_shape, element_type);
90   }
91 
92   auto shape_large = shape_x.size() > shape_y.size() ? shape_x : shape_y;
93   auto shape_small = shape_x.size() <= shape_y.size() ? shape_x : shape_y;
94 
95   llvm::SmallVector<int64_t, 4> broadcast_dimensions;
96   if (broadcast_dimensions_attr) {
97     // Explicit broadcast dimensions.
98     for (const APInt& int_value : broadcast_dimensions_attr.getIntValues()) {
99       broadcast_dimensions.push_back(int_value.getSExtValue());
100     }
101     if (broadcast_dimensions.size() != shape_small.size()) {
102       // Signal illegal broadcast_dimensions as unranked.
103       return UnrankedTensorType::get(element_type);
104     }
105   } else {
106     // If no broadcast dimensions, assume "numpy" broadcasting.
107     broadcast_dimensions = llvm::to_vector<4>(llvm::seq<int64_t>(
108         shape_large.size() - shape_small.size(), shape_large.size()));
109   }
110 
111   llvm::SmallVector<int64_t, 4> out_shape(shape_large.begin(),
112                                           shape_large.end());
113 
114   // Update according to the broadcast dimensions.
115   for (auto index_pair : llvm::enumerate(broadcast_dimensions)) {
116     auto old_value = out_shape[index_pair.value()];
117     auto new_value = shape_small[index_pair.index()];
118     if (old_value != -1 && (new_value == -1 || new_value > old_value)) {
119       out_shape[index_pair.value()] = new_value;
120     }
121   }
122 
123   return RankedTensorType::get(out_shape, element_type);
124 }
125 
InferBroadcastBinaryOpReturnTypeComponents(MLIRContext * context,Optional<Location> location,ValueRange operands,DictionaryAttr attributes,Type element_type,SmallVectorImpl<ShapedTypeComponents> & inferedReturnShapes)126 LogicalResult InferBroadcastBinaryOpReturnTypeComponents(
127     MLIRContext* context, Optional<Location> location, ValueRange operands,
128     DictionaryAttr attributes, Type element_type,
129     SmallVectorImpl<ShapedTypeComponents>& inferedReturnShapes) {
130   // Find broadcast_dimensions.
131   DenseIntElementsAttr broadcast_dimensions =
132       attributes.get("broadcast_dimensions")
133           .dyn_cast_or_null<DenseIntElementsAttr>();
134 
135   ShapedType lhs_type = operands[0].getType().dyn_cast<ShapedType>();
136   ShapedType rhs_type = operands[1].getType().dyn_cast<ShapedType>();
137   if (!lhs_type || !rhs_type ||
138       lhs_type.getElementType() != rhs_type.getElementType()) {
139     return emitOptionalError(location, "mismatched operand types");
140   }
141   if (!element_type) element_type = lhs_type.getElementType();
142   Type result_type =
143       GetBroadcastType(lhs_type, rhs_type, element_type, broadcast_dimensions);
144 
145   if (auto ranked_result_type = result_type.dyn_cast<RankedTensorType>()) {
146     inferedReturnShapes.emplace_back(ranked_result_type.getShape(),
147                                      element_type);
148     return success();
149   }
150 
151   // TODO(laurenzo): This should be constructing with `element_type` but that
152   // constructor variant needs to be added upstream.
153   inferedReturnShapes.emplace_back(/* element_type */);
154   return success();
155 }
156 
ReifyBroadcastBinaryOpReturnTypeShapes(OpBuilder & builder,Operation * op,ValueRange operands,SmallVectorImpl<Value> & result)157 LogicalResult ReifyBroadcastBinaryOpReturnTypeShapes(
158     OpBuilder& builder, Operation* op, ValueRange operands,
159     SmallVectorImpl<Value>& result) {
160   assert(operands.size() == 2 && "expect binary op");
161   auto loc = op->getLoc();
162   auto lhs = operands[0];
163   auto rhs = operands[1];
164 
165   // Check for "numpy"-style rank broadcast.
166   auto broadcast_dimensions = op->getAttr("broadcast_dimensions")
167                                   .dyn_cast_or_null<DenseIntElementsAttr>();
168   if (broadcast_dimensions &&
169       !hlo::IsLegalNumpyRankedBroadcast(lhs, rhs, broadcast_dimensions)) {
170     // Note: It is unclear whether the general specification of explicit
171     // broadcast_dimensions on binary ops is a feature we want to carry
172     // forward. While it can technically be implemented for ranked-dynamic,
173     // it is incompatible with unranked inputs. If this warning is emitted
174     // in real programs, it is an indication that the feature should be
175     // implemented versus just falling back on the more standard definition
176     // of numpy-like prefix-padding.
177     return op->emitWarning()
178            << "unsupported non prefix-padded dynamic rank "
179            << "broadcast_dimensions = " << broadcast_dimensions;
180   }
181 
182   result.push_back(hlo::ComputeBinaryElementwiseBroadcastingResultExtents(
183       loc, lhs, rhs, builder));
184   return success();
185 }
186 }  // namespace
187 
188 //===----------------------------------------------------------------------===//
189 // BroadcastComplexOp (has custom type inference due to different result type).
190 //===----------------------------------------------------------------------===//
191 
inferReturnTypeComponents(MLIRContext * context,Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferedReturnShapes)192 LogicalResult BroadcastComplexOp::inferReturnTypeComponents(
193     MLIRContext* context, Optional<Location> location, ValueShapeRange operands,
194     DictionaryAttr attributes, RegionRange regions,
195     SmallVectorImpl<ShapedTypeComponents>& inferedReturnShapes) {
196   ShapedType lhs_type = operands[0].getType().dyn_cast<ShapedType>();
197   if (!lhs_type) {
198     return emitOptionalError(location, "expected ShapedType");
199   }
200   Type element_type = ComplexType::get(lhs_type.getElementType());
201   return InferBroadcastBinaryOpReturnTypeComponents(context, location, operands,
202                                                     attributes, element_type,
203                                                     inferedReturnShapes);
204 }
reifyReturnTypeShapes(OpBuilder & builder,ValueRange operands,SmallVectorImpl<Value> & reifiedReturnShapes)205 LogicalResult BroadcastComplexOp::reifyReturnTypeShapes(
206     OpBuilder& builder, ValueRange operands,
207     SmallVectorImpl<Value>& reifiedReturnShapes) {
208   return ReifyBroadcastBinaryOpReturnTypeShapes(builder, getOperation(),
209                                                 operands, reifiedReturnShapes);
210 }
211 
212 //===----------------------------------------------------------------------===//
213 // BroadcastCompareOp (has custom type inference due to different result type).
214 //===----------------------------------------------------------------------===//
215 
build(OpBuilder & builder,OperationState & result,Value lhs,Value rhs,DenseIntElementsAttr broadcast_dimensions,StringAttr comparison_direction,StringAttr compare_type)216 void BroadcastCompareOp::build(OpBuilder& builder, OperationState& result,
217                                Value lhs, Value rhs,
218                                DenseIntElementsAttr broadcast_dimensions,
219                                StringAttr comparison_direction,
220                                StringAttr compare_type) {
221   auto new_type = GetBroadcastType(lhs.getType(), rhs.getType(),
222                                    builder.getI1Type(), broadcast_dimensions);
223   build(builder, result, new_type, lhs, rhs, broadcast_dimensions,
224         comparison_direction, compare_type);
225 }
226 
inferReturnTypeComponents(MLIRContext * context,Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferedReturnShapes)227 LogicalResult BroadcastCompareOp::inferReturnTypeComponents(
228     MLIRContext* context, Optional<Location> location, ValueShapeRange operands,
229     DictionaryAttr attributes, RegionRange regions,
230     SmallVectorImpl<ShapedTypeComponents>& inferedReturnShapes) {
231   Type element_type = IntegerType::get(context, 1);
232   return InferBroadcastBinaryOpReturnTypeComponents(context, location, operands,
233                                                     attributes, element_type,
234                                                     inferedReturnShapes);
235 }
236 
reifyReturnTypeShapes(OpBuilder & builder,ValueRange operands,SmallVectorImpl<Value> & reifiedReturnShapes)237 LogicalResult BroadcastCompareOp::reifyReturnTypeShapes(
238     OpBuilder& builder, ValueRange operands,
239     SmallVectorImpl<Value>& reifiedReturnShapes) {
240   return ReifyBroadcastBinaryOpReturnTypeShapes(builder, getOperation(),
241                                                 operands, reifiedReturnShapes);
242 }
243 
244 //===----------------------------------------------------------------------===//
245 // IsInfOp
246 //===----------------------------------------------------------------------===//
247 
getIsInfLikeReturnType(Value operand)248 static Type getIsInfLikeReturnType(Value operand) {
249   Builder b(operand.getContext());
250   return mhlo::getSameShapeTensorType(operand.getType().cast<TensorType>(),
251                                       b.getI1Type());
252 }
253 
inferReturnTypes(MLIRContext * ctx,Optional<Location>,ValueRange operands,DictionaryAttr,RegionRange,SmallVectorImpl<Type> & inferredReturnTypes)254 LogicalResult IsInfOp::inferReturnTypes(
255     MLIRContext* ctx, Optional<Location>, ValueRange operands, DictionaryAttr,
256     RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {
257   inferredReturnTypes.push_back(getIsInfLikeReturnType(operands.front()));
258   return success();
259 }
260 
261 //===----------------------------------------------------------------------===//
262 // IsNegInfOp
263 //===----------------------------------------------------------------------===//
264 
inferReturnTypes(MLIRContext * ctx,Optional<Location>,ValueRange operands,DictionaryAttr,RegionRange,SmallVectorImpl<Type> & inferredReturnTypes)265 LogicalResult IsNegInfOp::inferReturnTypes(
266     MLIRContext* ctx, Optional<Location>, ValueRange operands, DictionaryAttr,
267     RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {
268   inferredReturnTypes.push_back(getIsInfLikeReturnType(operands.front()));
269   return success();
270 }
271 
272 //===----------------------------------------------------------------------===//
273 // IsPosInfOp
274 //===----------------------------------------------------------------------===//
275 
inferReturnTypes(MLIRContext * ctx,Optional<Location>,ValueRange operands,DictionaryAttr,RegionRange,SmallVectorImpl<Type> & inferredReturnTypes)276 LogicalResult IsPosInfOp::inferReturnTypes(
277     MLIRContext* ctx, Optional<Location>, ValueRange operands, DictionaryAttr,
278     RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {
279   inferredReturnTypes.push_back(getIsInfLikeReturnType(operands.front()));
280   return success();
281 }
282 
283 //===----------------------------------------------------------------------===//
284 // Macros for method definitions that are common to most broadcasting ops.
285 //===----------------------------------------------------------------------===//
286 
287 #define BROADCAST_INFER_SHAPE_TYPE_OP_DEFS(Op)                             \
288   LogicalResult Op::inferReturnTypeComponents(                             \
289       MLIRContext* context, Optional<Location> location,                   \
290       ValueShapeRange operands, DictionaryAttr attributes,                 \
291       RegionRange regions,                                                 \
292       SmallVectorImpl<ShapedTypeComponents>& inferedReturnShapes) {        \
293     return InferBroadcastBinaryOpReturnTypeComponents(                     \
294         context, location, operands, attributes, /*element_type=*/nullptr, \
295         inferedReturnShapes);                                              \
296   }                                                                        \
297   LogicalResult Op::reifyReturnTypeShapes(                                 \
298       OpBuilder& builder, ValueRange operands,                             \
299       SmallVectorImpl<Value>& reifiedReturnShapes) {                       \
300     return ReifyBroadcastBinaryOpReturnTypeShapes(                         \
301         builder, getOperation(), operands, reifiedReturnShapes);           \
302   }
303 
304 #define BROADCAST_BINARY_OP_DEFS(Op)                                           \
305   void Op::build(OpBuilder& builder, OperationState& result, Value left,       \
306                  Value right, DenseIntElementsAttr broadcast_dimensions) {     \
307     auto type = GetBroadcastType(                                              \
308         left.getType().cast<ShapedType>(), right.getType().cast<ShapedType>(), \
309         getElementTypeOrSelf(right.getType()), broadcast_dimensions);          \
310     return Op::build(builder, result, type, left, right,                       \
311                      broadcast_dimensions);                                    \
312   }                                                                            \
313   BROADCAST_INFER_SHAPE_TYPE_OP_DEFS(Op)
314 
315 BROADCAST_BINARY_OP_DEFS(BroadcastAddOp);
316 BROADCAST_BINARY_OP_DEFS(BroadcastAndOp);
317 BROADCAST_BINARY_OP_DEFS(BroadcastAtan2Op);
318 BROADCAST_BINARY_OP_DEFS(BroadcastDivOp);
319 BROADCAST_BINARY_OP_DEFS(BroadcastMaxOp);
320 BROADCAST_BINARY_OP_DEFS(BroadcastMinOp);
321 BROADCAST_BINARY_OP_DEFS(BroadcastMulOp);
322 BROADCAST_BINARY_OP_DEFS(BroadcastNextAfterOp);
323 BROADCAST_BINARY_OP_DEFS(BroadcastOrOp);
324 BROADCAST_BINARY_OP_DEFS(BroadcastPolygammaOp);
325 BROADCAST_BINARY_OP_DEFS(BroadcastPowOp);
326 BROADCAST_BINARY_OP_DEFS(BroadcastRemOp);
327 BROADCAST_BINARY_OP_DEFS(BroadcastShiftLeftOp);
328 BROADCAST_BINARY_OP_DEFS(BroadcastShiftRightArithmeticOp);
329 BROADCAST_BINARY_OP_DEFS(BroadcastShiftRightLogicalOp);
330 BROADCAST_BINARY_OP_DEFS(BroadcastSubOp);
331 BROADCAST_BINARY_OP_DEFS(BroadcastXorOp);
332 BROADCAST_BINARY_OP_DEFS(BroadcastZetaOp);
333 
334 #undef BROADCAST_INFER_SHAPE_TYPE_OP_DEFS
335 #undef BROADCAST_BINARY_OP_DEFS
336 
Verify(ConstantLikeOp op)337 static LogicalResult Verify(ConstantLikeOp op) {
338   if (op.value().getType() != op.getType().cast<ShapedType>().getElementType())
339     return op.emitOpError() << "value's type doesn't match element return type";
340   return success();
341 }
342 
343 //===----------------------------------------------------------------------===//
344 // MinimumBroadcastShapesOp
345 //===----------------------------------------------------------------------===//
Verify(MinimumBroadcastShapesOp op)346 static LogicalResult Verify(MinimumBroadcastShapesOp op) {
347   // Check that the number of operands matches the number of outputs.
348   unsigned result_shapes_count = op.results().size();
349   unsigned operand_shapes_count = op.shapes().size();
350   if (operand_shapes_count != result_shapes_count) {
351     return op.emitOpError()
352            << "number of operand shapes (" << operand_shapes_count
353            << ") does not match number of result shapes ("
354            << result_shapes_count << ")";
355   }
356   if (operand_shapes_count < 2) {
357     return op.emitOpError() << "number of operand shapes ("
358                             << operand_shapes_count << ") should be >= 2";
359   }
360   return success();
361 }
362 
inferReturnTypeComponents(MLIRContext * context,Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferedReturnShapes)363 LogicalResult ConstantLikeOp::inferReturnTypeComponents(
364     MLIRContext* context, Optional<Location> location, ValueShapeRange operands,
365     DictionaryAttr attributes, RegionRange regions,
366     SmallVectorImpl<ShapedTypeComponents>& inferedReturnShapes) {
367   ConstantLikeOp::Adaptor op(operands, attributes);
368   if (failed(op.verify(location.getValue()))) return failure();
369   Type element_type = op.value().getType();
370   Type operand_type = op.operand().getType();
371   if (operand_type.isa<UnrankedTensorType>()) {
372     inferedReturnShapes.emplace_back(element_type);
373   } else {
374     const auto& shape = operand_type.cast<RankedTensorType>().getShape();
375     inferedReturnShapes.emplace_back(shape, element_type);
376   }
377   return success();
378 }
379 
reifyReturnTypeShapes(OpBuilder & builder,ValueRange operands,SmallVectorImpl<Value> & reifiedReturnShapes)380 LogicalResult ConstantLikeOp::reifyReturnTypeShapes(
381     OpBuilder& builder, ValueRange operands,
382     SmallVectorImpl<Value>& reifiedReturnShapes) {
383   return ::mlir::mhlo::deriveShapeFromOperand(
384       &builder, getOperation(), operands.front(), &reifiedReturnShapes);
385 }
386 
387 struct ConstantLikeToConstant : public OpRewritePattern<ConstantLikeOp> {
388   using OpRewritePattern<ConstantLikeOp>::OpRewritePattern;
389 
matchAndRewritemlir::chlo::ConstantLikeToConstant390   LogicalResult matchAndRewrite(ConstantLikeOp op,
391                                 PatternRewriter& rewriter) const override {
392     auto op_type = op.operand().getType().cast<ShapedType>();
393     if (!op_type.hasStaticShape()) return failure();
394     auto type = RankedTensorType::get(op_type.getShape(), op.value().getType());
395     ElementsAttr attr = DenseElementsAttr::get(type, op.value());
396     rewriter.replaceOpWithNewOp<mhlo::ConstOp>(op.getOperation(), attr);
397     return success();
398   }
399 };
400 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)401 void ConstantLikeOp::getCanonicalizationPatterns(
402     OwningRewritePatternList& results, MLIRContext* context) {
403   results.insert<ConstantLikeToConstant>(context);
404 }
405 
inferReturnTypeComponents(MLIRContext *,Optional<Location> location,ValueShapeRange operands,DictionaryAttr,RegionRange,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)406 LogicalResult BroadcastSelectOp::inferReturnTypeComponents(
407     MLIRContext*, Optional<Location> location, ValueShapeRange operands,
408     DictionaryAttr, RegionRange,
409     SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
410   BroadcastSelectOp::Adaptor op(operands.getValues());
411   auto pred_type = op.pred().getType().dyn_cast<ShapedType>();
412   auto on_true_type = op.on_true().getType().dyn_cast<ShapedType>();
413   auto on_false_type = op.on_false().getType().dyn_cast<ShapedType>();
414 
415   if (!pred_type || !on_true_type || !on_false_type ||
416       on_true_type.getElementType() != on_false_type.getElementType()) {
417     return emitOptionalError(location, "mismatched operand types");
418   }
419 
420   Type element_type = on_true_type.getElementType();
421 
422   // Compute the result shape as two binary broadcasts.
423   Type other =
424       GetBroadcastType(on_true_type, on_false_type, element_type, nullptr);
425   Type output = GetBroadcastType(other, pred_type, element_type, nullptr);
426 
427   inferredReturnShapes.push_back(output);
428   return success();
429 }
430 
reifyReturnTypeShapes(OpBuilder & builder,ValueRange operands,SmallVectorImpl<Value> & result)431 LogicalResult BroadcastSelectOp::reifyReturnTypeShapes(
432     OpBuilder& builder, ValueRange operands, SmallVectorImpl<Value>& result) {
433   result.push_back(hlo::ComputeNaryElementwiseBroadcastingResultExtents(
434       getLoc(), operands, builder));
435   return success();
436 }
437 
438 //===----------------------------------------------------------------------===//
439 // RankSpecializationClusterOp
440 //===----------------------------------------------------------------------===//
441 
getSuccessorRegions(Optional<unsigned> index,ArrayRef<Attribute> operands,SmallVectorImpl<RegionSuccessor> & regions)442 void RankSpecializationClusterOp::getSuccessorRegions(
443     Optional<unsigned> index, ArrayRef<Attribute> operands,
444     SmallVectorImpl<RegionSuccessor>& regions) {
445   // RankSpecializationClusterOp has unconditional control flows into the region
446   // and back to the parent, so return the correct RegionSuccessor purely based
447   // on the index being None or 0.
448   if (index.hasValue()) {
449     regions.push_back(RegionSuccessor(getResults()));
450     return;
451   }
452   regions.push_back(RegionSuccessor(&body()));
453 }
454 
Verify(RankSpecializationClusterOp op)455 static LogicalResult Verify(RankSpecializationClusterOp op) {
456   if (failed(RegionBranchOpInterface::verifyTypes(op))) return failure();
457   if (op.body().getArgumentTypes() != op.getOperandTypes())
458     return op.emitOpError() << "block argument types must match operand types";
459 
460   // All operands of nested ops must be defined in the body or declared by the
461   // cluster.
462   Block* body = op.getBody();
463   for (Operation& nested : body->without_terminator()) {
464     if (!llvm::all_of(nested.getOpOperands(), [&](OpOperand& operand) {
465           Operation* def = operand.get().getDefiningOp();
466           if (def != nullptr && def->getBlock() == body) return true;
467           return llvm::is_contained(body->getArguments(), operand.get());
468         })) {
469       return op.emitOpError()
470              << "nested ops must not depend on implicit operands";
471     }
472   }
473 
474   return success();
475 }
476 
477 }  // namespace chlo
478 }  // namespace mlir
479 
480 #define GET_OP_CLASSES
481 #include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.cc.inc"
482 
483 namespace mlir {
484 namespace chlo {
485 
486 //===----------------------------------------------------------------------===//
487 // chlo Dialect Constructor
488 //===----------------------------------------------------------------------===//
489 
initialize()490 void HloClientDialect::initialize() {
491   addOperations<
492 #define GET_OP_LIST
493 #include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.cc.inc"
494       >();
495 }
496 
497 }  // namespace chlo
498 }  // namespace mlir
499