1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 Copyright 2022 The StableHLO Authors.
3
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
7
8 http://www.apache.org/licenses/LICENSE-2.0
9
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
15 ==============================================================================*/
16
17 #include "dialect/ChloOps.h"
18
19 #include "dialect/BroadcastUtils.h"
20 #include "llvm/ADT/STLExtras.h"
21 #include "llvm/ADT/SmallVector.h"
22 #include "llvm/ADT/TypeSwitch.h"
23 #include "mlir/Dialect/Complex/IR/Complex.h"
24 #include "mlir/Dialect/Traits.h"
25 #include "mlir/IR/Diagnostics.h"
26 #include "mlir/IR/PatternMatch.h"
27
28 // Include order matters
29 #include "dialect/ChloEnums.cpp.inc"
30 #define GET_ATTRDEF_CLASSES
31 #include "dialect/ChloAttrs.cpp.inc"
32
33 namespace mlir {
34 namespace chlo {
35
36 //===----------------------------------------------------------------------===//
37 // CompatibleOperandsAndResultType
38 //===----------------------------------------------------------------------===//
39
40 // TODO(b/231358795): Review the use of InferTypeOpInterface for ops that
41 // support quantization or sparsity.
42 #define INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(Op) \
43 LogicalResult Op::inferReturnTypeComponents( \
44 MLIRContext* context, Optional<Location> location, \
45 ValueShapeRange operands, DictionaryAttr attributes, \
46 RegionRange regions, \
47 SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) { \
48 return inferReturnTypeComponentsFromOperands(context, location, operands, \
49 attributes, regions, \
50 inferredReturnShapes); \
51 }
52
53 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(AcosOp)
54 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(AcoshOp)
55 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(AsinOp)
56 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(AsinhOp)
57 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(AtanOp)
58 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(AtanhOp)
59 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(BesselI1eOp)
60 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(ConjOp)
61 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(CoshOp)
62 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(DigammaOp)
63 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(ErfOp)
64 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(ErfcOp)
65 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(LgammaOp)
66 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(NextAfterOp)
67 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(PolygammaOp)
68 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(SinhOp)
69 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(TanOp)
70 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(ZetaOp)
71
72 //===----------------------------------------------------------------------===//
73 // BinaryOps
74 //===----------------------------------------------------------------------===//
75
76 namespace {
77 // Gets the resulting type from a broadcast between two types.
getBroadcastType(Type x,Type y,Type elementType,DenseIntElementsAttr broadcastDimensionsAttr)78 ShapedTypeComponents getBroadcastType(
79 Type x, Type y, Type elementType,
80 DenseIntElementsAttr broadcastDimensionsAttr) {
81 auto xRanked = x.dyn_cast<RankedTensorType>();
82 auto yRanked = y.dyn_cast<RankedTensorType>();
83 if (!xRanked || !yRanked) {
84 return {elementType};
85 }
86
87 auto shapeX = xRanked.getShape();
88 auto shapeY = yRanked.getShape();
89
90 // If no broadcast dimensions, assume "numpy" broadcasting.
91 if (shapeX.size() == shapeY.size() || !broadcastDimensionsAttr) {
92 llvm::SmallVector<int64_t, 4> outShape;
93 if (!mlir::OpTrait::util::getBroadcastedShape(shapeX, shapeY, outShape)) {
94 // Signal illegal broadcast_dimensions as unranked.
95 return {elementType};
96 }
97 return {outShape, elementType};
98 }
99
100 auto shapeLarge = shapeX.size() > shapeY.size() ? shapeX : shapeY;
101 auto shapeSmall = shapeX.size() <= shapeY.size() ? shapeX : shapeY;
102
103 auto broadcastDimensions = broadcastDimensionsAttr.getValues<APInt>();
104 if (broadcastDimensions.size() != shapeSmall.size()) {
105 // Signal illegal broadcast_dimensions as unranked.
106 return {elementType};
107 }
108
109 llvm::SmallVector<int64_t, 4> shapeLargeFiltered;
110 shapeLargeFiltered.reserve(shapeSmall.size());
111 for (const auto& dim : broadcastDimensions) {
112 if (dim.getZExtValue() >= shapeLarge.size()) return {elementType};
113 shapeLargeFiltered.push_back(shapeLarge[dim.getZExtValue()]);
114 }
115 llvm::SmallVector<int64_t, 4> outShapeFiltered;
116 if (!mlir::OpTrait::util::getBroadcastedShape(shapeSmall, shapeLargeFiltered,
117 outShapeFiltered)) {
118 // Signal illegal broadcast_dimensions as unranked.
119 return {elementType};
120 }
121
122 // Update according to the broadcast dimensions.
123 llvm::SmallVector<int64_t, 4> outShape(shapeLarge.begin(), shapeLarge.end());
124 for (const auto& indexPair : llvm::enumerate(broadcastDimensions)) {
125 auto newValue = outShapeFiltered[indexPair.index()];
126 outShape[indexPair.value().getZExtValue()] = newValue;
127 }
128
129 return {outShape, elementType};
130 }
131
InferBroadcastBinaryOpReturnTypeComponents(MLIRContext * context,Optional<Location> location,ValueRange operands,DictionaryAttr attributes,Type elementType,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)132 LogicalResult InferBroadcastBinaryOpReturnTypeComponents(
133 MLIRContext* context, Optional<Location> location, ValueRange operands,
134 DictionaryAttr attributes, Type elementType,
135 SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
136 // Find broadcast_dimensions.
137 DenseIntElementsAttr broadcastDimensions =
138 attributes.get("broadcast_dimensions")
139 .dyn_cast_or_null<DenseIntElementsAttr>();
140
141 ShapedType lhsType = operands[0].getType().dyn_cast<ShapedType>();
142 ShapedType rhsType = operands[1].getType().dyn_cast<ShapedType>();
143 if (!lhsType || !rhsType ||
144 lhsType.getElementType() != rhsType.getElementType()) {
145 return emitOptionalError(location, "mismatched operand types");
146 }
147 if (!elementType) elementType = lhsType.getElementType();
148 inferredReturnShapes.push_back(
149 getBroadcastType(lhsType, rhsType, elementType, broadcastDimensions));
150 return success();
151 }
152
ReifyBroadcastBinaryOpReturnTypeShapes(OpBuilder & builder,Operation * op,ValueRange operands,SmallVectorImpl<Value> & result)153 LogicalResult ReifyBroadcastBinaryOpReturnTypeShapes(
154 OpBuilder& builder, Operation* op, ValueRange operands,
155 SmallVectorImpl<Value>& result) {
156 assert(operands.size() == 2 && "expect binary op");
157 auto loc = op->getLoc();
158 auto lhs = operands[0];
159 auto rhs = operands[1];
160
161 // Check for "numpy"-style rank broadcast.
162 auto broadcastDimensions = op->getAttr("broadcast_dimensions")
163 .dyn_cast_or_null<DenseIntElementsAttr>();
164 if (broadcastDimensions &&
165 !hlo::isLegalNumpyRankedBroadcast(lhs, rhs, broadcastDimensions)) {
166 // Note: It is unclear whether the general specification of explicit
167 // broadcast_dimensions on binary ops is a feature we want to carry
168 // forward. While it can technically be implemented for ranked-dynamic,
169 // it is incompatible with unranked inputs. If this warning is emitted
170 // in real programs, it is an indication that the feature should be
171 // implemented versus just falling back on the more standard definition
172 // of numpy-like prefix-padding.
173 return op->emitWarning()
174 << "unsupported non prefix-padded dynamic rank "
175 << "broadcast_dimensions = " << broadcastDimensions;
176 }
177
178 result.push_back(hlo::computeBinaryElementwiseBroadcastingResultExtents(
179 loc, lhs, rhs, builder));
180 return success();
181 }
182 } // namespace
183
184 //===----------------------------------------------------------------------===//
185 // BroadcastComplexOp (has custom type inference due to different result type).
186 //===----------------------------------------------------------------------===//
187
inferReturnTypeComponents(MLIRContext * context,Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange,SmallVectorImpl<ShapedTypeComponents> & inferedReturnShapes)188 LogicalResult BroadcastComplexOp::inferReturnTypeComponents(
189 MLIRContext* context, Optional<Location> location, ValueShapeRange operands,
190 DictionaryAttr attributes, RegionRange /*regions*/,
191 SmallVectorImpl<ShapedTypeComponents>& inferedReturnShapes) {
192 ShapedType lhsType = operands[0].getType().dyn_cast<ShapedType>();
193 if (!lhsType) {
194 return emitOptionalError(location, "expected ShapedType");
195 }
196 Type elementType = ComplexType::get(lhsType.getElementType());
197 return InferBroadcastBinaryOpReturnTypeComponents(context, location, operands,
198 attributes, elementType,
199 inferedReturnShapes);
200 }
reifyReturnTypeShapes(OpBuilder & builder,ValueRange operands,SmallVectorImpl<Value> & reifiedReturnShapes)201 LogicalResult BroadcastComplexOp::reifyReturnTypeShapes(
202 OpBuilder& builder, ValueRange operands,
203 SmallVectorImpl<Value>& reifiedReturnShapes) {
204 return ReifyBroadcastBinaryOpReturnTypeShapes(builder, getOperation(),
205 operands, reifiedReturnShapes);
206 }
207
208 //===----------------------------------------------------------------------===//
209 // BroadcastCompareOp (has custom type inference due to different result type).
210 //===----------------------------------------------------------------------===//
211
build(OpBuilder & builder,OperationState & result,Value lhs,Value rhs,DenseIntElementsAttr broadcastDimensions,chlo::ComparisonDirection comparisonDirection,chlo::ComparisonType compareType)212 void BroadcastCompareOp::build(OpBuilder& builder, OperationState& result,
213 Value lhs, Value rhs,
214 DenseIntElementsAttr broadcastDimensions,
215 chlo::ComparisonDirection comparisonDirection,
216 chlo::ComparisonType compareType) {
217 build(builder, result, lhs, rhs, broadcastDimensions,
218 chlo::ComparisonDirectionAttr::get(builder.getContext(),
219 comparisonDirection),
220 chlo::ComparisonTypeAttr::get(builder.getContext(), compareType));
221 }
222
inferReturnTypeComponents(MLIRContext * context,Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange,SmallVectorImpl<ShapedTypeComponents> & inferedReturnShapes)223 LogicalResult BroadcastCompareOp::inferReturnTypeComponents(
224 MLIRContext* context, Optional<Location> location, ValueShapeRange operands,
225 DictionaryAttr attributes, RegionRange /*regions*/,
226 SmallVectorImpl<ShapedTypeComponents>& inferedReturnShapes) {
227 Type elementType = IntegerType::get(context, 1);
228 return InferBroadcastBinaryOpReturnTypeComponents(context, location, operands,
229 attributes, elementType,
230 inferedReturnShapes);
231 }
232
reifyReturnTypeShapes(OpBuilder & builder,ValueRange operands,SmallVectorImpl<Value> & reifiedReturnShapes)233 LogicalResult BroadcastCompareOp::reifyReturnTypeShapes(
234 OpBuilder& builder, ValueRange operands,
235 SmallVectorImpl<Value>& reifiedReturnShapes) {
236 return ReifyBroadcastBinaryOpReturnTypeShapes(builder, getOperation(),
237 operands, reifiedReturnShapes);
238 }
239
240 //===----------------------------------------------------------------------===//
241 // IsInfOp
242 //===----------------------------------------------------------------------===//
243
getIsInfLikeReturnType(Value operand)244 static Type getIsInfLikeReturnType(Value operand) {
245 Builder b(operand.getContext());
246 return hlo::getSameShapeTensorType(operand.getType().cast<TensorType>(),
247 b.getI1Type());
248 }
249
inferReturnTypes(MLIRContext *,Optional<Location>,ValueRange operands,DictionaryAttr,RegionRange,SmallVectorImpl<Type> & inferredReturnTypes)250 LogicalResult IsInfOp::inferReturnTypes(
251 MLIRContext* /*ctx*/, Optional<Location>, ValueRange operands,
252 DictionaryAttr, RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {
253 inferredReturnTypes.push_back(getIsInfLikeReturnType(operands.front()));
254 return success();
255 }
256
257 //===----------------------------------------------------------------------===//
258 // IsNegInfOp
259 //===----------------------------------------------------------------------===//
260
inferReturnTypes(MLIRContext *,Optional<Location>,ValueRange operands,DictionaryAttr,RegionRange,SmallVectorImpl<Type> & inferredReturnTypes)261 LogicalResult IsNegInfOp::inferReturnTypes(
262 MLIRContext* /*ctx*/, Optional<Location>, ValueRange operands,
263 DictionaryAttr, RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {
264 inferredReturnTypes.push_back(getIsInfLikeReturnType(operands.front()));
265 return success();
266 }
267
268 //===----------------------------------------------------------------------===//
269 // IsPosInfOp
270 //===----------------------------------------------------------------------===//
271
inferReturnTypes(MLIRContext *,Optional<Location>,ValueRange operands,DictionaryAttr,RegionRange,SmallVectorImpl<Type> & inferredReturnTypes)272 LogicalResult IsPosInfOp::inferReturnTypes(
273 MLIRContext* /*ctx*/, Optional<Location>, ValueRange operands,
274 DictionaryAttr, RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {
275 inferredReturnTypes.push_back(getIsInfLikeReturnType(operands.front()));
276 return success();
277 }
278
279 //===----------------------------------------------------------------------===//
280 // Macros for method definitions that are common to most broadcasting ops.
281 //===----------------------------------------------------------------------===//
282
283 #define BROADCAST_BINARY_OP_DEFS(Op) \
284 LogicalResult Op::inferReturnTypeComponents( \
285 MLIRContext* context, Optional<Location> location, \
286 ValueShapeRange operands, DictionaryAttr attributes, \
287 RegionRange regions, \
288 SmallVectorImpl<ShapedTypeComponents>& inferedReturnShapes) { \
289 return InferBroadcastBinaryOpReturnTypeComponents( \
290 context, location, operands, attributes, /*element_type=*/nullptr, \
291 inferedReturnShapes); \
292 } \
293 LogicalResult Op::reifyReturnTypeShapes( \
294 OpBuilder& builder, ValueRange operands, \
295 SmallVectorImpl<Value>& reifiedReturnShapes) { \
296 return ReifyBroadcastBinaryOpReturnTypeShapes( \
297 builder, getOperation(), operands, reifiedReturnShapes); \
298 }
299
300 BROADCAST_BINARY_OP_DEFS(BroadcastAddOp);
301 BROADCAST_BINARY_OP_DEFS(BroadcastAndOp);
302 BROADCAST_BINARY_OP_DEFS(BroadcastAtan2Op);
303 BROADCAST_BINARY_OP_DEFS(BroadcastDivOp);
304 BROADCAST_BINARY_OP_DEFS(BroadcastMaxOp);
305 BROADCAST_BINARY_OP_DEFS(BroadcastMinOp);
306 BROADCAST_BINARY_OP_DEFS(BroadcastMulOp);
307 BROADCAST_BINARY_OP_DEFS(BroadcastNextAfterOp);
308 BROADCAST_BINARY_OP_DEFS(BroadcastOrOp);
309 BROADCAST_BINARY_OP_DEFS(BroadcastPolygammaOp);
310 BROADCAST_BINARY_OP_DEFS(BroadcastPowOp);
311 BROADCAST_BINARY_OP_DEFS(BroadcastRemOp);
312 BROADCAST_BINARY_OP_DEFS(BroadcastShiftLeftOp);
313 BROADCAST_BINARY_OP_DEFS(BroadcastShiftRightArithmeticOp);
314 BROADCAST_BINARY_OP_DEFS(BroadcastShiftRightLogicalOp);
315 BROADCAST_BINARY_OP_DEFS(BroadcastSubOp);
316 BROADCAST_BINARY_OP_DEFS(BroadcastXorOp);
317 BROADCAST_BINARY_OP_DEFS(BroadcastZetaOp);
318
319 #undef BROADCAST_BINARY_OP_DEFS
320
verify()321 LogicalResult ConstantLikeOp::verify() {
322 if (value().getType() != getType().cast<ShapedType>().getElementType())
323 return emitOpError() << "value's type doesn't match element return type";
324 return success();
325 }
326
327 //===----------------------------------------------------------------------===//
328 // MinimumBroadcastShapesOp
329 //===----------------------------------------------------------------------===//
verify()330 LogicalResult MinimumBroadcastShapesOp::verify() {
331 // Check that the number of operands matches the number of outputs.
332 unsigned resultShapesCount = results().size();
333 unsigned operandShapesCount = shapes().size();
334 if (operandShapesCount != resultShapesCount) {
335 return emitOpError() << "number of operand shapes (" << operandShapesCount
336 << ") does not match number of result shapes ("
337 << resultShapesCount << ")";
338 }
339 if (operandShapesCount < 2) {
340 return emitOpError() << "number of operand shapes (" << operandShapesCount
341 << ") should be >= 2";
342 }
343 return success();
344 }
345
inferReturnTypeComponents(MLIRContext *,Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange,SmallVectorImpl<ShapedTypeComponents> & inferedReturnShapes)346 LogicalResult ConstantLikeOp::inferReturnTypeComponents(
347 MLIRContext* /*context*/, Optional<Location> location,
348 ValueShapeRange operands, DictionaryAttr attributes,
349 RegionRange /*regions*/,
350 SmallVectorImpl<ShapedTypeComponents>& inferedReturnShapes) {
351 ConstantLikeOp::Adaptor op(operands, attributes);
352 if (failed(op.verify(location.value()))) return failure();
353 Type elementType = op.value().getType();
354 Type operandType = op.operand().getType();
355 if (operandType.isa<UnrankedTensorType>()) {
356 inferedReturnShapes.emplace_back(elementType);
357 } else {
358 const auto& shape = operandType.cast<RankedTensorType>().getShape();
359 inferedReturnShapes.emplace_back(shape, elementType);
360 }
361 return success();
362 }
363
reifyReturnTypeShapes(OpBuilder & builder,ValueRange operands,SmallVectorImpl<Value> & reifiedReturnShapes)364 LogicalResult ConstantLikeOp::reifyReturnTypeShapes(
365 OpBuilder& builder, ValueRange operands,
366 SmallVectorImpl<Value>& reifiedReturnShapes) {
367 return hlo::deriveShapeFromOperand(&builder, getOperation(), operands.front(),
368 &reifiedReturnShapes);
369 }
370
fold(ArrayRef<Attribute>)371 OpFoldResult ConstantLikeOp::fold(ArrayRef<Attribute> /*operands*/) {
372 auto opType = operand().getType().cast<ShapedType>();
373 if (!opType.hasStaticShape()) return {};
374 auto type = RankedTensorType::get(opType.getShape(), value().getType());
375 if (auto complexAttr = value().dyn_cast<complex::NumberAttr>())
376 return DenseElementsAttr::get(type, complexAttr.getValue());
377 return DenseElementsAttr::get(type, value());
378 }
379
inferReturnTypeComponents(MLIRContext *,Optional<Location> location,ValueShapeRange operands,DictionaryAttr,RegionRange,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)380 LogicalResult BroadcastSelectOp::inferReturnTypeComponents(
381 MLIRContext*, Optional<Location> location, ValueShapeRange operands,
382 DictionaryAttr, RegionRange,
383 SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
384 BroadcastSelectOp::Adaptor op(operands.getValues());
385 auto predType = op.pred().getType().dyn_cast<ShapedType>();
386 auto onTrueType = op.on_true().getType().dyn_cast<ShapedType>();
387 auto onFalseType = op.on_false().getType().dyn_cast<ShapedType>();
388
389 if (!predType || !onTrueType || !onFalseType ||
390 onTrueType.getElementType() != onFalseType.getElementType()) {
391 return emitOptionalError(location, "mismatched operand types");
392 }
393
394 Type elementType = onTrueType.getElementType();
395
396 // Compute the result shape as two binary broadcasts.
397 ShapedTypeComponents& components = inferredReturnShapes.emplace_back(
398 getBroadcastType(onTrueType, onFalseType, elementType, nullptr));
399 if (components.hasRank()) {
400 components = getBroadcastType(
401 RankedTensorType::get(components.getDims(), elementType), predType,
402 elementType, nullptr);
403 }
404 return success();
405 }
406
reifyReturnTypeShapes(OpBuilder & builder,ValueRange operands,SmallVectorImpl<Value> & result)407 LogicalResult BroadcastSelectOp::reifyReturnTypeShapes(
408 OpBuilder& builder, ValueRange operands, SmallVectorImpl<Value>& result) {
409 result.push_back(hlo::computeNaryElementwiseBroadcastingResultExtents(
410 getLoc(), operands, builder));
411 return success();
412 }
413
414 //===----------------------------------------------------------------------===//
415 // RankSpecializationClusterOp
416 //===----------------------------------------------------------------------===//
417
getSuccessorRegions(Optional<unsigned> index,ArrayRef<Attribute>,SmallVectorImpl<RegionSuccessor> & regions)418 void RankSpecializationClusterOp::getSuccessorRegions(
419 Optional<unsigned> index, ArrayRef<Attribute> /*operands*/,
420 SmallVectorImpl<RegionSuccessor>& regions) {
421 // RankSpecializationClusterOp has unconditional control flows into the region
422 // and back to the parent, so return the correct RegionSuccessor purely based
423 // on the index being None or 0.
424 if (index.has_value()) {
425 regions.push_back(RegionSuccessor(getResults()));
426 return;
427 }
428 regions.push_back(RegionSuccessor(&body()));
429 }
430
verify()431 LogicalResult RankSpecializationClusterOp::verify() {
432 if (body().getArgumentTypes() != getOperandTypes())
433 return emitOpError() << "block argument types must match operand types";
434
435 // All operands of nested ops must be defined in the body or declared by the
436 // cluster.
437 Block* body = getBody();
438 for (Operation& nested : body->without_terminator()) {
439 if (!llvm::all_of(nested.getOpOperands(), [&](OpOperand& operand) {
440 Operation* def = operand.get().getDefiningOp();
441 if (def != nullptr && def->getBlock() == body) return true;
442 return llvm::is_contained(body->getArguments(), operand.get());
443 })) {
444 return emitOpError() << "nested ops must not depend on implicit operands";
445 }
446 }
447
448 return success();
449 }
450
451 //===----------------------------------------------------------------------===//
452 // TopKOp
453 //===----------------------------------------------------------------------===//
454
inferReturnTypeComponents(MLIRContext * context,Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)455 LogicalResult TopKOp::inferReturnTypeComponents(
456 MLIRContext* context, Optional<Location> location, ValueShapeRange operands,
457 DictionaryAttr attributes, RegionRange regions,
458 SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
459 Builder builder(context);
460 TopKOp::Adaptor adaptor(operands, attributes, regions);
461 Value operand = adaptor.operand();
462 uint64_t k = adaptor.k();
463
464 auto operandTy = operand.getType().dyn_cast<RankedTensorType>();
465 if (!operandTy) {
466 return emitOptionalError(location, "operand must be ranked");
467 }
468 if (operandTy.getRank() < 1) {
469 return emitOptionalError(location, "operand's rank must be at least 1");
470 }
471 auto operandLastDim = operandTy.getShape()[operandTy.getRank() - 1];
472 if (operandLastDim == ShapedType::kDynamicSize) {
473 return emitOptionalError(location,
474 "operand's last dimension must be static");
475 }
476 if (operandLastDim < static_cast<int64_t>(k)) {
477 return emitOptionalError(location,
478 "operand's last dimension must be at least ", k);
479 }
480
481 SmallVector<int64_t> resultShape;
482 append_range(resultShape, operandTy.getShape());
483 resultShape[operandTy.getRank() - 1] = k;
484
485 inferredReturnShapes.emplace_back(resultShape, operandTy.getElementType());
486 inferredReturnShapes.emplace_back(resultShape, builder.getI32Type());
487 return success();
488 }
489
490 //===----------------------------------------------------------------------===//
491 // ConstantOp
492 //===----------------------------------------------------------------------===//
493
fold(ArrayRef<Attribute>)494 OpFoldResult ConstantOp::fold(ArrayRef<Attribute> /*operands*/) {
495 return value();
496 }
497
inferReturnTypes(MLIRContext *,Optional<Location>,ValueRange,DictionaryAttr attributes,RegionRange,SmallVectorImpl<Type> & inferredReturnTypes)498 LogicalResult ConstantOp::inferReturnTypes(
499 MLIRContext*, Optional<Location>, ValueRange, DictionaryAttr attributes,
500 RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {
501 Type type = attributes.get("value").cast<TypedAttr>().getType();
502 inferredReturnTypes.push_back(type);
503 return success();
504 }
505
506 } // namespace chlo
507 } // namespace mlir
508
509 #define GET_OP_CLASSES
510 #include "dialect/ChloOps.cpp.inc"
511
512 namespace mlir {
513 namespace chlo {
514
515 //===----------------------------------------------------------------------===//
516 // chlo Dialect Constructor
517 //===----------------------------------------------------------------------===//
518
ChloDialect(MLIRContext * context)519 ChloDialect::ChloDialect(MLIRContext* context)
520 : Dialect(getDialectNamespace(), context, TypeID::get<ChloDialect>()) {
521 addOperations<
522 #define GET_OP_LIST
523 #include "dialect/ChloOps.cpp.inc"
524 >();
525 addAttributes<
526 #define GET_ATTRDEF_LIST
527 #include "dialect/ChloAttrs.cpp.inc"
528 >();
529 }
530
materializeConstant(OpBuilder & builder,Attribute value,Type type,Location loc)531 Operation* ChloDialect::materializeConstant(OpBuilder& builder, Attribute value,
532 Type type, Location loc) {
533 if (value.isa<ElementsAttr>())
534 return builder.create<chlo::ConstantOp>(loc, type,
535 value.cast<ElementsAttr>());
536 return nullptr;
537 }
538
539 // Entry point for Attribute parsing, TableGen generated code will handle the
540 // dispatch to the individual classes.
parseAttribute(DialectAsmParser & parser,Type type) const541 Attribute ChloDialect::parseAttribute(DialectAsmParser& parser,
542 Type type) const {
543 StringRef attrTag;
544 Attribute attr;
545 auto parseResult = generatedAttributeParser(parser, &attrTag, type, attr);
546 if (parseResult.hasValue()) return attr;
547 parser.emitError(parser.getNameLoc(), "unknown chlo attribute");
548 return Attribute();
549 }
550
551 // Entry point for Attribute printing, TableGen generated code will handle the
552 // dispatch to the individual classes.
printAttribute(Attribute attr,DialectAsmPrinter & os) const553 void ChloDialect::printAttribute(Attribute attr, DialectAsmPrinter& os) const {
554 LogicalResult result = generatedAttributePrinter(attr, os);
555 (void)result;
556 assert(succeeded(result));
557 }
558
559 } // namespace chlo
560 } // namespace mlir
561