1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/python/optimize/calibration_wrapper.h"
16
17 #include <functional>
18 #include <memory>
19 #include <optional>
20 #include <sstream>
21 #include <string>
22 #include <utility>
23
24 #include "absl/memory/memory.h"
25 #include "absl/strings/str_format.h"
26 #include "absl/types/optional.h"
27 #include "tensorflow/lite/c/common.h"
28 #include "tensorflow/lite/interpreter.h"
29 #include "tensorflow/lite/kernels/register.h"
30 #include "tensorflow/lite/model.h"
31 #include "tensorflow/lite/python/interpreter_wrapper/numpy.h"
32 #include "tensorflow/lite/python/interpreter_wrapper/python_error_reporter.h"
33 #include "tensorflow/lite/python/interpreter_wrapper/python_utils.h"
34 #include "tensorflow/lite/shared_library.h"
35 #include "tensorflow/lite/tools/optimize/calibration/calibration_reader.h"
36 #include "tensorflow/lite/tools/optimize/calibration/calibrator.h"
37 #include "tensorflow/lite/tools/optimize/quantization_wrapper_utils.h"
38 #include "tensorflow/lite/tools/optimize/quantize_model.h"
39
40 #define TFLITE_PY_CHECK(x) \
41 if ((x) != kTfLiteOk) { \
42 return error_reporter_->exception(); \
43 }
44
45 #define TFLITE_PY_ENSURE_VALID_INTERPRETER() \
46 if (!interpreter_) { \
47 PyErr_SetString(PyExc_ValueError, "Interpreter was not initialized."); \
48 return nullptr; \
49 }
50
51 namespace tflite {
52 namespace calibration_wrapper {
53
54 namespace {
55
56 using python_utils::PyDecrefDeleter;
57
CreateMutableModel(const tflite::Model & model)58 std::unique_ptr<tflite::ModelT> CreateMutableModel(const tflite::Model& model) {
59 auto copied_model = std::make_unique<tflite::ModelT>();
60 model.UnPackTo(copied_model.get(), nullptr);
61 return copied_model;
62 }
63
NoOpModel(const tflite::FlatBufferModel & model)64 bool NoOpModel(const tflite::FlatBufferModel& model) {
65 return model->subgraphs()->size() == 1 &&
66 (!model->subgraphs()->begin()->operators() ||
67 model->subgraphs()->begin()->operators()->size() == 0);
68 }
69
TfLiteTypeToSchemaType(TfLiteType type)70 inline TensorType TfLiteTypeToSchemaType(TfLiteType type) {
71 switch (type) {
72 case kTfLiteNoType:
73 return TensorType_FLOAT32; // TODO(b/129336260): No schema type for none.
74 case kTfLiteFloat32:
75 return TensorType_FLOAT32;
76 case kTfLiteFloat16:
77 return TensorType_FLOAT16;
78 case kTfLiteFloat64:
79 return TensorType_FLOAT64;
80 case kTfLiteInt32:
81 return TensorType_INT32;
82 case kTfLiteUInt32:
83 return TensorType_UINT32;
84 case kTfLiteUInt8:
85 return TensorType_UINT8;
86 case kTfLiteInt8:
87 return TensorType_INT8;
88 case kTfLiteInt64:
89 return TensorType_INT64;
90 case kTfLiteUInt64:
91 return TensorType_UINT64;
92 case kTfLiteString:
93 return TensorType_STRING;
94 case kTfLiteBool:
95 return TensorType_BOOL;
96 case kTfLiteInt16:
97 return TensorType_INT16;
98 case kTfLiteUInt16:
99 return TensorType_UINT16;
100 case kTfLiteComplex64:
101 return TensorType_COMPLEX64;
102 case kTfLiteComplex128:
103 return TensorType_COMPLEX128;
104 case kTfLiteResource:
105 return TensorType_RESOURCE;
106 case kTfLiteVariant:
107 return TensorType_VARIANT;
108 }
109 // No default to get compiler error when new type is introduced.
110 }
111
RegisterCustomOpByName(const char * registerer_name,tflite::MutableOpResolver * resolver)112 bool RegisterCustomOpByName(const char* registerer_name,
113 tflite::MutableOpResolver* resolver) {
114 // Registerer functions take a pointer to a BuiltinOpResolver as an input
115 // parameter and return void.
116 // TODO(b/137576229): We should implement this functionality in a more
117 // principled way.
118 typedef void (*RegistererFunctionType)(tflite::MutableOpResolver*);
119
120 // Look for the Registerer function by name.
121 RegistererFunctionType registerer = reinterpret_cast<RegistererFunctionType>(
122 SharedLibrary::GetSymbol(registerer_name));
123
124 // Fail in an informative way if the function was not found.
125 if (registerer == nullptr) {
126 PyErr_Format(PyExc_ValueError,
127 "Looking up symbol '%s' failed with error '%s'.",
128 registerer_name, SharedLibrary::GetError());
129 return false;
130 }
131
132 // Call the registerer with the resolver.
133 registerer(resolver);
134 return true;
135 }
136
137 // Returns the dimension from the stored list in the PyObject. If the given
138 // PyObject is not a list, it will return absl::optional and set the Python
139 // error message to notify users.
ConvertInputShapeToVector(PyObject * input_shapes,size_t index)140 std::optional<std::vector<int>> ConvertInputShapeToVector(
141 PyObject* input_shapes, size_t index) {
142 PyObject* shape = PyList_GetItem(input_shapes, index);
143 if (!shape || !PyList_Check(shape)) {
144 PyErr_Format(PyExc_ValueError,
145 "Invalid %ld input shape: expected to be a list.", index);
146 return std::nullopt;
147 }
148 size_t size = PyList_Size(shape);
149 std::vector<int> dims(size);
150 for (size_t dim_index = 0; dim_index < size; ++dim_index) {
151 PyObject* dim = PyList_GetItem(shape, dim_index);
152 dims[dim_index] = PyLong_AsLong(dim);
153 }
154 return dims;
155 }
156
157 } // namespace
158
AddIntermediateTensors(PyObject * data)159 PyObject* AddIntermediateTensors(PyObject* data) {
160 using tflite::interpreter_wrapper::PythonErrorReporter;
161 char* buf = nullptr;
162 Py_ssize_t length;
163 std::unique_ptr<PythonErrorReporter> error_reporter(new PythonErrorReporter);
164 ::tflite::python::ImportNumpy();
165
166 if (python_utils::ConvertFromPyString(data, &buf, &length) == -1) {
167 return nullptr;
168 }
169 std::unique_ptr<tflite::FlatBufferModel> model =
170 tflite::FlatBufferModel::BuildFromBuffer(buf, length,
171 error_reporter.get());
172 if (!model) {
173 PyErr_Format(PyExc_ValueError, "Invalid model");
174 return nullptr;
175 }
176 flatbuffers::FlatBufferBuilder builder;
177 auto tflite_model = CreateMutableModel(*model->GetModel());
178 if (optimize::AddIntermediateTensorsToFusedOp(&builder, tflite_model.get()) !=
179 kTfLiteOk) {
180 error_reporter->exception();
181 return nullptr;
182 }
183
184 if (builder.GetSize()) {
185 return python_utils::ConvertToPyString(
186 reinterpret_cast<const char*>(builder.GetCurrentBufferPointer()),
187 builder.GetSize());
188 } else {
189 // When AddIntermediateTensorsToFusedOp early returns, return the model as
190 // it is.
191 return python_utils::ConvertToPyString(buf, length);
192 }
193 }
194
CalibrationWrapper(std::unique_ptr<tflite::Interpreter> interpreter,std::unique_ptr<tflite::ops::builtin::BuiltinOpResolver> resolver,std::unique_ptr<tflite::interpreter_wrapper::PythonErrorReporter> error_reporter,std::unique_ptr<tflite::FlatBufferModel> model,std::unique_ptr<tflite::optimize::calibration::CalibrationReader> reader,std::unique_ptr<std::string> model_str)195 CalibrationWrapper::CalibrationWrapper(
196 std::unique_ptr<tflite::Interpreter> interpreter,
197 std::unique_ptr<tflite::ops::builtin::BuiltinOpResolver> resolver,
198 std::unique_ptr<tflite::interpreter_wrapper::PythonErrorReporter>
199 error_reporter,
200 std::unique_ptr<tflite::FlatBufferModel> model,
201 std::unique_ptr<tflite::optimize::calibration::CalibrationReader> reader,
202 std::unique_ptr<std::string> model_str)
203 : interpreter_(std::move(interpreter)),
204 error_reporter_(std::move(error_reporter)),
205 resolver_(std::move(resolver)),
206 model_(std::move(model)),
207 reader_(std::move(reader)),
208 model_str_(std::move(model_str)) {}
209
~CalibrationWrapper()210 CalibrationWrapper::~CalibrationWrapper() {}
211
Prepare()212 PyObject* CalibrationWrapper::Prepare() {
213 TFLITE_PY_ENSURE_VALID_INTERPRETER();
214 TFLITE_PY_CHECK(interpreter_->AllocateTensors());
215 TFLITE_PY_CHECK(interpreter_->ResetVariableTensors());
216 Py_RETURN_NONE;
217 }
218
Prepare(std::string signature_key)219 PyObject* CalibrationWrapper::Prepare(std::string signature_key) {
220 TFLITE_PY_ENSURE_VALID_INTERPRETER();
221 SignatureRunner* runner =
222 interpreter_->GetSignatureRunner(signature_key.c_str());
223 if (runner == nullptr) {
224 PyErr_Format(PyExc_ValueError, "Invalid signature key: %s",
225 signature_key.c_str());
226 return nullptr;
227 }
228 TFLITE_PY_CHECK(runner->AllocateTensors());
229 TFLITE_PY_CHECK(interpreter_->ResetVariableTensors());
230 Py_RETURN_NONE;
231 }
232
Prepare(PyObject * input_shapes,std::string signature_key)233 PyObject* CalibrationWrapper::Prepare(PyObject* input_shapes,
234 std::string signature_key) {
235 TFLITE_PY_ENSURE_VALID_INTERPRETER();
236 if (!PyList_Check(input_shapes)) {
237 PyErr_Format(PyExc_ValueError,
238 "Invalid input shapes: expected shapes to be a list.");
239 return nullptr;
240 }
241 const int subgraph_index =
242 interpreter_->GetSubgraphIndexFromSignature(signature_key.c_str());
243 if (subgraph_index == -1) {
244 PyErr_Format(PyExc_ValueError, "Invalid signature key: %s",
245 signature_key.c_str());
246 return nullptr;
247 }
248 auto* subgraph = interpreter_->subgraph(subgraph_index);
249
250 const size_t inputs_size = PyList_Size(input_shapes);
251 if (inputs_size != subgraph->inputs().size()) {
252 PyErr_Format(PyExc_ValueError,
253 "Invalid input shapes: expected %ld items got %ld items.",
254 subgraph->inputs().size(), inputs_size);
255 return nullptr;
256 }
257
258 for (size_t i = 0; i < inputs_size; ++i) {
259 std::optional<std::vector<int>> dims =
260 ConvertInputShapeToVector(input_shapes, i);
261 if (!dims.has_value()) {
262 return nullptr;
263 }
264 int input_tensor_idx = subgraph->inputs()[i];
265 if (subgraph->ResizeInputTensor(input_tensor_idx, *dims) != kTfLiteOk) {
266 PyErr_Format(PyExc_ValueError, "Failed to resize %ld input tensor.", i);
267 return nullptr;
268 }
269 }
270
271 return Prepare(signature_key);
272 }
273
Prepare(PyObject * input_shapes)274 PyObject* CalibrationWrapper::Prepare(PyObject* input_shapes) {
275 TFLITE_PY_ENSURE_VALID_INTERPRETER();
276 if (!PyList_Check(input_shapes)) {
277 PyErr_Format(PyExc_ValueError,
278 "Invalid input shapes: expected shapes to be a list.");
279 return nullptr;
280 }
281
282 const size_t inputs_size = PyList_Size(input_shapes);
283 if (inputs_size != interpreter_->inputs().size()) {
284 PyErr_Format(PyExc_ValueError,
285 "Invalid input shapes: expected %ld items got %ld items.",
286 interpreter_->inputs().size(), inputs_size);
287 return nullptr;
288 }
289
290 for (size_t i = 0; i < inputs_size; ++i) {
291 std::optional<std::vector<int>> dims =
292 ConvertInputShapeToVector(input_shapes, i);
293 if (!dims.has_value()) {
294 return nullptr;
295 }
296 int input_tensor_idx = interpreter_->inputs()[i];
297 if (interpreter_->ResizeInputTensor(input_tensor_idx, *dims) != kTfLiteOk) {
298 PyErr_Format(PyExc_ValueError, "Failed to resize %ld input tensor.", i);
299 return nullptr;
300 }
301 }
302
303 return Prepare();
304 }
305
FeedTensor(PyObject * input_value,std::string signature_key)306 PyObject* CalibrationWrapper::FeedTensor(PyObject* input_value,
307 std::string signature_key) {
308 TFLITE_PY_ENSURE_VALID_INTERPRETER();
309 if (!PyList_Check(input_value)) {
310 PyErr_Format(PyExc_ValueError,
311 "Invalid input type: expected input to be a list.");
312 return nullptr;
313 }
314 const int subgraph_index =
315 interpreter_->GetSubgraphIndexFromSignature(signature_key.c_str());
316 if (subgraph_index == -1) {
317 PyErr_Format(PyExc_ValueError, "Invalid signature key: %s",
318 signature_key.c_str());
319 return nullptr;
320 }
321 const size_t inputs_size = PyList_Size(input_value);
322
323 auto* subgraph = interpreter_->subgraph(subgraph_index);
324 if (inputs_size != subgraph->inputs().size()) {
325 PyErr_Format(PyExc_ValueError,
326 "Invalid input size: expected %ld items got %ld items.",
327 subgraph->inputs().size(), inputs_size);
328 return nullptr;
329 }
330
331 for (size_t i = 0; i < inputs_size; ++i) {
332 PyObject* input = PyList_GetItem(input_value, i);
333 if (!input) {
334 return nullptr;
335 }
336 int input_tensor_idx = subgraph->inputs()[i];
337 if (!SetTensor(input_tensor_idx, input, signature_key)) {
338 return nullptr;
339 }
340 }
341
342 TFLITE_PY_CHECK(subgraph->Invoke());
343 Py_RETURN_NONE;
344 }
345
FeedTensor(PyObject * input_value)346 PyObject* CalibrationWrapper::FeedTensor(PyObject* input_value) {
347 TFLITE_PY_ENSURE_VALID_INTERPRETER();
348 if (!PyList_Check(input_value)) {
349 PyErr_Format(PyExc_ValueError,
350 "Invalid input type: expected input to be a list.");
351 return nullptr;
352 }
353
354 const size_t inputs_size = PyList_Size(input_value);
355
356 if (inputs_size != interpreter_->inputs().size()) {
357 PyErr_Format(PyExc_ValueError,
358 "Invalid input size: expected %ld items got %ld items.",
359 interpreter_->inputs().size(), inputs_size);
360 return nullptr;
361 }
362
363 for (size_t i = 0; i < inputs_size; ++i) {
364 PyObject* input = PyList_GetItem(input_value, i);
365 if (!input) {
366 return nullptr;
367 }
368 int input_tensor_idx = interpreter_->inputs()[i];
369 if (!SetTensor(input_tensor_idx, input)) {
370 return nullptr;
371 }
372 }
373
374 TFLITE_PY_CHECK(interpreter_->Invoke());
375 Py_RETURN_NONE;
376 }
377
SetTensor(int index,PyObject * value,std::string signature_key)378 PyObject* CalibrationWrapper::SetTensor(int index, PyObject* value,
379 std::string signature_key) {
380 TFLITE_PY_ENSURE_VALID_INTERPRETER();
381 std::unique_ptr<PyObject, PyDecrefDeleter> array_safe(
382 PyArray_FromAny(value, nullptr, 0, 0, NPY_ARRAY_CARRAY, nullptr));
383 if (!array_safe) {
384 PyErr_SetString(PyExc_ValueError,
385 "Failed to convert value into readable tensor.");
386 return nullptr;
387 }
388
389 PyArrayObject* array = reinterpret_cast<PyArrayObject*>(array_safe.get());
390
391 const int subgraph_index =
392 interpreter_->GetSubgraphIndexFromSignature(signature_key.c_str());
393 if (subgraph_index == -1) {
394 PyErr_Format(PyExc_ValueError, "Invalid signature key: %s",
395 signature_key.c_str());
396 return nullptr;
397 }
398 auto* subgraph = interpreter_->subgraph(subgraph_index);
399 const TfLiteTensor* tensor = subgraph->tensor(index);
400
401 if (python_utils::TfLiteTypeFromPyArray(array) != tensor->type) {
402 PyErr_Format(PyExc_ValueError,
403 "Cannot set tensor: "
404 "Got value of type %s "
405 "but expected type %s for input %d, name: %s ",
406 TfLiteTypeGetName(python_utils::TfLiteTypeFromPyArray(array)),
407 TfLiteTypeGetName(tensor->type), index, tensor->name);
408 return nullptr;
409 }
410
411 if (PyArray_NDIM(array) != tensor->dims->size) {
412 PyErr_Format(PyExc_ValueError,
413 "Cannot set tensor: Dimension count mismatch, expected %d "
414 "but found %d",
415 tensor->dims->size, PyArray_NDIM(array));
416 return nullptr;
417 }
418
419 std::vector<int> dims(PyArray_NDIM(array));
420 bool has_unknown_dims = false;
421 for (int j = 0; j < PyArray_NDIM(array); ++j) {
422 // Ensure the calibration data input shape is the same as the model input
423 // shape unless the dimension is unknown.
424 if (tensor->dims_signature != nullptr &&
425 tensor->dims_signature->size == tensor->dims->size &&
426 tensor->dims_signature->data[j] == -1) {
427 has_unknown_dims = true;
428 } else if (tensor->dims->data[j] != PyArray_SHAPE(array)[j]) {
429 PyErr_Format(PyExc_ValueError,
430 "Cannot set tensor: Size mismatch, expected %d for dim "
431 "%d but found %ld",
432 tensor->dims->data[j], j, PyArray_SHAPE(array)[j]);
433 return nullptr;
434 }
435 dims[j] = PyArray_SHAPE(array)[j];
436 }
437
438 // Resize the input tensor if there are unknown dimensions.
439 if (has_unknown_dims) {
440 // Does strict checking on the `ResizeInputTensor` call.
441 TFLITE_PY_CHECK(subgraph->ResizeInputTensorStrict(index, dims));
442 TFLITE_PY_CHECK(subgraph->AllocateTensors());
443 }
444
445 // Re-read the updated tensor after the allocation is done.
446 tensor = subgraph->tensor(index);
447
448 size_t size = PyArray_NBYTES(array);
449
450 if (tensor->type == kTfLiteString) {
451 tflite::DynamicBuffer buffer;
452 buffer.AddString(reinterpret_cast<const char*>(PyArray_BYTES(array)), size);
453 buffer.WriteToTensor(subgraph->tensor(index), /*new_shape=*/nullptr);
454 Py_RETURN_NONE;
455 }
456
457 if (size != tensor->bytes) {
458 PyErr_Format(PyExc_ValueError,
459 "numpy array had %zu bytes but expected %zu bytes.", size,
460 tensor->bytes);
461 return nullptr;
462 }
463 memcpy(tensor->data.raw, PyArray_DATA(array), size);
464 Py_RETURN_NONE;
465 }
466
SetTensor(int index,PyObject * value)467 PyObject* CalibrationWrapper::SetTensor(int index, PyObject* value) {
468 TFLITE_PY_ENSURE_VALID_INTERPRETER();
469
470 std::unique_ptr<PyObject, PyDecrefDeleter> array_safe(
471 PyArray_FromAny(value, nullptr, 0, 0, NPY_ARRAY_CARRAY, nullptr));
472 if (!array_safe) {
473 PyErr_SetString(PyExc_ValueError,
474 "Failed to convert value into readable tensor.");
475 return nullptr;
476 }
477
478 PyArrayObject* array = reinterpret_cast<PyArrayObject*>(array_safe.get());
479 const TfLiteTensor* tensor = interpreter_->tensor(index);
480
481 if (python_utils::TfLiteTypeFromPyArray(array) != tensor->type) {
482 PyErr_Format(PyExc_ValueError,
483 "Cannot set tensor: "
484 "Got value of type %s "
485 "but expected type %s for input %d, name: %s ",
486 TfLiteTypeGetName(python_utils::TfLiteTypeFromPyArray(array)),
487 TfLiteTypeGetName(tensor->type), index, tensor->name);
488 return nullptr;
489 }
490
491 if (PyArray_NDIM(array) != tensor->dims->size) {
492 PyErr_Format(
493 PyExc_ValueError,
494 "Cannot set tensor: Dimension count mismatch, expected %d but found %d",
495 tensor->dims->size, PyArray_NDIM(array));
496 return nullptr;
497 }
498
499 std::vector<int> dims(PyArray_NDIM(array));
500 bool has_unknown_dims = false;
501 for (int j = 0; j < PyArray_NDIM(array); ++j) {
502 // Ensure the calibration data input shape is the same as the model input
503 // shape unless the dimension is unknown.
504 if (tensor->dims_signature != nullptr &&
505 tensor->dims_signature->size == tensor->dims->size &&
506 tensor->dims_signature->data[j] == -1) {
507 has_unknown_dims = true;
508 } else if (tensor->dims->data[j] != PyArray_SHAPE(array)[j]) {
509 PyErr_Format(PyExc_ValueError,
510 "Cannot set tensor: Size mismatch, expected %d for dim "
511 "%d but found %ld",
512 tensor->dims->data[j], j, PyArray_SHAPE(array)[j]);
513 return nullptr;
514 }
515 dims[j] = PyArray_SHAPE(array)[j];
516 }
517
518 // Resize the input tensor if there are unknown dimensions.
519 if (has_unknown_dims) {
520 // Does strict checking on the `ResizeInputTensor` call.
521 TFLITE_PY_CHECK(interpreter_->ResizeInputTensorStrict(index, dims));
522 TFLITE_PY_CHECK(interpreter_->AllocateTensors());
523 }
524
525 // Re-read the updated tensor after the allocation is done.
526 tensor = interpreter_->tensor(index);
527
528 size_t size = PyArray_NBYTES(array);
529
530 if (tensor->type == kTfLiteString) {
531 tflite::DynamicBuffer buffer;
532 buffer.AddString(reinterpret_cast<const char*>(PyArray_BYTES(array)), size);
533 buffer.WriteToTensor(interpreter_->tensor(index), /*new_shape=*/nullptr);
534 Py_RETURN_NONE;
535 }
536
537 if (size != tensor->bytes) {
538 PyErr_Format(PyExc_ValueError,
539 "numpy array had %zu bytes but expected %zu bytes.", size,
540 tensor->bytes);
541 return nullptr;
542 }
543 memcpy(tensor->data.raw, PyArray_DATA(array), size);
544 Py_RETURN_NONE;
545 }
546
Calibrate()547 PyObject* CalibrationWrapper::Calibrate() {
548 auto tflite_model = CreateMutableModel(*model_->GetModel());
549 reader_->AddCalibrationToModel(tflite_model.get(), /*update=*/false);
550 flatbuffers::FlatBufferBuilder builder;
551 auto loc = tflite::Model::Pack(builder, tflite_model.get());
552 tflite::FinishModelBuffer(builder, loc);
553 return python_utils::ConvertToPyString(
554 reinterpret_cast<const char*>(builder.GetCurrentBufferPointer()),
555 builder.GetSize());
556 }
557
QuantizeModel(int input_py_type,int output_py_type,bool allow_float,int activations_py_type,int bias_py_type)558 PyObject* CalibrationWrapper::QuantizeModel(int input_py_type,
559 int output_py_type,
560 bool allow_float,
561 int activations_py_type,
562 int bias_py_type) {
563 return QuantizeModel(input_py_type, output_py_type, allow_float,
564 activations_py_type, bias_py_type,
565 /*disable_per_channel=*/false);
566 }
567
QuantizeModel(int input_py_type,int output_py_type,bool allow_float,int activations_py_type,int bias_py_type,bool disable_per_channel)568 PyObject* CalibrationWrapper::QuantizeModel(
569 int input_py_type, int output_py_type, bool allow_float,
570 int activations_py_type, int bias_py_type, bool disable_per_channel) {
571 if (NoOpModel(*model_)) {
572 return python_utils::ConvertToPyString(model_str_->data(),
573 model_str_->size());
574 }
575
576 TfLiteType input_type = python_utils::TfLiteTypeFromPyType(input_py_type);
577 TfLiteType output_type = python_utils::TfLiteTypeFromPyType(output_py_type);
578 TfLiteType activations_type =
579 python_utils::TfLiteTypeFromPyType(activations_py_type);
580 TfLiteType bias_type = python_utils::TfLiteTypeFromPyType(bias_py_type);
581
582 if (input_type == kTfLiteNoType || output_type == kTfLiteNoType) {
583 PyErr_SetString(PyExc_ValueError,
584 "Input/output type cannot be kTfLiteNoType");
585 return nullptr;
586 }
587 auto tflite_model = CreateMutableModel(*model_->GetModel());
588 reader_->AddCalibrationToModel(tflite_model.get(), /*update=*/false);
589 flatbuffers::FlatBufferBuilder builder;
590 auto status = kTfLiteOk;
591
592 status = tflite::optimize::QuantizeModelAllOperators(
593 &builder, tflite_model.get(), TfLiteTypeToSchemaType(input_type),
594 TfLiteTypeToSchemaType(output_type), allow_float,
595 TfLiteTypeToSchemaType(activations_type),
596 TfLiteTypeToSchemaType(bias_type), disable_per_channel,
597 error_reporter_.get());
598
599 if (status != kTfLiteOk) {
600 error_reporter_->exception();
601 return nullptr;
602 }
603
604 return python_utils::ConvertToPyString(
605 reinterpret_cast<const char*>(builder.GetCurrentBufferPointer()),
606 builder.GetSize());
607 }
608
QuantizeModel(int input_py_type,int output_py_type,bool allow_float,const char * operator_output_name)609 PyObject* CalibrationWrapper::QuantizeModel(int input_py_type,
610 int output_py_type,
611 bool allow_float,
612 const char* operator_output_name) {
613 string op_name = std::string(operator_output_name);
614
615 TfLiteType input_type = python_utils::TfLiteTypeFromPyType(input_py_type);
616 TfLiteType output_type = python_utils::TfLiteTypeFromPyType(output_py_type);
617 if (input_type == kTfLiteNoType || output_type == kTfLiteNoType) {
618 PyErr_SetString(PyExc_ValueError,
619 "Input/output type cannot be kTfLiteNoType");
620 return nullptr;
621 }
622 auto tflite_model = CreateMutableModel(*model_->GetModel());
623 reader_->AddCalibrationToModel(tflite_model.get(), /*update=*/false);
624 flatbuffers::FlatBufferBuilder builder;
625 auto status = tflite::optimize::QuantizeModel(
626 &builder, tflite_model.get(), TfLiteTypeToSchemaType(input_type),
627 TfLiteTypeToSchemaType(output_type), allow_float, {op_name},
628 /*activations_type=*/TensorType_INT8, /*bias_type=*/TensorType_INT32,
629 error_reporter_.get());
630 if (status != kTfLiteOk) {
631 error_reporter_->exception();
632 return nullptr;
633 }
634
635 return python_utils::ConvertToPyString(
636 reinterpret_cast<const char*>(builder.GetCurrentBufferPointer()),
637 builder.GetSize());
638 }
639
CreateWrapperCPPFromBuffer(PyObject * data,const std::vector<std::string> & registerers_by_name,const std::vector<std::function<void (uintptr_t)>> & registerers_by_func,std::string * error_msg)640 /*static*/ CalibrationWrapper* CalibrationWrapper::CreateWrapperCPPFromBuffer(
641 PyObject* data, const std::vector<std::string>& registerers_by_name,
642 const std::vector<std::function<void(uintptr_t)>>& registerers_by_func,
643 std::string* error_msg) {
644 using tflite::interpreter_wrapper::PythonErrorReporter;
645 char* buf = nullptr;
646 Py_ssize_t length;
647 std::unique_ptr<PythonErrorReporter> error_reporter(new PythonErrorReporter);
648 ::tflite::python::ImportNumpy();
649
650 if (python_utils::ConvertFromPyString(data, &buf, &length) == -1) {
651 *error_msg = "Failed to convert from python string";
652 return nullptr;
653 }
654 std::unique_ptr<tflite::FlatBufferModel> model =
655 tflite::FlatBufferModel::BuildFromBuffer(buf, length,
656 error_reporter.get());
657 if (!model) {
658 *error_msg = "Invalid model";
659 return nullptr;
660 }
661 auto resolver = std::make_unique<tflite::ops::builtin::BuiltinOpResolver>();
662 for (const auto& registerer : registerers_by_name) {
663 if (!RegisterCustomOpByName(registerer.c_str(), resolver.get())) {
664 *error_msg =
665 absl::StrFormat("Looking up symbol '%s' failed with error '%s'.",
666 registerer.c_str(), SharedLibrary::GetError());
667 return nullptr;
668 }
669 }
670 for (const auto& registerer : registerers_by_func) {
671 registerer(reinterpret_cast<uintptr_t>(resolver.get()));
672 }
673 std::unique_ptr<tflite::Interpreter> interpreter;
674 std::unique_ptr<tflite::optimize::calibration::CalibrationReader> reader;
675 auto status = tflite::optimize::calibration::BuildLoggingInterpreter(
676 *model, *resolver, &interpreter, &reader);
677 if (status != kTfLiteOk) {
678 *error_msg = error_reporter->message();
679 return nullptr;
680 }
681
682 auto model_str = std::make_unique<std::string>(buf, length);
683 // If we are not going to use this string during quantization, reset the
684 // pointer and release the memory.
685 if (!NoOpModel(*model)) {
686 model_str.reset();
687 }
688
689 auto wrapper = new CalibrationWrapper(
690 std::move(interpreter), std::move(resolver), std::move(error_reporter),
691 std::move(model), std::move(reader), std::move(model_str));
692 return wrapper;
693 }
694
695 } // namespace calibration_wrapper
696 } // namespace tflite
697