• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/bounds_check.h"
17 #include "tensorflow/core/framework/common_shape_fns.h"
18 #include "tensorflow/core/framework/op.h"
19 #include "tensorflow/core/framework/op_kernel.h"
20 #include "tensorflow/core/framework/register_types.h"
21 #include "tensorflow/core/framework/register_types_traits.h"
22 #include "tensorflow/core/framework/shape_inference.h"
23 #include "tensorflow/core/lib/gtl/array_slice.h"
24 #include "tensorflow/core/platform/types.h"
25 #include "tensorflow/core/util/work_sharder.h"
26 
27 namespace tensorflow {
28 
29 #define EIGEN_USE_THREADS
30 using CPUDevice = Eigen::ThreadPoolDevice;
31 
32 // dim_size - the size of each dimension
33 // dim_range - the number of indices over in the flattened tensor
34 //    you need to skip in order to make it over from one side of a dimension
35 //    to the other. Used to make the shifts wrap around after a threshold.
36 // threshold - the index for each dimension that the roll starts to wrap
37 //    back to the front
38 template <typename T>
DoRoll(OpKernelContext * context,const int64 num_elements,const int num_dims,const gtl::ArraySlice<int> & dim_size,const T * input,T * output,const gtl::ArraySlice<int> & threshold,const gtl::ArraySlice<int64> & dim_range)39 void DoRoll(OpKernelContext* context, const int64 num_elements,
40             const int num_dims, const gtl::ArraySlice<int>& dim_size,
41             const T* input, T* output, const gtl::ArraySlice<int>& threshold,
42             const gtl::ArraySlice<int64>& dim_range) {
43   auto work = [input, output, num_dims, &dim_size, &threshold, &dim_range](
44                   int64 start, int64 end) {
45     // array of indices for each dimension
46     gtl::InlinedVector<int, 4> indices(num_dims);
47     int offset = 0;  // the shift along the flattened tensor for current element
48     // initialize indices and offset
49     for (int i = 0; i < num_dims; i++) {
50       // stride is the number of indices over in the flattened tensor
51       // you need to skip in order to make it over to an adjacent element
52       // along a dimension. dim_size[i] != 0 because we set it to max(dim, 1)
53       const int64 stride = dim_range[i] / dim_size[i];
54       const int shift = dim_size[i] - threshold[i];
55       const int indx = (start / stride) % dim_size[i];
56       indices[i] = indx;
57       // calculate dimension index after the shift
58       const int shifted_indx = (indx + shift) % dim_size[i];
59       offset += (shifted_indx - indx) * stride;
60     }
61 
62     for (int64 i = start; i < end; i++) {
63       output[i + offset] = input[i];
64       // create next combination of indices
65       // while at it adjust offset if needed
66       for (int j = num_dims - 1; j >= 0; j--) {
67         const int indx = (indices[j] + 1) % dim_size[j];
68         indices[j] = indx;
69         if (indx != 0) {
70           if (indx == threshold[j]) {  // we've reached the threshold
71             // dim_range[j] = threshold[j] + shift[j]
72             // offset = shift[j] + ... other offsets
73             // offset - dim_range[j] = -threshold[j] + ... other offsets
74             // thus we undo our previous offset as well as add a new offset of
75             // -threshold[j] in one operation
76             offset -= dim_range[j];  // now wraps around
77           }
78           break;                         // indx != 0 don't need to carry
79         } else if (threshold[j] != 0) {  // if threshold is 0 shift is 0
80           offset += dim_range[j];        // indx became 0 so reverse wrap around
81         }
82       }
83     }
84   };
85   // Shard
86   auto worker_threads = context->device()->tensorflow_cpu_worker_threads();
87   // 15 - expiramentally determined with float and bool types
88   const int cost_per_element = 15 * sizeof(T);  // rough estimate
89   Shard(worker_threads->num_threads, worker_threads->workers, num_elements,
90         cost_per_element, std::move(work));
91 }
92 
93 // dim_size - the size of each dimension
94 // dim_range - the number of indices over in the flattened tensor
95 //    you need to skip in order to make it over from one side of a dimension
96 //    to the other. Used to make the shifts wrap around after a threshold.
97 // threshold - the index for each dimension that the roll starts to wrap
98 //    back to the front
99 // isd - inner shift dimension
100 template <typename T>
101 // Use memcpy to copy memory in groups when the data type supports memcpy
DoRollWithMemcpy(OpKernelContext * context,const int64 num_elements,const int num_dims,const gtl::ArraySlice<int> & dim_size,const T * input,T * output,const gtl::ArraySlice<int> & threshold,const gtl::ArraySlice<int64> & dim_range,const int64 isd)102 void DoRollWithMemcpy(OpKernelContext* context, const int64 num_elements,
103                       const int num_dims, const gtl::ArraySlice<int>& dim_size,
104                       const T* input, T* output,
105                       const gtl::ArraySlice<int>& threshold,
106                       const gtl::ArraySlice<int64>& dim_range,
107                       const int64 isd) {
108   auto work = [input, output, num_dims, &dim_size, &threshold, &dim_range, isd](
109                   int64 start, int64 end) {
110     // the number of indices over in the flattened tensor you need to skip in
111     // order to make it over from one side of the isd to the other
112     const int64 isd_range = std::max<int>(dim_range[isd], 1);
113     // the distance along the flattend tensor to the next element in the isd
114     const int64 isd_stride = isd_range / std::max<int>(dim_size[isd], 1);
115 
116     // start and end represent the i-th group currently so we will convert
117     // them into numbers representing the i-th elements.
118     // there are 2 groups per isd one for all elements before threshold[isd]
119     // and another for all elements after threshold[isd].
120     const int64 start_remainder = (start % 2) * threshold[isd] * isd_stride;
121     const int64 end_remainder = (end % 2) * threshold[isd] * isd_stride;
122     start = (start / 2) * isd_range + start_remainder;
123     end = (end / 2) * isd_range + end_remainder;
124 
125     const T* in_ptr = &input[0];
126     T* out_ptr = &output[0];
127     in_ptr += start;
128     out_ptr += start;
129 
130     // array of indices for each dimension
131     // indicies = [i, j, k, l, m, n]
132     gtl::InlinedVector<int, 4> indicies(num_dims);
133     // the offset needed to make all inner non-shifting dimensions become 0
134     int64 remainder_offset = 0;
135     // initialize indicies
136     for (int i = 0; i < num_dims; i++) {
137       // stride is the number of indices over in the flattened tensor
138       // you need to skip in order to make it over to an adjacent element
139       // along a dimension. dim_size[i] != 0 because we set it to max(dim, 1)
140       const int64 stride = dim_range[i] / dim_size[i];
141       const int shift = dim_size[i] - threshold[i];
142       const int indx = (start / stride) % dim_size[i];
143       indicies[i] = indx;
144       // calculate dimension index after the shift
145       int out_indx = (indx + shift) % dim_size[i];
146       if (i > isd) {
147         // trailing zeroes for indices after the inner shifted dimension
148         out_indx = 0;
149         remainder_offset += (out_indx - indx) * stride;
150       }
151       out_ptr += (out_indx - indx) * stride;
152     }
153     // set trailing zeroes for indices after the inner shifted dimension
154     for (int i = num_dims - 1; i > isd; i--) indicies[i] = 0;
155 
156     // the number of indices in the isd dimension the next group will skip
157     // to make it to the next threshold or end point
158     int isd_indx_skip = 0;
159     // the size of the next group
160     int64 group_size = 0;
161     // initialize isd_indx_skip and group_size
162     if (indicies[isd] < threshold[isd]) {
163       isd_indx_skip = threshold[isd] - indicies[isd];
164       group_size = isd_indx_skip * isd_stride + remainder_offset;
165     } else {
166       isd_indx_skip = dim_size[isd] - indicies[isd];
167       group_size = isd_indx_skip * isd_stride + remainder_offset;
168     }
169 
170     int64 i = start;
171     while (i < end) {
172       // copy group of elements
173       memcpy(out_ptr, in_ptr, group_size * sizeof(T));
174 
175       // shift i and the pointers over to the next group position
176       i += group_size;
177       out_ptr += group_size;
178       in_ptr += group_size;
179 
180       // produce next combination of indices and adjust the out_ptr position
181       // to fix the offset if necessary
182       // the isd (inner shift dim) should skip to next threshold or endpoint
183       // all dimensions to the left increment by 1 when a digit is carried
184       // all dimensions to the right remain set to 0
185       //            +1 +1 +1 +isd_indx_skip
186       // indicies = [i, j, k, l, 0, 0]
187       //                      ^isd
188       for (int j = isd; j >= 0; j--) {
189         int inc = 1;
190         if (j == isd) inc = isd_indx_skip;
191         const int indx = (indicies[j] + inc) % dim_size[j];
192         indicies[j] = indx;
193         if (indx != 0) {
194           if (indx == threshold[j]) {
195             out_ptr -= dim_range[j];  // now wraps around
196           }
197           break;                         // indx != 0 don't need to carry
198         } else if (threshold[j] != 0) {  // if threshold is 0 shift is 0
199           out_ptr += dim_range[j];       // indx became 0 so reverse wrap around
200         }
201       }
202 
203       // set isd_indx_skip and group_size for next iteration
204       if (indicies[isd] < threshold[isd]) {
205         isd_indx_skip = threshold[isd] - indicies[isd];
206         group_size = isd_indx_skip * isd_stride;
207       } else {
208         isd_indx_skip = dim_size[isd] - indicies[isd];
209         group_size = isd_indx_skip * isd_stride;
210       }
211     }
212   };
213   // Shard
214   auto worker_threads = context->device()->tensorflow_cpu_worker_threads();
215   const int64 ave_group_size = dim_range[isd] / 2;
216   const int total_work = 2 * num_elements / std::max<int>(dim_range[isd], 1);
217   // 25000 - expiramentally determined with float and bool types
218   const int cost_per_group = 25000 * sizeof(T) * ave_group_size;
219   Shard(worker_threads->num_threads, worker_threads->workers, total_work,
220         cost_per_group, std::move(work));
221 }
222 
223 template <typename Device, typename T, typename Tshift, typename Taxis>
224 class RollOp : public OpKernel {
225  public:
RollOp(OpKernelConstruction * context)226   explicit RollOp(OpKernelConstruction* context) : OpKernel(context) {}
227 
Compute(OpKernelContext * context)228   void Compute(OpKernelContext* context) override {
229     // Grab the input tensor
230     const Tensor& input = context->input(0);
231     const Tensor& shift = context->input(1);
232     const Tensor& axis = context->input(2);
233 
234     auto shift_flat = shift.flat<Tshift>();
235     auto axis_flat = axis.flat<Taxis>();
236 
237     OP_REQUIRES(context, TensorShapeUtils::IsVectorOrHigher(input.shape()),
238                 errors::InvalidArgument("input must be 1-D or higher"));
239     OP_REQUIRES(context, shift.shape().dims() <= 1,
240                 errors::InvalidArgument(
241                     "shift must be a scalar or a 1-D vector. Found: ",
242                     shift.shape().DebugString()));
243     OP_REQUIRES(context, axis.shape().dims() <= 1,
244                 errors::InvalidArgument(
245                     "axis must be a scalar or a 1-D vector. Found: ",
246                     axis.shape().DebugString()));
247     OP_REQUIRES(
248         context, shift.shape() == axis.shape(),
249         errors::InvalidArgument("shift and axis must have the same size"));
250     const int64 num_elements = input.NumElements();
251     const int num_shifts = static_cast<int>(shift_flat.size());
252     const int num_dims = input.dims();
253 
254     // if there are any duplicate axes, shift_mod_sum will have the
255     // total modulo sum of shifts for each dimension
256     gtl::InlinedVector<int, 4> shift_mod_sum(num_dims, 0);
257     for (int i = 0; i < num_shifts; i++) {
258       int axis = axis_flat(i);
259       if (axis < 0) {
260         axis += num_dims;
261       }
262       OP_REQUIRES(context, FastBoundsCheck(axis, num_dims),
263                   errors::InvalidArgument("axis ", axis, " is out of range"));
264       const int ds = std::max<int>(static_cast<int>(input.dim_size(axis)), 1);
265       const int sum = shift_mod_sum[axis] + static_cast<int>(shift_flat(i));
266       // modulo that works with negatives: ((x % y) + y) % y
267       shift_mod_sum[axis] = (sum % ds + ds) % ds;
268     }
269     // the size of each dimension
270     gtl::InlinedVector<int, 4> dim_size(num_dims);
271     // threshold[i] is the index that the roll starts to wrap back to the front
272     gtl::InlinedVector<int, 4> threshold(num_dims);
273     // dim_range is the number of indices over in the flattened tensor
274     // you need to skip in order to make it over from one side of a dimension
275     // to the other. Used to make the shifts wrap around after a threshold.
276     gtl::InlinedVector<int64, 4> dim_range(num_dims);
277     int64 dim_size_prod = 1;  // dimension size product
278     // inner shift dimension (inner most shifted dimension)
279     int64 isd = 0;
280     for (int i = num_dims - 1; i >= 0; i--) {
281       if (isd == 0 && shift_mod_sum[i] != 0) isd = i;
282       const int ds = std::max<int>(static_cast<int>(input.dim_size(i)), 1);
283       dim_size[i] = ds;
284       threshold[i] = (ds - shift_mod_sum[i]) % ds;
285       dim_size_prod *= static_cast<int64>(input.dim_size(i));
286       dim_range[i] = dim_size_prod;
287     }
288 
289     Tensor* output = nullptr;
290     OP_REQUIRES_OK(context,
291                    context->allocate_output(0, input.shape(), &output));
292     auto input_flat = input.flat<T>().data();
293     auto output_flat = output->flat<T>().data();
294 
295     if (std::is_same<Device, CPUDevice>::value) {
296       if (DataTypeCanUseMemcpy(DataTypeToEnum<T>::v())) {
297         // V2 copies memory in groups instead of element by element
298         DoRollWithMemcpy<T>(context, num_elements, num_dims, dim_size,
299                             input_flat, output_flat, threshold, dim_range, isd);
300       } else {
301         // incase memcpy does not work for current data type
302         DoRoll<T>(context, num_elements, num_dims, dim_size, input_flat,
303                   output_flat, threshold, dim_range);
304       }
305     }
306   }
307 };
308 
309 // Register the CPU kernels.
310 #define REGISTER_CPU(type)                                       \
311   REGISTER_KERNEL_BUILDER(Name("Roll")                           \
312                               .Device(DEVICE_CPU)                \
313                               .TypeConstraint<type>("T")         \
314                               .TypeConstraint<int32>("Tshift")   \
315                               .TypeConstraint<int32>("Taxis"),   \
316                           RollOp<CPUDevice, type, int32, int32>) \
317   REGISTER_KERNEL_BUILDER(Name("Roll")                           \
318                               .Device(DEVICE_CPU)                \
319                               .TypeConstraint<type>("T")         \
320                               .TypeConstraint<int64>("Tshift")   \
321                               .TypeConstraint<int32>("Taxis"),   \
322                           RollOp<CPUDevice, type, int64, int32>) \
323   REGISTER_KERNEL_BUILDER(Name("Roll")                           \
324                               .Device(DEVICE_CPU)                \
325                               .TypeConstraint<type>("T")         \
326                               .TypeConstraint<int32>("Tshift")   \
327                               .TypeConstraint<int64>("Taxis"),   \
328                           RollOp<CPUDevice, type, int32, int64>) \
329   REGISTER_KERNEL_BUILDER(Name("Roll")                           \
330                               .Device(DEVICE_CPU)                \
331                               .TypeConstraint<type>("T")         \
332                               .TypeConstraint<int64>("Tshift")   \
333                               .TypeConstraint<int64>("Taxis"),   \
334                           RollOp<CPUDevice, type, int64, int64>)
335 
336 TF_CALL_ALL_TYPES(REGISTER_CPU);
337 #undef REGISTER_CPU
338 }  // namespace tensorflow
339