1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/optimizers/data/fusion_utils.h"
17
18 #include "tensorflow/core/framework/node_def.pb.h"
19 #include "tensorflow/core/framework/node_def_builder.h"
20 #include "tensorflow/core/framework/op_def.pb.h"
21 #include "tensorflow/core/grappler/grappler_item.h"
22 #include "tensorflow/core/grappler/mutable_graph_view.h"
23 #include "tensorflow/core/grappler/op_types.h"
24 #include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
25 #include "tensorflow/core/grappler/optimizers/data/function_utils.h"
26 #include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
27 #include "tensorflow/core/grappler/utils.h"
28 #include "tensorflow/core/lib/gtl/flatmap.h"
29 #include "tensorflow/core/lib/gtl/flatset.h"
30 #include "tensorflow/core/lib/gtl/map_util.h"
31 #include "tensorflow/core/lib/strings/strcat.h"
32 #include "tensorflow/core/platform/protobuf.h"
33
34 namespace tensorflow {
35 namespace grappler {
36 namespace fusion_utils {
37
38 namespace {
ParseNodeConnection(const string & name)39 string ParseNodeConnection(const string& name) {
40 // If input/output node name has semicolon, take the prefix. Otherwise take
41 // the whole string.
42 return name.substr(0, name.find(':'));
43 }
44
ParseOutputNode(const string & name)45 string ParseOutputNode(const string& name) {
46 if (name.find(':') == string::npos) return {};
47 return name.substr(name.find(':'), string::npos);
48 }
49
GetOutputNode(const FunctionDef & function,int output_idx)50 string GetOutputNode(const FunctionDef& function, int output_idx) {
51 const auto& ret_output_name =
52 function.signature().output_arg(output_idx).name();
53 return function.ret().at(ret_output_name);
54 }
55
GetMutableOutputNode(FunctionDef * function,int output_idx)56 string& GetMutableOutputNode(FunctionDef* function, int output_idx) {
57 const auto& ret_output_name =
58 function->signature().output_arg(output_idx).name();
59 return function->mutable_ret()->at(ret_output_name);
60 }
61
62 template <typename Iterable>
GetNames(const Iterable & iterable,int allocate_size)63 StringCollection GetNames(const Iterable& iterable, int allocate_size) {
64 StringCollection names;
65 names.reserve(allocate_size);
66 for (auto& arg : iterable) names.push_back(arg.name());
67 return names;
68 }
69
70 template <typename Iterable>
GetNodeNamesSet(const Iterable & nodes)71 gtl::FlatSet<string> GetNodeNamesSet(const Iterable& nodes) {
72 // NOTE(prazek): Cases where the set is not modified after construction
73 // could use sorted vector with binary_search instead, to make it faster.
74 gtl::FlatSet<string> names;
75 for (const auto& node : nodes) {
76 CHECK(gtl::InsertIfNotPresent(&names, node.name()))
77 << "Functions should have unique node names. Node with name "
78 << node.name() << " already exists";
79 }
80 return names;
81 }
82
83 template <typename Iterable>
GetUniqueNames(const Iterable & first_iterable,const Iterable & second_iterable)84 gtl::FlatMap<string, string> GetUniqueNames(const Iterable& first_iterable,
85 const Iterable& second_iterable) {
86 gtl::FlatMap<string, string> changed_node_names;
87 const auto first_names = GetNodeNamesSet(first_iterable);
88 auto second_names = GetNodeNamesSet(first_iterable);
89 int id = second_iterable.size();
90
91 for (const auto& node : second_iterable) {
92 string name_before = node.name();
93 string name = name_before;
94 bool changed_name = false;
95
96 while (first_names.count(name) ||
97 (changed_name && second_names.count(name))) {
98 name = strings::StrCat(name_before, "/_", id);
99 changed_name = true;
100 ++id;
101 }
102 if (changed_name) {
103 changed_node_names[name_before] = name;
104 // We don't want to pick a new name that would collide with another new
105 // name.
106 second_names.insert(std::move(name));
107 }
108 }
109 return changed_node_names;
110 }
111
112 // We need to rename them and the connections of the inputs that refer to them.
113 // Nodes that will be added to the function can have the same name as the nodes
114 // from parent function.
RenameFunctionNodes(const FunctionDef & first_function,protobuf::RepeatedPtrField<NodeDef> * nodes_to_fuse,protobuf::Map<string,string> * rets_to_fuse)115 void RenameFunctionNodes(const FunctionDef& first_function,
116 protobuf::RepeatedPtrField<NodeDef>* nodes_to_fuse,
117 protobuf::Map<string, string>* rets_to_fuse) {
118 const gtl::FlatMap<string, string> changed_node_names =
119 GetUniqueNames(first_function.node_def(), *nodes_to_fuse);
120
121 auto update_name = [&changed_node_names](string* input) {
122 string input_node = ParseNodeConnection(*input);
123 auto iter = changed_node_names.find(input_node);
124 if (iter != changed_node_names.end()) {
125 *input = iter->second + ParseOutputNode(*input);
126 }
127 };
128
129 for (NodeDef& function_node : *nodes_to_fuse) {
130 if (const string* new_name =
131 gtl::FindOrNull(changed_node_names, function_node.name())) {
132 function_node.set_name(*new_name);
133 }
134
135 for (string& input : *function_node.mutable_input()) {
136 update_name(&input);
137 }
138 }
139
140 for (auto& ret : *rets_to_fuse) update_name(&ret.second);
141 }
142
GetFunctionInputs(const FunctionDef & function)143 StringCollection GetFunctionInputs(const FunctionDef& function) {
144 return GetNames(function.signature().input_arg(),
145 function.signature().input_arg_size());
146 }
147
148 // This function produces signature having names that do not conflict with
149 // `first_signature`. The input of returns and nodes that will be fused are
150 // updated to use new names.
GetUniqueSignature(const OpDef & first_signature,const OpDef & second_signature,protobuf::Map<string,string> * rets_to_fuse,protobuf::RepeatedPtrField<NodeDef> * nodes_to_fuse)151 OpDef GetUniqueSignature(const OpDef& first_signature,
152 const OpDef& second_signature,
153 protobuf::Map<string, string>* rets_to_fuse,
154 protobuf::RepeatedPtrField<NodeDef>* nodes_to_fuse) {
155 const gtl::FlatMap<string, string> changed_input_names =
156 GetUniqueNames(first_signature.input_arg(), second_signature.input_arg());
157 OpDef signature;
158 signature.set_name(second_signature.name());
159
160 for (const auto& input_arg : second_signature.input_arg()) {
161 auto& input = *signature.add_input_arg();
162 input = input_arg;
163 if (const string* new_name =
164 gtl::FindOrNull(changed_input_names, input.name())) {
165 input.set_name(*new_name);
166 }
167 }
168 const gtl::FlatMap<string, string> changed_output_names = GetUniqueNames(
169 first_signature.output_arg(), second_signature.output_arg());
170
171 for (const auto& output_arg : second_signature.output_arg()) {
172 auto& output = *signature.add_output_arg();
173 output = output_arg;
174 if (const string* new_name =
175 gtl::FindOrNull(changed_output_names, output.name())) {
176 output.set_name(*new_name);
177 }
178 }
179
180 protobuf::Map<string, string> new_rets;
181 for (const auto& ret : *rets_to_fuse) {
182 const auto& key = changed_output_names.count(ret.first)
183 ? changed_output_names.at(ret.first)
184 : ret.first;
185 const auto& input = ParseNodeConnection(ret.second);
186 const auto& value =
187 changed_input_names.count(input)
188 ? changed_input_names.at(input) + ParseOutputNode(ret.second)
189 : ret.second;
190 new_rets[key] = value;
191 }
192 *rets_to_fuse = std::move(new_rets);
193
194 for (NodeDef& function_node : *nodes_to_fuse) {
195 for (auto& node_input : *function_node.mutable_input()) {
196 const auto& input = ParseNodeConnection(node_input);
197 if (const string* new_name =
198 gtl::FindOrNull(changed_input_names, input)) {
199 node_input = *new_name + ParseOutputNode(node_input);
200 }
201 }
202 }
203
204 return signature;
205 }
206
207 // This function adds new nodes and changes their input to the output nodes
208 // of parent function. It assumes that the name of nodes to fuse are not
209 // conflicting.
FuseFunctionNodes(const StringCollection & first_inputs,const StringCollection & second_inputs,const StringCollection & first_outputs,const SetInputFn & set_input,protobuf::RepeatedPtrField<NodeDef> * nodes_to_fuse)210 void FuseFunctionNodes(const StringCollection& first_inputs,
211 const StringCollection& second_inputs,
212 const StringCollection& first_outputs,
213 const SetInputFn& set_input,
214 protobuf::RepeatedPtrField<NodeDef>* nodes_to_fuse) {
215 for (NodeDef& function_node : *nodes_to_fuse) {
216 for (auto& node_input : *function_node.mutable_input()) {
217 auto parsed_name = ParseNodeConnection(node_input);
218
219 auto input_it =
220 std::find(second_inputs.begin(), second_inputs.end(), parsed_name);
221 if (input_it == second_inputs.end()) continue;
222
223 auto arg_num = std::distance(second_inputs.begin(), input_it);
224 node_input =
225 set_input(first_inputs, second_inputs, first_outputs, arg_num);
226 }
227 }
228 }
229
230 // This function looks for direct edges from input to return and rewrites
231 // them to the corresponding input of the return of `first_function`.
FuseReturns(const StringCollection & first_inputs,const StringCollection & second_inputs,const StringCollection & first_outputs,const SetInputFn & set_input,protobuf::Map<string,string> * fused_ret)232 void FuseReturns(const StringCollection& first_inputs,
233 const StringCollection& second_inputs,
234 const StringCollection& first_outputs,
235 const SetInputFn& set_input,
236 protobuf::Map<string, string>* fused_ret) {
237 for (auto& ret : *fused_ret) {
238 auto return_input = ParseNodeConnection(ret.second);
239 auto input_it =
240 std::find(second_inputs.begin(), second_inputs.end(), return_input);
241 if (input_it == second_inputs.end()) continue;
242
243 auto input_idx = std::distance(second_inputs.begin(), input_it);
244 ret.second =
245 set_input(first_inputs, second_inputs, first_outputs, input_idx);
246 }
247 }
248
249 // Returns collection of node names that are used as a return from function.
GetFunctionOutputs(const FunctionDef & function)250 StringCollection GetFunctionOutputs(const FunctionDef& function) {
251 const auto number_of_outputs = function.signature().output_arg_size();
252 StringCollection outputs;
253 outputs.reserve(number_of_outputs);
254
255 for (int output_idx = 0; output_idx < number_of_outputs; output_idx++)
256 outputs.push_back(GetOutputNode(function, output_idx));
257 return outputs;
258 }
259
CreateFalsePredicate(const protobuf::RepeatedPtrField<OpDef_ArgDef> & fake_args,FunctionDefLibrary * library)260 FunctionDef* CreateFalsePredicate(
261 const protobuf::RepeatedPtrField<OpDef_ArgDef>& fake_args,
262 FunctionDefLibrary* library) {
263 GraphDef graph;
264 MutableGraphView graph_view(&graph);
265 auto* node = graph_utils::AddScalarConstNode(false, &graph_view);
266 auto* false_predicate = library->add_function();
267 graph_utils::SetUniqueGraphFunctionName("false_predicate", library,
268 false_predicate);
269
270 int num = 0;
271 for (const auto& fake_arg : fake_args) {
272 auto* arg = false_predicate->mutable_signature()->add_input_arg();
273 arg->set_type(fake_arg.type());
274 arg->set_name(strings::StrCat("fake_arg", num));
275 num++;
276 }
277
278 auto* output = false_predicate->mutable_signature()->add_output_arg();
279 output->set_name("false_out");
280 output->set_type(DT_BOOL);
281
282 (*false_predicate->mutable_ret())["false_out"] = node->name() + ":output:0";
283 *false_predicate->mutable_node_def() = std::move(*graph.mutable_node());
284 return false_predicate;
285 }
286
CheckIfCanCompose(const OpDef & first_signature,const OpDef & second_signature)287 void CheckIfCanCompose(const OpDef& first_signature,
288 const OpDef& second_signature) {
289 CHECK(CanCompose(first_signature, second_signature))
290 << "The number of input arguments of function " << second_signature.name()
291 << " should be the same as the number of output arguments of function "
292 << first_signature.name() << ".";
293 }
294
295 } // namespace
296
MergeNodes(const FunctionDef & first_function,const FunctionDef & second_function,FunctionDef * fused_function,FunctionDefLibrary * library)297 void MergeNodes(const FunctionDef& first_function,
298 const FunctionDef& second_function, FunctionDef* fused_function,
299 FunctionDefLibrary* library) {
300 // Copy all nodes from first_function.
301 fused_function->mutable_node_def()->CopyFrom(first_function.node_def());
302 // Copy transformed nodes from the second function.
303 fused_function->mutable_node_def()->MergeFrom(second_function.node_def());
304 }
305
CanCompose(const OpDef & first_signature,const OpDef & second_signature)306 bool CanCompose(const OpDef& first_signature, const OpDef& second_signature) {
307 // TODO(prazek): Functions can have additional inputs being placeholders
308 // for a values used in function. We should be able to also fuse these
309 // functions.
310 return first_signature.output_arg_size() == second_signature.input_arg_size();
311 }
312
ComposeInput(const StringCollection & first_inputs,const StringCollection & second_inputs,const StringCollection & first_outputs,int arg_num)313 string ComposeInput(const StringCollection& first_inputs,
314 const StringCollection& second_inputs,
315 const StringCollection& first_outputs, int arg_num) {
316 // Take corresponding parent output.
317 return first_outputs.at(arg_num);
318 }
319
ComposeSignature(const OpDef & first_signature,const OpDef & second_signature,OpDef * fused_signature)320 void ComposeSignature(const OpDef& first_signature,
321 const OpDef& second_signature, OpDef* fused_signature) {
322 CheckIfCanCompose(first_signature, second_signature);
323
324 // Copy input signature from parent function.
325 *fused_signature->mutable_input_arg() = first_signature.input_arg();
326 // Copy output signature from second function.
327 *fused_signature->mutable_output_arg() = second_signature.output_arg();
328 }
329
ComposeOutput(const protobuf::Map<string,string> & first_ret,const protobuf::Map<string,string> & second_ret,protobuf::Map<string,string> * fused_ret)330 void ComposeOutput(const protobuf::Map<string, string>& first_ret,
331 const protobuf::Map<string, string>& second_ret,
332 protobuf::Map<string, string>* fused_ret) {
333 *fused_ret = second_ret;
334 }
335
CombineSignature(const OpDef & first_signature,const OpDef & second_signature,OpDef * fused_signature)336 void CombineSignature(const OpDef& first_signature,
337 const OpDef& second_signature, OpDef* fused_signature) {
338 CheckIfCanCompose(first_signature, second_signature);
339 // Copy input and output signature from parent function.
340 *fused_signature = first_signature;
341
342 // Add new output parameter.
343 fused_signature->mutable_output_arg()->MergeFrom(
344 second_signature.output_arg());
345 }
346
CombineOutput(const protobuf::Map<string,string> & first_ret,const protobuf::Map<string,string> & second_ret,protobuf::Map<string,string> * fused_ret)347 void CombineOutput(const protobuf::Map<string, string>& first_ret,
348 const protobuf::Map<string, string>& second_ret,
349 protobuf::Map<string, string>* fused_ret) {
350 *fused_ret = first_ret;
351 fused_ret->insert(second_ret.begin(), second_ret.end());
352 }
353
SameInput(const StringCollection & first_inputs,const StringCollection & second_inputs,const StringCollection & first_outputs,int arg_num)354 string SameInput(const StringCollection& first_inputs,
355 const StringCollection& second_inputs,
356 const StringCollection& first_outputs, int arg_num) {
357 return first_inputs.at(arg_num);
358 }
359
HasSameSignature(const OpDef & first_signature,const OpDef & second_signature)360 bool HasSameSignature(const OpDef& first_signature,
361 const OpDef& second_signature) {
362 return first_signature.input_arg_size() ==
363 second_signature.input_arg_size() &&
364 first_signature.output_arg_size() ==
365 second_signature.output_arg_size();
366 }
367
SameSignature(const OpDef & first_signature,const OpDef & second_signature,OpDef * fused_signature)368 void SameSignature(const OpDef& first_signature, const OpDef& second_signature,
369 OpDef* fused_signature) {
370 CHECK(HasSameSignature(first_signature, second_signature))
371 << "Functions do not have the same signature";
372 // Copy signature from first function.
373 *fused_signature = first_signature;
374 }
375
LazyConjunctionNodes(const FunctionDef & first_function,const FunctionDef & second_function,FunctionDef * fused_function,FunctionDefLibrary * library)376 void LazyConjunctionNodes(const FunctionDef& first_function,
377 const FunctionDef& second_function,
378 FunctionDef* fused_function,
379 FunctionDefLibrary* library) {
380 fused_function->mutable_node_def()->CopyFrom(first_function.node_def());
381
382 NodeDefBuilder if_builder("", "If");
383 if_builder.Input(GetOutputNode(first_function, 0), 0, DT_BOOL);
384 DataTypeVector in_arg_types;
385 std::vector<NodeDefBuilder::NodeOut> inputs;
386 for (const auto& input_arg : first_function.signature().input_arg()) {
387 inputs.push_back({input_arg.name(), 0, input_arg.type()});
388 in_arg_types.push_back(input_arg.type());
389 }
390 if_builder.Attr("Tin", in_arg_types);
391
392 if_builder.Attr("Tcond", DT_BOOL);
393 if_builder.Attr("Tout", DataTypeVector{DT_BOOL});
394 if_builder.Attr("_lower_using_switch_merge", true);
395
396 NameAttrList then_branch;
397 then_branch.set_name(second_function.signature().name());
398 if_builder.Attr("then_branch", then_branch);
399
400 auto* false_predicate =
401 CreateFalsePredicate(first_function.signature().input_arg(), library);
402
403 NameAttrList else_branch;
404 else_branch.set_name(false_predicate->signature().name());
405 if_builder.Attr("else_branch", else_branch);
406 if_builder.Input(inputs);
407
408 auto* if_node = fused_function->add_node_def();
409 // This is guaranteed to succeed.
410 TF_CHECK_OK(if_builder.Finalize(if_node));
411 function_utils::SetUniqueFunctionNodeName("cond", fused_function, if_node);
412
413 GetMutableOutputNode(fused_function, 0) = if_node->name() + ":output:0";
414 }
415
LazyConjunctionOutput(const protobuf::Map<string,string> & first_ret,const protobuf::Map<string,string> & second_ret,protobuf::Map<string,string> * fused_ret)416 void LazyConjunctionOutput(const protobuf::Map<string, string>& first_ret,
417 const protobuf::Map<string, string>& second_ret,
418 protobuf::Map<string, string>* fused_ret) {
419 CHECK_EQ(first_ret.size(), 1);
420 CHECK_EQ(second_ret.size(), 1);
421 // Temporarily copy returns from first_ret. We are going to change the
422 // output node after creating it.
423 *fused_ret = first_ret;
424 }
425
FuseFunctions(const FunctionDef & first_function,const FunctionDef & second_function,StringPiece fused_name_prefix,const SetFunctionSignatureFn & set_signature,const SetInputFn & set_input,const SetOutputFn & set_output,const SetNodesFn & set_nodes,FunctionDefLibrary * library)426 FunctionDef* FuseFunctions(
427 const FunctionDef& first_function, const FunctionDef& second_function,
428 StringPiece fused_name_prefix, const SetFunctionSignatureFn& set_signature,
429 const SetInputFn& set_input, const SetOutputFn& set_output,
430 const SetNodesFn& set_nodes, FunctionDefLibrary* library) {
431 if (first_function.attr_size() != 0 || second_function.attr_size() != 0)
432 return nullptr; // Functions with attributes are currently not supported
433
434 // This function will be used as a clone of second function, having unique
435 // names.
436 FunctionDef setup_function = second_function;
437 *setup_function.mutable_signature() = GetUniqueSignature(
438 first_function.signature(), setup_function.signature(),
439 setup_function.mutable_ret(), setup_function.mutable_node_def());
440
441 FunctionDef* fused_function = library->add_function();
442
443 set_signature(first_function.signature(), setup_function.signature(),
444 fused_function->mutable_signature());
445
446 graph_utils::SetUniqueGraphFunctionName(fused_name_prefix, library,
447 fused_function);
448
449 RenameFunctionNodes(first_function, setup_function.mutable_node_def(),
450 setup_function.mutable_ret());
451 set_output(first_function.ret(), setup_function.ret(),
452 fused_function->mutable_ret());
453
454 CHECK(fused_function->signature().output_arg_size() ==
455 fused_function->ret_size())
456 << "Fused function must have the same number of returns as output "
457 "args. Output size: "
458 << fused_function->signature().output_arg_size()
459 << ", ret size: " << fused_function->ret_size();
460
461 const auto first_inputs = GetFunctionInputs(first_function);
462 const auto second_inputs = GetFunctionInputs(setup_function);
463 const auto first_outputs = GetFunctionOutputs(first_function);
464 FuseFunctionNodes(first_inputs, second_inputs, first_outputs, set_input,
465 setup_function.mutable_node_def());
466 FuseReturns(first_inputs, second_inputs, first_outputs, set_input,
467 fused_function->mutable_ret());
468
469 set_nodes(first_function, setup_function, fused_function, library);
470
471 return fused_function;
472 }
473
474 } // namespace fusion_utils
475 } // namespace grappler
476 } // namespace tensorflow
477