1 //
2 // Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #include "QuantizedLstmLayer.hpp"
6
7 #include "LayerCloneBase.hpp"
8
9 #include <armnn/QuantizedLstmParams.hpp>
10 #include <armnn/TypesUtils.hpp>
11 #include <backendsCommon/CpuTensorHandle.hpp>
12 #include <backendsCommon/WorkloadFactory.hpp>
13
14 namespace armnn
15 {
16
QuantizedLstmLayer(const char * name)17 QuantizedLstmLayer::QuantizedLstmLayer(const char* name)
18 : Layer(3, 2, LayerType::QuantizedLstm, name)
19 {
20 }
21
CreateWorkload(const IWorkloadFactory & factory) const22 std::unique_ptr<IWorkload> QuantizedLstmLayer::CreateWorkload(const IWorkloadFactory& factory) const
23 {
24 QuantizedLstmQueueDescriptor descriptor;
25
26 // QuantizedLstmLayer parameters - there are no optional params
27 descriptor.m_InputToInputWeights = m_QuantizedLstmParameters.m_InputToInputWeights.get();
28 descriptor.m_InputToForgetWeights = m_QuantizedLstmParameters.m_InputToForgetWeights.get();
29 descriptor.m_InputToCellWeights = m_QuantizedLstmParameters.m_InputToCellWeights.get();
30 descriptor.m_InputToOutputWeights = m_QuantizedLstmParameters.m_InputToOutputWeights.get();
31
32 descriptor.m_RecurrentToInputWeights = m_QuantizedLstmParameters.m_RecurrentToInputWeights.get();
33 descriptor.m_RecurrentToForgetWeights = m_QuantizedLstmParameters.m_RecurrentToForgetWeights.get();
34 descriptor.m_RecurrentToCellWeights = m_QuantizedLstmParameters.m_RecurrentToCellWeights.get();
35 descriptor.m_RecurrentToOutputWeights = m_QuantizedLstmParameters.m_RecurrentToOutputWeights.get();
36
37 descriptor.m_InputGateBias = m_QuantizedLstmParameters.m_InputGateBias.get();
38 descriptor.m_ForgetGateBias = m_QuantizedLstmParameters.m_ForgetGateBias.get();
39 descriptor.m_CellBias = m_QuantizedLstmParameters.m_CellBias.get();
40 descriptor.m_OutputGateBias = m_QuantizedLstmParameters.m_OutputGateBias.get();
41
42 SetAdditionalInfo(descriptor);
43
44 return factory.CreateQuantizedLstm(descriptor, PrepInfoAndDesc(descriptor));
45 }
46
Clone(Graph & graph) const47 QuantizedLstmLayer* QuantizedLstmLayer::Clone(Graph& graph) const
48 {
49 auto layer = CloneBase<QuantizedLstmLayer>(graph, GetName());
50
51 layer->m_QuantizedLstmParameters.m_InputToInputWeights = m_QuantizedLstmParameters.m_InputToInputWeights ?
52 std::make_unique<ScopedCpuTensorHandle>(*m_QuantizedLstmParameters.m_InputToInputWeights) : nullptr;
53 layer->m_QuantizedLstmParameters.m_InputToForgetWeights = m_QuantizedLstmParameters.m_InputToForgetWeights ?
54 std::make_unique<ScopedCpuTensorHandle>(*m_QuantizedLstmParameters.m_InputToForgetWeights) : nullptr;
55 layer->m_QuantizedLstmParameters.m_InputToCellWeights = m_QuantizedLstmParameters.m_InputToCellWeights ?
56 std::make_unique<ScopedCpuTensorHandle>(*m_QuantizedLstmParameters.m_InputToCellWeights) : nullptr;
57 layer->m_QuantizedLstmParameters.m_InputToOutputWeights = m_QuantizedLstmParameters.m_InputToOutputWeights ?
58 std::make_unique<ScopedCpuTensorHandle>(*m_QuantizedLstmParameters.m_InputToOutputWeights) : nullptr;
59
60 layer->m_QuantizedLstmParameters.m_RecurrentToInputWeights = m_QuantizedLstmParameters.m_RecurrentToInputWeights ?
61 std::make_unique<ScopedCpuTensorHandle>(*m_QuantizedLstmParameters.m_RecurrentToInputWeights) : nullptr;
62 layer->m_QuantizedLstmParameters.m_RecurrentToForgetWeights = m_QuantizedLstmParameters.m_RecurrentToForgetWeights
63 ? std::make_unique<ScopedCpuTensorHandle>(*m_QuantizedLstmParameters.m_RecurrentToForgetWeights) : nullptr;
64 layer->m_QuantizedLstmParameters.m_RecurrentToCellWeights = m_QuantizedLstmParameters.m_RecurrentToCellWeights ?
65 std::make_unique<ScopedCpuTensorHandle>(*m_QuantizedLstmParameters.m_RecurrentToCellWeights) : nullptr;
66 layer->m_QuantizedLstmParameters.m_RecurrentToOutputWeights = m_QuantizedLstmParameters.m_RecurrentToOutputWeights
67 ? std::make_unique<ScopedCpuTensorHandle>(*m_QuantizedLstmParameters.m_RecurrentToOutputWeights) : nullptr;
68
69 layer->m_QuantizedLstmParameters.m_InputGateBias = m_QuantizedLstmParameters.m_InputGateBias ?
70 std::make_unique<ScopedCpuTensorHandle>(*m_QuantizedLstmParameters.m_InputGateBias) : nullptr;
71 layer->m_QuantizedLstmParameters.m_ForgetGateBias = m_QuantizedLstmParameters.m_ForgetGateBias ?
72 std::make_unique<ScopedCpuTensorHandle>(*m_QuantizedLstmParameters.m_ForgetGateBias) : nullptr;
73 layer->m_QuantizedLstmParameters.m_CellBias = m_QuantizedLstmParameters.m_CellBias ?
74 std::make_unique<ScopedCpuTensorHandle>(*m_QuantizedLstmParameters.m_CellBias) : nullptr;
75 layer->m_QuantizedLstmParameters.m_OutputGateBias = m_QuantizedLstmParameters.m_OutputGateBias ?
76 std::make_unique<ScopedCpuTensorHandle>(*m_QuantizedLstmParameters.m_OutputGateBias) : nullptr;
77
78 return std::move(layer);
79 }
80
InferOutputShapes(const std::vector<TensorShape> & inputShapes) const81 std::vector<TensorShape> QuantizedLstmLayer::InferOutputShapes(const std::vector<TensorShape>& inputShapes) const
82 {
83 ARMNN_ASSERT(inputShapes.size() == 3);
84
85 // Get input values for validation
86 unsigned int numBatches = inputShapes[0][0];
87 unsigned int outputSize = inputShapes[1][1];
88
89 std::vector<TensorShape> outShapes;
90 outShapes.push_back(TensorShape({numBatches, outputSize})); // cellStateOut
91 outShapes.push_back(TensorShape({numBatches, outputSize})); // output
92
93 return outShapes;
94 }
95
ValidateTensorShapesFromInputs()96 void QuantizedLstmLayer::ValidateTensorShapesFromInputs()
97 {
98 VerifyLayerConnections(3, CHECK_LOCATION());
99
100 const TensorShape& outputShape = GetOutputSlot(0).GetTensorInfo().GetShape();
101
102 VerifyShapeInferenceType(outputShape, m_ShapeInferenceMethod);
103
104 auto inferredShapes = InferOutputShapes(
105 {
106 GetInputSlot(0).GetConnection()->GetTensorInfo().GetShape(), // input
107 GetInputSlot(1).GetConnection()->GetTensorInfo().GetShape(), // previousCellStateIn
108 GetInputSlot(2).GetConnection()->GetTensorInfo().GetShape() // previousOutputIn
109 });
110
111 ARMNN_ASSERT(inferredShapes.size() == 2);
112
113 // Check weights and bias for nullptr
114 ARMNN_ASSERT_MSG(m_QuantizedLstmParameters.m_InputToInputWeights != nullptr,
115 "QuantizedLstmLayer: m_QuantizedLstmParameters.m_InputToInputWeights should not be null.");
116 ARMNN_ASSERT_MSG(m_QuantizedLstmParameters.m_InputToForgetWeights != nullptr,
117 "QuantizedLstmLayer: m_QuantizedLstmParameters.m_InputToForgetWeights should not be null.");
118 ARMNN_ASSERT_MSG(m_QuantizedLstmParameters.m_InputToCellWeights != nullptr,
119 "QuantizedLstmLayer: m_QuantizedLstmParameters.m_InputToCellWeights should not be null.");
120 ARMNN_ASSERT_MSG(m_QuantizedLstmParameters.m_InputToOutputWeights != nullptr,
121 "QuantizedLstmLayer: m_QuantizedLstmParameters.m_InputToOutputWeights should not be null.");
122
123 ARMNN_ASSERT_MSG(m_QuantizedLstmParameters.m_RecurrentToInputWeights != nullptr,
124 "QuantizedLstmLayer: m_QuantizedLstmParameters.m_RecurrentToInputWeights should not be null.");
125 ARMNN_ASSERT_MSG(m_QuantizedLstmParameters.m_RecurrentToForgetWeights != nullptr,
126 "QuantizedLstmLayer: m_QuantizedLstmParameters.m_RecurrentToForgetWeights should not be null.");
127 ARMNN_ASSERT_MSG(m_QuantizedLstmParameters.m_RecurrentToCellWeights != nullptr,
128 "QuantizedLstmLayer: m_QuantizedLstmParameters.m_RecurrentToCellWeights should not be null.");
129 ARMNN_ASSERT_MSG(m_QuantizedLstmParameters.m_RecurrentToOutputWeights != nullptr,
130 "QuantizedLstmLayer: m_QuantizedLstmParameters.m_RecurrentToOutputWeights should not be null.");
131
132 ARMNN_ASSERT_MSG(m_QuantizedLstmParameters.m_InputGateBias != nullptr,
133 "QuantizedLstmLayer: m_QuantizedLstmParameters.m_InputGateBias should not be null.");
134 ARMNN_ASSERT_MSG(m_QuantizedLstmParameters.m_ForgetGateBias != nullptr,
135 "QuantizedLstmLayer: m_QuantizedLstmParameters.m_ForgetGateBias should not be null.");
136 ARMNN_ASSERT_MSG(m_QuantizedLstmParameters.m_CellBias != nullptr,
137 "QuantizedLstmLayer: m_QuantizedLstmParameters.m_CellBias should not be null.");
138 ARMNN_ASSERT_MSG(m_QuantizedLstmParameters.m_OutputGateBias != nullptr,
139 "QuantizedLstmLayer: m_QuantizedLstmParameters.m_OutputGateBias should not be null.");
140
141 // Check output TensorShape(s) match inferred shape
142 ValidateAndCopyShape(outputShape, inferredShapes[0], m_ShapeInferenceMethod, "QuantizedLstmLayer");
143
144 ValidateAndCopyShape(GetOutputSlot(1).GetTensorInfo().GetShape(),
145 inferredShapes[1],
146 m_ShapeInferenceMethod,
147 "QuantizedLstmLayer",
148 1);
149 }
150
GetConstantTensorsByRef()151 Layer::ConstantTensors QuantizedLstmLayer::GetConstantTensorsByRef()
152 {
153 return
154 {
155 m_QuantizedLstmParameters.m_InputToInputWeights,
156 m_QuantizedLstmParameters.m_InputToForgetWeights,
157 m_QuantizedLstmParameters.m_InputToCellWeights,
158 m_QuantizedLstmParameters.m_InputToOutputWeights,
159
160 m_QuantizedLstmParameters.m_RecurrentToInputWeights,
161 m_QuantizedLstmParameters.m_RecurrentToForgetWeights,
162 m_QuantizedLstmParameters.m_RecurrentToCellWeights,
163 m_QuantizedLstmParameters.m_RecurrentToOutputWeights,
164
165 m_QuantizedLstmParameters.m_InputGateBias,
166 m_QuantizedLstmParameters.m_ForgetGateBias,
167 m_QuantizedLstmParameters.m_CellBias,
168 m_QuantizedLstmParameters.m_OutputGateBias
169 };
170 }
171
Accept(ILayerVisitor & visitor) const172 void QuantizedLstmLayer::Accept(ILayerVisitor& visitor) const
173 {
174 QuantizedLstmInputParams inputParams;
175
176 // InputToX weight tensors
177 ConstTensor inputToInputWeightsTensor;
178 if (m_QuantizedLstmParameters.m_InputToInputWeights != nullptr)
179 {
180 ConstTensor inputToInputWeightsTensorCopy(m_QuantizedLstmParameters.m_InputToInputWeights->GetTensorInfo(),
181 m_QuantizedLstmParameters.m_InputToInputWeights->Map(true));
182 inputToInputWeightsTensor = inputToInputWeightsTensorCopy;
183 inputParams.m_InputToInputWeights = &inputToInputWeightsTensor;
184 }
185
186 ConstTensor inputToForgetWeightsTensor;
187 if (m_QuantizedLstmParameters.m_InputToForgetWeights != nullptr)
188 {
189 ConstTensor inputToForgetWeightsTensorCopy(m_QuantizedLstmParameters.m_InputToForgetWeights->GetTensorInfo(),
190 m_QuantizedLstmParameters.m_InputToForgetWeights->Map(true));
191 inputToForgetWeightsTensor = inputToForgetWeightsTensorCopy;
192 inputParams.m_InputToForgetWeights = &inputToForgetWeightsTensor;
193 }
194
195 ConstTensor inputToCellWeightsTensor;
196 if (m_QuantizedLstmParameters.m_InputToCellWeights != nullptr)
197 {
198 ConstTensor inputToCellWeightsTensorCopy(m_QuantizedLstmParameters.m_InputToCellWeights->GetTensorInfo(),
199 m_QuantizedLstmParameters.m_InputToCellWeights->Map(true));
200 inputToCellWeightsTensor = inputToCellWeightsTensorCopy;
201 inputParams.m_InputToCellWeights = &inputToCellWeightsTensor;
202 }
203
204 ConstTensor inputToOutputWeightsTensor;
205 if (m_QuantizedLstmParameters.m_InputToOutputWeights != nullptr)
206 {
207 ConstTensor inputToOutputWeightsTensorCopy(m_QuantizedLstmParameters.m_InputToOutputWeights->GetTensorInfo(),
208 m_QuantizedLstmParameters.m_InputToOutputWeights->Map(true));
209 inputToOutputWeightsTensor = inputToOutputWeightsTensorCopy;
210 inputParams.m_InputToOutputWeights = &inputToOutputWeightsTensor;
211 }
212
213 // RecurrentToX weight tensors
214 ConstTensor recurrentToInputWeightsTensor;
215 if (m_QuantizedLstmParameters.m_RecurrentToInputWeights != nullptr)
216 {
217 ConstTensor recurrentToInputWeightsTensorCopy(
218 m_QuantizedLstmParameters.m_RecurrentToInputWeights->GetTensorInfo(),
219 m_QuantizedLstmParameters.m_RecurrentToInputWeights->Map(true));
220 recurrentToInputWeightsTensor = recurrentToInputWeightsTensorCopy;
221 inputParams.m_RecurrentToInputWeights = &recurrentToInputWeightsTensor;
222 }
223
224 ConstTensor recurrentToForgetWeightsTensor;
225 if (m_QuantizedLstmParameters.m_RecurrentToForgetWeights != nullptr)
226 {
227 ConstTensor recurrentToForgetWeightsTensorCopy(
228 m_QuantizedLstmParameters.m_RecurrentToForgetWeights->GetTensorInfo(),
229 m_QuantizedLstmParameters.m_RecurrentToForgetWeights->Map(true));
230 recurrentToForgetWeightsTensor = recurrentToForgetWeightsTensorCopy;
231 inputParams.m_RecurrentToForgetWeights = &recurrentToForgetWeightsTensor;
232 }
233
234 ConstTensor recurrentToCellWeightsTensor;
235 if (m_QuantizedLstmParameters.m_RecurrentToCellWeights != nullptr)
236 {
237 ConstTensor recurrentToCellWeightsTensorCopy(
238 m_QuantizedLstmParameters.m_RecurrentToCellWeights->GetTensorInfo(),
239 m_QuantizedLstmParameters.m_RecurrentToCellWeights->Map(true));
240 recurrentToCellWeightsTensor = recurrentToCellWeightsTensorCopy;
241 inputParams.m_RecurrentToCellWeights = &recurrentToCellWeightsTensor;
242 }
243
244 ConstTensor recurrentToOutputWeightsTensor;
245 if (m_QuantizedLstmParameters.m_RecurrentToOutputWeights != nullptr)
246 {
247 ConstTensor recurrentToOutputWeightsTensorCopy(
248 m_QuantizedLstmParameters.m_RecurrentToOutputWeights->GetTensorInfo(),
249 m_QuantizedLstmParameters.m_RecurrentToOutputWeights->Map(true));
250 recurrentToOutputWeightsTensor = recurrentToOutputWeightsTensorCopy;
251 inputParams.m_RecurrentToOutputWeights = &recurrentToOutputWeightsTensor;
252 }
253
254 // Bias tensors
255 ConstTensor inputGateBiasTensor;
256 if (m_QuantizedLstmParameters.m_InputGateBias != nullptr)
257 {
258 ConstTensor inputGateBiasTensorCopy(m_QuantizedLstmParameters.m_InputGateBias->GetTensorInfo(),
259 m_QuantizedLstmParameters.m_InputGateBias->Map(true));
260 inputGateBiasTensor = inputGateBiasTensorCopy;
261 inputParams.m_InputGateBias = &inputGateBiasTensor;
262 }
263
264 ConstTensor forgetGateBiasTensor;
265 if (m_QuantizedLstmParameters.m_ForgetGateBias != nullptr)
266 {
267 ConstTensor forgetGateBiasTensorCopy(m_QuantizedLstmParameters.m_ForgetGateBias->GetTensorInfo(),
268 m_QuantizedLstmParameters.m_ForgetGateBias->Map(true));
269 forgetGateBiasTensor = forgetGateBiasTensorCopy;
270 inputParams.m_ForgetGateBias = &forgetGateBiasTensor;
271 }
272
273 ConstTensor cellBiasTensor;
274 if (m_QuantizedLstmParameters.m_CellBias != nullptr)
275 {
276 ConstTensor cellBiasTensorCopy(m_QuantizedLstmParameters.m_CellBias->GetTensorInfo(),
277 m_QuantizedLstmParameters.m_CellBias->Map(true));
278 cellBiasTensor = cellBiasTensorCopy;
279 inputParams.m_CellBias = &cellBiasTensor;
280 }
281
282 ConstTensor outputGateBiasTensor;
283 if (m_QuantizedLstmParameters.m_OutputGateBias != nullptr)
284 {
285 ConstTensor outputGateBiasCopy(m_QuantizedLstmParameters.m_OutputGateBias->GetTensorInfo(),
286 m_QuantizedLstmParameters.m_OutputGateBias->Map(true));
287 outputGateBiasTensor = outputGateBiasCopy;
288 inputParams.m_OutputGateBias = &outputGateBiasTensor;
289 }
290
291 visitor.VisitQuantizedLstmLayer(this, inputParams, GetName());
292 }
293
294 } // namespace armnn
295