• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "utils/tflite/text_encoder3s.h"
18 
19 #include <memory>
20 #include <vector>
21 
22 #include "utils/base/logging.h"
23 #include "utils/strings/stringpiece.h"
24 #include "utils/tflite/encoder_common.h"
25 #include "utils/tflite/text_encoder_config_generated.h"
26 #include "utils/tokenfree/byte_encoder.h"
27 #include "flatbuffers/flatbuffers.h"
28 #include "flatbuffers/flexbuffers.h"
29 #include "tensorflow/lite/kernels/kernel_util.h"
30 #include "tensorflow/lite/model.h"
31 #include "tensorflow/lite/string_util.h"
32 
33 namespace libtextclassifier3 {
34 namespace {
35 
36 // Input parameters for the op.
37 constexpr int kInputTextInd = 0;
38 
39 constexpr int kTextLengthInd = 1;
40 constexpr int kMaxLengthInd = 2;
41 constexpr int kInputAttrInd = 3;
42 
43 // Output parameters for the op.
44 constexpr int kOutputEncodedInd = 0;
45 constexpr int kOutputPositionInd = 1;
46 constexpr int kOutputLengthsInd = 2;
47 constexpr int kOutputAttrInd = 3;
48 
49 // Initializes text encoder object from serialized parameters.
Initialize(TfLiteContext * context,const char * buffer,size_t length)50 void* Initialize(TfLiteContext* context, const char* buffer, size_t length) {
51   std::unique_ptr<ByteEncoder> encoder(new ByteEncoder());
52   return encoder.release();
53 }
54 
Free(TfLiteContext * context,void * buffer)55 void Free(TfLiteContext* context, void* buffer) {
56   delete reinterpret_cast<ByteEncoder*>(buffer);
57 }
58 
59 namespace {
ResizeOutputTensors(TfLiteContext * context,TfLiteNode * node,int max_output_length)60 TfLiteStatus ResizeOutputTensors(TfLiteContext* context, TfLiteNode* node,
61                                  int max_output_length) {
62   TfLiteTensor& output_encoded =
63       context->tensors[node->outputs->data[kOutputEncodedInd]];
64 
65   TF_LITE_ENSURE_OK(
66       context, context->ResizeTensor(
67                    context, &output_encoded,
68                    CreateIntArray({kEncoderBatchSize, max_output_length})));
69   TfLiteTensor& output_positions =
70       context->tensors[node->outputs->data[kOutputPositionInd]];
71 
72   TF_LITE_ENSURE_OK(
73       context, context->ResizeTensor(
74                    context, &output_positions,
75                    CreateIntArray({kEncoderBatchSize, max_output_length})));
76 
77   const int num_output_attrs = node->outputs->size - kOutputAttrInd;
78   for (int i = 0; i < num_output_attrs; ++i) {
79     TfLiteTensor& output =
80         context->tensors[node->outputs->data[kOutputAttrInd + i]];
81     TF_LITE_ENSURE_OK(
82         context, context->ResizeTensor(
83                      context, &output,
84                      CreateIntArray({kEncoderBatchSize, max_output_length})));
85   }
86   return kTfLiteOk;
87 }
88 }  // namespace
89 
Prepare(TfLiteContext * context,TfLiteNode * node)90 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
91   // Check that the batch dimension is kEncoderBatchSize.
92   const TfLiteTensor& input_text =
93       context->tensors[node->inputs->data[kInputTextInd]];
94   TF_LITE_ENSURE_EQ(context, input_text.dims->size, kEncoderInputRank);
95   TF_LITE_ENSURE_EQ(context, input_text.dims->data[0], kEncoderBatchSize);
96 
97   TfLiteTensor& output_lengths =
98       context->tensors[node->outputs->data[kOutputLengthsInd]];
99 
100   TfLiteTensor& output_encoded =
101       context->tensors[node->outputs->data[kOutputEncodedInd]];
102   TfLiteTensor& output_positions =
103       context->tensors[node->outputs->data[kOutputPositionInd]];
104   output_encoded.type = kTfLiteInt32;
105   output_positions.type = kTfLiteInt32;
106   output_lengths.type = kTfLiteInt32;
107 
108   TF_LITE_ENSURE_OK(context,
109                     context->ResizeTensor(context, &output_lengths,
110                                           CreateIntArray({kEncoderBatchSize})));
111 
112   // Check that there are enough outputs for attributes.
113   const int num_output_attrs = node->outputs->size - kOutputAttrInd;
114   TF_LITE_ENSURE_EQ(context, node->inputs->size - kInputAttrInd,
115                     num_output_attrs);
116 
117   // Copy attribute types from input to output tensors.
118   for (int i = 0; i < num_output_attrs; ++i) {
119     TfLiteTensor& input =
120         context->tensors[node->inputs->data[kInputAttrInd + i]];
121     TfLiteTensor& output =
122         context->tensors[node->outputs->data[kOutputAttrInd + i]];
123     output.type = input.type;
124   }
125 
126   const TfLiteTensor& output_length =
127       context->tensors[node->inputs->data[kMaxLengthInd]];
128 
129   if (tflite::IsConstantTensor(&output_length)) {
130     return ResizeOutputTensors(context, node, output_length.data.i64[0]);
131   } else {
132     tflite::SetTensorToDynamic(&output_encoded);
133     tflite::SetTensorToDynamic(&output_positions);
134     for (int i = 0; i < num_output_attrs; ++i) {
135       TfLiteTensor& output_attr =
136           context->tensors[node->outputs->data[kOutputAttrInd + i]];
137       tflite::SetTensorToDynamic(&output_attr);
138     }
139   }
140 
141   return kTfLiteOk;
142 }
143 
Eval(TfLiteContext * context,TfLiteNode * node)144 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
145   if (node->user_data == nullptr) {
146     return kTfLiteError;
147   }
148   auto text_encoder = reinterpret_cast<ByteEncoder*>(node->user_data);
149   const TfLiteTensor& input_text =
150       context->tensors[node->inputs->data[kInputTextInd]];
151   const int num_strings_in_tensor = tflite::GetStringCount(&input_text);
152   const int num_strings =
153       context->tensors[node->inputs->data[kTextLengthInd]].data.i32[0];
154 
155   // Check that the number of strings is not bigger than the input tensor size.
156   TF_LITE_ENSURE(context, num_strings_in_tensor >= num_strings);
157 
158   TfLiteTensor& output_encoded =
159       context->tensors[node->outputs->data[kOutputEncodedInd]];
160   if (tflite::IsDynamicTensor(&output_encoded)) {
161     const TfLiteTensor& output_length =
162         context->tensors[node->inputs->data[kMaxLengthInd]];
163     TF_LITE_ENSURE_OK(
164         context, ResizeOutputTensors(context, node, output_length.data.i64[0]));
165   }
166   TfLiteTensor& output_positions =
167       context->tensors[node->outputs->data[kOutputPositionInd]];
168 
169   std::vector<int> encoded_total;
170   std::vector<int> encoded_positions;
171   std::vector<int> encoded_offsets;
172   encoded_offsets.reserve(num_strings);
173   const int max_output_length = output_encoded.dims->data[1];
174   const int max_encoded_position = max_output_length;
175 
176   for (int i = 0; i < num_strings; ++i) {
177     const auto& strref = tflite::GetString(&input_text, i);
178     std::vector<int64_t> encoded;
179     text_encoder->Encode(
180         libtextclassifier3::StringPiece(strref.str, strref.len), &encoded);
181     encoded_total.insert(encoded_total.end(), encoded.begin(), encoded.end());
182     encoded_offsets.push_back(encoded_total.size());
183     for (int i = 0; i < encoded.size(); ++i) {
184       encoded_positions.push_back(std::min(i, max_encoded_position - 1));
185     }
186   }
187 
188   // Copy encoding to output tensor.
189   const int start_offset =
190       std::max(0, static_cast<int>(encoded_total.size()) - max_output_length);
191   int output_offset = 0;
192   int32_t* output_buffer = output_encoded.data.i32;
193   int32_t* output_positions_buffer = output_positions.data.i32;
194   for (int i = start_offset; i < encoded_total.size(); ++i, ++output_offset) {
195     output_buffer[output_offset] = encoded_total[i];
196     output_positions_buffer[output_offset] = encoded_positions[i];
197   }
198 
199   // Save output encoded length.
200   TfLiteTensor& output_lengths =
201       context->tensors[node->outputs->data[kOutputLengthsInd]];
202   output_lengths.data.i32[0] = output_offset;
203 
204   // Do padding.
205   for (; output_offset < max_output_length; ++output_offset) {
206     output_buffer[output_offset] = 0;
207     output_positions_buffer[output_offset] = 0;
208   }
209 
210   // Process attributes, all checks of sizes and types are done in Prepare.
211   const int num_output_attrs = node->outputs->size - kOutputAttrInd;
212   TF_LITE_ENSURE_EQ(context, node->inputs->size - kInputAttrInd,
213                     num_output_attrs);
214   for (int i = 0; i < num_output_attrs; ++i) {
215     TfLiteStatus attr_status = CopyValuesToTensorAndPadOrTruncate(
216         context->tensors[node->inputs->data[kInputAttrInd + i]],
217         encoded_offsets, start_offset, context,
218         &context->tensors[node->outputs->data[kOutputAttrInd + i]]);
219     if (attr_status != kTfLiteOk) {
220       return attr_status;
221     }
222   }
223 
224   return kTfLiteOk;
225 }
226 
227 }  // namespace
228 }  // namespace libtextclassifier3
229 
230 namespace tflite {
231 namespace ops {
232 namespace custom {
233 
Register_TEXT_ENCODER3S()234 TfLiteRegistration* Register_TEXT_ENCODER3S() {
235   static TfLiteRegistration registration = {
236       libtextclassifier3::Initialize, libtextclassifier3::Free,
237       libtextclassifier3::Prepare, libtextclassifier3::Eval};
238   return &registration;
239 }
240 
241 }  // namespace custom
242 }  // namespace ops
243 }  // namespace tflite
244