• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2018-2020 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #ifndef ARM_COMPUTE_LSTMPARAMS_H
25 #define ARM_COMPUTE_LSTMPARAMS_H
26 
27 #include "arm_compute/core/IPyramid.h"
28 #include "arm_compute/core/PyramidInfo.h"
29 #include "arm_compute/core/Types.h"
30 #include "arm_compute/runtime/Tensor.h"
31 
32 #include <cstddef>
33 #include <memory>
34 
35 namespace arm_compute
36 {
37 template <typename T>
38 class LSTMParams
39 {
40 public:
41     /** Constructor */
LSTMParams()42     LSTMParams()
43         : _input_to_input_weights(nullptr),
44           _recurrent_to_input_weights(nullptr),
45           _cell_to_input_weights(nullptr),
46           _input_gate_bias(nullptr),
47           _cell_to_forget_weights(nullptr),
48           _cell_to_output_weights(nullptr),
49           _projection_weights(nullptr),
50           _projection_bias(nullptr),
51           _input_layer_norm_weights(nullptr),
52           _forget_layer_norm_weights(nullptr),
53           _cell_layer_norm_weights(nullptr),
54           _output_layer_norm_weights(nullptr),
55           _cell_clip(0.f),
56           _projection_clip(0.0f),
57           _input_intermediate_scale(0.0f),
58           _forget_intermediate_scale(0.0f),
59           _cell_intermediate_scale(0.0f),
60           _output_intermediate_scale(0.0f),
61           _hidden_state_zero(0),
62           _hidden_state_scale(0.0f),
63           _has_peephole_opt(false),
64           _has_projection(false),
65           _has_cifg_opt(true),
66           _use_layer_norm(false)
67     {
68     }
69     /** Prevent instances of this class from being copied (As this class contains pointers) */
70     LSTMParams(const LSTMParams &) = delete;
71     /** Prevent instances of this class from being copied (As this class contains pointers) */
72     LSTMParams &operator=(const LSTMParams &) = delete;
73     /** Default destructor */
74     ~LSTMParams() = default;
75     /** Set CIFG tensor parameters.
76      *
77      * @param[in] input_to_input_weights     2D weights tensor with dimensions [input_size, num_units]. Data types supported: QSYMM8/F16/F32.
78      * @param[in] recurrent_to_input_weights 2D weights tensor with dimensions [output_size, num_units]. Data type supported: Same as @p input_to_input_weights.
79      * @param[in] cell_to_input_weights      1D weights tensor with dimensions [num_units]. Can be nullptr. Data type supported: Same as @p input_to_input_weights.
80      * @param[in] input_gate_bias            1D weights tensor with dimensions [num_units]. Data type supported: Same as @p input_to_input_weights, S32 when @p input_to_input_weights is QSYMM8
81      *
82      * @return Reference to this LSTMParams object
83      */
set_cifg_params(const T * input_to_input_weights,const T * recurrent_to_input_weights,T * cell_to_input_weights,const T * input_gate_bias)84     LSTMParams &set_cifg_params(const T *input_to_input_weights, const T *recurrent_to_input_weights, T *cell_to_input_weights, const T *input_gate_bias)
85     {
86         _input_to_input_weights     = input_to_input_weights;
87         _recurrent_to_input_weights = recurrent_to_input_weights;
88         _cell_to_input_weights      = cell_to_input_weights;
89         _input_gate_bias            = input_gate_bias;
90         _has_cifg_opt               = false;
91         return *this;
92     }
93     /** Set projection tensor parameters.
94      *
95      * @param[in] projection_weights 2D weights tensor with dimensions [output_size, num_units]. Data type supported: Data types supported: QSYMM8/F16/F32.
96      * @param[in] projection_bias    1D weights tensor with dimensions [output_size]. Data type supported: Same as @p projection_weights, S32 when @p input_to_input_weights is QSYMM8.
97      *
98      * @return Reference to this LSTMParams object
99      */
set_projection_params(const T * projection_weights,const T * projection_bias)100     LSTMParams &set_projection_params(const T *projection_weights, const T *projection_bias)
101     {
102         _projection_weights = projection_weights;
103         _projection_bias    = projection_bias;
104         _has_projection     = true;
105         return *this;
106     }
107     /** Set peephole tensor parameters.
108      *
109      * @param[in] cell_to_forget_weights 1D weights tensor with dimensions [num_units]. Data type supported: Data types supported: QSYMM16/F16/F32.
110      * @param[in] cell_to_output_weights 1D weights tensor with dimensions [num_units]. Data type supported: Same as @p cell_to_forget_weights.
111      *
112      * @return Reference to this LSTMParams object
113      */
set_peephole_params(T * cell_to_forget_weights,T * cell_to_output_weights)114     LSTMParams &set_peephole_params(T *cell_to_forget_weights, T *cell_to_output_weights)
115     {
116         _cell_to_forget_weights = cell_to_forget_weights;
117         _cell_to_output_weights = cell_to_output_weights;
118         _has_peephole_opt       = true;
119         return *this;
120     }
121     /** Set layer normalization tensor parameters.
122      *
123      * @param[in] input_layer_norm_weights  1D weights tensor with dimensions [num_units]. Data type supported: Data types supported: QSYMM16/F16/F32.
124      * @param[in] forget_layer_norm_weights 1D weights tensor with dimensions [num_units]. Data type supported: Same as @p input_layer_norm_weights.
125      * @param[in] cell_layer_norm_weights   1D weights tensor with dimensions [num_units]. Data type supported: Same as @p input_layer_norm_weights.
126      * @param[in] output_layer_norm_weights 1D weights tensor with dimensions [num_units]. Data type supported: Same as @p input_layer_norm_weights.
127      *
128      * @return Reference to this LSTMParams object
129      */
set_layer_normalization_params(T * input_layer_norm_weights,T * forget_layer_norm_weights,T * cell_layer_norm_weights,T * output_layer_norm_weights)130     LSTMParams &set_layer_normalization_params(T *input_layer_norm_weights, T *forget_layer_norm_weights,
131                                                T *cell_layer_norm_weights, T *output_layer_norm_weights)
132     {
133         _input_layer_norm_weights  = input_layer_norm_weights;
134         _forget_layer_norm_weights = forget_layer_norm_weights;
135         _cell_layer_norm_weights   = cell_layer_norm_weights;
136         _output_layer_norm_weights = output_layer_norm_weights;
137         _use_layer_norm            = true;
138         return *this;
139     }
140 
141     /** Set cell clip value.
142      *
143      * @param[in] cell_clip Value to be used to clip the cell state prior to the cell output activation.
144      *
145      * @return Reference to this LSTMParams object
146      */
set_cell_clip_params(float cell_clip)147     LSTMParams &set_cell_clip_params(float cell_clip)
148     {
149         _cell_clip = cell_clip;
150         return *this;
151     }
152 
153     /** Set projection clip value.
154      *
155      * @param[in] projection_clip Value to be used to clip the projection, in case projection is enabled.
156      *
157      * @return Reference to this LSTMParams object
158      */
set_projection_clip_params(float projection_clip)159     LSTMParams &set_projection_clip_params(float projection_clip)
160     {
161         _projection_clip = projection_clip;
162         return *this;
163     }
164 
165     /** Set scale of the intermediate results of matmul of each layer parameters.
166      *
167      * @param[in] input_intermediate_scale  Scale of the intermediate result of matmul, i.e. input to layer normalization, at input gate.
168      * @param[in] forget_intermediate_scale Scale of the intermediate result of matmul, i.e. input to layer normalization, at forget gate.
169      * @param[in] cell_intermediate_scale   Scale of the intermediate result of matmul, i.e. input to layer normalization, at cell gate.
170      * @param[in] output_intermediate_scale Scale of the intermediate result of matmul, i.e. input to layer normalization, at output gate.
171      *
172      * @return Reference to this LSTMParams object
173      */
set_matmul_scale_params(float input_intermediate_scale,float forget_intermediate_scale,float cell_intermediate_scale,float output_intermediate_scale)174     LSTMParams &set_matmul_scale_params(float input_intermediate_scale, float forget_intermediate_scale, float cell_intermediate_scale, float output_intermediate_scale)
175     {
176         _input_intermediate_scale  = input_intermediate_scale;
177         _forget_intermediate_scale = forget_intermediate_scale;
178         _cell_intermediate_scale   = cell_intermediate_scale;
179         _output_intermediate_scale = output_intermediate_scale;
180         return *this;
181     }
182 
183     /** Set hidden state zero and scale parameters.
184      *
185      * @param[in] hidden_state_zero  The zero point of the hidden state.
186      * @param[in] hidden_state_scale The scale of the hidden state.
187      *
188      * @return Reference to this LSTMParams object
189      */
set_hidden_state_params(int32_t hidden_state_zero,float hidden_state_scale)190     LSTMParams &set_hidden_state_params(int32_t hidden_state_zero, float hidden_state_scale)
191     {
192         _hidden_state_zero  = hidden_state_zero;
193         _hidden_state_scale = hidden_state_scale;
194         return *this;
195     }
196 
input_to_input_weights()197     const T *input_to_input_weights() const
198     {
199         return _input_to_input_weights;
200     }
201 
recurrent_to_input_weights()202     const T *recurrent_to_input_weights() const
203     {
204         return _recurrent_to_input_weights;
205     }
206 
cell_to_input_weights()207     T *cell_to_input_weights() const
208     {
209         return _cell_to_input_weights;
210     }
211 
input_gate_bias()212     const T *input_gate_bias() const
213     {
214         return _input_gate_bias;
215     }
216 
cell_to_forget_weights()217     T *cell_to_forget_weights() const
218     {
219         return _cell_to_forget_weights;
220     }
221 
cell_to_output_weights()222     T *cell_to_output_weights() const
223     {
224         return _cell_to_output_weights;
225     }
226 
projection_weights()227     const T *projection_weights() const
228     {
229         return _projection_weights;
230     }
231 
projection_bias()232     const T *projection_bias() const
233     {
234         return _projection_bias;
235     }
236 
input_layer_norm_weights()237     T *input_layer_norm_weights() const
238     {
239         return _input_layer_norm_weights;
240     }
241 
forget_layer_norm_weights()242     T *forget_layer_norm_weights() const
243     {
244         return _forget_layer_norm_weights;
245     }
246 
cell_layer_norm_weights()247     T *cell_layer_norm_weights() const
248     {
249         return _cell_layer_norm_weights;
250     }
251 
output_layer_norm_weights()252     T *output_layer_norm_weights() const
253     {
254         return _output_layer_norm_weights;
255     }
256 
cell_clip()257     float cell_clip() const
258     {
259         return _cell_clip;
260     }
261 
projection_clip()262     float projection_clip() const
263     {
264         return _projection_clip;
265     }
266 
input_intermediate_scale()267     float input_intermediate_scale() const
268     {
269         return _input_intermediate_scale;
270     }
271 
forget_intermediate_scale()272     float forget_intermediate_scale() const
273     {
274         return _forget_intermediate_scale;
275     }
276 
cell_intermediate_scale()277     float cell_intermediate_scale() const
278     {
279         return _cell_intermediate_scale;
280     }
281 
output_intermediate_scale()282     float output_intermediate_scale() const
283     {
284         return _output_intermediate_scale;
285     }
286 
hidden_state_zero()287     int32_t hidden_state_zero() const
288     {
289         return _hidden_state_zero;
290     }
291 
hidden_state_scale()292     float hidden_state_scale() const
293     {
294         return _hidden_state_scale;
295     }
296 
has_peephole_opt()297     bool has_peephole_opt() const
298     {
299         return _has_peephole_opt;
300     }
301 
has_projection()302     bool has_projection() const
303     {
304         return _has_projection;
305     }
306 
has_cifg_opt()307     bool has_cifg_opt() const
308     {
309         return _has_cifg_opt;
310     }
311 
use_layer_norm()312     bool use_layer_norm() const
313     {
314         return _use_layer_norm;
315     }
316 
317 private:
318     const T *_input_to_input_weights;
319     const T *_recurrent_to_input_weights;
320     T       *_cell_to_input_weights;
321     const T *_input_gate_bias;
322     T       *_cell_to_forget_weights;
323     T       *_cell_to_output_weights;
324     const T *_projection_weights;
325     const T *_projection_bias;
326     T       *_input_layer_norm_weights;
327     T       *_forget_layer_norm_weights;
328     T       *_cell_layer_norm_weights;
329     T       *_output_layer_norm_weights;
330     float    _cell_clip;
331     float    _projection_clip;
332     float    _input_intermediate_scale;
333     float    _forget_intermediate_scale;
334     float    _cell_intermediate_scale;
335     float    _output_intermediate_scale;
336     int32_t  _hidden_state_zero;
337     float    _hidden_state_scale;
338     bool     _has_peephole_opt;
339     bool     _has_projection;
340     bool     _has_cifg_opt;
341     bool     _use_layer_norm;
342 };
343 }
344 #endif /*ARM_COMPUTE_LSTMPARAMS_H */
345