1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h"
17
18 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
19 #include "tensorflow/core/platform/dynamic_annotations.h"
20 #include "tensorflow/core/platform/types.h"
21
22 #if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL)
23 #include "tensorflow/core/kernels/eigen_contraction_kernel.h"
24 #endif
25
26 namespace {
27
Is16BytesAligned(void * ptr)28 bool Is16BytesAligned(void* ptr) {
29 return reinterpret_cast<uintptr_t>(ptr) % 16 == 0;
30 }
31
32 template <typename T, Eigen::AlignmentType Alignment>
MatMul(const void * run_options_ptr,T * out,T * lhs,T * rhs,int64_t m,int64_t n,int64_t k,int32_t transpose_lhs,int32_t transpose_rhs)33 void MatMul(const void* run_options_ptr, T* out, T* lhs, T* rhs, int64_t m,
34 int64_t n, int64_t k, int32_t transpose_lhs,
35 int32_t transpose_rhs) {
36 int64_t lhs_rows = m;
37 int64_t lhs_cols = k;
38 if (transpose_lhs) {
39 std::swap(lhs_rows, lhs_cols);
40 }
41
42 int64_t rhs_rows = k;
43 int64_t rhs_cols = n;
44 if (transpose_rhs) {
45 std::swap(rhs_rows, rhs_cols);
46 }
47
48 const Eigen::TensorMap<Eigen::Tensor<const T, 2>, Alignment> A(lhs, lhs_rows,
49 lhs_cols);
50 const Eigen::TensorMap<Eigen::Tensor<const T, 2>, Alignment> B(rhs, rhs_rows,
51 rhs_cols);
52 Eigen::TensorMap<Eigen::Tensor<T, 2>, Alignment> C(out, m, n);
53
54 typedef typename Eigen::Tensor<T, 2>::DimensionPair DimPair;
55 int lhs_contract_dim = transpose_lhs ? 0 : 1;
56 int rhs_contract_dim = transpose_rhs ? 1 : 0;
57 const Eigen::array<DimPair, 1> dims(
58 {DimPair(lhs_contract_dim, rhs_contract_dim)});
59
60 // Matrix multiply is a special case of the "contract" operation where
61 // the contraction is performed along dimension 1 of the lhs and dimension
62 // 0 of the rhs.
63 C = A.contract(B, dims);
64 }
65
66 template <typename T>
SingleThreadedMatMulDispatch(const void * run_options_ptr,T * out,T * lhs,T * rhs,int64_t m,int64_t n,int64_t k,int32_t transpose_lhs,int32_t transpose_rhs)67 void SingleThreadedMatMulDispatch(const void* run_options_ptr, T* out, T* lhs,
68 T* rhs, int64_t m, int64_t n, int64_t k,
69 int32_t transpose_lhs,
70 int32_t transpose_rhs) {
71 bool all_buffers_16b_aligned =
72 Is16BytesAligned(out) && Is16BytesAligned(lhs) && Is16BytesAligned(rhs);
73
74 if (!all_buffers_16b_aligned) {
75 MatMul<T, Eigen::Unaligned>(run_options_ptr, out, lhs, rhs, m, n, k,
76 transpose_lhs, transpose_rhs);
77 }
78
79 MatMul<T, Eigen::Aligned16>(run_options_ptr, out, lhs, rhs, m, n, k,
80 transpose_lhs, transpose_rhs);
81 }
82
83 } // namespace
84
85 TF_ATTRIBUTE_NO_SANITIZE_MEMORY void
__xla_cpu_runtime_EigenSingleThreadedMatMulF16(const void * run_options_ptr,Eigen::half * out,Eigen::half * lhs,Eigen::half * rhs,int64_t m,int64_t n,int64_t k,int32_t transpose_lhs,int32_t transpose_rhs)86 __xla_cpu_runtime_EigenSingleThreadedMatMulF16(
87 const void* run_options_ptr, Eigen::half* out, Eigen::half* lhs,
88 Eigen::half* rhs, int64_t m, int64_t n, int64_t k, int32_t transpose_lhs,
89 int32_t transpose_rhs) {
90 SingleThreadedMatMulDispatch<Eigen::half>(run_options_ptr, out, lhs, rhs, m,
91 n, k, transpose_lhs, transpose_rhs);
92 }
93
94 TF_ATTRIBUTE_NO_SANITIZE_MEMORY void
__xla_cpu_runtime_EigenSingleThreadedMatMulF32(const void * run_options_ptr,float * out,float * lhs,float * rhs,int64_t m,int64_t n,int64_t k,int32_t transpose_lhs,int32_t transpose_rhs)95 __xla_cpu_runtime_EigenSingleThreadedMatMulF32(const void* run_options_ptr,
96 float* out, float* lhs,
97 float* rhs, int64_t m, int64_t n,
98 int64_t k, int32_t transpose_lhs,
99 int32_t transpose_rhs) {
100 SingleThreadedMatMulDispatch<float>(run_options_ptr, out, lhs, rhs, m, n, k,
101 transpose_lhs, transpose_rhs);
102 }
103
104 TF_ATTRIBUTE_NO_SANITIZE_MEMORY void
__xla_cpu_runtime_EigenSingleThreadedMatMulF64(const void * run_options_ptr,double * out,double * lhs,double * rhs,int64_t m,int64_t n,int64_t k,int32_t transpose_lhs,int32_t transpose_rhs)105 __xla_cpu_runtime_EigenSingleThreadedMatMulF64(const void* run_options_ptr,
106 double* out, double* lhs,
107 double* rhs, int64_t m,
108 int64_t n, int64_t k,
109 int32_t transpose_lhs,
110 int32_t transpose_rhs) {
111 SingleThreadedMatMulDispatch<double>(run_options_ptr, out, lhs, rhs, m, n, k,
112 transpose_lhs, transpose_rhs);
113 }
114
115 TF_ATTRIBUTE_NO_SANITIZE_MEMORY void
__xla_cpu_runtime_EigenSingleThreadedMatMulC64(const void * run_options_ptr,std::complex<float> * out,std::complex<float> * lhs,std::complex<float> * rhs,int64_t m,int64_t n,int64_t k,int32_t transpose_lhs,int32_t transpose_rhs)116 __xla_cpu_runtime_EigenSingleThreadedMatMulC64(
117 const void* run_options_ptr, std::complex<float>* out,
118 std::complex<float>* lhs, std::complex<float>* rhs, int64_t m, int64_t n,
119 int64_t k, int32_t transpose_lhs, int32_t transpose_rhs) {
120 SingleThreadedMatMulDispatch<std::complex<float>>(
121 run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs);
122 }
123
124 TF_ATTRIBUTE_NO_SANITIZE_MEMORY void
__xla_cpu_runtime_EigenSingleThreadedMatMulC128(const void * run_options_ptr,std::complex<double> * out,std::complex<double> * lhs,std::complex<double> * rhs,int64_t m,int64_t n,int64_t k,int32_t transpose_lhs,int32_t transpose_rhs)125 __xla_cpu_runtime_EigenSingleThreadedMatMulC128(
126 const void* run_options_ptr, std::complex<double>* out,
127 std::complex<double>* lhs, std::complex<double>* rhs, int64_t m, int64_t n,
128 int64_t k, int32_t transpose_lhs, int32_t transpose_rhs) {
129 SingleThreadedMatMulDispatch<std::complex<double>>(
130 run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs);
131 }
132
133 TF_ATTRIBUTE_NO_SANITIZE_MEMORY void
__xla_cpu_runtime_EigenSingleThreadedMatMulS32(const void * run_options_ptr,tensorflow::int32 * out,tensorflow::int32 * lhs,tensorflow::int32 * rhs,int64_t m,int64_t n,int64_t k,int32_t transpose_lhs,int32_t transpose_rhs)134 __xla_cpu_runtime_EigenSingleThreadedMatMulS32(
135 const void* run_options_ptr, tensorflow::int32* out, tensorflow::int32* lhs,
136 tensorflow::int32* rhs, int64_t m, int64_t n, int64_t k,
137 int32_t transpose_lhs, int32_t transpose_rhs) {
138 SingleThreadedMatMulDispatch<tensorflow::int32>(
139 run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs);
140 }
141