1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/cc/saved_model/bundle_v2.h"
17
18 #include "tensorflow/cc/saved_model/constants.h"
19 #include "tensorflow/core/lib/io/path.h"
20 #include "tensorflow/core/lib/strings/strcat.h"
21 #include "tensorflow/core/platform/env.h"
22 #include "tensorflow/core/platform/strcat.h"
23 #include "tensorflow/core/protobuf/saved_model.pb.h"
24 #include "tensorflow/core/protobuf/trackable_object_graph.pb.h"
25
26 namespace tensorflow {
27
28 namespace {
29
ReadSavedModelProto(const string & export_dir,SavedModel * saved_model_proto)30 Status ReadSavedModelProto(const string& export_dir,
31 SavedModel* saved_model_proto) {
32 LOG(INFO) << "Reading SavedModel from: " << export_dir;
33
34 const string saved_model_pb_path =
35 io::JoinPath(export_dir, kSavedModelFilenamePb);
36 if (Env::Default()->FileExists(saved_model_pb_path).ok()) {
37 return ReadBinaryProto(Env::Default(), saved_model_pb_path,
38 saved_model_proto);
39 }
40 const string saved_model_pbtxt_path =
41 io::JoinPath(export_dir, kSavedModelFilenamePbTxt);
42 if (Env::Default()->FileExists(saved_model_pbtxt_path).ok()) {
43 return ReadTextProto(Env::Default(), saved_model_pbtxt_path,
44 saved_model_proto);
45 }
46 return Status(error::Code::NOT_FOUND,
47 "Could not find SavedModel .pb or .pbtxt at supplied export "
48 "directory path: " +
49 export_dir);
50 }
51
ReadSavedModelDebugInfoIfPresent(const string & export_dir,std::unique_ptr<GraphDebugInfo> * debug_info_proto)52 Status ReadSavedModelDebugInfoIfPresent(
53 const string& export_dir,
54 std::unique_ptr<GraphDebugInfo>* debug_info_proto) {
55 LOG(INFO) << "Reading SavedModel debug info (if present) from: "
56 << export_dir;
57
58 const string debug_info_pb_path =
59 io::JoinPath(export_dir, "debug", "saved_model_debug_info.pb");
60 if (Env::Default()->FileExists(debug_info_pb_path).ok()) {
61 GraphDebugInfo debug_info;
62 TF_RETURN_IF_ERROR(
63 ReadBinaryProto(Env::Default(), debug_info_pb_path, &debug_info));
64 *debug_info_proto =
65 absl::make_unique<GraphDebugInfo>(std::move(debug_info));
66 }
67 return Status::OK();
68 }
69
ReadCheckpointObjectGraph(BundleReader * bundle_reader,TrackableObjectGraph * object_graph)70 Status ReadCheckpointObjectGraph(BundleReader* bundle_reader,
71 TrackableObjectGraph* object_graph) {
72 Tensor object_graph_tensor;
73 TF_RETURN_WITH_CONTEXT_IF_ERROR(
74 bundle_reader->Lookup(kObjectGraphProtoKey, &object_graph_tensor),
75 "SavedModel checkpoint does not contain object graph.");
76 if (object_graph_tensor.dtype() != DT_STRING ||
77 object_graph_tensor.dims() != 0 ||
78 object_graph_tensor.NumElements() != 1) {
79 return Status(
80 error::Code::FAILED_PRECONDITION,
81 "SavedModel checkpoint object graph was not the correct type.");
82 }
83
84 const tstring* object_graph_string = reinterpret_cast<const tstring*>(
85 object_graph_tensor.tensor_data().data());
86 if (!object_graph->ParseFromString(*object_graph_string)) {
87 return Status(
88 error::Code::FAILED_PRECONDITION,
89 "SavedModel checkpoint object graph could not be deserialized.");
90 }
91 return Status::OK();
92 }
93
94 } // namespace
95
Load(const std::string & export_dir,SavedModelV2Bundle * const bundle)96 Status SavedModelV2Bundle::Load(const std::string& export_dir,
97 SavedModelV2Bundle* const bundle) {
98 SavedModel saved_model_proto;
99 TF_RETURN_IF_ERROR(ReadSavedModelProto(export_dir, &saved_model_proto));
100
101 // Load MetaGraphDef.
102 // In version 2 SavedModels, there is only one MetaGraphDef.
103 if (saved_model_proto.meta_graphs_size() != 1) {
104 return Status(
105 error::Code::INVALID_ARGUMENT,
106 strings::StrCat(
107 "SavedModelV2 should have exactly one MetaGraphDef but actually ",
108 "contains ", saved_model_proto.meta_graphs_size()));
109 }
110 bundle->meta_graph_def_ =
111 std::move(*saved_model_proto.mutable_meta_graphs(0));
112
113 // Load GraphDebugInfo.
114 TF_RETURN_IF_ERROR(
115 ReadSavedModelDebugInfoIfPresent(export_dir, &bundle->debug_info_));
116
117 // Load the variables checkpoint reader.
118 const std::string variables_prefix = io::JoinPath(
119 export_dir, kSavedModelVariablesDirectory, kSavedModelVariablesFilename);
120 bundle->variable_reader_.reset(
121 new BundleReader(Env::Default(), variables_prefix));
122 TF_RETURN_WITH_CONTEXT_IF_ERROR(
123 bundle->variable_reader_->status(),
124 "Unable to load SavedModel variables checkpoint from ", variables_prefix);
125
126 // Deserialize the object graph proto from the tensor bundle.
127 TF_RETURN_IF_ERROR(ReadCheckpointObjectGraph(
128 bundle->variable_reader_.get(), &bundle->trackable_object_graph_));
129 return Status::OK();
130 }
131
VisitObjectsToRestore(RestoreObjectsCallback callback)132 Status SavedModelV2Bundle::VisitObjectsToRestore(
133 RestoreObjectsCallback callback) {
134 if (saved_object_graph().nodes_size() == 0 ||
135 trackable_object_graph().nodes_size() == 0) {
136 return Status::OK();
137 }
138
139 // Start from root nodes of both the SavedObjectGraph and TrackableObjectGraph
140 // and descend to leaves. Note that the TrackableObjectGraph can have cycles
141 // (as can the SavedObjectGraph).
142 // This is detected and cycle edges are skipped.
143 const SavedObject* root_saved_object = &saved_object_graph().nodes(0);
144 const TrackableObjectGraph::TrackableObject* root_trackable_object =
145 &trackable_object_graph().nodes(0);
146 absl::flat_hash_set<int> trackable_node_ids;
147 return RecurseObjectsToRestore(root_saved_object, 0, root_trackable_object,
148 std::string(), &trackable_node_ids,
149 std::move(callback));
150 }
151
RecurseObjectsToRestore(const SavedObject * saved_object,int saved_object_node_id,const TrackableObjectGraph::TrackableObject * trackable_object,std::string object_name,absl::flat_hash_set<int> * seen_trackable_node_ids,RestoreObjectsCallback callback)152 Status SavedModelV2Bundle::RecurseObjectsToRestore(
153 const SavedObject* saved_object, int saved_object_node_id,
154 const TrackableObjectGraph::TrackableObject* trackable_object,
155 std::string object_name, absl::flat_hash_set<int>* seen_trackable_node_ids,
156 RestoreObjectsCallback callback) {
157 // Callback if any attributes or slot variables.
158 // Note that the root is always excluded from the search (it can never
159 // be a restorable object). This matches some logic on the Python side.
160 if (saved_object_node_id != 0 &&
161 (trackable_object->attributes_size() > 0 ||
162 trackable_object->slot_variables_size() > 0)) {
163 TF_RETURN_WITH_CONTEXT_IF_ERROR(
164 callback(saved_object_node_id, *trackable_object), "Unable to restore ",
165 object_name);
166 }
167
168 for (const auto& trackable_child_ref : trackable_object->children()) {
169 const auto& local_name = trackable_child_ref.local_name();
170
171 // Compute the full child name.
172 std::string child_name;
173 if (object_name.empty()) {
174 child_name = local_name;
175 } else {
176 child_name = strings::StrCat(object_name, ".", local_name);
177 }
178
179 // Descend down the trackable graph.
180 int trackable_child_node_id = trackable_child_ref.node_id();
181 if (!seen_trackable_node_ids->insert(trackable_child_node_id).second) {
182 // Cycle or duplicate detected - ignore this branch.
183 continue;
184 }
185 if (trackable_child_node_id < 0 ||
186 trackable_child_node_id >= trackable_object_graph().nodes_size()) {
187 return Status(
188 errors::Code::FAILED_PRECONDITION,
189 strings::StrCat("Illegal trackable child node id for ", child_name));
190 }
191 const auto* trackable_child =
192 &trackable_object_graph().nodes(trackable_child_node_id);
193
194 // Descend down the saved object graph.
195 int saved_child_node_id = -1;
196 const SavedObject* saved_child = nullptr;
197 for (const auto& saved_child_ref : saved_object->children()) {
198 if (saved_child_ref.local_name() == local_name) {
199 // Found.
200 saved_child_node_id = saved_child_ref.node_id();
201 if (saved_child_node_id >= 0 &&
202 saved_child_node_id < saved_object_graph().nodes_size()) {
203 saved_child = &saved_object_graph().nodes(saved_child_node_id);
204 }
205 break;
206 }
207 }
208
209 if (!saved_child) {
210 return Status(
211 errors::Code::FAILED_PRECONDITION,
212 strings::StrCat("Could not find saved object to restore for ",
213 child_name));
214 }
215
216 TF_RETURN_IF_ERROR(RecurseObjectsToRestore(
217 saved_child, saved_child_node_id, trackable_child, child_name,
218 seen_trackable_node_ids, callback));
219 }
220 return Status::OK();
221 }
222
223 } // namespace tensorflow
224