1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // See docs in ../ops/state_ops.cc.
17
18 #include "tensorflow/core/framework/op_kernel.h"
19 #include "tensorflow/core/framework/register_types.h"
20 #include "tensorflow/core/framework/tensor.h"
21 #include "tensorflow/core/kernels/scatter_functor.h"
22 #include "tensorflow/core/platform/mutex.h"
23 #include "tensorflow/core/platform/types.h"
24 #include "tensorflow/core/util/util.h"
25
26 #ifdef TENSORFLOW_USE_SYCL
27 #include "tensorflow/core/common_runtime/sycl/sycl_util.h"
28 #endif // TENSORFLOW_USE_SYCL
29
30 namespace tensorflow {
31
32 typedef Eigen::ThreadPoolDevice CPUDevice;
33 typedef Eigen::GpuDevice GPUDevice;
34 #ifdef TENSORFLOW_USE_SYCL
35 typedef Eigen::SyclDevice SYCLDevice;
36 #endif // TENSORFLOW_USE_SYCL
37
38 // Check whether updates.shape = indices.shape + params.shape[1:]
ValidShapes(const Tensor & params,const Tensor & updates,const Tensor & indices)39 static bool ValidShapes(const Tensor& params, const Tensor& updates,
40 const Tensor& indices) {
41 if (updates.dims() == 0) return true;
42 if (updates.dims() != indices.dims() + params.dims() - 1) return false;
43 for (int d = 0; d < indices.dims(); d++) {
44 if (updates.dim_size(d) != indices.dim_size(d)) {
45 return false;
46 }
47 }
48 for (int d = 1; d < params.dims(); d++) {
49 if (params.dim_size(d) != updates.dim_size(d - 1 + indices.dims())) {
50 return false;
51 }
52 }
53 return true;
54 }
55
DoValidationChecking(OpKernelContext * c,const Tensor & params,const Tensor & indices,const Tensor & updates)56 static void DoValidationChecking(OpKernelContext* c, const Tensor& params,
57 const Tensor& indices, const Tensor& updates) {
58 OP_REQUIRES(c, params.IsInitialized(),
59 errors::FailedPrecondition("Null ref for params"));
60 OP_REQUIRES(c, TensorShapeUtils::IsVectorOrHigher(params.shape()),
61 errors::InvalidArgument("params must be at least 1-D, got shape ",
62 params.shape().DebugString()));
63 OP_REQUIRES(
64 c, ValidShapes(params, updates, indices),
65 errors::InvalidArgument("Must have updates.shape = indices.shape + "
66 "params.shape[1:] or updates.shape = [], got ",
67 "updates.shape ", updates.shape().DebugString(),
68 ", indices.shape ", indices.shape().DebugString(),
69 ", params.shape ", params.shape().DebugString()));
70 }
71
72 template <typename Device, typename T, typename Index, scatter_op::UpdateOp op>
73 class ScatterUpdateOp : public OpKernel {
74 public:
75 // QUESTION: It'd be nice to support DT_INT16, DT_UINT8,
76 // etc. here. Should we have the framework do some sort of
77 // integer promotion automatically, or should that be something
78 // that users have to do explicitly with a conversion operator
79 // in the graph?
ScatterUpdateOp(OpKernelConstruction * c)80 explicit ScatterUpdateOp(OpKernelConstruction* c) : OpKernel(c) {
81 OP_REQUIRES_OK(c, c->GetAttr("use_locking", &use_exclusive_lock_));
82 }
83
Compute(OpKernelContext * c)84 void Compute(OpKernelContext* c) override {
85 if (use_exclusive_lock_) {
86 // Hold mutex while we apply updates
87 mutex_lock l(*c->input_ref_mutex(0));
88 DoCompute(c);
89 } else {
90 DoCompute(c);
91 }
92 }
93
94 private:
95 bool use_exclusive_lock_;
96
DoCompute(OpKernelContext * c)97 void DoCompute(OpKernelContext* c) {
98 Tensor params = c->mutable_input(0, use_exclusive_lock_);
99 const Tensor& indices = c->input(1);
100 const Tensor& updates = c->input(2);
101 DoValidationChecking(c, params, indices, updates);
102 if (!c->status().ok()) return;
103
104 // Check that we have enough index space
105 const int64 N_big = indices.NumElements();
106 OP_REQUIRES(
107 c, N_big <= std::numeric_limits<Index>::max(),
108 errors::InvalidArgument("indices has too many elements for ",
109 DataTypeString(DataTypeToEnum<Index>::v()),
110 " indexing: ", N_big, " > ",
111 std::numeric_limits<Index>::max()));
112 const Index N = static_cast<Index>(indices.NumElements());
113 OP_REQUIRES(
114 c, params.dim_size(0) <= std::numeric_limits<Index>::max(),
115 errors::InvalidArgument("params.shape[0] too large for ",
116 DataTypeString(DataTypeToEnum<Index>::v()),
117 " indexing: ", params.dim_size(0), " > ",
118 std::numeric_limits<Index>::max()));
119
120 // We always return the input ref.
121 c->forward_ref_input_to_ref_output(0, 0);
122
123 if (N > 0) {
124 auto indices_flat = indices.flat<Index>();
125 auto params_flat = params.flat_outer_dims<T>();
126
127 if (TensorShapeUtils::IsScalar(updates.shape()) ||
128 IsLegacyScalar(updates.shape())) {
129 const auto update = updates.scalar<T>();
130 functor::ScatterScalarFunctor<Device, T, Index, op> functor;
131 const Index bad_i = functor(c, c->template eigen_device<Device>(),
132 params_flat, update, indices_flat);
133 OP_REQUIRES(c, bad_i < 0,
134 errors::InvalidArgument(
135 "indices", SliceDebugString(indices.shape(), bad_i),
136 " = ", indices_flat(bad_i), " is not in [0, ",
137 params.dim_size(0), ")"));
138 } else {
139 auto updates_flat =
140 updates.shaped<T, 2>({N, updates.NumElements() / N});
141
142 functor::ScatterFunctor<Device, T, Index, op> functor;
143 const Index bad_i = functor(c, c->template eigen_device<Device>(),
144 params_flat, updates_flat, indices_flat);
145 OP_REQUIRES(c, bad_i < 0,
146 errors::InvalidArgument(
147 "indices", SliceDebugString(indices.shape(), bad_i),
148 " = ", indices_flat(bad_i), " is not in [0, ",
149 params.dim_size(0), ")"));
150 }
151 }
152 }
153 };
154
155 #ifdef TENSORFLOW_USE_SYCL
156 template <typename T, typename Index, scatter_op::UpdateOp op>
157 class ScatterUpdateOp<SYCLDevice, T, Index, op> : public OpKernel {
158 public:
ScatterUpdateOp(OpKernelConstruction * c)159 explicit ScatterUpdateOp(OpKernelConstruction* c) : OpKernel(c) {
160 OP_REQUIRES_OK(c, c->GetAttr("use_locking", &use_exclusive_lock_));
161 }
162
Compute(OpKernelContext * c)163 void Compute(OpKernelContext* c) override {
164 if (use_exclusive_lock_) {
165 // Hold mutex while we apply updates
166 mutex_lock l(*c->input_ref_mutex(0));
167 DoCompute(c);
168 } else {
169 DoCompute(c);
170 }
171 }
172
173 private:
174 bool use_exclusive_lock_;
175
DoCompute(OpKernelContext * c)176 void DoCompute(OpKernelContext* c) {
177 Tensor params = c->mutable_input(0, use_exclusive_lock_);
178 const Tensor& indices = c->input(1);
179 const Tensor& updates = c->input(2);
180 DoValidationChecking(c, params, indices, updates);
181 if (!c->status().ok()) return;
182
183 // Check that we have enough index space
184 const int64 N_big = indices.NumElements();
185 OP_REQUIRES(
186 c, N_big <= std::numeric_limits<Index>::max(),
187 errors::InvalidArgument("indices has too many elements for ",
188 DataTypeString(DataTypeToEnum<Index>::v()),
189 " indexing: ", N_big, " > ",
190 std::numeric_limits<Index>::max()));
191 const Index N = static_cast<Index>(indices.NumElements());
192 OP_REQUIRES(
193 c, params.dim_size(0) <= std::numeric_limits<Index>::max(),
194 errors::InvalidArgument("params.shape[0] too large for ",
195 DataTypeString(DataTypeToEnum<Index>::v()),
196 " indexing: ", params.dim_size(0), " > ",
197 std::numeric_limits<Index>::max()));
198
199 // We always return the input ref.
200 c->forward_ref_input_to_ref_output(0, 0);
201
202 if (N > 0) {
203 auto index_size = indices.NumElements() * sizeof(Index);
204 Tensor indices_host = Tensor(indices.dtype(), indices.shape());
205
206 auto src_ptr = GetBase(&indices);
207 auto dst_ptr = GetBase(&indices_host);
208
209 c->eigen_sycl_device().memcpyDeviceToHost(
210 dst_ptr, static_cast<const Index*>(src_ptr), index_size);
211
212 auto indices_flat = indices_host.flat<Index>();
213 auto params_flat = params.flat_outer_dims<T>();
214
215 if (TensorShapeUtils::IsScalar(updates.shape())) {
216 const auto update = updates.scalar<T>();
217
218 functor::ScatterScalarFunctorSYCL<T, Index, op> functor;
219 const Index bad_i = functor(c, c->template eigen_device<SYCLDevice>(),
220 params_flat, update, indices_flat);
221 OP_REQUIRES(c, bad_i < 0,
222 errors::InvalidArgument(
223 "indices", SliceDebugString(indices.shape(), bad_i),
224 " = ", indices_flat(bad_i), " is not in [0, ",
225 params.dim_size(0), ")"));
226 } else {
227 auto updates_flat =
228 updates.shaped<T, 2>({N, updates.NumElements() / N});
229
230 functor::ScatterFunctorSYCL<T, Index, op> functor;
231 const Index bad_i = functor(c, c->template eigen_device<SYCLDevice>(),
232 params_flat, updates_flat, indices_flat);
233 OP_REQUIRES(c, bad_i < 0,
234 errors::InvalidArgument(
235 "indices", SliceDebugString(indices.shape(), bad_i),
236 " = ", indices_flat(bad_i), " is not in [0, ",
237 params.dim_size(0), ")"));
238 }
239 }
240 }
241 };
242 #endif // TENSORFLOW_USE_SYCL
243
244 #define REGISTER_SCATTER_KERNEL_INDEX(type, index_type, dev, name, op) \
245 REGISTER_KERNEL_BUILDER(Name(name) \
246 .Device(DEVICE_##dev) \
247 .TypeConstraint<type>("T") \
248 .TypeConstraint<index_type>("Tindices"), \
249 ScatterUpdateOp<dev##Device, type, index_type, op>)
250
251 #define REGISTER_SCATTER_KERNEL(type, dev, name, op) \
252 REGISTER_SCATTER_KERNEL_INDEX(type, int32, dev, name, op); \
253 REGISTER_SCATTER_KERNEL_INDEX(type, int64, dev, name, op);
254
255 #define REGISTER_SCATTER_ARITHMETIC(type, dev) \
256 REGISTER_SCATTER_KERNEL(type, dev, "ScatterAdd", scatter_op::UpdateOp::ADD); \
257 REGISTER_SCATTER_KERNEL(type, dev, "ScatterDiv", scatter_op::UpdateOp::DIV); \
258 REGISTER_SCATTER_KERNEL(type, dev, "ScatterMul", scatter_op::UpdateOp::MUL); \
259 REGISTER_SCATTER_KERNEL(type, dev, "ScatterSub", scatter_op::UpdateOp::SUB);
260
261 #define REGISTER_SCATTER_MINMAX(type, dev) \
262 REGISTER_SCATTER_KERNEL(type, dev, "ScatterMin", scatter_op::UpdateOp::MIN); \
263 REGISTER_SCATTER_KERNEL(type, dev, "ScatterMax", scatter_op::UpdateOp::MAX);
264
265 #define REGISTER_SCATTER_UPDATE(type, dev) \
266 REGISTER_SCATTER_KERNEL(type, dev, "ScatterUpdate", \
267 scatter_op::UpdateOp::ASSIGN);
268
269 // Registers CPU kernels.
270 #define REGISTER_SCATTER_ARITHMETIC_CPU(type) \
271 REGISTER_SCATTER_ARITHMETIC(type, CPU);
272
273 #define REGISTER_SCATTER_MINMAX_CPU(type) REGISTER_SCATTER_MINMAX(type, CPU);
274
275 #define REGISTER_SCATTER_UPDATE_CPU(type) REGISTER_SCATTER_UPDATE(type, CPU);
276
277 TF_CALL_REAL_NUMBER_TYPES(REGISTER_SCATTER_MINMAX_CPU);
278 TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ARITHMETIC_CPU);
279 TF_CALL_ALL_TYPES(REGISTER_SCATTER_UPDATE_CPU);
280
281 // Registers GPU kernels.
282 #if GOOGLE_CUDA
283 #define REGISTER_SCATTER_ARITHMETIC_GPU(type) \
284 REGISTER_SCATTER_ARITHMETIC(type, GPU);
285
286 #define REGISTER_SCATTER_MINMAX_GPU(type) REGISTER_SCATTER_MINMAX(type, GPU);
287
288 #define REGISTER_SCATTER_UPDATE_GPU(type) REGISTER_SCATTER_UPDATE(type, GPU);
289
290 TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ARITHMETIC_GPU);
291 TF_CALL_GPU_NUMBER_TYPES(REGISTER_SCATTER_MINMAX_GPU);
292 TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_UPDATE_GPU);
293
294 #endif // GOOGLE_CUDA
295
296 // Registers GPU kernels.
297 #if TENSORFLOW_USE_SYCL
298 #define REGISTER_SCATTER_ARITHMETIC_SYCL(type) \
299 REGISTER_SCATTER_ARITHMETIC(type, SYCL);
300
301 #define REGISTER_SCATTER_MINMAX_SYCL(type) REGISTER_SCATTER_MINMAX(type, SYCL);
302
303 #define REGISTER_SCATTER_UPDATE_SYCL(type) REGISTER_SCATTER_UPDATE(type, SYCL);
304
305 TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ARITHMETIC_SYCL);
306 TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_MINMAX_SYCL);
307 TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_UPDATE_SYCL);
308
309 #undef REGISTER_SCATTER_ARITHMETIC_SYCL
310 #undef REGISTER_SCATTER_MINMAX_SYCL
311 #undef REGISTER_SCATTER_UPDATE_SYCL
312 #endif // TENSORFLOW_USE_SYCL
313
314 #undef REGISTER_SCATTER_ARITHMETIC
315 #undef REGISTER_SCATTER_ARITHMETIC_CPU
316 #undef REGISTER_SCATTER_ARITHMETIC_GPU
317 #undef REGISTER_SCATTER_MINMAX
318 #undef REGISTER_SCATTER_MINMAX_CPU
319 #undef REGISTER_SCATTER_MINMAX_GPU
320 #undef REGISTER_SCATTER_UPDATE
321 #undef REGISTER_SCATTER_UPDATE_CPU
322 #undef REGISTER_SCATTER_UPDATE_GPU
323 #undef REGISTER_SCATTER_KERNEL
324 #undef REGISTER_SCATTER_KERNEL_INDEX
325
326 } // namespace tensorflow
327