1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <utility>
17
18 #include "tensorflow/core/util/tensor_format.h"
19
20 #include "tensorflow/core/platform/logging.h"
21 #include "tensorflow/core/platform/test.h"
22
23 namespace tensorflow {
24
25 #define EnumStringPair(val) \
26 { val, #val }
27
28 std::pair<TensorFormat, const char*> test_data_formats[] = {
29 EnumStringPair(FORMAT_NHWC), EnumStringPair(FORMAT_NCHW),
30 EnumStringPair(FORMAT_NCHW_VECT_C), EnumStringPair(FORMAT_NHWC_VECT_W),
31 EnumStringPair(FORMAT_HWNC), EnumStringPair(FORMAT_HWCN),
32 };
33
34 std::pair<FilterTensorFormat, const char*> test_filter_formats[] = {
35 EnumStringPair(FORMAT_HWIO),
36 EnumStringPair(FORMAT_OIHW),
37 EnumStringPair(FORMAT_OIHW_VECT_I),
38 };
39
40 // This is an alternative way of specifying the tensor dimension indexes for
41 // each tensor format. For now it can be used as a cross-check of the existing
42 // functions, but later could replace them.
43
44 // Represents the dimension indexes of an activations tensor format.
45 struct TensorDimMap {
ntensorflow::TensorDimMap46 int n() const { return dim_n; }
htensorflow::TensorDimMap47 int h() const { return dim_h; }
wtensorflow::TensorDimMap48 int w() const { return dim_w; }
ctensorflow::TensorDimMap49 int c() const { return dim_c; }
spatialtensorflow::TensorDimMap50 int spatial(int spatial_index) const { return spatial_dim[spatial_index]; }
51
52 int dim_n, dim_h, dim_w, dim_c;
53 int spatial_dim[3];
54 };
55
56 // Represents the dimension indexes of a filter tensor format.
57 struct FilterDimMap {
htensorflow::FilterDimMap58 int h() const { return dim_h; }
wtensorflow::FilterDimMap59 int w() const { return dim_w; }
itensorflow::FilterDimMap60 int i() const { return dim_i; }
otensorflow::FilterDimMap61 int o() const { return dim_o; }
spatialtensorflow::FilterDimMap62 int spatial(int spatial_index) const { return spatial_dim[spatial_index]; }
63
64 int dim_h, dim_w, dim_i, dim_o;
65 int spatial_dim[3];
66 };
67
68 // clang-format off
69
70 // Predefined constants specifying the actual dimension indexes for each
71 // supported tensor and filter format.
72 struct DimMaps {
73 #define StaCoExTensorDm static constexpr TensorDimMap
74 // 'N', 'H', 'W', 'C' 0, 1, 2
75 StaCoExTensorDm kTdmInvalid = { -1, -1, -1, -1, { -1, -1, -1 } };
76 // These arrays are indexed by the number of spatial dimensions in the format.
77 StaCoExTensorDm kTdmNHWC[4] = { kTdmInvalid,
78 { 0, -1, 1, 2, { 1, -1, -1 } }, // 1D
79 { 0, 1, 2, 3, { 1, 2, -1 } }, // 2D
80 { 0, 2, 3, 4, { 1, 2, 3 } } // 3D
81 };
82 StaCoExTensorDm kTdmNCHW[4] = { kTdmInvalid,
83 { 0, -1, 2, 1, { 2, -1, -1 } },
84 { 0, 2, 3, 1, { 2, 3, -1 } },
85 { 0, 3, 4, 1, { 2, 3, 4 } }
86 };
87 StaCoExTensorDm kTdmHWNC[4] = { kTdmInvalid,
88 { 1, -1, 0, 2, { 0, -1, -1 } },
89 { 2, 0, 1, 3, { 0, 1, -1 } },
90 { 3, 1, 2, 4, { 0, 1, 2 } }
91 };
92 StaCoExTensorDm kTdmHWCN[4] = { kTdmInvalid,
93 { 2, -1, 0, 1, { 0, -1, -1 } },
94 { 3, 0, 1, 2, { 0, 1, -1 } },
95 { 4, 1, 2, 3, { 0, 1, 2 } }
96 };
97 #undef StaCoExTensorDm
98 #define StaCoExFilterDm static constexpr FilterDimMap
99 // 'H', 'W', 'I', 'O' 0 1 2
100 StaCoExFilterDm kFdmInvalid = { -1, -1, -1, -1, { -1, -1, -1 } };
101 StaCoExFilterDm kFdmHWIO[4] = { kFdmInvalid,
102 { -1, 0, 1, 2, { 0, -1, -1 } },
103 { 0, 1, 2, 3, { 0, 1, -1 } },
104 { 1, 2, 3, 4, { 0, 1, 2 } }
105 };
106 StaCoExFilterDm kFdmOIHW[4] = { kFdmInvalid,
107 { -1, 2, 1, 0, { 2, -1, -1 } },
108 { 2, 3, 1, 0, { 2, 3, -1 } },
109 { 3, 4, 1, 0, { 2, 3, 4 } }
110 };
111 #undef StaCoExFilterDm
112 };
113
114 inline constexpr const TensorDimMap&
GetTensorDimMap(const int num_spatial_dims,const TensorFormat format)115 GetTensorDimMap(const int num_spatial_dims, const TensorFormat format) {
116 return
117 (format == FORMAT_NHWC ||
118 format == FORMAT_NHWC_VECT_W) ? DimMaps::kTdmNHWC[num_spatial_dims] :
119 (format == FORMAT_NCHW ||
120 format == FORMAT_NCHW_VECT_C) ? DimMaps::kTdmNCHW[num_spatial_dims] :
121 (format == FORMAT_HWNC) ? DimMaps::kTdmHWNC[num_spatial_dims] :
122 (format == FORMAT_HWCN) ? DimMaps::kTdmHWCN[num_spatial_dims]
123 : DimMaps::kTdmInvalid;
124 }
125
126 inline constexpr const FilterDimMap&
GetFilterDimMap(const int num_spatial_dims,const FilterTensorFormat format)127 GetFilterDimMap(const int num_spatial_dims,
128 const FilterTensorFormat format) {
129 return
130 (format == FORMAT_HWIO) ? DimMaps::kFdmHWIO[num_spatial_dims] :
131 (format == FORMAT_OIHW ||
132 format == FORMAT_OIHW_VECT_I) ? DimMaps::kFdmOIHW[num_spatial_dims]
133 : DimMaps::kFdmInvalid;
134 }
135 // clang-format on
136
137 constexpr TensorDimMap DimMaps::kTdmInvalid;
138 constexpr TensorDimMap DimMaps::kTdmNHWC[4];
139 constexpr TensorDimMap DimMaps::kTdmNCHW[4];
140 constexpr TensorDimMap DimMaps::kTdmHWNC[4];
141 constexpr TensorDimMap DimMaps::kTdmHWCN[4];
142 constexpr FilterDimMap DimMaps::kFdmInvalid;
143 constexpr FilterDimMap DimMaps::kFdmHWIO[4];
144 constexpr FilterDimMap DimMaps::kFdmOIHW[4];
145
TEST(TensorFormatTest,FormatEnumsAndStrings)146 TEST(TensorFormatTest, FormatEnumsAndStrings) {
147 const string prefix = "FORMAT_";
148 for (auto& test_data_format : test_data_formats) {
149 const char* stringified_format_enum = test_data_format.second;
150 LOG(INFO) << stringified_format_enum << " = " << test_data_format.first;
151 string expected_format_str = &stringified_format_enum[prefix.size()];
152 TensorFormat format;
153 EXPECT_TRUE(FormatFromString(expected_format_str, &format));
154 string format_str = ToString(format);
155 EXPECT_EQ(expected_format_str, format_str);
156 EXPECT_EQ(test_data_format.first, format);
157 }
158 for (auto& test_filter_format : test_filter_formats) {
159 const char* stringified_format_enum = test_filter_format.second;
160 LOG(INFO) << stringified_format_enum << " = " << test_filter_format.first;
161 string expected_format_str = &stringified_format_enum[prefix.size()];
162 FilterTensorFormat format;
163 EXPECT_TRUE(FilterFormatFromString(expected_format_str, &format));
164 string format_str = ToString(format);
165 EXPECT_EQ(expected_format_str, format_str);
166 EXPECT_EQ(test_filter_format.first, format);
167 }
168 }
169
170 template <int num_spatial_dims>
RunDimensionIndexesTest()171 void RunDimensionIndexesTest() {
172 for (auto& test_data_format : test_data_formats) {
173 TensorFormat format = test_data_format.first;
174 auto& tdm = GetTensorDimMap(num_spatial_dims, format);
175 int num_dims = GetTensorDimsFromSpatialDims(num_spatial_dims, format);
176 LOG(INFO) << ToString(format) << ", num_spatial_dims=" << num_spatial_dims
177 << ", num_dims=" << num_dims;
178 EXPECT_EQ(GetTensorBatchDimIndex(num_dims, format), tdm.n());
179 EXPECT_EQ(GetTensorDimIndex<num_spatial_dims>(format, 'N'), tdm.n());
180 EXPECT_EQ(GetTensorFeatureDimIndex(num_dims, format), tdm.c());
181 EXPECT_EQ(GetTensorDimIndex<num_spatial_dims>(format, 'C'), tdm.c());
182 for (int i = 0; i < num_spatial_dims; ++i) {
183 EXPECT_EQ(GetTensorSpatialDimIndex(num_dims, format, i), tdm.spatial(i));
184 EXPECT_EQ(GetTensorDimIndex<num_spatial_dims>(format, '0' + i),
185 tdm.spatial(i));
186 }
187 }
188 for (auto& test_filter_format : test_filter_formats) {
189 FilterTensorFormat format = test_filter_format.first;
190 auto& fdm = GetFilterDimMap(num_spatial_dims, format);
191 int num_dims = GetFilterTensorDimsFromSpatialDims(num_spatial_dims, format);
192 LOG(INFO) << ToString(format) << ", num_spatial_dims=" << num_spatial_dims
193 << ", num_dims=" << num_dims;
194 EXPECT_EQ(GetFilterTensorOutputChannelsDimIndex(num_dims, format), fdm.o());
195 EXPECT_EQ(GetFilterDimIndex<num_spatial_dims>(format, 'O'), fdm.o());
196 EXPECT_EQ(GetFilterTensorInputChannelsDimIndex(num_dims, format), fdm.i());
197 EXPECT_EQ(GetFilterDimIndex<num_spatial_dims>(format, 'I'), fdm.i());
198 for (int i = 0; i < num_spatial_dims; ++i) {
199 EXPECT_EQ(GetFilterTensorSpatialDimIndex(num_dims, format, i),
200 fdm.spatial(i));
201 EXPECT_EQ(GetFilterDimIndex<num_spatial_dims>(format, '0' + i),
202 fdm.spatial(i));
203 }
204 }
205 }
206
TEST(TensorFormatTest,DimensionIndexes)207 TEST(TensorFormatTest, DimensionIndexes) {
208 RunDimensionIndexesTest<1>();
209 RunDimensionIndexesTest<2>();
210 RunDimensionIndexesTest<3>();
211 }
212
213 } // namespace tensorflow
214