1 /* 2 * Copyright (C) 2018 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 // Shared methods for the text and token encoders. 18 19 #ifndef LIBTEXTCLASSIFIER_UTILS_TFLITE_ENCODER_COMMON_H_ 20 #define LIBTEXTCLASSIFIER_UTILS_TFLITE_ENCODER_COMMON_H_ 21 22 #include <memory> 23 #include <vector> 24 25 #include "tensorflow/lite/model.h" 26 27 namespace libtextclassifier3 { 28 29 // Input rank for the encoder ops is 2, because the first dimension is 30 // always considered to be for batching, and during inference is always set to 31 // 1, and the second dimension indexes the input values (texts or token 32 // lengths). 33 constexpr const int kEncoderInputRank = 2; 34 constexpr const int kEncoderBatchSize = 1; 35 36 // Creates a TensorFlow Lite array from an initializer list. 37 TfLiteIntArray* CreateIntArray(const std::initializer_list<int>& values); 38 39 // Copies values associated with the input to the output. 40 // Typically we have attribute values associated with each item in the input, 41 // e.g. user id per message in the conversation. 42 // This aligns and replicates the attribute values with the encoded input, e.g. 43 // replicates the same user id per token or sentence piece of the input. 44 // As the input for the whole conversation is concatenated and (potentially) 45 // trimmed, `encoding_end_offset` indicates where each item ends and 46 // `start_offset` indicates how many elements at the beginning were dropped. 47 TfLiteStatus CopyValuesToTensorAndPadOrTruncate( 48 const TfLiteTensor& in, const std::vector<int>& encoding_end_offsets, 49 int start_offset, TfLiteContext* context, TfLiteTensor* out); 50 51 // Resizes an output tensor to shape {kBatchSize, max_output_length}. 52 TfLiteStatus ResizeOutputTensor(const int max_output_length, 53 TfLiteTensor* tensor, TfLiteContext* context); 54 55 // Copy a slice of data to output. 56 // If the size of the data is smaller than `max_output_length` then the output 57 // is padded with `padding_value`. 58 // If the size of the data is larger than `max_output_length` then entries at 59 // the beginning a dropped to fit into the limit. 60 int CopyDataToTensorAndPadOrTruncate(const int32_t max_output_length, 61 const std::vector<int32_t>& data, 62 const int32_t padding_value, 63 TfLiteTensor* output_tensor); 64 65 } // namespace libtextclassifier3 66 67 #endif // LIBTEXTCLASSIFIER_UTILS_TFLITE_ENCODER_COMMON_H_ 68