• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <psimd.h>
10 
11 #include <qnnpack/sdwconv.h>
12 
pytorch_sdwconv_ukernel_up4x9__psimd(size_t channels,size_t output_width,const float ** input,const float * weights,float * output,size_t input_stride,size_t output_increment,const struct pytorch_qnnp_fp32_clamping_params clamping_params[restrict static1])13 void pytorch_sdwconv_ukernel_up4x9__psimd(
14     size_t channels,
15     size_t output_width,
16     const float** input,
17     const float* weights,
18     float* output,
19     size_t input_stride,
20     size_t output_increment,
21     const struct pytorch_qnnp_fp32_clamping_params
22         clamping_params[restrict static 1]) {
23   const psimd_f32 vmax = psimd_splat_f32(clamping_params->max);
24   const psimd_f32 vmin = psimd_splat_f32(clamping_params->min);
25   do {
26     const float* i0 = input[0];
27     const float* i1 = input[1];
28     const float* i2 = input[2];
29     const float* i3 = input[3];
30     const float* i4 = input[4];
31     const float* i5 = input[5];
32     const float* i6 = input[6];
33     const float* i7 = input[7];
34     const float* i8 = input[8];
35 
36     input = (const float**)((uintptr_t)input + input_stride);
37 
38     size_t c = channels;
39     const float* w = weights;
40     for (; c >= 4; c -= 4) {
41       psimd_f32 vacc = psimd_load_f32(w);
42 
43       const psimd_f32 vi0 = psimd_load_f32(i0);
44       i0 += 4;
45       const psimd_f32 vk0 = psimd_load_f32(w + 8);
46       vacc += vi0 * vk0;
47 
48       const psimd_f32 vi1 = psimd_load_f32(i1);
49       i1 += 4;
50       const psimd_f32 vk1 = psimd_load_f32(w + 12);
51       psimd_f32 vacc2 = vi1 * vk1;
52 
53       const psimd_f32 vi2 = psimd_load_f32(i2);
54       i2 += 4;
55       const psimd_f32 vk2 = psimd_load_f32(w + 16);
56       vacc += vi2 * vk2;
57 
58       const psimd_f32 vi3 = psimd_load_f32(i3);
59       i3 += 4;
60       const psimd_f32 vk3 = psimd_load_f32(w + 20);
61       vacc2 += vi3 * vk3;
62 
63       const psimd_f32 vi4 = psimd_load_f32(i4);
64       i4 += 4;
65       const psimd_f32 vk4 = psimd_load_f32(w + 24);
66       vacc += vi4 * vk4;
67 
68       const psimd_f32 vi5 = psimd_load_f32(i5);
69       i5 += 4;
70       const psimd_f32 vk5 = psimd_load_f32(w + 28);
71       vacc2 += vi5 * vk5;
72 
73       const psimd_f32 vi6 = psimd_load_f32(i6);
74       i6 += 4;
75       const psimd_f32 vk6 = psimd_load_f32(w + 32);
76       vacc += vi6 * vk6;
77 
78       const psimd_f32 vi7 = psimd_load_f32(i7);
79       i7 += 4;
80       const psimd_f32 vk7 = psimd_load_f32(w + 36);
81       vacc2 += vi7 * vk7;
82 
83       const psimd_f32 vi8 = psimd_load_f32(i8);
84       i8 += 4;
85       const psimd_f32 vk8 = psimd_load_f32(w + 40);
86       vacc += vi8 * vk8;
87 
88       vacc += vacc2;
89 
90       vacc = psimd_min_f32(vacc, vmax);
91       vacc = psimd_max_f32(vacc, vmin);
92 
93       psimd_store_f32(output, vacc);
94       w += 44;
95     }
96     if (c != 0) {
97       psimd_f32 vacc = psimd_load_f32(w);
98       c *= sizeof(float);
99 
100       i0 = (const float*)((uintptr_t)i0 - c);
101       const psimd_f32 vi0 = psimd_load_f32(i0);
102       const psimd_f32 vk0 = psimd_load_f32(w + 8);
103       vacc += vi0 * vk0;
104 
105       i1 = (const float*)((uintptr_t)i1 - c);
106       const psimd_f32 vi1 = psimd_load_f32(i1);
107       const psimd_f32 vk1 = psimd_load_f32(w + 12);
108       psimd_f32 vacc2 = vi1 * vk1;
109 
110       i2 = (const float*)((uintptr_t)i2 - c);
111       const psimd_f32 vi2 = psimd_load_f32(i2);
112       const psimd_f32 vk2 = psimd_load_f32(w + 16);
113       vacc += vi2 * vk2;
114 
115       i3 = (const float*)((uintptr_t)i3 - c);
116       const psimd_f32 vi3 = psimd_load_f32(i3);
117       const psimd_f32 vk3 = psimd_load_f32(w + 20);
118       vacc2 += vi3 * vk3;
119 
120       i4 = (const float*)((uintptr_t)i4 - c);
121       const psimd_f32 vi4 = psimd_load_f32(i4);
122       const psimd_f32 vk4 = psimd_load_f32(w + 24);
123       vacc += vi4 * vk4;
124 
125       i5 = (const float*)((uintptr_t)i5 - c);
126       const psimd_f32 vi5 = psimd_load_f32(i5);
127       const psimd_f32 vk5 = psimd_load_f32(w + 28);
128       vacc2 += vi5 * vk5;
129 
130       i6 = (const float*)((uintptr_t)i6 - c);
131       const psimd_f32 vi6 = psimd_load_f32(i6);
132       const psimd_f32 vk6 = psimd_load_f32(w + 32);
133       vacc += vi6 * vk6;
134 
135       i7 = (const float*)((uintptr_t)i7 - c);
136       const psimd_f32 vi7 = psimd_load_f32(i7);
137       const psimd_f32 vk7 = psimd_load_f32(w + 36);
138       vacc2 += vi7 * vk7;
139 
140       i8 = (const float*)((uintptr_t)i8 - c);
141       const psimd_f32 vi8 = psimd_load_f32(i8);
142       const psimd_f32 vk8 = psimd_load_f32(w + 40);
143       vacc += vi8 * vk8;
144 
145       vacc += vacc2;
146 
147       vacc = psimd_min_f32(vacc, vmax);
148       vacc = psimd_max_f32(vacc, vmin);
149 
150       output = (float*)((uintptr_t)output - c);
151       psimd_store_f32(output, vacc);
152     }
153 
154     output = (float*)((uintptr_t)output + output_increment);
155   } while (--output_width != 0);
156 }
157