1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
17
18 #include "llvm/ADT/APFloat.h"
19 #include "mlir-hlo/utils/broadcast_utils.h"
20 #include "mlir/IR/Attributes.h"
21 #include "mlir/IR/Builders.h"
22 #include "mlir/IR/BuiltinTypes.h"
23 #include "mlir/IR/Diagnostics.h"
24 #include "mlir/IR/PatternMatch.h"
25 #include "mlir/IR/TypeUtilities.h"
26
27 namespace mlir {
28 namespace chlo {
29
30 template <typename T>
Verify(T op)31 static LogicalResult Verify(T op) {
32 return success();
33 }
34
getConstantLikeMaxFiniteValue(OpBuilder & b,Location loc,Value val)35 Value getConstantLikeMaxFiniteValue(OpBuilder& b, Location loc, Value val) {
36 auto ty = getElementTypeOrSelf(val.getType()).cast<FloatType>();
37 return getConstantLike(
38 b, loc, llvm::APFloat::getLargest(ty.getFloatSemantics()), val);
39 }
40
getConstantLikeInfValue(OpBuilder & b,Location loc,Value val,bool negative)41 Value getConstantLikeInfValue(OpBuilder& b, Location loc, Value val,
42 bool negative) {
43 auto ty = getElementTypeOrSelf(val.getType()).cast<FloatType>();
44 return getConstantLike(
45 b, loc, llvm::APFloat::getInf(ty.getFloatSemantics(), negative), val);
46 }
47
getConstantLikeSmallestFiniteValue(OpBuilder & b,Location loc,Value val)48 Value getConstantLikeSmallestFiniteValue(OpBuilder& b, Location loc,
49 Value val) {
50 auto ty = getElementTypeOrSelf(val.getType()).cast<FloatType>();
51 return getConstantLike(
52 b, loc, llvm::APFloat::getSmallest(ty.getFloatSemantics()), val);
53 }
54
getConstantLike(OpBuilder & b,Location loc,const APFloat & constant,Value val)55 Value getConstantLike(OpBuilder& b, Location loc, const APFloat& constant,
56 Value val) {
57 Type ty = getElementTypeOrSelf(val.getType());
58 return b.create<ConstantLikeOp>(loc, b.getFloatAttr(ty, constant), val);
59 }
60
61 //===----------------------------------------------------------------------===//
62 // BinaryOps
63 //===----------------------------------------------------------------------===//
64
65 namespace {
66 // Gets the resulting type from a broadcast between two types.
GetBroadcastType(Type x,Type y,Type element_type,DenseIntElementsAttr broadcast_dimensions_attr)67 static Type GetBroadcastType(Type x, Type y, Type element_type,
68 DenseIntElementsAttr broadcast_dimensions_attr) {
69 auto x_ranked = x.dyn_cast<RankedTensorType>();
70 auto y_ranked = y.dyn_cast<RankedTensorType>();
71 if (!x_ranked || !y_ranked) {
72 return UnrankedTensorType::get(element_type);
73 }
74
75 auto shape_x = x_ranked.getShape();
76 auto shape_y = y_ranked.getShape();
77
78 if (shape_x.size() == shape_y.size()) {
79 llvm::SmallVector<int64_t, 4> out_shape(shape_x.size());
80 for (int i = 0, e = shape_x.size(); i < e; i++) {
81 auto x_val = shape_x[i];
82 auto y_val = shape_y[i];
83 if (x_val == -1 || y_val == -1) {
84 out_shape[i] = -1;
85 } else {
86 out_shape[i] = std::max(x_val, y_val);
87 }
88 }
89 return RankedTensorType::get(out_shape, element_type);
90 }
91
92 auto shape_large = shape_x.size() > shape_y.size() ? shape_x : shape_y;
93 auto shape_small = shape_x.size() <= shape_y.size() ? shape_x : shape_y;
94
95 llvm::SmallVector<int64_t, 4> broadcast_dimensions;
96 if (broadcast_dimensions_attr) {
97 // Explicit broadcast dimensions.
98 for (const APInt& int_value : broadcast_dimensions_attr.getIntValues()) {
99 broadcast_dimensions.push_back(int_value.getSExtValue());
100 }
101 if (broadcast_dimensions.size() != shape_small.size()) {
102 // Signal illegal broadcast_dimensions as unranked.
103 return UnrankedTensorType::get(element_type);
104 }
105 } else {
106 // If no broadcast dimensions, assume "numpy" broadcasting.
107 broadcast_dimensions = llvm::to_vector<4>(llvm::seq<int64_t>(
108 shape_large.size() - shape_small.size(), shape_large.size()));
109 }
110
111 llvm::SmallVector<int64_t, 4> out_shape(shape_large.begin(),
112 shape_large.end());
113
114 // Update according to the broadcast dimensions.
115 for (auto index_pair : llvm::enumerate(broadcast_dimensions)) {
116 auto old_value = out_shape[index_pair.value()];
117 auto new_value = shape_small[index_pair.index()];
118 if (old_value != -1 && (new_value == -1 || new_value > old_value)) {
119 out_shape[index_pair.value()] = new_value;
120 }
121 }
122
123 return RankedTensorType::get(out_shape, element_type);
124 }
125
InferBroadcastBinaryOpReturnTypeComponents(MLIRContext * context,Optional<Location> location,ValueRange operands,DictionaryAttr attributes,Type element_type,SmallVectorImpl<ShapedTypeComponents> & inferedReturnShapes)126 LogicalResult InferBroadcastBinaryOpReturnTypeComponents(
127 MLIRContext* context, Optional<Location> location, ValueRange operands,
128 DictionaryAttr attributes, Type element_type,
129 SmallVectorImpl<ShapedTypeComponents>& inferedReturnShapes) {
130 // Find broadcast_dimensions.
131 DenseIntElementsAttr broadcast_dimensions =
132 attributes.get("broadcast_dimensions")
133 .dyn_cast_or_null<DenseIntElementsAttr>();
134
135 ShapedType lhs_type = operands[0].getType().dyn_cast<ShapedType>();
136 ShapedType rhs_type = operands[1].getType().dyn_cast<ShapedType>();
137 if (!lhs_type || !rhs_type ||
138 lhs_type.getElementType() != rhs_type.getElementType()) {
139 return emitOptionalError(location, "mismatched operand types");
140 }
141 if (!element_type) element_type = lhs_type.getElementType();
142 Type result_type =
143 GetBroadcastType(lhs_type, rhs_type, element_type, broadcast_dimensions);
144
145 if (auto ranked_result_type = result_type.dyn_cast<RankedTensorType>()) {
146 inferedReturnShapes.emplace_back(ranked_result_type.getShape(),
147 element_type);
148 return success();
149 }
150
151 // TODO(laurenzo): This should be constructing with `element_type` but that
152 // constructor variant needs to be added upstream.
153 inferedReturnShapes.emplace_back(/* element_type */);
154 return success();
155 }
156
ReifyBroadcastBinaryOpReturnTypeShapes(OpBuilder & builder,Operation * op,ValueRange operands,SmallVectorImpl<Value> & result)157 LogicalResult ReifyBroadcastBinaryOpReturnTypeShapes(
158 OpBuilder& builder, Operation* op, ValueRange operands,
159 SmallVectorImpl<Value>& result) {
160 assert(operands.size() == 2 && "expect binary op");
161 auto loc = op->getLoc();
162 auto lhs = operands[0];
163 auto rhs = operands[1];
164
165 // Check for "numpy"-style rank broadcast.
166 auto broadcast_dimensions = op->getAttr("broadcast_dimensions")
167 .dyn_cast_or_null<DenseIntElementsAttr>();
168 if (broadcast_dimensions &&
169 !hlo::IsLegalNumpyRankedBroadcast(lhs, rhs, broadcast_dimensions)) {
170 // Note: It is unclear whether the general specification of explicit
171 // broadcast_dimensions on binary ops is a feature we want to carry
172 // forward. While it can technically be implemented for ranked-dynamic,
173 // it is incompatible with unranked inputs. If this warning is emitted
174 // in real programs, it is an indication that the feature should be
175 // implemented versus just falling back on the more standard definition
176 // of numpy-like prefix-padding.
177 return op->emitWarning()
178 << "unsupported non prefix-padded dynamic rank "
179 << "broadcast_dimensions = " << broadcast_dimensions;
180 }
181
182 result.push_back(hlo::ComputeBinaryElementwiseBroadcastingResultExtents(
183 loc, lhs, rhs, builder));
184 return success();
185 }
186 } // namespace
187
188 //===----------------------------------------------------------------------===//
189 // BroadcastComplexOp (has custom type inference due to different result type).
190 //===----------------------------------------------------------------------===//
191
inferReturnTypeComponents(MLIRContext * context,Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferedReturnShapes)192 LogicalResult BroadcastComplexOp::inferReturnTypeComponents(
193 MLIRContext* context, Optional<Location> location, ValueShapeRange operands,
194 DictionaryAttr attributes, RegionRange regions,
195 SmallVectorImpl<ShapedTypeComponents>& inferedReturnShapes) {
196 ShapedType lhs_type = operands[0].getType().dyn_cast<ShapedType>();
197 if (!lhs_type) {
198 return emitOptionalError(location, "expected ShapedType");
199 }
200 Type element_type = ComplexType::get(lhs_type.getElementType());
201 return InferBroadcastBinaryOpReturnTypeComponents(context, location, operands,
202 attributes, element_type,
203 inferedReturnShapes);
204 }
reifyReturnTypeShapes(OpBuilder & builder,ValueRange operands,SmallVectorImpl<Value> & reifiedReturnShapes)205 LogicalResult BroadcastComplexOp::reifyReturnTypeShapes(
206 OpBuilder& builder, ValueRange operands,
207 SmallVectorImpl<Value>& reifiedReturnShapes) {
208 return ReifyBroadcastBinaryOpReturnTypeShapes(builder, getOperation(),
209 operands, reifiedReturnShapes);
210 }
211
212 //===----------------------------------------------------------------------===//
213 // BroadcastCompareOp (has custom type inference due to different result type).
214 //===----------------------------------------------------------------------===//
215
build(OpBuilder & builder,OperationState & result,Value lhs,Value rhs,DenseIntElementsAttr broadcast_dimensions,StringAttr comparison_direction,StringAttr compare_type)216 void BroadcastCompareOp::build(OpBuilder& builder, OperationState& result,
217 Value lhs, Value rhs,
218 DenseIntElementsAttr broadcast_dimensions,
219 StringAttr comparison_direction,
220 StringAttr compare_type) {
221 auto new_type = GetBroadcastType(lhs.getType(), rhs.getType(),
222 builder.getI1Type(), broadcast_dimensions);
223 build(builder, result, new_type, lhs, rhs, broadcast_dimensions,
224 comparison_direction, compare_type);
225 }
226
inferReturnTypeComponents(MLIRContext * context,Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferedReturnShapes)227 LogicalResult BroadcastCompareOp::inferReturnTypeComponents(
228 MLIRContext* context, Optional<Location> location, ValueShapeRange operands,
229 DictionaryAttr attributes, RegionRange regions,
230 SmallVectorImpl<ShapedTypeComponents>& inferedReturnShapes) {
231 Type element_type = IntegerType::get(context, 1);
232 return InferBroadcastBinaryOpReturnTypeComponents(context, location, operands,
233 attributes, element_type,
234 inferedReturnShapes);
235 }
236
reifyReturnTypeShapes(OpBuilder & builder,ValueRange operands,SmallVectorImpl<Value> & reifiedReturnShapes)237 LogicalResult BroadcastCompareOp::reifyReturnTypeShapes(
238 OpBuilder& builder, ValueRange operands,
239 SmallVectorImpl<Value>& reifiedReturnShapes) {
240 return ReifyBroadcastBinaryOpReturnTypeShapes(builder, getOperation(),
241 operands, reifiedReturnShapes);
242 }
243
244 //===----------------------------------------------------------------------===//
245 // IsInfOp
246 //===----------------------------------------------------------------------===//
247
getIsInfLikeReturnType(Value operand)248 static Type getIsInfLikeReturnType(Value operand) {
249 Builder b(operand.getContext());
250 return mhlo::getSameShapeTensorType(operand.getType().cast<TensorType>(),
251 b.getI1Type());
252 }
253
inferReturnTypes(MLIRContext * ctx,Optional<Location>,ValueRange operands,DictionaryAttr,RegionRange,SmallVectorImpl<Type> & inferredReturnTypes)254 LogicalResult IsInfOp::inferReturnTypes(
255 MLIRContext* ctx, Optional<Location>, ValueRange operands, DictionaryAttr,
256 RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {
257 inferredReturnTypes.push_back(getIsInfLikeReturnType(operands.front()));
258 return success();
259 }
260
261 //===----------------------------------------------------------------------===//
262 // IsNegInfOp
263 //===----------------------------------------------------------------------===//
264
inferReturnTypes(MLIRContext * ctx,Optional<Location>,ValueRange operands,DictionaryAttr,RegionRange,SmallVectorImpl<Type> & inferredReturnTypes)265 LogicalResult IsNegInfOp::inferReturnTypes(
266 MLIRContext* ctx, Optional<Location>, ValueRange operands, DictionaryAttr,
267 RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {
268 inferredReturnTypes.push_back(getIsInfLikeReturnType(operands.front()));
269 return success();
270 }
271
272 //===----------------------------------------------------------------------===//
273 // IsPosInfOp
274 //===----------------------------------------------------------------------===//
275
inferReturnTypes(MLIRContext * ctx,Optional<Location>,ValueRange operands,DictionaryAttr,RegionRange,SmallVectorImpl<Type> & inferredReturnTypes)276 LogicalResult IsPosInfOp::inferReturnTypes(
277 MLIRContext* ctx, Optional<Location>, ValueRange operands, DictionaryAttr,
278 RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {
279 inferredReturnTypes.push_back(getIsInfLikeReturnType(operands.front()));
280 return success();
281 }
282
283 //===----------------------------------------------------------------------===//
284 // Macros for method definitions that are common to most broadcasting ops.
285 //===----------------------------------------------------------------------===//
286
287 #define BROADCAST_INFER_SHAPE_TYPE_OP_DEFS(Op) \
288 LogicalResult Op::inferReturnTypeComponents( \
289 MLIRContext* context, Optional<Location> location, \
290 ValueShapeRange operands, DictionaryAttr attributes, \
291 RegionRange regions, \
292 SmallVectorImpl<ShapedTypeComponents>& inferedReturnShapes) { \
293 return InferBroadcastBinaryOpReturnTypeComponents( \
294 context, location, operands, attributes, /*element_type=*/nullptr, \
295 inferedReturnShapes); \
296 } \
297 LogicalResult Op::reifyReturnTypeShapes( \
298 OpBuilder& builder, ValueRange operands, \
299 SmallVectorImpl<Value>& reifiedReturnShapes) { \
300 return ReifyBroadcastBinaryOpReturnTypeShapes( \
301 builder, getOperation(), operands, reifiedReturnShapes); \
302 }
303
304 #define BROADCAST_BINARY_OP_DEFS(Op) \
305 void Op::build(OpBuilder& builder, OperationState& result, Value left, \
306 Value right, DenseIntElementsAttr broadcast_dimensions) { \
307 auto type = GetBroadcastType( \
308 left.getType().cast<ShapedType>(), right.getType().cast<ShapedType>(), \
309 getElementTypeOrSelf(right.getType()), broadcast_dimensions); \
310 return Op::build(builder, result, type, left, right, \
311 broadcast_dimensions); \
312 } \
313 BROADCAST_INFER_SHAPE_TYPE_OP_DEFS(Op)
314
315 BROADCAST_BINARY_OP_DEFS(BroadcastAddOp);
316 BROADCAST_BINARY_OP_DEFS(BroadcastAndOp);
317 BROADCAST_BINARY_OP_DEFS(BroadcastAtan2Op);
318 BROADCAST_BINARY_OP_DEFS(BroadcastDivOp);
319 BROADCAST_BINARY_OP_DEFS(BroadcastMaxOp);
320 BROADCAST_BINARY_OP_DEFS(BroadcastMinOp);
321 BROADCAST_BINARY_OP_DEFS(BroadcastMulOp);
322 BROADCAST_BINARY_OP_DEFS(BroadcastNextAfterOp);
323 BROADCAST_BINARY_OP_DEFS(BroadcastOrOp);
324 BROADCAST_BINARY_OP_DEFS(BroadcastPolygammaOp);
325 BROADCAST_BINARY_OP_DEFS(BroadcastPowOp);
326 BROADCAST_BINARY_OP_DEFS(BroadcastRemOp);
327 BROADCAST_BINARY_OP_DEFS(BroadcastShiftLeftOp);
328 BROADCAST_BINARY_OP_DEFS(BroadcastShiftRightArithmeticOp);
329 BROADCAST_BINARY_OP_DEFS(BroadcastShiftRightLogicalOp);
330 BROADCAST_BINARY_OP_DEFS(BroadcastSubOp);
331 BROADCAST_BINARY_OP_DEFS(BroadcastXorOp);
332 BROADCAST_BINARY_OP_DEFS(BroadcastZetaOp);
333
334 #undef BROADCAST_INFER_SHAPE_TYPE_OP_DEFS
335 #undef BROADCAST_BINARY_OP_DEFS
336
Verify(ConstantLikeOp op)337 static LogicalResult Verify(ConstantLikeOp op) {
338 if (op.value().getType() != op.getType().cast<ShapedType>().getElementType())
339 return op.emitOpError() << "value's type doesn't match element return type";
340 return success();
341 }
342
343 //===----------------------------------------------------------------------===//
344 // MinimumBroadcastShapesOp
345 //===----------------------------------------------------------------------===//
Verify(MinimumBroadcastShapesOp op)346 static LogicalResult Verify(MinimumBroadcastShapesOp op) {
347 // Check that the number of operands matches the number of outputs.
348 unsigned result_shapes_count = op.results().size();
349 unsigned operand_shapes_count = op.shapes().size();
350 if (operand_shapes_count != result_shapes_count) {
351 return op.emitOpError()
352 << "number of operand shapes (" << operand_shapes_count
353 << ") does not match number of result shapes ("
354 << result_shapes_count << ")";
355 }
356 if (operand_shapes_count < 2) {
357 return op.emitOpError() << "number of operand shapes ("
358 << operand_shapes_count << ") should be >= 2";
359 }
360 return success();
361 }
362
inferReturnTypeComponents(MLIRContext * context,Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferedReturnShapes)363 LogicalResult ConstantLikeOp::inferReturnTypeComponents(
364 MLIRContext* context, Optional<Location> location, ValueShapeRange operands,
365 DictionaryAttr attributes, RegionRange regions,
366 SmallVectorImpl<ShapedTypeComponents>& inferedReturnShapes) {
367 ConstantLikeOp::Adaptor op(operands, attributes);
368 if (failed(op.verify(location.getValue()))) return failure();
369 Type element_type = op.value().getType();
370 Type operand_type = op.operand().getType();
371 if (operand_type.isa<UnrankedTensorType>()) {
372 inferedReturnShapes.emplace_back(element_type);
373 } else {
374 const auto& shape = operand_type.cast<RankedTensorType>().getShape();
375 inferedReturnShapes.emplace_back(shape, element_type);
376 }
377 return success();
378 }
379
reifyReturnTypeShapes(OpBuilder & builder,ValueRange operands,SmallVectorImpl<Value> & reifiedReturnShapes)380 LogicalResult ConstantLikeOp::reifyReturnTypeShapes(
381 OpBuilder& builder, ValueRange operands,
382 SmallVectorImpl<Value>& reifiedReturnShapes) {
383 return ::mlir::mhlo::deriveShapeFromOperand(
384 &builder, getOperation(), operands.front(), &reifiedReturnShapes);
385 }
386
387 struct ConstantLikeToConstant : public OpRewritePattern<ConstantLikeOp> {
388 using OpRewritePattern<ConstantLikeOp>::OpRewritePattern;
389
matchAndRewritemlir::chlo::ConstantLikeToConstant390 LogicalResult matchAndRewrite(ConstantLikeOp op,
391 PatternRewriter& rewriter) const override {
392 auto op_type = op.operand().getType().cast<ShapedType>();
393 if (!op_type.hasStaticShape()) return failure();
394 auto type = RankedTensorType::get(op_type.getShape(), op.value().getType());
395 ElementsAttr attr = DenseElementsAttr::get(type, op.value());
396 rewriter.replaceOpWithNewOp<mhlo::ConstOp>(op.getOperation(), attr);
397 return success();
398 }
399 };
400
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)401 void ConstantLikeOp::getCanonicalizationPatterns(
402 OwningRewritePatternList& results, MLIRContext* context) {
403 results.insert<ConstantLikeToConstant>(context);
404 }
405
inferReturnTypeComponents(MLIRContext *,Optional<Location> location,ValueShapeRange operands,DictionaryAttr,RegionRange,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)406 LogicalResult BroadcastSelectOp::inferReturnTypeComponents(
407 MLIRContext*, Optional<Location> location, ValueShapeRange operands,
408 DictionaryAttr, RegionRange,
409 SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
410 BroadcastSelectOp::Adaptor op(operands.getValues());
411 auto pred_type = op.pred().getType().dyn_cast<ShapedType>();
412 auto on_true_type = op.on_true().getType().dyn_cast<ShapedType>();
413 auto on_false_type = op.on_false().getType().dyn_cast<ShapedType>();
414
415 if (!pred_type || !on_true_type || !on_false_type ||
416 on_true_type.getElementType() != on_false_type.getElementType()) {
417 return emitOptionalError(location, "mismatched operand types");
418 }
419
420 Type element_type = on_true_type.getElementType();
421
422 // Compute the result shape as two binary broadcasts.
423 Type other =
424 GetBroadcastType(on_true_type, on_false_type, element_type, nullptr);
425 Type output = GetBroadcastType(other, pred_type, element_type, nullptr);
426
427 inferredReturnShapes.push_back(output);
428 return success();
429 }
430
reifyReturnTypeShapes(OpBuilder & builder,ValueRange operands,SmallVectorImpl<Value> & result)431 LogicalResult BroadcastSelectOp::reifyReturnTypeShapes(
432 OpBuilder& builder, ValueRange operands, SmallVectorImpl<Value>& result) {
433 result.push_back(hlo::ComputeNaryElementwiseBroadcastingResultExtents(
434 getLoc(), operands, builder));
435 return success();
436 }
437
438 //===----------------------------------------------------------------------===//
439 // RankSpecializationClusterOp
440 //===----------------------------------------------------------------------===//
441
getSuccessorRegions(Optional<unsigned> index,ArrayRef<Attribute> operands,SmallVectorImpl<RegionSuccessor> & regions)442 void RankSpecializationClusterOp::getSuccessorRegions(
443 Optional<unsigned> index, ArrayRef<Attribute> operands,
444 SmallVectorImpl<RegionSuccessor>& regions) {
445 // RankSpecializationClusterOp has unconditional control flows into the region
446 // and back to the parent, so return the correct RegionSuccessor purely based
447 // on the index being None or 0.
448 if (index.hasValue()) {
449 regions.push_back(RegionSuccessor(getResults()));
450 return;
451 }
452 regions.push_back(RegionSuccessor(&body()));
453 }
454
Verify(RankSpecializationClusterOp op)455 static LogicalResult Verify(RankSpecializationClusterOp op) {
456 if (failed(RegionBranchOpInterface::verifyTypes(op))) return failure();
457 if (op.body().getArgumentTypes() != op.getOperandTypes())
458 return op.emitOpError() << "block argument types must match operand types";
459
460 // All operands of nested ops must be defined in the body or declared by the
461 // cluster.
462 Block* body = op.getBody();
463 for (Operation& nested : body->without_terminator()) {
464 if (!llvm::all_of(nested.getOpOperands(), [&](OpOperand& operand) {
465 Operation* def = operand.get().getDefiningOp();
466 if (def != nullptr && def->getBlock() == body) return true;
467 return llvm::is_contained(body->getArguments(), operand.get());
468 })) {
469 return op.emitOpError()
470 << "nested ops must not depend on implicit operands";
471 }
472 }
473
474 return success();
475 }
476
477 } // namespace chlo
478 } // namespace mlir
479
480 #define GET_OP_CLASSES
481 #include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.cc.inc"
482
483 namespace mlir {
484 namespace chlo {
485
486 //===----------------------------------------------------------------------===//
487 // chlo Dialect Constructor
488 //===----------------------------------------------------------------------===//
489
initialize()490 void HloClientDialect::initialize() {
491 addOperations<
492 #define GET_OP_LIST
493 #include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.cc.inc"
494 >();
495 }
496
497 } // namespace chlo
498 } // namespace mlir
499