1 /*
2 * Copyright (c) Facebook, Inc. and its affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <psimd.h>
10
11 #include <qnnpack/sdwconv.h>
12
pytorch_sdwconv_ukernel_up4x9__psimd(size_t channels,size_t output_width,const float ** input,const float * weights,float * output,size_t input_stride,size_t output_increment,const struct pytorch_qnnp_fp32_clamping_params clamping_params[restrict static1])13 void pytorch_sdwconv_ukernel_up4x9__psimd(
14 size_t channels,
15 size_t output_width,
16 const float** input,
17 const float* weights,
18 float* output,
19 size_t input_stride,
20 size_t output_increment,
21 const struct pytorch_qnnp_fp32_clamping_params
22 clamping_params[restrict static 1]) {
23 const psimd_f32 vmax = psimd_splat_f32(clamping_params->max);
24 const psimd_f32 vmin = psimd_splat_f32(clamping_params->min);
25 do {
26 const float* i0 = input[0];
27 const float* i1 = input[1];
28 const float* i2 = input[2];
29 const float* i3 = input[3];
30 const float* i4 = input[4];
31 const float* i5 = input[5];
32 const float* i6 = input[6];
33 const float* i7 = input[7];
34 const float* i8 = input[8];
35
36 input = (const float**)((uintptr_t)input + input_stride);
37
38 size_t c = channels;
39 const float* w = weights;
40 for (; c >= 4; c -= 4) {
41 psimd_f32 vacc = psimd_load_f32(w);
42
43 const psimd_f32 vi0 = psimd_load_f32(i0);
44 i0 += 4;
45 const psimd_f32 vk0 = psimd_load_f32(w + 8);
46 vacc += vi0 * vk0;
47
48 const psimd_f32 vi1 = psimd_load_f32(i1);
49 i1 += 4;
50 const psimd_f32 vk1 = psimd_load_f32(w + 12);
51 psimd_f32 vacc2 = vi1 * vk1;
52
53 const psimd_f32 vi2 = psimd_load_f32(i2);
54 i2 += 4;
55 const psimd_f32 vk2 = psimd_load_f32(w + 16);
56 vacc += vi2 * vk2;
57
58 const psimd_f32 vi3 = psimd_load_f32(i3);
59 i3 += 4;
60 const psimd_f32 vk3 = psimd_load_f32(w + 20);
61 vacc2 += vi3 * vk3;
62
63 const psimd_f32 vi4 = psimd_load_f32(i4);
64 i4 += 4;
65 const psimd_f32 vk4 = psimd_load_f32(w + 24);
66 vacc += vi4 * vk4;
67
68 const psimd_f32 vi5 = psimd_load_f32(i5);
69 i5 += 4;
70 const psimd_f32 vk5 = psimd_load_f32(w + 28);
71 vacc2 += vi5 * vk5;
72
73 const psimd_f32 vi6 = psimd_load_f32(i6);
74 i6 += 4;
75 const psimd_f32 vk6 = psimd_load_f32(w + 32);
76 vacc += vi6 * vk6;
77
78 const psimd_f32 vi7 = psimd_load_f32(i7);
79 i7 += 4;
80 const psimd_f32 vk7 = psimd_load_f32(w + 36);
81 vacc2 += vi7 * vk7;
82
83 const psimd_f32 vi8 = psimd_load_f32(i8);
84 i8 += 4;
85 const psimd_f32 vk8 = psimd_load_f32(w + 40);
86 vacc += vi8 * vk8;
87
88 vacc += vacc2;
89
90 vacc = psimd_min_f32(vacc, vmax);
91 vacc = psimd_max_f32(vacc, vmin);
92
93 psimd_store_f32(output, vacc);
94 w += 44;
95 }
96 if (c != 0) {
97 psimd_f32 vacc = psimd_load_f32(w);
98 c *= sizeof(float);
99
100 i0 = (const float*)((uintptr_t)i0 - c);
101 const psimd_f32 vi0 = psimd_load_f32(i0);
102 const psimd_f32 vk0 = psimd_load_f32(w + 8);
103 vacc += vi0 * vk0;
104
105 i1 = (const float*)((uintptr_t)i1 - c);
106 const psimd_f32 vi1 = psimd_load_f32(i1);
107 const psimd_f32 vk1 = psimd_load_f32(w + 12);
108 psimd_f32 vacc2 = vi1 * vk1;
109
110 i2 = (const float*)((uintptr_t)i2 - c);
111 const psimd_f32 vi2 = psimd_load_f32(i2);
112 const psimd_f32 vk2 = psimd_load_f32(w + 16);
113 vacc += vi2 * vk2;
114
115 i3 = (const float*)((uintptr_t)i3 - c);
116 const psimd_f32 vi3 = psimd_load_f32(i3);
117 const psimd_f32 vk3 = psimd_load_f32(w + 20);
118 vacc2 += vi3 * vk3;
119
120 i4 = (const float*)((uintptr_t)i4 - c);
121 const psimd_f32 vi4 = psimd_load_f32(i4);
122 const psimd_f32 vk4 = psimd_load_f32(w + 24);
123 vacc += vi4 * vk4;
124
125 i5 = (const float*)((uintptr_t)i5 - c);
126 const psimd_f32 vi5 = psimd_load_f32(i5);
127 const psimd_f32 vk5 = psimd_load_f32(w + 28);
128 vacc2 += vi5 * vk5;
129
130 i6 = (const float*)((uintptr_t)i6 - c);
131 const psimd_f32 vi6 = psimd_load_f32(i6);
132 const psimd_f32 vk6 = psimd_load_f32(w + 32);
133 vacc += vi6 * vk6;
134
135 i7 = (const float*)((uintptr_t)i7 - c);
136 const psimd_f32 vi7 = psimd_load_f32(i7);
137 const psimd_f32 vk7 = psimd_load_f32(w + 36);
138 vacc2 += vi7 * vk7;
139
140 i8 = (const float*)((uintptr_t)i8 - c);
141 const psimd_f32 vi8 = psimd_load_f32(i8);
142 const psimd_f32 vk8 = psimd_load_f32(w + 40);
143 vacc += vi8 * vk8;
144
145 vacc += vacc2;
146
147 vacc = psimd_min_f32(vacc, vmax);
148 vacc = psimd_max_f32(vacc, vmin);
149
150 output = (float*)((uintptr_t)output - c);
151 psimd_store_f32(output, vacc);
152 }
153
154 output = (float*)((uintptr_t)output + output_increment);
155 } while (--output_width != 0);
156 }
157