1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
17
18 #define EIGEN_USE_GPU
19
20 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
21 #include "tensorflow/core/framework/register_types.h"
22 #include "tensorflow/core/kernels/scatter_nd_op.h"
23 #include "tensorflow/core/platform/types.h"
24 #include "tensorflow/core/util/gpu_kernel_helper.h"
25
26 namespace tensorflow {
27
28 typedef Eigen::GpuDevice GPUDevice;
29
30 namespace {
31
32 template <typename T, scatter_nd_op::UpdateOp Op>
33 struct LeftUpdate {
34 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC void operator()(T* out, const T& val);
35 };
36
37 template <typename T>
38 struct LeftUpdate<T, scatter_nd_op::UpdateOp::ASSIGN> {
operator ()tensorflow::__anonbaf50ba60111::LeftUpdate39 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC void operator()(T* out, const T& val) {
40 *out = val;
41 }
42 };
43
44 template <typename T>
45 struct LeftUpdate<T, scatter_nd_op::UpdateOp::ADD> {
operator ()tensorflow::__anonbaf50ba60111::LeftUpdate46 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC void operator()(T* out, const T& val) {
47 GpuAtomicAdd(out, val);
48 }
49 };
50
51 template <typename T>
52 struct LeftUpdate<T, scatter_nd_op::UpdateOp::SUB> {
operator ()tensorflow::__anonbaf50ba60111::LeftUpdate53 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC void operator()(T* out, const T& val) {
54 GpuAtomicSub(out, val);
55 }
56 };
57
58 template <typename T>
59 struct LeftUpdate<T, scatter_nd_op::UpdateOp::MAX> {
operator ()tensorflow::__anonbaf50ba60111::LeftUpdate60 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC void operator()(T* out, const T& val) {
61 GpuAtomicMax(out, val);
62 }
63 };
64
65 template <typename T>
66 struct LeftUpdate<T, scatter_nd_op::UpdateOp::MIN> {
operator ()tensorflow::__anonbaf50ba60111::LeftUpdate67 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC void operator()(T* out, const T& val) {
68 GpuAtomicMin(out, val);
69 }
70 };
71
72 // Specializations for std::complex, updating real and imaginary part
73 // individually. Even though this is not an atomic op anymore, it is safe
74 // because there is only one type of op per kernel.
75 template <typename T>
76 struct LeftUpdate<std::complex<T>, scatter_nd_op::UpdateOp::ADD> {
operator ()tensorflow::__anonbaf50ba60111::LeftUpdate77 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC void operator()(
78 std::complex<T>* out, const std::complex<T>& val) {
79 T* ptr = reinterpret_cast<T*>(out);
80 GpuAtomicAdd(ptr, val.real());
81 GpuAtomicAdd(ptr + 1, val.imag());
82 }
83 };
84
85 template <typename T>
86 struct LeftUpdate<std::complex<T>, scatter_nd_op::UpdateOp::SUB> {
operator ()tensorflow::__anonbaf50ba60111::LeftUpdate87 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC void operator()(
88 std::complex<T>* out, const std::complex<T>& val) {
89 T* ptr = reinterpret_cast<T*>(out);
90 GpuAtomicSub(ptr, val.real());
91 GpuAtomicSub(ptr + 1, val.imag());
92 }
93 };
94
95 } // namespace
96
97 template <typename T, typename Index, scatter_nd_op::UpdateOp op, int IXDIM>
ScatterNdOpKernel(const Index * indices,const T * updates,T * out,const Eigen::array<Eigen::DenseIndex,IXDIM> output_shape_prefix,const Eigen::array<int64,IXDIM> batch_strides,const int64 num_indices,const Index slice_size)98 __global__ void ScatterNdOpKernel(
99 const Index* indices, const T* updates, T* out,
100 const Eigen::array<Eigen::DenseIndex, IXDIM> output_shape_prefix,
101 const Eigen::array<int64, IXDIM> batch_strides, const int64 num_indices,
102 const Index slice_size) {
103 auto update = LeftUpdate<T, op>();
104
105 GPU_1D_KERNEL_LOOP(index, num_indices) {
106 Index i = 0;
107 bool out_of_bounds = false;
108 #pragma unroll
109 for (int dim = 0; dim < IXDIM; ++dim) {
110 int offset = (IXDIM * index + dim);
111 const Index ix_d = internal::SubtleMustCopy(ldg(indices + offset));
112 out_of_bounds |= !FastBoundsCheck(ix_d, output_shape_prefix[dim]);
113 i += ix_d * batch_strides[dim] * slice_size;
114 }
115 if (!out_of_bounds) {
116 #pragma unroll
117 for (int si = 0; si < slice_size; si++) {
118 update(out + i + si, ldg(updates + (index * slice_size + si)));
119 }
120 }
121 }
122 }
123
124 namespace functor {
125
126 // Functor used by ScatterOp to do the computations.
127 template <typename T, typename Index, scatter_nd_op::UpdateOp op, int IXDIM>
128 struct ScatterNdFunctor<GPUDevice, T, Index, op, IXDIM> {
operator ()tensorflow::functor::ScatterNdFunctor129 Index operator()(
130 const GPUDevice& d, const Index slice_size,
131 const Eigen::array<Eigen::DenseIndex, IXDIM> output_shape_prefix,
132 typename TTypes<T, 2>::Tensor Tparams,
133 typename TTypes<Index, 2>::ConstTensor Tindices,
134 typename TTypes<T, 2>::ConstTensor Tupdates,
135 typename TTypes<T, 2>::Tensor Toutput) {
136 // TODO(ebrevdo): The performance of this for small indices (large
137 // slices) is poor. Write a kernel whose splitting is
138 // independent of the slice size. Same for CPU. See the
139 // gather_nd kernel for an example.
140
141 const Eigen::DenseIndex batch_size = Tindices.dimension(0);
142
143 // Index batch_strides[IXDIM];
144 Eigen::array<int64, IXDIM> batch_strides;
145 for (int dim = IXDIM - 1; dim >= 0; --dim) {
146 if (dim == IXDIM - 1) {
147 batch_strides[dim] = 1;
148 } else {
149 batch_strides[dim] =
150 batch_strides[dim + 1] * output_shape_prefix[dim + 1];
151 }
152 }
153
154 GpuLaunchConfig config = GetGpuLaunchConfig(Toutput.size(), d);
155
156 TF_CHECK_OK(GpuLaunchKernel(ScatterNdOpKernel<T, Index, op, IXDIM>,
157 config.block_count, config.thread_per_block, 0,
158 d.stream(), Tindices.data(), Tupdates.data(),
159 Toutput.data(), output_shape_prefix,
160 batch_strides, batch_size, slice_size));
161
162 return -1;
163 }
164 };
165
166 } // namespace functor
167
168 #define DECLARE_GPU_SPECS_INDEX_OP_IXDIM(T, Index, op, IXDIM) \
169 template struct functor::ScatterNdFunctor<GPUDevice, T, Index, op, IXDIM>;
170
171 #define DECLARE_GPU_SPECS_INDEX_OP(T, Index, op) \
172 DECLARE_GPU_SPECS_INDEX_OP_IXDIM(T, Index, op, 1); \
173 DECLARE_GPU_SPECS_INDEX_OP_IXDIM(T, Index, op, 2); \
174 DECLARE_GPU_SPECS_INDEX_OP_IXDIM(T, Index, op, 3); \
175 DECLARE_GPU_SPECS_INDEX_OP_IXDIM(T, Index, op, 4); \
176 DECLARE_GPU_SPECS_INDEX_OP_IXDIM(T, Index, op, 5); \
177 DECLARE_GPU_SPECS_INDEX_OP_IXDIM(T, Index, op, 6); \
178 DECLARE_GPU_SPECS_INDEX_OP_IXDIM(T, Index, op, 7);
179
180 #define DECLARE_GPU_SPECS_INDEX(T, Index) \
181 DECLARE_GPU_SPECS_INDEX_OP(T, Index, scatter_nd_op::UpdateOp::ASSIGN); \
182 DECLARE_GPU_SPECS_INDEX_OP(T, Index, scatter_nd_op::UpdateOp::ADD); \
183 DECLARE_GPU_SPECS_INDEX_OP(T, Index, scatter_nd_op::UpdateOp::SUB);
184
185 #define DECLARE_GPU_SPECS_INDEX_MINMAX(T, Index) \
186 DECLARE_GPU_SPECS_INDEX_OP(T, Index, scatter_nd_op::UpdateOp::MAX) \
187 DECLARE_GPU_SPECS_INDEX_OP(T, Index, scatter_nd_op::UpdateOp::MIN);
188
189 #define DECLARE_GPU_SPECS(T) \
190 DECLARE_GPU_SPECS_INDEX(T, int32); \
191 DECLARE_GPU_SPECS_INDEX(T, int64)
192
193 #define DECLARE_GPU_SPECS_MINMAX(T) \
194 DECLARE_GPU_SPECS_INDEX_MINMAX(T, int32); \
195 DECLARE_GPU_SPECS_INDEX_MINMAX(T, int64)
196
197 TF_CALL_int32(DECLARE_GPU_SPECS);
198 TF_CALL_int32(DECLARE_GPU_SPECS_MINMAX);
199 TF_CALL_int64(DECLARE_GPU_SPECS);
200 TF_CALL_int64(DECLARE_GPU_SPECS_MINMAX);
201 TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPECS);
202 TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPECS_MINMAX);
203 TF_CALL_COMPLEX_TYPES(DECLARE_GPU_SPECS);
204
205 #undef DECLARE_GPU_SPECS
206 #undef DECLARE_GPU_SPECS_MINMAX
207 #undef DECLARE_GPU_SPECS_INDEX
208 #undef DECLARE_GPU_SPECS_INDEX_MINMAX
209 #undef DECLARE_GPU_SPECS_INDEX_OP
210
211 } // namespace tensorflow
212
213 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
214