• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //===- Visitors.h - Utilities for visiting operations -----------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines utilities for walking and visiting operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef MLIR_IR_VISITORS_H
14 #define MLIR_IR_VISITORS_H
15 
16 #include "mlir/Support/LLVM.h"
17 #include "mlir/Support/LogicalResult.h"
18 #include "llvm/ADT/STLExtras.h"
19 
20 namespace mlir {
21 class Diagnostic;
22 class InFlightDiagnostic;
23 class Operation;
24 class Block;
25 class Region;
26 
27 /// A utility result that is used to signal if a walk method should be
28 /// interrupted or advance.
29 class WalkResult {
30   enum ResultEnum { Interrupt, Advance } result;
31 
32 public:
WalkResult(ResultEnum result)33   WalkResult(ResultEnum result) : result(result) {}
34 
35   /// Allow LogicalResult to interrupt the walk on failure.
WalkResult(LogicalResult result)36   WalkResult(LogicalResult result)
37       : result(failed(result) ? Interrupt : Advance) {}
38 
39   /// Allow diagnostics to interrupt the walk.
WalkResult(Diagnostic &&)40   WalkResult(Diagnostic &&) : result(Interrupt) {}
WalkResult(InFlightDiagnostic &&)41   WalkResult(InFlightDiagnostic &&) : result(Interrupt) {}
42 
43   bool operator==(const WalkResult &rhs) const { return result == rhs.result; }
44 
interrupt()45   static WalkResult interrupt() { return {Interrupt}; }
advance()46   static WalkResult advance() { return {Advance}; }
47 
48   /// Returns true if the walk was interrupted.
wasInterrupted()49   bool wasInterrupted() const { return result == Interrupt; }
50 };
51 
52 namespace detail {
53 /// Helper templates to deduce the first argument of a callback parameter.
54 template <typename Ret, typename Arg> Arg first_argument_type(Ret (*)(Arg));
55 template <typename Ret, typename F, typename Arg>
56 Arg first_argument_type(Ret (F::*)(Arg));
57 template <typename Ret, typename F, typename Arg>
58 Arg first_argument_type(Ret (F::*)(Arg) const);
59 template <typename F>
60 decltype(first_argument_type(&F::operator())) first_argument_type(F);
61 
62 /// Type definition of the first argument to the given callable 'T'.
63 template <typename T>
64 using first_argument = decltype(first_argument_type(std::declval<T>()));
65 
66 /// Walk all of the regions, blocks, or operations nested under (and including)
67 /// the given operation.
68 void walk(Operation *op, function_ref<void(Region *)> callback);
69 void walk(Operation *op, function_ref<void(Block *)> callback);
70 void walk(Operation *op, function_ref<void(Operation *)> callback);
71 
72 /// Walk all of the regions, blocks, or operations nested under (and including)
73 /// the given operation. These functions walk until an interrupt result is
74 /// returned by the callback.
75 WalkResult walk(Operation *op, function_ref<WalkResult(Region *)> callback);
76 WalkResult walk(Operation *op, function_ref<WalkResult(Block *)> callback);
77 WalkResult walk(Operation *op, function_ref<WalkResult(Operation *)> callback);
78 
79 // Below are a set of functions to walk nested operations. Users should favor
80 // the direct `walk` methods on the IR classes(Operation/Block/etc) over these
81 // methods. They are also templated to allow for statically dispatching based
82 // upon the type of the callback function.
83 
84 /// Walk all of the regions, blocks, or operations nested under (and including)
85 /// the given operation. This method is selected for callbacks that operate on
86 /// Region*, Block*, and Operation*.
87 ///
88 /// Example:
89 ///   op->walk([](Region *r) { ... });
90 ///   op->walk([](Block *b) { ... });
91 ///   op->walk([](Operation *op) { ... });
92 template <
93     typename FuncTy, typename ArgT = detail::first_argument<FuncTy>,
94     typename RetT = decltype(std::declval<FuncTy>()(std::declval<ArgT>()))>
95 typename std::enable_if<
96     llvm::is_one_of<ArgT, Operation *, Region *, Block *>::value, RetT>::type
walk(Operation * op,FuncTy && callback)97 walk(Operation *op, FuncTy &&callback) {
98   return walk(op, function_ref<RetT(ArgT)>(callback));
99 }
100 
101 /// Walk all of the operations of type 'ArgT' nested under and including the
102 /// given operation. This method is selected for void returning callbacks that
103 /// operate on a specific derived operation type.
104 ///
105 /// Example:
106 ///   op->walk([](ReturnOp op) { ... });
107 template <
108     typename FuncTy, typename ArgT = detail::first_argument<FuncTy>,
109     typename RetT = decltype(std::declval<FuncTy>()(std::declval<ArgT>()))>
110 typename std::enable_if<
111     !llvm::is_one_of<ArgT, Operation *, Region *, Block *>::value &&
112         std::is_same<RetT, void>::value,
113     RetT>::type
walk(Operation * op,FuncTy && callback)114 walk(Operation *op, FuncTy &&callback) {
115   auto wrapperFn = [&](Operation *op) {
116     if (auto derivedOp = dyn_cast<ArgT>(op))
117       callback(derivedOp);
118   };
119   return detail::walk(op, function_ref<RetT(Operation *)>(wrapperFn));
120 }
121 
122 /// Walk all of the operations of type 'ArgT' nested under and including the
123 /// given operation. This method is selected for WalkReturn returning
124 /// interruptible callbacks that operate on a specific derived operation type.
125 ///
126 /// Example:
127 ///   op->walk([](ReturnOp op) {
128 ///     if (some_invariant)
129 ///       return WalkResult::interrupt();
130 ///     return WalkResult::advance();
131 ///   });
132 template <
133     typename FuncTy, typename ArgT = detail::first_argument<FuncTy>,
134     typename RetT = decltype(std::declval<FuncTy>()(std::declval<ArgT>()))>
135 typename std::enable_if<
136     !llvm::is_one_of<ArgT, Operation *, Region *, Block *>::value &&
137         std::is_same<RetT, WalkResult>::value,
138     RetT>::type
walk(Operation * op,FuncTy && callback)139 walk(Operation *op, FuncTy &&callback) {
140   auto wrapperFn = [&](Operation *op) {
141     if (auto derivedOp = dyn_cast<ArgT>(op))
142       return callback(derivedOp);
143     return WalkResult::advance();
144   };
145   return detail::walk(op, function_ref<RetT(Operation *)>(wrapperFn));
146 }
147 
148 /// Utility to provide the return type of a templated walk method.
149 template <typename FnT>
150 using walkResultType = decltype(walk(nullptr, std::declval<FnT>()));
151 } // end namespace detail
152 
153 } // namespace mlir
154 
155 #endif
156