• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/cpu/runtime_matmul.h"
17 
18 #define EIGEN_USE_THREADS
19 
20 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
21 #include "tensorflow/compiler/xla/executable_run_options.h"
22 #include "tensorflow/compiler/xla/service/cpu/runtime_lightweight_check.h"
23 #include "tensorflow/core/platform/dynamic_annotations.h"
24 #include "tensorflow/core/platform/types.h"
25 
26 #if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL)
27 #include "tensorflow/core/kernels/eigen_contraction_kernel.h"
28 #endif
29 
30 namespace {
31 
Is16BytesAligned(void * ptr)32 bool Is16BytesAligned(void* ptr) {
33   return reinterpret_cast<uintptr_t>(ptr) % 16 == 0;
34 }
35 
36 template <typename T, Eigen::AlignmentType Alignment>
MatMul(const void * run_options_ptr,T * out,T * lhs,T * rhs,int64_t m,int64_t n,int64_t k,int32_t transpose_lhs,int32_t transpose_rhs)37 void MatMul(const void* run_options_ptr, T* out, T* lhs, T* rhs, int64_t m,
38             int64_t n, int64_t k, int32_t transpose_lhs,
39             int32_t transpose_rhs) {
40   const xla::ExecutableRunOptions* run_options =
41       static_cast<const xla::ExecutableRunOptions*>(run_options_ptr);
42 
43   int64_t lhs_rows = m;
44   int64_t lhs_cols = k;
45   if (transpose_lhs) {
46     std::swap(lhs_rows, lhs_cols);
47   }
48 
49   int64_t rhs_rows = k;
50   int64_t rhs_cols = n;
51   if (transpose_rhs) {
52     std::swap(rhs_rows, rhs_cols);
53   }
54 
55   const Eigen::TensorMap<Eigen::Tensor<const T, 2>, Alignment> A(lhs, lhs_rows,
56                                                                  lhs_cols);
57   const Eigen::TensorMap<Eigen::Tensor<const T, 2>, Alignment> B(rhs, rhs_rows,
58                                                                  rhs_cols);
59   Eigen::TensorMap<Eigen::Tensor<T, 2>, Alignment> C(out, m, n);
60 
61   typedef typename Eigen::Tensor<T, 2>::DimensionPair DimPair;
62   int lhs_contract_dim = transpose_lhs ? 0 : 1;
63   int rhs_contract_dim = transpose_rhs ? 1 : 0;
64   const Eigen::array<DimPair, 1> dims(
65       {DimPair(lhs_contract_dim, rhs_contract_dim)});
66 
67   // Matrix multiply is a special case of the "contract" operation where
68   // the contraction is performed along dimension 1 of the lhs and dimension
69   // 0 of the rhs.
70   XLA_LIGHTWEIGHT_CHECK(run_options->intra_op_thread_pool() != nullptr);
71   C.device(*run_options->intra_op_thread_pool()) = A.contract(B, dims);
72 }
73 
74 template <typename T>
MatMulDispatch(const void * run_options_ptr,T * out,T * lhs,T * rhs,int64_t m,int64_t n,int64_t k,int32_t transpose_lhs,int32_t transpose_rhs)75 void MatMulDispatch(const void* run_options_ptr, T* out, T* lhs, T* rhs,
76                     int64_t m, int64_t n, int64_t k, int32_t transpose_lhs,
77                     int32_t transpose_rhs) {
78   bool all_buffers_16b_aligned =
79       Is16BytesAligned(out) && Is16BytesAligned(lhs) && Is16BytesAligned(rhs);
80 
81   if (!all_buffers_16b_aligned) {
82     MatMul<T, Eigen::Unaligned>(run_options_ptr, out, lhs, rhs, m, n, k,
83                                 transpose_lhs, transpose_rhs);
84     return;
85   }
86 
87   MatMul<T, Eigen::Aligned16>(run_options_ptr, out, lhs, rhs, m, n, k,
88                               transpose_lhs, transpose_rhs);
89 }
90 
91 }  // namespace
92 
__xla_cpu_runtime_EigenMatMulF16(const void * run_options_ptr,Eigen::half * out,Eigen::half * lhs,Eigen::half * rhs,int64_t m,int64_t n,int64_t k,int32_t transpose_lhs,int32_t transpose_rhs)93 TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenMatMulF16(
94     const void* run_options_ptr, Eigen::half* out, Eigen::half* lhs,
95     Eigen::half* rhs, int64_t m, int64_t n, int64_t k, int32_t transpose_lhs,
96     int32_t transpose_rhs) {
97   MatMulDispatch<Eigen::half>(run_options_ptr, out, lhs, rhs, m, n, k,
98                               transpose_lhs, transpose_rhs);
99 }
100 
__xla_cpu_runtime_EigenMatMulF32(const void * run_options_ptr,float * out,float * lhs,float * rhs,int64_t m,int64_t n,int64_t k,int32_t transpose_lhs,int32_t transpose_rhs)101 TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenMatMulF32(
102     const void* run_options_ptr, float* out, float* lhs, float* rhs, int64_t m,
103     int64_t n, int64_t k, int32_t transpose_lhs, int32_t transpose_rhs) {
104   MatMulDispatch<float>(run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs,
105                         transpose_rhs);
106 }
107 
__xla_cpu_runtime_EigenMatMulF64(const void * run_options_ptr,double * out,double * lhs,double * rhs,int64_t m,int64_t n,int64_t k,int32_t transpose_lhs,int32_t transpose_rhs)108 TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenMatMulF64(
109     const void* run_options_ptr, double* out, double* lhs, double* rhs,
110     int64_t m, int64_t n, int64_t k, int32_t transpose_lhs,
111     int32_t transpose_rhs) {
112   MatMulDispatch<double>(run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs,
113                          transpose_rhs);
114 }
115 
__xla_cpu_runtime_EigenMatMulC64(const void * run_options_ptr,std::complex<float> * out,std::complex<float> * lhs,std::complex<float> * rhs,int64_t m,int64_t n,int64_t k,int32_t transpose_lhs,int32_t transpose_rhs)116 TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenMatMulC64(
117     const void* run_options_ptr, std::complex<float>* out,
118     std::complex<float>* lhs, std::complex<float>* rhs, int64_t m, int64_t n,
119     int64_t k, int32_t transpose_lhs, int32_t transpose_rhs) {
120   MatMulDispatch<std::complex<float>>(run_options_ptr, out, lhs, rhs, m, n, k,
121                                       transpose_lhs, transpose_rhs);
122 }
123 
__xla_cpu_runtime_EigenMatMulC128(const void * run_options_ptr,std::complex<double> * out,std::complex<double> * lhs,std::complex<double> * rhs,int64_t m,int64_t n,int64_t k,int32_t transpose_lhs,int32_t transpose_rhs)124 TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenMatMulC128(
125     const void* run_options_ptr, std::complex<double>* out,
126     std::complex<double>* lhs, std::complex<double>* rhs, int64_t m, int64_t n,
127     int64_t k, int32_t transpose_lhs, int32_t transpose_rhs) {
128   MatMulDispatch<std::complex<double>>(run_options_ptr, out, lhs, rhs, m, n, k,
129                                        transpose_lhs, transpose_rhs);
130 }
131 
__xla_cpu_runtime_EigenMatMulS32(const void * run_options_ptr,tensorflow::int32 * out,tensorflow::int32 * lhs,tensorflow::int32 * rhs,int64_t m,int64_t n,int64_t k,int32_t transpose_lhs,int32_t transpose_rhs)132 TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenMatMulS32(
133     const void* run_options_ptr, tensorflow::int32* out, tensorflow::int32* lhs,
134     tensorflow::int32* rhs, int64_t m, int64_t n, int64_t k,
135     int32_t transpose_lhs, int32_t transpose_rhs) {
136   MatMulDispatch<tensorflow::int32>(run_options_ptr, out, lhs, rhs, m, n, k,
137                                     transpose_lhs, transpose_rhs);
138 }
139