1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/toco/tensorflow_graph_matching/resolve_svdf.h"
16
17 #include <ctype.h>
18 #include <stddef.h>
19 #include <algorithm>
20 #include <memory>
21 #include <string>
22 #include <vector>
23
24 #include "google/protobuf/map.h"
25 #include "tensorflow/lite/toco/model.h"
26 #include "tensorflow/lite/toco/tensorflow_graph_matching/cluster.h"
27 #include "tensorflow/lite/toco/tensorflow_graph_matching/cluster_utils.h"
28 #include "tensorflow/lite/toco/toco_port.h"
29 #include "tensorflow/lite/toco/tooling_util.h"
30
31 #include "tensorflow/core/framework/attr_value.pb.h"
32 #include "tensorflow/core/framework/graph.pb.h"
33 #include "tensorflow/core/framework/node_def.pb.h"
34 #include "tensorflow/core/framework/tensor.pb.h"
35 #include "tensorflow/core/framework/tensor_shape.pb.h"
36 #include "tensorflow/core/platform/logging.h"
37
38 using tensorflow::GraphDef;
39 using tensorflow::NodeDef;
40
41 namespace toco {
42
43 namespace {
44
45 // Receives a vector of cluster nodes and returns only those which are array
46 // partitions (of type 'Const' and have the pattern 'part_<.*>' in their name.
47 // Since these nodes are connected to a Concatenate node, it makes sure the
48 // axis value input of the Concatenate operator is 0.
FilterPartitionedConstNodes(const std::string & const_pattern,const std::vector<const NodeDef * > & cluster_nodes,std::vector<const NodeDef * > * const_node_parts)49 void FilterPartitionedConstNodes(
50 const std::string& const_pattern,
51 const std::vector<const NodeDef*>& cluster_nodes,
52 std::vector<const NodeDef*>* const_node_parts) {
53 for (const NodeDef* node : cluster_nodes) {
54 std::string node_name_to_upper = node->name();
55 std::transform(node_name_to_upper.begin(), node_name_to_upper.end(),
56 node_name_to_upper.begin(), ::toupper);
57 if (StrContains(node->name(), const_pattern) && node->op() == "Const") {
58 if (StrContains(node_name_to_upper, "/PART_")) {
59 const_node_parts->push_back(node);
60 } else if (StrContains(node->name(), "AXIS") &&
61 StrContains(node->name(), "CONCAT")) {
62 // For now only supporting Concatenate on Axix 0
63 const auto& value_attr = node->attr().at("value");
64 const tensorflow::TensorProto& tensor = value_attr.tensor();
65 CHECK_EQ(tensor.int_val(0), 0);
66 }
67 }
68 }
69 sort(const_node_parts->begin(), const_node_parts->end(),
70 [](const NodeDef* a, const NodeDef* b) {
71 return (a->name().compare(b->name()) < 0 &&
72 (a->name().size() < b->name().size()));
73 });
74 }
75
76 } // namespace
77
78 // SvdfCluster methods
79
InferFilterRank()80 int SvdfCluster::InferFilterRank() {
81 for (const NodeDef* node : nodes_) {
82 if (StrContains(node->name(), "Reshape/shape")) {
83 const auto& value_attr = node->attr().at("value");
84 const tensorflow::TensorProto& tensor = value_attr.tensor();
85 std::vector<int32> shape_values(
86 tensor.tensor_content().size() / sizeof(int), 0);
87 port::CopyToBuffer(tensor.tensor_content(),
88 reinterpret_cast<char*>(shape_values.data()));
89 CHECK_EQ(shape_values.size(), 3);
90 // shape_value array is arranged as:
91 // [num_units, rank, -1]
92 CHECK_EQ(shape_values[2], -1);
93 return shape_values[1];
94 }
95 }
96 return -1;
97 }
98
CreateNodes()99 void SvdfCluster::CreateNodes() {
100 for (const std::string& const_pattern : const_node_patterns_) {
101 CreateConstNode(const_pattern);
102 }
103 std::unique_ptr<tensorflow::NodeDef> svdf_node(new NodeDef);
104 svdf_node->set_op("Svdf");
105 svdf_node->set_name(name_);
106 svdf_node->set_device(device_);
107
108 // Add the main input.
109 svdf_node->add_input(inputs_[0]);
110
111 // Add the rest of the inputs to Svdf cell: weights and bias.
112 CHECK(new_nodes_.size() == 3 || new_nodes_.size() == 2);
113 std::string* weights_feature_input = svdf_node->add_input();
114 std::string* weights_time_input = svdf_node->add_input();
115 std::string* bias_input;
116 if (new_nodes_.size() == 3) {
117 bias_input = svdf_node->add_input();
118 }
119 for (const std::unique_ptr<tensorflow::NodeDef>& node : new_nodes_) {
120 const std::string node_name = node->name();
121 if (StrContains(node_name, "SVDF_weights_feature")) {
122 *weights_feature_input = node_name;
123 } else if (StrContains(node_name, "SVDF_weights_time")) {
124 *weights_time_input = node_name;
125 } else if (StrContains(node_name, "SVDF_bias")) {
126 CHECK(bias_input) << "Bias input cannot be provided when there are only "
127 "two Const input nodes!";
128 *bias_input = node_name;
129 } else {
130 // Unexpected input for Svdf op.
131 LOG(FATAL) << "Unexpected input node for SVDF op! Accepted inputs are: "
132 "weights_feature, weights_time and bias.";
133 }
134 }
135 const int rank = InferFilterRank();
136 CHECK_GT(rank, 0);
137
138 // Add Svdf activation and rank.
139 std::string activation_function =
140 StrContains(outputs_[0], "Relu") ? "Relu" : "None";
141 (*svdf_node->mutable_attr())["ActivationFunction"].set_s(activation_function);
142 (*svdf_node->mutable_attr())["Rank"].set_i(rank);
143
144 // Finally add it to the list of the newly created nodes.
145 new_nodes_.push_back(std::move(svdf_node));
146 }
147
CreateConstNode(const std::string & const_pattern)148 void SvdfCluster::CreateConstNode(const std::string& const_pattern) {
149 // Find the nodes with pattern like: "const_pattern"/part_xxx of type Const.
150 std::vector<const NodeDef*> const_node_parts;
151 FilterPartitionedConstNodes(const_pattern, nodes_, &const_node_parts);
152
153 if (const_node_parts.empty()) return;
154
155 bool transpose_tensor_value =
156 StrContains(const_pattern, "SVDF_weights_feature");
157
158 // Merge them if necessary.
159 std::unique_ptr<tensorflow::NodeDef> merged_node(new NodeDef);
160 MaybeMergeConstNodes(const_node_parts, transpose_tensor_value, merged_node);
161 new_nodes_.push_back(std::move(merged_node));
162 }
163
MaybeMergeConstNodes(const std::vector<const NodeDef * > & const_node_parts,bool transpose_tensor_value,const std::unique_ptr<tensorflow::NodeDef> & merged_node)164 void SvdfCluster::MaybeMergeConstNodes(
165 const std::vector<const NodeDef*>& const_node_parts,
166 bool transpose_tensor_value,
167 const std::unique_ptr<tensorflow::NodeDef>& merged_node) {
168 merged_node->set_name(const_node_parts[0]->name());
169 merged_node->set_op("Const");
170 merged_node->set_device(const_node_parts[0]->device());
171 (*merged_node->mutable_attr())["dtype"].set_type(
172 const_node_parts[0]->attr().at("dtype").type());
173
174 // Figuring out Value attribute for the merged node.
175 // Assuming the partitioning is done on Axis 0.
176 // The attributes which are inferred:
177 // * Shape and dimensions
178 // * Float content values
179
180 // Inferring shape and dimension
181 int dim0_size = 0;
182 int dim1_size = 1;
183 tensorflow::TensorProto* allocated_tensor =
184 (*merged_node->mutable_attr())["value"].mutable_tensor();
185 tensorflow::TensorShapeProto* allocated_tensor_shape =
186 allocated_tensor->mutable_tensor_shape();
187 auto tensor_shape_dim0 = allocated_tensor_shape->add_dim();
188 int allocated_content_flat_size = 0;
189 for (size_t i = 0; i < const_node_parts.size(); i++) {
190 const auto& value_attr = const_node_parts[i]->attr().at("value");
191 const tensorflow::TensorProto& tensor = value_attr.tensor();
192 if (i == 0) {
193 allocated_tensor->set_dtype(tensor.dtype());
194 } else {
195 CHECK_EQ(allocated_tensor->dtype(), tensor.dtype());
196 }
197 allocated_content_flat_size += tensor.tensor_content().size();
198 CHECK(tensor.has_tensor_shape());
199 const tensorflow::TensorShapeProto shape = tensor.tensor_shape();
200 dim0_size += shape.dim(0).size();
201 for (int d = 1; d < shape.dim_size(); d++) {
202 if (i == 0) {
203 allocated_tensor_shape->add_dim()->set_size(shape.dim(d).size());
204 allocated_tensor_shape->set_unknown_rank(shape.unknown_rank());
205 dim1_size *= shape.dim(d).size();
206 } else {
207 CHECK_EQ(shape.dim(d).size(), allocated_tensor_shape->dim(d).size());
208 CHECK_EQ(allocated_tensor_shape->unknown_rank(), shape.unknown_rank());
209 }
210 }
211 }
212
213 // Copying the float content from each array partition.
214 std::unique_ptr<char[]> allocated_content(
215 new char[allocated_content_flat_size]);
216 char* content_ptr = allocated_content.get();
217 for (size_t i = 0; i < const_node_parts.size(); i++) {
218 const auto& value_attr = const_node_parts[i]->attr().at("value");
219 const tensorflow::TensorProto& tensor = value_attr.tensor();
220 port::CopyToBuffer(tensor.tensor_content(), content_ptr);
221 content_ptr += tensor.tensor_content().size();
222 }
223
224 // Transpose the tensor if needed.
225 if (transpose_tensor_value) {
226 // We use dimension 0 to show the row size for the tensor.
227 // We use multiplication of the rest of dimension size to for the col size
228 // of the tensor.
229 std::unique_ptr<float[]> transposed_tensor(
230 new float[dim0_size * dim1_size]);
231 Transpose2DTensor(reinterpret_cast<float*>(allocated_content.get()),
232 dim0_size, dim1_size, transposed_tensor.get());
233 allocated_tensor_shape->clear_dim();
234 allocated_tensor_shape->add_dim()->set_size(dim1_size);
235 allocated_tensor_shape->add_dim()->set_size(dim0_size);
236
237 // Set the tensor attributes.
238 allocated_tensor->set_tensor_content(
239 std::string(reinterpret_cast<const char*>(transposed_tensor.get()),
240 allocated_content_flat_size));
241 } else {
242 tensor_shape_dim0->set_size(dim0_size);
243
244 // Set the tensor attributes.
245 allocated_tensor->set_tensor_content(
246 std::string(reinterpret_cast<const char*>(allocated_content.get()),
247 allocated_content_flat_size));
248 }
249 }
250
251 // SvdfClusterFactory methods
252
CreateCluster(const NodeDef & node,const GraphDef & graph_def) const253 std::unique_ptr<Cluster> SvdfClusterFactory::CreateCluster(
254 const NodeDef& node, const GraphDef& graph_def) const {
255 std::vector<std::string> node_patterns = {"SVDF_weights_feature",
256 "SVDF_weights_time", "SVDF_bias"};
257
258 std::string node_name_to_upper = node.name();
259 std::transform(node_name_to_upper.begin(), node_name_to_upper.end(),
260 node_name_to_upper.begin(), ::toupper);
261 std::unique_ptr<SvdfCluster> cluster = nullptr;
262 if (node_name_to_upper.find("SVDF", 0) != std::string::npos) {
263 size_t weights_pos = node.name().find(node_patterns[0]);
264 if (weights_pos != std::string::npos) {
265 // Assuming the node name has a pattern like:
266 // "SOMESTRING1/CELLNAME/SEARCH_PATTERN/SOMESTRING2", we use
267 // CELLNAME as the cluster name.
268 size_t cell_pos = node.name().rfind('/', weights_pos - 2) + 1;
269 std::string cell_name =
270 node.name().substr(cell_pos, weights_pos - cell_pos - 1);
271 cluster = std::unique_ptr<SvdfCluster>(new SvdfCluster);
272 cluster->SetName(cell_name);
273 cluster->SetDevice(node.device());
274 cluster->SetGraphDefInfo(&graph_def);
275 CHECK(cluster->FindClusterInputsAndOutputs());
276
277 for (const std::string& const_pattern : node_patterns) {
278 cluster->AddConstNodePattern(const_pattern);
279 }
280 }
281 }
282 return std::move(cluster);
283 }
284
285 } // end namespace toco
286