1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/kernels/winograd_transform.h"
17 #include "tensorflow/core/platform/test.h"
18
19 namespace tensorflow {
20 namespace {
21
ComputeKroneckerProduct(const int rows,const int cols,const float * matrix,float * matrix_out)22 static void ComputeKroneckerProduct(const int rows, const int cols,
23 const float* matrix, float* matrix_out) {
24 for (int i = 0; i < rows; ++i) {
25 for (int j = 0; j < cols; ++j) {
26 const float v = matrix[i * cols + j];
27 const int output_index_base = cols * (i * rows * cols + j);
28 for (int k = 0; k < rows; ++k) {
29 for (int l = 0; l < cols; ++l) {
30 const int input_index = k * cols + l;
31 const int output_index = k * cols * cols + l;
32 matrix_out[output_index_base + output_index] =
33 matrix[input_index] * v;
34 }
35 }
36 }
37 }
38 }
39
TEST(DeepConv2DTransformTest,Basic)40 TEST(DeepConv2DTransformTest, Basic) {
41 // Tests kronecker product of the following matrix with itself:
42 //
43 // [1.0 2.0]
44 // [3.0 4.0]
45 //
46 const int rows = 2;
47 const int cols = 2;
48
49 float transform_matrix[] = {1, 2, 3, 4};
50
51 const int kron_rows = rows * rows;
52 const int kron_cols = cols * cols;
53 float transform_matrix_kron[kron_rows * kron_cols];
54
55 ComputeKroneckerProduct(rows, cols, &transform_matrix[0],
56 &transform_matrix_kron[0]);
57
58 float transform_matrix_test[] = {1, 2, 2, 4, 3, 4, 6, 8,
59 3, 6, 4, 8, 9, 12, 12, 16};
60
61 for (int i = 0; i < kron_rows * kron_cols; ++i) {
62 EXPECT_FLOAT_EQ(transform_matrix_kron[i], transform_matrix_test[i]);
63 }
64 }
65
TEST(DeepConv2DTransformTest,WingradFilterTransformMatrix)66 TEST(DeepConv2DTransformTest, WingradFilterTransformMatrix) {
67 // Test that the filter transform matrix returned is the kronecker product of
68 // the following matrix with itself:
69 //
70 // [ 1 0 0 ]
71 // [ 1/2 1/2 1/2 ]
72 // [ 1/2 -1/2 1/2 ]
73 // [ 0 0 1 ]
74 //
75 const int rows = 4;
76 const int cols = 3;
77
78 float transform_matrix[] = {1, 0, 0, 0.5, 0.5, 0.5, 0.5, -0.5, 0.5, 0, 0, 1};
79
80 const int kron_rows = rows * rows;
81 const int kron_cols = cols * cols;
82
83 float transform_matrix_kron[kron_rows * kron_cols];
84
85 ComputeKroneckerProduct(rows, cols, &transform_matrix[0],
86 &transform_matrix_kron[0]);
87
88 float transform_matrix_test[kron_rows * kron_cols];
89 WinogradTransform<float> t;
90 t.GetFilterTransformMatrix(kron_rows, kron_cols, &transform_matrix_test[0]);
91
92 for (int i = 0; i < kron_rows * kron_cols; ++i) {
93 EXPECT_FLOAT_EQ(transform_matrix_kron[i], transform_matrix_test[i]);
94 }
95 }
96
TEST(DeepConv2DTransformTest,WingradInputTransformMatrix)97 TEST(DeepConv2DTransformTest, WingradInputTransformMatrix) {
98 // Test that the filter transform matrix returned is the kronecker product of
99 // the following matrix:
100 //
101 // [1 0 -1 0]
102 // [0 1 1 0]
103 // [0 -1 1 0]
104 // [0 1 0 -1]
105 //
106 const int rows = 4;
107 const int cols = 4;
108
109 float transform_matrix[] = {1, 0, -1, 0, 0, 1, 1, 0,
110 0, -1, 1, 0, 0, 1, 0, -1};
111
112 const int kron_rows = rows * rows;
113 const int kron_cols = cols * cols;
114
115 float transform_matrix_kron[kron_rows * kron_cols];
116
117 ComputeKroneckerProduct(rows, cols, &transform_matrix[0],
118 &transform_matrix_kron[0]);
119
120 float transform_matrix_test[kron_rows * kron_cols];
121 WinogradTransform<float> t;
122 t.GetInputTransformMatrix(kron_rows, kron_cols, &transform_matrix_test[0]);
123
124 for (int i = 0; i < kron_rows * kron_cols; ++i) {
125 EXPECT_FLOAT_EQ(transform_matrix_kron[i], transform_matrix_test[i]);
126 }
127 }
128
TEST(DeepConv2DTransformTest,WingradOutputTransformMatrix)129 TEST(DeepConv2DTransformTest, WingradOutputTransformMatrix) {
130 // Test that the filter transform matrix returned is the kronecker product of
131 // the following matrix:
132 //
133 // [1 1 1 0]
134 // [0 1 -1 -1]
135 //
136 const int rows = 2;
137 const int cols = 4;
138
139 float transform_matrix[] = {1, 1, 1, 0, 0, 1, -1, -1};
140
141 const int kron_rows = rows * rows;
142 const int kron_cols = cols * cols;
143
144 float transform_matrix_kron[kron_rows * kron_cols];
145
146 ComputeKroneckerProduct(rows, cols, &transform_matrix[0],
147 &transform_matrix_kron[0]);
148
149 float transform_matrix_test[kron_rows * kron_cols];
150 WinogradTransform<float> t;
151 t.GetOutputTransformMatrix(kron_rows, kron_cols, &transform_matrix_test[0]);
152
153 for (int i = 0; i < kron_rows * kron_cols; ++i) {
154 EXPECT_FLOAT_EQ(transform_matrix_kron[i], transform_matrix_test[i]);
155 }
156 }
157
158 } // namespace
159 } // namespace tensorflow
160