• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_KERNELS_SCATTER_FUNCTOR_H_
17 #define TENSORFLOW_CORE_KERNELS_SCATTER_FUNCTOR_H_
18 
19 #include <type_traits>
20 
21 #include "third_party/eigen3/Eigen/Core"
22 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
23 #include "tensorflow/core/framework/bounds_check.h"
24 #include "tensorflow/core/framework/op_kernel.h"
25 #include "tensorflow/core/framework/tensor.h"
26 #include "tensorflow/core/framework/variant_op_registry.h"
27 #include "tensorflow/core/kernels/dense_update_functor.h"
28 #include "tensorflow/core/platform/types.h"
29 #include "tensorflow/core/util/work_sharder.h"
30 
31 namespace tensorflow {
32 
33 class OpKernelContext;
34 typedef Eigen::ThreadPoolDevice CPUDevice;
35 typedef Eigen::GpuDevice GPUDevice;
36 
37 namespace scatter_op {
38 
39 enum class UpdateOp { ASSIGN, ADD, SUB, MUL, DIV, MIN, MAX };
40 
41 namespace internal {
42 
43 template <scatter_op::UpdateOp Op>
44 struct Assign {};
45 template <>
46 struct Assign<scatter_op::UpdateOp::ASSIGN> {
47   template <typename Params, typename Update>
48   static void Run(Params p, Update u) {
49     p = u;
50   }
51   template <typename Params, typename Update>
52   static void RunScalar(Params p, Update u) {
53     p.setConstant(u);
54   }
55 };
56 template <>
57 struct Assign<scatter_op::UpdateOp::ADD> {
58   template <typename Params, typename Update>
59   static void Run(Params p, Update u) {
60     p += u;
61   }
62   template <typename Params, typename Update>
63   static void RunScalar(Params p, Update u) {
64     p = p + u;
65   }
66 };
67 template <>
68 struct Assign<scatter_op::UpdateOp::SUB> {
69   template <typename Params, typename Update>
70   static void Run(Params p, Update u) {
71     p -= u;
72   }
73   template <typename Params, typename Update>
74   static void RunScalar(Params p, Update u) {
75     p = p + static_cast<Update>(-u);
76   }
77 };
78 template <>
79 struct Assign<scatter_op::UpdateOp::MUL> {
80   template <typename Params, typename Update>
81   static void Run(Params p, Update u) {
82     p *= u;
83   }
84   template <typename Params, typename Update>
85   static void RunScalar(Params p, Update u) {
86     p = p * u;
87   }
88 };
89 template <>
90 struct Assign<scatter_op::UpdateOp::DIV> {
91   template <typename Params, typename Update>
92   static void Run(Params p, Update u) {
93     p /= u;
94   }
95   template <typename Params, typename Update>
96   static void RunScalar(Params p, Update u) {
97     p = p / u;
98   }
99 };
100 template <>
101 struct Assign<scatter_op::UpdateOp::MIN> {
102   // This method requires that Params and Update are tensor types.
103   template <typename Params, typename Update>
104   static void Run(Params p, Update u) {
105     p = p.cwiseMin(u);
106   }
107   // Same thing, but for Update being a scalar type.
108   template <typename Params, typename Update>
109   static void RunScalar(Params p, Update u) {
110     p = p.cwiseMin(u);
111   }
112 };
113 template <>
114 struct Assign<scatter_op::UpdateOp::MAX> {
115   template <typename Params, typename Update>
116   static void Run(Params p, Update u) {
117     p = p.cwiseMax(u);
118   }
119   template <typename Params, typename Update>
120   static void RunScalar(Params p, Update u) {
121     p = p.cwiseMax(u);
122   }
123 };
124 
125 
126 }  // namespace internal
127 }  // namespace scatter_op
128 
129 namespace functor {
130 template <typename Device, typename T, typename Index, scatter_op::UpdateOp op>
131 struct ScatterFunctor {
132   Index operator()(OpKernelContext* c, const Device& d,
133                    typename TTypes<T>::Matrix params,
134                    typename TTypes<T>::ConstMatrix updates,
135                    typename TTypes<Index>::ConstFlat indices);
136 };
137 
138 template <typename Device, typename T, typename Index, scatter_op::UpdateOp op>
139 struct ScatterFunctorBase {
140   Index ParallelExecute(OpKernelContext* c, const Device& d,
141                         typename TTypes<T>::Matrix params,
142                         typename TTypes<T>::ConstMatrix updates,
143                         typename TTypes<Index>::ConstFlat indices) {
144     const Index N = static_cast<Index>(indices.size());
145     const Index limit = static_cast<Index>(params.dimension(0));
146     const Index kMaxLocks = 1024;
147     const Index entries_per_lock = (limit + kMaxLocks - 1) / kMaxLocks;
148     // To reduce the number of locks and the memory usage, we divide the whole
149     // index space into kMaxLocks regions with each lock serializing access to
150     // a region.
151     mutex accessed[kMaxLocks];
152     std::atomic<Index> bad_index(-1);
153     auto ParallelScatter = [&](Index start, Index end) {
154       for (Index i = start; i < end; ++i) {
155         // Grab the index and check its validity.  Do this carefully,
156         // to avoid checking the value and grabbing it again from
157         // memory a second time (a security risk since it may change in
158         // between).
159         const Index index = ::tensorflow::internal::SubtleMustCopy(indices(i));
160         if (!FastBoundsCheck(index, limit)) {
161           bad_index = i;
162           return;
163         }
164         const Index lock_id = index / entries_per_lock;
165         // Copy last Ndim-1 dimensions of updates[i] to params[index]
166         {
167           mutex_lock l(accessed[lock_id]);
168           scatter_op::internal::Assign<op>::Run(params.template chip<0>(index),
169                                                 updates.template chip<0>(i));
170         }
171       }
172     };
173     const float kMovingCost = 2.5f;
174     float shard_cost = kMovingCost * params.dimension(1);
175     const DeviceBase::CpuWorkerThreads& worker_threads =
176         *(c->device()->tensorflow_cpu_worker_threads());
177     Shard(worker_threads.num_threads, worker_threads.workers, N, shard_cost,
178           ParallelScatter);  // TODO: Come up with a good cost estimate.
179     return bad_index;
180   }
181   Index SerialExecute(OpKernelContext* c, const Device& d,
182                       typename TTypes<T>::Matrix params,
183                       typename TTypes<T>::ConstMatrix updates,
184                       typename TTypes<Index>::ConstFlat indices) {
185     const Index N = static_cast<Index>(indices.size());
186     const Index limit = static_cast<Index>(params.dimension(0));
187     for (Index i = 0; i < N; ++i) {
188       // Grab the index and check its validity.  Do this carefully,
189       // to avoid checking the value and grabbing it again from
190       // memory a second time (a security risk since it may change in
191       // between).
192       const Index index = ::tensorflow::internal::SubtleMustCopy(indices(i));
193       if (!FastBoundsCheck(index, limit)) return i;
194       // Copy last Ndim-1 dimensions of updates[i] to params[index]
195       scatter_op::internal::Assign<op>::Run(params.template chip<0>(index),
196                                             updates.template chip<0>(i));
197     }
198     return -1;
199   }
200 
201   Index operator()(OpKernelContext* c, const Device& d,
202                    typename TTypes<T>::Matrix params,
203                    typename TTypes<T>::ConstMatrix updates,
204                    typename TTypes<Index>::ConstFlat indices) {
205 #ifdef PLATFORM_GOOGLE
206     // The parallel version is significantly slower internally. Only call the
207     // serial version for now.
208     // TODO(penporn): Avoid locking in parallelization (sort beforehand).
209     return SerialExecute(c, d, params, updates, indices);
210 #else
211     // indices and params sizes were validated in DoCompute().
212     const Index N = static_cast<Index>(indices.size());
213     const Index limit = static_cast<Index>(params.dimension(0));
214     const Index min_n_threshold = 1024;
215     const Index ser_par_ratio = 10000;
216     // For parallelizing the updates, duplicate entries need to be handled
217     // correctly. Multiple updates to the same index has to be serialized.
218     // This can lead to lock contention which may nullify the benefits of
219     // parallelization. Assuming uniform random distribution of the indices, we
220     // come up with a rough heuristic and determine whether the updates execute
221     // serially or parallelly. Also if 'N' is small, overheads of parallel
222     // execution outweigh its benefits and hence we check the value of N.
223     const bool execute_serial =
224         ((N < min_n_threshold) || ((N / limit) > ser_par_ratio));
225     if (execute_serial)
226       return SerialExecute(c, d, params, updates, indices);
227     else
228       return ParallelExecute(c, d, params, updates, indices);
229 #endif  // PLATFORM_GOOGLE
230   }
231 };
232 
233 template <typename Device, typename Index>
234 struct ScatterFunctorVariantAssignBase {
235   Index operator()(OpKernelContext* c, const Device& d,
236                    typename TTypes<Variant>::Matrix params,
237                    typename TTypes<Variant>::ConstMatrix updates,
238                    typename TTypes<Index>::ConstFlat indices) {
239     // indices and params sizes were validated in DoCompute().
240     const Index N = static_cast<Index>(indices.size());
241     const Index limit = static_cast<Index>(params.dimension(0));
242     const Index cols = static_cast<Index>(params.dimension(1));
243     DCHECK_EQ(N, updates.dimension(0));
244     DCHECK_EQ(cols, updates.dimension(1));
245     for (Index i = 0; i < N; i++) {
246       // Grab the index and check its validity.  Do this carefully,
247       // to avoid checking the value and grabbing it again from
248       // memory a second time (a security risk since it may change in between).
249       const Index index = ::tensorflow::internal::SubtleMustCopy(indices(i));
250       if (!FastBoundsCheck(index, limit)) return i;
251       // Copy last Ndim-1 dimensions of updates[i] to params[index]
252       for (int j = 0; j < cols; ++j) {
253         const Variant& to_scatter = updates(i, j);
254         params(index, j) = to_scatter;
255       }
256     }
257     return -1;
258   }
259 };
260 
261 template <typename Index>
262 struct ScatterFunctor<CPUDevice, Variant, Index, scatter_op::UpdateOp::ASSIGN>
263     : ScatterFunctorVariantAssignBase<CPUDevice, Index> {};
264 
265 template <typename Index>
266 struct ScatterFunctor<GPUDevice, Variant, Index, scatter_op::UpdateOp::ASSIGN>
267     : ScatterFunctorVariantAssignBase<GPUDevice, Index> {};
268 
269 
270 template <typename T, typename Index>
271 struct ScatterFunctorBase<CPUDevice, T, Index, scatter_op::UpdateOp::ASSIGN> {
272   Index operator()(OpKernelContext* c, const CPUDevice& d,
273                    typename TTypes<T>::Matrix params,
274                    typename TTypes<T>::ConstMatrix updates,
275                    typename TTypes<Index>::ConstFlat indices) {
276     // indices and params sizes were validated in DoCompute().
277     const Index N = static_cast<Index>(indices.size());
278     const Index limit = static_cast<Index>(params.dimension(0));
279     if (!std::is_same<T, tstring>::value) {
280       for (Index i = 0; i < N; i++) {
281         // Grab the index and check its validity.  Do this carefully,
282         // to avoid checking the value and grabbing it again from
283         // memory a second time (a security risk since it may change in
284         // between).
285         const Index index = ::tensorflow::internal::SubtleMustCopy(indices(i));
286         if (!FastBoundsCheck(index, limit)) return i;
287         memmove(params.data() + index * params.dimension(1),
288                 updates.data() + i * updates.dimension(1),
289                 updates.dimension(1) * sizeof(T));
290       }
291     } else {
292       for (Index i = 0; i < N; i++) {
293         // Grab the index and check its validity.  Do this carefully,
294         // to avoid checking the value and grabbing it again from
295         // memory a second time (a security risk since it may change in
296         // between).
297         const Index index = ::tensorflow::internal::SubtleMustCopy(indices(i));
298         if (!FastBoundsCheck(index, limit)) return i;
299         // Copy last Ndim-1 dimensions of updates[i] to params[index]
300         scatter_op::internal::Assign<scatter_op::UpdateOp::ASSIGN>::Run(
301             params.template chip<0>(index), updates.template chip<0>(i));
302       }
303     }
304     return -1;
305   }
306 };
307 
308 template <typename T, typename Index, scatter_op::UpdateOp op>
309 struct ScatterFunctor<CPUDevice, T, Index, op>
310     : ScatterFunctorBase<CPUDevice, T, Index, op> {};
311 
312 
313 template <typename Device, typename T, typename Index, scatter_op::UpdateOp op>
314 struct ScatterScalarFunctor {
315   Index operator()(OpKernelContext* c, const Device& d,
316                    typename TTypes<T>::Matrix params,
317                    const typename TTypes<T>::ConstScalar update,
318                    typename TTypes<Index>::ConstFlat indices);
319 };
320 
321 template <typename Device, typename T, typename Index, scatter_op::UpdateOp op>
322 struct ScatterScalarFunctorBase {
323   Index operator()(OpKernelContext* c, const Device& d,
324                    typename TTypes<T>::Matrix params,
325                    const typename TTypes<T>::ConstScalar update,
326                    typename TTypes<Index>::ConstFlat indices) {
327     // indices and params sizes were validated in DoCompute().
328     const Index N = static_cast<Index>(indices.size());
329     const Index limit = static_cast<Index>(params.dimension(0));
330     for (Index i = 0; i < N; i++) {
331       // Grab the index and check its validity.  Do this carefully,
332       // to avoid checking the value and grabbing it again from
333       // memory a second time (a security risk since it may change in between).
334       const Index index = ::tensorflow::internal::SubtleMustCopy(indices(i));
335       if (!FastBoundsCheck(index, limit)) return i;
336       // Broadcast update to params[index]
337       scatter_op::internal::Assign<op>::RunScalar(
338           params.template chip<0>(index), update());
339     }
340     return -1;
341   }
342 };
343 
344 template <typename Device, typename Index>
345 struct ScatterScalarFunctorVariantAssignBase {
346   Index operator()(OpKernelContext* c, const Device& d,
347                    typename TTypes<Variant>::Matrix params,
348                    const typename TTypes<Variant>::ConstScalar update,
349                    typename TTypes<Index>::ConstFlat indices) {
350     // indices and params sizes were validated in DoCompute().
351     const Index N = static_cast<Index>(indices.size());
352     const Index limit = static_cast<Index>(params.dimension(0));
353     const Index cols = static_cast<Index>(params.dimension(1));
354     const Variant& to_scatter = update();
355     for (Index i = 0; i < N; i++) {
356       // Grab the index and check its validity.  Do this carefully,
357       // to avoid checking the value and grabbing it again from
358       // memory a second time (a security risk since it may change in between).
359       const Index index = ::tensorflow::internal::SubtleMustCopy(indices(i));
360       if (!FastBoundsCheck(index, limit)) return i;
361       // Broadcast update to params[index]
362       for (Index j = 0; j < cols; ++j) {
363         params(index, j) = to_scatter;
364       }
365     }
366     return -1;
367   }
368 };
369 
370 template <typename Index>
371 struct ScatterScalarFunctor<CPUDevice, Variant, Index,
372                             scatter_op::UpdateOp::ASSIGN>
373     : ScatterScalarFunctorVariantAssignBase<CPUDevice, Index> {};
374 template <typename Index>
375 struct ScatterScalarFunctor<GPUDevice, Variant, Index,
376                             scatter_op::UpdateOp::ASSIGN>
377     : ScatterScalarFunctorVariantAssignBase<GPUDevice, Index> {};
378 
379 
380 template <typename T, typename Index>
381 struct ScatterScalarFunctorBase<CPUDevice, T, Index,
382                                 scatter_op::UpdateOp::ASSIGN> {
383   Index operator()(OpKernelContext* c, const CPUDevice& d,
384                    typename TTypes<T>::Matrix params,
385                    const typename TTypes<T>::ConstScalar update,
386                    typename TTypes<Index>::ConstFlat indices) {
387     // indices and params sizes were validated in DoCompute().
388     const Index N = static_cast<Index>(indices.size());
389     const Index limit = static_cast<Index>(params.dimension(0));
390     for (Index i = 0; i < N; i++) {
391       // Grab the index and check its validity.  Do this carefully,
392       // to avoid checking the value and grabbing it again from
393       // memory a second time (a security risk since it may change in between).
394       const Index index = ::tensorflow::internal::SubtleMustCopy(indices(i));
395       if (!FastBoundsCheck(index, limit)) return i;
396       // Broadcast update to params[index]
397       scatter_op::internal::Assign<scatter_op::UpdateOp::ASSIGN>::RunScalar(
398           params.template chip<0>(index), update());
399     }
400     return -1;
401   }
402 };
403 
404 template <typename T, typename Index, scatter_op::UpdateOp op>
405 struct ScatterScalarFunctor<CPUDevice, T, Index, op>
406     : ScatterScalarFunctorBase<CPUDevice, T, Index, op> {};
407 
408 
409 }  // namespace functor
410 }  // namespace tensorflow
411 
412 #endif  // TENSORFLOW_CORE_KERNELS_SCATTER_FUNCTOR_H_
413