1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // See docs in ../ops/array_ops.cc.
17
18 #ifdef INTEL_MKL
19 #ifndef INTEL_MKL_ML_ONLY
20
21 #include "tensorflow/core/framework/op_kernel.h"
22 #include "tensorflow/core/framework/register_types.h"
23 #include "tensorflow/core/framework/tensor.h"
24 #include "tensorflow/core/kernels/ops_util.h"
25 #include "tensorflow/core/lib/core/status.h"
26 #include "tensorflow/core/lib/gtl/array_slice.h"
27 #include "tensorflow/core/platform/prefetch.h"
28 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
29
30 #include "mkldnn.hpp"
31 #include "tensorflow/core/util/mkl_util.h"
32
33 using mkldnn::stream;
34 using mkldnn::view;
35
36 namespace tensorflow {
37
38 namespace {
39
IntTensorToInt64Vec(const Tensor & tensor)40 gtl::InlinedVector<int64, 4> IntTensorToInt64Vec(const Tensor& tensor) {
41 gtl::InlinedVector<int64, 4> out;
42 if (tensor.dtype() == DT_INT32) {
43 for (int64 i = 0; i < tensor.NumElements(); ++i) {
44 out.push_back(tensor.flat<int32>()(i));
45 }
46 } else if (tensor.dtype() == DT_INT64) {
47 for (int64 i = 0; i < tensor.NumElements(); ++i) {
48 out.push_back(tensor.flat<int64>()(i));
49 }
50 } else {
51 // tensor must be either int32 or int64
52 DCHECK(false);
53 }
54 return out;
55 }
56
57 } // namespace
58
59 typedef Eigen::ThreadPoolDevice CPUDevice;
60
61 // A version of SharedValidation (slice_op.h) written for input that is in
62 // either Mkl layout or Tensorflow layout. A shared code to validate input
63 // shapes and check for identity, which is not dependent on the type of T.
64 // We do this to reduce code size by not duplicating all this for all T
65 // (float, double, int32, etc.)
ValidateMklInputs(OpKernelContext * context,bool * is_identity,gtl::InlinedVector<int64,4> * begin,gtl::InlinedVector<int64,4> * size)66 static void ValidateMklInputs(OpKernelContext* context, bool* is_identity,
67 gtl::InlinedVector<int64, 4>* begin,
68 gtl::InlinedVector<int64, 4>* size) {
69 const int kInputTensorIndex = 0;
70 const int kInputBeginIndex = 1;
71 const int kInputSizeIndex = 2;
72 const Tensor& input = MklGetInput(context, kInputTensorIndex);
73 const Tensor& begin_tensor = MklGetInput(context, kInputBeginIndex);
74 const Tensor& size_tensor = MklGetInput(context, kInputSizeIndex);
75
76 MklDnnShape input_mkl_shape, begin_mkl_shape, size_mkl_shape;
77 GetMklShape(context, kInputTensorIndex, &input_mkl_shape);
78 GetMklShape(context, kInputBeginIndex, &begin_mkl_shape);
79 GetMklShape(context, kInputSizeIndex, &size_mkl_shape);
80
81 // Begin and size tensors cannot be in MklDnn layout.
82 DCHECK_EQ(begin_mkl_shape.IsMklTensor(), false);
83 DCHECK_EQ(size_mkl_shape.IsMklTensor(), false);
84
85 TensorShape input_tf_shape = input_mkl_shape.IsMklTensor()
86 ? input_mkl_shape.GetTfShape()
87 : input.shape();
88 const int input_dims = input_tf_shape.dims();
89
90 OP_REQUIRES(
91 context, context->op_kernel().IsLegacyVector(begin_tensor.shape()) &&
92 context->op_kernel().IsLegacyVector(size_tensor.shape()) &&
93 begin_tensor.NumElements() == input_dims &&
94 size_tensor.NumElements() == input_dims,
95 errors::InvalidArgument(
96 "Expected begin and size arguments to be 1-D tensors of size ",
97 input_dims, ", but got shapes ", begin_tensor.shape().DebugString(),
98 " and ", size_tensor.shape().DebugString(), " instead."));
99
100 *begin = IntTensorToInt64Vec(begin_tensor);
101 *size = IntTensorToInt64Vec(size_tensor);
102 for (int i = 0; i < input_dims; ++i) {
103 if ((*size)[i] == -1) {
104 // A size[i] of -1 means "all elements from begin[i] to dim_size(i)".
105 (*size)[i] = input_tf_shape.dim_size(i) - (*begin)[i];
106 }
107 }
108
109 *is_identity = true;
110 for (int i = 0; i < input_dims; ++i) {
111 int64 b = (*begin)[i];
112 int64 s = (*size)[i];
113 if (input_tf_shape.dim_size(i) == 0) {
114 OP_REQUIRES(
115 context, b == 0 && s == 0,
116 errors::InvalidArgument("Expected begin[", i, "] == 0 (got ", b,
117 ") and size[", i, "] == 0 ", "(got ", s,
118 ") when ", "input.dim_size(", i, ") == 0"));
119 } else {
120 OP_REQUIRES(context, 0 <= b && b <= input_tf_shape.dim_size(i),
121 errors::InvalidArgument("Expected begin[", i, "] in [0, ",
122 input_tf_shape.dim_size(i),
123 "], but got ", b));
124 OP_REQUIRES(context, 0 <= s && b + s <= input_tf_shape.dim_size(i),
125 errors::InvalidArgument("Expected size[", i, "] in [0, ",
126 input_tf_shape.dim_size(i) - b,
127 "], but ", "got ", s));
128 }
129 const bool take_all = (b == 0) && (s == input_tf_shape.dim_size(i));
130 (*is_identity) &= take_all;
131 }
132 }
133
134 // A version of SharedSliceCommonCases function written for input tensor
135 // that may be in MklDnn layout or in Tensorflow layout.
136 template <typename T>
CheckCommonCasesForMklInputs(OpKernelContext * context,gtl::InlinedVector<int64,4> * begin,gtl::InlinedVector<int64,4> * size,bool * done)137 static void CheckCommonCasesForMklInputs(OpKernelContext* context,
138 gtl::InlinedVector<int64, 4>* begin,
139 gtl::InlinedVector<int64, 4>* size,
140 bool* done) {
141 bool is_identity = true;
142 *done = false;
143
144 ValidateMklInputs(context, &is_identity, begin, size);
145 if (!context->status().ok()) return;
146
147 const Tensor& input = MklGetInput(context, 0);
148 MklDnnShape input_mkl_shape;
149 GetMklShape(context, 0, &input_mkl_shape);
150
151 if (is_identity) {
152 VLOG(1) << "Slice identity";
153 context->set_output(0, input);
154 // Mkl metadata tensor in this case can just be forwarded from input to
155 // output.
156 AllocateOutputSetMklShape(context, 0, input_mkl_shape);
157 *done = true;
158 }
159 }
160
161 // This structure aggregates multiple inputs to Slice methods.
162 struct MklSliceParams {
163 // Parameters from & to represents memory pointing to reorder.
164 const memory* from;
165 const memory* to;
166
167 // Parameters begin_dims & size_dims represents offset and length
168 // passed to view primitive.
169 memory::dims begin_dims;
170 memory::dims size_dims;
171
MklSliceParamstensorflow::MklSliceParams172 MklSliceParams(const memory* from, const memory* to, memory::dims begin_dims,
173 memory::dims size_dims)
174 : from(from), to(to), begin_dims(begin_dims), size_dims(size_dims) {}
175 };
176
177 // This implements the shared interface of Slice reorders.
178 template <typename T>
179 class MklSlicePrimitive : public MklPrimitive {
180 public:
MklSlicePrimitive(const MklSliceParams & sliceParams)181 explicit MklSlicePrimitive(const MklSliceParams& sliceParams) {
182 context_.slice_stream.reset(new stream(stream::kind::eager));
183 Setup(sliceParams);
184 }
185
~MklSlicePrimitive()186 ~MklSlicePrimitive() {}
187
Execute(const MklSliceParams & sliceParams)188 void Execute(const MklSliceParams& sliceParams) {
189 context_.src_mem->set_data_handle(sliceParams.from->get_data_handle());
190 context_.dst_mem->set_data_handle(sliceParams.to->get_data_handle());
191 context_.slice_stream->submit(context_.slice_primitives);
192
193 // We should set it back to DummyData so as to make the primitive
194 // in cache pool stateless. Otherwise, if the result for previous
195 // iteration is kept, problems of current iteration won't be
196 // thrown immediately, and wrong data would be reused.
197 context_.src_mem->set_data_handle(DummyData);
198 context_.dst_mem->set_data_handle(DummyData);
199 return;
200 }
201
GetPrimitive()202 std::shared_ptr<primitive> GetPrimitive() { return context_.reorder_prim; }
203
204 private:
205 struct SliceContext {
206 std::shared_ptr<mkldnn::memory> src_mem;
207 std::shared_ptr<mkldnn::memory> dst_mem;
208 std::shared_ptr<primitive> reorder_prim;
209 std::shared_ptr<reorder::primitive_desc> reorder_pd;
210 std::shared_ptr<view::primitive_desc> view_pd;
211 std::shared_ptr<mkldnn::stream> slice_stream;
212 std::vector<mkldnn::primitive> slice_primitives;
SliceContexttensorflow::MklSlicePrimitive::SliceContext213 SliceContext()
214 : src_mem(nullptr), dst_mem(nullptr), reorder_prim(nullptr) {}
215 } context_;
216
217 engine cpu_engine_ = engine(engine::cpu, 0);
218
Setup(const MklSliceParams & sliceParams)219 void Setup(const MklSliceParams& sliceParams) {
220 // Actually, this DummyData will not be used in computation,
221 // because the real data will be filled before real execution.
222 context_.src_mem.reset(
223 new memory({sliceParams.from->get_primitive_desc().desc(), cpu_engine_},
224 DummyData));
225 context_.dst_mem.reset(new memory(
226 {sliceParams.to->get_primitive_desc().desc(), cpu_engine_}, DummyData));
227 auto src_pd = context_.src_mem->get_primitive_desc();
228 auto dst_pd = context_.dst_mem->get_primitive_desc();
229 context_.view_pd =
230 std::make_shared<view::primitive_desc>(view::primitive_desc(
231 src_pd, sliceParams.size_dims, sliceParams.begin_dims));
232 context_.reorder_pd =
233 std::make_shared<reorder::primitive_desc>(reorder::primitive_desc(
234 context_.view_pd->dst_primitive_desc(), dst_pd));
235 context_.reorder_prim = std::make_shared<mkldnn::reorder>(
236 reorder(*context_.reorder_pd, *context_.src_mem, *context_.dst_mem));
237 context_.slice_primitives.push_back(*context_.reorder_prim);
238 }
239 };
240
241 template <typename T>
242 class MklSlicePrimitiveFactory : public MklPrimitiveFactory<T> {
243 public:
Get(const MklSliceParams & sliceParams)244 static MklSlicePrimitive<T>* Get(const MklSliceParams& sliceParams) {
245 auto reorderPrim = static_cast<MklSlicePrimitive<T>*>(
246 MklSlicePrimitiveFactory<T>::GetInstance().GetReorder(sliceParams));
247 if (reorderPrim == nullptr) {
248 reorderPrim = new MklSlicePrimitive<T>(sliceParams);
249 MklSlicePrimitiveFactory<T>::GetInstance().SetReorder(sliceParams,
250 reorderPrim);
251 }
252 return reorderPrim;
253 }
254
GetInstance()255 static MklSlicePrimitiveFactory& GetInstance() {
256 static MklSlicePrimitiveFactory instance_;
257 return instance_;
258 }
259
260 private:
MklSlicePrimitiveFactory()261 MklSlicePrimitiveFactory() {}
~MklSlicePrimitiveFactory()262 ~MklSlicePrimitiveFactory() {}
263
CreateKey(const MklSliceParams & sliceParams)264 static string CreateKey(const MklSliceParams& sliceParams) {
265 string prefix = "reorder";
266 FactoryKeyCreator key_creator;
267 auto const& from_desc = sliceParams.from->get_primitive_desc().desc().data;
268 auto const& to_desc = sliceParams.to->get_primitive_desc().desc().data;
269 const int kIdxFirstStride = 0;
270 memory::dims from_dims(from_desc.dims, &from_desc.dims[from_desc.ndims]);
271 memory::dims to_dims(to_desc.dims, &to_desc.dims[to_desc.ndims]);
272 memory::dims from_strides(
273 from_desc.layout_desc.blocking.strides[kIdxFirstStride],
274 &from_desc.layout_desc.blocking
275 .strides[kIdxFirstStride][from_desc.ndims]);
276 memory::dims to_strides(
277 to_desc.layout_desc.blocking.strides[kIdxFirstStride],
278 &to_desc.layout_desc.blocking.strides[kIdxFirstStride][to_desc.ndims]);
279 key_creator.AddAsKey(prefix);
280 key_creator.AddAsKey(static_cast<int>(from_desc.format));
281 key_creator.AddAsKey(static_cast<int>(from_desc.data_type));
282 key_creator.AddAsKey(from_dims);
283 key_creator.AddAsKey(from_strides);
284 key_creator.AddAsKey(static_cast<int>(to_desc.format));
285 key_creator.AddAsKey(static_cast<int>(to_desc.data_type));
286 key_creator.AddAsKey(to_dims);
287 key_creator.AddAsKey(to_strides);
288 key_creator.AddAsKey(sliceParams.begin_dims);
289 key_creator.AddAsKey(sliceParams.size_dims);
290 return key_creator.GetKey();
291 }
292
GetReorder(const MklSliceParams & sliceParams)293 MklPrimitive* GetReorder(const MklSliceParams& sliceParams) {
294 string key = CreateKey(sliceParams);
295 return this->GetOp(key);
296 }
297
SetReorder(const MklSliceParams & sliceParams,MklPrimitive * op)298 void SetReorder(const MklSliceParams& sliceParams, MklPrimitive* op) {
299 string key = CreateKey(sliceParams);
300 this->SetOp(key, op);
301 }
302 };
303
304 // MKL-DNN implementation of Slice
305 template <typename Device, typename T>
306 class MklSliceOp : public OpKernel {
307 public:
MklSliceOp(OpKernelConstruction * context)308 explicit MklSliceOp(OpKernelConstruction* context) : OpKernel(context) {}
309
~MklSliceOp()310 ~MklSliceOp() {}
311
Compute(OpKernelContext * context)312 void Compute(OpKernelContext* context) override {
313 gtl::InlinedVector<int64, 4> begin;
314 gtl::InlinedVector<int64, 4> size;
315 bool done = false;
316
317 CheckCommonCasesForMklInputs<T>(context, &begin, &size, &done);
318 if (!context->status().ok() || done == true) return;
319
320 // Though MKL-DNN supports more than 8 dimension and
321 // less than 12 dimension tensor.
322 // But we are mimicking functionality of Eigen Slice op for CPU.
323 if (begin.size() >= 8) {
324 OP_REQUIRES(
325 context, false,
326 errors::Unimplemented("MklSliceOp : Unhandled input dimensions"));
327 }
328
329 ComputeMklSlice(context, begin, size);
330 }
331
332 private:
333 // Slice op implemented using MKL-DNN APIs.
ComputeMklSlice(OpKernelContext * context,const gtl::InlinedVector<int64,4> & begin,const gtl::InlinedVector<int64,4> & size)334 void ComputeMklSlice(OpKernelContext* context,
335 const gtl::InlinedVector<int64, 4>& begin,
336 const gtl::InlinedVector<int64, 4>& size) {
337 try {
338 // MKL-DNN API usage below is guided by description at:
339 // https://github.com/01org/mkl-dnn/issues/69
340 //
341 // Relevant part of the description is copied below:
342 //
343 // Let's say you want to copy a part of memory into another buffer (and
344 // probably change the format). Then your steps are:
345 //
346 // 1. create memory primitive descriptor in_mem_pd and memory primitive
347 // in_mem_p for the entire source data. create view primitive
348 // descriptor in_submem_pd based on in_mem_pd, initial offsets,
349 // and sub-sizes
350 // 2. create memory primitive descriptor out_mem_pd and memory primitive
351 // out_mem_p for the output (the logical sizes should match sub-sizes
352 // used in step 1, but the format might be arbitrary)
353 // 3. create reorder primitive descriptor reorder_pd based on in_submem_pd
354 // and out_mem_pd. create reorder primitive itself based on reorder_pd,
355 // in_mem_p, and out_mem_p.
356 //
357 // Please notice that there is no view primitive. There is only view
358 // primitive descriptor. And the reorder uses source memory as input but
359 // traverses it according to a view in_submem_pd.
360
361 auto cpu_engine = engine(engine::cpu, 0);
362 MklDnnData<T> src(&cpu_engine);
363 MklDnnData<T> output(&cpu_engine);
364
365 // Populate offsets and sizes in memory::dims format based on vector.
366 memory::dims begin_dims = {};
367 begin_dims.resize(begin.size());
368 for (size_t i = 0; i < begin.size(); ++i) begin_dims[i] = begin[i];
369 memory::dims size_dims = {};
370 bool empty = false;
371 size_dims.resize(size.size());
372 for (size_t i = 0; i < size.size(); ++i) {
373 size_dims[i] = size[i];
374 if (size_dims[i] == 0) empty = true;
375 }
376
377 Tensor* output_tensor = nullptr;
378 MklDnnShape output_mkl_shape;
379
380 // If no dimension is selected in slice, the result should be empty.
381 // Just return an empty output tensor, and a dummy Mkl-shape tensor.
382 if (empty) { // for empty dims
383 auto shape_to = MklDnnDimsToTFShape(size_dims);
384 AllocateOutputSetMklShape(context, 0, &output_tensor, shape_to,
385 output_mkl_shape);
386 return;
387 }
388
389 // Step 1 (as per above description) - Create memory for user data.
390 // We use blocked format here to describe input tensor.
391 const Tensor& input_tensor = MklGetInput(context, 0);
392 MklDnnShape input_mkl_shape;
393 GetMklShape(context, 0, &input_mkl_shape);
394
395 if (input_mkl_shape.IsMklTensor()) {
396 auto input_mkl_format = input_mkl_shape.GetTfDataFormat();
397 auto input_tf_format = MklDnnDataFormatToTFDataFormat(input_mkl_format);
398 begin_dims = MklDnnDimsInNCHW(begin_dims, input_tf_format);
399 size_dims = MklDnnDimsInNCHW(size_dims, input_tf_format);
400 auto input_md = input_mkl_shape.GetMklLayout();
401 src.SetUsrMem(input_md, &input_tensor);
402 } else {
403 // Initialize input dimensions and strides to be used when input is not
404 // in MklDnn layout.
405 memory::dims input_dims, input_strides;
406 input_dims = TFShapeToMklDnnDims(input_tensor.shape());
407 input_strides = CalculateTFStrides(input_dims);
408 // Create input memory descriptor.
409 auto input_md =
410 MklDnnData<T>::CreateBlockedMemDesc(input_dims, input_strides);
411 src.SetUsrMem(input_md, &input_tensor);
412 }
413
414 // Step 2 - Create memory for output.
415 auto output_strides = CalculateTFStrides(size_dims);
416 auto output_md =
417 MklDnnData<T>::CreateBlockedMemDesc(size_dims, output_strides);
418 auto output_pd = memory::primitive_desc(output_md, cpu_engine);
419 AllocateOutputTensor(context, input_mkl_shape, &output_pd, size_dims,
420 &output_tensor, &output_mkl_shape);
421 DCHECK(output_tensor);
422 DCHECK_EQ(input_mkl_shape.IsMklTensor(), output_mkl_shape.IsMklTensor());
423 output.SetUsrMem(output_md, output_tensor);
424
425 // Step 3 - create reorder primitive.
426 MklSliceParams sliceParams(src.GetUsrMem(), output.GetUsrMem(),
427 begin_dims, size_dims);
428 MklSlicePrimitive<T>* reorder_prim =
429 MklSlicePrimitiveFactory<T>::Get(sliceParams);
430 // Execute slice reorder.
431 reorder_prim->Execute(sliceParams);
432 } catch (mkldnn::error& e) {
433 string error_msg = "Status: " + std::to_string(e.status) + ", message: " +
434 string(e.message) + ", in file " + string(__FILE__) +
435 ":" + std::to_string(__LINE__);
436 OP_REQUIRES_OK(
437 context,
438 errors::Aborted("Operation received an exception:", error_msg));
439 }
440 }
441
442 private:
AllocateOutputTensor(OpKernelContext * context,const MklDnnShape & input_mkl_shape,memory::primitive_desc * output_pd,const memory::dims & output_dims,Tensor ** output_tensor,MklDnnShape * output_mkl_shape)443 void AllocateOutputTensor(OpKernelContext* context,
444 const MklDnnShape& input_mkl_shape,
445 memory::primitive_desc* output_pd,
446 const memory::dims& output_dims,
447 Tensor** output_tensor,
448 MklDnnShape* output_mkl_shape) {
449 DCHECK(output_tensor);
450 DCHECK(output_mkl_shape);
451
452 TensorShape output_tf_shape;
453
454 if (input_mkl_shape.IsMklTensor()) {
455 // Since input tensor is in Mkl layout, output tensor will be in Mkl
456 // layout.
457
458 // Allocate shape of Mkl tensor.
459 output_mkl_shape->SetMklTensor(true);
460 output_mkl_shape->SetMklLayout(output_pd);
461 output_mkl_shape->SetElemType(MklDnnType<T>());
462 output_mkl_shape->SetTfLayout(input_mkl_shape.GetDimension(), output_dims,
463 input_mkl_shape.GetTfDataFormat());
464
465 output_tf_shape.AddDim(output_pd->get_size() / sizeof(T));
466 } else {
467 // If input is not in Mkl layout, then output won't be in Mkl layout.
468 output_mkl_shape->SetMklTensor(false);
469 output_tf_shape = MklDnnDimsToTFShape(output_dims);
470 }
471
472 AllocateOutputSetMklShape(context, 0, output_tensor, output_tf_shape,
473 *output_mkl_shape);
474 }
475 };
476
477 // MKL-DNN Slice registration
478 #define REGISTER_MKL_SLICE(type) \
479 REGISTER_KERNEL_BUILDER(Name("_MklSlice") \
480 .Device(DEVICE_CPU) \
481 .TypeConstraint<type>("T") \
482 .HostMemory("begin") \
483 .HostMemory("size") \
484 .Label(mkl_op_registry::kMklOpLabel), \
485 MklSliceOp<CPUDevice, type>);
486
487 TF_CALL_float(REGISTER_MKL_SLICE);
488 #undef REGISTER_MKL_SLICE
489
490 } // namespace tensorflow
491
492 #endif // INTEL_MKL_DNN
493 #endif // INTEL_MKL
494