• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/cc/saved_model/bundle_v2.h"
17 
18 #include <string>
19 #include <utility>
20 
21 #include "tensorflow/cc/saved_model/constants.h"
22 #include "tensorflow/cc/saved_model/metrics.h"
23 #include "tensorflow/cc/saved_model/reader.h"
24 #include "tensorflow/cc/saved_model/util.h"
25 #include "tensorflow/core/lib/io/path.h"
26 #include "tensorflow/core/lib/strings/strcat.h"
27 #include "tensorflow/core/platform/env.h"
28 #include "tensorflow/core/platform/strcat.h"
29 #include "tensorflow/core/protobuf/saved_model.pb.h"
30 #include "tensorflow/core/protobuf/trackable_object_graph.pb.h"
31 
32 namespace tensorflow {
33 namespace {
34 
35 // `tensorflow::SavedModelV2Bundle::Load` API label.
36 constexpr char kCCLoadBundleV2Label[] = "cc_load_bundle_v2";
37 
ReadSavedModelProto(const string & export_dir,SavedModel * saved_model_proto)38 Status ReadSavedModelProto(const string& export_dir,
39                            SavedModel* saved_model_proto) {
40   LOG(INFO) << "Reading SavedModel from: " << export_dir;
41 
42   const string saved_model_pb_path =
43       io::JoinPath(export_dir, kSavedModelFilenamePb);
44 
45   if (Env::Default()->FileExists(saved_model_pb_path).ok()) {
46     Status result =
47         ReadBinaryProto(Env::Default(), saved_model_pb_path, saved_model_proto);
48     if (result.ok()) {
49       metrics::SavedModelRead(saved_model::GetWriteVersion(*saved_model_proto))
50           .IncrementBy(1);
51     }
52     return result;
53   }
54   const string saved_model_pbtxt_path =
55       io::JoinPath(export_dir, kSavedModelFilenamePbTxt);
56   if (Env::Default()->FileExists(saved_model_pbtxt_path).ok()) {
57     Status result = ReadTextProto(Env::Default(), saved_model_pbtxt_path,
58                                   saved_model_proto);
59     if (result.ok()) {
60       metrics::SavedModelRead(saved_model::GetWriteVersion(*saved_model_proto))
61           .IncrementBy(1);
62     }
63     return result;
64   }
65 
66   return Status(error::Code::NOT_FOUND,
67                 "Could not find SavedModel .pb or .pbtxt at supplied export "
68                 "directory path: " +
69                     export_dir);
70 }
71 
ReadCheckpointObjectGraph(BundleReader * bundle_reader,TrackableObjectGraph * object_graph)72 Status ReadCheckpointObjectGraph(BundleReader* bundle_reader,
73                                  TrackableObjectGraph* object_graph) {
74   Tensor object_graph_tensor;
75   TF_RETURN_WITH_CONTEXT_IF_ERROR(
76       bundle_reader->Lookup(kObjectGraphProtoKey, &object_graph_tensor),
77       "SavedModel checkpoint does not contain object graph.");
78   if (object_graph_tensor.dtype() != DT_STRING ||
79       object_graph_tensor.dims() != 0 ||
80       object_graph_tensor.NumElements() != 1) {
81     return Status(
82         error::Code::FAILED_PRECONDITION,
83         "SavedModel checkpoint object graph was not the correct type.");
84   }
85 
86   const tstring* object_graph_string = reinterpret_cast<const tstring*>(
87       object_graph_tensor.tensor_data().data());
88   if (!object_graph->ParseFromString(*object_graph_string)) {
89     return Status(
90         error::Code::FAILED_PRECONDITION,
91         "SavedModel checkpoint object graph could not be deserialized.");
92   }
93   return Status::OK();
94 }
95 
96 }  // namespace
97 
Load(const std::string & export_dir,SavedModelV2Bundle * const bundle)98 Status SavedModelV2Bundle::Load(const std::string& export_dir,
99                                 SavedModelV2Bundle* const bundle) {
100   metrics::SavedModelReadApi(kCCLoadBundleV2Label).IncrementBy(1);
101   SavedModel saved_model_proto;
102   TF_RETURN_IF_ERROR(ReadSavedModelProto(export_dir, &saved_model_proto));
103 
104   // Load MetaGraphDef.
105   // In version 2 SavedModels, there is only one MetaGraphDef.
106   if (saved_model_proto.meta_graphs_size() != 1) {
107     return Status(
108         error::Code::INVALID_ARGUMENT,
109         strings::StrCat(
110             "SavedModelV2 should have exactly one MetaGraphDef but actually ",
111             "contains ", saved_model_proto.meta_graphs_size()));
112   }
113   bundle->meta_graph_def_ =
114       std::move(*saved_model_proto.mutable_meta_graphs(0));
115 
116   // Load GraphDebugInfo.
117   TF_RETURN_IF_ERROR(
118       ReadSavedModelDebugInfoIfPresent(export_dir, &bundle->debug_info_));
119 
120   const std::string variables_dir =
121       io::JoinPath(export_dir, kSavedModelVariablesDirectory);
122   if (!Env::Default()->FileExists(variables_dir).ok()) {
123     LOG(INFO)
124         << "No checkpoint found, assuming this is a program-only SavedModel";
125   } else {
126     // Load the variables checkpoint reader.
127     const std::string variables_prefix =
128         io::JoinPath(variables_dir, kSavedModelVariablesFilename);
129     bundle->variable_reader_.reset(
130         new BundleReader(Env::Default(), variables_prefix));
131     TF_RETURN_WITH_CONTEXT_IF_ERROR(
132         bundle->variable_reader_->status(),
133         "Unable to load SavedModel variables checkpoint from ",
134         variables_prefix);
135 
136     // Deserialize the object graph proto from the tensor bundle.
137     TF_RETURN_IF_ERROR(ReadCheckpointObjectGraph(
138         bundle->variable_reader_.get(), &bundle->trackable_object_graph_));
139   }
140   return Status::OK();
141 }
142 
VisitObjectsToRestore(RestoreObjectsCallback callback)143 Status SavedModelV2Bundle::VisitObjectsToRestore(
144     RestoreObjectsCallback callback) {
145   if (saved_object_graph().nodes_size() == 0 ||
146       trackable_object_graph().nodes_size() == 0) {
147     return Status::OK();
148   }
149 
150   // Start from root nodes of both the SavedObjectGraph and TrackableObjectGraph
151   // and descend to leaves. Note that the TrackableObjectGraph can have cycles
152   // (as can the SavedObjectGraph).
153   // This is detected and cycle edges are skipped.
154   const SavedObject* root_saved_object = &saved_object_graph().nodes(0);
155   const TrackableObjectGraph::TrackableObject* root_trackable_object =
156       &trackable_object_graph().nodes(0);
157   absl::flat_hash_set<int> trackable_node_ids;
158   return RecurseObjectsToRestore(root_saved_object, 0, root_trackable_object,
159                                  std::string(), &trackable_node_ids,
160                                  std::move(callback));
161 }
162 
RecurseObjectsToRestore(const SavedObject * saved_object,int saved_object_node_id,const TrackableObjectGraph::TrackableObject * trackable_object,std::string object_name,absl::flat_hash_set<int> * seen_trackable_node_ids,RestoreObjectsCallback callback)163 Status SavedModelV2Bundle::RecurseObjectsToRestore(
164     const SavedObject* saved_object, int saved_object_node_id,
165     const TrackableObjectGraph::TrackableObject* trackable_object,
166     std::string object_name, absl::flat_hash_set<int>* seen_trackable_node_ids,
167     RestoreObjectsCallback callback) {
168   // Callback if any attributes or slot variables.
169   // Note that the root is always excluded from the search (it can never
170   // be a restorable object). This matches some logic on the Python side.
171   if (saved_object_node_id != 0 &&
172       (trackable_object->attributes_size() > 0 ||
173        trackable_object->slot_variables_size() > 0)) {
174     TF_RETURN_WITH_CONTEXT_IF_ERROR(
175         callback(saved_object_node_id, *trackable_object), "Unable to restore ",
176         object_name);
177   }
178 
179   for (const auto& trackable_child_ref : trackable_object->children()) {
180     const auto& local_name = trackable_child_ref.local_name();
181 
182     // Compute the full child name.
183     std::string child_name;
184     if (object_name.empty()) {
185       child_name = local_name;
186     } else {
187       child_name = strings::StrCat(object_name, ".", local_name);
188     }
189 
190     // Descend down the trackable graph.
191     int trackable_child_node_id = trackable_child_ref.node_id();
192     if (!seen_trackable_node_ids->insert(trackable_child_node_id).second) {
193       // Cycle or duplicate detected - ignore this branch.
194       continue;
195     }
196     if (trackable_child_node_id < 0 ||
197         trackable_child_node_id >= trackable_object_graph().nodes_size()) {
198       return Status(
199           errors::Code::FAILED_PRECONDITION,
200           strings::StrCat("Illegal trackable child node id for ", child_name));
201     }
202     const auto* trackable_child =
203         &trackable_object_graph().nodes(trackable_child_node_id);
204 
205     // Descend down the saved object graph.
206     int saved_child_node_id = -1;
207     const SavedObject* saved_child = nullptr;
208     for (const auto& saved_child_ref : saved_object->children()) {
209       if (saved_child_ref.local_name() == local_name) {
210         // Found.
211         saved_child_node_id = saved_child_ref.node_id();
212         if (saved_child_node_id >= 0 &&
213             saved_child_node_id < saved_object_graph().nodes_size()) {
214           saved_child = &saved_object_graph().nodes(saved_child_node_id);
215         }
216         break;
217       }
218     }
219 
220     if (!saved_child) {
221       return Status(
222           errors::Code::FAILED_PRECONDITION,
223           strings::StrCat("Could not find saved object to restore for ",
224                           child_name));
225     }
226 
227     TF_RETURN_IF_ERROR(RecurseObjectsToRestore(
228         saved_child, saved_child_node_id, trackable_child, child_name,
229         seen_trackable_node_ids, callback));
230   }
231   return Status::OK();
232 }
233 
234 }  // namespace tensorflow
235