1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/python/optimize/calibration_wrapper.h"
16
17 #include <memory>
18 #include <sstream>
19 #include <string>
20
21 #include "absl/memory/memory.h"
22 #include "absl/strings/str_format.h"
23 #include "tensorflow/lite/c/common.h"
24 #include "tensorflow/lite/interpreter.h"
25 #include "tensorflow/lite/kernels/register.h"
26 #include "tensorflow/lite/model.h"
27 #include "tensorflow/lite/python/interpreter_wrapper/numpy.h"
28 #include "tensorflow/lite/python/interpreter_wrapper/python_error_reporter.h"
29 #include "tensorflow/lite/python/interpreter_wrapper/python_utils.h"
30 #include "tensorflow/lite/shared_library.h"
31 #include "tensorflow/lite/tools/optimize/calibration/calibration_reader.h"
32 #include "tensorflow/lite/tools/optimize/calibration/calibrator.h"
33 #include "tensorflow/lite/tools/optimize/quantization_wrapper_utils.h"
34 #include "tensorflow/lite/tools/optimize/quantize_model.h"
35
36 #define TFLITE_PY_CHECK(x) \
37 if ((x) != kTfLiteOk) { \
38 return error_reporter_->exception(); \
39 }
40
41 #define TFLITE_PY_ENSURE_VALID_INTERPRETER() \
42 if (!interpreter_) { \
43 PyErr_SetString(PyExc_ValueError, "Interpreter was not initialized."); \
44 return nullptr; \
45 }
46
47 namespace tflite {
48 namespace calibration_wrapper {
49
50 namespace {
51
52 using python_utils::PyDecrefDeleter;
53
CreateMutableModel(const tflite::Model & model)54 std::unique_ptr<tflite::ModelT> CreateMutableModel(const tflite::Model& model) {
55 auto copied_model = absl::make_unique<tflite::ModelT>();
56 model.UnPackTo(copied_model.get(), nullptr);
57 return copied_model;
58 }
59
NoOpModel(const tflite::FlatBufferModel & model)60 bool NoOpModel(const tflite::FlatBufferModel& model) {
61 return model->subgraphs()->size() == 1 &&
62 (!model->subgraphs()->begin()->operators() ||
63 model->subgraphs()->begin()->operators()->size() == 0);
64 }
65
TfLiteTypeToSchemaType(TfLiteType type)66 inline TensorType TfLiteTypeToSchemaType(TfLiteType type) {
67 switch (type) {
68 case kTfLiteNoType:
69 return TensorType_FLOAT32; // TODO(b/129336260): No schema type for none.
70 case kTfLiteFloat32:
71 return TensorType_FLOAT32;
72 case kTfLiteFloat16:
73 return TensorType_FLOAT16;
74 case kTfLiteFloat64:
75 return TensorType_FLOAT64;
76 case kTfLiteInt32:
77 return TensorType_INT32;
78 case kTfLiteUInt32:
79 return TensorType_UINT32;
80 case kTfLiteUInt8:
81 return TensorType_UINT8;
82 case kTfLiteInt8:
83 return TensorType_INT8;
84 case kTfLiteInt64:
85 return TensorType_INT64;
86 case kTfLiteUInt64:
87 return TensorType_UINT64;
88 case kTfLiteString:
89 return TensorType_STRING;
90 case kTfLiteBool:
91 return TensorType_BOOL;
92 case kTfLiteInt16:
93 return TensorType_INT16;
94 case kTfLiteComplex64:
95 return TensorType_COMPLEX64;
96 case kTfLiteComplex128:
97 return TensorType_COMPLEX128;
98 case kTfLiteResource:
99 return TensorType_RESOURCE;
100 case kTfLiteVariant:
101 return TensorType_VARIANT;
102 }
103 // No default to get compiler error when new type is introduced.
104 }
105
RegisterCustomOpByName(const char * registerer_name,tflite::MutableOpResolver * resolver)106 bool RegisterCustomOpByName(const char* registerer_name,
107 tflite::MutableOpResolver* resolver) {
108 // Registerer functions take a pointer to a BuiltinOpResolver as an input
109 // parameter and return void.
110 // TODO(b/137576229): We should implement this functionality in a more
111 // principled way.
112 typedef void (*RegistererFunctionType)(tflite::MutableOpResolver*);
113
114 // Look for the Registerer function by name.
115 RegistererFunctionType registerer = reinterpret_cast<RegistererFunctionType>(
116 SharedLibrary::GetSymbol(registerer_name));
117
118 // Fail in an informative way if the function was not found.
119 if (registerer == nullptr) {
120 PyErr_Format(PyExc_ValueError,
121 "Looking up symbol '%s' failed with error '%s'.",
122 registerer_name, SharedLibrary::GetError());
123 return false;
124 }
125
126 // Call the registerer with the resolver.
127 registerer(resolver);
128 return true;
129 }
130
131 } // namespace
132
AddIntermediateTensors(PyObject * data)133 PyObject* AddIntermediateTensors(PyObject* data) {
134 using tflite::interpreter_wrapper::PythonErrorReporter;
135 char* buf = nullptr;
136 Py_ssize_t length;
137 std::unique_ptr<PythonErrorReporter> error_reporter(new PythonErrorReporter);
138 ::tflite::python::ImportNumpy();
139
140 if (python_utils::ConvertFromPyString(data, &buf, &length) == -1) {
141 return nullptr;
142 }
143 std::unique_ptr<tflite::FlatBufferModel> model =
144 tflite::FlatBufferModel::BuildFromBuffer(buf, length,
145 error_reporter.get());
146 if (!model) {
147 PyErr_Format(PyExc_ValueError, "Invalid model");
148 return nullptr;
149 }
150 flatbuffers::FlatBufferBuilder builder;
151 auto tflite_model = CreateMutableModel(*model->GetModel());
152 if (optimize::AddIntermediateTensorsToFusedOp(&builder, tflite_model.get()) !=
153 kTfLiteOk) {
154 error_reporter->exception();
155 return nullptr;
156 }
157
158 if (builder.GetSize()) {
159 return python_utils::ConvertToPyString(
160 reinterpret_cast<const char*>(builder.GetCurrentBufferPointer()),
161 builder.GetSize());
162 } else {
163 // When AddIntermediateTensorsToFusedOp early returns, return the model as
164 // it is.
165 return python_utils::ConvertToPyString(buf, length);
166 }
167 }
168
CalibrationWrapper(std::unique_ptr<tflite::Interpreter> interpreter,std::unique_ptr<tflite::ops::builtin::BuiltinOpResolver> resolver,std::unique_ptr<tflite::interpreter_wrapper::PythonErrorReporter> error_reporter,std::unique_ptr<tflite::FlatBufferModel> model,std::unique_ptr<tflite::optimize::calibration::CalibrationReader> reader,std::unique_ptr<std::string> model_str)169 CalibrationWrapper::CalibrationWrapper(
170 std::unique_ptr<tflite::Interpreter> interpreter,
171 std::unique_ptr<tflite::ops::builtin::BuiltinOpResolver> resolver,
172 std::unique_ptr<tflite::interpreter_wrapper::PythonErrorReporter>
173 error_reporter,
174 std::unique_ptr<tflite::FlatBufferModel> model,
175 std::unique_ptr<tflite::optimize::calibration::CalibrationReader> reader,
176 std::unique_ptr<std::string> model_str)
177 : interpreter_(std::move(interpreter)),
178 error_reporter_(std::move(error_reporter)),
179 resolver_(std::move(resolver)),
180 model_(std::move(model)),
181 reader_(std::move(reader)),
182 model_str_(std::move(model_str)) {}
183
~CalibrationWrapper()184 CalibrationWrapper::~CalibrationWrapper() {}
185
Prepare()186 PyObject* CalibrationWrapper::Prepare() {
187 TFLITE_PY_ENSURE_VALID_INTERPRETER();
188 TFLITE_PY_CHECK(interpreter_->AllocateTensors());
189 TFLITE_PY_CHECK(interpreter_->ResetVariableTensors());
190 Py_RETURN_NONE;
191 }
192
Prepare(PyObject * input_shapes)193 PyObject* CalibrationWrapper::Prepare(PyObject* input_shapes) {
194 TFLITE_PY_ENSURE_VALID_INTERPRETER();
195 if (!PyList_Check(input_shapes)) {
196 PyErr_Format(PyExc_ValueError,
197 "Invalid input shapes: expected shapes to be a list.");
198 return nullptr;
199 }
200
201 const size_t inputs_size = PyList_Size(input_shapes);
202 if (inputs_size != interpreter_->inputs().size()) {
203 PyErr_Format(PyExc_ValueError,
204 "Invalid input shapes: expected %ld items got %ld items.",
205 interpreter_->inputs().size(), inputs_size);
206 return nullptr;
207 }
208
209 for (size_t i = 0; i < inputs_size; i++) {
210 PyObject* shape = PyList_GetItem(input_shapes, i);
211 if (!shape || !PyList_Check(shape)) {
212 PyErr_Format(PyExc_ValueError,
213 "Invalid %ld input shape: expected to be a list.", i);
214 return nullptr;
215 }
216 std::vector<int> dims;
217 for (size_t dim_index = 0; dim_index < PyList_Size(shape); ++dim_index) {
218 PyObject* dim = PyList_GetItem(shape, dim_index);
219 dims.push_back(PyLong_AsLong(dim));
220 }
221 int input_tensor_idx = interpreter_->inputs()[i];
222 if (interpreter_->ResizeInputTensor(input_tensor_idx, dims) != kTfLiteOk) {
223 PyErr_Format(PyExc_ValueError, "Failed to resize %ld input tensor.", i);
224 return nullptr;
225 }
226 }
227
228 return Prepare();
229 }
230
FeedTensor(PyObject * input_value)231 PyObject* CalibrationWrapper::FeedTensor(PyObject* input_value) {
232 TFLITE_PY_ENSURE_VALID_INTERPRETER();
233 if (!PyList_Check(input_value)) {
234 PyErr_Format(PyExc_ValueError,
235 "Invalid input type: expected input to be a list.");
236 return nullptr;
237 }
238
239 const size_t inputs_size = PyList_Size(input_value);
240
241 if (inputs_size != interpreter_->inputs().size()) {
242 PyErr_Format(PyExc_ValueError,
243 "Invalid input size: expected %ld items got %ld items.",
244 interpreter_->inputs().size(), inputs_size);
245 return nullptr;
246 }
247
248 for (size_t i = 0; i < inputs_size; i++) {
249 PyObject* input = PyList_GetItem(input_value, i);
250 if (!input) {
251 return nullptr;
252 }
253 int input_tensor_idx = interpreter_->inputs()[i];
254 if (!SetTensor(input_tensor_idx, input)) {
255 return nullptr;
256 }
257 }
258
259 TFLITE_PY_CHECK(interpreter_->Invoke());
260 Py_RETURN_NONE;
261 }
262
SetTensor(int index,PyObject * value)263 PyObject* CalibrationWrapper::SetTensor(int index, PyObject* value) {
264 TFLITE_PY_ENSURE_VALID_INTERPRETER();
265
266 std::unique_ptr<PyObject, PyDecrefDeleter> array_safe(
267 PyArray_FromAny(value, nullptr, 0, 0, NPY_ARRAY_CARRAY, nullptr));
268 if (!array_safe) {
269 PyErr_SetString(PyExc_ValueError,
270 "Failed to convert value into readable tensor.");
271 return nullptr;
272 }
273
274 PyArrayObject* array = reinterpret_cast<PyArrayObject*>(array_safe.get());
275 const TfLiteTensor* tensor = interpreter_->tensor(index);
276
277 if (python_utils::TfLiteTypeFromPyArray(array) != tensor->type) {
278 PyErr_Format(PyExc_ValueError,
279 "Cannot set tensor:"
280 " Got value of type %s"
281 " but expected type %s for input %d, name: %s ",
282 TfLiteTypeGetName(python_utils::TfLiteTypeFromPyArray(array)),
283 TfLiteTypeGetName(tensor->type), index, tensor->name);
284 return nullptr;
285 }
286
287 if (PyArray_NDIM(array) != tensor->dims->size) {
288 PyErr_Format(
289 PyExc_ValueError,
290 "Cannot set tensor: Dimension count mismatch, expected %d but found %d",
291 tensor->dims->size, PyArray_NDIM(array));
292 return nullptr;
293 }
294
295 std::vector<int> dims(PyArray_NDIM(array));
296 bool has_unknown_dims = false;
297 for (int j = 0; j < PyArray_NDIM(array); j++) {
298 // Ensure the calibration data input shape is the same as the model input
299 // shape unless the dimension is unknown.
300 if (tensor->dims_signature != nullptr &&
301 tensor->dims_signature->size == tensor->dims->size &&
302 tensor->dims_signature->data[j] == -1) {
303 has_unknown_dims = true;
304 } else if (tensor->dims->data[j] != PyArray_SHAPE(array)[j]) {
305 PyErr_Format(PyExc_ValueError,
306 "Cannot set tensor: Size mismatch, expected %d for dim "
307 "%d but found %ld",
308 tensor->dims->data[j], j, PyArray_SHAPE(array)[j]);
309 return nullptr;
310 }
311 dims[j] = PyArray_SHAPE(array)[j];
312 }
313
314 // Resize the input tensor if there are unknown dimensions.
315 if (has_unknown_dims) {
316 // Does strict checking on the `ResizeInputTensor` call.
317 TFLITE_PY_CHECK(interpreter_->ResizeInputTensorStrict(index, dims));
318 TFLITE_PY_CHECK(interpreter_->AllocateTensors());
319 }
320
321 tensor = interpreter_->tensor(index);
322
323 size_t size = PyArray_NBYTES(array);
324
325 if (tensor->type == kTfLiteString) {
326 tflite::DynamicBuffer buffer;
327 buffer.AddString(reinterpret_cast<const char*>(PyArray_BYTES(array)), size);
328 buffer.WriteToTensor(interpreter_->tensor(index), /*new_shape=*/nullptr);
329 Py_RETURN_NONE;
330 }
331
332 if (size != tensor->bytes) {
333 PyErr_Format(PyExc_ValueError,
334 "numpy array had %zu bytes but expected %zu bytes.", size,
335 tensor->bytes);
336 return nullptr;
337 }
338 memcpy(tensor->data.raw, PyArray_DATA(array), size);
339 Py_RETURN_NONE;
340 }
341
Calibrate()342 PyObject* CalibrationWrapper::Calibrate() {
343 auto tflite_model = CreateMutableModel(*model_->GetModel());
344 reader_->AddCalibrationToModel(tflite_model.get(), /*update=*/false);
345 flatbuffers::FlatBufferBuilder builder;
346 auto loc = tflite::Model::Pack(builder, tflite_model.get());
347 tflite::FinishModelBuffer(builder, loc);
348 return python_utils::ConvertToPyString(
349 reinterpret_cast<const char*>(builder.GetCurrentBufferPointer()),
350 builder.GetSize());
351 }
352
QuantizeModel(int input_py_type,int output_py_type,bool allow_float,int activations_py_type)353 PyObject* CalibrationWrapper::QuantizeModel(int input_py_type,
354 int output_py_type,
355 bool allow_float,
356 int activations_py_type) {
357 return QuantizeModel(input_py_type, output_py_type, allow_float,
358 activations_py_type, /*disable_per_channel=*/false);
359 }
360
QuantizeModel(int input_py_type,int output_py_type,bool allow_float,int activations_py_type,bool disable_per_channel)361 PyObject* CalibrationWrapper::QuantizeModel(int input_py_type,
362 int output_py_type,
363 bool allow_float,
364 int activations_py_type,
365 bool disable_per_channel) {
366 if (NoOpModel(*model_)) {
367 return python_utils::ConvertToPyString(model_str_->data(),
368 model_str_->size());
369 }
370
371 TfLiteType input_type = python_utils::TfLiteTypeFromPyType(input_py_type);
372 TfLiteType output_type = python_utils::TfLiteTypeFromPyType(output_py_type);
373 TfLiteType activations_type =
374 python_utils::TfLiteTypeFromPyType(activations_py_type);
375
376 if (input_type == kTfLiteNoType || output_type == kTfLiteNoType) {
377 PyErr_SetString(PyExc_ValueError,
378 "Input/output type cannot be kTfLiteNoType");
379 return nullptr;
380 }
381 auto tflite_model = CreateMutableModel(*model_->GetModel());
382 reader_->AddCalibrationToModel(tflite_model.get(), /*update=*/false);
383 flatbuffers::FlatBufferBuilder builder;
384 auto status = kTfLiteOk;
385
386 status = tflite::optimize::QuantizeModelAllOperators(
387 &builder, tflite_model.get(), TfLiteTypeToSchemaType(input_type),
388 TfLiteTypeToSchemaType(output_type), allow_float,
389 TfLiteTypeToSchemaType(activations_type), disable_per_channel,
390 error_reporter_.get());
391
392 if (status != kTfLiteOk) {
393 error_reporter_->exception();
394 return nullptr;
395 }
396
397 return python_utils::ConvertToPyString(
398 reinterpret_cast<const char*>(builder.GetCurrentBufferPointer()),
399 builder.GetSize());
400 }
401
QuantizeModel(int input_py_type,int output_py_type,bool allow_float,const char * operator_output_name)402 PyObject* CalibrationWrapper::QuantizeModel(int input_py_type,
403 int output_py_type,
404 bool allow_float,
405 const char* operator_output_name) {
406 string op_name = std::string(operator_output_name);
407
408 TfLiteType input_type = python_utils::TfLiteTypeFromPyType(input_py_type);
409 TfLiteType output_type = python_utils::TfLiteTypeFromPyType(output_py_type);
410 if (input_type == kTfLiteNoType || output_type == kTfLiteNoType) {
411 PyErr_SetString(PyExc_ValueError,
412 "Input/output type cannot be kTfLiteNoType");
413 return nullptr;
414 }
415 auto tflite_model = CreateMutableModel(*model_->GetModel());
416 reader_->AddCalibrationToModel(tflite_model.get(), /*update=*/false);
417 flatbuffers::FlatBufferBuilder builder;
418 auto status = tflite::optimize::QuantizeModel(
419 &builder, tflite_model.get(), TfLiteTypeToSchemaType(input_type),
420 TfLiteTypeToSchemaType(output_type), allow_float, {op_name},
421 TensorType_INT8, error_reporter_.get());
422 if (status != kTfLiteOk) {
423 error_reporter_->exception();
424 return nullptr;
425 }
426
427 return python_utils::ConvertToPyString(
428 reinterpret_cast<const char*>(builder.GetCurrentBufferPointer()),
429 builder.GetSize());
430 }
431
CreateWrapperCPPFromBuffer(PyObject * data,const std::vector<std::string> & registerers_by_name,const std::vector<std::function<void (uintptr_t)>> & registerers_by_func,std::string * error_msg)432 /*static*/ CalibrationWrapper* CalibrationWrapper::CreateWrapperCPPFromBuffer(
433 PyObject* data, const std::vector<std::string>& registerers_by_name,
434 const std::vector<std::function<void(uintptr_t)>>& registerers_by_func,
435 std::string* error_msg) {
436 using tflite::interpreter_wrapper::PythonErrorReporter;
437 char* buf = nullptr;
438 Py_ssize_t length;
439 std::unique_ptr<PythonErrorReporter> error_reporter(new PythonErrorReporter);
440 ::tflite::python::ImportNumpy();
441
442 if (python_utils::ConvertFromPyString(data, &buf, &length) == -1) {
443 *error_msg = "Failed to convert from python string";
444 return nullptr;
445 }
446 std::unique_ptr<tflite::FlatBufferModel> model =
447 tflite::FlatBufferModel::BuildFromBuffer(buf, length,
448 error_reporter.get());
449 if (!model) {
450 *error_msg = "Invalid model";
451 return nullptr;
452 }
453 auto resolver = absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>();
454 for (const auto& registerer : registerers_by_name) {
455 if (!RegisterCustomOpByName(registerer.c_str(), resolver.get())) {
456 *error_msg =
457 absl::StrFormat("Looking up symbol '%s' failed with error '%s'.",
458 registerer.c_str(), SharedLibrary::GetError());
459 return nullptr;
460 }
461 }
462 for (const auto& registerer : registerers_by_func) {
463 registerer(reinterpret_cast<uintptr_t>(resolver.get()));
464 }
465 std::unique_ptr<tflite::Interpreter> interpreter;
466 std::unique_ptr<tflite::optimize::calibration::CalibrationReader> reader;
467 auto status = tflite::optimize::calibration::BuildLoggingInterpreter(
468 *model, *resolver, &interpreter, &reader);
469 if (status != kTfLiteOk) {
470 *error_msg = error_reporter->message();
471 return nullptr;
472 }
473
474 auto model_str = std::make_unique<std::string>(buf, length);
475 // If we are not going to use this string during quantization, reset the
476 // pointer and release the memory.
477 if (!NoOpModel(*model)) {
478 model_str.reset();
479 }
480
481 auto wrapper = new CalibrationWrapper(
482 std::move(interpreter), std::move(resolver), std::move(error_reporter),
483 std::move(model), std::move(reader), std::move(model_str));
484 return wrapper;
485 }
486
487 } // namespace calibration_wrapper
488 } // namespace tflite
489