1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/toco/logging/conversion_log_util.h"
16
17 #ifdef __linux__
18 #include <sys/utsname.h>
19 #endif
20
21 #include <vector>
22
23 #include "absl/strings/str_cat.h"
24 #include "absl/time/clock.h"
25 #include "absl/time/time.h"
26 #include "tensorflow/core/framework/node_def.pb.h"
27 #include "tensorflow/lite/toco/model.h"
28 #include "tensorflow/lite/toco/tflite/export.h"
29 #include "tensorflow/lite/toco/tflite/operator.h"
30 #include "tensorflow/lite/toco/tooling_util.h"
31 #include "tensorflow/lite/version.h"
32
33 namespace toco {
34
35 namespace {
36
TryGetOperatorName(const Operator & op)37 std::string TryGetOperatorName(const Operator& op) {
38 std::string op_name;
39 if (!op.tensorflow_node_def.empty()) {
40 // Parse op name from serialized NodeDef.
41 tensorflow::NodeDef node_def;
42 if (!node_def.ParseFromString(op.tensorflow_node_def)) {
43 LOG(ERROR) << "Failed to parse Tensorflow NodeDef";
44 } else {
45 op_name = node_def.op();
46 if (!op_name.empty()) return op_name;
47 }
48 }
49 if (op.type == OperatorType::kUnsupported) {
50 // If we failed to get op name from serialized NodeDef (either because
51 // the tensorflow_node_def is an empty string, or we failed to parse
52 // from it), fall back to use 'tensorflow_op' field if this op is a
53 // TensorflowUnsupportedOperator.
54 const TensorFlowUnsupportedOperator& unsupported_op =
55 static_cast<const TensorFlowUnsupportedOperator&>(op);
56 if (!unsupported_op.tensorflow_op.empty()) {
57 op_name = unsupported_op.tensorflow_op;
58 return op_name;
59 }
60 }
61 // If this is a built-in op.
62 op_name = OperatorTypeName(op.type);
63 return op_name;
64 }
65
GetOSVersion()66 std::string GetOSVersion() {
67 std::string os_info;
68 #ifdef __linux__
69 utsname info;
70 if (uname(&info)) {
71 // Failed
72 LOG(ERROR) << "Cannot get OS info.";
73 return "";
74 }
75 os_info =
76 std::string(info.sysname) + ";OSVer=" + std::string(info.release) + ";";
77 #endif
78 return os_info;
79 }
80
ShapeToStringNoSpace(const Shape & shape)81 std::string ShapeToStringNoSpace(const Shape& shape) {
82 if (shape.dimensions_count() == 0) {
83 return "[]";
84 }
85
86 return absl::StrCat("[", absl::StrJoin(shape.dims(), ","), "]");
87 }
88
GetOperatorSignature(const Model & model,const Operator & op,const std::map<OperatorType,std::unique_ptr<tflite::BaseOperator>> & op_types_map)89 std::string GetOperatorSignature(
90 const Model& model, const Operator& op,
91 const std::map<OperatorType, std::unique_ptr<tflite::BaseOperator>>&
92 op_types_map) {
93 // The signature of an op has the following schema:
94 // INPUT:SHAPE::TYPE::OUTPUT:SHAPE::TYPE::NAME:VERSION:
95 std::string op_signature;
96 constexpr char delimiter[] = "::";
97
98 // Get input shapes and types.
99 op_signature.append("INPUT:");
100 for (const auto& input : op.inputs) {
101 const auto& array = model.GetArray(input);
102 if (array.has_shape()) {
103 op_signature.append(ShapeToStringNoSpace(array.shape()));
104 } else {
105 op_signature.append("None");
106 }
107 op_signature.append(delimiter);
108 op_signature.append(ArrayDataTypeName(array.data_type) + delimiter);
109 }
110 // Get output shapes and types.
111 op_signature.append("OUTPUT:");
112 for (const auto& output : op.outputs) {
113 const auto& array = model.GetArray(output);
114 if (array.has_shape()) {
115 op_signature.append(ShapeToStringNoSpace(array.shape()));
116 } else {
117 op_signature.append("None");
118 }
119 op_signature.append(delimiter);
120 op_signature.append(ArrayDataTypeName(array.data_type) + delimiter);
121 }
122 // Append Op name.
123 op_signature.append("NAME:");
124 op_signature.append(TryGetOperatorName(op) + delimiter);
125 // Append Op version.
126 op_signature.append("VERSION:");
127 OperatorSignature toco_op_signature;
128 toco_op_signature.op = &op;
129 toco_op_signature.model = &model;
130 if (op_types_map.find(op.type) != op_types_map.end()) {
131 const int version = op_types_map.at(op.type)->GetVersion(toco_op_signature);
132 op_signature.append(std::to_string(version));
133 } else {
134 op_signature.append("None");
135 }
136 return op_signature;
137 }
138
139 } // namespace
140
GetOperatorNames(const Model & model)141 std::vector<std::string> GetOperatorNames(const Model& model) {
142 std::vector<std::string> op_names;
143 for (const auto& op : model.operators) {
144 op_names.push_back(TryGetOperatorName(*op));
145 }
146 return op_names;
147 }
148
CountOperatorsByType(const Model & model,std::map<std::string,int> * built_in_ops,std::map<std::string,int> * custom_ops,std::map<std::string,int> * select_ops)149 void CountOperatorsByType(const Model& model,
150 std::map<std::string, int>* built_in_ops,
151 std::map<std::string, int>* custom_ops,
152 std::map<std::string, int>* select_ops) {
153 for (const auto& op : model.operators) {
154 OperatorSignature op_signature = {op.get(), &model};
155 const auto ops_by_type =
156 tflite::BuildOperatorByTypeMap(true /*enable_select_tf_ops*/);
157 tflite::details::OperatorKey op_key(op_signature, ops_by_type,
158 true /*enable_select_tf_ops*/);
159
160 const std::string op_name = TryGetOperatorName(*op);
161 if (op_key.is_custom_op()) {
162 (*custom_ops)[op_name]++;
163 } else if (op_key.is_flex_op()) {
164 (*select_ops)[op_name]++;
165 } else {
166 (*built_in_ops)[op_name]++;
167 }
168 }
169 }
170
GetInputAndOutputTypes(const Model & model,TFLITE_PROTO_NS::RepeatedPtrField<std::string> * input_types,TFLITE_PROTO_NS::RepeatedPtrField<std::string> * output_types)171 void GetInputAndOutputTypes(
172 const Model& model,
173 TFLITE_PROTO_NS::RepeatedPtrField<std::string>* input_types,
174 TFLITE_PROTO_NS::RepeatedPtrField<std::string>* output_types) {
175 for (const auto& input_array : model.flags.input_arrays()) {
176 const Array& array = model.GetArray(input_array.name());
177 input_types->Add(ArrayDataTypeName(array.data_type));
178 }
179 for (const auto& output_array : model.flags.output_arrays()) {
180 const Array& array = model.GetArray(output_array);
181 output_types->Add(ArrayDataTypeName(array.data_type));
182 }
183 }
184
GetTfLiteVersion()185 std::string GetTfLiteVersion() { return TFLITE_VERSION_STRING; }
186
GetCachedOSVersion()187 std::string GetCachedOSVersion() {
188 static std::string* version = new std::string(GetOSVersion());
189 return *version;
190 }
191
GetOpSignatures(const Model & model,TFLITE_PROTO_NS::RepeatedPtrField<std::string> * op_signatures)192 void GetOpSignatures(
193 const Model& model,
194 TFLITE_PROTO_NS::RepeatedPtrField<std::string>* op_signatures) {
195 const auto& op_types_map =
196 tflite::BuildOperatorByTypeMap(true /*enable_select_tf_ops*/);
197 for (const auto& op : model.operators) {
198 op_signatures->Add(GetOperatorSignature(model, *op, op_types_map));
199 }
200 }
201
GetModelHash(const Model & model)202 std::string GetModelHash(const Model& model) {
203 // TODO(b/123519920): Implement the hash function for Model.
204 // Need to consider different implementations for public/private models.
205 return "";
206 }
207
208 // This function scans through the error message string, extracts the part about
209 // missing ops and prunes away all other information in the error info.
SanitizeErrorMessage(const std::string & error_message)210 std::string SanitizeErrorMessage(const std::string& error_message) {
211 const std::string s1 = "Ops that can be supported by the flex runtime";
212 const std::string s2 = "Ops that need custom implementation";
213 std::string pruned_message;
214 size_t pos = error_message.find(s1);
215 if (pos != std::string::npos) {
216 // Find the terminate point for flex op list.
217 auto end = error_message.find('.', pos);
218 pruned_message.append(error_message.substr(pos, end - pos + 1));
219 }
220 pos = error_message.find(s2);
221 if (pos != std::string::npos) {
222 // Find the terminate point for custom op list.
223 auto end = error_message.find('.', pos);
224 pruned_message.append(error_message.substr(pos, end - pos + 1));
225 }
226 return pruned_message;
227 }
228
PopulateConversionLog(const Model & model,TocoConversionLog * log)229 void PopulateConversionLog(const Model& model, TocoConversionLog* log) {
230 // Get the list of ops after conversion.
231 const std::vector<std::string> op_names = GetOperatorNames(model);
232 for (const auto& op_name : op_names) {
233 log->add_op_list(op_name);
234 }
235
236 // Get op signatures.
237 TFLITE_PROTO_NS::RepeatedPtrField<std::string> op_signatures;
238 GetOpSignatures(model, &op_signatures);
239 log->mutable_op_signatures()->CopyFrom(op_signatures);
240
241 // Get op counts by category: custom, built-in or select.
242 std::map<std::string, int> custom_ops, select_ops, built_in_ops;
243 CountOperatorsByType(model, &built_in_ops, &custom_ops, &select_ops);
244 log->mutable_custom_ops()->insert(custom_ops.cbegin(), custom_ops.cend());
245 log->mutable_built_in_ops()->insert(built_in_ops.cbegin(),
246 built_in_ops.cend());
247 log->mutable_select_ops()->insert(select_ops.cbegin(), select_ops.cend());
248
249 // Get the model's input and output types.
250 TFLITE_PROTO_NS::RepeatedPtrField<std::string> input_types, output_types;
251 GetInputAndOutputTypes(model, &input_types, &output_types);
252 log->mutable_input_tensor_types()->CopyFrom(input_types);
253 log->mutable_output_tensor_types()->CopyFrom(output_types);
254
255 log->set_log_generation_ts(absl::ToUnixMicros(absl::Now()));
256
257 log->set_model_size(model.operators.size());
258 log->set_tf_lite_version(GetTfLiteVersion());
259 log->set_os_version(GetCachedOSVersion());
260 log->set_model_hash(GetModelHash(model));
261 // TODO(b/123519920): Populate TOCO error logs.
262 // Currently we will focus on external installation of TOCO via pip, where
263 // the C++ TOCO binary is invoked via subprocess command, this will make our
264 // life easier collecting the error logs emitted by TOCO. However, note that
265 // if a user directly invokes the C++ TOCO binary, this log might not be
266 // available.
267 }
268
269 } // namespace toco
270