1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_BATCH_MATMUL_H_
16 #define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_BATCH_MATMUL_H_
17
18 #include <algorithm>
19 #include <cstdint>
20
21 #include "tensorflow/lite/kernels/internal/common.h"
22 #include "tensorflow/lite/kernels/internal/compatibility.h"
23 #include "tensorflow/lite/kernels/internal/tensor_utils_common.h"
24 #include "tensorflow/lite/kernels/internal/types.h"
25
26 namespace tflite {
27 namespace reference_ops {
28 namespace batch_matmul {
29
30 // Determine which dimension is the broadcast dimension.
broadcast_dim(int lhs_dim,int rhs_dim)31 inline int broadcast_dim(int lhs_dim, int rhs_dim) {
32 if (lhs_dim == rhs_dim) return lhs_dim;
33 if (lhs_dim == 1) return rhs_dim;
34 TFLITE_DCHECK_EQ(rhs_dim, 1);
35 return lhs_dim;
36 }
37
38 // Compute the "extent" for iterating on this dimension.
39 // If we are broadcasting, then don't advance (i.e return 0).
extent(const RuntimeShape & shape,int x)40 inline int extent(const RuntimeShape& shape, int x) {
41 if (shape.Dims(x) == 1) {
42 return 0;
43 }
44 int prod = 1;
45 for (int i = x + 1; i < shape.DimensionsCount(); ++i) {
46 prod *= shape.Dims(i);
47 }
48 return prod;
49 }
50
51 } // namespace batch_matmul
52
BatchMatMul(const RuntimeShape & lhs_shape,const float * lhs_data,const RuntimeShape & rhs_shape,const float * rhs_data,const RuntimeShape & output_shape,float * output_data)53 inline void BatchMatMul(const RuntimeShape& lhs_shape, const float* lhs_data,
54 const RuntimeShape& rhs_shape, const float* rhs_data,
55 const RuntimeShape& output_shape, float* output_data) {
56 const RuntimeShape extended_lhs_shape =
57 RuntimeShape::ExtendedShape(5, lhs_shape);
58 const RuntimeShape extended_rhs_shape =
59 RuntimeShape::ExtendedShape(5, rhs_shape);
60
61 const int batch_dim0 = batch_matmul::broadcast_dim(
62 extended_lhs_shape.Dims(0), extended_rhs_shape.Dims(0));
63 const int batch_dim1 = batch_matmul::broadcast_dim(
64 extended_lhs_shape.Dims(1), extended_rhs_shape.Dims(1));
65 const int batch_dim2 = batch_matmul::broadcast_dim(
66 extended_lhs_shape.Dims(2), extended_rhs_shape.Dims(2));
67
68 const int lhs_ext0 = batch_matmul::extent(extended_lhs_shape, 0);
69 const int lhs_ext1 = batch_matmul::extent(extended_lhs_shape, 1);
70 const int lhs_ext2 = batch_matmul::extent(extended_lhs_shape, 2);
71 const int rhs_ext0 = batch_matmul::extent(extended_rhs_shape, 0);
72 const int rhs_ext1 = batch_matmul::extent(extended_rhs_shape, 1);
73 const int rhs_ext2 = batch_matmul::extent(extended_rhs_shape, 2);
74
75 // Set params for each matrix multiply.
76 const int lhs_rows = extended_lhs_shape.Dims(3);
77 const int rhs_cols = extended_rhs_shape.Dims(4);
78 const int accum_depth = extended_lhs_shape.Dims(4);
79
80 for (int b0 = 0; b0 < batch_dim0; ++b0) {
81 const float* lhs_ptr0 = lhs_data + (b0 * lhs_ext0);
82 const float* rhs_ptr0 = rhs_data + (b0 * rhs_ext0);
83 for (int b1 = 0; b1 < batch_dim1; ++b1) {
84 const float* lhs_ptr1 = lhs_ptr0 + b1 * lhs_ext1;
85 const float* rhs_ptr1 = rhs_ptr0 + b1 * rhs_ext1;
86 for (int b2 = 0; b2 < batch_dim2; ++b2) {
87 const float* lhs_ptr2 = lhs_ptr1 + b2 * lhs_ext2;
88 const float* rhs_ptr2 = rhs_ptr1 + b2 * rhs_ext2;
89 float* out_ptr = output_data + ((b0 * batch_dim1 * batch_dim2) +
90 b1 * batch_dim2 + b2) *
91 lhs_rows * rhs_cols;
92 for (int j = 0; j < rhs_cols; ++j) {
93 for (int i = 0; i < lhs_rows; ++i) {
94 float total = 0.f;
95 for (int k = 0; k < accum_depth; ++k) {
96 total +=
97 lhs_ptr2[accum_depth * i + k] * rhs_ptr2[j * accum_depth + k];
98 }
99 int idx = lhs_rows * j + i;
100 out_ptr[idx] = total;
101 }
102 }
103 }
104 }
105 }
106 }
107
BatchMatMul(const RuntimeShape & lhs_shape,const int8_t * lhs_data,const RuntimeShape & rhs_shape,const int8_t * rhs_data,const float * scaling_factors,const int32_t * input_offset,int32_t * row_sums,const RuntimeShape & output_shape,float * output_data,bool * compute_row_sums)108 inline void BatchMatMul(const RuntimeShape& lhs_shape, const int8_t* lhs_data,
109 const RuntimeShape& rhs_shape, const int8_t* rhs_data,
110 const float* scaling_factors,
111 const int32_t* input_offset, int32_t* row_sums,
112 const RuntimeShape& output_shape, float* output_data,
113 bool* compute_row_sums) {
114 const RuntimeShape extended_lhs_shape =
115 RuntimeShape::ExtendedShape(5, lhs_shape);
116 const RuntimeShape extended_rhs_shape =
117 RuntimeShape::ExtendedShape(5, rhs_shape);
118
119 const int batch_dim0 = batch_matmul::broadcast_dim(
120 extended_lhs_shape.Dims(0), extended_rhs_shape.Dims(0));
121 const int batch_dim1 = batch_matmul::broadcast_dim(
122 extended_lhs_shape.Dims(1), extended_rhs_shape.Dims(1));
123 const int batch_dim2 = batch_matmul::broadcast_dim(
124 extended_lhs_shape.Dims(2), extended_rhs_shape.Dims(2));
125
126 const int lhs_ext0 = batch_matmul::extent(extended_lhs_shape, 0);
127 const int lhs_ext1 = batch_matmul::extent(extended_lhs_shape, 1);
128 const int lhs_ext2 = batch_matmul::extent(extended_lhs_shape, 2);
129 const int rhs_ext0 = batch_matmul::extent(extended_rhs_shape, 0);
130 const int rhs_ext1 = batch_matmul::extent(extended_rhs_shape, 1);
131 const int rhs_ext2 = batch_matmul::extent(extended_rhs_shape, 2);
132
133 // Set params for each matrix multiply.
134 const int lhs_rows = extended_lhs_shape.Dims(3);
135 const int rhs_cols = extended_rhs_shape.Dims(4);
136 const int accum_depth = extended_lhs_shape.Dims(4);
137
138 const int ioff_ext0 = rhs_ext0 == 0 ? 0 : rhs_cols;
139 const int ioff_ext1 = rhs_ext1 == 0 ? 0 : rhs_cols;
140 const int ioff_ext2 = rhs_ext2 == 0 ? 0 : rhs_cols;
141 const int woff_ext0 = lhs_ext0 == 0 ? 0 : lhs_rows;
142 const int woff_ext1 = lhs_ext1 == 0 ? 0 : lhs_rows;
143 const int woff_ext2 = lhs_ext2 == 0 ? 0 : lhs_rows;
144
145 if (!compute_row_sums || *compute_row_sums) {
146 int num_weights_matrices = 1;
147 for (int i = 1; i < extended_lhs_shape.DimensionsCount() - 2; ++i) {
148 num_weights_matrices *= extended_lhs_shape.Dims(i);
149 }
150 tensor_utils::ReductionSumVector(
151 lhs_data, row_sums, num_weights_matrices * lhs_rows, accum_depth);
152 if (compute_row_sums) {
153 *compute_row_sums = false;
154 }
155 }
156
157 for (int b0 = 0; b0 < batch_dim0; ++b0) {
158 const int8_t* lhs_ptr0 = lhs_data + (b0 * lhs_ext0);
159 const int8_t* rhs_ptr0 = rhs_data + (b0 * rhs_ext0);
160 const int32_t* ioff_ptr0 = input_offset + (b0 * ioff_ext0);
161 const float* scale_ptr0 = scaling_factors + (b0 * ioff_ext0);
162 const int32_t* woff_ptr0 = row_sums + (b0 * woff_ext0);
163 for (int b1 = 0; b1 < batch_dim1; ++b1) {
164 const int8_t* lhs_ptr1 = lhs_ptr0 + b1 * lhs_ext1;
165 const int8_t* rhs_ptr1 = rhs_ptr0 + b1 * rhs_ext1;
166 const int32_t* ioff_ptr1 = ioff_ptr0 + (b1 * ioff_ext1);
167 const float* scale_ptr1 = scale_ptr0 + (b1 * ioff_ext1);
168 const int32_t* woff_ptr1 = woff_ptr0 + (b1 * woff_ext1);
169 for (int b2 = 0; b2 < batch_dim2; ++b2) {
170 const int8_t* lhs_ptr2 = lhs_ptr1 + b2 * lhs_ext2;
171 const int8_t* rhs_ptr2 = rhs_ptr1 + b2 * rhs_ext2;
172 const int32_t* ioff_ptr2 = ioff_ptr1 + (b2 * ioff_ext2);
173 const float* scale_ptr2 = scale_ptr1 + (b2 * ioff_ext2);
174 const int32_t* woff_ptr2 = woff_ptr1 + (b2 * woff_ext2);
175 float* out_ptr = output_data + ((b0 * batch_dim1 * batch_dim2) +
176 b1 * batch_dim2 + b2) *
177 lhs_rows * rhs_cols;
178 for (int j = 0; j < rhs_cols; ++j) {
179 const float batch_scaling_factor = scale_ptr2[j];
180 const float batch_offset = static_cast<float>(ioff_ptr2[j]);
181 for (int i = 0; i < lhs_rows; ++i) {
182 int32_t total = 0;
183 for (int k = 0; k < accum_depth; ++k) {
184 total +=
185 lhs_ptr2[accum_depth * i + k] * rhs_ptr2[j * accum_depth + k];
186 }
187 int32_t row_sum = woff_ptr2[i];
188 total -= row_sum * batch_offset;
189 int idx = lhs_rows * j + i;
190 out_ptr[idx] += batch_scaling_factor * total;
191 }
192 }
193 }
194 }
195 }
196 }
197
198 template <typename T, typename AccumT>
BatchMatMul(const FullyConnectedParams & params,const RuntimeShape & lhs_shape,const T * lhs_data,const RuntimeShape & rhs_shape,const T * rhs_data,const RuntimeShape & output_shape,T * output_data)199 inline void BatchMatMul(const FullyConnectedParams& params,
200 const RuntimeShape& lhs_shape, const T* lhs_data,
201 const RuntimeShape& rhs_shape, const T* rhs_data,
202 const RuntimeShape& output_shape, T* output_data) {
203 const RuntimeShape extended_lhs_shape =
204 RuntimeShape::ExtendedShape(5, lhs_shape);
205 const RuntimeShape extended_rhs_shape =
206 RuntimeShape::ExtendedShape(5, rhs_shape);
207
208 const int batch_dim0 = batch_matmul::broadcast_dim(
209 extended_lhs_shape.Dims(0), extended_rhs_shape.Dims(0));
210 const int batch_dim1 = batch_matmul::broadcast_dim(
211 extended_lhs_shape.Dims(1), extended_rhs_shape.Dims(1));
212 const int batch_dim2 = batch_matmul::broadcast_dim(
213 extended_lhs_shape.Dims(2), extended_rhs_shape.Dims(2));
214
215 const int lhs_ext0 = batch_matmul::extent(extended_lhs_shape, 0);
216 const int lhs_ext1 = batch_matmul::extent(extended_lhs_shape, 1);
217 const int lhs_ext2 = batch_matmul::extent(extended_lhs_shape, 2);
218 const int rhs_ext0 = batch_matmul::extent(extended_rhs_shape, 0);
219 const int rhs_ext1 = batch_matmul::extent(extended_rhs_shape, 1);
220 const int rhs_ext2 = batch_matmul::extent(extended_rhs_shape, 2);
221
222 // Set params for each matrix multiply.
223 const int lhs_rows = extended_lhs_shape.Dims(3);
224 const int rhs_cols = extended_rhs_shape.Dims(4);
225 const int accum_depth = extended_lhs_shape.Dims(4);
226
227 const int32_t input_offset = params.input_offset;
228 const int32_t filter_offset = params.weights_offset;
229 const int32_t output_offset = params.output_offset;
230 const int32_t output_multiplier = params.output_multiplier;
231 const int output_shift = params.output_shift;
232 const int32_t output_activation_min = params.quantized_activation_min;
233 const int32_t output_activation_max = params.quantized_activation_max;
234 TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
235
236 for (int b0 = 0; b0 < batch_dim0; ++b0) {
237 const T* lhs_ptr0 = lhs_data + (b0 * lhs_ext0);
238 const T* rhs_ptr0 = rhs_data + (b0 * rhs_ext0);
239 for (int b1 = 0; b1 < batch_dim1; ++b1) {
240 const T* lhs_ptr1 = lhs_ptr0 + b1 * lhs_ext1;
241 const T* rhs_ptr1 = rhs_ptr0 + b1 * rhs_ext1;
242 for (int b2 = 0; b2 < batch_dim2; ++b2) {
243 const T* lhs_ptr2 = lhs_ptr1 + b2 * lhs_ext2;
244 const T* rhs_ptr2 = rhs_ptr1 + b2 * rhs_ext2;
245 T* out_ptr = output_data +
246 ((b0 * batch_dim1 * batch_dim2) + b1 * batch_dim2 + b2) *
247 lhs_rows * rhs_cols;
248
249 for (int j = 0; j < rhs_cols; ++j) {
250 for (int i = 0; i < lhs_rows; ++i) {
251 AccumT total = 0;
252 for (int k = 0; k < accum_depth; ++k) {
253 AccumT lhs_val = lhs_ptr2[accum_depth * i + k];
254 AccumT rhs_val = rhs_ptr2[accum_depth * j + k];
255 total += (lhs_val + filter_offset) * (rhs_val + input_offset);
256 }
257 int32_t total_scaled = MultiplyByQuantizedMultiplier(
258 total, output_multiplier, output_shift);
259 total_scaled += output_offset;
260 total_scaled = std::max(total_scaled, output_activation_min);
261 total_scaled = std::min(total_scaled, output_activation_max);
262 const int idx = lhs_rows * j + i;
263 out_ptr[idx] = static_cast<T>(total_scaled);
264 }
265 }
266 }
267 }
268 }
269 }
270
271 } // namespace reference_ops
272 } // namespace tflite
273
274 #endif // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_BATCH_MATMUL_H_
275