• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/c/builtin_op_data.h"
17 #include "tensorflow/lite/c/common.h"
18 #include "tensorflow/lite/kernels/internal/compatibility.h"
19 #include "tensorflow/lite/kernels/internal/quantization_util.h"
20 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
21 #include "tensorflow/lite/kernels/kernel_util.h"
22 #include "tensorflow/lite/kernels/op_macros.h"
23 #include "tensorflow/lite/micro/kernels/kernel_util.h"
24 
25 /*
26  * The circular buffer custom operator is used to implement strided streaming
27  * convolutions on TFLite Micro.  Each time this operator is invoked, it checks
28  * whether or not to run, based on a predetermined stride in time.  If the op
29  * runs, it inserts the input into the end of the output buffer and shifts the
30  * output values towards the start of the buffer.  It discards the oldest value
31  * in the output buffer.
32  *
33  * Input: [<input N+1]
34  * Before shifting:
35  * Output: [<input 1>, <input 2>, <input ...>, <input N>]
36  *
37  * After shifting:
38  * Output: [<input 2>, <input 3>, <input ...>, <input N+1>]
39  *
40  * We make some assumptions in this custom operator:
41  * - Input shape must be [1, 1, 1, depth]
42  * - Output shape must be [1, num_slots, 1, depth]
43  * - Input and output types must match.
44  * - Input and output quantization params must be identical.
45  */
46 namespace tflite {
47 namespace ops {
48 namespace micro {
49 namespace circular_buffer {
50 
51 namespace {
52 
53 // The CircularBuffer op has one input and one output tensor.
54 constexpr int kInputTensor = 0;
55 constexpr int kOutputTensor = 0;
56 
57 // TODO(b/149795762): Add this to TfLiteStatus enum.
58 constexpr int kTfLiteAbort = -9;
59 
60 // These fields control the stride period of a strided streaming model. This op
61 // returns kTfLiteAbort until cycles_until_run-- is zero.  At this time,
62 // cycles_until_run is reset to cycles_max.
63 struct OpData {
64   int cycles_until_run;
65   int cycles_max;
66 };
67 
68 }  // namespace
69 
Init(TfLiteContext * context,const char * buffer,size_t length)70 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
71   TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
72   return context->AllocatePersistentBuffer(context, sizeof(OpData));
73 }
74 
Prepare(TfLiteContext * context,TfLiteNode * node)75 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
76   const TfLiteTensor* input = GetInput(context, node, kInputTensor);
77   TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
78 
79   TFLITE_DCHECK(node->user_data != nullptr);
80   OpData* op_data = static_cast<OpData*>(node->user_data);
81 
82   TF_LITE_ENSURE(context, input != nullptr);
83   TF_LITE_ENSURE(context, output != nullptr);
84   TF_LITE_ENSURE_EQ(context, input->dims->data[0], output->dims->data[0]);
85   TF_LITE_ENSURE_EQ(context, 1, input->dims->data[1]);
86   TF_LITE_ENSURE_EQ(context, input->dims->data[2], output->dims->data[2]);
87   TF_LITE_ENSURE_EQ(context, output->dims->data[3], input->dims->data[3]);
88 
89   TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
90 
91   // The circular buffer custom operator currently only supports int8.
92   TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteInt8);
93 
94   // The last circular buffer layer simply accumulates outputs, and does not run
95   // periodically.
96   // TODO(b/150001379): Move this special case logic to the tflite flatbuffer.
97   static int cb_prepare_count = 0;
98   cb_prepare_count++;
99   // These checks specifically work for the only two streaming models supported
100   // on TFLM. They use the shape of the output tensor along with the layer
101   // number to determine if the circular buffer period should be 1 or 2.
102 
103   // These models are outlined int the following documents:
104   // https://docs.google.com/document/d/1lc_G2ZFhjiKFo02UHjBaljye1xsL0EkfybkaVELEE3Q/edit?usp=sharing
105   // https://docs.google.com/document/d/1pGc42PuWyrk-Jy1-9qeqtggvsmHr1ifz8Lmqfpr2rKA/edit?usp=sharing
106   if (output->dims->data[1] == 5 || output->dims->data[1] == 13 ||
107       (cb_prepare_count == 5 && output->dims->data[2] == 2 &&
108        output->dims->data[3] == 96)) {
109     op_data->cycles_max = 1;
110     cb_prepare_count = 0;
111   } else {
112     op_data->cycles_max = 2;
113   }
114   op_data->cycles_until_run = op_data->cycles_max;
115   node->user_data = op_data;
116 
117   return kTfLiteOk;
118 }
119 
120 // Shifts buffer over by the output depth, and write new input to end of buffer.
121 // num_slots is the number of samples stored in the output buffer.
122 // depth is the size of each sample.
EvalInt8(const int8_t * input,int num_slots,int depth,int8_t * output)123 void EvalInt8(const int8_t* input, int num_slots, int depth, int8_t* output) {
124   memmove(output, &output[depth], (num_slots - 1) * depth);
125   memcpy(&output[(num_slots - 1) * depth], input, depth);
126 }
127 
Eval(TfLiteContext * context,TfLiteNode * node)128 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
129   const TfLiteEvalTensor* input =
130       tflite::micro::GetEvalInput(context, node, kInputTensor);
131   TfLiteEvalTensor* output =
132       tflite::micro::GetEvalOutput(context, node, kOutputTensor);
133 
134   TFLITE_DCHECK(node->user_data != nullptr);
135   OpData* data = reinterpret_cast<OpData*>(node->user_data);
136 
137   int num_slots = output->dims->data[1];
138   int depth = output->dims->data[2] * output->dims->data[3];
139 
140   if (input->type == kTfLiteInt8) {
141     EvalInt8(tflite::micro::GetTensorData<int8_t>(input), num_slots, depth,
142              tflite::micro::GetTensorData<int8_t>(output));
143   } else {
144     TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
145                        TfLiteTypeGetName(input->type), input->type);
146     return kTfLiteError;
147   }
148 
149   if (--data->cycles_until_run != 0) {
150     // Signal the interpreter to end current run if the delay before op invoke
151     // has not been reached.
152     // TODO(b/149795762): Add kTfLiteAbort to TfLiteStatus enum.
153     return static_cast<TfLiteStatus>(kTfLiteAbort);
154   }
155 
156   data->cycles_until_run = data->cycles_max;
157 
158   return kTfLiteOk;
159 }
160 
161 }  // namespace circular_buffer
162 
Register_CIRCULAR_BUFFER()163 TfLiteRegistration* Register_CIRCULAR_BUFFER() {
164   static TfLiteRegistration r = {/*init=*/circular_buffer::Init,
165                                  /*free=*/nullptr,
166                                  /*prepare=*/circular_buffer::Prepare,
167                                  /*invoke=*/circular_buffer::Eval,
168                                  /*profiling_string=*/nullptr,
169                                  /*builtin_code=*/0,
170                                  /*custom_name=*/nullptr,
171                                  /*version=*/0};
172   return &r;
173 }
174 
175 }  // namespace micro
176 }  // namespace ops
177 }  // namespace tflite
178