• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2    Copyright 2022 The StableHLO Authors.
3 
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
7 
8     http://www.apache.org/licenses/LICENSE-2.0
9 
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
15 ==============================================================================*/
16 
17 #include "dialect/ChloOps.h"
18 
19 #include "dialect/BroadcastUtils.h"
20 #include "llvm/ADT/STLExtras.h"
21 #include "llvm/ADT/SmallVector.h"
22 #include "llvm/ADT/TypeSwitch.h"
23 #include "mlir/Dialect/Complex/IR/Complex.h"
24 #include "mlir/Dialect/Traits.h"
25 #include "mlir/IR/Diagnostics.h"
26 #include "mlir/IR/PatternMatch.h"
27 
28 // Include order matters
29 #include "dialect/ChloEnums.cpp.inc"
30 #define GET_ATTRDEF_CLASSES
31 #include "dialect/ChloAttrs.cpp.inc"
32 
33 namespace mlir {
34 namespace chlo {
35 
36 //===----------------------------------------------------------------------===//
37 // CompatibleOperandsAndResultType
38 //===----------------------------------------------------------------------===//
39 
40 // TODO(b/231358795): Review the use of InferTypeOpInterface for ops that
41 // support quantization or sparsity.
42 #define INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(Op)                        \
43   LogicalResult Op::inferReturnTypeComponents(                                \
44       MLIRContext* context, Optional<Location> location,                      \
45       ValueShapeRange operands, DictionaryAttr attributes,                    \
46       RegionRange regions,                                                    \
47       SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {          \
48     return inferReturnTypeComponentsFromOperands(context, location, operands, \
49                                                  attributes, regions,         \
50                                                  inferredReturnShapes);       \
51   }
52 
53 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(AcosOp)
54 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(AcoshOp)
55 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(AsinOp)
56 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(AsinhOp)
57 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(AtanOp)
58 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(AtanhOp)
59 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(BesselI1eOp)
60 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(ConjOp)
61 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(CoshOp)
62 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(DigammaOp)
63 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(ErfOp)
64 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(ErfcOp)
65 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(LgammaOp)
66 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(NextAfterOp)
67 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(PolygammaOp)
68 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(SinhOp)
69 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(TanOp)
70 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(ZetaOp)
71 
72 //===----------------------------------------------------------------------===//
73 // BinaryOps
74 //===----------------------------------------------------------------------===//
75 
76 namespace {
77 // Gets the resulting type from a broadcast between two types.
getBroadcastType(Type x,Type y,Type elementType,DenseIntElementsAttr broadcastDimensionsAttr)78 ShapedTypeComponents getBroadcastType(
79     Type x, Type y, Type elementType,
80     DenseIntElementsAttr broadcastDimensionsAttr) {
81   auto xRanked = x.dyn_cast<RankedTensorType>();
82   auto yRanked = y.dyn_cast<RankedTensorType>();
83   if (!xRanked || !yRanked) {
84     return {elementType};
85   }
86 
87   auto shapeX = xRanked.getShape();
88   auto shapeY = yRanked.getShape();
89 
90   // If no broadcast dimensions, assume "numpy" broadcasting.
91   if (shapeX.size() == shapeY.size() || !broadcastDimensionsAttr) {
92     llvm::SmallVector<int64_t, 4> outShape;
93     if (!mlir::OpTrait::util::getBroadcastedShape(shapeX, shapeY, outShape)) {
94       // Signal illegal broadcast_dimensions as unranked.
95       return {elementType};
96     }
97     return {outShape, elementType};
98   }
99 
100   auto shapeLarge = shapeX.size() > shapeY.size() ? shapeX : shapeY;
101   auto shapeSmall = shapeX.size() <= shapeY.size() ? shapeX : shapeY;
102 
103   auto broadcastDimensions = broadcastDimensionsAttr.getValues<APInt>();
104   if (broadcastDimensions.size() != shapeSmall.size()) {
105     // Signal illegal broadcast_dimensions as unranked.
106     return {elementType};
107   }
108 
109   llvm::SmallVector<int64_t, 4> shapeLargeFiltered;
110   shapeLargeFiltered.reserve(shapeSmall.size());
111   for (const auto& dim : broadcastDimensions) {
112     if (dim.getZExtValue() >= shapeLarge.size()) return {elementType};
113     shapeLargeFiltered.push_back(shapeLarge[dim.getZExtValue()]);
114   }
115   llvm::SmallVector<int64_t, 4> outShapeFiltered;
116   if (!mlir::OpTrait::util::getBroadcastedShape(shapeSmall, shapeLargeFiltered,
117                                                 outShapeFiltered)) {
118     // Signal illegal broadcast_dimensions as unranked.
119     return {elementType};
120   }
121 
122   // Update according to the broadcast dimensions.
123   llvm::SmallVector<int64_t, 4> outShape(shapeLarge.begin(), shapeLarge.end());
124   for (const auto& indexPair : llvm::enumerate(broadcastDimensions)) {
125     auto newValue = outShapeFiltered[indexPair.index()];
126     outShape[indexPair.value().getZExtValue()] = newValue;
127   }
128 
129   return {outShape, elementType};
130 }
131 
InferBroadcastBinaryOpReturnTypeComponents(MLIRContext * context,Optional<Location> location,ValueRange operands,DictionaryAttr attributes,Type elementType,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)132 LogicalResult InferBroadcastBinaryOpReturnTypeComponents(
133     MLIRContext* context, Optional<Location> location, ValueRange operands,
134     DictionaryAttr attributes, Type elementType,
135     SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
136   // Find broadcast_dimensions.
137   DenseIntElementsAttr broadcastDimensions =
138       attributes.get("broadcast_dimensions")
139           .dyn_cast_or_null<DenseIntElementsAttr>();
140 
141   ShapedType lhsType = operands[0].getType().dyn_cast<ShapedType>();
142   ShapedType rhsType = operands[1].getType().dyn_cast<ShapedType>();
143   if (!lhsType || !rhsType ||
144       lhsType.getElementType() != rhsType.getElementType()) {
145     return emitOptionalError(location, "mismatched operand types");
146   }
147   if (!elementType) elementType = lhsType.getElementType();
148   inferredReturnShapes.push_back(
149       getBroadcastType(lhsType, rhsType, elementType, broadcastDimensions));
150   return success();
151 }
152 
ReifyBroadcastBinaryOpReturnTypeShapes(OpBuilder & builder,Operation * op,ValueRange operands,SmallVectorImpl<Value> & result)153 LogicalResult ReifyBroadcastBinaryOpReturnTypeShapes(
154     OpBuilder& builder, Operation* op, ValueRange operands,
155     SmallVectorImpl<Value>& result) {
156   assert(operands.size() == 2 && "expect binary op");
157   auto loc = op->getLoc();
158   auto lhs = operands[0];
159   auto rhs = operands[1];
160 
161   // Check for "numpy"-style rank broadcast.
162   auto broadcastDimensions = op->getAttr("broadcast_dimensions")
163                                  .dyn_cast_or_null<DenseIntElementsAttr>();
164   if (broadcastDimensions &&
165       !hlo::isLegalNumpyRankedBroadcast(lhs, rhs, broadcastDimensions)) {
166     // Note: It is unclear whether the general specification of explicit
167     // broadcast_dimensions on binary ops is a feature we want to carry
168     // forward. While it can technically be implemented for ranked-dynamic,
169     // it is incompatible with unranked inputs. If this warning is emitted
170     // in real programs, it is an indication that the feature should be
171     // implemented versus just falling back on the more standard definition
172     // of numpy-like prefix-padding.
173     return op->emitWarning()
174            << "unsupported non prefix-padded dynamic rank "
175            << "broadcast_dimensions = " << broadcastDimensions;
176   }
177 
178   result.push_back(hlo::computeBinaryElementwiseBroadcastingResultExtents(
179       loc, lhs, rhs, builder));
180   return success();
181 }
182 }  // namespace
183 
184 //===----------------------------------------------------------------------===//
185 // BroadcastComplexOp (has custom type inference due to different result type).
186 //===----------------------------------------------------------------------===//
187 
inferReturnTypeComponents(MLIRContext * context,Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange,SmallVectorImpl<ShapedTypeComponents> & inferedReturnShapes)188 LogicalResult BroadcastComplexOp::inferReturnTypeComponents(
189     MLIRContext* context, Optional<Location> location, ValueShapeRange operands,
190     DictionaryAttr attributes, RegionRange /*regions*/,
191     SmallVectorImpl<ShapedTypeComponents>& inferedReturnShapes) {
192   ShapedType lhsType = operands[0].getType().dyn_cast<ShapedType>();
193   if (!lhsType) {
194     return emitOptionalError(location, "expected ShapedType");
195   }
196   Type elementType = ComplexType::get(lhsType.getElementType());
197   return InferBroadcastBinaryOpReturnTypeComponents(context, location, operands,
198                                                     attributes, elementType,
199                                                     inferedReturnShapes);
200 }
reifyReturnTypeShapes(OpBuilder & builder,ValueRange operands,SmallVectorImpl<Value> & reifiedReturnShapes)201 LogicalResult BroadcastComplexOp::reifyReturnTypeShapes(
202     OpBuilder& builder, ValueRange operands,
203     SmallVectorImpl<Value>& reifiedReturnShapes) {
204   return ReifyBroadcastBinaryOpReturnTypeShapes(builder, getOperation(),
205                                                 operands, reifiedReturnShapes);
206 }
207 
208 //===----------------------------------------------------------------------===//
209 // BroadcastCompareOp (has custom type inference due to different result type).
210 //===----------------------------------------------------------------------===//
211 
build(OpBuilder & builder,OperationState & result,Value lhs,Value rhs,DenseIntElementsAttr broadcastDimensions,chlo::ComparisonDirection comparisonDirection,chlo::ComparisonType compareType)212 void BroadcastCompareOp::build(OpBuilder& builder, OperationState& result,
213                                Value lhs, Value rhs,
214                                DenseIntElementsAttr broadcastDimensions,
215                                chlo::ComparisonDirection comparisonDirection,
216                                chlo::ComparisonType compareType) {
217   build(builder, result, lhs, rhs, broadcastDimensions,
218         chlo::ComparisonDirectionAttr::get(builder.getContext(),
219                                            comparisonDirection),
220         chlo::ComparisonTypeAttr::get(builder.getContext(), compareType));
221 }
222 
inferReturnTypeComponents(MLIRContext * context,Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange,SmallVectorImpl<ShapedTypeComponents> & inferedReturnShapes)223 LogicalResult BroadcastCompareOp::inferReturnTypeComponents(
224     MLIRContext* context, Optional<Location> location, ValueShapeRange operands,
225     DictionaryAttr attributes, RegionRange /*regions*/,
226     SmallVectorImpl<ShapedTypeComponents>& inferedReturnShapes) {
227   Type elementType = IntegerType::get(context, 1);
228   return InferBroadcastBinaryOpReturnTypeComponents(context, location, operands,
229                                                     attributes, elementType,
230                                                     inferedReturnShapes);
231 }
232 
reifyReturnTypeShapes(OpBuilder & builder,ValueRange operands,SmallVectorImpl<Value> & reifiedReturnShapes)233 LogicalResult BroadcastCompareOp::reifyReturnTypeShapes(
234     OpBuilder& builder, ValueRange operands,
235     SmallVectorImpl<Value>& reifiedReturnShapes) {
236   return ReifyBroadcastBinaryOpReturnTypeShapes(builder, getOperation(),
237                                                 operands, reifiedReturnShapes);
238 }
239 
240 //===----------------------------------------------------------------------===//
241 // IsInfOp
242 //===----------------------------------------------------------------------===//
243 
getIsInfLikeReturnType(Value operand)244 static Type getIsInfLikeReturnType(Value operand) {
245   Builder b(operand.getContext());
246   return hlo::getSameShapeTensorType(operand.getType().cast<TensorType>(),
247                                      b.getI1Type());
248 }
249 
inferReturnTypes(MLIRContext *,Optional<Location>,ValueRange operands,DictionaryAttr,RegionRange,SmallVectorImpl<Type> & inferredReturnTypes)250 LogicalResult IsInfOp::inferReturnTypes(
251     MLIRContext* /*ctx*/, Optional<Location>, ValueRange operands,
252     DictionaryAttr, RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {
253   inferredReturnTypes.push_back(getIsInfLikeReturnType(operands.front()));
254   return success();
255 }
256 
257 //===----------------------------------------------------------------------===//
258 // IsNegInfOp
259 //===----------------------------------------------------------------------===//
260 
inferReturnTypes(MLIRContext *,Optional<Location>,ValueRange operands,DictionaryAttr,RegionRange,SmallVectorImpl<Type> & inferredReturnTypes)261 LogicalResult IsNegInfOp::inferReturnTypes(
262     MLIRContext* /*ctx*/, Optional<Location>, ValueRange operands,
263     DictionaryAttr, RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {
264   inferredReturnTypes.push_back(getIsInfLikeReturnType(operands.front()));
265   return success();
266 }
267 
268 //===----------------------------------------------------------------------===//
269 // IsPosInfOp
270 //===----------------------------------------------------------------------===//
271 
inferReturnTypes(MLIRContext *,Optional<Location>,ValueRange operands,DictionaryAttr,RegionRange,SmallVectorImpl<Type> & inferredReturnTypes)272 LogicalResult IsPosInfOp::inferReturnTypes(
273     MLIRContext* /*ctx*/, Optional<Location>, ValueRange operands,
274     DictionaryAttr, RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {
275   inferredReturnTypes.push_back(getIsInfLikeReturnType(operands.front()));
276   return success();
277 }
278 
279 //===----------------------------------------------------------------------===//
280 // Macros for method definitions that are common to most broadcasting ops.
281 //===----------------------------------------------------------------------===//
282 
283 #define BROADCAST_BINARY_OP_DEFS(Op)                                       \
284   LogicalResult Op::inferReturnTypeComponents(                             \
285       MLIRContext* context, Optional<Location> location,                   \
286       ValueShapeRange operands, DictionaryAttr attributes,                 \
287       RegionRange regions,                                                 \
288       SmallVectorImpl<ShapedTypeComponents>& inferedReturnShapes) {        \
289     return InferBroadcastBinaryOpReturnTypeComponents(                     \
290         context, location, operands, attributes, /*element_type=*/nullptr, \
291         inferedReturnShapes);                                              \
292   }                                                                        \
293   LogicalResult Op::reifyReturnTypeShapes(                                 \
294       OpBuilder& builder, ValueRange operands,                             \
295       SmallVectorImpl<Value>& reifiedReturnShapes) {                       \
296     return ReifyBroadcastBinaryOpReturnTypeShapes(                         \
297         builder, getOperation(), operands, reifiedReturnShapes);           \
298   }
299 
300 BROADCAST_BINARY_OP_DEFS(BroadcastAddOp);
301 BROADCAST_BINARY_OP_DEFS(BroadcastAndOp);
302 BROADCAST_BINARY_OP_DEFS(BroadcastAtan2Op);
303 BROADCAST_BINARY_OP_DEFS(BroadcastDivOp);
304 BROADCAST_BINARY_OP_DEFS(BroadcastMaxOp);
305 BROADCAST_BINARY_OP_DEFS(BroadcastMinOp);
306 BROADCAST_BINARY_OP_DEFS(BroadcastMulOp);
307 BROADCAST_BINARY_OP_DEFS(BroadcastNextAfterOp);
308 BROADCAST_BINARY_OP_DEFS(BroadcastOrOp);
309 BROADCAST_BINARY_OP_DEFS(BroadcastPolygammaOp);
310 BROADCAST_BINARY_OP_DEFS(BroadcastPowOp);
311 BROADCAST_BINARY_OP_DEFS(BroadcastRemOp);
312 BROADCAST_BINARY_OP_DEFS(BroadcastShiftLeftOp);
313 BROADCAST_BINARY_OP_DEFS(BroadcastShiftRightArithmeticOp);
314 BROADCAST_BINARY_OP_DEFS(BroadcastShiftRightLogicalOp);
315 BROADCAST_BINARY_OP_DEFS(BroadcastSubOp);
316 BROADCAST_BINARY_OP_DEFS(BroadcastXorOp);
317 BROADCAST_BINARY_OP_DEFS(BroadcastZetaOp);
318 
319 #undef BROADCAST_BINARY_OP_DEFS
320 
verify()321 LogicalResult ConstantLikeOp::verify() {
322   if (value().getType() != getType().cast<ShapedType>().getElementType())
323     return emitOpError() << "value's type doesn't match element return type";
324   return success();
325 }
326 
327 //===----------------------------------------------------------------------===//
328 // MinimumBroadcastShapesOp
329 //===----------------------------------------------------------------------===//
verify()330 LogicalResult MinimumBroadcastShapesOp::verify() {
331   // Check that the number of operands matches the number of outputs.
332   unsigned resultShapesCount = results().size();
333   unsigned operandShapesCount = shapes().size();
334   if (operandShapesCount != resultShapesCount) {
335     return emitOpError() << "number of operand shapes (" << operandShapesCount
336                          << ") does not match number of result shapes ("
337                          << resultShapesCount << ")";
338   }
339   if (operandShapesCount < 2) {
340     return emitOpError() << "number of operand shapes (" << operandShapesCount
341                          << ") should be >= 2";
342   }
343   return success();
344 }
345 
inferReturnTypeComponents(MLIRContext *,Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange,SmallVectorImpl<ShapedTypeComponents> & inferedReturnShapes)346 LogicalResult ConstantLikeOp::inferReturnTypeComponents(
347     MLIRContext* /*context*/, Optional<Location> location,
348     ValueShapeRange operands, DictionaryAttr attributes,
349     RegionRange /*regions*/,
350     SmallVectorImpl<ShapedTypeComponents>& inferedReturnShapes) {
351   ConstantLikeOp::Adaptor op(operands, attributes);
352   if (failed(op.verify(location.value()))) return failure();
353   Type elementType = op.value().getType();
354   Type operandType = op.operand().getType();
355   if (operandType.isa<UnrankedTensorType>()) {
356     inferedReturnShapes.emplace_back(elementType);
357   } else {
358     const auto& shape = operandType.cast<RankedTensorType>().getShape();
359     inferedReturnShapes.emplace_back(shape, elementType);
360   }
361   return success();
362 }
363 
reifyReturnTypeShapes(OpBuilder & builder,ValueRange operands,SmallVectorImpl<Value> & reifiedReturnShapes)364 LogicalResult ConstantLikeOp::reifyReturnTypeShapes(
365     OpBuilder& builder, ValueRange operands,
366     SmallVectorImpl<Value>& reifiedReturnShapes) {
367   return hlo::deriveShapeFromOperand(&builder, getOperation(), operands.front(),
368                                      &reifiedReturnShapes);
369 }
370 
fold(ArrayRef<Attribute>)371 OpFoldResult ConstantLikeOp::fold(ArrayRef<Attribute> /*operands*/) {
372   auto opType = operand().getType().cast<ShapedType>();
373   if (!opType.hasStaticShape()) return {};
374   auto type = RankedTensorType::get(opType.getShape(), value().getType());
375   if (auto complexAttr = value().dyn_cast<complex::NumberAttr>())
376     return DenseElementsAttr::get(type, complexAttr.getValue());
377   return DenseElementsAttr::get(type, value());
378 }
379 
inferReturnTypeComponents(MLIRContext *,Optional<Location> location,ValueShapeRange operands,DictionaryAttr,RegionRange,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)380 LogicalResult BroadcastSelectOp::inferReturnTypeComponents(
381     MLIRContext*, Optional<Location> location, ValueShapeRange operands,
382     DictionaryAttr, RegionRange,
383     SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
384   BroadcastSelectOp::Adaptor op(operands.getValues());
385   auto predType = op.pred().getType().dyn_cast<ShapedType>();
386   auto onTrueType = op.on_true().getType().dyn_cast<ShapedType>();
387   auto onFalseType = op.on_false().getType().dyn_cast<ShapedType>();
388 
389   if (!predType || !onTrueType || !onFalseType ||
390       onTrueType.getElementType() != onFalseType.getElementType()) {
391     return emitOptionalError(location, "mismatched operand types");
392   }
393 
394   Type elementType = onTrueType.getElementType();
395 
396   // Compute the result shape as two binary broadcasts.
397   ShapedTypeComponents& components = inferredReturnShapes.emplace_back(
398       getBroadcastType(onTrueType, onFalseType, elementType, nullptr));
399   if (components.hasRank()) {
400     components = getBroadcastType(
401         RankedTensorType::get(components.getDims(), elementType), predType,
402         elementType, nullptr);
403   }
404   return success();
405 }
406 
reifyReturnTypeShapes(OpBuilder & builder,ValueRange operands,SmallVectorImpl<Value> & result)407 LogicalResult BroadcastSelectOp::reifyReturnTypeShapes(
408     OpBuilder& builder, ValueRange operands, SmallVectorImpl<Value>& result) {
409   result.push_back(hlo::computeNaryElementwiseBroadcastingResultExtents(
410       getLoc(), operands, builder));
411   return success();
412 }
413 
414 //===----------------------------------------------------------------------===//
415 // RankSpecializationClusterOp
416 //===----------------------------------------------------------------------===//
417 
getSuccessorRegions(Optional<unsigned> index,ArrayRef<Attribute>,SmallVectorImpl<RegionSuccessor> & regions)418 void RankSpecializationClusterOp::getSuccessorRegions(
419     Optional<unsigned> index, ArrayRef<Attribute> /*operands*/,
420     SmallVectorImpl<RegionSuccessor>& regions) {
421   // RankSpecializationClusterOp has unconditional control flows into the region
422   // and back to the parent, so return the correct RegionSuccessor purely based
423   // on the index being None or 0.
424   if (index.has_value()) {
425     regions.push_back(RegionSuccessor(getResults()));
426     return;
427   }
428   regions.push_back(RegionSuccessor(&body()));
429 }
430 
verify()431 LogicalResult RankSpecializationClusterOp::verify() {
432   if (body().getArgumentTypes() != getOperandTypes())
433     return emitOpError() << "block argument types must match operand types";
434 
435   // All operands of nested ops must be defined in the body or declared by the
436   // cluster.
437   Block* body = getBody();
438   for (Operation& nested : body->without_terminator()) {
439     if (!llvm::all_of(nested.getOpOperands(), [&](OpOperand& operand) {
440           Operation* def = operand.get().getDefiningOp();
441           if (def != nullptr && def->getBlock() == body) return true;
442           return llvm::is_contained(body->getArguments(), operand.get());
443         })) {
444       return emitOpError() << "nested ops must not depend on implicit operands";
445     }
446   }
447 
448   return success();
449 }
450 
451 //===----------------------------------------------------------------------===//
452 // TopKOp
453 //===----------------------------------------------------------------------===//
454 
inferReturnTypeComponents(MLIRContext * context,Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)455 LogicalResult TopKOp::inferReturnTypeComponents(
456     MLIRContext* context, Optional<Location> location, ValueShapeRange operands,
457     DictionaryAttr attributes, RegionRange regions,
458     SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
459   Builder builder(context);
460   TopKOp::Adaptor adaptor(operands, attributes, regions);
461   Value operand = adaptor.operand();
462   uint64_t k = adaptor.k();
463 
464   auto operandTy = operand.getType().dyn_cast<RankedTensorType>();
465   if (!operandTy) {
466     return emitOptionalError(location, "operand must be ranked");
467   }
468   if (operandTy.getRank() < 1) {
469     return emitOptionalError(location, "operand's rank must be at least 1");
470   }
471   auto operandLastDim = operandTy.getShape()[operandTy.getRank() - 1];
472   if (operandLastDim == ShapedType::kDynamicSize) {
473     return emitOptionalError(location,
474                              "operand's last dimension must be static");
475   }
476   if (operandLastDim < static_cast<int64_t>(k)) {
477     return emitOptionalError(location,
478                              "operand's last dimension must be at least ", k);
479   }
480 
481   SmallVector<int64_t> resultShape;
482   append_range(resultShape, operandTy.getShape());
483   resultShape[operandTy.getRank() - 1] = k;
484 
485   inferredReturnShapes.emplace_back(resultShape, operandTy.getElementType());
486   inferredReturnShapes.emplace_back(resultShape, builder.getI32Type());
487   return success();
488 }
489 
490 //===----------------------------------------------------------------------===//
491 // ConstantOp
492 //===----------------------------------------------------------------------===//
493 
fold(ArrayRef<Attribute>)494 OpFoldResult ConstantOp::fold(ArrayRef<Attribute> /*operands*/) {
495   return value();
496 }
497 
inferReturnTypes(MLIRContext *,Optional<Location>,ValueRange,DictionaryAttr attributes,RegionRange,SmallVectorImpl<Type> & inferredReturnTypes)498 LogicalResult ConstantOp::inferReturnTypes(
499     MLIRContext*, Optional<Location>, ValueRange, DictionaryAttr attributes,
500     RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {
501   Type type = attributes.get("value").cast<TypedAttr>().getType();
502   inferredReturnTypes.push_back(type);
503   return success();
504 }
505 
506 }  // namespace chlo
507 }  // namespace mlir
508 
509 #define GET_OP_CLASSES
510 #include "dialect/ChloOps.cpp.inc"
511 
512 namespace mlir {
513 namespace chlo {
514 
515 //===----------------------------------------------------------------------===//
516 // chlo Dialect Constructor
517 //===----------------------------------------------------------------------===//
518 
ChloDialect(MLIRContext * context)519 ChloDialect::ChloDialect(MLIRContext* context)
520     : Dialect(getDialectNamespace(), context, TypeID::get<ChloDialect>()) {
521   addOperations<
522 #define GET_OP_LIST
523 #include "dialect/ChloOps.cpp.inc"
524       >();
525   addAttributes<
526 #define GET_ATTRDEF_LIST
527 #include "dialect/ChloAttrs.cpp.inc"
528       >();
529 }
530 
materializeConstant(OpBuilder & builder,Attribute value,Type type,Location loc)531 Operation* ChloDialect::materializeConstant(OpBuilder& builder, Attribute value,
532                                             Type type, Location loc) {
533   if (value.isa<ElementsAttr>())
534     return builder.create<chlo::ConstantOp>(loc, type,
535                                             value.cast<ElementsAttr>());
536   return nullptr;
537 }
538 
539 // Entry point for Attribute parsing, TableGen generated code will handle the
540 // dispatch to the individual classes.
parseAttribute(DialectAsmParser & parser,Type type) const541 Attribute ChloDialect::parseAttribute(DialectAsmParser& parser,
542                                       Type type) const {
543   StringRef attrTag;
544   Attribute attr;
545   auto parseResult = generatedAttributeParser(parser, &attrTag, type, attr);
546   if (parseResult.hasValue()) return attr;
547   parser.emitError(parser.getNameLoc(), "unknown chlo attribute");
548   return Attribute();
549 }
550 
551 // Entry point for Attribute printing, TableGen generated code will handle the
552 // dispatch to the individual classes.
printAttribute(Attribute attr,DialectAsmPrinter & os) const553 void ChloDialect::printAttribute(Attribute attr, DialectAsmPrinter& os) const {
554   LogicalResult result = generatedAttributePrinter(attr, os);
555   (void)result;
556   assert(succeeded(result));
557 }
558 
559 }  // namespace chlo
560 }  // namespace mlir
561