• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #include "QuantizedLstmLayer.hpp"
6 
7 #include "LayerCloneBase.hpp"
8 
9 #include <armnn/QuantizedLstmParams.hpp>
10 #include <armnn/TypesUtils.hpp>
11 #include <backendsCommon/CpuTensorHandle.hpp>
12 #include <backendsCommon/WorkloadFactory.hpp>
13 
14 namespace armnn
15 {
16 
QuantizedLstmLayer(const char * name)17 QuantizedLstmLayer::QuantizedLstmLayer(const char* name)
18     : Layer(3, 2, LayerType::QuantizedLstm, name)
19 {
20 }
21 
CreateWorkload(const IWorkloadFactory & factory) const22 std::unique_ptr<IWorkload> QuantizedLstmLayer::CreateWorkload(const IWorkloadFactory& factory) const
23 {
24     QuantizedLstmQueueDescriptor descriptor;
25 
26     // QuantizedLstmLayer parameters - there are no optional params
27     descriptor.m_InputToInputWeights  = m_QuantizedLstmParameters.m_InputToInputWeights.get();
28     descriptor.m_InputToForgetWeights = m_QuantizedLstmParameters.m_InputToForgetWeights.get();
29     descriptor.m_InputToCellWeights   = m_QuantizedLstmParameters.m_InputToCellWeights.get();
30     descriptor.m_InputToOutputWeights = m_QuantizedLstmParameters.m_InputToOutputWeights.get();
31 
32     descriptor.m_RecurrentToInputWeights  = m_QuantizedLstmParameters.m_RecurrentToInputWeights.get();
33     descriptor.m_RecurrentToForgetWeights = m_QuantizedLstmParameters.m_RecurrentToForgetWeights.get();
34     descriptor.m_RecurrentToCellWeights   = m_QuantizedLstmParameters.m_RecurrentToCellWeights.get();
35     descriptor.m_RecurrentToOutputWeights = m_QuantizedLstmParameters.m_RecurrentToOutputWeights.get();
36 
37     descriptor.m_InputGateBias  = m_QuantizedLstmParameters.m_InputGateBias.get();
38     descriptor.m_ForgetGateBias = m_QuantizedLstmParameters.m_ForgetGateBias.get();
39     descriptor.m_CellBias       = m_QuantizedLstmParameters.m_CellBias.get();
40     descriptor.m_OutputGateBias = m_QuantizedLstmParameters.m_OutputGateBias.get();
41 
42     SetAdditionalInfo(descriptor);
43 
44     return factory.CreateQuantizedLstm(descriptor, PrepInfoAndDesc(descriptor));
45 }
46 
Clone(Graph & graph) const47 QuantizedLstmLayer* QuantizedLstmLayer::Clone(Graph& graph) const
48 {
49     auto layer = CloneBase<QuantizedLstmLayer>(graph, GetName());
50 
51     layer->m_QuantizedLstmParameters.m_InputToInputWeights = m_QuantizedLstmParameters.m_InputToInputWeights ?
52             std::make_unique<ScopedCpuTensorHandle>(*m_QuantizedLstmParameters.m_InputToInputWeights) : nullptr;
53     layer->m_QuantizedLstmParameters.m_InputToForgetWeights = m_QuantizedLstmParameters.m_InputToForgetWeights ?
54             std::make_unique<ScopedCpuTensorHandle>(*m_QuantizedLstmParameters.m_InputToForgetWeights) : nullptr;
55     layer->m_QuantizedLstmParameters.m_InputToCellWeights = m_QuantizedLstmParameters.m_InputToCellWeights ?
56             std::make_unique<ScopedCpuTensorHandle>(*m_QuantizedLstmParameters.m_InputToCellWeights) : nullptr;
57     layer->m_QuantizedLstmParameters.m_InputToOutputWeights = m_QuantizedLstmParameters.m_InputToOutputWeights ?
58             std::make_unique<ScopedCpuTensorHandle>(*m_QuantizedLstmParameters.m_InputToOutputWeights) : nullptr;
59 
60     layer->m_QuantizedLstmParameters.m_RecurrentToInputWeights = m_QuantizedLstmParameters.m_RecurrentToInputWeights ?
61             std::make_unique<ScopedCpuTensorHandle>(*m_QuantizedLstmParameters.m_RecurrentToInputWeights) : nullptr;
62     layer->m_QuantizedLstmParameters.m_RecurrentToForgetWeights = m_QuantizedLstmParameters.m_RecurrentToForgetWeights
63             ? std::make_unique<ScopedCpuTensorHandle>(*m_QuantizedLstmParameters.m_RecurrentToForgetWeights) : nullptr;
64     layer->m_QuantizedLstmParameters.m_RecurrentToCellWeights = m_QuantizedLstmParameters.m_RecurrentToCellWeights ?
65             std::make_unique<ScopedCpuTensorHandle>(*m_QuantizedLstmParameters.m_RecurrentToCellWeights) : nullptr;
66     layer->m_QuantizedLstmParameters.m_RecurrentToOutputWeights = m_QuantizedLstmParameters.m_RecurrentToOutputWeights
67             ? std::make_unique<ScopedCpuTensorHandle>(*m_QuantizedLstmParameters.m_RecurrentToOutputWeights) : nullptr;
68 
69     layer->m_QuantizedLstmParameters.m_InputGateBias = m_QuantizedLstmParameters.m_InputGateBias ?
70             std::make_unique<ScopedCpuTensorHandle>(*m_QuantizedLstmParameters.m_InputGateBias) : nullptr;
71     layer->m_QuantizedLstmParameters.m_ForgetGateBias = m_QuantizedLstmParameters.m_ForgetGateBias ?
72             std::make_unique<ScopedCpuTensorHandle>(*m_QuantizedLstmParameters.m_ForgetGateBias) : nullptr;
73     layer->m_QuantizedLstmParameters.m_CellBias = m_QuantizedLstmParameters.m_CellBias ?
74             std::make_unique<ScopedCpuTensorHandle>(*m_QuantizedLstmParameters.m_CellBias) : nullptr;
75     layer->m_QuantizedLstmParameters.m_OutputGateBias = m_QuantizedLstmParameters.m_OutputGateBias ?
76             std::make_unique<ScopedCpuTensorHandle>(*m_QuantizedLstmParameters.m_OutputGateBias) : nullptr;
77 
78     return std::move(layer);
79 }
80 
InferOutputShapes(const std::vector<TensorShape> & inputShapes) const81 std::vector<TensorShape> QuantizedLstmLayer::InferOutputShapes(const std::vector<TensorShape>& inputShapes) const
82 {
83     ARMNN_ASSERT(inputShapes.size() == 3);
84 
85     // Get input values for validation
86     unsigned int numBatches = inputShapes[0][0];
87     unsigned int outputSize = inputShapes[1][1];
88 
89     std::vector<TensorShape> outShapes;
90     outShapes.push_back(TensorShape({numBatches, outputSize})); // cellStateOut
91     outShapes.push_back(TensorShape({numBatches, outputSize})); // output
92 
93     return outShapes;
94 }
95 
ValidateTensorShapesFromInputs()96 void QuantizedLstmLayer::ValidateTensorShapesFromInputs()
97 {
98     VerifyLayerConnections(3, CHECK_LOCATION());
99 
100     const TensorShape& outputShape = GetOutputSlot(0).GetTensorInfo().GetShape();
101 
102     VerifyShapeInferenceType(outputShape, m_ShapeInferenceMethod);
103 
104     auto inferredShapes = InferOutputShapes(
105     {
106         GetInputSlot(0).GetConnection()->GetTensorInfo().GetShape(), // input
107         GetInputSlot(1).GetConnection()->GetTensorInfo().GetShape(), // previousCellStateIn
108         GetInputSlot(2).GetConnection()->GetTensorInfo().GetShape()  // previousOutputIn
109     });
110 
111     ARMNN_ASSERT(inferredShapes.size() == 2);
112 
113     // Check weights and bias for nullptr
114     ARMNN_ASSERT_MSG(m_QuantizedLstmParameters.m_InputToInputWeights != nullptr,
115                      "QuantizedLstmLayer: m_QuantizedLstmParameters.m_InputToInputWeights should not be null.");
116     ARMNN_ASSERT_MSG(m_QuantizedLstmParameters.m_InputToForgetWeights != nullptr,
117                      "QuantizedLstmLayer: m_QuantizedLstmParameters.m_InputToForgetWeights should not be null.");
118     ARMNN_ASSERT_MSG(m_QuantizedLstmParameters.m_InputToCellWeights != nullptr,
119                      "QuantizedLstmLayer: m_QuantizedLstmParameters.m_InputToCellWeights should not be null.");
120     ARMNN_ASSERT_MSG(m_QuantizedLstmParameters.m_InputToOutputWeights != nullptr,
121                      "QuantizedLstmLayer: m_QuantizedLstmParameters.m_InputToOutputWeights should not be null.");
122 
123     ARMNN_ASSERT_MSG(m_QuantizedLstmParameters.m_RecurrentToInputWeights != nullptr,
124                      "QuantizedLstmLayer: m_QuantizedLstmParameters.m_RecurrentToInputWeights should not be null.");
125     ARMNN_ASSERT_MSG(m_QuantizedLstmParameters.m_RecurrentToForgetWeights != nullptr,
126                      "QuantizedLstmLayer: m_QuantizedLstmParameters.m_RecurrentToForgetWeights should not be null.");
127     ARMNN_ASSERT_MSG(m_QuantizedLstmParameters.m_RecurrentToCellWeights != nullptr,
128                      "QuantizedLstmLayer: m_QuantizedLstmParameters.m_RecurrentToCellWeights should not be null.");
129     ARMNN_ASSERT_MSG(m_QuantizedLstmParameters.m_RecurrentToOutputWeights != nullptr,
130                      "QuantizedLstmLayer: m_QuantizedLstmParameters.m_RecurrentToOutputWeights should not be null.");
131 
132     ARMNN_ASSERT_MSG(m_QuantizedLstmParameters.m_InputGateBias != nullptr,
133                      "QuantizedLstmLayer: m_QuantizedLstmParameters.m_InputGateBias should not be null.");
134     ARMNN_ASSERT_MSG(m_QuantizedLstmParameters.m_ForgetGateBias != nullptr,
135                      "QuantizedLstmLayer: m_QuantizedLstmParameters.m_ForgetGateBias should not be null.");
136     ARMNN_ASSERT_MSG(m_QuantizedLstmParameters.m_CellBias != nullptr,
137                      "QuantizedLstmLayer: m_QuantizedLstmParameters.m_CellBias should not be null.");
138     ARMNN_ASSERT_MSG(m_QuantizedLstmParameters.m_OutputGateBias != nullptr,
139                      "QuantizedLstmLayer: m_QuantizedLstmParameters.m_OutputGateBias should not be null.");
140 
141     // Check output TensorShape(s) match inferred shape
142     ValidateAndCopyShape(outputShape, inferredShapes[0], m_ShapeInferenceMethod, "QuantizedLstmLayer");
143 
144     ValidateAndCopyShape(GetOutputSlot(1).GetTensorInfo().GetShape(),
145                          inferredShapes[1],
146                          m_ShapeInferenceMethod,
147                          "QuantizedLstmLayer",
148                          1);
149 }
150 
GetConstantTensorsByRef()151 Layer::ConstantTensors QuantizedLstmLayer::GetConstantTensorsByRef()
152 {
153     return
154     {
155         m_QuantizedLstmParameters.m_InputToInputWeights,
156         m_QuantizedLstmParameters.m_InputToForgetWeights,
157         m_QuantizedLstmParameters.m_InputToCellWeights,
158         m_QuantizedLstmParameters.m_InputToOutputWeights,
159 
160         m_QuantizedLstmParameters.m_RecurrentToInputWeights,
161         m_QuantizedLstmParameters.m_RecurrentToForgetWeights,
162         m_QuantizedLstmParameters.m_RecurrentToCellWeights,
163         m_QuantizedLstmParameters.m_RecurrentToOutputWeights,
164 
165         m_QuantizedLstmParameters.m_InputGateBias,
166         m_QuantizedLstmParameters.m_ForgetGateBias,
167         m_QuantizedLstmParameters.m_CellBias,
168         m_QuantizedLstmParameters.m_OutputGateBias
169     };
170 }
171 
Accept(ILayerVisitor & visitor) const172 void QuantizedLstmLayer::Accept(ILayerVisitor& visitor) const
173 {
174     QuantizedLstmInputParams inputParams;
175 
176     // InputToX weight tensors
177     ConstTensor inputToInputWeightsTensor;
178     if (m_QuantizedLstmParameters.m_InputToInputWeights != nullptr)
179     {
180         ConstTensor inputToInputWeightsTensorCopy(m_QuantizedLstmParameters.m_InputToInputWeights->GetTensorInfo(),
181                                                   m_QuantizedLstmParameters.m_InputToInputWeights->Map(true));
182         inputToInputWeightsTensor = inputToInputWeightsTensorCopy;
183         inputParams.m_InputToInputWeights = &inputToInputWeightsTensor;
184     }
185 
186     ConstTensor inputToForgetWeightsTensor;
187     if (m_QuantizedLstmParameters.m_InputToForgetWeights != nullptr)
188     {
189         ConstTensor inputToForgetWeightsTensorCopy(m_QuantizedLstmParameters.m_InputToForgetWeights->GetTensorInfo(),
190                                                    m_QuantizedLstmParameters.m_InputToForgetWeights->Map(true));
191         inputToForgetWeightsTensor = inputToForgetWeightsTensorCopy;
192         inputParams.m_InputToForgetWeights = &inputToForgetWeightsTensor;
193     }
194 
195     ConstTensor inputToCellWeightsTensor;
196     if (m_QuantizedLstmParameters.m_InputToCellWeights != nullptr)
197     {
198         ConstTensor inputToCellWeightsTensorCopy(m_QuantizedLstmParameters.m_InputToCellWeights->GetTensorInfo(),
199                                                  m_QuantizedLstmParameters.m_InputToCellWeights->Map(true));
200         inputToCellWeightsTensor = inputToCellWeightsTensorCopy;
201         inputParams.m_InputToCellWeights = &inputToCellWeightsTensor;
202     }
203 
204     ConstTensor inputToOutputWeightsTensor;
205     if (m_QuantizedLstmParameters.m_InputToOutputWeights != nullptr)
206     {
207         ConstTensor inputToOutputWeightsTensorCopy(m_QuantizedLstmParameters.m_InputToOutputWeights->GetTensorInfo(),
208                                                    m_QuantizedLstmParameters.m_InputToOutputWeights->Map(true));
209         inputToOutputWeightsTensor = inputToOutputWeightsTensorCopy;
210         inputParams.m_InputToOutputWeights = &inputToOutputWeightsTensor;
211     }
212 
213     // RecurrentToX weight tensors
214     ConstTensor recurrentToInputWeightsTensor;
215     if (m_QuantizedLstmParameters.m_RecurrentToInputWeights != nullptr)
216     {
217         ConstTensor recurrentToInputWeightsTensorCopy(
218                 m_QuantizedLstmParameters.m_RecurrentToInputWeights->GetTensorInfo(),
219                 m_QuantizedLstmParameters.m_RecurrentToInputWeights->Map(true));
220         recurrentToInputWeightsTensor = recurrentToInputWeightsTensorCopy;
221         inputParams.m_RecurrentToInputWeights = &recurrentToInputWeightsTensor;
222     }
223 
224     ConstTensor recurrentToForgetWeightsTensor;
225     if (m_QuantizedLstmParameters.m_RecurrentToForgetWeights != nullptr)
226     {
227         ConstTensor recurrentToForgetWeightsTensorCopy(
228                 m_QuantizedLstmParameters.m_RecurrentToForgetWeights->GetTensorInfo(),
229                 m_QuantizedLstmParameters.m_RecurrentToForgetWeights->Map(true));
230         recurrentToForgetWeightsTensor = recurrentToForgetWeightsTensorCopy;
231         inputParams.m_RecurrentToForgetWeights = &recurrentToForgetWeightsTensor;
232     }
233 
234     ConstTensor recurrentToCellWeightsTensor;
235     if (m_QuantizedLstmParameters.m_RecurrentToCellWeights != nullptr)
236     {
237         ConstTensor recurrentToCellWeightsTensorCopy(
238                 m_QuantizedLstmParameters.m_RecurrentToCellWeights->GetTensorInfo(),
239                 m_QuantizedLstmParameters.m_RecurrentToCellWeights->Map(true));
240         recurrentToCellWeightsTensor = recurrentToCellWeightsTensorCopy;
241         inputParams.m_RecurrentToCellWeights = &recurrentToCellWeightsTensor;
242     }
243 
244     ConstTensor recurrentToOutputWeightsTensor;
245     if (m_QuantizedLstmParameters.m_RecurrentToOutputWeights != nullptr)
246     {
247         ConstTensor recurrentToOutputWeightsTensorCopy(
248                 m_QuantizedLstmParameters.m_RecurrentToOutputWeights->GetTensorInfo(),
249                 m_QuantizedLstmParameters.m_RecurrentToOutputWeights->Map(true));
250         recurrentToOutputWeightsTensor = recurrentToOutputWeightsTensorCopy;
251         inputParams.m_RecurrentToOutputWeights = &recurrentToOutputWeightsTensor;
252     }
253 
254     // Bias tensors
255     ConstTensor inputGateBiasTensor;
256     if (m_QuantizedLstmParameters.m_InputGateBias != nullptr)
257     {
258         ConstTensor inputGateBiasTensorCopy(m_QuantizedLstmParameters.m_InputGateBias->GetTensorInfo(),
259                                             m_QuantizedLstmParameters.m_InputGateBias->Map(true));
260         inputGateBiasTensor = inputGateBiasTensorCopy;
261         inputParams.m_InputGateBias = &inputGateBiasTensor;
262     }
263 
264     ConstTensor forgetGateBiasTensor;
265     if (m_QuantizedLstmParameters.m_ForgetGateBias != nullptr)
266     {
267         ConstTensor forgetGateBiasTensorCopy(m_QuantizedLstmParameters.m_ForgetGateBias->GetTensorInfo(),
268                                              m_QuantizedLstmParameters.m_ForgetGateBias->Map(true));
269         forgetGateBiasTensor = forgetGateBiasTensorCopy;
270         inputParams.m_ForgetGateBias = &forgetGateBiasTensor;
271     }
272 
273     ConstTensor cellBiasTensor;
274     if (m_QuantizedLstmParameters.m_CellBias != nullptr)
275     {
276         ConstTensor cellBiasTensorCopy(m_QuantizedLstmParameters.m_CellBias->GetTensorInfo(),
277                                        m_QuantizedLstmParameters.m_CellBias->Map(true));
278         cellBiasTensor = cellBiasTensorCopy;
279         inputParams.m_CellBias = &cellBiasTensor;
280     }
281 
282     ConstTensor outputGateBiasTensor;
283     if (m_QuantizedLstmParameters.m_OutputGateBias != nullptr)
284     {
285         ConstTensor outputGateBiasCopy(m_QuantizedLstmParameters.m_OutputGateBias->GetTensorInfo(),
286                                        m_QuantizedLstmParameters.m_OutputGateBias->Map(true));
287         outputGateBiasTensor = outputGateBiasCopy;
288         inputParams.m_OutputGateBias = &outputGateBiasTensor;
289     }
290 
291     visitor.VisitQuantizedLstmLayer(this, inputParams, GetName());
292 }
293 
294 } // namespace armnn
295