1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/cc/saved_model/bundle_v2.h"
17
18 #include <string>
19 #include <utility>
20
21 #include "tensorflow/cc/saved_model/constants.h"
22 #include "tensorflow/cc/saved_model/metrics.h"
23 #include "tensorflow/cc/saved_model/reader.h"
24 #include "tensorflow/cc/saved_model/util.h"
25 #include "tensorflow/core/lib/io/path.h"
26 #include "tensorflow/core/lib/strings/strcat.h"
27 #include "tensorflow/core/platform/env.h"
28 #include "tensorflow/core/platform/strcat.h"
29 #include "tensorflow/core/protobuf/saved_model.pb.h"
30 #include "tensorflow/core/protobuf/trackable_object_graph.pb.h"
31 #include "tensorflow/core/util/tensor_bundle/byte_swap.h"
32
33 namespace tensorflow {
34 namespace {
35
36 // `tensorflow::SavedModelV2Bundle::Load` API label.
37 constexpr char kCCLoadBundleV2Label[] = "cc_load_bundle_v2";
38
ReadSavedModelProto(const string & export_dir,SavedModel * saved_model_proto)39 Status ReadSavedModelProto(const string& export_dir,
40 SavedModel* saved_model_proto) {
41 LOG(INFO) << "Reading SavedModel from: " << export_dir;
42
43 const string saved_model_pb_path =
44 io::JoinPath(export_dir, kSavedModelFilenamePb);
45
46 if (Env::Default()->FileExists(saved_model_pb_path).ok()) {
47 Status result =
48 ReadBinaryProto(Env::Default(), saved_model_pb_path, saved_model_proto);
49 if (result.ok()) {
50 metrics::SavedModelRead(saved_model::GetWriteVersion(*saved_model_proto))
51 .IncrementBy(1);
52 }
53 return result;
54 }
55 const string saved_model_pbtxt_path =
56 io::JoinPath(export_dir, kSavedModelFilenamePbTxt);
57 if (Env::Default()->FileExists(saved_model_pbtxt_path).ok()) {
58 Status result = ReadTextProto(Env::Default(), saved_model_pbtxt_path,
59 saved_model_proto);
60 if (result.ok()) {
61 metrics::SavedModelRead(saved_model::GetWriteVersion(*saved_model_proto))
62 .IncrementBy(1);
63 }
64 return result;
65 }
66
67 return Status(error::Code::NOT_FOUND,
68 "Could not find SavedModel .pb or .pbtxt at supplied export "
69 "directory path: " +
70 export_dir);
71 }
72
ReadCheckpointObjectGraph(BundleReader * bundle_reader,TrackableObjectGraph * object_graph)73 Status ReadCheckpointObjectGraph(BundleReader* bundle_reader,
74 TrackableObjectGraph* object_graph) {
75 Tensor object_graph_tensor;
76 TF_RETURN_WITH_CONTEXT_IF_ERROR(
77 bundle_reader->Lookup(kObjectGraphProtoKey, &object_graph_tensor),
78 "SavedModel checkpoint does not contain object graph.");
79 if (object_graph_tensor.dtype() != DT_STRING ||
80 object_graph_tensor.dims() != 0 ||
81 object_graph_tensor.NumElements() != 1) {
82 return Status(
83 error::Code::FAILED_PRECONDITION,
84 "SavedModel checkpoint object graph was not the correct type.");
85 }
86
87 const tstring* object_graph_string = reinterpret_cast<const tstring*>(
88 object_graph_tensor.tensor_data().data());
89 if (!object_graph->ParseFromString(*object_graph_string)) {
90 return Status(
91 error::Code::FAILED_PRECONDITION,
92 "SavedModel checkpoint object graph could not be deserialized.");
93 }
94 return OkStatus();
95 }
96
97 } // namespace
98
Load(const std::string & export_dir,SavedModelV2Bundle * const bundle)99 Status SavedModelV2Bundle::Load(const std::string& export_dir,
100 SavedModelV2Bundle* const bundle) {
101 metrics::SavedModelReadApi(kCCLoadBundleV2Label).IncrementBy(1);
102 SavedModel saved_model_proto;
103 TF_RETURN_IF_ERROR(ReadSavedModelProto(export_dir, &saved_model_proto));
104
105 // Load MetaGraphDef.
106 // In version 2 SavedModels, there is only one MetaGraphDef.
107 if (saved_model_proto.meta_graphs_size() != 1) {
108 return Status(
109 error::Code::INVALID_ARGUMENT,
110 strings::StrCat(
111 "SavedModelV2 should have exactly one MetaGraphDef but actually ",
112 "contains ", saved_model_proto.meta_graphs_size()));
113 }
114 bundle->meta_graph_def_ =
115 std::move(*saved_model_proto.mutable_meta_graphs(0));
116
117 // Correct the endiness of Tensor content on big-endian system
118 if (!port::kLittleEndian) {
119 TF_RETURN_IF_ERROR(ByteSwapTensorContent(&(bundle->meta_graph_def_)));
120 }
121
122 // Load GraphDebugInfo.
123 TF_RETURN_IF_ERROR(
124 ReadSavedModelDebugInfoIfPresent(export_dir, &bundle->debug_info_));
125
126 const std::string variables_dir =
127 io::JoinPath(export_dir, kSavedModelVariablesDirectory);
128 if (!Env::Default()->FileExists(variables_dir).ok()) {
129 LOG(INFO)
130 << "No checkpoint found, assuming this is a program-only SavedModel";
131 } else {
132 // Load the variables checkpoint reader.
133 const std::string variables_prefix =
134 io::JoinPath(variables_dir, kSavedModelVariablesFilename);
135 bundle->variable_reader_.reset(
136 new BundleReader(Env::Default(), variables_prefix));
137 TF_RETURN_WITH_CONTEXT_IF_ERROR(
138 bundle->variable_reader_->status(),
139 "Unable to load SavedModel variables checkpoint from ",
140 variables_prefix);
141
142 // Deserialize the object graph proto from the tensor bundle.
143 TF_RETURN_IF_ERROR(ReadCheckpointObjectGraph(
144 bundle->variable_reader_.get(), &bundle->trackable_object_graph_));
145 }
146 return OkStatus();
147 }
148
VisitObjectsToRestore(RestoreObjectsCallback callback)149 Status SavedModelV2Bundle::VisitObjectsToRestore(
150 RestoreObjectsCallback callback) {
151 if (saved_object_graph().nodes_size() == 0 ||
152 trackable_object_graph().nodes_size() == 0) {
153 return OkStatus();
154 }
155
156 // Start from root nodes of both the SavedObjectGraph and TrackableObjectGraph
157 // and descend to leaves. Note that the TrackableObjectGraph can have cycles
158 // (as can the SavedObjectGraph).
159 // This is detected and cycle edges are skipped.
160 const SavedObject* root_saved_object = &saved_object_graph().nodes(0);
161 const TrackableObjectGraph::TrackableObject* root_trackable_object =
162 &trackable_object_graph().nodes(0);
163 absl::flat_hash_set<int> trackable_node_ids;
164 return RecurseObjectsToRestore(root_saved_object, 0, root_trackable_object,
165 std::string(), &trackable_node_ids,
166 std::move(callback));
167 }
168
RecurseObjectsToRestore(const SavedObject * saved_object,int saved_object_node_id,const TrackableObjectGraph::TrackableObject * trackable_object,std::string object_name,absl::flat_hash_set<int> * seen_trackable_node_ids,RestoreObjectsCallback callback)169 Status SavedModelV2Bundle::RecurseObjectsToRestore(
170 const SavedObject* saved_object, int saved_object_node_id,
171 const TrackableObjectGraph::TrackableObject* trackable_object,
172 std::string object_name, absl::flat_hash_set<int>* seen_trackable_node_ids,
173 RestoreObjectsCallback callback) {
174 // Callback if any attributes or slot variables.
175 // Note that the root is always excluded from the search (it can never
176 // be a restorable object). This matches some logic on the Python side.
177 if (saved_object_node_id != 0 &&
178 (trackable_object->attributes_size() > 0 ||
179 trackable_object->slot_variables_size() > 0)) {
180 TF_RETURN_WITH_CONTEXT_IF_ERROR(
181 callback(saved_object_node_id, *trackable_object), "Unable to restore ",
182 object_name);
183 }
184
185 for (const auto& trackable_child_ref : trackable_object->children()) {
186 const auto& local_name = trackable_child_ref.local_name();
187
188 // Compute the full child name.
189 std::string child_name;
190 if (object_name.empty()) {
191 child_name = local_name;
192 } else {
193 child_name = strings::StrCat(object_name, ".", local_name);
194 }
195
196 // Descend down the trackable graph.
197 int trackable_child_node_id = trackable_child_ref.node_id();
198 if (!seen_trackable_node_ids->insert(trackable_child_node_id).second) {
199 // Cycle or duplicate detected - ignore this branch.
200 continue;
201 }
202 if (trackable_child_node_id < 0 ||
203 trackable_child_node_id >= trackable_object_graph().nodes_size()) {
204 return errors::FailedPrecondition(
205 strings::StrCat("Illegal trackable child node id for ", child_name));
206 }
207 const auto* trackable_child =
208 &trackable_object_graph().nodes(trackable_child_node_id);
209
210 // Descend down the saved object graph.
211 int saved_child_node_id = -1;
212 const SavedObject* saved_child = nullptr;
213 for (const auto& saved_child_ref : saved_object->children()) {
214 if (saved_child_ref.local_name() == local_name) {
215 // Found.
216 saved_child_node_id = saved_child_ref.node_id();
217 if (saved_child_node_id >= 0 &&
218 saved_child_node_id < saved_object_graph().nodes_size()) {
219 saved_child = &saved_object_graph().nodes(saved_child_node_id);
220 }
221 break;
222 }
223 }
224
225 if (!saved_child) {
226 return Status(
227 errors::Code::FAILED_PRECONDITION,
228 strings::StrCat("Could not find saved object to restore for ",
229 child_name));
230 }
231
232 TF_RETURN_IF_ERROR(RecurseObjectsToRestore(
233 saved_child, saved_child_node_id, trackable_child, child_name,
234 seen_trackable_node_ids, callback));
235 }
236 return OkStatus();
237 }
238
239 } // namespace tensorflow
240