1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/cc/saved_model/bundle_v2.h"
17
18 #include <string>
19 #include <utility>
20
21 #include "tensorflow/cc/saved_model/constants.h"
22 #include "tensorflow/cc/saved_model/metrics.h"
23 #include "tensorflow/cc/saved_model/reader.h"
24 #include "tensorflow/cc/saved_model/util.h"
25 #include "tensorflow/core/lib/io/path.h"
26 #include "tensorflow/core/lib/strings/strcat.h"
27 #include "tensorflow/core/platform/env.h"
28 #include "tensorflow/core/platform/strcat.h"
29 #include "tensorflow/core/protobuf/saved_model.pb.h"
30 #include "tensorflow/core/protobuf/trackable_object_graph.pb.h"
31
32 namespace tensorflow {
33 namespace {
34
35 // `tensorflow::SavedModelV2Bundle::Load` API label.
36 constexpr char kCCLoadBundleV2Label[] = "cc_load_bundle_v2";
37
ReadSavedModelProto(const string & export_dir,SavedModel * saved_model_proto)38 Status ReadSavedModelProto(const string& export_dir,
39 SavedModel* saved_model_proto) {
40 LOG(INFO) << "Reading SavedModel from: " << export_dir;
41
42 const string saved_model_pb_path =
43 io::JoinPath(export_dir, kSavedModelFilenamePb);
44
45 if (Env::Default()->FileExists(saved_model_pb_path).ok()) {
46 Status result =
47 ReadBinaryProto(Env::Default(), saved_model_pb_path, saved_model_proto);
48 if (result.ok()) {
49 metrics::SavedModelRead(saved_model::GetWriteVersion(*saved_model_proto))
50 .IncrementBy(1);
51 }
52 return result;
53 }
54 const string saved_model_pbtxt_path =
55 io::JoinPath(export_dir, kSavedModelFilenamePbTxt);
56 if (Env::Default()->FileExists(saved_model_pbtxt_path).ok()) {
57 Status result = ReadTextProto(Env::Default(), saved_model_pbtxt_path,
58 saved_model_proto);
59 if (result.ok()) {
60 metrics::SavedModelRead(saved_model::GetWriteVersion(*saved_model_proto))
61 .IncrementBy(1);
62 }
63 return result;
64 }
65
66 return Status(error::Code::NOT_FOUND,
67 "Could not find SavedModel .pb or .pbtxt at supplied export "
68 "directory path: " +
69 export_dir);
70 }
71
ReadCheckpointObjectGraph(BundleReader * bundle_reader,TrackableObjectGraph * object_graph)72 Status ReadCheckpointObjectGraph(BundleReader* bundle_reader,
73 TrackableObjectGraph* object_graph) {
74 Tensor object_graph_tensor;
75 TF_RETURN_WITH_CONTEXT_IF_ERROR(
76 bundle_reader->Lookup(kObjectGraphProtoKey, &object_graph_tensor),
77 "SavedModel checkpoint does not contain object graph.");
78 if (object_graph_tensor.dtype() != DT_STRING ||
79 object_graph_tensor.dims() != 0 ||
80 object_graph_tensor.NumElements() != 1) {
81 return Status(
82 error::Code::FAILED_PRECONDITION,
83 "SavedModel checkpoint object graph was not the correct type.");
84 }
85
86 const tstring* object_graph_string = reinterpret_cast<const tstring*>(
87 object_graph_tensor.tensor_data().data());
88 if (!object_graph->ParseFromString(*object_graph_string)) {
89 return Status(
90 error::Code::FAILED_PRECONDITION,
91 "SavedModel checkpoint object graph could not be deserialized.");
92 }
93 return Status::OK();
94 }
95
96 } // namespace
97
Load(const std::string & export_dir,SavedModelV2Bundle * const bundle)98 Status SavedModelV2Bundle::Load(const std::string& export_dir,
99 SavedModelV2Bundle* const bundle) {
100 metrics::SavedModelReadApi(kCCLoadBundleV2Label).IncrementBy(1);
101 SavedModel saved_model_proto;
102 TF_RETURN_IF_ERROR(ReadSavedModelProto(export_dir, &saved_model_proto));
103
104 // Load MetaGraphDef.
105 // In version 2 SavedModels, there is only one MetaGraphDef.
106 if (saved_model_proto.meta_graphs_size() != 1) {
107 return Status(
108 error::Code::INVALID_ARGUMENT,
109 strings::StrCat(
110 "SavedModelV2 should have exactly one MetaGraphDef but actually ",
111 "contains ", saved_model_proto.meta_graphs_size()));
112 }
113 bundle->meta_graph_def_ =
114 std::move(*saved_model_proto.mutable_meta_graphs(0));
115
116 // Load GraphDebugInfo.
117 TF_RETURN_IF_ERROR(
118 ReadSavedModelDebugInfoIfPresent(export_dir, &bundle->debug_info_));
119
120 const std::string variables_dir =
121 io::JoinPath(export_dir, kSavedModelVariablesDirectory);
122 if (!Env::Default()->FileExists(variables_dir).ok()) {
123 LOG(INFO)
124 << "No checkpoint found, assuming this is a program-only SavedModel";
125 } else {
126 // Load the variables checkpoint reader.
127 const std::string variables_prefix =
128 io::JoinPath(variables_dir, kSavedModelVariablesFilename);
129 bundle->variable_reader_.reset(
130 new BundleReader(Env::Default(), variables_prefix));
131 TF_RETURN_WITH_CONTEXT_IF_ERROR(
132 bundle->variable_reader_->status(),
133 "Unable to load SavedModel variables checkpoint from ",
134 variables_prefix);
135
136 // Deserialize the object graph proto from the tensor bundle.
137 TF_RETURN_IF_ERROR(ReadCheckpointObjectGraph(
138 bundle->variable_reader_.get(), &bundle->trackable_object_graph_));
139 }
140 return Status::OK();
141 }
142
VisitObjectsToRestore(RestoreObjectsCallback callback)143 Status SavedModelV2Bundle::VisitObjectsToRestore(
144 RestoreObjectsCallback callback) {
145 if (saved_object_graph().nodes_size() == 0 ||
146 trackable_object_graph().nodes_size() == 0) {
147 return Status::OK();
148 }
149
150 // Start from root nodes of both the SavedObjectGraph and TrackableObjectGraph
151 // and descend to leaves. Note that the TrackableObjectGraph can have cycles
152 // (as can the SavedObjectGraph).
153 // This is detected and cycle edges are skipped.
154 const SavedObject* root_saved_object = &saved_object_graph().nodes(0);
155 const TrackableObjectGraph::TrackableObject* root_trackable_object =
156 &trackable_object_graph().nodes(0);
157 absl::flat_hash_set<int> trackable_node_ids;
158 return RecurseObjectsToRestore(root_saved_object, 0, root_trackable_object,
159 std::string(), &trackable_node_ids,
160 std::move(callback));
161 }
162
RecurseObjectsToRestore(const SavedObject * saved_object,int saved_object_node_id,const TrackableObjectGraph::TrackableObject * trackable_object,std::string object_name,absl::flat_hash_set<int> * seen_trackable_node_ids,RestoreObjectsCallback callback)163 Status SavedModelV2Bundle::RecurseObjectsToRestore(
164 const SavedObject* saved_object, int saved_object_node_id,
165 const TrackableObjectGraph::TrackableObject* trackable_object,
166 std::string object_name, absl::flat_hash_set<int>* seen_trackable_node_ids,
167 RestoreObjectsCallback callback) {
168 // Callback if any attributes or slot variables.
169 // Note that the root is always excluded from the search (it can never
170 // be a restorable object). This matches some logic on the Python side.
171 if (saved_object_node_id != 0 &&
172 (trackable_object->attributes_size() > 0 ||
173 trackable_object->slot_variables_size() > 0)) {
174 TF_RETURN_WITH_CONTEXT_IF_ERROR(
175 callback(saved_object_node_id, *trackable_object), "Unable to restore ",
176 object_name);
177 }
178
179 for (const auto& trackable_child_ref : trackable_object->children()) {
180 const auto& local_name = trackable_child_ref.local_name();
181
182 // Compute the full child name.
183 std::string child_name;
184 if (object_name.empty()) {
185 child_name = local_name;
186 } else {
187 child_name = strings::StrCat(object_name, ".", local_name);
188 }
189
190 // Descend down the trackable graph.
191 int trackable_child_node_id = trackable_child_ref.node_id();
192 if (!seen_trackable_node_ids->insert(trackable_child_node_id).second) {
193 // Cycle or duplicate detected - ignore this branch.
194 continue;
195 }
196 if (trackable_child_node_id < 0 ||
197 trackable_child_node_id >= trackable_object_graph().nodes_size()) {
198 return Status(
199 errors::Code::FAILED_PRECONDITION,
200 strings::StrCat("Illegal trackable child node id for ", child_name));
201 }
202 const auto* trackable_child =
203 &trackable_object_graph().nodes(trackable_child_node_id);
204
205 // Descend down the saved object graph.
206 int saved_child_node_id = -1;
207 const SavedObject* saved_child = nullptr;
208 for (const auto& saved_child_ref : saved_object->children()) {
209 if (saved_child_ref.local_name() == local_name) {
210 // Found.
211 saved_child_node_id = saved_child_ref.node_id();
212 if (saved_child_node_id >= 0 &&
213 saved_child_node_id < saved_object_graph().nodes_size()) {
214 saved_child = &saved_object_graph().nodes(saved_child_node_id);
215 }
216 break;
217 }
218 }
219
220 if (!saved_child) {
221 return Status(
222 errors::Code::FAILED_PRECONDITION,
223 strings::StrCat("Could not find saved object to restore for ",
224 child_name));
225 }
226
227 TF_RETURN_IF_ERROR(RecurseObjectsToRestore(
228 saved_child, saved_child_node_id, trackable_child, child_name,
229 seen_trackable_node_ids, callback));
230 }
231 return Status::OK();
232 }
233
234 } // namespace tensorflow
235