• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/cc/saved_model/bundle_v2.h"
17 
18 #include "tensorflow/cc/saved_model/constants.h"
19 #include "tensorflow/core/lib/io/path.h"
20 #include "tensorflow/core/lib/strings/strcat.h"
21 #include "tensorflow/core/platform/env.h"
22 #include "tensorflow/core/platform/strcat.h"
23 #include "tensorflow/core/protobuf/saved_model.pb.h"
24 #include "tensorflow/core/protobuf/trackable_object_graph.pb.h"
25 
26 namespace tensorflow {
27 
28 namespace {
29 
ReadSavedModelProto(const string & export_dir,SavedModel * saved_model_proto)30 Status ReadSavedModelProto(const string& export_dir,
31                            SavedModel* saved_model_proto) {
32   LOG(INFO) << "Reading SavedModel from: " << export_dir;
33 
34   const string saved_model_pb_path =
35       io::JoinPath(export_dir, kSavedModelFilenamePb);
36   if (Env::Default()->FileExists(saved_model_pb_path).ok()) {
37     return ReadBinaryProto(Env::Default(), saved_model_pb_path,
38                            saved_model_proto);
39   }
40   const string saved_model_pbtxt_path =
41       io::JoinPath(export_dir, kSavedModelFilenamePbTxt);
42   if (Env::Default()->FileExists(saved_model_pbtxt_path).ok()) {
43     return ReadTextProto(Env::Default(), saved_model_pbtxt_path,
44                          saved_model_proto);
45   }
46   return Status(error::Code::NOT_FOUND,
47                 "Could not find SavedModel .pb or .pbtxt at supplied export "
48                 "directory path: " +
49                     export_dir);
50 }
51 
ReadSavedModelDebugInfoIfPresent(const string & export_dir,std::unique_ptr<GraphDebugInfo> * debug_info_proto)52 Status ReadSavedModelDebugInfoIfPresent(
53     const string& export_dir,
54     std::unique_ptr<GraphDebugInfo>* debug_info_proto) {
55   LOG(INFO) << "Reading SavedModel debug info (if present) from: "
56             << export_dir;
57 
58   const string debug_info_pb_path =
59       io::JoinPath(export_dir, "debug", "saved_model_debug_info.pb");
60   if (Env::Default()->FileExists(debug_info_pb_path).ok()) {
61     GraphDebugInfo debug_info;
62     TF_RETURN_IF_ERROR(
63         ReadBinaryProto(Env::Default(), debug_info_pb_path, &debug_info));
64     *debug_info_proto =
65         absl::make_unique<GraphDebugInfo>(std::move(debug_info));
66   }
67   return Status::OK();
68 }
69 
ReadCheckpointObjectGraph(BundleReader * bundle_reader,TrackableObjectGraph * object_graph)70 Status ReadCheckpointObjectGraph(BundleReader* bundle_reader,
71                                  TrackableObjectGraph* object_graph) {
72   Tensor object_graph_tensor;
73   TF_RETURN_WITH_CONTEXT_IF_ERROR(
74       bundle_reader->Lookup(kObjectGraphProtoKey, &object_graph_tensor),
75       "SavedModel checkpoint does not contain object graph.");
76   if (object_graph_tensor.dtype() != DT_STRING ||
77       object_graph_tensor.dims() != 0 ||
78       object_graph_tensor.NumElements() != 1) {
79     return Status(
80         error::Code::FAILED_PRECONDITION,
81         "SavedModel checkpoint object graph was not the correct type.");
82   }
83 
84   const tstring* object_graph_string = reinterpret_cast<const tstring*>(
85       object_graph_tensor.tensor_data().data());
86   if (!object_graph->ParseFromString(*object_graph_string)) {
87     return Status(
88         error::Code::FAILED_PRECONDITION,
89         "SavedModel checkpoint object graph could not be deserialized.");
90   }
91   return Status::OK();
92 }
93 
94 }  // namespace
95 
Load(const std::string & export_dir,SavedModelV2Bundle * const bundle)96 Status SavedModelV2Bundle::Load(const std::string& export_dir,
97                                 SavedModelV2Bundle* const bundle) {
98   SavedModel saved_model_proto;
99   TF_RETURN_IF_ERROR(ReadSavedModelProto(export_dir, &saved_model_proto));
100 
101   // Load MetaGraphDef.
102   // In version 2 SavedModels, there is only one MetaGraphDef.
103   if (saved_model_proto.meta_graphs_size() != 1) {
104     return Status(
105         error::Code::INVALID_ARGUMENT,
106         strings::StrCat(
107             "SavedModelV2 should have exactly one MetaGraphDef but actually ",
108             "contains ", saved_model_proto.meta_graphs_size()));
109   }
110   bundle->meta_graph_def_ =
111       std::move(*saved_model_proto.mutable_meta_graphs(0));
112 
113   // Load GraphDebugInfo.
114   TF_RETURN_IF_ERROR(
115       ReadSavedModelDebugInfoIfPresent(export_dir, &bundle->debug_info_));
116 
117   // Load the variables checkpoint reader.
118   const std::string variables_prefix = io::JoinPath(
119       export_dir, kSavedModelVariablesDirectory, kSavedModelVariablesFilename);
120   bundle->variable_reader_.reset(
121       new BundleReader(Env::Default(), variables_prefix));
122   TF_RETURN_WITH_CONTEXT_IF_ERROR(
123       bundle->variable_reader_->status(),
124       "Unable to load SavedModel variables checkpoint from ", variables_prefix);
125 
126   // Deserialize the object graph proto from the tensor bundle.
127   TF_RETURN_IF_ERROR(ReadCheckpointObjectGraph(
128       bundle->variable_reader_.get(), &bundle->trackable_object_graph_));
129   return Status::OK();
130 }
131 
VisitObjectsToRestore(RestoreObjectsCallback callback)132 Status SavedModelV2Bundle::VisitObjectsToRestore(
133     RestoreObjectsCallback callback) {
134   if (saved_object_graph().nodes_size() == 0 ||
135       trackable_object_graph().nodes_size() == 0) {
136     return Status::OK();
137   }
138 
139   // Start from root nodes of both the SavedObjectGraph and TrackableObjectGraph
140   // and descend to leaves. Note that the TrackableObjectGraph can have cycles
141   // (as can the SavedObjectGraph).
142   // This is detected and cycle edges are skipped.
143   const SavedObject* root_saved_object = &saved_object_graph().nodes(0);
144   const TrackableObjectGraph::TrackableObject* root_trackable_object =
145       &trackable_object_graph().nodes(0);
146   absl::flat_hash_set<int> trackable_node_ids;
147   return RecurseObjectsToRestore(root_saved_object, 0, root_trackable_object,
148                                  std::string(), &trackable_node_ids,
149                                  std::move(callback));
150 }
151 
RecurseObjectsToRestore(const SavedObject * saved_object,int saved_object_node_id,const TrackableObjectGraph::TrackableObject * trackable_object,std::string object_name,absl::flat_hash_set<int> * seen_trackable_node_ids,RestoreObjectsCallback callback)152 Status SavedModelV2Bundle::RecurseObjectsToRestore(
153     const SavedObject* saved_object, int saved_object_node_id,
154     const TrackableObjectGraph::TrackableObject* trackable_object,
155     std::string object_name, absl::flat_hash_set<int>* seen_trackable_node_ids,
156     RestoreObjectsCallback callback) {
157   // Callback if any attributes or slot variables.
158   // Note that the root is always excluded from the search (it can never
159   // be a restorable object). This matches some logic on the Python side.
160   if (saved_object_node_id != 0 &&
161       (trackable_object->attributes_size() > 0 ||
162        trackable_object->slot_variables_size() > 0)) {
163     TF_RETURN_WITH_CONTEXT_IF_ERROR(
164         callback(saved_object_node_id, *trackable_object), "Unable to restore ",
165         object_name);
166   }
167 
168   for (const auto& trackable_child_ref : trackable_object->children()) {
169     const auto& local_name = trackable_child_ref.local_name();
170 
171     // Compute the full child name.
172     std::string child_name;
173     if (object_name.empty()) {
174       child_name = local_name;
175     } else {
176       child_name = strings::StrCat(object_name, ".", local_name);
177     }
178 
179     // Descend down the trackable graph.
180     int trackable_child_node_id = trackable_child_ref.node_id();
181     if (!seen_trackable_node_ids->insert(trackable_child_node_id).second) {
182       // Cycle or duplicate detected - ignore this branch.
183       continue;
184     }
185     if (trackable_child_node_id < 0 ||
186         trackable_child_node_id >= trackable_object_graph().nodes_size()) {
187       return Status(
188           errors::Code::FAILED_PRECONDITION,
189           strings::StrCat("Illegal trackable child node id for ", child_name));
190     }
191     const auto* trackable_child =
192         &trackable_object_graph().nodes(trackable_child_node_id);
193 
194     // Descend down the saved object graph.
195     int saved_child_node_id = -1;
196     const SavedObject* saved_child = nullptr;
197     for (const auto& saved_child_ref : saved_object->children()) {
198       if (saved_child_ref.local_name() == local_name) {
199         // Found.
200         saved_child_node_id = saved_child_ref.node_id();
201         if (saved_child_node_id >= 0 &&
202             saved_child_node_id < saved_object_graph().nodes_size()) {
203           saved_child = &saved_object_graph().nodes(saved_child_node_id);
204         }
205         break;
206       }
207     }
208 
209     if (!saved_child) {
210       return Status(
211           errors::Code::FAILED_PRECONDITION,
212           strings::StrCat("Could not find saved object to restore for ",
213                           child_name));
214     }
215 
216     TF_RETURN_IF_ERROR(RecurseObjectsToRestore(
217         saved_child, saved_child_node_id, trackable_child, child_name,
218         seen_trackable_node_ids, callback));
219   }
220   return Status::OK();
221 }
222 
223 }  // namespace tensorflow
224