• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/kernels/remote_fused_graph_execute_utils.h"
17 
18 #include <algorithm>
19 #include <queue>
20 #include <utility>
21 
22 #include "tensorflow/core/common_runtime/shape_refiner.h"
23 #include "tensorflow/core/framework/graph.pb.h"
24 #include "tensorflow/core/framework/node_def_util.h"
25 #include "tensorflow/core/framework/remote_fused_graph_execute_info.pb.h"
26 #include "tensorflow/core/framework/tensor.pb.h"
27 #include "tensorflow/core/framework/tensor_shape.pb.h"
28 #include "tensorflow/core/graph/algorithm.h"
29 #include "tensorflow/core/graph/node_builder.h"
30 #include "tensorflow/core/public/session.h"
31 #include "tensorflow/core/public/session_options.h"
32 
33 namespace tensorflow {
34 namespace {
FindNodeByName(const string & name,const Graph & graph)35 const Node* FindNodeByName(const string& name, const Graph& graph) {
36   for (const Node* node : graph.nodes()) {
37     CHECK_NOTNULL(node);
38     if (node->name() == name) {
39       return node;
40     }
41   }
42   return nullptr;
43 }
44 
BuildNodeSetFromNodeNamesAndPorts(const std::vector<string> & node_names_and_ports)45 std::unordered_set<string> BuildNodeSetFromNodeNamesAndPorts(
46     const std::vector<string>& node_names_and_ports) {
47   std::unordered_set<string> retval;
48   for (const string& node_name_and_port : node_names_and_ports) {
49     const TensorId tid = ParseTensorName(node_name_and_port);
50     retval.emplace(tid.first);
51   }
52   return retval;
53 }
54 
FindMutableNodeByName(const string & name,Graph * graph)55 Node* FindMutableNodeByName(const string& name, Graph* graph) {
56   for (Node* node : graph->nodes()) {
57     if (node != nullptr && node->name() == name) {
58       return node;
59     }
60   }
61   return nullptr;
62 }
63 
FindNodeDefByName(const string & input,const GraphDef & graph_def)64 const NodeDef* FindNodeDefByName(const string& input,
65                                  const GraphDef& graph_def) {
66   const TensorId tid = ParseTensorName(input);
67   const string name = string(tid.first);
68   for (const NodeDef& node_def : graph_def.node()) {
69     if (node_def.name() == name) {
70       return &node_def;
71     }
72   }
73   return nullptr;
74 }
75 
IsSameNodeName(const NodeDef & node_def,const string & node_name_and_port,TensorId * tid)76 bool IsSameNodeName(const NodeDef& node_def, const string& node_name_and_port,
77                     TensorId* tid) {
78   CHECK_NOTNULL(tid);
79   *tid = ParseTensorName(node_name_and_port);
80   if (node_def.name() == tid->first) {
81     return true;
82   }
83   return false;
84 }
85 
ContainsSameTensorId(const string & tensor_name,const std::vector<string> & tensor_names)86 bool ContainsSameTensorId(const string& tensor_name,
87                           const std::vector<string>& tensor_names) {
88   const TensorId tid0 = ParseTensorName(tensor_name);
89   for (const string& name : tensor_names) {
90     const TensorId tid1 = ParseTensorName(name);
91     if (tid0.first == tid1.first && tid0.second == tid1.second) {
92       return true;
93     }
94   }
95   return false;
96 }
97 
AppendDeliminator(string * str)98 void AppendDeliminator(string* str) {
99   CHECK_NOTNULL(str);
100   if (!str->empty()) {
101     *str += ":";
102   }
103 }
104 
ConvertMapToVector(const std::unordered_map<int,string> & in,std::vector<string> * out)105 void ConvertMapToVector(const std::unordered_map<int, string>& in,
106                         std::vector<string>* out) {
107   CHECK_NOTNULL(out);
108   out->resize(in.size());
109   for (size_t i = 0; i < in.size(); ++i) {
110     CHECK(in.count(i) > 0);
111     out->at(i) = in.at(i);
112   }
113 }
114 
DumpGraphDef(const GraphDef & graph_def)115 string DumpGraphDef(const GraphDef& graph_def) {
116   string out;
117   for (const NodeDef& node : graph_def.node()) {
118     out += strings::StrCat("node: ", node.name(), "\n    input: ");
119     for (const string& input : node.input()) {
120       out += strings::StrCat(input, ", ");
121     }
122     out += "\n";
123   }
124   return out;
125 }
126 
DumpCluster(const RemoteFusedGraphExecuteUtils::ClusterInfo & cluster)127 string DumpCluster(const RemoteFusedGraphExecuteUtils::ClusterInfo& cluster) {
128   string out;
129   out += "Nodes:\n";
130   for (const string& str : std::get<0>(cluster)) {
131     out += str + ", ";
132   }
133   out += "\nInput border:\n";
134   for (const string& str : std::get<1>(cluster)) {
135     out += str + ", ";
136   }
137   out += "\nOutput border:\n";
138   for (const string& str : std::get<2>(cluster)) {
139     out += str + ", ";
140   }
141   return out;
142 }
143 
144 }  // namespace
145 
146 /* static */ constexpr const char* const
147     RemoteFusedGraphExecuteUtils::ATTR_OUTPUT_DATA_TYPES;
148 /* static */ constexpr const char* const
149     RemoteFusedGraphExecuteUtils::ATTR_OUTPUT_SHAPES;
150 /* static */ constexpr const char* const RemoteFusedGraphExecuteUtils::
151     ATTR_SERIALIZED_REMOTE_FUSED_GRAPH_EXECUTE_INFO;
152 /* static */ constexpr const char* const
153     RemoteFusedGraphExecuteUtils::ATTR_NODE_TYPE;
154 /* static */ constexpr const char* const RemoteFusedGraphExecuteUtils::
155     TRANSFORM_ARG_REMOTE_FUSED_GRAPH_EXECUTOR_NAME;
156 /* static */ constexpr const char* const
157     RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_REMOTE_FUSED_GRAPH_NODE_NAME;
158 /* static */ constexpr const char* const
159     RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_FUSED_NODES;
160 /* static */ constexpr const char* const
161     RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_BORDER_INPUTS;
162 /* static */ constexpr const char* const
163     RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_BORDER_OUTPUTS;
164 /* static */ constexpr const char* const
165     RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_FUSED_OP_TYPES;
166 /* static */ constexpr const char* const
167     RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_FUSE_BY_EXECUTOR;
168 /* static */ constexpr const char* const
169     RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_INPUT_TYPES;
170 /* static */ constexpr const char* const
171     RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_INPUT_SHAPES;
172 
ExecutorBuildRegistrar(const string & name,ExecutorBuildFunc executor_build_func)173 RemoteFusedGraphExecuteUtils::ExecutorBuildRegistrar::ExecutorBuildRegistrar(
174     const string& name, ExecutorBuildFunc executor_build_func) {
175   ExecutorBuildRegistry& executor_build_registry = *GetExecutorBuildRegistry();
176   executor_build_registry[name] = std::move(executor_build_func);
177 }
178 
179 /* static */ const RemoteFusedGraphExecuteUtils::ExecutorBuildFunc*
GetExecutorBuildFunc(const string & name)180 RemoteFusedGraphExecuteUtils::GetExecutorBuildFunc(const string& name) {
181   ExecutorBuildRegistry& executor_build_registry = *GetExecutorBuildRegistry();
182   if (executor_build_registry.count(name) <= 0) {
183     return nullptr;
184   }
185   return &executor_build_registry.at(name);
186 }
187 
188 /* static */ RemoteFusedGraphExecuteUtils::ExecutorBuildRegistry*
GetExecutorBuildRegistry()189 RemoteFusedGraphExecuteUtils::GetExecutorBuildRegistry() {
190   static ExecutorBuildRegistry executor_builder_registry;
191   return &executor_builder_registry;
192 }
193 
194 /**
195  * - DryRunInference
196  * To determine shapes of output tensors of all nodes, dryrun the graph.
197  * This function supplies memory allocation information when loading
198  * the graph. This function is used to verify shape inference and actual
199  * output shape.
200  */
DryRunInference(const GraphDef & graph_def,const std::vector<std::pair<string,Tensor>> & input_node_info_list,const std::vector<string> & output_node_names,const bool initialize_by_zero,std::vector<tensorflow::Tensor> * output_tensors)201 /* static */ Status RemoteFusedGraphExecuteUtils::DryRunInference(
202     const GraphDef& graph_def,
203     const std::vector<std::pair<string, Tensor>>& input_node_info_list,
204     const std::vector<string>& output_node_names, const bool initialize_by_zero,
205     std::vector<tensorflow::Tensor>* output_tensors) {
206   // Create input tensor vector.  If "initialize_by_zero" is true,
207   // input tensor fields are initialized by 0.
208   std::vector<std::pair<string, tensorflow::Tensor>> input_tensors;
209   for (const std::pair<string, Tensor>& input : input_node_info_list) {
210     CHECK(input.second.IsInitialized());
211     if (!initialize_by_zero) {
212       input_tensors.push_back({input.first, input.second});
213       continue;
214     }
215     // If input tensor is not initialized, initialize by 0-filling
216     const DataType data_type = input.second.dtype();
217     const TensorShape& shape = input.second.shape();
218     Tensor input_tensor(data_type, shape);
219     switch (data_type) {
220       case DT_INT32: {
221         auto int_tensor = input_tensor.flat<int32>();
222         int_tensor = int_tensor.constant(0);
223         break;
224       }
225       case DT_FLOAT: {
226         auto float_tensor = input_tensor.flat<float>();
227         float_tensor = float_tensor.constant(0.0f);
228         break;
229       }
230       case DT_QUINT8: {
231         auto int_tensor = input_tensor.flat<quint8>();
232         int_tensor = int_tensor.constant(0);
233         break;
234       }
235       default:
236         LOG(FATAL) << "Unsupported input type: " << data_type;
237     }
238     input_tensors.push_back({input.first, input_tensor});
239   }
240 
241   // Setup session
242   CHECK(output_tensors != nullptr);
243   SessionOptions session_options;
244   session_options.env = Env::Default();
245   std::unique_ptr<Session> session =
246       std::unique_ptr<Session>(NewSession(session_options));
247   Status status = session->Create(graph_def);
248   if (!status.ok()) {
249     return status;
250   }
251 
252   // Setup session arguments
253   RunOptions run_options;
254   run_options.set_trace_level(RunOptions::FULL_TRACE);
255   RunMetadata run_metadata;
256 
257   // Run inference with all node as output
258   status = session->Run(run_options, input_tensors, output_node_names, {},
259                         output_tensors, &run_metadata);
260   if (!status.ok()) {
261     LOG(ERROR) << "Error during inference: " << status;
262     return status;
263   }
264   return Status();
265 }
266 
DryRunInferenceForAllNode(const GraphDef & graph_def,const std::vector<std::pair<string,Tensor>> & input_node_info_list,const bool initialize_by_zero,RemoteFusedGraphExecuteUtils::TensorShapeMap * tensor_shape_map)267 /* static */ Status RemoteFusedGraphExecuteUtils::DryRunInferenceForAllNode(
268     const GraphDef& graph_def,
269     const std::vector<std::pair<string, Tensor>>& input_node_info_list,
270     const bool initialize_by_zero,
271     RemoteFusedGraphExecuteUtils::TensorShapeMap* tensor_shape_map) {
272   CHECK(tensor_shape_map != nullptr);
273   std::vector<Tensor> output_tensors;
274   output_tensors.reserve(graph_def.node_size());
275   std::vector<string> output_node_names;
276 
277   Graph graph(OpRegistry::Global());
278   Status status = ImportGraphDef({}, graph_def, &graph, nullptr);
279   if (!status.ok()) {
280     return status;
281   }
282 
283   for (const Node* node : graph.nodes()) {
284     if (IsInputNode(input_node_info_list, node->name())) {
285       continue;
286     }
287     for (int i = 0; i < node->num_outputs(); ++i) {
288       output_node_names.emplace_back(strings::StrCat(node->name(), ":", i));
289     }
290   }
291 
292   status = DryRunInference(graph_def, input_node_info_list, output_node_names,
293                            initialize_by_zero, &output_tensors);
294   if (!status.ok()) {
295     VLOG(1) << "Failed to dryrun " << status;
296     return status;
297   }
298 
299   CHECK_EQ(output_node_names.size(), output_tensors.size())
300       << output_node_names.size() << ", " << output_tensors.size();
301 
302   // Append output tensor of input node in advance to create a map
303   // to avoid memory reallocation inside vector
304   for (const std::pair<string, Tensor>& input_node_info :
305        input_node_info_list) {
306     output_tensors.push_back(input_node_info.second);
307   }
308 
309   for (int i = 0; static_cast<size_t>(i) < output_node_names.size(); ++i) {
310     const string& name = output_node_names.at(i);
311     const Tensor& tensor = output_tensors.at(i);
312     EmplaceTensorShapeType(name, tensor, tensor_shape_map);
313   }
314   for (int i = 0; static_cast<size_t>(i) < input_node_info_list.size(); ++i) {
315     const string& name = input_node_info_list.at(i).first;
316     const Tensor& tensor = output_tensors.at(output_node_names.size() + i);
317     EmplaceTensorShapeType(name, tensor, tensor_shape_map);
318   }
319   CHECK_EQ(output_node_names.size() + input_node_info_list.size(),
320            output_tensors.size());
321   return status;
322 }
323 
IsInputNode(const std::vector<std::pair<string,Tensor>> & input_tensor_vector,const string & node_name)324 /* static */ bool RemoteFusedGraphExecuteUtils::IsInputNode(
325     const std::vector<std::pair<string, Tensor>>& input_tensor_vector,
326     const string& node_name) {
327   for (const std::pair<string, Tensor>& pair : input_tensor_vector) {
328     const TensorId tid = ParseTensorName(pair.first);
329     if (node_name == tid.first) {
330       return true;
331     }
332   }
333   return false;
334 }
335 
ConvertToTensorShapeMap(const std::vector<std::pair<string,Tensor>> & input_node_info_list,const std::vector<string> & output_node_names,const std::vector<tensorflow::Tensor> & output_tensors,TensorShapeMap * tensor_shape_map)336 /* static */ void RemoteFusedGraphExecuteUtils::ConvertToTensorShapeMap(
337     const std::vector<std::pair<string, Tensor>>& input_node_info_list,
338     const std::vector<string>& output_node_names,
339     const std::vector<tensorflow::Tensor>& output_tensors,
340     TensorShapeMap* tensor_shape_map) {
341   CHECK_NE(tensor_shape_map, nullptr);
342   tensor_shape_map->clear();
343   tensor_shape_map->reserve(input_node_info_list.size() +
344                             output_node_names.size());
345   const int output_node_count = output_node_names.size();
346   CHECK_EQ(output_node_count, output_tensors.size());
347   for (int i = 0; i < output_node_count; ++i) {
348     const string& node_name = output_node_names.at(i);
349     const Tensor& tensor = output_tensors.at(i);
350     EmplaceTensorShapeType(node_name, tensor, tensor_shape_map);
351   }
352 }
353 
MakeTensorFromProto(const TensorProto & tensor_proto,Tensor * tensor)354 /* static */ Status RemoteFusedGraphExecuteUtils::MakeTensorFromProto(
355     const TensorProto& tensor_proto, Tensor* tensor) {
356   if (tensor_proto.dtype() > 0 && tensor_proto.dtype() <= DataType_MAX) {
357     Tensor parsed(tensor_proto.dtype());
358     if (parsed.FromProto(cpu_allocator(), tensor_proto)) {
359       *tensor = parsed;
360       return Status::OK();
361     }
362   }
363   return errors::InvalidArgument("Cannot parse tensor from proto");
364 }
365 
AddOutputTensorShapeType(const std::vector<DataType> & data_types,const std::vector<TensorShape> & shapes,NodeDef * node_def)366 /* static */ bool RemoteFusedGraphExecuteUtils::AddOutputTensorShapeType(
367     const std::vector<DataType>& data_types,
368     const std::vector<TensorShape>& shapes, NodeDef* node_def) {
369   AddNodeAttr(ATTR_OUTPUT_DATA_TYPES, data_types, node_def);
370   AddNodeAttr(ATTR_OUTPUT_SHAPES, shapes, node_def);
371   return true;
372 }
373 
374 /* static */ Status
AddOutputTensorShapeTypeByTensorShapeMap(const TensorShapeMap & tensor_shape_map,NodeDef * node_def)375 RemoteFusedGraphExecuteUtils::AddOutputTensorShapeTypeByTensorShapeMap(
376     const TensorShapeMap& tensor_shape_map, NodeDef* node_def) {
377   CHECK_NE(node_def, nullptr);
378   std::priority_queue<std::tuple<int, const TensorShapeType*>> queue;
379   auto its = tensor_shape_map.equal_range(node_def->name());
380   for (auto it = its.first; it != its.second; ++it) {
381     queue.emplace(std::make_tuple(it->second.first, &it->second.second));
382   }
383   int last_port = queue.size();
384   std::vector<DataType> data_types;
385   std::vector<TensorShape> shapes;
386   while (!queue.empty()) {
387     const int port = std::get<0>(queue.top());
388     const TensorShapeType* tst = std::get<1>(queue.top());
389     CHECK_NE(tst, nullptr);
390     data_types.emplace(data_types.begin(), tst->first);
391     shapes.emplace(shapes.begin(), tst->second);
392     CHECK_EQ(last_port - 1, port);
393     last_port = port;
394     queue.pop();
395   }
396   AddOutputTensorShapeType(data_types, shapes, node_def);
397   return Status::OK();
398 }
399 
GetOutputTensorShapeType(AttrSlice attrs,std::vector<DataType> * data_types,std::vector<TensorShape> * shapes)400 /* static */ Status RemoteFusedGraphExecuteUtils::GetOutputTensorShapeType(
401     AttrSlice attrs, std::vector<DataType>* data_types,
402     std::vector<TensorShape>* shapes) {
403   Status status;
404   if (data_types != nullptr) {
405     status = GetNodeAttr(attrs, ATTR_OUTPUT_DATA_TYPES, data_types);
406   }
407   if (!status.ok()) {
408     return status;
409   }
410   if (shapes != nullptr) {
411     status = GetNodeAttr(attrs, ATTR_OUTPUT_SHAPES, shapes);
412     if (status.ok() && data_types != nullptr) {
413       CHECK_EQ(data_types->size(), shapes->size());
414     }
415   }
416 
417   return status;
418 }
419 
GetOutputTensorShapeType(const GraphDef & graph_def,const string & name_and_port,DataType * data_type,TensorShape * shape)420 /* static */ bool RemoteFusedGraphExecuteUtils::GetOutputTensorShapeType(
421     const GraphDef& graph_def, const string& name_and_port, DataType* data_type,
422     TensorShape* shape) {
423   std::vector<DataType> data_types;
424   std::vector<TensorShape> shapes;
425   const TensorId tid = ParseTensorName(name_and_port);
426   const string node_name(tid.first);
427   const int port = tid.second;
428   const NodeDef* node_def = FindNodeDefByName(node_name, graph_def);
429   CHECK_NOTNULL(node_def);
430   GetOutputTensorShapeType(*node_def, &data_types, &shapes).IgnoreError();
431   if (data_types.empty()) {
432     return false;
433   }
434   CHECK(data_types.size() > port);
435   *data_type = data_types.at(port);
436   *shape = shapes.at(port);
437   return true;
438 }
439 
PropagateShapeInference(const GraphDef & graph_def,const std::vector<std::pair<string,Tensor>> & input_node_info_list,Graph * graph,ShapeRefiner * shape_refiner)440 /* static */ Status RemoteFusedGraphExecuteUtils::PropagateShapeInference(
441     const GraphDef& graph_def,
442     const std::vector<std::pair<string, Tensor>>& input_node_info_list,
443     Graph* graph, ShapeRefiner* shape_refiner) {
444   Status status;
445   auto visit = [&shape_refiner, &input_node_info_list, &status](Node* node) {
446     if (!status.ok()) {
447       return;
448     }
449     CHECK_NE(node, nullptr);
450     // If we visit an input node, we use the shape provided and set the
451     // shape accordingly.
452     bool is_input_node = false;
453     for (const std::pair<string, Tensor>& input_node_info :
454          input_node_info_list) {
455       if (node->name() == input_node_info.first) {
456         shape_inference::InferenceContext* context =
457             shape_refiner->GetContext(node);
458         shape_inference::ShapeHandle handle;
459         status = context->MakeShapeFromTensorShape(
460             input_node_info.second.shape(), &handle);
461         if (!status.ok()) {
462           break;
463         }
464         status = shape_refiner->SetShape(node, 0, handle);
465         if (!status.ok()) {
466           break;
467         }
468         is_input_node = true;
469       }
470       if (!status.ok()) {
471         break;
472       }
473     }
474     // If not an input node call AddNode() that recomputes the shape.
475     if (!is_input_node && status.ok()) {
476       status = shape_refiner->AddNode(node);
477     }
478     if (!status.ok()) {
479       VLOG(1) << "Shape inference failed for node: " << node->name();
480     }
481   };
482 
483   ReverseDFS(*graph, {}, visit);
484 
485   return status;
486 }
487 
BuildTensorShapeMapFromGraph(const Graph & graph,const ShapeRefiner & shape_refiner,TensorShapeMap * tensor_shape_map)488 /* static */ Status RemoteFusedGraphExecuteUtils::BuildTensorShapeMapFromGraph(
489     const Graph& graph, const ShapeRefiner& shape_refiner,
490     TensorShapeMap* tensor_shape_map) {
491   for (int i = 0; i < graph.num_node_ids(); ++i) {
492     const Node* node = graph.FindNodeId(i);
493     CHECK_NE(node, nullptr);
494     for (int j = 0; j < node->num_outputs(); ++j) {
495       const int output_index = j;
496       const DataType dt = node->output_type(output_index);
497       shape_inference::InferenceContext* context =
498           shape_refiner.GetContext(node);
499       CHECK_NE(context, nullptr);
500       shape_inference::ShapeHandle shape_handle = context->output(output_index);
501       if (context->RankKnown(shape_handle)) {
502         TensorShape ts;
503         for (int k = 0; k < context->Rank(shape_handle); ++k) {
504           shape_inference::DimensionHandle dh = context->Dim(shape_handle, k);
505           CHECK(context->ValueKnown(dh));
506           ts.AddDim(context->Value(dh));
507         }
508         const string& node_name = node->name();
509         CHECK(tensor_shape_map->count(node_name) == 0);
510         tensor_shape_map->emplace(node_name,
511                                   std::make_pair(j, std::make_pair(dt, ts)));
512       } else {
513         return errors::InvalidArgument("Graph contains unknow shapes");
514       }
515     }
516   }
517   return Status::OK();
518 }
519 
520 /* static */ const RemoteFusedGraphExecuteUtils::TensorShapeType*
GetTensorShapeType(const TensorShapeMap & tensor_shape_map,const string & node_name)521 RemoteFusedGraphExecuteUtils::GetTensorShapeType(
522     const TensorShapeMap& tensor_shape_map, const string& node_name) {
523   if (node_name.find(':') != string::npos) {
524     const TensorId tid = ParseTensorName(node_name);
525     return GetTensorShapeType(tensor_shape_map, string(tid.first), tid.second);
526   } else {
527     return GetTensorShapeType(tensor_shape_map, node_name, 0);
528   }
529 }
530 
531 /* static */ const RemoteFusedGraphExecuteUtils::TensorShapeType*
GetTensorShapeType(const TensorShapeMap & tensor_shape_map,const string & node_name,const int port)532 RemoteFusedGraphExecuteUtils::GetTensorShapeType(
533     const TensorShapeMap& tensor_shape_map, const string& node_name,
534     const int port) {
535   CHECK_EQ(node_name.find(':'), string::npos);
536   if (tensor_shape_map.count(node_name) <= 0) {
537     return nullptr;
538   }
539   auto its = tensor_shape_map.equal_range(node_name);
540   for (auto it = its.first; it != its.second; ++it) {
541     if (it->second.first == port) {
542       return &it->second.second;
543     }
544   }
545   return nullptr;
546 }
547 
548 /* static */ void
BuildRemoteGraphInputsAndOutputsFromProto(const RemoteFusedGraphExecuteInfo & proto,std::vector<std::pair<string,Tensor>> * inputs,std::vector<string> * outputs)549 RemoteFusedGraphExecuteUtils::BuildRemoteGraphInputsAndOutputsFromProto(
550     const RemoteFusedGraphExecuteInfo& proto,
551     std::vector<std::pair<string, Tensor>>* inputs,
552     std::vector<string>* outputs) {
553   CHECK_EQ(proto.graph_input_node_name_size(),
554            proto.default_graph_input_tensor_shape_size());
555   for (int i = 0; i < proto.graph_input_node_name_size(); ++i) {
556     inputs->emplace_back(
557         proto.graph_input_node_name(i),
558         Tensor(proto.default_graph_input_tensor_shape(i).dtype(),
559                TensorShape(proto.default_graph_input_tensor_shape(i).shape())));
560   }
561   for (const string& output_node_name : proto.graph_output_node_name()) {
562     outputs->emplace_back(output_node_name);
563   }
564 }
565 
EmplaceTensorShapeType(const string & name,const Tensor & tensor,TensorShapeMap * tensor_shape_map)566 /* static */ void RemoteFusedGraphExecuteUtils::EmplaceTensorShapeType(
567     const string& name, const Tensor& tensor,
568     TensorShapeMap* tensor_shape_map) {
569   const TensorId tid = ParseTensorName(name);
570   CHECK_EQ(tensor_shape_map->count(name), 0);
571   tensor_shape_map->emplace(
572       string(tid.first),
573       std::make_pair(tid.second,
574                      std::make_pair(tensor.dtype(), tensor.shape())));
575 }
576 
BuildAndAddTensorShapes(const std::vector<std::pair<string,Tensor>> & input_tensors,const bool dry_run_inference,GraphDef * graph_def)577 /* static */ Status RemoteFusedGraphExecuteUtils::BuildAndAddTensorShapes(
578     const std::vector<std::pair<string, Tensor>>& input_tensors,
579     const bool dry_run_inference, GraphDef* graph_def) {
580   TensorShapeMap tensor_shape_map;
581   if (dry_run_inference) {
582     TF_RETURN_IF_ERROR(DryRunInferenceForAllNode(*graph_def, input_tensors,
583                                                  /*initialize_by_zero=*/true,
584                                                  &tensor_shape_map));
585   } else {
586     ImportGraphDefOptions opts;
587     Graph graph(OpRegistry::Global());
588     ShapeRefiner shape_refiner(graph.versions(), graph.op_registry());
589     TF_RETURN_IF_ERROR(
590         ImportGraphDef(opts, *graph_def, &graph, &shape_refiner));
591     TF_RETURN_IF_ERROR(PropagateShapeInference(*graph_def, input_tensors,
592                                                &graph, &shape_refiner));
593     TF_RETURN_IF_ERROR(
594         BuildTensorShapeMapFromGraph(graph, shape_refiner, &tensor_shape_map));
595   }
596 
597   for (NodeDef& node_def : *graph_def->mutable_node()) {
598     TF_RETURN_IF_ERROR(
599         AddOutputTensorShapeTypeByTensorShapeMap(tensor_shape_map, &node_def));
600   }
601 
602   return Status::OK();
603 }
604 
605 /* static */ Status
BuildRemoteFusedGraphExecuteInfo(const string & executor_name,const GraphDef & subgraph_def,const std::vector<string> & inputs,const std::vector<string> & outputs,const bool require_shape_type,RemoteFusedGraphExecuteInfo * execute_info,DataTypeVector * input_types,DataTypeVector * output_types)606 RemoteFusedGraphExecuteUtils::BuildRemoteFusedGraphExecuteInfo(
607     const string& executor_name, const GraphDef& subgraph_def,
608     const std::vector<string>& inputs, const std::vector<string>& outputs,
609     const bool require_shape_type, RemoteFusedGraphExecuteInfo* execute_info,
610     DataTypeVector* input_types, DataTypeVector* output_types) {
611   CHECK_NOTNULL(execute_info);
612   CHECK_NOTNULL(input_types);
613   CHECK_NOTNULL(output_types);
614 
615   execute_info->Clear();
616   execute_info->set_executor_name(executor_name);
617 
618   // copy graph
619   *execute_info->mutable_remote_graph() = subgraph_def;
620 
621   for (const string& input : inputs) {
622     DataType dt;
623     TensorShape shape;
624     const bool has_shapetype =
625         GetOutputTensorShapeType(subgraph_def, input, &dt, &shape);
626 
627     execute_info->add_graph_input_node_name(input);
628     if (has_shapetype) {
629       RemoteFusedGraphExecuteInfo::TensorShapeTypeProto& tensor_shape_type =
630           *execute_info->add_default_graph_input_tensor_shape();
631       tensor_shape_type.set_dtype(dt);
632       TensorShapeProto& tensor_shape_proto = *tensor_shape_type.mutable_shape();
633       for (const int64 dim : shape.dim_sizes()) {
634         tensor_shape_proto.add_dim()->set_size(dim);
635       }
636       input_types->push_back(dt);
637     } else {
638       CHECK(!require_shape_type)
639           << "No shape type found for " << input << DumpGraphDef(subgraph_def);
640       // Assuming input type is float if no data provided.
641       input_types->push_back(DT_FLOAT);
642     }
643   }
644 
645   for (const string& output : outputs) {
646     DataType dt;
647     TensorShape shape;
648     const bool has_shapetype =
649         GetOutputTensorShapeType(subgraph_def, output, &dt, &shape);
650 
651     execute_info->add_graph_output_node_name(output);
652     if (has_shapetype) {
653       RemoteFusedGraphExecuteInfo::TensorShapeTypeProto&
654           tensor_shape_type_proto =
655               *execute_info->add_default_graph_output_tensor_shape();
656       tensor_shape_type_proto.set_dtype(dt);
657       TensorShapeProto& tensor_shape_proto =
658           *tensor_shape_type_proto.mutable_shape();
659       for (const int64 dim : shape.dim_sizes()) {
660         tensor_shape_proto.add_dim()->set_size(dim);
661       }
662       output_types->push_back(dt);
663     } else {
664       CHECK(!require_shape_type)
665           << "No shape type found for " << output << DumpGraphDef(subgraph_def);
666       // Assuming output type is float if no data provided.
667       output_types->push_back(DT_FLOAT);
668     }
669   }
670 
671   return Status::OK();
672 }
673 
674 /* static */ Status
BuildRemoteFusedGraphExecuteOpNode(const string & node_name,const string & executor_name,const GraphDef & subgraph_def,const std::vector<string> & inputs,const std::vector<string> & outputs,const bool require_shape_type,Graph * graph,Node ** created_node)675 RemoteFusedGraphExecuteUtils::BuildRemoteFusedGraphExecuteOpNode(
676     const string& node_name, const string& executor_name,
677     const GraphDef& subgraph_def, const std::vector<string>& inputs,
678     const std::vector<string>& outputs, const bool require_shape_type,
679     Graph* graph, Node** created_node) {
680   CHECK_NOTNULL(graph);
681   CHECK_NOTNULL(created_node);
682 
683   RemoteFusedGraphExecuteInfo execute_info;
684   DataTypeVector input_types;
685   DataTypeVector output_types;
686 
687   TF_CHECK_OK(RemoteFusedGraphExecuteUtils::BuildRemoteFusedGraphExecuteInfo(
688       executor_name, subgraph_def, inputs, outputs, require_shape_type,
689       &execute_info, &input_types, &output_types));
690 
691   std::vector<NodeBuilder::NodeOut> node_out_list;
692   for (const string& input : inputs) {
693     const TensorId tid = ParseTensorName(input);
694     Node* node = FindMutableNodeByName(string(tid.first), graph);
695     CHECK_NOTNULL(node);
696     node_out_list.emplace_back(node, tid.second);
697   }
698 
699   const string execute_info_str = execute_info.SerializeAsString();
700 
701   auto builder =
702       NodeBuilder(node_name, "RemoteFusedGraphExecute")
703           .Input(node_out_list)
704           .Attr("Tinputs", input_types)
705           .Attr("Toutputs", output_types)
706           .Attr("serialized_remote_fused_graph_execute_info", execute_info_str);
707 
708   TF_RETURN_IF_ERROR(builder.Finalize(graph, created_node));
709   return Status::OK();
710 }
711 
BuildIdentityOpNode(const string & node_name,const string & input_node_name,const int input_node_port,const DataType dt,Graph * graph,Node ** created_node)712 /* static */ Status RemoteFusedGraphExecuteUtils::BuildIdentityOpNode(
713     const string& node_name, const string& input_node_name,
714     const int input_node_port, const DataType dt, Graph* graph,
715     Node** created_node) {
716   Node* node = FindMutableNodeByName(input_node_name, graph);
717   CHECK_NOTNULL(node);
718   NodeBuilder::NodeOut node_out(node, input_node_port);
719 
720   auto builder =
721       NodeBuilder(node_name, "Identity").Input(node_out).Attr("T", dt);
722 
723   TF_RETURN_IF_ERROR(builder.Finalize(graph, created_node));
724   return Status::OK();
725 }
726 
ClusterizeNodes(const std::unordered_set<string> & node_names,const GraphDef & graph_def,std::vector<ClusterInfo> * cluster_infos)727 /* static */ Status RemoteFusedGraphExecuteUtils::ClusterizeNodes(
728     const std::unordered_set<string>& node_names, const GraphDef& graph_def,
729     std::vector<ClusterInfo>* cluster_infos) {
730   Graph graph(OpRegistry::Global());
731   ShapeRefiner shape_refiner(graph.versions(), graph.op_registry());
732   TF_RETURN_IF_ERROR(ImportGraphDef({}, graph_def, &graph, &shape_refiner));
733   std::unordered_set<string> remaining_nodes = node_names;
734 
735   while (!remaining_nodes.empty()) {
736     ClusterInfo ci;
737 
738     // Determine one cluster nodes
739     std::unordered_set<const Node*> visited;
740     std::deque<const Node*> queue;
741     queue.emplace_back(FindNodeByName(*remaining_nodes.begin(), graph));
742     while (!queue.empty()) {
743       const Node* node = queue.front();
744       CHECK_NOTNULL(node);
745       queue.pop_front();
746       const string& node_name = node->name();
747       if (node_names.count(node_name) > 0) {
748         std::get<0>(ci).emplace(node_name);
749         remaining_nodes.erase(node_name);
750       } else {
751         // Edge of subgraph.  Do nothing.
752         continue;
753       }
754       for (const Node* in : node->in_nodes()) {
755         if (visited.insert(in).second) {
756           queue.push_back(in);
757         }
758       }
759       for (const Node* out : node->out_nodes()) {
760         if (visited.insert(out).second) {
761           queue.push_back(out);
762         }
763       }
764     }
765 
766     // Determine one cluster border
767     std::vector<string>& border_inputs = std::get<1>(ci);
768     std::vector<string>& border_outputs = std::get<2>(ci);
769     for (const string& node_name : node_names) {
770       Node* node = FindMutableNodeByName(node_name, &graph);
771       CHECK_NOTNULL(node);
772       int input_count = 0;
773       for (const Edge* in_edge : node->in_edges()) {
774         const Node* src_node = in_edge->src();
775         const bool src_is_outside =
776             node_names.count(src_node->name()) <= 0 && !src_node->IsSource();
777         if (src_is_outside) {
778           const string src_name =
779               strings::StrCat(src_node->name(), ":", in_edge->src_output());
780           CHECK_EQ(1, src_node->num_outputs())
781               << "output count of input border node must be one."
782               << src_node->name();
783           if (std::find(border_inputs.begin(), border_inputs.end(), src_name) ==
784               border_inputs.end()) {
785             border_inputs.emplace_back(src_name);
786           }
787         } else {
788           ++input_count;
789         }
790       }
791       CHECK(input_count == 0 || input_count == node->in_edges().size())
792           << "Invalid input_count(" << input_count << ", "
793           << node->in_edges().size() << ") " << node_name;
794 
795       for (const Edge* out_edge : node->out_edges()) {
796         const Node* dst_node = out_edge->dst();
797         CHECK_NOTNULL(dst_node);
798         const bool dst_is_outside = node_names.count(dst_node->name()) <= 0;
799         const string dst_name =
800             strings::StrCat(node->name(), ":", out_edge->src_output());
801         if (dst_is_outside) {
802           if (dst_node->IsSink()) {
803             CHECK_EQ(1, node->num_outputs())
804                 << "If you want to specify output node as subgraph output node "
805                 << "the output count of the node must be 1 "
806                 << "because that node is replaced by identity node.";
807             const string identity_dst_name =
808                 strings::StrCat(node->name(), ":", 0);
809             if (std::find(border_outputs.begin(), border_outputs.end(),
810                           identity_dst_name) == border_outputs.end()) {
811               border_outputs.emplace_back(identity_dst_name);
812             }
813           } else {
814             if (std::find(border_outputs.begin(), border_outputs.end(),
815                           dst_name) == border_outputs.end()) {
816               border_outputs.emplace_back(dst_name);
817             }
818           }
819         }
820       }
821     }
822     cluster_infos->emplace_back(ci);
823     VLOG(1) << DumpCluster(ci);
824   }
825   return Status::OK();
826 }
827 
BuildClusterSubgraphDef(const ClusterInfo & cluster,const GraphDef & graph_def,GraphDef * subgraph_def)828 /* static */ Status RemoteFusedGraphExecuteUtils::BuildClusterSubgraphDef(
829     const ClusterInfo& cluster, const GraphDef& graph_def,
830     GraphDef* subgraph_def) {
831   const std::unordered_set<string>& node_names = std::get<0>(cluster);
832   const std::unordered_set<string>& border_input_names =
833       BuildNodeSetFromNodeNamesAndPorts(std::get<1>(cluster));
834 
835   Graph graph(OpRegistry::Global());
836   ShapeRefiner shape_refiner(graph.versions(), graph.op_registry());
837   TF_RETURN_IF_ERROR(ImportGraphDef({}, graph_def, &graph, &shape_refiner));
838 
839   for (Node* node : graph.nodes()) {
840     if (node != nullptr && node_names.count(node->name()) <= 0 &&
841         border_input_names.count(node->name()) <= 0 && !node->IsSource() &&
842         !node->IsSink()) {
843       graph.RemoveNode(node);
844     }
845   }
846   graph.ToGraphDef(subgraph_def);
847 
848   for (const string& subgraph_input : std::get<1>(cluster)) {
849     const TensorId tid = ParseTensorName(subgraph_input);
850     const string subgraph_input_name(tid.first);
851     const int subgraph_input_port = tid.second;
852     const NodeDef* node_def = FindNodeDefByName(subgraph_input_name, graph_def);
853     CHECK_NOTNULL(node_def);
854     std::vector<DataType> dt_vec;
855     std::vector<TensorShape> shape_vec;
856     GetOutputTensorShapeType(*node_def, &dt_vec, &shape_vec).IgnoreError();
857     const DataType& dt =
858         dt_vec.empty() ? DT_FLOAT : dt_vec.at(subgraph_input_port);
859     const TensorShape& shape =
860         shape_vec.empty() ? TensorShape({}) : shape_vec.at(subgraph_input_port);
861 
862     TF_RETURN_IF_ERROR(ReplaceInputNodeByPlaceHolder(subgraph_input_name, dt,
863                                                      shape, subgraph_def));
864   }
865 
866   // sort subgraph_def to align order in graph_def
867   std::unordered_map<string, int> name_to_id_map;
868   for (int i = 0; i < graph_def.node_size(); ++i) {
869     name_to_id_map.emplace(graph_def.node(i).name(), i);
870   }
871   std::sort(subgraph_def->mutable_node()->begin(),
872             subgraph_def->mutable_node()->end(),
873             [&name_to_id_map](const NodeDef& node0, const NodeDef& node1) {
874               CHECK(name_to_id_map.count(node0.name()) > 0);
875               CHECK(name_to_id_map.count(node1.name()) > 0);
876               const int id0 = name_to_id_map.at(node0.name());
877               const int id1 = name_to_id_map.at(node1.name());
878               return id0 < id1;
879             });
880 
881   VLOG(1) << DumpGraphDef(*subgraph_def);
882   return Status::OK();
883 }
884 
BuildClusterByBorder(const std::vector<string> & border_inputs,const std::vector<string> & border_outputs,const GraphDef & graph_def,ClusterInfo * cluster)885 /* static */ Status RemoteFusedGraphExecuteUtils::BuildClusterByBorder(
886     const std::vector<string>& border_inputs,
887     const std::vector<string>& border_outputs, const GraphDef& graph_def,
888     ClusterInfo* cluster) {
889   Graph graph(OpRegistry::Global());
890   ShapeRefiner shape_refiner(graph.versions(), graph.op_registry());
891   TF_RETURN_IF_ERROR(ImportGraphDef({}, graph_def, &graph, &shape_refiner));
892 
893   std::unordered_set<const Node*> visited;
894   std::deque<const Node*> queue;
895   for (const string& output : border_outputs) {
896     const TensorId tid = ParseTensorName(output);
897     const string output_node_name(tid.first);
898     for (const Node* node : graph.nodes()) {
899       if (output_node_name == node->name()) {
900         queue.push_back(node);
901         visited.insert(node);
902       }
903     }
904   }
905 
906   std::unordered_set<const Node*> border_input_nodes;
907   // propagate visit to parent nodes until input nodes
908   while (!queue.empty()) {
909     const Node* node = queue.front();
910     queue.pop_front();
911     for (const Edge* edge : node->in_edges()) {
912       const Node* src_node = edge->src();
913       CHECK_NOTNULL(src_node);
914       const int src_port = edge->src_output();
915       bool input_found = false;
916       for (const string& input : border_inputs) {
917         const TensorId tid = ParseTensorName(input);
918         if (tid.first == src_node->name() && tid.second == src_port) {
919           input_found = true;
920           border_input_nodes.insert(src_node);
921         }
922       }
923       if (visited.insert(src_node).second) {
924         if (!input_found) {
925           queue.push_back(src_node);
926         }
927       }
928     }
929   }
930 
931   for (const Node* node : visited) {
932     if (node != nullptr && !node->IsSource() && !node->IsSink() &&
933         border_input_nodes.count(node) <= 0) {
934       std::get<0>(*cluster).insert(node->name());
935     }
936   }
937   std::get<1>(*cluster) = border_inputs;
938   std::get<2>(*cluster) = border_outputs;
939   return Status::OK();
940 }
941 
FuseCluster(const GraphDef & input_graph_def,const std::vector<string> & inputs,const std::vector<string> & outputs,const string & remote_fused_graph_node_name,const ClusterInfo & cluster,const string & remote_graph_executor_name,const bool require_shape_type,GraphDef * output_graph_def)942 /* static */ Status RemoteFusedGraphExecuteUtils::FuseCluster(
943     const GraphDef& input_graph_def, const std::vector<string>& inputs,
944     const std::vector<string>& outputs,
945     const string& remote_fused_graph_node_name, const ClusterInfo& cluster,
946     const string& remote_graph_executor_name, const bool require_shape_type,
947     GraphDef* output_graph_def) {
948   LOG(INFO) << "Transforming quantized stripped model to a remote fused "
949                "graph execute op by fusing a specified subgraph...";
950 
951   CHECK(!remote_graph_executor_name.empty());
952 
953   const std::vector<string>& border_inputs = std::get<1>(cluster);
954   const std::vector<string>& border_outputs = std::get<2>(cluster);
955 
956   GraphDef subgraph_def;
957   TF_RETURN_IF_ERROR(
958       BuildClusterSubgraphDef(cluster, input_graph_def, &subgraph_def));
959 
960   Graph graph(OpRegistry::Global());
961   ShapeRefiner shape_refiner(graph.versions(), graph.op_registry());
962   TF_RETURN_IF_ERROR(
963       ImportGraphDef({}, input_graph_def, &graph, &shape_refiner));
964 
965   Node* fused_node;
966   TF_RETURN_IF_ERROR(BuildRemoteFusedGraphExecuteOpNode(
967       remote_fused_graph_node_name, remote_graph_executor_name, subgraph_def,
968       border_inputs, border_outputs, require_shape_type, &graph, &fused_node));
969 
970   for (const Node* node : graph.nodes()) {
971     for (int i = 0; i < node->num_inputs(); ++i) {
972       const Edge* edge = nullptr;
973       TF_RETURN_IF_ERROR(node->input_edge(i, &edge));
974       for (int j = 0; j < border_outputs.size(); ++j) {
975         const string& output = border_outputs.at(j);
976         const TensorId tid = ParseTensorName(output);
977         const string output_name(tid.first);
978         Node* src_node = edge->src();
979         if (src_node != nullptr && src_node->name() == output_name &&
980             edge->src_output() == tid.second) {
981           // Source node is replaced by new fused node.
982           Node* dst_node = edge->dst();
983           const int dst_input = edge->dst_input();
984           LOG(INFO) << "Removing existing edge to " << edge->dst()->name()
985                     << " from " << edge->src()->name();
986           graph.RemoveEdge(edge);
987           graph.AddEdge(fused_node, j, dst_node, dst_input);
988         }
989       }
990     }
991   }
992 
993   // Replace output nodes by identity nodes which forward outputs from
994   // RemoteFusedGraphExecuteOpNode
995   for (const string& output : outputs) {
996     const TensorId output_tid = ParseTensorName(output);
997     const string output_name(output_tid.first);
998     for (size_t i = 0; i < border_outputs.size(); ++i) {
999       const TensorId subgraph_output_tid =
1000           ParseTensorName(border_outputs.at(i));
1001       const string subgraph_output_name(subgraph_output_tid.first);
1002       if (output_name == subgraph_output_name) {
1003         LOG(INFO) << "As graph output and subgraph output are same, "
1004                   << "the graph output node is replaced by identity node";
1005         Node* original_output_node = FindMutableNodeByName(output_name, &graph);
1006         CHECK_NOTNULL(original_output_node);
1007         CHECK_EQ(1, original_output_node->num_outputs())
1008             << "Num outputs should be 1 for " << output << ".";
1009         graph.RemoveNode(original_output_node);
1010         Node* new_node;
1011         TF_RETURN_IF_ERROR(BuildIdentityOpNode(output_name,
1012                                                remote_fused_graph_node_name, i,
1013                                                DT_FLOAT, &graph, &new_node));
1014         CHECK_NOTNULL(new_node);
1015       }
1016     }
1017   }
1018 
1019   GraphDef result_graph_def;
1020 
1021   graph.ToGraphDef(&result_graph_def);
1022 
1023   ClusterInfo graph_cluster;
1024   TF_RETURN_IF_ERROR(
1025       BuildClusterByBorder(inputs, outputs, result_graph_def, &graph_cluster));
1026 
1027   // Remove unvisited nodes
1028   TF_RETURN_IF_ERROR(BuildClusterSubgraphDef(graph_cluster, result_graph_def,
1029                                              output_graph_def));
1030 
1031   return Status::OK();
1032 }
1033 
FuseRemoteGraphByNodeNames(const GraphDef & input_graph_def,const std::vector<string> & inputs,const std::vector<string> & outputs,const string & remote_fused_graph_node_name_prefix,const std::unordered_set<string> & subgraph_nodes,const string & remote_fused_graph_executor_name,const bool require_shape_type,GraphDef * output_graph_def)1034 /* static */ Status RemoteFusedGraphExecuteUtils::FuseRemoteGraphByNodeNames(
1035     const GraphDef& input_graph_def, const std::vector<string>& inputs,
1036     const std::vector<string>& outputs,
1037     const string& remote_fused_graph_node_name_prefix,
1038     const std::unordered_set<string>& subgraph_nodes,
1039     const string& remote_fused_graph_executor_name,
1040     const bool require_shape_type, GraphDef* output_graph_def) {
1041   std::vector<ClusterInfo> ci_vec;
1042   TF_RETURN_IF_ERROR(RemoteFusedGraphExecuteUtils::ClusterizeNodes(
1043       subgraph_nodes, input_graph_def, &ci_vec));
1044 
1045   for (size_t i = 0; i < ci_vec.size(); ++i) {
1046     const string remote_fused_graph_node_name =
1047         strings::StrCat(remote_fused_graph_node_name_prefix, "/", i);
1048     TF_RETURN_IF_ERROR(FuseCluster(input_graph_def, inputs, outputs,
1049                                    remote_fused_graph_node_name, ci_vec.at(i),
1050                                    remote_fused_graph_executor_name,
1051                                    require_shape_type, output_graph_def));
1052   }
1053   return Status::OK();
1054 }
1055 
FuseRemoteGraphByBorder(const GraphDef & input_graph_def,const std::vector<string> & inputs,const std::vector<string> & outputs,const string & remote_fused_graph_node_name,const std::vector<string> & border_inputs,const std::vector<string> & border_outputs,const string & remote_graph_executor_name,const bool require_shape_type,GraphDef * output_graph_def)1056 /* static */ Status RemoteFusedGraphExecuteUtils::FuseRemoteGraphByBorder(
1057     const GraphDef& input_graph_def, const std::vector<string>& inputs,
1058     const std::vector<string>& outputs,
1059     const string& remote_fused_graph_node_name,
1060     const std::vector<string>& border_inputs,
1061     const std::vector<string>& border_outputs,
1062     const string& remote_graph_executor_name, const bool require_shape_type,
1063     GraphDef* output_graph_def) {
1064   ClusterInfo cluster;
1065   TF_RETURN_IF_ERROR(RemoteFusedGraphExecuteUtils::BuildClusterByBorder(
1066       border_inputs, border_outputs, input_graph_def, &cluster));
1067 
1068   return FuseCluster(
1069       input_graph_def, inputs, outputs, remote_fused_graph_node_name, cluster,
1070       remote_graph_executor_name, require_shape_type, output_graph_def);
1071 }
1072 
FuseRemoteGraphByOpTypes(const GraphDef & input_graph_def,const std::vector<string> & inputs,const std::vector<string> & outputs,const string & remote_fused_graph_node_name_prefix,const std::unordered_set<string> & fused_op_types,const string & remote_fused_graph_executor_name,const bool require_shape_type,GraphDef * output_graph_def)1073 /* static */ Status RemoteFusedGraphExecuteUtils::FuseRemoteGraphByOpTypes(
1074     const GraphDef& input_graph_def, const std::vector<string>& inputs,
1075     const std::vector<string>& outputs,
1076     const string& remote_fused_graph_node_name_prefix,
1077     const std::unordered_set<string>& fused_op_types,
1078     const string& remote_fused_graph_executor_name,
1079     const bool require_shape_type, GraphDef* output_graph_def) {
1080   const std::unordered_set<string> fused_nodes_filtered_by_op_types =
1081       BuildNodeMapFromOpTypes(input_graph_def, fused_op_types);
1082 
1083   return FuseRemoteGraphByNodeNames(
1084       input_graph_def, inputs, outputs, remote_fused_graph_node_name_prefix,
1085       fused_nodes_filtered_by_op_types, remote_fused_graph_executor_name,
1086       require_shape_type, output_graph_def);
1087 }
1088 
FuseRemoteGraphByExecutor(const GraphDef & input_graph_def,const std::vector<string> & inputs,const std::vector<string> & outputs,const string & executor_name,GraphDef * output_graph_def)1089 /* static */ Status RemoteFusedGraphExecuteUtils::FuseRemoteGraphByExecutor(
1090     const GraphDef& input_graph_def, const std::vector<string>& inputs,
1091     const std::vector<string>& outputs, const string& executor_name,
1092     GraphDef* output_graph_def) {
1093   const ExecutorBuildFunc* build_func = GetExecutorBuildFunc(executor_name);
1094   if (build_func == nullptr) {
1095     return errors::InvalidArgument("Unknown executor name: " + executor_name);
1096   }
1097   std::unique_ptr<IRemoteFusedGraphExecutor> executor;
1098   TF_RETURN_IF_ERROR((*build_func)(&executor));
1099   CHECK_NOTNULL(executor.get());
1100   if (!executor->IsEnabled()) {
1101     // As this executor is not enabled, just return original graph as is.
1102     *output_graph_def = input_graph_def;
1103     return Status::OK();
1104   }
1105   return executor->FuseRemoteGraph(input_graph_def, inputs, outputs,
1106                                    output_graph_def);
1107 }
1108 
PlaceRemoteGraphArguments(const std::vector<string> & inputs,const std::vector<string> & outputs,const std::unordered_set<string> & fused_node_names,const std::vector<string> & border_inputs,const std::vector<string> & border_outputs,const std::unordered_set<string> & fused_op_types,const string & remote_fused_graph_node_name,const string & remote_graph_executor_name,GraphDef * graph_def)1109 /* static */ Status RemoteFusedGraphExecuteUtils::PlaceRemoteGraphArguments(
1110     const std::vector<string>& inputs, const std::vector<string>& outputs,
1111     const std::unordered_set<string>& fused_node_names,
1112     const std::vector<string>& border_inputs,
1113     const std::vector<string>& border_outputs,
1114     const std::unordered_set<string>& fused_op_types,
1115     const string& remote_fused_graph_node_name,
1116     const string& remote_graph_executor_name, GraphDef* graph_def) {
1117   CHECK_NOTNULL(graph_def);
1118 
1119   const std::unordered_set<string> fused_nodes_filtered_by_op_types =
1120       BuildNodeMapFromOpTypes(*graph_def, fused_op_types);
1121 
1122   for (NodeDef& node_def : *graph_def->mutable_node()) {
1123     string attr_str;
1124     TensorId tid;
1125     for (size_t i = 0; i < inputs.size(); ++i) {
1126       if (IsSameNodeName(node_def, inputs.at(i), &tid)) {
1127         AppendDeliminator(&attr_str);
1128         attr_str += BuildNodeTypeAttr(GRAPH_INPUT, tid.second, i,
1129                                       remote_graph_executor_name,
1130                                       remote_fused_graph_node_name);
1131       }
1132     }
1133     for (size_t i = 0; i < outputs.size(); ++i) {
1134       if (IsSameNodeName(node_def, outputs.at(i), &tid)) {
1135         AppendDeliminator(&attr_str);
1136         attr_str += BuildNodeTypeAttr(GRAPH_OUTPUT, tid.second, i);
1137       }
1138     }
1139     for (const string& fused_node_name : fused_node_names) {
1140       if (fused_node_name == node_def.name()) {
1141         AppendDeliminator(&attr_str);
1142         attr_str += BuildNodeTypeAttr(FUSED_NODE);
1143       }
1144     }
1145     for (const string& fused_node_name : fused_nodes_filtered_by_op_types) {
1146       if (fused_node_name == node_def.name()) {
1147         AppendDeliminator(&attr_str);
1148         attr_str += BuildNodeTypeAttr(FUSED_NODE);
1149       }
1150     }
1151     for (size_t i = 0; i < border_inputs.size(); ++i) {
1152       if (IsSameNodeName(node_def, border_inputs.at(i), &tid)) {
1153         AppendDeliminator(&attr_str);
1154         attr_str += BuildNodeTypeAttr(BORDER_INPUT, tid.second, i);
1155       }
1156     }
1157     for (size_t i = 0; i < border_outputs.size(); ++i) {
1158       if (IsSameNodeName(node_def, border_outputs.at(i), &tid)) {
1159         AppendDeliminator(&attr_str);
1160         attr_str += BuildNodeTypeAttr(BORDER_OUTPUT, tid.second, i);
1161       }
1162     }
1163     if (attr_str.empty()) {
1164       attr_str += BuildNodeTypeAttr(UNUSED);
1165     }
1166     AddNodeAttr(ATTR_NODE_TYPE, attr_str, &node_def);
1167   }
1168   return Status::OK();
1169 }
1170 
1171 /* static */ Status
FuseRemoteGraphByPlacedArguments(const GraphDef & input_graph_def,const std::vector<std::pair<string,Tensor>> & input_tensors,GraphDef * output_graph_def)1172 RemoteFusedGraphExecuteUtils::FuseRemoteGraphByPlacedArguments(
1173     const GraphDef& input_graph_def,
1174     const std::vector<std::pair<string, Tensor>>& input_tensors,
1175     GraphDef* output_graph_def) {
1176   std::unordered_map<int, string> input_map;
1177   std::unordered_map<int, string> output_map;
1178   std::unordered_set<string> fused_node_names;
1179   std::unordered_map<int, string> border_input_map;
1180   std::unordered_map<int, string> border_output_map;
1181   string remote_graph_executor_name;
1182   string remote_fused_graph_node_name;
1183 
1184   for (const NodeDef& node_def : input_graph_def.node()) {
1185     string attr_str;
1186     TF_RETURN_IF_ERROR(GetNodeAttr(node_def, ATTR_NODE_TYPE, &attr_str));
1187     std::vector<std::vector<string>> attr_strs;
1188     for (const string& str : str_util::Split(attr_str, ":")) {
1189       attr_strs.emplace_back(str_util::Split(str, ","));
1190     }
1191     if (attr_strs.empty()) {
1192       return errors::InvalidArgument("Remote graph node type not found.");
1193     }
1194     for (const std::vector<string>& attr : attr_strs) {
1195       if (attr.empty()) {
1196         return errors::InvalidArgument("Empty remote graph node type attr.");
1197       }
1198       int node_type_int;
1199       CHECK(strings::safe_strto32(attr.at(0), &node_type_int)) << attr.at(0);
1200       const RemoteFusedGraphNodeType node_type =
1201           static_cast<RemoteFusedGraphNodeType>(node_type_int);
1202       const string& name = node_def.name();
1203       int port;
1204       int index;
1205 
1206       switch (node_type) {
1207         case GRAPH_INPUT:
1208           VLOG(2) << "Graph input: " << name;
1209           CHECK_EQ(5, attr.size());
1210           CHECK(strings::safe_strto32(attr.at(1), &port));
1211           CHECK(strings::safe_strto32(attr.at(2), &index));
1212           CHECK(!attr.at(3).empty());
1213           remote_graph_executor_name = attr.at(3);
1214           CHECK(!attr.at(4).empty());
1215           remote_fused_graph_node_name = attr.at(4);
1216           input_map.emplace(index, strings::StrCat(name, ":", port));
1217           if (GetExecutorBuildFunc(remote_graph_executor_name) == nullptr) {
1218             LOG(INFO) << "Executor for " << remote_graph_executor_name
1219                       << " not registered.  Do not fuse.";
1220             *output_graph_def = input_graph_def;
1221             return Status::OK();
1222           }
1223           break;
1224         case GRAPH_OUTPUT:
1225           VLOG(2) << "Graph output: " << name;
1226           CHECK_EQ(3, attr.size());
1227           CHECK(strings::safe_strto32(attr.at(1), &port));
1228           CHECK(strings::safe_strto32(attr.at(2), &index));
1229           output_map.emplace(index, strings::StrCat(name, ":", port));
1230           break;
1231         case FUSED_NODE:
1232           VLOG(2) << "Fused node: " << name;
1233           CHECK_EQ(1, attr.size());
1234           fused_node_names.emplace(name);
1235           break;
1236         case BORDER_INPUT:
1237           VLOG(2) << "Border input: " << name;
1238           CHECK_EQ(3, attr.size());
1239           CHECK(strings::safe_strto32(attr.at(1), &port));
1240           CHECK(strings::safe_strto32(attr.at(2), &index));
1241           border_input_map.emplace(index, strings::StrCat(name, ":", port));
1242           break;
1243         case BORDER_OUTPUT:
1244           VLOG(2) << "Border output: " << name;
1245           CHECK_EQ(3, attr.size());
1246           CHECK(strings::safe_strto32(attr.at(1), &port));
1247           CHECK(strings::safe_strto32(attr.at(2), &index));
1248           border_output_map.emplace(index, strings::StrCat(name, ":", port));
1249           break;
1250         case UNUSED:
1251           // do nothing
1252           break;
1253         default:
1254           // unsupported value
1255           LOG(FATAL);
1256       }
1257     }
1258   }
1259   bool require_shape_type = false;
1260   std::vector<string> inputs;
1261   std::vector<string> outputs;
1262   std::vector<string> border_inputs;
1263   std::vector<string> border_outputs;
1264   ConvertMapToVector(input_map, &inputs);
1265   ConvertMapToVector(output_map, &outputs);
1266   ConvertMapToVector(border_input_map, &border_inputs);
1267   ConvertMapToVector(border_output_map, &border_outputs);
1268 
1269   if (!input_tensors.empty()) {
1270     bool input_match = false;
1271     if (inputs.size() == input_tensors.size()) {
1272       for (const std::pair<string, Tensor>& input_tensor : input_tensors) {
1273         if (!ContainsSameTensorId(input_tensor.first, inputs)) {
1274           break;
1275         }
1276         DataType data_type;
1277         TensorShape shape;
1278         if (GetOutputTensorShapeType(input_graph_def, input_tensor.first,
1279                                      &data_type, &shape)) {
1280           if (data_type == input_tensor.second.dtype() &&
1281               shape == input_tensor.second.shape()) {
1282             VLOG(2) << "Input matched!";
1283             // Shape type matched.
1284             input_match = true;
1285             require_shape_type = true;
1286           }
1287         } else {
1288           // Shape type not required.
1289           input_match = true;
1290         }
1291       }
1292     }
1293     if (!input_match) {
1294       // Input mismatch.  Just copy original graph
1295       *output_graph_def = input_graph_def;
1296       return Status::OK();
1297     }
1298   }
1299 
1300   if (!fused_node_names.empty()) {
1301     TF_RETURN_IF_ERROR(FuseRemoteGraphByNodeNames(
1302         input_graph_def, inputs, outputs, remote_fused_graph_node_name,
1303         fused_node_names, remote_graph_executor_name, require_shape_type,
1304         output_graph_def));
1305   } else if (!border_inputs.empty() || !border_outputs.empty()) {
1306     TF_RETURN_IF_ERROR(FuseRemoteGraphByBorder(
1307         input_graph_def, inputs, outputs, remote_fused_graph_node_name,
1308         border_inputs, border_outputs, remote_graph_executor_name,
1309         require_shape_type, output_graph_def));
1310   } else {
1311     *output_graph_def = input_graph_def;
1312   }
1313 
1314   return Status::OK();
1315 }
1316 
IsFuseReady(const GraphDef & graph_def,const std::vector<std::pair<string,Tensor>> & input_tensors)1317 /* static */ bool RemoteFusedGraphExecuteUtils::IsFuseReady(
1318     const GraphDef& graph_def,
1319     const std::vector<std::pair<string, Tensor>>& input_tensors) {
1320   for (const std::pair<string, Tensor>& input_tensor : input_tensors) {
1321     const NodeDef* node_def = FindNodeDefByName(input_tensor.first, graph_def);
1322     if (node_def == nullptr) {
1323       return false;
1324     }
1325     string attr;
1326     const Status status = GetNodeAttr(*node_def, ATTR_NODE_TYPE, &attr);
1327     if (!status.ok() || attr.empty()) {
1328       return false;
1329     }
1330   }
1331   return true;
1332 }
1333 
CopyByteArrayToTensor(const void * src_ptr,const int src_size,Tensor * tensor)1334 /* static */ Status RemoteFusedGraphExecuteUtils::CopyByteArrayToTensor(
1335     const void* src_ptr, const int src_size, Tensor* tensor) {
1336   CHECK(tensor->TotalBytes() >= src_size)
1337       << tensor->TotalBytes() << ", " << src_size;
1338   void* dst_ptr;
1339   switch (tensor->dtype()) {
1340     case DT_FLOAT:
1341       dst_ptr = tensor->flat<float>().data();
1342       break;
1343     case DT_DOUBLE:
1344       dst_ptr = tensor->flat<double>().data();
1345       break;
1346     case DT_INT32:
1347       dst_ptr = tensor->flat<int32>().data();
1348       break;
1349     case DT_UINT8:
1350       dst_ptr = tensor->flat<uint8>().data();
1351       break;
1352     case DT_INT16:
1353       dst_ptr = tensor->flat<int16>().data();
1354       break;
1355     case DT_INT8:
1356       dst_ptr = tensor->flat<int8>().data();
1357       break;
1358     case DT_STRING:
1359       dst_ptr = tensor->flat<string>().data();
1360       break;
1361     case DT_INT64:
1362       dst_ptr = tensor->flat<int64>().data();
1363       break;
1364     case DT_BOOL:
1365       dst_ptr = tensor->flat<bool>().data();
1366       break;
1367     case DT_QINT8:
1368       dst_ptr = tensor->flat<qint8>().data();
1369       break;
1370     case DT_QUINT8:
1371       dst_ptr = tensor->flat<quint8>().data();
1372       break;
1373     case DT_QINT32:
1374       dst_ptr = tensor->flat<qint32>().data();
1375       break;
1376     case DT_BFLOAT16:
1377       dst_ptr = tensor->flat<bfloat16>().data();
1378       break;
1379     case DT_QINT16:
1380       dst_ptr = tensor->flat<qint16>().data();
1381       break;
1382     case DT_QUINT16:
1383       dst_ptr = tensor->flat<quint16>().data();
1384       break;
1385     case DT_UINT16:
1386       dst_ptr = tensor->flat<uint16>().data();
1387       break;
1388     default:
1389       LOG(FATAL) << "type " << tensor->dtype() << " is not supported.";
1390       break;
1391   }
1392   CHECK_NOTNULL(dst_ptr);
1393   std::memcpy(dst_ptr, src_ptr, src_size);
1394   return Status::OK();
1395 }
1396 
1397 /* static */ std::unordered_set<string>
BuildNodeMapFromOpTypes(const GraphDef & graph_def,const std::unordered_set<string> & op_types)1398 RemoteFusedGraphExecuteUtils::BuildNodeMapFromOpTypes(
1399     const GraphDef& graph_def, const std::unordered_set<string>& op_types) {
1400   std::unordered_set<string> retval;
1401   for (const NodeDef& node_def : graph_def.node()) {
1402     if (op_types.count(node_def.op()) > 0) {
1403       retval.emplace(node_def.name());
1404     }
1405   }
1406   return retval;
1407 }
1408 
1409 /* static */ std::unordered_set<string>
BuildNodeMapFromOpsDefinitions(const GraphDef & graph_def,const IRemoteFusedGraphOpsDefinitions & ops_definitions)1410 RemoteFusedGraphExecuteUtils::BuildNodeMapFromOpsDefinitions(
1411     const GraphDef& graph_def,
1412     const IRemoteFusedGraphOpsDefinitions& ops_definitions) {
1413   std::unordered_set<string> retval;
1414   for (const NodeDef& node_def : graph_def.node()) {
1415     std::vector<DataType> dt_vec;
1416     std::vector<TensorShape> shape_vec;
1417     const Status status =
1418         GetOutputTensorShapeType(node_def, &dt_vec, &shape_vec);
1419     if (!status.ok()) {
1420       shape_vec.clear();
1421     }
1422     if (ops_definitions.GetOpIdFor(
1423             node_def.op(), DataTypeVector(dt_vec.begin(), dt_vec.end())) !=
1424         IRemoteFusedGraphOpsDefinitions::INVALID_OP_ID) {
1425       retval.emplace(node_def.name());
1426     }
1427   }
1428   return retval;
1429 }
1430 
ReplaceInputNodeByPlaceHolder(const string & input,const DataType type,const TensorShape & shape,GraphDef * graph_def)1431 /* static */ Status RemoteFusedGraphExecuteUtils::ReplaceInputNodeByPlaceHolder(
1432     const string& input, const DataType type, const TensorShape& shape,
1433     GraphDef* graph_def) {
1434   const TensorId tid = ParseTensorName(input);
1435   CHECK_EQ(0, tid.second);
1436   const string node_name(tid.first);
1437   for (NodeDef& node : *graph_def->mutable_node()) {
1438     if (node.name() != node_name) {
1439       continue;
1440     }
1441     if (node.op() == "Placeholder") {
1442       return Status::OK();
1443     } else {
1444       NodeDef placeholder_node;
1445       placeholder_node.set_op("Placeholder");
1446       placeholder_node.set_name(node_name);
1447       AddNodeAttr("dtype", type, &placeholder_node);
1448       AddNodeAttr("shape", shape, &placeholder_node);
1449       // TODO(satok): Remove once we merge attributes
1450       AddOutputTensorShapeType({type}, {shape}, &placeholder_node);
1451       node.Clear();
1452       node = placeholder_node;
1453       return Status::OK();
1454     }
1455   }
1456   return errors::InvalidArgument(
1457       strings::StrCat(node_name, " not found for replacement."));
1458 }
1459 
BuildNodeTypeAttr(const RemoteFusedGraphNodeType node_type,const int port,const int index,const string & executor_name,const string & node_name)1460 /* static */ string RemoteFusedGraphExecuteUtils::BuildNodeTypeAttr(
1461     const RemoteFusedGraphNodeType node_type, const int port, const int index,
1462     const string& executor_name, const string& node_name) {
1463   return strings::StrCat(static_cast<int>(node_type), ",", port, ",", index,
1464                          ",", executor_name, ",", node_name);
1465 }
1466 
BuildNodeTypeAttr(const RemoteFusedGraphNodeType node_type,const int port,const int index)1467 /* static */ string RemoteFusedGraphExecuteUtils::BuildNodeTypeAttr(
1468     const RemoteFusedGraphNodeType node_type, const int port, const int index) {
1469   return strings::StrCat(static_cast<int>(node_type), ",", port, ",", index);
1470 }
1471 
BuildNodeTypeAttr(const RemoteFusedGraphNodeType node_type)1472 /* static */ string RemoteFusedGraphExecuteUtils::BuildNodeTypeAttr(
1473     const RemoteFusedGraphNodeType node_type) {
1474   return strings::StrCat(static_cast<int>(node_type));
1475 }
1476 
1477 }  // namespace tensorflow
1478