1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #define EIGEN_USE_THREADS 17 18 #if GOOGLE_CUDA 19 #define EIGEN_USE_GPU 20 #endif // GOOGLE_CUDA 21 22 #include "tensorflow/core/framework/bounds_check.h" 23 #include "tensorflow/core/framework/register_types.h" 24 #include "tensorflow/core/kernels/cwise_ops_common.h" 25 #include "tensorflow/core/platform/prefetch.h" 26 27 namespace tensorflow { 28 29 typedef Eigen::ThreadPoolDevice CPUDevice; 30 typedef Eigen::GpuDevice GPUDevice; 31 32 #ifdef TENSORFLOW_USE_SYCL 33 typedef Eigen::SyclDevice SYCLDevice; 34 #endif // TENSORFLOW_USE_SYCL 35 36 namespace functor { 37 template <typename Device, typename T> 38 struct SelectScalarHandler; 39 } // namespace functor 40 41 template <typename Device, typename T> 42 class SelectOp : public OpKernel { 43 public: SelectOp(OpKernelConstruction * context)44 explicit SelectOp(OpKernelConstruction* context) : OpKernel(context) {} 45 Compute(OpKernelContext * ctx)46 void Compute(OpKernelContext* ctx) override { 47 const Tensor* cond; 48 const Tensor* then; 49 const Tensor* else_; 50 OP_REQUIRES_OK(ctx, ctx->input("condition", &cond)); 51 OP_REQUIRES_OK(ctx, ctx->input("t", &then)); 52 OP_REQUIRES_OK(ctx, ctx->input("e", &else_)); 53 54 if (TensorShapeUtils::IsScalar(cond->shape())) { 55 ComputeScalar(ctx, cond, then, else_); 56 return; 57 } 58 59 bool broadcasting = (TensorShapeUtils::IsVector(cond->shape()) && 60 !TensorShapeUtils::IsVector(then->shape())); 61 62 if (broadcasting) { 63 ComputeBroadcasting(ctx, cond, then, else_); 64 } else { 65 ComputeElementwise(ctx, cond, then, else_); 66 } 67 } 68 69 protected: ComputeBroadcasting(OpKernelContext * ctx,const Tensor * cond,const Tensor * then,const Tensor * else_)70 void ComputeBroadcasting(OpKernelContext* ctx, const Tensor* cond, 71 const Tensor* then, const Tensor* else_) { 72 // Preliminary validation of sizes. 73 OP_REQUIRES( 74 ctx, TensorShapeUtils::IsVector(cond->shape()), 75 errors::InvalidArgument("'cond' must be a vector, but saw shape: ", 76 cond->shape().DebugString())); 77 OP_REQUIRES( 78 ctx, 79 FastBoundsCheck(cond->NumElements(), 80 std::numeric_limits<Eigen::DenseIndex>::max()), 81 errors::InvalidArgument("cond vector larger than ", 82 std::numeric_limits<Eigen::DenseIndex>::max())); 83 OP_REQUIRES( 84 ctx, 85 FastBoundsCheck(then->flat_outer_dims<T>().dimension(1), 86 std::numeric_limits<Eigen::DenseIndex>::max()), 87 errors::InvalidArgument("flat outer dims dim 1 size >= ", 88 std::numeric_limits<Eigen::DenseIndex>::max())); 89 90 OP_REQUIRES(ctx, TensorShapeUtils::IsVectorOrHigher(then->shape()), 91 errors::InvalidArgument( 92 "'then' must be at least a vector, but saw shape: ", 93 then->shape().DebugString())); 94 OP_REQUIRES( 95 ctx, then->shape().dim_size(0) == cond->NumElements(), 96 errors::InvalidArgument( 97 "Number of batches of 'then' must match size of 'cond', but saw: ", 98 then->shape().dim_size(0), " vs. ", cond->NumElements())); 99 OP_REQUIRES( 100 ctx, then->shape().IsSameSize(else_->shape()), 101 errors::InvalidArgument( 102 "'then' and 'else' must have the same size. but received: ", 103 then->shape().DebugString(), " vs. ", 104 else_->shape().DebugString())); 105 106 Tensor* output = nullptr; 107 OP_REQUIRES_OK(ctx, ctx->forward_input_or_allocate_output( 108 {"t", "e"}, "output", then->shape(), &output)); 109 if (output->NumElements() > 0) { 110 functor::BatchSelectFunctor<Device, T> func; 111 func(ctx->eigen_device<Device>(), output->flat_outer_dims<T>(), 112 cond->vec<bool>(), then->flat_outer_dims<T>(), 113 else_->flat_outer_dims<T>()); 114 } 115 } 116 ComputeElementwise(OpKernelContext * ctx,const Tensor * cond,const Tensor * then,const Tensor * else_)117 void ComputeElementwise(OpKernelContext* ctx, const Tensor* cond, 118 const Tensor* then, const Tensor* else_) { 119 if (!ctx->ValidateInputsAreSameShape(this)) return; 120 Tensor* output = nullptr; 121 OP_REQUIRES_OK(ctx, ctx->forward_input_or_allocate_output( 122 {"t", "e"}, "output", then->shape(), &output)); 123 if (output->NumElements() > 0) { 124 functor::SelectFunctor<Device, T> func; 125 func(ctx->eigen_device<Device>(), output->flat<T>(), cond->flat<bool>(), 126 then->flat<T>(), else_->flat<T>()); 127 } 128 } 129 ComputeScalar(OpKernelContext * ctx,const Tensor * cond,const Tensor * then,const Tensor * else_)130 void ComputeScalar(OpKernelContext* ctx, const Tensor* cond, 131 const Tensor* then, const Tensor* else_) { 132 OP_REQUIRES( 133 ctx, then->shape().IsSameSize(else_->shape()), 134 errors::InvalidArgument( 135 "'then' and 'else' must have the same size. but received: ", 136 then->shape().DebugString(), " vs. ", 137 else_->shape().DebugString())); 138 139 functor::SelectScalarHandler<Device, T> handler; 140 handler(ctx, cond, then, else_); 141 } 142 143 private: 144 TF_DISALLOW_COPY_AND_ASSIGN(SelectOp); 145 }; 146 147 #define REGISTER_SELECT(type) \ 148 REGISTER_KERNEL_BUILDER( \ 149 Name("Select").Device(DEVICE_CPU).TypeConstraint<type>("T"), \ 150 SelectOp<CPUDevice, type>); 151 152 TF_CALL_ALL_TYPES(REGISTER_SELECT); 153 154 #if GOOGLE_CUDA 155 156 // Registration of the GPU implementations. 157 #define REGISTER_SELECT_GPU(type) \ 158 REGISTER_KERNEL_BUILDER( \ 159 Name("Select").Device(DEVICE_GPU).TypeConstraint<type>("T"), \ 160 SelectOp<GPUDevice, type>); 161 162 REGISTER_SELECT_GPU(bool); 163 REGISTER_SELECT_GPU(Eigen::half); 164 REGISTER_SELECT_GPU(float); 165 REGISTER_SELECT_GPU(double); 166 REGISTER_SELECT_GPU(int32); 167 REGISTER_SELECT_GPU(int64); 168 REGISTER_SELECT_GPU(complex64); 169 REGISTER_SELECT_GPU(complex128); 170 171 #undef REGISTER_SELECT_GPU 172 173 #endif // GOOGLE_CUDA 174 175 #ifdef TENSORFLOW_USE_SYCL 176 // Registration of the SYCL implementations. 177 #define REGISTER_SELECT_SYCL(type) \ 178 REGISTER_KERNEL_BUILDER( \ 179 Name("Select").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \ 180 SelectOp<SYCLDevice, type>); 181 182 REGISTER_SELECT_SYCL(float); 183 REGISTER_SELECT_SYCL(double); 184 REGISTER_SELECT_SYCL(int32); 185 REGISTER_SELECT_SYCL(int64); 186 #undef REGISTER_SELECT_SYCL 187 #endif // TENSORFLOW_USE_SYCL 188 189 namespace functor { 190 191 // CPU Specializations of Select functors. 192 template <typename Device, typename T> 193 struct SelectFunctorBase { operator ()tensorflow::functor::SelectFunctorBase194 void operator()(const Device& d, typename TTypes<T>::Flat out, 195 typename TTypes<bool>::ConstFlat cond_flat, 196 typename TTypes<T>::ConstFlat then_flat, 197 typename TTypes<T>::ConstFlat else_flat) { 198 Assign(d, out, cond_flat.select(then_flat, else_flat)); 199 } 200 }; 201 202 template <typename T> 203 struct SelectFunctor<CPUDevice, T> : SelectFunctorBase<CPUDevice, T> {}; 204 #ifdef TENSORFLOW_USE_SYCL 205 template <typename T> 206 struct SelectFunctor<SYCLDevice, T> : SelectFunctorBase<SYCLDevice, T> {}; 207 #endif // TENSORFLOW_USE_SYCL 208 209 template <typename Device, typename T> 210 struct SelectScalarHandler { operator ()tensorflow::functor::SelectScalarHandler211 void operator()(OpKernelContext* ctx, const Tensor* cond, const Tensor* then, 212 const Tensor* else_) { 213 Tensor* output = nullptr; 214 OP_REQUIRES_OK(ctx, ctx->forward_input_or_allocate_output( 215 {"t", "e"}, "output", then->shape(), &output)); 216 217 if (output->NumElements() > 0) { 218 functor::SelectScalarFunctor<Device, T> func; 219 TTypes<bool>::ConstScalar cond_scalar = cond->scalar<bool>(); 220 func(ctx->eigen_device<Device>(), output->flat<T>(), cond_scalar, 221 then->flat<T>(), else_->flat<T>()); 222 } 223 } 224 }; 225 226 // Specilization for CPU device. Forward input to output depending on the `cond` 227 // value. 228 // TODO(sjhwang): Consider specializing for GPUDevice as well by using 229 // GPUDevice::memcpyDeviceToHost() to fetch bool value. 230 template <typename T> 231 struct SelectScalarHandler<CPUDevice, T> { operator ()tensorflow::functor::SelectScalarHandler232 void operator()(OpKernelContext* ctx, const Tensor* cond, const Tensor* then, 233 const Tensor* else_) { 234 if (cond->scalar<bool>()()) { 235 OP_REQUIRES_OK(ctx, ctx->set_output("output", *then)); 236 } else { 237 OP_REQUIRES_OK(ctx, ctx->set_output("output", *else_)); 238 } 239 } 240 }; 241 242 #ifdef TENSORFLOW_USE_SYCL 243 template <typename Device, typename T> 244 struct SelectScalarFunctorBase { operator ()tensorflow::functor::SelectScalarFunctorBase245 void operator()(const Device& d, typename TTypes<T>::Flat out, 246 TTypes<bool>::ConstScalar cond, 247 typename TTypes<T>::ConstFlat then_flat, 248 typename TTypes<T>::ConstFlat else_flat) { 249 out.device(d) = cond() ? then_flat : else_flat; 250 } 251 }; 252 253 template <typename T> 254 struct SelectScalarFunctor<SYCLDevice, T> 255 : SelectScalarFunctorBase<SYCLDevice, T> {}; 256 #endif // TENSORFLOW_USE_SYCL 257 258 template <typename Device, typename T> 259 struct BatchSelectFunctorBase { operator ()tensorflow::functor::BatchSelectFunctorBase260 void operator()(const Device& d, 261 typename TTypes<T>::Matrix output_flat_outer_dims, 262 TTypes<bool>::ConstVec cond_vec, 263 typename TTypes<T>::ConstMatrix then_flat_outer_dims, 264 typename TTypes<T>::ConstMatrix else_flat_outer_dims) { 265 const Eigen::DenseIndex batch = cond_vec.size(); 266 const Eigen::DenseIndex all_but_batch = then_flat_outer_dims.dimension(1); 267 268 #if !defined(EIGEN_HAS_INDEX_LIST) 269 Eigen::array<Eigen::DenseIndex, 2> broadcast_dims{{1, all_but_batch}}; 270 Eigen::Tensor<Eigen::DenseIndex, 2>::Dimensions reshape_dims{{batch, 1}}; 271 #else 272 Eigen::IndexList<Eigen::type2index<1>, Eigen::DenseIndex> broadcast_dims; 273 broadcast_dims.set(1, all_but_batch); 274 Eigen::IndexList<Eigen::DenseIndex, Eigen::type2index<1> > reshape_dims; 275 reshape_dims.set(0, batch); 276 #endif 277 278 Assign(d, output_flat_outer_dims, 279 cond_vec.reshape(reshape_dims) 280 .broadcast(broadcast_dims) 281 .select(then_flat_outer_dims, else_flat_outer_dims)); 282 } 283 }; 284 285 // A fast implementation on CPU, using loop to get rid of broadcasting. 286 template <typename T> 287 struct BatchSelectFunctor<CPUDevice, T> { operator ()tensorflow::functor::BatchSelectFunctor288 void operator()(const CPUDevice& d, 289 typename TTypes<T>::Matrix output_flat_outer_dims, 290 TTypes<bool>::ConstVec cond_vec, 291 typename TTypes<T>::ConstMatrix then_flat_outer_dims, 292 typename TTypes<T>::ConstMatrix else_flat_outer_dims) { 293 const size_t batch = cond_vec.size(); 294 const size_t batch_size = then_flat_outer_dims.size() / batch; 295 T* output = output_flat_outer_dims.data(); 296 const bool* c = cond_vec.data(); 297 const T* t = then_flat_outer_dims.data(); 298 const T* e = else_flat_outer_dims.data(); 299 300 auto work = [batch_size, output, c, t, e](int64 start, int64 end) { 301 for (size_t i = start; i < end; ++i) { 302 size_t offset = i * batch_size; 303 port::prefetch<port::PREFETCH_HINT_NTA>( 304 reinterpret_cast<const void*>(&t[offset + batch_size])); 305 port::prefetch<port::PREFETCH_HINT_NTA>( 306 reinterpret_cast<const void*>(&e[offset + batch_size])); 307 port::prefetch<port::PREFETCH_HINT_NTA>( 308 reinterpret_cast<const void*>(&c[i + 1])); 309 if (c[i]) { 310 for (size_t j = 0; j < batch_size; ++j) { 311 output[offset + j] = t[offset + j]; 312 } 313 } else { 314 for (size_t j = 0; j < batch_size; ++j) { 315 output[offset + j] = e[offset + j]; 316 } 317 } 318 } 319 }; 320 auto cost = Eigen::TensorOpCost(sizeof(T) * batch_size * 2, // ld bytes 321 sizeof(T) * batch_size, // st bytes 322 batch_size); // compute cycles 323 d.parallelFor(batch, cost, work); 324 } 325 }; 326 327 #ifdef TENSORFLOW_USE_SYCL 328 template <typename T> 329 struct BatchSelectFunctor<SYCLDevice, T> 330 : BatchSelectFunctorBase<SYCLDevice, T> {}; 331 #endif // TENSORFLOW_USE_SYCL 332 333 } // namespace functor 334 335 } // namespace tensorflow 336