1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/python/optimize/calibration_wrapper.h"
16
17 #include <memory>
18 #include <sstream>
19 #include <string>
20
21 #include "absl/memory/memory.h"
22 #include "tensorflow/lite/interpreter.h"
23 #include "tensorflow/lite/kernels/register.h"
24 #include "tensorflow/lite/model.h"
25 #include "tensorflow/lite/python/interpreter_wrapper/numpy.h"
26 #include "tensorflow/lite/python/interpreter_wrapper/python_error_reporter.h"
27 #include "tensorflow/lite/python/interpreter_wrapper/python_utils.h"
28 #include "tensorflow/lite/tools/optimize/calibration/calibration_reader.h"
29 #include "tensorflow/lite/tools/optimize/calibration/calibrator.h"
30 #include "tensorflow/lite/tools/optimize/quantize_model.h"
31
32 #define TFLITE_PY_CHECK(x) \
33 if ((x) != kTfLiteOk) { \
34 return error_reporter_->exception(); \
35 }
36
37 #define TFLITE_PY_ENSURE_VALID_INTERPRETER() \
38 if (!interpreter_) { \
39 PyErr_SetString(PyExc_ValueError, "Interpreter was not initialized."); \
40 return nullptr; \
41 }
42
43 namespace tflite {
44 namespace calibration_wrapper {
45
46 namespace {
47
48 using python_utils::PyDecrefDeleter;
49
CreateMutableModel(const tflite::Model & model)50 std::unique_ptr<tflite::ModelT> CreateMutableModel(const tflite::Model& model) {
51 std::unique_ptr<tflite::ModelT> copied_model =
52 absl::make_unique<tflite::ModelT>();
53 model.UnPackTo(copied_model.get(), nullptr);
54 return copied_model;
55 }
56
57 } // namespace
58
CalibrationWrapper(std::unique_ptr<tflite::Interpreter> interpreter,std::unique_ptr<tflite::ops::builtin::BuiltinOpResolver> resolver,std::unique_ptr<tflite::interpreter_wrapper::PythonErrorReporter> error_reporter,std::unique_ptr<tflite::FlatBufferModel> model,std::unique_ptr<tflite::optimize::calibration::CalibrationReader> reader)59 CalibrationWrapper::CalibrationWrapper(
60 std::unique_ptr<tflite::Interpreter> interpreter,
61 std::unique_ptr<tflite::ops::builtin::BuiltinOpResolver> resolver,
62 std::unique_ptr<tflite::interpreter_wrapper::PythonErrorReporter>
63 error_reporter,
64 std::unique_ptr<tflite::FlatBufferModel> model,
65 std::unique_ptr<tflite::optimize::calibration::CalibrationReader> reader)
66 : interpreter_(std::move(interpreter)),
67 error_reporter_(std::move(error_reporter)),
68 resolver_(std::move(resolver)),
69 model_(std::move(model)),
70 reader_(std::move(reader)) {}
71
~CalibrationWrapper()72 CalibrationWrapper::~CalibrationWrapper() {}
73
Prepare()74 PyObject* CalibrationWrapper::Prepare() {
75 TFLITE_PY_ENSURE_VALID_INTERPRETER();
76 TFLITE_PY_CHECK(interpreter_->AllocateTensors());
77 TFLITE_PY_CHECK(interpreter_->ResetVariableTensors());
78 Py_RETURN_NONE;
79 }
80
FeedTensor(PyObject * input_value)81 PyObject* CalibrationWrapper::FeedTensor(PyObject* input_value) {
82 TFLITE_PY_ENSURE_VALID_INTERPRETER();
83 if (!PyList_Check(input_value)) {
84 PyErr_Format(PyExc_ValueError,
85 "Invalid input type: expected input to be a list.");
86 return nullptr;
87 }
88
89 const size_t inputs_size = PyList_Size(input_value);
90
91 if (inputs_size != interpreter_->inputs().size()) {
92 PyErr_Format(PyExc_ValueError,
93 "Invalid input size: expected %ld items got %ld items.",
94 interpreter_->inputs().size(), inputs_size);
95 return nullptr;
96 }
97
98 for (size_t i = 0; i < inputs_size; i++) {
99 PyObject* input = PyList_GetItem(input_value, i);
100 if (!input) {
101 return nullptr;
102 }
103 int input_tensor_idx = interpreter_->inputs()[i];
104 if (!SetTensor(input_tensor_idx, input)) {
105 return nullptr;
106 }
107 }
108
109 TFLITE_PY_CHECK(interpreter_->Invoke());
110 Py_RETURN_NONE;
111 }
112
SetTensor(int index,PyObject * value)113 PyObject* CalibrationWrapper::SetTensor(int index, PyObject* value) {
114 TFLITE_PY_ENSURE_VALID_INTERPRETER();
115
116 std::unique_ptr<PyObject, PyDecrefDeleter> array_safe(
117 PyArray_FromAny(value, nullptr, 0, 0, NPY_ARRAY_CARRAY, nullptr));
118 if (!array_safe) {
119 PyErr_SetString(PyExc_ValueError,
120 "Failed to convert value into readable tensor.");
121 return nullptr;
122 }
123
124 PyArrayObject* array = reinterpret_cast<PyArrayObject*>(array_safe.get());
125 const TfLiteTensor* tensor = interpreter_->tensor(index);
126
127 if (python_utils::TfLiteTypeFromPyArray(array) != tensor->type) {
128 PyErr_Format(PyExc_ValueError,
129 "Cannot set tensor:"
130 " Got tensor of type %d"
131 " but expected type %d for input %d, name: %s ",
132 python_utils::TfLiteTypeFromPyArray(array), tensor->type,
133 index, tensor->name);
134 return nullptr;
135 }
136
137 if (PyArray_NDIM(array) != tensor->dims->size) {
138 PyErr_SetString(PyExc_ValueError, "Cannot set tensor: Dimension mismatch");
139 return nullptr;
140 }
141
142 for (int j = 0; j < PyArray_NDIM(array); j++) {
143 if (tensor->dims->data[j] != PyArray_SHAPE(array)[j]) {
144 PyErr_SetString(PyExc_ValueError,
145 "Cannot set tensor: Dimension mismatch");
146 return nullptr;
147 }
148 }
149
150 size_t size = PyArray_NBYTES(array);
151 if (size != tensor->bytes) {
152 PyErr_Format(PyExc_ValueError,
153 "numpy array had %zu bytes but expected %zu bytes.", size,
154 tensor->bytes);
155 return nullptr;
156 }
157 memcpy(tensor->data.raw, PyArray_DATA(array), size);
158 Py_RETURN_NONE;
159 }
160
QuantizeModel()161 PyObject* CalibrationWrapper::QuantizeModel() {
162 auto tflite_model = CreateMutableModel(*model_->GetModel());
163 reader_->AddCalibrationToModel(tflite_model.get());
164 flatbuffers::FlatBufferBuilder builder;
165 auto status = tflite::optimize::QuantizeModel(&builder, tflite_model.get(),
166 error_reporter_.get());
167 if (status != kTfLiteOk) {
168 error_reporter_->exception();
169 return nullptr;
170 }
171
172 return python_utils::ConvertToPyString(
173 reinterpret_cast<const char*>(builder.GetCurrentBufferPointer()),
174 builder.GetSize());
175 }
176
CreateWrapperCPPFromBuffer(PyObject * data)177 /*static*/ CalibrationWrapper* CalibrationWrapper::CreateWrapperCPPFromBuffer(
178 PyObject* data) {
179 using tflite::interpreter_wrapper::PythonErrorReporter;
180 char* buf = nullptr;
181 Py_ssize_t length;
182 std::unique_ptr<PythonErrorReporter> error_reporter(new PythonErrorReporter);
183 ::tflite::python::ImportNumpy();
184
185 if (python_utils::ConvertFromPyString(data, &buf, &length) == -1) {
186 return nullptr;
187 }
188 std::unique_ptr<tflite::FlatBufferModel> model =
189 tflite::FlatBufferModel::BuildFromBuffer(buf, length,
190 error_reporter.get());
191 if (!model) {
192 PyErr_Format(PyExc_ValueError, "Invalid model");
193 return nullptr;
194 }
195 auto resolver = absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>();
196 std::unique_ptr<tflite::Interpreter> interpreter;
197 std::unique_ptr<tflite::optimize::calibration::CalibrationReader> reader;
198 auto status = tflite::optimize::calibration::BuildLoggingInterpreter(
199 *model, *resolver, &interpreter, &reader);
200 if (status != kTfLiteOk) {
201 error_reporter->exception();
202 return nullptr;
203 }
204
205 auto wrapper = new CalibrationWrapper(
206 std::move(interpreter), std::move(resolver), std::move(error_reporter),
207 std::move(model), std::move(reader));
208 return wrapper;
209 }
210
211 } // namespace calibration_wrapper
212 } // namespace tflite
213