• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // See docs in ../ops/math_ops.cc.
17 
18 #ifdef INTEL_MKL
19 #define EIGEN_USE_THREADS
20 
21 #include <numeric>
22 
23 #include "mkldnn.hpp"
24 #include "tensorflow/core/framework/numeric_op.h"
25 #include "tensorflow/core/framework/register_types.h"
26 #include "tensorflow/core/lib/gtl/inlined_vector.h"
27 #include "tensorflow/core/platform/logging.h"
28 #include "tensorflow/core/util/mkl_util.h"
29 
30 using mkldnn::stream;
31 using mkldnn::sum;
32 
33 namespace tensorflow {
34 typedef Eigen::ThreadPoolDevice CPUDevice;
35 
36 template <typename Device, typename T>
37 class MklAddNOp : public OpKernel {
38  public:
~MklAddNOp()39   ~MklAddNOp() {}
MklAddNOp(OpKernelConstruction * context)40   explicit MklAddNOp(OpKernelConstruction* context) : OpKernel(context) {}
41 
GetTensorShape(OpKernelContext * ctx,size_t src_index)42   TensorShape GetTensorShape(OpKernelContext* ctx, size_t src_index) {
43     const Tensor& src_tensor = MklGetInput(ctx, src_index);
44     MklDnnShape src_mkl_shape;
45     GetMklShape(ctx, src_index, &src_mkl_shape);
46     return src_mkl_shape.IsMklTensor() ? src_mkl_shape.GetTfShape()
47                                        : src_tensor.shape();
48   }
49 
CheckInputShape(OpKernelContext * ctx)50   bool CheckInputShape(OpKernelContext* ctx) {
51     const int num_inputs = ctx->num_inputs() / 2;
52     const TensorShape src0_shape = GetTensorShape(ctx, 0);
53 
54     for (size_t i = 1; i < num_inputs; ++i) {
55       if (!src0_shape.IsSameSize(GetTensorShape(ctx, i))) {
56         ctx->SetStatus(errors::InvalidArgument(
57             "Inputs to operation ", this->name(), " of type ",
58             this->type_string(),
59             " must have the same size and shape.  Input 0: ",
60             src0_shape.DebugString(), " != input : ", i,
61             GetTensorShape(ctx, i).DebugString()));
62 
63         return false;
64       }
65     }
66 
67     return true;
68   }
69 
70   // Return first tensor index which is in MKL layout, or -1 with no MKL input.
FindMKLInputIndex(OpKernelContext * ctx)71   int FindMKLInputIndex(OpKernelContext* ctx) {
72     int mkl_index = -1;
73     const int num_inputs = ctx->num_inputs() / 2;
74 
75     MklDnnShape src_mkl_shape;
76     for (size_t i = 0; i < num_inputs; ++i) {
77       GetMklShape(ctx, i, &src_mkl_shape);
78       if (src_mkl_shape.IsMklTensor()) {
79         mkl_index = i;
80         break;
81       }
82     }
83 
84     return mkl_index;
85   }
86 
ComputeScalar(OpKernelContext * ctx)87   void ComputeScalar(OpKernelContext* ctx) {
88     const int num_inputs = ctx->num_inputs() / 2;
89     const size_t kOutputIdx = 0;
90     TensorShape output_tf_shape;
91     MklDnnShape output_mkl_shape;
92     Tensor* dst_tensor = nullptr;
93 
94     T sum = static_cast<T>(0);
95     for (int src_idx = 0; src_idx < num_inputs; ++src_idx) {
96       const Tensor& src_tensor = MklGetInput(ctx, src_idx);
97       T* src_i = const_cast<T*>(src_tensor.flat<T>().data());
98       sum += src_i[0];
99     }
100 
101     output_mkl_shape.SetMklTensor(false);
102     output_tf_shape = MklGetInput(ctx, kOutputIdx).shape();
103     AllocateOutputSetMklShape(ctx, kOutputIdx, &dst_tensor, output_tf_shape,
104                               output_mkl_shape);
105 
106     T* out_o = dst_tensor->flat<T>().data();
107     out_o[0] = sum;
108   }
109 
Compute(OpKernelContext * ctx)110   void Compute(OpKernelContext* ctx) override {
111     // Each input tensor in MKL layout has additional meta-tensor carrying
112     // layout information. So the number of actual tensors is half the total
113     // number of inputs.
114     const int num_inputs = ctx->num_inputs() / 2;
115 
116     MklDnnShape mkl_shape;
117     const size_t kSrc0Idx = 0;
118     const size_t kOutputIdx = 0;
119 
120     if (num_inputs == 1) {
121       GetMklShape(ctx, kSrc0Idx, &mkl_shape);
122       bool input_in_mkl_format = mkl_shape.IsMklTensor();
123 
124       if (input_in_mkl_format) {
125         ForwardMklTensorInToOut(ctx, kSrc0Idx, kOutputIdx);
126       } else {
127         ForwardTfTensorInToOut(ctx, kSrc0Idx, kOutputIdx);
128       }
129       return;
130     }
131 
132     // Check if the input shape is same
133     if (!CheckInputShape(ctx)) return;
134 
135     try {
136       TensorShape output_tf_shape;
137       MklDnnShape output_mkl_shape;
138       const Tensor& src_tensor = MklGetInput(ctx, kSrc0Idx);
139 
140       Tensor* dst_tensor = nullptr;
141 
142       // Nothing to compute, return.
143       if (src_tensor.shape().num_elements() == 0) {
144         output_mkl_shape.SetMklTensor(false);
145         output_tf_shape = src_tensor.shape();
146         AllocateOutputSetMklShape(ctx, kOutputIdx, &dst_tensor, output_tf_shape,
147                                   output_mkl_shape);
148         return;
149       }
150 
151       if (src_tensor.dims() == 0) {
152         ComputeScalar(ctx);
153         return;
154       }
155 
156       auto cpu_engine = engine(engine::kind::cpu, 0);
157       std::vector<float> coeff(num_inputs, 1.0);
158       std::vector<memory::desc> srcs_pd;
159       std::vector<memory> inputs;
160 
161       MklDnnData<T> dst(&cpu_engine);
162       MklDnnData<T> src(&cpu_engine);
163       bool has_mkl_input = false;
164       int mkl_input_index = FindMKLInputIndex(ctx);
165       MklTensorFormat mkl_data_format;
166       TensorFormat tf_data_format;
167       memory::format_tag dnn_fmt = memory::format_tag::any;
168       if (mkl_input_index >= 0) {
169         has_mkl_input = true;
170         GetMklShape(ctx, mkl_input_index, &mkl_shape);
171         // MKL input has the data format information.
172         mkl_data_format = mkl_shape.GetTfDataFormat();
173         tf_data_format = MklDnnDataFormatToTFDataFormat(mkl_data_format);
174         dnn_fmt = MklTensorFormatToMklDnnDataFormat(mkl_data_format);
175       }
176 
177       std::shared_ptr<stream> fwd_cpu_stream;
178       MklDnnThreadPool eigen_tp(ctx);
179       fwd_cpu_stream.reset(CreateStream(&eigen_tp, cpu_engine));
180 
181       // Create memory descriptor for MKL-DNN.
182       // If all input in Tensorflow format, create block memory descriptor,
183       // else convert TF format to MKL memory descriptor
184       for (int src_idx = 0; src_idx < num_inputs; ++src_idx) {
185         MklDnnShape src_mkl_shape;
186         GetMklShape(ctx, src_idx, &src_mkl_shape);
187         memory::desc md({}, memory::data_type::undef,
188                         memory::format_tag::undef);
189         const Tensor& src_tensor = MklGetInput(ctx, src_idx);
190 
191         if (src_mkl_shape.IsMklTensor()) {
192           md = src_mkl_shape.GetMklLayout();
193         } else {
194           if (has_mkl_input) {
195             memory::dims src_dims;
196             if (src_tensor.dims() == 4) {
197               src_dims =
198                   TFShapeToMklDnnDimsInNCHW(src_tensor.shape(), tf_data_format);
199             } else {
200               DCHECK(src_tensor.dims() == 5);
201               src_dims = TFShapeToMklDnnDimsInNCDHW(src_tensor.shape(),
202                                                     tf_data_format);
203             }
204             md = memory::desc(src_dims, MklDnnType<T>(), dnn_fmt);
205           } else {
206             // Create block memory descriptor for TensorFlow format input.
207             auto dims = TFShapeToMklDnnDims(src_tensor.shape());
208             auto strides = CalculateTFStrides(dims);
209             md = MklDnnData<T>::CreateBlockedMemDesc(dims, strides);
210           }
211         }
212         srcs_pd.push_back(memory::desc(md));
213         src.SetUsrMem(md, &src_tensor);
214         src.SetUsrMemDataHandle(&src_tensor, fwd_cpu_stream);
215         inputs.push_back(src.GetOpMem());
216       }
217 
218       auto sum_pd = sum::primitive_desc(coeff, srcs_pd, cpu_engine);
219       output_mkl_shape.SetMklTensor(has_mkl_input);
220       auto output_pd = sum_pd.dst_desc();
221       dst.SetUsrMem(output_pd);
222 
223       if (has_mkl_input) {
224         output_mkl_shape.SetMklLayout(&output_pd);
225         output_mkl_shape.SetElemType(MklDnnType<T>());
226         output_mkl_shape.SetTfLayout(mkl_shape.GetDimension(),
227                                      mkl_shape.GetSizesAsMklDnnDims(),
228                                      mkl_shape.GetTfDataFormat());
229         output_tf_shape.AddDim((output_pd.get_size() / sizeof(T)));
230       } else {
231         // All inputs have TF shapes, get the shape from first one.
232         output_tf_shape = MklGetInput(ctx, kSrc0Idx).shape();
233       }
234       AllocateOutputSetMklShape(ctx, kOutputIdx, &dst_tensor, output_tf_shape,
235                                 output_mkl_shape);
236       dst.SetUsrMemDataHandle(dst_tensor, fwd_cpu_stream);
237 
238       // Create Sum op, and submit net for execution.
239       std::vector<primitive> net;
240       mkldnn::sum sum_op(sum_pd);
241       std::unordered_map<int, memory> net_args = {
242           {MKLDNN_ARG_DST, dst.GetOpMem()}};
243       for (int i = 0; i < num_inputs; ++i) {
244         net_args.insert({MKLDNN_ARG_MULTIPLE_SRC + i, inputs[i]});
245       }
246       sum_op.execute(*fwd_cpu_stream, net_args);
247     } catch (mkldnn::error& e) {
248       string error_msg = "Status: " + std::to_string(e.status) +
249                          ", message: " + string(e.message) + ", in file " +
250                          string(__FILE__) + ":" + std::to_string(__LINE__);
251       OP_REQUIRES_OK(
252           ctx, errors::Aborted("Operation received an exception:", error_msg));
253     }
254   }
255 };
256 
257 #define REGISTER_MKL_CPU(T)                                    \
258   REGISTER_KERNEL_BUILDER(                                     \
259       Name("_MklAddN")                                         \
260           .Device(DEVICE_CPU)                                  \
261           .TypeConstraint<T>("T")                              \
262           .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
263       MklAddNOp<CPUDevice, T>);
264 
265 TF_CALL_float(REGISTER_MKL_CPU);
266 TF_CALL_bfloat16(REGISTER_MKL_CPU);
267 #undef REGISTER_MKL_CPU
268 }  // namespace tensorflow
269 #endif  // INTEL_MKL
270