• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/kernels/internal/strided_slice_logic.h"
16 
17 #include <gtest/gtest.h>
18 
19 namespace tflite {
20 namespace {
21 
RunStridedSlicePadIndices(std::initializer_list<int> begin,std::initializer_list<int> end,std::initializer_list<int> stride,std::initializer_list<int> expected_begin,std::initializer_list<int> expected_end,std::initializer_list<int> expected_stride)22 void RunStridedSlicePadIndices(std::initializer_list<int> begin,
23                                std::initializer_list<int> end,
24                                std::initializer_list<int> stride,
25                                std::initializer_list<int> expected_begin,
26                                std::initializer_list<int> expected_end,
27                                std::initializer_list<int> expected_stride) {
28   StridedSliceParams op_params;
29   int dims = begin.size();
30   op_params.start_indices_count = dims;
31   op_params.stop_indices_count = dims;
32   op_params.strides_count = dims;
33 
34   for (int i = 0; i < dims; ++i) {
35     op_params.start_indices[i] = begin.begin()[i];
36     op_params.stop_indices[i] = end.begin()[i];
37     op_params.strides[i] = stride.begin()[i];
38   }
39 
40   strided_slice::StridedSlicePadIndices(&op_params, 4);
41 
42   for (int i = 0; i < 4; ++i) {
43     EXPECT_EQ(op_params.start_indices[i], expected_begin.begin()[i]);
44     EXPECT_EQ(op_params.stop_indices[i], expected_end.begin()[i]);
45     EXPECT_EQ(op_params.strides[i], expected_stride.begin()[i]);
46   }
47 }
48 
TEST(RunStridedSlicePadIndices,Pad1)49 TEST(RunStridedSlicePadIndices, Pad1) {
50   RunStridedSlicePadIndices({1, 2, 3},     // begin
51                             {4, 5, 6},     // end
52                             {2, 2, 2},     // stride
53                             {0, 1, 2, 3},  // expected_begin
54                             {1, 4, 5, 6},  // expected_end
55                             {1, 2, 2, 2}   // expected_stride
56   );
57 }
58 
TEST(RunStridedSlicePadIndices,Pad2)59 TEST(RunStridedSlicePadIndices, Pad2) {
60   RunStridedSlicePadIndices({1, 2},        // begin
61                             {4, 5},        // end
62                             {2, 2},        // stride
63                             {0, 0, 1, 2},  // expected_begin
64                             {1, 1, 4, 5},  // expected_end
65                             {1, 1, 2, 2}   // expected_stride
66   );
67 }
68 
TEST(RunStridedSlicePadIndices,Pad3)69 TEST(RunStridedSlicePadIndices, Pad3) {
70   RunStridedSlicePadIndices({1},           // begin
71                             {4},           // end
72                             {2},           // stride
73                             {0, 0, 0, 1},  // expected_begin
74                             {1, 1, 1, 4},  // expected_end
75                             {1, 1, 1, 2}   // expected_stride
76   );
77 }
78 
79 }  // namespace
80 }  // namespace tflite
81