1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/cc/saved_model/reader.h"
17
18 #include <unordered_set>
19
20 #include "absl/memory/memory.h"
21 #include "tensorflow/cc/saved_model/constants.h"
22 #include "tensorflow/cc/saved_model/metrics.h"
23 #include "tensorflow/cc/saved_model/util.h"
24 #include "tensorflow/core/framework/attr_value.pb.h"
25 #include "tensorflow/core/framework/function.pb.h"
26 #include "tensorflow/core/framework/graph.pb.h"
27 #include "tensorflow/core/framework/node_def.pb.h"
28 #include "tensorflow/core/framework/tensor.pb.h"
29 #include "tensorflow/core/lib/io/path.h"
30 #include "tensorflow/core/lib/strings/str_util.h"
31 #include "tensorflow/core/lib/strings/strcat.h"
32 #include "tensorflow/core/platform/env.h"
33 #include "tensorflow/core/protobuf/saved_model.pb.h"
34 #include "tensorflow/core/util/tensor_bundle/byte_swap.h"
35
36 namespace tensorflow {
37 namespace {
38
39 // Reads the SavedModel proto from saved_model.pb in `export_dir`.
40 // Returns a failure status when the SavedModel file does not exist.
ReadSavedModel(absl::string_view export_dir,SavedModel * saved_model_proto)41 Status ReadSavedModel(absl::string_view export_dir,
42 SavedModel* saved_model_proto) {
43 LOG(INFO) << "Reading SavedModel from: " << export_dir;
44
45 const std::string saved_model_pb_path =
46 io::JoinPath(export_dir, kSavedModelFilenamePb);
47
48 if (Env::Default()->FileExists(saved_model_pb_path).ok()) {
49 Status result =
50 ReadBinaryProto(Env::Default(), saved_model_pb_path, saved_model_proto);
51 if (result.ok()) {
52 metrics::SavedModelRead(saved_model::GetWriteVersion(*saved_model_proto))
53 .IncrementBy(1);
54 }
55 return result;
56 }
57 const std::string saved_model_pbtxt_path =
58 io::JoinPath(export_dir, kSavedModelFilenamePbTxt);
59 if (Env::Default()->FileExists(saved_model_pbtxt_path).ok()) {
60 Status result = ReadTextProto(Env::Default(), saved_model_pbtxt_path,
61 saved_model_proto);
62 if (result.ok()) {
63 metrics::SavedModelRead(saved_model::GetWriteVersion(*saved_model_proto))
64 .IncrementBy(1);
65 }
66 return result;
67 }
68 return Status(
69 error::Code::NOT_FOUND,
70 strings::StrCat("Could not find SavedModel .pb or .pbtxt at supplied "
71 "export directory path: ",
72 export_dir));
73 }
74
75 // Swap tensor_content field of Const Op Tensors in the named functions
SwapTensorContent(MetaGraphDef * meta_graph_def)76 static Status SwapTensorContent(MetaGraphDef* meta_graph_def) {
77 GraphDef graph_def = *meta_graph_def->mutable_graph_def();
78 for (auto& function : *meta_graph_def->mutable_graph_def()
79 ->mutable_library()
80 ->mutable_function()) {
81 for (auto& node : (*function.mutable_node_def())) {
82 if (node.op() != "Const") continue;
83 auto node_iterator = node.mutable_attr()->find("value");
84 if (node_iterator == node.mutable_attr()->end()) continue;
85 AttrValue node_value = node_iterator->second;
86 if (!node_value.has_tensor()) continue;
87
88 auto tsize = node_value.mutable_tensor()->tensor_content().size();
89 auto p_type = node_value.mutable_tensor()->dtype();
90 // Swap only when there is something in tensor_content field
91 if (tsize != 0 && DataTypeCanUseMemcpy(p_type)) {
92 Tensor parsed(p_type);
93 DCHECK(parsed.FromProto(*node_value.mutable_tensor()));
94 TF_RETURN_IF_ERROR(ByteSwapTensor(&parsed));
95 (*node.mutable_attr())["value"].mutable_tensor()->set_tensor_content(
96 string(reinterpret_cast<const char*>(parsed.tensor_data().data()),
97 parsed.tensor_data().size()));
98 }
99 }
100 }
101 return Status::OK();
102 }
103
FindMetaGraphDef(const std::unordered_set<string> & tags,SavedModel * saved_model_proto,MetaGraphDef * meta_graph_def)104 Status FindMetaGraphDef(const std::unordered_set<string>& tags,
105 SavedModel* saved_model_proto,
106 MetaGraphDef* meta_graph_def) {
107 LOG(INFO) << "Reading meta graph with tags { " << absl::StrJoin(tags, " ")
108 << " }";
109 for (MetaGraphDef& graph_def : *saved_model_proto->mutable_meta_graphs()) {
110 // Get tags from the graph_def.
111 std::unordered_set<string> graph_tags;
112 for (const string& tag : graph_def.meta_info_def().tags()) {
113 graph_tags.insert(tag);
114 }
115 // Match with the set of tags provided.
116 if (graph_tags == tags) {
117 *meta_graph_def = std::move(graph_def);
118 // Correct the endiness of Tensor content on big-endian system
119 if (!port::kLittleEndian) {
120 TF_RETURN_IF_ERROR(SwapTensorContent(meta_graph_def));
121 }
122 return Status::OK();
123 }
124 }
125 return Status(
126 error::Code::NOT_FOUND,
127 strings::StrCat(
128 "Could not find meta graph def matching supplied tags: { ",
129 absl::StrJoin(tags, " "),
130 " }. To inspect available tag-sets in the SavedModel, please "
131 "use the SavedModel CLI: `saved_model_cli`"));
132 }
133 } // namespace
134
ReadMetaGraphDefFromSavedModel(const string & export_dir,const std::unordered_set<string> & tags,MetaGraphDef * const meta_graph_def)135 Status ReadMetaGraphDefFromSavedModel(const string& export_dir,
136 const std::unordered_set<string>& tags,
137 MetaGraphDef* const meta_graph_def) {
138 SavedModel saved_model_proto;
139 TF_RETURN_IF_ERROR(ReadSavedModel(export_dir, &saved_model_proto));
140 TF_RETURN_IF_ERROR(
141 FindMetaGraphDef(tags, &saved_model_proto, meta_graph_def));
142 return Status::OK();
143 }
144
ReadSavedModelDebugInfoIfPresent(const string & export_dir,std::unique_ptr<GraphDebugInfo> * debug_info_proto)145 Status ReadSavedModelDebugInfoIfPresent(
146 const string& export_dir,
147 std::unique_ptr<GraphDebugInfo>* debug_info_proto) {
148 LOG(INFO) << "Reading SavedModel debug info (if present) from: "
149 << export_dir;
150
151 const string debug_info_pb_path =
152 io::JoinPath(export_dir, "debug", "saved_model_debug_info.pb");
153 if (Env::Default()->FileExists(debug_info_pb_path).ok()) {
154 GraphDebugInfo debug_info;
155 TF_RETURN_IF_ERROR(
156 ReadBinaryProto(Env::Default(), debug_info_pb_path, &debug_info));
157 *debug_info_proto =
158 absl::make_unique<GraphDebugInfo>(std::move(debug_info));
159 }
160 return Status::OK();
161 }
162
163 } // namespace tensorflow
164