1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/optimizers/data/function_utils.h"
17
18 #include "tensorflow/core/framework/device_base.h"
19 #include "tensorflow/core/framework/op_def.pb.h"
20 #include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
21 #include "tensorflow/core/lib/strings/scanner.h"
22 #include "tensorflow/core/util/ptr_util.h"
23
24 namespace tensorflow {
25 namespace grappler {
26 namespace function_utils {
27
FunctionDefTensorDesc(const string & node_name,const string & output,int position)28 FunctionDefTensorDesc::FunctionDefTensorDesc(const string& node_name,
29 const string& output, int position)
30 : node_name(node_name), node_output(output), position(position) {
31 full_str = strings::StrCat(node_name, ":", node_output, ":", position);
32 }
33
FunctionDefTensorDesc(const string & input)34 FunctionDefTensorDesc::FunctionDefTensorDesc(const string& input) {
35 // Parses node_name:node_output:position string into its components.
36 full_str = input;
37 StringPiece capture;
38 StringPiece remaining;
39
40 // Parse "node_name"
41 if (strings::Scanner(input)
42 .One(strings::Scanner::LETTER_DIGIT_DOT_UNDERSCORE)
43 .Any(strings::Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE)
44 .GetResult(&remaining, &capture)) {
45 node_name = string(capture.data(), capture.size());
46 }
47
48 // Parse "node_output" if it exists
49 if (strings::Scanner(remaining)
50 .OneLiteral(":")
51 .RestartCapture()
52 .One(strings::Scanner::LETTER)
53 .Any(strings::Scanner::LETTER_DIGIT_UNDERSCORE)
54 .GetResult(&remaining, &capture)) {
55 node_output = string(capture.data(), capture.size());
56 }
57
58 // Parse "position" if it exists
59 if (strings::Scanner(remaining)
60 .OneLiteral(":")
61 .RestartCapture()
62 .Many(strings::Scanner::DIGIT)
63 .GetResult(nullptr, &capture)) {
64 CHECK(strings::safe_strto32(capture, &position));
65 }
66 }
67
68 // TODO(rachelim): Create a utility class similar to MutableGraphView for
69 // FunctionDefs, and use that to manipulate functions. It'll be more
70 // performant if we kept mappings of nodes->inputs/outputs, so that we don't
71 // have to search over all nodes each time.
72 // Note that we're not using GrapplerFunctionItem because it doesn't cover
73 // some of our desired uses (eg changing the outputs of a function), and the
74 // FunctionDef -> GraphDef conversion isn't really necessary in this case.
ReplaceReferences(const string & from,const string & to,FunctionDef * func)75 void ReplaceReferences(const string& from, const string& to,
76 FunctionDef* func) {
77 for (NodeDef& n : *func->mutable_node_def()) {
78 std::replace(n.mutable_input()->begin(), n.mutable_input()->end(), from,
79 to);
80 }
81
82 for (auto& p : *func->mutable_ret()) {
83 if (p.second == from) {
84 p.second = to;
85 }
86 }
87 }
88
AddFunctionOutputWithUniqueName(StringPiece prefix,StringPiece output_tensor_name,FunctionDef * fdef,DataType dtype)89 void AddFunctionOutputWithUniqueName(StringPiece prefix,
90 StringPiece output_tensor_name,
91 FunctionDef* fdef, DataType dtype) {
92 string name = string(prefix);
93 int id = fdef->signature().output_arg_size();
94 while (ContainsFunctionOutputWithName(name, *fdef)) {
95 name = strings::StrCat(prefix, "/_", id);
96 ++id;
97 }
98 auto* output = fdef->mutable_signature()->mutable_output_arg()->Add();
99 output->set_name(name);
100 output->set_type(dtype);
101
102 (*fdef->mutable_ret())[name] = string(output_tensor_name);
103 }
104
AddFunctionInput(const string & name,FunctionDef * fdef,DataType dtype)105 OpDef_ArgDef* AddFunctionInput(const string& name, FunctionDef* fdef,
106 DataType dtype) {
107 auto* input_arg = fdef->mutable_signature()->mutable_input_arg()->Add();
108 input_arg->set_type(dtype);
109 input_arg->set_name(name);
110
111 return input_arg;
112 }
113
AddNode(StringPiece name,StringPiece op,const std::vector<string> & inputs,const std::vector<std::pair<string,AttrValue>> & attributes,FunctionDef * fd)114 NodeDef* AddNode(StringPiece name, StringPiece op,
115 const std::vector<string>& inputs,
116 const std::vector<std::pair<string, AttrValue>>& attributes,
117 FunctionDef* fd) {
118 NodeDef* node = fd->add_node_def();
119 if (!name.empty()) {
120 node->set_name(string(name));
121 } else {
122 SetUniqueFunctionNodeName(op, fd, node);
123 }
124 node->set_op(string(op));
125 for (const string& input : inputs) {
126 node->add_input(input);
127 }
128 for (const auto& attr : attributes) {
129 (*node->mutable_attr())[attr.first] = attr.second;
130 }
131 return node;
132 }
133
ContainsFunctionNodeWithName(StringPiece name,const FunctionDef & function)134 bool ContainsFunctionNodeWithName(StringPiece name,
135 const FunctionDef& function) {
136 return FindFunctionNodeWithName(name, function) != -1;
137 }
138
ContainsFunctionNodeWithOp(StringPiece op,const FunctionDef & function)139 bool ContainsFunctionNodeWithOp(StringPiece op, const FunctionDef& function) {
140 return FindFunctionNodeWithOp(op, function) != -1;
141 }
142
ContainsFunctionOutputWithName(StringPiece name,const FunctionDef & function)143 bool ContainsFunctionOutputWithName(StringPiece name,
144 const FunctionDef& function) {
145 return FindFunctionOutputWithName(name, function) != -1;
146 }
147
FindFunctionInputWithName(StringPiece name,const FunctionDef & function)148 int FindFunctionInputWithName(StringPiece name, const FunctionDef& function) {
149 return graph_utils::GetFirstElementIndexWithPredicate(
150 [&name](const OpDef_ArgDef& arg) { return arg.name() == name; },
151 function.signature().input_arg());
152 }
153
FindFunctionOutputWithName(StringPiece name,const FunctionDef & function)154 int FindFunctionOutputWithName(StringPiece name, const FunctionDef& function) {
155 return graph_utils::GetFirstElementIndexWithPredicate(
156 [&name](const OpDef_ArgDef& arg) { return arg.name() == name; },
157 function.signature().output_arg());
158 }
159
FindFunctionNodeWithName(StringPiece name,const FunctionDef & function)160 int FindFunctionNodeWithName(StringPiece name, const FunctionDef& function) {
161 return graph_utils::GetFirstElementIndexWithPredicate(
162 [&name](const NodeDef& node) { return node.name() == name; },
163 function.node_def());
164 }
165
FindFunctionNodeWithOp(StringPiece op,const FunctionDef & function)166 int FindFunctionNodeWithOp(StringPiece op, const FunctionDef& function) {
167 return graph_utils::GetFirstElementIndexWithPredicate(
168 [&op](const NodeDef& node) { return node.op() == op; },
169 function.node_def());
170 }
171
SetUniqueFunctionNodeName(StringPiece prefix,FunctionDef * function,NodeDef * node)172 void SetUniqueFunctionNodeName(StringPiece prefix, FunctionDef* function,
173 NodeDef* node) {
174 string name = string(prefix);
175 int id = function->node_def_size();
176 while (ContainsFunctionNodeWithName(name, *function)) {
177 name = strings::StrCat(prefix, "/_", id);
178 ++id;
179 }
180 node->set_name(std::move(name));
181 }
182
IsFunctionStateful(const FunctionLibraryDefinition & library,const FunctionDef & function_def,bool skip_assert)183 bool IsFunctionStateful(const FunctionLibraryDefinition& library,
184 const FunctionDef& function_def, bool skip_assert) {
185 if (!function_def.signature().is_stateful()) return false;
186
187 for (const NodeDef& node_def : function_def.node_def()) {
188 if (IsNodeStateful(library, node_def, skip_assert)) return true;
189 }
190 return false;
191 }
192
IsNodeStateful(const FunctionLibraryDefinition & library,const NodeDef & node,bool skip_assert)193 bool IsNodeStateful(const FunctionLibraryDefinition& library,
194 const NodeDef& node, bool skip_assert) {
195 const OpDef* op_def;
196 Status s = OpRegistry::Global()->LookUpOpDef(node.op(), &op_def);
197
198 if (!s.ok()) return true;
199
200 if (!op_def->is_stateful()) return false;
201
202 if (skip_assert && op_def->name() == "Assert") {
203 return false;
204 }
205
206 if (op_def->name() == "If") {
207 const FunctionDef* then_func =
208 library.Find(node.attr().at("then_branch").func().name());
209 const FunctionDef* else_func =
210 library.Find(node.attr().at("else_branch").func().name());
211 if ((then_func != nullptr &&
212 !IsFunctionStateful(library, *then_func, skip_assert)) &&
213 (else_func != nullptr &&
214 !IsFunctionStateful(library, *else_func, skip_assert))) {
215 return false;
216 }
217 }
218
219 if (op_def->name() == "While") {
220 const FunctionDef* cond_func =
221 library.Find(node.attr().at("cond").func().name());
222 const FunctionDef* body_func =
223 library.Find(node.attr().at("body").func().name());
224 if ((cond_func != nullptr &&
225 !IsFunctionStateful(library, *cond_func, skip_assert)) &&
226 (body_func != nullptr &&
227 !IsFunctionStateful(library, *body_func, skip_assert))) {
228 return false;
229 }
230 }
231 return true;
232 }
233
234 } // namespace function_utils
235 } // namespace grappler
236 } // namespace tensorflow
237