• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <algorithm>
16 #include <cmath>
17 #include <memory>
18 #include <string>
19 #include <unordered_map>
20 #include <vector>
21 
22 #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
23 #include "tensorflow/lite/toco/model.h"
24 #include "tensorflow/lite/toco/runtime/types.h"
25 #include "tensorflow/lite/toco/tooling_util.h"
26 #include "tensorflow/core/platform/logging.h"
27 
28 namespace toco {
29 
30 namespace {
31 
VectorGreaterThan(const std::vector<int> & a,const std::vector<int> & b)32 std::vector<bool> VectorGreaterThan(const std::vector<int>& a,
33                                     const std::vector<int>& b) {
34   DCHECK_EQ(a.size(), b.size());
35   const int size = a.size();
36   std::vector<bool> result(size);
37   for (int i = 0; i < size; i++) {
38     result[i] = a[i] > b[i];
39   }
40   return result;
41 }
42 
PairwiseVectorSelect(const std::vector<bool> & selector,const std::vector<int> & input_a,const std::vector<int> & input_b,std::vector<int> * output_a,std::vector<int> * output_b)43 void PairwiseVectorSelect(const std::vector<bool>& selector,
44                           const std::vector<int>& input_a,
45                           const std::vector<int>& input_b,
46                           std::vector<int>* output_a,
47                           std::vector<int>* output_b) {
48   DCHECK_EQ(input_a.size(), input_b.size());
49   DCHECK_EQ(output_a->size(), output_b->size());
50   DCHECK_EQ(input_a.size(), output_a->size());
51   DCHECK_EQ(selector.size(), input_a.size());
52   const int size = input_a.size();
53   for (int i = 0; i < size; i++) {
54     if (selector[i]) {
55       (*output_a)[i] = input_a[i];
56       (*output_b)[i] = input_b[i];
57     } else {
58       (*output_a)[i] = input_b[i];
59       (*output_b)[i] = input_a[i];
60     }
61   }
62 }
63 
64 template <ArrayDataType InputsDataType, ArrayDataType OutputDataType>
EvaluateBinaryOperatorOnConstantInputs(Model * model,const Operator * binary_op)65 void EvaluateBinaryOperatorOnConstantInputs(Model* model,
66                                             const Operator* binary_op) {
67   CHECK(IsConstantParameterArray(*model, binary_op->inputs[0]));
68   CHECK(IsConstantParameterArray(*model, binary_op->inputs[1]));
69   CHECK(binary_op->fused_activation_function ==
70         FusedActivationFunctionType::kNone);
71   const auto& input0_array = model->GetArray(binary_op->inputs[0]);
72   const auto& input1_array = model->GetArray(binary_op->inputs[1]);
73   const auto& output_name = binary_op->outputs[0];
74   auto& output_array = model->GetArray(output_name);
75   CHECK(input0_array.data_type == InputsDataType);
76   CHECK(input1_array.data_type == InputsDataType);
77   CHECK(output_array.data_type == OutputDataType);
78 
79   // We have already tested above for existence of input buffers
80   // (synonymous to being a constant param).
81   CHECK(input0_array.buffer);
82   CHECK(input1_array.buffer);
83   // On the other hand, the output should not already have a buffer.
84   CHECK(!output_array.buffer);
85 
86   const auto& input0_data = input0_array.GetBuffer<InputsDataType>().data;
87   const auto& input1_data = input1_array.GetBuffer<InputsDataType>().data;
88   // Create the buffer on the output array, effectively turning it into
89   // a constant parameter
90 
91   const Shape& output_shape = output_array.shape();
92   auto& output_data = output_array.GetMutableBuffer<OutputDataType>().data;
93   const int output_buffer_size = RequiredBufferSizeForShape(output_shape);
94   output_data.resize(output_buffer_size);
95   const int dims_count = output_shape.dimensions_count();
96 
97   // It will be convenient here to have copies of the operands shapes
98   // extended to match the number of dimensions of the output shape.
99   Shape input0_shape = input0_array.shape();
100   Shape input1_shape = input1_array.shape();
101   ExtendShape(&input0_shape, dims_count);
102   ExtendShape(&input1_shape, dims_count);
103   // Now we may still have operands of different sizes, which would indicate
104   // that we have to "broadcast" the smaller dimension.  We do this using a
105   // a vector of Booleans indicating which input is the larger in each
106   // dimension.
107   CHECK_EQ(input0_shape.dimensions_count(), input1_shape.dimensions_count());
108   CHECK_EQ(input0_shape.dimensions_count(), dims_count);
109   const std::vector<bool> input0_larger =
110       VectorGreaterThan(input0_shape.dims(), input1_shape.dims());
111 
112   std::vector<int> big_sizes(dims_count);
113   std::vector<int> small_sizes(dims_count);
114   PairwiseVectorSelect(input0_larger, input0_shape.dims(), input1_shape.dims(),
115                        &big_sizes, &small_sizes);
116 
117   // The output should already be correctly sized to match the big dimensions.
118   for (int i = 0; i < dims_count; i++) {
119     CHECK_EQ(output_shape.dims(i), big_sizes[i]);
120   }
121 
122   std::vector<int> input0_indices(dims_count);
123   std::vector<int> input1_indices(dims_count);
124   std::vector<int> modulo_indices(dims_count);
125 
126   for (int k = 0; k < output_buffer_size; k++) {
127     const std::vector<int> output_indices = ReverseOffset(output_shape, k);
128     for (int i = 0; i < dims_count; i++) {
129       modulo_indices[i] = output_indices[i] % small_sizes[i];
130     }
131     PairwiseVectorSelect(input0_larger, output_indices, modulo_indices,
132                          &input0_indices, &input1_indices);
133     const auto val0 = input0_data[Offset(input0_shape, input0_indices)];
134     const auto val1 = input1_data[Offset(input1_shape, input1_indices)];
135 
136     DataType<OutputDataType> outval;
137     if (binary_op->type == OperatorType::kAdd) {
138       outval = val0 + val1;
139     } else if (binary_op->type == OperatorType::kMul) {
140       outval = val0 * val1;
141     } else if (binary_op->type == OperatorType::kSub) {
142       outval = val0 - val1;
143     } else if (binary_op->type == OperatorType::kDiv) {
144       outval = val0 / val1;
145     } else if (binary_op->type == OperatorType::kFloorDiv) {
146       outval = std::floor(val0 / val1);
147     } else if (binary_op->type == OperatorType::kFloorMod) {
148       outval = val0 - (std::floor(val0 / val1) * val1);
149     } else if (binary_op->type == OperatorType::kMinimum) {
150       outval = std::min(val0, val1);
151     } else if (binary_op->type == OperatorType::kMaximum) {
152       outval = std::max(val0, val1);
153     } else if (binary_op->type == OperatorType::kLess) {
154       outval = val0 < val1;
155     } else if (binary_op->type == OperatorType::kLessEqual) {
156       outval = val0 <= val1;
157     } else if (binary_op->type == OperatorType::kGreater) {
158       outval = val0 > val1;
159     } else if (binary_op->type == OperatorType::kGreaterEqual) {
160       outval = val0 >= val1;
161     } else {
162       LOG(FATAL) << "should not get here";
163     }
164     output_data[Offset(output_shape, output_indices)] = outval;
165   }
166 }
167 
EvaluateBinaryOperatorOnConstantInputs(Model * model,const Operator * binary_op)168 bool EvaluateBinaryOperatorOnConstantInputs(Model* model,
169                                             const Operator* binary_op) {
170   const auto inputs_data_type = model->GetArray(binary_op->inputs[0]).data_type;
171   const auto output_data_type =
172       model->GetArray(binary_op->outputs[0]).data_type;
173 #define TOCO_HANDLE_CASE(InputsDataType, OutputDataType)                    \
174   if (inputs_data_type == InputsDataType &&                                 \
175       output_data_type == OutputDataType) {                                 \
176     EvaluateBinaryOperatorOnConstantInputs<InputsDataType, OutputDataType>( \
177         model, binary_op);                                                  \
178     return true;                                                            \
179   }
180   TOCO_HANDLE_CASE(ArrayDataType::kFloat, ArrayDataType::kFloat)
181   TOCO_HANDLE_CASE(ArrayDataType::kFloat, ArrayDataType::kBool)
182   TOCO_HANDLE_CASE(ArrayDataType::kInt32, ArrayDataType::kInt32)
183   TOCO_HANDLE_CASE(ArrayDataType::kInt32, ArrayDataType::kBool)
184   TOCO_HANDLE_CASE(ArrayDataType::kInt64, ArrayDataType::kInt64)
185   TOCO_HANDLE_CASE(ArrayDataType::kInt64, ArrayDataType::kBool)
186   return false;
187 #undef TOCO_HANDLE_CASE
188 }
189 }  // namespace
190 
Run(Model * model,std::size_t op_index,bool * modified)191 ::tensorflow::Status ResolveConstantBinaryOperator::Run(Model* model,
192                                                         std::size_t op_index,
193                                                         bool* modified) {
194   *modified = false;
195   const auto binary_it = model->operators.begin() + op_index;
196   const auto* binary_op = binary_it->get();
197   // Test for binary ops of types that we know how to resolve
198   if (binary_op->type != OperatorType::kAdd &&
199       binary_op->type != OperatorType::kMul &&
200       binary_op->type != OperatorType::kSub &&
201       binary_op->type != OperatorType::kDiv &&
202       binary_op->type != OperatorType::kFloorDiv &&
203       binary_op->type != OperatorType::kFloorMod &&
204       binary_op->type != OperatorType::kMinimum &&
205       binary_op->type != OperatorType::kMaximum &&
206       binary_op->type != OperatorType::kLess &&
207       binary_op->type != OperatorType::kLessEqual &&
208       binary_op->type != OperatorType::kGreater &&
209       binary_op->type != OperatorType::kGreaterEqual) {
210     return ::tensorflow::Status::OK();
211   }
212   CHECK_EQ(binary_op->inputs.size(), 2);
213 
214   const auto& input0_array = model->GetArray(binary_op->inputs[0]);
215   const auto& input1_array = model->GetArray(binary_op->inputs[1]);
216   // Check if both inputs are constant parameters.
217   if (!input0_array.buffer || !input1_array.buffer) {
218     return ::tensorflow::Status::OK();
219   }
220 
221   auto& output_array = model->GetArray(binary_op->outputs[0]);
222   // Yield until the output array dims have been resolved.
223   if (!output_array.has_shape()) {
224     return ::tensorflow::Status::OK();
225   }
226 
227   // At the moment we don't want to care about fused activation functions.
228   // The idea is that we should do the present constants-propagation before
229   // activation functions get fused.
230   if (binary_op->fused_activation_function !=
231       FusedActivationFunctionType::kNone) {
232     AddMessageF(
233         "Not resolving constant %s because it has a fused activation function",
234         LogName(*binary_op));
235     return ::tensorflow::Status::OK();
236   }
237 
238   // Check that input data types agree.
239   CHECK(input0_array.data_type == input1_array.data_type)
240       << "Dissimilar data types given to op outputting \""
241       << binary_op->outputs[0] << "\". 0:\"" << binary_op->inputs[0] << "\"("
242       << static_cast<int>(input0_array.data_type) << ")   1:\""
243       << binary_op->inputs[1] << "\"("
244       << static_cast<int>(input1_array.data_type) << ").";
245 
246   // Do the actual constants propagation
247   if (!EvaluateBinaryOperatorOnConstantInputs(model, binary_op)) {
248     return ::tensorflow::Status::OK();
249   }
250 
251   DeleteOpAndArrays(model, binary_op);
252   *modified = true;
253   return ::tensorflow::Status::OK();
254 }
255 
256 }  // namespace toco
257