• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/kernels/ops_util.h"
17 
18 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
19 #include "tensorflow/core/framework/kernel_shape_util.h"
20 #include "tensorflow/core/framework/tensor.h"
21 #include "tensorflow/core/platform/test.h"
22 
23 namespace tensorflow {
24 namespace {
25 
26 class OpsUtilTest : public ::testing::Test {
27  protected:
OpsUtilTest()28   OpsUtilTest() {}
~OpsUtilTest()29   ~OpsUtilTest() override {}
30 
31   // Padding structure.
32   struct padding_struct {
33     // Input parameters.
34     struct {
35       int in_height;
36       int in_width;
37       int filter_height;
38       int filter_width;
39       int row_stride;
40       int col_stride;
41       Padding padding;
42     } input;
43     // Output.
44     struct {
45       int new_height;
46       int new_width;
47       int pad_top;
48       int pad_bottom;
49       int pad_left;
50       int pad_right;
51     } output;
52   };
53 
54   // Broadcast structure.
55   struct bcast_struct {
56     // Input parameters.
57     struct {
58       int index;     // Current index.
59       int in_size;   // Size of the dimension.
60       int ksize;     // Kernel size.
61       int stride;    // Stride.
62       int pad_size;  // Padding size.
63     } input;
64     // Output.
65     struct {
66       int new_index;  // New starting index.
67       int new_size;   // New broadcast size.
68     } output;
69   };
70 
VerifyGet2dOutputSizeBoundaries(padding_struct pad_struct,error::Code code)71   static void VerifyGet2dOutputSizeBoundaries(padding_struct pad_struct,
72                                               error::Code code) {
73     int64_t new_height, new_width, pad_rows, pad_cols;
74     Status status = GetWindowedOutputSize(
75         pad_struct.input.in_height, pad_struct.input.filter_height,
76         pad_struct.input.row_stride, pad_struct.input.padding, &new_height,
77         &pad_rows);
78     EXPECT_EQ(status.code(), code) << status;
79     status = GetWindowedOutputSize(
80         pad_struct.input.in_width, pad_struct.input.filter_width,
81         pad_struct.input.col_stride, pad_struct.input.padding, &new_width,
82         &pad_cols);
83     EXPECT_EQ(status.code(), code) << status;
84   }
85 
VerifyGet2dOutputSizeValues(padding_struct pad_struct,error::Code code)86   static void VerifyGet2dOutputSizeValues(padding_struct pad_struct,
87                                           error::Code code) {
88     int64_t new_height, new_width, pad_rows, pad_cols;
89     Status status = GetWindowedOutputSize(
90         pad_struct.input.in_height, pad_struct.input.filter_height,
91         pad_struct.input.row_stride, pad_struct.input.padding, &new_height,
92         &pad_rows);
93     EXPECT_EQ(status.code(), code) << status;
94     status = GetWindowedOutputSize(
95         pad_struct.input.in_width, pad_struct.input.filter_width,
96         pad_struct.input.col_stride, pad_struct.input.padding, &new_width,
97         &pad_cols);
98     EXPECT_EQ(status.code(), code) << status;
99     EXPECT_EQ(pad_struct.output.new_height, new_height);
100     EXPECT_EQ(pad_struct.output.new_width, new_width);
101     EXPECT_EQ(pad_struct.output.pad_top, pad_rows);
102     EXPECT_EQ(pad_struct.output.pad_left, pad_cols);
103   }
104 
VerifyGet2dOutputVerboseSizeValues(padding_struct pad_struct,error::Code code)105   static void VerifyGet2dOutputVerboseSizeValues(padding_struct pad_struct,
106                                                  error::Code code) {
107     int64_t new_height, new_width, pad_top, pad_bottom, pad_left, pad_right;
108     Status status = GetWindowedOutputSizeVerbose(
109         pad_struct.input.in_height, pad_struct.input.filter_height,
110         pad_struct.input.row_stride, pad_struct.input.padding, &new_height,
111         &pad_top, &pad_bottom);
112     EXPECT_EQ(status.code(), code) << status;
113     status = GetWindowedOutputSizeVerbose(
114         pad_struct.input.in_width, pad_struct.input.filter_width,
115         pad_struct.input.col_stride, pad_struct.input.padding, &new_width,
116         &pad_left, &pad_right);
117     EXPECT_EQ(status.code(), code) << status;
118     EXPECT_EQ(pad_struct.output.new_height, new_height);
119     EXPECT_EQ(pad_struct.output.new_width, new_width);
120     EXPECT_EQ(pad_struct.output.pad_top, pad_top);
121     EXPECT_EQ(pad_struct.output.pad_bottom, pad_bottom);
122     EXPECT_EQ(pad_struct.output.pad_left, pad_left);
123     EXPECT_EQ(pad_struct.output.pad_right, pad_right);
124   }
125 
VerifyBoundaries(bcast_struct bcast,error::Code code)126   static void VerifyBoundaries(bcast_struct bcast, error::Code code) {
127     int new_index, new_size;
128     Status status = GetBroadcastSize(
129         bcast.input.index, bcast.input.in_size, bcast.input.ksize,
130         bcast.input.stride, bcast.input.pad_size, &new_index, &new_size);
131     EXPECT_EQ(status.code(), code) << status;
132   }
133 
VerifyBcastValues(bcast_struct bcast)134   static void VerifyBcastValues(bcast_struct bcast) {
135     int new_index, new_size;
136     EXPECT_EQ(Status::OK(),
137               GetBroadcastSize(bcast.input.index, bcast.input.in_size,
138                                bcast.input.ksize, bcast.input.stride,
139                                bcast.input.pad_size, &new_index, &new_size));
140     EXPECT_EQ(bcast.output.new_index, new_index);
141     EXPECT_EQ(bcast.output.new_size, new_size);
142   }
143 };
144 
TEST_F(OpsUtilTest,Get2dOutputSizeNegativeSizeTest)145 TEST_F(OpsUtilTest, Get2dOutputSizeNegativeSizeTest) {
146   padding_struct pad_struct = {{1, 1, 3, 3, 1, 1, VALID}, {-1, -1, 0, 0, 0, 0}};
147   VerifyGet2dOutputSizeBoundaries(pad_struct, error::INVALID_ARGUMENT);
148 }
149 
TEST_F(OpsUtilTest,Get2dOutputSizeSquareFilterTest)150 TEST_F(OpsUtilTest, Get2dOutputSizeSquareFilterTest) {
151   padding_struct pad_struct1 = {{3, 3, 2, 2, 2, 2, SAME}, {2, 2, 0, 0, 0, 0}};
152   padding_struct pad_struct2 = {{3, 3, 2, 2, 2, 2, VALID}, {1, 1, 0, 0, 0, 0}};
153   VerifyGet2dOutputSizeValues(pad_struct1, error::OK);
154   VerifyGet2dOutputSizeValues(pad_struct2, error::OK);
155 }
156 
TEST_F(OpsUtilTest,Get2dOutputSizeNonSquareFilterTest)157 TEST_F(OpsUtilTest, Get2dOutputSizeNonSquareFilterTest) {
158   padding_struct pad_struct1 = {{4, 5, 1, 2, 1, 1, SAME}, {4, 5, 0, 0, 0, 0}};
159   padding_struct pad_struct2 = {{4, 5, 1, 2, 1, 1, VALID}, {4, 4, 0, 0, 0, 0}};
160   VerifyGet2dOutputSizeValues(pad_struct1, error::OK);
161   VerifyGet2dOutputSizeValues(pad_struct2, error::OK);
162 }
163 
TEST_F(OpsUtilTest,Get2dOutputSizeUnevenStrideTest)164 TEST_F(OpsUtilTest, Get2dOutputSizeUnevenStrideTest) {
165   padding_struct pad_struct1 = {{4, 4, 2, 2, 1, 2, VALID}, {3, 2, 0, 0, 0, 0}};
166   padding_struct pad_struct2 = {{4, 4, 2, 2, 2, 1, VALID}, {2, 3, 0, 0, 0, 0}};
167   VerifyGet2dOutputSizeValues(pad_struct1, error::OK);
168   VerifyGet2dOutputSizeValues(pad_struct2, error::OK);
169 }
170 
TEST_F(OpsUtilTest,Get2dOutputSizeVerbose)171 TEST_F(OpsUtilTest, Get2dOutputSizeVerbose) {
172   padding_struct pad_struct1 = {{3, 3, 2, 2, 2, 2, SAME}, {2, 2, 0, 1, 0, 1}};
173   padding_struct pad_struct2 = {{3, 3, 2, 2, 2, 2, VALID}, {1, 1, 0, 0, 0, 0}};
174   VerifyGet2dOutputVerboseSizeValues(pad_struct1, error::OK);
175   VerifyGet2dOutputVerboseSizeValues(pad_struct2, error::OK);
176 }
177 
178 // Test index * stride > in_size fails with INVALID_ARGUMENT.
TEST_F(OpsUtilTest,GetBroadcastTestBadIndex)179 TEST_F(OpsUtilTest, GetBroadcastTestBadIndex) {
180   bcast_struct bcast = {{2, 3, 1, 2, 0}, {0, 3}};
181   VerifyBoundaries(bcast, error::INVALID_ARGUMENT);
182 }
183 
184 // in_size = 3, ksize = 3, stride = 1, pad_size = 0
TEST_F(OpsUtilTest,GetBroadcastTest3_3_1_0)185 TEST_F(OpsUtilTest, GetBroadcastTest3_3_1_0) {
186   bcast_struct bcast[] = {
187       {{0, 3, 3, 1, 0}, {0, 3}},
188       {{1, 3, 3, 1, 0}, {1, 2}},
189       {{2, 3, 3, 1, 0}, {2, 1}},
190   };
191   for (size_t i = 0; i < sizeof(bcast) / sizeof(bcast[0]); ++i) {
192     VerifyBcastValues(bcast[i]);
193   }
194 }
195 
196 // in_size = 3, ksize = 3, stride = 1, pad_size = 1
TEST_F(OpsUtilTest,GetBroadcastTest3_3_1_1)197 TEST_F(OpsUtilTest, GetBroadcastTest3_3_1_1) {
198   bcast_struct bcast[] = {
199       {{0, 3, 3, 1, 1}, {0, 2}},
200       {{1, 3, 3, 1, 1}, {0, 3}},
201       {{2, 3, 3, 1, 1}, {1, 2}},
202   };
203   for (size_t i = 0; i < sizeof(bcast) / sizeof(bcast[0]); ++i) {
204     VerifyBcastValues(bcast[i]);
205   }
206 }
207 
208 // in_size = 3, ksize = 3, stride = 1, pad_size = 2
TEST_F(OpsUtilTest,GetBroadcastTest3_3_1_2)209 TEST_F(OpsUtilTest, GetBroadcastTest3_3_1_2) {
210   bcast_struct bcast[] = {
211       {{0, 3, 3, 1, 2}, {0, 1}},
212       {{1, 3, 3, 1, 2}, {0, 2}},
213       {{2, 3, 3, 1, 2}, {0, 3}},
214   };
215   for (size_t i = 0; i < sizeof(bcast) / sizeof(bcast[0]); ++i) {
216     VerifyBcastValues(bcast[i]);
217   }
218 }
219 
220 // in_size = 3, ksize = 3, stride = 2, pad_size = 0
TEST_F(OpsUtilTest,GetBroadcastTest3_3_2_0)221 TEST_F(OpsUtilTest, GetBroadcastTest3_3_2_0) {
222   bcast_struct bcast[] = {
223       {{0, 3, 3, 2, 0}, {0, 3}},
224       {{1, 3, 3, 2, 0}, {2, 1}},
225   };
226   for (size_t i = 0; i < sizeof(bcast) / sizeof(bcast[0]); ++i) {
227     VerifyBcastValues(bcast[i]);
228   }
229 }
230 
231 // in_size = 3, ksize = 3, stride = 2, pad_size = 1
TEST_F(OpsUtilTest,GetBroadcastTest3_3_2_1)232 TEST_F(OpsUtilTest, GetBroadcastTest3_3_2_1) {
233   bcast_struct bcast[] = {
234       {{0, 3, 3, 2, 1}, {0, 2}},
235       {{1, 3, 3, 2, 1}, {1, 2}},
236   };
237   for (size_t i = 0; i < sizeof(bcast) / sizeof(bcast[0]); ++i) {
238     VerifyBcastValues(bcast[i]);
239   }
240 }
241 
242 // in_size = 3, ksize = 3, stride = 2, pad_size = 2
TEST_F(OpsUtilTest,GetBroadcastTest3_3_2_2)243 TEST_F(OpsUtilTest, GetBroadcastTest3_3_2_2) {
244   bcast_struct bcast[] = {
245       {{0, 3, 3, 2, 2}, {0, 1}},
246   };
247   for (size_t i = 0; i < sizeof(bcast) / sizeof(bcast[0]); ++i) {
248     VerifyBcastValues(bcast[i]);
249   }
250 }
251 
252 // in_size = 3, ksize = 3, stride = 3, pad_size = 0
TEST_F(OpsUtilTest,GetBroadcastTest3_3_3_0)253 TEST_F(OpsUtilTest, GetBroadcastTest3_3_3_0) {
254   bcast_struct bcast[] = {
255       {{0, 3, 3, 3, 0}, {0, 3}},
256   };
257   for (size_t i = 0; i < sizeof(bcast) / sizeof(bcast[0]); ++i) {
258     VerifyBcastValues(bcast[i]);
259   }
260 }
261 
262 // in_size = 3, ksize = 3, stride = 3, pad_size = 1
TEST_F(OpsUtilTest,GetBroadcastTest3_3_3_1)263 TEST_F(OpsUtilTest, GetBroadcastTest3_3_3_1) {
264   bcast_struct bcast[] = {
265       {{0, 3, 3, 3, 1}, {0, 2}},
266       {{1, 3, 3, 3, 1}, {2, 1}},
267   };
268   for (size_t i = 0; i < sizeof(bcast) / sizeof(bcast[0]); ++i) {
269     VerifyBcastValues(bcast[i]);
270   }
271 }
272 
273 // in_size = 3, ksize = 3, stride = 3, pad_size = 2
TEST_F(OpsUtilTest,GetBroadcastTest3_3_3_2)274 TEST_F(OpsUtilTest, GetBroadcastTest3_3_3_2) {
275   bcast_struct bcast[] = {
276       {{0, 3, 3, 3, 2}, {0, 1}},
277   };
278   for (size_t i = 0; i < sizeof(bcast) / sizeof(bcast[0]); ++i) {
279     VerifyBcastValues(bcast[i]);
280   }
281 }
282 
283 // in_size = 3, ksize = 1, stride = 2, pad_size = 0
TEST_F(OpsUtilTest,GetBroadcastTest3_1_2_0)284 TEST_F(OpsUtilTest, GetBroadcastTest3_1_2_0) {
285   bcast_struct bcast[] = {
286       {{0, 3, 1, 2, 0}, {0, 1}},
287       {{1, 3, 1, 2, 0}, {2, 1}},
288   };
289   for (size_t i = 0; i < sizeof(bcast) / sizeof(bcast[0]); ++i) {
290     VerifyBcastValues(bcast[i]);
291   }
292 }
293 
294 // in_size = 3, ksize = 2, stride = 3, pad_size = 0
TEST_F(OpsUtilTest,GetBroadcastTest3_2_3_0)295 TEST_F(OpsUtilTest, GetBroadcastTest3_2_3_0) {
296   bcast_struct bcast[] = {
297       {{0, 3, 2, 3, 0}, {0, 2}},
298   };
299   for (size_t i = 0; i < sizeof(bcast) / sizeof(bcast[0]); ++i) {
300     VerifyBcastValues(bcast[i]);
301   }
302 }
303 
304 // in_size = 3, ksize = 2, stride = 3, pad_size = 1
TEST_F(OpsUtilTest,GetBroadcastTest3_2_3_1)305 TEST_F(OpsUtilTest, GetBroadcastTest3_2_3_1) {
306   bcast_struct bcast[] = {
307       {{0, 3, 2, 3, 1}, {0, 1}},
308       {{1, 3, 2, 3, 1}, {2, 1}},
309   };
310   for (size_t i = 0; i < sizeof(bcast) / sizeof(bcast[0]); ++i) {
311     VerifyBcastValues(bcast[i]);
312   }
313 }
314 
TEST_F(OpsUtilTest,SanitizeThreadSuffix)315 TEST_F(OpsUtilTest, SanitizeThreadSuffix) {
316   EXPECT_EQ("_aBc123_-___", SanitizeThreadSuffix("/aBc123_-  /"));
317 }
318 
TEST_F(OpsUtilTest,Aligned1DSlice)319 TEST_F(OpsUtilTest, Aligned1DSlice) {
320 #if EIGEN_MAX_ALIGN_BYTES == 0
321   // When EIGEN_MAX_ALIGN_BYTES is 0, a 1D tensor is always aligned.
322   Tensor t(DT_FLOAT, TensorShape({3}));
323   int64 start = 0;
324   int64 end = 1;
325   bool output = IsDim0SliceAligned<float>(t.shape(), start, end);
326   EXPECT_EQ(output, true);
327 #else
328   Tensor t(DT_FLOAT, TensorShape({EIGEN_MAX_ALIGN_BYTES * 2}));
329   int64_t start = 0;
330   int64_t end = EIGEN_MAX_ALIGN_BYTES;
331   bool output = IsDim0SliceAligned<float>(t.shape(), start, end);
332   EXPECT_EQ(output, true);
333   // Checks sliced 1D tensor is aligned for sanity.
334   Tensor sliced;
335   CHECK(sliced.CopyFrom(t.Slice(start, end), TensorShape({end - start})));
336   EXPECT_EQ(sliced.IsAligned(), true);
337 #endif
338 }
339 
340 #if EIGEN_MAX_ALIGN_BYTES > 0
TEST_F(OpsUtilTest,Misaligned1DSlice)341 TEST_F(OpsUtilTest, Misaligned1DSlice) {
342   Tensor t(DT_FLOAT, TensorShape({EIGEN_MAX_ALIGN_BYTES * 2}));
343   int64_t start = 1;
344   int64_t end = EIGEN_MAX_ALIGN_BYTES + 1;
345   bool output = IsDim0SliceAligned<float>(t.shape(), start, end);
346   EXPECT_EQ(output, false);
347   // Checks sliced 1D tensor is misaligned for sanity.
348   Tensor sliced;
349   CHECK(sliced.CopyFrom(t.Slice(start, end), TensorShape({end - start})));
350   EXPECT_EQ(sliced.IsAligned(), false);
351 }
352 #endif
353 
TEST_F(OpsUtilTest,Aligned2DSliceOfDim0)354 TEST_F(OpsUtilTest, Aligned2DSliceOfDim0) {
355 #if EIGEN_MAX_ALIGN_BYTES == 0
356   // When EIGEN_MAX_ALIGN_BYTES is 0 and the size of the first dimension is
357   // nonzero, a multidimensional tensor is always aligned.
358   Tensor t(DT_FLOAT, TensorShape({3, 4}));
359   int64 start = 1;
360   int64 end = 2;
361   bool output = IsDim0SliceAligned<float>(t.shape(), start, end);
362   EXPECT_EQ(output, true);
363 #else
364   // For multidimensional tensors, alignment is dictated by inner_dim_size.
365   int64_t inner_dim_size = EIGEN_MAX_ALIGN_BYTES;
366   Tensor t(DT_FLOAT, TensorShape({3, inner_dim_size}));
367   int64_t start = 1;
368   int64_t end = 2;
369   bool output = IsDim0SliceAligned<float>(t.shape(), start, end);
370   EXPECT_EQ(output, true);
371   // Checks sliced 2D is aligned, for sanity.
372   Tensor sliced;
373   CHECK(sliced.CopyFrom(t.Slice(start, end), TensorShape({1, inner_dim_size})));
374   EXPECT_EQ(sliced.IsAligned(), true);
375 #endif
376 }
377 
378 #if EIGEN_MAX_ALIGN_BYTES > 0
TEST_F(OpsUtilTest,Misaligned2DSliceOfDim0)379 TEST_F(OpsUtilTest, Misaligned2DSliceOfDim0) {
380   // For multidimensional tensors, alignment is dictated by inner_dim_size.
381   int64_t inner_dim_size = EIGEN_MAX_ALIGN_BYTES + 1;
382   Tensor t(DT_FLOAT, TensorShape({3, inner_dim_size}));
383   int64_t start = 1;
384   int64_t end = 2;
385   bool output = IsDim0SliceAligned<float>(t.shape(), start, end);
386   EXPECT_EQ(output, false);
387   // Checks sliced 2D is misaligned, for sanity.
388   Tensor sliced;
389   CHECK(sliced.CopyFrom(t.Slice(start, end), TensorShape({1, inner_dim_size})));
390   EXPECT_EQ(sliced.IsAligned(), false);
391 }
392 #endif
393 
TEST_F(OpsUtilTest,MisalignedEmptyShape)394 TEST_F(OpsUtilTest, MisalignedEmptyShape) {
395   TensorShape shape({});
396   int64_t start = 1;
397   int64_t end = 2;
398   bool output = IsDim0SliceAligned<float>(shape, start, end);
399   EXPECT_EQ(output, false);
400 }
401 
TEST_F(OpsUtilTest,MisalignedEmptyDim0)402 TEST_F(OpsUtilTest, MisalignedEmptyDim0) {
403   TensorShape shape({0, 1, 2});
404   int64_t start = 0;
405   int64_t end = 1;
406   bool output = IsDim0SliceAligned<float>(shape, start, end);
407   EXPECT_EQ(output, false);
408 }
409 
410 }  // namespace
411 }  // namespace tensorflow
412