• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2015 The Gemmlowp Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 // gemmlowp.h: the main public interface header of gemmlowp.
16 
17 #ifndef GEMMLOWP_PUBLIC_GEMMLOWP_H_
18 #define GEMMLOWP_PUBLIC_GEMMLOWP_H_
19 #include "../internal/dispatch_gemm_shape.h"
20 #include "bit_depth.h"
21 #include "map.h"
22 #include "output_stages.h"
23 
24 namespace gemmlowp {
25 
26 class GemmContext : public MultiThreadGemmContext {};
27 
28 // Computes a general matrix product ("GEMM").
29 // This is a version that supports per channel quantization.
30 template <typename InputScalar, typename OutputScalar, typename BitDepthParams,
31           MapOrder LhsOrder, MapOrder RhsOrder, MapOrder ResultOrder,
32           typename LhsOffset, typename RhsOffset, typename OutputPipelineType,
33           typename GemmContextType>
GemmWithOutputPipelinePC(GemmContextType * context,const MatrixMap<const InputScalar,LhsOrder> & lhs,const MatrixMap<const InputScalar,RhsOrder> & rhs,MatrixMap<OutputScalar,ResultOrder> * result,const LhsOffset & lhs_offset,const RhsOffset & rhs_offset,const OutputPipelineType & output_pipeline)34 void GemmWithOutputPipelinePC(GemmContextType* context,
35                               const MatrixMap<const InputScalar, LhsOrder>& lhs,
36                               const MatrixMap<const InputScalar, RhsOrder>& rhs,
37                               MatrixMap<OutputScalar, ResultOrder>* result,
38                               const LhsOffset& lhs_offset,
39                               const RhsOffset& rhs_offset,
40                               const OutputPipelineType& output_pipeline) {
41   DispatchGemmShape<InputScalar, OutputScalar, BitDepthParams>(
42       context, lhs, rhs, result, lhs_offset, rhs_offset, output_pipeline);
43 }
44 
45 // Computes a general matrix product ("GEMM").
46 // This is the legacy version that does not support per channel quantization.
47 // The meaning of the offsets, result_mult_int and result_shift
48 // parameters is the same as in the standard EightBitIntGemm interface
49 // (which is also implemented in the eight_bit_int_gemm directory).
50 template <typename InputScalar, typename OutputScalar, typename BitDepthParams,
51           MapOrder LhsOrder, MapOrder RhsOrder, MapOrder ResultOrder,
52           typename OutputPipelineType, typename GemmContextType>
GemmWithOutputPipeline(GemmContextType * context,const MatrixMap<const InputScalar,LhsOrder> & lhs,const MatrixMap<const InputScalar,RhsOrder> & rhs,MatrixMap<OutputScalar,ResultOrder> * result,int lhs_offset,int rhs_offset,const OutputPipelineType & output_pipeline)53 void GemmWithOutputPipeline(GemmContextType* context,
54                             const MatrixMap<const InputScalar, LhsOrder>& lhs,
55                             const MatrixMap<const InputScalar, RhsOrder>& rhs,
56                             MatrixMap<OutputScalar, ResultOrder>* result,
57                             int lhs_offset, int rhs_offset,
58                             const OutputPipelineType& output_pipeline) {
59   typedef VectorDup<const std::int32_t, VectorShape::Col> OffsetColDup;
60   typedef VectorDup<const std::int32_t, VectorShape::Row> OffsetRowDup;
61   const OffsetColDup lhs_offset_vector(lhs_offset, lhs.rows());
62   const OffsetRowDup rhs_offset_vector(rhs_offset, rhs.cols());
63   DispatchGemmShape<InputScalar, OutputScalar, BitDepthParams>(
64       context, lhs, rhs, result, lhs_offset_vector, rhs_offset_vector,
65       output_pipeline);
66 }
67 
68 // Computes a general matrix product ("GEMM").
69 // The meaning of the offsets, result_mult_int and result_shift
70 // parameters is the same as in the standard EightBitIntGemm interface
71 // (which is also implemented in the eight_bit_int_gemm directory).
72 template <typename Scalar, typename BitDepthParams, MapOrder LhsOrder,
73           MapOrder RhsOrder, MapOrder ResultOrder, typename GemmContextType>
Gemm(GemmContextType * context,const MatrixMap<const Scalar,LhsOrder> & lhs,const MatrixMap<const Scalar,RhsOrder> & rhs,MatrixMap<Scalar,ResultOrder> * result,int lhs_offset,int rhs_offset,int result_offset,int result_mult_int,int result_shift)74 void Gemm(GemmContextType* context,
75           const MatrixMap<const Scalar, LhsOrder>& lhs,
76           const MatrixMap<const Scalar, RhsOrder>& rhs,
77           MatrixMap<Scalar, ResultOrder>* result, int lhs_offset,
78           int rhs_offset, int result_offset, int result_mult_int,
79           int result_shift) {
80   GemmWithOutputPipeline<Scalar, Scalar, BitDepthParams>(
81       context, lhs, rhs, result, lhs_offset, rhs_offset,
82       MakeStandardOutputPipeline(result_offset, result_mult_int, result_shift));
83 }
84 
85 }  // namespace gemmlowp
86 
87 #endif  // GEMMLOWP_PUBLIC_GEMMLOWP_H_
88