• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/optimizers/data/vectorization_utils.h"
17 #include "absl/container/flat_hash_set.h"
18 #include "tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h"
19 
20 #include "absl/strings/str_join.h"
21 #include "tensorflow/cc/framework/ops.h"
22 #include "tensorflow/core/common_runtime/function.h"
23 #include "tensorflow/core/framework/attr_value.pb.h"
24 #include "tensorflow/core/framework/device_base.h"
25 #include "tensorflow/core/framework/function.h"
26 #include "tensorflow/core/framework/graph_to_functiondef.h"
27 #include "tensorflow/core/framework/node_def.pb.h"
28 #include "tensorflow/core/framework/node_def_builder.h"
29 #include "tensorflow/core/framework/node_def_util.h"
30 #include "tensorflow/core/framework/op_def.pb.h"
31 #include "tensorflow/core/framework/types.h"
32 #include "tensorflow/core/graph/node_builder.h"
33 #include "tensorflow/core/grappler/mutable_graph_view.h"
34 #include "tensorflow/core/grappler/optimizers/data/function_utils.h"
35 #include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
36 #include "tensorflow/core/grappler/utils.h"
37 #include "tensorflow/core/grappler/utils/functions.h"
38 #include "tensorflow/core/lib/gtl/map_util.h"
39 
40 namespace tensorflow {
41 namespace grappler {
42 namespace vectorization_utils {
43 
44 namespace {
45 
46 // Describes a tensor with its operation Node and output position
47 typedef std::pair<Node*, int> TensorDesc;
48 
49 constexpr char kRetValOp[] = "_Retval";
50 
ReplaceEdgeSources(const TensorDesc & old_src,const TensorDesc & new_src,Graph * graph)51 void ReplaceEdgeSources(const TensorDesc& old_src, const TensorDesc& new_src,
52                         Graph* graph) {
53   // NOTE: We need two for loops here because we can't mutate the set of output
54   // edges as we iterate over them.
55   std::vector<const Edge*> edges_to_replace;
56   for (auto edge : old_src.first->out_edges()) {
57     if (edge->src_output() == old_src.second) {
58       edges_to_replace.push_back(edge);
59     }
60   }
61   for (auto edge : edges_to_replace) {
62     graph->AddEdge(new_src.first, new_src.second, edge->dst(),
63                    edge->dst_input());
64     graph->RemoveEdge(edge);
65   }
66 }
67 
68 // Update node attrs to keep its properties consistent with the function
UpdateMapDefunAttrs(FunctionBody * map_defun_fn,Node * map_defun_node)69 void UpdateMapDefunAttrs(FunctionBody* map_defun_fn, Node* map_defun_node) {
70   map_defun_node->AddAttr("output_types", map_defun_fn->ret_types);
71 
72   // TODO(rachelim): Propagate precise shapes if they're known, which may enable
73   // subsequent optimizations.
74   map_defun_node->AddAttr("output_shapes", std::vector<PartialTensorShape>(
75                                                map_defun_fn->ret_types.size()));
76 }
77 
AddMapDefunOutput(FunctionBody * map_defun_fn,Node * map_defun_node,const TensorDesc & output)78 Status AddMapDefunOutput(FunctionBody* map_defun_fn, Node* map_defun_node,
79                          const TensorDesc& output) {
80   DataType type = output.first->output_type(output.second);
81   int index = map_defun_fn->ret_nodes.size();
82 
83   NodeDef ret_node_def;
84   ret_node_def.set_name("map_out");
85   ret_node_def.set_op(kRetValOp);
86   AddNodeAttr("T", type, &ret_node_def);
87   AddNodeAttr("index", index, &ret_node_def);
88 
89   Status s;
90   Node* ret_node = map_defun_fn->graph->AddNode(ret_node_def, &s);
91   TF_RETURN_IF_ERROR(s);
92 
93   map_defun_fn->graph->AddEdge(output.first, output.second, ret_node, 0);
94   map_defun_fn->ret_nodes.push_back(ret_node);
95   map_defun_fn->ret_types.push_back(type);
96   UpdateMapDefunAttrs(map_defun_fn, map_defun_node);
97 
98   return s;
99 }
100 
RemoveMapDefunOutput(int output_position,Graph * outer_scope,FunctionBody * map_defun_fn,Node * map_defun_node)101 void RemoveMapDefunOutput(int output_position, Graph* outer_scope,
102                           FunctionBody* map_defun_fn, Node* map_defun_node) {
103   DCHECK_LT(output_position, map_defun_fn->ret_nodes.size())
104       << "Trying to remove output that doesn't exist. Output number: "
105       << output_position;
106 
107   int num_later_outputs = map_defun_fn->ret_nodes.size() - output_position - 1;
108 
109   // Modify map_defun_fn's signature and remove the output node from its graph
110   map_defun_fn->graph->RemoveNode(map_defun_fn->ret_nodes[output_position]);
111   map_defun_fn->ret_nodes.erase(map_defun_fn->ret_nodes.begin() +
112                                 output_position);
113   map_defun_fn->ret_types.erase(map_defun_fn->ret_types.begin() +
114                                 output_position);
115   UpdateMapDefunAttrs(map_defun_fn, map_defun_node);
116 
117   // Renumber the nodes and edges that come after
118   for (int i = 0; i < num_later_outputs; ++i) {
119     ReplaceEdgeSources({map_defun_node, output_position + i + 1},
120                        {map_defun_node, output_position + i}, outer_scope);
121     // Each ret node has an "index" attr that has to be updated
122     map_defun_fn->ret_nodes[output_position + i]->AddAttr("index",
123                                                           output_position + i);
124   }
125 }
126 
127 // Helper class that vectorizes the body of a MapDefun node, adding new
128 // operations to the graph that collectively compute the same value as what
129 // running the MapDefun function on slices of the input would produce.
130 // This class transforms the input FunctionDefs into their corresponding
131 // Graph objects and works on the graphs directly, then converts them back
132 // to FunctionDefs when GetResult is called.
133 // TODO(rachelim): Move this to its own header.
134 class Vectorization {
135  public:
Vectorization(FunctionDefLibrary * lib)136   explicit Vectorization(FunctionDefLibrary* lib)
137       : lib_(lib), lib_def_(OpRegistry::Global(), *lib) {}
138 
139   // Adds the vectorized function and new map_defun_fn to lib, and points
140   // vectorized_function to the former. Returns an error status if
141   // the conversion between FunctionDef -> Graph -> FunctionDef failed anywhere
142   // along the way.
143   Status Vectorize(const FunctionDef& outer_scope,
144                    const NodeDef& map_defun_node, FunctionDef** result);
145 
146  private:
147   // Converts FunctionDefs to Graphs and adds mappings from
148   // arg nodes and unstacked nodes to the corresponding nodes in outer_scope_.
149   Status Initialize(const FunctionDef& outer_scope,
150                     const NodeDef& map_defun_node);
151 
152   // Converts Graphs back to FunctionDefs and adds them to `lib_`.
153   Status GetResult(FunctionDef** vectorized_function);
154 
155   // Repeatedly tries to convert outputs of `map_defun_fn_` into new nodes in
156   // `outer_scope_`, until there are no convertible outputs remaining.
157   void VectorizeHelper();
158 
159   // Vectorizes map_defun_fn's output at output_position.
160   Status ConvertOutput(int output_position);
161 
162   // Adds mappings from node's outputs tensors to converted output tensors,
163   // creating the necessary new node(s). Generally, the steps to convert an op
164   // are:
165   // 1) Create new node(s) in `outer_scope_` that act on batched input tensors.
166   //    These operations collectively compute the same value as what running
167   //    the original operation on slices of the input tensors would produce.
168   //    For example, a Cast op in MapDefun translates to a Cast op in
169   //    `outer_scope_`, since the vectorized version of Cast is itself.
170   // 2) Promote the inputs of the op inputs to outputs of the
171   //    `map_defun_node_` and `map_defun_fn_`.
172   // 3) Add edges between the promoted inputs (that are now outputs of
173   //    `map_defun_node`) and the inputs ports of the new node(s).
174   // 4) For each output of the old node, add the mapping of output tensors to
175   //    the conversion map.
176   Status AddConversionMapping(Node* op_node);
177 
178   // Given a tensor t in `unstacked`, stacks it by doing the equivalent of
179   // tf.tile(tf.expand_dims(t, 0), [n, 1, 1, ...]) where n is dimension 0 of
180   // inputs to `map_defun_node_`. This stacked tensor will be compatible with
181   // the expected output shape of `map_defun_node_`.
182   // This is equivalent to the _stack function in python Pfor.
183   Status StackTensor(WrappedTensor* unstacked, TensorDesc* result);
184 
185   // Recursively looks for unstacked nodes in the `map_defun_fn_` graph by
186   // doing a depth-first search from the ret nodes. Lifts tensors that are
187   // unstacked (i.e. don't derive from arg tensors) into `outer_scope_` directly
188   // and adds mappings to `conversion_map_`.
189   // Note that this function may have false negatives, i.e. not
190   // add mappings for some tensors that are unstacked. This may happen in the
191   // following cases: 1) a vectorized op produces unstacked outputs from stacked
192   // inputs (e.g. the vectorized "Shape" op), 2) the tensors are in a cycle, or
193   // 3) the unstacked op could not be lifted into `outer_scope`.
194   Status AddUnstackedTensorMappings();
195 
196   // Recursive helper for `AddUnstackedTensorMappings`. If an op node is
197   // unstacked, lifts its output tensors into `outer_scope`, adding the mappings
198   // to `conversion_map`. Returns true if the unstacked mappings were added.
199   bool AddUnstackedTensorMappingsHelper(
200       TensorDesc&& tensor, absl::flat_hash_set<const Edge*>* visited);
201 
202   // Add mappings from `map_defun_fn_` arg tensors to `map_defun_node_` input
203   // tensors to `conversion_map_`.
204   Status AddArgTensorMappings();
205 
206   // Maps a tensor to the corresponding WrappedTensor. For example,
207   // {"Cast" Node*, 0} -> WrappedTensor({"Vectorize/Cast" Node*, 0}, true)
208   std::map<TensorDesc, WrappedTensor> conversion_map_;
209 
210   // Unconvertible ret nodes
211   std::set<Node*> unconvertible_;
212 
213   FunctionDefLibrary* lib_;  // Not owned
214   FunctionLibraryDefinition lib_def_;
215   // Note that FunctionBody has a pointer to a Graph object that corresponds
216   // to the function's subgraph, with additional kArgOp and kRetValOp nodes
217   // that denote that function arguments and return values. These nodes have the
218   // attrs "T" for the type, and "index" for the argument / retval index
219   // respectively. FunctionBody also keeps track of arg/ret_nodes and
220   // arg/ret_types, that should be ordered according to argument/output indices.
221   std::unique_ptr<Graph> outer_scope_;
222   std::unique_ptr<FunctionBody> map_defun_fn_;
223   Node* map_defun_node_ = nullptr;  // Owned by `outer_scope`
224 
225   // Caches the loop_len_node_ needed for tiling unstacked output. This
226   // corresponds to a vector with one element.
227   Node* loop_len_node_ = nullptr;  // Owned by `outer_scope`
228   Status status_;
229 };
230 
AddConversionMapping(Node * op_node)231 Status Vectorization::AddConversionMapping(Node* op_node) {
232   for (auto edge : op_node->in_edges()) {
233     if (edge->IsControlEdge()) {
234       return errors::InvalidArgument(
235           "Vectorizing outputs with control inputs is currently not "
236           "supported.");
237     }
238   }
239 
240   auto vectorizer = VectorizerRegistry::Global()->Get(op_node->type_string());
241   if (vectorizer == nullptr) {
242     return errors::Unimplemented("No vectorizer registered for op: ",
243                                  op_node->type_string());
244   }
245   std::vector<WrappedTensor> inputs, outputs;
246   inputs.reserve(op_node->num_inputs());
247   outputs.reserve(op_node->num_outputs());
248 
249   std::vector<const Edge*> input_edges;
250   TF_RETURN_IF_ERROR(op_node->input_edges(&input_edges));
251 
252   // The inputs for the node to be converted may already have been converted
253   // themselves. For those that are not, we promote them to MapDefun outputs.
254   for (size_t i = 0; i < op_node->num_inputs(); ++i) {
255     auto edge = input_edges[i];
256     if (auto found = gtl::FindOrNull(conversion_map_,
257                                      {edge->src(), edge->src_output()})) {
258       inputs.push_back(*found);
259     } else {
260       // TODO(rachelim): Handle the case where unconverted inputs are unstacked.
261       // We assume that all unconverted inputs will be stacked, since we
262       // converted all unstacked nodes in `Initialize`. However, it's actually
263       // possible that yet-unconverted nodes may produce unstacked outputs after
264       // they are vectorized. (For example, see the "Shape" converter in
265       // tensorflow/python/ops/parallel_for/pfor.py). If a vectorizer expects
266       // an unstacked input but receives a stacked one, vectorizer->Vectorize
267       // will return an error.
268       TF_RETURN_IF_ERROR(AddMapDefunOutput(map_defun_fn_.get(), map_defun_node_,
269                                            {edge->src(), edge->src_output()}));
270       int output_index = map_defun_fn_->ret_nodes.size() - 1;
271       inputs.push_back({map_defun_node_, output_index, true});
272     }
273   }
274 
275   Status s = vectorizer->Vectorize(*op_node, outer_scope_.get(),
276                                    std::move(inputs), &outputs);
277   if (!s.ok()) {
278     VLOG(2) << "Vectorizer for op \"" << op_node->type_string()
279             << "\" failed with error: " << s;
280     return s;
281   }
282 
283   if (op_node->num_outputs() != outputs.size()) {
284     return errors::Internal(
285         "Number of vectorizer outputs does not match. Expected: ",
286         op_node->num_outputs(), " Actual: ", outputs.size());
287   }
288 
289   // Add output mappings.
290   for (size_t i = 0; i < op_node->num_outputs(); ++i) {
291     conversion_map_.insert({{op_node, i}, outputs[i]});
292   }
293 
294   return Status::OK();
295 }
296 
ConvertOutput(int output_position)297 Status Vectorization::ConvertOutput(int output_position) {
298   // ret_edge->src() is the actual op that generated the retval, and
299   // ret_edge->dst() is the retval node whose op is "_Retval"
300   const Edge* ret_edge;
301   TF_RETURN_IF_ERROR(
302       map_defun_fn_->ret_nodes[output_position]->input_edge(0, &ret_edge));
303 
304   TensorDesc output({ret_edge->src(), ret_edge->src_output()});
305   TensorDesc converted_output;
306 
307   // It's possible the output already has a mapping, if it comes from a node
308   // that has already been converted.
309   auto found = gtl::FindOrNull(conversion_map_, output);
310   if (!found) {
311     TF_RETURN_IF_ERROR(AddConversionMapping(output.first));
312     found = &conversion_map_.at(output);
313   }
314 
315   if (found->stacked) {
316     converted_output = {found->node, found->output_index};
317   } else {
318     // Some outputs may be unstacked if they don't derive from arg nodes
319     // (for example, if a function returns a constant). For these, we
320     // have to add extra nodes to tile it in the 0th dimension.
321     TF_RETURN_IF_ERROR(StackTensor(found, &converted_output));
322   }
323 
324   ReplaceEdgeSources({map_defun_node_, output_position}, converted_output,
325                      outer_scope_.get());
326   RemoveMapDefunOutput(output_position, outer_scope_.get(), map_defun_fn_.get(),
327                        map_defun_node_);
328 
329   return Status::OK();
330 }
331 
Vectorize(const FunctionDef & outer_scope,const NodeDef & map_defun_node,FunctionDef ** result)332 Status Vectorization::Vectorize(const FunctionDef& outer_scope,
333                                 const NodeDef& map_defun_node,
334                                 FunctionDef** result) {
335   TF_RETURN_IF_ERROR(Initialize(outer_scope, map_defun_node));
336   VectorizeHelper();
337   return GetResult(result);
338 }
339 
VectorizeHelper()340 void Vectorization::VectorizeHelper() {
341   while (true) {
342     int output_position = graph_utils::GetFirstElementIndexWithPredicate(
343         [this](Node* n) {
344           return this->unconvertible_.find(n) == this->unconvertible_.end();
345         },
346         map_defun_fn_->ret_nodes);
347 
348     // No outputs left to convert
349     if (output_position == -1) break;
350 
351     Status s = ConvertOutput(output_position);
352     if (!s.ok()) {
353       Node* output_node = map_defun_fn_->ret_nodes.at(output_position);
354       VLOG(2) << "Could not convert the output at node: "
355               << output_node->DebugString() << "\nError: " << s;
356       unconvertible_.insert(output_node);
357     }
358   }
359 
360   // If we've converted all the outputs of the MapDefun function, we no longer
361   // need the MapDefun node and can delete it.
362   if (map_defun_fn_->ret_nodes.empty()) {
363     outer_scope_->RemoveNode(map_defun_node_);
364   }
365 }
366 
Initialize(const FunctionDef & outer_scope,const NodeDef & map_defun_node)367 Status Vectorization::Initialize(const FunctionDef& outer_scope,
368                                  const NodeDef& map_defun_node) {
369   // Convert outer_scope and map_defun_fn to FunctionBodys so we can
370   // work on Graphs directly.
371   const FunctionDef* map_defun_fn =
372       lib_def_.Find(map_defun_node.attr().at("f").func().name());
373 
374   if (map_defun_fn == nullptr) {
375     return errors::NotFound("Could not find function with name ",
376                             map_defun_node.attr().at("f").func().name(),
377                             " in function library.");
378   }
379 
380   auto get_func_sig = [this](const string& op, const OpDef** sig) {
381     return this->lib_def_.LookUpOpDef(op, sig);
382   };
383 
384   FunctionBody* outer_fn;
385   TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(outer_scope, {}, &lib_def_,
386                                              get_func_sig, &outer_fn));
387   // We don't need outer_fn, just the graph
388   outer_scope_.reset(outer_fn->graph);
389   outer_fn->graph = nullptr;
390   delete outer_fn;
391 
392   FunctionBody* tmp;
393   TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*map_defun_fn, {}, &lib_def_,
394                                              get_func_sig, &tmp));
395   map_defun_fn_.reset(tmp);
396 
397   // Find the MapDefun node in outer_scope_
398   int node_id = graph_utils::GetFirstElementIndexWithPredicate(
399       [&map_defun_node](Node* n) { return n->name() == map_defun_node.name(); },
400       outer_scope_->nodes());
401   if (node_id == -1) {
402     return errors::NotFound("Could not find node with name ",
403                             map_defun_node.name(), " in outer_scope.");
404   }
405   map_defun_node_ = outer_scope_->FindNodeId(node_id);
406 
407   TF_RETURN_IF_ERROR(AddArgTensorMappings());
408   TF_RETURN_IF_ERROR(AddUnstackedTensorMappings());
409   loop_len_node_ = nullptr;
410 
411   return Status::OK();
412 }
413 
414 // TODO(rachelim): It might be profitable to use the C++ API for this instead of
415 // NodeBuilder
StackTensor(WrappedTensor * unstacked,TensorDesc * result)416 Status Vectorization::StackTensor(WrappedTensor* unstacked,
417                                   TensorDesc* result) {
418   if (unstacked->node->output_type(unstacked->output_index) == DT_VARIANT) {
419     // TODO(b/124069171): "ExpandDims" doesn't work with Variant tensors.
420     return errors::Unimplemented("Cannot stack tensor with Variant type.");
421   }
422   // Note that all these nodes are necessary as the size of the batch may not be
423   // constant.
424   if (unstacked->stacked) {
425     return errors::Internal("Can only stack unstacked tensor.");
426   }
427 
428   Graph* g = outer_scope_.get();
429   auto node_builder = [](StringPiece op) {
430     return NodeBuilder(strings::StrCat("vectorized/stack/", op), op);
431   };
432 
433   auto make_const = [&node_builder](const Input::Initializer& val, Graph* graph,
434                                     Node** result) {
435     TF_RETURN_IF_ERROR(val.status);
436     return node_builder("Const")
437         .Attr("value", val.tensor)
438         .Attr("dtype", val.tensor.dtype())
439         .Finalize(graph, result);
440   };
441 
442   // If loop_len_node_ hasn't been created yet, add the node and cache it.
443   if (loop_len_node_ == nullptr) {
444     Node* input_node;
445     TF_RETURN_IF_ERROR(map_defun_node_->input_node(0, &input_node));
446 
447     Node* shape_node;
448     TF_RETURN_IF_ERROR(
449         node_builder("Shape").Input(input_node).Finalize(g, &shape_node));
450 
451     Node* const_vec_0;
452     TF_RETURN_IF_ERROR(make_const({0}, g, &const_vec_0));
453     Node* const_vec_1;
454     TF_RETURN_IF_ERROR(make_const({1}, g, &const_vec_1));
455 
456     Node* strided_slice_node;
457     TF_RETURN_IF_ERROR(node_builder("StridedSlice")
458                            .Input(shape_node)   // input
459                            .Input(const_vec_0)  // begin
460                            .Input(const_vec_1)  // end
461                            .Input(const_vec_1)  // strides
462                            .Finalize(g, &strided_slice_node));
463 
464     // Produces a vector of length 1
465     TF_RETURN_IF_ERROR(node_builder("Reshape")
466                            .Input(strided_slice_node)  // tensor
467                            .Input(const_vec_1)         // shape
468                            .Finalize(g, &loop_len_node_));
469   }
470 
471   Node* ones_shape;
472   TF_RETURN_IF_ERROR(node_builder("Shape")
473                          .Input(unstacked->node)  // input
474                          .Finalize(g, &ones_shape));
475 
476   Node* ones;
477   TF_RETURN_IF_ERROR(
478       node_builder("OnesLike").Input(ones_shape).Finalize(g, &ones));
479 
480   Node* const_0;
481   TF_RETURN_IF_ERROR(make_const(0, g, &const_0));
482 
483   Node* multiples;
484   TF_RETURN_IF_ERROR(node_builder("Concat")
485                          .Input(const_0)                           // concat_dim
486                          .Input({{loop_len_node_, 0}, {ones, 0}})  // values
487                          .Finalize(g, &multiples));
488 
489   Node* expand_dims;
490   TF_RETURN_IF_ERROR(node_builder("ExpandDims")
491                          .Input(unstacked->node)  // input
492                          .Input(const_0)          // dim
493                          .Finalize(g, &expand_dims));
494 
495   TF_RETURN_IF_ERROR(node_builder("Tile")
496                          .Input(expand_dims)  // input
497                          .Input(multiples)    // multiples
498                          .Finalize(g, &result->first));
499   result->second = 0;
500   return Status::OK();
501 }
502 
AddArgTensorMappings()503 Status Vectorization::AddArgTensorMappings() {
504   // Note that inputs to map_defun_fn_ are either regular arguments (for which
505   // the operations are mapped across their 0th dimension) or captured inputs
506   // (for which the operations apply to the argument wholesale).
507   int num_args =
508       map_defun_node_->attrs().Find("Targuments")->list().type_size();
509 
510   auto add_conversion = [this](Node* arg_node, bool stacked) {
511     Node* input_node;
512     TF_RETURN_IF_ERROR(map_defun_node_->input_node(
513         arg_node->attrs().Find("index")->i(), &input_node));
514 
515     conversion_map_.insert({{arg_node, 0}, {input_node, 0, stacked}});
516 
517     // Control inputs
518     conversion_map_.insert({{arg_node, Graph::kControlSlot},
519                             {input_node, Graph::kControlSlot, stacked}});
520 
521     return Status::OK();
522   };
523 
524   // Regular arguments
525   for (int i = 0; i < num_args; ++i) {
526     TF_RETURN_IF_ERROR(add_conversion(map_defun_fn_->arg_nodes[i], true));
527   }
528 
529   // Captured inputs. These are applied (without slicing) to every iteration of
530   // the map function, hence are mapped to unstacked nodes.
531   for (int i = num_args; i < map_defun_fn_->arg_nodes.size(); ++i) {
532     TF_RETURN_IF_ERROR(add_conversion(map_defun_fn_->arg_nodes[i], false));
533   }
534 
535   return Status::OK();
536 }
537 
AddUnstackedTensorMappingsHelper(TensorDesc && tensor,absl::flat_hash_set<const Edge * > * visited)538 bool Vectorization::AddUnstackedTensorMappingsHelper(
539     TensorDesc&& tensor, absl::flat_hash_set<const Edge*>* visited) {
540   if (auto found = gtl::FindOrNull(conversion_map_, tensor)) {
541     return !found->stacked;
542   }
543 
544   if (tensor.first->op_def().is_stateful()) {
545     // We don't lift stateful nodes directly out of the MapDefun, since they may
546     // have to be executed N times.
547     return false;
548   }
549 
550   bool is_unstacked = true;
551   for (const auto& edge : tensor.first->in_edges()) {
552     // Ignore Source nodes. Note that these are also ignored in the
553     // GraphToFunctionDef conversion.
554     if (edge->src()->IsSource()) continue;
555 
556     if (visited->find(edge) != visited->end()) {
557       // If we've visited this edge already, we're in a cycle. In this case, we
558       // are conservative and don't mark the node as unstacked.
559       is_unstacked = false;
560       continue;
561     }
562     visited->insert(edge);
563 
564     // A node is unstacked if all of its inputs are unstacked
565     is_unstacked &= AddUnstackedTensorMappingsHelper(
566         {edge->src(), edge->src_output()}, visited);
567   }
568 
569   if (!is_unstacked) {
570     return false;
571   }
572 
573   // If the node is unstacked, we copy it into outer_scope_ and
574   // add it to the map. Note that we don't clean up the nodes that are copied
575   // in map_defun_fn_, and rely on them being pruned out later.
576   Status status;
577   Node* node = outer_scope_->AddNode(tensor.first->def(), &status);
578   if (!status.ok()) return false;
579 
580   // Add input edges to nodes that should already have been lifted.
581   for (const auto& edge : tensor.first->in_edges()) {
582     // Ignore Source nodes. Note that these are also ignored in the
583     // GraphToFunctionDef conversion.
584     if (edge->src()->IsSource()) continue;
585 
586     if (auto found = gtl::FindOrNull(conversion_map_,
587                                      {edge->src(), edge->src_output()})) {
588       outer_scope_->AddEdge(found->node, found->output_index, node,
589                             edge->dst_input());
590     } else {
591       return false;
592     }
593   }
594 
595   // Add output mappings
596   for (int i = 0; i < tensor.first->num_outputs(); ++i) {
597     conversion_map_.insert({{tensor.first, i}, WrappedTensor(node, i, false)});
598   }
599   conversion_map_.insert({{tensor.first, Graph::kControlSlot},
600                           WrappedTensor(node, Graph::kControlSlot, false)});
601 
602   return true;
603 }
604 
AddUnstackedTensorMappings()605 Status Vectorization::AddUnstackedTensorMappings() {
606   absl::flat_hash_set<const Edge*> visited;
607   for (const auto& ret_node : map_defun_fn_->ret_nodes) {
608     const Edge* in_edge = nullptr;
609     TF_RETURN_IF_ERROR(ret_node->input_edge(0, &in_edge));
610     AddUnstackedTensorMappingsHelper({in_edge->src(), in_edge->src_output()},
611                                      &visited);
612   }
613   return Status::OK();
614 }
615 
GetResult(FunctionDef ** vectorized_function)616 Status Vectorization::GetResult(FunctionDef** vectorized_function) {
617   TF_RETURN_IF_ERROR(status_);
618   TF_RETURN_IF_ERROR(graph_utils::EnsureNodeNamesUnique(outer_scope_.get()));
619   TF_RETURN_IF_ERROR(graph_utils::EnsureNodeNamesUnique(map_defun_fn_->graph));
620 
621   if (!map_defun_fn_->ret_nodes.empty()) {
622     FunctionDef* map_defun_fn = lib_->add_function();
623     graph_utils::SetUniqueGraphFunctionName("map_defun_fn", lib_, map_defun_fn);
624     TF_RETURN_IF_ERROR(GraphToFunctionDef(
625         *map_defun_fn_->graph, map_defun_fn->signature().name(), map_defun_fn));
626 
627     AttrValue func_attr;
628     func_attr.mutable_func()->set_name(map_defun_fn->signature().name());
629     map_defun_node_->AddAttr("f", func_attr);
630   }
631 
632   *vectorized_function = lib_->add_function();
633   graph_utils::SetUniqueGraphFunctionName("vectorized_fn", lib_,
634                                           *vectorized_function);
635   TF_RETURN_IF_ERROR(GraphToFunctionDef(
636       *outer_scope_, (*vectorized_function)->signature().name(),
637       *vectorized_function));
638   return Status::OK();
639 }
640 
641 }  // namespace
642 
VectorizeMapDefun(const FunctionDef & outer_scope,const NodeDef & map_defun_node,FunctionDefLibrary * lib,FunctionDef ** result)643 Status VectorizeMapDefun(const FunctionDef& outer_scope,
644                          const NodeDef& map_defun_node, FunctionDefLibrary* lib,
645                          FunctionDef** result) {
646   *result = nullptr;
647   return Vectorization(lib).Vectorize(outer_scope, map_defun_node, result);
648 }
649 
650 }  // namespace vectorization_utils
651 }  // namespace grappler
652 }  // namespace tensorflow
653