• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2020 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
25 
26 #include "input.hpp"
27 #include "arm.hpp"
28 
29 namespace winograd
30 {
31 
32 template <>
transform_tile(const int n_channels,const __fp16 * const input_base,const int input_row_stride,const int input_col_stride,__fp16 * outptr,const int matrix_stride)33 void InputTransform<4, 4, __fp16, __fp16, WinogradRoots::Integers>::transform_tile(
34     const int n_channels,
35     const __fp16* const input_base,
36     const int input_row_stride,
37     const int input_col_stride,
38     __fp16* outptr,
39     const int matrix_stride
40 )
41 {
42     constexpr int inner_tile_rows = 4, inner_tile_cols = 4;
43 
44     // Get pointers into the input tile
45     const __fp16 *x_ptrs[inner_tile_rows][inner_tile_cols];
46     for (int i = 0, xi = 0; i < inner_tile_rows; i++, xi++)
47     {
48         // Get a pointer into the row
49         const __fp16* const row_ptr = input_base + xi*input_row_stride;
50 
51         for (int j = 0, xj = 0; j < inner_tile_cols; j++, xj++)
52         {
53             x_ptrs[i][j] = row_ptr + xj*input_col_stride;
54         }
55     }
56 
57     // Matrices used/computed in this kernel.
58     __fp16 x[inner_tile_rows][inner_tile_cols];
59     __fp16 XTx[inner_tile_rows][inner_tile_cols];
60     __fp16 U[inner_tile_rows][inner_tile_cols];
61 
62     for (int i = 0; i < inner_tile_rows; i++)
63     {
64         for (int j = 0; j < inner_tile_cols; j++)
65         {
66             x[i][j] = XTx[i][j] = 0.0f;
67         }
68     }
69 
70     // Perform the Winograd input transformation for each channel in the input
71     // tensor.
72     int channels_remaining = n_channels;
73 #ifdef __aarch64__
74     for (; channels_remaining >= 8; channels_remaining -= 8)
75   {
76     // Matrices used/computed in this kernel.
77     float16x8_t x[inner_tile_rows][inner_tile_cols];
78     float16x8_t XTx[inner_tile_rows][inner_tile_cols];
79     float16x8_t U[inner_tile_rows][inner_tile_cols];
80 
81     for (int i = 0; i < inner_tile_rows; i++)
82     {
83       for (int j = 0; j < inner_tile_cols; j++)
84       {
85         x[i][j] = vdupq_n_f16(0.0f);
86         XTx[i][j] = vdupq_n_f16(0.0f);
87       }
88     }
89 
90     // Load x
91     for (int i = 0; i < inner_tile_rows; i++)
92     {
93       for (int j = 0; j < inner_tile_cols; j++)
94       {
95         x[i][j] = vld1q_f16(x_ptrs[i][j]);
96         x_ptrs[i][j] += 8;
97       }
98     }
99 
100     // Compute XT . x
101     for (int j = 0; j < inner_tile_cols; j++)
102     {
103       // XTx[0][j] = x[0][j] - x[2][j];
104       XTx[0][j] = vsubq_f16(x[0][j], x[2][j]);
105 
106       // XTx[1][j] = x[1][j] + x[2][j];
107       XTx[1][j] = vaddq_f16(x[1][j], x[2][j]);
108 
109       // XTx[2][j] = x[2][j] - x[1][j];
110       XTx[2][j] = vsubq_f16(x[2][j], x[1][j]);
111 
112       // XTx[3][j] = x[1][j] - x[3][j];
113       XTx[3][j] = vsubq_f16(x[1][j], x[3][j]);
114     }
115 
116     // Compute U = XT . x . X
117     for (int i = 0; i < inner_tile_rows; i++)
118     {
119       // U[i][0] = XTx[i][0] - XTx[i][2];
120       U[i][0] = vsubq_f16(XTx[i][0], XTx[i][2]);
121 
122       // U[i][1] = XTx[i][1] + XTx[i][2];
123       U[i][1] = vaddq_f16(XTx[i][1], XTx[i][2]);
124 
125       // U[i][2] = XTx[i][2] - XTx[i][1];
126       U[i][2] = vsubq_f16(XTx[i][2], XTx[i][1]);
127 
128       // U[i][3] = XTx[i][1] - XTx[i][3];
129       U[i][3] = vsubq_f16(XTx[i][1], XTx[i][3]);
130     }
131 
132     // Store the transformed matrix
133     for (int i = 0, m = 0; i < inner_tile_rows; i++)
134     {
135       for (int j = 0; j < inner_tile_cols; j++, m++)
136       {
137         vst1q_f16(outptr + m*matrix_stride, U[i][j]);
138       }
139     }
140     outptr += 8;
141   }
142 #endif  // __aarch64__
143 #ifdef __arm_any__
144     for (; channels_remaining >= 4; channels_remaining -= 4)
145   {
146     // Matrices used/computed in this kernel.
147     float16x4_t x[inner_tile_rows][inner_tile_cols];
148     float16x4_t XTx[inner_tile_rows][inner_tile_cols];
149     float16x4_t U[inner_tile_rows][inner_tile_cols];
150 
151     for (int i = 0; i < inner_tile_rows; i++)
152     {
153       for (int j = 0; j < inner_tile_cols; j++)
154       {
155         x[i][j] = vdup_n_f16(0.0f);
156         XTx[i][j] = vdup_n_f16(0.0f);
157       }
158     }
159 
160     // Load x
161     for (int i = 0; i < inner_tile_rows; i++)
162     {
163       for (int j = 0; j < inner_tile_cols; j++)
164       {
165         x[i][j] = vld1_f16(x_ptrs[i][j]);
166         x_ptrs[i][j] += 4;
167       }
168     }
169 
170     // Compute XT . x
171     for (int j = 0; j < inner_tile_cols; j++)
172     {
173       // XTx[0][j] = x[0][j] - x[2][j];
174       XTx[0][j] = vsub_f16(x[0][j], x[2][j]);
175 
176       // XTx[1][j] = x[1][j] + x[2][j];
177       XTx[1][j] = vadd_f16(x[1][j], x[2][j]);
178 
179       // XTx[2][j] = x[2][j] - x[1][j];
180       XTx[2][j] = vsub_f16(x[2][j], x[1][j]);
181 
182       // XTx[3][j] = x[1][j] - x[3][j];
183       XTx[3][j] = vsub_f16(x[1][j], x[3][j]);
184     }
185 
186     // Compute U = XT . x . X
187     for (int i = 0; i < inner_tile_rows; i++)
188     {
189       // U[i][0] = XTx[i][0] - XTx[i][2];
190       U[i][0] = vsub_f16(XTx[i][0], XTx[i][2]);
191 
192       // U[i][1] = XTx[i][1] + XTx[i][2];
193       U[i][1] = vadd_f16(XTx[i][1], XTx[i][2]);
194 
195       // U[i][2] = XTx[i][2] - XTx[i][1];
196       U[i][2] = vsub_f16(XTx[i][2], XTx[i][1]);
197 
198       // U[i][3] = XTx[i][1] - XTx[i][3];
199       U[i][3] = vsub_f16(XTx[i][1], XTx[i][3]);
200     }
201 
202     // Store the transformed matrix
203     for (int i = 0, m = 0; i < inner_tile_rows; i++)
204     {
205       for (int j = 0; j < inner_tile_cols; j++, m++)
206       {
207         vst1_f16(outptr + m*matrix_stride, U[i][j]);
208       }
209     }
210     outptr += 4;
211   }
212 #endif  // __arm_any__
213     for (; channels_remaining; channels_remaining--)
214     {
215         // Load x
216         for (int i = 0; i < inner_tile_rows; i++)
217         {
218             for (int j = 0; j < inner_tile_cols; j++)
219             {
220                 x[i][j] = *(x_ptrs[i][j]++);
221             }
222         }
223 
224         // Compute XT . x
225         for (int j = 0; j < inner_tile_cols; j++)
226         {
227             XTx[0][j] = x[0][j] - x[2][j];
228             XTx[1][j] = x[1][j] + x[2][j];
229             XTx[2][j] = x[2][j] - x[1][j];
230             XTx[3][j] = x[1][j] - x[3][j];
231         }
232 
233         // Compute U = XT . x . X
234         for (int i = 0; i < inner_tile_rows; i++)
235         {
236             U[i][0] = XTx[i][0] - XTx[i][2];
237             U[i][1] = XTx[i][1] + XTx[i][2];
238             U[i][2] = XTx[i][2] - XTx[i][1];
239             U[i][3] = XTx[i][1] - XTx[i][3];
240         }
241 
242         // Store the transformed matrix
243         for (int i = 0, m = 0; i < inner_tile_rows; i++)
244         {
245             for (int j = 0; j < inner_tile_cols; j++, m++)
246             {
247                 *(outptr + m*matrix_stride) = U[i][j];
248             }
249         }
250         outptr++;
251     }
252 }
253 
254 template class InputTransform<4, 4, __fp16, __fp16, WinogradRoots::Integers>;
255 
256 }  // namespace
257 #endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
258