1 /* 2 * Copyright (c) 2018-2020 Arm Limited. 3 * 4 * SPDX-License-Identifier: MIT 5 * 6 * Permission is hereby granted, free of charge, to any person obtaining a copy 7 * of this software and associated documentation files (the "Software"), to 8 * deal in the Software without restriction, including without limitation the 9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or 10 * sell copies of the Software, and to permit persons to whom the Software is 11 * furnished to do so, subject to the following conditions: 12 * 13 * The above copyright notice and this permission notice shall be included in all 14 * copies or substantial portions of the Software. 15 * 16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 * SOFTWARE. 23 */ 24 #ifndef ARM_COMPUTE_LSTMPARAMS_H 25 #define ARM_COMPUTE_LSTMPARAMS_H 26 27 #include "arm_compute/core/IPyramid.h" 28 #include "arm_compute/core/PyramidInfo.h" 29 #include "arm_compute/core/Types.h" 30 #include "arm_compute/runtime/Tensor.h" 31 32 #include <cstddef> 33 #include <memory> 34 35 namespace arm_compute 36 { 37 template <typename T> 38 class LSTMParams 39 { 40 public: 41 /** Constructor */ LSTMParams()42 LSTMParams() 43 : _input_to_input_weights(nullptr), 44 _recurrent_to_input_weights(nullptr), 45 _cell_to_input_weights(nullptr), 46 _input_gate_bias(nullptr), 47 _cell_to_forget_weights(nullptr), 48 _cell_to_output_weights(nullptr), 49 _projection_weights(nullptr), 50 _projection_bias(nullptr), 51 _input_layer_norm_weights(nullptr), 52 _forget_layer_norm_weights(nullptr), 53 _cell_layer_norm_weights(nullptr), 54 _output_layer_norm_weights(nullptr), 55 _cell_clip(0.f), 56 _projection_clip(0.0f), 57 _input_intermediate_scale(0.0f), 58 _forget_intermediate_scale(0.0f), 59 _cell_intermediate_scale(0.0f), 60 _output_intermediate_scale(0.0f), 61 _hidden_state_zero(0), 62 _hidden_state_scale(0.0f), 63 _has_peephole_opt(false), 64 _has_projection(false), 65 _has_cifg_opt(true), 66 _use_layer_norm(false) 67 { 68 } 69 /** Prevent instances of this class from being copied (As this class contains pointers) */ 70 LSTMParams(const LSTMParams &) = delete; 71 /** Prevent instances of this class from being copied (As this class contains pointers) */ 72 LSTMParams &operator=(const LSTMParams &) = delete; 73 /** Default destructor */ 74 ~LSTMParams() = default; 75 /** Set CIFG tensor parameters. 76 * 77 * @param[in] input_to_input_weights 2D weights tensor with dimensions [input_size, num_units]. Data types supported: QSYMM8/F16/F32. 78 * @param[in] recurrent_to_input_weights 2D weights tensor with dimensions [output_size, num_units]. Data type supported: Same as @p input_to_input_weights. 79 * @param[in] cell_to_input_weights 1D weights tensor with dimensions [num_units]. Can be nullptr. Data type supported: Same as @p input_to_input_weights. 80 * @param[in] input_gate_bias 1D weights tensor with dimensions [num_units]. Data type supported: Same as @p input_to_input_weights, S32 when @p input_to_input_weights is QSYMM8 81 * 82 * @return Reference to this LSTMParams object 83 */ set_cifg_params(const T * input_to_input_weights,const T * recurrent_to_input_weights,T * cell_to_input_weights,const T * input_gate_bias)84 LSTMParams &set_cifg_params(const T *input_to_input_weights, const T *recurrent_to_input_weights, T *cell_to_input_weights, const T *input_gate_bias) 85 { 86 _input_to_input_weights = input_to_input_weights; 87 _recurrent_to_input_weights = recurrent_to_input_weights; 88 _cell_to_input_weights = cell_to_input_weights; 89 _input_gate_bias = input_gate_bias; 90 _has_cifg_opt = false; 91 return *this; 92 } 93 /** Set projection tensor parameters. 94 * 95 * @param[in] projection_weights 2D weights tensor with dimensions [output_size, num_units]. Data type supported: Data types supported: QSYMM8/F16/F32. 96 * @param[in] projection_bias 1D weights tensor with dimensions [output_size]. Data type supported: Same as @p projection_weights, S32 when @p input_to_input_weights is QSYMM8. 97 * 98 * @return Reference to this LSTMParams object 99 */ set_projection_params(const T * projection_weights,const T * projection_bias)100 LSTMParams &set_projection_params(const T *projection_weights, const T *projection_bias) 101 { 102 _projection_weights = projection_weights; 103 _projection_bias = projection_bias; 104 _has_projection = true; 105 return *this; 106 } 107 /** Set peephole tensor parameters. 108 * 109 * @param[in] cell_to_forget_weights 1D weights tensor with dimensions [num_units]. Data type supported: Data types supported: QSYMM16/F16/F32. 110 * @param[in] cell_to_output_weights 1D weights tensor with dimensions [num_units]. Data type supported: Same as @p cell_to_forget_weights. 111 * 112 * @return Reference to this LSTMParams object 113 */ set_peephole_params(T * cell_to_forget_weights,T * cell_to_output_weights)114 LSTMParams &set_peephole_params(T *cell_to_forget_weights, T *cell_to_output_weights) 115 { 116 _cell_to_forget_weights = cell_to_forget_weights; 117 _cell_to_output_weights = cell_to_output_weights; 118 _has_peephole_opt = true; 119 return *this; 120 } 121 /** Set layer normalization tensor parameters. 122 * 123 * @param[in] input_layer_norm_weights 1D weights tensor with dimensions [num_units]. Data type supported: Data types supported: QSYMM16/F16/F32. 124 * @param[in] forget_layer_norm_weights 1D weights tensor with dimensions [num_units]. Data type supported: Same as @p input_layer_norm_weights. 125 * @param[in] cell_layer_norm_weights 1D weights tensor with dimensions [num_units]. Data type supported: Same as @p input_layer_norm_weights. 126 * @param[in] output_layer_norm_weights 1D weights tensor with dimensions [num_units]. Data type supported: Same as @p input_layer_norm_weights. 127 * 128 * @return Reference to this LSTMParams object 129 */ set_layer_normalization_params(T * input_layer_norm_weights,T * forget_layer_norm_weights,T * cell_layer_norm_weights,T * output_layer_norm_weights)130 LSTMParams &set_layer_normalization_params(T *input_layer_norm_weights, T *forget_layer_norm_weights, 131 T *cell_layer_norm_weights, T *output_layer_norm_weights) 132 { 133 _input_layer_norm_weights = input_layer_norm_weights; 134 _forget_layer_norm_weights = forget_layer_norm_weights; 135 _cell_layer_norm_weights = cell_layer_norm_weights; 136 _output_layer_norm_weights = output_layer_norm_weights; 137 _use_layer_norm = true; 138 return *this; 139 } 140 141 /** Set cell clip value. 142 * 143 * @param[in] cell_clip Value to be used to clip the cell state prior to the cell output activation. 144 * 145 * @return Reference to this LSTMParams object 146 */ set_cell_clip_params(float cell_clip)147 LSTMParams &set_cell_clip_params(float cell_clip) 148 { 149 _cell_clip = cell_clip; 150 return *this; 151 } 152 153 /** Set projection clip value. 154 * 155 * @param[in] projection_clip Value to be used to clip the projection, in case projection is enabled. 156 * 157 * @return Reference to this LSTMParams object 158 */ set_projection_clip_params(float projection_clip)159 LSTMParams &set_projection_clip_params(float projection_clip) 160 { 161 _projection_clip = projection_clip; 162 return *this; 163 } 164 165 /** Set scale of the intermediate results of matmul of each layer parameters. 166 * 167 * @param[in] input_intermediate_scale Scale of the intermediate result of matmul, i.e. input to layer normalization, at input gate. 168 * @param[in] forget_intermediate_scale Scale of the intermediate result of matmul, i.e. input to layer normalization, at forget gate. 169 * @param[in] cell_intermediate_scale Scale of the intermediate result of matmul, i.e. input to layer normalization, at cell gate. 170 * @param[in] output_intermediate_scale Scale of the intermediate result of matmul, i.e. input to layer normalization, at output gate. 171 * 172 * @return Reference to this LSTMParams object 173 */ set_matmul_scale_params(float input_intermediate_scale,float forget_intermediate_scale,float cell_intermediate_scale,float output_intermediate_scale)174 LSTMParams &set_matmul_scale_params(float input_intermediate_scale, float forget_intermediate_scale, float cell_intermediate_scale, float output_intermediate_scale) 175 { 176 _input_intermediate_scale = input_intermediate_scale; 177 _forget_intermediate_scale = forget_intermediate_scale; 178 _cell_intermediate_scale = cell_intermediate_scale; 179 _output_intermediate_scale = output_intermediate_scale; 180 return *this; 181 } 182 183 /** Set hidden state zero and scale parameters. 184 * 185 * @param[in] hidden_state_zero The zero point of the hidden state. 186 * @param[in] hidden_state_scale The scale of the hidden state. 187 * 188 * @return Reference to this LSTMParams object 189 */ set_hidden_state_params(int32_t hidden_state_zero,float hidden_state_scale)190 LSTMParams &set_hidden_state_params(int32_t hidden_state_zero, float hidden_state_scale) 191 { 192 _hidden_state_zero = hidden_state_zero; 193 _hidden_state_scale = hidden_state_scale; 194 return *this; 195 } 196 input_to_input_weights()197 const T *input_to_input_weights() const 198 { 199 return _input_to_input_weights; 200 } 201 recurrent_to_input_weights()202 const T *recurrent_to_input_weights() const 203 { 204 return _recurrent_to_input_weights; 205 } 206 cell_to_input_weights()207 T *cell_to_input_weights() const 208 { 209 return _cell_to_input_weights; 210 } 211 input_gate_bias()212 const T *input_gate_bias() const 213 { 214 return _input_gate_bias; 215 } 216 cell_to_forget_weights()217 T *cell_to_forget_weights() const 218 { 219 return _cell_to_forget_weights; 220 } 221 cell_to_output_weights()222 T *cell_to_output_weights() const 223 { 224 return _cell_to_output_weights; 225 } 226 projection_weights()227 const T *projection_weights() const 228 { 229 return _projection_weights; 230 } 231 projection_bias()232 const T *projection_bias() const 233 { 234 return _projection_bias; 235 } 236 input_layer_norm_weights()237 T *input_layer_norm_weights() const 238 { 239 return _input_layer_norm_weights; 240 } 241 forget_layer_norm_weights()242 T *forget_layer_norm_weights() const 243 { 244 return _forget_layer_norm_weights; 245 } 246 cell_layer_norm_weights()247 T *cell_layer_norm_weights() const 248 { 249 return _cell_layer_norm_weights; 250 } 251 output_layer_norm_weights()252 T *output_layer_norm_weights() const 253 { 254 return _output_layer_norm_weights; 255 } 256 cell_clip()257 float cell_clip() const 258 { 259 return _cell_clip; 260 } 261 projection_clip()262 float projection_clip() const 263 { 264 return _projection_clip; 265 } 266 input_intermediate_scale()267 float input_intermediate_scale() const 268 { 269 return _input_intermediate_scale; 270 } 271 forget_intermediate_scale()272 float forget_intermediate_scale() const 273 { 274 return _forget_intermediate_scale; 275 } 276 cell_intermediate_scale()277 float cell_intermediate_scale() const 278 { 279 return _cell_intermediate_scale; 280 } 281 output_intermediate_scale()282 float output_intermediate_scale() const 283 { 284 return _output_intermediate_scale; 285 } 286 hidden_state_zero()287 int32_t hidden_state_zero() const 288 { 289 return _hidden_state_zero; 290 } 291 hidden_state_scale()292 float hidden_state_scale() const 293 { 294 return _hidden_state_scale; 295 } 296 has_peephole_opt()297 bool has_peephole_opt() const 298 { 299 return _has_peephole_opt; 300 } 301 has_projection()302 bool has_projection() const 303 { 304 return _has_projection; 305 } 306 has_cifg_opt()307 bool has_cifg_opt() const 308 { 309 return _has_cifg_opt; 310 } 311 use_layer_norm()312 bool use_layer_norm() const 313 { 314 return _use_layer_norm; 315 } 316 317 private: 318 const T *_input_to_input_weights; 319 const T *_recurrent_to_input_weights; 320 T *_cell_to_input_weights; 321 const T *_input_gate_bias; 322 T *_cell_to_forget_weights; 323 T *_cell_to_output_weights; 324 const T *_projection_weights; 325 const T *_projection_bias; 326 T *_input_layer_norm_weights; 327 T *_forget_layer_norm_weights; 328 T *_cell_layer_norm_weights; 329 T *_output_layer_norm_weights; 330 float _cell_clip; 331 float _projection_clip; 332 float _input_intermediate_scale; 333 float _forget_intermediate_scale; 334 float _cell_intermediate_scale; 335 float _output_intermediate_scale; 336 int32_t _hidden_state_zero; 337 float _hidden_state_scale; 338 bool _has_peephole_opt; 339 bool _has_projection; 340 bool _has_cifg_opt; 341 bool _use_layer_norm; 342 }; 343 } 344 #endif /*ARM_COMPUTE_LSTMPARAMS_H */ 345