1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/toco/python/toco_python_api.h"
16
17 #include <fstream>
18 #include <map>
19 #include <string>
20 #include <vector>
21
22 #include "google/protobuf/text_format.h"
23 #include "tensorflow/c/kernels.h"
24 #include "tensorflow/compiler/mlir/lite/metrics/error_collector.h"
25 #include "tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.h"
26 #include "tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.h"
27 #include "tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.h"
28 #include "tensorflow/compiler/mlir/lite/sparsity/sparsify_model.h"
29 #include "tensorflow/core/framework/op.h"
30 #include "tensorflow/core/framework/op_def.pb.h"
31 #include "tensorflow/core/platform/logging.h"
32 #include "tensorflow/lite/c/common.h"
33 #include "tensorflow/lite/core/api/error_reporter.h"
34 #include "tensorflow/lite/python/interpreter_wrapper/python_error_reporter.h"
35 #include "tensorflow/lite/python/interpreter_wrapper/python_utils.h"
36 #include "tensorflow/lite/schema/schema_generated.h"
37 #include "tensorflow/lite/toco/import_tensorflow.h"
38 #include "tensorflow/lite/toco/logging/conversion_log_util.h"
39 #include "tensorflow/lite/toco/logging/toco_conversion_log.pb.h"
40 #include "tensorflow/lite/toco/model_flags.pb.h"
41 #include "tensorflow/lite/toco/toco_convert.h"
42 #include "tensorflow/lite/toco/toco_flags.pb.h"
43 #include "tensorflow/lite/toco/toco_graphviz_dump_options.h"
44 #include "tensorflow/lite/toco/toco_port.h"
45 #include "tensorflow/lite/toco/toco_tooling.h"
46 #include "tensorflow/lite/toco/toco_types.h"
47 #include "tensorflow/lite/toco/tooling_util.h"
48 #include "tensorflow/lite/toco/types.pb.h"
49
50 namespace toco {
51 using mlir::lite::StringSet;
52
PopulateConversionLogHelper(const toco::ModelFlags & model_flags,toco::TocoFlags * toco_flags,const std::string & input_contents_txt,const std::string & output_file_contents_txt,const std::string & error_message,GraphVizDumpOptions * dump_options)53 void PopulateConversionLogHelper(const toco::ModelFlags& model_flags,
54 toco::TocoFlags* toco_flags,
55 const std::string& input_contents_txt,
56 const std::string& output_file_contents_txt,
57 const std::string& error_message,
58 GraphVizDumpOptions* dump_options) {
59 // Make sure the graphviz file will be dumped under the same folder.
60 dump_options->dump_graphviz = toco_flags->conversion_summary_dir();
61 // Here we construct the `toco::Model` class based on the input graph def,
62 // it will then be used to populate the conversion log.
63 // TODO(haoliang): Don't depend on `toco::Model`.
64 std::unique_ptr<toco::Model> imported_model =
65 toco::Import(*toco_flags, model_flags, input_contents_txt);
66 // Dump pre-conversion toco logs.
67 TocoConversionLog toco_log_before;
68 PopulateConversionLog(*imported_model, &toco_log_before);
69 std::ofstream osstream_before(toco_flags->conversion_summary_dir() +
70 "/toco_log_before.pb");
71 toco_log_before.SerializeToOstream(&osstream_before);
72 osstream_before.close();
73 toco::LogDump(toco::kLogLevelModelChanged, "tf_graph", *imported_model);
74
75 // Populate the post-conversion log, for convenient initiate the
76 // `toco::Model` class from the generated flatbuffer.
77 toco_flags->set_input_format(toco::FileFormat::TFLITE);
78 std::unique_ptr<toco::Model> flatbuffer_model =
79 toco::Import(*toco_flags, model_flags, output_file_contents_txt);
80 // Dump post-conversion toco logs.
81 TocoConversionLog toco_log_after;
82 PopulateConversionLog(*flatbuffer_model, &toco_log_after);
83 // Make sure we sanitize the error message.
84 toco_log_after.set_toco_err_logs(SanitizeErrorMessage(error_message));
85 std::ofstream ostream_after(toco_flags->conversion_summary_dir() +
86 "/toco_log_after.pb");
87 toco_log_after.SerializeToOstream(&ostream_after);
88 ostream_after.close();
89 toco::LogDump(toco::kLogLevelModelChanged, "tflite_graph", *flatbuffer_model);
90 }
91
92 // NOTE(aselle): We are using raw PyObject's here because we want to make
93 // sure we input and output bytes rather than unicode strings for Python3.
TocoConvert(PyObject * model_flags_proto_txt_raw,PyObject * toco_flags_proto_txt_raw,PyObject * input_contents_txt_raw,bool extended_return,PyObject * debug_info_txt_raw,bool enable_mlir_converter)94 PyObject* TocoConvert(PyObject* model_flags_proto_txt_raw,
95 PyObject* toco_flags_proto_txt_raw,
96 PyObject* input_contents_txt_raw, bool extended_return,
97 PyObject* debug_info_txt_raw,
98 bool enable_mlir_converter) {
99 // Use Python C API to validate and convert arguments. In py3 (bytes),
100 // in py2 (str).
101 auto ConvertArg = [&](PyObject* obj, bool* error) {
102 char* buf;
103 Py_ssize_t len;
104 if (::tflite::python_utils::ConvertFromPyString(obj, &buf, &len) == -1) {
105 *error = true;
106 return std::string();
107 } else {
108 *error = false;
109 return std::string(buf, len);
110 }
111 };
112
113 bool error;
114 std::string model_flags_proto_txt =
115 ConvertArg(model_flags_proto_txt_raw, &error);
116 if (error) {
117 PyErr_SetString(PyExc_ValueError, "Model flags are invalid.");
118 return nullptr;
119 }
120 std::string toco_flags_proto_txt =
121 ConvertArg(toco_flags_proto_txt_raw, &error);
122 if (error) {
123 PyErr_SetString(PyExc_ValueError, "Toco flags are invalid.");
124 return nullptr;
125 }
126
127 // Use TOCO to produce new outputs.
128 toco::ModelFlags model_flags;
129 if (!model_flags.ParseFromString(model_flags_proto_txt)) {
130 PyErr_SetString(PyExc_ValueError,
131 "Failed to convert Model to Python String.");
132 return nullptr;
133 }
134 toco::TocoFlags toco_flags;
135 if (!toco_flags.ParseFromString(toco_flags_proto_txt)) {
136 PyErr_SetString(PyExc_ValueError,
137 "Failed to convert Toco to Python String.");
138 return nullptr;
139 }
140
141 tensorflow::GraphDebugInfo debug_info;
142 if (debug_info_txt_raw && debug_info_txt_raw != Py_None) {
143 std::string debug_info_txt = ConvertArg(debug_info_txt_raw, &error);
144 if (error) {
145 PyErr_SetString(PyExc_ValueError, "Input DebugInfo is invalid.");
146 return nullptr;
147 }
148 if (!debug_info.ParseFromString(debug_info_txt)) {
149 PyErr_SetString(PyExc_ValueError,
150 "Failed to convert DebugInfo to Python String.");
151 return nullptr;
152 }
153 }
154
155 tensorflow::GraphDef graph_def;
156 std::string input_contents_txt;
157 if (model_flags.saved_model_dir().empty()) {
158 input_contents_txt = ConvertArg(input_contents_txt_raw, &error);
159 if (error) {
160 PyErr_SetString(PyExc_ValueError, "Input GraphDef is invalid.");
161 return nullptr;
162 }
163 if (!graph_def.ParseFromString(input_contents_txt)) {
164 PyErr_SetString(PyExc_ValueError,
165 "Failed to convert GraphDef to Python String.");
166 return nullptr;
167 }
168 }
169
170 auto& dump_options = *GraphVizDumpOptions::singleton();
171 if (toco_flags.has_dump_graphviz_dir()) {
172 dump_options.dump_graphviz = toco_flags.dump_graphviz_dir();
173 }
174 if (toco_flags.has_dump_graphviz_include_video()) {
175 dump_options.dump_graphviz_video = toco_flags.dump_graphviz_include_video();
176 }
177
178 std::string output_file_contents_txt;
179 tensorflow::Status status;
180 int64_t arithmetic_ops_count;
181
182 // Convert model.
183 if (enable_mlir_converter) {
184 if (!model_flags.saved_model_dir().empty()) {
185 status = tensorflow::ConvertSavedModelToTFLiteFlatBuffer(
186 model_flags, toco_flags, &output_file_contents_txt);
187 } else {
188 tensorflow::GraphDef graph_def;
189 if (!graph_def.ParseFromString(input_contents_txt)) {
190 PyErr_SetString(PyExc_ValueError,
191 "Failed to convert GraphDef to Python String.");
192 return nullptr;
193 }
194
195 status = tensorflow::ConvertGraphDefToTFLiteFlatBuffer(
196 model_flags, toco_flags, debug_info, graph_def,
197 &output_file_contents_txt);
198 if (!toco_flags.conversion_summary_dir().empty()) {
199 PopulateConversionLogHelper(
200 model_flags, &toco_flags, input_contents_txt,
201 output_file_contents_txt, status.error_message(), &dump_options);
202 }
203 }
204 } else {
205 status = Convert(input_contents_txt, toco_flags, model_flags,
206 &output_file_contents_txt, &arithmetic_ops_count);
207 }
208
209 if (!status.ok()) {
210 PyErr_SetString(PyExc_Exception, status.error_message().c_str());
211 return nullptr;
212 }
213 if (extended_return && !enable_mlir_converter) {
214 PyObject* dict = PyDict_New();
215 PyDict_SetItemString(
216 dict, "flatbuffer",
217 ::tflite::python_utils::ConvertToPyString(
218 output_file_contents_txt.data(), output_file_contents_txt.size()));
219 PyDict_SetItemString(dict, "arithmetic_ops",
220 PyLong_FromLong(arithmetic_ops_count));
221 return dict;
222 }
223 // Convert arguments back to byte (py3) or str (py2)
224 return ::tflite::python_utils::ConvertToPyString(
225 output_file_contents_txt.data(), output_file_contents_txt.size());
226 }
227
FromTocoDataTypeToTflitToTensorType(int inference_type)228 tflite::TensorType FromTocoDataTypeToTflitToTensorType(int inference_type) {
229 switch (inference_type) {
230 case toco::IODataType::QUANTIZED_INT16:
231 return tflite::TensorType_INT16;
232 case toco::IODataType::QUANTIZED_UINT8:
233 return tflite::TensorType_UINT8;
234 case toco::IODataType::UINT8:
235 return tflite::TensorType_UINT8;
236 case toco::IODataType::QUANTIZED_INT8:
237 return tflite::TensorType_INT8;
238 case toco::IODataType::INT8:
239 return tflite::TensorType_INT8;
240 default:
241 return tflite::TensorType_FLOAT32;
242 }
243 }
244
ToStringSet(PyObject * py_denylist,StringSet * string_set)245 int ToStringSet(PyObject* py_denylist, StringSet* string_set) {
246 using tflite::python_utils::ConvertFromPyString;
247 // Ensure op_denylist is non null
248 if (!py_denylist) {
249 return 0;
250 }
251 if (PyList_Check(py_denylist)) {
252 for (int i = 0; i < PyList_GET_SIZE(py_denylist); ++i) {
253 PyObject* value = PyList_GetItem(py_denylist, i);
254 char* str_buf;
255 Py_ssize_t length;
256 if (ConvertFromPyString(value, &str_buf, &length) == -1) {
257 return -1;
258 }
259 string_set->emplace(str_buf, length);
260 }
261 }
262 if (PySet_Check(py_denylist)) {
263 auto* tmp = PySet_New(py_denylist);
264 while (PySet_GET_SIZE(tmp)) {
265 PyObject* value = PySet_Pop(tmp);
266 char* str_buf;
267 Py_ssize_t length;
268 if (ConvertFromPyString(value, &str_buf, &length) == -1) {
269 return -1;
270 }
271 string_set->emplace(str_buf, length);
272 }
273 }
274 return 0;
275 }
276
MlirQuantizeModel(PyObject * data,bool disable_per_channel,bool fully_quantize,int inference_type,int input_data_type,int output_data_type,bool enable_numeric_verify,bool enable_whole_model_verify,PyObject * op_denylist,PyObject * node_denylist)277 PyObject* MlirQuantizeModel(PyObject* data, bool disable_per_channel,
278 bool fully_quantize, int inference_type,
279 int input_data_type, int output_data_type,
280 bool enable_numeric_verify,
281 bool enable_whole_model_verify,
282 PyObject* op_denylist, PyObject* node_denylist) {
283 using tflite::interpreter_wrapper::PythonErrorReporter;
284 char* buf = nullptr;
285 Py_ssize_t length;
286 std::unique_ptr<PythonErrorReporter> error_reporter(new PythonErrorReporter);
287
288 if (tflite::python_utils::ConvertFromPyString(data, &buf, &length) == -1) {
289 PyErr_Format(PyExc_ValueError, "Failed to convert input PyObject");
290 return nullptr;
291 }
292
293 StringSet denylisted_ops;
294 StringSet denylisted_nodes;
295 if (ToStringSet(op_denylist, &denylisted_ops) == -1) {
296 PyErr_Format(PyExc_ValueError, "Failed to convert op denylist PyObject");
297 return nullptr;
298 }
299 if (ToStringSet(node_denylist, &denylisted_nodes) == -1) {
300 PyErr_Format(PyExc_ValueError, "Failed to convert node denylist PyObject");
301 return nullptr;
302 }
303
304 std::unique_ptr<tflite::FlatBufferModel> model =
305 tflite::FlatBufferModel::BuildFromBuffer(buf, length,
306 error_reporter.get());
307 if (!model) {
308 PyErr_Format(PyExc_ValueError, "Invalid model");
309 return nullptr;
310 }
311 auto tflite_model = absl::make_unique<tflite::ModelT>();
312 model->GetModel()->UnPackTo(tflite_model.get(), nullptr);
313
314 tflite::TensorType inference_tensor_type =
315 FromTocoDataTypeToTflitToTensorType(inference_type);
316 tflite::TensorType input_type =
317 FromTocoDataTypeToTflitToTensorType(input_data_type);
318 tflite::TensorType output_type =
319 FromTocoDataTypeToTflitToTensorType(output_data_type);
320
321 flatbuffers::FlatBufferBuilder builder;
322 auto status = mlir::lite::QuantizeModel(
323 *tflite_model, input_type, output_type, inference_tensor_type, {},
324 disable_per_channel, fully_quantize, &builder, error_reporter.get(),
325 enable_numeric_verify, enable_whole_model_verify,
326 /*legacy_float_scale=*/true, denylisted_ops, denylisted_nodes);
327
328 if (status != kTfLiteOk) {
329 error_reporter->exception();
330 return nullptr;
331 }
332 return tflite::python_utils::ConvertToPyString(
333 reinterpret_cast<const char*>(builder.GetCurrentBufferPointer()),
334 builder.GetSize());
335 }
336
MlirSparsifyModel(PyObject * data)337 PyObject* MlirSparsifyModel(PyObject* data) {
338 using tflite::interpreter_wrapper::PythonErrorReporter;
339 char* buf = nullptr;
340 Py_ssize_t length;
341 std::unique_ptr<PythonErrorReporter> error_reporter(new PythonErrorReporter);
342
343 if (tflite::python_utils::ConvertFromPyString(data, &buf, &length) == -1) {
344 PyErr_Format(PyExc_ValueError, "Failed to convert input PyObject");
345 return nullptr;
346 }
347 std::unique_ptr<tflite::FlatBufferModel> model =
348 tflite::FlatBufferModel::BuildFromBuffer(buf, length,
349 error_reporter.get());
350 if (!model) {
351 PyErr_Format(PyExc_ValueError, "Invalid model");
352 return nullptr;
353 }
354 auto tflite_model = absl::make_unique<tflite::ModelT>();
355 model->GetModel()->UnPackTo(tflite_model.get(), nullptr);
356
357 flatbuffers::FlatBufferBuilder builder;
358 auto status =
359 mlir::lite::SparsifyModel(*tflite_model, &builder, error_reporter.get());
360
361 if (status != kTfLiteOk) {
362 error_reporter->exception();
363 return nullptr;
364 }
365 return tflite::python_utils::ConvertToPyString(
366 reinterpret_cast<const char*>(builder.GetCurrentBufferPointer()),
367 builder.GetSize());
368 }
369
RegisterCustomOpdefs(PyObject * list)370 PyObject* RegisterCustomOpdefs(PyObject* list) {
371 if (!PyList_Check(list)) {
372 PyErr_SetString(PyExc_TypeError, "Expected list in argument");
373 return nullptr;
374 }
375
376 int64_t size = PyList_Size(list);
377 for (int i = 0; i < size; ++i) {
378 // Get character array from Python object.
379 char* tf_opdefs;
380 Py_ssize_t len;
381 if (tflite::python_utils::ConvertFromPyString(PyList_GetItem(list, i),
382 &tf_opdefs, &len) == -1) {
383 PyErr_Format(PyExc_ValueError,
384 "Failed to convert Python string at index %d of custom op "
385 "defs argument",
386 i);
387 return nullptr;
388 }
389
390 // Parse op def from character array.
391 tensorflow::OpDef opdef;
392 if (!tensorflow::protobuf::TextFormat::ParseFromString(tf_opdefs, &opdef)) {
393 PyErr_Format(
394 PyExc_ValueError,
395 "Failed to parse opdefs at index %d of custom op defs argument: %s",
396 i, tf_opdefs);
397 return nullptr;
398 }
399
400 // Register extra opdefs to TensorFlow global op registry.
401 tensorflow::OpRegistry::Global()->Register(
402 [opdef](
403 tensorflow::OpRegistrationData* op_reg_data) -> tensorflow::Status {
404 *op_reg_data = tensorflow::OpRegistrationData(opdef);
405 return tensorflow::Status::OK();
406 });
407
408 // Register the corresponding fake op kernel.
409 const char* node_name = opdef.name().c_str();
410 const char* op_name = opdef.name().c_str();
411 const char* device_name = "CPU";
412 static auto fake_compute_func = [](void* kernel, TF_OpKernelContext* ctx) {
413 };
414
415 TF_KernelBuilder* builder =
416 TF_NewKernelBuilder(op_name, device_name, /*create_func=*/nullptr,
417 fake_compute_func, /*delete_func=*/nullptr);
418
419 TF_Status* status = TF_NewStatus();
420 TF_RegisterKernelBuilder(node_name, builder, status);
421 if (TF_GetCode(status) != TF_OK) {
422 TF_DeleteStatus(status);
423 PyErr_Format(PyExc_ValueError,
424 "Failed to register fake op kernel at index %d of custom op "
425 "defs argument",
426 i);
427 return nullptr;
428 }
429 TF_DeleteStatus(status);
430 }
431
432 Py_RETURN_TRUE;
433 }
434
RetrieveCollectedErrors()435 const std::vector<std::string> RetrieveCollectedErrors() {
436 mlir::TFL::ErrorCollector* collector =
437 mlir::TFL::ErrorCollector::GetErrorCollector();
438 std::vector<std::string> collected_errors;
439 for (const auto& error_data : collector->CollectedErrors()) {
440 collected_errors.push_back(error_data.SerializeAsString());
441 }
442 collector->Clear();
443 return collected_errors;
444 }
445
446 } // namespace toco
447