1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/kernels/remote_fused_graph_execute_utils.h"
17
18 #include <algorithm>
19 #include <queue>
20 #include <utility>
21
22 #include "tensorflow/core/common_runtime/shape_refiner.h"
23 #include "tensorflow/core/framework/graph.pb.h"
24 #include "tensorflow/core/framework/node_def_util.h"
25 #include "tensorflow/core/framework/remote_fused_graph_execute_info.pb.h"
26 #include "tensorflow/core/framework/tensor.pb.h"
27 #include "tensorflow/core/framework/tensor_shape.pb.h"
28 #include "tensorflow/core/graph/algorithm.h"
29 #include "tensorflow/core/graph/node_builder.h"
30 #include "tensorflow/core/public/session.h"
31 #include "tensorflow/core/public/session_options.h"
32
33 namespace tensorflow {
34 namespace {
FindNodeByName(const string & name,const Graph & graph)35 const Node* FindNodeByName(const string& name, const Graph& graph) {
36 for (const Node* node : graph.nodes()) {
37 CHECK_NOTNULL(node);
38 if (node->name() == name) {
39 return node;
40 }
41 }
42 return nullptr;
43 }
44
BuildNodeSetFromNodeNamesAndPorts(const std::vector<string> & node_names_and_ports)45 std::unordered_set<string> BuildNodeSetFromNodeNamesAndPorts(
46 const std::vector<string>& node_names_and_ports) {
47 std::unordered_set<string> retval;
48 for (const string& node_name_and_port : node_names_and_ports) {
49 const TensorId tid = ParseTensorName(node_name_and_port);
50 retval.emplace(tid.first);
51 }
52 return retval;
53 }
54
FindMutableNodeByName(const string & name,Graph * graph)55 Node* FindMutableNodeByName(const string& name, Graph* graph) {
56 for (Node* node : graph->nodes()) {
57 if (node != nullptr && node->name() == name) {
58 return node;
59 }
60 }
61 return nullptr;
62 }
63
FindNodeDefByName(const string & input,const GraphDef & graph_def)64 const NodeDef* FindNodeDefByName(const string& input,
65 const GraphDef& graph_def) {
66 const TensorId tid = ParseTensorName(input);
67 const string name = string(tid.first);
68 for (const NodeDef& node_def : graph_def.node()) {
69 if (node_def.name() == name) {
70 return &node_def;
71 }
72 }
73 return nullptr;
74 }
75
IsSameNodeName(const NodeDef & node_def,const string & node_name_and_port,TensorId * tid)76 bool IsSameNodeName(const NodeDef& node_def, const string& node_name_and_port,
77 TensorId* tid) {
78 CHECK_NOTNULL(tid);
79 *tid = ParseTensorName(node_name_and_port);
80 if (node_def.name() == tid->first) {
81 return true;
82 }
83 return false;
84 }
85
ContainsSameTensorId(const string & tensor_name,const std::vector<string> & tensor_names)86 bool ContainsSameTensorId(const string& tensor_name,
87 const std::vector<string>& tensor_names) {
88 const TensorId tid0 = ParseTensorName(tensor_name);
89 for (const string& name : tensor_names) {
90 const TensorId tid1 = ParseTensorName(name);
91 if (tid0.first == tid1.first && tid0.second == tid1.second) {
92 return true;
93 }
94 }
95 return false;
96 }
97
AppendDeliminator(string * str)98 void AppendDeliminator(string* str) {
99 CHECK_NOTNULL(str);
100 if (!str->empty()) {
101 *str += ":";
102 }
103 }
104
ConvertMapToVector(const std::unordered_map<int,string> & in,std::vector<string> * out)105 void ConvertMapToVector(const std::unordered_map<int, string>& in,
106 std::vector<string>* out) {
107 CHECK_NOTNULL(out);
108 out->resize(in.size());
109 for (size_t i = 0; i < in.size(); ++i) {
110 CHECK(in.count(i) > 0);
111 out->at(i) = in.at(i);
112 }
113 }
114
DumpGraphDef(const GraphDef & graph_def)115 string DumpGraphDef(const GraphDef& graph_def) {
116 string out;
117 for (const NodeDef& node : graph_def.node()) {
118 out += strings::StrCat("node: ", node.name(), "\n input: ");
119 for (const string& input : node.input()) {
120 out += strings::StrCat(input, ", ");
121 }
122 out += "\n";
123 }
124 return out;
125 }
126
DumpCluster(const RemoteFusedGraphExecuteUtils::ClusterInfo & cluster)127 string DumpCluster(const RemoteFusedGraphExecuteUtils::ClusterInfo& cluster) {
128 string out;
129 out += "Nodes:\n";
130 for (const string& str : std::get<0>(cluster)) {
131 out += str + ", ";
132 }
133 out += "\nInput border:\n";
134 for (const string& str : std::get<1>(cluster)) {
135 out += str + ", ";
136 }
137 out += "\nOutput border:\n";
138 for (const string& str : std::get<2>(cluster)) {
139 out += str + ", ";
140 }
141 return out;
142 }
143
144 } // namespace
145
146 /* static */ constexpr const char* const
147 RemoteFusedGraphExecuteUtils::ATTR_OUTPUT_DATA_TYPES;
148 /* static */ constexpr const char* const
149 RemoteFusedGraphExecuteUtils::ATTR_OUTPUT_SHAPES;
150 /* static */ constexpr const char* const RemoteFusedGraphExecuteUtils::
151 ATTR_SERIALIZED_REMOTE_FUSED_GRAPH_EXECUTE_INFO;
152 /* static */ constexpr const char* const
153 RemoteFusedGraphExecuteUtils::ATTR_NODE_TYPE;
154 /* static */ constexpr const char* const RemoteFusedGraphExecuteUtils::
155 TRANSFORM_ARG_REMOTE_FUSED_GRAPH_EXECUTOR_NAME;
156 /* static */ constexpr const char* const
157 RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_REMOTE_FUSED_GRAPH_NODE_NAME;
158 /* static */ constexpr const char* const
159 RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_FUSED_NODES;
160 /* static */ constexpr const char* const
161 RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_BORDER_INPUTS;
162 /* static */ constexpr const char* const
163 RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_BORDER_OUTPUTS;
164 /* static */ constexpr const char* const
165 RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_FUSED_OP_TYPES;
166 /* static */ constexpr const char* const
167 RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_FUSE_BY_EXECUTOR;
168 /* static */ constexpr const char* const
169 RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_INPUT_TYPES;
170 /* static */ constexpr const char* const
171 RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_INPUT_SHAPES;
172
ExecutorBuildRegistrar(const string & name,ExecutorBuildFunc executor_build_func)173 RemoteFusedGraphExecuteUtils::ExecutorBuildRegistrar::ExecutorBuildRegistrar(
174 const string& name, ExecutorBuildFunc executor_build_func) {
175 ExecutorBuildRegistry& executor_build_registry = *GetExecutorBuildRegistry();
176 executor_build_registry[name] = std::move(executor_build_func);
177 }
178
179 /* static */ const RemoteFusedGraphExecuteUtils::ExecutorBuildFunc*
GetExecutorBuildFunc(const string & name)180 RemoteFusedGraphExecuteUtils::GetExecutorBuildFunc(const string& name) {
181 ExecutorBuildRegistry& executor_build_registry = *GetExecutorBuildRegistry();
182 if (executor_build_registry.count(name) <= 0) {
183 return nullptr;
184 }
185 return &executor_build_registry.at(name);
186 }
187
188 /* static */ RemoteFusedGraphExecuteUtils::ExecutorBuildRegistry*
GetExecutorBuildRegistry()189 RemoteFusedGraphExecuteUtils::GetExecutorBuildRegistry() {
190 static ExecutorBuildRegistry executor_builder_registry;
191 return &executor_builder_registry;
192 }
193
194 /**
195 * - DryRunInference
196 * To determine shapes of output tensors of all nodes, dryrun the graph.
197 * This function supplies memory allocation information when loading
198 * the graph. This function is used to verify shape inference and actual
199 * output shape.
200 */
DryRunInference(const GraphDef & graph_def,const std::vector<std::pair<string,Tensor>> & input_node_info_list,const std::vector<string> & output_node_names,const bool initialize_by_zero,std::vector<tensorflow::Tensor> * output_tensors)201 /* static */ Status RemoteFusedGraphExecuteUtils::DryRunInference(
202 const GraphDef& graph_def,
203 const std::vector<std::pair<string, Tensor>>& input_node_info_list,
204 const std::vector<string>& output_node_names, const bool initialize_by_zero,
205 std::vector<tensorflow::Tensor>* output_tensors) {
206 // Create input tensor vector. If "initialize_by_zero" is true,
207 // input tensor fields are initialized by 0.
208 std::vector<std::pair<string, tensorflow::Tensor>> input_tensors;
209 for (const std::pair<string, Tensor>& input : input_node_info_list) {
210 CHECK(input.second.IsInitialized());
211 if (!initialize_by_zero) {
212 input_tensors.push_back({input.first, input.second});
213 continue;
214 }
215 // If input tensor is not initialized, initialize by 0-filling
216 const DataType data_type = input.second.dtype();
217 const TensorShape& shape = input.second.shape();
218 Tensor input_tensor(data_type, shape);
219 switch (data_type) {
220 case DT_INT32: {
221 auto int_tensor = input_tensor.flat<int32>();
222 int_tensor = int_tensor.constant(0);
223 break;
224 }
225 case DT_FLOAT: {
226 auto float_tensor = input_tensor.flat<float>();
227 float_tensor = float_tensor.constant(0.0f);
228 break;
229 }
230 case DT_QUINT8: {
231 auto int_tensor = input_tensor.flat<quint8>();
232 int_tensor = int_tensor.constant(0);
233 break;
234 }
235 default:
236 LOG(FATAL) << "Unsupported input type: " << data_type;
237 }
238 input_tensors.push_back({input.first, input_tensor});
239 }
240
241 // Setup session
242 CHECK(output_tensors != nullptr);
243 SessionOptions session_options;
244 session_options.env = Env::Default();
245 std::unique_ptr<Session> session =
246 std::unique_ptr<Session>(NewSession(session_options));
247 Status status = session->Create(graph_def);
248 if (!status.ok()) {
249 return status;
250 }
251
252 // Setup session arguments
253 RunOptions run_options;
254 run_options.set_trace_level(RunOptions::FULL_TRACE);
255 RunMetadata run_metadata;
256
257 // Run inference with all node as output
258 status = session->Run(run_options, input_tensors, output_node_names, {},
259 output_tensors, &run_metadata);
260 if (!status.ok()) {
261 LOG(ERROR) << "Error during inference: " << status;
262 return status;
263 }
264 return Status();
265 }
266
DryRunInferenceForAllNode(const GraphDef & graph_def,const std::vector<std::pair<string,Tensor>> & input_node_info_list,const bool initialize_by_zero,RemoteFusedGraphExecuteUtils::TensorShapeMap * tensor_shape_map)267 /* static */ Status RemoteFusedGraphExecuteUtils::DryRunInferenceForAllNode(
268 const GraphDef& graph_def,
269 const std::vector<std::pair<string, Tensor>>& input_node_info_list,
270 const bool initialize_by_zero,
271 RemoteFusedGraphExecuteUtils::TensorShapeMap* tensor_shape_map) {
272 CHECK(tensor_shape_map != nullptr);
273 std::vector<Tensor> output_tensors;
274 output_tensors.reserve(graph_def.node_size());
275 std::vector<string> output_node_names;
276
277 Graph graph(OpRegistry::Global());
278 Status status = ImportGraphDef({}, graph_def, &graph, nullptr);
279 if (!status.ok()) {
280 return status;
281 }
282
283 for (const Node* node : graph.nodes()) {
284 if (IsInputNode(input_node_info_list, node->name())) {
285 continue;
286 }
287 for (int i = 0; i < node->num_outputs(); ++i) {
288 output_node_names.emplace_back(strings::StrCat(node->name(), ":", i));
289 }
290 }
291
292 status = DryRunInference(graph_def, input_node_info_list, output_node_names,
293 initialize_by_zero, &output_tensors);
294 if (!status.ok()) {
295 VLOG(1) << "Failed to dryrun " << status;
296 return status;
297 }
298
299 CHECK_EQ(output_node_names.size(), output_tensors.size())
300 << output_node_names.size() << ", " << output_tensors.size();
301
302 // Append output tensor of input node in advance to create a map
303 // to avoid memory reallocation inside vector
304 for (const std::pair<string, Tensor>& input_node_info :
305 input_node_info_list) {
306 output_tensors.push_back(input_node_info.second);
307 }
308
309 for (int i = 0; static_cast<size_t>(i) < output_node_names.size(); ++i) {
310 const string& name = output_node_names.at(i);
311 const Tensor& tensor = output_tensors.at(i);
312 EmplaceTensorShapeType(name, tensor, tensor_shape_map);
313 }
314 for (int i = 0; static_cast<size_t>(i) < input_node_info_list.size(); ++i) {
315 const string& name = input_node_info_list.at(i).first;
316 const Tensor& tensor = output_tensors.at(output_node_names.size() + i);
317 EmplaceTensorShapeType(name, tensor, tensor_shape_map);
318 }
319 CHECK_EQ(output_node_names.size() + input_node_info_list.size(),
320 output_tensors.size());
321 return status;
322 }
323
IsInputNode(const std::vector<std::pair<string,Tensor>> & input_tensor_vector,const string & node_name)324 /* static */ bool RemoteFusedGraphExecuteUtils::IsInputNode(
325 const std::vector<std::pair<string, Tensor>>& input_tensor_vector,
326 const string& node_name) {
327 for (const std::pair<string, Tensor>& pair : input_tensor_vector) {
328 const TensorId tid = ParseTensorName(pair.first);
329 if (node_name == tid.first) {
330 return true;
331 }
332 }
333 return false;
334 }
335
ConvertToTensorShapeMap(const std::vector<std::pair<string,Tensor>> & input_node_info_list,const std::vector<string> & output_node_names,const std::vector<tensorflow::Tensor> & output_tensors,TensorShapeMap * tensor_shape_map)336 /* static */ void RemoteFusedGraphExecuteUtils::ConvertToTensorShapeMap(
337 const std::vector<std::pair<string, Tensor>>& input_node_info_list,
338 const std::vector<string>& output_node_names,
339 const std::vector<tensorflow::Tensor>& output_tensors,
340 TensorShapeMap* tensor_shape_map) {
341 CHECK_NE(tensor_shape_map, nullptr);
342 tensor_shape_map->clear();
343 tensor_shape_map->reserve(input_node_info_list.size() +
344 output_node_names.size());
345 const int output_node_count = output_node_names.size();
346 CHECK_EQ(output_node_count, output_tensors.size());
347 for (int i = 0; i < output_node_count; ++i) {
348 const string& node_name = output_node_names.at(i);
349 const Tensor& tensor = output_tensors.at(i);
350 EmplaceTensorShapeType(node_name, tensor, tensor_shape_map);
351 }
352 }
353
MakeTensorFromProto(const TensorProto & tensor_proto,Tensor * tensor)354 /* static */ Status RemoteFusedGraphExecuteUtils::MakeTensorFromProto(
355 const TensorProto& tensor_proto, Tensor* tensor) {
356 if (tensor_proto.dtype() > 0 && tensor_proto.dtype() <= DataType_MAX) {
357 Tensor parsed(tensor_proto.dtype());
358 if (parsed.FromProto(cpu_allocator(), tensor_proto)) {
359 *tensor = parsed;
360 return Status::OK();
361 }
362 }
363 return errors::InvalidArgument("Cannot parse tensor from proto");
364 }
365
AddOutputTensorShapeType(const std::vector<DataType> & data_types,const std::vector<TensorShape> & shapes,NodeDef * node_def)366 /* static */ bool RemoteFusedGraphExecuteUtils::AddOutputTensorShapeType(
367 const std::vector<DataType>& data_types,
368 const std::vector<TensorShape>& shapes, NodeDef* node_def) {
369 AddNodeAttr(ATTR_OUTPUT_DATA_TYPES, data_types, node_def);
370 AddNodeAttr(ATTR_OUTPUT_SHAPES, shapes, node_def);
371 return true;
372 }
373
374 /* static */ Status
AddOutputTensorShapeTypeByTensorShapeMap(const TensorShapeMap & tensor_shape_map,NodeDef * node_def)375 RemoteFusedGraphExecuteUtils::AddOutputTensorShapeTypeByTensorShapeMap(
376 const TensorShapeMap& tensor_shape_map, NodeDef* node_def) {
377 CHECK_NE(node_def, nullptr);
378 std::priority_queue<std::tuple<int, const TensorShapeType*>> queue;
379 auto its = tensor_shape_map.equal_range(node_def->name());
380 for (auto it = its.first; it != its.second; ++it) {
381 queue.emplace(std::make_tuple(it->second.first, &it->second.second));
382 }
383 int last_port = queue.size();
384 std::vector<DataType> data_types;
385 std::vector<TensorShape> shapes;
386 while (!queue.empty()) {
387 const int port = std::get<0>(queue.top());
388 const TensorShapeType* tst = std::get<1>(queue.top());
389 CHECK_NE(tst, nullptr);
390 data_types.emplace(data_types.begin(), tst->first);
391 shapes.emplace(shapes.begin(), tst->second);
392 CHECK_EQ(last_port - 1, port);
393 last_port = port;
394 queue.pop();
395 }
396 AddOutputTensorShapeType(data_types, shapes, node_def);
397 return Status::OK();
398 }
399
GetOutputTensorShapeType(AttrSlice attrs,std::vector<DataType> * data_types,std::vector<TensorShape> * shapes)400 /* static */ Status RemoteFusedGraphExecuteUtils::GetOutputTensorShapeType(
401 AttrSlice attrs, std::vector<DataType>* data_types,
402 std::vector<TensorShape>* shapes) {
403 Status status;
404 if (data_types != nullptr) {
405 status = GetNodeAttr(attrs, ATTR_OUTPUT_DATA_TYPES, data_types);
406 }
407 if (!status.ok()) {
408 return status;
409 }
410 if (shapes != nullptr) {
411 status = GetNodeAttr(attrs, ATTR_OUTPUT_SHAPES, shapes);
412 if (status.ok() && data_types != nullptr) {
413 CHECK_EQ(data_types->size(), shapes->size());
414 }
415 }
416
417 return status;
418 }
419
GetOutputTensorShapeType(const GraphDef & graph_def,const string & name_and_port,DataType * data_type,TensorShape * shape)420 /* static */ bool RemoteFusedGraphExecuteUtils::GetOutputTensorShapeType(
421 const GraphDef& graph_def, const string& name_and_port, DataType* data_type,
422 TensorShape* shape) {
423 std::vector<DataType> data_types;
424 std::vector<TensorShape> shapes;
425 const TensorId tid = ParseTensorName(name_and_port);
426 const string node_name(tid.first);
427 const int port = tid.second;
428 const NodeDef* node_def = FindNodeDefByName(node_name, graph_def);
429 CHECK_NOTNULL(node_def);
430 GetOutputTensorShapeType(*node_def, &data_types, &shapes).IgnoreError();
431 if (data_types.empty()) {
432 return false;
433 }
434 CHECK(data_types.size() > port);
435 *data_type = data_types.at(port);
436 *shape = shapes.at(port);
437 return true;
438 }
439
PropagateShapeInference(const GraphDef & graph_def,const std::vector<std::pair<string,Tensor>> & input_node_info_list,Graph * graph,ShapeRefiner * shape_refiner)440 /* static */ Status RemoteFusedGraphExecuteUtils::PropagateShapeInference(
441 const GraphDef& graph_def,
442 const std::vector<std::pair<string, Tensor>>& input_node_info_list,
443 Graph* graph, ShapeRefiner* shape_refiner) {
444 Status status;
445 auto visit = [&shape_refiner, &input_node_info_list, &status](Node* node) {
446 if (!status.ok()) {
447 return;
448 }
449 CHECK_NE(node, nullptr);
450 // If we visit an input node, we use the shape provided and set the
451 // shape accordingly.
452 bool is_input_node = false;
453 for (const std::pair<string, Tensor>& input_node_info :
454 input_node_info_list) {
455 if (node->name() == input_node_info.first) {
456 shape_inference::InferenceContext* context =
457 shape_refiner->GetContext(node);
458 shape_inference::ShapeHandle handle;
459 status = context->MakeShapeFromTensorShape(
460 input_node_info.second.shape(), &handle);
461 if (!status.ok()) {
462 break;
463 }
464 status = shape_refiner->SetShape(node, 0, handle);
465 if (!status.ok()) {
466 break;
467 }
468 is_input_node = true;
469 }
470 if (!status.ok()) {
471 break;
472 }
473 }
474 // If not an input node call AddNode() that recomputes the shape.
475 if (!is_input_node && status.ok()) {
476 status = shape_refiner->AddNode(node);
477 }
478 if (!status.ok()) {
479 VLOG(1) << "Shape inference failed for node: " << node->name();
480 }
481 };
482
483 ReverseDFS(*graph, {}, visit);
484
485 return status;
486 }
487
BuildTensorShapeMapFromGraph(const Graph & graph,const ShapeRefiner & shape_refiner,TensorShapeMap * tensor_shape_map)488 /* static */ Status RemoteFusedGraphExecuteUtils::BuildTensorShapeMapFromGraph(
489 const Graph& graph, const ShapeRefiner& shape_refiner,
490 TensorShapeMap* tensor_shape_map) {
491 for (int i = 0; i < graph.num_node_ids(); ++i) {
492 const Node* node = graph.FindNodeId(i);
493 CHECK_NE(node, nullptr);
494 for (int j = 0; j < node->num_outputs(); ++j) {
495 const int output_index = j;
496 const DataType dt = node->output_type(output_index);
497 shape_inference::InferenceContext* context =
498 shape_refiner.GetContext(node);
499 CHECK_NE(context, nullptr);
500 shape_inference::ShapeHandle shape_handle = context->output(output_index);
501 if (context->RankKnown(shape_handle)) {
502 TensorShape ts;
503 for (int k = 0; k < context->Rank(shape_handle); ++k) {
504 shape_inference::DimensionHandle dh = context->Dim(shape_handle, k);
505 CHECK(context->ValueKnown(dh));
506 ts.AddDim(context->Value(dh));
507 }
508 const string& node_name = node->name();
509 CHECK(tensor_shape_map->count(node_name) == 0);
510 tensor_shape_map->emplace(node_name,
511 std::make_pair(j, std::make_pair(dt, ts)));
512 } else {
513 return errors::InvalidArgument("Graph contains unknow shapes");
514 }
515 }
516 }
517 return Status::OK();
518 }
519
520 /* static */ const RemoteFusedGraphExecuteUtils::TensorShapeType*
GetTensorShapeType(const TensorShapeMap & tensor_shape_map,const string & node_name)521 RemoteFusedGraphExecuteUtils::GetTensorShapeType(
522 const TensorShapeMap& tensor_shape_map, const string& node_name) {
523 if (node_name.find(':') != string::npos) {
524 const TensorId tid = ParseTensorName(node_name);
525 return GetTensorShapeType(tensor_shape_map, string(tid.first), tid.second);
526 } else {
527 return GetTensorShapeType(tensor_shape_map, node_name, 0);
528 }
529 }
530
531 /* static */ const RemoteFusedGraphExecuteUtils::TensorShapeType*
GetTensorShapeType(const TensorShapeMap & tensor_shape_map,const string & node_name,const int port)532 RemoteFusedGraphExecuteUtils::GetTensorShapeType(
533 const TensorShapeMap& tensor_shape_map, const string& node_name,
534 const int port) {
535 CHECK_EQ(node_name.find(':'), string::npos);
536 if (tensor_shape_map.count(node_name) <= 0) {
537 return nullptr;
538 }
539 auto its = tensor_shape_map.equal_range(node_name);
540 for (auto it = its.first; it != its.second; ++it) {
541 if (it->second.first == port) {
542 return &it->second.second;
543 }
544 }
545 return nullptr;
546 }
547
548 /* static */ void
BuildRemoteGraphInputsAndOutputsFromProto(const RemoteFusedGraphExecuteInfo & proto,std::vector<std::pair<string,Tensor>> * inputs,std::vector<string> * outputs)549 RemoteFusedGraphExecuteUtils::BuildRemoteGraphInputsAndOutputsFromProto(
550 const RemoteFusedGraphExecuteInfo& proto,
551 std::vector<std::pair<string, Tensor>>* inputs,
552 std::vector<string>* outputs) {
553 CHECK_EQ(proto.graph_input_node_name_size(),
554 proto.default_graph_input_tensor_shape_size());
555 for (int i = 0; i < proto.graph_input_node_name_size(); ++i) {
556 inputs->emplace_back(
557 proto.graph_input_node_name(i),
558 Tensor(proto.default_graph_input_tensor_shape(i).dtype(),
559 TensorShape(proto.default_graph_input_tensor_shape(i).shape())));
560 }
561 for (const string& output_node_name : proto.graph_output_node_name()) {
562 outputs->emplace_back(output_node_name);
563 }
564 }
565
EmplaceTensorShapeType(const string & name,const Tensor & tensor,TensorShapeMap * tensor_shape_map)566 /* static */ void RemoteFusedGraphExecuteUtils::EmplaceTensorShapeType(
567 const string& name, const Tensor& tensor,
568 TensorShapeMap* tensor_shape_map) {
569 const TensorId tid = ParseTensorName(name);
570 CHECK_EQ(tensor_shape_map->count(name), 0);
571 tensor_shape_map->emplace(
572 string(tid.first),
573 std::make_pair(tid.second,
574 std::make_pair(tensor.dtype(), tensor.shape())));
575 }
576
BuildAndAddTensorShapes(const std::vector<std::pair<string,Tensor>> & input_tensors,const bool dry_run_inference,GraphDef * graph_def)577 /* static */ Status RemoteFusedGraphExecuteUtils::BuildAndAddTensorShapes(
578 const std::vector<std::pair<string, Tensor>>& input_tensors,
579 const bool dry_run_inference, GraphDef* graph_def) {
580 TensorShapeMap tensor_shape_map;
581 if (dry_run_inference) {
582 TF_RETURN_IF_ERROR(DryRunInferenceForAllNode(*graph_def, input_tensors,
583 /*initialize_by_zero=*/true,
584 &tensor_shape_map));
585 } else {
586 ImportGraphDefOptions opts;
587 Graph graph(OpRegistry::Global());
588 ShapeRefiner shape_refiner(graph.versions(), graph.op_registry());
589 TF_RETURN_IF_ERROR(
590 ImportGraphDef(opts, *graph_def, &graph, &shape_refiner));
591 TF_RETURN_IF_ERROR(PropagateShapeInference(*graph_def, input_tensors,
592 &graph, &shape_refiner));
593 TF_RETURN_IF_ERROR(
594 BuildTensorShapeMapFromGraph(graph, shape_refiner, &tensor_shape_map));
595 }
596
597 for (NodeDef& node_def : *graph_def->mutable_node()) {
598 TF_RETURN_IF_ERROR(
599 AddOutputTensorShapeTypeByTensorShapeMap(tensor_shape_map, &node_def));
600 }
601
602 return Status::OK();
603 }
604
605 /* static */ Status
BuildRemoteFusedGraphExecuteInfo(const string & executor_name,const GraphDef & subgraph_def,const std::vector<string> & inputs,const std::vector<string> & outputs,const bool require_shape_type,RemoteFusedGraphExecuteInfo * execute_info,DataTypeVector * input_types,DataTypeVector * output_types)606 RemoteFusedGraphExecuteUtils::BuildRemoteFusedGraphExecuteInfo(
607 const string& executor_name, const GraphDef& subgraph_def,
608 const std::vector<string>& inputs, const std::vector<string>& outputs,
609 const bool require_shape_type, RemoteFusedGraphExecuteInfo* execute_info,
610 DataTypeVector* input_types, DataTypeVector* output_types) {
611 CHECK_NOTNULL(execute_info);
612 CHECK_NOTNULL(input_types);
613 CHECK_NOTNULL(output_types);
614
615 execute_info->Clear();
616 execute_info->set_executor_name(executor_name);
617
618 // copy graph
619 *execute_info->mutable_remote_graph() = subgraph_def;
620
621 for (const string& input : inputs) {
622 DataType dt;
623 TensorShape shape;
624 const bool has_shapetype =
625 GetOutputTensorShapeType(subgraph_def, input, &dt, &shape);
626
627 execute_info->add_graph_input_node_name(input);
628 if (has_shapetype) {
629 RemoteFusedGraphExecuteInfo::TensorShapeTypeProto& tensor_shape_type =
630 *execute_info->add_default_graph_input_tensor_shape();
631 tensor_shape_type.set_dtype(dt);
632 TensorShapeProto& tensor_shape_proto = *tensor_shape_type.mutable_shape();
633 for (const int64 dim : shape.dim_sizes()) {
634 tensor_shape_proto.add_dim()->set_size(dim);
635 }
636 input_types->push_back(dt);
637 } else {
638 CHECK(!require_shape_type)
639 << "No shape type found for " << input << DumpGraphDef(subgraph_def);
640 // Assuming input type is float if no data provided.
641 input_types->push_back(DT_FLOAT);
642 }
643 }
644
645 for (const string& output : outputs) {
646 DataType dt;
647 TensorShape shape;
648 const bool has_shapetype =
649 GetOutputTensorShapeType(subgraph_def, output, &dt, &shape);
650
651 execute_info->add_graph_output_node_name(output);
652 if (has_shapetype) {
653 RemoteFusedGraphExecuteInfo::TensorShapeTypeProto&
654 tensor_shape_type_proto =
655 *execute_info->add_default_graph_output_tensor_shape();
656 tensor_shape_type_proto.set_dtype(dt);
657 TensorShapeProto& tensor_shape_proto =
658 *tensor_shape_type_proto.mutable_shape();
659 for (const int64 dim : shape.dim_sizes()) {
660 tensor_shape_proto.add_dim()->set_size(dim);
661 }
662 output_types->push_back(dt);
663 } else {
664 CHECK(!require_shape_type)
665 << "No shape type found for " << output << DumpGraphDef(subgraph_def);
666 // Assuming output type is float if no data provided.
667 output_types->push_back(DT_FLOAT);
668 }
669 }
670
671 return Status::OK();
672 }
673
674 /* static */ Status
BuildRemoteFusedGraphExecuteOpNode(const string & node_name,const string & executor_name,const GraphDef & subgraph_def,const std::vector<string> & inputs,const std::vector<string> & outputs,const bool require_shape_type,Graph * graph,Node ** created_node)675 RemoteFusedGraphExecuteUtils::BuildRemoteFusedGraphExecuteOpNode(
676 const string& node_name, const string& executor_name,
677 const GraphDef& subgraph_def, const std::vector<string>& inputs,
678 const std::vector<string>& outputs, const bool require_shape_type,
679 Graph* graph, Node** created_node) {
680 CHECK_NOTNULL(graph);
681 CHECK_NOTNULL(created_node);
682
683 RemoteFusedGraphExecuteInfo execute_info;
684 DataTypeVector input_types;
685 DataTypeVector output_types;
686
687 TF_CHECK_OK(RemoteFusedGraphExecuteUtils::BuildRemoteFusedGraphExecuteInfo(
688 executor_name, subgraph_def, inputs, outputs, require_shape_type,
689 &execute_info, &input_types, &output_types));
690
691 std::vector<NodeBuilder::NodeOut> node_out_list;
692 for (const string& input : inputs) {
693 const TensorId tid = ParseTensorName(input);
694 Node* node = FindMutableNodeByName(string(tid.first), graph);
695 CHECK_NOTNULL(node);
696 node_out_list.emplace_back(node, tid.second);
697 }
698
699 const string execute_info_str = execute_info.SerializeAsString();
700
701 auto builder =
702 NodeBuilder(node_name, "RemoteFusedGraphExecute")
703 .Input(node_out_list)
704 .Attr("Tinputs", input_types)
705 .Attr("Toutputs", output_types)
706 .Attr("serialized_remote_fused_graph_execute_info", execute_info_str);
707
708 TF_RETURN_IF_ERROR(builder.Finalize(graph, created_node));
709 return Status::OK();
710 }
711
BuildIdentityOpNode(const string & node_name,const string & input_node_name,const int input_node_port,const DataType dt,Graph * graph,Node ** created_node)712 /* static */ Status RemoteFusedGraphExecuteUtils::BuildIdentityOpNode(
713 const string& node_name, const string& input_node_name,
714 const int input_node_port, const DataType dt, Graph* graph,
715 Node** created_node) {
716 Node* node = FindMutableNodeByName(input_node_name, graph);
717 CHECK_NOTNULL(node);
718 NodeBuilder::NodeOut node_out(node, input_node_port);
719
720 auto builder =
721 NodeBuilder(node_name, "Identity").Input(node_out).Attr("T", dt);
722
723 TF_RETURN_IF_ERROR(builder.Finalize(graph, created_node));
724 return Status::OK();
725 }
726
ClusterizeNodes(const std::unordered_set<string> & node_names,const GraphDef & graph_def,std::vector<ClusterInfo> * cluster_infos)727 /* static */ Status RemoteFusedGraphExecuteUtils::ClusterizeNodes(
728 const std::unordered_set<string>& node_names, const GraphDef& graph_def,
729 std::vector<ClusterInfo>* cluster_infos) {
730 Graph graph(OpRegistry::Global());
731 ShapeRefiner shape_refiner(graph.versions(), graph.op_registry());
732 TF_RETURN_IF_ERROR(ImportGraphDef({}, graph_def, &graph, &shape_refiner));
733 std::unordered_set<string> remaining_nodes = node_names;
734
735 while (!remaining_nodes.empty()) {
736 ClusterInfo ci;
737
738 // Determine one cluster nodes
739 std::unordered_set<const Node*> visited;
740 std::deque<const Node*> queue;
741 queue.emplace_back(FindNodeByName(*remaining_nodes.begin(), graph));
742 while (!queue.empty()) {
743 const Node* node = queue.front();
744 CHECK_NOTNULL(node);
745 queue.pop_front();
746 const string& node_name = node->name();
747 if (node_names.count(node_name) > 0) {
748 std::get<0>(ci).emplace(node_name);
749 remaining_nodes.erase(node_name);
750 } else {
751 // Edge of subgraph. Do nothing.
752 continue;
753 }
754 for (const Node* in : node->in_nodes()) {
755 if (visited.insert(in).second) {
756 queue.push_back(in);
757 }
758 }
759 for (const Node* out : node->out_nodes()) {
760 if (visited.insert(out).second) {
761 queue.push_back(out);
762 }
763 }
764 }
765
766 // Determine one cluster border
767 std::vector<string>& border_inputs = std::get<1>(ci);
768 std::vector<string>& border_outputs = std::get<2>(ci);
769 for (const string& node_name : node_names) {
770 Node* node = FindMutableNodeByName(node_name, &graph);
771 CHECK_NOTNULL(node);
772 int input_count = 0;
773 for (const Edge* in_edge : node->in_edges()) {
774 const Node* src_node = in_edge->src();
775 const bool src_is_outside =
776 node_names.count(src_node->name()) <= 0 && !src_node->IsSource();
777 if (src_is_outside) {
778 const string src_name =
779 strings::StrCat(src_node->name(), ":", in_edge->src_output());
780 CHECK_EQ(1, src_node->num_outputs())
781 << "output count of input border node must be one."
782 << src_node->name();
783 if (std::find(border_inputs.begin(), border_inputs.end(), src_name) ==
784 border_inputs.end()) {
785 border_inputs.emplace_back(src_name);
786 }
787 } else {
788 ++input_count;
789 }
790 }
791 CHECK(input_count == 0 || input_count == node->in_edges().size())
792 << "Invalid input_count(" << input_count << ", "
793 << node->in_edges().size() << ") " << node_name;
794
795 for (const Edge* out_edge : node->out_edges()) {
796 const Node* dst_node = out_edge->dst();
797 CHECK_NOTNULL(dst_node);
798 const bool dst_is_outside = node_names.count(dst_node->name()) <= 0;
799 const string dst_name =
800 strings::StrCat(node->name(), ":", out_edge->src_output());
801 if (dst_is_outside) {
802 if (dst_node->IsSink()) {
803 CHECK_EQ(1, node->num_outputs())
804 << "If you want to specify output node as subgraph output node "
805 << "the output count of the node must be 1 "
806 << "because that node is replaced by identity node.";
807 const string identity_dst_name =
808 strings::StrCat(node->name(), ":", 0);
809 if (std::find(border_outputs.begin(), border_outputs.end(),
810 identity_dst_name) == border_outputs.end()) {
811 border_outputs.emplace_back(identity_dst_name);
812 }
813 } else {
814 if (std::find(border_outputs.begin(), border_outputs.end(),
815 dst_name) == border_outputs.end()) {
816 border_outputs.emplace_back(dst_name);
817 }
818 }
819 }
820 }
821 }
822 cluster_infos->emplace_back(ci);
823 VLOG(1) << DumpCluster(ci);
824 }
825 return Status::OK();
826 }
827
BuildClusterSubgraphDef(const ClusterInfo & cluster,const GraphDef & graph_def,GraphDef * subgraph_def)828 /* static */ Status RemoteFusedGraphExecuteUtils::BuildClusterSubgraphDef(
829 const ClusterInfo& cluster, const GraphDef& graph_def,
830 GraphDef* subgraph_def) {
831 const std::unordered_set<string>& node_names = std::get<0>(cluster);
832 const std::unordered_set<string>& border_input_names =
833 BuildNodeSetFromNodeNamesAndPorts(std::get<1>(cluster));
834
835 Graph graph(OpRegistry::Global());
836 ShapeRefiner shape_refiner(graph.versions(), graph.op_registry());
837 TF_RETURN_IF_ERROR(ImportGraphDef({}, graph_def, &graph, &shape_refiner));
838
839 for (Node* node : graph.nodes()) {
840 if (node != nullptr && node_names.count(node->name()) <= 0 &&
841 border_input_names.count(node->name()) <= 0 && !node->IsSource() &&
842 !node->IsSink()) {
843 graph.RemoveNode(node);
844 }
845 }
846 graph.ToGraphDef(subgraph_def);
847
848 for (const string& subgraph_input : std::get<1>(cluster)) {
849 const TensorId tid = ParseTensorName(subgraph_input);
850 const string subgraph_input_name(tid.first);
851 const int subgraph_input_port = tid.second;
852 const NodeDef* node_def = FindNodeDefByName(subgraph_input_name, graph_def);
853 CHECK_NOTNULL(node_def);
854 std::vector<DataType> dt_vec;
855 std::vector<TensorShape> shape_vec;
856 GetOutputTensorShapeType(*node_def, &dt_vec, &shape_vec).IgnoreError();
857 const DataType& dt =
858 dt_vec.empty() ? DT_FLOAT : dt_vec.at(subgraph_input_port);
859 const TensorShape& shape =
860 shape_vec.empty() ? TensorShape({}) : shape_vec.at(subgraph_input_port);
861
862 TF_RETURN_IF_ERROR(ReplaceInputNodeByPlaceHolder(subgraph_input_name, dt,
863 shape, subgraph_def));
864 }
865
866 // sort subgraph_def to align order in graph_def
867 std::unordered_map<string, int> name_to_id_map;
868 for (int i = 0; i < graph_def.node_size(); ++i) {
869 name_to_id_map.emplace(graph_def.node(i).name(), i);
870 }
871 std::sort(subgraph_def->mutable_node()->begin(),
872 subgraph_def->mutable_node()->end(),
873 [&name_to_id_map](const NodeDef& node0, const NodeDef& node1) {
874 CHECK(name_to_id_map.count(node0.name()) > 0);
875 CHECK(name_to_id_map.count(node1.name()) > 0);
876 const int id0 = name_to_id_map.at(node0.name());
877 const int id1 = name_to_id_map.at(node1.name());
878 return id0 < id1;
879 });
880
881 VLOG(1) << DumpGraphDef(*subgraph_def);
882 return Status::OK();
883 }
884
BuildClusterByBorder(const std::vector<string> & border_inputs,const std::vector<string> & border_outputs,const GraphDef & graph_def,ClusterInfo * cluster)885 /* static */ Status RemoteFusedGraphExecuteUtils::BuildClusterByBorder(
886 const std::vector<string>& border_inputs,
887 const std::vector<string>& border_outputs, const GraphDef& graph_def,
888 ClusterInfo* cluster) {
889 Graph graph(OpRegistry::Global());
890 ShapeRefiner shape_refiner(graph.versions(), graph.op_registry());
891 TF_RETURN_IF_ERROR(ImportGraphDef({}, graph_def, &graph, &shape_refiner));
892
893 std::unordered_set<const Node*> visited;
894 std::deque<const Node*> queue;
895 for (const string& output : border_outputs) {
896 const TensorId tid = ParseTensorName(output);
897 const string output_node_name(tid.first);
898 for (const Node* node : graph.nodes()) {
899 if (output_node_name == node->name()) {
900 queue.push_back(node);
901 visited.insert(node);
902 }
903 }
904 }
905
906 std::unordered_set<const Node*> border_input_nodes;
907 // propagate visit to parent nodes until input nodes
908 while (!queue.empty()) {
909 const Node* node = queue.front();
910 queue.pop_front();
911 for (const Edge* edge : node->in_edges()) {
912 const Node* src_node = edge->src();
913 CHECK_NOTNULL(src_node);
914 const int src_port = edge->src_output();
915 bool input_found = false;
916 for (const string& input : border_inputs) {
917 const TensorId tid = ParseTensorName(input);
918 if (tid.first == src_node->name() && tid.second == src_port) {
919 input_found = true;
920 border_input_nodes.insert(src_node);
921 }
922 }
923 if (visited.insert(src_node).second) {
924 if (!input_found) {
925 queue.push_back(src_node);
926 }
927 }
928 }
929 }
930
931 for (const Node* node : visited) {
932 if (node != nullptr && !node->IsSource() && !node->IsSink() &&
933 border_input_nodes.count(node) <= 0) {
934 std::get<0>(*cluster).insert(node->name());
935 }
936 }
937 std::get<1>(*cluster) = border_inputs;
938 std::get<2>(*cluster) = border_outputs;
939 return Status::OK();
940 }
941
FuseCluster(const GraphDef & input_graph_def,const std::vector<string> & inputs,const std::vector<string> & outputs,const string & remote_fused_graph_node_name,const ClusterInfo & cluster,const string & remote_graph_executor_name,const bool require_shape_type,GraphDef * output_graph_def)942 /* static */ Status RemoteFusedGraphExecuteUtils::FuseCluster(
943 const GraphDef& input_graph_def, const std::vector<string>& inputs,
944 const std::vector<string>& outputs,
945 const string& remote_fused_graph_node_name, const ClusterInfo& cluster,
946 const string& remote_graph_executor_name, const bool require_shape_type,
947 GraphDef* output_graph_def) {
948 LOG(INFO) << "Transforming quantized stripped model to a remote fused "
949 "graph execute op by fusing a specified subgraph...";
950
951 CHECK(!remote_graph_executor_name.empty());
952
953 const std::vector<string>& border_inputs = std::get<1>(cluster);
954 const std::vector<string>& border_outputs = std::get<2>(cluster);
955
956 GraphDef subgraph_def;
957 TF_RETURN_IF_ERROR(
958 BuildClusterSubgraphDef(cluster, input_graph_def, &subgraph_def));
959
960 Graph graph(OpRegistry::Global());
961 ShapeRefiner shape_refiner(graph.versions(), graph.op_registry());
962 TF_RETURN_IF_ERROR(
963 ImportGraphDef({}, input_graph_def, &graph, &shape_refiner));
964
965 Node* fused_node;
966 TF_RETURN_IF_ERROR(BuildRemoteFusedGraphExecuteOpNode(
967 remote_fused_graph_node_name, remote_graph_executor_name, subgraph_def,
968 border_inputs, border_outputs, require_shape_type, &graph, &fused_node));
969
970 for (const Node* node : graph.nodes()) {
971 for (int i = 0; i < node->num_inputs(); ++i) {
972 const Edge* edge = nullptr;
973 TF_RETURN_IF_ERROR(node->input_edge(i, &edge));
974 for (int j = 0; j < border_outputs.size(); ++j) {
975 const string& output = border_outputs.at(j);
976 const TensorId tid = ParseTensorName(output);
977 const string output_name(tid.first);
978 Node* src_node = edge->src();
979 if (src_node != nullptr && src_node->name() == output_name &&
980 edge->src_output() == tid.second) {
981 // Source node is replaced by new fused node.
982 Node* dst_node = edge->dst();
983 const int dst_input = edge->dst_input();
984 LOG(INFO) << "Removing existing edge to " << edge->dst()->name()
985 << " from " << edge->src()->name();
986 graph.RemoveEdge(edge);
987 graph.AddEdge(fused_node, j, dst_node, dst_input);
988 }
989 }
990 }
991 }
992
993 // Replace output nodes by identity nodes which forward outputs from
994 // RemoteFusedGraphExecuteOpNode
995 for (const string& output : outputs) {
996 const TensorId output_tid = ParseTensorName(output);
997 const string output_name(output_tid.first);
998 for (size_t i = 0; i < border_outputs.size(); ++i) {
999 const TensorId subgraph_output_tid =
1000 ParseTensorName(border_outputs.at(i));
1001 const string subgraph_output_name(subgraph_output_tid.first);
1002 if (output_name == subgraph_output_name) {
1003 LOG(INFO) << "As graph output and subgraph output are same, "
1004 << "the graph output node is replaced by identity node";
1005 Node* original_output_node = FindMutableNodeByName(output_name, &graph);
1006 CHECK_NOTNULL(original_output_node);
1007 CHECK_EQ(1, original_output_node->num_outputs())
1008 << "Num outputs should be 1 for " << output << ".";
1009 graph.RemoveNode(original_output_node);
1010 Node* new_node;
1011 TF_RETURN_IF_ERROR(BuildIdentityOpNode(output_name,
1012 remote_fused_graph_node_name, i,
1013 DT_FLOAT, &graph, &new_node));
1014 CHECK_NOTNULL(new_node);
1015 }
1016 }
1017 }
1018
1019 GraphDef result_graph_def;
1020
1021 graph.ToGraphDef(&result_graph_def);
1022
1023 ClusterInfo graph_cluster;
1024 TF_RETURN_IF_ERROR(
1025 BuildClusterByBorder(inputs, outputs, result_graph_def, &graph_cluster));
1026
1027 // Remove unvisited nodes
1028 TF_RETURN_IF_ERROR(BuildClusterSubgraphDef(graph_cluster, result_graph_def,
1029 output_graph_def));
1030
1031 return Status::OK();
1032 }
1033
FuseRemoteGraphByNodeNames(const GraphDef & input_graph_def,const std::vector<string> & inputs,const std::vector<string> & outputs,const string & remote_fused_graph_node_name_prefix,const std::unordered_set<string> & subgraph_nodes,const string & remote_fused_graph_executor_name,const bool require_shape_type,GraphDef * output_graph_def)1034 /* static */ Status RemoteFusedGraphExecuteUtils::FuseRemoteGraphByNodeNames(
1035 const GraphDef& input_graph_def, const std::vector<string>& inputs,
1036 const std::vector<string>& outputs,
1037 const string& remote_fused_graph_node_name_prefix,
1038 const std::unordered_set<string>& subgraph_nodes,
1039 const string& remote_fused_graph_executor_name,
1040 const bool require_shape_type, GraphDef* output_graph_def) {
1041 std::vector<ClusterInfo> ci_vec;
1042 TF_RETURN_IF_ERROR(RemoteFusedGraphExecuteUtils::ClusterizeNodes(
1043 subgraph_nodes, input_graph_def, &ci_vec));
1044
1045 for (size_t i = 0; i < ci_vec.size(); ++i) {
1046 const string remote_fused_graph_node_name =
1047 strings::StrCat(remote_fused_graph_node_name_prefix, "/", i);
1048 TF_RETURN_IF_ERROR(FuseCluster(input_graph_def, inputs, outputs,
1049 remote_fused_graph_node_name, ci_vec.at(i),
1050 remote_fused_graph_executor_name,
1051 require_shape_type, output_graph_def));
1052 }
1053 return Status::OK();
1054 }
1055
FuseRemoteGraphByBorder(const GraphDef & input_graph_def,const std::vector<string> & inputs,const std::vector<string> & outputs,const string & remote_fused_graph_node_name,const std::vector<string> & border_inputs,const std::vector<string> & border_outputs,const string & remote_graph_executor_name,const bool require_shape_type,GraphDef * output_graph_def)1056 /* static */ Status RemoteFusedGraphExecuteUtils::FuseRemoteGraphByBorder(
1057 const GraphDef& input_graph_def, const std::vector<string>& inputs,
1058 const std::vector<string>& outputs,
1059 const string& remote_fused_graph_node_name,
1060 const std::vector<string>& border_inputs,
1061 const std::vector<string>& border_outputs,
1062 const string& remote_graph_executor_name, const bool require_shape_type,
1063 GraphDef* output_graph_def) {
1064 ClusterInfo cluster;
1065 TF_RETURN_IF_ERROR(RemoteFusedGraphExecuteUtils::BuildClusterByBorder(
1066 border_inputs, border_outputs, input_graph_def, &cluster));
1067
1068 return FuseCluster(
1069 input_graph_def, inputs, outputs, remote_fused_graph_node_name, cluster,
1070 remote_graph_executor_name, require_shape_type, output_graph_def);
1071 }
1072
FuseRemoteGraphByOpTypes(const GraphDef & input_graph_def,const std::vector<string> & inputs,const std::vector<string> & outputs,const string & remote_fused_graph_node_name_prefix,const std::unordered_set<string> & fused_op_types,const string & remote_fused_graph_executor_name,const bool require_shape_type,GraphDef * output_graph_def)1073 /* static */ Status RemoteFusedGraphExecuteUtils::FuseRemoteGraphByOpTypes(
1074 const GraphDef& input_graph_def, const std::vector<string>& inputs,
1075 const std::vector<string>& outputs,
1076 const string& remote_fused_graph_node_name_prefix,
1077 const std::unordered_set<string>& fused_op_types,
1078 const string& remote_fused_graph_executor_name,
1079 const bool require_shape_type, GraphDef* output_graph_def) {
1080 const std::unordered_set<string> fused_nodes_filtered_by_op_types =
1081 BuildNodeMapFromOpTypes(input_graph_def, fused_op_types);
1082
1083 return FuseRemoteGraphByNodeNames(
1084 input_graph_def, inputs, outputs, remote_fused_graph_node_name_prefix,
1085 fused_nodes_filtered_by_op_types, remote_fused_graph_executor_name,
1086 require_shape_type, output_graph_def);
1087 }
1088
FuseRemoteGraphByExecutor(const GraphDef & input_graph_def,const std::vector<string> & inputs,const std::vector<string> & outputs,const string & executor_name,GraphDef * output_graph_def)1089 /* static */ Status RemoteFusedGraphExecuteUtils::FuseRemoteGraphByExecutor(
1090 const GraphDef& input_graph_def, const std::vector<string>& inputs,
1091 const std::vector<string>& outputs, const string& executor_name,
1092 GraphDef* output_graph_def) {
1093 const ExecutorBuildFunc* build_func = GetExecutorBuildFunc(executor_name);
1094 if (build_func == nullptr) {
1095 return errors::InvalidArgument("Unknown executor name: " + executor_name);
1096 }
1097 std::unique_ptr<IRemoteFusedGraphExecutor> executor;
1098 TF_RETURN_IF_ERROR((*build_func)(&executor));
1099 CHECK_NOTNULL(executor.get());
1100 if (!executor->IsEnabled()) {
1101 // As this executor is not enabled, just return original graph as is.
1102 *output_graph_def = input_graph_def;
1103 return Status::OK();
1104 }
1105 return executor->FuseRemoteGraph(input_graph_def, inputs, outputs,
1106 output_graph_def);
1107 }
1108
PlaceRemoteGraphArguments(const std::vector<string> & inputs,const std::vector<string> & outputs,const std::unordered_set<string> & fused_node_names,const std::vector<string> & border_inputs,const std::vector<string> & border_outputs,const std::unordered_set<string> & fused_op_types,const string & remote_fused_graph_node_name,const string & remote_graph_executor_name,GraphDef * graph_def)1109 /* static */ Status RemoteFusedGraphExecuteUtils::PlaceRemoteGraphArguments(
1110 const std::vector<string>& inputs, const std::vector<string>& outputs,
1111 const std::unordered_set<string>& fused_node_names,
1112 const std::vector<string>& border_inputs,
1113 const std::vector<string>& border_outputs,
1114 const std::unordered_set<string>& fused_op_types,
1115 const string& remote_fused_graph_node_name,
1116 const string& remote_graph_executor_name, GraphDef* graph_def) {
1117 CHECK_NOTNULL(graph_def);
1118
1119 const std::unordered_set<string> fused_nodes_filtered_by_op_types =
1120 BuildNodeMapFromOpTypes(*graph_def, fused_op_types);
1121
1122 for (NodeDef& node_def : *graph_def->mutable_node()) {
1123 string attr_str;
1124 TensorId tid;
1125 for (size_t i = 0; i < inputs.size(); ++i) {
1126 if (IsSameNodeName(node_def, inputs.at(i), &tid)) {
1127 AppendDeliminator(&attr_str);
1128 attr_str += BuildNodeTypeAttr(GRAPH_INPUT, tid.second, i,
1129 remote_graph_executor_name,
1130 remote_fused_graph_node_name);
1131 }
1132 }
1133 for (size_t i = 0; i < outputs.size(); ++i) {
1134 if (IsSameNodeName(node_def, outputs.at(i), &tid)) {
1135 AppendDeliminator(&attr_str);
1136 attr_str += BuildNodeTypeAttr(GRAPH_OUTPUT, tid.second, i);
1137 }
1138 }
1139 for (const string& fused_node_name : fused_node_names) {
1140 if (fused_node_name == node_def.name()) {
1141 AppendDeliminator(&attr_str);
1142 attr_str += BuildNodeTypeAttr(FUSED_NODE);
1143 }
1144 }
1145 for (const string& fused_node_name : fused_nodes_filtered_by_op_types) {
1146 if (fused_node_name == node_def.name()) {
1147 AppendDeliminator(&attr_str);
1148 attr_str += BuildNodeTypeAttr(FUSED_NODE);
1149 }
1150 }
1151 for (size_t i = 0; i < border_inputs.size(); ++i) {
1152 if (IsSameNodeName(node_def, border_inputs.at(i), &tid)) {
1153 AppendDeliminator(&attr_str);
1154 attr_str += BuildNodeTypeAttr(BORDER_INPUT, tid.second, i);
1155 }
1156 }
1157 for (size_t i = 0; i < border_outputs.size(); ++i) {
1158 if (IsSameNodeName(node_def, border_outputs.at(i), &tid)) {
1159 AppendDeliminator(&attr_str);
1160 attr_str += BuildNodeTypeAttr(BORDER_OUTPUT, tid.second, i);
1161 }
1162 }
1163 if (attr_str.empty()) {
1164 attr_str += BuildNodeTypeAttr(UNUSED);
1165 }
1166 AddNodeAttr(ATTR_NODE_TYPE, attr_str, &node_def);
1167 }
1168 return Status::OK();
1169 }
1170
1171 /* static */ Status
FuseRemoteGraphByPlacedArguments(const GraphDef & input_graph_def,const std::vector<std::pair<string,Tensor>> & input_tensors,GraphDef * output_graph_def)1172 RemoteFusedGraphExecuteUtils::FuseRemoteGraphByPlacedArguments(
1173 const GraphDef& input_graph_def,
1174 const std::vector<std::pair<string, Tensor>>& input_tensors,
1175 GraphDef* output_graph_def) {
1176 std::unordered_map<int, string> input_map;
1177 std::unordered_map<int, string> output_map;
1178 std::unordered_set<string> fused_node_names;
1179 std::unordered_map<int, string> border_input_map;
1180 std::unordered_map<int, string> border_output_map;
1181 string remote_graph_executor_name;
1182 string remote_fused_graph_node_name;
1183
1184 for (const NodeDef& node_def : input_graph_def.node()) {
1185 string attr_str;
1186 TF_RETURN_IF_ERROR(GetNodeAttr(node_def, ATTR_NODE_TYPE, &attr_str));
1187 std::vector<std::vector<string>> attr_strs;
1188 for (const string& str : str_util::Split(attr_str, ":")) {
1189 attr_strs.emplace_back(str_util::Split(str, ","));
1190 }
1191 if (attr_strs.empty()) {
1192 return errors::InvalidArgument("Remote graph node type not found.");
1193 }
1194 for (const std::vector<string>& attr : attr_strs) {
1195 if (attr.empty()) {
1196 return errors::InvalidArgument("Empty remote graph node type attr.");
1197 }
1198 int node_type_int;
1199 CHECK(strings::safe_strto32(attr.at(0), &node_type_int)) << attr.at(0);
1200 const RemoteFusedGraphNodeType node_type =
1201 static_cast<RemoteFusedGraphNodeType>(node_type_int);
1202 const string& name = node_def.name();
1203 int port;
1204 int index;
1205
1206 switch (node_type) {
1207 case GRAPH_INPUT:
1208 VLOG(2) << "Graph input: " << name;
1209 CHECK_EQ(5, attr.size());
1210 CHECK(strings::safe_strto32(attr.at(1), &port));
1211 CHECK(strings::safe_strto32(attr.at(2), &index));
1212 CHECK(!attr.at(3).empty());
1213 remote_graph_executor_name = attr.at(3);
1214 CHECK(!attr.at(4).empty());
1215 remote_fused_graph_node_name = attr.at(4);
1216 input_map.emplace(index, strings::StrCat(name, ":", port));
1217 if (GetExecutorBuildFunc(remote_graph_executor_name) == nullptr) {
1218 LOG(INFO) << "Executor for " << remote_graph_executor_name
1219 << " not registered. Do not fuse.";
1220 *output_graph_def = input_graph_def;
1221 return Status::OK();
1222 }
1223 break;
1224 case GRAPH_OUTPUT:
1225 VLOG(2) << "Graph output: " << name;
1226 CHECK_EQ(3, attr.size());
1227 CHECK(strings::safe_strto32(attr.at(1), &port));
1228 CHECK(strings::safe_strto32(attr.at(2), &index));
1229 output_map.emplace(index, strings::StrCat(name, ":", port));
1230 break;
1231 case FUSED_NODE:
1232 VLOG(2) << "Fused node: " << name;
1233 CHECK_EQ(1, attr.size());
1234 fused_node_names.emplace(name);
1235 break;
1236 case BORDER_INPUT:
1237 VLOG(2) << "Border input: " << name;
1238 CHECK_EQ(3, attr.size());
1239 CHECK(strings::safe_strto32(attr.at(1), &port));
1240 CHECK(strings::safe_strto32(attr.at(2), &index));
1241 border_input_map.emplace(index, strings::StrCat(name, ":", port));
1242 break;
1243 case BORDER_OUTPUT:
1244 VLOG(2) << "Border output: " << name;
1245 CHECK_EQ(3, attr.size());
1246 CHECK(strings::safe_strto32(attr.at(1), &port));
1247 CHECK(strings::safe_strto32(attr.at(2), &index));
1248 border_output_map.emplace(index, strings::StrCat(name, ":", port));
1249 break;
1250 case UNUSED:
1251 // do nothing
1252 break;
1253 default:
1254 // unsupported value
1255 LOG(FATAL);
1256 }
1257 }
1258 }
1259 bool require_shape_type = false;
1260 std::vector<string> inputs;
1261 std::vector<string> outputs;
1262 std::vector<string> border_inputs;
1263 std::vector<string> border_outputs;
1264 ConvertMapToVector(input_map, &inputs);
1265 ConvertMapToVector(output_map, &outputs);
1266 ConvertMapToVector(border_input_map, &border_inputs);
1267 ConvertMapToVector(border_output_map, &border_outputs);
1268
1269 if (!input_tensors.empty()) {
1270 bool input_match = false;
1271 if (inputs.size() == input_tensors.size()) {
1272 for (const std::pair<string, Tensor>& input_tensor : input_tensors) {
1273 if (!ContainsSameTensorId(input_tensor.first, inputs)) {
1274 break;
1275 }
1276 DataType data_type;
1277 TensorShape shape;
1278 if (GetOutputTensorShapeType(input_graph_def, input_tensor.first,
1279 &data_type, &shape)) {
1280 if (data_type == input_tensor.second.dtype() &&
1281 shape == input_tensor.second.shape()) {
1282 VLOG(2) << "Input matched!";
1283 // Shape type matched.
1284 input_match = true;
1285 require_shape_type = true;
1286 }
1287 } else {
1288 // Shape type not required.
1289 input_match = true;
1290 }
1291 }
1292 }
1293 if (!input_match) {
1294 // Input mismatch. Just copy original graph
1295 *output_graph_def = input_graph_def;
1296 return Status::OK();
1297 }
1298 }
1299
1300 if (!fused_node_names.empty()) {
1301 TF_RETURN_IF_ERROR(FuseRemoteGraphByNodeNames(
1302 input_graph_def, inputs, outputs, remote_fused_graph_node_name,
1303 fused_node_names, remote_graph_executor_name, require_shape_type,
1304 output_graph_def));
1305 } else if (!border_inputs.empty() || !border_outputs.empty()) {
1306 TF_RETURN_IF_ERROR(FuseRemoteGraphByBorder(
1307 input_graph_def, inputs, outputs, remote_fused_graph_node_name,
1308 border_inputs, border_outputs, remote_graph_executor_name,
1309 require_shape_type, output_graph_def));
1310 } else {
1311 *output_graph_def = input_graph_def;
1312 }
1313
1314 return Status::OK();
1315 }
1316
IsFuseReady(const GraphDef & graph_def,const std::vector<std::pair<string,Tensor>> & input_tensors)1317 /* static */ bool RemoteFusedGraphExecuteUtils::IsFuseReady(
1318 const GraphDef& graph_def,
1319 const std::vector<std::pair<string, Tensor>>& input_tensors) {
1320 for (const std::pair<string, Tensor>& input_tensor : input_tensors) {
1321 const NodeDef* node_def = FindNodeDefByName(input_tensor.first, graph_def);
1322 if (node_def == nullptr) {
1323 return false;
1324 }
1325 string attr;
1326 const Status status = GetNodeAttr(*node_def, ATTR_NODE_TYPE, &attr);
1327 if (!status.ok() || attr.empty()) {
1328 return false;
1329 }
1330 }
1331 return true;
1332 }
1333
CopyByteArrayToTensor(const void * src_ptr,const int src_size,Tensor * tensor)1334 /* static */ Status RemoteFusedGraphExecuteUtils::CopyByteArrayToTensor(
1335 const void* src_ptr, const int src_size, Tensor* tensor) {
1336 CHECK(tensor->TotalBytes() >= src_size)
1337 << tensor->TotalBytes() << ", " << src_size;
1338 void* dst_ptr;
1339 switch (tensor->dtype()) {
1340 case DT_FLOAT:
1341 dst_ptr = tensor->flat<float>().data();
1342 break;
1343 case DT_DOUBLE:
1344 dst_ptr = tensor->flat<double>().data();
1345 break;
1346 case DT_INT32:
1347 dst_ptr = tensor->flat<int32>().data();
1348 break;
1349 case DT_UINT8:
1350 dst_ptr = tensor->flat<uint8>().data();
1351 break;
1352 case DT_INT16:
1353 dst_ptr = tensor->flat<int16>().data();
1354 break;
1355 case DT_INT8:
1356 dst_ptr = tensor->flat<int8>().data();
1357 break;
1358 case DT_STRING:
1359 dst_ptr = tensor->flat<string>().data();
1360 break;
1361 case DT_INT64:
1362 dst_ptr = tensor->flat<int64>().data();
1363 break;
1364 case DT_BOOL:
1365 dst_ptr = tensor->flat<bool>().data();
1366 break;
1367 case DT_QINT8:
1368 dst_ptr = tensor->flat<qint8>().data();
1369 break;
1370 case DT_QUINT8:
1371 dst_ptr = tensor->flat<quint8>().data();
1372 break;
1373 case DT_QINT32:
1374 dst_ptr = tensor->flat<qint32>().data();
1375 break;
1376 case DT_BFLOAT16:
1377 dst_ptr = tensor->flat<bfloat16>().data();
1378 break;
1379 case DT_QINT16:
1380 dst_ptr = tensor->flat<qint16>().data();
1381 break;
1382 case DT_QUINT16:
1383 dst_ptr = tensor->flat<quint16>().data();
1384 break;
1385 case DT_UINT16:
1386 dst_ptr = tensor->flat<uint16>().data();
1387 break;
1388 default:
1389 LOG(FATAL) << "type " << tensor->dtype() << " is not supported.";
1390 break;
1391 }
1392 CHECK_NOTNULL(dst_ptr);
1393 std::memcpy(dst_ptr, src_ptr, src_size);
1394 return Status::OK();
1395 }
1396
1397 /* static */ std::unordered_set<string>
BuildNodeMapFromOpTypes(const GraphDef & graph_def,const std::unordered_set<string> & op_types)1398 RemoteFusedGraphExecuteUtils::BuildNodeMapFromOpTypes(
1399 const GraphDef& graph_def, const std::unordered_set<string>& op_types) {
1400 std::unordered_set<string> retval;
1401 for (const NodeDef& node_def : graph_def.node()) {
1402 if (op_types.count(node_def.op()) > 0) {
1403 retval.emplace(node_def.name());
1404 }
1405 }
1406 return retval;
1407 }
1408
1409 /* static */ std::unordered_set<string>
BuildNodeMapFromOpsDefinitions(const GraphDef & graph_def,const IRemoteFusedGraphOpsDefinitions & ops_definitions)1410 RemoteFusedGraphExecuteUtils::BuildNodeMapFromOpsDefinitions(
1411 const GraphDef& graph_def,
1412 const IRemoteFusedGraphOpsDefinitions& ops_definitions) {
1413 std::unordered_set<string> retval;
1414 for (const NodeDef& node_def : graph_def.node()) {
1415 std::vector<DataType> dt_vec;
1416 std::vector<TensorShape> shape_vec;
1417 const Status status =
1418 GetOutputTensorShapeType(node_def, &dt_vec, &shape_vec);
1419 if (!status.ok()) {
1420 shape_vec.clear();
1421 }
1422 if (ops_definitions.GetOpIdFor(
1423 node_def.op(), DataTypeVector(dt_vec.begin(), dt_vec.end())) !=
1424 IRemoteFusedGraphOpsDefinitions::INVALID_OP_ID) {
1425 retval.emplace(node_def.name());
1426 }
1427 }
1428 return retval;
1429 }
1430
ReplaceInputNodeByPlaceHolder(const string & input,const DataType type,const TensorShape & shape,GraphDef * graph_def)1431 /* static */ Status RemoteFusedGraphExecuteUtils::ReplaceInputNodeByPlaceHolder(
1432 const string& input, const DataType type, const TensorShape& shape,
1433 GraphDef* graph_def) {
1434 const TensorId tid = ParseTensorName(input);
1435 CHECK_EQ(0, tid.second);
1436 const string node_name(tid.first);
1437 for (NodeDef& node : *graph_def->mutable_node()) {
1438 if (node.name() != node_name) {
1439 continue;
1440 }
1441 if (node.op() == "Placeholder") {
1442 return Status::OK();
1443 } else {
1444 NodeDef placeholder_node;
1445 placeholder_node.set_op("Placeholder");
1446 placeholder_node.set_name(node_name);
1447 AddNodeAttr("dtype", type, &placeholder_node);
1448 AddNodeAttr("shape", shape, &placeholder_node);
1449 // TODO(satok): Remove once we merge attributes
1450 AddOutputTensorShapeType({type}, {shape}, &placeholder_node);
1451 node.Clear();
1452 node = placeholder_node;
1453 return Status::OK();
1454 }
1455 }
1456 return errors::InvalidArgument(
1457 strings::StrCat(node_name, " not found for replacement."));
1458 }
1459
BuildNodeTypeAttr(const RemoteFusedGraphNodeType node_type,const int port,const int index,const string & executor_name,const string & node_name)1460 /* static */ string RemoteFusedGraphExecuteUtils::BuildNodeTypeAttr(
1461 const RemoteFusedGraphNodeType node_type, const int port, const int index,
1462 const string& executor_name, const string& node_name) {
1463 return strings::StrCat(static_cast<int>(node_type), ",", port, ",", index,
1464 ",", executor_name, ",", node_name);
1465 }
1466
BuildNodeTypeAttr(const RemoteFusedGraphNodeType node_type,const int port,const int index)1467 /* static */ string RemoteFusedGraphExecuteUtils::BuildNodeTypeAttr(
1468 const RemoteFusedGraphNodeType node_type, const int port, const int index) {
1469 return strings::StrCat(static_cast<int>(node_type), ",", port, ",", index);
1470 }
1471
BuildNodeTypeAttr(const RemoteFusedGraphNodeType node_type)1472 /* static */ string RemoteFusedGraphExecuteUtils::BuildNodeTypeAttr(
1473 const RemoteFusedGraphNodeType node_type) {
1474 return strings::StrCat(static_cast<int>(node_type));
1475 }
1476
1477 } // namespace tensorflow
1478