• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_STRIDED_SLICE_LOGIC_H_
17 #define TENSORFLOW_LITE_KERNELS_INTERNAL_STRIDED_SLICE_LOGIC_H_
18 
19 #include <limits>
20 #include <vector>
21 
22 #include "tensorflow/lite/kernels/internal/compatibility.h"
23 #include "tensorflow/lite/kernels/internal/types.h"
24 
25 namespace tflite {
26 namespace strided_slice {
27 
28 // Use until std::clamp() is available from C++17.
Clamp(const int v,const int lo,const int hi)29 inline int Clamp(const int v, const int lo, const int hi) {
30   TFLITE_DCHECK(!(hi < lo));
31   if (hi < v) return hi;
32   if (v < lo) return lo;
33   return v;
34 }
35 
StridedSlicePadIndices(tflite::StridedSliceParams * p,int dim_count)36 inline void StridedSlicePadIndices(tflite::StridedSliceParams* p,
37                                    int dim_count) {
38   // Add indices and mask bits to fully include extra dimensions
39   TFLITE_CHECK_LE(dim_count, 5);
40   TFLITE_CHECK_GE(dim_count, p->start_indices_count);
41   TFLITE_CHECK_EQ(p->start_indices_count, p->stop_indices_count);
42   TFLITE_CHECK_EQ(p->stop_indices_count, p->strides_count);
43 
44   const int pad_count = dim_count - p->start_indices_count;
45 
46   // Pad indices at start, so move arrays by pad_count.
47   for (int i = p->start_indices_count - 1; i >= 0; --i) {
48     p->strides[i + pad_count] = p->strides[i];
49     p->start_indices[i + pad_count] = p->start_indices[i];
50     p->stop_indices[i + pad_count] = p->stop_indices[i];
51   }
52   for (int i = 0; i < pad_count; ++i) {
53     p->start_indices[i] = 0;
54     p->stop_indices[i] = 1;
55     p->strides[i] = 1;
56   }
57 
58   // Pad masks with 0s or 1s as required.
59   p->shrink_axis_mask <<= pad_count;
60   p->ellipsis_mask <<= pad_count;
61   p->new_axis_mask <<= pad_count;
62   p->begin_mask <<= pad_count;
63   p->end_mask <<= pad_count;
64   p->begin_mask |= (1 << pad_count) - 1;
65   p->end_mask |= (1 << pad_count) - 1;
66 
67   p->start_indices_count = dim_count;
68   p->stop_indices_count = dim_count;
69   p->strides_count = dim_count;
70 }
71 
72 // Return the index for the first element along that axis. This index will be a
73 // positive integer between [0, axis_size] (or [-1, axis_size -1] if stride < 0)
74 // that can be used to index directly into the data.
StartForAxis(const tflite::StridedSliceParams & params,const RuntimeShape & input_shape,int axis)75 inline int StartForAxis(const tflite::StridedSliceParams& params,
76                         const RuntimeShape& input_shape, int axis) {
77   const auto begin_mask = params.begin_mask;
78   const auto* start_indices = params.start_indices;
79   const auto* strides = params.strides;
80   const int axis_size = input_shape.Dims(axis);
81   if (axis_size == 0) {
82     return 0;
83   }
84   // Begin with the specified index.
85   int start = start_indices[axis];
86 
87   // begin_mask override
88   if (begin_mask & 1 << axis) {
89     if (strides[axis] > 0) {
90       // Forward iteration - use the first element. These values will get
91       // clamped below (Note: We could have set them to 0 and axis_size-1, but
92       // use lowest() and max() to maintain symmetry with StopForAxis())
93       start = std::numeric_limits<int>::lowest();
94     } else {
95       // Backward iteration - use the last element.
96       start = std::numeric_limits<int>::max();
97     }
98   }
99 
100   // Handle negative indices
101   if (start < 0) {
102     start += axis_size;
103   }
104 
105   // Clamping
106   if (strides[axis] > 0) {
107     // Forward iteration
108     start = Clamp(start, 0, axis_size);
109   } else {
110     // Backward iteration
111     start = Clamp(start, -1, axis_size - 1);
112   }
113 
114   return start;
115 }
116 
117 // Return the "real" index for the end of iteration along that axis. This is an
118 // "end" in the traditional C sense, in that it points to one past the last
119 // element. ie. So if you were iterating through all elements of a 1D array of
120 // size 4, this function would return 4 as the stop, because it is one past the
121 // "real" indices of 0, 1, 2 & 3.
StopForAxis(const tflite::StridedSliceParams & params,const RuntimeShape & input_shape,int axis,int start_for_axis)122 inline int StopForAxis(const tflite::StridedSliceParams& params,
123                        const RuntimeShape& input_shape, int axis,
124                        int start_for_axis) {
125   const auto end_mask = params.end_mask;
126   const auto shrink_axis_mask = params.shrink_axis_mask;
127   const auto* stop_indices = params.stop_indices;
128   const auto* strides = params.strides;
129   const int axis_size = input_shape.Dims(axis);
130   if (axis_size == 0) {
131     return 0;
132   }
133 
134   // Begin with the specified index
135   const bool shrink_axis = shrink_axis_mask & (1 << axis);
136   int stop = stop_indices[axis];
137 
138   // When shrinking an axis, the end position does not matter (and can be
139   // incorrect when negative indexing is used, see Issue #19260). Always use
140   // start_for_axis + 1 to generate a length 1 slice, since start_for_axis has
141   // already been adjusted for negative indices.
142   if (shrink_axis) {
143     return start_for_axis + 1;
144   }
145 
146   // end_mask override
147   if (end_mask & (1 << axis)) {
148     if (strides[axis] > 0) {
149       // Forward iteration - use the last element. These values will get
150       // clamped below
151       stop = std::numeric_limits<int>::max();
152     } else {
153       // Backward iteration - use the first element.
154       stop = std::numeric_limits<int>::lowest();
155     }
156   }
157 
158   // Handle negative indices
159   if (stop < 0) {
160     stop += axis_size;
161   }
162 
163   // Clamping
164   // Because the end index points one past the last element, we need slightly
165   // different clamping ranges depending on the direction.
166   if (strides[axis] > 0) {
167     // Forward iteration
168     stop = Clamp(stop, 0, axis_size);
169   } else {
170     // Backward iteration
171     stop = Clamp(stop, -1, axis_size - 1);
172   }
173 
174   return stop;
175 }
176 
LoopCondition(int index,int stop,int stride)177 inline bool LoopCondition(int index, int stop, int stride) {
178   // True when we have reached the end of an axis and should loop.
179   return stride > 0 ? index >= stop : index <= stop;
180 }
181 
BuildStridedSliceParams(int begin_mask,int end_mask,int shrink_axis_mask,const std::vector<int> & start_indices,const std::vector<int> & stop_indices,const std::vector<int> & strides)182 inline tflite::StridedSliceParams BuildStridedSliceParams(
183     int begin_mask, int end_mask, int shrink_axis_mask,
184     const std::vector<int>& start_indices, const std::vector<int>& stop_indices,
185     const std::vector<int>& strides) {
186   tflite::StridedSliceParams op_params;
187   const int dims_count = start_indices.size();
188 
189   op_params.start_indices_count = dims_count;
190   op_params.stop_indices_count = dims_count;
191   op_params.strides_count = dims_count;
192   for (int i = 0; i < dims_count; ++i) {
193     op_params.start_indices[i] = start_indices[i];
194     op_params.stop_indices[i] = stop_indices[i];
195     op_params.strides[i] = strides[i];
196   }
197 
198   op_params.begin_mask = begin_mask;
199   op_params.ellipsis_mask = 0;
200   op_params.end_mask = end_mask;
201   op_params.new_axis_mask = 0;
202   op_params.shrink_axis_mask = shrink_axis_mask;
203 
204   return op_params;
205 }
206 
207 }  // namespace strided_slice
208 
209 }  // namespace tflite
210 
211 #endif  // TENSORFLOW_LITE_KERNELS_INTERNAL_STRIDED_SLICE_LOGIC_H_
212