• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/toco/python/toco_python_api.h"
16 
17 #include <fstream>
18 #include <map>
19 #include <string>
20 #include <vector>
21 
22 #include "google/protobuf/text_format.h"
23 #include "tensorflow/c/kernels.h"
24 #include "tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.h"
25 #include "tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.h"
26 #include "tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.h"
27 #include "tensorflow/compiler/mlir/lite/sparsity/sparsify_model.h"
28 #include "tensorflow/core/framework/op.h"
29 #include "tensorflow/core/framework/op_def.pb.h"
30 #include "tensorflow/core/platform/logging.h"
31 #include "tensorflow/lite/c/common.h"
32 #include "tensorflow/lite/core/api/error_reporter.h"
33 #include "tensorflow/lite/python/interpreter_wrapper/python_error_reporter.h"
34 #include "tensorflow/lite/python/interpreter_wrapper/python_utils.h"
35 #include "tensorflow/lite/schema/schema_generated.h"
36 #include "tensorflow/lite/toco/import_tensorflow.h"
37 #include "tensorflow/lite/toco/logging/conversion_log_util.h"
38 #include "tensorflow/lite/toco/logging/toco_conversion_log.pb.h"
39 #include "tensorflow/lite/toco/model_flags.pb.h"
40 #include "tensorflow/lite/toco/toco_convert.h"
41 #include "tensorflow/lite/toco/toco_flags.pb.h"
42 #include "tensorflow/lite/toco/toco_graphviz_dump_options.h"
43 #include "tensorflow/lite/toco/toco_port.h"
44 #include "tensorflow/lite/toco/toco_tooling.h"
45 #include "tensorflow/lite/toco/toco_types.h"
46 #include "tensorflow/lite/toco/tooling_util.h"
47 #include "tensorflow/lite/toco/types.pb.h"
48 
49 namespace toco {
50 
PopulateConversionLogHelper(const toco::ModelFlags & model_flags,toco::TocoFlags * toco_flags,const std::string & input_contents_txt,const std::string & output_file_contents_txt,const std::string & error_message,GraphVizDumpOptions * dump_options)51 void PopulateConversionLogHelper(const toco::ModelFlags& model_flags,
52                                  toco::TocoFlags* toco_flags,
53                                  const std::string& input_contents_txt,
54                                  const std::string& output_file_contents_txt,
55                                  const std::string& error_message,
56                                  GraphVizDumpOptions* dump_options) {
57   // Make sure the graphviz file will be dumped under the same folder.
58   dump_options->dump_graphviz = toco_flags->conversion_summary_dir();
59   // Here we construct the `toco::Model` class based on the input graph def,
60   // it will then be used to populate the conversion log.
61   // TODO(haoliang): Don't depend on `toco::Model`.
62   std::unique_ptr<toco::Model> imported_model =
63       toco::Import(*toco_flags, model_flags, input_contents_txt);
64   // Dump pre-conversion toco logs.
65   TocoConversionLog toco_log_before;
66   PopulateConversionLog(*imported_model, &toco_log_before);
67   std::ofstream osstream_before(toco_flags->conversion_summary_dir() +
68                                 "/toco_log_before.pb");
69   toco_log_before.SerializeToOstream(&osstream_before);
70   osstream_before.close();
71   toco::LogDump(toco::kLogLevelModelChanged, "tf_graph", *imported_model);
72 
73   // Populate the post-conversion log, for convenient initiate the
74   // `toco::Model` class from the generated flatbuffer.
75   toco_flags->set_input_format(toco::FileFormat::TFLITE);
76   std::unique_ptr<toco::Model> flatbuffer_model =
77       toco::Import(*toco_flags, model_flags, output_file_contents_txt);
78   // Dump post-conversion toco logs.
79   TocoConversionLog toco_log_after;
80   PopulateConversionLog(*flatbuffer_model, &toco_log_after);
81   // Make sure we sanitize the error message.
82   toco_log_after.set_toco_err_logs(SanitizeErrorMessage(error_message));
83   std::ofstream ostream_after(toco_flags->conversion_summary_dir() +
84                               "/toco_log_after.pb");
85   toco_log_after.SerializeToOstream(&ostream_after);
86   ostream_after.close();
87   toco::LogDump(toco::kLogLevelModelChanged, "tflite_graph", *flatbuffer_model);
88 }
89 
90 // NOTE(aselle): We are using raw PyObject's here because we want to make
91 // sure we input and output bytes rather than unicode strings for Python3.
TocoConvert(PyObject * model_flags_proto_txt_raw,PyObject * toco_flags_proto_txt_raw,PyObject * input_contents_txt_raw,bool extended_return,PyObject * debug_info_txt_raw,bool enable_mlir_converter)92 PyObject* TocoConvert(PyObject* model_flags_proto_txt_raw,
93                       PyObject* toco_flags_proto_txt_raw,
94                       PyObject* input_contents_txt_raw, bool extended_return,
95                       PyObject* debug_info_txt_raw,
96                       bool enable_mlir_converter) {
97   // Use Python C API to validate and convert arguments. In py3 (bytes),
98   // in py2 (str).
99   auto ConvertArg = [&](PyObject* obj, bool* error) {
100     char* buf;
101     Py_ssize_t len;
102     if (::tflite::python_utils::ConvertFromPyString(obj, &buf, &len) == -1) {
103       *error = true;
104       return std::string();
105     } else {
106       *error = false;
107       return std::string(buf, len);
108     }
109   };
110 
111   bool error;
112   std::string model_flags_proto_txt =
113       ConvertArg(model_flags_proto_txt_raw, &error);
114   if (error) {
115     PyErr_SetString(PyExc_ValueError, "Model flags are invalid.");
116     return nullptr;
117   }
118   std::string toco_flags_proto_txt =
119       ConvertArg(toco_flags_proto_txt_raw, &error);
120   if (error) {
121     PyErr_SetString(PyExc_ValueError, "Toco flags are invalid.");
122     return nullptr;
123   }
124 
125   // Use TOCO to produce new outputs.
126   toco::ModelFlags model_flags;
127   if (!model_flags.ParseFromString(model_flags_proto_txt)) {
128     PyErr_SetString(PyExc_ValueError,
129                     "Failed to convert Model to Python String.");
130     return nullptr;
131   }
132   toco::TocoFlags toco_flags;
133   if (!toco_flags.ParseFromString(toco_flags_proto_txt)) {
134     PyErr_SetString(PyExc_ValueError,
135                     "Failed to convert Toco to Python String.");
136     return nullptr;
137   }
138 
139   tensorflow::GraphDebugInfo debug_info;
140   if (debug_info_txt_raw && debug_info_txt_raw != Py_None) {
141     std::string debug_info_txt = ConvertArg(debug_info_txt_raw, &error);
142     if (error) {
143       PyErr_SetString(PyExc_ValueError, "Input DebugInfo is invalid.");
144       return nullptr;
145     }
146     if (!debug_info.ParseFromString(debug_info_txt)) {
147       PyErr_SetString(PyExc_ValueError,
148                       "Failed to convert DebugInfo to Python String.");
149       return nullptr;
150     }
151   }
152 
153   tensorflow::GraphDef graph_def;
154   std::string input_contents_txt;
155   if (model_flags.saved_model_dir().empty()) {
156     input_contents_txt = ConvertArg(input_contents_txt_raw, &error);
157     if (error) {
158       PyErr_SetString(PyExc_ValueError, "Input GraphDef is invalid.");
159       return nullptr;
160     }
161     if (!graph_def.ParseFromString(input_contents_txt)) {
162       PyErr_SetString(PyExc_ValueError,
163                       "Failed to convert GraphDef to Python String.");
164       return nullptr;
165     }
166   }
167 
168   auto& dump_options = *GraphVizDumpOptions::singleton();
169   if (toco_flags.has_dump_graphviz_dir()) {
170     dump_options.dump_graphviz = toco_flags.dump_graphviz_dir();
171   }
172   if (toco_flags.has_dump_graphviz_include_video()) {
173     dump_options.dump_graphviz_video = toco_flags.dump_graphviz_include_video();
174   }
175 
176   std::string output_file_contents_txt;
177   tensorflow::Status status;
178   int64 arithmetic_ops_count;
179 
180   // Convert model.
181   if (enable_mlir_converter) {
182     if (!model_flags.saved_model_dir().empty()) {
183       status = tensorflow::ConvertSavedModelToTFLiteFlatBuffer(
184           model_flags, toco_flags, &output_file_contents_txt);
185     } else {
186       tensorflow::GraphDef graph_def;
187       if (!graph_def.ParseFromString(input_contents_txt)) {
188         PyErr_SetString(PyExc_ValueError,
189                         "Failed to convert GraphDef to Python String.");
190         return nullptr;
191       }
192 
193       status = tensorflow::ConvertGraphDefToTFLiteFlatBuffer(
194           model_flags, toco_flags, debug_info, graph_def,
195           &output_file_contents_txt);
196       if (!toco_flags.conversion_summary_dir().empty()) {
197         PopulateConversionLogHelper(
198             model_flags, &toco_flags, input_contents_txt,
199             output_file_contents_txt, status.error_message(), &dump_options);
200       }
201     }
202   } else {
203     status = Convert(input_contents_txt, toco_flags, model_flags,
204                      &output_file_contents_txt, &arithmetic_ops_count);
205   }
206 
207   if (!status.ok()) {
208     PyErr_SetString(PyExc_Exception, status.error_message().c_str());
209     return nullptr;
210   }
211   if (extended_return && !enable_mlir_converter) {
212     PyObject* dict = PyDict_New();
213     PyDict_SetItemString(
214         dict, "flatbuffer",
215         ::tflite::python_utils::ConvertToPyString(
216             output_file_contents_txt.data(), output_file_contents_txt.size()));
217     PyDict_SetItemString(dict, "arithmetic_ops",
218                          PyLong_FromLong(arithmetic_ops_count));
219     return dict;
220   }
221   // Convert arguments back to byte (py3) or str (py2)
222   return ::tflite::python_utils::ConvertToPyString(
223       output_file_contents_txt.data(), output_file_contents_txt.size());
224 }
225 
TocoGetPotentiallySupportedOps()226 PyObject* TocoGetPotentiallySupportedOps() {
227   std::vector<std::string> supported_ops = toco::GetPotentiallySupportedOps();
228   PyObject* list = PyList_New(supported_ops.size());
229   for (size_t i = 0; i < supported_ops.size(); ++i) {
230     const std::string& op = supported_ops[i];
231     PyObject* op_dict = PyDict_New();
232     PyDict_SetItemString(op_dict, "op", PyUnicode_FromString(op.c_str()));
233     PyList_SetItem(list, i, op_dict);
234   }
235   return list;
236 }
237 
MlirQuantizeModel(PyObject * data,bool disable_per_channel,bool fully_quantize,int inference_type,bool enable_numeric_verify)238 PyObject* MlirQuantizeModel(PyObject* data, bool disable_per_channel,
239                             bool fully_quantize, int inference_type,
240                             bool enable_numeric_verify) {
241   using tflite::interpreter_wrapper::PythonErrorReporter;
242   char* buf = nullptr;
243   Py_ssize_t length;
244   std::unique_ptr<PythonErrorReporter> error_reporter(new PythonErrorReporter);
245 
246   if (tflite::python_utils::ConvertFromPyString(data, &buf, &length) == -1) {
247     PyErr_Format(PyExc_ValueError, "Failed to convert input PyObject");
248     return nullptr;
249   }
250   std::unique_ptr<tflite::FlatBufferModel> model =
251       tflite::FlatBufferModel::BuildFromBuffer(buf, length,
252                                                error_reporter.get());
253   if (!model) {
254     PyErr_Format(PyExc_ValueError, "Invalid model");
255     return nullptr;
256   }
257   auto tflite_model = absl::make_unique<tflite::ModelT>();
258   model->GetModel()->UnPackTo(tflite_model.get(), nullptr);
259 
260   tflite::TensorType inference_tensor_type;
261   switch (inference_type) {
262     case toco::IODataType::QUANTIZED_INT16:
263       inference_tensor_type = tflite::TensorType_INT16;
264       break;
265     case toco::IODataType::QUANTIZED_UINT8:
266       inference_tensor_type = tflite::TensorType_UINT8;
267       break;
268     case toco::IODataType::INT8:
269       inference_tensor_type = tflite::TensorType_INT8;
270       break;
271     default:
272       return nullptr;
273   }
274   tflite::TensorType inference_io_type =
275       fully_quantize ? inference_tensor_type : tflite::TensorType_FLOAT32;
276   flatbuffers::FlatBufferBuilder builder;
277   auto status = mlir::lite::QuantizeModel(
278       *tflite_model, inference_io_type, inference_io_type,
279       inference_tensor_type, {}, disable_per_channel, fully_quantize, &builder,
280       error_reporter.get(), enable_numeric_verify);
281 
282   if (status != kTfLiteOk) {
283     error_reporter->exception();
284     return nullptr;
285   }
286   return tflite::python_utils::ConvertToPyString(
287       reinterpret_cast<const char*>(builder.GetCurrentBufferPointer()),
288       builder.GetSize());
289 }
290 
MlirSparsifyModel(PyObject * data)291 PyObject* MlirSparsifyModel(PyObject* data) {
292   using tflite::interpreter_wrapper::PythonErrorReporter;
293   char* buf = nullptr;
294   Py_ssize_t length;
295   std::unique_ptr<PythonErrorReporter> error_reporter(new PythonErrorReporter);
296 
297   if (tflite::python_utils::ConvertFromPyString(data, &buf, &length) == -1) {
298     PyErr_Format(PyExc_ValueError, "Failed to convert input PyObject");
299     return nullptr;
300   }
301   std::unique_ptr<tflite::FlatBufferModel> model =
302       tflite::FlatBufferModel::BuildFromBuffer(buf, length,
303                                                error_reporter.get());
304   if (!model) {
305     PyErr_Format(PyExc_ValueError, "Invalid model");
306     return nullptr;
307   }
308   auto tflite_model = absl::make_unique<tflite::ModelT>();
309   model->GetModel()->UnPackTo(tflite_model.get(), nullptr);
310 
311   flatbuffers::FlatBufferBuilder builder;
312   auto status =
313       mlir::lite::SparsifyModel(*tflite_model, &builder, error_reporter.get());
314 
315   if (status != kTfLiteOk) {
316     error_reporter->exception();
317     return nullptr;
318   }
319   return tflite::python_utils::ConvertToPyString(
320       reinterpret_cast<const char*>(builder.GetCurrentBufferPointer()),
321       builder.GetSize());
322 }
323 
RegisterCustomOpdefs(PyObject * list)324 PyObject* RegisterCustomOpdefs(PyObject* list) {
325   if (!PyList_Check(list)) {
326     PyErr_SetString(PyExc_TypeError, "Expected list in argument");
327     return nullptr;
328   }
329 
330   int64 size = PyList_Size(list);
331   for (int i = 0; i < size; ++i) {
332     // Get character array from Python object.
333     char* tf_opdefs;
334     Py_ssize_t len;
335     if (tflite::python_utils::ConvertFromPyString(PyList_GetItem(list, i),
336                                                   &tf_opdefs, &len) == -1) {
337       PyErr_Format(PyExc_ValueError,
338                    "Failed to convert Python string at index %d of custom op "
339                    "defs argument",
340                    i);
341       return nullptr;
342     }
343 
344     // Parse op def from character array.
345     tensorflow::OpDef opdef;
346     if (!tensorflow::protobuf::TextFormat::ParseFromString(tf_opdefs, &opdef)) {
347       PyErr_Format(
348           PyExc_ValueError,
349           "Failed to parse opdefs at index %d of custom op defs argument: %s",
350           i, tf_opdefs);
351       return nullptr;
352     }
353 
354     // Register extra opdefs to TensorFlow global op registry.
355     tensorflow::OpRegistry::Global()->Register(
356         [opdef](
357             tensorflow::OpRegistrationData* op_reg_data) -> tensorflow::Status {
358           *op_reg_data = tensorflow::OpRegistrationData(opdef);
359           return tensorflow::Status::OK();
360         });
361 
362     // Register the corresponding fake op kernel.
363     const char* node_name = opdef.name().c_str();
364     const char* op_name = opdef.name().c_str();
365     const char* device_name = "CPU";
366     static auto fake_compute_func = [](void* kernel, TF_OpKernelContext* ctx) {
367     };
368 
369     TF_KernelBuilder* builder =
370         TF_NewKernelBuilder(op_name, device_name, /*create_func=*/nullptr,
371                             fake_compute_func, /*delete_func=*/nullptr);
372 
373     TF_Status* status = TF_NewStatus();
374     TF_RegisterKernelBuilder(node_name, builder, status);
375     if (TF_GetCode(status) != TF_OK) {
376       TF_DeleteStatus(status);
377       PyErr_Format(PyExc_ValueError,
378                    "Failed to register fake op kernel at index %d of custom op "
379                    "defs argument",
380                    i);
381       return nullptr;
382     }
383     TF_DeleteStatus(status);
384   }
385 
386   Py_RETURN_TRUE;
387 }
388 
389 }  // namespace toco
390