• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/graph_def_util.h"
17 
18 #include <set>
19 #include <unordered_map>
20 #include <unordered_set>
21 #include <vector>
22 
23 #include "tensorflow/core/framework/attr_value.pb.h"
24 #include "tensorflow/core/framework/function.pb.h"
25 #include "tensorflow/core/framework/graph.pb.h"
26 #include "tensorflow/core/framework/node_def.pb.h"
27 #include "tensorflow/core/framework/node_def_util.h"
28 #include "tensorflow/core/framework/op_def_util.h"
29 #include "tensorflow/core/framework/versions.pb_text.h"
30 #include "tensorflow/core/lib/core/errors.h"
31 #include "tensorflow/core/lib/core/status.h"
32 #include "tensorflow/core/lib/strings/str_util.h"
33 #include "tensorflow/core/lib/strings/strcat.h"
34 
35 namespace tensorflow {
36 
SummarizeGraphDef(const GraphDef & graph_def)37 string SummarizeGraphDef(const GraphDef& graph_def) {
38   string ret;
39   strings::StrAppend(
40       &ret, "versions = ", ProtoShortDebugString(graph_def.versions()), ";\n");
41   for (const NodeDef& node : graph_def.node()) {
42     strings::StrAppend(&ret, SummarizeNodeDef(node), ";\n");
43   }
44   return ret;
45 }
46 
ValidateExternalGraphDefSyntax(const GraphDef & graph_def)47 Status ValidateExternalGraphDefSyntax(const GraphDef& graph_def) {
48   for (const NodeDef& node : graph_def.node()) {
49     TF_RETURN_IF_ERROR(ValidateExternalNodeDefSyntax(node));
50   }
51   return Status::OK();
52 }
53 
AddDefaultAttrsToGraphDef(GraphDef * graph_def,const OpRegistryInterface & op_registry,int node_offset)54 Status AddDefaultAttrsToGraphDef(GraphDef* graph_def,
55                                  const OpRegistryInterface& op_registry,
56                                  int node_offset) {
57   return AddDefaultAttrsToGraphDef(graph_def, op_registry, node_offset, false);
58 }
59 
AddDefaultAttrsToGraphDef(GraphDef * graph_def,const OpRegistryInterface & op_registry,int node_offset,bool skip_unknown_ops)60 Status AddDefaultAttrsToGraphDef(GraphDef* graph_def,
61                                  const OpRegistryInterface& op_registry,
62                                  int node_offset, bool skip_unknown_ops) {
63   if (node_offset > graph_def->node_size()) {
64     return errors::InvalidArgument(
65         "Tried to add default attrs to GraphDef "
66         "starting at offset ",
67         node_offset, " with total nodes in graph: ", graph_def->node_size());
68   }
69 
70   for (int i = node_offset; i < graph_def->node_size(); ++i) {
71     NodeDef* node_def = graph_def->mutable_node(i);
72     const OpDef* op_def;
73     Status s = op_registry.LookUpOpDef(node_def->op(), &op_def);
74     if (s.ok()) {
75       AddDefaultsToNodeDef(*op_def, node_def);
76     } else if (!skip_unknown_ops) {
77       return s;
78     }
79   }
80 
81   return Status::OK();
82 }
83 
RemoveNewDefaultAttrsFromNodeDef(NodeDef * node_def,const OpRegistryInterface & consumer_op_registry,const OpRegistryInterface & producer_op_registry,std::set<std::pair<string,string>> * op_attr_removed)84 static Status RemoveNewDefaultAttrsFromNodeDef(
85     NodeDef* node_def, const OpRegistryInterface& consumer_op_registry,
86     const OpRegistryInterface& producer_op_registry,
87     std::set<std::pair<string, string>>* op_attr_removed) {
88   const OpDef* producer_op_def;
89   const OpDef* consumer_op_def;
90   TF_RETURN_IF_ERROR(
91       producer_op_registry.LookUpOpDef(node_def->op(), &producer_op_def));
92   TF_RETURN_IF_ERROR(
93       consumer_op_registry.LookUpOpDef(node_def->op(), &consumer_op_def));
94 
95   std::vector<string> to_remove;
96   for (const auto& attr : node_def->attr()) {
97     // If the attr is not in consumer_op_def and doesn't start with '_'...
98     if (!str_util::StartsWith(attr.first, "_") &&
99         FindAttr(attr.first, *consumer_op_def) == nullptr) {
100       const OpDef::AttrDef* producer_attr_def =
101           FindAttr(attr.first, *producer_op_def);
102       if (producer_attr_def == nullptr) {
103         return errors::InvalidArgument(
104             "Attr '", attr.first,
105             "' missing in producer's OpDef: ", SummarizeOpDef(*producer_op_def),
106             " but found in node: ", FormatNodeDefForError(*node_def));
107       }
108       // ...and it has the same value as the default in producer,
109       if (producer_attr_def->has_default_value() &&
110           AreAttrValuesEqual(producer_attr_def->default_value(), attr.second)) {
111         // then we will remove it below.
112         to_remove.emplace_back(attr.first);
113       }
114     }
115   }
116   // We separate identifying which attrs should be removed from
117   // actually removing them to avoid invalidating the loop iterators
118   // above.
119   for (const string& attr_name : to_remove) {
120     node_def->mutable_attr()->erase(attr_name);
121     if (op_attr_removed != nullptr) {
122       op_attr_removed->insert(std::make_pair(node_def->op(), attr_name));
123     }
124   }
125 
126   return Status::OK();
127 }
128 
IsFunction(const GraphDef & graph_def,const string & op_name)129 static bool IsFunction(const GraphDef& graph_def, const string& op_name) {
130   for (const auto& func_def : graph_def.library().function()) {
131     if (op_name == func_def.signature().name()) return true;
132   }
133   return false;
134 }
135 
RemoveNewDefaultAttrsFromGraphDef(GraphDef * graph_def,const OpRegistryInterface & consumer_op_registry,const OpRegistryInterface & producer_op_registry,std::set<std::pair<string,string>> * op_attr_removed)136 Status RemoveNewDefaultAttrsFromGraphDef(
137     GraphDef* graph_def, const OpRegistryInterface& consumer_op_registry,
138     const OpRegistryInterface& producer_op_registry,
139     std::set<std::pair<string, string>>* op_attr_removed) {
140   // TODO(joshL): Make IsFunction() faster by collecting the names of
141   // all functions as a preprocessing step.
142   for (int n = 0; n < graph_def->node_size(); ++n) {
143     NodeDef* node_def = graph_def->mutable_node(n);
144     if (!IsFunction(*graph_def, node_def->op())) {
145       TF_RETURN_IF_ERROR(RemoveNewDefaultAttrsFromNodeDef(
146           node_def, consumer_op_registry, producer_op_registry,
147           op_attr_removed));
148     }
149   }
150   for (int f = 0; f < graph_def->library().function_size(); ++f) {
151     FunctionDef* func_def = graph_def->mutable_library()->mutable_function(f);
152     for (int n = 0; n < func_def->node_def_size(); ++n) {
153       NodeDef* node_def = func_def->mutable_node_def(n);
154       if (!IsFunction(*graph_def, node_def->op())) {
155         // TODO(josh11b): Better handling of attrs with placeholder values.
156         TF_RETURN_IF_ERROR(RemoveNewDefaultAttrsFromNodeDef(
157             node_def, consumer_op_registry, producer_op_registry,
158             op_attr_removed));
159       }
160     }
161   }
162 
163   return Status::OK();
164 }
165 
OpsUsedByGraph(const GraphDef & graph_def,std::set<string> * ops_used_in_graph)166 void OpsUsedByGraph(const GraphDef& graph_def,
167                     std::set<string>* ops_used_in_graph) {
168   // Map function names to definitions.
169   std::unordered_map<string, const FunctionDef*> name_to_function;
170   for (const auto& function : graph_def.library().function()) {
171     name_to_function.insert(
172         std::make_pair(function.signature().name(), &function));
173   }
174 
175   // Collect the sorted list of op names.  Since functions can reference
176   // functions, we need a recursive traversal.
177   std::set<string> used_ops;  // Includes both primitive ops and functions
178   std::vector<const FunctionDef*> functions_to_process;  // A subset of used_ops
179   // Collect the logic to mark an op in a lambda; it'll be used twice below.
180   const auto mark_op_as_used = [&used_ops, &functions_to_process,
181                                 &name_to_function](const string& op) {
182     if (used_ops.insert(op).second) {
183       // If it's a function, we'll need to process further
184       const auto it = name_to_function.find(op);
185       if (it != name_to_function.end()) {
186         functions_to_process.push_back(it->second);
187       }
188     }
189   };
190   for (const auto& node : graph_def.node()) {
191     mark_op_as_used(node.op());
192   }
193   while (!functions_to_process.empty()) {
194     const FunctionDef* fun = functions_to_process.back();
195     functions_to_process.pop_back();
196     for (const auto& node : fun->node_def()) {
197       mark_op_as_used(node.op());
198     }
199   }
200 
201   // Filter out function names to produce output.
202   // TODO(josh11b): Change the above code to produce this directly.
203   ops_used_in_graph->clear();
204   for (const string& op_name : used_ops) {
205     if (name_to_function.find(op_name) == name_to_function.end()) {
206       ops_used_in_graph->insert(op_name);
207     }
208   }
209 }
210 
StrippedOpListForGraph(const GraphDef & graph_def,const OpRegistryInterface & op_registry,OpList * stripped_op_list)211 Status StrippedOpListForGraph(const GraphDef& graph_def,
212                               const OpRegistryInterface& op_registry,
213                               OpList* stripped_op_list) {
214   std::set<string> used_ops;
215   OpsUsedByGraph(graph_def, &used_ops);
216 
217   // Build the stripped op list in sorted order, ignoring functions.
218   stripped_op_list->clear_op();
219   for (const string& op_name : used_ops) {
220     const OpDef* op_def;
221     TF_RETURN_IF_ERROR(op_registry.LookUpOpDef(op_name, &op_def));
222     OpDef* stripped_op = stripped_op_list->add_op();
223     stripped_op->CopyFrom(*op_def);
224     RemoveDescriptionsFromOpDef(stripped_op);
225   }
226   return Status::OK();
227 }
228 
229 }  // namespace tensorflow
230