• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2018-2020 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #include "arm_compute/core/utils/helpers/tensor_transform.h"
25 
26 #include "bit_ops.h"
27 
28 namespace arm_compute
29 {
30 namespace helpers
31 {
32 namespace tensor_transform
33 {
calculate_stride_on_index(int index,Coordinates strides)34 int calculate_stride_on_index(int index, Coordinates strides)
35 {
36     return index >= static_cast<int>(strides.num_dimensions()) ? 1 : strides[index];
37 }
38 
calculate_start_on_index(TensorShape input_shape,int index,Coordinates starts,Coordinates strides,int32_t begin_mask)39 int calculate_start_on_index(TensorShape input_shape, int index, Coordinates starts, Coordinates strides, int32_t begin_mask)
40 {
41     // Early exit
42     if(index >= static_cast<int>(starts.num_dimensions()))
43     {
44         return 0;
45     }
46 
47     // Get stride
48     const int stride = calculate_stride_on_index(index, strides);
49 
50     // Calculate start
51     int start = starts[index];
52 
53     // Reset in case of begin mask present
54     if(arm_compute::helpers::bit_ops::is_bit_set(begin_mask, index))
55     {
56         start = stride > 0 ? std::numeric_limits<int>::lowest() : std::numeric_limits<int>::max();
57     }
58 
59     // Account negative start points
60     const int dim_size = input_shape[index];
61     if(start < 0)
62     {
63         start += dim_size;
64     }
65 
66     // Final clamp
67     start = utility::clamp(start, 0, dim_size - 1);
68 
69     return start;
70 }
71 
calculate_end_on_index(TensorShape input_shape,int index,int start_on_index,Coordinates ends,Coordinates strides,int32_t end_mask,int32_t shrink_axis_mask)72 int calculate_end_on_index(TensorShape input_shape, int index, int start_on_index,
73                            Coordinates ends, Coordinates strides,
74                            int32_t end_mask, int32_t shrink_axis_mask)
75 {
76     // Early exit
77     if(index >= static_cast<int>(ends.num_dimensions()))
78     {
79         return input_shape[index];
80     }
81 
82     const int  stride      = calculate_stride_on_index(index, strides);
83     const bool shrink_axis = arm_compute::helpers::bit_ops::is_bit_set(shrink_axis_mask, index);
84 
85     // Calculate start
86     int stop = ends[index];
87 
88     // Shrink dimension
89     if(shrink_axis)
90     {
91         if(start_on_index == std::numeric_limits<int>::max())
92         {
93             stop = start_on_index;
94         }
95         else
96         {
97             stop = start_on_index + 1;
98         }
99     }
100 
101     // Reset in case of begin mask present
102     if(arm_compute::helpers::bit_ops::is_bit_set(end_mask, index) && !shrink_axis)
103     {
104         stop = (stride > 0) ? std::numeric_limits<int>::max() : std::numeric_limits<int>::lowest();
105     }
106 
107     // Account negative end points
108     const int dim_size = input_shape[index];
109     if(stop < 0)
110     {
111         stop += dim_size;
112     }
113 
114     // Final clamp
115     stop = (stride > 0) ? utility::clamp(stop, 0, dim_size) : utility::clamp(stop, -1, dim_size - 1);
116 
117     return stop;
118 }
119 
calculate_strided_slice_coords(TensorShape input_shape,Coordinates starts,Coordinates ends,Coordinates strides,int32_t begin_mask,int32_t end_mask,int32_t shrink_axis_mask)120 std::tuple<Coordinates, Coordinates, Coordinates> calculate_strided_slice_coords(TensorShape input_shape,
121                                                                                  Coordinates starts, Coordinates ends, Coordinates strides,
122                                                                                  int32_t begin_mask, int32_t end_mask, int32_t shrink_axis_mask)
123 {
124     Coordinates starts_abs{};
125     Coordinates ends_abs{};
126     Coordinates final_strides{};
127 
128     for(unsigned int i = 0; i < input_shape.num_dimensions(); ++i)
129     {
130         const int start_i = calculate_start_on_index(input_shape, i, starts, strides, begin_mask);
131         starts_abs.set(i, start_i);
132         ends_abs.set(i, calculate_end_on_index(input_shape, i, start_i, ends, strides, end_mask, shrink_axis_mask));
133         final_strides.set(i, calculate_stride_on_index(i, strides));
134     }
135 
136     return std::make_tuple(starts_abs, ends_abs, final_strides);
137 }
138 
compute_strided_slice_output_shape(TensorShape input_shape,Coordinates starts,Coordinates ends,Coordinates strides,int32_t begin_mask,int32_t end_mask,int32_t shrink_axis_mask,bool return_unshrinked)139 TensorShape compute_strided_slice_output_shape(TensorShape input_shape, Coordinates starts, Coordinates ends, Coordinates strides,
140                                                int32_t begin_mask, int32_t end_mask, int32_t shrink_axis_mask, bool return_unshrinked)
141 {
142     unsigned int index = 0;
143 
144     TensorShape output_shape;
145     for(unsigned int i = 0; i < input_shape.num_dimensions(); ++i)
146     {
147         const int stride = calculate_stride_on_index(index, strides);
148         const int start  = calculate_start_on_index(input_shape, i, starts, strides, begin_mask);
149         const int end    = calculate_end_on_index(input_shape, i, start, ends, strides, end_mask, shrink_axis_mask);
150         const int range  = end - start;
151 
152         const bool is_shrink = arm_compute::helpers::bit_ops::is_bit_set(shrink_axis_mask, i);
153         if(return_unshrinked || !is_shrink)
154         {
155             if((range == 0) ||               // Zero range
156                (range < 0 && stride >= 0) || // Negative range with positive stride
157                (range > 0 && stride <= 0))   // Positive range with negative stride
158             {
159                 output_shape.set(index, 0);
160                 return output_shape;
161             }
162             else
163             {
164                 int dim = range / stride + (range % stride != 0 ? 1 : 0);
165                 output_shape.set(index++, dim);
166             }
167         }
168     }
169     return output_shape;
170 }
171 
construct_slice_end_mask(Coordinates ends)172 int32_t construct_slice_end_mask(Coordinates ends)
173 {
174     // Create end mask
175     int32_t end_mask = 0;
176     for(unsigned int i = 0; i < ends.num_dimensions(); ++i)
177     {
178         if(ends[i] < 0)
179         {
180             end_mask |= 1 << i;
181         }
182     }
183 
184     return end_mask;
185 }
186 } // namespace tensor_transform
187 } // namespace helpers
188 } // namespace arm_compute
189