• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_TOCO_GRAPH_TRANSFORMATIONS_LSTM_UTILS_H_
16 #define TENSORFLOW_LITE_TOCO_GRAPH_TRANSFORMATIONS_LSTM_UTILS_H_
17 
18 #include <iostream>
19 #include <string>
20 #include <vector>
21 
22 #include "tensorflow/lite/toco/model.h"
23 #include "tensorflow/lite/toco/tooling_util.h"
24 
25 namespace toco {
26 
27 // For consistency with the parameters defined in extended LstmCell's kernel
28 // (tensorflow/lite/kernels/lstm.cc),
29 // use lowercase for these constants.
30 
31 enum ExtendedLstmCellInputs {
32   kInputTensor = 0,
33   kInputToInputWeightsTensor = 1,  // Optional
34   kInputToForgetWeightsTensor = 2,
35   kInputToCellWeightsTensor = 3,
36   kInputToOutputWeightsTensor = 4,
37   kRecurrentToInputWeightsTensor = 5,  // Optional
38   kRecurrentToForgetWeightsTensor = 6,
39   kRecurrentToCellWeightsTensor = 7,
40   kRecurrentToOutputWeightsTensor = 8,
41   kCellToInputWeightsTensor = 9,    // Optional
42   kCellToForgetWeightsTensor = 10,  // Optional
43   kCellToOutputWeightsTensor = 11,  // Optional
44   kInputGateBiasTensor = 12,        // Optional
45   kForgetGateBiasTensor = 13,
46   kCellGateBiasTensor = 14,
47   kOutputGateBiasTensor = 15,
48   kProjectionWeightsTensor = 16,  // Optional
49   kProjectionBiasTensor = 17,     // Optional
50   kInputActivationStateTensor = 18,
51   // The op can handle 18 inputs or 20 inputs.
52   kInputCellStateTensor = 19,
53   kExtendedLstmInputCount = 20,
54 };
55 
56 enum ExtendedLstmCellOutputs {
57   // TODO(ycling): Make the 2 output state tensors optional.
58   kOutputStateTensor = 0,
59   kCellStateTensor = 1,
60   kOutputTensor = 2,
61   kExtendedLstmOutputCount = 3
62 };
63 
64 // Create optional array used for optional tensor in ExtendedLstmCell inputs.
65 void CreateOptionalArray(Model* model, std::string* input_array_buffer,
66                          const std::string& array_name);
67 
68 // Create float array and get its buffer.
69 Buffer<ArrayDataType::kFloat>* CreateFloatArrayBuffer(Model* model,
70                                                       std::string* array_name,
71                                                       const Shape& shape);
72 
73 // Copy data from one array to the other one (supports 1D and 2D array),
74 // for 1D array, the 2nd dim's size is 1.
75 // Arguments:
76 //   src_buffer: the source buffer
77 //   src_stride: the stride of source buffer, i.e., 2nd dim's size
78 //   src_start_idx1: the 1st dim index of start point in src matrix
79 //   src_start_idx2: the 2nd dim index of start point in src matrix
80 //   dst_buffer: the destination buffer
81 //   dst_stride: the stride of destination buffer, i.e., 2nd dim's size
82 //   dst_start_idx1: the 1st dim index of start point in dst matrix
83 //   dst_start_idx2: the 2nd dim index of start point in dst matrix
84 //   dim1_copy_size: 1st dim size of copy data
85 //   dim2_copy_size: 2nd dim size of copy data
86 void CopyArrayData(const Buffer<ArrayDataType::kFloat>& src_buffer,
87                    int src_stride, int src_start_idx1, int src_start_idx2,
88                    Buffer<ArrayDataType::kFloat>* dst_buffer, int dst_stride,
89                    int dst_start_idx1, int dst_start_idx2, int dim1_copy_size,
90                    int dim2_copy_size);
91 
92 // Copy a subset of array data and create a smaller array,
93 // mostly used for spliting weights and bias for Lstm cell.
94 void CopySubArrayToArray(Model* model, std::string* array_name,
95                          const std::string& tensor_name, int dim1_size,
96                          int dim2_size, const Array& original_array,
97                          int start_idx1, int start_idx2);
98 
99 // Copy array data to a large array's submatrix,
100 // mostly used for merging weights and bias for Lstm cell.
101 void CopyArrayToSubArray(Buffer<ArrayDataType::kFloat>& tensor_buffer,
102                          int tensor_stride, const Array& sub_array,
103                          int start_idx1, int start_idx2);
104 
105 // Get mating rnn array inputs using rnn_states flag.
106 bool GetMatchingRnnArray(Model* model,
107                          const std::string& back_edge_source_array,
108                          std::string* rnn_array);
109 
110 }  // namespace toco
111 
112 #endif  // TENSORFLOW_LITE_TOCO_GRAPH_TRANSFORMATIONS_LSTM_UTILS_H_
113