• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/optimizers/data/function_utils.h"
17 
18 #include "tensorflow/core/framework/device_base.h"
19 #include "tensorflow/core/framework/op_def.pb.h"
20 #include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
21 #include "tensorflow/core/lib/strings/scanner.h"
22 #include "tensorflow/core/util/ptr_util.h"
23 
24 namespace tensorflow {
25 namespace grappler {
26 namespace function_utils {
27 
FunctionDefTensorDesc(const string & node_name,const string & output,int position)28 FunctionDefTensorDesc::FunctionDefTensorDesc(const string& node_name,
29                                              const string& output, int position)
30     : node_name(node_name), node_output(output), position(position) {
31   full_str = strings::StrCat(node_name, ":", node_output, ":", position);
32 }
33 
FunctionDefTensorDesc(const string & input)34 FunctionDefTensorDesc::FunctionDefTensorDesc(const string& input) {
35   // Parses node_name:node_output:position string into its components.
36   full_str = input;
37   StringPiece capture;
38   StringPiece remaining;
39 
40   // Parse "node_name"
41   if (strings::Scanner(input)
42           .One(strings::Scanner::LETTER_DIGIT_DOT_UNDERSCORE)
43           .Any(strings::Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE)
44           .GetResult(&remaining, &capture)) {
45     node_name = string(capture.data(), capture.size());
46   }
47 
48   // Parse "node_output" if it exists
49   if (strings::Scanner(remaining)
50           .OneLiteral(":")
51           .RestartCapture()
52           .One(strings::Scanner::LETTER)
53           .Any(strings::Scanner::LETTER_DIGIT_UNDERSCORE)
54           .GetResult(&remaining, &capture)) {
55     node_output = string(capture.data(), capture.size());
56   }
57 
58   // Parse "position" if it exists
59   if (strings::Scanner(remaining)
60           .OneLiteral(":")
61           .RestartCapture()
62           .Many(strings::Scanner::DIGIT)
63           .GetResult(nullptr, &capture)) {
64     CHECK(strings::safe_strto32(capture, &position));
65   }
66 }
67 
68 // TODO(rachelim): Create a utility class similar to MutableGraphView for
69 // FunctionDefs, and use that to manipulate functions. It'll be more
70 // performant if we kept mappings of nodes->inputs/outputs, so that we don't
71 // have to search over all nodes each time.
72 // Note that we're not using GrapplerFunctionItem because it doesn't cover
73 // some of our desired uses (eg changing the outputs of a function), and the
74 // FunctionDef -> GraphDef conversion isn't really necessary in this case.
ReplaceReferences(const string & from,const string & to,FunctionDef * func)75 void ReplaceReferences(const string& from, const string& to,
76                        FunctionDef* func) {
77   for (NodeDef& n : *func->mutable_node_def()) {
78     std::replace(n.mutable_input()->begin(), n.mutable_input()->end(), from,
79                  to);
80   }
81 
82   for (auto& p : *func->mutable_ret()) {
83     if (p.second == from) {
84       p.second = to;
85     }
86   }
87 }
88 
AddFunctionOutputWithUniqueName(StringPiece prefix,StringPiece output_tensor_name,FunctionDef * fdef,DataType dtype)89 void AddFunctionOutputWithUniqueName(StringPiece prefix,
90                                      StringPiece output_tensor_name,
91                                      FunctionDef* fdef, DataType dtype) {
92   string name = string(prefix);
93   int id = fdef->signature().output_arg_size();
94   while (ContainsFunctionOutputWithName(name, *fdef)) {
95     name = strings::StrCat(prefix, "/_", id);
96     ++id;
97   }
98   auto* output = fdef->mutable_signature()->mutable_output_arg()->Add();
99   output->set_name(name);
100   output->set_type(dtype);
101 
102   (*fdef->mutable_ret())[name] = string(output_tensor_name);
103 }
104 
AddFunctionInput(const string & name,FunctionDef * fdef,DataType dtype)105 OpDef_ArgDef* AddFunctionInput(const string& name, FunctionDef* fdef,
106                                DataType dtype) {
107   auto* input_arg = fdef->mutable_signature()->mutable_input_arg()->Add();
108   input_arg->set_type(dtype);
109   input_arg->set_name(name);
110 
111   return input_arg;
112 }
113 
AddNode(StringPiece name,StringPiece op,const std::vector<string> & inputs,const std::vector<std::pair<string,AttrValue>> & attributes,FunctionDef * fd)114 NodeDef* AddNode(StringPiece name, StringPiece op,
115                  const std::vector<string>& inputs,
116                  const std::vector<std::pair<string, AttrValue>>& attributes,
117                  FunctionDef* fd) {
118   NodeDef* node = fd->add_node_def();
119   if (!name.empty()) {
120     node->set_name(string(name));
121   } else {
122     SetUniqueFunctionNodeName(op, fd, node);
123   }
124   node->set_op(string(op));
125   for (const string& input : inputs) {
126     node->add_input(input);
127   }
128   for (const auto& attr : attributes) {
129     (*node->mutable_attr())[attr.first] = attr.second;
130   }
131   return node;
132 }
133 
ContainsFunctionNodeWithName(StringPiece name,const FunctionDef & function)134 bool ContainsFunctionNodeWithName(StringPiece name,
135                                   const FunctionDef& function) {
136   return FindFunctionNodeWithName(name, function) != -1;
137 }
138 
ContainsFunctionNodeWithOp(StringPiece op,const FunctionDef & function)139 bool ContainsFunctionNodeWithOp(StringPiece op, const FunctionDef& function) {
140   return FindFunctionNodeWithOp(op, function) != -1;
141 }
142 
ContainsFunctionOutputWithName(StringPiece name,const FunctionDef & function)143 bool ContainsFunctionOutputWithName(StringPiece name,
144                                     const FunctionDef& function) {
145   return FindFunctionOutputWithName(name, function) != -1;
146 }
147 
FindFunctionInputWithName(StringPiece name,const FunctionDef & function)148 int FindFunctionInputWithName(StringPiece name, const FunctionDef& function) {
149   return graph_utils::GetFirstElementIndexWithPredicate(
150       [&name](const OpDef_ArgDef& arg) { return arg.name() == name; },
151       function.signature().input_arg());
152 }
153 
FindFunctionOutputWithName(StringPiece name,const FunctionDef & function)154 int FindFunctionOutputWithName(StringPiece name, const FunctionDef& function) {
155   return graph_utils::GetFirstElementIndexWithPredicate(
156       [&name](const OpDef_ArgDef& arg) { return arg.name() == name; },
157       function.signature().output_arg());
158 }
159 
FindFunctionNodeWithName(StringPiece name,const FunctionDef & function)160 int FindFunctionNodeWithName(StringPiece name, const FunctionDef& function) {
161   return graph_utils::GetFirstElementIndexWithPredicate(
162       [&name](const NodeDef& node) { return node.name() == name; },
163       function.node_def());
164 }
165 
FindFunctionNodeWithOp(StringPiece op,const FunctionDef & function)166 int FindFunctionNodeWithOp(StringPiece op, const FunctionDef& function) {
167   return graph_utils::GetFirstElementIndexWithPredicate(
168       [&op](const NodeDef& node) { return node.op() == op; },
169       function.node_def());
170 }
171 
SetUniqueFunctionNodeName(StringPiece prefix,FunctionDef * function,NodeDef * node)172 void SetUniqueFunctionNodeName(StringPiece prefix, FunctionDef* function,
173                                NodeDef* node) {
174   string name = string(prefix);
175   int id = function->node_def_size();
176   while (ContainsFunctionNodeWithName(name, *function)) {
177     name = strings::StrCat(prefix, "/_", id);
178     ++id;
179   }
180   node->set_name(std::move(name));
181 }
182 
IsFunctionStateful(const FunctionLibraryDefinition & library,const FunctionDef & function_def,bool skip_assert)183 bool IsFunctionStateful(const FunctionLibraryDefinition& library,
184                         const FunctionDef& function_def, bool skip_assert) {
185   if (!function_def.signature().is_stateful()) return false;
186 
187   for (const NodeDef& node_def : function_def.node_def()) {
188     if (IsNodeStateful(library, node_def, skip_assert)) return true;
189   }
190   return false;
191 }
192 
IsNodeStateful(const FunctionLibraryDefinition & library,const NodeDef & node,bool skip_assert)193 bool IsNodeStateful(const FunctionLibraryDefinition& library,
194                     const NodeDef& node, bool skip_assert) {
195   const OpDef* op_def;
196   Status s = OpRegistry::Global()->LookUpOpDef(node.op(), &op_def);
197 
198   if (!s.ok()) return true;
199 
200   if (!op_def->is_stateful()) return false;
201 
202   if (skip_assert && op_def->name() == "Assert") {
203     return false;
204   }
205 
206   if (op_def->name() == "If") {
207     const FunctionDef* then_func =
208         library.Find(node.attr().at("then_branch").func().name());
209     const FunctionDef* else_func =
210         library.Find(node.attr().at("else_branch").func().name());
211     if ((then_func != nullptr &&
212          !IsFunctionStateful(library, *then_func, skip_assert)) &&
213         (else_func != nullptr &&
214          !IsFunctionStateful(library, *else_func, skip_assert))) {
215       return false;
216     }
217   }
218 
219   if (op_def->name() == "While") {
220     const FunctionDef* cond_func =
221         library.Find(node.attr().at("cond").func().name());
222     const FunctionDef* body_func =
223         library.Find(node.attr().at("body").func().name());
224     if ((cond_func != nullptr &&
225          !IsFunctionStateful(library, *cond_func, skip_assert)) &&
226         (body_func != nullptr &&
227          !IsFunctionStateful(library, *body_func, skip_assert))) {
228       return false;
229     }
230   }
231   return true;
232 }
233 
234 }  // namespace function_utils
235 }  // namespace grappler
236 }  // namespace tensorflow
237