1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/toco/tooling_util.h"
16
17 #include <functional>
18 #include <iterator>
19 #include <set>
20 #include <unordered_map>
21 #include <unordered_set>
22 #include <utility>
23
24 #include "absl/strings/ascii.h"
25 #include "absl/strings/str_cat.h"
26 #include "absl/strings/str_join.h"
27 #include "absl/strings/str_replace.h"
28 #include "absl/strings/str_split.h"
29 #include "re2/re2.h"
30 #include "tensorflow/core/lib/core/status.h"
31 #include "tensorflow/core/platform/logging.h"
32 #include "tensorflow/lite/toco/dump_graphviz.h"
33 #include "tensorflow/lite/toco/model_flags.pb.h"
34 #include "tensorflow/lite/toco/toco_graphviz_dump_options.h"
35
36 namespace toco {
37
38 // Find the longest common prefix of two strings.
FindLongestCommonPrefix(absl::string_view a,absl::string_view b)39 absl::string_view FindLongestCommonPrefix(absl::string_view a,
40 absl::string_view b) {
41 if (a.empty() || b.empty()) return absl::string_view();
42
43 const char* pa = a.data();
44 const char* pb = b.data();
45 size_t count = 0;
46 const size_t limit = std::min(a.size(), b.size());
47 while (count < limit && *pa == *pb) {
48 ++pa;
49 ++pb;
50 ++count;
51 }
52
53 return absl::string_view(a.data(), count);
54 }
55
LogName(const Operator & op)56 string LogName(const Operator& op) {
57 const string& opname = HelpfulOperatorTypeName(op);
58 if (op.outputs.empty()) {
59 return toco::port::StringF("{%s operator}", opname);
60 } else {
61 return toco::port::StringF("{%s operator with output %s}", opname,
62 op.outputs[0]);
63 }
64 }
65
ArrayDataTypeName(ArrayDataType data_type)66 string ArrayDataTypeName(ArrayDataType data_type) {
67 switch (data_type) {
68 case ArrayDataType::kFloat:
69 return "float";
70 case ArrayDataType::kInt8:
71 return "int8";
72 case ArrayDataType::kUint8:
73 return "uint8";
74 case ArrayDataType::kInt16:
75 return "int16";
76 case ArrayDataType::kUint16:
77 return "uint16";
78 case ArrayDataType::kInt32:
79 return "int32";
80 case ArrayDataType::kUint32:
81 return "uint32";
82 case ArrayDataType::kInt64:
83 return "int64";
84 case ArrayDataType::kUint64:
85 return "uint64";
86 case ArrayDataType::kString:
87 return "string";
88 case ArrayDataType::kBool:
89 return "bool";
90 case ArrayDataType::kComplex64:
91 return "complex64";
92 case ArrayDataType::kNone:
93 return "None";
94 default:
95 LOG(FATAL) << "Unhandled array data type " << static_cast<int>(data_type);
96 }
97 }
98
IsInputArray(const Model & model,const string & array_name)99 bool IsInputArray(const Model& model, const string& array_name) {
100 for (const auto& input_array : model.flags.input_arrays()) {
101 if (array_name == input_array.name()) {
102 return true;
103 }
104 }
105 return false;
106 }
107
IsOutputArray(const Model & model,const string & array_name)108 bool IsOutputArray(const Model& model, const string& array_name) {
109 for (const auto& output_array : model.flags.output_arrays()) {
110 if (array_name == output_array) {
111 return true;
112 }
113 }
114 return false;
115 }
116
IsArrayConsumed(const Model & model,const string & name)117 bool IsArrayConsumed(const Model& model, const string& name) {
118 if (GetOpWithInput(model, name)) {
119 return true;
120 }
121 if (IsOutputArray(model, name)) {
122 return true;
123 }
124 for (const auto& rnn_state : model.flags.rnn_states()) {
125 if (rnn_state.back_edge_source_array() == name) {
126 return true;
127 }
128 }
129 return false;
130 }
131
CountTrueOutputs(const Model & model,const Operator & op)132 int CountTrueOutputs(const Model& model, const Operator& op) {
133 int count = 0;
134 for (const string& output : op.outputs) {
135 if (IsArrayConsumed(model, output)) {
136 ++count;
137 }
138 }
139 return count;
140 }
141
CountOpsWithInput(const Model & model,const string & array_name)142 int CountOpsWithInput(const Model& model, const string& array_name) {
143 int count = 0;
144 for (const auto& op : model.operators) {
145 for (auto& input : op->inputs) {
146 if (input == array_name) {
147 count++;
148 // Breaking here is important: some graphs have ops that use the
149 // same array as more than one of their inputs, and in that case
150 // we want it counted only once.
151 break;
152 }
153 }
154 }
155 return count;
156 }
157
DeleteArrayIfUnused(const string & array_name,Model * model)158 bool DeleteArrayIfUnused(const string& array_name, Model* model) {
159 if (IsDiscardableArray(*model, array_name) &&
160 CountOpsWithInput(*model, array_name) == 0) {
161 model->EraseArray(array_name);
162 return true;
163 }
164 return false;
165 }
166
DeleteArrayIfUsedOnce(const string & array_name,Model * model)167 bool DeleteArrayIfUsedOnce(const string& array_name, Model* model) {
168 if (IsDiscardableArray(*model, array_name) &&
169 CountOpsWithInput(*model, array_name) == 1) {
170 model->EraseArray(array_name);
171 return true;
172 }
173 return false;
174 }
175
DeleteOpAndArraysIfUnused(Model * model,const Operator * op)176 void DeleteOpAndArraysIfUnused(Model* model, const Operator* op) {
177 for (const string& array_name : op->inputs) {
178 DeleteArrayIfUsedOnce(array_name, model);
179 }
180 auto op_it = FindOp(*model, op);
181 CHECK(op_it != model->operators.end());
182 model->operators.erase(op_it);
183 }
184
FindOpWithOutput(const Model & model,const string & array_name)185 std::vector<std::unique_ptr<Operator>>::const_iterator FindOpWithOutput(
186 const Model& model, const string& array_name) {
187 for (auto it = model.operators.begin(); it != model.operators.end(); ++it) {
188 for (auto& output : it->get()->outputs) {
189 if (output == array_name) {
190 return it;
191 }
192 }
193 }
194 return model.operators.end();
195 }
196
FindOpWithOutput(Model & model,const string & array_name)197 std::vector<std::unique_ptr<Operator>>::iterator FindOpWithOutput(
198 Model& model, const string& array_name) {
199 for (auto it = model.operators.begin(); it != model.operators.end(); ++it) {
200 for (auto& output : it->get()->outputs) {
201 if (output == array_name) {
202 return it;
203 }
204 }
205 }
206 return model.operators.end();
207 }
208
GetOpWithOutput(const Model & model,const string & array_name)209 Operator* GetOpWithOutput(const Model& model, const string& array_name) {
210 auto it = FindOpWithOutput(model, array_name);
211 return it == model.operators.end() ? nullptr : it->get();
212 }
213
214 // GetFirstOpWithInput assumes that this finds the first op.
FindOpWithInput(const Model & model,const string & array_name)215 std::vector<std::unique_ptr<Operator>>::const_iterator FindOpWithInput(
216 const Model& model, const string& array_name) {
217 for (auto it = model.operators.begin(); it != model.operators.end(); ++it) {
218 for (auto& input : it->get()->inputs) {
219 if (input == array_name) {
220 return it;
221 }
222 }
223 }
224 return model.operators.end();
225 }
226
FindOpWithInput(Model & model,const string & array_name)227 std::vector<std::unique_ptr<Operator>>::iterator FindOpWithInput(
228 Model& model, const string& array_name) {
229 for (auto it = model.operators.begin(); it != model.operators.end(); ++it) {
230 for (auto& input : it->get()->inputs) {
231 if (input == array_name) {
232 return it;
233 }
234 }
235 }
236 return model.operators.end();
237 }
238
FindOp(const Model & model,const Operator * op)239 std::vector<std::unique_ptr<Operator>>::const_iterator FindOp(
240 const Model& model, const Operator* op) {
241 for (auto it = model.operators.begin(); it != model.operators.end(); ++it) {
242 if (it->get() == op) {
243 return it;
244 }
245 }
246 return model.operators.end();
247 }
248
FindOp(Model & model,const Operator * op)249 std::vector<std::unique_ptr<Operator>>::iterator FindOp(Model& model,
250 const Operator* op) {
251 for (auto it = model.operators.begin(); it != model.operators.end(); ++it) {
252 if (it->get() == op) {
253 return it;
254 }
255 }
256 return model.operators.end();
257 }
258
GetOpWithInput(const Model & model,const string & array_name)259 Operator* GetOpWithInput(const Model& model, const string& array_name) {
260 auto it = FindOpWithInput(model, array_name);
261 return it == model.operators.end() ? nullptr : it->get();
262 }
263
GetFirstOpWithInput(const Model & model,const string & array_name)264 Operator* GetFirstOpWithInput(const Model& model, const string& array_name) {
265 auto it = FindOpWithInput(model, array_name);
266 return it == model.operators.end() ? nullptr : it->get();
267 }
268
ReplaceArrayUsage(Model * model,const string & old_array_name,const string & new_array_name)269 void ReplaceArrayUsage(Model* model, const string& old_array_name,
270 const string& new_array_name) {
271 for (auto& op_it : model->operators) {
272 Operator* op = op_it.get();
273 for (size_t i = 0; i < op->inputs.size(); ++i) {
274 if (op->inputs[i] == old_array_name) {
275 op->inputs[i] = new_array_name;
276 }
277 }
278 for (size_t i = 0; i < op->outputs.size(); ++i) {
279 if (op->outputs[i] == old_array_name) {
280 op->outputs[i] = new_array_name;
281 }
282 }
283 }
284 }
285
FormatArraysList(const Model & model,const std::vector<string> & list)286 string FormatArraysList(const Model& model, const std::vector<string>& list) {
287 if (list.empty()) {
288 return "[]";
289 }
290 string result = "";
291 if (list.size() > 1) {
292 result += "[ ";
293 }
294 for (std::size_t i = 0; i < list.size(); i++) {
295 if (i > 0) {
296 result += ", ";
297 }
298 result += list[i];
299 }
300 if (list.size() > 1) {
301 result += " ]";
302 }
303 return result;
304 }
305
OperatorTypeName(OperatorType type)306 const char* OperatorTypeName(OperatorType type) {
307 switch (type) {
308 #define HANDLE_OPERATORTYPENAME_CASE(c) \
309 case OperatorType::k##c: \
310 return #c;
311 HANDLE_OPERATORTYPENAME_CASE(Abs)
312 HANDLE_OPERATORTYPENAME_CASE(Add)
313 HANDLE_OPERATORTYPENAME_CASE(AddN)
314 HANDLE_OPERATORTYPENAME_CASE(AveragePool)
315 HANDLE_OPERATORTYPENAME_CASE(BatchMatMul)
316 HANDLE_OPERATORTYPENAME_CASE(BatchNormalization)
317 HANDLE_OPERATORTYPENAME_CASE(Conv)
318 HANDLE_OPERATORTYPENAME_CASE(Concatenation)
319 HANDLE_OPERATORTYPENAME_CASE(DepthwiseConv)
320 HANDLE_OPERATORTYPENAME_CASE(DepthToSpace)
321 HANDLE_OPERATORTYPENAME_CASE(SpaceToDepth)
322 HANDLE_OPERATORTYPENAME_CASE(FullyConnected)
323 HANDLE_OPERATORTYPENAME_CASE(Dequantize)
324 HANDLE_OPERATORTYPENAME_CASE(L2Normalization)
325 HANDLE_OPERATORTYPENAME_CASE(LocalResponseNormalization)
326 HANDLE_OPERATORTYPENAME_CASE(Log)
327 HANDLE_OPERATORTYPENAME_CASE(Logistic)
328 HANDLE_OPERATORTYPENAME_CASE(LstmCell)
329 HANDLE_OPERATORTYPENAME_CASE(MaxPool)
330 HANDLE_OPERATORTYPENAME_CASE(L2Pool)
331 HANDLE_OPERATORTYPENAME_CASE(FakeQuant)
332 HANDLE_OPERATORTYPENAME_CASE(Mul)
333 HANDLE_OPERATORTYPENAME_CASE(RandomUniform)
334 HANDLE_OPERATORTYPENAME_CASE(Elu)
335 HANDLE_OPERATORTYPENAME_CASE(Relu)
336 HANDLE_OPERATORTYPENAME_CASE(Relu1)
337 HANDLE_OPERATORTYPENAME_CASE(Relu6)
338 HANDLE_OPERATORTYPENAME_CASE(PRelu)
339 HANDLE_OPERATORTYPENAME_CASE(ReorderAxes)
340 HANDLE_OPERATORTYPENAME_CASE(Softmax)
341 HANDLE_OPERATORTYPENAME_CASE(LogSoftmax)
342 HANDLE_OPERATORTYPENAME_CASE(Div)
343 HANDLE_OPERATORTYPENAME_CASE(Tanh)
344 HANDLE_OPERATORTYPENAME_CASE(Sin)
345 HANDLE_OPERATORTYPENAME_CASE(All)
346 HANDLE_OPERATORTYPENAME_CASE(Assert)
347 HANDLE_OPERATORTYPENAME_CASE(ExpandDims)
348 HANDLE_OPERATORTYPENAME_CASE(Fill)
349 HANDLE_OPERATORTYPENAME_CASE(FloorMod)
350 HANDLE_OPERATORTYPENAME_CASE(FloorDiv)
351 HANDLE_OPERATORTYPENAME_CASE(Greater)
352 HANDLE_OPERATORTYPENAME_CASE(GreaterEqual)
353 HANDLE_OPERATORTYPENAME_CASE(Identity)
354 HANDLE_OPERATORTYPENAME_CASE(Less)
355 HANDLE_OPERATORTYPENAME_CASE(LessEqual)
356 HANDLE_OPERATORTYPENAME_CASE(MatMul)
357 HANDLE_OPERATORTYPENAME_CASE(ReduceMax) // Reduction Max
358 HANDLE_OPERATORTYPENAME_CASE(Maximum) // Element-wise Maximum
359 HANDLE_OPERATORTYPENAME_CASE(Merge)
360 HANDLE_OPERATORTYPENAME_CASE(ReduceMin) // Reduction Min
361 HANDLE_OPERATORTYPENAME_CASE(Minimum) // Element-wise Minimum
362 HANDLE_OPERATORTYPENAME_CASE(Neg)
363 HANDLE_OPERATORTYPENAME_CASE(OneHot)
364 HANDLE_OPERATORTYPENAME_CASE(Pack)
365 HANDLE_OPERATORTYPENAME_CASE(Pad)
366 HANDLE_OPERATORTYPENAME_CASE(PadV2)
367 HANDLE_OPERATORTYPENAME_CASE(StridedSlice)
368 HANDLE_OPERATORTYPENAME_CASE(Range)
369 HANDLE_OPERATORTYPENAME_CASE(Rank)
370 HANDLE_OPERATORTYPENAME_CASE(Reshape)
371 HANDLE_OPERATORTYPENAME_CASE(Squeeze)
372 HANDLE_OPERATORTYPENAME_CASE(Rsqrt)
373 HANDLE_OPERATORTYPENAME_CASE(Shape)
374 HANDLE_OPERATORTYPENAME_CASE(Slice)
375 HANDLE_OPERATORTYPENAME_CASE(Split)
376 HANDLE_OPERATORTYPENAME_CASE(SplitV)
377 HANDLE_OPERATORTYPENAME_CASE(Sqrt)
378 HANDLE_OPERATORTYPENAME_CASE(Square)
379 HANDLE_OPERATORTYPENAME_CASE(Switch)
380 HANDLE_OPERATORTYPENAME_CASE(Sub)
381 HANDLE_OPERATORTYPENAME_CASE(Sum)
382 HANDLE_OPERATORTYPENAME_CASE(Tile)
383 HANDLE_OPERATORTYPENAME_CASE(Transpose)
384 HANDLE_OPERATORTYPENAME_CASE(TransposeConv)
385 HANDLE_OPERATORTYPENAME_CASE(Concat)
386 HANDLE_OPERATORTYPENAME_CASE(ConcatV2)
387 HANDLE_OPERATORTYPENAME_CASE(Cast)
388 HANDLE_OPERATORTYPENAME_CASE(Floor)
389 HANDLE_OPERATORTYPENAME_CASE(Ceil)
390 HANDLE_OPERATORTYPENAME_CASE(Gather)
391 HANDLE_OPERATORTYPENAME_CASE(GatherNd)
392 HANDLE_OPERATORTYPENAME_CASE(ResizeBilinear)
393 HANDLE_OPERATORTYPENAME_CASE(SpaceToBatchND)
394 HANDLE_OPERATORTYPENAME_CASE(BatchToSpaceND)
395 HANDLE_OPERATORTYPENAME_CASE(Mean)
396 HANDLE_OPERATORTYPENAME_CASE(ReduceProd)
397 HANDLE_OPERATORTYPENAME_CASE(Svdf)
398 HANDLE_OPERATORTYPENAME_CASE(ArgMax)
399 HANDLE_OPERATORTYPENAME_CASE(ArgMin)
400 HANDLE_OPERATORTYPENAME_CASE(TopK_V2)
401 HANDLE_OPERATORTYPENAME_CASE(Unsupported)
402 HANDLE_OPERATORTYPENAME_CASE(Exp)
403 HANDLE_OPERATORTYPENAME_CASE(DynamicPartition)
404 HANDLE_OPERATORTYPENAME_CASE(DynamicStitch)
405 HANDLE_OPERATORTYPENAME_CASE(Select)
406 HANDLE_OPERATORTYPENAME_CASE(SparseToDense)
407 HANDLE_OPERATORTYPENAME_CASE(Equal)
408 HANDLE_OPERATORTYPENAME_CASE(NotEqual)
409 HANDLE_OPERATORTYPENAME_CASE(Pow)
410 HANDLE_OPERATORTYPENAME_CASE(Any)
411 HANDLE_OPERATORTYPENAME_CASE(LogicalAnd)
412 HANDLE_OPERATORTYPENAME_CASE(LogicalNot)
413 HANDLE_OPERATORTYPENAME_CASE(LogicalOr)
414 HANDLE_OPERATORTYPENAME_CASE(CTCBeamSearchDecoder)
415 HANDLE_OPERATORTYPENAME_CASE(Unpack)
416 HANDLE_OPERATORTYPENAME_CASE(ZerosLike)
417 HANDLE_OPERATORTYPENAME_CASE(UnidirectionalSequenceLstm)
418 HANDLE_OPERATORTYPENAME_CASE(BidirectionalSequenceLstm)
419 HANDLE_OPERATORTYPENAME_CASE(BidirectionalSequenceRnn)
420 HANDLE_OPERATORTYPENAME_CASE(ResizeNearestNeighbor)
421 HANDLE_OPERATORTYPENAME_CASE(LeakyRelu)
422 HANDLE_OPERATORTYPENAME_CASE(SquaredDifference)
423 HANDLE_OPERATORTYPENAME_CASE(MirrorPad)
424 HANDLE_OPERATORTYPENAME_CASE(Unique)
425 HANDLE_OPERATORTYPENAME_CASE(UnidirectionalSequenceRnn)
426 HANDLE_OPERATORTYPENAME_CASE(ReverseV2)
427 HANDLE_OPERATORTYPENAME_CASE(Cos)
428 HANDLE_OPERATORTYPENAME_CASE(Where)
429 HANDLE_OPERATORTYPENAME_CASE(ReverseSequence)
430 default:
431 LOG(FATAL) << "Unhandled op type";
432 #undef HANDLE_OPERATORTYPENAME_CASE
433 }
434 }
435
HelpfulOperatorTypeName(const Operator & op)436 string HelpfulOperatorTypeName(const Operator& op) {
437 if (op.type == OperatorType::kUnsupported) {
438 return toco::port::StringF(
439 "(Unsupported TensorFlow op: %s)",
440 static_cast<const TensorFlowUnsupportedOperator&>(op).tensorflow_op);
441 }
442 return OperatorTypeName(op.type);
443 }
444
OperatorSupportsFusedActivation(OperatorType type)445 bool OperatorSupportsFusedActivation(OperatorType type) {
446 switch (type) {
447 case OperatorType::kAdd:
448 case OperatorType::kAveragePool:
449 case OperatorType::kBatchNormalization:
450 case OperatorType::kConv:
451 case OperatorType::kDepthwiseConv:
452 case OperatorType::kDiv:
453 case OperatorType::kFullyConnected:
454 case OperatorType::kL2Pool:
455 case OperatorType::kMaxPool:
456 case OperatorType::kMul:
457 case OperatorType::kSub:
458 case OperatorType::kSquaredDifference:
459 return true;
460 default:
461 return false;
462 }
463 }
464
LogSummary(int log_level,const Model & model)465 void LogSummary(int log_level, const Model& model) {
466 VLOG(log_level) << "Operators summary (" << model.operators.size()
467 << " operators):";
468 std::unordered_multiset<OperatorType> ops_by_type;
469 for (const auto& op : model.operators) {
470 ops_by_type.insert(op->type);
471 }
472 auto it = ops_by_type.begin();
473 while (it != ops_by_type.end()) {
474 int count = ops_by_type.count(*it);
475 VLOG(log_level) << " " << OperatorTypeName(*it) << ": " << count;
476 std::advance(it, count);
477 }
478 }
479
LogArray(int log_level,const Model & model,const string & name)480 void LogArray(int log_level, const Model& model, const string& name) {
481 VLOG(log_level) << "Array: " << name;
482 if (!model.HasArray(name)) {
483 VLOG(log_level) << " DOES NOT EXIST";
484 return;
485 }
486 const auto& array = model.GetArray(name);
487 VLOG(log_level) << " Data type: " << ArrayDataTypeName(array.data_type);
488 VLOG(log_level) << " Final type: "
489 << ArrayDataTypeName(array.final_data_type);
490 if (array.buffer) {
491 VLOG(log_level) << " Constant Buffer";
492 }
493 if (array.alloc) {
494 VLOG(log_level) << " Transient Alloc";
495 }
496 if (array.has_shape()) {
497 const Shape& array_shape = array.shape();
498 if (array_shape.dimensions_count() == 0) {
499 VLOG(log_level) << " (Zero dimensions)";
500 } else {
501 string message = " Dims: ";
502 bool first = true;
503 for (const int dim : array_shape.dims()) {
504 if (!first) {
505 message += ", ";
506 }
507 first = false;
508 toco::port::AppendF(&message, "%d", dim);
509 }
510 VLOG(log_level) << message;
511 }
512 }
513 if (array.minmax) {
514 VLOG(log_level) << " MinMax: " << array.minmax->min << " .. "
515 << array.minmax->max;
516 }
517 if (array.quantization_params) {
518 VLOG(log_level) << " QuantizationParams: zero_point="
519 << static_cast<int>(array.quantization_params->zero_point)
520 << ", scale=" << array.quantization_params->scale;
521 }
522 }
523
DumpGraphvizVideoFrame(const Model & model)524 void DumpGraphvizVideoFrame(const Model& model) {
525 namespace port = toco::port;
526
527 const auto& dump_options = *GraphVizDumpOptions::singleton();
528 if (!dump_options.dump_graphviz_video) {
529 return;
530 }
531 CHECK(!dump_options.dump_graphviz.empty());
532 // TODO(benoitjacob): the static data here means that this function
533 // is stateful, not reentrant, and effectively leaks memory till exit
534 // (since dump_hashes can only grow in size). It also means that it
535 // really only is intended to be called for a single model during the
536 // process' lifetime. So it's not great design at all. The overriding
537 // design aspect here is to make the video-dumping code as unintrusive
538 // and self-contained as possible. Eventually, we'll want to have that
539 // cleaned-up, but that will require some form of general statefulness
540 // in toco (some kind of 'tooling state' data structure) that does
541 // not exist at present, and would be premature to design here just for
542 // this new video-dumping feature.
543 static int dump_id = 0;
544 static std::unordered_set<std::size_t> dump_hashes;
545 string graphviz_dump;
546 DumpGraphviz(model, &graphviz_dump,
547 toco::port::StringF("VIDEO frame:%05d", dump_id));
548 std::size_t hash = std::hash<string>{}(graphviz_dump);
549 if (!dump_hashes.count(hash)) {
550 LOG(INFO) << "DUMPING GRAPHVIZ VIDEO FRAME: " << dump_id;
551 dump_hashes.insert(hash);
552 const auto result = port::file::SetContents(
553 port::file::JoinPath(
554 dump_options.dump_graphviz,
555 toco::port::StringF("toco_video_%05d.dot", dump_id)),
556 graphviz_dump, port::file::Defaults());
557 QCHECK(result.ok()) << result.error_message();
558 dump_id++;
559 }
560 }
561
LogDump(int log_level,const string & message,const Model & model)562 void LogDump(int log_level, const string& message, const Model& model) {
563 namespace port = toco::port;
564 const auto& dump_options = *GraphVizDumpOptions::singleton();
565
566 DumpGraphvizVideoFrame(model);
567 if (!dump_options.dump_graphviz.empty()) {
568 string graphviz_dump;
569
570 DumpGraphviz(model, &graphviz_dump, message);
571 const auto result = port::file::SetContents(
572 port::file::JoinPath(
573 dump_options.dump_graphviz,
574 absl::StrCat("toco_", absl::StrReplaceAll(message, {{" ", "_"}}),
575 ".dot")),
576 graphviz_dump, port::file::Defaults());
577 QCHECK(result.ok()) << result.error_message();
578 }
579
580 if (!VLOG_IS_ON(log_level)) {
581 return;
582 }
583 VLOG(log_level) << "BEGIN DUMP OF TOCO MODEL (" << message << ")";
584 LogSummary(log_level, model);
585 std::unordered_set<string> already_printed_arrays;
586 for (const auto& op : model.operators) {
587 for (const auto& input : op->inputs) {
588 if (!already_printed_arrays.count(input)) {
589 already_printed_arrays.insert(input);
590 LogArray(log_level, model, input);
591 }
592 }
593 VLOG(log_level) << HelpfulOperatorTypeName(*op) << " :";
594 VLOG(log_level) << " " << FormatArraysList(model, op->inputs) << " -> "
595 << FormatArraysList(model, op->outputs);
596 if (op->fused_activation_function != FusedActivationFunctionType::kNone) {
597 VLOG(log_level) << " (with fused activation function)";
598 }
599 for (const auto& output : op->outputs) {
600 if (!already_printed_arrays.count(output)) {
601 already_printed_arrays.insert(output);
602 LogArray(log_level, model, output);
603 }
604 }
605 }
606 VLOG(log_level) << "END DUMP OF TOCO MODEL (" << message << ")";
607 }
608
609 // Note remaining raw-array extension in ProcessTensorFlowReshapeOperator().
ExtendShape(Shape * shape,int new_shape_size)610 void ExtendShape(Shape* shape, int new_shape_size) {
611 CHECK_GE(new_shape_size, shape->dimensions_count());
612 const int size_increase = new_shape_size - shape->dimensions_count();
613 auto* shape_dims = shape->mutable_dims();
614 shape_dims->insert(shape_dims->begin(), size_increase, 1);
615 }
616
617 // TODO(b/62904716) Remove along with remaining uses.
UnextendShape(Shape * shape,int new_shape_size)618 void UnextendShape(Shape* shape, int new_shape_size) {
619 CHECK_LE(new_shape_size, shape->dimensions_count());
620 const int size_reduction = shape->dimensions_count() - new_shape_size;
621 for (int i = 0; i < size_reduction; i++) {
622 CHECK_EQ(shape->dims(i), 1);
623 }
624 std::vector<int>& shape_dims = *shape->mutable_dims();
625 shape_dims.erase(shape_dims.begin(), shape_dims.begin() + size_reduction);
626 }
627
628 // In general, zero-sized dimensions are disallowed, but there are exceptions,
629 // e.g., if the tensor data itself represents a scalar (rank 0) shape, its
630 // shape will have dimensions [0]. CheckNonEmptyShapeDimensions is more
631 // strict, and is appropriate for ops and comparisons where an empty shape
632 // doesn't make sense.
633 template <typename Dims>
CheckValidShapeDimensions(const Dims & dims)634 void CheckValidShapeDimensions(const Dims& dims) {
635 if (dims.size() == 1 && dims[0] == 0) {
636 return;
637 }
638 for (const auto& dim : dims) {
639 CHECK_GE(dim, 1);
640 }
641 }
642
CheckValidShape(const Shape & shape)643 void CheckValidShape(const Shape& shape) {
644 CheckValidShapeDimensions(shape.dims());
645 }
646
IsNonEmpty(const Shape & shape)647 bool IsNonEmpty(const Shape& shape) {
648 for (int i = 0; i < shape.dimensions_count(); ++i) {
649 if (shape.dims(i) < 1) return false;
650 }
651 return true;
652 }
653
CheckNonEmptyShapeDimensions(const Shape & shape)654 void CheckNonEmptyShapeDimensions(const Shape& shape) {
655 for (int i = 0; i < shape.dimensions_count(); ++i) {
656 CHECK_GE(shape.dims()[i], 1) << "shape has dimension 0 at index << " << i
657 << ". shape = " << ShapeToString(shape);
658 }
659 }
660
ShapesAgreeUpToBroadcasting(const Shape & shape0,const Shape & shape1)661 bool ShapesAgreeUpToBroadcasting(const Shape& shape0, const Shape& shape1) {
662 CheckNonEmptyShapeDimensions(shape0);
663 CheckNonEmptyShapeDimensions(shape1);
664
665 const Shape* longer = &shape0;
666 const Shape* shorter = &shape1;
667 if (shape1.dimensions_count() > shape0.dimensions_count()) {
668 longer = &shape1;
669 shorter = &shape0;
670 }
671
672 // Walk dimensions back to front until we run out of dimensions in the shorter
673 // shape.
674 int longer_index = longer->dimensions_count() - 1;
675 int shorter_index = shorter->dimensions_count() - 1;
676 while (shorter_index >= 0) {
677 const int d_long = longer->dims(longer_index);
678 const int d_short = shorter->dims(shorter_index);
679 // Broadcasting fails if the dimensions are different *and* neither is 1.
680 if ((d_long != d_short) && (d_long != 1) && (d_short != 1)) {
681 return false;
682 }
683 longer_index--;
684 shorter_index--;
685 }
686 return true;
687 }
688
ShapesAgreeUpToExtending(const Shape & shape0,const Shape & shape1)689 bool ShapesAgreeUpToExtending(const Shape& shape0, const Shape& shape1) {
690 CheckNonEmptyShapeDimensions(shape0);
691 CheckNonEmptyShapeDimensions(shape1);
692
693 const Shape* longer = &shape0;
694 const Shape* shorter = &shape1;
695 if (shape1.dimensions_count() > shape0.dimensions_count()) {
696 longer = &shape1;
697 shorter = &shape0;
698 }
699
700 // Walk dimensions back to front until we run out of dimensions in the shorter
701 // shape.
702 int longer_index = longer->dimensions_count() - 1;
703 int shorter_index = shorter->dimensions_count() - 1;
704 while (shorter_index >= 0) {
705 const int d_long = longer->dims(longer_index);
706 const int d_short = shorter->dims(shorter_index);
707 // Extending fails if the dimensions are different.
708 if (d_long != d_short) {
709 return false;
710 }
711 longer_index--;
712 shorter_index--;
713 }
714
715 // The remaining dimensions in the longer shape must be 1.
716 while (longer_index >= 0) {
717 const int d_long = longer->dims(longer_index);
718 if (d_long != 1) {
719 return false;
720 }
721 longer_index--;
722 }
723
724 return true;
725 }
726
RequiredBufferSizeForShape(const Shape & shape)727 int RequiredBufferSizeForShape(const Shape& shape) {
728 CheckValidShape(shape);
729 int max_offset = 1;
730 for (const auto& dim : shape.dims()) {
731 max_offset *= dim;
732 }
733 return max_offset;
734 }
735
IsConstantParameterArray(const Model & model,const string & name)736 bool IsConstantParameterArray(const Model& model, const string& name) {
737 if (!model.HasArray(name)) {
738 return false;
739 }
740
741 return !!model.GetArray(name).buffer;
742 }
743
744 namespace {
745 template <ArrayDataType A>
CompareArrayBuffers(const Array & lhs_array,const Array & rhs_array)746 bool CompareArrayBuffers(const Array& lhs_array, const Array& rhs_array) {
747 CHECK(lhs_array.data_type == rhs_array.data_type) << "Data types must match";
748 CHECK(lhs_array.buffer) << "LHS must be constant";
749 CHECK(rhs_array.buffer) << "RHS must be constant";
750 const auto& lhs_data = lhs_array.GetBuffer<A>().data;
751 const auto& rhs_data = rhs_array.GetBuffer<A>().data;
752 CHECK_EQ(lhs_data.size(), rhs_data.size())
753 << "Buffer sizes must match in element count";
754 for (int i = 0; i < lhs_data.size(); ++i) {
755 if (lhs_data[i] != rhs_data[i]) {
756 return false;
757 }
758 }
759 return true;
760 }
761
HaveSameMinMax(const Array & lhs_array,const Array & rhs_array)762 bool HaveSameMinMax(const Array& lhs_array, const Array& rhs_array) {
763 if (lhs_array.minmax || rhs_array.minmax) {
764 if (!lhs_array.minmax || !rhs_array.minmax) {
765 return false;
766 }
767 if (!(*lhs_array.minmax == *rhs_array.minmax)) {
768 return false;
769 }
770 }
771 return true;
772 }
773
HaveSameQuantizationParams(const Array & lhs_array,const Array & rhs_array)774 bool HaveSameQuantizationParams(const Array& lhs_array,
775 const Array& rhs_array) {
776 if (lhs_array.quantization_params || rhs_array.quantization_params) {
777 if (!lhs_array.quantization_params || !rhs_array.quantization_params) {
778 return false;
779 }
780 if (!(*lhs_array.quantization_params == *rhs_array.quantization_params)) {
781 return false;
782 }
783 }
784 return true;
785 }
786
787 } // namespace
788
CompareConstantArrays(const Array & lhs_array,const Array & rhs_array)789 bool CompareConstantArrays(const Array& lhs_array, const Array& rhs_array) {
790 bool attrs_equal = lhs_array.shape() == rhs_array.shape() &&
791 lhs_array.data_type == rhs_array.data_type &&
792 lhs_array.final_data_type == rhs_array.final_data_type &&
793 HaveSameMinMax(lhs_array, rhs_array) &&
794 HaveSameQuantizationParams(lhs_array, rhs_array) &&
795 lhs_array.narrow_range == rhs_array.narrow_range;
796 if (!attrs_equal) {
797 return false;
798 }
799 switch (lhs_array.data_type) {
800 case ArrayDataType::kBool:
801 return CompareArrayBuffers<ArrayDataType::kBool>(lhs_array, rhs_array);
802 case ArrayDataType::kFloat:
803 return CompareArrayBuffers<ArrayDataType::kFloat>(lhs_array, rhs_array);
804 case ArrayDataType::kInt8:
805 return CompareArrayBuffers<ArrayDataType::kInt8>(lhs_array, rhs_array);
806 case ArrayDataType::kUint8:
807 return CompareArrayBuffers<ArrayDataType::kUint8>(lhs_array, rhs_array);
808 case ArrayDataType::kInt16:
809 return CompareArrayBuffers<ArrayDataType::kInt16>(lhs_array, rhs_array);
810 case ArrayDataType::kUint16:
811 return CompareArrayBuffers<ArrayDataType::kUint16>(lhs_array, rhs_array);
812 case ArrayDataType::kInt32:
813 return CompareArrayBuffers<ArrayDataType::kInt32>(lhs_array, rhs_array);
814 case ArrayDataType::kUint32:
815 return CompareArrayBuffers<ArrayDataType::kUint32>(lhs_array, rhs_array);
816 case ArrayDataType::kInt64:
817 return CompareArrayBuffers<ArrayDataType::kInt64>(lhs_array, rhs_array);
818 case ArrayDataType::kUint64:
819 return CompareArrayBuffers<ArrayDataType::kUint64>(lhs_array, rhs_array);
820 case ArrayDataType::kString:
821 return CompareArrayBuffers<ArrayDataType::kString>(lhs_array, rhs_array);
822 case ArrayDataType::kComplex64:
823 return CompareArrayBuffers<ArrayDataType::kComplex64>(lhs_array,
824 rhs_array);
825 default:
826 LOG(FATAL) << "Unsupported data type: "
827 << ArrayDataTypeName(lhs_array.data_type);
828 return false;
829 }
830 }
831
832 namespace {
833 // Take an array name, which may be something like "name:3_5" and make it
834 // acceptable as a TF node name, say "name_3_5";
SanitizeNameForTFNode(const string & array_name)835 string SanitizeNameForTFNode(const string& array_name) {
836 auto node_name = array_name;
837 std::replace(node_name.begin(), node_name.end(), ':', '_');
838 return node_name;
839 }
840
CheckInputArraysAreNotOutputArrays(const ModelFlags & model_flags)841 void CheckInputArraysAreNotOutputArrays(const ModelFlags& model_flags) {
842 for (const auto& input_array : model_flags.input_arrays()) {
843 for (const string& output_array : model_flags.output_arrays()) {
844 QCHECK_NE(input_array.name(), output_array)
845 << "The array " << output_array
846 << " is listed in both --input_arrays and --output_arrays.";
847 }
848 }
849 }
850
IsAsciiPrintable(const string & name)851 bool IsAsciiPrintable(const string& name) {
852 for (char c : name) {
853 if (!absl::ascii_isprint(c)) {
854 return false;
855 }
856 }
857 return true;
858 }
859
DumpAscii(const string & name)860 string DumpAscii(const string& name) {
861 string result;
862 port::AppendF(&result, "ASCII | Hex\n");
863 port::AppendF(&result, "------+----\n");
864 for (char c : name) {
865 if (absl::ascii_isprint(c)) {
866 port::AppendF(&result, "%c | %x\n", c, c);
867 } else {
868 port::AppendF(&result, " | %x Not ASCII printable!\n", c);
869 }
870 }
871 return result;
872 }
873
CheckNonAsciiIOArrays(const ModelFlags & model_flags)874 void CheckNonAsciiIOArrays(const ModelFlags& model_flags) {
875 if (model_flags.allow_nonascii_arrays()) {
876 return;
877 }
878 for (const auto& input_array : model_flags.input_arrays()) {
879 QCHECK(IsAsciiPrintable(input_array.name()))
880 << "Non-ASCII-printable character found in --input_arrays: "
881 << input_array.name()
882 << ". Pass --allow_nonascii_arrays to allow that. "
883 << "Here is a dump of the string:\n\n"
884 << DumpAscii(input_array.name());
885 }
886 for (const string& output_array : model_flags.output_arrays()) {
887 QCHECK(IsAsciiPrintable(output_array))
888 << "Non-ASCII-printable character found in --output_arrays: "
889 << output_array << ". Pass --allow_nonascii_arrays to allow that. "
890 << "Here is a dump of the string:\n\n"
891 << DumpAscii(output_array);
892 }
893 }
894
CheckNonExistentIOArrays(const Model & model)895 void CheckNonExistentIOArrays(const Model& model) {
896 // "non-existent" is interpreted in the stronger sense of
897 // "not actually produced/consumed by an op".
898 // Rationale: we have to artificially fix up TensorFlow graphs by creating
899 // any array that it refers to, so just checking that arrays exist isn't
900 // sufficient. The real invariant here is whether arrays are produced/consumed
901 // by something.
902 if (model.flags.allow_nonexistent_arrays()) {
903 return;
904 }
905 static constexpr char general_comment[] =
906 "Is it a typo? To silence this message, pass this flag: "
907 "allow_nonexistent_arrays";
908 for (const string& output_array : model.flags.output_arrays()) {
909 if (IsConstantParameterArray(model, output_array)) {
910 continue; // It is OK to request that a constant be an output.
911 }
912 QCHECK(GetOpWithOutput(model, output_array))
913 << "Specified output array \"" << output_array
914 << "\" is not produced by any op in this graph. " << general_comment;
915 }
916 for (const auto& rnn_state : model.flags.rnn_states()) {
917 if (!rnn_state.discardable()) {
918 // Check that all RNN states are consumed
919 QCHECK(GetOpWithInput(model, rnn_state.state_array()))
920 << "Specified RNN state \"" << rnn_state.state_array()
921 << "\" is not consumed by any op in this graph. " << general_comment;
922 // Check that all RNN back-edge source arrays are produced
923 QCHECK(GetOpWithOutput(model, rnn_state.back_edge_source_array()))
924 << "Specified RNN back-edge source array \""
925 << rnn_state.back_edge_source_array()
926 << "\" is not produced by any op in this graph. " << general_comment;
927 }
928 }
929 }
930
931 } // namespace
932
CheckNoMissingArray(const Model & model)933 void CheckNoMissingArray(const Model& model) {
934 for (const auto& op : model.operators) {
935 for (const auto& input : op->inputs) {
936 CHECK(model.HasArray(input) || model.optional_arrays.count(input))
937 << "Input: " << input << " missing for op: " << op->outputs[0] << ".";
938 }
939 for (const auto& output : op->outputs) {
940 CHECK(model.HasArray(output)) << "Output: " << output << " missing.";
941 }
942 }
943 CheckNonExistentIOArrays(model);
944 }
945
FixNoMissingArray(Model * model)946 void FixNoMissingArray(Model* model) {
947 for (const auto& op : model->operators) {
948 for (const auto& input : op->inputs) {
949 if (!model->HasArray(input) && !model->IsOptionalArray(input)) {
950 model->GetOrCreateArray(input);
951 }
952 }
953 for (const auto& output : op->outputs) {
954 if (!model->HasArray(output) && !model->IsOptionalArray(output)) {
955 model->GetOrCreateArray(output);
956 }
957 }
958 }
959 if (model->flags.allow_nonexistent_arrays()) {
960 for (const string& output_array : model->flags.output_arrays()) {
961 model->GetOrCreateArray(output_array);
962 }
963 for (const auto& rnn_state : model->flags.rnn_states()) {
964 model->GetOrCreateArray(rnn_state.state_array());
965 model->GetOrCreateArray(rnn_state.back_edge_source_array());
966 }
967 }
968 }
969
CheckNoOrphanedArray(const Model & model)970 void CheckNoOrphanedArray(const Model& model) {
971 std::unordered_set<string> arrays_without_known_use;
972 for (const auto& array : model.GetArrayMap()) {
973 if (IsDiscardableArray(model, array.first)) {
974 arrays_without_known_use.insert(array.first);
975 }
976 }
977 for (const auto& op : model.operators) {
978 for (const auto& input : op->inputs) {
979 arrays_without_known_use.erase(input);
980 }
981 for (const auto& output : op->outputs) {
982 arrays_without_known_use.erase(output);
983 }
984 }
985 for (const auto& rnn_state : model.flags.rnn_states()) {
986 arrays_without_known_use.erase(rnn_state.state_array());
987 arrays_without_known_use.erase(rnn_state.back_edge_source_array());
988 }
989 if (!arrays_without_known_use.empty()) {
990 for (const auto& array : arrays_without_known_use) {
991 LOG(INFO) << "Error: Orphaned array: " << array;
992 }
993 }
994 CHECK(arrays_without_known_use.empty());
995 }
996
FixNoOrphanedArray(Model * model)997 void FixNoOrphanedArray(Model* model) {
998 std::unordered_set<string> arrays_without_known_use;
999 for (const auto& array : model->GetArrayMap()) {
1000 arrays_without_known_use.insert(array.first);
1001 }
1002 for (const auto& op : model->operators) {
1003 for (const auto& input : op->inputs) {
1004 arrays_without_known_use.erase(input);
1005 }
1006 for (const auto& output : op->outputs) {
1007 arrays_without_known_use.erase(output);
1008 }
1009 }
1010 for (const auto& rnn_state : model->flags.rnn_states()) {
1011 arrays_without_known_use.erase(rnn_state.state_array());
1012 arrays_without_known_use.erase(rnn_state.back_edge_source_array());
1013 }
1014 for (const auto& array : arrays_without_known_use) {
1015 if (IsDiscardableArray(*model, array)) {
1016 model->EraseArray(array);
1017 }
1018 }
1019 }
1020
1021 // Apply checks to arrays individually (for-each fashion).
1022 //
1023 // Check consistency of array fields, check name.
CheckEachArray(const Model & model)1024 void CheckEachArray(const Model& model) {
1025 for (const auto& array_entry : model.GetArrayMap()) {
1026 const auto& array = array_entry.second;
1027 // It's OK to have a buffer or an alloc, but not both.
1028 // (Since allocs are for transient arrays without a buffer).
1029 CHECK(!array->buffer || !array->alloc);
1030 if (array->buffer) {
1031 // If there is a buffer, its type should be consistent with data_type.
1032 CHECK(array->buffer->type == array->data_type);
1033 // The presence of a fixed buffer should imply the presence of a fixed
1034 // shape.
1035 CHECK(array->has_shape());
1036 // Constant buffer should has a valid shape.
1037 CheckValidShape(array->shape());
1038 // The shape flat-size should agree with the buffer length.
1039 CHECK_EQ(array->buffer->Length(),
1040 RequiredBufferSizeForShape(array->shape()));
1041 }
1042
1043 // Check name. Either "name_with_suffix_8", "name_with_port:3", but not
1044 // "name_with_both:3_8".
1045 const string& name = array_entry.first;
1046 auto colon_pos = name.find_first_of(":");
1047 if (colon_pos != string::npos) {
1048 CHECK_EQ(name.substr(colon_pos + 1).find_first_not_of("0123456789"),
1049 string::npos)
1050 << "Array '" << name << "' has non-digit characters after colon.";
1051 }
1052 CHECK_GT(colon_pos, 0) << "Array '" << name
1053 << "' must not start with a colon.";
1054 }
1055 }
1056
CheckOperatorOrdering(const Model & model)1057 void CheckOperatorOrdering(const Model& model) {
1058 std::unordered_set<string> arrays_behind_us;
1059 for (const auto& array_entry : model.GetArrayMap()) {
1060 if (!GetOpWithOutput(model, array_entry.first)) {
1061 arrays_behind_us.insert(array_entry.first);
1062 }
1063 }
1064 arrays_behind_us.insert(model.optional_arrays.begin(),
1065 model.optional_arrays.end());
1066 for (const auto& op : model.operators) {
1067 for (const auto& input : op->inputs) {
1068 if (!IsConstantParameterArray(model, input)) {
1069 CHECK(arrays_behind_us.count(input));
1070 }
1071 }
1072 for (const auto& output : op->outputs) {
1073 CHECK(!arrays_behind_us.count(output));
1074 arrays_behind_us.insert(output);
1075 }
1076 }
1077 for (const string& output_array : model.flags.output_arrays()) {
1078 CHECK(arrays_behind_us.count(output_array));
1079 }
1080 }
1081
FixOperatorOrdering(Model * model)1082 void FixOperatorOrdering(Model* model) {
1083 std::unordered_set<string> arrays_behind_us;
1084 for (const auto& array_entry : model->GetArrayMap()) {
1085 if (!GetOpWithOutput(*model, array_entry.first)) {
1086 arrays_behind_us.insert(array_entry.first);
1087 }
1088 }
1089 arrays_behind_us.insert(model->optional_arrays.begin(),
1090 model->optional_arrays.end());
1091 std::vector<std::unique_ptr<Operator>> old_operators;
1092 std::swap(old_operators, model->operators);
1093 std::set<std::size_t> remaining;
1094 for (std::size_t i = 0; i < old_operators.size(); i++) {
1095 remaining.insert(i);
1096 }
1097 std::unordered_map<string, string> reason_why_leftover;
1098 while (true) {
1099 bool inserted_something = false;
1100 for (const auto& i : remaining) {
1101 bool can_insert = true;
1102 auto& op = old_operators[i];
1103 CHECK(op);
1104 for (const auto& input : op->inputs) {
1105 if (!IsConstantParameterArray(*model, input) &&
1106 !arrays_behind_us.count(input)) {
1107 for (const string& output : op->outputs) {
1108 reason_why_leftover[output] = input;
1109 }
1110 can_insert = false;
1111 break;
1112 }
1113 }
1114 if (can_insert) {
1115 model->operators.emplace_back(nullptr);
1116 for (const auto& output : op->outputs) {
1117 arrays_behind_us.insert(output);
1118 }
1119 std::swap(op, model->operators.back());
1120 remaining.erase(i);
1121 inserted_something = true;
1122 break;
1123 }
1124 }
1125 if (!inserted_something) {
1126 break;
1127 }
1128 }
1129 if (!remaining.empty()) {
1130 LOG(ERROR)
1131 << "No viable ordering of operators was found. "
1132 << "Here is a 'backtrace' of at least one part of the graph that is "
1133 << "problematic. It starts with the first operator that has as "
1134 << "problematic input array, and then walks back the graph to "
1135 << "the operator that produced that input array, etc., until we find "
1136 << "the root cause:";
1137 LOG(ERROR) << "BEGIN TRACE OF OPERATOR WITH BAD INPUT";
1138 LOG(ERROR) << "Here is the first-encountered operator with a bad input: ";
1139 const Operator* bad_op = old_operators[*remaining.begin()].get();
1140 std::unordered_set<string> bad_inputs_already_traced;
1141 // The following while(true) loop should always end with a LOG(FATAL).
1142 while (true) {
1143 LOG(ERROR) << HelpfulOperatorTypeName(*bad_op) << " : "
1144 << FormatArraysList(*model, bad_op->inputs) << " -> "
1145 << FormatArraysList(*model, bad_op->outputs);
1146 bool found_bad_output = false;
1147 string bad_output;
1148 for (const string& output : bad_op->outputs) {
1149 if (reason_why_leftover.count(output)) {
1150 found_bad_output = true;
1151 bad_output = output;
1152 break;
1153 }
1154 }
1155 CHECK(found_bad_output);
1156 const string& bad_input = reason_why_leftover[bad_output];
1157 LOG(ERROR) << "The bad input here is: " << bad_input;
1158 if (bad_inputs_already_traced.count(bad_input)) {
1159 LOG(FATAL)
1160 << "Cycle found! We already encountered that "
1161 << "input array, " << bad_input << ", earlier in the "
1162 << "above trace! We expect graphs to be acyclic, even "
1163 << "RNNs. Let us know if some graph actually needs to have "
1164 << "cycles, but first, please check if it really is "
1165 << "an *inference* graph. *Training* graphs are out-of-scope "
1166 << "for toco.";
1167 }
1168 bad_inputs_already_traced.insert(bad_input);
1169 bad_op = nullptr;
1170 for (const auto& i : remaining) {
1171 const Operator* op = old_operators[i].get();
1172 for (const string& output : op->outputs) {
1173 if (bad_input == output) {
1174 bad_op = op;
1175 break;
1176 }
1177 }
1178 if (bad_op) {
1179 break;
1180 }
1181 }
1182 if (!bad_op) {
1183 LOG(ERROR) << "And that's the root cause: "
1184 << "that array, " << bad_input << ", isn't produced by any "
1185 << "operator, or provided in any other way.";
1186 LOG(ERROR) << "END TRACE OF OPERATOR WITH BAD INPUT";
1187 LOG(FATAL) << "(The above was a multi-line fatal error)";
1188 }
1189 LOG(ERROR) << "And that array is the output of the following operator:";
1190 }
1191 }
1192 CHECK(remaining.empty())
1193 << "Should never get here! In case of bad graph, "
1194 << "the above code should have generated a FATAL error already!";
1195 }
1196
CheckInvariants(const Model & model)1197 void CheckInvariants(const Model& model) {
1198 CheckInputArraysAreNotOutputArrays(model.flags);
1199 CheckNonAsciiIOArrays(model.flags);
1200 CheckNoMissingArray(model);
1201 CheckNoOrphanedArray(model);
1202 CheckEachArray(model);
1203 CheckOperatorOrdering(model);
1204 }
1205
CheckCountInRange(const::toco::ModelFlags::ModelCheck & model_check,const int count,const string & count_description)1206 void CheckCountInRange(const ::toco::ModelFlags::ModelCheck& model_check,
1207 const int count, const string& count_description) {
1208 if (model_check.count_min() >= 0) {
1209 CHECK_GE(count, model_check.count_min())
1210 << "Mismatch in " << count_description << ": count was " << count
1211 << ", but the specified "
1212 << (model_check.count_max() > model_check.count_min() ? "minimum"
1213 : "value")
1214 << " was " << model_check.count_min() << ".";
1215 }
1216 if (model_check.count_max() > model_check.count_min()) {
1217 CHECK_LE(count, model_check.count_max())
1218 << "Mismatch in " << count_description << ": count was " << count
1219 << ", but the specified maximum was " << model_check.count_max() << ".";
1220 }
1221 }
1222
CheckModelCounts(const Model & model)1223 void CheckModelCounts(const Model& model) {
1224 std::unordered_multiset<OperatorType> ops_by_type;
1225 std::unordered_map<string, OperatorType> op_type_by_name;
1226 if (model.flags.model_checks_size() == 0) {
1227 return;
1228 }
1229
1230 for (const auto& op : model.operators) {
1231 ops_by_type.insert(op->type);
1232 op_type_by_name[OperatorTypeName(op->type)] = op->type;
1233 }
1234 for (const auto& model_check : model.flags.model_checks()) {
1235 string count_type = model_check.count_type();
1236 if (count_type == "None") {
1237 continue;
1238 } else if (count_type == "Arrays") {
1239 CheckCountInRange(model_check, model.GetArrayMap().size(),
1240 "count of arrays");
1241 } else if (count_type == "Total") {
1242 CheckCountInRange(model_check, model.operators.size(),
1243 "count of all operator instances");
1244 } else {
1245 // The check type is not itself checked against the set of valid
1246 // operators, mainly because the enum set cannot be iterated in C++.
1247 const int found_count =
1248 op_type_by_name.count(count_type) > 0
1249 ? ops_by_type.count(op_type_by_name[count_type])
1250 : 0;
1251 CheckCountInRange(model_check, found_count,
1252 "count of instances of " + count_type + " operator");
1253 }
1254 }
1255 }
1256
FixEdgeArrays(Model * model)1257 void FixEdgeArrays(Model* model) {
1258 for (const string& output_array_name : model->flags.output_arrays()) {
1259 if (!GetOpWithOutput(*model, output_array_name)) {
1260 // Output has no operator producing it. Change that by inserting a copy.
1261 LOG(WARNING) << "Fixing constant output array " << output_array_name
1262 << " by inserting a copy. This is not optimal.";
1263 string intermediate_array_name =
1264 AvailableArrayName(*model, output_array_name + "_copy");
1265 CloneArray(model, output_array_name, intermediate_array_name);
1266 InsertCopyOperator(model, intermediate_array_name, output_array_name);
1267 }
1268 }
1269 }
1270
DedupeConstantArrays(Model * model,size_t min_size)1271 void DedupeConstantArrays(Model* model, size_t min_size) {
1272 // Walk all 0..N and compare with the remaining n+1..N.
1273 // This lets us avoid N^2 comparisons and erase duplicate arrays while
1274 // iterating.
1275 const auto& array_map = model->GetArrayMap();
1276 for (auto lhs_array_it = array_map.begin(); lhs_array_it != array_map.end();
1277 ++lhs_array_it) {
1278 const auto& lhs_array_name = lhs_array_it->first;
1279 const auto& lhs_array = *lhs_array_it->second;
1280 if (!IsConstantParameterArray(*model, lhs_array_name)) {
1281 // Not a constant array; skip.
1282 continue;
1283 }
1284 ArrayDataType final_data_type =
1285 lhs_array.final_data_type != ArrayDataType::kNone
1286 ? lhs_array.final_data_type
1287 : lhs_array.data_type;
1288 // Ignore small arrays, don't check string arrays because it is not possible
1289 // to estimate its size.
1290 if (final_data_type != ArrayDataType::kString) {
1291 size_t array_byte_size =
1292 lhs_array.buffer->Length() * ElementSize(final_data_type);
1293 if (array_byte_size < min_size) {
1294 // Too small; skip.
1295 continue;
1296 }
1297 }
1298
1299 auto next_lhs_array_it = lhs_array_it;
1300 ++next_lhs_array_it;
1301 for (auto rhs_array_it = next_lhs_array_it;
1302 rhs_array_it != array_map.end();) {
1303 const auto& rhs_array_name = rhs_array_it->first;
1304 const auto& rhs_array = *rhs_array_it->second;
1305 ++rhs_array_it;
1306 if (!IsConstantParameterArray(*model, rhs_array_name)) {
1307 // Not a constant array; skip.
1308 continue;
1309 }
1310 if (!IsDiscardableArray(*model, rhs_array_name)) {
1311 // Can't remove the array as it's not discardable (such as an IO edge).
1312 continue;
1313 }
1314 if (!CompareConstantArrays(lhs_array, rhs_array)) {
1315 // Arrays aren't equal; skip.
1316 continue;
1317 }
1318
1319 // Arrays can be deduped!
1320 VLOG(1) << "Deduplicating arrays; using " << lhs_array_name
1321 << " in place of " << rhs_array_name;
1322 ReplaceArrayUsage(model, rhs_array_name, lhs_array_name);
1323 // Note: rhs_array_it above is already incremented so this is safe.
1324 model->EraseArray(rhs_array_name);
1325 }
1326 }
1327 }
1328
1329 namespace {
CopyArrayAttribs(const Array & source_array,Array * target_array)1330 void CopyArrayAttribs(const Array& source_array, Array* target_array) {
1331 target_array->data_type = source_array.data_type;
1332 target_array->final_data_type = source_array.final_data_type;
1333 target_array->copy_shape(source_array.shape());
1334
1335 if (source_array.minmax) {
1336 target_array->GetOrCreateMinMax() = source_array.GetMinMax();
1337 } else {
1338 target_array->minmax.reset();
1339 }
1340
1341 if (source_array.quantization_params) {
1342 target_array->GetOrCreateQuantizationParams() =
1343 source_array.GetQuantizationParams();
1344 } else {
1345 target_array->quantization_params.reset();
1346 }
1347 }
1348 } // namespace
1349
InsertCopyOperator(Model * model,const string & source_array_name,const string & target_array_name)1350 void InsertCopyOperator(Model* model, const string& source_array_name,
1351 const string& target_array_name) {
1352 // Reshape to the same size. This should be a no-op.
1353 const Array& source_array = model->GetArray(source_array_name);
1354 std::vector<int> shape = source_array.shape().dims();
1355
1356 // Drop constant data from the target array as the copy will be done at
1357 // runtime.
1358 Array& target_array = model->GetOrCreateArray(target_array_name);
1359 target_array.buffer.reset();
1360 CopyArrayAttribs(source_array, &target_array);
1361
1362 // Insert copy operator.
1363 auto* copy_op = new TensorFlowReshapeOperator;
1364 copy_op->inputs = {
1365 source_array_name,
1366 CreateInt32Array(
1367 model, AvailableArrayName(*model, target_array_name + "_copy_shape"),
1368 shape)};
1369 copy_op->outputs = {target_array_name};
1370 if (target_array.has_shape()) {
1371 copy_op->shape = target_array.shape().dims();
1372 }
1373 model->operators.emplace_back(copy_op);
1374 }
1375
CloneArray(Model * model,const string & source_array_name,const string & target_array_name)1376 void CloneArray(Model* model, const string& source_array_name,
1377 const string& target_array_name) {
1378 CHECK(!model->HasArray(target_array_name));
1379 const Array& source_array = model->GetArray(source_array_name);
1380 Array& target_array = model->GetOrCreateArray(target_array_name);
1381 CopyArrayAttribs(source_array, &target_array);
1382
1383 if (source_array.minmax) {
1384 const auto& smm = source_array.GetMinMax();
1385 auto& tmm = target_array.GetOrCreateMinMax();
1386 tmm.min = smm.min;
1387 tmm.max = smm.max;
1388 }
1389
1390 if (source_array.quantization_params) {
1391 const auto& sqp = source_array.GetQuantizationParams();
1392 auto& tqp = target_array.GetOrCreateQuantizationParams();
1393 tqp.zero_point = sqp.zero_point;
1394 tqp.scale = sqp.scale;
1395 }
1396
1397 target_array.data_type = source_array.data_type;
1398 target_array.final_data_type = source_array.final_data_type;
1399 target_array.copy_shape(source_array.shape());
1400
1401 switch (source_array.data_type) {
1402 case ArrayDataType::kBool:
1403 CopyArrayBuffer<ArrayDataType::kBool>(source_array, &target_array);
1404 break;
1405 case ArrayDataType::kFloat:
1406 CopyArrayBuffer<ArrayDataType::kFloat>(source_array, &target_array);
1407 break;
1408 case ArrayDataType::kInt8:
1409 CopyArrayBuffer<ArrayDataType::kInt8>(source_array, &target_array);
1410 break;
1411 case ArrayDataType::kUint8:
1412 CopyArrayBuffer<ArrayDataType::kUint8>(source_array, &target_array);
1413 break;
1414 case ArrayDataType::kInt16:
1415 CopyArrayBuffer<ArrayDataType::kInt16>(source_array, &target_array);
1416 break;
1417 case ArrayDataType::kUint16:
1418 CopyArrayBuffer<ArrayDataType::kUint16>(source_array, &target_array);
1419 break;
1420 case ArrayDataType::kInt32:
1421 CopyArrayBuffer<ArrayDataType::kInt32>(source_array, &target_array);
1422 break;
1423 case ArrayDataType::kUint32:
1424 CopyArrayBuffer<ArrayDataType::kUint32>(source_array, &target_array);
1425 break;
1426 case ArrayDataType::kInt64:
1427 CopyArrayBuffer<ArrayDataType::kInt64>(source_array, &target_array);
1428 break;
1429 case ArrayDataType::kUint64:
1430 CopyArrayBuffer<ArrayDataType::kUint64>(source_array, &target_array);
1431 break;
1432 case ArrayDataType::kString:
1433 CopyArrayBuffer<ArrayDataType::kString>(source_array, &target_array);
1434 break;
1435 case ArrayDataType::kComplex64:
1436 CopyArrayBuffer<ArrayDataType::kComplex64>(source_array, &target_array);
1437 break;
1438 default:
1439 LOG(FATAL) << "Unsupported data type: "
1440 << ArrayDataTypeName(source_array.data_type);
1441 return;
1442 }
1443 }
1444
MakeArrayDims(int num_dims,int batch,int height,int width,int depth,std::vector<int> * out_dims)1445 void MakeArrayDims(int num_dims, int batch, int height, int width, int depth,
1446 std::vector<int>* out_dims) {
1447 CHECK(out_dims->empty());
1448 if (num_dims == 0) {
1449 return;
1450 } else if (num_dims == 1) {
1451 CHECK_EQ(batch, 1);
1452 *out_dims = {depth};
1453 } else if (num_dims == 2) {
1454 *out_dims = {batch, depth};
1455 } else if (num_dims == 3) {
1456 CHECK_EQ(batch, 1);
1457 *out_dims = {height, width, depth};
1458 } else if (num_dims == 4) {
1459 *out_dims = {batch, height, width, depth};
1460 } else {
1461 LOG(FATAL) << "Should not get here: " << num_dims;
1462 }
1463 }
1464
CreateOrCheckRnnStateArray(const string & name,int size,int state_num_dims,Model * model)1465 void CreateOrCheckRnnStateArray(const string& name, int size,
1466 int state_num_dims, Model* model) {
1467 int batch = 1;
1468 int num_dims = -1;
1469 if (state_num_dims > 0) {
1470 num_dims = state_num_dims;
1471 } else {
1472 // state_num_dims is not given. We will infer it from an input tensor.
1473 for (const auto& input_array : model->flags.input_arrays()) {
1474 // Pick 'num_dims' and 'batch' from the first input_arrays, unless we find
1475 // a better match by name.
1476 if (input_array.name() == name || num_dims == -1) {
1477 num_dims = input_array.shape().dims_size();
1478 if (num_dims > 0) {
1479 batch = input_array.shape().dims(0);
1480 }
1481 }
1482 }
1483 }
1484 Array& array = model->GetOrCreateArray(name);
1485 if (array.has_shape()) {
1486 num_dims = array.shape().dimensions_count();
1487 }
1488 if (!array.has_shape() && num_dims >= 0) {
1489 Shape* shape = array.mutable_shape();
1490 std::vector<int> dims;
1491 MakeArrayDims(num_dims, batch, 1, 1, size, &dims);
1492 *shape->mutable_dims() = dims;
1493 }
1494 }
1495
ResolveModelFlags(const ModelFlags & model_flags,Model * model)1496 void ResolveModelFlags(const ModelFlags& model_flags, Model* model) {
1497 // Merge info about input_arrays from model_flags into model->flags
1498 for (const auto& specified_input_array : model_flags.input_arrays()) {
1499 toco::InputArray* dst_input_array = nullptr;
1500 for (int i = 0; i < model->flags.input_arrays_size(); i++) {
1501 toco::InputArray* candidate_dst_input_array =
1502 model->flags.mutable_input_arrays(i);
1503 if (candidate_dst_input_array->name() == specified_input_array.name()) {
1504 // specified_input_array from model_flags maps to dst_input_array
1505 // in model->flags
1506 dst_input_array = candidate_dst_input_array;
1507 break;
1508 }
1509 }
1510 if (!dst_input_array) {
1511 // Specified_input_array from model_flags is not found in model->flags.
1512 // Match a name-less specified input array when there can be no ambiguity
1513 // as there is only 1 input array.
1514 if (model->flags.input_arrays_size() == 1 &&
1515 model_flags.input_arrays_size() == 1 &&
1516 !specified_input_array.has_name()) {
1517 dst_input_array = model->flags.mutable_input_arrays(0);
1518 }
1519 }
1520 if (!dst_input_array) {
1521 // Still no match, so create a new input array to copy
1522 // specified_input_array into.
1523 dst_input_array = model->flags.add_input_arrays();
1524 dst_input_array->set_name(specified_input_array.name());
1525 }
1526
1527 #define RESOLVE_MODEL_FLAG(field_name) \
1528 if (specified_input_array.has_##field_name()) { \
1529 if (dst_input_array->has_##field_name()) { \
1530 QCHECK_EQ(dst_input_array->field_name(), \
1531 specified_input_array.field_name()) \
1532 << "For input array '" << dst_input_array->name() << "', " \
1533 << "specified " #field_name " flag with value: " \
1534 << specified_input_array.field_name() \
1535 << " does not agree with already defined " #field_name \
1536 " of this model, with value: " \
1537 << specified_input_array.field_name(); \
1538 } else { \
1539 dst_input_array->set_##field_name(specified_input_array.field_name()); \
1540 } \
1541 }
1542 RESOLVE_MODEL_FLAG(std_value);
1543 RESOLVE_MODEL_FLAG(mean_value);
1544 #undef RESOLVE_MODEL_FLAG
1545
1546 if (specified_input_array.has_shape()) {
1547 if (dst_input_array->has_shape()) {
1548 QCHECK_EQ(specified_input_array.shape().dims_size(),
1549 dst_input_array->shape().dims_size())
1550 << "For input array '" << specified_input_array.name() << "', "
1551 << "size of specified input shape flag with size: "
1552 << specified_input_array.shape().dims_size()
1553 << " does not agree with already defined input shape"
1554 " of this model, with size: "
1555 << dst_input_array->shape().dims_size();
1556 // We treat the first dimension as a special case, since it is often
1557 // a batch size and the input_shape flag is effectively overriding
1558 // the model.
1559 for (int i = 1; i < specified_input_array.shape().dims_size(); i++) {
1560 QCHECK_EQ(specified_input_array.shape().dims(i),
1561 dst_input_array->shape().dims(i))
1562 << "At dimension number " << i << " of input array "
1563 << specified_input_array.name() << ", the specified shape's "
1564 << "dimension flag with dimension: "
1565 << specified_input_array.shape().dims(i)
1566 << " does not agree with already defined shape"
1567 << " of this model, with dimension: "
1568 << dst_input_array->shape().dims(i);
1569 }
1570 } else {
1571 *dst_input_array->mutable_shape() = specified_input_array.shape();
1572 }
1573 }
1574
1575 if (specified_input_array.has_data_type()) {
1576 QCHECK(!dst_input_array->has_data_type());
1577 dst_input_array->set_data_type(specified_input_array.data_type());
1578 }
1579 }
1580
1581 if (model_flags.output_arrays_size() > 0) {
1582 model->flags.mutable_output_arrays()->CopyFrom(model_flags.output_arrays());
1583 }
1584
1585 #define RESOLVE_MODEL_FLAG(name) \
1586 if (model_flags.has_##name()) { \
1587 if (model->flags.has_##name()) { \
1588 QCHECK_EQ(model_flags.name(), model->flags.name()) \
1589 << "Specified " #name " flag with value: " << model_flags.name() \
1590 << " does not agree with already defined " #name \
1591 " of this model, with value: " \
1592 << model->flags.name(); \
1593 } else { \
1594 model->flags.set_##name(model_flags.name()); \
1595 } \
1596 }
1597
1598 RESOLVE_MODEL_FLAG(variable_batch)
1599
1600 #undef RESOLVE_MODEL_FLAG
1601
1602 if (!model_flags.rnn_states().empty()) {
1603 model->flags.mutable_rnn_states()->CopyFrom(model_flags.rnn_states());
1604 }
1605
1606 if (model->flags.model_checks_size() == 0) {
1607 model->flags.mutable_model_checks()->CopyFrom(model_flags.model_checks());
1608 }
1609
1610 QCHECK_GT(model->flags.output_arrays_size(), 0)
1611 << "This model does not define output arrays, so a "
1612 "--output_arrays flag must be given on the command-line.";
1613
1614 for (auto& input_array_proto : *model->flags.mutable_input_arrays()) {
1615 auto& input_array = model->GetOrCreateArray(input_array_proto.name());
1616 if (input_array_proto.has_data_type()) {
1617 const ArrayDataType specified_type =
1618 ConvertIODataTypeToArrayDataType(input_array_proto.data_type());
1619 QCHECK(specified_type != ArrayDataType::kNone);
1620 if (input_array.data_type != ArrayDataType::kNone) {
1621 QCHECK(specified_type == input_array.data_type)
1622 << "For input array " << input_array_proto.name()
1623 << " the specified input data type "
1624 << IODataType_Name(input_array_proto.data_type())
1625 << " conflicts with the existing type.";
1626 }
1627 input_array.data_type = specified_type;
1628 }
1629
1630 if (input_array.data_type == ArrayDataType::kNone) {
1631 // We start out with a float input array;
1632 // that may get replaced by a uint8 array later, by
1633 // MakeInitialDequantizeOp.
1634 input_array.data_type = ArrayDataType::kFloat;
1635 }
1636
1637 // Compare/merge the model->flags describing the input_shape with
1638 // the actual input array's shape.
1639 if (!input_array.has_shape()) {
1640 if (input_array_proto.has_shape()) {
1641 auto& input_array_dims = *input_array.mutable_shape()->mutable_dims();
1642 CheckValidShapeDimensions(input_array_proto.shape().dims());
1643 for (const auto& dim : input_array_proto.shape().dims()) {
1644 input_array_dims.push_back(dim);
1645 }
1646 }
1647 } else {
1648 if (input_array_proto.has_shape()) {
1649 // If an input shape was specified on the flags ensure that it matches
1650 // the actual shape in the model.
1651 const auto& input_array_dims =
1652 *input_array.mutable_shape()->mutable_dims();
1653 CHECK_EQ(input_array_dims.size(),
1654 input_array_proto.shape().dims_size());
1655 for (int i = 0; i < input_array_dims.size(); i++) {
1656 CHECK_EQ(input_array_dims[i], input_array_proto.shape().dims(i));
1657 }
1658 } else {
1659 for (int i = 0; i < input_array.shape().dimensions_count(); i++) {
1660 input_array_proto.mutable_shape()->add_dims(
1661 input_array.shape().dims(i));
1662 }
1663 }
1664 }
1665
1666 const float mean_value = input_array_proto.mean_value();
1667 const float std_value = input_array_proto.std_value();
1668 MinMax input_minmax;
1669 float qmin = 0, qmax = 255;
1670 if (input_array.data_type == ArrayDataType::kInt16) {
1671 qmin = -32768;
1672 qmax = 32767;
1673 }
1674 input_minmax.min = (qmin - mean_value) / std_value;
1675 input_minmax.max = (qmax - mean_value) / std_value;
1676 if (!input_array.minmax) {
1677 input_array.GetOrCreateMinMax() = input_minmax;
1678 }
1679 }
1680
1681 // Creation of the RNN state arrays
1682 for (const auto& rnn_state : model->flags.rnn_states()) {
1683 CreateOrCheckRnnStateArray(rnn_state.state_array(), rnn_state.size(),
1684 rnn_state.num_dims(), model);
1685 }
1686
1687 model->flags.set_change_concat_input_ranges(
1688 model_flags.change_concat_input_ranges());
1689 model->flags.set_allow_nonascii_arrays(model_flags.allow_nonascii_arrays());
1690 model->flags.set_allow_nonexistent_arrays(
1691 model_flags.allow_nonexistent_arrays());
1692
1693 CHECK(!model->flags.has_arrays_extra_info());
1694 *model->flags.mutable_arrays_extra_info() = model_flags.arrays_extra_info();
1695 }
1696
CheckIsReadyForQuantization(const Model & model)1697 void CheckIsReadyForQuantization(const Model& model) {
1698 for (const auto& op : model.operators) {
1699 for (const auto& input : op->inputs) {
1700 const auto& input_array = model.GetArray(input);
1701 if (input_array.data_type != ArrayDataType::kFloat) {
1702 // The array is not floats, no quantization needed.
1703 continue;
1704 }
1705 if (input_array.minmax) {
1706 // The array has minmax, we're good.
1707 continue;
1708 }
1709 if (input_array.buffer) {
1710 // The array has a constant buffer, so we can
1711 // fall back to computing the minmax from actual array entries
1712 // (with a WARNING about possible accuracy implications).
1713 continue;
1714 }
1715 LOG(FATAL)
1716 << "Array " << input << ", which is an input to the "
1717 << HelpfulOperatorTypeName(*op) << " operator producing the output "
1718 << "array " << op->outputs[0] << ", is lacking min/max data, "
1719 << "which is necessary for quantization. If accuracy matters, either "
1720 << "target a non-quantized output format, or run quantized training "
1721 << "with your model from a floating point checkpoint to change the "
1722 << "input graph to contain min/max information. If you don't care "
1723 << "about accuracy, you can pass --default_ranges_min= and "
1724 << "--default_ranges_max= for easy experimentation.";
1725 }
1726 }
1727 }
1728
ElementSize(ArrayDataType data_type)1729 int ElementSize(ArrayDataType data_type) {
1730 switch (data_type) {
1731 case ArrayDataType::kBool:
1732 return sizeof(bool);
1733 case ArrayDataType::kFloat:
1734 return 4;
1735 case ArrayDataType::kInt8:
1736 return 1;
1737 case ArrayDataType::kUint8:
1738 return 1;
1739 case ArrayDataType::kInt16:
1740 return 2;
1741 case ArrayDataType::kUint16:
1742 return 2;
1743 case ArrayDataType::kInt32:
1744 return 4;
1745 case ArrayDataType::kUint32:
1746 return 4;
1747 case ArrayDataType::kInt64:
1748 return 8;
1749 case ArrayDataType::kUint64:
1750 return 8;
1751 case ArrayDataType::kComplex64:
1752 return 8;
1753
1754 // Usually not critical limitation because strings are only input and/or
1755 // output.
1756 case ArrayDataType::kString:
1757 LOG(FATAL) << "Transient arrays with strings are not supported yet";
1758 return 0;
1759 default:
1760 LOG(FATAL) << "Unknown data_type = " << static_cast<int>(data_type);
1761 return 0;
1762 }
1763 }
1764
DropMinMax(Model * model,const string & array_name)1765 void DropMinMax(Model* model, const string& array_name) {
1766 auto& array = model->GetArray(array_name);
1767 if (!!array.minmax) {
1768 LOG(WARNING) << "Dropping MinMax information in array " << array_name
1769 << ". Expect inaccuracy in quantized inference.";
1770 array.minmax = nullptr;
1771 }
1772 }
1773
IsAllocatableTransientArray(const Model & model,const string & array_name)1774 bool IsAllocatableTransientArray(const Model& model, const string& array_name) {
1775 // Optional array is not transient
1776 if (model.IsOptionalArray(array_name)) return false;
1777 // The model's input and output arrays are externally allocated.
1778 // They are not transient arrays.
1779 if (IsInputArray(model, array_name) || IsOutputArray(model, array_name)) {
1780 return false;
1781 }
1782 const auto& array = &model.GetArray(array_name);
1783 // An array with a constant buffer isn't a transient array.
1784 if (!!array->buffer) {
1785 return false;
1786 }
1787 // An array without shape isn't allocatable.
1788 if (!array->has_shape()) {
1789 return false;
1790 }
1791
1792 // The size of string tensors is rarely known ahead of time, so all transient
1793 // tensors of this type will need to be dynamically allocated.
1794 if (array->final_data_type == ArrayDataType::kString ||
1795 array->data_type == ArrayDataType::kString) {
1796 return false;
1797 }
1798
1799 return true;
1800 }
1801
AvailableArrayName(const Model & model,const string & name)1802 string AvailableArrayName(const Model& model, const string& name) {
1803 string sanitized_name = SanitizeNameForTFNode(name);
1804 if (!model.HasArray(sanitized_name) &&
1805 !model.IsOptionalArray(sanitized_name)) {
1806 return sanitized_name;
1807 }
1808 const int kNumSuffixesToTry = 1000;
1809 for (int i = 0; i < kNumSuffixesToTry; i++) {
1810 const string& name_with_suffix =
1811 toco::port::StringF("%s_%d", sanitized_name, i);
1812 if (!model.HasArray(name_with_suffix) &&
1813 !model.IsOptionalArray(name_with_suffix)) {
1814 return name_with_suffix;
1815 }
1816 }
1817 LOG(FATAL) << "Could not find an available array name starting with "
1818 << sanitized_name << ". Tried " << kNumSuffixesToTry
1819 << " suffixes, all were taken!";
1820 return "";
1821 }
1822
ShapeToString(const Shape & shape)1823 string ShapeToString(const Shape& shape) {
1824 if (shape.dimensions_count() == 0) {
1825 return "[]";
1826 }
1827
1828 return absl::StrCat("[ ", absl::StrJoin(shape.dims(), ", "), " ]");
1829 }
1830
PrintArrayShape(Model * model,const string & name)1831 void PrintArrayShape(Model* model, const string& name) {
1832 if (!model->GetArray(name).has_shape()) {
1833 LOG(INFO) << name << " has no shape";
1834 return;
1835 }
1836 LOG(INFO) << name
1837 << " has shape: " << ShapeToString(model->GetArray(name).shape());
1838 }
1839
IsArrayFullyConnectedWeights(const Model & model,const string & name)1840 bool IsArrayFullyConnectedWeights(const Model& model, const string& name) {
1841 bool is_fc_weights = false;
1842 bool is_something_else = false;
1843 for (const auto& op : model.operators) {
1844 for (int input_index = 0; input_index < op->inputs.size(); input_index++) {
1845 if (op->inputs[input_index] == name) {
1846 if (op->type == OperatorType::kFullyConnected && input_index == 1) {
1847 is_fc_weights = true;
1848 } else {
1849 is_something_else = true;
1850 }
1851 }
1852 }
1853 }
1854 CHECK(!(is_fc_weights && is_something_else));
1855 return is_fc_weights;
1856 }
1857
CreateInt32Array(Model * model,const string & param_name,const std::vector<int> & value)1858 string CreateInt32Array(Model* model, const string& param_name,
1859 const std::vector<int>& value) {
1860 auto param_array_name = AvailableArrayName(*model, param_name);
1861 auto& param_array = model->GetOrCreateArray(param_array_name);
1862 param_array.mutable_shape()->ReplaceDims({static_cast<int>(value.size())});
1863 param_array.data_type = ArrayDataType::kInt32;
1864 auto& param_array_data =
1865 param_array.GetMutableBuffer<ArrayDataType::kInt32>().data;
1866 param_array_data.resize(RequiredBufferSizeForShape(param_array.shape()));
1867 for (int i = 0; i < value.size(); ++i) {
1868 param_array_data[i] = value[i];
1869 }
1870 return param_array_name;
1871 }
1872
EstimateArithmeticOpsCount(const Model & model,const Operator & op,int64 * result)1873 bool EstimateArithmeticOpsCount(const Model& model, const Operator& op,
1874 int64* result) {
1875 switch (op.type) {
1876 case OperatorType::kFullyConnected:
1877 case OperatorType::kConv:
1878 case OperatorType::kDepthwiseConv: {
1879 const auto& output_array = model.GetArray(op.outputs[0]);
1880 const auto& weights_array = model.GetArray(op.inputs[1]);
1881 if (!output_array.has_shape() || !weights_array.has_shape()) {
1882 return false;
1883 }
1884 int64 cols = 1;
1885 for (int i = 0; i < output_array.shape().dimensions_count() - 1; i++) {
1886 cols *= output_array.shape().dims(i);
1887 }
1888 const int64 cost_per_col =
1889 2 * RequiredBufferSizeForShape(weights_array.shape());
1890 *result = cost_per_col * cols;
1891 if (op.inputs.size() > 2) {
1892 // There is a bias vector. One more op per output value.
1893 *result += RequiredBufferSizeForShape(output_array.shape());
1894 }
1895 break;
1896 }
1897 case OperatorType::kAdd:
1898 case OperatorType::kSub:
1899 case OperatorType::kMul: {
1900 const auto& output_array = model.GetArray(op.outputs[0]);
1901 if (!output_array.has_shape()) {
1902 return false;
1903 }
1904 *result = RequiredBufferSizeForShape(output_array.shape());
1905 break;
1906 }
1907 case OperatorType::kAddN: {
1908 const auto& output_array = model.GetArray(op.outputs[0]);
1909 if (!output_array.has_shape()) {
1910 return false;
1911 }
1912 // AddN cost is roughly the same cost as N-1 Adds.
1913 const int64 num_adds = op.inputs.size() - 1;
1914 *result = num_adds * RequiredBufferSizeForShape(output_array.shape());
1915 break;
1916 }
1917 case OperatorType::kLogistic:
1918 case OperatorType::kSoftmax:
1919 case OperatorType::kLogSoftmax:
1920 case OperatorType::kTanh: {
1921 const auto& output_array = model.GetArray(op.outputs[0]);
1922 if (!output_array.has_shape()) {
1923 return false;
1924 }
1925 // As a very rough ballpark, the cost of evaluating a math function
1926 // such as tanh or logistic is about 32 multiplications, and about as
1927 // many additions/subtractions. (Just a power-of-two order-of-magnitude
1928 // from looking at actual implementations that we use in runtime/ code).
1929 *result = 64 * RequiredBufferSizeForShape(output_array.shape());
1930 break;
1931 }
1932 case OperatorType::kMaxPool: {
1933 const auto& maxpool = *static_cast<const MaxPoolOperator*>(&op);
1934 const auto& output_array = model.GetArray(op.outputs[0]);
1935 if (!output_array.has_shape()) {
1936 return false;
1937 }
1938 *result = RequiredBufferSizeForShape(output_array.shape()) *
1939 maxpool.kheight * maxpool.kwidth;
1940 break;
1941 }
1942 case OperatorType::kAveragePool: {
1943 const auto& avgpool = *static_cast<const AveragePoolOperator*>(&op);
1944 const auto& output_array = model.GetArray(op.outputs[0]);
1945 if (!output_array.has_shape()) {
1946 return false;
1947 }
1948 *result = RequiredBufferSizeForShape(output_array.shape()) *
1949 avgpool.kheight * avgpool.kwidth;
1950 break;
1951 }
1952 case OperatorType::kL2Pool: {
1953 const auto* maxpool = static_cast<const MaxPoolOperator*>(&op);
1954 const auto& output_array = model.GetArray(op.outputs[0]);
1955 if (!output_array.has_shape()) {
1956 return false;
1957 }
1958 // The sum of squares requires (kheight*kwidth) multiply-adds,
1959 // and then there is the sqrt which we ballpark at 32 ops.
1960 const int64 cost_per_val = 2 * maxpool->kheight * maxpool->kwidth + 32;
1961 *result = RequiredBufferSizeForShape(output_array.shape()) * cost_per_val;
1962 break;
1963 }
1964 case OperatorType::kL2Normalization: {
1965 const auto& output_array = model.GetArray(op.outputs[0]);
1966 if (!output_array.has_shape()) {
1967 return false;
1968 }
1969 // Computing the squared L2 norm is N multiply-adds so 2N ops,
1970 // then the single inverse-sqrt is negligible, then we multiply each
1971 // value by the resulting multiplier, so an extra N ops. count 3N ops.
1972 *result = 3 * RequiredBufferSizeForShape(output_array.shape());
1973 break;
1974 }
1975 default:
1976 *result = 0;
1977 break;
1978 }
1979 return true;
1980 }
1981
EstimateArithmeticOpsCount(const Model & model,int64 * result)1982 bool EstimateArithmeticOpsCount(const Model& model, int64* result) {
1983 int64 total = 0;
1984 for (const auto& op : model.operators) {
1985 int64 num_ops;
1986 if (!EstimateArithmeticOpsCount(model, *op, &num_ops)) {
1987 return false;
1988 }
1989 total += num_ops;
1990 }
1991 *result = total;
1992 return true;
1993 }
1994
FormattedNumber(int64 x)1995 string FormattedNumber(int64 x) {
1996 const int64 million = 1000000;
1997 const int64 billion = 1000000000;
1998 if (x < 10000) {
1999 return toco::port::StringF("%d ", x);
2000 } else if (x < billion) {
2001 return toco::port::StringF("%.3f M", static_cast<double>(x) / million);
2002 } else {
2003 return toco::port::StringF("%.3f G", static_cast<double>(x) / billion);
2004 }
2005 }
2006
GetShuffleShape(AxesOrder input_axes_order,AxesOrder output_axes_order,std::vector<int> * shuffle)2007 void GetShuffleShape(AxesOrder input_axes_order, AxesOrder output_axes_order,
2008 std::vector<int>* shuffle) {
2009 CHECK_EQ(AxesCount(input_axes_order), AxesCount(output_axes_order));
2010 shuffle->resize(4);
2011 for (int i = 0; i < 4; i++) {
2012 (*shuffle)[i] = i;
2013 }
2014 if (input_axes_order == output_axes_order) {
2015 // nothing to do
2016 } else if (AxesCount(input_axes_order) == 2) {
2017 shuffle->resize(2);
2018 (*shuffle)[0] = 1;
2019 (*shuffle)[1] = 0;
2020 } else if (input_axes_order == AxesOrder::kOHWI &&
2021 output_axes_order == AxesOrder::kHWIO) {
2022 // 3210 <- 3210
2023 // HWIO <- OHWI
2024 *shuffle = {1, 2, 3, 0};
2025 } else if (input_axes_order == AxesOrder::kHWIO &&
2026 output_axes_order == AxesOrder::kOHWI) {
2027 // 3210 <- 3210
2028 // OHWI <- HWIO
2029 *shuffle = {3, 0, 1, 2};
2030 } else if (input_axes_order == AxesOrder::kOHWI &&
2031 output_axes_order == AxesOrder::kHWOI) {
2032 *shuffle = {1, 2, 0, 3};
2033 } else {
2034 LOG(FATAL) << "Bad shuffle";
2035 }
2036 }
2037
ExtendShuffle(const std::vector<int> & input_shuffle,int newdim,std::vector<int> * extended_shuffle)2038 void ExtendShuffle(const std::vector<int>& input_shuffle, int newdim,
2039 std::vector<int>* extended_shuffle) {
2040 *extended_shuffle = input_shuffle;
2041 CHECK(newdim >= input_shuffle.size());
2042 const int pad_size = newdim - input_shuffle.size();
2043 extended_shuffle->resize(newdim);
2044 for (int i = 0; i < pad_size; i++) {
2045 (*extended_shuffle)[i] = i;
2046 }
2047 for (int i = pad_size; i < newdim; i++) {
2048 (*extended_shuffle)[i] = input_shuffle[i - pad_size] + pad_size;
2049 }
2050 }
2051
ShuffleDims(const Shape & input_shape,AxesOrder input_axes_order,AxesOrder output_axes_order,Shape * output_shape)2052 void ShuffleDims(const Shape& input_shape, AxesOrder input_axes_order,
2053 AxesOrder output_axes_order, Shape* output_shape) {
2054 if (input_axes_order == AxesOrder::kHWIM &&
2055 output_axes_order == AxesOrder::k1HWO) {
2056 // This special case isn't just a permutation, the IM pair of dims get
2057 // merged into the 3 dim, so we have to special-case it.
2058 *output_shape = Shape({1, input_shape.dims(0), input_shape.dims(1),
2059 input_shape.dims(3) * input_shape.dims(2)});
2060 } else {
2061 std::vector<int> shuffle;
2062 GetShuffleShape(input_axes_order, output_axes_order, &shuffle);
2063 std::vector<int>* output_dims = output_shape->mutable_dims();
2064 output_dims->resize(input_shape.dimensions_count());
2065 for (int i = 0; i < input_shape.dimensions_count(); i++) {
2066 (*output_dims)[i] = input_shape.dims(shuffle[i]);
2067 }
2068 }
2069 }
2070
2071 template <typename T>
ShuffleArrayTemplate(const Shape & input_shape,AxesOrder input_axes_order,AxesOrder output_axes_order,const Shape & output_shape,const T * input_data,T * output_data)2072 void ShuffleArrayTemplate(const Shape& input_shape, AxesOrder input_axes_order,
2073 AxesOrder output_axes_order,
2074 const Shape& output_shape, const T* input_data,
2075 T* output_data) {
2076 if (input_axes_order == AxesOrder::kHWIM &&
2077 output_axes_order == AxesOrder::k1HWO) {
2078 // This special case isn't just a permutation, the IM pair of dims get
2079 // merged into the O dim, so we have to special-case it. Fortunately,
2080 // as far as array shuffling is concerned, it's just the identity
2081 // transformation.
2082 memcpy(output_data, input_data,
2083 RequiredBufferSizeForShape(input_shape) * sizeof(output_data[0]));
2084 return;
2085 }
2086 CHECK(input_shape.dimensions_count() == output_shape.dimensions_count());
2087 const int dim = input_shape.dimensions_count();
2088 CHECK_LE(dim, 4);
2089 std::vector<int> shuffle;
2090 GetShuffleShape(input_axes_order, output_axes_order, &shuffle);
2091 CHECK(shuffle.size() >= dim);
2092 for (int i = 0; i < dim; i++) {
2093 CHECK(shuffle[i] >= 0 && shuffle[i] < dim);
2094 CHECK(input_shape.dims(shuffle[i]) == output_shape.dims(i));
2095 }
2096 Shape extended_input_shape = input_shape;
2097 ExtendShape(&extended_input_shape, 4);
2098 Shape extended_output_shape = output_shape;
2099 ExtendShape(&extended_output_shape, 4);
2100 std::vector<int> extended_shuffle;
2101 ExtendShuffle(shuffle, 4, &extended_shuffle);
2102
2103 const std::vector<int>& extended_input_dims = extended_input_shape.dims();
2104 const std::vector<int>& extended_output_dims = extended_output_shape.dims();
2105
2106 // TODO(starka): Rework to handle different numbers of dimensions.
2107 int input_strides[4];
2108 input_strides[3] = 1;
2109 input_strides[2] = extended_input_dims[3];
2110 input_strides[1] = input_strides[2] * extended_input_dims[2];
2111 input_strides[0] = input_strides[1] * extended_input_dims[1];
2112 const int input_stride_0 = input_strides[extended_shuffle[3]];
2113 const int input_stride_1 = input_strides[extended_shuffle[2]];
2114 const int input_stride_2 = input_strides[extended_shuffle[1]];
2115 const int input_stride_3 = input_strides[extended_shuffle[0]];
2116
2117 const int output_size_0 = extended_output_dims[3];
2118 const int output_size_1 = extended_output_dims[2];
2119 const int output_size_2 = extended_output_dims[1];
2120 const int output_size_3 = extended_output_dims[0];
2121 const int output_stride_0 = 1;
2122 const int output_stride_1 = output_size_0;
2123 const int output_stride_2 = output_stride_1 * output_size_1;
2124 const int output_stride_3 = output_stride_2 * output_size_2;
2125
2126 for (int i3 = 0; i3 < output_size_3; i3++) {
2127 const T* const input_ptr_3 = input_data + i3 * input_stride_3;
2128 T* const output_ptr_3 = output_data + i3 * output_stride_3;
2129 for (int i2 = 0; i2 < output_size_2; i2++) {
2130 const T* const input_ptr_2 = input_ptr_3 + i2 * input_stride_2;
2131 T* const output_ptr_2 = output_ptr_3 + i2 * output_stride_2;
2132 for (int i1 = 0; i1 < output_size_1; i1++) {
2133 const T* input_ptr = input_ptr_2 + i1 * input_stride_1;
2134 T* output_ptr = output_ptr_2 + i1 * output_stride_1;
2135 T* const output_ptr_end = output_ptr + output_size_0 * output_stride_0;
2136 while (output_ptr != output_ptr_end) {
2137 *output_ptr = *input_ptr;
2138 input_ptr += input_stride_0;
2139 output_ptr += output_stride_0;
2140 }
2141 }
2142 }
2143 }
2144 }
2145
ShuffleArray(const Shape & input_shape,AxesOrder input_axes_order,AxesOrder output_axes_order,const Shape & output_shape,const uint8 * input_data,uint8 * output_data)2146 void ShuffleArray(const Shape& input_shape, AxesOrder input_axes_order,
2147 AxesOrder output_axes_order, const Shape& output_shape,
2148 const uint8* input_data, uint8* output_data) {
2149 ShuffleArrayTemplate<uint8>(input_shape, input_axes_order, output_axes_order,
2150 output_shape, input_data, output_data);
2151 }
2152
ShuffleArray(const Shape & input_shape,AxesOrder input_axes_order,AxesOrder output_axes_order,const Shape & output_shape,const float * input_data,float * output_data)2153 void ShuffleArray(const Shape& input_shape, AxesOrder input_axes_order,
2154 AxesOrder output_axes_order, const Shape& output_shape,
2155 const float* input_data, float* output_data) {
2156 ShuffleArrayTemplate<float>(input_shape, input_axes_order, output_axes_order,
2157 output_shape, input_data, output_data);
2158 }
2159
AxesCount(AxesOrder axes_order)2160 int AxesCount(AxesOrder axes_order) {
2161 switch (axes_order) {
2162 case AxesOrder::kOneAxis:
2163 return 1;
2164 case AxesOrder::kRC:
2165 return 2;
2166 case AxesOrder::kCR:
2167 return 2;
2168 case AxesOrder::kHWIO:
2169 return 4;
2170 case AxesOrder::kOHWI:
2171 return 4;
2172 case AxesOrder::kHWIM:
2173 return 4;
2174 case AxesOrder::k1HWO:
2175 return 4;
2176 case AxesOrder::kNHWC:
2177 return 4;
2178 case AxesOrder::kHWOI:
2179 return 4;
2180 default:
2181 LOG(FATAL) << "Bad AxesOrder";
2182 return 0;
2183 }
2184 }
2185
IsDiscardableArray(const Model & model,const string & array_name)2186 bool IsDiscardableArray(const Model& model, const string& array_name) {
2187 if (IsInputArray(model, array_name) || IsOutputArray(model, array_name)) {
2188 return false;
2189 }
2190 for (const auto& rnn_state : model.flags.rnn_states()) {
2191 if (!rnn_state.discardable()) {
2192 if (array_name == rnn_state.state_array()) {
2193 return false;
2194 }
2195 if (array_name == rnn_state.back_edge_source_array()) {
2196 return false;
2197 }
2198 }
2199 }
2200 return true;
2201 }
2202
ReshapeIsEquivalentToTranspose(const Model & model,const TensorFlowReshapeOperator * op,bool allow_extra_unary_dims)2203 bool ReshapeIsEquivalentToTranspose(const Model& model,
2204 const TensorFlowReshapeOperator* op,
2205 bool allow_extra_unary_dims) {
2206 CHECK(!op->shape.empty());
2207 CHECK(model.HasArray(op->inputs[0]));
2208 CHECK(model.HasArray(op->outputs[0]));
2209
2210 const auto& input_array = model.GetArray(op->inputs[0]);
2211 const auto& output_array = model.GetArray(op->outputs[0]);
2212
2213 CHECK(input_array.has_shape());
2214 CHECK(output_array.has_shape());
2215
2216 std::vector<int> in_shape = input_array.shape().dims();
2217 std::vector<int> out_shape = output_array.shape().dims();
2218
2219 // If the reshape changes the number of dimensions so it cannot be interpreted
2220 // as a transpose.
2221 if (!allow_extra_unary_dims && in_shape.size() != out_shape.size()) {
2222 return false;
2223 }
2224
2225 in_shape.erase(std::remove(in_shape.begin(), in_shape.end(), 1),
2226 in_shape.end());
2227 out_shape.erase(std::remove(out_shape.begin(), out_shape.end(), 1),
2228 out_shape.end());
2229 return in_shape == out_shape;
2230 }
2231
CheckFinalDataTypesSatisfied(const Model & model)2232 void CheckFinalDataTypesSatisfied(const Model& model) {
2233 for (const auto& array_entry : model.GetArrayMap()) {
2234 const auto& array = *array_entry.second;
2235 if (array.data_type == ArrayDataType::kBool) {
2236 // Boolean values are never quantized.
2237 continue;
2238 }
2239
2240 // If the final data type is int16, the data type may be float, for example
2241 // after dequantization.
2242 if (array.final_data_type != ArrayDataType::kNone &&
2243 array.final_data_type != ArrayDataType::kInt16) {
2244 CHECK(array.data_type == array.final_data_type)
2245 << "Array \"" << array_entry.first
2246 << "\" has mis-matching actual and final data types (data_type="
2247 << ArrayDataTypeName(array.data_type)
2248 << ", final_data_type=" << ArrayDataTypeName(array.final_data_type)
2249 << ").";
2250 }
2251 }
2252 }
2253
ConvertIODataTypeToArrayDataType(IODataType type)2254 ArrayDataType ConvertIODataTypeToArrayDataType(IODataType type) {
2255 switch (type) {
2256 case FLOAT:
2257 return ArrayDataType::kFloat;
2258 case QUANTIZED_UINT8:
2259 return ArrayDataType::kUint8;
2260 case INT8:
2261 return ArrayDataType::kInt8;
2262 case QUANTIZED_INT16:
2263 return ArrayDataType::kInt16;
2264 case INT32:
2265 return ArrayDataType::kInt32;
2266 case INT64:
2267 return ArrayDataType::kInt64;
2268 case BOOL:
2269 return ArrayDataType::kBool;
2270 case STRING:
2271 return ArrayDataType::kString;
2272 case COMPLEX64:
2273 return ArrayDataType::kComplex64;
2274 default:
2275 return ArrayDataType::kNone;
2276 }
2277 }
2278
FinishBuildingRNNStates(Model * model)2279 void FinishBuildingRNNStates(Model* model) {
2280 for (const auto& rnn_state : model->flags.rnn_states()) {
2281 if (!model->HasArray(rnn_state.back_edge_source_array()) ||
2282 !model->HasArray(rnn_state.state_array())) {
2283 CHECK(model->HasArray(rnn_state.back_edge_source_array()));
2284 CHECK(model->HasArray(rnn_state.state_array()));
2285 continue;
2286 }
2287 const auto& src_array = model->GetArray(rnn_state.back_edge_source_array());
2288 auto& dst_array = model->GetArray(rnn_state.state_array());
2289 if (src_array.data_type == ArrayDataType::kNone &&
2290 dst_array.data_type == ArrayDataType::kNone) {
2291 dst_array.data_type = ArrayDataType::kFloat;
2292 }
2293 }
2294 }
2295
2296 // Returns the array names that match the ArraysExtraInfo's name and
2297 // name_regexp. The regexp match is for a full match.
ScanArrayNames(const Model & model,const toco::ArraysExtraInfo_Entry & entry)2298 std::unordered_set<string> ScanArrayNames(
2299 const Model& model, const toco::ArraysExtraInfo_Entry& entry) {
2300 std::unordered_set<string> matches;
2301 if (model.HasArray(entry.name())) {
2302 matches.insert(entry.name());
2303 }
2304 if (!entry.name_regexp().empty()) {
2305 const auto& arrays = model.GetArrayMap();
2306 const RE2 name_regexp = {entry.name_regexp()};
2307 for (auto it = arrays.begin(); it != arrays.end(); ++it) {
2308 if (RE2::FullMatch(it->first, name_regexp)) {
2309 matches.insert(it->first);
2310 }
2311 }
2312 }
2313 return matches;
2314 }
2315
UseArraysExtraInfo(Model * model,bool quantize_output)2316 void UseArraysExtraInfo(Model* model, bool quantize_output) {
2317 for (const auto& entry : model->flags.arrays_extra_info().entries()) {
2318 const auto matches = ScanArrayNames(*model, entry);
2319 for (const auto& matched_name : matches) {
2320 auto& array = model->GetArray(matched_name);
2321 if (entry.has_min() || entry.has_max()) {
2322 CHECK_EQ(entry.has_min(), entry.has_max());
2323 auto& minmax = array.GetOrCreateMinMax();
2324 minmax.min = entry.min();
2325 minmax.max = entry.max();
2326 }
2327 if (entry.has_data_type() && quantize_output) {
2328 array.final_data_type =
2329 ConvertIODataTypeToArrayDataType(entry.data_type());
2330 }
2331 if (entry.has_shape()) {
2332 array.clear_shape();
2333 // Make sure to create the shape even if there are no dims, to
2334 // correctly record 0-D shapes.
2335 array.mutable_shape();
2336 for (const auto& dim : entry.shape().dims()) {
2337 array.mutable_shape()->mutable_dims()->push_back(dim);
2338 }
2339 }
2340 if (entry.has_constant_float_value()) {
2341 CHECK(array.has_shape());
2342 if (array.data_type == ArrayDataType::kFloat) {
2343 auto& data = array.GetMutableBuffer<ArrayDataType::kFloat>().data;
2344 data.resize(RequiredBufferSizeForShape(array.shape()));
2345 for (float& f : data) {
2346 f = entry.constant_float_value();
2347 }
2348 }
2349 }
2350 }
2351 }
2352 }
2353
UndoWeightsShuffling(Model * model)2354 void UndoWeightsShuffling(Model* model) {
2355 for (const auto& op : model->operators) {
2356 if (op->type != toco::OperatorType::kFullyConnected) {
2357 continue;
2358 }
2359 const auto& fc_op = static_cast<toco::FullyConnectedOperator&>(*op);
2360 if (fc_op.weights_format == FullyConnectedWeightsFormat::kDefault) {
2361 continue;
2362 }
2363 const string& weights_name = fc_op.inputs[1];
2364 QCHECK_EQ(CountOpsWithInput(*model, weights_name), 1);
2365 auto& weights_array = model->GetArray(weights_name);
2366 QCHECK(weights_array.data_type == ArrayDataType::kUint8);
2367 auto& weights_data =
2368 weights_array.GetMutableBuffer<toco::ArrayDataType::kUint8>().data;
2369 const auto& weights_shape = weights_array.shape();
2370 QCHECK_EQ(weights_shape.dimensions_count(), 2);
2371 const int rows = weights_shape.dims(0);
2372 const int cols = weights_shape.dims(1);
2373 QCHECK_EQ(rows % 4, 0);
2374 QCHECK_EQ(cols % 16, 0);
2375 CHECK_EQ(rows * cols, weights_data.size());
2376 // Compute the de-shuffled weights
2377 std::vector<uint8> deshuffled_data(weights_data.size());
2378 uint8* shuffled_data_ptr = weights_data.data();
2379 for (int r = 0; r < rows; r += 4) {
2380 for (int c = 0; c < cols; c += 16) {
2381 for (int i = 0; i < 4; i++) {
2382 uint8* deshuffled_data_ptr =
2383 deshuffled_data.data() + (r + i) * cols + c;
2384 for (int j = 0; j < 16; j++) {
2385 uint8 shuffled_val = *shuffled_data_ptr++;
2386 // Deshuffling isn't only about deshuffling the storage layout,
2387 // it's also about undoing the flipping of the sign bit, which is
2388 // performed on the shuffled weights.
2389 uint8 deshuffled_val = shuffled_val ^ 0x80;
2390 *deshuffled_data_ptr++ = deshuffled_val;
2391 }
2392 }
2393 }
2394 }
2395 CHECK_EQ(shuffled_data_ptr, weights_data.data() + rows * cols);
2396 // Switch this FC op to using the deshuffled weights.
2397 weights_data = std::move(deshuffled_data);
2398 }
2399 }
2400
CopyMinMaxAndQuantizationRelatedFields(const Array & src,Array * dst)2401 void CopyMinMaxAndQuantizationRelatedFields(const Array& src, Array* dst) {
2402 if (src.minmax) {
2403 dst->GetOrCreateMinMax() = src.GetMinMax();
2404 }
2405 if (src.quantization_params) {
2406 dst->GetOrCreateQuantizationParams() = src.GetQuantizationParams();
2407 }
2408 dst->narrow_range = src.narrow_range;
2409 }
2410
2411 } // namespace toco
2412