1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
17
18 #include <functional>
19 #include <queue>
20 #include <random>
21 #include <set>
22 #include <unordered_map>
23
24 #include "absl/strings/str_cat.h"
25 #include "tensorflow/compiler/tf2xla/sharding_util.h"
26 #include "tensorflow/compiler/tf2xla/tf2xla.pb.h"
27 #include "tensorflow/compiler/xla/xla_data.pb.h"
28 #include "tensorflow/core/common_runtime/function.h"
29 #include "tensorflow/core/framework/graph.pb.h"
30 #include "tensorflow/core/framework/graph_def_util.h"
31 #include "tensorflow/core/framework/graph_to_functiondef.h"
32 #include "tensorflow/core/framework/node_def.pb.h"
33 #include "tensorflow/core/framework/node_def_builder.h"
34 #include "tensorflow/core/framework/node_def_util.h"
35 #include "tensorflow/core/framework/tensor_shape.h"
36 #include "tensorflow/core/framework/tensor_shape.pb.h"
37 #include "tensorflow/core/framework/versions.pb.h"
38 #include "tensorflow/core/graph/tensor_id.h"
39 #include "tensorflow/core/lib/core/errors.h"
40 #include "tensorflow/core/lib/core/status.h"
41
42 namespace tensorflow {
43
44 namespace {
45
ValidateTensorId(const tf2xla::TensorId & id)46 Status ValidateTensorId(const tf2xla::TensorId& id) {
47 if (id.node_name().empty()) {
48 return errors::InvalidArgument("TensorId node_name must be non-empty");
49 }
50 if (id.output_index() < 0) {
51 return errors::InvalidArgument("TensorId output_index must be positive");
52 }
53 return Status::OK();
54 }
55
CheckNameDuplicates(const string & kind,const string & name,std::set<string> * names)56 Status CheckNameDuplicates(const string& kind, const string& name,
57 std::set<string>* names) {
58 if (!name.empty()) {
59 if (!names->insert(name).second) {
60 return errors::InvalidArgument("duplicate ", kind, " name: ", name);
61 }
62 }
63 return Status::OK();
64 }
65
CheckFeedFetchNameConflicts(const string & kind,const std::set<string> & names)66 Status CheckFeedFetchNameConflicts(const string& kind,
67 const std::set<string>& names) {
68 // We don't allow the feeds or fetches to contain both "foo" and "foo_data",
69 // since that will cause a collision in codegen symbols.
70 for (const string& name : names) {
71 const string name_data(name + "_data");
72 if (names.find(name_data) != names.end()) {
73 return errors::InvalidArgument("conflicting ", kind, " name: ", name,
74 " and ", name_data);
75 }
76 }
77 return Status::OK();
78 }
79
80 // For graph `g`, copy all function call nodes' FunctionDef from `lookup_fld` to
81 // `fld`. This is to ensure that `fld` can instantiate FunctionDef of graph `g`.
CopyAssociatedFunctions(Graph * g,const FunctionLibraryDefinition * lookup_fld,FunctionLibraryDefinition * fld)82 Status CopyAssociatedFunctions(Graph* g,
83 const FunctionLibraryDefinition* lookup_fld,
84 FunctionLibraryDefinition* fld) {
85 for (Node* n : g->op_nodes()) {
86 for (const auto& associated_function :
87 GetAssociatedFunctions(*n, lookup_fld)) {
88 switch (associated_function.type()) {
89 case AssociatedFunctionInfo::kFunctionCallNode: {
90 const FunctionDef* fdef =
91 lookup_fld->Find(associated_function.func_name());
92 if (!fdef) {
93 return errors::Internal(
94 "Cannot find function ", associated_function.func_name(),
95 " for function call node ", n->DebugString());
96 }
97 TF_RETURN_IF_ERROR(fld->AddFunctionDef(*fdef));
98 break;
99 }
100 case AssociatedFunctionInfo::kSymbolicGradient:
101 case AssociatedFunctionInfo::kFunctionAttr:
102 break;
103 }
104 }
105 }
106 return Status::OK();
107 }
108
109 // For graph `g`, replaces _Arg nodes whose "index" attribute is in
110 // `const_input_index_to_node` with Const nodes.
ReplaceArgUsageWithConstNode(Graph * g,const std::unordered_map<int,const Node * > & const_input_index_to_node)111 Status ReplaceArgUsageWithConstNode(
112 Graph* g,
113 const std::unordered_map<int, const Node*>& const_input_index_to_node) {
114 // Collect all _Arg nodes.
115 std::unordered_map<int, Node*> arg_nodes;
116 for (Node* n : g->op_nodes()) {
117 if (n->IsArg()) {
118 int index;
119 TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
120 arg_nodes[index] = n;
121 }
122 }
123
124 for (const auto& iter : const_input_index_to_node) {
125 int arg_index = iter.first;
126 NodeDef const_def = iter.second->def();
127 const_def.set_name(g->NewName(const_def.name()));
128 Status s;
129 Node* const_node = g->AddNode(const_def, &s);
130 TF_RETURN_IF_ERROR(s);
131
132 Node* arg_node = arg_nodes[arg_index];
133
134 // Collect all usages of the _Arg node.
135 struct OutEdgeInfo {
136 int dst_node_id, dst_input;
137 };
138 std::vector<OutEdgeInfo> usages;
139 for (const Edge* e : arg_node->out_edges()) {
140 if (e->IsControlEdge()) {
141 continue;
142 }
143 usages.push_back({e->dst()->id(), e->dst_input()});
144 }
145
146 for (int i = 0; i < usages.size(); i++) {
147 // Make a copy of `usage_node`, and change its input to const node.
148 Node* usage_node = g->FindNodeId(usages[i].dst_node_id);
149 NodeDef replace_def = usage_node->def();
150 *replace_def.mutable_input(usages[i].dst_input) = const_node->name();
151 TF_ASSIGN_OR_RETURN(Node * replace_node,
152 ReplaceNode(g, usage_node, replace_def));
153 const Edge* usage_edge;
154 TF_RETURN_IF_ERROR(
155 replace_node->input_edge(usages[i].dst_input, &usage_edge));
156 g->RemoveEdge(usage_edge);
157 g->AddEdge(const_node, 0, replace_node, usages[i].dst_input);
158
159 // Later entries in `usages` might have `usage_node` as dst node, but
160 // `usage_node` is removed. Replace such entries with `replace_node`.
161 for (int j = i + 1; j < usages.size(); j++) {
162 if (usages[j].dst_node_id == usages[i].dst_node_id) {
163 usages[j].dst_node_id = replace_node->id();
164 }
165 }
166 }
167 }
168 return Status::OK();
169 }
170
171 // For a node's function attr (e.g. then/else branch for "If" nodes), rewrites
172 // the function to replace _Arg nodes in `const_input_index_to_node` with Const
173 // inputs.
PropagateConstIntoFuncAttr(Node * n,const string & attr_name,const std::unordered_map<int,const Node * > & const_input_index_to_node,const FunctionLibraryDefinition * lookup_fld,FunctionLibraryDefinition * fld)174 Status PropagateConstIntoFuncAttr(
175 Node* n, const string& attr_name,
176 const std::unordered_map<int, const Node*>& const_input_index_to_node,
177 const FunctionLibraryDefinition* lookup_fld,
178 FunctionLibraryDefinition* fld) {
179 // Instantiate the function.
180 NameAttrList func_attr;
181 TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), attr_name, &func_attr));
182 const FunctionDef* fdef = lookup_fld->Find(func_attr.name());
183 if (!fdef) {
184 return errors::Internal("Cannot find function ", func_attr.name(),
185 " for node ", n->name());
186 }
187 FunctionBody* fbody;
188 TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
189 *fdef, AttrSlice(&func_attr.attr()), lookup_fld,
190 [lookup_fld](const string& op, const OpDef** sig) {
191 return lookup_fld->LookUpOpDef(op, sig);
192 },
193 &fbody));
194 std::unique_ptr<FunctionBody> fbody_deleter(fbody);
195
196 // Rewrite _Arg usages with Const node.
197 Graph* func_graph = fbody->graph;
198 TF_RETURN_IF_ERROR(
199 ReplaceArgUsageWithConstNode(func_graph, const_input_index_to_node));
200
201 // Save rewritten function.
202 FunctionDef replace_fdef;
203 string new_func_name =
204 fld->UniqueFunctionName(absl::StrCat(func_attr.name(), "_const_"));
205 TF_RETURN_IF_ERROR(
206 GraphToFunctionDef(*func_graph, new_func_name, &replace_fdef));
207 TF_RETURN_IF_ERROR(fld->AddFunctionDef(replace_fdef));
208
209 // Change the node to use rewritten function.
210 func_attr.set_name(new_func_name);
211 n->ClearAttr(attr_name);
212 n->AddAttr(attr_name, func_attr);
213
214 // Copy associated functions.
215 TF_RETURN_IF_ERROR(CopyAssociatedFunctions(func_graph, lookup_fld, fld));
216
217 return Status::OK();
218 }
219
220 // For an "If" node in graph `g`, if it has Const node inputs, rewrite its
221 // then/else branch function to replace _Arg nodes with those Const inputs.
PropagateConstIntoIfNode(Graph * g,Node * if_node,const FunctionLibraryDefinition * lookup_fld,FunctionLibraryDefinition * fld)222 Status PropagateConstIntoIfNode(Graph* g, Node* if_node,
223 const FunctionLibraryDefinition* lookup_fld,
224 FunctionLibraryDefinition* fld) {
225 // Notice that first input for If node is predicate; other inputs are function
226 // inputs.
227 std::unordered_map<int, const Node*> const_input_index_to_node;
228 for (int i = 1; i < if_node->num_inputs(); i++) {
229 const Node* input_node;
230 TF_RETURN_IF_ERROR(if_node->input_node(i, &input_node));
231 if (input_node->type_string() == "Const") {
232 const_input_index_to_node[i - 1] = input_node;
233 }
234 }
235 if (const_input_index_to_node.empty()) {
236 return Status::OK();
237 }
238
239 // Rewrite "then_branch" and "else_branch" function, replace usage of those
240 // _Arg nodes with corresponding const node.
241 for (const auto& attr_name :
242 std::vector<string>{"then_branch", "else_branch"}) {
243 TF_RETURN_IF_ERROR(PropagateConstIntoFuncAttr(
244 if_node, attr_name, const_input_index_to_node, lookup_fld, fld));
245 }
246
247 return Status::OK();
248 }
249
250 // For a "While" node in graph `g`, if it has Const node inputs, rewrite its
251 // cond/body function to replace _Arg nodes with those Const inputs.
PropagateConstIntoWhileNode(Graph * g,Node * while_node,const FunctionLibraryDefinition * lookup_fld,FunctionLibraryDefinition * fld)252 Status PropagateConstIntoWhileNode(Graph* g, Node* while_node,
253 const FunctionLibraryDefinition* lookup_fld,
254 FunctionLibraryDefinition* fld) {
255 // For "While" node, we should only replace _Arg nodes which are loop
256 // invariants. For such _Arg nodes, the return value's input will come
257 // directly from the corresponding arg.
258 std::unordered_map<int, const Node*> const_input_index_to_node;
259 NameAttrList body_attr;
260 TF_RETURN_IF_ERROR(GetNodeAttr(while_node->def(), "body", &body_attr));
261 const FunctionDef* body_func = lookup_fld->Find(body_attr.name());
262 if (!body_func) {
263 return errors::Internal("Cannot find body function ", body_attr.name(),
264 " for While node ", while_node->name());
265 }
266 for (int i = 0; i < while_node->num_inputs(); i++) {
267 const Node* input_node;
268 TF_RETURN_IF_ERROR(while_node->input_node(i, &input_node));
269 if (input_node->type_string() != "Const") {
270 continue;
271 }
272
273 // Check if i-th retval's input comes from i-th arg directly.
274 // For resource variable input of While nodes, TF2XLA convention is to place
275 // them at the end of all inputs (after all data inputs), and *not* return
276 // them. So number of While node inputs might be larger than number of its
277 // outputs.
278 if (i >= body_func->signature().output_arg_size()) {
279 continue;
280 }
281 const OpDef_ArgDef& output_arg = body_func->signature().output_arg(i);
282 auto output_arg_input = body_func->ret().find(output_arg.name());
283 if (output_arg_input == body_func->ret().end()) {
284 return errors::Internal("Cannot find input for output arg ",
285 output_arg.name(), " in function ",
286 body_attr.name());
287 }
288 const OpDef_ArgDef& input_arg = body_func->signature().input_arg(i);
289 if (output_arg_input->second != input_arg.name()) {
290 continue;
291 }
292
293 const_input_index_to_node[i] = input_node;
294 }
295 if (const_input_index_to_node.empty()) {
296 return Status::OK();
297 }
298
299 // Rewrite "cond" and "body" function, replace usage of those _Arg nodes with
300 // corresponding const node.
301 for (const auto& attr_name : std::vector<string>{"cond", "body"}) {
302 TF_RETURN_IF_ERROR(PropagateConstIntoFuncAttr(
303 while_node, attr_name, const_input_index_to_node, lookup_fld, fld));
304 }
305 return Status::OK();
306 }
307
308 } // namespace
309
310 const char kXlaOutsideCompilationAttrName[] = "_xla_outside_compilation";
311
ValidateConfig(const tf2xla::Config & config)312 Status ValidateConfig(const tf2xla::Config& config) {
313 std::set<string> names;
314 for (const tf2xla::Feed& feed : config.feed()) {
315 TF_RETURN_IF_ERROR(ValidateTensorId(feed.id()));
316 TF_RETURN_IF_ERROR(TensorShape::IsValidShape(feed.shape()));
317 TF_RETURN_IF_ERROR(CheckNameDuplicates("feed", feed.name(), &names));
318 }
319 TF_RETURN_IF_ERROR(CheckFeedFetchNameConflicts("feed", names));
320 names.clear();
321 for (const tf2xla::Fetch& fetch : config.fetch()) {
322 TF_RETURN_IF_ERROR(ValidateTensorId(fetch.id()));
323 TF_RETURN_IF_ERROR(CheckNameDuplicates("fetch", fetch.name(), &names));
324 }
325 TF_RETURN_IF_ERROR(CheckFeedFetchNameConflicts("fetch", names));
326 if (config.fetch().empty()) {
327 return errors::InvalidArgument("fetches must be specified");
328 }
329 return Status::OK();
330 }
331
AddPlaceholdersForFeeds(const tf2xla::Config & config,const OpRegistryInterface * op_registry,std::unordered_map<string,string> * feed_remapping,GraphDef * graph_def)332 Status AddPlaceholdersForFeeds(
333 const tf2xla::Config& config, const OpRegistryInterface* op_registry,
334 std::unordered_map<string, string>* feed_remapping, GraphDef* graph_def) {
335 struct PlaceholderInfo {
336 const tf2xla::Feed* feed = nullptr; // point to Feed in <config>.
337 string placeholder_name;
338 DataType data_type = DT_INVALID;
339 };
340
341 // Put each fed tensor into a map by name:port. A map is used for determinism
342 // when creating placeholders (genrules want deterministic output).
343 std::map<string, PlaceholderInfo> placeholder_info;
344 for (int i = 0; i < config.feed_size(); ++i) {
345 const tf2xla::Feed* feed = &config.feed(i);
346 const string name_port = TensorIdToString(feed->id());
347 PlaceholderInfo& info = placeholder_info[name_port];
348 info.feed = feed;
349 info.placeholder_name = absl::StrCat("aot_feed_", feed->id().output_index(),
350 "/", feed->id().node_name());
351 (*feed_remapping)[name_port] = info.placeholder_name;
352 }
353
354 // Verify node exists and determine data type.
355 std::unordered_map<string, const NodeDef*> name_to_node;
356 for (int i = 0; i < graph_def->node_size(); ++i) {
357 name_to_node[graph_def->node(i).name()] = &graph_def->node(i);
358 }
359 for (auto it = placeholder_info.begin(); it != placeholder_info.end(); ++it) {
360 PlaceholderInfo& info = it->second;
361 const tf2xla::TensorId& feed_id = info.feed->id();
362
363 // Find the existing node and determine data type.
364 auto node_it = name_to_node.find(feed_id.node_name());
365 if (node_it == name_to_node.end()) {
366 return errors::NotFound("Can't find feed node: ",
367 TensorIdToString(feed_id));
368 }
369 const NodeDef* existing = node_it->second;
370
371 if (info.feed->type() != DT_INVALID) {
372 info.data_type = info.feed->type();
373 } else {
374 // Build the node in order to infer its type.
375
376 // Must first add default attrs as well, so do this in a copied GraphDef.
377 GraphDef gd;
378 *gd.mutable_versions() = graph_def->versions();
379 *gd.add_node() = *existing;
380 MergeDebugInfo(NodeDebugInfo(*existing), gd.mutable_node(0));
381 TF_RETURN_IF_ERROR(
382 AddDefaultAttrsToGraphDef(&gd, *op_registry, 0 /*node_offset*/));
383
384 // Now build the node from the copied node def.
385 Graph g(op_registry);
386 g.set_versions(graph_def->versions());
387 Status status;
388 Node* feed_node = g.AddNode(gd.node(0), &status);
389 TF_RETURN_IF_ERROR(status);
390
391 if (info.feed->id().output_index() < feed_node->num_outputs()) {
392 info.data_type =
393 BaseType(feed_node->output_type(info.feed->id().output_index()));
394 } else {
395 return errors::InvalidArgument(
396 "Invalid output_index ", info.feed->id().output_index(),
397 " for feed node ", info.feed->id().node_name());
398 }
399 }
400 }
401
402 // Create placeholders. Note that we could avoid creating a placeholder for
403 // feeds which are already placeholders, but we omit that to avoid more cases
404 // in this code.
405 for (auto it = placeholder_info.begin(); it != placeholder_info.end(); ++it) {
406 const PlaceholderInfo& info = it->second;
407 // TODO(shikharagarwal): Add original node information.
408 NodeDef* d = graph_def->add_node();
409 d->set_name(info.placeholder_name);
410 d->set_op("PlaceholderV2");
411 auto& attr_map = *d->mutable_attr();
412 attr_map["dtype"].set_type(info.data_type);
413 *attr_map["shape"].mutable_shape() = info.feed->shape();
414 }
415
416 // Rewrite references to the fed tensors to refer to the placeholder.
417 for (int i = 0; i < graph_def->node_size(); ++i) {
418 NodeDef* node_def = graph_def->mutable_node(i);
419 for (int j = 0; j < node_def->input_size(); ++j) {
420 auto id = ParseTensorName(node_def->input(j));
421 auto it = placeholder_info.find(id.ToString());
422 if (it != placeholder_info.end()) {
423 node_def->set_input(j, it->second.placeholder_name);
424 }
425 }
426 }
427
428 return Status::OK();
429 }
430
PruneGraphDefInto(const tf2xla::Config & config,const GraphDef & in,GraphDef * out)431 Status PruneGraphDefInto(const tf2xla::Config& config, const GraphDef& in,
432 GraphDef* out) {
433 *out = in;
434 out->clear_node();
435
436 // Tensors needed for feeding.
437 std::set<std::pair<string, int>> feed_tensors;
438 for (const tf2xla::Feed& feed : config.feed()) {
439 feed_tensors.insert(
440 std::make_pair(feed.id().node_name(), feed.id().output_index()));
441 }
442
443 // Maps node name to reachability.
444 std::unordered_map<string, std::pair<bool, const NodeDef*>> node_by_name;
445 for (const NodeDef& node : in.node()) {
446 node_by_name[node.name()] = std::pair<bool, const NodeDef*>(false, &node);
447 }
448
449 // Traverse.
450 std::queue<string> name_queue;
451 for (int i = 0; i < config.fetch_size(); ++i) {
452 name_queue.push(config.fetch(i).id().node_name());
453 }
454 while (!name_queue.empty()) {
455 const string name = name_queue.front();
456 name_queue.pop();
457
458 auto find_it = node_by_name.find(name);
459 if (find_it == node_by_name.end()) {
460 return errors::InvalidArgument("While pruning graph, node ", name,
461 " needed but not found in the graph.");
462 }
463 auto& map_entry = find_it->second;
464 if (map_entry.first) {
465 continue;
466 }
467 map_entry.first = true;
468
469 // Push input nodes of the currently visited node to name_queue.
470 for (const string& in_edge : map_entry.second->input()) {
471 auto id = ParseTensorName(in_edge);
472 const string node_name = string(id.first);
473 if (feed_tensors.find(std::make_pair(node_name, id.second)) ==
474 feed_tensors.end()) {
475 name_queue.push(node_name);
476 } else {
477 // The input tensor is from an edge that is being fed. Therefore,
478 // we skip recursing down that edge, to avoid requiring nodes that
479 // may not be needed (note that the input node may still be added
480 // to name_queue later if one of its output edges is not being fed).
481 }
482 }
483 }
484
485 // Copy over, preserving order of original and only nodes that are reachable
486 // from the fetches.
487 out->mutable_node()->Reserve(in.node_size());
488 for (const NodeDef& node : in.node()) {
489 if (node_by_name[node.name()].first) {
490 *out->add_node() = node;
491 }
492 }
493 return Status::OK();
494 }
495
TensorIdToString(const tf2xla::TensorId & id)496 string TensorIdToString(const tf2xla::TensorId& id) {
497 return absl::StrCat(id.node_name(), ":", id.output_index());
498 }
499
SetNodeShardingFromNeighbors(Node * n,bool out_edges)500 Status SetNodeShardingFromNeighbors(Node* n, bool out_edges) {
501 int core = -1;
502 const Node* matching_node = nullptr;
503 for (const Edge* edge : (out_edges ? n->out_edges() : n->in_edges())) {
504 if (edge->IsControlEdge()) continue;
505 const Node* possible_match = out_edges ? edge->dst() : edge->src();
506 TF_ASSIGN_OR_RETURN(
507 absl::optional<xla::OpSharding> sharding,
508 ParseShardingFromDevice(
509 *possible_match,
510 /*num_cores_per_replica=*/std::numeric_limits<int32>::max()));
511 if (sharding.has_value()) {
512 TF_RET_CHECK(sharding.value().type() ==
513 xla::OpSharding::Type::OpSharding_Type_MAXIMAL);
514 const int core_annotation = sharding.value().tile_assignment_devices(0);
515 if (core == -1 || core > core_annotation) {
516 core = core_annotation;
517 matching_node = possible_match;
518 }
519 }
520 }
521 if (matching_node != nullptr) {
522 n->set_assigned_device_name(matching_node->assigned_device_name());
523 n->set_requested_device(matching_node->requested_device());
524 }
525 return Status::OK();
526 }
527
AddDtypeToKernelDefConstraint(absl::string_view name,DataType dtype,KernelDef * kdef)528 void AddDtypeToKernelDefConstraint(absl::string_view name, DataType dtype,
529 KernelDef* kdef) {
530 for (KernelDef::AttrConstraint& constraint : *kdef->mutable_constraint()) {
531 if (constraint.name() == name) {
532 constraint.mutable_allowed_values()->mutable_list()->add_type(dtype);
533 }
534 }
535 }
536
537 namespace {
InitialRandomSeed()538 uint32 InitialRandomSeed() {
539 // Support plumbing the TF seed through to XLA is being worked on.
540 // If a user wants deterministic behavior, their best option
541 // is to start with a known checkpoint. This also handles issues when
542 // multiple random calls can be invoked in any order by TF executor.
543 // Another option is to use stateless random ops. They have much cleaner
544 // semantics.
545 // If a user really wants to set a deterministic seed for XLA-based
546 // devices, this is the place to do it.
547 std::random_device rd;
548 // Make the starting value odd.
549 return rd() | 1;
550 }
551 } // namespace
552
GetXLARandomSeed()553 uint32 GetXLARandomSeed() {
554 // We initialize counter with an odd number and increment it by two
555 // everytime. This ensures that it will never be zero, even
556 // after an overflow. When seeded with zero, some XLA backends
557 // can return all zeros instead of random numbers.
558 static std::atomic<uint32> counter(InitialRandomSeed());
559 uint32 seed = counter.fetch_add(2);
560 std::srand(seed);
561 return std::rand() | 1;
562 }
563
564 // TODO(b/77601805): add tests for associated function related stuff.
HasAssociatedFunction(const NodeDef & node_def,const FunctionLibraryDefinition * fld)565 bool HasAssociatedFunction(const NodeDef& node_def,
566 const FunctionLibraryDefinition* fld) {
567 if (fld->Contains(node_def.op())) {
568 return true;
569 }
570
571 if (node_def.op() == FunctionLibraryDefinition::kGradientOp) {
572 // Gradient op has "f" attr, which is set to the function we are getting
573 // gradient for. We need to functionalize the gradient function.
574 return true;
575 }
576
577 if (node_def.op() == "XlaHostCompute") {
578 // XlaHostCompute has "shape_inference_graph" func attr, but that's not
579 // related to graph execution.
580 return false;
581 }
582
583 for (const auto& iter : node_def.attr()) {
584 if (iter.second.has_func()) {
585 return true;
586 }
587 }
588
589 return false;
590 }
591
GetAssociatedFunctions(const Node & node,const FunctionLibraryDefinition * fld)592 std::vector<AssociatedFunctionInfo> GetAssociatedFunctions(
593 const Node& node, const FunctionLibraryDefinition* fld) {
594 std::vector<AssociatedFunctionInfo> results;
595 const string& op = node.type_string();
596 if (fld->Contains(op)) {
597 // This is a function call node.
598 AttrValueMap attrs(node.attrs().begin(), node.attrs().end());
599 results.emplace_back(AssociatedFunctionInfo::FunctionCall(op, attrs));
600 } else if (node.type_string() == FunctionLibraryDefinition::kGradientOp) {
601 // This is a SymbolicGradient op.
602 AttrValueMap attrs(node.attrs().begin(), node.attrs().end());
603 results.emplace_back(AssociatedFunctionInfo::SymbolicGradient(op, attrs));
604 } else if (node.type_string() == "XlaHostCompute") {
605 // XlaHostCompute has "shape_inference_graph" func attr, but that's not
606 // related to graph execution.
607 } else {
608 // Collect all function attrs for the node.
609 for (auto& iter : node.attrs()) {
610 if (iter.second.has_func()) {
611 VLOG(2) << "Found function attr for node " << node.name() << ": "
612 << iter.first << " = " << iter.second.func().name();
613 results.emplace_back(AssociatedFunctionInfo::FunctionAttr(
614 iter.second.func().name(), iter.second.func().attr(), iter.first));
615 }
616 }
617 }
618 return results;
619 }
620
RewriteAssociatedFunction(Graph * graph,Node * node,FunctionLibraryDefinition * fld,const AssociatedFunctionInfo & associated_function,const string & rewritten_function_name)621 Status RewriteAssociatedFunction(
622 Graph* graph, Node* node, FunctionLibraryDefinition* fld,
623 const AssociatedFunctionInfo& associated_function,
624 const string& rewritten_function_name) {
625 switch (associated_function.type()) {
626 case AssociatedFunctionInfo::kFunctionCallNode: {
627 // Change this node to call the new function.
628 NodeDebugInfo debug_info(*node);
629 NodeDefBuilder builder(node->name(), rewritten_function_name, fld,
630 &debug_info);
631 for (auto attr : node->attrs()) {
632 builder.Attr(attr.first, attr.second);
633 }
634 for (int i = 0; i < node->num_inputs(); i++) {
635 Node* input_node;
636 TF_RETURN_IF_ERROR(node->input_node(i, &input_node));
637 builder.Input(input_node->name(), i, node->input_type(i));
638 }
639 builder.Device(node->assigned_device_name().empty()
640 ? node->requested_device()
641 : node->assigned_device_name());
642 NodeDef node_def;
643 TF_RETURN_IF_ERROR(builder.Finalize(&node_def));
644 Status s;
645 Node* new_node = graph->AddNode(node_def, &s);
646 TF_RETURN_IF_ERROR(s);
647 for (auto edge : node->in_edges()) {
648 graph->AddEdge(edge->src(), edge->src_output(), new_node,
649 edge->dst_input());
650 }
651 for (auto edge : node->out_edges()) {
652 graph->AddEdge(new_node, edge->src_output(), edge->dst(),
653 edge->dst_input());
654 }
655 graph->RemoveNode(node);
656 break;
657 }
658 case AssociatedFunctionInfo::kSymbolicGradient: {
659 NameAttrList func;
660 TF_RETURN_IF_ERROR(GetNodeAttr(
661 node->attrs(), FunctionLibraryDefinition::kFuncAttr, &func));
662 GradientDef gradient_def;
663 gradient_def.set_function_name(func.name());
664 gradient_def.set_gradient_func(rewritten_function_name);
665 string original_grad_func = fld->FindGradient(func.name());
666 if (original_grad_func.empty()) {
667 TF_RETURN_IF_ERROR(fld->AddGradientDef(gradient_def));
668 } else if (original_grad_func != rewritten_function_name) {
669 TF_RETURN_IF_ERROR(fld->ReplaceGradient(gradient_def));
670 }
671 break;
672 }
673 case AssociatedFunctionInfo::kFunctionAttr: {
674 // Change function attr to rewritten functions.
675 NameAttrList func;
676 TF_RETURN_IF_ERROR(
677 GetNodeAttr(node->attrs(), associated_function.attr_name(), &func));
678 node->ClearAttr(associated_function.attr_name());
679 func.set_name(rewritten_function_name);
680 node->AddAttr(associated_function.attr_name(), func);
681 break;
682 }
683 }
684
685 return Status::OK();
686 }
687
GetOrInstantiate(const string & func_name,AttrSlice attrs,FunctionLibraryRuntime::Handle * handle)688 Status CachedFunctionHandles::GetOrInstantiate(
689 const string& func_name, AttrSlice attrs,
690 FunctionLibraryRuntime::Handle* handle) {
691 string canonicalized_name = Canonicalize(func_name, attrs);
692 auto iter = handles_.find(canonicalized_name);
693 if (iter != handles_.end()) {
694 *handle = iter->second;
695 return Status::OK();
696 }
697
698 TF_RETURN_IF_ERROR(flr_->Instantiate(func_name, attrs, handle));
699 handles_[canonicalized_name] = *handle;
700 return Status::OK();
701 }
702
ReleaseAllHandles()703 Status CachedFunctionHandles::ReleaseAllHandles() {
704 Status result;
705 for (auto iter : handles_) {
706 result.Update(flr_->ReleaseHandle(iter.second));
707 }
708 handles_.clear();
709 return result;
710 }
711
ReplaceNode(Graph * g,Node * n,const NodeDef & node_def)712 xla::StatusOr<Node*> ReplaceNode(Graph* g, Node* n, const NodeDef& node_def) {
713 // Create the replacement node.
714 Status s;
715 Node* new_node = g->AddNode(node_def, &s);
716 if (!s.ok()) {
717 return s;
718 }
719
720 // Record original node's output edges and remove them first. This is to avoid
721 // multiple producers for dst nodes' input.
722 std::vector<OutEdgeInfo> out_edge_info;
723 std::vector<const Edge*> out_edges;
724 for (const Edge* edge : n->out_edges()) {
725 out_edges.push_back(edge);
726 out_edge_info.push_back(
727 {edge->dst(), edge->src_output(), edge->dst_input()});
728 }
729 for (const Edge* edge : out_edges) {
730 g->RemoveEdge(edge);
731 }
732
733 // Add original node's input and output edges to the replacement node.
734 for (const Edge* in_edge : n->in_edges()) {
735 g->AddEdge(in_edge->src(), in_edge->src_output(), new_node,
736 in_edge->dst_input());
737 }
738 for (const OutEdgeInfo& out_edge : out_edge_info) {
739 g->AddEdge(new_node, out_edge.src_output, out_edge.dst, out_edge.dst_input);
740 }
741
742 // Remove the original node.
743 g->RemoveNode(n);
744
745 return new_node;
746 }
747
BuildIdentityNode(Graph * graph,const string & node_name,DataType dtype,const Node * input,absl::optional<string> requested_device)748 xla::StatusOr<Node*> BuildIdentityNode(
749 Graph* graph, const string& node_name, DataType dtype, const Node* input,
750 absl::optional<string> requested_device) {
751 // Create identity node.
752 NodeDef ndef;
753 ndef.set_name(node_name);
754 ndef.set_op("Identity");
755 if (input) {
756 ndef.add_input(input->name());
757 }
758 if (requested_device) {
759 ndef.set_device(*requested_device);
760 }
761 AddNodeAttr("T", dtype, &ndef);
762 Status s;
763 Node* id_node = graph->AddNode(ndef, &s);
764 TF_RETURN_IF_ERROR(s);
765 return id_node;
766 }
767
PropagateConstIntoFunctionalNodes(Graph * g,const FunctionLibraryDefinition * lookup_fld,FunctionLibraryDefinition * fld)768 Status PropagateConstIntoFunctionalNodes(
769 Graph* g, const FunctionLibraryDefinition* lookup_fld,
770 FunctionLibraryDefinition* fld) {
771 for (Node* n : g->op_nodes()) {
772 if (n->type_string() == "If") {
773 TF_RETURN_IF_ERROR(PropagateConstIntoIfNode(g, n, lookup_fld, fld));
774 } else if (n->type_string() == "While") {
775 TF_RETURN_IF_ERROR(PropagateConstIntoWhileNode(g, n, lookup_fld, fld));
776 }
777 }
778 return Status::OK();
779 }
780
781 } // namespace tensorflow
782