1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <memory>
16 #include <sstream>
17 #include <string>
18
19 #include "absl/strings/str_join.h"
20 #include "tensorflow/lite/core/api/error_reporter.h"
21 #include "tensorflow/lite/model_builder.h"
22 #include "tensorflow/lite/schema/schema_generated.h"
23 #include "tensorflow/lite/schema/schema_utils.h"
24 #include "tensorflow/lite/tools/versioning/gpu_compatibility.h"
25 #include "tensorflow/lite/version.h"
26
27 namespace tflite {
28
29 namespace {
30
31 // Dump details of the given tensor.
dump_tensor_detail(std::stringstream & out_stream,const tflite::Tensor * tensor,const int tensor_idx)32 void dump_tensor_detail(std::stringstream& out_stream,
33 const tflite::Tensor* tensor, const int tensor_idx) {
34 out_stream << "T#" << tensor_idx;
35 out_stream << "(" << tensor->name()->str() << ") ";
36 // Prints `shape_signature` instead of `shape` if it's available since it
37 // supports dynamic shapes.
38 if (tensor->shape_signature()) {
39 out_stream << "shape_signature:[";
40 for (int i = 0; i < tensor->shape_signature()->Length(); ++i) {
41 const int j = tensor->shape_signature()->Get(i);
42 out_stream << j;
43 if (i != tensor->shape_signature()->Length() - 1) {
44 out_stream << ", ";
45 }
46 }
47 out_stream << "]";
48 } else {
49 out_stream << "shape:[";
50 for (int i = 0; i < tensor->shape()->Length(); ++i) {
51 const int j = tensor->shape()->Get(i);
52 out_stream << j;
53 if (i != tensor->shape()->Length() - 1) {
54 out_stream << ", ";
55 }
56 }
57 out_stream << "]";
58 }
59 out_stream << ", type:" << EnumNameTensorType(tensor->type());
60 out_stream << "\n";
61 }
62
63 // Dump list of input or output tensors.
dump_tensor_list(std::stringstream & out_stream,const flatbuffers::Vector<int32_t> * tensors,bool verbose=false)64 void dump_tensor_list(std::stringstream& out_stream,
65 const flatbuffers::Vector<int32_t>* tensors,
66 bool verbose = false) {
67 for (int i = 0; i < tensors->Length(); ++i) {
68 const int tensor_idx = tensors->Get(i);
69 if (verbose) {
70 out_stream << "tensor #" << tensor_idx;
71 } else {
72 out_stream << "T#" << tensor_idx;
73 }
74 if (i != tensors->Length() - 1) {
75 if (verbose) {
76 out_stream << " and ";
77 } else {
78 out_stream << ", ";
79 }
80 }
81 }
82 }
83
84 // Returns the string representation of the given OperatorCode.
get_op_name(const OperatorCode * op_code)85 const std::string get_op_name(const OperatorCode* op_code) {
86 auto builtin_code = GetBuiltinCode(op_code);
87 if (builtin_code != BuiltinOperator_CUSTOM) {
88 return EnumNameBuiltinOperator(builtin_code);
89 } else {
90 return op_code->custom_code()->str();
91 }
92 }
93
94 // Dump the given Operator node.
dump_node(std::stringstream & out_stream,const int node_no,const OperatorCode * op_code,const Operator * op,const SubGraph * subgraph)95 void dump_node(std::stringstream& out_stream, const int node_no,
96 const OperatorCode* op_code, const Operator* op,
97 const SubGraph* subgraph) {
98 out_stream << "Op#" << node_no << " " << get_op_name(op_code);
99 out_stream << "(";
100 dump_tensor_list(out_stream, op->inputs());
101 out_stream << ") -> [";
102 dump_tensor_list(out_stream, op->outputs());
103 out_stream << "]\n";
104 }
105
106 // Dump the summary of the given TFLite flatbuffer model. It's printed at the
107 // beginning of the analyzer output.
dump_model_summary(std::stringstream & out_stream,const::tflite::Model * model)108 void dump_model_summary(std::stringstream& out_stream,
109 const ::tflite::Model* model) {
110 auto* subgraphs = model->subgraphs();
111 out_stream
112 << "Your TFLite model has ‘" << subgraphs->Length()
113 << "’ subgraph(s). In the subgraph description below,\nT# represents the "
114 "Tensor numbers. ";
115 if (subgraphs->Length() > 0 && subgraphs->Get(0)->operators()->Length() > 0) {
116 const Operator* first_op = subgraphs->Get(0)->operators()->Get(0);
117 const OperatorCode* first_op_code =
118 model->operator_codes()->Get(first_op->opcode_index());
119 out_stream << "For example, in Subgraph#0, the "
120 << get_op_name(first_op_code) << " op takes\n";
121 dump_tensor_list(out_stream, first_op->inputs(), /*verbose=*/true);
122 out_stream << " as input and produces ";
123 dump_tensor_list(out_stream, first_op->outputs(), /*verbose=*/true);
124 out_stream << " as output.\n\n";
125 }
126 }
127
128 } // namespace
129
130 class StreamErrorReporter : public ErrorReporter {
131 public:
StreamErrorReporter(std::stringstream * out_stream)132 explicit StreamErrorReporter(std::stringstream* out_stream)
133 : out_stream_(out_stream) {}
Report(const char * format,va_list args)134 int Report(const char* format, va_list args) override {
135 char buffer[1024];
136 int size = vsnprintf(buffer, sizeof(buffer), format, args);
137 *out_stream_ << buffer;
138 return size;
139 }
140
141 private:
142 std::stringstream* out_stream_;
143 };
144
model_analyzer(const std::string & model_file_or_buffer,bool input_is_filepath,bool check_gpu_compatibility)145 std::string model_analyzer(const std::string& model_file_or_buffer,
146 bool input_is_filepath,
147 bool check_gpu_compatibility) {
148 std::stringstream out_stream;
149 StreamErrorReporter error_reporter(&out_stream);
150 std::unique_ptr<FlatBufferModel> fb_model;
151 if (input_is_filepath) {
152 fb_model = FlatBufferModel::BuildFromFile(model_file_or_buffer.c_str(),
153 &error_reporter);
154 if (!fb_model) {
155 out_stream << "Failed to mmap model " << model_file_or_buffer;
156 return out_stream.str();
157 }
158 } else {
159 fb_model = FlatBufferModel::BuildFromBuffer(model_file_or_buffer.c_str(),
160 model_file_or_buffer.size(),
161 &error_reporter);
162 if (!fb_model) {
163 out_stream << "Failed to mmap the given model buffer.";
164 return out_stream.str();
165 }
166 }
167 const ::tflite::Model* model = fb_model->GetModel();
168 auto* subgraphs = model->subgraphs();
169
170 dump_model_summary(out_stream, model);
171
172 bool model_is_gpu_compatibile = true;
173 for (int i = 0; i < subgraphs->Length(); ++i) {
174 std::vector<int> gpu_incompatibile_nodes;
175 const SubGraph* subgraph = subgraphs->Get(i);
176 out_stream << "Subgraph#" << i;
177 if (subgraph->name()) {
178 out_stream << " " << subgraph->name()->str();
179 }
180 out_stream << "(";
181 dump_tensor_list(out_stream, subgraph->inputs());
182 out_stream << ") -> [";
183 dump_tensor_list(out_stream, subgraph->outputs());
184 out_stream << "]\n";
185 for (int j = 0; j < subgraph->operators()->Length(); ++j) {
186 const Operator* op = subgraph->operators()->Get(j);
187 const OperatorCode* op_code =
188 model->operator_codes()->Get(op->opcode_index());
189 out_stream << " "; // indents for operators
190 dump_node(out_stream, /*node_no=*/j, op_code, op, subgraph);
191 if (check_gpu_compatibility) {
192 auto status =
193 CheckGpuDelegateCompatibility(op_code, op, subgraph, model);
194 if (!status.ok()) {
195 gpu_incompatibile_nodes.push_back(j);
196 out_stream << "GPU COMPATIBILITY WARNING: " << status.message()
197 << "\n";
198 }
199 }
200 }
201 if (!gpu_incompatibile_nodes.empty()) {
202 model_is_gpu_compatibile = false;
203 out_stream << "\nGPU COMPATIBILITY WARNING: Subgraph#" << i
204 << " has GPU delegate compatibility issues at nodes "
205 << absl::StrJoin(gpu_incompatibile_nodes, ", ")
206 << " with TFLite runtime version " << TF_VERSION_STRING
207 << "\n";
208 }
209
210 // Dump Subgraph Tensors.
211 out_stream << "\nTensors of Subgraph#" << i << "\n";
212 auto tensors = subgraph->tensors();
213 for (int j = 0; j < tensors->Length(); ++j) {
214 auto tensor = tensors->Get(j);
215 out_stream << " "; // indents for tensors
216 dump_tensor_detail(out_stream, tensor, j);
217 }
218 }
219 if (check_gpu_compatibility && model_is_gpu_compatibile) {
220 out_stream
221 << "\nYour model looks compatibile with GPU delegate"
222 << " with TFLite runtime version " << TF_VERSION_STRING
223 << ".\nBut it doesn't guarantee that your model works well with GPU "
224 "delegate.\nThere could be some runtime incompatibililty happen.\n";
225 }
226 return out_stream.str();
227 }
228
229 } // namespace tflite
230