• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // See docs in ../ops/state_ops.cc.
17 
18 #include "tensorflow/core/framework/op_kernel.h"
19 #include "tensorflow/core/framework/register_types.h"
20 #include "tensorflow/core/framework/tensor.h"
21 #include "tensorflow/core/kernels/scatter_functor.h"
22 #include "tensorflow/core/platform/mutex.h"
23 #include "tensorflow/core/platform/types.h"
24 #include "tensorflow/core/util/util.h"
25 
26 #ifdef TENSORFLOW_USE_SYCL
27 #include "tensorflow/core/common_runtime/sycl/sycl_util.h"
28 #endif  // TENSORFLOW_USE_SYCL
29 
30 namespace tensorflow {
31 
32 typedef Eigen::ThreadPoolDevice CPUDevice;
33 typedef Eigen::GpuDevice GPUDevice;
34 #ifdef TENSORFLOW_USE_SYCL
35 typedef Eigen::SyclDevice SYCLDevice;
36 #endif  // TENSORFLOW_USE_SYCL
37 
38 // Check whether updates.shape = indices.shape + params.shape[1:]
ValidShapes(const Tensor & params,const Tensor & updates,const Tensor & indices)39 static bool ValidShapes(const Tensor& params, const Tensor& updates,
40                         const Tensor& indices) {
41   if (updates.dims() == 0) return true;
42   if (updates.dims() != indices.dims() + params.dims() - 1) return false;
43   for (int d = 0; d < indices.dims(); d++) {
44     if (updates.dim_size(d) != indices.dim_size(d)) {
45       return false;
46     }
47   }
48   for (int d = 1; d < params.dims(); d++) {
49     if (params.dim_size(d) != updates.dim_size(d - 1 + indices.dims())) {
50       return false;
51     }
52   }
53   return true;
54 }
55 
DoValidationChecking(OpKernelContext * c,const Tensor & params,const Tensor & indices,const Tensor & updates)56 static void DoValidationChecking(OpKernelContext* c, const Tensor& params,
57                                  const Tensor& indices, const Tensor& updates) {
58   OP_REQUIRES(c, params.IsInitialized(),
59               errors::FailedPrecondition("Null ref for params"));
60   OP_REQUIRES(c, TensorShapeUtils::IsVectorOrHigher(params.shape()),
61               errors::InvalidArgument("params must be at least 1-D, got shape ",
62                                       params.shape().DebugString()));
63   OP_REQUIRES(
64       c, ValidShapes(params, updates, indices),
65       errors::InvalidArgument("Must have updates.shape = indices.shape + "
66                               "params.shape[1:] or updates.shape = [], got ",
67                               "updates.shape ", updates.shape().DebugString(),
68                               ", indices.shape ", indices.shape().DebugString(),
69                               ", params.shape ", params.shape().DebugString()));
70 }
71 
72 template <typename Device, typename T, typename Index, scatter_op::UpdateOp op>
73 class ScatterUpdateOp : public OpKernel {
74  public:
75   //   QUESTION: It'd be nice to support DT_INT16, DT_UINT8,
76   //   etc. here.  Should we have the framework do some sort of
77   //   integer promotion automatically, or should that be something
78   //   that users have to do explicitly with a conversion operator
79   //   in the graph?
ScatterUpdateOp(OpKernelConstruction * c)80   explicit ScatterUpdateOp(OpKernelConstruction* c) : OpKernel(c) {
81     OP_REQUIRES_OK(c, c->GetAttr("use_locking", &use_exclusive_lock_));
82   }
83 
Compute(OpKernelContext * c)84   void Compute(OpKernelContext* c) override {
85     if (use_exclusive_lock_) {
86       // Hold mutex while we apply updates
87       mutex_lock l(*c->input_ref_mutex(0));
88       DoCompute(c);
89     } else {
90       DoCompute(c);
91     }
92   }
93 
94  private:
95   bool use_exclusive_lock_;
96 
DoCompute(OpKernelContext * c)97   void DoCompute(OpKernelContext* c) {
98     Tensor params = c->mutable_input(0, use_exclusive_lock_);
99     const Tensor& indices = c->input(1);
100     const Tensor& updates = c->input(2);
101     DoValidationChecking(c, params, indices, updates);
102     if (!c->status().ok()) return;
103 
104     // Check that we have enough index space
105     const int64 N_big = indices.NumElements();
106     OP_REQUIRES(
107         c, N_big <= std::numeric_limits<Index>::max(),
108         errors::InvalidArgument("indices has too many elements for ",
109                                 DataTypeString(DataTypeToEnum<Index>::v()),
110                                 " indexing: ", N_big, " > ",
111                                 std::numeric_limits<Index>::max()));
112     const Index N = static_cast<Index>(indices.NumElements());
113     OP_REQUIRES(
114         c, params.dim_size(0) <= std::numeric_limits<Index>::max(),
115         errors::InvalidArgument("params.shape[0] too large for ",
116                                 DataTypeString(DataTypeToEnum<Index>::v()),
117                                 " indexing: ", params.dim_size(0), " > ",
118                                 std::numeric_limits<Index>::max()));
119 
120     // We always return the input ref.
121     c->forward_ref_input_to_ref_output(0, 0);
122 
123     if (N > 0) {
124       auto indices_flat = indices.flat<Index>();
125       auto params_flat = params.flat_outer_dims<T>();
126 
127       if (TensorShapeUtils::IsScalar(updates.shape()) ||
128           IsLegacyScalar(updates.shape())) {
129         const auto update = updates.scalar<T>();
130         functor::ScatterScalarFunctor<Device, T, Index, op> functor;
131         const Index bad_i = functor(c, c->template eigen_device<Device>(),
132                                     params_flat, update, indices_flat);
133         OP_REQUIRES(c, bad_i < 0,
134                     errors::InvalidArgument(
135                         "indices", SliceDebugString(indices.shape(), bad_i),
136                         " = ", indices_flat(bad_i), " is not in [0, ",
137                         params.dim_size(0), ")"));
138       } else {
139         auto updates_flat =
140             updates.shaped<T, 2>({N, updates.NumElements() / N});
141 
142         functor::ScatterFunctor<Device, T, Index, op> functor;
143         const Index bad_i = functor(c, c->template eigen_device<Device>(),
144                                     params_flat, updates_flat, indices_flat);
145         OP_REQUIRES(c, bad_i < 0,
146                     errors::InvalidArgument(
147                         "indices", SliceDebugString(indices.shape(), bad_i),
148                         " = ", indices_flat(bad_i), " is not in [0, ",
149                         params.dim_size(0), ")"));
150       }
151     }
152   }
153 };
154 
155 #ifdef TENSORFLOW_USE_SYCL
156 template <typename T, typename Index, scatter_op::UpdateOp op>
157 class ScatterUpdateOp<SYCLDevice, T, Index, op> : public OpKernel {
158  public:
ScatterUpdateOp(OpKernelConstruction * c)159   explicit ScatterUpdateOp(OpKernelConstruction* c) : OpKernel(c) {
160     OP_REQUIRES_OK(c, c->GetAttr("use_locking", &use_exclusive_lock_));
161   }
162 
Compute(OpKernelContext * c)163   void Compute(OpKernelContext* c) override {
164     if (use_exclusive_lock_) {
165       // Hold mutex while we apply updates
166       mutex_lock l(*c->input_ref_mutex(0));
167       DoCompute(c);
168     } else {
169       DoCompute(c);
170     }
171   }
172 
173  private:
174   bool use_exclusive_lock_;
175 
DoCompute(OpKernelContext * c)176   void DoCompute(OpKernelContext* c) {
177     Tensor params = c->mutable_input(0, use_exclusive_lock_);
178     const Tensor& indices = c->input(1);
179     const Tensor& updates = c->input(2);
180     DoValidationChecking(c, params, indices, updates);
181     if (!c->status().ok()) return;
182 
183     // Check that we have enough index space
184     const int64 N_big = indices.NumElements();
185     OP_REQUIRES(
186         c, N_big <= std::numeric_limits<Index>::max(),
187         errors::InvalidArgument("indices has too many elements for ",
188                                 DataTypeString(DataTypeToEnum<Index>::v()),
189                                 " indexing: ", N_big, " > ",
190                                 std::numeric_limits<Index>::max()));
191     const Index N = static_cast<Index>(indices.NumElements());
192     OP_REQUIRES(
193         c, params.dim_size(0) <= std::numeric_limits<Index>::max(),
194         errors::InvalidArgument("params.shape[0] too large for ",
195                                 DataTypeString(DataTypeToEnum<Index>::v()),
196                                 " indexing: ", params.dim_size(0), " > ",
197                                 std::numeric_limits<Index>::max()));
198 
199     // We always return the input ref.
200     c->forward_ref_input_to_ref_output(0, 0);
201 
202     if (N > 0) {
203       auto index_size = indices.NumElements() * sizeof(Index);
204       Tensor indices_host = Tensor(indices.dtype(), indices.shape());
205 
206       auto src_ptr = GetBase(&indices);
207       auto dst_ptr = GetBase(&indices_host);
208 
209       c->eigen_sycl_device().memcpyDeviceToHost(
210           dst_ptr, static_cast<const Index*>(src_ptr), index_size);
211 
212       auto indices_flat = indices_host.flat<Index>();
213       auto params_flat = params.flat_outer_dims<T>();
214 
215       if (TensorShapeUtils::IsScalar(updates.shape())) {
216         const auto update = updates.scalar<T>();
217 
218         functor::ScatterScalarFunctorSYCL<T, Index, op> functor;
219         const Index bad_i = functor(c, c->template eigen_device<SYCLDevice>(),
220                                     params_flat, update, indices_flat);
221         OP_REQUIRES(c, bad_i < 0,
222                     errors::InvalidArgument(
223                         "indices", SliceDebugString(indices.shape(), bad_i),
224                         " = ", indices_flat(bad_i), " is not in [0, ",
225                         params.dim_size(0), ")"));
226       } else {
227         auto updates_flat =
228             updates.shaped<T, 2>({N, updates.NumElements() / N});
229 
230         functor::ScatterFunctorSYCL<T, Index, op> functor;
231         const Index bad_i = functor(c, c->template eigen_device<SYCLDevice>(),
232                                     params_flat, updates_flat, indices_flat);
233         OP_REQUIRES(c, bad_i < 0,
234                     errors::InvalidArgument(
235                         "indices", SliceDebugString(indices.shape(), bad_i),
236                         " = ", indices_flat(bad_i), " is not in [0, ",
237                         params.dim_size(0), ")"));
238       }
239     }
240   }
241 };
242 #endif  // TENSORFLOW_USE_SYCL
243 
244 #define REGISTER_SCATTER_KERNEL_INDEX(type, index_type, dev, name, op) \
245   REGISTER_KERNEL_BUILDER(Name(name)                                   \
246                               .Device(DEVICE_##dev)                    \
247                               .TypeConstraint<type>("T")               \
248                               .TypeConstraint<index_type>("Tindices"), \
249                           ScatterUpdateOp<dev##Device, type, index_type, op>)
250 
251 #define REGISTER_SCATTER_KERNEL(type, dev, name, op)         \
252   REGISTER_SCATTER_KERNEL_INDEX(type, int32, dev, name, op); \
253   REGISTER_SCATTER_KERNEL_INDEX(type, int64, dev, name, op);
254 
255 #define REGISTER_SCATTER_ARITHMETIC(type, dev)                                 \
256   REGISTER_SCATTER_KERNEL(type, dev, "ScatterAdd", scatter_op::UpdateOp::ADD); \
257   REGISTER_SCATTER_KERNEL(type, dev, "ScatterDiv", scatter_op::UpdateOp::DIV); \
258   REGISTER_SCATTER_KERNEL(type, dev, "ScatterMul", scatter_op::UpdateOp::MUL); \
259   REGISTER_SCATTER_KERNEL(type, dev, "ScatterSub", scatter_op::UpdateOp::SUB);
260 
261 #define REGISTER_SCATTER_MINMAX(type, dev)                                     \
262   REGISTER_SCATTER_KERNEL(type, dev, "ScatterMin", scatter_op::UpdateOp::MIN); \
263   REGISTER_SCATTER_KERNEL(type, dev, "ScatterMax", scatter_op::UpdateOp::MAX);
264 
265 #define REGISTER_SCATTER_UPDATE(type, dev)            \
266   REGISTER_SCATTER_KERNEL(type, dev, "ScatterUpdate", \
267                           scatter_op::UpdateOp::ASSIGN);
268 
269 // Registers CPU kernels.
270 #define REGISTER_SCATTER_ARITHMETIC_CPU(type) \
271   REGISTER_SCATTER_ARITHMETIC(type, CPU);
272 
273 #define REGISTER_SCATTER_MINMAX_CPU(type) REGISTER_SCATTER_MINMAX(type, CPU);
274 
275 #define REGISTER_SCATTER_UPDATE_CPU(type) REGISTER_SCATTER_UPDATE(type, CPU);
276 
277 TF_CALL_REAL_NUMBER_TYPES(REGISTER_SCATTER_MINMAX_CPU);
278 TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ARITHMETIC_CPU);
279 TF_CALL_ALL_TYPES(REGISTER_SCATTER_UPDATE_CPU);
280 
281 // Registers GPU kernels.
282 #if GOOGLE_CUDA
283 #define REGISTER_SCATTER_ARITHMETIC_GPU(type) \
284   REGISTER_SCATTER_ARITHMETIC(type, GPU);
285 
286 #define REGISTER_SCATTER_MINMAX_GPU(type) REGISTER_SCATTER_MINMAX(type, GPU);
287 
288 #define REGISTER_SCATTER_UPDATE_GPU(type) REGISTER_SCATTER_UPDATE(type, GPU);
289 
290 TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ARITHMETIC_GPU);
291 TF_CALL_GPU_NUMBER_TYPES(REGISTER_SCATTER_MINMAX_GPU);
292 TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_UPDATE_GPU);
293 
294 #endif  // GOOGLE_CUDA
295 
296 // Registers GPU kernels.
297 #if TENSORFLOW_USE_SYCL
298 #define REGISTER_SCATTER_ARITHMETIC_SYCL(type) \
299   REGISTER_SCATTER_ARITHMETIC(type, SYCL);
300 
301 #define REGISTER_SCATTER_MINMAX_SYCL(type) REGISTER_SCATTER_MINMAX(type, SYCL);
302 
303 #define REGISTER_SCATTER_UPDATE_SYCL(type) REGISTER_SCATTER_UPDATE(type, SYCL);
304 
305 TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ARITHMETIC_SYCL);
306 TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_MINMAX_SYCL);
307 TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_UPDATE_SYCL);
308 
309 #undef REGISTER_SCATTER_ARITHMETIC_SYCL
310 #undef REGISTER_SCATTER_MINMAX_SYCL
311 #undef REGISTER_SCATTER_UPDATE_SYCL
312 #endif  // TENSORFLOW_USE_SYCL
313 
314 #undef REGISTER_SCATTER_ARITHMETIC
315 #undef REGISTER_SCATTER_ARITHMETIC_CPU
316 #undef REGISTER_SCATTER_ARITHMETIC_GPU
317 #undef REGISTER_SCATTER_MINMAX
318 #undef REGISTER_SCATTER_MINMAX_CPU
319 #undef REGISTER_SCATTER_MINMAX_GPU
320 #undef REGISTER_SCATTER_UPDATE
321 #undef REGISTER_SCATTER_UPDATE_CPU
322 #undef REGISTER_SCATTER_UPDATE_GPU
323 #undef REGISTER_SCATTER_KERNEL
324 #undef REGISTER_SCATTER_KERNEL_INDEX
325 
326 }  // namespace tensorflow
327