• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Ops that looks up items from matrix.
17 //
18 // Input:
19 //     Tensor[0]: Row number to lookup, dim.size == 1, int32
20 //     Tensor[1]: 2-dimensional matrix of multi-dimensional items
21 //                dim.size >= 2, any data type.
22 //                first dimension is row, second dimension is column.
23 //
24 // Output:
25 //   Output.dim[0] == Tensor[0].dim[0], num of lookups
26 //   Output.dim[1] == Tensor[1].dim[1],  num of items per row
27 //   Each item in output is a raw bytes copy of the corresponding item in input,
28 //   or a dequantized value in the case of a uint8 input.
29 //   When indices are out of bound, the ops will not succeed.
30 //
31 
32 #include <cassert>
33 #include <cmath>
34 #include <cstdio>
35 #include <cstdlib>
36 #include <cstring>
37 #include <iostream>
38 #include <limits>
39 
40 #include "tensorflow/lite/c/builtin_op_data.h"
41 #include "tensorflow/lite/c/common.h"
42 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
43 #include "tensorflow/lite/kernels/kernel_util.h"
44 #include "tensorflow/lite/kernels/op_macros.h"
45 
46 namespace tflite {
47 namespace ops {
48 namespace builtin {
49 namespace embedding_lookup {
50 
Prepare(TfLiteContext * context,TfLiteNode * node)51 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
52   TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
53   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
54 
55   const TfLiteTensor* lookup = GetInput(context, node, 0);
56   TF_LITE_ENSURE_EQ(context, NumDimensions(lookup), 1);
57   TF_LITE_ENSURE_EQ(context, lookup->type, kTfLiteInt32);
58 
59   const TfLiteTensor* value = GetInput(context, node, 1);
60   TF_LITE_ENSURE(context, NumDimensions(value) >= 2);
61 
62   TfLiteTensor* output = GetOutput(context, node, 0);
63   TfLiteIntArray* outputSize = TfLiteIntArrayCreate(NumDimensions(value));
64 
65   outputSize->data[0] = SizeOfDimension(lookup, 0);
66   outputSize->data[1] = SizeOfDimension(value, 1);
67   for (int i = 2; i < NumDimensions(value); i++) {
68     outputSize->data[i] = SizeOfDimension(value, i);
69   }
70   return context->ResizeTensor(context, output, outputSize);
71 }
72 
EvalSimple(TfLiteContext * context,TfLiteNode * node,const TfLiteTensor * lookup,const TfLiteTensor * value,TfLiteTensor * output)73 TfLiteStatus EvalSimple(TfLiteContext* context, TfLiteNode* node,
74                         const TfLiteTensor* lookup, const TfLiteTensor* value,
75                         TfLiteTensor* output) {
76   const int row_size = SizeOfDimension(value, 0);
77   const int row_bytes = value->bytes / row_size;
78 
79   char* output_raw = GetTensorData<char>(output);
80   const char* value_raw = GetTensorData<char>(value);
81   const int32_t* lookup_data = GetTensorData<int32_t>(lookup);
82   for (int i = 0; i < SizeOfDimension(lookup, 0); i++) {
83     int idx = lookup_data[i];
84     if (idx >= row_size || idx < 0) {
85       context->ReportError(context,
86                            "Embedding Lookup: index out of bounds. "
87                            "Got %d, and bounds are [0, %d]",
88                            idx, row_size - 1);
89       return kTfLiteError;
90     } else {
91       std::memcpy(output_raw + i * row_bytes, value_raw + idx * row_bytes,
92                   row_bytes);
93     }
94   }
95 
96   return kTfLiteOk;
97 }
98 
EvalHybrid(TfLiteContext * context,TfLiteNode * node,const TfLiteTensor * lookup,const TfLiteTensor * value,TfLiteTensor * output)99 TfLiteStatus EvalHybrid(TfLiteContext* context, TfLiteNode* node,
100                         const TfLiteTensor* lookup, const TfLiteTensor* value,
101                         TfLiteTensor* output) {
102   const int row_size = SizeOfDimension(value, 0);
103   const double scaling_factor = value->params.scale;
104 
105   // col_size after we flatten tensor into 2D.
106   int col_size = 1;
107   for (int i = 1; i < NumDimensions(value); i++) {
108     col_size *= SizeOfDimension(value, i);
109   }
110 
111   float* output_ptr = GetTensorData<float>(output);
112   const int8_t* value_ptr = GetTensorData<int8_t>(value);
113   const int32_t* lookup_data = GetTensorData<int32_t>(lookup);
114 
115   for (int i = 0; i < SizeOfDimension(lookup, 0); i++) {
116     int idx = lookup_data[i];
117     if (idx >= row_size || idx < 0) {
118       context->ReportError(context,
119                            "Embedding Lookup: index out of bounds. "
120                            "Got %d, and bounds are [0, %d]",
121                            idx, row_size - 1);
122       return kTfLiteError;
123     } else {
124       // Dequantize embedding values.
125       // TODO(alanchiao): refactor scalar multiply into separate function
126       // for ease of adding a neon equivalent if ever necessary.
127       for (int j = 0; j < col_size; j++) {
128         output_ptr[j + i * col_size] =
129             value_ptr[j + idx * col_size] * scaling_factor;
130       }
131     }
132   }
133 
134   return kTfLiteOk;
135 }
136 
Eval(TfLiteContext * context,TfLiteNode * node)137 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
138   const TfLiteTensor* lookup = GetInput(context, node, 0);
139   const TfLiteTensor* value = GetInput(context, node, 1);
140   TfLiteTensor* output = GetOutput(context, node, 0);
141   switch (value->type) {
142     case kTfLiteFloat32:
143       return EvalSimple(context, node, lookup, value, output);
144     case kTfLiteUInt8:
145     case kTfLiteInt8:
146       if (output->type == kTfLiteFloat32) {
147         return EvalHybrid(context, node, lookup, value, output);
148       } else {
149         return EvalSimple(context, node, lookup, value, output);
150       }
151     default:
152       context->ReportError(context, "Type not currently supported.");
153       return kTfLiteError;
154   }
155 }
156 
157 }  // namespace embedding_lookup
158 
Register_EMBEDDING_LOOKUP()159 TfLiteRegistration* Register_EMBEDDING_LOOKUP() {
160   static TfLiteRegistration r = {nullptr, nullptr, embedding_lookup::Prepare,
161                                  embedding_lookup::Eval};
162   return &r;
163 }
164 
165 }  // namespace builtin
166 }  // namespace ops
167 }  // namespace tflite
168