1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/micro/kernels/kernel_runner.h"
17
18 #include "tensorflow/lite/micro/micro_error_reporter.h"
19
20 namespace tflite {
21 namespace micro {
22
23 namespace {
24 constexpr size_t kBufferAlignment = 16;
25 } // namespace
26
27 // TODO(b/161841696): Consider moving away from global arena buffers:
28 constexpr int KernelRunner::kNumScratchBuffers_;
29 constexpr int KernelRunner::kKernelRunnerBufferSize_;
30 uint8_t KernelRunner::kKernelRunnerBuffer_[];
31
KernelRunner(const TfLiteRegistration & registration,TfLiteTensor * tensors,int tensors_size,TfLiteIntArray * inputs,TfLiteIntArray * outputs,void * builtin_data)32 KernelRunner::KernelRunner(const TfLiteRegistration& registration,
33 TfLiteTensor* tensors, int tensors_size,
34 TfLiteIntArray* inputs, TfLiteIntArray* outputs,
35 void* builtin_data)
36 : allocator_(SimpleMemoryAllocator::Create(GetMicroErrorReporter(),
37 kKernelRunnerBuffer_,
38 kKernelRunnerBufferSize_)),
39 registration_(registration),
40 tensors_(tensors) {
41 // Prepare TfLiteContext:
42 context_.impl_ = static_cast<void*>(this);
43 context_.ReportError = ReportOpError;
44 context_.recommended_num_threads = 1;
45 context_.GetTensor = GetTensor;
46 context_.GetEvalTensor = GetEvalTensor;
47 context_.AllocatePersistentBuffer = AllocatePersistentBuffer;
48 context_.RequestScratchBufferInArena = RequestScratchBufferInArena;
49 context_.GetScratchBuffer = GetScratchBuffer;
50
51 // Prepare TfLiteNode:
52 node_.inputs = inputs;
53 node_.outputs = outputs;
54 node_.builtin_data = builtin_data;
55 }
56
InitAndPrepare(const char * init_data,size_t length)57 TfLiteStatus KernelRunner::InitAndPrepare(const char* init_data,
58 size_t length) {
59 if (registration_.init) {
60 node_.user_data = registration_.init(&context_, init_data, length);
61 }
62 if (registration_.prepare) {
63 TF_LITE_ENSURE_STATUS(registration_.prepare(&context_, &node_));
64 }
65 return kTfLiteOk;
66 }
67
Invoke()68 TfLiteStatus KernelRunner::Invoke() {
69 if (registration_.invoke == nullptr) {
70 MicroPrintf("TfLiteRegistration missing invoke function pointer!");
71 return kTfLiteError;
72 }
73 return registration_.invoke(&context_, &node_);
74 }
75
GetTensor(const struct TfLiteContext * context,int tensor_index)76 TfLiteTensor* KernelRunner::GetTensor(const struct TfLiteContext* context,
77 int tensor_index) {
78 TFLITE_DCHECK(context != nullptr);
79 KernelRunner* runner = reinterpret_cast<KernelRunner*>(context->impl_);
80 TFLITE_DCHECK(runner != nullptr);
81
82 return &runner->tensors_[tensor_index];
83 }
84
GetEvalTensor(const struct TfLiteContext * context,int tensor_index)85 TfLiteEvalTensor* KernelRunner::GetEvalTensor(
86 const struct TfLiteContext* context, int tensor_index) {
87 TFLITE_DCHECK(context != nullptr);
88 KernelRunner* runner = reinterpret_cast<KernelRunner*>(context->impl_);
89 TFLITE_DCHECK(runner != nullptr);
90
91 TfLiteEvalTensor* eval_tensor =
92 reinterpret_cast<TfLiteEvalTensor*>(runner->allocator_->AllocateTemp(
93 sizeof(TfLiteEvalTensor), alignof(TfLiteEvalTensor)));
94 TFLITE_DCHECK(eval_tensor != nullptr);
95
96 // In unit tests, the TfLiteTensor pointer contains the source of truth for
97 // buffers and values:
98 eval_tensor->data = runner->tensors_[tensor_index].data;
99 eval_tensor->dims = runner->tensors_[tensor_index].dims;
100 eval_tensor->type = runner->tensors_[tensor_index].type;
101 return eval_tensor;
102 }
103
AllocatePersistentBuffer(TfLiteContext * context,size_t bytes)104 void* KernelRunner::AllocatePersistentBuffer(TfLiteContext* context,
105 size_t bytes) {
106 TFLITE_DCHECK(context != nullptr);
107 KernelRunner* runner = reinterpret_cast<KernelRunner*>(context->impl_);
108 TFLITE_DCHECK(runner != nullptr);
109
110 return runner->allocator_->AllocateFromTail(bytes, kBufferAlignment);
111 }
112
RequestScratchBufferInArena(TfLiteContext * context,size_t bytes,int * buffer_index)113 TfLiteStatus KernelRunner::RequestScratchBufferInArena(TfLiteContext* context,
114 size_t bytes,
115 int* buffer_index) {
116 TFLITE_DCHECK(context != nullptr);
117 TFLITE_DCHECK(buffer_index != nullptr);
118
119 KernelRunner* runner = reinterpret_cast<KernelRunner*>(context->impl_);
120 TFLITE_DCHECK(runner != nullptr);
121
122 if (runner->scratch_buffer_count_ == kNumScratchBuffers_) {
123 MicroPrintf("Exceeded the maximum number of scratch tensors allowed (%d).",
124 kNumScratchBuffers_);
125 return kTfLiteError;
126 }
127
128 // For tests, we allocate scratch buffers from the tail and keep them around
129 // for the lifetime of model. This means that the arena size in the tests will
130 // be more than what we would have if the scratch buffers could share memory.
131 runner->scratch_buffers_[runner->scratch_buffer_count_] =
132 runner->allocator_->AllocateFromTail(bytes, kBufferAlignment);
133 TFLITE_DCHECK(runner->scratch_buffers_[runner->scratch_buffer_count_] !=
134 nullptr);
135
136 *buffer_index = runner->scratch_buffer_count_++;
137 return kTfLiteOk;
138 }
139
GetScratchBuffer(TfLiteContext * context,int buffer_index)140 void* KernelRunner::GetScratchBuffer(TfLiteContext* context, int buffer_index) {
141 TFLITE_DCHECK(context != nullptr);
142 KernelRunner* runner = reinterpret_cast<KernelRunner*>(context->impl_);
143 TFLITE_DCHECK(runner != nullptr);
144
145 TFLITE_DCHECK(runner->scratch_buffer_count_ <= kNumScratchBuffers_);
146 if (buffer_index >= runner->scratch_buffer_count_) {
147 return nullptr;
148 }
149 return runner->scratch_buffers_[buffer_index];
150 }
151
ReportOpError(struct TfLiteContext * context,const char * format,...)152 void KernelRunner::ReportOpError(struct TfLiteContext* context,
153 const char* format, ...) {
154 va_list args;
155 va_start(args, format);
156 GetMicroErrorReporter()->Report(format, args);
157 va_end(args);
158 }
159
160 } // namespace micro
161 } // namespace tflite
162