• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/python/optimize/calibration_wrapper.h"
16 
17 #include <functional>
18 #include <memory>
19 #include <optional>
20 #include <sstream>
21 #include <string>
22 #include <utility>
23 
24 #include "absl/memory/memory.h"
25 #include "absl/strings/str_format.h"
26 #include "absl/types/optional.h"
27 #include "tensorflow/lite/c/common.h"
28 #include "tensorflow/lite/interpreter.h"
29 #include "tensorflow/lite/kernels/register.h"
30 #include "tensorflow/lite/model.h"
31 #include "tensorflow/lite/python/interpreter_wrapper/numpy.h"
32 #include "tensorflow/lite/python/interpreter_wrapper/python_error_reporter.h"
33 #include "tensorflow/lite/python/interpreter_wrapper/python_utils.h"
34 #include "tensorflow/lite/shared_library.h"
35 #include "tensorflow/lite/tools/optimize/calibration/calibration_reader.h"
36 #include "tensorflow/lite/tools/optimize/calibration/calibrator.h"
37 #include "tensorflow/lite/tools/optimize/quantization_wrapper_utils.h"
38 #include "tensorflow/lite/tools/optimize/quantize_model.h"
39 
40 #define TFLITE_PY_CHECK(x)               \
41   if ((x) != kTfLiteOk) {                \
42     return error_reporter_->exception(); \
43   }
44 
45 #define TFLITE_PY_ENSURE_VALID_INTERPRETER()                               \
46   if (!interpreter_) {                                                     \
47     PyErr_SetString(PyExc_ValueError, "Interpreter was not initialized."); \
48     return nullptr;                                                        \
49   }
50 
51 namespace tflite {
52 namespace calibration_wrapper {
53 
54 namespace {
55 
56 using python_utils::PyDecrefDeleter;
57 
CreateMutableModel(const tflite::Model & model)58 std::unique_ptr<tflite::ModelT> CreateMutableModel(const tflite::Model& model) {
59   auto copied_model = std::make_unique<tflite::ModelT>();
60   model.UnPackTo(copied_model.get(), nullptr);
61   return copied_model;
62 }
63 
NoOpModel(const tflite::FlatBufferModel & model)64 bool NoOpModel(const tflite::FlatBufferModel& model) {
65   return model->subgraphs()->size() == 1 &&
66          (!model->subgraphs()->begin()->operators() ||
67           model->subgraphs()->begin()->operators()->size() == 0);
68 }
69 
TfLiteTypeToSchemaType(TfLiteType type)70 inline TensorType TfLiteTypeToSchemaType(TfLiteType type) {
71   switch (type) {
72     case kTfLiteNoType:
73       return TensorType_FLOAT32;  // TODO(b/129336260): No schema type for none.
74     case kTfLiteFloat32:
75       return TensorType_FLOAT32;
76     case kTfLiteFloat16:
77       return TensorType_FLOAT16;
78     case kTfLiteFloat64:
79       return TensorType_FLOAT64;
80     case kTfLiteInt32:
81       return TensorType_INT32;
82     case kTfLiteUInt32:
83       return TensorType_UINT32;
84     case kTfLiteUInt8:
85       return TensorType_UINT8;
86     case kTfLiteInt8:
87       return TensorType_INT8;
88     case kTfLiteInt64:
89       return TensorType_INT64;
90     case kTfLiteUInt64:
91       return TensorType_UINT64;
92     case kTfLiteString:
93       return TensorType_STRING;
94     case kTfLiteBool:
95       return TensorType_BOOL;
96     case kTfLiteInt16:
97       return TensorType_INT16;
98     case kTfLiteUInt16:
99       return TensorType_UINT16;
100     case kTfLiteComplex64:
101       return TensorType_COMPLEX64;
102     case kTfLiteComplex128:
103       return TensorType_COMPLEX128;
104     case kTfLiteResource:
105       return TensorType_RESOURCE;
106     case kTfLiteVariant:
107       return TensorType_VARIANT;
108   }
109   // No default to get compiler error when new type is introduced.
110 }
111 
RegisterCustomOpByName(const char * registerer_name,tflite::MutableOpResolver * resolver)112 bool RegisterCustomOpByName(const char* registerer_name,
113                             tflite::MutableOpResolver* resolver) {
114   // Registerer functions take a pointer to a BuiltinOpResolver as an input
115   // parameter and return void.
116   // TODO(b/137576229): We should implement this functionality in a more
117   // principled way.
118   typedef void (*RegistererFunctionType)(tflite::MutableOpResolver*);
119 
120   // Look for the Registerer function by name.
121   RegistererFunctionType registerer = reinterpret_cast<RegistererFunctionType>(
122       SharedLibrary::GetSymbol(registerer_name));
123 
124   // Fail in an informative way if the function was not found.
125   if (registerer == nullptr) {
126     PyErr_Format(PyExc_ValueError,
127                  "Looking up symbol '%s' failed with error '%s'.",
128                  registerer_name, SharedLibrary::GetError());
129     return false;
130   }
131 
132   // Call the registerer with the resolver.
133   registerer(resolver);
134   return true;
135 }
136 
137 // Returns the dimension from the stored list in the PyObject. If the given
138 // PyObject is not a list, it will return absl::optional and set the Python
139 // error message to notify users.
ConvertInputShapeToVector(PyObject * input_shapes,size_t index)140 std::optional<std::vector<int>> ConvertInputShapeToVector(
141     PyObject* input_shapes, size_t index) {
142   PyObject* shape = PyList_GetItem(input_shapes, index);
143   if (!shape || !PyList_Check(shape)) {
144     PyErr_Format(PyExc_ValueError,
145                  "Invalid %ld input shape: expected to be a list.", index);
146     return std::nullopt;
147   }
148   size_t size = PyList_Size(shape);
149   std::vector<int> dims(size);
150   for (size_t dim_index = 0; dim_index < size; ++dim_index) {
151     PyObject* dim = PyList_GetItem(shape, dim_index);
152     dims[dim_index] = PyLong_AsLong(dim);
153   }
154   return dims;
155 }
156 
157 }  // namespace
158 
AddIntermediateTensors(PyObject * data)159 PyObject* AddIntermediateTensors(PyObject* data) {
160   using tflite::interpreter_wrapper::PythonErrorReporter;
161   char* buf = nullptr;
162   Py_ssize_t length;
163   std::unique_ptr<PythonErrorReporter> error_reporter(new PythonErrorReporter);
164   ::tflite::python::ImportNumpy();
165 
166   if (python_utils::ConvertFromPyString(data, &buf, &length) == -1) {
167     return nullptr;
168   }
169   std::unique_ptr<tflite::FlatBufferModel> model =
170       tflite::FlatBufferModel::BuildFromBuffer(buf, length,
171                                                error_reporter.get());
172   if (!model) {
173     PyErr_Format(PyExc_ValueError, "Invalid model");
174     return nullptr;
175   }
176   flatbuffers::FlatBufferBuilder builder;
177   auto tflite_model = CreateMutableModel(*model->GetModel());
178   if (optimize::AddIntermediateTensorsToFusedOp(&builder, tflite_model.get()) !=
179       kTfLiteOk) {
180     error_reporter->exception();
181     return nullptr;
182   }
183 
184   if (builder.GetSize()) {
185     return python_utils::ConvertToPyString(
186         reinterpret_cast<const char*>(builder.GetCurrentBufferPointer()),
187         builder.GetSize());
188   } else {
189     // When AddIntermediateTensorsToFusedOp early returns, return the model as
190     // it is.
191     return python_utils::ConvertToPyString(buf, length);
192   }
193 }
194 
CalibrationWrapper(std::unique_ptr<tflite::Interpreter> interpreter,std::unique_ptr<tflite::ops::builtin::BuiltinOpResolver> resolver,std::unique_ptr<tflite::interpreter_wrapper::PythonErrorReporter> error_reporter,std::unique_ptr<tflite::FlatBufferModel> model,std::unique_ptr<tflite::optimize::calibration::CalibrationReader> reader,std::unique_ptr<std::string> model_str)195 CalibrationWrapper::CalibrationWrapper(
196     std::unique_ptr<tflite::Interpreter> interpreter,
197     std::unique_ptr<tflite::ops::builtin::BuiltinOpResolver> resolver,
198     std::unique_ptr<tflite::interpreter_wrapper::PythonErrorReporter>
199         error_reporter,
200     std::unique_ptr<tflite::FlatBufferModel> model,
201     std::unique_ptr<tflite::optimize::calibration::CalibrationReader> reader,
202     std::unique_ptr<std::string> model_str)
203     : interpreter_(std::move(interpreter)),
204       error_reporter_(std::move(error_reporter)),
205       resolver_(std::move(resolver)),
206       model_(std::move(model)),
207       reader_(std::move(reader)),
208       model_str_(std::move(model_str)) {}
209 
~CalibrationWrapper()210 CalibrationWrapper::~CalibrationWrapper() {}
211 
Prepare()212 PyObject* CalibrationWrapper::Prepare() {
213   TFLITE_PY_ENSURE_VALID_INTERPRETER();
214   TFLITE_PY_CHECK(interpreter_->AllocateTensors());
215   TFLITE_PY_CHECK(interpreter_->ResetVariableTensors());
216   Py_RETURN_NONE;
217 }
218 
Prepare(std::string signature_key)219 PyObject* CalibrationWrapper::Prepare(std::string signature_key) {
220   TFLITE_PY_ENSURE_VALID_INTERPRETER();
221   SignatureRunner* runner =
222       interpreter_->GetSignatureRunner(signature_key.c_str());
223   if (runner == nullptr) {
224     PyErr_Format(PyExc_ValueError, "Invalid signature key: %s",
225                  signature_key.c_str());
226     return nullptr;
227   }
228   TFLITE_PY_CHECK(runner->AllocateTensors());
229   TFLITE_PY_CHECK(interpreter_->ResetVariableTensors());
230   Py_RETURN_NONE;
231 }
232 
Prepare(PyObject * input_shapes,std::string signature_key)233 PyObject* CalibrationWrapper::Prepare(PyObject* input_shapes,
234                                       std::string signature_key) {
235   TFLITE_PY_ENSURE_VALID_INTERPRETER();
236   if (!PyList_Check(input_shapes)) {
237     PyErr_Format(PyExc_ValueError,
238                  "Invalid input shapes: expected shapes to be a list.");
239     return nullptr;
240   }
241   const int subgraph_index =
242       interpreter_->GetSubgraphIndexFromSignature(signature_key.c_str());
243   if (subgraph_index == -1) {
244     PyErr_Format(PyExc_ValueError, "Invalid signature key: %s",
245                  signature_key.c_str());
246     return nullptr;
247   }
248   auto* subgraph = interpreter_->subgraph(subgraph_index);
249 
250   const size_t inputs_size = PyList_Size(input_shapes);
251   if (inputs_size != subgraph->inputs().size()) {
252     PyErr_Format(PyExc_ValueError,
253                  "Invalid input shapes: expected %ld items got %ld items.",
254                  subgraph->inputs().size(), inputs_size);
255     return nullptr;
256   }
257 
258   for (size_t i = 0; i < inputs_size; ++i) {
259     std::optional<std::vector<int>> dims =
260         ConvertInputShapeToVector(input_shapes, i);
261     if (!dims.has_value()) {
262       return nullptr;
263     }
264     int input_tensor_idx = subgraph->inputs()[i];
265     if (subgraph->ResizeInputTensor(input_tensor_idx, *dims) != kTfLiteOk) {
266       PyErr_Format(PyExc_ValueError, "Failed to resize %ld input tensor.", i);
267       return nullptr;
268     }
269   }
270 
271   return Prepare(signature_key);
272 }
273 
Prepare(PyObject * input_shapes)274 PyObject* CalibrationWrapper::Prepare(PyObject* input_shapes) {
275   TFLITE_PY_ENSURE_VALID_INTERPRETER();
276   if (!PyList_Check(input_shapes)) {
277     PyErr_Format(PyExc_ValueError,
278                  "Invalid input shapes: expected shapes to be a list.");
279     return nullptr;
280   }
281 
282   const size_t inputs_size = PyList_Size(input_shapes);
283   if (inputs_size != interpreter_->inputs().size()) {
284     PyErr_Format(PyExc_ValueError,
285                  "Invalid input shapes: expected %ld items got %ld items.",
286                  interpreter_->inputs().size(), inputs_size);
287     return nullptr;
288   }
289 
290   for (size_t i = 0; i < inputs_size; ++i) {
291     std::optional<std::vector<int>> dims =
292         ConvertInputShapeToVector(input_shapes, i);
293     if (!dims.has_value()) {
294       return nullptr;
295     }
296     int input_tensor_idx = interpreter_->inputs()[i];
297     if (interpreter_->ResizeInputTensor(input_tensor_idx, *dims) != kTfLiteOk) {
298       PyErr_Format(PyExc_ValueError, "Failed to resize %ld input tensor.", i);
299       return nullptr;
300     }
301   }
302 
303   return Prepare();
304 }
305 
FeedTensor(PyObject * input_value,std::string signature_key)306 PyObject* CalibrationWrapper::FeedTensor(PyObject* input_value,
307                                          std::string signature_key) {
308   TFLITE_PY_ENSURE_VALID_INTERPRETER();
309   if (!PyList_Check(input_value)) {
310     PyErr_Format(PyExc_ValueError,
311                  "Invalid input type: expected input to be a list.");
312     return nullptr;
313   }
314   const int subgraph_index =
315       interpreter_->GetSubgraphIndexFromSignature(signature_key.c_str());
316   if (subgraph_index == -1) {
317     PyErr_Format(PyExc_ValueError, "Invalid signature key: %s",
318                  signature_key.c_str());
319     return nullptr;
320   }
321   const size_t inputs_size = PyList_Size(input_value);
322 
323   auto* subgraph = interpreter_->subgraph(subgraph_index);
324   if (inputs_size != subgraph->inputs().size()) {
325     PyErr_Format(PyExc_ValueError,
326                  "Invalid input size: expected %ld items got %ld items.",
327                  subgraph->inputs().size(), inputs_size);
328     return nullptr;
329   }
330 
331   for (size_t i = 0; i < inputs_size; ++i) {
332     PyObject* input = PyList_GetItem(input_value, i);
333     if (!input) {
334       return nullptr;
335     }
336     int input_tensor_idx = subgraph->inputs()[i];
337     if (!SetTensor(input_tensor_idx, input, signature_key)) {
338       return nullptr;
339     }
340   }
341 
342   TFLITE_PY_CHECK(subgraph->Invoke());
343   Py_RETURN_NONE;
344 }
345 
FeedTensor(PyObject * input_value)346 PyObject* CalibrationWrapper::FeedTensor(PyObject* input_value) {
347   TFLITE_PY_ENSURE_VALID_INTERPRETER();
348   if (!PyList_Check(input_value)) {
349     PyErr_Format(PyExc_ValueError,
350                  "Invalid input type: expected input to be a list.");
351     return nullptr;
352   }
353 
354   const size_t inputs_size = PyList_Size(input_value);
355 
356   if (inputs_size != interpreter_->inputs().size()) {
357     PyErr_Format(PyExc_ValueError,
358                  "Invalid input size: expected %ld items got %ld items.",
359                  interpreter_->inputs().size(), inputs_size);
360     return nullptr;
361   }
362 
363   for (size_t i = 0; i < inputs_size; ++i) {
364     PyObject* input = PyList_GetItem(input_value, i);
365     if (!input) {
366       return nullptr;
367     }
368     int input_tensor_idx = interpreter_->inputs()[i];
369     if (!SetTensor(input_tensor_idx, input)) {
370       return nullptr;
371     }
372   }
373 
374   TFLITE_PY_CHECK(interpreter_->Invoke());
375   Py_RETURN_NONE;
376 }
377 
SetTensor(int index,PyObject * value,std::string signature_key)378 PyObject* CalibrationWrapper::SetTensor(int index, PyObject* value,
379                                         std::string signature_key) {
380   TFLITE_PY_ENSURE_VALID_INTERPRETER();
381   std::unique_ptr<PyObject, PyDecrefDeleter> array_safe(
382       PyArray_FromAny(value, nullptr, 0, 0, NPY_ARRAY_CARRAY, nullptr));
383   if (!array_safe) {
384     PyErr_SetString(PyExc_ValueError,
385                     "Failed to convert value into readable tensor.");
386     return nullptr;
387   }
388 
389   PyArrayObject* array = reinterpret_cast<PyArrayObject*>(array_safe.get());
390 
391   const int subgraph_index =
392       interpreter_->GetSubgraphIndexFromSignature(signature_key.c_str());
393   if (subgraph_index == -1) {
394     PyErr_Format(PyExc_ValueError, "Invalid signature key: %s",
395                  signature_key.c_str());
396     return nullptr;
397   }
398   auto* subgraph = interpreter_->subgraph(subgraph_index);
399   const TfLiteTensor* tensor = subgraph->tensor(index);
400 
401   if (python_utils::TfLiteTypeFromPyArray(array) != tensor->type) {
402     PyErr_Format(PyExc_ValueError,
403                  "Cannot set tensor: "
404                  "Got value of type %s "
405                  "but expected type %s for input %d, name: %s ",
406                  TfLiteTypeGetName(python_utils::TfLiteTypeFromPyArray(array)),
407                  TfLiteTypeGetName(tensor->type), index, tensor->name);
408     return nullptr;
409   }
410 
411   if (PyArray_NDIM(array) != tensor->dims->size) {
412     PyErr_Format(PyExc_ValueError,
413                  "Cannot set tensor: Dimension count mismatch, expected %d "
414                  "but found %d",
415                  tensor->dims->size, PyArray_NDIM(array));
416     return nullptr;
417   }
418 
419   std::vector<int> dims(PyArray_NDIM(array));
420   bool has_unknown_dims = false;
421   for (int j = 0; j < PyArray_NDIM(array); ++j) {
422     // Ensure the calibration data input shape is the same as the model input
423     // shape unless the dimension is unknown.
424     if (tensor->dims_signature != nullptr &&
425         tensor->dims_signature->size == tensor->dims->size &&
426         tensor->dims_signature->data[j] == -1) {
427       has_unknown_dims = true;
428     } else if (tensor->dims->data[j] != PyArray_SHAPE(array)[j]) {
429       PyErr_Format(PyExc_ValueError,
430                    "Cannot set tensor: Size mismatch, expected %d for dim "
431                    "%d but found %ld",
432                    tensor->dims->data[j], j, PyArray_SHAPE(array)[j]);
433       return nullptr;
434     }
435     dims[j] = PyArray_SHAPE(array)[j];
436   }
437 
438   // Resize the input tensor if there are unknown dimensions.
439   if (has_unknown_dims) {
440     // Does strict checking on the `ResizeInputTensor` call.
441     TFLITE_PY_CHECK(subgraph->ResizeInputTensorStrict(index, dims));
442     TFLITE_PY_CHECK(subgraph->AllocateTensors());
443   }
444 
445   // Re-read the updated tensor after the allocation is done.
446   tensor = subgraph->tensor(index);
447 
448   size_t size = PyArray_NBYTES(array);
449 
450   if (tensor->type == kTfLiteString) {
451     tflite::DynamicBuffer buffer;
452     buffer.AddString(reinterpret_cast<const char*>(PyArray_BYTES(array)), size);
453     buffer.WriteToTensor(subgraph->tensor(index), /*new_shape=*/nullptr);
454     Py_RETURN_NONE;
455   }
456 
457   if (size != tensor->bytes) {
458     PyErr_Format(PyExc_ValueError,
459                  "numpy array had %zu bytes but expected %zu bytes.", size,
460                  tensor->bytes);
461     return nullptr;
462   }
463   memcpy(tensor->data.raw, PyArray_DATA(array), size);
464   Py_RETURN_NONE;
465 }
466 
SetTensor(int index,PyObject * value)467 PyObject* CalibrationWrapper::SetTensor(int index, PyObject* value) {
468   TFLITE_PY_ENSURE_VALID_INTERPRETER();
469 
470   std::unique_ptr<PyObject, PyDecrefDeleter> array_safe(
471       PyArray_FromAny(value, nullptr, 0, 0, NPY_ARRAY_CARRAY, nullptr));
472   if (!array_safe) {
473     PyErr_SetString(PyExc_ValueError,
474                     "Failed to convert value into readable tensor.");
475     return nullptr;
476   }
477 
478   PyArrayObject* array = reinterpret_cast<PyArrayObject*>(array_safe.get());
479   const TfLiteTensor* tensor = interpreter_->tensor(index);
480 
481   if (python_utils::TfLiteTypeFromPyArray(array) != tensor->type) {
482     PyErr_Format(PyExc_ValueError,
483                  "Cannot set tensor: "
484                  "Got value of type %s "
485                  "but expected type %s for input %d, name: %s ",
486                  TfLiteTypeGetName(python_utils::TfLiteTypeFromPyArray(array)),
487                  TfLiteTypeGetName(tensor->type), index, tensor->name);
488     return nullptr;
489   }
490 
491   if (PyArray_NDIM(array) != tensor->dims->size) {
492     PyErr_Format(
493         PyExc_ValueError,
494         "Cannot set tensor: Dimension count mismatch, expected %d but found %d",
495         tensor->dims->size, PyArray_NDIM(array));
496     return nullptr;
497   }
498 
499   std::vector<int> dims(PyArray_NDIM(array));
500   bool has_unknown_dims = false;
501   for (int j = 0; j < PyArray_NDIM(array); ++j) {
502     // Ensure the calibration data input shape is the same as the model input
503     // shape unless the dimension is unknown.
504     if (tensor->dims_signature != nullptr &&
505         tensor->dims_signature->size == tensor->dims->size &&
506         tensor->dims_signature->data[j] == -1) {
507       has_unknown_dims = true;
508     } else if (tensor->dims->data[j] != PyArray_SHAPE(array)[j]) {
509       PyErr_Format(PyExc_ValueError,
510                    "Cannot set tensor: Size mismatch, expected %d for dim "
511                    "%d but found %ld",
512                    tensor->dims->data[j], j, PyArray_SHAPE(array)[j]);
513       return nullptr;
514     }
515     dims[j] = PyArray_SHAPE(array)[j];
516   }
517 
518   // Resize the input tensor if there are unknown dimensions.
519   if (has_unknown_dims) {
520     // Does strict checking on the `ResizeInputTensor` call.
521     TFLITE_PY_CHECK(interpreter_->ResizeInputTensorStrict(index, dims));
522     TFLITE_PY_CHECK(interpreter_->AllocateTensors());
523   }
524 
525   // Re-read the updated tensor after the allocation is done.
526   tensor = interpreter_->tensor(index);
527 
528   size_t size = PyArray_NBYTES(array);
529 
530   if (tensor->type == kTfLiteString) {
531     tflite::DynamicBuffer buffer;
532     buffer.AddString(reinterpret_cast<const char*>(PyArray_BYTES(array)), size);
533     buffer.WriteToTensor(interpreter_->tensor(index), /*new_shape=*/nullptr);
534     Py_RETURN_NONE;
535   }
536 
537   if (size != tensor->bytes) {
538     PyErr_Format(PyExc_ValueError,
539                  "numpy array had %zu bytes but expected %zu bytes.", size,
540                  tensor->bytes);
541     return nullptr;
542   }
543   memcpy(tensor->data.raw, PyArray_DATA(array), size);
544   Py_RETURN_NONE;
545 }
546 
Calibrate()547 PyObject* CalibrationWrapper::Calibrate() {
548   auto tflite_model = CreateMutableModel(*model_->GetModel());
549   reader_->AddCalibrationToModel(tflite_model.get(), /*update=*/false);
550   flatbuffers::FlatBufferBuilder builder;
551   auto loc = tflite::Model::Pack(builder, tflite_model.get());
552   tflite::FinishModelBuffer(builder, loc);
553   return python_utils::ConvertToPyString(
554       reinterpret_cast<const char*>(builder.GetCurrentBufferPointer()),
555       builder.GetSize());
556 }
557 
QuantizeModel(int input_py_type,int output_py_type,bool allow_float,int activations_py_type,int bias_py_type)558 PyObject* CalibrationWrapper::QuantizeModel(int input_py_type,
559                                             int output_py_type,
560                                             bool allow_float,
561                                             int activations_py_type,
562                                             int bias_py_type) {
563   return QuantizeModel(input_py_type, output_py_type, allow_float,
564                        activations_py_type, bias_py_type,
565                        /*disable_per_channel=*/false);
566 }
567 
QuantizeModel(int input_py_type,int output_py_type,bool allow_float,int activations_py_type,int bias_py_type,bool disable_per_channel)568 PyObject* CalibrationWrapper::QuantizeModel(
569     int input_py_type, int output_py_type, bool allow_float,
570     int activations_py_type, int bias_py_type, bool disable_per_channel) {
571   if (NoOpModel(*model_)) {
572     return python_utils::ConvertToPyString(model_str_->data(),
573                                            model_str_->size());
574   }
575 
576   TfLiteType input_type = python_utils::TfLiteTypeFromPyType(input_py_type);
577   TfLiteType output_type = python_utils::TfLiteTypeFromPyType(output_py_type);
578   TfLiteType activations_type =
579       python_utils::TfLiteTypeFromPyType(activations_py_type);
580   TfLiteType bias_type = python_utils::TfLiteTypeFromPyType(bias_py_type);
581 
582   if (input_type == kTfLiteNoType || output_type == kTfLiteNoType) {
583     PyErr_SetString(PyExc_ValueError,
584                     "Input/output type cannot be kTfLiteNoType");
585     return nullptr;
586   }
587   auto tflite_model = CreateMutableModel(*model_->GetModel());
588   reader_->AddCalibrationToModel(tflite_model.get(), /*update=*/false);
589   flatbuffers::FlatBufferBuilder builder;
590   auto status = kTfLiteOk;
591 
592   status = tflite::optimize::QuantizeModelAllOperators(
593       &builder, tflite_model.get(), TfLiteTypeToSchemaType(input_type),
594       TfLiteTypeToSchemaType(output_type), allow_float,
595       TfLiteTypeToSchemaType(activations_type),
596       TfLiteTypeToSchemaType(bias_type), disable_per_channel,
597       error_reporter_.get());
598 
599   if (status != kTfLiteOk) {
600     error_reporter_->exception();
601     return nullptr;
602   }
603 
604   return python_utils::ConvertToPyString(
605       reinterpret_cast<const char*>(builder.GetCurrentBufferPointer()),
606       builder.GetSize());
607 }
608 
QuantizeModel(int input_py_type,int output_py_type,bool allow_float,const char * operator_output_name)609 PyObject* CalibrationWrapper::QuantizeModel(int input_py_type,
610                                             int output_py_type,
611                                             bool allow_float,
612                                             const char* operator_output_name) {
613   string op_name = std::string(operator_output_name);
614 
615   TfLiteType input_type = python_utils::TfLiteTypeFromPyType(input_py_type);
616   TfLiteType output_type = python_utils::TfLiteTypeFromPyType(output_py_type);
617   if (input_type == kTfLiteNoType || output_type == kTfLiteNoType) {
618     PyErr_SetString(PyExc_ValueError,
619                     "Input/output type cannot be kTfLiteNoType");
620     return nullptr;
621   }
622   auto tflite_model = CreateMutableModel(*model_->GetModel());
623   reader_->AddCalibrationToModel(tflite_model.get(), /*update=*/false);
624   flatbuffers::FlatBufferBuilder builder;
625   auto status = tflite::optimize::QuantizeModel(
626       &builder, tflite_model.get(), TfLiteTypeToSchemaType(input_type),
627       TfLiteTypeToSchemaType(output_type), allow_float, {op_name},
628       /*activations_type=*/TensorType_INT8, /*bias_type=*/TensorType_INT32,
629       error_reporter_.get());
630   if (status != kTfLiteOk) {
631     error_reporter_->exception();
632     return nullptr;
633   }
634 
635   return python_utils::ConvertToPyString(
636       reinterpret_cast<const char*>(builder.GetCurrentBufferPointer()),
637       builder.GetSize());
638 }
639 
CreateWrapperCPPFromBuffer(PyObject * data,const std::vector<std::string> & registerers_by_name,const std::vector<std::function<void (uintptr_t)>> & registerers_by_func,std::string * error_msg)640 /*static*/ CalibrationWrapper* CalibrationWrapper::CreateWrapperCPPFromBuffer(
641     PyObject* data, const std::vector<std::string>& registerers_by_name,
642     const std::vector<std::function<void(uintptr_t)>>& registerers_by_func,
643     std::string* error_msg) {
644   using tflite::interpreter_wrapper::PythonErrorReporter;
645   char* buf = nullptr;
646   Py_ssize_t length;
647   std::unique_ptr<PythonErrorReporter> error_reporter(new PythonErrorReporter);
648   ::tflite::python::ImportNumpy();
649 
650   if (python_utils::ConvertFromPyString(data, &buf, &length) == -1) {
651     *error_msg = "Failed to convert from python string";
652     return nullptr;
653   }
654   std::unique_ptr<tflite::FlatBufferModel> model =
655       tflite::FlatBufferModel::BuildFromBuffer(buf, length,
656                                                error_reporter.get());
657   if (!model) {
658     *error_msg = "Invalid model";
659     return nullptr;
660   }
661   auto resolver = std::make_unique<tflite::ops::builtin::BuiltinOpResolver>();
662   for (const auto& registerer : registerers_by_name) {
663     if (!RegisterCustomOpByName(registerer.c_str(), resolver.get())) {
664       *error_msg =
665           absl::StrFormat("Looking up symbol '%s' failed with error '%s'.",
666                           registerer.c_str(), SharedLibrary::GetError());
667       return nullptr;
668     }
669   }
670   for (const auto& registerer : registerers_by_func) {
671     registerer(reinterpret_cast<uintptr_t>(resolver.get()));
672   }
673   std::unique_ptr<tflite::Interpreter> interpreter;
674   std::unique_ptr<tflite::optimize::calibration::CalibrationReader> reader;
675   auto status = tflite::optimize::calibration::BuildLoggingInterpreter(
676       *model, *resolver, &interpreter, &reader);
677   if (status != kTfLiteOk) {
678     *error_msg = error_reporter->message();
679     return nullptr;
680   }
681 
682   auto model_str = std::make_unique<std::string>(buf, length);
683   // If we are not going to use this string during quantization, reset the
684   // pointer and release the memory.
685   if (!NoOpModel(*model)) {
686     model_str.reset();
687   }
688 
689   auto wrapper = new CalibrationWrapper(
690       std::move(interpreter), std::move(resolver), std::move(error_reporter),
691       std::move(model), std::move(reader), std::move(model_str));
692   return wrapper;
693 }
694 
695 }  // namespace calibration_wrapper
696 }  // namespace tflite
697