• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/toco/tooling_util.h"
16 
17 #include <functional>
18 #include <iterator>
19 #include <set>
20 #include <unordered_map>
21 #include <unordered_set>
22 #include <utility>
23 
24 #include "absl/strings/ascii.h"
25 #include "absl/strings/str_cat.h"
26 #include "absl/strings/str_join.h"
27 #include "absl/strings/str_replace.h"
28 #include "absl/strings/str_split.h"
29 #include "re2/re2.h"
30 #include "tensorflow/core/lib/core/status.h"
31 #include "tensorflow/core/platform/logging.h"
32 #include "tensorflow/lite/toco/dump_graphviz.h"
33 #include "tensorflow/lite/toco/model_flags.pb.h"
34 #include "tensorflow/lite/toco/toco_graphviz_dump_options.h"
35 
36 namespace toco {
37 
38 // Find the longest common prefix of two strings.
FindLongestCommonPrefix(absl::string_view a,absl::string_view b)39 absl::string_view FindLongestCommonPrefix(absl::string_view a,
40                                           absl::string_view b) {
41   if (a.empty() || b.empty()) return absl::string_view();
42 
43   const char* pa = a.data();
44   const char* pb = b.data();
45   size_t count = 0;
46   const size_t limit = std::min(a.size(), b.size());
47   while (count < limit && *pa == *pb) {
48     ++pa;
49     ++pb;
50     ++count;
51   }
52 
53   return absl::string_view(a.data(), count);
54 }
55 
LogName(const Operator & op)56 std::string LogName(const Operator& op) {
57   const std::string& opname = HelpfulOperatorTypeName(op);
58   if (op.outputs.empty()) {
59     return toco::port::StringF("{%s operator}", opname);
60   } else {
61     return toco::port::StringF("{%s operator with output %s}", opname,
62                                op.outputs[0]);
63   }
64 }
65 
ArrayDataTypeName(ArrayDataType data_type)66 std::string ArrayDataTypeName(ArrayDataType data_type) {
67   switch (data_type) {
68     case ArrayDataType::kFloat:
69       return "float";
70     case ArrayDataType::kInt8:
71       return "int8";
72     case ArrayDataType::kUint8:
73       return "uint8";
74     case ArrayDataType::kInt16:
75       return "int16";
76     case ArrayDataType::kUint16:
77       return "uint16";
78     case ArrayDataType::kInt32:
79       return "int32";
80     case ArrayDataType::kUint32:
81       return "uint32";
82     case ArrayDataType::kInt64:
83       return "int64";
84     case ArrayDataType::kUint64:
85       return "uint64";
86     case ArrayDataType::kString:
87       return "string";
88     case ArrayDataType::kBool:
89       return "bool";
90     case ArrayDataType::kComplex64:
91       return "complex64";
92     case ArrayDataType::kNone:
93       return "None";
94     default:
95       LOG(FATAL) << "Unhandled array data type " << static_cast<int>(data_type);
96   }
97 }
98 
IsInputArray(const Model & model,const std::string & array_name)99 bool IsInputArray(const Model& model, const std::string& array_name) {
100   for (const auto& input_array : model.flags.input_arrays()) {
101     if (array_name == input_array.name()) {
102       return true;
103     }
104   }
105   return false;
106 }
107 
IsOutputArray(const Model & model,const std::string & array_name)108 bool IsOutputArray(const Model& model, const std::string& array_name) {
109   for (const auto& output_array : model.flags.output_arrays()) {
110     if (array_name == output_array) {
111       return true;
112     }
113   }
114   return false;
115 }
116 
IsArrayConsumed(const Model & model,const std::string & name)117 bool IsArrayConsumed(const Model& model, const std::string& name) {
118   if (GetOpWithInput(model, name)) {
119     return true;
120   }
121   if (IsOutputArray(model, name)) {
122     return true;
123   }
124   for (const auto& rnn_state : model.flags.rnn_states()) {
125     if (rnn_state.back_edge_source_array() == name) {
126       return true;
127     }
128   }
129   return false;
130 }
131 
CountTrueOutputs(const Model & model,const Operator & op)132 int CountTrueOutputs(const Model& model, const Operator& op) {
133   int count = 0;
134   for (const std::string& output : op.outputs) {
135     if (IsArrayConsumed(model, output)) {
136       ++count;
137     }
138   }
139   return count;
140 }
141 
CountOpsWithInput(const Model & model,const std::string & array_name)142 int CountOpsWithInput(const Model& model, const std::string& array_name) {
143   int count = 0;
144   for (const auto& op : model.operators) {
145     for (auto& input : op->inputs) {
146       if (input == array_name) {
147         count++;
148         // Breaking here is important: some graphs have ops that use the
149         // same array as more than one of their inputs, and in that case
150         // we want it counted only once.
151         break;
152       }
153     }
154   }
155   return count;
156 }
157 
DeleteArrayIfUnused(const std::string & array_name,Model * model)158 bool DeleteArrayIfUnused(const std::string& array_name, Model* model) {
159   if (IsDiscardableArray(*model, array_name) &&
160       CountOpsWithInput(*model, array_name) == 0 &&
161       GetOpWithOutput(*model, array_name) == nullptr) {
162     model->EraseArray(array_name);
163     return true;
164   }
165   return false;
166 }
167 
DeleteArrayIfUnusedOutsideOfOp(const std::string & array_name,const Operator * op,Model * model)168 bool DeleteArrayIfUnusedOutsideOfOp(const std::string& array_name,
169                                     const Operator* op, Model* model) {
170   if (!IsDiscardableArray(*model, array_name)) {
171     return false;
172   }
173   if (CountOpsWithInput(*model, array_name) > 1) {
174     return false;
175   }
176   const Operator* op_having_this_as_input = GetOpWithInput(*model, array_name);
177   if (op_having_this_as_input && op_having_this_as_input != op) {
178     return false;
179   }
180   const Operator* op_having_this_as_output =
181       GetOpWithOutput(*model, array_name);
182   if (op_having_this_as_output && op_having_this_as_output != op) {
183     return false;
184   }
185   model->EraseArray(array_name);
186   return true;
187 }
188 
DeleteOpAndArrays(Model * model,const Operator * op)189 void DeleteOpAndArrays(Model* model, const Operator* op) {
190   for (const std::string& array_name : op->inputs) {
191     DeleteArrayIfUnusedOutsideOfOp(array_name, op, model);
192   }
193   for (const std::string& array_name : op->outputs) {
194     DeleteArrayIfUnusedOutsideOfOp(array_name, op, model);
195   }
196   auto op_it = FindOp(*model, op);
197   CHECK(op_it != model->operators.end());
198   model->operators.erase(op_it);
199 }
200 
FindOpWithOutput(const Model & model,const std::string & array_name)201 std::vector<std::unique_ptr<Operator>>::const_iterator FindOpWithOutput(
202     const Model& model, const std::string& array_name) {
203   for (auto it = model.operators.begin(); it != model.operators.end(); ++it) {
204     for (auto& output : it->get()->outputs) {
205       if (output == array_name) {
206         return it;
207       }
208     }
209   }
210   return model.operators.end();
211 }
212 
FindOpWithOutput(Model & model,const std::string & array_name)213 std::vector<std::unique_ptr<Operator>>::iterator FindOpWithOutput(
214     Model& model, const std::string& array_name) {
215   for (auto it = model.operators.begin(); it != model.operators.end(); ++it) {
216     for (auto& output : it->get()->outputs) {
217       if (output == array_name) {
218         return it;
219       }
220     }
221   }
222   return model.operators.end();
223 }
224 
GetOpWithOutput(const Model & model,const std::string & array_name)225 Operator* GetOpWithOutput(const Model& model, const std::string& array_name) {
226   auto it = FindOpWithOutput(model, array_name);
227   return it == model.operators.end() ? nullptr : it->get();
228 }
229 
230 // GetFirstOpWithInput assumes that this finds the first op.
FindOpWithInput(const Model & model,const std::string & array_name)231 std::vector<std::unique_ptr<Operator>>::const_iterator FindOpWithInput(
232     const Model& model, const std::string& array_name) {
233   for (auto it = model.operators.begin(); it != model.operators.end(); ++it) {
234     for (auto& input : it->get()->inputs) {
235       if (input == array_name) {
236         return it;
237       }
238     }
239   }
240   return model.operators.end();
241 }
242 
FindOpWithInput(Model & model,const std::string & array_name)243 std::vector<std::unique_ptr<Operator>>::iterator FindOpWithInput(
244     Model& model, const std::string& array_name) {
245   for (auto it = model.operators.begin(); it != model.operators.end(); ++it) {
246     for (auto& input : it->get()->inputs) {
247       if (input == array_name) {
248         return it;
249       }
250     }
251   }
252   return model.operators.end();
253 }
254 
FindOp(const Model & model,const Operator * op)255 std::vector<std::unique_ptr<Operator>>::const_iterator FindOp(
256     const Model& model, const Operator* op) {
257   for (auto it = model.operators.begin(); it != model.operators.end(); ++it) {
258     if (it->get() == op) {
259       return it;
260     }
261   }
262   return model.operators.end();
263 }
264 
FindOp(Model & model,const Operator * op)265 std::vector<std::unique_ptr<Operator>>::iterator FindOp(Model& model,
266                                                         const Operator* op) {
267   for (auto it = model.operators.begin(); it != model.operators.end(); ++it) {
268     if (it->get() == op) {
269       return it;
270     }
271   }
272   return model.operators.end();
273 }
274 
GetOpWithInput(const Model & model,const std::string & array_name)275 Operator* GetOpWithInput(const Model& model, const std::string& array_name) {
276   auto it = FindOpWithInput(model, array_name);
277   return it == model.operators.end() ? nullptr : it->get();
278 }
279 
GetFirstOpWithInput(const Model & model,const std::string & array_name)280 Operator* GetFirstOpWithInput(const Model& model,
281                               const std::string& array_name) {
282   auto it = FindOpWithInput(model, array_name);
283   return it == model.operators.end() ? nullptr : it->get();
284 }
285 
ReplaceArrayUsage(Model * model,const std::string & old_array_name,const std::string & new_array_name)286 void ReplaceArrayUsage(Model* model, const std::string& old_array_name,
287                        const std::string& new_array_name) {
288   for (auto& op_it : model->operators) {
289     Operator* op = op_it.get();
290     for (size_t i = 0; i < op->inputs.size(); ++i) {
291       if (op->inputs[i] == old_array_name) {
292         op->inputs[i] = new_array_name;
293       }
294     }
295     for (size_t i = 0; i < op->outputs.size(); ++i) {
296       if (op->outputs[i] == old_array_name) {
297         op->outputs[i] = new_array_name;
298       }
299     }
300   }
301 }
302 
FormatArraysList(const Model & model,const std::vector<std::string> & list)303 std::string FormatArraysList(const Model& model,
304                              const std::vector<std::string>& list) {
305   if (list.empty()) {
306     return "[]";
307   }
308   std::string result = "";
309   if (list.size() > 1) {
310     result += "[ ";
311   }
312   for (std::size_t i = 0; i < list.size(); i++) {
313     if (i > 0) {
314       result += ", ";
315     }
316     result += list[i];
317   }
318   if (list.size() > 1) {
319     result += " ]";
320   }
321   return result;
322 }
323 
OperatorTypeName(OperatorType type)324 const char* OperatorTypeName(OperatorType type) {
325   switch (type) {
326 #define HANDLE_OPERATORTYPENAME_CASE(c) \
327   case OperatorType::k##c:              \
328     return #c;
329     HANDLE_OPERATORTYPENAME_CASE(Abs)
330     HANDLE_OPERATORTYPENAME_CASE(Add)
331     HANDLE_OPERATORTYPENAME_CASE(AddN)
332     HANDLE_OPERATORTYPENAME_CASE(AveragePool)
333     HANDLE_OPERATORTYPENAME_CASE(BatchMatMul)
334     HANDLE_OPERATORTYPENAME_CASE(BatchNormalization)
335     HANDLE_OPERATORTYPENAME_CASE(Conv)
336     HANDLE_OPERATORTYPENAME_CASE(Concatenation)
337     HANDLE_OPERATORTYPENAME_CASE(DepthwiseConv)
338     HANDLE_OPERATORTYPENAME_CASE(DepthToSpace)
339     HANDLE_OPERATORTYPENAME_CASE(SpaceToDepth)
340     HANDLE_OPERATORTYPENAME_CASE(FullyConnected)
341     HANDLE_OPERATORTYPENAME_CASE(HardSwish)
342     HANDLE_OPERATORTYPENAME_CASE(Dequantize)
343     HANDLE_OPERATORTYPENAME_CASE(L2Normalization)
344     HANDLE_OPERATORTYPENAME_CASE(LocalResponseNormalization)
345     HANDLE_OPERATORTYPENAME_CASE(Log)
346     HANDLE_OPERATORTYPENAME_CASE(Logistic)
347     HANDLE_OPERATORTYPENAME_CASE(LstmCell)
348     HANDLE_OPERATORTYPENAME_CASE(MaxPool)
349     HANDLE_OPERATORTYPENAME_CASE(L2Pool)
350     HANDLE_OPERATORTYPENAME_CASE(FakeQuant)
351     HANDLE_OPERATORTYPENAME_CASE(Mul)
352     HANDLE_OPERATORTYPENAME_CASE(RandomUniform)
353     HANDLE_OPERATORTYPENAME_CASE(Elu)
354     HANDLE_OPERATORTYPENAME_CASE(Relu)
355     HANDLE_OPERATORTYPENAME_CASE(Relu1)
356     HANDLE_OPERATORTYPENAME_CASE(Relu6)
357     HANDLE_OPERATORTYPENAME_CASE(PRelu)
358     HANDLE_OPERATORTYPENAME_CASE(ReorderAxes)
359     HANDLE_OPERATORTYPENAME_CASE(Softmax)
360     HANDLE_OPERATORTYPENAME_CASE(LogSoftmax)
361     HANDLE_OPERATORTYPENAME_CASE(Div)
362     HANDLE_OPERATORTYPENAME_CASE(Tanh)
363     HANDLE_OPERATORTYPENAME_CASE(Sin)
364     HANDLE_OPERATORTYPENAME_CASE(All)
365     HANDLE_OPERATORTYPENAME_CASE(Assert)
366     HANDLE_OPERATORTYPENAME_CASE(ExpandDims)
367     HANDLE_OPERATORTYPENAME_CASE(Fill)
368     HANDLE_OPERATORTYPENAME_CASE(FloorMod)
369     HANDLE_OPERATORTYPENAME_CASE(FloorDiv)
370     HANDLE_OPERATORTYPENAME_CASE(Greater)
371     HANDLE_OPERATORTYPENAME_CASE(GreaterEqual)
372     HANDLE_OPERATORTYPENAME_CASE(Identity)
373     HANDLE_OPERATORTYPENAME_CASE(Less)
374     HANDLE_OPERATORTYPENAME_CASE(LessEqual)
375     HANDLE_OPERATORTYPENAME_CASE(MatMul)
376     HANDLE_OPERATORTYPENAME_CASE(ReduceMax)  //  Reduction Max
377     HANDLE_OPERATORTYPENAME_CASE(Maximum)    //  Element-wise Maximum
378     HANDLE_OPERATORTYPENAME_CASE(Merge)
379     HANDLE_OPERATORTYPENAME_CASE(ReduceMin)  //  Reduction Min
380     HANDLE_OPERATORTYPENAME_CASE(Minimum)    //  Element-wise Minimum
381     HANDLE_OPERATORTYPENAME_CASE(Neg)
382     HANDLE_OPERATORTYPENAME_CASE(OneHot)
383     HANDLE_OPERATORTYPENAME_CASE(Pack)
384     HANDLE_OPERATORTYPENAME_CASE(Pad)
385     HANDLE_OPERATORTYPENAME_CASE(PadV2)
386     HANDLE_OPERATORTYPENAME_CASE(StridedSlice)
387     HANDLE_OPERATORTYPENAME_CASE(Range)
388     HANDLE_OPERATORTYPENAME_CASE(Rank)
389     HANDLE_OPERATORTYPENAME_CASE(Reshape)
390     HANDLE_OPERATORTYPENAME_CASE(Squeeze)
391     HANDLE_OPERATORTYPENAME_CASE(Rsqrt)
392     HANDLE_OPERATORTYPENAME_CASE(SegmentSum)
393     HANDLE_OPERATORTYPENAME_CASE(Shape)
394     HANDLE_OPERATORTYPENAME_CASE(Slice)
395     HANDLE_OPERATORTYPENAME_CASE(Split)
396     HANDLE_OPERATORTYPENAME_CASE(SplitV)
397     HANDLE_OPERATORTYPENAME_CASE(Sqrt)
398     HANDLE_OPERATORTYPENAME_CASE(Square)
399     HANDLE_OPERATORTYPENAME_CASE(Switch)
400     HANDLE_OPERATORTYPENAME_CASE(Sub)
401     HANDLE_OPERATORTYPENAME_CASE(Sum)
402     HANDLE_OPERATORTYPENAME_CASE(Tile)
403     HANDLE_OPERATORTYPENAME_CASE(Transpose)
404     HANDLE_OPERATORTYPENAME_CASE(TransposeConv)
405     HANDLE_OPERATORTYPENAME_CASE(Concat)
406     HANDLE_OPERATORTYPENAME_CASE(ConcatV2)
407     HANDLE_OPERATORTYPENAME_CASE(Cast)
408     HANDLE_OPERATORTYPENAME_CASE(Floor)
409     HANDLE_OPERATORTYPENAME_CASE(Ceil)
410     HANDLE_OPERATORTYPENAME_CASE(Round)
411     HANDLE_OPERATORTYPENAME_CASE(Gather)
412     HANDLE_OPERATORTYPENAME_CASE(GatherNd)
413     HANDLE_OPERATORTYPENAME_CASE(ResizeBilinear)
414     HANDLE_OPERATORTYPENAME_CASE(SpaceToBatchND)
415     HANDLE_OPERATORTYPENAME_CASE(BatchToSpaceND)
416     HANDLE_OPERATORTYPENAME_CASE(Mean)
417     HANDLE_OPERATORTYPENAME_CASE(ReduceProd)
418     HANDLE_OPERATORTYPENAME_CASE(Svdf)
419     HANDLE_OPERATORTYPENAME_CASE(ArgMax)
420     HANDLE_OPERATORTYPENAME_CASE(ArgMin)
421     HANDLE_OPERATORTYPENAME_CASE(TopK_V2)
422     HANDLE_OPERATORTYPENAME_CASE(Unsupported)
423     HANDLE_OPERATORTYPENAME_CASE(Exp)
424     HANDLE_OPERATORTYPENAME_CASE(DynamicPartition)
425     HANDLE_OPERATORTYPENAME_CASE(DynamicStitch)
426     HANDLE_OPERATORTYPENAME_CASE(Select)
427     HANDLE_OPERATORTYPENAME_CASE(SparseToDense)
428     HANDLE_OPERATORTYPENAME_CASE(Equal)
429     HANDLE_OPERATORTYPENAME_CASE(NotEqual)
430     HANDLE_OPERATORTYPENAME_CASE(Pow)
431     HANDLE_OPERATORTYPENAME_CASE(Any)
432     HANDLE_OPERATORTYPENAME_CASE(LogicalAnd)
433     HANDLE_OPERATORTYPENAME_CASE(LogicalNot)
434     HANDLE_OPERATORTYPENAME_CASE(LogicalOr)
435     HANDLE_OPERATORTYPENAME_CASE(CTCBeamSearchDecoder)
436     HANDLE_OPERATORTYPENAME_CASE(Unpack)
437     HANDLE_OPERATORTYPENAME_CASE(ZerosLike)
438     HANDLE_OPERATORTYPENAME_CASE(UnidirectionalSequenceLstm)
439     HANDLE_OPERATORTYPENAME_CASE(BidirectionalSequenceLstm)
440     HANDLE_OPERATORTYPENAME_CASE(BidirectionalSequenceRnn)
441     HANDLE_OPERATORTYPENAME_CASE(ResizeNearestNeighbor)
442     HANDLE_OPERATORTYPENAME_CASE(LeakyRelu)
443     HANDLE_OPERATORTYPENAME_CASE(SquaredDifference)
444     HANDLE_OPERATORTYPENAME_CASE(MirrorPad)
445     HANDLE_OPERATORTYPENAME_CASE(Unique)
446     HANDLE_OPERATORTYPENAME_CASE(UnidirectionalSequenceRnn)
447     HANDLE_OPERATORTYPENAME_CASE(ReverseV2)
448     HANDLE_OPERATORTYPENAME_CASE(Cos)
449     HANDLE_OPERATORTYPENAME_CASE(Where)
450     HANDLE_OPERATORTYPENAME_CASE(ReverseSequence)
451     HANDLE_OPERATORTYPENAME_CASE(MatrixDiag)
452     HANDLE_OPERATORTYPENAME_CASE(MatrixSetDiag)
453     HANDLE_OPERATORTYPENAME_CASE(MatrixDiagV2)
454     HANDLE_OPERATORTYPENAME_CASE(MatrixSetDiagV2)
455     HANDLE_OPERATORTYPENAME_CASE(MatrixDiagV3)
456     HANDLE_OPERATORTYPENAME_CASE(MatrixSetDiagV3)
457     HANDLE_OPERATORTYPENAME_CASE(ScatterNd)
458     default:
459       LOG(FATAL) << "Unhandled op type";
460 #undef HANDLE_OPERATORTYPENAME_CASE
461   }
462 }
463 
HelpfulOperatorTypeName(const Operator & op)464 std::string HelpfulOperatorTypeName(const Operator& op) {
465   if (op.type == OperatorType::kUnsupported) {
466     return toco::port::StringF(
467         "(Unsupported TensorFlow op: %s)",
468         static_cast<const TensorFlowUnsupportedOperator&>(op).tensorflow_op);
469   }
470   return OperatorTypeName(op.type);
471 }
472 
OperatorSupportsFusedActivation(OperatorType type)473 bool OperatorSupportsFusedActivation(OperatorType type) {
474   switch (type) {
475     case OperatorType::kAdd:
476     case OperatorType::kAveragePool:
477     case OperatorType::kBatchNormalization:
478     case OperatorType::kConv:
479     case OperatorType::kDepthwiseConv:
480     case OperatorType::kDiv:
481     case OperatorType::kFullyConnected:
482     case OperatorType::kL2Pool:
483     case OperatorType::kMaxPool:
484     case OperatorType::kMul:
485     case OperatorType::kSub:
486     case OperatorType::kSquaredDifference:
487       return true;
488     default:
489       return false;
490   }
491 }
492 
LogSummary(int log_level,const Model & model)493 void LogSummary(int log_level, const Model& model) {
494   VLOG(log_level) << "Operators summary (" << model.operators.size()
495                   << " operators):";
496   std::unordered_multiset<OperatorType> ops_by_type;
497   for (const auto& op : model.operators) {
498     ops_by_type.insert(op->type);
499   }
500   auto it = ops_by_type.begin();
501   while (it != ops_by_type.end()) {
502     int count = ops_by_type.count(*it);
503     VLOG(log_level) << "    " << OperatorTypeName(*it) << ": " << count;
504     std::advance(it, count);
505   }
506 }
507 
LogArray(int log_level,const Model & model,const std::string & name)508 void LogArray(int log_level, const Model& model, const std::string& name) {
509   VLOG(log_level) << "Array: " << name;
510   if (!model.HasArray(name)) {
511     VLOG(log_level) << "  DOES NOT EXIST";
512     return;
513   }
514   const auto& array = model.GetArray(name);
515   VLOG(log_level) << "  Data type: " << ArrayDataTypeName(array.data_type);
516   VLOG(log_level) << "  Final type: "
517                   << ArrayDataTypeName(array.final_data_type);
518   if (array.buffer) {
519     VLOG(log_level) << "  Constant Buffer";
520   }
521   if (array.alloc) {
522     VLOG(log_level) << "  Transient Alloc";
523   }
524   if (array.has_shape()) {
525     const Shape& array_shape = array.shape();
526     if (array_shape.dimensions_count() == 0) {
527       VLOG(log_level) << "  (Zero dimensions)";
528     } else {
529       std::string message = "  Dims: ";
530       bool first = true;
531       for (const int dim : array_shape.dims()) {
532         if (!first) {
533           message += ", ";
534         }
535         first = false;
536         toco::port::AppendF(&message, "%d", dim);
537       }
538       VLOG(log_level) << message;
539     }
540   }
541   if (array.minmax) {
542     VLOG(log_level) << "  MinMax: " << array.minmax->min << " .. "
543                     << array.minmax->max;
544   }
545   if (array.quantization_params) {
546     VLOG(log_level) << "  QuantizationParams: zero_point="
547                     << static_cast<int>(array.quantization_params->zero_point)
548                     << ", scale=" << array.quantization_params->scale;
549   }
550 }
551 
DumpGraphvizVideoFrame(const Model & model)552 void DumpGraphvizVideoFrame(const Model& model) {
553   namespace port = toco::port;
554 
555   const auto& dump_options = *GraphVizDumpOptions::singleton();
556   if (!dump_options.dump_graphviz_video) {
557     return;
558   }
559   CHECK(!dump_options.dump_graphviz.empty());
560   // TODO(benoitjacob): the static data here means that this function
561   // is stateful, not reentrant, and effectively leaks memory till exit
562   // (since dump_hashes can only grow in size). It also means that it
563   // really only is intended to be called for a single model during the
564   // process' lifetime. So it's not great design at all. The overriding
565   // design aspect here is to make the video-dumping code as unintrusive
566   // and self-contained as possible. Eventually, we'll want to have that
567   // cleaned-up, but that will require some form of general statefulness
568   // in toco (some kind of 'tooling state' data structure) that does
569   // not exist at present, and would be premature to design here just for
570   // this new video-dumping feature.
571   static int dump_id = 0;
572   static std::unordered_set<std::size_t> dump_hashes;
573   std::string graphviz_dump;
574   DumpGraphviz(model, &graphviz_dump,
575                toco::port::StringF("VIDEO frame:%05d", dump_id));
576   std::size_t hash = std::hash<std::string>{}(graphviz_dump);
577   if (!dump_hashes.count(hash)) {
578     LOG(INFO) << "DUMPING GRAPHVIZ VIDEO FRAME: " << dump_id;
579     dump_hashes.insert(hash);
580     const auto result = port::file::SetContents(
581         port::file::JoinPath(
582             dump_options.dump_graphviz,
583             toco::port::StringF("toco_video_%05d.dot", dump_id)),
584         graphviz_dump, port::file::Defaults());
585     QCHECK(result.ok()) << result.error_message();
586     dump_id++;
587   }
588 }
589 
LogDump(int log_level,const std::string & message,const Model & model)590 void LogDump(int log_level, const std::string& message, const Model& model) {
591   namespace port = toco::port;
592   const auto& dump_options = *GraphVizDumpOptions::singleton();
593 
594   DumpGraphvizVideoFrame(model);
595   if (!dump_options.dump_graphviz.empty()) {
596     std::string graphviz_dump;
597 
598     DumpGraphviz(model, &graphviz_dump, message);
599     const auto result = port::file::SetContents(
600         port::file::JoinPath(
601             dump_options.dump_graphviz,
602             absl::StrCat("toco_", absl::StrReplaceAll(message, {{" ", "_"}}),
603                          ".dot")),
604         graphviz_dump, port::file::Defaults());
605     QCHECK(result.ok()) << result.error_message();
606   }
607 
608   if (!VLOG_IS_ON(log_level)) {
609     return;
610   }
611   VLOG(log_level) << "BEGIN DUMP OF TOCO MODEL (" << message << ")";
612   LogSummary(log_level, model);
613   std::unordered_set<std::string> already_printed_arrays;
614   for (const auto& op : model.operators) {
615     for (const auto& input : op->inputs) {
616       if (!already_printed_arrays.count(input)) {
617         already_printed_arrays.insert(input);
618         LogArray(log_level, model, input);
619       }
620     }
621     VLOG(log_level) << HelpfulOperatorTypeName(*op) << " :";
622     VLOG(log_level) << "  " << FormatArraysList(model, op->inputs) << " -> "
623                     << FormatArraysList(model, op->outputs);
624     if (op->fused_activation_function != FusedActivationFunctionType::kNone) {
625       VLOG(log_level) << "    (with fused activation function)";
626     }
627     for (const auto& output : op->outputs) {
628       if (!already_printed_arrays.count(output)) {
629         already_printed_arrays.insert(output);
630         LogArray(log_level, model, output);
631       }
632     }
633   }
634   VLOG(log_level) << "END DUMP OF TOCO MODEL (" << message << ")";
635 }
636 
637 // Note remaining raw-array extension in ProcessTensorFlowReshapeOperator().
ExtendShape(Shape * shape,int new_shape_size)638 void ExtendShape(Shape* shape, int new_shape_size) {
639   CHECK_GE(new_shape_size, shape->dimensions_count());
640   const int size_increase = new_shape_size - shape->dimensions_count();
641   auto* shape_dims = shape->mutable_dims();
642   shape_dims->insert(shape_dims->begin(), size_increase, 1);
643 }
644 
645 // TODO(b/62904716) Remove along with remaining uses.
UnextendShape(Shape * shape,int new_shape_size)646 void UnextendShape(Shape* shape, int new_shape_size) {
647   CHECK_LE(new_shape_size, shape->dimensions_count());
648   const int size_reduction = shape->dimensions_count() - new_shape_size;
649   for (int i = 0; i < size_reduction; i++) {
650     CHECK_EQ(shape->dims(i), 1);
651   }
652   std::vector<int>& shape_dims = *shape->mutable_dims();
653   shape_dims.erase(shape_dims.begin(), shape_dims.begin() + size_reduction);
654 }
655 
656 // In general, zero-sized dimensions are disallowed, but there are exceptions,
657 // e.g., if the tensor data itself represents a scalar (rank 0) shape, its
658 // shape will have dimensions [0]. CheckNonEmptyShapeDimensions is more
659 // strict, and is appropriate for ops and comparisons where an empty shape
660 // doesn't make sense.
661 template <typename Dims>
CheckValidShapeDimensions(const Dims & dims)662 void CheckValidShapeDimensions(const Dims& dims) {
663   if (dims.size() == 1 && dims[0] == 0) {
664     return;
665   }
666   for (const auto& dim : dims) {
667     CHECK_GE(dim, 1);
668   }
669 }
670 
CheckValidShape(const Shape & shape)671 void CheckValidShape(const Shape& shape) {
672   CheckValidShapeDimensions(shape.dims());
673 }
674 
IsNonEmpty(const Shape & shape)675 bool IsNonEmpty(const Shape& shape) {
676   for (int i = 0; i < shape.dimensions_count(); ++i) {
677     if (shape.dims(i) < 1) return false;
678   }
679   return true;
680 }
681 
CheckNonEmptyShapeDimensions(const Shape & shape)682 void CheckNonEmptyShapeDimensions(const Shape& shape) {
683   for (int i = 0; i < shape.dimensions_count(); ++i) {
684     CHECK_GE(shape.dims()[i], 1) << "shape has dimension 0 at index << " << i
685                                  << ". shape = " << ShapeToString(shape);
686   }
687 }
688 
ShapesAgreeUpToBroadcasting(const Shape & shape0,const Shape & shape1)689 bool ShapesAgreeUpToBroadcasting(const Shape& shape0, const Shape& shape1) {
690   CheckNonEmptyShapeDimensions(shape0);
691   CheckNonEmptyShapeDimensions(shape1);
692 
693   const Shape* longer = &shape0;
694   const Shape* shorter = &shape1;
695   if (shape1.dimensions_count() > shape0.dimensions_count()) {
696     longer = &shape1;
697     shorter = &shape0;
698   }
699 
700   // Walk dimensions back to front until we run out of dimensions in the shorter
701   // shape.
702   int longer_index = longer->dimensions_count() - 1;
703   int shorter_index = shorter->dimensions_count() - 1;
704   while (shorter_index >= 0) {
705     const int d_long = longer->dims(longer_index);
706     const int d_short = shorter->dims(shorter_index);
707     // Broadcasting fails if the dimensions are different *and* neither is 1.
708     if ((d_long != d_short) && (d_long != 1) && (d_short != 1)) {
709       return false;
710     }
711     longer_index--;
712     shorter_index--;
713   }
714   return true;
715 }
716 
ShapesAgreeUpToExtending(const Shape & shape0,const Shape & shape1)717 bool ShapesAgreeUpToExtending(const Shape& shape0, const Shape& shape1) {
718   CheckNonEmptyShapeDimensions(shape0);
719   CheckNonEmptyShapeDimensions(shape1);
720 
721   const Shape* longer = &shape0;
722   const Shape* shorter = &shape1;
723   if (shape1.dimensions_count() > shape0.dimensions_count()) {
724     longer = &shape1;
725     shorter = &shape0;
726   }
727 
728   // Walk dimensions back to front until we run out of dimensions in the shorter
729   // shape.
730   int longer_index = longer->dimensions_count() - 1;
731   int shorter_index = shorter->dimensions_count() - 1;
732   while (shorter_index >= 0) {
733     const int d_long = longer->dims(longer_index);
734     const int d_short = shorter->dims(shorter_index);
735     // Extending fails if the dimensions are different.
736     if (d_long != d_short) {
737       return false;
738     }
739     longer_index--;
740     shorter_index--;
741   }
742 
743   // The remaining dimensions in the longer shape must be 1.
744   while (longer_index >= 0) {
745     const int d_long = longer->dims(longer_index);
746     if (d_long != 1) {
747       return false;
748     }
749     longer_index--;
750   }
751 
752   return true;
753 }
754 
RequiredBufferSizeForShape(const Shape & shape)755 int RequiredBufferSizeForShape(const Shape& shape) {
756   CheckValidShape(shape);
757   int max_offset = 1;
758   for (const auto& dim : shape.dims()) {
759     max_offset *= dim;
760   }
761   return max_offset;
762 }
763 
IsConstantParameterArray(const Model & model,const std::string & name)764 bool IsConstantParameterArray(const Model& model, const std::string& name) {
765   if (!model.HasArray(name)) {
766     return false;
767   }
768 
769   return !!model.GetArray(name).buffer;
770 }
771 
772 namespace {
773 template <ArrayDataType A>
CompareArrayBuffers(const Array & lhs_array,const Array & rhs_array)774 bool CompareArrayBuffers(const Array& lhs_array, const Array& rhs_array) {
775   CHECK(lhs_array.data_type == rhs_array.data_type) << "Data types must match";
776   CHECK(lhs_array.buffer) << "LHS must be constant";
777   CHECK(rhs_array.buffer) << "RHS must be constant";
778   const auto& lhs_data = lhs_array.GetBuffer<A>().data;
779   const auto& rhs_data = rhs_array.GetBuffer<A>().data;
780   CHECK_EQ(lhs_data.size(), rhs_data.size())
781       << "Buffer sizes must match in element count";
782   for (int i = 0; i < lhs_data.size(); ++i) {
783     if (lhs_data[i] != rhs_data[i]) {
784       return false;
785     }
786   }
787   return true;
788 }
789 
HaveSameMinMax(const Array & lhs_array,const Array & rhs_array)790 bool HaveSameMinMax(const Array& lhs_array, const Array& rhs_array) {
791   if (lhs_array.minmax || rhs_array.minmax) {
792     if (!lhs_array.minmax || !rhs_array.minmax) {
793       return false;
794     }
795     if (!(*lhs_array.minmax == *rhs_array.minmax)) {
796       return false;
797     }
798   }
799   return true;
800 }
801 
HaveSameQuantizationParams(const Array & lhs_array,const Array & rhs_array)802 bool HaveSameQuantizationParams(const Array& lhs_array,
803                                 const Array& rhs_array) {
804   if (lhs_array.quantization_params || rhs_array.quantization_params) {
805     if (!lhs_array.quantization_params || !rhs_array.quantization_params) {
806       return false;
807     }
808     if (!(*lhs_array.quantization_params == *rhs_array.quantization_params)) {
809       return false;
810     }
811   }
812   return true;
813 }
814 
815 }  // namespace
816 
CompareConstantArrays(const Array & lhs_array,const Array & rhs_array)817 bool CompareConstantArrays(const Array& lhs_array, const Array& rhs_array) {
818   bool attrs_equal = lhs_array.shape() == rhs_array.shape() &&
819                      lhs_array.data_type == rhs_array.data_type &&
820                      lhs_array.final_data_type == rhs_array.final_data_type &&
821                      HaveSameMinMax(lhs_array, rhs_array) &&
822                      HaveSameQuantizationParams(lhs_array, rhs_array) &&
823                      lhs_array.narrow_range == rhs_array.narrow_range;
824   if (!attrs_equal) {
825     return false;
826   }
827   switch (lhs_array.data_type) {
828     case ArrayDataType::kBool:
829       return CompareArrayBuffers<ArrayDataType::kBool>(lhs_array, rhs_array);
830     case ArrayDataType::kFloat:
831       return CompareArrayBuffers<ArrayDataType::kFloat>(lhs_array, rhs_array);
832     case ArrayDataType::kInt8:
833       return CompareArrayBuffers<ArrayDataType::kInt8>(lhs_array, rhs_array);
834     case ArrayDataType::kUint8:
835       return CompareArrayBuffers<ArrayDataType::kUint8>(lhs_array, rhs_array);
836     case ArrayDataType::kInt16:
837       return CompareArrayBuffers<ArrayDataType::kInt16>(lhs_array, rhs_array);
838     case ArrayDataType::kUint16:
839       return CompareArrayBuffers<ArrayDataType::kUint16>(lhs_array, rhs_array);
840     case ArrayDataType::kInt32:
841       return CompareArrayBuffers<ArrayDataType::kInt32>(lhs_array, rhs_array);
842     case ArrayDataType::kUint32:
843       return CompareArrayBuffers<ArrayDataType::kUint32>(lhs_array, rhs_array);
844     case ArrayDataType::kInt64:
845       return CompareArrayBuffers<ArrayDataType::kInt64>(lhs_array, rhs_array);
846     case ArrayDataType::kUint64:
847       return CompareArrayBuffers<ArrayDataType::kUint64>(lhs_array, rhs_array);
848     case ArrayDataType::kString:
849       return CompareArrayBuffers<ArrayDataType::kString>(lhs_array, rhs_array);
850     case ArrayDataType::kComplex64:
851       return CompareArrayBuffers<ArrayDataType::kComplex64>(lhs_array,
852                                                             rhs_array);
853     default:
854       LOG(FATAL) << "Unsupported data type: "
855                  << ArrayDataTypeName(lhs_array.data_type);
856       return false;
857   }
858 }
859 
860 namespace {
861 // Take an array name, which may be something like "name:3_5" and make it
862 // acceptable as a TF node name, say "name_3_5";
SanitizeNameForTFNode(const std::string & array_name)863 std::string SanitizeNameForTFNode(const std::string& array_name) {
864   auto node_name = array_name;
865   std::replace(node_name.begin(), node_name.end(), ':', '_');
866   return node_name;
867 }
868 
CheckInputArraysAreNotOutputArrays(const ModelFlags & model_flags)869 void CheckInputArraysAreNotOutputArrays(const ModelFlags& model_flags) {
870   for (const auto& input_array : model_flags.input_arrays()) {
871     for (const std::string& output_array : model_flags.output_arrays()) {
872       QCHECK_NE(input_array.name(), output_array)
873           << "The array " << output_array
874           << " is listed in both --input_arrays and --output_arrays.";
875     }
876   }
877 }
878 
IsAsciiPrintable(const std::string & name)879 bool IsAsciiPrintable(const std::string& name) {
880   for (char c : name) {
881     if (!absl::ascii_isprint(c)) {
882       return false;
883     }
884   }
885   return true;
886 }
887 
DumpAscii(const std::string & name)888 std::string DumpAscii(const std::string& name) {
889   std::string result;
890   port::AppendF(&result, "ASCII | Hex\n");
891   port::AppendF(&result, "------+----\n");
892   for (char c : name) {
893     if (absl::ascii_isprint(c)) {
894       port::AppendF(&result, "%c     | %x\n", c, c);
895     } else {
896       port::AppendF(&result, "      | %x   Not ASCII printable!\n", c);
897     }
898   }
899   return result;
900 }
901 
CheckNonAsciiIOArrays(const ModelFlags & model_flags)902 void CheckNonAsciiIOArrays(const ModelFlags& model_flags) {
903   if (model_flags.allow_nonascii_arrays()) {
904     return;
905   }
906   for (const auto& input_array : model_flags.input_arrays()) {
907     QCHECK(IsAsciiPrintable(input_array.name()))
908         << "Non-ASCII-printable character found in --input_arrays: "
909         << input_array.name()
910         << ". Pass --allow_nonascii_arrays to allow that. "
911         << "Here is a dump of the string:\n\n"
912         << DumpAscii(input_array.name());
913   }
914   for (const std::string& output_array : model_flags.output_arrays()) {
915     QCHECK(IsAsciiPrintable(output_array))
916         << "Non-ASCII-printable character found in --output_arrays: "
917         << output_array << ". Pass --allow_nonascii_arrays to allow that. "
918         << "Here is a dump of the string:\n\n"
919         << DumpAscii(output_array);
920   }
921 }
922 
CheckNonExistentIOArrays(const Model & model)923 void CheckNonExistentIOArrays(const Model& model) {
924   // "non-existent" is interpreted in the stronger sense of
925   // "not actually produced/consumed by an op".
926   // Rationale: we have to artificially fix up TensorFlow graphs by creating
927   // any array that it refers to, so just checking that arrays exist isn't
928   // sufficient. The real invariant here is whether arrays are produced/consumed
929   // by something.
930   if (model.flags.allow_nonexistent_arrays()) {
931     return;
932   }
933   static constexpr char general_comment[] =
934       "Is it a typo? This should not happen. If you trigger this error "
935       "please send a bug report (with code to reproduce this error), to the "
936       "TensorFlow Lite team.";
937   for (const std::string& output_array : model.flags.output_arrays()) {
938     if (IsConstantParameterArray(model, output_array)) {
939       continue;  // It is OK to request that a constant be an output.
940     }
941     QCHECK(GetOpWithOutput(model, output_array))
942         << "Specified output array \"" << output_array
943         << "\" is not produced by any op in this graph. " << general_comment;
944   }
945   for (const auto& rnn_state : model.flags.rnn_states()) {
946     if (!rnn_state.discardable()) {
947       // Check that all RNN states are consumed
948       QCHECK(GetOpWithInput(model, rnn_state.state_array()))
949           << "Specified RNN state \"" << rnn_state.state_array()
950           << "\" is not consumed by any op in this graph. " << general_comment;
951       // Check that all RNN back-edge source arrays are produced
952       QCHECK(GetOpWithOutput(model, rnn_state.back_edge_source_array()))
953           << "Specified RNN back-edge source array \""
954           << rnn_state.back_edge_source_array()
955           << "\" is not produced by any op in this graph. " << general_comment;
956     }
957   }
958 }
959 
960 }  // namespace
961 
CheckNoMissingArray(const Model & model)962 void CheckNoMissingArray(const Model& model) {
963   for (const auto& op : model.operators) {
964     for (const auto& input : op->inputs) {
965       CHECK(model.HasArray(input) || model.optional_arrays.count(input))
966           << "Input: " << input << " missing for op: " << op->outputs[0] << ".";
967     }
968     for (const auto& output : op->outputs) {
969       CHECK(model.HasArray(output)) << "Output: " << output << " missing.";
970     }
971   }
972   CheckNonExistentIOArrays(model);
973 }
974 
FixNoMissingArray(Model * model)975 void FixNoMissingArray(Model* model) {
976   for (const auto& op : model->operators) {
977     for (const auto& input : op->inputs) {
978       if (!model->HasArray(input) && !model->IsOptionalArray(input)) {
979         model->GetOrCreateArray(input);
980       }
981     }
982     for (const auto& output : op->outputs) {
983       if (!model->HasArray(output) && !model->IsOptionalArray(output)) {
984         model->GetOrCreateArray(output);
985       }
986     }
987   }
988   if (model->flags.allow_nonexistent_arrays()) {
989     for (const std::string& output_array : model->flags.output_arrays()) {
990       model->GetOrCreateArray(output_array);
991     }
992     for (const auto& rnn_state : model->flags.rnn_states()) {
993       model->GetOrCreateArray(rnn_state.state_array());
994       model->GetOrCreateArray(rnn_state.back_edge_source_array());
995     }
996   }
997 }
998 
CheckNoOrphanedArray(const Model & model)999 void CheckNoOrphanedArray(const Model& model) {
1000   std::unordered_set<std::string> arrays_without_known_use;
1001   for (const auto& array : model.GetArrayMap()) {
1002     if (IsDiscardableArray(model, array.first)) {
1003       arrays_without_known_use.insert(array.first);
1004     }
1005   }
1006   for (const auto& op : model.operators) {
1007     for (const auto& input : op->inputs) {
1008       arrays_without_known_use.erase(input);
1009     }
1010     for (const auto& output : op->outputs) {
1011       arrays_without_known_use.erase(output);
1012     }
1013   }
1014   for (const auto& rnn_state : model.flags.rnn_states()) {
1015     arrays_without_known_use.erase(rnn_state.state_array());
1016     arrays_without_known_use.erase(rnn_state.back_edge_source_array());
1017   }
1018   if (!arrays_without_known_use.empty()) {
1019     for (const auto& array : arrays_without_known_use) {
1020       LOG(INFO) << "Error: Orphaned array: " << array;
1021     }
1022   }
1023   CHECK(arrays_without_known_use.empty());
1024 }
1025 
FixNoOrphanedArray(Model * model)1026 void FixNoOrphanedArray(Model* model) {
1027   std::unordered_set<std::string> arrays_without_known_use;
1028   for (const auto& array : model->GetArrayMap()) {
1029     arrays_without_known_use.insert(array.first);
1030   }
1031   for (const auto& op : model->operators) {
1032     for (const auto& input : op->inputs) {
1033       arrays_without_known_use.erase(input);
1034     }
1035     for (const auto& output : op->outputs) {
1036       arrays_without_known_use.erase(output);
1037     }
1038   }
1039   for (const auto& rnn_state : model->flags.rnn_states()) {
1040     arrays_without_known_use.erase(rnn_state.state_array());
1041     arrays_without_known_use.erase(rnn_state.back_edge_source_array());
1042   }
1043   for (const auto& array : arrays_without_known_use) {
1044     if (IsDiscardableArray(*model, array)) {
1045       model->EraseArray(array);
1046     }
1047   }
1048 }
1049 
1050 // Apply checks to arrays individually (for-each fashion).
1051 //
1052 // Check consistency of array fields, check name.
CheckEachArray(const Model & model)1053 void CheckEachArray(const Model& model) {
1054   for (const auto& array_entry : model.GetArrayMap()) {
1055     const auto& array = array_entry.second;
1056     // It's OK to have a buffer or an alloc, but not both.
1057     // (Since allocs are for transient arrays without a buffer).
1058     CHECK(!array->buffer || !array->alloc) << "Tensor: " << array_entry.first;
1059     if (array->buffer) {
1060       // If there is a buffer, its type should be consistent with data_type.
1061       CHECK(array->buffer->type == array->data_type)
1062           << "Tensor: " << array_entry.first;
1063       // The presence of a fixed buffer should imply the presence of a fixed
1064       // shape.
1065       CHECK(array->has_shape()) << array_entry.first;
1066       // Constant buffer should has a valid shape.
1067       CheckValidShape(array->shape());
1068       // The shape flat-size should agree with the buffer length.
1069       CHECK_EQ(array->buffer->Length(),
1070                RequiredBufferSizeForShape(array->shape()))
1071           << "Tensor: " << array_entry.first;
1072     }
1073 
1074     // Check name.  Either "name_with_suffix_8", "name_with_port:3", but not
1075     // "name_with_both:3_8".
1076     const std::string& name = array_entry.first;
1077     auto colon_pos = name.find_first_of(':');
1078     if (colon_pos != std::string::npos) {
1079       CHECK_EQ(name.substr(colon_pos + 1).find_first_not_of("0123456789"),
1080                std::string::npos)
1081           << "Array '" << name << "' has non-digit characters after colon.";
1082     }
1083     CHECK_GT(colon_pos, 0) << "Array '" << name
1084                            << "' must not start with a colon.";
1085   }
1086 }
1087 
CheckOperatorOrdering(const Model & model)1088 void CheckOperatorOrdering(const Model& model) {
1089   std::unordered_set<std::string> arrays_behind_us;
1090   for (const auto& array_entry : model.GetArrayMap()) {
1091     if (!GetOpWithOutput(model, array_entry.first)) {
1092       arrays_behind_us.insert(array_entry.first);
1093     }
1094   }
1095   arrays_behind_us.insert(model.optional_arrays.begin(),
1096                           model.optional_arrays.end());
1097   for (const auto& op : model.operators) {
1098     for (const auto& input : op->inputs) {
1099       if (!IsConstantParameterArray(model, input)) {
1100         CHECK(arrays_behind_us.count(input));
1101       }
1102     }
1103     for (const auto& output : op->outputs) {
1104       CHECK(!arrays_behind_us.count(output));
1105       arrays_behind_us.insert(output);
1106     }
1107   }
1108   for (const std::string& output_array : model.flags.output_arrays()) {
1109     CHECK(arrays_behind_us.count(output_array));
1110   }
1111 }
1112 
FixOperatorOrdering(Model * model)1113 void FixOperatorOrdering(Model* model) {
1114   std::unordered_set<std::string> arrays_behind_us;
1115   for (const auto& array_entry : model->GetArrayMap()) {
1116     if (!GetOpWithOutput(*model, array_entry.first)) {
1117       arrays_behind_us.insert(array_entry.first);
1118     }
1119   }
1120   arrays_behind_us.insert(model->optional_arrays.begin(),
1121                           model->optional_arrays.end());
1122   std::vector<std::unique_ptr<Operator>> old_operators;
1123   std::swap(old_operators, model->operators);
1124   std::set<std::size_t> remaining;
1125   for (std::size_t i = 0; i < old_operators.size(); i++) {
1126     remaining.insert(i);
1127   }
1128   std::unordered_map<std::string, std::string> reason_why_leftover;
1129   while (true) {
1130     bool inserted_something = false;
1131     for (const auto& i : remaining) {
1132       bool can_insert = true;
1133       auto& op = old_operators[i];
1134       CHECK(op);
1135       for (const auto& input : op->inputs) {
1136         if (!IsConstantParameterArray(*model, input) &&
1137             !arrays_behind_us.count(input)) {
1138           for (const std::string& output : op->outputs) {
1139             reason_why_leftover[output] = input;
1140           }
1141           can_insert = false;
1142           break;
1143         }
1144       }
1145       if (can_insert) {
1146         model->operators.emplace_back(nullptr);
1147         for (const auto& output : op->outputs) {
1148           arrays_behind_us.insert(output);
1149         }
1150         std::swap(op, model->operators.back());
1151         remaining.erase(i);
1152         inserted_something = true;
1153         break;
1154       }
1155     }
1156     if (!inserted_something) {
1157       break;
1158     }
1159   }
1160   if (!remaining.empty()) {
1161     LOG(ERROR)
1162         << "No viable ordering of operators was found. "
1163         << "Here is a 'backtrace' of at least one part of the graph that is "
1164         << "problematic. It starts with the first operator that has as "
1165         << "problematic input array, and then walks back the graph to "
1166         << "the operator that produced that input array, etc., until we find "
1167         << "the root cause:";
1168     LOG(ERROR) << "BEGIN TRACE OF OPERATOR WITH BAD INPUT";
1169     LOG(ERROR) << "Here is the first-encountered operator with a bad input: ";
1170     const Operator* bad_op = old_operators[*remaining.begin()].get();
1171     std::unordered_set<std::string> bad_inputs_already_traced;
1172     // The following while(true) loop should always end with a LOG(FATAL).
1173     while (true) {
1174       LOG(ERROR) << HelpfulOperatorTypeName(*bad_op) << " : "
1175                  << FormatArraysList(*model, bad_op->inputs) << " -> "
1176                  << FormatArraysList(*model, bad_op->outputs);
1177       bool found_bad_output = false;
1178       std::string bad_output;
1179       for (const std::string& output : bad_op->outputs) {
1180         if (reason_why_leftover.count(output)) {
1181           found_bad_output = true;
1182           bad_output = output;
1183           break;
1184         }
1185       }
1186       CHECK(found_bad_output);
1187       const std::string& bad_input = reason_why_leftover[bad_output];
1188       LOG(ERROR) << "The bad input here is: " << bad_input;
1189       if (bad_inputs_already_traced.count(bad_input)) {
1190         LOG(FATAL)
1191             << "Cycle found! We already encountered that "
1192             << "input array, " << bad_input << ", earlier in the "
1193             << "above trace! We expect graphs to be acyclic, even "
1194             << "RNNs. Let us know if some graph actually needs to have "
1195             << "cycles, but first, please check if it really is "
1196             << "an *inference* graph. *Training* graphs are out-of-scope "
1197             << "for toco.";
1198       }
1199       bad_inputs_already_traced.insert(bad_input);
1200       bad_op = nullptr;
1201       for (const auto& i : remaining) {
1202         const Operator* op = old_operators[i].get();
1203         for (const std::string& output : op->outputs) {
1204           if (bad_input == output) {
1205             bad_op = op;
1206             break;
1207           }
1208         }
1209         if (bad_op) {
1210           break;
1211         }
1212       }
1213       if (!bad_op) {
1214         LOG(ERROR) << "And that's the root cause: "
1215                    << "that array, " << bad_input << ", isn't produced by any "
1216                    << "operator, or provided in any other way.";
1217         LOG(ERROR) << "END TRACE OF OPERATOR WITH BAD INPUT";
1218         LOG(FATAL) << "(The above was a multi-line fatal error)";
1219       }
1220       LOG(ERROR) << "And that array is the output of the following operator:";
1221     }
1222   }
1223   CHECK(remaining.empty())
1224       << "Should never get here! In case of bad graph, "
1225       << "the above code should have generated a FATAL error already!";
1226 }
1227 
CheckInvariants(const Model & model)1228 void CheckInvariants(const Model& model) {
1229   CheckInputArraysAreNotOutputArrays(model.flags);
1230   CheckNonAsciiIOArrays(model.flags);
1231   CheckNoMissingArray(model);
1232   CheckNoOrphanedArray(model);
1233   CheckEachArray(model);
1234   CheckOperatorOrdering(model);
1235 }
1236 
CheckCountInRange(const::toco::ModelFlags::ModelCheck & model_check,const int count,const std::string & count_description)1237 void CheckCountInRange(const ::toco::ModelFlags::ModelCheck& model_check,
1238                        const int count, const std::string& count_description) {
1239   if (model_check.count_min() >= 0) {
1240     CHECK_GE(count, model_check.count_min())
1241         << "Mismatch in " << count_description << ": count  was " << count
1242         << ", but the specified "
1243         << (model_check.count_max() > model_check.count_min() ? "minimum"
1244                                                               : "value")
1245         << " was " << model_check.count_min() << ".";
1246   }
1247   if (model_check.count_max() > model_check.count_min()) {
1248     CHECK_LE(count, model_check.count_max())
1249         << "Mismatch in " << count_description << ": count  was " << count
1250         << ", but the specified maximum was " << model_check.count_max() << ".";
1251   }
1252 }
1253 
CheckModelCounts(const Model & model)1254 void CheckModelCounts(const Model& model) {
1255   std::unordered_multiset<OperatorType> ops_by_type;
1256   std::unordered_map<std::string, OperatorType> op_type_by_name;
1257   if (model.flags.model_checks_size() == 0) {
1258     return;
1259   }
1260 
1261   for (const auto& op : model.operators) {
1262     ops_by_type.insert(op->type);
1263     op_type_by_name[OperatorTypeName(op->type)] = op->type;
1264   }
1265   for (const auto& model_check : model.flags.model_checks()) {
1266     std::string count_type = model_check.count_type();
1267     if (count_type == "None") {
1268       continue;
1269     } else if (count_type == "Arrays") {
1270       CheckCountInRange(model_check, model.GetArrayMap().size(),
1271                         "count of arrays");
1272     } else if (count_type == "Total") {
1273       CheckCountInRange(model_check, model.operators.size(),
1274                         "count of all operator instances");
1275     } else {
1276       // The check type is not itself checked against the set of valid
1277       // operators, mainly because the enum set cannot be iterated in C++.
1278       const int found_count =
1279           op_type_by_name.count(count_type) > 0
1280               ? ops_by_type.count(op_type_by_name[count_type])
1281               : 0;
1282       CheckCountInRange(model_check, found_count,
1283                         "count of instances of " + count_type + " operator");
1284     }
1285   }
1286 }
1287 
FixEdgeArrays(Model * model)1288 void FixEdgeArrays(Model* model) {
1289   for (const std::string& output_array_name : model->flags.output_arrays()) {
1290     if (!GetOpWithOutput(*model, output_array_name)) {
1291       // Output has no operator producing it. Change that by inserting a copy.
1292       LOG(WARNING) << "Fixing constant output array " << output_array_name
1293                    << " by inserting a copy. This is not optimal.";
1294       std::string intermediate_array_name =
1295           AvailableArrayName(*model, output_array_name + "_copy");
1296       CloneArray(model, output_array_name, intermediate_array_name);
1297       InsertCopyOperator(model, intermediate_array_name, output_array_name);
1298     }
1299   }
1300 }
1301 
DedupeConstantArrays(Model * model,size_t min_size)1302 void DedupeConstantArrays(Model* model, size_t min_size) {
1303   // Walk all 0..N and compare with the remaining n+1..N.
1304   // This lets us avoid N^2 comparisons and erase duplicate arrays while
1305   // iterating.
1306   const auto& array_map = model->GetArrayMap();
1307   for (auto lhs_array_it = array_map.begin(); lhs_array_it != array_map.end();
1308        ++lhs_array_it) {
1309     const auto& lhs_array_name = lhs_array_it->first;
1310     const auto& lhs_array = *lhs_array_it->second;
1311     if (!IsConstantParameterArray(*model, lhs_array_name)) {
1312       // Not a constant array; skip.
1313       continue;
1314     }
1315     ArrayDataType final_data_type =
1316         lhs_array.final_data_type != ArrayDataType::kNone
1317             ? lhs_array.final_data_type
1318             : lhs_array.data_type;
1319     // Ignore small arrays, don't check string arrays because it is not possible
1320     // to estimate its size.
1321     if (final_data_type != ArrayDataType::kString) {
1322       size_t array_byte_size =
1323           lhs_array.buffer->Length() * ElementSize(final_data_type);
1324       if (array_byte_size < min_size) {
1325         // Too small; skip.
1326         continue;
1327       }
1328     }
1329 
1330     auto next_lhs_array_it = lhs_array_it;
1331     ++next_lhs_array_it;
1332     for (auto rhs_array_it = next_lhs_array_it;
1333          rhs_array_it != array_map.end();) {
1334       const auto& rhs_array_name = rhs_array_it->first;
1335       const auto& rhs_array = *rhs_array_it->second;
1336       ++rhs_array_it;
1337       if (!IsConstantParameterArray(*model, rhs_array_name)) {
1338         // Not a constant array; skip.
1339         continue;
1340       }
1341       if (!IsDiscardableArray(*model, rhs_array_name)) {
1342         // Can't remove the array as it's not discardable (such as an IO edge).
1343         continue;
1344       }
1345       if (!CompareConstantArrays(lhs_array, rhs_array)) {
1346         // Arrays aren't equal; skip.
1347         continue;
1348       }
1349 
1350       // Arrays can be deduped!
1351       VLOG(1) << "Deduplicating arrays; using " << lhs_array_name
1352               << " in place of " << rhs_array_name;
1353       ReplaceArrayUsage(model, rhs_array_name, lhs_array_name);
1354       // Note: rhs_array_it above is already incremented so this is safe.
1355       model->EraseArray(rhs_array_name);
1356     }
1357   }
1358 }
1359 
1360 namespace {
CopyArrayAttribs(const Array & source_array,Array * target_array)1361 void CopyArrayAttribs(const Array& source_array, Array* target_array) {
1362   target_array->data_type = source_array.data_type;
1363   target_array->final_data_type = source_array.final_data_type;
1364   if (source_array.has_shape()) {
1365     target_array->copy_shape(source_array.shape());
1366   }
1367 
1368   if (source_array.minmax) {
1369     target_array->GetOrCreateMinMax() = source_array.GetMinMax();
1370   } else {
1371     target_array->minmax.reset();
1372   }
1373 
1374   if (source_array.quantization_params) {
1375     target_array->GetOrCreateQuantizationParams() =
1376         source_array.GetQuantizationParams();
1377   } else {
1378     target_array->quantization_params.reset();
1379   }
1380 }
1381 }  // namespace
1382 
InsertCopyOperator(Model * model,const std::string & source_array_name,const std::string & target_array_name)1383 void InsertCopyOperator(Model* model, const std::string& source_array_name,
1384                         const std::string& target_array_name) {
1385   // Reshape to the same size. This should be a no-op.
1386   const Array& source_array = model->GetArray(source_array_name);
1387   std::vector<int> shape = source_array.shape().dims();
1388 
1389   // Drop constant data from the target array as the copy will be done at
1390   // runtime.
1391   Array& target_array = model->GetOrCreateArray(target_array_name);
1392   target_array.buffer.reset();
1393   CopyArrayAttribs(source_array, &target_array);
1394 
1395   // Insert copy operator.
1396   auto* copy_op = new TensorFlowReshapeOperator;
1397   copy_op->inputs = {
1398       source_array_name,
1399       CreateInt32Array(
1400           model, AvailableArrayName(*model, target_array_name + "_copy_shape"),
1401           shape)};
1402   copy_op->outputs = {target_array_name};
1403   if (target_array.has_shape()) {
1404     copy_op->shape = target_array.shape().dims();
1405   }
1406   model->operators.emplace_back(copy_op);
1407 }
1408 
CloneArray(Model * model,const std::string & source_array_name,const std::string & target_array_name)1409 void CloneArray(Model* model, const std::string& source_array_name,
1410                 const std::string& target_array_name) {
1411   CHECK(!model->HasArray(target_array_name));
1412   const Array& source_array = model->GetArray(source_array_name);
1413   Array& target_array = model->GetOrCreateArray(target_array_name);
1414   CopyArrayAttribs(source_array, &target_array);
1415 
1416   if (!source_array.buffer) {
1417     return;
1418   }
1419 
1420   switch (source_array.data_type) {
1421     case ArrayDataType::kBool:
1422       CopyArrayBuffer<ArrayDataType::kBool>(source_array, &target_array);
1423       break;
1424     case ArrayDataType::kFloat:
1425       CopyArrayBuffer<ArrayDataType::kFloat>(source_array, &target_array);
1426       break;
1427     case ArrayDataType::kInt8:
1428       CopyArrayBuffer<ArrayDataType::kInt8>(source_array, &target_array);
1429       break;
1430     case ArrayDataType::kUint8:
1431       CopyArrayBuffer<ArrayDataType::kUint8>(source_array, &target_array);
1432       break;
1433     case ArrayDataType::kInt16:
1434       CopyArrayBuffer<ArrayDataType::kInt16>(source_array, &target_array);
1435       break;
1436     case ArrayDataType::kUint16:
1437       CopyArrayBuffer<ArrayDataType::kUint16>(source_array, &target_array);
1438       break;
1439     case ArrayDataType::kInt32:
1440       CopyArrayBuffer<ArrayDataType::kInt32>(source_array, &target_array);
1441       break;
1442     case ArrayDataType::kUint32:
1443       CopyArrayBuffer<ArrayDataType::kUint32>(source_array, &target_array);
1444       break;
1445     case ArrayDataType::kInt64:
1446       CopyArrayBuffer<ArrayDataType::kInt64>(source_array, &target_array);
1447       break;
1448     case ArrayDataType::kUint64:
1449       CopyArrayBuffer<ArrayDataType::kUint64>(source_array, &target_array);
1450       break;
1451     case ArrayDataType::kString:
1452       CopyArrayBuffer<ArrayDataType::kString>(source_array, &target_array);
1453       break;
1454     case ArrayDataType::kComplex64:
1455       CopyArrayBuffer<ArrayDataType::kComplex64>(source_array, &target_array);
1456       break;
1457     default:
1458       LOG(FATAL) << "Unsupported data type: "
1459                  << ArrayDataTypeName(source_array.data_type);
1460       return;
1461   }
1462 }
1463 
MakeArrayDims(int num_dims,int batch,int height,int width,int depth,std::vector<int> * out_dims)1464 void MakeArrayDims(int num_dims, int batch, int height, int width, int depth,
1465                    std::vector<int>* out_dims) {
1466   CHECK(out_dims->empty());
1467   if (num_dims == 0) {
1468     return;
1469   } else if (num_dims == 1) {
1470     CHECK_EQ(batch, 1);
1471     *out_dims = {depth};
1472   } else if (num_dims == 2) {
1473     *out_dims = {batch, depth};
1474   } else if (num_dims == 3) {
1475     CHECK_EQ(batch, 1);
1476     *out_dims = {height, width, depth};
1477   } else if (num_dims == 4) {
1478     *out_dims = {batch, height, width, depth};
1479   } else {
1480     LOG(FATAL) << "Should not get here: " << num_dims;
1481   }
1482 }
1483 
CreateOrCheckRnnStateArray(const std::string & name,int size,int state_num_dims,Model * model)1484 void CreateOrCheckRnnStateArray(const std::string& name, int size,
1485                                 int state_num_dims, Model* model) {
1486   int batch = 1;
1487   int num_dims = -1;
1488   if (state_num_dims > 0) {
1489     num_dims = state_num_dims;
1490   } else {
1491     // state_num_dims is not given. We will infer it from an input tensor.
1492     for (const auto& input_array : model->flags.input_arrays()) {
1493       // Pick 'num_dims' and 'batch' from the first input_arrays, unless we find
1494       // a better match by name.
1495       if (input_array.name() == name || num_dims == -1) {
1496         num_dims = input_array.shape().dims_size();
1497         if (num_dims > 0) {
1498           batch = input_array.shape().dims(0);
1499         }
1500       }
1501     }
1502   }
1503   Array& array = model->GetOrCreateArray(name);
1504   if (array.has_shape()) {
1505     num_dims = array.shape().dimensions_count();
1506   }
1507   if (!array.has_shape() && num_dims >= 0) {
1508     Shape* shape = array.mutable_shape();
1509     std::vector<int> dims;
1510     MakeArrayDims(num_dims, batch, 1, 1, size, &dims);
1511     *shape->mutable_dims() = dims;
1512   }
1513 }
1514 
ResolveModelFlags(const ModelFlags & model_flags,Model * model)1515 void ResolveModelFlags(const ModelFlags& model_flags, Model* model) {
1516   // Merge info about input_arrays from model_flags into model->flags
1517   for (const auto& specified_input_array : model_flags.input_arrays()) {
1518     toco::InputArray* dst_input_array = nullptr;
1519     for (int i = 0; i < model->flags.input_arrays_size(); i++) {
1520       toco::InputArray* candidate_dst_input_array =
1521           model->flags.mutable_input_arrays(i);
1522       if (candidate_dst_input_array->name() == specified_input_array.name()) {
1523         // specified_input_array from model_flags maps to dst_input_array
1524         // in model->flags
1525         dst_input_array = candidate_dst_input_array;
1526         break;
1527       }
1528     }
1529     if (!dst_input_array) {
1530       // Specified_input_array from model_flags is not found in model->flags.
1531       // Match a name-less specified input array when there can be no ambiguity
1532       // as there is only 1 input array.
1533       if (model->flags.input_arrays_size() == 1 &&
1534           model_flags.input_arrays_size() == 1 &&
1535           !specified_input_array.has_name()) {
1536         dst_input_array = model->flags.mutable_input_arrays(0);
1537       }
1538     }
1539     if (!dst_input_array) {
1540       // Still no match, so create a new input array to copy
1541       // specified_input_array into.
1542       dst_input_array = model->flags.add_input_arrays();
1543       dst_input_array->set_name(specified_input_array.name());
1544     }
1545 
1546 #define RESOLVE_MODEL_FLAG(field_name)                                       \
1547   if (specified_input_array.has_##field_name()) {                            \
1548     if (dst_input_array->has_##field_name()) {                               \
1549       QCHECK_EQ(dst_input_array->field_name(),                               \
1550                 specified_input_array.field_name())                          \
1551           << "For input array '" << dst_input_array->name() << "', "         \
1552           << "specified " #field_name " flag with value: "                   \
1553           << specified_input_array.field_name()                              \
1554           << " does not agree with already defined " #field_name             \
1555              " of this model, with value: "                                  \
1556           << specified_input_array.field_name();                             \
1557     } else {                                                                 \
1558       dst_input_array->set_##field_name(specified_input_array.field_name()); \
1559     }                                                                        \
1560   }
1561     RESOLVE_MODEL_FLAG(std_value);
1562     RESOLVE_MODEL_FLAG(mean_value);
1563 #undef RESOLVE_MODEL_FLAG
1564 
1565     if (specified_input_array.has_shape()) {
1566       if (dst_input_array->has_shape()) {
1567         QCHECK_EQ(specified_input_array.shape().dims_size(),
1568                   dst_input_array->shape().dims_size())
1569             << "For input array '" << specified_input_array.name() << "', "
1570             << "size of specified input shape flag with size: "
1571             << specified_input_array.shape().dims_size()
1572             << " does not agree with already defined input shape"
1573                " of this model, with size: "
1574             << dst_input_array->shape().dims_size();
1575         // We treat the first dimension as a special case, since it is often
1576         // a batch size and the input_shape flag is effectively overriding
1577         // the model.
1578         for (int i = 1; i < specified_input_array.shape().dims_size(); i++) {
1579           QCHECK_EQ(specified_input_array.shape().dims(i),
1580                     dst_input_array->shape().dims(i))
1581               << "At dimension number " << i << " of input array "
1582               << specified_input_array.name() << ", the specified shape's "
1583               << "dimension flag with dimension: "
1584               << specified_input_array.shape().dims(i)
1585               << " does not agree with already defined shape"
1586               << " of this model, with dimension: "
1587               << dst_input_array->shape().dims(i);
1588         }
1589       } else {
1590         *dst_input_array->mutable_shape() = specified_input_array.shape();
1591       }
1592     }
1593 
1594     if (specified_input_array.has_data_type()) {
1595       QCHECK(!dst_input_array->has_data_type());
1596       dst_input_array->set_data_type(specified_input_array.data_type());
1597     }
1598   }
1599 
1600   if (model_flags.output_arrays_size() > 0) {
1601     model->flags.mutable_output_arrays()->CopyFrom(model_flags.output_arrays());
1602   }
1603 
1604 #define RESOLVE_MODEL_FLAG(name)                                           \
1605   if (model_flags.has_##name()) {                                          \
1606     if (model->flags.has_##name()) {                                       \
1607       QCHECK_EQ(model_flags.name(), model->flags.name())                   \
1608           << "Specified " #name " flag with value: " << model_flags.name() \
1609           << " does not agree with already defined " #name                 \
1610              " of this model, with value: "                                \
1611           << model->flags.name();                                          \
1612     } else {                                                               \
1613       model->flags.set_##name(model_flags.name());                         \
1614     }                                                                      \
1615   }
1616 
1617   RESOLVE_MODEL_FLAG(variable_batch)
1618 
1619 #undef RESOLVE_MODEL_FLAG
1620 
1621   if (!model_flags.rnn_states().empty()) {
1622     model->flags.mutable_rnn_states()->CopyFrom(model_flags.rnn_states());
1623   }
1624 
1625   if (model->flags.model_checks_size() == 0) {
1626     model->flags.mutable_model_checks()->CopyFrom(model_flags.model_checks());
1627   }
1628 
1629   QCHECK_GT(model->flags.output_arrays_size(), 0)
1630       << "This model does not define output arrays, so a "
1631          "--output_arrays flag must be given on the command-line.";
1632 
1633   for (auto& input_array_proto : *model->flags.mutable_input_arrays()) {
1634     auto& input_array = model->GetOrCreateArray(input_array_proto.name());
1635     if (input_array_proto.has_data_type()) {
1636       const ArrayDataType specified_type =
1637           ConvertIODataTypeToArrayDataType(input_array_proto.data_type());
1638       QCHECK(specified_type != ArrayDataType::kNone);
1639       if (input_array.data_type != ArrayDataType::kNone) {
1640         QCHECK(specified_type == input_array.data_type)
1641             << "For input array " << input_array_proto.name()
1642             << " the specified input data type "
1643             << IODataType_Name(input_array_proto.data_type())
1644             << " conflicts with the existing type.";
1645       }
1646       input_array.data_type = specified_type;
1647     }
1648 
1649     if (input_array.data_type == ArrayDataType::kNone) {
1650       // We start out with a float input array;
1651       // that may get replaced by a uint8 array later, by
1652       // MakeInitialDequantizeOp.
1653       input_array.data_type = ArrayDataType::kFloat;
1654     }
1655 
1656     // Compare/merge the model->flags describing the input_shape with
1657     // the actual input array's shape.
1658     if (!input_array.has_shape()) {
1659       if (input_array_proto.has_shape()) {
1660         auto& input_array_dims = *input_array.mutable_shape()->mutable_dims();
1661         CheckValidShapeDimensions(input_array_proto.shape().dims());
1662         for (const auto& dim : input_array_proto.shape().dims()) {
1663           input_array_dims.push_back(dim);
1664         }
1665       }
1666     } else {
1667       if (input_array_proto.has_shape()) {
1668         // If an input shape was specified on the flags ensure that it matches
1669         // the actual shape in the model.
1670         const auto& input_array_dims =
1671             *input_array.mutable_shape()->mutable_dims();
1672         CHECK_EQ(input_array_dims.size(),
1673                  input_array_proto.shape().dims_size());
1674         for (int i = 0; i < input_array_dims.size(); i++) {
1675           CHECK_EQ(input_array_dims[i], input_array_proto.shape().dims(i));
1676         }
1677       } else {
1678         for (int i = 0; i < input_array.shape().dimensions_count(); i++) {
1679           input_array_proto.mutable_shape()->add_dims(
1680               input_array.shape().dims(i));
1681         }
1682       }
1683     }
1684 
1685     const float mean_value = input_array_proto.mean_value();
1686     const float std_value = input_array_proto.std_value();
1687     MinMax input_minmax;
1688     float qmin = 0, qmax = 255;
1689     if (input_array.data_type == ArrayDataType::kInt16) {
1690       qmin = -32768;
1691       qmax = 32767;
1692     }
1693     input_minmax.min = (qmin - mean_value) / std_value;
1694     input_minmax.max = (qmax - mean_value) / std_value;
1695     if (!input_array.minmax) {
1696       input_array.GetOrCreateMinMax() = input_minmax;
1697     }
1698   }
1699 
1700   // Creation of the RNN state arrays
1701   for (const auto& rnn_state : model->flags.rnn_states()) {
1702     CreateOrCheckRnnStateArray(rnn_state.state_array(), rnn_state.size(),
1703                                rnn_state.num_dims(), model);
1704   }
1705 
1706   model->flags.set_change_concat_input_ranges(
1707       model_flags.change_concat_input_ranges());
1708   model->flags.set_allow_nonascii_arrays(model_flags.allow_nonascii_arrays());
1709   model->flags.set_allow_nonexistent_arrays(
1710       model_flags.allow_nonexistent_arrays());
1711 
1712   CHECK(!model->flags.has_arrays_extra_info());
1713   *model->flags.mutable_arrays_extra_info() = model_flags.arrays_extra_info();
1714 }
1715 
CheckIsReadyForQuantization(const Model & model)1716 void CheckIsReadyForQuantization(const Model& model) {
1717   for (const auto& op : model.operators) {
1718     for (const auto& input : op->inputs) {
1719       const auto& input_array = model.GetArray(input);
1720       if (input_array.data_type != ArrayDataType::kFloat) {
1721         // The array is not floats, no quantization needed.
1722         continue;
1723       }
1724       if (input_array.minmax) {
1725         // The array has minmax, we're good.
1726         continue;
1727       }
1728       if (input_array.buffer) {
1729         // The array has a constant buffer, so we can
1730         // fall back to computing the minmax from actual array entries
1731         // (with a WARNING about possible accuracy implications).
1732         continue;
1733       }
1734       LOG(FATAL)
1735           << "Array " << input << ", which is an input to the "
1736           << HelpfulOperatorTypeName(*op) << " operator producing the output "
1737           << "array " << op->outputs[0] << ", is lacking min/max data, "
1738           << "which is necessary for quantization. If accuracy matters, either "
1739           << "target a non-quantized output format, or run quantized training "
1740           << "with your model from a floating point checkpoint to change the "
1741           << "input graph to contain min/max information. If you don't care "
1742           << "about accuracy, you can pass --default_ranges_min= and "
1743           << "--default_ranges_max= for easy experimentation.";
1744     }
1745   }
1746 }
1747 
ElementSize(ArrayDataType data_type)1748 int ElementSize(ArrayDataType data_type) {
1749   switch (data_type) {
1750     case ArrayDataType::kBool:
1751       return sizeof(bool);
1752     case ArrayDataType::kFloat:
1753       return 4;
1754     case ArrayDataType::kInt8:
1755       return 1;
1756     case ArrayDataType::kUint8:
1757       return 1;
1758     case ArrayDataType::kInt16:
1759       return 2;
1760     case ArrayDataType::kUint16:
1761       return 2;
1762     case ArrayDataType::kInt32:
1763       return 4;
1764     case ArrayDataType::kUint32:
1765       return 4;
1766     case ArrayDataType::kInt64:
1767       return 8;
1768     case ArrayDataType::kUint64:
1769       return 8;
1770     case ArrayDataType::kComplex64:
1771       return 8;
1772     case ArrayDataType::kComplex128:
1773       return 16;
1774     case ArrayDataType::kFloat64:
1775       return 8;
1776 
1777     // Usually not critical limitation because strings are only input and/or
1778     // output.
1779     case ArrayDataType::kString:
1780       LOG(FATAL) << "Transient arrays with strings are not supported yet";
1781       return 0;
1782     default:
1783       LOG(FATAL) << "Unknown data_type = " << static_cast<int>(data_type);
1784       return 0;
1785   }
1786 }
1787 
DropMinMax(Model * model,const std::string & array_name)1788 void DropMinMax(Model* model, const std::string& array_name) {
1789   auto& array = model->GetArray(array_name);
1790   if (!!array.minmax) {
1791     LOG(WARNING) << "Dropping MinMax information in array " << array_name
1792                  << ". Expect inaccuracy in quantized inference.";
1793     array.minmax = nullptr;
1794   }
1795 }
1796 
IsAllocatableTransientArray(const Model & model,const std::string & array_name)1797 bool IsAllocatableTransientArray(const Model& model,
1798                                  const std::string& array_name) {
1799   // Optional array is not transient
1800   if (model.IsOptionalArray(array_name)) return false;
1801   // The model's input and output arrays are externally allocated.
1802   // They are not transient arrays.
1803   if (IsInputArray(model, array_name) || IsOutputArray(model, array_name)) {
1804     return false;
1805   }
1806   const auto& array = &model.GetArray(array_name);
1807   // An array with a constant buffer isn't a transient array.
1808   if (!!array->buffer) {
1809     return false;
1810   }
1811   // An array without shape isn't allocatable.
1812   if (!array->has_shape()) {
1813     return false;
1814   }
1815 
1816   // The size of string tensors is rarely known ahead of time, so all transient
1817   // tensors of this type will need to be dynamically allocated.
1818   if (array->final_data_type == ArrayDataType::kString ||
1819       array->data_type == ArrayDataType::kString) {
1820     return false;
1821   }
1822 
1823   return true;
1824 }
1825 
AvailableArrayName(const Model & model,const std::string & name)1826 std::string AvailableArrayName(const Model& model, const std::string& name) {
1827   std::string sanitized_name = SanitizeNameForTFNode(name);
1828   if (!model.HasArray(sanitized_name) &&
1829       !model.IsOptionalArray(sanitized_name)) {
1830     return sanitized_name;
1831   }
1832   const int kNumSuffixesToTry = 1000;
1833   for (int i = 0; i < kNumSuffixesToTry; i++) {
1834     const std::string& name_with_suffix =
1835         toco::port::StringF("%s_%d", sanitized_name, i);
1836     if (!model.HasArray(name_with_suffix) &&
1837         !model.IsOptionalArray(name_with_suffix)) {
1838       return name_with_suffix;
1839     }
1840   }
1841   LOG(FATAL) << "Could not find an available array name starting with "
1842              << sanitized_name << ". Tried " << kNumSuffixesToTry
1843              << " suffixes, all were taken!";
1844   return "";
1845 }
1846 
ShapeToString(const Shape & shape)1847 std::string ShapeToString(const Shape& shape) {
1848   if (shape.dimensions_count() == 0) {
1849     return "[]";
1850   }
1851 
1852   return absl::StrCat("[ ", absl::StrJoin(shape.dims(), ", "), " ]");
1853 }
1854 
PrintArrayShape(Model * model,const std::string & name)1855 void PrintArrayShape(Model* model, const std::string& name) {
1856   if (!model->GetArray(name).has_shape()) {
1857     LOG(INFO) << name << " has no shape";
1858     return;
1859   }
1860   LOG(INFO) << name
1861             << " has shape: " << ShapeToString(model->GetArray(name).shape());
1862 }
1863 
IsArrayFullyConnectedWeights(const Model & model,const std::string & name)1864 bool IsArrayFullyConnectedWeights(const Model& model, const std::string& name) {
1865   bool is_fc_weights = false;
1866   bool is_something_else = false;
1867   for (const auto& op : model.operators) {
1868     for (int input_index = 0; input_index < op->inputs.size(); input_index++) {
1869       if (op->inputs[input_index] == name) {
1870         if (op->type == OperatorType::kFullyConnected && input_index == 1) {
1871           is_fc_weights = true;
1872         } else {
1873           is_something_else = true;
1874         }
1875       }
1876     }
1877   }
1878   CHECK(!(is_fc_weights && is_something_else));
1879   return is_fc_weights;
1880 }
1881 
CreateInt32Array(Model * model,const std::string & param_name,const std::vector<int> & value)1882 std::string CreateInt32Array(Model* model, const std::string& param_name,
1883                              const std::vector<int>& value) {
1884   auto param_array_name = AvailableArrayName(*model, param_name);
1885   auto& param_array = model->GetOrCreateArray(param_array_name);
1886   param_array.mutable_shape()->ReplaceDims({static_cast<int>(value.size())});
1887   param_array.data_type = ArrayDataType::kInt32;
1888   auto& param_array_data =
1889       param_array.GetMutableBuffer<ArrayDataType::kInt32>().data;
1890   param_array_data.resize(RequiredBufferSizeForShape(param_array.shape()));
1891   for (int i = 0; i < value.size(); ++i) {
1892     param_array_data[i] = value[i];
1893   }
1894   return param_array_name;
1895 }
1896 
EstimateArithmeticOpsCount(const Model & model,const Operator & op,int64 * result)1897 bool EstimateArithmeticOpsCount(const Model& model, const Operator& op,
1898                                 int64* result) {
1899   switch (op.type) {
1900     case OperatorType::kFullyConnected:
1901     case OperatorType::kConv:
1902     case OperatorType::kDepthwiseConv: {
1903       const auto& output_array = model.GetArray(op.outputs[0]);
1904       const auto& weights_array = model.GetArray(op.inputs[1]);
1905       if (!output_array.has_shape() || !weights_array.has_shape()) {
1906         return false;
1907       }
1908       int64 cols = 1;
1909       for (int i = 0; i < output_array.shape().dimensions_count() - 1; i++) {
1910         cols *= output_array.shape().dims(i);
1911       }
1912       const int64 cost_per_col =
1913           2 * RequiredBufferSizeForShape(weights_array.shape());
1914       *result = cost_per_col * cols;
1915       if (op.inputs.size() > 2) {
1916         // There is a bias vector. One more op per output value.
1917         *result += RequiredBufferSizeForShape(output_array.shape());
1918       }
1919       break;
1920     }
1921     case OperatorType::kTransposeConv: {
1922       const auto& input_array = model.GetArray(op.inputs[2]);
1923       const auto& weights_array = model.GetArray(op.inputs[1]);
1924       if (!input_array.has_shape() || !weights_array.has_shape()) {
1925         return false;
1926       }
1927       const Shape& input = input_array.shape();
1928       const Shape& weights = weights_array.shape();
1929       // Compute op count from the seven nested loops of
1930       // tflite::reference_ops::TransposeConv():
1931       *result = 2 * input.dims(0) * input.dims(1) * input.dims(2) *
1932                 input.dims(3) * weights.dims(1) * weights.dims(2) *
1933                 weights.dims(0);
1934       // Note that tflite::optimized_ops::TransposeConv() uses an im2col matrix
1935       // and has a higher op count, by a factor of (output_height*output_width)
1936       // vs. (input_height*input_width). Yet it generally performs better
1937       // because of coherent memory access. (At least for 2x2 striding. But not
1938       // likely for all cases.)
1939       break;
1940     }
1941     case OperatorType::kAdd:
1942     case OperatorType::kSub:
1943     case OperatorType::kMul: {
1944       const auto& output_array = model.GetArray(op.outputs[0]);
1945       if (!output_array.has_shape()) {
1946         return false;
1947       }
1948       *result = RequiredBufferSizeForShape(output_array.shape());
1949       break;
1950     }
1951     case OperatorType::kAddN: {
1952       const auto& output_array = model.GetArray(op.outputs[0]);
1953       if (!output_array.has_shape()) {
1954         return false;
1955       }
1956       // AddN cost is roughly the same cost as N-1 Adds.
1957       const int64 num_adds = op.inputs.size() - 1;
1958       *result = num_adds * RequiredBufferSizeForShape(output_array.shape());
1959       break;
1960     }
1961     case OperatorType::kLogistic:
1962     case OperatorType::kSoftmax:
1963     case OperatorType::kLogSoftmax:
1964     case OperatorType::kTanh: {
1965       const auto& output_array = model.GetArray(op.outputs[0]);
1966       if (!output_array.has_shape()) {
1967         return false;
1968       }
1969       // As a very rough ballpark, the cost of evaluating a math function
1970       // such as tanh or logistic is about 32 multiplications, and about as
1971       // many additions/subtractions. (Just a power-of-two order-of-magnitude
1972       // from looking at actual implementations that we use in runtime/ code).
1973       *result = 64 * RequiredBufferSizeForShape(output_array.shape());
1974       break;
1975     }
1976     case OperatorType::kMaxPool: {
1977       const auto& maxpool = *static_cast<const MaxPoolOperator*>(&op);
1978       const auto& output_array = model.GetArray(op.outputs[0]);
1979       if (!output_array.has_shape()) {
1980         return false;
1981       }
1982       *result = RequiredBufferSizeForShape(output_array.shape()) *
1983                 maxpool.kheight * maxpool.kwidth;
1984       break;
1985     }
1986     case OperatorType::kAveragePool: {
1987       const auto& avgpool = *static_cast<const AveragePoolOperator*>(&op);
1988       const auto& output_array = model.GetArray(op.outputs[0]);
1989       if (!output_array.has_shape()) {
1990         return false;
1991       }
1992       *result = RequiredBufferSizeForShape(output_array.shape()) *
1993                 avgpool.kheight * avgpool.kwidth;
1994       break;
1995     }
1996     case OperatorType::kL2Pool: {
1997       const auto* maxpool = static_cast<const MaxPoolOperator*>(&op);
1998       const auto& output_array = model.GetArray(op.outputs[0]);
1999       if (!output_array.has_shape()) {
2000         return false;
2001       }
2002       // The sum of squares requires (kheight*kwidth) multiply-adds,
2003       // and then there is the sqrt which we ballpark at 32 ops.
2004       const int64 cost_per_val = 2 * maxpool->kheight * maxpool->kwidth + 32;
2005       *result = RequiredBufferSizeForShape(output_array.shape()) * cost_per_val;
2006       break;
2007     }
2008     case OperatorType::kL2Normalization: {
2009       const auto& output_array = model.GetArray(op.outputs[0]);
2010       if (!output_array.has_shape()) {
2011         return false;
2012       }
2013       // Computing the squared L2 norm is N multiply-adds so 2N ops,
2014       // then the single inverse-sqrt is negligible, then we multiply each
2015       // value by the resulting multiplier, so an extra N ops. count 3N ops.
2016       *result = 3 * RequiredBufferSizeForShape(output_array.shape());
2017       break;
2018     }
2019     default:
2020       *result = 0;
2021       break;
2022   }
2023   return true;
2024 }
2025 
EstimateArithmeticOpsCount(const Model & model,int64 * result)2026 bool EstimateArithmeticOpsCount(const Model& model, int64* result) {
2027   int64 total = 0;
2028   for (const auto& op : model.operators) {
2029     int64 num_ops;
2030     if (!EstimateArithmeticOpsCount(model, *op, &num_ops)) {
2031       return false;
2032     }
2033     total += num_ops;
2034   }
2035   *result = total;
2036   return true;
2037 }
2038 
FormattedNumber(int64 x)2039 std::string FormattedNumber(int64 x) {
2040   const int64 million = 1000000;
2041   const int64 billion = 1000000000;
2042   if (x < 10000) {
2043     return toco::port::StringF("%d ", x);
2044   } else if (x < billion) {
2045     return toco::port::StringF("%.3f M", static_cast<double>(x) / million);
2046   } else {
2047     return toco::port::StringF("%.3f G", static_cast<double>(x) / billion);
2048   }
2049 }
2050 
GetShuffleShape(AxesOrder input_axes_order,AxesOrder output_axes_order,std::vector<int> * shuffle)2051 void GetShuffleShape(AxesOrder input_axes_order, AxesOrder output_axes_order,
2052                      std::vector<int>* shuffle) {
2053   CHECK_EQ(AxesCount(input_axes_order), AxesCount(output_axes_order));
2054   shuffle->resize(4);
2055   for (int i = 0; i < 4; i++) {
2056     (*shuffle)[i] = i;
2057   }
2058   if (input_axes_order == output_axes_order) {
2059     // nothing to do
2060   } else if (AxesCount(input_axes_order) == 2) {
2061     shuffle->resize(2);
2062     (*shuffle)[0] = 1;
2063     (*shuffle)[1] = 0;
2064   } else if (input_axes_order == AxesOrder::kOHWI &&
2065              output_axes_order == AxesOrder::kHWIO) {
2066     // 3210 <- 3210
2067     // HWIO <- OHWI
2068     *shuffle = {1, 2, 3, 0};
2069   } else if (input_axes_order == AxesOrder::kHWIO &&
2070              output_axes_order == AxesOrder::kOHWI) {
2071     // 3210 <- 3210
2072     // OHWI <- HWIO
2073     *shuffle = {3, 0, 1, 2};
2074   } else if (input_axes_order == AxesOrder::kOHWI &&
2075              output_axes_order == AxesOrder::kHWOI) {
2076     *shuffle = {1, 2, 0, 3};
2077   } else {
2078     LOG(FATAL) << "Bad shuffle";
2079   }
2080 }
2081 
ExtendShuffle(const std::vector<int> & input_shuffle,int newdim,std::vector<int> * extended_shuffle)2082 void ExtendShuffle(const std::vector<int>& input_shuffle, int newdim,
2083                    std::vector<int>* extended_shuffle) {
2084   *extended_shuffle = input_shuffle;
2085   CHECK(newdim >= input_shuffle.size());
2086   const int pad_size = newdim - input_shuffle.size();
2087   extended_shuffle->resize(newdim);
2088   for (int i = 0; i < pad_size; i++) {
2089     (*extended_shuffle)[i] = i;
2090   }
2091   for (int i = pad_size; i < newdim; i++) {
2092     (*extended_shuffle)[i] = input_shuffle[i - pad_size] + pad_size;
2093   }
2094 }
2095 
ShuffleDims(const Shape & input_shape,AxesOrder input_axes_order,AxesOrder output_axes_order,Shape * output_shape)2096 void ShuffleDims(const Shape& input_shape, AxesOrder input_axes_order,
2097                  AxesOrder output_axes_order, Shape* output_shape) {
2098   if (input_axes_order == AxesOrder::kHWIM &&
2099       output_axes_order == AxesOrder::k1HWO) {
2100     // This special case isn't just a permutation, the IM pair of dims get
2101     // merged into the 3 dim, so we have to special-case it.
2102     *output_shape = Shape({1, input_shape.dims(0), input_shape.dims(1),
2103                            input_shape.dims(3) * input_shape.dims(2)});
2104   } else {
2105     std::vector<int> shuffle;
2106     GetShuffleShape(input_axes_order, output_axes_order, &shuffle);
2107     std::vector<int>* output_dims = output_shape->mutable_dims();
2108     output_dims->resize(input_shape.dimensions_count());
2109     for (int i = 0; i < input_shape.dimensions_count(); i++) {
2110       (*output_dims)[i] = input_shape.dims(shuffle[i]);
2111     }
2112   }
2113 }
2114 
2115 template <typename T>
ShuffleArrayTemplate(const Shape & input_shape,AxesOrder input_axes_order,AxesOrder output_axes_order,const Shape & output_shape,const T * input_data,T * output_data)2116 void ShuffleArrayTemplate(const Shape& input_shape, AxesOrder input_axes_order,
2117                           AxesOrder output_axes_order,
2118                           const Shape& output_shape, const T* input_data,
2119                           T* output_data) {
2120   if (input_axes_order == AxesOrder::kHWIM &&
2121       output_axes_order == AxesOrder::k1HWO) {
2122     // This special case isn't just a permutation, the IM pair of dims get
2123     // merged into the O dim, so we have to special-case it. Fortunately,
2124     // as far as array shuffling is concerned, it's just the identity
2125     // transformation.
2126     memcpy(output_data, input_data,
2127            RequiredBufferSizeForShape(input_shape) * sizeof(output_data[0]));
2128     return;
2129   }
2130   CHECK(input_shape.dimensions_count() == output_shape.dimensions_count());
2131   const int dim = input_shape.dimensions_count();
2132   CHECK_LE(dim, 4);
2133   std::vector<int> shuffle;
2134   GetShuffleShape(input_axes_order, output_axes_order, &shuffle);
2135   CHECK(shuffle.size() >= dim);
2136   for (int i = 0; i < dim; i++) {
2137     CHECK(shuffle[i] >= 0 && shuffle[i] < dim);
2138     CHECK(input_shape.dims(shuffle[i]) == output_shape.dims(i));
2139   }
2140   Shape extended_input_shape = input_shape;
2141   ExtendShape(&extended_input_shape, 4);
2142   Shape extended_output_shape = output_shape;
2143   ExtendShape(&extended_output_shape, 4);
2144   std::vector<int> extended_shuffle;
2145   ExtendShuffle(shuffle, 4, &extended_shuffle);
2146 
2147   const std::vector<int>& extended_input_dims = extended_input_shape.dims();
2148   const std::vector<int>& extended_output_dims = extended_output_shape.dims();
2149 
2150   // TODO(starka): Rework to handle different numbers of dimensions.
2151   int input_strides[4];
2152   input_strides[3] = 1;
2153   input_strides[2] = extended_input_dims[3];
2154   input_strides[1] = input_strides[2] * extended_input_dims[2];
2155   input_strides[0] = input_strides[1] * extended_input_dims[1];
2156   const int input_stride_0 = input_strides[extended_shuffle[3]];
2157   const int input_stride_1 = input_strides[extended_shuffle[2]];
2158   const int input_stride_2 = input_strides[extended_shuffle[1]];
2159   const int input_stride_3 = input_strides[extended_shuffle[0]];
2160 
2161   const int output_size_0 = extended_output_dims[3];
2162   const int output_size_1 = extended_output_dims[2];
2163   const int output_size_2 = extended_output_dims[1];
2164   const int output_size_3 = extended_output_dims[0];
2165   const int output_stride_0 = 1;
2166   const int output_stride_1 = output_size_0;
2167   const int output_stride_2 = output_stride_1 * output_size_1;
2168   const int output_stride_3 = output_stride_2 * output_size_2;
2169 
2170   for (int i3 = 0; i3 < output_size_3; i3++) {
2171     const T* const input_ptr_3 = input_data + i3 * input_stride_3;
2172     T* const output_ptr_3 = output_data + i3 * output_stride_3;
2173     for (int i2 = 0; i2 < output_size_2; i2++) {
2174       const T* const input_ptr_2 = input_ptr_3 + i2 * input_stride_2;
2175       T* const output_ptr_2 = output_ptr_3 + i2 * output_stride_2;
2176       for (int i1 = 0; i1 < output_size_1; i1++) {
2177         const T* input_ptr = input_ptr_2 + i1 * input_stride_1;
2178         T* output_ptr = output_ptr_2 + i1 * output_stride_1;
2179         T* const output_ptr_end = output_ptr + output_size_0 * output_stride_0;
2180         while (output_ptr != output_ptr_end) {
2181           *output_ptr = *input_ptr;
2182           input_ptr += input_stride_0;
2183           output_ptr += output_stride_0;
2184         }
2185       }
2186     }
2187   }
2188 }
2189 
ShuffleArray(const Shape & input_shape,AxesOrder input_axes_order,AxesOrder output_axes_order,const Shape & output_shape,const uint8 * input_data,uint8 * output_data)2190 void ShuffleArray(const Shape& input_shape, AxesOrder input_axes_order,
2191                   AxesOrder output_axes_order, const Shape& output_shape,
2192                   const uint8* input_data, uint8* output_data) {
2193   ShuffleArrayTemplate<uint8>(input_shape, input_axes_order, output_axes_order,
2194                               output_shape, input_data, output_data);
2195 }
2196 
ShuffleArray(const Shape & input_shape,AxesOrder input_axes_order,AxesOrder output_axes_order,const Shape & output_shape,const float * input_data,float * output_data)2197 void ShuffleArray(const Shape& input_shape, AxesOrder input_axes_order,
2198                   AxesOrder output_axes_order, const Shape& output_shape,
2199                   const float* input_data, float* output_data) {
2200   ShuffleArrayTemplate<float>(input_shape, input_axes_order, output_axes_order,
2201                               output_shape, input_data, output_data);
2202 }
2203 
AxesCount(AxesOrder axes_order)2204 int AxesCount(AxesOrder axes_order) {
2205   switch (axes_order) {
2206     case AxesOrder::kOneAxis:
2207       return 1;
2208     case AxesOrder::kRC:
2209       return 2;
2210     case AxesOrder::kCR:
2211       return 2;
2212     case AxesOrder::kHWIO:
2213       return 4;
2214     case AxesOrder::kOHWI:
2215       return 4;
2216     case AxesOrder::kHWIM:
2217       return 4;
2218     case AxesOrder::k1HWO:
2219       return 4;
2220     case AxesOrder::kNHWC:
2221       return 4;
2222     case AxesOrder::kHWOI:
2223       return 4;
2224     default:
2225       LOG(FATAL) << "Bad AxesOrder";
2226       return 0;
2227   }
2228 }
2229 
IsDiscardableArray(const Model & model,const std::string & array_name)2230 bool IsDiscardableArray(const Model& model, const std::string& array_name) {
2231   if (IsInputArray(model, array_name) || IsOutputArray(model, array_name)) {
2232     return false;
2233   }
2234   for (const auto& rnn_state : model.flags.rnn_states()) {
2235     if (!rnn_state.discardable()) {
2236       if (array_name == rnn_state.state_array()) {
2237         return false;
2238       }
2239       if (array_name == rnn_state.back_edge_source_array()) {
2240         return false;
2241       }
2242     }
2243   }
2244   return true;
2245 }
2246 
ReshapeIsEquivalentToTranspose(const Model & model,const TensorFlowReshapeOperator * op,bool allow_extra_unary_dims)2247 bool ReshapeIsEquivalentToTranspose(const Model& model,
2248                                     const TensorFlowReshapeOperator* op,
2249                                     bool allow_extra_unary_dims) {
2250   CHECK(!op->shape.empty());
2251   CHECK(model.HasArray(op->inputs[0]));
2252   CHECK(model.HasArray(op->outputs[0]));
2253 
2254   const auto& input_array = model.GetArray(op->inputs[0]);
2255   const auto& output_array = model.GetArray(op->outputs[0]);
2256 
2257   CHECK(input_array.has_shape());
2258   CHECK(output_array.has_shape());
2259 
2260   std::vector<int> in_shape = input_array.shape().dims();
2261   std::vector<int> out_shape = output_array.shape().dims();
2262 
2263   // If the reshape changes the number of dimensions so it cannot be interpreted
2264   // as a transpose.
2265   if (!allow_extra_unary_dims && in_shape.size() != out_shape.size()) {
2266     return false;
2267   }
2268 
2269   in_shape.erase(std::remove(in_shape.begin(), in_shape.end(), 1),
2270                  in_shape.end());
2271   out_shape.erase(std::remove(out_shape.begin(), out_shape.end(), 1),
2272                   out_shape.end());
2273   return in_shape == out_shape;
2274 }
2275 
CheckFinalDataTypesSatisfied(const Model & model)2276 void CheckFinalDataTypesSatisfied(const Model& model) {
2277   for (const auto& array_entry : model.GetArrayMap()) {
2278     const auto& array = *array_entry.second;
2279     if (array.data_type == ArrayDataType::kBool) {
2280       // Boolean values are never quantized.
2281       continue;
2282     }
2283 
2284     // If the final data type is int16, the data type may be float, for example
2285     // after dequantization.
2286     if (array.final_data_type != ArrayDataType::kNone &&
2287         array.final_data_type != ArrayDataType::kInt16) {
2288       CHECK(array.data_type == array.final_data_type)
2289           << "Array \"" << array_entry.first
2290           << "\" has mis-matching actual and final data types (data_type="
2291           << ArrayDataTypeName(array.data_type)
2292           << ", final_data_type=" << ArrayDataTypeName(array.final_data_type)
2293           << ").";
2294     }
2295   }
2296 }
2297 
ConvertIODataTypeToArrayDataType(IODataType type)2298 ArrayDataType ConvertIODataTypeToArrayDataType(IODataType type) {
2299   switch (type) {
2300     case FLOAT:
2301       return ArrayDataType::kFloat;
2302     case QUANTIZED_UINT8:
2303       return ArrayDataType::kUint8;
2304     case INT8:
2305       return ArrayDataType::kInt8;
2306     case QUANTIZED_INT16:
2307       return ArrayDataType::kInt16;
2308     case INT32:
2309       return ArrayDataType::kInt32;
2310     case UINT32:
2311       return ArrayDataType::kUint32;
2312     case INT64:
2313       return ArrayDataType::kInt64;
2314     case UINT64:
2315       return ArrayDataType::kUint64;
2316     case BOOL:
2317       return ArrayDataType::kBool;
2318     case STRING:
2319       return ArrayDataType::kString;
2320     case COMPLEX64:
2321       return ArrayDataType::kComplex64;
2322     case COMPLEX128:
2323       return ArrayDataType::kComplex128;
2324     case FLOAT16:
2325       return ArrayDataType::kFloat16;
2326     case FLOAT64:
2327       return ArrayDataType::kFloat64;
2328     case RESOURCE:
2329     case VARIANT:
2330     default:
2331       return ArrayDataType::kNone;
2332   }
2333 }
2334 
FinishBuildingRNNStates(Model * model)2335 void FinishBuildingRNNStates(Model* model) {
2336   for (const auto& rnn_state : model->flags.rnn_states()) {
2337     if (!model->HasArray(rnn_state.back_edge_source_array()) ||
2338         !model->HasArray(rnn_state.state_array())) {
2339       CHECK(model->HasArray(rnn_state.back_edge_source_array()));
2340       CHECK(model->HasArray(rnn_state.state_array()));
2341       continue;
2342     }
2343     const auto& src_array = model->GetArray(rnn_state.back_edge_source_array());
2344     auto& dst_array = model->GetArray(rnn_state.state_array());
2345     if (src_array.data_type == ArrayDataType::kNone &&
2346         dst_array.data_type == ArrayDataType::kNone) {
2347       dst_array.data_type = ArrayDataType::kFloat;
2348     }
2349   }
2350 }
2351 
2352 // Returns the array names that match the ArraysExtraInfo's name and
2353 // name_regexp. The regexp match is for a full match.
ScanArrayNames(const Model & model,const toco::ArraysExtraInfo_Entry & entry)2354 std::unordered_set<std::string> ScanArrayNames(
2355     const Model& model, const toco::ArraysExtraInfo_Entry& entry) {
2356   std::unordered_set<std::string> matches;
2357   if (model.HasArray(entry.name())) {
2358     matches.insert(entry.name());
2359   }
2360   if (!entry.name_regexp().empty()) {
2361     const auto& arrays = model.GetArrayMap();
2362     const RE2 name_regexp = {entry.name_regexp()};
2363     for (auto it = arrays.begin(); it != arrays.end(); ++it) {
2364       if (RE2::FullMatch(it->first, name_regexp)) {
2365         matches.insert(it->first);
2366       }
2367     }
2368   }
2369   return matches;
2370 }
2371 
UseArraysExtraInfo(Model * model,bool quantize_output)2372 void UseArraysExtraInfo(Model* model, bool quantize_output) {
2373   for (const auto& entry : model->flags.arrays_extra_info().entries()) {
2374     const auto matches = ScanArrayNames(*model, entry);
2375     if (matches.empty()) {
2376       LOG(ERROR) << "arrays_extra_info_file: No matching arrays found for "
2377                  << (entry.has_name() ? entry.name() : "")
2378                  << (entry.has_name_regexp() ? entry.name_regexp() : "");
2379       continue;
2380     }
2381     for (const auto& matched_name : matches) {
2382       auto& array = model->GetArray(matched_name);
2383       if (entry.has_min() || entry.has_max()) {
2384         CHECK_EQ(entry.has_min(), entry.has_max());
2385         auto& minmax = array.GetOrCreateMinMax();
2386         minmax.min = entry.min();
2387         minmax.max = entry.max();
2388       }
2389       if (entry.has_data_type() && quantize_output) {
2390         array.final_data_type =
2391             ConvertIODataTypeToArrayDataType(entry.data_type());
2392       }
2393       if (entry.has_shape()) {
2394         array.clear_shape();
2395         // Make sure to create the shape even if there are no dims, to
2396         // correctly record 0-D shapes.
2397         array.mutable_shape();
2398         for (const auto& dim : entry.shape().dims()) {
2399           array.mutable_shape()->mutable_dims()->push_back(dim);
2400         }
2401       }
2402       if (entry.has_constant_float_value()) {
2403         CHECK(array.has_shape());
2404         if (array.data_type == ArrayDataType::kFloat) {
2405           auto& data = array.GetMutableBuffer<ArrayDataType::kFloat>().data;
2406           data.resize(RequiredBufferSizeForShape(array.shape()));
2407           for (float& f : data) {
2408             f = entry.constant_float_value();
2409           }
2410         }
2411       }
2412     }
2413   }
2414 }
2415 
UndoWeightsShuffling(Model * model)2416 void UndoWeightsShuffling(Model* model) {
2417   for (const auto& op : model->operators) {
2418     if (op->type != toco::OperatorType::kFullyConnected) {
2419       continue;
2420     }
2421     const auto& fc_op = static_cast<toco::FullyConnectedOperator&>(*op);
2422     if (fc_op.weights_format == FullyConnectedWeightsFormat::kDefault) {
2423       continue;
2424     }
2425     const std::string& weights_name = fc_op.inputs[1];
2426     QCHECK_EQ(CountOpsWithInput(*model, weights_name), 1);
2427     auto& weights_array = model->GetArray(weights_name);
2428     QCHECK(weights_array.data_type == ArrayDataType::kUint8);
2429     auto& weights_data =
2430         weights_array.GetMutableBuffer<toco::ArrayDataType::kUint8>().data;
2431     const auto& weights_shape = weights_array.shape();
2432     QCHECK_EQ(weights_shape.dimensions_count(), 2);
2433     const int rows = weights_shape.dims(0);
2434     const int cols = weights_shape.dims(1);
2435     QCHECK_EQ(rows % 4, 0);
2436     QCHECK_EQ(cols % 16, 0);
2437     CHECK_EQ(rows * cols, weights_data.size());
2438     // Compute the de-shuffled weights
2439     std::vector<uint8> deshuffled_data(weights_data.size());
2440     uint8* shuffled_data_ptr = weights_data.data();
2441     for (int r = 0; r < rows; r += 4) {
2442       for (int c = 0; c < cols; c += 16) {
2443         for (int i = 0; i < 4; i++) {
2444           uint8* deshuffled_data_ptr =
2445               deshuffled_data.data() + (r + i) * cols + c;
2446           for (int j = 0; j < 16; j++) {
2447             uint8 shuffled_val = *shuffled_data_ptr++;
2448             // Deshuffling isn't only about deshuffling the storage layout,
2449             // it's also about undoing the flipping of the sign bit, which is
2450             // performed on the shuffled weights.
2451             uint8 deshuffled_val = shuffled_val ^ 0x80;
2452             *deshuffled_data_ptr++ = deshuffled_val;
2453           }
2454         }
2455       }
2456     }
2457     CHECK_EQ(shuffled_data_ptr, weights_data.data() + rows * cols);
2458     // Switch this FC op to using the deshuffled weights.
2459     weights_data = std::move(deshuffled_data);
2460   }
2461 }
2462 
CopyMinMaxAndQuantizationRelatedFields(const Array & src,Array * dst)2463 void CopyMinMaxAndQuantizationRelatedFields(const Array& src, Array* dst) {
2464   if (src.minmax) {
2465     dst->GetOrCreateMinMax() = src.GetMinMax();
2466   }
2467   if (src.quantization_params) {
2468     dst->GetOrCreateQuantizationParams() = src.GetQuantizationParams();
2469   }
2470   dst->narrow_range = src.narrow_range;
2471 }
2472 
2473 }  // namespace toco
2474