• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/micro/kernels/kernel_runner.h"
17 
18 #include "tensorflow/lite/micro/micro_error_reporter.h"
19 
20 namespace tflite {
21 namespace micro {
22 
23 namespace {
24 constexpr size_t kBufferAlignment = 16;
25 }  // namespace
26 
27 // TODO(b/161841696): Consider moving away from global arena buffers:
28 constexpr int KernelRunner::kNumScratchBuffers_;
29 constexpr int KernelRunner::kKernelRunnerBufferSize_;
30 uint8_t KernelRunner::kKernelRunnerBuffer_[];
31 
KernelRunner(const TfLiteRegistration & registration,TfLiteTensor * tensors,int tensors_size,TfLiteIntArray * inputs,TfLiteIntArray * outputs,void * builtin_data)32 KernelRunner::KernelRunner(const TfLiteRegistration& registration,
33                            TfLiteTensor* tensors, int tensors_size,
34                            TfLiteIntArray* inputs, TfLiteIntArray* outputs,
35                            void* builtin_data)
36     : allocator_(SimpleMemoryAllocator::Create(GetMicroErrorReporter(),
37                                                kKernelRunnerBuffer_,
38                                                kKernelRunnerBufferSize_)),
39       registration_(registration),
40       tensors_(tensors) {
41   // Prepare TfLiteContext:
42   context_.impl_ = static_cast<void*>(this);
43   context_.ReportError = ReportOpError;
44   context_.recommended_num_threads = 1;
45   context_.GetTensor = GetTensor;
46   context_.GetEvalTensor = GetEvalTensor;
47   context_.AllocatePersistentBuffer = AllocatePersistentBuffer;
48   context_.RequestScratchBufferInArena = RequestScratchBufferInArena;
49   context_.GetScratchBuffer = GetScratchBuffer;
50 
51   // Prepare TfLiteNode:
52   node_.inputs = inputs;
53   node_.outputs = outputs;
54   node_.builtin_data = builtin_data;
55 }
56 
InitAndPrepare(const char * init_data,size_t length)57 TfLiteStatus KernelRunner::InitAndPrepare(const char* init_data,
58                                           size_t length) {
59   if (registration_.init) {
60     node_.user_data = registration_.init(&context_, init_data, length);
61   }
62   if (registration_.prepare) {
63     TF_LITE_ENSURE_STATUS(registration_.prepare(&context_, &node_));
64   }
65   return kTfLiteOk;
66 }
67 
Invoke()68 TfLiteStatus KernelRunner::Invoke() {
69   if (registration_.invoke == nullptr) {
70     MicroPrintf("TfLiteRegistration missing invoke function pointer!");
71     return kTfLiteError;
72   }
73   return registration_.invoke(&context_, &node_);
74 }
75 
GetTensor(const struct TfLiteContext * context,int tensor_index)76 TfLiteTensor* KernelRunner::GetTensor(const struct TfLiteContext* context,
77                                       int tensor_index) {
78   TFLITE_DCHECK(context != nullptr);
79   KernelRunner* runner = reinterpret_cast<KernelRunner*>(context->impl_);
80   TFLITE_DCHECK(runner != nullptr);
81 
82   return &runner->tensors_[tensor_index];
83 }
84 
GetEvalTensor(const struct TfLiteContext * context,int tensor_index)85 TfLiteEvalTensor* KernelRunner::GetEvalTensor(
86     const struct TfLiteContext* context, int tensor_index) {
87   TFLITE_DCHECK(context != nullptr);
88   KernelRunner* runner = reinterpret_cast<KernelRunner*>(context->impl_);
89   TFLITE_DCHECK(runner != nullptr);
90 
91   TfLiteEvalTensor* eval_tensor =
92       reinterpret_cast<TfLiteEvalTensor*>(runner->allocator_->AllocateTemp(
93           sizeof(TfLiteEvalTensor), alignof(TfLiteEvalTensor)));
94   TFLITE_DCHECK(eval_tensor != nullptr);
95 
96   // In unit tests, the TfLiteTensor pointer contains the source of truth for
97   // buffers and values:
98   eval_tensor->data = runner->tensors_[tensor_index].data;
99   eval_tensor->dims = runner->tensors_[tensor_index].dims;
100   eval_tensor->type = runner->tensors_[tensor_index].type;
101   return eval_tensor;
102 }
103 
AllocatePersistentBuffer(TfLiteContext * context,size_t bytes)104 void* KernelRunner::AllocatePersistentBuffer(TfLiteContext* context,
105                                              size_t bytes) {
106   TFLITE_DCHECK(context != nullptr);
107   KernelRunner* runner = reinterpret_cast<KernelRunner*>(context->impl_);
108   TFLITE_DCHECK(runner != nullptr);
109 
110   return runner->allocator_->AllocateFromTail(bytes, kBufferAlignment);
111 }
112 
RequestScratchBufferInArena(TfLiteContext * context,size_t bytes,int * buffer_index)113 TfLiteStatus KernelRunner::RequestScratchBufferInArena(TfLiteContext* context,
114                                                        size_t bytes,
115                                                        int* buffer_index) {
116   TFLITE_DCHECK(context != nullptr);
117   TFLITE_DCHECK(buffer_index != nullptr);
118 
119   KernelRunner* runner = reinterpret_cast<KernelRunner*>(context->impl_);
120   TFLITE_DCHECK(runner != nullptr);
121 
122   if (runner->scratch_buffer_count_ == kNumScratchBuffers_) {
123     MicroPrintf("Exceeded the maximum number of scratch tensors allowed (%d).",
124                 kNumScratchBuffers_);
125     return kTfLiteError;
126   }
127 
128   // For tests, we allocate scratch buffers from the tail and keep them around
129   // for the lifetime of model. This means that the arena size in the tests will
130   // be more than what we would have if the scratch buffers could share memory.
131   runner->scratch_buffers_[runner->scratch_buffer_count_] =
132       runner->allocator_->AllocateFromTail(bytes, kBufferAlignment);
133   TFLITE_DCHECK(runner->scratch_buffers_[runner->scratch_buffer_count_] !=
134                 nullptr);
135 
136   *buffer_index = runner->scratch_buffer_count_++;
137   return kTfLiteOk;
138 }
139 
GetScratchBuffer(TfLiteContext * context,int buffer_index)140 void* KernelRunner::GetScratchBuffer(TfLiteContext* context, int buffer_index) {
141   TFLITE_DCHECK(context != nullptr);
142   KernelRunner* runner = reinterpret_cast<KernelRunner*>(context->impl_);
143   TFLITE_DCHECK(runner != nullptr);
144 
145   TFLITE_DCHECK(runner->scratch_buffer_count_ <= kNumScratchBuffers_);
146   if (buffer_index >= runner->scratch_buffer_count_) {
147     return nullptr;
148   }
149   return runner->scratch_buffers_[buffer_index];
150 }
151 
ReportOpError(struct TfLiteContext * context,const char * format,...)152 void KernelRunner::ReportOpError(struct TfLiteContext* context,
153                                  const char* format, ...) {
154   va_list args;
155   va_start(args, format);
156   GetMicroErrorReporter()->Report(format, args);
157   va_end(args);
158 }
159 
160 }  // namespace micro
161 }  // namespace tflite
162