• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "llvm/ADT/DenseSet.h"
17 #include "llvm/ADT/STLExtras.h"
18 #include "llvm/ADT/SmallPtrSet.h"
19 #include "llvm/ADT/SmallVector.h"
20 #include "llvm/ADT/iterator_range.h"
21 #include "llvm/Support/Casting.h"
22 #include "mlir/IR/Attributes.h"  // from @llvm-project
23 #include "mlir/IR/Operation.h"  // from @llvm-project
24 #include "mlir/IR/UseDefLists.h"  // from @llvm-project
25 #include "mlir/IR/Value.h"  // from @llvm-project
26 #include "mlir/Pass/Pass.h"  // from @llvm-project
27 #include "mlir/Pass/PassRegistry.h"  // from @llvm-project
28 #include "mlir/Transforms/RegionUtils.h"  // from @llvm-project
29 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
30 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
31 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
32 
33 namespace mlir {
34 namespace tf_executor {
35 namespace {
36 
37 // This transformation pass prunes a TF graph eliminating dead-nodes.
38 class GraphPruningPass
39     : public TF::ExecutorGraphPruningPassBase<GraphPruningPass> {
40  public:
41   GraphPruningPass() = default;
42   explicit GraphPruningPass(llvm::ArrayRef<std::string> ops_to_preserve);
43   void runOnFunction() override;
44 
45  private:
46   bool ShouldPreserveOp(Operation* op);
47   bool ShouldPreserveIsland(IslandOp island);
48   void PruneGraph(GraphOp graph);
49 
50   llvm::SmallDenseSet<mlir::Identifier, 4> ops_to_preserve_ids_;
51 };
52 
53 // Checks if a tf_executor.Graph can be pruned.
54 // For TensorFlow V1.0 compatibility: when importing a graph without providing
55 // feeds/fetches/targets we should not attempt to prune. The best approximation
56 // here is to check if the graph is of the "main" function and does not have the
57 // "tf.entry_function" attribute defined.
CanPruneGraph(FuncOp func)58 bool CanPruneGraph(FuncOp func) {
59   return func.getName() != "main" ||
60          func->getAttrOfType<DictionaryAttr>("tf.entry_function") != nullptr;
61 }
62 
63 // Visits an op's operand if it is an output of an Operation in the same
64 // tf_executor.graph.
VisitOpOperand(GraphOp graph,Value operand,llvm::SmallPtrSetImpl<Operation * > * reachable_ops,llvm::SmallVectorImpl<Operation * > * ops_to_visit)65 void VisitOpOperand(GraphOp graph, Value operand,
66                     llvm::SmallPtrSetImpl<Operation*>* reachable_ops,
67                     llvm::SmallVectorImpl<Operation*>* ops_to_visit) {
68   Operation* def = operand.getDefiningOp();
69   if (def && def->getParentOp() == graph && reachable_ops->insert(def).second) {
70     // Op has not been visited, add to queue to visit later.
71     ops_to_visit->push_back(def);
72   }
73 }
74 
75 // Visits all operands of an op where each operand is an output of an Operation
76 // in the same tf_executor.graph.
VisitOpOperands(GraphOp graph,Operation * op,llvm::SmallPtrSetImpl<Operation * > * reachable_ops,llvm::SmallVectorImpl<Operation * > * ops_to_visit)77 void VisitOpOperands(GraphOp graph, Operation* op,
78                      llvm::SmallPtrSetImpl<Operation*>* reachable_ops,
79                      llvm::SmallVectorImpl<Operation*>* ops_to_visit) {
80   for (Value operand : op->getOperands())
81     VisitOpOperand(graph, operand, reachable_ops, ops_to_visit);
82 }
83 
84 // Visits an op and it's associated operands. IslandOps are handled differently
85 // where it's regions op operands are also visited as values may be implicitly
86 // captured within. NextIterationSourceOp will also visit it's associated
87 // NextIterationSinkOp.
VisitOp(GraphOp graph,Operation * op,llvm::SmallPtrSetImpl<Operation * > * reachable_ops,llvm::SmallVectorImpl<Operation * > * ops_to_visit)88 void VisitOp(GraphOp graph, Operation* op,
89              llvm::SmallPtrSetImpl<Operation*>* reachable_ops,
90              llvm::SmallVectorImpl<Operation*>* ops_to_visit) {
91   if (auto island = llvm::dyn_cast<IslandOp>(op)) {
92     mlir::visitUsedValuesDefinedAbove(
93         island.body(), island.body(), [&](OpOperand* operand) {
94           VisitOpOperand(graph, operand->get(), reachable_ops, ops_to_visit);
95         });
96   }
97 
98   VisitOpOperands(graph, op, reachable_ops, ops_to_visit);
99 
100   // If op is a `tf_executor.NextIteration.Source`, visit its associated
101   // `tf_executor.NextIteration.Sink` op.
102   if (auto source_op = llvm::dyn_cast<NextIterationSourceOp>(op)) {
103     Operation* sink_op = source_op.GetSink().getOperation();
104     if (reachable_ops->insert(sink_op).second) ops_to_visit->push_back(sink_op);
105   }
106 }
107 
GraphPruningPass(llvm::ArrayRef<std::string> ops_to_preserve)108 GraphPruningPass::GraphPruningPass(
109     llvm::ArrayRef<std::string> ops_to_preserve) {
110   ops_to_preserve_ = ops_to_preserve;
111 }
112 
runOnFunction()113 void GraphPruningPass::runOnFunction() {
114   for (const auto& op_name : ops_to_preserve_) {
115     ops_to_preserve_ids_.insert(mlir::Identifier::get(op_name, &getContext()));
116   }
117   if (!CanPruneGraph(getFunction())) return;
118   getFunction().walk([this](tf_executor::GraphOp graph) { PruneGraph(graph); });
119 }
120 
121 // An op should be preserved if its identifier is contained in
122 // `ops_to_preserve_ids_`.
ShouldPreserveOp(Operation * op)123 bool GraphPruningPass::ShouldPreserveOp(Operation* op) {
124   return ops_to_preserve_ids_.contains(op->getName().getIdentifier());
125 }
126 
127 // An island should be preserved if any of its inner ops should be preserved.
ShouldPreserveIsland(IslandOp island)128 bool GraphPruningPass::ShouldPreserveIsland(IslandOp island) {
129   auto result = island.walk([this](Operation* inner_op) {
130     if (ShouldPreserveOp(inner_op)) return WalkResult::interrupt();
131     return WalkResult::advance();
132   });
133   return result.wasInterrupted();
134 }
135 
136 // Prunes unreachable operations of a tf_executor.graph operation.
PruneGraph(GraphOp graph)137 void GraphPruningPass::PruneGraph(GraphOp graph) {
138   // A graph has a single block which forms a DAG: operations that aren't
139   // reachable from the `fetch` operands can be eliminated.
140 
141   llvm::SmallPtrSet<Operation*, 8> reachable_ops;
142   llvm::SmallVector<Operation*, 8> ops_to_visit;
143 
144   // Visit fetches first to create a starting point for ops that are reachable.
145   reachable_ops.insert(graph.GetFetch());
146   VisitOpOperands(graph, graph.GetFetch(), &reachable_ops, &ops_to_visit);
147 
148   // Find and visit ops that should be preserved regardless of being reachable
149   // from a fetch.
150   for (Operation& op : graph.GetBody().without_terminator()) {
151     auto island = llvm::dyn_cast<IslandOp>(op);
152     if (!island) continue;
153     if (ShouldPreserveIsland(island)) {
154       reachable_ops.insert(&op);
155       VisitOp(graph, &op, &reachable_ops, &ops_to_visit);
156     }
157   }
158 
159   // Visit transitive ops until no there are no reachable ops left that have not
160   // been visited.
161   while (!ops_to_visit.empty()) {
162     Operation* op = ops_to_visit.pop_back_val();
163     VisitOp(graph, op, &reachable_ops, &ops_to_visit);
164   }
165 
166   // Erase unreachable ops in reverse order so references don't need to be
167   // dropped before removing an op. Going in reverse order will guarantee that
168   // when an op to be erased is reached, there are no users left.
169   for (Operation& op :
170        llvm::make_early_inc_range(llvm::reverse(graph.GetBody())))
171     if (!reachable_ops.contains(&op)) op.erase();
172 }
173 
174 }  // namespace
175 
CreateTFExecutorGraphPruningPass(llvm::ArrayRef<std::string> ops_to_preserve)176 std::unique_ptr<OperationPass<FuncOp>> CreateTFExecutorGraphPruningPass(
177     llvm::ArrayRef<std::string> ops_to_preserve) {
178   return std::make_unique<GraphPruningPass>(ops_to_preserve);
179 }
180 
181 }  // namespace tf_executor
182 }  // namespace mlir
183