• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2019 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 
25 #include "input.hpp"
26 #include "arm.hpp"
27 
28 namespace winograd
29 {
30 
31 template <>
transform_tile(const int n_channels,const float * const input_base,const int input_row_stride,const int input_col_stride,float * outptr,const int matrix_stride)32 void InputTransform<4, 4, float, float, WinogradRoots::Integers>::transform_tile(
33   const int n_channels,
34   const float* const input_base,
35   const int input_row_stride,
36   const int input_col_stride,
37   float* outptr,
38   const int matrix_stride
39 )
40 {
41   constexpr int inner_tile_rows = 4, inner_tile_cols = 4;
42 
43   // Get pointers into the input tile
44   const float *x_ptrs[inner_tile_rows][inner_tile_cols];
45   for (int i = 0, xi = 0; i < inner_tile_rows; i++, xi++)
46   {
47     // Get a pointer into the row
48     const float* const row_ptr = input_base + xi*input_row_stride;
49 
50     for (int j = 0, xj = 0; j < inner_tile_cols; j++, xj++)
51     {
52       x_ptrs[i][j] = row_ptr + xj*input_col_stride;
53     }
54   }
55 
56   // Matrices used/computed in this kernel.
57   float x[inner_tile_rows][inner_tile_cols];
58   float XTx[inner_tile_rows][inner_tile_cols];
59   float U[inner_tile_rows][inner_tile_cols];
60 
61   for (int i = 0; i < inner_tile_rows; i++)
62   {
63     for (int j = 0; j < inner_tile_cols; j++)
64     {
65       x[i][j] = XTx[i][j] = 0.0f;
66     }
67   }
68 
69   // Perform the Winograd input transformation for each channel in the input
70   // tensor.
71   int channels_remaining = n_channels;
72 #ifdef __aarch64__
73   for (; channels_remaining >= 4; channels_remaining -= 4)
74   {
75     // Matrices used/computed in this kernel.
76     float32x4_t x[inner_tile_rows][inner_tile_cols];
77     float32x4_t XTx[inner_tile_rows][inner_tile_cols];
78     float32x4_t U[inner_tile_rows][inner_tile_cols];
79 
80     for (int i = 0; i < inner_tile_rows; i++)
81     {
82       for (int j = 0; j < inner_tile_cols; j++)
83       {
84         x[i][j] = vdupq_n_f32(0.0f);
85         XTx[i][j] = vdupq_n_f32(0.0f);
86       }
87     }
88 
89     // Load x
90     for (int i = 0; i < inner_tile_rows; i++)
91     {
92       for (int j = 0; j < inner_tile_cols; j++)
93       {
94         x[i][j] = vld1q_f32(x_ptrs[i][j]);
95         x_ptrs[i][j] += 4;
96       }
97     }
98 
99     // Compute XT . x
100     for (int j = 0; j < inner_tile_cols; j++)
101     {
102       // XTx[0][j] = x[0][j] - x[2][j];
103       XTx[0][j] = vsubq_f32(x[0][j], x[2][j]);
104 
105       // XTx[1][j] = x[1][j] + x[2][j];
106       XTx[1][j] = vaddq_f32(x[1][j], x[2][j]);
107 
108       // XTx[2][j] = x[2][j] - x[1][j];
109       XTx[2][j] = vsubq_f32(x[2][j], x[1][j]);
110 
111       // XTx[3][j] = x[1][j] - x[3][j];
112       XTx[3][j] = vsubq_f32(x[1][j], x[3][j]);
113     }
114 
115     // Compute U = XT . x . X
116     for (int i = 0; i < inner_tile_rows; i++)
117     {
118       // U[i][0] = XTx[i][0] - XTx[i][2];
119       U[i][0] = vsubq_f32(XTx[i][0], XTx[i][2]);
120 
121       // U[i][1] = XTx[i][1] + XTx[i][2];
122       U[i][1] = vaddq_f32(XTx[i][1], XTx[i][2]);
123 
124       // U[i][2] = XTx[i][2] - XTx[i][1];
125       U[i][2] = vsubq_f32(XTx[i][2], XTx[i][1]);
126 
127       // U[i][3] = XTx[i][1] - XTx[i][3];
128       U[i][3] = vsubq_f32(XTx[i][1], XTx[i][3]);
129     }
130 
131     // Store the transformed matrix
132     for (int i = 0, m = 0; i < inner_tile_rows; i++)
133     {
134       for (int j = 0; j < inner_tile_cols; j++, m++)
135       {
136         vst1q_f32(outptr + m*matrix_stride, U[i][j]);
137       }
138     }
139     outptr += 4;
140   }
141 #endif  // __aarch64__
142 #ifdef __arm_any__
143   for (; channels_remaining >= 2; channels_remaining -= 2)
144   {
145     // Matrices used/computed in this kernel.
146     float32x2_t x[inner_tile_rows][inner_tile_cols];
147     float32x2_t XTx[inner_tile_rows][inner_tile_cols];
148     float32x2_t U[inner_tile_rows][inner_tile_cols];
149 
150     for (int i = 0; i < inner_tile_rows; i++)
151     {
152       for (int j = 0; j < inner_tile_cols; j++)
153       {
154         x[i][j] = vdup_n_f32(0.0f);
155         XTx[i][j] = vdup_n_f32(0.0f);
156       }
157     }
158 
159     // Load x
160     for (int i = 0; i < inner_tile_rows; i++)
161     {
162       for (int j = 0; j < inner_tile_cols; j++)
163       {
164         x[i][j] = vld1_f32(x_ptrs[i][j]);
165         x_ptrs[i][j] += 2;
166       }
167     }
168 
169     // Compute XT . x
170     for (int j = 0; j < inner_tile_cols; j++)
171     {
172       // XTx[0][j] = x[0][j] - x[2][j];
173       XTx[0][j] = vsub_f32(x[0][j], x[2][j]);
174 
175       // XTx[1][j] = x[1][j] + x[2][j];
176       XTx[1][j] = vadd_f32(x[1][j], x[2][j]);
177 
178       // XTx[2][j] = x[2][j] - x[1][j];
179       XTx[2][j] = vsub_f32(x[2][j], x[1][j]);
180 
181       // XTx[3][j] = x[1][j] - x[3][j];
182       XTx[3][j] = vsub_f32(x[1][j], x[3][j]);
183     }
184 
185     // Compute U = XT . x . X
186     for (int i = 0; i < inner_tile_rows; i++)
187     {
188       // U[i][0] = XTx[i][0] - XTx[i][2];
189       U[i][0] = vsub_f32(XTx[i][0], XTx[i][2]);
190 
191       // U[i][1] = XTx[i][1] + XTx[i][2];
192       U[i][1] = vadd_f32(XTx[i][1], XTx[i][2]);
193 
194       // U[i][2] = XTx[i][2] - XTx[i][1];
195       U[i][2] = vsub_f32(XTx[i][2], XTx[i][1]);
196 
197       // U[i][3] = XTx[i][1] - XTx[i][3];
198       U[i][3] = vsub_f32(XTx[i][1], XTx[i][3]);
199     }
200 
201     // Store the transformed matrix
202     for (int i = 0, m = 0; i < inner_tile_rows; i++)
203     {
204       for (int j = 0; j < inner_tile_cols; j++, m++)
205       {
206         vst1_f32(outptr + m*matrix_stride, U[i][j]);
207       }
208     }
209     outptr += 2;
210   }
211 #endif  // __arm_any__
212   for (; channels_remaining; channels_remaining--)
213   {
214     // Load x
215     for (int i = 0; i < inner_tile_rows; i++)
216     {
217       for (int j = 0; j < inner_tile_cols; j++)
218       {
219         x[i][j] = *(x_ptrs[i][j]++);
220       }
221     }
222 
223     // Compute XT . x
224     for (int j = 0; j < inner_tile_cols; j++)
225     {
226       XTx[0][j] = x[0][j] - x[2][j];
227       XTx[1][j] = x[1][j] + x[2][j];
228       XTx[2][j] = x[2][j] - x[1][j];
229       XTx[3][j] = x[1][j] - x[3][j];
230     }
231 
232     // Compute U = XT . x . X
233     for (int i = 0; i < inner_tile_rows; i++)
234     {
235       U[i][0] = XTx[i][0] - XTx[i][2];
236       U[i][1] = XTx[i][1] + XTx[i][2];
237       U[i][2] = XTx[i][2] - XTx[i][1];
238       U[i][3] = XTx[i][1] - XTx[i][3];
239     }
240 
241     // Store the transformed matrix
242     for (int i = 0, m = 0; i < inner_tile_rows; i++)
243     {
244       for (int j = 0; j < inner_tile_cols; j++, m++)
245       {
246         *(outptr + m*matrix_stride) = U[i][j];
247       }
248     }
249     outptr++;
250   }
251 }
252 
253 template class InputTransform<4, 4, float, float, WinogradRoots::Integers>;
254 
255 }  // namespace
256