1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/tfrt/utils/tfrt_graph_execution_state.h"
16
17 #include <memory>
18 #include <utility>
19
20 #include "absl/container/flat_hash_map.h"
21 #include "absl/container/flat_hash_set.h"
22 #include "absl/time/clock.h"
23 #include "tensorflow/compiler/mlir/tensorflow/translate/upgrade_graph.h"
24 #include "tensorflow/core/common_runtime/graph_constructor.h"
25 #include "tensorflow/core/framework/op.h"
26 #include "tensorflow/core/framework/op_def.pb.h"
27 #include "tensorflow/core/framework/versions.pb.h"
28 #include "tensorflow/core/grappler/utils.h"
29 #include "tensorflow/core/platform/errors.h"
30 #include "tensorflow/core/platform/status.h"
31 #include "tensorflow/core/tfrt/fallback/fallback_state.h"
32 #include "tensorflow/core/util/dump_graph.h"
33
34 namespace tensorflow {
35 namespace tfrt_stub {
36
37 StatusOr<std::unique_ptr<TfrtGraphExecutionState>>
Create(tensorflow::GraphDef graph_def,const FallbackState & fallback_state)38 TfrtGraphExecutionState::Create(tensorflow::GraphDef graph_def,
39 const FallbackState& fallback_state) {
40 if (VLOG_IS_ON(1)) {
41 DumpGraphDefToFile("create_input_graph_def", graph_def);
42 }
43
44 TF_RETURN_IF_ERROR(tensorflow::GenerateResourceSharedNameIfEmpty(
45 graph_def, tensorflow::OpRegistry::Global()));
46
47 if (VLOG_IS_ON(2)) {
48 DumpGraphDefToFile("after_generate_resource_shared_name_graph_def",
49 graph_def);
50 }
51
52 // `CreateExecutionState()` will preprocess the graph (e.g., apply Placer).
53 TF_ASSIGN_OR_RETURN(
54 auto graph_execution_state,
55 fallback_state.CreateGraphExecutionState(std::move(graph_def)));
56
57 return std::make_unique<TfrtGraphExecutionState>(
58 std::move(graph_execution_state));
59 }
60
61 namespace {
62
PopulateCallableOptions(CallableOptions & callable_options,const tensorflow::GraphImportConfig & graph_import_config)63 CallableOptions PopulateCallableOptions(
64 CallableOptions& callable_options,
65 const tensorflow::GraphImportConfig& graph_import_config) {
66 // Configure pruning with the feed/fetch/target tensor names.
67 callable_options.mutable_feed()->Reserve(graph_import_config.inputs.size());
68 for (const auto& feed_tensor : graph_import_config.inputs) {
69 callable_options.add_feed(feed_tensor.first);
70 }
71 callable_options.mutable_fetch()->Reserve(graph_import_config.outputs.size());
72 for (const auto& fetch_tensor_name : graph_import_config.outputs) {
73 callable_options.add_fetch(fetch_tensor_name);
74 }
75 callable_options.mutable_target()->Reserve(
76 graph_import_config.control_outputs.size());
77 for (const auto& target_tensor_name : graph_import_config.control_outputs) {
78 callable_options.add_target(target_tensor_name);
79 }
80
81 return callable_options;
82 }
83
CreateGraphDefFromGraphAndFlibDef(const tensorflow::Graph & graph,const tensorflow::FunctionLibraryDefinition & flib_def)84 tensorflow::GraphDef CreateGraphDefFromGraphAndFlibDef(
85 const tensorflow::Graph& graph,
86 const tensorflow::FunctionLibraryDefinition& flib_def) {
87 tensorflow::GraphDef graph_def;
88 graph.ToGraphDef(&graph_def);
89 *graph_def.mutable_library() = flib_def.ToProto();
90 return graph_def;
91 }
92
93 // Creates a pruned graph from `graph_def` according to `callable_options`.
CreatePrunedGraph(tensorflow::GraphDef graph_def,const CallableOptions & callable_options)94 StatusOr<std::unique_ptr<tensorflow::Graph>> CreatePrunedGraph(
95 tensorflow::GraphDef graph_def, const CallableOptions& callable_options) {
96 VLOG(1) << "Creating pruned graph: " << callable_options.DebugString();
97
98 // Prune the graph with `callable_options`. Although
99 // grappler has model_pruner stage, it may leave v1 control flows in an
100 // invalid state that cannot be functionalized. So we perform additional
101 // pruning before functionalization.
102 TF_RETURN_IF_ERROR(PruneGraphDef(graph_def, callable_options));
103
104 if (VLOG_IS_ON(2)) {
105 DumpGraphDefToFile("before_eliminate_ref_variables_graph_def", graph_def);
106 }
107
108 TF_RETURN_IF_ERROR(EliminateRefVariablesFromV1ControlFlow(graph_def));
109
110 auto pruned_graph =
111 std::make_unique<tensorflow::Graph>(tensorflow::OpRegistry::Global());
112 tensorflow::GraphConstructorOptions options;
113 options.allow_internal_ops = true;
114 options.add_default_attributes = true;
115 TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(options, std::move(graph_def),
116 pruned_graph.get()));
117 return pruned_graph;
118 }
119
120 // Creates a new identity node to replace an operand of a given `node`.
CreateNewIdentityNode(const NodeDef & node,const std::string & input_name,const std::string & identity_name)121 NodeDef CreateNewIdentityNode(const NodeDef& node,
122 const std::string& input_name,
123 const std::string& identity_name) {
124 NodeDef identity;
125 identity.set_name(identity_name);
126 identity.set_op("Identity");
127 identity.add_input(input_name);
128 identity.set_device(node.device());
129 for (const auto& name_and_attr : node.attr()) {
130 if (name_and_attr.first == "T") {
131 identity.mutable_attr()->insert(name_and_attr);
132 break;
133 }
134 }
135 return identity;
136 }
137
138 } // namespace
139
140 StatusOr<TfrtGraphExecutionState::OptimizationResult>
CreateOptimizedGraph(const tensorflow::GraphImportConfig & graph_import_config)141 TfrtGraphExecutionState::CreateOptimizedGraph(
142 const tensorflow::GraphImportConfig& graph_import_config) {
143 OptimizationResult result;
144
145 tensorflow::BuildGraphOptions build_graph_options;
146 PopulateCallableOptions(build_graph_options.callable_options,
147 graph_import_config);
148
149 auto graph_def = CreateGraphDefFromGraphAndFlibDef(graph(), flib_def());
150
151 if (VLOG_IS_ON(1)) {
152 DumpGraphDefToFile("before_pruning", graph_def);
153 }
154
155 TF_ASSIGN_OR_RETURN(
156 result.graph,
157 CreatePrunedGraph(graph_def, build_graph_options.callable_options));
158 DCHECK(result.graph);
159
160 if (VLOG_IS_ON(1)) {
161 DumpGraphToFile("after_pruning", *result.graph);
162 }
163
164 const auto functionalization_start_time = absl::Now();
165
166 // Perform functionalization to convert v1 control flow to v2 control flow. It
167 // should be applied to the unoptimized graph, because Grappler may cause
168 // unfunctionalizablity.
169 TF_RETURN_IF_ERROR(tensorflow::UpgradeLegacyGraph(
170 result.graph.get(),
171 const_cast<tensorflow::FunctionLibraryDefinition*>(
172 &result.graph->flib_def()),
173 /*restrict_functionalization_to_tpu_nodes=*/false));
174
175 if (VLOG_IS_ON(1)) {
176 DumpGraphToFile("after_functionalization", *result.graph);
177 }
178
179 auto grappler_start_time = absl::Now();
180 result.functionalization_duration =
181 grappler_start_time - functionalization_start_time;
182
183 TF_RETURN_IF_ERROR(OptimizeGraph(result.graph, build_graph_options));
184
185 if (VLOG_IS_ON(1)) {
186 DumpGraphToFile("after_grappler", *result.graph);
187 }
188
189 result.grappler_duration = absl::Now() - grappler_start_time;
190
191 return result;
192 }
193
194 namespace {
195
196 // Given an "Exit" node, finds its corresponding "LoopCond" node.
FindLoopCondFromExitNode(const NodeDef & exit_node,const absl::flat_hash_map<std::string,NodeDef * > & name_to_node)197 StatusOr<const NodeDef*> FindLoopCondFromExitNode(
198 const NodeDef& exit_node,
199 const absl::flat_hash_map<std::string, NodeDef*>& name_to_node) {
200 const NodeDef* switch_node = nullptr;
201 for (const std::string& tensor_name : exit_node.input()) {
202 const std::string node_name = grappler::NodeName(tensor_name);
203 if (!name_to_node.contains(node_name)) {
204 return errors::InvalidArgument("Graph does not contain input ", node_name,
205 " of exit node ", exit_node.name());
206 }
207 const NodeDef* node = name_to_node.at(node_name);
208 if (node->op() == "Switch") {
209 switch_node = node;
210 break;
211 }
212 }
213 if (switch_node == nullptr) {
214 return errors::InvalidArgument("Exit node ", exit_node.name(),
215 " does not have a Switch node as its ",
216 "predecessor.");
217 }
218 for (const std::string& tensor_name : switch_node->input()) {
219 const std::string node_name = grappler::NodeName(tensor_name);
220 if (!name_to_node.contains(node_name)) {
221 return errors::InvalidArgument("Graph does not contain input ", node_name,
222 " of switch node ", switch_node->name());
223 }
224
225 const NodeDef* node = name_to_node.at(node_name);
226 if (node->op() == "LoopCond") {
227 return node;
228 }
229 }
230
231 return errors::InvalidArgument("Switch node ", switch_node->name(),
232 " does not have a LoopCond node as its ",
233 "predecessor.");
234 }
235
236 } // namespace
237
PruneGraphDef(GraphDef & graph_def,const CallableOptions & callable_options)238 Status PruneGraphDef(GraphDef& graph_def,
239 const CallableOptions& callable_options) {
240 // Gather node names and create a map from names to NodeDefs.
241 absl::flat_hash_map<std::string, NodeDef*> name_to_node;
242 // All exit nodes in order to track all while loops.
243 absl::flat_hash_set<const NodeDef*> exit_nodes;
244 for (auto& node : *graph_def.mutable_node()) {
245 name_to_node[node.name()] = &node;
246 if (node.op() == "Exit") {
247 exit_nodes.insert(&node);
248 }
249
250 // TODO(tfrt-devs): Add support for _Send and _Recv ops.
251 if (node.op() == "_Send" || node.op() == "_Recv") {
252 return errors::InvalidArgument(
253 "TFRT prune graphdef cannot handle graphs contains _Send and _Recv "
254 "ops.");
255 }
256 }
257
258 // Find all LoopCond -> Exit nodes mapping. So when we traverse to a LoopCond
259 // node, we can add corresponding Exit nodes to the traversal queue in order
260 // to maintain complete structure of a while loop.
261 absl::flat_hash_map<const NodeDef*, absl::flat_hash_set<const NodeDef*>>
262 loop_cond_to_exit_nodes;
263 for (const NodeDef* exit_node : exit_nodes) {
264 TF_ASSIGN_OR_RETURN(const NodeDef* loop_cond_node,
265 FindLoopCondFromExitNode(*exit_node, name_to_node));
266 loop_cond_to_exit_nodes[loop_cond_node].insert(exit_node);
267 }
268
269 // `queue` is for candidate nodes we want to visit in the graph.
270 std::vector<const NodeDef*> queue;
271
272 // Add fetch nodes to the queue.
273 absl::flat_hash_set<std::string> fetch_node_names;
274 for (const std::string& tensor_name : callable_options.fetch()) {
275 const NodeDef* node = name_to_node[grappler::NodeName(tensor_name)];
276 if (!node) {
277 return errors::InvalidArgument("Graph does not contain fetch node ",
278 tensor_name, ".");
279 }
280 queue.push_back(node);
281 fetch_node_names.insert(node->name());
282 }
283
284 // Add control target nodes to the queue.
285 for (const std::string& tensor_name : callable_options.target()) {
286 const NodeDef* node = name_to_node[grappler::NodeName(tensor_name)];
287 if (!node) {
288 return errors::InvalidArgument("Graph does not contain target node ",
289 tensor_name, ".");
290 }
291 queue.push_back(node);
292 fetch_node_names.insert(node->name());
293 }
294
295 absl::flat_hash_set<NodeDef*> feed_node_defs;
296
297 // Add feed nodes to the queue. In addition, perform necessary rewrites to
298 // remove unnecessary input edges.
299 for (const std::string& tensor_name : callable_options.feed()) {
300 NodeDef* node = name_to_node[grappler::NodeName(tensor_name)];
301 if (!node) {
302 return errors::InvalidArgument("Graph does not contain feed node ",
303 tensor_name, ".");
304 }
305
306 // If a feed node is a Const, we don't need its inputs at all.
307 //
308 // TODO(tfrt-devs): Consider a general solution that we could just rewrite
309 // all feed nodes to Placeholde nodes.
310 if (node->op() == "Const") {
311 node->clear_input();
312 }
313
314 queue.push_back(node);
315 feed_node_defs.insert(node);
316 }
317
318 absl::flat_hash_set<const NodeDef*> visited;
319 std::vector<NodeDef> keep;
320
321 // Perform graph traversal to find out connected nodes from fetches.
322 while (!queue.empty()) {
323 const NodeDef* node = queue.back();
324 queue.pop_back();
325
326 if (!visited.insert(node).second) {
327 continue;
328 }
329
330 keep.push_back(*node);
331 if (node->op() == "LoopCond") {
332 for (const NodeDef* exit_node : loop_cond_to_exit_nodes[node]) {
333 queue.push_back(exit_node);
334 }
335 }
336
337 for (const std::string& tensor_name : node->input()) {
338 const NodeDef* in = name_to_node[grappler::NodeName(tensor_name)];
339 if (!in) {
340 return errors::InvalidArgument("Graph does not contain input ",
341 grappler::NodeName(tensor_name),
342 " of node ", node->name(), ".");
343 }
344 queue.push_back(in);
345 }
346 }
347
348 graph_def.clear_node();
349 for (auto& node : keep) {
350 if (fetch_node_names.contains(node.name())) {
351 // If the fetch node is an Exit op, we insert an Identity op right after
352 // it and rename it to be the new fetch node. This is to prevent
353 // functionalization from removing the fetch nodes.
354 if (node.op() == "Exit") {
355 auto renamed_exit_node = node;
356 renamed_exit_node.set_name(
357 absl::StrCat(renamed_exit_node.name(), "/tfrt_renamed"));
358 node.set_op("Identity");
359 *node.mutable_input(0) = renamed_exit_node.name();
360 *graph_def.add_node() = std::move(renamed_exit_node);
361 }
362 }
363
364 *graph_def.add_node() = std::move(node);
365 }
366
367 return Status::OK();
368 }
369
EliminateRefVariablesFromV1ControlFlow(tensorflow::GraphDef & graph_def)370 Status EliminateRefVariablesFromV1ControlFlow(tensorflow::GraphDef& graph_def) {
371 auto* op_factory = OpRegistry::Global();
372
373 absl::flat_hash_set<std::string> ref_nodes;
374 for (const auto& node : graph_def.node()) {
375 if (node.op() == "RefEnter" || node.op() == "RefSwitch") {
376 ref_nodes.insert(node.name());
377 }
378 }
379
380 tensorflow::GraphDef updated_graph_def;
381 absl::flat_hash_set<std::string> new_identities;
382 // Insert an identity node between each "RefEnter" or "RefSwitch" node and its
383 // ref input. Then modify each "RefEnter"/"RefSwitch" node in-place to an
384 // "Enter"/"Switch" node.
385 for (auto& node : *graph_def.mutable_node()) {
386 // First find the ref input name to this RefEnter or RefSwitch.
387 std::string* ref_input_name = nullptr;
388 if (node.op() == "RefEnter") {
389 node.set_op("Enter");
390 if (node.input_size() != 1) {
391 return errors::InvalidArgument("RefEnter node ", node.name(),
392 " does not have exactly 1 input.");
393 }
394 ref_input_name = node.mutable_input(0);
395 } else if (node.op() == "RefSwitch") {
396 node.set_op("Switch");
397 if (node.input_size() != 2) {
398 return errors::InvalidArgument("RefSwitch node", node.name(),
399 " does not have exactly 2 inputs.");
400 }
401 ref_input_name = node.mutable_input(0);
402 } else {
403 // For other ops, check if their inputs are the ref ops we want to
404 // eliminate, and if so, these ops must not require their inputs to be
405 // refs.
406 std::string ref_input;
407 for (const auto& tensor_name : node.input()) {
408 std::string input = grappler::NodeName(tensor_name);
409 if (ref_nodes.contains(input)) {
410 ref_input = std::move(input);
411 break;
412 }
413 }
414 if (!ref_input.empty()) {
415 const OpDef* op_def;
416 TF_RETURN_IF_ERROR(op_factory->LookUpOpDef(node.op(), &op_def));
417 // TODO(tfrt-devs): How to match input_args to input names in NodeDef?
418 for (const auto& input_arg : op_def->input_arg()) {
419 if (input_arg.is_ref()) {
420 return errors::Unimplemented(
421 "Cannot in-place update ref node ", ref_input,
422 " to the non-ref counterpart since its user node ", node.name(),
423 " requires its input to be refs.");
424 }
425 }
426 }
427 }
428
429 if (ref_input_name != nullptr) {
430 std::string identity_name =
431 absl::StrCat(grappler::NodeName(*ref_input_name), "/identity");
432 if (!new_identities.contains(identity_name)) {
433 *updated_graph_def.add_node() =
434 CreateNewIdentityNode(node, *ref_input_name, identity_name);
435 new_identities.insert(identity_name);
436 }
437 *ref_input_name = std::move(identity_name);
438 }
439
440 *updated_graph_def.add_node() = std::move(node);
441 }
442
443 graph_def.mutable_node()->Swap(updated_graph_def.mutable_node());
444 return Status::OK();
445 }
446
OptimizeGraph(std::unique_ptr<tensorflow::Graph> & graph,const tensorflow::BuildGraphOptions & build_graph_options)447 Status TfrtGraphExecutionState::OptimizeGraph(
448 std::unique_ptr<tensorflow::Graph>& graph,
449 const tensorflow::BuildGraphOptions& build_graph_options) {
450 std::unique_ptr<tensorflow::Graph> optimized_graph;
451 std::unique_ptr<tensorflow::FunctionLibraryDefinition> optimized_flib;
452
453 // Invoke Grappler to optimize the graph.
454 auto status = graph_execution_state_->OptimizeGraph(
455 build_graph_options, *graph, &graph->flib_def(), &optimized_graph,
456 &optimized_flib);
457
458 if (!status.ok()) {
459 LOG(WARNING) << "TFRT failed to optimize graph: " << status;
460 return tensorflow::Status::OK();
461 }
462
463 TF_RETURN_IF_ERROR(
464 optimized_graph->AddFunctionLibrary(optimized_flib->ToProto()));
465 graph = std::move(optimized_graph);
466 return tensorflow::Status::OK();
467 }
468
469 } // namespace tfrt_stub
470 } // namespace tensorflow
471