• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/cc/saved_model/reader.h"
17 
18 #include <unordered_set>
19 
20 #include "absl/memory/memory.h"
21 #include "tensorflow/cc/saved_model/constants.h"
22 #include "tensorflow/cc/saved_model/metrics.h"
23 #include "tensorflow/cc/saved_model/util.h"
24 #include "tensorflow/core/framework/attr_value.pb.h"
25 #include "tensorflow/core/framework/function.pb.h"
26 #include "tensorflow/core/framework/graph.pb.h"
27 #include "tensorflow/core/framework/node_def.pb.h"
28 #include "tensorflow/core/framework/tensor.pb.h"
29 #include "tensorflow/core/lib/io/path.h"
30 #include "tensorflow/core/lib/strings/str_util.h"
31 #include "tensorflow/core/lib/strings/strcat.h"
32 #include "tensorflow/core/platform/env.h"
33 #include "tensorflow/core/protobuf/saved_model.pb.h"
34 #include "tensorflow/core/util/tensor_bundle/byte_swap.h"
35 
36 namespace tensorflow {
37 namespace {
38 
39 // Reads the SavedModel proto from saved_model.pb in `export_dir`.
40 // Returns a failure status when the SavedModel file does not exist.
ReadSavedModel(absl::string_view export_dir,SavedModel * saved_model_proto)41 Status ReadSavedModel(absl::string_view export_dir,
42                       SavedModel* saved_model_proto) {
43   LOG(INFO) << "Reading SavedModel from: " << export_dir;
44 
45   const std::string saved_model_pb_path =
46       io::JoinPath(export_dir, kSavedModelFilenamePb);
47 
48   if (Env::Default()->FileExists(saved_model_pb_path).ok()) {
49     Status result =
50         ReadBinaryProto(Env::Default(), saved_model_pb_path, saved_model_proto);
51     if (result.ok()) {
52       metrics::SavedModelRead(saved_model::GetWriteVersion(*saved_model_proto))
53           .IncrementBy(1);
54     }
55     return result;
56   }
57   const std::string saved_model_pbtxt_path =
58       io::JoinPath(export_dir, kSavedModelFilenamePbTxt);
59   if (Env::Default()->FileExists(saved_model_pbtxt_path).ok()) {
60     Status result = ReadTextProto(Env::Default(), saved_model_pbtxt_path,
61                                   saved_model_proto);
62     if (result.ok()) {
63       metrics::SavedModelRead(saved_model::GetWriteVersion(*saved_model_proto))
64           .IncrementBy(1);
65     }
66     return result;
67   }
68   return Status(
69       error::Code::NOT_FOUND,
70       strings::StrCat("Could not find SavedModel .pb or .pbtxt at supplied "
71                       "export directory path: ",
72                       export_dir));
73 }
74 
75 // Swap tensor_content field of Const Op Tensors in the named functions
SwapTensorContent(MetaGraphDef * meta_graph_def)76 static Status SwapTensorContent(MetaGraphDef* meta_graph_def) {
77   GraphDef graph_def = *meta_graph_def->mutable_graph_def();
78   for (auto& function : *meta_graph_def->mutable_graph_def()
79                              ->mutable_library()
80                              ->mutable_function()) {
81     for (auto& node : (*function.mutable_node_def())) {
82       if (node.op() != "Const") continue;
83       auto node_iterator = node.mutable_attr()->find("value");
84       if (node_iterator == node.mutable_attr()->end()) continue;
85       AttrValue node_value = node_iterator->second;
86       if (!node_value.has_tensor()) continue;
87 
88       auto tsize = node_value.mutable_tensor()->tensor_content().size();
89       auto p_type = node_value.mutable_tensor()->dtype();
90       // Swap only when there is something in tensor_content field
91       if (tsize != 0 && DataTypeCanUseMemcpy(p_type)) {
92         Tensor parsed(p_type);
93         DCHECK(parsed.FromProto(*node_value.mutable_tensor()));
94         TF_RETURN_IF_ERROR(ByteSwapTensor(&parsed));
95         (*node.mutable_attr())["value"].mutable_tensor()->set_tensor_content(
96             string(reinterpret_cast<const char*>(parsed.tensor_data().data()),
97                    parsed.tensor_data().size()));
98       }
99     }
100   }
101   return Status::OK();
102 }
103 
FindMetaGraphDef(const std::unordered_set<string> & tags,SavedModel * saved_model_proto,MetaGraphDef * meta_graph_def)104 Status FindMetaGraphDef(const std::unordered_set<string>& tags,
105                         SavedModel* saved_model_proto,
106                         MetaGraphDef* meta_graph_def) {
107   LOG(INFO) << "Reading meta graph with tags { " << absl::StrJoin(tags, " ")
108             << " }";
109   for (MetaGraphDef& graph_def : *saved_model_proto->mutable_meta_graphs()) {
110     // Get tags from the graph_def.
111     std::unordered_set<string> graph_tags;
112     for (const string& tag : graph_def.meta_info_def().tags()) {
113       graph_tags.insert(tag);
114     }
115     // Match with the set of tags provided.
116     if (graph_tags == tags) {
117       *meta_graph_def = std::move(graph_def);
118       // Correct the endiness of Tensor content on big-endian system
119       if (!port::kLittleEndian) {
120         TF_RETURN_IF_ERROR(SwapTensorContent(meta_graph_def));
121       }
122       return Status::OK();
123     }
124   }
125   return Status(
126       error::Code::NOT_FOUND,
127       strings::StrCat(
128           "Could not find meta graph def matching supplied tags: { ",
129           absl::StrJoin(tags, " "),
130           " }. To inspect available tag-sets in the SavedModel, please "
131           "use the SavedModel CLI: `saved_model_cli`"));
132 }
133 }  // namespace
134 
ReadMetaGraphDefFromSavedModel(const string & export_dir,const std::unordered_set<string> & tags,MetaGraphDef * const meta_graph_def)135 Status ReadMetaGraphDefFromSavedModel(const string& export_dir,
136                                       const std::unordered_set<string>& tags,
137                                       MetaGraphDef* const meta_graph_def) {
138   SavedModel saved_model_proto;
139   TF_RETURN_IF_ERROR(ReadSavedModel(export_dir, &saved_model_proto));
140   TF_RETURN_IF_ERROR(
141       FindMetaGraphDef(tags, &saved_model_proto, meta_graph_def));
142   return Status::OK();
143 }
144 
ReadSavedModelDebugInfoIfPresent(const string & export_dir,std::unique_ptr<GraphDebugInfo> * debug_info_proto)145 Status ReadSavedModelDebugInfoIfPresent(
146     const string& export_dir,
147     std::unique_ptr<GraphDebugInfo>* debug_info_proto) {
148   LOG(INFO) << "Reading SavedModel debug info (if present) from: "
149             << export_dir;
150 
151   const string debug_info_pb_path =
152       io::JoinPath(export_dir, "debug", "saved_model_debug_info.pb");
153   if (Env::Default()->FileExists(debug_info_pb_path).ok()) {
154     GraphDebugInfo debug_info;
155     TF_RETURN_IF_ERROR(
156         ReadBinaryProto(Env::Default(), debug_info_pb_path, &debug_info));
157     *debug_info_proto =
158         absl::make_unique<GraphDebugInfo>(std::move(debug_info));
159   }
160   return Status::OK();
161 }
162 
163 }  // namespace tensorflow
164