• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Implements matmul operations with other kernels baked into the
17 // processing, to optimize latency and memory usage:
18 //  - MatMul + BiasAdd + <Activation>
19 //  - MatMul + FusedBatchNorm + <Activation>
20 //
21 // Activation: Relu, Relu6, Elu, etc...
22 //
23 // Currently supported only on CPU device.
24 
25 #ifndef TENSORFLOW_CORE_KERNELS_MATMUL_OP_FUSED_H_
26 #define TENSORFLOW_CORE_KERNELS_MATMUL_OP_FUSED_H_
27 
28 #define USE_EIGEN_TENSOR
29 #define EIGEN_USE_THREADS
30 
31 #include <string>
32 #include <vector>
33 
34 #include "tensorflow/core/framework/bounds_check.h"
35 #include "tensorflow/core/framework/op_kernel.h"
36 #include "tensorflow/core/framework/register_types.h"
37 #include "tensorflow/core/framework/tensor.h"
38 #include "tensorflow/core/framework/tensor_shape.h"
39 #include "tensorflow/core/kernels/fill_functor.h"
40 #include "tensorflow/core/kernels/fused_eigen_output_kernels.h"
41 #include "tensorflow/core/util/tensor_format.h"
42 
43 #if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL)
44 #include "tensorflow/core/kernels/eigen_contraction_kernel.h"
45 #endif
46 
47 namespace tensorflow {
48 
49 typedef Eigen::ThreadPoolDevice CPUDevice;
50 
51 template <typename Device, typename T>
52 struct LaunchFusedMatMulOp {
53   void operator()(
54       OpKernelContext* context, const Tensor& a, const Tensor& b,
55       const Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1>& dim_pair,
56       FusedComputationType fusion, const FusedComputationArgs& fusion_args,
57       Tensor* output);
58 };
59 
60 template <typename T>
61 struct LaunchFusedMatMulOp<CPUDevice, T> {
operator ()tensorflow::LaunchFusedMatMulOp62   void operator()(
63       OpKernelContext* context, const Tensor& a, const Tensor& b,
64       const Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1>& dim_pair,
65       FusedComputationType fusion, const FusedComputationArgs& fusion_args,
66       Tensor* output) {
67     auto lhs = a.matrix<T>();
68     auto rhs = b.matrix<T>();
69     auto out = output->matrix<T>();
70 
71     auto& d = context->eigen_device<CPUDevice>();
72 
73     // Executes Eigen contraction with output kernel wrapped into type erased
74     // wrapper to reduce the number of unique template instantiations.
75     auto executeWithOutputKernel = [&](auto output_kernel) {
76       OutputKernelWrapper output_kernel_wrapper(
77           [&output_kernel](
78               const ContractionOutputMapper<T, Eigen::Index>& output_mapper,
79               const Eigen::TensorContractionParams& params, Eigen::Index i,
80               Eigen::Index j, Eigen::Index num_rows, Eigen::Index num_cols) {
81             output_kernel(output_mapper, params, i, j, num_rows, num_cols);
82           });
83 
84       out.device(d) = lhs.contract(rhs, dim_pair, output_kernel_wrapper);
85     };
86 
87     BiasAddArgs<T> bias_add_args;
88     if (BiasAddArgs<T>::IsSupported(fusion)) {
89       if (fusion == FusedComputationType::kBiasAddWithLeakyRelu) {
90         OP_REQUIRES_OK(context, InitBiasAddArgs(context, &bias_add_args,
91                                                 &fusion_args.leakyrelu_alpha));
92       } else {
93         OP_REQUIRES_OK(context, InitBiasAddArgs(context, &bias_add_args));
94       }
95     }
96 
97     switch (fusion) {
98       case FusedComputationType::kBiasAdd:
99         executeWithOutputKernel(WithBiasAdd<T>(bias_add_args));
100         break;
101       case FusedComputationType::kBiasAddWithRelu:
102         executeWithOutputKernel(WithBiasAddAndRelu<T>(bias_add_args));
103         break;
104       case FusedComputationType::kBiasAddWithRelu6:
105         executeWithOutputKernel(WithBiasAddAndRelu6<T>(bias_add_args));
106         break;
107       case FusedComputationType::kBiasAddWithElu:
108         executeWithOutputKernel(WithBiasAddAndElu<T>(bias_add_args));
109         break;
110       case FusedComputationType::kBiasAddWithLeakyRelu:
111         executeWithOutputKernel(WithBiasAddAndLeakyRelu<T>(bias_add_args));
112         break;
113       case FusedComputationType::kUndefined:
114         OP_REQUIRES_OK(context, errors::Internal("Fusion type is undefined"));
115         break;
116       default:
117         OP_REQUIRES_OK(context,
118                        errors::Internal("Fusion type is not supported"));
119     }
120   }
121 
122  private:
123   // Wrap output_kernel into type erased struct to reduce the number of unique
124   // template instantiations for Eigen Tensor contraction expressions.
125   //
126   // We do not pass std::function directly as an output kernel because it blows
127   // up the binary size in debug mode with super long symbol names.
128   struct OutputKernelWrapper {
129     using OutputKernelFn =
130         std::function<void(const ContractionOutputMapper<T, Eigen::Index>&,
131                            const Eigen::TensorContractionParams&, Eigen::Index,
132                            Eigen::Index, Eigen::Index, Eigen::Index)>;
133 
OutputKernelWrappertensorflow::LaunchFusedMatMulOp::OutputKernelWrapper134     explicit OutputKernelWrapper(OutputKernelFn fn)
135         : output_kernel_fn(std::move(fn)) {}
136 
operator ()tensorflow::LaunchFusedMatMulOp::OutputKernelWrapper137     void operator()(
138         const ContractionOutputMapper<T, Eigen::Index>& output_mapper,
139         const Eigen::TensorContractionParams& params, Eigen::Index i,
140         Eigen::Index j, Eigen::Index num_rows, Eigen::Index num_cols) const {
141       output_kernel_fn(output_mapper, params, i, j, num_rows, num_cols);
142     }
143 
144     OutputKernelFn output_kernel_fn;
145   };
146 };
147 
148 template <typename Device, typename T>
149 class FusedMatMulOp : public OpKernel {
150  public:
FusedMatMulOp(OpKernelConstruction * context)151   explicit FusedMatMulOp(OpKernelConstruction* context) : OpKernel(context) {
152     OP_REQUIRES_OK(context, context->GetAttr("transpose_a", &transpose_a_));
153     OP_REQUIRES_OK(context, context->GetAttr("transpose_b", &transpose_b_));
154 
155     std::vector<FusedComputationPattern> patterns;
156 
157     using FCT = FusedComputationType;
158     if (std::is_same<Device, CPUDevice>::value) {
159       patterns = {
160           {FCT::kBiasAdd, {"BiasAdd"}},
161           {FCT::kBiasAddWithRelu, {"BiasAdd", "Relu"}},
162           {FCT::kBiasAddWithRelu6, {"BiasAdd", "Relu6"}},
163           {FCT::kBiasAddWithElu, {"BiasAdd", "Elu"}},
164           {FCT::kBiasAddWithLeakyRelu, {"BiasAdd", "LeakyRelu"}},
165       };
166     }
167 
168     OP_REQUIRES_OK(context, InitializeFusedComputation(
169                                 context, "MatMul", patterns,
170                                 &fused_computation_, &fused_computation_args_));
171   }
172 
Compute(OpKernelContext * ctx)173   void Compute(OpKernelContext* ctx) override {
174     const Tensor& a = ctx->input(0);
175     const Tensor& b = ctx->input(1);
176 
177     // Check that the dimensions of the two matrices are valid.
178     OP_REQUIRES(
179         ctx, TensorShapeUtils::IsMatrix(a.shape()),
180         errors::InvalidArgument("In[0] is not a matrix. Instead it has shape ",
181                                 a.shape().DebugString()));
182     OP_REQUIRES(
183         ctx, TensorShapeUtils::IsMatrix(b.shape()),
184         errors::InvalidArgument("In[1] is not a matrix. Instead it has shape ",
185                                 b.shape().DebugString()));
186     Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> dim_pair;
187     dim_pair[0].first = transpose_a_ ? 0 : 1;
188     dim_pair[0].second = transpose_b_ ? 1 : 0;
189 
190     OP_REQUIRES(
191         ctx, a.dim_size(dim_pair[0].first) == b.dim_size(dim_pair[0].second),
192         errors::InvalidArgument(
193             "Matrix size-incompatible: In[0]: ", a.shape().DebugString(),
194             ", In[1]: ", b.shape().DebugString()));
195     int a_dim_remaining = 1 - dim_pair[0].first;
196     int b_dim_remaining = 1 - dim_pair[0].second;
197     TensorShape out_shape(
198         {a.dim_size(a_dim_remaining), b.dim_size(b_dim_remaining)});
199     Tensor* out = nullptr;
200     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, out_shape, &out));
201 
202     if (out->NumElements() == 0) {
203       // If a has shape [0, x] or b has shape [x, 0], the output shape
204       // is a 0-element matrix, so there is nothing to do.
205       return;
206     }
207 
208     if (a.NumElements() == 0 && b.NumElements() == 0) {
209       // If a has shape [x, 0] and b has shape [0, y], the
210       // output shape is [x, y] where x and y are non-zero, so we fill
211       // the output with zeros.
212       functor::SetZeroFunctor<Device, T> f;
213       f(ctx->eigen_device<Device>(), out->flat<T>());
214       return;
215     }
216 
217     auto launch = LaunchFusedMatMulOp<Device, T>();
218     launch(ctx, a, b, dim_pair, fused_computation_, fused_computation_args_,
219            out);
220   }
221 
222  private:
223   bool transpose_a_;
224   bool transpose_b_;
225 
226   FusedComputationType fused_computation_ = FusedComputationType::kUndefined;
227   FusedComputationArgs fused_computation_args_;
228 
229   TF_DISALLOW_COPY_AND_ASSIGN(FusedMatMulOp);
230 };
231 
232 // Registration of the CPU implementations.
233 #define REGISTER_FUSED_CPU_MATMUL(T)                                  \
234   REGISTER_KERNEL_BUILDER(                                            \
235       Name("_FusedMatMul").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
236       FusedMatMulOp<CPUDevice, T>);
237 
238 TF_CALL_float(REGISTER_FUSED_CPU_MATMUL);
239 
240 #undef REGISTER_FUSED_CPU_MATMUL
241 
242 }  // namespace tensorflow
243 #endif  // TENSORFLOW_CORE_KERNELS_MATMUL_OP_FUSED_H_
244