• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/toco/python/toco_python_api.h"
16 
17 #include <fstream>
18 #include <map>
19 #include <string>
20 #include <vector>
21 
22 #include "google/protobuf/text_format.h"
23 #include "tensorflow/c/kernels.h"
24 #include "tensorflow/compiler/mlir/lite/metrics/error_collector.h"
25 #include "tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.h"
26 #include "tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.h"
27 #include "tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.h"
28 #include "tensorflow/compiler/mlir/lite/sparsity/sparsify_model.h"
29 #include "tensorflow/core/framework/op.h"
30 #include "tensorflow/core/framework/op_def.pb.h"
31 #include "tensorflow/core/platform/logging.h"
32 #include "tensorflow/lite/c/common.h"
33 #include "tensorflow/lite/core/api/error_reporter.h"
34 #include "tensorflow/lite/python/interpreter_wrapper/python_error_reporter.h"
35 #include "tensorflow/lite/python/interpreter_wrapper/python_utils.h"
36 #include "tensorflow/lite/schema/schema_generated.h"
37 #include "tensorflow/lite/toco/import_tensorflow.h"
38 #include "tensorflow/lite/toco/logging/conversion_log_util.h"
39 #include "tensorflow/lite/toco/logging/toco_conversion_log.pb.h"
40 #include "tensorflow/lite/toco/model_flags.pb.h"
41 #include "tensorflow/lite/toco/toco_convert.h"
42 #include "tensorflow/lite/toco/toco_flags.pb.h"
43 #include "tensorflow/lite/toco/toco_graphviz_dump_options.h"
44 #include "tensorflow/lite/toco/toco_port.h"
45 #include "tensorflow/lite/toco/toco_tooling.h"
46 #include "tensorflow/lite/toco/toco_types.h"
47 #include "tensorflow/lite/toco/tooling_util.h"
48 #include "tensorflow/lite/toco/types.pb.h"
49 
50 namespace toco {
51 using mlir::lite::StringSet;
52 
PopulateConversionLogHelper(const toco::ModelFlags & model_flags,toco::TocoFlags * toco_flags,const std::string & input_contents_txt,const std::string & output_file_contents_txt,const std::string & error_message,GraphVizDumpOptions * dump_options)53 void PopulateConversionLogHelper(const toco::ModelFlags& model_flags,
54                                  toco::TocoFlags* toco_flags,
55                                  const std::string& input_contents_txt,
56                                  const std::string& output_file_contents_txt,
57                                  const std::string& error_message,
58                                  GraphVizDumpOptions* dump_options) {
59   // Make sure the graphviz file will be dumped under the same folder.
60   dump_options->dump_graphviz = toco_flags->conversion_summary_dir();
61   // Here we construct the `toco::Model` class based on the input graph def,
62   // it will then be used to populate the conversion log.
63   // TODO(haoliang): Don't depend on `toco::Model`.
64   std::unique_ptr<toco::Model> imported_model =
65       toco::Import(*toco_flags, model_flags, input_contents_txt);
66   // Dump pre-conversion toco logs.
67   TocoConversionLog toco_log_before;
68   PopulateConversionLog(*imported_model, &toco_log_before);
69   std::ofstream osstream_before(toco_flags->conversion_summary_dir() +
70                                 "/toco_log_before.pb");
71   toco_log_before.SerializeToOstream(&osstream_before);
72   osstream_before.close();
73   toco::LogDump(toco::kLogLevelModelChanged, "tf_graph", *imported_model);
74 
75   // Populate the post-conversion log, for convenient initiate the
76   // `toco::Model` class from the generated flatbuffer.
77   toco_flags->set_input_format(toco::FileFormat::TFLITE);
78   std::unique_ptr<toco::Model> flatbuffer_model =
79       toco::Import(*toco_flags, model_flags, output_file_contents_txt);
80   // Dump post-conversion toco logs.
81   TocoConversionLog toco_log_after;
82   PopulateConversionLog(*flatbuffer_model, &toco_log_after);
83   // Make sure we sanitize the error message.
84   toco_log_after.set_toco_err_logs(SanitizeErrorMessage(error_message));
85   std::ofstream ostream_after(toco_flags->conversion_summary_dir() +
86                               "/toco_log_after.pb");
87   toco_log_after.SerializeToOstream(&ostream_after);
88   ostream_after.close();
89   toco::LogDump(toco::kLogLevelModelChanged, "tflite_graph", *flatbuffer_model);
90 }
91 
92 // NOTE(aselle): We are using raw PyObject's here because we want to make
93 // sure we input and output bytes rather than unicode strings for Python3.
TocoConvert(PyObject * model_flags_proto_txt_raw,PyObject * toco_flags_proto_txt_raw,PyObject * input_contents_txt_raw,bool extended_return,PyObject * debug_info_txt_raw,bool enable_mlir_converter)94 PyObject* TocoConvert(PyObject* model_flags_proto_txt_raw,
95                       PyObject* toco_flags_proto_txt_raw,
96                       PyObject* input_contents_txt_raw, bool extended_return,
97                       PyObject* debug_info_txt_raw,
98                       bool enable_mlir_converter) {
99   // Use Python C API to validate and convert arguments. In py3 (bytes),
100   // in py2 (str).
101   auto ConvertArg = [&](PyObject* obj, bool* error) {
102     char* buf;
103     Py_ssize_t len;
104     if (::tflite::python_utils::ConvertFromPyString(obj, &buf, &len) == -1) {
105       *error = true;
106       return std::string();
107     } else {
108       *error = false;
109       return std::string(buf, len);
110     }
111   };
112 
113   bool error;
114   std::string model_flags_proto_txt =
115       ConvertArg(model_flags_proto_txt_raw, &error);
116   if (error) {
117     PyErr_SetString(PyExc_ValueError, "Model flags are invalid.");
118     return nullptr;
119   }
120   std::string toco_flags_proto_txt =
121       ConvertArg(toco_flags_proto_txt_raw, &error);
122   if (error) {
123     PyErr_SetString(PyExc_ValueError, "Toco flags are invalid.");
124     return nullptr;
125   }
126 
127   // Use TOCO to produce new outputs.
128   toco::ModelFlags model_flags;
129   if (!model_flags.ParseFromString(model_flags_proto_txt)) {
130     PyErr_SetString(PyExc_ValueError,
131                     "Failed to convert Model to Python String.");
132     return nullptr;
133   }
134   toco::TocoFlags toco_flags;
135   if (!toco_flags.ParseFromString(toco_flags_proto_txt)) {
136     PyErr_SetString(PyExc_ValueError,
137                     "Failed to convert Toco to Python String.");
138     return nullptr;
139   }
140 
141   tensorflow::GraphDebugInfo debug_info;
142   if (debug_info_txt_raw && debug_info_txt_raw != Py_None) {
143     std::string debug_info_txt = ConvertArg(debug_info_txt_raw, &error);
144     if (error) {
145       PyErr_SetString(PyExc_ValueError, "Input DebugInfo is invalid.");
146       return nullptr;
147     }
148     if (!debug_info.ParseFromString(debug_info_txt)) {
149       PyErr_SetString(PyExc_ValueError,
150                       "Failed to convert DebugInfo to Python String.");
151       return nullptr;
152     }
153   }
154 
155   tensorflow::GraphDef graph_def;
156   std::string input_contents_txt;
157   if (model_flags.saved_model_dir().empty()) {
158     input_contents_txt = ConvertArg(input_contents_txt_raw, &error);
159     if (error) {
160       PyErr_SetString(PyExc_ValueError, "Input GraphDef is invalid.");
161       return nullptr;
162     }
163     if (!graph_def.ParseFromString(input_contents_txt)) {
164       PyErr_SetString(PyExc_ValueError,
165                       "Failed to convert GraphDef to Python String.");
166       return nullptr;
167     }
168   }
169 
170   auto& dump_options = *GraphVizDumpOptions::singleton();
171   if (toco_flags.has_dump_graphviz_dir()) {
172     dump_options.dump_graphviz = toco_flags.dump_graphviz_dir();
173   }
174   if (toco_flags.has_dump_graphviz_include_video()) {
175     dump_options.dump_graphviz_video = toco_flags.dump_graphviz_include_video();
176   }
177 
178   std::string output_file_contents_txt;
179   tensorflow::Status status;
180   int64_t arithmetic_ops_count;
181 
182   // Convert model.
183   if (enable_mlir_converter) {
184     if (!model_flags.saved_model_dir().empty()) {
185       status = tensorflow::ConvertSavedModelToTFLiteFlatBuffer(
186           model_flags, toco_flags, &output_file_contents_txt);
187     } else {
188       tensorflow::GraphDef graph_def;
189       if (!graph_def.ParseFromString(input_contents_txt)) {
190         PyErr_SetString(PyExc_ValueError,
191                         "Failed to convert GraphDef to Python String.");
192         return nullptr;
193       }
194 
195       status = tensorflow::ConvertGraphDefToTFLiteFlatBuffer(
196           model_flags, toco_flags, debug_info, graph_def,
197           &output_file_contents_txt);
198       if (!toco_flags.conversion_summary_dir().empty()) {
199         PopulateConversionLogHelper(
200             model_flags, &toco_flags, input_contents_txt,
201             output_file_contents_txt, status.error_message(), &dump_options);
202       }
203     }
204   } else {
205     status = Convert(input_contents_txt, toco_flags, model_flags,
206                      &output_file_contents_txt, &arithmetic_ops_count);
207   }
208 
209   if (!status.ok()) {
210     PyErr_SetString(PyExc_Exception, status.error_message().c_str());
211     return nullptr;
212   }
213   if (extended_return && !enable_mlir_converter) {
214     PyObject* dict = PyDict_New();
215     PyDict_SetItemString(
216         dict, "flatbuffer",
217         ::tflite::python_utils::ConvertToPyString(
218             output_file_contents_txt.data(), output_file_contents_txt.size()));
219     PyDict_SetItemString(dict, "arithmetic_ops",
220                          PyLong_FromLong(arithmetic_ops_count));
221     return dict;
222   }
223   // Convert arguments back to byte (py3) or str (py2)
224   return ::tflite::python_utils::ConvertToPyString(
225       output_file_contents_txt.data(), output_file_contents_txt.size());
226 }
227 
FromTocoDataTypeToTflitToTensorType(int inference_type)228 tflite::TensorType FromTocoDataTypeToTflitToTensorType(int inference_type) {
229   switch (inference_type) {
230     case toco::IODataType::QUANTIZED_INT16:
231       return tflite::TensorType_INT16;
232     case toco::IODataType::QUANTIZED_UINT8:
233       return tflite::TensorType_UINT8;
234     case toco::IODataType::UINT8:
235       return tflite::TensorType_UINT8;
236     case toco::IODataType::QUANTIZED_INT8:
237       return tflite::TensorType_INT8;
238     case toco::IODataType::INT8:
239       return tflite::TensorType_INT8;
240     default:
241       return tflite::TensorType_FLOAT32;
242   }
243 }
244 
ToStringSet(PyObject * py_denylist,StringSet * string_set)245 int ToStringSet(PyObject* py_denylist, StringSet* string_set) {
246   using tflite::python_utils::ConvertFromPyString;
247   // Ensure op_denylist is non null
248   if (!py_denylist) {
249     return 0;
250   }
251   if (PyList_Check(py_denylist)) {
252     for (int i = 0; i < PyList_GET_SIZE(py_denylist); ++i) {
253       PyObject* value = PyList_GetItem(py_denylist, i);
254       char* str_buf;
255       Py_ssize_t length;
256       if (ConvertFromPyString(value, &str_buf, &length) == -1) {
257         return -1;
258       }
259       string_set->emplace(str_buf, length);
260     }
261   }
262   if (PySet_Check(py_denylist)) {
263     auto* tmp = PySet_New(py_denylist);
264     while (PySet_GET_SIZE(tmp)) {
265       PyObject* value = PySet_Pop(tmp);
266       char* str_buf;
267       Py_ssize_t length;
268       if (ConvertFromPyString(value, &str_buf, &length) == -1) {
269         return -1;
270       }
271       string_set->emplace(str_buf, length);
272     }
273   }
274   return 0;
275 }
276 
MlirQuantizeModel(PyObject * data,bool disable_per_channel,bool fully_quantize,int inference_type,int input_data_type,int output_data_type,bool enable_numeric_verify,bool enable_whole_model_verify,PyObject * op_denylist,PyObject * node_denylist)277 PyObject* MlirQuantizeModel(PyObject* data, bool disable_per_channel,
278                             bool fully_quantize, int inference_type,
279                             int input_data_type, int output_data_type,
280                             bool enable_numeric_verify,
281                             bool enable_whole_model_verify,
282                             PyObject* op_denylist, PyObject* node_denylist) {
283   using tflite::interpreter_wrapper::PythonErrorReporter;
284   char* buf = nullptr;
285   Py_ssize_t length;
286   std::unique_ptr<PythonErrorReporter> error_reporter(new PythonErrorReporter);
287 
288   if (tflite::python_utils::ConvertFromPyString(data, &buf, &length) == -1) {
289     PyErr_Format(PyExc_ValueError, "Failed to convert input PyObject");
290     return nullptr;
291   }
292 
293   StringSet denylisted_ops;
294   StringSet denylisted_nodes;
295   if (ToStringSet(op_denylist, &denylisted_ops) == -1) {
296     PyErr_Format(PyExc_ValueError, "Failed to convert op denylist PyObject");
297     return nullptr;
298   }
299   if (ToStringSet(node_denylist, &denylisted_nodes) == -1) {
300     PyErr_Format(PyExc_ValueError, "Failed to convert node denylist PyObject");
301     return nullptr;
302   }
303 
304   std::unique_ptr<tflite::FlatBufferModel> model =
305       tflite::FlatBufferModel::BuildFromBuffer(buf, length,
306                                                error_reporter.get());
307   if (!model) {
308     PyErr_Format(PyExc_ValueError, "Invalid model");
309     return nullptr;
310   }
311   auto tflite_model = absl::make_unique<tflite::ModelT>();
312   model->GetModel()->UnPackTo(tflite_model.get(), nullptr);
313 
314   tflite::TensorType inference_tensor_type =
315       FromTocoDataTypeToTflitToTensorType(inference_type);
316   tflite::TensorType input_type =
317       FromTocoDataTypeToTflitToTensorType(input_data_type);
318   tflite::TensorType output_type =
319       FromTocoDataTypeToTflitToTensorType(output_data_type);
320 
321   flatbuffers::FlatBufferBuilder builder;
322   auto status = mlir::lite::QuantizeModel(
323       *tflite_model, input_type, output_type, inference_tensor_type, {},
324       disable_per_channel, fully_quantize, &builder, error_reporter.get(),
325       enable_numeric_verify, enable_whole_model_verify,
326       /*legacy_float_scale=*/true, denylisted_ops, denylisted_nodes);
327 
328   if (status != kTfLiteOk) {
329     error_reporter->exception();
330     return nullptr;
331   }
332   return tflite::python_utils::ConvertToPyString(
333       reinterpret_cast<const char*>(builder.GetCurrentBufferPointer()),
334       builder.GetSize());
335 }
336 
MlirSparsifyModel(PyObject * data)337 PyObject* MlirSparsifyModel(PyObject* data) {
338   using tflite::interpreter_wrapper::PythonErrorReporter;
339   char* buf = nullptr;
340   Py_ssize_t length;
341   std::unique_ptr<PythonErrorReporter> error_reporter(new PythonErrorReporter);
342 
343   if (tflite::python_utils::ConvertFromPyString(data, &buf, &length) == -1) {
344     PyErr_Format(PyExc_ValueError, "Failed to convert input PyObject");
345     return nullptr;
346   }
347   std::unique_ptr<tflite::FlatBufferModel> model =
348       tflite::FlatBufferModel::BuildFromBuffer(buf, length,
349                                                error_reporter.get());
350   if (!model) {
351     PyErr_Format(PyExc_ValueError, "Invalid model");
352     return nullptr;
353   }
354   auto tflite_model = absl::make_unique<tflite::ModelT>();
355   model->GetModel()->UnPackTo(tflite_model.get(), nullptr);
356 
357   flatbuffers::FlatBufferBuilder builder;
358   auto status =
359       mlir::lite::SparsifyModel(*tflite_model, &builder, error_reporter.get());
360 
361   if (status != kTfLiteOk) {
362     error_reporter->exception();
363     return nullptr;
364   }
365   return tflite::python_utils::ConvertToPyString(
366       reinterpret_cast<const char*>(builder.GetCurrentBufferPointer()),
367       builder.GetSize());
368 }
369 
RegisterCustomOpdefs(PyObject * list)370 PyObject* RegisterCustomOpdefs(PyObject* list) {
371   if (!PyList_Check(list)) {
372     PyErr_SetString(PyExc_TypeError, "Expected list in argument");
373     return nullptr;
374   }
375 
376   int64_t size = PyList_Size(list);
377   for (int i = 0; i < size; ++i) {
378     // Get character array from Python object.
379     char* tf_opdefs;
380     Py_ssize_t len;
381     if (tflite::python_utils::ConvertFromPyString(PyList_GetItem(list, i),
382                                                   &tf_opdefs, &len) == -1) {
383       PyErr_Format(PyExc_ValueError,
384                    "Failed to convert Python string at index %d of custom op "
385                    "defs argument",
386                    i);
387       return nullptr;
388     }
389 
390     // Parse op def from character array.
391     tensorflow::OpDef opdef;
392     if (!tensorflow::protobuf::TextFormat::ParseFromString(tf_opdefs, &opdef)) {
393       PyErr_Format(
394           PyExc_ValueError,
395           "Failed to parse opdefs at index %d of custom op defs argument: %s",
396           i, tf_opdefs);
397       return nullptr;
398     }
399 
400     // Register extra opdefs to TensorFlow global op registry.
401     tensorflow::OpRegistry::Global()->Register(
402         [opdef](
403             tensorflow::OpRegistrationData* op_reg_data) -> tensorflow::Status {
404           *op_reg_data = tensorflow::OpRegistrationData(opdef);
405           return tensorflow::Status::OK();
406         });
407 
408     // Register the corresponding fake op kernel.
409     const char* node_name = opdef.name().c_str();
410     const char* op_name = opdef.name().c_str();
411     const char* device_name = "CPU";
412     static auto fake_compute_func = [](void* kernel, TF_OpKernelContext* ctx) {
413     };
414 
415     TF_KernelBuilder* builder =
416         TF_NewKernelBuilder(op_name, device_name, /*create_func=*/nullptr,
417                             fake_compute_func, /*delete_func=*/nullptr);
418 
419     TF_Status* status = TF_NewStatus();
420     TF_RegisterKernelBuilder(node_name, builder, status);
421     if (TF_GetCode(status) != TF_OK) {
422       TF_DeleteStatus(status);
423       PyErr_Format(PyExc_ValueError,
424                    "Failed to register fake op kernel at index %d of custom op "
425                    "defs argument",
426                    i);
427       return nullptr;
428     }
429     TF_DeleteStatus(status);
430   }
431 
432   Py_RETURN_TRUE;
433 }
434 
RetrieveCollectedErrors()435 const std::vector<std::string> RetrieveCollectedErrors() {
436   mlir::TFL::ErrorCollector* collector =
437       mlir::TFL::ErrorCollector::GetErrorCollector();
438   std::vector<std::string> collected_errors;
439   for (const auto& error_data : collector->CollectedErrors()) {
440     collected_errors.push_back(error_data.SerializeAsString());
441   }
442   collector->Clear();
443   return collected_errors;
444 }
445 
446 }  // namespace toco
447