• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_UTIL_MKL_TYPES_H_
17 #define TENSORFLOW_CORE_UTIL_MKL_TYPES_H_
18 #ifdef INTEL_MKL
19 
20 namespace tensorflow {
21 // MKL DNN 0.x will not be supported. So all related macro's have been removed
22 // This file will be removed once MKL DNN 0.x related source code is cleaned and
23 // all MKL DNN 1.x related macro's have been replaced.
24 
25 #ifdef ENABLE_MKLDNN_V1
26 #define ADD_MD add_md
27 #define ALGORITHM mkldnn::algorithm
28 #define ALGORITHM_UNDEF ALGORITHM::undef
29 #define BN_FLAGS mkldnn::normalization_flags
30 #define CPU_STREAM(engine) stream(engine)
31 #define DATA_WITH_ENGINE(data, engine) data, engine
32 #define DST_MD dst_md
33 #define ENGINE_CPU engine::kind::cpu
34 #define GET_CHECK_REORDER_MEM_ARGS(md, tensor, net, net_args, engine) \
35   md, tensor, net, net_args, engine
36 #define GET_CHECK_REORDER_TO_OP_MEM_ARGS(md, tensor, net, net_args, engine) \
37   md, tensor, net, net_args, engine
38 #define GET_DESC get_desc()
39 #define GET_FORMAT_FROM_SHAPE(src_mkl_shape) MklTensorFormat::FORMAT_BLOCKED
40 #define GET_BLOCK_STRIDES(strides, idx) strides
41 #define GET_MEMORY_DESC_CONSTRUCTOR(dims, type, fm) \
42   { {dims}, MklDnnType<type>(), fm }
43 #define GET_MEMORY_DESC_FROM_MEM_PTR(mem_ptr) mem_ptr->get_desc()
44 #define GET_MEMORY_PRIMITIVE_DESC_FROM_MEM_PTR(mem_ptr) \
45   GET_MEMORY_DESC_FROM_MEM_PTR(mem_ptr)
46 #define GET_MEMORY_SIZE_FROM_MD(md, engine) md.get_size()
47 #define GET_SRC_DESC_FROM_OP_PD(op_pd) op_pd->src_desc()
48 #define GET_DST_DESC_FROM_OP_PD(op_pd) op_pd->dst_desc()
49 #define GET_BIAS_DESC_FROM_OP_PD(op_pd) op_pd->bias_desc()
50 #define GET_DIFF_DST_DESC_FROM_OP_PD(op_pd) op_pd->diff_dst_desc()
51 #define GET_WORKSPACE_DESC_FROM_OP_PD(op_pd) op_pd->workspace_desc()
52 #define GET_TENSOR_FORMAT(fmt) MklTensorFormatToMklDnnDataFormat(fmt)
53 #define GET_TF_DATA_FORMAT(shape, mem_desc) shape.GetTfDataFormat()
54 #define GET_USR_MEM_PRIM_DESC(src) src.GetUsrMemDesc()
55 #define GET_WEIGHTS_DESC_FROM_OP_PD(op_pd) op_pd->weights_desc()
56 #define GET_WEIGHTS_FORMAT_FROM_OP_PD(op_pd, op) \
57   GET_WEIGHTS_DESC_FROM_OP_PD(op_pd)
58 #define IS_DIFF_DST_REORDER_NEEDED(diff_dst_md, op_pd, op) \
59   diff_dst_md != op_pd->diff_dst_desc()
60 #define IS_DIFF_FILTER_REORDER_NEEDED(diff_filter_md, fmt, op_pd, op) \
61   diff_filter_md != op_pd->diff_weights_desc()
62 #define IS_FILTER_REORDER_NEEDED(filter_md, op_pd, op) \
63   filter_md != op_pd->weights_desc()
64 #define IS_SRC_REORDER_NEEDED(src_md, op_pd, op) src_md != op_pd->src_desc()
65 #define IS_WEIGHTS_REORDER_NEEDED(weights_md, op_pd, op) \
66   weights_md != op_pd->weights_desc()
67 #define MEMORY_CONSTRUCTOR(mem_desc, engine, data) \
68   memory(mem_desc, engine, data)
69 #define MEMORY_CONSTRUCTOR_PD(mem_desc, engine, data) \
70   MEMORY_CONSTRUCTOR(mem_desc, engine, data)
71 #define MEMORY_CONSTRUCTOR_USING_MEM_PD(dims, type, fm, engine, data) \
72   memory(GET_MEMORY_DESC_CONSTRUCTOR(dims, type, fm), engine, data)
73 #define MEMORY_CONSTRUCTOR_USING_MD(md, engine, data) memory(md, engine, data)
74 #define MEMORY_CONSTRUCTOR_WITH_MEM_PD(mem_ptr, cpu_engine, data) \
75   memory(GET_MEMORY_DESC_FROM_MEM_PTR(mem_ptr), cpu_engine, data)
76 #define MEMORY_CONSTRUCTOR_WITHOUT_DATA(mem_desc, engine) \
77   memory(mem_desc, engine)
78 #define MEMORY_DATA_TYPE_UNDEF memory::data_type::undef
79 #define MEMORY_DESC memory::desc
80 #define MEMORY_FORMAT mkldnn::memory::format_tag
81 #define MEMORY_FORMAT_DESC format_desc
82 #define MEMORY_FORMAT_UNDEF mkldnn::memory::format_tag::undef
83 #define MEMORY_PD_CONSTRUCTOR(dims, type, fm, engine) \
84   memory::desc({dims}, MklDnnType<type>(), fm)
85 #define MEMORY_PD_WITHOUT_DATA(md, engine) md, engine
86 #define MEMORY_PRIMITIVE_DESC memory::desc
87 #define MEMORY_PD_CONSTRUCTOR_2_PARAMS(md, engine) MEMORY_PRIMITIVE_DESC(md)
88 #define MKL_FMT_TAG mkl_fmt_tag
89 #define MKL_TENSOR_FORMAT MklTensorFormat
90 #define MKL_TENSOR_FORMAT_BLOCKED MklTensorFormat::FORMAT_BLOCKED
91 #define MKL_TENSOR_FORMAT_IN_C MKL_TENSOR_FORMAT
92 #define MKL_TENSOR_FORMAT_INVALID MklTensorFormat::FORMAT_INVALID
93 #define MKL_TENSOR_FORMAT_NC MklTensorFormat::FORMAT_NC
94 #define MKL_TENSOR_FORMAT_NCHW MklTensorFormat::FORMAT_NCHW
95 #define MKL_TENSOR_FORMAT_NCDHW MklTensorFormat::FORMAT_NCDHW
96 #define MKL_TENSOR_FORMAT_NDHWC MklTensorFormat::FORMAT_NDHWC
97 #define MKL_TENSOR_FORMAT_NHWC MklTensorFormat::FORMAT_NHWC
98 #define MKL_TENSOR_FORMAT_TNC MklTensorFormat::FORMAT_TNC
99 #define MKL_TENSOR_FORMAT_X MklTensorFormat::FORMAT_X
100 #define MKL_TENSOR_FORMAT_UNDEF MKL_TENSOR_FORMAT_BLOCKED
101 #define NET_ARGS_PTR &net_args
102 #define OUTPUT_TF_MD output_tf_md
103 #define PRIMITIVE_DESC_BIAS bias_desc()
104 #define PRIMITIVE_DESC_DIFF_DST diff_dst_desc()
105 #define PRIMITIVE_DESC_DIFF_SRC diff_src_desc()
106 #define PRIMITIVE_DESC_DIFF_WEIGHTS diff_weights_desc()
107 #define PRIMITIVE_DESC_DST dst_desc()
108 #define PRIMITIVE_DESC_SRC src_desc()
109 #define PRIMITIVE_DESC_WORKSPACE workspace_desc()
110 #define PRIMITIVE_DESC_WEIGHTS weights_desc()
111 #define REORDER_PD_CONSTRUCTOR(src_md, dst_md, engine) \
112   ReorderPd(engine, src_md, engine, dst_md)
113 #define REORDER_PD_CONSTRUCTOR_WITH_ATTR(src_md, dst_md, engine, prim_attr) \
114   ReorderPd(engine, src_md, engine, dst_md, prim_attr)
115 #define SKIP_INPUT_REORDER(input_mkl_shape, input_md) \
116   input_mkl_shape.GetTfDataFormat() == MKL_TENSOR_FORMAT_BLOCKED
117 #define SUMMAND_MD summand_md
118 #define TENSOR_FORMAT MKL_TENSOR_FORMAT
119 #define TENSOR_FORMAT_NHWC MKL_TENSOR_FORMAT_NHWC
120 #define TENSOR_MAX_DIMS MKLDNN_MAX_NDIMS
121 
122 #endif  // ENABLE_MKLDNN_V1
123 
124 }  // namespace tensorflow
125 
126 #endif  // INTEL_MKL
127 #endif  // TENSORFLOW_CORE_UTIL_MKL_TYPES_H_
128