1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/contrib/lite/toco/graph_transformations/lstm_utils.h"
16
17 namespace toco {
18
CreateOptionalArray(Model * model,string * input_array_buffer,const string & array_name)19 void CreateOptionalArray(Model* model, string* input_array_buffer,
20 const string& array_name) {
21 *input_array_buffer = array_name;
22 model->CreateOptionalArray(array_name);
23 }
24
CopyArrayData(const Buffer<ArrayDataType::kFloat> & src_buffer,int src_stride,int src_start_idx1,int src_start_idx2,Buffer<ArrayDataType::kFloat> * dst_buffer,int dst_stride,int dst_start_idx1,int dst_start_idx2,int dim1_copy_size,int dim2_copy_size)25 void CopyArrayData(const Buffer<ArrayDataType::kFloat>& src_buffer,
26 int src_stride, int src_start_idx1, int src_start_idx2,
27 Buffer<ArrayDataType::kFloat>* dst_buffer, int dst_stride,
28 int dst_start_idx1, int dst_start_idx2, int dim1_copy_size,
29 int dim2_copy_size) {
30 int src_offset = src_start_idx1 * src_stride + src_start_idx2;
31 int dst_offset = dst_start_idx1 * dst_stride + dst_start_idx2;
32 for (int i = 0; i < dim1_copy_size; i++) {
33 for (int j = 0; j < dim2_copy_size; j++) {
34 int idx_src = src_offset + i * src_stride + j;
35 int idx_dst = dst_offset + i * dst_stride + j;
36 dst_buffer->data[idx_dst] = src_buffer.data[idx_src];
37 }
38 }
39 }
40
CreateFloatArrayBuffer(Model * model,string * array_name,const Shape & shape)41 Buffer<ArrayDataType::kFloat>* CreateFloatArrayBuffer(Model* model,
42 string* array_name,
43 const Shape& shape) {
44 *array_name = AvailableArrayName(*model, *array_name);
45 auto& array = model->GetOrCreateArray(*array_name);
46 array.data_type = ArrayDataType::kFloat;
47 array.copy_shape(shape);
48 Buffer<ArrayDataType::kFloat>* buffer =
49 &(array.GetMutableBuffer<ArrayDataType::kFloat>());
50 buffer->data.resize(RequiredBufferSizeForShape(shape));
51 return buffer;
52 }
53
CopySubArrayToArray(Model * model,string * array_name,const string & tensor_name,int dim1_size,int dim2_size,const Array & original_array,int start_idx1,int start_idx2)54 void CopySubArrayToArray(Model* model, string* array_name,
55 const string& tensor_name, int dim1_size,
56 int dim2_size, const Array& original_array,
57 int start_idx1, int start_idx2) {
58 // Determine whether it's bias or not, create shape, buffer.
59 bool is_bias = dim2_size == 1;
60 Shape shape = is_bias ? Shape({dim1_size}) : Shape({dim1_size, dim2_size});
61 Buffer<ArrayDataType::kFloat>* buffer =
62 CreateFloatArrayBuffer(model, array_name, shape);
63 auto& orig_buffer = original_array.GetBuffer<ArrayDataType::kFloat>();
64
65 // Copy data from big tensor.
66 CopyArrayData(orig_buffer, is_bias ? 1 : original_array.shape().dims(1),
67 start_idx1, start_idx2, buffer, dim2_size, 0, 0, dim1_size,
68 dim2_size);
69 }
70
CopyArrayToSubArray(Buffer<ArrayDataType::kFloat> & tensor_buffer,int tensor_stride,const Array & sub_array,int start_idx1,int start_idx2)71 void CopyArrayToSubArray(Buffer<ArrayDataType::kFloat>& tensor_buffer,
72 int tensor_stride, const Array& sub_array,
73 int start_idx1, int start_idx2) {
74 // Get tensor data.
75 bool is_bias = sub_array.shape().dims().size() == 1;
76 int dim1_copy_size = sub_array.shape().dims()[0];
77 int dim2_copy_size = is_bias ? 1 : sub_array.shape().dims(1);
78 auto& sub_buffer = sub_array.GetBuffer<ArrayDataType::kFloat>();
79
80 // Copy data from sub tensor.
81 CopyArrayData(sub_buffer, dim2_copy_size, 0, 0, &tensor_buffer,
82 is_bias ? 1 : tensor_stride, start_idx1, start_idx2,
83 dim1_copy_size, dim2_copy_size);
84 }
85
GetMatchingRnnArray(Model * model,const string & back_edge_source_array,string * rnn_array)86 bool GetMatchingRnnArray(Model* model, const string& back_edge_source_array,
87 string* rnn_array) {
88 for (const auto& rnn_state : model->flags.rnn_states()) {
89 if (rnn_state.back_edge_source_array() == back_edge_source_array) {
90 *rnn_array = rnn_state.state_array();
91 return true;
92 }
93 }
94 return false;
95 }
96
97 } // namespace toco
98