• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/toco/logging/conversion_log_util.h"
16 
17 #ifdef __linux__
18 #include <sys/utsname.h>
19 #endif
20 
21 #include <vector>
22 
23 #include "absl/strings/str_cat.h"
24 #include "absl/time/clock.h"
25 #include "absl/time/time.h"
26 #include "tensorflow/core/framework/node_def.pb.h"
27 #include "tensorflow/lite/toco/model.h"
28 #include "tensorflow/lite/toco/tflite/export.h"
29 #include "tensorflow/lite/toco/tflite/operator.h"
30 #include "tensorflow/lite/toco/tooling_util.h"
31 #include "tensorflow/lite/version.h"
32 
33 namespace toco {
34 
35 namespace {
36 
TryGetOperatorName(const Operator & op)37 std::string TryGetOperatorName(const Operator& op) {
38   std::string op_name;
39   if (!op.tensorflow_node_def.empty()) {
40     // Parse op name from serialized NodeDef.
41     tensorflow::NodeDef node_def;
42     if (!node_def.ParseFromString(op.tensorflow_node_def)) {
43       LOG(ERROR) << "Failed to parse Tensorflow NodeDef";
44     } else {
45       op_name = node_def.op();
46       if (!op_name.empty()) return op_name;
47     }
48   }
49   if (op.type == OperatorType::kUnsupported) {
50     // If we failed to get op name from serialized NodeDef (either because
51     // the tensorflow_node_def is an empty string, or we failed to parse
52     // from it), fall back to use 'tensorflow_op' field if this op is a
53     // TensorflowUnsupportedOperator.
54     const TensorFlowUnsupportedOperator& unsupported_op =
55         static_cast<const TensorFlowUnsupportedOperator&>(op);
56     if (!unsupported_op.tensorflow_op.empty()) {
57       op_name = unsupported_op.tensorflow_op;
58       return op_name;
59     }
60   }
61   // If this is a built-in op.
62   op_name = OperatorTypeName(op.type);
63   return op_name;
64 }
65 
GetOSVersion()66 std::string GetOSVersion() {
67   std::string os_info;
68 #ifdef __linux__
69   utsname info;
70   if (uname(&info)) {
71     // Failed
72     LOG(ERROR) << "Cannot get OS info.";
73     return "";
74   }
75   os_info =
76       std::string(info.sysname) + ";OSVer=" + std::string(info.release) + ";";
77 #endif
78   return os_info;
79 }
80 
ShapeToStringNoSpace(const Shape & shape)81 std::string ShapeToStringNoSpace(const Shape& shape) {
82   if (shape.dimensions_count() == 0) {
83     return "[]";
84   }
85 
86   return absl::StrCat("[", absl::StrJoin(shape.dims(), ","), "]");
87 }
88 
GetOperatorSignature(const Model & model,const Operator & op,const std::map<OperatorType,std::unique_ptr<tflite::BaseOperator>> & op_types_map)89 std::string GetOperatorSignature(
90     const Model& model, const Operator& op,
91     const std::map<OperatorType, std::unique_ptr<tflite::BaseOperator>>&
92         op_types_map) {
93   // The signature of an op has the following schema:
94   // INPUT:SHAPE::TYPE::OUTPUT:SHAPE::TYPE::NAME:VERSION:
95   std::string op_signature;
96   constexpr char delimiter[] = "::";
97 
98   // Get input shapes and types.
99   op_signature.append("INPUT:");
100   for (const auto& input : op.inputs) {
101     const auto& array = model.GetArray(input);
102     if (array.has_shape()) {
103       op_signature.append(ShapeToStringNoSpace(array.shape()));
104     } else {
105       op_signature.append("None");
106     }
107     op_signature.append(delimiter);
108     op_signature.append(ArrayDataTypeName(array.data_type) + delimiter);
109   }
110   // Get output shapes and types.
111   op_signature.append("OUTPUT:");
112   for (const auto& output : op.outputs) {
113     const auto& array = model.GetArray(output);
114     if (array.has_shape()) {
115       op_signature.append(ShapeToStringNoSpace(array.shape()));
116     } else {
117       op_signature.append("None");
118     }
119     op_signature.append(delimiter);
120     op_signature.append(ArrayDataTypeName(array.data_type) + delimiter);
121   }
122   // Append Op name.
123   op_signature.append("NAME:");
124   op_signature.append(TryGetOperatorName(op) + delimiter);
125   // Append Op version.
126   op_signature.append("VERSION:");
127   OperatorSignature toco_op_signature;
128   toco_op_signature.op = &op;
129   toco_op_signature.model = &model;
130   if (op_types_map.find(op.type) != op_types_map.end()) {
131     const int version = op_types_map.at(op.type)->GetVersion(toco_op_signature);
132     op_signature.append(std::to_string(version));
133   } else {
134     op_signature.append("None");
135   }
136   return op_signature;
137 }
138 
139 }  // namespace
140 
GetOperatorNames(const Model & model)141 std::vector<std::string> GetOperatorNames(const Model& model) {
142   std::vector<std::string> op_names;
143   for (const auto& op : model.operators) {
144     op_names.push_back(TryGetOperatorName(*op));
145   }
146   return op_names;
147 }
148 
CountOperatorsByType(const Model & model,std::map<std::string,int> * built_in_ops,std::map<std::string,int> * custom_ops,std::map<std::string,int> * select_ops)149 void CountOperatorsByType(const Model& model,
150                           std::map<std::string, int>* built_in_ops,
151                           std::map<std::string, int>* custom_ops,
152                           std::map<std::string, int>* select_ops) {
153   for (const auto& op : model.operators) {
154     OperatorSignature op_signature = {op.get(), &model};
155     const auto ops_by_type =
156         tflite::BuildOperatorByTypeMap(true /*enable_select_tf_ops*/);
157     tflite::details::OperatorKey op_key(op_signature, ops_by_type,
158                                         true /*enable_select_tf_ops*/);
159 
160     const std::string op_name = TryGetOperatorName(*op);
161     if (op_key.is_custom_op()) {
162       (*custom_ops)[op_name]++;
163     } else if (op_key.is_flex_op()) {
164       (*select_ops)[op_name]++;
165     } else {
166       (*built_in_ops)[op_name]++;
167     }
168   }
169 }
170 
GetInputAndOutputTypes(const Model & model,TFLITE_PROTO_NS::RepeatedPtrField<std::string> * input_types,TFLITE_PROTO_NS::RepeatedPtrField<std::string> * output_types)171 void GetInputAndOutputTypes(
172     const Model& model,
173     TFLITE_PROTO_NS::RepeatedPtrField<std::string>* input_types,
174     TFLITE_PROTO_NS::RepeatedPtrField<std::string>* output_types) {
175   for (const auto& input_array : model.flags.input_arrays()) {
176     const Array& array = model.GetArray(input_array.name());
177     input_types->Add(ArrayDataTypeName(array.data_type));
178   }
179   for (const auto& output_array : model.flags.output_arrays()) {
180     const Array& array = model.GetArray(output_array);
181     output_types->Add(ArrayDataTypeName(array.data_type));
182   }
183 }
184 
GetTfLiteVersion()185 std::string GetTfLiteVersion() { return TFLITE_VERSION_STRING; }
186 
GetCachedOSVersion()187 std::string GetCachedOSVersion() {
188   static std::string* version = new std::string(GetOSVersion());
189   return *version;
190 }
191 
GetOpSignatures(const Model & model,TFLITE_PROTO_NS::RepeatedPtrField<std::string> * op_signatures)192 void GetOpSignatures(
193     const Model& model,
194     TFLITE_PROTO_NS::RepeatedPtrField<std::string>* op_signatures) {
195   const auto& op_types_map =
196       tflite::BuildOperatorByTypeMap(true /*enable_select_tf_ops*/);
197   for (const auto& op : model.operators) {
198     op_signatures->Add(GetOperatorSignature(model, *op, op_types_map));
199   }
200 }
201 
GetModelHash(const Model & model)202 std::string GetModelHash(const Model& model) {
203   // TODO(b/123519920): Implement the hash function for Model.
204   // Need to consider different implementations for public/private models.
205   return "";
206 }
207 
208 // This function scans through the error message string, extracts the part about
209 // missing ops and prunes away all other information in the error info.
SanitizeErrorMessage(const std::string & error_message)210 std::string SanitizeErrorMessage(const std::string& error_message) {
211   const std::string s1 = "Ops that can be supported by the flex runtime";
212   const std::string s2 = "Ops that need custom implementation";
213   std::string pruned_message;
214   size_t pos = error_message.find(s1);
215   if (pos != std::string::npos) {
216     // Find the terminate point for flex op list.
217     auto end = error_message.find('.', pos);
218     pruned_message.append(error_message.substr(pos, end - pos + 1));
219   }
220   pos = error_message.find(s2);
221   if (pos != std::string::npos) {
222     // Find the terminate point for custom op list.
223     auto end = error_message.find('.', pos);
224     pruned_message.append(error_message.substr(pos, end - pos + 1));
225   }
226   return pruned_message;
227 }
228 
PopulateConversionLog(const Model & model,TocoConversionLog * log)229 void PopulateConversionLog(const Model& model, TocoConversionLog* log) {
230   // Get the list of ops after conversion.
231   const std::vector<std::string> op_names = GetOperatorNames(model);
232   for (const auto& op_name : op_names) {
233     log->add_op_list(op_name);
234   }
235 
236   // Get op signatures.
237   TFLITE_PROTO_NS::RepeatedPtrField<std::string> op_signatures;
238   GetOpSignatures(model, &op_signatures);
239   log->mutable_op_signatures()->CopyFrom(op_signatures);
240 
241   // Get op counts by category: custom, built-in or select.
242   std::map<std::string, int> custom_ops, select_ops, built_in_ops;
243   CountOperatorsByType(model, &built_in_ops, &custom_ops, &select_ops);
244   log->mutable_custom_ops()->insert(custom_ops.cbegin(), custom_ops.cend());
245   log->mutable_built_in_ops()->insert(built_in_ops.cbegin(),
246                                       built_in_ops.cend());
247   log->mutable_select_ops()->insert(select_ops.cbegin(), select_ops.cend());
248 
249   // Get the model's input and output types.
250   TFLITE_PROTO_NS::RepeatedPtrField<std::string> input_types, output_types;
251   GetInputAndOutputTypes(model, &input_types, &output_types);
252   log->mutable_input_tensor_types()->CopyFrom(input_types);
253   log->mutable_output_tensor_types()->CopyFrom(output_types);
254 
255   log->set_log_generation_ts(absl::ToUnixMicros(absl::Now()));
256 
257   log->set_model_size(model.operators.size());
258   log->set_tf_lite_version(GetTfLiteVersion());
259   log->set_os_version(GetCachedOSVersion());
260   log->set_model_hash(GetModelHash(model));
261   // TODO(b/123519920): Populate TOCO error logs.
262   // Currently we will focus on external installation of TOCO via pip, where
263   // the C++ TOCO binary is invoked via subprocess command, this will make our
264   // life easier collecting the error logs emitted by TOCO. However, note that
265   // if a user directly invokes the C++ TOCO binary, this log might not be
266   // available.
267 }
268 
269 }  // namespace toco
270