• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "llvm/ADT/DenseSet.h"
17 #include "llvm/ADT/STLExtras.h"
18 #include "llvm/ADT/SmallPtrSet.h"
19 #include "llvm/ADT/SmallVector.h"
20 #include "llvm/ADT/iterator_range.h"
21 #include "llvm/Support/Casting.h"
22 #include "mlir/IR/Attributes.h"  // from @llvm-project
23 #include "mlir/IR/Operation.h"  // from @llvm-project
24 #include "mlir/IR/UseDefLists.h"  // from @llvm-project
25 #include "mlir/IR/Value.h"  // from @llvm-project
26 #include "mlir/Pass/Pass.h"  // from @llvm-project
27 #include "mlir/Pass/PassRegistry.h"  // from @llvm-project
28 #include "mlir/Transforms/RegionUtils.h"  // from @llvm-project
29 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
30 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_side_effects.h"
31 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
32 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
33 
34 namespace mlir {
35 namespace tf_executor {
36 namespace {
37 
38 // This transformation pass prunes a TF graph eliminating dead-nodes.
39 class GraphPruningPass
40     : public TF::ExecutorGraphPruningPassBase<GraphPruningPass> {
41  public:
42   GraphPruningPass() = default;
43   explicit GraphPruningPass(llvm::ArrayRef<std::string> ops_to_preserve);
44   void runOnOperation() override;
45 
46  private:
47   bool ShouldPreserveOp(Operation* op);
48   bool ShouldPreserveIsland(IslandOp island);
49   void PruneGraph(GraphOp graph);
50 
51   llvm::SmallDenseSet<mlir::StringAttr, 4> ops_to_preserve_ids_;
52 };
53 
54 // Checks if a tf_executor.Graph can be pruned.
55 // For TensorFlow V1.0 compatibility: when importing a graph without providing
56 // feeds/fetches/targets we should not attempt to prune. The best approximation
57 // here is to check if the graph is of the "main" function and does not have the
58 // "tf.entry_function" attribute defined.
CanPruneGraph(func::FuncOp func)59 bool CanPruneGraph(func::FuncOp func) {
60   return func.getName() != "main" ||
61          func->getAttrOfType<DictionaryAttr>("tf.entry_function") != nullptr;
62 }
63 
64 // Visits an op's operand if it is an output of an Operation in the same
65 // tf_executor.graph.
VisitOpOperand(GraphOp graph,Value operand,llvm::SmallPtrSetImpl<Operation * > * reachable_ops,llvm::SmallVectorImpl<Operation * > * ops_to_visit)66 void VisitOpOperand(GraphOp graph, Value operand,
67                     llvm::SmallPtrSetImpl<Operation*>* reachable_ops,
68                     llvm::SmallVectorImpl<Operation*>* ops_to_visit) {
69   Operation* def = operand.getDefiningOp();
70   if (def && def->getParentOp() == graph && reachable_ops->insert(def).second) {
71     // Op has not been visited, add to queue to visit later.
72     ops_to_visit->push_back(def);
73   }
74 }
75 
76 // Visits all operands of an op where each operand is an output of an Operation
77 // in the same tf_executor.graph.
VisitOpOperands(GraphOp graph,Operation * op,llvm::SmallPtrSetImpl<Operation * > * reachable_ops,llvm::SmallVectorImpl<Operation * > * ops_to_visit)78 void VisitOpOperands(GraphOp graph, Operation* op,
79                      llvm::SmallPtrSetImpl<Operation*>* reachable_ops,
80                      llvm::SmallVectorImpl<Operation*>* ops_to_visit) {
81   for (Value operand : op->getOperands())
82     VisitOpOperand(graph, operand, reachable_ops, ops_to_visit);
83 }
84 
85 // Visits an op and it's associated operands. IslandOps are handled differently
86 // where it's regions op operands are also visited as values may be implicitly
87 // captured within. NextIterationSourceOp will also visit it's associated
88 // NextIterationSinkOp.
VisitOp(GraphOp graph,Operation * op,llvm::SmallPtrSetImpl<Operation * > * reachable_ops,llvm::SmallVectorImpl<Operation * > * ops_to_visit)89 void VisitOp(GraphOp graph, Operation* op,
90              llvm::SmallPtrSetImpl<Operation*>* reachable_ops,
91              llvm::SmallVectorImpl<Operation*>* ops_to_visit) {
92   if (auto island = llvm::dyn_cast<IslandOp>(op)) {
93     mlir::visitUsedValuesDefinedAbove(
94         island.body(), island.body(), [&](OpOperand* operand) {
95           VisitOpOperand(graph, operand->get(), reachable_ops, ops_to_visit);
96         });
97   }
98 
99   VisitOpOperands(graph, op, reachable_ops, ops_to_visit);
100 
101   // If op is a `tf_executor.NextIteration.Source`, visit its associated
102   // `tf_executor.NextIteration.Sink` op.
103   if (auto source_op = llvm::dyn_cast<NextIterationSourceOp>(op)) {
104     Operation* sink_op = source_op.GetSink().getOperation();
105     if (reachable_ops->insert(sink_op).second) ops_to_visit->push_back(sink_op);
106   }
107 }
108 
GraphPruningPass(llvm::ArrayRef<std::string> ops_to_preserve)109 GraphPruningPass::GraphPruningPass(
110     llvm::ArrayRef<std::string> ops_to_preserve) {
111   ops_to_preserve_ = ops_to_preserve;
112 }
113 
runOnOperation()114 void GraphPruningPass::runOnOperation() {
115   for (const auto& op_name : ops_to_preserve_) {
116     ops_to_preserve_ids_.insert(mlir::StringAttr::get(&getContext(), op_name));
117   }
118   if (!CanPruneGraph(getOperation())) return;
119   getOperation().walk(
120       [this](tf_executor::GraphOp graph) { PruneGraph(graph); });
121 }
122 
123 // An op should be preserved if either its identifier is contained in
124 // `ops_to_preserve_ids_` or if it has a `MustExecute` effect.
ShouldPreserveOp(Operation * op)125 bool GraphPruningPass::ShouldPreserveOp(Operation* op) {
126   if (ops_to_preserve_ids_.contains(op->getName().getIdentifier())) return true;
127 
128   llvm::SmallVector<MemoryEffects::EffectInstance, 4> effects;
129   auto interface = dyn_cast<MemoryEffectOpInterface>(op);
130   if (interface) interface.getEffects(effects);
131 
132   for (const auto& effect : effects) {
133     if (llvm::isa<TF::ResourceEffects::MustExecute>(effect.getResource())) {
134       return true;
135     }
136   }
137   return false;
138 }
139 
140 // An island should be preserved if any of its inner ops should be preserved.
ShouldPreserveIsland(IslandOp island)141 bool GraphPruningPass::ShouldPreserveIsland(IslandOp island) {
142   auto result = island.walk([this](Operation* inner_op) {
143     if (ShouldPreserveOp(inner_op)) return WalkResult::interrupt();
144     return WalkResult::advance();
145   });
146   return result.wasInterrupted();
147 }
148 
149 // Prunes unreachable operations of a tf_executor.graph operation.
PruneGraph(GraphOp graph)150 void GraphPruningPass::PruneGraph(GraphOp graph) {
151   // A graph has a single block which forms a DAG: operations that aren't
152   // reachable from the `fetch` operands can be eliminated.
153 
154   llvm::SmallPtrSet<Operation*, 8> reachable_ops;
155   llvm::SmallVector<Operation*, 8> ops_to_visit;
156 
157   // Visit fetches first to create a starting point for ops that are reachable.
158   reachable_ops.insert(graph.GetFetch());
159   VisitOpOperands(graph, graph.GetFetch(), &reachable_ops, &ops_to_visit);
160 
161   // Find and visit ops that should be preserved regardless of being reachable
162   // from a fetch.
163   for (Operation& op : graph.GetBody().without_terminator()) {
164     auto island = llvm::dyn_cast<IslandOp>(op);
165     if (!island) continue;
166     if (ShouldPreserveIsland(island)) {
167       reachable_ops.insert(&op);
168       VisitOp(graph, &op, &reachable_ops, &ops_to_visit);
169     }
170   }
171 
172   // Visit transitive ops until no there are no reachable ops left that have not
173   // been visited.
174   while (!ops_to_visit.empty()) {
175     Operation* op = ops_to_visit.pop_back_val();
176     VisitOp(graph, op, &reachable_ops, &ops_to_visit);
177   }
178 
179   // Erase unreachable ops in reverse order so references don't need to be
180   // dropped before removing an op. Going in reverse order will guarantee that
181   // when an op to be erased is reached, there are no users left.
182   for (Operation& op :
183        llvm::make_early_inc_range(llvm::reverse(graph.GetBody())))
184     if (!reachable_ops.contains(&op)) op.erase();
185 }
186 
187 }  // namespace
188 
CreateTFExecutorGraphPruningPass(llvm::ArrayRef<std::string> ops_to_preserve)189 std::unique_ptr<OperationPass<func::FuncOp>> CreateTFExecutorGraphPruningPass(
190     llvm::ArrayRef<std::string> ops_to_preserve) {
191   return std::make_unique<GraphPruningPass>(ops_to_preserve);
192 }
193 
194 }  // namespace tf_executor
195 }  // namespace mlir
196