• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/kernels/hexagon/graph_transferer.h"
17 
18 #include <algorithm>
19 #include <cinttypes>
20 
21 #include "tensorflow/core/framework/graph.pb.h"
22 #include "tensorflow/core/framework/graph_transfer_info.pb.h"
23 #include "tensorflow/core/framework/op.h"
24 #include "tensorflow/core/graph/algorithm.h"
25 #include "tensorflow/core/graph/graph_constructor.h"
26 #include "tensorflow/core/graph/node_builder.h"
27 #include "tensorflow/core/platform/env.h"
28 #include "tensorflow/core/platform/types.h"
29 #include "tensorflow/core/public/session.h"
30 #include "tensorflow/core/public/session_options.h"
31 #include "tensorflow/core/util/tensor_slice_writer.h"
32 
33 namespace tensorflow {
34 
35 // function alias
36 constexpr auto AddOutputTensorShapeTypeByTensorShapeMap =
37     &RemoteFusedGraphExecuteUtils::AddOutputTensorShapeTypeByTensorShapeMap;
38 
39 constexpr bool DBG_DUMP_VERIFICATION_STRING = false;
40 constexpr bool DBG_DUMP_PARAMS = false;
41 
42 const char RESHAPE_NODE_TYPE_STRING[] = "Reshape";
43 const char SOURCE_NODE_NAME[] = "_SOURCE";
44 const char SINK_NODE_NAME[] = "_SINK";
45 const char INPUTS_NODE_PREFIX[] = "inputs_for_";
46 const char OUTPUTS_NODE_PREFIX[] = "outputs_for_";
47 const char DATA_NODE_PREFIX[] = "data_for_op_";
48 const char CONST_SHAPE_PREFIX[] = "const_shape_";
49 const char CONST_VAL_PREFIX[] = "const_val_";
50 const char CONST_TENSOR_PREFIX[] = "const_tensor_";
51 const char PADDING_ATTR_NAME[] = "padding";
52 const char STRIDES_ATTR_NAME[] = "strides";
53 const char KEEP_DIMS_ATTR_NAME[] = "keep_dims";
54 const char KSIZE_ATTR_NAME[] = "ksize";
55 const char NULL_OUTPUT_NAME[] = "NULL";
56 const char AGGREGATED_INPUT_NODE_NAME[] = "graph_transfer_aggregated_input";
57 const int PADDING_NA_ID = 0;  // VALID = 1, SAME = 2
58 
59 // This is a temporary workaround to support android build
60 // where std::string is not supported even with c++11 option.
61 template <typename T>
ToString(T val)62 static string ToString(T val) {
63   std::stringstream stream;
64   stream << val;
65   return stream.str();
66 }
67 
FindMutableNodeByName(const string & name,Graph * graph)68 static Node* FindMutableNodeByName(const string& name, Graph* graph) {
69   const TensorId tid = ParseTensorName(name);
70   for (Node* node : graph->nodes()) {
71     if (node != nullptr && node->name() == tid.first) {
72       return node;
73     }
74   }
75   return nullptr;
76 }
77 
GraphTransferer()78 GraphTransferer::GraphTransferer() {
79   graph_transfer_info_ = new GraphTransferInfo();
80 }
81 
~GraphTransferer()82 GraphTransferer::~GraphTransferer() { delete graph_transfer_info_; }
83 
84 /**
85  * graph loading functions
86  * - LoadGraphFromProto
87  * - LoadGraphFromProptoFile
88  * These functions read a graph definition and store parameters
89  * of node to transfer the graph to SOC.
90  */
LoadGraphFromProto(const IRemoteFusedGraphOpsDefinitions & ops_definitions,const GraphDef & graph_def,const std::vector<std::pair<string,Tensor>> & input_node_info_list,const std::vector<string> & output_node_names,const bool shape_inference_for_unknown_shape)91 Status GraphTransferer::LoadGraphFromProto(
92     const IRemoteFusedGraphOpsDefinitions& ops_definitions,
93     const GraphDef& graph_def,
94     const std::vector<std::pair<string, Tensor>>& input_node_info_list,
95     const std::vector<string>& output_node_names,
96     const bool shape_inference_for_unknown_shape) {
97   Graph graph(OpRegistry::Global());
98   ShapeRefiner shape_refiner(graph.versions(), graph.op_registry());
99   Status status = ImportGraphDef({}, graph_def, &graph, &shape_refiner);
100   if (!status.ok()) {
101     return status;
102   }
103 
104   if (shape_inference_for_unknown_shape) {
105     status = RemoteFusedGraphExecuteUtils::PropagateShapeInference(
106         graph_def, input_node_info_list, &graph, &shape_refiner);
107     if (!status.ok()) {
108       return status;
109     }
110   }
111 
112   TF_RETURN_IF_ERROR(TransformGraphToAddAggregatedInputNode(
113       input_node_info_list, &graph, &shape_refiner));
114 
115   std::unordered_multimap<string, const Node*> op_name_to_node_multimap(
116       graph.num_nodes());
117   for (const Node* const node : graph.nodes()) {
118     if (node == nullptr) {
119       continue;
120     }
121     CacheNode(*node);
122   }
123 
124   for (const Node* const node : graph.nodes()) {
125     if (node == nullptr) {
126       continue;
127     }
128     VLOG(1) << "<Node> " << node->name();
129     for (const Node* const input_node : node->in_nodes()) {
130       const string& name = input_node->name();
131       op_name_to_node_multimap.emplace(name, node);
132       VLOG(1) << "Add dependency: " << name << " -> " << node->name();
133     }
134   }
135 
136   for (const Node* const node : graph.nodes()) {
137     if (node == nullptr) {
138       continue;
139     }
140     status = RegisterNodeIfAllInputsAreCached(
141         ops_definitions, shape_refiner, *node, false, input_node_info_list,
142         output_node_names);
143     if (!status.ok()) {
144       LOG(ERROR) << "Failed to transfer graph " << status;
145       return status;
146     }
147   }
148 
149   SortParams(output_node_names);
150 
151   for (const std::pair<string, Tensor>& input_node_info :
152        input_node_info_list) {
153     GraphTransferGraphInputNodeInfo& graph_input_node_info =
154         *graph_transfer_info_->add_graph_input_node_info();
155     graph_input_node_info.set_name(input_node_info.first);
156     graph_input_node_info.set_dtype(input_node_info.second.dtype());
157     for (const int64 dim : ToTensorShapeArray(input_node_info.second.shape())) {
158       graph_input_node_info.add_shape(dim);
159     }
160   }
161 
162   for (const string& output_node_name : output_node_names) {
163     const TensorId tid = ParseTensorName(output_node_name);
164     const string node_name(tid.first);
165     const int port = tid.second;
166     const int node_id = node_name_to_id_cache_map_.at(node_name);
167     const Node* node = node_name_cache_list_.at(node_id);
168     CHECK_NOTNULL(node);
169 
170     GraphTransferGraphOutputNodeInfo& graph_output_node_info =
171         *graph_transfer_info_->add_graph_output_node_info();
172     graph_output_node_info.set_name(strings::StrCat(node_name, ":", port));
173 
174     // Get output tensor shape type
175     std::vector<DataType> data_types;
176     std::vector<TensorShape> shapes;
177     status = RemoteFusedGraphExecuteUtils::GetOutputTensorShapeType(
178         node->attrs(), &data_types, &shapes);
179     if (status.ok()) {
180       CHECK(data_types.size() > port);
181       graph_output_node_info.set_dtype(data_types.at(port));
182       for (const int64 dim : ToTensorShapeArray(shapes.at(port))) {
183         graph_output_node_info.add_shape(dim);
184       }
185     }
186   }
187 
188   ClearCache();
189   if (DBG_DUMP_PARAMS) {
190     DumpNodeTransferParams();
191   }
192   if (DBG_DUMP_VERIFICATION_STRING) {
193     DumpVerificationStringOfNodeTransferParams();
194   }
195   return Status();
196 }
197 
LoadGraphFromProtoFile(const IRemoteFusedGraphOpsDefinitions & ops_definitions,const string & graph_def_path,const std::vector<std::pair<string,Tensor>> & input_node_info_list,const std::vector<string> & output_node_names,const bool is_text_proto,const bool shape_inference_for_unknown_shape,const bool dry_run_for_unknown_shape)198 Status GraphTransferer::LoadGraphFromProtoFile(
199     const IRemoteFusedGraphOpsDefinitions& ops_definitions,
200     const string& graph_def_path,
201     const std::vector<std::pair<string, Tensor>>& input_node_info_list,
202     const std::vector<string>& output_node_names, const bool is_text_proto,
203     const bool shape_inference_for_unknown_shape,
204     const bool dry_run_for_unknown_shape) {
205   GraphDef graph_def;
206   string output;
207   Status status;
208   VLOG(1) << "Parse file " << graph_def_path;
209   if (is_text_proto) {
210     status = ReadFileToString(Env::Default(), graph_def_path, &output);
211     if (!protobuf::TextFormat::ParseFromString(output, &graph_def)) {
212       return errors::InvalidArgument("Cannot parse proto string.");
213     }
214   } else {
215     status = ReadBinaryProto(Env::Default(), graph_def_path, &graph_def);
216   }
217   if (!status.ok()) {
218     VLOG(1) << "Failed to load graph " << status;
219     return status;
220   }
221   if (dry_run_for_unknown_shape) {
222     VLOG(1) << "Dry run graph to obtain shape of nodes";
223     RemoteFusedGraphExecuteUtils::TensorShapeMap tensor_shape_map;
224     status = RemoteFusedGraphExecuteUtils::DryRunInferenceForAllNode(
225         graph_def, input_node_info_list, true, &tensor_shape_map);
226     if (!status.ok()) {
227       return status;
228     }
229     for (NodeDef& node_def : *graph_def.mutable_node()) {
230       TF_CHECK_OK(AddOutputTensorShapeTypeByTensorShapeMap(tensor_shape_map,
231                                                            &node_def));
232     }
233   }
234   VLOG(1) << "Load graph with output tensors";
235   return LoadGraphFromProto(ops_definitions, graph_def, input_node_info_list,
236                             output_node_names,
237                             shape_inference_for_unknown_shape);
238 }
239 
SortParams(const std::vector<string> & output_node_names)240 void GraphTransferer::SortParams(const std::vector<string>& output_node_names) {
241   // TODO(satok): optimize complexity
242   std::unordered_map<int, GraphTransferNodeInputInfo*> input_map;
243   for (GraphTransferNodeInputInfo& input :
244        *graph_transfer_info_->mutable_node_input_info()) {
245     input_map.emplace(input.node_id(), &input);
246   }
247 
248   // Setup dependency map placeholder
249   std::vector<int> output_node_ids;
250   std::unordered_map<int, std::unordered_set<int>> dependency_map;
251   for (const GraphTransferNodeInfo& params :
252        graph_transfer_info_->node_info()) {
253     const int node_id = params.node_id();
254     for (const string& output_node_name : output_node_names) {
255       if (params.name() == output_node_name) {
256         output_node_ids.emplace_back(node_id);
257       }
258     }
259 
260     dependency_map.emplace(std::piecewise_construct, std::make_tuple(node_id),
261                            std::make_tuple());
262     if (params.input_count() == 0) {
263       continue;
264     }
265     CHECK_EQ(input_map.count(node_id), 1);
266     for (const GraphTransferNodeInput& node_input :
267          input_map.at(node_id)->node_input()) {
268       dependency_map.at(node_id).emplace(node_input.node_id());
269     }
270   }
271 
272   // Create dependency map traversed from output nodes
273   std::unordered_set<int> completed;
274   for (int output_node_id : output_node_ids) {
275     FillDependencyRec(output_node_id, dependency_map, completed);
276   }
277 
278   std::sort(graph_transfer_info_->mutable_node_info()->begin(),
279             graph_transfer_info_->mutable_node_info()->end(),
280             TransferParamsComparator(dependency_map));
281 }
282 
EnableStrictCheckMode(const bool enable)283 void GraphTransferer::EnableStrictCheckMode(const bool enable) {
284   strict_check_mode_ = enable;
285 }
286 
SetSerializedGraphTransferInfo(const string & serialized_proto)287 void GraphTransferer::SetSerializedGraphTransferInfo(
288     const string& serialized_proto) {
289   graph_transfer_info_->ParseFromString(serialized_proto);
290 }
291 
GetGraphTransferInfo() const292 const GraphTransferInfo& GraphTransferer::GetGraphTransferInfo() const {
293   return *graph_transfer_info_;
294 }
295 
GetMutableGraphTransferInfo()296 GraphTransferInfo& GraphTransferer::GetMutableGraphTransferInfo() {
297   return *graph_transfer_info_;
298 }
299 
CacheNode(const Node & node)300 void GraphTransferer::CacheNode(const Node& node) {
301   if (node_name_to_id_cache_map_.count(node.name()) > 0) {
302     return;
303   }
304   node_name_cache_list_.emplace_back(&node);
305   const int node_id = node_name_cache_list_.size() - 1;
306   bool emplace_succeeded = false;
307   std::tie(std::ignore, emplace_succeeded) =
308       node_name_to_id_cache_map_.emplace(node.name(), node_id);
309   CHECK(emplace_succeeded);
310 }
311 
AreAllInputsCached(const Node & node) const312 bool GraphTransferer::AreAllInputsCached(const Node& node) const {
313   for (const Node* const input_node : node.in_nodes()) {
314     if (node_name_to_id_cache_map_.count(input_node->name()) <= 0) {
315       VLOG(1) << "input_node " << input_node->name() << " of " << node.name()
316               << " is not cached yet.";
317       return false;
318     }
319   }
320   return true;
321 }
322 
TransformGraphToAddAggregatedInputNode(const std::vector<std::pair<string,Tensor>> & input_node_info_list,Graph * graph,ShapeRefiner * shape_refiner)323 Status GraphTransferer::TransformGraphToAddAggregatedInputNode(
324     const std::vector<std::pair<string, Tensor>>& input_node_info_list,
325     Graph* graph, ShapeRefiner* shape_refiner) {
326   // Transform a remote fused graph to add an aggregated input node which takes
327   // all inputs of the remote graph.
328   DataTypeVector input_data_types;
329   std::vector<DataType> data_types;
330   std::vector<TensorShape> shapes;
331   std::vector<string> input_nodes;
332   for (int i = 0; i < input_node_info_list.size(); ++i) {
333     Node* node = FindMutableNodeByName(input_node_info_list.at(i).first, graph);
334     CHECK_NOTNULL(node);
335     input_nodes.emplace_back(node->name());
336     input_data_types.emplace_back(input_node_info_list.at(i).second.dtype());
337     data_types.emplace_back(input_node_info_list.at(i).second.dtype());
338     shapes.emplace_back(input_node_info_list.at(i).second.shape());
339   }
340 
341   auto builder =
342       NodeBuilder(AGGREGATED_INPUT_NODE_NAME, "RemoteFusedGraphExecute")
343           .Input(std::vector<NodeBuilder::NodeOut>{})
344           .Attr("Tinputs", DataTypeVector{})
345           .Attr("Toutputs", input_data_types)
346           .Attr("serialized_remote_fused_graph_execute_info", "")
347           .Attr(RemoteFusedGraphExecuteUtils::ATTR_OUTPUT_DATA_TYPES,
348                 data_types)
349           .Attr(RemoteFusedGraphExecuteUtils::ATTR_OUTPUT_SHAPES, shapes);
350 
351   Node* input_node;
352   TF_RETURN_IF_ERROR(builder.Finalize(graph, &input_node));
353   CHECK_NOTNULL(input_node);
354 
355   bool refined;
356   TF_RETURN_IF_ERROR(
357       shape_refiner->UpdateNode(input_node, false /* relax */, &refined));
358 
359   shape_inference::InferenceContext* context =
360       shape_refiner->GetContext(input_node);
361   for (int i = 0; i < input_node_info_list.size(); ++i) {
362     shape_inference::ShapeHandle handle;
363     TF_RETURN_IF_ERROR(context->MakeShapeFromTensorShape(
364         input_node_info_list.at(i).second.shape(), &handle));
365     TF_RETURN_IF_ERROR(shape_refiner->SetShape(input_node, i, handle));
366   }
367 
368   // Cache the aggregate input node first as it's consumed first.
369   CacheNode(*input_node);
370 
371   std::vector<Node*> original_input_nodes(input_nodes.size());
372 
373   for (int i = 0; i < input_nodes.size(); ++i) {
374     const string& node_name = input_nodes.at(i);
375     Node* original_input_node = FindMutableNodeByName(node_name, graph);
376     CHECK_NOTNULL(original_input_node);
377     CHECK_EQ(1, original_input_node->num_outputs());  // replaced by identity.
378     Node* created_node;
379     TF_RETURN_IF_ERROR(RemoteFusedGraphExecuteUtils::BuildIdentityOpNode(
380         node_name, AGGREGATED_INPUT_NODE_NAME, i, data_types.at(i), graph,
381         &created_node));
382     CHECK_NOTNULL(created_node);
383     std::vector<DataType> data_types;
384     std::vector<TensorShape> shapes;
385     Status status = RemoteFusedGraphExecuteUtils::GetOutputTensorShapeType(
386         original_input_node->attrs(), &data_types, &shapes);
387     if (status.ok()) {
388       created_node->AddAttr(
389           RemoteFusedGraphExecuteUtils::ATTR_OUTPUT_DATA_TYPES, data_types);
390       created_node->AddAttr(RemoteFusedGraphExecuteUtils::ATTR_OUTPUT_SHAPES,
391                             shapes);
392     }
393     for (const Edge* out_edge : original_input_node->out_edges()) {
394       Node* dst = out_edge->dst();
395       int dst_port = out_edge->dst_input();
396       // Unused edge will be removed when removing node.
397       graph->AddEdge(created_node, 0, dst, dst_port);
398     }
399     original_input_nodes[i] = original_input_node;
400 
401     TF_RETURN_IF_ERROR(
402         shape_refiner->UpdateNode(created_node, false /* relax */, &refined));
403 
404     shape_inference::InferenceContext* context =
405         shape_refiner->GetContext(created_node);
406     CHECK_NOTNULL(context);
407 
408     // Cache replaced input node next to the aggregated input node.
409     CacheNode(*created_node);
410   }
411 
412   // Remove original input nodes after adding new input nodes to avoid
413   // reusing same pointer in Graph.
414   for (Node* original_input_node : original_input_nodes) {
415     graph->RemoveNode(original_input_node);
416   }
417 
418   return Status::OK();
419 }
420 
RegisterNode(const IRemoteFusedGraphOpsDefinitions & ops_definitions,const ShapeRefiner & shape_refiner,const Node & node,const std::vector<std::pair<string,Tensor>> & input_node_info_list,const std::vector<string> & output_node_names)421 Status GraphTransferer::RegisterNode(
422     const IRemoteFusedGraphOpsDefinitions& ops_definitions,
423     const ShapeRefiner& shape_refiner, const Node& node,
424     const std::vector<std::pair<string, Tensor>>& input_node_info_list,
425     const std::vector<string>& output_node_names) {
426   VLOG(1) << "Register node: " << node.name() << ", " << std::hex
427           << node_name_to_id_cache_map_.at(node.name());
428   if (node.name() == SOURCE_NODE_NAME || node.name() == SINK_NODE_NAME) {
429     // Just ignore sink and source
430     return Status::OK();
431   } else if (node.name() == AGGREGATED_INPUT_NODE_NAME) {
432     RegisterInputNode(ops_definitions, shape_refiner, node);
433     return Status::OK();
434   } else if (node.IsConstant()) {
435     RegisterConstantNode(shape_refiner, node);
436   } else if (IsPadNode(node)) {
437     RegisterPadNode(ops_definitions, shape_refiner, node);
438   } else if (HasPaddingAndStrides(node)) {
439     RegisterNodeWithPaddingAndStrides(ops_definitions, shape_refiner, node);
440   } else if (NeedsToAddRank(node)) {
441     RegisterNodeWithRank(ops_definitions, shape_refiner, node);
442   } else if (IsNodeFlattenReshape(node, shape_refiner)) {
443     RegisterFlattenNode(ops_definitions, shape_refiner, node);
444   } else if (ops_definitions.GetOpIdFor(node.type_string(), {}) !=
445              IRemoteFusedGraphOpsDefinitions::INVALID_OP_ID) {
446     // TODO(satok): Set correct data type if it's given.
447     RegisterGenericNode(ops_definitions, shape_refiner, node);
448   } else {
449     return errors::InvalidArgument(node.type_string() +
450                                    " has not been implemented yet.");
451   }
452 
453   return Status::OK();
454 }
455 
RegisterConstantNode(const ShapeRefiner & shape_refiner,const Node & node)456 void GraphTransferer::RegisterConstantNode(const ShapeRefiner& shape_refiner,
457                                            const Node& node) {
458   VLOG(1) << "Register constant node: " << node.name();
459   CHECK_EQ(node_name_to_id_cache_map_.count(node.name()), 1);
460   const int id = node_name_to_id_cache_map_[node.name()];
461   const int output_node_size = node.num_outputs();
462   CHECK_EQ(output_node_size, 1);
463   // TODO(satok): support multiple outputs?
464   const int output_index = 0;
465   const DataType dt = node.output_type(output_index);
466   const size_t max_bytes_per_data = DataTypeSize(dt);
467   CHECK_GT(max_bytes_per_data, 0)
468       << "dt = " << dt << ", " + DataTypeString(dt) << ", "
469       << max_bytes_per_data << ", " << static_cast<int>(DataTypeSize(dt))
470       << ",,,,,,,";
471   shape_inference::InferenceContext* context = shape_refiner.GetContext(&node);
472   shape_inference::ShapeHandle shape_handle = context->output(output_index);
473   const shape_inference::DimensionHandle num_elements_dim =
474       context->NumElements(shape_handle);
475   std::array<int64, SHAPE_ARRAY_SIZE> shape_array;
476   int data_size;
477   // Shape of constant node must be known
478   CHECK(context->ValueKnown(num_elements_dim));
479   const int64 num_output_elements = context->Value(num_elements_dim);
480   data_size = max_bytes_per_data * num_output_elements;
481   shape_array = BuildShapeArray(shape_handle, context);
482 
483   GraphTransferConstNodeInfo& const_node_info =
484       *graph_transfer_info_->add_const_node_info();
485   const_node_info.set_name(node.name());
486   const_node_info.set_node_id(id);
487   // TODO(satok): Make this generic. Never assume rank is 4.
488   CHECK_EQ(4, SHAPE_ARRAY_SIZE);
489   const_node_info.add_shape(shape_array[0]);
490   const_node_info.add_shape(shape_array[1]);
491   const_node_info.add_shape(shape_array[2]);
492   const_node_info.add_shape(shape_array[3]);
493   const TensorProto* proto = nullptr;
494   TF_CHECK_OK(GetNodeAttr(node.attrs(), "value", &proto));
495   Tensor const_tensor;
496   TF_CHECK_OK(MakeTensorFromProto(*proto, &const_tensor));
497 
498   const_node_info.set_dtype(const_tensor.dtype());
499   if (data_size > 0) {
500     const_node_info.set_data(const_tensor.tensor_data().data(), data_size);
501   }
502 }
503 
RegisterConstantShape(const std::vector<int> & shape)504 int GraphTransferer::RegisterConstantShape(const std::vector<int>& shape) {
505   VLOG(1) << "Cache constant shape.";
506   // TODO(satok): Handle non-4dim strides
507   CHECK_EQ(shape.size(), 4);
508   const string shape_name = CONST_SHAPE_PREFIX + ToString(shape.at(0)) + 'x' +
509                             ToString(shape.at(1)) + 'x' +
510                             ToString(shape.at(2)) + 'x' + ToString(shape.at(3));
511   if (node_name_to_id_cache_map_.count(shape_name) <= 0) {
512     node_name_cache_list_.emplace_back(nullptr);
513     const int id = node_name_cache_list_.size() - 1;
514     node_name_to_id_cache_map_.emplace(shape_name, id);
515     GraphTransferConstNodeInfo& const_node_info =
516         *graph_transfer_info_->add_const_node_info();
517     const_node_info.set_name(shape_name);
518     const_node_info.set_node_id(id);
519     // TODO(satok): Make this generic. Never assume rank is 5.
520     const_node_info.add_shape(static_cast<int64>(shape[0]));
521     const_node_info.add_shape(static_cast<int64>(shape[1]));
522     const_node_info.add_shape(static_cast<int64>(shape[2]));
523     const_node_info.add_shape(static_cast<int64>(shape[3]));
524   }
525   return node_name_to_id_cache_map_[shape_name];
526 }
527 
RegisterConstTensor(const Tensor & tensor,const string & suffix)528 int GraphTransferer::RegisterConstTensor(const Tensor& tensor,
529                                          const string& suffix) {
530   VLOG(1) << "Cache const tensor.";
531   const int dims = tensor.shape().dims();
532   CHECK(dims <= 4);
533   const string node_name = strings::StrCat(CONST_TENSOR_PREFIX, "_", suffix);
534   if (node_name_to_id_cache_map_.count(node_name) <= 0) {
535     node_name_cache_list_.emplace_back(nullptr);
536     const int id = node_name_cache_list_.size() - 1;
537     node_name_to_id_cache_map_.emplace(node_name, id);
538     GraphTransferConstNodeInfo& const_node_info =
539         *graph_transfer_info_->add_const_node_info();
540     const_node_info.set_name(node_name);
541     const_node_info.set_node_id(id);
542     CHECK_EQ(4, SHAPE_ARRAY_SIZE);
543     for (int i = 0; i < SHAPE_ARRAY_SIZE; ++i) {
544       if (i < SHAPE_ARRAY_SIZE - dims) {
545         const_node_info.add_shape(1);
546       } else {
547         const_node_info.add_shape(
548             tensor.shape().dim_size(i - (SHAPE_ARRAY_SIZE - dims)));
549       }
550     }
551     const_node_info.set_dtype(tensor.dtype());
552     const_node_info.set_data(tensor.tensor_data().data(),
553                              tensor.tensor_data().size());
554   }
555   return node_name_to_id_cache_map_[node_name];
556 }
557 
RegisterConstScalar(const DataType dt,const int val,const int dst_id,const int dst_input_count)558 int GraphTransferer::RegisterConstScalar(const DataType dt, const int val,
559                                          const int dst_id,
560                                          const int dst_input_count) {
561   VLOG(1) << "Cache const.";
562   const string val_name =
563       CONST_VAL_PREFIX + ToString(dst_id) + '_' + ToString(dst_input_count);
564   if (node_name_to_id_cache_map_.count(val_name) <= 0) {
565     node_name_cache_list_.emplace_back(nullptr);
566     const int id = node_name_cache_list_.size() - 1;
567     node_name_to_id_cache_map_.emplace(val_name, id);
568     GraphTransferConstNodeInfo& const_node_info =
569         *graph_transfer_info_->add_const_node_info();
570     const_node_info.set_name(val_name);
571     const_node_info.set_node_id(id);
572     // TODO(satok): Do not assume rank is 4 here.
573     const_node_info.add_shape(static_cast<int64>(1));
574     const_node_info.add_shape(static_cast<int64>(1));
575     const_node_info.add_shape(static_cast<int64>(1));
576     const_node_info.add_shape(static_cast<int64>(1));
577     const_node_info.set_data(&val, DataTypeSize(dt));
578   }
579   return node_name_to_id_cache_map_[val_name];
580 }
581 
HasPaddingAndStrides(const Node & node)582 bool GraphTransferer::HasPaddingAndStrides(const Node& node) {
583   auto attrs = node.attrs();
584   return attrs.Find(PADDING_ATTR_NAME) != nullptr &&
585          attrs.Find(STRIDES_ATTR_NAME) != nullptr;
586 }
587 
NeedsToAddRank(const Node & node)588 bool GraphTransferer::NeedsToAddRank(const Node& node) {
589   const StringPiece op_type(node.type_string());
590   if (op_type == "Transpose" || op_type == "ExpandDims") {
591     return true;
592   }
593   return false;
594 }
595 
IsPadNode(const Node & node)596 bool GraphTransferer::IsPadNode(const Node& node) {
597   const StringPiece op_type(node.type_string());
598   if (op_type == "Pad") {
599     return true;
600   }
601   return false;
602 }
603 
IsNodeFlattenReshape(const Node & node,const ShapeRefiner & shape_refiner)604 bool GraphTransferer::IsNodeFlattenReshape(const Node& node,
605                                            const ShapeRefiner& shape_refiner) {
606   // Check if node is reshape op
607   if (node.type_string() != RESHAPE_NODE_TYPE_STRING) {
608     return false;
609   }
610 
611   shape_inference::InferenceContext* context = shape_refiner.GetContext(&node);
612   // Check if output count is valid
613   if (context->num_outputs() != 1) {
614     return false;
615   }
616 
617   shape_inference::ShapeHandle shape_handle = context->output(0);
618   std::array<int64, SHAPE_ARRAY_SIZE> shape_array;
619   const shape_inference::DimensionHandle dim_handle =
620       context->NumElements(shape_handle);
621 
622   // Obtain shape of output of node
623   if (context->ValueKnown(dim_handle)) {
624     shape_array = BuildShapeArray(shape_handle, context);
625   } else {
626     std::vector<TensorShape> shapes;
627     TF_CHECK_OK(RemoteFusedGraphExecuteUtils::GetOutputTensorShapeType(
628         node.attrs(), nullptr, &shapes));
629 
630     // Number of outputs should be 1 for reshape node.
631     CHECK_EQ(1, shapes.size());
632     shape_array = ToTensorShapeArray(shapes.at(0));
633   }
634 
635   // check if reshape op just does flatten
636   if (shape_array[0] == 1 && shape_array[1] == 1 && shape_array[2] == 1) {
637     return true;
638   } else {
639     return false;
640   }
641 }
642 
RegisterNodeWithPaddingAndStrides(const IRemoteFusedGraphOpsDefinitions & ops_definitions,const ShapeRefiner & shape_refiner,const Node & node)643 void GraphTransferer::RegisterNodeWithPaddingAndStrides(
644     const IRemoteFusedGraphOpsDefinitions& ops_definitions,
645     const ShapeRefiner& shape_refiner, const Node& node) {
646   CHECK_EQ(node_name_to_id_cache_map_.count(node.name()), 1);
647   const int id = node_name_to_id_cache_map_[node.name()];
648   shape_inference::InferenceContext* context = shape_refiner.GetContext(&node);
649   CHECK(node.attrs().Find(PADDING_ATTR_NAME));
650   // TODO(satok): Use context->GetAttr(...) instead?
651   Padding padding;
652   TF_CHECK_OK(context->GetAttr(PADDING_ATTR_NAME, &padding));
653   CHECK(node.attrs().Find(STRIDES_ATTR_NAME));
654   std::vector<int32> strides;
655   TF_CHECK_OK(context->GetAttr(STRIDES_ATTR_NAME, &strides));
656   const int stride_id = RegisterConstantShape(strides);
657   std::vector<int> extra_inputs{stride_id};
658   if (node.attrs().Find(KSIZE_ATTR_NAME)) {
659     std::vector<int32> kernel_sizes;
660     TF_CHECK_OK(context->GetAttr(KSIZE_ATTR_NAME, &kernel_sizes));
661     const int ksize_id = RegisterConstantShape(kernel_sizes);
662     extra_inputs.insert(extra_inputs.begin(), ksize_id);
663   }
664   // TODO(satok): Set correct data type if it's given.
665   const int op_type_id = ops_definitions.GetOpIdFor(node.type_string(), {});
666   CHECK(op_type_id >= 0 && op_type_id < ops_definitions.GetTotalOpsCount())
667       << "Op " << node.type_string() << " not found in map(id = " << op_type_id
668       << ")";
669   // Safety check of padding id
670   CHECK(padding == Padding::VALID ? 1 : 2);
671   AppendNodeParamsWithIoParams(
672       shape_refiner, node, node.name(), id, node.type_string(), op_type_id,
673       static_cast<int>(padding), node.num_inputs(), extra_inputs,
674       node.num_outputs(), true /* append_input */, true /* append_output */);
675 }
676 
RegisterNodeWithRank(const IRemoteFusedGraphOpsDefinitions & ops_definitions,const ShapeRefiner & shape_refiner,const Node & node)677 void GraphTransferer::RegisterNodeWithRank(
678     const IRemoteFusedGraphOpsDefinitions& ops_definitions,
679     const ShapeRefiner& shape_refiner, const Node& node) {
680   CHECK_EQ(node_name_to_id_cache_map_.count(node.name()), 1);
681   const int id = node_name_to_id_cache_map_[node.name()];
682   shape_inference::InferenceContext* context = shape_refiner.GetContext(&node);
683   const Node* input0_node;
684   TF_CHECK_OK(node.input_node(0, &input0_node));
685   CHECK_NOTNULL(input0_node);
686   std::vector<TensorShape> shapes;
687   Status status = RemoteFusedGraphExecuteUtils::GetOutputTensorShapeType(
688       input0_node->attrs(), nullptr, &shapes);
689   CHECK_EQ(1, shapes.size()) << "Output size should be 1.";
690   const int const_val_id =
691       RegisterConstScalar(DT_INT32, shapes.at(0).dims(), id, node.num_inputs());
692   std::vector<int> extra_inputs{const_val_id};
693   // TODO(satok): Set correct data type if it's given.
694   const int op_type_id = ops_definitions.GetOpIdFor(node.type_string(), {});
695   CHECK(op_type_id >= 0 && op_type_id < ops_definitions.GetTotalOpsCount())
696       << "Op " << node.type_string() << " not found in map(id = " << op_type_id
697       << ")";
698   bool keep_dims = false;
699   int padding_id = PADDING_NA_ID;
700   if (context->GetAttr(KEEP_DIMS_ATTR_NAME, &keep_dims).ok()) {
701     padding_id = keep_dims ? Padding::SAME : Padding::VALID;
702   }
703 
704   AppendNodeParamsWithIoParams(
705       shape_refiner, node, node.name(), id, node.type_string(), op_type_id,
706       padding_id, node.num_inputs(), extra_inputs, node.num_outputs(),
707       true /* append_input */, true /* append_output */);
708 }
709 
RegisterPadNode(const IRemoteFusedGraphOpsDefinitions & ops_definitions,const ShapeRefiner & shape_refiner,const Node & node)710 void GraphTransferer::RegisterPadNode(
711     const IRemoteFusedGraphOpsDefinitions& ops_definitions,
712     const ShapeRefiner& shape_refiner, const Node& node) {
713   static constexpr int PAD_WIDTH = 4;
714   static constexpr int PAD_HEIGHT = 2;
715   VLOG(1) << "Register generic node: " << node.name();
716   CHECK_EQ(node_name_to_id_cache_map_.count(node.name()), 1);
717   const int id = node_name_to_id_cache_map_[node.name()];
718 
719   // TODO(satok): Set correct data type if it's given.
720   const int op_type_id = ops_definitions.GetOpIdFor(node.type_string(), {});
721   CHECK(op_type_id >= 0 && op_type_id < ops_definitions.GetTotalOpsCount());
722 
723   CHECK_EQ(2, node.num_inputs());
724 
725   GraphTransferNodeInputInfo& node_input_info =
726       *graph_transfer_info_->add_node_input_info();
727   node_input_info.set_node_id(id);
728 
729   AddNodeInputByInputIndex(node, 0, &node_input_info);
730 
731   const Edge* edge = nullptr;
732   TF_CHECK_OK(node.input_edge(1, &edge));
733   const Node* input_node = edge->src();
734   CHECK_NOTNULL(input_node);
735   CHECK(input_node->IsConstant());
736 
737   const TensorProto* tensor_proto = nullptr;
738   TF_CHECK_OK(GetNodeAttr(input_node->attrs(), "value", &tensor_proto));
739   CHECK_NOTNULL(tensor_proto);
740   Tensor const_tensor;
741   TF_CHECK_OK(MakeTensorFromProto(*tensor_proto, &const_tensor));
742   CHECK_EQ(2, const_tensor.shape().dims());
743   CHECK_EQ(PAD_HEIGHT, const_tensor.shape().dim_size(1));
744   if (const_tensor.shape().dim_size(0) == PAD_WIDTH) {
745     AddNodeInputByInputIndex(node, 1, &node_input_info);
746   } else if (const_tensor.shape().dim_size(0) < PAD_WIDTH) {
747     const int width = const_tensor.shape().dim_size(0);
748     const TensorProto* proto = nullptr;
749     TF_CHECK_OK(GetNodeAttr(input_node->attrs(), "value", &proto));
750     Tensor const_tensor;
751     TF_CHECK_OK(MakeTensorFromProto(*proto, &const_tensor));
752     CHECK_EQ(DT_INT32, const_tensor.dtype());
753     // reshape tensor input to be rank 4.
754     // TODO(satok): Never assume rank is 4.
755     Tensor new_const_tensor(const_tensor.dtype(), TensorShape{4, 2});
756     for (int i = 0; i < PAD_HEIGHT; ++i) {
757       for (int j = 0; j < PAD_WIDTH; ++j) {
758         if (j < PAD_WIDTH - width) {
759           new_const_tensor.matrix<int32>()(j, i) = 0;
760         } else {
761           new_const_tensor.matrix<int32>()(j, i) =
762               const_tensor.matrix<int32>()(j - (PAD_WIDTH - width), i);
763         }
764       }
765     }
766 
767     const int id = RegisterConstTensor(
768         new_const_tensor,
769         strings::StrCat(input_node->name(), "_", node.name(), "_1"));
770 
771     GraphTransferNodeInput& node_input = *node_input_info.add_node_input();
772     node_input.set_node_id(id);
773     node_input.set_output_port(0);
774   } else {
775     LOG(FATAL);
776   }
777 
778   AppendNodeParamsWithIoParams(
779       shape_refiner, node, node.name(), id, node.type_string(), op_type_id,
780       PADDING_NA_ID, node.num_inputs(), {}, node.num_outputs(),
781       false /* append_input */, true /* append_output */);
782 }
783 
RegisterInputNode(const IRemoteFusedGraphOpsDefinitions & ops_definitions,const ShapeRefiner & shape_refiner,const Node & node)784 void GraphTransferer::RegisterInputNode(
785     const IRemoteFusedGraphOpsDefinitions& ops_definitions,
786     const ShapeRefiner& shape_refiner, const Node& node) {
787   const string op_type = node.type_string();
788   VLOG(1) << "Register input node: " << node.name() << ", " << op_type;
789   CHECK_EQ(node_name_to_id_cache_map_.count(node.name()), 1);
790   const int id = node_name_to_id_cache_map_[node.name()];
791   // TODO(satok): Set correct data type if it's given.
792   const int op_type_id = ops_definitions.GetOpIdFor("INPUT", {});
793   CHECK(op_type_id >= 0 && op_type_id < ops_definitions.GetTotalOpsCount())
794       << "Op" << node.name() << ", " << op_type << " is not supported,"
795       << op_type_id;
796   AppendNodeParamsWithIoParams(
797       shape_refiner, node, node.name(), id, node.type_string(), op_type_id,
798       PADDING_NA_ID, node.num_inputs(), {}, node.num_outputs(),
799       true /* append_input */, true /* append_output */);
800 }
801 
RegisterFlattenNode(const IRemoteFusedGraphOpsDefinitions & ops_definitions,const ShapeRefiner & shape_refiner,const Node & node)802 void GraphTransferer::RegisterFlattenNode(
803     const IRemoteFusedGraphOpsDefinitions& ops_definitions,
804     const ShapeRefiner& shape_refiner, const Node& node) {
805   VLOG(1) << "Register flatten node: " << node.name();
806   CHECK_EQ(node_name_to_id_cache_map_.count(node.name()), 1);
807   const int id = node_name_to_id_cache_map_[node.name()];
808   // TODO(satok): Remove dependency to specific type
809   const string op_type = "FLATTEN";
810   // TODO(satok): Set correct data type if it's given.
811   const int op_type_id = ops_definitions.GetOpIdFor(op_type, {});
812   CHECK(op_type_id >= 0 && op_type_id < ops_definitions.GetTotalOpsCount());
813 
814   AppendNodeParamsWithIoParams(
815       shape_refiner, node, node.name(), id, node.type_string(), op_type_id,
816       PADDING_NA_ID, node.num_inputs(), {}, node.num_outputs(),
817       true /* append_input */, true /* append_output */);
818 }
819 
RegisterGenericNode(const IRemoteFusedGraphOpsDefinitions & ops_definitions,const ShapeRefiner & shape_refiner,const Node & node)820 void GraphTransferer::RegisterGenericNode(
821     const IRemoteFusedGraphOpsDefinitions& ops_definitions,
822     const ShapeRefiner& shape_refiner, const Node& node) {
823   VLOG(1) << "Register generic node: " << node.name();
824   CHECK_EQ(node_name_to_id_cache_map_.count(node.name()), 1);
825   const int id = node_name_to_id_cache_map_[node.name()];
826   // TODO(satok): Set correct data type if it's given.
827   const int op_type_id = ops_definitions.GetOpIdFor(node.type_string(), {});
828   CHECK(op_type_id >= 0 && op_type_id < ops_definitions.GetTotalOpsCount());
829 
830   AppendNodeParamsWithIoParams(
831       shape_refiner, node, node.name(), id, node.type_string(), op_type_id,
832       PADDING_NA_ID, node.num_inputs(), {}, node.num_outputs(),
833       true /* append_input */, true /* append_output */);
834 }
835 
836 // TODO(satok): Remove this function.
837 // TODO(satok): Remove only_register_const_node.
RegisterNodeIfAllInputsAreCached(const IRemoteFusedGraphOpsDefinitions & ops_definitions,const ShapeRefiner & shape_refiner,const Node & node,const bool only_register_const_node,const std::vector<std::pair<string,Tensor>> & input_node_info_list,const std::vector<string> & output_node_names)838 Status GraphTransferer::RegisterNodeIfAllInputsAreCached(
839     const IRemoteFusedGraphOpsDefinitions& ops_definitions,
840     const ShapeRefiner& shape_refiner, const Node& node,
841     const bool only_register_const_node,
842     const std::vector<std::pair<string, Tensor>>& input_node_info_list,
843     const std::vector<string>& output_node_names) {
844   if (only_register_const_node && !node.IsConstant()) {
845     return Status();
846   }
847   CHECK(AreAllInputsCached(node));
848   return RegisterNode(ops_definitions, shape_refiner, node,
849                       input_node_info_list, output_node_names);
850 }
851 
852 // CAVEAT: Append inputs and outputs params accordingly
AppendNodeParams(const string & name,const int id,const string & type,const int type_id,const int padding,const int inputs_size,const std::vector<int> & extra_inputs,const int outputs_size)853 void GraphTransferer::AppendNodeParams(const string& name, const int id,
854                                        const string& type, const int type_id,
855                                        const int padding, const int inputs_size,
856                                        const std::vector<int>& extra_inputs,
857                                        const int outputs_size) {
858   GraphTransferNodeInfo& node_info = *graph_transfer_info_->add_node_info();
859   node_info.set_name(name);
860   node_info.set_node_id(id);
861   node_info.set_type_name(type);
862   node_info.set_soc_op_id(type_id);
863   node_info.set_padding_id(padding);
864   node_info.set_input_count(inputs_size +
865                             static_cast<int>(extra_inputs.size()));
866   node_info.set_output_count(static_cast<int>(outputs_size));
867 }
868 
AddNodeInputByInputIndex(const Node & node,const int idx,GraphTransferNodeInputInfo * node_input_info)869 void GraphTransferer::AddNodeInputByInputIndex(
870     const Node& node, const int idx,
871     GraphTransferNodeInputInfo* node_input_info) {
872   const Edge* edge = nullptr;
873   TF_CHECK_OK(node.input_edge(idx, &edge));
874   const Node* input_node = edge->src();
875   CHECK_NOTNULL(input_node);
876   const int port = edge->src_output();
877 
878   const std::string& op_name = input_node->name();
879   CHECK_GT(node_name_to_id_cache_map_.count(op_name), 0) << op_name;
880   const int src_id = node_name_to_id_cache_map_[op_name];
881   GraphTransferNodeInput& node_input = *node_input_info->add_node_input();
882   node_input.set_node_id(src_id);
883   node_input.set_output_port(port);
884 }
885 
AppendNodeInputParams(const int id,const Node & node,const std::vector<int> & extra_inputs)886 void GraphTransferer::AppendNodeInputParams(
887     const int id, const Node& node, const std::vector<int>& extra_inputs) {
888   VLOG(1) << "Append input params: " << node.name() << ", " << node.num_inputs()
889           << ", " << extra_inputs.size();
890   GraphTransferNodeInputInfo& node_input_info =
891       *graph_transfer_info_->add_node_input_info();
892   node_input_info.set_node_id(id);
893   for (int i = 0; i < node.num_inputs(); ++i) {
894     AddNodeInputByInputIndex(node, i, &node_input_info);
895   }
896   for (const int extra_input : extra_inputs) {
897     GraphTransferNodeInput& node_input = *node_input_info.add_node_input();
898     node_input.set_node_id(extra_input);
899     node_input.set_output_port(0);
900   }
901 }
902 
AppendNodeOutputParams(const ShapeRefiner & shape_refiner,const int id,const Node & node)903 void GraphTransferer::AppendNodeOutputParams(const ShapeRefiner& shape_refiner,
904                                              const int id, const Node& node) {
905   VLOG(1) << "Append output params: " << node.name() << ", "
906           << node.num_outputs();
907   GraphTransferNodeOutputInfo& node_output_info =
908       *graph_transfer_info_->add_node_output_info();
909   node_output_info.set_node_id(id);
910 
911   std::vector<DataType> data_types;
912   std::vector<TensorShape> shapes;
913   Status status = RemoteFusedGraphExecuteUtils::GetOutputTensorShapeType(
914       node.attrs(), &data_types, &shapes);
915 
916   for (int i = 0; i < node.num_outputs(); ++i) {
917     int data_size = -1;
918     const int output_index = i;
919     const DataType dt = node.output_type(output_index);
920     const size_t max_bytes_per_data = DataTypeSize(dt);
921 
922     shape_inference::InferenceContext* context =
923         shape_refiner.GetContext(&node);
924 
925     if (context != nullptr && context->ValueKnown(context->NumElements(
926                                   context->output(output_index)))) {
927       const shape_inference::DimensionHandle num_elements_dim =
928           context->NumElements(context->output(output_index));
929       const int64 num_output_elements = context->Value(num_elements_dim);
930       data_size = max_bytes_per_data * num_output_elements;
931       if (status.ok()) {
932         TF_CHECK_OK(status);
933         CHECK_EQ(shapes.at(i).num_elements(), num_output_elements);
934       }
935     } else {
936       TF_CHECK_OK(status);
937       // Use attribute attached to node
938       data_size = max_bytes_per_data * shapes.at(i).num_elements();
939     }
940     CHECK_GE(data_size, 0);
941     node_output_info.add_max_byte_size(data_size);
942   }
943 }
944 
AppendNodeParamsWithIoParams(const ShapeRefiner & shape_refiner,const Node & node,const string & name,const int id,const string & type,const int type_id,const int padding,const int inputs_size,const std::vector<int> & extra_inputs,const int outputs_size,const bool append_input_params,const bool append_output_params)945 void GraphTransferer::AppendNodeParamsWithIoParams(
946     const ShapeRefiner& shape_refiner, const Node& node, const string& name,
947     const int id, const string& type, const int type_id, const int padding,
948     const int inputs_size, const std::vector<int>& extra_inputs,
949     const int outputs_size, const bool append_input_params,
950     const bool append_output_params) {
951   VLOG(1) << "Append node with io params: " << node.name();
952   if (append_input_params) {
953     AppendNodeInputParams(id, node, extra_inputs);
954   }
955   if (append_output_params) {
956     AppendNodeOutputParams(shape_refiner, id, node);
957   }
958   AppendNodeParams(name, id, type, type_id, padding, inputs_size, extra_inputs,
959                    outputs_size);
960 }
961 
962 /* static */ std::array<int64, GraphTransferer::SHAPE_ARRAY_SIZE>
BuildShapeArray(const shape_inference::ShapeHandle & shape_handle,shape_inference::InferenceContext * context)963 GraphTransferer::BuildShapeArray(
964     const shape_inference::ShapeHandle& shape_handle,
965     shape_inference::InferenceContext* context) {
966   switch (context->Rank(shape_handle)) {
967     case 0:
968       return std::array<int64, SHAPE_ARRAY_SIZE>{{1, 1, 1, 1}};
969     case 1:
970       return std::array<int64, SHAPE_ARRAY_SIZE>{
971           {1, 1, 1, context->Value(context->Dim(shape_handle, 0))}};
972     case 2:
973       return std::array<int64, SHAPE_ARRAY_SIZE>{
974           {1, 1, context->Value(context->Dim(shape_handle, 0)),
975            context->Value(context->Dim(shape_handle, 1))}};
976     case 3:
977       return std::array<int64, SHAPE_ARRAY_SIZE>{
978           {1, context->Value(context->Dim(shape_handle, 0)),
979            context->Value(context->Dim(shape_handle, 1)),
980            context->Value(context->Dim(shape_handle, 2))}};
981     case 4:
982       return std::array<int64, SHAPE_ARRAY_SIZE>{
983           {context->Value(context->Dim(shape_handle, 0)),
984            context->Value(context->Dim(shape_handle, 1)),
985            context->Value(context->Dim(shape_handle, 2)),
986            context->Value(context->Dim(shape_handle, 3))}};
987     default:
988       // TODO(satok): Support more ranks?
989       LOG(FATAL);
990       return std::array<int64, SHAPE_ARRAY_SIZE>();
991   }
992 }
993 
994 /* static */ std::array<int64, GraphTransferer::SHAPE_ARRAY_SIZE>
ToTensorShapeArray(const TensorShape & shape)995 GraphTransferer::ToTensorShapeArray(const TensorShape& shape) {
996   switch (shape.dims()) {
997     case 0:
998       return std::array<int64, SHAPE_ARRAY_SIZE>{{1, 1, 1, 1}};
999     case 1:
1000       return std::array<int64, SHAPE_ARRAY_SIZE>{{1, 1, 1, shape.dim_size(0)}};
1001     case 2:
1002       return std::array<int64, SHAPE_ARRAY_SIZE>{
1003           {1, 1, shape.dim_size(0), shape.dim_size(1)}};
1004     case 3:
1005       return std::array<int64, SHAPE_ARRAY_SIZE>{
1006           {1, shape.dim_size(0), shape.dim_size(1), shape.dim_size(2)}};
1007     case 4:
1008       return std::array<int64, SHAPE_ARRAY_SIZE>{
1009           {shape.dim_size(0), shape.dim_size(1), shape.dim_size(2),
1010            shape.dim_size(3)}};
1011     default:
1012       // TODO(satok): Support more ranks?
1013       LOG(FATAL);
1014       return std::array<int64, SHAPE_ARRAY_SIZE>();
1015   }
1016 }
1017 
ToPaddingDebugString(const int padding)1018 /* static */ string GraphTransferer::ToPaddingDebugString(const int padding) {
1019   switch (padding) {
1020     case 0:
1021       return "NN_PAD_NA";
1022     case Padding::VALID:
1023       return "NN_PAD_VALID";
1024     case Padding::SAME:
1025       return "NN_PAD_SAME";
1026     default:
1027       LOG(FATAL);
1028       return "";
1029   }
1030 }
1031 
TransferParamsComparator(const std::unordered_map<int,std::unordered_set<int>> & dep_map)1032 GraphTransferer::TransferParamsComparator::TransferParamsComparator(
1033     const std::unordered_map<int, std::unordered_set<int>>& dep_map)
1034     : dependency_map_(dep_map) {}
1035 
operator ()(const GraphTransferNodeInfo & obj0,const GraphTransferNodeInfo & obj1)1036 bool GraphTransferer::TransferParamsComparator::operator()(
1037     const GraphTransferNodeInfo& obj0, const GraphTransferNodeInfo& obj1) {
1038   const int node_id0 = obj0.node_id();
1039   const int node_id1 = obj1.node_id();
1040   bool obj0_uses_obj1 = false;
1041   if (dependency_map_.count(node_id0) > 0) {
1042     obj0_uses_obj1 = dependency_map_.at(node_id0).count(node_id1) > 0;
1043   }
1044   bool obj1_uses_obj0 = false;
1045   if (dependency_map_.count(node_id1) > 0) {
1046     obj1_uses_obj0 = dependency_map_.at(node_id1).count(node_id0) > 0;
1047   }
1048   CHECK(!obj0_uses_obj1 || !obj1_uses_obj0);
1049   if (obj0_uses_obj1) {
1050     return false;
1051   } else if (obj1_uses_obj0) {
1052     return true;
1053   }
1054   // If there is no dependency between two nodes, it expects that
1055   // the execution order follows node id order.
1056   return node_id0 < node_id1;
1057 }
1058 
FillDependencyRec(const int node_id,std::unordered_map<int,std::unordered_set<int>> & dep_map,std::unordered_set<int> & completed)1059 /* static */ void GraphTransferer::FillDependencyRec(
1060     const int node_id,
1061     std::unordered_map<int, std::unordered_set<int>>& dep_map,
1062     std::unordered_set<int>& completed) {
1063   if (dep_map.count(node_id) == 0 || dep_map.at(node_id).empty() ||
1064       completed.count(node_id) == 1) {
1065     return;
1066   }
1067   CHECK_EQ(dep_map.count(node_id), 1);
1068 
1069   // Complete children's dependency map
1070   for (int child_node_id : dep_map.at(node_id)) {
1071     CHECK(child_node_id != node_id);
1072     if (completed.count(child_node_id) != 0) {
1073       continue;
1074     }
1075     FillDependencyRec(child_node_id, dep_map, completed);
1076   }
1077 
1078   // Find additional depending ids
1079   std::vector<int> depending_ids;
1080   for (int child_node_id : dep_map.at(node_id)) {
1081     if (dep_map.count(child_node_id) == 0) {
1082       continue;
1083     }
1084     for (int depending_id : dep_map.at(child_node_id)) {
1085       depending_ids.emplace_back(depending_id);
1086     }
1087   }
1088 
1089   // Insert additional depending ids
1090   for (int depending_id : depending_ids) {
1091     if (dep_map.at(node_id).count(depending_id) == 0) {
1092       dep_map.at(node_id).emplace(depending_id);
1093     }
1094   }
1095 
1096   // DP: Record completed node id
1097   completed.emplace(node_id);
1098 }
1099 
MakeTensorFromProto(const TensorProto & tensor_proto,Tensor * tensor)1100 /* static */ Status GraphTransferer::MakeTensorFromProto(
1101     const TensorProto& tensor_proto, Tensor* tensor) {
1102   if (tensor_proto.dtype() > 0 && tensor_proto.dtype() <= DataType_MAX) {
1103     Tensor parsed(tensor_proto.dtype());
1104     if (parsed.FromProto(cpu_allocator(), tensor_proto)) {
1105       *tensor = parsed;
1106       return Status::OK();
1107     }
1108   }
1109   return errors::InvalidArgument("Cannot parse tensor from proto: ",
1110                                  tensor_proto.DebugString());
1111 }
1112 
ClearCache()1113 void GraphTransferer::ClearCache() {
1114   node_name_cache_list_.clear();
1115   node_name_to_id_cache_map_.clear();
1116 }
1117 
DumpNodeTransferParams() const1118 void GraphTransferer::DumpNodeTransferParams() const {
1119   LOG(INFO) << "*** Const Nodes ***";
1120   for (const GraphTransferConstNodeInfo& params :
1121        graph_transfer_info_->const_node_info()) {
1122     // TODO(satok): Stop assuming shape size is 4.
1123     CHECK_EQ(params.shape_size(), 4);
1124     LOG(INFO) << "[ " << params.node_id() << " \"" << params.name()
1125               << "\" (Const)";
1126     LOG(INFO) << "  shape: " << params.shape(0) << params.shape(1)
1127               << params.shape(2) << params.shape(3);
1128     LOG(INFO) << "  data_name: "
1129               << (params.data().length() <= 0
1130                       ? ""
1131                       : DATA_NODE_PREFIX + ToString(params.node_id()));
1132     LOG(INFO) << "  data_size: " << params.data().length() << " bytes"
1133               << " ]";
1134   }
1135   LOG(INFO) << "******\n";
1136   LOG(INFO) << "*** Op Nodes ***";
1137   for (const GraphTransferNodeInfo& params :
1138        graph_transfer_info_->node_info()) {
1139     LOG(INFO) << "[ " << params.node_id() << " \"" << params.name();
1140     LOG(INFO) << "  type: " << params.type_name();
1141     LOG(INFO) << "  padding: " << ToPaddingDebugString(params.padding_id());
1142     LOG(INFO) << "  inputs: " << INPUTS_NODE_PREFIX + ToString(params.node_id())
1143               << ", size = " << params.input_count();
1144     LOG(INFO) << "  outputs: "
1145               << (params.output_count() <= 0
1146                       ? NULL_OUTPUT_NAME
1147                       : (OUTPUTS_NODE_PREFIX + ToString(params.node_id())))
1148               << ", size = " << params.output_count() << " ]";
1149   }
1150   LOG(INFO) << "******\n";
1151   LOG(INFO) << "*** Node input params ***";
1152   for (const GraphTransferNodeInputInfo& params :
1153        graph_transfer_info_->node_input_info()) {
1154     LOG(INFO) << "[ " << params.node_id() << " ]";
1155     for (const GraphTransferNodeInput& node_input : params.node_input()) {
1156       LOG(INFO) << "    src node id = " << node_input.node_id()
1157                 << ", output port = " << node_input.output_port();
1158     }
1159   }
1160   LOG(INFO) << "******\n";
1161   LOG(INFO) << "*** Node output params ***";
1162   for (const GraphTransferNodeOutputInfo& params :
1163        graph_transfer_info_->node_output_info()) {
1164     LOG(INFO) << "[ " << params.node_id() << " ]";
1165     for (const int max_size : params.max_byte_size()) {
1166       LOG(INFO) << "    max_size = " << max_size;
1167     }
1168   }
1169   LOG(INFO) << "******\n";
1170 }
1171 
DumpVerificationStringOfNodeTransferParams() const1172 void GraphTransferer::DumpVerificationStringOfNodeTransferParams() const {
1173   for (const GraphTransferConstNodeInfo& params :
1174        graph_transfer_info_->const_node_info()) {
1175     std::stringstream sstream;
1176     // TODO(satok): Stop assuming shape size is 4.
1177     CHECK_EQ(params.shape_size(), 4);
1178     sstream << "---(CONST) [" << std::hex << params.node_id() << std::dec << ","
1179             << params.shape(0) << "," << params.shape(1) << ","
1180             << params.shape(2) << "," << params.shape(3) << ","
1181             << (params.data().length() <= 0
1182                     ? ""
1183                     : DATA_NODE_PREFIX + ToString(params.node_id()))
1184             << "," << params.data().length() << "," << params.name() << "]";
1185     LOG(INFO) << sstream.str();
1186   }
1187   LOG(INFO) << "Const node count = "
1188             << graph_transfer_info_->const_node_info_size();
1189   for (const GraphTransferNodeInfo& params :
1190        graph_transfer_info_->node_info()) {
1191     std::stringstream sstream;
1192     sstream << "---(OP) [" << params.name().c_str() << "," << std::hex
1193             << params.node_id() << std::dec << "," << params.soc_op_id() << ","
1194             << ToPaddingDebugString(params.padding_id()) << ","
1195             << INPUTS_NODE_PREFIX + ToString(params.node_id()) << ","
1196             << params.input_count() << ","
1197             << (params.output_count() <= 0
1198                     ? NULL_OUTPUT_NAME
1199                     : (OUTPUTS_NODE_PREFIX + ToString(params.node_id())))
1200             << "," << params.output_count() << "," << params.type_name() << "]";
1201     LOG(INFO) << sstream.str();
1202   }
1203   LOG(INFO) << "Op node count = " << graph_transfer_info_->node_info_size();
1204   for (const GraphTransferNodeInputInfo& params :
1205        graph_transfer_info_->node_input_info()) {
1206     std::stringstream sstream;
1207     sstream << "---(INPUT) [" << std::hex << params.node_id() << std::dec;
1208     for (const GraphTransferNodeInput& node_input : params.node_input()) {
1209       sstream << "," << std::hex << node_input.node_id() << std::dec << ","
1210               << node_input.output_port();
1211     }
1212     sstream << "]";
1213     LOG(INFO) << sstream.str();
1214   }
1215   LOG(INFO) << "Input params count = "
1216             << graph_transfer_info_->node_input_info_size();
1217   for (const GraphTransferNodeOutputInfo& params :
1218        graph_transfer_info_->node_output_info()) {
1219     std::stringstream sstream;
1220     sstream << "---(OUTPUT) [" << std::hex << params.node_id() << std::dec;
1221     for (const int max_size : params.max_byte_size()) {
1222       sstream << "," << max_size;
1223     }
1224     sstream << "]";
1225     LOG(INFO) << sstream.str();
1226   }
1227   LOG(INFO) << "Output params count = "
1228             << graph_transfer_info_->node_output_info_size();
1229 }
1230 
1231 }  // namespace tensorflow
1232