1 //===- OpDefinition.h - Classes for defining concrete Op types --*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements helper classes for implementing the "Op" types. This
10 // includes the Op type, which is the base class for Op class definitions,
11 // as well as number of traits in the OpTrait namespace that provide a
12 // declarative way to specify properties of Ops.
13 //
14 // The purpose of these types are to allow light-weight implementation of
15 // concrete ops (like DimOp) with very little boilerplate.
16 //
17 //===----------------------------------------------------------------------===//
18
19 #ifndef MLIR_IR_OPDEFINITION_H
20 #define MLIR_IR_OPDEFINITION_H
21
22 #include "mlir/IR/Operation.h"
23 #include "llvm/Support/PointerLikeTypeTraits.h"
24
25 #include <type_traits>
26
27 namespace mlir {
28 class Builder;
29 class OpBuilder;
30
31 namespace OpTrait {
32 template <typename ConcreteType> class OneResult;
33 }
34
35 /// This class represents success/failure for operation parsing. It is
36 /// essentially a simple wrapper class around LogicalResult that allows for
37 /// explicit conversion to bool. This allows for the parser to chain together
38 /// parse rules without the clutter of "failed/succeeded".
39 class ParseResult : public LogicalResult {
40 public:
LogicalResult(result)41 ParseResult(LogicalResult result = success()) : LogicalResult(result) {}
42
43 // Allow diagnostics emitted during parsing to be converted to failure.
ParseResult(const InFlightDiagnostic &)44 ParseResult(const InFlightDiagnostic &) : LogicalResult(failure()) {}
ParseResult(const Diagnostic &)45 ParseResult(const Diagnostic &) : LogicalResult(failure()) {}
46
47 /// Failure is true in a boolean context.
48 explicit operator bool() const { return failed(*this); }
49 };
50 /// This class implements `Optional` functionality for ParseResult. We don't
51 /// directly use Optional here, because it provides an implicit conversion
52 /// to 'bool' which we want to avoid. This class is used to implement tri-state
53 /// 'parseOptional' functions that may have a failure mode when parsing that
54 /// shouldn't be attributed to "not present".
55 class OptionalParseResult {
56 public:
57 OptionalParseResult() = default;
OptionalParseResult(LogicalResult result)58 OptionalParseResult(LogicalResult result) : impl(result) {}
OptionalParseResult(ParseResult result)59 OptionalParseResult(ParseResult result) : impl(result) {}
OptionalParseResult(const InFlightDiagnostic &)60 OptionalParseResult(const InFlightDiagnostic &)
61 : OptionalParseResult(failure()) {}
OptionalParseResult(llvm::NoneType)62 OptionalParseResult(llvm::NoneType) : impl(llvm::None) {}
63
64 /// Returns true if we contain a valid ParseResult value.
hasValue()65 bool hasValue() const { return impl.hasValue(); }
66
67 /// Access the internal ParseResult value.
getValue()68 ParseResult getValue() const { return impl.getValue(); }
69 ParseResult operator*() const { return getValue(); }
70
71 private:
72 Optional<ParseResult> impl;
73 };
74
75 // These functions are out-of-line utilities, which avoids them being template
76 // instantiated/duplicated.
77 namespace impl {
78 /// Insert an operation, generated by `buildTerminatorOp`, at the end of the
79 /// region's only block if it does not have a terminator already. If the region
80 /// is empty, insert a new block first. `buildTerminatorOp` should return the
81 /// terminator operation to insert.
82 void ensureRegionTerminator(
83 Region ®ion, OpBuilder &builder, Location loc,
84 function_ref<Operation *(OpBuilder &, Location)> buildTerminatorOp);
85 void ensureRegionTerminator(
86 Region ®ion, Builder &builder, Location loc,
87 function_ref<Operation *(OpBuilder &, Location)> buildTerminatorOp);
88
89 } // namespace impl
90
91 /// This is the concrete base class that holds the operation pointer and has
92 /// non-generic methods that only depend on State (to avoid having them
93 /// instantiated on template types that don't affect them.
94 ///
95 /// This also has the fallback implementations of customization hooks for when
96 /// they aren't customized.
97 class OpState {
98 public:
99 /// Ops are pointer-like, so we allow implicit conversion to bool.
100 operator bool() { return getOperation() != nullptr; }
101
102 /// This implicitly converts to Operation*.
103 operator Operation *() const { return state; }
104
105 /// Shortcut of `->` to access a member of Operation.
106 Operation *operator->() const { return state; }
107
108 /// Return the operation that this refers to.
getOperation()109 Operation *getOperation() { return state; }
110
111 /// Return the dialect that this refers to.
getDialect()112 Dialect *getDialect() { return getOperation()->getDialect(); }
113
114 /// Return the parent Region of this operation.
getParentRegion()115 Region *getParentRegion() { return getOperation()->getParentRegion(); }
116
117 /// Returns the closest surrounding operation that contains this operation
118 /// or nullptr if this is a top-level operation.
getParentOp()119 Operation *getParentOp() { return getOperation()->getParentOp(); }
120
121 /// Return the closest surrounding parent operation that is of type 'OpTy'.
getParentOfType()122 template <typename OpTy> OpTy getParentOfType() {
123 return getOperation()->getParentOfType<OpTy>();
124 }
125
126 /// Returns the closest surrounding parent operation with trait `Trait`.
127 template <template <typename T> class Trait>
getParentWithTrait()128 Operation *getParentWithTrait() {
129 return getOperation()->getParentWithTrait<Trait>();
130 }
131
132 /// Return the context this operation belongs to.
getContext()133 MLIRContext *getContext() { return getOperation()->getContext(); }
134
135 /// Print the operation to the given stream.
136 void print(raw_ostream &os, OpPrintingFlags flags = llvm::None) {
137 state->print(os, flags);
138 }
139 void print(raw_ostream &os, AsmState &asmState,
140 OpPrintingFlags flags = llvm::None) {
141 state->print(os, asmState, flags);
142 }
143
144 /// Dump this operation.
dump()145 void dump() { state->dump(); }
146
147 /// The source location the operation was defined or derived from.
getLoc()148 Location getLoc() { return state->getLoc(); }
setLoc(Location loc)149 void setLoc(Location loc) { state->setLoc(loc); }
150
151 /// Return all of the attributes on this operation.
getAttrs()152 ArrayRef<NamedAttribute> getAttrs() { return state->getAttrs(); }
153
154 /// A utility iterator that filters out non-dialect attributes.
155 using dialect_attr_iterator = Operation::dialect_attr_iterator;
156 using dialect_attr_range = Operation::dialect_attr_range;
157
158 /// Return a range corresponding to the dialect attributes for this operation.
getDialectAttrs()159 dialect_attr_range getDialectAttrs() { return state->getDialectAttrs(); }
dialect_attr_begin()160 dialect_attr_iterator dialect_attr_begin() {
161 return state->dialect_attr_begin();
162 }
dialect_attr_end()163 dialect_attr_iterator dialect_attr_end() { return state->dialect_attr_end(); }
164
165 /// Return an attribute with the specified name.
getAttr(StringRef name)166 Attribute getAttr(StringRef name) { return state->getAttr(name); }
167
168 /// If the operation has an attribute of the specified type, return it.
getAttrOfType(StringRef name)169 template <typename AttrClass> AttrClass getAttrOfType(StringRef name) {
170 return getAttr(name).dyn_cast_or_null<AttrClass>();
171 }
172
173 /// If the an attribute exists with the specified name, change it to the new
174 /// value. Otherwise, add a new attribute with the specified name/value.
setAttr(Identifier name,Attribute value)175 void setAttr(Identifier name, Attribute value) {
176 state->setAttr(name, value);
177 }
setAttr(StringRef name,Attribute value)178 void setAttr(StringRef name, Attribute value) {
179 setAttr(Identifier::get(name, getContext()), value);
180 }
181
182 /// Set the attributes held by this operation.
setAttrs(ArrayRef<NamedAttribute> attributes)183 void setAttrs(ArrayRef<NamedAttribute> attributes) {
184 state->setAttrs(attributes);
185 }
setAttrs(MutableDictionaryAttr newAttrs)186 void setAttrs(MutableDictionaryAttr newAttrs) { state->setAttrs(newAttrs); }
187
188 /// Set the dialect attributes for this operation, and preserve all dependent.
setDialectAttrs(DialectAttrs && attrs)189 template <typename DialectAttrs> void setDialectAttrs(DialectAttrs &&attrs) {
190 state->setDialectAttrs(std::move(attrs));
191 }
192
193 /// Remove the attribute with the specified name if it exists. The return
194 /// value indicates whether the attribute was present or not.
removeAttr(Identifier name)195 MutableDictionaryAttr::RemoveResult removeAttr(Identifier name) {
196 return state->removeAttr(name);
197 }
removeAttr(StringRef name)198 MutableDictionaryAttr::RemoveResult removeAttr(StringRef name) {
199 return state->removeAttr(Identifier::get(name, getContext()));
200 }
201
202 /// Return true if there are no users of any results of this operation.
use_empty()203 bool use_empty() { return state->use_empty(); }
204
205 /// Remove this operation from its parent block and delete it.
erase()206 void erase() { state->erase(); }
207
208 /// Emit an error with the op name prefixed, like "'dim' op " which is
209 /// convenient for verifiers.
210 InFlightDiagnostic emitOpError(const Twine &message = {});
211
212 /// Emit an error about fatal conditions with this operation, reporting up to
213 /// any diagnostic handlers that may be listening.
214 InFlightDiagnostic emitError(const Twine &message = {});
215
216 /// Emit a warning about this operation, reporting up to any diagnostic
217 /// handlers that may be listening.
218 InFlightDiagnostic emitWarning(const Twine &message = {});
219
220 /// Emit a remark about this operation, reporting up to any diagnostic
221 /// handlers that may be listening.
222 InFlightDiagnostic emitRemark(const Twine &message = {});
223
224 /// Walk the operation in postorder, calling the callback for each nested
225 /// operation(including this one).
226 /// See Operation::walk for more details.
227 template <typename FnT, typename RetT = detail::walkResultType<FnT>>
walk(FnT && callback)228 RetT walk(FnT &&callback) {
229 return state->walk(std::forward<FnT>(callback));
230 }
231
232 // These are default implementations of customization hooks.
233 public:
234 /// This hook returns any canonicalization pattern rewrites that the operation
235 /// supports, for use by the canonicalization pass.
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)236 static void getCanonicalizationPatterns(OwningRewritePatternList &results,
237 MLIRContext *context) {}
238
239 protected:
240 /// If the concrete type didn't implement a custom verifier hook, just fall
241 /// back to this one which accepts everything.
verify()242 LogicalResult verify() { return success(); }
243
244 /// Unless overridden, the custom assembly form of an op is always rejected.
245 /// Op implementations should implement this to return failure.
246 /// On success, they should fill in result with the fields to use.
247 static ParseResult parse(OpAsmParser &parser, OperationState &result);
248
249 // The fallback for the printer is to print it the generic assembly form.
250 static void print(Operation *op, OpAsmPrinter &p);
251
252 /// Mutability management is handled by the OpWrapper/OpConstWrapper classes,
253 /// so we can cast it away here.
OpState(Operation * state)254 explicit OpState(Operation *state) : state(state) {}
255
256 private:
257 Operation *state;
258
259 /// Allow access to internal hook implementation methods.
260 friend AbstractOperation;
261 };
262
263 // Allow comparing operators.
264 inline bool operator==(OpState lhs, OpState rhs) {
265 return lhs.getOperation() == rhs.getOperation();
266 }
267 inline bool operator!=(OpState lhs, OpState rhs) {
268 return lhs.getOperation() != rhs.getOperation();
269 }
270
271 /// This class represents a single result from folding an operation.
272 class OpFoldResult : public PointerUnion<Attribute, Value> {
273 using PointerUnion<Attribute, Value>::PointerUnion;
274 };
275
276 /// Allow printing to a stream.
277 inline raw_ostream &operator<<(raw_ostream &os, OpState &op) {
278 op.print(os, OpPrintingFlags().useLocalScope());
279 return os;
280 }
281
282 //===----------------------------------------------------------------------===//
283 // Operation Trait Types
284 //===----------------------------------------------------------------------===//
285
286 namespace OpTrait {
287
288 // These functions are out-of-line implementations of the methods in the
289 // corresponding trait classes. This avoids them being template
290 // instantiated/duplicated.
291 namespace impl {
292 OpFoldResult foldIdempotent(Operation *op);
293 OpFoldResult foldInvolution(Operation *op);
294 LogicalResult verifyZeroOperands(Operation *op);
295 LogicalResult verifyOneOperand(Operation *op);
296 LogicalResult verifyNOperands(Operation *op, unsigned numOperands);
297 LogicalResult verifyIsIdempotent(Operation *op);
298 LogicalResult verifyIsInvolution(Operation *op);
299 LogicalResult verifyAtLeastNOperands(Operation *op, unsigned numOperands);
300 LogicalResult verifyOperandsAreFloatLike(Operation *op);
301 LogicalResult verifyOperandsAreSignlessIntegerLike(Operation *op);
302 LogicalResult verifySameTypeOperands(Operation *op);
303 LogicalResult verifyZeroRegion(Operation *op);
304 LogicalResult verifyOneRegion(Operation *op);
305 LogicalResult verifyNRegions(Operation *op, unsigned numRegions);
306 LogicalResult verifyAtLeastNRegions(Operation *op, unsigned numRegions);
307 LogicalResult verifyZeroResult(Operation *op);
308 LogicalResult verifyOneResult(Operation *op);
309 LogicalResult verifyNResults(Operation *op, unsigned numOperands);
310 LogicalResult verifyAtLeastNResults(Operation *op, unsigned numOperands);
311 LogicalResult verifySameOperandsShape(Operation *op);
312 LogicalResult verifySameOperandsAndResultShape(Operation *op);
313 LogicalResult verifySameOperandsElementType(Operation *op);
314 LogicalResult verifySameOperandsAndResultElementType(Operation *op);
315 LogicalResult verifySameOperandsAndResultType(Operation *op);
316 LogicalResult verifyResultsAreBoolLike(Operation *op);
317 LogicalResult verifyResultsAreFloatLike(Operation *op);
318 LogicalResult verifyResultsAreSignlessIntegerLike(Operation *op);
319 LogicalResult verifyIsTerminator(Operation *op);
320 LogicalResult verifyZeroSuccessor(Operation *op);
321 LogicalResult verifyOneSuccessor(Operation *op);
322 LogicalResult verifyNSuccessors(Operation *op, unsigned numSuccessors);
323 LogicalResult verifyAtLeastNSuccessors(Operation *op, unsigned numSuccessors);
324 LogicalResult verifyOperandSizeAttr(Operation *op, StringRef sizeAttrName);
325 LogicalResult verifyResultSizeAttr(Operation *op, StringRef sizeAttrName);
326 LogicalResult verifyNoRegionArguments(Operation *op);
327 LogicalResult verifyElementwiseMappable(Operation *op);
328 } // namespace impl
329
330 /// Helper class for implementing traits. Clients are not expected to interact
331 /// with this directly, so its members are all protected.
332 template <typename ConcreteType, template <typename> class TraitType>
333 class TraitBase {
334 protected:
335 /// Return the ultimate Operation being worked on.
getOperation()336 Operation *getOperation() {
337 // We have to cast up to the trait type, then to the concrete type, then to
338 // the BaseState class in explicit hops because the concrete type will
339 // multiply derive from the (content free) TraitBase class, and we need to
340 // be able to disambiguate the path for the C++ compiler.
341 auto *trait = static_cast<TraitType<ConcreteType> *>(this);
342 auto *concrete = static_cast<ConcreteType *>(trait);
343 auto *base = static_cast<OpState *>(concrete);
344 return base->getOperation();
345 }
346 };
347
348 //===----------------------------------------------------------------------===//
349 // Operand Traits
350
351 namespace detail {
352 /// Utility trait base that provides accessors for derived traits that have
353 /// multiple operands.
354 template <typename ConcreteType, template <typename> class TraitType>
355 struct MultiOperandTraitBase : public TraitBase<ConcreteType, TraitType> {
356 using operand_iterator = Operation::operand_iterator;
357 using operand_range = Operation::operand_range;
358 using operand_type_iterator = Operation::operand_type_iterator;
359 using operand_type_range = Operation::operand_type_range;
360
361 /// Return the number of operands.
getNumOperandsMultiOperandTraitBase362 unsigned getNumOperands() { return this->getOperation()->getNumOperands(); }
363
364 /// Return the operand at index 'i'.
getOperandMultiOperandTraitBase365 Value getOperand(unsigned i) { return this->getOperation()->getOperand(i); }
366
367 /// Set the operand at index 'i' to 'value'.
setOperandMultiOperandTraitBase368 void setOperand(unsigned i, Value value) {
369 this->getOperation()->setOperand(i, value);
370 }
371
372 /// Operand iterator access.
operand_beginMultiOperandTraitBase373 operand_iterator operand_begin() {
374 return this->getOperation()->operand_begin();
375 }
operand_endMultiOperandTraitBase376 operand_iterator operand_end() { return this->getOperation()->operand_end(); }
getOperandsMultiOperandTraitBase377 operand_range getOperands() { return this->getOperation()->getOperands(); }
378
379 /// Operand type access.
operand_type_beginMultiOperandTraitBase380 operand_type_iterator operand_type_begin() {
381 return this->getOperation()->operand_type_begin();
382 }
operand_type_endMultiOperandTraitBase383 operand_type_iterator operand_type_end() {
384 return this->getOperation()->operand_type_end();
385 }
getOperandTypesMultiOperandTraitBase386 operand_type_range getOperandTypes() {
387 return this->getOperation()->getOperandTypes();
388 }
389 };
390 } // end namespace detail
391
392 /// This class provides the API for ops that are known to have no
393 /// SSA operand.
394 template <typename ConcreteType>
395 class ZeroOperands : public TraitBase<ConcreteType, ZeroOperands> {
396 public:
verifyTrait(Operation * op)397 static LogicalResult verifyTrait(Operation *op) {
398 return impl::verifyZeroOperands(op);
399 }
400
401 private:
402 // Disable these.
getOperand()403 void getOperand() {}
setOperand()404 void setOperand() {}
405 };
406
407 /// This class provides the API for ops that are known to have exactly one
408 /// SSA operand.
409 template <typename ConcreteType>
410 class OneOperand : public TraitBase<ConcreteType, OneOperand> {
411 public:
getOperand()412 Value getOperand() { return this->getOperation()->getOperand(0); }
413
setOperand(Value value)414 void setOperand(Value value) { this->getOperation()->setOperand(0, value); }
415
verifyTrait(Operation * op)416 static LogicalResult verifyTrait(Operation *op) {
417 return impl::verifyOneOperand(op);
418 }
419 };
420
421 /// This class provides the API for ops that are known to have a specified
422 /// number of operands. This is used as a trait like this:
423 ///
424 /// class FooOp : public Op<FooOp, OpTrait::NOperands<2>::Impl> {
425 ///
426 template <unsigned N> class NOperands {
427 public:
428 static_assert(N > 1, "use ZeroOperands/OneOperand for N < 2");
429
430 template <typename ConcreteType>
431 class Impl
432 : public detail::MultiOperandTraitBase<ConcreteType, NOperands<N>::Impl> {
433 public:
verifyTrait(Operation * op)434 static LogicalResult verifyTrait(Operation *op) {
435 return impl::verifyNOperands(op, N);
436 }
437 };
438 };
439
440 /// This class provides the API for ops that are known to have a at least a
441 /// specified number of operands. This is used as a trait like this:
442 ///
443 /// class FooOp : public Op<FooOp, OpTrait::AtLeastNOperands<2>::Impl> {
444 ///
445 template <unsigned N> class AtLeastNOperands {
446 public:
447 template <typename ConcreteType>
448 class Impl : public detail::MultiOperandTraitBase<ConcreteType,
449 AtLeastNOperands<N>::Impl> {
450 public:
verifyTrait(Operation * op)451 static LogicalResult verifyTrait(Operation *op) {
452 return impl::verifyAtLeastNOperands(op, N);
453 }
454 };
455 };
456
457 /// This class provides the API for ops which have an unknown number of
458 /// SSA operands.
459 template <typename ConcreteType>
460 class VariadicOperands
461 : public detail::MultiOperandTraitBase<ConcreteType, VariadicOperands> {};
462
463 //===----------------------------------------------------------------------===//
464 // Region Traits
465
466 /// This class provides verification for ops that are known to have zero
467 /// regions.
468 template <typename ConcreteType>
469 class ZeroRegion : public TraitBase<ConcreteType, ZeroRegion> {
470 public:
verifyTrait(Operation * op)471 static LogicalResult verifyTrait(Operation *op) {
472 return impl::verifyZeroRegion(op);
473 }
474 };
475
476 namespace detail {
477 /// Utility trait base that provides accessors for derived traits that have
478 /// multiple regions.
479 template <typename ConcreteType, template <typename> class TraitType>
480 struct MultiRegionTraitBase : public TraitBase<ConcreteType, TraitType> {
481 using region_iterator = MutableArrayRef<Region>;
482 using region_range = RegionRange;
483
484 /// Return the number of regions.
getNumRegionsMultiRegionTraitBase485 unsigned getNumRegions() { return this->getOperation()->getNumRegions(); }
486
487 /// Return the region at `index`.
getRegionMultiRegionTraitBase488 Region &getRegion(unsigned i) { return this->getOperation()->getRegion(i); }
489
490 /// Region iterator access.
region_beginMultiRegionTraitBase491 region_iterator region_begin() {
492 return this->getOperation()->region_begin();
493 }
region_endMultiRegionTraitBase494 region_iterator region_end() { return this->getOperation()->region_end(); }
getRegionsMultiRegionTraitBase495 region_range getRegions() { return this->getOperation()->getRegions(); }
496 };
497 } // end namespace detail
498
499 /// This class provides APIs for ops that are known to have a single region.
500 template <typename ConcreteType>
501 class OneRegion : public TraitBase<ConcreteType, OneRegion> {
502 public:
getRegion()503 Region &getRegion() { return this->getOperation()->getRegion(0); }
504
505 /// Returns a range of operations within the region of this operation.
getOps()506 auto getOps() { return getRegion().getOps(); }
507 template <typename OpT>
getOps()508 auto getOps() {
509 return getRegion().template getOps<OpT>();
510 }
511
verifyTrait(Operation * op)512 static LogicalResult verifyTrait(Operation *op) {
513 return impl::verifyOneRegion(op);
514 }
515 };
516
517 /// This class provides the API for ops that are known to have a specified
518 /// number of regions.
519 template <unsigned N> class NRegions {
520 public:
521 static_assert(N > 1, "use ZeroRegion/OneRegion for N < 2");
522
523 template <typename ConcreteType>
524 class Impl
525 : public detail::MultiRegionTraitBase<ConcreteType, NRegions<N>::Impl> {
526 public:
verifyTrait(Operation * op)527 static LogicalResult verifyTrait(Operation *op) {
528 return impl::verifyNRegions(op, N);
529 }
530 };
531 };
532
533 /// This class provides APIs for ops that are known to have at least a specified
534 /// number of regions.
535 template <unsigned N> class AtLeastNRegions {
536 public:
537 template <typename ConcreteType>
538 class Impl : public detail::MultiRegionTraitBase<ConcreteType,
539 AtLeastNRegions<N>::Impl> {
540 public:
verifyTrait(Operation * op)541 static LogicalResult verifyTrait(Operation *op) {
542 return impl::verifyAtLeastNRegions(op, N);
543 }
544 };
545 };
546
547 /// This class provides the API for ops which have an unknown number of
548 /// regions.
549 template <typename ConcreteType>
550 class VariadicRegions
551 : public detail::MultiRegionTraitBase<ConcreteType, VariadicRegions> {};
552
553 //===----------------------------------------------------------------------===//
554 // Result Traits
555
556 /// This class provides return value APIs for ops that are known to have
557 /// zero results.
558 template <typename ConcreteType>
559 class ZeroResult : public TraitBase<ConcreteType, ZeroResult> {
560 public:
verifyTrait(Operation * op)561 static LogicalResult verifyTrait(Operation *op) {
562 return impl::verifyZeroResult(op);
563 }
564 };
565
566 namespace detail {
567 /// Utility trait base that provides accessors for derived traits that have
568 /// multiple results.
569 template <typename ConcreteType, template <typename> class TraitType>
570 struct MultiResultTraitBase : public TraitBase<ConcreteType, TraitType> {
571 using result_iterator = Operation::result_iterator;
572 using result_range = Operation::result_range;
573 using result_type_iterator = Operation::result_type_iterator;
574 using result_type_range = Operation::result_type_range;
575
576 /// Return the number of results.
getNumResultsMultiResultTraitBase577 unsigned getNumResults() { return this->getOperation()->getNumResults(); }
578
579 /// Return the result at index 'i'.
getResultMultiResultTraitBase580 Value getResult(unsigned i) { return this->getOperation()->getResult(i); }
581
582 /// Replace all uses of results of this operation with the provided 'values'.
583 /// 'values' may correspond to an existing operation, or a range of 'Value'.
replaceAllUsesWithMultiResultTraitBase584 template <typename ValuesT> void replaceAllUsesWith(ValuesT &&values) {
585 this->getOperation()->replaceAllUsesWith(std::forward<ValuesT>(values));
586 }
587
588 /// Return the type of the `i`-th result.
getTypeMultiResultTraitBase589 Type getType(unsigned i) { return getResult(i).getType(); }
590
591 /// Result iterator access.
result_beginMultiResultTraitBase592 result_iterator result_begin() {
593 return this->getOperation()->result_begin();
594 }
result_endMultiResultTraitBase595 result_iterator result_end() { return this->getOperation()->result_end(); }
getResultsMultiResultTraitBase596 result_range getResults() { return this->getOperation()->getResults(); }
597
598 /// Result type access.
result_type_beginMultiResultTraitBase599 result_type_iterator result_type_begin() {
600 return this->getOperation()->result_type_begin();
601 }
result_type_endMultiResultTraitBase602 result_type_iterator result_type_end() {
603 return this->getOperation()->result_type_end();
604 }
getResultTypesMultiResultTraitBase605 result_type_range getResultTypes() {
606 return this->getOperation()->getResultTypes();
607 }
608 };
609 } // end namespace detail
610
611 /// This class provides return value APIs for ops that are known to have a
612 /// single result.
613 template <typename ConcreteType>
614 class OneResult : public TraitBase<ConcreteType, OneResult> {
615 public:
getResult()616 Value getResult() { return this->getOperation()->getResult(0); }
getType()617 Type getType() { return getResult().getType(); }
618
619 /// If the operation returns a single value, then the Op can be implicitly
620 /// converted to an Value. This yields the value of the only result.
Value()621 operator Value() { return getResult(); }
622
623 /// Replace all uses of 'this' value with the new value, updating anything in
624 /// the IR that uses 'this' to use the other value instead. When this returns
625 /// there are zero uses of 'this'.
replaceAllUsesWith(Value newValue)626 void replaceAllUsesWith(Value newValue) {
627 getResult().replaceAllUsesWith(newValue);
628 }
629
630 /// Replace all uses of 'this' value with the result of 'op'.
replaceAllUsesWith(Operation * op)631 void replaceAllUsesWith(Operation *op) {
632 this->getOperation()->replaceAllUsesWith(op);
633 }
634
verifyTrait(Operation * op)635 static LogicalResult verifyTrait(Operation *op) {
636 return impl::verifyOneResult(op);
637 }
638 };
639
640 /// This class provides the API for ops that are known to have a specified
641 /// number of results. This is used as a trait like this:
642 ///
643 /// class FooOp : public Op<FooOp, OpTrait::NResults<2>::Impl> {
644 ///
645 template <unsigned N> class NResults {
646 public:
647 static_assert(N > 1, "use ZeroResult/OneResult for N < 2");
648
649 template <typename ConcreteType>
650 class Impl
651 : public detail::MultiResultTraitBase<ConcreteType, NResults<N>::Impl> {
652 public:
verifyTrait(Operation * op)653 static LogicalResult verifyTrait(Operation *op) {
654 return impl::verifyNResults(op, N);
655 }
656 };
657 };
658
659 /// This class provides the API for ops that are known to have at least a
660 /// specified number of results. This is used as a trait like this:
661 ///
662 /// class FooOp : public Op<FooOp, OpTrait::AtLeastNResults<2>::Impl> {
663 ///
664 template <unsigned N> class AtLeastNResults {
665 public:
666 template <typename ConcreteType>
667 class Impl : public detail::MultiResultTraitBase<ConcreteType,
668 AtLeastNResults<N>::Impl> {
669 public:
verifyTrait(Operation * op)670 static LogicalResult verifyTrait(Operation *op) {
671 return impl::verifyAtLeastNResults(op, N);
672 }
673 };
674 };
675
676 /// This class provides the API for ops which have an unknown number of
677 /// results.
678 template <typename ConcreteType>
679 class VariadicResults
680 : public detail::MultiResultTraitBase<ConcreteType, VariadicResults> {};
681
682 //===----------------------------------------------------------------------===//
683 // Terminator Traits
684
685 /// This class provides the API for ops that are known to be terminators.
686 template <typename ConcreteType>
687 class IsTerminator : public TraitBase<ConcreteType, IsTerminator> {
688 public:
getTraitProperties()689 static AbstractOperation::OperationProperties getTraitProperties() {
690 return static_cast<AbstractOperation::OperationProperties>(
691 OperationProperty::Terminator);
692 }
verifyTrait(Operation * op)693 static LogicalResult verifyTrait(Operation *op) {
694 return impl::verifyIsTerminator(op);
695 }
696 };
697
698 /// This class provides verification for ops that are known to have zero
699 /// successors.
700 template <typename ConcreteType>
701 class ZeroSuccessor : public TraitBase<ConcreteType, ZeroSuccessor> {
702 public:
verifyTrait(Operation * op)703 static LogicalResult verifyTrait(Operation *op) {
704 return impl::verifyZeroSuccessor(op);
705 }
706 };
707
708 namespace detail {
709 /// Utility trait base that provides accessors for derived traits that have
710 /// multiple successors.
711 template <typename ConcreteType, template <typename> class TraitType>
712 struct MultiSuccessorTraitBase : public TraitBase<ConcreteType, TraitType> {
713 using succ_iterator = Operation::succ_iterator;
714 using succ_range = SuccessorRange;
715
716 /// Return the number of successors.
getNumSuccessorsMultiSuccessorTraitBase717 unsigned getNumSuccessors() {
718 return this->getOperation()->getNumSuccessors();
719 }
720
721 /// Return the successor at `index`.
getSuccessorMultiSuccessorTraitBase722 Block *getSuccessor(unsigned i) {
723 return this->getOperation()->getSuccessor(i);
724 }
725
726 /// Set the successor at `index`.
setSuccessorMultiSuccessorTraitBase727 void setSuccessor(Block *block, unsigned i) {
728 return this->getOperation()->setSuccessor(block, i);
729 }
730
731 /// Successor iterator access.
succ_beginMultiSuccessorTraitBase732 succ_iterator succ_begin() { return this->getOperation()->succ_begin(); }
succ_endMultiSuccessorTraitBase733 succ_iterator succ_end() { return this->getOperation()->succ_end(); }
getSuccessorsMultiSuccessorTraitBase734 succ_range getSuccessors() { return this->getOperation()->getSuccessors(); }
735 };
736 } // end namespace detail
737
738 /// This class provides APIs for ops that are known to have a single successor.
739 template <typename ConcreteType>
740 class OneSuccessor : public TraitBase<ConcreteType, OneSuccessor> {
741 public:
getSuccessor()742 Block *getSuccessor() { return this->getOperation()->getSuccessor(0); }
setSuccessor(Block * succ)743 void setSuccessor(Block *succ) {
744 this->getOperation()->setSuccessor(succ, 0);
745 }
746
verifyTrait(Operation * op)747 static LogicalResult verifyTrait(Operation *op) {
748 return impl::verifyOneSuccessor(op);
749 }
750 };
751
752 /// This class provides the API for ops that are known to have a specified
753 /// number of successors.
754 template <unsigned N>
755 class NSuccessors {
756 public:
757 static_assert(N > 1, "use ZeroSuccessor/OneSuccessor for N < 2");
758
759 template <typename ConcreteType>
760 class Impl : public detail::MultiSuccessorTraitBase<ConcreteType,
761 NSuccessors<N>::Impl> {
762 public:
verifyTrait(Operation * op)763 static LogicalResult verifyTrait(Operation *op) {
764 return impl::verifyNSuccessors(op, N);
765 }
766 };
767 };
768
769 /// This class provides APIs for ops that are known to have at least a specified
770 /// number of successors.
771 template <unsigned N>
772 class AtLeastNSuccessors {
773 public:
774 template <typename ConcreteType>
775 class Impl
776 : public detail::MultiSuccessorTraitBase<ConcreteType,
777 AtLeastNSuccessors<N>::Impl> {
778 public:
verifyTrait(Operation * op)779 static LogicalResult verifyTrait(Operation *op) {
780 return impl::verifyAtLeastNSuccessors(op, N);
781 }
782 };
783 };
784
785 /// This class provides the API for ops which have an unknown number of
786 /// successors.
787 template <typename ConcreteType>
788 class VariadicSuccessors
789 : public detail::MultiSuccessorTraitBase<ConcreteType, VariadicSuccessors> {
790 };
791
792 //===----------------------------------------------------------------------===//
793 // SingleBlockImplicitTerminator
794
795 /// This class provides APIs and verifiers for ops with regions having a single
796 /// block that must terminate with `TerminatorOpType`.
797 template <typename TerminatorOpType>
798 struct SingleBlockImplicitTerminator {
799 template <typename ConcreteType>
800 class Impl : public TraitBase<ConcreteType, Impl> {
801 private:
802 /// Builds a terminator operation without relying on OpBuilder APIs to avoid
803 /// cyclic header inclusion.
buildTerminatorSingleBlockImplicitTerminator804 static Operation *buildTerminator(OpBuilder &builder, Location loc) {
805 OperationState state(loc, TerminatorOpType::getOperationName());
806 TerminatorOpType::build(builder, state);
807 return Operation::create(state);
808 }
809
810 public:
811 /// The type of the operation used as the implicit terminator type.
812 using ImplicitTerminatorOpT = TerminatorOpType;
813
verifyTraitSingleBlockImplicitTerminator814 static LogicalResult verifyTrait(Operation *op) {
815 for (unsigned i = 0, e = op->getNumRegions(); i < e; ++i) {
816 Region ®ion = op->getRegion(i);
817
818 // Empty regions are fine.
819 if (region.empty())
820 continue;
821
822 // Non-empty regions must contain a single basic block.
823 if (std::next(region.begin()) != region.end())
824 return op->emitOpError("expects region #")
825 << i << " to have 0 or 1 blocks";
826
827 Block &block = region.front();
828 if (block.empty())
829 return op->emitOpError() << "expects a non-empty block";
830 Operation &terminator = block.back();
831 if (isa<TerminatorOpType>(terminator))
832 continue;
833
834 return op->emitOpError("expects regions to end with '" +
835 TerminatorOpType::getOperationName() +
836 "', found '" +
837 terminator.getName().getStringRef() + "'")
838 .attachNote()
839 << "in custom textual format, the absence of terminator implies "
840 "'"
841 << TerminatorOpType::getOperationName() << '\'';
842 }
843
844 return success();
845 }
846
847 /// Ensure that the given region has the terminator required by this trait.
848 /// If OpBuilder is provided, use it to build the terminator and notify the
849 /// OpBuilder litsteners accordingly. If only a Builder is provided, locally
850 /// construct an OpBuilder with no listeners; this should only be used if no
851 /// OpBuilder is available at the call site, e.g., in the parser.
ensureTerminatorSingleBlockImplicitTerminator852 static void ensureTerminator(Region ®ion, Builder &builder,
853 Location loc) {
854 ::mlir::impl::ensureRegionTerminator(region, builder, loc,
855 buildTerminator);
856 }
ensureTerminatorSingleBlockImplicitTerminator857 static void ensureTerminator(Region ®ion, OpBuilder &builder,
858 Location loc) {
859 ::mlir::impl::ensureRegionTerminator(region, builder, loc,
860 buildTerminator);
861 }
862
863 Block *getBody(unsigned idx = 0) {
864 Region ®ion = this->getOperation()->getRegion(idx);
865 assert(!region.empty() && "unexpected empty region");
866 return ®ion.front();
867 }
868 Region &getBodyRegion(unsigned idx = 0) {
869 return this->getOperation()->getRegion(idx);
870 }
871
872 //===------------------------------------------------------------------===//
873 // Single Region Utilities
874 //===------------------------------------------------------------------===//
875
876 /// The following are a set of methods only enabled when the parent
877 /// operation has a single region. Each of these methods take an additional
878 /// template parameter that represents the concrete operation so that we
879 /// can use SFINAE to disable the methods for non-single region operations.
880 template <typename OpT, typename T = void>
881 using enable_if_single_region =
882 typename std::enable_if_t<OpT::template hasTrait<OneRegion>(), T>;
883
884 template <typename OpT = ConcreteType>
beginSingleBlockImplicitTerminator885 enable_if_single_region<OpT, Block::iterator> begin() {
886 return getBody()->begin();
887 }
888 template <typename OpT = ConcreteType>
endSingleBlockImplicitTerminator889 enable_if_single_region<OpT, Block::iterator> end() {
890 return getBody()->end();
891 }
892 template <typename OpT = ConcreteType>
frontSingleBlockImplicitTerminator893 enable_if_single_region<OpT, Operation &> front() {
894 return *begin();
895 }
896
897 /// Insert the operation into the back of the body, before the terminator.
898 template <typename OpT = ConcreteType>
push_backSingleBlockImplicitTerminator899 enable_if_single_region<OpT> push_back(Operation *op) {
900 insert(Block::iterator(getBody()->getTerminator()), op);
901 }
902
903 /// Insert the operation at the given insertion point. Note: The operation
904 /// is never inserted after the terminator, even if the insertion point is
905 /// end().
906 template <typename OpT = ConcreteType>
insertSingleBlockImplicitTerminator907 enable_if_single_region<OpT> insert(Operation *insertPt, Operation *op) {
908 insert(Block::iterator(insertPt), op);
909 }
910 template <typename OpT = ConcreteType>
insertSingleBlockImplicitTerminator911 enable_if_single_region<OpT> insert(Block::iterator insertPt,
912 Operation *op) {
913 auto *body = getBody();
914 if (insertPt == body->end())
915 insertPt = Block::iterator(body->getTerminator());
916 body->getOperations().insert(insertPt, op);
917 }
918 };
919 };
920
921 //===----------------------------------------------------------------------===//
922 // Misc Traits
923
924 /// This class provides verification for ops that are known to have the same
925 /// operand shape: all operands are scalars, vectors/tensors of the same
926 /// shape.
927 template <typename ConcreteType>
928 class SameOperandsShape : public TraitBase<ConcreteType, SameOperandsShape> {
929 public:
verifyTrait(Operation * op)930 static LogicalResult verifyTrait(Operation *op) {
931 return impl::verifySameOperandsShape(op);
932 }
933 };
934
935 /// This class provides verification for ops that are known to have the same
936 /// operand and result shape: both are scalars, vectors/tensors of the same
937 /// shape.
938 template <typename ConcreteType>
939 class SameOperandsAndResultShape
940 : public TraitBase<ConcreteType, SameOperandsAndResultShape> {
941 public:
verifyTrait(Operation * op)942 static LogicalResult verifyTrait(Operation *op) {
943 return impl::verifySameOperandsAndResultShape(op);
944 }
945 };
946
947 /// This class provides verification for ops that are known to have the same
948 /// operand element type (or the type itself if it is scalar).
949 ///
950 template <typename ConcreteType>
951 class SameOperandsElementType
952 : public TraitBase<ConcreteType, SameOperandsElementType> {
953 public:
verifyTrait(Operation * op)954 static LogicalResult verifyTrait(Operation *op) {
955 return impl::verifySameOperandsElementType(op);
956 }
957 };
958
959 /// This class provides verification for ops that are known to have the same
960 /// operand and result element type (or the type itself if it is scalar).
961 ///
962 template <typename ConcreteType>
963 class SameOperandsAndResultElementType
964 : public TraitBase<ConcreteType, SameOperandsAndResultElementType> {
965 public:
verifyTrait(Operation * op)966 static LogicalResult verifyTrait(Operation *op) {
967 return impl::verifySameOperandsAndResultElementType(op);
968 }
969 };
970
971 /// This class provides verification for ops that are known to have the same
972 /// operand and result type.
973 ///
974 /// Note: this trait subsumes the SameOperandsAndResultShape and
975 /// SameOperandsAndResultElementType traits.
976 template <typename ConcreteType>
977 class SameOperandsAndResultType
978 : public TraitBase<ConcreteType, SameOperandsAndResultType> {
979 public:
verifyTrait(Operation * op)980 static LogicalResult verifyTrait(Operation *op) {
981 return impl::verifySameOperandsAndResultType(op);
982 }
983 };
984
985 /// This class verifies that any results of the specified op have a boolean
986 /// type, a vector thereof, or a tensor thereof.
987 template <typename ConcreteType>
988 class ResultsAreBoolLike : public TraitBase<ConcreteType, ResultsAreBoolLike> {
989 public:
verifyTrait(Operation * op)990 static LogicalResult verifyTrait(Operation *op) {
991 return impl::verifyResultsAreBoolLike(op);
992 }
993 };
994
995 /// This class verifies that any results of the specified op have a floating
996 /// point type, a vector thereof, or a tensor thereof.
997 template <typename ConcreteType>
998 class ResultsAreFloatLike
999 : public TraitBase<ConcreteType, ResultsAreFloatLike> {
1000 public:
verifyTrait(Operation * op)1001 static LogicalResult verifyTrait(Operation *op) {
1002 return impl::verifyResultsAreFloatLike(op);
1003 }
1004 };
1005
1006 /// This class verifies that any results of the specified op have a signless
1007 /// integer or index type, a vector thereof, or a tensor thereof.
1008 template <typename ConcreteType>
1009 class ResultsAreSignlessIntegerLike
1010 : public TraitBase<ConcreteType, ResultsAreSignlessIntegerLike> {
1011 public:
verifyTrait(Operation * op)1012 static LogicalResult verifyTrait(Operation *op) {
1013 return impl::verifyResultsAreSignlessIntegerLike(op);
1014 }
1015 };
1016
1017 /// This class adds property that the operation is commutative.
1018 template <typename ConcreteType>
1019 class IsCommutative : public TraitBase<ConcreteType, IsCommutative> {
1020 public:
getTraitProperties()1021 static AbstractOperation::OperationProperties getTraitProperties() {
1022 return static_cast<AbstractOperation::OperationProperties>(
1023 OperationProperty::Commutative);
1024 }
1025 };
1026
1027 /// This class adds property that the operation is an involution.
1028 /// This means a unary to unary operation "f" that satisfies f(f(x)) = x
1029 template <typename ConcreteType>
1030 class IsInvolution : public TraitBase<ConcreteType, IsInvolution> {
1031 public:
verifyTrait(Operation * op)1032 static LogicalResult verifyTrait(Operation *op) {
1033 static_assert(ConcreteType::template hasTrait<OneResult>(),
1034 "expected operation to produce one result");
1035 static_assert(ConcreteType::template hasTrait<OneOperand>(),
1036 "expected operation to take one operand");
1037 static_assert(ConcreteType::template hasTrait<SameOperandsAndResultType>(),
1038 "expected operation to preserve type");
1039 // Involution requires the operation to be side effect free as well
1040 // but currently this check is under a FIXME and is not actually done.
1041 return impl::verifyIsInvolution(op);
1042 }
1043
foldTrait(Operation * op,ArrayRef<Attribute> operands)1044 static OpFoldResult foldTrait(Operation *op, ArrayRef<Attribute> operands) {
1045 return impl::foldInvolution(op);
1046 }
1047 };
1048
1049 /// This class adds property that the operation is idempotent.
1050 /// This means a unary to unary operation "f" that satisfies f(f(x)) = f(x)
1051 template <typename ConcreteType>
1052 class IsIdempotent : public TraitBase<ConcreteType, IsIdempotent> {
1053 public:
verifyTrait(Operation * op)1054 static LogicalResult verifyTrait(Operation *op) {
1055 static_assert(ConcreteType::template hasTrait<OneResult>(),
1056 "expected operation to produce one result");
1057 static_assert(ConcreteType::template hasTrait<OneOperand>(),
1058 "expected operation to take one operand");
1059 static_assert(ConcreteType::template hasTrait<SameOperandsAndResultType>(),
1060 "expected operation to preserve type");
1061 // Idempotent requires the operation to be side effect free as well
1062 // but currently this check is under a FIXME and is not actually done.
1063 return impl::verifyIsIdempotent(op);
1064 }
1065
foldTrait(Operation * op,ArrayRef<Attribute> operands)1066 static OpFoldResult foldTrait(Operation *op, ArrayRef<Attribute> operands) {
1067 return impl::foldIdempotent(op);
1068 }
1069 };
1070
1071 /// This class verifies that all operands of the specified op have a float type,
1072 /// a vector thereof, or a tensor thereof.
1073 template <typename ConcreteType>
1074 class OperandsAreFloatLike
1075 : public TraitBase<ConcreteType, OperandsAreFloatLike> {
1076 public:
verifyTrait(Operation * op)1077 static LogicalResult verifyTrait(Operation *op) {
1078 return impl::verifyOperandsAreFloatLike(op);
1079 }
1080 };
1081
1082 /// This class verifies that all operands of the specified op have a signless
1083 /// integer or index type, a vector thereof, or a tensor thereof.
1084 template <typename ConcreteType>
1085 class OperandsAreSignlessIntegerLike
1086 : public TraitBase<ConcreteType, OperandsAreSignlessIntegerLike> {
1087 public:
verifyTrait(Operation * op)1088 static LogicalResult verifyTrait(Operation *op) {
1089 return impl::verifyOperandsAreSignlessIntegerLike(op);
1090 }
1091 };
1092
1093 /// This class verifies that all operands of the specified op have the same
1094 /// type.
1095 template <typename ConcreteType>
1096 class SameTypeOperands : public TraitBase<ConcreteType, SameTypeOperands> {
1097 public:
verifyTrait(Operation * op)1098 static LogicalResult verifyTrait(Operation *op) {
1099 return impl::verifySameTypeOperands(op);
1100 }
1101 };
1102
1103 /// This class provides the API for a sub-set of ops that are known to be
1104 /// constant-like. These are non-side effecting operations with one result and
1105 /// zero operands that can always be folded to a specific attribute value.
1106 template <typename ConcreteType>
1107 class ConstantLike : public TraitBase<ConcreteType, ConstantLike> {
1108 public:
verifyTrait(Operation * op)1109 static LogicalResult verifyTrait(Operation *op) {
1110 static_assert(ConcreteType::template hasTrait<OneResult>(),
1111 "expected operation to produce one result");
1112 static_assert(ConcreteType::template hasTrait<ZeroOperands>(),
1113 "expected operation to take zero operands");
1114 // TODO: We should verify that the operation can always be folded, but this
1115 // requires that the attributes of the op already be verified. We should add
1116 // support for verifying traits "after" the operation to enable this use
1117 // case.
1118 return success();
1119 }
1120 };
1121
1122 /// This class provides the API for ops that are known to be isolated from
1123 /// above.
1124 template <typename ConcreteType>
1125 class IsIsolatedFromAbove
1126 : public TraitBase<ConcreteType, IsIsolatedFromAbove> {
1127 public:
getTraitProperties()1128 static AbstractOperation::OperationProperties getTraitProperties() {
1129 return static_cast<AbstractOperation::OperationProperties>(
1130 OperationProperty::IsolatedFromAbove);
1131 }
verifyTrait(Operation * op)1132 static LogicalResult verifyTrait(Operation *op) {
1133 for (auto ®ion : op->getRegions())
1134 if (!region.isIsolatedFromAbove(op->getLoc()))
1135 return failure();
1136 return success();
1137 }
1138 };
1139
1140 /// A trait of region holding operations that defines a new scope for polyhedral
1141 /// optimization purposes. Any SSA values of 'index' type that either dominate
1142 /// such an operation or are used at the top-level of such an operation
1143 /// automatically become valid symbols for the polyhedral scope defined by that
1144 /// operation. For more details, see `Traits.md#AffineScope`.
1145 template <typename ConcreteType>
1146 class AffineScope : public TraitBase<ConcreteType, AffineScope> {
1147 public:
verifyTrait(Operation * op)1148 static LogicalResult verifyTrait(Operation *op) {
1149 static_assert(!ConcreteType::template hasTrait<ZeroRegion>(),
1150 "expected operation to have one or more regions");
1151 return success();
1152 }
1153 };
1154
1155 /// A trait of region holding operations that define a new scope for automatic
1156 /// allocations, i.e., allocations that are freed when control is transferred
1157 /// back from the operation's region. Any operations performing such allocations
1158 /// (for eg. std.alloca) will have their allocations automatically freed at
1159 /// their closest enclosing operation with this trait.
1160 template <typename ConcreteType>
1161 class AutomaticAllocationScope
1162 : public TraitBase<ConcreteType, AutomaticAllocationScope> {
1163 public:
verifyTrait(Operation * op)1164 static LogicalResult verifyTrait(Operation *op) {
1165 if (op->hasTrait<ZeroRegion>())
1166 return op->emitOpError("is expected to have regions");
1167 return success();
1168 }
1169 };
1170
1171 /// This class provides a verifier for ops that are expecting their parent
1172 /// to be one of the given parent ops
1173 template <typename... ParentOpTypes>
1174 struct HasParent {
1175 template <typename ConcreteType>
1176 class Impl : public TraitBase<ConcreteType, Impl> {
1177 public:
verifyTraitHasParent1178 static LogicalResult verifyTrait(Operation *op) {
1179 if (llvm::isa<ParentOpTypes...>(op->getParentOp()))
1180 return success();
1181
1182 return op->emitOpError()
1183 << "expects parent op "
1184 << (sizeof...(ParentOpTypes) != 1 ? "to be one of '" : "'")
1185 << llvm::makeArrayRef({ParentOpTypes::getOperationName()...})
1186 << "'";
1187 }
1188 };
1189 };
1190
1191 /// A trait for operations that have an attribute specifying operand segments.
1192 ///
1193 /// Certain operations can have multiple variadic operands and their size
1194 /// relationship is not always known statically. For such cases, we need
1195 /// a per-op-instance specification to divide the operands into logical groups
1196 /// or segments. This can be modeled by attributes. The attribute will be named
1197 /// as `operand_segment_sizes`.
1198 ///
1199 /// This trait verifies the attribute for specifying operand segments has
1200 /// the correct type (1D vector) and values (non-negative), etc.
1201 template <typename ConcreteType>
1202 class AttrSizedOperandSegments
1203 : public TraitBase<ConcreteType, AttrSizedOperandSegments> {
1204 public:
getOperandSegmentSizeAttr()1205 static StringRef getOperandSegmentSizeAttr() {
1206 return "operand_segment_sizes";
1207 }
1208
verifyTrait(Operation * op)1209 static LogicalResult verifyTrait(Operation *op) {
1210 return ::mlir::OpTrait::impl::verifyOperandSizeAttr(
1211 op, getOperandSegmentSizeAttr());
1212 }
1213 };
1214
1215 /// Similar to AttrSizedOperandSegments but used for results.
1216 template <typename ConcreteType>
1217 class AttrSizedResultSegments
1218 : public TraitBase<ConcreteType, AttrSizedResultSegments> {
1219 public:
getResultSegmentSizeAttr()1220 static StringRef getResultSegmentSizeAttr() { return "result_segment_sizes"; }
1221
verifyTrait(Operation * op)1222 static LogicalResult verifyTrait(Operation *op) {
1223 return ::mlir::OpTrait::impl::verifyResultSizeAttr(
1224 op, getResultSegmentSizeAttr());
1225 }
1226 };
1227
1228 /// This trait provides a verifier for ops that are expecting their regions to
1229 /// not have any arguments
1230 template <typename ConcrentType>
1231 struct NoRegionArguments : public TraitBase<ConcrentType, NoRegionArguments> {
verifyTraitNoRegionArguments1232 static LogicalResult verifyTrait(Operation *op) {
1233 return ::mlir::OpTrait::impl::verifyNoRegionArguments(op);
1234 }
1235 };
1236
1237 // This trait is used to flag operations that consume or produce
1238 // values of `MemRef` type where those references can be 'normalized'.
1239 // TODO: Right now, the operands of an operation are either all normalizable,
1240 // or not. In the future, we may want to allow some of the operands to be
1241 // normalizable.
1242 template <typename ConcrentType>
1243 struct MemRefsNormalizable
1244 : public TraitBase<ConcrentType, MemRefsNormalizable> {};
1245
1246 /// This trait tags scalar ops that also can be applied to vectors/tensors, with
1247 /// their semantics on vectors/tensors being elementwise application.
1248 ///
1249 /// NOTE: Not all ops that are "elementwise" in some abstract sense satisfy this
1250 /// trait. In particular, broadcasting behavior is not allowed. This trait
1251 /// describes a set of invariants that allow systematic
1252 /// vectorization/tensorization, and the reverse, scalarization. The properties
1253 /// needed for this also can be used to implement a number of
1254 /// transformations/analyses/interfaces.
1255 ///
1256 /// An `ElementwiseMappable` op must satisfy the following properties:
1257 ///
1258 /// 1. If any result is a vector (resp. tensor), then at least one operand must
1259 /// be a vector (resp. tensor).
1260 /// 2. If any operand is a vector (resp. tensor), then there must be at least
1261 /// one result, and all results must be vectors (resp. tensors).
1262 /// 3. The static types of all vector (resp. tensor) operands and results must
1263 /// have the same shape.
1264 /// 4. In the case of tensor operands, the dynamic shapes of all tensor operands
1265 /// must be the same, otherwise the op has undefined behavior.
1266 /// 5. ("systematic scalarization" property) If an op has vector/tensor
1267 /// operands/results, then the same op, with the operand/result types changed to
1268 /// their corresponding element type, shall be a verifier-valid op.
1269 /// 6. The semantics of the op on vectors (resp. tensors) shall be the same as
1270 /// applying the scalarized version of the op for each corresponding element of
1271 /// the vector (resp. tensor) operands in parallel.
1272 /// 7. ("systematic vectorization/tensorization" property) If an op has
1273 /// scalar operands/results, the op shall remain verifier-valid if all scalar
1274 /// operands are replaced with vectors/tensors of the same shape and
1275 /// corresponding element types.
1276 ///
1277 /// Together, these properties provide an easy way for scalar operations to
1278 /// conveniently generalize their behavior to vectors/tensors, and systematize
1279 /// conversion between these forms.
1280 ///
1281 /// Examples:
1282 /// ```
1283 /// %scalar = "std.addf"(%a, %b) : (f32, f32) -> f32
1284 /// // Applying the systematic vectorization/tensorization property, this op
1285 /// // must also be valid:
1286 /// %tensor = "std.addf"(%a_tensor, %b_tensor)
1287 /// : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>)
1288 ///
1289 /// // These properties generalize well to the cases of non-scalar operands.
1290 /// %select_scalar_pred = "std.select"(%pred, %true_val, %false_val)
1291 /// : (i1, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
1292 /// // Applying the systematic vectorization / tensorization property, this
1293 /// // op must also be valid:
1294 /// %select_tensor_pred = "std.select"(%pred_tensor, %true_val, %false_val)
1295 /// : (tensor<?xi1>, tensor<?xf32>, tensor<?xf32>)
1296 /// -> tensor<?xf32>
1297 /// // Applying the systematic scalarization property, this op must also
1298 /// // be valid.
1299 /// %select_scalar = "std.select"(%pred, %true_val_scalar, %false_val_scalar)
1300 /// : (i1, f32, f32) -> f32
1301 /// ```
1302 ///
1303 /// TODO: Avoid hardcoding vector/tensor, and generalize this to any type
1304 /// implementing a new "ElementwiseMappableTypeInterface" that describes types
1305 /// for which it makes sense to apply a scalar function to each element.
1306 ///
1307 /// Rationale:
1308 /// - 1. and 2. guarantee a well-defined iteration space for 6.
1309 /// - These also exclude the cases of 0 non-scalar operands or 0 non-scalar
1310 /// results, which complicate a generic definition of the iteration space.
1311 /// - 3. guarantees that folding can be done across scalars/vectors/tensors
1312 /// with the same pattern, as otherwise lots of special handling of type
1313 /// mismatches would be needed.
1314 /// - 4. guarantees that no error handling cases need to be considered.
1315 /// - Higher-level dialects should reify any needed guards / error handling
1316 /// code before lowering to an ElementwiseMappable op.
1317 /// - 5. and 6. allow defining the semantics on vectors/tensors via the scalar
1318 /// semantics and provide a constructive procedure for IR transformations
1319 /// to e.g. create scalar loop bodies from tensor ops.
1320 /// - 7. provides the reverse of 5., which when chained together allows
1321 /// reasoning about the relationship between the tensor and vector case.
1322 /// Additionally, it permits reasoning about promoting scalars to
1323 /// vectors/tensors via broadcasting in cases like `%select_scalar_pred`
1324 /// above.
1325 template <typename ConcreteType>
1326 struct ElementwiseMappable
1327 : public TraitBase<ConcreteType, ElementwiseMappable> {
verifyTraitElementwiseMappable1328 static LogicalResult verifyTrait(Operation *op) {
1329 return ::mlir::OpTrait::impl::verifyElementwiseMappable(op);
1330 }
1331 };
1332
1333 } // end namespace OpTrait
1334
1335 //===----------------------------------------------------------------------===//
1336 // Internal Trait Utilities
1337 //===----------------------------------------------------------------------===//
1338
1339 namespace op_definition_impl {
1340 //===----------------------------------------------------------------------===//
1341 // Trait Existence
1342
1343 /// Returns true if this given Trait ID matches the IDs of any of the provided
1344 /// trait types `Traits`.
1345 template <template <typename T> class... Traits>
hasTrait(TypeID traitID)1346 static bool hasTrait(TypeID traitID) {
1347 TypeID traitIDs[] = {TypeID::get<Traits>()...};
1348 for (unsigned i = 0, e = sizeof...(Traits); i != e; ++i)
1349 if (traitIDs[i] == traitID)
1350 return true;
1351 return false;
1352 }
1353
1354 //===----------------------------------------------------------------------===//
1355 // Trait Folding
1356
1357 /// Trait to check if T provides a 'foldTrait' method for single result
1358 /// operations.
1359 template <typename T, typename... Args>
1360 using has_single_result_fold_trait = decltype(T::foldTrait(
1361 std::declval<Operation *>(), std::declval<ArrayRef<Attribute>>()));
1362 template <typename T>
1363 using detect_has_single_result_fold_trait =
1364 llvm::is_detected<has_single_result_fold_trait, T>;
1365 /// Trait to check if T provides a general 'foldTrait' method.
1366 template <typename T, typename... Args>
1367 using has_fold_trait =
1368 decltype(T::foldTrait(std::declval<Operation *>(),
1369 std::declval<ArrayRef<Attribute>>(),
1370 std::declval<SmallVectorImpl<OpFoldResult> &>()));
1371 template <typename T>
1372 using detect_has_fold_trait = llvm::is_detected<has_fold_trait, T>;
1373 /// Trait to check if T provides any `foldTrait` method.
1374 /// NOTE: This should use std::disjunction when C++17 is available.
1375 template <typename T>
1376 using detect_has_any_fold_trait =
1377 std::conditional_t<bool(detect_has_fold_trait<T>::value),
1378 detect_has_fold_trait<T>,
1379 detect_has_single_result_fold_trait<T>>;
1380
1381 /// Returns the result of folding a trait that implements a `foldTrait` function
1382 /// that is specialized for operations that have a single result.
1383 template <typename Trait>
1384 static std::enable_if_t<detect_has_single_result_fold_trait<Trait>::value,
1385 LogicalResult>
foldTrait(Operation * op,ArrayRef<Attribute> operands,SmallVectorImpl<OpFoldResult> & results)1386 foldTrait(Operation *op, ArrayRef<Attribute> operands,
1387 SmallVectorImpl<OpFoldResult> &results) {
1388 assert(op->hasTrait<OpTrait::OneResult>() &&
1389 "expected trait on non single-result operation to implement the "
1390 "general `foldTrait` method");
1391 // If a previous trait has already been folded and replaced this operation, we
1392 // fail to fold this trait.
1393 if (!results.empty())
1394 return failure();
1395
1396 if (OpFoldResult result = Trait::foldTrait(op, operands)) {
1397 if (result.template dyn_cast<Value>() != op->getResult(0))
1398 results.push_back(result);
1399 return success();
1400 }
1401 return failure();
1402 }
1403 /// Returns the result of folding a trait that implements a generalized
1404 /// `foldTrait` function that is supports any operation type.
1405 template <typename Trait>
1406 static std::enable_if_t<detect_has_fold_trait<Trait>::value, LogicalResult>
foldTrait(Operation * op,ArrayRef<Attribute> operands,SmallVectorImpl<OpFoldResult> & results)1407 foldTrait(Operation *op, ArrayRef<Attribute> operands,
1408 SmallVectorImpl<OpFoldResult> &results) {
1409 // If a previous trait has already been folded and replaced this operation, we
1410 // fail to fold this trait.
1411 return results.empty() ? Trait::foldTrait(op, operands, results) : failure();
1412 }
1413
1414 /// The internal implementation of `foldTraits` below that returns the result of
1415 /// folding a set of trait types `Ts` that implement a `foldTrait` method.
1416 template <typename... Ts>
foldTraitsImpl(Operation * op,ArrayRef<Attribute> operands,SmallVectorImpl<OpFoldResult> & results,std::tuple<Ts...> *)1417 static LogicalResult foldTraitsImpl(Operation *op, ArrayRef<Attribute> operands,
1418 SmallVectorImpl<OpFoldResult> &results,
1419 std::tuple<Ts...> *) {
1420 bool anyFolded = false;
1421 (void)std::initializer_list<int>{
1422 (anyFolded |= succeeded(foldTrait<Ts>(op, operands, results)), 0)...};
1423 return success(anyFolded);
1424 }
1425
1426 /// Given a tuple type containing a set of traits that contain a `foldTrait`
1427 /// method, return the result of folding the given operation.
1428 template <typename TraitTupleT>
1429 static std::enable_if_t<std::tuple_size<TraitTupleT>::value != 0, LogicalResult>
foldTraits(Operation * op,ArrayRef<Attribute> operands,SmallVectorImpl<OpFoldResult> & results)1430 foldTraits(Operation *op, ArrayRef<Attribute> operands,
1431 SmallVectorImpl<OpFoldResult> &results) {
1432 return foldTraitsImpl(op, operands, results, (TraitTupleT *)nullptr);
1433 }
1434 /// A variant of the method above that is specialized when there are no traits
1435 /// that contain a `foldTrait` method.
1436 template <typename TraitTupleT>
1437 static std::enable_if_t<std::tuple_size<TraitTupleT>::value == 0, LogicalResult>
foldTraits(Operation * op,ArrayRef<Attribute> operands,SmallVectorImpl<OpFoldResult> & results)1438 foldTraits(Operation *op, ArrayRef<Attribute> operands,
1439 SmallVectorImpl<OpFoldResult> &results) {
1440 return failure();
1441 }
1442
1443 //===----------------------------------------------------------------------===//
1444 // Trait Properties
1445
1446 /// Trait to check if T provides a `getTraitProperties` method.
1447 template <typename T, typename... Args>
1448 using has_get_trait_properties = decltype(T::getTraitProperties());
1449 template <typename T>
1450 using detect_has_get_trait_properties =
1451 llvm::is_detected<has_get_trait_properties, T>;
1452
1453 /// The internal implementation of `getTraitProperties` below that returns the
1454 /// OR of invoking `getTraitProperties` on all of the provided trait types `Ts`.
1455 template <typename... Ts>
1456 static AbstractOperation::OperationProperties
getTraitPropertiesImpl(std::tuple<Ts...> *)1457 getTraitPropertiesImpl(std::tuple<Ts...> *) {
1458 AbstractOperation::OperationProperties result = 0;
1459 (void)std::initializer_list<int>{(result |= Ts::getTraitProperties(), 0)...};
1460 return result;
1461 }
1462
1463 /// Given a tuple type containing a set of traits that contain a
1464 /// `getTraitProperties` method, return the OR of all of the results of invoking
1465 /// those methods.
1466 template <typename TraitTupleT>
getTraitProperties()1467 static AbstractOperation::OperationProperties getTraitProperties() {
1468 return getTraitPropertiesImpl((TraitTupleT *)nullptr);
1469 }
1470
1471 //===----------------------------------------------------------------------===//
1472 // Trait Verification
1473
1474 /// Trait to check if T provides a `verifyTrait` method.
1475 template <typename T, typename... Args>
1476 using has_verify_trait = decltype(T::verifyTrait(std::declval<Operation *>()));
1477 template <typename T>
1478 using detect_has_verify_trait = llvm::is_detected<has_verify_trait, T>;
1479
1480 /// The internal implementation of `verifyTraits` below that returns the result
1481 /// of verifying the current operation with all of the provided trait types
1482 /// `Ts`.
1483 template <typename... Ts>
verifyTraitsImpl(Operation * op,std::tuple<Ts...> *)1484 static LogicalResult verifyTraitsImpl(Operation *op, std::tuple<Ts...> *) {
1485 LogicalResult result = success();
1486 (void)std::initializer_list<int>{
1487 (result = succeeded(result) ? Ts::verifyTrait(op) : failure(), 0)...};
1488 return result;
1489 }
1490
1491 /// Given a tuple type containing a set of traits that contain a
1492 /// `verifyTrait` method, return the result of verifying the given operation.
1493 template <typename TraitTupleT>
verifyTraits(Operation * op)1494 static LogicalResult verifyTraits(Operation *op) {
1495 return verifyTraitsImpl(op, (TraitTupleT *)nullptr);
1496 }
1497 } // namespace op_definition_impl
1498
1499 //===----------------------------------------------------------------------===//
1500 // Operation Definition classes
1501 //===----------------------------------------------------------------------===//
1502
1503 /// This provides public APIs that all operations should have. The template
1504 /// argument 'ConcreteType' should be the concrete type by CRTP and the others
1505 /// are base classes by the policy pattern.
1506 template <typename ConcreteType, template <typename T> class... Traits>
1507 class Op : public OpState, public Traits<ConcreteType>... {
1508 public:
1509 /// Inherit getOperation from `OpState`.
1510 using OpState::getOperation;
1511
1512 /// Return if this operation contains the provided trait.
1513 template <template <typename T> class Trait>
hasTrait()1514 static constexpr bool hasTrait() {
1515 return llvm::is_one_of<Trait<ConcreteType>, Traits<ConcreteType>...>::value;
1516 }
1517
1518 /// Create a deep copy of this operation.
clone()1519 ConcreteType clone() { return cast<ConcreteType>(getOperation()->clone()); }
1520
1521 /// Create a partial copy of this operation without traversing into attached
1522 /// regions. The new operation will have the same number of regions as the
1523 /// original one, but they will be left empty.
cloneWithoutRegions()1524 ConcreteType cloneWithoutRegions() {
1525 return cast<ConcreteType>(getOperation()->cloneWithoutRegions());
1526 }
1527
1528 /// Return true if this "op class" can match against the specified operation.
classof(Operation * op)1529 static bool classof(Operation *op) {
1530 if (auto *abstractOp = op->getAbstractOperation())
1531 return TypeID::get<ConcreteType>() == abstractOp->typeID;
1532 #ifndef NDEBUG
1533 if (op->getName().getStringRef() == ConcreteType::getOperationName())
1534 llvm::report_fatal_error(
1535 "classof on '" + ConcreteType::getOperationName() +
1536 "' failed due to the operation not being registered");
1537 #endif
1538 return false;
1539 }
1540
1541 /// Expose the type we are instantiated on to template machinery that may want
1542 /// to introspect traits on this operation.
1543 using ConcreteOpType = ConcreteType;
1544
1545 /// This is a public constructor. Any op can be initialized to null.
Op()1546 explicit Op() : OpState(nullptr) {}
Op(std::nullptr_t)1547 Op(std::nullptr_t) : OpState(nullptr) {}
1548
1549 /// This is a public constructor to enable access via the llvm::cast family of
1550 /// methods. This should not be used directly.
Op(Operation * state)1551 explicit Op(Operation *state) : OpState(state) {}
1552
1553 /// Methods for supporting PointerLikeTypeTraits.
getAsOpaquePointer()1554 const void *getAsOpaquePointer() const {
1555 return static_cast<const void *>((Operation *)*this);
1556 }
getFromOpaquePointer(const void * pointer)1557 static ConcreteOpType getFromOpaquePointer(const void *pointer) {
1558 return ConcreteOpType(
1559 reinterpret_cast<Operation *>(const_cast<void *>(pointer)));
1560 }
1561
1562 private:
1563 /// Trait to check if T provides a 'fold' method for a single result op.
1564 template <typename T, typename... Args>
1565 using has_single_result_fold =
1566 decltype(std::declval<T>().fold(std::declval<ArrayRef<Attribute>>()));
1567 template <typename T>
1568 using detect_has_single_result_fold =
1569 llvm::is_detected<has_single_result_fold, T>;
1570 /// Trait to check if T provides a general 'fold' method.
1571 template <typename T, typename... Args>
1572 using has_fold = decltype(
1573 std::declval<T>().fold(std::declval<ArrayRef<Attribute>>(),
1574 std::declval<SmallVectorImpl<OpFoldResult> &>()));
1575 template <typename T> using detect_has_fold = llvm::is_detected<has_fold, T>;
1576 /// Trait to check if T provides a 'print' method.
1577 template <typename T, typename... Args>
1578 using has_print =
1579 decltype(std::declval<T>().print(std::declval<OpAsmPrinter &>()));
1580 template <typename T>
1581 using detect_has_print = llvm::is_detected<has_print, T>;
1582 /// A tuple type containing the traits that have a `foldTrait` function.
1583 using FoldableTraitsTupleT = typename detail::FilterTypes<
1584 op_definition_impl::detect_has_any_fold_trait,
1585 Traits<ConcreteType>...>::type;
1586 /// A tuple type containing the traits that have a verify function.
1587 using VerifiableTraitsTupleT =
1588 typename detail::FilterTypes<op_definition_impl::detect_has_verify_trait,
1589 Traits<ConcreteType>...>::type;
1590
1591 /// Returns the properties of this operation by combining the properties
1592 /// defined by the traits.
getOperationProperties()1593 static AbstractOperation::OperationProperties getOperationProperties() {
1594 return op_definition_impl::getTraitProperties<typename detail::FilterTypes<
1595 op_definition_impl::detect_has_get_trait_properties,
1596 Traits<ConcreteType>...>::type>();
1597 }
1598
1599 /// Returns an interface map containing the interfaces registered to this
1600 /// operation.
getInterfaceMap()1601 static detail::InterfaceMap getInterfaceMap() {
1602 return detail::InterfaceMap::template get<Traits<ConcreteType>...>();
1603 }
1604
1605 /// Return the internal implementations of each of the AbstractOperation
1606 /// hooks.
1607 /// Implementation of `FoldHookFn` AbstractOperation hook.
getFoldHookFn()1608 static AbstractOperation::FoldHookFn getFoldHookFn() {
1609 return getFoldHookFnImpl<ConcreteType>();
1610 }
1611 /// The internal implementation of `getFoldHookFn` above that is invoked if
1612 /// the operation is single result and defines a `fold` method.
1613 template <typename ConcreteOpT>
1614 static std::enable_if_t<llvm::is_one_of<OpTrait::OneResult<ConcreteOpT>,
1615 Traits<ConcreteOpT>...>::value &&
1616 detect_has_single_result_fold<ConcreteOpT>::value,
1617 AbstractOperation::FoldHookFn>
getFoldHookFnImpl()1618 getFoldHookFnImpl() {
1619 return &foldSingleResultHook<ConcreteOpT>;
1620 }
1621 /// The internal implementation of `getFoldHookFn` above that is invoked if
1622 /// the operation is not single result and defines a `fold` method.
1623 template <typename ConcreteOpT>
1624 static std::enable_if_t<!llvm::is_one_of<OpTrait::OneResult<ConcreteOpT>,
1625 Traits<ConcreteOpT>...>::value &&
1626 detect_has_fold<ConcreteOpT>::value,
1627 AbstractOperation::FoldHookFn>
getFoldHookFnImpl()1628 getFoldHookFnImpl() {
1629 return &foldHook<ConcreteOpT>;
1630 }
1631 /// The internal implementation of `getFoldHookFn` above that is invoked if
1632 /// the operation does not define a `fold` method.
1633 template <typename ConcreteOpT>
1634 static std::enable_if_t<!detect_has_single_result_fold<ConcreteOpT>::value &&
1635 !detect_has_fold<ConcreteOpT>::value,
1636 AbstractOperation::FoldHookFn>
getFoldHookFnImpl()1637 getFoldHookFnImpl() {
1638 // In this case, we only need to fold the traits of the operation.
1639 return &op_definition_impl::foldTraits<FoldableTraitsTupleT>;
1640 }
1641 /// Return the result of folding a single result operation that defines a
1642 /// `fold` method.
1643 template <typename ConcreteOpT>
1644 static LogicalResult
foldSingleResultHook(Operation * op,ArrayRef<Attribute> operands,SmallVectorImpl<OpFoldResult> & results)1645 foldSingleResultHook(Operation *op, ArrayRef<Attribute> operands,
1646 SmallVectorImpl<OpFoldResult> &results) {
1647 OpFoldResult result = cast<ConcreteOpT>(op).fold(operands);
1648
1649 // If the fold failed or was in-place, try to fold the traits of the
1650 // operation.
1651 if (!result || result.template dyn_cast<Value>() == op->getResult(0)) {
1652 if (succeeded(op_definition_impl::foldTraits<FoldableTraitsTupleT>(
1653 op, operands, results)))
1654 return success();
1655 return success(static_cast<bool>(result));
1656 }
1657 results.push_back(result);
1658 return success();
1659 }
1660 /// Return the result of folding an operation that defines a `fold` method.
1661 template <typename ConcreteOpT>
foldHook(Operation * op,ArrayRef<Attribute> operands,SmallVectorImpl<OpFoldResult> & results)1662 static LogicalResult foldHook(Operation *op, ArrayRef<Attribute> operands,
1663 SmallVectorImpl<OpFoldResult> &results) {
1664 LogicalResult result = cast<ConcreteOpT>(op).fold(operands, results);
1665
1666 // If the fold failed or was in-place, try to fold the traits of the
1667 // operation.
1668 if (failed(result) || results.empty()) {
1669 if (succeeded(op_definition_impl::foldTraits<FoldableTraitsTupleT>(
1670 op, operands, results)))
1671 return success();
1672 }
1673 return result;
1674 }
1675
1676 /// Implementation of `GetCanonicalizationPatternsFn` AbstractOperation hook.
1677 static AbstractOperation::GetCanonicalizationPatternsFn
getGetCanonicalizationPatternsFn()1678 getGetCanonicalizationPatternsFn() {
1679 return &ConcreteType::getCanonicalizationPatterns;
1680 }
1681 /// Implementation of `GetHasTraitFn`
getHasTraitFn()1682 static AbstractOperation::HasTraitFn getHasTraitFn() {
1683 return &op_definition_impl::hasTrait<Traits...>;
1684 }
1685 /// Implementation of `ParseAssemblyFn` AbstractOperation hook.
getParseAssemblyFn()1686 static AbstractOperation::ParseAssemblyFn getParseAssemblyFn() {
1687 return &ConcreteType::parse;
1688 }
1689 /// Implementation of `PrintAssemblyFn` AbstractOperation hook.
getPrintAssemblyFn()1690 static AbstractOperation::PrintAssemblyFn getPrintAssemblyFn() {
1691 return getPrintAssemblyFnImpl<ConcreteType>();
1692 }
1693 /// The internal implementation of `getPrintAssemblyFn` that is invoked when
1694 /// the concrete operation does not define a `print` method.
1695 template <typename ConcreteOpT>
1696 static std::enable_if_t<!detect_has_print<ConcreteOpT>::value,
1697 AbstractOperation::PrintAssemblyFn>
getPrintAssemblyFnImpl()1698 getPrintAssemblyFnImpl() {
1699 return &OpState::print;
1700 }
1701 /// The internal implementation of `getPrintAssemblyFn` that is invoked when
1702 /// the concrete operation defines a `print` method.
1703 template <typename ConcreteOpT>
1704 static std::enable_if_t<detect_has_print<ConcreteOpT>::value,
1705 AbstractOperation::PrintAssemblyFn>
getPrintAssemblyFnImpl()1706 getPrintAssemblyFnImpl() {
1707 return &printAssembly;
1708 }
printAssembly(Operation * op,OpAsmPrinter & p)1709 static void printAssembly(Operation *op, OpAsmPrinter &p) {
1710 return cast<ConcreteType>(op).print(p);
1711 }
1712 /// Implementation of `VerifyInvariantsFn` AbstractOperation hook.
getVerifyInvariantsFn()1713 static AbstractOperation::VerifyInvariantsFn getVerifyInvariantsFn() {
1714 return &verifyInvariants;
1715 }
verifyInvariants(Operation * op)1716 static LogicalResult verifyInvariants(Operation *op) {
1717 return failure(
1718 failed(op_definition_impl::verifyTraits<VerifiableTraitsTupleT>(op)) ||
1719 failed(cast<ConcreteType>(op).verify()));
1720 }
1721
1722 /// Allow access to internal implementation methods.
1723 friend AbstractOperation;
1724 };
1725
1726 /// This class represents the base of an operation interface. See the definition
1727 /// of `detail::Interface` for requirements on the `Traits` type.
1728 template <typename ConcreteType, typename Traits>
1729 class OpInterface
1730 : public detail::Interface<ConcreteType, Operation *, Traits,
1731 Op<ConcreteType>, OpTrait::TraitBase> {
1732 public:
1733 using Base = OpInterface<ConcreteType, Traits>;
1734 using InterfaceBase = detail::Interface<ConcreteType, Operation *, Traits,
1735 Op<ConcreteType>, OpTrait::TraitBase>;
1736
1737 /// Inherit the base class constructor.
1738 using InterfaceBase::InterfaceBase;
1739
1740 protected:
1741 /// Returns the impl interface instance for the given operation.
getInterfaceFor(Operation * op)1742 static typename InterfaceBase::Concept *getInterfaceFor(Operation *op) {
1743 // Access the raw interface from the abstract operation.
1744 auto *abstractOp = op->getAbstractOperation();
1745 return abstractOp ? abstractOp->getInterface<ConcreteType>() : nullptr;
1746 }
1747
1748 /// Allow access to `getInterfaceFor`.
1749 friend InterfaceBase;
1750 };
1751
1752 //===----------------------------------------------------------------------===//
1753 // Common Operation Folders/Parsers/Printers
1754 //===----------------------------------------------------------------------===//
1755
1756 // These functions are out-of-line implementations of the methods in UnaryOp and
1757 // BinaryOp, which avoids them being template instantiated/duplicated.
1758 namespace impl {
1759 ParseResult parseOneResultOneOperandTypeOp(OpAsmParser &parser,
1760 OperationState &result);
1761
1762 void buildBinaryOp(OpBuilder &builder, OperationState &result, Value lhs,
1763 Value rhs);
1764 ParseResult parseOneResultSameOperandTypeOp(OpAsmParser &parser,
1765 OperationState &result);
1766
1767 // Prints the given binary `op` in custom assembly form if both the two operands
1768 // and the result have the same time. Otherwise, prints the generic assembly
1769 // form.
1770 void printOneResultOp(Operation *op, OpAsmPrinter &p);
1771 } // namespace impl
1772
1773 // These functions are out-of-line implementations of the methods in CastOp,
1774 // which avoids them being template instantiated/duplicated.
1775 namespace impl {
1776 void buildCastOp(OpBuilder &builder, OperationState &result, Value source,
1777 Type destType);
1778 ParseResult parseCastOp(OpAsmParser &parser, OperationState &result);
1779 void printCastOp(Operation *op, OpAsmPrinter &p);
1780 Value foldCastOp(Operation *op);
1781 } // namespace impl
1782 } // end namespace mlir
1783
1784 #endif
1785