1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/framework/graph_def_util.h"
17
18 #include <set>
19 #include <unordered_map>
20 #include <unordered_set>
21 #include <vector>
22
23 #include "tensorflow/core/framework/attr_value.pb.h"
24 #include "tensorflow/core/framework/function.h"
25 #include "tensorflow/core/framework/function.pb.h"
26 #include "tensorflow/core/framework/graph.pb.h"
27 #include "tensorflow/core/framework/node_def.pb.h"
28 #include "tensorflow/core/framework/node_def_util.h"
29 #include "tensorflow/core/framework/op_def_util.h"
30 #include "tensorflow/core/framework/versions.pb.h"
31 #include "tensorflow/core/lib/core/errors.h"
32 #include "tensorflow/core/lib/core/status.h"
33 #include "tensorflow/core/lib/strings/str_util.h"
34 #include "tensorflow/core/lib/strings/strcat.h"
35
36 namespace tensorflow {
37
SummarizeGraphDef(const GraphDef & graph_def)38 string SummarizeGraphDef(const GraphDef& graph_def) {
39 string ret;
40 strings::StrAppend(
41 &ret, "versions = ", graph_def.versions().ShortDebugString(), ";\n");
42 for (const NodeDef& node : graph_def.node()) {
43 strings::StrAppend(&ret, SummarizeNodeDef(node), ";\n");
44 }
45 return ret;
46 }
47
ValidateExternalGraphDefSyntax(const GraphDef & graph_def)48 Status ValidateExternalGraphDefSyntax(const GraphDef& graph_def) {
49 for (const NodeDef& node : graph_def.node()) {
50 TF_RETURN_IF_ERROR(ValidateExternalNodeDefSyntax(node));
51 }
52 return Status::OK();
53 }
54
AddDefaultAttrsToGraphDef(GraphDef * graph_def,const OpRegistryInterface & op_registry,int node_offset)55 Status AddDefaultAttrsToGraphDef(GraphDef* graph_def,
56 const OpRegistryInterface& op_registry,
57 int node_offset) {
58 return AddDefaultAttrsToGraphDef(graph_def, op_registry, node_offset, false);
59 }
60
AddDefaultAttrsToGraphDef(GraphDef * graph_def,const OpRegistryInterface & op_registry,int node_offset,bool skip_unknown_ops)61 Status AddDefaultAttrsToGraphDef(GraphDef* graph_def,
62 const OpRegistryInterface& op_registry,
63 int node_offset, bool skip_unknown_ops) {
64 if (node_offset > graph_def->node_size()) {
65 return errors::InvalidArgument(
66 "Tried to add default attrs to GraphDef "
67 "starting at offset ",
68 node_offset, " with total nodes in graph: ", graph_def->node_size());
69 }
70
71 for (int i = node_offset; i < graph_def->node_size(); ++i) {
72 NodeDef* node_def = graph_def->mutable_node(i);
73 const OpDef* op_def;
74 Status s = op_registry.LookUpOpDef(node_def->op(), &op_def);
75 if (s.ok()) {
76 AddDefaultsToNodeDef(*op_def, node_def);
77 } else if (!skip_unknown_ops) {
78 return s;
79 }
80 }
81
82 return Status::OK();
83 }
84
RemoveNewDefaultAttrsFromNodeDef(NodeDef * node_def,const OpRegistryInterface & consumer_op_registry,const OpRegistryInterface & producer_op_registry,std::set<std::pair<string,string>> * op_attr_removed)85 static Status RemoveNewDefaultAttrsFromNodeDef(
86 NodeDef* node_def, const OpRegistryInterface& consumer_op_registry,
87 const OpRegistryInterface& producer_op_registry,
88 std::set<std::pair<string, string>>* op_attr_removed) {
89 const OpDef* producer_op_def;
90 const OpDef* consumer_op_def;
91 TF_RETURN_IF_ERROR(
92 producer_op_registry.LookUpOpDef(node_def->op(), &producer_op_def));
93 TF_RETURN_IF_ERROR(
94 consumer_op_registry.LookUpOpDef(node_def->op(), &consumer_op_def));
95
96 std::vector<string> to_remove;
97 for (const auto& attr : node_def->attr()) {
98 // If the attr is not in consumer_op_def and doesn't start with '_'...
99 if (!absl::StartsWith(attr.first, "_") &&
100 FindAttr(attr.first, *consumer_op_def) == nullptr) {
101 const OpDef::AttrDef* producer_attr_def =
102 FindAttr(attr.first, *producer_op_def);
103 if (producer_attr_def == nullptr) {
104 return errors::InvalidArgument(
105 "Attr '", attr.first,
106 "' missing in producer's OpDef: ", SummarizeOpDef(*producer_op_def),
107 " but found in node: ", FormatNodeDefForError(*node_def));
108 }
109 // ...and it has the same value as the default in producer,
110 if (producer_attr_def->has_default_value() &&
111 AreAttrValuesEqual(producer_attr_def->default_value(), attr.second)) {
112 // then we will remove it below.
113 to_remove.emplace_back(attr.first);
114 }
115 }
116 }
117 // We separate identifying which attrs should be removed from
118 // actually removing them to avoid invalidating the loop iterators
119 // above.
120 for (const string& attr_name : to_remove) {
121 node_def->mutable_attr()->erase(attr_name);
122 if (op_attr_removed != nullptr) {
123 op_attr_removed->insert(std::make_pair(node_def->op(), attr_name));
124 }
125 }
126
127 return Status::OK();
128 }
129
IsFunction(const GraphDef & graph_def,const string & op_name)130 static bool IsFunction(const GraphDef& graph_def, const string& op_name) {
131 for (const auto& func_def : graph_def.library().function()) {
132 if (op_name == func_def.signature().name()) return true;
133 }
134 return false;
135 }
136
RemoveNewDefaultAttrsFromGraphDef(GraphDef * graph_def,const OpRegistryInterface & consumer_op_registry,const OpRegistryInterface & producer_op_registry,std::set<std::pair<string,string>> * op_attr_removed)137 Status RemoveNewDefaultAttrsFromGraphDef(
138 GraphDef* graph_def, const OpRegistryInterface& consumer_op_registry,
139 const OpRegistryInterface& producer_op_registry,
140 std::set<std::pair<string, string>>* op_attr_removed) {
141 // TODO(joshL): Make IsFunction() faster by collecting the names of
142 // all functions as a preprocessing step.
143 for (int n = 0; n < graph_def->node_size(); ++n) {
144 NodeDef* node_def = graph_def->mutable_node(n);
145 if (!IsFunction(*graph_def, node_def->op())) {
146 TF_RETURN_IF_ERROR(RemoveNewDefaultAttrsFromNodeDef(
147 node_def, consumer_op_registry, producer_op_registry,
148 op_attr_removed));
149 }
150 }
151 for (int f = 0; f < graph_def->library().function_size(); ++f) {
152 FunctionDef* func_def = graph_def->mutable_library()->mutable_function(f);
153 for (int n = 0; n < func_def->node_def_size(); ++n) {
154 NodeDef* node_def = func_def->mutable_node_def(n);
155 if (!IsFunction(*graph_def, node_def->op())) {
156 // TODO(josh11b): Better handling of attrs with placeholder values.
157 TF_RETURN_IF_ERROR(RemoveNewDefaultAttrsFromNodeDef(
158 node_def, consumer_op_registry, producer_op_registry,
159 op_attr_removed));
160 }
161 }
162 }
163
164 return Status::OK();
165 }
166
StripDefaultAttributes(const OpRegistryInterface & op_registry,protobuf::RepeatedPtrField<NodeDef> * nodes)167 void StripDefaultAttributes(const OpRegistryInterface& op_registry,
168 protobuf::RepeatedPtrField<NodeDef>* nodes) {
169 for (int i = 0; i < nodes->size(); ++i) {
170 NodeDef* node = nodes->Mutable(i);
171
172 const OpDef* op_def;
173 const OpRegistrationData* op_reg_data = nullptr;
174 Status s = op_registry.LookUp(node->op(), &op_reg_data);
175 if (!s.ok()) {
176 VLOG(1) << "Ignoring encountered unknown operation "
177 << SummarizeNodeDef(*node)
178 << " when stripping default attributes. It is likely a function, "
179 "in which case ignoring it is fine";
180 continue;
181 }
182 op_def = &op_reg_data->op_def;
183
184 for (const OpDef::AttrDef& attr_def : op_def->attr()) {
185 if (attr_def.has_default_value()) {
186 AttrValueMap* attrs = node->mutable_attr();
187 const string& name = attr_def.name();
188 auto iter = attrs->find(name);
189 if (iter != attrs->end()) {
190 const AttrValue& default_value = attr_def.default_value();
191 // The "Fast*" version can return false negatives for very large
192 // AttrValues containing Tensors. There should never be an attribute
193 // whose default value is a tensor larger than 32MB.
194 if (FastAreAttrValuesEqual(iter->second, default_value)) {
195 attrs->erase(name);
196 }
197 }
198 }
199 }
200 }
201 }
202
OpsUsedByGraph(const GraphDef & graph_def,std::set<string> * ops_used_in_graph)203 void OpsUsedByGraph(const GraphDef& graph_def,
204 std::set<string>* ops_used_in_graph) {
205 // Map function names to definitions.
206 std::unordered_map<string, const FunctionDef*> name_to_function;
207 for (const auto& function : graph_def.library().function()) {
208 name_to_function.insert(
209 std::make_pair(function.signature().name(), &function));
210 }
211
212 // Collect the sorted list of op names. Since functions can reference
213 // functions, we need a recursive traversal.
214 std::set<string> used_ops; // Includes both primitive ops and functions
215 std::vector<const FunctionDef*> functions_to_process; // A subset of used_ops
216 // Collect the logic to mark an op in a lambda; it'll be used twice below.
217 const auto mark_op_as_used = [&used_ops, &functions_to_process,
218 &name_to_function](const string& op) {
219 if (used_ops.insert(op).second) {
220 // If it's a function, we'll need to process further
221 const auto it = name_to_function.find(op);
222 if (it != name_to_function.end()) {
223 functions_to_process.push_back(it->second);
224 }
225 }
226 };
227 for (const auto& node : graph_def.node()) {
228 mark_op_as_used(node.op());
229 }
230 while (!functions_to_process.empty()) {
231 const FunctionDef* fun = functions_to_process.back();
232 functions_to_process.pop_back();
233 for (const auto& node : fun->node_def()) {
234 mark_op_as_used(node.op());
235 }
236 }
237
238 // Filter out function names to produce output.
239 // TODO(josh11b): Change the above code to produce this directly.
240 ops_used_in_graph->clear();
241 for (const string& op_name : used_ops) {
242 if (name_to_function.find(op_name) == name_to_function.end()) {
243 ops_used_in_graph->insert(op_name);
244 }
245 }
246 }
247
StrippedOpListForGraph(const GraphDef & graph_def,const OpRegistryInterface & op_registry,OpList * stripped_op_list)248 Status StrippedOpListForGraph(const GraphDef& graph_def,
249 const OpRegistryInterface& op_registry,
250 OpList* stripped_op_list) {
251 std::set<string> used_ops;
252 OpsUsedByGraph(graph_def, &used_ops);
253
254 // Build the stripped op list in sorted order, ignoring functions.
255 stripped_op_list->clear_op();
256 for (const string& op_name : used_ops) {
257 const OpDef* op_def;
258 TF_RETURN_IF_ERROR(op_registry.LookUpOpDef(op_name, &op_def));
259 OpDef* stripped_op = stripped_op_list->add_op();
260 stripped_op->CopyFrom(*op_def);
261 RemoveDescriptionsFromOpDef(stripped_op);
262 }
263 return Status::OK();
264 }
265
266 } // namespace tensorflow
267