1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <algorithm>
16 #include <cmath>
17 #include <memory>
18 #include <string>
19 #include <unordered_map>
20 #include <vector>
21
22 #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
23 #include "tensorflow/lite/toco/model.h"
24 #include "tensorflow/lite/toco/runtime/types.h"
25 #include "tensorflow/lite/toco/tooling_util.h"
26 #include "tensorflow/core/platform/logging.h"
27
28 namespace toco {
29
30 namespace {
31
VectorGreaterThan(const std::vector<int> & a,const std::vector<int> & b)32 std::vector<bool> VectorGreaterThan(const std::vector<int>& a,
33 const std::vector<int>& b) {
34 DCHECK_EQ(a.size(), b.size());
35 const int size = a.size();
36 std::vector<bool> result(size);
37 for (int i = 0; i < size; i++) {
38 result[i] = a[i] > b[i];
39 }
40 return result;
41 }
42
PairwiseVectorSelect(const std::vector<bool> & selector,const std::vector<int> & input_a,const std::vector<int> & input_b,std::vector<int> * output_a,std::vector<int> * output_b)43 void PairwiseVectorSelect(const std::vector<bool>& selector,
44 const std::vector<int>& input_a,
45 const std::vector<int>& input_b,
46 std::vector<int>* output_a,
47 std::vector<int>* output_b) {
48 DCHECK_EQ(input_a.size(), input_b.size());
49 DCHECK_EQ(output_a->size(), output_b->size());
50 DCHECK_EQ(input_a.size(), output_a->size());
51 DCHECK_EQ(selector.size(), input_a.size());
52 const int size = input_a.size();
53 for (int i = 0; i < size; i++) {
54 if (selector[i]) {
55 (*output_a)[i] = input_a[i];
56 (*output_b)[i] = input_b[i];
57 } else {
58 (*output_a)[i] = input_b[i];
59 (*output_b)[i] = input_a[i];
60 }
61 }
62 }
63
64 template <ArrayDataType InputsDataType, ArrayDataType OutputDataType>
EvaluateBinaryOperatorOnConstantInputs(Model * model,const Operator * binary_op)65 void EvaluateBinaryOperatorOnConstantInputs(Model* model,
66 const Operator* binary_op) {
67 CHECK(IsConstantParameterArray(*model, binary_op->inputs[0]));
68 CHECK(IsConstantParameterArray(*model, binary_op->inputs[1]));
69 CHECK(binary_op->fused_activation_function ==
70 FusedActivationFunctionType::kNone);
71 const auto& input0_array = model->GetArray(binary_op->inputs[0]);
72 const auto& input1_array = model->GetArray(binary_op->inputs[1]);
73 const auto& output_name = binary_op->outputs[0];
74 auto& output_array = model->GetArray(output_name);
75 CHECK(input0_array.data_type == InputsDataType);
76 CHECK(input1_array.data_type == InputsDataType);
77 CHECK(output_array.data_type == OutputDataType);
78
79 // We have already tested above for existence of input buffers
80 // (synonymous to being a constant param).
81 CHECK(input0_array.buffer);
82 CHECK(input1_array.buffer);
83 // On the other hand, the output should not already have a buffer.
84 CHECK(!output_array.buffer);
85
86 const auto& input0_data = input0_array.GetBuffer<InputsDataType>().data;
87 const auto& input1_data = input1_array.GetBuffer<InputsDataType>().data;
88 // Create the buffer on the output array, effectively turning it into
89 // a constant parameter
90
91 const Shape& output_shape = output_array.shape();
92 auto& output_data = output_array.GetMutableBuffer<OutputDataType>().data;
93 const int output_buffer_size = RequiredBufferSizeForShape(output_shape);
94 output_data.resize(output_buffer_size);
95 const int dims_count = output_shape.dimensions_count();
96
97 // It will be convenient here to have copies of the operands shapes
98 // extended to match the number of dimensions of the output shape.
99 Shape input0_shape = input0_array.shape();
100 Shape input1_shape = input1_array.shape();
101 ExtendShape(&input0_shape, dims_count);
102 ExtendShape(&input1_shape, dims_count);
103 // Now we may still have operands of different sizes, which would indicate
104 // that we have to "broadcast" the smaller dimension. We do this using a
105 // a vector of Booleans indicating which input is the larger in each
106 // dimension.
107 CHECK_EQ(input0_shape.dimensions_count(), input1_shape.dimensions_count());
108 CHECK_EQ(input0_shape.dimensions_count(), dims_count);
109 const std::vector<bool> input0_larger =
110 VectorGreaterThan(input0_shape.dims(), input1_shape.dims());
111
112 std::vector<int> big_sizes(dims_count);
113 std::vector<int> small_sizes(dims_count);
114 PairwiseVectorSelect(input0_larger, input0_shape.dims(), input1_shape.dims(),
115 &big_sizes, &small_sizes);
116
117 // The output should already be correctly sized to match the big dimensions.
118 for (int i = 0; i < dims_count; i++) {
119 CHECK_EQ(output_shape.dims(i), big_sizes[i]);
120 }
121
122 std::vector<int> input0_indices(dims_count);
123 std::vector<int> input1_indices(dims_count);
124 std::vector<int> modulo_indices(dims_count);
125
126 for (int k = 0; k < output_buffer_size; k++) {
127 const std::vector<int> output_indices = ReverseOffset(output_shape, k);
128 for (int i = 0; i < dims_count; i++) {
129 modulo_indices[i] = output_indices[i] % small_sizes[i];
130 }
131 PairwiseVectorSelect(input0_larger, output_indices, modulo_indices,
132 &input0_indices, &input1_indices);
133 const auto val0 = input0_data[Offset(input0_shape, input0_indices)];
134 const auto val1 = input1_data[Offset(input1_shape, input1_indices)];
135
136 DataType<OutputDataType> outval;
137 if (binary_op->type == OperatorType::kAdd) {
138 outval = val0 + val1;
139 } else if (binary_op->type == OperatorType::kMul) {
140 outval = val0 * val1;
141 } else if (binary_op->type == OperatorType::kSub) {
142 outval = val0 - val1;
143 } else if (binary_op->type == OperatorType::kDiv) {
144 outval = val0 / val1;
145 } else if (binary_op->type == OperatorType::kFloorDiv) {
146 outval = std::floor(val0 / val1);
147 } else if (binary_op->type == OperatorType::kFloorMod) {
148 outval = val0 - (std::floor(val0 / val1) * val1);
149 } else if (binary_op->type == OperatorType::kMinimum) {
150 outval = std::min(val0, val1);
151 } else if (binary_op->type == OperatorType::kMaximum) {
152 outval = std::max(val0, val1);
153 } else if (binary_op->type == OperatorType::kLess) {
154 outval = val0 < val1;
155 } else if (binary_op->type == OperatorType::kLessEqual) {
156 outval = val0 <= val1;
157 } else if (binary_op->type == OperatorType::kGreater) {
158 outval = val0 > val1;
159 } else if (binary_op->type == OperatorType::kGreaterEqual) {
160 outval = val0 >= val1;
161 } else {
162 LOG(FATAL) << "should not get here";
163 }
164 output_data[Offset(output_shape, output_indices)] = outval;
165 }
166 }
167
EvaluateBinaryOperatorOnConstantInputs(Model * model,const Operator * binary_op)168 bool EvaluateBinaryOperatorOnConstantInputs(Model* model,
169 const Operator* binary_op) {
170 const auto inputs_data_type = model->GetArray(binary_op->inputs[0]).data_type;
171 const auto output_data_type =
172 model->GetArray(binary_op->outputs[0]).data_type;
173 #define TOCO_HANDLE_CASE(InputsDataType, OutputDataType) \
174 if (inputs_data_type == InputsDataType && \
175 output_data_type == OutputDataType) { \
176 EvaluateBinaryOperatorOnConstantInputs<InputsDataType, OutputDataType>( \
177 model, binary_op); \
178 return true; \
179 }
180 TOCO_HANDLE_CASE(ArrayDataType::kFloat, ArrayDataType::kFloat)
181 TOCO_HANDLE_CASE(ArrayDataType::kFloat, ArrayDataType::kBool)
182 TOCO_HANDLE_CASE(ArrayDataType::kInt32, ArrayDataType::kInt32)
183 TOCO_HANDLE_CASE(ArrayDataType::kInt32, ArrayDataType::kBool)
184 TOCO_HANDLE_CASE(ArrayDataType::kInt64, ArrayDataType::kInt64)
185 TOCO_HANDLE_CASE(ArrayDataType::kInt64, ArrayDataType::kBool)
186 return false;
187 #undef TOCO_HANDLE_CASE
188 }
189 } // namespace
190
Run(Model * model,std::size_t op_index,bool * modified)191 ::tensorflow::Status ResolveConstantBinaryOperator::Run(Model* model,
192 std::size_t op_index,
193 bool* modified) {
194 *modified = false;
195 const auto binary_it = model->operators.begin() + op_index;
196 const auto* binary_op = binary_it->get();
197 // Test for binary ops of types that we know how to resolve
198 if (binary_op->type != OperatorType::kAdd &&
199 binary_op->type != OperatorType::kMul &&
200 binary_op->type != OperatorType::kSub &&
201 binary_op->type != OperatorType::kDiv &&
202 binary_op->type != OperatorType::kFloorDiv &&
203 binary_op->type != OperatorType::kFloorMod &&
204 binary_op->type != OperatorType::kMinimum &&
205 binary_op->type != OperatorType::kMaximum &&
206 binary_op->type != OperatorType::kLess &&
207 binary_op->type != OperatorType::kLessEqual &&
208 binary_op->type != OperatorType::kGreater &&
209 binary_op->type != OperatorType::kGreaterEqual) {
210 return ::tensorflow::Status::OK();
211 }
212 CHECK_EQ(binary_op->inputs.size(), 2);
213
214 const auto& input0_array = model->GetArray(binary_op->inputs[0]);
215 const auto& input1_array = model->GetArray(binary_op->inputs[1]);
216 // Check if both inputs are constant parameters.
217 if (!input0_array.buffer || !input1_array.buffer) {
218 return ::tensorflow::Status::OK();
219 }
220
221 auto& output_array = model->GetArray(binary_op->outputs[0]);
222 // Yield until the output array dims have been resolved.
223 if (!output_array.has_shape()) {
224 return ::tensorflow::Status::OK();
225 }
226
227 // At the moment we don't want to care about fused activation functions.
228 // The idea is that we should do the present constants-propagation before
229 // activation functions get fused.
230 if (binary_op->fused_activation_function !=
231 FusedActivationFunctionType::kNone) {
232 AddMessageF(
233 "Not resolving constant %s because it has a fused activation function",
234 LogName(*binary_op));
235 return ::tensorflow::Status::OK();
236 }
237
238 // Check that input data types agree.
239 CHECK(input0_array.data_type == input1_array.data_type)
240 << "Dissimilar data types given to op outputting \""
241 << binary_op->outputs[0] << "\". 0:\"" << binary_op->inputs[0] << "\"("
242 << static_cast<int>(input0_array.data_type) << ") 1:\""
243 << binary_op->inputs[1] << "\"("
244 << static_cast<int>(input1_array.data_type) << ").";
245
246 // Do the actual constants propagation
247 if (!EvaluateBinaryOperatorOnConstantInputs(model, binary_op)) {
248 return ::tensorflow::Status::OK();
249 }
250
251 DeleteOpAndArrays(model, binary_op);
252 *modified = true;
253 return ::tensorflow::Status::OK();
254 }
255
256 } // namespace toco
257