1 /* 2 * Copyright (C) 2017 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef FRAMEWORKS_ML_NN_LSTMCELL_H 18 #define FRAMEWORKS_ML_NN_LSTMCELL_H 19 20 #include "ActivationFunctor.h" 21 #include "HalOperation.h" 22 #include "tensorflow/lite/kernels/internal/tensor_utils.h" 23 24 #include <algorithm> 25 #include <cmath> 26 27 namespace android { 28 namespace nn { 29 30 struct LSTMParams { 31 TfLiteFusedActivation activation; 32 float cell_clip; 33 float proj_clip; 34 bool use_cifg; 35 bool use_peephole; 36 bool use_layer_norm; 37 bool use_projection_weight; 38 bool use_projection_bias; 39 bool merge_outputs; 40 bool time_major; 41 }; 42 43 struct RunTimeOperandInfo; 44 struct Shape; 45 46 class LSTMCell { 47 public: 48 LSTMCell(const Operation& operation, std::vector<RunTimeOperandInfo>& operands); 49 50 bool Prepare(const Operation& operation, std::vector<RunTimeOperandInfo>& operands, 51 Shape* scratchShape, Shape* outputStateShape, Shape* cellStateShape, 52 Shape* outputShape); 53 bool Eval(); 54 55 // Input Tensors of size {n_batch, n_input} 56 static constexpr int kInputTensor = 0; 57 58 // Input weight tensors of size: {n_cell, n_input} 59 static constexpr int kInputToInputWeightsTensor = 1; // Optional 60 static constexpr int kInputToForgetWeightsTensor = 2; 61 static constexpr int kInputToCellWeightsTensor = 3; 62 static constexpr int kInputToOutputWeightsTensor = 4; 63 64 // Recurrent weight tensors of size {n_cell, n_output} 65 static constexpr int kRecurrentToInputWeightsTensor = 5; // Optional 66 static constexpr int kRecurrentToForgetWeightsTensor = 6; 67 static constexpr int kRecurrentToCellWeightsTensor = 7; 68 static constexpr int kRecurrentToOutputWeightsTensor = 8; 69 70 // Peephole weights tensors of size {n_cell}, representing a diagonal matrix. 71 static constexpr int kCellToInputWeightsTensor = 9; // Optional 72 static constexpr int kCellToForgetWeightsTensor = 10; // Optional 73 static constexpr int kCellToOutputWeightsTensor = 11; // Optional 74 75 // Gates bias tensors of size {n_cell} 76 static constexpr int kInputGateBiasTensor = 12; // Optional 77 static constexpr int kForgetGateBiasTensor = 13; 78 static constexpr int kCellGateBiasTensor = 14; 79 static constexpr int kOutputGateBiasTensor = 15; 80 81 // Projection weight tensor of size {n_output, n_cell} 82 static constexpr int kProjectionWeightsTensor = 16; // Optional 83 // Projection bias tensor of size {n_output} 84 static constexpr int kProjectionBiasTensor = 17; // Optional 85 86 static constexpr int kOutputStateInTensor = 18; 87 static constexpr int kCellStateInTensor = 19; 88 89 static constexpr int kActivationParam = 20; 90 static constexpr int kCellClipParam = 21; 91 static constexpr int kProjClipParam = 22; 92 93 // Layer norm weights tensors of size {n_cell}, representing a diagonal matrix. 94 static constexpr int kInputLayerNormWeightsTensor = 23; 95 static constexpr int kForgetLayerNormWeightsTensor = 24; 96 static constexpr int kCellLayerNormWeightsTensor = 25; 97 static constexpr int kOutputLayerNormWeightsTensor = 26; 98 99 // Output tensors. 100 static constexpr int kScratchBufferTensor = 0; 101 static constexpr int kOutputStateOutTensor = 1; 102 static constexpr int kCellStateOutTensor = 2; 103 static constexpr int kOutputTensor = 3; 104 105 static constexpr float kLayerNormEpsilon = 1e-8; 106 107 static bool LSTMEvalFloat32( 108 const LSTMParams& params, const float* input_buffer, const Shape& input_shape, 109 const float* input_to_input_weights_buffer, const float* input_to_forget_weights_buffer, 110 const float* input_to_cell_weights_buffer, const float* input_to_output_weights_buffer, 111 const Shape& input_to_output_weights_shape, 112 const float* recurrent_to_input_weights_buffer, 113 const float* recurrent_to_forget_weights_buffer, 114 const float* recurrent_to_cell_weights_buffer, 115 const float* recurrent_to_output_weights_buffer, 116 const Shape& recurrent_to_output_weights_shape, 117 const float* cell_to_input_weights_buffer, const float* cell_to_forget_weights_buffer, 118 const float* cell_to_output_weights_buffer, const float* aux_input_buffer, 119 const float* aux_input_to_input_weights, const float* aux_input_to_forget_weights, 120 const float* aux_input_to_cell_weights, const float* aux_input_to_output_weights, 121 const float* input_gate_bias_buffer, const float* forget_gate_bias_buffer, 122 const float* cell_bias_buffer, const float* output_gate_bias_buffer, 123 const float* projection_weights_buffer, const float* projection_bias_buffer, 124 const float* output_state_in_buffer, const float* cell_state_in_buffer, 125 const float* input_layer_norm_weights_buffer, 126 const float* forget_layer_norm_weights_buffer, 127 const float* cell_layer_norm_weights_buffer, 128 const float* output_layer_norm_weights_buffer, float* output_state_out_buffer, 129 float* cell_state_out_buffer, float* output_buffer, float* scratch_buffer_buffer, 130 bool timeMajor = true, bool forwardSequence = true); 131 132 static bool LSTMEvalFloat16( 133 const LSTMParams& params, const _Float16* input_buffer, const Shape& input_shape, 134 const _Float16* input_to_input_weights_buffer, 135 const _Float16* input_to_forget_weights_buffer, 136 const _Float16* input_to_cell_weights_buffer, 137 const _Float16* input_to_output_weights_buffer, 138 const Shape& input_to_output_weights_shape, 139 const _Float16* recurrent_to_input_weights_buffer, 140 const _Float16* recurrent_to_forget_weights_buffer, 141 const _Float16* recurrent_to_cell_weights_buffer, 142 const _Float16* recurrent_to_output_weights_buffer, 143 const Shape& recurrent_to_output_weights_shape, 144 const _Float16* cell_to_input_weights_buffer, 145 const _Float16* cell_to_forget_weights_buffer, 146 const _Float16* cell_to_output_weights_buffer, const _Float16* aux_input_buffer, 147 const _Float16* aux_input_to_input_weights, const _Float16* aux_input_to_forget_weights, 148 const _Float16* aux_input_to_cell_weights, const _Float16* aux_input_to_output_weights, 149 const _Float16* input_gate_bias_buffer, const _Float16* forget_gate_bias_buffer, 150 const _Float16* cell_bias_buffer, const _Float16* output_gate_bias_buffer, 151 const _Float16* projection_weights_buffer, const _Float16* projection_bias_buffer, 152 const _Float16* output_state_in_buffer, const _Float16* cell_state_in_buffer, 153 const _Float16* input_layer_norm_weights_buffer, 154 const _Float16* forget_layer_norm_weights_buffer, 155 const _Float16* cell_layer_norm_weights_buffer, 156 const _Float16* output_layer_norm_weights_buffer, _Float16* output_state_out_buffer, 157 _Float16* cell_state_out_buffer, _Float16* output_buffer, 158 _Float16* scratch_buffer_buffer, bool timeMajor = true, bool forwardSequence = true); 159 160 static bool LSTMStep( 161 const LSTMParams& params, const float* input_buffer, const Shape& input_shape, 162 const float* input_to_input_weights_buffer, const float* input_to_forget_weights_buffer, 163 const float* input_to_cell_weights_buffer, const float* input_to_output_weights_buffer, 164 const Shape& input_to_output_weights_shape, 165 const float* recurrent_to_input_weights_buffer, 166 const float* recurrent_to_forget_weights_buffer, 167 const float* recurrent_to_cell_weights_buffer, 168 const float* recurrent_to_output_weights_buffer, 169 const Shape& recurrent_to_output_weights_shape, 170 const float* cell_to_input_weights_buffer, const float* cell_to_forget_weights_buffer, 171 const float* cell_to_output_weights_buffer, const float* aux_input_buffer, 172 const float* aux_input_to_input_weights, const float* aux_input_to_forget_weights, 173 const float* aux_input_to_cell_weights, const float* aux_input_to_output_weights, 174 const float* input_gate_bias_buffer, const float* forget_gate_bias_buffer, 175 const float* cell_bias_buffer, const float* output_gate_bias_buffer, 176 const float* projection_weights_buffer, const float* projection_bias_buffer, 177 const float* output_state_in_buffer, const float* cell_state_in_buffer, 178 const float* input_layer_norm_weights_buffer, 179 const float* forget_layer_norm_weights_buffer, 180 const float* cell_layer_norm_weights_buffer, 181 const float* output_layer_norm_weights_buffer, float* output_state_out_buffer, 182 float* cell_state_out_buffer, float* output_buffer, float* scratch_buffer_buffer); 183 184 static bool CheckInputTensorDimensions( 185 const RunTimeOperandInfo* input_, const RunTimeOperandInfo* input_to_input_weights, 186 const RunTimeOperandInfo* input_to_forget_weights, 187 const RunTimeOperandInfo* input_to_cell_weights, 188 const RunTimeOperandInfo* input_to_output_weights, 189 const RunTimeOperandInfo* recurrent_to_input_weights, 190 const RunTimeOperandInfo* recurrent_to_forget_weights, 191 const RunTimeOperandInfo* recurrent_to_cell_weights, 192 const RunTimeOperandInfo* recurrent_to_output_weights, 193 const RunTimeOperandInfo* cell_to_input_weights, 194 const RunTimeOperandInfo* cell_to_forget_weights, 195 const RunTimeOperandInfo* cell_to_output_weights, 196 const RunTimeOperandInfo* input_gate_bias, const RunTimeOperandInfo* forget_gate_bias, 197 const RunTimeOperandInfo* cell_bias, const RunTimeOperandInfo* output_gate_bias, 198 const RunTimeOperandInfo* projection_weights, const RunTimeOperandInfo* projection_bias, 199 const RunTimeOperandInfo* input_layer_norm_weights, 200 const RunTimeOperandInfo* forget_layer_norm_weights, 201 const RunTimeOperandInfo* cell_layer_norm_weights, 202 const RunTimeOperandInfo* output_layer_norm_weights, uint32_t n_input, 203 uint32_t n_output, uint32_t n_cell, LSTMParams* params); 204 205 private: 206 LSTMParams params_; 207 const RunTimeOperandInfo* input_; 208 209 const RunTimeOperandInfo* input_to_input_weights_; 210 const RunTimeOperandInfo* input_to_forget_weights_; 211 const RunTimeOperandInfo* input_to_cell_weights_; 212 const RunTimeOperandInfo* input_to_output_weights_; 213 214 const RunTimeOperandInfo* recurrent_to_input_weights_; 215 const RunTimeOperandInfo* recurrent_to_forget_weights_; 216 const RunTimeOperandInfo* recurrent_to_cell_weights_; 217 const RunTimeOperandInfo* recurrent_to_output_weights_; 218 219 const RunTimeOperandInfo* cell_to_input_weights_; 220 const RunTimeOperandInfo* cell_to_forget_weights_; 221 const RunTimeOperandInfo* cell_to_output_weights_; 222 223 const RunTimeOperandInfo* input_gate_bias_; 224 const RunTimeOperandInfo* forget_gate_bias_; 225 const RunTimeOperandInfo* cell_bias_; 226 const RunTimeOperandInfo* output_gate_bias_; 227 228 const RunTimeOperandInfo* projection_weights_; 229 const RunTimeOperandInfo* projection_bias_; 230 231 const RunTimeOperandInfo* output_state_in_; 232 const RunTimeOperandInfo* cell_state_in_; 233 234 const RunTimeOperandInfo* input_layer_norm_weights_; 235 const RunTimeOperandInfo* forget_layer_norm_weights_; 236 const RunTimeOperandInfo* cell_layer_norm_weights_; 237 const RunTimeOperandInfo* output_layer_norm_weights_; 238 239 RunTimeOperandInfo* output_state_out_; 240 RunTimeOperandInfo* cell_state_out_; 241 RunTimeOperandInfo* output_; 242 243 RunTimeOperandInfo* scratch_buffer_; 244 }; 245 246 } // namespace nn 247 } // namespace android 248 249 #endif // FRAMEWORKS_ML_NN_LSTMCELL_H 250