1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/optimizers/data/function_utils.h"
17 #include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
18
19 #include "tensorflow/core/framework/device_base.h"
20 #include "tensorflow/core/framework/op_def.pb.h"
21 #include "tensorflow/core/lib/strings/scanner.h"
22 #include "tensorflow/core/util/ptr_util.h"
23
24 namespace tensorflow {
25 namespace grappler {
26 namespace function_utils {
27
FunctionDefTensorDesc(const string & node_name,const string & output,int position)28 FunctionDefTensorDesc::FunctionDefTensorDesc(const string& node_name,
29 const string& output, int position)
30 : node_name(node_name), node_output(output), position(position) {
31 full_str = strings::StrCat(node_name, ":", node_output, ":", position);
32 }
33
FunctionDefTensorDesc(const string & input)34 FunctionDefTensorDesc::FunctionDefTensorDesc(const string& input) {
35 // Parses node_name:node_output:position string into its components.
36 full_str = input;
37 StringPiece capture;
38 StringPiece remaining;
39
40 // Parse "node_name"
41 if (strings::Scanner(input)
42 .One(strings::Scanner::LETTER_DIGIT_DOT_UNDERSCORE)
43 .Any(strings::Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE)
44 .GetResult(&remaining, &capture)) {
45 node_name = string(capture.data(), capture.size());
46 }
47
48 // Parse "node_output" if it exists
49 if (strings::Scanner(remaining)
50 .OneLiteral(":")
51 .RestartCapture()
52 .One(strings::Scanner::LETTER)
53 .Any(strings::Scanner::LETTER_DIGIT_UNDERSCORE)
54 .GetResult(&remaining, &capture)) {
55 node_output = string(capture.data(), capture.size());
56 }
57
58 // Parse "position" if it exists
59 if (strings::Scanner(remaining)
60 .OneLiteral(":")
61 .RestartCapture()
62 .Many(strings::Scanner::DIGIT)
63 .GetResult(nullptr, &capture)) {
64 CHECK(strings::safe_strto32(capture, &position));
65 }
66 }
67
68 // TODO(rachelim): Create a utility class similar to MutableGraphView for
69 // FunctionDefs, and use that to manipulate functions. It'll be more
70 // performant if we kept mappings of nodes->inputs/outputs, so that we don't
71 // have to search over all nodes each time.
72 // Note that we're not using GrapplerFunctionItem because it doesn't cover
73 // some of our desired uses (eg changing the outputs of a function), and the
74 // FunctionDef -> GraphDef conversion isn't really necessary in this case.
ReplaceReferences(const string & from,const string & to,FunctionDef * func)75 void ReplaceReferences(const string& from, const string& to,
76 FunctionDef* func) {
77 for (NodeDef& n : *func->mutable_node_def()) {
78 std::replace(n.mutable_input()->begin(), n.mutable_input()->end(), from,
79 to);
80 }
81
82 for (auto& p : *func->mutable_ret()) {
83 if (p.second == from) {
84 p.second = to;
85 }
86 }
87 }
88
AddFunctionOutputWithUniqueName(StringPiece prefix,StringPiece output_tensor_name,FunctionDef * function,DataType dt)89 void AddFunctionOutputWithUniqueName(StringPiece prefix,
90 StringPiece output_tensor_name,
91 FunctionDef* function, DataType dt) {
92 string name = string(prefix);
93 int id = function->signature().output_arg_size();
94 while (ContainsFunctionOutputWithName(name, *function)) {
95 name = strings::StrCat(prefix, "/_", id);
96 ++id;
97 }
98 auto* output = function->mutable_signature()->mutable_output_arg()->Add();
99 output->set_name(name);
100 output->set_type(dt);
101
102 (*function->mutable_ret())[name] = string(output_tensor_name);
103 }
104
AddNode(StringPiece name,StringPiece op,const std::vector<string> & inputs,const std::vector<std::pair<string,AttrValue>> & attributes,FunctionDef * fd)105 NodeDef* AddNode(StringPiece name, StringPiece op,
106 const std::vector<string>& inputs,
107 const std::vector<std::pair<string, AttrValue>>& attributes,
108 FunctionDef* fd) {
109 NodeDef* node = fd->add_node_def();
110 if (!name.empty()) {
111 node->set_name(string(name));
112 } else {
113 SetUniqueFunctionNodeName(op, fd, node);
114 }
115 node->set_op(string(op));
116 for (const string& input : inputs) {
117 node->add_input(input);
118 }
119 for (auto attr : attributes) {
120 (*node->mutable_attr())[attr.first] = attr.second;
121 }
122 return node;
123 }
124
ContainsFunctionNodeWithName(StringPiece name,const FunctionDef & function)125 bool ContainsFunctionNodeWithName(StringPiece name,
126 const FunctionDef& function) {
127 return FindFunctionNodeWithName(name, function) != -1;
128 }
129
ContainsFunctionNodeWithOp(StringPiece op,const FunctionDef & function)130 bool ContainsFunctionNodeWithOp(StringPiece op, const FunctionDef& function) {
131 return FindFunctionNodeWithOp(op, function) != -1;
132 }
133
ContainsFunctionOutputWithName(StringPiece name,const FunctionDef & function)134 bool ContainsFunctionOutputWithName(StringPiece name,
135 const FunctionDef& function) {
136 return FindFunctionOutputWithName(name, function) != -1;
137 }
138
FindFunctionInputWithName(StringPiece name,const FunctionDef & function)139 int FindFunctionInputWithName(StringPiece name, const FunctionDef& function) {
140 return graph_utils::GetFirstElementIndexWithPredicate(
141 [&name](const OpDef_ArgDef& arg) { return arg.name() == name; },
142 function.signature().input_arg());
143 }
144
FindFunctionOutputWithName(StringPiece name,const FunctionDef & function)145 int FindFunctionOutputWithName(StringPiece name, const FunctionDef& function) {
146 return graph_utils::GetFirstElementIndexWithPredicate(
147 [&name](const OpDef_ArgDef& arg) { return arg.name() == name; },
148 function.signature().output_arg());
149 }
150
FindFunctionNodeWithName(StringPiece name,const FunctionDef & function)151 int FindFunctionNodeWithName(StringPiece name, const FunctionDef& function) {
152 return graph_utils::GetFirstElementIndexWithPredicate(
153 [&name](const NodeDef& node) { return node.name() == name; },
154 function.node_def());
155 }
156
FindFunctionNodeWithOp(StringPiece op,const FunctionDef & function)157 int FindFunctionNodeWithOp(StringPiece op, const FunctionDef& function) {
158 return graph_utils::GetFirstElementIndexWithPredicate(
159 [&op](const NodeDef& node) { return node.op() == op; },
160 function.node_def());
161 }
162
SetUniqueFunctionNodeName(StringPiece prefix,FunctionDef * function,NodeDef * node)163 void SetUniqueFunctionNodeName(StringPiece prefix, FunctionDef* function,
164 NodeDef* node) {
165 string name = string(prefix);
166 int id = function->node_def_size();
167 while (ContainsFunctionNodeWithName(name, *function)) {
168 name = strings::StrCat(prefix, "/_", id);
169 ++id;
170 }
171 node->set_name(std::move(name));
172 }
173
IsFunctionStateful(const FunctionLibraryDefinition & library,const FunctionDef & function_def,bool skip_assert)174 bool IsFunctionStateful(const FunctionLibraryDefinition& library,
175 const FunctionDef& function_def, bool skip_assert) {
176 if (!function_def.signature().is_stateful()) return false;
177
178 for (const NodeDef& node_def : function_def.node_def()) {
179 if (IsNodeStateful(library, node_def, skip_assert)) return true;
180 }
181 return false;
182 }
183
IsNodeStateful(const FunctionLibraryDefinition & library,const NodeDef & node,bool skip_assert)184 bool IsNodeStateful(const FunctionLibraryDefinition& library,
185 const NodeDef& node, bool skip_assert) {
186 const OpDef* op_def;
187 Status s = OpRegistry::Global()->LookUpOpDef(node.op(), &op_def);
188
189 if (!s.ok()) return true;
190
191 if (!op_def->is_stateful()) return false;
192
193 if (skip_assert && op_def->name() == "Assert") {
194 return false;
195 }
196
197 if (op_def->name() == "If") {
198 const FunctionDef* then_func =
199 library.Find(node.attr().at("then_branch").func().name());
200 const FunctionDef* else_func =
201 library.Find(node.attr().at("else_branch").func().name());
202 if ((then_func != nullptr &&
203 !IsFunctionStateful(library, *then_func, skip_assert)) &&
204 (else_func != nullptr &&
205 !IsFunctionStateful(library, *else_func, skip_assert))) {
206 return false;
207 }
208 }
209
210 if (op_def->name() == "While") {
211 const FunctionDef* cond_func =
212 library.Find(node.attr().at("cond").func().name());
213 const FunctionDef* body_func =
214 library.Find(node.attr().at("body").func().name());
215 if ((cond_func != nullptr &&
216 !IsFunctionStateful(library, *cond_func, skip_assert)) &&
217 (body_func != nullptr &&
218 !IsFunctionStateful(library, *body_func, skip_assert))) {
219 return false;
220 }
221 }
222 return true;
223 }
224
225 } // namespace function_utils
226 } // namespace grappler
227 } // namespace tensorflow
228