• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_SSE_TENSOR_UTILS_H_
16 #define TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_SSE_TENSOR_UTILS_H_
17 
18 // Note: This file is a copy-paste version of neon_tensor_utils.h, only
19 // difference is in MatrixBatchVectorMultiplyAccumulate and
20 // SparseMatrixBatchVectorMultiplyAccumulate (other functions do not have SSE
21 // implementation yet).
22 
23 // Note: Most of the functions below use NEON_OR_PORTABLE, through the Intel
24 // NEON_2_SSE translator library. If a native SSE version of a function is
25 // implemented, replace the appropriate one to SSE_OR_PORTABLE.
26 
27 #include "tensorflow/lite/kernels/cpu_backend_context.h"
28 #include "tensorflow/lite/kernels/internal/optimized/neon_check.h"
29 #include "tensorflow/lite/kernels/internal/optimized/neon_tensor_utils_impl.h"
30 #include "tensorflow/lite/kernels/internal/optimized/sse_check.h"
31 #include "tensorflow/lite/kernels/internal/optimized/sse_tensor_utils_impl.h"
32 #include "tensorflow/lite/kernels/internal/reference/portable_tensor_utils_impl.h"
33 
34 namespace tflite {
35 namespace tensor_utils {
36 
MatrixBatchVectorMultiplyAccumulate(const float * matrix,int m_rows,int m_cols,const float * vector,int n_batch,float * result)37 void MatrixBatchVectorMultiplyAccumulate(const float* matrix, int m_rows,
38                                          int m_cols, const float* vector,
39                                          int n_batch, float* result) {
40   NEON_OR_PORTABLE(MatrixBatchVectorMultiplyAccumulate, matrix, m_rows, m_cols,
41                    vector, n_batch, result);
42 }
43 
MatrixBatchVectorMultiplyAccumulate(const int8_t * __restrict__ matrix,const int m_rows,const int m_cols,const int8_t * __restrict__ vectors,const float * __restrict__ scaling_factors,int n_batch,float * __restrict__ result)44 void MatrixBatchVectorMultiplyAccumulate(
45     const int8_t* __restrict__ matrix, const int m_rows, const int m_cols,
46     const int8_t* __restrict__ vectors,
47     const float* __restrict__ scaling_factors, int n_batch,
48     float* __restrict__ result) {
49   SSE_OR_PORTABLE(MatrixBatchVectorMultiplyAccumulate, matrix, m_rows, m_cols,
50                   vectors, scaling_factors, n_batch, result);
51 }
52 
MatrixBatchVectorMultiplyAccumulate(const int8_t * __restrict__ matrix,const int m_rows,const int m_cols,const int8_t * __restrict__ vectors,const float * scaling_factors,int n_batch,float * __restrict__ result,const float * per_channel_scale,const int32_t * input_offset,int32_t * scratch,int32_t * row_sums,bool * compute_row_sums,CpuBackendContext * context)53 void MatrixBatchVectorMultiplyAccumulate(
54     const int8_t* __restrict__ matrix, const int m_rows, const int m_cols,
55     const int8_t* __restrict__ vectors, const float* scaling_factors,
56     int n_batch, float* __restrict__ result, const float* per_channel_scale,
57     const int32_t* input_offset, int32_t* scratch, int32_t* row_sums,
58     bool* compute_row_sums, CpuBackendContext* context) {
59   SSE_OR_PORTABLE(MatrixBatchVectorMultiplyAccumulate, matrix, m_rows, m_cols,
60                   vectors, scaling_factors, n_batch, result, per_channel_scale,
61                   input_offset, scratch, row_sums, compute_row_sums, context);
62 }
63 
MatrixBatchVectorMultiplyAccumulate(const int8_t * __restrict__ matrix,const int m_rows,const int m_cols,const int8_t * __restrict__ vectors,const float * __restrict__ scaling_factors,int n_batch,int32_t * __restrict__ scratch,float * __restrict__ result,CpuBackendContext * __restrict__ context)64 void MatrixBatchVectorMultiplyAccumulate(
65     const int8_t* __restrict__ matrix, const int m_rows, const int m_cols,
66     const int8_t* __restrict__ vectors,
67     const float* __restrict__ scaling_factors, int n_batch,
68     int32_t* __restrict__ scratch, float* __restrict__ result,
69     CpuBackendContext* __restrict__ context) {
70   SSE_OR_PORTABLE(MatrixBatchVectorMultiplyAccumulate, matrix, m_rows, m_cols,
71                   vectors, scaling_factors, n_batch, scratch, result, context);
72 }
73 
SparseMatrixBatchVectorMultiplyAccumulate1x4(const float * __restrict__ matrix,const int32_t * __restrict__ segments,const int32_t * __restrict__ indices,int m_rows,int m_cols,const float * __restrict__ vector,int n_batch,float * __restrict__ result)74 void SparseMatrixBatchVectorMultiplyAccumulate1x4(
75     const float* __restrict__ matrix, const int32_t* __restrict__ segments,
76     const int32_t* __restrict__ indices, int m_rows, int m_cols,
77     const float* __restrict__ vector, int n_batch, float* __restrict__ result) {
78   NEON_OR_PORTABLE(SparseMatrixBatchVectorMultiplyAccumulate1x4, matrix,
79                    segments, indices, m_rows, m_cols, vector, n_batch, result);
80 }
81 
SparseMatrixBatchVectorMultiplyAccumulate(const float * __restrict__ matrix,const uint8_t * __restrict__ ledger,int m_rows,int m_cols,const float * __restrict__ vector,int n_batch,float * __restrict__ result)82 void SparseMatrixBatchVectorMultiplyAccumulate(
83     const float* __restrict__ matrix, const uint8_t* __restrict__ ledger,
84     int m_rows, int m_cols, const float* __restrict__ vector, int n_batch,
85     float* __restrict__ result) {
86   NEON_OR_PORTABLE(SparseMatrixBatchVectorMultiplyAccumulate, matrix, ledger,
87                    m_rows, m_cols, vector, n_batch, result);
88 }
89 
SparseMatrixBatchVectorMultiplyAccumulate(const int8_t * __restrict__ matrix,const uint8_t * __restrict__ ledger,const int m_rows,const int m_cols,const int8_t * __restrict__ vectors,const float * __restrict__ scaling_factors,int n_batch,float * __restrict__ result)90 void SparseMatrixBatchVectorMultiplyAccumulate(
91     const int8_t* __restrict__ matrix, const uint8_t* __restrict__ ledger,
92     const int m_rows, const int m_cols, const int8_t* __restrict__ vectors,
93     const float* __restrict__ scaling_factors, int n_batch,
94     float* __restrict__ result) {
95   SSE_OR_PORTABLE(SparseMatrixBatchVectorMultiplyAccumulate, matrix, ledger,
96                   m_rows, m_cols, vectors, scaling_factors, n_batch, result);
97 }
98 
MatrixBatchVectorMultiplyAccumulate(const int8_t * input,const int32_t * input_zeropoint_times_weights,const int8_t * input_to_gate_weights,int32_t multiplier,int32_t shift,int32_t n_batch,int32_t n_input,int32_t n_output,int32_t output_zp,int32_t * scratch,int16_t * output,CpuBackendContext * context)99 void MatrixBatchVectorMultiplyAccumulate(
100     const int8_t* input, const int32_t* input_zeropoint_times_weights,
101     const int8_t* input_to_gate_weights, int32_t multiplier, int32_t shift,
102     int32_t n_batch, int32_t n_input, int32_t n_output, int32_t output_zp,
103     int32_t* scratch, int16_t* output, CpuBackendContext* context) {
104   PortableMatrixBatchVectorMultiplyAccumulate(
105       input, input_zeropoint_times_weights, input_to_gate_weights, multiplier,
106       shift, n_batch, n_input, n_output, output_zp, scratch, output, context);
107 }
108 
MatrixBatchVectorMultiplyAccumulate(const int8_t * input,const int32_t * input_zeropoint_times_weights,const int8_t * input_to_gate_weights,int32_t multiplier,int32_t shift,int32_t n_batch,int32_t n_input,int32_t n_output,int32_t output_zp,int32_t * scratch,int8_t * output,CpuBackendContext * context)109 void MatrixBatchVectorMultiplyAccumulate(
110     const int8_t* input, const int32_t* input_zeropoint_times_weights,
111     const int8_t* input_to_gate_weights, int32_t multiplier, int32_t shift,
112     int32_t n_batch, int32_t n_input, int32_t n_output, int32_t output_zp,
113     int32_t* scratch, int8_t* output, CpuBackendContext* context) {
114   PortableMatrixBatchVectorMultiplyAccumulate(
115       input, input_zeropoint_times_weights, input_to_gate_weights, multiplier,
116       shift, n_batch, n_input, n_output, output_zp, scratch, output, context);
117 }
118 
MatrixBatchVectorMultiply(const int8_t * input,int32_t input_zeropoint,const int8_t * input_to_gate_weights,int32_t input_to_gate_effective_scale_a,int32_t input_to_gate_effective_scale_b,int32_t n_batch,int32_t n_input,int32_t n_cell,int8_t * gate_output,int8_t gate_output_zp)119 void MatrixBatchVectorMultiply(const int8_t* input, int32_t input_zeropoint,
120                                const int8_t* input_to_gate_weights,
121                                int32_t input_to_gate_effective_scale_a,
122                                int32_t input_to_gate_effective_scale_b,
123                                int32_t n_batch, int32_t n_input, int32_t n_cell,
124                                int8_t* gate_output, int8_t gate_output_zp) {
125   PortableMatrixBatchVectorMultiply(
126       input, input_zeropoint, input_to_gate_weights,
127       input_to_gate_effective_scale_a, input_to_gate_effective_scale_b, n_batch,
128       n_input, n_cell, gate_output, gate_output_zp);
129 }
130 
MatrixBatchVectorMultiply(const int16_t * hidden,const int8_t * hidden_to_output_weights,int32_t proj_effective_scale_a,int32_t proj_effective_scale_b,const int32_t * gate_bias,int32_t n_batch,int32_t n_hidden,int32_t n_output,int32_t output_zp,int8_t * proj_output)131 void MatrixBatchVectorMultiply(const int16_t* hidden,
132                                const int8_t* hidden_to_output_weights,
133                                int32_t proj_effective_scale_a,
134                                int32_t proj_effective_scale_b,
135                                const int32_t* gate_bias, int32_t n_batch,
136                                int32_t n_hidden, int32_t n_output,
137                                int32_t output_zp, int8_t* proj_output) {
138   PortableMatrixBatchVectorMultiply(hidden, hidden_to_output_weights,
139                                     proj_effective_scale_a,
140                                     proj_effective_scale_b, gate_bias, n_batch,
141                                     n_hidden, n_output, output_zp, proj_output);
142 }
143 
MatrixScalarMultiplyAccumulate(const int8_t * matrix,int32_t scalar,int32_t n_row,int32_t n_col,int32_t * output)144 void MatrixScalarMultiplyAccumulate(const int8_t* matrix, int32_t scalar,
145                                     int32_t n_row, int32_t n_col,
146                                     int32_t* output) {
147   PortableMatrixScalarMultiplyAccumulate(matrix, scalar, n_row, n_col, output);
148 }
149 
ApplyLayerNorm(const int16_t * input,const int16_t * layer_norm_weights,const int32_t * bias,int32_t layer_norm_scale_a,int32_t layer_norm_scale_b,int32_t variance_limit,int n_batch,int n_input,int16_t * output)150 void ApplyLayerNorm(const int16_t* input, const int16_t* layer_norm_weights,
151                     const int32_t* bias, int32_t layer_norm_scale_a,
152                     int32_t layer_norm_scale_b, int32_t variance_limit,
153                     int n_batch, int n_input, int16_t* output) {
154   PortableApplyLayerNorm(input, layer_norm_weights, bias, layer_norm_scale_a,
155                          layer_norm_scale_b, variance_limit, n_batch, n_input,
156                          output);
157 }
158 
ApplyLayerNormFloat(const int16_t * input,const int16_t * layer_norm_weights,int32_t layer_norm_scale_a,int32_t layer_norm_scale_b,const int32_t * bias,int n_batch,int n_input,int16_t * output)159 void ApplyLayerNormFloat(const int16_t* input,
160                          const int16_t* layer_norm_weights,
161                          int32_t layer_norm_scale_a, int32_t layer_norm_scale_b,
162                          const int32_t* bias, int n_batch, int n_input,
163                          int16_t* output) {
164   PortableApplyLayerNormFloat(input, layer_norm_weights, layer_norm_scale_a,
165                               layer_norm_scale_b, bias, n_batch, n_input,
166                               output);
167 }
168 
ApplySigmoid(const int16_t * input,int32_t n_batch,int32_t n_input,int16_t * output)169 void ApplySigmoid(const int16_t* input, int32_t n_batch, int32_t n_input,
170                   int16_t* output) {
171   PortableApplySigmoid(input, n_batch, n_input, output);
172 }
173 
ApplySigmoidFloat(const int16_t * input,int32_t n_batch,int32_t n_input,int16_t * output)174 void ApplySigmoidFloat(const int16_t* input, int32_t n_batch, int32_t n_input,
175                        int16_t* output) {
176   PortableApplySigmoidFloat(input, n_batch, n_input, output);
177 }
178 
ApplyTanh(int32_t intger_bits,const int16_t * input,int32_t n_batch,int32_t n_input,int16_t * output)179 void ApplyTanh(int32_t intger_bits, const int16_t* input, int32_t n_batch,
180                int32_t n_input, int16_t* output) {
181   PortableApplyTanh(intger_bits, input, n_batch, n_input, output);
182 }
183 
ApplyTanhFloat(const int16_t * input,int32_t n_batch,int32_t n_input,int32_t integer_bits,int16_t * output)184 void ApplyTanhFloat(const int16_t* input, int32_t n_batch, int32_t n_input,
185                     int32_t integer_bits, int16_t* output) {
186   PortableApplyTanhFloat(input, n_batch, n_input, integer_bits, output);
187 }
188 
CwiseMul(const int16_t * input_1,const int16_t * input_2,int n_batch,int n_input,int shift,int16_t * output)189 void CwiseMul(const int16_t* input_1, const int16_t* input_2, int n_batch,
190               int n_input, int shift, int16_t* output) {
191   PortableCwiseMul(input_1, input_2, n_batch, n_input, shift, output);
192 }
193 
CwiseMul(const int16_t * input_1,const int16_t * input_2,int32_t multiplier,int32_t shift,int32_t n_batch,int32_t n_input,int32_t output_zp,int8_t * output)194 void CwiseMul(const int16_t* input_1, const int16_t* input_2,
195               int32_t multiplier, int32_t shift, int32_t n_batch,
196               int32_t n_input, int32_t output_zp, int8_t* output) {
197   PortableCwiseMul(input_1, input_2, multiplier, shift, n_batch, n_input,
198                    output_zp, output);
199 }
200 
CwiseAdd(const int16_t * input_1,const int16_t * input_2,int n_batch,int n_input,int16_t * output)201 void CwiseAdd(const int16_t* input_1, const int16_t* input_2, int n_batch,
202               int n_input, int16_t* output) {
203   PortableCwiseAdd(input_1, input_2, n_batch, n_input, output);
204 }
205 
CwiseClipping(float * vector,const int v_size,const float clipping_value)206 void CwiseClipping(float* vector, const int v_size,
207                    const float clipping_value) {
208   PortableCwiseClipping(vector, v_size, clipping_value);
209 }
210 
CwiseClipping(int16_t * vector,const int v_size,const int16_t clipping_value)211 void CwiseClipping(int16_t* vector, const int v_size,
212                    const int16_t clipping_value) {
213   PortableCwiseClipping(vector, v_size, clipping_value);
214 }
215 
CwiseClipping(int8_t * vector,const int v_size,const int8_t clipping_value)216 void CwiseClipping(int8_t* vector, const int v_size,
217                    const int8_t clipping_value) {
218   PortableCwiseClipping(vector, v_size, clipping_value);
219 }
220 
BatchVectorBatchVectorDotProduct(const int16_t * vector1,const int16_t * vector2,int v_size,int n_batch,int32_t * result)221 void BatchVectorBatchVectorDotProduct(const int16_t* vector1,
222                                       const int16_t* vector2, int v_size,
223                                       int n_batch, int32_t* result) {
224   PortableBatchVectorBatchVectorDotProduct(vector1, vector2, v_size, n_batch,
225                                            result);
226 }
227 
VectorBatchVectorCwiseProductAccumulate(const int16_t * vector,int v_size,const int16_t * batch_vector,int n_batch,int32_t multiplier,int shift,int16_t * result)228 void VectorBatchVectorCwiseProductAccumulate(const int16_t* vector, int v_size,
229                                              const int16_t* batch_vector,
230                                              int n_batch, int32_t multiplier,
231                                              int shift, int16_t* result) {
232   NEON_OR_PORTABLE(VectorBatchVectorCwiseProductAccumulate, vector, v_size,
233                    batch_vector, n_batch, multiplier, shift, result);
234 }
235 
VectorVectorDotProduct(const float * vector1,const float * vector2,int v_size)236 float VectorVectorDotProduct(const float* vector1, const float* vector2,
237                              int v_size) {
238   return NEON_OR_PORTABLE(VectorVectorDotProduct, vector1, vector2, v_size);
239 }
240 
Sub1Vector(const float * vector,int v_size,float * result)241 void Sub1Vector(const float* vector, int v_size, float* result) {
242   NEON_OR_PORTABLE(Sub1Vector, vector, v_size, result);
243 }
244 
Sub1Vector(const int16_t * vector,int v_size,int16_t * result)245 void Sub1Vector(const int16_t* vector, int v_size, int16_t* result) {
246   PortableSub1Vector(vector, v_size, result);
247 }
248 
249 // Check if all entries of a vector are zero for float.
IsZeroVector(const float * vector,int v_size)250 bool IsZeroVector(const float* vector, int v_size) {
251   return NEON_OR_PORTABLE(IsZeroVector, vector, v_size);
252 }
253 
254 // Check if all entries of a vector are zero for int8.
IsZeroVector(const int8_t * vector,int v_size)255 bool IsZeroVector(const int8_t* vector, int v_size) {
256   return PortableIsZeroVector(vector, v_size);
257 }
258 
VectorScalarMultiply(const int8_t * vector,int v_size,float scale,float * result)259 void VectorScalarMultiply(const int8_t* vector, int v_size, float scale,
260                           float* result) {
261   NEON_OR_PORTABLE(VectorScalarMultiply, vector, v_size, scale, result);
262 }
263 
SymmetricQuantizeFloats(const float * values,const int size,int8_t * quantized_values,float * min_value,float * max_value,float * scaling_factor)264 void SymmetricQuantizeFloats(const float* values, const int size,
265                              int8_t* quantized_values, float* min_value,
266                              float* max_value, float* scaling_factor) {
267   NEON_OR_PORTABLE(SymmetricQuantizeFloats, values, size, quantized_values,
268                    min_value, max_value, scaling_factor);
269 }
270 
SymmetricQuantizeFloats(const float * values,const int size,int8_t * quantized_values,float min_value,float max_value,float * scaling_factor)271 void SymmetricQuantizeFloats(const float* values, const int size,
272                              int8_t* quantized_values, float min_value,
273                              float max_value, float* scaling_factor) {
274   NEON_OR_PORTABLE(SymmetricQuantizeFloats, values, size, quantized_values,
275                    min_value, max_value, scaling_factor);
276 }
277 
AsymmetricQuantizeFloats(const float * values,const int size,int8_t * quantized_values,float * scaling_factor,int32_t * offset)278 void AsymmetricQuantizeFloats(const float* values, const int size,
279                               int8_t* quantized_values, float* scaling_factor,
280                               int32_t* offset) {
281   NEON_OR_PORTABLE(AsymmetricQuantizeFloats, values, size, quantized_values,
282                    scaling_factor, offset);
283 }
284 
ReductionSumVector(const float * input_vector,float * output_vector,int output_size,int reduction_size)285 void ReductionSumVector(const float* input_vector, float* output_vector,
286                         int output_size, int reduction_size) {
287   NEON_OR_PORTABLE(ReductionSumVector, input_vector, output_vector, output_size,
288                    reduction_size);
289 }
290 
ReductionSumVector(const int32_t * input_vector,int32_t * output_vector,int output_size,int reduction_size)291 void ReductionSumVector(const int32_t* input_vector, int32_t* output_vector,
292                         int output_size, int reduction_size) {
293   PortableReductionSumVector(input_vector, output_vector, output_size,
294                              reduction_size);
295 }
296 
ReductionSumVector(const int8_t * input_vector,int32_t * output_vector,int output_size,int reduction_size)297 void ReductionSumVector(const int8_t* input_vector, int32_t* output_vector,
298                         int output_size, int reduction_size) {
299   SSE_OR_PORTABLE(ReductionSumVector, input_vector, output_vector, output_size,
300                   reduction_size);
301 }
302 
MeanStddevNormalization(const float * __restrict__ input_vector,float * __restrict__ output_vector,int v_size,int n_batch)303 void MeanStddevNormalization(const float* __restrict__ input_vector,
304                              float* __restrict__ output_vector, int v_size,
305                              int n_batch) {
306   PortableMeanStddevNormalization(input_vector, output_vector, v_size, n_batch);
307 }
308 
TwoGateSaturatingAdd(const int8_t * input,int8_t input_zp,const int8_t * recurrent,int8_t recurrent_zp,int32_t input_effective_scale_a,int32_t input_effective_scale_b,int32_t recurrent_effective_scale_a,int32_t recurrent_effective_scale_b,int32_t n_batch,int32_t n_cell,int16_t * output)309 void TwoGateSaturatingAdd(const int8_t* input, int8_t input_zp,
310                           const int8_t* recurrent, int8_t recurrent_zp,
311                           int32_t input_effective_scale_a,
312                           int32_t input_effective_scale_b,
313                           int32_t recurrent_effective_scale_a,
314                           int32_t recurrent_effective_scale_b, int32_t n_batch,
315                           int32_t n_cell, int16_t* output) {
316   PortableTwoGateSaturatingAdd(
317       input, input_zp, recurrent, recurrent_zp, input_effective_scale_a,
318       input_effective_scale_b, recurrent_effective_scale_a,
319       recurrent_effective_scale_b, n_batch, n_cell, output);
320 }
321 
322 }  // namespace tensor_utils
323 }  // namespace tflite
324 
325 #endif  // TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_SSE_TENSOR_UTILS_H_
326