1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/c/builtin_op_data.h"
17 #include "tensorflow/lite/c/common.h"
18 #include "tensorflow/lite/kernels/internal/compatibility.h"
19 #include "tensorflow/lite/kernels/internal/quantization_util.h"
20 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
21 #include "tensorflow/lite/kernels/kernel_util.h"
22 #include "tensorflow/lite/kernels/op_macros.h"
23 #include "tensorflow/lite/micro/kernels/kernel_util.h"
24
25 /*
26 * The circular buffer custom operator is used to implement strided streaming
27 * convolutions on TFLite Micro. Each time this operator is invoked, it checks
28 * whether or not to run, based on a predetermined stride in time. If the op
29 * runs, it inserts the input into the end of the output buffer and shifts the
30 * output values towards the start of the buffer. It discards the oldest value
31 * in the output buffer.
32 *
33 * Input: [<input N+1]
34 * Before shifting:
35 * Output: [<input 1>, <input 2>, <input ...>, <input N>]
36 *
37 * After shifting:
38 * Output: [<input 2>, <input 3>, <input ...>, <input N+1>]
39 *
40 * We make some assumptions in this custom operator:
41 * - Input shape must be [1, 1, 1, depth]
42 * - Output shape must be [1, num_slots, 1, depth]
43 * - Input and output types must match.
44 * - Input and output quantization params must be identical.
45 */
46 namespace tflite {
47 namespace ops {
48 namespace micro {
49 namespace circular_buffer {
50
51 namespace {
52
53 // The CircularBuffer op has one input and one output tensor.
54 constexpr int kInputTensor = 0;
55 constexpr int kOutputTensor = 0;
56
57 // TODO(b/149795762): Add this to TfLiteStatus enum.
58 constexpr int kTfLiteAbort = -9;
59
60 // These fields control the stride period of a strided streaming model. This op
61 // returns kTfLiteAbort until cycles_until_run-- is zero. At this time,
62 // cycles_until_run is reset to cycles_max.
63 struct OpData {
64 int cycles_until_run;
65 int cycles_max;
66 };
67
68 } // namespace
69
Init(TfLiteContext * context,const char * buffer,size_t length)70 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
71 TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
72 return context->AllocatePersistentBuffer(context, sizeof(OpData));
73 }
74
Prepare(TfLiteContext * context,TfLiteNode * node)75 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
76 const TfLiteTensor* input = GetInput(context, node, kInputTensor);
77 TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
78
79 TFLITE_DCHECK(node->user_data != nullptr);
80 OpData* op_data = static_cast<OpData*>(node->user_data);
81
82 TF_LITE_ENSURE(context, input != nullptr);
83 TF_LITE_ENSURE(context, output != nullptr);
84 TF_LITE_ENSURE_EQ(context, input->dims->data[0], output->dims->data[0]);
85 TF_LITE_ENSURE_EQ(context, 1, input->dims->data[1]);
86 TF_LITE_ENSURE_EQ(context, input->dims->data[2], output->dims->data[2]);
87 TF_LITE_ENSURE_EQ(context, output->dims->data[3], input->dims->data[3]);
88
89 TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
90
91 // The circular buffer custom operator currently only supports int8.
92 TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteInt8);
93
94 // The last circular buffer layer simply accumulates outputs, and does not run
95 // periodically.
96 // TODO(b/150001379): Move this special case logic to the tflite flatbuffer.
97 static int cb_prepare_count = 0;
98 cb_prepare_count++;
99 // These checks specifically work for the only two streaming models supported
100 // on TFLM. They use the shape of the output tensor along with the layer
101 // number to determine if the circular buffer period should be 1 or 2.
102
103 // These models are outlined int the following documents:
104 // https://docs.google.com/document/d/1lc_G2ZFhjiKFo02UHjBaljye1xsL0EkfybkaVELEE3Q/edit?usp=sharing
105 // https://docs.google.com/document/d/1pGc42PuWyrk-Jy1-9qeqtggvsmHr1ifz8Lmqfpr2rKA/edit?usp=sharing
106 if (output->dims->data[1] == 5 || output->dims->data[1] == 13 ||
107 (cb_prepare_count == 5 && output->dims->data[2] == 2 &&
108 output->dims->data[3] == 96)) {
109 op_data->cycles_max = 1;
110 cb_prepare_count = 0;
111 } else {
112 op_data->cycles_max = 2;
113 }
114 op_data->cycles_until_run = op_data->cycles_max;
115 node->user_data = op_data;
116
117 return kTfLiteOk;
118 }
119
120 // Shifts buffer over by the output depth, and write new input to end of buffer.
121 // num_slots is the number of samples stored in the output buffer.
122 // depth is the size of each sample.
EvalInt8(const int8_t * input,int num_slots,int depth,int8_t * output)123 void EvalInt8(const int8_t* input, int num_slots, int depth, int8_t* output) {
124 memmove(output, &output[depth], (num_slots - 1) * depth);
125 memcpy(&output[(num_slots - 1) * depth], input, depth);
126 }
127
Eval(TfLiteContext * context,TfLiteNode * node)128 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
129 const TfLiteEvalTensor* input =
130 tflite::micro::GetEvalInput(context, node, kInputTensor);
131 TfLiteEvalTensor* output =
132 tflite::micro::GetEvalOutput(context, node, kOutputTensor);
133
134 TFLITE_DCHECK(node->user_data != nullptr);
135 OpData* data = reinterpret_cast<OpData*>(node->user_data);
136
137 int num_slots = output->dims->data[1];
138 int depth = output->dims->data[2] * output->dims->data[3];
139
140 if (input->type == kTfLiteInt8) {
141 EvalInt8(tflite::micro::GetTensorData<int8_t>(input), num_slots, depth,
142 tflite::micro::GetTensorData<int8_t>(output));
143 } else {
144 TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
145 TfLiteTypeGetName(input->type), input->type);
146 return kTfLiteError;
147 }
148
149 if (--data->cycles_until_run != 0) {
150 // Signal the interpreter to end current run if the delay before op invoke
151 // has not been reached.
152 // TODO(b/149795762): Add kTfLiteAbort to TfLiteStatus enum.
153 return static_cast<TfLiteStatus>(kTfLiteAbort);
154 }
155
156 data->cycles_until_run = data->cycles_max;
157
158 return kTfLiteOk;
159 }
160
161 } // namespace circular_buffer
162
Register_CIRCULAR_BUFFER()163 TfLiteRegistration* Register_CIRCULAR_BUFFER() {
164 static TfLiteRegistration r = {/*init=*/circular_buffer::Init,
165 /*free=*/nullptr,
166 /*prepare=*/circular_buffer::Prepare,
167 /*invoke=*/circular_buffer::Eval,
168 /*profiling_string=*/nullptr,
169 /*builtin_code=*/0,
170 /*custom_name=*/nullptr,
171 /*version=*/0};
172 return &r;
173 }
174
175 } // namespace micro
176 } // namespace ops
177 } // namespace tflite
178