• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/kernels/remote_fused_graph_execute_utils.h"
17 
18 #include <algorithm>
19 #include <queue>
20 #include <utility>
21 
22 #include "tensorflow/core/common_runtime/shape_refiner.h"
23 #include "tensorflow/core/framework/graph.pb.h"
24 #include "tensorflow/core/framework/node_def_util.h"
25 #include "tensorflow/core/framework/remote_fused_graph_execute_info.pb.h"
26 #include "tensorflow/core/framework/tensor.pb.h"
27 #include "tensorflow/core/framework/tensor_shape.pb.h"
28 #include "tensorflow/core/graph/algorithm.h"
29 #include "tensorflow/core/graph/node_builder.h"
30 #include "tensorflow/core/public/session.h"
31 #include "tensorflow/core/public/session_options.h"
32 
33 namespace tensorflow {
34 namespace {
FindNodeByName(const string & name,const Graph & graph)35 const Node* FindNodeByName(const string& name, const Graph& graph) {
36   for (const Node* node : graph.nodes()) {
37     CHECK_NOTNULL(node);
38     if (node->name() == name) {
39       return node;
40     }
41   }
42   return nullptr;
43 }
44 
BuildNodeSetFromNodeNamesAndPorts(const std::vector<string> & node_names_and_ports)45 std::unordered_set<string> BuildNodeSetFromNodeNamesAndPorts(
46     const std::vector<string>& node_names_and_ports) {
47   std::unordered_set<string> retval;
48   for (const string& node_name_and_port : node_names_and_ports) {
49     const TensorId tid = ParseTensorName(node_name_and_port);
50     retval.emplace(tid.first);
51   }
52   return retval;
53 }
54 
FindMutableNodeByName(const string & name,Graph * graph)55 Node* FindMutableNodeByName(const string& name, Graph* graph) {
56   for (Node* node : graph->nodes()) {
57     if (node != nullptr && node->name() == name) {
58       return node;
59     }
60   }
61   return nullptr;
62 }
63 
FindNodeDefByName(const string & input,const GraphDef & graph_def)64 const NodeDef* FindNodeDefByName(const string& input,
65                                  const GraphDef& graph_def) {
66   const TensorId tid = ParseTensorName(input);
67   const string name = string(tid.first);
68   for (const NodeDef& node_def : graph_def.node()) {
69     if (node_def.name() == name) {
70       return &node_def;
71     }
72   }
73   return nullptr;
74 }
75 
IsSameNodeName(const NodeDef & node_def,const string & node_name_and_port,TensorId * tid)76 bool IsSameNodeName(const NodeDef& node_def, const string& node_name_and_port,
77                     TensorId* tid) {
78   CHECK_NOTNULL(tid);
79   *tid = ParseTensorName(node_name_and_port);
80   if (node_def.name() == tid->first) {
81     return true;
82   }
83   return false;
84 }
85 
ContainsSameTensorId(const string & tensor_name,const std::vector<string> & tensor_names)86 bool ContainsSameTensorId(const string& tensor_name,
87                           const std::vector<string>& tensor_names) {
88   const TensorId tid0 = ParseTensorName(tensor_name);
89   for (const string& name : tensor_names) {
90     const TensorId tid1 = ParseTensorName(name);
91     if (tid0.first == tid1.first && tid0.second == tid1.second) {
92       return true;
93     }
94   }
95   return false;
96 }
97 
AppendDeliminator(string * str)98 void AppendDeliminator(string* str) {
99   CHECK_NOTNULL(str);
100   if (!str->empty()) {
101     *str += ":";
102   }
103 }
104 
ConvertMapToVector(const std::unordered_map<int,string> & in,std::vector<string> * out)105 void ConvertMapToVector(const std::unordered_map<int, string>& in,
106                         std::vector<string>* out) {
107   CHECK_NOTNULL(out);
108   out->resize(in.size());
109   for (size_t i = 0; i < in.size(); ++i) {
110     CHECK(in.count(i) > 0);
111     out->at(i) = in.at(i);
112   }
113 }
114 
DumpGraphDef(const GraphDef & graph_def)115 string DumpGraphDef(const GraphDef& graph_def) {
116   string out;
117   for (const NodeDef& node : graph_def.node()) {
118     out += strings::StrCat("node: ", node.name(), "\n    input: ");
119     for (const string& input : node.input()) {
120       out += strings::StrCat(input, ", ");
121     }
122     out += "\n";
123   }
124   return out;
125 }
126 
DumpCluster(const RemoteFusedGraphExecuteUtils::ClusterInfo & cluster)127 string DumpCluster(const RemoteFusedGraphExecuteUtils::ClusterInfo& cluster) {
128   string out;
129   out += "Nodes:\n";
130   for (const string& str : std::get<0>(cluster)) {
131     out += str + ", ";
132   }
133   out += "\nInput border:\n";
134   for (const string& str : std::get<1>(cluster)) {
135     out += str + ", ";
136   }
137   out += "\nOutput border:\n";
138   for (const string& str : std::get<2>(cluster)) {
139     out += str + ", ";
140   }
141   return out;
142 }
143 
144 }  // namespace
145 
146 /* static */ constexpr const char* const
147     RemoteFusedGraphExecuteUtils::ATTR_OUTPUT_DATA_TYPES;
148 /* static */ constexpr const char* const
149     RemoteFusedGraphExecuteUtils::ATTR_OUTPUT_SHAPES;
150 /* static */ constexpr const char* const RemoteFusedGraphExecuteUtils::
151     ATTR_SERIALIZED_REMOTE_FUSED_GRAPH_EXECUTE_INFO;
152 /* static */ constexpr const char* const
153     RemoteFusedGraphExecuteUtils::ATTR_NODE_TYPE;
154 /* static */ constexpr const char* const RemoteFusedGraphExecuteUtils::
155     TRANSFORM_ARG_REMOTE_FUSED_GRAPH_EXECUTOR_NAME;
156 /* static */ constexpr const char* const
157     RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_REMOTE_FUSED_GRAPH_NODE_NAME;
158 /* static */ constexpr const char* const
159     RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_FUSED_NODES;
160 /* static */ constexpr const char* const
161     RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_BORDER_INPUTS;
162 /* static */ constexpr const char* const
163     RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_BORDER_OUTPUTS;
164 /* static */ constexpr const char* const
165     RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_FUSED_OP_TYPES;
166 /* static */ constexpr const char* const
167     RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_FUSE_BY_EXECUTOR;
168 /* static */ constexpr const char* const
169     RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_INPUT_TYPES;
170 /* static */ constexpr const char* const
171     RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_INPUT_SHAPES;
172 
ExecutorBuildRegistrar(const string & name,ExecutorBuildFunc executor_build_func)173 RemoteFusedGraphExecuteUtils::ExecutorBuildRegistrar::ExecutorBuildRegistrar(
174     const string& name, ExecutorBuildFunc executor_build_func) {
175   ExecutorBuildRegistry& executor_build_registry = *GetExecutorBuildRegistry();
176   executor_build_registry[name] = std::move(executor_build_func);
177 }
178 
179 /* static */ const RemoteFusedGraphExecuteUtils::ExecutorBuildFunc*
GetExecutorBuildFunc(const string & name)180 RemoteFusedGraphExecuteUtils::GetExecutorBuildFunc(const string& name) {
181   ExecutorBuildRegistry& executor_build_registry = *GetExecutorBuildRegistry();
182   if (executor_build_registry.count(name) <= 0) {
183     return nullptr;
184   }
185   return &executor_build_registry.at(name);
186 }
187 
188 /* static */ RemoteFusedGraphExecuteUtils::ExecutorBuildRegistry*
GetExecutorBuildRegistry()189 RemoteFusedGraphExecuteUtils::GetExecutorBuildRegistry() {
190   static ExecutorBuildRegistry executor_builder_registry;
191   return &executor_builder_registry;
192 }
193 
194 /**
195  * - DryRunInference
196  * To determine shapes of output tensors of all nodes, dryrun the graph.
197  * This function supplies memory allocation information when loading
198  * the graph. This function is used to verify shape inference and actual
199  * output shape.
200  */
DryRunInference(const GraphDef & graph_def,const std::vector<std::pair<string,Tensor>> & input_node_info_list,const std::vector<string> & output_node_names,const bool initialize_by_zero,std::vector<tensorflow::Tensor> * output_tensors)201 /* static */ Status RemoteFusedGraphExecuteUtils::DryRunInference(
202     const GraphDef& graph_def,
203     const std::vector<std::pair<string, Tensor>>& input_node_info_list,
204     const std::vector<string>& output_node_names, const bool initialize_by_zero,
205     std::vector<tensorflow::Tensor>* output_tensors) {
206   // Create input tensor vector.  If "initialize_by_zero" is true,
207   // input tensor fields are initialized by 0.
208   std::vector<std::pair<string, tensorflow::Tensor>> input_tensors;
209   for (const std::pair<string, Tensor>& input : input_node_info_list) {
210     CHECK(input.second.IsInitialized());
211     if (!initialize_by_zero) {
212       input_tensors.push_back({input.first, input.second});
213       continue;
214     }
215     // If input tensor is not initialized, initialize by 0-filling
216     const DataType data_type = input.second.dtype();
217     const TensorShape& shape = input.second.shape();
218     Tensor input_tensor(data_type, shape);
219     switch (data_type) {
220       case DT_INT32: {
221         auto int_tensor = input_tensor.flat<int32>();
222         int_tensor = int_tensor.constant(0);
223         break;
224       }
225       case DT_FLOAT: {
226         auto float_tensor = input_tensor.flat<float>();
227         float_tensor = float_tensor.constant(0.0f);
228         break;
229       }
230       case DT_QUINT8: {
231         auto int_tensor = input_tensor.flat<quint8>();
232         int_tensor = int_tensor.constant(0);
233         break;
234       }
235       default:
236         LOG(FATAL) << "Unsupported input type: " << data_type;
237     }
238     input_tensors.push_back({input.first, input_tensor});
239   }
240 
241   // Setup session
242   CHECK(output_tensors != nullptr);
243   SessionOptions session_options;
244   session_options.env = Env::Default();
245   std::unique_ptr<Session> session =
246       std::unique_ptr<Session>(NewSession(session_options));
247   Status status = session->Create(graph_def);
248   if (!status.ok()) {
249     return status;
250   }
251 
252   // Setup session arguments
253   RunOptions run_options;
254   run_options.set_trace_level(RunOptions::FULL_TRACE);
255   RunMetadata run_metadata;
256 
257   // Run inference with all node as output
258   status = session->Run(run_options, input_tensors, output_node_names, {},
259                         output_tensors, &run_metadata);
260   if (!status.ok()) {
261     LOG(ERROR) << "Error during inference: " << status;
262     return status;
263   }
264   return Status();
265 }
266 
DryRunInferenceForAllNode(const GraphDef & graph_def,const std::vector<std::pair<string,Tensor>> & input_node_info_list,const bool initialize_by_zero,RemoteFusedGraphExecuteUtils::TensorShapeMap * tensor_shape_map)267 /* static */ Status RemoteFusedGraphExecuteUtils::DryRunInferenceForAllNode(
268     const GraphDef& graph_def,
269     const std::vector<std::pair<string, Tensor>>& input_node_info_list,
270     const bool initialize_by_zero,
271     RemoteFusedGraphExecuteUtils::TensorShapeMap* tensor_shape_map) {
272   CHECK(tensor_shape_map != nullptr);
273   std::vector<Tensor> output_tensors;
274   output_tensors.reserve(graph_def.node_size());
275   std::vector<string> output_node_names;
276 
277   Graph graph(OpRegistry::Global());
278   Status status = ImportGraphDef({}, graph_def, &graph, nullptr);
279   if (!status.ok()) {
280     return status;
281   }
282 
283   for (const Node* node : graph.nodes()) {
284     if (IsInputNode(input_node_info_list, node->name())) {
285       continue;
286     }
287     for (int i = 0; i < node->num_outputs(); ++i) {
288       output_node_names.emplace_back(strings::StrCat(node->name(), ":", i));
289     }
290   }
291 
292   status = DryRunInference(graph_def, input_node_info_list, output_node_names,
293                            initialize_by_zero, &output_tensors);
294   if (!status.ok()) {
295     VLOG(1) << "Failed to dryrun " << status;
296     return status;
297   }
298 
299   CHECK_EQ(output_node_names.size(), output_tensors.size())
300       << output_node_names.size() << ", " << output_tensors.size();
301 
302   // Append output tensor of input node in advance to create a map
303   // to avoid memory reallocation inside vector
304   for (const std::pair<string, Tensor>& input_node_info :
305        input_node_info_list) {
306     output_tensors.push_back(input_node_info.second);
307   }
308 
309   for (int i = 0; static_cast<size_t>(i) < output_node_names.size(); ++i) {
310     const string& name = output_node_names.at(i);
311     const Tensor& tensor = output_tensors.at(i);
312     EmplaceTensorShapeType(name, tensor, tensor_shape_map);
313   }
314   for (int i = 0; static_cast<size_t>(i) < input_node_info_list.size(); ++i) {
315     const string& name = input_node_info_list.at(i).first;
316     const Tensor& tensor = output_tensors.at(output_node_names.size() + i);
317     EmplaceTensorShapeType(name, tensor, tensor_shape_map);
318   }
319   CHECK_EQ(output_node_names.size() + input_node_info_list.size(),
320            output_tensors.size());
321   return status;
322 }
323 
IsInputNode(const std::vector<std::pair<string,Tensor>> & input_tensor_vector,const string & node_name)324 /* static */ bool RemoteFusedGraphExecuteUtils::IsInputNode(
325     const std::vector<std::pair<string, Tensor>>& input_tensor_vector,
326     const string& node_name) {
327   for (const std::pair<string, Tensor>& pair : input_tensor_vector) {
328     const TensorId tid = ParseTensorName(pair.first);
329     if (node_name == tid.first) {
330       return true;
331     }
332   }
333   return false;
334 }
335 
ConvertToTensorShapeMap(const std::vector<std::pair<string,Tensor>> & input_node_info_list,const std::vector<string> & output_node_names,const std::vector<tensorflow::Tensor> & output_tensors,TensorShapeMap * tensor_shape_map)336 /* static */ void RemoteFusedGraphExecuteUtils::ConvertToTensorShapeMap(
337     const std::vector<std::pair<string, Tensor>>& input_node_info_list,
338     const std::vector<string>& output_node_names,
339     const std::vector<tensorflow::Tensor>& output_tensors,
340     TensorShapeMap* tensor_shape_map) {
341   CHECK_NE(tensor_shape_map, nullptr);
342   tensor_shape_map->clear();
343   tensor_shape_map->reserve(input_node_info_list.size() +
344                             output_node_names.size());
345   const int output_node_count = output_node_names.size();
346   CHECK_EQ(output_node_count, output_tensors.size());
347   for (int i = 0; i < output_node_count; ++i) {
348     const string& node_name = output_node_names.at(i);
349     const Tensor& tensor = output_tensors.at(i);
350     EmplaceTensorShapeType(node_name, tensor, tensor_shape_map);
351   }
352 }
353 
MakeTensorFromProto(const TensorProto & tensor_proto,Tensor * tensor)354 /* static */ Status RemoteFusedGraphExecuteUtils::MakeTensorFromProto(
355     const TensorProto& tensor_proto, Tensor* tensor) {
356   if (tensor_proto.dtype() > 0 && tensor_proto.dtype() <= DataType_MAX) {
357     Tensor parsed(tensor_proto.dtype());
358     if (parsed.FromProto(cpu_allocator(), tensor_proto)) {
359       *tensor = parsed;
360       return Status::OK();
361     }
362   }
363   return errors::InvalidArgument("Cannot parse tensor from proto");
364 }
365 
AddOutputTensorShapeType(const std::vector<DataType> & data_types,const std::vector<TensorShape> & shapes,NodeDef * node_def)366 /* static */ bool RemoteFusedGraphExecuteUtils::AddOutputTensorShapeType(
367     const std::vector<DataType>& data_types,
368     const std::vector<TensorShape>& shapes, NodeDef* node_def) {
369   AddNodeAttr(ATTR_OUTPUT_DATA_TYPES, data_types, node_def);
370   AddNodeAttr(ATTR_OUTPUT_SHAPES, shapes, node_def);
371   return true;
372 }
373 
374 /* static */ Status
AddOutputTensorShapeTypeByTensorShapeMap(const TensorShapeMap & tensor_shape_map,NodeDef * node_def)375 RemoteFusedGraphExecuteUtils::AddOutputTensorShapeTypeByTensorShapeMap(
376     const TensorShapeMap& tensor_shape_map, NodeDef* node_def) {
377   CHECK_NE(node_def, nullptr);
378   std::priority_queue<std::tuple<int, const TensorShapeType*>> queue;
379   auto its = tensor_shape_map.equal_range(node_def->name());
380   for (auto it = its.first; it != its.second; ++it) {
381     queue.emplace(std::make_tuple(it->second.first, &it->second.second));
382   }
383   int last_port = queue.size();
384   std::vector<DataType> data_types;
385   std::vector<TensorShape> shapes;
386   while (!queue.empty()) {
387     const int port = std::get<0>(queue.top());
388     const TensorShapeType* tst = std::get<1>(queue.top());
389     CHECK_NE(tst, nullptr);
390     data_types.emplace(data_types.begin(), tst->first);
391     shapes.emplace(shapes.begin(), tst->second);
392     CHECK_EQ(last_port - 1, port);
393     last_port = port;
394     queue.pop();
395   }
396   AddOutputTensorShapeType(data_types, shapes, node_def);
397   return Status::OK();
398 }
399 
GetOutputTensorShapeType(AttrSlice attrs,std::vector<DataType> * data_types,std::vector<TensorShape> * shapes)400 /* static */ Status RemoteFusedGraphExecuteUtils::GetOutputTensorShapeType(
401     AttrSlice attrs, std::vector<DataType>* data_types,
402     std::vector<TensorShape>* shapes) {
403   Status status;
404   if (data_types != nullptr) {
405     status = GetNodeAttr(attrs, ATTR_OUTPUT_DATA_TYPES, data_types);
406   }
407   if (!status.ok()) {
408     return status;
409   }
410   if (shapes != nullptr) {
411     status = GetNodeAttr(attrs, ATTR_OUTPUT_SHAPES, shapes);
412     if (status.ok() && data_types != nullptr) {
413       CHECK_EQ(data_types->size(), shapes->size());
414     }
415   }
416 
417   return status;
418 }
419 
GetOutputTensorShapeType(const GraphDef & graph_def,const string & name_and_port,DataType * data_type,TensorShape * shape)420 /* static */ bool RemoteFusedGraphExecuteUtils::GetOutputTensorShapeType(
421     const GraphDef& graph_def, const string& name_and_port, DataType* data_type,
422     TensorShape* shape) {
423   std::vector<DataType> data_types;
424   std::vector<TensorShape> shapes;
425   const TensorId tid = ParseTensorName(name_and_port);
426   const string node_name(tid.first);
427   const int port = tid.second;
428   const NodeDef* node_def = FindNodeDefByName(node_name, graph_def);
429   CHECK_NOTNULL(node_def);
430   GetOutputTensorShapeType(*node_def, &data_types, &shapes).IgnoreError();
431   if (data_types.empty()) {
432     return false;
433   }
434   int data_types_size = data_types.size();
435   CHECK(data_types_size > port);
436   *data_type = data_types.at(port);
437   *shape = shapes.at(port);
438   return true;
439 }
440 
PropagateShapeInference(const GraphDef & graph_def,const std::vector<std::pair<string,Tensor>> & input_node_info_list,Graph * graph,ShapeRefiner * shape_refiner)441 /* static */ Status RemoteFusedGraphExecuteUtils::PropagateShapeInference(
442     const GraphDef& graph_def,
443     const std::vector<std::pair<string, Tensor>>& input_node_info_list,
444     Graph* graph, ShapeRefiner* shape_refiner) {
445   Status status;
446   auto visit = [&shape_refiner, &input_node_info_list, &status](Node* node) {
447     if (!status.ok()) {
448       return;
449     }
450     CHECK_NE(node, nullptr);
451     // If we visit an input node, we use the shape provided and set the
452     // shape accordingly.
453     bool is_input_node = false;
454     for (const std::pair<string, Tensor>& input_node_info :
455          input_node_info_list) {
456       if (node->name() == input_node_info.first) {
457         shape_inference::InferenceContext* context =
458             shape_refiner->GetContext(node);
459         shape_inference::ShapeHandle handle;
460         status = context->MakeShapeFromTensorShape(
461             input_node_info.second.shape(), &handle);
462         if (!status.ok()) {
463           break;
464         }
465         status = shape_refiner->SetShape(node, 0, handle);
466         if (!status.ok()) {
467           break;
468         }
469         is_input_node = true;
470       }
471       if (!status.ok()) {
472         break;
473       }
474     }
475     // If not an input node call AddNode() that recomputes the shape.
476     if (!is_input_node && status.ok()) {
477       status = shape_refiner->AddNode(node);
478     }
479     if (!status.ok()) {
480       VLOG(1) << "Shape inference failed for node: " << node->name();
481     }
482   };
483 
484   ReverseDFS(*graph, {}, visit);
485 
486   return status;
487 }
488 
BuildTensorShapeMapFromGraph(const Graph & graph,const ShapeRefiner & shape_refiner,TensorShapeMap * tensor_shape_map)489 /* static */ Status RemoteFusedGraphExecuteUtils::BuildTensorShapeMapFromGraph(
490     const Graph& graph, const ShapeRefiner& shape_refiner,
491     TensorShapeMap* tensor_shape_map) {
492   for (int i = 0; i < graph.num_node_ids(); ++i) {
493     const Node* node = graph.FindNodeId(i);
494     CHECK_NE(node, nullptr);
495     for (int j = 0; j < node->num_outputs(); ++j) {
496       const int output_index = j;
497       const DataType dt = node->output_type(output_index);
498       shape_inference::InferenceContext* context =
499           shape_refiner.GetContext(node);
500       CHECK_NE(context, nullptr);
501       shape_inference::ShapeHandle shape_handle = context->output(output_index);
502       if (context->RankKnown(shape_handle)) {
503         TensorShape ts;
504         for (int k = 0; k < context->Rank(shape_handle); ++k) {
505           shape_inference::DimensionHandle dh = context->Dim(shape_handle, k);
506           CHECK(context->ValueKnown(dh));
507           ts.AddDim(context->Value(dh));
508         }
509         const string& node_name = node->name();
510         CHECK(tensor_shape_map->count(node_name) == 0);
511         tensor_shape_map->emplace(node_name,
512                                   std::make_pair(j, std::make_pair(dt, ts)));
513       } else {
514         return errors::InvalidArgument("Graph contains unknown shapes");
515       }
516     }
517   }
518   return Status::OK();
519 }
520 
521 /* static */ const RemoteFusedGraphExecuteUtils::TensorShapeType*
GetTensorShapeType(const TensorShapeMap & tensor_shape_map,const string & node_name)522 RemoteFusedGraphExecuteUtils::GetTensorShapeType(
523     const TensorShapeMap& tensor_shape_map, const string& node_name) {
524   if (node_name.find(':') != string::npos) {
525     const TensorId tid = ParseTensorName(node_name);
526     return GetTensorShapeType(tensor_shape_map, string(tid.first), tid.second);
527   } else {
528     return GetTensorShapeType(tensor_shape_map, node_name, 0);
529   }
530 }
531 
532 /* static */ const RemoteFusedGraphExecuteUtils::TensorShapeType*
GetTensorShapeType(const TensorShapeMap & tensor_shape_map,const string & node_name,const int port)533 RemoteFusedGraphExecuteUtils::GetTensorShapeType(
534     const TensorShapeMap& tensor_shape_map, const string& node_name,
535     const int port) {
536   CHECK_EQ(node_name.find(':'), string::npos);
537   if (tensor_shape_map.count(node_name) <= 0) {
538     return nullptr;
539   }
540   auto its = tensor_shape_map.equal_range(node_name);
541   for (auto it = its.first; it != its.second; ++it) {
542     if (it->second.first == port) {
543       return &it->second.second;
544     }
545   }
546   return nullptr;
547 }
548 
549 /* static */ void
BuildRemoteGraphInputsAndOutputsFromProto(const RemoteFusedGraphExecuteInfo & proto,std::vector<std::pair<string,Tensor>> * inputs,std::vector<string> * outputs)550 RemoteFusedGraphExecuteUtils::BuildRemoteGraphInputsAndOutputsFromProto(
551     const RemoteFusedGraphExecuteInfo& proto,
552     std::vector<std::pair<string, Tensor>>* inputs,
553     std::vector<string>* outputs) {
554   CHECK_EQ(proto.graph_input_node_name_size(),
555            proto.default_graph_input_tensor_shape_size());
556   for (int i = 0; i < proto.graph_input_node_name_size(); ++i) {
557     inputs->emplace_back(
558         proto.graph_input_node_name(i),
559         Tensor(proto.default_graph_input_tensor_shape(i).dtype(),
560                TensorShape(proto.default_graph_input_tensor_shape(i).shape())));
561   }
562   for (const string& output_node_name : proto.graph_output_node_name()) {
563     outputs->emplace_back(output_node_name);
564   }
565 }
566 
EmplaceTensorShapeType(const string & name,const Tensor & tensor,TensorShapeMap * tensor_shape_map)567 /* static */ void RemoteFusedGraphExecuteUtils::EmplaceTensorShapeType(
568     const string& name, const Tensor& tensor,
569     TensorShapeMap* tensor_shape_map) {
570   const TensorId tid = ParseTensorName(name);
571   CHECK_EQ(tensor_shape_map->count(name), 0);
572   tensor_shape_map->emplace(
573       string(tid.first),
574       std::make_pair(tid.second,
575                      std::make_pair(tensor.dtype(), tensor.shape())));
576 }
577 
BuildAndAddTensorShapes(const std::vector<std::pair<string,Tensor>> & input_tensors,const bool dry_run_inference,GraphDef * graph_def)578 /* static */ Status RemoteFusedGraphExecuteUtils::BuildAndAddTensorShapes(
579     const std::vector<std::pair<string, Tensor>>& input_tensors,
580     const bool dry_run_inference, GraphDef* graph_def) {
581   TensorShapeMap tensor_shape_map;
582   if (dry_run_inference) {
583     TF_RETURN_IF_ERROR(DryRunInferenceForAllNode(*graph_def, input_tensors,
584                                                  /*initialize_by_zero=*/true,
585                                                  &tensor_shape_map));
586   } else {
587     ImportGraphDefOptions opts;
588     Graph graph(OpRegistry::Global());
589     ShapeRefiner shape_refiner(graph.versions(), graph.op_registry());
590     TF_RETURN_IF_ERROR(
591         ImportGraphDef(opts, *graph_def, &graph, &shape_refiner));
592     TF_RETURN_IF_ERROR(PropagateShapeInference(*graph_def, input_tensors,
593                                                &graph, &shape_refiner));
594     TF_RETURN_IF_ERROR(
595         BuildTensorShapeMapFromGraph(graph, shape_refiner, &tensor_shape_map));
596   }
597 
598   for (NodeDef& node_def : *graph_def->mutable_node()) {
599     TF_RETURN_IF_ERROR(
600         AddOutputTensorShapeTypeByTensorShapeMap(tensor_shape_map, &node_def));
601   }
602 
603   return Status::OK();
604 }
605 
606 /* static */ Status
BuildRemoteFusedGraphExecuteInfo(const string & executor_name,const GraphDef & subgraph_def,const std::vector<string> & inputs,const std::vector<string> & outputs,const bool require_shape_type,RemoteFusedGraphExecuteInfo * execute_info,DataTypeVector * input_types,DataTypeVector * output_types)607 RemoteFusedGraphExecuteUtils::BuildRemoteFusedGraphExecuteInfo(
608     const string& executor_name, const GraphDef& subgraph_def,
609     const std::vector<string>& inputs, const std::vector<string>& outputs,
610     const bool require_shape_type, RemoteFusedGraphExecuteInfo* execute_info,
611     DataTypeVector* input_types, DataTypeVector* output_types) {
612   CHECK_NOTNULL(execute_info);
613   CHECK_NOTNULL(input_types);
614   CHECK_NOTNULL(output_types);
615 
616   execute_info->Clear();
617   execute_info->set_executor_name(executor_name);
618 
619   // copy graph
620   *execute_info->mutable_remote_graph() = subgraph_def;
621 
622   for (const string& input : inputs) {
623     DataType dt;
624     TensorShape shape;
625     const bool has_shapetype =
626         GetOutputTensorShapeType(subgraph_def, input, &dt, &shape);
627 
628     execute_info->add_graph_input_node_name(input);
629     if (has_shapetype) {
630       RemoteFusedGraphExecuteInfo::TensorShapeTypeProto& tensor_shape_type =
631           *execute_info->add_default_graph_input_tensor_shape();
632       tensor_shape_type.set_dtype(dt);
633       TensorShapeProto& tensor_shape_proto = *tensor_shape_type.mutable_shape();
634       for (const int64 dim : shape.dim_sizes()) {
635         tensor_shape_proto.add_dim()->set_size(dim);
636       }
637       input_types->push_back(dt);
638     } else {
639       CHECK(!require_shape_type)
640           << "No shape type found for " << input << DumpGraphDef(subgraph_def);
641       // Assuming input type is float if no data provided.
642       input_types->push_back(DT_FLOAT);
643     }
644   }
645 
646   for (const string& output : outputs) {
647     DataType dt;
648     TensorShape shape;
649     const bool has_shapetype =
650         GetOutputTensorShapeType(subgraph_def, output, &dt, &shape);
651 
652     execute_info->add_graph_output_node_name(output);
653     if (has_shapetype) {
654       RemoteFusedGraphExecuteInfo::TensorShapeTypeProto&
655           tensor_shape_type_proto =
656               *execute_info->add_default_graph_output_tensor_shape();
657       tensor_shape_type_proto.set_dtype(dt);
658       TensorShapeProto& tensor_shape_proto =
659           *tensor_shape_type_proto.mutable_shape();
660       for (const int64 dim : shape.dim_sizes()) {
661         tensor_shape_proto.add_dim()->set_size(dim);
662       }
663       output_types->push_back(dt);
664     } else {
665       CHECK(!require_shape_type)
666           << "No shape type found for " << output << DumpGraphDef(subgraph_def);
667       // Assuming output type is float if no data provided.
668       output_types->push_back(DT_FLOAT);
669     }
670   }
671 
672   return Status::OK();
673 }
674 
675 /* static */ Status
BuildRemoteFusedGraphExecuteOpNode(const string & node_name,const string & executor_name,const GraphDef & subgraph_def,const std::vector<string> & inputs,const std::vector<string> & outputs,const bool require_shape_type,Graph * graph,Node ** created_node)676 RemoteFusedGraphExecuteUtils::BuildRemoteFusedGraphExecuteOpNode(
677     const string& node_name, const string& executor_name,
678     const GraphDef& subgraph_def, const std::vector<string>& inputs,
679     const std::vector<string>& outputs, const bool require_shape_type,
680     Graph* graph, Node** created_node) {
681   CHECK_NOTNULL(graph);
682   CHECK_NOTNULL(created_node);
683 
684   RemoteFusedGraphExecuteInfo execute_info;
685   DataTypeVector input_types;
686   DataTypeVector output_types;
687 
688   TF_CHECK_OK(RemoteFusedGraphExecuteUtils::BuildRemoteFusedGraphExecuteInfo(
689       executor_name, subgraph_def, inputs, outputs, require_shape_type,
690       &execute_info, &input_types, &output_types));
691 
692   std::vector<NodeBuilder::NodeOut> node_out_list;
693   for (const string& input : inputs) {
694     const TensorId tid = ParseTensorName(input);
695     Node* node = FindMutableNodeByName(string(tid.first), graph);
696     CHECK_NOTNULL(node);
697     node_out_list.emplace_back(node, tid.second);
698   }
699 
700   const string execute_info_str = execute_info.SerializeAsString();
701 
702   auto builder =
703       NodeBuilder(node_name, "RemoteFusedGraphExecute")
704           .Input(node_out_list)
705           .Attr("Tinputs", input_types)
706           .Attr("Toutputs", output_types)
707           .Attr("serialized_remote_fused_graph_execute_info", execute_info_str);
708 
709   TF_RETURN_IF_ERROR(builder.Finalize(graph, created_node));
710   return Status::OK();
711 }
712 
BuildIdentityOpNode(const string & node_name,const string & input_node_name,const int input_node_port,const DataType dt,Graph * graph,Node ** created_node)713 /* static */ Status RemoteFusedGraphExecuteUtils::BuildIdentityOpNode(
714     const string& node_name, const string& input_node_name,
715     const int input_node_port, const DataType dt, Graph* graph,
716     Node** created_node) {
717   Node* node = FindMutableNodeByName(input_node_name, graph);
718   CHECK_NOTNULL(node);
719   NodeBuilder::NodeOut node_out(node, input_node_port);
720 
721   auto builder =
722       NodeBuilder(node_name, "Identity").Input(node_out).Attr("T", dt);
723 
724   TF_RETURN_IF_ERROR(builder.Finalize(graph, created_node));
725   return Status::OK();
726 }
727 
ClusterizeNodes(const std::unordered_set<string> & node_names,const GraphDef & graph_def,std::vector<ClusterInfo> * cluster_infos)728 /* static */ Status RemoteFusedGraphExecuteUtils::ClusterizeNodes(
729     const std::unordered_set<string>& node_names, const GraphDef& graph_def,
730     std::vector<ClusterInfo>* cluster_infos) {
731   Graph graph(OpRegistry::Global());
732   ShapeRefiner shape_refiner(graph.versions(), graph.op_registry());
733   TF_RETURN_IF_ERROR(ImportGraphDef({}, graph_def, &graph, &shape_refiner));
734   std::unordered_set<string> remaining_nodes = node_names;
735 
736   while (!remaining_nodes.empty()) {
737     ClusterInfo ci;
738 
739     // Determine one cluster nodes
740     std::unordered_set<const Node*> visited;
741     std::deque<const Node*> queue;
742     queue.emplace_back(FindNodeByName(*remaining_nodes.begin(), graph));
743     while (!queue.empty()) {
744       const Node* node = queue.front();
745       CHECK_NOTNULL(node);
746       queue.pop_front();
747       const string& node_name = node->name();
748       if (node_names.count(node_name) > 0) {
749         std::get<0>(ci).emplace(node_name);
750         remaining_nodes.erase(node_name);
751       } else {
752         // Edge of subgraph.  Do nothing.
753         continue;
754       }
755       for (const Node* in : node->in_nodes()) {
756         if (visited.insert(in).second) {
757           queue.push_back(in);
758         }
759       }
760       for (const Node* out : node->out_nodes()) {
761         if (visited.insert(out).second) {
762           queue.push_back(out);
763         }
764       }
765     }
766 
767     // Determine one cluster border
768     std::vector<string>& border_inputs = std::get<1>(ci);
769     std::vector<string>& border_outputs = std::get<2>(ci);
770     for (const string& node_name : node_names) {
771       Node* node = FindMutableNodeByName(node_name, &graph);
772       CHECK_NOTNULL(node);
773       int input_count = 0;
774       for (const Edge* in_edge : node->in_edges()) {
775         const Node* src_node = in_edge->src();
776         const bool src_is_outside =
777             node_names.count(src_node->name()) <= 0 && !src_node->IsSource();
778         if (src_is_outside) {
779           const string src_name =
780               strings::StrCat(src_node->name(), ":", in_edge->src_output());
781           CHECK_EQ(1, src_node->num_outputs())
782               << "output count of input border node must be one."
783               << src_node->name();
784           if (std::find(border_inputs.begin(), border_inputs.end(), src_name) ==
785               border_inputs.end()) {
786             border_inputs.emplace_back(src_name);
787           }
788         } else {
789           ++input_count;
790         }
791       }
792       int node_in_edges_size = node->in_edges().size();
793       CHECK(input_count == 0 || input_count == node_in_edges_size)
794           << "Invalid input_count(" << input_count << ", "
795           << node->in_edges().size() << ") " << node_name;
796 
797       for (const Edge* out_edge : node->out_edges()) {
798         const Node* dst_node = out_edge->dst();
799         CHECK_NOTNULL(dst_node);
800         const bool dst_is_outside = node_names.count(dst_node->name()) <= 0;
801         const string dst_name =
802             strings::StrCat(node->name(), ":", out_edge->src_output());
803         if (dst_is_outside) {
804           if (dst_node->IsSink()) {
805             CHECK_EQ(1, node->num_outputs())
806                 << "If you want to specify output node as subgraph output node "
807                 << "the output count of the node must be 1 "
808                 << "because that node is replaced by identity node.";
809             const string identity_dst_name =
810                 strings::StrCat(node->name(), ":", 0);
811             if (std::find(border_outputs.begin(), border_outputs.end(),
812                           identity_dst_name) == border_outputs.end()) {
813               border_outputs.emplace_back(identity_dst_name);
814             }
815           } else {
816             if (std::find(border_outputs.begin(), border_outputs.end(),
817                           dst_name) == border_outputs.end()) {
818               border_outputs.emplace_back(dst_name);
819             }
820           }
821         }
822       }
823     }
824     cluster_infos->emplace_back(ci);
825     VLOG(1) << DumpCluster(ci);
826   }
827   return Status::OK();
828 }
829 
BuildClusterSubgraphDef(const ClusterInfo & cluster,const GraphDef & graph_def,GraphDef * subgraph_def)830 /* static */ Status RemoteFusedGraphExecuteUtils::BuildClusterSubgraphDef(
831     const ClusterInfo& cluster, const GraphDef& graph_def,
832     GraphDef* subgraph_def) {
833   const std::unordered_set<string>& node_names = std::get<0>(cluster);
834   const std::unordered_set<string>& border_input_names =
835       BuildNodeSetFromNodeNamesAndPorts(std::get<1>(cluster));
836 
837   Graph graph(OpRegistry::Global());
838   ShapeRefiner shape_refiner(graph.versions(), graph.op_registry());
839   TF_RETURN_IF_ERROR(ImportGraphDef({}, graph_def, &graph, &shape_refiner));
840 
841   for (Node* node : graph.nodes()) {
842     if (node != nullptr && node_names.count(node->name()) <= 0 &&
843         border_input_names.count(node->name()) <= 0 && !node->IsSource() &&
844         !node->IsSink()) {
845       graph.RemoveNode(node);
846     }
847   }
848   graph.ToGraphDef(subgraph_def);
849 
850   for (const string& subgraph_input : std::get<1>(cluster)) {
851     const TensorId tid = ParseTensorName(subgraph_input);
852     const string subgraph_input_name(tid.first);
853     const int subgraph_input_port = tid.second;
854     const NodeDef* node_def = FindNodeDefByName(subgraph_input_name, graph_def);
855     CHECK_NOTNULL(node_def);
856     std::vector<DataType> dt_vec;
857     std::vector<TensorShape> shape_vec;
858     GetOutputTensorShapeType(*node_def, &dt_vec, &shape_vec).IgnoreError();
859     const DataType& dt =
860         dt_vec.empty() ? DT_FLOAT : dt_vec.at(subgraph_input_port);
861     const TensorShape& shape =
862         shape_vec.empty() ? TensorShape({}) : shape_vec.at(subgraph_input_port);
863 
864     TF_RETURN_IF_ERROR(ReplaceInputNodeByPlaceHolder(subgraph_input_name, dt,
865                                                      shape, subgraph_def));
866   }
867 
868   // sort subgraph_def to align order in graph_def
869   std::unordered_map<string, int> name_to_id_map;
870   for (int i = 0; i < graph_def.node_size(); ++i) {
871     name_to_id_map.emplace(graph_def.node(i).name(), i);
872   }
873   std::sort(subgraph_def->mutable_node()->begin(),
874             subgraph_def->mutable_node()->end(),
875             [&name_to_id_map](const NodeDef& node0, const NodeDef& node1) {
876               CHECK(name_to_id_map.count(node0.name()) > 0);
877               CHECK(name_to_id_map.count(node1.name()) > 0);
878               const int id0 = name_to_id_map.at(node0.name());
879               const int id1 = name_to_id_map.at(node1.name());
880               return id0 < id1;
881             });
882 
883   VLOG(1) << DumpGraphDef(*subgraph_def);
884   return Status::OK();
885 }
886 
BuildClusterByBorder(const std::vector<string> & border_inputs,const std::vector<string> & border_outputs,const GraphDef & graph_def,ClusterInfo * cluster)887 /* static */ Status RemoteFusedGraphExecuteUtils::BuildClusterByBorder(
888     const std::vector<string>& border_inputs,
889     const std::vector<string>& border_outputs, const GraphDef& graph_def,
890     ClusterInfo* cluster) {
891   Graph graph(OpRegistry::Global());
892   ShapeRefiner shape_refiner(graph.versions(), graph.op_registry());
893   TF_RETURN_IF_ERROR(ImportGraphDef({}, graph_def, &graph, &shape_refiner));
894 
895   std::unordered_set<const Node*> visited;
896   std::deque<const Node*> queue;
897   for (const string& output : border_outputs) {
898     const TensorId tid = ParseTensorName(output);
899     const string output_node_name(tid.first);
900     for (const Node* node : graph.nodes()) {
901       if (output_node_name == node->name()) {
902         queue.push_back(node);
903         visited.insert(node);
904       }
905     }
906   }
907 
908   std::unordered_set<const Node*> border_input_nodes;
909   // propagate visit to parent nodes until input nodes
910   while (!queue.empty()) {
911     const Node* node = queue.front();
912     queue.pop_front();
913     for (const Edge* edge : node->in_edges()) {
914       const Node* src_node = edge->src();
915       CHECK_NOTNULL(src_node);
916       const int src_port = edge->src_output();
917       bool input_found = false;
918       for (const string& input : border_inputs) {
919         const TensorId tid = ParseTensorName(input);
920         if (tid.first == src_node->name() && tid.second == src_port) {
921           input_found = true;
922           border_input_nodes.insert(src_node);
923         }
924       }
925       if (visited.insert(src_node).second) {
926         if (!input_found) {
927           queue.push_back(src_node);
928         }
929       }
930     }
931   }
932 
933   for (const Node* node : visited) {
934     if (node != nullptr && !node->IsSource() && !node->IsSink() &&
935         border_input_nodes.count(node) <= 0) {
936       std::get<0>(*cluster).insert(node->name());
937     }
938   }
939   std::get<1>(*cluster) = border_inputs;
940   std::get<2>(*cluster) = border_outputs;
941   return Status::OK();
942 }
943 
FuseCluster(const GraphDef & input_graph_def,const std::vector<string> & inputs,const std::vector<string> & outputs,const string & remote_fused_graph_node_name,const ClusterInfo & cluster,const string & remote_graph_executor_name,const bool require_shape_type,GraphDef * output_graph_def)944 /* static */ Status RemoteFusedGraphExecuteUtils::FuseCluster(
945     const GraphDef& input_graph_def, const std::vector<string>& inputs,
946     const std::vector<string>& outputs,
947     const string& remote_fused_graph_node_name, const ClusterInfo& cluster,
948     const string& remote_graph_executor_name, const bool require_shape_type,
949     GraphDef* output_graph_def) {
950   LOG(INFO) << "Transforming quantized stripped model to a remote fused "
951                "graph execute op by fusing a specified subgraph...";
952 
953   CHECK(!remote_graph_executor_name.empty());
954 
955   const std::vector<string>& border_inputs = std::get<1>(cluster);
956   const std::vector<string>& border_outputs = std::get<2>(cluster);
957 
958   GraphDef subgraph_def;
959   TF_RETURN_IF_ERROR(
960       BuildClusterSubgraphDef(cluster, input_graph_def, &subgraph_def));
961 
962   Graph graph(OpRegistry::Global());
963   ShapeRefiner shape_refiner(graph.versions(), graph.op_registry());
964   TF_RETURN_IF_ERROR(
965       ImportGraphDef({}, input_graph_def, &graph, &shape_refiner));
966 
967   Node* fused_node;
968   TF_RETURN_IF_ERROR(BuildRemoteFusedGraphExecuteOpNode(
969       remote_fused_graph_node_name, remote_graph_executor_name, subgraph_def,
970       border_inputs, border_outputs, require_shape_type, &graph, &fused_node));
971 
972   for (const Node* node : graph.nodes()) {
973     for (int i = 0, end = node->num_inputs(); i < end; ++i) {
974       const Edge* edge = nullptr;
975       TF_RETURN_IF_ERROR(node->input_edge(i, &edge));
976       for (int j = 0, second_end = border_outputs.size(); j < second_end; ++j) {
977         const string& output = border_outputs.at(j);
978         const TensorId tid = ParseTensorName(output);
979         const string output_name(tid.first);
980         Node* src_node = edge->src();
981         if (src_node != nullptr && src_node->name() == output_name &&
982             edge->src_output() == tid.second) {
983           // Source node is replaced by new fused node.
984           Node* dst_node = edge->dst();
985           const int dst_input = edge->dst_input();
986           LOG(INFO) << "Removing existing edge to " << edge->dst()->name()
987                     << " from " << edge->src()->name();
988           graph.RemoveEdge(edge);
989           graph.AddEdge(fused_node, j, dst_node, dst_input);
990         }
991       }
992     }
993   }
994 
995   // Replace output nodes by identity nodes which forward outputs from
996   // RemoteFusedGraphExecuteOpNode
997   for (const string& output : outputs) {
998     const TensorId output_tid = ParseTensorName(output);
999     const string output_name(output_tid.first);
1000     for (size_t i = 0; i < border_outputs.size(); ++i) {
1001       const TensorId subgraph_output_tid =
1002           ParseTensorName(border_outputs.at(i));
1003       const string subgraph_output_name(subgraph_output_tid.first);
1004       if (output_name == subgraph_output_name) {
1005         LOG(INFO) << "As graph output and subgraph output are same, "
1006                   << "the graph output node is replaced by identity node";
1007         Node* original_output_node = FindMutableNodeByName(output_name, &graph);
1008         CHECK_NOTNULL(original_output_node);
1009         CHECK_EQ(1, original_output_node->num_outputs())
1010             << "Num outputs should be 1 for " << output << ".";
1011         graph.RemoveNode(original_output_node);
1012         Node* new_node;
1013         TF_RETURN_IF_ERROR(BuildIdentityOpNode(output_name,
1014                                                remote_fused_graph_node_name, i,
1015                                                DT_FLOAT, &graph, &new_node));
1016         CHECK_NOTNULL(new_node);
1017       }
1018     }
1019   }
1020 
1021   GraphDef result_graph_def;
1022 
1023   graph.ToGraphDef(&result_graph_def);
1024 
1025   ClusterInfo graph_cluster;
1026   TF_RETURN_IF_ERROR(
1027       BuildClusterByBorder(inputs, outputs, result_graph_def, &graph_cluster));
1028 
1029   // Remove unvisited nodes
1030   TF_RETURN_IF_ERROR(BuildClusterSubgraphDef(graph_cluster, result_graph_def,
1031                                              output_graph_def));
1032 
1033   return Status::OK();
1034 }
1035 
FuseRemoteGraphByNodeNames(const GraphDef & input_graph_def,const std::vector<string> & inputs,const std::vector<string> & outputs,const string & remote_fused_graph_node_name_prefix,const std::unordered_set<string> & subgraph_nodes,const string & remote_fused_graph_executor_name,const bool require_shape_type,GraphDef * output_graph_def)1036 /* static */ Status RemoteFusedGraphExecuteUtils::FuseRemoteGraphByNodeNames(
1037     const GraphDef& input_graph_def, const std::vector<string>& inputs,
1038     const std::vector<string>& outputs,
1039     const string& remote_fused_graph_node_name_prefix,
1040     const std::unordered_set<string>& subgraph_nodes,
1041     const string& remote_fused_graph_executor_name,
1042     const bool require_shape_type, GraphDef* output_graph_def) {
1043   std::vector<ClusterInfo> ci_vec;
1044   TF_RETURN_IF_ERROR(RemoteFusedGraphExecuteUtils::ClusterizeNodes(
1045       subgraph_nodes, input_graph_def, &ci_vec));
1046 
1047   for (size_t i = 0; i < ci_vec.size(); ++i) {
1048     const string remote_fused_graph_node_name =
1049         strings::StrCat(remote_fused_graph_node_name_prefix, "/", i);
1050     TF_RETURN_IF_ERROR(FuseCluster(input_graph_def, inputs, outputs,
1051                                    remote_fused_graph_node_name, ci_vec.at(i),
1052                                    remote_fused_graph_executor_name,
1053                                    require_shape_type, output_graph_def));
1054   }
1055   return Status::OK();
1056 }
1057 
FuseRemoteGraphByBorder(const GraphDef & input_graph_def,const std::vector<string> & inputs,const std::vector<string> & outputs,const string & remote_fused_graph_node_name,const std::vector<string> & border_inputs,const std::vector<string> & border_outputs,const string & remote_graph_executor_name,const bool require_shape_type,GraphDef * output_graph_def)1058 /* static */ Status RemoteFusedGraphExecuteUtils::FuseRemoteGraphByBorder(
1059     const GraphDef& input_graph_def, const std::vector<string>& inputs,
1060     const std::vector<string>& outputs,
1061     const string& remote_fused_graph_node_name,
1062     const std::vector<string>& border_inputs,
1063     const std::vector<string>& border_outputs,
1064     const string& remote_graph_executor_name, const bool require_shape_type,
1065     GraphDef* output_graph_def) {
1066   ClusterInfo cluster;
1067   TF_RETURN_IF_ERROR(RemoteFusedGraphExecuteUtils::BuildClusterByBorder(
1068       border_inputs, border_outputs, input_graph_def, &cluster));
1069 
1070   return FuseCluster(
1071       input_graph_def, inputs, outputs, remote_fused_graph_node_name, cluster,
1072       remote_graph_executor_name, require_shape_type, output_graph_def);
1073 }
1074 
FuseRemoteGraphByOpTypes(const GraphDef & input_graph_def,const std::vector<string> & inputs,const std::vector<string> & outputs,const string & remote_fused_graph_node_name_prefix,const std::unordered_set<string> & fused_op_types,const string & remote_fused_graph_executor_name,const bool require_shape_type,GraphDef * output_graph_def)1075 /* static */ Status RemoteFusedGraphExecuteUtils::FuseRemoteGraphByOpTypes(
1076     const GraphDef& input_graph_def, const std::vector<string>& inputs,
1077     const std::vector<string>& outputs,
1078     const string& remote_fused_graph_node_name_prefix,
1079     const std::unordered_set<string>& fused_op_types,
1080     const string& remote_fused_graph_executor_name,
1081     const bool require_shape_type, GraphDef* output_graph_def) {
1082   const std::unordered_set<string> fused_nodes_filtered_by_op_types =
1083       BuildNodeMapFromOpTypes(input_graph_def, fused_op_types);
1084 
1085   return FuseRemoteGraphByNodeNames(
1086       input_graph_def, inputs, outputs, remote_fused_graph_node_name_prefix,
1087       fused_nodes_filtered_by_op_types, remote_fused_graph_executor_name,
1088       require_shape_type, output_graph_def);
1089 }
1090 
FuseRemoteGraphByExecutor(const GraphDef & input_graph_def,const std::vector<string> & inputs,const std::vector<string> & outputs,const string & executor_name,GraphDef * output_graph_def)1091 /* static */ Status RemoteFusedGraphExecuteUtils::FuseRemoteGraphByExecutor(
1092     const GraphDef& input_graph_def, const std::vector<string>& inputs,
1093     const std::vector<string>& outputs, const string& executor_name,
1094     GraphDef* output_graph_def) {
1095   const ExecutorBuildFunc* build_func = GetExecutorBuildFunc(executor_name);
1096   if (build_func == nullptr) {
1097     return errors::InvalidArgument("Unknown executor name: " + executor_name);
1098   }
1099   std::unique_ptr<IRemoteFusedGraphExecutor> executor;
1100   TF_RETURN_IF_ERROR((*build_func)(&executor));
1101   CHECK_NOTNULL(executor.get());
1102   if (!executor->IsEnabled()) {
1103     // As this executor is not enabled, just return original graph as is.
1104     *output_graph_def = input_graph_def;
1105     return Status::OK();
1106   }
1107   return executor->FuseRemoteGraph(input_graph_def, inputs, outputs,
1108                                    output_graph_def);
1109 }
1110 
PlaceRemoteGraphArguments(const std::vector<string> & inputs,const std::vector<string> & outputs,const std::unordered_set<string> & fused_node_names,const std::vector<string> & border_inputs,const std::vector<string> & border_outputs,const std::unordered_set<string> & fused_op_types,const string & remote_fused_graph_node_name,const string & remote_graph_executor_name,GraphDef * graph_def)1111 /* static */ Status RemoteFusedGraphExecuteUtils::PlaceRemoteGraphArguments(
1112     const std::vector<string>& inputs, const std::vector<string>& outputs,
1113     const std::unordered_set<string>& fused_node_names,
1114     const std::vector<string>& border_inputs,
1115     const std::vector<string>& border_outputs,
1116     const std::unordered_set<string>& fused_op_types,
1117     const string& remote_fused_graph_node_name,
1118     const string& remote_graph_executor_name, GraphDef* graph_def) {
1119   CHECK_NOTNULL(graph_def);
1120 
1121   const std::unordered_set<string> fused_nodes_filtered_by_op_types =
1122       BuildNodeMapFromOpTypes(*graph_def, fused_op_types);
1123 
1124   for (NodeDef& node_def : *graph_def->mutable_node()) {
1125     string attr_str;
1126     TensorId tid;
1127     for (size_t i = 0; i < inputs.size(); ++i) {
1128       if (IsSameNodeName(node_def, inputs.at(i), &tid)) {
1129         AppendDeliminator(&attr_str);
1130         attr_str += BuildNodeTypeAttr(GRAPH_INPUT, tid.second, i,
1131                                       remote_graph_executor_name,
1132                                       remote_fused_graph_node_name);
1133       }
1134     }
1135     for (size_t i = 0; i < outputs.size(); ++i) {
1136       if (IsSameNodeName(node_def, outputs.at(i), &tid)) {
1137         AppendDeliminator(&attr_str);
1138         attr_str += BuildNodeTypeAttr(GRAPH_OUTPUT, tid.second, i);
1139       }
1140     }
1141     for (const string& fused_node_name : fused_node_names) {
1142       if (fused_node_name == node_def.name()) {
1143         AppendDeliminator(&attr_str);
1144         attr_str += BuildNodeTypeAttr(FUSED_NODE);
1145       }
1146     }
1147     for (const string& fused_node_name : fused_nodes_filtered_by_op_types) {
1148       if (fused_node_name == node_def.name()) {
1149         AppendDeliminator(&attr_str);
1150         attr_str += BuildNodeTypeAttr(FUSED_NODE);
1151       }
1152     }
1153     for (size_t i = 0; i < border_inputs.size(); ++i) {
1154       if (IsSameNodeName(node_def, border_inputs.at(i), &tid)) {
1155         AppendDeliminator(&attr_str);
1156         attr_str += BuildNodeTypeAttr(BORDER_INPUT, tid.second, i);
1157       }
1158     }
1159     for (size_t i = 0; i < border_outputs.size(); ++i) {
1160       if (IsSameNodeName(node_def, border_outputs.at(i), &tid)) {
1161         AppendDeliminator(&attr_str);
1162         attr_str += BuildNodeTypeAttr(BORDER_OUTPUT, tid.second, i);
1163       }
1164     }
1165     if (attr_str.empty()) {
1166       attr_str += BuildNodeTypeAttr(UNUSED);
1167     }
1168     AddNodeAttr(ATTR_NODE_TYPE, attr_str, &node_def);
1169   }
1170   return Status::OK();
1171 }
1172 
1173 /* static */ Status
FuseRemoteGraphByPlacedArguments(const GraphDef & input_graph_def,const std::vector<std::pair<string,Tensor>> & input_tensors,GraphDef * output_graph_def)1174 RemoteFusedGraphExecuteUtils::FuseRemoteGraphByPlacedArguments(
1175     const GraphDef& input_graph_def,
1176     const std::vector<std::pair<string, Tensor>>& input_tensors,
1177     GraphDef* output_graph_def) {
1178   std::unordered_map<int, string> input_map;
1179   std::unordered_map<int, string> output_map;
1180   std::unordered_set<string> fused_node_names;
1181   std::unordered_map<int, string> border_input_map;
1182   std::unordered_map<int, string> border_output_map;
1183   string remote_graph_executor_name;
1184   string remote_fused_graph_node_name;
1185 
1186   for (const NodeDef& node_def : input_graph_def.node()) {
1187     string attr_str;
1188     TF_RETURN_IF_ERROR(GetNodeAttr(node_def, ATTR_NODE_TYPE, &attr_str));
1189     std::vector<std::vector<string>> attr_strs;
1190     for (const string& str : str_util::Split(attr_str, ":")) {
1191       attr_strs.emplace_back(str_util::Split(str, ","));
1192     }
1193     if (attr_strs.empty()) {
1194       return errors::InvalidArgument("Remote graph node type not found.");
1195     }
1196     for (const std::vector<string>& attr : attr_strs) {
1197       if (attr.empty()) {
1198         return errors::InvalidArgument("Empty remote graph node type attr.");
1199       }
1200       int node_type_int;
1201       CHECK(strings::safe_strto32(attr.at(0), &node_type_int)) << attr.at(0);
1202       const RemoteFusedGraphNodeType node_type =
1203           static_cast<RemoteFusedGraphNodeType>(node_type_int);
1204       const string& name = node_def.name();
1205       int port;
1206       int index;
1207 
1208       switch (node_type) {
1209         case GRAPH_INPUT:
1210           VLOG(2) << "Graph input: " << name;
1211           CHECK_EQ(5, attr.size());
1212           CHECK(strings::safe_strto32(attr.at(1), &port));
1213           CHECK(strings::safe_strto32(attr.at(2), &index));
1214           CHECK(!attr.at(3).empty());
1215           remote_graph_executor_name = attr.at(3);
1216           CHECK(!attr.at(4).empty());
1217           remote_fused_graph_node_name = attr.at(4);
1218           input_map.emplace(index, strings::StrCat(name, ":", port));
1219           if (GetExecutorBuildFunc(remote_graph_executor_name) == nullptr) {
1220             LOG(INFO) << "Executor for " << remote_graph_executor_name
1221                       << " not registered.  Do not fuse.";
1222             *output_graph_def = input_graph_def;
1223             return Status::OK();
1224           }
1225           break;
1226         case GRAPH_OUTPUT:
1227           VLOG(2) << "Graph output: " << name;
1228           CHECK_EQ(3, attr.size());
1229           CHECK(strings::safe_strto32(attr.at(1), &port));
1230           CHECK(strings::safe_strto32(attr.at(2), &index));
1231           output_map.emplace(index, strings::StrCat(name, ":", port));
1232           break;
1233         case FUSED_NODE:
1234           VLOG(2) << "Fused node: " << name;
1235           CHECK_EQ(1, attr.size());
1236           fused_node_names.emplace(name);
1237           break;
1238         case BORDER_INPUT:
1239           VLOG(2) << "Border input: " << name;
1240           CHECK_EQ(3, attr.size());
1241           CHECK(strings::safe_strto32(attr.at(1), &port));
1242           CHECK(strings::safe_strto32(attr.at(2), &index));
1243           border_input_map.emplace(index, strings::StrCat(name, ":", port));
1244           break;
1245         case BORDER_OUTPUT:
1246           VLOG(2) << "Border output: " << name;
1247           CHECK_EQ(3, attr.size());
1248           CHECK(strings::safe_strto32(attr.at(1), &port));
1249           CHECK(strings::safe_strto32(attr.at(2), &index));
1250           border_output_map.emplace(index, strings::StrCat(name, ":", port));
1251           break;
1252         case UNUSED:
1253           // do nothing
1254           break;
1255         default:
1256           // unsupported value
1257           LOG(FATAL);
1258       }
1259     }
1260   }
1261   bool require_shape_type = false;
1262   std::vector<string> inputs;
1263   std::vector<string> outputs;
1264   std::vector<string> border_inputs;
1265   std::vector<string> border_outputs;
1266   ConvertMapToVector(input_map, &inputs);
1267   ConvertMapToVector(output_map, &outputs);
1268   ConvertMapToVector(border_input_map, &border_inputs);
1269   ConvertMapToVector(border_output_map, &border_outputs);
1270 
1271   if (!input_tensors.empty()) {
1272     bool input_match = false;
1273     if (inputs.size() == input_tensors.size()) {
1274       for (const std::pair<string, Tensor>& input_tensor : input_tensors) {
1275         if (!ContainsSameTensorId(input_tensor.first, inputs)) {
1276           break;
1277         }
1278         DataType data_type;
1279         TensorShape shape;
1280         if (GetOutputTensorShapeType(input_graph_def, input_tensor.first,
1281                                      &data_type, &shape)) {
1282           if (data_type == input_tensor.second.dtype() &&
1283               shape == input_tensor.second.shape()) {
1284             VLOG(2) << "Input matched!";
1285             // Shape type matched.
1286             input_match = true;
1287             require_shape_type = true;
1288           }
1289         } else {
1290           // Shape type not required.
1291           input_match = true;
1292         }
1293       }
1294     }
1295     if (!input_match) {
1296       // Input mismatch.  Just copy original graph
1297       *output_graph_def = input_graph_def;
1298       return Status::OK();
1299     }
1300   }
1301 
1302   if (!fused_node_names.empty()) {
1303     TF_RETURN_IF_ERROR(FuseRemoteGraphByNodeNames(
1304         input_graph_def, inputs, outputs, remote_fused_graph_node_name,
1305         fused_node_names, remote_graph_executor_name, require_shape_type,
1306         output_graph_def));
1307   } else if (!border_inputs.empty() || !border_outputs.empty()) {
1308     TF_RETURN_IF_ERROR(FuseRemoteGraphByBorder(
1309         input_graph_def, inputs, outputs, remote_fused_graph_node_name,
1310         border_inputs, border_outputs, remote_graph_executor_name,
1311         require_shape_type, output_graph_def));
1312   } else {
1313     *output_graph_def = input_graph_def;
1314   }
1315 
1316   return Status::OK();
1317 }
1318 
IsFuseReady(const GraphDef & graph_def,const std::vector<std::pair<string,Tensor>> & input_tensors)1319 /* static */ bool RemoteFusedGraphExecuteUtils::IsFuseReady(
1320     const GraphDef& graph_def,
1321     const std::vector<std::pair<string, Tensor>>& input_tensors) {
1322   for (const std::pair<string, Tensor>& input_tensor : input_tensors) {
1323     const NodeDef* node_def = FindNodeDefByName(input_tensor.first, graph_def);
1324     if (node_def == nullptr) {
1325       return false;
1326     }
1327     string attr;
1328     const Status status = GetNodeAttr(*node_def, ATTR_NODE_TYPE, &attr);
1329     if (!status.ok() || attr.empty()) {
1330       return false;
1331     }
1332   }
1333   return true;
1334 }
1335 
CopyByteArrayToTensor(const void * src_ptr,const int src_size,Tensor * tensor)1336 /* static */ Status RemoteFusedGraphExecuteUtils::CopyByteArrayToTensor(
1337     const void* src_ptr, const int src_size, Tensor* tensor) {
1338   int tensor_TotalBytes = tensor->TotalBytes();
1339   CHECK(tensor_TotalBytes >= src_size) << tensor_TotalBytes << ", " << src_size;
1340   void* dst_ptr;
1341   switch (tensor->dtype()) {
1342     case DT_FLOAT:
1343       dst_ptr = tensor->flat<float>().data();
1344       break;
1345     case DT_DOUBLE:
1346       dst_ptr = tensor->flat<double>().data();
1347       break;
1348     case DT_INT32:
1349       dst_ptr = tensor->flat<int32>().data();
1350       break;
1351     case DT_UINT8:
1352       dst_ptr = tensor->flat<uint8>().data();
1353       break;
1354     case DT_INT16:
1355       dst_ptr = tensor->flat<int16>().data();
1356       break;
1357     case DT_INT8:
1358       dst_ptr = tensor->flat<int8>().data();
1359       break;
1360     case DT_STRING:
1361       dst_ptr = tensor->flat<tstring>().data();
1362       break;
1363     case DT_INT64:
1364       dst_ptr = tensor->flat<int64>().data();
1365       break;
1366     case DT_BOOL:
1367       dst_ptr = tensor->flat<bool>().data();
1368       break;
1369     case DT_QINT8:
1370       dst_ptr = tensor->flat<qint8>().data();
1371       break;
1372     case DT_QUINT8:
1373       dst_ptr = tensor->flat<quint8>().data();
1374       break;
1375     case DT_QINT32:
1376       dst_ptr = tensor->flat<qint32>().data();
1377       break;
1378     case DT_BFLOAT16:
1379       dst_ptr = tensor->flat<bfloat16>().data();
1380       break;
1381     case DT_QINT16:
1382       dst_ptr = tensor->flat<qint16>().data();
1383       break;
1384     case DT_QUINT16:
1385       dst_ptr = tensor->flat<quint16>().data();
1386       break;
1387     case DT_UINT16:
1388       dst_ptr = tensor->flat<uint16>().data();
1389       break;
1390     default:
1391       LOG(FATAL) << "type " << tensor->dtype() << " is not supported.";
1392       break;
1393   }
1394   CHECK_NOTNULL(dst_ptr);
1395   std::memcpy(dst_ptr, src_ptr, src_size);
1396   return Status::OK();
1397 }
1398 
1399 /* static */ std::unordered_set<string>
BuildNodeMapFromOpTypes(const GraphDef & graph_def,const std::unordered_set<string> & op_types)1400 RemoteFusedGraphExecuteUtils::BuildNodeMapFromOpTypes(
1401     const GraphDef& graph_def, const std::unordered_set<string>& op_types) {
1402   std::unordered_set<string> retval;
1403   for (const NodeDef& node_def : graph_def.node()) {
1404     if (op_types.count(node_def.op()) > 0) {
1405       retval.emplace(node_def.name());
1406     }
1407   }
1408   return retval;
1409 }
1410 
1411 /* static */ std::unordered_set<string>
BuildNodeMapFromOpsDefinitions(const GraphDef & graph_def,const IRemoteFusedGraphOpsDefinitions & ops_definitions)1412 RemoteFusedGraphExecuteUtils::BuildNodeMapFromOpsDefinitions(
1413     const GraphDef& graph_def,
1414     const IRemoteFusedGraphOpsDefinitions& ops_definitions) {
1415   std::unordered_set<string> retval;
1416   for (const NodeDef& node_def : graph_def.node()) {
1417     std::vector<DataType> dt_vec;
1418     std::vector<TensorShape> shape_vec;
1419     const Status status =
1420         GetOutputTensorShapeType(node_def, &dt_vec, &shape_vec);
1421     if (!status.ok()) {
1422       shape_vec.clear();
1423     }
1424     if (ops_definitions.GetOpIdFor(
1425             node_def.op(), DataTypeVector(dt_vec.begin(), dt_vec.end())) !=
1426         IRemoteFusedGraphOpsDefinitions::INVALID_OP_ID) {
1427       retval.emplace(node_def.name());
1428     }
1429   }
1430   return retval;
1431 }
1432 
ReplaceInputNodeByPlaceHolder(const string & input,const DataType type,const TensorShape & shape,GraphDef * graph_def)1433 /* static */ Status RemoteFusedGraphExecuteUtils::ReplaceInputNodeByPlaceHolder(
1434     const string& input, const DataType type, const TensorShape& shape,
1435     GraphDef* graph_def) {
1436   const TensorId tid = ParseTensorName(input);
1437   CHECK_EQ(0, tid.second);
1438   const string node_name(tid.first);
1439   for (NodeDef& node : *graph_def->mutable_node()) {
1440     if (node.name() != node_name) {
1441       continue;
1442     }
1443     if (node.op() == "Placeholder") {
1444       return Status::OK();
1445     } else {
1446       NodeDef placeholder_node;
1447       placeholder_node.set_op("Placeholder");
1448       placeholder_node.set_name(node_name);
1449       AddNodeAttr("dtype", type, &placeholder_node);
1450       AddNodeAttr("shape", shape, &placeholder_node);
1451       // TODO(satok): Remove once we merge attributes
1452       AddOutputTensorShapeType({type}, {shape}, &placeholder_node);
1453       node.Clear();
1454       node = placeholder_node;
1455       return Status::OK();
1456     }
1457   }
1458   return errors::InvalidArgument(
1459       strings::StrCat(node_name, " not found for replacement."));
1460 }
1461 
BuildNodeTypeAttr(const RemoteFusedGraphNodeType node_type,const int port,const int index,const string & executor_name,const string & node_name)1462 /* static */ string RemoteFusedGraphExecuteUtils::BuildNodeTypeAttr(
1463     const RemoteFusedGraphNodeType node_type, const int port, const int index,
1464     const string& executor_name, const string& node_name) {
1465   return strings::StrCat(static_cast<int>(node_type), ",", port, ",", index,
1466                          ",", executor_name, ",", node_name);
1467 }
1468 
BuildNodeTypeAttr(const RemoteFusedGraphNodeType node_type,const int port,const int index)1469 /* static */ string RemoteFusedGraphExecuteUtils::BuildNodeTypeAttr(
1470     const RemoteFusedGraphNodeType node_type, const int port, const int index) {
1471   return strings::StrCat(static_cast<int>(node_type), ",", port, ",", index);
1472 }
1473 
BuildNodeTypeAttr(const RemoteFusedGraphNodeType node_type)1474 /* static */ string RemoteFusedGraphExecuteUtils::BuildNodeTypeAttr(
1475     const RemoteFusedGraphNodeType node_type) {
1476   return strings::StrCat(static_cast<int>(node_type));
1477 }
1478 
1479 }  // namespace tensorflow
1480