1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/jit/encapsulate_util.h"
17
18 #include <algorithm>
19 #include <iterator>
20
21 #include "absl/container/flat_hash_map.h"
22 #include "absl/container/flat_hash_set.h"
23 #include "absl/strings/str_cat.h"
24 #include "absl/types/optional.h"
25 #include "tensorflow/compiler/jit/shape_inference.h"
26 #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
27 #include "tensorflow/core/framework/node_def_util.h"
28 #include "tensorflow/core/graph/node_builder.h"
29 #include "tensorflow/core/protobuf/error_codes.pb.h"
30 #include "tensorflow/stream_executor/lib/statusor.h"
31
32 using stream_executor::port::StatusOr;
33
34 namespace tensorflow {
35
36 namespace {
37
38 // Returns string attribute value for the node if the attribute is present,
39 // otherwise returns empty optional value.
GetStringAttr(const Node & n,const string & attr_name)40 absl::optional<string> GetStringAttr(const Node& n, const string& attr_name) {
41 auto attr = n.attrs().Find(attr_name);
42 if (!attr) {
43 return absl::nullopt;
44 } else {
45 return attr->s();
46 }
47 }
48
49 // Adds a value to the node's list attribute.
50 template <typename T>
AppendToListAttr(Node * n,const string & attr_name,const string & value)51 Status AppendToListAttr(Node* n, const string& attr_name, const string& value) {
52 std::vector<T> attr_value;
53 Status s = GetNodeAttr(n->attrs(), attr_name, &attr_value);
54 if (!s.ok() && s.code() != error::NOT_FOUND) {
55 return s;
56 }
57
58 n->ClearAttr(attr_name);
59 attr_value.push_back(value);
60 n->AddAttr(attr_name, attr_value);
61 return Status::OK();
62 }
63
64 // Replaces attribute value.
65 template <typename T>
ReplaceAttr(Node * n,const string & attr_name,const T & value)66 void ReplaceAttr(Node* n, const string& attr_name, const T& value) {
67 n->ClearAttr(attr_name);
68 n->AddAttr(attr_name, value);
69 }
70
71 // Step 1 for `PreprocessEdgesBetweenOutsideCompilations`. See comments of
72 // `PreprocessEdgesBetweenOutsideCompilations` for details.
PreprocessControlEdgesBetweenOutsideCompilations(Graph * g,const string & outside_compilation_attr_name)73 Status PreprocessControlEdgesBetweenOutsideCompilations(
74 Graph* g, const string& outside_compilation_attr_name) {
75 // Gather edges to remove. We should not remove the edge while iterating.
76 std::vector<const Edge*> edges_to_remove;
77 for (const Edge* e : g->edges()) {
78 if (!e->IsControlEdge()) {
79 continue;
80 }
81
82 auto src_outside_compilation =
83 GetStringAttr(*e->src(), outside_compilation_attr_name);
84 auto dst_outside_compilation =
85 GetStringAttr(*e->dst(), outside_compilation_attr_name);
86
87 if (src_outside_compilation && dst_outside_compilation) {
88 if (*src_outside_compilation != *dst_outside_compilation) {
89 // Case 1a: outside compilation to outside compilation control edge.
90 edges_to_remove.push_back(e);
91
92 TF_RETURN_IF_ERROR(AppendToListAttr<string>(
93 e->dst(), kXlaControlDependenciesWithinXlaClusterAttrName,
94 e->src()->name()));
95 }
96 } else if (src_outside_compilation && !dst_outside_compilation) {
97 // Case 1b: outside compilation to its XLA computation control edge.
98 ReplaceAttr(e->src(), kXlaConnectedToXlaComputationAttrName, true);
99 } else if (!src_outside_compilation && dst_outside_compilation) {
100 // Case 1b: XLA computation to outside compilation in it control edge.
101 ReplaceAttr(e->dst(), kXlaConnectedFromXlaComputationAttrName, true);
102 }
103 }
104
105 for (auto e : edges_to_remove) {
106 g->RemoveEdge(e);
107 }
108 return Status::OK();
109 }
110
111 // Step 2 for `PreprocessEdgesBetweenOutsideCompilations`. See comments of
112 // `PreprocessEdgesBetweenOutsideCompilations` for details.
PreprocessDataEdgesBetweenOutsideCompilations(Graph * g,const string & outside_compilation_attr_name)113 Status PreprocessDataEdgesBetweenOutsideCompilations(
114 Graph* g, const string& outside_compilation_attr_name) {
115 // Gather edges between outside compilation and host computation. Notice that
116 // we do not store `Edge*` directly because we remove some nodes while adding
117 // Identity nodes, and those Edge pointers might be invalidated.
118 struct EdgeInfo {
119 int dst_input, dst_node_id;
120 };
121 std::vector<EdgeInfo> edges;
122 for (const Edge* e : g->edges()) {
123 if (e->IsControlEdge()) {
124 continue;
125 }
126
127 auto src_outside_compilation =
128 GetStringAttr(*e->src(), outside_compilation_attr_name);
129 auto dst_outside_compilation =
130 GetStringAttr(*e->dst(), outside_compilation_attr_name);
131
132 if (src_outside_compilation && dst_outside_compilation &&
133 *src_outside_compilation != *dst_outside_compilation) {
134 edges.push_back(EdgeInfo{e->dst_input(), e->dst()->id()});
135 VLOG(4) << "Oc -> oc edge: " << e->DebugString();
136 }
137 }
138
139 // Remove the edge from host to outside compilation. Add a placeholder as
140 // outside compilation node input.
141 std::map<std::pair<string, int>, Node*> placeholders;
142 for (int i = 0, end = edges.size(); i < end; i++) {
143 Node* dst = g->FindNodeId(edges[i].dst_node_id);
144 const Edge* e;
145 TF_RETURN_IF_ERROR(dst->input_edge(edges[i].dst_input, &e));
146 Node* src = e->src();
147 int src_output = e->src_output(), dst_input = e->dst_input();
148 g->RemoveEdge(e);
149
150 // Find or create placeholder node.
151 string new_name =
152 absl::StrCat(src->name(), "_oc_to_oc_placeholder_", src_output);
153 auto placeholder_index = std::make_pair(src->name(), src_output);
154 auto iter = placeholders.find(placeholder_index);
155 Node* placeholder_node;
156 if (iter == placeholders.end()) {
157 NodeDefBuilder placeholder_builder(new_name, "Placeholder");
158 placeholder_builder.Attr("dtype", src->output_type(src_output));
159 string outside_compilation_attr;
160 TF_RETURN_IF_ERROR(GetNodeAttr(dst->attrs(),
161 outside_compilation_attr_name,
162 &outside_compilation_attr));
163 placeholder_builder.Attr(outside_compilation_attr_name,
164 outside_compilation_attr);
165 placeholder_builder.Attr(kOutsideCompilationOriginalNodeAttrName,
166 src->name());
167 placeholder_builder.Attr(kOutsideCompilationSrcOutputAttrName,
168 src_output);
169 NodeDef placeholder_def;
170 TF_RETURN_IF_ERROR(placeholder_builder.Finalize(&placeholder_def));
171 Status s;
172 placeholder_node = g->AddNode(placeholder_def, &s);
173 TF_RETURN_IF_ERROR(s);
174 placeholders[placeholder_index] = placeholder_node;
175 } else {
176 placeholder_node = iter->second;
177 }
178 g->AddEdge(placeholder_node, 0, dst, dst_input);
179
180 // Replace `e->dst()` because its input node changed.
181 NodeDef new_def = dst->def();
182 *new_def.mutable_input(dst_input) = placeholder_node->name();
183 TF_ASSIGN_OR_RETURN(Node * dst_replace_node, ReplaceNode(g, dst, new_def));
184
185 // Other edge in `edges` might have `e->dst()` as src or dst
186 // node. Before removing `e->dst()`, replace those edges with
187 // corresponding edges for `dst_replace_node`.
188 for (int j = i + 1, end = edges.size(); j < end; j++) {
189 if (edges[j].dst_node_id == edges[i].dst_node_id) {
190 edges[j].dst_node_id = dst_replace_node->id();
191 }
192 }
193 }
194 return Status::OK();
195 }
196
197 // Step 1 for `PostprocessEdgesBetweenOutsideCompilations`. See comments of
198 // `PostprocessEdgesBetweenOutsideCompilations` for details.
PostprocessDataEdgesBetweenOutsideCompilations(Graph * g,const string & outside_compilation_attr_name)199 Status PostprocessDataEdgesBetweenOutsideCompilations(
200 Graph* g, const string& outside_compilation_attr_name) {
201 // Gather all outside compilation to outside compilation nodes.
202 std::vector<Node*> placeholder_nodes;
203 for (Node* n : g->nodes()) {
204 if (n->type_string() == "Placeholder" &&
205 HasNodeAttr(n->def(), kOutsideCompilationOriginalNodeAttrName)) {
206 placeholder_nodes.push_back(n);
207 }
208 }
209
210 // Remove the placeholder nodes, and reconnect original edge.
211 auto node_name_index = g->BuildNodeNameIndex();
212 for (auto n : placeholder_nodes) {
213 string node_name;
214 int node_src_output;
215 TF_RETURN_IF_ERROR(GetNodeAttr(
216 n->attrs(), kOutsideCompilationOriginalNodeAttrName, &node_name));
217 TF_RETURN_IF_ERROR(GetNodeAttr(
218 n->attrs(), kOutsideCompilationSrcOutputAttrName, &node_src_output));
219 auto iter = node_name_index.find(node_name);
220 if (iter == node_name_index.end()) {
221 return errors::Internal(
222 "Cannot find original node for oc -> host placeholder node ",
223 node_name);
224 }
225
226 // Change all usage node to use the original node instead.
227 Node* original_node = iter->second;
228 std::vector<const Edge*> control_edges;
229 std::vector<OutEdgeInfo> data_edges;
230 for (auto e : n->out_edges()) {
231 if (e->IsControlEdge()) {
232 control_edges.push_back(e);
233 } else {
234 data_edges.push_back({e->dst(), e->src_output(), e->dst_input()});
235 }
236 }
237 for (const Edge* e : control_edges) {
238 g->AddControlEdge(original_node, e->dst());
239 g->RemoveEdge(e);
240 }
241 for (int i = 0, end = data_edges.size(); i < end; i++) {
242 Node* dst = data_edges[i].dst;
243 NodeDef new_def = dst->def();
244 int dst_input = data_edges[i].dst_input;
245 *new_def.mutable_input(dst_input) =
246 absl::StrCat(original_node->name(), ":", node_src_output);
247 TF_ASSIGN_OR_RETURN(Node * replace_node, ReplaceNode(g, dst, new_def));
248
249 const Edge* edge_to_replace = nullptr;
250 TF_RETURN_IF_ERROR(replace_node->input_edge(dst_input, &edge_to_replace));
251 g->RemoveEdge(edge_to_replace);
252 g->AddEdge(original_node, node_src_output, replace_node, dst_input);
253
254 // Other edges might have `dst` as dst node. Update those edges with
255 // `replace_node`.
256 for (int j = i + 1, end = data_edges.size(); j < end; j++) {
257 if (data_edges[j].dst == dst) {
258 data_edges[j].dst = replace_node;
259 }
260 }
261
262 // Other placeholder node might have `dst` as original node. Update
263 // `node_name_index` with `replace_node`.
264 node_name_index[replace_node->name()] = replace_node;
265 }
266
267 // Remove placeholder node.
268 g->RemoveNode(n);
269 }
270 return Status::OK();
271 }
272
273 // Step 2 for `PostprocessEdgesBetweenOutsideCompilations`. See comments of
274 // `PostprocessEdgesBetweenOutsideCompilations` for details.
PostprocessControlEdgesBetweenOutsideCompilations(Graph * g,const string & outside_compilation_attr_name)275 Status PostprocessControlEdgesBetweenOutsideCompilations(
276 Graph* g, const string& outside_compilation_attr_name) {
277 auto node_name_index = g->BuildNodeNameIndex();
278
279 // Reconnect outside compilation to outside compilation control edge.
280 for (Node* n : g->nodes()) {
281 std::vector<string> control_deps;
282 Status s =
283 GetNodeAttr(n->attrs(), kXlaControlDependenciesWithinXlaClusterAttrName,
284 &control_deps);
285 if (!s.ok()) {
286 if (s.code() != error::NOT_FOUND) {
287 return s;
288 } else {
289 continue;
290 }
291 } else {
292 n->ClearAttr(kXlaControlDependenciesWithinXlaClusterAttrName);
293 for (const string& control_input : control_deps) {
294 auto iter = node_name_index.find(control_input);
295 if (iter == node_name_index.end()) {
296 return errors::Internal("Cannot find original node for ",
297 control_input);
298 }
299 g->AddControlEdge(iter->second, n);
300 }
301 }
302 }
303 return Status::OK();
304 }
305 } // namespace
306
307 const char kXlaInferredShapesAttrName[] = "_xla_inferred_shapes";
308
309 const char kXlaConnectedToXlaComputationAttrName[] =
310 "_xla_connected_to_xla_computation";
311 const char kXlaConnectedFromXlaComputationAttrName[] =
312 "_xla_connected_from_xla_computation";
313 const char kOutsideCompilationOriginalNodeAttrName[] =
314 "_xla_oc_to_oc_node_name";
315 const char kOutsideCompilationSrcOutputAttrName[] = "_xla_oc_to_oc_src_output";
316 const char kXlaControlDependenciesWithinXlaClusterAttrName[] =
317 "_xla_control_dependencies_within_xla_cluster";
318 const char kXlaIsLiftedArgAttrName[] = "_xla_is_lifted_arg";
319 const char kXlaLiftedArgOutsideCompilationAttrName[] = "_xla_lifted_arg_oc";
320 const char kXlaOutsideCompilationInputsAttrName[] = "_xla_oc_inputs";
321 const char kXlaIsPlaceholderForArg[] = "_xla_is_placeholder_for_arg";
322
PerformStaticShapeInferenceBeforeEncapsulation(Graph * g)323 Status PerformStaticShapeInferenceBeforeEncapsulation(Graph* g) {
324 // Perform shape inference.
325 std::map<int, InferredShape> arg_shapes;
326 GraphShapeInfo shape_info;
327 TF_RETURN_IF_ERROR(
328 InferShapes(g, arg_shapes, /*fnlib_def=*/nullptr, &shape_info));
329
330 // Add attribute for output shapes.
331 auto node_name_index = g->BuildNodeNameIndex();
332 for (auto iter : shape_info) {
333 std::vector<PartialTensorShape> output_shapes;
334 std::transform(iter.second.begin(), iter.second.end(),
335 std::back_inserter(output_shapes),
336 [](const InferredShape& inferred_shape) {
337 return inferred_shape.shape;
338 });
339 Node* n = node_name_index[iter.first];
340 n->AddAttr(kXlaInferredShapesAttrName, output_shapes);
341 }
342
343 return Status::OK();
344 }
345
346 StatusOr<std::unique_ptr<absl::flat_hash_map<string, std::vector<string>>>>
OutsideCompilationClusterDependencies(const Graph * g,const string & outside_compilation_attr_name)347 OutsideCompilationClusterDependencies(
348 const Graph* g, const string& outside_compilation_attr_name) {
349 auto cluster_deps = absl::make_unique<
350 absl::flat_hash_map<string, absl::flat_hash_set<string>>>();
351
352 for (const Edge* e : g->edges()) {
353 auto src_outside_compilation =
354 GetStringAttr(*e->src(), outside_compilation_attr_name);
355 auto dst_outside_compilation =
356 GetStringAttr(*e->dst(), outside_compilation_attr_name);
357
358 if (src_outside_compilation && dst_outside_compilation &&
359 *src_outside_compilation != *dst_outside_compilation) {
360 auto dst_deps_it = cluster_deps->find(*dst_outside_compilation);
361 if (dst_deps_it == cluster_deps->end()) {
362 cluster_deps->insert(std::make_pair(
363 *dst_outside_compilation,
364 absl::flat_hash_set<string>({*src_outside_compilation})));
365 } else {
366 dst_deps_it->second.insert(*src_outside_compilation);
367 }
368 }
369 }
370
371 auto cluster_deps_ordered =
372 absl::make_unique<absl::flat_hash_map<string, std::vector<string>>>();
373
374 for (auto it = cluster_deps->begin(); it != cluster_deps->end(); it++) {
375 std::vector<string> ordered_deps(it->second.begin(), it->second.end());
376 std::sort(ordered_deps.begin(), ordered_deps.end());
377 cluster_deps_ordered->insert(std::make_pair(it->first, ordered_deps));
378 }
379
380 return std::move(cluster_deps_ordered);
381 }
382
PreprocessEdgesBetweenOutsideCompilations(Graph * g,const string & outside_compilation_attr_name)383 Status PreprocessEdgesBetweenOutsideCompilations(
384 Graph* g, const string& outside_compilation_attr_name) {
385 // Remove edges from source node to outside compilation nodes, and edges
386 // from outside compilation nodes to sink node.
387 std::vector<const Edge*> edges_to_remove;
388 for (const Edge* e : g->source_node()->out_edges()) {
389 if (HasNodeAttr(e->dst()->def(), outside_compilation_attr_name)) {
390 edges_to_remove.push_back(e);
391 }
392 }
393 for (const Edge* e : g->sink_node()->in_edges()) {
394 if (HasNodeAttr(e->src()->def(), outside_compilation_attr_name)) {
395 edges_to_remove.push_back(e);
396 }
397 }
398 for (auto e : edges_to_remove) {
399 g->RemoveEdge(e);
400 }
401
402 TF_RETURN_IF_ERROR(PreprocessControlEdgesBetweenOutsideCompilations(
403 g, outside_compilation_attr_name));
404 TF_RETURN_IF_ERROR(PreprocessDataEdgesBetweenOutsideCompilations(
405 g, outside_compilation_attr_name));
406 return Status::OK();
407 }
408
PostprocessEdgesBetweenOutsideCompilations(Graph * g,const string & outside_compilation_attr_name)409 Status PostprocessEdgesBetweenOutsideCompilations(
410 Graph* g, const string& outside_compilation_attr_name) {
411 TF_RETURN_IF_ERROR(PostprocessDataEdgesBetweenOutsideCompilations(
412 g, outside_compilation_attr_name));
413 TF_RETURN_IF_ERROR(PostprocessControlEdgesBetweenOutsideCompilations(
414 g, outside_compilation_attr_name));
415 return Status::OK();
416 }
417
418 } // namespace tensorflow
419