1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/optimizers/data/vectorization_utils.h"
17 #include "absl/container/flat_hash_set.h"
18 #include "tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h"
19
20 #include "absl/strings/str_join.h"
21 #include "tensorflow/cc/framework/ops.h"
22 #include "tensorflow/core/common_runtime/function.h"
23 #include "tensorflow/core/framework/attr_value.pb.h"
24 #include "tensorflow/core/framework/device_base.h"
25 #include "tensorflow/core/framework/function.h"
26 #include "tensorflow/core/framework/graph_to_functiondef.h"
27 #include "tensorflow/core/framework/node_def.pb.h"
28 #include "tensorflow/core/framework/node_def_builder.h"
29 #include "tensorflow/core/framework/node_def_util.h"
30 #include "tensorflow/core/framework/op_def.pb.h"
31 #include "tensorflow/core/framework/types.h"
32 #include "tensorflow/core/graph/node_builder.h"
33 #include "tensorflow/core/grappler/mutable_graph_view.h"
34 #include "tensorflow/core/grappler/optimizers/data/function_utils.h"
35 #include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
36 #include "tensorflow/core/grappler/utils.h"
37 #include "tensorflow/core/grappler/utils/functions.h"
38 #include "tensorflow/core/lib/gtl/map_util.h"
39
40 namespace tensorflow {
41 namespace grappler {
42 namespace vectorization_utils {
43
44 namespace {
45
46 // Describes a tensor with its operation Node and output position
47 typedef std::pair<Node*, int> TensorDesc;
48
49 constexpr char kRetValOp[] = "_Retval";
50
ReplaceEdgeSources(const TensorDesc & old_src,const TensorDesc & new_src,Graph * graph)51 void ReplaceEdgeSources(const TensorDesc& old_src, const TensorDesc& new_src,
52 Graph* graph) {
53 // NOTE: We need two for loops here because we can't mutate the set of output
54 // edges as we iterate over them.
55 std::vector<const Edge*> edges_to_replace;
56 for (auto edge : old_src.first->out_edges()) {
57 if (edge->src_output() == old_src.second) {
58 edges_to_replace.push_back(edge);
59 }
60 }
61 for (auto edge : edges_to_replace) {
62 graph->AddEdge(new_src.first, new_src.second, edge->dst(),
63 edge->dst_input());
64 graph->RemoveEdge(edge);
65 }
66 }
67
68 // Update node attrs to keep its properties consistent with the function
UpdateMapDefunAttrs(FunctionBody * map_defun_fn,Node * map_defun_node)69 void UpdateMapDefunAttrs(FunctionBody* map_defun_fn, Node* map_defun_node) {
70 map_defun_node->AddAttr("output_types", map_defun_fn->ret_types);
71
72 // TODO(rachelim): Propagate precise shapes if they're known, which may enable
73 // subsequent optimizations.
74 map_defun_node->AddAttr("output_shapes", std::vector<PartialTensorShape>(
75 map_defun_fn->ret_types.size()));
76 }
77
AddMapDefunOutput(FunctionBody * map_defun_fn,Node * map_defun_node,const TensorDesc & output)78 Status AddMapDefunOutput(FunctionBody* map_defun_fn, Node* map_defun_node,
79 const TensorDesc& output) {
80 DataType type = output.first->output_type(output.second);
81 int index = map_defun_fn->ret_nodes.size();
82
83 NodeDef ret_node_def;
84 ret_node_def.set_name("map_out");
85 ret_node_def.set_op(kRetValOp);
86 AddNodeAttr("T", type, &ret_node_def);
87 AddNodeAttr("index", index, &ret_node_def);
88
89 Status s;
90 Node* ret_node = map_defun_fn->graph->AddNode(ret_node_def, &s);
91 TF_RETURN_IF_ERROR(s);
92
93 map_defun_fn->graph->AddEdge(output.first, output.second, ret_node, 0);
94 map_defun_fn->ret_nodes.push_back(ret_node);
95 map_defun_fn->ret_types.push_back(type);
96 UpdateMapDefunAttrs(map_defun_fn, map_defun_node);
97
98 return s;
99 }
100
RemoveMapDefunOutput(int output_position,Graph * outer_scope,FunctionBody * map_defun_fn,Node * map_defun_node)101 void RemoveMapDefunOutput(int output_position, Graph* outer_scope,
102 FunctionBody* map_defun_fn, Node* map_defun_node) {
103 DCHECK_LT(output_position, map_defun_fn->ret_nodes.size())
104 << "Trying to remove output that doesn't exist. Output number: "
105 << output_position;
106
107 int num_later_outputs = map_defun_fn->ret_nodes.size() - output_position - 1;
108
109 // Modify map_defun_fn's signature and remove the output node from its graph
110 map_defun_fn->graph->RemoveNode(map_defun_fn->ret_nodes[output_position]);
111 map_defun_fn->ret_nodes.erase(map_defun_fn->ret_nodes.begin() +
112 output_position);
113 map_defun_fn->ret_types.erase(map_defun_fn->ret_types.begin() +
114 output_position);
115 UpdateMapDefunAttrs(map_defun_fn, map_defun_node);
116
117 // Renumber the nodes and edges that come after
118 for (int i = 0; i < num_later_outputs; ++i) {
119 ReplaceEdgeSources({map_defun_node, output_position + i + 1},
120 {map_defun_node, output_position + i}, outer_scope);
121 // Each ret node has an "index" attr that has to be updated
122 map_defun_fn->ret_nodes[output_position + i]->AddAttr("index",
123 output_position + i);
124 }
125 }
126
127 // Helper class that vectorizes the body of a MapDefun node, adding new
128 // operations to the graph that collectively compute the same value as what
129 // running the MapDefun function on slices of the input would produce.
130 // This class transforms the input FunctionDefs into their corresponding
131 // Graph objects and works on the graphs directly, then converts them back
132 // to FunctionDefs when GetResult is called.
133 // TODO(rachelim): Move this to its own header.
134 class Vectorization {
135 public:
Vectorization(FunctionDefLibrary * lib)136 explicit Vectorization(FunctionDefLibrary* lib)
137 : lib_(lib), lib_def_(OpRegistry::Global(), *lib) {}
138
139 // Adds the vectorized function and new map_defun_fn to lib, and points
140 // vectorized_function to the former. Returns an error status if
141 // the conversion between FunctionDef -> Graph -> FunctionDef failed anywhere
142 // along the way.
143 Status Vectorize(const FunctionDef& outer_scope,
144 const NodeDef& map_defun_node, FunctionDef** result);
145
146 private:
147 // Converts FunctionDefs to Graphs and adds mappings from
148 // arg nodes and unstacked nodes to the corresponding nodes in outer_scope_.
149 Status Initialize(const FunctionDef& outer_scope,
150 const NodeDef& map_defun_node);
151
152 // Converts Graphs back to FunctionDefs and adds them to `lib_`.
153 Status GetResult(FunctionDef** vectorized_function);
154
155 // Repeatedly tries to convert outputs of `map_defun_fn_` into new nodes in
156 // `outer_scope_`, until there are no convertible outputs remaining.
157 void VectorizeHelper();
158
159 // Vectorizes map_defun_fn's output at output_position.
160 Status ConvertOutput(int output_position);
161
162 // Adds mappings from node's outputs tensors to converted output tensors,
163 // creating the necessary new node(s). Generally, the steps to convert an op
164 // are:
165 // 1) Create new node(s) in `outer_scope_` that act on batched input tensors.
166 // These operations collectively compute the same value as what running
167 // the original operation on slices of the input tensors would produce.
168 // For example, a Cast op in MapDefun translates to a Cast op in
169 // `outer_scope_`, since the vectorized version of Cast is itself.
170 // 2) Promote the inputs of the op inputs to outputs of the
171 // `map_defun_node_` and `map_defun_fn_`.
172 // 3) Add edges between the promoted inputs (that are now outputs of
173 // `map_defun_node`) and the inputs ports of the new node(s).
174 // 4) For each output of the old node, add the mapping of output tensors to
175 // the conversion map.
176 Status AddConversionMapping(Node* op_node);
177
178 // Given a tensor t in `unstacked`, stacks it by doing the equivalent of
179 // tf.tile(tf.expand_dims(t, 0), [n, 1, 1, ...]) where n is dimension 0 of
180 // inputs to `map_defun_node_`. This stacked tensor will be compatible with
181 // the expected output shape of `map_defun_node_`.
182 // This is equivalent to the _stack function in python Pfor.
183 Status StackTensor(WrappedTensor* unstacked, TensorDesc* result);
184
185 // Recursively looks for unstacked nodes in the `map_defun_fn_` graph by
186 // doing a depth-first search from the ret nodes. Lifts tensors that are
187 // unstacked (i.e. don't derive from arg tensors) into `outer_scope_` directly
188 // and adds mappings to `conversion_map_`.
189 // Note that this function may have false negatives, i.e. not
190 // add mappings for some tensors that are unstacked. This may happen in the
191 // following cases: 1) a vectorized op produces unstacked outputs from stacked
192 // inputs (e.g. the vectorized "Shape" op), 2) the tensors are in a cycle, or
193 // 3) the unstacked op could not be lifted into `outer_scope`.
194 Status AddUnstackedTensorMappings();
195
196 // Recursive helper for `AddUnstackedTensorMappings`. If an op node is
197 // unstacked, lifts its output tensors into `outer_scope`, adding the mappings
198 // to `conversion_map`. Returns true if the unstacked mappings were added.
199 bool AddUnstackedTensorMappingsHelper(
200 TensorDesc&& tensor, absl::flat_hash_set<const Edge*>* visited);
201
202 // Add mappings from `map_defun_fn_` arg tensors to `map_defun_node_` input
203 // tensors to `conversion_map_`.
204 Status AddArgTensorMappings();
205
206 // Maps a tensor to the corresponding WrappedTensor. For example,
207 // {"Cast" Node*, 0} -> WrappedTensor({"Vectorize/Cast" Node*, 0}, true)
208 std::map<TensorDesc, WrappedTensor> conversion_map_;
209
210 // Unconvertible ret nodes
211 std::set<Node*> unconvertible_;
212
213 FunctionDefLibrary* lib_; // Not owned
214 FunctionLibraryDefinition lib_def_;
215 // Note that FunctionBody has a pointer to a Graph object that corresponds
216 // to the function's subgraph, with additional kArgOp and kRetValOp nodes
217 // that denote that function arguments and return values. These nodes have the
218 // attrs "T" for the type, and "index" for the argument / retval index
219 // respectively. FunctionBody also keeps track of arg/ret_nodes and
220 // arg/ret_types, that should be ordered according to argument/output indices.
221 std::unique_ptr<Graph> outer_scope_;
222 std::unique_ptr<FunctionBody> map_defun_fn_;
223 Node* map_defun_node_ = nullptr; // Owned by `outer_scope`
224
225 // Caches the loop_len_node_ needed for tiling unstacked output. This
226 // corresponds to a vector with one element.
227 Node* loop_len_node_ = nullptr; // Owned by `outer_scope`
228 Status status_;
229 };
230
AddConversionMapping(Node * op_node)231 Status Vectorization::AddConversionMapping(Node* op_node) {
232 for (auto edge : op_node->in_edges()) {
233 if (edge->IsControlEdge()) {
234 return errors::InvalidArgument(
235 "Vectorizing outputs with control inputs is currently not "
236 "supported.");
237 }
238 }
239
240 auto vectorizer = VectorizerRegistry::Global()->Get(op_node->type_string());
241 if (vectorizer == nullptr) {
242 return errors::Unimplemented("No vectorizer registered for op: ",
243 op_node->type_string());
244 }
245 std::vector<WrappedTensor> inputs, outputs;
246 inputs.reserve(op_node->num_inputs());
247 outputs.reserve(op_node->num_outputs());
248
249 std::vector<const Edge*> input_edges;
250 TF_RETURN_IF_ERROR(op_node->input_edges(&input_edges));
251
252 // The inputs for the node to be converted may already have been converted
253 // themselves. For those that are not, we promote them to MapDefun outputs.
254 for (int i = 0; i < op_node->num_inputs(); ++i) {
255 auto edge = input_edges[i];
256 if (auto found = gtl::FindOrNull(conversion_map_,
257 {edge->src(), edge->src_output()})) {
258 inputs.push_back(*found);
259 } else {
260 // TODO(rachelim): Handle the case where unconverted inputs are unstacked.
261 // We assume that all unconverted inputs will be stacked, since we
262 // converted all unstacked nodes in `Initialize`. However, it's actually
263 // possible that yet-unconverted nodes may produce unstacked outputs after
264 // they are vectorized. (For example, see the "Shape" converter in
265 // tensorflow/python/ops/parallel_for/pfor.py). If a vectorizer expects
266 // an unstacked input but receives a stacked one, vectorizer->Vectorize
267 // will return an error.
268 TF_RETURN_IF_ERROR(AddMapDefunOutput(map_defun_fn_.get(), map_defun_node_,
269 {edge->src(), edge->src_output()}));
270 int output_index = map_defun_fn_->ret_nodes.size() - 1;
271 inputs.push_back({map_defun_node_, output_index, true});
272 }
273 }
274
275 Status s = vectorizer->Vectorize(*op_node, outer_scope_.get(),
276 std::move(inputs), &outputs);
277 if (!s.ok()) {
278 VLOG(2) << "Vectorizer for op \"" << op_node->type_string()
279 << "\" failed with error: " << s;
280 return s;
281 }
282 const int64 op_node_num_outputs = op_node->num_outputs();
283 if (op_node_num_outputs != outputs.size()) {
284 return errors::Internal(
285 "Number of vectorizer outputs does not match. Expected: ",
286 op_node->num_outputs(), " Actual: ", outputs.size());
287 }
288
289 // Add output mappings.
290 for (int i = 0; i < op_node->num_outputs(); ++i) {
291 conversion_map_.insert({{op_node, i}, outputs[i]});
292 }
293
294 return Status::OK();
295 }
296
ConvertOutput(int output_position)297 Status Vectorization::ConvertOutput(int output_position) {
298 // ret_edge->src() is the actual op that generated the retval, and
299 // ret_edge->dst() is the retval node whose op is "_Retval"
300 const Edge* ret_edge;
301 TF_RETURN_IF_ERROR(
302 map_defun_fn_->ret_nodes[output_position]->input_edge(0, &ret_edge));
303
304 TensorDesc output({ret_edge->src(), ret_edge->src_output()});
305 TensorDesc converted_output;
306
307 // It's possible the output already has a mapping, if it comes from a node
308 // that has already been converted.
309 auto found = gtl::FindOrNull(conversion_map_, output);
310 if (!found) {
311 TF_RETURN_IF_ERROR(AddConversionMapping(output.first));
312 found = &conversion_map_.at(output);
313 }
314
315 if (found->stacked) {
316 converted_output = {found->node, found->output_index};
317 } else {
318 // Some outputs may be unstacked if they don't derive from arg nodes
319 // (for example, if a function returns a constant). For these, we
320 // have to add extra nodes to tile it in the 0th dimension.
321 TF_RETURN_IF_ERROR(StackTensor(found, &converted_output));
322 }
323
324 ReplaceEdgeSources({map_defun_node_, output_position}, converted_output,
325 outer_scope_.get());
326 RemoveMapDefunOutput(output_position, outer_scope_.get(), map_defun_fn_.get(),
327 map_defun_node_);
328
329 return Status::OK();
330 }
331
Vectorize(const FunctionDef & outer_scope,const NodeDef & map_defun_node,FunctionDef ** result)332 Status Vectorization::Vectorize(const FunctionDef& outer_scope,
333 const NodeDef& map_defun_node,
334 FunctionDef** result) {
335 TF_RETURN_IF_ERROR(Initialize(outer_scope, map_defun_node));
336 VectorizeHelper();
337 return GetResult(result);
338 }
339
VectorizeHelper()340 void Vectorization::VectorizeHelper() {
341 while (true) {
342 int output_position = graph_utils::GetFirstElementIndexWithPredicate(
343 [this](Node* n) {
344 return this->unconvertible_.find(n) == this->unconvertible_.end();
345 },
346 map_defun_fn_->ret_nodes);
347
348 // No outputs left to convert
349 if (output_position == -1) break;
350
351 Status s = ConvertOutput(output_position);
352 if (!s.ok()) {
353 Node* output_node = map_defun_fn_->ret_nodes.at(output_position);
354 VLOG(2) << "Could not convert the output at node: "
355 << output_node->DebugString() << "\nError: " << s;
356 unconvertible_.insert(output_node);
357 }
358 }
359
360 // If we've converted all the outputs of the MapDefun function, we no longer
361 // need the MapDefun node and can delete it.
362 if (map_defun_fn_->ret_nodes.empty()) {
363 outer_scope_->RemoveNode(map_defun_node_);
364 }
365 }
366
Initialize(const FunctionDef & outer_scope,const NodeDef & map_defun_node)367 Status Vectorization::Initialize(const FunctionDef& outer_scope,
368 const NodeDef& map_defun_node) {
369 // Convert outer_scope and map_defun_fn to FunctionBodys so we can
370 // work on Graphs directly.
371 const FunctionDef* map_defun_fn =
372 lib_def_.Find(map_defun_node.attr().at("f").func().name());
373
374 if (map_defun_fn == nullptr) {
375 return errors::NotFound("Could not find function with name ",
376 map_defun_node.attr().at("f").func().name(),
377 " in function library.");
378 }
379
380 std::unique_ptr<FunctionBody> outer_fn;
381 TF_RETURN_IF_ERROR(
382 FunctionDefToBodyHelper(outer_scope, {}, &lib_def_, &outer_fn));
383 // We don't need outer_fn, just the graph
384 outer_scope_.reset(outer_fn->graph);
385 outer_fn->graph = nullptr;
386
387 TF_RETURN_IF_ERROR(
388 FunctionDefToBodyHelper(*map_defun_fn, {}, &lib_def_, &map_defun_fn_));
389
390 // Find the MapDefun node in outer_scope_
391 int node_id = graph_utils::GetFirstElementIndexWithPredicate(
392 [&map_defun_node](Node* n) { return n->name() == map_defun_node.name(); },
393 outer_scope_->nodes());
394 if (node_id == -1) {
395 return errors::NotFound("Could not find node with name ",
396 map_defun_node.name(), " in outer_scope.");
397 }
398 map_defun_node_ = outer_scope_->FindNodeId(node_id);
399
400 TF_RETURN_IF_ERROR(AddArgTensorMappings());
401 TF_RETURN_IF_ERROR(AddUnstackedTensorMappings());
402 loop_len_node_ = nullptr;
403
404 return Status::OK();
405 }
406
407 // TODO(rachelim): It might be profitable to use the C++ API for this instead of
408 // NodeBuilder
StackTensor(WrappedTensor * unstacked,TensorDesc * result)409 Status Vectorization::StackTensor(WrappedTensor* unstacked,
410 TensorDesc* result) {
411 if (unstacked->node->output_type(unstacked->output_index) == DT_VARIANT) {
412 // TODO(b/124069171): "ExpandDims" doesn't work with Variant tensors.
413 return errors::Unimplemented("Cannot stack tensor with Variant type.");
414 }
415 // Note that all these nodes are necessary as the size of the batch may not be
416 // constant.
417 if (unstacked->stacked) {
418 return errors::Internal("Can only stack unstacked tensor.");
419 }
420
421 Graph* g = outer_scope_.get();
422 auto node_builder = [](StringPiece op) {
423 return NodeBuilder(strings::StrCat("vectorized/stack/", op), op);
424 };
425
426 auto make_const = [&node_builder](const Input::Initializer& val, Graph* graph,
427 Node** result) {
428 TF_RETURN_IF_ERROR(val.status);
429 return node_builder("Const")
430 .Attr("value", val.tensor)
431 .Attr("dtype", val.tensor.dtype())
432 .Finalize(graph, result);
433 };
434
435 // If loop_len_node_ hasn't been created yet, add the node and cache it.
436 if (loop_len_node_ == nullptr) {
437 Node* input_node;
438 TF_RETURN_IF_ERROR(map_defun_node_->input_node(0, &input_node));
439
440 Node* shape_node;
441 TF_RETURN_IF_ERROR(
442 node_builder("Shape").Input(input_node).Finalize(g, &shape_node));
443
444 Node* const_vec_0;
445 TF_RETURN_IF_ERROR(make_const({0}, g, &const_vec_0));
446 Node* const_vec_1;
447 TF_RETURN_IF_ERROR(make_const({1}, g, &const_vec_1));
448
449 Node* strided_slice_node;
450 TF_RETURN_IF_ERROR(node_builder("StridedSlice")
451 .Input(shape_node) // input
452 .Input(const_vec_0) // begin
453 .Input(const_vec_1) // end
454 .Input(const_vec_1) // strides
455 .Finalize(g, &strided_slice_node));
456
457 // Produces a vector of length 1
458 TF_RETURN_IF_ERROR(node_builder("Reshape")
459 .Input(strided_slice_node) // tensor
460 .Input(const_vec_1) // shape
461 .Finalize(g, &loop_len_node_));
462 }
463
464 Node* ones_shape;
465 TF_RETURN_IF_ERROR(node_builder("Shape")
466 .Input(unstacked->node) // input
467 .Finalize(g, &ones_shape));
468
469 Node* ones;
470 TF_RETURN_IF_ERROR(
471 node_builder("OnesLike").Input(ones_shape).Finalize(g, &ones));
472
473 Node* const_0;
474 TF_RETURN_IF_ERROR(make_const(0, g, &const_0));
475
476 Node* multiples;
477 TF_RETURN_IF_ERROR(node_builder("Concat")
478 .Input(const_0) // concat_dim
479 .Input({{loop_len_node_, 0}, {ones, 0}}) // values
480 .Finalize(g, &multiples));
481
482 Node* expand_dims;
483 TF_RETURN_IF_ERROR(node_builder("ExpandDims")
484 .Input(unstacked->node) // input
485 .Input(const_0) // dim
486 .Finalize(g, &expand_dims));
487
488 TF_RETURN_IF_ERROR(node_builder("Tile")
489 .Input(expand_dims) // input
490 .Input(multiples) // multiples
491 .Finalize(g, &result->first));
492 result->second = 0;
493 return Status::OK();
494 }
495
AddArgTensorMappings()496 Status Vectorization::AddArgTensorMappings() {
497 // Note that inputs to map_defun_fn_ are either regular arguments (for which
498 // the operations are mapped across their 0th dimension) or captured inputs
499 // (for which the operations apply to the argument wholesale).
500 int num_args =
501 map_defun_node_->attrs().Find("Targuments")->list().type_size();
502
503 auto add_conversion = [this](Node* arg_node, bool stacked) {
504 Node* input_node;
505 TF_RETURN_IF_ERROR(map_defun_node_->input_node(
506 arg_node->attrs().Find("index")->i(), &input_node));
507
508 conversion_map_.insert({{arg_node, 0}, {input_node, 0, stacked}});
509
510 // Control inputs
511 conversion_map_.insert({{arg_node, Graph::kControlSlot},
512 {input_node, Graph::kControlSlot, stacked}});
513
514 return Status::OK();
515 };
516
517 // Regular arguments
518 for (int i = 0; i < num_args; ++i) {
519 TF_RETURN_IF_ERROR(add_conversion(map_defun_fn_->arg_nodes[i], true));
520 }
521
522 // Captured inputs. These are applied (without slicing) to every iteration of
523 // the map function, hence are mapped to unstacked nodes.
524 for (int i = num_args, end = map_defun_fn_->arg_nodes.size(); i < end; ++i) {
525 TF_RETURN_IF_ERROR(add_conversion(map_defun_fn_->arg_nodes[i], false));
526 }
527
528 return Status::OK();
529 }
530
AddUnstackedTensorMappingsHelper(TensorDesc && tensor,absl::flat_hash_set<const Edge * > * visited)531 bool Vectorization::AddUnstackedTensorMappingsHelper(
532 TensorDesc&& tensor, absl::flat_hash_set<const Edge*>* visited) {
533 if (auto found = gtl::FindOrNull(conversion_map_, tensor)) {
534 return !found->stacked;
535 }
536
537 if (tensor.first->op_def().is_stateful()) {
538 // We don't lift stateful nodes directly out of the MapDefun, since they may
539 // have to be executed N times.
540 return false;
541 }
542
543 bool is_unstacked = true;
544 for (const auto& edge : tensor.first->in_edges()) {
545 // Ignore Source nodes. Note that these are also ignored in the
546 // GraphToFunctionDef conversion.
547 if (edge->src()->IsSource()) continue;
548
549 if (visited->find(edge) != visited->end()) {
550 // If we've visited this edge already, we're in a cycle. In this case, we
551 // are conservative and don't mark the node as unstacked.
552 is_unstacked = false;
553 continue;
554 }
555 visited->insert(edge);
556
557 // A node is unstacked if all of its inputs are unstacked
558 is_unstacked &= AddUnstackedTensorMappingsHelper(
559 {edge->src(), edge->src_output()}, visited);
560 }
561
562 if (!is_unstacked) {
563 return false;
564 }
565
566 // If the node is unstacked, we copy it into outer_scope_ and
567 // add it to the map. Note that we don't clean up the nodes that are copied
568 // in map_defun_fn_, and rely on them being pruned out later.
569 Status status;
570 Node* node = outer_scope_->AddNode(tensor.first->def(), &status);
571 if (!status.ok()) return false;
572
573 // Add input edges to nodes that should already have been lifted.
574 for (const auto& edge : tensor.first->in_edges()) {
575 // Ignore Source nodes. Note that these are also ignored in the
576 // GraphToFunctionDef conversion.
577 if (edge->src()->IsSource()) continue;
578
579 if (auto found = gtl::FindOrNull(conversion_map_,
580 {edge->src(), edge->src_output()})) {
581 outer_scope_->AddEdge(found->node, found->output_index, node,
582 edge->dst_input());
583 } else {
584 return false;
585 }
586 }
587
588 // Add output mappings
589 for (int i = 0; i < tensor.first->num_outputs(); ++i) {
590 conversion_map_.insert({{tensor.first, i}, WrappedTensor(node, i, false)});
591 }
592 conversion_map_.insert({{tensor.first, Graph::kControlSlot},
593 WrappedTensor(node, Graph::kControlSlot, false)});
594
595 return true;
596 }
597
AddUnstackedTensorMappings()598 Status Vectorization::AddUnstackedTensorMappings() {
599 absl::flat_hash_set<const Edge*> visited;
600 for (const auto& ret_node : map_defun_fn_->ret_nodes) {
601 const Edge* in_edge = nullptr;
602 TF_RETURN_IF_ERROR(ret_node->input_edge(0, &in_edge));
603 AddUnstackedTensorMappingsHelper({in_edge->src(), in_edge->src_output()},
604 &visited);
605 }
606 return Status::OK();
607 }
608
GetResult(FunctionDef ** vectorized_function)609 Status Vectorization::GetResult(FunctionDef** vectorized_function) {
610 TF_RETURN_IF_ERROR(status_);
611 TF_RETURN_IF_ERROR(graph_utils::EnsureNodeNamesUnique(outer_scope_.get()));
612 TF_RETURN_IF_ERROR(graph_utils::EnsureNodeNamesUnique(map_defun_fn_->graph));
613
614 if (!map_defun_fn_->ret_nodes.empty()) {
615 FunctionDef* map_defun_fn = lib_->add_function();
616 graph_utils::SetUniqueGraphFunctionName("map_defun_fn", lib_, map_defun_fn);
617 TF_RETURN_IF_ERROR(GraphToFunctionDef(
618 *map_defun_fn_->graph, map_defun_fn->signature().name(), map_defun_fn));
619
620 AttrValue func_attr;
621 func_attr.mutable_func()->set_name(map_defun_fn->signature().name());
622 map_defun_node_->AddAttr("f", func_attr);
623 }
624
625 *vectorized_function = lib_->add_function();
626 graph_utils::SetUniqueGraphFunctionName("vectorized_fn", lib_,
627 *vectorized_function);
628 TF_RETURN_IF_ERROR(GraphToFunctionDef(
629 *outer_scope_, (*vectorized_function)->signature().name(),
630 *vectorized_function));
631 return Status::OK();
632 }
633
634 } // namespace
635
VectorizeMapDefun(const FunctionDef & outer_scope,const NodeDef & map_defun_node,FunctionDefLibrary * lib,FunctionDef ** result)636 Status VectorizeMapDefun(const FunctionDef& outer_scope,
637 const NodeDef& map_defun_node, FunctionDefLibrary* lib,
638 FunctionDef** result) {
639 *result = nullptr;
640 return Vectorization(lib).Vectorize(outer_scope, map_defun_node, result);
641 }
642
643 } // namespace vectorization_utils
644 } // namespace grappler
645 } // namespace tensorflow
646