• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 #ifndef FRAMEWORKS_ML_NN_QUANTIZEDLSTM_H
2 #define FRAMEWORKS_ML_NN_QUANTIZEDLSTM_H
3 
4 #include "HalOperation.h"
5 #include "OperationsUtils.h"
6 
7 #include <vector>
8 
9 namespace android {
10 namespace nn {
11 
12 struct RunTimeOperandInfo;
13 
14 class QuantizedLSTMCell {
15    public:
16     QuantizedLSTMCell(const android::hardware::neuralnetworks::V1_2::Operation& operation,
17                       std::vector<RunTimeOperandInfo>& operands);
18 
19     static bool prepare(const android::hardware::neuralnetworks::V1_2::Operation& operation,
20                         std::vector<RunTimeOperandInfo>& operands, Shape* cellStateShape,
21                         Shape* outputShape);
22     bool eval();
23 
24     // Inputs:
25     static constexpr int kInputTensor = 0;
26     // Input weight tensors of size: {n_cell, n_input}
27     static constexpr int kInputToInputWeightsTensor = 1;
28     static constexpr int kInputToForgetWeightsTensor = 2;
29     static constexpr int kInputToCellWeightsTensor = 3;
30     static constexpr int kInputToOutputWeightsTensor = 4;
31 
32     // Recurrent weight tensors of size {n_cell, n_output}
33     static constexpr int kRecurrentToInputWeightsTensor = 5;
34     static constexpr int kRecurrentToForgetWeightsTensor = 6;
35     static constexpr int kRecurrentToCellWeightsTensor = 7;
36     static constexpr int kRecurrentToOutputWeightsTensor = 8;
37 
38     // Gates bias tensors of size {n_cell}
39     static constexpr int kInputGateBiasTensor = 9;
40     static constexpr int kForgetGateBiasTensor = 10;
41     static constexpr int kCellGateBiasTensor = 11;
42     static constexpr int kOutputGateBiasTensor = 12;
43 
44     static constexpr int kPrevCellStateTensor = 13;
45     static constexpr int kPrevOutputTensor = 14;
46 
47     // Outputs:
48     static constexpr int kCellStateOutTensor = 0;
49     static constexpr int kOutputTensor = 1;
50 
51    private:
52     const RunTimeOperandInfo* input_;
53 
54     const RunTimeOperandInfo* inputToInputWeights_;
55     const RunTimeOperandInfo* inputToForgetWeights_;
56     const RunTimeOperandInfo* inputToCellWeights_;
57     const RunTimeOperandInfo* inputToOutputWeights_;
58 
59     const RunTimeOperandInfo* recurrentToInputWeights_;
60     const RunTimeOperandInfo* recurrentToForgetWeights_;
61     const RunTimeOperandInfo* recurrentToCellWeights_;
62     const RunTimeOperandInfo* recurrentToOutputWeights_;
63 
64     const RunTimeOperandInfo* inputGateBias_;
65     const RunTimeOperandInfo* forgetGateBias_;
66     const RunTimeOperandInfo* cellGateBias_;
67     const RunTimeOperandInfo* outputGateBias_;
68 
69     const RunTimeOperandInfo* prevCellState_;
70     const RunTimeOperandInfo* prevOutput_;
71 
72     RunTimeOperandInfo* cellStateOut_;
73     RunTimeOperandInfo* output_;
74 
75     void concatenateWeights(const std::vector<uint32_t>& weightsDims, uint8_t* weights);
76     void concatenateBiases(uint32_t outputSize, int32_t* bias);
77 };
78 
79 }  // namespace nn
80 }  // namespace android
81 
82 #endif  // FRAMEWORKS_ML_NN_QUANTIZEDLSTM_H
83