1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/tools/optimize/calibration/calibrator.h"
16
17 #include <fstream>
18 #include <memory>
19 #include <string>
20 #include <unordered_map>
21 #include <unordered_set>
22 #include <vector>
23
24 #include "absl/memory/memory.h"
25 #include "tensorflow/lite/core/api/error_reporter.h"
26 #include "tensorflow/lite/core/api/op_resolver.h"
27 #include "tensorflow/lite/interpreter.h"
28 #include "tensorflow/lite/kernels/register.h"
29 #include "tensorflow/lite/model.h"
30 #include "tensorflow/lite/op_resolver.h"
31 #include "tensorflow/lite/schema/schema_generated.h"
32 #include "tensorflow/lite/string_util.h"
33 #include "tensorflow/lite/tools/optimize/calibration/calibration_common.h"
34 #include "tensorflow/lite/tools/optimize/calibration/calibration_logger.h"
35 #include "tensorflow/lite/tools/optimize/calibration/calibration_reader.h"
36 #include "tensorflow/lite/tools/optimize/calibration/logging_op_resolver.h"
37 #include "tensorflow/lite/tools/optimize/calibration/node_info_delegate.h"
38
39 namespace tflite {
40 namespace optimize {
41 namespace calibration {
42
43 namespace {
44
45 // Calibrator is used to hold information that can be accessed during kernel
46 // invocations.
47 // TfLite kernel invocations are C functions and cannot look at the global
48 // structure of the graph. Calibrator allows the kernel invoke functions to
49 // access the global structure of graph and know which node is currently being
50 // executed. This also allows us to write a simple kernel invoke wrapper
51 // (see LoggingEval) that can work for most builtin ops.
52 class Calibrator {
53 public:
Calibrator(const std::unordered_map<const TfLiteNode *,OperatorInfo> & node_ptr_opinfo_map,std::unique_ptr<LoggingOpResolver> logging_op_resolver)54 Calibrator(const std::unordered_map<const TfLiteNode*, OperatorInfo>&
55 node_ptr_opinfo_map,
56 std::unique_ptr<LoggingOpResolver> logging_op_resolver)
57 : node_ptr_opinfo_map_(node_ptr_opinfo_map),
58 logging_op_resolver_(std::move(logging_op_resolver)) {
59 logger_ = absl::make_unique<Logger>();
60 }
61
62 // Returns the wrapped kernel invoke function |TfLiteRegistration.invoke|.
63 KernelEvalFuncPtr GetKernelInvoke(const TfLiteNode* node) const;
64
65 // Gets the instance of logger associated with the current context.
GetLogger() const66 Logger* GetLogger() const { return logger_.get(); }
67
68 // Gets the operator information about the given TfLiteNode.
GetOpInfo(const TfLiteNode * node) const69 const OperatorInfo& GetOpInfo(const TfLiteNode* node) const {
70 return node_ptr_opinfo_map_.at(node);
71 }
72
73 private:
74 std::unordered_map<const TfLiteNode*, OperatorInfo> node_ptr_opinfo_map_;
75 std::unique_ptr<LoggingOpResolver> logging_op_resolver_;
76 const std::unordered_map<int, OperatorInfo> index_opinfo_;
77 std::unique_ptr<Logger> logger_;
78 };
79
GetKernelInvoke(const TfLiteNode * node) const80 KernelEvalFuncPtr Calibrator::GetKernelInvoke(const TfLiteNode* node) const {
81 auto op_info = node_ptr_opinfo_map_.at(node);
82 return logging_op_resolver_->GetWrappedKernelInvoke(op_info.builtin_op_code,
83 1);
84 }
85
86 // A registry of |Calibrator| objects per |TfLiteContext|.
87 // This global registry is needed to access |Calibrator| objects in the kernel
88 // invoke functions i.e. |TfLiteRegistration.invoke|.
89 // Kernel invoke functions are C functions that have limited access to
90 // |TfLiteContext|. Kernel invoke functions don't have access to global state of
91 // graph. That means during a kernel invocation, the function cannot know which
92 // node it was invoked for. E.g. in case of a model with |Conv| op at two
93 // locations, there is no easy way for the Conv.invoke function to disambiguate
94 // the calls.
95 //
96 // For calibration we solve this problem by creating a map of calibrators
97 // per |TfLiteContext|. This map is |GlobalCalibrationRegistry|.
98 //
99 // This registry is then accessed using a global getter function:
100 // |GetCalibratorRegistry|.
101 // E.g.
102 // TfLiteStatus SomeKernelInvokeFn(TfLiteContext* context, TfLiteNode* node) {
103 // .... code ....
104 // auto registry = GetCalibratorRegistry();
105 // auto calibrator = registry->GetCalibrator(context);
106 // ..... code ....
107 // }
108 //
109 // This way the kernel invoke functions can get the access to the Calibrator
110 // object associated with the |TfLiteContext|.
111 class GlobalCalibratorRegistry {
112 public:
113 // Get the |Calibrator| associated with given context, returns null if no
114 // calibrator is associated with the given context.
GetCalibrator(const TfLiteContext * context) const115 Calibrator* GetCalibrator(const TfLiteContext* context) const {
116 if (calibrator_registry_.find(context) == calibrator_registry_.cend()) {
117 return nullptr;
118 }
119 return calibrator_registry_.at(context).get();
120 }
121
122 // Removes the association between calibrator and context.
123 // Note: This deletes the calibrator as well.
RemoveCalibrator(const TfLiteContext * context)124 void RemoveCalibrator(const TfLiteContext* context) {
125 calibrator_registry_.erase(context);
126 }
127
128 // Creates an instance of |Calibrator|.
129 // Registry owns the |Calibrator| object which can be deleted by calling
130 // |RemoveCalibrator|.
CreateCalibrator(const TfLiteContext * context,const std::unordered_map<const TfLiteNode *,OperatorInfo> & node_to_opinfo,std::unique_ptr<LoggingOpResolver> logging_op_resolver,Calibrator ** calibrator_ptr,ErrorReporter * reporter)131 TfLiteStatus CreateCalibrator(
132 const TfLiteContext* context,
133 const std::unordered_map<const TfLiteNode*, OperatorInfo>& node_to_opinfo,
134 std::unique_ptr<LoggingOpResolver> logging_op_resolver,
135 Calibrator** calibrator_ptr, ErrorReporter* reporter) {
136 if (calibrator_registry_.find(context) != calibrator_registry_.cend()) {
137 reporter->Report(
138 "Failed to create calibrator, context already registered.");
139 return kTfLiteError;
140 }
141 std::unique_ptr<Calibrator> calibrator = absl::make_unique<Calibrator>(
142 node_to_opinfo, std::move(logging_op_resolver));
143 calibrator_registry_[context] = std::move(calibrator);
144 *calibrator_ptr = calibrator_registry_.at(context).get();
145 return kTfLiteOk;
146 }
147
148 private:
149 std::unordered_map<const TfLiteContext*, std::unique_ptr<Calibrator>>
150 calibrator_registry_;
151 };
152
GetCalibratorRegistry()153 GlobalCalibratorRegistry* GetCalibratorRegistry() {
154 static GlobalCalibratorRegistry* registry = new GlobalCalibratorRegistry();
155 return registry;
156 }
157
158 // A wrapper implementation for |TfLiteRegistration.invoke| that logs inputs,
159 // invokes the wrapped implementation and then logs the outputs.
LoggingEval(TfLiteContext * context,TfLiteNode * node)160 TfLiteStatus LoggingEval(TfLiteContext* context, TfLiteNode* node) {
161 Calibrator* calibrator = GetCalibratorRegistry()->GetCalibrator(context);
162
163 if (!calibrator) {
164 context->ReportError(context, "No calibrator found for context.");
165 return kTfLiteError;
166 }
167
168 auto kernel_invoke = calibrator->GetKernelInvoke(node);
169 auto logger = calibrator->GetLogger();
170 auto op_info = calibrator->GetOpInfo(node);
171
172 for (int i : op_info.loggable_inputs) {
173 auto tensor = context->tensors[i];
174 logger->LogTensorValue(i, tensor.data.f, tensor.bytes / sizeof(float));
175 }
176
177 auto status = kernel_invoke(context, node);
178 // TODO(shashishekhar): An intermediate tensor in graph will get logged twice
179 // once as an input and second time as output. This doesn't change the min max
180 // values but is inefficient.
181 // Using moving average will also break this.
182
183 for (int i : op_info.loggable_outputs) {
184 auto tensor = context->tensors[i];
185 logger->LogTensorValue(i, tensor.data.f, tensor.bytes / sizeof(float));
186 }
187
188 return status;
189 }
190
191 // Returns the loggable tensors. Not all inputs and outputs need to be logged.
192 // For example, const weight tensors which have buffers associated with them
193 // don't need to be logged.
GetLoggableTensorIndices(const std::vector<int> & tensor_indices,const flatbuffers::Vector<flatbuffers::Offset<Tensor>> * tensors,const flatbuffers::Vector<flatbuffers::Offset<Buffer>> * tensor_buffers)194 std::vector<int> GetLoggableTensorIndices(
195 const std::vector<int>& tensor_indices,
196 const flatbuffers::Vector<flatbuffers::Offset<Tensor>>* tensors,
197 const flatbuffers::Vector<flatbuffers::Offset<Buffer>>* tensor_buffers) {
198 std::vector<int> loggable;
199 for (auto tensor_index : tensor_indices) {
200 auto tensor = tensors->Get(tensor_index);
201 auto buffer_index = tensor->buffer();
202 const bool has_no_buffer =
203 (tensor_buffers->Get(buffer_index) == nullptr) ||
204 (tensor_buffers->Get(buffer_index)->data() == nullptr) ||
205 (tensor_buffers->Get(buffer_index)->data()->size() == 0);
206 if (has_no_buffer && tensor->type() == tflite::TensorType_FLOAT32) {
207 loggable.push_back(tensor_index);
208 }
209 }
210 return loggable;
211 }
212
213 // Creates a mapping between the static model graph and the runtime TfLiteNode*
214 // nodes in the graph for the given context.
215 // This is done by querying the TfLiteContext for node and registrations using
216 // the |NodeInfoDelegateObserver|.
GetNodeOpInfoMapAndContext(const std::unordered_map<int,OperatorInfo> & node_to_opinfo,tflite::Interpreter * const interpreter,std::unordered_map<const TfLiteNode *,OperatorInfo> * node_ptr_opinfo_map,const TfLiteContext ** context)217 TfLiteStatus GetNodeOpInfoMapAndContext(
218 const std::unordered_map<int, OperatorInfo>& node_to_opinfo,
219 tflite::Interpreter* const interpreter,
220 std::unordered_map<const TfLiteNode*, OperatorInfo>* node_ptr_opinfo_map,
221 const TfLiteContext** context
222
223 ) {
224 NodeInfoDelegateObserver delegate_observer(node_to_opinfo,
225 node_ptr_opinfo_map);
226 NodeInfoDelegateParams delegate_params;
227 delegate_params.delegate_observer = &delegate_observer;
228 TfLiteDelegate logging_delegate = CreateNodeInfoDelegate(&delegate_params);
229
230 auto modify_status = interpreter->ModifyGraphWithDelegate(&logging_delegate);
231 if (modify_status != kTfLiteOk) {
232 return kTfLiteError;
233 }
234 *context = delegate_observer.GetContext();
235 return kTfLiteOk;
236 }
237
GetOpName(const tflite::OperatorCode & opcode)238 string GetOpName(const tflite::OperatorCode& opcode) {
239 if (opcode.custom_code() != nullptr) {
240 return opcode.custom_code()->str();
241 }
242 return tflite::EnumNamesBuiltinOperator()[opcode.builtin_code()];
243 }
244
245 // A |CalibrationReader| that owns the Calibrator.
246 class Reader : public CalibrationReader {
247 public:
Reader(const TfLiteContext * context,const Logger * logger)248 Reader(const TfLiteContext* context, const Logger* logger)
249 : CalibrationReader(logger), context_(context) {}
250
~Reader()251 ~Reader() override { GetCalibratorRegistry()->RemoveCalibrator(context_); }
252
253 private:
254 const TfLiteContext* context_;
255 };
256
257 } // namespace
258
BuildLoggingInterpreter(const FlatBufferModel & model,const OpResolver & op_resolver,std::unique_ptr<Interpreter> * interpreter,std::unique_ptr<CalibrationReader> * calibration_reader)259 TfLiteStatus BuildLoggingInterpreter(
260 const FlatBufferModel& model, const OpResolver& op_resolver,
261 std::unique_ptr<Interpreter>* interpreter,
262 std::unique_ptr<CalibrationReader>* calibration_reader) {
263 auto tflite_model = model.GetModel();
264 auto subgraphs = tflite_model->subgraphs();
265 auto tensor_buffers = tflite_model->buffers();
266
267 if (subgraphs->size() != 1) {
268 model.error_reporter()->Report(
269 "Only models with a single subgraph are supported, model had %d "
270 "subgraphs",
271 subgraphs->size());
272 return kTfLiteError;
273 }
274
275 // Populate the node index to operator info map.
276 // We want to collect this information so we can use it during runtime to
277 // log details of which inputs and outputs.
278 // At runtime TFLite kernel invoke functions can only look into their
279 // own node in the graph (TFLiteNode*) and some limited context information.
280 auto primary_subgraph = subgraphs->Get(0);
281 auto operator_codes = tflite_model->operator_codes();
282 auto operators = primary_subgraph->operators();
283 auto tensors = primary_subgraph->tensors();
284 std::unordered_map<int, OperatorInfo> node_to_opinfo;
285 BuiltinOpsSet op_and_versions;
286
287 for (size_t i = 0; i < operators->size(); i++) {
288 OperatorInfo op_info;
289 op_info.node_index = i;
290 auto op = operators->Get(i);
291 auto operator_code = operator_codes->Get(op->opcode_index());
292 op_info.builtin_op_code = operator_code->builtin_code();
293 op_info.name = GetOpName(*operator_code);
294 op_info.is_custom_op = operator_code->custom_code() != nullptr;
295
296 auto op_inputs = op->inputs();
297 auto op_outputs = op->outputs();
298 op_info.inputs = std::vector<int>(op_inputs->begin(), op_inputs->end());
299 op_info.outputs = std::vector<int>(op_outputs->begin(), op_outputs->end());
300 op_info.loggable_inputs =
301 GetLoggableTensorIndices(op_info.inputs, tensors, tensor_buffers);
302 op_info.loggable_outputs =
303 GetLoggableTensorIndices(op_info.outputs, tensors, tensor_buffers);
304 if (!op_info.is_custom_op) {
305 op_info.registration = op_resolver.FindOp(operator_code->builtin_code(),
306 operator_code->version());
307 } else {
308 op_info.registration =
309 op_resolver.FindOp(op_info.name.c_str(), operator_code->version());
310 }
311 node_to_opinfo[i] = op_info;
312 op_and_versions.insert({op_info.builtin_op_code, operator_code->version()});
313 }
314
315 // Prepare the logging op resolver to use |LoggingEval| for kernel
316 // invocations.
317 auto logging_op_resolver = absl::make_unique<LoggingOpResolver>(
318 op_and_versions, op_resolver, LoggingEval);
319 tflite::InterpreterBuilder(model, *logging_op_resolver)(interpreter);
320
321 if (!(*interpreter)) {
322 model.error_reporter()->Report("Failed to construct interpreter");
323 return kTfLiteError;
324 }
325
326 // Compute the mapping between runtime and static graph structure, i.e.
327 // (TfLiteContext, TfLiteNode) -> OperatorInfo
328 std::unordered_map<const TfLiteNode*, OperatorInfo> node_ptr_opinfo_map;
329 const TfLiteContext* context = nullptr;
330 GetNodeOpInfoMapAndContext(node_to_opinfo, interpreter->get(),
331 &node_ptr_opinfo_map, &context);
332
333 Calibrator* calibrator = nullptr;
334 // Register a calibrator object for the context. This can be accessed
335 // during invocations by the logging kernels.
336 TF_LITE_ENSURE_STATUS(GetCalibratorRegistry()->CreateCalibrator(
337 context, node_ptr_opinfo_map, std::move(logging_op_resolver), &calibrator,
338 model.error_reporter()));
339 *calibration_reader = std::unique_ptr<CalibrationReader>(
340 new Reader(context, calibrator->GetLogger()));
341
342 return kTfLiteOk;
343 }
344
345 } // namespace calibration
346 } // namespace optimize
347 } // namespace tflite
348