• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2017-2020 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #include "src/core/NEON/kernels/NELocallyConnectedMatrixMultiplyKernel.h"
25 
26 #include "arm_compute/core/Error.h"
27 #include "arm_compute/core/Helpers.h"
28 #include "arm_compute/core/IAccessWindow.h"
29 #include "arm_compute/core/ITensor.h"
30 #include "arm_compute/core/TensorInfo.h"
31 #include "arm_compute/core/Types.h"
32 #include "arm_compute/core/Utils.h"
33 #include "arm_compute/core/Validate.h"
34 #include "arm_compute/core/Window.h"
35 #include "src/core/CPP/Validate.h"
36 #include "src/core/NEON/NEFixedPoint.h"
37 #include "src/core/helpers/AutoConfiguration.h"
38 #include "src/core/helpers/WindowHelpers.h"
39 
40 #include <arm_neon.h>
41 #include <cstddef>
42 #include <cstdint>
43 #include <tuple>
44 
45 namespace arm_compute
46 {
47 class Coordinates;
48 
49 namespace
50 {
vector_matrix_multiply_f16(const ITensor * input0,const ITensor * input1,ITensor * output,const Window & window,const ThreadInfo & info)51 void vector_matrix_multiply_f16(const ITensor *input0, const ITensor *input1, ITensor *output, const Window &window, const ThreadInfo &info)
52 {
53 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
54     const auto width_matrix_b  = static_cast<int>(output->info()->dimension(0));
55     const auto in_b_stride     = static_cast<int>(input1->info()->strides_in_bytes()[1] / data_size_from_type(input1->info()->data_type()));
56     const auto num_elems_vec_a = static_cast<int>(input0->info()->dimension(0));
57 
58     // The implementation computes 16 elements per iteration
59     const int window_start_x = 16 * info.thread_id;
60     const int window_step_x  = 16 * info.num_threads;
61     // Make sure (window_end_x - window_start_x) is a multiple of window_step_x
62     const int window_end_x = ceil_to_multiple(width_matrix_b - window_start_x, window_step_x) + window_start_x;
63 
64     Window win_out(window);
65     win_out.set(Window::DimX, Window::Dimension(window_start_x, window_end_x, window_step_x));
66 
67     Window win_a(window);
68     win_a.set(Window::DimX, Window::Dimension(0, 1, 1));
69 
70     Iterator ina(input0, win_a);
71     Iterator out(output, win_out);
72 
73     execute_window_loop(win_out, [&](const Coordinates & id)
74     {
75         if(id.x() > width_matrix_b)
76         {
77             return;
78         }
79 
80         float16x8_t acc0 = vdupq_n_f16(0.f);
81         float16x8_t acc1 = vdupq_n_f16(0.f);
82         float16x8_t acc2 = vdupq_n_f16(0.f);
83         float16x8_t acc3 = vdupq_n_f16(0.f);
84 
85         auto vec_a    = reinterpret_cast<const float16_t *>(ina.ptr());
86         auto matrix_b = reinterpret_cast<const float16_t *>(input1->ptr_to_element(Coordinates(id[0], 0, id[1])));
87 
88         const float16_t *vec_a_end_addr = vec_a + num_elems_vec_a;
89 
90         for(; vec_a <= (vec_a_end_addr - 4);)
91         {
92             const float16x4_t a0l = vld1_f16(vec_a);
93 
94             float16x8_t b00 = vld1q_f16(matrix_b);
95             float16x8_t b01 = vld1q_f16(matrix_b + 8 + 0 * in_b_stride);
96             float16x8_t b02 = vld1q_f16(matrix_b + 16 + 0 * in_b_stride);
97             float16x8_t b03 = vld1q_f16(matrix_b + 24 + 0 * in_b_stride);
98 
99             float16x8_t b10 = vld1q_f16(matrix_b + 0 + 1 * in_b_stride);
100             float16x8_t b11 = vld1q_f16(matrix_b + 8 + 1 * in_b_stride);
101             float16x8_t b12 = vld1q_f16(matrix_b + 16 + 1 * in_b_stride);
102             float16x8_t b13 = vld1q_f16(matrix_b + 24 + 1 * in_b_stride);
103 
104             acc0 = vaddq_f16(acc0, vmulq_lane_f16(b00, a0l, 0));
105             acc1 = vaddq_f16(acc1, vmulq_lane_f16(b01, a0l, 0));
106             acc2 = vaddq_f16(acc2, vmulq_lane_f16(b02, a0l, 0));
107             acc3 = vaddq_f16(acc3, vmulq_lane_f16(b03, a0l, 0));
108             acc0 = vaddq_f16(acc0, vmulq_lane_f16(b10, a0l, 1));
109             acc1 = vaddq_f16(acc1, vmulq_lane_f16(b11, a0l, 1));
110             acc2 = vaddq_f16(acc2, vmulq_lane_f16(b12, a0l, 1));
111             acc3 = vaddq_f16(acc3, vmulq_lane_f16(b13, a0l, 1));
112 
113             matrix_b += 2 * in_b_stride;
114 
115             b00 = vld1q_f16(matrix_b);
116             b01 = vld1q_f16(matrix_b + 8 + 0 * in_b_stride);
117             b02 = vld1q_f16(matrix_b + 16 + 0 * in_b_stride);
118             b03 = vld1q_f16(matrix_b + 24 + 0 * in_b_stride);
119             b10 = vld1q_f16(matrix_b + 0 + 1 * in_b_stride);
120             b11 = vld1q_f16(matrix_b + 8 + 1 * in_b_stride);
121             b12 = vld1q_f16(matrix_b + 16 + 1 * in_b_stride);
122             b13 = vld1q_f16(matrix_b + 24 + 1 * in_b_stride);
123 
124             acc0 = vaddq_f16(acc0, vmulq_lane_f16(b00, a0l, 2));
125             acc1 = vaddq_f16(acc1, vmulq_lane_f16(b01, a0l, 2));
126             acc2 = vaddq_f16(acc2, vmulq_lane_f16(b02, a0l, 2));
127             acc3 = vaddq_f16(acc3, vmulq_lane_f16(b03, a0l, 2));
128             acc0 = vaddq_f16(acc0, vmulq_lane_f16(b10, a0l, 3));
129             acc1 = vaddq_f16(acc1, vmulq_lane_f16(b11, a0l, 3));
130             acc2 = vaddq_f16(acc2, vmulq_lane_f16(b12, a0l, 3));
131             acc3 = vaddq_f16(acc3, vmulq_lane_f16(b13, a0l, 3));
132 
133             vec_a += 4;
134             matrix_b += 2 * in_b_stride;
135         }
136 
137         for(; vec_a < vec_a_end_addr;)
138         {
139             const float16_t   a0  = *vec_a;
140             const float16x8_t b00 = vld1q_f16(matrix_b);
141             const float16x8_t b01 = vld1q_f16(matrix_b + 8 + 0 * in_b_stride);
142             const float16x8_t b02 = vld1q_f16(matrix_b + 16 + 0 * in_b_stride);
143             const float16x8_t b03 = vld1q_f16(matrix_b + 24 + 0 * in_b_stride);
144 
145             acc0 = vaddq_f16(acc0, vmulq_n_f16(b00, a0));
146             acc1 = vaddq_f16(acc1, vmulq_n_f16(b01, a0));
147             acc2 = vaddq_f16(acc2, vmulq_n_f16(b02, a0));
148             acc3 = vaddq_f16(acc3, vmulq_n_f16(b03, a0));
149 
150             vec_a += 1;
151             matrix_b += in_b_stride;
152         }
153 
154         const auto vec_out = reinterpret_cast<float16_t *>(out.ptr());
155 
156         vst1q_f16(vec_out + 0, acc0);
157         vst1q_f16(vec_out + 8, acc1);
158         vst1q_f16(vec_out + 16, acc2);
159         vst1q_f16(vec_out + 24, acc3);
160     },
161     ina, out);
162 #else  /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
163     ARM_COMPUTE_UNUSED(input0);
164     ARM_COMPUTE_UNUSED(input1);
165     ARM_COMPUTE_UNUSED(output);
166     ARM_COMPUTE_UNUSED(window);
167     ARM_COMPUTE_UNUSED(info);
168     ARM_COMPUTE_ERROR("Not supported, recompile with -march=armv8.2-a+fp16+simd.");
169 #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
170 }
171 
vector_matrix_multiply_f32(const ITensor * input0,const ITensor * input1,ITensor * output,const Window & window,const ThreadInfo & info)172 void vector_matrix_multiply_f32(const ITensor *input0, const ITensor *input1, ITensor *output, const Window &window, const ThreadInfo &info)
173 {
174     const auto width_matrix_b  = static_cast<int>(output->info()->dimension(0));
175     const auto in_b_stride     = static_cast<int>(input1->info()->strides_in_bytes()[1] / data_size_from_type(input1->info()->data_type()));
176     const auto num_elems_vec_a = static_cast<int>(input0->info()->dimension(0));
177 
178     // The implementation computes 16 elements per iteration
179     const int window_start_x = 16 * info.thread_id;
180     const int window_step_x  = 16 * info.num_threads;
181     // Make sure (window_end_x - window_start_x) is a multiple of window_step_x
182     const int window_end_x = ceil_to_multiple(width_matrix_b - window_start_x, window_step_x) + window_start_x;
183 
184     Window win_out(window);
185     win_out.set(Window::DimX, Window::Dimension(window_start_x, window_end_x, window_step_x));
186 
187     Window win_a(window);
188     win_a.set(Window::DimX, Window::Dimension(0, 0, 0));
189 
190     Iterator ina(input0, win_a);
191     Iterator out(output, win_out);
192 
193     execute_window_loop(win_out, [&](const Coordinates & id)
194     {
195         if(id.x() > width_matrix_b)
196         {
197             return;
198         }
199 
200         float32x4_t acc0 = vdupq_n_f32(0.f);
201         float32x4_t acc1 = vdupq_n_f32(0.f);
202         float32x4_t acc2 = vdupq_n_f32(0.f);
203         float32x4_t acc3 = vdupq_n_f32(0.f);
204 
205         auto vec_a    = reinterpret_cast<const float *>(ina.ptr());
206         auto matrix_b = reinterpret_cast<const float *>(input1->ptr_to_element(Coordinates(id[0], 0, id[1])));
207 
208 #if __arm__
209         asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(vec_a)));
210         asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(matrix_b)));
211         asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(matrix_b + in_b_stride)));
212 #endif /* __arm__ */
213 
214         const float *vec_a_end_addr = vec_a + num_elems_vec_a;
215 
216         for(; vec_a <= (vec_a_end_addr - 4);)
217         {
218             float32x2_t a0l = vld1_f32(vec_a);
219 
220             float32x4_t b00 = vld1q_f32(matrix_b + 0 + 0 * in_b_stride);
221             float32x4_t b01 = vld1q_f32(matrix_b + 4 + 0 * in_b_stride);
222             float32x4_t b02 = vld1q_f32(matrix_b + 8 + 0 * in_b_stride);
223             float32x4_t b03 = vld1q_f32(matrix_b + 12 + 0 * in_b_stride);
224 
225             float32x4_t b10 = vld1q_f32(matrix_b + 0 + 1 * in_b_stride);
226             float32x4_t b11 = vld1q_f32(matrix_b + 4 + 1 * in_b_stride);
227             float32x4_t b12 = vld1q_f32(matrix_b + 8 + 1 * in_b_stride);
228             float32x4_t b13 = vld1q_f32(matrix_b + 12 + 1 * in_b_stride);
229 
230 #if __arm__
231             asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(vec_a)));
232             asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast<const uint8_t *>(matrix_b + 1 * in_b_stride)));
233             asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast<const uint8_t *>(matrix_b + 2 * in_b_stride)));
234             asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast<const uint8_t *>(matrix_b + 3 * in_b_stride)));
235             asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast<const uint8_t *>(matrix_b + 4 * in_b_stride)));
236 #endif /* __arm__ */
237 
238             acc0 = vmlaq_lane_f32(acc0, b00, a0l, 0);
239             acc1 = vmlaq_lane_f32(acc1, b01, a0l, 0);
240             acc2 = vmlaq_lane_f32(acc2, b02, a0l, 0);
241             acc3 = vmlaq_lane_f32(acc3, b03, a0l, 0);
242 
243             acc0 = vmlaq_lane_f32(acc0, b10, a0l, 1);
244             acc1 = vmlaq_lane_f32(acc1, b11, a0l, 1);
245             acc2 = vmlaq_lane_f32(acc2, b12, a0l, 1);
246             acc3 = vmlaq_lane_f32(acc3, b13, a0l, 1);
247 
248             vec_a += 2;
249             matrix_b += 2 * in_b_stride;
250 
251             a0l = vld1_f32(vec_a);
252 
253             b00 = vld1q_f32(matrix_b + 0 + 0 * in_b_stride);
254             b01 = vld1q_f32(matrix_b + 4 + 0 * in_b_stride);
255             b02 = vld1q_f32(matrix_b + 8 + 0 * in_b_stride);
256             b03 = vld1q_f32(matrix_b + 12 + 0 * in_b_stride);
257 
258             b10 = vld1q_f32(matrix_b + 0 + 1 * in_b_stride);
259             b11 = vld1q_f32(matrix_b + 4 + 1 * in_b_stride);
260             b12 = vld1q_f32(matrix_b + 8 + 1 * in_b_stride);
261             b13 = vld1q_f32(matrix_b + 12 + 1 * in_b_stride);
262 
263             acc0 = vmlaq_lane_f32(acc0, b00, a0l, 0);
264             acc1 = vmlaq_lane_f32(acc1, b01, a0l, 0);
265             acc2 = vmlaq_lane_f32(acc2, b02, a0l, 0);
266             acc3 = vmlaq_lane_f32(acc3, b03, a0l, 0);
267 
268             acc0 = vmlaq_lane_f32(acc0, b10, a0l, 1);
269             acc1 = vmlaq_lane_f32(acc1, b11, a0l, 1);
270             acc2 = vmlaq_lane_f32(acc2, b12, a0l, 1);
271             acc3 = vmlaq_lane_f32(acc3, b13, a0l, 1);
272 
273             vec_a += 2;
274             matrix_b += 2 * in_b_stride;
275         }
276 
277         for(; vec_a < vec_a_end_addr;)
278         {
279             const float a0 = *vec_a;
280 
281             const float32x4_t b00 = vld1q_f32(matrix_b + 0 + 0 * in_b_stride);
282             const float32x4_t b01 = vld1q_f32(matrix_b + 4 + 0 * in_b_stride);
283             const float32x4_t b02 = vld1q_f32(matrix_b + 8 + 0 * in_b_stride);
284             const float32x4_t b03 = vld1q_f32(matrix_b + 12 + 0 * in_b_stride);
285 
286             acc0 = vmlaq_n_f32(acc0, b00, a0);
287             acc1 = vmlaq_n_f32(acc1, b01, a0);
288             acc2 = vmlaq_n_f32(acc2, b02, a0);
289             acc3 = vmlaq_n_f32(acc3, b03, a0);
290 
291             vec_a += 1;
292             matrix_b += in_b_stride;
293         }
294 
295         const auto vec_out = reinterpret_cast<float *>(out.ptr());
296 
297         vst1q_f32(vec_out + 0, acc0);
298         vst1q_f32(vec_out + 4, acc1);
299         vst1q_f32(vec_out + 8, acc2);
300         vst1q_f32(vec_out + 12, acc3);
301     },
302     ina, out);
303 }
304 
validate_arguments(const ITensorInfo * input0,const ITensorInfo * input1,const ITensorInfo * output)305 Status validate_arguments(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *output)
306 {
307     ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(input0);
308     ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input0, 1, DataType::F16, DataType::F32);
309     ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 1, DataType::F16, DataType::F32);
310     ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::F16, DataType::F32);
311     ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input0, input1, output);
312     ARM_COMPUTE_RETURN_ERROR_ON(input0->dimension(0) != input1->dimension(1));
313 
314     return Status{};
315 }
316 
validate_and_configure_window(ITensorInfo * input0,ITensorInfo * input1,ITensorInfo * output)317 std::tuple<Status, Window> validate_and_configure_window(ITensorInfo *input0, ITensorInfo *input1, ITensorInfo *output)
318 {
319     constexpr unsigned int num_elems_processed_per_iteration_x = 16;
320 
321     Window win = calculate_max_window(*output, Steps(num_elems_processed_per_iteration_x));
322 
323     AccessWindowHorizontal input0_access(input0, 0, num_elems_processed_per_iteration_x);
324     AccessWindowHorizontal input1_access(input1, 0, num_elems_processed_per_iteration_x);
325     AccessWindowHorizontal output_access(output, 0, num_elems_processed_per_iteration_x);
326 
327     bool window_changed = update_window_and_padding(win, input0_access, input1_access, output_access);
328 
329     output_access.set_valid_region(win, ValidRegion(Coordinates(), output->tensor_shape()));
330 
331     Status err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Status{};
332 
333     return std::make_tuple(err, win);
334 }
335 } // namespace
336 
NELocallyConnectedMatrixMultiplyKernel()337 NELocallyConnectedMatrixMultiplyKernel::NELocallyConnectedMatrixMultiplyKernel()
338     : _input0(nullptr), _input1(nullptr), _output(nullptr)
339 {
340 }
341 
configure(const ITensor * input0,const ITensor * input1,ITensor * output)342 void NELocallyConnectedMatrixMultiplyKernel::configure(const ITensor *input0, const ITensor *input1, ITensor *output)
343 {
344     ARM_COMPUTE_ERROR_ON_NULLPTR(input0, input1, output);
345     ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input0->info(), input1->info(), output->info()));
346 
347     _input0 = input0;
348     _input1 = input1;
349     _output = output;
350 
351     // Configure kernel window
352     auto win_config = validate_and_configure_window(input0->info(), input1->info(), output->info());
353 
354     ARM_COMPUTE_ERROR_THROW_ON(std::get<0>(win_config));
355 
356     INEKernel::configure(std::get<1>(win_config));
357 }
358 
validate(const ITensorInfo * input0,const ITensorInfo * input1,const ITensorInfo * output)359 Status NELocallyConnectedMatrixMultiplyKernel::validate(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *output)
360 {
361     ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input0, input1, output));
362     ARM_COMPUTE_RETURN_ON_ERROR(std::get<0>(validate_and_configure_window(input0->clone().get(), input1->clone().get(), output->clone().get())));
363 
364     return Status{};
365 }
366 
run(const Window & window,const ThreadInfo & info)367 void NELocallyConnectedMatrixMultiplyKernel::run(const Window &window, const ThreadInfo &info)
368 {
369     ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
370     ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
371 
372     switch(_input0->info()->data_type())
373     {
374         case DataType::F16:
375         {
376             vector_matrix_multiply_f16(_input0, _input1, _output, window, info);
377             break;
378         }
379         case DataType::F32:
380         {
381             vector_matrix_multiply_f32(_input0, _input1, _output, window, info);
382             break;
383         }
384         default:
385         {
386             ARM_COMPUTE_ERROR("Data type not supported");
387             break;
388         }
389     }
390 }
391 } // namespace arm_compute
392