1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils.h"
17
18 #include "tensorflow/cc/saved_model/signature_constants.h"
19 #include "tensorflow/core/framework/types.pb.h"
20 #include "tensorflow/core/lib/core/errors.h"
21 #include "tensorflow/core/lib/core/stringpiece.h"
22 #include "tensorflow/core/platform/protobuf.h"
23
24 namespace tensorflow {
25
26 namespace {
27 template <class T>
FindInProtobufMap(StringPiece description,const protobuf::Map<string,T> & map,const string & key,const T ** value)28 Status FindInProtobufMap(StringPiece description,
29 const protobuf::Map<string, T>& map, const string& key,
30 const T** value) {
31 const auto it = map.find(key);
32 if (it == map.end()) {
33 return errors::NotFound("Could not find ", description, " for key: ", key);
34 }
35 *value = &it->second;
36 return Status::OK();
37 }
38
39 // Looks up the TensorInfo for the given key in the given map and verifies that
40 // its datatype matches the given correct datatype.
VerifyTensorInfoForKeyInMap(const protobuf::Map<string,TensorInfo> & map,const string & key,DataType correct_dtype)41 bool VerifyTensorInfoForKeyInMap(const protobuf::Map<string, TensorInfo>& map,
42 const string& key, DataType correct_dtype) {
43 const TensorInfo* tensor_info;
44 const Status& status = FindInProtobufMap("", map, key, &tensor_info);
45 if (!status.ok()) {
46 return false;
47 }
48 if (tensor_info->dtype() != correct_dtype) {
49 return false;
50 }
51 return true;
52 }
53
IsValidPredictSignature(const SignatureDef & signature_def)54 bool IsValidPredictSignature(const SignatureDef& signature_def) {
55 if (signature_def.method_name() != kPredictMethodName) {
56 return false;
57 }
58 if (signature_def.inputs().empty()) {
59 return false;
60 }
61 if (signature_def.outputs().empty()) {
62 return false;
63 }
64 return true;
65 }
66
IsValidRegressionSignature(const SignatureDef & signature_def)67 bool IsValidRegressionSignature(const SignatureDef& signature_def) {
68 if (signature_def.method_name() != kRegressMethodName) {
69 return false;
70 }
71 if (!VerifyTensorInfoForKeyInMap(signature_def.inputs(), kRegressInputs,
72 DT_STRING)) {
73 return false;
74 }
75 if (!VerifyTensorInfoForKeyInMap(signature_def.outputs(), kRegressOutputs,
76 DT_FLOAT)) {
77 return false;
78 }
79 return true;
80 }
81
IsValidClassificationSignature(const SignatureDef & signature_def)82 bool IsValidClassificationSignature(const SignatureDef& signature_def) {
83 if (signature_def.method_name() != kClassifyMethodName) {
84 return false;
85 }
86 if (!VerifyTensorInfoForKeyInMap(signature_def.inputs(), kClassifyInputs,
87 DT_STRING)) {
88 return false;
89 }
90 if (signature_def.outputs().empty()) {
91 return false;
92 }
93 for (auto const& output : signature_def.outputs()) {
94 const string& key = output.first;
95 const TensorInfo& tensor_info = output.second;
96 if (key == kClassifyOutputClasses) {
97 if (tensor_info.dtype() != DT_STRING) {
98 return false;
99 }
100 } else if (key == kClassifyOutputScores) {
101 if (tensor_info.dtype() != DT_FLOAT) {
102 return false;
103 }
104 } else {
105 return false;
106 }
107 }
108 return true;
109 }
110
111 } // namespace
112
FindSignatureDefByKey(const MetaGraphDef & meta_graph_def,const string & signature_def_key,const SignatureDef ** signature_def)113 Status FindSignatureDefByKey(const MetaGraphDef& meta_graph_def,
114 const string& signature_def_key,
115 const SignatureDef** signature_def) {
116 return FindInProtobufMap("SignatureDef", meta_graph_def.signature_def(),
117 signature_def_key, signature_def);
118 }
119
FindInputTensorInfoByKey(const SignatureDef & signature_def,const string & tensor_info_key,const TensorInfo ** tensor_info)120 Status FindInputTensorInfoByKey(const SignatureDef& signature_def,
121 const string& tensor_info_key,
122 const TensorInfo** tensor_info) {
123 return FindInProtobufMap("input TensorInfo", signature_def.inputs(),
124 tensor_info_key, tensor_info);
125 }
126
FindOutputTensorInfoByKey(const SignatureDef & signature_def,const string & tensor_info_key,const TensorInfo ** tensor_info)127 Status FindOutputTensorInfoByKey(const SignatureDef& signature_def,
128 const string& tensor_info_key,
129 const TensorInfo** tensor_info) {
130 return FindInProtobufMap("output TensorInfo", signature_def.outputs(),
131 tensor_info_key, tensor_info);
132 }
133
FindInputTensorNameByKey(const SignatureDef & signature_def,const string & tensor_info_key,string * name)134 Status FindInputTensorNameByKey(const SignatureDef& signature_def,
135 const string& tensor_info_key, string* name) {
136 const TensorInfo* tensor_info;
137 TF_RETURN_IF_ERROR(
138 FindInputTensorInfoByKey(signature_def, tensor_info_key, &tensor_info));
139 *name = tensor_info->name();
140 return Status::OK();
141 }
142
FindOutputTensorNameByKey(const SignatureDef & signature_def,const string & tensor_info_key,string * name)143 Status FindOutputTensorNameByKey(const SignatureDef& signature_def,
144 const string& tensor_info_key, string* name) {
145 const TensorInfo* tensor_info;
146 TF_RETURN_IF_ERROR(
147 FindOutputTensorInfoByKey(signature_def, tensor_info_key, &tensor_info));
148 *name = tensor_info->name();
149 return Status::OK();
150 }
151
IsValidSignature(const SignatureDef & signature_def)152 bool IsValidSignature(const SignatureDef& signature_def) {
153 return IsValidClassificationSignature(signature_def) ||
154 IsValidRegressionSignature(signature_def) ||
155 IsValidPredictSignature(signature_def);
156 }
157
158 } // namespace tensorflow
159