1 // Copyright 2017 The Gemmlowp Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 // dispatch_gemm_shape.h: dispatch GEMM calls according to their shape
16
17 #ifndef GEMMLOWP_INTERNAL_DISPATCH_GEMM_SHAPE_H_
18 #define GEMMLOWP_INTERNAL_DISPATCH_GEMM_SHAPE_H_
19
20 #include "../internal/kernel_default.h"
21 #include "../public/map.h"
22 #include "../public/output_stages.h"
23 #include "multi_thread_gemm.h"
24
25 namespace gemmlowp {
26
27 template <typename T>
28 struct TransposeImpl {
29 typedef T DstType;
RunTransposeImpl30 static T Run(const T& t) { return t; }
31 };
32
33 template <typename T>
34 using TransposeType = typename TransposeImpl<T>::DstType;
35
36 template <typename T>
Transpose(const T & t)37 TransposeType<T> Transpose(const T& t) {
38 return TransposeImpl<T>::Run(t);
39 }
40
41 template <MapOrder Order>
42 struct TransposeMapOrder {
43 static constexpr MapOrder Value =
44 Order == MapOrder::RowMajor ? MapOrder::ColMajor : MapOrder::RowMajor;
45 };
46
47 template <VectorShape Shape>
48 struct TransposeVectorShape {
49 static constexpr VectorShape Value =
50 Shape == VectorShape::Row ? VectorShape::Col : VectorShape::Row;
51 };
52
53 template <typename Scalar, VectorShape Shape>
54 struct TransposeImpl<VectorMap<Scalar, Shape>> {
55 typedef VectorMap<Scalar, Shape> SrcType;
56 static constexpr VectorShape TransposedShape =
57 TransposeVectorShape<Shape>::Value;
58 typedef VectorMap<Scalar, TransposedShape> DstType;
59 static DstType Run(const SrcType& src) {
60 return DstType(src.data(), src.size());
61 }
62 };
63
64 template <typename Scalar, MapOrder Order>
65 struct TransposeImpl<MatrixMap<Scalar, Order>> {
66 typedef MatrixMap<Scalar, Order> SrcType;
67 static constexpr MapOrder TransposedOrder = TransposeMapOrder<Order>::Value;
68 typedef MatrixMap<Scalar, TransposedOrder> DstType;
69 static DstType Run(const SrcType& src) {
70 return DstType(src.data(), src.cols(), src.rows(), src.stride());
71 }
72 };
73
74 template <VectorShape Shape>
75 struct TransposeImpl<OutputStageQuantizeDownInt32ToUint8ScalePC<Shape>> {
76 typedef OutputStageQuantizeDownInt32ToUint8ScalePC<Shape> SrcType;
77 static const VectorShape TransposedShape = TransposeVectorShape<Shape>::Value;
78 typedef OutputStageQuantizeDownInt32ToUint8ScalePC<TransposedShape> DstType;
79 static DstType Run(const SrcType& src) {
80 DstType dst;
81 dst.result_shift = src.result_shift;
82 dst.result_offset = Transpose(src.result_offset);
83 dst.result_mult_int = Transpose(src.result_mult_int);
84 return dst;
85 }
86 };
87
88 template <VectorShape Shape>
89 struct TransposeImpl<OutputStageScaleInt32ByFixedPointAndExponentPC<Shape>> {
90 typedef OutputStageScaleInt32ByFixedPointAndExponentPC<Shape> SrcType;
91 static const VectorShape TransposedShape = TransposeVectorShape<Shape>::Value;
92 typedef OutputStageScaleInt32ByFixedPointAndExponentPC<TransposedShape>
93 DstType;
94 static DstType Run(const SrcType& src) {
95 DstType dst;
96 dst.result_fixedpoint_multiplier =
97 Transpose(src.result_fixedpoint_multiplier);
98 dst.result_exponent = Transpose(src.result_exponent);
99 dst.result_offset_after_shift = src.result_offset_after_shift;
100 return dst;
101 }
102 };
103
104 template <typename VectorMapType>
105 struct TransposeImpl<OutputStageBiasAddition<VectorMapType>> {
106 typedef OutputStageBiasAddition<VectorMapType> SrcType;
107 typedef TransposeType<VectorMapType> TransposedVectorMapType;
108 typedef OutputStageBiasAddition<TransposedVectorMapType> DstType;
109 static DstType Run(const SrcType& src) {
110 DstType dst;
111 dst.bias_vector = Transpose(src.bias_vector);
112 return dst;
113 }
114 };
115
116 // TODO(benoitjacob) - does anyone understand C++ variadic templates?
117 // How to use them to implement TransposeTuple? Note: there are lots
118 // of answers on StackOverflow but they seem to all involve either
119 // C++14/C++17 (we can only use C++11) or lots of abstract nonsense.
120 inline std::tuple<> TransposeTuple(const std::tuple<>& t) { return t; }
121
122 template <typename T0>
123 std::tuple<TransposeType<T0>> TransposeTuple(const std::tuple<T0>& t) {
124 return std::make_tuple(Transpose(std::get<0>(t)));
125 }
126
127 template <typename T0, typename T1>
128 std::tuple<TransposeType<T0>, TransposeType<T1>> TransposeTuple(
129 const std::tuple<T0, T1>& t) {
130 return std::make_tuple(Transpose(std::get<0>(t)), Transpose(std::get<1>(t)));
131 }
132
133 template <typename T0, typename T1, typename T2>
134 std::tuple<TransposeType<T0>, TransposeType<T1>, TransposeType<T2>>
135 TransposeTuple(const std::tuple<T0, T1, T2>& t) {
136 return std::make_tuple(Transpose(std::get<0>(t)), Transpose(std::get<1>(t)),
137 Transpose(std::get<2>(t)));
138 }
139
140 template <typename T0, typename T1, typename T2, typename T3>
141 std::tuple<TransposeType<T0>, TransposeType<T1>, TransposeType<T2>,
142 TransposeType<T3>>
143 TransposeTuple(const std::tuple<T0, T1, T2, T3>& t) {
144 return std::make_tuple(Transpose(std::get<0>(t)), Transpose(std::get<1>(t)),
145 Transpose(std::get<2>(t)), Transpose(std::get<3>(t)));
146 }
147
148 template <typename T0, typename T1, typename T2, typename T3, typename T4>
149 std::tuple<TransposeType<T0>, TransposeType<T1>, TransposeType<T2>,
150 TransposeType<T3>, TransposeType<T4>>
151 TransposeTuple(const std::tuple<T0, T1, T2, T3, T4>& t) {
152 return std::make_tuple(Transpose(std::get<0>(t)), Transpose(std::get<1>(t)),
153 Transpose(std::get<2>(t)), Transpose(std::get<3>(t)),
154 Transpose(std::get<4>(t)));
155 }
156
157 template <typename T0, typename T1, typename T2, typename T3, typename T4,
158 typename T5>
159 std::tuple<TransposeType<T0>, TransposeType<T1>, TransposeType<T2>,
160 TransposeType<T3>, TransposeType<T4>, TransposeType<T5>>
161 TransposeTuple(const std::tuple<T0, T1, T2, T3, T4, T5>& t) {
162 return std::make_tuple(Transpose(std::get<0>(t)), Transpose(std::get<1>(t)),
163 Transpose(std::get<2>(t)), Transpose(std::get<3>(t)),
164 Transpose(std::get<4>(t)), Transpose(std::get<5>(t)));
165 }
166
167 template <typename InputScalar, typename OutputScalar, typename BitDepthParams,
168 MapOrder LhsOrder, MapOrder RhsOrder, MapOrder ResultOrder,
169 typename LhsOffset, typename RhsOffset, typename OutputPipelineType,
170 typename GemmContextType>
171 void DispatchGemmShape(GemmContextType* context,
172 const MatrixMap<const InputScalar, LhsOrder>& lhs,
173 const MatrixMap<const InputScalar, RhsOrder>& rhs,
174 MatrixMap<OutputScalar, ResultOrder>* result,
175 const LhsOffset& lhs_offset, const RhsOffset& rhs_offset,
176 const OutputPipelineType& output_pipeline) {
177 assert(lhs.cols() == rhs.rows());
178
179 int rows = result->rows();
180 int cols = result->cols();
181 int depth = lhs.cols();
182
183 if (rows == 0 || cols == 0 || depth == 0) {
184 // Vacuous GEMM, return early to avoid having to deal with
185 // zero sizes below.
186 return;
187 }
188
189 if (rows < cols) {
190 auto transposed_result_map = Transpose(*result);
191 return DispatchGemmShape<InputScalar, OutputScalar, BitDepthParams>(
192 context, Transpose(rhs), Transpose(lhs), &transposed_result_map,
193 Transpose(rhs_offset), Transpose(lhs_offset),
194 TransposeTuple(output_pipeline));
195 }
196
197 typedef DefaultKernel<BitDepthParams> Kernel;
198 MultiThreadGemm<typename Kernel::Format, InputScalar, OutputScalar,
199 BitDepthParams>(context, Kernel(), lhs, rhs, result,
200 lhs_offset, rhs_offset, output_pipeline);
201 }
202
203 } // end namespace gemmlowp
204
205 #endif // GEMMLOWP_INTERNAL_DISPATCH_GEMM_SHAPE_H_
206