• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_VISITOR_UTIL_H_
16 #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_VISITOR_UTIL_H_
17 
18 #include <utility>
19 
20 #include "mlir/IR/Visitors.h"  // from @llvm-project
21 
22 // This file defines generic (pre/in/post)-order MLIR IR visitors/walkers. The
23 // walk() utility that MLIR core provides traverses operations in a block/
24 // blocks in a region in the program order, and these walkers do the same. When
25 // operations have regions attached to them, the core MLIR walkers visit the
26 // regions attached to an Op first, and then visit the op. So within the context
27 // of a single Op, the traversal is post-order (considering the Op as the parent
28 // node and regions as the children). For certain use cases, it may be more
29 // efficient/desirable to visit the parent Op before visiting the attached
30 // regions. As an example, if the attached regions have region arguments that
31 // are related to the operation inputs (tf.WhileRegion is an example), then we
32 // may want to propagate some information from the Op inputs to the region
33 // inputs and then visit the regions to continue progagating that information
34 // within the regions. With just post-order traversal, to acheive the same we
35 // may need to schedule another walk so make sure child regions get visited.
36 // A pre-order walk (within the context of a single operation) will avoid that.
37 // Similarly, for certain operations, we may want to visit the Op both before
38 // and after all regions have been visited (say to propagate information from
39 // inputs -> region arguments and then from region results -> outputs).
40 
41 // In general, since the data flow between an operation and its regions is
42 // opaque in MLIR, we may need to visit the operation in-between regions as well
43 // if say region0 is transferring control back to the Op and from then to
44 // region1. So a more general walker that supports pre/in/post-order walk is
45 // desirable. To support this, the generic walkers defined below will invoke
46 // the walk callback on the parent Op at each stage of the child region walk,
47 // i.e., before visiting any region, in between regions, and after visiting all
48 // regions. To indicate the current walk stage, the callback will also get a
49 // `WalkState` parameter. The callback can inspect the current walk stage and
50 // decide to take appropriate actions (incuding not doing anything). With this
51 // the walker below can support pre/in/post-order walks as well as combined
52 // walks (pre+in+post)-order walk.
53 
54 namespace tensorflow {
55 
56 // A class to indicate the current walk stage.
57 class WalkStage {
58  public:
59   explicit WalkStage(mlir::Operation *op);
60 
IsBeforeAllRegions()61   bool IsBeforeAllRegions() const { return next_region_ == 0; }
IsBeforeRegion(int region)62   bool IsBeforeRegion(int region) const { return next_region_ == region; }
IsAfterRegion(int region)63   bool IsAfterRegion(int region) const { return next_region_ == region + 1; }
IsAfterAllRegions()64   bool IsAfterAllRegions() const { return next_region_ == num_regions_; }
Advance()65   void Advance() { next_region_++; }
GetNextRegion()66   int GetNextRegion() const { return next_region_; }
67 
68  private:
69   const int num_regions_;
70   int next_region_;
71 };
72 
73 namespace detail {
74 // This is similar to MLIR version, but works with multiple argument functions.
75 // Helper templates to deduce the first argument of a callback parameter.
76 template <typename Ret, typename Arg, typename... Rest>
77 Arg first_argument_type(Ret (*)(Arg, Rest...));
78 template <typename Ret, typename F, typename Arg, typename... Rest>
79 Arg first_argument_type(Ret (F::*)(Arg, Rest...));
80 template <typename Ret, typename F, typename Arg, typename... Rest>
81 Arg first_argument_type(Ret (F::*)(Arg, Rest...) const);
82 template <typename F>
83 decltype(first_argument_type(&F::operator())) first_argument_type(F);
84 
85 /// Type definition of the first argument to the given callable 'T'.
86 template <typename T>
87 using first_argument = decltype(first_argument_type(std::declval<T>()));
88 
89 using VoidCallback =
90     llvm::function_ref<void(mlir::Operation *, const WalkStage &)>;
91 using InterruptCallback =
92     llvm::function_ref<mlir::WalkResult(mlir::Operation *, const WalkStage &)>;
93 
94 // Walk all of the operations nested under and including the given operation.
95 void WalkOperations(mlir::Operation *op, VoidCallback callback);
96 
97 // Walk all of the operations nested under and including the given operation.
98 // This methods walks operations until an interrupt result is returned by the
99 // callback.
100 mlir::WalkResult WalkOperations(mlir::Operation *op,
101                                 InterruptCallback callback);
102 
103 }  // namespace detail
104 
105 // Walk all of the operations nested under and including the given operation.
106 // This method is selected for stage-aware callbacks that operate on Operation*.
107 //
108 // Example:
109 //   tensorflow::walk(op, [](Operation *op, const WalkStage &stage) { ... });
110 template <typename FuncTy, typename ArgT = detail::first_argument<FuncTy>,
111           typename RetT = decltype(std::declval<FuncTy>()(
112               std::declval<ArgT>(), std::declval<const WalkStage &>()))>
113 typename std::enable_if<std::is_same<ArgT, mlir::Operation *>::value,
114                         RetT>::type
GenericWalk(mlir::Operation * op,FuncTy && callback)115 GenericWalk(mlir::Operation *op, FuncTy &&callback) {
116   return detail::WalkOperations(
117       op, llvm::function_ref<RetT(ArgT, const WalkStage &)>(callback));
118 }
119 
120 // Walk all of the operations of type 'ArgT' nested under and including the
121 // given operation. This method is selected for void returning callbacks that
122 // operate on a specific derived operation type.
123 //
124 // Example:
125 //   tensorflow::walk(op, [](ReturnOp op, const WalkStage &stage) { ... });
126 template <typename FuncTy, typename ArgT = detail::first_argument<FuncTy>,
127           typename RetT = decltype(std::declval<FuncTy>()(
128               std::declval<ArgT>(), std::declval<const WalkStage &>()))>
129 typename std::enable_if<!std::is_same<ArgT, mlir::Operation *>::value &&
130                             std::is_same<RetT, void>::value,
131                         RetT>::type
GenericWalk(mlir::Operation * op,FuncTy && callback)132 GenericWalk(mlir::Operation *op, FuncTy &&callback) {
133   auto wrapperFn = [&](mlir::Operation *op, const WalkStage &stage) {
134     if (auto derivedOp = llvm::dyn_cast<ArgT>(op)) callback(derivedOp, stage);
135   };
136   return detail::WalkOperations(op,
137                                 static_cast<detail::VoidCallback>(wrapperFn));
138 }
139 
140 // Walk all of the operations of type 'ArgT' nested under and including the
141 // given operation. This method is selected for WalkReturn returning
142 // interruptible callbacks that operate on a specific derived operation type.
143 //
144 // Example:
145 //   tensorflow::walk(op, [](ReturnOp op, const WalkStage &stage) {
146 //     if (some_invariant)
147 //       return WalkResult::interrupt();
148 //     return WalkResult::advance();
149 //   });
150 template <typename FuncTy, typename ArgT = detail::first_argument<FuncTy>,
151           typename RetT = decltype(std::declval<FuncTy>()(
152               std::declval<ArgT>(), std::declval<const WalkStage &>()))>
153 typename std::enable_if<!std::is_same<ArgT, mlir::Operation *>::value &&
154                             std::is_same<RetT, mlir::WalkResult>::value,
155                         RetT>::type
GenericWalk(mlir::Operation * op,FuncTy && callback)156 GenericWalk(mlir::Operation *op, FuncTy &&callback) {
157   auto wrapperFn = [&](mlir::Operation *op, const WalkStage &stage) {
158     if (auto derivedOp = llvm::dyn_cast<ArgT>(op))
159       return callback(derivedOp, stage);
160     return mlir::WalkResult::advance();
161   };
162   return detail::WalkOperations(
163       op, static_cast<detail::InterruptCallback>(wrapperFn));
164 }
165 
166 }  // namespace tensorflow
167 
168 #endif  // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_VISITOR_UTIL_H_
169