• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <math.h>
17 #include <stddef.h>
18 #include <stdint.h>
19 #include <string.h>
20 
21 #include <algorithm>
22 #include <complex>
23 
24 #include "third_party/fft2d/fft2d.h"
25 #include "ruy/profiler/instrumentation.h"  // from @ruy
26 #include "tensorflow/lite/c/common.h"
27 #include "tensorflow/lite/kernels/internal/tensor.h"
28 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
29 #include "tensorflow/lite/kernels/internal/types.h"
30 #include "tensorflow/lite/kernels/kernel_util.h"
31 
32 namespace tflite {
33 namespace ops {
34 namespace builtin {
35 namespace rfft2d {
36 
37 using std::complex;
38 
39 constexpr int kInputTensor = 0;
40 constexpr int kFftLengthTensor = 1;
41 constexpr int kOutputTensor = 0;
42 constexpr int kFftIntegerWorkingAreaTensor = 0;
43 constexpr int kFftDoubleWorkingAreaTensor = 1;
44 constexpr int kTensorNotAllocated = -1;
45 
46 struct OpData {
47   // IDs are the arbitrary identifiers used by TF Lite to identify and access
48   // memory buffers.
49   int fft_integer_working_area_id = kTensorNotAllocated;
50   int fft_double_working_area_id = kTensorNotAllocated;
51 };
52 
IsPowerOfTwo(uint32_t v)53 bool IsPowerOfTwo(uint32_t v) { return v && !(v & (v - 1)); }
54 
InitTemporaryTensors(TfLiteContext * context,TfLiteNode * node)55 static TfLiteStatus InitTemporaryTensors(TfLiteContext* context,
56                                          TfLiteNode* node) {
57   OpData* data = reinterpret_cast<OpData*>(node->user_data);
58   // The prepare function may be executed multiple times. But temporary tensors
59   // only need to be initiated once.
60   if (data->fft_integer_working_area_id != kTensorNotAllocated &&
61       data->fft_double_working_area_id != kTensorNotAllocated) {
62     return kTfLiteOk;
63   }
64 
65   TfLiteIntArrayFree(node->temporaries);
66   // Create two temporary tensors.
67   node->temporaries = TfLiteIntArrayCreate(2);
68   int first_new_index;
69   TF_LITE_ENSURE_STATUS(context->AddTensors(context, 2, &first_new_index));
70   node->temporaries->data[kFftIntegerWorkingAreaTensor] = first_new_index;
71   data->fft_integer_working_area_id = first_new_index;
72   node->temporaries->data[kFftDoubleWorkingAreaTensor] = first_new_index + 1;
73   data->fft_double_working_area_id = first_new_index + 1;
74 
75   // Set up FFT integer working area buffer.
76   TfLiteTensor* fft_integer_working_area;
77   TF_LITE_ENSURE_OK(
78       context, GetTemporarySafe(context, node, kFftIntegerWorkingAreaTensor,
79                                 &fft_integer_working_area));
80   fft_integer_working_area->type = kTfLiteInt32;
81   // If fft_length is not a constant tensor, fft_integer_working_area will be
82   // set to dynamic later in Prepare.
83   fft_integer_working_area->allocation_type = kTfLiteArenaRw;
84 
85   // Set up FFT double working area buffer.
86   TfLiteTensor* fft_double_working_area;
87   TF_LITE_ENSURE_OK(context,
88                     GetTemporarySafe(context, node, kFftDoubleWorkingAreaTensor,
89                                      &fft_double_working_area));
90   // fft_double_working_area is a double tensor. Ideally, double should be
91   // added into tflite data types. However, since fft_double_working_area is a
92   // temporary tensor, and there are no ops having double input/output tensors
93   // in tflite at this point, adding double as a tflite data type may confuse
94   // users that double is supported. As a results, kTfLiteInt64 is used here
95   // for memory allocation. And it will be cast into double in Eval when being
96   // used.
97   fft_double_working_area->type = kTfLiteInt64;
98   // If fft_length is not a constant tensor, fft_double_working_area will be
99   // set to dynamic later in Prepare.
100   fft_double_working_area->allocation_type = kTfLiteArenaRw;
101 
102   return kTfLiteOk;
103 }
104 
ResizeOutputandTemporaryTensors(TfLiteContext * context,TfLiteNode * node)105 TfLiteStatus ResizeOutputandTemporaryTensors(TfLiteContext* context,
106                                              TfLiteNode* node) {
107   const TfLiteTensor* input;
108   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
109   const int num_dims = NumDimensions(input);
110   TF_LITE_ENSURE(context, num_dims >= 2);
111   const TfLiteTensor* fft_length;
112   TF_LITE_ENSURE_OK(context,
113                     GetInputSafe(context, node, kFftLengthTensor, &fft_length));
114   const int32_t* fft_length_data = GetTensorData<int32_t>(fft_length);
115   // The lib, fft2d, can only handle fft_lengths of power of 2.
116   TF_LITE_ENSURE(context, IsPowerOfTwo(fft_length_data[0]));
117   TF_LITE_ENSURE(context, IsPowerOfTwo(fft_length_data[1]));
118 
119   int fft_height, fft_width;
120   fft_height = fft_length_data[0];
121   fft_width = fft_length_data[1];
122   int fft_working_length = std::max(fft_height, fft_width / 2);
123   int half_fft_working_length = fft_working_length / 2;
124 
125   // Resize output tensor.
126   TfLiteTensor* output;
127   TF_LITE_ENSURE_OK(context,
128                     GetOutputSafe(context, node, kOutputTensor, &output));
129   TfLiteIntArray* output_shape = TfLiteIntArrayCopy(input->dims);
130   output_shape->data[num_dims - 2] = fft_length_data[0];
131   output_shape->data[num_dims - 1] = fft_length_data[1] / 2 + 1;
132   TF_LITE_ENSURE_STATUS(context->ResizeTensor(context, output, output_shape));
133 
134   // Resize temporary tensors, fft_integer_working_area.
135   TfLiteTensor* fft_integer_working_area;
136   TF_LITE_ENSURE_OK(
137       context, GetTemporarySafe(context, node, kFftIntegerWorkingAreaTensor,
138                                 &fft_integer_working_area));
139   TfLiteIntArray* fft_integer_working_area_shape = TfLiteIntArrayCreate(1);
140   fft_integer_working_area_shape->data[0] =
141       2 + static_cast<int>(sqrt(fft_working_length));
142   TF_LITE_ENSURE_STATUS(context->ResizeTensor(context, fft_integer_working_area,
143                                               fft_integer_working_area_shape));
144 
145   // Resize temporary tensors, fft_double_working_area.
146   TfLiteTensor* fft_double_working_area;
147   TF_LITE_ENSURE_OK(context,
148                     GetTemporarySafe(context, node, kFftDoubleWorkingAreaTensor,
149                                      &fft_double_working_area));
150   TfLiteIntArray* fft_double_working_area_shape = TfLiteIntArrayCreate(1);
151   fft_double_working_area_shape->data[0] =
152       half_fft_working_length + fft_width / 4;
153   TF_LITE_ENSURE_STATUS(context->ResizeTensor(context, fft_double_working_area,
154                                               fft_double_working_area_shape));
155 
156   return kTfLiteOk;
157 }
158 
Init(TfLiteContext * context,const char * buffer,size_t length)159 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
160   auto* data = new OpData;
161   return data;
162 }
163 
Free(TfLiteContext * context,void * buffer)164 void Free(TfLiteContext* context, void* buffer) {
165   delete reinterpret_cast<OpData*>(buffer);
166 }
167 
Prepare(TfLiteContext * context,TfLiteNode * node)168 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
169   TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
170   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
171 
172   // Check type and shape of the input tensor
173   const TfLiteTensor* input;
174   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
175   TF_LITE_ENSURE(context, NumDimensions(input) >= 2);
176   if (input->type != kTfLiteFloat32) {
177     context->ReportError(context,
178                          "Type '%s' for input is not supported by rfft2d.",
179                          TfLiteTypeGetName(input->type));
180     return kTfLiteError;
181   }
182 
183   // Check type and shape of the fft_length tensor
184   const TfLiteTensor* fft_length;
185   TF_LITE_ENSURE_OK(context,
186                     GetInputSafe(context, node, kFftLengthTensor, &fft_length));
187   const RuntimeShape fft_length_shape = GetTensorShape(fft_length);
188 
189   TF_LITE_ENSURE_EQ(context, NumDimensions(fft_length), 1);
190   TF_LITE_ENSURE_EQ(context, fft_length_shape.Dims(0), 2);
191   if (fft_length->type != kTfLiteInt32) {
192     context->ReportError(context,
193                          "Type '%s' for fft_length is not supported by rfft2d.",
194                          TfLiteTypeGetName(fft_length->type));
195     return kTfLiteError;
196   }
197 
198   // Setup temporary tensors for fft computation.
199   TF_LITE_ENSURE_STATUS(InitTemporaryTensors(context, node));
200 
201   // Set output type
202   TfLiteTensor* output;
203   TF_LITE_ENSURE_OK(context,
204                     GetOutputSafe(context, node, kOutputTensor, &output));
205   output->type = kTfLiteComplex64;
206 
207   // Exit early if fft_length is a non-const tensor. Set output tensor and
208   // temporary tensors to dynamic, so that their tensor sizes can be determined
209   // in Eval.
210   if (!IsConstantTensor(fft_length)) {
211     TfLiteTensor* fft_integer_working_area;
212     TF_LITE_ENSURE_OK(
213         context, GetTemporarySafe(context, node, kFftIntegerWorkingAreaTensor,
214                                   &fft_integer_working_area));
215     TfLiteTensor* fft_double_working_area;
216     TF_LITE_ENSURE_OK(
217         context, GetTemporarySafe(context, node, kFftDoubleWorkingAreaTensor,
218                                   &fft_double_working_area));
219     SetTensorToDynamic(fft_integer_working_area);
220     SetTensorToDynamic(fft_double_working_area);
221     SetTensorToDynamic(output);
222     return kTfLiteOk;
223   }
224 
225   TF_LITE_ENSURE_STATUS(ResizeOutputandTemporaryTensors(context, node));
226   return kTfLiteOk;
227 }
228 
229 // Reorder the result so that it matches the pattern of tf.signal.rfft2d.
230 // In tf.signal.fft2d the frequency matrix of a 4x4 input is
231 //    [[F(0, 0),  F(0, 1/4),   F(0, 2/4)],
232 //    [F(1/4, 0), F(1/4, 1/4), F(1/4, 2/4)],
233 //    [F(2/4, 0), F(2/4, 1/4), F(2/4, 2/4)],
234 //    [F(3/4, 0), F(3/4, 1/4), F(3/4, 2/4)]]
235 // While in rdft2d, the frequency matrix of a 4x4 input is
236 //    [[(F(0, 0), F(0, -2/4))       F(0, -1/4),   0],
237 //     [ F(-1/4, 0),                F(-1/4, -1/4), 0],
238 //     [(F(-2/4, 0),F(-2/4, -2/4)), F(-2/4, -1/4), 0],
239 //     [ j*F(-3/4, -2/4),           F(-3/4, -1/4), 0]]
240 // Since real fft has the property that
241 //   Real(u,v) = Real(-u, -v)
242 //   Img(u,v) = - Img(-u, -v)
243 // Result of rdft2d can be reordered and match the pattern of tf.signal.rfft2d.
244 // For example,
245 //   Real(-3/4, 0) = Real(1/4, 0) = Real(-1/4, 0)
246 //   Img(-3/4, 0) = Img(1/4, 0) = -Img(-1/4, 0)
Rfft2dReorder(int fft_height,int fft_width,double ** fft_input_output)247 void Rfft2dReorder(int fft_height, int fft_width, double** fft_input_output) {
248   int fft_height_half;
249   ruy::profiler::ScopeLabel label("Rfft2dReorder");
250   double real, img;
251 
252   fft_height_half = fft_height >> 1;
253   // Use 4x4 input as an example, reorder the frequency matrix from
254   //    [[(F(0, 0), F(0, -2/4))       F(0, -1/4),   0],
255   //     [ F(-1/4, 0),                F(-1/4, -1/4), 0],
256   //     [(F(-2/4, 0),F(-2/4, -2/4)), F(-2/4, -1/4), 0],
257   //     [ j*F(-3/4, -2/4),           F(-3/4, -1/4), 0]]
258   // to
259   //    [[F(0, 0),  F(0, -1/4),   F(0, -2/4)],
260   //    [F(-1/4, 0), F(-1/4, -1/4), F(-1/4, -2/4)],
261   //    [F(-2/4, 0), F(-2/4, -1/4), F(-2/4, -2/4)],
262   //    [F(-3/4, 0), F(-3/4, -1/4), F(-3/4, -2/4)]]
263   for (int i = fft_height_half + 1; i < fft_height; ++i) {
264     real = fft_input_output[i][0];
265     img = fft_input_output[i][1];
266     fft_input_output[i][fft_width] = img;
267     fft_input_output[i][fft_width + 1] = real;
268     fft_input_output[fft_height - i][fft_width] = img;
269     fft_input_output[fft_height - i][fft_width + 1] = -real;
270     fft_input_output[i][0] = fft_input_output[fft_height - i][0];
271     fft_input_output[i][1] = -fft_input_output[fft_height - i][1];
272   }
273 
274   double temp = fft_input_output[0][1];
275   fft_input_output[0][fft_width + 1] = 0;
276   fft_input_output[0][1] = 0;
277   fft_input_output[fft_height_half][fft_width] =
278       fft_input_output[fft_height_half][1];
279   fft_input_output[fft_height_half][fft_width + 1] = 0;
280   fft_input_output[fft_height_half][1] = 0;
281   fft_input_output[0][fft_width] = temp;
282 
283   // Reorder the frequency matrix from
284   //    [[F(0, 0),  F(0, -1/4),   F(0, -2/4)],
285   //    [F(-1/4, 0), F(-1/4, -1/4), F(-1/4, -2/4)],
286   //    [F(-2/4, 0), F(-2/4, -1/4), F(-2/4, -2/4)],
287   //    [F(-3/4, 0), F(-3/4, -1/4), F(-3/4, -2/4)]]
288   // to
289   //    [[F(0, 0),  F(0, 1/4),   F(0, 2/4)],
290   //    [F(1/4, 0), F(1/4, 1/4), F(1/4, 2/4)],
291   //    [F(2/4, 0), F(2/4, 1/4), F(2/4, 2/4)],
292   //    [F(3/4, 0), F(3/4, 1/4), F(3/4, 2/4)]]
293   for (int i = 0; i < fft_height; ++i) {
294     for (int j = 1; j < fft_width + 2; j += 2) {
295       fft_input_output[i][j] = -fft_input_output[i][j];
296     }
297   }
298 }
299 
Rfft2dImpl(int fft_height,int fft_width,double ** fft_input_output,int * fft_integer_working_area_data,double * fft_double_working_area_data)300 void Rfft2dImpl(int fft_height, int fft_width, double** fft_input_output,
301                 int* fft_integer_working_area_data,
302                 double* fft_double_working_area_data) {
303   ruy::profiler::ScopeLabel label("Rfft2dImpl");
304 
305   // Working data areas for the FFT routines.
306   double* fft_dynamic_working_area = nullptr;
307   const int kForwardFft = 1;
308   rdft2d(fft_height, fft_width, kForwardFft, fft_input_output,
309          fft_dynamic_working_area, fft_integer_working_area_data,
310          fft_double_working_area_data);
311   Rfft2dReorder(fft_height, fft_width, fft_input_output);
312 }
313 
PrepareInputBuffer(const float * input_data,int input_height,int input_width,int fft_height,int fft_width,double ** fft_input_output)314 void PrepareInputBuffer(const float* input_data, int input_height,
315                         int input_width, int fft_height, int fft_width,
316                         double** fft_input_output) {
317   int valid_input_height = std::min(input_height, fft_height);
318   int valid_input_width = std::min(input_width, fft_width);
319   for (int i = 0; i < valid_input_height; ++i) {
320     int in_pos = i * input_width;
321     for (int j = 0; j < valid_input_width; ++j) {
322       fft_input_output[i][j] = input_data[in_pos++];
323     }
324     // Zero-pad the rest of the input buffer
325     for (int j = valid_input_width; j < fft_width + 2; ++j) {
326       fft_input_output[i][j] = 0;
327     }
328   }
329 
330   // Zero-pad input buffer, if fft_height is greater than valid_input_height.
331   for (int i = valid_input_height; i < fft_height; ++i) {
332     for (int j = 0; j < fft_width + 2; ++j) {
333       fft_input_output[i][j] = 0;
334     }
335   }
336 }
337 
PrepareOutputBuffer(complex<float> * output_data,int fft_height,int fft_width,double ** fft_input_output)338 void PrepareOutputBuffer(complex<float>* output_data, int fft_height,
339                          int fft_width, double** fft_input_output) {
340   int cnt = 0;
341   for (int i = 0; i < fft_height; ++i) {
342     for (int j = 0; j < fft_width / 2 + 1; ++j) {
343       output_data[cnt++] = complex<float>(fft_input_output[i][j * 2],
344                                           fft_input_output[i][j * 2 + 1]);
345     }
346   }
347 }
348 
Rfft2dHelper(TfLiteContext * context,TfLiteNode * node)349 TfLiteStatus Rfft2dHelper(TfLiteContext* context, TfLiteNode* node) {
350   const TfLiteTensor* input;
351   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
352   const float* input_data = GetTensorData<float>(input);
353   const TfLiteTensor* fft_length;
354   TF_LITE_ENSURE_OK(context,
355                     GetInputSafe(context, node, kFftLengthTensor, &fft_length));
356   const int32_t* fft_length_data = GetTensorData<int32_t>(fft_length);
357   TfLiteTensor* output;
358   TF_LITE_ENSURE_OK(context,
359                     GetOutputSafe(context, node, kOutputTensor, &output));
360   complex<float>* output_data = GetTensorData<complex<float>>(output);
361 
362   int fft_height, fft_width;
363   fft_height = fft_length_data[0];
364   fft_width = fft_length_data[1];
365 
366   // FFT is processed for every slice on the inner most 2 dimensions.
367   // Count the number of slices in the input tensor.
368   const RuntimeShape input_shape = GetTensorShape(input);
369   const int input_dims_count = input_shape.DimensionsCount();
370   const auto* input_dims_data = input_shape.DimsData();
371   int num_slices = 1;
372   for (int i = 0; i < input_dims_count - 2; ++i) {
373     num_slices *= input_dims_data[i];
374   }
375 
376   int input_height = input_dims_data[input_dims_count - 2];
377   int input_width = input_dims_data[input_dims_count - 1];
378   int input_slice_size = input_height * input_width;
379   int output_slice_size = fft_height * (fft_width / 2 + 1);
380 
381   // Create input/output buffer for FFT
382   double** fft_input_output = new double*[fft_height];
383   for (int i = 0; i < fft_height; ++i) {
384     fft_input_output[i] = new double[fft_width + 2];
385   }
386 
387   // Get buffer for integer working area.
388   TfLiteTensor* fft_integer_working_area;
389   TF_LITE_ENSURE_OK(
390       context, GetTemporarySafe(context, node, kFftIntegerWorkingAreaTensor,
391                                 &fft_integer_working_area));
392   int* fft_integer_working_area_data =
393       GetTensorData<int>(fft_integer_working_area);
394 
395   // Get buffer for double working area.
396   TfLiteTensor* fft_double_working_area;
397   TF_LITE_ENSURE_OK(context,
398                     GetTemporarySafe(context, node, kFftDoubleWorkingAreaTensor,
399                                      &fft_double_working_area));
400   // Get double value out of the memory of fft_double_working_area_data.
401   double* fft_double_working_area_data = reinterpret_cast<double*>(
402       GetTensorData<int64_t>(fft_double_working_area));
403 
404   // Process every slice in the input buffer
405   for (int i = 0; i < num_slices; ++i) {
406     PrepareInputBuffer(input_data, input_height, input_width, fft_height,
407                        fft_width, fft_input_output);
408     memset(fft_integer_working_area_data, 0, fft_integer_working_area->bytes);
409     memset(fft_double_working_area_data, 0, fft_double_working_area->bytes);
410     Rfft2dImpl(fft_height, fft_width, fft_input_output,
411                fft_integer_working_area_data, fft_double_working_area_data);
412     PrepareOutputBuffer(output_data, fft_height, fft_width, fft_input_output);
413     input_data += input_slice_size;
414     output_data += output_slice_size;
415   }
416 
417   // Delete the input buffer
418   for (int i = 0; i < fft_height; ++i) {
419     delete[] fft_input_output[i];
420   }
421   delete[] fft_input_output;
422 
423   return kTfLiteOk;
424 }
425 
Eval(TfLiteContext * context,TfLiteNode * node)426 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
427   const TfLiteTensor* input;
428   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
429   const TfLiteTensor* fft_length;
430   TF_LITE_ENSURE_OK(context,
431                     GetInputSafe(context, node, kFftLengthTensor, &fft_length));
432   const int32_t* fft_length_data = GetTensorData<int32_t>(fft_length);
433   TfLiteTensor* output;
434   TF_LITE_ENSURE_OK(context,
435                     GetOutputSafe(context, node, kOutputTensor, &output));
436 
437   if (output->type != kTfLiteComplex64) {
438     context->ReportError(context,
439                          "Type '%s' for output is not supported by rfft2d.",
440                          TfLiteTypeGetName(output->type));
441     return kTfLiteError;
442   }
443 
444   // Resize the output tensor if the fft_length tensor is not constant.
445   // Otherwise, check if the output shape is correct.
446   if (!IsConstantTensor(fft_length)) {
447     TF_LITE_ENSURE_STATUS(ResizeOutputandTemporaryTensors(context, node));
448   } else {
449     int num_dims_output = NumDimensions(output);
450     const RuntimeShape output_shape = GetTensorShape(output);
451     TF_LITE_ENSURE_EQ(context, num_dims_output, NumDimensions(input));
452     TF_LITE_ENSURE(context, num_dims_output >= 2);
453     TF_LITE_ENSURE_EQ(context, output_shape.Dims(num_dims_output - 2),
454                       fft_length_data[0]);
455     TF_LITE_ENSURE_EQ(context, output_shape.Dims(num_dims_output - 1),
456                       fft_length_data[1] / 2 + 1);
457   }
458 
459   return Rfft2dHelper(context, node);
460 }
461 
462 }  // namespace rfft2d
463 
Register_RFFT2D()464 TfLiteRegistration* Register_RFFT2D() {
465   static TfLiteRegistration r = {rfft2d::Init, rfft2d::Free, rfft2d::Prepare,
466                                  rfft2d::Eval};
467   return &r;
468 }
469 
470 }  // namespace builtin
471 }  // namespace ops
472 }  // namespace tflite
473