1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
17
18 #include <functional>
19 #include <queue>
20 #include <random>
21 #include <set>
22 #include <unordered_map>
23
24 #include "absl/strings/str_cat.h"
25 #include "tensorflow/compiler/tf2xla/sharding_util.h"
26 #include "tensorflow/compiler/tf2xla/tf2xla.pb.h"
27 #include "tensorflow/compiler/xla/xla_data.pb.h"
28 #include "tensorflow/core/common_runtime/function.h"
29 #include "tensorflow/core/framework/graph.pb.h"
30 #include "tensorflow/core/framework/graph_def_util.h"
31 #include "tensorflow/core/framework/graph_to_functiondef.h"
32 #include "tensorflow/core/framework/node_def.pb.h"
33 #include "tensorflow/core/framework/node_def_builder.h"
34 #include "tensorflow/core/framework/node_def_util.h"
35 #include "tensorflow/core/framework/tensor_shape.h"
36 #include "tensorflow/core/framework/tensor_shape.pb.h"
37 #include "tensorflow/core/framework/versions.pb.h"
38 #include "tensorflow/core/graph/tensor_id.h"
39 #include "tensorflow/core/lib/core/errors.h"
40 #include "tensorflow/core/lib/core/status.h"
41
42 namespace tensorflow {
43
44 namespace {
45
ValidateTensorId(const tf2xla::TensorId & id)46 Status ValidateTensorId(const tf2xla::TensorId& id) {
47 if (id.node_name().empty()) {
48 return errors::InvalidArgument("TensorId node_name must be non-empty");
49 }
50 if (id.output_index() < 0) {
51 return errors::InvalidArgument("TensorId output_index must be positive");
52 }
53 return Status::OK();
54 }
55
CheckNameDuplicates(const string & kind,const string & name,std::set<string> * names)56 Status CheckNameDuplicates(const string& kind, const string& name,
57 std::set<string>* names) {
58 if (!name.empty()) {
59 if (!names->insert(name).second) {
60 return errors::InvalidArgument("duplicate ", kind, " name: ", name);
61 }
62 }
63 return Status::OK();
64 }
65
CheckFeedFetchNameConflicts(const string & kind,const std::set<string> & names)66 Status CheckFeedFetchNameConflicts(const string& kind,
67 const std::set<string>& names) {
68 // We don't allow the feeds or fetches to contain both "foo" and "foo_data",
69 // since that will cause a collision in codegen symbols.
70 for (const string& name : names) {
71 const string name_data(name + "_data");
72 if (names.find(name_data) != names.end()) {
73 return errors::InvalidArgument("conflicting ", kind, " name: ", name,
74 " and ", name_data);
75 }
76 }
77 return Status::OK();
78 }
79
80 // For graph `g`, copy all function call nodes' FunctionDef from `lookup_fld` to
81 // `fld`. This is to ensure that `fld` can instantiate FunctionDef of graph `g`.
CopyAssociatedFunctions(Graph * g,const FunctionLibraryDefinition * lookup_fld,FunctionLibraryDefinition * fld)82 Status CopyAssociatedFunctions(Graph* g,
83 const FunctionLibraryDefinition* lookup_fld,
84 FunctionLibraryDefinition* fld) {
85 for (Node* n : g->op_nodes()) {
86 for (const auto& associated_function :
87 GetAssociatedFunctions(*n, lookup_fld)) {
88 switch (associated_function.type()) {
89 case AssociatedFunctionInfo::kFunctionCallNode: {
90 const FunctionDef* fdef =
91 lookup_fld->Find(associated_function.func_name());
92 if (!fdef) {
93 return errors::Internal(
94 "Cannot find function ", associated_function.func_name(),
95 " for function call node ", n->DebugString());
96 }
97 TF_RETURN_IF_ERROR(fld->AddFunctionDef(*fdef));
98 break;
99 }
100 case AssociatedFunctionInfo::kSymbolicGradient:
101 case AssociatedFunctionInfo::kFunctionAttr:
102 break;
103 }
104 }
105 }
106 return Status::OK();
107 }
108
109 // For graph `g`, replaces _Arg nodes whose "index" attribute is in
110 // `const_input_index_to_node` with Const nodes.
ReplaceArgUsageWithConstNode(Graph * g,const std::unordered_map<int,const Node * > & const_input_index_to_node)111 Status ReplaceArgUsageWithConstNode(
112 Graph* g,
113 const std::unordered_map<int, const Node*>& const_input_index_to_node) {
114 // Collect all _Arg nodes.
115 std::unordered_map<int, Node*> arg_nodes;
116 for (Node* n : g->op_nodes()) {
117 if (n->IsArg()) {
118 int index;
119 TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
120 arg_nodes[index] = n;
121 }
122 }
123
124 for (const auto& iter : const_input_index_to_node) {
125 int arg_index = iter.first;
126 NodeDef const_def = iter.second->def();
127 const_def.set_name(g->NewName(const_def.name()));
128 Status s;
129 Node* const_node = g->AddNode(const_def, &s);
130 TF_RETURN_IF_ERROR(s);
131
132 Node* arg_node = arg_nodes[arg_index];
133
134 // Collect all usages of the _Arg node.
135 struct OutEdgeInfo {
136 int dst_node_id, dst_input;
137 };
138 std::vector<OutEdgeInfo> usages;
139 for (const Edge* e : arg_node->out_edges()) {
140 if (e->IsControlEdge()) {
141 continue;
142 }
143 usages.push_back({e->dst()->id(), e->dst_input()});
144 }
145
146 for (int i = 0; i < usages.size(); i++) {
147 // Make a copy of `usage_node`, and change its input to const node.
148 Node* usage_node = g->FindNodeId(usages[i].dst_node_id);
149 NodeDef replace_def = usage_node->def();
150 *replace_def.mutable_input(usages[i].dst_input) = const_node->name();
151 TF_ASSIGN_OR_RETURN(Node * replace_node,
152 ReplaceNode(g, usage_node, replace_def));
153 const Edge* usage_edge;
154 TF_RETURN_IF_ERROR(
155 replace_node->input_edge(usages[i].dst_input, &usage_edge));
156 g->RemoveEdge(usage_edge);
157 g->AddEdge(const_node, 0, replace_node, usages[i].dst_input);
158
159 // Later entries in `usages` might have `usage_node` as dst node, but
160 // `usage_node` is removed. Replace such entries with `replace_node`.
161 for (int j = i + 1; j < usages.size(); j++) {
162 if (usages[j].dst_node_id == usages[i].dst_node_id) {
163 usages[j].dst_node_id = replace_node->id();
164 }
165 }
166 }
167 }
168 return Status::OK();
169 }
170
171 // For a node's function attr (e.g. then/else branch for "If" nodes), rewrites
172 // the function to replace _Arg nodes in `const_input_index_to_node` with Const
173 // inputs.
PropagateConstIntoFuncAttr(Node * n,const string & attr_name,const std::unordered_map<int,const Node * > & const_input_index_to_node,const FunctionLibraryDefinition * lookup_fld,FunctionLibraryDefinition * fld)174 Status PropagateConstIntoFuncAttr(
175 Node* n, const string& attr_name,
176 const std::unordered_map<int, const Node*>& const_input_index_to_node,
177 const FunctionLibraryDefinition* lookup_fld,
178 FunctionLibraryDefinition* fld) {
179 // Instantiate the function.
180 NameAttrList func_attr;
181 TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), attr_name, &func_attr));
182 const FunctionDef* fdef = lookup_fld->Find(func_attr.name());
183 if (!fdef) {
184 return errors::Internal("Cannot find function ", func_attr.name(),
185 " for node ", n->name());
186 }
187 std::unique_ptr<FunctionBody> fbody;
188 TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
189 *fdef, AttrSlice(&func_attr.attr()), lookup_fld, &fbody));
190
191 // Rewrite _Arg usages with Const node.
192 Graph* func_graph = fbody->graph;
193 TF_RETURN_IF_ERROR(
194 ReplaceArgUsageWithConstNode(func_graph, const_input_index_to_node));
195
196 // Save rewritten function.
197 FunctionDef replace_fdef;
198 string new_func_name =
199 fld->UniqueFunctionName(absl::StrCat(func_attr.name(), "_const_"));
200 TF_RETURN_IF_ERROR(
201 GraphToFunctionDef(*func_graph, new_func_name, &replace_fdef));
202 TF_RETURN_IF_ERROR(fld->AddFunctionDef(replace_fdef));
203
204 // Change the node to use rewritten function.
205 func_attr.set_name(new_func_name);
206 n->ClearAttr(attr_name);
207 n->AddAttr(attr_name, func_attr);
208
209 // Copy associated functions.
210 TF_RETURN_IF_ERROR(CopyAssociatedFunctions(func_graph, lookup_fld, fld));
211
212 return Status::OK();
213 }
214
215 // For an "If" node in graph `g`, if it has Const node inputs, rewrite its
216 // then/else branch function to replace _Arg nodes with those Const inputs.
PropagateConstIntoIfNode(Graph * g,Node * if_node,const FunctionLibraryDefinition * lookup_fld,FunctionLibraryDefinition * fld)217 Status PropagateConstIntoIfNode(Graph* g, Node* if_node,
218 const FunctionLibraryDefinition* lookup_fld,
219 FunctionLibraryDefinition* fld) {
220 // Notice that first input for If node is predicate; other inputs are function
221 // inputs.
222 std::unordered_map<int, const Node*> const_input_index_to_node;
223 for (int i = 1; i < if_node->num_inputs(); i++) {
224 const Node* input_node;
225 TF_RETURN_IF_ERROR(if_node->input_node(i, &input_node));
226 if (input_node->type_string() == "Const") {
227 const_input_index_to_node[i - 1] = input_node;
228 }
229 }
230 if (const_input_index_to_node.empty()) {
231 return Status::OK();
232 }
233
234 // Rewrite "then_branch" and "else_branch" function, replace usage of those
235 // _Arg nodes with corresponding const node.
236 for (const auto& attr_name :
237 std::vector<string>{"then_branch", "else_branch"}) {
238 TF_RETURN_IF_ERROR(PropagateConstIntoFuncAttr(
239 if_node, attr_name, const_input_index_to_node, lookup_fld, fld));
240 }
241
242 return Status::OK();
243 }
244
245 // For a "While" node in graph `g`, if it has Const node inputs, rewrite its
246 // cond/body function to replace _Arg nodes with those Const inputs.
PropagateConstIntoWhileNode(Graph * g,Node * while_node,const FunctionLibraryDefinition * lookup_fld,FunctionLibraryDefinition * fld)247 Status PropagateConstIntoWhileNode(Graph* g, Node* while_node,
248 const FunctionLibraryDefinition* lookup_fld,
249 FunctionLibraryDefinition* fld) {
250 // For "While" node, we should only replace _Arg nodes which are loop
251 // invariants. For such _Arg nodes, the return value's input will come
252 // directly from the corresponding arg.
253 std::unordered_map<int, const Node*> const_input_index_to_node;
254 NameAttrList body_attr;
255 TF_RETURN_IF_ERROR(GetNodeAttr(while_node->def(), "body", &body_attr));
256 const FunctionDef* body_func = lookup_fld->Find(body_attr.name());
257 if (!body_func) {
258 return errors::Internal("Cannot find body function ", body_attr.name(),
259 " for While node ", while_node->name());
260 }
261 for (int i = 0; i < while_node->num_inputs(); i++) {
262 const Node* input_node;
263 TF_RETURN_IF_ERROR(while_node->input_node(i, &input_node));
264 if (input_node->type_string() != "Const") {
265 continue;
266 }
267
268 // Check if i-th retval's input comes from i-th arg directly.
269 // For resource variable input of While nodes, TF2XLA convention is to place
270 // them at the end of all inputs (after all data inputs), and *not* return
271 // them. So number of While node inputs might be larger than number of its
272 // outputs.
273 if (i >= body_func->signature().output_arg_size()) {
274 continue;
275 }
276 const OpDef_ArgDef& output_arg = body_func->signature().output_arg(i);
277 auto output_arg_input = body_func->ret().find(output_arg.name());
278 if (output_arg_input == body_func->ret().end()) {
279 return errors::Internal("Cannot find input for output arg ",
280 output_arg.name(), " in function ",
281 body_attr.name());
282 }
283 const OpDef_ArgDef& input_arg = body_func->signature().input_arg(i);
284 if (output_arg_input->second != input_arg.name()) {
285 continue;
286 }
287
288 const_input_index_to_node[i] = input_node;
289 }
290 if (const_input_index_to_node.empty()) {
291 return Status::OK();
292 }
293
294 // Rewrite "cond" and "body" function, replace usage of those _Arg nodes with
295 // corresponding const node.
296 for (const auto& attr_name : std::vector<string>{"cond", "body"}) {
297 TF_RETURN_IF_ERROR(PropagateConstIntoFuncAttr(
298 while_node, attr_name, const_input_index_to_node, lookup_fld, fld));
299 }
300 return Status::OK();
301 }
302
303 } // namespace
304
305 const char kXlaOutsideCompilationAttrName[] = "_xla_outside_compilation";
306
ValidateConfig(const tf2xla::Config & config)307 Status ValidateConfig(const tf2xla::Config& config) {
308 std::set<string> names;
309 for (const tf2xla::Feed& feed : config.feed()) {
310 TF_RETURN_IF_ERROR(ValidateTensorId(feed.id()));
311 TF_RETURN_IF_ERROR(TensorShape::IsValidShape(feed.shape()));
312 TF_RETURN_IF_ERROR(CheckNameDuplicates("feed", feed.name(), &names));
313 }
314 TF_RETURN_IF_ERROR(CheckFeedFetchNameConflicts("feed", names));
315 names.clear();
316 for (const tf2xla::Fetch& fetch : config.fetch()) {
317 TF_RETURN_IF_ERROR(ValidateTensorId(fetch.id()));
318 TF_RETURN_IF_ERROR(CheckNameDuplicates("fetch", fetch.name(), &names));
319 }
320 TF_RETURN_IF_ERROR(CheckFeedFetchNameConflicts("fetch", names));
321 if (config.fetch().empty()) {
322 return errors::InvalidArgument("fetches must be specified");
323 }
324 return Status::OK();
325 }
326
AddPlaceholdersForFeeds(const tf2xla::Config & config,const OpRegistryInterface * op_registry,std::unordered_map<string,string> * feed_remapping,GraphDef * graph_def)327 Status AddPlaceholdersForFeeds(
328 const tf2xla::Config& config, const OpRegistryInterface* op_registry,
329 std::unordered_map<string, string>* feed_remapping, GraphDef* graph_def) {
330 struct PlaceholderInfo {
331 const tf2xla::Feed* feed = nullptr; // point to Feed in <config>.
332 string placeholder_name;
333 DataType data_type = DT_INVALID;
334 };
335
336 // Put each fed tensor into a map by name:port. A map is used for determinism
337 // when creating placeholders (genrules want deterministic output).
338 std::map<string, PlaceholderInfo> placeholder_info;
339 for (int i = 0; i < config.feed_size(); ++i) {
340 const tf2xla::Feed* feed = &config.feed(i);
341 const string name_port = TensorIdToString(feed->id());
342 PlaceholderInfo& info = placeholder_info[name_port];
343 info.feed = feed;
344 info.placeholder_name = absl::StrCat("aot_feed_", feed->id().output_index(),
345 "/", feed->id().node_name());
346 (*feed_remapping)[name_port] = info.placeholder_name;
347 }
348
349 // Verify node exists and determine data type.
350 std::unordered_map<string, const NodeDef*> name_to_node;
351 for (int i = 0; i < graph_def->node_size(); ++i) {
352 name_to_node[graph_def->node(i).name()] = &graph_def->node(i);
353 }
354 for (auto it = placeholder_info.begin(); it != placeholder_info.end(); ++it) {
355 PlaceholderInfo& info = it->second;
356 const tf2xla::TensorId& feed_id = info.feed->id();
357
358 // Find the existing node and determine data type.
359 auto node_it = name_to_node.find(feed_id.node_name());
360 if (node_it == name_to_node.end()) {
361 return errors::NotFound("Can't find feed node: ",
362 TensorIdToString(feed_id));
363 }
364 const NodeDef* existing = node_it->second;
365
366 if (info.feed->type() != DT_INVALID) {
367 info.data_type = info.feed->type();
368 } else {
369 // Build the node in order to infer its type.
370
371 // Must first add default attrs as well, so do this in a copied GraphDef.
372 GraphDef gd;
373 *gd.mutable_versions() = graph_def->versions();
374 *gd.add_node() = *existing;
375 MergeDebugInfo(NodeDebugInfo(*existing), gd.mutable_node(0));
376 TF_RETURN_IF_ERROR(
377 AddDefaultAttrsToGraphDef(&gd, *op_registry, 0 /*node_offset*/));
378
379 // Now build the node from the copied node def.
380 Graph g(op_registry);
381 g.set_versions(graph_def->versions());
382 Status status;
383 Node* feed_node = g.AddNode(gd.node(0), &status);
384 TF_RETURN_IF_ERROR(status);
385
386 if (info.feed->id().output_index() < feed_node->num_outputs()) {
387 info.data_type =
388 BaseType(feed_node->output_type(info.feed->id().output_index()));
389 } else {
390 return errors::InvalidArgument(
391 "Invalid output_index ", info.feed->id().output_index(),
392 " for feed node ", info.feed->id().node_name());
393 }
394 }
395 }
396
397 // Create placeholders. Note that we could avoid creating a placeholder for
398 // feeds which are already placeholders, but we omit that to avoid more cases
399 // in this code.
400 for (auto it = placeholder_info.begin(); it != placeholder_info.end(); ++it) {
401 const PlaceholderInfo& info = it->second;
402 // TODO(shikharagarwal): Add original node information.
403 NodeDef* d = graph_def->add_node();
404 d->set_name(info.placeholder_name);
405 d->set_op("PlaceholderV2");
406 auto& attr_map = *d->mutable_attr();
407 attr_map["dtype"].set_type(info.data_type);
408 *attr_map["shape"].mutable_shape() = info.feed->shape();
409 }
410
411 // Rewrite references to the fed tensors to refer to the placeholder.
412 for (int i = 0; i < graph_def->node_size(); ++i) {
413 NodeDef* node_def = graph_def->mutable_node(i);
414 for (int j = 0; j < node_def->input_size(); ++j) {
415 auto id = ParseTensorName(node_def->input(j));
416 auto it = placeholder_info.find(id.ToString());
417 if (it != placeholder_info.end()) {
418 node_def->set_input(j, it->second.placeholder_name);
419 }
420 }
421 }
422
423 return Status::OK();
424 }
425
PruneGraphDefInto(const tf2xla::Config & config,const GraphDef & in,GraphDef * out)426 Status PruneGraphDefInto(const tf2xla::Config& config, const GraphDef& in,
427 GraphDef* out) {
428 *out = in;
429 out->clear_node();
430
431 // Tensors needed for feeding.
432 std::set<std::pair<string, int>> feed_tensors;
433 for (const tf2xla::Feed& feed : config.feed()) {
434 feed_tensors.insert(
435 std::make_pair(feed.id().node_name(), feed.id().output_index()));
436 }
437
438 // Maps node name to reachability.
439 std::unordered_map<string, std::pair<bool, const NodeDef*>> node_by_name;
440 for (const NodeDef& node : in.node()) {
441 node_by_name[node.name()] = std::pair<bool, const NodeDef*>(false, &node);
442 }
443
444 // Traverse.
445 std::queue<string> name_queue;
446 for (int i = 0; i < config.fetch_size(); ++i) {
447 name_queue.push(config.fetch(i).id().node_name());
448 }
449 while (!name_queue.empty()) {
450 const string name = name_queue.front();
451 name_queue.pop();
452
453 auto find_it = node_by_name.find(name);
454 if (find_it == node_by_name.end()) {
455 return errors::InvalidArgument("While pruning graph, node ", name,
456 " needed but not found in the graph.");
457 }
458 auto& map_entry = find_it->second;
459 if (map_entry.first) {
460 continue;
461 }
462 map_entry.first = true;
463
464 // Push input nodes of the currently visited node to name_queue.
465 for (const string& in_edge : map_entry.second->input()) {
466 auto id = ParseTensorName(in_edge);
467 const string node_name = string(id.first);
468 if (feed_tensors.find(std::make_pair(node_name, id.second)) ==
469 feed_tensors.end()) {
470 name_queue.push(node_name);
471 } else {
472 // The input tensor is from an edge that is being fed. Therefore,
473 // we skip recursing down that edge, to avoid requiring nodes that
474 // may not be needed (note that the input node may still be added
475 // to name_queue later if one of its output edges is not being fed).
476 }
477 }
478 }
479
480 // Copy over, preserving order of original and only nodes that are reachable
481 // from the fetches.
482 out->mutable_node()->Reserve(in.node_size());
483 for (const NodeDef& node : in.node()) {
484 if (node_by_name[node.name()].first) {
485 *out->add_node() = node;
486 }
487 }
488 return Status::OK();
489 }
490
TensorIdToString(const tf2xla::TensorId & id)491 string TensorIdToString(const tf2xla::TensorId& id) {
492 return absl::StrCat(id.node_name(), ":", id.output_index());
493 }
494
SetNodeShardingFromNeighbors(Node * n,bool out_edges)495 Status SetNodeShardingFromNeighbors(Node* n, bool out_edges) {
496 int core = -1;
497 const Node* matching_node = nullptr;
498 for (const Edge* edge : (out_edges ? n->out_edges() : n->in_edges())) {
499 if (edge->IsControlEdge()) continue;
500 const Node* possible_match = out_edges ? edge->dst() : edge->src();
501 TF_ASSIGN_OR_RETURN(
502 absl::optional<xla::OpSharding> sharding,
503 ParseShardingFromDevice(
504 *possible_match,
505 /*num_cores_per_replica=*/std::numeric_limits<int32>::max()));
506 if (sharding && sharding->type() == xla::OpSharding::MAXIMAL) {
507 const int core_annotation = sharding.value().tile_assignment_devices(0);
508 if (core == -1 || core > core_annotation) {
509 core = core_annotation;
510 matching_node = possible_match;
511 }
512 }
513 }
514 if (matching_node != nullptr) {
515 n->set_assigned_device_name(matching_node->assigned_device_name());
516 n->set_requested_device(matching_node->requested_device());
517 }
518 return Status::OK();
519 }
520
AddDtypeToKernelDefConstraint(absl::string_view name,DataType dtype,KernelDef * kdef)521 void AddDtypeToKernelDefConstraint(absl::string_view name, DataType dtype,
522 KernelDef* kdef) {
523 for (KernelDef::AttrConstraint& constraint : *kdef->mutable_constraint()) {
524 if (constraint.name() == name) {
525 constraint.mutable_allowed_values()->mutable_list()->add_type(dtype);
526 }
527 }
528 }
529
530 namespace {
InitialRandomSeed()531 uint32 InitialRandomSeed() {
532 // Support plumbing the TF seed through to XLA is being worked on.
533 // If a user wants deterministic behavior, their best option
534 // is to start with a known checkpoint. This also handles issues when
535 // multiple random calls can be invoked in any order by TF executor.
536 // Another option is to use stateless random ops. They have much cleaner
537 // semantics.
538 // If a user really wants to set a deterministic seed for XLA-based
539 // devices, this is the place to do it.
540 std::random_device rd;
541 // Make the starting value odd.
542 return rd() | 1;
543 }
544 } // namespace
545
GetXLARandomSeed()546 uint32 GetXLARandomSeed() {
547 // We initialize counter with an odd number and increment it by two
548 // everytime. This ensures that it will never be zero, even
549 // after an overflow. When seeded with zero, some XLA backends
550 // can return all zeros instead of random numbers.
551 static std::atomic<uint32> counter(InitialRandomSeed());
552 uint32 seed = counter.fetch_add(2);
553 std::srand(seed);
554 return std::rand() | 1;
555 }
556
557 // TODO(b/77601805): add tests for associated function related stuff.
HasAssociatedFunction(const NodeDef & node_def,const FunctionLibraryDefinition * fld)558 bool HasAssociatedFunction(const NodeDef& node_def,
559 const FunctionLibraryDefinition* fld) {
560 if (fld->Contains(node_def.op())) {
561 return true;
562 }
563
564 if (node_def.op() == FunctionLibraryDefinition::kGradientOp) {
565 // Gradient op has "f" attr, which is set to the function we are getting
566 // gradient for. We need to functionalize the gradient function.
567 return true;
568 }
569
570 if (node_def.op() == "XlaHostCompute") {
571 // XlaHostCompute has "shape_inference_graph" func attr, but that's not
572 // related to graph execution.
573 return false;
574 }
575
576 for (const auto& iter : node_def.attr()) {
577 if (iter.second.has_func()) {
578 return true;
579 }
580 }
581
582 return false;
583 }
584
GetAssociatedFunctions(const Node & node,const FunctionLibraryDefinition * fld)585 std::vector<AssociatedFunctionInfo> GetAssociatedFunctions(
586 const Node& node, const FunctionLibraryDefinition* fld) {
587 std::vector<AssociatedFunctionInfo> results;
588 const string& op = node.type_string();
589 if (fld->Contains(op)) {
590 // This is a function call node.
591 AttrValueMap attrs(node.attrs().begin(), node.attrs().end());
592 results.emplace_back(AssociatedFunctionInfo::FunctionCall(op, attrs));
593 } else if (node.type_string() == FunctionLibraryDefinition::kGradientOp) {
594 // This is a SymbolicGradient op.
595 AttrValueMap attrs(node.attrs().begin(), node.attrs().end());
596 results.emplace_back(AssociatedFunctionInfo::SymbolicGradient(op, attrs));
597 } else if (node.type_string() == "XlaHostCompute") {
598 // XlaHostCompute has "shape_inference_graph" func attr, but that's not
599 // related to graph execution.
600 } else {
601 // Collect all function attrs for the node.
602 for (auto& iter : node.attrs()) {
603 if (iter.second.has_func()) {
604 VLOG(2) << "Found function attr for node " << node.name() << ": "
605 << iter.first << " = " << iter.second.func().name();
606 results.emplace_back(AssociatedFunctionInfo::FunctionAttr(
607 iter.second.func().name(), iter.second.func().attr(), iter.first));
608 }
609 }
610 }
611 return results;
612 }
613
RewriteAssociatedFunction(Graph * graph,Node * node,FunctionLibraryDefinition * fld,const AssociatedFunctionInfo & associated_function,const string & rewritten_function_name)614 Status RewriteAssociatedFunction(
615 Graph* graph, Node* node, FunctionLibraryDefinition* fld,
616 const AssociatedFunctionInfo& associated_function,
617 const string& rewritten_function_name) {
618 switch (associated_function.type()) {
619 case AssociatedFunctionInfo::kFunctionCallNode: {
620 // Change this node to call the new function.
621 NodeDebugInfo debug_info(*node);
622 NodeDefBuilder builder(node->name(), rewritten_function_name, fld,
623 &debug_info);
624 for (auto attr : node->attrs()) {
625 builder.Attr(attr.first, attr.second);
626 }
627 for (int i = 0; i < node->num_inputs(); i++) {
628 Node* input_node;
629 TF_RETURN_IF_ERROR(node->input_node(i, &input_node));
630 builder.Input(input_node->name(), i, node->input_type(i));
631 }
632 builder.Device(node->assigned_device_name().empty()
633 ? node->requested_device()
634 : node->assigned_device_name());
635 NodeDef node_def;
636 TF_RETURN_IF_ERROR(builder.Finalize(&node_def));
637 Status s;
638 Node* new_node = graph->AddNode(node_def, &s);
639 TF_RETURN_IF_ERROR(s);
640 for (auto edge : node->in_edges()) {
641 graph->AddEdge(edge->src(), edge->src_output(), new_node,
642 edge->dst_input());
643 }
644 for (auto edge : node->out_edges()) {
645 graph->AddEdge(new_node, edge->src_output(), edge->dst(),
646 edge->dst_input());
647 }
648 graph->RemoveNode(node);
649 break;
650 }
651 case AssociatedFunctionInfo::kSymbolicGradient: {
652 NameAttrList func;
653 TF_RETURN_IF_ERROR(GetNodeAttr(
654 node->attrs(), FunctionLibraryDefinition::kFuncAttr, &func));
655 GradientDef gradient_def;
656 gradient_def.set_function_name(func.name());
657 gradient_def.set_gradient_func(rewritten_function_name);
658 string original_grad_func = fld->FindGradient(func.name());
659 if (original_grad_func.empty()) {
660 TF_RETURN_IF_ERROR(fld->AddGradientDef(gradient_def));
661 } else if (original_grad_func != rewritten_function_name) {
662 TF_RETURN_IF_ERROR(fld->ReplaceGradient(gradient_def));
663 }
664 break;
665 }
666 case AssociatedFunctionInfo::kFunctionAttr: {
667 // Change function attr to rewritten functions.
668 NameAttrList func;
669 TF_RETURN_IF_ERROR(
670 GetNodeAttr(node->attrs(), associated_function.attr_name(), &func));
671 node->ClearAttr(associated_function.attr_name());
672 func.set_name(rewritten_function_name);
673 node->AddAttr(associated_function.attr_name(), func);
674 break;
675 }
676 }
677
678 return Status::OK();
679 }
680
GetOrInstantiate(const string & func_name,AttrSlice attrs,FunctionLibraryRuntime::Handle * handle)681 Status CachedFunctionHandles::GetOrInstantiate(
682 const string& func_name, AttrSlice attrs,
683 FunctionLibraryRuntime::Handle* handle) {
684 string canonicalized_name = Canonicalize(func_name, attrs);
685 auto iter = handles_.find(canonicalized_name);
686 if (iter != handles_.end()) {
687 *handle = iter->second;
688 return Status::OK();
689 }
690
691 TF_RETURN_IF_ERROR(flr_->Instantiate(func_name, attrs, handle));
692 handles_[canonicalized_name] = *handle;
693 return Status::OK();
694 }
695
ReleaseAllHandles()696 Status CachedFunctionHandles::ReleaseAllHandles() {
697 Status result;
698 for (auto iter : handles_) {
699 result.Update(flr_->ReleaseHandle(iter.second));
700 }
701 handles_.clear();
702 return result;
703 }
704
ReplaceNode(Graph * g,Node * n,const NodeDef & node_def)705 xla::StatusOr<Node*> ReplaceNode(Graph* g, Node* n, const NodeDef& node_def) {
706 // Create the replacement node.
707 Status s;
708 Node* new_node = g->AddNode(node_def, &s);
709 if (!s.ok()) {
710 return s;
711 }
712
713 // Record original node's output edges and remove them first. This is to avoid
714 // multiple producers for dst nodes' input.
715 std::vector<OutEdgeInfo> out_edge_info;
716 std::vector<const Edge*> out_edges;
717 for (const Edge* edge : n->out_edges()) {
718 out_edges.push_back(edge);
719 out_edge_info.push_back(
720 {edge->dst(), edge->src_output(), edge->dst_input()});
721 }
722 for (const Edge* edge : out_edges) {
723 g->RemoveEdge(edge);
724 }
725
726 // Add original node's input and output edges to the replacement node.
727 for (const Edge* in_edge : n->in_edges()) {
728 g->AddEdge(in_edge->src(), in_edge->src_output(), new_node,
729 in_edge->dst_input());
730 }
731 for (const OutEdgeInfo& out_edge : out_edge_info) {
732 g->AddEdge(new_node, out_edge.src_output, out_edge.dst, out_edge.dst_input);
733 }
734
735 // Remove the original node.
736 g->RemoveNode(n);
737
738 return new_node;
739 }
740
BuildIdentityNode(Graph * graph,const string & node_name,DataType dtype,const Node * input,absl::optional<string> requested_device)741 xla::StatusOr<Node*> BuildIdentityNode(
742 Graph* graph, const string& node_name, DataType dtype, const Node* input,
743 absl::optional<string> requested_device) {
744 // Create identity node.
745 NodeDef ndef;
746 ndef.set_name(node_name);
747 ndef.set_op("Identity");
748 if (input) {
749 ndef.add_input(input->name());
750 }
751 if (requested_device) {
752 ndef.set_device(*requested_device);
753 }
754 AddNodeAttr("T", dtype, &ndef);
755 Status s;
756 Node* id_node = graph->AddNode(ndef, &s);
757 TF_RETURN_IF_ERROR(s);
758 return id_node;
759 }
760
PropagateConstIntoFunctionalNodes(Graph * g,const FunctionLibraryDefinition * lookup_fld,FunctionLibraryDefinition * fld)761 Status PropagateConstIntoFunctionalNodes(
762 Graph* g, const FunctionLibraryDefinition* lookup_fld,
763 FunctionLibraryDefinition* fld) {
764 for (Node* n : g->op_nodes()) {
765 if (n->IsIfNode()) {
766 TF_RETURN_IF_ERROR(PropagateConstIntoIfNode(g, n, lookup_fld, fld));
767 } else if (n->IsWhileNode()) {
768 TF_RETURN_IF_ERROR(PropagateConstIntoWhileNode(g, n, lookup_fld, fld));
769 }
770 }
771 return Status::OK();
772 }
773
PruneUnreachableFunctionsFromGraph(const Graph & g,FunctionLibraryDefinition * fld)774 Status PruneUnreachableFunctionsFromGraph(const Graph& g,
775 FunctionLibraryDefinition* fld) {
776 GraphDef graph_def;
777 g.ToGraphDef(&graph_def);
778 FunctionLibraryDefinition reachable_functions =
779 fld->ReachableDefinitions(graph_def);
780 for (const string& func_name : fld->ListFunctionNames()) {
781 if (!reachable_functions.Find(func_name)) {
782 TF_RETURN_IF_ERROR(fld->RemoveFunction(func_name));
783 }
784 }
785 return Status::OK();
786 }
787
RewriteTensorListWithConstElement(Graph * g,FunctionLibraryDefinition * fld)788 Status RewriteTensorListWithConstElement(Graph* g,
789 FunctionLibraryDefinition* fld) {
790 for (Node* n : g->nodes()) {
791 if (n->type_string() != "EmptyTensorList") {
792 continue;
793 }
794
795 // Find the forward While op.
796 std::vector<const Edge*> fwd_while_edges;
797 for (const Edge* e : n->out_edges()) {
798 if (!e->IsControlEdge() && e->dst()->IsWhileNode()) {
799 fwd_while_edges.push_back(e);
800 }
801 }
802 if (fwd_while_edges.size() != 1) {
803 // No forward While op found, or multiple forward While ops.
804 continue;
805 }
806
807 // Find the backward While op.
808 Node* fwd_while = fwd_while_edges[0]->dst();
809 int fwd_while_dst_input = fwd_while_edges[0]->dst_input();
810 std::vector<const Edge*> bwd_while_edges;
811 for (const Edge* e : fwd_while->out_edges()) {
812 if (e->src_output() == fwd_while_dst_input && e->dst()->IsWhileNode()) {
813 bwd_while_edges.push_back(e);
814 }
815 }
816 if (bwd_while_edges.size() != 1) {
817 // No backward While op found, or multiple backward While ops.
818 continue;
819 }
820
821 Node* bwd_while = bwd_while_edges[0]->dst();
822 int bwd_while_dst_input = bwd_while_edges[0]->dst_input();
823
824 // Look into forward While body function and check if TensorListPushBack op
825 // has a Const input.
826 NameAttrList fwd_body_attr;
827 TF_CHECK_OK(GetNodeAttr(fwd_while->def(), "body", &fwd_body_attr));
828 const FunctionDef* fwd_body = fld->Find(fwd_body_attr.name());
829 if (!fwd_body) {
830 return errors::InvalidArgument("Cannot find function ",
831 fwd_body_attr.name(), " for While node ",
832 fwd_while->DebugString());
833 }
834 std::unique_ptr<FunctionBody> fwd_fbody;
835 TF_CHECK_OK(FunctionDefToBodyHelper(
836 *fwd_body, AttrSlice(&fwd_body_attr.attr()), fld, &fwd_fbody));
837
838 // Find the TensorListPushBack node; it's one of fwd_arg's successors.
839 Node* fwd_arg = fwd_fbody->arg_nodes[fwd_while_dst_input];
840 std::vector<Node*> tl_push_nodes;
841 for (const Edge* out_edge : fwd_arg->out_edges()) {
842 if (out_edge->dst()->type_string() == "TensorListPushBack") {
843 tl_push_nodes.push_back(out_edge->dst());
844 }
845 }
846 if (tl_push_nodes.size() != 1) {
847 // No TensorListPushBack found, or multiple TensorListPushBack.
848 continue;
849 }
850
851 // Get input for the TensorListPushBack node.
852 Node* input_node;
853 TF_CHECK_OK(tl_push_nodes[0]->input_node(1, &input_node));
854 if (input_node->type_string() != "Const") {
855 // Input for the TensorList is not Const node.
856 continue;
857 }
858
859 NodeDef const_input_nodedef = input_node->def();
860
861 // Rewrite backward While body function, replace usages of
862 // TensorListPopBack with a Const node.
863 NameAttrList bwd_body_attr;
864 TF_CHECK_OK(GetNodeAttr(bwd_while->def(), "body", &bwd_body_attr));
865 const FunctionDef* bwd_body = fld->Find(bwd_body_attr.name());
866 if (!bwd_body) {
867 return errors::InvalidArgument("Cannot find function ",
868 bwd_body_attr.name(), " for While node ",
869 bwd_while->DebugString());
870 }
871 std::unique_ptr<FunctionBody> bwd_fbody;
872 TF_CHECK_OK(FunctionDefToBodyHelper(
873 *bwd_body, AttrSlice(&bwd_body_attr.attr()), fld, &bwd_fbody));
874
875 // Find the TensorListPopBack node; it's one of bwd_arg's successors.
876 Node* bwd_arg = bwd_fbody->arg_nodes[bwd_while_dst_input];
877 std::vector<Node*> tl_pop_nodes;
878 for (const Edge* out_edge : bwd_arg->out_edges()) {
879 if (out_edge->dst()->type_string() == "TensorListPopBack") {
880 tl_pop_nodes.push_back(out_edge->dst());
881 }
882 }
883 if (tl_pop_nodes.size() != 1) {
884 // No TensorListPopBack found, or multiple TensorListPopBack.
885 continue;
886 }
887
888 // Replace TensorListPopBack usages with Const node.
889 std::vector<const Edge*> edges_to_replace;
890 for (const Edge* e : tl_pop_nodes[0]->out_edges()) {
891 if (e->src_output() == 1) {
892 edges_to_replace.push_back(e);
893 }
894 }
895 if (edges_to_replace.empty()) {
896 continue;
897 }
898 Status s;
899 const_input_nodedef.set_name(
900 bwd_fbody->graph->NewName(const_input_nodedef.name()));
901 Node* const_node = bwd_fbody->graph->AddNode(const_input_nodedef, &s);
902 TF_RETURN_IF_ERROR(s);
903 for (const Edge* e : edges_to_replace) {
904 Node* dst = e->dst();
905 int dst_input = e->dst_input();
906 bwd_fbody->graph->RemoveEdge(e);
907 bwd_fbody->graph->AddEdge(const_node, 0, dst, dst_input);
908 }
909
910 // Add rewritten backward While body function.
911 FunctionDef new_fdef;
912 string new_name = fld->UniqueFunctionName(
913 absl::StrCat(bwd_body_attr.name(), "_tl_rewrite_"));
914 TF_RETURN_IF_ERROR(
915 GraphToFunctionDef(*bwd_fbody->graph, new_name, &new_fdef));
916 TF_RETURN_IF_ERROR(fld->AddFunctionDef(new_fdef));
917
918 // Change backward While op to use the new body function.
919 bwd_body_attr.set_name(new_name);
920 bwd_while->ClearAttr("body");
921 bwd_while->AddAttr("body", bwd_body_attr);
922 }
923 return Status::OK();
924 }
925
926 } // namespace tensorflow
927