1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/optimizers/data/vectorization_utils.h"
17 #include "absl/container/flat_hash_set.h"
18 #include "tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h"
19
20 #include "absl/strings/str_join.h"
21 #include "tensorflow/cc/framework/ops.h"
22 #include "tensorflow/core/common_runtime/function.h"
23 #include "tensorflow/core/framework/attr_value.pb.h"
24 #include "tensorflow/core/framework/device_base.h"
25 #include "tensorflow/core/framework/function.h"
26 #include "tensorflow/core/framework/graph_to_functiondef.h"
27 #include "tensorflow/core/framework/node_def.pb.h"
28 #include "tensorflow/core/framework/node_def_builder.h"
29 #include "tensorflow/core/framework/node_def_util.h"
30 #include "tensorflow/core/framework/op_def.pb.h"
31 #include "tensorflow/core/framework/types.h"
32 #include "tensorflow/core/graph/node_builder.h"
33 #include "tensorflow/core/grappler/mutable_graph_view.h"
34 #include "tensorflow/core/grappler/optimizers/data/function_utils.h"
35 #include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
36 #include "tensorflow/core/grappler/utils.h"
37 #include "tensorflow/core/grappler/utils/functions.h"
38 #include "tensorflow/core/lib/gtl/map_util.h"
39
40 namespace tensorflow {
41 namespace grappler {
42 namespace vectorization_utils {
43
44 namespace {
45
46 // Describes a tensor with its operation Node and output position
47 typedef std::pair<Node*, int> TensorDesc;
48
49 constexpr char kRetValOp[] = "_Retval";
50
ReplaceEdgeSources(const TensorDesc & old_src,const TensorDesc & new_src,Graph * graph)51 void ReplaceEdgeSources(const TensorDesc& old_src, const TensorDesc& new_src,
52 Graph* graph) {
53 // NOTE: We need two for loops here because we can't mutate the set of output
54 // edges as we iterate over them.
55 std::vector<const Edge*> edges_to_replace;
56 for (auto edge : old_src.first->out_edges()) {
57 if (edge->src_output() == old_src.second) {
58 edges_to_replace.push_back(edge);
59 }
60 }
61 for (auto edge : edges_to_replace) {
62 graph->AddEdge(new_src.first, new_src.second, edge->dst(),
63 edge->dst_input());
64 graph->RemoveEdge(edge);
65 }
66 }
67
68 // Update node attrs to keep its properties consistent with the function
UpdateMapDefunAttrs(FunctionBody * map_defun_fn,Node * map_defun_node)69 void UpdateMapDefunAttrs(FunctionBody* map_defun_fn, Node* map_defun_node) {
70 map_defun_node->AddAttr("output_types", map_defun_fn->ret_types);
71
72 // TODO(rachelim): Propagate precise shapes if they're known, which may enable
73 // subsequent optimizations.
74 map_defun_node->AddAttr("output_shapes", std::vector<PartialTensorShape>(
75 map_defun_fn->ret_types.size()));
76 }
77
AddMapDefunOutput(FunctionBody * map_defun_fn,Node * map_defun_node,const TensorDesc & output)78 Status AddMapDefunOutput(FunctionBody* map_defun_fn, Node* map_defun_node,
79 const TensorDesc& output) {
80 DataType type = output.first->output_type(output.second);
81 int index = map_defun_fn->ret_nodes.size();
82
83 NodeDef ret_node_def;
84 ret_node_def.set_name("map_out");
85 ret_node_def.set_op(kRetValOp);
86 AddNodeAttr("T", type, &ret_node_def);
87 AddNodeAttr("index", index, &ret_node_def);
88
89 Status s;
90 Node* ret_node = map_defun_fn->graph->AddNode(ret_node_def, &s);
91 TF_RETURN_IF_ERROR(s);
92
93 map_defun_fn->graph->AddEdge(output.first, output.second, ret_node, 0);
94 map_defun_fn->ret_nodes.push_back(ret_node);
95 map_defun_fn->ret_types.push_back(type);
96 UpdateMapDefunAttrs(map_defun_fn, map_defun_node);
97
98 return s;
99 }
100
RemoveMapDefunOutput(int output_position,Graph * outer_scope,FunctionBody * map_defun_fn,Node * map_defun_node)101 void RemoveMapDefunOutput(int output_position, Graph* outer_scope,
102 FunctionBody* map_defun_fn, Node* map_defun_node) {
103 DCHECK_LT(output_position, map_defun_fn->ret_nodes.size())
104 << "Trying to remove output that doesn't exist. Output number: "
105 << output_position;
106
107 int num_later_outputs = map_defun_fn->ret_nodes.size() - output_position - 1;
108
109 // Modify map_defun_fn's signature and remove the output node from its graph
110 map_defun_fn->graph->RemoveNode(map_defun_fn->ret_nodes[output_position]);
111 map_defun_fn->ret_nodes.erase(map_defun_fn->ret_nodes.begin() +
112 output_position);
113 map_defun_fn->ret_types.erase(map_defun_fn->ret_types.begin() +
114 output_position);
115 UpdateMapDefunAttrs(map_defun_fn, map_defun_node);
116
117 // Renumber the nodes and edges that come after
118 for (int i = 0; i < num_later_outputs; ++i) {
119 ReplaceEdgeSources({map_defun_node, output_position + i + 1},
120 {map_defun_node, output_position + i}, outer_scope);
121 // Each ret node has an "index" attr that has to be updated
122 map_defun_fn->ret_nodes[output_position + i]->AddAttr("index",
123 output_position + i);
124 }
125 }
126
127 // Helper class that vectorizes the body of a MapDefun node, adding new
128 // operations to the graph that collectively compute the same value as what
129 // running the MapDefun function on slices of the input would produce.
130 // This class transforms the input FunctionDefs into their corresponding
131 // Graph objects and works on the graphs directly, then converts them back
132 // to FunctionDefs when GetResult is called.
133 // TODO(rachelim): Move this to its own header.
134 class Vectorization {
135 public:
Vectorization(FunctionDefLibrary * lib)136 explicit Vectorization(FunctionDefLibrary* lib)
137 : lib_(lib), lib_def_(OpRegistry::Global(), *lib) {}
138
139 // Adds the vectorized function and new map_defun_fn to lib, and points
140 // vectorized_function to the former. Returns an error status if
141 // the conversion between FunctionDef -> Graph -> FunctionDef failed anywhere
142 // along the way.
143 Status Vectorize(const FunctionDef& outer_scope,
144 const NodeDef& map_defun_node, FunctionDef** result);
145
146 private:
147 // Converts FunctionDefs to Graphs and adds mappings from
148 // arg nodes and unstacked nodes to the corresponding nodes in outer_scope_.
149 Status Initialize(const FunctionDef& outer_scope,
150 const NodeDef& map_defun_node);
151
152 // Converts Graphs back to FunctionDefs and adds them to `lib_`.
153 Status GetResult(FunctionDef** vectorized_function);
154
155 // Repeatedly tries to convert outputs of `map_defun_fn_` into new nodes in
156 // `outer_scope_`, until there are no convertible outputs remaining.
157 void VectorizeHelper();
158
159 // Vectorizes map_defun_fn's output at output_position.
160 Status ConvertOutput(int output_position);
161
162 // Adds mappings from node's outputs tensors to converted output tensors,
163 // creating the necessary new node(s). Generally, the steps to convert an op
164 // are:
165 // 1) Create new node(s) in `outer_scope_` that act on batched input tensors.
166 // These operations collectively compute the same value as what running
167 // the original operation on slices of the input tensors would produce.
168 // For example, a Cast op in MapDefun translates to a Cast op in
169 // `outer_scope_`, since the vectorized version of Cast is itself.
170 // 2) Promote the inputs of the op inputs to outputs of the
171 // `map_defun_node_` and `map_defun_fn_`.
172 // 3) Add edges between the promoted inputs (that are now outputs of
173 // `map_defun_node`) and the inputs ports of the new node(s).
174 // 4) For each output of the old node, add the mapping of output tensors to
175 // the conversion map.
176 Status AddConversionMapping(Node* op_node);
177
178 // Given a tensor t in `unstacked`, stacks it by doing the equivalent of
179 // tf.tile(tf.expand_dims(t, 0), [n, 1, 1, ...]) where n is dimension 0 of
180 // inputs to `map_defun_node_`. This stacked tensor will be compatible with
181 // the expected output shape of `map_defun_node_`.
182 // This is equivalent to the _stack function in python Pfor.
183 Status StackTensor(WrappedTensor* unstacked, TensorDesc* result);
184
185 // Recursively looks for unstacked nodes in the `map_defun_fn_` graph by
186 // doing a depth-first search from the ret nodes. Lifts tensors that are
187 // unstacked (i.e. don't derive from arg tensors) into `outer_scope_` directly
188 // and adds mappings to `conversion_map_`.
189 // Note that this function may have false negatives, i.e. not
190 // add mappings for some tensors that are unstacked. This may happen in the
191 // following cases: 1) a vectorized op produces unstacked outputs from stacked
192 // inputs (e.g. the vectorized "Shape" op), 2) the tensors are in a cycle, or
193 // 3) the unstacked op could not be lifted into `outer_scope`.
194 Status AddUnstackedTensorMappings();
195
196 // Recursive helper for `AddUnstackedTensorMappings`. If an op node is
197 // unstacked, lifts its output tensors into `outer_scope`, adding the mappings
198 // to `conversion_map`. Returns true if the unstacked mappings were added.
199 bool AddUnstackedTensorMappingsHelper(
200 TensorDesc&& tensor, absl::flat_hash_set<const Edge*>* visited);
201
202 // Add mappings from `map_defun_fn_` arg tensors to `map_defun_node_` input
203 // tensors to `conversion_map_`.
204 Status AddArgTensorMappings();
205
206 // Maps a tensor to the corresponding WrappedTensor. For example,
207 // {"Cast" Node*, 0} -> WrappedTensor({"Vectorize/Cast" Node*, 0}, true)
208 std::map<TensorDesc, WrappedTensor> conversion_map_;
209
210 // Unconvertible ret nodes
211 std::set<Node*> unconvertible_;
212
213 FunctionDefLibrary* lib_; // Not owned
214 FunctionLibraryDefinition lib_def_;
215 // Note that FunctionBody has a pointer to a Graph object that corresponds
216 // to the function's subgraph, with additional kArgOp and kRetValOp nodes
217 // that denote that function arguments and return values. These nodes have the
218 // attrs "T" for the type, and "index" for the argument / retval index
219 // respectively. FunctionBody also keeps track of arg/ret_nodes and
220 // arg/ret_types, that should be ordered according to argument/output indices.
221 std::unique_ptr<Graph> outer_scope_;
222 std::unique_ptr<FunctionBody> map_defun_fn_;
223 Node* map_defun_node_ = nullptr; // Owned by `outer_scope`
224
225 // Caches the loop_len_node_ needed for tiling unstacked output. This
226 // corresponds to a vector with one element.
227 Node* loop_len_node_ = nullptr; // Owned by `outer_scope`
228 Status status_;
229 };
230
AddConversionMapping(Node * op_node)231 Status Vectorization::AddConversionMapping(Node* op_node) {
232 for (auto edge : op_node->in_edges()) {
233 if (edge->IsControlEdge()) {
234 return errors::InvalidArgument(
235 "Vectorizing outputs with control inputs is currently not "
236 "supported.");
237 }
238 }
239
240 auto vectorizer = VectorizerRegistry::Global()->Get(op_node->type_string());
241 if (vectorizer == nullptr) {
242 return errors::Unimplemented("No vectorizer registered for op: ",
243 op_node->type_string());
244 }
245 std::vector<WrappedTensor> inputs, outputs;
246 inputs.reserve(op_node->num_inputs());
247 outputs.reserve(op_node->num_outputs());
248
249 std::vector<const Edge*> input_edges;
250 TF_RETURN_IF_ERROR(op_node->input_edges(&input_edges));
251
252 // The inputs for the node to be converted may already have been converted
253 // themselves. For those that are not, we promote them to MapDefun outputs.
254 for (size_t i = 0; i < op_node->num_inputs(); ++i) {
255 auto edge = input_edges[i];
256 if (auto found = gtl::FindOrNull(conversion_map_,
257 {edge->src(), edge->src_output()})) {
258 inputs.push_back(*found);
259 } else {
260 // TODO(rachelim): Handle the case where unconverted inputs are unstacked.
261 // We assume that all unconverted inputs will be stacked, since we
262 // converted all unstacked nodes in `Initialize`. However, it's actually
263 // possible that yet-unconverted nodes may produce unstacked outputs after
264 // they are vectorized. (For example, see the "Shape" converter in
265 // tensorflow/python/ops/parallel_for/pfor.py). If a vectorizer expects
266 // an unstacked input but receives a stacked one, vectorizer->Vectorize
267 // will return an error.
268 TF_RETURN_IF_ERROR(AddMapDefunOutput(map_defun_fn_.get(), map_defun_node_,
269 {edge->src(), edge->src_output()}));
270 int output_index = map_defun_fn_->ret_nodes.size() - 1;
271 inputs.push_back({map_defun_node_, output_index, true});
272 }
273 }
274
275 Status s = vectorizer->Vectorize(*op_node, outer_scope_.get(),
276 std::move(inputs), &outputs);
277 if (!s.ok()) {
278 VLOG(2) << "Vectorizer for op \"" << op_node->type_string()
279 << "\" failed with error: " << s;
280 return s;
281 }
282
283 if (op_node->num_outputs() != outputs.size()) {
284 return errors::Internal(
285 "Number of vectorizer outputs does not match. Expected: ",
286 op_node->num_outputs(), " Actual: ", outputs.size());
287 }
288
289 // Add output mappings.
290 for (size_t i = 0; i < op_node->num_outputs(); ++i) {
291 conversion_map_.insert({{op_node, i}, outputs[i]});
292 }
293
294 return Status::OK();
295 }
296
ConvertOutput(int output_position)297 Status Vectorization::ConvertOutput(int output_position) {
298 // ret_edge->src() is the actual op that generated the retval, and
299 // ret_edge->dst() is the retval node whose op is "_Retval"
300 const Edge* ret_edge;
301 TF_RETURN_IF_ERROR(
302 map_defun_fn_->ret_nodes[output_position]->input_edge(0, &ret_edge));
303
304 TensorDesc output({ret_edge->src(), ret_edge->src_output()});
305 TensorDesc converted_output;
306
307 // It's possible the output already has a mapping, if it comes from a node
308 // that has already been converted.
309 auto found = gtl::FindOrNull(conversion_map_, output);
310 if (!found) {
311 TF_RETURN_IF_ERROR(AddConversionMapping(output.first));
312 found = &conversion_map_.at(output);
313 }
314
315 if (found->stacked) {
316 converted_output = {found->node, found->output_index};
317 } else {
318 // Some outputs may be unstacked if they don't derive from arg nodes
319 // (for example, if a function returns a constant). For these, we
320 // have to add extra nodes to tile it in the 0th dimension.
321 TF_RETURN_IF_ERROR(StackTensor(found, &converted_output));
322 }
323
324 ReplaceEdgeSources({map_defun_node_, output_position}, converted_output,
325 outer_scope_.get());
326 RemoveMapDefunOutput(output_position, outer_scope_.get(), map_defun_fn_.get(),
327 map_defun_node_);
328
329 return Status::OK();
330 }
331
Vectorize(const FunctionDef & outer_scope,const NodeDef & map_defun_node,FunctionDef ** result)332 Status Vectorization::Vectorize(const FunctionDef& outer_scope,
333 const NodeDef& map_defun_node,
334 FunctionDef** result) {
335 TF_RETURN_IF_ERROR(Initialize(outer_scope, map_defun_node));
336 VectorizeHelper();
337 return GetResult(result);
338 }
339
VectorizeHelper()340 void Vectorization::VectorizeHelper() {
341 while (true) {
342 int output_position = graph_utils::GetFirstElementIndexWithPredicate(
343 [this](Node* n) {
344 return this->unconvertible_.find(n) == this->unconvertible_.end();
345 },
346 map_defun_fn_->ret_nodes);
347
348 // No outputs left to convert
349 if (output_position == -1) break;
350
351 Status s = ConvertOutput(output_position);
352 if (!s.ok()) {
353 Node* output_node = map_defun_fn_->ret_nodes.at(output_position);
354 VLOG(2) << "Could not convert the output at node: "
355 << output_node->DebugString() << "\nError: " << s;
356 unconvertible_.insert(output_node);
357 }
358 }
359
360 // If we've converted all the outputs of the MapDefun function, we no longer
361 // need the MapDefun node and can delete it.
362 if (map_defun_fn_->ret_nodes.empty()) {
363 outer_scope_->RemoveNode(map_defun_node_);
364 }
365 }
366
Initialize(const FunctionDef & outer_scope,const NodeDef & map_defun_node)367 Status Vectorization::Initialize(const FunctionDef& outer_scope,
368 const NodeDef& map_defun_node) {
369 // Convert outer_scope and map_defun_fn to FunctionBodys so we can
370 // work on Graphs directly.
371 const FunctionDef* map_defun_fn =
372 lib_def_.Find(map_defun_node.attr().at("f").func().name());
373
374 if (map_defun_fn == nullptr) {
375 return errors::NotFound("Could not find function with name ",
376 map_defun_node.attr().at("f").func().name(),
377 " in function library.");
378 }
379
380 auto get_func_sig = [this](const string& op, const OpDef** sig) {
381 return this->lib_def_.LookUpOpDef(op, sig);
382 };
383
384 FunctionBody* outer_fn;
385 TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(outer_scope, {}, &lib_def_,
386 get_func_sig, &outer_fn));
387 // We don't need outer_fn, just the graph
388 outer_scope_.reset(outer_fn->graph);
389 outer_fn->graph = nullptr;
390 delete outer_fn;
391
392 FunctionBody* tmp;
393 TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*map_defun_fn, {}, &lib_def_,
394 get_func_sig, &tmp));
395 map_defun_fn_.reset(tmp);
396
397 // Find the MapDefun node in outer_scope_
398 int node_id = graph_utils::GetFirstElementIndexWithPredicate(
399 [&map_defun_node](Node* n) { return n->name() == map_defun_node.name(); },
400 outer_scope_->nodes());
401 if (node_id == -1) {
402 return errors::NotFound("Could not find node with name ",
403 map_defun_node.name(), " in outer_scope.");
404 }
405 map_defun_node_ = outer_scope_->FindNodeId(node_id);
406
407 TF_RETURN_IF_ERROR(AddArgTensorMappings());
408 TF_RETURN_IF_ERROR(AddUnstackedTensorMappings());
409 loop_len_node_ = nullptr;
410
411 return Status::OK();
412 }
413
414 // TODO(rachelim): It might be profitable to use the C++ API for this instead of
415 // NodeBuilder
StackTensor(WrappedTensor * unstacked,TensorDesc * result)416 Status Vectorization::StackTensor(WrappedTensor* unstacked,
417 TensorDesc* result) {
418 if (unstacked->node->output_type(unstacked->output_index) == DT_VARIANT) {
419 // TODO(b/124069171): "ExpandDims" doesn't work with Variant tensors.
420 return errors::Unimplemented("Cannot stack tensor with Variant type.");
421 }
422 // Note that all these nodes are necessary as the size of the batch may not be
423 // constant.
424 if (unstacked->stacked) {
425 return errors::Internal("Can only stack unstacked tensor.");
426 }
427
428 Graph* g = outer_scope_.get();
429 auto node_builder = [](StringPiece op) {
430 return NodeBuilder(strings::StrCat("vectorized/stack/", op), op);
431 };
432
433 auto make_const = [&node_builder](const Input::Initializer& val, Graph* graph,
434 Node** result) {
435 TF_RETURN_IF_ERROR(val.status);
436 return node_builder("Const")
437 .Attr("value", val.tensor)
438 .Attr("dtype", val.tensor.dtype())
439 .Finalize(graph, result);
440 };
441
442 // If loop_len_node_ hasn't been created yet, add the node and cache it.
443 if (loop_len_node_ == nullptr) {
444 Node* input_node;
445 TF_RETURN_IF_ERROR(map_defun_node_->input_node(0, &input_node));
446
447 Node* shape_node;
448 TF_RETURN_IF_ERROR(
449 node_builder("Shape").Input(input_node).Finalize(g, &shape_node));
450
451 Node* const_vec_0;
452 TF_RETURN_IF_ERROR(make_const({0}, g, &const_vec_0));
453 Node* const_vec_1;
454 TF_RETURN_IF_ERROR(make_const({1}, g, &const_vec_1));
455
456 Node* strided_slice_node;
457 TF_RETURN_IF_ERROR(node_builder("StridedSlice")
458 .Input(shape_node) // input
459 .Input(const_vec_0) // begin
460 .Input(const_vec_1) // end
461 .Input(const_vec_1) // strides
462 .Finalize(g, &strided_slice_node));
463
464 // Produces a vector of length 1
465 TF_RETURN_IF_ERROR(node_builder("Reshape")
466 .Input(strided_slice_node) // tensor
467 .Input(const_vec_1) // shape
468 .Finalize(g, &loop_len_node_));
469 }
470
471 Node* ones_shape;
472 TF_RETURN_IF_ERROR(node_builder("Shape")
473 .Input(unstacked->node) // input
474 .Finalize(g, &ones_shape));
475
476 Node* ones;
477 TF_RETURN_IF_ERROR(
478 node_builder("OnesLike").Input(ones_shape).Finalize(g, &ones));
479
480 Node* const_0;
481 TF_RETURN_IF_ERROR(make_const(0, g, &const_0));
482
483 Node* multiples;
484 TF_RETURN_IF_ERROR(node_builder("Concat")
485 .Input(const_0) // concat_dim
486 .Input({{loop_len_node_, 0}, {ones, 0}}) // values
487 .Finalize(g, &multiples));
488
489 Node* expand_dims;
490 TF_RETURN_IF_ERROR(node_builder("ExpandDims")
491 .Input(unstacked->node) // input
492 .Input(const_0) // dim
493 .Finalize(g, &expand_dims));
494
495 TF_RETURN_IF_ERROR(node_builder("Tile")
496 .Input(expand_dims) // input
497 .Input(multiples) // multiples
498 .Finalize(g, &result->first));
499 result->second = 0;
500 return Status::OK();
501 }
502
AddArgTensorMappings()503 Status Vectorization::AddArgTensorMappings() {
504 // Note that inputs to map_defun_fn_ are either regular arguments (for which
505 // the operations are mapped across their 0th dimension) or captured inputs
506 // (for which the operations apply to the argument wholesale).
507 int num_args =
508 map_defun_node_->attrs().Find("Targuments")->list().type_size();
509
510 auto add_conversion = [this](Node* arg_node, bool stacked) {
511 Node* input_node;
512 TF_RETURN_IF_ERROR(map_defun_node_->input_node(
513 arg_node->attrs().Find("index")->i(), &input_node));
514
515 conversion_map_.insert({{arg_node, 0}, {input_node, 0, stacked}});
516
517 // Control inputs
518 conversion_map_.insert({{arg_node, Graph::kControlSlot},
519 {input_node, Graph::kControlSlot, stacked}});
520
521 return Status::OK();
522 };
523
524 // Regular arguments
525 for (int i = 0; i < num_args; ++i) {
526 TF_RETURN_IF_ERROR(add_conversion(map_defun_fn_->arg_nodes[i], true));
527 }
528
529 // Captured inputs. These are applied (without slicing) to every iteration of
530 // the map function, hence are mapped to unstacked nodes.
531 for (int i = num_args; i < map_defun_fn_->arg_nodes.size(); ++i) {
532 TF_RETURN_IF_ERROR(add_conversion(map_defun_fn_->arg_nodes[i], false));
533 }
534
535 return Status::OK();
536 }
537
AddUnstackedTensorMappingsHelper(TensorDesc && tensor,absl::flat_hash_set<const Edge * > * visited)538 bool Vectorization::AddUnstackedTensorMappingsHelper(
539 TensorDesc&& tensor, absl::flat_hash_set<const Edge*>* visited) {
540 if (auto found = gtl::FindOrNull(conversion_map_, tensor)) {
541 return !found->stacked;
542 }
543
544 if (tensor.first->op_def().is_stateful()) {
545 // We don't lift stateful nodes directly out of the MapDefun, since they may
546 // have to be executed N times.
547 return false;
548 }
549
550 bool is_unstacked = true;
551 for (const auto& edge : tensor.first->in_edges()) {
552 // Ignore Source nodes. Note that these are also ignored in the
553 // GraphToFunctionDef conversion.
554 if (edge->src()->IsSource()) continue;
555
556 if (visited->find(edge) != visited->end()) {
557 // If we've visited this edge already, we're in a cycle. In this case, we
558 // are conservative and don't mark the node as unstacked.
559 is_unstacked = false;
560 continue;
561 }
562 visited->insert(edge);
563
564 // A node is unstacked if all of its inputs are unstacked
565 is_unstacked &= AddUnstackedTensorMappingsHelper(
566 {edge->src(), edge->src_output()}, visited);
567 }
568
569 if (!is_unstacked) {
570 return false;
571 }
572
573 // If the node is unstacked, we copy it into outer_scope_ and
574 // add it to the map. Note that we don't clean up the nodes that are copied
575 // in map_defun_fn_, and rely on them being pruned out later.
576 Status status;
577 Node* node = outer_scope_->AddNode(tensor.first->def(), &status);
578 if (!status.ok()) return false;
579
580 // Add input edges to nodes that should already have been lifted.
581 for (const auto& edge : tensor.first->in_edges()) {
582 // Ignore Source nodes. Note that these are also ignored in the
583 // GraphToFunctionDef conversion.
584 if (edge->src()->IsSource()) continue;
585
586 if (auto found = gtl::FindOrNull(conversion_map_,
587 {edge->src(), edge->src_output()})) {
588 outer_scope_->AddEdge(found->node, found->output_index, node,
589 edge->dst_input());
590 } else {
591 return false;
592 }
593 }
594
595 // Add output mappings
596 for (int i = 0; i < tensor.first->num_outputs(); ++i) {
597 conversion_map_.insert({{tensor.first, i}, WrappedTensor(node, i, false)});
598 }
599 conversion_map_.insert({{tensor.first, Graph::kControlSlot},
600 WrappedTensor(node, Graph::kControlSlot, false)});
601
602 return true;
603 }
604
AddUnstackedTensorMappings()605 Status Vectorization::AddUnstackedTensorMappings() {
606 absl::flat_hash_set<const Edge*> visited;
607 for (const auto& ret_node : map_defun_fn_->ret_nodes) {
608 const Edge* in_edge = nullptr;
609 TF_RETURN_IF_ERROR(ret_node->input_edge(0, &in_edge));
610 AddUnstackedTensorMappingsHelper({in_edge->src(), in_edge->src_output()},
611 &visited);
612 }
613 return Status::OK();
614 }
615
GetResult(FunctionDef ** vectorized_function)616 Status Vectorization::GetResult(FunctionDef** vectorized_function) {
617 TF_RETURN_IF_ERROR(status_);
618 TF_RETURN_IF_ERROR(graph_utils::EnsureNodeNamesUnique(outer_scope_.get()));
619 TF_RETURN_IF_ERROR(graph_utils::EnsureNodeNamesUnique(map_defun_fn_->graph));
620
621 if (!map_defun_fn_->ret_nodes.empty()) {
622 FunctionDef* map_defun_fn = lib_->add_function();
623 graph_utils::SetUniqueGraphFunctionName("map_defun_fn", lib_, map_defun_fn);
624 TF_RETURN_IF_ERROR(GraphToFunctionDef(
625 *map_defun_fn_->graph, map_defun_fn->signature().name(), map_defun_fn));
626
627 AttrValue func_attr;
628 func_attr.mutable_func()->set_name(map_defun_fn->signature().name());
629 map_defun_node_->AddAttr("f", func_attr);
630 }
631
632 *vectorized_function = lib_->add_function();
633 graph_utils::SetUniqueGraphFunctionName("vectorized_fn", lib_,
634 *vectorized_function);
635 TF_RETURN_IF_ERROR(GraphToFunctionDef(
636 *outer_scope_, (*vectorized_function)->signature().name(),
637 *vectorized_function));
638 return Status::OK();
639 }
640
641 } // namespace
642
VectorizeMapDefun(const FunctionDef & outer_scope,const NodeDef & map_defun_node,FunctionDefLibrary * lib,FunctionDef ** result)643 Status VectorizeMapDefun(const FunctionDef& outer_scope,
644 const NodeDef& map_defun_node, FunctionDefLibrary* lib,
645 FunctionDef** result) {
646 *result = nullptr;
647 return Vectorization(lib).Vectorize(outer_scope, map_defun_node, result);
648 }
649
650 } // namespace vectorization_utils
651 } // namespace grappler
652 } // namespace tensorflow
653