• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_KERNELS_WINOGRAD_TRANSFORM_H_
17 #define TENSORFLOW_CORE_KERNELS_WINOGRAD_TRANSFORM_H_
18 
19 #include "tensorflow/core/kernels/deep_conv2d.h"
20 
21 namespace tensorflow {
22 
23 // Winograd DeepConv2DTransform implementation for 3x3 filters.
24 // Details:
25 // *) Arithmetic complexity of computations: Shmuel Winograd
26 // *) Fast Algorithms for Convolutional Neural Networks: Lavin, Gray
27 
28 template <typename T>
29 class WinogradTransform : public DeepConv2DTransform<T> {
30  public:
31   typedef typename DeepConv2DTransform<T>::Shape Shape;
32 
WinogradTransform()33   WinogradTransform()
34       : filter_shape_(3, 3), input_shape_(4, 4), output_shape_(2, 2) {}
35 
36   virtual void GetFilterTransformMatrix(const int64 rows, const int64 cols,
37                                         T* transform_matrix) const;
38 
39   virtual void GetInputTransformMatrix(const int64 rows, const int64 cols,
40                                        T* transform_matrix) const;
41 
42   virtual void GetOutputTransformMatrix(const int64 rows, const int64 cols,
43                                         T* transform_matrix) const;
44 
filter_shape()45   virtual const Shape& filter_shape() const { return filter_shape_; }
input_shape()46   virtual const Shape& input_shape() const { return input_shape_; }
output_shape()47   virtual const Shape& output_shape() const { return output_shape_; }
48 
49  private:
50   const Shape filter_shape_;
51   const Shape input_shape_;
52   const Shape output_shape_;
53 };
54 
55 // The filter transform matrix is the kronecker product 'M * M' of the
56 // following matrix 'M':
57 //
58 //   [ 1    0   0   ]
59 //   [ 1/2  1/2 1/2 ]
60 //   [ 1/2 -1/2 1/2 ]
61 //   [ 0    0   1   ]
62 //
63 // The data layout of 'transform_matrix':
64 //   [input_tile_spatial_size, filter_spatial_size]
65 //
66 template <typename T>
GetFilterTransformMatrix(const int64 rows,const int64 cols,T * transform_matrix)67 void WinogradTransform<T>::GetFilterTransformMatrix(const int64 rows,
68                                                     const int64 cols,
69                                                     T* transform_matrix) const {
70   CHECK_GT(rows, 0);
71   CHECK_GT(cols, 0);
72   memset(transform_matrix, 0, sizeof(T) * rows * cols);
73 
74   // Sub matrix [0,0]
75   transform_matrix[0 * cols + 0] = T(1.0);
76 
77   transform_matrix[1 * cols + 0] = T(0.5);
78   transform_matrix[1 * cols + 1] = T(0.5);
79   transform_matrix[1 * cols + 2] = T(0.5);
80 
81   transform_matrix[2 * cols + 0] = T(0.5);
82   transform_matrix[2 * cols + 1] = T(-0.5);
83   transform_matrix[2 * cols + 2] = T(0.5);
84 
85   transform_matrix[3 * cols + 2] = T(1.0);
86 
87   // Sub matrix [1,0]
88   transform_matrix[4 * cols + 0] = T(0.5);
89 
90   transform_matrix[5 * cols + 0] = T(0.25);
91   transform_matrix[5 * cols + 1] = T(0.25);
92   transform_matrix[5 * cols + 2] = T(0.25);
93 
94   transform_matrix[6 * cols + 0] = T(0.25);
95   transform_matrix[6 * cols + 1] = T(-0.25);
96   transform_matrix[6 * cols + 2] = T(0.25);
97 
98   transform_matrix[7 * cols + 2] = T(0.5);
99 
100   // Sub matrix [1,1]
101   transform_matrix[4 * cols + 3] = T(0.5);
102 
103   transform_matrix[5 * cols + 3] = T(0.25);
104   transform_matrix[5 * cols + 4] = T(0.25);
105   transform_matrix[5 * cols + 5] = T(0.25);
106 
107   transform_matrix[6 * cols + 3] = T(0.25);
108   transform_matrix[6 * cols + 4] = T(-0.25);
109   transform_matrix[6 * cols + 5] = T(0.25);
110 
111   transform_matrix[7 * cols + 5] = T(0.5);
112 
113   // Sub matrix [1,2]
114   transform_matrix[4 * cols + 6] = T(0.5);
115 
116   transform_matrix[5 * cols + 6] = T(0.25);
117   transform_matrix[5 * cols + 7] = T(0.25);
118   transform_matrix[5 * cols + 8] = T(0.25);
119 
120   transform_matrix[6 * cols + 6] = T(0.25);
121   transform_matrix[6 * cols + 7] = T(-0.25);
122   transform_matrix[6 * cols + 8] = T(0.25);
123 
124   transform_matrix[7 * cols + 8] = T(0.5);
125 
126   // Sub matrix [2,0]
127   transform_matrix[8 * cols + 0] = T(0.5);
128 
129   transform_matrix[9 * cols + 0] = T(0.25);
130   transform_matrix[9 * cols + 1] = T(0.25);
131   transform_matrix[9 * cols + 2] = T(0.25);
132 
133   transform_matrix[10 * cols + 0] = T(0.25);
134   transform_matrix[10 * cols + 1] = T(-0.25);
135   transform_matrix[10 * cols + 2] = T(0.25);
136 
137   transform_matrix[11 * cols + 2] = T(0.5);
138 
139   // Sub matrix [2,1]
140   transform_matrix[8 * cols + 3] = T(-0.5);
141 
142   transform_matrix[9 * cols + 3] = T(-0.25);
143   transform_matrix[9 * cols + 4] = T(-0.25);
144   transform_matrix[9 * cols + 5] = T(-0.25);
145 
146   transform_matrix[10 * cols + 3] = T(-0.25);
147   transform_matrix[10 * cols + 4] = T(0.25);
148   transform_matrix[10 * cols + 5] = T(-0.25);
149 
150   transform_matrix[11 * cols + 5] = T(-0.5);
151 
152   // Sub matrix [2,2]
153   transform_matrix[8 * cols + 6] = T(0.5);
154 
155   transform_matrix[9 * cols + 6] = T(0.25);
156   transform_matrix[9 * cols + 7] = T(0.25);
157   transform_matrix[9 * cols + 8] = T(0.25);
158 
159   transform_matrix[10 * cols + 6] = T(0.25);
160   transform_matrix[10 * cols + 7] = T(-0.25);
161   transform_matrix[10 * cols + 8] = T(0.25);
162 
163   transform_matrix[11 * cols + 8] = T(0.5);
164 
165   // Sub matrix [3,2]
166   transform_matrix[12 * cols + 6] = T(1.0);
167 
168   transform_matrix[13 * cols + 6] = T(0.5);
169   transform_matrix[13 * cols + 7] = T(0.5);
170   transform_matrix[13 * cols + 8] = T(0.5);
171 
172   transform_matrix[14 * cols + 6] = T(0.5);
173   transform_matrix[14 * cols + 7] = T(-0.5);
174   transform_matrix[14 * cols + 8] = T(0.5);
175 
176   transform_matrix[15 * cols + 8] = T(1.0);
177 }
178 
179 // The input transform matrix is the kronecker product 'M * M' of the
180 // following matrix 'M':
181 //
182 //   [1   0  -1   0]
183 //   [0   1   1   0]
184 //   [0  -1   1   0]
185 //   [0   1   0  -1]
186 //
187 // Data layout of 'transform_matrix':
188 //   [tile_spatial_size, tile_spatial_size]
189 //
190 template <typename T>
GetInputTransformMatrix(const int64 rows,const int64 cols,T * transform_matrix)191 void WinogradTransform<T>::GetInputTransformMatrix(const int64 rows,
192                                                    const int64 cols,
193                                                    T* transform_matrix) const {
194   CHECK_GT(rows, 0);
195   CHECK_GT(cols, 0);
196   memset(transform_matrix, 0, sizeof(T) * rows * cols);
197 
198   // Sub matrix [0,0]
199   transform_matrix[0 * cols + 0] = T(1.0);
200   transform_matrix[0 * cols + 2] = T(-1.0);
201 
202   transform_matrix[1 * cols + 1] = T(1.0);
203   transform_matrix[1 * cols + 2] = T(1.0);
204 
205   transform_matrix[2 * cols + 1] = T(-1.0);
206   transform_matrix[2 * cols + 2] = T(1.0);
207 
208   transform_matrix[3 * cols + 1] = T(1.0);
209   transform_matrix[3 * cols + 3] = T(-1.0);
210 
211   // Sub matrix [0,2]
212   transform_matrix[0 * cols + 8] = T(-1.0);
213   transform_matrix[0 * cols + 10] = T(1.0);
214 
215   transform_matrix[1 * cols + 9] = T(-1.0);
216   transform_matrix[1 * cols + 10] = T(-1.0);
217 
218   transform_matrix[2 * cols + 9] = T(1.0);
219   transform_matrix[2 * cols + 10] = T(-1.0);
220 
221   transform_matrix[3 * cols + 9] = T(-1.0);
222   transform_matrix[3 * cols + 11] = T(1.0);
223 
224   // Sub matrix [1,1]
225   transform_matrix[4 * cols + 4] = T(1.0);
226   transform_matrix[4 * cols + 6] = T(-1.0);
227 
228   transform_matrix[5 * cols + 5] = T(1.0);
229   transform_matrix[5 * cols + 6] = T(1.0);
230 
231   transform_matrix[6 * cols + 5] = T(-1.0);
232   transform_matrix[6 * cols + 6] = T(1.0);
233 
234   transform_matrix[7 * cols + 5] = T(1.0);
235   transform_matrix[7 * cols + 7] = T(-1.0);
236 
237   // Sub matrix [1,2]
238   transform_matrix[4 * cols + 8] = T(1.0);
239   transform_matrix[4 * cols + 10] = T(-1.0);
240 
241   transform_matrix[5 * cols + 9] = T(1.0);
242   transform_matrix[5 * cols + 10] = T(1.0);
243 
244   transform_matrix[6 * cols + 9] = T(-1.0);
245   transform_matrix[6 * cols + 10] = T(1.0);
246 
247   transform_matrix[7 * cols + 9] = T(1.0);
248   transform_matrix[7 * cols + 11] = T(-1.0);
249 
250   // Sub matrix [2,1]
251   transform_matrix[8 * cols + 4] = T(-1.0);
252   transform_matrix[8 * cols + 6] = T(1.0);
253 
254   transform_matrix[9 * cols + 5] = T(-1.0);
255   transform_matrix[9 * cols + 6] = T(-1.0);
256 
257   transform_matrix[10 * cols + 5] = T(1.0);
258   transform_matrix[10 * cols + 6] = T(-1.0);
259 
260   transform_matrix[11 * cols + 5] = T(-1.0);
261   transform_matrix[11 * cols + 7] = T(1.0);
262 
263   // Sub matrix [2,2]
264   transform_matrix[8 * cols + 8] = T(1.0);
265   transform_matrix[8 * cols + 10] = T(-1.0);
266 
267   transform_matrix[9 * cols + 9] = T(1.0);
268   transform_matrix[9 * cols + 10] = T(1.0);
269 
270   transform_matrix[10 * cols + 9] = T(-1.0);
271   transform_matrix[10 * cols + 10] = T(1.0);
272 
273   transform_matrix[11 * cols + 9] = T(1.0);
274   transform_matrix[11 * cols + 11] = T(-1.0);
275 
276   // Sub matrix [3,1]
277   transform_matrix[12 * cols + 4] = T(1.0);
278   transform_matrix[12 * cols + 6] = T(-1.0);
279 
280   transform_matrix[13 * cols + 5] = T(1.0);
281   transform_matrix[13 * cols + 6] = T(1.0);
282 
283   transform_matrix[14 * cols + 5] = T(-1.0);
284   transform_matrix[14 * cols + 6] = T(1.0);
285 
286   transform_matrix[15 * cols + 5] = T(1.0);
287   transform_matrix[15 * cols + 7] = T(-1.0);
288 
289   // Sub matrix [3,3]
290   transform_matrix[12 * cols + 12] = T(-1.0);
291   transform_matrix[12 * cols + 14] = T(1.0);
292 
293   transform_matrix[13 * cols + 13] = T(-1.0);
294   transform_matrix[13 * cols + 14] = T(-1.0);
295 
296   transform_matrix[14 * cols + 13] = T(1.0);
297   transform_matrix[14 * cols + 14] = T(-1.0);
298 
299   transform_matrix[15 * cols + 13] = T(-1.0);
300   transform_matrix[15 * cols + 15] = T(1.0);
301 };
302 
303 // The output transform matrix is the kronecker product 'M * M' of the
304 // following matrix 'M':
305 //
306 //   [1  1  1  0]
307 //   [0  1 -1 -1]
308 //
309 // Data layout of 'transform_matrix':
310 //   [out_tile_spatial_size, tile_spatial_size]
311 //
312 template <typename T>
GetOutputTransformMatrix(const int64 rows,const int64 cols,T * transform_matrix)313 void WinogradTransform<T>::GetOutputTransformMatrix(const int64 rows,
314                                                     const int64 cols,
315                                                     T* transform_matrix) const {
316   CHECK_GT(rows, 0);
317   CHECK_GT(cols, 0);
318   memset(transform_matrix, 0, sizeof(T) * rows * cols);
319 
320   // Sub matrix [0,0]
321   transform_matrix[0 * cols + 0] = T(1.0);
322   transform_matrix[0 * cols + 1] = T(1.0);
323   transform_matrix[0 * cols + 2] = T(1.0);
324 
325   transform_matrix[1 * cols + 1] = T(1.0);
326   transform_matrix[1 * cols + 2] = T(-1.0);
327   transform_matrix[1 * cols + 3] = T(-1.0);
328 
329   // Sub matrix [0,1]
330   transform_matrix[0 * cols + 4] = T(1.0);
331   transform_matrix[0 * cols + 5] = T(1.0);
332   transform_matrix[0 * cols + 6] = T(1.0);
333 
334   transform_matrix[1 * cols + 5] = T(1.0);
335   transform_matrix[1 * cols + 6] = T(-1.0);
336   transform_matrix[1 * cols + 7] = T(-1.0);
337 
338   // Sub matrix [0,2]
339   transform_matrix[0 * cols + 8] = T(1.0);
340   transform_matrix[0 * cols + 9] = T(1.0);
341   transform_matrix[0 * cols + 10] = T(1.0);
342 
343   transform_matrix[1 * cols + 9] = T(1.0);
344   transform_matrix[1 * cols + 10] = T(-1.0);
345   transform_matrix[1 * cols + 11] = T(-1.0);
346 
347   // Sub matrix [1,1]
348   transform_matrix[2 * cols + 4] = T(1.0);
349   transform_matrix[2 * cols + 5] = T(1.0);
350   transform_matrix[2 * cols + 6] = T(1.0);
351 
352   transform_matrix[3 * cols + 5] = T(1.0);
353   transform_matrix[3 * cols + 6] = T(-1.0);
354   transform_matrix[3 * cols + 7] = T(-1.0);
355 
356   // Sub matrix [1,2]
357   transform_matrix[2 * cols + 8] = T(-1.0);
358   transform_matrix[2 * cols + 9] = T(-1.0);
359   transform_matrix[2 * cols + 10] = T(-1.0);
360 
361   transform_matrix[3 * cols + 9] = T(-1.0);
362   transform_matrix[3 * cols + 10] = T(1.0);
363   transform_matrix[3 * cols + 11] = T(1.0);
364 
365   // Sub matrix [1,3]
366   transform_matrix[2 * cols + 12] = T(-1.0);
367   transform_matrix[2 * cols + 13] = T(-1.0);
368   transform_matrix[2 * cols + 14] = T(-1.0);
369 
370   transform_matrix[3 * cols + 13] = T(-1.0);
371   transform_matrix[3 * cols + 14] = T(1.0);
372   transform_matrix[3 * cols + 15] = T(1.0);
373 };
374 
375 }  // namespace tensorflow
376 
377 #endif  // TENSORFLOW_CORE_KERNELS_WINOGRAD_TRANSFORM_H_
378