• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/c/common.h"
16 #include "tensorflow/lite/kernels/internal/common.h"
17 #include "tensorflow/lite/kernels/internal/tensor.h"
18 #include "tensorflow/lite/kernels/kernel_util.h"
19 
20 namespace tflite {
21 namespace ops {
22 namespace custom {
23 namespace table {
24 
25 constexpr int kInputTensor = 0;
26 constexpr int kTable = 1;
27 constexpr int kOutputTensor = 0;
28 
Prepare(TfLiteContext * context,TfLiteNode * node)29 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
30   TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
31   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
32 
33   const TfLiteTensor* input;
34   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
35   const TfLiteTensor* table;
36   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kTable, &table));
37   TfLiteTensor* output;
38   TF_LITE_ENSURE_OK(context,
39                     GetOutputSafe(context, node, kOutputTensor, &output));
40 
41   TF_LITE_ENSURE(context,
42                  input->type == kTfLiteInt8 || input->type == kTfLiteInt16);
43   TF_LITE_ENSURE(context,
44                  output->type == kTfLiteInt8 || output->type == kTfLiteInt16);
45   TF_LITE_ENSURE_TYPES_EQ(context, output->type, table->type);
46 
47   if (input->type == kTfLiteInt16) {
48     TF_LITE_ENSURE_EQ(context, input->params.zero_point, 0);
49   }
50   if (output->type == kTfLiteInt16) {
51     TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0);
52   }
53 
54   TF_LITE_ENSURE_EQ(context, NumDimensions(table), 1);
55   if (input->type == kTfLiteInt8) {
56     TF_LITE_ENSURE_EQ(context, NumElements(table), lut_size<int8_t>());
57   } else {
58     TF_LITE_ENSURE_EQ(context, input->type, kTfLiteInt16);
59     TF_LITE_ENSURE_EQ(context, NumElements(table), lut_size<int16_t>());
60   }
61 
62   return context->ResizeTensor(context, output,
63                                TfLiteIntArrayCopy(input->dims));
64 }
65 
66 template <typename InputT, typename OutputT>
Table(TfLiteContext * context,const TfLiteTensor * input,const TfLiteTensor * table,TfLiteTensor * output)67 void Table(TfLiteContext* context, const TfLiteTensor* input,
68            const TfLiteTensor* table, TfLiteTensor* output) {
69   const InputT* input_data = GetTensorData<InputT>(input);
70   const OutputT* table_data = GetTensorData<OutputT>(table);
71   OutputT* output_data = GetTensorData<OutputT>(output);
72 
73   const int num_elements = NumElements(input);
74   for (int i = 0; i < num_elements; i++) {
75     // No need to rescale the input and output, the rescaling and its zero-point
76     // are implicitly included into the table data during its generation.
77     output_data[i] = lut_lookup(input_data[i], table_data);
78   }
79 }
80 
81 template <typename InputT>
EvalTable(TfLiteContext * context,const TfLiteTensor * input,const TfLiteTensor * table,TfLiteTensor * output)82 TfLiteStatus EvalTable(TfLiteContext* context, const TfLiteTensor* input,
83                        const TfLiteTensor* table, TfLiteTensor* output) {
84   switch (output->type) {
85     case kTfLiteInt8:
86       Table<InputT, int8_t>(context, input, table, output);
87       break;
88     case kTfLiteInt16:
89       Table<InputT, int16_t>(context, input, table, output);
90       break;
91     default:
92       TF_LITE_UNSUPPORTED_TYPE(context, output->type, "Table");
93   }
94 
95   return kTfLiteOk;
96 }
97 
Eval(TfLiteContext * context,TfLiteNode * node)98 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
99   const TfLiteTensor* input;
100   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
101   const TfLiteTensor* table;
102   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kTable, &table));
103   TfLiteTensor* output;
104   TF_LITE_ENSURE_OK(context,
105                     GetOutputSafe(context, node, kOutputTensor, &output));
106 
107   switch (input->type) {
108     case kTfLiteInt8:
109       return EvalTable<int8_t>(context, input, table, output);
110     case kTfLiteInt16:
111       return EvalTable<int16_t>(context, input, table, output);
112     default:
113       TF_LITE_UNSUPPORTED_TYPE(context, input->type, "Table");
114   }
115 
116   return kTfLiteOk;
117 }
118 
119 }  // namespace table
120 
Register_TABLE()121 TfLiteRegistration* Register_TABLE() {
122   static TfLiteRegistration r = {nullptr, nullptr, table::Prepare, table::Eval};
123   return &r;
124 }
125 
126 }  // namespace custom
127 }  // namespace ops
128 }  // namespace tflite
129