• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_LITE_KERNELS_CPU_BACKEND_GEMM_H_
17 #define TENSORFLOW_LITE_KERNELS_CPU_BACKEND_GEMM_H_
18 
19 #include <cstdint>
20 
21 #include "ruy/profiler/instrumentation.h"  // from @ruy
22 #include "tensorflow/lite/kernels/cpu_backend_context.h"
23 #include "tensorflow/lite/kernels/cpu_backend_gemm_custom_gemv.h"
24 #include "tensorflow/lite/kernels/cpu_backend_gemm_params.h"
25 #include "tensorflow/lite/kernels/cpu_backend_gemm_ruy.h"
26 
27 #ifndef TFLITE_WITH_RUY
28 #include "tensorflow/lite/kernels/cpu_backend_gemm_eigen.h"
29 #include "tensorflow/lite/kernels/cpu_backend_gemm_gemmlowp.h"
30 #include "tensorflow/lite/kernels/cpu_backend_gemm_x86.h"
31 #endif
32 
33 namespace tflite {
34 
35 namespace cpu_backend_gemm {
36 
37 // The main entry point for CpuBackendGemm::Gemm.
38 //
39 // If TFLITE_WITH_RUY is set, CpuBackendGemm::Gemm will always go to Ruy aka
40 // GemmImplUsingRuy. Other cases are as follows:
41 //
42 //                    |Quantized (uint8)|Quantized (int8)| Float |
43 // TFLITE_WITH_RUY    |      Ruy        |      Ruy       | Ruy   |
44 // !TFLITE_WITH_RUY   |      gemmlowp   |  Ruy/gemmlowp* | eigen |
45 // * - Ruy if NEON is not available.
46 
47 //  On x86 platforms:
48 //  (default)         |      gemmlowp   |     Ruy        | eigen |
49 //  TFLITE_X86_RUY_\  |      Ruy        |     Ruy        | Ruy   |
50 //  ENABLED && (AVX
51 //  or above available)
52 
53 #if !defined(TFLITE_WITH_RUY) && defined(TFLITE_X86_PLATFORM)
54 /* GEMM dispatch implementation for x86.
55  */
56 template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
57           typename DstScalar, QuantizationFlavor quantization_flavor>
58 struct GemmImpl : detail::GemmImplX86<LhsScalar, RhsScalar, AccumScalar,
59                                       DstScalar, quantization_flavor> {};
60 #else
61 /* Generic implementation using ruy.
62  * Non-ruy implementation will be partial specializations of this template.
63  */
64 template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
65           typename DstScalar, QuantizationFlavor quantization_flavor>
66 struct GemmImpl : detail::GemmImplUsingRuy<LhsScalar, RhsScalar, AccumScalar,
67                                            DstScalar, quantization_flavor> {};
68 
69 #if !defined(TFLITE_WITH_RUY)
70 
71 /* Specializations using gemmlowp */
72 template <typename SrcScalar, typename DstScalar,
73           QuantizationFlavor quantization_flavor>
74 struct GemmImpl<SrcScalar, SrcScalar, std::int32_t, DstScalar,
75                 quantization_flavor>
76     : detail::GemmImplUsingGemmlowp<SrcScalar, SrcScalar, std::int32_t,
77                                     DstScalar, quantization_flavor> {};
78 
79 // When SrcScalar=int8 or DstScalar=int8, gemmlowp fails to compile
80 // outside of NEON. We avoid the compilation failure by subspecializing these
81 // cases, rerouting it back to ruy.
82 #if !defined(GEMMLOWP_NEON)
83 template <typename SrcScalar, QuantizationFlavor quantization_flavor>
84 struct GemmImpl<SrcScalar, SrcScalar, std::int32_t, std::int8_t,
85                 quantization_flavor>
86     : detail::GemmImplUsingRuy<SrcScalar, SrcScalar, std::int32_t, std::int8_t,
87                                quantization_flavor> {};
88 
89 template <typename DstScalar, QuantizationFlavor quantization_flavor>
90 struct GemmImpl<std::int8_t, std::int8_t, std::int32_t, DstScalar,
91                 quantization_flavor>
92     : detail::GemmImplUsingRuy<std::int8_t, std::int8_t, std::int32_t,
93                                DstScalar, quantization_flavor> {};
94 
95 template <QuantizationFlavor quantization_flavor>
96 struct GemmImpl<std::int8_t, std::int8_t, std::int32_t, std::int8_t,
97                 quantization_flavor>
98     : detail::GemmImplUsingRuy<std::int8_t, std::int8_t, std::int32_t,
99                                std::int8_t, quantization_flavor> {};
100 #endif  // not GEMMLOWP_NEON
101 
102 /* Specializations using Eigen */
103 
104 template <>
105 struct GemmImpl<float, float, float, float, QuantizationFlavor::kFloatingPoint>
106     : detail::GemmImplUsingEigen {};
107 
108 #endif  // not TFLITE_WITH_RUY
109 
110 #endif  // not TFLITE_WITH_RUY and TFLITE_X86_PLATFORM
111 
112 /* Public entry point */
113 
114 template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
115           typename DstScalar, QuantizationFlavor quantization_flavor>
Gemm(const MatrixParams<LhsScalar> & lhs_params,const LhsScalar * lhs_data,const MatrixParams<RhsScalar> & rhs_params,const RhsScalar * rhs_data,const MatrixParams<DstScalar> & dst_params,DstScalar * dst_data,const GemmParams<AccumScalar,DstScalar,quantization_flavor> & params,CpuBackendContext * context)116 void Gemm(const MatrixParams<LhsScalar>& lhs_params, const LhsScalar* lhs_data,
117           const MatrixParams<RhsScalar>& rhs_params, const RhsScalar* rhs_data,
118           const MatrixParams<DstScalar>& dst_params, DstScalar* dst_data,
119           const GemmParams<AccumScalar, DstScalar, quantization_flavor>& params,
120           CpuBackendContext* context) {
121   ruy::profiler::ScopeLabel label("cpu_backend_gemm::Gemm");
122   ValidateParams(lhs_params, rhs_params, dst_params, params);
123   // In some cases we want to unconditionally use ruy as the backend, overriding
124   // the `tflite_with_ruy` setting and the platform default.
125   bool must_use_ruy = false;
126   if (context->use_caching()) {
127     // Only ruy supports caching of pre-packed matrices. Due to the large
128     // performance impact in the cases where it's typically used, this overrides
129     // the default.
130     must_use_ruy = true;
131   }
132   if (lhs_params.order != Order::kRowMajor ||
133       rhs_params.order != Order::kColMajor ||
134       dst_params.order != Order::kColMajor) {
135     // ruy supports all 2^3=8 combinations of storage orders with comparable
136     // performance. In ruy, it's only a runtime switch. In other backends
137     // (gemmlowp, Eigen), storage orders are template parameters, supporting
138     // all 8 combinations would be up to a 8-fold code size increase, so we
139     // prefer to force usage of ruy in these cases.
140     must_use_ruy = true;
141   }
142   if (must_use_ruy) {
143     detail::GemmImplUsingRuy<LhsScalar, RhsScalar, AccumScalar, DstScalar,
144                              quantization_flavor>::Run(lhs_params, lhs_data,
145                                                        rhs_params, rhs_data,
146                                                        dst_params, dst_data,
147                                                        params, context);
148     return;
149   }
150   // If we did not choose to force usage of ruy above, then we may now consider
151   // using custom GEMV code for the matrix*vector cases.
152   const bool try_custom_gemv = (dst_params.cols == 1);
153   if (try_custom_gemv) {
154     // GEMV case: try a custom fast GEMV path. It will return true if it
155     // actually handled it.
156     if (detail::CustomGemv(lhs_params, lhs_data, rhs_params, rhs_data,
157                            dst_params, dst_data, params, context)) {
158       return;
159     }
160   }
161   // Generic case: dispatch to any backend as a general GEMM.
162   GemmImpl<LhsScalar, RhsScalar, AccumScalar, DstScalar,
163            quantization_flavor>::Run(lhs_params, lhs_data, rhs_params, rhs_data,
164                                      dst_params, dst_data, params, context);
165 }
166 
167 // Special path for gemm with raw accumulator case. i.e. AccumScalar ==
168 // DstScalar == int32 case.
169 template <typename LhsScalar, typename RhsScalar,
170           QuantizationFlavor quantization_flavor>
Gemm(const MatrixParams<LhsScalar> & lhs_params,const LhsScalar * lhs_data,const MatrixParams<RhsScalar> & rhs_params,const RhsScalar * rhs_data,const MatrixParams<int32_t> & dst_params,int32_t * dst_data,const GemmParams<int32_t,int32_t,quantization_flavor> & params,CpuBackendContext * context)171 void Gemm(const MatrixParams<LhsScalar>& lhs_params, const LhsScalar* lhs_data,
172           const MatrixParams<RhsScalar>& rhs_params, const RhsScalar* rhs_data,
173           const MatrixParams<int32_t>& dst_params, int32_t* dst_data,
174           const GemmParams<int32_t, int32_t, quantization_flavor>& params,
175           CpuBackendContext* context) {
176   ruy::profiler::ScopeLabel label("cpu_backend_gemm::Gemm");
177   ValidateParams(lhs_params, rhs_params, dst_params, params);
178 
179   // Currently, only Ruy backend supports get raw accumulator, so we use ruy
180   // only.
181   ruy::profiler::ScopeLabel label2("cpu_backend_gemm::Gemm: general GEMM");
182   detail::GemmImplUsingRuy<LhsScalar, RhsScalar, int32_t, int32_t,
183                            quantization_flavor>::Run(lhs_params, lhs_data,
184                                                      rhs_params, rhs_data,
185                                                      dst_params, dst_data,
186                                                      params, context);
187 }
188 
189 }  // namespace cpu_backend_gemm
190 
191 }  // namespace tflite
192 
193 #endif  // TENSORFLOW_LITE_KERNELS_CPU_BACKEND_GEMM_H_
194