1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/c/common.h"
16 #include "tensorflow/lite/kernels/internal/common.h"
17 #include "tensorflow/lite/kernels/internal/tensor.h"
18 #include "tensorflow/lite/kernels/kernel_util.h"
19
20 namespace tflite {
21 namespace ops {
22 namespace custom {
23 namespace table {
24
25 constexpr int kInputTensor = 0;
26 constexpr int kTable = 1;
27 constexpr int kOutputTensor = 0;
28
Prepare(TfLiteContext * context,TfLiteNode * node)29 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
30 TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
31 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
32
33 const TfLiteTensor* input;
34 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
35 const TfLiteTensor* table;
36 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kTable, &table));
37 TfLiteTensor* output;
38 TF_LITE_ENSURE_OK(context,
39 GetOutputSafe(context, node, kOutputTensor, &output));
40
41 TF_LITE_ENSURE(context,
42 input->type == kTfLiteInt8 || input->type == kTfLiteInt16);
43 TF_LITE_ENSURE(context,
44 output->type == kTfLiteInt8 || output->type == kTfLiteInt16);
45 TF_LITE_ENSURE_TYPES_EQ(context, output->type, table->type);
46
47 if (input->type == kTfLiteInt16) {
48 TF_LITE_ENSURE_EQ(context, input->params.zero_point, 0);
49 }
50 if (output->type == kTfLiteInt16) {
51 TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0);
52 }
53
54 TF_LITE_ENSURE_EQ(context, NumDimensions(table), 1);
55 if (input->type == kTfLiteInt8) {
56 TF_LITE_ENSURE_EQ(context, NumElements(table), lut_size<int8_t>());
57 } else {
58 TF_LITE_ENSURE_EQ(context, input->type, kTfLiteInt16);
59 TF_LITE_ENSURE_EQ(context, NumElements(table), lut_size<int16_t>());
60 }
61
62 return context->ResizeTensor(context, output,
63 TfLiteIntArrayCopy(input->dims));
64 }
65
66 template <typename InputT, typename OutputT>
Table(TfLiteContext * context,const TfLiteTensor * input,const TfLiteTensor * table,TfLiteTensor * output)67 void Table(TfLiteContext* context, const TfLiteTensor* input,
68 const TfLiteTensor* table, TfLiteTensor* output) {
69 const InputT* input_data = GetTensorData<InputT>(input);
70 const OutputT* table_data = GetTensorData<OutputT>(table);
71 OutputT* output_data = GetTensorData<OutputT>(output);
72
73 const int num_elements = NumElements(input);
74 for (int i = 0; i < num_elements; i++) {
75 // No need to rescale the input and output, the rescaling and its zero-point
76 // are implicitly included into the table data during its generation.
77 output_data[i] = lut_lookup(input_data[i], table_data);
78 }
79 }
80
81 template <typename InputT>
EvalTable(TfLiteContext * context,const TfLiteTensor * input,const TfLiteTensor * table,TfLiteTensor * output)82 TfLiteStatus EvalTable(TfLiteContext* context, const TfLiteTensor* input,
83 const TfLiteTensor* table, TfLiteTensor* output) {
84 switch (output->type) {
85 case kTfLiteInt8:
86 Table<InputT, int8_t>(context, input, table, output);
87 break;
88 case kTfLiteInt16:
89 Table<InputT, int16_t>(context, input, table, output);
90 break;
91 default:
92 TF_LITE_UNSUPPORTED_TYPE(context, output->type, "Table");
93 }
94
95 return kTfLiteOk;
96 }
97
Eval(TfLiteContext * context,TfLiteNode * node)98 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
99 const TfLiteTensor* input;
100 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
101 const TfLiteTensor* table;
102 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kTable, &table));
103 TfLiteTensor* output;
104 TF_LITE_ENSURE_OK(context,
105 GetOutputSafe(context, node, kOutputTensor, &output));
106
107 switch (input->type) {
108 case kTfLiteInt8:
109 return EvalTable<int8_t>(context, input, table, output);
110 case kTfLiteInt16:
111 return EvalTable<int16_t>(context, input, table, output);
112 default:
113 TF_LITE_UNSUPPORTED_TYPE(context, input->type, "Table");
114 }
115
116 return kTfLiteOk;
117 }
118
119 } // namespace table
120
Register_TABLE()121 TfLiteRegistration* Register_TABLE() {
122 static TfLiteRegistration r = {nullptr, nullptr, table::Prepare, table::Eval};
123 return &r;
124 }
125
126 } // namespace custom
127 } // namespace ops
128 } // namespace tflite
129