1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "llvm/ADT/DenseSet.h"
17 #include "llvm/ADT/STLExtras.h"
18 #include "llvm/ADT/SmallPtrSet.h"
19 #include "llvm/ADT/SmallVector.h"
20 #include "llvm/ADT/iterator_range.h"
21 #include "llvm/Support/Casting.h"
22 #include "mlir/IR/Attributes.h" // from @llvm-project
23 #include "mlir/IR/Operation.h" // from @llvm-project
24 #include "mlir/IR/UseDefLists.h" // from @llvm-project
25 #include "mlir/IR/Value.h" // from @llvm-project
26 #include "mlir/Pass/Pass.h" // from @llvm-project
27 #include "mlir/Pass/PassRegistry.h" // from @llvm-project
28 #include "mlir/Transforms/RegionUtils.h" // from @llvm-project
29 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
30 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_side_effects.h"
31 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
32 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
33
34 namespace mlir {
35 namespace tf_executor {
36 namespace {
37
38 // This transformation pass prunes a TF graph eliminating dead-nodes.
39 class GraphPruningPass
40 : public TF::ExecutorGraphPruningPassBase<GraphPruningPass> {
41 public:
42 GraphPruningPass() = default;
43 explicit GraphPruningPass(llvm::ArrayRef<std::string> ops_to_preserve);
44 void runOnOperation() override;
45
46 private:
47 bool ShouldPreserveOp(Operation* op);
48 bool ShouldPreserveIsland(IslandOp island);
49 void PruneGraph(GraphOp graph);
50
51 llvm::SmallDenseSet<mlir::StringAttr, 4> ops_to_preserve_ids_;
52 };
53
54 // Checks if a tf_executor.Graph can be pruned.
55 // For TensorFlow V1.0 compatibility: when importing a graph without providing
56 // feeds/fetches/targets we should not attempt to prune. The best approximation
57 // here is to check if the graph is of the "main" function and does not have the
58 // "tf.entry_function" attribute defined.
CanPruneGraph(func::FuncOp func)59 bool CanPruneGraph(func::FuncOp func) {
60 return func.getName() != "main" ||
61 func->getAttrOfType<DictionaryAttr>("tf.entry_function") != nullptr;
62 }
63
64 // Visits an op's operand if it is an output of an Operation in the same
65 // tf_executor.graph.
VisitOpOperand(GraphOp graph,Value operand,llvm::SmallPtrSetImpl<Operation * > * reachable_ops,llvm::SmallVectorImpl<Operation * > * ops_to_visit)66 void VisitOpOperand(GraphOp graph, Value operand,
67 llvm::SmallPtrSetImpl<Operation*>* reachable_ops,
68 llvm::SmallVectorImpl<Operation*>* ops_to_visit) {
69 Operation* def = operand.getDefiningOp();
70 if (def && def->getParentOp() == graph && reachable_ops->insert(def).second) {
71 // Op has not been visited, add to queue to visit later.
72 ops_to_visit->push_back(def);
73 }
74 }
75
76 // Visits all operands of an op where each operand is an output of an Operation
77 // in the same tf_executor.graph.
VisitOpOperands(GraphOp graph,Operation * op,llvm::SmallPtrSetImpl<Operation * > * reachable_ops,llvm::SmallVectorImpl<Operation * > * ops_to_visit)78 void VisitOpOperands(GraphOp graph, Operation* op,
79 llvm::SmallPtrSetImpl<Operation*>* reachable_ops,
80 llvm::SmallVectorImpl<Operation*>* ops_to_visit) {
81 for (Value operand : op->getOperands())
82 VisitOpOperand(graph, operand, reachable_ops, ops_to_visit);
83 }
84
85 // Visits an op and it's associated operands. IslandOps are handled differently
86 // where it's regions op operands are also visited as values may be implicitly
87 // captured within. NextIterationSourceOp will also visit it's associated
88 // NextIterationSinkOp.
VisitOp(GraphOp graph,Operation * op,llvm::SmallPtrSetImpl<Operation * > * reachable_ops,llvm::SmallVectorImpl<Operation * > * ops_to_visit)89 void VisitOp(GraphOp graph, Operation* op,
90 llvm::SmallPtrSetImpl<Operation*>* reachable_ops,
91 llvm::SmallVectorImpl<Operation*>* ops_to_visit) {
92 if (auto island = llvm::dyn_cast<IslandOp>(op)) {
93 mlir::visitUsedValuesDefinedAbove(
94 island.body(), island.body(), [&](OpOperand* operand) {
95 VisitOpOperand(graph, operand->get(), reachable_ops, ops_to_visit);
96 });
97 }
98
99 VisitOpOperands(graph, op, reachable_ops, ops_to_visit);
100
101 // If op is a `tf_executor.NextIteration.Source`, visit its associated
102 // `tf_executor.NextIteration.Sink` op.
103 if (auto source_op = llvm::dyn_cast<NextIterationSourceOp>(op)) {
104 Operation* sink_op = source_op.GetSink().getOperation();
105 if (reachable_ops->insert(sink_op).second) ops_to_visit->push_back(sink_op);
106 }
107 }
108
GraphPruningPass(llvm::ArrayRef<std::string> ops_to_preserve)109 GraphPruningPass::GraphPruningPass(
110 llvm::ArrayRef<std::string> ops_to_preserve) {
111 ops_to_preserve_ = ops_to_preserve;
112 }
113
runOnOperation()114 void GraphPruningPass::runOnOperation() {
115 for (const auto& op_name : ops_to_preserve_) {
116 ops_to_preserve_ids_.insert(mlir::StringAttr::get(&getContext(), op_name));
117 }
118 if (!CanPruneGraph(getOperation())) return;
119 getOperation().walk(
120 [this](tf_executor::GraphOp graph) { PruneGraph(graph); });
121 }
122
123 // An op should be preserved if either its identifier is contained in
124 // `ops_to_preserve_ids_` or if it has a `MustExecute` effect.
ShouldPreserveOp(Operation * op)125 bool GraphPruningPass::ShouldPreserveOp(Operation* op) {
126 if (ops_to_preserve_ids_.contains(op->getName().getIdentifier())) return true;
127
128 llvm::SmallVector<MemoryEffects::EffectInstance, 4> effects;
129 auto interface = dyn_cast<MemoryEffectOpInterface>(op);
130 if (interface) interface.getEffects(effects);
131
132 for (const auto& effect : effects) {
133 if (llvm::isa<TF::ResourceEffects::MustExecute>(effect.getResource())) {
134 return true;
135 }
136 }
137 return false;
138 }
139
140 // An island should be preserved if any of its inner ops should be preserved.
ShouldPreserveIsland(IslandOp island)141 bool GraphPruningPass::ShouldPreserveIsland(IslandOp island) {
142 auto result = island.walk([this](Operation* inner_op) {
143 if (ShouldPreserveOp(inner_op)) return WalkResult::interrupt();
144 return WalkResult::advance();
145 });
146 return result.wasInterrupted();
147 }
148
149 // Prunes unreachable operations of a tf_executor.graph operation.
PruneGraph(GraphOp graph)150 void GraphPruningPass::PruneGraph(GraphOp graph) {
151 // A graph has a single block which forms a DAG: operations that aren't
152 // reachable from the `fetch` operands can be eliminated.
153
154 llvm::SmallPtrSet<Operation*, 8> reachable_ops;
155 llvm::SmallVector<Operation*, 8> ops_to_visit;
156
157 // Visit fetches first to create a starting point for ops that are reachable.
158 reachable_ops.insert(graph.GetFetch());
159 VisitOpOperands(graph, graph.GetFetch(), &reachable_ops, &ops_to_visit);
160
161 // Find and visit ops that should be preserved regardless of being reachable
162 // from a fetch.
163 for (Operation& op : graph.GetBody().without_terminator()) {
164 auto island = llvm::dyn_cast<IslandOp>(op);
165 if (!island) continue;
166 if (ShouldPreserveIsland(island)) {
167 reachable_ops.insert(&op);
168 VisitOp(graph, &op, &reachable_ops, &ops_to_visit);
169 }
170 }
171
172 // Visit transitive ops until no there are no reachable ops left that have not
173 // been visited.
174 while (!ops_to_visit.empty()) {
175 Operation* op = ops_to_visit.pop_back_val();
176 VisitOp(graph, op, &reachable_ops, &ops_to_visit);
177 }
178
179 // Erase unreachable ops in reverse order so references don't need to be
180 // dropped before removing an op. Going in reverse order will guarantee that
181 // when an op to be erased is reached, there are no users left.
182 for (Operation& op :
183 llvm::make_early_inc_range(llvm::reverse(graph.GetBody())))
184 if (!reachable_ops.contains(&op)) op.erase();
185 }
186
187 } // namespace
188
CreateTFExecutorGraphPruningPass(llvm::ArrayRef<std::string> ops_to_preserve)189 std::unique_ptr<OperationPass<func::FuncOp>> CreateTFExecutorGraphPruningPass(
190 llvm::ArrayRef<std::string> ops_to_preserve) {
191 return std::make_unique<GraphPruningPass>(ops_to_preserve);
192 }
193
194 } // namespace tf_executor
195 } // namespace mlir
196