1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "llvm/ADT/DenseSet.h"
17 #include "llvm/ADT/STLExtras.h"
18 #include "llvm/ADT/SmallPtrSet.h"
19 #include "llvm/ADT/SmallVector.h"
20 #include "llvm/ADT/iterator_range.h"
21 #include "llvm/Support/Casting.h"
22 #include "mlir/IR/Attributes.h" // from @llvm-project
23 #include "mlir/IR/Operation.h" // from @llvm-project
24 #include "mlir/IR/UseDefLists.h" // from @llvm-project
25 #include "mlir/IR/Value.h" // from @llvm-project
26 #include "mlir/Pass/Pass.h" // from @llvm-project
27 #include "mlir/Pass/PassRegistry.h" // from @llvm-project
28 #include "mlir/Transforms/RegionUtils.h" // from @llvm-project
29 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
30 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
31 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
32
33 namespace mlir {
34 namespace tf_executor {
35 namespace {
36
37 // This transformation pass prunes a TF graph eliminating dead-nodes.
38 class GraphPruningPass
39 : public TF::ExecutorGraphPruningPassBase<GraphPruningPass> {
40 public:
41 GraphPruningPass() = default;
42 explicit GraphPruningPass(llvm::ArrayRef<std::string> ops_to_preserve);
43 void runOnFunction() override;
44
45 private:
46 bool ShouldPreserveOp(Operation* op);
47 bool ShouldPreserveIsland(IslandOp island);
48 void PruneGraph(GraphOp graph);
49
50 llvm::SmallDenseSet<mlir::Identifier, 4> ops_to_preserve_ids_;
51 };
52
53 // Checks if a tf_executor.Graph can be pruned.
54 // For TensorFlow V1.0 compatibility: when importing a graph without providing
55 // feeds/fetches/targets we should not attempt to prune. The best approximation
56 // here is to check if the graph is of the "main" function and does not have the
57 // "tf.entry_function" attribute defined.
CanPruneGraph(FuncOp func)58 bool CanPruneGraph(FuncOp func) {
59 return func.getName() != "main" ||
60 func->getAttrOfType<DictionaryAttr>("tf.entry_function") != nullptr;
61 }
62
63 // Visits an op's operand if it is an output of an Operation in the same
64 // tf_executor.graph.
VisitOpOperand(GraphOp graph,Value operand,llvm::SmallPtrSetImpl<Operation * > * reachable_ops,llvm::SmallVectorImpl<Operation * > * ops_to_visit)65 void VisitOpOperand(GraphOp graph, Value operand,
66 llvm::SmallPtrSetImpl<Operation*>* reachable_ops,
67 llvm::SmallVectorImpl<Operation*>* ops_to_visit) {
68 Operation* def = operand.getDefiningOp();
69 if (def && def->getParentOp() == graph && reachable_ops->insert(def).second) {
70 // Op has not been visited, add to queue to visit later.
71 ops_to_visit->push_back(def);
72 }
73 }
74
75 // Visits all operands of an op where each operand is an output of an Operation
76 // in the same tf_executor.graph.
VisitOpOperands(GraphOp graph,Operation * op,llvm::SmallPtrSetImpl<Operation * > * reachable_ops,llvm::SmallVectorImpl<Operation * > * ops_to_visit)77 void VisitOpOperands(GraphOp graph, Operation* op,
78 llvm::SmallPtrSetImpl<Operation*>* reachable_ops,
79 llvm::SmallVectorImpl<Operation*>* ops_to_visit) {
80 for (Value operand : op->getOperands())
81 VisitOpOperand(graph, operand, reachable_ops, ops_to_visit);
82 }
83
84 // Visits an op and it's associated operands. IslandOps are handled differently
85 // where it's regions op operands are also visited as values may be implicitly
86 // captured within. NextIterationSourceOp will also visit it's associated
87 // NextIterationSinkOp.
VisitOp(GraphOp graph,Operation * op,llvm::SmallPtrSetImpl<Operation * > * reachable_ops,llvm::SmallVectorImpl<Operation * > * ops_to_visit)88 void VisitOp(GraphOp graph, Operation* op,
89 llvm::SmallPtrSetImpl<Operation*>* reachable_ops,
90 llvm::SmallVectorImpl<Operation*>* ops_to_visit) {
91 if (auto island = llvm::dyn_cast<IslandOp>(op)) {
92 mlir::visitUsedValuesDefinedAbove(
93 island.body(), island.body(), [&](OpOperand* operand) {
94 VisitOpOperand(graph, operand->get(), reachable_ops, ops_to_visit);
95 });
96 }
97
98 VisitOpOperands(graph, op, reachable_ops, ops_to_visit);
99
100 // If op is a `tf_executor.NextIteration.Source`, visit its associated
101 // `tf_executor.NextIteration.Sink` op.
102 if (auto source_op = llvm::dyn_cast<NextIterationSourceOp>(op)) {
103 Operation* sink_op = source_op.GetSink().getOperation();
104 if (reachable_ops->insert(sink_op).second) ops_to_visit->push_back(sink_op);
105 }
106 }
107
GraphPruningPass(llvm::ArrayRef<std::string> ops_to_preserve)108 GraphPruningPass::GraphPruningPass(
109 llvm::ArrayRef<std::string> ops_to_preserve) {
110 ops_to_preserve_ = ops_to_preserve;
111 }
112
runOnFunction()113 void GraphPruningPass::runOnFunction() {
114 for (const auto& op_name : ops_to_preserve_) {
115 ops_to_preserve_ids_.insert(mlir::Identifier::get(op_name, &getContext()));
116 }
117 if (!CanPruneGraph(getFunction())) return;
118 getFunction().walk([this](tf_executor::GraphOp graph) { PruneGraph(graph); });
119 }
120
121 // An op should be preserved if its identifier is contained in
122 // `ops_to_preserve_ids_`.
ShouldPreserveOp(Operation * op)123 bool GraphPruningPass::ShouldPreserveOp(Operation* op) {
124 return ops_to_preserve_ids_.contains(op->getName().getIdentifier());
125 }
126
127 // An island should be preserved if any of its inner ops should be preserved.
ShouldPreserveIsland(IslandOp island)128 bool GraphPruningPass::ShouldPreserveIsland(IslandOp island) {
129 auto result = island.walk([this](Operation* inner_op) {
130 if (ShouldPreserveOp(inner_op)) return WalkResult::interrupt();
131 return WalkResult::advance();
132 });
133 return result.wasInterrupted();
134 }
135
136 // Prunes unreachable operations of a tf_executor.graph operation.
PruneGraph(GraphOp graph)137 void GraphPruningPass::PruneGraph(GraphOp graph) {
138 // A graph has a single block which forms a DAG: operations that aren't
139 // reachable from the `fetch` operands can be eliminated.
140
141 llvm::SmallPtrSet<Operation*, 8> reachable_ops;
142 llvm::SmallVector<Operation*, 8> ops_to_visit;
143
144 // Visit fetches first to create a starting point for ops that are reachable.
145 reachable_ops.insert(graph.GetFetch());
146 VisitOpOperands(graph, graph.GetFetch(), &reachable_ops, &ops_to_visit);
147
148 // Find and visit ops that should be preserved regardless of being reachable
149 // from a fetch.
150 for (Operation& op : graph.GetBody().without_terminator()) {
151 auto island = llvm::dyn_cast<IslandOp>(op);
152 if (!island) continue;
153 if (ShouldPreserveIsland(island)) {
154 reachable_ops.insert(&op);
155 VisitOp(graph, &op, &reachable_ops, &ops_to_visit);
156 }
157 }
158
159 // Visit transitive ops until no there are no reachable ops left that have not
160 // been visited.
161 while (!ops_to_visit.empty()) {
162 Operation* op = ops_to_visit.pop_back_val();
163 VisitOp(graph, op, &reachable_ops, &ops_to_visit);
164 }
165
166 // Erase unreachable ops in reverse order so references don't need to be
167 // dropped before removing an op. Going in reverse order will guarantee that
168 // when an op to be erased is reached, there are no users left.
169 for (Operation& op :
170 llvm::make_early_inc_range(llvm::reverse(graph.GetBody())))
171 if (!reachable_ops.contains(&op)) op.erase();
172 }
173
174 } // namespace
175
CreateTFExecutorGraphPruningPass(llvm::ArrayRef<std::string> ops_to_preserve)176 std::unique_ptr<OperationPass<FuncOp>> CreateTFExecutorGraphPruningPass(
177 llvm::ArrayRef<std::string> ops_to_preserve) {
178 return std::make_unique<GraphPruningPass>(ops_to_preserve);
179 }
180
181 } // namespace tf_executor
182 } // namespace mlir
183