1 /* 2 * Copyright (c) 2017-2020 Arm Limited. 3 * 4 * SPDX-License-Identifier: MIT 5 * 6 * Permission is hereby granted, free of charge, to any person obtaining a copy 7 * of this software and associated documentation files (the "Software"), to 8 * deal in the Software without restriction, including without limitation the 9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or 10 * sell copies of the Software, and to permit persons to whom the Software is 11 * furnished to do so, subject to the following conditions: 12 * 13 * The above copyright notice and this permission notice shall be included in all 14 * copies or substantial portions of the Software. 15 * 16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 * SOFTWARE. 23 */ 24 #pragma once 25 26 #include "convolution_parameters.hpp" 27 #include "ndrange.hpp" 28 29 #include <cstddef> 30 31 namespace arm_gemm 32 { 33 // Abstract class for the GEMM/GEMV functions. 34 // 35 // GEMM implementations may be "native" (never require any input 36 // permutation), "pretransposed" (require permutation up-front) or require 37 // working space (permute as they go along). This interface should support 38 // all of them. 39 40 // The real GemmCommon class is templated based on the operand and return 41 // type. This is an interface class which is independent of those types. 42 class IGemmCommon 43 { 44 public: 45 /* Pass in the pointers to the arrays to be operated on and their 46 * strides. This "generic" version uses void *s, the preferred version 47 * is the one provided by templated GemmCommon (below) which takes 48 * appropriately typed pointers. If B is pretransposed (see below) then 49 * the settings for B here are ignored. 50 */ 51 virtual void set_arrays_generic(const void *A, const int lda, const int A_batch_stride, const int A_multi_stride, 52 const void *B, const int ldb, /* batches share B */ const int B_multi_stride, 53 void *C, const int ldc, const int C_batch_stride, const int C_multi_stride, 54 const void *bias, /* no row or batch stride needed */ const int bias_multi_stride) = 0; 55 56 /** @returns an ndrange containing ranges of the compute space which can be 57 * broken up and parallelised over 58 */ 59 virtual ndrange_t get_window_size() const = 0; 60 61 /* The maximum thread count is specified when the GEMM is created. Some 62 * implementations need to know how many threads will actually run in 63 * order to work properly. 64 * 65 * In some cases, after creating the GEMM the number of threads needs to 66 * be reduced (e.g. not enough work to split across threads). This 67 * method allows the number of actual threads to be run to be set (must 68 * be equal or lower). 69 * 70 * This has an empty default implementation, as GEMMs which don't care 71 * about thread count can safely ignore this. 72 */ set_nthreads(int)73 virtual void set_nthreads(int) {}; 74 75 /* Whether this GEMM can be dynamically scheduled or not. */ supports_dynamic_scheduling() const76 virtual bool supports_dynamic_scheduling() const 77 { 78 return false; 79 } 80 81 /** Main execute member fucntion 82 * @param [in] work_range specifies the range of work we want to be computed, total range defined by get_window_size() 83 * @param [in] thread_locator where are we inside of the thread space 84 * @naram [in] threadid a unique threadid 85 */ 86 virtual void execute(const ndcoord_t &work_range, const ndcoord_t &thread_locator, int threadid) = 0; 87 88 /*** Working space interface (optional) ***/ 89 /* Total number of bytes of temporary working space needed. If zero, it's not necessary to call set_working_space(). */ get_working_size() const90 virtual size_t get_working_size() const 91 { 92 return 0; 93 } 94 /* Provide working space buffer - the void * passed in must remain allocated for the duration of any execute calls. */ set_working_space(void *)95 virtual void set_working_space(void *) {}; 96 97 /*** "Pretransposed" interface (optional) ***/ 98 /* Is this object set up for pretranspose? If so, pretranspose_array() needs to be called before execute(); */ B_is_pretransposed() const99 virtual bool B_is_pretransposed() const 100 { 101 return false; 102 } 103 /* Does pretranspose still need to be done? */ B_pretranspose_required() const104 virtual bool B_pretranspose_required() const 105 { 106 return false; 107 } 108 /* Total number of bytes of space needed for pretransposed arrays. */ get_B_pretransposed_array_size() const109 virtual size_t get_B_pretransposed_array_size() const 110 { 111 return 0; 112 } 113 /* Perform pretranspose - arguments are output, input, input row stride and input multi stride. */ 114 /* The "real" version of this depends on the templated operand type (see below). */ 115 virtual void pretranspose_B_array_generic(void *, const void *, const int, const int) = 0; 116 /* Set pretransposed data - the void * passed in must previously have been passed to pretranspose_B_array() for the same or a similar GEMM. */ set_pretransposed_B_data(void *)117 virtual void set_pretransposed_B_data(void *) 118 { 119 } 120 121 /*** "Quantized bias" interface (optional) ***/ 122 /* Set the bias vector for quantized GEMMs */ set_quantized_bias(const int32_t *,size_t)123 virtual void set_quantized_bias(const int32_t *, size_t) 124 { 125 } 126 127 /*** Indirect interface (optional) ***/ 128 /* Set the indirect table. This comprises a number of values per kernel point, and a densely packed array of pointers, 129 * multis * batches * kernel_points */ set_indirect_parameters_generic(size_t,const void * const * const *)130 virtual void set_indirect_parameters_generic(size_t, const void *const *const *) 131 { 132 } 133 134 /*** Convolution interface (optional) ***/ 135 /* Set the convolution parameters. */ set_convolution_parameters(ConvolutionParameters)136 virtual void set_convolution_parameters(ConvolutionParameters) 137 { 138 } 139 140 // Destructor ~IGemmCommon()141 virtual ~IGemmCommon() 142 { 143 } 144 }; 145 146 /* "Real" GemmCommon class which is templated on the operand and return types. 147 * 148 * In addition to correctly typed versions of the functions that operate on 149 * operand and return data, this class provides a default implementation of 150 * 'set_arrays' to capture the provided arguments in protected class 151 * members, as essentially any implementation will need these. 152 */ 153 template <typename To, typename Tr> 154 class GemmCommon : public IGemmCommon 155 { 156 protected: 157 const To *_Aptr = nullptr; 158 int _lda = 0; 159 int _A_batch_stride = 0; 160 int _A_multi_stride = 0; 161 const To *_Bptr = nullptr; 162 int _ldb = 0; 163 int _B_multi_stride = 0; 164 Tr *_Cptr = nullptr; 165 int _ldc = 0; 166 int _C_batch_stride = 0; 167 int _C_multi_stride = 0; 168 const Tr *_bias = nullptr; 169 int _bias_multi_stride = 0; 170 171 public: 172 /* Pass in the pointers to the arrays to be operated on and their 173 * strides (templated version with appropriate types). */ set_arrays(const To * A,const int lda,const int A_batch_stride,const int A_multi_stride,const To * B,const int ldb,const int B_multi_stride,Tr * C,const int ldc,const int C_batch_stride,const int C_multi_stride,const Tr * bias,const int bias_multi_stride)174 virtual void set_arrays(const To *A, const int lda, const int A_batch_stride, const int A_multi_stride, 175 const To *B, const int ldb, /* batches share B */ const int B_multi_stride, 176 Tr *C, const int ldc, const int C_batch_stride, const int C_multi_stride, 177 const Tr *bias, /* no row or batch stride needed */ const int bias_multi_stride) 178 { 179 _Aptr = A; 180 _lda = lda; 181 _A_batch_stride = A_batch_stride; 182 _A_multi_stride = A_multi_stride; 183 _Bptr = B; 184 _ldb = ldb; 185 _B_multi_stride = B_multi_stride; 186 _Cptr = C; 187 _ldc = ldc; 188 _C_batch_stride = C_batch_stride; 189 _C_multi_stride = C_multi_stride; 190 _bias = bias; 191 _bias_multi_stride = bias_multi_stride; 192 } 193 194 /* Implementation of the void * overload which casts its arguments to the appropriate type. */ set_arrays_generic(const void * A,const int lda,const int A_batch_stride,const int A_multi_stride,const void * B,const int ldb,const int B_multi_stride,void * C,const int ldc,const int C_batch_stride,const int C_multi_stride,const void * bias,const int bias_multi_stride)195 void set_arrays_generic(const void *A, const int lda, const int A_batch_stride, const int A_multi_stride, 196 const void *B, const int ldb, /* batches share B */ const int B_multi_stride, 197 void *C, const int ldc, const int C_batch_stride, const int C_multi_stride, 198 const void *bias, /* no row or batch stride needed */ const int bias_multi_stride) override 199 { 200 set_arrays(static_cast<const To *>(A), lda, A_batch_stride, A_multi_stride, 201 static_cast<const To *>(B), ldb, B_multi_stride, 202 static_cast<Tr *>(C), ldc, C_batch_stride, C_multi_stride, 203 static_cast<const Tr *>(bias), bias_multi_stride); 204 } 205 206 /*** "Pretransposed" interface ***/ 207 208 /* Perform pretranspose - the void * passed in must remain allocated for the duration of any execute calls. */ 209 /* Arguments are: output buffer pointer, source pointer, source row stride, source multi stride */ pretranspose_B_array(void *,const To *,const int,const int)210 virtual void pretranspose_B_array(void *, const To *, const int, const int) {}; 211 212 /* Implementation of the void * overload which casts its arguments to the appropriate type. */ pretranspose_B_array_generic(void * out,const void * in,const int row_stride,const int multi_stride)213 void pretranspose_B_array_generic(void *out, const void *in, const int row_stride, const int multi_stride) override 214 { 215 pretranspose_B_array(out, static_cast<const To *>(in), row_stride, multi_stride); 216 } 217 218 /*** Indirect interface ***/ set_indirect_parameters(size_t,const To * const * const *)219 virtual void set_indirect_parameters(size_t, const To *const *const *) 220 { 221 } 222 set_indirect_parameters_generic(size_t sz,const void * const * const * ptr)223 void set_indirect_parameters_generic(size_t sz, const void *const *const *ptr) override 224 { 225 set_indirect_parameters(sz, reinterpret_cast<const To *const *const *>(ptr)); 226 } 227 }; 228 229 } // namespace arm_gemm 230