• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/toco/tensorflow_graph_matching/resolve_svdf.h"
16 
17 #include <ctype.h>
18 #include <stddef.h>
19 #include <algorithm>
20 #include <memory>
21 #include <string>
22 #include <vector>
23 
24 #include "google/protobuf/map.h"
25 #include "tensorflow/lite/toco/model.h"
26 #include "tensorflow/lite/toco/tensorflow_graph_matching/cluster.h"
27 #include "tensorflow/lite/toco/tensorflow_graph_matching/cluster_utils.h"
28 #include "tensorflow/lite/toco/toco_port.h"
29 #include "tensorflow/lite/toco/tooling_util.h"
30 
31 #include "tensorflow/core/framework/attr_value.pb.h"
32 #include "tensorflow/core/framework/graph.pb.h"
33 #include "tensorflow/core/framework/node_def.pb.h"
34 #include "tensorflow/core/framework/tensor.pb.h"
35 #include "tensorflow/core/framework/tensor_shape.pb.h"
36 #include "tensorflow/core/platform/logging.h"
37 
38 using tensorflow::GraphDef;
39 using tensorflow::NodeDef;
40 
41 namespace toco {
42 
43 namespace {
44 
45 // Receives a vector of cluster nodes and returns only those which are array
46 // partitions (of type 'Const' and have the pattern 'part_<.*>' in their name.
47 // Since these nodes are connected to a Concatenate node, it makes sure the
48 // axis value input of the Concatenate operator is 0.
FilterPartitionedConstNodes(const string & const_pattern,const std::vector<const NodeDef * > & cluster_nodes,std::vector<const NodeDef * > * const_node_parts)49 void FilterPartitionedConstNodes(
50     const string& const_pattern,
51     const std::vector<const NodeDef*>& cluster_nodes,
52     std::vector<const NodeDef*>* const_node_parts) {
53   for (const NodeDef* node : cluster_nodes) {
54     string node_name_to_upper = node->name();
55     std::transform(node_name_to_upper.begin(), node_name_to_upper.end(),
56                    node_name_to_upper.begin(), ::toupper);
57     if (StrContains(node->name(), const_pattern) && node->op() == "Const") {
58       if (StrContains(node_name_to_upper, "/PART_")) {
59         const_node_parts->push_back(node);
60       } else if (StrContains(node->name(), "AXIS") &&
61                  StrContains(node->name(), "CONCAT")) {
62         // For now only supporting Concatenate on Axix 0
63         const auto& value_attr = node->attr().at("value");
64         const tensorflow::TensorProto& tensor = value_attr.tensor();
65         CHECK_EQ(tensor.int_val(0), 0);
66       }
67     }
68   }
69   sort(const_node_parts->begin(), const_node_parts->end(),
70        [](const NodeDef* a, const NodeDef* b) {
71          return (a->name().compare(b->name()) < 0 &&
72                  (a->name().size() < b->name().size()));
73        });
74 }
75 
76 }  // namespace
77 
78 // SvdfCluster methods
79 
InferFilterRank()80 int SvdfCluster::InferFilterRank() {
81   for (const NodeDef* node : nodes_) {
82     if (StrContains(node->name(), "Reshape/shape")) {
83       const auto& value_attr = node->attr().at("value");
84       const tensorflow::TensorProto& tensor = value_attr.tensor();
85       std::vector<int32> shape_values(
86           tensor.tensor_content().size() / sizeof(int), 0);
87       port::CopyToBuffer(tensor.tensor_content(),
88                          reinterpret_cast<char*>(shape_values.data()));
89       CHECK_EQ(shape_values.size(), 3);
90       // shape_value array is arranged as:
91       // [num_units, rank, -1]
92       CHECK_EQ(shape_values[2], -1);
93       return shape_values[1];
94     }
95   }
96   return -1;
97 }
98 
CreateNodes()99 void SvdfCluster::CreateNodes() {
100   for (const string& const_pattern : const_node_patterns_) {
101     CreateConstNode(const_pattern);
102   }
103   std::unique_ptr<tensorflow::NodeDef> svdf_node(new NodeDef);
104   svdf_node->set_op("Svdf");
105   svdf_node->set_name(name_);
106   svdf_node->set_device(device_);
107 
108   // Add the main input.
109   svdf_node->add_input(inputs_[0]);
110 
111   // Add the rest of the inputs to Svdf cell: weights and bias.
112   CHECK(new_nodes_.size() == 3 || new_nodes_.size() == 2);
113   string* weights_feature_input = svdf_node->add_input();
114   string* weights_time_input = svdf_node->add_input();
115   string* bias_input;
116   if (new_nodes_.size() == 3) {
117     bias_input = svdf_node->add_input();
118   }
119   for (const std::unique_ptr<tensorflow::NodeDef>& node : new_nodes_) {
120     const string node_name = node->name();
121     if (StrContains(node_name, "SVDF_weights_feature")) {
122       *weights_feature_input = node_name;
123     } else if (StrContains(node_name, "SVDF_weights_time")) {
124       *weights_time_input = node_name;
125     } else if (StrContains(node_name, "SVDF_bias")) {
126       CHECK(bias_input) << "Bias input cannot be provided when there are only "
127                            "two Const input nodes!";
128       *bias_input = node_name;
129     } else {
130       // Unexpected input for Svdf op.
131       LOG(FATAL) << "Unexpected input node for SVDF op! Accepted inputs are: "
132                     "weights_feature, weights_time and bias.";
133     }
134   }
135   const int rank = InferFilterRank();
136   CHECK_GT(rank, 0);
137 
138   // Add Svdf activation and rank.
139   string activation_function =
140       StrContains(outputs_[0], "Relu") ? "Relu" : "None";
141   (*svdf_node->mutable_attr())["ActivationFunction"].set_s(activation_function);
142   (*svdf_node->mutable_attr())["Rank"].set_i(rank);
143 
144   // Finally add it to the list of the newly created nodes.
145   new_nodes_.push_back(std::move(svdf_node));
146 }
147 
CreateConstNode(const string & const_pattern)148 void SvdfCluster::CreateConstNode(const string& const_pattern) {
149   // Find the nodes with pattern like: "const_pattern"/part_xxx of type Const.
150   std::vector<const NodeDef*> const_node_parts;
151   FilterPartitionedConstNodes(const_pattern, nodes_, &const_node_parts);
152 
153   if (const_node_parts.empty()) return;
154 
155   bool transpose_tensor_value =
156       StrContains(const_pattern, "SVDF_weights_feature");
157 
158   // Merge them if necessary.
159   std::unique_ptr<tensorflow::NodeDef> merged_node(new NodeDef);
160   MaybeMergeConstNodes(const_node_parts, transpose_tensor_value, merged_node);
161   new_nodes_.push_back(std::move(merged_node));
162 }
163 
MaybeMergeConstNodes(const std::vector<const NodeDef * > & const_node_parts,bool transpose_tensor_value,const std::unique_ptr<tensorflow::NodeDef> & merged_node)164 void SvdfCluster::MaybeMergeConstNodes(
165     const std::vector<const NodeDef*>& const_node_parts,
166     bool transpose_tensor_value,
167     const std::unique_ptr<tensorflow::NodeDef>& merged_node) {
168   merged_node->set_name(const_node_parts[0]->name());
169   merged_node->set_op("Const");
170   merged_node->set_device(const_node_parts[0]->device());
171   (*merged_node->mutable_attr())["dtype"].set_type(
172       const_node_parts[0]->attr().at("dtype").type());
173 
174   // Figuring out Value attribute for the merged node.
175   // Assuming the partitioning is done on Axis 0.
176   // The attributes which are inferred:
177   // * Shape and dimensions
178   // * Float content values
179 
180   // Inferring shape and dimension
181   int dim0_size = 0;
182   int dim1_size = 1;
183   tensorflow::TensorProto* allocated_tensor =
184       (*merged_node->mutable_attr())["value"].mutable_tensor();
185   tensorflow::TensorShapeProto* allocated_tensor_shape =
186       allocated_tensor->mutable_tensor_shape();
187   auto tensor_shape_dim0 = allocated_tensor_shape->add_dim();
188   int allocated_content_flat_size = 0;
189   for (size_t i = 0; i < const_node_parts.size(); i++) {
190     const auto& value_attr = const_node_parts[i]->attr().at("value");
191     const tensorflow::TensorProto& tensor = value_attr.tensor();
192     if (i == 0) {
193       allocated_tensor->set_dtype(tensor.dtype());
194     } else {
195       CHECK_EQ(allocated_tensor->dtype(), tensor.dtype());
196     }
197     allocated_content_flat_size += tensor.tensor_content().size();
198     CHECK(tensor.has_tensor_shape());
199     const tensorflow::TensorShapeProto shape = tensor.tensor_shape();
200     dim0_size += shape.dim(0).size();
201     for (int d = 1; d < shape.dim_size(); d++) {
202       if (i == 0) {
203         allocated_tensor_shape->add_dim()->set_size(shape.dim(d).size());
204         allocated_tensor_shape->set_unknown_rank(shape.unknown_rank());
205         dim1_size *= shape.dim(d).size();
206       } else {
207         CHECK_EQ(shape.dim(d).size(), allocated_tensor_shape->dim(d).size());
208         CHECK_EQ(allocated_tensor_shape->unknown_rank(), shape.unknown_rank());
209       }
210     }
211   }
212 
213   // Copying the float content from each array partition.
214   std::unique_ptr<char[]> allocated_content(
215       new char[allocated_content_flat_size]);
216   char* content_ptr = allocated_content.get();
217   for (size_t i = 0; i < const_node_parts.size(); i++) {
218     const auto& value_attr = const_node_parts[i]->attr().at("value");
219     const tensorflow::TensorProto& tensor = value_attr.tensor();
220     port::CopyToBuffer(tensor.tensor_content(), content_ptr);
221     content_ptr += tensor.tensor_content().size();
222   }
223 
224   // Transpose the tensor if needed.
225   if (transpose_tensor_value) {
226     // We use dimension 0 to show the row size for the tensor.
227     // We use multiplication of the rest of dimension size to for the col size
228     // of the tensor.
229     std::unique_ptr<float[]> transposed_tensor(
230         new float[dim0_size * dim1_size]);
231     Transpose2DTensor(reinterpret_cast<float*>(allocated_content.get()),
232                       dim0_size, dim1_size, transposed_tensor.get());
233     allocated_tensor_shape->clear_dim();
234     allocated_tensor_shape->add_dim()->set_size(dim1_size);
235     allocated_tensor_shape->add_dim()->set_size(dim0_size);
236 
237     // Set the tensor attributes.
238     allocated_tensor->set_tensor_content(
239         string(reinterpret_cast<const char*>(transposed_tensor.get()),
240                allocated_content_flat_size));
241   } else {
242     tensor_shape_dim0->set_size(dim0_size);
243 
244     // Set the tensor attributes.
245     allocated_tensor->set_tensor_content(
246         string(reinterpret_cast<const char*>(allocated_content.get()),
247                allocated_content_flat_size));
248   }
249 }
250 
251 // SvdfClusterFactory methods
252 
CreateCluster(const NodeDef & node,const GraphDef & graph_def) const253 std::unique_ptr<Cluster> SvdfClusterFactory::CreateCluster(
254     const NodeDef& node, const GraphDef& graph_def) const {
255   std::vector<string> node_patterns = {"SVDF_weights_feature",
256                                        "SVDF_weights_time", "SVDF_bias"};
257 
258   string node_name_to_upper = node.name();
259   std::transform(node_name_to_upper.begin(), node_name_to_upper.end(),
260                  node_name_to_upper.begin(), ::toupper);
261   std::unique_ptr<SvdfCluster> cluster = nullptr;
262   if (node_name_to_upper.find("SVDF", 0) != string::npos) {
263     size_t weights_pos = node.name().find(node_patterns[0]);
264     if (weights_pos != string::npos) {
265       // Assuming the node name has a pattern like:
266       // "SOMESTRING1/CELLNAME/SEARCH_PATTERN/SOMESTRING2", we use
267       // CELLNAME as the cluster name.
268       size_t cell_pos = node.name().rfind("/", weights_pos - 2) + 1;
269       string cell_name =
270           node.name().substr(cell_pos, weights_pos - cell_pos - 1);
271       cluster = std::unique_ptr<SvdfCluster>(new SvdfCluster);
272       cluster->SetName(cell_name);
273       cluster->SetDevice(node.device());
274       cluster->SetGraphDefInfo(&graph_def);
275       CHECK(cluster->FindClusterInputsAndOutputs());
276 
277       for (const string& const_pattern : node_patterns) {
278         cluster->AddConstNodePattern(const_pattern);
279       }
280     }
281   }
282   return std::move(cluster);
283 }
284 
285 }  // end namespace toco
286