• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <memory>
16 #include <string>
17 #include <unordered_map>
18 #include <vector>
19 
20 #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
21 #include "tensorflow/lite/toco/model.h"
22 #include "tensorflow/lite/toco/tooling_util.h"
23 #include "tensorflow/core/platform/logging.h"
24 
25 namespace toco {
26 namespace {
27 
28 // Reroute all edges involving a given discardable array to another
29 // array instead. from_array is assumed to be discardable, and consequently
30 // this only updates operator edges (since discardable arrays only
31 // appear there, and not e.g. in model flags).
Reroute(const std::string & from,const std::string & to,Model * model)32 void Reroute(const std::string& from, const std::string& to, Model* model) {
33   for (const auto& op : model->operators) {
34     for (auto& output : op->outputs) {
35       if (output == from) {
36         output = to;
37       }
38     }
39     for (auto& input : op->inputs) {
40       if (input == from) {
41         input = to;
42       }
43     }
44   }
45   const Array& from_array = model->GetArray(from);
46   Array& to_array = model->GetOrCreateArray(to);
47   // Preserve minmax information if to_array didn't already have any.
48   if (from_array.minmax && !to_array.minmax) {
49     to_array.GetOrCreateMinMax() = from_array.GetMinMax();
50     // If we're copying minmax info, then we should also be copying
51     // narrow_range, which affects how minmax info is to be interpreted.
52     to_array.narrow_range = from_array.narrow_range;
53   }
54   // Separately, also preserve final_data_type if to_array didn't already
55   // have any.
56   if (from_array.final_data_type != ArrayDataType::kNone &&
57       to_array.final_data_type == ArrayDataType::kNone) {
58     to_array.final_data_type = from_array.final_data_type;
59   }
60   // The 'from' array may now be unused. We delete it here immediately
61   // so that this function doesn't violate graph invariants (no unused arrays)
62   // and as it's not trivial to get this right for the caller since
63   // DeleteOpAndArrays will no longer delete this array, since it's no longer
64   // referenced by this op.
65   DeleteArrayIfUnused(from, model);
66 }
67 
68 }  // namespace
69 
RemoveTrivialPassthroughOp(GraphTransformation * transformation,Model * model,std::size_t op_index,int input_index)70 bool RemoveTrivialPassthroughOp(GraphTransformation* transformation,
71                                 Model* model, std::size_t op_index,
72                                 int input_index) {
73   auto passthru_it = model->operators.begin() + op_index;
74   auto* passthru_op = passthru_it->get();
75   CHECK_EQ(passthru_op->outputs.size(), 1);
76   CHECK_GE(passthru_op->inputs.size(), 1);
77 
78   int main_input_array_index = 0;
79   if (input_index != -1) {
80     main_input_array_index = input_index;
81   } else {
82     // We call 'main input' the unique nonconstant input array if there is one,
83     // or else the 0-th input.
84     int count_nonconstant_input_arrays = 0;
85     for (size_t i = 0; i < passthru_op->inputs.size(); i++) {
86       if (!model->GetArray(passthru_op->inputs[i]).buffer) {
87         count_nonconstant_input_arrays++;
88         if (count_nonconstant_input_arrays == 1) {
89           main_input_array_index = i;
90         }
91       }
92     }
93   }
94 
95   const std::string main_input_name =
96       passthru_op->inputs[main_input_array_index];
97   const std::string output_name = passthru_op->outputs[0];
98 
99   if (IsDiscardableArray(*model, output_name)) {
100     transformation->AddMessageF(
101         "Removing %s, keeping its non-constant input array %s and removing %s",
102         LogName(*passthru_op), main_input_name, output_name);
103     Reroute(output_name, main_input_name, model);
104   } else if (IsDiscardableArray(*model, main_input_name) &&
105              !IsConstantParameterArray(*model, main_input_name)) {
106     transformation->AddMessageF(
107         "Removing %s, keeping its output array %s and removing non-constant "
108         "input %s",
109         LogName(*passthru_op), output_name, main_input_name);
110     Reroute(main_input_name, output_name, model);
111   } else {
112     transformation->AddMessageF(
113         "Cannot remove %s, neither its main input nor its output may be "
114         "discarded",
115         LogName(*passthru_op));
116     if (passthru_op->type != OperatorType::kReshape &&
117         model->GetArray(main_input_name).has_shape()) {
118       // We can't remove either array but we can remove the op. Converting it to
119       // a reshape gives us some hope of later on fixing that (either in the
120       // final runtime or as an additional fixup step).
121       //
122       // Note that we don't try to insert copies in place of reshapes as the
123       // copy itself is a trivial reshape and we'd go into an infinite loop!
124       transformation->AddMessageF("Replacing with a copy (reshape) instead");
125       InsertCopyOperator(model, main_input_name, output_name);
126       // To avoid using invalidated iterator, evaluate passthru_it again.
127       passthru_it = model->operators.begin() + op_index;
128     } else {
129       return false;
130     }
131   }
132 
133   // Remove the pass-through node.
134   DeleteOpAndArrays(model, passthru_op);
135 
136   return true;
137 }
138 
139 }  // namespace toco
140