1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_VISITOR_UTIL_H_
16 #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_VISITOR_UTIL_H_
17
18 #include <utility>
19
20 #include "mlir/IR/Visitors.h" // from @llvm-project
21
22 // This file defines generic (pre/in/post)-order MLIR IR visitors/walkers. The
23 // walk() utility that MLIR core provides traverses operations in a block/
24 // blocks in a region in the program order, and these walkers do the same. When
25 // operations have regions attached to them, the core MLIR walkers visit the
26 // regions attached to an Op first, and then visit the op. So within the context
27 // of a single Op, the traversal is post-order (considering the Op as the parent
28 // node and regions as the children). For certain use cases, it may be more
29 // efficient/desirable to visit the parent Op before visiting the attached
30 // regions. As an example, if the attached regions have region arguments that
31 // are related to the operation inputs (tf.WhileRegion is an example), then we
32 // may want to propagate some information from the Op inputs to the region
33 // inputs and then visit the regions to continue progagating that information
34 // within the regions. With just post-order traversal, to acheive the same we
35 // may need to schedule another walk so make sure child regions get visited.
36 // A pre-order walk (within the context of a single operation) will avoid that.
37 // Similarly, for certain operations, we may want to visit the Op both before
38 // and after all regions have been visited (say to propagate information from
39 // inputs -> region arguments and then from region results -> outputs).
40
41 // In general, since the data flow between an operation and its regions is
42 // opaque in MLIR, we may need to visit the operation in-between regions as well
43 // if say region0 is transferring control back to the Op and from then to
44 // region1. So a more general walker that supports pre/in/post-order walk is
45 // desirable. To support this, the generic walkers defined below will invoke
46 // the walk callback on the parent Op at each stage of the child region walk,
47 // i.e., before visiting any region, in between regions, and after visiting all
48 // regions. To indicate the current walk stage, the callback will also get a
49 // `WalkState` parameter. The callback can inspect the current walk stage and
50 // decide to take appropriate actions (incuding not doing anything). With this
51 // the walker below can support pre/in/post-order walks as well as combined
52 // walks (pre+in+post)-order walk.
53
54 namespace tensorflow {
55
56 // A class to indicate the current walk stage.
57 class WalkStage {
58 public:
59 explicit WalkStage(mlir::Operation *op);
60
IsBeforeAllRegions()61 bool IsBeforeAllRegions() const { return next_region_ == 0; }
IsBeforeRegion(int region)62 bool IsBeforeRegion(int region) const { return next_region_ == region; }
IsAfterRegion(int region)63 bool IsAfterRegion(int region) const { return next_region_ == region + 1; }
IsAfterAllRegions()64 bool IsAfterAllRegions() const { return next_region_ == num_regions_; }
Advance()65 void Advance() { next_region_++; }
GetNextRegion()66 int GetNextRegion() const { return next_region_; }
67
68 private:
69 const int num_regions_;
70 int next_region_;
71 };
72
73 namespace detail {
74 // This is similar to MLIR version, but works with multiple argument functions.
75 // Helper templates to deduce the first argument of a callback parameter.
76 template <typename Ret, typename Arg, typename... Rest>
77 Arg first_argument_type(Ret (*)(Arg, Rest...));
78 template <typename Ret, typename F, typename Arg, typename... Rest>
79 Arg first_argument_type(Ret (F::*)(Arg, Rest...));
80 template <typename Ret, typename F, typename Arg, typename... Rest>
81 Arg first_argument_type(Ret (F::*)(Arg, Rest...) const);
82 template <typename F>
83 decltype(first_argument_type(&F::operator())) first_argument_type(F);
84
85 /// Type definition of the first argument to the given callable 'T'.
86 template <typename T>
87 using first_argument = decltype(first_argument_type(std::declval<T>()));
88
89 using VoidCallback =
90 llvm::function_ref<void(mlir::Operation *, const WalkStage &)>;
91 using InterruptCallback =
92 llvm::function_ref<mlir::WalkResult(mlir::Operation *, const WalkStage &)>;
93
94 // Walk all of the operations nested under and including the given operation.
95 void WalkOperations(mlir::Operation *op, VoidCallback callback);
96
97 // Walk all of the operations nested under and including the given operation.
98 // This methods walks operations until an interrupt result is returned by the
99 // callback.
100 mlir::WalkResult WalkOperations(mlir::Operation *op,
101 InterruptCallback callback);
102
103 } // namespace detail
104
105 // Walk all of the operations nested under and including the given operation.
106 // This method is selected for stage-aware callbacks that operate on Operation*.
107 //
108 // Example:
109 // tensorflow::walk(op, [](Operation *op, const WalkStage &stage) { ... });
110 template <typename FuncTy, typename ArgT = detail::first_argument<FuncTy>,
111 typename RetT = decltype(std::declval<FuncTy>()(
112 std::declval<ArgT>(), std::declval<const WalkStage &>()))>
113 typename std::enable_if<std::is_same<ArgT, mlir::Operation *>::value,
114 RetT>::type
GenericWalk(mlir::Operation * op,FuncTy && callback)115 GenericWalk(mlir::Operation *op, FuncTy &&callback) {
116 return detail::WalkOperations(
117 op, llvm::function_ref<RetT(ArgT, const WalkStage &)>(callback));
118 }
119
120 // Walk all of the operations of type 'ArgT' nested under and including the
121 // given operation. This method is selected for void returning callbacks that
122 // operate on a specific derived operation type.
123 //
124 // Example:
125 // tensorflow::walk(op, [](ReturnOp op, const WalkStage &stage) { ... });
126 template <typename FuncTy, typename ArgT = detail::first_argument<FuncTy>,
127 typename RetT = decltype(std::declval<FuncTy>()(
128 std::declval<ArgT>(), std::declval<const WalkStage &>()))>
129 typename std::enable_if<!std::is_same<ArgT, mlir::Operation *>::value &&
130 std::is_same<RetT, void>::value,
131 RetT>::type
GenericWalk(mlir::Operation * op,FuncTy && callback)132 GenericWalk(mlir::Operation *op, FuncTy &&callback) {
133 auto wrapperFn = [&](mlir::Operation *op, const WalkStage &stage) {
134 if (auto derivedOp = llvm::dyn_cast<ArgT>(op)) callback(derivedOp, stage);
135 };
136 return detail::WalkOperations(op,
137 static_cast<detail::VoidCallback>(wrapperFn));
138 }
139
140 // Walk all of the operations of type 'ArgT' nested under and including the
141 // given operation. This method is selected for WalkReturn returning
142 // interruptible callbacks that operate on a specific derived operation type.
143 //
144 // Example:
145 // tensorflow::walk(op, [](ReturnOp op, const WalkStage &stage) {
146 // if (some_invariant)
147 // return WalkResult::interrupt();
148 // return WalkResult::advance();
149 // });
150 template <typename FuncTy, typename ArgT = detail::first_argument<FuncTy>,
151 typename RetT = decltype(std::declval<FuncTy>()(
152 std::declval<ArgT>(), std::declval<const WalkStage &>()))>
153 typename std::enable_if<!std::is_same<ArgT, mlir::Operation *>::value &&
154 std::is_same<RetT, mlir::WalkResult>::value,
155 RetT>::type
GenericWalk(mlir::Operation * op,FuncTy && callback)156 GenericWalk(mlir::Operation *op, FuncTy &&callback) {
157 auto wrapperFn = [&](mlir::Operation *op, const WalkStage &stage) {
158 if (auto derivedOp = llvm::dyn_cast<ArgT>(op))
159 return callback(derivedOp, stage);
160 return mlir::WalkResult::advance();
161 };
162 return detail::WalkOperations(
163 op, static_cast<detail::InterruptCallback>(wrapperFn));
164 }
165
166 } // namespace tensorflow
167
168 #endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_VISITOR_UTIL_H_
169