• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/tools/optimize/calibration/calibrator.h"
16 
17 #include <fstream>
18 #include <memory>
19 #include <string>
20 #include <unordered_map>
21 #include <unordered_set>
22 #include <vector>
23 
24 #include "absl/memory/memory.h"
25 #include "tensorflow/lite/core/api/error_reporter.h"
26 #include "tensorflow/lite/core/api/op_resolver.h"
27 #include "tensorflow/lite/interpreter.h"
28 #include "tensorflow/lite/kernels/register.h"
29 #include "tensorflow/lite/model.h"
30 #include "tensorflow/lite/op_resolver.h"
31 #include "tensorflow/lite/schema/schema_generated.h"
32 #include "tensorflow/lite/string_util.h"
33 #include "tensorflow/lite/tools/optimize/calibration/calibration_common.h"
34 #include "tensorflow/lite/tools/optimize/calibration/calibration_logger.h"
35 #include "tensorflow/lite/tools/optimize/calibration/calibration_reader.h"
36 #include "tensorflow/lite/tools/optimize/calibration/logging_op_resolver.h"
37 #include "tensorflow/lite/tools/optimize/calibration/node_info_delegate.h"
38 
39 namespace tflite {
40 namespace optimize {
41 namespace calibration {
42 
43 namespace {
44 
45 // Calibrator is used to hold information that can be accessed during kernel
46 // invocations.
47 // TfLite kernel invocations are C functions and cannot look at the global
48 // structure of the graph. Calibrator allows the kernel invoke functions to
49 // access the global structure of graph and know which node is currently being
50 // executed. This also allows us to write a simple kernel invoke wrapper
51 // (see LoggingEval) that can work for most builtin ops.
52 class Calibrator {
53  public:
Calibrator(const std::unordered_map<const TfLiteNode *,OperatorInfo> & node_ptr_opinfo_map,std::unique_ptr<LoggingOpResolver> logging_op_resolver)54   Calibrator(const std::unordered_map<const TfLiteNode*, OperatorInfo>&
55                  node_ptr_opinfo_map,
56              std::unique_ptr<LoggingOpResolver> logging_op_resolver)
57       : node_ptr_opinfo_map_(node_ptr_opinfo_map),
58         logging_op_resolver_(std::move(logging_op_resolver)) {
59     logger_ = absl::make_unique<Logger>();
60   }
61 
62   // Returns the wrapped kernel invoke function |TfLiteRegistration.invoke|.
63   KernelEvalFuncPtr GetKernelInvoke(const TfLiteNode* node) const;
64 
65   // Gets the instance of logger associated with the current context.
GetLogger() const66   Logger* GetLogger() const { return logger_.get(); }
67 
68   // Gets the operator information about the given TfLiteNode.
GetOpInfo(const TfLiteNode * node) const69   const OperatorInfo& GetOpInfo(const TfLiteNode* node) const {
70     return node_ptr_opinfo_map_.at(node);
71   }
72 
73  private:
74   std::unordered_map<const TfLiteNode*, OperatorInfo> node_ptr_opinfo_map_;
75   std::unique_ptr<LoggingOpResolver> logging_op_resolver_;
76   const std::unordered_map<int, OperatorInfo> index_opinfo_;
77   std::unique_ptr<Logger> logger_;
78 };
79 
GetKernelInvoke(const TfLiteNode * node) const80 KernelEvalFuncPtr Calibrator::GetKernelInvoke(const TfLiteNode* node) const {
81   auto op_info = node_ptr_opinfo_map_.at(node);
82   return logging_op_resolver_->GetWrappedKernelInvoke(op_info.builtin_op_code,
83                                                       1);
84 }
85 
86 // A registry of |Calibrator| objects per |TfLiteContext|.
87 // This global registry is needed to access |Calibrator| objects in the kernel
88 // invoke functions i.e. |TfLiteRegistration.invoke|.
89 // Kernel invoke functions are C functions that have limited access to
90 // |TfLiteContext|. Kernel invoke functions don't have access to global state of
91 // graph. That means during a kernel invocation, the function cannot know which
92 // node it was invoked for. E.g. in case of a model with |Conv| op at two
93 // locations, there is no easy way for the Conv.invoke function to disambiguate
94 // the calls.
95 //
96 // For calibration we solve this problem by creating a map of calibrators
97 // per |TfLiteContext|. This map is |GlobalCalibrationRegistry|.
98 //
99 // This registry is then accessed using a global getter function:
100 // |GetCalibratorRegistry|.
101 // E.g.
102 // TfLiteStatus SomeKernelInvokeFn(TfLiteContext* context, TfLiteNode* node) {
103 //   .... code ....
104 //   auto registry = GetCalibratorRegistry();
105 //   auto calibrator = registry->GetCalibrator(context);
106 //   ..... code ....
107 //  }
108 //
109 // This way the kernel invoke functions can get the access to the Calibrator
110 // object associated with the |TfLiteContext|.
111 class GlobalCalibratorRegistry {
112  public:
113   // Get the |Calibrator| associated with given context, returns null if no
114   // calibrator is associated with the given context.
GetCalibrator(const TfLiteContext * context) const115   Calibrator* GetCalibrator(const TfLiteContext* context) const {
116     if (calibrator_registry_.find(context) == calibrator_registry_.cend()) {
117       return nullptr;
118     }
119     return calibrator_registry_.at(context).get();
120   }
121 
122   // Removes the association between calibrator and context.
123   // Note: This deletes the calibrator as well.
RemoveCalibrator(const TfLiteContext * context)124   void RemoveCalibrator(const TfLiteContext* context) {
125     calibrator_registry_.erase(context);
126   }
127 
128   // Creates an instance of |Calibrator|.
129   // Registry owns the |Calibrator| object which can be deleted by calling
130   // |RemoveCalibrator|.
CreateCalibrator(const TfLiteContext * context,const std::unordered_map<const TfLiteNode *,OperatorInfo> & node_to_opinfo,std::unique_ptr<LoggingOpResolver> logging_op_resolver,Calibrator ** calibrator_ptr,ErrorReporter * reporter)131   TfLiteStatus CreateCalibrator(
132       const TfLiteContext* context,
133       const std::unordered_map<const TfLiteNode*, OperatorInfo>& node_to_opinfo,
134       std::unique_ptr<LoggingOpResolver> logging_op_resolver,
135       Calibrator** calibrator_ptr, ErrorReporter* reporter) {
136     if (calibrator_registry_.find(context) != calibrator_registry_.cend()) {
137       reporter->Report(
138           "Failed to create calibrator, context already registered.");
139       return kTfLiteError;
140     }
141     std::unique_ptr<Calibrator> calibrator = absl::make_unique<Calibrator>(
142         node_to_opinfo, std::move(logging_op_resolver));
143     calibrator_registry_[context] = std::move(calibrator);
144     *calibrator_ptr = calibrator_registry_.at(context).get();
145     return kTfLiteOk;
146   }
147 
148  private:
149   std::unordered_map<const TfLiteContext*, std::unique_ptr<Calibrator>>
150       calibrator_registry_;
151 };
152 
GetCalibratorRegistry()153 GlobalCalibratorRegistry* GetCalibratorRegistry() {
154   static GlobalCalibratorRegistry* registry = new GlobalCalibratorRegistry();
155   return registry;
156 }
157 
158 // A wrapper implementation for |TfLiteRegistration.invoke| that logs inputs,
159 // invokes the wrapped implementation and then logs the outputs.
LoggingEval(TfLiteContext * context,TfLiteNode * node)160 TfLiteStatus LoggingEval(TfLiteContext* context, TfLiteNode* node) {
161   Calibrator* calibrator = GetCalibratorRegistry()->GetCalibrator(context);
162 
163   if (!calibrator) {
164     context->ReportError(context, "No calibrator found for context.");
165     return kTfLiteError;
166   }
167 
168   auto kernel_invoke = calibrator->GetKernelInvoke(node);
169   auto logger = calibrator->GetLogger();
170   auto op_info = calibrator->GetOpInfo(node);
171 
172   for (int i : op_info.loggable_inputs) {
173     auto tensor = context->tensors[i];
174     logger->LogTensorValue(i, tensor.data.f, tensor.bytes / sizeof(float));
175   }
176 
177   auto status = kernel_invoke(context, node);
178   // TODO(shashishekhar): An intermediate tensor in graph will get logged twice
179   // once as an input and second time as output. This doesn't change the min max
180   // values but is inefficient.
181   // Using moving average will also break this.
182 
183   for (int i : op_info.loggable_outputs) {
184     auto tensor = context->tensors[i];
185     logger->LogTensorValue(i, tensor.data.f, tensor.bytes / sizeof(float));
186   }
187 
188   return status;
189 }
190 
191 // Returns the loggable tensors. Not all inputs and outputs need to be logged.
192 // For example, const weight tensors which have buffers associated with them
193 // don't need to be logged.
GetLoggableTensorIndices(const std::vector<int> & tensor_indices,const flatbuffers::Vector<flatbuffers::Offset<Tensor>> * tensors,const flatbuffers::Vector<flatbuffers::Offset<Buffer>> * tensor_buffers)194 std::vector<int> GetLoggableTensorIndices(
195     const std::vector<int>& tensor_indices,
196     const flatbuffers::Vector<flatbuffers::Offset<Tensor>>* tensors,
197     const flatbuffers::Vector<flatbuffers::Offset<Buffer>>* tensor_buffers) {
198   std::vector<int> loggable;
199   for (auto tensor_index : tensor_indices) {
200     auto tensor = tensors->Get(tensor_index);
201     auto buffer_index = tensor->buffer();
202     const bool has_no_buffer =
203         (tensor_buffers->Get(buffer_index) == nullptr) ||
204         (tensor_buffers->Get(buffer_index)->data() == nullptr) ||
205         (tensor_buffers->Get(buffer_index)->data()->size() == 0);
206     if (has_no_buffer && tensor->type() == tflite::TensorType_FLOAT32) {
207       loggable.push_back(tensor_index);
208     }
209   }
210   return loggable;
211 }
212 
213 // Creates a mapping between the static model graph and the runtime TfLiteNode*
214 // nodes in the graph for the given context.
215 // This is done by querying the TfLiteContext for node and registrations using
216 // the |NodeInfoDelegateObserver|.
GetNodeOpInfoMapAndContext(const std::unordered_map<int,OperatorInfo> & node_to_opinfo,tflite::Interpreter * const interpreter,std::unordered_map<const TfLiteNode *,OperatorInfo> * node_ptr_opinfo_map,const TfLiteContext ** context)217 TfLiteStatus GetNodeOpInfoMapAndContext(
218     const std::unordered_map<int, OperatorInfo>& node_to_opinfo,
219     tflite::Interpreter* const interpreter,
220     std::unordered_map<const TfLiteNode*, OperatorInfo>* node_ptr_opinfo_map,
221     const TfLiteContext** context
222 
223 ) {
224   NodeInfoDelegateObserver delegate_observer(node_to_opinfo,
225                                              node_ptr_opinfo_map);
226   NodeInfoDelegateParams delegate_params;
227   delegate_params.delegate_observer = &delegate_observer;
228   TfLiteDelegate logging_delegate = CreateNodeInfoDelegate(&delegate_params);
229 
230   auto modify_status = interpreter->ModifyGraphWithDelegate(&logging_delegate);
231   if (modify_status != kTfLiteOk) {
232     return kTfLiteError;
233   }
234   *context = delegate_observer.GetContext();
235   return kTfLiteOk;
236 }
237 
GetOpName(const tflite::OperatorCode & opcode)238 string GetOpName(const tflite::OperatorCode& opcode) {
239   if (opcode.custom_code() != nullptr) {
240     return opcode.custom_code()->str();
241   }
242   return tflite::EnumNamesBuiltinOperator()[opcode.builtin_code()];
243 }
244 
245 // A |CalibrationReader| that owns the Calibrator.
246 class Reader : public CalibrationReader {
247  public:
Reader(const TfLiteContext * context,const Logger * logger)248   Reader(const TfLiteContext* context, const Logger* logger)
249       : CalibrationReader(logger), context_(context) {}
250 
~Reader()251   ~Reader() override { GetCalibratorRegistry()->RemoveCalibrator(context_); }
252 
253  private:
254   const TfLiteContext* context_;
255 };
256 
257 }  // namespace
258 
BuildLoggingInterpreter(const FlatBufferModel & model,const OpResolver & op_resolver,std::unique_ptr<Interpreter> * interpreter,std::unique_ptr<CalibrationReader> * calibration_reader)259 TfLiteStatus BuildLoggingInterpreter(
260     const FlatBufferModel& model, const OpResolver& op_resolver,
261     std::unique_ptr<Interpreter>* interpreter,
262     std::unique_ptr<CalibrationReader>* calibration_reader) {
263   auto tflite_model = model.GetModel();
264   auto subgraphs = tflite_model->subgraphs();
265   auto tensor_buffers = tflite_model->buffers();
266 
267   if (subgraphs->size() != 1) {
268     model.error_reporter()->Report(
269         "Only models with a single subgraph are supported, model had %d "
270         "subgraphs",
271         subgraphs->size());
272     return kTfLiteError;
273   }
274 
275   // Populate the node index to operator info map.
276   // We want to collect this information so we can use it during runtime to
277   // log details of which inputs and outputs.
278   // At runtime TFLite kernel invoke functions can only look into their
279   // own node in the graph (TFLiteNode*) and some limited context information.
280   auto primary_subgraph = subgraphs->Get(0);
281   auto operator_codes = tflite_model->operator_codes();
282   auto operators = primary_subgraph->operators();
283   auto tensors = primary_subgraph->tensors();
284   std::unordered_map<int, OperatorInfo> node_to_opinfo;
285   BuiltinOpsSet op_and_versions;
286 
287   for (size_t i = 0; i < operators->size(); i++) {
288     OperatorInfo op_info;
289     op_info.node_index = i;
290     auto op = operators->Get(i);
291     auto operator_code = operator_codes->Get(op->opcode_index());
292     op_info.builtin_op_code = operator_code->builtin_code();
293     op_info.name = GetOpName(*operator_code);
294     op_info.is_custom_op = operator_code->custom_code() != nullptr;
295 
296     auto op_inputs = op->inputs();
297     auto op_outputs = op->outputs();
298     op_info.inputs = std::vector<int>(op_inputs->begin(), op_inputs->end());
299     op_info.outputs = std::vector<int>(op_outputs->begin(), op_outputs->end());
300     op_info.loggable_inputs =
301         GetLoggableTensorIndices(op_info.inputs, tensors, tensor_buffers);
302     op_info.loggable_outputs =
303         GetLoggableTensorIndices(op_info.outputs, tensors, tensor_buffers);
304     if (!op_info.is_custom_op) {
305       op_info.registration = op_resolver.FindOp(operator_code->builtin_code(),
306                                                 operator_code->version());
307     } else {
308       op_info.registration =
309           op_resolver.FindOp(op_info.name.c_str(), operator_code->version());
310     }
311     node_to_opinfo[i] = op_info;
312     op_and_versions.insert({op_info.builtin_op_code, operator_code->version()});
313   }
314 
315   // Prepare the logging op resolver to use |LoggingEval| for kernel
316   // invocations.
317   auto logging_op_resolver = absl::make_unique<LoggingOpResolver>(
318       op_and_versions, op_resolver, LoggingEval);
319   tflite::InterpreterBuilder(model, *logging_op_resolver)(interpreter);
320 
321   if (!(*interpreter)) {
322     model.error_reporter()->Report("Failed to construct interpreter");
323     return kTfLiteError;
324   }
325 
326   // Compute the mapping between runtime and static graph structure, i.e.
327   // (TfLiteContext, TfLiteNode) -> OperatorInfo
328   std::unordered_map<const TfLiteNode*, OperatorInfo> node_ptr_opinfo_map;
329   const TfLiteContext* context = nullptr;
330   GetNodeOpInfoMapAndContext(node_to_opinfo, interpreter->get(),
331                              &node_ptr_opinfo_map, &context);
332 
333   Calibrator* calibrator = nullptr;
334   // Register a calibrator object for the context. This can be accessed
335   // during invocations by the logging kernels.
336   TF_LITE_ENSURE_STATUS(GetCalibratorRegistry()->CreateCalibrator(
337       context, node_ptr_opinfo_map, std::move(logging_op_resolver), &calibrator,
338       model.error_reporter()));
339   *calibration_reader = std::unique_ptr<CalibrationReader>(
340       new Reader(context, calibrator->GetLogger()));
341 
342   return kTfLiteOk;
343 }
344 
345 }  // namespace calibration
346 }  // namespace optimize
347 }  // namespace tflite
348