1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/kernels/remote_fused_graph_execute_utils.h"
17
18 #include <algorithm>
19 #include <queue>
20 #include <utility>
21
22 #include "tensorflow/core/common_runtime/shape_refiner.h"
23 #include "tensorflow/core/framework/graph.pb.h"
24 #include "tensorflow/core/framework/node_def_util.h"
25 #include "tensorflow/core/framework/remote_fused_graph_execute_info.pb.h"
26 #include "tensorflow/core/framework/tensor.pb.h"
27 #include "tensorflow/core/framework/tensor_shape.pb.h"
28 #include "tensorflow/core/graph/algorithm.h"
29 #include "tensorflow/core/graph/node_builder.h"
30 #include "tensorflow/core/public/session.h"
31 #include "tensorflow/core/public/session_options.h"
32
33 namespace tensorflow {
34 namespace {
FindNodeByName(const string & name,const Graph & graph)35 const Node* FindNodeByName(const string& name, const Graph& graph) {
36 for (const Node* node : graph.nodes()) {
37 CHECK_NOTNULL(node);
38 if (node->name() == name) {
39 return node;
40 }
41 }
42 return nullptr;
43 }
44
BuildNodeSetFromNodeNamesAndPorts(const std::vector<string> & node_names_and_ports)45 std::unordered_set<string> BuildNodeSetFromNodeNamesAndPorts(
46 const std::vector<string>& node_names_and_ports) {
47 std::unordered_set<string> retval;
48 for (const string& node_name_and_port : node_names_and_ports) {
49 const TensorId tid = ParseTensorName(node_name_and_port);
50 retval.emplace(tid.first);
51 }
52 return retval;
53 }
54
FindMutableNodeByName(const string & name,Graph * graph)55 Node* FindMutableNodeByName(const string& name, Graph* graph) {
56 for (Node* node : graph->nodes()) {
57 if (node != nullptr && node->name() == name) {
58 return node;
59 }
60 }
61 return nullptr;
62 }
63
FindNodeDefByName(const string & input,const GraphDef & graph_def)64 const NodeDef* FindNodeDefByName(const string& input,
65 const GraphDef& graph_def) {
66 const TensorId tid = ParseTensorName(input);
67 const string name = string(tid.first);
68 for (const NodeDef& node_def : graph_def.node()) {
69 if (node_def.name() == name) {
70 return &node_def;
71 }
72 }
73 return nullptr;
74 }
75
IsSameNodeName(const NodeDef & node_def,const string & node_name_and_port,TensorId * tid)76 bool IsSameNodeName(const NodeDef& node_def, const string& node_name_and_port,
77 TensorId* tid) {
78 CHECK_NOTNULL(tid);
79 *tid = ParseTensorName(node_name_and_port);
80 if (node_def.name() == tid->first) {
81 return true;
82 }
83 return false;
84 }
85
ContainsSameTensorId(const string & tensor_name,const std::vector<string> & tensor_names)86 bool ContainsSameTensorId(const string& tensor_name,
87 const std::vector<string>& tensor_names) {
88 const TensorId tid0 = ParseTensorName(tensor_name);
89 for (const string& name : tensor_names) {
90 const TensorId tid1 = ParseTensorName(name);
91 if (tid0.first == tid1.first && tid0.second == tid1.second) {
92 return true;
93 }
94 }
95 return false;
96 }
97
AppendDeliminator(string * str)98 void AppendDeliminator(string* str) {
99 CHECK_NOTNULL(str);
100 if (!str->empty()) {
101 *str += ":";
102 }
103 }
104
ConvertMapToVector(const std::unordered_map<int,string> & in,std::vector<string> * out)105 void ConvertMapToVector(const std::unordered_map<int, string>& in,
106 std::vector<string>* out) {
107 CHECK_NOTNULL(out);
108 out->resize(in.size());
109 for (size_t i = 0; i < in.size(); ++i) {
110 CHECK(in.count(i) > 0);
111 out->at(i) = in.at(i);
112 }
113 }
114
DumpGraphDef(const GraphDef & graph_def)115 string DumpGraphDef(const GraphDef& graph_def) {
116 string out;
117 for (const NodeDef& node : graph_def.node()) {
118 out += strings::StrCat("node: ", node.name(), "\n input: ");
119 for (const string& input : node.input()) {
120 out += strings::StrCat(input, ", ");
121 }
122 out += "\n";
123 }
124 return out;
125 }
126
DumpCluster(const RemoteFusedGraphExecuteUtils::ClusterInfo & cluster)127 string DumpCluster(const RemoteFusedGraphExecuteUtils::ClusterInfo& cluster) {
128 string out;
129 out += "Nodes:\n";
130 for (const string& str : std::get<0>(cluster)) {
131 out += str + ", ";
132 }
133 out += "\nInput border:\n";
134 for (const string& str : std::get<1>(cluster)) {
135 out += str + ", ";
136 }
137 out += "\nOutput border:\n";
138 for (const string& str : std::get<2>(cluster)) {
139 out += str + ", ";
140 }
141 return out;
142 }
143
144 } // namespace
145
146 /* static */ constexpr const char* const
147 RemoteFusedGraphExecuteUtils::ATTR_OUTPUT_DATA_TYPES;
148 /* static */ constexpr const char* const
149 RemoteFusedGraphExecuteUtils::ATTR_OUTPUT_SHAPES;
150 /* static */ constexpr const char* const RemoteFusedGraphExecuteUtils::
151 ATTR_SERIALIZED_REMOTE_FUSED_GRAPH_EXECUTE_INFO;
152 /* static */ constexpr const char* const
153 RemoteFusedGraphExecuteUtils::ATTR_NODE_TYPE;
154 /* static */ constexpr const char* const RemoteFusedGraphExecuteUtils::
155 TRANSFORM_ARG_REMOTE_FUSED_GRAPH_EXECUTOR_NAME;
156 /* static */ constexpr const char* const
157 RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_REMOTE_FUSED_GRAPH_NODE_NAME;
158 /* static */ constexpr const char* const
159 RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_FUSED_NODES;
160 /* static */ constexpr const char* const
161 RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_BORDER_INPUTS;
162 /* static */ constexpr const char* const
163 RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_BORDER_OUTPUTS;
164 /* static */ constexpr const char* const
165 RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_FUSED_OP_TYPES;
166 /* static */ constexpr const char* const
167 RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_FUSE_BY_EXECUTOR;
168 /* static */ constexpr const char* const
169 RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_INPUT_TYPES;
170 /* static */ constexpr const char* const
171 RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_INPUT_SHAPES;
172
ExecutorBuildRegistrar(const string & name,ExecutorBuildFunc executor_build_func)173 RemoteFusedGraphExecuteUtils::ExecutorBuildRegistrar::ExecutorBuildRegistrar(
174 const string& name, ExecutorBuildFunc executor_build_func) {
175 ExecutorBuildRegistry& executor_build_registry = *GetExecutorBuildRegistry();
176 executor_build_registry[name] = std::move(executor_build_func);
177 }
178
179 /* static */ const RemoteFusedGraphExecuteUtils::ExecutorBuildFunc*
GetExecutorBuildFunc(const string & name)180 RemoteFusedGraphExecuteUtils::GetExecutorBuildFunc(const string& name) {
181 ExecutorBuildRegistry& executor_build_registry = *GetExecutorBuildRegistry();
182 if (executor_build_registry.count(name) <= 0) {
183 return nullptr;
184 }
185 return &executor_build_registry.at(name);
186 }
187
188 /* static */ RemoteFusedGraphExecuteUtils::ExecutorBuildRegistry*
GetExecutorBuildRegistry()189 RemoteFusedGraphExecuteUtils::GetExecutorBuildRegistry() {
190 static ExecutorBuildRegistry executor_builder_registry;
191 return &executor_builder_registry;
192 }
193
194 /**
195 * - DryRunInference
196 * To determine shapes of output tensors of all nodes, dryrun the graph.
197 * This function supplies memory allocation information when loading
198 * the graph. This function is used to verify shape inference and actual
199 * output shape.
200 */
DryRunInference(const GraphDef & graph_def,const std::vector<std::pair<string,Tensor>> & input_node_info_list,const std::vector<string> & output_node_names,const bool initialize_by_zero,std::vector<tensorflow::Tensor> * output_tensors)201 /* static */ Status RemoteFusedGraphExecuteUtils::DryRunInference(
202 const GraphDef& graph_def,
203 const std::vector<std::pair<string, Tensor>>& input_node_info_list,
204 const std::vector<string>& output_node_names, const bool initialize_by_zero,
205 std::vector<tensorflow::Tensor>* output_tensors) {
206 // Create input tensor vector. If "initialize_by_zero" is true,
207 // input tensor fields are initialized by 0.
208 std::vector<std::pair<string, tensorflow::Tensor>> input_tensors;
209 for (const std::pair<string, Tensor>& input : input_node_info_list) {
210 CHECK(input.second.IsInitialized());
211 if (!initialize_by_zero) {
212 input_tensors.push_back({input.first, input.second});
213 continue;
214 }
215 // If input tensor is not initialized, initialize by 0-filling
216 const DataType data_type = input.second.dtype();
217 const TensorShape& shape = input.second.shape();
218 Tensor input_tensor(data_type, shape);
219 switch (data_type) {
220 case DT_INT32: {
221 auto int_tensor = input_tensor.flat<int32>();
222 int_tensor = int_tensor.constant(0);
223 break;
224 }
225 case DT_FLOAT: {
226 auto float_tensor = input_tensor.flat<float>();
227 float_tensor = float_tensor.constant(0.0f);
228 break;
229 }
230 case DT_QUINT8: {
231 auto int_tensor = input_tensor.flat<quint8>();
232 int_tensor = int_tensor.constant(0);
233 break;
234 }
235 default:
236 LOG(FATAL) << "Unsupported input type: " << data_type;
237 }
238 input_tensors.push_back({input.first, input_tensor});
239 }
240
241 // Setup session
242 CHECK(output_tensors != nullptr);
243 SessionOptions session_options;
244 session_options.env = Env::Default();
245 std::unique_ptr<Session> session =
246 std::unique_ptr<Session>(NewSession(session_options));
247 Status status = session->Create(graph_def);
248 if (!status.ok()) {
249 return status;
250 }
251
252 // Setup session arguments
253 RunOptions run_options;
254 run_options.set_trace_level(RunOptions::FULL_TRACE);
255 RunMetadata run_metadata;
256
257 // Run inference with all node as output
258 status = session->Run(run_options, input_tensors, output_node_names, {},
259 output_tensors, &run_metadata);
260 if (!status.ok()) {
261 LOG(ERROR) << "Error during inference: " << status;
262 return status;
263 }
264 return Status();
265 }
266
DryRunInferenceForAllNode(const GraphDef & graph_def,const std::vector<std::pair<string,Tensor>> & input_node_info_list,const bool initialize_by_zero,RemoteFusedGraphExecuteUtils::TensorShapeMap * tensor_shape_map)267 /* static */ Status RemoteFusedGraphExecuteUtils::DryRunInferenceForAllNode(
268 const GraphDef& graph_def,
269 const std::vector<std::pair<string, Tensor>>& input_node_info_list,
270 const bool initialize_by_zero,
271 RemoteFusedGraphExecuteUtils::TensorShapeMap* tensor_shape_map) {
272 CHECK(tensor_shape_map != nullptr);
273 std::vector<Tensor> output_tensors;
274 output_tensors.reserve(graph_def.node_size());
275 std::vector<string> output_node_names;
276
277 Graph graph(OpRegistry::Global());
278 Status status = ImportGraphDef({}, graph_def, &graph, nullptr);
279 if (!status.ok()) {
280 return status;
281 }
282
283 for (const Node* node : graph.nodes()) {
284 if (IsInputNode(input_node_info_list, node->name())) {
285 continue;
286 }
287 for (int i = 0; i < node->num_outputs(); ++i) {
288 output_node_names.emplace_back(strings::StrCat(node->name(), ":", i));
289 }
290 }
291
292 status = DryRunInference(graph_def, input_node_info_list, output_node_names,
293 initialize_by_zero, &output_tensors);
294 if (!status.ok()) {
295 VLOG(1) << "Failed to dryrun " << status;
296 return status;
297 }
298
299 CHECK_EQ(output_node_names.size(), output_tensors.size())
300 << output_node_names.size() << ", " << output_tensors.size();
301
302 // Append output tensor of input node in advance to create a map
303 // to avoid memory reallocation inside vector
304 for (const std::pair<string, Tensor>& input_node_info :
305 input_node_info_list) {
306 output_tensors.push_back(input_node_info.second);
307 }
308
309 for (int i = 0; static_cast<size_t>(i) < output_node_names.size(); ++i) {
310 const string& name = output_node_names.at(i);
311 const Tensor& tensor = output_tensors.at(i);
312 EmplaceTensorShapeType(name, tensor, tensor_shape_map);
313 }
314 for (int i = 0; static_cast<size_t>(i) < input_node_info_list.size(); ++i) {
315 const string& name = input_node_info_list.at(i).first;
316 const Tensor& tensor = output_tensors.at(output_node_names.size() + i);
317 EmplaceTensorShapeType(name, tensor, tensor_shape_map);
318 }
319 CHECK_EQ(output_node_names.size() + input_node_info_list.size(),
320 output_tensors.size());
321 return status;
322 }
323
IsInputNode(const std::vector<std::pair<string,Tensor>> & input_tensor_vector,const string & node_name)324 /* static */ bool RemoteFusedGraphExecuteUtils::IsInputNode(
325 const std::vector<std::pair<string, Tensor>>& input_tensor_vector,
326 const string& node_name) {
327 for (const std::pair<string, Tensor>& pair : input_tensor_vector) {
328 const TensorId tid = ParseTensorName(pair.first);
329 if (node_name == tid.first) {
330 return true;
331 }
332 }
333 return false;
334 }
335
ConvertToTensorShapeMap(const std::vector<std::pair<string,Tensor>> & input_node_info_list,const std::vector<string> & output_node_names,const std::vector<tensorflow::Tensor> & output_tensors,TensorShapeMap * tensor_shape_map)336 /* static */ void RemoteFusedGraphExecuteUtils::ConvertToTensorShapeMap(
337 const std::vector<std::pair<string, Tensor>>& input_node_info_list,
338 const std::vector<string>& output_node_names,
339 const std::vector<tensorflow::Tensor>& output_tensors,
340 TensorShapeMap* tensor_shape_map) {
341 CHECK_NE(tensor_shape_map, nullptr);
342 tensor_shape_map->clear();
343 tensor_shape_map->reserve(input_node_info_list.size() +
344 output_node_names.size());
345 const int output_node_count = output_node_names.size();
346 CHECK_EQ(output_node_count, output_tensors.size());
347 for (int i = 0; i < output_node_count; ++i) {
348 const string& node_name = output_node_names.at(i);
349 const Tensor& tensor = output_tensors.at(i);
350 EmplaceTensorShapeType(node_name, tensor, tensor_shape_map);
351 }
352 }
353
MakeTensorFromProto(const TensorProto & tensor_proto,Tensor * tensor)354 /* static */ Status RemoteFusedGraphExecuteUtils::MakeTensorFromProto(
355 const TensorProto& tensor_proto, Tensor* tensor) {
356 if (tensor_proto.dtype() > 0 && tensor_proto.dtype() <= DataType_MAX) {
357 Tensor parsed(tensor_proto.dtype());
358 if (parsed.FromProto(cpu_allocator(), tensor_proto)) {
359 *tensor = parsed;
360 return Status::OK();
361 }
362 }
363 return errors::InvalidArgument("Cannot parse tensor from proto");
364 }
365
AddOutputTensorShapeType(const std::vector<DataType> & data_types,const std::vector<TensorShape> & shapes,NodeDef * node_def)366 /* static */ bool RemoteFusedGraphExecuteUtils::AddOutputTensorShapeType(
367 const std::vector<DataType>& data_types,
368 const std::vector<TensorShape>& shapes, NodeDef* node_def) {
369 AddNodeAttr(ATTR_OUTPUT_DATA_TYPES, data_types, node_def);
370 AddNodeAttr(ATTR_OUTPUT_SHAPES, shapes, node_def);
371 return true;
372 }
373
374 /* static */ Status
AddOutputTensorShapeTypeByTensorShapeMap(const TensorShapeMap & tensor_shape_map,NodeDef * node_def)375 RemoteFusedGraphExecuteUtils::AddOutputTensorShapeTypeByTensorShapeMap(
376 const TensorShapeMap& tensor_shape_map, NodeDef* node_def) {
377 CHECK_NE(node_def, nullptr);
378 std::priority_queue<std::tuple<int, const TensorShapeType*>> queue;
379 auto its = tensor_shape_map.equal_range(node_def->name());
380 for (auto it = its.first; it != its.second; ++it) {
381 queue.emplace(std::make_tuple(it->second.first, &it->second.second));
382 }
383 int last_port = queue.size();
384 std::vector<DataType> data_types;
385 std::vector<TensorShape> shapes;
386 while (!queue.empty()) {
387 const int port = std::get<0>(queue.top());
388 const TensorShapeType* tst = std::get<1>(queue.top());
389 CHECK_NE(tst, nullptr);
390 data_types.emplace(data_types.begin(), tst->first);
391 shapes.emplace(shapes.begin(), tst->second);
392 CHECK_EQ(last_port - 1, port);
393 last_port = port;
394 queue.pop();
395 }
396 AddOutputTensorShapeType(data_types, shapes, node_def);
397 return Status::OK();
398 }
399
GetOutputTensorShapeType(AttrSlice attrs,std::vector<DataType> * data_types,std::vector<TensorShape> * shapes)400 /* static */ Status RemoteFusedGraphExecuteUtils::GetOutputTensorShapeType(
401 AttrSlice attrs, std::vector<DataType>* data_types,
402 std::vector<TensorShape>* shapes) {
403 Status status;
404 if (data_types != nullptr) {
405 status = GetNodeAttr(attrs, ATTR_OUTPUT_DATA_TYPES, data_types);
406 }
407 if (!status.ok()) {
408 return status;
409 }
410 if (shapes != nullptr) {
411 status = GetNodeAttr(attrs, ATTR_OUTPUT_SHAPES, shapes);
412 if (status.ok() && data_types != nullptr) {
413 CHECK_EQ(data_types->size(), shapes->size());
414 }
415 }
416
417 return status;
418 }
419
GetOutputTensorShapeType(const GraphDef & graph_def,const string & name_and_port,DataType * data_type,TensorShape * shape)420 /* static */ bool RemoteFusedGraphExecuteUtils::GetOutputTensorShapeType(
421 const GraphDef& graph_def, const string& name_and_port, DataType* data_type,
422 TensorShape* shape) {
423 std::vector<DataType> data_types;
424 std::vector<TensorShape> shapes;
425 const TensorId tid = ParseTensorName(name_and_port);
426 const string node_name(tid.first);
427 const int port = tid.second;
428 const NodeDef* node_def = FindNodeDefByName(node_name, graph_def);
429 CHECK_NOTNULL(node_def);
430 GetOutputTensorShapeType(*node_def, &data_types, &shapes).IgnoreError();
431 if (data_types.empty()) {
432 return false;
433 }
434 int data_types_size = data_types.size();
435 CHECK(data_types_size > port);
436 *data_type = data_types.at(port);
437 *shape = shapes.at(port);
438 return true;
439 }
440
PropagateShapeInference(const GraphDef & graph_def,const std::vector<std::pair<string,Tensor>> & input_node_info_list,Graph * graph,ShapeRefiner * shape_refiner)441 /* static */ Status RemoteFusedGraphExecuteUtils::PropagateShapeInference(
442 const GraphDef& graph_def,
443 const std::vector<std::pair<string, Tensor>>& input_node_info_list,
444 Graph* graph, ShapeRefiner* shape_refiner) {
445 Status status;
446 auto visit = [&shape_refiner, &input_node_info_list, &status](Node* node) {
447 if (!status.ok()) {
448 return;
449 }
450 CHECK_NE(node, nullptr);
451 // If we visit an input node, we use the shape provided and set the
452 // shape accordingly.
453 bool is_input_node = false;
454 for (const std::pair<string, Tensor>& input_node_info :
455 input_node_info_list) {
456 if (node->name() == input_node_info.first) {
457 shape_inference::InferenceContext* context =
458 shape_refiner->GetContext(node);
459 shape_inference::ShapeHandle handle;
460 status = context->MakeShapeFromTensorShape(
461 input_node_info.second.shape(), &handle);
462 if (!status.ok()) {
463 break;
464 }
465 status = shape_refiner->SetShape(node, 0, handle);
466 if (!status.ok()) {
467 break;
468 }
469 is_input_node = true;
470 }
471 if (!status.ok()) {
472 break;
473 }
474 }
475 // If not an input node call AddNode() that recomputes the shape.
476 if (!is_input_node && status.ok()) {
477 status = shape_refiner->AddNode(node);
478 }
479 if (!status.ok()) {
480 VLOG(1) << "Shape inference failed for node: " << node->name();
481 }
482 };
483
484 ReverseDFS(*graph, {}, visit);
485
486 return status;
487 }
488
BuildTensorShapeMapFromGraph(const Graph & graph,const ShapeRefiner & shape_refiner,TensorShapeMap * tensor_shape_map)489 /* static */ Status RemoteFusedGraphExecuteUtils::BuildTensorShapeMapFromGraph(
490 const Graph& graph, const ShapeRefiner& shape_refiner,
491 TensorShapeMap* tensor_shape_map) {
492 for (int i = 0; i < graph.num_node_ids(); ++i) {
493 const Node* node = graph.FindNodeId(i);
494 CHECK_NE(node, nullptr);
495 for (int j = 0; j < node->num_outputs(); ++j) {
496 const int output_index = j;
497 const DataType dt = node->output_type(output_index);
498 shape_inference::InferenceContext* context =
499 shape_refiner.GetContext(node);
500 CHECK_NE(context, nullptr);
501 shape_inference::ShapeHandle shape_handle = context->output(output_index);
502 if (context->RankKnown(shape_handle)) {
503 TensorShape ts;
504 for (int k = 0; k < context->Rank(shape_handle); ++k) {
505 shape_inference::DimensionHandle dh = context->Dim(shape_handle, k);
506 CHECK(context->ValueKnown(dh));
507 ts.AddDim(context->Value(dh));
508 }
509 const string& node_name = node->name();
510 CHECK(tensor_shape_map->count(node_name) == 0);
511 tensor_shape_map->emplace(node_name,
512 std::make_pair(j, std::make_pair(dt, ts)));
513 } else {
514 return errors::InvalidArgument("Graph contains unknown shapes");
515 }
516 }
517 }
518 return Status::OK();
519 }
520
521 /* static */ const RemoteFusedGraphExecuteUtils::TensorShapeType*
GetTensorShapeType(const TensorShapeMap & tensor_shape_map,const string & node_name)522 RemoteFusedGraphExecuteUtils::GetTensorShapeType(
523 const TensorShapeMap& tensor_shape_map, const string& node_name) {
524 if (node_name.find(':') != string::npos) {
525 const TensorId tid = ParseTensorName(node_name);
526 return GetTensorShapeType(tensor_shape_map, string(tid.first), tid.second);
527 } else {
528 return GetTensorShapeType(tensor_shape_map, node_name, 0);
529 }
530 }
531
532 /* static */ const RemoteFusedGraphExecuteUtils::TensorShapeType*
GetTensorShapeType(const TensorShapeMap & tensor_shape_map,const string & node_name,const int port)533 RemoteFusedGraphExecuteUtils::GetTensorShapeType(
534 const TensorShapeMap& tensor_shape_map, const string& node_name,
535 const int port) {
536 CHECK_EQ(node_name.find(':'), string::npos);
537 if (tensor_shape_map.count(node_name) <= 0) {
538 return nullptr;
539 }
540 auto its = tensor_shape_map.equal_range(node_name);
541 for (auto it = its.first; it != its.second; ++it) {
542 if (it->second.first == port) {
543 return &it->second.second;
544 }
545 }
546 return nullptr;
547 }
548
549 /* static */ void
BuildRemoteGraphInputsAndOutputsFromProto(const RemoteFusedGraphExecuteInfo & proto,std::vector<std::pair<string,Tensor>> * inputs,std::vector<string> * outputs)550 RemoteFusedGraphExecuteUtils::BuildRemoteGraphInputsAndOutputsFromProto(
551 const RemoteFusedGraphExecuteInfo& proto,
552 std::vector<std::pair<string, Tensor>>* inputs,
553 std::vector<string>* outputs) {
554 CHECK_EQ(proto.graph_input_node_name_size(),
555 proto.default_graph_input_tensor_shape_size());
556 for (int i = 0; i < proto.graph_input_node_name_size(); ++i) {
557 inputs->emplace_back(
558 proto.graph_input_node_name(i),
559 Tensor(proto.default_graph_input_tensor_shape(i).dtype(),
560 TensorShape(proto.default_graph_input_tensor_shape(i).shape())));
561 }
562 for (const string& output_node_name : proto.graph_output_node_name()) {
563 outputs->emplace_back(output_node_name);
564 }
565 }
566
EmplaceTensorShapeType(const string & name,const Tensor & tensor,TensorShapeMap * tensor_shape_map)567 /* static */ void RemoteFusedGraphExecuteUtils::EmplaceTensorShapeType(
568 const string& name, const Tensor& tensor,
569 TensorShapeMap* tensor_shape_map) {
570 const TensorId tid = ParseTensorName(name);
571 CHECK_EQ(tensor_shape_map->count(name), 0);
572 tensor_shape_map->emplace(
573 string(tid.first),
574 std::make_pair(tid.second,
575 std::make_pair(tensor.dtype(), tensor.shape())));
576 }
577
BuildAndAddTensorShapes(const std::vector<std::pair<string,Tensor>> & input_tensors,const bool dry_run_inference,GraphDef * graph_def)578 /* static */ Status RemoteFusedGraphExecuteUtils::BuildAndAddTensorShapes(
579 const std::vector<std::pair<string, Tensor>>& input_tensors,
580 const bool dry_run_inference, GraphDef* graph_def) {
581 TensorShapeMap tensor_shape_map;
582 if (dry_run_inference) {
583 TF_RETURN_IF_ERROR(DryRunInferenceForAllNode(*graph_def, input_tensors,
584 /*initialize_by_zero=*/true,
585 &tensor_shape_map));
586 } else {
587 ImportGraphDefOptions opts;
588 Graph graph(OpRegistry::Global());
589 ShapeRefiner shape_refiner(graph.versions(), graph.op_registry());
590 TF_RETURN_IF_ERROR(
591 ImportGraphDef(opts, *graph_def, &graph, &shape_refiner));
592 TF_RETURN_IF_ERROR(PropagateShapeInference(*graph_def, input_tensors,
593 &graph, &shape_refiner));
594 TF_RETURN_IF_ERROR(
595 BuildTensorShapeMapFromGraph(graph, shape_refiner, &tensor_shape_map));
596 }
597
598 for (NodeDef& node_def : *graph_def->mutable_node()) {
599 TF_RETURN_IF_ERROR(
600 AddOutputTensorShapeTypeByTensorShapeMap(tensor_shape_map, &node_def));
601 }
602
603 return Status::OK();
604 }
605
606 /* static */ Status
BuildRemoteFusedGraphExecuteInfo(const string & executor_name,const GraphDef & subgraph_def,const std::vector<string> & inputs,const std::vector<string> & outputs,const bool require_shape_type,RemoteFusedGraphExecuteInfo * execute_info,DataTypeVector * input_types,DataTypeVector * output_types)607 RemoteFusedGraphExecuteUtils::BuildRemoteFusedGraphExecuteInfo(
608 const string& executor_name, const GraphDef& subgraph_def,
609 const std::vector<string>& inputs, const std::vector<string>& outputs,
610 const bool require_shape_type, RemoteFusedGraphExecuteInfo* execute_info,
611 DataTypeVector* input_types, DataTypeVector* output_types) {
612 CHECK_NOTNULL(execute_info);
613 CHECK_NOTNULL(input_types);
614 CHECK_NOTNULL(output_types);
615
616 execute_info->Clear();
617 execute_info->set_executor_name(executor_name);
618
619 // copy graph
620 *execute_info->mutable_remote_graph() = subgraph_def;
621
622 for (const string& input : inputs) {
623 DataType dt;
624 TensorShape shape;
625 const bool has_shapetype =
626 GetOutputTensorShapeType(subgraph_def, input, &dt, &shape);
627
628 execute_info->add_graph_input_node_name(input);
629 if (has_shapetype) {
630 RemoteFusedGraphExecuteInfo::TensorShapeTypeProto& tensor_shape_type =
631 *execute_info->add_default_graph_input_tensor_shape();
632 tensor_shape_type.set_dtype(dt);
633 TensorShapeProto& tensor_shape_proto = *tensor_shape_type.mutable_shape();
634 for (const int64 dim : shape.dim_sizes()) {
635 tensor_shape_proto.add_dim()->set_size(dim);
636 }
637 input_types->push_back(dt);
638 } else {
639 CHECK(!require_shape_type)
640 << "No shape type found for " << input << DumpGraphDef(subgraph_def);
641 // Assuming input type is float if no data provided.
642 input_types->push_back(DT_FLOAT);
643 }
644 }
645
646 for (const string& output : outputs) {
647 DataType dt;
648 TensorShape shape;
649 const bool has_shapetype =
650 GetOutputTensorShapeType(subgraph_def, output, &dt, &shape);
651
652 execute_info->add_graph_output_node_name(output);
653 if (has_shapetype) {
654 RemoteFusedGraphExecuteInfo::TensorShapeTypeProto&
655 tensor_shape_type_proto =
656 *execute_info->add_default_graph_output_tensor_shape();
657 tensor_shape_type_proto.set_dtype(dt);
658 TensorShapeProto& tensor_shape_proto =
659 *tensor_shape_type_proto.mutable_shape();
660 for (const int64 dim : shape.dim_sizes()) {
661 tensor_shape_proto.add_dim()->set_size(dim);
662 }
663 output_types->push_back(dt);
664 } else {
665 CHECK(!require_shape_type)
666 << "No shape type found for " << output << DumpGraphDef(subgraph_def);
667 // Assuming output type is float if no data provided.
668 output_types->push_back(DT_FLOAT);
669 }
670 }
671
672 return Status::OK();
673 }
674
675 /* static */ Status
BuildRemoteFusedGraphExecuteOpNode(const string & node_name,const string & executor_name,const GraphDef & subgraph_def,const std::vector<string> & inputs,const std::vector<string> & outputs,const bool require_shape_type,Graph * graph,Node ** created_node)676 RemoteFusedGraphExecuteUtils::BuildRemoteFusedGraphExecuteOpNode(
677 const string& node_name, const string& executor_name,
678 const GraphDef& subgraph_def, const std::vector<string>& inputs,
679 const std::vector<string>& outputs, const bool require_shape_type,
680 Graph* graph, Node** created_node) {
681 CHECK_NOTNULL(graph);
682 CHECK_NOTNULL(created_node);
683
684 RemoteFusedGraphExecuteInfo execute_info;
685 DataTypeVector input_types;
686 DataTypeVector output_types;
687
688 TF_CHECK_OK(RemoteFusedGraphExecuteUtils::BuildRemoteFusedGraphExecuteInfo(
689 executor_name, subgraph_def, inputs, outputs, require_shape_type,
690 &execute_info, &input_types, &output_types));
691
692 std::vector<NodeBuilder::NodeOut> node_out_list;
693 for (const string& input : inputs) {
694 const TensorId tid = ParseTensorName(input);
695 Node* node = FindMutableNodeByName(string(tid.first), graph);
696 CHECK_NOTNULL(node);
697 node_out_list.emplace_back(node, tid.second);
698 }
699
700 const string execute_info_str = execute_info.SerializeAsString();
701
702 auto builder =
703 NodeBuilder(node_name, "RemoteFusedGraphExecute")
704 .Input(node_out_list)
705 .Attr("Tinputs", input_types)
706 .Attr("Toutputs", output_types)
707 .Attr("serialized_remote_fused_graph_execute_info", execute_info_str);
708
709 TF_RETURN_IF_ERROR(builder.Finalize(graph, created_node));
710 return Status::OK();
711 }
712
BuildIdentityOpNode(const string & node_name,const string & input_node_name,const int input_node_port,const DataType dt,Graph * graph,Node ** created_node)713 /* static */ Status RemoteFusedGraphExecuteUtils::BuildIdentityOpNode(
714 const string& node_name, const string& input_node_name,
715 const int input_node_port, const DataType dt, Graph* graph,
716 Node** created_node) {
717 Node* node = FindMutableNodeByName(input_node_name, graph);
718 CHECK_NOTNULL(node);
719 NodeBuilder::NodeOut node_out(node, input_node_port);
720
721 auto builder =
722 NodeBuilder(node_name, "Identity").Input(node_out).Attr("T", dt);
723
724 TF_RETURN_IF_ERROR(builder.Finalize(graph, created_node));
725 return Status::OK();
726 }
727
ClusterizeNodes(const std::unordered_set<string> & node_names,const GraphDef & graph_def,std::vector<ClusterInfo> * cluster_infos)728 /* static */ Status RemoteFusedGraphExecuteUtils::ClusterizeNodes(
729 const std::unordered_set<string>& node_names, const GraphDef& graph_def,
730 std::vector<ClusterInfo>* cluster_infos) {
731 Graph graph(OpRegistry::Global());
732 ShapeRefiner shape_refiner(graph.versions(), graph.op_registry());
733 TF_RETURN_IF_ERROR(ImportGraphDef({}, graph_def, &graph, &shape_refiner));
734 std::unordered_set<string> remaining_nodes = node_names;
735
736 while (!remaining_nodes.empty()) {
737 ClusterInfo ci;
738
739 // Determine one cluster nodes
740 std::unordered_set<const Node*> visited;
741 std::deque<const Node*> queue;
742 queue.emplace_back(FindNodeByName(*remaining_nodes.begin(), graph));
743 while (!queue.empty()) {
744 const Node* node = queue.front();
745 CHECK_NOTNULL(node);
746 queue.pop_front();
747 const string& node_name = node->name();
748 if (node_names.count(node_name) > 0) {
749 std::get<0>(ci).emplace(node_name);
750 remaining_nodes.erase(node_name);
751 } else {
752 // Edge of subgraph. Do nothing.
753 continue;
754 }
755 for (const Node* in : node->in_nodes()) {
756 if (visited.insert(in).second) {
757 queue.push_back(in);
758 }
759 }
760 for (const Node* out : node->out_nodes()) {
761 if (visited.insert(out).second) {
762 queue.push_back(out);
763 }
764 }
765 }
766
767 // Determine one cluster border
768 std::vector<string>& border_inputs = std::get<1>(ci);
769 std::vector<string>& border_outputs = std::get<2>(ci);
770 for (const string& node_name : node_names) {
771 Node* node = FindMutableNodeByName(node_name, &graph);
772 CHECK_NOTNULL(node);
773 int input_count = 0;
774 for (const Edge* in_edge : node->in_edges()) {
775 const Node* src_node = in_edge->src();
776 const bool src_is_outside =
777 node_names.count(src_node->name()) <= 0 && !src_node->IsSource();
778 if (src_is_outside) {
779 const string src_name =
780 strings::StrCat(src_node->name(), ":", in_edge->src_output());
781 CHECK_EQ(1, src_node->num_outputs())
782 << "output count of input border node must be one."
783 << src_node->name();
784 if (std::find(border_inputs.begin(), border_inputs.end(), src_name) ==
785 border_inputs.end()) {
786 border_inputs.emplace_back(src_name);
787 }
788 } else {
789 ++input_count;
790 }
791 }
792 int node_in_edges_size = node->in_edges().size();
793 CHECK(input_count == 0 || input_count == node_in_edges_size)
794 << "Invalid input_count(" << input_count << ", "
795 << node->in_edges().size() << ") " << node_name;
796
797 for (const Edge* out_edge : node->out_edges()) {
798 const Node* dst_node = out_edge->dst();
799 CHECK_NOTNULL(dst_node);
800 const bool dst_is_outside = node_names.count(dst_node->name()) <= 0;
801 const string dst_name =
802 strings::StrCat(node->name(), ":", out_edge->src_output());
803 if (dst_is_outside) {
804 if (dst_node->IsSink()) {
805 CHECK_EQ(1, node->num_outputs())
806 << "If you want to specify output node as subgraph output node "
807 << "the output count of the node must be 1 "
808 << "because that node is replaced by identity node.";
809 const string identity_dst_name =
810 strings::StrCat(node->name(), ":", 0);
811 if (std::find(border_outputs.begin(), border_outputs.end(),
812 identity_dst_name) == border_outputs.end()) {
813 border_outputs.emplace_back(identity_dst_name);
814 }
815 } else {
816 if (std::find(border_outputs.begin(), border_outputs.end(),
817 dst_name) == border_outputs.end()) {
818 border_outputs.emplace_back(dst_name);
819 }
820 }
821 }
822 }
823 }
824 cluster_infos->emplace_back(ci);
825 VLOG(1) << DumpCluster(ci);
826 }
827 return Status::OK();
828 }
829
BuildClusterSubgraphDef(const ClusterInfo & cluster,const GraphDef & graph_def,GraphDef * subgraph_def)830 /* static */ Status RemoteFusedGraphExecuteUtils::BuildClusterSubgraphDef(
831 const ClusterInfo& cluster, const GraphDef& graph_def,
832 GraphDef* subgraph_def) {
833 const std::unordered_set<string>& node_names = std::get<0>(cluster);
834 const std::unordered_set<string>& border_input_names =
835 BuildNodeSetFromNodeNamesAndPorts(std::get<1>(cluster));
836
837 Graph graph(OpRegistry::Global());
838 ShapeRefiner shape_refiner(graph.versions(), graph.op_registry());
839 TF_RETURN_IF_ERROR(ImportGraphDef({}, graph_def, &graph, &shape_refiner));
840
841 for (Node* node : graph.nodes()) {
842 if (node != nullptr && node_names.count(node->name()) <= 0 &&
843 border_input_names.count(node->name()) <= 0 && !node->IsSource() &&
844 !node->IsSink()) {
845 graph.RemoveNode(node);
846 }
847 }
848 graph.ToGraphDef(subgraph_def);
849
850 for (const string& subgraph_input : std::get<1>(cluster)) {
851 const TensorId tid = ParseTensorName(subgraph_input);
852 const string subgraph_input_name(tid.first);
853 const int subgraph_input_port = tid.second;
854 const NodeDef* node_def = FindNodeDefByName(subgraph_input_name, graph_def);
855 CHECK_NOTNULL(node_def);
856 std::vector<DataType> dt_vec;
857 std::vector<TensorShape> shape_vec;
858 GetOutputTensorShapeType(*node_def, &dt_vec, &shape_vec).IgnoreError();
859 const DataType& dt =
860 dt_vec.empty() ? DT_FLOAT : dt_vec.at(subgraph_input_port);
861 const TensorShape& shape =
862 shape_vec.empty() ? TensorShape({}) : shape_vec.at(subgraph_input_port);
863
864 TF_RETURN_IF_ERROR(ReplaceInputNodeByPlaceHolder(subgraph_input_name, dt,
865 shape, subgraph_def));
866 }
867
868 // sort subgraph_def to align order in graph_def
869 std::unordered_map<string, int> name_to_id_map;
870 for (int i = 0; i < graph_def.node_size(); ++i) {
871 name_to_id_map.emplace(graph_def.node(i).name(), i);
872 }
873 std::sort(subgraph_def->mutable_node()->begin(),
874 subgraph_def->mutable_node()->end(),
875 [&name_to_id_map](const NodeDef& node0, const NodeDef& node1) {
876 CHECK(name_to_id_map.count(node0.name()) > 0);
877 CHECK(name_to_id_map.count(node1.name()) > 0);
878 const int id0 = name_to_id_map.at(node0.name());
879 const int id1 = name_to_id_map.at(node1.name());
880 return id0 < id1;
881 });
882
883 VLOG(1) << DumpGraphDef(*subgraph_def);
884 return Status::OK();
885 }
886
BuildClusterByBorder(const std::vector<string> & border_inputs,const std::vector<string> & border_outputs,const GraphDef & graph_def,ClusterInfo * cluster)887 /* static */ Status RemoteFusedGraphExecuteUtils::BuildClusterByBorder(
888 const std::vector<string>& border_inputs,
889 const std::vector<string>& border_outputs, const GraphDef& graph_def,
890 ClusterInfo* cluster) {
891 Graph graph(OpRegistry::Global());
892 ShapeRefiner shape_refiner(graph.versions(), graph.op_registry());
893 TF_RETURN_IF_ERROR(ImportGraphDef({}, graph_def, &graph, &shape_refiner));
894
895 std::unordered_set<const Node*> visited;
896 std::deque<const Node*> queue;
897 for (const string& output : border_outputs) {
898 const TensorId tid = ParseTensorName(output);
899 const string output_node_name(tid.first);
900 for (const Node* node : graph.nodes()) {
901 if (output_node_name == node->name()) {
902 queue.push_back(node);
903 visited.insert(node);
904 }
905 }
906 }
907
908 std::unordered_set<const Node*> border_input_nodes;
909 // propagate visit to parent nodes until input nodes
910 while (!queue.empty()) {
911 const Node* node = queue.front();
912 queue.pop_front();
913 for (const Edge* edge : node->in_edges()) {
914 const Node* src_node = edge->src();
915 CHECK_NOTNULL(src_node);
916 const int src_port = edge->src_output();
917 bool input_found = false;
918 for (const string& input : border_inputs) {
919 const TensorId tid = ParseTensorName(input);
920 if (tid.first == src_node->name() && tid.second == src_port) {
921 input_found = true;
922 border_input_nodes.insert(src_node);
923 }
924 }
925 if (visited.insert(src_node).second) {
926 if (!input_found) {
927 queue.push_back(src_node);
928 }
929 }
930 }
931 }
932
933 for (const Node* node : visited) {
934 if (node != nullptr && !node->IsSource() && !node->IsSink() &&
935 border_input_nodes.count(node) <= 0) {
936 std::get<0>(*cluster).insert(node->name());
937 }
938 }
939 std::get<1>(*cluster) = border_inputs;
940 std::get<2>(*cluster) = border_outputs;
941 return Status::OK();
942 }
943
FuseCluster(const GraphDef & input_graph_def,const std::vector<string> & inputs,const std::vector<string> & outputs,const string & remote_fused_graph_node_name,const ClusterInfo & cluster,const string & remote_graph_executor_name,const bool require_shape_type,GraphDef * output_graph_def)944 /* static */ Status RemoteFusedGraphExecuteUtils::FuseCluster(
945 const GraphDef& input_graph_def, const std::vector<string>& inputs,
946 const std::vector<string>& outputs,
947 const string& remote_fused_graph_node_name, const ClusterInfo& cluster,
948 const string& remote_graph_executor_name, const bool require_shape_type,
949 GraphDef* output_graph_def) {
950 LOG(INFO) << "Transforming quantized stripped model to a remote fused "
951 "graph execute op by fusing a specified subgraph...";
952
953 CHECK(!remote_graph_executor_name.empty());
954
955 const std::vector<string>& border_inputs = std::get<1>(cluster);
956 const std::vector<string>& border_outputs = std::get<2>(cluster);
957
958 GraphDef subgraph_def;
959 TF_RETURN_IF_ERROR(
960 BuildClusterSubgraphDef(cluster, input_graph_def, &subgraph_def));
961
962 Graph graph(OpRegistry::Global());
963 ShapeRefiner shape_refiner(graph.versions(), graph.op_registry());
964 TF_RETURN_IF_ERROR(
965 ImportGraphDef({}, input_graph_def, &graph, &shape_refiner));
966
967 Node* fused_node;
968 TF_RETURN_IF_ERROR(BuildRemoteFusedGraphExecuteOpNode(
969 remote_fused_graph_node_name, remote_graph_executor_name, subgraph_def,
970 border_inputs, border_outputs, require_shape_type, &graph, &fused_node));
971
972 for (const Node* node : graph.nodes()) {
973 for (int i = 0, end = node->num_inputs(); i < end; ++i) {
974 const Edge* edge = nullptr;
975 TF_RETURN_IF_ERROR(node->input_edge(i, &edge));
976 for (int j = 0, second_end = border_outputs.size(); j < second_end; ++j) {
977 const string& output = border_outputs.at(j);
978 const TensorId tid = ParseTensorName(output);
979 const string output_name(tid.first);
980 Node* src_node = edge->src();
981 if (src_node != nullptr && src_node->name() == output_name &&
982 edge->src_output() == tid.second) {
983 // Source node is replaced by new fused node.
984 Node* dst_node = edge->dst();
985 const int dst_input = edge->dst_input();
986 LOG(INFO) << "Removing existing edge to " << edge->dst()->name()
987 << " from " << edge->src()->name();
988 graph.RemoveEdge(edge);
989 graph.AddEdge(fused_node, j, dst_node, dst_input);
990 }
991 }
992 }
993 }
994
995 // Replace output nodes by identity nodes which forward outputs from
996 // RemoteFusedGraphExecuteOpNode
997 for (const string& output : outputs) {
998 const TensorId output_tid = ParseTensorName(output);
999 const string output_name(output_tid.first);
1000 for (size_t i = 0; i < border_outputs.size(); ++i) {
1001 const TensorId subgraph_output_tid =
1002 ParseTensorName(border_outputs.at(i));
1003 const string subgraph_output_name(subgraph_output_tid.first);
1004 if (output_name == subgraph_output_name) {
1005 LOG(INFO) << "As graph output and subgraph output are same, "
1006 << "the graph output node is replaced by identity node";
1007 Node* original_output_node = FindMutableNodeByName(output_name, &graph);
1008 CHECK_NOTNULL(original_output_node);
1009 CHECK_EQ(1, original_output_node->num_outputs())
1010 << "Num outputs should be 1 for " << output << ".";
1011 graph.RemoveNode(original_output_node);
1012 Node* new_node;
1013 TF_RETURN_IF_ERROR(BuildIdentityOpNode(output_name,
1014 remote_fused_graph_node_name, i,
1015 DT_FLOAT, &graph, &new_node));
1016 CHECK_NOTNULL(new_node);
1017 }
1018 }
1019 }
1020
1021 GraphDef result_graph_def;
1022
1023 graph.ToGraphDef(&result_graph_def);
1024
1025 ClusterInfo graph_cluster;
1026 TF_RETURN_IF_ERROR(
1027 BuildClusterByBorder(inputs, outputs, result_graph_def, &graph_cluster));
1028
1029 // Remove unvisited nodes
1030 TF_RETURN_IF_ERROR(BuildClusterSubgraphDef(graph_cluster, result_graph_def,
1031 output_graph_def));
1032
1033 return Status::OK();
1034 }
1035
FuseRemoteGraphByNodeNames(const GraphDef & input_graph_def,const std::vector<string> & inputs,const std::vector<string> & outputs,const string & remote_fused_graph_node_name_prefix,const std::unordered_set<string> & subgraph_nodes,const string & remote_fused_graph_executor_name,const bool require_shape_type,GraphDef * output_graph_def)1036 /* static */ Status RemoteFusedGraphExecuteUtils::FuseRemoteGraphByNodeNames(
1037 const GraphDef& input_graph_def, const std::vector<string>& inputs,
1038 const std::vector<string>& outputs,
1039 const string& remote_fused_graph_node_name_prefix,
1040 const std::unordered_set<string>& subgraph_nodes,
1041 const string& remote_fused_graph_executor_name,
1042 const bool require_shape_type, GraphDef* output_graph_def) {
1043 std::vector<ClusterInfo> ci_vec;
1044 TF_RETURN_IF_ERROR(RemoteFusedGraphExecuteUtils::ClusterizeNodes(
1045 subgraph_nodes, input_graph_def, &ci_vec));
1046
1047 for (size_t i = 0; i < ci_vec.size(); ++i) {
1048 const string remote_fused_graph_node_name =
1049 strings::StrCat(remote_fused_graph_node_name_prefix, "/", i);
1050 TF_RETURN_IF_ERROR(FuseCluster(input_graph_def, inputs, outputs,
1051 remote_fused_graph_node_name, ci_vec.at(i),
1052 remote_fused_graph_executor_name,
1053 require_shape_type, output_graph_def));
1054 }
1055 return Status::OK();
1056 }
1057
FuseRemoteGraphByBorder(const GraphDef & input_graph_def,const std::vector<string> & inputs,const std::vector<string> & outputs,const string & remote_fused_graph_node_name,const std::vector<string> & border_inputs,const std::vector<string> & border_outputs,const string & remote_graph_executor_name,const bool require_shape_type,GraphDef * output_graph_def)1058 /* static */ Status RemoteFusedGraphExecuteUtils::FuseRemoteGraphByBorder(
1059 const GraphDef& input_graph_def, const std::vector<string>& inputs,
1060 const std::vector<string>& outputs,
1061 const string& remote_fused_graph_node_name,
1062 const std::vector<string>& border_inputs,
1063 const std::vector<string>& border_outputs,
1064 const string& remote_graph_executor_name, const bool require_shape_type,
1065 GraphDef* output_graph_def) {
1066 ClusterInfo cluster;
1067 TF_RETURN_IF_ERROR(RemoteFusedGraphExecuteUtils::BuildClusterByBorder(
1068 border_inputs, border_outputs, input_graph_def, &cluster));
1069
1070 return FuseCluster(
1071 input_graph_def, inputs, outputs, remote_fused_graph_node_name, cluster,
1072 remote_graph_executor_name, require_shape_type, output_graph_def);
1073 }
1074
FuseRemoteGraphByOpTypes(const GraphDef & input_graph_def,const std::vector<string> & inputs,const std::vector<string> & outputs,const string & remote_fused_graph_node_name_prefix,const std::unordered_set<string> & fused_op_types,const string & remote_fused_graph_executor_name,const bool require_shape_type,GraphDef * output_graph_def)1075 /* static */ Status RemoteFusedGraphExecuteUtils::FuseRemoteGraphByOpTypes(
1076 const GraphDef& input_graph_def, const std::vector<string>& inputs,
1077 const std::vector<string>& outputs,
1078 const string& remote_fused_graph_node_name_prefix,
1079 const std::unordered_set<string>& fused_op_types,
1080 const string& remote_fused_graph_executor_name,
1081 const bool require_shape_type, GraphDef* output_graph_def) {
1082 const std::unordered_set<string> fused_nodes_filtered_by_op_types =
1083 BuildNodeMapFromOpTypes(input_graph_def, fused_op_types);
1084
1085 return FuseRemoteGraphByNodeNames(
1086 input_graph_def, inputs, outputs, remote_fused_graph_node_name_prefix,
1087 fused_nodes_filtered_by_op_types, remote_fused_graph_executor_name,
1088 require_shape_type, output_graph_def);
1089 }
1090
FuseRemoteGraphByExecutor(const GraphDef & input_graph_def,const std::vector<string> & inputs,const std::vector<string> & outputs,const string & executor_name,GraphDef * output_graph_def)1091 /* static */ Status RemoteFusedGraphExecuteUtils::FuseRemoteGraphByExecutor(
1092 const GraphDef& input_graph_def, const std::vector<string>& inputs,
1093 const std::vector<string>& outputs, const string& executor_name,
1094 GraphDef* output_graph_def) {
1095 const ExecutorBuildFunc* build_func = GetExecutorBuildFunc(executor_name);
1096 if (build_func == nullptr) {
1097 return errors::InvalidArgument("Unknown executor name: " + executor_name);
1098 }
1099 std::unique_ptr<IRemoteFusedGraphExecutor> executor;
1100 TF_RETURN_IF_ERROR((*build_func)(&executor));
1101 CHECK_NOTNULL(executor.get());
1102 if (!executor->IsEnabled()) {
1103 // As this executor is not enabled, just return original graph as is.
1104 *output_graph_def = input_graph_def;
1105 return Status::OK();
1106 }
1107 return executor->FuseRemoteGraph(input_graph_def, inputs, outputs,
1108 output_graph_def);
1109 }
1110
PlaceRemoteGraphArguments(const std::vector<string> & inputs,const std::vector<string> & outputs,const std::unordered_set<string> & fused_node_names,const std::vector<string> & border_inputs,const std::vector<string> & border_outputs,const std::unordered_set<string> & fused_op_types,const string & remote_fused_graph_node_name,const string & remote_graph_executor_name,GraphDef * graph_def)1111 /* static */ Status RemoteFusedGraphExecuteUtils::PlaceRemoteGraphArguments(
1112 const std::vector<string>& inputs, const std::vector<string>& outputs,
1113 const std::unordered_set<string>& fused_node_names,
1114 const std::vector<string>& border_inputs,
1115 const std::vector<string>& border_outputs,
1116 const std::unordered_set<string>& fused_op_types,
1117 const string& remote_fused_graph_node_name,
1118 const string& remote_graph_executor_name, GraphDef* graph_def) {
1119 CHECK_NOTNULL(graph_def);
1120
1121 const std::unordered_set<string> fused_nodes_filtered_by_op_types =
1122 BuildNodeMapFromOpTypes(*graph_def, fused_op_types);
1123
1124 for (NodeDef& node_def : *graph_def->mutable_node()) {
1125 string attr_str;
1126 TensorId tid;
1127 for (size_t i = 0; i < inputs.size(); ++i) {
1128 if (IsSameNodeName(node_def, inputs.at(i), &tid)) {
1129 AppendDeliminator(&attr_str);
1130 attr_str += BuildNodeTypeAttr(GRAPH_INPUT, tid.second, i,
1131 remote_graph_executor_name,
1132 remote_fused_graph_node_name);
1133 }
1134 }
1135 for (size_t i = 0; i < outputs.size(); ++i) {
1136 if (IsSameNodeName(node_def, outputs.at(i), &tid)) {
1137 AppendDeliminator(&attr_str);
1138 attr_str += BuildNodeTypeAttr(GRAPH_OUTPUT, tid.second, i);
1139 }
1140 }
1141 for (const string& fused_node_name : fused_node_names) {
1142 if (fused_node_name == node_def.name()) {
1143 AppendDeliminator(&attr_str);
1144 attr_str += BuildNodeTypeAttr(FUSED_NODE);
1145 }
1146 }
1147 for (const string& fused_node_name : fused_nodes_filtered_by_op_types) {
1148 if (fused_node_name == node_def.name()) {
1149 AppendDeliminator(&attr_str);
1150 attr_str += BuildNodeTypeAttr(FUSED_NODE);
1151 }
1152 }
1153 for (size_t i = 0; i < border_inputs.size(); ++i) {
1154 if (IsSameNodeName(node_def, border_inputs.at(i), &tid)) {
1155 AppendDeliminator(&attr_str);
1156 attr_str += BuildNodeTypeAttr(BORDER_INPUT, tid.second, i);
1157 }
1158 }
1159 for (size_t i = 0; i < border_outputs.size(); ++i) {
1160 if (IsSameNodeName(node_def, border_outputs.at(i), &tid)) {
1161 AppendDeliminator(&attr_str);
1162 attr_str += BuildNodeTypeAttr(BORDER_OUTPUT, tid.second, i);
1163 }
1164 }
1165 if (attr_str.empty()) {
1166 attr_str += BuildNodeTypeAttr(UNUSED);
1167 }
1168 AddNodeAttr(ATTR_NODE_TYPE, attr_str, &node_def);
1169 }
1170 return Status::OK();
1171 }
1172
1173 /* static */ Status
FuseRemoteGraphByPlacedArguments(const GraphDef & input_graph_def,const std::vector<std::pair<string,Tensor>> & input_tensors,GraphDef * output_graph_def)1174 RemoteFusedGraphExecuteUtils::FuseRemoteGraphByPlacedArguments(
1175 const GraphDef& input_graph_def,
1176 const std::vector<std::pair<string, Tensor>>& input_tensors,
1177 GraphDef* output_graph_def) {
1178 std::unordered_map<int, string> input_map;
1179 std::unordered_map<int, string> output_map;
1180 std::unordered_set<string> fused_node_names;
1181 std::unordered_map<int, string> border_input_map;
1182 std::unordered_map<int, string> border_output_map;
1183 string remote_graph_executor_name;
1184 string remote_fused_graph_node_name;
1185
1186 for (const NodeDef& node_def : input_graph_def.node()) {
1187 string attr_str;
1188 TF_RETURN_IF_ERROR(GetNodeAttr(node_def, ATTR_NODE_TYPE, &attr_str));
1189 std::vector<std::vector<string>> attr_strs;
1190 for (const string& str : str_util::Split(attr_str, ":")) {
1191 attr_strs.emplace_back(str_util::Split(str, ","));
1192 }
1193 if (attr_strs.empty()) {
1194 return errors::InvalidArgument("Remote graph node type not found.");
1195 }
1196 for (const std::vector<string>& attr : attr_strs) {
1197 if (attr.empty()) {
1198 return errors::InvalidArgument("Empty remote graph node type attr.");
1199 }
1200 int node_type_int;
1201 CHECK(strings::safe_strto32(attr.at(0), &node_type_int)) << attr.at(0);
1202 const RemoteFusedGraphNodeType node_type =
1203 static_cast<RemoteFusedGraphNodeType>(node_type_int);
1204 const string& name = node_def.name();
1205 int port;
1206 int index;
1207
1208 switch (node_type) {
1209 case GRAPH_INPUT:
1210 VLOG(2) << "Graph input: " << name;
1211 CHECK_EQ(5, attr.size());
1212 CHECK(strings::safe_strto32(attr.at(1), &port));
1213 CHECK(strings::safe_strto32(attr.at(2), &index));
1214 CHECK(!attr.at(3).empty());
1215 remote_graph_executor_name = attr.at(3);
1216 CHECK(!attr.at(4).empty());
1217 remote_fused_graph_node_name = attr.at(4);
1218 input_map.emplace(index, strings::StrCat(name, ":", port));
1219 if (GetExecutorBuildFunc(remote_graph_executor_name) == nullptr) {
1220 LOG(INFO) << "Executor for " << remote_graph_executor_name
1221 << " not registered. Do not fuse.";
1222 *output_graph_def = input_graph_def;
1223 return Status::OK();
1224 }
1225 break;
1226 case GRAPH_OUTPUT:
1227 VLOG(2) << "Graph output: " << name;
1228 CHECK_EQ(3, attr.size());
1229 CHECK(strings::safe_strto32(attr.at(1), &port));
1230 CHECK(strings::safe_strto32(attr.at(2), &index));
1231 output_map.emplace(index, strings::StrCat(name, ":", port));
1232 break;
1233 case FUSED_NODE:
1234 VLOG(2) << "Fused node: " << name;
1235 CHECK_EQ(1, attr.size());
1236 fused_node_names.emplace(name);
1237 break;
1238 case BORDER_INPUT:
1239 VLOG(2) << "Border input: " << name;
1240 CHECK_EQ(3, attr.size());
1241 CHECK(strings::safe_strto32(attr.at(1), &port));
1242 CHECK(strings::safe_strto32(attr.at(2), &index));
1243 border_input_map.emplace(index, strings::StrCat(name, ":", port));
1244 break;
1245 case BORDER_OUTPUT:
1246 VLOG(2) << "Border output: " << name;
1247 CHECK_EQ(3, attr.size());
1248 CHECK(strings::safe_strto32(attr.at(1), &port));
1249 CHECK(strings::safe_strto32(attr.at(2), &index));
1250 border_output_map.emplace(index, strings::StrCat(name, ":", port));
1251 break;
1252 case UNUSED:
1253 // do nothing
1254 break;
1255 default:
1256 // unsupported value
1257 LOG(FATAL);
1258 }
1259 }
1260 }
1261 bool require_shape_type = false;
1262 std::vector<string> inputs;
1263 std::vector<string> outputs;
1264 std::vector<string> border_inputs;
1265 std::vector<string> border_outputs;
1266 ConvertMapToVector(input_map, &inputs);
1267 ConvertMapToVector(output_map, &outputs);
1268 ConvertMapToVector(border_input_map, &border_inputs);
1269 ConvertMapToVector(border_output_map, &border_outputs);
1270
1271 if (!input_tensors.empty()) {
1272 bool input_match = false;
1273 if (inputs.size() == input_tensors.size()) {
1274 for (const std::pair<string, Tensor>& input_tensor : input_tensors) {
1275 if (!ContainsSameTensorId(input_tensor.first, inputs)) {
1276 break;
1277 }
1278 DataType data_type;
1279 TensorShape shape;
1280 if (GetOutputTensorShapeType(input_graph_def, input_tensor.first,
1281 &data_type, &shape)) {
1282 if (data_type == input_tensor.second.dtype() &&
1283 shape == input_tensor.second.shape()) {
1284 VLOG(2) << "Input matched!";
1285 // Shape type matched.
1286 input_match = true;
1287 require_shape_type = true;
1288 }
1289 } else {
1290 // Shape type not required.
1291 input_match = true;
1292 }
1293 }
1294 }
1295 if (!input_match) {
1296 // Input mismatch. Just copy original graph
1297 *output_graph_def = input_graph_def;
1298 return Status::OK();
1299 }
1300 }
1301
1302 if (!fused_node_names.empty()) {
1303 TF_RETURN_IF_ERROR(FuseRemoteGraphByNodeNames(
1304 input_graph_def, inputs, outputs, remote_fused_graph_node_name,
1305 fused_node_names, remote_graph_executor_name, require_shape_type,
1306 output_graph_def));
1307 } else if (!border_inputs.empty() || !border_outputs.empty()) {
1308 TF_RETURN_IF_ERROR(FuseRemoteGraphByBorder(
1309 input_graph_def, inputs, outputs, remote_fused_graph_node_name,
1310 border_inputs, border_outputs, remote_graph_executor_name,
1311 require_shape_type, output_graph_def));
1312 } else {
1313 *output_graph_def = input_graph_def;
1314 }
1315
1316 return Status::OK();
1317 }
1318
IsFuseReady(const GraphDef & graph_def,const std::vector<std::pair<string,Tensor>> & input_tensors)1319 /* static */ bool RemoteFusedGraphExecuteUtils::IsFuseReady(
1320 const GraphDef& graph_def,
1321 const std::vector<std::pair<string, Tensor>>& input_tensors) {
1322 for (const std::pair<string, Tensor>& input_tensor : input_tensors) {
1323 const NodeDef* node_def = FindNodeDefByName(input_tensor.first, graph_def);
1324 if (node_def == nullptr) {
1325 return false;
1326 }
1327 string attr;
1328 const Status status = GetNodeAttr(*node_def, ATTR_NODE_TYPE, &attr);
1329 if (!status.ok() || attr.empty()) {
1330 return false;
1331 }
1332 }
1333 return true;
1334 }
1335
CopyByteArrayToTensor(const void * src_ptr,const int src_size,Tensor * tensor)1336 /* static */ Status RemoteFusedGraphExecuteUtils::CopyByteArrayToTensor(
1337 const void* src_ptr, const int src_size, Tensor* tensor) {
1338 int tensor_TotalBytes = tensor->TotalBytes();
1339 CHECK(tensor_TotalBytes >= src_size) << tensor_TotalBytes << ", " << src_size;
1340 void* dst_ptr;
1341 switch (tensor->dtype()) {
1342 case DT_FLOAT:
1343 dst_ptr = tensor->flat<float>().data();
1344 break;
1345 case DT_DOUBLE:
1346 dst_ptr = tensor->flat<double>().data();
1347 break;
1348 case DT_INT32:
1349 dst_ptr = tensor->flat<int32>().data();
1350 break;
1351 case DT_UINT8:
1352 dst_ptr = tensor->flat<uint8>().data();
1353 break;
1354 case DT_INT16:
1355 dst_ptr = tensor->flat<int16>().data();
1356 break;
1357 case DT_INT8:
1358 dst_ptr = tensor->flat<int8>().data();
1359 break;
1360 case DT_STRING:
1361 dst_ptr = tensor->flat<tstring>().data();
1362 break;
1363 case DT_INT64:
1364 dst_ptr = tensor->flat<int64>().data();
1365 break;
1366 case DT_BOOL:
1367 dst_ptr = tensor->flat<bool>().data();
1368 break;
1369 case DT_QINT8:
1370 dst_ptr = tensor->flat<qint8>().data();
1371 break;
1372 case DT_QUINT8:
1373 dst_ptr = tensor->flat<quint8>().data();
1374 break;
1375 case DT_QINT32:
1376 dst_ptr = tensor->flat<qint32>().data();
1377 break;
1378 case DT_BFLOAT16:
1379 dst_ptr = tensor->flat<bfloat16>().data();
1380 break;
1381 case DT_QINT16:
1382 dst_ptr = tensor->flat<qint16>().data();
1383 break;
1384 case DT_QUINT16:
1385 dst_ptr = tensor->flat<quint16>().data();
1386 break;
1387 case DT_UINT16:
1388 dst_ptr = tensor->flat<uint16>().data();
1389 break;
1390 default:
1391 LOG(FATAL) << "type " << tensor->dtype() << " is not supported.";
1392 break;
1393 }
1394 CHECK_NOTNULL(dst_ptr);
1395 std::memcpy(dst_ptr, src_ptr, src_size);
1396 return Status::OK();
1397 }
1398
1399 /* static */ std::unordered_set<string>
BuildNodeMapFromOpTypes(const GraphDef & graph_def,const std::unordered_set<string> & op_types)1400 RemoteFusedGraphExecuteUtils::BuildNodeMapFromOpTypes(
1401 const GraphDef& graph_def, const std::unordered_set<string>& op_types) {
1402 std::unordered_set<string> retval;
1403 for (const NodeDef& node_def : graph_def.node()) {
1404 if (op_types.count(node_def.op()) > 0) {
1405 retval.emplace(node_def.name());
1406 }
1407 }
1408 return retval;
1409 }
1410
1411 /* static */ std::unordered_set<string>
BuildNodeMapFromOpsDefinitions(const GraphDef & graph_def,const IRemoteFusedGraphOpsDefinitions & ops_definitions)1412 RemoteFusedGraphExecuteUtils::BuildNodeMapFromOpsDefinitions(
1413 const GraphDef& graph_def,
1414 const IRemoteFusedGraphOpsDefinitions& ops_definitions) {
1415 std::unordered_set<string> retval;
1416 for (const NodeDef& node_def : graph_def.node()) {
1417 std::vector<DataType> dt_vec;
1418 std::vector<TensorShape> shape_vec;
1419 const Status status =
1420 GetOutputTensorShapeType(node_def, &dt_vec, &shape_vec);
1421 if (!status.ok()) {
1422 shape_vec.clear();
1423 }
1424 if (ops_definitions.GetOpIdFor(
1425 node_def.op(), DataTypeVector(dt_vec.begin(), dt_vec.end())) !=
1426 IRemoteFusedGraphOpsDefinitions::INVALID_OP_ID) {
1427 retval.emplace(node_def.name());
1428 }
1429 }
1430 return retval;
1431 }
1432
ReplaceInputNodeByPlaceHolder(const string & input,const DataType type,const TensorShape & shape,GraphDef * graph_def)1433 /* static */ Status RemoteFusedGraphExecuteUtils::ReplaceInputNodeByPlaceHolder(
1434 const string& input, const DataType type, const TensorShape& shape,
1435 GraphDef* graph_def) {
1436 const TensorId tid = ParseTensorName(input);
1437 CHECK_EQ(0, tid.second);
1438 const string node_name(tid.first);
1439 for (NodeDef& node : *graph_def->mutable_node()) {
1440 if (node.name() != node_name) {
1441 continue;
1442 }
1443 if (node.op() == "Placeholder") {
1444 return Status::OK();
1445 } else {
1446 NodeDef placeholder_node;
1447 placeholder_node.set_op("Placeholder");
1448 placeholder_node.set_name(node_name);
1449 AddNodeAttr("dtype", type, &placeholder_node);
1450 AddNodeAttr("shape", shape, &placeholder_node);
1451 // TODO(satok): Remove once we merge attributes
1452 AddOutputTensorShapeType({type}, {shape}, &placeholder_node);
1453 node.Clear();
1454 node = placeholder_node;
1455 return Status::OK();
1456 }
1457 }
1458 return errors::InvalidArgument(
1459 strings::StrCat(node_name, " not found for replacement."));
1460 }
1461
BuildNodeTypeAttr(const RemoteFusedGraphNodeType node_type,const int port,const int index,const string & executor_name,const string & node_name)1462 /* static */ string RemoteFusedGraphExecuteUtils::BuildNodeTypeAttr(
1463 const RemoteFusedGraphNodeType node_type, const int port, const int index,
1464 const string& executor_name, const string& node_name) {
1465 return strings::StrCat(static_cast<int>(node_type), ",", port, ",", index,
1466 ",", executor_name, ",", node_name);
1467 }
1468
BuildNodeTypeAttr(const RemoteFusedGraphNodeType node_type,const int port,const int index)1469 /* static */ string RemoteFusedGraphExecuteUtils::BuildNodeTypeAttr(
1470 const RemoteFusedGraphNodeType node_type, const int port, const int index) {
1471 return strings::StrCat(static_cast<int>(node_type), ",", port, ",", index);
1472 }
1473
BuildNodeTypeAttr(const RemoteFusedGraphNodeType node_type)1474 /* static */ string RemoteFusedGraphExecuteUtils::BuildNodeTypeAttr(
1475 const RemoteFusedGraphNodeType node_type) {
1476 return strings::StrCat(static_cast<int>(node_type));
1477 }
1478
1479 } // namespace tensorflow
1480