• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/optimizers/data/function_utils.h"
17 #include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
18 
19 #include "tensorflow/core/framework/device_base.h"
20 #include "tensorflow/core/framework/op_def.pb.h"
21 #include "tensorflow/core/lib/strings/scanner.h"
22 #include "tensorflow/core/util/ptr_util.h"
23 
24 namespace tensorflow {
25 namespace grappler {
26 namespace function_utils {
27 
FunctionDefTensorDesc(const string & node_name,const string & output,int position)28 FunctionDefTensorDesc::FunctionDefTensorDesc(const string& node_name,
29                                              const string& output, int position)
30     : node_name(node_name), node_output(output), position(position) {
31   full_str = strings::StrCat(node_name, ":", node_output, ":", position);
32 }
33 
FunctionDefTensorDesc(const string & input)34 FunctionDefTensorDesc::FunctionDefTensorDesc(const string& input) {
35   // Parses node_name:node_output:position string into its components.
36   full_str = input;
37   StringPiece capture;
38   StringPiece remaining;
39 
40   // Parse "node_name"
41   if (strings::Scanner(input)
42           .One(strings::Scanner::LETTER_DIGIT_DOT_UNDERSCORE)
43           .Any(strings::Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE)
44           .GetResult(&remaining, &capture)) {
45     node_name = string(capture.data(), capture.size());
46   }
47 
48   // Parse "node_output" if it exists
49   if (strings::Scanner(remaining)
50           .OneLiteral(":")
51           .RestartCapture()
52           .One(strings::Scanner::LETTER)
53           .Any(strings::Scanner::LETTER_DIGIT_UNDERSCORE)
54           .GetResult(&remaining, &capture)) {
55     node_output = string(capture.data(), capture.size());
56   }
57 
58   // Parse "position" if it exists
59   if (strings::Scanner(remaining)
60           .OneLiteral(":")
61           .RestartCapture()
62           .Many(strings::Scanner::DIGIT)
63           .GetResult(nullptr, &capture)) {
64     CHECK(strings::safe_strto32(capture, &position));
65   }
66 }
67 
68 // TODO(rachelim): Create a utility class similar to MutableGraphView for
69 // FunctionDefs, and use that to manipulate functions. It'll be more
70 // performant if we kept mappings of nodes->inputs/outputs, so that we don't
71 // have to search over all nodes each time.
72 // Note that we're not using GrapplerFunctionItem because it doesn't cover
73 // some of our desired uses (eg changing the outputs of a function), and the
74 // FunctionDef -> GraphDef conversion isn't really necessary in this case.
ReplaceReferences(const string & from,const string & to,FunctionDef * func)75 void ReplaceReferences(const string& from, const string& to,
76                        FunctionDef* func) {
77   for (NodeDef& n : *func->mutable_node_def()) {
78     std::replace(n.mutable_input()->begin(), n.mutable_input()->end(), from,
79                  to);
80   }
81 
82   for (auto& p : *func->mutable_ret()) {
83     if (p.second == from) {
84       p.second = to;
85     }
86   }
87 }
88 
AddFunctionOutputWithUniqueName(StringPiece prefix,StringPiece output_tensor_name,FunctionDef * function,DataType dt)89 void AddFunctionOutputWithUniqueName(StringPiece prefix,
90                                      StringPiece output_tensor_name,
91                                      FunctionDef* function, DataType dt) {
92   string name = string(prefix);
93   int id = function->signature().output_arg_size();
94   while (ContainsFunctionOutputWithName(name, *function)) {
95     name = strings::StrCat(prefix, "/_", id);
96     ++id;
97   }
98   auto* output = function->mutable_signature()->mutable_output_arg()->Add();
99   output->set_name(name);
100   output->set_type(dt);
101 
102   (*function->mutable_ret())[name] = string(output_tensor_name);
103 }
104 
AddNode(StringPiece name,StringPiece op,const std::vector<string> & inputs,const std::vector<std::pair<string,AttrValue>> & attributes,FunctionDef * fd)105 NodeDef* AddNode(StringPiece name, StringPiece op,
106                  const std::vector<string>& inputs,
107                  const std::vector<std::pair<string, AttrValue>>& attributes,
108                  FunctionDef* fd) {
109   NodeDef* node = fd->add_node_def();
110   if (!name.empty()) {
111     node->set_name(string(name));
112   } else {
113     SetUniqueFunctionNodeName(op, fd, node);
114   }
115   node->set_op(string(op));
116   for (const string& input : inputs) {
117     node->add_input(input);
118   }
119   for (auto attr : attributes) {
120     (*node->mutable_attr())[attr.first] = attr.second;
121   }
122   return node;
123 }
124 
ContainsFunctionNodeWithName(StringPiece name,const FunctionDef & function)125 bool ContainsFunctionNodeWithName(StringPiece name,
126                                   const FunctionDef& function) {
127   return FindFunctionNodeWithName(name, function) != -1;
128 }
129 
ContainsFunctionNodeWithOp(StringPiece op,const FunctionDef & function)130 bool ContainsFunctionNodeWithOp(StringPiece op, const FunctionDef& function) {
131   return FindFunctionNodeWithOp(op, function) != -1;
132 }
133 
ContainsFunctionOutputWithName(StringPiece name,const FunctionDef & function)134 bool ContainsFunctionOutputWithName(StringPiece name,
135                                     const FunctionDef& function) {
136   return FindFunctionOutputWithName(name, function) != -1;
137 }
138 
FindFunctionInputWithName(StringPiece name,const FunctionDef & function)139 int FindFunctionInputWithName(StringPiece name, const FunctionDef& function) {
140   return graph_utils::GetFirstElementIndexWithPredicate(
141       [&name](const OpDef_ArgDef& arg) { return arg.name() == name; },
142       function.signature().input_arg());
143 }
144 
FindFunctionOutputWithName(StringPiece name,const FunctionDef & function)145 int FindFunctionOutputWithName(StringPiece name, const FunctionDef& function) {
146   return graph_utils::GetFirstElementIndexWithPredicate(
147       [&name](const OpDef_ArgDef& arg) { return arg.name() == name; },
148       function.signature().output_arg());
149 }
150 
FindFunctionNodeWithName(StringPiece name,const FunctionDef & function)151 int FindFunctionNodeWithName(StringPiece name, const FunctionDef& function) {
152   return graph_utils::GetFirstElementIndexWithPredicate(
153       [&name](const NodeDef& node) { return node.name() == name; },
154       function.node_def());
155 }
156 
FindFunctionNodeWithOp(StringPiece op,const FunctionDef & function)157 int FindFunctionNodeWithOp(StringPiece op, const FunctionDef& function) {
158   return graph_utils::GetFirstElementIndexWithPredicate(
159       [&op](const NodeDef& node) { return node.op() == op; },
160       function.node_def());
161 }
162 
SetUniqueFunctionNodeName(StringPiece prefix,FunctionDef * function,NodeDef * node)163 void SetUniqueFunctionNodeName(StringPiece prefix, FunctionDef* function,
164                                NodeDef* node) {
165   string name = string(prefix);
166   int id = function->node_def_size();
167   while (ContainsFunctionNodeWithName(name, *function)) {
168     name = strings::StrCat(prefix, "/_", id);
169     ++id;
170   }
171   node->set_name(std::move(name));
172 }
173 
IsFunctionStateful(const FunctionLibraryDefinition & library,const FunctionDef & function_def,bool skip_assert)174 bool IsFunctionStateful(const FunctionLibraryDefinition& library,
175                         const FunctionDef& function_def, bool skip_assert) {
176   if (!function_def.signature().is_stateful()) return false;
177 
178   for (const NodeDef& node_def : function_def.node_def()) {
179     if (IsNodeStateful(library, node_def, skip_assert)) return true;
180   }
181   return false;
182 }
183 
IsNodeStateful(const FunctionLibraryDefinition & library,const NodeDef & node,bool skip_assert)184 bool IsNodeStateful(const FunctionLibraryDefinition& library,
185                     const NodeDef& node, bool skip_assert) {
186   const OpDef* op_def;
187   Status s = OpRegistry::Global()->LookUpOpDef(node.op(), &op_def);
188 
189   if (!s.ok()) return true;
190 
191   if (!op_def->is_stateful()) return false;
192 
193   if (skip_assert && op_def->name() == "Assert") {
194     return false;
195   }
196 
197   if (op_def->name() == "If") {
198     const FunctionDef* then_func =
199         library.Find(node.attr().at("then_branch").func().name());
200     const FunctionDef* else_func =
201         library.Find(node.attr().at("else_branch").func().name());
202     if ((then_func != nullptr &&
203          !IsFunctionStateful(library, *then_func, skip_assert)) &&
204         (else_func != nullptr &&
205          !IsFunctionStateful(library, *else_func, skip_assert))) {
206       return false;
207     }
208   }
209 
210   if (op_def->name() == "While") {
211     const FunctionDef* cond_func =
212         library.Find(node.attr().at("cond").func().name());
213     const FunctionDef* body_func =
214         library.Find(node.attr().at("body").func().name());
215     if ((cond_func != nullptr &&
216          !IsFunctionStateful(library, *cond_func, skip_assert)) &&
217         (body_func != nullptr &&
218          !IsFunctionStateful(library, *body_func, skip_assert))) {
219       return false;
220     }
221   }
222   return true;
223 }
224 
225 }  // namespace function_utils
226 }  // namespace grappler
227 }  // namespace tensorflow
228