• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/tfrt/utils/tfrt_graph_execution_state.h"
16 
17 #include <memory>
18 #include <utility>
19 
20 #include "absl/container/flat_hash_map.h"
21 #include "absl/container/flat_hash_set.h"
22 #include "absl/time/clock.h"
23 #include "tensorflow/compiler/mlir/tensorflow/translate/upgrade_graph.h"
24 #include "tensorflow/core/common_runtime/graph_constructor.h"
25 #include "tensorflow/core/framework/op.h"
26 #include "tensorflow/core/framework/op_def.pb.h"
27 #include "tensorflow/core/framework/versions.pb.h"
28 #include "tensorflow/core/grappler/utils.h"
29 #include "tensorflow/core/platform/errors.h"
30 #include "tensorflow/core/platform/status.h"
31 #include "tensorflow/core/tfrt/fallback/fallback_state.h"
32 #include "tensorflow/core/util/dump_graph.h"
33 
34 namespace tensorflow {
35 namespace tfrt_stub {
36 
37 StatusOr<std::unique_ptr<TfrtGraphExecutionState>>
Create(tensorflow::GraphDef graph_def,const FallbackState & fallback_state)38 TfrtGraphExecutionState::Create(tensorflow::GraphDef graph_def,
39                                 const FallbackState& fallback_state) {
40   if (VLOG_IS_ON(1)) {
41     DumpGraphDefToFile("create_input_graph_def", graph_def);
42   }
43 
44   TF_RETURN_IF_ERROR(tensorflow::GenerateResourceSharedNameIfEmpty(
45       graph_def, tensorflow::OpRegistry::Global()));
46 
47   if (VLOG_IS_ON(2)) {
48     DumpGraphDefToFile("after_generate_resource_shared_name_graph_def",
49                        graph_def);
50   }
51 
52   // `CreateExecutionState()` will preprocess the graph (e.g., apply Placer).
53   TF_ASSIGN_OR_RETURN(
54       auto graph_execution_state,
55       fallback_state.CreateGraphExecutionState(std::move(graph_def)));
56 
57   return std::make_unique<TfrtGraphExecutionState>(
58       std::move(graph_execution_state));
59 }
60 
61 namespace {
62 
PopulateCallableOptions(CallableOptions & callable_options,const tensorflow::GraphImportConfig & graph_import_config)63 CallableOptions PopulateCallableOptions(
64     CallableOptions& callable_options,
65     const tensorflow::GraphImportConfig& graph_import_config) {
66   // Configure pruning with the feed/fetch/target tensor names.
67   callable_options.mutable_feed()->Reserve(graph_import_config.inputs.size());
68   for (const auto& feed_tensor : graph_import_config.inputs) {
69     callable_options.add_feed(feed_tensor.first);
70   }
71   callable_options.mutable_fetch()->Reserve(graph_import_config.outputs.size());
72   for (const auto& fetch_tensor_name : graph_import_config.outputs) {
73     callable_options.add_fetch(fetch_tensor_name);
74   }
75   callable_options.mutable_target()->Reserve(
76       graph_import_config.control_outputs.size());
77   for (const auto& target_tensor_name : graph_import_config.control_outputs) {
78     callable_options.add_target(target_tensor_name);
79   }
80 
81   return callable_options;
82 }
83 
CreateGraphDefFromGraphAndFlibDef(const tensorflow::Graph & graph,const tensorflow::FunctionLibraryDefinition & flib_def)84 tensorflow::GraphDef CreateGraphDefFromGraphAndFlibDef(
85     const tensorflow::Graph& graph,
86     const tensorflow::FunctionLibraryDefinition& flib_def) {
87   tensorflow::GraphDef graph_def;
88   graph.ToGraphDef(&graph_def);
89   *graph_def.mutable_library() = flib_def.ToProto();
90   return graph_def;
91 }
92 
93 // Creates a pruned graph from `graph_def` according to `callable_options`.
CreatePrunedGraph(tensorflow::GraphDef graph_def,const CallableOptions & callable_options)94 StatusOr<std::unique_ptr<tensorflow::Graph>> CreatePrunedGraph(
95     tensorflow::GraphDef graph_def, const CallableOptions& callable_options) {
96   VLOG(1) << "Creating pruned graph: " << callable_options.DebugString();
97 
98   // Prune the graph with `callable_options`. Although
99   // grappler has model_pruner stage, it may leave v1 control flows in an
100   // invalid state that cannot be functionalized. So we perform additional
101   // pruning before functionalization.
102   TF_RETURN_IF_ERROR(PruneGraphDef(graph_def, callable_options));
103 
104   if (VLOG_IS_ON(2)) {
105     DumpGraphDefToFile("before_eliminate_ref_variables_graph_def", graph_def);
106   }
107 
108   TF_RETURN_IF_ERROR(EliminateRefVariablesFromV1ControlFlow(graph_def));
109 
110   auto pruned_graph =
111       std::make_unique<tensorflow::Graph>(tensorflow::OpRegistry::Global());
112   tensorflow::GraphConstructorOptions options;
113   options.allow_internal_ops = true;
114   options.add_default_attributes = true;
115   TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(options, std::move(graph_def),
116                                             pruned_graph.get()));
117   return pruned_graph;
118 }
119 
120 // Creates a new identity node to replace an operand of a given `node`.
CreateNewIdentityNode(const NodeDef & node,const std::string & input_name,const std::string & identity_name)121 NodeDef CreateNewIdentityNode(const NodeDef& node,
122                               const std::string& input_name,
123                               const std::string& identity_name) {
124   NodeDef identity;
125   identity.set_name(identity_name);
126   identity.set_op("Identity");
127   identity.add_input(input_name);
128   identity.set_device(node.device());
129   for (const auto& name_and_attr : node.attr()) {
130     if (name_and_attr.first == "T") {
131       identity.mutable_attr()->insert(name_and_attr);
132       break;
133     }
134   }
135   return identity;
136 }
137 
138 }  // namespace
139 
140 StatusOr<TfrtGraphExecutionState::OptimizationResult>
CreateOptimizedGraph(const tensorflow::GraphImportConfig & graph_import_config)141 TfrtGraphExecutionState::CreateOptimizedGraph(
142     const tensorflow::GraphImportConfig& graph_import_config) {
143   OptimizationResult result;
144 
145   tensorflow::BuildGraphOptions build_graph_options;
146   PopulateCallableOptions(build_graph_options.callable_options,
147                           graph_import_config);
148 
149   auto graph_def = CreateGraphDefFromGraphAndFlibDef(graph(), flib_def());
150 
151   if (VLOG_IS_ON(1)) {
152     DumpGraphDefToFile("before_pruning", graph_def);
153   }
154 
155   TF_ASSIGN_OR_RETURN(
156       result.graph,
157       CreatePrunedGraph(graph_def, build_graph_options.callable_options));
158   DCHECK(result.graph);
159 
160   if (VLOG_IS_ON(1)) {
161     DumpGraphToFile("after_pruning", *result.graph);
162   }
163 
164   const auto functionalization_start_time = absl::Now();
165 
166   // Perform functionalization to convert v1 control flow to v2 control flow. It
167   // should be applied to the unoptimized graph, because Grappler may cause
168   // unfunctionalizablity.
169   TF_RETURN_IF_ERROR(tensorflow::UpgradeLegacyGraph(
170       result.graph.get(),
171       const_cast<tensorflow::FunctionLibraryDefinition*>(
172           &result.graph->flib_def()),
173       /*restrict_functionalization_to_tpu_nodes=*/false));
174 
175   if (VLOG_IS_ON(1)) {
176     DumpGraphToFile("after_functionalization", *result.graph);
177   }
178 
179   auto grappler_start_time = absl::Now();
180   result.functionalization_duration =
181       grappler_start_time - functionalization_start_time;
182 
183   TF_RETURN_IF_ERROR(OptimizeGraph(result.graph, build_graph_options));
184 
185   if (VLOG_IS_ON(1)) {
186     DumpGraphToFile("after_grappler", *result.graph);
187   }
188 
189   result.grappler_duration = absl::Now() - grappler_start_time;
190 
191   return result;
192 }
193 
194 namespace {
195 
196 // Given an "Exit" node, finds its corresponding "LoopCond" node.
FindLoopCondFromExitNode(const NodeDef & exit_node,const absl::flat_hash_map<std::string,NodeDef * > & name_to_node)197 StatusOr<const NodeDef*> FindLoopCondFromExitNode(
198     const NodeDef& exit_node,
199     const absl::flat_hash_map<std::string, NodeDef*>& name_to_node) {
200   const NodeDef* switch_node = nullptr;
201   for (const std::string& tensor_name : exit_node.input()) {
202     const std::string node_name = grappler::NodeName(tensor_name);
203     if (!name_to_node.contains(node_name)) {
204       return errors::InvalidArgument("Graph does not contain input ", node_name,
205                                      " of exit node ", exit_node.name());
206     }
207     const NodeDef* node = name_to_node.at(node_name);
208     if (node->op() == "Switch") {
209       switch_node = node;
210       break;
211     }
212   }
213   if (switch_node == nullptr) {
214     return errors::InvalidArgument("Exit node ", exit_node.name(),
215                                    " does not have a Switch node as its ",
216                                    "predecessor.");
217   }
218   for (const std::string& tensor_name : switch_node->input()) {
219     const std::string node_name = grappler::NodeName(tensor_name);
220     if (!name_to_node.contains(node_name)) {
221       return errors::InvalidArgument("Graph does not contain input ", node_name,
222                                      " of switch node ", switch_node->name());
223     }
224 
225     const NodeDef* node = name_to_node.at(node_name);
226     if (node->op() == "LoopCond") {
227       return node;
228     }
229   }
230 
231   return errors::InvalidArgument("Switch node ", switch_node->name(),
232                                  " does not have a LoopCond node as its ",
233                                  "predecessor.");
234 }
235 
236 }  // namespace
237 
PruneGraphDef(GraphDef & graph_def,const CallableOptions & callable_options)238 Status PruneGraphDef(GraphDef& graph_def,
239                      const CallableOptions& callable_options) {
240   // Gather node names and create a map from names to NodeDefs.
241   absl::flat_hash_map<std::string, NodeDef*> name_to_node;
242   // All exit nodes in order to track all while loops.
243   absl::flat_hash_set<const NodeDef*> exit_nodes;
244   for (auto& node : *graph_def.mutable_node()) {
245     name_to_node[node.name()] = &node;
246     if (node.op() == "Exit") {
247       exit_nodes.insert(&node);
248     }
249 
250     // TODO(tfrt-devs): Add support for _Send and _Recv ops.
251     if (node.op() == "_Send" || node.op() == "_Recv") {
252       return errors::InvalidArgument(
253           "TFRT prune graphdef cannot handle graphs contains _Send and _Recv "
254           "ops.");
255     }
256   }
257 
258   // Find all LoopCond -> Exit nodes mapping. So when we traverse to a LoopCond
259   // node, we can add corresponding Exit nodes to the traversal queue in order
260   // to maintain complete structure of a while loop.
261   absl::flat_hash_map<const NodeDef*, absl::flat_hash_set<const NodeDef*>>
262       loop_cond_to_exit_nodes;
263   for (const NodeDef* exit_node : exit_nodes) {
264     TF_ASSIGN_OR_RETURN(const NodeDef* loop_cond_node,
265                         FindLoopCondFromExitNode(*exit_node, name_to_node));
266     loop_cond_to_exit_nodes[loop_cond_node].insert(exit_node);
267   }
268 
269   // `queue` is for candidate nodes we want to visit in the graph.
270   std::vector<const NodeDef*> queue;
271 
272   // Add fetch nodes to the queue.
273   absl::flat_hash_set<std::string> fetch_node_names;
274   for (const std::string& tensor_name : callable_options.fetch()) {
275     const NodeDef* node = name_to_node[grappler::NodeName(tensor_name)];
276     if (!node) {
277       return errors::InvalidArgument("Graph does not contain fetch node ",
278                                      tensor_name, ".");
279     }
280     queue.push_back(node);
281     fetch_node_names.insert(node->name());
282   }
283 
284   // Add control target nodes to the queue.
285   for (const std::string& tensor_name : callable_options.target()) {
286     const NodeDef* node = name_to_node[grappler::NodeName(tensor_name)];
287     if (!node) {
288       return errors::InvalidArgument("Graph does not contain target node ",
289                                      tensor_name, ".");
290     }
291     queue.push_back(node);
292     fetch_node_names.insert(node->name());
293   }
294 
295   absl::flat_hash_set<NodeDef*> feed_node_defs;
296 
297   // Add feed nodes to the queue. In addition, perform necessary rewrites to
298   // remove unnecessary input edges.
299   for (const std::string& tensor_name : callable_options.feed()) {
300     NodeDef* node = name_to_node[grappler::NodeName(tensor_name)];
301     if (!node) {
302       return errors::InvalidArgument("Graph does not contain feed node ",
303                                      tensor_name, ".");
304     }
305 
306     // If a feed node is a Const, we don't need its inputs at all.
307     //
308     // TODO(tfrt-devs): Consider a general solution that we could just rewrite
309     // all feed nodes to Placeholde nodes.
310     if (node->op() == "Const") {
311       node->clear_input();
312     }
313 
314     queue.push_back(node);
315     feed_node_defs.insert(node);
316   }
317 
318   absl::flat_hash_set<const NodeDef*> visited;
319   std::vector<NodeDef> keep;
320 
321   // Perform graph traversal to find out connected nodes from fetches.
322   while (!queue.empty()) {
323     const NodeDef* node = queue.back();
324     queue.pop_back();
325 
326     if (!visited.insert(node).second) {
327       continue;
328     }
329 
330     keep.push_back(*node);
331     if (node->op() == "LoopCond") {
332       for (const NodeDef* exit_node : loop_cond_to_exit_nodes[node]) {
333         queue.push_back(exit_node);
334       }
335     }
336 
337     for (const std::string& tensor_name : node->input()) {
338       const NodeDef* in = name_to_node[grappler::NodeName(tensor_name)];
339       if (!in) {
340         return errors::InvalidArgument("Graph does not contain input ",
341                                        grappler::NodeName(tensor_name),
342                                        " of node ", node->name(), ".");
343       }
344       queue.push_back(in);
345     }
346   }
347 
348   graph_def.clear_node();
349   for (auto& node : keep) {
350     if (fetch_node_names.contains(node.name())) {
351       // If the fetch node is an Exit op, we insert an Identity op right after
352       // it and rename it to be the new fetch node. This is to prevent
353       // functionalization from removing the fetch nodes.
354       if (node.op() == "Exit") {
355         auto renamed_exit_node = node;
356         renamed_exit_node.set_name(
357             absl::StrCat(renamed_exit_node.name(), "/tfrt_renamed"));
358         node.set_op("Identity");
359         *node.mutable_input(0) = renamed_exit_node.name();
360         *graph_def.add_node() = std::move(renamed_exit_node);
361       }
362     }
363 
364     *graph_def.add_node() = std::move(node);
365   }
366 
367   return Status::OK();
368 }
369 
EliminateRefVariablesFromV1ControlFlow(tensorflow::GraphDef & graph_def)370 Status EliminateRefVariablesFromV1ControlFlow(tensorflow::GraphDef& graph_def) {
371   auto* op_factory = OpRegistry::Global();
372 
373   absl::flat_hash_set<std::string> ref_nodes;
374   for (const auto& node : graph_def.node()) {
375     if (node.op() == "RefEnter" || node.op() == "RefSwitch") {
376       ref_nodes.insert(node.name());
377     }
378   }
379 
380   tensorflow::GraphDef updated_graph_def;
381   absl::flat_hash_set<std::string> new_identities;
382   // Insert an identity node between each "RefEnter" or "RefSwitch" node and its
383   // ref input. Then modify each "RefEnter"/"RefSwitch" node in-place to an
384   // "Enter"/"Switch" node.
385   for (auto& node : *graph_def.mutable_node()) {
386     // First find the ref input name to this RefEnter or RefSwitch.
387     std::string* ref_input_name = nullptr;
388     if (node.op() == "RefEnter") {
389       node.set_op("Enter");
390       if (node.input_size() != 1) {
391         return errors::InvalidArgument("RefEnter node ", node.name(),
392                                        " does not have exactly 1 input.");
393       }
394       ref_input_name = node.mutable_input(0);
395     } else if (node.op() == "RefSwitch") {
396       node.set_op("Switch");
397       if (node.input_size() != 2) {
398         return errors::InvalidArgument("RefSwitch node", node.name(),
399                                        " does not have exactly 2 inputs.");
400       }
401       ref_input_name = node.mutable_input(0);
402     } else {
403       // For other ops, check if their inputs are the ref ops we want to
404       // eliminate, and if so, these ops must not require their inputs to be
405       // refs.
406       std::string ref_input;
407       for (const auto& tensor_name : node.input()) {
408         std::string input = grappler::NodeName(tensor_name);
409         if (ref_nodes.contains(input)) {
410           ref_input = std::move(input);
411           break;
412         }
413       }
414       if (!ref_input.empty()) {
415         const OpDef* op_def;
416         TF_RETURN_IF_ERROR(op_factory->LookUpOpDef(node.op(), &op_def));
417         // TODO(tfrt-devs): How to match input_args to input names in NodeDef?
418         for (const auto& input_arg : op_def->input_arg()) {
419           if (input_arg.is_ref()) {
420             return errors::Unimplemented(
421                 "Cannot in-place update ref node ", ref_input,
422                 " to the non-ref counterpart since its user node ", node.name(),
423                 " requires its input to be refs.");
424           }
425         }
426       }
427     }
428 
429     if (ref_input_name != nullptr) {
430       std::string identity_name =
431           absl::StrCat(grappler::NodeName(*ref_input_name), "/identity");
432       if (!new_identities.contains(identity_name)) {
433         *updated_graph_def.add_node() =
434             CreateNewIdentityNode(node, *ref_input_name, identity_name);
435         new_identities.insert(identity_name);
436       }
437       *ref_input_name = std::move(identity_name);
438     }
439 
440     *updated_graph_def.add_node() = std::move(node);
441   }
442 
443   graph_def.mutable_node()->Swap(updated_graph_def.mutable_node());
444   return Status::OK();
445 }
446 
OptimizeGraph(std::unique_ptr<tensorflow::Graph> & graph,const tensorflow::BuildGraphOptions & build_graph_options)447 Status TfrtGraphExecutionState::OptimizeGraph(
448     std::unique_ptr<tensorflow::Graph>& graph,
449     const tensorflow::BuildGraphOptions& build_graph_options) {
450   std::unique_ptr<tensorflow::Graph> optimized_graph;
451   std::unique_ptr<tensorflow::FunctionLibraryDefinition> optimized_flib;
452 
453   // Invoke Grappler to optimize the graph.
454   auto status = graph_execution_state_->OptimizeGraph(
455       build_graph_options, *graph, &graph->flib_def(), &optimized_graph,
456       &optimized_flib);
457 
458   if (!status.ok()) {
459     LOG(WARNING) << "TFRT failed to optimize graph: " << status;
460     return tensorflow::Status::OK();
461   }
462 
463   TF_RETURN_IF_ERROR(
464       optimized_graph->AddFunctionLibrary(optimized_flib->ToProto()));
465   graph = std::move(optimized_graph);
466   return tensorflow::Status::OK();
467 }
468 
469 }  // namespace tfrt_stub
470 }  // namespace tensorflow
471