• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6 
7 #include "TensorFwd.hpp"
8 #include "Exceptions.hpp"
9 
10 namespace armnn
11 {
12 
13 struct LstmInputParams
14 {
LstmInputParamsarmnn::LstmInputParams15     LstmInputParams()
16         : m_InputToInputWeights(nullptr)
17         , m_InputToForgetWeights(nullptr)
18         , m_InputToCellWeights(nullptr)
19         , m_InputToOutputWeights(nullptr)
20         , m_RecurrentToInputWeights(nullptr)
21         , m_RecurrentToForgetWeights(nullptr)
22         , m_RecurrentToCellWeights(nullptr)
23         , m_RecurrentToOutputWeights(nullptr)
24         , m_CellToInputWeights(nullptr)
25         , m_CellToForgetWeights(nullptr)
26         , m_CellToOutputWeights(nullptr)
27         , m_InputGateBias(nullptr)
28         , m_ForgetGateBias(nullptr)
29         , m_CellBias(nullptr)
30         , m_OutputGateBias(nullptr)
31         , m_ProjectionWeights(nullptr)
32         , m_ProjectionBias(nullptr)
33         , m_InputLayerNormWeights(nullptr)
34         , m_ForgetLayerNormWeights(nullptr)
35         , m_CellLayerNormWeights(nullptr)
36         , m_OutputLayerNormWeights(nullptr)
37     {
38     }
39 
40     const ConstTensor* m_InputToInputWeights;
41     const ConstTensor* m_InputToForgetWeights;
42     const ConstTensor* m_InputToCellWeights;
43     const ConstTensor* m_InputToOutputWeights;
44     const ConstTensor* m_RecurrentToInputWeights;
45     const ConstTensor* m_RecurrentToForgetWeights;
46     const ConstTensor* m_RecurrentToCellWeights;
47     const ConstTensor* m_RecurrentToOutputWeights;
48     const ConstTensor* m_CellToInputWeights;
49     const ConstTensor* m_CellToForgetWeights;
50     const ConstTensor* m_CellToOutputWeights;
51     const ConstTensor* m_InputGateBias;
52     const ConstTensor* m_ForgetGateBias;
53     const ConstTensor* m_CellBias;
54     const ConstTensor* m_OutputGateBias;
55     const ConstTensor* m_ProjectionWeights;
56     const ConstTensor* m_ProjectionBias;
57     const ConstTensor* m_InputLayerNormWeights;
58     const ConstTensor* m_ForgetLayerNormWeights;
59     const ConstTensor* m_CellLayerNormWeights;
60     const ConstTensor* m_OutputLayerNormWeights;
61 };
62 
63 struct LstmInputParamsInfo
64 {
LstmInputParamsInfoarmnn::LstmInputParamsInfo65     LstmInputParamsInfo()
66             : m_InputToInputWeights(nullptr)
67             , m_InputToForgetWeights(nullptr)
68             , m_InputToCellWeights(nullptr)
69             , m_InputToOutputWeights(nullptr)
70             , m_RecurrentToInputWeights(nullptr)
71             , m_RecurrentToForgetWeights(nullptr)
72             , m_RecurrentToCellWeights(nullptr)
73             , m_RecurrentToOutputWeights(nullptr)
74             , m_CellToInputWeights(nullptr)
75             , m_CellToForgetWeights(nullptr)
76             , m_CellToOutputWeights(nullptr)
77             , m_InputGateBias(nullptr)
78             , m_ForgetGateBias(nullptr)
79             , m_CellBias(nullptr)
80             , m_OutputGateBias(nullptr)
81             , m_ProjectionWeights(nullptr)
82             , m_ProjectionBias(nullptr)
83             , m_InputLayerNormWeights(nullptr)
84             , m_ForgetLayerNormWeights(nullptr)
85             , m_CellLayerNormWeights(nullptr)
86             , m_OutputLayerNormWeights(nullptr)
87     {
88     }
89     const TensorInfo* m_InputToInputWeights;
90     const TensorInfo* m_InputToForgetWeights;
91     const TensorInfo* m_InputToCellWeights;
92     const TensorInfo* m_InputToOutputWeights;
93     const TensorInfo* m_RecurrentToInputWeights;
94     const TensorInfo* m_RecurrentToForgetWeights;
95     const TensorInfo* m_RecurrentToCellWeights;
96     const TensorInfo* m_RecurrentToOutputWeights;
97     const TensorInfo* m_CellToInputWeights;
98     const TensorInfo* m_CellToForgetWeights;
99     const TensorInfo* m_CellToOutputWeights;
100     const TensorInfo* m_InputGateBias;
101     const TensorInfo* m_ForgetGateBias;
102     const TensorInfo* m_CellBias;
103     const TensorInfo* m_OutputGateBias;
104     const TensorInfo* m_ProjectionWeights;
105     const TensorInfo* m_ProjectionBias;
106     const TensorInfo* m_InputLayerNormWeights;
107     const TensorInfo* m_ForgetLayerNormWeights;
108     const TensorInfo* m_CellLayerNormWeights;
109     const TensorInfo* m_OutputLayerNormWeights;
110 
Derefarmnn::LstmInputParamsInfo111     const TensorInfo& Deref(const TensorInfo* tensorInfo) const
112     {
113         if (tensorInfo != nullptr)
114         {
115             const TensorInfo &temp = *tensorInfo;
116             return temp;
117         }
118         throw InvalidArgumentException("Can't dereference a null pointer");
119     }
120 
GetInputToInputWeightsarmnn::LstmInputParamsInfo121     const TensorInfo& GetInputToInputWeights() const
122     {
123         return Deref(m_InputToInputWeights);
124     }
GetInputToForgetWeightsarmnn::LstmInputParamsInfo125     const TensorInfo& GetInputToForgetWeights() const
126     {
127         return Deref(m_InputToForgetWeights);
128     }
GetInputToCellWeightsarmnn::LstmInputParamsInfo129     const TensorInfo& GetInputToCellWeights() const
130     {
131         return Deref(m_InputToCellWeights);
132     }
GetInputToOutputWeightsarmnn::LstmInputParamsInfo133     const TensorInfo& GetInputToOutputWeights() const
134     {
135         return Deref(m_InputToOutputWeights);
136     }
GetRecurrentToInputWeightsarmnn::LstmInputParamsInfo137     const TensorInfo& GetRecurrentToInputWeights() const
138     {
139         return Deref(m_RecurrentToInputWeights);
140     }
GetRecurrentToForgetWeightsarmnn::LstmInputParamsInfo141     const TensorInfo& GetRecurrentToForgetWeights() const
142     {
143         return Deref(m_RecurrentToForgetWeights);
144     }
GetRecurrentToCellWeightsarmnn::LstmInputParamsInfo145     const TensorInfo& GetRecurrentToCellWeights() const
146     {
147         return Deref(m_RecurrentToCellWeights);
148     }
GetRecurrentToOutputWeightsarmnn::LstmInputParamsInfo149     const TensorInfo& GetRecurrentToOutputWeights() const
150     {
151         return Deref(m_RecurrentToOutputWeights);
152     }
GetCellToInputWeightsarmnn::LstmInputParamsInfo153     const TensorInfo& GetCellToInputWeights() const
154     {
155         return Deref(m_CellToInputWeights);
156     }
GetCellToForgetWeightsarmnn::LstmInputParamsInfo157     const TensorInfo& GetCellToForgetWeights() const
158     {
159         return Deref(m_CellToForgetWeights);
160     }
GetCellToOutputWeightsarmnn::LstmInputParamsInfo161     const TensorInfo& GetCellToOutputWeights() const
162     {
163         return Deref(m_CellToOutputWeights);
164     }
GetInputGateBiasarmnn::LstmInputParamsInfo165     const TensorInfo& GetInputGateBias() const
166     {
167         return Deref(m_InputGateBias);
168     }
GetForgetGateBiasarmnn::LstmInputParamsInfo169     const TensorInfo& GetForgetGateBias() const
170     {
171         return Deref(m_ForgetGateBias);
172     }
GetCellBiasarmnn::LstmInputParamsInfo173     const TensorInfo& GetCellBias() const
174     {
175         return Deref(m_CellBias);
176     }
GetOutputGateBiasarmnn::LstmInputParamsInfo177     const TensorInfo& GetOutputGateBias() const
178     {
179         return Deref(m_OutputGateBias);
180     }
GetProjectionWeightsarmnn::LstmInputParamsInfo181     const TensorInfo& GetProjectionWeights() const
182     {
183         return Deref(m_ProjectionWeights);
184     }
GetProjectionBiasarmnn::LstmInputParamsInfo185     const TensorInfo& GetProjectionBias() const
186     {
187         return Deref(m_ProjectionBias);
188     }
GetInputLayerNormWeightsarmnn::LstmInputParamsInfo189     const TensorInfo& GetInputLayerNormWeights() const
190     {
191         return Deref(m_InputLayerNormWeights);
192     }
GetForgetLayerNormWeightsarmnn::LstmInputParamsInfo193     const TensorInfo& GetForgetLayerNormWeights() const
194     {
195         return Deref(m_ForgetLayerNormWeights);
196     }
GetCellLayerNormWeightsarmnn::LstmInputParamsInfo197     const TensorInfo& GetCellLayerNormWeights() const
198     {
199         return Deref(m_CellLayerNormWeights);
200     }
GetOutputLayerNormWeightsarmnn::LstmInputParamsInfo201     const TensorInfo& GetOutputLayerNormWeights() const
202     {
203         return Deref(m_OutputLayerNormWeights);
204     }
205 };
206 
207 } // namespace armnn
208 
209