1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/framework/bounds_check.h"
17 #include "tensorflow/core/framework/common_shape_fns.h"
18 #include "tensorflow/core/framework/op.h"
19 #include "tensorflow/core/framework/op_kernel.h"
20 #include "tensorflow/core/framework/register_types.h"
21 #include "tensorflow/core/framework/register_types_traits.h"
22 #include "tensorflow/core/framework/shape_inference.h"
23 #include "tensorflow/core/lib/gtl/array_slice.h"
24 #include "tensorflow/core/platform/types.h"
25 #include "tensorflow/core/util/work_sharder.h"
26
27 namespace tensorflow {
28
29 #define EIGEN_USE_THREADS
30 using CPUDevice = Eigen::ThreadPoolDevice;
31
32 // dim_size - the size of each dimension
33 // dim_range - the number of indices over in the flattened tensor
34 // you need to skip in order to make it over from one side of a dimension
35 // to the other. Used to make the shifts wrap around after a threshold.
36 // threshold - the index for each dimension that the roll starts to wrap
37 // back to the front
38 template <typename T>
DoRoll(OpKernelContext * context,const int64 num_elements,const int num_dims,const gtl::ArraySlice<int> & dim_size,const T * input,T * output,const gtl::ArraySlice<int> & threshold,const gtl::ArraySlice<int64> & dim_range)39 void DoRoll(OpKernelContext* context, const int64 num_elements,
40 const int num_dims, const gtl::ArraySlice<int>& dim_size,
41 const T* input, T* output, const gtl::ArraySlice<int>& threshold,
42 const gtl::ArraySlice<int64>& dim_range) {
43 auto work = [input, output, num_dims, &dim_size, &threshold, &dim_range](
44 int64 start, int64 end) {
45 // array of indices for each dimension
46 gtl::InlinedVector<int, 4> indices(num_dims);
47 int offset = 0; // the shift along the flattened tensor for current element
48 // initialize indices and offset
49 for (int i = 0; i < num_dims; i++) {
50 // stride is the number of indices over in the flattened tensor
51 // you need to skip in order to make it over to an adjacent element
52 // along a dimension. dim_size[i] != 0 because we set it to max(dim, 1)
53 const int64 stride = dim_range[i] / dim_size[i];
54 const int shift = dim_size[i] - threshold[i];
55 const int indx = (start / stride) % dim_size[i];
56 indices[i] = indx;
57 // calculate dimension index after the shift
58 const int shifted_indx = (indx + shift) % dim_size[i];
59 offset += (shifted_indx - indx) * stride;
60 }
61
62 for (int64 i = start; i < end; i++) {
63 output[i + offset] = input[i];
64 // create next combination of indices
65 // while at it adjust offset if needed
66 for (int j = num_dims - 1; j >= 0; j--) {
67 const int indx = (indices[j] + 1) % dim_size[j];
68 indices[j] = indx;
69 if (indx != 0) {
70 if (indx == threshold[j]) { // we've reached the threshold
71 // dim_range[j] = threshold[j] + shift[j]
72 // offset = shift[j] + ... other offsets
73 // offset - dim_range[j] = -threshold[j] + ... other offsets
74 // thus we undo our previous offset as well as add a new offset of
75 // -threshold[j] in one operation
76 offset -= dim_range[j]; // now wraps around
77 }
78 break; // indx != 0 don't need to carry
79 } else if (threshold[j] != 0) { // if threshold is 0 shift is 0
80 offset += dim_range[j]; // indx became 0 so reverse wrap around
81 }
82 }
83 }
84 };
85 // Shard
86 auto worker_threads = context->device()->tensorflow_cpu_worker_threads();
87 // 15 - expiramentally determined with float and bool types
88 const int cost_per_element = 15 * sizeof(T); // rough estimate
89 Shard(worker_threads->num_threads, worker_threads->workers, num_elements,
90 cost_per_element, std::move(work));
91 }
92
93 // dim_size - the size of each dimension
94 // dim_range - the number of indices over in the flattened tensor
95 // you need to skip in order to make it over from one side of a dimension
96 // to the other. Used to make the shifts wrap around after a threshold.
97 // threshold - the index for each dimension that the roll starts to wrap
98 // back to the front
99 // isd - inner shift dimension
100 template <typename T>
101 // Use memcpy to copy memory in groups when the data type supports memcpy
DoRollWithMemcpy(OpKernelContext * context,const int64 num_elements,const int num_dims,const gtl::ArraySlice<int> & dim_size,const T * input,T * output,const gtl::ArraySlice<int> & threshold,const gtl::ArraySlice<int64> & dim_range,const int64 isd)102 void DoRollWithMemcpy(OpKernelContext* context, const int64 num_elements,
103 const int num_dims, const gtl::ArraySlice<int>& dim_size,
104 const T* input, T* output,
105 const gtl::ArraySlice<int>& threshold,
106 const gtl::ArraySlice<int64>& dim_range,
107 const int64 isd) {
108 auto work = [input, output, num_dims, &dim_size, &threshold, &dim_range, isd](
109 int64 start, int64 end) {
110 // the number of indices over in the flattened tensor you need to skip in
111 // order to make it over from one side of the isd to the other
112 const int64 isd_range = std::max<int>(dim_range[isd], 1);
113 // the distance along the flattend tensor to the next element in the isd
114 const int64 isd_stride = isd_range / std::max<int>(dim_size[isd], 1);
115
116 // start and end represent the i-th group currently so we will convert
117 // them into numbers representing the i-th elements.
118 // there are 2 groups per isd one for all elements before threshold[isd]
119 // and another for all elements after threshold[isd].
120 const int64 start_remainder = (start % 2) * threshold[isd] * isd_stride;
121 const int64 end_remainder = (end % 2) * threshold[isd] * isd_stride;
122 start = (start / 2) * isd_range + start_remainder;
123 end = (end / 2) * isd_range + end_remainder;
124
125 const T* in_ptr = &input[0];
126 T* out_ptr = &output[0];
127 in_ptr += start;
128 out_ptr += start;
129
130 // array of indices for each dimension
131 // indicies = [i, j, k, l, m, n]
132 gtl::InlinedVector<int, 4> indicies(num_dims);
133 // the offset needed to make all inner non-shifting dimensions become 0
134 int64 remainder_offset = 0;
135 // initialize indicies
136 for (int i = 0; i < num_dims; i++) {
137 // stride is the number of indices over in the flattened tensor
138 // you need to skip in order to make it over to an adjacent element
139 // along a dimension. dim_size[i] != 0 because we set it to max(dim, 1)
140 const int64 stride = dim_range[i] / dim_size[i];
141 const int shift = dim_size[i] - threshold[i];
142 const int indx = (start / stride) % dim_size[i];
143 indicies[i] = indx;
144 // calculate dimension index after the shift
145 int out_indx = (indx + shift) % dim_size[i];
146 if (i > isd) {
147 // trailing zeroes for indices after the inner shifted dimension
148 out_indx = 0;
149 remainder_offset += (out_indx - indx) * stride;
150 }
151 out_ptr += (out_indx - indx) * stride;
152 }
153 // set trailing zeroes for indices after the inner shifted dimension
154 for (int i = num_dims - 1; i > isd; i--) indicies[i] = 0;
155
156 // the number of indices in the isd dimension the next group will skip
157 // to make it to the next threshold or end point
158 int isd_indx_skip = 0;
159 // the size of the next group
160 int64 group_size = 0;
161 // initialize isd_indx_skip and group_size
162 if (indicies[isd] < threshold[isd]) {
163 isd_indx_skip = threshold[isd] - indicies[isd];
164 group_size = isd_indx_skip * isd_stride + remainder_offset;
165 } else {
166 isd_indx_skip = dim_size[isd] - indicies[isd];
167 group_size = isd_indx_skip * isd_stride + remainder_offset;
168 }
169
170 int64 i = start;
171 while (i < end) {
172 // copy group of elements
173 memcpy(out_ptr, in_ptr, group_size * sizeof(T));
174
175 // shift i and the pointers over to the next group position
176 i += group_size;
177 out_ptr += group_size;
178 in_ptr += group_size;
179
180 // produce next combination of indices and adjust the out_ptr position
181 // to fix the offset if necessary
182 // the isd (inner shift dim) should skip to next threshold or endpoint
183 // all dimensions to the left increment by 1 when a digit is carried
184 // all dimensions to the right remain set to 0
185 // +1 +1 +1 +isd_indx_skip
186 // indicies = [i, j, k, l, 0, 0]
187 // ^isd
188 for (int j = isd; j >= 0; j--) {
189 int inc = 1;
190 if (j == isd) inc = isd_indx_skip;
191 const int indx = (indicies[j] + inc) % dim_size[j];
192 indicies[j] = indx;
193 if (indx != 0) {
194 if (indx == threshold[j]) {
195 out_ptr -= dim_range[j]; // now wraps around
196 }
197 break; // indx != 0 don't need to carry
198 } else if (threshold[j] != 0) { // if threshold is 0 shift is 0
199 out_ptr += dim_range[j]; // indx became 0 so reverse wrap around
200 }
201 }
202
203 // set isd_indx_skip and group_size for next iteration
204 if (indicies[isd] < threshold[isd]) {
205 isd_indx_skip = threshold[isd] - indicies[isd];
206 group_size = isd_indx_skip * isd_stride;
207 } else {
208 isd_indx_skip = dim_size[isd] - indicies[isd];
209 group_size = isd_indx_skip * isd_stride;
210 }
211 }
212 };
213 // Shard
214 auto worker_threads = context->device()->tensorflow_cpu_worker_threads();
215 const int64 ave_group_size = dim_range[isd] / 2;
216 const int total_work = 2 * num_elements / std::max<int>(dim_range[isd], 1);
217 // 25000 - expiramentally determined with float and bool types
218 const int cost_per_group = 25000 * sizeof(T) * ave_group_size;
219 Shard(worker_threads->num_threads, worker_threads->workers, total_work,
220 cost_per_group, std::move(work));
221 }
222
223 template <typename Device, typename T, typename Tshift, typename Taxis>
224 class RollOp : public OpKernel {
225 public:
RollOp(OpKernelConstruction * context)226 explicit RollOp(OpKernelConstruction* context) : OpKernel(context) {}
227
Compute(OpKernelContext * context)228 void Compute(OpKernelContext* context) override {
229 // Grab the input tensor
230 const Tensor& input = context->input(0);
231 const Tensor& shift = context->input(1);
232 const Tensor& axis = context->input(2);
233
234 auto shift_flat = shift.flat<Tshift>();
235 auto axis_flat = axis.flat<Taxis>();
236
237 OP_REQUIRES(context, TensorShapeUtils::IsVectorOrHigher(input.shape()),
238 errors::InvalidArgument("input must be 1-D or higher"));
239 OP_REQUIRES(context, shift.shape().dims() <= 1,
240 errors::InvalidArgument(
241 "shift must be a scalar or a 1-D vector. Found: ",
242 shift.shape().DebugString()));
243 OP_REQUIRES(context, axis.shape().dims() <= 1,
244 errors::InvalidArgument(
245 "axis must be a scalar or a 1-D vector. Found: ",
246 axis.shape().DebugString()));
247 OP_REQUIRES(
248 context, shift.shape() == axis.shape(),
249 errors::InvalidArgument("shift and axis must have the same size"));
250 const int64 num_elements = input.NumElements();
251 const int num_shifts = static_cast<int>(shift_flat.size());
252 const int num_dims = input.dims();
253
254 // if there are any duplicate axes, shift_mod_sum will have the
255 // total modulo sum of shifts for each dimension
256 gtl::InlinedVector<int, 4> shift_mod_sum(num_dims, 0);
257 for (int i = 0; i < num_shifts; i++) {
258 int axis = axis_flat(i);
259 if (axis < 0) {
260 axis += num_dims;
261 }
262 OP_REQUIRES(context, FastBoundsCheck(axis, num_dims),
263 errors::InvalidArgument("axis ", axis, " is out of range"));
264 const int ds = std::max<int>(static_cast<int>(input.dim_size(axis)), 1);
265 const int sum = shift_mod_sum[axis] + static_cast<int>(shift_flat(i));
266 // modulo that works with negatives: ((x % y) + y) % y
267 shift_mod_sum[axis] = (sum % ds + ds) % ds;
268 }
269 // the size of each dimension
270 gtl::InlinedVector<int, 4> dim_size(num_dims);
271 // threshold[i] is the index that the roll starts to wrap back to the front
272 gtl::InlinedVector<int, 4> threshold(num_dims);
273 // dim_range is the number of indices over in the flattened tensor
274 // you need to skip in order to make it over from one side of a dimension
275 // to the other. Used to make the shifts wrap around after a threshold.
276 gtl::InlinedVector<int64, 4> dim_range(num_dims);
277 int64 dim_size_prod = 1; // dimension size product
278 // inner shift dimension (inner most shifted dimension)
279 int64 isd = 0;
280 for (int i = num_dims - 1; i >= 0; i--) {
281 if (isd == 0 && shift_mod_sum[i] != 0) isd = i;
282 const int ds = std::max<int>(static_cast<int>(input.dim_size(i)), 1);
283 dim_size[i] = ds;
284 threshold[i] = (ds - shift_mod_sum[i]) % ds;
285 dim_size_prod *= static_cast<int64>(input.dim_size(i));
286 dim_range[i] = dim_size_prod;
287 }
288
289 Tensor* output = nullptr;
290 OP_REQUIRES_OK(context,
291 context->allocate_output(0, input.shape(), &output));
292 auto input_flat = input.flat<T>().data();
293 auto output_flat = output->flat<T>().data();
294
295 if (std::is_same<Device, CPUDevice>::value) {
296 if (DataTypeCanUseMemcpy(DataTypeToEnum<T>::v())) {
297 // V2 copies memory in groups instead of element by element
298 DoRollWithMemcpy<T>(context, num_elements, num_dims, dim_size,
299 input_flat, output_flat, threshold, dim_range, isd);
300 } else {
301 // incase memcpy does not work for current data type
302 DoRoll<T>(context, num_elements, num_dims, dim_size, input_flat,
303 output_flat, threshold, dim_range);
304 }
305 }
306 }
307 };
308
309 // Register the CPU kernels.
310 #define REGISTER_CPU(type) \
311 REGISTER_KERNEL_BUILDER(Name("Roll") \
312 .Device(DEVICE_CPU) \
313 .TypeConstraint<type>("T") \
314 .TypeConstraint<int32>("Tshift") \
315 .TypeConstraint<int32>("Taxis"), \
316 RollOp<CPUDevice, type, int32, int32>) \
317 REGISTER_KERNEL_BUILDER(Name("Roll") \
318 .Device(DEVICE_CPU) \
319 .TypeConstraint<type>("T") \
320 .TypeConstraint<int64>("Tshift") \
321 .TypeConstraint<int32>("Taxis"), \
322 RollOp<CPUDevice, type, int64, int32>) \
323 REGISTER_KERNEL_BUILDER(Name("Roll") \
324 .Device(DEVICE_CPU) \
325 .TypeConstraint<type>("T") \
326 .TypeConstraint<int32>("Tshift") \
327 .TypeConstraint<int64>("Taxis"), \
328 RollOp<CPUDevice, type, int32, int64>) \
329 REGISTER_KERNEL_BUILDER(Name("Roll") \
330 .Device(DEVICE_CPU) \
331 .TypeConstraint<type>("T") \
332 .TypeConstraint<int64>("Tshift") \
333 .TypeConstraint<int64>("Taxis"), \
334 RollOp<CPUDevice, type, int64, int64>)
335
336 TF_CALL_ALL_TYPES(REGISTER_CPU);
337 #undef REGISTER_CPU
338 } // namespace tensorflow
339