1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/compiler/xla/service/cpu/runtime_conv2d_mkl.h"
16 #include <iostream>
17 #include "tensorflow/compiler/xla/executable_run_options.h"
18 #include "tensorflow/core/platform/dynamic_annotations.h"
19 #include "tensorflow/core/platform/types.h"
20
21 using tensorflow::int64;
22
23 #ifdef INTEL_MKL
24 #include <omp.h>
25 #include "mkldnn.hpp"
26 #include "tensorflow/compiler/xla/service/cpu/runtime_conv2d.h"
27
28 namespace {
29
30 // Downcast an int64 to int and check if value is in range.
ToInt(int64 input)31 int ToInt(int64 input) {
32 int output = static_cast<int>(input);
33 if (static_cast<int64>(output) != input) {
34 std::cerr << "Error occurred in downcasting int64 to int32: Value " << input
35 << " is out-of-range for type int32. \n";
36 exit(1);
37 }
38 return output;
39 }
40
41 using mkldnn::convolution_direct;
42 using mkldnn::convolution_forward;
43 using mkldnn::engine;
44 using mkldnn::memory;
45 using mkldnn::padding_kind;
46 using mkldnn::primitive;
47 using mkldnn::prop_kind;
48 using mkldnn::reorder;
49 using mkldnn::stream;
50
51 template <typename EigenDevice, typename ScalarType>
MKLConvImpl(const EigenDevice & device,ScalarType * out,ScalarType * lhs,ScalarType * rhs,int64 input_batch,int64 input_rows,int64 input_cols,int64 input_channels,int64 kernel_rows,int64 kernel_cols,int64 kernel_channels,int64 kernel_filters,int64 output_rows,int64 output_cols,int64 row_stride,int64 col_stride,int64 padding_top,int64 padding_bottom,int64 padding_left,int64 padding_right,int64 lhs_row_dilation,int64 lhs_col_dilation,int64 rhs_row_dilation,int64 rhs_col_dilation)52 void MKLConvImpl(const EigenDevice& device, ScalarType* out, ScalarType* lhs,
53 ScalarType* rhs, int64 input_batch, int64 input_rows,
54 int64 input_cols, int64 input_channels, int64 kernel_rows,
55 int64 kernel_cols, int64 kernel_channels, int64 kernel_filters,
56 int64 output_rows, int64 output_cols, int64 row_stride,
57 int64 col_stride, int64 padding_top, int64 padding_bottom,
58 int64 padding_left, int64 padding_right,
59 int64 lhs_row_dilation, int64 lhs_col_dilation,
60 int64 rhs_row_dilation, int64 rhs_col_dilation) {
61 auto cpu_engine = engine(engine::cpu, 0);
62
63 // Create a vector primitive to hold the network.
64 std::vector<primitive> net;
65
66 // Since memory::dims takes int for each dimension, we downcast the int64
67 // values to int using the ToInt function defined above.
68 memory::dims conv1_src_dim = {ToInt(input_batch), ToInt(input_channels),
69 ToInt(input_rows), ToInt(input_cols)};
70 memory::dims conv1_weights_dim = {ToInt(kernel_filters),
71 ToInt(kernel_channels), ToInt(kernel_rows),
72 ToInt(kernel_cols)};
73 memory::dims conv1_dst_dim = {ToInt(input_batch), ToInt(kernel_filters),
74 ToInt(output_rows), ToInt(output_cols)};
75 memory::dims conv1_strides = {ToInt(row_stride), ToInt(col_stride)};
76 // Note: In MKL_DNN dilation starts from 0.
77 memory::dims conv1_dilates = {ToInt(rhs_row_dilation - 1),
78 ToInt(rhs_col_dilation - 1)};
79 memory::dims conv1_padding_l = {ToInt(padding_top), ToInt(padding_left)};
80 memory::dims conv1_padding_r = {ToInt(padding_bottom), ToInt(padding_right)};
81
82 // Create memory for user data. Input and output data have format of NHWC and
83 // kernel data has format of HWIO.
84 // Note that as a convention in MKL-DNN, the dimensions of the data is always
85 // described in NCHW/IOHW, regardless of the actual layout of the data.
86 auto user_src_memory =
87 memory({{{conv1_src_dim}, memory::data_type::f32, memory::format::nhwc},
88 cpu_engine},
89 lhs);
90 auto user_weights_memory = memory(
91 {{{conv1_weights_dim}, memory::data_type::f32, memory::format::hwio},
92 cpu_engine},
93 rhs);
94 auto user_dst_memory =
95 memory({{{conv1_dst_dim}, memory::data_type::f32, memory::format::nhwc},
96 cpu_engine},
97 out);
98
99 // Create memory descriptors for convolution data with no specified format for
100 // best performance.
101 auto conv1_src_mem_desc = memory::desc(
102 {conv1_src_dim}, memory::data_type::f32, memory::format::any);
103 auto conv1_weights_mem_desc = memory::desc(
104 {conv1_weights_dim}, memory::data_type::f32, memory::format::any);
105 auto conv1_dst_mem_desc = memory::desc(
106 {conv1_dst_dim}, memory::data_type::f32, memory::format::any);
107
108 // Create a convolution.
109 auto conv1_desc = convolution_forward::desc(
110 prop_kind::forward_inference, convolution_direct, conv1_src_mem_desc,
111 conv1_weights_mem_desc, conv1_dst_mem_desc, conv1_strides, conv1_dilates,
112 conv1_padding_l, conv1_padding_r, padding_kind::zero);
113 auto conv1_prim_desc =
114 convolution_forward::primitive_desc(conv1_desc, cpu_engine);
115
116 // Create reorders for data and weights if layout requested by convolution is
117 // different from NCHW/OIHW.
118 auto conv1_src_memory = user_src_memory;
119 if (memory::primitive_desc(conv1_prim_desc.src_primitive_desc()) !=
120 user_src_memory.get_primitive_desc()) {
121 conv1_src_memory = memory(conv1_prim_desc.src_primitive_desc());
122 net.push_back(reorder(user_src_memory, conv1_src_memory));
123 }
124
125 auto conv1_weights_memory = user_weights_memory;
126 if (memory::primitive_desc(conv1_prim_desc.weights_primitive_desc()) !=
127 user_weights_memory.get_primitive_desc()) {
128 conv1_weights_memory = memory(conv1_prim_desc.weights_primitive_desc());
129 net.push_back(reorder(user_weights_memory, conv1_weights_memory));
130 }
131
132 // Check if output need layout conversion. If yes, create memory for
133 // intermediate layer of conv1_dst_memory.
134 bool need_output_conversion =
135 (memory::primitive_desc(conv1_prim_desc.dst_primitive_desc()) !=
136 user_dst_memory.get_primitive_desc());
137 auto conv1_dst_memory = need_output_conversion
138 ? memory(conv1_prim_desc.dst_primitive_desc())
139 : user_dst_memory;
140
141 // Create convolution primitive and add it to net.
142 net.push_back(convolution_forward(conv1_prim_desc, conv1_src_memory,
143 conv1_weights_memory, conv1_dst_memory));
144 if (need_output_conversion) {
145 net.push_back(reorder(conv1_dst_memory, user_dst_memory));
146 }
147 stream(stream::kind::eager).submit(net).wait();
148 }
149 } // namespace
150 #endif // INTEL_MKL
151
__xla_cpu_runtime_MKLConvF32(const void * run_options_ptr,float * out,float * lhs,float * rhs,int64 input_batch,int64 input_rows,int64 input_cols,int64 input_channels,int64 kernel_rows,int64 kernel_cols,int64 kernel_channels,int64 kernel_filters,int64 output_rows,int64 output_cols,int64 row_stride,int64 col_stride,int64 padding_top,int64 padding_bottom,int64 padding_left,int64 padding_right,int64 lhs_row_dilation,int64 lhs_col_dilation,int64 rhs_row_dilation,int64 rhs_col_dilation)152 TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_MKLConvF32(
153 const void* run_options_ptr, float* out, float* lhs, float* rhs,
154 int64 input_batch, int64 input_rows, int64 input_cols, int64 input_channels,
155 int64 kernel_rows, int64 kernel_cols, int64 kernel_channels,
156 int64 kernel_filters, int64 output_rows, int64 output_cols,
157 int64 row_stride, int64 col_stride, int64 padding_top, int64 padding_bottom,
158 int64 padding_left, int64 padding_right, int64 lhs_row_dilation,
159 int64 lhs_col_dilation, int64 rhs_row_dilation, int64 rhs_col_dilation) {
160 #ifdef INTEL_MKL
161 // Since MKL_DNN cannot handle transposed convolution, this is handled by
162 // Eigen.
163 if (lhs_row_dilation > 1 || lhs_col_dilation > 1) {
164 __xla_cpu_runtime_EigenConvF32(
165 run_options_ptr, out, lhs, rhs, input_batch, input_rows, input_cols,
166 input_channels, kernel_rows, kernel_cols, kernel_channels,
167 kernel_filters, output_rows, output_cols, row_stride, col_stride,
168 padding_top, padding_bottom, padding_left, padding_right,
169 lhs_row_dilation, lhs_col_dilation, rhs_row_dilation, rhs_col_dilation);
170 } else {
171 MKLConvImpl(nullptr, out, lhs, rhs, input_batch, input_rows, input_cols,
172 input_channels, kernel_rows, kernel_cols, kernel_channels,
173 kernel_filters, output_rows, output_cols, row_stride,
174 col_stride, padding_top, padding_bottom, padding_left,
175 padding_right, lhs_row_dilation, lhs_col_dilation,
176 rhs_row_dilation, rhs_col_dilation);
177 }
178 #else
179 std::cerr << "Attempt to call MKL Conv2D runtime library without defining "
180 "INTEL_MKL. Add --config=mkl to build with MKL.";
181 exit(1);
182 #endif // INTEL_MKL
183 }
184