1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/cc/tools/freeze_saved_model.h"
17
18 #include <iostream>
19 #include <queue>
20
21 #include "tensorflow/core/framework/attr_value.pb.h"
22 #include "tensorflow/core/framework/function.pb.h"
23 #include "tensorflow/core/framework/graph.pb.h"
24 #include "tensorflow/core/framework/node_def.pb.h"
25 #include "tensorflow/core/framework/versions.pb.h"
26 #include "tensorflow/core/lib/core/errors.h"
27 #include "tensorflow/core/lib/strings/str_util.h"
28 #include "tensorflow/core/protobuf/meta_graph.pb.h"
29
30 namespace tensorflow {
31
32 namespace {
33
34 // Gets tensor names from tensor_info and inserts them into the set of tensor
35 // names.
GetTensorNamesFromTensorInfo(const TensorInfo & tensor_info,std::unordered_set<string> * tensor_names)36 void GetTensorNamesFromTensorInfo(const TensorInfo& tensor_info,
37 std::unordered_set<string>* tensor_names) {
38 if (tensor_info.has_coo_sparse()) {
39 // If the tensor is sparse we have to add all three tensors of the sparse
40 // representations.
41 const TensorInfo_CooSparse& coo_sparse = tensor_info.coo_sparse();
42 tensor_names->insert(coo_sparse.values_tensor_name());
43 tensor_names->insert(coo_sparse.indices_tensor_name());
44 tensor_names->insert(coo_sparse.dense_shape_tensor_name());
45 } else if (tensor_info.has_composite_tensor()) {
46 for (const auto& component : tensor_info.composite_tensor().components()) {
47 tensor_names->insert(component.name());
48 }
49 } else {
50 tensor_names->insert(tensor_info.name());
51 }
52 }
53
54 // Gets the union of all inputs and outputs of all SignatureDefs in the bundle
GetSignatureDefsInputsAndOutputs(const SavedModelBundle & saved_model_bundle,std::unordered_set<string> * inputs,std::unordered_set<string> * outputs)55 void GetSignatureDefsInputsAndOutputs(
56 const SavedModelBundle& saved_model_bundle,
57 std::unordered_set<string>* inputs, std::unordered_set<string>* outputs) {
58 for (auto& sigdef_elem : saved_model_bundle.meta_graph_def.signature_def()) {
59 const SignatureDef& signature_def = sigdef_elem.second;
60 for (auto& input_elem : signature_def.inputs()) {
61 GetTensorNamesFromTensorInfo(input_elem.second, inputs);
62 }
63 for (auto& output_elem : signature_def.outputs()) {
64 GetTensorNamesFromTensorInfo(output_elem.second, outputs);
65 }
66 }
67 }
68
69 // Gets a map from string node name to NodeDef.
GetNodeNameToNodeDefMap(GraphDef * graph_def,std::unordered_map<string,NodeDef * > * name_to_node_map)70 void GetNodeNameToNodeDefMap(
71 GraphDef* graph_def,
72 std::unordered_map<string, NodeDef*>* name_to_node_map) {
73 for (size_t i = 0; i < graph_def->node_size(); i++) {
74 NodeDef* node = graph_def->mutable_node(i);
75 (*name_to_node_map)[node->name()] = node;
76 }
77 }
78
79 // Strips off the tensor part of the tensor_name to get the node_name.
GetNodeNameFromTensorName(string tensor_name)80 const string GetNodeNameFromTensorName(string tensor_name) {
81 if (tensor_name[0] == '^') {
82 tensor_name.erase(0, 1);
83 }
84 std::vector<string> tensor_name_parts = str_util::Split(tensor_name, ':');
85 return tensor_name_parts[0];
86 }
87
88 // Gets the set of node names needed by `outputs` and the corresponding set of
89 // variable nodes to convert.
GetReachableNodesAndVariables(GraphDef * graph_def,const std::unordered_set<string> & outputs,const std::unordered_map<string,NodeDef * > & name_to_node_map,std::unordered_set<string> * reachable_node_names,std::unordered_set<string> * variable_node_names)90 void GetReachableNodesAndVariables(
91 GraphDef* graph_def, const std::unordered_set<string>& outputs,
92 const std::unordered_map<string, NodeDef*>& name_to_node_map,
93 std::unordered_set<string>* reachable_node_names,
94 std::unordered_set<string>* variable_node_names) {
95 // TODO(suharshs): Add support for ResourceVariables.
96 static const std::unordered_set<string>* kVariableTypes =
97 new std::unordered_set<string>({"Variable", "VariableV2", "VarHandleOp"});
98
99 std::queue<string> nodes_to_visit;
100 for (const string& output_tensor_name : outputs) {
101 nodes_to_visit.push(GetNodeNameFromTensorName(output_tensor_name));
102 }
103 // We do a traversal backwards from the outputs specified in the MetaGraphDef.
104 while (!nodes_to_visit.empty()) {
105 const string node_name = nodes_to_visit.front();
106 nodes_to_visit.pop();
107 if (reachable_node_names->find(node_name) != reachable_node_names->end()) {
108 continue;
109 }
110 reachable_node_names->insert(node_name);
111 NodeDef* node = name_to_node_map.at(node_name);
112 if (kVariableTypes->find(node->op()) != kVariableTypes->end()) {
113 variable_node_names->insert(node->name());
114 }
115 for (const string& input_tensor_name : node->input()) {
116 nodes_to_visit.push(GetNodeNameFromTensorName(input_tensor_name));
117 }
118 }
119 }
120
121 // Gets a map from variable name to variable value.
GetVariableNameToTensorMap(Session * session,const std::unordered_map<string,NodeDef * > & name_to_node_map,std::unordered_set<string> variable_names_set,std::unordered_map<string,Tensor> * variable_name_to_value_map)122 Status GetVariableNameToTensorMap(
123 Session* session,
124 const std::unordered_map<string, NodeDef*>& name_to_node_map,
125 std::unordered_set<string> variable_names_set,
126 std::unordered_map<string, Tensor>* variable_name_to_value_map) {
127 if (variable_names_set.empty()) {
128 return Status::OK();
129 }
130 std::vector<string> variable_names;
131 variable_names.reserve(variable_names_set.size());
132 std::vector<string> tensor_names;
133 tensor_names.reserve(variable_names_set.size());
134 for (const string& node_name : variable_names_set) {
135 variable_names.push_back(node_name);
136 NodeDef* node_def = name_to_node_map.at(node_name);
137 if (node_def->op() == "VarHandleOp") {
138 // If this is a resource variable, we have to run the corresponding
139 // ReadVariableOp.
140 tensor_names.push_back(node_name + "/Read/ReadVariableOp:0");
141 } else {
142 tensor_names.push_back(node_name + ":0");
143 }
144 }
145 std::vector<Tensor> outputs;
146 TF_RETURN_IF_ERROR(
147 session->Run(/* inputs */ {}, tensor_names, /* targets */ {}, &outputs));
148 for (size_t i = 0; i < variable_names.size(); i++) {
149 (*variable_name_to_value_map)[variable_names[i]] = outputs[i];
150 }
151 return Status::OK();
152 }
153
154 // Converts a Variable NodeDef into a Constant NodeDef.
ConvertVariableToConstant(const NodeDef & variable_node,const Tensor & variable_value,NodeDef * const_node)155 void ConvertVariableToConstant(const NodeDef& variable_node,
156 const Tensor& variable_value,
157 NodeDef* const_node) {
158 const_node->set_name(variable_node.name());
159 const_node->set_op("Const");
160 (*const_node->mutable_attr())["dtype"] = variable_node.attr().at("dtype");
161 variable_value.AsProtoTensorContent(
162 (*const_node->mutable_attr())["value"].mutable_tensor());
163 }
164
165 // Converts a ReadVariableOp NodeDef to an Identity NodeDef.
ConvertReadVariableOpToIdentity(const NodeDef & node,NodeDef * identity_node)166 void ConvertReadVariableOpToIdentity(const NodeDef& node,
167 NodeDef* identity_node) {
168 identity_node->set_name(node.name());
169 identity_node->set_op("Identity");
170 (*identity_node->mutable_attr())["T"] = node.attr().at("dtype");
171 identity_node->add_input(node.input(0));
172 }
173
174 // Freezes the subgraph of all nodes needed by `outputs`.
FreezeGraphDef(const SavedModelBundle & saved_model_bundle,const std::unordered_set<string> & outputs,GraphDef * frozen_graph_def)175 Status FreezeGraphDef(const SavedModelBundle& saved_model_bundle,
176 const std::unordered_set<string>& outputs,
177 GraphDef* frozen_graph_def) {
178 GraphDef graph_def = saved_model_bundle.meta_graph_def.graph_def();
179 // Copy versions and library as-is from original graph.
180 *frozen_graph_def->mutable_versions() = graph_def.versions();
181 *frozen_graph_def->mutable_library() = graph_def.library();
182 // If the graph is empty there is nothing left to do.
183 if (graph_def.node_size() == 0) {
184 return Status::OK();
185 }
186 // name_to_node_map is needed to get the inputs from the NodeDef corresponding
187 // the a string node name. These inputs are used when doing our backwards
188 // traversal.
189 std::unordered_map<string, NodeDef*> name_to_node_map;
190 GetNodeNameToNodeDefMap(&graph_def, &name_to_node_map);
191 std::unordered_set<string> reachable_node_names;
192 std::unordered_set<string> variable_node_names;
193 GetReachableNodesAndVariables(&graph_def, outputs, name_to_node_map,
194 &reachable_node_names, &variable_node_names);
195 std::unordered_map<string, Tensor> variable_to_value_map;
196 TF_RETURN_IF_ERROR(GetVariableNameToTensorMap(
197 saved_model_bundle.session.get(), name_to_node_map, variable_node_names,
198 &variable_to_value_map));
199 // We copy the nodes in the same order they were in the original graph_def.
200 for (const NodeDef& node : graph_def.node()) {
201 if (reachable_node_names.find(node.name()) == reachable_node_names.end()) {
202 continue;
203 }
204 if (variable_node_names.find(node.name()) != variable_node_names.end()) {
205 ConvertVariableToConstant(node, variable_to_value_map[node.name()],
206 frozen_graph_def->add_node());
207 } else if (node.op() == "ReadVariableOp" &&
208 variable_node_names.find(node.input(0)) !=
209 variable_node_names.end()) {
210 // If the node is a ReadVariableOp, its input VarHandleOp will be
211 // converted to a Constant, so we will need to convert it to an Identity.
212 ConvertReadVariableOpToIdentity(node, frozen_graph_def->add_node());
213 } else {
214 // If the node isn't a variable, just copy the node as-is.
215 *frozen_graph_def->add_node() = node;
216 }
217 }
218 return Status::OK();
219 }
220
221 } // namespace
222
FreezeSavedModel(const SavedModelBundle & saved_model_bundle,GraphDef * frozen_graph_def,std::unordered_set<string> * inputs,std::unordered_set<string> * outputs)223 Status FreezeSavedModel(const SavedModelBundle& saved_model_bundle,
224 GraphDef* frozen_graph_def,
225 std::unordered_set<string>* inputs,
226 std::unordered_set<string>* outputs) {
227 GetSignatureDefsInputsAndOutputs(saved_model_bundle, inputs, outputs);
228 TF_RETURN_IF_ERROR(
229 FreezeGraphDef(saved_model_bundle, *outputs, frozen_graph_def));
230 return Status::OK();
231 }
232
233 } // namespace tensorflow
234