• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // See docs in ../ops/math_ops.cc.
17 
18 #define EIGEN_USE_THREADS
19 
20 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
21 #include "tensorflow/core/framework/op_kernel.h"
22 #include "tensorflow/core/kernels/cwise_ops.h"
23 #include "tensorflow/core/kernels/cwise_ops_common.h"
24 #include "tensorflow/core/kernels/relu_op_functor.h"
25 
26 namespace tensorflow {
27 
28 template <typename T>
29 class UnaryOpsComposition;  // forward declare kernel
30 
31 template <typename T>
32 struct UnaryOpsCompositionSupport;
33 
34 template <typename T>
35 struct UnaryOpsCompositionBase {
36   using InputBuffer = typename TTypes<T>::ConstFlat;
37   using OutputBuffer = typename TTypes<T>::Flat;
38 
39   using ComputeFn = void (*)(const InputBuffer&, OutputBuffer*);
40 
41   struct ComputeFnRegistration {
42     ComputeFn compute_fn;
43     int cost;
44   };
45 
HasComputeFntensorflow::UnaryOpsCompositionBase46   bool HasComputeFn(const string& name) {
47     return compute_fns.find(name) != compute_fns.end();
48   }
49 
50  protected:
RegisterComputeFntensorflow::UnaryOpsCompositionBase51   void RegisterComputeFn(const string& name, ComputeFn compute_fn, int cost) {
52     VLOG(5) << "Register compute fn: name=" << name << " cost=" << cost;
53     compute_fns[name] = {compute_fn, cost};
54   }
55 
56  private:
57   friend class UnaryOpsComposition<T>;
58 
ExportComputeFnstensorflow::UnaryOpsCompositionBase59   Status ExportComputeFns(const std::vector<string>& op_names,
60                           std::vector<ComputeFn>* fns, int* cost) {
61     for (const string& op_name : op_names) {
62       auto it = compute_fns.find(op_name);
63       if (it == compute_fns.end())
64         return errors::InvalidArgument(
65             "Do not have a compute function registered for op: ", op_name);
66 
67       const ComputeFnRegistration& reg = it->second;
68       fns->push_back(reg.compute_fn);
69       *cost += reg.cost;
70     }
71 
72     return OkStatus();
73   }
74 
75   std::unordered_map<string, ComputeFnRegistration> compute_fns;
76 };
77 
78 template <typename T>
79 class UnaryOpsComposition : public OpKernel {
80  public:
81   using Kernel = UnaryOpsComposition<T>;
82 
83   using Scalar = T;
84   using Packet = typename Eigen::internal::packet_traits<T>::type;
85 
86   using Support = UnaryOpsCompositionSupport<T>;
87 
88   using InputBuffer = typename Support::InputBuffer;
89   using OutputBuffer = typename Support::OutputBuffer;
90   using ComputeFn = typename Support::ComputeFn;
91 
UnaryOpsComposition(OpKernelConstruction * context)92   explicit UnaryOpsComposition(OpKernelConstruction* context)
93       : OpKernel(context) {
94     OP_REQUIRES_OK(context, context->GetAttr("op_names", &op_names_));
95 
96     OP_REQUIRES(context, !op_names_.empty(),
97                 errors::InvalidArgument(
98                     "Unary op composition must have at least one op"));
99 
100     OP_REQUIRES_OK(context,
101                    support_.ExportComputeFns(op_names_, &fns_, &cost_));
102 
103     VLOG(2) << "Composed unary op: [" << absl::StrJoin(op_names_, ", ")
104             << "]; cost=" << cost_;
105   }
106 
Compute(OpKernelContext * ctx)107   void Compute(OpKernelContext* ctx) override {
108     const Tensor& in = ctx->input(0);
109     Tensor* out = nullptr;
110     OP_REQUIRES_OK(
111         ctx, ctx->forward_input_or_allocate_output({0}, 0, in.shape(), &out));
112 
113     InputBuffer in_flat = in.flat<T>();
114     OutputBuffer out_flat = out->flat<T>();
115 
116     const std::size_t num_fns = fns_.size();
117     auto compute_fn = [this, &in_flat, &out_flat, &num_fns](int64_t begin,
118                                                             int64_t end) {
119       int64_t len = end - begin;
120       const InputBuffer in_slice(in_flat.data() + begin, len);
121       const InputBuffer scratch_slice(out_flat.data() + begin, len);
122       OutputBuffer out_slice(out_flat.data() + begin, len);
123 
124       fns_[0](in_slice, &out_slice);
125       for (int i = 1; i < num_fns; ++i) {
126         fns_[i](scratch_slice, &out_slice);
127       }
128     };
129 
130     const CPUDevice& device = ctx->eigen_device<CPUDevice>();
131     const int kOverheadCycles = static_cast<int>(num_fns) * 10;
132     Eigen::TensorOpCost cost(/*bytes_loaded=*/sizeof(T) * num_fns,
133                              /*bytes_stored=*/sizeof(T) * num_fns,
134                              kOverheadCycles + cost_);
135     device.parallelFor(in.NumElements(), cost, AlignBlockSize,
136                        std::move(compute_fn));
137   }
138 
139  private:
140   static constexpr int kPacketSize =
141       Eigen::internal::unpacket_traits<Packet>::size;
142 
AlignBlockSize(int64_t block_size)143   static inline int64_t AlignBlockSize(int64_t block_size) {
144     // Align block size to packet size and account for unrolling in run above.
145     if (block_size >= 16 * kPacketSize) {
146       return (block_size + 4 * kPacketSize - 1) & ~(4 * kPacketSize - 1);
147     }
148     // Aligning to 4 * PacketSize would increase block size by more than 25%.
149     return (block_size + kPacketSize - 1) & ~(kPacketSize - 1);
150   }
151 
152   Support support_;
153 
154   std::vector<string> op_names_;
155   std::vector<ComputeFn> fns_;
156   int cost_ = 0;
157 };
158 
159 // Register compute functions for UnaryOp functors.
160 #define REGISTER_COMPUTE_FN_HELPER(name, functor)                              \
161   static_assert(std::is_same<functor::in_type, functor::out_type>::value,      \
162                 "Functor must have same input and output types");              \
163                                                                                \
164   static inline void Compute##name(const InputBuffer& in, OutputBuffer* out) { \
165     *out = in.unaryExpr(functor::func());                                      \
166   }                                                                            \
167   static inline int Cost##name() {                                             \
168     return Eigen::internal::functor_traits<functor::func>::Cost;               \
169   }
170 
171 // Register compute function for the Relu/Relu6/Elu/Selu.
172 #define REGISTER_RELU_HELPER()                                                \
173   template <typename T>                                                       \
174   using functor_traits = Eigen::internal::functor_traits<T>;                  \
175                                                                               \
176   static inline void ComputeRelu(const InputBuffer& in, OutputBuffer* out) {  \
177     auto relu = functor::Relu<Eigen::DefaultDevice, T>();                     \
178     relu(Eigen::DefaultDevice(), in, *out);                                   \
179   }                                                                           \
180                                                                               \
181   static inline int CostRelu() {                                              \
182     return functor_traits<Eigen::internal::scalar_max_op<T>>::Cost;           \
183   }                                                                           \
184                                                                               \
185   static inline void ComputeRelu6(const InputBuffer& in, OutputBuffer* out) { \
186     auto relu6 = functor::Relu6<Eigen::DefaultDevice, T>();                   \
187     relu6(Eigen::DefaultDevice(), in, *out);                                  \
188   }                                                                           \
189                                                                               \
190   static inline int CostRelu6() {                                             \
191     return functor_traits<Eigen::internal::scalar_max_op<T>>::Cost +          \
192            functor_traits<Eigen::internal::scalar_min_op<T>>::Cost;           \
193   }                                                                           \
194   static inline void ComputeElu(const InputBuffer& in, OutputBuffer* out) {   \
195     auto elu = functor::Elu<Eigen::DefaultDevice, T>();                       \
196     elu(Eigen::DefaultDevice(), in, *out);                                    \
197   }                                                                           \
198                                                                               \
199   static inline int CostElu() {                                               \
200     return functor_traits<Eigen::internal::scalar_exp_op<T>>::Cost +          \
201            Eigen::NumTraits<T>::MulCost;                                      \
202   }                                                                           \
203   static inline void ComputeSelu(const InputBuffer& in, OutputBuffer* out) {  \
204     auto selu = functor::Selu<Eigen::DefaultDevice, T>();                     \
205     selu(Eigen::DefaultDevice(), in, *out);                                   \
206   }                                                                           \
207                                                                               \
208   static inline int CostSelu() {                                              \
209     return 2 * (functor_traits<Eigen::internal::scalar_exp_op<T>>::Cost +     \
210                 Eigen::NumTraits<T>::MulCost);                                \
211   }
212 
213 #define REGISTER_COMPUTE_FN(func) \
214   RegisterComputeFn(#func, Compute##func, Cost##func());
215 
216 template <>
217 struct UnaryOpsCompositionSupport<float> : UnaryOpsCompositionBase<float> {
218   using T = float;
219 
UnaryOpsCompositionSupporttensorflow::UnaryOpsCompositionSupport220   UnaryOpsCompositionSupport() {
221     // UnaryOp functors.
222     REGISTER_COMPUTE_FN(Abs);
223     REGISTER_COMPUTE_FN(Acos);
224     REGISTER_COMPUTE_FN(Acosh);
225     REGISTER_COMPUTE_FN(Asin);
226     REGISTER_COMPUTE_FN(Asinh);
227     REGISTER_COMPUTE_FN(Atan);
228     REGISTER_COMPUTE_FN(Atanh);
229     REGISTER_COMPUTE_FN(Ceil);
230     REGISTER_COMPUTE_FN(Cos);
231     REGISTER_COMPUTE_FN(Cosh);
232     REGISTER_COMPUTE_FN(Expm1);
233     REGISTER_COMPUTE_FN(Exp);
234     REGISTER_COMPUTE_FN(Floor);
235     REGISTER_COMPUTE_FN(Inv);
236     REGISTER_COMPUTE_FN(Log);
237     REGISTER_COMPUTE_FN(Log1p);
238     REGISTER_COMPUTE_FN(Neg);
239     REGISTER_COMPUTE_FN(Reciprocal);
240     REGISTER_COMPUTE_FN(Rint);
241     REGISTER_COMPUTE_FN(Round);
242     REGISTER_COMPUTE_FN(Rsqrt);
243     REGISTER_COMPUTE_FN(Sigmoid);
244     REGISTER_COMPUTE_FN(Sin);
245     REGISTER_COMPUTE_FN(Sinh);
246     REGISTER_COMPUTE_FN(Sqrt);
247     REGISTER_COMPUTE_FN(Square);
248     REGISTER_COMPUTE_FN(Tan);
249     REGISTER_COMPUTE_FN(Tanh);
250 
251     // Additional compute functions not defined via UnaryOp functors.
252     REGISTER_COMPUTE_FN(Elu);
253     REGISTER_COMPUTE_FN(Relu);
254     REGISTER_COMPUTE_FN(Relu6);
255     REGISTER_COMPUTE_FN(Selu);
256   }
257 
258   REGISTER_RELU_HELPER();
259 
260   // clang-format off
261   REGISTER_COMPUTE_FN_HELPER(Abs,        functor::abs<T>);
262   REGISTER_COMPUTE_FN_HELPER(Acos,       functor::acos<T>);
263   REGISTER_COMPUTE_FN_HELPER(Acosh,      functor::acosh<T>);
264   REGISTER_COMPUTE_FN_HELPER(Asin,       functor::asin<T>);
265   REGISTER_COMPUTE_FN_HELPER(Asinh,      functor::asinh<T>);
266   REGISTER_COMPUTE_FN_HELPER(Atan,       functor::atan<T>);
267   REGISTER_COMPUTE_FN_HELPER(Atanh,      functor::atanh<T>);
268   REGISTER_COMPUTE_FN_HELPER(Ceil,       functor::ceil<T>);
269   REGISTER_COMPUTE_FN_HELPER(Cos,        functor::cos<T>);
270   REGISTER_COMPUTE_FN_HELPER(Cosh,       functor::cosh<T>);
271   REGISTER_COMPUTE_FN_HELPER(Expm1,      functor::expm1<T>);
272   REGISTER_COMPUTE_FN_HELPER(Exp,        functor::exp<T>);
273   REGISTER_COMPUTE_FN_HELPER(Floor,      functor::floor<T>);
274   REGISTER_COMPUTE_FN_HELPER(Inv,        functor::inverse<T>);
275   REGISTER_COMPUTE_FN_HELPER(Log,        functor::log<T>);
276   REGISTER_COMPUTE_FN_HELPER(Log1p,      functor::log1p<T>);
277   REGISTER_COMPUTE_FN_HELPER(Neg,        functor::neg<T>);
278   REGISTER_COMPUTE_FN_HELPER(Reciprocal, functor::inverse<T>);
279   REGISTER_COMPUTE_FN_HELPER(Rint,       functor::rint<T>);
280   REGISTER_COMPUTE_FN_HELPER(Round,      functor::round<T>);
281   REGISTER_COMPUTE_FN_HELPER(Rsqrt,      functor::rsqrt<T>);
282   REGISTER_COMPUTE_FN_HELPER(Sigmoid,    functor::sigmoid<T>);
283   REGISTER_COMPUTE_FN_HELPER(Sin,        functor::sin<T>);
284   REGISTER_COMPUTE_FN_HELPER(Sinh,       functor::sinh<T>);
285   REGISTER_COMPUTE_FN_HELPER(Sqrt,       functor::sqrt<T>);
286   REGISTER_COMPUTE_FN_HELPER(Square,     functor::square<T>);
287   REGISTER_COMPUTE_FN_HELPER(Tan,        functor::tan<T>);
288   REGISTER_COMPUTE_FN_HELPER(Tanh,       functor::tanh<T>);
289   // clang-format on
290 };
291 
292 template <>
293 struct UnaryOpsCompositionSupport<Eigen::half>
294     : UnaryOpsCompositionBase<Eigen::half> {
295   using T = Eigen::half;
296 
UnaryOpsCompositionSupporttensorflow::UnaryOpsCompositionSupport297   UnaryOpsCompositionSupport() {
298     REGISTER_COMPUTE_FN(Abs);
299     REGISTER_COMPUTE_FN(Ceil);
300     REGISTER_COMPUTE_FN(Cos);
301     REGISTER_COMPUTE_FN(Expm1);
302     REGISTER_COMPUTE_FN(Exp);
303     REGISTER_COMPUTE_FN(Floor);
304     REGISTER_COMPUTE_FN(Inv);
305     REGISTER_COMPUTE_FN(Log);
306     REGISTER_COMPUTE_FN(Log1p);
307     REGISTER_COMPUTE_FN(Neg);
308     REGISTER_COMPUTE_FN(Reciprocal);
309     REGISTER_COMPUTE_FN(Round);
310     REGISTER_COMPUTE_FN(Rsqrt);
311     REGISTER_COMPUTE_FN(Sigmoid);
312     REGISTER_COMPUTE_FN(Sin);
313     REGISTER_COMPUTE_FN(Sqrt);
314     REGISTER_COMPUTE_FN(Square);
315     REGISTER_COMPUTE_FN(Tanh);
316     // Additional compute functions not defined via UnaryOp functors.
317     REGISTER_COMPUTE_FN(Elu);
318     REGISTER_COMPUTE_FN(Relu);
319     REGISTER_COMPUTE_FN(Relu6);
320     REGISTER_COMPUTE_FN(Selu);
321   }
322 
323   REGISTER_RELU_HELPER();
324 
325   // clang-format off
326   REGISTER_COMPUTE_FN_HELPER(Abs,        functor::abs<T>);
327   REGISTER_COMPUTE_FN_HELPER(Ceil,       functor::ceil<T>);
328   REGISTER_COMPUTE_FN_HELPER(Cos,        functor::cos<T>);
329   REGISTER_COMPUTE_FN_HELPER(Expm1,      functor::expm1<T>);
330   REGISTER_COMPUTE_FN_HELPER(Exp,        functor::exp<T>);
331   REGISTER_COMPUTE_FN_HELPER(Floor,      functor::floor<T>);
332   REGISTER_COMPUTE_FN_HELPER(Inv,        functor::inverse<T>);
333   REGISTER_COMPUTE_FN_HELPER(Log,        functor::log<T>);
334   REGISTER_COMPUTE_FN_HELPER(Log1p,      functor::log1p<T>);
335   REGISTER_COMPUTE_FN_HELPER(Neg,        functor::neg<T>);
336   REGISTER_COMPUTE_FN_HELPER(Reciprocal, functor::inverse<T>);
337   REGISTER_COMPUTE_FN_HELPER(Round,      functor::round<T>);
338   REGISTER_COMPUTE_FN_HELPER(Rsqrt,      functor::rsqrt<T>);
339   REGISTER_COMPUTE_FN_HELPER(Sigmoid,    functor::sigmoid<T>);
340   REGISTER_COMPUTE_FN_HELPER(Sin,        functor::sin<T>);
341   REGISTER_COMPUTE_FN_HELPER(Sqrt,       functor::sqrt<T>);
342   REGISTER_COMPUTE_FN_HELPER(Square,     functor::square<T>);
343   REGISTER_COMPUTE_FN_HELPER(Tanh,       functor::tanh<T>);
344   // clang-format on
345 };
346 
347 template <>
348 struct UnaryOpsCompositionSupport<double> : UnaryOpsCompositionBase<double> {
349   using T = double;
350 
UnaryOpsCompositionSupporttensorflow::UnaryOpsCompositionSupport351   UnaryOpsCompositionSupport() {
352     REGISTER_COMPUTE_FN(Abs);
353     REGISTER_COMPUTE_FN(Acos);
354     REGISTER_COMPUTE_FN(Acosh);
355     REGISTER_COMPUTE_FN(Asin);
356     REGISTER_COMPUTE_FN(Asinh);
357     REGISTER_COMPUTE_FN(Atan);
358     REGISTER_COMPUTE_FN(Atanh);
359     REGISTER_COMPUTE_FN(Ceil);
360     REGISTER_COMPUTE_FN(Cos);
361     REGISTER_COMPUTE_FN(Cosh);
362     REGISTER_COMPUTE_FN(Expm1);
363     REGISTER_COMPUTE_FN(Exp);
364     REGISTER_COMPUTE_FN(Floor);
365     REGISTER_COMPUTE_FN(Inv);
366     REGISTER_COMPUTE_FN(Log);
367     REGISTER_COMPUTE_FN(Log1p);
368     REGISTER_COMPUTE_FN(Neg);
369     REGISTER_COMPUTE_FN(Reciprocal);
370     REGISTER_COMPUTE_FN(Rint);
371     REGISTER_COMPUTE_FN(Round);
372     REGISTER_COMPUTE_FN(Rsqrt);
373     REGISTER_COMPUTE_FN(Sigmoid);
374     REGISTER_COMPUTE_FN(Sin);
375     REGISTER_COMPUTE_FN(Sinh);
376     REGISTER_COMPUTE_FN(Sqrt);
377     REGISTER_COMPUTE_FN(Square);
378     REGISTER_COMPUTE_FN(Tan);
379     REGISTER_COMPUTE_FN(Tanh);
380     // Additional compute functions not defined via UnaryOp functors.
381     REGISTER_COMPUTE_FN(Elu);
382     REGISTER_COMPUTE_FN(Relu);
383     REGISTER_COMPUTE_FN(Relu6);
384     REGISTER_COMPUTE_FN(Selu);
385   }
386 
387   REGISTER_RELU_HELPER();
388 
389   // clang-format off
390   REGISTER_COMPUTE_FN_HELPER(Abs,        functor::abs<T>);
391   REGISTER_COMPUTE_FN_HELPER(Acos,       functor::acos<T>);
392   REGISTER_COMPUTE_FN_HELPER(Acosh,      functor::acosh<T>);
393   REGISTER_COMPUTE_FN_HELPER(Asin,       functor::asin<T>);
394   REGISTER_COMPUTE_FN_HELPER(Asinh,      functor::asinh<T>);
395   REGISTER_COMPUTE_FN_HELPER(Atan,       functor::atan<T>);
396   REGISTER_COMPUTE_FN_HELPER(Atanh,      functor::atanh<T>);
397   REGISTER_COMPUTE_FN_HELPER(Ceil,       functor::ceil<T>);
398   REGISTER_COMPUTE_FN_HELPER(Cos,        functor::cos<T>);
399   REGISTER_COMPUTE_FN_HELPER(Cosh,       functor::cosh<T>);
400   REGISTER_COMPUTE_FN_HELPER(Expm1,      functor::expm1<T>);
401   REGISTER_COMPUTE_FN_HELPER(Exp,        functor::exp<T>);
402   REGISTER_COMPUTE_FN_HELPER(Floor,      functor::floor<T>);
403   REGISTER_COMPUTE_FN_HELPER(Inv,        functor::inverse<T>);
404   REGISTER_COMPUTE_FN_HELPER(Log,        functor::log<T>);
405   REGISTER_COMPUTE_FN_HELPER(Log1p,      functor::log1p<T>);
406   REGISTER_COMPUTE_FN_HELPER(Neg,        functor::neg<T>);
407   REGISTER_COMPUTE_FN_HELPER(Reciprocal, functor::inverse<T>);
408   REGISTER_COMPUTE_FN_HELPER(Rint,       functor::rint<T>);
409   REGISTER_COMPUTE_FN_HELPER(Round,      functor::round<T>);
410   REGISTER_COMPUTE_FN_HELPER(Rsqrt,      functor::rsqrt<T>);
411   REGISTER_COMPUTE_FN_HELPER(Sigmoid,    functor::sigmoid<T>);
412   REGISTER_COMPUTE_FN_HELPER(Sin,        functor::sin<T>);
413   REGISTER_COMPUTE_FN_HELPER(Sinh,       functor::sinh<T>);
414   REGISTER_COMPUTE_FN_HELPER(Sqrt,       functor::sqrt<T>);
415   REGISTER_COMPUTE_FN_HELPER(Square,     functor::square<T>);
416   REGISTER_COMPUTE_FN_HELPER(Tan,        functor::tan<T>);
417   REGISTER_COMPUTE_FN_HELPER(Tanh,       functor::tanh<T>);
418   // clang-format on
419 };
420 
421 // Register the CPU kernels.
422 #define REGISTER_CPU(T)                                                       \
423   REGISTER_KERNEL_BUILDER(                                                    \
424       Name("_UnaryOpsComposition").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
425       UnaryOpsComposition<T>);
426 
427 REGISTER_CPU(float);
428 REGISTER_CPU(Eigen::half);
429 REGISTER_CPU(double);
430 
431 #undef REGISTER_CPU
432 
433 }  // namespace tensorflow
434