• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/cc/tools/freeze_saved_model.h"
17 
18 #include <iostream>
19 #include <queue>
20 
21 #include "tensorflow/core/framework/attr_value.pb.h"
22 #include "tensorflow/core/framework/function.pb.h"
23 #include "tensorflow/core/framework/graph.pb.h"
24 #include "tensorflow/core/framework/node_def.pb.h"
25 #include "tensorflow/core/framework/versions.pb.h"
26 #include "tensorflow/core/lib/core/errors.h"
27 #include "tensorflow/core/lib/strings/str_util.h"
28 #include "tensorflow/core/protobuf/meta_graph.pb.h"
29 
30 namespace tensorflow {
31 
32 namespace {
33 
34 // Gets tensor names from tensor_info and inserts them into the set of tensor
35 // names.
GetTensorNamesFromTensorInfo(const TensorInfo & tensor_info,std::unordered_set<string> * tensor_names)36 void GetTensorNamesFromTensorInfo(const TensorInfo& tensor_info,
37                                   std::unordered_set<string>* tensor_names) {
38   if (tensor_info.has_coo_sparse()) {
39     // If the tensor is sparse we have to add all three tensors of the sparse
40     // representations.
41     const TensorInfo_CooSparse& coo_sparse = tensor_info.coo_sparse();
42     tensor_names->insert(coo_sparse.values_tensor_name());
43     tensor_names->insert(coo_sparse.indices_tensor_name());
44     tensor_names->insert(coo_sparse.dense_shape_tensor_name());
45   } else {
46     tensor_names->insert(tensor_info.name());
47   }
48 }
49 
50 // Gets the union of all inputs and outputs of all SignatureDefs in the bundle
GetSignatureDefsInputsAndOutputs(const SavedModelBundle & saved_model_bundle,std::unordered_set<string> * inputs,std::unordered_set<string> * outputs)51 void GetSignatureDefsInputsAndOutputs(
52     const SavedModelBundle& saved_model_bundle,
53     std::unordered_set<string>* inputs, std::unordered_set<string>* outputs) {
54   for (auto& sigdef_elem : saved_model_bundle.meta_graph_def.signature_def()) {
55     const SignatureDef& signature_def = sigdef_elem.second;
56     for (auto& input_elem : signature_def.inputs()) {
57       GetTensorNamesFromTensorInfo(input_elem.second, inputs);
58     }
59     for (auto& output_elem : signature_def.outputs()) {
60       GetTensorNamesFromTensorInfo(output_elem.second, outputs);
61     }
62   }
63 }
64 
65 // Gets a map from string node name to NodeDef.
GetNodeNameToNodeDefMap(GraphDef * graph_def,std::unordered_map<string,NodeDef * > * name_to_node_map)66 void GetNodeNameToNodeDefMap(
67     GraphDef* graph_def,
68     std::unordered_map<string, NodeDef*>* name_to_node_map) {
69   for (size_t i = 0; i < graph_def->node_size(); i++) {
70     NodeDef* node = graph_def->mutable_node(i);
71     (*name_to_node_map)[node->name()] = node;
72   }
73 }
74 
75 // Strips off the tensor part of the tensor_name to get the node_name.
GetNodeNameFromTensorName(string tensor_name)76 const string GetNodeNameFromTensorName(string tensor_name) {
77   if (tensor_name[0] == '^') {
78     tensor_name.erase(0, 1);
79   }
80   std::vector<string> tensor_name_parts = str_util::Split(tensor_name, ':');
81   return tensor_name_parts[0];
82 }
83 
84 // Gets the set of node names needed by `outputs` and the corresponding set of
85 // variable nodes to convert.
GetReachableNodesAndVariables(GraphDef * graph_def,const std::unordered_set<string> & outputs,const std::unordered_map<string,NodeDef * > & name_to_node_map,std::unordered_set<string> * reachable_node_names,std::unordered_set<string> * variable_node_names)86 void GetReachableNodesAndVariables(
87     GraphDef* graph_def, const std::unordered_set<string>& outputs,
88     const std::unordered_map<string, NodeDef*>& name_to_node_map,
89     std::unordered_set<string>* reachable_node_names,
90     std::unordered_set<string>* variable_node_names) {
91   // TODO(suharshs): Add support for ResourceVariables.
92   static const std::unordered_set<string>* kVariableTypes =
93       new std::unordered_set<string>({"Variable", "VariableV2", "VarHandleOp"});
94 
95   std::queue<string> nodes_to_visit;
96   for (const string& output_tensor_name : outputs) {
97     nodes_to_visit.push(GetNodeNameFromTensorName(output_tensor_name));
98   }
99   // We do a traversal backwards from the outputs specified in the MetaGraphDef.
100   while (!nodes_to_visit.empty()) {
101     const string node_name = nodes_to_visit.front();
102     nodes_to_visit.pop();
103     if (reachable_node_names->find(node_name) != reachable_node_names->end()) {
104       continue;
105     }
106     reachable_node_names->insert(node_name);
107     NodeDef* node = name_to_node_map.at(node_name);
108     if (kVariableTypes->find(node->op()) != kVariableTypes->end()) {
109       variable_node_names->insert(node->name());
110     }
111     for (const string& input_tensor_name : node->input()) {
112       nodes_to_visit.push(GetNodeNameFromTensorName(input_tensor_name));
113     }
114   }
115 }
116 
117 // Gets a map from variable name to variable value.
GetVariableNameToTensorMap(Session * session,const std::unordered_map<string,NodeDef * > & name_to_node_map,std::unordered_set<string> variable_names_set,std::unordered_map<string,Tensor> * variable_name_to_value_map)118 Status GetVariableNameToTensorMap(
119     Session* session,
120     const std::unordered_map<string, NodeDef*>& name_to_node_map,
121     std::unordered_set<string> variable_names_set,
122     std::unordered_map<string, Tensor>* variable_name_to_value_map) {
123   if (variable_names_set.empty()) {
124     return Status::OK();
125   }
126   std::vector<string> variable_names;
127   variable_names.reserve(variable_names_set.size());
128   std::vector<string> tensor_names;
129   tensor_names.reserve(variable_names_set.size());
130   for (const string& node_name : variable_names_set) {
131     variable_names.push_back(node_name);
132     NodeDef* node_def = name_to_node_map.at(node_name);
133     if (node_def->op() == "VarHandleOp") {
134       // If this is a resource variable, we have to run the corresponding
135       // ReadVariableOp.
136       tensor_names.push_back(node_name + "/Read/ReadVariableOp:0");
137     } else {
138       tensor_names.push_back(node_name + ":0");
139     }
140   }
141   std::vector<Tensor> outputs;
142   TF_RETURN_IF_ERROR(
143       session->Run(/* inputs */ {}, tensor_names, /* targets */ {}, &outputs));
144   for (size_t i = 0; i < variable_names.size(); i++) {
145     (*variable_name_to_value_map)[variable_names[i]] = outputs[i];
146   }
147   return Status::OK();
148 }
149 
150 // Converts a Variable NodeDef into a Constant NodeDef.
ConvertVariableToConstant(const NodeDef & variable_node,const Tensor & variable_value,NodeDef * const_node)151 void ConvertVariableToConstant(const NodeDef& variable_node,
152                                const Tensor& variable_value,
153                                NodeDef* const_node) {
154   const_node->set_name(variable_node.name());
155   const_node->set_op("Const");
156   (*const_node->mutable_attr())["dtype"] = variable_node.attr().at("dtype");
157   variable_value.AsProtoTensorContent(
158       (*const_node->mutable_attr())["value"].mutable_tensor());
159 }
160 
161 // Converts a ReadVariableOp NodeDef to an Identity NodeDef.
ConvertReadVariableOpToIdentity(const NodeDef & node,NodeDef * identity_node)162 void ConvertReadVariableOpToIdentity(const NodeDef& node,
163                                      NodeDef* identity_node) {
164   identity_node->set_name(node.name());
165   identity_node->set_op("Identity");
166   (*identity_node->mutable_attr())["T"] = node.attr().at("dtype");
167   identity_node->add_input(node.input(0));
168 }
169 
170 // Freezes the subgraph of all nodes needed by `outputs`.
FreezeGraphDef(const SavedModelBundle & saved_model_bundle,const std::unordered_set<string> & outputs,GraphDef * frozen_graph_def)171 Status FreezeGraphDef(const SavedModelBundle& saved_model_bundle,
172                       const std::unordered_set<string>& outputs,
173                       GraphDef* frozen_graph_def) {
174   GraphDef graph_def = saved_model_bundle.meta_graph_def.graph_def();
175   // Copy versions and library as-is from original graph.
176   *frozen_graph_def->mutable_versions() = graph_def.versions();
177   *frozen_graph_def->mutable_library() = graph_def.library();
178   // If the graph is empty there is nothing left to do.
179   if (graph_def.node_size() == 0) {
180     return Status::OK();
181   }
182   // name_to_node_map is needed to get the inputs from the NodeDef corresponding
183   // the a string node name. These inputs are used when doing our backwards
184   // traversal.
185   std::unordered_map<string, NodeDef*> name_to_node_map;
186   GetNodeNameToNodeDefMap(&graph_def, &name_to_node_map);
187   std::unordered_set<string> reachable_node_names;
188   std::unordered_set<string> variable_node_names;
189   GetReachableNodesAndVariables(&graph_def, outputs, name_to_node_map,
190                                 &reachable_node_names, &variable_node_names);
191   std::unordered_map<string, Tensor> variable_to_value_map;
192   TF_RETURN_IF_ERROR(GetVariableNameToTensorMap(
193       saved_model_bundle.session.get(), name_to_node_map, variable_node_names,
194       &variable_to_value_map));
195   // We copy the nodes in the same order they were in the original graph_def.
196   for (const NodeDef& node : graph_def.node()) {
197     if (reachable_node_names.find(node.name()) == reachable_node_names.end()) {
198       continue;
199     }
200     if (variable_node_names.find(node.name()) != variable_node_names.end()) {
201       ConvertVariableToConstant(node, variable_to_value_map[node.name()],
202                                 frozen_graph_def->add_node());
203     } else if (node.op() == "ReadVariableOp" &&
204                variable_node_names.find(node.input(0)) !=
205                    variable_node_names.end()) {
206       // If the node is a ReadVariableOp, its input VarHandleOp will be
207       // converted to a Constant, so we will need to convert it to an Identity.
208       ConvertReadVariableOpToIdentity(node, frozen_graph_def->add_node());
209     } else {
210       // If the node isn't a variable, just copy the node as-is.
211       *frozen_graph_def->add_node() = node;
212     }
213   }
214   return Status::OK();
215 }
216 
217 }  // namespace
218 
FreezeSavedModel(const SavedModelBundle & saved_model_bundle,GraphDef * frozen_graph_def,std::unordered_set<string> * inputs,std::unordered_set<string> * outputs)219 Status FreezeSavedModel(const SavedModelBundle& saved_model_bundle,
220                         GraphDef* frozen_graph_def,
221                         std::unordered_set<string>* inputs,
222                         std::unordered_set<string>* outputs) {
223   GetSignatureDefsInputsAndOutputs(saved_model_bundle, inputs, outputs);
224   TF_RETURN_IF_ERROR(
225       FreezeGraphDef(saved_model_bundle, *outputs, frozen_graph_def));
226   return Status::OK();
227 }
228 
229 }  // namespace tensorflow
230