1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_TENSOR_UTILS_H_ 16 #define TENSORFLOW_LITE_KERNELS_INTERNAL_TENSOR_UTILS_H_ 17 18 #include "tensorflow/lite/c/builtin_op_data.h" 19 20 #if defined(_MSC_VER) 21 #define __restrict__ __restrict 22 #endif 23 24 namespace tflite { 25 namespace tensor_utils { 26 27 // Limit a float input f between +abs_limit and -abs_limit. 28 float Clip(float f, float abs_limit); 29 30 // Checks if all entries of vector are zero. 31 bool IsZeroVector(const float* vector, int v_size); 32 33 // Quantizes a buffer of floating point values using a symmetric quantization 34 // (i.e. linear quantization without an offset) to 8-bit signed integers. 35 // It also outputs the range (min, max) of the floating point buffer, and the 36 // scaling factor used to quantize the values. 37 void SymmetricQuantizeFloats(const float* values, const int size, 38 int8_t* quantized_values, float* min_value, 39 float* max_value, float* scaling_factor); 40 41 // Multiplies a matrix by a "batched" vector (i.e. a matrix with a batch 42 // dimension composed by input vectors independent from each other). The result 43 // of the multiplication is accumulated to the passed result buffer. 44 // More specifically, for a matrix M of shape [n, i] and a batched-vector 45 // of shape [i, batch] it will first compute the product of shape [n, batch]. 46 // This product will be accumulated to the result buffer, using a stride value 47 // provided in result_stride (the number of elements between consecutive result 48 // values). For example result_stride = 1, will cause the output to look like 49 // this: 50 // [O_1, 0_2, ... O_rows] 51 // but result_stride = 3, will cause it to be arranged like this in memory: 52 // [O_1, x, x, 0_2, x, x, ..., O_rows] 53 void MatrixBatchVectorMultiplyAccumulate(const float* matrix, int m_rows, 54 int m_cols, const float* vector, 55 int n_batch, float* result, 56 int result_stride); 57 58 // Same as the function above, but the matrix is stored in block compressed 59 // sparse row format with block pattern 1x16 which consists of two arrays: 60 // 1. A matrix array stores non-zero blocks of the matrix in row major. 61 // 2. A ledger array stores nrows groups, one group per row. Each group starts 62 // with 63 // an integer representing the number of non-zero blocks for the 64 // corresponding row and follows with column indexes of the first element 65 // of each non-zero block. 66 // This function assumes that 67 // 1. m_cols is a multiple of 16 so that all blocks are full blocks. 68 // 2. m_cols < 254 * 16 so that block index can be represented by uint8. 69 void SparseMatrixBatchVectorMultiplyAccumulate( 70 const float* matrix, const uint8_t* ledger, int m_rows, int m_cols, 71 const float* vector, int n_batch, float* result, int result_stride); 72 73 // Same as the function above, but for values quantized using symmetric 74 // quantization (e.g. by calling SymmetricQuantizeFloats). 75 // The passed scaling factors is a buffer of the quantization scaling factors 76 // that will be used to dequentize the products into the final result buffer. 77 // These scaling factors are the multiplication of the matrix scaling factor 78 // by the vector's scaling factor, one per batch (i.e. this allows quantizing 79 // each batch in the batch-vector matrix independently). 80 void MatrixBatchVectorMultiplyAccumulate( 81 const int8_t* __restrict__ matrix, const int m_rows, const int m_cols, 82 const int8_t* __restrict__ vectors, const float* scaling_factors, 83 int n_batch, float* __restrict__ result, int result_stride); 84 85 // Same as the function above, but the matrix is stored in block compressed 86 // sparse row format with block pattern 1x16 which consists of two arrays: 87 // 1. A matrix array stores non-zero blocks of the matrix in row major. 88 // 2. A ledger array stores nrows groups, one group per row. Each group starts 89 // with 90 // an integer representing the number of non-zero blocks for the 91 // corresponding row followed by column index of the first element of 92 // each non-zero block. 93 // This function assumes that 94 // 1. m_cols is a multiple of 16 so that all blocks are full blocks. 95 // 2. m_cols < 254 * 16 so that block index can be represented by uint8. 96 void SparseMatrixBatchVectorMultiplyAccumulate( 97 const int8_t* __restrict__ matrix, const uint8_t* ledger, const int m_rows, 98 const int m_cols, const int8_t* __restrict__ vectors, 99 const float* scaling_factors, int n_batch, float* __restrict__ result, 100 int result_stride); 101 102 // Cwise product of two vectors. 103 void VectorVectorCwiseProduct(const float* vector1, const float* vector2, 104 int v_size, float* result); 105 106 // Cwise product and accumulate of two vectors. Since it's a MAC opertation, the 107 // assumption here is that result array is initialized to valid values. 108 void VectorVectorCwiseProductAccumulate(const float* vector1, 109 const float* vector2, int v_size, 110 float* result); 111 112 // Dot product of two vectors. 113 float VectorVectorDotProduct(const float* vector1, const float* vector2, 114 int v_size); 115 116 // Dot product of two batch vectors of size n_batch * v_size: 117 // vector1 = [x_1_1, x_1_2, ..., x_1_vsize, 118 // x_2_1, x_2_2, ..., x_2_vsize, 119 // ... 120 // x_nbatch_1,..., x_nbatch_vsize] 121 // vector2 = [y_1_1, y_1_2, ..., y_1_vsize, 122 // y_2_1, y_2_2, ..., y_2_vsize, 123 // ... 124 // y_nbatch_1,..., y_nbatch_vsize] 125 // Then result will be a vector of n_batch size which will be saved with a 126 // stride of result_stride in memory starting from 'result': 127 // [x_1_1 * y_1_1 + x_1_2 * y_1_2 + ... + x_1_vsize * y_1_vsize, 128 // x_2_1 * y_2_1 + x_2_2 * y_2_2 + ... + x_2_vsize * y_2_vsize, 129 // ... 130 // x_nbatch_1 * y_nbatch_1 + ... + x_nbatch_vsize * y_nbatch_vsize] 131 void BatchVectorBatchVectorDotProduct(const float* vector1, 132 const float* vector2, int v_size, 133 int n_batch, float* result, 134 int result_stride); 135 136 // Cwise product of a vector and a batch-vector. 137 void VectorBatchVectorCwiseProduct(const float* vector, int v_size, 138 const float* batch_vector, int n_batch, 139 float* result); 140 141 // Cwise product and accumulate of a vector and a batch-vector. Since it's a MAC 142 // operation, the assumption here is that result array is initialized to valid 143 // values. 144 void VectorBatchVectorCwiseProductAccumulate(const float* vector, int v_size, 145 const float* batch_vector, 146 int n_batch, float* result); 147 148 // Add another vector for each batch in the batch vector. 149 void VectorBatchVectorAdd(const float* vector, int v_size, int n_batch, 150 float* batch_vector); 151 152 // Batch vector initialization with another vector. 153 void VectorBatchVectorAssign(const float* vector, int v_size, int n_batch, 154 float* batch_vector); 155 156 // Apply sigmoid to elements of a vector. 157 void ApplySigmoidToVector(const float* vector, int v_size, float* result); 158 159 // Apply activation function to elements of a vector. 160 void ApplyActivationToVector(const float* vector, int v_size, 161 TfLiteFusedActivation activation, float* result); 162 163 // Copy vector to another vector. 164 void CopyVector(const float* vector, int v_size, float* result); 165 166 // Compute "1.0f - elements of vector" (used in CIFG). 167 void Sub1Vector(const float* vector, int v_size, float* result); 168 169 // Fill vector with 0.f. 170 void ZeroVector(float* vector, int v_size); 171 172 // Multiply all elements of vector with a scalar. 173 void VectorScalarMultiply(const int8_t* vector, int v_size, float scale, 174 float* result); 175 176 // Clip elements of a vector using a abs_limit value. 177 void ClipVector(const float* vector, int v_size, float abs_limit, 178 float* result); 179 180 // Shift left a vector in place with v_size size. 181 void VectorShiftLeft(float* vector, int v_size, float shift_value); 182 183 // Reduce-sum on a float input vector: 184 // input_vector: float pointer to input vector. 185 // output_vector: float pointer to vector. 186 // output_size: output vector size. 187 // reduction_size: number of consecutive elements from input vector which are 188 // added to get one element of output. 189 void ReductionSumVector(const float* input_vector, float* output_vector, 190 int output_size, int reduction_size); 191 192 // Layer norm for each batch. 193 // normalization_epsilon is added to avoid divergence. 194 void MeanStddevNormalization(const float* input_vector, float* output_vector, 195 int v_size, int n_batch, 196 float normalization_epsilon); 197 } // namespace tensor_utils 198 } // namespace tflite 199 200 #endif // TENSORFLOW_LITE_KERNELS_INTERNAL_TENSOR_UTILS_H_ 201