1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/toco/tooling_util.h"
16
17 #include <functional>
18 #include <iterator>
19 #include <set>
20 #include <unordered_map>
21 #include <unordered_set>
22 #include <utility>
23
24 #include "absl/strings/ascii.h"
25 #include "absl/strings/str_cat.h"
26 #include "absl/strings/str_join.h"
27 #include "absl/strings/str_replace.h"
28 #include "absl/strings/str_split.h"
29 #include "re2/re2.h"
30 #include "tensorflow/core/lib/core/status.h"
31 #include "tensorflow/core/platform/logging.h"
32 #include "tensorflow/lite/toco/dump_graphviz.h"
33 #include "tensorflow/lite/toco/model_flags.pb.h"
34 #include "tensorflow/lite/toco/toco_graphviz_dump_options.h"
35
36 namespace toco {
37
38 // Find the longest common prefix of two strings.
FindLongestCommonPrefix(absl::string_view a,absl::string_view b)39 absl::string_view FindLongestCommonPrefix(absl::string_view a,
40 absl::string_view b) {
41 if (a.empty() || b.empty()) return absl::string_view();
42
43 const char* pa = a.data();
44 const char* pb = b.data();
45 size_t count = 0;
46 const size_t limit = std::min(a.size(), b.size());
47 while (count < limit && *pa == *pb) {
48 ++pa;
49 ++pb;
50 ++count;
51 }
52
53 return absl::string_view(a.data(), count);
54 }
55
LogName(const Operator & op)56 std::string LogName(const Operator& op) {
57 const std::string& opname = HelpfulOperatorTypeName(op);
58 if (op.outputs.empty()) {
59 return toco::port::StringF("{%s operator}", opname);
60 } else {
61 return toco::port::StringF("{%s operator with output %s}", opname,
62 op.outputs[0]);
63 }
64 }
65
ArrayDataTypeName(ArrayDataType data_type)66 std::string ArrayDataTypeName(ArrayDataType data_type) {
67 switch (data_type) {
68 case ArrayDataType::kFloat:
69 return "float";
70 case ArrayDataType::kInt8:
71 return "int8";
72 case ArrayDataType::kUint8:
73 return "uint8";
74 case ArrayDataType::kInt16:
75 return "int16";
76 case ArrayDataType::kUint16:
77 return "uint16";
78 case ArrayDataType::kInt32:
79 return "int32";
80 case ArrayDataType::kUint32:
81 return "uint32";
82 case ArrayDataType::kInt64:
83 return "int64";
84 case ArrayDataType::kUint64:
85 return "uint64";
86 case ArrayDataType::kString:
87 return "string";
88 case ArrayDataType::kBool:
89 return "bool";
90 case ArrayDataType::kComplex64:
91 return "complex64";
92 case ArrayDataType::kNone:
93 return "None";
94 default:
95 LOG(FATAL) << "Unhandled array data type " << static_cast<int>(data_type);
96 }
97 }
98
IsInputArray(const Model & model,const std::string & array_name)99 bool IsInputArray(const Model& model, const std::string& array_name) {
100 for (const auto& input_array : model.flags.input_arrays()) {
101 if (array_name == input_array.name()) {
102 return true;
103 }
104 }
105 return false;
106 }
107
IsOutputArray(const Model & model,const std::string & array_name)108 bool IsOutputArray(const Model& model, const std::string& array_name) {
109 for (const auto& output_array : model.flags.output_arrays()) {
110 if (array_name == output_array) {
111 return true;
112 }
113 }
114 return false;
115 }
116
IsArrayConsumed(const Model & model,const std::string & name)117 bool IsArrayConsumed(const Model& model, const std::string& name) {
118 if (GetOpWithInput(model, name)) {
119 return true;
120 }
121 if (IsOutputArray(model, name)) {
122 return true;
123 }
124 for (const auto& rnn_state : model.flags.rnn_states()) {
125 if (rnn_state.back_edge_source_array() == name) {
126 return true;
127 }
128 }
129 return false;
130 }
131
CountTrueOutputs(const Model & model,const Operator & op)132 int CountTrueOutputs(const Model& model, const Operator& op) {
133 int count = 0;
134 for (const std::string& output : op.outputs) {
135 if (IsArrayConsumed(model, output)) {
136 ++count;
137 }
138 }
139 return count;
140 }
141
CountOpsWithInput(const Model & model,const std::string & array_name)142 int CountOpsWithInput(const Model& model, const std::string& array_name) {
143 int count = 0;
144 for (const auto& op : model.operators) {
145 for (auto& input : op->inputs) {
146 if (input == array_name) {
147 count++;
148 // Breaking here is important: some graphs have ops that use the
149 // same array as more than one of their inputs, and in that case
150 // we want it counted only once.
151 break;
152 }
153 }
154 }
155 return count;
156 }
157
DeleteArrayIfUnused(const std::string & array_name,Model * model)158 bool DeleteArrayIfUnused(const std::string& array_name, Model* model) {
159 if (IsDiscardableArray(*model, array_name) &&
160 CountOpsWithInput(*model, array_name) == 0 &&
161 GetOpWithOutput(*model, array_name) == nullptr) {
162 model->EraseArray(array_name);
163 return true;
164 }
165 return false;
166 }
167
DeleteArrayIfUnusedOutsideOfOp(const std::string & array_name,const Operator * op,Model * model)168 bool DeleteArrayIfUnusedOutsideOfOp(const std::string& array_name,
169 const Operator* op, Model* model) {
170 if (!IsDiscardableArray(*model, array_name)) {
171 return false;
172 }
173 if (CountOpsWithInput(*model, array_name) > 1) {
174 return false;
175 }
176 const Operator* op_having_this_as_input = GetOpWithInput(*model, array_name);
177 if (op_having_this_as_input && op_having_this_as_input != op) {
178 return false;
179 }
180 const Operator* op_having_this_as_output =
181 GetOpWithOutput(*model, array_name);
182 if (op_having_this_as_output && op_having_this_as_output != op) {
183 return false;
184 }
185 model->EraseArray(array_name);
186 return true;
187 }
188
DeleteOpAndArrays(Model * model,const Operator * op)189 void DeleteOpAndArrays(Model* model, const Operator* op) {
190 for (const std::string& array_name : op->inputs) {
191 DeleteArrayIfUnusedOutsideOfOp(array_name, op, model);
192 }
193 for (const std::string& array_name : op->outputs) {
194 DeleteArrayIfUnusedOutsideOfOp(array_name, op, model);
195 }
196 auto op_it = FindOp(*model, op);
197 CHECK(op_it != model->operators.end());
198 model->operators.erase(op_it);
199 }
200
FindOpWithOutput(const Model & model,const std::string & array_name)201 std::vector<std::unique_ptr<Operator>>::const_iterator FindOpWithOutput(
202 const Model& model, const std::string& array_name) {
203 for (auto it = model.operators.begin(); it != model.operators.end(); ++it) {
204 for (auto& output : it->get()->outputs) {
205 if (output == array_name) {
206 return it;
207 }
208 }
209 }
210 return model.operators.end();
211 }
212
FindOpWithOutput(Model & model,const std::string & array_name)213 std::vector<std::unique_ptr<Operator>>::iterator FindOpWithOutput(
214 Model& model, const std::string& array_name) {
215 for (auto it = model.operators.begin(); it != model.operators.end(); ++it) {
216 for (auto& output : it->get()->outputs) {
217 if (output == array_name) {
218 return it;
219 }
220 }
221 }
222 return model.operators.end();
223 }
224
GetOpWithOutput(const Model & model,const std::string & array_name)225 Operator* GetOpWithOutput(const Model& model, const std::string& array_name) {
226 auto it = FindOpWithOutput(model, array_name);
227 return it == model.operators.end() ? nullptr : it->get();
228 }
229
230 // GetFirstOpWithInput assumes that this finds the first op.
FindOpWithInput(const Model & model,const std::string & array_name)231 std::vector<std::unique_ptr<Operator>>::const_iterator FindOpWithInput(
232 const Model& model, const std::string& array_name) {
233 for (auto it = model.operators.begin(); it != model.operators.end(); ++it) {
234 for (auto& input : it->get()->inputs) {
235 if (input == array_name) {
236 return it;
237 }
238 }
239 }
240 return model.operators.end();
241 }
242
FindOpWithInput(Model & model,const std::string & array_name)243 std::vector<std::unique_ptr<Operator>>::iterator FindOpWithInput(
244 Model& model, const std::string& array_name) {
245 for (auto it = model.operators.begin(); it != model.operators.end(); ++it) {
246 for (auto& input : it->get()->inputs) {
247 if (input == array_name) {
248 return it;
249 }
250 }
251 }
252 return model.operators.end();
253 }
254
FindOp(const Model & model,const Operator * op)255 std::vector<std::unique_ptr<Operator>>::const_iterator FindOp(
256 const Model& model, const Operator* op) {
257 for (auto it = model.operators.begin(); it != model.operators.end(); ++it) {
258 if (it->get() == op) {
259 return it;
260 }
261 }
262 return model.operators.end();
263 }
264
FindOp(Model & model,const Operator * op)265 std::vector<std::unique_ptr<Operator>>::iterator FindOp(Model& model,
266 const Operator* op) {
267 for (auto it = model.operators.begin(); it != model.operators.end(); ++it) {
268 if (it->get() == op) {
269 return it;
270 }
271 }
272 return model.operators.end();
273 }
274
GetOpWithInput(const Model & model,const std::string & array_name)275 Operator* GetOpWithInput(const Model& model, const std::string& array_name) {
276 auto it = FindOpWithInput(model, array_name);
277 return it == model.operators.end() ? nullptr : it->get();
278 }
279
GetFirstOpWithInput(const Model & model,const std::string & array_name)280 Operator* GetFirstOpWithInput(const Model& model,
281 const std::string& array_name) {
282 auto it = FindOpWithInput(model, array_name);
283 return it == model.operators.end() ? nullptr : it->get();
284 }
285
ReplaceArrayUsage(Model * model,const std::string & old_array_name,const std::string & new_array_name)286 void ReplaceArrayUsage(Model* model, const std::string& old_array_name,
287 const std::string& new_array_name) {
288 for (auto& op_it : model->operators) {
289 Operator* op = op_it.get();
290 for (size_t i = 0; i < op->inputs.size(); ++i) {
291 if (op->inputs[i] == old_array_name) {
292 op->inputs[i] = new_array_name;
293 }
294 }
295 for (size_t i = 0; i < op->outputs.size(); ++i) {
296 if (op->outputs[i] == old_array_name) {
297 op->outputs[i] = new_array_name;
298 }
299 }
300 }
301 }
302
FormatArraysList(const Model & model,const std::vector<std::string> & list)303 std::string FormatArraysList(const Model& model,
304 const std::vector<std::string>& list) {
305 if (list.empty()) {
306 return "[]";
307 }
308 std::string result = "";
309 if (list.size() > 1) {
310 result += "[ ";
311 }
312 for (std::size_t i = 0; i < list.size(); i++) {
313 if (i > 0) {
314 result += ", ";
315 }
316 result += list[i];
317 }
318 if (list.size() > 1) {
319 result += " ]";
320 }
321 return result;
322 }
323
OperatorTypeName(OperatorType type)324 const char* OperatorTypeName(OperatorType type) {
325 switch (type) {
326 #define HANDLE_OPERATORTYPENAME_CASE(c) \
327 case OperatorType::k##c: \
328 return #c;
329 HANDLE_OPERATORTYPENAME_CASE(Abs)
330 HANDLE_OPERATORTYPENAME_CASE(Add)
331 HANDLE_OPERATORTYPENAME_CASE(AddN)
332 HANDLE_OPERATORTYPENAME_CASE(AveragePool)
333 HANDLE_OPERATORTYPENAME_CASE(BatchMatMul)
334 HANDLE_OPERATORTYPENAME_CASE(BatchNormalization)
335 HANDLE_OPERATORTYPENAME_CASE(Conv)
336 HANDLE_OPERATORTYPENAME_CASE(Concatenation)
337 HANDLE_OPERATORTYPENAME_CASE(DepthwiseConv)
338 HANDLE_OPERATORTYPENAME_CASE(DepthToSpace)
339 HANDLE_OPERATORTYPENAME_CASE(SpaceToDepth)
340 HANDLE_OPERATORTYPENAME_CASE(FullyConnected)
341 HANDLE_OPERATORTYPENAME_CASE(HardSwish)
342 HANDLE_OPERATORTYPENAME_CASE(Dequantize)
343 HANDLE_OPERATORTYPENAME_CASE(L2Normalization)
344 HANDLE_OPERATORTYPENAME_CASE(LocalResponseNormalization)
345 HANDLE_OPERATORTYPENAME_CASE(Log)
346 HANDLE_OPERATORTYPENAME_CASE(Logistic)
347 HANDLE_OPERATORTYPENAME_CASE(LstmCell)
348 HANDLE_OPERATORTYPENAME_CASE(MaxPool)
349 HANDLE_OPERATORTYPENAME_CASE(L2Pool)
350 HANDLE_OPERATORTYPENAME_CASE(FakeQuant)
351 HANDLE_OPERATORTYPENAME_CASE(Mul)
352 HANDLE_OPERATORTYPENAME_CASE(RandomUniform)
353 HANDLE_OPERATORTYPENAME_CASE(Elu)
354 HANDLE_OPERATORTYPENAME_CASE(Relu)
355 HANDLE_OPERATORTYPENAME_CASE(Relu1)
356 HANDLE_OPERATORTYPENAME_CASE(Relu6)
357 HANDLE_OPERATORTYPENAME_CASE(PRelu)
358 HANDLE_OPERATORTYPENAME_CASE(ReorderAxes)
359 HANDLE_OPERATORTYPENAME_CASE(Softmax)
360 HANDLE_OPERATORTYPENAME_CASE(LogSoftmax)
361 HANDLE_OPERATORTYPENAME_CASE(Div)
362 HANDLE_OPERATORTYPENAME_CASE(Tanh)
363 HANDLE_OPERATORTYPENAME_CASE(Sin)
364 HANDLE_OPERATORTYPENAME_CASE(All)
365 HANDLE_OPERATORTYPENAME_CASE(Assert)
366 HANDLE_OPERATORTYPENAME_CASE(ExpandDims)
367 HANDLE_OPERATORTYPENAME_CASE(Fill)
368 HANDLE_OPERATORTYPENAME_CASE(FloorMod)
369 HANDLE_OPERATORTYPENAME_CASE(FloorDiv)
370 HANDLE_OPERATORTYPENAME_CASE(Greater)
371 HANDLE_OPERATORTYPENAME_CASE(GreaterEqual)
372 HANDLE_OPERATORTYPENAME_CASE(Identity)
373 HANDLE_OPERATORTYPENAME_CASE(Less)
374 HANDLE_OPERATORTYPENAME_CASE(LessEqual)
375 HANDLE_OPERATORTYPENAME_CASE(MatMul)
376 HANDLE_OPERATORTYPENAME_CASE(ReduceMax) // Reduction Max
377 HANDLE_OPERATORTYPENAME_CASE(Maximum) // Element-wise Maximum
378 HANDLE_OPERATORTYPENAME_CASE(Merge)
379 HANDLE_OPERATORTYPENAME_CASE(ReduceMin) // Reduction Min
380 HANDLE_OPERATORTYPENAME_CASE(Minimum) // Element-wise Minimum
381 HANDLE_OPERATORTYPENAME_CASE(Neg)
382 HANDLE_OPERATORTYPENAME_CASE(OneHot)
383 HANDLE_OPERATORTYPENAME_CASE(Pack)
384 HANDLE_OPERATORTYPENAME_CASE(Pad)
385 HANDLE_OPERATORTYPENAME_CASE(PadV2)
386 HANDLE_OPERATORTYPENAME_CASE(StridedSlice)
387 HANDLE_OPERATORTYPENAME_CASE(Range)
388 HANDLE_OPERATORTYPENAME_CASE(Rank)
389 HANDLE_OPERATORTYPENAME_CASE(Reshape)
390 HANDLE_OPERATORTYPENAME_CASE(Squeeze)
391 HANDLE_OPERATORTYPENAME_CASE(Rsqrt)
392 HANDLE_OPERATORTYPENAME_CASE(SegmentSum)
393 HANDLE_OPERATORTYPENAME_CASE(Shape)
394 HANDLE_OPERATORTYPENAME_CASE(Slice)
395 HANDLE_OPERATORTYPENAME_CASE(Split)
396 HANDLE_OPERATORTYPENAME_CASE(SplitV)
397 HANDLE_OPERATORTYPENAME_CASE(Sqrt)
398 HANDLE_OPERATORTYPENAME_CASE(Square)
399 HANDLE_OPERATORTYPENAME_CASE(Switch)
400 HANDLE_OPERATORTYPENAME_CASE(Sub)
401 HANDLE_OPERATORTYPENAME_CASE(Sum)
402 HANDLE_OPERATORTYPENAME_CASE(Tile)
403 HANDLE_OPERATORTYPENAME_CASE(Transpose)
404 HANDLE_OPERATORTYPENAME_CASE(TransposeConv)
405 HANDLE_OPERATORTYPENAME_CASE(Concat)
406 HANDLE_OPERATORTYPENAME_CASE(ConcatV2)
407 HANDLE_OPERATORTYPENAME_CASE(Cast)
408 HANDLE_OPERATORTYPENAME_CASE(Floor)
409 HANDLE_OPERATORTYPENAME_CASE(Ceil)
410 HANDLE_OPERATORTYPENAME_CASE(Round)
411 HANDLE_OPERATORTYPENAME_CASE(Gather)
412 HANDLE_OPERATORTYPENAME_CASE(GatherNd)
413 HANDLE_OPERATORTYPENAME_CASE(ResizeBilinear)
414 HANDLE_OPERATORTYPENAME_CASE(SpaceToBatchND)
415 HANDLE_OPERATORTYPENAME_CASE(BatchToSpaceND)
416 HANDLE_OPERATORTYPENAME_CASE(Mean)
417 HANDLE_OPERATORTYPENAME_CASE(ReduceProd)
418 HANDLE_OPERATORTYPENAME_CASE(Svdf)
419 HANDLE_OPERATORTYPENAME_CASE(ArgMax)
420 HANDLE_OPERATORTYPENAME_CASE(ArgMin)
421 HANDLE_OPERATORTYPENAME_CASE(TopK_V2)
422 HANDLE_OPERATORTYPENAME_CASE(Unsupported)
423 HANDLE_OPERATORTYPENAME_CASE(Exp)
424 HANDLE_OPERATORTYPENAME_CASE(DynamicPartition)
425 HANDLE_OPERATORTYPENAME_CASE(DynamicStitch)
426 HANDLE_OPERATORTYPENAME_CASE(Select)
427 HANDLE_OPERATORTYPENAME_CASE(SparseToDense)
428 HANDLE_OPERATORTYPENAME_CASE(Equal)
429 HANDLE_OPERATORTYPENAME_CASE(NotEqual)
430 HANDLE_OPERATORTYPENAME_CASE(Pow)
431 HANDLE_OPERATORTYPENAME_CASE(Any)
432 HANDLE_OPERATORTYPENAME_CASE(LogicalAnd)
433 HANDLE_OPERATORTYPENAME_CASE(LogicalNot)
434 HANDLE_OPERATORTYPENAME_CASE(LogicalOr)
435 HANDLE_OPERATORTYPENAME_CASE(CTCBeamSearchDecoder)
436 HANDLE_OPERATORTYPENAME_CASE(Unpack)
437 HANDLE_OPERATORTYPENAME_CASE(ZerosLike)
438 HANDLE_OPERATORTYPENAME_CASE(UnidirectionalSequenceLstm)
439 HANDLE_OPERATORTYPENAME_CASE(BidirectionalSequenceLstm)
440 HANDLE_OPERATORTYPENAME_CASE(BidirectionalSequenceRnn)
441 HANDLE_OPERATORTYPENAME_CASE(ResizeNearestNeighbor)
442 HANDLE_OPERATORTYPENAME_CASE(LeakyRelu)
443 HANDLE_OPERATORTYPENAME_CASE(SquaredDifference)
444 HANDLE_OPERATORTYPENAME_CASE(MirrorPad)
445 HANDLE_OPERATORTYPENAME_CASE(Unique)
446 HANDLE_OPERATORTYPENAME_CASE(UnidirectionalSequenceRnn)
447 HANDLE_OPERATORTYPENAME_CASE(ReverseV2)
448 HANDLE_OPERATORTYPENAME_CASE(Cos)
449 HANDLE_OPERATORTYPENAME_CASE(Where)
450 HANDLE_OPERATORTYPENAME_CASE(ReverseSequence)
451 HANDLE_OPERATORTYPENAME_CASE(MatrixDiag)
452 HANDLE_OPERATORTYPENAME_CASE(MatrixSetDiag)
453 HANDLE_OPERATORTYPENAME_CASE(MatrixDiagV2)
454 HANDLE_OPERATORTYPENAME_CASE(MatrixSetDiagV2)
455 HANDLE_OPERATORTYPENAME_CASE(MatrixDiagV3)
456 HANDLE_OPERATORTYPENAME_CASE(MatrixSetDiagV3)
457 HANDLE_OPERATORTYPENAME_CASE(ScatterNd)
458 default:
459 LOG(FATAL) << "Unhandled op type";
460 #undef HANDLE_OPERATORTYPENAME_CASE
461 }
462 }
463
HelpfulOperatorTypeName(const Operator & op)464 std::string HelpfulOperatorTypeName(const Operator& op) {
465 if (op.type == OperatorType::kUnsupported) {
466 return toco::port::StringF(
467 "(Unsupported TensorFlow op: %s)",
468 static_cast<const TensorFlowUnsupportedOperator&>(op).tensorflow_op);
469 }
470 return OperatorTypeName(op.type);
471 }
472
OperatorSupportsFusedActivation(OperatorType type)473 bool OperatorSupportsFusedActivation(OperatorType type) {
474 switch (type) {
475 case OperatorType::kAdd:
476 case OperatorType::kAveragePool:
477 case OperatorType::kBatchNormalization:
478 case OperatorType::kConv:
479 case OperatorType::kDepthwiseConv:
480 case OperatorType::kDiv:
481 case OperatorType::kFullyConnected:
482 case OperatorType::kL2Pool:
483 case OperatorType::kMaxPool:
484 case OperatorType::kMul:
485 case OperatorType::kSub:
486 case OperatorType::kSquaredDifference:
487 return true;
488 default:
489 return false;
490 }
491 }
492
LogSummary(int log_level,const Model & model)493 void LogSummary(int log_level, const Model& model) {
494 VLOG(log_level) << "Operators summary (" << model.operators.size()
495 << " operators):";
496 std::unordered_multiset<OperatorType> ops_by_type;
497 for (const auto& op : model.operators) {
498 ops_by_type.insert(op->type);
499 }
500 auto it = ops_by_type.begin();
501 while (it != ops_by_type.end()) {
502 int count = ops_by_type.count(*it);
503 VLOG(log_level) << " " << OperatorTypeName(*it) << ": " << count;
504 std::advance(it, count);
505 }
506 }
507
LogArray(int log_level,const Model & model,const std::string & name)508 void LogArray(int log_level, const Model& model, const std::string& name) {
509 VLOG(log_level) << "Array: " << name;
510 if (!model.HasArray(name)) {
511 VLOG(log_level) << " DOES NOT EXIST";
512 return;
513 }
514 const auto& array = model.GetArray(name);
515 VLOG(log_level) << " Data type: " << ArrayDataTypeName(array.data_type);
516 VLOG(log_level) << " Final type: "
517 << ArrayDataTypeName(array.final_data_type);
518 if (array.buffer) {
519 VLOG(log_level) << " Constant Buffer";
520 }
521 if (array.alloc) {
522 VLOG(log_level) << " Transient Alloc";
523 }
524 if (array.has_shape()) {
525 const Shape& array_shape = array.shape();
526 if (array_shape.dimensions_count() == 0) {
527 VLOG(log_level) << " (Zero dimensions)";
528 } else {
529 std::string message = " Dims: ";
530 bool first = true;
531 for (const int dim : array_shape.dims()) {
532 if (!first) {
533 message += ", ";
534 }
535 first = false;
536 toco::port::AppendF(&message, "%d", dim);
537 }
538 VLOG(log_level) << message;
539 }
540 }
541 if (array.minmax) {
542 VLOG(log_level) << " MinMax: " << array.minmax->min << " .. "
543 << array.minmax->max;
544 }
545 if (array.quantization_params) {
546 VLOG(log_level) << " QuantizationParams: zero_point="
547 << static_cast<int>(array.quantization_params->zero_point)
548 << ", scale=" << array.quantization_params->scale;
549 }
550 }
551
DumpGraphvizVideoFrame(const Model & model)552 void DumpGraphvizVideoFrame(const Model& model) {
553 namespace port = toco::port;
554
555 const auto& dump_options = *GraphVizDumpOptions::singleton();
556 if (!dump_options.dump_graphviz_video) {
557 return;
558 }
559 CHECK(!dump_options.dump_graphviz.empty());
560 // TODO(benoitjacob): the static data here means that this function
561 // is stateful, not reentrant, and effectively leaks memory till exit
562 // (since dump_hashes can only grow in size). It also means that it
563 // really only is intended to be called for a single model during the
564 // process' lifetime. So it's not great design at all. The overriding
565 // design aspect here is to make the video-dumping code as unintrusive
566 // and self-contained as possible. Eventually, we'll want to have that
567 // cleaned-up, but that will require some form of general statefulness
568 // in toco (some kind of 'tooling state' data structure) that does
569 // not exist at present, and would be premature to design here just for
570 // this new video-dumping feature.
571 static int dump_id = 0;
572 static std::unordered_set<std::size_t> dump_hashes;
573 std::string graphviz_dump;
574 DumpGraphviz(model, &graphviz_dump,
575 toco::port::StringF("VIDEO frame:%05d", dump_id));
576 std::size_t hash = std::hash<std::string>{}(graphviz_dump);
577 if (!dump_hashes.count(hash)) {
578 LOG(INFO) << "DUMPING GRAPHVIZ VIDEO FRAME: " << dump_id;
579 dump_hashes.insert(hash);
580 const auto result = port::file::SetContents(
581 port::file::JoinPath(
582 dump_options.dump_graphviz,
583 toco::port::StringF("toco_video_%05d.dot", dump_id)),
584 graphviz_dump, port::file::Defaults());
585 QCHECK(result.ok()) << result.error_message();
586 dump_id++;
587 }
588 }
589
LogDump(int log_level,const std::string & message,const Model & model)590 void LogDump(int log_level, const std::string& message, const Model& model) {
591 namespace port = toco::port;
592 const auto& dump_options = *GraphVizDumpOptions::singleton();
593
594 DumpGraphvizVideoFrame(model);
595 if (!dump_options.dump_graphviz.empty()) {
596 std::string graphviz_dump;
597
598 DumpGraphviz(model, &graphviz_dump, message);
599 const auto result = port::file::SetContents(
600 port::file::JoinPath(
601 dump_options.dump_graphviz,
602 absl::StrCat("toco_", absl::StrReplaceAll(message, {{" ", "_"}}),
603 ".dot")),
604 graphviz_dump, port::file::Defaults());
605 QCHECK(result.ok()) << result.error_message();
606 }
607
608 if (!VLOG_IS_ON(log_level)) {
609 return;
610 }
611 VLOG(log_level) << "BEGIN DUMP OF TOCO MODEL (" << message << ")";
612 LogSummary(log_level, model);
613 std::unordered_set<std::string> already_printed_arrays;
614 for (const auto& op : model.operators) {
615 for (const auto& input : op->inputs) {
616 if (!already_printed_arrays.count(input)) {
617 already_printed_arrays.insert(input);
618 LogArray(log_level, model, input);
619 }
620 }
621 VLOG(log_level) << HelpfulOperatorTypeName(*op) << " :";
622 VLOG(log_level) << " " << FormatArraysList(model, op->inputs) << " -> "
623 << FormatArraysList(model, op->outputs);
624 if (op->fused_activation_function != FusedActivationFunctionType::kNone) {
625 VLOG(log_level) << " (with fused activation function)";
626 }
627 for (const auto& output : op->outputs) {
628 if (!already_printed_arrays.count(output)) {
629 already_printed_arrays.insert(output);
630 LogArray(log_level, model, output);
631 }
632 }
633 }
634 VLOG(log_level) << "END DUMP OF TOCO MODEL (" << message << ")";
635 }
636
637 // Note remaining raw-array extension in ProcessTensorFlowReshapeOperator().
ExtendShape(Shape * shape,int new_shape_size)638 void ExtendShape(Shape* shape, int new_shape_size) {
639 CHECK_GE(new_shape_size, shape->dimensions_count());
640 const int size_increase = new_shape_size - shape->dimensions_count();
641 auto* shape_dims = shape->mutable_dims();
642 shape_dims->insert(shape_dims->begin(), size_increase, 1);
643 }
644
645 // TODO(b/62904716) Remove along with remaining uses.
UnextendShape(Shape * shape,int new_shape_size)646 void UnextendShape(Shape* shape, int new_shape_size) {
647 CHECK_LE(new_shape_size, shape->dimensions_count());
648 const int size_reduction = shape->dimensions_count() - new_shape_size;
649 for (int i = 0; i < size_reduction; i++) {
650 CHECK_EQ(shape->dims(i), 1);
651 }
652 std::vector<int>& shape_dims = *shape->mutable_dims();
653 shape_dims.erase(shape_dims.begin(), shape_dims.begin() + size_reduction);
654 }
655
656 // In general, zero-sized dimensions are disallowed, but there are exceptions,
657 // e.g., if the tensor data itself represents a scalar (rank 0) shape, its
658 // shape will have dimensions [0]. CheckNonEmptyShapeDimensions is more
659 // strict, and is appropriate for ops and comparisons where an empty shape
660 // doesn't make sense.
661 template <typename Dims>
CheckValidShapeDimensions(const Dims & dims)662 void CheckValidShapeDimensions(const Dims& dims) {
663 if (dims.size() == 1 && dims[0] == 0) {
664 return;
665 }
666 for (const auto& dim : dims) {
667 CHECK_GE(dim, 1);
668 }
669 }
670
CheckValidShape(const Shape & shape)671 void CheckValidShape(const Shape& shape) {
672 CheckValidShapeDimensions(shape.dims());
673 }
674
IsNonEmpty(const Shape & shape)675 bool IsNonEmpty(const Shape& shape) {
676 for (int i = 0; i < shape.dimensions_count(); ++i) {
677 if (shape.dims(i) < 1) return false;
678 }
679 return true;
680 }
681
CheckNonEmptyShapeDimensions(const Shape & shape)682 void CheckNonEmptyShapeDimensions(const Shape& shape) {
683 for (int i = 0; i < shape.dimensions_count(); ++i) {
684 CHECK_GE(shape.dims()[i], 1) << "shape has dimension 0 at index << " << i
685 << ". shape = " << ShapeToString(shape);
686 }
687 }
688
ShapesAgreeUpToBroadcasting(const Shape & shape0,const Shape & shape1)689 bool ShapesAgreeUpToBroadcasting(const Shape& shape0, const Shape& shape1) {
690 CheckNonEmptyShapeDimensions(shape0);
691 CheckNonEmptyShapeDimensions(shape1);
692
693 const Shape* longer = &shape0;
694 const Shape* shorter = &shape1;
695 if (shape1.dimensions_count() > shape0.dimensions_count()) {
696 longer = &shape1;
697 shorter = &shape0;
698 }
699
700 // Walk dimensions back to front until we run out of dimensions in the shorter
701 // shape.
702 int longer_index = longer->dimensions_count() - 1;
703 int shorter_index = shorter->dimensions_count() - 1;
704 while (shorter_index >= 0) {
705 const int d_long = longer->dims(longer_index);
706 const int d_short = shorter->dims(shorter_index);
707 // Broadcasting fails if the dimensions are different *and* neither is 1.
708 if ((d_long != d_short) && (d_long != 1) && (d_short != 1)) {
709 return false;
710 }
711 longer_index--;
712 shorter_index--;
713 }
714 return true;
715 }
716
ShapesAgreeUpToExtending(const Shape & shape0,const Shape & shape1)717 bool ShapesAgreeUpToExtending(const Shape& shape0, const Shape& shape1) {
718 CheckNonEmptyShapeDimensions(shape0);
719 CheckNonEmptyShapeDimensions(shape1);
720
721 const Shape* longer = &shape0;
722 const Shape* shorter = &shape1;
723 if (shape1.dimensions_count() > shape0.dimensions_count()) {
724 longer = &shape1;
725 shorter = &shape0;
726 }
727
728 // Walk dimensions back to front until we run out of dimensions in the shorter
729 // shape.
730 int longer_index = longer->dimensions_count() - 1;
731 int shorter_index = shorter->dimensions_count() - 1;
732 while (shorter_index >= 0) {
733 const int d_long = longer->dims(longer_index);
734 const int d_short = shorter->dims(shorter_index);
735 // Extending fails if the dimensions are different.
736 if (d_long != d_short) {
737 return false;
738 }
739 longer_index--;
740 shorter_index--;
741 }
742
743 // The remaining dimensions in the longer shape must be 1.
744 while (longer_index >= 0) {
745 const int d_long = longer->dims(longer_index);
746 if (d_long != 1) {
747 return false;
748 }
749 longer_index--;
750 }
751
752 return true;
753 }
754
RequiredBufferSizeForShape(const Shape & shape)755 int RequiredBufferSizeForShape(const Shape& shape) {
756 CheckValidShape(shape);
757 int max_offset = 1;
758 for (const auto& dim : shape.dims()) {
759 max_offset *= dim;
760 }
761 return max_offset;
762 }
763
IsConstantParameterArray(const Model & model,const std::string & name)764 bool IsConstantParameterArray(const Model& model, const std::string& name) {
765 if (!model.HasArray(name)) {
766 return false;
767 }
768
769 return !!model.GetArray(name).buffer;
770 }
771
772 namespace {
773 template <ArrayDataType A>
CompareArrayBuffers(const Array & lhs_array,const Array & rhs_array)774 bool CompareArrayBuffers(const Array& lhs_array, const Array& rhs_array) {
775 CHECK(lhs_array.data_type == rhs_array.data_type) << "Data types must match";
776 CHECK(lhs_array.buffer) << "LHS must be constant";
777 CHECK(rhs_array.buffer) << "RHS must be constant";
778 const auto& lhs_data = lhs_array.GetBuffer<A>().data;
779 const auto& rhs_data = rhs_array.GetBuffer<A>().data;
780 CHECK_EQ(lhs_data.size(), rhs_data.size())
781 << "Buffer sizes must match in element count";
782 for (int i = 0; i < lhs_data.size(); ++i) {
783 if (lhs_data[i] != rhs_data[i]) {
784 return false;
785 }
786 }
787 return true;
788 }
789
HaveSameMinMax(const Array & lhs_array,const Array & rhs_array)790 bool HaveSameMinMax(const Array& lhs_array, const Array& rhs_array) {
791 if (lhs_array.minmax || rhs_array.minmax) {
792 if (!lhs_array.minmax || !rhs_array.minmax) {
793 return false;
794 }
795 if (!(*lhs_array.minmax == *rhs_array.minmax)) {
796 return false;
797 }
798 }
799 return true;
800 }
801
HaveSameQuantizationParams(const Array & lhs_array,const Array & rhs_array)802 bool HaveSameQuantizationParams(const Array& lhs_array,
803 const Array& rhs_array) {
804 if (lhs_array.quantization_params || rhs_array.quantization_params) {
805 if (!lhs_array.quantization_params || !rhs_array.quantization_params) {
806 return false;
807 }
808 if (!(*lhs_array.quantization_params == *rhs_array.quantization_params)) {
809 return false;
810 }
811 }
812 return true;
813 }
814
815 } // namespace
816
CompareConstantArrays(const Array & lhs_array,const Array & rhs_array)817 bool CompareConstantArrays(const Array& lhs_array, const Array& rhs_array) {
818 bool attrs_equal = lhs_array.shape() == rhs_array.shape() &&
819 lhs_array.data_type == rhs_array.data_type &&
820 lhs_array.final_data_type == rhs_array.final_data_type &&
821 HaveSameMinMax(lhs_array, rhs_array) &&
822 HaveSameQuantizationParams(lhs_array, rhs_array) &&
823 lhs_array.narrow_range == rhs_array.narrow_range;
824 if (!attrs_equal) {
825 return false;
826 }
827 switch (lhs_array.data_type) {
828 case ArrayDataType::kBool:
829 return CompareArrayBuffers<ArrayDataType::kBool>(lhs_array, rhs_array);
830 case ArrayDataType::kFloat:
831 return CompareArrayBuffers<ArrayDataType::kFloat>(lhs_array, rhs_array);
832 case ArrayDataType::kInt8:
833 return CompareArrayBuffers<ArrayDataType::kInt8>(lhs_array, rhs_array);
834 case ArrayDataType::kUint8:
835 return CompareArrayBuffers<ArrayDataType::kUint8>(lhs_array, rhs_array);
836 case ArrayDataType::kInt16:
837 return CompareArrayBuffers<ArrayDataType::kInt16>(lhs_array, rhs_array);
838 case ArrayDataType::kUint16:
839 return CompareArrayBuffers<ArrayDataType::kUint16>(lhs_array, rhs_array);
840 case ArrayDataType::kInt32:
841 return CompareArrayBuffers<ArrayDataType::kInt32>(lhs_array, rhs_array);
842 case ArrayDataType::kUint32:
843 return CompareArrayBuffers<ArrayDataType::kUint32>(lhs_array, rhs_array);
844 case ArrayDataType::kInt64:
845 return CompareArrayBuffers<ArrayDataType::kInt64>(lhs_array, rhs_array);
846 case ArrayDataType::kUint64:
847 return CompareArrayBuffers<ArrayDataType::kUint64>(lhs_array, rhs_array);
848 case ArrayDataType::kString:
849 return CompareArrayBuffers<ArrayDataType::kString>(lhs_array, rhs_array);
850 case ArrayDataType::kComplex64:
851 return CompareArrayBuffers<ArrayDataType::kComplex64>(lhs_array,
852 rhs_array);
853 default:
854 LOG(FATAL) << "Unsupported data type: "
855 << ArrayDataTypeName(lhs_array.data_type);
856 return false;
857 }
858 }
859
860 namespace {
861 // Take an array name, which may be something like "name:3_5" and make it
862 // acceptable as a TF node name, say "name_3_5";
SanitizeNameForTFNode(const std::string & array_name)863 std::string SanitizeNameForTFNode(const std::string& array_name) {
864 auto node_name = array_name;
865 std::replace(node_name.begin(), node_name.end(), ':', '_');
866 return node_name;
867 }
868
CheckInputArraysAreNotOutputArrays(const ModelFlags & model_flags)869 void CheckInputArraysAreNotOutputArrays(const ModelFlags& model_flags) {
870 for (const auto& input_array : model_flags.input_arrays()) {
871 for (const std::string& output_array : model_flags.output_arrays()) {
872 QCHECK_NE(input_array.name(), output_array)
873 << "The array " << output_array
874 << " is listed in both --input_arrays and --output_arrays.";
875 }
876 }
877 }
878
IsAsciiPrintable(const std::string & name)879 bool IsAsciiPrintable(const std::string& name) {
880 for (char c : name) {
881 if (!absl::ascii_isprint(c)) {
882 return false;
883 }
884 }
885 return true;
886 }
887
DumpAscii(const std::string & name)888 std::string DumpAscii(const std::string& name) {
889 std::string result;
890 port::AppendF(&result, "ASCII | Hex\n");
891 port::AppendF(&result, "------+----\n");
892 for (char c : name) {
893 if (absl::ascii_isprint(c)) {
894 port::AppendF(&result, "%c | %x\n", c, c);
895 } else {
896 port::AppendF(&result, " | %x Not ASCII printable!\n", c);
897 }
898 }
899 return result;
900 }
901
CheckNonAsciiIOArrays(const ModelFlags & model_flags)902 void CheckNonAsciiIOArrays(const ModelFlags& model_flags) {
903 if (model_flags.allow_nonascii_arrays()) {
904 return;
905 }
906 for (const auto& input_array : model_flags.input_arrays()) {
907 QCHECK(IsAsciiPrintable(input_array.name()))
908 << "Non-ASCII-printable character found in --input_arrays: "
909 << input_array.name()
910 << ". Pass --allow_nonascii_arrays to allow that. "
911 << "Here is a dump of the string:\n\n"
912 << DumpAscii(input_array.name());
913 }
914 for (const std::string& output_array : model_flags.output_arrays()) {
915 QCHECK(IsAsciiPrintable(output_array))
916 << "Non-ASCII-printable character found in --output_arrays: "
917 << output_array << ". Pass --allow_nonascii_arrays to allow that. "
918 << "Here is a dump of the string:\n\n"
919 << DumpAscii(output_array);
920 }
921 }
922
CheckNonExistentIOArrays(const Model & model)923 void CheckNonExistentIOArrays(const Model& model) {
924 // "non-existent" is interpreted in the stronger sense of
925 // "not actually produced/consumed by an op".
926 // Rationale: we have to artificially fix up TensorFlow graphs by creating
927 // any array that it refers to, so just checking that arrays exist isn't
928 // sufficient. The real invariant here is whether arrays are produced/consumed
929 // by something.
930 if (model.flags.allow_nonexistent_arrays()) {
931 return;
932 }
933 static constexpr char general_comment[] =
934 "Is it a typo? This should not happen. If you trigger this error "
935 "please send a bug report (with code to reproduce this error), to the "
936 "TensorFlow Lite team.";
937 for (const std::string& output_array : model.flags.output_arrays()) {
938 if (IsConstantParameterArray(model, output_array)) {
939 continue; // It is OK to request that a constant be an output.
940 }
941 QCHECK(GetOpWithOutput(model, output_array))
942 << "Specified output array \"" << output_array
943 << "\" is not produced by any op in this graph. " << general_comment;
944 }
945 for (const auto& rnn_state : model.flags.rnn_states()) {
946 if (!rnn_state.discardable()) {
947 // Check that all RNN states are consumed
948 QCHECK(GetOpWithInput(model, rnn_state.state_array()))
949 << "Specified RNN state \"" << rnn_state.state_array()
950 << "\" is not consumed by any op in this graph. " << general_comment;
951 // Check that all RNN back-edge source arrays are produced
952 QCHECK(GetOpWithOutput(model, rnn_state.back_edge_source_array()))
953 << "Specified RNN back-edge source array \""
954 << rnn_state.back_edge_source_array()
955 << "\" is not produced by any op in this graph. " << general_comment;
956 }
957 }
958 }
959
960 } // namespace
961
CheckNoMissingArray(const Model & model)962 void CheckNoMissingArray(const Model& model) {
963 for (const auto& op : model.operators) {
964 for (const auto& input : op->inputs) {
965 CHECK(model.HasArray(input) || model.optional_arrays.count(input))
966 << "Input: " << input << " missing for op: " << op->outputs[0] << ".";
967 }
968 for (const auto& output : op->outputs) {
969 CHECK(model.HasArray(output)) << "Output: " << output << " missing.";
970 }
971 }
972 CheckNonExistentIOArrays(model);
973 }
974
FixNoMissingArray(Model * model)975 void FixNoMissingArray(Model* model) {
976 for (const auto& op : model->operators) {
977 for (const auto& input : op->inputs) {
978 if (!model->HasArray(input) && !model->IsOptionalArray(input)) {
979 model->GetOrCreateArray(input);
980 }
981 }
982 for (const auto& output : op->outputs) {
983 if (!model->HasArray(output) && !model->IsOptionalArray(output)) {
984 model->GetOrCreateArray(output);
985 }
986 }
987 }
988 if (model->flags.allow_nonexistent_arrays()) {
989 for (const std::string& output_array : model->flags.output_arrays()) {
990 model->GetOrCreateArray(output_array);
991 }
992 for (const auto& rnn_state : model->flags.rnn_states()) {
993 model->GetOrCreateArray(rnn_state.state_array());
994 model->GetOrCreateArray(rnn_state.back_edge_source_array());
995 }
996 }
997 }
998
CheckNoOrphanedArray(const Model & model)999 void CheckNoOrphanedArray(const Model& model) {
1000 std::unordered_set<std::string> arrays_without_known_use;
1001 for (const auto& array : model.GetArrayMap()) {
1002 if (IsDiscardableArray(model, array.first)) {
1003 arrays_without_known_use.insert(array.first);
1004 }
1005 }
1006 for (const auto& op : model.operators) {
1007 for (const auto& input : op->inputs) {
1008 arrays_without_known_use.erase(input);
1009 }
1010 for (const auto& output : op->outputs) {
1011 arrays_without_known_use.erase(output);
1012 }
1013 }
1014 for (const auto& rnn_state : model.flags.rnn_states()) {
1015 arrays_without_known_use.erase(rnn_state.state_array());
1016 arrays_without_known_use.erase(rnn_state.back_edge_source_array());
1017 }
1018 if (!arrays_without_known_use.empty()) {
1019 for (const auto& array : arrays_without_known_use) {
1020 LOG(INFO) << "Error: Orphaned array: " << array;
1021 }
1022 }
1023 CHECK(arrays_without_known_use.empty());
1024 }
1025
FixNoOrphanedArray(Model * model)1026 void FixNoOrphanedArray(Model* model) {
1027 std::unordered_set<std::string> arrays_without_known_use;
1028 for (const auto& array : model->GetArrayMap()) {
1029 arrays_without_known_use.insert(array.first);
1030 }
1031 for (const auto& op : model->operators) {
1032 for (const auto& input : op->inputs) {
1033 arrays_without_known_use.erase(input);
1034 }
1035 for (const auto& output : op->outputs) {
1036 arrays_without_known_use.erase(output);
1037 }
1038 }
1039 for (const auto& rnn_state : model->flags.rnn_states()) {
1040 arrays_without_known_use.erase(rnn_state.state_array());
1041 arrays_without_known_use.erase(rnn_state.back_edge_source_array());
1042 }
1043 for (const auto& array : arrays_without_known_use) {
1044 if (IsDiscardableArray(*model, array)) {
1045 model->EraseArray(array);
1046 }
1047 }
1048 }
1049
1050 // Apply checks to arrays individually (for-each fashion).
1051 //
1052 // Check consistency of array fields, check name.
CheckEachArray(const Model & model)1053 void CheckEachArray(const Model& model) {
1054 for (const auto& array_entry : model.GetArrayMap()) {
1055 const auto& array = array_entry.second;
1056 // It's OK to have a buffer or an alloc, but not both.
1057 // (Since allocs are for transient arrays without a buffer).
1058 CHECK(!array->buffer || !array->alloc) << "Tensor: " << array_entry.first;
1059 if (array->buffer) {
1060 // If there is a buffer, its type should be consistent with data_type.
1061 CHECK(array->buffer->type == array->data_type)
1062 << "Tensor: " << array_entry.first;
1063 // The presence of a fixed buffer should imply the presence of a fixed
1064 // shape.
1065 CHECK(array->has_shape()) << array_entry.first;
1066 // Constant buffer should has a valid shape.
1067 CheckValidShape(array->shape());
1068 // The shape flat-size should agree with the buffer length.
1069 CHECK_EQ(array->buffer->Length(),
1070 RequiredBufferSizeForShape(array->shape()))
1071 << "Tensor: " << array_entry.first;
1072 }
1073
1074 // Check name. Either "name_with_suffix_8", "name_with_port:3", but not
1075 // "name_with_both:3_8".
1076 const std::string& name = array_entry.first;
1077 auto colon_pos = name.find_first_of(':');
1078 if (colon_pos != std::string::npos) {
1079 CHECK_EQ(name.substr(colon_pos + 1).find_first_not_of("0123456789"),
1080 std::string::npos)
1081 << "Array '" << name << "' has non-digit characters after colon.";
1082 }
1083 CHECK_GT(colon_pos, 0) << "Array '" << name
1084 << "' must not start with a colon.";
1085 }
1086 }
1087
CheckOperatorOrdering(const Model & model)1088 void CheckOperatorOrdering(const Model& model) {
1089 std::unordered_set<std::string> arrays_behind_us;
1090 for (const auto& array_entry : model.GetArrayMap()) {
1091 if (!GetOpWithOutput(model, array_entry.first)) {
1092 arrays_behind_us.insert(array_entry.first);
1093 }
1094 }
1095 arrays_behind_us.insert(model.optional_arrays.begin(),
1096 model.optional_arrays.end());
1097 for (const auto& op : model.operators) {
1098 for (const auto& input : op->inputs) {
1099 if (!IsConstantParameterArray(model, input)) {
1100 CHECK(arrays_behind_us.count(input));
1101 }
1102 }
1103 for (const auto& output : op->outputs) {
1104 CHECK(!arrays_behind_us.count(output));
1105 arrays_behind_us.insert(output);
1106 }
1107 }
1108 for (const std::string& output_array : model.flags.output_arrays()) {
1109 CHECK(arrays_behind_us.count(output_array));
1110 }
1111 }
1112
FixOperatorOrdering(Model * model)1113 void FixOperatorOrdering(Model* model) {
1114 std::unordered_set<std::string> arrays_behind_us;
1115 for (const auto& array_entry : model->GetArrayMap()) {
1116 if (!GetOpWithOutput(*model, array_entry.first)) {
1117 arrays_behind_us.insert(array_entry.first);
1118 }
1119 }
1120 arrays_behind_us.insert(model->optional_arrays.begin(),
1121 model->optional_arrays.end());
1122 std::vector<std::unique_ptr<Operator>> old_operators;
1123 std::swap(old_operators, model->operators);
1124 std::set<std::size_t> remaining;
1125 for (std::size_t i = 0; i < old_operators.size(); i++) {
1126 remaining.insert(i);
1127 }
1128 std::unordered_map<std::string, std::string> reason_why_leftover;
1129 while (true) {
1130 bool inserted_something = false;
1131 for (const auto& i : remaining) {
1132 bool can_insert = true;
1133 auto& op = old_operators[i];
1134 CHECK(op);
1135 for (const auto& input : op->inputs) {
1136 if (!IsConstantParameterArray(*model, input) &&
1137 !arrays_behind_us.count(input)) {
1138 for (const std::string& output : op->outputs) {
1139 reason_why_leftover[output] = input;
1140 }
1141 can_insert = false;
1142 break;
1143 }
1144 }
1145 if (can_insert) {
1146 model->operators.emplace_back(nullptr);
1147 for (const auto& output : op->outputs) {
1148 arrays_behind_us.insert(output);
1149 }
1150 std::swap(op, model->operators.back());
1151 remaining.erase(i);
1152 inserted_something = true;
1153 break;
1154 }
1155 }
1156 if (!inserted_something) {
1157 break;
1158 }
1159 }
1160 if (!remaining.empty()) {
1161 LOG(ERROR)
1162 << "No viable ordering of operators was found. "
1163 << "Here is a 'backtrace' of at least one part of the graph that is "
1164 << "problematic. It starts with the first operator that has as "
1165 << "problematic input array, and then walks back the graph to "
1166 << "the operator that produced that input array, etc., until we find "
1167 << "the root cause:";
1168 LOG(ERROR) << "BEGIN TRACE OF OPERATOR WITH BAD INPUT";
1169 LOG(ERROR) << "Here is the first-encountered operator with a bad input: ";
1170 const Operator* bad_op = old_operators[*remaining.begin()].get();
1171 std::unordered_set<std::string> bad_inputs_already_traced;
1172 // The following while(true) loop should always end with a LOG(FATAL).
1173 while (true) {
1174 LOG(ERROR) << HelpfulOperatorTypeName(*bad_op) << " : "
1175 << FormatArraysList(*model, bad_op->inputs) << " -> "
1176 << FormatArraysList(*model, bad_op->outputs);
1177 bool found_bad_output = false;
1178 std::string bad_output;
1179 for (const std::string& output : bad_op->outputs) {
1180 if (reason_why_leftover.count(output)) {
1181 found_bad_output = true;
1182 bad_output = output;
1183 break;
1184 }
1185 }
1186 CHECK(found_bad_output);
1187 const std::string& bad_input = reason_why_leftover[bad_output];
1188 LOG(ERROR) << "The bad input here is: " << bad_input;
1189 if (bad_inputs_already_traced.count(bad_input)) {
1190 LOG(FATAL)
1191 << "Cycle found! We already encountered that "
1192 << "input array, " << bad_input << ", earlier in the "
1193 << "above trace! We expect graphs to be acyclic, even "
1194 << "RNNs. Let us know if some graph actually needs to have "
1195 << "cycles, but first, please check if it really is "
1196 << "an *inference* graph. *Training* graphs are out-of-scope "
1197 << "for toco.";
1198 }
1199 bad_inputs_already_traced.insert(bad_input);
1200 bad_op = nullptr;
1201 for (const auto& i : remaining) {
1202 const Operator* op = old_operators[i].get();
1203 for (const std::string& output : op->outputs) {
1204 if (bad_input == output) {
1205 bad_op = op;
1206 break;
1207 }
1208 }
1209 if (bad_op) {
1210 break;
1211 }
1212 }
1213 if (!bad_op) {
1214 LOG(ERROR) << "And that's the root cause: "
1215 << "that array, " << bad_input << ", isn't produced by any "
1216 << "operator, or provided in any other way.";
1217 LOG(ERROR) << "END TRACE OF OPERATOR WITH BAD INPUT";
1218 LOG(FATAL) << "(The above was a multi-line fatal error)";
1219 }
1220 LOG(ERROR) << "And that array is the output of the following operator:";
1221 }
1222 }
1223 CHECK(remaining.empty())
1224 << "Should never get here! In case of bad graph, "
1225 << "the above code should have generated a FATAL error already!";
1226 }
1227
CheckInvariants(const Model & model)1228 void CheckInvariants(const Model& model) {
1229 CheckInputArraysAreNotOutputArrays(model.flags);
1230 CheckNonAsciiIOArrays(model.flags);
1231 CheckNoMissingArray(model);
1232 CheckNoOrphanedArray(model);
1233 CheckEachArray(model);
1234 CheckOperatorOrdering(model);
1235 }
1236
CheckCountInRange(const::toco::ModelFlags::ModelCheck & model_check,const int count,const std::string & count_description)1237 void CheckCountInRange(const ::toco::ModelFlags::ModelCheck& model_check,
1238 const int count, const std::string& count_description) {
1239 if (model_check.count_min() >= 0) {
1240 CHECK_GE(count, model_check.count_min())
1241 << "Mismatch in " << count_description << ": count was " << count
1242 << ", but the specified "
1243 << (model_check.count_max() > model_check.count_min() ? "minimum"
1244 : "value")
1245 << " was " << model_check.count_min() << ".";
1246 }
1247 if (model_check.count_max() > model_check.count_min()) {
1248 CHECK_LE(count, model_check.count_max())
1249 << "Mismatch in " << count_description << ": count was " << count
1250 << ", but the specified maximum was " << model_check.count_max() << ".";
1251 }
1252 }
1253
CheckModelCounts(const Model & model)1254 void CheckModelCounts(const Model& model) {
1255 std::unordered_multiset<OperatorType> ops_by_type;
1256 std::unordered_map<std::string, OperatorType> op_type_by_name;
1257 if (model.flags.model_checks_size() == 0) {
1258 return;
1259 }
1260
1261 for (const auto& op : model.operators) {
1262 ops_by_type.insert(op->type);
1263 op_type_by_name[OperatorTypeName(op->type)] = op->type;
1264 }
1265 for (const auto& model_check : model.flags.model_checks()) {
1266 std::string count_type = model_check.count_type();
1267 if (count_type == "None") {
1268 continue;
1269 } else if (count_type == "Arrays") {
1270 CheckCountInRange(model_check, model.GetArrayMap().size(),
1271 "count of arrays");
1272 } else if (count_type == "Total") {
1273 CheckCountInRange(model_check, model.operators.size(),
1274 "count of all operator instances");
1275 } else {
1276 // The check type is not itself checked against the set of valid
1277 // operators, mainly because the enum set cannot be iterated in C++.
1278 const int found_count =
1279 op_type_by_name.count(count_type) > 0
1280 ? ops_by_type.count(op_type_by_name[count_type])
1281 : 0;
1282 CheckCountInRange(model_check, found_count,
1283 "count of instances of " + count_type + " operator");
1284 }
1285 }
1286 }
1287
FixEdgeArrays(Model * model)1288 void FixEdgeArrays(Model* model) {
1289 for (const std::string& output_array_name : model->flags.output_arrays()) {
1290 if (!GetOpWithOutput(*model, output_array_name)) {
1291 // Output has no operator producing it. Change that by inserting a copy.
1292 LOG(WARNING) << "Fixing constant output array " << output_array_name
1293 << " by inserting a copy. This is not optimal.";
1294 std::string intermediate_array_name =
1295 AvailableArrayName(*model, output_array_name + "_copy");
1296 CloneArray(model, output_array_name, intermediate_array_name);
1297 InsertCopyOperator(model, intermediate_array_name, output_array_name);
1298 }
1299 }
1300 }
1301
DedupeConstantArrays(Model * model,size_t min_size)1302 void DedupeConstantArrays(Model* model, size_t min_size) {
1303 // Walk all 0..N and compare with the remaining n+1..N.
1304 // This lets us avoid N^2 comparisons and erase duplicate arrays while
1305 // iterating.
1306 const auto& array_map = model->GetArrayMap();
1307 for (auto lhs_array_it = array_map.begin(); lhs_array_it != array_map.end();
1308 ++lhs_array_it) {
1309 const auto& lhs_array_name = lhs_array_it->first;
1310 const auto& lhs_array = *lhs_array_it->second;
1311 if (!IsConstantParameterArray(*model, lhs_array_name)) {
1312 // Not a constant array; skip.
1313 continue;
1314 }
1315 ArrayDataType final_data_type =
1316 lhs_array.final_data_type != ArrayDataType::kNone
1317 ? lhs_array.final_data_type
1318 : lhs_array.data_type;
1319 // Ignore small arrays, don't check string arrays because it is not possible
1320 // to estimate its size.
1321 if (final_data_type != ArrayDataType::kString) {
1322 size_t array_byte_size =
1323 lhs_array.buffer->Length() * ElementSize(final_data_type);
1324 if (array_byte_size < min_size) {
1325 // Too small; skip.
1326 continue;
1327 }
1328 }
1329
1330 auto next_lhs_array_it = lhs_array_it;
1331 ++next_lhs_array_it;
1332 for (auto rhs_array_it = next_lhs_array_it;
1333 rhs_array_it != array_map.end();) {
1334 const auto& rhs_array_name = rhs_array_it->first;
1335 const auto& rhs_array = *rhs_array_it->second;
1336 ++rhs_array_it;
1337 if (!IsConstantParameterArray(*model, rhs_array_name)) {
1338 // Not a constant array; skip.
1339 continue;
1340 }
1341 if (!IsDiscardableArray(*model, rhs_array_name)) {
1342 // Can't remove the array as it's not discardable (such as an IO edge).
1343 continue;
1344 }
1345 if (!CompareConstantArrays(lhs_array, rhs_array)) {
1346 // Arrays aren't equal; skip.
1347 continue;
1348 }
1349
1350 // Arrays can be deduped!
1351 VLOG(1) << "Deduplicating arrays; using " << lhs_array_name
1352 << " in place of " << rhs_array_name;
1353 ReplaceArrayUsage(model, rhs_array_name, lhs_array_name);
1354 // Note: rhs_array_it above is already incremented so this is safe.
1355 model->EraseArray(rhs_array_name);
1356 }
1357 }
1358 }
1359
1360 namespace {
CopyArrayAttribs(const Array & source_array,Array * target_array)1361 void CopyArrayAttribs(const Array& source_array, Array* target_array) {
1362 target_array->data_type = source_array.data_type;
1363 target_array->final_data_type = source_array.final_data_type;
1364 if (source_array.has_shape()) {
1365 target_array->copy_shape(source_array.shape());
1366 }
1367
1368 if (source_array.minmax) {
1369 target_array->GetOrCreateMinMax() = source_array.GetMinMax();
1370 } else {
1371 target_array->minmax.reset();
1372 }
1373
1374 if (source_array.quantization_params) {
1375 target_array->GetOrCreateQuantizationParams() =
1376 source_array.GetQuantizationParams();
1377 } else {
1378 target_array->quantization_params.reset();
1379 }
1380 }
1381 } // namespace
1382
InsertCopyOperator(Model * model,const std::string & source_array_name,const std::string & target_array_name)1383 void InsertCopyOperator(Model* model, const std::string& source_array_name,
1384 const std::string& target_array_name) {
1385 // Reshape to the same size. This should be a no-op.
1386 const Array& source_array = model->GetArray(source_array_name);
1387 std::vector<int> shape = source_array.shape().dims();
1388
1389 // Drop constant data from the target array as the copy will be done at
1390 // runtime.
1391 Array& target_array = model->GetOrCreateArray(target_array_name);
1392 target_array.buffer.reset();
1393 CopyArrayAttribs(source_array, &target_array);
1394
1395 // Insert copy operator.
1396 auto* copy_op = new TensorFlowReshapeOperator;
1397 copy_op->inputs = {
1398 source_array_name,
1399 CreateInt32Array(
1400 model, AvailableArrayName(*model, target_array_name + "_copy_shape"),
1401 shape)};
1402 copy_op->outputs = {target_array_name};
1403 if (target_array.has_shape()) {
1404 copy_op->shape = target_array.shape().dims();
1405 }
1406 model->operators.emplace_back(copy_op);
1407 }
1408
CloneArray(Model * model,const std::string & source_array_name,const std::string & target_array_name)1409 void CloneArray(Model* model, const std::string& source_array_name,
1410 const std::string& target_array_name) {
1411 CHECK(!model->HasArray(target_array_name));
1412 const Array& source_array = model->GetArray(source_array_name);
1413 Array& target_array = model->GetOrCreateArray(target_array_name);
1414 CopyArrayAttribs(source_array, &target_array);
1415
1416 if (!source_array.buffer) {
1417 return;
1418 }
1419
1420 switch (source_array.data_type) {
1421 case ArrayDataType::kBool:
1422 CopyArrayBuffer<ArrayDataType::kBool>(source_array, &target_array);
1423 break;
1424 case ArrayDataType::kFloat:
1425 CopyArrayBuffer<ArrayDataType::kFloat>(source_array, &target_array);
1426 break;
1427 case ArrayDataType::kInt8:
1428 CopyArrayBuffer<ArrayDataType::kInt8>(source_array, &target_array);
1429 break;
1430 case ArrayDataType::kUint8:
1431 CopyArrayBuffer<ArrayDataType::kUint8>(source_array, &target_array);
1432 break;
1433 case ArrayDataType::kInt16:
1434 CopyArrayBuffer<ArrayDataType::kInt16>(source_array, &target_array);
1435 break;
1436 case ArrayDataType::kUint16:
1437 CopyArrayBuffer<ArrayDataType::kUint16>(source_array, &target_array);
1438 break;
1439 case ArrayDataType::kInt32:
1440 CopyArrayBuffer<ArrayDataType::kInt32>(source_array, &target_array);
1441 break;
1442 case ArrayDataType::kUint32:
1443 CopyArrayBuffer<ArrayDataType::kUint32>(source_array, &target_array);
1444 break;
1445 case ArrayDataType::kInt64:
1446 CopyArrayBuffer<ArrayDataType::kInt64>(source_array, &target_array);
1447 break;
1448 case ArrayDataType::kUint64:
1449 CopyArrayBuffer<ArrayDataType::kUint64>(source_array, &target_array);
1450 break;
1451 case ArrayDataType::kString:
1452 CopyArrayBuffer<ArrayDataType::kString>(source_array, &target_array);
1453 break;
1454 case ArrayDataType::kComplex64:
1455 CopyArrayBuffer<ArrayDataType::kComplex64>(source_array, &target_array);
1456 break;
1457 default:
1458 LOG(FATAL) << "Unsupported data type: "
1459 << ArrayDataTypeName(source_array.data_type);
1460 return;
1461 }
1462 }
1463
MakeArrayDims(int num_dims,int batch,int height,int width,int depth,std::vector<int> * out_dims)1464 void MakeArrayDims(int num_dims, int batch, int height, int width, int depth,
1465 std::vector<int>* out_dims) {
1466 CHECK(out_dims->empty());
1467 if (num_dims == 0) {
1468 return;
1469 } else if (num_dims == 1) {
1470 CHECK_EQ(batch, 1);
1471 *out_dims = {depth};
1472 } else if (num_dims == 2) {
1473 *out_dims = {batch, depth};
1474 } else if (num_dims == 3) {
1475 CHECK_EQ(batch, 1);
1476 *out_dims = {height, width, depth};
1477 } else if (num_dims == 4) {
1478 *out_dims = {batch, height, width, depth};
1479 } else {
1480 LOG(FATAL) << "Should not get here: " << num_dims;
1481 }
1482 }
1483
CreateOrCheckRnnStateArray(const std::string & name,int size,int state_num_dims,Model * model)1484 void CreateOrCheckRnnStateArray(const std::string& name, int size,
1485 int state_num_dims, Model* model) {
1486 int batch = 1;
1487 int num_dims = -1;
1488 if (state_num_dims > 0) {
1489 num_dims = state_num_dims;
1490 } else {
1491 // state_num_dims is not given. We will infer it from an input tensor.
1492 for (const auto& input_array : model->flags.input_arrays()) {
1493 // Pick 'num_dims' and 'batch' from the first input_arrays, unless we find
1494 // a better match by name.
1495 if (input_array.name() == name || num_dims == -1) {
1496 num_dims = input_array.shape().dims_size();
1497 if (num_dims > 0) {
1498 batch = input_array.shape().dims(0);
1499 }
1500 }
1501 }
1502 }
1503 Array& array = model->GetOrCreateArray(name);
1504 if (array.has_shape()) {
1505 num_dims = array.shape().dimensions_count();
1506 }
1507 if (!array.has_shape() && num_dims >= 0) {
1508 Shape* shape = array.mutable_shape();
1509 std::vector<int> dims;
1510 MakeArrayDims(num_dims, batch, 1, 1, size, &dims);
1511 *shape->mutable_dims() = dims;
1512 }
1513 }
1514
ResolveModelFlags(const ModelFlags & model_flags,Model * model)1515 void ResolveModelFlags(const ModelFlags& model_flags, Model* model) {
1516 // Merge info about input_arrays from model_flags into model->flags
1517 for (const auto& specified_input_array : model_flags.input_arrays()) {
1518 toco::InputArray* dst_input_array = nullptr;
1519 for (int i = 0; i < model->flags.input_arrays_size(); i++) {
1520 toco::InputArray* candidate_dst_input_array =
1521 model->flags.mutable_input_arrays(i);
1522 if (candidate_dst_input_array->name() == specified_input_array.name()) {
1523 // specified_input_array from model_flags maps to dst_input_array
1524 // in model->flags
1525 dst_input_array = candidate_dst_input_array;
1526 break;
1527 }
1528 }
1529 if (!dst_input_array) {
1530 // Specified_input_array from model_flags is not found in model->flags.
1531 // Match a name-less specified input array when there can be no ambiguity
1532 // as there is only 1 input array.
1533 if (model->flags.input_arrays_size() == 1 &&
1534 model_flags.input_arrays_size() == 1 &&
1535 !specified_input_array.has_name()) {
1536 dst_input_array = model->flags.mutable_input_arrays(0);
1537 }
1538 }
1539 if (!dst_input_array) {
1540 // Still no match, so create a new input array to copy
1541 // specified_input_array into.
1542 dst_input_array = model->flags.add_input_arrays();
1543 dst_input_array->set_name(specified_input_array.name());
1544 }
1545
1546 #define RESOLVE_MODEL_FLAG(field_name) \
1547 if (specified_input_array.has_##field_name()) { \
1548 if (dst_input_array->has_##field_name()) { \
1549 QCHECK_EQ(dst_input_array->field_name(), \
1550 specified_input_array.field_name()) \
1551 << "For input array '" << dst_input_array->name() << "', " \
1552 << "specified " #field_name " flag with value: " \
1553 << specified_input_array.field_name() \
1554 << " does not agree with already defined " #field_name \
1555 " of this model, with value: " \
1556 << specified_input_array.field_name(); \
1557 } else { \
1558 dst_input_array->set_##field_name(specified_input_array.field_name()); \
1559 } \
1560 }
1561 RESOLVE_MODEL_FLAG(std_value);
1562 RESOLVE_MODEL_FLAG(mean_value);
1563 #undef RESOLVE_MODEL_FLAG
1564
1565 if (specified_input_array.has_shape()) {
1566 if (dst_input_array->has_shape()) {
1567 QCHECK_EQ(specified_input_array.shape().dims_size(),
1568 dst_input_array->shape().dims_size())
1569 << "For input array '" << specified_input_array.name() << "', "
1570 << "size of specified input shape flag with size: "
1571 << specified_input_array.shape().dims_size()
1572 << " does not agree with already defined input shape"
1573 " of this model, with size: "
1574 << dst_input_array->shape().dims_size();
1575 // We treat the first dimension as a special case, since it is often
1576 // a batch size and the input_shape flag is effectively overriding
1577 // the model.
1578 for (int i = 1; i < specified_input_array.shape().dims_size(); i++) {
1579 QCHECK_EQ(specified_input_array.shape().dims(i),
1580 dst_input_array->shape().dims(i))
1581 << "At dimension number " << i << " of input array "
1582 << specified_input_array.name() << ", the specified shape's "
1583 << "dimension flag with dimension: "
1584 << specified_input_array.shape().dims(i)
1585 << " does not agree with already defined shape"
1586 << " of this model, with dimension: "
1587 << dst_input_array->shape().dims(i);
1588 }
1589 } else {
1590 *dst_input_array->mutable_shape() = specified_input_array.shape();
1591 }
1592 }
1593
1594 if (specified_input_array.has_data_type()) {
1595 QCHECK(!dst_input_array->has_data_type());
1596 dst_input_array->set_data_type(specified_input_array.data_type());
1597 }
1598 }
1599
1600 if (model_flags.output_arrays_size() > 0) {
1601 model->flags.mutable_output_arrays()->CopyFrom(model_flags.output_arrays());
1602 }
1603
1604 #define RESOLVE_MODEL_FLAG(name) \
1605 if (model_flags.has_##name()) { \
1606 if (model->flags.has_##name()) { \
1607 QCHECK_EQ(model_flags.name(), model->flags.name()) \
1608 << "Specified " #name " flag with value: " << model_flags.name() \
1609 << " does not agree with already defined " #name \
1610 " of this model, with value: " \
1611 << model->flags.name(); \
1612 } else { \
1613 model->flags.set_##name(model_flags.name()); \
1614 } \
1615 }
1616
1617 RESOLVE_MODEL_FLAG(variable_batch)
1618
1619 #undef RESOLVE_MODEL_FLAG
1620
1621 if (!model_flags.rnn_states().empty()) {
1622 model->flags.mutable_rnn_states()->CopyFrom(model_flags.rnn_states());
1623 }
1624
1625 if (model->flags.model_checks_size() == 0) {
1626 model->flags.mutable_model_checks()->CopyFrom(model_flags.model_checks());
1627 }
1628
1629 QCHECK_GT(model->flags.output_arrays_size(), 0)
1630 << "This model does not define output arrays, so a "
1631 "--output_arrays flag must be given on the command-line.";
1632
1633 for (auto& input_array_proto : *model->flags.mutable_input_arrays()) {
1634 auto& input_array = model->GetOrCreateArray(input_array_proto.name());
1635 if (input_array_proto.has_data_type()) {
1636 const ArrayDataType specified_type =
1637 ConvertIODataTypeToArrayDataType(input_array_proto.data_type());
1638 QCHECK(specified_type != ArrayDataType::kNone);
1639 if (input_array.data_type != ArrayDataType::kNone) {
1640 QCHECK(specified_type == input_array.data_type)
1641 << "For input array " << input_array_proto.name()
1642 << " the specified input data type "
1643 << IODataType_Name(input_array_proto.data_type())
1644 << " conflicts with the existing type.";
1645 }
1646 input_array.data_type = specified_type;
1647 }
1648
1649 if (input_array.data_type == ArrayDataType::kNone) {
1650 // We start out with a float input array;
1651 // that may get replaced by a uint8 array later, by
1652 // MakeInitialDequantizeOp.
1653 input_array.data_type = ArrayDataType::kFloat;
1654 }
1655
1656 // Compare/merge the model->flags describing the input_shape with
1657 // the actual input array's shape.
1658 if (!input_array.has_shape()) {
1659 if (input_array_proto.has_shape()) {
1660 auto& input_array_dims = *input_array.mutable_shape()->mutable_dims();
1661 CheckValidShapeDimensions(input_array_proto.shape().dims());
1662 for (const auto& dim : input_array_proto.shape().dims()) {
1663 input_array_dims.push_back(dim);
1664 }
1665 }
1666 } else {
1667 if (input_array_proto.has_shape()) {
1668 // If an input shape was specified on the flags ensure that it matches
1669 // the actual shape in the model.
1670 const auto& input_array_dims =
1671 *input_array.mutable_shape()->mutable_dims();
1672 CHECK_EQ(input_array_dims.size(),
1673 input_array_proto.shape().dims_size());
1674 for (int i = 0; i < input_array_dims.size(); i++) {
1675 CHECK_EQ(input_array_dims[i], input_array_proto.shape().dims(i));
1676 }
1677 } else {
1678 for (int i = 0; i < input_array.shape().dimensions_count(); i++) {
1679 input_array_proto.mutable_shape()->add_dims(
1680 input_array.shape().dims(i));
1681 }
1682 }
1683 }
1684
1685 const float mean_value = input_array_proto.mean_value();
1686 const float std_value = input_array_proto.std_value();
1687 MinMax input_minmax;
1688 float qmin = 0, qmax = 255;
1689 if (input_array.data_type == ArrayDataType::kInt16) {
1690 qmin = -32768;
1691 qmax = 32767;
1692 }
1693 input_minmax.min = (qmin - mean_value) / std_value;
1694 input_minmax.max = (qmax - mean_value) / std_value;
1695 if (!input_array.minmax) {
1696 input_array.GetOrCreateMinMax() = input_minmax;
1697 }
1698 }
1699
1700 // Creation of the RNN state arrays
1701 for (const auto& rnn_state : model->flags.rnn_states()) {
1702 CreateOrCheckRnnStateArray(rnn_state.state_array(), rnn_state.size(),
1703 rnn_state.num_dims(), model);
1704 }
1705
1706 model->flags.set_change_concat_input_ranges(
1707 model_flags.change_concat_input_ranges());
1708 model->flags.set_allow_nonascii_arrays(model_flags.allow_nonascii_arrays());
1709 model->flags.set_allow_nonexistent_arrays(
1710 model_flags.allow_nonexistent_arrays());
1711
1712 CHECK(!model->flags.has_arrays_extra_info());
1713 *model->flags.mutable_arrays_extra_info() = model_flags.arrays_extra_info();
1714 }
1715
CheckIsReadyForQuantization(const Model & model)1716 void CheckIsReadyForQuantization(const Model& model) {
1717 for (const auto& op : model.operators) {
1718 for (const auto& input : op->inputs) {
1719 const auto& input_array = model.GetArray(input);
1720 if (input_array.data_type != ArrayDataType::kFloat) {
1721 // The array is not floats, no quantization needed.
1722 continue;
1723 }
1724 if (input_array.minmax) {
1725 // The array has minmax, we're good.
1726 continue;
1727 }
1728 if (input_array.buffer) {
1729 // The array has a constant buffer, so we can
1730 // fall back to computing the minmax from actual array entries
1731 // (with a WARNING about possible accuracy implications).
1732 continue;
1733 }
1734 LOG(FATAL)
1735 << "Array " << input << ", which is an input to the "
1736 << HelpfulOperatorTypeName(*op) << " operator producing the output "
1737 << "array " << op->outputs[0] << ", is lacking min/max data, "
1738 << "which is necessary for quantization. If accuracy matters, either "
1739 << "target a non-quantized output format, or run quantized training "
1740 << "with your model from a floating point checkpoint to change the "
1741 << "input graph to contain min/max information. If you don't care "
1742 << "about accuracy, you can pass --default_ranges_min= and "
1743 << "--default_ranges_max= for easy experimentation.";
1744 }
1745 }
1746 }
1747
ElementSize(ArrayDataType data_type)1748 int ElementSize(ArrayDataType data_type) {
1749 switch (data_type) {
1750 case ArrayDataType::kBool:
1751 return sizeof(bool);
1752 case ArrayDataType::kFloat:
1753 return 4;
1754 case ArrayDataType::kInt8:
1755 return 1;
1756 case ArrayDataType::kUint8:
1757 return 1;
1758 case ArrayDataType::kInt16:
1759 return 2;
1760 case ArrayDataType::kUint16:
1761 return 2;
1762 case ArrayDataType::kInt32:
1763 return 4;
1764 case ArrayDataType::kUint32:
1765 return 4;
1766 case ArrayDataType::kInt64:
1767 return 8;
1768 case ArrayDataType::kUint64:
1769 return 8;
1770 case ArrayDataType::kComplex64:
1771 return 8;
1772 case ArrayDataType::kComplex128:
1773 return 16;
1774 case ArrayDataType::kFloat64:
1775 return 8;
1776
1777 // Usually not critical limitation because strings are only input and/or
1778 // output.
1779 case ArrayDataType::kString:
1780 LOG(FATAL) << "Transient arrays with strings are not supported yet";
1781 return 0;
1782 default:
1783 LOG(FATAL) << "Unknown data_type = " << static_cast<int>(data_type);
1784 return 0;
1785 }
1786 }
1787
DropMinMax(Model * model,const std::string & array_name)1788 void DropMinMax(Model* model, const std::string& array_name) {
1789 auto& array = model->GetArray(array_name);
1790 if (!!array.minmax) {
1791 LOG(WARNING) << "Dropping MinMax information in array " << array_name
1792 << ". Expect inaccuracy in quantized inference.";
1793 array.minmax = nullptr;
1794 }
1795 }
1796
IsAllocatableTransientArray(const Model & model,const std::string & array_name)1797 bool IsAllocatableTransientArray(const Model& model,
1798 const std::string& array_name) {
1799 // Optional array is not transient
1800 if (model.IsOptionalArray(array_name)) return false;
1801 // The model's input and output arrays are externally allocated.
1802 // They are not transient arrays.
1803 if (IsInputArray(model, array_name) || IsOutputArray(model, array_name)) {
1804 return false;
1805 }
1806 const auto& array = &model.GetArray(array_name);
1807 // An array with a constant buffer isn't a transient array.
1808 if (!!array->buffer) {
1809 return false;
1810 }
1811 // An array without shape isn't allocatable.
1812 if (!array->has_shape()) {
1813 return false;
1814 }
1815
1816 // The size of string tensors is rarely known ahead of time, so all transient
1817 // tensors of this type will need to be dynamically allocated.
1818 if (array->final_data_type == ArrayDataType::kString ||
1819 array->data_type == ArrayDataType::kString) {
1820 return false;
1821 }
1822
1823 return true;
1824 }
1825
AvailableArrayName(const Model & model,const std::string & name)1826 std::string AvailableArrayName(const Model& model, const std::string& name) {
1827 std::string sanitized_name = SanitizeNameForTFNode(name);
1828 if (!model.HasArray(sanitized_name) &&
1829 !model.IsOptionalArray(sanitized_name)) {
1830 return sanitized_name;
1831 }
1832 const int kNumSuffixesToTry = 1000;
1833 for (int i = 0; i < kNumSuffixesToTry; i++) {
1834 const std::string& name_with_suffix =
1835 toco::port::StringF("%s_%d", sanitized_name, i);
1836 if (!model.HasArray(name_with_suffix) &&
1837 !model.IsOptionalArray(name_with_suffix)) {
1838 return name_with_suffix;
1839 }
1840 }
1841 LOG(FATAL) << "Could not find an available array name starting with "
1842 << sanitized_name << ". Tried " << kNumSuffixesToTry
1843 << " suffixes, all were taken!";
1844 return "";
1845 }
1846
ShapeToString(const Shape & shape)1847 std::string ShapeToString(const Shape& shape) {
1848 if (shape.dimensions_count() == 0) {
1849 return "[]";
1850 }
1851
1852 return absl::StrCat("[ ", absl::StrJoin(shape.dims(), ", "), " ]");
1853 }
1854
PrintArrayShape(Model * model,const std::string & name)1855 void PrintArrayShape(Model* model, const std::string& name) {
1856 if (!model->GetArray(name).has_shape()) {
1857 LOG(INFO) << name << " has no shape";
1858 return;
1859 }
1860 LOG(INFO) << name
1861 << " has shape: " << ShapeToString(model->GetArray(name).shape());
1862 }
1863
IsArrayFullyConnectedWeights(const Model & model,const std::string & name)1864 bool IsArrayFullyConnectedWeights(const Model& model, const std::string& name) {
1865 bool is_fc_weights = false;
1866 bool is_something_else = false;
1867 for (const auto& op : model.operators) {
1868 for (int input_index = 0; input_index < op->inputs.size(); input_index++) {
1869 if (op->inputs[input_index] == name) {
1870 if (op->type == OperatorType::kFullyConnected && input_index == 1) {
1871 is_fc_weights = true;
1872 } else {
1873 is_something_else = true;
1874 }
1875 }
1876 }
1877 }
1878 CHECK(!(is_fc_weights && is_something_else));
1879 return is_fc_weights;
1880 }
1881
CreateInt32Array(Model * model,const std::string & param_name,const std::vector<int> & value)1882 std::string CreateInt32Array(Model* model, const std::string& param_name,
1883 const std::vector<int>& value) {
1884 auto param_array_name = AvailableArrayName(*model, param_name);
1885 auto& param_array = model->GetOrCreateArray(param_array_name);
1886 param_array.mutable_shape()->ReplaceDims({static_cast<int>(value.size())});
1887 param_array.data_type = ArrayDataType::kInt32;
1888 auto& param_array_data =
1889 param_array.GetMutableBuffer<ArrayDataType::kInt32>().data;
1890 param_array_data.resize(RequiredBufferSizeForShape(param_array.shape()));
1891 for (int i = 0; i < value.size(); ++i) {
1892 param_array_data[i] = value[i];
1893 }
1894 return param_array_name;
1895 }
1896
EstimateArithmeticOpsCount(const Model & model,const Operator & op,int64 * result)1897 bool EstimateArithmeticOpsCount(const Model& model, const Operator& op,
1898 int64* result) {
1899 switch (op.type) {
1900 case OperatorType::kFullyConnected:
1901 case OperatorType::kConv:
1902 case OperatorType::kDepthwiseConv: {
1903 const auto& output_array = model.GetArray(op.outputs[0]);
1904 const auto& weights_array = model.GetArray(op.inputs[1]);
1905 if (!output_array.has_shape() || !weights_array.has_shape()) {
1906 return false;
1907 }
1908 int64 cols = 1;
1909 for (int i = 0; i < output_array.shape().dimensions_count() - 1; i++) {
1910 cols *= output_array.shape().dims(i);
1911 }
1912 const int64 cost_per_col =
1913 2 * RequiredBufferSizeForShape(weights_array.shape());
1914 *result = cost_per_col * cols;
1915 if (op.inputs.size() > 2) {
1916 // There is a bias vector. One more op per output value.
1917 *result += RequiredBufferSizeForShape(output_array.shape());
1918 }
1919 break;
1920 }
1921 case OperatorType::kTransposeConv: {
1922 const auto& input_array = model.GetArray(op.inputs[2]);
1923 const auto& weights_array = model.GetArray(op.inputs[1]);
1924 if (!input_array.has_shape() || !weights_array.has_shape()) {
1925 return false;
1926 }
1927 const Shape& input = input_array.shape();
1928 const Shape& weights = weights_array.shape();
1929 // Compute op count from the seven nested loops of
1930 // tflite::reference_ops::TransposeConv():
1931 *result = 2 * input.dims(0) * input.dims(1) * input.dims(2) *
1932 input.dims(3) * weights.dims(1) * weights.dims(2) *
1933 weights.dims(0);
1934 // Note that tflite::optimized_ops::TransposeConv() uses an im2col matrix
1935 // and has a higher op count, by a factor of (output_height*output_width)
1936 // vs. (input_height*input_width). Yet it generally performs better
1937 // because of coherent memory access. (At least for 2x2 striding. But not
1938 // likely for all cases.)
1939 break;
1940 }
1941 case OperatorType::kAdd:
1942 case OperatorType::kSub:
1943 case OperatorType::kMul: {
1944 const auto& output_array = model.GetArray(op.outputs[0]);
1945 if (!output_array.has_shape()) {
1946 return false;
1947 }
1948 *result = RequiredBufferSizeForShape(output_array.shape());
1949 break;
1950 }
1951 case OperatorType::kAddN: {
1952 const auto& output_array = model.GetArray(op.outputs[0]);
1953 if (!output_array.has_shape()) {
1954 return false;
1955 }
1956 // AddN cost is roughly the same cost as N-1 Adds.
1957 const int64 num_adds = op.inputs.size() - 1;
1958 *result = num_adds * RequiredBufferSizeForShape(output_array.shape());
1959 break;
1960 }
1961 case OperatorType::kLogistic:
1962 case OperatorType::kSoftmax:
1963 case OperatorType::kLogSoftmax:
1964 case OperatorType::kTanh: {
1965 const auto& output_array = model.GetArray(op.outputs[0]);
1966 if (!output_array.has_shape()) {
1967 return false;
1968 }
1969 // As a very rough ballpark, the cost of evaluating a math function
1970 // such as tanh or logistic is about 32 multiplications, and about as
1971 // many additions/subtractions. (Just a power-of-two order-of-magnitude
1972 // from looking at actual implementations that we use in runtime/ code).
1973 *result = 64 * RequiredBufferSizeForShape(output_array.shape());
1974 break;
1975 }
1976 case OperatorType::kMaxPool: {
1977 const auto& maxpool = *static_cast<const MaxPoolOperator*>(&op);
1978 const auto& output_array = model.GetArray(op.outputs[0]);
1979 if (!output_array.has_shape()) {
1980 return false;
1981 }
1982 *result = RequiredBufferSizeForShape(output_array.shape()) *
1983 maxpool.kheight * maxpool.kwidth;
1984 break;
1985 }
1986 case OperatorType::kAveragePool: {
1987 const auto& avgpool = *static_cast<const AveragePoolOperator*>(&op);
1988 const auto& output_array = model.GetArray(op.outputs[0]);
1989 if (!output_array.has_shape()) {
1990 return false;
1991 }
1992 *result = RequiredBufferSizeForShape(output_array.shape()) *
1993 avgpool.kheight * avgpool.kwidth;
1994 break;
1995 }
1996 case OperatorType::kL2Pool: {
1997 const auto* maxpool = static_cast<const MaxPoolOperator*>(&op);
1998 const auto& output_array = model.GetArray(op.outputs[0]);
1999 if (!output_array.has_shape()) {
2000 return false;
2001 }
2002 // The sum of squares requires (kheight*kwidth) multiply-adds,
2003 // and then there is the sqrt which we ballpark at 32 ops.
2004 const int64 cost_per_val = 2 * maxpool->kheight * maxpool->kwidth + 32;
2005 *result = RequiredBufferSizeForShape(output_array.shape()) * cost_per_val;
2006 break;
2007 }
2008 case OperatorType::kL2Normalization: {
2009 const auto& output_array = model.GetArray(op.outputs[0]);
2010 if (!output_array.has_shape()) {
2011 return false;
2012 }
2013 // Computing the squared L2 norm is N multiply-adds so 2N ops,
2014 // then the single inverse-sqrt is negligible, then we multiply each
2015 // value by the resulting multiplier, so an extra N ops. count 3N ops.
2016 *result = 3 * RequiredBufferSizeForShape(output_array.shape());
2017 break;
2018 }
2019 default:
2020 *result = 0;
2021 break;
2022 }
2023 return true;
2024 }
2025
EstimateArithmeticOpsCount(const Model & model,int64 * result)2026 bool EstimateArithmeticOpsCount(const Model& model, int64* result) {
2027 int64 total = 0;
2028 for (const auto& op : model.operators) {
2029 int64 num_ops;
2030 if (!EstimateArithmeticOpsCount(model, *op, &num_ops)) {
2031 return false;
2032 }
2033 total += num_ops;
2034 }
2035 *result = total;
2036 return true;
2037 }
2038
FormattedNumber(int64 x)2039 std::string FormattedNumber(int64 x) {
2040 const int64 million = 1000000;
2041 const int64 billion = 1000000000;
2042 if (x < 10000) {
2043 return toco::port::StringF("%d ", x);
2044 } else if (x < billion) {
2045 return toco::port::StringF("%.3f M", static_cast<double>(x) / million);
2046 } else {
2047 return toco::port::StringF("%.3f G", static_cast<double>(x) / billion);
2048 }
2049 }
2050
GetShuffleShape(AxesOrder input_axes_order,AxesOrder output_axes_order,std::vector<int> * shuffle)2051 void GetShuffleShape(AxesOrder input_axes_order, AxesOrder output_axes_order,
2052 std::vector<int>* shuffle) {
2053 CHECK_EQ(AxesCount(input_axes_order), AxesCount(output_axes_order));
2054 shuffle->resize(4);
2055 for (int i = 0; i < 4; i++) {
2056 (*shuffle)[i] = i;
2057 }
2058 if (input_axes_order == output_axes_order) {
2059 // nothing to do
2060 } else if (AxesCount(input_axes_order) == 2) {
2061 shuffle->resize(2);
2062 (*shuffle)[0] = 1;
2063 (*shuffle)[1] = 0;
2064 } else if (input_axes_order == AxesOrder::kOHWI &&
2065 output_axes_order == AxesOrder::kHWIO) {
2066 // 3210 <- 3210
2067 // HWIO <- OHWI
2068 *shuffle = {1, 2, 3, 0};
2069 } else if (input_axes_order == AxesOrder::kHWIO &&
2070 output_axes_order == AxesOrder::kOHWI) {
2071 // 3210 <- 3210
2072 // OHWI <- HWIO
2073 *shuffle = {3, 0, 1, 2};
2074 } else if (input_axes_order == AxesOrder::kOHWI &&
2075 output_axes_order == AxesOrder::kHWOI) {
2076 *shuffle = {1, 2, 0, 3};
2077 } else {
2078 LOG(FATAL) << "Bad shuffle";
2079 }
2080 }
2081
ExtendShuffle(const std::vector<int> & input_shuffle,int newdim,std::vector<int> * extended_shuffle)2082 void ExtendShuffle(const std::vector<int>& input_shuffle, int newdim,
2083 std::vector<int>* extended_shuffle) {
2084 *extended_shuffle = input_shuffle;
2085 CHECK(newdim >= input_shuffle.size());
2086 const int pad_size = newdim - input_shuffle.size();
2087 extended_shuffle->resize(newdim);
2088 for (int i = 0; i < pad_size; i++) {
2089 (*extended_shuffle)[i] = i;
2090 }
2091 for (int i = pad_size; i < newdim; i++) {
2092 (*extended_shuffle)[i] = input_shuffle[i - pad_size] + pad_size;
2093 }
2094 }
2095
ShuffleDims(const Shape & input_shape,AxesOrder input_axes_order,AxesOrder output_axes_order,Shape * output_shape)2096 void ShuffleDims(const Shape& input_shape, AxesOrder input_axes_order,
2097 AxesOrder output_axes_order, Shape* output_shape) {
2098 if (input_axes_order == AxesOrder::kHWIM &&
2099 output_axes_order == AxesOrder::k1HWO) {
2100 // This special case isn't just a permutation, the IM pair of dims get
2101 // merged into the 3 dim, so we have to special-case it.
2102 *output_shape = Shape({1, input_shape.dims(0), input_shape.dims(1),
2103 input_shape.dims(3) * input_shape.dims(2)});
2104 } else {
2105 std::vector<int> shuffle;
2106 GetShuffleShape(input_axes_order, output_axes_order, &shuffle);
2107 std::vector<int>* output_dims = output_shape->mutable_dims();
2108 output_dims->resize(input_shape.dimensions_count());
2109 for (int i = 0; i < input_shape.dimensions_count(); i++) {
2110 (*output_dims)[i] = input_shape.dims(shuffle[i]);
2111 }
2112 }
2113 }
2114
2115 template <typename T>
ShuffleArrayTemplate(const Shape & input_shape,AxesOrder input_axes_order,AxesOrder output_axes_order,const Shape & output_shape,const T * input_data,T * output_data)2116 void ShuffleArrayTemplate(const Shape& input_shape, AxesOrder input_axes_order,
2117 AxesOrder output_axes_order,
2118 const Shape& output_shape, const T* input_data,
2119 T* output_data) {
2120 if (input_axes_order == AxesOrder::kHWIM &&
2121 output_axes_order == AxesOrder::k1HWO) {
2122 // This special case isn't just a permutation, the IM pair of dims get
2123 // merged into the O dim, so we have to special-case it. Fortunately,
2124 // as far as array shuffling is concerned, it's just the identity
2125 // transformation.
2126 memcpy(output_data, input_data,
2127 RequiredBufferSizeForShape(input_shape) * sizeof(output_data[0]));
2128 return;
2129 }
2130 CHECK(input_shape.dimensions_count() == output_shape.dimensions_count());
2131 const int dim = input_shape.dimensions_count();
2132 CHECK_LE(dim, 4);
2133 std::vector<int> shuffle;
2134 GetShuffleShape(input_axes_order, output_axes_order, &shuffle);
2135 CHECK(shuffle.size() >= dim);
2136 for (int i = 0; i < dim; i++) {
2137 CHECK(shuffle[i] >= 0 && shuffle[i] < dim);
2138 CHECK(input_shape.dims(shuffle[i]) == output_shape.dims(i));
2139 }
2140 Shape extended_input_shape = input_shape;
2141 ExtendShape(&extended_input_shape, 4);
2142 Shape extended_output_shape = output_shape;
2143 ExtendShape(&extended_output_shape, 4);
2144 std::vector<int> extended_shuffle;
2145 ExtendShuffle(shuffle, 4, &extended_shuffle);
2146
2147 const std::vector<int>& extended_input_dims = extended_input_shape.dims();
2148 const std::vector<int>& extended_output_dims = extended_output_shape.dims();
2149
2150 // TODO(starka): Rework to handle different numbers of dimensions.
2151 int input_strides[4];
2152 input_strides[3] = 1;
2153 input_strides[2] = extended_input_dims[3];
2154 input_strides[1] = input_strides[2] * extended_input_dims[2];
2155 input_strides[0] = input_strides[1] * extended_input_dims[1];
2156 const int input_stride_0 = input_strides[extended_shuffle[3]];
2157 const int input_stride_1 = input_strides[extended_shuffle[2]];
2158 const int input_stride_2 = input_strides[extended_shuffle[1]];
2159 const int input_stride_3 = input_strides[extended_shuffle[0]];
2160
2161 const int output_size_0 = extended_output_dims[3];
2162 const int output_size_1 = extended_output_dims[2];
2163 const int output_size_2 = extended_output_dims[1];
2164 const int output_size_3 = extended_output_dims[0];
2165 const int output_stride_0 = 1;
2166 const int output_stride_1 = output_size_0;
2167 const int output_stride_2 = output_stride_1 * output_size_1;
2168 const int output_stride_3 = output_stride_2 * output_size_2;
2169
2170 for (int i3 = 0; i3 < output_size_3; i3++) {
2171 const T* const input_ptr_3 = input_data + i3 * input_stride_3;
2172 T* const output_ptr_3 = output_data + i3 * output_stride_3;
2173 for (int i2 = 0; i2 < output_size_2; i2++) {
2174 const T* const input_ptr_2 = input_ptr_3 + i2 * input_stride_2;
2175 T* const output_ptr_2 = output_ptr_3 + i2 * output_stride_2;
2176 for (int i1 = 0; i1 < output_size_1; i1++) {
2177 const T* input_ptr = input_ptr_2 + i1 * input_stride_1;
2178 T* output_ptr = output_ptr_2 + i1 * output_stride_1;
2179 T* const output_ptr_end = output_ptr + output_size_0 * output_stride_0;
2180 while (output_ptr != output_ptr_end) {
2181 *output_ptr = *input_ptr;
2182 input_ptr += input_stride_0;
2183 output_ptr += output_stride_0;
2184 }
2185 }
2186 }
2187 }
2188 }
2189
ShuffleArray(const Shape & input_shape,AxesOrder input_axes_order,AxesOrder output_axes_order,const Shape & output_shape,const uint8 * input_data,uint8 * output_data)2190 void ShuffleArray(const Shape& input_shape, AxesOrder input_axes_order,
2191 AxesOrder output_axes_order, const Shape& output_shape,
2192 const uint8* input_data, uint8* output_data) {
2193 ShuffleArrayTemplate<uint8>(input_shape, input_axes_order, output_axes_order,
2194 output_shape, input_data, output_data);
2195 }
2196
ShuffleArray(const Shape & input_shape,AxesOrder input_axes_order,AxesOrder output_axes_order,const Shape & output_shape,const float * input_data,float * output_data)2197 void ShuffleArray(const Shape& input_shape, AxesOrder input_axes_order,
2198 AxesOrder output_axes_order, const Shape& output_shape,
2199 const float* input_data, float* output_data) {
2200 ShuffleArrayTemplate<float>(input_shape, input_axes_order, output_axes_order,
2201 output_shape, input_data, output_data);
2202 }
2203
AxesCount(AxesOrder axes_order)2204 int AxesCount(AxesOrder axes_order) {
2205 switch (axes_order) {
2206 case AxesOrder::kOneAxis:
2207 return 1;
2208 case AxesOrder::kRC:
2209 return 2;
2210 case AxesOrder::kCR:
2211 return 2;
2212 case AxesOrder::kHWIO:
2213 return 4;
2214 case AxesOrder::kOHWI:
2215 return 4;
2216 case AxesOrder::kHWIM:
2217 return 4;
2218 case AxesOrder::k1HWO:
2219 return 4;
2220 case AxesOrder::kNHWC:
2221 return 4;
2222 case AxesOrder::kHWOI:
2223 return 4;
2224 default:
2225 LOG(FATAL) << "Bad AxesOrder";
2226 return 0;
2227 }
2228 }
2229
IsDiscardableArray(const Model & model,const std::string & array_name)2230 bool IsDiscardableArray(const Model& model, const std::string& array_name) {
2231 if (IsInputArray(model, array_name) || IsOutputArray(model, array_name)) {
2232 return false;
2233 }
2234 for (const auto& rnn_state : model.flags.rnn_states()) {
2235 if (!rnn_state.discardable()) {
2236 if (array_name == rnn_state.state_array()) {
2237 return false;
2238 }
2239 if (array_name == rnn_state.back_edge_source_array()) {
2240 return false;
2241 }
2242 }
2243 }
2244 return true;
2245 }
2246
ReshapeIsEquivalentToTranspose(const Model & model,const TensorFlowReshapeOperator * op,bool allow_extra_unary_dims)2247 bool ReshapeIsEquivalentToTranspose(const Model& model,
2248 const TensorFlowReshapeOperator* op,
2249 bool allow_extra_unary_dims) {
2250 CHECK(!op->shape.empty());
2251 CHECK(model.HasArray(op->inputs[0]));
2252 CHECK(model.HasArray(op->outputs[0]));
2253
2254 const auto& input_array = model.GetArray(op->inputs[0]);
2255 const auto& output_array = model.GetArray(op->outputs[0]);
2256
2257 CHECK(input_array.has_shape());
2258 CHECK(output_array.has_shape());
2259
2260 std::vector<int> in_shape = input_array.shape().dims();
2261 std::vector<int> out_shape = output_array.shape().dims();
2262
2263 // If the reshape changes the number of dimensions so it cannot be interpreted
2264 // as a transpose.
2265 if (!allow_extra_unary_dims && in_shape.size() != out_shape.size()) {
2266 return false;
2267 }
2268
2269 in_shape.erase(std::remove(in_shape.begin(), in_shape.end(), 1),
2270 in_shape.end());
2271 out_shape.erase(std::remove(out_shape.begin(), out_shape.end(), 1),
2272 out_shape.end());
2273 return in_shape == out_shape;
2274 }
2275
CheckFinalDataTypesSatisfied(const Model & model)2276 void CheckFinalDataTypesSatisfied(const Model& model) {
2277 for (const auto& array_entry : model.GetArrayMap()) {
2278 const auto& array = *array_entry.second;
2279 if (array.data_type == ArrayDataType::kBool) {
2280 // Boolean values are never quantized.
2281 continue;
2282 }
2283
2284 // If the final data type is int16, the data type may be float, for example
2285 // after dequantization.
2286 if (array.final_data_type != ArrayDataType::kNone &&
2287 array.final_data_type != ArrayDataType::kInt16) {
2288 CHECK(array.data_type == array.final_data_type)
2289 << "Array \"" << array_entry.first
2290 << "\" has mis-matching actual and final data types (data_type="
2291 << ArrayDataTypeName(array.data_type)
2292 << ", final_data_type=" << ArrayDataTypeName(array.final_data_type)
2293 << ").";
2294 }
2295 }
2296 }
2297
ConvertIODataTypeToArrayDataType(IODataType type)2298 ArrayDataType ConvertIODataTypeToArrayDataType(IODataType type) {
2299 switch (type) {
2300 case FLOAT:
2301 return ArrayDataType::kFloat;
2302 case QUANTIZED_UINT8:
2303 return ArrayDataType::kUint8;
2304 case INT8:
2305 return ArrayDataType::kInt8;
2306 case QUANTIZED_INT16:
2307 return ArrayDataType::kInt16;
2308 case INT32:
2309 return ArrayDataType::kInt32;
2310 case UINT32:
2311 return ArrayDataType::kUint32;
2312 case INT64:
2313 return ArrayDataType::kInt64;
2314 case UINT64:
2315 return ArrayDataType::kUint64;
2316 case BOOL:
2317 return ArrayDataType::kBool;
2318 case STRING:
2319 return ArrayDataType::kString;
2320 case COMPLEX64:
2321 return ArrayDataType::kComplex64;
2322 case COMPLEX128:
2323 return ArrayDataType::kComplex128;
2324 case FLOAT16:
2325 return ArrayDataType::kFloat16;
2326 case FLOAT64:
2327 return ArrayDataType::kFloat64;
2328 case RESOURCE:
2329 case VARIANT:
2330 default:
2331 return ArrayDataType::kNone;
2332 }
2333 }
2334
FinishBuildingRNNStates(Model * model)2335 void FinishBuildingRNNStates(Model* model) {
2336 for (const auto& rnn_state : model->flags.rnn_states()) {
2337 if (!model->HasArray(rnn_state.back_edge_source_array()) ||
2338 !model->HasArray(rnn_state.state_array())) {
2339 CHECK(model->HasArray(rnn_state.back_edge_source_array()));
2340 CHECK(model->HasArray(rnn_state.state_array()));
2341 continue;
2342 }
2343 const auto& src_array = model->GetArray(rnn_state.back_edge_source_array());
2344 auto& dst_array = model->GetArray(rnn_state.state_array());
2345 if (src_array.data_type == ArrayDataType::kNone &&
2346 dst_array.data_type == ArrayDataType::kNone) {
2347 dst_array.data_type = ArrayDataType::kFloat;
2348 }
2349 }
2350 }
2351
2352 // Returns the array names that match the ArraysExtraInfo's name and
2353 // name_regexp. The regexp match is for a full match.
ScanArrayNames(const Model & model,const toco::ArraysExtraInfo_Entry & entry)2354 std::unordered_set<std::string> ScanArrayNames(
2355 const Model& model, const toco::ArraysExtraInfo_Entry& entry) {
2356 std::unordered_set<std::string> matches;
2357 if (model.HasArray(entry.name())) {
2358 matches.insert(entry.name());
2359 }
2360 if (!entry.name_regexp().empty()) {
2361 const auto& arrays = model.GetArrayMap();
2362 const RE2 name_regexp = {entry.name_regexp()};
2363 for (auto it = arrays.begin(); it != arrays.end(); ++it) {
2364 if (RE2::FullMatch(it->first, name_regexp)) {
2365 matches.insert(it->first);
2366 }
2367 }
2368 }
2369 return matches;
2370 }
2371
UseArraysExtraInfo(Model * model,bool quantize_output)2372 void UseArraysExtraInfo(Model* model, bool quantize_output) {
2373 for (const auto& entry : model->flags.arrays_extra_info().entries()) {
2374 const auto matches = ScanArrayNames(*model, entry);
2375 if (matches.empty()) {
2376 LOG(ERROR) << "arrays_extra_info_file: No matching arrays found for "
2377 << (entry.has_name() ? entry.name() : "")
2378 << (entry.has_name_regexp() ? entry.name_regexp() : "");
2379 continue;
2380 }
2381 for (const auto& matched_name : matches) {
2382 auto& array = model->GetArray(matched_name);
2383 if (entry.has_min() || entry.has_max()) {
2384 CHECK_EQ(entry.has_min(), entry.has_max());
2385 auto& minmax = array.GetOrCreateMinMax();
2386 minmax.min = entry.min();
2387 minmax.max = entry.max();
2388 }
2389 if (entry.has_data_type() && quantize_output) {
2390 array.final_data_type =
2391 ConvertIODataTypeToArrayDataType(entry.data_type());
2392 }
2393 if (entry.has_shape()) {
2394 array.clear_shape();
2395 // Make sure to create the shape even if there are no dims, to
2396 // correctly record 0-D shapes.
2397 array.mutable_shape();
2398 for (const auto& dim : entry.shape().dims()) {
2399 array.mutable_shape()->mutable_dims()->push_back(dim);
2400 }
2401 }
2402 if (entry.has_constant_float_value()) {
2403 CHECK(array.has_shape());
2404 if (array.data_type == ArrayDataType::kFloat) {
2405 auto& data = array.GetMutableBuffer<ArrayDataType::kFloat>().data;
2406 data.resize(RequiredBufferSizeForShape(array.shape()));
2407 for (float& f : data) {
2408 f = entry.constant_float_value();
2409 }
2410 }
2411 }
2412 }
2413 }
2414 }
2415
UndoWeightsShuffling(Model * model)2416 void UndoWeightsShuffling(Model* model) {
2417 for (const auto& op : model->operators) {
2418 if (op->type != toco::OperatorType::kFullyConnected) {
2419 continue;
2420 }
2421 const auto& fc_op = static_cast<toco::FullyConnectedOperator&>(*op);
2422 if (fc_op.weights_format == FullyConnectedWeightsFormat::kDefault) {
2423 continue;
2424 }
2425 const std::string& weights_name = fc_op.inputs[1];
2426 QCHECK_EQ(CountOpsWithInput(*model, weights_name), 1);
2427 auto& weights_array = model->GetArray(weights_name);
2428 QCHECK(weights_array.data_type == ArrayDataType::kUint8);
2429 auto& weights_data =
2430 weights_array.GetMutableBuffer<toco::ArrayDataType::kUint8>().data;
2431 const auto& weights_shape = weights_array.shape();
2432 QCHECK_EQ(weights_shape.dimensions_count(), 2);
2433 const int rows = weights_shape.dims(0);
2434 const int cols = weights_shape.dims(1);
2435 QCHECK_EQ(rows % 4, 0);
2436 QCHECK_EQ(cols % 16, 0);
2437 CHECK_EQ(rows * cols, weights_data.size());
2438 // Compute the de-shuffled weights
2439 std::vector<uint8> deshuffled_data(weights_data.size());
2440 uint8* shuffled_data_ptr = weights_data.data();
2441 for (int r = 0; r < rows; r += 4) {
2442 for (int c = 0; c < cols; c += 16) {
2443 for (int i = 0; i < 4; i++) {
2444 uint8* deshuffled_data_ptr =
2445 deshuffled_data.data() + (r + i) * cols + c;
2446 for (int j = 0; j < 16; j++) {
2447 uint8 shuffled_val = *shuffled_data_ptr++;
2448 // Deshuffling isn't only about deshuffling the storage layout,
2449 // it's also about undoing the flipping of the sign bit, which is
2450 // performed on the shuffled weights.
2451 uint8 deshuffled_val = shuffled_val ^ 0x80;
2452 *deshuffled_data_ptr++ = deshuffled_val;
2453 }
2454 }
2455 }
2456 }
2457 CHECK_EQ(shuffled_data_ptr, weights_data.data() + rows * cols);
2458 // Switch this FC op to using the deshuffled weights.
2459 weights_data = std::move(deshuffled_data);
2460 }
2461 }
2462
CopyMinMaxAndQuantizationRelatedFields(const Array & src,Array * dst)2463 void CopyMinMaxAndQuantizationRelatedFields(const Array& src, Array* dst) {
2464 if (src.minmax) {
2465 dst->GetOrCreateMinMax() = src.GetMinMax();
2466 }
2467 if (src.quantization_params) {
2468 dst->GetOrCreateQuantizationParams() = src.GetQuantizationParams();
2469 }
2470 dst->narrow_range = src.narrow_range;
2471 }
2472
2473 } // namespace toco
2474