• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/cc/tools/freeze_saved_model.h"
17 
18 #include <iostream>
19 #include <queue>
20 
21 #include "tensorflow/core/framework/attr_value.pb.h"
22 #include "tensorflow/core/framework/function.pb.h"
23 #include "tensorflow/core/framework/graph.pb.h"
24 #include "tensorflow/core/framework/node_def.pb.h"
25 #include "tensorflow/core/framework/versions.pb.h"
26 #include "tensorflow/core/lib/core/errors.h"
27 #include "tensorflow/core/lib/strings/str_util.h"
28 #include "tensorflow/core/protobuf/meta_graph.pb.h"
29 
30 namespace tensorflow {
31 
32 namespace {
33 
34 // Gets tensor names from tensor_info and inserts them into the set of tensor
35 // names.
GetTensorNamesFromTensorInfo(const TensorInfo & tensor_info,std::unordered_set<string> * tensor_names)36 void GetTensorNamesFromTensorInfo(const TensorInfo& tensor_info,
37                                   std::unordered_set<string>* tensor_names) {
38   if (tensor_info.has_coo_sparse()) {
39     // If the tensor is sparse we have to add all three tensors of the sparse
40     // representations.
41     const TensorInfo_CooSparse& coo_sparse = tensor_info.coo_sparse();
42     tensor_names->insert(coo_sparse.values_tensor_name());
43     tensor_names->insert(coo_sparse.indices_tensor_name());
44     tensor_names->insert(coo_sparse.dense_shape_tensor_name());
45   } else if (tensor_info.has_composite_tensor()) {
46     for (const auto& component : tensor_info.composite_tensor().components()) {
47       tensor_names->insert(component.name());
48     }
49   } else {
50     tensor_names->insert(tensor_info.name());
51   }
52 }
53 
54 // Gets the union of all inputs and outputs of all SignatureDefs in the bundle
GetSignatureDefsInputsAndOutputs(const SavedModelBundle & saved_model_bundle,std::unordered_set<string> * inputs,std::unordered_set<string> * outputs)55 void GetSignatureDefsInputsAndOutputs(
56     const SavedModelBundle& saved_model_bundle,
57     std::unordered_set<string>* inputs, std::unordered_set<string>* outputs) {
58   for (auto& sigdef_elem : saved_model_bundle.meta_graph_def.signature_def()) {
59     const SignatureDef& signature_def = sigdef_elem.second;
60     for (auto& input_elem : signature_def.inputs()) {
61       GetTensorNamesFromTensorInfo(input_elem.second, inputs);
62     }
63     for (auto& output_elem : signature_def.outputs()) {
64       GetTensorNamesFromTensorInfo(output_elem.second, outputs);
65     }
66   }
67 }
68 
69 // Gets a map from string node name to NodeDef.
GetNodeNameToNodeDefMap(GraphDef * graph_def,std::unordered_map<string,NodeDef * > * name_to_node_map)70 void GetNodeNameToNodeDefMap(
71     GraphDef* graph_def,
72     std::unordered_map<string, NodeDef*>* name_to_node_map) {
73   for (size_t i = 0; i < graph_def->node_size(); i++) {
74     NodeDef* node = graph_def->mutable_node(i);
75     (*name_to_node_map)[node->name()] = node;
76   }
77 }
78 
79 // Strips off the tensor part of the tensor_name to get the node_name.
GetNodeNameFromTensorName(string tensor_name)80 const string GetNodeNameFromTensorName(string tensor_name) {
81   if (tensor_name[0] == '^') {
82     tensor_name.erase(0, 1);
83   }
84   std::vector<string> tensor_name_parts = str_util::Split(tensor_name, ':');
85   return tensor_name_parts[0];
86 }
87 
88 // Gets the set of node names needed by `outputs` and the corresponding set of
89 // variable nodes to convert.
GetReachableNodesAndVariables(GraphDef * graph_def,const std::unordered_set<string> & outputs,const std::unordered_map<string,NodeDef * > & name_to_node_map,std::unordered_set<string> * reachable_node_names,std::unordered_set<string> * variable_node_names)90 void GetReachableNodesAndVariables(
91     GraphDef* graph_def, const std::unordered_set<string>& outputs,
92     const std::unordered_map<string, NodeDef*>& name_to_node_map,
93     std::unordered_set<string>* reachable_node_names,
94     std::unordered_set<string>* variable_node_names) {
95   // TODO(suharshs): Add support for ResourceVariables.
96   static const std::unordered_set<string>* kVariableTypes =
97       new std::unordered_set<string>({"Variable", "VariableV2", "VarHandleOp"});
98 
99   std::queue<string> nodes_to_visit;
100   for (const string& output_tensor_name : outputs) {
101     nodes_to_visit.push(GetNodeNameFromTensorName(output_tensor_name));
102   }
103   // We do a traversal backwards from the outputs specified in the MetaGraphDef.
104   while (!nodes_to_visit.empty()) {
105     const string node_name = nodes_to_visit.front();
106     nodes_to_visit.pop();
107     if (reachable_node_names->find(node_name) != reachable_node_names->end()) {
108       continue;
109     }
110     reachable_node_names->insert(node_name);
111     NodeDef* node = name_to_node_map.at(node_name);
112     if (kVariableTypes->find(node->op()) != kVariableTypes->end()) {
113       variable_node_names->insert(node->name());
114     }
115     for (const string& input_tensor_name : node->input()) {
116       nodes_to_visit.push(GetNodeNameFromTensorName(input_tensor_name));
117     }
118   }
119 }
120 
121 // Gets a map from variable name to variable value.
GetVariableNameToTensorMap(Session * session,const std::unordered_map<string,NodeDef * > & name_to_node_map,std::unordered_set<string> variable_names_set,std::unordered_map<string,Tensor> * variable_name_to_value_map)122 Status GetVariableNameToTensorMap(
123     Session* session,
124     const std::unordered_map<string, NodeDef*>& name_to_node_map,
125     std::unordered_set<string> variable_names_set,
126     std::unordered_map<string, Tensor>* variable_name_to_value_map) {
127   if (variable_names_set.empty()) {
128     return Status::OK();
129   }
130   std::vector<string> variable_names;
131   variable_names.reserve(variable_names_set.size());
132   std::vector<string> tensor_names;
133   tensor_names.reserve(variable_names_set.size());
134   for (const string& node_name : variable_names_set) {
135     variable_names.push_back(node_name);
136     NodeDef* node_def = name_to_node_map.at(node_name);
137     if (node_def->op() == "VarHandleOp") {
138       // If this is a resource variable, we have to run the corresponding
139       // ReadVariableOp.
140       tensor_names.push_back(node_name + "/Read/ReadVariableOp:0");
141     } else {
142       tensor_names.push_back(node_name + ":0");
143     }
144   }
145   std::vector<Tensor> outputs;
146   TF_RETURN_IF_ERROR(
147       session->Run(/* inputs */ {}, tensor_names, /* targets */ {}, &outputs));
148   for (size_t i = 0; i < variable_names.size(); i++) {
149     (*variable_name_to_value_map)[variable_names[i]] = outputs[i];
150   }
151   return Status::OK();
152 }
153 
154 // Converts a Variable NodeDef into a Constant NodeDef.
ConvertVariableToConstant(const NodeDef & variable_node,const Tensor & variable_value,NodeDef * const_node)155 void ConvertVariableToConstant(const NodeDef& variable_node,
156                                const Tensor& variable_value,
157                                NodeDef* const_node) {
158   const_node->set_name(variable_node.name());
159   const_node->set_op("Const");
160   (*const_node->mutable_attr())["dtype"] = variable_node.attr().at("dtype");
161   variable_value.AsProtoTensorContent(
162       (*const_node->mutable_attr())["value"].mutable_tensor());
163 }
164 
165 // Converts a ReadVariableOp NodeDef to an Identity NodeDef.
ConvertReadVariableOpToIdentity(const NodeDef & node,NodeDef * identity_node)166 void ConvertReadVariableOpToIdentity(const NodeDef& node,
167                                      NodeDef* identity_node) {
168   identity_node->set_name(node.name());
169   identity_node->set_op("Identity");
170   (*identity_node->mutable_attr())["T"] = node.attr().at("dtype");
171   identity_node->add_input(node.input(0));
172 }
173 
174 // Freezes the subgraph of all nodes needed by `outputs`.
FreezeGraphDef(const SavedModelBundle & saved_model_bundle,const std::unordered_set<string> & outputs,GraphDef * frozen_graph_def)175 Status FreezeGraphDef(const SavedModelBundle& saved_model_bundle,
176                       const std::unordered_set<string>& outputs,
177                       GraphDef* frozen_graph_def) {
178   GraphDef graph_def = saved_model_bundle.meta_graph_def.graph_def();
179   // Copy versions and library as-is from original graph.
180   *frozen_graph_def->mutable_versions() = graph_def.versions();
181   *frozen_graph_def->mutable_library() = graph_def.library();
182   // If the graph is empty there is nothing left to do.
183   if (graph_def.node_size() == 0) {
184     return Status::OK();
185   }
186   // name_to_node_map is needed to get the inputs from the NodeDef corresponding
187   // the a string node name. These inputs are used when doing our backwards
188   // traversal.
189   std::unordered_map<string, NodeDef*> name_to_node_map;
190   GetNodeNameToNodeDefMap(&graph_def, &name_to_node_map);
191   std::unordered_set<string> reachable_node_names;
192   std::unordered_set<string> variable_node_names;
193   GetReachableNodesAndVariables(&graph_def, outputs, name_to_node_map,
194                                 &reachable_node_names, &variable_node_names);
195   std::unordered_map<string, Tensor> variable_to_value_map;
196   TF_RETURN_IF_ERROR(GetVariableNameToTensorMap(
197       saved_model_bundle.session.get(), name_to_node_map, variable_node_names,
198       &variable_to_value_map));
199   // We copy the nodes in the same order they were in the original graph_def.
200   for (const NodeDef& node : graph_def.node()) {
201     if (reachable_node_names.find(node.name()) == reachable_node_names.end()) {
202       continue;
203     }
204     if (variable_node_names.find(node.name()) != variable_node_names.end()) {
205       ConvertVariableToConstant(node, variable_to_value_map[node.name()],
206                                 frozen_graph_def->add_node());
207     } else if (node.op() == "ReadVariableOp" &&
208                variable_node_names.find(node.input(0)) !=
209                    variable_node_names.end()) {
210       // If the node is a ReadVariableOp, its input VarHandleOp will be
211       // converted to a Constant, so we will need to convert it to an Identity.
212       ConvertReadVariableOpToIdentity(node, frozen_graph_def->add_node());
213     } else {
214       // If the node isn't a variable, just copy the node as-is.
215       *frozen_graph_def->add_node() = node;
216     }
217   }
218   return Status::OK();
219 }
220 
221 }  // namespace
222 
FreezeSavedModel(const SavedModelBundle & saved_model_bundle,GraphDef * frozen_graph_def,std::unordered_set<string> * inputs,std::unordered_set<string> * outputs)223 Status FreezeSavedModel(const SavedModelBundle& saved_model_bundle,
224                         GraphDef* frozen_graph_def,
225                         std::unordered_set<string>* inputs,
226                         std::unordered_set<string>* outputs) {
227   GetSignatureDefsInputsAndOutputs(saved_model_bundle, inputs, outputs);
228   TF_RETURN_IF_ERROR(
229       FreezeGraphDef(saved_model_bundle, *outputs, frozen_graph_def));
230   return Status::OK();
231 }
232 
233 }  // namespace tensorflow
234