• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_BATCH_MATMUL_H_
16 #define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_BATCH_MATMUL_H_
17 
18 #include <algorithm>
19 #include <cstdint>
20 
21 #include "tensorflow/lite/kernels/internal/common.h"
22 #include "tensorflow/lite/kernels/internal/compatibility.h"
23 #include "tensorflow/lite/kernels/internal/tensor_utils_common.h"
24 #include "tensorflow/lite/kernels/internal/types.h"
25 
26 namespace tflite {
27 namespace reference_ops {
28 namespace batch_matmul {
29 
30 // Determine which dimension is the broadcast dimension.
broadcast_dim(int lhs_dim,int rhs_dim)31 inline int broadcast_dim(int lhs_dim, int rhs_dim) {
32   if (lhs_dim == rhs_dim) return lhs_dim;
33   if (lhs_dim == 1) return rhs_dim;
34   TFLITE_DCHECK_EQ(rhs_dim, 1);
35   return lhs_dim;
36 }
37 
38 // Compute the "extent" for iterating on this dimension.
39 // If we are broadcasting, then don't advance (i.e return 0).
extent(const RuntimeShape & shape,int x)40 inline int extent(const RuntimeShape& shape, int x) {
41   if (shape.Dims(x) == 1) {
42     return 0;
43   }
44   int prod = 1;
45   for (int i = x + 1; i < shape.DimensionsCount(); ++i) {
46     prod *= shape.Dims(i);
47   }
48   return prod;
49 }
50 
51 }  // namespace batch_matmul
52 
BatchMatMul(const RuntimeShape & lhs_shape,const float * lhs_data,const RuntimeShape & rhs_shape,const float * rhs_data,const RuntimeShape & output_shape,float * output_data)53 inline void BatchMatMul(const RuntimeShape& lhs_shape, const float* lhs_data,
54                         const RuntimeShape& rhs_shape, const float* rhs_data,
55                         const RuntimeShape& output_shape, float* output_data) {
56   const RuntimeShape extended_lhs_shape =
57       RuntimeShape::ExtendedShape(5, lhs_shape);
58   const RuntimeShape extended_rhs_shape =
59       RuntimeShape::ExtendedShape(5, rhs_shape);
60 
61   const int batch_dim0 = batch_matmul::broadcast_dim(
62       extended_lhs_shape.Dims(0), extended_rhs_shape.Dims(0));
63   const int batch_dim1 = batch_matmul::broadcast_dim(
64       extended_lhs_shape.Dims(1), extended_rhs_shape.Dims(1));
65   const int batch_dim2 = batch_matmul::broadcast_dim(
66       extended_lhs_shape.Dims(2), extended_rhs_shape.Dims(2));
67 
68   const int lhs_ext0 = batch_matmul::extent(extended_lhs_shape, 0);
69   const int lhs_ext1 = batch_matmul::extent(extended_lhs_shape, 1);
70   const int lhs_ext2 = batch_matmul::extent(extended_lhs_shape, 2);
71   const int rhs_ext0 = batch_matmul::extent(extended_rhs_shape, 0);
72   const int rhs_ext1 = batch_matmul::extent(extended_rhs_shape, 1);
73   const int rhs_ext2 = batch_matmul::extent(extended_rhs_shape, 2);
74 
75   // Set params for each matrix multiply.
76   const int lhs_rows = extended_lhs_shape.Dims(3);
77   const int rhs_cols = extended_rhs_shape.Dims(4);
78   const int accum_depth = extended_lhs_shape.Dims(4);
79 
80   for (int b0 = 0; b0 < batch_dim0; ++b0) {
81     const float* lhs_ptr0 = lhs_data + (b0 * lhs_ext0);
82     const float* rhs_ptr0 = rhs_data + (b0 * rhs_ext0);
83     for (int b1 = 0; b1 < batch_dim1; ++b1) {
84       const float* lhs_ptr1 = lhs_ptr0 + b1 * lhs_ext1;
85       const float* rhs_ptr1 = rhs_ptr0 + b1 * rhs_ext1;
86       for (int b2 = 0; b2 < batch_dim2; ++b2) {
87         const float* lhs_ptr2 = lhs_ptr1 + b2 * lhs_ext2;
88         const float* rhs_ptr2 = rhs_ptr1 + b2 * rhs_ext2;
89         float* out_ptr = output_data + ((b0 * batch_dim1 * batch_dim2) +
90                                         b1 * batch_dim2 + b2) *
91                                            lhs_rows * rhs_cols;
92         for (int j = 0; j < rhs_cols; ++j) {
93           for (int i = 0; i < lhs_rows; ++i) {
94             float total = 0.f;
95             for (int k = 0; k < accum_depth; ++k) {
96               total +=
97                   lhs_ptr2[accum_depth * i + k] * rhs_ptr2[j * accum_depth + k];
98             }
99             int idx = lhs_rows * j + i;
100             out_ptr[idx] = total;
101           }
102         }
103       }
104     }
105   }
106 }
107 
BatchMatMul(const RuntimeShape & lhs_shape,const int8_t * lhs_data,const RuntimeShape & rhs_shape,const int8_t * rhs_data,const float * scaling_factors,const int32_t * input_offset,int32_t * row_sums,const RuntimeShape & output_shape,float * output_data,bool * compute_row_sums)108 inline void BatchMatMul(const RuntimeShape& lhs_shape, const int8_t* lhs_data,
109                         const RuntimeShape& rhs_shape, const int8_t* rhs_data,
110                         const float* scaling_factors,
111                         const int32_t* input_offset, int32_t* row_sums,
112                         const RuntimeShape& output_shape, float* output_data,
113                         bool* compute_row_sums) {
114   const RuntimeShape extended_lhs_shape =
115       RuntimeShape::ExtendedShape(5, lhs_shape);
116   const RuntimeShape extended_rhs_shape =
117       RuntimeShape::ExtendedShape(5, rhs_shape);
118 
119   const int batch_dim0 = batch_matmul::broadcast_dim(
120       extended_lhs_shape.Dims(0), extended_rhs_shape.Dims(0));
121   const int batch_dim1 = batch_matmul::broadcast_dim(
122       extended_lhs_shape.Dims(1), extended_rhs_shape.Dims(1));
123   const int batch_dim2 = batch_matmul::broadcast_dim(
124       extended_lhs_shape.Dims(2), extended_rhs_shape.Dims(2));
125 
126   const int lhs_ext0 = batch_matmul::extent(extended_lhs_shape, 0);
127   const int lhs_ext1 = batch_matmul::extent(extended_lhs_shape, 1);
128   const int lhs_ext2 = batch_matmul::extent(extended_lhs_shape, 2);
129   const int rhs_ext0 = batch_matmul::extent(extended_rhs_shape, 0);
130   const int rhs_ext1 = batch_matmul::extent(extended_rhs_shape, 1);
131   const int rhs_ext2 = batch_matmul::extent(extended_rhs_shape, 2);
132 
133   // Set params for each matrix multiply.
134   const int lhs_rows = extended_lhs_shape.Dims(3);
135   const int rhs_cols = extended_rhs_shape.Dims(4);
136   const int accum_depth = extended_lhs_shape.Dims(4);
137 
138   const int ioff_ext0 = rhs_ext0 == 0 ? 0 : rhs_cols;
139   const int ioff_ext1 = rhs_ext1 == 0 ? 0 : rhs_cols;
140   const int ioff_ext2 = rhs_ext2 == 0 ? 0 : rhs_cols;
141   const int woff_ext0 = lhs_ext0 == 0 ? 0 : lhs_rows;
142   const int woff_ext1 = lhs_ext1 == 0 ? 0 : lhs_rows;
143   const int woff_ext2 = lhs_ext2 == 0 ? 0 : lhs_rows;
144 
145   if (!compute_row_sums || *compute_row_sums) {
146     int num_weights_matrices = 1;
147     for (int i = 1; i < extended_lhs_shape.DimensionsCount() - 2; ++i) {
148       num_weights_matrices *= extended_lhs_shape.Dims(i);
149     }
150     tensor_utils::ReductionSumVector(
151         lhs_data, row_sums, num_weights_matrices * lhs_rows, accum_depth);
152     if (compute_row_sums) {
153       *compute_row_sums = false;
154     }
155   }
156 
157   for (int b0 = 0; b0 < batch_dim0; ++b0) {
158     const int8_t* lhs_ptr0 = lhs_data + (b0 * lhs_ext0);
159     const int8_t* rhs_ptr0 = rhs_data + (b0 * rhs_ext0);
160     const int32_t* ioff_ptr0 = input_offset + (b0 * ioff_ext0);
161     const float* scale_ptr0 = scaling_factors + (b0 * ioff_ext0);
162     const int32_t* woff_ptr0 = row_sums + (b0 * woff_ext0);
163     for (int b1 = 0; b1 < batch_dim1; ++b1) {
164       const int8_t* lhs_ptr1 = lhs_ptr0 + b1 * lhs_ext1;
165       const int8_t* rhs_ptr1 = rhs_ptr0 + b1 * rhs_ext1;
166       const int32_t* ioff_ptr1 = ioff_ptr0 + (b1 * ioff_ext1);
167       const float* scale_ptr1 = scale_ptr0 + (b1 * ioff_ext1);
168       const int32_t* woff_ptr1 = woff_ptr0 + (b1 * woff_ext1);
169       for (int b2 = 0; b2 < batch_dim2; ++b2) {
170         const int8_t* lhs_ptr2 = lhs_ptr1 + b2 * lhs_ext2;
171         const int8_t* rhs_ptr2 = rhs_ptr1 + b2 * rhs_ext2;
172         const int32_t* ioff_ptr2 = ioff_ptr1 + (b2 * ioff_ext2);
173         const float* scale_ptr2 = scale_ptr1 + (b2 * ioff_ext2);
174         const int32_t* woff_ptr2 = woff_ptr1 + (b2 * woff_ext2);
175         float* out_ptr = output_data + ((b0 * batch_dim1 * batch_dim2) +
176                                         b1 * batch_dim2 + b2) *
177                                            lhs_rows * rhs_cols;
178         for (int j = 0; j < rhs_cols; ++j) {
179           const float batch_scaling_factor = scale_ptr2[j];
180           const float batch_offset = static_cast<float>(ioff_ptr2[j]);
181           for (int i = 0; i < lhs_rows; ++i) {
182             int32_t total = 0;
183             for (int k = 0; k < accum_depth; ++k) {
184               total +=
185                   lhs_ptr2[accum_depth * i + k] * rhs_ptr2[j * accum_depth + k];
186             }
187             int32_t row_sum = woff_ptr2[i];
188             total -= row_sum * batch_offset;
189             int idx = lhs_rows * j + i;
190             out_ptr[idx] += batch_scaling_factor * total;
191           }
192         }
193       }
194     }
195   }
196 }
197 
198 template <typename T, typename AccumT>
BatchMatMul(const FullyConnectedParams & params,const RuntimeShape & lhs_shape,const T * lhs_data,const RuntimeShape & rhs_shape,const T * rhs_data,const RuntimeShape & output_shape,T * output_data)199 inline void BatchMatMul(const FullyConnectedParams& params,
200                         const RuntimeShape& lhs_shape, const T* lhs_data,
201                         const RuntimeShape& rhs_shape, const T* rhs_data,
202                         const RuntimeShape& output_shape, T* output_data) {
203   const RuntimeShape extended_lhs_shape =
204       RuntimeShape::ExtendedShape(5, lhs_shape);
205   const RuntimeShape extended_rhs_shape =
206       RuntimeShape::ExtendedShape(5, rhs_shape);
207 
208   const int batch_dim0 = batch_matmul::broadcast_dim(
209       extended_lhs_shape.Dims(0), extended_rhs_shape.Dims(0));
210   const int batch_dim1 = batch_matmul::broadcast_dim(
211       extended_lhs_shape.Dims(1), extended_rhs_shape.Dims(1));
212   const int batch_dim2 = batch_matmul::broadcast_dim(
213       extended_lhs_shape.Dims(2), extended_rhs_shape.Dims(2));
214 
215   const int lhs_ext0 = batch_matmul::extent(extended_lhs_shape, 0);
216   const int lhs_ext1 = batch_matmul::extent(extended_lhs_shape, 1);
217   const int lhs_ext2 = batch_matmul::extent(extended_lhs_shape, 2);
218   const int rhs_ext0 = batch_matmul::extent(extended_rhs_shape, 0);
219   const int rhs_ext1 = batch_matmul::extent(extended_rhs_shape, 1);
220   const int rhs_ext2 = batch_matmul::extent(extended_rhs_shape, 2);
221 
222   // Set params for each matrix multiply.
223   const int lhs_rows = extended_lhs_shape.Dims(3);
224   const int rhs_cols = extended_rhs_shape.Dims(4);
225   const int accum_depth = extended_lhs_shape.Dims(4);
226 
227   const int32_t input_offset = params.input_offset;
228   const int32_t filter_offset = params.weights_offset;
229   const int32_t output_offset = params.output_offset;
230   const int32_t output_multiplier = params.output_multiplier;
231   const int output_shift = params.output_shift;
232   const int32_t output_activation_min = params.quantized_activation_min;
233   const int32_t output_activation_max = params.quantized_activation_max;
234   TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
235 
236   for (int b0 = 0; b0 < batch_dim0; ++b0) {
237     const T* lhs_ptr0 = lhs_data + (b0 * lhs_ext0);
238     const T* rhs_ptr0 = rhs_data + (b0 * rhs_ext0);
239     for (int b1 = 0; b1 < batch_dim1; ++b1) {
240       const T* lhs_ptr1 = lhs_ptr0 + b1 * lhs_ext1;
241       const T* rhs_ptr1 = rhs_ptr0 + b1 * rhs_ext1;
242       for (int b2 = 0; b2 < batch_dim2; ++b2) {
243         const T* lhs_ptr2 = lhs_ptr1 + b2 * lhs_ext2;
244         const T* rhs_ptr2 = rhs_ptr1 + b2 * rhs_ext2;
245         T* out_ptr = output_data +
246                      ((b0 * batch_dim1 * batch_dim2) + b1 * batch_dim2 + b2) *
247                          lhs_rows * rhs_cols;
248 
249         for (int j = 0; j < rhs_cols; ++j) {
250           for (int i = 0; i < lhs_rows; ++i) {
251             AccumT total = 0;
252             for (int k = 0; k < accum_depth; ++k) {
253               AccumT lhs_val = lhs_ptr2[accum_depth * i + k];
254               AccumT rhs_val = rhs_ptr2[accum_depth * j + k];
255               total += (lhs_val + filter_offset) * (rhs_val + input_offset);
256             }
257             int32_t total_scaled = MultiplyByQuantizedMultiplier(
258                 total, output_multiplier, output_shift);
259             total_scaled += output_offset;
260             total_scaled = std::max(total_scaled, output_activation_min);
261             total_scaled = std::min(total_scaled, output_activation_max);
262             const int idx = lhs_rows * j + i;
263             out_ptr[idx] = static_cast<T>(total_scaled);
264           }
265         }
266       }
267     }
268   }
269 }
270 
271 }  // namespace reference_ops
272 }  // namespace tflite
273 
274 #endif  // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_BATCH_MATMUL_H_
275