• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/python/optimize/calibration_wrapper.h"
16 
17 #include <memory>
18 #include <sstream>
19 #include <string>
20 
21 #include "absl/memory/memory.h"
22 #include "tensorflow/lite/interpreter.h"
23 #include "tensorflow/lite/kernels/register.h"
24 #include "tensorflow/lite/model.h"
25 #include "tensorflow/lite/python/interpreter_wrapper/numpy.h"
26 #include "tensorflow/lite/python/interpreter_wrapper/python_error_reporter.h"
27 #include "tensorflow/lite/python/interpreter_wrapper/python_utils.h"
28 #include "tensorflow/lite/tools/optimize/calibration/calibration_reader.h"
29 #include "tensorflow/lite/tools/optimize/calibration/calibrator.h"
30 #include "tensorflow/lite/tools/optimize/quantize_model.h"
31 
32 #define TFLITE_PY_CHECK(x)               \
33   if ((x) != kTfLiteOk) {                \
34     return error_reporter_->exception(); \
35   }
36 
37 #define TFLITE_PY_ENSURE_VALID_INTERPRETER()                               \
38   if (!interpreter_) {                                                     \
39     PyErr_SetString(PyExc_ValueError, "Interpreter was not initialized."); \
40     return nullptr;                                                        \
41   }
42 
43 namespace tflite {
44 namespace calibration_wrapper {
45 
46 namespace {
47 
48 using python_utils::PyDecrefDeleter;
49 
CreateMutableModel(const tflite::Model & model)50 std::unique_ptr<tflite::ModelT> CreateMutableModel(const tflite::Model& model) {
51   std::unique_ptr<tflite::ModelT> copied_model =
52       absl::make_unique<tflite::ModelT>();
53   model.UnPackTo(copied_model.get(), nullptr);
54   return copied_model;
55 }
56 
57 }  // namespace
58 
CalibrationWrapper(std::unique_ptr<tflite::Interpreter> interpreter,std::unique_ptr<tflite::ops::builtin::BuiltinOpResolver> resolver,std::unique_ptr<tflite::interpreter_wrapper::PythonErrorReporter> error_reporter,std::unique_ptr<tflite::FlatBufferModel> model,std::unique_ptr<tflite::optimize::calibration::CalibrationReader> reader)59 CalibrationWrapper::CalibrationWrapper(
60     std::unique_ptr<tflite::Interpreter> interpreter,
61     std::unique_ptr<tflite::ops::builtin::BuiltinOpResolver> resolver,
62     std::unique_ptr<tflite::interpreter_wrapper::PythonErrorReporter>
63         error_reporter,
64     std::unique_ptr<tflite::FlatBufferModel> model,
65     std::unique_ptr<tflite::optimize::calibration::CalibrationReader> reader)
66     : interpreter_(std::move(interpreter)),
67       error_reporter_(std::move(error_reporter)),
68       resolver_(std::move(resolver)),
69       model_(std::move(model)),
70       reader_(std::move(reader)) {}
71 
~CalibrationWrapper()72 CalibrationWrapper::~CalibrationWrapper() {}
73 
Prepare()74 PyObject* CalibrationWrapper::Prepare() {
75   TFLITE_PY_ENSURE_VALID_INTERPRETER();
76   TFLITE_PY_CHECK(interpreter_->AllocateTensors());
77   TFLITE_PY_CHECK(interpreter_->ResetVariableTensors());
78   Py_RETURN_NONE;
79 }
80 
FeedTensor(PyObject * input_value)81 PyObject* CalibrationWrapper::FeedTensor(PyObject* input_value) {
82   TFLITE_PY_ENSURE_VALID_INTERPRETER();
83   if (!PyList_Check(input_value)) {
84     PyErr_Format(PyExc_ValueError,
85                  "Invalid input type: expected input to be a list.");
86     return nullptr;
87   }
88 
89   const size_t inputs_size = PyList_Size(input_value);
90 
91   if (inputs_size != interpreter_->inputs().size()) {
92     PyErr_Format(PyExc_ValueError,
93                  "Invalid input size: expected %ld items got %ld items.",
94                  interpreter_->inputs().size(), inputs_size);
95     return nullptr;
96   }
97 
98   for (size_t i = 0; i < inputs_size; i++) {
99     PyObject* input = PyList_GetItem(input_value, i);
100     if (!input) {
101       return nullptr;
102     }
103     int input_tensor_idx = interpreter_->inputs()[i];
104     if (!SetTensor(input_tensor_idx, input)) {
105       return nullptr;
106     }
107   }
108 
109   TFLITE_PY_CHECK(interpreter_->Invoke());
110   Py_RETURN_NONE;
111 }
112 
SetTensor(int index,PyObject * value)113 PyObject* CalibrationWrapper::SetTensor(int index, PyObject* value) {
114   TFLITE_PY_ENSURE_VALID_INTERPRETER();
115 
116   std::unique_ptr<PyObject, PyDecrefDeleter> array_safe(
117       PyArray_FromAny(value, nullptr, 0, 0, NPY_ARRAY_CARRAY, nullptr));
118   if (!array_safe) {
119     PyErr_SetString(PyExc_ValueError,
120                     "Failed to convert value into readable tensor.");
121     return nullptr;
122   }
123 
124   PyArrayObject* array = reinterpret_cast<PyArrayObject*>(array_safe.get());
125   const TfLiteTensor* tensor = interpreter_->tensor(index);
126 
127   if (python_utils::TfLiteTypeFromPyArray(array) != tensor->type) {
128     PyErr_Format(PyExc_ValueError,
129                  "Cannot set tensor:"
130                  " Got tensor of type %d"
131                  " but expected type %d for input %d, name: %s ",
132                  python_utils::TfLiteTypeFromPyArray(array), tensor->type,
133                  index, tensor->name);
134     return nullptr;
135   }
136 
137   if (PyArray_NDIM(array) != tensor->dims->size) {
138     PyErr_SetString(PyExc_ValueError, "Cannot set tensor: Dimension mismatch");
139     return nullptr;
140   }
141 
142   for (int j = 0; j < PyArray_NDIM(array); j++) {
143     if (tensor->dims->data[j] != PyArray_SHAPE(array)[j]) {
144       PyErr_SetString(PyExc_ValueError,
145                       "Cannot set tensor: Dimension mismatch");
146       return nullptr;
147     }
148   }
149 
150   size_t size = PyArray_NBYTES(array);
151   if (size != tensor->bytes) {
152     PyErr_Format(PyExc_ValueError,
153                  "numpy array had %zu bytes but expected %zu bytes.", size,
154                  tensor->bytes);
155     return nullptr;
156   }
157   memcpy(tensor->data.raw, PyArray_DATA(array), size);
158   Py_RETURN_NONE;
159 }
160 
QuantizeModel()161 PyObject* CalibrationWrapper::QuantizeModel() {
162   auto tflite_model = CreateMutableModel(*model_->GetModel());
163   reader_->AddCalibrationToModel(tflite_model.get());
164   flatbuffers::FlatBufferBuilder builder;
165   auto status = tflite::optimize::QuantizeModel(&builder, tflite_model.get(),
166                                                 error_reporter_.get());
167   if (status != kTfLiteOk) {
168     error_reporter_->exception();
169     return nullptr;
170   }
171 
172   return python_utils::ConvertToPyString(
173       reinterpret_cast<const char*>(builder.GetCurrentBufferPointer()),
174       builder.GetSize());
175 }
176 
CreateWrapperCPPFromBuffer(PyObject * data)177 /*static*/ CalibrationWrapper* CalibrationWrapper::CreateWrapperCPPFromBuffer(
178     PyObject* data) {
179   using tflite::interpreter_wrapper::PythonErrorReporter;
180   char* buf = nullptr;
181   Py_ssize_t length;
182   std::unique_ptr<PythonErrorReporter> error_reporter(new PythonErrorReporter);
183   ::tflite::python::ImportNumpy();
184 
185   if (python_utils::ConvertFromPyString(data, &buf, &length) == -1) {
186     return nullptr;
187   }
188   std::unique_ptr<tflite::FlatBufferModel> model =
189       tflite::FlatBufferModel::BuildFromBuffer(buf, length,
190                                                error_reporter.get());
191   if (!model) {
192     PyErr_Format(PyExc_ValueError, "Invalid model");
193     return nullptr;
194   }
195   auto resolver = absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>();
196   std::unique_ptr<tflite::Interpreter> interpreter;
197   std::unique_ptr<tflite::optimize::calibration::CalibrationReader> reader;
198   auto status = tflite::optimize::calibration::BuildLoggingInterpreter(
199       *model, *resolver, &interpreter, &reader);
200   if (status != kTfLiteOk) {
201     error_reporter->exception();
202     return nullptr;
203   }
204 
205   auto wrapper = new CalibrationWrapper(
206       std::move(interpreter), std::move(resolver), std::move(error_reporter),
207       std::move(model), std::move(reader));
208   return wrapper;
209 }
210 
211 }  // namespace calibration_wrapper
212 }  // namespace tflite
213