1 //===- Visitors.h - Utilities for visiting operations -----------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines utilities for walking and visiting operations.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #ifndef MLIR_IR_VISITORS_H
14 #define MLIR_IR_VISITORS_H
15
16 #include "mlir/Support/LLVM.h"
17 #include "mlir/Support/LogicalResult.h"
18 #include "llvm/ADT/STLExtras.h"
19
20 namespace mlir {
21 class Diagnostic;
22 class InFlightDiagnostic;
23 class Operation;
24 class Block;
25 class Region;
26
27 /// A utility result that is used to signal if a walk method should be
28 /// interrupted or advance.
29 class WalkResult {
30 enum ResultEnum { Interrupt, Advance } result;
31
32 public:
WalkResult(ResultEnum result)33 WalkResult(ResultEnum result) : result(result) {}
34
35 /// Allow LogicalResult to interrupt the walk on failure.
WalkResult(LogicalResult result)36 WalkResult(LogicalResult result)
37 : result(failed(result) ? Interrupt : Advance) {}
38
39 /// Allow diagnostics to interrupt the walk.
WalkResult(Diagnostic &&)40 WalkResult(Diagnostic &&) : result(Interrupt) {}
WalkResult(InFlightDiagnostic &&)41 WalkResult(InFlightDiagnostic &&) : result(Interrupt) {}
42
43 bool operator==(const WalkResult &rhs) const { return result == rhs.result; }
44
interrupt()45 static WalkResult interrupt() { return {Interrupt}; }
advance()46 static WalkResult advance() { return {Advance}; }
47
48 /// Returns true if the walk was interrupted.
wasInterrupted()49 bool wasInterrupted() const { return result == Interrupt; }
50 };
51
52 namespace detail {
53 /// Helper templates to deduce the first argument of a callback parameter.
54 template <typename Ret, typename Arg> Arg first_argument_type(Ret (*)(Arg));
55 template <typename Ret, typename F, typename Arg>
56 Arg first_argument_type(Ret (F::*)(Arg));
57 template <typename Ret, typename F, typename Arg>
58 Arg first_argument_type(Ret (F::*)(Arg) const);
59 template <typename F>
60 decltype(first_argument_type(&F::operator())) first_argument_type(F);
61
62 /// Type definition of the first argument to the given callable 'T'.
63 template <typename T>
64 using first_argument = decltype(first_argument_type(std::declval<T>()));
65
66 /// Walk all of the regions, blocks, or operations nested under (and including)
67 /// the given operation.
68 void walk(Operation *op, function_ref<void(Region *)> callback);
69 void walk(Operation *op, function_ref<void(Block *)> callback);
70 void walk(Operation *op, function_ref<void(Operation *)> callback);
71
72 /// Walk all of the regions, blocks, or operations nested under (and including)
73 /// the given operation. These functions walk until an interrupt result is
74 /// returned by the callback.
75 WalkResult walk(Operation *op, function_ref<WalkResult(Region *)> callback);
76 WalkResult walk(Operation *op, function_ref<WalkResult(Block *)> callback);
77 WalkResult walk(Operation *op, function_ref<WalkResult(Operation *)> callback);
78
79 // Below are a set of functions to walk nested operations. Users should favor
80 // the direct `walk` methods on the IR classes(Operation/Block/etc) over these
81 // methods. They are also templated to allow for statically dispatching based
82 // upon the type of the callback function.
83
84 /// Walk all of the regions, blocks, or operations nested under (and including)
85 /// the given operation. This method is selected for callbacks that operate on
86 /// Region*, Block*, and Operation*.
87 ///
88 /// Example:
89 /// op->walk([](Region *r) { ... });
90 /// op->walk([](Block *b) { ... });
91 /// op->walk([](Operation *op) { ... });
92 template <
93 typename FuncTy, typename ArgT = detail::first_argument<FuncTy>,
94 typename RetT = decltype(std::declval<FuncTy>()(std::declval<ArgT>()))>
95 typename std::enable_if<
96 llvm::is_one_of<ArgT, Operation *, Region *, Block *>::value, RetT>::type
walk(Operation * op,FuncTy && callback)97 walk(Operation *op, FuncTy &&callback) {
98 return walk(op, function_ref<RetT(ArgT)>(callback));
99 }
100
101 /// Walk all of the operations of type 'ArgT' nested under and including the
102 /// given operation. This method is selected for void returning callbacks that
103 /// operate on a specific derived operation type.
104 ///
105 /// Example:
106 /// op->walk([](ReturnOp op) { ... });
107 template <
108 typename FuncTy, typename ArgT = detail::first_argument<FuncTy>,
109 typename RetT = decltype(std::declval<FuncTy>()(std::declval<ArgT>()))>
110 typename std::enable_if<
111 !llvm::is_one_of<ArgT, Operation *, Region *, Block *>::value &&
112 std::is_same<RetT, void>::value,
113 RetT>::type
walk(Operation * op,FuncTy && callback)114 walk(Operation *op, FuncTy &&callback) {
115 auto wrapperFn = [&](Operation *op) {
116 if (auto derivedOp = dyn_cast<ArgT>(op))
117 callback(derivedOp);
118 };
119 return detail::walk(op, function_ref<RetT(Operation *)>(wrapperFn));
120 }
121
122 /// Walk all of the operations of type 'ArgT' nested under and including the
123 /// given operation. This method is selected for WalkReturn returning
124 /// interruptible callbacks that operate on a specific derived operation type.
125 ///
126 /// Example:
127 /// op->walk([](ReturnOp op) {
128 /// if (some_invariant)
129 /// return WalkResult::interrupt();
130 /// return WalkResult::advance();
131 /// });
132 template <
133 typename FuncTy, typename ArgT = detail::first_argument<FuncTy>,
134 typename RetT = decltype(std::declval<FuncTy>()(std::declval<ArgT>()))>
135 typename std::enable_if<
136 !llvm::is_one_of<ArgT, Operation *, Region *, Block *>::value &&
137 std::is_same<RetT, WalkResult>::value,
138 RetT>::type
walk(Operation * op,FuncTy && callback)139 walk(Operation *op, FuncTy &&callback) {
140 auto wrapperFn = [&](Operation *op) {
141 if (auto derivedOp = dyn_cast<ArgT>(op))
142 return callback(derivedOp);
143 return WalkResult::advance();
144 };
145 return detail::walk(op, function_ref<RetT(Operation *)>(wrapperFn));
146 }
147
148 /// Utility to provide the return type of a templated walk method.
149 template <typename FnT>
150 using walkResultType = decltype(walk(nullptr, std::declval<FnT>()));
151 } // end namespace detail
152
153 } // namespace mlir
154
155 #endif
156