• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //===- OpDefinition.h - Classes for defining concrete Op types --*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements helper classes for implementing the "Op" types.  This
10 // includes the Op type, which is the base class for Op class definitions,
11 // as well as number of traits in the OpTrait namespace that provide a
12 // declarative way to specify properties of Ops.
13 //
14 // The purpose of these types are to allow light-weight implementation of
15 // concrete ops (like DimOp) with very little boilerplate.
16 //
17 //===----------------------------------------------------------------------===//
18 
19 #ifndef MLIR_IR_OPDEFINITION_H
20 #define MLIR_IR_OPDEFINITION_H
21 
22 #include "mlir/IR/Operation.h"
23 #include "llvm/Support/PointerLikeTypeTraits.h"
24 
25 #include <type_traits>
26 
27 namespace mlir {
28 class Builder;
29 class OpBuilder;
30 
31 namespace OpTrait {
32 template <typename ConcreteType> class OneResult;
33 }
34 
35 /// This class represents success/failure for operation parsing. It is
36 /// essentially a simple wrapper class around LogicalResult that allows for
37 /// explicit conversion to bool. This allows for the parser to chain together
38 /// parse rules without the clutter of "failed/succeeded".
39 class ParseResult : public LogicalResult {
40 public:
LogicalResult(result)41   ParseResult(LogicalResult result = success()) : LogicalResult(result) {}
42 
43   // Allow diagnostics emitted during parsing to be converted to failure.
ParseResult(const InFlightDiagnostic &)44   ParseResult(const InFlightDiagnostic &) : LogicalResult(failure()) {}
ParseResult(const Diagnostic &)45   ParseResult(const Diagnostic &) : LogicalResult(failure()) {}
46 
47   /// Failure is true in a boolean context.
48   explicit operator bool() const { return failed(*this); }
49 };
50 /// This class implements `Optional` functionality for ParseResult. We don't
51 /// directly use Optional here, because it provides an implicit conversion
52 /// to 'bool' which we want to avoid. This class is used to implement tri-state
53 /// 'parseOptional' functions that may have a failure mode when parsing that
54 /// shouldn't be attributed to "not present".
55 class OptionalParseResult {
56 public:
57   OptionalParseResult() = default;
OptionalParseResult(LogicalResult result)58   OptionalParseResult(LogicalResult result) : impl(result) {}
OptionalParseResult(ParseResult result)59   OptionalParseResult(ParseResult result) : impl(result) {}
OptionalParseResult(const InFlightDiagnostic &)60   OptionalParseResult(const InFlightDiagnostic &)
61       : OptionalParseResult(failure()) {}
OptionalParseResult(llvm::NoneType)62   OptionalParseResult(llvm::NoneType) : impl(llvm::None) {}
63 
64   /// Returns true if we contain a valid ParseResult value.
hasValue()65   bool hasValue() const { return impl.hasValue(); }
66 
67   /// Access the internal ParseResult value.
getValue()68   ParseResult getValue() const { return impl.getValue(); }
69   ParseResult operator*() const { return getValue(); }
70 
71 private:
72   Optional<ParseResult> impl;
73 };
74 
75 // These functions are out-of-line utilities, which avoids them being template
76 // instantiated/duplicated.
77 namespace impl {
78 /// Insert an operation, generated by `buildTerminatorOp`, at the end of the
79 /// region's only block if it does not have a terminator already. If the region
80 /// is empty, insert a new block first. `buildTerminatorOp` should return the
81 /// terminator operation to insert.
82 void ensureRegionTerminator(
83     Region &region, OpBuilder &builder, Location loc,
84     function_ref<Operation *(OpBuilder &, Location)> buildTerminatorOp);
85 void ensureRegionTerminator(
86     Region &region, Builder &builder, Location loc,
87     function_ref<Operation *(OpBuilder &, Location)> buildTerminatorOp);
88 
89 } // namespace impl
90 
91 /// This is the concrete base class that holds the operation pointer and has
92 /// non-generic methods that only depend on State (to avoid having them
93 /// instantiated on template types that don't affect them.
94 ///
95 /// This also has the fallback implementations of customization hooks for when
96 /// they aren't customized.
97 class OpState {
98 public:
99   /// Ops are pointer-like, so we allow implicit conversion to bool.
100   operator bool() { return getOperation() != nullptr; }
101 
102   /// This implicitly converts to Operation*.
103   operator Operation *() const { return state; }
104 
105   /// Shortcut of `->` to access a member of Operation.
106   Operation *operator->() const { return state; }
107 
108   /// Return the operation that this refers to.
getOperation()109   Operation *getOperation() { return state; }
110 
111   /// Return the dialect that this refers to.
getDialect()112   Dialect *getDialect() { return getOperation()->getDialect(); }
113 
114   /// Return the parent Region of this operation.
getParentRegion()115   Region *getParentRegion() { return getOperation()->getParentRegion(); }
116 
117   /// Returns the closest surrounding operation that contains this operation
118   /// or nullptr if this is a top-level operation.
getParentOp()119   Operation *getParentOp() { return getOperation()->getParentOp(); }
120 
121   /// Return the closest surrounding parent operation that is of type 'OpTy'.
getParentOfType()122   template <typename OpTy> OpTy getParentOfType() {
123     return getOperation()->getParentOfType<OpTy>();
124   }
125 
126   /// Returns the closest surrounding parent operation with trait `Trait`.
127   template <template <typename T> class Trait>
getParentWithTrait()128   Operation *getParentWithTrait() {
129     return getOperation()->getParentWithTrait<Trait>();
130   }
131 
132   /// Return the context this operation belongs to.
getContext()133   MLIRContext *getContext() { return getOperation()->getContext(); }
134 
135   /// Print the operation to the given stream.
136   void print(raw_ostream &os, OpPrintingFlags flags = llvm::None) {
137     state->print(os, flags);
138   }
139   void print(raw_ostream &os, AsmState &asmState,
140              OpPrintingFlags flags = llvm::None) {
141     state->print(os, asmState, flags);
142   }
143 
144   /// Dump this operation.
dump()145   void dump() { state->dump(); }
146 
147   /// The source location the operation was defined or derived from.
getLoc()148   Location getLoc() { return state->getLoc(); }
setLoc(Location loc)149   void setLoc(Location loc) { state->setLoc(loc); }
150 
151   /// Return all of the attributes on this operation.
getAttrs()152   ArrayRef<NamedAttribute> getAttrs() { return state->getAttrs(); }
153 
154   /// A utility iterator that filters out non-dialect attributes.
155   using dialect_attr_iterator = Operation::dialect_attr_iterator;
156   using dialect_attr_range = Operation::dialect_attr_range;
157 
158   /// Return a range corresponding to the dialect attributes for this operation.
getDialectAttrs()159   dialect_attr_range getDialectAttrs() { return state->getDialectAttrs(); }
dialect_attr_begin()160   dialect_attr_iterator dialect_attr_begin() {
161     return state->dialect_attr_begin();
162   }
dialect_attr_end()163   dialect_attr_iterator dialect_attr_end() { return state->dialect_attr_end(); }
164 
165   /// Return an attribute with the specified name.
getAttr(StringRef name)166   Attribute getAttr(StringRef name) { return state->getAttr(name); }
167 
168   /// If the operation has an attribute of the specified type, return it.
getAttrOfType(StringRef name)169   template <typename AttrClass> AttrClass getAttrOfType(StringRef name) {
170     return getAttr(name).dyn_cast_or_null<AttrClass>();
171   }
172 
173   /// If the an attribute exists with the specified name, change it to the new
174   /// value.  Otherwise, add a new attribute with the specified name/value.
setAttr(Identifier name,Attribute value)175   void setAttr(Identifier name, Attribute value) {
176     state->setAttr(name, value);
177   }
setAttr(StringRef name,Attribute value)178   void setAttr(StringRef name, Attribute value) {
179     setAttr(Identifier::get(name, getContext()), value);
180   }
181 
182   /// Set the attributes held by this operation.
setAttrs(ArrayRef<NamedAttribute> attributes)183   void setAttrs(ArrayRef<NamedAttribute> attributes) {
184     state->setAttrs(attributes);
185   }
setAttrs(MutableDictionaryAttr newAttrs)186   void setAttrs(MutableDictionaryAttr newAttrs) { state->setAttrs(newAttrs); }
187 
188   /// Set the dialect attributes for this operation, and preserve all dependent.
setDialectAttrs(DialectAttrs && attrs)189   template <typename DialectAttrs> void setDialectAttrs(DialectAttrs &&attrs) {
190     state->setDialectAttrs(std::move(attrs));
191   }
192 
193   /// Remove the attribute with the specified name if it exists.  The return
194   /// value indicates whether the attribute was present or not.
removeAttr(Identifier name)195   MutableDictionaryAttr::RemoveResult removeAttr(Identifier name) {
196     return state->removeAttr(name);
197   }
removeAttr(StringRef name)198   MutableDictionaryAttr::RemoveResult removeAttr(StringRef name) {
199     return state->removeAttr(Identifier::get(name, getContext()));
200   }
201 
202   /// Return true if there are no users of any results of this operation.
use_empty()203   bool use_empty() { return state->use_empty(); }
204 
205   /// Remove this operation from its parent block and delete it.
erase()206   void erase() { state->erase(); }
207 
208   /// Emit an error with the op name prefixed, like "'dim' op " which is
209   /// convenient for verifiers.
210   InFlightDiagnostic emitOpError(const Twine &message = {});
211 
212   /// Emit an error about fatal conditions with this operation, reporting up to
213   /// any diagnostic handlers that may be listening.
214   InFlightDiagnostic emitError(const Twine &message = {});
215 
216   /// Emit a warning about this operation, reporting up to any diagnostic
217   /// handlers that may be listening.
218   InFlightDiagnostic emitWarning(const Twine &message = {});
219 
220   /// Emit a remark about this operation, reporting up to any diagnostic
221   /// handlers that may be listening.
222   InFlightDiagnostic emitRemark(const Twine &message = {});
223 
224   /// Walk the operation in postorder, calling the callback for each nested
225   /// operation(including this one).
226   /// See Operation::walk for more details.
227   template <typename FnT, typename RetT = detail::walkResultType<FnT>>
walk(FnT && callback)228   RetT walk(FnT &&callback) {
229     return state->walk(std::forward<FnT>(callback));
230   }
231 
232   // These are default implementations of customization hooks.
233 public:
234   /// This hook returns any canonicalization pattern rewrites that the operation
235   /// supports, for use by the canonicalization pass.
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)236   static void getCanonicalizationPatterns(OwningRewritePatternList &results,
237                                           MLIRContext *context) {}
238 
239 protected:
240   /// If the concrete type didn't implement a custom verifier hook, just fall
241   /// back to this one which accepts everything.
verify()242   LogicalResult verify() { return success(); }
243 
244   /// Unless overridden, the custom assembly form of an op is always rejected.
245   /// Op implementations should implement this to return failure.
246   /// On success, they should fill in result with the fields to use.
247   static ParseResult parse(OpAsmParser &parser, OperationState &result);
248 
249   // The fallback for the printer is to print it the generic assembly form.
250   static void print(Operation *op, OpAsmPrinter &p);
251 
252   /// Mutability management is handled by the OpWrapper/OpConstWrapper classes,
253   /// so we can cast it away here.
OpState(Operation * state)254   explicit OpState(Operation *state) : state(state) {}
255 
256 private:
257   Operation *state;
258 
259   /// Allow access to internal hook implementation methods.
260   friend AbstractOperation;
261 };
262 
263 // Allow comparing operators.
264 inline bool operator==(OpState lhs, OpState rhs) {
265   return lhs.getOperation() == rhs.getOperation();
266 }
267 inline bool operator!=(OpState lhs, OpState rhs) {
268   return lhs.getOperation() != rhs.getOperation();
269 }
270 
271 /// This class represents a single result from folding an operation.
272 class OpFoldResult : public PointerUnion<Attribute, Value> {
273   using PointerUnion<Attribute, Value>::PointerUnion;
274 };
275 
276 /// Allow printing to a stream.
277 inline raw_ostream &operator<<(raw_ostream &os, OpState &op) {
278   op.print(os, OpPrintingFlags().useLocalScope());
279   return os;
280 }
281 
282 //===----------------------------------------------------------------------===//
283 // Operation Trait Types
284 //===----------------------------------------------------------------------===//
285 
286 namespace OpTrait {
287 
288 // These functions are out-of-line implementations of the methods in the
289 // corresponding trait classes.  This avoids them being template
290 // instantiated/duplicated.
291 namespace impl {
292 OpFoldResult foldIdempotent(Operation *op);
293 OpFoldResult foldInvolution(Operation *op);
294 LogicalResult verifyZeroOperands(Operation *op);
295 LogicalResult verifyOneOperand(Operation *op);
296 LogicalResult verifyNOperands(Operation *op, unsigned numOperands);
297 LogicalResult verifyIsIdempotent(Operation *op);
298 LogicalResult verifyIsInvolution(Operation *op);
299 LogicalResult verifyAtLeastNOperands(Operation *op, unsigned numOperands);
300 LogicalResult verifyOperandsAreFloatLike(Operation *op);
301 LogicalResult verifyOperandsAreSignlessIntegerLike(Operation *op);
302 LogicalResult verifySameTypeOperands(Operation *op);
303 LogicalResult verifyZeroRegion(Operation *op);
304 LogicalResult verifyOneRegion(Operation *op);
305 LogicalResult verifyNRegions(Operation *op, unsigned numRegions);
306 LogicalResult verifyAtLeastNRegions(Operation *op, unsigned numRegions);
307 LogicalResult verifyZeroResult(Operation *op);
308 LogicalResult verifyOneResult(Operation *op);
309 LogicalResult verifyNResults(Operation *op, unsigned numOperands);
310 LogicalResult verifyAtLeastNResults(Operation *op, unsigned numOperands);
311 LogicalResult verifySameOperandsShape(Operation *op);
312 LogicalResult verifySameOperandsAndResultShape(Operation *op);
313 LogicalResult verifySameOperandsElementType(Operation *op);
314 LogicalResult verifySameOperandsAndResultElementType(Operation *op);
315 LogicalResult verifySameOperandsAndResultType(Operation *op);
316 LogicalResult verifyResultsAreBoolLike(Operation *op);
317 LogicalResult verifyResultsAreFloatLike(Operation *op);
318 LogicalResult verifyResultsAreSignlessIntegerLike(Operation *op);
319 LogicalResult verifyIsTerminator(Operation *op);
320 LogicalResult verifyZeroSuccessor(Operation *op);
321 LogicalResult verifyOneSuccessor(Operation *op);
322 LogicalResult verifyNSuccessors(Operation *op, unsigned numSuccessors);
323 LogicalResult verifyAtLeastNSuccessors(Operation *op, unsigned numSuccessors);
324 LogicalResult verifyOperandSizeAttr(Operation *op, StringRef sizeAttrName);
325 LogicalResult verifyResultSizeAttr(Operation *op, StringRef sizeAttrName);
326 LogicalResult verifyNoRegionArguments(Operation *op);
327 LogicalResult verifyElementwiseMappable(Operation *op);
328 } // namespace impl
329 
330 /// Helper class for implementing traits.  Clients are not expected to interact
331 /// with this directly, so its members are all protected.
332 template <typename ConcreteType, template <typename> class TraitType>
333 class TraitBase {
334 protected:
335   /// Return the ultimate Operation being worked on.
getOperation()336   Operation *getOperation() {
337     // We have to cast up to the trait type, then to the concrete type, then to
338     // the BaseState class in explicit hops because the concrete type will
339     // multiply derive from the (content free) TraitBase class, and we need to
340     // be able to disambiguate the path for the C++ compiler.
341     auto *trait = static_cast<TraitType<ConcreteType> *>(this);
342     auto *concrete = static_cast<ConcreteType *>(trait);
343     auto *base = static_cast<OpState *>(concrete);
344     return base->getOperation();
345   }
346 };
347 
348 //===----------------------------------------------------------------------===//
349 // Operand Traits
350 
351 namespace detail {
352 /// Utility trait base that provides accessors for derived traits that have
353 /// multiple operands.
354 template <typename ConcreteType, template <typename> class TraitType>
355 struct MultiOperandTraitBase : public TraitBase<ConcreteType, TraitType> {
356   using operand_iterator = Operation::operand_iterator;
357   using operand_range = Operation::operand_range;
358   using operand_type_iterator = Operation::operand_type_iterator;
359   using operand_type_range = Operation::operand_type_range;
360 
361   /// Return the number of operands.
getNumOperandsMultiOperandTraitBase362   unsigned getNumOperands() { return this->getOperation()->getNumOperands(); }
363 
364   /// Return the operand at index 'i'.
getOperandMultiOperandTraitBase365   Value getOperand(unsigned i) { return this->getOperation()->getOperand(i); }
366 
367   /// Set the operand at index 'i' to 'value'.
setOperandMultiOperandTraitBase368   void setOperand(unsigned i, Value value) {
369     this->getOperation()->setOperand(i, value);
370   }
371 
372   /// Operand iterator access.
operand_beginMultiOperandTraitBase373   operand_iterator operand_begin() {
374     return this->getOperation()->operand_begin();
375   }
operand_endMultiOperandTraitBase376   operand_iterator operand_end() { return this->getOperation()->operand_end(); }
getOperandsMultiOperandTraitBase377   operand_range getOperands() { return this->getOperation()->getOperands(); }
378 
379   /// Operand type access.
operand_type_beginMultiOperandTraitBase380   operand_type_iterator operand_type_begin() {
381     return this->getOperation()->operand_type_begin();
382   }
operand_type_endMultiOperandTraitBase383   operand_type_iterator operand_type_end() {
384     return this->getOperation()->operand_type_end();
385   }
getOperandTypesMultiOperandTraitBase386   operand_type_range getOperandTypes() {
387     return this->getOperation()->getOperandTypes();
388   }
389 };
390 } // end namespace detail
391 
392 /// This class provides the API for ops that are known to have no
393 /// SSA operand.
394 template <typename ConcreteType>
395 class ZeroOperands : public TraitBase<ConcreteType, ZeroOperands> {
396 public:
verifyTrait(Operation * op)397   static LogicalResult verifyTrait(Operation *op) {
398     return impl::verifyZeroOperands(op);
399   }
400 
401 private:
402   // Disable these.
getOperand()403   void getOperand() {}
setOperand()404   void setOperand() {}
405 };
406 
407 /// This class provides the API for ops that are known to have exactly one
408 /// SSA operand.
409 template <typename ConcreteType>
410 class OneOperand : public TraitBase<ConcreteType, OneOperand> {
411 public:
getOperand()412   Value getOperand() { return this->getOperation()->getOperand(0); }
413 
setOperand(Value value)414   void setOperand(Value value) { this->getOperation()->setOperand(0, value); }
415 
verifyTrait(Operation * op)416   static LogicalResult verifyTrait(Operation *op) {
417     return impl::verifyOneOperand(op);
418   }
419 };
420 
421 /// This class provides the API for ops that are known to have a specified
422 /// number of operands.  This is used as a trait like this:
423 ///
424 ///   class FooOp : public Op<FooOp, OpTrait::NOperands<2>::Impl> {
425 ///
426 template <unsigned N> class NOperands {
427 public:
428   static_assert(N > 1, "use ZeroOperands/OneOperand for N < 2");
429 
430   template <typename ConcreteType>
431   class Impl
432       : public detail::MultiOperandTraitBase<ConcreteType, NOperands<N>::Impl> {
433   public:
verifyTrait(Operation * op)434     static LogicalResult verifyTrait(Operation *op) {
435       return impl::verifyNOperands(op, N);
436     }
437   };
438 };
439 
440 /// This class provides the API for ops that are known to have a at least a
441 /// specified number of operands.  This is used as a trait like this:
442 ///
443 ///   class FooOp : public Op<FooOp, OpTrait::AtLeastNOperands<2>::Impl> {
444 ///
445 template <unsigned N> class AtLeastNOperands {
446 public:
447   template <typename ConcreteType>
448   class Impl : public detail::MultiOperandTraitBase<ConcreteType,
449                                                     AtLeastNOperands<N>::Impl> {
450   public:
verifyTrait(Operation * op)451     static LogicalResult verifyTrait(Operation *op) {
452       return impl::verifyAtLeastNOperands(op, N);
453     }
454   };
455 };
456 
457 /// This class provides the API for ops which have an unknown number of
458 /// SSA operands.
459 template <typename ConcreteType>
460 class VariadicOperands
461     : public detail::MultiOperandTraitBase<ConcreteType, VariadicOperands> {};
462 
463 //===----------------------------------------------------------------------===//
464 // Region Traits
465 
466 /// This class provides verification for ops that are known to have zero
467 /// regions.
468 template <typename ConcreteType>
469 class ZeroRegion : public TraitBase<ConcreteType, ZeroRegion> {
470 public:
verifyTrait(Operation * op)471   static LogicalResult verifyTrait(Operation *op) {
472     return impl::verifyZeroRegion(op);
473   }
474 };
475 
476 namespace detail {
477 /// Utility trait base that provides accessors for derived traits that have
478 /// multiple regions.
479 template <typename ConcreteType, template <typename> class TraitType>
480 struct MultiRegionTraitBase : public TraitBase<ConcreteType, TraitType> {
481   using region_iterator = MutableArrayRef<Region>;
482   using region_range = RegionRange;
483 
484   /// Return the number of regions.
getNumRegionsMultiRegionTraitBase485   unsigned getNumRegions() { return this->getOperation()->getNumRegions(); }
486 
487   /// Return the region at `index`.
getRegionMultiRegionTraitBase488   Region &getRegion(unsigned i) { return this->getOperation()->getRegion(i); }
489 
490   /// Region iterator access.
region_beginMultiRegionTraitBase491   region_iterator region_begin() {
492     return this->getOperation()->region_begin();
493   }
region_endMultiRegionTraitBase494   region_iterator region_end() { return this->getOperation()->region_end(); }
getRegionsMultiRegionTraitBase495   region_range getRegions() { return this->getOperation()->getRegions(); }
496 };
497 } // end namespace detail
498 
499 /// This class provides APIs for ops that are known to have a single region.
500 template <typename ConcreteType>
501 class OneRegion : public TraitBase<ConcreteType, OneRegion> {
502 public:
getRegion()503   Region &getRegion() { return this->getOperation()->getRegion(0); }
504 
505   /// Returns a range of operations within the region of this operation.
getOps()506   auto getOps() { return getRegion().getOps(); }
507   template <typename OpT>
getOps()508   auto getOps() {
509     return getRegion().template getOps<OpT>();
510   }
511 
verifyTrait(Operation * op)512   static LogicalResult verifyTrait(Operation *op) {
513     return impl::verifyOneRegion(op);
514   }
515 };
516 
517 /// This class provides the API for ops that are known to have a specified
518 /// number of regions.
519 template <unsigned N> class NRegions {
520 public:
521   static_assert(N > 1, "use ZeroRegion/OneRegion for N < 2");
522 
523   template <typename ConcreteType>
524   class Impl
525       : public detail::MultiRegionTraitBase<ConcreteType, NRegions<N>::Impl> {
526   public:
verifyTrait(Operation * op)527     static LogicalResult verifyTrait(Operation *op) {
528       return impl::verifyNRegions(op, N);
529     }
530   };
531 };
532 
533 /// This class provides APIs for ops that are known to have at least a specified
534 /// number of regions.
535 template <unsigned N> class AtLeastNRegions {
536 public:
537   template <typename ConcreteType>
538   class Impl : public detail::MultiRegionTraitBase<ConcreteType,
539                                                    AtLeastNRegions<N>::Impl> {
540   public:
verifyTrait(Operation * op)541     static LogicalResult verifyTrait(Operation *op) {
542       return impl::verifyAtLeastNRegions(op, N);
543     }
544   };
545 };
546 
547 /// This class provides the API for ops which have an unknown number of
548 /// regions.
549 template <typename ConcreteType>
550 class VariadicRegions
551     : public detail::MultiRegionTraitBase<ConcreteType, VariadicRegions> {};
552 
553 //===----------------------------------------------------------------------===//
554 // Result Traits
555 
556 /// This class provides return value APIs for ops that are known to have
557 /// zero results.
558 template <typename ConcreteType>
559 class ZeroResult : public TraitBase<ConcreteType, ZeroResult> {
560 public:
verifyTrait(Operation * op)561   static LogicalResult verifyTrait(Operation *op) {
562     return impl::verifyZeroResult(op);
563   }
564 };
565 
566 namespace detail {
567 /// Utility trait base that provides accessors for derived traits that have
568 /// multiple results.
569 template <typename ConcreteType, template <typename> class TraitType>
570 struct MultiResultTraitBase : public TraitBase<ConcreteType, TraitType> {
571   using result_iterator = Operation::result_iterator;
572   using result_range = Operation::result_range;
573   using result_type_iterator = Operation::result_type_iterator;
574   using result_type_range = Operation::result_type_range;
575 
576   /// Return the number of results.
getNumResultsMultiResultTraitBase577   unsigned getNumResults() { return this->getOperation()->getNumResults(); }
578 
579   /// Return the result at index 'i'.
getResultMultiResultTraitBase580   Value getResult(unsigned i) { return this->getOperation()->getResult(i); }
581 
582   /// Replace all uses of results of this operation with the provided 'values'.
583   /// 'values' may correspond to an existing operation, or a range of 'Value'.
replaceAllUsesWithMultiResultTraitBase584   template <typename ValuesT> void replaceAllUsesWith(ValuesT &&values) {
585     this->getOperation()->replaceAllUsesWith(std::forward<ValuesT>(values));
586   }
587 
588   /// Return the type of the `i`-th result.
getTypeMultiResultTraitBase589   Type getType(unsigned i) { return getResult(i).getType(); }
590 
591   /// Result iterator access.
result_beginMultiResultTraitBase592   result_iterator result_begin() {
593     return this->getOperation()->result_begin();
594   }
result_endMultiResultTraitBase595   result_iterator result_end() { return this->getOperation()->result_end(); }
getResultsMultiResultTraitBase596   result_range getResults() { return this->getOperation()->getResults(); }
597 
598   /// Result type access.
result_type_beginMultiResultTraitBase599   result_type_iterator result_type_begin() {
600     return this->getOperation()->result_type_begin();
601   }
result_type_endMultiResultTraitBase602   result_type_iterator result_type_end() {
603     return this->getOperation()->result_type_end();
604   }
getResultTypesMultiResultTraitBase605   result_type_range getResultTypes() {
606     return this->getOperation()->getResultTypes();
607   }
608 };
609 } // end namespace detail
610 
611 /// This class provides return value APIs for ops that are known to have a
612 /// single result.
613 template <typename ConcreteType>
614 class OneResult : public TraitBase<ConcreteType, OneResult> {
615 public:
getResult()616   Value getResult() { return this->getOperation()->getResult(0); }
getType()617   Type getType() { return getResult().getType(); }
618 
619   /// If the operation returns a single value, then the Op can be implicitly
620   /// converted to an Value. This yields the value of the only result.
Value()621   operator Value() { return getResult(); }
622 
623   /// Replace all uses of 'this' value with the new value, updating anything in
624   /// the IR that uses 'this' to use the other value instead.  When this returns
625   /// there are zero uses of 'this'.
replaceAllUsesWith(Value newValue)626   void replaceAllUsesWith(Value newValue) {
627     getResult().replaceAllUsesWith(newValue);
628   }
629 
630   /// Replace all uses of 'this' value with the result of 'op'.
replaceAllUsesWith(Operation * op)631   void replaceAllUsesWith(Operation *op) {
632     this->getOperation()->replaceAllUsesWith(op);
633   }
634 
verifyTrait(Operation * op)635   static LogicalResult verifyTrait(Operation *op) {
636     return impl::verifyOneResult(op);
637   }
638 };
639 
640 /// This class provides the API for ops that are known to have a specified
641 /// number of results.  This is used as a trait like this:
642 ///
643 ///   class FooOp : public Op<FooOp, OpTrait::NResults<2>::Impl> {
644 ///
645 template <unsigned N> class NResults {
646 public:
647   static_assert(N > 1, "use ZeroResult/OneResult for N < 2");
648 
649   template <typename ConcreteType>
650   class Impl
651       : public detail::MultiResultTraitBase<ConcreteType, NResults<N>::Impl> {
652   public:
verifyTrait(Operation * op)653     static LogicalResult verifyTrait(Operation *op) {
654       return impl::verifyNResults(op, N);
655     }
656   };
657 };
658 
659 /// This class provides the API for ops that are known to have at least a
660 /// specified number of results.  This is used as a trait like this:
661 ///
662 ///   class FooOp : public Op<FooOp, OpTrait::AtLeastNResults<2>::Impl> {
663 ///
664 template <unsigned N> class AtLeastNResults {
665 public:
666   template <typename ConcreteType>
667   class Impl : public detail::MultiResultTraitBase<ConcreteType,
668                                                    AtLeastNResults<N>::Impl> {
669   public:
verifyTrait(Operation * op)670     static LogicalResult verifyTrait(Operation *op) {
671       return impl::verifyAtLeastNResults(op, N);
672     }
673   };
674 };
675 
676 /// This class provides the API for ops which have an unknown number of
677 /// results.
678 template <typename ConcreteType>
679 class VariadicResults
680     : public detail::MultiResultTraitBase<ConcreteType, VariadicResults> {};
681 
682 //===----------------------------------------------------------------------===//
683 // Terminator Traits
684 
685 /// This class provides the API for ops that are known to be terminators.
686 template <typename ConcreteType>
687 class IsTerminator : public TraitBase<ConcreteType, IsTerminator> {
688 public:
getTraitProperties()689   static AbstractOperation::OperationProperties getTraitProperties() {
690     return static_cast<AbstractOperation::OperationProperties>(
691         OperationProperty::Terminator);
692   }
verifyTrait(Operation * op)693   static LogicalResult verifyTrait(Operation *op) {
694     return impl::verifyIsTerminator(op);
695   }
696 };
697 
698 /// This class provides verification for ops that are known to have zero
699 /// successors.
700 template <typename ConcreteType>
701 class ZeroSuccessor : public TraitBase<ConcreteType, ZeroSuccessor> {
702 public:
verifyTrait(Operation * op)703   static LogicalResult verifyTrait(Operation *op) {
704     return impl::verifyZeroSuccessor(op);
705   }
706 };
707 
708 namespace detail {
709 /// Utility trait base that provides accessors for derived traits that have
710 /// multiple successors.
711 template <typename ConcreteType, template <typename> class TraitType>
712 struct MultiSuccessorTraitBase : public TraitBase<ConcreteType, TraitType> {
713   using succ_iterator = Operation::succ_iterator;
714   using succ_range = SuccessorRange;
715 
716   /// Return the number of successors.
getNumSuccessorsMultiSuccessorTraitBase717   unsigned getNumSuccessors() {
718     return this->getOperation()->getNumSuccessors();
719   }
720 
721   /// Return the successor at `index`.
getSuccessorMultiSuccessorTraitBase722   Block *getSuccessor(unsigned i) {
723     return this->getOperation()->getSuccessor(i);
724   }
725 
726   /// Set the successor at `index`.
setSuccessorMultiSuccessorTraitBase727   void setSuccessor(Block *block, unsigned i) {
728     return this->getOperation()->setSuccessor(block, i);
729   }
730 
731   /// Successor iterator access.
succ_beginMultiSuccessorTraitBase732   succ_iterator succ_begin() { return this->getOperation()->succ_begin(); }
succ_endMultiSuccessorTraitBase733   succ_iterator succ_end() { return this->getOperation()->succ_end(); }
getSuccessorsMultiSuccessorTraitBase734   succ_range getSuccessors() { return this->getOperation()->getSuccessors(); }
735 };
736 } // end namespace detail
737 
738 /// This class provides APIs for ops that are known to have a single successor.
739 template <typename ConcreteType>
740 class OneSuccessor : public TraitBase<ConcreteType, OneSuccessor> {
741 public:
getSuccessor()742   Block *getSuccessor() { return this->getOperation()->getSuccessor(0); }
setSuccessor(Block * succ)743   void setSuccessor(Block *succ) {
744     this->getOperation()->setSuccessor(succ, 0);
745   }
746 
verifyTrait(Operation * op)747   static LogicalResult verifyTrait(Operation *op) {
748     return impl::verifyOneSuccessor(op);
749   }
750 };
751 
752 /// This class provides the API for ops that are known to have a specified
753 /// number of successors.
754 template <unsigned N>
755 class NSuccessors {
756 public:
757   static_assert(N > 1, "use ZeroSuccessor/OneSuccessor for N < 2");
758 
759   template <typename ConcreteType>
760   class Impl : public detail::MultiSuccessorTraitBase<ConcreteType,
761                                                       NSuccessors<N>::Impl> {
762   public:
verifyTrait(Operation * op)763     static LogicalResult verifyTrait(Operation *op) {
764       return impl::verifyNSuccessors(op, N);
765     }
766   };
767 };
768 
769 /// This class provides APIs for ops that are known to have at least a specified
770 /// number of successors.
771 template <unsigned N>
772 class AtLeastNSuccessors {
773 public:
774   template <typename ConcreteType>
775   class Impl
776       : public detail::MultiSuccessorTraitBase<ConcreteType,
777                                                AtLeastNSuccessors<N>::Impl> {
778   public:
verifyTrait(Operation * op)779     static LogicalResult verifyTrait(Operation *op) {
780       return impl::verifyAtLeastNSuccessors(op, N);
781     }
782   };
783 };
784 
785 /// This class provides the API for ops which have an unknown number of
786 /// successors.
787 template <typename ConcreteType>
788 class VariadicSuccessors
789     : public detail::MultiSuccessorTraitBase<ConcreteType, VariadicSuccessors> {
790 };
791 
792 //===----------------------------------------------------------------------===//
793 // SingleBlockImplicitTerminator
794 
795 /// This class provides APIs and verifiers for ops with regions having a single
796 /// block that must terminate with `TerminatorOpType`.
797 template <typename TerminatorOpType>
798 struct SingleBlockImplicitTerminator {
799   template <typename ConcreteType>
800   class Impl : public TraitBase<ConcreteType, Impl> {
801   private:
802     /// Builds a terminator operation without relying on OpBuilder APIs to avoid
803     /// cyclic header inclusion.
buildTerminatorSingleBlockImplicitTerminator804     static Operation *buildTerminator(OpBuilder &builder, Location loc) {
805       OperationState state(loc, TerminatorOpType::getOperationName());
806       TerminatorOpType::build(builder, state);
807       return Operation::create(state);
808     }
809 
810   public:
811     /// The type of the operation used as the implicit terminator type.
812     using ImplicitTerminatorOpT = TerminatorOpType;
813 
verifyTraitSingleBlockImplicitTerminator814     static LogicalResult verifyTrait(Operation *op) {
815       for (unsigned i = 0, e = op->getNumRegions(); i < e; ++i) {
816         Region &region = op->getRegion(i);
817 
818         // Empty regions are fine.
819         if (region.empty())
820           continue;
821 
822         // Non-empty regions must contain a single basic block.
823         if (std::next(region.begin()) != region.end())
824           return op->emitOpError("expects region #")
825                  << i << " to have 0 or 1 blocks";
826 
827         Block &block = region.front();
828         if (block.empty())
829           return op->emitOpError() << "expects a non-empty block";
830         Operation &terminator = block.back();
831         if (isa<TerminatorOpType>(terminator))
832           continue;
833 
834         return op->emitOpError("expects regions to end with '" +
835                                TerminatorOpType::getOperationName() +
836                                "', found '" +
837                                terminator.getName().getStringRef() + "'")
838                    .attachNote()
839                << "in custom textual format, the absence of terminator implies "
840                   "'"
841                << TerminatorOpType::getOperationName() << '\'';
842       }
843 
844       return success();
845     }
846 
847     /// Ensure that the given region has the terminator required by this trait.
848     /// If OpBuilder is provided, use it to build the terminator and notify the
849     /// OpBuilder litsteners accordingly. If only a Builder is provided, locally
850     /// construct an OpBuilder with no listeners; this should only be used if no
851     /// OpBuilder is available at the call site, e.g., in the parser.
ensureTerminatorSingleBlockImplicitTerminator852     static void ensureTerminator(Region &region, Builder &builder,
853                                  Location loc) {
854       ::mlir::impl::ensureRegionTerminator(region, builder, loc,
855                                            buildTerminator);
856     }
ensureTerminatorSingleBlockImplicitTerminator857     static void ensureTerminator(Region &region, OpBuilder &builder,
858                                  Location loc) {
859       ::mlir::impl::ensureRegionTerminator(region, builder, loc,
860                                            buildTerminator);
861     }
862 
863     Block *getBody(unsigned idx = 0) {
864       Region &region = this->getOperation()->getRegion(idx);
865       assert(!region.empty() && "unexpected empty region");
866       return &region.front();
867     }
868     Region &getBodyRegion(unsigned idx = 0) {
869       return this->getOperation()->getRegion(idx);
870     }
871 
872     //===------------------------------------------------------------------===//
873     // Single Region Utilities
874     //===------------------------------------------------------------------===//
875 
876     /// The following are a set of methods only enabled when the parent
877     /// operation has a single region. Each of these methods take an additional
878     /// template parameter that represents the concrete operation so that we
879     /// can use SFINAE to disable the methods for non-single region operations.
880     template <typename OpT, typename T = void>
881     using enable_if_single_region =
882         typename std::enable_if_t<OpT::template hasTrait<OneRegion>(), T>;
883 
884     template <typename OpT = ConcreteType>
beginSingleBlockImplicitTerminator885     enable_if_single_region<OpT, Block::iterator> begin() {
886       return getBody()->begin();
887     }
888     template <typename OpT = ConcreteType>
endSingleBlockImplicitTerminator889     enable_if_single_region<OpT, Block::iterator> end() {
890       return getBody()->end();
891     }
892     template <typename OpT = ConcreteType>
frontSingleBlockImplicitTerminator893     enable_if_single_region<OpT, Operation &> front() {
894       return *begin();
895     }
896 
897     /// Insert the operation into the back of the body, before the terminator.
898     template <typename OpT = ConcreteType>
push_backSingleBlockImplicitTerminator899     enable_if_single_region<OpT> push_back(Operation *op) {
900       insert(Block::iterator(getBody()->getTerminator()), op);
901     }
902 
903     /// Insert the operation at the given insertion point. Note: The operation
904     /// is never inserted after the terminator, even if the insertion point is
905     /// end().
906     template <typename OpT = ConcreteType>
insertSingleBlockImplicitTerminator907     enable_if_single_region<OpT> insert(Operation *insertPt, Operation *op) {
908       insert(Block::iterator(insertPt), op);
909     }
910     template <typename OpT = ConcreteType>
insertSingleBlockImplicitTerminator911     enable_if_single_region<OpT> insert(Block::iterator insertPt,
912                                         Operation *op) {
913       auto *body = getBody();
914       if (insertPt == body->end())
915         insertPt = Block::iterator(body->getTerminator());
916       body->getOperations().insert(insertPt, op);
917     }
918   };
919 };
920 
921 //===----------------------------------------------------------------------===//
922 // Misc Traits
923 
924 /// This class provides verification for ops that are known to have the same
925 /// operand shape: all operands are scalars, vectors/tensors of the same
926 /// shape.
927 template <typename ConcreteType>
928 class SameOperandsShape : public TraitBase<ConcreteType, SameOperandsShape> {
929 public:
verifyTrait(Operation * op)930   static LogicalResult verifyTrait(Operation *op) {
931     return impl::verifySameOperandsShape(op);
932   }
933 };
934 
935 /// This class provides verification for ops that are known to have the same
936 /// operand and result shape: both are scalars, vectors/tensors of the same
937 /// shape.
938 template <typename ConcreteType>
939 class SameOperandsAndResultShape
940     : public TraitBase<ConcreteType, SameOperandsAndResultShape> {
941 public:
verifyTrait(Operation * op)942   static LogicalResult verifyTrait(Operation *op) {
943     return impl::verifySameOperandsAndResultShape(op);
944   }
945 };
946 
947 /// This class provides verification for ops that are known to have the same
948 /// operand element type (or the type itself if it is scalar).
949 ///
950 template <typename ConcreteType>
951 class SameOperandsElementType
952     : public TraitBase<ConcreteType, SameOperandsElementType> {
953 public:
verifyTrait(Operation * op)954   static LogicalResult verifyTrait(Operation *op) {
955     return impl::verifySameOperandsElementType(op);
956   }
957 };
958 
959 /// This class provides verification for ops that are known to have the same
960 /// operand and result element type (or the type itself if it is scalar).
961 ///
962 template <typename ConcreteType>
963 class SameOperandsAndResultElementType
964     : public TraitBase<ConcreteType, SameOperandsAndResultElementType> {
965 public:
verifyTrait(Operation * op)966   static LogicalResult verifyTrait(Operation *op) {
967     return impl::verifySameOperandsAndResultElementType(op);
968   }
969 };
970 
971 /// This class provides verification for ops that are known to have the same
972 /// operand and result type.
973 ///
974 /// Note: this trait subsumes the SameOperandsAndResultShape and
975 /// SameOperandsAndResultElementType traits.
976 template <typename ConcreteType>
977 class SameOperandsAndResultType
978     : public TraitBase<ConcreteType, SameOperandsAndResultType> {
979 public:
verifyTrait(Operation * op)980   static LogicalResult verifyTrait(Operation *op) {
981     return impl::verifySameOperandsAndResultType(op);
982   }
983 };
984 
985 /// This class verifies that any results of the specified op have a boolean
986 /// type, a vector thereof, or a tensor thereof.
987 template <typename ConcreteType>
988 class ResultsAreBoolLike : public TraitBase<ConcreteType, ResultsAreBoolLike> {
989 public:
verifyTrait(Operation * op)990   static LogicalResult verifyTrait(Operation *op) {
991     return impl::verifyResultsAreBoolLike(op);
992   }
993 };
994 
995 /// This class verifies that any results of the specified op have a floating
996 /// point type, a vector thereof, or a tensor thereof.
997 template <typename ConcreteType>
998 class ResultsAreFloatLike
999     : public TraitBase<ConcreteType, ResultsAreFloatLike> {
1000 public:
verifyTrait(Operation * op)1001   static LogicalResult verifyTrait(Operation *op) {
1002     return impl::verifyResultsAreFloatLike(op);
1003   }
1004 };
1005 
1006 /// This class verifies that any results of the specified op have a signless
1007 /// integer or index type, a vector thereof, or a tensor thereof.
1008 template <typename ConcreteType>
1009 class ResultsAreSignlessIntegerLike
1010     : public TraitBase<ConcreteType, ResultsAreSignlessIntegerLike> {
1011 public:
verifyTrait(Operation * op)1012   static LogicalResult verifyTrait(Operation *op) {
1013     return impl::verifyResultsAreSignlessIntegerLike(op);
1014   }
1015 };
1016 
1017 /// This class adds property that the operation is commutative.
1018 template <typename ConcreteType>
1019 class IsCommutative : public TraitBase<ConcreteType, IsCommutative> {
1020 public:
getTraitProperties()1021   static AbstractOperation::OperationProperties getTraitProperties() {
1022     return static_cast<AbstractOperation::OperationProperties>(
1023         OperationProperty::Commutative);
1024   }
1025 };
1026 
1027 /// This class adds property that the operation is an involution.
1028 /// This means a unary to unary operation "f" that satisfies f(f(x)) = x
1029 template <typename ConcreteType>
1030 class IsInvolution : public TraitBase<ConcreteType, IsInvolution> {
1031 public:
verifyTrait(Operation * op)1032   static LogicalResult verifyTrait(Operation *op) {
1033     static_assert(ConcreteType::template hasTrait<OneResult>(),
1034                   "expected operation to produce one result");
1035     static_assert(ConcreteType::template hasTrait<OneOperand>(),
1036                   "expected operation to take one operand");
1037     static_assert(ConcreteType::template hasTrait<SameOperandsAndResultType>(),
1038                   "expected operation to preserve type");
1039     // Involution requires the operation to be side effect free as well
1040     // but currently this check is under a FIXME and is not actually done.
1041     return impl::verifyIsInvolution(op);
1042   }
1043 
foldTrait(Operation * op,ArrayRef<Attribute> operands)1044   static OpFoldResult foldTrait(Operation *op, ArrayRef<Attribute> operands) {
1045     return impl::foldInvolution(op);
1046   }
1047 };
1048 
1049 /// This class adds property that the operation is idempotent.
1050 /// This means a unary to unary operation "f" that satisfies f(f(x)) = f(x)
1051 template <typename ConcreteType>
1052 class IsIdempotent : public TraitBase<ConcreteType, IsIdempotent> {
1053 public:
verifyTrait(Operation * op)1054   static LogicalResult verifyTrait(Operation *op) {
1055     static_assert(ConcreteType::template hasTrait<OneResult>(),
1056                   "expected operation to produce one result");
1057     static_assert(ConcreteType::template hasTrait<OneOperand>(),
1058                   "expected operation to take one operand");
1059     static_assert(ConcreteType::template hasTrait<SameOperandsAndResultType>(),
1060                   "expected operation to preserve type");
1061     // Idempotent requires the operation to be side effect free as well
1062     // but currently this check is under a FIXME and is not actually done.
1063     return impl::verifyIsIdempotent(op);
1064   }
1065 
foldTrait(Operation * op,ArrayRef<Attribute> operands)1066   static OpFoldResult foldTrait(Operation *op, ArrayRef<Attribute> operands) {
1067     return impl::foldIdempotent(op);
1068   }
1069 };
1070 
1071 /// This class verifies that all operands of the specified op have a float type,
1072 /// a vector thereof, or a tensor thereof.
1073 template <typename ConcreteType>
1074 class OperandsAreFloatLike
1075     : public TraitBase<ConcreteType, OperandsAreFloatLike> {
1076 public:
verifyTrait(Operation * op)1077   static LogicalResult verifyTrait(Operation *op) {
1078     return impl::verifyOperandsAreFloatLike(op);
1079   }
1080 };
1081 
1082 /// This class verifies that all operands of the specified op have a signless
1083 /// integer or index type, a vector thereof, or a tensor thereof.
1084 template <typename ConcreteType>
1085 class OperandsAreSignlessIntegerLike
1086     : public TraitBase<ConcreteType, OperandsAreSignlessIntegerLike> {
1087 public:
verifyTrait(Operation * op)1088   static LogicalResult verifyTrait(Operation *op) {
1089     return impl::verifyOperandsAreSignlessIntegerLike(op);
1090   }
1091 };
1092 
1093 /// This class verifies that all operands of the specified op have the same
1094 /// type.
1095 template <typename ConcreteType>
1096 class SameTypeOperands : public TraitBase<ConcreteType, SameTypeOperands> {
1097 public:
verifyTrait(Operation * op)1098   static LogicalResult verifyTrait(Operation *op) {
1099     return impl::verifySameTypeOperands(op);
1100   }
1101 };
1102 
1103 /// This class provides the API for a sub-set of ops that are known to be
1104 /// constant-like. These are non-side effecting operations with one result and
1105 /// zero operands that can always be folded to a specific attribute value.
1106 template <typename ConcreteType>
1107 class ConstantLike : public TraitBase<ConcreteType, ConstantLike> {
1108 public:
verifyTrait(Operation * op)1109   static LogicalResult verifyTrait(Operation *op) {
1110     static_assert(ConcreteType::template hasTrait<OneResult>(),
1111                   "expected operation to produce one result");
1112     static_assert(ConcreteType::template hasTrait<ZeroOperands>(),
1113                   "expected operation to take zero operands");
1114     // TODO: We should verify that the operation can always be folded, but this
1115     // requires that the attributes of the op already be verified. We should add
1116     // support for verifying traits "after" the operation to enable this use
1117     // case.
1118     return success();
1119   }
1120 };
1121 
1122 /// This class provides the API for ops that are known to be isolated from
1123 /// above.
1124 template <typename ConcreteType>
1125 class IsIsolatedFromAbove
1126     : public TraitBase<ConcreteType, IsIsolatedFromAbove> {
1127 public:
getTraitProperties()1128   static AbstractOperation::OperationProperties getTraitProperties() {
1129     return static_cast<AbstractOperation::OperationProperties>(
1130         OperationProperty::IsolatedFromAbove);
1131   }
verifyTrait(Operation * op)1132   static LogicalResult verifyTrait(Operation *op) {
1133     for (auto &region : op->getRegions())
1134       if (!region.isIsolatedFromAbove(op->getLoc()))
1135         return failure();
1136     return success();
1137   }
1138 };
1139 
1140 /// A trait of region holding operations that defines a new scope for polyhedral
1141 /// optimization purposes. Any SSA values of 'index' type that either dominate
1142 /// such an operation or are used at the top-level of such an operation
1143 /// automatically become valid symbols for the polyhedral scope defined by that
1144 /// operation. For more details, see `Traits.md#AffineScope`.
1145 template <typename ConcreteType>
1146 class AffineScope : public TraitBase<ConcreteType, AffineScope> {
1147 public:
verifyTrait(Operation * op)1148   static LogicalResult verifyTrait(Operation *op) {
1149     static_assert(!ConcreteType::template hasTrait<ZeroRegion>(),
1150                   "expected operation to have one or more regions");
1151     return success();
1152   }
1153 };
1154 
1155 /// A trait of region holding operations that define a new scope for automatic
1156 /// allocations, i.e., allocations that are freed when control is transferred
1157 /// back from the operation's region. Any operations performing such allocations
1158 /// (for eg. std.alloca) will have their allocations automatically freed at
1159 /// their closest enclosing operation with this trait.
1160 template <typename ConcreteType>
1161 class AutomaticAllocationScope
1162     : public TraitBase<ConcreteType, AutomaticAllocationScope> {
1163 public:
verifyTrait(Operation * op)1164   static LogicalResult verifyTrait(Operation *op) {
1165     if (op->hasTrait<ZeroRegion>())
1166       return op->emitOpError("is expected to have regions");
1167     return success();
1168   }
1169 };
1170 
1171 /// This class provides a verifier for ops that are expecting their parent
1172 /// to be one of the given parent ops
1173 template <typename... ParentOpTypes>
1174 struct HasParent {
1175   template <typename ConcreteType>
1176   class Impl : public TraitBase<ConcreteType, Impl> {
1177   public:
verifyTraitHasParent1178     static LogicalResult verifyTrait(Operation *op) {
1179       if (llvm::isa<ParentOpTypes...>(op->getParentOp()))
1180         return success();
1181 
1182       return op->emitOpError()
1183              << "expects parent op "
1184              << (sizeof...(ParentOpTypes) != 1 ? "to be one of '" : "'")
1185              << llvm::makeArrayRef({ParentOpTypes::getOperationName()...})
1186              << "'";
1187     }
1188   };
1189 };
1190 
1191 /// A trait for operations that have an attribute specifying operand segments.
1192 ///
1193 /// Certain operations can have multiple variadic operands and their size
1194 /// relationship is not always known statically. For such cases, we need
1195 /// a per-op-instance specification to divide the operands into logical groups
1196 /// or segments. This can be modeled by attributes. The attribute will be named
1197 /// as `operand_segment_sizes`.
1198 ///
1199 /// This trait verifies the attribute for specifying operand segments has
1200 /// the correct type (1D vector) and values (non-negative), etc.
1201 template <typename ConcreteType>
1202 class AttrSizedOperandSegments
1203     : public TraitBase<ConcreteType, AttrSizedOperandSegments> {
1204 public:
getOperandSegmentSizeAttr()1205   static StringRef getOperandSegmentSizeAttr() {
1206     return "operand_segment_sizes";
1207   }
1208 
verifyTrait(Operation * op)1209   static LogicalResult verifyTrait(Operation *op) {
1210     return ::mlir::OpTrait::impl::verifyOperandSizeAttr(
1211         op, getOperandSegmentSizeAttr());
1212   }
1213 };
1214 
1215 /// Similar to AttrSizedOperandSegments but used for results.
1216 template <typename ConcreteType>
1217 class AttrSizedResultSegments
1218     : public TraitBase<ConcreteType, AttrSizedResultSegments> {
1219 public:
getResultSegmentSizeAttr()1220   static StringRef getResultSegmentSizeAttr() { return "result_segment_sizes"; }
1221 
verifyTrait(Operation * op)1222   static LogicalResult verifyTrait(Operation *op) {
1223     return ::mlir::OpTrait::impl::verifyResultSizeAttr(
1224         op, getResultSegmentSizeAttr());
1225   }
1226 };
1227 
1228 /// This trait provides a verifier for ops that are expecting their regions to
1229 /// not have any arguments
1230 template <typename ConcrentType>
1231 struct NoRegionArguments : public TraitBase<ConcrentType, NoRegionArguments> {
verifyTraitNoRegionArguments1232   static LogicalResult verifyTrait(Operation *op) {
1233     return ::mlir::OpTrait::impl::verifyNoRegionArguments(op);
1234   }
1235 };
1236 
1237 // This trait is used to flag operations that consume or produce
1238 // values of `MemRef` type where those references can be 'normalized'.
1239 // TODO: Right now, the operands of an operation are either all normalizable,
1240 // or not. In the future, we may want to allow some of the operands to be
1241 // normalizable.
1242 template <typename ConcrentType>
1243 struct MemRefsNormalizable
1244     : public TraitBase<ConcrentType, MemRefsNormalizable> {};
1245 
1246 /// This trait tags scalar ops that also can be applied to vectors/tensors, with
1247 /// their semantics on vectors/tensors being elementwise application.
1248 ///
1249 /// NOTE: Not all ops that are "elementwise" in some abstract sense satisfy this
1250 /// trait. In particular, broadcasting behavior is not allowed. This trait
1251 /// describes a set of invariants that allow systematic
1252 /// vectorization/tensorization, and the reverse, scalarization. The properties
1253 /// needed for this also can be used to implement a number of
1254 /// transformations/analyses/interfaces.
1255 ///
1256 /// An `ElementwiseMappable` op must satisfy the following properties:
1257 ///
1258 /// 1. If any result is a vector (resp. tensor), then at least one operand must
1259 /// be a vector (resp. tensor).
1260 /// 2. If any operand is a vector (resp. tensor), then there must be at least
1261 /// one result, and all results must be vectors (resp. tensors).
1262 /// 3. The static types of all vector (resp. tensor) operands and results must
1263 /// have the same shape.
1264 /// 4. In the case of tensor operands, the dynamic shapes of all tensor operands
1265 /// must be the same, otherwise the op has undefined behavior.
1266 /// 5. ("systematic scalarization" property) If an op has vector/tensor
1267 /// operands/results, then the same op, with the operand/result types changed to
1268 /// their corresponding element type, shall be a verifier-valid op.
1269 /// 6. The semantics of the op on vectors (resp. tensors) shall be the same as
1270 /// applying the scalarized version of the op for each corresponding element of
1271 /// the vector (resp. tensor) operands in parallel.
1272 /// 7. ("systematic vectorization/tensorization" property) If an op has
1273 /// scalar operands/results, the op shall remain verifier-valid if all scalar
1274 /// operands are replaced with vectors/tensors of the same shape and
1275 /// corresponding element types.
1276 ///
1277 /// Together, these properties provide an easy way for scalar operations to
1278 /// conveniently generalize their behavior to vectors/tensors, and systematize
1279 /// conversion between these forms.
1280 ///
1281 /// Examples:
1282 /// ```
1283 /// %scalar = "std.addf"(%a, %b) : (f32, f32) -> f32
1284 /// // Applying the systematic vectorization/tensorization property, this op
1285 /// // must also be valid:
1286 /// %tensor = "std.addf"(%a_tensor, %b_tensor)
1287 ///           : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>)
1288 ///
1289 /// // These properties generalize well to the cases of non-scalar operands.
1290 /// %select_scalar_pred = "std.select"(%pred, %true_val, %false_val)
1291 ///                       : (i1, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
1292 /// // Applying the systematic vectorization / tensorization property, this
1293 /// // op must also be valid:
1294 /// %select_tensor_pred = "std.select"(%pred_tensor, %true_val, %false_val)
1295 ///                       : (tensor<?xi1>, tensor<?xf32>, tensor<?xf32>)
1296 ///                       -> tensor<?xf32>
1297 /// // Applying the systematic scalarization property, this op must also
1298 /// // be valid.
1299 /// %select_scalar = "std.select"(%pred, %true_val_scalar, %false_val_scalar)
1300 ///                  : (i1, f32, f32) -> f32
1301 /// ```
1302 ///
1303 /// TODO: Avoid hardcoding vector/tensor, and generalize this to any type
1304 /// implementing a new "ElementwiseMappableTypeInterface" that describes types
1305 /// for which it makes sense to apply a scalar function to each element.
1306 ///
1307 /// Rationale:
1308 /// - 1. and 2. guarantee a well-defined iteration space for 6.
1309 ///   - These also exclude the cases of 0 non-scalar operands or 0 non-scalar
1310 ///     results, which complicate a generic definition of the iteration space.
1311 /// - 3. guarantees that folding can be done across scalars/vectors/tensors
1312 ///   with the same pattern, as otherwise lots of special handling of type
1313 ///   mismatches would be needed.
1314 /// - 4. guarantees that no error handling cases need to be considered.
1315 ///   - Higher-level dialects should reify any needed guards / error handling
1316 ///   code before lowering to an ElementwiseMappable op.
1317 /// - 5. and 6. allow defining the semantics on vectors/tensors via the scalar
1318 ///   semantics and provide a constructive procedure for IR transformations
1319 ///   to e.g. create scalar loop bodies from tensor ops.
1320 /// - 7. provides the reverse of 5., which when chained together allows
1321 ///   reasoning about the relationship between the tensor and vector case.
1322 ///   Additionally, it permits reasoning about promoting scalars to
1323 ///   vectors/tensors via broadcasting in cases like `%select_scalar_pred`
1324 ///   above.
1325 template <typename ConcreteType>
1326 struct ElementwiseMappable
1327     : public TraitBase<ConcreteType, ElementwiseMappable> {
verifyTraitElementwiseMappable1328   static LogicalResult verifyTrait(Operation *op) {
1329     return ::mlir::OpTrait::impl::verifyElementwiseMappable(op);
1330   }
1331 };
1332 
1333 } // end namespace OpTrait
1334 
1335 //===----------------------------------------------------------------------===//
1336 // Internal Trait Utilities
1337 //===----------------------------------------------------------------------===//
1338 
1339 namespace op_definition_impl {
1340 //===----------------------------------------------------------------------===//
1341 // Trait Existence
1342 
1343 /// Returns true if this given Trait ID matches the IDs of any of the provided
1344 /// trait types `Traits`.
1345 template <template <typename T> class... Traits>
hasTrait(TypeID traitID)1346 static bool hasTrait(TypeID traitID) {
1347   TypeID traitIDs[] = {TypeID::get<Traits>()...};
1348   for (unsigned i = 0, e = sizeof...(Traits); i != e; ++i)
1349     if (traitIDs[i] == traitID)
1350       return true;
1351   return false;
1352 }
1353 
1354 //===----------------------------------------------------------------------===//
1355 // Trait Folding
1356 
1357 /// Trait to check if T provides a 'foldTrait' method for single result
1358 /// operations.
1359 template <typename T, typename... Args>
1360 using has_single_result_fold_trait = decltype(T::foldTrait(
1361     std::declval<Operation *>(), std::declval<ArrayRef<Attribute>>()));
1362 template <typename T>
1363 using detect_has_single_result_fold_trait =
1364     llvm::is_detected<has_single_result_fold_trait, T>;
1365 /// Trait to check if T provides a general 'foldTrait' method.
1366 template <typename T, typename... Args>
1367 using has_fold_trait =
1368     decltype(T::foldTrait(std::declval<Operation *>(),
1369                           std::declval<ArrayRef<Attribute>>(),
1370                           std::declval<SmallVectorImpl<OpFoldResult> &>()));
1371 template <typename T>
1372 using detect_has_fold_trait = llvm::is_detected<has_fold_trait, T>;
1373 /// Trait to check if T provides any `foldTrait` method.
1374 /// NOTE: This should use std::disjunction when C++17 is available.
1375 template <typename T>
1376 using detect_has_any_fold_trait =
1377     std::conditional_t<bool(detect_has_fold_trait<T>::value),
1378                        detect_has_fold_trait<T>,
1379                        detect_has_single_result_fold_trait<T>>;
1380 
1381 /// Returns the result of folding a trait that implements a `foldTrait` function
1382 /// that is specialized for operations that have a single result.
1383 template <typename Trait>
1384 static std::enable_if_t<detect_has_single_result_fold_trait<Trait>::value,
1385                         LogicalResult>
foldTrait(Operation * op,ArrayRef<Attribute> operands,SmallVectorImpl<OpFoldResult> & results)1386 foldTrait(Operation *op, ArrayRef<Attribute> operands,
1387           SmallVectorImpl<OpFoldResult> &results) {
1388   assert(op->hasTrait<OpTrait::OneResult>() &&
1389          "expected trait on non single-result operation to implement the "
1390          "general `foldTrait` method");
1391   // If a previous trait has already been folded and replaced this operation, we
1392   // fail to fold this trait.
1393   if (!results.empty())
1394     return failure();
1395 
1396   if (OpFoldResult result = Trait::foldTrait(op, operands)) {
1397     if (result.template dyn_cast<Value>() != op->getResult(0))
1398       results.push_back(result);
1399     return success();
1400   }
1401   return failure();
1402 }
1403 /// Returns the result of folding a trait that implements a generalized
1404 /// `foldTrait` function that is supports any operation type.
1405 template <typename Trait>
1406 static std::enable_if_t<detect_has_fold_trait<Trait>::value, LogicalResult>
foldTrait(Operation * op,ArrayRef<Attribute> operands,SmallVectorImpl<OpFoldResult> & results)1407 foldTrait(Operation *op, ArrayRef<Attribute> operands,
1408           SmallVectorImpl<OpFoldResult> &results) {
1409   // If a previous trait has already been folded and replaced this operation, we
1410   // fail to fold this trait.
1411   return results.empty() ? Trait::foldTrait(op, operands, results) : failure();
1412 }
1413 
1414 /// The internal implementation of `foldTraits` below that returns the result of
1415 /// folding a set of trait types `Ts` that implement a `foldTrait` method.
1416 template <typename... Ts>
foldTraitsImpl(Operation * op,ArrayRef<Attribute> operands,SmallVectorImpl<OpFoldResult> & results,std::tuple<Ts...> *)1417 static LogicalResult foldTraitsImpl(Operation *op, ArrayRef<Attribute> operands,
1418                                     SmallVectorImpl<OpFoldResult> &results,
1419                                     std::tuple<Ts...> *) {
1420   bool anyFolded = false;
1421   (void)std::initializer_list<int>{
1422       (anyFolded |= succeeded(foldTrait<Ts>(op, operands, results)), 0)...};
1423   return success(anyFolded);
1424 }
1425 
1426 /// Given a tuple type containing a set of traits that contain a `foldTrait`
1427 /// method, return the result of folding the given operation.
1428 template <typename TraitTupleT>
1429 static std::enable_if_t<std::tuple_size<TraitTupleT>::value != 0, LogicalResult>
foldTraits(Operation * op,ArrayRef<Attribute> operands,SmallVectorImpl<OpFoldResult> & results)1430 foldTraits(Operation *op, ArrayRef<Attribute> operands,
1431            SmallVectorImpl<OpFoldResult> &results) {
1432   return foldTraitsImpl(op, operands, results, (TraitTupleT *)nullptr);
1433 }
1434 /// A variant of the method above that is specialized when there are no traits
1435 /// that contain a `foldTrait` method.
1436 template <typename TraitTupleT>
1437 static std::enable_if_t<std::tuple_size<TraitTupleT>::value == 0, LogicalResult>
foldTraits(Operation * op,ArrayRef<Attribute> operands,SmallVectorImpl<OpFoldResult> & results)1438 foldTraits(Operation *op, ArrayRef<Attribute> operands,
1439            SmallVectorImpl<OpFoldResult> &results) {
1440   return failure();
1441 }
1442 
1443 //===----------------------------------------------------------------------===//
1444 // Trait Properties
1445 
1446 /// Trait to check if T provides a `getTraitProperties` method.
1447 template <typename T, typename... Args>
1448 using has_get_trait_properties = decltype(T::getTraitProperties());
1449 template <typename T>
1450 using detect_has_get_trait_properties =
1451     llvm::is_detected<has_get_trait_properties, T>;
1452 
1453 /// The internal implementation of `getTraitProperties` below that returns the
1454 /// OR of invoking `getTraitProperties` on all of the provided trait types `Ts`.
1455 template <typename... Ts>
1456 static AbstractOperation::OperationProperties
getTraitPropertiesImpl(std::tuple<Ts...> *)1457 getTraitPropertiesImpl(std::tuple<Ts...> *) {
1458   AbstractOperation::OperationProperties result = 0;
1459   (void)std::initializer_list<int>{(result |= Ts::getTraitProperties(), 0)...};
1460   return result;
1461 }
1462 
1463 /// Given a tuple type containing a set of traits that contain a
1464 /// `getTraitProperties` method, return the OR of all of the results of invoking
1465 /// those methods.
1466 template <typename TraitTupleT>
getTraitProperties()1467 static AbstractOperation::OperationProperties getTraitProperties() {
1468   return getTraitPropertiesImpl((TraitTupleT *)nullptr);
1469 }
1470 
1471 //===----------------------------------------------------------------------===//
1472 // Trait Verification
1473 
1474 /// Trait to check if T provides a `verifyTrait` method.
1475 template <typename T, typename... Args>
1476 using has_verify_trait = decltype(T::verifyTrait(std::declval<Operation *>()));
1477 template <typename T>
1478 using detect_has_verify_trait = llvm::is_detected<has_verify_trait, T>;
1479 
1480 /// The internal implementation of `verifyTraits` below that returns the result
1481 /// of verifying the current operation with all of the provided trait types
1482 /// `Ts`.
1483 template <typename... Ts>
verifyTraitsImpl(Operation * op,std::tuple<Ts...> *)1484 static LogicalResult verifyTraitsImpl(Operation *op, std::tuple<Ts...> *) {
1485   LogicalResult result = success();
1486   (void)std::initializer_list<int>{
1487       (result = succeeded(result) ? Ts::verifyTrait(op) : failure(), 0)...};
1488   return result;
1489 }
1490 
1491 /// Given a tuple type containing a set of traits that contain a
1492 /// `verifyTrait` method, return the result of verifying the given operation.
1493 template <typename TraitTupleT>
verifyTraits(Operation * op)1494 static LogicalResult verifyTraits(Operation *op) {
1495   return verifyTraitsImpl(op, (TraitTupleT *)nullptr);
1496 }
1497 } // namespace op_definition_impl
1498 
1499 //===----------------------------------------------------------------------===//
1500 // Operation Definition classes
1501 //===----------------------------------------------------------------------===//
1502 
1503 /// This provides public APIs that all operations should have.  The template
1504 /// argument 'ConcreteType' should be the concrete type by CRTP and the others
1505 /// are base classes by the policy pattern.
1506 template <typename ConcreteType, template <typename T> class... Traits>
1507 class Op : public OpState, public Traits<ConcreteType>... {
1508 public:
1509   /// Inherit getOperation from `OpState`.
1510   using OpState::getOperation;
1511 
1512   /// Return if this operation contains the provided trait.
1513   template <template <typename T> class Trait>
hasTrait()1514   static constexpr bool hasTrait() {
1515     return llvm::is_one_of<Trait<ConcreteType>, Traits<ConcreteType>...>::value;
1516   }
1517 
1518   /// Create a deep copy of this operation.
clone()1519   ConcreteType clone() { return cast<ConcreteType>(getOperation()->clone()); }
1520 
1521   /// Create a partial copy of this operation without traversing into attached
1522   /// regions. The new operation will have the same number of regions as the
1523   /// original one, but they will be left empty.
cloneWithoutRegions()1524   ConcreteType cloneWithoutRegions() {
1525     return cast<ConcreteType>(getOperation()->cloneWithoutRegions());
1526   }
1527 
1528   /// Return true if this "op class" can match against the specified operation.
classof(Operation * op)1529   static bool classof(Operation *op) {
1530     if (auto *abstractOp = op->getAbstractOperation())
1531       return TypeID::get<ConcreteType>() == abstractOp->typeID;
1532 #ifndef NDEBUG
1533     if (op->getName().getStringRef() == ConcreteType::getOperationName())
1534       llvm::report_fatal_error(
1535           "classof on '" + ConcreteType::getOperationName() +
1536           "' failed due to the operation not being registered");
1537 #endif
1538     return false;
1539   }
1540 
1541   /// Expose the type we are instantiated on to template machinery that may want
1542   /// to introspect traits on this operation.
1543   using ConcreteOpType = ConcreteType;
1544 
1545   /// This is a public constructor.  Any op can be initialized to null.
Op()1546   explicit Op() : OpState(nullptr) {}
Op(std::nullptr_t)1547   Op(std::nullptr_t) : OpState(nullptr) {}
1548 
1549   /// This is a public constructor to enable access via the llvm::cast family of
1550   /// methods. This should not be used directly.
Op(Operation * state)1551   explicit Op(Operation *state) : OpState(state) {}
1552 
1553   /// Methods for supporting PointerLikeTypeTraits.
getAsOpaquePointer()1554   const void *getAsOpaquePointer() const {
1555     return static_cast<const void *>((Operation *)*this);
1556   }
getFromOpaquePointer(const void * pointer)1557   static ConcreteOpType getFromOpaquePointer(const void *pointer) {
1558     return ConcreteOpType(
1559         reinterpret_cast<Operation *>(const_cast<void *>(pointer)));
1560   }
1561 
1562 private:
1563   /// Trait to check if T provides a 'fold' method for a single result op.
1564   template <typename T, typename... Args>
1565   using has_single_result_fold =
1566       decltype(std::declval<T>().fold(std::declval<ArrayRef<Attribute>>()));
1567   template <typename T>
1568   using detect_has_single_result_fold =
1569       llvm::is_detected<has_single_result_fold, T>;
1570   /// Trait to check if T provides a general 'fold' method.
1571   template <typename T, typename... Args>
1572   using has_fold = decltype(
1573       std::declval<T>().fold(std::declval<ArrayRef<Attribute>>(),
1574                              std::declval<SmallVectorImpl<OpFoldResult> &>()));
1575   template <typename T> using detect_has_fold = llvm::is_detected<has_fold, T>;
1576   /// Trait to check if T provides a 'print' method.
1577   template <typename T, typename... Args>
1578   using has_print =
1579       decltype(std::declval<T>().print(std::declval<OpAsmPrinter &>()));
1580   template <typename T>
1581   using detect_has_print = llvm::is_detected<has_print, T>;
1582   /// A tuple type containing the traits that have a `foldTrait` function.
1583   using FoldableTraitsTupleT = typename detail::FilterTypes<
1584       op_definition_impl::detect_has_any_fold_trait,
1585       Traits<ConcreteType>...>::type;
1586   /// A tuple type containing the traits that have a verify function.
1587   using VerifiableTraitsTupleT =
1588       typename detail::FilterTypes<op_definition_impl::detect_has_verify_trait,
1589                                    Traits<ConcreteType>...>::type;
1590 
1591   /// Returns the properties of this operation by combining the properties
1592   /// defined by the traits.
getOperationProperties()1593   static AbstractOperation::OperationProperties getOperationProperties() {
1594     return op_definition_impl::getTraitProperties<typename detail::FilterTypes<
1595         op_definition_impl::detect_has_get_trait_properties,
1596         Traits<ConcreteType>...>::type>();
1597   }
1598 
1599   /// Returns an interface map containing the interfaces registered to this
1600   /// operation.
getInterfaceMap()1601   static detail::InterfaceMap getInterfaceMap() {
1602     return detail::InterfaceMap::template get<Traits<ConcreteType>...>();
1603   }
1604 
1605   /// Return the internal implementations of each of the AbstractOperation
1606   /// hooks.
1607   /// Implementation of `FoldHookFn` AbstractOperation hook.
getFoldHookFn()1608   static AbstractOperation::FoldHookFn getFoldHookFn() {
1609     return getFoldHookFnImpl<ConcreteType>();
1610   }
1611   /// The internal implementation of `getFoldHookFn` above that is invoked if
1612   /// the operation is single result and defines a `fold` method.
1613   template <typename ConcreteOpT>
1614   static std::enable_if_t<llvm::is_one_of<OpTrait::OneResult<ConcreteOpT>,
1615                                           Traits<ConcreteOpT>...>::value &&
1616                               detect_has_single_result_fold<ConcreteOpT>::value,
1617                           AbstractOperation::FoldHookFn>
getFoldHookFnImpl()1618   getFoldHookFnImpl() {
1619     return &foldSingleResultHook<ConcreteOpT>;
1620   }
1621   /// The internal implementation of `getFoldHookFn` above that is invoked if
1622   /// the operation is not single result and defines a `fold` method.
1623   template <typename ConcreteOpT>
1624   static std::enable_if_t<!llvm::is_one_of<OpTrait::OneResult<ConcreteOpT>,
1625                                            Traits<ConcreteOpT>...>::value &&
1626                               detect_has_fold<ConcreteOpT>::value,
1627                           AbstractOperation::FoldHookFn>
getFoldHookFnImpl()1628   getFoldHookFnImpl() {
1629     return &foldHook<ConcreteOpT>;
1630   }
1631   /// The internal implementation of `getFoldHookFn` above that is invoked if
1632   /// the operation does not define a `fold` method.
1633   template <typename ConcreteOpT>
1634   static std::enable_if_t<!detect_has_single_result_fold<ConcreteOpT>::value &&
1635                               !detect_has_fold<ConcreteOpT>::value,
1636                           AbstractOperation::FoldHookFn>
getFoldHookFnImpl()1637   getFoldHookFnImpl() {
1638     // In this case, we only need to fold the traits of the operation.
1639     return &op_definition_impl::foldTraits<FoldableTraitsTupleT>;
1640   }
1641   /// Return the result of folding a single result operation that defines a
1642   /// `fold` method.
1643   template <typename ConcreteOpT>
1644   static LogicalResult
foldSingleResultHook(Operation * op,ArrayRef<Attribute> operands,SmallVectorImpl<OpFoldResult> & results)1645   foldSingleResultHook(Operation *op, ArrayRef<Attribute> operands,
1646                        SmallVectorImpl<OpFoldResult> &results) {
1647     OpFoldResult result = cast<ConcreteOpT>(op).fold(operands);
1648 
1649     // If the fold failed or was in-place, try to fold the traits of the
1650     // operation.
1651     if (!result || result.template dyn_cast<Value>() == op->getResult(0)) {
1652       if (succeeded(op_definition_impl::foldTraits<FoldableTraitsTupleT>(
1653               op, operands, results)))
1654         return success();
1655       return success(static_cast<bool>(result));
1656     }
1657     results.push_back(result);
1658     return success();
1659   }
1660   /// Return the result of folding an operation that defines a `fold` method.
1661   template <typename ConcreteOpT>
foldHook(Operation * op,ArrayRef<Attribute> operands,SmallVectorImpl<OpFoldResult> & results)1662   static LogicalResult foldHook(Operation *op, ArrayRef<Attribute> operands,
1663                                 SmallVectorImpl<OpFoldResult> &results) {
1664     LogicalResult result = cast<ConcreteOpT>(op).fold(operands, results);
1665 
1666     // If the fold failed or was in-place, try to fold the traits of the
1667     // operation.
1668     if (failed(result) || results.empty()) {
1669       if (succeeded(op_definition_impl::foldTraits<FoldableTraitsTupleT>(
1670               op, operands, results)))
1671         return success();
1672     }
1673     return result;
1674   }
1675 
1676   /// Implementation of `GetCanonicalizationPatternsFn` AbstractOperation hook.
1677   static AbstractOperation::GetCanonicalizationPatternsFn
getGetCanonicalizationPatternsFn()1678   getGetCanonicalizationPatternsFn() {
1679     return &ConcreteType::getCanonicalizationPatterns;
1680   }
1681   /// Implementation of `GetHasTraitFn`
getHasTraitFn()1682   static AbstractOperation::HasTraitFn getHasTraitFn() {
1683     return &op_definition_impl::hasTrait<Traits...>;
1684   }
1685   /// Implementation of `ParseAssemblyFn` AbstractOperation hook.
getParseAssemblyFn()1686   static AbstractOperation::ParseAssemblyFn getParseAssemblyFn() {
1687     return &ConcreteType::parse;
1688   }
1689   /// Implementation of `PrintAssemblyFn` AbstractOperation hook.
getPrintAssemblyFn()1690   static AbstractOperation::PrintAssemblyFn getPrintAssemblyFn() {
1691     return getPrintAssemblyFnImpl<ConcreteType>();
1692   }
1693   /// The internal implementation of `getPrintAssemblyFn` that is invoked when
1694   /// the concrete operation does not define a `print` method.
1695   template <typename ConcreteOpT>
1696   static std::enable_if_t<!detect_has_print<ConcreteOpT>::value,
1697                           AbstractOperation::PrintAssemblyFn>
getPrintAssemblyFnImpl()1698   getPrintAssemblyFnImpl() {
1699     return &OpState::print;
1700   }
1701   /// The internal implementation of `getPrintAssemblyFn` that is invoked when
1702   /// the concrete operation defines a `print` method.
1703   template <typename ConcreteOpT>
1704   static std::enable_if_t<detect_has_print<ConcreteOpT>::value,
1705                           AbstractOperation::PrintAssemblyFn>
getPrintAssemblyFnImpl()1706   getPrintAssemblyFnImpl() {
1707     return &printAssembly;
1708   }
printAssembly(Operation * op,OpAsmPrinter & p)1709   static void printAssembly(Operation *op, OpAsmPrinter &p) {
1710     return cast<ConcreteType>(op).print(p);
1711   }
1712   /// Implementation of `VerifyInvariantsFn` AbstractOperation hook.
getVerifyInvariantsFn()1713   static AbstractOperation::VerifyInvariantsFn getVerifyInvariantsFn() {
1714     return &verifyInvariants;
1715   }
verifyInvariants(Operation * op)1716   static LogicalResult verifyInvariants(Operation *op) {
1717     return failure(
1718         failed(op_definition_impl::verifyTraits<VerifiableTraitsTupleT>(op)) ||
1719         failed(cast<ConcreteType>(op).verify()));
1720   }
1721 
1722   /// Allow access to internal implementation methods.
1723   friend AbstractOperation;
1724 };
1725 
1726 /// This class represents the base of an operation interface. See the definition
1727 /// of `detail::Interface` for requirements on the `Traits` type.
1728 template <typename ConcreteType, typename Traits>
1729 class OpInterface
1730     : public detail::Interface<ConcreteType, Operation *, Traits,
1731                                Op<ConcreteType>, OpTrait::TraitBase> {
1732 public:
1733   using Base = OpInterface<ConcreteType, Traits>;
1734   using InterfaceBase = detail::Interface<ConcreteType, Operation *, Traits,
1735                                           Op<ConcreteType>, OpTrait::TraitBase>;
1736 
1737   /// Inherit the base class constructor.
1738   using InterfaceBase::InterfaceBase;
1739 
1740 protected:
1741   /// Returns the impl interface instance for the given operation.
getInterfaceFor(Operation * op)1742   static typename InterfaceBase::Concept *getInterfaceFor(Operation *op) {
1743     // Access the raw interface from the abstract operation.
1744     auto *abstractOp = op->getAbstractOperation();
1745     return abstractOp ? abstractOp->getInterface<ConcreteType>() : nullptr;
1746   }
1747 
1748   /// Allow access to `getInterfaceFor`.
1749   friend InterfaceBase;
1750 };
1751 
1752 //===----------------------------------------------------------------------===//
1753 // Common Operation Folders/Parsers/Printers
1754 //===----------------------------------------------------------------------===//
1755 
1756 // These functions are out-of-line implementations of the methods in UnaryOp and
1757 // BinaryOp, which avoids them being template instantiated/duplicated.
1758 namespace impl {
1759 ParseResult parseOneResultOneOperandTypeOp(OpAsmParser &parser,
1760                                            OperationState &result);
1761 
1762 void buildBinaryOp(OpBuilder &builder, OperationState &result, Value lhs,
1763                    Value rhs);
1764 ParseResult parseOneResultSameOperandTypeOp(OpAsmParser &parser,
1765                                             OperationState &result);
1766 
1767 // Prints the given binary `op` in custom assembly form if both the two operands
1768 // and the result have the same time. Otherwise, prints the generic assembly
1769 // form.
1770 void printOneResultOp(Operation *op, OpAsmPrinter &p);
1771 } // namespace impl
1772 
1773 // These functions are out-of-line implementations of the methods in CastOp,
1774 // which avoids them being template instantiated/duplicated.
1775 namespace impl {
1776 void buildCastOp(OpBuilder &builder, OperationState &result, Value source,
1777                  Type destType);
1778 ParseResult parseCastOp(OpAsmParser &parser, OperationState &result);
1779 void printCastOp(Operation *op, OpAsmPrinter &p);
1780 Value foldCastOp(Operation *op);
1781 } // namespace impl
1782 } // end namespace mlir
1783 
1784 #endif
1785