• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/graph_def_util.h"
17 
18 #include <set>
19 #include <unordered_map>
20 #include <unordered_set>
21 #include <vector>
22 
23 #include "tensorflow/core/framework/attr_value.pb.h"
24 #include "tensorflow/core/framework/function.h"
25 #include "tensorflow/core/framework/function.pb.h"
26 #include "tensorflow/core/framework/graph.pb.h"
27 #include "tensorflow/core/framework/node_def.pb.h"
28 #include "tensorflow/core/framework/node_def_util.h"
29 #include "tensorflow/core/framework/op_def_util.h"
30 #include "tensorflow/core/framework/versions.pb.h"
31 #include "tensorflow/core/lib/core/errors.h"
32 #include "tensorflow/core/lib/core/status.h"
33 #include "tensorflow/core/lib/strings/str_util.h"
34 #include "tensorflow/core/lib/strings/strcat.h"
35 
36 namespace tensorflow {
37 
SummarizeGraphDef(const GraphDef & graph_def)38 string SummarizeGraphDef(const GraphDef& graph_def) {
39   string ret;
40   strings::StrAppend(
41       &ret, "versions = ", graph_def.versions().ShortDebugString(), ";\n");
42   for (const NodeDef& node : graph_def.node()) {
43     strings::StrAppend(&ret, SummarizeNodeDef(node), ";\n");
44   }
45   return ret;
46 }
47 
ValidateExternalGraphDefSyntax(const GraphDef & graph_def)48 Status ValidateExternalGraphDefSyntax(const GraphDef& graph_def) {
49   for (const NodeDef& node : graph_def.node()) {
50     TF_RETURN_IF_ERROR(ValidateExternalNodeDefSyntax(node));
51   }
52   return Status::OK();
53 }
54 
AddDefaultAttrsToGraphDef(GraphDef * graph_def,const OpRegistryInterface & op_registry,int node_offset)55 Status AddDefaultAttrsToGraphDef(GraphDef* graph_def,
56                                  const OpRegistryInterface& op_registry,
57                                  int node_offset) {
58   return AddDefaultAttrsToGraphDef(graph_def, op_registry, node_offset, false);
59 }
60 
AddDefaultAttrsToGraphDef(GraphDef * graph_def,const OpRegistryInterface & op_registry,int node_offset,bool skip_unknown_ops)61 Status AddDefaultAttrsToGraphDef(GraphDef* graph_def,
62                                  const OpRegistryInterface& op_registry,
63                                  int node_offset, bool skip_unknown_ops) {
64   if (node_offset > graph_def->node_size()) {
65     return errors::InvalidArgument(
66         "Tried to add default attrs to GraphDef "
67         "starting at offset ",
68         node_offset, " with total nodes in graph: ", graph_def->node_size());
69   }
70 
71   for (int i = node_offset; i < graph_def->node_size(); ++i) {
72     NodeDef* node_def = graph_def->mutable_node(i);
73     const OpDef* op_def;
74     Status s = op_registry.LookUpOpDef(node_def->op(), &op_def);
75     if (s.ok()) {
76       AddDefaultsToNodeDef(*op_def, node_def);
77     } else if (!skip_unknown_ops) {
78       return s;
79     }
80   }
81 
82   return Status::OK();
83 }
84 
RemoveNewDefaultAttrsFromNodeDef(NodeDef * node_def,const OpRegistryInterface & consumer_op_registry,const OpRegistryInterface & producer_op_registry,std::set<std::pair<string,string>> * op_attr_removed)85 static Status RemoveNewDefaultAttrsFromNodeDef(
86     NodeDef* node_def, const OpRegistryInterface& consumer_op_registry,
87     const OpRegistryInterface& producer_op_registry,
88     std::set<std::pair<string, string>>* op_attr_removed) {
89   const OpDef* producer_op_def;
90   const OpDef* consumer_op_def;
91   TF_RETURN_IF_ERROR(
92       producer_op_registry.LookUpOpDef(node_def->op(), &producer_op_def));
93   TF_RETURN_IF_ERROR(
94       consumer_op_registry.LookUpOpDef(node_def->op(), &consumer_op_def));
95 
96   std::vector<string> to_remove;
97   for (const auto& attr : node_def->attr()) {
98     // If the attr is not in consumer_op_def and doesn't start with '_'...
99     if (!absl::StartsWith(attr.first, "_") &&
100         FindAttr(attr.first, *consumer_op_def) == nullptr) {
101       const OpDef::AttrDef* producer_attr_def =
102           FindAttr(attr.first, *producer_op_def);
103       if (producer_attr_def == nullptr) {
104         return errors::InvalidArgument(
105             "Attr '", attr.first,
106             "' missing in producer's OpDef: ", SummarizeOpDef(*producer_op_def),
107             " but found in node: ", FormatNodeDefForError(*node_def));
108       }
109       // ...and it has the same value as the default in producer,
110       if (producer_attr_def->has_default_value() &&
111           AreAttrValuesEqual(producer_attr_def->default_value(), attr.second)) {
112         // then we will remove it below.
113         to_remove.emplace_back(attr.first);
114       }
115     }
116   }
117   // We separate identifying which attrs should be removed from
118   // actually removing them to avoid invalidating the loop iterators
119   // above.
120   for (const string& attr_name : to_remove) {
121     node_def->mutable_attr()->erase(attr_name);
122     if (op_attr_removed != nullptr) {
123       op_attr_removed->insert(std::make_pair(node_def->op(), attr_name));
124     }
125   }
126 
127   return Status::OK();
128 }
129 
IsFunction(const GraphDef & graph_def,const string & op_name)130 static bool IsFunction(const GraphDef& graph_def, const string& op_name) {
131   for (const auto& func_def : graph_def.library().function()) {
132     if (op_name == func_def.signature().name()) return true;
133   }
134   return false;
135 }
136 
RemoveNewDefaultAttrsFromGraphDef(GraphDef * graph_def,const OpRegistryInterface & consumer_op_registry,const OpRegistryInterface & producer_op_registry,std::set<std::pair<string,string>> * op_attr_removed)137 Status RemoveNewDefaultAttrsFromGraphDef(
138     GraphDef* graph_def, const OpRegistryInterface& consumer_op_registry,
139     const OpRegistryInterface& producer_op_registry,
140     std::set<std::pair<string, string>>* op_attr_removed) {
141   // TODO(joshL): Make IsFunction() faster by collecting the names of
142   // all functions as a preprocessing step.
143   for (int n = 0; n < graph_def->node_size(); ++n) {
144     NodeDef* node_def = graph_def->mutable_node(n);
145     if (!IsFunction(*graph_def, node_def->op())) {
146       TF_RETURN_IF_ERROR(RemoveNewDefaultAttrsFromNodeDef(
147           node_def, consumer_op_registry, producer_op_registry,
148           op_attr_removed));
149     }
150   }
151   for (int f = 0; f < graph_def->library().function_size(); ++f) {
152     FunctionDef* func_def = graph_def->mutable_library()->mutable_function(f);
153     for (int n = 0; n < func_def->node_def_size(); ++n) {
154       NodeDef* node_def = func_def->mutable_node_def(n);
155       if (!IsFunction(*graph_def, node_def->op())) {
156         // TODO(josh11b): Better handling of attrs with placeholder values.
157         TF_RETURN_IF_ERROR(RemoveNewDefaultAttrsFromNodeDef(
158             node_def, consumer_op_registry, producer_op_registry,
159             op_attr_removed));
160       }
161     }
162   }
163 
164   return Status::OK();
165 }
166 
StripDefaultAttributes(const OpRegistryInterface & op_registry,protobuf::RepeatedPtrField<NodeDef> * nodes)167 void StripDefaultAttributes(const OpRegistryInterface& op_registry,
168                             protobuf::RepeatedPtrField<NodeDef>* nodes) {
169   for (int i = 0; i < nodes->size(); ++i) {
170     NodeDef* node = nodes->Mutable(i);
171 
172     const OpDef* op_def;
173     const OpRegistrationData* op_reg_data = nullptr;
174     Status s = op_registry.LookUp(node->op(), &op_reg_data);
175     if (!s.ok()) {
176       VLOG(1) << "Ignoring encountered unknown operation "
177               << SummarizeNodeDef(*node)
178               << " when stripping default attributes. It is likely a function, "
179                  "in which case ignoring it is fine";
180       continue;
181     }
182     op_def = &op_reg_data->op_def;
183 
184     for (const OpDef::AttrDef& attr_def : op_def->attr()) {
185       if (attr_def.has_default_value()) {
186         AttrValueMap* attrs = node->mutable_attr();
187         const string& name = attr_def.name();
188         auto iter = attrs->find(name);
189         if (iter != attrs->end()) {
190           const AttrValue& default_value = attr_def.default_value();
191           // The "Fast*" version can return false negatives for very large
192           // AttrValues containing Tensors. There should never be an attribute
193           // whose default value is a tensor larger than 32MB.
194           if (FastAreAttrValuesEqual(iter->second, default_value)) {
195             attrs->erase(name);
196           }
197         }
198       }
199     }
200   }
201 }
202 
OpsUsedByGraph(const GraphDef & graph_def,std::set<string> * ops_used_in_graph)203 void OpsUsedByGraph(const GraphDef& graph_def,
204                     std::set<string>* ops_used_in_graph) {
205   // Map function names to definitions.
206   std::unordered_map<string, const FunctionDef*> name_to_function;
207   for (const auto& function : graph_def.library().function()) {
208     name_to_function.insert(
209         std::make_pair(function.signature().name(), &function));
210   }
211 
212   // Collect the sorted list of op names.  Since functions can reference
213   // functions, we need a recursive traversal.
214   std::set<string> used_ops;  // Includes both primitive ops and functions
215   std::vector<const FunctionDef*> functions_to_process;  // A subset of used_ops
216   // Collect the logic to mark an op in a lambda; it'll be used twice below.
217   const auto mark_op_as_used = [&used_ops, &functions_to_process,
218                                 &name_to_function](const string& op) {
219     if (used_ops.insert(op).second) {
220       // If it's a function, we'll need to process further
221       const auto it = name_to_function.find(op);
222       if (it != name_to_function.end()) {
223         functions_to_process.push_back(it->second);
224       }
225     }
226   };
227   for (const auto& node : graph_def.node()) {
228     mark_op_as_used(node.op());
229   }
230   while (!functions_to_process.empty()) {
231     const FunctionDef* fun = functions_to_process.back();
232     functions_to_process.pop_back();
233     for (const auto& node : fun->node_def()) {
234       mark_op_as_used(node.op());
235     }
236   }
237 
238   // Filter out function names to produce output.
239   // TODO(josh11b): Change the above code to produce this directly.
240   ops_used_in_graph->clear();
241   for (const string& op_name : used_ops) {
242     if (name_to_function.find(op_name) == name_to_function.end()) {
243       ops_used_in_graph->insert(op_name);
244     }
245   }
246 }
247 
StrippedOpListForGraph(const GraphDef & graph_def,const OpRegistryInterface & op_registry,OpList * stripped_op_list)248 Status StrippedOpListForGraph(const GraphDef& graph_def,
249                               const OpRegistryInterface& op_registry,
250                               OpList* stripped_op_list) {
251   std::set<string> used_ops;
252   OpsUsedByGraph(graph_def, &used_ops);
253 
254   // Build the stripped op list in sorted order, ignoring functions.
255   stripped_op_list->clear_op();
256   for (const string& op_name : used_ops) {
257     const OpDef* op_def;
258     TF_RETURN_IF_ERROR(op_registry.LookUpOpDef(op_name, &op_def));
259     OpDef* stripped_op = stripped_op_list->add_op();
260     stripped_op->CopyFrom(*op_def);
261     RemoveDescriptionsFromOpDef(stripped_op);
262   }
263   return Status::OK();
264 }
265 
266 }  // namespace tensorflow
267