• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/python/optimize/calibration_wrapper.h"
16 
17 #include <memory>
18 #include <sstream>
19 #include <string>
20 
21 #include "absl/memory/memory.h"
22 #include "tensorflow/lite/c/common.h"
23 #include "tensorflow/lite/interpreter.h"
24 #include "tensorflow/lite/kernels/register.h"
25 #include "tensorflow/lite/model.h"
26 #include "tensorflow/lite/python/interpreter_wrapper/numpy.h"
27 #include "tensorflow/lite/python/interpreter_wrapper/python_error_reporter.h"
28 #include "tensorflow/lite/python/interpreter_wrapper/python_utils.h"
29 #include "tensorflow/lite/tools/optimize/calibration/calibration_reader.h"
30 #include "tensorflow/lite/tools/optimize/calibration/calibrator.h"
31 #include "tensorflow/lite/tools/optimize/quantization_wrapper_utils.h"
32 #include "tensorflow/lite/tools/optimize/quantize_model.h"
33 
34 #define TFLITE_PY_CHECK(x)               \
35   if ((x) != kTfLiteOk) {                \
36     return error_reporter_->exception(); \
37   }
38 
39 #define TFLITE_PY_ENSURE_VALID_INTERPRETER()                               \
40   if (!interpreter_) {                                                     \
41     PyErr_SetString(PyExc_ValueError, "Interpreter was not initialized."); \
42     return nullptr;                                                        \
43   }
44 
45 namespace tflite {
46 namespace calibration_wrapper {
47 
48 namespace {
49 
50 using python_utils::PyDecrefDeleter;
51 
CreateMutableModel(const tflite::Model & model)52 std::unique_ptr<tflite::ModelT> CreateMutableModel(const tflite::Model& model) {
53   auto copied_model = absl::make_unique<tflite::ModelT>();
54   model.UnPackTo(copied_model.get(), nullptr);
55   return copied_model;
56 }
57 
NoOpModel(const tflite::FlatBufferModel & model)58 bool NoOpModel(const tflite::FlatBufferModel& model) {
59   return model->subgraphs()->size() == 1 &&
60          (!model->subgraphs()->begin()->operators() ||
61           model->subgraphs()->begin()->operators()->size() == 0);
62 }
63 
TfLiteTypeToSchemaType(TfLiteType type)64 inline TensorType TfLiteTypeToSchemaType(TfLiteType type) {
65   switch (type) {
66     case kTfLiteNoType:
67       return TensorType_FLOAT32;  // TODO(b/129336260): No schema type for none.
68     case kTfLiteFloat32:
69       return TensorType_FLOAT32;
70     case kTfLiteFloat16:
71       return TensorType_FLOAT16;
72     case kTfLiteFloat64:
73       return TensorType_FLOAT64;
74     case kTfLiteInt32:
75       return TensorType_INT32;
76     case kTfLiteUInt32:
77       return TensorType_UINT32;
78     case kTfLiteUInt8:
79       return TensorType_UINT8;
80     case kTfLiteInt8:
81       return TensorType_INT8;
82     case kTfLiteInt64:
83       return TensorType_INT64;
84     case kTfLiteUInt64:
85       return TensorType_UINT64;
86     case kTfLiteString:
87       return TensorType_STRING;
88     case kTfLiteBool:
89       return TensorType_BOOL;
90     case kTfLiteInt16:
91       return TensorType_INT16;
92     case kTfLiteComplex64:
93       return TensorType_COMPLEX64;
94     case kTfLiteComplex128:
95       return TensorType_COMPLEX128;
96     case kTfLiteResource:
97       return TensorType_RESOURCE;
98     case kTfLiteVariant:
99       return TensorType_VARIANT;
100   }
101   // No default to get compiler error when new type is introduced.
102 }
103 
104 }  // namespace
105 
AddIntermediateTensors(PyObject * data)106 PyObject* AddIntermediateTensors(PyObject* data) {
107   using tflite::interpreter_wrapper::PythonErrorReporter;
108   char* buf = nullptr;
109   Py_ssize_t length;
110   std::unique_ptr<PythonErrorReporter> error_reporter(new PythonErrorReporter);
111   ::tflite::python::ImportNumpy();
112 
113   if (python_utils::ConvertFromPyString(data, &buf, &length) == -1) {
114     return nullptr;
115   }
116   std::unique_ptr<tflite::FlatBufferModel> model =
117       tflite::FlatBufferModel::BuildFromBuffer(buf, length,
118                                                error_reporter.get());
119   if (!model) {
120     PyErr_Format(PyExc_ValueError, "Invalid model");
121     return nullptr;
122   }
123   flatbuffers::FlatBufferBuilder builder;
124   auto tflite_model = CreateMutableModel(*model->GetModel());
125   if (optimize::AddIntermediateTensorsToFusedOp(&builder, tflite_model.get()) !=
126       kTfLiteOk) {
127     error_reporter->exception();
128     return nullptr;
129   }
130 
131   if (builder.GetSize()) {
132     return python_utils::ConvertToPyString(
133         reinterpret_cast<const char*>(builder.GetCurrentBufferPointer()),
134         builder.GetSize());
135   } else {
136     // When AddIntermediateTensorsToFusedOp early returns, return the model as
137     // it is.
138     return python_utils::ConvertToPyString(buf, length);
139   }
140 }
141 
CalibrationWrapper(std::unique_ptr<tflite::Interpreter> interpreter,std::unique_ptr<tflite::ops::builtin::BuiltinOpResolver> resolver,std::unique_ptr<tflite::interpreter_wrapper::PythonErrorReporter> error_reporter,std::unique_ptr<tflite::FlatBufferModel> model,std::unique_ptr<tflite::optimize::calibration::CalibrationReader> reader,std::unique_ptr<std::string> model_str)142 CalibrationWrapper::CalibrationWrapper(
143     std::unique_ptr<tflite::Interpreter> interpreter,
144     std::unique_ptr<tflite::ops::builtin::BuiltinOpResolver> resolver,
145     std::unique_ptr<tflite::interpreter_wrapper::PythonErrorReporter>
146         error_reporter,
147     std::unique_ptr<tflite::FlatBufferModel> model,
148     std::unique_ptr<tflite::optimize::calibration::CalibrationReader> reader,
149     std::unique_ptr<std::string> model_str)
150     : interpreter_(std::move(interpreter)),
151       error_reporter_(std::move(error_reporter)),
152       resolver_(std::move(resolver)),
153       model_(std::move(model)),
154       reader_(std::move(reader)),
155       model_str_(std::move(model_str)) {}
156 
~CalibrationWrapper()157 CalibrationWrapper::~CalibrationWrapper() {}
158 
Prepare()159 PyObject* CalibrationWrapper::Prepare() {
160   TFLITE_PY_ENSURE_VALID_INTERPRETER();
161   TFLITE_PY_CHECK(interpreter_->AllocateTensors());
162   TFLITE_PY_CHECK(interpreter_->ResetVariableTensors());
163   Py_RETURN_NONE;
164 }
165 
Prepare(PyObject * input_shapes)166 PyObject* CalibrationWrapper::Prepare(PyObject* input_shapes) {
167   TFLITE_PY_ENSURE_VALID_INTERPRETER();
168   if (!PyList_Check(input_shapes)) {
169     PyErr_Format(PyExc_ValueError,
170                  "Invalid input shapes: expected shapes to be a list.");
171     return nullptr;
172   }
173 
174   const size_t inputs_size = PyList_Size(input_shapes);
175   if (inputs_size != interpreter_->inputs().size()) {
176     PyErr_Format(PyExc_ValueError,
177                  "Invalid input shapes: expected %ld items got %ld items.",
178                  interpreter_->inputs().size(), inputs_size);
179     return nullptr;
180   }
181 
182   for (size_t i = 0; i < inputs_size; i++) {
183     PyObject* shape = PyList_GetItem(input_shapes, i);
184     if (!shape || !PyList_Check(shape)) {
185       PyErr_Format(PyExc_ValueError,
186                    "Invalid %ld input shape: expected to be a list.", i);
187       return nullptr;
188     }
189     std::vector<int> dims;
190     for (size_t dim_index = 0; dim_index < PyList_Size(shape); ++dim_index) {
191       PyObject* dim = PyList_GetItem(shape, dim_index);
192       dims.push_back(PyLong_AsLong(dim));
193     }
194     int input_tensor_idx = interpreter_->inputs()[i];
195     if (interpreter_->ResizeInputTensor(input_tensor_idx, dims) != kTfLiteOk) {
196       PyErr_Format(PyExc_ValueError, "Failed to resize %ld input tensor.", i);
197       return nullptr;
198     }
199   }
200 
201   return Prepare();
202 }
203 
FeedTensor(PyObject * input_value)204 PyObject* CalibrationWrapper::FeedTensor(PyObject* input_value) {
205   TFLITE_PY_ENSURE_VALID_INTERPRETER();
206   if (!PyList_Check(input_value)) {
207     PyErr_Format(PyExc_ValueError,
208                  "Invalid input type: expected input to be a list.");
209     return nullptr;
210   }
211 
212   const size_t inputs_size = PyList_Size(input_value);
213 
214   if (inputs_size != interpreter_->inputs().size()) {
215     PyErr_Format(PyExc_ValueError,
216                  "Invalid input size: expected %ld items got %ld items.",
217                  interpreter_->inputs().size(), inputs_size);
218     return nullptr;
219   }
220 
221   for (size_t i = 0; i < inputs_size; i++) {
222     PyObject* input = PyList_GetItem(input_value, i);
223     if (!input) {
224       return nullptr;
225     }
226     int input_tensor_idx = interpreter_->inputs()[i];
227     if (!SetTensor(input_tensor_idx, input)) {
228       return nullptr;
229     }
230   }
231 
232   TFLITE_PY_CHECK(interpreter_->Invoke());
233   Py_RETURN_NONE;
234 }
235 
SetTensor(int index,PyObject * value)236 PyObject* CalibrationWrapper::SetTensor(int index, PyObject* value) {
237   TFLITE_PY_ENSURE_VALID_INTERPRETER();
238 
239   std::unique_ptr<PyObject, PyDecrefDeleter> array_safe(
240       PyArray_FromAny(value, nullptr, 0, 0, NPY_ARRAY_CARRAY, nullptr));
241   if (!array_safe) {
242     PyErr_SetString(PyExc_ValueError,
243                     "Failed to convert value into readable tensor.");
244     return nullptr;
245   }
246 
247   PyArrayObject* array = reinterpret_cast<PyArrayObject*>(array_safe.get());
248   const TfLiteTensor* tensor = interpreter_->tensor(index);
249 
250   if (python_utils::TfLiteTypeFromPyArray(array) != tensor->type) {
251     PyErr_Format(PyExc_ValueError,
252                  "Cannot set tensor:"
253                  " Got value of type %s"
254                  " but expected type %s for input %d, name: %s ",
255                  TfLiteTypeGetName(python_utils::TfLiteTypeFromPyArray(array)),
256                  TfLiteTypeGetName(tensor->type), index, tensor->name);
257     return nullptr;
258   }
259 
260   if (PyArray_NDIM(array) != tensor->dims->size) {
261     PyErr_Format(
262         PyExc_ValueError,
263         "Cannot set tensor: Dimension count mismatch, expected %d but found %d",
264         tensor->dims->size, PyArray_NDIM(array));
265     return nullptr;
266   }
267 
268   std::vector<int> dims(PyArray_NDIM(array));
269   bool has_unknown_dims = false;
270   for (int j = 0; j < PyArray_NDIM(array); j++) {
271     // Ensure the calibration data input shape is the same as the model input
272     // shape unless the dimension is unknown.
273     if (tensor->dims_signature != nullptr &&
274         tensor->dims_signature->size == tensor->dims->size &&
275         tensor->dims_signature->data[j] == -1) {
276       has_unknown_dims = true;
277     } else if (tensor->dims->data[j] != PyArray_SHAPE(array)[j]) {
278       PyErr_Format(PyExc_ValueError,
279                    "Cannot set tensor: Size mismatch, expected %d for dim "
280                    "%d but found %ld",
281                    tensor->dims->data[j], j, PyArray_SHAPE(array)[j]);
282       return nullptr;
283     }
284     dims[j] = PyArray_SHAPE(array)[j];
285   }
286 
287   // Resize the input tensor if there are unknown dimensions.
288   if (has_unknown_dims) {
289     // Does strict checking on the `ResizeInputTensor` call.
290     TFLITE_PY_CHECK(interpreter_->ResizeInputTensorStrict(index, dims));
291     TFLITE_PY_CHECK(interpreter_->AllocateTensors());
292   }
293 
294   tensor = interpreter_->tensor(index);
295 
296   size_t size = PyArray_NBYTES(array);
297 
298   if (tensor->type == kTfLiteString) {
299     tflite::DynamicBuffer buffer;
300     buffer.AddString(reinterpret_cast<const char*>(PyArray_BYTES(array)), size);
301     buffer.WriteToTensor(interpreter_->tensor(index), /*new_shape=*/nullptr);
302     Py_RETURN_NONE;
303   }
304 
305   if (size != tensor->bytes) {
306     PyErr_Format(PyExc_ValueError,
307                  "numpy array had %zu bytes but expected %zu bytes.", size,
308                  tensor->bytes);
309     return nullptr;
310   }
311   memcpy(tensor->data.raw, PyArray_DATA(array), size);
312   Py_RETURN_NONE;
313 }
314 
Calibrate()315 PyObject* CalibrationWrapper::Calibrate() {
316   auto tflite_model = CreateMutableModel(*model_->GetModel());
317   reader_->AddCalibrationToModel(tflite_model.get(), /*update=*/false);
318   flatbuffers::FlatBufferBuilder builder;
319   auto loc = tflite::Model::Pack(builder, tflite_model.get());
320   tflite::FinishModelBuffer(builder, loc);
321   return python_utils::ConvertToPyString(
322       reinterpret_cast<const char*>(builder.GetCurrentBufferPointer()),
323       builder.GetSize());
324 }
325 
QuantizeModel(int input_py_type,int output_py_type,bool allow_float,int activations_py_type)326 PyObject* CalibrationWrapper::QuantizeModel(int input_py_type,
327                                             int output_py_type,
328                                             bool allow_float,
329                                             int activations_py_type) {
330   if (NoOpModel(*model_)) {
331     return python_utils::ConvertToPyString(model_str_->data(),
332                                            model_str_->size());
333   }
334 
335   TfLiteType input_type = python_utils::TfLiteTypeFromPyType(input_py_type);
336   TfLiteType output_type = python_utils::TfLiteTypeFromPyType(output_py_type);
337   TfLiteType activations_type =
338       python_utils::TfLiteTypeFromPyType(activations_py_type);
339 
340   if (input_type == kTfLiteNoType || output_type == kTfLiteNoType) {
341     PyErr_SetString(PyExc_ValueError,
342                     "Input/output type cannot be kTfLiteNoType");
343     return nullptr;
344   }
345   auto tflite_model = CreateMutableModel(*model_->GetModel());
346   reader_->AddCalibrationToModel(tflite_model.get(), /*update=*/false);
347   flatbuffers::FlatBufferBuilder builder;
348   auto status = kTfLiteOk;
349 
350   status = tflite::optimize::QuantizeModelAllOperators(
351       &builder, tflite_model.get(), TfLiteTypeToSchemaType(input_type),
352       TfLiteTypeToSchemaType(output_type), allow_float,
353       TfLiteTypeToSchemaType(activations_type), error_reporter_.get());
354 
355   if (status != kTfLiteOk) {
356     error_reporter_->exception();
357     return nullptr;
358   }
359 
360   return python_utils::ConvertToPyString(
361       reinterpret_cast<const char*>(builder.GetCurrentBufferPointer()),
362       builder.GetSize());
363 }
364 
QuantizeModel(int input_py_type,int output_py_type,bool allow_float,const char * operator_output_name)365 PyObject* CalibrationWrapper::QuantizeModel(int input_py_type,
366                                             int output_py_type,
367                                             bool allow_float,
368                                             const char* operator_output_name) {
369   string op_name = std::string(operator_output_name);
370 
371   TfLiteType input_type = python_utils::TfLiteTypeFromPyType(input_py_type);
372   TfLiteType output_type = python_utils::TfLiteTypeFromPyType(output_py_type);
373   if (input_type == kTfLiteNoType || output_type == kTfLiteNoType) {
374     PyErr_SetString(PyExc_ValueError,
375                     "Input/output type cannot be kTfLiteNoType");
376     return nullptr;
377   }
378   auto tflite_model = CreateMutableModel(*model_->GetModel());
379   reader_->AddCalibrationToModel(tflite_model.get(), /*update=*/false);
380   flatbuffers::FlatBufferBuilder builder;
381   auto status = tflite::optimize::QuantizeModel(
382       &builder, tflite_model.get(), TfLiteTypeToSchemaType(input_type),
383       TfLiteTypeToSchemaType(output_type), allow_float, {op_name},
384       TensorType_INT8, error_reporter_.get());
385   if (status != kTfLiteOk) {
386     error_reporter_->exception();
387     return nullptr;
388   }
389 
390   return python_utils::ConvertToPyString(
391       reinterpret_cast<const char*>(builder.GetCurrentBufferPointer()),
392       builder.GetSize());
393 }
394 
CreateWrapperCPPFromBuffer(PyObject * data)395 /*static*/ CalibrationWrapper* CalibrationWrapper::CreateWrapperCPPFromBuffer(
396     PyObject* data) {
397   using tflite::interpreter_wrapper::PythonErrorReporter;
398   char* buf = nullptr;
399   Py_ssize_t length;
400   std::unique_ptr<PythonErrorReporter> error_reporter(new PythonErrorReporter);
401   ::tflite::python::ImportNumpy();
402 
403   if (python_utils::ConvertFromPyString(data, &buf, &length) == -1) {
404     return nullptr;
405   }
406   std::unique_ptr<tflite::FlatBufferModel> model =
407       tflite::FlatBufferModel::BuildFromBuffer(buf, length,
408                                                error_reporter.get());
409   if (!model) {
410     PyErr_Format(PyExc_ValueError, "Invalid model");
411     return nullptr;
412   }
413   auto resolver = absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>();
414   std::unique_ptr<tflite::Interpreter> interpreter;
415   std::unique_ptr<tflite::optimize::calibration::CalibrationReader> reader;
416   auto status = tflite::optimize::calibration::BuildLoggingInterpreter(
417       *model, *resolver, &interpreter, &reader);
418   if (status != kTfLiteOk) {
419     error_reporter->exception();
420     return nullptr;
421   }
422 
423   auto model_str = std::make_unique<std::string>(buf, length);
424   // If we are not going to use this string during quantization, reset the
425   // pointer and release the memory.
426   if (!NoOpModel(*model)) {
427     model_str.reset();
428   }
429 
430   auto wrapper = new CalibrationWrapper(
431       std::move(interpreter), std::move(resolver), std::move(error_reporter),
432       std::move(model), std::move(reader), std::move(model_str));
433   return wrapper;
434 }
435 
436 }  // namespace calibration_wrapper
437 }  // namespace tflite
438