1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_LITE_KERNELS_CPU_BACKEND_GEMM_H_
17 #define TENSORFLOW_LITE_KERNELS_CPU_BACKEND_GEMM_H_
18
19 #include <cstdint>
20
21 #include "ruy/profiler/instrumentation.h" // from @ruy
22 #include "tensorflow/lite/kernels/cpu_backend_context.h"
23 #include "tensorflow/lite/kernels/cpu_backend_gemm_custom_gemv.h"
24 #include "tensorflow/lite/kernels/cpu_backend_gemm_params.h"
25 #include "tensorflow/lite/kernels/cpu_backend_gemm_ruy.h"
26
27 #ifndef TFLITE_WITH_RUY
28 #include "tensorflow/lite/kernels/cpu_backend_gemm_eigen.h"
29 #include "tensorflow/lite/kernels/cpu_backend_gemm_gemmlowp.h"
30 #include "tensorflow/lite/kernels/cpu_backend_gemm_x86.h"
31 #endif
32
33 namespace tflite {
34
35 namespace cpu_backend_gemm {
36
37 // The main entry point for CpuBackendGemm::Gemm.
38 //
39 // If TFLITE_WITH_RUY is set, CpuBackendGemm::Gemm will always go to Ruy aka
40 // GemmImplUsingRuy. Other cases are as follows:
41 //
42 // |Quantized (uint8)|Quantized (int8)| Float |
43 // TFLITE_WITH_RUY | Ruy | Ruy | Ruy |
44 // !TFLITE_WITH_RUY | gemmlowp | Ruy/gemmlowp* | eigen |
45 // * - Ruy if NEON is not available.
46
47 // On x86 platforms:
48 // (default) | gemmlowp | Ruy | eigen |
49 // TFLITE_X86_RUY_\ | Ruy | Ruy | Ruy |
50 // ENABLED && (AVX
51 // or above available)
52
53 #if !defined(TFLITE_WITH_RUY) && defined(TFLITE_X86_PLATFORM)
54 /* GEMM dispatch implementation for x86.
55 */
56 template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
57 typename DstScalar, QuantizationFlavor quantization_flavor>
58 struct GemmImpl : detail::GemmImplX86<LhsScalar, RhsScalar, AccumScalar,
59 DstScalar, quantization_flavor> {};
60 #else
61 /* Generic implementation using ruy.
62 * Non-ruy implementation will be partial specializations of this template.
63 */
64 template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
65 typename DstScalar, QuantizationFlavor quantization_flavor>
66 struct GemmImpl : detail::GemmImplUsingRuy<LhsScalar, RhsScalar, AccumScalar,
67 DstScalar, quantization_flavor> {};
68
69 #if !defined(TFLITE_WITH_RUY)
70
71 /* Specializations using gemmlowp */
72 template <typename SrcScalar, typename DstScalar,
73 QuantizationFlavor quantization_flavor>
74 struct GemmImpl<SrcScalar, SrcScalar, std::int32_t, DstScalar,
75 quantization_flavor>
76 : detail::GemmImplUsingGemmlowp<SrcScalar, SrcScalar, std::int32_t,
77 DstScalar, quantization_flavor> {};
78
79 // When SrcScalar=int8 or DstScalar=int8, gemmlowp fails to compile
80 // outside of NEON. We avoid the compilation failure by subspecializing these
81 // cases, rerouting it back to ruy.
82 #if !defined(GEMMLOWP_NEON)
83 template <typename SrcScalar, QuantizationFlavor quantization_flavor>
84 struct GemmImpl<SrcScalar, SrcScalar, std::int32_t, std::int8_t,
85 quantization_flavor>
86 : detail::GemmImplUsingRuy<SrcScalar, SrcScalar, std::int32_t, std::int8_t,
87 quantization_flavor> {};
88
89 template <typename DstScalar, QuantizationFlavor quantization_flavor>
90 struct GemmImpl<std::int8_t, std::int8_t, std::int32_t, DstScalar,
91 quantization_flavor>
92 : detail::GemmImplUsingRuy<std::int8_t, std::int8_t, std::int32_t,
93 DstScalar, quantization_flavor> {};
94
95 template <QuantizationFlavor quantization_flavor>
96 struct GemmImpl<std::int8_t, std::int8_t, std::int32_t, std::int8_t,
97 quantization_flavor>
98 : detail::GemmImplUsingRuy<std::int8_t, std::int8_t, std::int32_t,
99 std::int8_t, quantization_flavor> {};
100 #endif // not GEMMLOWP_NEON
101
102 /* Specializations using Eigen */
103
104 template <>
105 struct GemmImpl<float, float, float, float, QuantizationFlavor::kFloatingPoint>
106 : detail::GemmImplUsingEigen {};
107
108 #endif // not TFLITE_WITH_RUY
109
110 #endif // not TFLITE_WITH_RUY and TFLITE_X86_PLATFORM
111
112 /* Public entry point */
113
114 template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
115 typename DstScalar, QuantizationFlavor quantization_flavor>
Gemm(const MatrixParams<LhsScalar> & lhs_params,const LhsScalar * lhs_data,const MatrixParams<RhsScalar> & rhs_params,const RhsScalar * rhs_data,const MatrixParams<DstScalar> & dst_params,DstScalar * dst_data,const GemmParams<AccumScalar,DstScalar,quantization_flavor> & params,CpuBackendContext * context)116 void Gemm(const MatrixParams<LhsScalar>& lhs_params, const LhsScalar* lhs_data,
117 const MatrixParams<RhsScalar>& rhs_params, const RhsScalar* rhs_data,
118 const MatrixParams<DstScalar>& dst_params, DstScalar* dst_data,
119 const GemmParams<AccumScalar, DstScalar, quantization_flavor>& params,
120 CpuBackendContext* context) {
121 ruy::profiler::ScopeLabel label("cpu_backend_gemm::Gemm");
122 ValidateParams(lhs_params, rhs_params, dst_params, params);
123 // In some cases we want to unconditionally use ruy as the backend, overriding
124 // the `tflite_with_ruy` setting and the platform default.
125 bool must_use_ruy = false;
126 if (context->use_caching()) {
127 // Only ruy supports caching of pre-packed matrices. Due to the large
128 // performance impact in the cases where it's typically used, this overrides
129 // the default.
130 must_use_ruy = true;
131 }
132 if (lhs_params.order != Order::kRowMajor ||
133 rhs_params.order != Order::kColMajor ||
134 dst_params.order != Order::kColMajor) {
135 // ruy supports all 2^3=8 combinations of storage orders with comparable
136 // performance. In ruy, it's only a runtime switch. In other backends
137 // (gemmlowp, Eigen), storage orders are template parameters, supporting
138 // all 8 combinations would be up to a 8-fold code size increase, so we
139 // prefer to force usage of ruy in these cases.
140 must_use_ruy = true;
141 }
142 if (must_use_ruy) {
143 detail::GemmImplUsingRuy<LhsScalar, RhsScalar, AccumScalar, DstScalar,
144 quantization_flavor>::Run(lhs_params, lhs_data,
145 rhs_params, rhs_data,
146 dst_params, dst_data,
147 params, context);
148 return;
149 }
150 // If we did not choose to force usage of ruy above, then we may now consider
151 // using custom GEMV code for the matrix*vector cases.
152 const bool try_custom_gemv = (dst_params.cols == 1);
153 if (try_custom_gemv) {
154 // GEMV case: try a custom fast GEMV path. It will return true if it
155 // actually handled it.
156 if (detail::CustomGemv(lhs_params, lhs_data, rhs_params, rhs_data,
157 dst_params, dst_data, params, context)) {
158 return;
159 }
160 }
161 // Generic case: dispatch to any backend as a general GEMM.
162 GemmImpl<LhsScalar, RhsScalar, AccumScalar, DstScalar,
163 quantization_flavor>::Run(lhs_params, lhs_data, rhs_params, rhs_data,
164 dst_params, dst_data, params, context);
165 }
166
167 // Special path for gemm with raw accumulator case. i.e. AccumScalar ==
168 // DstScalar == int32 case.
169 template <typename LhsScalar, typename RhsScalar,
170 QuantizationFlavor quantization_flavor>
Gemm(const MatrixParams<LhsScalar> & lhs_params,const LhsScalar * lhs_data,const MatrixParams<RhsScalar> & rhs_params,const RhsScalar * rhs_data,const MatrixParams<int32_t> & dst_params,int32_t * dst_data,const GemmParams<int32_t,int32_t,quantization_flavor> & params,CpuBackendContext * context)171 void Gemm(const MatrixParams<LhsScalar>& lhs_params, const LhsScalar* lhs_data,
172 const MatrixParams<RhsScalar>& rhs_params, const RhsScalar* rhs_data,
173 const MatrixParams<int32_t>& dst_params, int32_t* dst_data,
174 const GemmParams<int32_t, int32_t, quantization_flavor>& params,
175 CpuBackendContext* context) {
176 ruy::profiler::ScopeLabel label("cpu_backend_gemm::Gemm");
177 ValidateParams(lhs_params, rhs_params, dst_params, params);
178
179 // Currently, only Ruy backend supports get raw accumulator, so we use ruy
180 // only.
181 ruy::profiler::ScopeLabel label2("cpu_backend_gemm::Gemm: general GEMM");
182 detail::GemmImplUsingRuy<LhsScalar, RhsScalar, int32_t, int32_t,
183 quantization_flavor>::Run(lhs_params, lhs_data,
184 rhs_params, rhs_data,
185 dst_params, dst_data,
186 params, context);
187 }
188
189 } // namespace cpu_backend_gemm
190
191 } // namespace tflite
192
193 #endif // TENSORFLOW_LITE_KERNELS_CPU_BACKEND_GEMM_H_
194