1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // Ops that looks up items from matrix.
17 //
18 // Input:
19 // Tensor[0]: Row number to lookup, dim.size == 1, int32
20 // Tensor[1]: 2-dimensional matrix of multi-dimensional items
21 // dim.size >= 2, any data type.
22 // first dimension is row, second dimension is column.
23 //
24 // Output:
25 // Output.dim[0] == Tensor[0].dim[0], num of lookups
26 // Output.dim[1] == Tensor[1].dim[1], num of items per row
27 // Each item in output is a raw bytes copy of the corresponding item in input,
28 // or a dequantized value in the case of a uint8 input.
29 // When indices are out of bound, the ops will not succeed.
30 //
31
32 #include <cassert>
33 #include <cmath>
34 #include <cstdio>
35 #include <cstdlib>
36 #include <cstring>
37 #include <iostream>
38 #include <limits>
39
40 #include "tensorflow/lite/c/builtin_op_data.h"
41 #include "tensorflow/lite/c/common.h"
42 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
43 #include "tensorflow/lite/kernels/kernel_util.h"
44 #include "tensorflow/lite/kernels/op_macros.h"
45
46 namespace tflite {
47 namespace ops {
48 namespace builtin {
49 namespace embedding_lookup {
50
Prepare(TfLiteContext * context,TfLiteNode * node)51 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
52 TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
53 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
54
55 const TfLiteTensor* lookup = GetInput(context, node, 0);
56 TF_LITE_ENSURE_EQ(context, NumDimensions(lookup), 1);
57 TF_LITE_ENSURE_EQ(context, lookup->type, kTfLiteInt32);
58
59 const TfLiteTensor* value = GetInput(context, node, 1);
60 TF_LITE_ENSURE(context, NumDimensions(value) >= 2);
61
62 TfLiteTensor* output = GetOutput(context, node, 0);
63 TfLiteIntArray* outputSize = TfLiteIntArrayCreate(NumDimensions(value));
64
65 outputSize->data[0] = SizeOfDimension(lookup, 0);
66 outputSize->data[1] = SizeOfDimension(value, 1);
67 for (int i = 2; i < NumDimensions(value); i++) {
68 outputSize->data[i] = SizeOfDimension(value, i);
69 }
70 return context->ResizeTensor(context, output, outputSize);
71 }
72
EvalSimple(TfLiteContext * context,TfLiteNode * node,const TfLiteTensor * lookup,const TfLiteTensor * value,TfLiteTensor * output)73 TfLiteStatus EvalSimple(TfLiteContext* context, TfLiteNode* node,
74 const TfLiteTensor* lookup, const TfLiteTensor* value,
75 TfLiteTensor* output) {
76 const int row_size = SizeOfDimension(value, 0);
77 const int row_bytes = value->bytes / row_size;
78
79 char* output_raw = GetTensorData<char>(output);
80 const char* value_raw = GetTensorData<char>(value);
81 const int32_t* lookup_data = GetTensorData<int32_t>(lookup);
82 for (int i = 0; i < SizeOfDimension(lookup, 0); i++) {
83 int idx = lookup_data[i];
84 if (idx >= row_size || idx < 0) {
85 context->ReportError(context,
86 "Embedding Lookup: index out of bounds. "
87 "Got %d, and bounds are [0, %d]",
88 idx, row_size - 1);
89 return kTfLiteError;
90 } else {
91 std::memcpy(output_raw + i * row_bytes, value_raw + idx * row_bytes,
92 row_bytes);
93 }
94 }
95
96 return kTfLiteOk;
97 }
98
EvalHybrid(TfLiteContext * context,TfLiteNode * node,const TfLiteTensor * lookup,const TfLiteTensor * value,TfLiteTensor * output)99 TfLiteStatus EvalHybrid(TfLiteContext* context, TfLiteNode* node,
100 const TfLiteTensor* lookup, const TfLiteTensor* value,
101 TfLiteTensor* output) {
102 const int row_size = SizeOfDimension(value, 0);
103 const double scaling_factor = value->params.scale;
104
105 // col_size after we flatten tensor into 2D.
106 int col_size = 1;
107 for (int i = 1; i < NumDimensions(value); i++) {
108 col_size *= SizeOfDimension(value, i);
109 }
110
111 float* output_ptr = GetTensorData<float>(output);
112 const int8_t* value_ptr = GetTensorData<int8_t>(value);
113 const int32_t* lookup_data = GetTensorData<int32_t>(lookup);
114
115 for (int i = 0; i < SizeOfDimension(lookup, 0); i++) {
116 int idx = lookup_data[i];
117 if (idx >= row_size || idx < 0) {
118 context->ReportError(context,
119 "Embedding Lookup: index out of bounds. "
120 "Got %d, and bounds are [0, %d]",
121 idx, row_size - 1);
122 return kTfLiteError;
123 } else {
124 // Dequantize embedding values.
125 // TODO(alanchiao): refactor scalar multiply into separate function
126 // for ease of adding a neon equivalent if ever necessary.
127 for (int j = 0; j < col_size; j++) {
128 output_ptr[j + i * col_size] =
129 value_ptr[j + idx * col_size] * scaling_factor;
130 }
131 }
132 }
133
134 return kTfLiteOk;
135 }
136
Eval(TfLiteContext * context,TfLiteNode * node)137 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
138 const TfLiteTensor* lookup = GetInput(context, node, 0);
139 const TfLiteTensor* value = GetInput(context, node, 1);
140 TfLiteTensor* output = GetOutput(context, node, 0);
141 switch (value->type) {
142 case kTfLiteFloat32:
143 return EvalSimple(context, node, lookup, value, output);
144 case kTfLiteUInt8:
145 case kTfLiteInt8:
146 if (output->type == kTfLiteFloat32) {
147 return EvalHybrid(context, node, lookup, value, output);
148 } else {
149 return EvalSimple(context, node, lookup, value, output);
150 }
151 default:
152 context->ReportError(context, "Type not currently supported.");
153 return kTfLiteError;
154 }
155 }
156
157 } // namespace embedding_lookup
158
Register_EMBEDDING_LOOKUP()159 TfLiteRegistration* Register_EMBEDDING_LOOKUP() {
160 static TfLiteRegistration r = {nullptr, nullptr, embedding_lookup::Prepare,
161 embedding_lookup::Eval};
162 return &r;
163 }
164
165 } // namespace builtin
166 } // namespace ops
167 } // namespace tflite
168