1 // 2 // Copyright © 2017 Arm Ltd. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 #pragma once 6 7 #include "TensorFwd.hpp" 8 #include "Exceptions.hpp" 9 10 namespace armnn 11 { 12 13 struct LstmInputParams 14 { LstmInputParamsarmnn::LstmInputParams15 LstmInputParams() 16 : m_InputToInputWeights(nullptr) 17 , m_InputToForgetWeights(nullptr) 18 , m_InputToCellWeights(nullptr) 19 , m_InputToOutputWeights(nullptr) 20 , m_RecurrentToInputWeights(nullptr) 21 , m_RecurrentToForgetWeights(nullptr) 22 , m_RecurrentToCellWeights(nullptr) 23 , m_RecurrentToOutputWeights(nullptr) 24 , m_CellToInputWeights(nullptr) 25 , m_CellToForgetWeights(nullptr) 26 , m_CellToOutputWeights(nullptr) 27 , m_InputGateBias(nullptr) 28 , m_ForgetGateBias(nullptr) 29 , m_CellBias(nullptr) 30 , m_OutputGateBias(nullptr) 31 , m_ProjectionWeights(nullptr) 32 , m_ProjectionBias(nullptr) 33 , m_InputLayerNormWeights(nullptr) 34 , m_ForgetLayerNormWeights(nullptr) 35 , m_CellLayerNormWeights(nullptr) 36 , m_OutputLayerNormWeights(nullptr) 37 { 38 } 39 40 const ConstTensor* m_InputToInputWeights; 41 const ConstTensor* m_InputToForgetWeights; 42 const ConstTensor* m_InputToCellWeights; 43 const ConstTensor* m_InputToOutputWeights; 44 const ConstTensor* m_RecurrentToInputWeights; 45 const ConstTensor* m_RecurrentToForgetWeights; 46 const ConstTensor* m_RecurrentToCellWeights; 47 const ConstTensor* m_RecurrentToOutputWeights; 48 const ConstTensor* m_CellToInputWeights; 49 const ConstTensor* m_CellToForgetWeights; 50 const ConstTensor* m_CellToOutputWeights; 51 const ConstTensor* m_InputGateBias; 52 const ConstTensor* m_ForgetGateBias; 53 const ConstTensor* m_CellBias; 54 const ConstTensor* m_OutputGateBias; 55 const ConstTensor* m_ProjectionWeights; 56 const ConstTensor* m_ProjectionBias; 57 const ConstTensor* m_InputLayerNormWeights; 58 const ConstTensor* m_ForgetLayerNormWeights; 59 const ConstTensor* m_CellLayerNormWeights; 60 const ConstTensor* m_OutputLayerNormWeights; 61 }; 62 63 struct LstmInputParamsInfo 64 { LstmInputParamsInfoarmnn::LstmInputParamsInfo65 LstmInputParamsInfo() 66 : m_InputToInputWeights(nullptr) 67 , m_InputToForgetWeights(nullptr) 68 , m_InputToCellWeights(nullptr) 69 , m_InputToOutputWeights(nullptr) 70 , m_RecurrentToInputWeights(nullptr) 71 , m_RecurrentToForgetWeights(nullptr) 72 , m_RecurrentToCellWeights(nullptr) 73 , m_RecurrentToOutputWeights(nullptr) 74 , m_CellToInputWeights(nullptr) 75 , m_CellToForgetWeights(nullptr) 76 , m_CellToOutputWeights(nullptr) 77 , m_InputGateBias(nullptr) 78 , m_ForgetGateBias(nullptr) 79 , m_CellBias(nullptr) 80 , m_OutputGateBias(nullptr) 81 , m_ProjectionWeights(nullptr) 82 , m_ProjectionBias(nullptr) 83 , m_InputLayerNormWeights(nullptr) 84 , m_ForgetLayerNormWeights(nullptr) 85 , m_CellLayerNormWeights(nullptr) 86 , m_OutputLayerNormWeights(nullptr) 87 { 88 } 89 const TensorInfo* m_InputToInputWeights; 90 const TensorInfo* m_InputToForgetWeights; 91 const TensorInfo* m_InputToCellWeights; 92 const TensorInfo* m_InputToOutputWeights; 93 const TensorInfo* m_RecurrentToInputWeights; 94 const TensorInfo* m_RecurrentToForgetWeights; 95 const TensorInfo* m_RecurrentToCellWeights; 96 const TensorInfo* m_RecurrentToOutputWeights; 97 const TensorInfo* m_CellToInputWeights; 98 const TensorInfo* m_CellToForgetWeights; 99 const TensorInfo* m_CellToOutputWeights; 100 const TensorInfo* m_InputGateBias; 101 const TensorInfo* m_ForgetGateBias; 102 const TensorInfo* m_CellBias; 103 const TensorInfo* m_OutputGateBias; 104 const TensorInfo* m_ProjectionWeights; 105 const TensorInfo* m_ProjectionBias; 106 const TensorInfo* m_InputLayerNormWeights; 107 const TensorInfo* m_ForgetLayerNormWeights; 108 const TensorInfo* m_CellLayerNormWeights; 109 const TensorInfo* m_OutputLayerNormWeights; 110 Derefarmnn::LstmInputParamsInfo111 const TensorInfo& Deref(const TensorInfo* tensorInfo) const 112 { 113 if (tensorInfo != nullptr) 114 { 115 const TensorInfo &temp = *tensorInfo; 116 return temp; 117 } 118 throw InvalidArgumentException("Can't dereference a null pointer"); 119 } 120 GetInputToInputWeightsarmnn::LstmInputParamsInfo121 const TensorInfo& GetInputToInputWeights() const 122 { 123 return Deref(m_InputToInputWeights); 124 } GetInputToForgetWeightsarmnn::LstmInputParamsInfo125 const TensorInfo& GetInputToForgetWeights() const 126 { 127 return Deref(m_InputToForgetWeights); 128 } GetInputToCellWeightsarmnn::LstmInputParamsInfo129 const TensorInfo& GetInputToCellWeights() const 130 { 131 return Deref(m_InputToCellWeights); 132 } GetInputToOutputWeightsarmnn::LstmInputParamsInfo133 const TensorInfo& GetInputToOutputWeights() const 134 { 135 return Deref(m_InputToOutputWeights); 136 } GetRecurrentToInputWeightsarmnn::LstmInputParamsInfo137 const TensorInfo& GetRecurrentToInputWeights() const 138 { 139 return Deref(m_RecurrentToInputWeights); 140 } GetRecurrentToForgetWeightsarmnn::LstmInputParamsInfo141 const TensorInfo& GetRecurrentToForgetWeights() const 142 { 143 return Deref(m_RecurrentToForgetWeights); 144 } GetRecurrentToCellWeightsarmnn::LstmInputParamsInfo145 const TensorInfo& GetRecurrentToCellWeights() const 146 { 147 return Deref(m_RecurrentToCellWeights); 148 } GetRecurrentToOutputWeightsarmnn::LstmInputParamsInfo149 const TensorInfo& GetRecurrentToOutputWeights() const 150 { 151 return Deref(m_RecurrentToOutputWeights); 152 } GetCellToInputWeightsarmnn::LstmInputParamsInfo153 const TensorInfo& GetCellToInputWeights() const 154 { 155 return Deref(m_CellToInputWeights); 156 } GetCellToForgetWeightsarmnn::LstmInputParamsInfo157 const TensorInfo& GetCellToForgetWeights() const 158 { 159 return Deref(m_CellToForgetWeights); 160 } GetCellToOutputWeightsarmnn::LstmInputParamsInfo161 const TensorInfo& GetCellToOutputWeights() const 162 { 163 return Deref(m_CellToOutputWeights); 164 } GetInputGateBiasarmnn::LstmInputParamsInfo165 const TensorInfo& GetInputGateBias() const 166 { 167 return Deref(m_InputGateBias); 168 } GetForgetGateBiasarmnn::LstmInputParamsInfo169 const TensorInfo& GetForgetGateBias() const 170 { 171 return Deref(m_ForgetGateBias); 172 } GetCellBiasarmnn::LstmInputParamsInfo173 const TensorInfo& GetCellBias() const 174 { 175 return Deref(m_CellBias); 176 } GetOutputGateBiasarmnn::LstmInputParamsInfo177 const TensorInfo& GetOutputGateBias() const 178 { 179 return Deref(m_OutputGateBias); 180 } GetProjectionWeightsarmnn::LstmInputParamsInfo181 const TensorInfo& GetProjectionWeights() const 182 { 183 return Deref(m_ProjectionWeights); 184 } GetProjectionBiasarmnn::LstmInputParamsInfo185 const TensorInfo& GetProjectionBias() const 186 { 187 return Deref(m_ProjectionBias); 188 } GetInputLayerNormWeightsarmnn::LstmInputParamsInfo189 const TensorInfo& GetInputLayerNormWeights() const 190 { 191 return Deref(m_InputLayerNormWeights); 192 } GetForgetLayerNormWeightsarmnn::LstmInputParamsInfo193 const TensorInfo& GetForgetLayerNormWeights() const 194 { 195 return Deref(m_ForgetLayerNormWeights); 196 } GetCellLayerNormWeightsarmnn::LstmInputParamsInfo197 const TensorInfo& GetCellLayerNormWeights() const 198 { 199 return Deref(m_CellLayerNormWeights); 200 } GetOutputLayerNormWeightsarmnn::LstmInputParamsInfo201 const TensorInfo& GetOutputLayerNormWeights() const 202 { 203 return Deref(m_OutputLayerNormWeights); 204 } 205 }; 206 207 } // namespace armnn 208 209