1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "utils/tflite/text_encoder3s.h"
18
19 #include <memory>
20 #include <vector>
21
22 #include "utils/base/logging.h"
23 #include "utils/strings/stringpiece.h"
24 #include "utils/tflite/encoder_common.h"
25 #include "utils/tflite/text_encoder_config_generated.h"
26 #include "utils/tokenfree/byte_encoder.h"
27 #include "flatbuffers/flatbuffers.h"
28 #include "flatbuffers/flexbuffers.h"
29 #include "tensorflow/lite/kernels/kernel_util.h"
30 #include "tensorflow/lite/model.h"
31 #include "tensorflow/lite/string_util.h"
32
33 namespace libtextclassifier3 {
34 namespace {
35
36 // Input parameters for the op.
37 constexpr int kInputTextInd = 0;
38
39 constexpr int kTextLengthInd = 1;
40 constexpr int kMaxLengthInd = 2;
41 constexpr int kInputAttrInd = 3;
42
43 // Output parameters for the op.
44 constexpr int kOutputEncodedInd = 0;
45 constexpr int kOutputPositionInd = 1;
46 constexpr int kOutputLengthsInd = 2;
47 constexpr int kOutputAttrInd = 3;
48
49 // Initializes text encoder object from serialized parameters.
Initialize(TfLiteContext * context,const char * buffer,size_t length)50 void* Initialize(TfLiteContext* context, const char* buffer, size_t length) {
51 std::unique_ptr<ByteEncoder> encoder(new ByteEncoder());
52 return encoder.release();
53 }
54
Free(TfLiteContext * context,void * buffer)55 void Free(TfLiteContext* context, void* buffer) {
56 delete reinterpret_cast<ByteEncoder*>(buffer);
57 }
58
59 namespace {
ResizeOutputTensors(TfLiteContext * context,TfLiteNode * node,int max_output_length)60 TfLiteStatus ResizeOutputTensors(TfLiteContext* context, TfLiteNode* node,
61 int max_output_length) {
62 TfLiteTensor& output_encoded =
63 context->tensors[node->outputs->data[kOutputEncodedInd]];
64
65 TF_LITE_ENSURE_OK(
66 context, context->ResizeTensor(
67 context, &output_encoded,
68 CreateIntArray({kEncoderBatchSize, max_output_length})));
69 TfLiteTensor& output_positions =
70 context->tensors[node->outputs->data[kOutputPositionInd]];
71
72 TF_LITE_ENSURE_OK(
73 context, context->ResizeTensor(
74 context, &output_positions,
75 CreateIntArray({kEncoderBatchSize, max_output_length})));
76
77 const int num_output_attrs = node->outputs->size - kOutputAttrInd;
78 for (int i = 0; i < num_output_attrs; ++i) {
79 TfLiteTensor& output =
80 context->tensors[node->outputs->data[kOutputAttrInd + i]];
81 TF_LITE_ENSURE_OK(
82 context, context->ResizeTensor(
83 context, &output,
84 CreateIntArray({kEncoderBatchSize, max_output_length})));
85 }
86 return kTfLiteOk;
87 }
88 } // namespace
89
Prepare(TfLiteContext * context,TfLiteNode * node)90 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
91 // Check that the batch dimension is kEncoderBatchSize.
92 const TfLiteTensor& input_text =
93 context->tensors[node->inputs->data[kInputTextInd]];
94 TF_LITE_ENSURE_EQ(context, input_text.dims->size, kEncoderInputRank);
95 TF_LITE_ENSURE_EQ(context, input_text.dims->data[0], kEncoderBatchSize);
96
97 TfLiteTensor& output_lengths =
98 context->tensors[node->outputs->data[kOutputLengthsInd]];
99
100 TfLiteTensor& output_encoded =
101 context->tensors[node->outputs->data[kOutputEncodedInd]];
102 TfLiteTensor& output_positions =
103 context->tensors[node->outputs->data[kOutputPositionInd]];
104 output_encoded.type = kTfLiteInt32;
105 output_positions.type = kTfLiteInt32;
106 output_lengths.type = kTfLiteInt32;
107
108 TF_LITE_ENSURE_OK(context,
109 context->ResizeTensor(context, &output_lengths,
110 CreateIntArray({kEncoderBatchSize})));
111
112 // Check that there are enough outputs for attributes.
113 const int num_output_attrs = node->outputs->size - kOutputAttrInd;
114 TF_LITE_ENSURE_EQ(context, node->inputs->size - kInputAttrInd,
115 num_output_attrs);
116
117 // Copy attribute types from input to output tensors.
118 for (int i = 0; i < num_output_attrs; ++i) {
119 TfLiteTensor& input =
120 context->tensors[node->inputs->data[kInputAttrInd + i]];
121 TfLiteTensor& output =
122 context->tensors[node->outputs->data[kOutputAttrInd + i]];
123 output.type = input.type;
124 }
125
126 const TfLiteTensor& output_length =
127 context->tensors[node->inputs->data[kMaxLengthInd]];
128
129 if (tflite::IsConstantTensor(&output_length)) {
130 return ResizeOutputTensors(context, node, output_length.data.i64[0]);
131 } else {
132 tflite::SetTensorToDynamic(&output_encoded);
133 tflite::SetTensorToDynamic(&output_positions);
134 for (int i = 0; i < num_output_attrs; ++i) {
135 TfLiteTensor& output_attr =
136 context->tensors[node->outputs->data[kOutputAttrInd + i]];
137 tflite::SetTensorToDynamic(&output_attr);
138 }
139 }
140
141 return kTfLiteOk;
142 }
143
Eval(TfLiteContext * context,TfLiteNode * node)144 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
145 if (node->user_data == nullptr) {
146 return kTfLiteError;
147 }
148 auto text_encoder = reinterpret_cast<ByteEncoder*>(node->user_data);
149 const TfLiteTensor& input_text =
150 context->tensors[node->inputs->data[kInputTextInd]];
151 const int num_strings_in_tensor = tflite::GetStringCount(&input_text);
152 const int num_strings =
153 context->tensors[node->inputs->data[kTextLengthInd]].data.i32[0];
154
155 // Check that the number of strings is not bigger than the input tensor size.
156 TF_LITE_ENSURE(context, num_strings_in_tensor >= num_strings);
157
158 TfLiteTensor& output_encoded =
159 context->tensors[node->outputs->data[kOutputEncodedInd]];
160 if (tflite::IsDynamicTensor(&output_encoded)) {
161 const TfLiteTensor& output_length =
162 context->tensors[node->inputs->data[kMaxLengthInd]];
163 TF_LITE_ENSURE_OK(
164 context, ResizeOutputTensors(context, node, output_length.data.i64[0]));
165 }
166 TfLiteTensor& output_positions =
167 context->tensors[node->outputs->data[kOutputPositionInd]];
168
169 std::vector<int> encoded_total;
170 std::vector<int> encoded_positions;
171 std::vector<int> encoded_offsets;
172 encoded_offsets.reserve(num_strings);
173 const int max_output_length = output_encoded.dims->data[1];
174 const int max_encoded_position = max_output_length;
175
176 for (int i = 0; i < num_strings; ++i) {
177 const auto& strref = tflite::GetString(&input_text, i);
178 std::vector<int64_t> encoded;
179 text_encoder->Encode(
180 libtextclassifier3::StringPiece(strref.str, strref.len), &encoded);
181 encoded_total.insert(encoded_total.end(), encoded.begin(), encoded.end());
182 encoded_offsets.push_back(encoded_total.size());
183 for (int i = 0; i < encoded.size(); ++i) {
184 encoded_positions.push_back(std::min(i, max_encoded_position - 1));
185 }
186 }
187
188 // Copy encoding to output tensor.
189 const int start_offset =
190 std::max(0, static_cast<int>(encoded_total.size()) - max_output_length);
191 int output_offset = 0;
192 int32_t* output_buffer = output_encoded.data.i32;
193 int32_t* output_positions_buffer = output_positions.data.i32;
194 for (int i = start_offset; i < encoded_total.size(); ++i, ++output_offset) {
195 output_buffer[output_offset] = encoded_total[i];
196 output_positions_buffer[output_offset] = encoded_positions[i];
197 }
198
199 // Save output encoded length.
200 TfLiteTensor& output_lengths =
201 context->tensors[node->outputs->data[kOutputLengthsInd]];
202 output_lengths.data.i32[0] = output_offset;
203
204 // Do padding.
205 for (; output_offset < max_output_length; ++output_offset) {
206 output_buffer[output_offset] = 0;
207 output_positions_buffer[output_offset] = 0;
208 }
209
210 // Process attributes, all checks of sizes and types are done in Prepare.
211 const int num_output_attrs = node->outputs->size - kOutputAttrInd;
212 TF_LITE_ENSURE_EQ(context, node->inputs->size - kInputAttrInd,
213 num_output_attrs);
214 for (int i = 0; i < num_output_attrs; ++i) {
215 TfLiteStatus attr_status = CopyValuesToTensorAndPadOrTruncate(
216 context->tensors[node->inputs->data[kInputAttrInd + i]],
217 encoded_offsets, start_offset, context,
218 &context->tensors[node->outputs->data[kOutputAttrInd + i]]);
219 if (attr_status != kTfLiteOk) {
220 return attr_status;
221 }
222 }
223
224 return kTfLiteOk;
225 }
226
227 } // namespace
228 } // namespace libtextclassifier3
229
230 namespace tflite {
231 namespace ops {
232 namespace custom {
233
Register_TEXT_ENCODER3S()234 TfLiteRegistration* Register_TEXT_ENCODER3S() {
235 static TfLiteRegistration registration = {
236 libtextclassifier3::Initialize, libtextclassifier3::Free,
237 libtextclassifier3::Prepare, libtextclassifier3::Eval};
238 return ®istration;
239 }
240
241 } // namespace custom
242 } // namespace ops
243 } // namespace tflite
244