• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/python/optimize/calibration_wrapper.h"
16 
17 #include <memory>
18 #include <sstream>
19 #include <string>
20 
21 #include "absl/memory/memory.h"
22 #include "absl/strings/str_format.h"
23 #include "tensorflow/lite/c/common.h"
24 #include "tensorflow/lite/interpreter.h"
25 #include "tensorflow/lite/kernels/register.h"
26 #include "tensorflow/lite/model.h"
27 #include "tensorflow/lite/python/interpreter_wrapper/numpy.h"
28 #include "tensorflow/lite/python/interpreter_wrapper/python_error_reporter.h"
29 #include "tensorflow/lite/python/interpreter_wrapper/python_utils.h"
30 #include "tensorflow/lite/shared_library.h"
31 #include "tensorflow/lite/tools/optimize/calibration/calibration_reader.h"
32 #include "tensorflow/lite/tools/optimize/calibration/calibrator.h"
33 #include "tensorflow/lite/tools/optimize/quantization_wrapper_utils.h"
34 #include "tensorflow/lite/tools/optimize/quantize_model.h"
35 
36 #define TFLITE_PY_CHECK(x)               \
37   if ((x) != kTfLiteOk) {                \
38     return error_reporter_->exception(); \
39   }
40 
41 #define TFLITE_PY_ENSURE_VALID_INTERPRETER()                               \
42   if (!interpreter_) {                                                     \
43     PyErr_SetString(PyExc_ValueError, "Interpreter was not initialized."); \
44     return nullptr;                                                        \
45   }
46 
47 namespace tflite {
48 namespace calibration_wrapper {
49 
50 namespace {
51 
52 using python_utils::PyDecrefDeleter;
53 
CreateMutableModel(const tflite::Model & model)54 std::unique_ptr<tflite::ModelT> CreateMutableModel(const tflite::Model& model) {
55   auto copied_model = absl::make_unique<tflite::ModelT>();
56   model.UnPackTo(copied_model.get(), nullptr);
57   return copied_model;
58 }
59 
NoOpModel(const tflite::FlatBufferModel & model)60 bool NoOpModel(const tflite::FlatBufferModel& model) {
61   return model->subgraphs()->size() == 1 &&
62          (!model->subgraphs()->begin()->operators() ||
63           model->subgraphs()->begin()->operators()->size() == 0);
64 }
65 
TfLiteTypeToSchemaType(TfLiteType type)66 inline TensorType TfLiteTypeToSchemaType(TfLiteType type) {
67   switch (type) {
68     case kTfLiteNoType:
69       return TensorType_FLOAT32;  // TODO(b/129336260): No schema type for none.
70     case kTfLiteFloat32:
71       return TensorType_FLOAT32;
72     case kTfLiteFloat16:
73       return TensorType_FLOAT16;
74     case kTfLiteFloat64:
75       return TensorType_FLOAT64;
76     case kTfLiteInt32:
77       return TensorType_INT32;
78     case kTfLiteUInt32:
79       return TensorType_UINT32;
80     case kTfLiteUInt8:
81       return TensorType_UINT8;
82     case kTfLiteInt8:
83       return TensorType_INT8;
84     case kTfLiteInt64:
85       return TensorType_INT64;
86     case kTfLiteUInt64:
87       return TensorType_UINT64;
88     case kTfLiteString:
89       return TensorType_STRING;
90     case kTfLiteBool:
91       return TensorType_BOOL;
92     case kTfLiteInt16:
93       return TensorType_INT16;
94     case kTfLiteComplex64:
95       return TensorType_COMPLEX64;
96     case kTfLiteComplex128:
97       return TensorType_COMPLEX128;
98     case kTfLiteResource:
99       return TensorType_RESOURCE;
100     case kTfLiteVariant:
101       return TensorType_VARIANT;
102   }
103   // No default to get compiler error when new type is introduced.
104 }
105 
RegisterCustomOpByName(const char * registerer_name,tflite::MutableOpResolver * resolver)106 bool RegisterCustomOpByName(const char* registerer_name,
107                             tflite::MutableOpResolver* resolver) {
108   // Registerer functions take a pointer to a BuiltinOpResolver as an input
109   // parameter and return void.
110   // TODO(b/137576229): We should implement this functionality in a more
111   // principled way.
112   typedef void (*RegistererFunctionType)(tflite::MutableOpResolver*);
113 
114   // Look for the Registerer function by name.
115   RegistererFunctionType registerer = reinterpret_cast<RegistererFunctionType>(
116       SharedLibrary::GetSymbol(registerer_name));
117 
118   // Fail in an informative way if the function was not found.
119   if (registerer == nullptr) {
120     PyErr_Format(PyExc_ValueError,
121                  "Looking up symbol '%s' failed with error '%s'.",
122                  registerer_name, SharedLibrary::GetError());
123     return false;
124   }
125 
126   // Call the registerer with the resolver.
127   registerer(resolver);
128   return true;
129 }
130 
131 }  // namespace
132 
AddIntermediateTensors(PyObject * data)133 PyObject* AddIntermediateTensors(PyObject* data) {
134   using tflite::interpreter_wrapper::PythonErrorReporter;
135   char* buf = nullptr;
136   Py_ssize_t length;
137   std::unique_ptr<PythonErrorReporter> error_reporter(new PythonErrorReporter);
138   ::tflite::python::ImportNumpy();
139 
140   if (python_utils::ConvertFromPyString(data, &buf, &length) == -1) {
141     return nullptr;
142   }
143   std::unique_ptr<tflite::FlatBufferModel> model =
144       tflite::FlatBufferModel::BuildFromBuffer(buf, length,
145                                                error_reporter.get());
146   if (!model) {
147     PyErr_Format(PyExc_ValueError, "Invalid model");
148     return nullptr;
149   }
150   flatbuffers::FlatBufferBuilder builder;
151   auto tflite_model = CreateMutableModel(*model->GetModel());
152   if (optimize::AddIntermediateTensorsToFusedOp(&builder, tflite_model.get()) !=
153       kTfLiteOk) {
154     error_reporter->exception();
155     return nullptr;
156   }
157 
158   if (builder.GetSize()) {
159     return python_utils::ConvertToPyString(
160         reinterpret_cast<const char*>(builder.GetCurrentBufferPointer()),
161         builder.GetSize());
162   } else {
163     // When AddIntermediateTensorsToFusedOp early returns, return the model as
164     // it is.
165     return python_utils::ConvertToPyString(buf, length);
166   }
167 }
168 
CalibrationWrapper(std::unique_ptr<tflite::Interpreter> interpreter,std::unique_ptr<tflite::ops::builtin::BuiltinOpResolver> resolver,std::unique_ptr<tflite::interpreter_wrapper::PythonErrorReporter> error_reporter,std::unique_ptr<tflite::FlatBufferModel> model,std::unique_ptr<tflite::optimize::calibration::CalibrationReader> reader,std::unique_ptr<std::string> model_str)169 CalibrationWrapper::CalibrationWrapper(
170     std::unique_ptr<tflite::Interpreter> interpreter,
171     std::unique_ptr<tflite::ops::builtin::BuiltinOpResolver> resolver,
172     std::unique_ptr<tflite::interpreter_wrapper::PythonErrorReporter>
173         error_reporter,
174     std::unique_ptr<tflite::FlatBufferModel> model,
175     std::unique_ptr<tflite::optimize::calibration::CalibrationReader> reader,
176     std::unique_ptr<std::string> model_str)
177     : interpreter_(std::move(interpreter)),
178       error_reporter_(std::move(error_reporter)),
179       resolver_(std::move(resolver)),
180       model_(std::move(model)),
181       reader_(std::move(reader)),
182       model_str_(std::move(model_str)) {}
183 
~CalibrationWrapper()184 CalibrationWrapper::~CalibrationWrapper() {}
185 
Prepare()186 PyObject* CalibrationWrapper::Prepare() {
187   TFLITE_PY_ENSURE_VALID_INTERPRETER();
188   TFLITE_PY_CHECK(interpreter_->AllocateTensors());
189   TFLITE_PY_CHECK(interpreter_->ResetVariableTensors());
190   Py_RETURN_NONE;
191 }
192 
Prepare(PyObject * input_shapes)193 PyObject* CalibrationWrapper::Prepare(PyObject* input_shapes) {
194   TFLITE_PY_ENSURE_VALID_INTERPRETER();
195   if (!PyList_Check(input_shapes)) {
196     PyErr_Format(PyExc_ValueError,
197                  "Invalid input shapes: expected shapes to be a list.");
198     return nullptr;
199   }
200 
201   const size_t inputs_size = PyList_Size(input_shapes);
202   if (inputs_size != interpreter_->inputs().size()) {
203     PyErr_Format(PyExc_ValueError,
204                  "Invalid input shapes: expected %ld items got %ld items.",
205                  interpreter_->inputs().size(), inputs_size);
206     return nullptr;
207   }
208 
209   for (size_t i = 0; i < inputs_size; i++) {
210     PyObject* shape = PyList_GetItem(input_shapes, i);
211     if (!shape || !PyList_Check(shape)) {
212       PyErr_Format(PyExc_ValueError,
213                    "Invalid %ld input shape: expected to be a list.", i);
214       return nullptr;
215     }
216     std::vector<int> dims;
217     for (size_t dim_index = 0; dim_index < PyList_Size(shape); ++dim_index) {
218       PyObject* dim = PyList_GetItem(shape, dim_index);
219       dims.push_back(PyLong_AsLong(dim));
220     }
221     int input_tensor_idx = interpreter_->inputs()[i];
222     if (interpreter_->ResizeInputTensor(input_tensor_idx, dims) != kTfLiteOk) {
223       PyErr_Format(PyExc_ValueError, "Failed to resize %ld input tensor.", i);
224       return nullptr;
225     }
226   }
227 
228   return Prepare();
229 }
230 
FeedTensor(PyObject * input_value)231 PyObject* CalibrationWrapper::FeedTensor(PyObject* input_value) {
232   TFLITE_PY_ENSURE_VALID_INTERPRETER();
233   if (!PyList_Check(input_value)) {
234     PyErr_Format(PyExc_ValueError,
235                  "Invalid input type: expected input to be a list.");
236     return nullptr;
237   }
238 
239   const size_t inputs_size = PyList_Size(input_value);
240 
241   if (inputs_size != interpreter_->inputs().size()) {
242     PyErr_Format(PyExc_ValueError,
243                  "Invalid input size: expected %ld items got %ld items.",
244                  interpreter_->inputs().size(), inputs_size);
245     return nullptr;
246   }
247 
248   for (size_t i = 0; i < inputs_size; i++) {
249     PyObject* input = PyList_GetItem(input_value, i);
250     if (!input) {
251       return nullptr;
252     }
253     int input_tensor_idx = interpreter_->inputs()[i];
254     if (!SetTensor(input_tensor_idx, input)) {
255       return nullptr;
256     }
257   }
258 
259   TFLITE_PY_CHECK(interpreter_->Invoke());
260   Py_RETURN_NONE;
261 }
262 
SetTensor(int index,PyObject * value)263 PyObject* CalibrationWrapper::SetTensor(int index, PyObject* value) {
264   TFLITE_PY_ENSURE_VALID_INTERPRETER();
265 
266   std::unique_ptr<PyObject, PyDecrefDeleter> array_safe(
267       PyArray_FromAny(value, nullptr, 0, 0, NPY_ARRAY_CARRAY, nullptr));
268   if (!array_safe) {
269     PyErr_SetString(PyExc_ValueError,
270                     "Failed to convert value into readable tensor.");
271     return nullptr;
272   }
273 
274   PyArrayObject* array = reinterpret_cast<PyArrayObject*>(array_safe.get());
275   const TfLiteTensor* tensor = interpreter_->tensor(index);
276 
277   if (python_utils::TfLiteTypeFromPyArray(array) != tensor->type) {
278     PyErr_Format(PyExc_ValueError,
279                  "Cannot set tensor:"
280                  " Got value of type %s"
281                  " but expected type %s for input %d, name: %s ",
282                  TfLiteTypeGetName(python_utils::TfLiteTypeFromPyArray(array)),
283                  TfLiteTypeGetName(tensor->type), index, tensor->name);
284     return nullptr;
285   }
286 
287   if (PyArray_NDIM(array) != tensor->dims->size) {
288     PyErr_Format(
289         PyExc_ValueError,
290         "Cannot set tensor: Dimension count mismatch, expected %d but found %d",
291         tensor->dims->size, PyArray_NDIM(array));
292     return nullptr;
293   }
294 
295   std::vector<int> dims(PyArray_NDIM(array));
296   bool has_unknown_dims = false;
297   for (int j = 0; j < PyArray_NDIM(array); j++) {
298     // Ensure the calibration data input shape is the same as the model input
299     // shape unless the dimension is unknown.
300     if (tensor->dims_signature != nullptr &&
301         tensor->dims_signature->size == tensor->dims->size &&
302         tensor->dims_signature->data[j] == -1) {
303       has_unknown_dims = true;
304     } else if (tensor->dims->data[j] != PyArray_SHAPE(array)[j]) {
305       PyErr_Format(PyExc_ValueError,
306                    "Cannot set tensor: Size mismatch, expected %d for dim "
307                    "%d but found %ld",
308                    tensor->dims->data[j], j, PyArray_SHAPE(array)[j]);
309       return nullptr;
310     }
311     dims[j] = PyArray_SHAPE(array)[j];
312   }
313 
314   // Resize the input tensor if there are unknown dimensions.
315   if (has_unknown_dims) {
316     // Does strict checking on the `ResizeInputTensor` call.
317     TFLITE_PY_CHECK(interpreter_->ResizeInputTensorStrict(index, dims));
318     TFLITE_PY_CHECK(interpreter_->AllocateTensors());
319   }
320 
321   tensor = interpreter_->tensor(index);
322 
323   size_t size = PyArray_NBYTES(array);
324 
325   if (tensor->type == kTfLiteString) {
326     tflite::DynamicBuffer buffer;
327     buffer.AddString(reinterpret_cast<const char*>(PyArray_BYTES(array)), size);
328     buffer.WriteToTensor(interpreter_->tensor(index), /*new_shape=*/nullptr);
329     Py_RETURN_NONE;
330   }
331 
332   if (size != tensor->bytes) {
333     PyErr_Format(PyExc_ValueError,
334                  "numpy array had %zu bytes but expected %zu bytes.", size,
335                  tensor->bytes);
336     return nullptr;
337   }
338   memcpy(tensor->data.raw, PyArray_DATA(array), size);
339   Py_RETURN_NONE;
340 }
341 
Calibrate()342 PyObject* CalibrationWrapper::Calibrate() {
343   auto tflite_model = CreateMutableModel(*model_->GetModel());
344   reader_->AddCalibrationToModel(tflite_model.get(), /*update=*/false);
345   flatbuffers::FlatBufferBuilder builder;
346   auto loc = tflite::Model::Pack(builder, tflite_model.get());
347   tflite::FinishModelBuffer(builder, loc);
348   return python_utils::ConvertToPyString(
349       reinterpret_cast<const char*>(builder.GetCurrentBufferPointer()),
350       builder.GetSize());
351 }
352 
QuantizeModel(int input_py_type,int output_py_type,bool allow_float,int activations_py_type)353 PyObject* CalibrationWrapper::QuantizeModel(int input_py_type,
354                                             int output_py_type,
355                                             bool allow_float,
356                                             int activations_py_type) {
357   return QuantizeModel(input_py_type, output_py_type, allow_float,
358                        activations_py_type, /*disable_per_channel=*/false);
359 }
360 
QuantizeModel(int input_py_type,int output_py_type,bool allow_float,int activations_py_type,bool disable_per_channel)361 PyObject* CalibrationWrapper::QuantizeModel(int input_py_type,
362                                             int output_py_type,
363                                             bool allow_float,
364                                             int activations_py_type,
365                                             bool disable_per_channel) {
366   if (NoOpModel(*model_)) {
367     return python_utils::ConvertToPyString(model_str_->data(),
368                                            model_str_->size());
369   }
370 
371   TfLiteType input_type = python_utils::TfLiteTypeFromPyType(input_py_type);
372   TfLiteType output_type = python_utils::TfLiteTypeFromPyType(output_py_type);
373   TfLiteType activations_type =
374       python_utils::TfLiteTypeFromPyType(activations_py_type);
375 
376   if (input_type == kTfLiteNoType || output_type == kTfLiteNoType) {
377     PyErr_SetString(PyExc_ValueError,
378                     "Input/output type cannot be kTfLiteNoType");
379     return nullptr;
380   }
381   auto tflite_model = CreateMutableModel(*model_->GetModel());
382   reader_->AddCalibrationToModel(tflite_model.get(), /*update=*/false);
383   flatbuffers::FlatBufferBuilder builder;
384   auto status = kTfLiteOk;
385 
386   status = tflite::optimize::QuantizeModelAllOperators(
387       &builder, tflite_model.get(), TfLiteTypeToSchemaType(input_type),
388       TfLiteTypeToSchemaType(output_type), allow_float,
389       TfLiteTypeToSchemaType(activations_type), disable_per_channel,
390       error_reporter_.get());
391 
392   if (status != kTfLiteOk) {
393     error_reporter_->exception();
394     return nullptr;
395   }
396 
397   return python_utils::ConvertToPyString(
398       reinterpret_cast<const char*>(builder.GetCurrentBufferPointer()),
399       builder.GetSize());
400 }
401 
QuantizeModel(int input_py_type,int output_py_type,bool allow_float,const char * operator_output_name)402 PyObject* CalibrationWrapper::QuantizeModel(int input_py_type,
403                                             int output_py_type,
404                                             bool allow_float,
405                                             const char* operator_output_name) {
406   string op_name = std::string(operator_output_name);
407 
408   TfLiteType input_type = python_utils::TfLiteTypeFromPyType(input_py_type);
409   TfLiteType output_type = python_utils::TfLiteTypeFromPyType(output_py_type);
410   if (input_type == kTfLiteNoType || output_type == kTfLiteNoType) {
411     PyErr_SetString(PyExc_ValueError,
412                     "Input/output type cannot be kTfLiteNoType");
413     return nullptr;
414   }
415   auto tflite_model = CreateMutableModel(*model_->GetModel());
416   reader_->AddCalibrationToModel(tflite_model.get(), /*update=*/false);
417   flatbuffers::FlatBufferBuilder builder;
418   auto status = tflite::optimize::QuantizeModel(
419       &builder, tflite_model.get(), TfLiteTypeToSchemaType(input_type),
420       TfLiteTypeToSchemaType(output_type), allow_float, {op_name},
421       TensorType_INT8, error_reporter_.get());
422   if (status != kTfLiteOk) {
423     error_reporter_->exception();
424     return nullptr;
425   }
426 
427   return python_utils::ConvertToPyString(
428       reinterpret_cast<const char*>(builder.GetCurrentBufferPointer()),
429       builder.GetSize());
430 }
431 
CreateWrapperCPPFromBuffer(PyObject * data,const std::vector<std::string> & registerers_by_name,const std::vector<std::function<void (uintptr_t)>> & registerers_by_func,std::string * error_msg)432 /*static*/ CalibrationWrapper* CalibrationWrapper::CreateWrapperCPPFromBuffer(
433     PyObject* data, const std::vector<std::string>& registerers_by_name,
434     const std::vector<std::function<void(uintptr_t)>>& registerers_by_func,
435     std::string* error_msg) {
436   using tflite::interpreter_wrapper::PythonErrorReporter;
437   char* buf = nullptr;
438   Py_ssize_t length;
439   std::unique_ptr<PythonErrorReporter> error_reporter(new PythonErrorReporter);
440   ::tflite::python::ImportNumpy();
441 
442   if (python_utils::ConvertFromPyString(data, &buf, &length) == -1) {
443     *error_msg = "Failed to convert from python string";
444     return nullptr;
445   }
446   std::unique_ptr<tflite::FlatBufferModel> model =
447       tflite::FlatBufferModel::BuildFromBuffer(buf, length,
448                                                error_reporter.get());
449   if (!model) {
450     *error_msg = "Invalid model";
451     return nullptr;
452   }
453   auto resolver = absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>();
454   for (const auto& registerer : registerers_by_name) {
455     if (!RegisterCustomOpByName(registerer.c_str(), resolver.get())) {
456       *error_msg =
457           absl::StrFormat("Looking up symbol '%s' failed with error '%s'.",
458                           registerer.c_str(), SharedLibrary::GetError());
459       return nullptr;
460     }
461   }
462   for (const auto& registerer : registerers_by_func) {
463     registerer(reinterpret_cast<uintptr_t>(resolver.get()));
464   }
465   std::unique_ptr<tflite::Interpreter> interpreter;
466   std::unique_ptr<tflite::optimize::calibration::CalibrationReader> reader;
467   auto status = tflite::optimize::calibration::BuildLoggingInterpreter(
468       *model, *resolver, &interpreter, &reader);
469   if (status != kTfLiteOk) {
470     *error_msg = error_reporter->message();
471     return nullptr;
472   }
473 
474   auto model_str = std::make_unique<std::string>(buf, length);
475   // If we are not going to use this string during quantization, reset the
476   // pointer and release the memory.
477   if (!NoOpModel(*model)) {
478     model_str.reset();
479   }
480 
481   auto wrapper = new CalibrationWrapper(
482       std::move(interpreter), std::move(resolver), std::move(error_reporter),
483       std::move(model), std::move(reader), std::move(model_str));
484   return wrapper;
485 }
486 
487 }  // namespace calibration_wrapper
488 }  // namespace tflite
489