• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_KERNELS_SCATTER_ND_OP_CPU_IMPL_H_
17 #define TENSORFLOW_CORE_KERNELS_SCATTER_ND_OP_CPU_IMPL_H_
18 
19 // Functor definitions for ScatterND ops, must be compilable by nvcc.
20 
21 #define EIGEN_USE_THREADS
22 
23 #include <atomic>
24 
25 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
26 
27 #include "tensorflow/core/framework/bounds_check.h"
28 #include "tensorflow/core/framework/op_kernel.h"
29 #include "tensorflow/core/framework/register_types.h"
30 #include "tensorflow/core/framework/tensor.h"
31 #include "tensorflow/core/framework/tensor_shape.h"
32 #include "tensorflow/core/kernels/fill_functor.h"
33 #include "tensorflow/core/kernels/scatter_nd_op.h"
34 #include "tensorflow/core/platform/mutex.h"
35 #include "tensorflow/core/platform/types.h"
36 #include "tensorflow/core/util/util.h"
37 
38 namespace tensorflow {
39 
40 typedef Eigen::ThreadPoolDevice CPUDevice;
41 
42 class OpKernelContext;
43 
44 // Specialization of UpdateExecutor to CPU
45 namespace update_executor {
46 
47 template <typename T, typename Input, typename Update, typename Output,
48           scatter_nd_op::UpdateOp OP>
49 class UpdateExecutor {
50  public:
51   EIGEN_STRONG_INLINE static void Execute(const T& device, Input value,
52                                           Update update, Output output);
53 };
54 
55 template <typename T, typename Input, typename Update, typename Output>
56 class UpdateExecutor<T, Input, Update, Output,
57                      scatter_nd_op::UpdateOp::ASSIGN> {
58  public:
Execute(const T & device,Input,Update update,Output output)59   EIGEN_STRONG_INLINE static void Execute(const T& device, Input /* input */,
60                                           Update update, Output output) {
61     output.device(device) = update;
62   }
63 };
64 
65 template <typename T, typename Input, typename Update, typename Output>
66 class UpdateExecutor<T, Input, Update, Output, scatter_nd_op::UpdateOp::ADD> {
67  public:
Execute(const T & device,Input,Update update,Output output)68   EIGEN_STRONG_INLINE static void Execute(const T& device, Input /* input */,
69                                           Update update, Output output) {
70     output.device(device) += update;
71   }
72 };
73 
74 template <typename T, typename Input, typename Update, typename Output>
75 class UpdateExecutor<T, Input, Update, Output, scatter_nd_op::UpdateOp::SUB> {
76  public:
Execute(const T & device,Input,Update update,Output output)77   EIGEN_STRONG_INLINE static void Execute(const T& device, Input /* input */,
78                                           Update update, Output output) {
79     output.device(device) -= update;
80   }
81 };
82 
83 template <typename T, typename Input, typename Update, typename Output>
84 class UpdateExecutor<T, Input, Update, Output, scatter_nd_op::UpdateOp::MIN> {
85  public:
Execute(const T & device,Input,Update update,Output output)86   EIGEN_STRONG_INLINE static void Execute(const T& device, Input /* input */,
87                                           Update update, Output output) {
88     output.device(device) = output.cwiseMin(update);
89   }
90 };
91 
92 template <typename T, typename Input, typename Update, typename Output>
93 class UpdateExecutor<T, Input, Update, Output, scatter_nd_op::UpdateOp::MAX> {
94  public:
Execute(const T & device,Input,Update update,Output output)95   EIGEN_STRONG_INLINE static void Execute(const T& device, Input /* input */,
96                                           Update update, Output output) {
97     output.device(device) = output.cwiseMax(update);
98   }
99 };
100 
101 }  // namespace update_executor
102 
103 namespace functor {
104 
105 // Implementation of update functor for CPU.
106 template <typename T, typename Index, scatter_nd_op::UpdateOp OP, int IXDIM>
107 struct ScatterNdFunctor<CPUDevice, T, Index, OP, IXDIM> {
108   Index operator()(
109       const CPUDevice& d, const Index slice_size,
110       const Eigen::array<Eigen::DenseIndex, IXDIM> output_shape_prefix,
111       typename TTypes<T, 2>::Tensor Tparams,
112       typename TTypes<Index, 2>::ConstTensor Tindices,
113       typename TTypes<T, 2>::ConstTensor Tupdates,
114       typename TTypes<T, 2>::Tensor Toutput) {
115     // error_loc is -1 if there's no out-of-bounds index,
116     // otherwise it is the location of an OOB index in Tindices.
117     Index error_loc = -1;
118 
119     const Eigen::DenseIndex batch_size = Tindices.dimension(0);
120 
121     Index batch_strides[IXDIM];
122     for (int dim = IXDIM - 1; dim >= 0; --dim) {
123       if (dim == IXDIM - 1) {
124         batch_strides[dim] = 1;
125       } else {
126         batch_strides[dim] =
127             batch_strides[dim + 1] * output_shape_prefix[dim + 1];
128       }
129     }
130 
131     for (Eigen::DenseIndex loc = 0; loc < batch_size; ++loc) {
132       Index i = 0;
133       bool out_of_bounds = false;
134       for (int dim = 0; dim < IXDIM; ++dim) {
135         const Index ix_d = internal::SubtleMustCopy(Tindices(loc, dim));
136         out_of_bounds |= !FastBoundsCheck(ix_d, output_shape_prefix[dim]);
137         i += ix_d * batch_strides[dim];
138       }
139       if (TF_PREDICT_FALSE(out_of_bounds)) {
140         error_loc = loc;
141         break;
142       } else {
143         auto input_chip = Toutput.template chip<0>(i);
144         auto output_chip = input_chip;
145         auto update_chip = Tupdates.template chip<0>(loc);
146         update_executor::UpdateExecutor<
147             CPUDevice, decltype(input_chip), decltype(update_chip),
148             decltype(output_chip), OP>::Execute(d, input_chip, update_chip,
149                                                 output_chip);
150       }
151     }
152 
153     return error_loc;
154   }
155 };
156 
157 #define REGISTER_SCATTER_ND_FULL(T, Index, op)                               \
158   template Index                                                             \
159   ScatterNdFunctor<CPUDevice, T, Index, op, CPU_PROVIDED_IXDIM>::operator()( \
160       const CPUDevice& d, const Index slice_size,                            \
161       const Eigen::array<Eigen::DenseIndex, CPU_PROVIDED_IXDIM>              \
162           output_shape_prefix,                                               \
163       typename TTypes<T, 2>::Tensor Tparams,                                 \
164       typename TTypes<Index, 2>::ConstTensor Tindices,                       \
165       typename TTypes<T, 2>::ConstTensor Tupdates,                           \
166       typename TTypes<T, 2>::Tensor Toutput)
167 
168 #define REGISTER_SCATTER_ND_INDEX(type, op)  \
169   REGISTER_SCATTER_ND_FULL(type, int32, op); \
170   REGISTER_SCATTER_ND_FULL(type, int64, op)
171 
172 #define REGISTER_SCATTER_ND_UPDATE(type) \
173   REGISTER_SCATTER_ND_INDEX(type, scatter_nd_op::UpdateOp::ASSIGN);
174 
175 #define REGISTER_SCATTER_ND_MATH(type)                           \
176   REGISTER_SCATTER_ND_INDEX(type, scatter_nd_op::UpdateOp::ADD); \
177   REGISTER_SCATTER_ND_INDEX(type, scatter_nd_op::UpdateOp::SUB);
178 
179 #define REGISTER_SCATTER_ND_MIN_MAX(type)                        \
180   REGISTER_SCATTER_ND_INDEX(type, scatter_nd_op::UpdateOp::MAX); \
181   REGISTER_SCATTER_ND_INDEX(type, scatter_nd_op::UpdateOp::MIN);
182 
183 TF_CALL_ALL_TYPES(REGISTER_SCATTER_ND_UPDATE);
184 REGISTER_SCATTER_ND_INDEX(tstring, scatter_nd_op::UpdateOp::ADD);
185 TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ND_MATH);
186 TF_CALL_REAL_NUMBER_TYPES(REGISTER_SCATTER_ND_MIN_MAX);
187 TF_CALL_bool(REGISTER_SCATTER_ND_MATH);
188 
189 #undef REGISTER_SCATTER_ND_MATH
190 #undef REGISTER_SCATTER_ND_MIN_MAX
191 #undef REGISTER_SCATTER_ND_UPDATE
192 #undef REGISTER_SCATTER_ND_INDEX
193 #undef REGISTER_SCATTER_ND_FULL
194 }  // namespace functor
195 
196 }  // namespace tensorflow
197 
198 #endif  // TENSORFLOW_CORE_KERNELS_SCATTER_ND_OP_CPU_IMPL_H_
199