• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/tools/optimize/calibration/calibrator.h"
16 
17 #include <fstream>
18 #include <memory>
19 #include <string>
20 #include <unordered_map>
21 #include <unordered_set>
22 #include <utility>
23 #include <vector>
24 
25 #include "absl/container/flat_hash_map.h"
26 #include "absl/memory/memory.h"
27 #include "tensorflow/lite/c/common.h"
28 #include "tensorflow/lite/core/api/error_reporter.h"
29 #include "tensorflow/lite/core/api/op_resolver.h"
30 #include "tensorflow/lite/interpreter.h"
31 #include "tensorflow/lite/kernels/kernel_util.h"
32 #include "tensorflow/lite/kernels/register.h"
33 #include "tensorflow/lite/model.h"
34 #include "tensorflow/lite/op_resolver.h"
35 #include "tensorflow/lite/schema/schema_generated.h"
36 #include "tensorflow/lite/schema/schema_utils.h"
37 #include "tensorflow/lite/stderr_reporter.h"
38 #include "tensorflow/lite/string_util.h"
39 #include "tensorflow/lite/tools/optimize/calibration/builtin_logging_ops/lstm.h"
40 #include "tensorflow/lite/tools/optimize/calibration/calibration_common.h"
41 #include "tensorflow/lite/tools/optimize/calibration/calibration_logger.h"
42 #include "tensorflow/lite/tools/optimize/calibration/calibration_reader.h"
43 #include "tensorflow/lite/tools/optimize/calibration/custom_logging_ops/lstm.h"
44 #include "tensorflow/lite/tools/optimize/calibration/logging_op.h"
45 #include "tensorflow/lite/tools/optimize/calibration/logging_op_resolver.h"
46 
47 namespace tflite {
48 namespace optimize {
49 namespace calibration {
50 
51 namespace {
52 
53 // Calibrator is used to hold information that can be accessed during kernel
54 // invocations.
55 // TfLite kernel invocations are C functions and cannot look at the global
56 // structure of the graph. Calibrator allows the kernel invoke functions to
57 // access the global structure of graph and know which node is currently being
58 // executed. This also allows us to write a simple kernel invoke wrapper
59 // (see LoggingEval) that can work for most builtin ops.
60 class Calibrator {
61  public:
Calibrator(const std::unordered_map<const TfLiteNode *,OperatorInfo> & node_ptr_opinfo_map,std::unique_ptr<LoggingOpResolver> logging_op_resolver,ErrorReporter * error_reporter)62   Calibrator(const std::unordered_map<const TfLiteNode*, OperatorInfo>&
63                  node_ptr_opinfo_map,
64              std::unique_ptr<LoggingOpResolver> logging_op_resolver,
65              ErrorReporter* error_reporter)
66       : node_ptr_opinfo_map_(node_ptr_opinfo_map),
67         logging_op_resolver_(std::move(logging_op_resolver)),
68         error_reporter_(error_reporter) {
69     logger_ = absl::make_unique<Logger>();
70   }
71 
72   // Returns the wrapped kernel invoke function |TfLiteRegistration.invoke|.
73   KernelEvalFuncPtr GetKernelInvoke(const TfLiteNode* node) const;
74 
75   // Gets the instance of logger associated with the current context.
GetLogger() const76   Logger* GetLogger() const { return logger_.get(); }
77 
78   // Gets the error reporter.
GetErrorReporter() const79   ErrorReporter* GetErrorReporter() const { return error_reporter_; }
80 
81   // Gets the operator information about the given TfLiteNode.
GetOpInfo(const TfLiteNode * node) const82   const OperatorInfo& GetOpInfo(const TfLiteNode* node) const {
83     return node_ptr_opinfo_map_.at(node);
84   }
85 
GetNodesUnderCalibration()86   std::vector<const TfLiteNode*> GetNodesUnderCalibration() {
87     std::vector<const TfLiteNode*> nodes;
88     for (const auto& entry : node_ptr_opinfo_map_) {
89       nodes.push_back(entry.first);
90     }
91     return nodes;
92   }
93 
94  private:
95   std::unordered_map<const TfLiteNode*, OperatorInfo> node_ptr_opinfo_map_;
96   std::unique_ptr<LoggingOpResolver> logging_op_resolver_;
97   const std::unordered_map<int, OperatorInfo> index_opinfo_;
98   std::unique_ptr<Logger> logger_;
99   ErrorReporter* error_reporter_;
100 };
101 
GetKernelInvoke(const TfLiteNode * node) const102 KernelEvalFuncPtr Calibrator::GetKernelInvoke(const TfLiteNode* node) const {
103   auto op_info = node_ptr_opinfo_map_.at(node);
104   if (op_info.is_custom_op) {
105     return logging_op_resolver_->GetWrappedKernelInvoke(op_info.name.c_str(),
106                                                         op_info.version);
107   }
108   return logging_op_resolver_->GetWrappedKernelInvoke(op_info.builtin_op_code,
109                                                       op_info.version);
110 }
111 
112 // A registry of |Calibrator| objects per |TfLiteContext|.
113 // This global registry is needed to access |Calibrator| objects in the kernel
114 // invoke functions i.e. |TfLiteRegistration.invoke|.
115 // Kernel invoke functions are C functions that have limited access to
116 // |TfLiteContext|. Kernel invoke functions don't have access to global state of
117 // graph. That means during a kernel invocation, the function cannot know which
118 // node it was invoked for. E.g. in case of a model with |Conv| op at two
119 // locations, there is no easy way for the Conv.invoke function to disambiguate
120 // the calls.
121 //
122 // For calibration we solve this problem by creating a map of calibrators
123 // per |TfLiteContext|. This map is |GlobalCalibrationRegistry|.
124 //
125 // This registry is then accessed using a global getter function:
126 // |GetCalibratorRegistry|.
127 // E.g.
128 // TfLiteStatus SomeKernelInvokeFn(TfLiteContext* context, TfLiteNode* node) {
129 //   .... code ....
130 //   auto registry = GetCalibratorRegistry();
131 //   auto calibrator = registry->GetCalibrator(context);
132 //   ..... code ....
133 //  }
134 //
135 // This way the kernel invoke functions can get the access to the Calibrator
136 // object associated with the |TfLiteContext|.
137 class GlobalCalibratorRegistry {
138  public:
139   // Get the |Calibrator| associated with given context, returns null if no
140   // calibrator is associated with the given context.
GetCalibrator(const TfLiteNode * node) const141   Calibrator* GetCalibrator(const TfLiteNode* node) const {
142     if (node_to_calibrator_.find(node) == node_to_calibrator_.cend()) {
143       return nullptr;
144     }
145     return node_to_calibrator_.at(node);
146   }
147 
148   // Removes the association between calibrator and context.
149   // Note: This deletes the calibrator as well.
RemoveCalibrator(const TfLiteContext * context)150   void RemoveCalibrator(const TfLiteContext* context) {
151     Calibrator* calibrator = calibrator_registry_.at(context).get();
152     auto nodes = calibrator->GetNodesUnderCalibration();
153     for (auto node : nodes) {
154       node_to_calibrator_.erase(node);
155     }
156     calibrator_registry_.erase(context);
157   }
158 
159   // Creates an instance of |Calibrator|.
160   // Registry owns the |Calibrator| object which can be deleted by calling
161   // |RemoveCalibrator|.
CreateCalibrator(const TfLiteContext * context,const std::unordered_map<const TfLiteNode *,OperatorInfo> & node_to_opinfo,std::unique_ptr<LoggingOpResolver> logging_op_resolver,Calibrator ** calibrator_ptr,ErrorReporter * reporter)162   TfLiteStatus CreateCalibrator(
163       const TfLiteContext* context,
164       const std::unordered_map<const TfLiteNode*, OperatorInfo>& node_to_opinfo,
165       std::unique_ptr<LoggingOpResolver> logging_op_resolver,
166       Calibrator** calibrator_ptr, ErrorReporter* reporter) {
167     if (calibrator_registry_.find(context) != calibrator_registry_.cend()) {
168       reporter->Report(
169           "Failed to create calibrator, context already registered.");
170       return kTfLiteError;
171     }
172     auto calibrator = absl::make_unique<Calibrator>(
173         node_to_opinfo, std::move(logging_op_resolver), reporter);
174     calibrator_registry_[context] = std::move(calibrator);
175     *calibrator_ptr = calibrator_registry_.at(context).get();
176     for (const auto& entry : node_to_opinfo) {
177       node_to_calibrator_[entry.first] = *calibrator_ptr;
178     }
179     return kTfLiteOk;
180   }
181 
182  private:
183   absl::flat_hash_map<const TfLiteContext*, std::unique_ptr<Calibrator>>
184       calibrator_registry_;
185   absl::flat_hash_map<const TfLiteNode*, Calibrator*> node_to_calibrator_;
186 };
187 
GetCalibratorRegistry()188 GlobalCalibratorRegistry* GetCalibratorRegistry() {
189   static GlobalCalibratorRegistry* registry = new GlobalCalibratorRegistry();
190   return registry;
191 }
192 
193 // Get the logging kernel if there are any.
194 // TODO(jianlijianli): extend this to support multiple recipe for the same
195 // model.
GetLoggingEvalFunc(TfLiteContext * context,TfLiteNode * node,int builtin_op_code)196 logging_kernel_func_ptr GetLoggingEvalFunc(TfLiteContext* context,
197                                            TfLiteNode* node,
198                                            int builtin_op_code) {
199   switch (builtin_op_code) {
200     case BuiltinOperator_LSTM: {
201       if (node->intermediates->size == 12) {
202         return tflite::optimize::calibration::custom::lstm_logging_kernel;
203       }
204       return tflite::optimize::calibration::builtin::lstm_logging_kernel;
205     }
206     case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM:
207       return tflite::optimize::calibration::builtin::
208           unidirectional_sequence_lstm_logging_kernel;
209     default:
210       return nullptr;
211   }
212 }
213 
214 // A wrapper implementation for |TfLiteRegistration.invoke| that logs inputs,
215 // invokes the wrapped implementation and then logs the outputs.
LoggingEval(TfLiteContext * context,TfLiteNode * node)216 TfLiteStatus LoggingEval(TfLiteContext* context, TfLiteNode* node) {
217   Calibrator* calibrator = GetCalibratorRegistry()->GetCalibrator(node);
218 
219   if (!calibrator) {
220     context->ReportError(context, "No calibrator found for context.");
221     return kTfLiteError;
222   }
223 
224   auto kernel_invoke = calibrator->GetKernelInvoke(node);
225   auto logger = calibrator->GetLogger();
226   auto op_info = calibrator->GetOpInfo(node);
227   auto error_reporter = calibrator->GetErrorReporter();
228 
229   for (int i : op_info.loggable_inputs) {
230     auto tensor = context->tensors[i];
231     TF_LITE_ENSURE_STATUS(
232         logger->LogTensorValue(op_info.subgraph_index, i, tensor.data.f,
233                                tensor.bytes / sizeof(float), error_reporter));
234   }
235   auto builtin_op_code = calibrator->GetOpInfo(node).builtin_op_code;
236   auto kernel_invoke_intermediate =
237       GetLoggingEvalFunc(context, node, builtin_op_code);
238   if (kernel_invoke_intermediate == nullptr) {
239     TF_LITE_ENSURE_STATUS(kernel_invoke(context, node));
240   } else {
241     TF_LITE_ENSURE_STATUS(
242         kernel_invoke_intermediate(context, op_info.subgraph_index, node,
243                                    calibrator->GetLogger(), error_reporter));
244   }
245 
246   // TODO(shashishekhar): An intermediate tensor in graph will get logged twice
247   // once as an input and second time as output. This doesn't change the min max
248   // values but is inefficient.
249   // Using moving average will also break this.
250 
251   // Log input again to make sure the state tensors are captured after lstm
252   // cell.
253   for (int i : op_info.loggable_inputs) {
254     auto tensor = context->tensors[i];
255     TF_LITE_ENSURE_STATUS(
256         logger->LogTensorValue(op_info.subgraph_index, i, tensor.data.f,
257                                tensor.bytes / sizeof(float), error_reporter));
258   }
259 
260   for (int i : op_info.loggable_outputs) {
261     auto tensor = context->tensors[i];
262     TF_LITE_ENSURE_STATUS(
263         logger->LogTensorValue(op_info.subgraph_index, i, tensor.data.f,
264                                tensor.bytes / sizeof(float), error_reporter));
265   }
266 
267   return kTfLiteOk;
268 }
269 
270 // Returns the loggable tensors. Not all inputs and outputs need to be logged.
271 // For example, const weight tensors which have buffers associated with them
272 // don't need to be logged.
GetLoggableTensorIndices(const std::vector<int> & tensor_indices,const flatbuffers::Vector<flatbuffers::Offset<Tensor>> * tensors,const flatbuffers::Vector<flatbuffers::Offset<Buffer>> * tensor_buffers)273 std::vector<int> GetLoggableTensorIndices(
274     const std::vector<int>& tensor_indices,
275     const flatbuffers::Vector<flatbuffers::Offset<Tensor>>* tensors,
276     const flatbuffers::Vector<flatbuffers::Offset<Buffer>>* tensor_buffers) {
277   std::vector<int> loggable;
278   for (auto tensor_index : tensor_indices) {
279     if (tensor_index == kTfLiteOptionalTensor) {
280       continue;
281     }
282     auto tensor = tensors->Get(tensor_index);
283     auto buffer_index = tensor->buffer();
284     const bool has_no_buffer =
285         (tensor_buffers->Get(buffer_index) == nullptr) ||
286         (tensor_buffers->Get(buffer_index)->data() == nullptr) ||
287         (tensor_buffers->Get(buffer_index)->data()->size() == 0);
288     if (has_no_buffer && tensor->type() == tflite::TensorType_FLOAT32) {
289       loggable.push_back(tensor_index);
290     }
291   }
292   return loggable;
293 }
294 
295 // Creates a mapping between the static model graph and the runtime TfLiteNode*
296 // nodes in the graph for the given context.
297 // This is done by querying the TfLiteContext for node and registrations using
298 // the |NodeInfoDelegateObserver|.
GetNodeOpInfoMapAndContext(const absl::flat_hash_map<std::tuple<int,int>,OperatorInfo> & node_to_opinfo,tflite::Interpreter * const interpreter,std::unordered_map<const TfLiteNode *,OperatorInfo> * node_ptr_opinfo_map,TfLiteContext ** context)299 TfLiteStatus GetNodeOpInfoMapAndContext(
300     const absl::flat_hash_map<std::tuple<int, int>, OperatorInfo>&
301         node_to_opinfo,
302     tflite::Interpreter* const interpreter,
303     std::unordered_map<const TfLiteNode*, OperatorInfo>* node_ptr_opinfo_map,
304     TfLiteContext** context) {
305   *context = interpreter->primary_subgraph().context();
306 
307   // Since we only consider the primary subgraph while populating
308   // node_to_opinfo, do the same here.
309   // Because Flex delegate can merge multiple op nodes into one Delegate node if
310   // they are located in a row, the size of the execution plan can be lesser
311   // than the size of the graph's op nodes.
312   TF_LITE_ENSURE(*context,
313                  interpreter->execution_plan().size() <= node_to_opinfo.size());
314   for (const auto& entry : node_to_opinfo) {
315     auto op_info = entry.second;
316     int subgraph_index, op_index;
317     std::tie(subgraph_index, op_index) = entry.first;
318     const auto* node_and_reg =
319         interpreter->node_and_registration(subgraph_index, op_index);
320     op_info.registration = &node_and_reg->second;
321     node_ptr_opinfo_map->insert({&node_and_reg->first, op_info});
322   }
323   return kTfLiteOk;
324 }
325 
GetOpName(const tflite::OperatorCode & opcode)326 string GetOpName(const tflite::OperatorCode& opcode) {
327   if (opcode.custom_code() != nullptr) {
328     return opcode.custom_code()->str();
329   }
330   return tflite::EnumNamesBuiltinOperator()[GetBuiltinCode(&opcode)];
331 }
332 
333 // A |CalibrationReader| that owns the Calibrator.
334 class Reader : public CalibrationReader {
335  public:
Reader(const TfLiteContext * context,const Logger * logger)336   Reader(const TfLiteContext* context, const Logger* logger)
337       : CalibrationReader(logger), context_(context) {}
338 
~Reader()339   ~Reader() override { GetCalibratorRegistry()->RemoveCalibrator(context_); }
340 
341  private:
342   const TfLiteContext* context_;
343 };
344 
HasInputs(BuiltinOperator code)345 bool HasInputs(BuiltinOperator code) {
346   switch (code) {
347     case BuiltinOperator_CALL_ONCE:
348     case BuiltinOperator_VAR_HANDLE:
349       return false;
350     default:
351       return true;
352   }
353 }
354 
HasOutputs(BuiltinOperator code)355 bool HasOutputs(BuiltinOperator code) {
356   switch (code) {
357     case BuiltinOperator_ASSIGN_VARIABLE:
358     case BuiltinOperator_CALL_ONCE:
359       return false;
360     default:
361       return true;
362   }
363 }
364 
365 }  // namespace
366 
BuildLoggingInterpreter(const FlatBufferModel & model,const OpResolver & op_resolver,std::unique_ptr<Interpreter> * interpreter,std::unique_ptr<CalibrationReader> * calibration_reader)367 TfLiteStatus BuildLoggingInterpreter(
368     const FlatBufferModel& model, const OpResolver& op_resolver,
369     std::unique_ptr<Interpreter>* interpreter,
370     std::unique_ptr<CalibrationReader>* calibration_reader) {
371   return BuildLoggingInterpreter(model.GetModel(), model.error_reporter(),
372                                  op_resolver, interpreter, calibration_reader);
373 }
374 
BuildLoggingInterpreter(const tflite::Model * tflite_model,ErrorReporter * error_reporter,const OpResolver & op_resolver,std::unique_ptr<Interpreter> * interpreter,std::unique_ptr<CalibrationReader> * calibration_reader)375 TfLiteStatus BuildLoggingInterpreter(
376     const tflite::Model* tflite_model, ErrorReporter* error_reporter,
377     const OpResolver& op_resolver, std::unique_ptr<Interpreter>* interpreter,
378     std::unique_ptr<CalibrationReader>* calibration_reader) {
379   if (error_reporter == nullptr) {
380     // Make sure error_reporter is valid.
381     error_reporter = DefaultErrorReporter();
382   }
383   auto subgraphs = tflite_model->subgraphs();
384   auto tensor_buffers = tflite_model->buffers();
385 
386   // Populate the node index to operator info map.
387   // We want to collect this information so we can use it during runtime to
388   // log details of which inputs and outputs.
389   // At runtime TFLite kernel invoke functions can only look into their
390   // own node in the graph (TFLiteNode*) and some limited context information.
391   absl::flat_hash_map<std::tuple<int, int>, OperatorInfo> node_to_opinfo;
392   BuiltinOpsSet builtin_op_and_versions;
393   CustomOpsSet custom_op_and_versions;
394 
395   for (size_t subgraph_index = 0; subgraph_index < subgraphs->size();
396        subgraph_index++) {
397     auto subgraph = subgraphs->Get(subgraph_index);
398     auto operator_codes = tflite_model->operator_codes();
399     auto operators = subgraph->operators();
400     auto tensors = subgraph->tensors();
401     if (!operators) {
402       continue;
403     }
404 
405     for (size_t i = 0; i < operators->size(); i++) {
406       OperatorInfo op_info;
407       op_info.subgraph_index = subgraph_index;
408       op_info.node_index = i;
409       auto op = operators->Get(i);
410       auto operator_code = operator_codes->Get(op->opcode_index());
411       op_info.builtin_op_code = GetBuiltinCode(operator_code);
412       op_info.name = GetOpName(*operator_code);
413       op_info.is_custom_op = operator_code->custom_code() != nullptr;
414       op_info.version = operator_code->version();
415 
416       auto op_inputs = op->inputs();
417       auto op_outputs = op->outputs();
418       if (op_inputs) {
419         op_info.inputs = std::vector<int>(op_inputs->begin(), op_inputs->end());
420       } else if (HasInputs(op_info.builtin_op_code)) {
421         TF_LITE_REPORT_ERROR(error_reporter, "Op %s missing inputs",
422                              op_info.name.c_str());
423         return kTfLiteError;
424       }
425       if (op_outputs) {
426         op_info.outputs =
427             std::vector<int>(op_outputs->begin(), op_outputs->end());
428       } else if (HasOutputs(op_info.builtin_op_code)) {
429         TF_LITE_REPORT_ERROR(error_reporter, "Op %s missing outputs",
430                              op_info.name.c_str());
431         return kTfLiteError;
432       }
433       op_info.loggable_inputs =
434           GetLoggableTensorIndices(op_info.inputs, tensors, tensor_buffers);
435       op_info.loggable_outputs =
436           GetLoggableTensorIndices(op_info.outputs, tensors, tensor_buffers);
437       if (op_info.is_custom_op) {
438         op_info.registration =
439             op_resolver.FindOp(op_info.name.c_str(), operator_code->version());
440         custom_op_and_versions.insert(
441             {op_info.name.c_str(), operator_code->version()});
442       } else {
443         op_info.registration = op_resolver.FindOp(GetBuiltinCode(operator_code),
444                                                   operator_code->version());
445         builtin_op_and_versions.insert(
446             {op_info.builtin_op_code, operator_code->version()});
447       }
448       std::tuple<int, int> key{subgraph_index, i};
449       node_to_opinfo[key] = op_info;
450     }
451   }
452 
453   // Prepare the logging op resolver to use |LoggingEval| for kernel
454   // invocations.
455   auto logging_op_resolver = absl::make_unique<LoggingOpResolver>(
456       builtin_op_and_versions, custom_op_and_versions, op_resolver, LoggingEval,
457       error_reporter);
458   tflite::InterpreterBuilder(tflite_model, *logging_op_resolver,
459                              error_reporter)(interpreter);
460 
461   if (!(*interpreter)) {
462     error_reporter->Report("Failed to construct interpreter");
463     return kTfLiteError;
464   }
465 
466   // Compute the mapping between runtime and static graph structure, i.e.
467   // (TfLiteContext, TfLiteNode) -> OperatorInfo
468   std::unordered_map<const TfLiteNode*, OperatorInfo> node_ptr_opinfo_map;
469   TfLiteContext* context = nullptr;
470   TF_LITE_ENSURE_STATUS(GetNodeOpInfoMapAndContext(
471       node_to_opinfo, interpreter->get(), &node_ptr_opinfo_map, &context));
472 
473   Calibrator* calibrator = nullptr;
474   // Register a calibrator object for the context. This can be accessed
475   // during invocations by the logging kernels.
476   TF_LITE_ENSURE_STATUS(GetCalibratorRegistry()->CreateCalibrator(
477       context, node_ptr_opinfo_map, std::move(logging_op_resolver), &calibrator,
478       error_reporter));
479   *calibration_reader = std::unique_ptr<CalibrationReader>(
480       new Reader(context, calibrator->GetLogger()));
481 
482   return kTfLiteOk;
483 }
484 
485 }  // namespace calibration
486 }  // namespace optimize
487 }  // namespace tflite
488