1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // See docs in ../ops/math_ops.cc. 17 18 #define EIGEN_USE_THREADS 19 20 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 21 #include "tensorflow/core/framework/op_kernel.h" 22 #include "tensorflow/core/kernels/cwise_ops.h" 23 #include "tensorflow/core/kernels/cwise_ops_common.h" 24 #include "tensorflow/core/kernels/relu_op_functor.h" 25 26 namespace tensorflow { 27 28 template <typename T> 29 class UnaryOpsComposition; // forward declare kernel 30 31 template <typename T> 32 struct UnaryOpsCompositionSupport; 33 34 template <typename T> 35 struct UnaryOpsCompositionBase { 36 using InputBuffer = typename TTypes<T>::ConstFlat; 37 using OutputBuffer = typename TTypes<T>::Flat; 38 39 using ComputeFn = void (*)(const InputBuffer&, OutputBuffer*); 40 41 struct ComputeFnRegistration { 42 ComputeFn compute_fn; 43 int cost; 44 }; 45 HasComputeFntensorflow::UnaryOpsCompositionBase46 bool HasComputeFn(const string& name) { 47 return compute_fns.find(name) != compute_fns.end(); 48 } 49 50 protected: RegisterComputeFntensorflow::UnaryOpsCompositionBase51 void RegisterComputeFn(const string& name, ComputeFn compute_fn, int cost) { 52 VLOG(5) << "Register compute fn: name=" << name << " cost=" << cost; 53 compute_fns[name] = {compute_fn, cost}; 54 } 55 56 private: 57 friend class UnaryOpsComposition<T>; 58 ExportComputeFnstensorflow::UnaryOpsCompositionBase59 Status ExportComputeFns(const std::vector<string>& op_names, 60 std::vector<ComputeFn>* fns, int* cost) { 61 for (const string& op_name : op_names) { 62 auto it = compute_fns.find(op_name); 63 if (it == compute_fns.end()) 64 return errors::InvalidArgument( 65 "Do not have a compute function registered for op: ", op_name); 66 67 const ComputeFnRegistration& reg = it->second; 68 fns->push_back(reg.compute_fn); 69 *cost += reg.cost; 70 } 71 72 return Status::OK(); 73 } 74 75 std::unordered_map<string, ComputeFnRegistration> compute_fns; 76 }; 77 78 template <typename T> 79 class UnaryOpsComposition : public OpKernel { 80 public: 81 using Kernel = UnaryOpsComposition<T>; 82 83 using Scalar = T; 84 using Packet = typename Eigen::internal::packet_traits<T>::type; 85 86 using Support = UnaryOpsCompositionSupport<T>; 87 88 using InputBuffer = typename Support::InputBuffer; 89 using OutputBuffer = typename Support::OutputBuffer; 90 using ComputeFn = typename Support::ComputeFn; 91 UnaryOpsComposition(OpKernelConstruction * context)92 explicit UnaryOpsComposition(OpKernelConstruction* context) 93 : OpKernel(context) { 94 OP_REQUIRES_OK(context, context->GetAttr("op_names", &op_names_)); 95 96 OP_REQUIRES(context, !op_names_.empty(), 97 errors::InvalidArgument( 98 "Unary op composition must have at least one op")); 99 100 OP_REQUIRES_OK(context, 101 support_.ExportComputeFns(op_names_, &fns_, &cost_)); 102 103 VLOG(2) << "Composed unary op: [" << str_util::Join(op_names_, ", ") 104 << "]; cost=" << cost_; 105 } 106 Compute(OpKernelContext * ctx)107 void Compute(OpKernelContext* ctx) override { 108 const Tensor& in = ctx->input(0); 109 Tensor* out = nullptr; 110 OP_REQUIRES_OK( 111 ctx, ctx->forward_input_or_allocate_output({0}, 0, in.shape(), &out)); 112 113 InputBuffer in_flat = in.flat<T>(); 114 OutputBuffer out_flat = out->flat<T>(); 115 116 const std::size_t num_fns = fns_.size(); 117 auto compute_fn = [this, &in_flat, &out_flat, &num_fns](int64 begin, 118 int64 end) { 119 int64 len = end - begin; 120 const InputBuffer in_slice(in_flat.data() + begin, len); 121 const InputBuffer scratch_slice(out_flat.data() + begin, len); 122 OutputBuffer out_slice(out_flat.data() + begin, len); 123 124 fns_[0](in_slice, &out_slice); 125 for (int i = 1; i < num_fns; ++i) { 126 fns_[i](scratch_slice, &out_slice); 127 } 128 }; 129 130 const CPUDevice& device = ctx->eigen_device<CPUDevice>(); 131 const int kOverheadCycles = static_cast<int>(num_fns) * 10; 132 Eigen::TensorOpCost cost(/*bytes_loaded=*/sizeof(T) * num_fns, 133 /*bytes_stored=*/sizeof(T) * num_fns, 134 kOverheadCycles + cost_); 135 device.parallelFor(in.NumElements(), cost, AlignBlockSize, 136 std::move(compute_fn)); 137 } 138 139 private: 140 static const int kPacketSize = Eigen::internal::unpacket_traits<Packet>::size; 141 AlignBlockSize(int64 block_size)142 static inline int64 AlignBlockSize(int64 block_size) { 143 // Align block size to packet size and account for unrolling in run above. 144 if (block_size >= 16 * kPacketSize) { 145 return (block_size + 4 * kPacketSize - 1) & ~(4 * kPacketSize - 1); 146 } 147 // Aligning to 4 * PacketSize would increase block size by more than 25%. 148 return (block_size + kPacketSize - 1) & ~(kPacketSize - 1); 149 } 150 151 Support support_; 152 153 std::vector<string> op_names_; 154 std::vector<ComputeFn> fns_; 155 int cost_ = 0; 156 }; 157 158 // Register compute functions for UnaryOp functors. 159 #define REGISTER_COMPUTE_FN_HELPER(name, functor) \ 160 static_assert(std::is_same<functor::in_type, functor::out_type>::value, \ 161 "Functor must have same input and output types"); \ 162 \ 163 static inline void Compute##name(const InputBuffer& in, OutputBuffer* out) { \ 164 *out = in.unaryExpr(functor::func()); \ 165 } \ 166 static inline int Cost##name() { \ 167 return Eigen::internal::functor_traits<functor::func>::Cost; \ 168 } 169 170 // Register compute function for the Relu/Relu6/Elu/Selu. 171 #define REGISTER_RELU_HELPER() \ 172 template <typename T> \ 173 using functor_traits = Eigen::internal::functor_traits<T>; \ 174 \ 175 static inline void ComputeRelu(const InputBuffer& in, OutputBuffer* out) { \ 176 auto relu = functor::Relu<Eigen::DefaultDevice, T>(); \ 177 relu(Eigen::DefaultDevice(), in, *out); \ 178 } \ 179 \ 180 static inline int CostRelu() { \ 181 return functor_traits<Eigen::internal::scalar_max_op<T>>::Cost; \ 182 } \ 183 \ 184 static inline void ComputeRelu6(const InputBuffer& in, OutputBuffer* out) { \ 185 auto relu6 = functor::Relu6<Eigen::DefaultDevice, T>(); \ 186 relu6(Eigen::DefaultDevice(), in, *out); \ 187 } \ 188 \ 189 static inline int CostRelu6() { \ 190 return functor_traits<Eigen::internal::scalar_max_op<T>>::Cost + \ 191 functor_traits<Eigen::internal::scalar_min_op<T>>::Cost; \ 192 } \ 193 static inline void ComputeElu(const InputBuffer& in, OutputBuffer* out) { \ 194 auto elu = functor::Elu<Eigen::DefaultDevice, T>(); \ 195 elu(Eigen::DefaultDevice(), in, *out); \ 196 } \ 197 \ 198 static inline int CostElu() { \ 199 return functor_traits<Eigen::internal::scalar_exp_op<T>>::Cost + \ 200 Eigen::NumTraits<T>::MulCost; \ 201 } \ 202 static inline void ComputeSelu(const InputBuffer& in, OutputBuffer* out) { \ 203 auto selu = functor::Selu<Eigen::DefaultDevice, T>(); \ 204 selu(Eigen::DefaultDevice(), in, *out); \ 205 } \ 206 \ 207 static inline int CostSelu() { \ 208 return 2 * (functor_traits<Eigen::internal::scalar_exp_op<T>>::Cost + \ 209 Eigen::NumTraits<T>::MulCost); \ 210 } 211 212 #define REGISTER_COMPUTE_FN(func) \ 213 RegisterComputeFn(#func, Compute##func, Cost##func()); 214 215 template <> 216 struct UnaryOpsCompositionSupport<float> : UnaryOpsCompositionBase<float> { 217 using T = float; 218 UnaryOpsCompositionSupporttensorflow::UnaryOpsCompositionSupport219 UnaryOpsCompositionSupport() { 220 // UnaryOp functors. 221 REGISTER_COMPUTE_FN(Abs); 222 REGISTER_COMPUTE_FN(Acos); 223 REGISTER_COMPUTE_FN(Acosh); 224 REGISTER_COMPUTE_FN(Asin); 225 REGISTER_COMPUTE_FN(Asinh); 226 REGISTER_COMPUTE_FN(Atan); 227 REGISTER_COMPUTE_FN(Atanh); 228 REGISTER_COMPUTE_FN(Ceil); 229 REGISTER_COMPUTE_FN(Cos); 230 REGISTER_COMPUTE_FN(Cosh); 231 REGISTER_COMPUTE_FN(Expm1); 232 REGISTER_COMPUTE_FN(Exp); 233 REGISTER_COMPUTE_FN(Floor); 234 REGISTER_COMPUTE_FN(Inv); 235 REGISTER_COMPUTE_FN(Log); 236 REGISTER_COMPUTE_FN(Log1p); 237 REGISTER_COMPUTE_FN(Neg); 238 REGISTER_COMPUTE_FN(Reciprocal); 239 REGISTER_COMPUTE_FN(Rint); 240 REGISTER_COMPUTE_FN(Round); 241 REGISTER_COMPUTE_FN(Rsqrt); 242 REGISTER_COMPUTE_FN(Sigmoid); 243 REGISTER_COMPUTE_FN(Sin); 244 REGISTER_COMPUTE_FN(Sinh); 245 REGISTER_COMPUTE_FN(Sqrt); 246 REGISTER_COMPUTE_FN(Square); 247 REGISTER_COMPUTE_FN(Tan); 248 REGISTER_COMPUTE_FN(Tanh); 249 250 // Additional compute functions not defined via UnaryOp functors. 251 REGISTER_COMPUTE_FN(Elu); 252 REGISTER_COMPUTE_FN(Relu); 253 REGISTER_COMPUTE_FN(Relu6); 254 REGISTER_COMPUTE_FN(Selu); 255 } 256 257 REGISTER_RELU_HELPER(); 258 259 // clang-format off 260 REGISTER_COMPUTE_FN_HELPER(Abs, functor::abs<T>); 261 REGISTER_COMPUTE_FN_HELPER(Acos, functor::acos<T>); 262 REGISTER_COMPUTE_FN_HELPER(Acosh, functor::acosh<T>); 263 REGISTER_COMPUTE_FN_HELPER(Asin, functor::asin<T>); 264 REGISTER_COMPUTE_FN_HELPER(Asinh, functor::asinh<T>); 265 REGISTER_COMPUTE_FN_HELPER(Atan, functor::atan<T>); 266 REGISTER_COMPUTE_FN_HELPER(Atanh, functor::atanh<T>); 267 REGISTER_COMPUTE_FN_HELPER(Ceil, functor::ceil<T>); 268 REGISTER_COMPUTE_FN_HELPER(Cos, functor::cos<T>); 269 REGISTER_COMPUTE_FN_HELPER(Cosh, functor::cosh<T>); 270 REGISTER_COMPUTE_FN_HELPER(Expm1, functor::expm1<T>); 271 REGISTER_COMPUTE_FN_HELPER(Exp, functor::exp<T>); 272 REGISTER_COMPUTE_FN_HELPER(Floor, functor::floor<T>); 273 REGISTER_COMPUTE_FN_HELPER(Inv, functor::inverse<T>); 274 REGISTER_COMPUTE_FN_HELPER(Log, functor::log<T>); 275 REGISTER_COMPUTE_FN_HELPER(Log1p, functor::log1p<T>); 276 REGISTER_COMPUTE_FN_HELPER(Neg, functor::neg<T>); 277 REGISTER_COMPUTE_FN_HELPER(Reciprocal, functor::inverse<T>); 278 REGISTER_COMPUTE_FN_HELPER(Rint, functor::rint<T>); 279 REGISTER_COMPUTE_FN_HELPER(Round, functor::round<T>); 280 REGISTER_COMPUTE_FN_HELPER(Rsqrt, functor::rsqrt<T>); 281 REGISTER_COMPUTE_FN_HELPER(Sigmoid, functor::sigmoid<T>); 282 REGISTER_COMPUTE_FN_HELPER(Sin, functor::sin<T>); 283 REGISTER_COMPUTE_FN_HELPER(Sinh, functor::sinh<T>); 284 REGISTER_COMPUTE_FN_HELPER(Sqrt, functor::sqrt<T>); 285 REGISTER_COMPUTE_FN_HELPER(Square, functor::square<T>); 286 REGISTER_COMPUTE_FN_HELPER(Tan, functor::tan<T>); 287 REGISTER_COMPUTE_FN_HELPER(Tanh, functor::tanh<T>); 288 // clang-format on 289 }; 290 291 template <> 292 struct UnaryOpsCompositionSupport<Eigen::half> 293 : UnaryOpsCompositionBase<Eigen::half> { 294 using T = Eigen::half; 295 UnaryOpsCompositionSupporttensorflow::UnaryOpsCompositionSupport296 UnaryOpsCompositionSupport() { 297 REGISTER_COMPUTE_FN(Abs); 298 REGISTER_COMPUTE_FN(Ceil); 299 REGISTER_COMPUTE_FN(Cos); 300 REGISTER_COMPUTE_FN(Expm1); 301 REGISTER_COMPUTE_FN(Exp); 302 REGISTER_COMPUTE_FN(Floor); 303 REGISTER_COMPUTE_FN(Inv); 304 REGISTER_COMPUTE_FN(Log); 305 REGISTER_COMPUTE_FN(Log1p); 306 REGISTER_COMPUTE_FN(Neg); 307 REGISTER_COMPUTE_FN(Reciprocal); 308 REGISTER_COMPUTE_FN(Round); 309 REGISTER_COMPUTE_FN(Rsqrt); 310 REGISTER_COMPUTE_FN(Sigmoid); 311 REGISTER_COMPUTE_FN(Sin); 312 REGISTER_COMPUTE_FN(Sqrt); 313 REGISTER_COMPUTE_FN(Square); 314 REGISTER_COMPUTE_FN(Tanh); 315 // Additional compute functions not defined via UnaryOp functors. 316 REGISTER_COMPUTE_FN(Elu); 317 REGISTER_COMPUTE_FN(Relu); 318 REGISTER_COMPUTE_FN(Relu6); 319 REGISTER_COMPUTE_FN(Selu); 320 } 321 322 REGISTER_RELU_HELPER(); 323 324 // clang-format off 325 REGISTER_COMPUTE_FN_HELPER(Abs, functor::abs<T>); 326 REGISTER_COMPUTE_FN_HELPER(Ceil, functor::ceil<T>); 327 REGISTER_COMPUTE_FN_HELPER(Cos, functor::cos<T>); 328 REGISTER_COMPUTE_FN_HELPER(Expm1, functor::expm1<T>); 329 REGISTER_COMPUTE_FN_HELPER(Exp, functor::exp<T>); 330 REGISTER_COMPUTE_FN_HELPER(Floor, functor::floor<T>); 331 REGISTER_COMPUTE_FN_HELPER(Inv, functor::inverse<T>); 332 REGISTER_COMPUTE_FN_HELPER(Log, functor::log<T>); 333 REGISTER_COMPUTE_FN_HELPER(Log1p, functor::log1p<T>); 334 REGISTER_COMPUTE_FN_HELPER(Neg, functor::neg<T>); 335 REGISTER_COMPUTE_FN_HELPER(Reciprocal, functor::inverse<T>); 336 REGISTER_COMPUTE_FN_HELPER(Round, functor::round<T>); 337 REGISTER_COMPUTE_FN_HELPER(Rsqrt, functor::rsqrt<T>); 338 REGISTER_COMPUTE_FN_HELPER(Sigmoid, functor::sigmoid<T>); 339 REGISTER_COMPUTE_FN_HELPER(Sin, functor::sin<T>); 340 REGISTER_COMPUTE_FN_HELPER(Sqrt, functor::sqrt<T>); 341 REGISTER_COMPUTE_FN_HELPER(Square, functor::square<T>); 342 REGISTER_COMPUTE_FN_HELPER(Tanh, functor::tanh<T>); 343 // clang-format on 344 }; 345 346 template <> 347 struct UnaryOpsCompositionSupport<double> : UnaryOpsCompositionBase<double> { 348 using T = double; 349 UnaryOpsCompositionSupporttensorflow::UnaryOpsCompositionSupport350 UnaryOpsCompositionSupport() { 351 REGISTER_COMPUTE_FN(Abs); 352 REGISTER_COMPUTE_FN(Acos); 353 REGISTER_COMPUTE_FN(Acosh); 354 REGISTER_COMPUTE_FN(Asin); 355 REGISTER_COMPUTE_FN(Asinh); 356 REGISTER_COMPUTE_FN(Atan); 357 REGISTER_COMPUTE_FN(Atanh); 358 REGISTER_COMPUTE_FN(Ceil); 359 REGISTER_COMPUTE_FN(Cos); 360 REGISTER_COMPUTE_FN(Cosh); 361 REGISTER_COMPUTE_FN(Expm1); 362 REGISTER_COMPUTE_FN(Exp); 363 REGISTER_COMPUTE_FN(Floor); 364 REGISTER_COMPUTE_FN(Inv); 365 REGISTER_COMPUTE_FN(Log); 366 REGISTER_COMPUTE_FN(Log1p); 367 REGISTER_COMPUTE_FN(Neg); 368 REGISTER_COMPUTE_FN(Reciprocal); 369 REGISTER_COMPUTE_FN(Rint); 370 REGISTER_COMPUTE_FN(Round); 371 REGISTER_COMPUTE_FN(Rsqrt); 372 REGISTER_COMPUTE_FN(Sigmoid); 373 REGISTER_COMPUTE_FN(Sin); 374 REGISTER_COMPUTE_FN(Sinh); 375 REGISTER_COMPUTE_FN(Sqrt); 376 REGISTER_COMPUTE_FN(Square); 377 REGISTER_COMPUTE_FN(Tan); 378 REGISTER_COMPUTE_FN(Tanh); 379 // Additional compute functions not defined via UnaryOp functors. 380 REGISTER_COMPUTE_FN(Elu); 381 REGISTER_COMPUTE_FN(Relu); 382 REGISTER_COMPUTE_FN(Relu6); 383 REGISTER_COMPUTE_FN(Selu); 384 } 385 386 REGISTER_RELU_HELPER(); 387 388 // clang-format off 389 REGISTER_COMPUTE_FN_HELPER(Abs, functor::abs<T>); 390 REGISTER_COMPUTE_FN_HELPER(Acos, functor::acos<T>); 391 REGISTER_COMPUTE_FN_HELPER(Acosh, functor::acosh<T>); 392 REGISTER_COMPUTE_FN_HELPER(Asin, functor::asin<T>); 393 REGISTER_COMPUTE_FN_HELPER(Asinh, functor::asinh<T>); 394 REGISTER_COMPUTE_FN_HELPER(Atan, functor::atan<T>); 395 REGISTER_COMPUTE_FN_HELPER(Atanh, functor::atanh<T>); 396 REGISTER_COMPUTE_FN_HELPER(Ceil, functor::ceil<T>); 397 REGISTER_COMPUTE_FN_HELPER(Cos, functor::cos<T>); 398 REGISTER_COMPUTE_FN_HELPER(Cosh, functor::cosh<T>); 399 REGISTER_COMPUTE_FN_HELPER(Expm1, functor::expm1<T>); 400 REGISTER_COMPUTE_FN_HELPER(Exp, functor::exp<T>); 401 REGISTER_COMPUTE_FN_HELPER(Floor, functor::floor<T>); 402 REGISTER_COMPUTE_FN_HELPER(Inv, functor::inverse<T>); 403 REGISTER_COMPUTE_FN_HELPER(Log, functor::log<T>); 404 REGISTER_COMPUTE_FN_HELPER(Log1p, functor::log1p<T>); 405 REGISTER_COMPUTE_FN_HELPER(Neg, functor::neg<T>); 406 REGISTER_COMPUTE_FN_HELPER(Reciprocal, functor::inverse<T>); 407 REGISTER_COMPUTE_FN_HELPER(Rint, functor::rint<T>); 408 REGISTER_COMPUTE_FN_HELPER(Round, functor::round<T>); 409 REGISTER_COMPUTE_FN_HELPER(Rsqrt, functor::rsqrt<T>); 410 REGISTER_COMPUTE_FN_HELPER(Sigmoid, functor::sigmoid<T>); 411 REGISTER_COMPUTE_FN_HELPER(Sin, functor::sin<T>); 412 REGISTER_COMPUTE_FN_HELPER(Sinh, functor::sinh<T>); 413 REGISTER_COMPUTE_FN_HELPER(Sqrt, functor::sqrt<T>); 414 REGISTER_COMPUTE_FN_HELPER(Square, functor::square<T>); 415 REGISTER_COMPUTE_FN_HELPER(Tan, functor::tan<T>); 416 REGISTER_COMPUTE_FN_HELPER(Tanh, functor::tanh<T>); 417 // clang-format on 418 }; 419 420 // Register the CPU kernels. 421 #define REGISTER_CPU(T) \ 422 REGISTER_KERNEL_BUILDER( \ 423 Name("_UnaryOpsComposition").Device(DEVICE_CPU).TypeConstraint<T>("T"), \ 424 UnaryOpsComposition<T>); 425 426 REGISTER_CPU(float); 427 REGISTER_CPU(Eigen::half); 428 REGISTER_CPU(double); 429 430 #undef REGISTER_CPU 431 432 } // namespace tensorflow 433