1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/python/optimize/calibration_wrapper.h"
16
17 #include <memory>
18 #include <sstream>
19 #include <string>
20
21 #include "absl/memory/memory.h"
22 #include "tensorflow/lite/c/common.h"
23 #include "tensorflow/lite/interpreter.h"
24 #include "tensorflow/lite/kernels/register.h"
25 #include "tensorflow/lite/model.h"
26 #include "tensorflow/lite/python/interpreter_wrapper/numpy.h"
27 #include "tensorflow/lite/python/interpreter_wrapper/python_error_reporter.h"
28 #include "tensorflow/lite/python/interpreter_wrapper/python_utils.h"
29 #include "tensorflow/lite/tools/optimize/calibration/calibration_reader.h"
30 #include "tensorflow/lite/tools/optimize/calibration/calibrator.h"
31 #include "tensorflow/lite/tools/optimize/quantization_wrapper_utils.h"
32 #include "tensorflow/lite/tools/optimize/quantize_model.h"
33
34 #define TFLITE_PY_CHECK(x) \
35 if ((x) != kTfLiteOk) { \
36 return error_reporter_->exception(); \
37 }
38
39 #define TFLITE_PY_ENSURE_VALID_INTERPRETER() \
40 if (!interpreter_) { \
41 PyErr_SetString(PyExc_ValueError, "Interpreter was not initialized."); \
42 return nullptr; \
43 }
44
45 namespace tflite {
46 namespace calibration_wrapper {
47
48 namespace {
49
50 using python_utils::PyDecrefDeleter;
51
CreateMutableModel(const tflite::Model & model)52 std::unique_ptr<tflite::ModelT> CreateMutableModel(const tflite::Model& model) {
53 auto copied_model = absl::make_unique<tflite::ModelT>();
54 model.UnPackTo(copied_model.get(), nullptr);
55 return copied_model;
56 }
57
NoOpModel(const tflite::FlatBufferModel & model)58 bool NoOpModel(const tflite::FlatBufferModel& model) {
59 return model->subgraphs()->size() == 1 &&
60 (!model->subgraphs()->begin()->operators() ||
61 model->subgraphs()->begin()->operators()->size() == 0);
62 }
63
TfLiteTypeToSchemaType(TfLiteType type)64 inline TensorType TfLiteTypeToSchemaType(TfLiteType type) {
65 switch (type) {
66 case kTfLiteNoType:
67 return TensorType_FLOAT32; // TODO(b/129336260): No schema type for none.
68 case kTfLiteFloat32:
69 return TensorType_FLOAT32;
70 case kTfLiteFloat16:
71 return TensorType_FLOAT16;
72 case kTfLiteFloat64:
73 return TensorType_FLOAT64;
74 case kTfLiteInt32:
75 return TensorType_INT32;
76 case kTfLiteUInt32:
77 return TensorType_UINT32;
78 case kTfLiteUInt8:
79 return TensorType_UINT8;
80 case kTfLiteInt8:
81 return TensorType_INT8;
82 case kTfLiteInt64:
83 return TensorType_INT64;
84 case kTfLiteUInt64:
85 return TensorType_UINT64;
86 case kTfLiteString:
87 return TensorType_STRING;
88 case kTfLiteBool:
89 return TensorType_BOOL;
90 case kTfLiteInt16:
91 return TensorType_INT16;
92 case kTfLiteComplex64:
93 return TensorType_COMPLEX64;
94 case kTfLiteComplex128:
95 return TensorType_COMPLEX128;
96 case kTfLiteResource:
97 return TensorType_RESOURCE;
98 case kTfLiteVariant:
99 return TensorType_VARIANT;
100 }
101 // No default to get compiler error when new type is introduced.
102 }
103
104 } // namespace
105
AddIntermediateTensors(PyObject * data)106 PyObject* AddIntermediateTensors(PyObject* data) {
107 using tflite::interpreter_wrapper::PythonErrorReporter;
108 char* buf = nullptr;
109 Py_ssize_t length;
110 std::unique_ptr<PythonErrorReporter> error_reporter(new PythonErrorReporter);
111 ::tflite::python::ImportNumpy();
112
113 if (python_utils::ConvertFromPyString(data, &buf, &length) == -1) {
114 return nullptr;
115 }
116 std::unique_ptr<tflite::FlatBufferModel> model =
117 tflite::FlatBufferModel::BuildFromBuffer(buf, length,
118 error_reporter.get());
119 if (!model) {
120 PyErr_Format(PyExc_ValueError, "Invalid model");
121 return nullptr;
122 }
123 flatbuffers::FlatBufferBuilder builder;
124 auto tflite_model = CreateMutableModel(*model->GetModel());
125 if (optimize::AddIntermediateTensorsToFusedOp(&builder, tflite_model.get()) !=
126 kTfLiteOk) {
127 error_reporter->exception();
128 return nullptr;
129 }
130
131 if (builder.GetSize()) {
132 return python_utils::ConvertToPyString(
133 reinterpret_cast<const char*>(builder.GetCurrentBufferPointer()),
134 builder.GetSize());
135 } else {
136 // When AddIntermediateTensorsToFusedOp early returns, return the model as
137 // it is.
138 return python_utils::ConvertToPyString(buf, length);
139 }
140 }
141
CalibrationWrapper(std::unique_ptr<tflite::Interpreter> interpreter,std::unique_ptr<tflite::ops::builtin::BuiltinOpResolver> resolver,std::unique_ptr<tflite::interpreter_wrapper::PythonErrorReporter> error_reporter,std::unique_ptr<tflite::FlatBufferModel> model,std::unique_ptr<tflite::optimize::calibration::CalibrationReader> reader,std::unique_ptr<std::string> model_str)142 CalibrationWrapper::CalibrationWrapper(
143 std::unique_ptr<tflite::Interpreter> interpreter,
144 std::unique_ptr<tflite::ops::builtin::BuiltinOpResolver> resolver,
145 std::unique_ptr<tflite::interpreter_wrapper::PythonErrorReporter>
146 error_reporter,
147 std::unique_ptr<tflite::FlatBufferModel> model,
148 std::unique_ptr<tflite::optimize::calibration::CalibrationReader> reader,
149 std::unique_ptr<std::string> model_str)
150 : interpreter_(std::move(interpreter)),
151 error_reporter_(std::move(error_reporter)),
152 resolver_(std::move(resolver)),
153 model_(std::move(model)),
154 reader_(std::move(reader)),
155 model_str_(std::move(model_str)) {}
156
~CalibrationWrapper()157 CalibrationWrapper::~CalibrationWrapper() {}
158
Prepare()159 PyObject* CalibrationWrapper::Prepare() {
160 TFLITE_PY_ENSURE_VALID_INTERPRETER();
161 TFLITE_PY_CHECK(interpreter_->AllocateTensors());
162 TFLITE_PY_CHECK(interpreter_->ResetVariableTensors());
163 Py_RETURN_NONE;
164 }
165
Prepare(PyObject * input_shapes)166 PyObject* CalibrationWrapper::Prepare(PyObject* input_shapes) {
167 TFLITE_PY_ENSURE_VALID_INTERPRETER();
168 if (!PyList_Check(input_shapes)) {
169 PyErr_Format(PyExc_ValueError,
170 "Invalid input shapes: expected shapes to be a list.");
171 return nullptr;
172 }
173
174 const size_t inputs_size = PyList_Size(input_shapes);
175 if (inputs_size != interpreter_->inputs().size()) {
176 PyErr_Format(PyExc_ValueError,
177 "Invalid input shapes: expected %ld items got %ld items.",
178 interpreter_->inputs().size(), inputs_size);
179 return nullptr;
180 }
181
182 for (size_t i = 0; i < inputs_size; i++) {
183 PyObject* shape = PyList_GetItem(input_shapes, i);
184 if (!shape || !PyList_Check(shape)) {
185 PyErr_Format(PyExc_ValueError,
186 "Invalid %ld input shape: expected to be a list.", i);
187 return nullptr;
188 }
189 std::vector<int> dims;
190 for (size_t dim_index = 0; dim_index < PyList_Size(shape); ++dim_index) {
191 PyObject* dim = PyList_GetItem(shape, dim_index);
192 dims.push_back(PyLong_AsLong(dim));
193 }
194 int input_tensor_idx = interpreter_->inputs()[i];
195 if (interpreter_->ResizeInputTensor(input_tensor_idx, dims) != kTfLiteOk) {
196 PyErr_Format(PyExc_ValueError, "Failed to resize %ld input tensor.", i);
197 return nullptr;
198 }
199 }
200
201 return Prepare();
202 }
203
FeedTensor(PyObject * input_value)204 PyObject* CalibrationWrapper::FeedTensor(PyObject* input_value) {
205 TFLITE_PY_ENSURE_VALID_INTERPRETER();
206 if (!PyList_Check(input_value)) {
207 PyErr_Format(PyExc_ValueError,
208 "Invalid input type: expected input to be a list.");
209 return nullptr;
210 }
211
212 const size_t inputs_size = PyList_Size(input_value);
213
214 if (inputs_size != interpreter_->inputs().size()) {
215 PyErr_Format(PyExc_ValueError,
216 "Invalid input size: expected %ld items got %ld items.",
217 interpreter_->inputs().size(), inputs_size);
218 return nullptr;
219 }
220
221 for (size_t i = 0; i < inputs_size; i++) {
222 PyObject* input = PyList_GetItem(input_value, i);
223 if (!input) {
224 return nullptr;
225 }
226 int input_tensor_idx = interpreter_->inputs()[i];
227 if (!SetTensor(input_tensor_idx, input)) {
228 return nullptr;
229 }
230 }
231
232 TFLITE_PY_CHECK(interpreter_->Invoke());
233 Py_RETURN_NONE;
234 }
235
SetTensor(int index,PyObject * value)236 PyObject* CalibrationWrapper::SetTensor(int index, PyObject* value) {
237 TFLITE_PY_ENSURE_VALID_INTERPRETER();
238
239 std::unique_ptr<PyObject, PyDecrefDeleter> array_safe(
240 PyArray_FromAny(value, nullptr, 0, 0, NPY_ARRAY_CARRAY, nullptr));
241 if (!array_safe) {
242 PyErr_SetString(PyExc_ValueError,
243 "Failed to convert value into readable tensor.");
244 return nullptr;
245 }
246
247 PyArrayObject* array = reinterpret_cast<PyArrayObject*>(array_safe.get());
248 const TfLiteTensor* tensor = interpreter_->tensor(index);
249
250 if (python_utils::TfLiteTypeFromPyArray(array) != tensor->type) {
251 PyErr_Format(PyExc_ValueError,
252 "Cannot set tensor:"
253 " Got value of type %s"
254 " but expected type %s for input %d, name: %s ",
255 TfLiteTypeGetName(python_utils::TfLiteTypeFromPyArray(array)),
256 TfLiteTypeGetName(tensor->type), index, tensor->name);
257 return nullptr;
258 }
259
260 if (PyArray_NDIM(array) != tensor->dims->size) {
261 PyErr_Format(
262 PyExc_ValueError,
263 "Cannot set tensor: Dimension count mismatch, expected %d but found %d",
264 tensor->dims->size, PyArray_NDIM(array));
265 return nullptr;
266 }
267
268 std::vector<int> dims(PyArray_NDIM(array));
269 bool has_unknown_dims = false;
270 for (int j = 0; j < PyArray_NDIM(array); j++) {
271 // Ensure the calibration data input shape is the same as the model input
272 // shape unless the dimension is unknown.
273 if (tensor->dims_signature != nullptr &&
274 tensor->dims_signature->size == tensor->dims->size &&
275 tensor->dims_signature->data[j] == -1) {
276 has_unknown_dims = true;
277 } else if (tensor->dims->data[j] != PyArray_SHAPE(array)[j]) {
278 PyErr_Format(PyExc_ValueError,
279 "Cannot set tensor: Size mismatch, expected %d for dim "
280 "%d but found %ld",
281 tensor->dims->data[j], j, PyArray_SHAPE(array)[j]);
282 return nullptr;
283 }
284 dims[j] = PyArray_SHAPE(array)[j];
285 }
286
287 // Resize the input tensor if there are unknown dimensions.
288 if (has_unknown_dims) {
289 // Does strict checking on the `ResizeInputTensor` call.
290 TFLITE_PY_CHECK(interpreter_->ResizeInputTensorStrict(index, dims));
291 TFLITE_PY_CHECK(interpreter_->AllocateTensors());
292 }
293
294 tensor = interpreter_->tensor(index);
295
296 size_t size = PyArray_NBYTES(array);
297
298 if (tensor->type == kTfLiteString) {
299 tflite::DynamicBuffer buffer;
300 buffer.AddString(reinterpret_cast<const char*>(PyArray_BYTES(array)), size);
301 buffer.WriteToTensor(interpreter_->tensor(index), /*new_shape=*/nullptr);
302 Py_RETURN_NONE;
303 }
304
305 if (size != tensor->bytes) {
306 PyErr_Format(PyExc_ValueError,
307 "numpy array had %zu bytes but expected %zu bytes.", size,
308 tensor->bytes);
309 return nullptr;
310 }
311 memcpy(tensor->data.raw, PyArray_DATA(array), size);
312 Py_RETURN_NONE;
313 }
314
Calibrate()315 PyObject* CalibrationWrapper::Calibrate() {
316 auto tflite_model = CreateMutableModel(*model_->GetModel());
317 reader_->AddCalibrationToModel(tflite_model.get(), /*update=*/false);
318 flatbuffers::FlatBufferBuilder builder;
319 auto loc = tflite::Model::Pack(builder, tflite_model.get());
320 tflite::FinishModelBuffer(builder, loc);
321 return python_utils::ConvertToPyString(
322 reinterpret_cast<const char*>(builder.GetCurrentBufferPointer()),
323 builder.GetSize());
324 }
325
QuantizeModel(int input_py_type,int output_py_type,bool allow_float,int activations_py_type)326 PyObject* CalibrationWrapper::QuantizeModel(int input_py_type,
327 int output_py_type,
328 bool allow_float,
329 int activations_py_type) {
330 if (NoOpModel(*model_)) {
331 return python_utils::ConvertToPyString(model_str_->data(),
332 model_str_->size());
333 }
334
335 TfLiteType input_type = python_utils::TfLiteTypeFromPyType(input_py_type);
336 TfLiteType output_type = python_utils::TfLiteTypeFromPyType(output_py_type);
337 TfLiteType activations_type =
338 python_utils::TfLiteTypeFromPyType(activations_py_type);
339
340 if (input_type == kTfLiteNoType || output_type == kTfLiteNoType) {
341 PyErr_SetString(PyExc_ValueError,
342 "Input/output type cannot be kTfLiteNoType");
343 return nullptr;
344 }
345 auto tflite_model = CreateMutableModel(*model_->GetModel());
346 reader_->AddCalibrationToModel(tflite_model.get(), /*update=*/false);
347 flatbuffers::FlatBufferBuilder builder;
348 auto status = kTfLiteOk;
349
350 status = tflite::optimize::QuantizeModelAllOperators(
351 &builder, tflite_model.get(), TfLiteTypeToSchemaType(input_type),
352 TfLiteTypeToSchemaType(output_type), allow_float,
353 TfLiteTypeToSchemaType(activations_type), error_reporter_.get());
354
355 if (status != kTfLiteOk) {
356 error_reporter_->exception();
357 return nullptr;
358 }
359
360 return python_utils::ConvertToPyString(
361 reinterpret_cast<const char*>(builder.GetCurrentBufferPointer()),
362 builder.GetSize());
363 }
364
QuantizeModel(int input_py_type,int output_py_type,bool allow_float,const char * operator_output_name)365 PyObject* CalibrationWrapper::QuantizeModel(int input_py_type,
366 int output_py_type,
367 bool allow_float,
368 const char* operator_output_name) {
369 string op_name = std::string(operator_output_name);
370
371 TfLiteType input_type = python_utils::TfLiteTypeFromPyType(input_py_type);
372 TfLiteType output_type = python_utils::TfLiteTypeFromPyType(output_py_type);
373 if (input_type == kTfLiteNoType || output_type == kTfLiteNoType) {
374 PyErr_SetString(PyExc_ValueError,
375 "Input/output type cannot be kTfLiteNoType");
376 return nullptr;
377 }
378 auto tflite_model = CreateMutableModel(*model_->GetModel());
379 reader_->AddCalibrationToModel(tflite_model.get(), /*update=*/false);
380 flatbuffers::FlatBufferBuilder builder;
381 auto status = tflite::optimize::QuantizeModel(
382 &builder, tflite_model.get(), TfLiteTypeToSchemaType(input_type),
383 TfLiteTypeToSchemaType(output_type), allow_float, {op_name},
384 TensorType_INT8, error_reporter_.get());
385 if (status != kTfLiteOk) {
386 error_reporter_->exception();
387 return nullptr;
388 }
389
390 return python_utils::ConvertToPyString(
391 reinterpret_cast<const char*>(builder.GetCurrentBufferPointer()),
392 builder.GetSize());
393 }
394
CreateWrapperCPPFromBuffer(PyObject * data)395 /*static*/ CalibrationWrapper* CalibrationWrapper::CreateWrapperCPPFromBuffer(
396 PyObject* data) {
397 using tflite::interpreter_wrapper::PythonErrorReporter;
398 char* buf = nullptr;
399 Py_ssize_t length;
400 std::unique_ptr<PythonErrorReporter> error_reporter(new PythonErrorReporter);
401 ::tflite::python::ImportNumpy();
402
403 if (python_utils::ConvertFromPyString(data, &buf, &length) == -1) {
404 return nullptr;
405 }
406 std::unique_ptr<tflite::FlatBufferModel> model =
407 tflite::FlatBufferModel::BuildFromBuffer(buf, length,
408 error_reporter.get());
409 if (!model) {
410 PyErr_Format(PyExc_ValueError, "Invalid model");
411 return nullptr;
412 }
413 auto resolver = absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>();
414 std::unique_ptr<tflite::Interpreter> interpreter;
415 std::unique_ptr<tflite::optimize::calibration::CalibrationReader> reader;
416 auto status = tflite::optimize::calibration::BuildLoggingInterpreter(
417 *model, *resolver, &interpreter, &reader);
418 if (status != kTfLiteOk) {
419 error_reporter->exception();
420 return nullptr;
421 }
422
423 auto model_str = std::make_unique<std::string>(buf, length);
424 // If we are not going to use this string during quantization, reset the
425 // pointer and release the memory.
426 if (!NoOpModel(*model)) {
427 model_str.reset();
428 }
429
430 auto wrapper = new CalibrationWrapper(
431 std::move(interpreter), std::move(resolver), std::move(error_reporter),
432 std::move(model), std::move(reader), std::move(model_str));
433 return wrapper;
434 }
435
436 } // namespace calibration_wrapper
437 } // namespace tflite
438