• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_KERNELS_SCATTER_ND_OP_CPU_IMPL_H_
17 #define TENSORFLOW_CORE_KERNELS_SCATTER_ND_OP_CPU_IMPL_H_
18 
19 // Functor definitions for ScatterND ops, must be compilable by nvcc.
20 
21 #define EIGEN_USE_THREADS
22 
23 #include <atomic>
24 
25 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
26 
27 #include "tensorflow/core/framework/bounds_check.h"
28 #include "tensorflow/core/framework/op_kernel.h"
29 #include "tensorflow/core/framework/register_types.h"
30 #include "tensorflow/core/framework/tensor.h"
31 #include "tensorflow/core/framework/tensor_shape.h"
32 #include "tensorflow/core/kernels/fill_functor.h"
33 #include "tensorflow/core/kernels/scatter_nd_op.h"
34 #include "tensorflow/core/platform/mutex.h"
35 #include "tensorflow/core/platform/types.h"
36 #include "tensorflow/core/util/util.h"
37 
38 namespace tensorflow {
39 
40 typedef Eigen::ThreadPoolDevice CPUDevice;
41 #ifdef TENSORFLOW_USE_SYCL
42 typedef Eigen::SyclDevice SYCLDevice;
43 #endif  // TENSORFLOW_USE_SYCL
44 
45 class OpKernelContext;
46 
47 // Specialization of UpdateExecutor to CPU
48 namespace update_executor {
49 
50 template <typename Input, typename Update, typename Output,
51           scatter_nd_op::UpdateOp OP>
52 class UpdateExecutor {
53  public:
54   EIGEN_STRONG_INLINE static void Execute(Input value, Update update,
55                                           Output output);
56 };
57 
58 template <typename Input, typename Update, typename Output>
59 class UpdateExecutor<Input, Update, Output, scatter_nd_op::UpdateOp::ASSIGN> {
60  public:
Execute(Input,Update update,Output output)61   EIGEN_STRONG_INLINE static void Execute(Input /* input */, Update update,
62                                           Output output) {
63     output = update;
64   }
65 };
66 
67 template <typename Input, typename Update, typename Output>
68 class UpdateExecutor<Input, Update, Output, scatter_nd_op::UpdateOp::ADD> {
69  public:
Execute(Input,Update update,Output output)70   EIGEN_STRONG_INLINE static void Execute(Input /* input */, Update update,
71                                           Output output) {
72     output += update;
73   }
74 };
75 
76 template <typename Input, typename Update, typename Output>
77 class UpdateExecutor<Input, Update, Output, scatter_nd_op::UpdateOp::SUB> {
78  public:
Execute(Input,Update update,Output output)79   EIGEN_STRONG_INLINE static void Execute(Input /* input */, Update update,
80                                           Output output) {
81     output -= update;
82   }
83 };
84 
85 }  // namespace update_executor
86 
87 namespace functor {
88 
89 // Implementation of update functor for CPU.
90 template <typename T, typename Index, scatter_nd_op::UpdateOp OP, int IXDIM>
91 struct ScatterNdFunctor<CPUDevice, T, Index, OP, IXDIM> {
92   Index operator()(
93       const CPUDevice& d, const Index slice_size,
94       const Eigen::array<Eigen::DenseIndex, IXDIM> output_shape_prefix,
95       typename TTypes<T, 2>::Tensor Tparams,
96       typename TTypes<Index, 2>::ConstTensor Tindices,
97       typename TTypes<T, 2>::ConstTensor Tupdates,
98       typename TTypes<T, 2>::Tensor Toutput) {
99     // error_loc is -1 if there's no out-of-bounds index,
100     // otherwise it is the location of an OOB index in Tindices.
101     Index error_loc = -1;
102 
103     const Eigen::DenseIndex batch_size = Tindices.dimension(0);
104 
105     Index batch_strides[IXDIM];
106     for (int dim = IXDIM - 1; dim >= 0; --dim) {
107       if (dim == IXDIM - 1) {
108         batch_strides[dim] = 1;
109       } else {
110         batch_strides[dim] =
111             batch_strides[dim + 1] * output_shape_prefix[dim + 1];
112       }
113     }
114 
115     for (Eigen::DenseIndex loc = 0; loc < batch_size; ++loc) {
116       Index i = 0;
117       bool out_of_bounds = false;
118       for (int dim = 0; dim < IXDIM; ++dim) {
119         const Index ix_d = internal::SubtleMustCopy(Tindices(loc, dim));
120         out_of_bounds |= !FastBoundsCheck(ix_d, output_shape_prefix[dim]);
121         i += ix_d * batch_strides[dim];
122       }
123       if (TF_PREDICT_FALSE(out_of_bounds)) {
124         error_loc = loc;
125         break;
126       } else {
127         auto input_chip = Toutput.template chip<0>(i);
128         auto output_chip = input_chip.device(d);
129         auto update_chip = Tupdates.template chip<0>(loc);
130         update_executor::UpdateExecutor<
131             decltype(input_chip), decltype(update_chip), decltype(output_chip),
132             OP>::Execute(input_chip, update_chip, output_chip);
133       }
134     }
135 
136     return error_loc;
137   }
138 };
139 
140 #define REGISTER_SCATTER_ND_FULL(T, Index, op)                               \
141   template Index                                                             \
142   ScatterNdFunctor<CPUDevice, T, Index, op, CPU_PROVIDED_IXDIM>::operator()( \
143       const CPUDevice& d, const Index slice_size,                            \
144       const Eigen::array<Eigen::DenseIndex, CPU_PROVIDED_IXDIM>              \
145           output_shape_prefix,                                               \
146       typename TTypes<T, 2>::Tensor Tparams,                                 \
147       typename TTypes<Index, 2>::ConstTensor Tindices,                       \
148       typename TTypes<T, 2>::ConstTensor Tupdates,                           \
149       typename TTypes<T, 2>::Tensor Toutput)
150 
151 #define REGISTER_SCATTER_ND_INDEX(type, op)  \
152   REGISTER_SCATTER_ND_FULL(type, int32, op); \
153   REGISTER_SCATTER_ND_FULL(type, int64, op)
154 
155 #define REGISTER_SCATTER_ND_UPDATE(type) \
156   REGISTER_SCATTER_ND_INDEX(type, scatter_nd_op::UpdateOp::ASSIGN);
157 
158 #define REGISTER_SCATTER_ND_MATH(type)                           \
159   REGISTER_SCATTER_ND_INDEX(type, scatter_nd_op::UpdateOp::ADD); \
160   REGISTER_SCATTER_ND_INDEX(type, scatter_nd_op::UpdateOp::SUB);
161 
162 TF_CALL_ALL_TYPES(REGISTER_SCATTER_ND_UPDATE);
163 REGISTER_SCATTER_ND_INDEX(string, scatter_nd_op::UpdateOp::ADD);
164 TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ND_MATH);
165 TF_CALL_bool(REGISTER_SCATTER_ND_MATH);
166 #undef REGISTER_SCATTER_ND_MATH
167 #undef REGISTER_SCATTER_ND_UPDATE
168 #undef REGISTER_SCATTER_ND_INDEX
169 #undef REGISTER_SCATTER_ND_FULL
170 
171 // Implementation of update functor for SYCL.
172 #ifdef TENSORFLOW_USE_SYCL
173 
174 template <typename T, typename Index, scatter_nd_op::UpdateOp OP, int IXDIM>
175 struct ScatterNdFunctor<SYCLDevice, T, Index, OP, IXDIM> {
176   Index operator()(
177       const SYCLDevice& d, const Index slice_size,
178       const Eigen::array<Eigen::DenseIndex, IXDIM> output_shape_prefix,
179       typename TTypes<T, 2>::Tensor Tparams,
180       typename TTypes<Index, 2>::ConstTensor Tindices,
181       typename TTypes<T, 2>::ConstTensor Tupdates,
182       typename TTypes<T, 2>::Tensor Toutput) {
183     // error_loc is -1 if there's no out-of-bounds index,
184     // otherwise it is the location of an OOB index in Tindices.
185     Index error_loc = -1;
186 
187     const Eigen::DenseIndex batch_size = Tindices.dimension(0);
188 
189     Index batch_strides[IXDIM];
190     for (int dim = IXDIM - 1; dim >= 0; --dim) {
191       if (dim == IXDIM - 1) {
192         batch_strides[dim] = 1;
193       } else {
194         batch_strides[dim] =
195             batch_strides[dim + 1] * output_shape_prefix[dim + 1];
196       }
197     }
198 
199     for (Eigen::DenseIndex loc = 0; loc < batch_size; ++loc) {
200       Index i = 0;
201       bool out_of_bounds = false;
202       for (int dim = 0; dim < IXDIM; ++dim) {
203         const Index ix_d = internal::SubtleMustCopy(Tindices(loc, dim));
204         out_of_bounds |= !FastBoundsCheck(ix_d, output_shape_prefix[dim]);
205         i += ix_d * batch_strides[dim];
206       }
207       if (TF_PREDICT_FALSE(out_of_bounds)) {
208         error_loc = loc;
209         break;
210       } else {
211         auto input_chip = Toutput.template chip<0>(i);
212         auto output_chip = input_chip.device(d);
213         auto update_chip = Tupdates.template chip<0>(loc);
214         update_executor::UpdateExecutor<
215             decltype(input_chip), decltype(update_chip), decltype(output_chip),
216             OP>::Execute(input_chip, update_chip, output_chip);
217       }
218     }
219 
220     return error_loc;
221   }
222 };
223 
224 #define REGISTER_SCATTER_ND_FULL_SYCL(T, Index, op)                           \
225   template Index                                                              \
226   ScatterNdFunctor<SYCLDevice, T, Index, op, CPU_PROVIDED_IXDIM>::operator()( \
227       const SYCLDevice& d, const Index slice_size,                            \
228       const Eigen::array<Eigen::DenseIndex, CPU_PROVIDED_IXDIM>               \
229           output_shape_prefix,                                                \
230       typename TTypes<T, 2>::Tensor Tparams,                                  \
231       typename TTypes<Index, 2>::ConstTensor Tindices,                        \
232       typename TTypes<T, 2>::ConstTensor Tupdates,                            \
233       typename TTypes<T, 2>::Tensor Toutput)
234 
235 #define REGISTER_SCATTER_ND_INDEX_SYCL(type, op)  \
236   REGISTER_SCATTER_ND_FULL_SYCL(type, int32, op); \
237   REGISTER_SCATTER_ND_FULL_SYCL(type, int64, op)
238 
239 #define REGISTER_SCATTER_ND_UPDATE_SYCL(type) \
240   REGISTER_SCATTER_ND_INDEX_SYCL(type, scatter_nd_op::UpdateOp::ASSIGN);
241 
242 #define REGISTER_SCATTER_ND_MATH_SYCL(type)                           \
243   REGISTER_SCATTER_ND_INDEX_SYCL(type, scatter_nd_op::UpdateOp::ADD); \
244   REGISTER_SCATTER_ND_INDEX_SYCL(type, scatter_nd_op::UpdateOp::SUB);
245 
246 TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ND_UPDATE_SYCL)
247 TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ND_MATH_SYCL)
248 REGISTER_SCATTER_ND_UPDATE_SYCL(int32);
249 REGISTER_SCATTER_ND_MATH_SYCL(int32);
250 
251 #undef REGISTER_SCATTER_ND_MATH_SYCL
252 #undef REGISTER_SCATTER_ND_UPDATE_SYCL
253 #undef REGISTER_SCATTER_ND_INDEX_SYCL
254 #undef REGISTER_SCATTER_ND_FULL_SYCL
255 
256 #endif  // TENSORFLOW_USE_SYCL
257 
258 }  // namespace functor
259 
260 }  // namespace tensorflow
261 
262 #endif  // TENSORFLOW_CORE_KERNELS_SCATTER_ND_OP_CPU_IMPL_H_
263