1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/cc/saved_model/bundle_v2.h"
17
18 #include "tensorflow/cc/saved_model/constants.h"
19 #include "tensorflow/core/lib/io/path.h"
20 #include "tensorflow/core/lib/strings/strcat.h"
21 #include "tensorflow/core/platform/env.h"
22 #include "tensorflow/core/platform/strcat.h"
23 #include "tensorflow/core/protobuf/saved_model.pb.h"
24 #include "tensorflow/core/protobuf/trackable_object_graph.pb.h"
25
26 namespace tensorflow {
27
28 namespace {
29
ReadSavedModelProto(const string & export_dir,SavedModel * saved_model_proto)30 Status ReadSavedModelProto(const string& export_dir,
31 SavedModel* saved_model_proto) {
32 LOG(INFO) << "Reading SavedModel from: " << export_dir;
33
34 const string saved_model_pb_path =
35 io::JoinPath(export_dir, kSavedModelFilenamePb);
36 if (Env::Default()->FileExists(saved_model_pb_path).ok()) {
37 return ReadBinaryProto(Env::Default(), saved_model_pb_path,
38 saved_model_proto);
39 }
40 const string saved_model_pbtxt_path =
41 io::JoinPath(export_dir, kSavedModelFilenamePbTxt);
42 if (Env::Default()->FileExists(saved_model_pbtxt_path).ok()) {
43 return ReadTextProto(Env::Default(), saved_model_pbtxt_path,
44 saved_model_proto);
45 }
46 return Status(error::Code::NOT_FOUND,
47 "Could not find SavedModel .pb or .pbtxt at supplied export "
48 "directory path: " +
49 export_dir);
50 }
51
ReadSavedModelDebugInfoIfPresent(const string & export_dir,std::unique_ptr<GraphDebugInfo> * debug_info_proto)52 Status ReadSavedModelDebugInfoIfPresent(
53 const string& export_dir,
54 std::unique_ptr<GraphDebugInfo>* debug_info_proto) {
55 LOG(INFO) << "Reading SavedModel debug info (if present) from: "
56 << export_dir;
57
58 const string debug_info_pb_path =
59 io::JoinPath(export_dir, "debug", "saved_model_debug_info.pb");
60 if (Env::Default()->FileExists(debug_info_pb_path).ok()) {
61 GraphDebugInfo debug_info;
62 TF_RETURN_IF_ERROR(
63 ReadBinaryProto(Env::Default(), debug_info_pb_path, &debug_info));
64 *debug_info_proto =
65 absl::make_unique<GraphDebugInfo>(std::move(debug_info));
66 }
67 return Status::OK();
68 }
69
ReadCheckpointObjectGraph(BundleReader * bundle_reader,TrackableObjectGraph * object_graph)70 Status ReadCheckpointObjectGraph(BundleReader* bundle_reader,
71 TrackableObjectGraph* object_graph) {
72 Tensor object_graph_tensor;
73 TF_RETURN_WITH_CONTEXT_IF_ERROR(
74 bundle_reader->Lookup(kObjectGraphProtoKey, &object_graph_tensor),
75 "SavedModel checkpoint does not contain object graph.");
76 if (object_graph_tensor.dtype() != DT_STRING ||
77 object_graph_tensor.dims() != 0 ||
78 object_graph_tensor.NumElements() != 1) {
79 return Status(
80 error::Code::FAILED_PRECONDITION,
81 "SavedModel checkpoint object graph was not the correct type.");
82 }
83
84 const tstring* object_graph_string = reinterpret_cast<const tstring*>(
85 object_graph_tensor.tensor_data().data());
86 if (!object_graph->ParseFromString(*object_graph_string)) {
87 return Status(
88 error::Code::FAILED_PRECONDITION,
89 "SavedModel checkpoint object graph could not be deserialized.");
90 }
91 return Status::OK();
92 }
93
94 } // namespace
95
Load(const std::string & export_dir,SavedModelV2Bundle * const bundle)96 Status SavedModelV2Bundle::Load(const std::string& export_dir,
97 SavedModelV2Bundle* const bundle) {
98 SavedModel saved_model_proto;
99 TF_RETURN_IF_ERROR(ReadSavedModelProto(export_dir, &saved_model_proto));
100
101 // Load MetaGraphDef.
102 // In version 2 SavedModels, there is only one MetaGraphDef.
103 if (saved_model_proto.meta_graphs_size() != 1) {
104 return Status(
105 error::Code::INVALID_ARGUMENT,
106 strings::StrCat(
107 "SavedModelV2 should have exactly one MetaGraphDef but actually ",
108 "contains ", saved_model_proto.meta_graphs_size()));
109 }
110 bundle->meta_graph_def_ =
111 std::move(*saved_model_proto.mutable_meta_graphs(0));
112
113 // Load GraphDebugInfo.
114 TF_RETURN_IF_ERROR(
115 ReadSavedModelDebugInfoIfPresent(export_dir, &bundle->debug_info_));
116
117 const std::string variables_dir =
118 io::JoinPath(export_dir, kSavedModelVariablesDirectory);
119 if (!Env::Default()->FileExists(variables_dir).ok()) {
120 LOG(INFO)
121 << "No checkpoint found, assuming this is a program-only SavedModel";
122 } else {
123 // Load the variables checkpoint reader.
124 const std::string variables_prefix =
125 io::JoinPath(variables_dir, kSavedModelVariablesFilename);
126 bundle->variable_reader_.reset(
127 new BundleReader(Env::Default(), variables_prefix));
128 TF_RETURN_WITH_CONTEXT_IF_ERROR(
129 bundle->variable_reader_->status(),
130 "Unable to load SavedModel variables checkpoint from ",
131 variables_prefix);
132
133 // Deserialize the object graph proto from the tensor bundle.
134 TF_RETURN_IF_ERROR(ReadCheckpointObjectGraph(
135 bundle->variable_reader_.get(), &bundle->trackable_object_graph_));
136 }
137
138 return Status::OK();
139 }
140
VisitObjectsToRestore(RestoreObjectsCallback callback)141 Status SavedModelV2Bundle::VisitObjectsToRestore(
142 RestoreObjectsCallback callback) {
143 if (saved_object_graph().nodes_size() == 0 ||
144 trackable_object_graph().nodes_size() == 0) {
145 return Status::OK();
146 }
147
148 // Start from root nodes of both the SavedObjectGraph and TrackableObjectGraph
149 // and descend to leaves. Note that the TrackableObjectGraph can have cycles
150 // (as can the SavedObjectGraph).
151 // This is detected and cycle edges are skipped.
152 const SavedObject* root_saved_object = &saved_object_graph().nodes(0);
153 const TrackableObjectGraph::TrackableObject* root_trackable_object =
154 &trackable_object_graph().nodes(0);
155 absl::flat_hash_set<int> trackable_node_ids;
156 return RecurseObjectsToRestore(root_saved_object, 0, root_trackable_object,
157 std::string(), &trackable_node_ids,
158 std::move(callback));
159 }
160
RecurseObjectsToRestore(const SavedObject * saved_object,int saved_object_node_id,const TrackableObjectGraph::TrackableObject * trackable_object,std::string object_name,absl::flat_hash_set<int> * seen_trackable_node_ids,RestoreObjectsCallback callback)161 Status SavedModelV2Bundle::RecurseObjectsToRestore(
162 const SavedObject* saved_object, int saved_object_node_id,
163 const TrackableObjectGraph::TrackableObject* trackable_object,
164 std::string object_name, absl::flat_hash_set<int>* seen_trackable_node_ids,
165 RestoreObjectsCallback callback) {
166 // Callback if any attributes or slot variables.
167 // Note that the root is always excluded from the search (it can never
168 // be a restorable object). This matches some logic on the Python side.
169 if (saved_object_node_id != 0 &&
170 (trackable_object->attributes_size() > 0 ||
171 trackable_object->slot_variables_size() > 0)) {
172 TF_RETURN_WITH_CONTEXT_IF_ERROR(
173 callback(saved_object_node_id, *trackable_object), "Unable to restore ",
174 object_name);
175 }
176
177 for (const auto& trackable_child_ref : trackable_object->children()) {
178 const auto& local_name = trackable_child_ref.local_name();
179
180 // Compute the full child name.
181 std::string child_name;
182 if (object_name.empty()) {
183 child_name = local_name;
184 } else {
185 child_name = strings::StrCat(object_name, ".", local_name);
186 }
187
188 // Descend down the trackable graph.
189 int trackable_child_node_id = trackable_child_ref.node_id();
190 if (!seen_trackable_node_ids->insert(trackable_child_node_id).second) {
191 // Cycle or duplicate detected - ignore this branch.
192 continue;
193 }
194 if (trackable_child_node_id < 0 ||
195 trackable_child_node_id >= trackable_object_graph().nodes_size()) {
196 return Status(
197 errors::Code::FAILED_PRECONDITION,
198 strings::StrCat("Illegal trackable child node id for ", child_name));
199 }
200 const auto* trackable_child =
201 &trackable_object_graph().nodes(trackable_child_node_id);
202
203 // Descend down the saved object graph.
204 int saved_child_node_id = -1;
205 const SavedObject* saved_child = nullptr;
206 for (const auto& saved_child_ref : saved_object->children()) {
207 if (saved_child_ref.local_name() == local_name) {
208 // Found.
209 saved_child_node_id = saved_child_ref.node_id();
210 if (saved_child_node_id >= 0 &&
211 saved_child_node_id < saved_object_graph().nodes_size()) {
212 saved_child = &saved_object_graph().nodes(saved_child_node_id);
213 }
214 break;
215 }
216 }
217
218 if (!saved_child) {
219 return Status(
220 errors::Code::FAILED_PRECONDITION,
221 strings::StrCat("Could not find saved object to restore for ",
222 child_name));
223 }
224
225 TF_RETURN_IF_ERROR(RecurseObjectsToRestore(
226 saved_child, saved_child_node_id, trackable_child, child_name,
227 seen_trackable_node_ids, callback));
228 }
229 return Status::OK();
230 }
231
232 } // namespace tensorflow
233