1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/tools/optimize/calibration/calibrator.h"
16
17 #include <fstream>
18 #include <memory>
19 #include <string>
20 #include <unordered_map>
21 #include <unordered_set>
22 #include <utility>
23 #include <vector>
24
25 #include "absl/container/flat_hash_map.h"
26 #include "absl/memory/memory.h"
27 #include "tensorflow/lite/c/common.h"
28 #include "tensorflow/lite/core/api/error_reporter.h"
29 #include "tensorflow/lite/core/api/op_resolver.h"
30 #include "tensorflow/lite/interpreter.h"
31 #include "tensorflow/lite/kernels/kernel_util.h"
32 #include "tensorflow/lite/kernels/register.h"
33 #include "tensorflow/lite/model.h"
34 #include "tensorflow/lite/op_resolver.h"
35 #include "tensorflow/lite/schema/schema_generated.h"
36 #include "tensorflow/lite/schema/schema_utils.h"
37 #include "tensorflow/lite/stderr_reporter.h"
38 #include "tensorflow/lite/string_util.h"
39 #include "tensorflow/lite/tools/optimize/calibration/builtin_logging_ops/lstm.h"
40 #include "tensorflow/lite/tools/optimize/calibration/calibration_common.h"
41 #include "tensorflow/lite/tools/optimize/calibration/calibration_logger.h"
42 #include "tensorflow/lite/tools/optimize/calibration/calibration_reader.h"
43 #include "tensorflow/lite/tools/optimize/calibration/custom_logging_ops/lstm.h"
44 #include "tensorflow/lite/tools/optimize/calibration/logging_op.h"
45 #include "tensorflow/lite/tools/optimize/calibration/logging_op_resolver.h"
46
47 namespace tflite {
48 namespace optimize {
49 namespace calibration {
50
51 namespace {
52
53 // Calibrator is used to hold information that can be accessed during kernel
54 // invocations.
55 // TfLite kernel invocations are C functions and cannot look at the global
56 // structure of the graph. Calibrator allows the kernel invoke functions to
57 // access the global structure of graph and know which node is currently being
58 // executed. This also allows us to write a simple kernel invoke wrapper
59 // (see LoggingEval) that can work for most builtin ops.
60 class Calibrator {
61 public:
Calibrator(const std::unordered_map<const TfLiteNode *,OperatorInfo> & node_ptr_opinfo_map,std::unique_ptr<LoggingOpResolver> logging_op_resolver,ErrorReporter * error_reporter)62 Calibrator(const std::unordered_map<const TfLiteNode*, OperatorInfo>&
63 node_ptr_opinfo_map,
64 std::unique_ptr<LoggingOpResolver> logging_op_resolver,
65 ErrorReporter* error_reporter)
66 : node_ptr_opinfo_map_(node_ptr_opinfo_map),
67 logging_op_resolver_(std::move(logging_op_resolver)),
68 error_reporter_(error_reporter) {
69 logger_ = absl::make_unique<Logger>();
70 }
71
72 // Returns the wrapped kernel invoke function |TfLiteRegistration.invoke|.
73 KernelEvalFuncPtr GetKernelInvoke(const TfLiteNode* node) const;
74
75 // Gets the instance of logger associated with the current context.
GetLogger() const76 Logger* GetLogger() const { return logger_.get(); }
77
78 // Gets the error reporter.
GetErrorReporter() const79 ErrorReporter* GetErrorReporter() const { return error_reporter_; }
80
81 // Gets the operator information about the given TfLiteNode.
GetOpInfo(const TfLiteNode * node) const82 const OperatorInfo& GetOpInfo(const TfLiteNode* node) const {
83 return node_ptr_opinfo_map_.at(node);
84 }
85
GetNodesUnderCalibration()86 std::vector<const TfLiteNode*> GetNodesUnderCalibration() {
87 std::vector<const TfLiteNode*> nodes;
88 for (const auto& entry : node_ptr_opinfo_map_) {
89 nodes.push_back(entry.first);
90 }
91 return nodes;
92 }
93
94 private:
95 std::unordered_map<const TfLiteNode*, OperatorInfo> node_ptr_opinfo_map_;
96 std::unique_ptr<LoggingOpResolver> logging_op_resolver_;
97 const std::unordered_map<int, OperatorInfo> index_opinfo_;
98 std::unique_ptr<Logger> logger_;
99 ErrorReporter* error_reporter_;
100 };
101
GetKernelInvoke(const TfLiteNode * node) const102 KernelEvalFuncPtr Calibrator::GetKernelInvoke(const TfLiteNode* node) const {
103 auto op_info = node_ptr_opinfo_map_.at(node);
104 if (op_info.is_custom_op) {
105 return logging_op_resolver_->GetWrappedKernelInvoke(op_info.name.c_str(),
106 op_info.version);
107 }
108 return logging_op_resolver_->GetWrappedKernelInvoke(op_info.builtin_op_code,
109 op_info.version);
110 }
111
112 // A registry of |Calibrator| objects per |TfLiteContext|.
113 // This global registry is needed to access |Calibrator| objects in the kernel
114 // invoke functions i.e. |TfLiteRegistration.invoke|.
115 // Kernel invoke functions are C functions that have limited access to
116 // |TfLiteContext|. Kernel invoke functions don't have access to global state of
117 // graph. That means during a kernel invocation, the function cannot know which
118 // node it was invoked for. E.g. in case of a model with |Conv| op at two
119 // locations, there is no easy way for the Conv.invoke function to disambiguate
120 // the calls.
121 //
122 // For calibration we solve this problem by creating a map of calibrators
123 // per |TfLiteContext|. This map is |GlobalCalibrationRegistry|.
124 //
125 // This registry is then accessed using a global getter function:
126 // |GetCalibratorRegistry|.
127 // E.g.
128 // TfLiteStatus SomeKernelInvokeFn(TfLiteContext* context, TfLiteNode* node) {
129 // .... code ....
130 // auto registry = GetCalibratorRegistry();
131 // auto calibrator = registry->GetCalibrator(context);
132 // ..... code ....
133 // }
134 //
135 // This way the kernel invoke functions can get the access to the Calibrator
136 // object associated with the |TfLiteContext|.
137 class GlobalCalibratorRegistry {
138 public:
139 // Get the |Calibrator| associated with given context, returns null if no
140 // calibrator is associated with the given context.
GetCalibrator(const TfLiteNode * node) const141 Calibrator* GetCalibrator(const TfLiteNode* node) const {
142 if (node_to_calibrator_.find(node) == node_to_calibrator_.cend()) {
143 return nullptr;
144 }
145 return node_to_calibrator_.at(node);
146 }
147
148 // Removes the association between calibrator and context.
149 // Note: This deletes the calibrator as well.
RemoveCalibrator(const TfLiteContext * context)150 void RemoveCalibrator(const TfLiteContext* context) {
151 Calibrator* calibrator = calibrator_registry_.at(context).get();
152 auto nodes = calibrator->GetNodesUnderCalibration();
153 for (auto node : nodes) {
154 node_to_calibrator_.erase(node);
155 }
156 calibrator_registry_.erase(context);
157 }
158
159 // Creates an instance of |Calibrator|.
160 // Registry owns the |Calibrator| object which can be deleted by calling
161 // |RemoveCalibrator|.
CreateCalibrator(const TfLiteContext * context,const std::unordered_map<const TfLiteNode *,OperatorInfo> & node_to_opinfo,std::unique_ptr<LoggingOpResolver> logging_op_resolver,Calibrator ** calibrator_ptr,ErrorReporter * reporter)162 TfLiteStatus CreateCalibrator(
163 const TfLiteContext* context,
164 const std::unordered_map<const TfLiteNode*, OperatorInfo>& node_to_opinfo,
165 std::unique_ptr<LoggingOpResolver> logging_op_resolver,
166 Calibrator** calibrator_ptr, ErrorReporter* reporter) {
167 if (calibrator_registry_.find(context) != calibrator_registry_.cend()) {
168 reporter->Report(
169 "Failed to create calibrator, context already registered.");
170 return kTfLiteError;
171 }
172 auto calibrator = absl::make_unique<Calibrator>(
173 node_to_opinfo, std::move(logging_op_resolver), reporter);
174 calibrator_registry_[context] = std::move(calibrator);
175 *calibrator_ptr = calibrator_registry_.at(context).get();
176 for (const auto& entry : node_to_opinfo) {
177 node_to_calibrator_[entry.first] = *calibrator_ptr;
178 }
179 return kTfLiteOk;
180 }
181
182 private:
183 absl::flat_hash_map<const TfLiteContext*, std::unique_ptr<Calibrator>>
184 calibrator_registry_;
185 absl::flat_hash_map<const TfLiteNode*, Calibrator*> node_to_calibrator_;
186 };
187
GetCalibratorRegistry()188 GlobalCalibratorRegistry* GetCalibratorRegistry() {
189 static GlobalCalibratorRegistry* registry = new GlobalCalibratorRegistry();
190 return registry;
191 }
192
193 // Get the logging kernel if there are any.
194 // TODO(jianlijianli): extend this to support multiple recipe for the same
195 // model.
GetLoggingEvalFunc(TfLiteContext * context,TfLiteNode * node,int builtin_op_code)196 logging_kernel_func_ptr GetLoggingEvalFunc(TfLiteContext* context,
197 TfLiteNode* node,
198 int builtin_op_code) {
199 switch (builtin_op_code) {
200 case BuiltinOperator_LSTM: {
201 if (node->intermediates->size == 12) {
202 return tflite::optimize::calibration::custom::lstm_logging_kernel;
203 }
204 return tflite::optimize::calibration::builtin::lstm_logging_kernel;
205 }
206 case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM:
207 return tflite::optimize::calibration::builtin::
208 unidirectional_sequence_lstm_logging_kernel;
209 default:
210 return nullptr;
211 }
212 }
213
214 // A wrapper implementation for |TfLiteRegistration.invoke| that logs inputs,
215 // invokes the wrapped implementation and then logs the outputs.
LoggingEval(TfLiteContext * context,TfLiteNode * node)216 TfLiteStatus LoggingEval(TfLiteContext* context, TfLiteNode* node) {
217 Calibrator* calibrator = GetCalibratorRegistry()->GetCalibrator(node);
218
219 if (!calibrator) {
220 context->ReportError(context, "No calibrator found for context.");
221 return kTfLiteError;
222 }
223
224 auto kernel_invoke = calibrator->GetKernelInvoke(node);
225 auto logger = calibrator->GetLogger();
226 auto op_info = calibrator->GetOpInfo(node);
227 auto error_reporter = calibrator->GetErrorReporter();
228
229 for (int i : op_info.loggable_inputs) {
230 auto tensor = context->tensors[i];
231 TF_LITE_ENSURE_STATUS(
232 logger->LogTensorValue(op_info.subgraph_index, i, tensor.data.f,
233 tensor.bytes / sizeof(float), error_reporter));
234 }
235 auto builtin_op_code = calibrator->GetOpInfo(node).builtin_op_code;
236 auto kernel_invoke_intermediate =
237 GetLoggingEvalFunc(context, node, builtin_op_code);
238 if (kernel_invoke_intermediate == nullptr) {
239 TF_LITE_ENSURE_STATUS(kernel_invoke(context, node));
240 } else {
241 TF_LITE_ENSURE_STATUS(
242 kernel_invoke_intermediate(context, op_info.subgraph_index, node,
243 calibrator->GetLogger(), error_reporter));
244 }
245
246 // TODO(shashishekhar): An intermediate tensor in graph will get logged twice
247 // once as an input and second time as output. This doesn't change the min max
248 // values but is inefficient.
249 // Using moving average will also break this.
250
251 // Log input again to make sure the state tensors are captured after lstm
252 // cell.
253 for (int i : op_info.loggable_inputs) {
254 auto tensor = context->tensors[i];
255 TF_LITE_ENSURE_STATUS(
256 logger->LogTensorValue(op_info.subgraph_index, i, tensor.data.f,
257 tensor.bytes / sizeof(float), error_reporter));
258 }
259
260 for (int i : op_info.loggable_outputs) {
261 auto tensor = context->tensors[i];
262 TF_LITE_ENSURE_STATUS(
263 logger->LogTensorValue(op_info.subgraph_index, i, tensor.data.f,
264 tensor.bytes / sizeof(float), error_reporter));
265 }
266
267 return kTfLiteOk;
268 }
269
270 // Returns the loggable tensors. Not all inputs and outputs need to be logged.
271 // For example, const weight tensors which have buffers associated with them
272 // don't need to be logged.
GetLoggableTensorIndices(const std::vector<int> & tensor_indices,const flatbuffers::Vector<flatbuffers::Offset<Tensor>> * tensors,const flatbuffers::Vector<flatbuffers::Offset<Buffer>> * tensor_buffers)273 std::vector<int> GetLoggableTensorIndices(
274 const std::vector<int>& tensor_indices,
275 const flatbuffers::Vector<flatbuffers::Offset<Tensor>>* tensors,
276 const flatbuffers::Vector<flatbuffers::Offset<Buffer>>* tensor_buffers) {
277 std::vector<int> loggable;
278 for (auto tensor_index : tensor_indices) {
279 if (tensor_index == kTfLiteOptionalTensor) {
280 continue;
281 }
282 auto tensor = tensors->Get(tensor_index);
283 auto buffer_index = tensor->buffer();
284 const bool has_no_buffer =
285 (tensor_buffers->Get(buffer_index) == nullptr) ||
286 (tensor_buffers->Get(buffer_index)->data() == nullptr) ||
287 (tensor_buffers->Get(buffer_index)->data()->size() == 0);
288 if (has_no_buffer && tensor->type() == tflite::TensorType_FLOAT32) {
289 loggable.push_back(tensor_index);
290 }
291 }
292 return loggable;
293 }
294
295 // Creates a mapping between the static model graph and the runtime TfLiteNode*
296 // nodes in the graph for the given context.
297 // This is done by querying the TfLiteContext for node and registrations using
298 // the |NodeInfoDelegateObserver|.
GetNodeOpInfoMapAndContext(const absl::flat_hash_map<std::tuple<int,int>,OperatorInfo> & node_to_opinfo,tflite::Interpreter * const interpreter,std::unordered_map<const TfLiteNode *,OperatorInfo> * node_ptr_opinfo_map,TfLiteContext ** context)299 TfLiteStatus GetNodeOpInfoMapAndContext(
300 const absl::flat_hash_map<std::tuple<int, int>, OperatorInfo>&
301 node_to_opinfo,
302 tflite::Interpreter* const interpreter,
303 std::unordered_map<const TfLiteNode*, OperatorInfo>* node_ptr_opinfo_map,
304 TfLiteContext** context) {
305 *context = interpreter->primary_subgraph().context();
306
307 // Since we only consider the primary subgraph while populating
308 // node_to_opinfo, do the same here.
309 // Because Flex delegate can merge multiple op nodes into one Delegate node if
310 // they are located in a row, the size of the execution plan can be lesser
311 // than the size of the graph's op nodes.
312 TF_LITE_ENSURE(*context,
313 interpreter->execution_plan().size() <= node_to_opinfo.size());
314 for (const auto& entry : node_to_opinfo) {
315 auto op_info = entry.second;
316 int subgraph_index, op_index;
317 std::tie(subgraph_index, op_index) = entry.first;
318 const auto* node_and_reg =
319 interpreter->node_and_registration(subgraph_index, op_index);
320 op_info.registration = &node_and_reg->second;
321 node_ptr_opinfo_map->insert({&node_and_reg->first, op_info});
322 }
323 return kTfLiteOk;
324 }
325
GetOpName(const tflite::OperatorCode & opcode)326 string GetOpName(const tflite::OperatorCode& opcode) {
327 if (opcode.custom_code() != nullptr) {
328 return opcode.custom_code()->str();
329 }
330 return tflite::EnumNamesBuiltinOperator()[GetBuiltinCode(&opcode)];
331 }
332
333 // A |CalibrationReader| that owns the Calibrator.
334 class Reader : public CalibrationReader {
335 public:
Reader(const TfLiteContext * context,const Logger * logger)336 Reader(const TfLiteContext* context, const Logger* logger)
337 : CalibrationReader(logger), context_(context) {}
338
~Reader()339 ~Reader() override { GetCalibratorRegistry()->RemoveCalibrator(context_); }
340
341 private:
342 const TfLiteContext* context_;
343 };
344
HasInputs(BuiltinOperator code)345 bool HasInputs(BuiltinOperator code) {
346 switch (code) {
347 case BuiltinOperator_CALL_ONCE:
348 case BuiltinOperator_VAR_HANDLE:
349 return false;
350 default:
351 return true;
352 }
353 }
354
HasOutputs(BuiltinOperator code)355 bool HasOutputs(BuiltinOperator code) {
356 switch (code) {
357 case BuiltinOperator_ASSIGN_VARIABLE:
358 case BuiltinOperator_CALL_ONCE:
359 return false;
360 default:
361 return true;
362 }
363 }
364
365 } // namespace
366
BuildLoggingInterpreter(const FlatBufferModel & model,const OpResolver & op_resolver,std::unique_ptr<Interpreter> * interpreter,std::unique_ptr<CalibrationReader> * calibration_reader)367 TfLiteStatus BuildLoggingInterpreter(
368 const FlatBufferModel& model, const OpResolver& op_resolver,
369 std::unique_ptr<Interpreter>* interpreter,
370 std::unique_ptr<CalibrationReader>* calibration_reader) {
371 return BuildLoggingInterpreter(model.GetModel(), model.error_reporter(),
372 op_resolver, interpreter, calibration_reader);
373 }
374
BuildLoggingInterpreter(const tflite::Model * tflite_model,ErrorReporter * error_reporter,const OpResolver & op_resolver,std::unique_ptr<Interpreter> * interpreter,std::unique_ptr<CalibrationReader> * calibration_reader)375 TfLiteStatus BuildLoggingInterpreter(
376 const tflite::Model* tflite_model, ErrorReporter* error_reporter,
377 const OpResolver& op_resolver, std::unique_ptr<Interpreter>* interpreter,
378 std::unique_ptr<CalibrationReader>* calibration_reader) {
379 if (error_reporter == nullptr) {
380 // Make sure error_reporter is valid.
381 error_reporter = DefaultErrorReporter();
382 }
383 auto subgraphs = tflite_model->subgraphs();
384 auto tensor_buffers = tflite_model->buffers();
385
386 // Populate the node index to operator info map.
387 // We want to collect this information so we can use it during runtime to
388 // log details of which inputs and outputs.
389 // At runtime TFLite kernel invoke functions can only look into their
390 // own node in the graph (TFLiteNode*) and some limited context information.
391 absl::flat_hash_map<std::tuple<int, int>, OperatorInfo> node_to_opinfo;
392 BuiltinOpsSet builtin_op_and_versions;
393 CustomOpsSet custom_op_and_versions;
394
395 for (size_t subgraph_index = 0; subgraph_index < subgraphs->size();
396 subgraph_index++) {
397 auto subgraph = subgraphs->Get(subgraph_index);
398 auto operator_codes = tflite_model->operator_codes();
399 auto operators = subgraph->operators();
400 auto tensors = subgraph->tensors();
401 if (!operators) {
402 continue;
403 }
404
405 for (size_t i = 0; i < operators->size(); i++) {
406 OperatorInfo op_info;
407 op_info.subgraph_index = subgraph_index;
408 op_info.node_index = i;
409 auto op = operators->Get(i);
410 auto operator_code = operator_codes->Get(op->opcode_index());
411 op_info.builtin_op_code = GetBuiltinCode(operator_code);
412 op_info.name = GetOpName(*operator_code);
413 op_info.is_custom_op = operator_code->custom_code() != nullptr;
414 op_info.version = operator_code->version();
415
416 auto op_inputs = op->inputs();
417 auto op_outputs = op->outputs();
418 if (op_inputs) {
419 op_info.inputs = std::vector<int>(op_inputs->begin(), op_inputs->end());
420 } else if (HasInputs(op_info.builtin_op_code)) {
421 TF_LITE_REPORT_ERROR(error_reporter, "Op %s missing inputs",
422 op_info.name.c_str());
423 return kTfLiteError;
424 }
425 if (op_outputs) {
426 op_info.outputs =
427 std::vector<int>(op_outputs->begin(), op_outputs->end());
428 } else if (HasOutputs(op_info.builtin_op_code)) {
429 TF_LITE_REPORT_ERROR(error_reporter, "Op %s missing outputs",
430 op_info.name.c_str());
431 return kTfLiteError;
432 }
433 op_info.loggable_inputs =
434 GetLoggableTensorIndices(op_info.inputs, tensors, tensor_buffers);
435 op_info.loggable_outputs =
436 GetLoggableTensorIndices(op_info.outputs, tensors, tensor_buffers);
437 if (op_info.is_custom_op) {
438 op_info.registration =
439 op_resolver.FindOp(op_info.name.c_str(), operator_code->version());
440 custom_op_and_versions.insert(
441 {op_info.name.c_str(), operator_code->version()});
442 } else {
443 op_info.registration = op_resolver.FindOp(GetBuiltinCode(operator_code),
444 operator_code->version());
445 builtin_op_and_versions.insert(
446 {op_info.builtin_op_code, operator_code->version()});
447 }
448 std::tuple<int, int> key{subgraph_index, i};
449 node_to_opinfo[key] = op_info;
450 }
451 }
452
453 // Prepare the logging op resolver to use |LoggingEval| for kernel
454 // invocations.
455 auto logging_op_resolver = absl::make_unique<LoggingOpResolver>(
456 builtin_op_and_versions, custom_op_and_versions, op_resolver, LoggingEval,
457 error_reporter);
458 tflite::InterpreterBuilder(tflite_model, *logging_op_resolver,
459 error_reporter)(interpreter);
460
461 if (!(*interpreter)) {
462 error_reporter->Report("Failed to construct interpreter");
463 return kTfLiteError;
464 }
465
466 // Compute the mapping between runtime and static graph structure, i.e.
467 // (TfLiteContext, TfLiteNode) -> OperatorInfo
468 std::unordered_map<const TfLiteNode*, OperatorInfo> node_ptr_opinfo_map;
469 TfLiteContext* context = nullptr;
470 TF_LITE_ENSURE_STATUS(GetNodeOpInfoMapAndContext(
471 node_to_opinfo, interpreter->get(), &node_ptr_opinfo_map, &context));
472
473 Calibrator* calibrator = nullptr;
474 // Register a calibrator object for the context. This can be accessed
475 // during invocations by the logging kernels.
476 TF_LITE_ENSURE_STATUS(GetCalibratorRegistry()->CreateCalibrator(
477 context, node_ptr_opinfo_map, std::move(logging_op_resolver), &calibrator,
478 error_reporter));
479 *calibration_reader = std::unique_ptr<CalibrationReader>(
480 new Reader(context, calibrator->GetLogger()));
481
482 return kTfLiteOk;
483 }
484
485 } // namespace calibration
486 } // namespace optimize
487 } // namespace tflite
488