• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/cc/saved_model/bundle_v2.h"
17 
18 #include <string>
19 #include <utility>
20 
21 #include "tensorflow/cc/saved_model/constants.h"
22 #include "tensorflow/cc/saved_model/metrics.h"
23 #include "tensorflow/cc/saved_model/reader.h"
24 #include "tensorflow/cc/saved_model/util.h"
25 #include "tensorflow/core/lib/io/path.h"
26 #include "tensorflow/core/lib/strings/strcat.h"
27 #include "tensorflow/core/platform/env.h"
28 #include "tensorflow/core/platform/strcat.h"
29 #include "tensorflow/core/protobuf/saved_model.pb.h"
30 #include "tensorflow/core/protobuf/trackable_object_graph.pb.h"
31 #include "tensorflow/core/util/tensor_bundle/byte_swap.h"
32 
33 namespace tensorflow {
34 namespace {
35 
36 // `tensorflow::SavedModelV2Bundle::Load` API label.
37 constexpr char kCCLoadBundleV2Label[] = "cc_load_bundle_v2";
38 
ReadSavedModelProto(const string & export_dir,SavedModel * saved_model_proto)39 Status ReadSavedModelProto(const string& export_dir,
40                            SavedModel* saved_model_proto) {
41   LOG(INFO) << "Reading SavedModel from: " << export_dir;
42 
43   const string saved_model_pb_path =
44       io::JoinPath(export_dir, kSavedModelFilenamePb);
45 
46   if (Env::Default()->FileExists(saved_model_pb_path).ok()) {
47     Status result =
48         ReadBinaryProto(Env::Default(), saved_model_pb_path, saved_model_proto);
49     if (result.ok()) {
50       metrics::SavedModelRead(saved_model::GetWriteVersion(*saved_model_proto))
51           .IncrementBy(1);
52     }
53     return result;
54   }
55   const string saved_model_pbtxt_path =
56       io::JoinPath(export_dir, kSavedModelFilenamePbTxt);
57   if (Env::Default()->FileExists(saved_model_pbtxt_path).ok()) {
58     Status result = ReadTextProto(Env::Default(), saved_model_pbtxt_path,
59                                   saved_model_proto);
60     if (result.ok()) {
61       metrics::SavedModelRead(saved_model::GetWriteVersion(*saved_model_proto))
62           .IncrementBy(1);
63     }
64     return result;
65   }
66 
67   return Status(error::Code::NOT_FOUND,
68                 "Could not find SavedModel .pb or .pbtxt at supplied export "
69                 "directory path: " +
70                     export_dir);
71 }
72 
ReadCheckpointObjectGraph(BundleReader * bundle_reader,TrackableObjectGraph * object_graph)73 Status ReadCheckpointObjectGraph(BundleReader* bundle_reader,
74                                  TrackableObjectGraph* object_graph) {
75   Tensor object_graph_tensor;
76   TF_RETURN_WITH_CONTEXT_IF_ERROR(
77       bundle_reader->Lookup(kObjectGraphProtoKey, &object_graph_tensor),
78       "SavedModel checkpoint does not contain object graph.");
79   if (object_graph_tensor.dtype() != DT_STRING ||
80       object_graph_tensor.dims() != 0 ||
81       object_graph_tensor.NumElements() != 1) {
82     return Status(
83         error::Code::FAILED_PRECONDITION,
84         "SavedModel checkpoint object graph was not the correct type.");
85   }
86 
87   const tstring* object_graph_string = reinterpret_cast<const tstring*>(
88       object_graph_tensor.tensor_data().data());
89   if (!object_graph->ParseFromString(*object_graph_string)) {
90     return Status(
91         error::Code::FAILED_PRECONDITION,
92         "SavedModel checkpoint object graph could not be deserialized.");
93   }
94   return OkStatus();
95 }
96 
97 }  // namespace
98 
Load(const std::string & export_dir,SavedModelV2Bundle * const bundle)99 Status SavedModelV2Bundle::Load(const std::string& export_dir,
100                                 SavedModelV2Bundle* const bundle) {
101   metrics::SavedModelReadApi(kCCLoadBundleV2Label).IncrementBy(1);
102   SavedModel saved_model_proto;
103   TF_RETURN_IF_ERROR(ReadSavedModelProto(export_dir, &saved_model_proto));
104 
105   // Load MetaGraphDef.
106   // In version 2 SavedModels, there is only one MetaGraphDef.
107   if (saved_model_proto.meta_graphs_size() != 1) {
108     return Status(
109         error::Code::INVALID_ARGUMENT,
110         strings::StrCat(
111             "SavedModelV2 should have exactly one MetaGraphDef but actually ",
112             "contains ", saved_model_proto.meta_graphs_size()));
113   }
114   bundle->meta_graph_def_ =
115       std::move(*saved_model_proto.mutable_meta_graphs(0));
116 
117   // Correct the endiness of Tensor content on big-endian system
118   if (!port::kLittleEndian) {
119     TF_RETURN_IF_ERROR(ByteSwapTensorContent(&(bundle->meta_graph_def_)));
120   }
121 
122   // Load GraphDebugInfo.
123   TF_RETURN_IF_ERROR(
124       ReadSavedModelDebugInfoIfPresent(export_dir, &bundle->debug_info_));
125 
126   const std::string variables_dir =
127       io::JoinPath(export_dir, kSavedModelVariablesDirectory);
128   if (!Env::Default()->FileExists(variables_dir).ok()) {
129     LOG(INFO)
130         << "No checkpoint found, assuming this is a program-only SavedModel";
131   } else {
132     // Load the variables checkpoint reader.
133     const std::string variables_prefix =
134         io::JoinPath(variables_dir, kSavedModelVariablesFilename);
135     bundle->variable_reader_.reset(
136         new BundleReader(Env::Default(), variables_prefix));
137     TF_RETURN_WITH_CONTEXT_IF_ERROR(
138         bundle->variable_reader_->status(),
139         "Unable to load SavedModel variables checkpoint from ",
140         variables_prefix);
141 
142     // Deserialize the object graph proto from the tensor bundle.
143     TF_RETURN_IF_ERROR(ReadCheckpointObjectGraph(
144         bundle->variable_reader_.get(), &bundle->trackable_object_graph_));
145   }
146   return OkStatus();
147 }
148 
VisitObjectsToRestore(RestoreObjectsCallback callback)149 Status SavedModelV2Bundle::VisitObjectsToRestore(
150     RestoreObjectsCallback callback) {
151   if (saved_object_graph().nodes_size() == 0 ||
152       trackable_object_graph().nodes_size() == 0) {
153     return OkStatus();
154   }
155 
156   // Start from root nodes of both the SavedObjectGraph and TrackableObjectGraph
157   // and descend to leaves. Note that the TrackableObjectGraph can have cycles
158   // (as can the SavedObjectGraph).
159   // This is detected and cycle edges are skipped.
160   const SavedObject* root_saved_object = &saved_object_graph().nodes(0);
161   const TrackableObjectGraph::TrackableObject* root_trackable_object =
162       &trackable_object_graph().nodes(0);
163   absl::flat_hash_set<int> trackable_node_ids;
164   return RecurseObjectsToRestore(root_saved_object, 0, root_trackable_object,
165                                  std::string(), &trackable_node_ids,
166                                  std::move(callback));
167 }
168 
RecurseObjectsToRestore(const SavedObject * saved_object,int saved_object_node_id,const TrackableObjectGraph::TrackableObject * trackable_object,std::string object_name,absl::flat_hash_set<int> * seen_trackable_node_ids,RestoreObjectsCallback callback)169 Status SavedModelV2Bundle::RecurseObjectsToRestore(
170     const SavedObject* saved_object, int saved_object_node_id,
171     const TrackableObjectGraph::TrackableObject* trackable_object,
172     std::string object_name, absl::flat_hash_set<int>* seen_trackable_node_ids,
173     RestoreObjectsCallback callback) {
174   // Callback if any attributes or slot variables.
175   // Note that the root is always excluded from the search (it can never
176   // be a restorable object). This matches some logic on the Python side.
177   if (saved_object_node_id != 0 &&
178       (trackable_object->attributes_size() > 0 ||
179        trackable_object->slot_variables_size() > 0)) {
180     TF_RETURN_WITH_CONTEXT_IF_ERROR(
181         callback(saved_object_node_id, *trackable_object), "Unable to restore ",
182         object_name);
183   }
184 
185   for (const auto& trackable_child_ref : trackable_object->children()) {
186     const auto& local_name = trackable_child_ref.local_name();
187 
188     // Compute the full child name.
189     std::string child_name;
190     if (object_name.empty()) {
191       child_name = local_name;
192     } else {
193       child_name = strings::StrCat(object_name, ".", local_name);
194     }
195 
196     // Descend down the trackable graph.
197     int trackable_child_node_id = trackable_child_ref.node_id();
198     if (!seen_trackable_node_ids->insert(trackable_child_node_id).second) {
199       // Cycle or duplicate detected - ignore this branch.
200       continue;
201     }
202     if (trackable_child_node_id < 0 ||
203         trackable_child_node_id >= trackable_object_graph().nodes_size()) {
204       return errors::FailedPrecondition(
205           strings::StrCat("Illegal trackable child node id for ", child_name));
206     }
207     const auto* trackable_child =
208         &trackable_object_graph().nodes(trackable_child_node_id);
209 
210     // Descend down the saved object graph.
211     int saved_child_node_id = -1;
212     const SavedObject* saved_child = nullptr;
213     for (const auto& saved_child_ref : saved_object->children()) {
214       if (saved_child_ref.local_name() == local_name) {
215         // Found.
216         saved_child_node_id = saved_child_ref.node_id();
217         if (saved_child_node_id >= 0 &&
218             saved_child_node_id < saved_object_graph().nodes_size()) {
219           saved_child = &saved_object_graph().nodes(saved_child_node_id);
220         }
221         break;
222       }
223     }
224 
225     if (!saved_child) {
226       return Status(
227           errors::Code::FAILED_PRECONDITION,
228           strings::StrCat("Could not find saved object to restore for ",
229                           child_name));
230     }
231 
232     TF_RETURN_IF_ERROR(RecurseObjectsToRestore(
233         saved_child, saved_child_node_id, trackable_child, child_name,
234         seen_trackable_node_ids, callback));
235   }
236   return OkStatus();
237 }
238 
239 }  // namespace tensorflow
240