1 // 2 // Copyright © 2020 Arm Ltd. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 6 #pragma once 7 8 #include <armnn/TypesUtils.hpp> 9 10 #include <backendsCommon/Workload.hpp> 11 #include <backendsCommon/WorkloadData.hpp> 12 13 namespace armnn 14 { 15 16 class RefQLstmWorkload : public BaseWorkload<QLstmQueueDescriptor> 17 { 18 public: 19 explicit RefQLstmWorkload(const QLstmQueueDescriptor& descriptor, const WorkloadInfo& info); 20 21 virtual void Execute() const override; 22 23 private: 24 std::unique_ptr<ScopedCpuTensorHandle> m_InputToInputWeightsTensor; 25 std::unique_ptr<ScopedCpuTensorHandle> m_InputToForgetWeightsTensor; 26 std::unique_ptr<ScopedCpuTensorHandle> m_InputToCellWeightsTensor; 27 std::unique_ptr<ScopedCpuTensorHandle> m_InputToOutputWeightsTensor; 28 29 std::unique_ptr<ScopedCpuTensorHandle> m_RecurrentToInputWeightsTensor; 30 std::unique_ptr<ScopedCpuTensorHandle> m_RecurrentToForgetWeightsTensor; 31 std::unique_ptr<ScopedCpuTensorHandle> m_RecurrentToCellWeightsTensor; 32 std::unique_ptr<ScopedCpuTensorHandle> m_RecurrentToOutputWeightsTensor; 33 34 std::unique_ptr<ScopedCpuTensorHandle> m_CellToInputWeightsTensor; 35 std::unique_ptr<ScopedCpuTensorHandle> m_CellToForgetWeightsTensor; 36 std::unique_ptr<ScopedCpuTensorHandle> m_CellToOutputWeightsTensor; 37 38 std::unique_ptr<ScopedCpuTensorHandle> m_InputGateBiasTensor; 39 std::unique_ptr<ScopedCpuTensorHandle> m_ForgetGateBiasTensor; 40 std::unique_ptr<ScopedCpuTensorHandle> m_CellBiasTensor; 41 std::unique_ptr<ScopedCpuTensorHandle> m_OutputGateBiasTensor; 42 43 std::unique_ptr<ScopedCpuTensorHandle> m_ProjectionWeightsTensor; 44 std::unique_ptr<ScopedCpuTensorHandle> m_ProjectionBiasTensor; 45 46 std::unique_ptr<ScopedCpuTensorHandle> m_InputLayerNormWeightsTensor; 47 std::unique_ptr<ScopedCpuTensorHandle> m_ForgetLayerNormWeightsTensor; 48 std::unique_ptr<ScopedCpuTensorHandle> m_CellLayerNormWeightsTensor; 49 std::unique_ptr<ScopedCpuTensorHandle> m_OutputLayerNormWeightsTensor; 50 51 float m_LayerNormEpsilon = static_cast<float>(1e-8); 52 }; 53 54 } //namespace armnn 55