1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/cc/tools/freeze_saved_model.h"
17
18 #include <iostream>
19 #include <queue>
20
21 #include "tensorflow/core/framework/attr_value.pb.h"
22 #include "tensorflow/core/framework/function.pb.h"
23 #include "tensorflow/core/framework/graph.pb.h"
24 #include "tensorflow/core/framework/node_def.pb.h"
25 #include "tensorflow/core/framework/versions.pb.h"
26 #include "tensorflow/core/lib/core/errors.h"
27 #include "tensorflow/core/lib/strings/str_util.h"
28 #include "tensorflow/core/protobuf/meta_graph.pb.h"
29
30 namespace tensorflow {
31
32 namespace {
33
34 // Gets tensor names from tensor_info and inserts them into the set of tensor
35 // names.
GetTensorNamesFromTensorInfo(const TensorInfo & tensor_info,std::unordered_set<string> * tensor_names)36 void GetTensorNamesFromTensorInfo(const TensorInfo& tensor_info,
37 std::unordered_set<string>* tensor_names) {
38 if (tensor_info.has_coo_sparse()) {
39 // If the tensor is sparse we have to add all three tensors of the sparse
40 // representations.
41 const TensorInfo_CooSparse& coo_sparse = tensor_info.coo_sparse();
42 tensor_names->insert(coo_sparse.values_tensor_name());
43 tensor_names->insert(coo_sparse.indices_tensor_name());
44 tensor_names->insert(coo_sparse.dense_shape_tensor_name());
45 } else {
46 tensor_names->insert(tensor_info.name());
47 }
48 }
49
50 // Gets the union of all inputs and outputs of all SignatureDefs in the bundle
GetSignatureDefsInputsAndOutputs(const SavedModelBundle & saved_model_bundle,std::unordered_set<string> * inputs,std::unordered_set<string> * outputs)51 void GetSignatureDefsInputsAndOutputs(
52 const SavedModelBundle& saved_model_bundle,
53 std::unordered_set<string>* inputs, std::unordered_set<string>* outputs) {
54 for (auto& sigdef_elem : saved_model_bundle.meta_graph_def.signature_def()) {
55 const SignatureDef& signature_def = sigdef_elem.second;
56 for (auto& input_elem : signature_def.inputs()) {
57 GetTensorNamesFromTensorInfo(input_elem.second, inputs);
58 }
59 for (auto& output_elem : signature_def.outputs()) {
60 GetTensorNamesFromTensorInfo(output_elem.second, outputs);
61 }
62 }
63 }
64
65 // Gets a map from string node name to NodeDef.
GetNodeNameToNodeDefMap(GraphDef * graph_def,std::unordered_map<string,NodeDef * > * name_to_node_map)66 void GetNodeNameToNodeDefMap(
67 GraphDef* graph_def,
68 std::unordered_map<string, NodeDef*>* name_to_node_map) {
69 for (size_t i = 0; i < graph_def->node_size(); i++) {
70 NodeDef* node = graph_def->mutable_node(i);
71 (*name_to_node_map)[node->name()] = node;
72 }
73 }
74
75 // Strips off the tensor part of the tensor_name to get the node_name.
GetNodeNameFromTensorName(string tensor_name)76 const string GetNodeNameFromTensorName(string tensor_name) {
77 if (tensor_name[0] == '^') {
78 tensor_name.erase(0, 1);
79 }
80 std::vector<string> tensor_name_parts = str_util::Split(tensor_name, ':');
81 return tensor_name_parts[0];
82 }
83
84 // Gets the set of node names needed by `outputs` and the corresponding set of
85 // variable nodes to convert.
GetReachableNodesAndVariables(GraphDef * graph_def,const std::unordered_set<string> & outputs,const std::unordered_map<string,NodeDef * > & name_to_node_map,std::unordered_set<string> * reachable_node_names,std::unordered_set<string> * variable_node_names)86 void GetReachableNodesAndVariables(
87 GraphDef* graph_def, const std::unordered_set<string>& outputs,
88 const std::unordered_map<string, NodeDef*>& name_to_node_map,
89 std::unordered_set<string>* reachable_node_names,
90 std::unordered_set<string>* variable_node_names) {
91 // TODO(suharshs): Add support for ResourceVariables.
92 static const std::unordered_set<string>* kVariableTypes =
93 new std::unordered_set<string>({"Variable", "VariableV2", "VarHandleOp"});
94
95 std::queue<string> nodes_to_visit;
96 for (const string& output_tensor_name : outputs) {
97 nodes_to_visit.push(GetNodeNameFromTensorName(output_tensor_name));
98 }
99 // We do a traversal backwards from the outputs specified in the MetaGraphDef.
100 while (!nodes_to_visit.empty()) {
101 const string node_name = nodes_to_visit.front();
102 nodes_to_visit.pop();
103 if (reachable_node_names->find(node_name) != reachable_node_names->end()) {
104 continue;
105 }
106 reachable_node_names->insert(node_name);
107 NodeDef* node = name_to_node_map.at(node_name);
108 if (kVariableTypes->find(node->op()) != kVariableTypes->end()) {
109 variable_node_names->insert(node->name());
110 }
111 for (const string& input_tensor_name : node->input()) {
112 nodes_to_visit.push(GetNodeNameFromTensorName(input_tensor_name));
113 }
114 }
115 }
116
117 // Gets a map from variable name to variable value.
GetVariableNameToTensorMap(Session * session,const std::unordered_map<string,NodeDef * > & name_to_node_map,std::unordered_set<string> variable_names_set,std::unordered_map<string,Tensor> * variable_name_to_value_map)118 Status GetVariableNameToTensorMap(
119 Session* session,
120 const std::unordered_map<string, NodeDef*>& name_to_node_map,
121 std::unordered_set<string> variable_names_set,
122 std::unordered_map<string, Tensor>* variable_name_to_value_map) {
123 if (variable_names_set.empty()) {
124 return Status::OK();
125 }
126 std::vector<string> variable_names;
127 variable_names.reserve(variable_names_set.size());
128 std::vector<string> tensor_names;
129 tensor_names.reserve(variable_names_set.size());
130 for (const string& node_name : variable_names_set) {
131 variable_names.push_back(node_name);
132 NodeDef* node_def = name_to_node_map.at(node_name);
133 if (node_def->op() == "VarHandleOp") {
134 // If this is a resource variable, we have to run the corresponding
135 // ReadVariableOp.
136 tensor_names.push_back(node_name + "/Read/ReadVariableOp:0");
137 } else {
138 tensor_names.push_back(node_name + ":0");
139 }
140 }
141 std::vector<Tensor> outputs;
142 TF_RETURN_IF_ERROR(
143 session->Run(/* inputs */ {}, tensor_names, /* targets */ {}, &outputs));
144 for (size_t i = 0; i < variable_names.size(); i++) {
145 (*variable_name_to_value_map)[variable_names[i]] = outputs[i];
146 }
147 return Status::OK();
148 }
149
150 // Converts a Variable NodeDef into a Constant NodeDef.
ConvertVariableToConstant(const NodeDef & variable_node,const Tensor & variable_value,NodeDef * const_node)151 void ConvertVariableToConstant(const NodeDef& variable_node,
152 const Tensor& variable_value,
153 NodeDef* const_node) {
154 const_node->set_name(variable_node.name());
155 const_node->set_op("Const");
156 (*const_node->mutable_attr())["dtype"] = variable_node.attr().at("dtype");
157 variable_value.AsProtoTensorContent(
158 (*const_node->mutable_attr())["value"].mutable_tensor());
159 }
160
161 // Converts a ReadVariableOp NodeDef to an Identity NodeDef.
ConvertReadVariableOpToIdentity(const NodeDef & node,NodeDef * identity_node)162 void ConvertReadVariableOpToIdentity(const NodeDef& node,
163 NodeDef* identity_node) {
164 identity_node->set_name(node.name());
165 identity_node->set_op("Identity");
166 (*identity_node->mutable_attr())["T"] = node.attr().at("dtype");
167 identity_node->add_input(node.input(0));
168 }
169
170 // Freezes the subgraph of all nodes needed by `outputs`.
FreezeGraphDef(const SavedModelBundle & saved_model_bundle,const std::unordered_set<string> & outputs,GraphDef * frozen_graph_def)171 Status FreezeGraphDef(const SavedModelBundle& saved_model_bundle,
172 const std::unordered_set<string>& outputs,
173 GraphDef* frozen_graph_def) {
174 GraphDef graph_def = saved_model_bundle.meta_graph_def.graph_def();
175 // Copy versions and library as-is from original graph.
176 *frozen_graph_def->mutable_versions() = graph_def.versions();
177 *frozen_graph_def->mutable_library() = graph_def.library();
178 // If the graph is empty there is nothing left to do.
179 if (graph_def.node_size() == 0) {
180 return Status::OK();
181 }
182 // name_to_node_map is needed to get the inputs from the NodeDef corresponding
183 // the a string node name. These inputs are used when doing our backwards
184 // traversal.
185 std::unordered_map<string, NodeDef*> name_to_node_map;
186 GetNodeNameToNodeDefMap(&graph_def, &name_to_node_map);
187 std::unordered_set<string> reachable_node_names;
188 std::unordered_set<string> variable_node_names;
189 GetReachableNodesAndVariables(&graph_def, outputs, name_to_node_map,
190 &reachable_node_names, &variable_node_names);
191 std::unordered_map<string, Tensor> variable_to_value_map;
192 TF_RETURN_IF_ERROR(GetVariableNameToTensorMap(
193 saved_model_bundle.session.get(), name_to_node_map, variable_node_names,
194 &variable_to_value_map));
195 // We copy the nodes in the same order they were in the original graph_def.
196 for (const NodeDef& node : graph_def.node()) {
197 if (reachable_node_names.find(node.name()) == reachable_node_names.end()) {
198 continue;
199 }
200 if (variable_node_names.find(node.name()) != variable_node_names.end()) {
201 ConvertVariableToConstant(node, variable_to_value_map[node.name()],
202 frozen_graph_def->add_node());
203 } else if (node.op() == "ReadVariableOp" &&
204 variable_node_names.find(node.input(0)) !=
205 variable_node_names.end()) {
206 // If the node is a ReadVariableOp, its input VarHandleOp will be
207 // converted to a Constant, so we will need to convert it to an Identity.
208 ConvertReadVariableOpToIdentity(node, frozen_graph_def->add_node());
209 } else {
210 // If the node isn't a variable, just copy the node as-is.
211 *frozen_graph_def->add_node() = node;
212 }
213 }
214 return Status::OK();
215 }
216
217 } // namespace
218
FreezeSavedModel(const SavedModelBundle & saved_model_bundle,GraphDef * frozen_graph_def,std::unordered_set<string> * inputs,std::unordered_set<string> * outputs)219 Status FreezeSavedModel(const SavedModelBundle& saved_model_bundle,
220 GraphDef* frozen_graph_def,
221 std::unordered_set<string>* inputs,
222 std::unordered_set<string>* outputs) {
223 GetSignatureDefsInputsAndOutputs(saved_model_bundle, inputs, outputs);
224 TF_RETURN_IF_ERROR(
225 FreezeGraphDef(saved_model_bundle, *outputs, frozen_graph_def));
226 return Status::OK();
227 }
228
229 } // namespace tensorflow
230