1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/common_runtime/placer_inspection_required_ops_utils.h"
16
17 #include <unordered_map>
18 #include <unordered_set>
19
20 #include "absl/strings/str_cat.h"
21 #include "absl/types/optional.h"
22 #include "tensorflow/core/framework/function.h"
23 #include "tensorflow/core/framework/node_def_builder.h"
24 #include "tensorflow/core/graph/graph.h"
25 #include "tensorflow/core/lib/core/errors.h"
26 #include "tensorflow/core/lib/core/status.h"
27
28 namespace tensorflow {
29 namespace {
30
IsFunctionCall(const Node & node)31 bool IsFunctionCall(const Node& node) {
32 // TODO(iga): Handle non-PCO functions when we add multi-device support
33 // to regular function calls. Also, the GetFunctionDefAndAttrs assumes that
34 // the function name is stored in the `f` attribute of the node. That code
35 // will need to change as well.
36 const string& op_type = node.op_def().name();
37 return op_type == "PartitionedCall" || op_type == "StatefulPartitionedCall";
38 }
39
40 // Utility to set node's value in `cache` and `is_deep` to `value`.
Set(const Node & node,bool value,bool * is_deep,std::vector<absl::optional<bool>> * cache)41 Status Set(const Node& node, bool value, bool* is_deep,
42 std::vector<absl::optional<bool>>* cache) {
43 *is_deep = value;
44 (*cache)[node.id()] = value;
45 return Status::OK();
46 }
47
48 } // namespace
49
PlacerInspectionRequiredOpChecker(const Graph * graph,const FunctionLibraryDefinition * flib_def)50 PlacerInspectionRequiredOpChecker::PlacerInspectionRequiredOpChecker(
51 const Graph* graph, const FunctionLibraryDefinition* flib_def)
52 : graph_(*graph), flib_def_(*flib_def) {
53 cache_.resize(graph_.num_node_ids());
54 }
55
IsPlacerInspectionRequired(const Node & node,bool * is_deep)56 Status PlacerInspectionRequiredOpChecker::IsPlacerInspectionRequired(
57 const Node& node, bool* is_deep) {
58 if (cache_[node.id()].has_value()) {
59 *is_deep = cache_[node.id()].value();
60 return Status::OK();
61 }
62
63 if (!IsFunctionCall(node)) {
64 return Set(node, false, is_deep, &cache_);
65 }
66 const FunctionDef* fdef;
67 NameAttrList func;
68 TF_RETURN_IF_ERROR(GetFunctionDefAndAttrs(flib_def_, node, &fdef, &func));
69 DataTypeVector types;
70 TF_RETURN_IF_ERROR(
71 OutputTypesForNode(AttrSlice(&func.attr()), fdef->signature(), &types));
72 for (DataType type : types) {
73 if (type == DT_RESOURCE) {
74 return Set(node, true, is_deep, &cache_);
75 }
76 }
77 return Set(node, false, is_deep, &cache_);
78 }
79
GetFunctionDefAndAttrs(const FunctionLibraryDefinition & flib_def,const Node & node,const FunctionDef ** fdef,NameAttrList * func)80 Status GetFunctionDefAndAttrs(const FunctionLibraryDefinition& flib_def,
81 const Node& node, const FunctionDef** fdef,
82 NameAttrList* func) {
83 TF_RETURN_IF_ERROR(GetNodeAttr(node.def(), "f", func));
84 const string& function_name = func->name();
85 *fdef = flib_def.Find(function_name);
86 if (*fdef == nullptr) {
87 return errors::InvalidArgument(
88 "Failed to find function \"", function_name,
89 "\" in function library: ", flib_def.ToProto().DebugString());
90 }
91 return Status::OK();
92 }
93
FunctionStack(const string & function_name)94 FunctionStack::FunctionStack(const string& function_name)
95 : current_function_name_(function_name) {}
96
Push(const Node * node_in_current_function,const string & new_current_function) const97 FunctionStack FunctionStack::Push(const Node* node_in_current_function,
98 const string& new_current_function) const {
99 FunctionStack new_stack(new_current_function);
100 new_stack.frames_ = frames_;
101 new_stack.frames_.emplace_back(current_function_name_,
102 node_in_current_function);
103 return new_stack;
104 }
105
HasFunction(const string & function_name) const106 bool FunctionStack::HasFunction(const string& function_name) const {
107 if (current_function_name_ == function_name) {
108 return true;
109 }
110 for (const Frame& frame : frames_) {
111 if (frame.function_name == function_name) {
112 return true;
113 }
114 }
115 return false;
116 }
117
FormatForError() const118 string FunctionStack::FormatForError() const {
119 std::vector<string> msgs;
120 for (int i = 0; i < frames_.size(); ++i) {
121 if (frames_[i].function_name.empty()) {
122 // Empty function body should only happen at the top level, i.e. i = 0.
123 // All internal frames should have valid function names.
124 msgs.push_back(absl::StrCat("Graph contains node ",
125 FormatNodeForError(*frames_[i].node)));
126
127 } else {
128 msgs.push_back(absl::StrCat(
129 "Function ", errors::FormatFunctionForError(frames_[i].function_name),
130 " contains node ", FormatNodeForError(*frames_[i].node)));
131 }
132 const string& fname = (i + 1 < frames_.size())
133 ? frames_[i + 1].function_name
134 : current_function_name_;
135 msgs.push_back(absl::StrCat("Node ", FormatNodeForError(*frames_[i].node),
136 " calls function ",
137 errors::FormatFunctionForError(fname)));
138 }
139 return absl::StrJoin(msgs, "\n ");
140 }
141
142 namespace {
143
144 using OutputEdgeMap = std::vector<std::vector<const Edge*>>;
145
146 constexpr char kIdentityOp[] = "Identity";
147
Uniquify(const string & candidate_name,std::unordered_set<string> * node_names)148 string Uniquify(const string& candidate_name,
149 std::unordered_set<string>* node_names) {
150 if (node_names->find(candidate_name) == node_names->end()) {
151 node_names->insert(candidate_name);
152 return candidate_name;
153 }
154
155 for (int counter = 0;; ++counter) {
156 string candidate = absl::StrCat(candidate_name, "_", counter);
157 if (node_names->find(candidate) == node_names->end()) {
158 node_names->insert(candidate);
159 return candidate;
160 }
161 }
162 }
163
AddInputIdentity(Node * node,int input_idx,Graph * graph,std::unordered_set<string> * node_names)164 Status AddInputIdentity(Node* node, int input_idx, Graph* graph,
165 std::unordered_set<string>* node_names) {
166 const Edge* edge;
167 TF_RETURN_IF_ERROR(node->input_edge(input_idx, &edge));
168
169 string identity_name = Uniquify(
170 absl::StrCat(edge->src()->name(), "_", node->name()), node_names);
171
172 NodeDefBuilder builder(identity_name, kIdentityOp);
173 builder.Attr("T", node->input_type(input_idx));
174 NodeDefBuilder::NodeOut input(edge->src()->name(), edge->src_output(),
175 node->input_type(input_idx));
176 builder.Input(input);
177 NodeDef identity_def;
178 TF_RETURN_IF_ERROR(builder.Finalize(&identity_def));
179 MergeDebugInfo(NodeDebugInfo(*node), &identity_def);
180
181 VLOG(6) << "Adding identity into " << edge->src()->name() << ":"
182 << edge->src_output() << " -> " << edge->dst()->name() << ":"
183 << input_idx << " \n"
184 << identity_def.DebugString();
185
186 Status status;
187 Node* identity_node = graph->AddNode(identity_def, &status);
188 if (!status.ok()) {
189 return status;
190 }
191 graph->AddEdge(edge->src(), edge->src_output(), identity_node, 0);
192
193 // Replace node's `input_idx` input with the new identity's 0'th output
194 TF_RETURN_IF_ERROR(graph->UpdateEdge(identity_node, 0, node, input_idx));
195
196 VLOG(6) << "Successfully inserted identity. Modified node: \n"
197 << node->DebugString();
198 return Status::OK();
199 }
200
201 struct EdgePtrCompare {
operator ()tensorflow::__anon58bc023b0211::EdgePtrCompare202 bool operator()(const Edge* lhs, const Edge* rhs) const {
203 return lhs->id() < rhs->id();
204 }
205 };
206
AddOutputIdentities(Node * node,Graph * graph,std::unordered_set<string> * node_names)207 Status AddOutputIdentities(Node* node, Graph* graph,
208 std::unordered_set<string>* node_names) {
209 auto add_identity = [&](int src_output, const string& identity_name,
210 Node** identity_node) {
211 NodeDefBuilder builder(identity_name, kIdentityOp);
212 builder.Attr("T", node->output_type(src_output));
213 NodeDefBuilder::NodeOut input(node->name(), src_output,
214 node->output_type(src_output));
215 builder.Input(input);
216 NodeDef identity_def;
217 TF_RETURN_IF_ERROR(builder.Finalize(&identity_def));
218 MergeDebugInfo(NodeDebugInfo(*node), &identity_def);
219
220 Status status;
221 *identity_node = graph->AddNode(identity_def, &status);
222 if (!status.ok()) {
223 return status;
224 }
225 graph->AddEdge(node, src_output, *identity_node, 0);
226 return Status::OK();
227 };
228
229 // output_used[i] == true iff `node`'s i'th output is used
230 // in this graph
231 std::vector<bool> output_used(node->num_outputs(), false);
232 // Copy the set of edges since EdgeSet does not allow modifications
233 // to graph edges during iteration.
234 const EdgeSet& out_edges = node->out_edges();
235 std::vector<const Edge*> edge_vector(out_edges.begin(), out_edges.end());
236 std::sort(edge_vector.begin(), edge_vector.end(), EdgePtrCompare());
237 for (const Edge* edge : edge_vector) {
238 if (edge->IsControlEdge()) {
239 continue;
240 }
241 output_used[edge->src_output()] = true;
242
243 Node* dst = edge->dst();
244 int dst_input = edge->dst_input();
245 int src_output = edge->src_output();
246 string identity_name =
247 Uniquify(absl::StrCat(node->name(), "_", dst->name()), node_names);
248 Node* identity_node;
249 TF_RETURN_IF_ERROR(add_identity(src_output, identity_name, &identity_node));
250 VLOG(6) << "Adding identity into " << node->name() << ":" << src_output
251 << " -> " << dst->name() << ":" << dst_input << " \n"
252 << identity_node->DebugString();
253
254 // Make original dst node consume the new identity's output instead of
255 // `node`'s output.
256 TF_RETURN_IF_ERROR(graph->UpdateEdge(identity_node, 0, dst, dst_input));
257 }
258
259 for (int output_idx = 0; output_idx < node->num_outputs(); ++output_idx) {
260 if (output_used[output_idx]) {
261 continue;
262 }
263 // The output is unused in the graph. Just add an identity
264 // consuming it.
265 string identity_name = Uniquify(node->name(), node_names);
266 Node* identity_node;
267 TF_RETURN_IF_ERROR(add_identity(output_idx, identity_name, &identity_node));
268 VLOG(6) << "Added identity into " << node->name() << ":" << output_idx
269 << " -> <no consumer>: \n"
270 << identity_node->DebugString();
271 }
272 return Status::OK();
273 }
274
IsolateNode(Node * node,Graph * graph)275 Status IsolateNode(Node* node, Graph* graph) {
276 // We use `node_names` to make sure we pick unique names.
277 // We don't use graph->NewName() because it produces verbose names and
278 // does not actually ensure that they are unique (it assumes all names
279 // are generated using it, which is not true today).
280 std::unordered_set<string> node_names(graph->num_nodes());
281 for (Node* n : graph->nodes()) {
282 node_names.insert(n->name());
283 }
284
285 for (int i = 0; i < node->num_inputs(); ++i) {
286 TF_RETURN_IF_ERROR(AddInputIdentity(node, i, graph, &node_names));
287 }
288 TF_RETURN_IF_ERROR(AddOutputIdentities(node, graph, &node_names));
289 return Status::OK();
290 }
291
292 } // namespace
293
IsolatePlacerInspectionRequiredOps(const FunctionLibraryDefinition & flib_def,Graph * graph)294 Status IsolatePlacerInspectionRequiredOps(
295 const FunctionLibraryDefinition& flib_def, Graph* graph) {
296 PlacerInspectionRequiredOpChecker checker(graph, &flib_def);
297 // It is OK to add nodes to the graph during iteration.
298 // New nodes will get ids above current ids. The loop
299 // will loop over current nodes only because the op_nodes()
300 // iterator uses node ids to iterate.
301 // Because the new nodes will be higher ids, the caching in
302 // the checker will also work fine as new nodes are added.
303 for (Node* node : graph->op_nodes()) {
304 bool should_be_isolated = false;
305 TF_RETURN_IF_ERROR(
306 checker.IsPlacerInspectionRequired(*node, &should_be_isolated));
307 if (!should_be_isolated) {
308 continue;
309 }
310 TF_RETURN_IF_ERROR(IsolateNode(node, graph));
311 }
312
313 return Status::OK();
314 }
315
316 } // namespace tensorflow
317