• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // See docs in ../ops/array_ops.cc.
17 
18 #ifdef INTEL_MKL
19 #ifndef INTEL_MKL_ML_ONLY
20 
21 #include "tensorflow/core/framework/op_kernel.h"
22 #include "tensorflow/core/framework/register_types.h"
23 #include "tensorflow/core/framework/tensor.h"
24 #include "tensorflow/core/kernels/ops_util.h"
25 #include "tensorflow/core/lib/core/status.h"
26 #include "tensorflow/core/lib/gtl/array_slice.h"
27 #include "tensorflow/core/platform/prefetch.h"
28 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
29 
30 #include "mkldnn.hpp"
31 #include "tensorflow/core/util/mkl_util.h"
32 
33 using mkldnn::stream;
34 using mkldnn::view;
35 
36 namespace tensorflow {
37 
38 namespace {
39 
IntTensorToInt64Vec(const Tensor & tensor)40 gtl::InlinedVector<int64, 4> IntTensorToInt64Vec(const Tensor& tensor) {
41   gtl::InlinedVector<int64, 4> out;
42   if (tensor.dtype() == DT_INT32) {
43     for (int64 i = 0; i < tensor.NumElements(); ++i) {
44       out.push_back(tensor.flat<int32>()(i));
45     }
46   } else if (tensor.dtype() == DT_INT64) {
47     for (int64 i = 0; i < tensor.NumElements(); ++i) {
48       out.push_back(tensor.flat<int64>()(i));
49     }
50   } else {
51     // tensor must be either int32 or int64
52     DCHECK(false);
53   }
54   return out;
55 }
56 
57 }  // namespace
58 
59 typedef Eigen::ThreadPoolDevice CPUDevice;
60 
61 // A version of SharedValidation (slice_op.h) written for input that is in
62 // either Mkl layout or Tensorflow layout. A shared code to validate input
63 // shapes and check for identity, which is not dependent on the type of T.
64 // We do this to reduce code size by not duplicating all this for all T
65 // (float, double, int32, etc.)
ValidateMklInputs(OpKernelContext * context,bool * is_identity,gtl::InlinedVector<int64,4> * begin,gtl::InlinedVector<int64,4> * size)66 static void ValidateMklInputs(OpKernelContext* context, bool* is_identity,
67                               gtl::InlinedVector<int64, 4>* begin,
68                               gtl::InlinedVector<int64, 4>* size) {
69   const int kInputTensorIndex = 0;
70   const int kInputBeginIndex = 1;
71   const int kInputSizeIndex = 2;
72   const Tensor& input = MklGetInput(context, kInputTensorIndex);
73   const Tensor& begin_tensor = MklGetInput(context, kInputBeginIndex);
74   const Tensor& size_tensor = MklGetInput(context, kInputSizeIndex);
75 
76   MklDnnShape input_mkl_shape, begin_mkl_shape, size_mkl_shape;
77   GetMklShape(context, kInputTensorIndex, &input_mkl_shape);
78   GetMklShape(context, kInputBeginIndex, &begin_mkl_shape);
79   GetMklShape(context, kInputSizeIndex, &size_mkl_shape);
80 
81   // Begin and size tensors cannot be in MklDnn layout.
82   DCHECK_EQ(begin_mkl_shape.IsMklTensor(), false);
83   DCHECK_EQ(size_mkl_shape.IsMklTensor(), false);
84 
85   TensorShape input_tf_shape = input_mkl_shape.IsMklTensor()
86                                    ? input_mkl_shape.GetTfShape()
87                                    : input.shape();
88   const int input_dims = input_tf_shape.dims();
89 
90   OP_REQUIRES(
91       context, context->op_kernel().IsLegacyVector(begin_tensor.shape()) &&
92                    context->op_kernel().IsLegacyVector(size_tensor.shape()) &&
93                    begin_tensor.NumElements() == input_dims &&
94                    size_tensor.NumElements() == input_dims,
95       errors::InvalidArgument(
96           "Expected begin and size arguments to be 1-D tensors of size ",
97           input_dims, ", but got shapes ", begin_tensor.shape().DebugString(),
98           " and ", size_tensor.shape().DebugString(), " instead."));
99 
100   *begin = IntTensorToInt64Vec(begin_tensor);
101   *size = IntTensorToInt64Vec(size_tensor);
102   for (int i = 0; i < input_dims; ++i) {
103     if ((*size)[i] == -1) {
104       // A size[i] of -1 means "all elements from begin[i] to dim_size(i)".
105       (*size)[i] = input_tf_shape.dim_size(i) - (*begin)[i];
106     }
107   }
108 
109   *is_identity = true;
110   for (int i = 0; i < input_dims; ++i) {
111     int64 b = (*begin)[i];
112     int64 s = (*size)[i];
113     if (input_tf_shape.dim_size(i) == 0) {
114       OP_REQUIRES(
115           context, b == 0 && s == 0,
116           errors::InvalidArgument("Expected begin[", i, "] == 0 (got ", b,
117                                   ") and size[", i, "] == 0 ", "(got ", s,
118                                   ") when ", "input.dim_size(", i, ") == 0"));
119     } else {
120       OP_REQUIRES(context, 0 <= b && b <= input_tf_shape.dim_size(i),
121                   errors::InvalidArgument("Expected begin[", i, "] in [0, ",
122                                           input_tf_shape.dim_size(i),
123                                           "], but got ", b));
124       OP_REQUIRES(context, 0 <= s && b + s <= input_tf_shape.dim_size(i),
125                   errors::InvalidArgument("Expected size[", i, "] in [0, ",
126                                           input_tf_shape.dim_size(i) - b,
127                                           "], but ", "got ", s));
128     }
129     const bool take_all = (b == 0) && (s == input_tf_shape.dim_size(i));
130     (*is_identity) &= take_all;
131   }
132 }
133 
134 // A version of SharedSliceCommonCases function written for input tensor
135 // that may be in MklDnn layout or in Tensorflow layout.
136 template <typename T>
CheckCommonCasesForMklInputs(OpKernelContext * context,gtl::InlinedVector<int64,4> * begin,gtl::InlinedVector<int64,4> * size,bool * done)137 static void CheckCommonCasesForMklInputs(OpKernelContext* context,
138                                          gtl::InlinedVector<int64, 4>* begin,
139                                          gtl::InlinedVector<int64, 4>* size,
140                                          bool* done) {
141   bool is_identity = true;
142   *done = false;
143 
144   ValidateMklInputs(context, &is_identity, begin, size);
145   if (!context->status().ok()) return;
146 
147   const Tensor& input = MklGetInput(context, 0);
148   MklDnnShape input_mkl_shape;
149   GetMklShape(context, 0, &input_mkl_shape);
150 
151   if (is_identity) {
152     VLOG(1) << "Slice identity";
153     context->set_output(0, input);
154     // Mkl metadata tensor in this case can just be forwarded from input to
155     // output.
156     AllocateOutputSetMklShape(context, 0, input_mkl_shape);
157     *done = true;
158   }
159 }
160 
161 // This structure aggregates multiple inputs to Slice methods.
162 struct MklSliceParams {
163   // Parameters from & to represents memory pointing to reorder.
164   const memory* from;
165   const memory* to;
166 
167   // Parameters begin_dims & size_dims represents offset and length
168   // passed to view primitive.
169   memory::dims begin_dims;
170   memory::dims size_dims;
171 
MklSliceParamstensorflow::MklSliceParams172   MklSliceParams(const memory* from, const memory* to, memory::dims begin_dims,
173                  memory::dims size_dims)
174       : from(from), to(to), begin_dims(begin_dims), size_dims(size_dims) {}
175 };
176 
177 // This implements the shared interface of Slice reorders.
178 template <typename T>
179 class MklSlicePrimitive : public MklPrimitive {
180  public:
MklSlicePrimitive(const MklSliceParams & sliceParams)181   explicit MklSlicePrimitive(const MklSliceParams& sliceParams) {
182     context_.slice_stream.reset(new stream(stream::kind::eager));
183     Setup(sliceParams);
184   }
185 
~MklSlicePrimitive()186   ~MklSlicePrimitive() {}
187 
Execute(const MklSliceParams & sliceParams)188   void Execute(const MklSliceParams& sliceParams) {
189     context_.src_mem->set_data_handle(sliceParams.from->get_data_handle());
190     context_.dst_mem->set_data_handle(sliceParams.to->get_data_handle());
191     context_.slice_stream->submit(context_.slice_primitives);
192 
193     // We should set it back to DummyData so as to make the primitive
194     // in cache pool stateless. Otherwise, if the result for previous
195     // iteration is kept, problems of current iteration won't be
196     // thrown immediately, and wrong data would be reused.
197     context_.src_mem->set_data_handle(DummyData);
198     context_.dst_mem->set_data_handle(DummyData);
199     return;
200   }
201 
GetPrimitive()202   std::shared_ptr<primitive> GetPrimitive() { return context_.reorder_prim; }
203 
204  private:
205   struct SliceContext {
206     std::shared_ptr<mkldnn::memory> src_mem;
207     std::shared_ptr<mkldnn::memory> dst_mem;
208     std::shared_ptr<primitive> reorder_prim;
209     std::shared_ptr<reorder::primitive_desc> reorder_pd;
210     std::shared_ptr<view::primitive_desc> view_pd;
211     std::shared_ptr<mkldnn::stream> slice_stream;
212     std::vector<mkldnn::primitive> slice_primitives;
SliceContexttensorflow::MklSlicePrimitive::SliceContext213     SliceContext()
214         : src_mem(nullptr), dst_mem(nullptr), reorder_prim(nullptr) {}
215   } context_;
216 
217   engine cpu_engine_ = engine(engine::cpu, 0);
218 
Setup(const MklSliceParams & sliceParams)219   void Setup(const MklSliceParams& sliceParams) {
220     // Actually, this DummyData will not be used in computation,
221     // because the real data will be filled before real execution.
222     context_.src_mem.reset(
223         new memory({sliceParams.from->get_primitive_desc().desc(), cpu_engine_},
224                    DummyData));
225     context_.dst_mem.reset(new memory(
226         {sliceParams.to->get_primitive_desc().desc(), cpu_engine_}, DummyData));
227     auto src_pd = context_.src_mem->get_primitive_desc();
228     auto dst_pd = context_.dst_mem->get_primitive_desc();
229     context_.view_pd =
230         std::make_shared<view::primitive_desc>(view::primitive_desc(
231             src_pd, sliceParams.size_dims, sliceParams.begin_dims));
232     context_.reorder_pd =
233         std::make_shared<reorder::primitive_desc>(reorder::primitive_desc(
234             context_.view_pd->dst_primitive_desc(), dst_pd));
235     context_.reorder_prim = std::make_shared<mkldnn::reorder>(
236         reorder(*context_.reorder_pd, *context_.src_mem, *context_.dst_mem));
237     context_.slice_primitives.push_back(*context_.reorder_prim);
238   }
239 };
240 
241 template <typename T>
242 class MklSlicePrimitiveFactory : public MklPrimitiveFactory<T> {
243  public:
Get(const MklSliceParams & sliceParams)244   static MklSlicePrimitive<T>* Get(const MklSliceParams& sliceParams) {
245     auto reorderPrim = static_cast<MklSlicePrimitive<T>*>(
246         MklSlicePrimitiveFactory<T>::GetInstance().GetReorder(sliceParams));
247     if (reorderPrim == nullptr) {
248       reorderPrim = new MklSlicePrimitive<T>(sliceParams);
249       MklSlicePrimitiveFactory<T>::GetInstance().SetReorder(sliceParams,
250                                                             reorderPrim);
251     }
252     return reorderPrim;
253   }
254 
GetInstance()255   static MklSlicePrimitiveFactory& GetInstance() {
256     static MklSlicePrimitiveFactory instance_;
257     return instance_;
258   }
259 
260  private:
MklSlicePrimitiveFactory()261   MklSlicePrimitiveFactory() {}
~MklSlicePrimitiveFactory()262   ~MklSlicePrimitiveFactory() {}
263 
CreateKey(const MklSliceParams & sliceParams)264   static string CreateKey(const MklSliceParams& sliceParams) {
265     string prefix = "reorder";
266     FactoryKeyCreator key_creator;
267     auto const& from_desc = sliceParams.from->get_primitive_desc().desc().data;
268     auto const& to_desc = sliceParams.to->get_primitive_desc().desc().data;
269     const int kIdxFirstStride = 0;
270     memory::dims from_dims(from_desc.dims, &from_desc.dims[from_desc.ndims]);
271     memory::dims to_dims(to_desc.dims, &to_desc.dims[to_desc.ndims]);
272     memory::dims from_strides(
273         from_desc.layout_desc.blocking.strides[kIdxFirstStride],
274         &from_desc.layout_desc.blocking
275              .strides[kIdxFirstStride][from_desc.ndims]);
276     memory::dims to_strides(
277         to_desc.layout_desc.blocking.strides[kIdxFirstStride],
278         &to_desc.layout_desc.blocking.strides[kIdxFirstStride][to_desc.ndims]);
279     key_creator.AddAsKey(prefix);
280     key_creator.AddAsKey(static_cast<int>(from_desc.format));
281     key_creator.AddAsKey(static_cast<int>(from_desc.data_type));
282     key_creator.AddAsKey(from_dims);
283     key_creator.AddAsKey(from_strides);
284     key_creator.AddAsKey(static_cast<int>(to_desc.format));
285     key_creator.AddAsKey(static_cast<int>(to_desc.data_type));
286     key_creator.AddAsKey(to_dims);
287     key_creator.AddAsKey(to_strides);
288     key_creator.AddAsKey(sliceParams.begin_dims);
289     key_creator.AddAsKey(sliceParams.size_dims);
290     return key_creator.GetKey();
291   }
292 
GetReorder(const MklSliceParams & sliceParams)293   MklPrimitive* GetReorder(const MklSliceParams& sliceParams) {
294     string key = CreateKey(sliceParams);
295     return this->GetOp(key);
296   }
297 
SetReorder(const MklSliceParams & sliceParams,MklPrimitive * op)298   void SetReorder(const MklSliceParams& sliceParams, MklPrimitive* op) {
299     string key = CreateKey(sliceParams);
300     this->SetOp(key, op);
301   }
302 };
303 
304 // MKL-DNN implementation of Slice
305 template <typename Device, typename T>
306 class MklSliceOp : public OpKernel {
307  public:
MklSliceOp(OpKernelConstruction * context)308   explicit MklSliceOp(OpKernelConstruction* context) : OpKernel(context) {}
309 
~MklSliceOp()310   ~MklSliceOp() {}
311 
Compute(OpKernelContext * context)312   void Compute(OpKernelContext* context) override {
313     gtl::InlinedVector<int64, 4> begin;
314     gtl::InlinedVector<int64, 4> size;
315     bool done = false;
316 
317     CheckCommonCasesForMklInputs<T>(context, &begin, &size, &done);
318     if (!context->status().ok() || done == true) return;
319 
320     // Though MKL-DNN supports more than 8 dimension and
321     // less than 12 dimension tensor.
322     // But we are mimicking functionality of Eigen Slice op for CPU.
323     if (begin.size() >= 8) {
324       OP_REQUIRES(
325           context, false,
326           errors::Unimplemented("MklSliceOp : Unhandled input dimensions"));
327     }
328 
329     ComputeMklSlice(context, begin, size);
330   }
331 
332  private:
333   // Slice op implemented using MKL-DNN APIs.
ComputeMklSlice(OpKernelContext * context,const gtl::InlinedVector<int64,4> & begin,const gtl::InlinedVector<int64,4> & size)334   void ComputeMklSlice(OpKernelContext* context,
335                        const gtl::InlinedVector<int64, 4>& begin,
336                        const gtl::InlinedVector<int64, 4>& size) {
337     try {
338       // MKL-DNN API usage below is guided by description at:
339       //  https://github.com/01org/mkl-dnn/issues/69
340       //
341       // Relevant part of the description is copied below:
342       //
343       // Let's say you want to copy a part of memory into another buffer (and
344       // probably change the format). Then your steps are:
345       //
346       // 1. create memory primitive descriptor in_mem_pd and memory primitive
347       //    in_mem_p for the entire source data. create view primitive
348       //    descriptor in_submem_pd based on in_mem_pd, initial offsets,
349       //    and sub-sizes
350       // 2. create memory primitive descriptor out_mem_pd and memory primitive
351       //    out_mem_p for the output (the logical sizes should match sub-sizes
352       //    used in step 1, but the format might be arbitrary)
353       // 3. create reorder primitive descriptor reorder_pd based on in_submem_pd
354       //    and out_mem_pd. create reorder primitive itself based on reorder_pd,
355       //    in_mem_p, and out_mem_p.
356       //
357       // Please notice that there is no view primitive. There is only view
358       // primitive descriptor. And the reorder uses source memory as input but
359       // traverses it according to a view in_submem_pd.
360 
361       auto cpu_engine = engine(engine::cpu, 0);
362       MklDnnData<T> src(&cpu_engine);
363       MklDnnData<T> output(&cpu_engine);
364 
365       // Populate offsets and sizes in memory::dims format based on vector.
366       memory::dims begin_dims = {};
367       begin_dims.resize(begin.size());
368       for (size_t i = 0; i < begin.size(); ++i) begin_dims[i] = begin[i];
369       memory::dims size_dims = {};
370       bool empty = false;
371       size_dims.resize(size.size());
372       for (size_t i = 0; i < size.size(); ++i) {
373         size_dims[i] = size[i];
374         if (size_dims[i] == 0) empty = true;
375       }
376 
377       Tensor* output_tensor = nullptr;
378       MklDnnShape output_mkl_shape;
379 
380       // If no dimension is selected in slice, the result should be empty.
381       // Just return an empty output tensor, and a dummy Mkl-shape tensor.
382       if (empty) {  // for empty dims
383         auto shape_to = MklDnnDimsToTFShape(size_dims);
384         AllocateOutputSetMklShape(context, 0, &output_tensor, shape_to,
385                                   output_mkl_shape);
386         return;
387       }
388 
389       // Step 1 (as per above description) - Create memory for user data.
390       // We use blocked format here to describe input tensor.
391       const Tensor& input_tensor = MklGetInput(context, 0);
392       MklDnnShape input_mkl_shape;
393       GetMklShape(context, 0, &input_mkl_shape);
394 
395       if (input_mkl_shape.IsMklTensor()) {
396         auto input_mkl_format = input_mkl_shape.GetTfDataFormat();
397         auto input_tf_format = MklDnnDataFormatToTFDataFormat(input_mkl_format);
398         begin_dims = MklDnnDimsInNCHW(begin_dims, input_tf_format);
399         size_dims = MklDnnDimsInNCHW(size_dims, input_tf_format);
400         auto input_md = input_mkl_shape.GetMklLayout();
401         src.SetUsrMem(input_md, &input_tensor);
402       } else {
403         // Initialize input dimensions and strides to be used when input is not
404         // in MklDnn layout.
405         memory::dims input_dims, input_strides;
406         input_dims = TFShapeToMklDnnDims(input_tensor.shape());
407         input_strides = CalculateTFStrides(input_dims);
408         // Create input memory descriptor.
409         auto input_md =
410             MklDnnData<T>::CreateBlockedMemDesc(input_dims, input_strides);
411         src.SetUsrMem(input_md, &input_tensor);
412       }
413 
414       // Step 2 - Create memory for output.
415       auto output_strides = CalculateTFStrides(size_dims);
416       auto output_md =
417           MklDnnData<T>::CreateBlockedMemDesc(size_dims, output_strides);
418       auto output_pd = memory::primitive_desc(output_md, cpu_engine);
419       AllocateOutputTensor(context, input_mkl_shape, &output_pd, size_dims,
420                            &output_tensor, &output_mkl_shape);
421       DCHECK(output_tensor);
422       DCHECK_EQ(input_mkl_shape.IsMklTensor(), output_mkl_shape.IsMklTensor());
423       output.SetUsrMem(output_md, output_tensor);
424 
425       // Step 3 - create reorder primitive.
426       MklSliceParams sliceParams(src.GetUsrMem(), output.GetUsrMem(),
427                                  begin_dims, size_dims);
428       MklSlicePrimitive<T>* reorder_prim =
429           MklSlicePrimitiveFactory<T>::Get(sliceParams);
430       // Execute slice reorder.
431       reorder_prim->Execute(sliceParams);
432     } catch (mkldnn::error& e) {
433       string error_msg = "Status: " + std::to_string(e.status) + ", message: " +
434                          string(e.message) + ", in file " + string(__FILE__) +
435                          ":" + std::to_string(__LINE__);
436       OP_REQUIRES_OK(
437           context,
438           errors::Aborted("Operation received an exception:", error_msg));
439     }
440   }
441 
442  private:
AllocateOutputTensor(OpKernelContext * context,const MklDnnShape & input_mkl_shape,memory::primitive_desc * output_pd,const memory::dims & output_dims,Tensor ** output_tensor,MklDnnShape * output_mkl_shape)443   void AllocateOutputTensor(OpKernelContext* context,
444                             const MklDnnShape& input_mkl_shape,
445                             memory::primitive_desc* output_pd,
446                             const memory::dims& output_dims,
447                             Tensor** output_tensor,
448                             MklDnnShape* output_mkl_shape) {
449     DCHECK(output_tensor);
450     DCHECK(output_mkl_shape);
451 
452     TensorShape output_tf_shape;
453 
454     if (input_mkl_shape.IsMklTensor()) {
455       // Since input tensor is in Mkl layout, output tensor will be in Mkl
456       // layout.
457 
458       // Allocate shape of Mkl tensor.
459       output_mkl_shape->SetMklTensor(true);
460       output_mkl_shape->SetMklLayout(output_pd);
461       output_mkl_shape->SetElemType(MklDnnType<T>());
462       output_mkl_shape->SetTfLayout(input_mkl_shape.GetDimension(), output_dims,
463                                     input_mkl_shape.GetTfDataFormat());
464 
465       output_tf_shape.AddDim(output_pd->get_size() / sizeof(T));
466     } else {
467       // If input is not in Mkl layout, then output won't be in Mkl layout.
468       output_mkl_shape->SetMklTensor(false);
469       output_tf_shape = MklDnnDimsToTFShape(output_dims);
470     }
471 
472     AllocateOutputSetMklShape(context, 0, output_tensor, output_tf_shape,
473                               *output_mkl_shape);
474   }
475 };
476 
477 // MKL-DNN Slice registration
478 #define REGISTER_MKL_SLICE(type)                                    \
479   REGISTER_KERNEL_BUILDER(Name("_MklSlice")                         \
480                               .Device(DEVICE_CPU)                   \
481                               .TypeConstraint<type>("T")            \
482                               .HostMemory("begin")                  \
483                               .HostMemory("size")                   \
484                               .Label(mkl_op_registry::kMklOpLabel), \
485                           MklSliceOp<CPUDevice, type>);
486 
487 TF_CALL_float(REGISTER_MKL_SLICE);
488 #undef REGISTER_MKL_SLICE
489 
490 }  // namespace tensorflow
491 
492 #endif  // INTEL_MKL_DNN
493 #endif  // INTEL_MKL
494