1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_KERNELS_SCATTER_ND_OP_CPU_IMPL_H_ 17 #define TENSORFLOW_CORE_KERNELS_SCATTER_ND_OP_CPU_IMPL_H_ 18 19 // Functor definitions for ScatterND ops, must be compilable by nvcc. 20 21 #define EIGEN_USE_THREADS 22 23 #include <atomic> 24 25 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 26 27 #include "tensorflow/core/framework/bounds_check.h" 28 #include "tensorflow/core/framework/op_kernel.h" 29 #include "tensorflow/core/framework/register_types.h" 30 #include "tensorflow/core/framework/tensor.h" 31 #include "tensorflow/core/framework/tensor_shape.h" 32 #include "tensorflow/core/kernels/fill_functor.h" 33 #include "tensorflow/core/kernels/scatter_nd_op.h" 34 #include "tensorflow/core/platform/mutex.h" 35 #include "tensorflow/core/platform/types.h" 36 #include "tensorflow/core/util/util.h" 37 38 namespace tensorflow { 39 40 typedef Eigen::ThreadPoolDevice CPUDevice; 41 #ifdef TENSORFLOW_USE_SYCL 42 typedef Eigen::SyclDevice SYCLDevice; 43 #endif // TENSORFLOW_USE_SYCL 44 45 class OpKernelContext; 46 47 // Specialization of UpdateExecutor to CPU 48 namespace update_executor { 49 50 template <typename Input, typename Update, typename Output, 51 scatter_nd_op::UpdateOp OP> 52 class UpdateExecutor { 53 public: 54 EIGEN_STRONG_INLINE static void Execute(Input value, Update update, 55 Output output); 56 }; 57 58 template <typename Input, typename Update, typename Output> 59 class UpdateExecutor<Input, Update, Output, scatter_nd_op::UpdateOp::ASSIGN> { 60 public: Execute(Input,Update update,Output output)61 EIGEN_STRONG_INLINE static void Execute(Input /* input */, Update update, 62 Output output) { 63 output = update; 64 } 65 }; 66 67 template <typename Input, typename Update, typename Output> 68 class UpdateExecutor<Input, Update, Output, scatter_nd_op::UpdateOp::ADD> { 69 public: Execute(Input,Update update,Output output)70 EIGEN_STRONG_INLINE static void Execute(Input /* input */, Update update, 71 Output output) { 72 output += update; 73 } 74 }; 75 76 template <typename Input, typename Update, typename Output> 77 class UpdateExecutor<Input, Update, Output, scatter_nd_op::UpdateOp::SUB> { 78 public: Execute(Input,Update update,Output output)79 EIGEN_STRONG_INLINE static void Execute(Input /* input */, Update update, 80 Output output) { 81 output -= update; 82 } 83 }; 84 85 } // namespace update_executor 86 87 namespace functor { 88 89 // Implementation of update functor for CPU. 90 template <typename T, typename Index, scatter_nd_op::UpdateOp OP, int IXDIM> 91 struct ScatterNdFunctor<CPUDevice, T, Index, OP, IXDIM> { 92 Index operator()( 93 const CPUDevice& d, const Index slice_size, 94 const Eigen::array<Eigen::DenseIndex, IXDIM> output_shape_prefix, 95 typename TTypes<T, 2>::Tensor Tparams, 96 typename TTypes<Index, 2>::ConstTensor Tindices, 97 typename TTypes<T, 2>::ConstTensor Tupdates, 98 typename TTypes<T, 2>::Tensor Toutput) { 99 // error_loc is -1 if there's no out-of-bounds index, 100 // otherwise it is the location of an OOB index in Tindices. 101 Index error_loc = -1; 102 103 const Eigen::DenseIndex batch_size = Tindices.dimension(0); 104 105 Index batch_strides[IXDIM]; 106 for (int dim = IXDIM - 1; dim >= 0; --dim) { 107 if (dim == IXDIM - 1) { 108 batch_strides[dim] = 1; 109 } else { 110 batch_strides[dim] = 111 batch_strides[dim + 1] * output_shape_prefix[dim + 1]; 112 } 113 } 114 115 for (Eigen::DenseIndex loc = 0; loc < batch_size; ++loc) { 116 Index i = 0; 117 bool out_of_bounds = false; 118 for (int dim = 0; dim < IXDIM; ++dim) { 119 const Index ix_d = internal::SubtleMustCopy(Tindices(loc, dim)); 120 out_of_bounds |= !FastBoundsCheck(ix_d, output_shape_prefix[dim]); 121 i += ix_d * batch_strides[dim]; 122 } 123 if (TF_PREDICT_FALSE(out_of_bounds)) { 124 error_loc = loc; 125 break; 126 } else { 127 auto input_chip = Toutput.template chip<0>(i); 128 auto output_chip = input_chip.device(d); 129 auto update_chip = Tupdates.template chip<0>(loc); 130 update_executor::UpdateExecutor< 131 decltype(input_chip), decltype(update_chip), decltype(output_chip), 132 OP>::Execute(input_chip, update_chip, output_chip); 133 } 134 } 135 136 return error_loc; 137 } 138 }; 139 140 #define REGISTER_SCATTER_ND_FULL(T, Index, op) \ 141 template Index \ 142 ScatterNdFunctor<CPUDevice, T, Index, op, CPU_PROVIDED_IXDIM>::operator()( \ 143 const CPUDevice& d, const Index slice_size, \ 144 const Eigen::array<Eigen::DenseIndex, CPU_PROVIDED_IXDIM> \ 145 output_shape_prefix, \ 146 typename TTypes<T, 2>::Tensor Tparams, \ 147 typename TTypes<Index, 2>::ConstTensor Tindices, \ 148 typename TTypes<T, 2>::ConstTensor Tupdates, \ 149 typename TTypes<T, 2>::Tensor Toutput) 150 151 #define REGISTER_SCATTER_ND_INDEX(type, op) \ 152 REGISTER_SCATTER_ND_FULL(type, int32, op); \ 153 REGISTER_SCATTER_ND_FULL(type, int64, op) 154 155 #define REGISTER_SCATTER_ND_UPDATE(type) \ 156 REGISTER_SCATTER_ND_INDEX(type, scatter_nd_op::UpdateOp::ASSIGN); 157 158 #define REGISTER_SCATTER_ND_MATH(type) \ 159 REGISTER_SCATTER_ND_INDEX(type, scatter_nd_op::UpdateOp::ADD); \ 160 REGISTER_SCATTER_ND_INDEX(type, scatter_nd_op::UpdateOp::SUB); 161 162 TF_CALL_ALL_TYPES(REGISTER_SCATTER_ND_UPDATE); 163 REGISTER_SCATTER_ND_INDEX(string, scatter_nd_op::UpdateOp::ADD); 164 TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ND_MATH); 165 TF_CALL_bool(REGISTER_SCATTER_ND_MATH); 166 #undef REGISTER_SCATTER_ND_MATH 167 #undef REGISTER_SCATTER_ND_UPDATE 168 #undef REGISTER_SCATTER_ND_INDEX 169 #undef REGISTER_SCATTER_ND_FULL 170 171 // Implementation of update functor for SYCL. 172 #ifdef TENSORFLOW_USE_SYCL 173 174 template <typename T, typename Index, scatter_nd_op::UpdateOp OP, int IXDIM> 175 struct ScatterNdFunctor<SYCLDevice, T, Index, OP, IXDIM> { 176 Index operator()( 177 const SYCLDevice& d, const Index slice_size, 178 const Eigen::array<Eigen::DenseIndex, IXDIM> output_shape_prefix, 179 typename TTypes<T, 2>::Tensor Tparams, 180 typename TTypes<Index, 2>::ConstTensor Tindices, 181 typename TTypes<T, 2>::ConstTensor Tupdates, 182 typename TTypes<T, 2>::Tensor Toutput) { 183 // error_loc is -1 if there's no out-of-bounds index, 184 // otherwise it is the location of an OOB index in Tindices. 185 Index error_loc = -1; 186 187 const Eigen::DenseIndex batch_size = Tindices.dimension(0); 188 189 Index batch_strides[IXDIM]; 190 for (int dim = IXDIM - 1; dim >= 0; --dim) { 191 if (dim == IXDIM - 1) { 192 batch_strides[dim] = 1; 193 } else { 194 batch_strides[dim] = 195 batch_strides[dim + 1] * output_shape_prefix[dim + 1]; 196 } 197 } 198 199 for (Eigen::DenseIndex loc = 0; loc < batch_size; ++loc) { 200 Index i = 0; 201 bool out_of_bounds = false; 202 for (int dim = 0; dim < IXDIM; ++dim) { 203 const Index ix_d = internal::SubtleMustCopy(Tindices(loc, dim)); 204 out_of_bounds |= !FastBoundsCheck(ix_d, output_shape_prefix[dim]); 205 i += ix_d * batch_strides[dim]; 206 } 207 if (TF_PREDICT_FALSE(out_of_bounds)) { 208 error_loc = loc; 209 break; 210 } else { 211 auto input_chip = Toutput.template chip<0>(i); 212 auto output_chip = input_chip.device(d); 213 auto update_chip = Tupdates.template chip<0>(loc); 214 update_executor::UpdateExecutor< 215 decltype(input_chip), decltype(update_chip), decltype(output_chip), 216 OP>::Execute(input_chip, update_chip, output_chip); 217 } 218 } 219 220 return error_loc; 221 } 222 }; 223 224 #define REGISTER_SCATTER_ND_FULL_SYCL(T, Index, op) \ 225 template Index \ 226 ScatterNdFunctor<SYCLDevice, T, Index, op, CPU_PROVIDED_IXDIM>::operator()( \ 227 const SYCLDevice& d, const Index slice_size, \ 228 const Eigen::array<Eigen::DenseIndex, CPU_PROVIDED_IXDIM> \ 229 output_shape_prefix, \ 230 typename TTypes<T, 2>::Tensor Tparams, \ 231 typename TTypes<Index, 2>::ConstTensor Tindices, \ 232 typename TTypes<T, 2>::ConstTensor Tupdates, \ 233 typename TTypes<T, 2>::Tensor Toutput) 234 235 #define REGISTER_SCATTER_ND_INDEX_SYCL(type, op) \ 236 REGISTER_SCATTER_ND_FULL_SYCL(type, int32, op); \ 237 REGISTER_SCATTER_ND_FULL_SYCL(type, int64, op) 238 239 #define REGISTER_SCATTER_ND_UPDATE_SYCL(type) \ 240 REGISTER_SCATTER_ND_INDEX_SYCL(type, scatter_nd_op::UpdateOp::ASSIGN); 241 242 #define REGISTER_SCATTER_ND_MATH_SYCL(type) \ 243 REGISTER_SCATTER_ND_INDEX_SYCL(type, scatter_nd_op::UpdateOp::ADD); \ 244 REGISTER_SCATTER_ND_INDEX_SYCL(type, scatter_nd_op::UpdateOp::SUB); 245 246 TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ND_UPDATE_SYCL) 247 TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ND_MATH_SYCL) 248 REGISTER_SCATTER_ND_UPDATE_SYCL(int32); 249 REGISTER_SCATTER_ND_MATH_SYCL(int32); 250 251 #undef REGISTER_SCATTER_ND_MATH_SYCL 252 #undef REGISTER_SCATTER_ND_UPDATE_SYCL 253 #undef REGISTER_SCATTER_ND_INDEX_SYCL 254 #undef REGISTER_SCATTER_ND_FULL_SYCL 255 256 #endif // TENSORFLOW_USE_SYCL 257 258 } // namespace functor 259 260 } // namespace tensorflow 261 262 #endif // TENSORFLOW_CORE_KERNELS_SCATTER_ND_OP_CPU_IMPL_H_ 263