1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // See docs in ../ops/array_ops.cc.
17
18 #ifdef INTEL_MKL
19
20 #include "mkldnn.hpp"
21 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
22 #include "tensorflow/core/framework/op_kernel.h"
23 #include "tensorflow/core/framework/register_types.h"
24 #include "tensorflow/core/framework/tensor.h"
25 #include "tensorflow/core/kernels/ops_util.h"
26 #include "tensorflow/core/lib/core/status.h"
27 #include "tensorflow/core/lib/gtl/array_slice.h"
28 #include "tensorflow/core/platform/prefetch.h"
29 #include "tensorflow/core/util/mkl_util.h"
30
31 using mkldnn::stream;
32
33 namespace tensorflow {
34
35 namespace {
36
IntTensorToInt64Vec(const Tensor & tensor)37 gtl::InlinedVector<int64, 4> IntTensorToInt64Vec(const Tensor& tensor) {
38 gtl::InlinedVector<int64, 4> out;
39 if (tensor.dtype() == DT_INT32) {
40 for (int64 i = 0; i < tensor.NumElements(); ++i) {
41 out.push_back(tensor.flat<int32>()(i));
42 }
43 } else if (tensor.dtype() == DT_INT64) {
44 for (int64 i = 0; i < tensor.NumElements(); ++i) {
45 out.push_back(tensor.flat<int64>()(i));
46 }
47 } else {
48 // tensor must be either int32 or int64
49 DCHECK(false);
50 }
51 return out;
52 }
53
54 } // namespace
55
56 typedef Eigen::ThreadPoolDevice CPUDevice;
57
58 // A version of SharedValidation (slice_op.h) written for input that is in
59 // either Mkl layout or Tensorflow layout. A shared code to validate input
60 // shapes and check for identity, which is not dependent on the type of T.
61 // We do this to reduce code size by not duplicating all this for all T
62 // (float, double, int32, etc.)
ValidateMklInputs(OpKernelContext * context,bool * is_identity,gtl::InlinedVector<int64,4> * begin,gtl::InlinedVector<int64,4> * size)63 static void ValidateMklInputs(OpKernelContext* context, bool* is_identity,
64 gtl::InlinedVector<int64, 4>* begin,
65 gtl::InlinedVector<int64, 4>* size) {
66 const int kInputTensorIndex = 0;
67 const int kInputBeginIndex = 1;
68 const int kInputSizeIndex = 2;
69 const Tensor& input = MklGetInput(context, kInputTensorIndex);
70 const Tensor& begin_tensor = MklGetInput(context, kInputBeginIndex);
71 const Tensor& size_tensor = MklGetInput(context, kInputSizeIndex);
72
73 MklDnnShape input_mkl_shape, begin_mkl_shape, size_mkl_shape;
74 GetMklShape(context, kInputTensorIndex, &input_mkl_shape);
75 GetMklShape(context, kInputBeginIndex, &begin_mkl_shape);
76 GetMklShape(context, kInputSizeIndex, &size_mkl_shape);
77
78 // Begin and size tensors cannot be in MklDnn layout.
79 DCHECK_EQ(begin_mkl_shape.IsMklTensor(), false);
80 DCHECK_EQ(size_mkl_shape.IsMklTensor(), false);
81
82 TensorShape input_tf_shape = input_mkl_shape.IsMklTensor()
83 ? input_mkl_shape.GetTfShape()
84 : input.shape();
85 const int input_dims = input_tf_shape.dims();
86
87 OP_REQUIRES(
88 context,
89 TensorShapeUtils::IsVector(begin_tensor.shape()) &&
90 TensorShapeUtils::IsVector(size_tensor.shape()) &&
91 begin_tensor.NumElements() == input_dims &&
92 size_tensor.NumElements() == input_dims,
93 errors::InvalidArgument(
94 "Expected begin and size arguments to be 1-D tensors of size ",
95 input_dims, ", but got shapes ", begin_tensor.shape().DebugString(),
96 " and ", size_tensor.shape().DebugString(), " instead."));
97
98 *begin = IntTensorToInt64Vec(begin_tensor);
99 *size = IntTensorToInt64Vec(size_tensor);
100 for (int i = 0; i < input_dims; ++i) {
101 if ((*size)[i] == -1) {
102 // A size[i] of -1 means "all elements from begin[i] to dim_size(i)".
103 (*size)[i] = input_tf_shape.dim_size(i) - (*begin)[i];
104 }
105 }
106
107 *is_identity = true;
108 for (int i = 0; i < input_dims; ++i) {
109 int64 b = (*begin)[i];
110 int64 s = (*size)[i];
111 if (input_tf_shape.dim_size(i) == 0) {
112 OP_REQUIRES(
113 context, b == 0 && s == 0,
114 errors::InvalidArgument("Expected begin[", i, "] == 0 (got ", b,
115 ") and size[", i, "] == 0 ", "(got ", s,
116 ") when ", "input.dim_size(", i, ") == 0"));
117 } else {
118 OP_REQUIRES(context, 0 <= b && b <= input_tf_shape.dim_size(i),
119 errors::InvalidArgument("Expected begin[", i, "] in [0, ",
120 input_tf_shape.dim_size(i),
121 "], but got ", b));
122 OP_REQUIRES(context, 0 <= s && b + s <= input_tf_shape.dim_size(i),
123 errors::InvalidArgument("Expected size[", i, "] in [0, ",
124 input_tf_shape.dim_size(i) - b,
125 "], but ", "got ", s));
126 }
127 const bool take_all = (b == 0) && (s == input_tf_shape.dim_size(i));
128 (*is_identity) &= take_all;
129 }
130 }
131
132 // A version of SharedSliceCommonCases function written for input tensor
133 // that may be in MklDnn layout or in Tensorflow layout.
134 template <typename T>
CheckCommonCasesForMklInputs(OpKernelContext * context,gtl::InlinedVector<int64,4> * begin,gtl::InlinedVector<int64,4> * size,bool * done)135 static void CheckCommonCasesForMklInputs(OpKernelContext* context,
136 gtl::InlinedVector<int64, 4>* begin,
137 gtl::InlinedVector<int64, 4>* size,
138 bool* done) {
139 bool is_identity = true;
140 *done = false;
141
142 ValidateMklInputs(context, &is_identity, begin, size);
143 if (!context->status().ok()) return;
144
145 const Tensor& input = MklGetInput(context, 0);
146 MklDnnShape input_mkl_shape;
147 GetMklShape(context, 0, &input_mkl_shape);
148
149 if (is_identity) {
150 VLOG(1) << "Slice identity";
151 context->set_output(0, input);
152 // Mkl metadata tensor in this case can just be forwarded from input to
153 // output.
154 AllocateOutputSetMklShape(context, 0, input_mkl_shape);
155 *done = true;
156 }
157 }
158
159 // This structure aggregates multiple inputs to Slice methods.
160 struct MklSliceParams {
161 // Parameters from & to represents memory pointing to reorder.
162 const memory* from;
163 const memory* to;
164
165 // Parameters begin_dims & size_dims represents offset and length
166 // passed to view primitive.
167 memory::dims begin_dims;
168 memory::dims size_dims;
169
MklSliceParamstensorflow::MklSliceParams170 MklSliceParams(const memory* from, const memory* to, memory::dims begin_dims,
171 memory::dims size_dims)
172 : from(from), to(to), begin_dims(begin_dims), size_dims(size_dims) {}
173 };
174
175 // This implements the shared interface of Slice reorders.
176 template <typename T>
177 class MklSlicePrimitive : public MklPrimitive {
178 public:
MklSlicePrimitive(const MklSliceParams & sliceParams)179 explicit MklSlicePrimitive(const MklSliceParams& sliceParams)
180 : MklPrimitive(engine(engine::kind::cpu, 0)) {
181 Setup(sliceParams);
182 }
183
~MklSlicePrimitive()184 ~MklSlicePrimitive() {}
185
Execute(const MklSliceParams & sliceParams,std::shared_ptr<stream> slice_stream)186 void Execute(const MklSliceParams& sliceParams,
187 std::shared_ptr<stream> slice_stream) {
188 #ifdef ENABLE_MKLDNN_THREADPOOL
189 context_.src_mem->set_data_handle(sliceParams.from->get_data_handle(),
190 *slice_stream);
191 context_.dst_mem->set_data_handle(sliceParams.to->get_data_handle(),
192 *slice_stream);
193 #else
194 context_.src_mem->set_data_handle(sliceParams.from->get_data_handle());
195 context_.dst_mem->set_data_handle(sliceParams.to->get_data_handle());
196 #endif // ENABLE_MKLDNN_THREADPOOL
197
198 execute_primitives(context_.slice_primitives, slice_stream,
199 context_.slice_primitives_args);
200
201 // We should set it back to DummyData so as to make the primitive
202 // in cache pool stateless. Otherwise, if the result for previous
203 // iteration is kept, problems of current iteration won't be
204 // thrown immediately, and wrong data would be reused.
205 context_.src_mem->set_data_handle(DummyData);
206 context_.dst_mem->set_data_handle(DummyData);
207 return;
208 }
209
GetPrimitive()210 std::shared_ptr<primitive> GetPrimitive() { return context_.reorder_prim; }
211
212 private:
213 struct SliceContext {
214 std::shared_ptr<mkldnn::memory> src_mem;
215 std::shared_ptr<mkldnn::memory> dst_mem;
216 std::shared_ptr<primitive> reorder_prim;
217 std::shared_ptr<reorder::primitive_desc> reorder_pd;
218 std::shared_ptr<mkldnn::stream> slice_stream;
219 std::vector<mkldnn::primitive> slice_primitives;
220 std::shared_ptr<mkldnn::memory> src_sub_mem;
221 std::vector<std::unordered_map<int, memory>> slice_primitives_args;
SliceContexttensorflow::MklSlicePrimitive::SliceContext222 SliceContext()
223 : src_mem(nullptr), dst_mem(nullptr), reorder_prim(nullptr) {}
224 } context_;
225
Setup(const MklSliceParams & sliceParams)226 void Setup(const MklSliceParams& sliceParams) {
227 // Actually, DummyData will not be used in computation,
228 // because the real data will be filled before execution.
229 context_.src_mem.reset(
230 new memory(sliceParams.from->get_desc(), cpu_engine_, DummyData));
231 context_.dst_mem.reset(
232 new memory(sliceParams.to->get_desc(), cpu_engine_, DummyData));
233
234 auto src_sub_desc = context_.src_mem->get_desc().submemory_desc(
235 sliceParams.size_dims, sliceParams.begin_dims);
236 context_.src_sub_mem.reset(new memory(src_sub_desc, cpu_engine_, nullptr));
237 context_.reorder_pd = std::make_shared<reorder::primitive_desc>(
238 reorder::primitive_desc(*context_.src_sub_mem, *context_.dst_mem));
239 context_.reorder_prim =
240 std::make_shared<mkldnn::reorder>(reorder(*context_.reorder_pd));
241
242 context_.slice_primitives_args.push_back(
243 {{MKLDNN_ARG_SRC, *context_.src_mem},
244 {MKLDNN_ARG_DST, *context_.dst_mem}});
245 context_.slice_primitives.push_back(*context_.reorder_prim);
246 }
247 };
248
249 template <typename T>
250 class MklSlicePrimitiveFactory : public MklPrimitiveFactory<T> {
251 public:
Get(const MklSliceParams & sliceParams)252 static MklSlicePrimitive<T>* Get(const MklSliceParams& sliceParams) {
253 auto reorderPrim = static_cast<MklSlicePrimitive<T>*>(
254 MklSlicePrimitiveFactory<T>::GetInstance().GetReorder(sliceParams));
255 if (reorderPrim == nullptr) {
256 reorderPrim = new MklSlicePrimitive<T>(sliceParams);
257 MklSlicePrimitiveFactory<T>::GetInstance().SetReorder(sliceParams,
258 reorderPrim);
259 }
260 return reorderPrim;
261 }
262
GetInstance()263 static MklSlicePrimitiveFactory& GetInstance() {
264 static MklSlicePrimitiveFactory instance_;
265 return instance_;
266 }
267
268 private:
MklSlicePrimitiveFactory()269 MklSlicePrimitiveFactory() {}
~MklSlicePrimitiveFactory()270 ~MklSlicePrimitiveFactory() {}
271
CreateKey(const MklSliceParams & sliceParams)272 static string CreateKey(const MklSliceParams& sliceParams) {
273 string prefix = "reorder";
274 FactoryKeyCreator key_creator;
275 auto const& from_desc = sliceParams.from->get_desc().data;
276 auto const& to_desc = sliceParams.to->get_desc().data;
277 memory::dims from_dims(from_desc.dims, &from_desc.dims[from_desc.ndims]);
278 memory::dims to_dims(to_desc.dims, &to_desc.dims[to_desc.ndims]);
279
280 // MKL-DNN removes "struct view". Submemory has similar capability.
281 auto from_strides = from_desc.format_desc.blocking.strides;
282 auto to_strides = to_desc.format_desc.blocking.strides;
283 memory::dims from_strides_outer_blocks(from_strides,
284 &from_strides[from_desc.ndims]);
285 memory::dims to_strides_outer_blocks(to_strides,
286 &to_strides[to_desc.ndims]);
287
288 key_creator.AddAsKey(prefix);
289 key_creator.AddAsKey(static_cast<int>(from_desc.data_type));
290 key_creator.AddAsKey(from_dims);
291 key_creator.AddAsKey(from_strides_outer_blocks);
292 key_creator.AddAsKey(static_cast<int>(to_desc.data_type));
293 key_creator.AddAsKey(to_dims);
294 key_creator.AddAsKey(to_strides_outer_blocks);
295 key_creator.AddAsKey(sliceParams.begin_dims);
296 key_creator.AddAsKey(sliceParams.size_dims);
297 return key_creator.GetKey();
298 }
299
GetReorder(const MklSliceParams & sliceParams)300 MklPrimitive* GetReorder(const MklSliceParams& sliceParams) {
301 string key = CreateKey(sliceParams);
302 return this->GetOp(key);
303 }
304
SetReorder(const MklSliceParams & sliceParams,MklPrimitive * op)305 void SetReorder(const MklSliceParams& sliceParams, MklPrimitive* op) {
306 string key = CreateKey(sliceParams);
307 this->SetOp(key, op);
308 }
309 };
310
311 // MKL-DNN implementation of Slice
312 template <typename Device, typename T>
313 class MklSliceOp : public OpKernel {
314 public:
MklSliceOp(OpKernelConstruction * context)315 explicit MklSliceOp(OpKernelConstruction* context) : OpKernel(context) {}
316
~MklSliceOp()317 ~MklSliceOp() {}
318
Compute(OpKernelContext * context)319 void Compute(OpKernelContext* context) override {
320 gtl::InlinedVector<int64, 4> begin;
321 gtl::InlinedVector<int64, 4> size;
322 bool done = false;
323
324 CheckCommonCasesForMklInputs<T>(context, &begin, &size, &done);
325
326 if (!context->status().ok() || done == true) return;
327
328 // Though MKL-DNN supports more than 8 dimension and
329 // less than 12 dimension tensor.
330 // But we are mimicking functionality of Eigen Slice op for CPU.
331 if (begin.size() >= 8) {
332 OP_REQUIRES(
333 context, false,
334 errors::Unimplemented("MklSliceOp : Unhandled input dimensions"));
335 }
336
337 ComputeMklSlice(context, begin, size);
338 }
339
340 private:
341 // Slice op implemented using MKL-DNN APIs.
ComputeMklSlice(OpKernelContext * context,const gtl::InlinedVector<int64,4> & begin,const gtl::InlinedVector<int64,4> & size)342 void ComputeMklSlice(OpKernelContext* context,
343 const gtl::InlinedVector<int64, 4>& begin,
344 const gtl::InlinedVector<int64, 4>& size) {
345 try {
346 // MKL-DNN API usage below is guided by description at:
347 // https://github.com/01org/mkl-dnn/issues/69
348 //
349 // Relevant part of the description is copied below:
350 //
351 // Let's say you want to copy a part of memory into another buffer (and
352 // probably change the format). Then your steps are:
353 //
354 // 1. create memory primitive descriptor in_mem_pd and memory primitive
355 // in_mem_p for the entire source data. create view primitive
356 // descriptor in_submem_pd based on in_mem_pd, initial offsets,
357 // and sub-sizes
358 // 2. create memory primitive descriptor out_mem_pd and memory primitive
359 // out_mem_p for the output (the logical sizes should match sub-sizes
360 // used in step 1, but the format might be arbitrary)
361 // 3. create reorder primitive descriptor reorder_pd based on in_submem_pd
362 // and out_mem_pd. create reorder primitive itself based on reorder_pd,
363 // in_mem_p, and out_mem_p.
364 //
365 // Please notice that there is no view primitive. There is only view
366 // primitive descriptor. And the reorder uses source memory as input but
367 // traverses it according to a view in_submem_pd.
368
369 auto cpu_engine = engine(engine::kind::cpu, 0);
370 MklDnnData<T> src(&cpu_engine);
371 MklDnnData<T> output(&cpu_engine);
372
373 // Populate offsets and sizes in memory::dims format based on vector.
374 memory::dims begin_dims = {};
375 begin_dims.resize(begin.size());
376 for (size_t i = 0; i < begin.size(); ++i) begin_dims[i] = begin[i];
377 memory::dims size_dims = {};
378 bool empty = false;
379 size_dims.resize(size.size());
380 for (size_t i = 0; i < size.size(); ++i) {
381 size_dims[i] = size[i];
382 if (size_dims[i] == 0) empty = true;
383 }
384
385 Tensor* output_tensor = nullptr;
386 MklDnnShape output_mkl_shape;
387
388 // If no dimension is selected in slice, the result should be empty.
389 // Just return an empty output tensor, and a dummy Mkl-shape tensor.
390 if (empty) { // for empty dims
391 auto shape_to = MklDnnDimsToTFShape(size_dims);
392 AllocateOutputSetMklShape(context, 0, &output_tensor, shape_to,
393 output_mkl_shape);
394 return;
395 }
396
397 // Step 1 (as per above description) - Create memory for user data.
398 // We use blocked format here to describe input tensor.
399 const Tensor& input_tensor = MklGetInput(context, 0);
400 memory::dims input_dims, input_strides;
401 MklDnnShape input_mkl_shape;
402 GetMklShape(context, 0, &input_mkl_shape);
403
404 if (input_mkl_shape.IsMklTensor()) {
405 auto input_mkl_format = input_mkl_shape.GetTfDataFormat();
406 auto input_tf_format = MklDnnDataFormatToTFDataFormat(input_mkl_format);
407
408 bool is_slice2d = (input_mkl_shape.GetDimension() == 4);
409 begin_dims = is_slice2d
410 ? MklDnnDimsInNCHW(begin_dims, input_tf_format)
411 : MklDnnDimsInNCDHW(begin_dims, input_tf_format);
412 size_dims = is_slice2d ? MklDnnDimsInNCHW(size_dims, input_tf_format)
413 : MklDnnDimsInNCDHW(size_dims, input_tf_format);
414 auto input_md = input_mkl_shape.GetMklLayout();
415 src.SetUsrMem(input_md, &input_tensor);
416
417 // Handle data format safely, change them to block format.
418 // Compute parameters of reorder primitive first.
419 input_dims = input_mkl_shape.GetSizesAsMklDnnDims();
420 input_strides = CalculateTFStrides(input_dims);
421 } else {
422 // Initialize input dimensions and strides to be used when input is not
423 // in MklDnn layout.
424 input_dims = TFShapeToMklDnnDims(input_tensor.shape());
425 input_strides = CalculateTFStrides(input_dims);
426 // Create input memory descriptor.
427 auto input_md =
428 MklDnnData<T>::CreateBlockedMemDesc(input_dims, input_strides);
429 src.SetUsrMem(input_md, &input_tensor);
430 }
431
432 // If format not equal to block format, execute reorder.
433 // Or else do nothing for it.
434 auto op_md =
435 MklDnnData<T>::CreateBlockedMemDesc(input_dims, input_strides);
436 src.CheckReorderToOpMem(op_md, cpu_engine, context);
437
438 // Step 2 - Create memory for output.
439 auto output_strides = CalculateTFStrides(size_dims);
440 auto output_md =
441 MklDnnData<T>::CreateBlockedMemDesc(size_dims, output_strides);
442 auto output_pd = output_md;
443 AllocateOutputTensor(context, input_mkl_shape, &output_pd, size_dims,
444 &output_tensor, &output_mkl_shape);
445 DCHECK(output_tensor);
446 DCHECK_EQ(input_mkl_shape.IsMklTensor(), output_mkl_shape.IsMklTensor());
447 output.SetUsrMem(output_md, output_tensor);
448
449 // Step 3 - create reorder primitive.
450 MklSliceParams sliceParams(&src.GetOpMem(), output.GetUsrMem(),
451 begin_dims, size_dims);
452 MklSlicePrimitive<T>* reorder_prim =
453 MklSlicePrimitiveFactory<T>::Get(sliceParams);
454 // Execute slice reorder.
455 std::shared_ptr<stream> slice_stream;
456 slice_stream.reset(CreateStream(context, reorder_prim->GetEngine()));
457 reorder_prim->Execute(sliceParams, slice_stream);
458 } catch (mkldnn::error& e) {
459 string error_msg = "Status: " + std::to_string(e.status) +
460 ", message: " + string(e.message) + ", in file " +
461 string(__FILE__) + ":" + std::to_string(__LINE__);
462 OP_REQUIRES_OK(
463 context,
464 errors::Aborted("Operation received an exception:", error_msg));
465 }
466 }
467
468 private:
AllocateOutputTensor(OpKernelContext * context,const MklDnnShape & input_mkl_shape,memory::desc * output_pd,const memory::dims & output_dims,Tensor ** output_tensor,MklDnnShape * output_mkl_shape)469 void AllocateOutputTensor(OpKernelContext* context,
470 const MklDnnShape& input_mkl_shape,
471 memory::desc* output_pd,
472 const memory::dims& output_dims,
473 Tensor** output_tensor,
474 MklDnnShape* output_mkl_shape) {
475 DCHECK(output_tensor);
476 DCHECK(output_mkl_shape);
477
478 TensorShape output_tf_shape;
479
480 if (input_mkl_shape.IsMklTensor()) {
481 // Since input tensor is in Mkl layout, output tensor will be in Mkl
482 // layout.
483
484 // Allocate shape of Mkl tensor.
485 output_mkl_shape->SetMklTensor(true);
486 output_mkl_shape->SetMklLayout(output_pd);
487 output_mkl_shape->SetElemType(MklDnnType<T>());
488 output_mkl_shape->SetTfLayout(input_mkl_shape.GetDimension(), output_dims,
489 input_mkl_shape.GetTfDataFormat());
490
491 output_tf_shape.AddDim(output_pd->get_size() / sizeof(T));
492 } else {
493 // If input is not in Mkl layout, then output won't be in Mkl layout.
494 output_mkl_shape->SetMklTensor(false);
495 output_tf_shape = MklDnnDimsToTFShape(output_dims);
496 }
497
498 AllocateOutputSetMklShape(context, 0, output_tensor, output_tf_shape,
499 *output_mkl_shape);
500 }
501 };
502
503 // MKL-DNN Slice registration
504 #define REGISTER_MKL_SLICE(type) \
505 REGISTER_KERNEL_BUILDER( \
506 Name("_MklSlice") \
507 .Device(DEVICE_CPU) \
508 .TypeConstraint<type>("T") \
509 .HostMemory("begin") \
510 .HostMemory("size") \
511 .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
512 MklSliceOp<CPUDevice, type>);
513
514 TF_CALL_float(REGISTER_MKL_SLICE);
515 TF_CALL_bfloat16(REGISTER_MKL_SLICE);
516 #undef REGISTER_MKL_SLICE
517
518 } // namespace tensorflow
519
520 #endif // INTEL_MKL
521