1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/framework/graph_def_util.h"
17
18 #include <set>
19 #include <unordered_map>
20 #include <unordered_set>
21 #include <vector>
22
23 #include "tensorflow/core/framework/attr_value.pb.h"
24 #include "tensorflow/core/framework/function.pb.h"
25 #include "tensorflow/core/framework/graph.pb.h"
26 #include "tensorflow/core/framework/node_def.pb.h"
27 #include "tensorflow/core/framework/node_def_util.h"
28 #include "tensorflow/core/framework/op_def_util.h"
29 #include "tensorflow/core/framework/versions.pb_text.h"
30 #include "tensorflow/core/lib/core/errors.h"
31 #include "tensorflow/core/lib/core/status.h"
32 #include "tensorflow/core/lib/strings/str_util.h"
33 #include "tensorflow/core/lib/strings/strcat.h"
34
35 namespace tensorflow {
36
SummarizeGraphDef(const GraphDef & graph_def)37 string SummarizeGraphDef(const GraphDef& graph_def) {
38 string ret;
39 strings::StrAppend(
40 &ret, "versions = ", ProtoShortDebugString(graph_def.versions()), ";\n");
41 for (const NodeDef& node : graph_def.node()) {
42 strings::StrAppend(&ret, SummarizeNodeDef(node), ";\n");
43 }
44 return ret;
45 }
46
ValidateExternalGraphDefSyntax(const GraphDef & graph_def)47 Status ValidateExternalGraphDefSyntax(const GraphDef& graph_def) {
48 for (const NodeDef& node : graph_def.node()) {
49 TF_RETURN_IF_ERROR(ValidateExternalNodeDefSyntax(node));
50 }
51 return Status::OK();
52 }
53
AddDefaultAttrsToGraphDef(GraphDef * graph_def,const OpRegistryInterface & op_registry,int node_offset)54 Status AddDefaultAttrsToGraphDef(GraphDef* graph_def,
55 const OpRegistryInterface& op_registry,
56 int node_offset) {
57 return AddDefaultAttrsToGraphDef(graph_def, op_registry, node_offset, false);
58 }
59
AddDefaultAttrsToGraphDef(GraphDef * graph_def,const OpRegistryInterface & op_registry,int node_offset,bool skip_unknown_ops)60 Status AddDefaultAttrsToGraphDef(GraphDef* graph_def,
61 const OpRegistryInterface& op_registry,
62 int node_offset, bool skip_unknown_ops) {
63 if (node_offset > graph_def->node_size()) {
64 return errors::InvalidArgument(
65 "Tried to add default attrs to GraphDef "
66 "starting at offset ",
67 node_offset, " with total nodes in graph: ", graph_def->node_size());
68 }
69
70 for (int i = node_offset; i < graph_def->node_size(); ++i) {
71 NodeDef* node_def = graph_def->mutable_node(i);
72 const OpDef* op_def;
73 Status s = op_registry.LookUpOpDef(node_def->op(), &op_def);
74 if (s.ok()) {
75 AddDefaultsToNodeDef(*op_def, node_def);
76 } else if (!skip_unknown_ops) {
77 return s;
78 }
79 }
80
81 return Status::OK();
82 }
83
RemoveNewDefaultAttrsFromNodeDef(NodeDef * node_def,const OpRegistryInterface & consumer_op_registry,const OpRegistryInterface & producer_op_registry,std::set<std::pair<string,string>> * op_attr_removed)84 static Status RemoveNewDefaultAttrsFromNodeDef(
85 NodeDef* node_def, const OpRegistryInterface& consumer_op_registry,
86 const OpRegistryInterface& producer_op_registry,
87 std::set<std::pair<string, string>>* op_attr_removed) {
88 const OpDef* producer_op_def;
89 const OpDef* consumer_op_def;
90 TF_RETURN_IF_ERROR(
91 producer_op_registry.LookUpOpDef(node_def->op(), &producer_op_def));
92 TF_RETURN_IF_ERROR(
93 consumer_op_registry.LookUpOpDef(node_def->op(), &consumer_op_def));
94
95 std::vector<string> to_remove;
96 for (const auto& attr : node_def->attr()) {
97 // If the attr is not in consumer_op_def and doesn't start with '_'...
98 if (!str_util::StartsWith(attr.first, "_") &&
99 FindAttr(attr.first, *consumer_op_def) == nullptr) {
100 const OpDef::AttrDef* producer_attr_def =
101 FindAttr(attr.first, *producer_op_def);
102 if (producer_attr_def == nullptr) {
103 return errors::InvalidArgument(
104 "Attr '", attr.first,
105 "' missing in producer's OpDef: ", SummarizeOpDef(*producer_op_def),
106 " but found in node: ", FormatNodeDefForError(*node_def));
107 }
108 // ...and it has the same value as the default in producer,
109 if (producer_attr_def->has_default_value() &&
110 AreAttrValuesEqual(producer_attr_def->default_value(), attr.second)) {
111 // then we will remove it below.
112 to_remove.emplace_back(attr.first);
113 }
114 }
115 }
116 // We separate identifying which attrs should be removed from
117 // actually removing them to avoid invalidating the loop iterators
118 // above.
119 for (const string& attr_name : to_remove) {
120 node_def->mutable_attr()->erase(attr_name);
121 if (op_attr_removed != nullptr) {
122 op_attr_removed->insert(std::make_pair(node_def->op(), attr_name));
123 }
124 }
125
126 return Status::OK();
127 }
128
IsFunction(const GraphDef & graph_def,const string & op_name)129 static bool IsFunction(const GraphDef& graph_def, const string& op_name) {
130 for (const auto& func_def : graph_def.library().function()) {
131 if (op_name == func_def.signature().name()) return true;
132 }
133 return false;
134 }
135
RemoveNewDefaultAttrsFromGraphDef(GraphDef * graph_def,const OpRegistryInterface & consumer_op_registry,const OpRegistryInterface & producer_op_registry,std::set<std::pair<string,string>> * op_attr_removed)136 Status RemoveNewDefaultAttrsFromGraphDef(
137 GraphDef* graph_def, const OpRegistryInterface& consumer_op_registry,
138 const OpRegistryInterface& producer_op_registry,
139 std::set<std::pair<string, string>>* op_attr_removed) {
140 // TODO(joshL): Make IsFunction() faster by collecting the names of
141 // all functions as a preprocessing step.
142 for (int n = 0; n < graph_def->node_size(); ++n) {
143 NodeDef* node_def = graph_def->mutable_node(n);
144 if (!IsFunction(*graph_def, node_def->op())) {
145 TF_RETURN_IF_ERROR(RemoveNewDefaultAttrsFromNodeDef(
146 node_def, consumer_op_registry, producer_op_registry,
147 op_attr_removed));
148 }
149 }
150 for (int f = 0; f < graph_def->library().function_size(); ++f) {
151 FunctionDef* func_def = graph_def->mutable_library()->mutable_function(f);
152 for (int n = 0; n < func_def->node_def_size(); ++n) {
153 NodeDef* node_def = func_def->mutable_node_def(n);
154 if (!IsFunction(*graph_def, node_def->op())) {
155 // TODO(josh11b): Better handling of attrs with placeholder values.
156 TF_RETURN_IF_ERROR(RemoveNewDefaultAttrsFromNodeDef(
157 node_def, consumer_op_registry, producer_op_registry,
158 op_attr_removed));
159 }
160 }
161 }
162
163 return Status::OK();
164 }
165
OpsUsedByGraph(const GraphDef & graph_def,std::set<string> * ops_used_in_graph)166 void OpsUsedByGraph(const GraphDef& graph_def,
167 std::set<string>* ops_used_in_graph) {
168 // Map function names to definitions.
169 std::unordered_map<string, const FunctionDef*> name_to_function;
170 for (const auto& function : graph_def.library().function()) {
171 name_to_function.insert(
172 std::make_pair(function.signature().name(), &function));
173 }
174
175 // Collect the sorted list of op names. Since functions can reference
176 // functions, we need a recursive traversal.
177 std::set<string> used_ops; // Includes both primitive ops and functions
178 std::vector<const FunctionDef*> functions_to_process; // A subset of used_ops
179 // Collect the logic to mark an op in a lambda; it'll be used twice below.
180 const auto mark_op_as_used = [&used_ops, &functions_to_process,
181 &name_to_function](const string& op) {
182 if (used_ops.insert(op).second) {
183 // If it's a function, we'll need to process further
184 const auto it = name_to_function.find(op);
185 if (it != name_to_function.end()) {
186 functions_to_process.push_back(it->second);
187 }
188 }
189 };
190 for (const auto& node : graph_def.node()) {
191 mark_op_as_used(node.op());
192 }
193 while (!functions_to_process.empty()) {
194 const FunctionDef* fun = functions_to_process.back();
195 functions_to_process.pop_back();
196 for (const auto& node : fun->node_def()) {
197 mark_op_as_used(node.op());
198 }
199 }
200
201 // Filter out function names to produce output.
202 // TODO(josh11b): Change the above code to produce this directly.
203 ops_used_in_graph->clear();
204 for (const string& op_name : used_ops) {
205 if (name_to_function.find(op_name) == name_to_function.end()) {
206 ops_used_in_graph->insert(op_name);
207 }
208 }
209 }
210
StrippedOpListForGraph(const GraphDef & graph_def,const OpRegistryInterface & op_registry,OpList * stripped_op_list)211 Status StrippedOpListForGraph(const GraphDef& graph_def,
212 const OpRegistryInterface& op_registry,
213 OpList* stripped_op_list) {
214 std::set<string> used_ops;
215 OpsUsedByGraph(graph_def, &used_ops);
216
217 // Build the stripped op list in sorted order, ignoring functions.
218 stripped_op_list->clear_op();
219 for (const string& op_name : used_ops) {
220 const OpDef* op_def;
221 TF_RETURN_IF_ERROR(op_registry.LookUpOpDef(op_name, &op_def));
222 OpDef* stripped_op = stripped_op_list->add_op();
223 stripped_op->CopyFrom(*op_def);
224 RemoveDescriptionsFromOpDef(stripped_op);
225 }
226 return Status::OK();
227 }
228
229 } // namespace tensorflow
230