• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/cc/saved_model/loader.h"
17 
18 #include <unordered_set>
19 
20 #include "tensorflow/cc/saved_model/constants.h"
21 #include "tensorflow/cc/saved_model/loader_util.h"
22 #include "tensorflow/cc/saved_model/reader.h"
23 #include "tensorflow/core/framework/attr_value.pb.h"
24 #include "tensorflow/core/framework/function.pb.h"
25 #include "tensorflow/core/framework/node_def.pb.h"
26 #include "tensorflow/core/framework/tensor.pb.h"
27 #include "tensorflow/core/lib/io/path.h"
28 #include "tensorflow/core/lib/monitoring/counter.h"
29 #include "tensorflow/core/lib/monitoring/sampler.h"
30 #include "tensorflow/core/lib/strings/str_util.h"
31 #include "tensorflow/core/lib/strings/strcat.h"
32 #include "tensorflow/core/platform/env.h"
33 #include "tensorflow/core/platform/errors.h"
34 #include "tensorflow/core/protobuf/graph_debug_info.pb.h"
35 #include "tensorflow/core/protobuf/meta_graph.pb.h"
36 #include "tensorflow/core/protobuf/saver.pb.h"
37 #include "tensorflow/core/public/session.h"
38 #include "tensorflow/core/public/session_options.h"
39 #include "tensorflow/core/util/tensor_bundle/naming.h"
40 
41 namespace tensorflow {
42 namespace {
43 
44 auto* load_attempt_count = monitoring::Counter<2>::New(
45     "/tensorflow/cc/saved_model/load_attempt_count",
46     "The number of times a SavedModel was successfully loaded.", "model_path",
47     "status");
48 auto* load_latency = monitoring::Counter<1>::New(
49     "/tensorflow/cc/saved_model/load_latency",
50     "Latency in microseconds for SavedModels that were successfully loaded.",
51     "model_path");
52 auto* load_latency_by_stage = monitoring::Sampler<2>::New(
53     {
54         "/tensorflow/cc/saved_model/load_latency_by_stage",  // metric name
55         "Distribution of wall time spent (in microseconds) in each stage "
56         "(restore graph from disk, run init graph op, etc) when loading the "
57         "model",
58         "model_path",
59         "stage",
60     },
61     // Scale of 10, power of 1.8 with bucket count 33 (~20 minutes).
62     monitoring::Buckets::Exponential(10, 1.8, 33));
63 
64 constexpr char kLoadAttemptFail[] = "fail";
65 constexpr char kLoadAttemptSuccess[] = "success";
66 
GetLatencyMicroseconds(const uint64 start_microseconds)67 uint64 GetLatencyMicroseconds(const uint64 start_microseconds) {
68   const uint64 end_microseconds = EnvTime::NowMicros();
69   // Avoid clock skew.
70   if (end_microseconds < start_microseconds) return 0;
71   return end_microseconds - start_microseconds;
72 }
73 
74 // Ensure that constant tensors loaded from the saved model have valid shape.
75 // Also ensure that constant nodes have a value assigned to them.
76 // TODO(b/154763635): this is temporary and will be replaced with a better audit
ValidateNode(const NodeDef & node)77 static Status ValidateNode(const NodeDef& node) {
78   const auto node_iterator = node.attr().find("value");
79   if (node_iterator != node.attr().end()) {
80     AttrValue node_value = node_iterator->second;
81     if (node_value.has_tensor()) {
82       const PartialTensorShape node_shape(node_value.tensor().tensor_shape());
83       if (node_shape.num_elements() < 0) {
84         return errors::FailedPrecondition(
85             "Saved model contains node \"", node.name(), "\" (op \"", node.op(),
86             "\") which initializes from a tensor with ",
87             node_shape.num_elements(), " elements");
88       }
89     }
90   } else if (node.op() == "Const") {
91     return errors::FailedPrecondition(
92         "Saved model contains node \"", node.name(),
93         "\" which is a constant tensor but no value has been provided");
94   }
95   return Status::OK();
96 }
97 
ValidateSavedTensors(const GraphDef & graph_def)98 static Status ValidateSavedTensors(const GraphDef& graph_def) {
99   for (const auto& node : graph_def.node()) {
100     TF_RETURN_IF_ERROR(ValidateNode(node));
101   }
102 
103   if (graph_def.has_library()) {
104     const FunctionDefLibrary& library = graph_def.library();
105     for (const auto& function : library.function()) {
106       for (const auto& node : function.node_def()) {
107         TF_RETURN_IF_ERROR(ValidateNode(node));
108       }
109     }
110   }
111 
112   return Status::OK();
113 }
114 
CreateStringTensor(const string & value)115 Tensor CreateStringTensor(const string& value) {
116   Tensor tensor(DT_STRING, TensorShape({}));
117   tensor.scalar<tstring>()() = value;
118   return tensor;
119 }
120 
AddAssetsTensorsToInputs(const StringPiece export_dir,const std::vector<AssetFileDef> & asset_file_defs,std::vector<std::pair<string,Tensor>> * inputs)121 void AddAssetsTensorsToInputs(const StringPiece export_dir,
122                               const std::vector<AssetFileDef>& asset_file_defs,
123                               std::vector<std::pair<string, Tensor>>* inputs) {
124   if (asset_file_defs.empty()) {
125     return;
126   }
127   for (auto& asset_file_def : asset_file_defs) {
128     Tensor assets_file_path_tensor = CreateStringTensor(io::JoinPath(
129         export_dir, kSavedModelAssetsDirectory, asset_file_def.filename()));
130     inputs->push_back(
131         {asset_file_def.tensor_info().name(), assets_file_path_tensor});
132   }
133 }
134 
135 // Like Session::Run(), but uses the Make/Run/ReleaseCallable() API to avoid
136 // leaving behind non-GC'ed state.
137 //
138 // Detailed motivation behind this approach, from ashankar@:
139 //
140 // Each call to Session::Run() that identifies a new subgraph (based on feeds
141 // and fetches) creates some datastructures that live as long as the session
142 // (the partitioned graph, associated executors etc.).
143 //
144 // A pathological case of this would be if say the initialization op
145 // (main_op/legacy_init_op) involves the use of a large constant. Then we
146 // allocate memory for that large constant that will just stick around till the
147 // session dies. With this Callable mechanism, that memory will be released
148 // right after ReleaseCallable returns.
149 //
150 // However, the resource manager state remains.
RunOnce(const RunOptions & run_options,const std::vector<std::pair<string,Tensor>> & inputs,const std::vector<string> & output_tensor_names,const std::vector<string> & target_node_names,std::vector<Tensor> * outputs,RunMetadata * run_metadata,Session * session)151 Status RunOnce(const RunOptions& run_options,
152                const std::vector<std::pair<string, Tensor>>& inputs,
153                const std::vector<string>& output_tensor_names,
154                const std::vector<string>& target_node_names,
155                std::vector<Tensor>* outputs, RunMetadata* run_metadata,
156                Session* session) {
157   CallableOptions callable_options;
158   std::vector<Tensor> feed_tensors;
159   *callable_options.mutable_run_options() = run_options;
160   for (const auto& input : inputs) {
161     const string& name = input.first;
162     const Tensor& tensor = input.second;
163     callable_options.add_feed(name);
164     feed_tensors.push_back(tensor);
165   }
166   for (const string& output_tensor_name : output_tensor_names) {
167     callable_options.add_fetch(output_tensor_name);
168   }
169   for (const string& target_node_name : target_node_names) {
170     callable_options.add_target(target_node_name);
171   }
172 
173   Session::CallableHandle callable_handle;
174   TF_RETURN_IF_ERROR(session->MakeCallable(callable_options, &callable_handle));
175   const Status run_status = session->RunCallable(callable_handle, feed_tensors,
176                                                  outputs, run_metadata);
177   // Be sure to call ReleaseCallable() regardless of the outcome of
178   // RunCallable().
179   session->ReleaseCallable(callable_handle).IgnoreError();
180   return run_status;
181 }
182 
183 // RunInitOp will return OK if the initialization op was run successfully.
184 // An empty init_op_name indicates that there are no init ops to run.
RunInitOp(const RunOptions & run_options,const string & export_dir,const MetaGraphDef & meta_graph_def,const std::vector<AssetFileDef> & asset_file_defs,Session * session,const string & init_op_name)185 Status RunInitOp(const RunOptions& run_options, const string& export_dir,
186                  const MetaGraphDef& meta_graph_def,
187                  const std::vector<AssetFileDef>& asset_file_defs,
188                  Session* session, const string& init_op_name) {
189   if (!init_op_name.empty()) {
190     LOG(INFO) << "Running initialization op on SavedModel bundle at path: "
191               << export_dir;
192     std::vector<std::pair<string, Tensor>> inputs;
193     AddAssetsTensorsToInputs(export_dir, asset_file_defs, &inputs);
194     RunMetadata run_metadata;
195     return RunOnce(run_options, inputs, {}, {init_op_name},
196                    nullptr /* outputs */, &run_metadata, session);
197   }
198   return Status::OK();
199 }
200 
RunRestore(const RunOptions & run_options,const string & export_dir,const StringPiece restore_op_name,const StringPiece variable_filename_const_op_name,const std::vector<AssetFileDef> & asset_file_defs,Session * session)201 Status RunRestore(const RunOptions& run_options, const string& export_dir,
202                   const StringPiece restore_op_name,
203                   const StringPiece variable_filename_const_op_name,
204                   const std::vector<AssetFileDef>& asset_file_defs,
205                   Session* session) {
206   LOG(INFO) << "Restoring SavedModel bundle.";
207   // Find path to variables to be restored in export directory.
208   const string variables_directory =
209       io::JoinPath(export_dir, kSavedModelVariablesDirectory);
210   // Check for saver checkpoints in v2 format. Models exported in the checkpoint
211   // v2 format will have a variables.index file. The corresponding
212   // variables are stored in the variables.data-?????-of-????? files.
213   const string variables_index_path = io::JoinPath(
214       variables_directory, MetaFilename(kSavedModelVariablesFilename));
215   if (!Env::Default()->FileExists(variables_index_path).ok()) {
216     LOG(INFO) << "The specified SavedModel has no variables; no checkpoints "
217                  "were restored. File does not exist: "
218               << variables_index_path;
219     return Status::OK();
220   }
221   const string variables_path =
222       io::JoinPath(variables_directory, kSavedModelVariablesFilename);
223 
224   // Add variables to the graph.
225   Tensor variables_path_tensor(DT_STRING, TensorShape({}));
226   variables_path_tensor.scalar<tstring>()() = variables_path;
227 
228   std::vector<std::pair<string, Tensor>> inputs = {
229       {string(variable_filename_const_op_name), variables_path_tensor}};
230 
231   AddAssetsTensorsToInputs(export_dir, asset_file_defs, &inputs);
232 
233   RunMetadata run_metadata;
234   return RunOnce(run_options, inputs, {}, {string(restore_op_name)},
235                  nullptr /* outputs */, &run_metadata, session);
236 }
237 
238 }  // namespace
239 
~SavedModelBundleInterface()240 SavedModelBundleInterface::~SavedModelBundleInterface() {}
241 
LoadMetagraphIntoSession(const SessionOptions & session_options,const MetaGraphDef & meta_graph,std::unique_ptr<Session> * session)242 Status LoadMetagraphIntoSession(const SessionOptions& session_options,
243                                 const MetaGraphDef& meta_graph,
244                                 std::unique_ptr<Session>* session) {
245   Session* session_p = nullptr;
246   TF_RETURN_IF_ERROR(NewSession(session_options, &session_p));
247   session->reset(session_p);
248   TF_RETURN_IF_ERROR(ValidateSavedTensors(meta_graph.graph_def()));
249   return (*session)->Create(meta_graph.graph_def());
250 }
251 
LoadSavedModelInternal(const SessionOptions & session_options,const RunOptions & run_options,const string & export_dir,const std::unordered_set<string> & tags,SavedModelBundle * const bundle)252 Status LoadSavedModelInternal(const SessionOptions& session_options,
253                               const RunOptions& run_options,
254                               const string& export_dir,
255                               const std::unordered_set<string>& tags,
256                               SavedModelBundle* const bundle) {
257   TF_RETURN_IF_ERROR(ReadMetaGraphDefFromSavedModel(export_dir, tags,
258                                                     &bundle->meta_graph_def));
259   TF_RETURN_IF_ERROR(
260       ReadSavedModelDebugInfoIfPresent(export_dir, &bundle->debug_info));
261   TF_RETURN_IF_ERROR(LoadMetagraphIntoSession(
262       session_options, bundle->meta_graph_def, &bundle->session));
263   TF_RETURN_IF_ERROR(RestoreSession(run_options, bundle->meta_graph_def,
264                                     export_dir, &bundle->session));
265   return Status::OK();
266 }
267 
LoadSavedModel(const SessionOptions & session_options,const RunOptions & run_options,const string & export_dir,const std::unordered_set<string> & tags,SavedModelBundle * const bundle)268 Status LoadSavedModel(const SessionOptions& session_options,
269                       const RunOptions& run_options, const string& export_dir,
270                       const std::unordered_set<string>& tags,
271                       SavedModelBundle* const bundle) {
272   // TODO(robson): Add tests for the counters.
273   const uint64 start_microseconds = Env::Default()->NowMicros();
274   const Status status = LoadSavedModelInternal(session_options, run_options,
275                                                export_dir, tags, bundle);
276   auto log_and_count = [&](const string& status_str) {
277     LOG(INFO) << "SavedModel load for tags { " << absl::StrJoin(tags, " ")
278               << " }; Status: " << status_str << ": " << status << ". Took "
279               << GetLatencyMicroseconds(start_microseconds) << " microseconds.";
280     load_attempt_count->GetCell(export_dir, status_str)->IncrementBy(1);
281   };
282   if (status.ok()) {
283     log_and_count(kLoadAttemptSuccess);
284   } else {
285     log_and_count(kLoadAttemptFail);
286   }
287   load_latency->GetCell(export_dir)
288       ->IncrementBy(GetLatencyMicroseconds(start_microseconds));
289   return status;
290 }
291 
292 namespace {
293 // Session wrapper that prevents calls to Session::Create(), Session::Extend(),
294 // and the deprecated partial-run methods.
295 //
296 // Limiting the available methods on a returned Session gives us the option
297 // to replace the Session with a cut-down implementation, without breaking any
298 // users.
299 class LiteSessionWrapper : public Session {
300  public:
LiteSessionWrapper(std::unique_ptr<Session> wrapped)301   explicit LiteSessionWrapper(std::unique_ptr<Session> wrapped)
302       : wrapped_(std::move(wrapped)) {}
303 
Create(const GraphDef & graph)304   Status Create(const GraphDef& graph) override {
305     return errors::Unimplemented("Session::Create()");
306   }
Create(GraphDef && graph)307   Status Create(GraphDef&& graph) override {
308     return errors::Unimplemented("Session::Create()");
309   }
310 
Extend(const GraphDef & graph)311   Status Extend(const GraphDef& graph) override {
312     return errors::Unimplemented("Session::Extend()");
313   }
Extend(GraphDef && graph)314   Status Extend(GraphDef&& graph) override {
315     return errors::Unimplemented("Session::Extend()");
316   }
317 
Run(const std::vector<std::pair<string,Tensor>> & inputs,const std::vector<string> & output_tensor_names,const std::vector<string> & target_node_names,std::vector<Tensor> * outputs)318   Status Run(const std::vector<std::pair<string, Tensor>>& inputs,
319              const std::vector<string>& output_tensor_names,
320              const std::vector<string>& target_node_names,
321              std::vector<Tensor>* outputs) override {
322     return wrapped_->Run(inputs, output_tensor_names, target_node_names,
323                          outputs);
324   }
325 
Create(const RunOptions & run_options,const GraphDef & graph)326   Status Create(const RunOptions& run_options, const GraphDef& graph) override {
327     return errors::Unimplemented("Session::Create()");
328   }
Extend(const RunOptions & run_options,const GraphDef & graph)329   Status Extend(const RunOptions& run_options, const GraphDef& graph) override {
330     return errors::Unimplemented("Session::Extend()");
331   }
Create(const RunOptions & run_options,GraphDef && graph)332   Status Create(const RunOptions& run_options, GraphDef&& graph) override {
333     return errors::Unimplemented("Session::Create()");
334   }
Extend(const RunOptions & run_options,GraphDef && graph)335   Status Extend(const RunOptions& run_options, GraphDef&& graph) override {
336     return errors::Unimplemented("Session::Extend()");
337   }
Close(const RunOptions & run_options)338   Status Close(const RunOptions& run_options) override {
339     return wrapped_->Close(run_options);
340   }
341 
Run(const RunOptions & run_options,const std::vector<std::pair<string,Tensor>> & inputs,const std::vector<string> & output_tensor_names,const std::vector<string> & target_node_names,std::vector<Tensor> * outputs,RunMetadata * run_metadata)342   Status Run(const RunOptions& run_options,
343              const std::vector<std::pair<string, Tensor>>& inputs,
344              const std::vector<string>& output_tensor_names,
345              const std::vector<string>& target_node_names,
346              std::vector<Tensor>* outputs, RunMetadata* run_metadata) override {
347     return wrapped_->Run(run_options, inputs, output_tensor_names,
348                          target_node_names, outputs, run_metadata);
349   }
350 
PRunSetup(const std::vector<string> & input_names,const std::vector<string> & output_names,const std::vector<string> & target_nodes,string * handle)351   Status PRunSetup(const std::vector<string>& input_names,
352                    const std::vector<string>& output_names,
353                    const std::vector<string>& target_nodes,
354                    string* handle) override {
355     return errors::Unimplemented("Session::PRunSetup()");
356   }
357 
PRun(const string & handle,const std::vector<std::pair<string,Tensor>> & inputs,const std::vector<string> & output_names,std::vector<Tensor> * outputs)358   Status PRun(const string& handle,
359               const std::vector<std::pair<string, Tensor>>& inputs,
360               const std::vector<string>& output_names,
361               std::vector<Tensor>* outputs) override {
362     return errors::Unimplemented("Session::PRun()");
363   }
364 
ListDevices(std::vector<DeviceAttributes> * response)365   Status ListDevices(std::vector<DeviceAttributes>* response) override {
366     return wrapped_->ListDevices(response);
367   }
368 
Close()369   Status Close() override { return wrapped_->Close(); }
370 
MakeCallable(const CallableOptions & callable_options,CallableHandle * out_handle)371   Status MakeCallable(const CallableOptions& callable_options,
372                       CallableHandle* out_handle) override {
373     return wrapped_->MakeCallable(callable_options, out_handle);
374   }
375 
RunCallable(CallableHandle handle,const std::vector<Tensor> & feed_tensors,std::vector<Tensor> * fetch_tensors,RunMetadata * run_metadata)376   Status RunCallable(CallableHandle handle,
377                      const std::vector<Tensor>& feed_tensors,
378                      std::vector<Tensor>* fetch_tensors,
379                      RunMetadata* run_metadata) override {
380     return wrapped_->RunCallable(handle, feed_tensors, fetch_tensors,
381                                  run_metadata);
382   }
383 
RunCallable(CallableHandle handle,const std::vector<Tensor> & feed_tensors,std::vector<Tensor> * fetch_tensors,RunMetadata * run_metadata,const thread::ThreadPoolOptions & threadpool_options)384   Status RunCallable(
385       CallableHandle handle, const std::vector<Tensor>& feed_tensors,
386       std::vector<Tensor>* fetch_tensors, RunMetadata* run_metadata,
387       const thread::ThreadPoolOptions& threadpool_options) override {
388     return wrapped_->RunCallable(handle, feed_tensors, fetch_tensors,
389                                  run_metadata, threadpool_options);
390   }
391 
ReleaseCallable(CallableHandle handle)392   Status ReleaseCallable(CallableHandle handle) override {
393     return wrapped_->ReleaseCallable(handle);
394   }
395 
396  private:
397   const std::unique_ptr<Session> wrapped_;
398 };
399 }  // namespace
400 
RestoreSession(const RunOptions & run_options,const MetaGraphDef & meta_graph,const string & export_dir,std::unique_ptr<Session> * session)401 Status RestoreSession(const RunOptions& run_options,
402                       const MetaGraphDef& meta_graph, const string& export_dir,
403                       std::unique_ptr<Session>* session) {
404   const uint64 read_start_microseconds = Env::Default()->NowMicros();
405   std::vector<AssetFileDef> asset_file_defs;
406   TF_RETURN_IF_ERROR(internal::GetAssetFileDefs(meta_graph, &asset_file_defs));
407   if (meta_graph.has_saver_def()) {
408     TF_RETURN_IF_ERROR(RunRestore(run_options, export_dir,
409                                   meta_graph.saver_def().restore_op_name(),
410                                   meta_graph.saver_def().filename_tensor_name(),
411                                   asset_file_defs, session->get()));
412   }
413   // Record walltime spent in restoring graph from disk, but postpone metric
414   // increments until graph init finishes.
415   const uint64 restore_graph_walltime =
416       GetLatencyMicroseconds(read_start_microseconds);
417 
418   const uint64 graph_init_start_microseconds = Env::Default()->NowMicros();
419   string init_op_name;
420   TF_RETURN_IF_ERROR(
421       internal::GetInitOp(export_dir, meta_graph, &init_op_name));
422   TF_RETURN_IF_ERROR(RunInitOp(run_options, export_dir, meta_graph,
423                                asset_file_defs, session->get(), init_op_name));
424   load_latency_by_stage->GetCell(export_dir, "restore_graph")
425       ->Add(restore_graph_walltime);
426   // Record wall time spent in init op.
427   load_latency_by_stage->GetCell(export_dir, "init_graph")
428       ->Add(GetLatencyMicroseconds(graph_init_start_microseconds));
429   return Status::OK();
430 }
431 
LoadSavedModel(const SessionOptions & session_options,const RunOptions & run_options,const string & export_dir,const std::unordered_set<string> & tags,SavedModelBundleLite * const bundle)432 Status LoadSavedModel(const SessionOptions& session_options,
433                       const RunOptions& run_options, const string& export_dir,
434                       const std::unordered_set<string>& tags,
435                       SavedModelBundleLite* const bundle) {
436   SavedModelBundle legacy_bundle;
437   SessionOptions rewritten_options(session_options);
438   // We disallow calls to Session::Extend() on the returned session, so we can
439   // reduce memory consumption by not storing the original GraphDef.
440   rewritten_options.config.mutable_experimental()
441       ->set_optimize_for_static_graph(true);
442   // Disallowing the `RunOptions.output_partition_graphs` option (typically used
443   // in debugging and tests) allows us to reduce memory consumption further by
444   // not storing the rewritten subgraph for each signature.
445   rewritten_options.config.mutable_experimental()
446       ->set_disable_output_partition_graphs(true);
447   // TODO(mrry): Consider specializing the session creation to reduce peak
448   // RAM consumption by using `Session::Create(GraphDef&&)`.
449   TF_RETURN_IF_ERROR(LoadSavedModel(rewritten_options, run_options, export_dir,
450                                     tags, &legacy_bundle));
451   *bundle = SavedModelBundleLite(
452       absl::make_unique<LiteSessionWrapper>(std::move(legacy_bundle.session)),
453       std::move(*legacy_bundle.meta_graph_def.mutable_signature_def()));
454   return Status::OK();
455 }
456 
MaybeSavedModelDirectory(const string & export_dir)457 bool MaybeSavedModelDirectory(const string& export_dir) {
458   const string saved_model_pb_path =
459       io::JoinPath(export_dir, kSavedModelFilenamePb);
460   const string saved_model_pbtxt_path =
461       io::JoinPath(export_dir, kSavedModelFilenamePbTxt);
462   return Env::Default()->FileExists(saved_model_pb_path).ok() ||
463          Env::Default()->FileExists(saved_model_pbtxt_path).ok();
464 }
465 
466 }  // namespace tensorflow
467