1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/toco/python/toco_python_api.h"
16
17 #include <fstream>
18 #include <map>
19 #include <string>
20 #include <vector>
21
22 #include "google/protobuf/text_format.h"
23 #include "tensorflow/c/kernels.h"
24 #include "tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.h"
25 #include "tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.h"
26 #include "tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.h"
27 #include "tensorflow/compiler/mlir/lite/sparsity/sparsify_model.h"
28 #include "tensorflow/core/framework/op.h"
29 #include "tensorflow/core/framework/op_def.pb.h"
30 #include "tensorflow/core/platform/logging.h"
31 #include "tensorflow/lite/c/common.h"
32 #include "tensorflow/lite/core/api/error_reporter.h"
33 #include "tensorflow/lite/python/interpreter_wrapper/python_error_reporter.h"
34 #include "tensorflow/lite/python/interpreter_wrapper/python_utils.h"
35 #include "tensorflow/lite/schema/schema_generated.h"
36 #include "tensorflow/lite/toco/import_tensorflow.h"
37 #include "tensorflow/lite/toco/logging/conversion_log_util.h"
38 #include "tensorflow/lite/toco/logging/toco_conversion_log.pb.h"
39 #include "tensorflow/lite/toco/model_flags.pb.h"
40 #include "tensorflow/lite/toco/toco_convert.h"
41 #include "tensorflow/lite/toco/toco_flags.pb.h"
42 #include "tensorflow/lite/toco/toco_graphviz_dump_options.h"
43 #include "tensorflow/lite/toco/toco_port.h"
44 #include "tensorflow/lite/toco/toco_tooling.h"
45 #include "tensorflow/lite/toco/toco_types.h"
46 #include "tensorflow/lite/toco/tooling_util.h"
47 #include "tensorflow/lite/toco/types.pb.h"
48
49 namespace toco {
50
PopulateConversionLogHelper(const toco::ModelFlags & model_flags,toco::TocoFlags * toco_flags,const std::string & input_contents_txt,const std::string & output_file_contents_txt,const std::string & error_message,GraphVizDumpOptions * dump_options)51 void PopulateConversionLogHelper(const toco::ModelFlags& model_flags,
52 toco::TocoFlags* toco_flags,
53 const std::string& input_contents_txt,
54 const std::string& output_file_contents_txt,
55 const std::string& error_message,
56 GraphVizDumpOptions* dump_options) {
57 // Make sure the graphviz file will be dumped under the same folder.
58 dump_options->dump_graphviz = toco_flags->conversion_summary_dir();
59 // Here we construct the `toco::Model` class based on the input graph def,
60 // it will then be used to populate the conversion log.
61 // TODO(haoliang): Don't depend on `toco::Model`.
62 std::unique_ptr<toco::Model> imported_model =
63 toco::Import(*toco_flags, model_flags, input_contents_txt);
64 // Dump pre-conversion toco logs.
65 TocoConversionLog toco_log_before;
66 PopulateConversionLog(*imported_model, &toco_log_before);
67 std::ofstream osstream_before(toco_flags->conversion_summary_dir() +
68 "/toco_log_before.pb");
69 toco_log_before.SerializeToOstream(&osstream_before);
70 osstream_before.close();
71 toco::LogDump(toco::kLogLevelModelChanged, "tf_graph", *imported_model);
72
73 // Populate the post-conversion log, for convenient initiate the
74 // `toco::Model` class from the generated flatbuffer.
75 toco_flags->set_input_format(toco::FileFormat::TFLITE);
76 std::unique_ptr<toco::Model> flatbuffer_model =
77 toco::Import(*toco_flags, model_flags, output_file_contents_txt);
78 // Dump post-conversion toco logs.
79 TocoConversionLog toco_log_after;
80 PopulateConversionLog(*flatbuffer_model, &toco_log_after);
81 // Make sure we sanitize the error message.
82 toco_log_after.set_toco_err_logs(SanitizeErrorMessage(error_message));
83 std::ofstream ostream_after(toco_flags->conversion_summary_dir() +
84 "/toco_log_after.pb");
85 toco_log_after.SerializeToOstream(&ostream_after);
86 ostream_after.close();
87 toco::LogDump(toco::kLogLevelModelChanged, "tflite_graph", *flatbuffer_model);
88 }
89
90 // NOTE(aselle): We are using raw PyObject's here because we want to make
91 // sure we input and output bytes rather than unicode strings for Python3.
TocoConvert(PyObject * model_flags_proto_txt_raw,PyObject * toco_flags_proto_txt_raw,PyObject * input_contents_txt_raw,bool extended_return,PyObject * debug_info_txt_raw,bool enable_mlir_converter)92 PyObject* TocoConvert(PyObject* model_flags_proto_txt_raw,
93 PyObject* toco_flags_proto_txt_raw,
94 PyObject* input_contents_txt_raw, bool extended_return,
95 PyObject* debug_info_txt_raw,
96 bool enable_mlir_converter) {
97 // Use Python C API to validate and convert arguments. In py3 (bytes),
98 // in py2 (str).
99 auto ConvertArg = [&](PyObject* obj, bool* error) {
100 char* buf;
101 Py_ssize_t len;
102 if (::tflite::python_utils::ConvertFromPyString(obj, &buf, &len) == -1) {
103 *error = true;
104 return std::string();
105 } else {
106 *error = false;
107 return std::string(buf, len);
108 }
109 };
110
111 bool error;
112 std::string model_flags_proto_txt =
113 ConvertArg(model_flags_proto_txt_raw, &error);
114 if (error) {
115 PyErr_SetString(PyExc_ValueError, "Model flags are invalid.");
116 return nullptr;
117 }
118 std::string toco_flags_proto_txt =
119 ConvertArg(toco_flags_proto_txt_raw, &error);
120 if (error) {
121 PyErr_SetString(PyExc_ValueError, "Toco flags are invalid.");
122 return nullptr;
123 }
124
125 // Use TOCO to produce new outputs.
126 toco::ModelFlags model_flags;
127 if (!model_flags.ParseFromString(model_flags_proto_txt)) {
128 PyErr_SetString(PyExc_ValueError,
129 "Failed to convert Model to Python String.");
130 return nullptr;
131 }
132 toco::TocoFlags toco_flags;
133 if (!toco_flags.ParseFromString(toco_flags_proto_txt)) {
134 PyErr_SetString(PyExc_ValueError,
135 "Failed to convert Toco to Python String.");
136 return nullptr;
137 }
138
139 tensorflow::GraphDebugInfo debug_info;
140 if (debug_info_txt_raw && debug_info_txt_raw != Py_None) {
141 std::string debug_info_txt = ConvertArg(debug_info_txt_raw, &error);
142 if (error) {
143 PyErr_SetString(PyExc_ValueError, "Input DebugInfo is invalid.");
144 return nullptr;
145 }
146 if (!debug_info.ParseFromString(debug_info_txt)) {
147 PyErr_SetString(PyExc_ValueError,
148 "Failed to convert DebugInfo to Python String.");
149 return nullptr;
150 }
151 }
152
153 tensorflow::GraphDef graph_def;
154 std::string input_contents_txt;
155 if (model_flags.saved_model_dir().empty()) {
156 input_contents_txt = ConvertArg(input_contents_txt_raw, &error);
157 if (error) {
158 PyErr_SetString(PyExc_ValueError, "Input GraphDef is invalid.");
159 return nullptr;
160 }
161 if (!graph_def.ParseFromString(input_contents_txt)) {
162 PyErr_SetString(PyExc_ValueError,
163 "Failed to convert GraphDef to Python String.");
164 return nullptr;
165 }
166 }
167
168 auto& dump_options = *GraphVizDumpOptions::singleton();
169 if (toco_flags.has_dump_graphviz_dir()) {
170 dump_options.dump_graphviz = toco_flags.dump_graphviz_dir();
171 }
172 if (toco_flags.has_dump_graphviz_include_video()) {
173 dump_options.dump_graphviz_video = toco_flags.dump_graphviz_include_video();
174 }
175
176 std::string output_file_contents_txt;
177 tensorflow::Status status;
178 int64 arithmetic_ops_count;
179
180 // Convert model.
181 if (enable_mlir_converter) {
182 if (!model_flags.saved_model_dir().empty()) {
183 status = tensorflow::ConvertSavedModelToTFLiteFlatBuffer(
184 model_flags, toco_flags, &output_file_contents_txt);
185 } else {
186 tensorflow::GraphDef graph_def;
187 if (!graph_def.ParseFromString(input_contents_txt)) {
188 PyErr_SetString(PyExc_ValueError,
189 "Failed to convert GraphDef to Python String.");
190 return nullptr;
191 }
192
193 status = tensorflow::ConvertGraphDefToTFLiteFlatBuffer(
194 model_flags, toco_flags, debug_info, graph_def,
195 &output_file_contents_txt);
196 if (!toco_flags.conversion_summary_dir().empty()) {
197 PopulateConversionLogHelper(
198 model_flags, &toco_flags, input_contents_txt,
199 output_file_contents_txt, status.error_message(), &dump_options);
200 }
201 }
202 } else {
203 status = Convert(input_contents_txt, toco_flags, model_flags,
204 &output_file_contents_txt, &arithmetic_ops_count);
205 }
206
207 if (!status.ok()) {
208 PyErr_SetString(PyExc_Exception, status.error_message().c_str());
209 return nullptr;
210 }
211 if (extended_return && !enable_mlir_converter) {
212 PyObject* dict = PyDict_New();
213 PyDict_SetItemString(
214 dict, "flatbuffer",
215 ::tflite::python_utils::ConvertToPyString(
216 output_file_contents_txt.data(), output_file_contents_txt.size()));
217 PyDict_SetItemString(dict, "arithmetic_ops",
218 PyLong_FromLong(arithmetic_ops_count));
219 return dict;
220 }
221 // Convert arguments back to byte (py3) or str (py2)
222 return ::tflite::python_utils::ConvertToPyString(
223 output_file_contents_txt.data(), output_file_contents_txt.size());
224 }
225
TocoGetPotentiallySupportedOps()226 PyObject* TocoGetPotentiallySupportedOps() {
227 std::vector<std::string> supported_ops = toco::GetPotentiallySupportedOps();
228 PyObject* list = PyList_New(supported_ops.size());
229 for (size_t i = 0; i < supported_ops.size(); ++i) {
230 const std::string& op = supported_ops[i];
231 PyObject* op_dict = PyDict_New();
232 PyDict_SetItemString(op_dict, "op", PyUnicode_FromString(op.c_str()));
233 PyList_SetItem(list, i, op_dict);
234 }
235 return list;
236 }
237
MlirQuantizeModel(PyObject * data,bool disable_per_channel,bool fully_quantize,int inference_type,bool enable_numeric_verify)238 PyObject* MlirQuantizeModel(PyObject* data, bool disable_per_channel,
239 bool fully_quantize, int inference_type,
240 bool enable_numeric_verify) {
241 using tflite::interpreter_wrapper::PythonErrorReporter;
242 char* buf = nullptr;
243 Py_ssize_t length;
244 std::unique_ptr<PythonErrorReporter> error_reporter(new PythonErrorReporter);
245
246 if (tflite::python_utils::ConvertFromPyString(data, &buf, &length) == -1) {
247 PyErr_Format(PyExc_ValueError, "Failed to convert input PyObject");
248 return nullptr;
249 }
250 std::unique_ptr<tflite::FlatBufferModel> model =
251 tflite::FlatBufferModel::BuildFromBuffer(buf, length,
252 error_reporter.get());
253 if (!model) {
254 PyErr_Format(PyExc_ValueError, "Invalid model");
255 return nullptr;
256 }
257 auto tflite_model = absl::make_unique<tflite::ModelT>();
258 model->GetModel()->UnPackTo(tflite_model.get(), nullptr);
259
260 tflite::TensorType inference_tensor_type;
261 switch (inference_type) {
262 case toco::IODataType::QUANTIZED_INT16:
263 inference_tensor_type = tflite::TensorType_INT16;
264 break;
265 case toco::IODataType::QUANTIZED_UINT8:
266 inference_tensor_type = tflite::TensorType_UINT8;
267 break;
268 case toco::IODataType::INT8:
269 inference_tensor_type = tflite::TensorType_INT8;
270 break;
271 default:
272 return nullptr;
273 }
274 tflite::TensorType inference_io_type =
275 fully_quantize ? inference_tensor_type : tflite::TensorType_FLOAT32;
276 flatbuffers::FlatBufferBuilder builder;
277 auto status = mlir::lite::QuantizeModel(
278 *tflite_model, inference_io_type, inference_io_type,
279 inference_tensor_type, {}, disable_per_channel, fully_quantize, &builder,
280 error_reporter.get(), enable_numeric_verify);
281
282 if (status != kTfLiteOk) {
283 error_reporter->exception();
284 return nullptr;
285 }
286 return tflite::python_utils::ConvertToPyString(
287 reinterpret_cast<const char*>(builder.GetCurrentBufferPointer()),
288 builder.GetSize());
289 }
290
MlirSparsifyModel(PyObject * data)291 PyObject* MlirSparsifyModel(PyObject* data) {
292 using tflite::interpreter_wrapper::PythonErrorReporter;
293 char* buf = nullptr;
294 Py_ssize_t length;
295 std::unique_ptr<PythonErrorReporter> error_reporter(new PythonErrorReporter);
296
297 if (tflite::python_utils::ConvertFromPyString(data, &buf, &length) == -1) {
298 PyErr_Format(PyExc_ValueError, "Failed to convert input PyObject");
299 return nullptr;
300 }
301 std::unique_ptr<tflite::FlatBufferModel> model =
302 tflite::FlatBufferModel::BuildFromBuffer(buf, length,
303 error_reporter.get());
304 if (!model) {
305 PyErr_Format(PyExc_ValueError, "Invalid model");
306 return nullptr;
307 }
308 auto tflite_model = absl::make_unique<tflite::ModelT>();
309 model->GetModel()->UnPackTo(tflite_model.get(), nullptr);
310
311 flatbuffers::FlatBufferBuilder builder;
312 auto status =
313 mlir::lite::SparsifyModel(*tflite_model, &builder, error_reporter.get());
314
315 if (status != kTfLiteOk) {
316 error_reporter->exception();
317 return nullptr;
318 }
319 return tflite::python_utils::ConvertToPyString(
320 reinterpret_cast<const char*>(builder.GetCurrentBufferPointer()),
321 builder.GetSize());
322 }
323
RegisterCustomOpdefs(PyObject * list)324 PyObject* RegisterCustomOpdefs(PyObject* list) {
325 if (!PyList_Check(list)) {
326 PyErr_SetString(PyExc_TypeError, "Expected list in argument");
327 return nullptr;
328 }
329
330 int64 size = PyList_Size(list);
331 for (int i = 0; i < size; ++i) {
332 // Get character array from Python object.
333 char* tf_opdefs;
334 Py_ssize_t len;
335 if (tflite::python_utils::ConvertFromPyString(PyList_GetItem(list, i),
336 &tf_opdefs, &len) == -1) {
337 PyErr_Format(PyExc_ValueError,
338 "Failed to convert Python string at index %d of custom op "
339 "defs argument",
340 i);
341 return nullptr;
342 }
343
344 // Parse op def from character array.
345 tensorflow::OpDef opdef;
346 if (!tensorflow::protobuf::TextFormat::ParseFromString(tf_opdefs, &opdef)) {
347 PyErr_Format(
348 PyExc_ValueError,
349 "Failed to parse opdefs at index %d of custom op defs argument: %s",
350 i, tf_opdefs);
351 return nullptr;
352 }
353
354 // Register extra opdefs to TensorFlow global op registry.
355 tensorflow::OpRegistry::Global()->Register(
356 [opdef](
357 tensorflow::OpRegistrationData* op_reg_data) -> tensorflow::Status {
358 *op_reg_data = tensorflow::OpRegistrationData(opdef);
359 return tensorflow::Status::OK();
360 });
361
362 // Register the corresponding fake op kernel.
363 const char* node_name = opdef.name().c_str();
364 const char* op_name = opdef.name().c_str();
365 const char* device_name = "CPU";
366 static auto fake_compute_func = [](void* kernel, TF_OpKernelContext* ctx) {
367 };
368
369 TF_KernelBuilder* builder =
370 TF_NewKernelBuilder(op_name, device_name, /*create_func=*/nullptr,
371 fake_compute_func, /*delete_func=*/nullptr);
372
373 TF_Status* status = TF_NewStatus();
374 TF_RegisterKernelBuilder(node_name, builder, status);
375 if (TF_GetCode(status) != TF_OK) {
376 TF_DeleteStatus(status);
377 PyErr_Format(PyExc_ValueError,
378 "Failed to register fake op kernel at index %d of custom op "
379 "defs argument",
380 i);
381 return nullptr;
382 }
383 TF_DeleteStatus(status);
384 }
385
386 Py_RETURN_TRUE;
387 }
388
389 } // namespace toco
390