• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <utility>
17 
18 #include "tensorflow/core/util/tensor_format.h"
19 
20 #include "tensorflow/core/platform/logging.h"
21 #include "tensorflow/core/platform/test.h"
22 
23 namespace tensorflow {
24 
25 #define EnumStringPair(val) \
26   { val, #val }
27 
28 std::pair<TensorFormat, const char*> test_data_formats[] = {
29     EnumStringPair(FORMAT_NHWC),        EnumStringPair(FORMAT_NCHW),
30     EnumStringPair(FORMAT_NCHW_VECT_C), EnumStringPair(FORMAT_NHWC_VECT_W),
31     EnumStringPair(FORMAT_HWNC),        EnumStringPair(FORMAT_HWCN),
32 };
33 
34 std::pair<FilterTensorFormat, const char*> test_filter_formats[] = {
35     EnumStringPair(FORMAT_HWIO),
36     EnumStringPair(FORMAT_OIHW),
37     EnumStringPair(FORMAT_OIHW_VECT_I),
38 };
39 
40 // This is an alternative way of specifying the tensor dimension indexes for
41 // each tensor format. For now it can be used as a cross-check of the existing
42 // functions, but later could replace them.
43 
44 // Represents the dimension indexes of an activations tensor format.
45 struct TensorDimMap {
ntensorflow::TensorDimMap46   int n() const { return dim_n; }
htensorflow::TensorDimMap47   int h() const { return dim_h; }
wtensorflow::TensorDimMap48   int w() const { return dim_w; }
ctensorflow::TensorDimMap49   int c() const { return dim_c; }
spatialtensorflow::TensorDimMap50   int spatial(int spatial_index) const { return spatial_dim[spatial_index]; }
51 
52   int dim_n, dim_h, dim_w, dim_c;
53   int spatial_dim[3];
54 };
55 
56 // Represents the dimension indexes of a filter tensor format.
57 struct FilterDimMap {
htensorflow::FilterDimMap58   int h() const { return dim_h; }
wtensorflow::FilterDimMap59   int w() const { return dim_w; }
itensorflow::FilterDimMap60   int i() const { return dim_i; }
otensorflow::FilterDimMap61   int o() const { return dim_o; }
spatialtensorflow::FilterDimMap62   int spatial(int spatial_index) const { return spatial_dim[spatial_index]; }
63 
64   int dim_h, dim_w, dim_i, dim_o;
65   int spatial_dim[3];
66 };
67 
68 // clang-format off
69 
70 // Predefined constants specifying the actual dimension indexes for each
71 // supported tensor and filter format.
72 struct DimMaps {
73 #define StaCoExTensorDm static constexpr TensorDimMap
74   //                                'N', 'H', 'W', 'C'    0,  1,  2
75   StaCoExTensorDm kTdmInvalid =   { -1,  -1,  -1,  -1, { -1, -1, -1 } };
76   // These arrays are indexed by the number of spatial dimensions in the format.
77   StaCoExTensorDm kTdmNHWC[4] = { kTdmInvalid,
78                                   {  0,  -1,   1,   2, {  1, -1, -1 } },  // 1D
79                                   {  0,   1,   2,   3, {  1,  2, -1 } },  // 2D
80                                   {  0,   2,   3,   4, {  1,  2,  3 } }   // 3D
81                                 };
82   StaCoExTensorDm kTdmNCHW[4] = { kTdmInvalid,
83                                   {  0,  -1,   2,   1, {  2, -1, -1 } },
84                                   {  0,   2,   3,   1, {  2,  3, -1 } },
85                                   {  0,   3,   4,   1, {  2,  3,  4 } }
86                                 };
87   StaCoExTensorDm kTdmHWNC[4] = { kTdmInvalid,
88                                   {  1,  -1,   0,   2, {  0, -1, -1 } },
89                                   {  2,   0,   1,   3, {  0,  1, -1 } },
90                                   {  3,   1,   2,   4, {  0,  1,  2 } }
91                                 };
92   StaCoExTensorDm kTdmHWCN[4] = { kTdmInvalid,
93                                   {  2,  -1,   0,   1, {  0, -1, -1 } },
94                                   {  3,   0,   1,   2, {  0,  1, -1 } },
95                                   {  4,   1,   2,   3, {  0,  1,  2 } }
96                                 };
97 #undef StaCoExTensorDm
98 #define StaCoExFilterDm static constexpr FilterDimMap
99   //                                'H', 'W', 'I', 'O'    0   1   2
100   StaCoExFilterDm kFdmInvalid =   { -1,  -1,  -1,  -1, { -1, -1, -1 } };
101   StaCoExFilterDm kFdmHWIO[4] = { kFdmInvalid,
102                                   { -1,   0,   1,   2, {  0, -1, -1 } },
103                                   {  0,   1,   2,   3, {  0,  1, -1 } },
104                                   {  1,   2,   3,   4, {  0,  1,  2 } }
105                                 };
106   StaCoExFilterDm kFdmOIHW[4] = { kFdmInvalid,
107                                   { -1,   2,   1,   0, {  2, -1, -1 } },
108                                   {  2,   3,   1,   0, {  2,  3, -1 } },
109                                   {  3,   4,   1,   0, {  2,  3,  4 } }
110                                 };
111 #undef StaCoExFilterDm
112 };
113 
114 inline constexpr const TensorDimMap&
GetTensorDimMap(const int num_spatial_dims,const TensorFormat format)115 GetTensorDimMap(const int num_spatial_dims, const TensorFormat format) {
116   return
117       (format == FORMAT_NHWC ||
118        format == FORMAT_NHWC_VECT_W) ? DimMaps::kTdmNHWC[num_spatial_dims] :
119       (format == FORMAT_NCHW ||
120        format == FORMAT_NCHW_VECT_C) ? DimMaps::kTdmNCHW[num_spatial_dims] :
121       (format == FORMAT_HWNC) ? DimMaps::kTdmHWNC[num_spatial_dims] :
122       (format == FORMAT_HWCN) ? DimMaps::kTdmHWCN[num_spatial_dims]
123                               : DimMaps::kTdmInvalid;
124 }
125 
126 inline constexpr const FilterDimMap&
GetFilterDimMap(const int num_spatial_dims,const FilterTensorFormat format)127 GetFilterDimMap(const int num_spatial_dims,
128                 const FilterTensorFormat format) {
129   return
130       (format == FORMAT_HWIO) ? DimMaps::kFdmHWIO[num_spatial_dims] :
131       (format == FORMAT_OIHW ||
132        format == FORMAT_OIHW_VECT_I) ? DimMaps::kFdmOIHW[num_spatial_dims]
133                                      : DimMaps::kFdmInvalid;
134 }
135 // clang-format on
136 
137 constexpr TensorDimMap DimMaps::kTdmInvalid;
138 constexpr TensorDimMap DimMaps::kTdmNHWC[4];
139 constexpr TensorDimMap DimMaps::kTdmNCHW[4];
140 constexpr TensorDimMap DimMaps::kTdmHWNC[4];
141 constexpr TensorDimMap DimMaps::kTdmHWCN[4];
142 constexpr FilterDimMap DimMaps::kFdmInvalid;
143 constexpr FilterDimMap DimMaps::kFdmHWIO[4];
144 constexpr FilterDimMap DimMaps::kFdmOIHW[4];
145 
TEST(TensorFormatTest,FormatEnumsAndStrings)146 TEST(TensorFormatTest, FormatEnumsAndStrings) {
147   const string prefix = "FORMAT_";
148   for (auto& test_data_format : test_data_formats) {
149     const char* stringified_format_enum = test_data_format.second;
150     LOG(INFO) << stringified_format_enum << " = " << test_data_format.first;
151     string expected_format_str = &stringified_format_enum[prefix.size()];
152     TensorFormat format;
153     EXPECT_TRUE(FormatFromString(expected_format_str, &format));
154     string format_str = ToString(format);
155     EXPECT_EQ(expected_format_str, format_str);
156     EXPECT_EQ(test_data_format.first, format);
157   }
158   for (auto& test_filter_format : test_filter_formats) {
159     const char* stringified_format_enum = test_filter_format.second;
160     LOG(INFO) << stringified_format_enum << " = " << test_filter_format.first;
161     string expected_format_str = &stringified_format_enum[prefix.size()];
162     FilterTensorFormat format;
163     EXPECT_TRUE(FilterFormatFromString(expected_format_str, &format));
164     string format_str = ToString(format);
165     EXPECT_EQ(expected_format_str, format_str);
166     EXPECT_EQ(test_filter_format.first, format);
167   }
168 }
169 
170 template <int num_spatial_dims>
RunDimensionIndexesTest()171 void RunDimensionIndexesTest() {
172   for (auto& test_data_format : test_data_formats) {
173     TensorFormat format = test_data_format.first;
174     auto& tdm = GetTensorDimMap(num_spatial_dims, format);
175     int num_dims = GetTensorDimsFromSpatialDims(num_spatial_dims, format);
176     LOG(INFO) << ToString(format) << ", num_spatial_dims=" << num_spatial_dims
177               << ", num_dims=" << num_dims;
178     EXPECT_EQ(GetTensorBatchDimIndex(num_dims, format), tdm.n());
179     EXPECT_EQ(GetTensorDimIndex<num_spatial_dims>(format, 'N'), tdm.n());
180     EXPECT_EQ(GetTensorFeatureDimIndex(num_dims, format), tdm.c());
181     EXPECT_EQ(GetTensorDimIndex<num_spatial_dims>(format, 'C'), tdm.c());
182     for (int i = 0; i < num_spatial_dims; ++i) {
183       EXPECT_EQ(GetTensorSpatialDimIndex(num_dims, format, i), tdm.spatial(i));
184       EXPECT_EQ(GetTensorDimIndex<num_spatial_dims>(format, '0' + i),
185                 tdm.spatial(i));
186     }
187   }
188   for (auto& test_filter_format : test_filter_formats) {
189     FilterTensorFormat format = test_filter_format.first;
190     auto& fdm = GetFilterDimMap(num_spatial_dims, format);
191     int num_dims = GetFilterTensorDimsFromSpatialDims(num_spatial_dims, format);
192     LOG(INFO) << ToString(format) << ", num_spatial_dims=" << num_spatial_dims
193               << ", num_dims=" << num_dims;
194     EXPECT_EQ(GetFilterTensorOutputChannelsDimIndex(num_dims, format), fdm.o());
195     EXPECT_EQ(GetFilterDimIndex<num_spatial_dims>(format, 'O'), fdm.o());
196     EXPECT_EQ(GetFilterTensorInputChannelsDimIndex(num_dims, format), fdm.i());
197     EXPECT_EQ(GetFilterDimIndex<num_spatial_dims>(format, 'I'), fdm.i());
198     for (int i = 0; i < num_spatial_dims; ++i) {
199       EXPECT_EQ(GetFilterTensorSpatialDimIndex(num_dims, format, i),
200                 fdm.spatial(i));
201       EXPECT_EQ(GetFilterDimIndex<num_spatial_dims>(format, '0' + i),
202                 fdm.spatial(i));
203     }
204   }
205 }
206 
TEST(TensorFormatTest,DimensionIndexes)207 TEST(TensorFormatTest, DimensionIndexes) {
208   RunDimensionIndexesTest<1>();
209   RunDimensionIndexesTest<2>();
210   RunDimensionIndexesTest<3>();
211 }
212 
213 }  // namespace tensorflow
214