• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #define EIGEN_USE_THREADS
17 
18 #if GOOGLE_CUDA
19 #define EIGEN_USE_GPU
20 #endif  // GOOGLE_CUDA
21 
22 #include "tensorflow/core/framework/bounds_check.h"
23 #include "tensorflow/core/framework/register_types.h"
24 #include "tensorflow/core/kernels/cwise_ops_common.h"
25 #include "tensorflow/core/platform/prefetch.h"
26 
27 namespace tensorflow {
28 
29 typedef Eigen::ThreadPoolDevice CPUDevice;
30 typedef Eigen::GpuDevice GPUDevice;
31 
32 #ifdef TENSORFLOW_USE_SYCL
33 typedef Eigen::SyclDevice SYCLDevice;
34 #endif  // TENSORFLOW_USE_SYCL
35 
36 namespace functor {
37 template <typename Device, typename T>
38 struct SelectScalarHandler;
39 }  // namespace functor
40 
41 template <typename Device, typename T>
42 class SelectOp : public OpKernel {
43  public:
SelectOp(OpKernelConstruction * context)44   explicit SelectOp(OpKernelConstruction* context) : OpKernel(context) {}
45 
Compute(OpKernelContext * ctx)46   void Compute(OpKernelContext* ctx) override {
47     const Tensor* cond;
48     const Tensor* then;
49     const Tensor* else_;
50     OP_REQUIRES_OK(ctx, ctx->input("condition", &cond));
51     OP_REQUIRES_OK(ctx, ctx->input("t", &then));
52     OP_REQUIRES_OK(ctx, ctx->input("e", &else_));
53 
54     if (TensorShapeUtils::IsScalar(cond->shape())) {
55       ComputeScalar(ctx, cond, then, else_);
56       return;
57     }
58 
59     bool broadcasting = (TensorShapeUtils::IsVector(cond->shape()) &&
60                          !TensorShapeUtils::IsVector(then->shape()));
61 
62     if (broadcasting) {
63       ComputeBroadcasting(ctx, cond, then, else_);
64     } else {
65       ComputeElementwise(ctx, cond, then, else_);
66     }
67   }
68 
69  protected:
ComputeBroadcasting(OpKernelContext * ctx,const Tensor * cond,const Tensor * then,const Tensor * else_)70   void ComputeBroadcasting(OpKernelContext* ctx, const Tensor* cond,
71                            const Tensor* then, const Tensor* else_) {
72     // Preliminary validation of sizes.
73     OP_REQUIRES(
74         ctx, TensorShapeUtils::IsVector(cond->shape()),
75         errors::InvalidArgument("'cond' must be a vector, but saw shape: ",
76                                 cond->shape().DebugString()));
77     OP_REQUIRES(
78         ctx,
79         FastBoundsCheck(cond->NumElements(),
80                         std::numeric_limits<Eigen::DenseIndex>::max()),
81         errors::InvalidArgument("cond vector larger than ",
82                                 std::numeric_limits<Eigen::DenseIndex>::max()));
83     OP_REQUIRES(
84         ctx,
85         FastBoundsCheck(then->flat_outer_dims<T>().dimension(1),
86                         std::numeric_limits<Eigen::DenseIndex>::max()),
87         errors::InvalidArgument("flat outer dims dim 1 size >= ",
88                                 std::numeric_limits<Eigen::DenseIndex>::max()));
89 
90     OP_REQUIRES(ctx, TensorShapeUtils::IsVectorOrHigher(then->shape()),
91                 errors::InvalidArgument(
92                     "'then' must be at least a vector, but saw shape: ",
93                     then->shape().DebugString()));
94     OP_REQUIRES(
95         ctx, then->shape().dim_size(0) == cond->NumElements(),
96         errors::InvalidArgument(
97             "Number of batches of 'then' must match size of 'cond', but saw: ",
98             then->shape().dim_size(0), " vs. ", cond->NumElements()));
99     OP_REQUIRES(
100         ctx, then->shape().IsSameSize(else_->shape()),
101         errors::InvalidArgument(
102             "'then' and 'else' must have the same size.  but received: ",
103             then->shape().DebugString(), " vs. ",
104             else_->shape().DebugString()));
105 
106     Tensor* output = nullptr;
107     OP_REQUIRES_OK(ctx, ctx->forward_input_or_allocate_output(
108                             {"t", "e"}, "output", then->shape(), &output));
109     if (output->NumElements() > 0) {
110       functor::BatchSelectFunctor<Device, T> func;
111       func(ctx->eigen_device<Device>(), output->flat_outer_dims<T>(),
112            cond->vec<bool>(), then->flat_outer_dims<T>(),
113            else_->flat_outer_dims<T>());
114     }
115   }
116 
ComputeElementwise(OpKernelContext * ctx,const Tensor * cond,const Tensor * then,const Tensor * else_)117   void ComputeElementwise(OpKernelContext* ctx, const Tensor* cond,
118                           const Tensor* then, const Tensor* else_) {
119     if (!ctx->ValidateInputsAreSameShape(this)) return;
120     Tensor* output = nullptr;
121     OP_REQUIRES_OK(ctx, ctx->forward_input_or_allocate_output(
122                             {"t", "e"}, "output", then->shape(), &output));
123     if (output->NumElements() > 0) {
124       functor::SelectFunctor<Device, T> func;
125       func(ctx->eigen_device<Device>(), output->flat<T>(), cond->flat<bool>(),
126            then->flat<T>(), else_->flat<T>());
127     }
128   }
129 
ComputeScalar(OpKernelContext * ctx,const Tensor * cond,const Tensor * then,const Tensor * else_)130   void ComputeScalar(OpKernelContext* ctx, const Tensor* cond,
131                      const Tensor* then, const Tensor* else_) {
132     OP_REQUIRES(
133         ctx, then->shape().IsSameSize(else_->shape()),
134         errors::InvalidArgument(
135             "'then' and 'else' must have the same size.  but received: ",
136             then->shape().DebugString(), " vs. ",
137             else_->shape().DebugString()));
138 
139     functor::SelectScalarHandler<Device, T> handler;
140     handler(ctx, cond, then, else_);
141   }
142 
143  private:
144   TF_DISALLOW_COPY_AND_ASSIGN(SelectOp);
145 };
146 
147 #define REGISTER_SELECT(type)                                      \
148   REGISTER_KERNEL_BUILDER(                                         \
149       Name("Select").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
150       SelectOp<CPUDevice, type>);
151 
152 TF_CALL_ALL_TYPES(REGISTER_SELECT);
153 
154 #if GOOGLE_CUDA
155 
156 // Registration of the GPU implementations.
157 #define REGISTER_SELECT_GPU(type)                                  \
158   REGISTER_KERNEL_BUILDER(                                         \
159       Name("Select").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
160       SelectOp<GPUDevice, type>);
161 
162 REGISTER_SELECT_GPU(bool);
163 REGISTER_SELECT_GPU(Eigen::half);
164 REGISTER_SELECT_GPU(float);
165 REGISTER_SELECT_GPU(double);
166 REGISTER_SELECT_GPU(int32);
167 REGISTER_SELECT_GPU(int64);
168 REGISTER_SELECT_GPU(complex64);
169 REGISTER_SELECT_GPU(complex128);
170 
171 #undef REGISTER_SELECT_GPU
172 
173 #endif  // GOOGLE_CUDA
174 
175 #ifdef TENSORFLOW_USE_SYCL
176 // Registration of the SYCL implementations.
177 #define REGISTER_SELECT_SYCL(type)                                  \
178   REGISTER_KERNEL_BUILDER(                                          \
179       Name("Select").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
180       SelectOp<SYCLDevice, type>);
181 
182 REGISTER_SELECT_SYCL(float);
183 REGISTER_SELECT_SYCL(double);
184 REGISTER_SELECT_SYCL(int32);
185 REGISTER_SELECT_SYCL(int64);
186 #undef REGISTER_SELECT_SYCL
187 #endif  // TENSORFLOW_USE_SYCL
188 
189 namespace functor {
190 
191 // CPU Specializations of Select functors.
192 template <typename Device, typename T>
193 struct SelectFunctorBase {
operator ()tensorflow::functor::SelectFunctorBase194   void operator()(const Device& d, typename TTypes<T>::Flat out,
195                   typename TTypes<bool>::ConstFlat cond_flat,
196                   typename TTypes<T>::ConstFlat then_flat,
197                   typename TTypes<T>::ConstFlat else_flat) {
198     Assign(d, out, cond_flat.select(then_flat, else_flat));
199   }
200 };
201 
202 template <typename T>
203 struct SelectFunctor<CPUDevice, T> : SelectFunctorBase<CPUDevice, T> {};
204 #ifdef TENSORFLOW_USE_SYCL
205 template <typename T>
206 struct SelectFunctor<SYCLDevice, T> : SelectFunctorBase<SYCLDevice, T> {};
207 #endif  // TENSORFLOW_USE_SYCL
208 
209 template <typename Device, typename T>
210 struct SelectScalarHandler {
operator ()tensorflow::functor::SelectScalarHandler211   void operator()(OpKernelContext* ctx, const Tensor* cond, const Tensor* then,
212                   const Tensor* else_) {
213     Tensor* output = nullptr;
214     OP_REQUIRES_OK(ctx, ctx->forward_input_or_allocate_output(
215                             {"t", "e"}, "output", then->shape(), &output));
216 
217     if (output->NumElements() > 0) {
218       functor::SelectScalarFunctor<Device, T> func;
219       TTypes<bool>::ConstScalar cond_scalar = cond->scalar<bool>();
220       func(ctx->eigen_device<Device>(), output->flat<T>(), cond_scalar,
221            then->flat<T>(), else_->flat<T>());
222     }
223   }
224 };
225 
226 // Specilization for CPU device. Forward input to output depending on the `cond`
227 // value.
228 // TODO(sjhwang): Consider specializing for GPUDevice as well by using
229 // GPUDevice::memcpyDeviceToHost() to fetch bool value.
230 template <typename T>
231 struct SelectScalarHandler<CPUDevice, T> {
operator ()tensorflow::functor::SelectScalarHandler232   void operator()(OpKernelContext* ctx, const Tensor* cond, const Tensor* then,
233                   const Tensor* else_) {
234     if (cond->scalar<bool>()()) {
235       OP_REQUIRES_OK(ctx, ctx->set_output("output", *then));
236     } else {
237       OP_REQUIRES_OK(ctx, ctx->set_output("output", *else_));
238     }
239   }
240 };
241 
242 #ifdef TENSORFLOW_USE_SYCL
243 template <typename Device, typename T>
244 struct SelectScalarFunctorBase {
operator ()tensorflow::functor::SelectScalarFunctorBase245   void operator()(const Device& d, typename TTypes<T>::Flat out,
246                   TTypes<bool>::ConstScalar cond,
247                   typename TTypes<T>::ConstFlat then_flat,
248                   typename TTypes<T>::ConstFlat else_flat) {
249     out.device(d) = cond() ? then_flat : else_flat;
250   }
251 };
252 
253 template <typename T>
254 struct SelectScalarFunctor<SYCLDevice, T>
255     : SelectScalarFunctorBase<SYCLDevice, T> {};
256 #endif  // TENSORFLOW_USE_SYCL
257 
258 template <typename Device, typename T>
259 struct BatchSelectFunctorBase {
operator ()tensorflow::functor::BatchSelectFunctorBase260   void operator()(const Device& d,
261                   typename TTypes<T>::Matrix output_flat_outer_dims,
262                   TTypes<bool>::ConstVec cond_vec,
263                   typename TTypes<T>::ConstMatrix then_flat_outer_dims,
264                   typename TTypes<T>::ConstMatrix else_flat_outer_dims) {
265     const Eigen::DenseIndex batch = cond_vec.size();
266     const Eigen::DenseIndex all_but_batch = then_flat_outer_dims.dimension(1);
267 
268 #if !defined(EIGEN_HAS_INDEX_LIST)
269     Eigen::array<Eigen::DenseIndex, 2> broadcast_dims{{1, all_but_batch}};
270     Eigen::Tensor<Eigen::DenseIndex, 2>::Dimensions reshape_dims{{batch, 1}};
271 #else
272     Eigen::IndexList<Eigen::type2index<1>, Eigen::DenseIndex> broadcast_dims;
273     broadcast_dims.set(1, all_but_batch);
274     Eigen::IndexList<Eigen::DenseIndex, Eigen::type2index<1> > reshape_dims;
275     reshape_dims.set(0, batch);
276 #endif
277 
278     Assign(d, output_flat_outer_dims,
279            cond_vec.reshape(reshape_dims)
280                .broadcast(broadcast_dims)
281                .select(then_flat_outer_dims, else_flat_outer_dims));
282   }
283 };
284 
285 // A fast implementation on CPU, using loop to get rid of broadcasting.
286 template <typename T>
287 struct BatchSelectFunctor<CPUDevice, T> {
operator ()tensorflow::functor::BatchSelectFunctor288   void operator()(const CPUDevice& d,
289                   typename TTypes<T>::Matrix output_flat_outer_dims,
290                   TTypes<bool>::ConstVec cond_vec,
291                   typename TTypes<T>::ConstMatrix then_flat_outer_dims,
292                   typename TTypes<T>::ConstMatrix else_flat_outer_dims) {
293     const size_t batch = cond_vec.size();
294     const size_t batch_size = then_flat_outer_dims.size() / batch;
295     T* output = output_flat_outer_dims.data();
296     const bool* c = cond_vec.data();
297     const T* t = then_flat_outer_dims.data();
298     const T* e = else_flat_outer_dims.data();
299 
300     auto work = [batch_size, output, c, t, e](int64 start, int64 end) {
301       for (size_t i = start; i < end; ++i) {
302         size_t offset = i * batch_size;
303         port::prefetch<port::PREFETCH_HINT_NTA>(
304             reinterpret_cast<const void*>(&t[offset + batch_size]));
305         port::prefetch<port::PREFETCH_HINT_NTA>(
306             reinterpret_cast<const void*>(&e[offset + batch_size]));
307         port::prefetch<port::PREFETCH_HINT_NTA>(
308             reinterpret_cast<const void*>(&c[i + 1]));
309         if (c[i]) {
310           for (size_t j = 0; j < batch_size; ++j) {
311             output[offset + j] = t[offset + j];
312           }
313         } else {
314           for (size_t j = 0; j < batch_size; ++j) {
315             output[offset + j] = e[offset + j];
316           }
317         }
318       }
319     };
320     auto cost = Eigen::TensorOpCost(sizeof(T) * batch_size * 2,  // ld bytes
321                                     sizeof(T) * batch_size,      // st bytes
322                                     batch_size);  // compute cycles
323     d.parallelFor(batch, cost, work);
324   }
325 };
326 
327 #ifdef TENSORFLOW_USE_SYCL
328 template <typename T>
329 struct BatchSelectFunctor<SYCLDevice, T>
330     : BatchSelectFunctorBase<SYCLDevice, T> {};
331 #endif  // TENSORFLOW_USE_SYCL
332 
333 }  // namespace functor
334 
335 }  // namespace tensorflow
336