1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/kernels/hexagon/graph_transferer.h"
17
18 #include <algorithm>
19 #include <cinttypes>
20
21 #include "tensorflow/core/framework/graph.pb.h"
22 #include "tensorflow/core/framework/graph_transfer_info.pb.h"
23 #include "tensorflow/core/framework/op.h"
24 #include "tensorflow/core/graph/algorithm.h"
25 #include "tensorflow/core/graph/graph_constructor.h"
26 #include "tensorflow/core/graph/node_builder.h"
27 #include "tensorflow/core/platform/env.h"
28 #include "tensorflow/core/platform/types.h"
29 #include "tensorflow/core/public/session.h"
30 #include "tensorflow/core/public/session_options.h"
31 #include "tensorflow/core/util/tensor_slice_writer.h"
32
33 namespace tensorflow {
34
35 // function alias
36 constexpr auto AddOutputTensorShapeTypeByTensorShapeMap =
37 &RemoteFusedGraphExecuteUtils::AddOutputTensorShapeTypeByTensorShapeMap;
38
39 constexpr bool DBG_DUMP_VERIFICATION_STRING = false;
40 constexpr bool DBG_DUMP_PARAMS = false;
41
42 const char RESHAPE_NODE_TYPE_STRING[] = "Reshape";
43 const char SOURCE_NODE_NAME[] = "_SOURCE";
44 const char SINK_NODE_NAME[] = "_SINK";
45 const char INPUTS_NODE_PREFIX[] = "inputs_for_";
46 const char OUTPUTS_NODE_PREFIX[] = "outputs_for_";
47 const char DATA_NODE_PREFIX[] = "data_for_op_";
48 const char CONST_SHAPE_PREFIX[] = "const_shape_";
49 const char CONST_VAL_PREFIX[] = "const_val_";
50 const char CONST_TENSOR_PREFIX[] = "const_tensor_";
51 const char PADDING_ATTR_NAME[] = "padding";
52 const char STRIDES_ATTR_NAME[] = "strides";
53 const char KEEP_DIMS_ATTR_NAME[] = "keep_dims";
54 const char KSIZE_ATTR_NAME[] = "ksize";
55 const char NULL_OUTPUT_NAME[] = "NULL";
56 const char AGGREGATED_INPUT_NODE_NAME[] = "graph_transfer_aggregated_input";
57 const int PADDING_NA_ID = 0; // VALID = 1, SAME = 2
58
59 // This is a temporary workaround to support android build
60 // where std::string is not supported even with c++11 option.
61 template <typename T>
ToString(T val)62 static string ToString(T val) {
63 std::stringstream stream;
64 stream << val;
65 return stream.str();
66 }
67
FindMutableNodeByName(const string & name,Graph * graph)68 static Node* FindMutableNodeByName(const string& name, Graph* graph) {
69 const TensorId tid = ParseTensorName(name);
70 for (Node* node : graph->nodes()) {
71 if (node != nullptr && node->name() == tid.first) {
72 return node;
73 }
74 }
75 return nullptr;
76 }
77
GraphTransferer()78 GraphTransferer::GraphTransferer() {
79 graph_transfer_info_ = new GraphTransferInfo();
80 }
81
~GraphTransferer()82 GraphTransferer::~GraphTransferer() { delete graph_transfer_info_; }
83
84 /**
85 * graph loading functions
86 * - LoadGraphFromProto
87 * - LoadGraphFromProptoFile
88 * These functions read a graph definition and store parameters
89 * of node to transfer the graph to SOC.
90 */
LoadGraphFromProto(const IRemoteFusedGraphOpsDefinitions & ops_definitions,const GraphDef & graph_def,const std::vector<std::pair<string,Tensor>> & input_node_info_list,const std::vector<string> & output_node_names,const bool shape_inference_for_unknown_shape)91 Status GraphTransferer::LoadGraphFromProto(
92 const IRemoteFusedGraphOpsDefinitions& ops_definitions,
93 const GraphDef& graph_def,
94 const std::vector<std::pair<string, Tensor>>& input_node_info_list,
95 const std::vector<string>& output_node_names,
96 const bool shape_inference_for_unknown_shape) {
97 Graph graph(OpRegistry::Global());
98 ShapeRefiner shape_refiner(graph.versions(), graph.op_registry());
99 Status status = ImportGraphDef({}, graph_def, &graph, &shape_refiner);
100 if (!status.ok()) {
101 return status;
102 }
103
104 if (shape_inference_for_unknown_shape) {
105 status = RemoteFusedGraphExecuteUtils::PropagateShapeInference(
106 graph_def, input_node_info_list, &graph, &shape_refiner);
107 if (!status.ok()) {
108 return status;
109 }
110 }
111
112 TF_RETURN_IF_ERROR(TransformGraphToAddAggregatedInputNode(
113 input_node_info_list, &graph, &shape_refiner));
114
115 std::unordered_multimap<string, const Node*> op_name_to_node_multimap(
116 graph.num_nodes());
117 for (const Node* const node : graph.nodes()) {
118 if (node == nullptr) {
119 continue;
120 }
121 CacheNode(*node);
122 }
123
124 for (const Node* const node : graph.nodes()) {
125 if (node == nullptr) {
126 continue;
127 }
128 VLOG(1) << "<Node> " << node->name();
129 for (const Node* const input_node : node->in_nodes()) {
130 const string& name = input_node->name();
131 op_name_to_node_multimap.emplace(name, node);
132 VLOG(1) << "Add dependency: " << name << " -> " << node->name();
133 }
134 }
135
136 for (const Node* const node : graph.nodes()) {
137 if (node == nullptr) {
138 continue;
139 }
140 status = RegisterNodeIfAllInputsAreCached(
141 ops_definitions, shape_refiner, *node, false, input_node_info_list,
142 output_node_names);
143 if (!status.ok()) {
144 LOG(ERROR) << "Failed to transfer graph " << status;
145 return status;
146 }
147 }
148
149 SortParams(output_node_names);
150
151 for (const std::pair<string, Tensor>& input_node_info :
152 input_node_info_list) {
153 GraphTransferGraphInputNodeInfo& graph_input_node_info =
154 *graph_transfer_info_->add_graph_input_node_info();
155 graph_input_node_info.set_name(input_node_info.first);
156 graph_input_node_info.set_dtype(input_node_info.second.dtype());
157 for (const int64 dim : ToTensorShapeArray(input_node_info.second.shape())) {
158 graph_input_node_info.add_shape(dim);
159 }
160 }
161
162 for (const string& output_node_name : output_node_names) {
163 const TensorId tid = ParseTensorName(output_node_name);
164 const string node_name(tid.first);
165 const int port = tid.second;
166 const int node_id = node_name_to_id_cache_map_.at(node_name);
167 const Node* node = node_name_cache_list_.at(node_id);
168 CHECK_NOTNULL(node);
169
170 GraphTransferGraphOutputNodeInfo& graph_output_node_info =
171 *graph_transfer_info_->add_graph_output_node_info();
172 graph_output_node_info.set_name(strings::StrCat(node_name, ":", port));
173
174 // Get output tensor shape type
175 std::vector<DataType> data_types;
176 std::vector<TensorShape> shapes;
177 status = RemoteFusedGraphExecuteUtils::GetOutputTensorShapeType(
178 node->attrs(), &data_types, &shapes);
179 if (status.ok()) {
180 CHECK(data_types.size() > port);
181 graph_output_node_info.set_dtype(data_types.at(port));
182 for (const int64 dim : ToTensorShapeArray(shapes.at(port))) {
183 graph_output_node_info.add_shape(dim);
184 }
185 }
186 }
187
188 ClearCache();
189 if (DBG_DUMP_PARAMS) {
190 DumpNodeTransferParams();
191 }
192 if (DBG_DUMP_VERIFICATION_STRING) {
193 DumpVerificationStringOfNodeTransferParams();
194 }
195 return Status();
196 }
197
LoadGraphFromProtoFile(const IRemoteFusedGraphOpsDefinitions & ops_definitions,const string & graph_def_path,const std::vector<std::pair<string,Tensor>> & input_node_info_list,const std::vector<string> & output_node_names,const bool is_text_proto,const bool shape_inference_for_unknown_shape,const bool dry_run_for_unknown_shape)198 Status GraphTransferer::LoadGraphFromProtoFile(
199 const IRemoteFusedGraphOpsDefinitions& ops_definitions,
200 const string& graph_def_path,
201 const std::vector<std::pair<string, Tensor>>& input_node_info_list,
202 const std::vector<string>& output_node_names, const bool is_text_proto,
203 const bool shape_inference_for_unknown_shape,
204 const bool dry_run_for_unknown_shape) {
205 GraphDef graph_def;
206 string output;
207 Status status;
208 VLOG(1) << "Parse file " << graph_def_path;
209 if (is_text_proto) {
210 status = ReadFileToString(Env::Default(), graph_def_path, &output);
211 if (!protobuf::TextFormat::ParseFromString(output, &graph_def)) {
212 return errors::InvalidArgument("Cannot parse proto string.");
213 }
214 } else {
215 status = ReadBinaryProto(Env::Default(), graph_def_path, &graph_def);
216 }
217 if (!status.ok()) {
218 VLOG(1) << "Failed to load graph " << status;
219 return status;
220 }
221 if (dry_run_for_unknown_shape) {
222 VLOG(1) << "Dry run graph to obtain shape of nodes";
223 RemoteFusedGraphExecuteUtils::TensorShapeMap tensor_shape_map;
224 status = RemoteFusedGraphExecuteUtils::DryRunInferenceForAllNode(
225 graph_def, input_node_info_list, true, &tensor_shape_map);
226 if (!status.ok()) {
227 return status;
228 }
229 for (NodeDef& node_def : *graph_def.mutable_node()) {
230 TF_CHECK_OK(AddOutputTensorShapeTypeByTensorShapeMap(tensor_shape_map,
231 &node_def));
232 }
233 }
234 VLOG(1) << "Load graph with output tensors";
235 return LoadGraphFromProto(ops_definitions, graph_def, input_node_info_list,
236 output_node_names,
237 shape_inference_for_unknown_shape);
238 }
239
SortParams(const std::vector<string> & output_node_names)240 void GraphTransferer::SortParams(const std::vector<string>& output_node_names) {
241 // TODO(satok): optimize complexity
242 std::unordered_map<int, GraphTransferNodeInputInfo*> input_map;
243 for (GraphTransferNodeInputInfo& input :
244 *graph_transfer_info_->mutable_node_input_info()) {
245 input_map.emplace(input.node_id(), &input);
246 }
247
248 // Setup dependency map placeholder
249 std::vector<int> output_node_ids;
250 std::unordered_map<int, std::unordered_set<int>> dependency_map;
251 for (const GraphTransferNodeInfo& params :
252 graph_transfer_info_->node_info()) {
253 const int node_id = params.node_id();
254 for (const string& output_node_name : output_node_names) {
255 if (params.name() == output_node_name) {
256 output_node_ids.emplace_back(node_id);
257 }
258 }
259
260 dependency_map.emplace(std::piecewise_construct, std::make_tuple(node_id),
261 std::make_tuple());
262 if (params.input_count() == 0) {
263 continue;
264 }
265 CHECK_EQ(input_map.count(node_id), 1);
266 for (const GraphTransferNodeInput& node_input :
267 input_map.at(node_id)->node_input()) {
268 dependency_map.at(node_id).emplace(node_input.node_id());
269 }
270 }
271
272 // Create dependency map traversed from output nodes
273 std::unordered_set<int> completed;
274 for (int output_node_id : output_node_ids) {
275 FillDependencyRec(output_node_id, dependency_map, completed);
276 }
277
278 std::sort(graph_transfer_info_->mutable_node_info()->begin(),
279 graph_transfer_info_->mutable_node_info()->end(),
280 TransferParamsComparator(dependency_map));
281 }
282
EnableStrictCheckMode(const bool enable)283 void GraphTransferer::EnableStrictCheckMode(const bool enable) {
284 strict_check_mode_ = enable;
285 }
286
SetSerializedGraphTransferInfo(const string & serialized_proto)287 void GraphTransferer::SetSerializedGraphTransferInfo(
288 const string& serialized_proto) {
289 graph_transfer_info_->ParseFromString(serialized_proto);
290 }
291
GetGraphTransferInfo() const292 const GraphTransferInfo& GraphTransferer::GetGraphTransferInfo() const {
293 return *graph_transfer_info_;
294 }
295
GetMutableGraphTransferInfo()296 GraphTransferInfo& GraphTransferer::GetMutableGraphTransferInfo() {
297 return *graph_transfer_info_;
298 }
299
CacheNode(const Node & node)300 void GraphTransferer::CacheNode(const Node& node) {
301 if (node_name_to_id_cache_map_.count(node.name()) > 0) {
302 return;
303 }
304 node_name_cache_list_.emplace_back(&node);
305 const int node_id = node_name_cache_list_.size() - 1;
306 bool emplace_succeeded = false;
307 std::tie(std::ignore, emplace_succeeded) =
308 node_name_to_id_cache_map_.emplace(node.name(), node_id);
309 CHECK(emplace_succeeded);
310 }
311
AreAllInputsCached(const Node & node) const312 bool GraphTransferer::AreAllInputsCached(const Node& node) const {
313 for (const Node* const input_node : node.in_nodes()) {
314 if (node_name_to_id_cache_map_.count(input_node->name()) <= 0) {
315 VLOG(1) << "input_node " << input_node->name() << " of " << node.name()
316 << " is not cached yet.";
317 return false;
318 }
319 }
320 return true;
321 }
322
TransformGraphToAddAggregatedInputNode(const std::vector<std::pair<string,Tensor>> & input_node_info_list,Graph * graph,ShapeRefiner * shape_refiner)323 Status GraphTransferer::TransformGraphToAddAggregatedInputNode(
324 const std::vector<std::pair<string, Tensor>>& input_node_info_list,
325 Graph* graph, ShapeRefiner* shape_refiner) {
326 // Transform a remote fused graph to add an aggregated input node which takes
327 // all inputs of the remote graph.
328 DataTypeVector input_data_types;
329 std::vector<DataType> data_types;
330 std::vector<TensorShape> shapes;
331 std::vector<string> input_nodes;
332 for (int i = 0; i < input_node_info_list.size(); ++i) {
333 Node* node = FindMutableNodeByName(input_node_info_list.at(i).first, graph);
334 CHECK_NOTNULL(node);
335 input_nodes.emplace_back(node->name());
336 input_data_types.emplace_back(input_node_info_list.at(i).second.dtype());
337 data_types.emplace_back(input_node_info_list.at(i).second.dtype());
338 shapes.emplace_back(input_node_info_list.at(i).second.shape());
339 }
340
341 auto builder =
342 NodeBuilder(AGGREGATED_INPUT_NODE_NAME, "RemoteFusedGraphExecute")
343 .Input(std::vector<NodeBuilder::NodeOut>{})
344 .Attr("Tinputs", DataTypeVector{})
345 .Attr("Toutputs", input_data_types)
346 .Attr("serialized_remote_fused_graph_execute_info", "")
347 .Attr(RemoteFusedGraphExecuteUtils::ATTR_OUTPUT_DATA_TYPES,
348 data_types)
349 .Attr(RemoteFusedGraphExecuteUtils::ATTR_OUTPUT_SHAPES, shapes);
350
351 Node* input_node;
352 TF_RETURN_IF_ERROR(builder.Finalize(graph, &input_node));
353 CHECK_NOTNULL(input_node);
354
355 bool refined;
356 TF_RETURN_IF_ERROR(
357 shape_refiner->UpdateNode(input_node, false /* relax */, &refined));
358
359 shape_inference::InferenceContext* context =
360 shape_refiner->GetContext(input_node);
361 for (int i = 0; i < input_node_info_list.size(); ++i) {
362 shape_inference::ShapeHandle handle;
363 TF_RETURN_IF_ERROR(context->MakeShapeFromTensorShape(
364 input_node_info_list.at(i).second.shape(), &handle));
365 TF_RETURN_IF_ERROR(shape_refiner->SetShape(input_node, i, handle));
366 }
367
368 // Cache the aggregate input node first as it's consumed first.
369 CacheNode(*input_node);
370
371 std::vector<Node*> original_input_nodes(input_nodes.size());
372
373 for (int i = 0; i < input_nodes.size(); ++i) {
374 const string& node_name = input_nodes.at(i);
375 Node* original_input_node = FindMutableNodeByName(node_name, graph);
376 CHECK_NOTNULL(original_input_node);
377 CHECK_EQ(1, original_input_node->num_outputs()); // replaced by identity.
378 Node* created_node;
379 TF_RETURN_IF_ERROR(RemoteFusedGraphExecuteUtils::BuildIdentityOpNode(
380 node_name, AGGREGATED_INPUT_NODE_NAME, i, data_types.at(i), graph,
381 &created_node));
382 CHECK_NOTNULL(created_node);
383 std::vector<DataType> data_types;
384 std::vector<TensorShape> shapes;
385 Status status = RemoteFusedGraphExecuteUtils::GetOutputTensorShapeType(
386 original_input_node->attrs(), &data_types, &shapes);
387 if (status.ok()) {
388 created_node->AddAttr(
389 RemoteFusedGraphExecuteUtils::ATTR_OUTPUT_DATA_TYPES, data_types);
390 created_node->AddAttr(RemoteFusedGraphExecuteUtils::ATTR_OUTPUT_SHAPES,
391 shapes);
392 }
393 for (const Edge* out_edge : original_input_node->out_edges()) {
394 Node* dst = out_edge->dst();
395 int dst_port = out_edge->dst_input();
396 // Unused edge will be removed when removing node.
397 graph->AddEdge(created_node, 0, dst, dst_port);
398 }
399 original_input_nodes[i] = original_input_node;
400
401 TF_RETURN_IF_ERROR(
402 shape_refiner->UpdateNode(created_node, false /* relax */, &refined));
403
404 shape_inference::InferenceContext* context =
405 shape_refiner->GetContext(created_node);
406 CHECK_NOTNULL(context);
407
408 // Cache replaced input node next to the aggregated input node.
409 CacheNode(*created_node);
410 }
411
412 // Remove original input nodes after adding new input nodes to avoid
413 // reusing same pointer in Graph.
414 for (Node* original_input_node : original_input_nodes) {
415 graph->RemoveNode(original_input_node);
416 }
417
418 return Status::OK();
419 }
420
RegisterNode(const IRemoteFusedGraphOpsDefinitions & ops_definitions,const ShapeRefiner & shape_refiner,const Node & node,const std::vector<std::pair<string,Tensor>> & input_node_info_list,const std::vector<string> & output_node_names)421 Status GraphTransferer::RegisterNode(
422 const IRemoteFusedGraphOpsDefinitions& ops_definitions,
423 const ShapeRefiner& shape_refiner, const Node& node,
424 const std::vector<std::pair<string, Tensor>>& input_node_info_list,
425 const std::vector<string>& output_node_names) {
426 VLOG(1) << "Register node: " << node.name() << ", " << std::hex
427 << node_name_to_id_cache_map_.at(node.name());
428 if (node.name() == SOURCE_NODE_NAME || node.name() == SINK_NODE_NAME) {
429 // Just ignore sink and source
430 return Status::OK();
431 } else if (node.name() == AGGREGATED_INPUT_NODE_NAME) {
432 RegisterInputNode(ops_definitions, shape_refiner, node);
433 return Status::OK();
434 } else if (node.IsConstant()) {
435 RegisterConstantNode(shape_refiner, node);
436 } else if (IsPadNode(node)) {
437 RegisterPadNode(ops_definitions, shape_refiner, node);
438 } else if (HasPaddingAndStrides(node)) {
439 RegisterNodeWithPaddingAndStrides(ops_definitions, shape_refiner, node);
440 } else if (NeedsToAddRank(node)) {
441 RegisterNodeWithRank(ops_definitions, shape_refiner, node);
442 } else if (IsNodeFlattenReshape(node, shape_refiner)) {
443 RegisterFlattenNode(ops_definitions, shape_refiner, node);
444 } else if (ops_definitions.GetOpIdFor(node.type_string(), {}) !=
445 IRemoteFusedGraphOpsDefinitions::INVALID_OP_ID) {
446 // TODO(satok): Set correct data type if it's given.
447 RegisterGenericNode(ops_definitions, shape_refiner, node);
448 } else {
449 return errors::InvalidArgument(node.type_string() +
450 " has not been implemented yet.");
451 }
452
453 return Status::OK();
454 }
455
RegisterConstantNode(const ShapeRefiner & shape_refiner,const Node & node)456 void GraphTransferer::RegisterConstantNode(const ShapeRefiner& shape_refiner,
457 const Node& node) {
458 VLOG(1) << "Register constant node: " << node.name();
459 CHECK_EQ(node_name_to_id_cache_map_.count(node.name()), 1);
460 const int id = node_name_to_id_cache_map_[node.name()];
461 const int output_node_size = node.num_outputs();
462 CHECK_EQ(output_node_size, 1);
463 // TODO(satok): support multiple outputs?
464 const int output_index = 0;
465 const DataType dt = node.output_type(output_index);
466 const size_t max_bytes_per_data = DataTypeSize(dt);
467 CHECK_GT(max_bytes_per_data, 0)
468 << "dt = " << dt << ", " + DataTypeString(dt) << ", "
469 << max_bytes_per_data << ", " << static_cast<int>(DataTypeSize(dt))
470 << ",,,,,,,";
471 shape_inference::InferenceContext* context = shape_refiner.GetContext(&node);
472 shape_inference::ShapeHandle shape_handle = context->output(output_index);
473 const shape_inference::DimensionHandle num_elements_dim =
474 context->NumElements(shape_handle);
475 std::array<int64, SHAPE_ARRAY_SIZE> shape_array;
476 int data_size;
477 // Shape of constant node must be known
478 CHECK(context->ValueKnown(num_elements_dim));
479 const int64 num_output_elements = context->Value(num_elements_dim);
480 data_size = max_bytes_per_data * num_output_elements;
481 shape_array = BuildShapeArray(shape_handle, context);
482
483 GraphTransferConstNodeInfo& const_node_info =
484 *graph_transfer_info_->add_const_node_info();
485 const_node_info.set_name(node.name());
486 const_node_info.set_node_id(id);
487 // TODO(satok): Make this generic. Never assume rank is 4.
488 CHECK_EQ(4, SHAPE_ARRAY_SIZE);
489 const_node_info.add_shape(shape_array[0]);
490 const_node_info.add_shape(shape_array[1]);
491 const_node_info.add_shape(shape_array[2]);
492 const_node_info.add_shape(shape_array[3]);
493 const TensorProto* proto = nullptr;
494 TF_CHECK_OK(GetNodeAttr(node.attrs(), "value", &proto));
495 Tensor const_tensor;
496 TF_CHECK_OK(MakeTensorFromProto(*proto, &const_tensor));
497
498 const_node_info.set_dtype(const_tensor.dtype());
499 if (data_size > 0) {
500 const_node_info.set_data(const_tensor.tensor_data().data(), data_size);
501 }
502 }
503
RegisterConstantShape(const std::vector<int> & shape)504 int GraphTransferer::RegisterConstantShape(const std::vector<int>& shape) {
505 VLOG(1) << "Cache constant shape.";
506 // TODO(satok): Handle non-4dim strides
507 CHECK_EQ(shape.size(), 4);
508 const string shape_name = CONST_SHAPE_PREFIX + ToString(shape.at(0)) + 'x' +
509 ToString(shape.at(1)) + 'x' +
510 ToString(shape.at(2)) + 'x' + ToString(shape.at(3));
511 if (node_name_to_id_cache_map_.count(shape_name) <= 0) {
512 node_name_cache_list_.emplace_back(nullptr);
513 const int id = node_name_cache_list_.size() - 1;
514 node_name_to_id_cache_map_.emplace(shape_name, id);
515 GraphTransferConstNodeInfo& const_node_info =
516 *graph_transfer_info_->add_const_node_info();
517 const_node_info.set_name(shape_name);
518 const_node_info.set_node_id(id);
519 // TODO(satok): Make this generic. Never assume rank is 5.
520 const_node_info.add_shape(static_cast<int64>(shape[0]));
521 const_node_info.add_shape(static_cast<int64>(shape[1]));
522 const_node_info.add_shape(static_cast<int64>(shape[2]));
523 const_node_info.add_shape(static_cast<int64>(shape[3]));
524 }
525 return node_name_to_id_cache_map_[shape_name];
526 }
527
RegisterConstTensor(const Tensor & tensor,const string & suffix)528 int GraphTransferer::RegisterConstTensor(const Tensor& tensor,
529 const string& suffix) {
530 VLOG(1) << "Cache const tensor.";
531 const int dims = tensor.shape().dims();
532 CHECK(dims <= 4);
533 const string node_name = strings::StrCat(CONST_TENSOR_PREFIX, "_", suffix);
534 if (node_name_to_id_cache_map_.count(node_name) <= 0) {
535 node_name_cache_list_.emplace_back(nullptr);
536 const int id = node_name_cache_list_.size() - 1;
537 node_name_to_id_cache_map_.emplace(node_name, id);
538 GraphTransferConstNodeInfo& const_node_info =
539 *graph_transfer_info_->add_const_node_info();
540 const_node_info.set_name(node_name);
541 const_node_info.set_node_id(id);
542 CHECK_EQ(4, SHAPE_ARRAY_SIZE);
543 for (int i = 0; i < SHAPE_ARRAY_SIZE; ++i) {
544 if (i < SHAPE_ARRAY_SIZE - dims) {
545 const_node_info.add_shape(1);
546 } else {
547 const_node_info.add_shape(
548 tensor.shape().dim_size(i - (SHAPE_ARRAY_SIZE - dims)));
549 }
550 }
551 const_node_info.set_dtype(tensor.dtype());
552 const_node_info.set_data(tensor.tensor_data().data(),
553 tensor.tensor_data().size());
554 }
555 return node_name_to_id_cache_map_[node_name];
556 }
557
RegisterConstScalar(const DataType dt,const int val,const int dst_id,const int dst_input_count)558 int GraphTransferer::RegisterConstScalar(const DataType dt, const int val,
559 const int dst_id,
560 const int dst_input_count) {
561 VLOG(1) << "Cache const.";
562 const string val_name =
563 CONST_VAL_PREFIX + ToString(dst_id) + '_' + ToString(dst_input_count);
564 if (node_name_to_id_cache_map_.count(val_name) <= 0) {
565 node_name_cache_list_.emplace_back(nullptr);
566 const int id = node_name_cache_list_.size() - 1;
567 node_name_to_id_cache_map_.emplace(val_name, id);
568 GraphTransferConstNodeInfo& const_node_info =
569 *graph_transfer_info_->add_const_node_info();
570 const_node_info.set_name(val_name);
571 const_node_info.set_node_id(id);
572 // TODO(satok): Do not assume rank is 4 here.
573 const_node_info.add_shape(static_cast<int64>(1));
574 const_node_info.add_shape(static_cast<int64>(1));
575 const_node_info.add_shape(static_cast<int64>(1));
576 const_node_info.add_shape(static_cast<int64>(1));
577 const_node_info.set_data(&val, DataTypeSize(dt));
578 }
579 return node_name_to_id_cache_map_[val_name];
580 }
581
HasPaddingAndStrides(const Node & node)582 bool GraphTransferer::HasPaddingAndStrides(const Node& node) {
583 auto attrs = node.attrs();
584 return attrs.Find(PADDING_ATTR_NAME) != nullptr &&
585 attrs.Find(STRIDES_ATTR_NAME) != nullptr;
586 }
587
NeedsToAddRank(const Node & node)588 bool GraphTransferer::NeedsToAddRank(const Node& node) {
589 const StringPiece op_type(node.type_string());
590 if (op_type == "Transpose" || op_type == "ExpandDims") {
591 return true;
592 }
593 return false;
594 }
595
IsPadNode(const Node & node)596 bool GraphTransferer::IsPadNode(const Node& node) {
597 const StringPiece op_type(node.type_string());
598 if (op_type == "Pad") {
599 return true;
600 }
601 return false;
602 }
603
IsNodeFlattenReshape(const Node & node,const ShapeRefiner & shape_refiner)604 bool GraphTransferer::IsNodeFlattenReshape(const Node& node,
605 const ShapeRefiner& shape_refiner) {
606 // Check if node is reshape op
607 if (node.type_string() != RESHAPE_NODE_TYPE_STRING) {
608 return false;
609 }
610
611 shape_inference::InferenceContext* context = shape_refiner.GetContext(&node);
612 // Check if output count is valid
613 if (context->num_outputs() != 1) {
614 return false;
615 }
616
617 shape_inference::ShapeHandle shape_handle = context->output(0);
618 std::array<int64, SHAPE_ARRAY_SIZE> shape_array;
619 const shape_inference::DimensionHandle dim_handle =
620 context->NumElements(shape_handle);
621
622 // Obtain shape of output of node
623 if (context->ValueKnown(dim_handle)) {
624 shape_array = BuildShapeArray(shape_handle, context);
625 } else {
626 std::vector<TensorShape> shapes;
627 TF_CHECK_OK(RemoteFusedGraphExecuteUtils::GetOutputTensorShapeType(
628 node.attrs(), nullptr, &shapes));
629
630 // Number of outputs should be 1 for reshape node.
631 CHECK_EQ(1, shapes.size());
632 shape_array = ToTensorShapeArray(shapes.at(0));
633 }
634
635 // check if reshape op just does flatten
636 if (shape_array[0] == 1 && shape_array[1] == 1 && shape_array[2] == 1) {
637 return true;
638 } else {
639 return false;
640 }
641 }
642
RegisterNodeWithPaddingAndStrides(const IRemoteFusedGraphOpsDefinitions & ops_definitions,const ShapeRefiner & shape_refiner,const Node & node)643 void GraphTransferer::RegisterNodeWithPaddingAndStrides(
644 const IRemoteFusedGraphOpsDefinitions& ops_definitions,
645 const ShapeRefiner& shape_refiner, const Node& node) {
646 CHECK_EQ(node_name_to_id_cache_map_.count(node.name()), 1);
647 const int id = node_name_to_id_cache_map_[node.name()];
648 shape_inference::InferenceContext* context = shape_refiner.GetContext(&node);
649 CHECK(node.attrs().Find(PADDING_ATTR_NAME));
650 // TODO(satok): Use context->GetAttr(...) instead?
651 Padding padding;
652 TF_CHECK_OK(context->GetAttr(PADDING_ATTR_NAME, &padding));
653 CHECK(node.attrs().Find(STRIDES_ATTR_NAME));
654 std::vector<int32> strides;
655 TF_CHECK_OK(context->GetAttr(STRIDES_ATTR_NAME, &strides));
656 const int stride_id = RegisterConstantShape(strides);
657 std::vector<int> extra_inputs{stride_id};
658 if (node.attrs().Find(KSIZE_ATTR_NAME)) {
659 std::vector<int32> kernel_sizes;
660 TF_CHECK_OK(context->GetAttr(KSIZE_ATTR_NAME, &kernel_sizes));
661 const int ksize_id = RegisterConstantShape(kernel_sizes);
662 extra_inputs.insert(extra_inputs.begin(), ksize_id);
663 }
664 // TODO(satok): Set correct data type if it's given.
665 const int op_type_id = ops_definitions.GetOpIdFor(node.type_string(), {});
666 CHECK(op_type_id >= 0 && op_type_id < ops_definitions.GetTotalOpsCount())
667 << "Op " << node.type_string() << " not found in map(id = " << op_type_id
668 << ")";
669 // Safety check of padding id
670 CHECK(padding == Padding::VALID ? 1 : 2);
671 AppendNodeParamsWithIoParams(
672 shape_refiner, node, node.name(), id, node.type_string(), op_type_id,
673 static_cast<int>(padding), node.num_inputs(), extra_inputs,
674 node.num_outputs(), true /* append_input */, true /* append_output */);
675 }
676
RegisterNodeWithRank(const IRemoteFusedGraphOpsDefinitions & ops_definitions,const ShapeRefiner & shape_refiner,const Node & node)677 void GraphTransferer::RegisterNodeWithRank(
678 const IRemoteFusedGraphOpsDefinitions& ops_definitions,
679 const ShapeRefiner& shape_refiner, const Node& node) {
680 CHECK_EQ(node_name_to_id_cache_map_.count(node.name()), 1);
681 const int id = node_name_to_id_cache_map_[node.name()];
682 shape_inference::InferenceContext* context = shape_refiner.GetContext(&node);
683 const Node* input0_node;
684 TF_CHECK_OK(node.input_node(0, &input0_node));
685 CHECK_NOTNULL(input0_node);
686 std::vector<TensorShape> shapes;
687 Status status = RemoteFusedGraphExecuteUtils::GetOutputTensorShapeType(
688 input0_node->attrs(), nullptr, &shapes);
689 CHECK_EQ(1, shapes.size()) << "Output size should be 1.";
690 const int const_val_id =
691 RegisterConstScalar(DT_INT32, shapes.at(0).dims(), id, node.num_inputs());
692 std::vector<int> extra_inputs{const_val_id};
693 // TODO(satok): Set correct data type if it's given.
694 const int op_type_id = ops_definitions.GetOpIdFor(node.type_string(), {});
695 CHECK(op_type_id >= 0 && op_type_id < ops_definitions.GetTotalOpsCount())
696 << "Op " << node.type_string() << " not found in map(id = " << op_type_id
697 << ")";
698 bool keep_dims = false;
699 int padding_id = PADDING_NA_ID;
700 if (context->GetAttr(KEEP_DIMS_ATTR_NAME, &keep_dims).ok()) {
701 padding_id = keep_dims ? Padding::SAME : Padding::VALID;
702 }
703
704 AppendNodeParamsWithIoParams(
705 shape_refiner, node, node.name(), id, node.type_string(), op_type_id,
706 padding_id, node.num_inputs(), extra_inputs, node.num_outputs(),
707 true /* append_input */, true /* append_output */);
708 }
709
RegisterPadNode(const IRemoteFusedGraphOpsDefinitions & ops_definitions,const ShapeRefiner & shape_refiner,const Node & node)710 void GraphTransferer::RegisterPadNode(
711 const IRemoteFusedGraphOpsDefinitions& ops_definitions,
712 const ShapeRefiner& shape_refiner, const Node& node) {
713 static constexpr int PAD_WIDTH = 4;
714 static constexpr int PAD_HEIGHT = 2;
715 VLOG(1) << "Register generic node: " << node.name();
716 CHECK_EQ(node_name_to_id_cache_map_.count(node.name()), 1);
717 const int id = node_name_to_id_cache_map_[node.name()];
718
719 // TODO(satok): Set correct data type if it's given.
720 const int op_type_id = ops_definitions.GetOpIdFor(node.type_string(), {});
721 CHECK(op_type_id >= 0 && op_type_id < ops_definitions.GetTotalOpsCount());
722
723 CHECK_EQ(2, node.num_inputs());
724
725 GraphTransferNodeInputInfo& node_input_info =
726 *graph_transfer_info_->add_node_input_info();
727 node_input_info.set_node_id(id);
728
729 AddNodeInputByInputIndex(node, 0, &node_input_info);
730
731 const Edge* edge = nullptr;
732 TF_CHECK_OK(node.input_edge(1, &edge));
733 const Node* input_node = edge->src();
734 CHECK_NOTNULL(input_node);
735 CHECK(input_node->IsConstant());
736
737 const TensorProto* tensor_proto = nullptr;
738 TF_CHECK_OK(GetNodeAttr(input_node->attrs(), "value", &tensor_proto));
739 CHECK_NOTNULL(tensor_proto);
740 Tensor const_tensor;
741 TF_CHECK_OK(MakeTensorFromProto(*tensor_proto, &const_tensor));
742 CHECK_EQ(2, const_tensor.shape().dims());
743 CHECK_EQ(PAD_HEIGHT, const_tensor.shape().dim_size(1));
744 if (const_tensor.shape().dim_size(0) == PAD_WIDTH) {
745 AddNodeInputByInputIndex(node, 1, &node_input_info);
746 } else if (const_tensor.shape().dim_size(0) < PAD_WIDTH) {
747 const int width = const_tensor.shape().dim_size(0);
748 const TensorProto* proto = nullptr;
749 TF_CHECK_OK(GetNodeAttr(input_node->attrs(), "value", &proto));
750 Tensor const_tensor;
751 TF_CHECK_OK(MakeTensorFromProto(*proto, &const_tensor));
752 CHECK_EQ(DT_INT32, const_tensor.dtype());
753 // reshape tensor input to be rank 4.
754 // TODO(satok): Never assume rank is 4.
755 Tensor new_const_tensor(const_tensor.dtype(), TensorShape{4, 2});
756 for (int i = 0; i < PAD_HEIGHT; ++i) {
757 for (int j = 0; j < PAD_WIDTH; ++j) {
758 if (j < PAD_WIDTH - width) {
759 new_const_tensor.matrix<int32>()(j, i) = 0;
760 } else {
761 new_const_tensor.matrix<int32>()(j, i) =
762 const_tensor.matrix<int32>()(j - (PAD_WIDTH - width), i);
763 }
764 }
765 }
766
767 const int id = RegisterConstTensor(
768 new_const_tensor,
769 strings::StrCat(input_node->name(), "_", node.name(), "_1"));
770
771 GraphTransferNodeInput& node_input = *node_input_info.add_node_input();
772 node_input.set_node_id(id);
773 node_input.set_output_port(0);
774 } else {
775 LOG(FATAL);
776 }
777
778 AppendNodeParamsWithIoParams(
779 shape_refiner, node, node.name(), id, node.type_string(), op_type_id,
780 PADDING_NA_ID, node.num_inputs(), {}, node.num_outputs(),
781 false /* append_input */, true /* append_output */);
782 }
783
RegisterInputNode(const IRemoteFusedGraphOpsDefinitions & ops_definitions,const ShapeRefiner & shape_refiner,const Node & node)784 void GraphTransferer::RegisterInputNode(
785 const IRemoteFusedGraphOpsDefinitions& ops_definitions,
786 const ShapeRefiner& shape_refiner, const Node& node) {
787 const string op_type = node.type_string();
788 VLOG(1) << "Register input node: " << node.name() << ", " << op_type;
789 CHECK_EQ(node_name_to_id_cache_map_.count(node.name()), 1);
790 const int id = node_name_to_id_cache_map_[node.name()];
791 // TODO(satok): Set correct data type if it's given.
792 const int op_type_id = ops_definitions.GetOpIdFor("INPUT", {});
793 CHECK(op_type_id >= 0 && op_type_id < ops_definitions.GetTotalOpsCount())
794 << "Op" << node.name() << ", " << op_type << " is not supported,"
795 << op_type_id;
796 AppendNodeParamsWithIoParams(
797 shape_refiner, node, node.name(), id, node.type_string(), op_type_id,
798 PADDING_NA_ID, node.num_inputs(), {}, node.num_outputs(),
799 true /* append_input */, true /* append_output */);
800 }
801
RegisterFlattenNode(const IRemoteFusedGraphOpsDefinitions & ops_definitions,const ShapeRefiner & shape_refiner,const Node & node)802 void GraphTransferer::RegisterFlattenNode(
803 const IRemoteFusedGraphOpsDefinitions& ops_definitions,
804 const ShapeRefiner& shape_refiner, const Node& node) {
805 VLOG(1) << "Register flatten node: " << node.name();
806 CHECK_EQ(node_name_to_id_cache_map_.count(node.name()), 1);
807 const int id = node_name_to_id_cache_map_[node.name()];
808 // TODO(satok): Remove dependency to specific type
809 const string op_type = "FLATTEN";
810 // TODO(satok): Set correct data type if it's given.
811 const int op_type_id = ops_definitions.GetOpIdFor(op_type, {});
812 CHECK(op_type_id >= 0 && op_type_id < ops_definitions.GetTotalOpsCount());
813
814 AppendNodeParamsWithIoParams(
815 shape_refiner, node, node.name(), id, node.type_string(), op_type_id,
816 PADDING_NA_ID, node.num_inputs(), {}, node.num_outputs(),
817 true /* append_input */, true /* append_output */);
818 }
819
RegisterGenericNode(const IRemoteFusedGraphOpsDefinitions & ops_definitions,const ShapeRefiner & shape_refiner,const Node & node)820 void GraphTransferer::RegisterGenericNode(
821 const IRemoteFusedGraphOpsDefinitions& ops_definitions,
822 const ShapeRefiner& shape_refiner, const Node& node) {
823 VLOG(1) << "Register generic node: " << node.name();
824 CHECK_EQ(node_name_to_id_cache_map_.count(node.name()), 1);
825 const int id = node_name_to_id_cache_map_[node.name()];
826 // TODO(satok): Set correct data type if it's given.
827 const int op_type_id = ops_definitions.GetOpIdFor(node.type_string(), {});
828 CHECK(op_type_id >= 0 && op_type_id < ops_definitions.GetTotalOpsCount());
829
830 AppendNodeParamsWithIoParams(
831 shape_refiner, node, node.name(), id, node.type_string(), op_type_id,
832 PADDING_NA_ID, node.num_inputs(), {}, node.num_outputs(),
833 true /* append_input */, true /* append_output */);
834 }
835
836 // TODO(satok): Remove this function.
837 // TODO(satok): Remove only_register_const_node.
RegisterNodeIfAllInputsAreCached(const IRemoteFusedGraphOpsDefinitions & ops_definitions,const ShapeRefiner & shape_refiner,const Node & node,const bool only_register_const_node,const std::vector<std::pair<string,Tensor>> & input_node_info_list,const std::vector<string> & output_node_names)838 Status GraphTransferer::RegisterNodeIfAllInputsAreCached(
839 const IRemoteFusedGraphOpsDefinitions& ops_definitions,
840 const ShapeRefiner& shape_refiner, const Node& node,
841 const bool only_register_const_node,
842 const std::vector<std::pair<string, Tensor>>& input_node_info_list,
843 const std::vector<string>& output_node_names) {
844 if (only_register_const_node && !node.IsConstant()) {
845 return Status();
846 }
847 CHECK(AreAllInputsCached(node));
848 return RegisterNode(ops_definitions, shape_refiner, node,
849 input_node_info_list, output_node_names);
850 }
851
852 // CAVEAT: Append inputs and outputs params accordingly
AppendNodeParams(const string & name,const int id,const string & type,const int type_id,const int padding,const int inputs_size,const std::vector<int> & extra_inputs,const int outputs_size)853 void GraphTransferer::AppendNodeParams(const string& name, const int id,
854 const string& type, const int type_id,
855 const int padding, const int inputs_size,
856 const std::vector<int>& extra_inputs,
857 const int outputs_size) {
858 GraphTransferNodeInfo& node_info = *graph_transfer_info_->add_node_info();
859 node_info.set_name(name);
860 node_info.set_node_id(id);
861 node_info.set_type_name(type);
862 node_info.set_soc_op_id(type_id);
863 node_info.set_padding_id(padding);
864 node_info.set_input_count(inputs_size +
865 static_cast<int>(extra_inputs.size()));
866 node_info.set_output_count(static_cast<int>(outputs_size));
867 }
868
AddNodeInputByInputIndex(const Node & node,const int idx,GraphTransferNodeInputInfo * node_input_info)869 void GraphTransferer::AddNodeInputByInputIndex(
870 const Node& node, const int idx,
871 GraphTransferNodeInputInfo* node_input_info) {
872 const Edge* edge = nullptr;
873 TF_CHECK_OK(node.input_edge(idx, &edge));
874 const Node* input_node = edge->src();
875 CHECK_NOTNULL(input_node);
876 const int port = edge->src_output();
877
878 const std::string& op_name = input_node->name();
879 CHECK_GT(node_name_to_id_cache_map_.count(op_name), 0) << op_name;
880 const int src_id = node_name_to_id_cache_map_[op_name];
881 GraphTransferNodeInput& node_input = *node_input_info->add_node_input();
882 node_input.set_node_id(src_id);
883 node_input.set_output_port(port);
884 }
885
AppendNodeInputParams(const int id,const Node & node,const std::vector<int> & extra_inputs)886 void GraphTransferer::AppendNodeInputParams(
887 const int id, const Node& node, const std::vector<int>& extra_inputs) {
888 VLOG(1) << "Append input params: " << node.name() << ", " << node.num_inputs()
889 << ", " << extra_inputs.size();
890 GraphTransferNodeInputInfo& node_input_info =
891 *graph_transfer_info_->add_node_input_info();
892 node_input_info.set_node_id(id);
893 for (int i = 0; i < node.num_inputs(); ++i) {
894 AddNodeInputByInputIndex(node, i, &node_input_info);
895 }
896 for (const int extra_input : extra_inputs) {
897 GraphTransferNodeInput& node_input = *node_input_info.add_node_input();
898 node_input.set_node_id(extra_input);
899 node_input.set_output_port(0);
900 }
901 }
902
AppendNodeOutputParams(const ShapeRefiner & shape_refiner,const int id,const Node & node)903 void GraphTransferer::AppendNodeOutputParams(const ShapeRefiner& shape_refiner,
904 const int id, const Node& node) {
905 VLOG(1) << "Append output params: " << node.name() << ", "
906 << node.num_outputs();
907 GraphTransferNodeOutputInfo& node_output_info =
908 *graph_transfer_info_->add_node_output_info();
909 node_output_info.set_node_id(id);
910
911 std::vector<DataType> data_types;
912 std::vector<TensorShape> shapes;
913 Status status = RemoteFusedGraphExecuteUtils::GetOutputTensorShapeType(
914 node.attrs(), &data_types, &shapes);
915
916 for (int i = 0; i < node.num_outputs(); ++i) {
917 int data_size = -1;
918 const int output_index = i;
919 const DataType dt = node.output_type(output_index);
920 const size_t max_bytes_per_data = DataTypeSize(dt);
921
922 shape_inference::InferenceContext* context =
923 shape_refiner.GetContext(&node);
924
925 if (context != nullptr && context->ValueKnown(context->NumElements(
926 context->output(output_index)))) {
927 const shape_inference::DimensionHandle num_elements_dim =
928 context->NumElements(context->output(output_index));
929 const int64 num_output_elements = context->Value(num_elements_dim);
930 data_size = max_bytes_per_data * num_output_elements;
931 if (status.ok()) {
932 TF_CHECK_OK(status);
933 CHECK_EQ(shapes.at(i).num_elements(), num_output_elements);
934 }
935 } else {
936 TF_CHECK_OK(status);
937 // Use attribute attached to node
938 data_size = max_bytes_per_data * shapes.at(i).num_elements();
939 }
940 CHECK_GE(data_size, 0);
941 node_output_info.add_max_byte_size(data_size);
942 }
943 }
944
AppendNodeParamsWithIoParams(const ShapeRefiner & shape_refiner,const Node & node,const string & name,const int id,const string & type,const int type_id,const int padding,const int inputs_size,const std::vector<int> & extra_inputs,const int outputs_size,const bool append_input_params,const bool append_output_params)945 void GraphTransferer::AppendNodeParamsWithIoParams(
946 const ShapeRefiner& shape_refiner, const Node& node, const string& name,
947 const int id, const string& type, const int type_id, const int padding,
948 const int inputs_size, const std::vector<int>& extra_inputs,
949 const int outputs_size, const bool append_input_params,
950 const bool append_output_params) {
951 VLOG(1) << "Append node with io params: " << node.name();
952 if (append_input_params) {
953 AppendNodeInputParams(id, node, extra_inputs);
954 }
955 if (append_output_params) {
956 AppendNodeOutputParams(shape_refiner, id, node);
957 }
958 AppendNodeParams(name, id, type, type_id, padding, inputs_size, extra_inputs,
959 outputs_size);
960 }
961
962 /* static */ std::array<int64, GraphTransferer::SHAPE_ARRAY_SIZE>
BuildShapeArray(const shape_inference::ShapeHandle & shape_handle,shape_inference::InferenceContext * context)963 GraphTransferer::BuildShapeArray(
964 const shape_inference::ShapeHandle& shape_handle,
965 shape_inference::InferenceContext* context) {
966 switch (context->Rank(shape_handle)) {
967 case 0:
968 return std::array<int64, SHAPE_ARRAY_SIZE>{{1, 1, 1, 1}};
969 case 1:
970 return std::array<int64, SHAPE_ARRAY_SIZE>{
971 {1, 1, 1, context->Value(context->Dim(shape_handle, 0))}};
972 case 2:
973 return std::array<int64, SHAPE_ARRAY_SIZE>{
974 {1, 1, context->Value(context->Dim(shape_handle, 0)),
975 context->Value(context->Dim(shape_handle, 1))}};
976 case 3:
977 return std::array<int64, SHAPE_ARRAY_SIZE>{
978 {1, context->Value(context->Dim(shape_handle, 0)),
979 context->Value(context->Dim(shape_handle, 1)),
980 context->Value(context->Dim(shape_handle, 2))}};
981 case 4:
982 return std::array<int64, SHAPE_ARRAY_SIZE>{
983 {context->Value(context->Dim(shape_handle, 0)),
984 context->Value(context->Dim(shape_handle, 1)),
985 context->Value(context->Dim(shape_handle, 2)),
986 context->Value(context->Dim(shape_handle, 3))}};
987 default:
988 // TODO(satok): Support more ranks?
989 LOG(FATAL);
990 return std::array<int64, SHAPE_ARRAY_SIZE>();
991 }
992 }
993
994 /* static */ std::array<int64, GraphTransferer::SHAPE_ARRAY_SIZE>
ToTensorShapeArray(const TensorShape & shape)995 GraphTransferer::ToTensorShapeArray(const TensorShape& shape) {
996 switch (shape.dims()) {
997 case 0:
998 return std::array<int64, SHAPE_ARRAY_SIZE>{{1, 1, 1, 1}};
999 case 1:
1000 return std::array<int64, SHAPE_ARRAY_SIZE>{{1, 1, 1, shape.dim_size(0)}};
1001 case 2:
1002 return std::array<int64, SHAPE_ARRAY_SIZE>{
1003 {1, 1, shape.dim_size(0), shape.dim_size(1)}};
1004 case 3:
1005 return std::array<int64, SHAPE_ARRAY_SIZE>{
1006 {1, shape.dim_size(0), shape.dim_size(1), shape.dim_size(2)}};
1007 case 4:
1008 return std::array<int64, SHAPE_ARRAY_SIZE>{
1009 {shape.dim_size(0), shape.dim_size(1), shape.dim_size(2),
1010 shape.dim_size(3)}};
1011 default:
1012 // TODO(satok): Support more ranks?
1013 LOG(FATAL);
1014 return std::array<int64, SHAPE_ARRAY_SIZE>();
1015 }
1016 }
1017
ToPaddingDebugString(const int padding)1018 /* static */ string GraphTransferer::ToPaddingDebugString(const int padding) {
1019 switch (padding) {
1020 case 0:
1021 return "NN_PAD_NA";
1022 case Padding::VALID:
1023 return "NN_PAD_VALID";
1024 case Padding::SAME:
1025 return "NN_PAD_SAME";
1026 default:
1027 LOG(FATAL);
1028 return "";
1029 }
1030 }
1031
TransferParamsComparator(const std::unordered_map<int,std::unordered_set<int>> & dep_map)1032 GraphTransferer::TransferParamsComparator::TransferParamsComparator(
1033 const std::unordered_map<int, std::unordered_set<int>>& dep_map)
1034 : dependency_map_(dep_map) {}
1035
operator ()(const GraphTransferNodeInfo & obj0,const GraphTransferNodeInfo & obj1)1036 bool GraphTransferer::TransferParamsComparator::operator()(
1037 const GraphTransferNodeInfo& obj0, const GraphTransferNodeInfo& obj1) {
1038 const int node_id0 = obj0.node_id();
1039 const int node_id1 = obj1.node_id();
1040 bool obj0_uses_obj1 = false;
1041 if (dependency_map_.count(node_id0) > 0) {
1042 obj0_uses_obj1 = dependency_map_.at(node_id0).count(node_id1) > 0;
1043 }
1044 bool obj1_uses_obj0 = false;
1045 if (dependency_map_.count(node_id1) > 0) {
1046 obj1_uses_obj0 = dependency_map_.at(node_id1).count(node_id0) > 0;
1047 }
1048 CHECK(!obj0_uses_obj1 || !obj1_uses_obj0);
1049 if (obj0_uses_obj1) {
1050 return false;
1051 } else if (obj1_uses_obj0) {
1052 return true;
1053 }
1054 // If there is no dependency between two nodes, it expects that
1055 // the execution order follows node id order.
1056 return node_id0 < node_id1;
1057 }
1058
FillDependencyRec(const int node_id,std::unordered_map<int,std::unordered_set<int>> & dep_map,std::unordered_set<int> & completed)1059 /* static */ void GraphTransferer::FillDependencyRec(
1060 const int node_id,
1061 std::unordered_map<int, std::unordered_set<int>>& dep_map,
1062 std::unordered_set<int>& completed) {
1063 if (dep_map.count(node_id) == 0 || dep_map.at(node_id).empty() ||
1064 completed.count(node_id) == 1) {
1065 return;
1066 }
1067 CHECK_EQ(dep_map.count(node_id), 1);
1068
1069 // Complete children's dependency map
1070 for (int child_node_id : dep_map.at(node_id)) {
1071 CHECK(child_node_id != node_id);
1072 if (completed.count(child_node_id) != 0) {
1073 continue;
1074 }
1075 FillDependencyRec(child_node_id, dep_map, completed);
1076 }
1077
1078 // Find additional depending ids
1079 std::vector<int> depending_ids;
1080 for (int child_node_id : dep_map.at(node_id)) {
1081 if (dep_map.count(child_node_id) == 0) {
1082 continue;
1083 }
1084 for (int depending_id : dep_map.at(child_node_id)) {
1085 depending_ids.emplace_back(depending_id);
1086 }
1087 }
1088
1089 // Insert additional depending ids
1090 for (int depending_id : depending_ids) {
1091 if (dep_map.at(node_id).count(depending_id) == 0) {
1092 dep_map.at(node_id).emplace(depending_id);
1093 }
1094 }
1095
1096 // DP: Record completed node id
1097 completed.emplace(node_id);
1098 }
1099
MakeTensorFromProto(const TensorProto & tensor_proto,Tensor * tensor)1100 /* static */ Status GraphTransferer::MakeTensorFromProto(
1101 const TensorProto& tensor_proto, Tensor* tensor) {
1102 if (tensor_proto.dtype() > 0 && tensor_proto.dtype() <= DataType_MAX) {
1103 Tensor parsed(tensor_proto.dtype());
1104 if (parsed.FromProto(cpu_allocator(), tensor_proto)) {
1105 *tensor = parsed;
1106 return Status::OK();
1107 }
1108 }
1109 return errors::InvalidArgument("Cannot parse tensor from proto: ",
1110 tensor_proto.DebugString());
1111 }
1112
ClearCache()1113 void GraphTransferer::ClearCache() {
1114 node_name_cache_list_.clear();
1115 node_name_to_id_cache_map_.clear();
1116 }
1117
DumpNodeTransferParams() const1118 void GraphTransferer::DumpNodeTransferParams() const {
1119 LOG(INFO) << "*** Const Nodes ***";
1120 for (const GraphTransferConstNodeInfo& params :
1121 graph_transfer_info_->const_node_info()) {
1122 // TODO(satok): Stop assuming shape size is 4.
1123 CHECK_EQ(params.shape_size(), 4);
1124 LOG(INFO) << "[ " << params.node_id() << " \"" << params.name()
1125 << "\" (Const)";
1126 LOG(INFO) << " shape: " << params.shape(0) << params.shape(1)
1127 << params.shape(2) << params.shape(3);
1128 LOG(INFO) << " data_name: "
1129 << (params.data().length() <= 0
1130 ? ""
1131 : DATA_NODE_PREFIX + ToString(params.node_id()));
1132 LOG(INFO) << " data_size: " << params.data().length() << " bytes"
1133 << " ]";
1134 }
1135 LOG(INFO) << "******\n";
1136 LOG(INFO) << "*** Op Nodes ***";
1137 for (const GraphTransferNodeInfo& params :
1138 graph_transfer_info_->node_info()) {
1139 LOG(INFO) << "[ " << params.node_id() << " \"" << params.name();
1140 LOG(INFO) << " type: " << params.type_name();
1141 LOG(INFO) << " padding: " << ToPaddingDebugString(params.padding_id());
1142 LOG(INFO) << " inputs: " << INPUTS_NODE_PREFIX + ToString(params.node_id())
1143 << ", size = " << params.input_count();
1144 LOG(INFO) << " outputs: "
1145 << (params.output_count() <= 0
1146 ? NULL_OUTPUT_NAME
1147 : (OUTPUTS_NODE_PREFIX + ToString(params.node_id())))
1148 << ", size = " << params.output_count() << " ]";
1149 }
1150 LOG(INFO) << "******\n";
1151 LOG(INFO) << "*** Node input params ***";
1152 for (const GraphTransferNodeInputInfo& params :
1153 graph_transfer_info_->node_input_info()) {
1154 LOG(INFO) << "[ " << params.node_id() << " ]";
1155 for (const GraphTransferNodeInput& node_input : params.node_input()) {
1156 LOG(INFO) << " src node id = " << node_input.node_id()
1157 << ", output port = " << node_input.output_port();
1158 }
1159 }
1160 LOG(INFO) << "******\n";
1161 LOG(INFO) << "*** Node output params ***";
1162 for (const GraphTransferNodeOutputInfo& params :
1163 graph_transfer_info_->node_output_info()) {
1164 LOG(INFO) << "[ " << params.node_id() << " ]";
1165 for (const int max_size : params.max_byte_size()) {
1166 LOG(INFO) << " max_size = " << max_size;
1167 }
1168 }
1169 LOG(INFO) << "******\n";
1170 }
1171
DumpVerificationStringOfNodeTransferParams() const1172 void GraphTransferer::DumpVerificationStringOfNodeTransferParams() const {
1173 for (const GraphTransferConstNodeInfo& params :
1174 graph_transfer_info_->const_node_info()) {
1175 std::stringstream sstream;
1176 // TODO(satok): Stop assuming shape size is 4.
1177 CHECK_EQ(params.shape_size(), 4);
1178 sstream << "---(CONST) [" << std::hex << params.node_id() << std::dec << ","
1179 << params.shape(0) << "," << params.shape(1) << ","
1180 << params.shape(2) << "," << params.shape(3) << ","
1181 << (params.data().length() <= 0
1182 ? ""
1183 : DATA_NODE_PREFIX + ToString(params.node_id()))
1184 << "," << params.data().length() << "," << params.name() << "]";
1185 LOG(INFO) << sstream.str();
1186 }
1187 LOG(INFO) << "Const node count = "
1188 << graph_transfer_info_->const_node_info_size();
1189 for (const GraphTransferNodeInfo& params :
1190 graph_transfer_info_->node_info()) {
1191 std::stringstream sstream;
1192 sstream << "---(OP) [" << params.name().c_str() << "," << std::hex
1193 << params.node_id() << std::dec << "," << params.soc_op_id() << ","
1194 << ToPaddingDebugString(params.padding_id()) << ","
1195 << INPUTS_NODE_PREFIX + ToString(params.node_id()) << ","
1196 << params.input_count() << ","
1197 << (params.output_count() <= 0
1198 ? NULL_OUTPUT_NAME
1199 : (OUTPUTS_NODE_PREFIX + ToString(params.node_id())))
1200 << "," << params.output_count() << "," << params.type_name() << "]";
1201 LOG(INFO) << sstream.str();
1202 }
1203 LOG(INFO) << "Op node count = " << graph_transfer_info_->node_info_size();
1204 for (const GraphTransferNodeInputInfo& params :
1205 graph_transfer_info_->node_input_info()) {
1206 std::stringstream sstream;
1207 sstream << "---(INPUT) [" << std::hex << params.node_id() << std::dec;
1208 for (const GraphTransferNodeInput& node_input : params.node_input()) {
1209 sstream << "," << std::hex << node_input.node_id() << std::dec << ","
1210 << node_input.output_port();
1211 }
1212 sstream << "]";
1213 LOG(INFO) << sstream.str();
1214 }
1215 LOG(INFO) << "Input params count = "
1216 << graph_transfer_info_->node_input_info_size();
1217 for (const GraphTransferNodeOutputInfo& params :
1218 graph_transfer_info_->node_output_info()) {
1219 std::stringstream sstream;
1220 sstream << "---(OUTPUT) [" << std::hex << params.node_id() << std::dec;
1221 for (const int max_size : params.max_byte_size()) {
1222 sstream << "," << max_size;
1223 }
1224 sstream << "]";
1225 LOG(INFO) << sstream.str();
1226 }
1227 LOG(INFO) << "Output params count = "
1228 << graph_transfer_info_->node_output_info_size();
1229 }
1230
1231 } // namespace tensorflow
1232