Lines Matching refs:context
36 inline bool hasTensor(IOperationExecutionContext* context, const uint32_t tensor) { in hasTensor() argument
37 return context->getInputBuffer(tensor) != nullptr; in hasTensor()
42 bool prepare(IOperationExecutionContext* context) { in prepare() argument
59 NN_RET_CHECK(!context->isOmittedInput(tensor)) in prepare()
63 const Shape inputShape = context->getInputShape(kInputTensor); in prepare()
70 const Shape inputToOutputShape = context->getInputShape(kInputToOutputWeightsTensor); in prepare()
75 const Shape recurrentToOutputShape = context->getInputShape(kRecurrentToOutputWeightsTensor); in prepare()
80 if (hasTensor(context, kInputToInputWeightsTensor)) { in prepare()
81 const Shape inputToInputShape = context->getInputShape(kInputToInputWeightsTensor); in prepare()
87 const Shape inputToForgetShape = context->getInputShape(kInputToForgetWeightsTensor); in prepare()
91 const Shape inputToCellShape = context->getInputShape(kInputToCellWeightsTensor); in prepare()
96 if (hasTensor(context, kRecurrentToInputWeightsTensor)) { in prepare()
97 const Shape recurrentToInputShape = context->getInputShape(kRecurrentToInputWeightsTensor); in prepare()
103 const Shape recurrentToForgetShape = context->getInputShape(kRecurrentToForgetWeightsTensor); in prepare()
107 const Shape recurrentToCellShape = context->getInputShape(kRecurrentToCellWeightsTensor); in prepare()
114 const bool cifgWeightsAllOrNone = (hasTensor(context, kInputToInputWeightsTensor) && in prepare()
115 hasTensor(context, kRecurrentToInputWeightsTensor)) || in prepare()
116 (!hasTensor(context, kInputToInputWeightsTensor) && in prepare()
117 !hasTensor(context, kRecurrentToInputWeightsTensor)); in prepare()
120 if (hasTensor(context, kCellToInputWeightsTensor)) { in prepare()
121 const Shape cellToInputShape = context->getInputShape(kCellToInputWeightsTensor); in prepare()
126 if (hasTensor(context, kCellToForgetWeightsTensor)) { in prepare()
127 const Shape cellToForgetShape = context->getInputShape(kCellToForgetWeightsTensor); in prepare()
132 if (hasTensor(context, kCellToOutputWeightsTensor)) { in prepare()
133 const Shape cellToOutputShape = context->getInputShape(kCellToOutputWeightsTensor); in prepare()
139 const bool cifgUsed = !hasTensor(context, kInputToInputWeightsTensor); in prepare()
141 ((hasTensor(context, kCellToInputWeightsTensor) || cifgUsed) && in prepare()
142 hasTensor(context, kCellToForgetWeightsTensor) && in prepare()
143 hasTensor(context, kCellToOutputWeightsTensor)) || in prepare()
144 (!hasTensor(context, kCellToInputWeightsTensor) && in prepare()
145 !hasTensor(context, kCellToForgetWeightsTensor) && in prepare()
146 !hasTensor(context, kCellToOutputWeightsTensor)); in prepare()
150 NN_RET_CHECK(hasTensor(context, kInputGateBiasTensor)); in prepare()
151 const Shape inputGateBiasShape = context->getInputShape(kInputGateBiasTensor); in prepare()
155 NN_RET_CHECK(!hasTensor(context, kInputGateBiasTensor)) in prepare()
159 const Shape forgetGateBiasShape = context->getInputShape(kForgetGateBiasTensor); in prepare()
162 const Shape cellGateBiasShape = context->getInputShape(kCellGateBiasTensor); in prepare()
165 const Shape outputGateBiasShape = context->getInputShape(kOutputGateBiasTensor); in prepare()
169 if (hasTensor(context, kProjectionWeightsTensor)) { in prepare()
170 const Shape projectionShape = context->getInputShape(kProjectionWeightsTensor); in prepare()
176 if (hasTensor(context, kProjectionBiasTensor)) { in prepare()
177 const Shape projectionBiasShape = context->getInputShape(kProjectionBiasTensor); in prepare()
182 const Shape outputStateShape = context->getInputShape(kPrevOutputTensor); in prepare()
186 const Shape cellStateShape = context->getInputShape(kPrevCellStateTensor); in prepare()
191 if (hasTensor(context, kInputLayerNormTensor)) { in prepare()
192 const Shape inputLayerNormShape = context->getInputShape(kInputLayerNormTensor); in prepare()
197 if (hasTensor(context, kForgetLayerNormTensor)) { in prepare()
198 const Shape forgetLayerNormShape = context->getInputShape(kForgetLayerNormTensor); in prepare()
203 if (hasTensor(context, kCellLayerNormTensor)) { in prepare()
204 const Shape cellLayerNormShape = context->getInputShape(kCellLayerNormTensor); in prepare()
209 if (hasTensor(context, kOutputLayerNormTensor)) { in prepare()
210 const Shape outputLayerNormShape = context->getInputShape(kOutputLayerNormTensor); in prepare()
216 NN_RET_CHECK(!hasTensor(context, kInputLayerNormTensor)) in prepare()
218 const bool layerNormWeightsAllOrNoneCifg = (hasTensor(context, kForgetLayerNormTensor) && in prepare()
219 hasTensor(context, kCellLayerNormTensor) && in prepare()
220 hasTensor(context, kOutputLayerNormTensor)) || in prepare()
221 (!hasTensor(context, kForgetLayerNormTensor) && in prepare()
222 !hasTensor(context, kCellLayerNormTensor) && in prepare()
223 !hasTensor(context, kOutputLayerNormTensor)); in prepare()
226 const bool layerNormWeightsAllOrNone = (hasTensor(context, kInputLayerNormTensor) && in prepare()
227 hasTensor(context, kForgetLayerNormTensor) && in prepare()
228 hasTensor(context, kCellLayerNormTensor) && in prepare()
229 hasTensor(context, kOutputLayerNormTensor)) || in prepare()
230 (!hasTensor(context, kInputLayerNormTensor) && in prepare()
231 !hasTensor(context, kForgetLayerNormTensor) && in prepare()
232 !hasTensor(context, kCellLayerNormTensor) && in prepare()
233 !hasTensor(context, kOutputLayerNormTensor)); in prepare()
237 const Shape prevOutputShape = context->getInputShape(kPrevOutputTensor); in prepare()
238 Shape outputShape = context->getOutputShape(kOutputTensor); in prepare()
241 const Shape prevCellStateShape = context->getInputShape(kPrevCellStateTensor); in prepare()
242 Shape cellStateOutShape = context->getOutputShape(kCellStateOutTensor); in prepare()
245 return context->setOutputShape(kOutputStateOutTensor, outputShape) && in prepare()
246 context->setOutputShape(kCellStateOutTensor, cellStateOutShape) && in prepare()
247 context->setOutputShape(kOutputTensor, outputShape); in prepare()
251 bool execute(IOperationExecutionContext* context) { in execute() argument
253 const Shape inputShape = context->getInputShape(kInputTensor); in execute()
254 const Shape inputToInputWeightsShape = context->getInputShape(kInputToInputWeightsTensor); in execute()
256 context->getInputShape(kRecurrentToInputWeightsTensor); in execute()
257 const Shape cellToInputShape = context->getInputShape(kCellToInputWeightsTensor); in execute()
258 const Shape inputLayerNormShape = context->getInputShape(kInputLayerNormTensor); in execute()
259 const Shape inputToForgetWeightsShape = context->getInputShape(kInputToForgetWeightsTensor); in execute()
261 context->getInputShape(kRecurrentToForgetWeightsTensor); in execute()
262 const Shape cellToForgetShape = context->getInputShape(kCellToForgetWeightsTensor); in execute()
263 const Shape forgetLayerNormShape = context->getInputShape(kForgetLayerNormTensor); in execute()
264 const Shape inputToCellWeightsShape = context->getInputShape(kInputToCellWeightsTensor); in execute()
265 const Shape recurrentToCellWeightsShape = context->getInputShape(kRecurrentToCellWeightsTensor); in execute()
266 const Shape cellLayerNormShape = context->getInputShape(kCellLayerNormTensor); in execute()
267 const Shape inputToOutputWeightsShape = context->getInputShape(kInputToOutputWeightsTensor); in execute()
269 context->getInputShape(kRecurrentToOutputWeightsTensor); in execute()
270 const Shape cellToOutputShape = context->getInputShape(kCellToOutputWeightsTensor); in execute()
271 const Shape outputLayerNormShape = context->getInputShape(kOutputLayerNormTensor); in execute()
272 const Shape projectionWeightsShape = context->getInputShape(kProjectionWeightsTensor); in execute()
273 const Shape prevOutputShape = context->getInputShape(kPrevOutputTensor); in execute()
274 const Shape prevCellStateShape = context->getInputShape(kPrevCellStateTensor); in execute()
281 const float cellClip = context->getInputValue<float>(kCellClip); in execute()
282 const float projectionClip = context->getInputValue<float>(kProjectionClip); in execute()
283 const float inputIntermediateScale = context->getInputValue<float>(kInputIntermediateScale); in execute()
284 const float forgetIntermediateScale = context->getInputValue<float>(kForgetIntermediateScale); in execute()
285 const float cellIntermediateScale = context->getInputValue<float>(kCellIntermediateScale); in execute()
286 const float outputIntermediateScale = context->getInputValue<float>(kOutputIntermediateScale); in execute()
287 const int8_t hiddenStateZeroPoint = context->getInputValue<int8_t>(kHiddenStateZeroPoint); in execute()
288 const float hiddenStateScale = context->getInputValue<float>(kHiddenStateScale); in execute()
291 reinterpret_cast<const int8_t*>(context->getInputBuffer(kInputTensor)); in execute()
294 reinterpret_cast<const int8_t*>(context->getInputBuffer(kInputToInputWeightsTensor)); in execute()
297 context->getInputBuffer(kRecurrentToInputWeightsTensor)); in execute()
299 reinterpret_cast<const int16_t*>(context->getInputBuffer(kCellToInputWeightsTensor)); in execute()
301 reinterpret_cast<const int16_t*>(context->getInputBuffer(kInputLayerNormTensor)); in execute()
303 reinterpret_cast<const int32_t*>(context->getInputBuffer(kInputGateBiasTensor)); in execute()
306 reinterpret_cast<const int8_t*>(context->getInputBuffer(kInputToForgetWeightsTensor)); in execute()
308 context->getInputBuffer(kRecurrentToForgetWeightsTensor)); in execute()
310 reinterpret_cast<const int16_t*>(context->getInputBuffer(kCellToForgetWeightsTensor)); in execute()
312 reinterpret_cast<const int16_t*>(context->getInputBuffer(kForgetLayerNormTensor)); in execute()
314 reinterpret_cast<const int32_t*>(context->getInputBuffer(kForgetGateBiasTensor)); in execute()
317 reinterpret_cast<const int8_t*>(context->getInputBuffer(kInputToCellWeightsTensor)); in execute()
319 reinterpret_cast<const int8_t*>(context->getInputBuffer(kRecurrentToCellWeightsTensor)); in execute()
321 reinterpret_cast<const int16_t*>(context->getInputBuffer(kCellLayerNormTensor)); in execute()
323 reinterpret_cast<const int32_t*>(context->getInputBuffer(kCellGateBiasTensor)); in execute()
326 reinterpret_cast<const int8_t*>(context->getInputBuffer(kInputToOutputWeightsTensor)); in execute()
328 context->getInputBuffer(kRecurrentToOutputWeightsTensor)); in execute()
330 reinterpret_cast<const int16_t*>(context->getInputBuffer(kCellToOutputWeightsTensor)); in execute()
332 reinterpret_cast<const int16_t*>(context->getInputBuffer(kOutputLayerNormTensor)); in execute()
334 reinterpret_cast<const int32_t*>(context->getInputBuffer(kOutputGateBiasTensor)); in execute()
337 reinterpret_cast<const int8_t*>(context->getInputBuffer(kProjectionWeightsTensor)); in execute()
339 reinterpret_cast<const int32_t*>(context->getInputBuffer(kProjectionBiasTensor)); in execute()
342 reinterpret_cast<const int8_t*>(context->getInputBuffer(kPrevOutputTensor)); in execute()
344 reinterpret_cast<const int16_t*>(context->getInputBuffer(kPrevCellStateTensor)); in execute()
347 reinterpret_cast<uint8_t*>(context->getOutputBuffer(kOutputStateOutTensor)); in execute()
349 reinterpret_cast<int16_t*>(context->getOutputBuffer(kCellStateOutTensor)); in execute()
350 int8_t* outputBuffer = reinterpret_cast<int8_t*>(context->getOutputBuffer(kOutputTensor)); in execute()