1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_KERNELS_SCATTER_ND_OP_CPU_IMPL_H_ 17 #define TENSORFLOW_CORE_KERNELS_SCATTER_ND_OP_CPU_IMPL_H_ 18 19 // Functor definitions for ScatterND ops, must be compilable by nvcc. 20 21 #define EIGEN_USE_THREADS 22 23 #include <atomic> 24 25 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 26 27 #include "tensorflow/core/framework/bounds_check.h" 28 #include "tensorflow/core/framework/op_kernel.h" 29 #include "tensorflow/core/framework/register_types.h" 30 #include "tensorflow/core/framework/tensor.h" 31 #include "tensorflow/core/framework/tensor_shape.h" 32 #include "tensorflow/core/kernels/fill_functor.h" 33 #include "tensorflow/core/kernels/scatter_nd_op.h" 34 #include "tensorflow/core/platform/mutex.h" 35 #include "tensorflow/core/platform/types.h" 36 #include "tensorflow/core/util/util.h" 37 38 namespace tensorflow { 39 40 typedef Eigen::ThreadPoolDevice CPUDevice; 41 42 class OpKernelContext; 43 44 // Specialization of UpdateExecutor to CPU 45 namespace update_executor { 46 47 template <typename T, typename Input, typename Update, typename Output, 48 scatter_nd_op::UpdateOp OP> 49 class UpdateExecutor { 50 public: 51 EIGEN_STRONG_INLINE static void Execute(const T& device, Input value, 52 Update update, Output output); 53 }; 54 55 template <typename T, typename Input, typename Update, typename Output> 56 class UpdateExecutor<T, Input, Update, Output, 57 scatter_nd_op::UpdateOp::ASSIGN> { 58 public: Execute(const T & device,Input,Update update,Output output)59 EIGEN_STRONG_INLINE static void Execute(const T& device, Input /* input */, 60 Update update, Output output) { 61 output.device(device) = update; 62 } 63 }; 64 65 template <typename T, typename Input, typename Update, typename Output> 66 class UpdateExecutor<T, Input, Update, Output, scatter_nd_op::UpdateOp::ADD> { 67 public: Execute(const T & device,Input,Update update,Output output)68 EIGEN_STRONG_INLINE static void Execute(const T& device, Input /* input */, 69 Update update, Output output) { 70 output.device(device) += update; 71 } 72 }; 73 74 template <typename T, typename Input, typename Update, typename Output> 75 class UpdateExecutor<T, Input, Update, Output, scatter_nd_op::UpdateOp::SUB> { 76 public: Execute(const T & device,Input,Update update,Output output)77 EIGEN_STRONG_INLINE static void Execute(const T& device, Input /* input */, 78 Update update, Output output) { 79 output.device(device) -= update; 80 } 81 }; 82 83 template <typename T, typename Input, typename Update, typename Output> 84 class UpdateExecutor<T, Input, Update, Output, scatter_nd_op::UpdateOp::MIN> { 85 public: Execute(const T & device,Input,Update update,Output output)86 EIGEN_STRONG_INLINE static void Execute(const T& device, Input /* input */, 87 Update update, Output output) { 88 output.device(device) = output.cwiseMin(update); 89 } 90 }; 91 92 template <typename T, typename Input, typename Update, typename Output> 93 class UpdateExecutor<T, Input, Update, Output, scatter_nd_op::UpdateOp::MAX> { 94 public: Execute(const T & device,Input,Update update,Output output)95 EIGEN_STRONG_INLINE static void Execute(const T& device, Input /* input */, 96 Update update, Output output) { 97 output.device(device) = output.cwiseMax(update); 98 } 99 }; 100 101 } // namespace update_executor 102 103 namespace functor { 104 105 // Implementation of update functor for CPU. 106 template <typename T, typename Index, scatter_nd_op::UpdateOp OP, int IXDIM> 107 struct ScatterNdFunctor<CPUDevice, T, Index, OP, IXDIM> { 108 Index operator()( 109 const CPUDevice& d, const Index slice_size, 110 const Eigen::array<Eigen::DenseIndex, IXDIM> output_shape_prefix, 111 typename TTypes<T, 2>::Tensor Tparams, 112 typename TTypes<Index, 2>::ConstTensor Tindices, 113 typename TTypes<T, 2>::ConstTensor Tupdates, 114 typename TTypes<T, 2>::Tensor Toutput) { 115 // error_loc is -1 if there's no out-of-bounds index, 116 // otherwise it is the location of an OOB index in Tindices. 117 Index error_loc = -1; 118 119 const Eigen::DenseIndex batch_size = Tindices.dimension(0); 120 121 Index batch_strides[IXDIM]; 122 for (int dim = IXDIM - 1; dim >= 0; --dim) { 123 if (dim == IXDIM - 1) { 124 batch_strides[dim] = 1; 125 } else { 126 batch_strides[dim] = 127 batch_strides[dim + 1] * output_shape_prefix[dim + 1]; 128 } 129 } 130 131 for (Eigen::DenseIndex loc = 0; loc < batch_size; ++loc) { 132 Index i = 0; 133 bool out_of_bounds = false; 134 for (int dim = 0; dim < IXDIM; ++dim) { 135 const Index ix_d = internal::SubtleMustCopy(Tindices(loc, dim)); 136 out_of_bounds |= !FastBoundsCheck(ix_d, output_shape_prefix[dim]); 137 i += ix_d * batch_strides[dim]; 138 } 139 if (TF_PREDICT_FALSE(out_of_bounds)) { 140 error_loc = loc; 141 break; 142 } else { 143 auto input_chip = Toutput.template chip<0>(i); 144 auto output_chip = input_chip; 145 auto update_chip = Tupdates.template chip<0>(loc); 146 update_executor::UpdateExecutor< 147 CPUDevice, decltype(input_chip), decltype(update_chip), 148 decltype(output_chip), OP>::Execute(d, input_chip, update_chip, 149 output_chip); 150 } 151 } 152 153 return error_loc; 154 } 155 }; 156 157 #define REGISTER_SCATTER_ND_FULL(T, Index, op) \ 158 template Index \ 159 ScatterNdFunctor<CPUDevice, T, Index, op, CPU_PROVIDED_IXDIM>::operator()( \ 160 const CPUDevice& d, const Index slice_size, \ 161 const Eigen::array<Eigen::DenseIndex, CPU_PROVIDED_IXDIM> \ 162 output_shape_prefix, \ 163 typename TTypes<T, 2>::Tensor Tparams, \ 164 typename TTypes<Index, 2>::ConstTensor Tindices, \ 165 typename TTypes<T, 2>::ConstTensor Tupdates, \ 166 typename TTypes<T, 2>::Tensor Toutput) 167 168 #define REGISTER_SCATTER_ND_INDEX(type, op) \ 169 REGISTER_SCATTER_ND_FULL(type, int32, op); \ 170 REGISTER_SCATTER_ND_FULL(type, int64, op) 171 172 #define REGISTER_SCATTER_ND_UPDATE(type) \ 173 REGISTER_SCATTER_ND_INDEX(type, scatter_nd_op::UpdateOp::ASSIGN); 174 175 #define REGISTER_SCATTER_ND_MATH(type) \ 176 REGISTER_SCATTER_ND_INDEX(type, scatter_nd_op::UpdateOp::ADD); \ 177 REGISTER_SCATTER_ND_INDEX(type, scatter_nd_op::UpdateOp::SUB); 178 179 #define REGISTER_SCATTER_ND_MIN_MAX(type) \ 180 REGISTER_SCATTER_ND_INDEX(type, scatter_nd_op::UpdateOp::MAX); \ 181 REGISTER_SCATTER_ND_INDEX(type, scatter_nd_op::UpdateOp::MIN); 182 183 TF_CALL_ALL_TYPES(REGISTER_SCATTER_ND_UPDATE); 184 REGISTER_SCATTER_ND_INDEX(tstring, scatter_nd_op::UpdateOp::ADD); 185 TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ND_MATH); 186 TF_CALL_REAL_NUMBER_TYPES(REGISTER_SCATTER_ND_MIN_MAX); 187 TF_CALL_bool(REGISTER_SCATTER_ND_MATH); 188 189 #undef REGISTER_SCATTER_ND_MATH 190 #undef REGISTER_SCATTER_ND_MIN_MAX 191 #undef REGISTER_SCATTER_ND_UPDATE 192 #undef REGISTER_SCATTER_ND_INDEX 193 #undef REGISTER_SCATTER_ND_FULL 194 } // namespace functor 195 196 } // namespace tensorflow 197 198 #endif // TENSORFLOW_CORE_KERNELS_SCATTER_ND_OP_CPU_IMPL_H_ 199