• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "utils/tflite/encoder_common.h"
18 
19 #include "tensorflow/lite/kernels/kernel_util.h"
20 #include "tensorflow/lite/string_util.h"
21 
22 namespace libtextclassifier3 {
23 
CreateIntArray(const std::initializer_list<int> & values)24 TfLiteIntArray* CreateIntArray(const std::initializer_list<int>& values) {
25   TfLiteIntArray* array_size = TfLiteIntArrayCreate(values.size());
26   int index = 0;
27   for (const int size : values) {
28     array_size->data[index++] = size;
29   }
30   return array_size;
31 }
32 
CopyValuesToTensorAndPadOrTruncate(const TfLiteTensor & in,const std::vector<int> & encoding_end_offsets,int start_offset,TfLiteContext * context,TfLiteTensor * out)33 TfLiteStatus CopyValuesToTensorAndPadOrTruncate(
34     const TfLiteTensor& in, const std::vector<int>& encoding_end_offsets,
35     int start_offset, TfLiteContext* context, TfLiteTensor* out) {
36   TF_LITE_ENSURE_EQ(context, in.dims->size, kEncoderInputRank);
37   TF_LITE_ENSURE_EQ(context, in.dims->data[0], kEncoderBatchSize);
38   const int output_size = out->dims->data[1];
39   int output_offset = 0;
40   for (int value_index = 0;
41        value_index < encoding_end_offsets.size() && output_offset < output_size;
42        ++value_index) {
43     // Calculate how many elements need to be set with this value.
44     // The low bound depends on the offset from the beginning. If this is 0, it
45     // means that this value it truncated.
46     // The upper bound depends on how many elements are in the output tensor.
47     const int from_this_element =
48         std::min(std::max(0, encoding_end_offsets[value_index] - start_offset -
49                                  output_offset),
50                  output_size - output_offset);
51     if (from_this_element == 0) {
52       continue;
53     }
54 
55     switch (in.type) {
56       case kTfLiteInt32: {
57         std::fill(out->data.i32 + output_offset,
58                   out->data.i32 + output_offset + from_this_element,
59                   in.data.i32[value_index]);
60       } break;
61       case kTfLiteFloat32: {
62         std::fill(out->data.f + output_offset,
63                   out->data.f + output_offset + from_this_element,
64                   in.data.f[value_index]);
65       } break;
66       default:
67         context->ReportError(
68             (context), __FILE__ " Not supported attribute type %d", in.type);
69         return kTfLiteError;
70     }
71     output_offset += from_this_element;
72   }
73   // Do final padding.
74   switch (in.type) {
75     case kTfLiteInt32: {
76       const int32_t value =
77           (output_offset > 0) ? out->data.i32[output_offset - 1] : 0;
78       std::fill(out->data.i32 + output_offset, out->data.i32 + output_size,
79                 value);
80     } break;
81     case kTfLiteFloat32: {
82       const float value =
83           (output_offset > 0) ? out->data.f[output_offset - 1] : 0;
84       std::fill(out->data.f + output_offset, out->data.f + output_size, value);
85     } break;
86     default:
87       break;
88   }
89   return kTfLiteOk;
90 }
91 
ResizeOutputTensor(const int max_output_length,TfLiteTensor * tensor,TfLiteContext * context)92 TfLiteStatus ResizeOutputTensor(const int max_output_length,
93                                 TfLiteTensor* tensor, TfLiteContext* context) {
94   TF_LITE_ENSURE_OK(
95       context, context->ResizeTensor(
96                    context, tensor,
97                    CreateIntArray({kEncoderBatchSize, max_output_length})));
98   return kTfLiteOk;
99 }
100 
CopyDataToTensorAndPadOrTruncate(const int32_t max_output_length,const std::vector<int32_t> & data,const int32_t padding_value,TfLiteTensor * output_tensor)101 int CopyDataToTensorAndPadOrTruncate(const int32_t max_output_length,
102                                      const std::vector<int32_t>& data,
103                                      const int32_t padding_value,
104                                      TfLiteTensor* output_tensor) {
105   const int num_skip =
106       std::max(0, static_cast<int>(data.size()) - max_output_length);
107   int output_offset = 0;
108   int32_t* output_buffer = output_tensor->data.i32;
109   for (int i = num_skip; i < data.size(); ++i, ++output_offset) {
110     output_buffer[output_offset] = data[i];
111   }
112 
113   // Do padding.
114   for (; output_offset < max_output_length; ++output_offset) {
115     output_buffer[output_offset] = padding_value;
116   }
117 
118   // Return number of skipped entries from the beginning.
119   return num_skip;
120 }
121 
122 }  // namespace libtextclassifier3
123