1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <memory>
16 #include <string>
17 #include <unordered_map>
18 #include <vector>
19
20 #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
21 #include "tensorflow/lite/toco/model.h"
22 #include "tensorflow/lite/toco/tooling_util.h"
23 #include "tensorflow/core/platform/logging.h"
24
25 namespace toco {
26 namespace {
27
28 // Reroute all edges involving a given discardable array to another
29 // array instead. from_array is assumed to be discardable, and consequently
30 // this only updates operator edges (since discardable arrays only
31 // appear there, and not e.g. in model flags).
Reroute(const std::string & from,const std::string & to,Model * model)32 void Reroute(const std::string& from, const std::string& to, Model* model) {
33 for (const auto& op : model->operators) {
34 for (auto& output : op->outputs) {
35 if (output == from) {
36 output = to;
37 }
38 }
39 for (auto& input : op->inputs) {
40 if (input == from) {
41 input = to;
42 }
43 }
44 }
45 const Array& from_array = model->GetArray(from);
46 Array& to_array = model->GetOrCreateArray(to);
47 // Preserve minmax information if to_array didn't already have any.
48 if (from_array.minmax && !to_array.minmax) {
49 to_array.GetOrCreateMinMax() = from_array.GetMinMax();
50 // If we're copying minmax info, then we should also be copying
51 // narrow_range, which affects how minmax info is to be interpreted.
52 to_array.narrow_range = from_array.narrow_range;
53 }
54 // Separately, also preserve final_data_type if to_array didn't already
55 // have any.
56 if (from_array.final_data_type != ArrayDataType::kNone &&
57 to_array.final_data_type == ArrayDataType::kNone) {
58 to_array.final_data_type = from_array.final_data_type;
59 }
60 // The 'from' array may now be unused. We delete it here immediately
61 // so that this function doesn't violate graph invariants (no unused arrays)
62 // and as it's not trivial to get this right for the caller since
63 // DeleteOpAndArrays will no longer delete this array, since it's no longer
64 // referenced by this op.
65 DeleteArrayIfUnused(from, model);
66 }
67
68 } // namespace
69
RemoveTrivialPassthroughOp(GraphTransformation * transformation,Model * model,std::size_t op_index,int input_index)70 bool RemoveTrivialPassthroughOp(GraphTransformation* transformation,
71 Model* model, std::size_t op_index,
72 int input_index) {
73 auto passthru_it = model->operators.begin() + op_index;
74 auto* passthru_op = passthru_it->get();
75 CHECK_EQ(passthru_op->outputs.size(), 1);
76 CHECK_GE(passthru_op->inputs.size(), 1);
77
78 int main_input_array_index = 0;
79 if (input_index != -1) {
80 main_input_array_index = input_index;
81 } else {
82 // We call 'main input' the unique nonconstant input array if there is one,
83 // or else the 0-th input.
84 int count_nonconstant_input_arrays = 0;
85 for (size_t i = 0; i < passthru_op->inputs.size(); i++) {
86 if (!model->GetArray(passthru_op->inputs[i]).buffer) {
87 count_nonconstant_input_arrays++;
88 if (count_nonconstant_input_arrays == 1) {
89 main_input_array_index = i;
90 }
91 }
92 }
93 }
94
95 const std::string main_input_name =
96 passthru_op->inputs[main_input_array_index];
97 const std::string output_name = passthru_op->outputs[0];
98
99 if (IsDiscardableArray(*model, output_name)) {
100 transformation->AddMessageF(
101 "Removing %s, keeping its non-constant input array %s and removing %s",
102 LogName(*passthru_op), main_input_name, output_name);
103 Reroute(output_name, main_input_name, model);
104 } else if (IsDiscardableArray(*model, main_input_name) &&
105 !IsConstantParameterArray(*model, main_input_name)) {
106 transformation->AddMessageF(
107 "Removing %s, keeping its output array %s and removing non-constant "
108 "input %s",
109 LogName(*passthru_op), output_name, main_input_name);
110 Reroute(main_input_name, output_name, model);
111 } else {
112 transformation->AddMessageF(
113 "Cannot remove %s, neither its main input nor its output may be "
114 "discarded",
115 LogName(*passthru_op));
116 if (passthru_op->type != OperatorType::kReshape &&
117 model->GetArray(main_input_name).has_shape()) {
118 // We can't remove either array but we can remove the op. Converting it to
119 // a reshape gives us some hope of later on fixing that (either in the
120 // final runtime or as an additional fixup step).
121 //
122 // Note that we don't try to insert copies in place of reshapes as the
123 // copy itself is a trivial reshape and we'd go into an infinite loop!
124 transformation->AddMessageF("Replacing with a copy (reshape) instead");
125 InsertCopyOperator(model, main_input_name, output_name);
126 // To avoid using invalidated iterator, evaluate passthru_it again.
127 passthru_it = model->operators.begin() + op_index;
128 } else {
129 return false;
130 }
131 }
132
133 // Remove the pass-through node.
134 DeleteOpAndArrays(model, passthru_op);
135
136 return true;
137 }
138
139 } // namespace toco
140