1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_UTIL_MKL_TYPES_H_ 17 #define TENSORFLOW_CORE_UTIL_MKL_TYPES_H_ 18 #ifdef INTEL_MKL 19 20 namespace tensorflow { 21 // MKL DNN 0.x will not be supported. So all related macro's have been removed 22 // This file will be removed once MKL DNN 0.x related source code is cleaned and 23 // all MKL DNN 1.x related macro's have been replaced. 24 25 #ifdef ENABLE_MKLDNN_V1 26 #define ADD_MD add_md 27 #define ALGORITHM mkldnn::algorithm 28 #define ALGORITHM_UNDEF ALGORITHM::undef 29 #define BN_FLAGS mkldnn::normalization_flags 30 #define CPU_STREAM(engine) stream(engine) 31 #define DATA_WITH_ENGINE(data, engine) data, engine 32 #define DST_MD dst_md 33 #define ENGINE_CPU engine::kind::cpu 34 #define GET_CHECK_REORDER_MEM_ARGS(md, tensor, net, net_args, engine) \ 35 md, tensor, net, net_args, engine 36 #define GET_CHECK_REORDER_TO_OP_MEM_ARGS(md, tensor, net, net_args, engine) \ 37 md, tensor, net, net_args, engine 38 #define GET_DESC get_desc() 39 #define GET_FORMAT_FROM_SHAPE(src_mkl_shape) MklTensorFormat::FORMAT_BLOCKED 40 #define GET_BLOCK_STRIDES(strides, idx) strides 41 #define GET_MEMORY_DESC_CONSTRUCTOR(dims, type, fm) \ 42 { {dims}, MklDnnType<type>(), fm } 43 #define GET_MEMORY_DESC_FROM_MEM_PTR(mem_ptr) mem_ptr->get_desc() 44 #define GET_MEMORY_PRIMITIVE_DESC_FROM_MEM_PTR(mem_ptr) \ 45 GET_MEMORY_DESC_FROM_MEM_PTR(mem_ptr) 46 #define GET_MEMORY_SIZE_FROM_MD(md, engine) md.get_size() 47 #define GET_SRC_DESC_FROM_OP_PD(op_pd) op_pd->src_desc() 48 #define GET_DST_DESC_FROM_OP_PD(op_pd) op_pd->dst_desc() 49 #define GET_BIAS_DESC_FROM_OP_PD(op_pd) op_pd->bias_desc() 50 #define GET_DIFF_DST_DESC_FROM_OP_PD(op_pd) op_pd->diff_dst_desc() 51 #define GET_WORKSPACE_DESC_FROM_OP_PD(op_pd) op_pd->workspace_desc() 52 #define GET_TENSOR_FORMAT(fmt) MklTensorFormatToMklDnnDataFormat(fmt) 53 #define GET_TF_DATA_FORMAT(shape, mem_desc) shape.GetTfDataFormat() 54 #define GET_USR_MEM_PRIM_DESC(src) src.GetUsrMemDesc() 55 #define GET_WEIGHTS_DESC_FROM_OP_PD(op_pd) op_pd->weights_desc() 56 #define GET_WEIGHTS_FORMAT_FROM_OP_PD(op_pd, op) \ 57 GET_WEIGHTS_DESC_FROM_OP_PD(op_pd) 58 #define IS_DIFF_DST_REORDER_NEEDED(diff_dst_md, op_pd, op) \ 59 diff_dst_md != op_pd->diff_dst_desc() 60 #define IS_DIFF_FILTER_REORDER_NEEDED(diff_filter_md, fmt, op_pd, op) \ 61 diff_filter_md != op_pd->diff_weights_desc() 62 #define IS_FILTER_REORDER_NEEDED(filter_md, op_pd, op) \ 63 filter_md != op_pd->weights_desc() 64 #define IS_SRC_REORDER_NEEDED(src_md, op_pd, op) src_md != op_pd->src_desc() 65 #define IS_WEIGHTS_REORDER_NEEDED(weights_md, op_pd, op) \ 66 weights_md != op_pd->weights_desc() 67 #define MEMORY_CONSTRUCTOR(mem_desc, engine, data) \ 68 memory(mem_desc, engine, data) 69 #define MEMORY_CONSTRUCTOR_PD(mem_desc, engine, data) \ 70 MEMORY_CONSTRUCTOR(mem_desc, engine, data) 71 #define MEMORY_CONSTRUCTOR_USING_MEM_PD(dims, type, fm, engine, data) \ 72 memory(GET_MEMORY_DESC_CONSTRUCTOR(dims, type, fm), engine, data) 73 #define MEMORY_CONSTRUCTOR_USING_MD(md, engine, data) memory(md, engine, data) 74 #define MEMORY_CONSTRUCTOR_WITH_MEM_PD(mem_ptr, cpu_engine, data) \ 75 memory(GET_MEMORY_DESC_FROM_MEM_PTR(mem_ptr), cpu_engine, data) 76 #define MEMORY_CONSTRUCTOR_WITHOUT_DATA(mem_desc, engine) \ 77 memory(mem_desc, engine) 78 #define MEMORY_DATA_TYPE_UNDEF memory::data_type::undef 79 #define MEMORY_DESC memory::desc 80 #define MEMORY_FORMAT mkldnn::memory::format_tag 81 #define MEMORY_FORMAT_DESC format_desc 82 #define MEMORY_FORMAT_UNDEF mkldnn::memory::format_tag::undef 83 #define MEMORY_PD_CONSTRUCTOR(dims, type, fm, engine) \ 84 memory::desc({dims}, MklDnnType<type>(), fm) 85 #define MEMORY_PD_WITHOUT_DATA(md, engine) md, engine 86 #define MEMORY_PRIMITIVE_DESC memory::desc 87 #define MEMORY_PD_CONSTRUCTOR_2_PARAMS(md, engine) MEMORY_PRIMITIVE_DESC(md) 88 #define MKL_FMT_TAG mkl_fmt_tag 89 #define MKL_TENSOR_FORMAT MklTensorFormat 90 #define MKL_TENSOR_FORMAT_BLOCKED MklTensorFormat::FORMAT_BLOCKED 91 #define MKL_TENSOR_FORMAT_IN_C MKL_TENSOR_FORMAT 92 #define MKL_TENSOR_FORMAT_INVALID MklTensorFormat::FORMAT_INVALID 93 #define MKL_TENSOR_FORMAT_NC MklTensorFormat::FORMAT_NC 94 #define MKL_TENSOR_FORMAT_NCHW MklTensorFormat::FORMAT_NCHW 95 #define MKL_TENSOR_FORMAT_NCDHW MklTensorFormat::FORMAT_NCDHW 96 #define MKL_TENSOR_FORMAT_NDHWC MklTensorFormat::FORMAT_NDHWC 97 #define MKL_TENSOR_FORMAT_NHWC MklTensorFormat::FORMAT_NHWC 98 #define MKL_TENSOR_FORMAT_TNC MklTensorFormat::FORMAT_TNC 99 #define MKL_TENSOR_FORMAT_X MklTensorFormat::FORMAT_X 100 #define MKL_TENSOR_FORMAT_UNDEF MKL_TENSOR_FORMAT_BLOCKED 101 #define NET_ARGS_PTR &net_args 102 #define OUTPUT_TF_MD output_tf_md 103 #define PRIMITIVE_DESC_BIAS bias_desc() 104 #define PRIMITIVE_DESC_DIFF_DST diff_dst_desc() 105 #define PRIMITIVE_DESC_DIFF_SRC diff_src_desc() 106 #define PRIMITIVE_DESC_DIFF_WEIGHTS diff_weights_desc() 107 #define PRIMITIVE_DESC_DST dst_desc() 108 #define PRIMITIVE_DESC_SRC src_desc() 109 #define PRIMITIVE_DESC_WORKSPACE workspace_desc() 110 #define PRIMITIVE_DESC_WEIGHTS weights_desc() 111 #define REORDER_PD_CONSTRUCTOR(src_md, dst_md, engine) \ 112 ReorderPd(engine, src_md, engine, dst_md) 113 #define REORDER_PD_CONSTRUCTOR_WITH_ATTR(src_md, dst_md, engine, prim_attr) \ 114 ReorderPd(engine, src_md, engine, dst_md, prim_attr) 115 #define SKIP_INPUT_REORDER(input_mkl_shape, input_md) \ 116 input_mkl_shape.GetTfDataFormat() == MKL_TENSOR_FORMAT_BLOCKED 117 #define SUMMAND_MD summand_md 118 #define TENSOR_FORMAT MKL_TENSOR_FORMAT 119 #define TENSOR_FORMAT_NHWC MKL_TENSOR_FORMAT_NHWC 120 #define TENSOR_MAX_DIMS MKLDNN_MAX_NDIMS 121 122 #endif // ENABLE_MKLDNN_V1 123 124 } // namespace tensorflow 125 126 #endif // INTEL_MKL 127 #endif // TENSORFLOW_CORE_UTIL_MKL_TYPES_H_ 128