1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/cc/saved_model/loader.h"
17
18 #include <unordered_set>
19
20 #include "tensorflow/cc/saved_model/constants.h"
21 #include "tensorflow/cc/saved_model/loader_util.h"
22 #include "tensorflow/cc/saved_model/reader.h"
23 #include "tensorflow/core/framework/attr_value.pb.h"
24 #include "tensorflow/core/framework/function.pb.h"
25 #include "tensorflow/core/framework/node_def.pb.h"
26 #include "tensorflow/core/framework/tensor.pb.h"
27 #include "tensorflow/core/lib/io/path.h"
28 #include "tensorflow/core/lib/monitoring/counter.h"
29 #include "tensorflow/core/lib/monitoring/sampler.h"
30 #include "tensorflow/core/lib/strings/str_util.h"
31 #include "tensorflow/core/lib/strings/strcat.h"
32 #include "tensorflow/core/platform/env.h"
33 #include "tensorflow/core/platform/errors.h"
34 #include "tensorflow/core/protobuf/graph_debug_info.pb.h"
35 #include "tensorflow/core/protobuf/meta_graph.pb.h"
36 #include "tensorflow/core/protobuf/saver.pb.h"
37 #include "tensorflow/core/public/session.h"
38 #include "tensorflow/core/public/session_options.h"
39 #include "tensorflow/core/util/tensor_bundle/naming.h"
40
41 namespace tensorflow {
42 namespace {
43
44 auto* load_attempt_count = monitoring::Counter<2>::New(
45 "/tensorflow/cc/saved_model/load_attempt_count",
46 "The number of times a SavedModel was successfully loaded.", "model_path",
47 "status");
48 auto* load_latency = monitoring::Counter<1>::New(
49 "/tensorflow/cc/saved_model/load_latency",
50 "Latency in microseconds for SavedModels that were successfully loaded.",
51 "model_path");
52 auto* load_latency_by_stage = monitoring::Sampler<2>::New(
53 {
54 "/tensorflow/cc/saved_model/load_latency_by_stage", // metric name
55 "Distribution of wall time spent (in microseconds) in each stage "
56 "(restore graph from disk, run init graph op, etc) when loading the "
57 "model",
58 "model_path",
59 "stage",
60 },
61 // Scale of 10, power of 1.8 with bucket count 33 (~20 minutes).
62 monitoring::Buckets::Exponential(10, 1.8, 33));
63
64 constexpr char kLoadAttemptFail[] = "fail";
65 constexpr char kLoadAttemptSuccess[] = "success";
66
GetLatencyMicroseconds(const uint64 start_microseconds)67 uint64 GetLatencyMicroseconds(const uint64 start_microseconds) {
68 const uint64 end_microseconds = EnvTime::NowMicros();
69 // Avoid clock skew.
70 if (end_microseconds < start_microseconds) return 0;
71 return end_microseconds - start_microseconds;
72 }
73
74 // Ensure that constant tensors loaded from the saved model have valid shape.
75 // Also ensure that constant nodes have a value assigned to them.
76 // TODO(b/154763635): this is temporary and will be replaced with a better audit
ValidateNode(const NodeDef & node)77 static Status ValidateNode(const NodeDef& node) {
78 const auto node_iterator = node.attr().find("value");
79 if (node_iterator != node.attr().end()) {
80 AttrValue node_value = node_iterator->second;
81 if (node_value.has_tensor()) {
82 const PartialTensorShape node_shape(node_value.tensor().tensor_shape());
83 if (node_shape.num_elements() < 0) {
84 return errors::FailedPrecondition(
85 "Saved model contains node \"", node.name(), "\" (op \"", node.op(),
86 "\") which initializes from a tensor with ",
87 node_shape.num_elements(), " elements");
88 }
89 }
90 } else if (node.op() == "Const") {
91 return errors::FailedPrecondition(
92 "Saved model contains node \"", node.name(),
93 "\" which is a constant tensor but no value has been provided");
94 }
95 return Status::OK();
96 }
97
ValidateSavedTensors(const GraphDef & graph_def)98 static Status ValidateSavedTensors(const GraphDef& graph_def) {
99 for (const auto& node : graph_def.node()) {
100 TF_RETURN_IF_ERROR(ValidateNode(node));
101 }
102
103 if (graph_def.has_library()) {
104 const FunctionDefLibrary& library = graph_def.library();
105 for (const auto& function : library.function()) {
106 for (const auto& node : function.node_def()) {
107 TF_RETURN_IF_ERROR(ValidateNode(node));
108 }
109 }
110 }
111
112 return Status::OK();
113 }
114
CreateStringTensor(const string & value)115 Tensor CreateStringTensor(const string& value) {
116 Tensor tensor(DT_STRING, TensorShape({}));
117 tensor.scalar<tstring>()() = value;
118 return tensor;
119 }
120
AddAssetsTensorsToInputs(const StringPiece export_dir,const std::vector<AssetFileDef> & asset_file_defs,std::vector<std::pair<string,Tensor>> * inputs)121 void AddAssetsTensorsToInputs(const StringPiece export_dir,
122 const std::vector<AssetFileDef>& asset_file_defs,
123 std::vector<std::pair<string, Tensor>>* inputs) {
124 if (asset_file_defs.empty()) {
125 return;
126 }
127 for (auto& asset_file_def : asset_file_defs) {
128 Tensor assets_file_path_tensor = CreateStringTensor(io::JoinPath(
129 export_dir, kSavedModelAssetsDirectory, asset_file_def.filename()));
130 inputs->push_back(
131 {asset_file_def.tensor_info().name(), assets_file_path_tensor});
132 }
133 }
134
135 // Like Session::Run(), but uses the Make/Run/ReleaseCallable() API to avoid
136 // leaving behind non-GC'ed state.
137 //
138 // Detailed motivation behind this approach, from ashankar@:
139 //
140 // Each call to Session::Run() that identifies a new subgraph (based on feeds
141 // and fetches) creates some datastructures that live as long as the session
142 // (the partitioned graph, associated executors etc.).
143 //
144 // A pathological case of this would be if say the initialization op
145 // (main_op/legacy_init_op) involves the use of a large constant. Then we
146 // allocate memory for that large constant that will just stick around till the
147 // session dies. With this Callable mechanism, that memory will be released
148 // right after ReleaseCallable returns.
149 //
150 // However, the resource manager state remains.
RunOnce(const RunOptions & run_options,const std::vector<std::pair<string,Tensor>> & inputs,const std::vector<string> & output_tensor_names,const std::vector<string> & target_node_names,std::vector<Tensor> * outputs,RunMetadata * run_metadata,Session * session)151 Status RunOnce(const RunOptions& run_options,
152 const std::vector<std::pair<string, Tensor>>& inputs,
153 const std::vector<string>& output_tensor_names,
154 const std::vector<string>& target_node_names,
155 std::vector<Tensor>* outputs, RunMetadata* run_metadata,
156 Session* session) {
157 CallableOptions callable_options;
158 std::vector<Tensor> feed_tensors;
159 *callable_options.mutable_run_options() = run_options;
160 for (const auto& input : inputs) {
161 const string& name = input.first;
162 const Tensor& tensor = input.second;
163 callable_options.add_feed(name);
164 feed_tensors.push_back(tensor);
165 }
166 for (const string& output_tensor_name : output_tensor_names) {
167 callable_options.add_fetch(output_tensor_name);
168 }
169 for (const string& target_node_name : target_node_names) {
170 callable_options.add_target(target_node_name);
171 }
172
173 Session::CallableHandle callable_handle;
174 TF_RETURN_IF_ERROR(session->MakeCallable(callable_options, &callable_handle));
175 const Status run_status = session->RunCallable(callable_handle, feed_tensors,
176 outputs, run_metadata);
177 // Be sure to call ReleaseCallable() regardless of the outcome of
178 // RunCallable().
179 session->ReleaseCallable(callable_handle).IgnoreError();
180 return run_status;
181 }
182
183 // RunInitOp will return OK if the initialization op was run successfully.
184 // An empty init_op_name indicates that there are no init ops to run.
RunInitOp(const RunOptions & run_options,const string & export_dir,const MetaGraphDef & meta_graph_def,const std::vector<AssetFileDef> & asset_file_defs,Session * session,const string & init_op_name)185 Status RunInitOp(const RunOptions& run_options, const string& export_dir,
186 const MetaGraphDef& meta_graph_def,
187 const std::vector<AssetFileDef>& asset_file_defs,
188 Session* session, const string& init_op_name) {
189 if (!init_op_name.empty()) {
190 LOG(INFO) << "Running initialization op on SavedModel bundle at path: "
191 << export_dir;
192 std::vector<std::pair<string, Tensor>> inputs;
193 AddAssetsTensorsToInputs(export_dir, asset_file_defs, &inputs);
194 RunMetadata run_metadata;
195 return RunOnce(run_options, inputs, {}, {init_op_name},
196 nullptr /* outputs */, &run_metadata, session);
197 }
198 return Status::OK();
199 }
200
RunRestore(const RunOptions & run_options,const string & export_dir,const StringPiece restore_op_name,const StringPiece variable_filename_const_op_name,const std::vector<AssetFileDef> & asset_file_defs,Session * session)201 Status RunRestore(const RunOptions& run_options, const string& export_dir,
202 const StringPiece restore_op_name,
203 const StringPiece variable_filename_const_op_name,
204 const std::vector<AssetFileDef>& asset_file_defs,
205 Session* session) {
206 LOG(INFO) << "Restoring SavedModel bundle.";
207 // Find path to variables to be restored in export directory.
208 const string variables_directory =
209 io::JoinPath(export_dir, kSavedModelVariablesDirectory);
210 // Check for saver checkpoints in v2 format. Models exported in the checkpoint
211 // v2 format will have a variables.index file. The corresponding
212 // variables are stored in the variables.data-?????-of-????? files.
213 const string variables_index_path = io::JoinPath(
214 variables_directory, MetaFilename(kSavedModelVariablesFilename));
215 if (!Env::Default()->FileExists(variables_index_path).ok()) {
216 LOG(INFO) << "The specified SavedModel has no variables; no checkpoints "
217 "were restored. File does not exist: "
218 << variables_index_path;
219 return Status::OK();
220 }
221 const string variables_path =
222 io::JoinPath(variables_directory, kSavedModelVariablesFilename);
223
224 // Add variables to the graph.
225 Tensor variables_path_tensor(DT_STRING, TensorShape({}));
226 variables_path_tensor.scalar<tstring>()() = variables_path;
227
228 std::vector<std::pair<string, Tensor>> inputs = {
229 {string(variable_filename_const_op_name), variables_path_tensor}};
230
231 AddAssetsTensorsToInputs(export_dir, asset_file_defs, &inputs);
232
233 RunMetadata run_metadata;
234 return RunOnce(run_options, inputs, {}, {string(restore_op_name)},
235 nullptr /* outputs */, &run_metadata, session);
236 }
237
238 } // namespace
239
~SavedModelBundleInterface()240 SavedModelBundleInterface::~SavedModelBundleInterface() {}
241
LoadMetagraphIntoSession(const SessionOptions & session_options,const MetaGraphDef & meta_graph,std::unique_ptr<Session> * session)242 Status LoadMetagraphIntoSession(const SessionOptions& session_options,
243 const MetaGraphDef& meta_graph,
244 std::unique_ptr<Session>* session) {
245 Session* session_p = nullptr;
246 TF_RETURN_IF_ERROR(NewSession(session_options, &session_p));
247 session->reset(session_p);
248 TF_RETURN_IF_ERROR(ValidateSavedTensors(meta_graph.graph_def()));
249 return (*session)->Create(meta_graph.graph_def());
250 }
251
LoadSavedModelInternal(const SessionOptions & session_options,const RunOptions & run_options,const string & export_dir,const std::unordered_set<string> & tags,SavedModelBundle * const bundle)252 Status LoadSavedModelInternal(const SessionOptions& session_options,
253 const RunOptions& run_options,
254 const string& export_dir,
255 const std::unordered_set<string>& tags,
256 SavedModelBundle* const bundle) {
257 TF_RETURN_IF_ERROR(ReadMetaGraphDefFromSavedModel(export_dir, tags,
258 &bundle->meta_graph_def));
259 TF_RETURN_IF_ERROR(
260 ReadSavedModelDebugInfoIfPresent(export_dir, &bundle->debug_info));
261 TF_RETURN_IF_ERROR(LoadMetagraphIntoSession(
262 session_options, bundle->meta_graph_def, &bundle->session));
263 TF_RETURN_IF_ERROR(RestoreSession(run_options, bundle->meta_graph_def,
264 export_dir, &bundle->session));
265 return Status::OK();
266 }
267
LoadSavedModel(const SessionOptions & session_options,const RunOptions & run_options,const string & export_dir,const std::unordered_set<string> & tags,SavedModelBundle * const bundle)268 Status LoadSavedModel(const SessionOptions& session_options,
269 const RunOptions& run_options, const string& export_dir,
270 const std::unordered_set<string>& tags,
271 SavedModelBundle* const bundle) {
272 // TODO(robson): Add tests for the counters.
273 const uint64 start_microseconds = Env::Default()->NowMicros();
274 const Status status = LoadSavedModelInternal(session_options, run_options,
275 export_dir, tags, bundle);
276 auto log_and_count = [&](const string& status_str) {
277 LOG(INFO) << "SavedModel load for tags { " << absl::StrJoin(tags, " ")
278 << " }; Status: " << status_str << ": " << status << ". Took "
279 << GetLatencyMicroseconds(start_microseconds) << " microseconds.";
280 load_attempt_count->GetCell(export_dir, status_str)->IncrementBy(1);
281 };
282 if (status.ok()) {
283 log_and_count(kLoadAttemptSuccess);
284 } else {
285 log_and_count(kLoadAttemptFail);
286 }
287 load_latency->GetCell(export_dir)
288 ->IncrementBy(GetLatencyMicroseconds(start_microseconds));
289 return status;
290 }
291
292 namespace {
293 // Session wrapper that prevents calls to Session::Create(), Session::Extend(),
294 // and the deprecated partial-run methods.
295 //
296 // Limiting the available methods on a returned Session gives us the option
297 // to replace the Session with a cut-down implementation, without breaking any
298 // users.
299 class LiteSessionWrapper : public Session {
300 public:
LiteSessionWrapper(std::unique_ptr<Session> wrapped)301 explicit LiteSessionWrapper(std::unique_ptr<Session> wrapped)
302 : wrapped_(std::move(wrapped)) {}
303
Create(const GraphDef & graph)304 Status Create(const GraphDef& graph) override {
305 return errors::Unimplemented("Session::Create()");
306 }
Create(GraphDef && graph)307 Status Create(GraphDef&& graph) override {
308 return errors::Unimplemented("Session::Create()");
309 }
310
Extend(const GraphDef & graph)311 Status Extend(const GraphDef& graph) override {
312 return errors::Unimplemented("Session::Extend()");
313 }
Extend(GraphDef && graph)314 Status Extend(GraphDef&& graph) override {
315 return errors::Unimplemented("Session::Extend()");
316 }
317
Run(const std::vector<std::pair<string,Tensor>> & inputs,const std::vector<string> & output_tensor_names,const std::vector<string> & target_node_names,std::vector<Tensor> * outputs)318 Status Run(const std::vector<std::pair<string, Tensor>>& inputs,
319 const std::vector<string>& output_tensor_names,
320 const std::vector<string>& target_node_names,
321 std::vector<Tensor>* outputs) override {
322 return wrapped_->Run(inputs, output_tensor_names, target_node_names,
323 outputs);
324 }
325
Create(const RunOptions & run_options,const GraphDef & graph)326 Status Create(const RunOptions& run_options, const GraphDef& graph) override {
327 return errors::Unimplemented("Session::Create()");
328 }
Extend(const RunOptions & run_options,const GraphDef & graph)329 Status Extend(const RunOptions& run_options, const GraphDef& graph) override {
330 return errors::Unimplemented("Session::Extend()");
331 }
Create(const RunOptions & run_options,GraphDef && graph)332 Status Create(const RunOptions& run_options, GraphDef&& graph) override {
333 return errors::Unimplemented("Session::Create()");
334 }
Extend(const RunOptions & run_options,GraphDef && graph)335 Status Extend(const RunOptions& run_options, GraphDef&& graph) override {
336 return errors::Unimplemented("Session::Extend()");
337 }
Close(const RunOptions & run_options)338 Status Close(const RunOptions& run_options) override {
339 return wrapped_->Close(run_options);
340 }
341
Run(const RunOptions & run_options,const std::vector<std::pair<string,Tensor>> & inputs,const std::vector<string> & output_tensor_names,const std::vector<string> & target_node_names,std::vector<Tensor> * outputs,RunMetadata * run_metadata)342 Status Run(const RunOptions& run_options,
343 const std::vector<std::pair<string, Tensor>>& inputs,
344 const std::vector<string>& output_tensor_names,
345 const std::vector<string>& target_node_names,
346 std::vector<Tensor>* outputs, RunMetadata* run_metadata) override {
347 return wrapped_->Run(run_options, inputs, output_tensor_names,
348 target_node_names, outputs, run_metadata);
349 }
350
PRunSetup(const std::vector<string> & input_names,const std::vector<string> & output_names,const std::vector<string> & target_nodes,string * handle)351 Status PRunSetup(const std::vector<string>& input_names,
352 const std::vector<string>& output_names,
353 const std::vector<string>& target_nodes,
354 string* handle) override {
355 return errors::Unimplemented("Session::PRunSetup()");
356 }
357
PRun(const string & handle,const std::vector<std::pair<string,Tensor>> & inputs,const std::vector<string> & output_names,std::vector<Tensor> * outputs)358 Status PRun(const string& handle,
359 const std::vector<std::pair<string, Tensor>>& inputs,
360 const std::vector<string>& output_names,
361 std::vector<Tensor>* outputs) override {
362 return errors::Unimplemented("Session::PRun()");
363 }
364
ListDevices(std::vector<DeviceAttributes> * response)365 Status ListDevices(std::vector<DeviceAttributes>* response) override {
366 return wrapped_->ListDevices(response);
367 }
368
Close()369 Status Close() override { return wrapped_->Close(); }
370
MakeCallable(const CallableOptions & callable_options,CallableHandle * out_handle)371 Status MakeCallable(const CallableOptions& callable_options,
372 CallableHandle* out_handle) override {
373 return wrapped_->MakeCallable(callable_options, out_handle);
374 }
375
RunCallable(CallableHandle handle,const std::vector<Tensor> & feed_tensors,std::vector<Tensor> * fetch_tensors,RunMetadata * run_metadata)376 Status RunCallable(CallableHandle handle,
377 const std::vector<Tensor>& feed_tensors,
378 std::vector<Tensor>* fetch_tensors,
379 RunMetadata* run_metadata) override {
380 return wrapped_->RunCallable(handle, feed_tensors, fetch_tensors,
381 run_metadata);
382 }
383
RunCallable(CallableHandle handle,const std::vector<Tensor> & feed_tensors,std::vector<Tensor> * fetch_tensors,RunMetadata * run_metadata,const thread::ThreadPoolOptions & threadpool_options)384 Status RunCallable(
385 CallableHandle handle, const std::vector<Tensor>& feed_tensors,
386 std::vector<Tensor>* fetch_tensors, RunMetadata* run_metadata,
387 const thread::ThreadPoolOptions& threadpool_options) override {
388 return wrapped_->RunCallable(handle, feed_tensors, fetch_tensors,
389 run_metadata, threadpool_options);
390 }
391
ReleaseCallable(CallableHandle handle)392 Status ReleaseCallable(CallableHandle handle) override {
393 return wrapped_->ReleaseCallable(handle);
394 }
395
396 private:
397 const std::unique_ptr<Session> wrapped_;
398 };
399 } // namespace
400
RestoreSession(const RunOptions & run_options,const MetaGraphDef & meta_graph,const string & export_dir,std::unique_ptr<Session> * session)401 Status RestoreSession(const RunOptions& run_options,
402 const MetaGraphDef& meta_graph, const string& export_dir,
403 std::unique_ptr<Session>* session) {
404 const uint64 read_start_microseconds = Env::Default()->NowMicros();
405 std::vector<AssetFileDef> asset_file_defs;
406 TF_RETURN_IF_ERROR(internal::GetAssetFileDefs(meta_graph, &asset_file_defs));
407 if (meta_graph.has_saver_def()) {
408 TF_RETURN_IF_ERROR(RunRestore(run_options, export_dir,
409 meta_graph.saver_def().restore_op_name(),
410 meta_graph.saver_def().filename_tensor_name(),
411 asset_file_defs, session->get()));
412 }
413 // Record walltime spent in restoring graph from disk, but postpone metric
414 // increments until graph init finishes.
415 const uint64 restore_graph_walltime =
416 GetLatencyMicroseconds(read_start_microseconds);
417
418 const uint64 graph_init_start_microseconds = Env::Default()->NowMicros();
419 string init_op_name;
420 TF_RETURN_IF_ERROR(
421 internal::GetInitOp(export_dir, meta_graph, &init_op_name));
422 TF_RETURN_IF_ERROR(RunInitOp(run_options, export_dir, meta_graph,
423 asset_file_defs, session->get(), init_op_name));
424 load_latency_by_stage->GetCell(export_dir, "restore_graph")
425 ->Add(restore_graph_walltime);
426 // Record wall time spent in init op.
427 load_latency_by_stage->GetCell(export_dir, "init_graph")
428 ->Add(GetLatencyMicroseconds(graph_init_start_microseconds));
429 return Status::OK();
430 }
431
LoadSavedModel(const SessionOptions & session_options,const RunOptions & run_options,const string & export_dir,const std::unordered_set<string> & tags,SavedModelBundleLite * const bundle)432 Status LoadSavedModel(const SessionOptions& session_options,
433 const RunOptions& run_options, const string& export_dir,
434 const std::unordered_set<string>& tags,
435 SavedModelBundleLite* const bundle) {
436 SavedModelBundle legacy_bundle;
437 SessionOptions rewritten_options(session_options);
438 // We disallow calls to Session::Extend() on the returned session, so we can
439 // reduce memory consumption by not storing the original GraphDef.
440 rewritten_options.config.mutable_experimental()
441 ->set_optimize_for_static_graph(true);
442 // Disallowing the `RunOptions.output_partition_graphs` option (typically used
443 // in debugging and tests) allows us to reduce memory consumption further by
444 // not storing the rewritten subgraph for each signature.
445 rewritten_options.config.mutable_experimental()
446 ->set_disable_output_partition_graphs(true);
447 // TODO(mrry): Consider specializing the session creation to reduce peak
448 // RAM consumption by using `Session::Create(GraphDef&&)`.
449 TF_RETURN_IF_ERROR(LoadSavedModel(rewritten_options, run_options, export_dir,
450 tags, &legacy_bundle));
451 *bundle = SavedModelBundleLite(
452 absl::make_unique<LiteSessionWrapper>(std::move(legacy_bundle.session)),
453 std::move(*legacy_bundle.meta_graph_def.mutable_signature_def()));
454 return Status::OK();
455 }
456
MaybeSavedModelDirectory(const string & export_dir)457 bool MaybeSavedModelDirectory(const string& export_dir) {
458 const string saved_model_pb_path =
459 io::JoinPath(export_dir, kSavedModelFilenamePb);
460 const string saved_model_pbtxt_path =
461 io::JoinPath(export_dir, kSavedModelFilenamePbTxt);
462 return Env::Default()->FileExists(saved_model_pb_path).ok() ||
463 Env::Default()->FileExists(saved_model_pbtxt_path).ok();
464 }
465
466 } // namespace tensorflow
467