1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <memory>
16 #include <numeric>
17 #include <string>
18 #include <unordered_map>
19 #include <vector>
20
21 #include "absl/strings/str_cat.h"
22 #include "absl/strings/str_join.h"
23 #include "tensorflow/core/platform/logging.h"
24 #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
25 #include "tensorflow/lite/toco/model.h"
26 #include "tensorflow/lite/toco/tooling_util.h"
27
28 namespace toco {
29
30 namespace {
31
UnrollBatchMatMul3D(const string & input_lhs,const string & input_rhs,const BatchMatMulOperator * batch_op,const std::vector<int> batch,Model * model,std::vector<std::unique_ptr<Operator>>::iterator * tail_it,std::vector<string> * pack_inputs)32 void UnrollBatchMatMul3D(
33 const string& input_lhs, const string& input_rhs,
34 const BatchMatMulOperator* batch_op, const std::vector<int> batch,
35 Model* model, std::vector<std::unique_ptr<Operator>>::iterator* tail_it,
36 std::vector<string>* pack_inputs) {
37 const std::string batch_name =
38 absl::StrCat(batch_op->outputs[0], "_b", absl::StrJoin(batch, "-"));
39 const auto& input_array_a = model->GetArray(input_lhs);
40 const auto& input_array_b = model->GetArray(input_rhs);
41 const int dims_count = input_array_a.shape().dimensions_count();
42
43 // tf.slice(a, ...).
44 std::vector<int> begin_indices_a = batch;
45 begin_indices_a.resize(dims_count);
46 std::vector<int> slice_size_a = input_array_a.shape().dims();
47 for (int i = 0; i < batch.size(); ++i) {
48 slice_size_a[i] = 1;
49 }
50 auto* slice_a_op = new SliceOperator;
51 slice_a_op->inputs = {
52 input_lhs,
53 CreateInt32Array(model, batch_name + "/slice_a/slice/begin",
54 begin_indices_a),
55 CreateInt32Array(model, batch_name + "/slice_a/slice/size", slice_size_a),
56 };
57 slice_a_op->outputs = {AvailableArrayName(*model, batch_name + "/slice_a")};
58 auto& slice_a_op_output = model->GetOrCreateArray(slice_a_op->outputs[0]);
59 slice_a_op_output.data_type = input_array_a.data_type;
60 *tail_it = model->operators.emplace(*tail_it, slice_a_op) + 1;
61
62 // Reshape to remove the first dimension ([1,M,N] -> [M,N]).
63 auto* slice_a_reshape_op = new TensorFlowReshapeOperator;
64 slice_a_reshape_op->inputs = {
65 slice_a_op->outputs[0],
66 CreateInt32Array(model, batch_name + "/slice_a/reshape/shape",
67 {-1, input_array_a.shape().dims(dims_count - 1)})};
68 slice_a_reshape_op->outputs = {
69 AvailableArrayName(*model, batch_name + "/slice_a/reshape")};
70 auto& slice_a_reshape_op_output =
71 model->GetOrCreateArray(slice_a_reshape_op->outputs[0]);
72 slice_a_reshape_op_output.data_type = input_array_a.data_type;
73 *tail_it = model->operators.emplace(*tail_it, slice_a_reshape_op) + 1;
74
75 // tf.slice(b, ...).
76 std::vector<int> begin_indices_b = batch;
77 begin_indices_b.resize(dims_count);
78 std::vector<int> slice_size_b = input_array_b.shape().dims();
79 for (int i = 0; i < batch.size(); ++i) {
80 slice_size_b[i] = 1;
81 }
82 auto* slice_b_op = new SliceOperator;
83 slice_b_op->inputs = {
84 input_rhs,
85 CreateInt32Array(model, batch_name + "/slice_b/slice/begin",
86 begin_indices_b),
87 CreateInt32Array(model, batch_name + "/slice_b/slice/size", slice_size_b),
88 };
89 slice_b_op->outputs = {AvailableArrayName(*model, batch_name + "/slice_b")};
90 auto& slice_b_op_output = model->GetOrCreateArray(slice_b_op->outputs[0]);
91 slice_b_op_output.data_type = input_array_b.data_type;
92 *tail_it = model->operators.emplace(*tail_it, slice_b_op) + 1;
93
94 // Reshape to remove the first dimension ([1,M,N] -> [M,N]).
95 auto* slice_b_reshape_op = new TensorFlowReshapeOperator;
96 slice_b_reshape_op->inputs = {
97 slice_b_op->outputs[0],
98 CreateInt32Array(model, batch_name + "/slice_b/reshape/shape",
99 {-1, input_array_b.shape().dims(dims_count - 1)})};
100 slice_b_reshape_op->outputs = {
101 AvailableArrayName(*model, batch_name + "/slice_b/reshape")};
102 auto& slice_b_reshape_op_output =
103 model->GetOrCreateArray(slice_b_reshape_op->outputs[0]);
104 slice_b_reshape_op_output.data_type = input_array_b.data_type;
105 *tail_it = model->operators.emplace(*tail_it, slice_b_reshape_op) + 1;
106
107 // tf.matmul(slice_a, slice_b).
108 auto* matmul_op = new TensorFlowMatMulOperator;
109 matmul_op->inputs = {slice_a_reshape_op->outputs[0],
110 slice_b_reshape_op->outputs[0]};
111 matmul_op->outputs = {AvailableArrayName(*model, batch_name)};
112 auto& matmul_op_output = model->GetOrCreateArray(matmul_op->outputs[0]);
113 matmul_op_output.data_type = input_array_a.data_type;
114 *tail_it = model->operators.emplace(*tail_it, matmul_op) + 1;
115
116 // Add to stack.
117 pack_inputs->push_back(matmul_op->outputs[0]);
118 }
119
UnrollBatchMatMulRecursion(const string & input_lhs,const string & input_rhs,const BatchMatMulOperator * batch_op,Model * model,std::vector<std::unique_ptr<Operator>>::iterator * tail_it,const std::vector<int> & batch_prefix)120 std::vector<string> UnrollBatchMatMulRecursion(
121 const string& input_lhs, const string& input_rhs,
122 const BatchMatMulOperator* batch_op, Model* model,
123 std::vector<std::unique_ptr<Operator>>::iterator* tail_it,
124 const std::vector<int>& batch_prefix) {
125 const auto& input_array_a = model->GetArray(input_lhs);
126 const auto& dims_vec = input_array_a.shape().dims();
127 const int current_dim_size = dims_vec[batch_prefix.size()];
128 std::vector<string> batch_pack_inputs;
129
130 if (batch_prefix.size() + 3 == dims_vec.size()) {
131 // Base case
132 for (int batch = 0; batch < current_dim_size; ++batch) {
133 std::vector<int> new_batch_prefix = batch_prefix;
134 new_batch_prefix.emplace_back(batch);
135 UnrollBatchMatMul3D(input_lhs, input_rhs, batch_op, new_batch_prefix,
136 model, tail_it, &batch_pack_inputs);
137 }
138 } else {
139 // Recursion
140 for (int batch = 0; batch < current_dim_size; ++batch) {
141 std::vector<int> new_batch_prefix = batch_prefix;
142 new_batch_prefix.emplace_back(batch);
143 std::vector<string> pack_inputs = UnrollBatchMatMulRecursion(
144 input_lhs, input_rhs, batch_op, model, tail_it, new_batch_prefix);
145
146 // The pack that will join all the individual matmul results together.
147 auto* pack_op = new PackOperator;
148 std::string batch_name = absl::StrCat(
149 batch_op->outputs[0], "_b", absl::StrJoin(new_batch_prefix, "-"));
150 pack_op->inputs = pack_inputs;
151 pack_op->outputs = {AvailableArrayName(*model, batch_name + "/pack")};
152 auto& pack_op_output = model->GetOrCreateArray(pack_op->outputs[0]);
153 pack_op_output.data_type = input_array_a.data_type;
154 pack_op->axis = 0;
155 pack_op->values_count = pack_inputs.size();
156 *tail_it = model->operators.emplace(*tail_it, pack_op) + 1;
157
158 batch_pack_inputs.push_back(pack_op->outputs[0]);
159 }
160 }
161 return batch_pack_inputs;
162 }
163
GetTransposePerm(const Array & input_array)164 std::vector<int32> GetTransposePerm(const Array& input_array) {
165 const int32 dims = input_array.shape().dimensions_count();
166 std::vector<int32> perm_array_val(dims);
167 for (int i = 0; i < dims; ++i) {
168 perm_array_val[i] = i;
169 }
170 perm_array_val[dims - 2] = dims - 1;
171 perm_array_val[dims - 1] = dims - 2;
172 return perm_array_val;
173 }
174
GetTransposeShape(const Shape & input_shape,const std::vector<int32> & perm_array_val)175 std::vector<int32> GetTransposeShape(const Shape& input_shape,
176 const std::vector<int32>& perm_array_val) {
177 const int32 dims = input_shape.dimensions_count();
178 std::vector<int32> output_shape(dims);
179 for (int i = 0; i < dims; ++i) {
180 output_shape[i] = input_shape.dims(perm_array_val[i]);
181 }
182 return output_shape;
183 }
184
TransposeInput(const string & input,Model * model)185 TransposeOperator* TransposeInput(const string& input, Model* model) {
186 const auto& input_array = model->GetArray(input);
187 const auto perm_array = GetTransposePerm(input_array);
188 const string perm_array_name = CreateInt32Array(
189 model, AvailableArrayName(*model, input + "/transpose/perm"), perm_array);
190 auto* transpose_op = new TransposeOperator;
191 transpose_op->inputs = {input, perm_array_name};
192 transpose_op->outputs = {AvailableArrayName(*model, input + "/transpose")};
193 auto& transpose_array = model->GetOrCreateArray(transpose_op->outputs[0]);
194 *transpose_array.mutable_shape()->mutable_dims() =
195 GetTransposeShape(input_array.shape(), perm_array);
196 model->GetOrCreateArray(transpose_op->outputs[0]);
197 return transpose_op;
198 }
199
200 } // namespace
201
202 // Unrolls a BatchMatMul on the batch dimension.
203 // We need to slice each batch out of the inputs, matmul them individually, then
204 // stack them all back together at the end.
205 //
206 // This transform effectively looks like:
207 // result_slices = []
208 // for bat in B:
209 // slice_a = tf.reshape(tf.slice(a, [bat, 0, 0], [1, M, N]), [M, N])
210 // slice_b = tf.reshape(tf.slice(b, [bat, 0, 0], [1, M, N]), [M, N])
211 // slice_c = tf.matmul(slice_a, slice_b)
212 // result_slices[bat] = slice_c
213 // result = tf.stack(result_slices)
Run(Model * model,std::size_t op_index,bool * modified)214 ::tensorflow::Status UnrollBatchMatMul::Run(Model* model, std::size_t op_index,
215 bool* modified) {
216 *modified = false;
217 auto batch_op_it = model->operators.begin() + op_index;
218 if (batch_op_it->get()->type != OperatorType::kBatchMatMul) {
219 return ::tensorflow::Status::OK();
220 }
221 const auto* batch_op =
222 static_cast<const BatchMatMulOperator*>(batch_op_it->get());
223
224 auto& tail_it = batch_op_it;
225
226 string input_lhs = batch_op->inputs[0];
227 string input_rhs = batch_op->inputs[1];
228 const auto& input_lhs_array = model->GetArray(input_lhs);
229 const auto& input_rhs_array = model->GetArray(input_rhs);
230 if (!input_lhs_array.has_shape() || !input_rhs_array.has_shape())
231 return ::tensorflow::Status::OK();
232
233 // Transpose LHS input if necessary.
234 if (batch_op->adj_x) {
235 TransposeOperator* transpose_op = TransposeInput(input_lhs, model);
236 tail_it = model->operators.emplace(tail_it, transpose_op) + 1;
237 input_lhs = transpose_op->outputs[0];
238 }
239 const auto& input_array_a = model->GetArray(input_lhs);
240
241 // Transpose RHS input if necessary.
242 if (batch_op->adj_y) {
243 TransposeOperator* transpose_op = TransposeInput(input_rhs, model);
244 tail_it = model->operators.emplace(tail_it, transpose_op) + 1;
245 input_rhs = transpose_op->outputs[0];
246 }
247 const auto& input_array_b = model->GetArray(input_rhs);
248
249 const int dims = input_array_a.shape().dimensions_count();
250 for (int i = 0; i < dims - 2; ++i) {
251 CHECK_EQ(input_array_a.shape().dims(i), input_array_b.shape().dims(i))
252 << "input array not consistent at index " << i;
253 }
254 CHECK_EQ(input_array_a.shape().dims(dims - 1),
255 input_array_b.shape().dims(dims - 2))
256 << "Input dimensions must be compatible for multipication. shape a = ["
257 << absl::StrJoin(input_array_a.shape().dims(), ", ") << "], shape b = ["
258 << absl::StrJoin(input_array_b.shape().dims(), ", ") << "]";
259
260 if (dims == 2) {
261 // This is really just a MatMul. This likely means that someone hand-crafted
262 // a graphdef with a BatchMatMul when they really wanted a MatMul.
263 AddMessageF("Replacing non-batch BatchMatMul %s by a MatMul operator",
264 LogName(*batch_op));
265 auto* matmul_op = new TensorFlowMatMulOperator;
266 matmul_op->inputs = {input_lhs, input_rhs};
267 matmul_op->outputs = batch_op->outputs;
268 tail_it = model->operators.emplace(tail_it, matmul_op) + 1;
269 CHECK_EQ(tail_it->get(), batch_op);
270 model->operators.erase(tail_it);
271 *modified = true;
272 return ::tensorflow::Status::OK();
273 }
274
275 CHECK_GE(input_array_a.shape().dimensions_count(), 3)
276 << "Input arrays must have rank >= 3";
277
278 const auto& dims_vec = input_array_a.shape().dims();
279 AddMessageF("Unrolling BatchMatMul %s %d times", LogName(*batch_op),
280 std::accumulate(dims_vec.begin(), dims_vec.end() - 2, 1,
281 std::multiplies<int>()));
282
283 std::vector<string> pack_inputs = UnrollBatchMatMulRecursion(
284 input_lhs, input_rhs, batch_op, model, &tail_it, {});
285 auto* pack_op = new PackOperator;
286 pack_op->inputs = pack_inputs;
287 pack_op->outputs = {batch_op->outputs[0]};
288 pack_op->axis = 0;
289 pack_op->values_count = pack_inputs.size();
290 model->operators.emplace(tail_it, pack_op);
291
292 // Remove the old batch matmul now that we've unrolled.
293 batch_op_it = model->operators.begin();
294 for (; batch_op_it != model->operators.end(); ++batch_op_it) {
295 if (batch_op_it->get() == batch_op) {
296 break;
297 }
298 }
299 CHECK(batch_op_it != model->operators.end());
300 CHECK(batch_op_it->get() == batch_op);
301 model->operators.erase(batch_op_it);
302 *modified = true;
303 return ::tensorflow::Status::OK();
304 }
305
306 } // namespace toco
307