1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_STRIDED_SLICE_LOGIC_H_
17 #define TENSORFLOW_LITE_KERNELS_INTERNAL_STRIDED_SLICE_LOGIC_H_
18
19 #include <limits>
20 #include <vector>
21
22 #include "tensorflow/lite/kernels/internal/compatibility.h"
23 #include "tensorflow/lite/kernels/internal/types.h"
24
25 namespace tflite {
26 namespace strided_slice {
27
28 // Use until std::clamp() is available from C++17.
Clamp(const int v,const int lo,const int hi)29 inline int Clamp(const int v, const int lo, const int hi) {
30 TFLITE_DCHECK(!(hi < lo));
31 if (hi < v) return hi;
32 if (v < lo) return lo;
33 return v;
34 }
35
StridedSlicePadIndices(tflite::StridedSliceParams * p,int dim_count)36 inline void StridedSlicePadIndices(tflite::StridedSliceParams* p,
37 int dim_count) {
38 // Add indices and mask bits to fully include extra dimensions
39 TFLITE_CHECK_LE(dim_count, 5);
40 TFLITE_CHECK_GE(dim_count, p->start_indices_count);
41 TFLITE_CHECK_EQ(p->start_indices_count, p->stop_indices_count);
42 TFLITE_CHECK_EQ(p->stop_indices_count, p->strides_count);
43
44 const int pad_count = dim_count - p->start_indices_count;
45
46 // Pad indices at start, so move arrays by pad_count.
47 for (int i = p->start_indices_count - 1; i >= 0; --i) {
48 p->strides[i + pad_count] = p->strides[i];
49 p->start_indices[i + pad_count] = p->start_indices[i];
50 p->stop_indices[i + pad_count] = p->stop_indices[i];
51 }
52 for (int i = 0; i < pad_count; ++i) {
53 p->start_indices[i] = 0;
54 p->stop_indices[i] = 1;
55 p->strides[i] = 1;
56 }
57
58 // Pad masks with 0s or 1s as required.
59 p->shrink_axis_mask <<= pad_count;
60 p->ellipsis_mask <<= pad_count;
61 p->new_axis_mask <<= pad_count;
62 p->begin_mask <<= pad_count;
63 p->end_mask <<= pad_count;
64 p->begin_mask |= (1 << pad_count) - 1;
65 p->end_mask |= (1 << pad_count) - 1;
66
67 p->start_indices_count = dim_count;
68 p->stop_indices_count = dim_count;
69 p->strides_count = dim_count;
70 }
71
72 // Return the index for the first element along that axis. This index will be a
73 // positive integer between [0, axis_size] (or [-1, axis_size -1] if stride < 0)
74 // that can be used to index directly into the data.
StartForAxis(const tflite::StridedSliceParams & params,const RuntimeShape & input_shape,int axis)75 inline int StartForAxis(const tflite::StridedSliceParams& params,
76 const RuntimeShape& input_shape, int axis) {
77 const auto begin_mask = params.begin_mask;
78 const auto* start_indices = params.start_indices;
79 const auto* strides = params.strides;
80 const int axis_size = input_shape.Dims(axis);
81 if (axis_size == 0) {
82 return 0;
83 }
84 // Begin with the specified index.
85 int start = start_indices[axis];
86
87 // begin_mask override
88 if (begin_mask & 1 << axis) {
89 if (strides[axis] > 0) {
90 // Forward iteration - use the first element. These values will get
91 // clamped below (Note: We could have set them to 0 and axis_size-1, but
92 // use lowest() and max() to maintain symmetry with StopForAxis())
93 start = std::numeric_limits<int>::lowest();
94 } else {
95 // Backward iteration - use the last element.
96 start = std::numeric_limits<int>::max();
97 }
98 }
99
100 // Handle negative indices
101 if (start < 0) {
102 start += axis_size;
103 }
104
105 // Clamping
106 if (strides[axis] > 0) {
107 // Forward iteration
108 start = Clamp(start, 0, axis_size);
109 } else {
110 // Backward iteration
111 start = Clamp(start, -1, axis_size - 1);
112 }
113
114 return start;
115 }
116
117 // Return the "real" index for the end of iteration along that axis. This is an
118 // "end" in the traditional C sense, in that it points to one past the last
119 // element. ie. So if you were iterating through all elements of a 1D array of
120 // size 4, this function would return 4 as the stop, because it is one past the
121 // "real" indices of 0, 1, 2 & 3.
StopForAxis(const tflite::StridedSliceParams & params,const RuntimeShape & input_shape,int axis,int start_for_axis)122 inline int StopForAxis(const tflite::StridedSliceParams& params,
123 const RuntimeShape& input_shape, int axis,
124 int start_for_axis) {
125 const auto end_mask = params.end_mask;
126 const auto shrink_axis_mask = params.shrink_axis_mask;
127 const auto* stop_indices = params.stop_indices;
128 const auto* strides = params.strides;
129 const int axis_size = input_shape.Dims(axis);
130 if (axis_size == 0) {
131 return 0;
132 }
133
134 // Begin with the specified index
135 const bool shrink_axis = shrink_axis_mask & (1 << axis);
136 int stop = stop_indices[axis];
137
138 // When shrinking an axis, the end position does not matter (and can be
139 // incorrect when negative indexing is used, see Issue #19260). Always use
140 // start_for_axis + 1 to generate a length 1 slice, since start_for_axis has
141 // already been adjusted for negative indices.
142 if (shrink_axis) {
143 return start_for_axis + 1;
144 }
145
146 // end_mask override
147 if (end_mask & (1 << axis)) {
148 if (strides[axis] > 0) {
149 // Forward iteration - use the last element. These values will get
150 // clamped below
151 stop = std::numeric_limits<int>::max();
152 } else {
153 // Backward iteration - use the first element.
154 stop = std::numeric_limits<int>::lowest();
155 }
156 }
157
158 // Handle negative indices
159 if (stop < 0) {
160 stop += axis_size;
161 }
162
163 // Clamping
164 // Because the end index points one past the last element, we need slightly
165 // different clamping ranges depending on the direction.
166 if (strides[axis] > 0) {
167 // Forward iteration
168 stop = Clamp(stop, 0, axis_size);
169 } else {
170 // Backward iteration
171 stop = Clamp(stop, -1, axis_size - 1);
172 }
173
174 return stop;
175 }
176
LoopCondition(int index,int stop,int stride)177 inline bool LoopCondition(int index, int stop, int stride) {
178 // True when we have reached the end of an axis and should loop.
179 return stride > 0 ? index >= stop : index <= stop;
180 }
181
BuildStridedSliceParams(int begin_mask,int end_mask,int shrink_axis_mask,const std::vector<int> & start_indices,const std::vector<int> & stop_indices,const std::vector<int> & strides)182 inline tflite::StridedSliceParams BuildStridedSliceParams(
183 int begin_mask, int end_mask, int shrink_axis_mask,
184 const std::vector<int>& start_indices, const std::vector<int>& stop_indices,
185 const std::vector<int>& strides) {
186 tflite::StridedSliceParams op_params;
187 const int dims_count = start_indices.size();
188
189 op_params.start_indices_count = dims_count;
190 op_params.stop_indices_count = dims_count;
191 op_params.strides_count = dims_count;
192 for (int i = 0; i < dims_count; ++i) {
193 op_params.start_indices[i] = start_indices[i];
194 op_params.stop_indices[i] = stop_indices[i];
195 op_params.strides[i] = strides[i];
196 }
197
198 op_params.begin_mask = begin_mask;
199 op_params.ellipsis_mask = 0;
200 op_params.end_mask = end_mask;
201 op_params.new_axis_mask = 0;
202 op_params.shrink_axis_mask = shrink_axis_mask;
203
204 return op_params;
205 }
206
207 } // namespace strided_slice
208
209 } // namespace tflite
210
211 #endif // TENSORFLOW_LITE_KERNELS_INTERNAL_STRIDED_SLICE_LOGIC_H_
212