1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/tf2tensorrt/segment/segment.h"
17
18 #include <queue>
19 #include <set>
20 #include <unordered_map>
21 #include <vector>
22
23 #include "absl/strings/str_cat.h"
24 #include "absl/strings/str_format.h"
25 #include "tensorflow/compiler/tf2tensorrt/common/utils.h"
26 #include "tensorflow/compiler/tf2tensorrt/convert/utils.h"
27 #include "tensorflow/core/common_runtime/graph_constructor.h"
28 #include "tensorflow/core/graph/algorithm.h"
29 #include "tensorflow/core/graph/graph.h"
30 #include "tensorflow/core/lib/core/errors.h"
31 #include "tensorflow/core/lib/core/status.h"
32 #include "tensorflow/core/lib/gtl/flatset.h"
33 #include "tensorflow/core/lib/strings/str_util.h"
34 #include "tensorflow/core/lib/strings/strcat.h"
35 #include "tensorflow/core/platform/types.h"
36 #include "tensorflow/core/util/env_var.h"
37
38 #if GOOGLE_CUDA && GOOGLE_TENSORRT
39
40 namespace tensorflow {
41 namespace tensorrt {
42 namespace segment {
43 namespace {
44 using absl::StrAppend;
45 using absl::StrAppendFormat;
46 using absl::StrCat;
47 using absl::StrJoin;
48
49 // A simple graph representation to mirror Graph. This structure
50 // helps saving memory since segmenter modifies the graph in place, preventing
51 // the need to create a copy of the graph. It is composed of edges and nodes.
52 // Nodes keep pointers to original TF nodes.
53 class SimpleNode;
54 class SimpleGraph;
55 class SimpleEdge {
56 public:
SimpleEdge(int id,SimpleNode * src,int src_port,SimpleNode * dst,int dst_port,bool is_control=false)57 SimpleEdge(int id, SimpleNode* src, int src_port, SimpleNode* dst,
58 int dst_port, bool is_control = false)
59 : id_(id),
60 src_(src),
61 src_port_(src_port),
62 dst_(dst),
63 dst_port_(dst_port),
64 control_(is_control) {}
~SimpleEdge()65 ~SimpleEdge() {}
66
src() const67 SimpleNode* src() const { return src_; }
dst() const68 SimpleNode* dst() const { return dst_; }
src_output() const69 int src_output() const { return src_port_; }
dst_input() const70 int dst_input() const { return dst_port_; }
id() const71 int id() const { return id_; }
IsControlEdge() const72 bool IsControlEdge() const { return control_; }
73
74 private:
75 int id_;
76 SimpleNode* src_;
77 int src_port_;
78 SimpleNode* dst_;
79 int dst_port_;
80 bool control_;
81 };
82
83 class SimpleNode {
84 public:
85 SimpleNode(const Node* node, const int id);
86
in_edges() const87 const std::vector<SimpleEdge*>& in_edges() const { return in_edges_; }
out_edges() const88 const std::vector<SimpleEdge*>& out_edges() const { return out_edges_; }
89
in_nodes() const90 std::vector<SimpleNode*> in_nodes() const {
91 std::vector<SimpleNode*> res;
92 res.reserve(in_edges_.size());
93 for (const auto e : in_edges_) {
94 if (e) res.push_back(e->src());
95 }
96 return res;
97 }
98
out_nodes() const99 std::vector<SimpleNode*> out_nodes() const {
100 std::vector<SimpleNode*> res;
101 res.reserve(out_edges_.size());
102 for (const auto e : out_edges_) {
103 if (e) res.push_back(e->dst());
104 }
105 return res;
106 }
107
name() const108 const string& name() const { return node_->name(); }
tf_node() const109 const Node* tf_node() const { return node_; }
id() const110 int id() const { return id_; }
111
112 private:
113 const Node* node_;
114 std::vector<SimpleEdge*> in_edges_;
115 std::vector<SimpleEdge*> out_edges_;
116 int id_;
117
118 friend class SimpleGraph;
119 };
120
121 class SimpleGraph {
122 public:
123 explicit SimpleGraph(const Graph* g);
124 ~SimpleGraph();
125
126 void AddControlEdge(SimpleNode* src, SimpleNode* dst);
127 void AddEdge(SimpleNode* src, int out_port, SimpleNode* dst, int in_port);
128 void RemoveEdge(const SimpleEdge*);
FindNodeId(int node_id)129 SimpleNode* FindNodeId(int node_id) {
130 if (node_id < 0 || node_id > static_cast<int>(nodes_.size())) {
131 return nullptr;
132 }
133 return nodes_[node_id];
134 }
num_node_ids() const135 int num_node_ids() const { return nodes_.size(); }
source_node() const136 const SimpleNode* source_node() const { return nodes_[Graph::kSourceId]; }
sink_node() const137 const SimpleNode* sink_node() const { return nodes_[Graph::kSinkId]; }
138
139 private:
140 const Graph* g_;
141 std::vector<SimpleNode*> nodes_;
142 std::vector<SimpleEdge*> edges_;
143 // free_edge_ids_ and free_node_ids_ contain freed indices.
144 std::set<int> free_edge_ids_;
145 std::set<int> free_node_ids_;
146 };
147
SimpleNode(const Node * node,const int id)148 SimpleNode::SimpleNode(const Node* node, const int id) : node_(node), id_(id) {
149 if (node_) {
150 in_edges_.reserve(node_->in_edges().size());
151 out_edges_.reserve(node_->out_edges().size());
152 }
153 }
154
SimpleGraph(const Graph * g)155 SimpleGraph::SimpleGraph(const Graph* g) : g_(g) {
156 int n_nodes = g_->num_node_ids();
157 nodes_.resize(n_nodes, nullptr);
158 nodes_[g->kSourceId] = new SimpleNode(g->source_node(), g->kSourceId);
159 nodes_[g->kSinkId] = new SimpleNode(g->sink_node(), g->kSinkId);
160 int n_edges = g->num_edge_ids();
161 edges_.resize(n_edges, nullptr);
162 for (int i = 2; i < n_nodes; i++) {
163 const auto n = g->FindNodeId(i);
164 if (n) {
165 nodes_[i] = new SimpleNode(n, i);
166 } else {
167 free_node_ids_.insert(i);
168 }
169 }
170 for (int i = 0; i < n_edges; i++) {
171 const auto e = g->FindEdgeId(i);
172 if (e) {
173 const auto tfsrc = e->src();
174 const auto tfdst = e->dst();
175 bool is_control = e->IsControlEdge();
176 auto src = nodes_[tfsrc->id()];
177 auto dst = nodes_[tfdst->id()];
178 auto edge = new SimpleEdge(i, src, e->src_output(), dst, e->dst_input(),
179 is_control);
180 edges_[i] = edge;
181 src->out_edges_.push_back(edge);
182 dst->in_edges_.push_back(edge);
183 } else {
184 free_edge_ids_.insert(i);
185 }
186 }
187 }
188
AddEdge(SimpleNode * src,int out_port,SimpleNode * dst,int in_port)189 void SimpleGraph::AddEdge(SimpleNode* src, int out_port, SimpleNode* dst,
190 int in_port) {
191 int i = edges_.size();
192 if (!free_edge_ids_.empty()) {
193 auto it = free_edge_ids_.begin();
194 i = *it;
195 free_edge_ids_.erase(it);
196 } else {
197 edges_.push_back(nullptr);
198 }
199 bool is_control = (out_port == Graph::kControlSlot);
200 is_control |= (in_port == Graph::kControlSlot);
201 auto edge = new SimpleEdge(i, src, out_port, dst, in_port, is_control);
202 edges_[i] = edge;
203 src->out_edges_.push_back(edge);
204 dst->in_edges_.push_back(edge);
205 }
206
AddControlEdge(SimpleNode * src,SimpleNode * dst)207 void SimpleGraph::AddControlEdge(SimpleNode* src, SimpleNode* dst) {
208 AddEdge(src, Graph::kControlSlot, dst, Graph::kControlSlot);
209 }
210
RemoveEdge(const SimpleEdge * edge)211 void SimpleGraph::RemoveEdge(const SimpleEdge* edge) {
212 auto src = edge->src();
213 auto dst = edge->dst();
214 for (auto it = src->out_edges_.begin(); it != src->out_edges_.end(); ++it) {
215 if (*it == edge) {
216 src->out_edges_.erase(it);
217 break;
218 }
219 }
220 for (auto it = dst->in_edges_.begin(); it != dst->in_edges_.end(); ++it) {
221 if (*it == edge) {
222 dst->in_edges_.erase(it);
223 break;
224 }
225 }
226 }
227
~SimpleGraph()228 SimpleGraph::~SimpleGraph() {
229 for (auto x : nodes_) delete x;
230 for (auto x : edges_) delete x;
231 }
232
233 // Define comparison functions for std::set with pointer keys so that behavior
234 // is deterministic. When using std::set with pointer key types, the items are
235 // sorted by pointer address which is non-deterministic. This can cause issues
236 // for INT8 mode because the graph is converted twice and non-determinism may
237 // cause a mismatch between the calibration tables of the conversions.
238 struct SimpleEdgePtrCompare {
operator ()tensorflow::tensorrt::segment::__anon3618f86b0111::SimpleEdgePtrCompare239 bool operator()(const SimpleEdge* lhs, const SimpleEdge* rhs) const {
240 return lhs->id() < rhs->id();
241 }
242 };
243
244 // Copied from TF ReverseDFS, which only works for Graph.
StableDFS(const SimpleGraph & g,bool reverse,const std::vector<const SimpleNode * > & start,const std::function<bool (const SimpleNode *)> & enter,const std::function<bool (const SimpleNode *)> & leave)245 void StableDFS(const SimpleGraph& g, bool reverse,
246 const std::vector<const SimpleNode*>& start,
247 const std::function<bool(const SimpleNode*)>& enter,
248 const std::function<bool(const SimpleNode*)>& leave) {
249 // Stack of work to do.
250 struct Work {
251 const SimpleNode* node;
252 bool leave; // Are we entering or leaving n?
253 };
254 std::vector<Work> stack(start.size());
255 for (int i = 0; i < start.size(); ++i) {
256 stack[i] = Work{start[i], false};
257 }
258
259 auto get_nodes = reverse ? [](const SimpleNode* n) { return n->in_nodes(); }
260 : [](const SimpleNode* n) { return n->out_nodes(); };
261 std::vector<bool> visited(g.num_node_ids(), false);
262 while (!stack.empty()) {
263 Work w = stack.back();
264 stack.pop_back();
265
266 auto n = w.node;
267 if (w.leave) {
268 if (leave && !leave(n)) return;
269 continue;
270 }
271
272 if (visited[n->id()]) continue;
273 visited[n->id()] = true;
274 if (enter && !enter(n)) return;
275
276 // Arrange to call leave(n) when all done with descendants.
277 if (leave) stack.push_back(Work{n, true});
278
279 auto nodes = get_nodes(n);
280 std::vector<const SimpleNode*> nodes_sorted(nodes.begin(), nodes.end());
281 std::sort(nodes_sorted.begin(), nodes_sorted.end(),
282 [](const SimpleNode* lhs, const SimpleNode* rhs) {
283 return lhs->name() < rhs->name();
284 });
285 for (const SimpleNode* node : nodes_sorted) {
286 if (!visited[node->id()]) {
287 stack.push_back(Work{node, false});
288 }
289 }
290 }
291 }
292
CanContractEdge(const SimpleEdge * edge,const std::unique_ptr<SimpleGraph> & graph)293 bool CanContractEdge(const SimpleEdge* edge,
294 const std::unique_ptr<SimpleGraph>& graph) {
295 const auto src = edge->src();
296 const auto dst = edge->dst();
297
298 // Can't contract edge if doing so would cause a cycle in the
299 // graph. So, if there is a directed path from 'src' to 'dst', other
300 // than 'edge' (or any other direct edge from 'src' to 'dst'), then
301 // combining 'src' and 'dst' will cause a cycle along that path.
302 //
303 // In practice, to avoid modifying the graph and to take advantage
304 // of existing graph functions, we perform an equivalent.
305 // 1. Get all nodes incoming to 'dst', excluding 'src'
306 // 2. Reverse DFS from those nodes
307 // 3. If reverse DFS reaches 'src' then we have a cycle
308 //
309 // TODO(aaroey): there are several problems with the current approach:
310 // 1. src->dst->src, this is not detected but it should be;
311 // 2. src->dst->...(any node sequence that doesn't contain src)...->dst, this
312 // is detected but it should not be.
313 //
314 // Note that it's fine that dst connects back to src indirectly (i.e. through
315 // a path with length > 1 that consists of intermedia nodes other than src).
316 // While loops is one example.
317 //
318 // The goal is to make sure that the trt subgraph:
319 // 1. has no loops (i.e. is a DAG), and
320 // 2. if there is a path in the subgraph from X to Y (X and Y are both nodes
321 // in the subgraph), then all paths from X to Y are in the subgraph.
322 //
323 // To achieve this goal, the correct way seems to be:
324 // 1. remove any direct edge from src->dst;
325 // 2. detect if src can reach dst, if so they cannot be merged.
326 std::vector<const SimpleNode*> dfs_start_nodes;
327 for (const SimpleNode* node : dst->in_nodes()) {
328 if (node != src) {
329 dfs_start_nodes.push_back(node);
330 }
331 }
332 bool has_cycle = false;
333 StableDFS(*graph, /*reverse=*/true, dfs_start_nodes, /*enter=*/nullptr,
334 [&has_cycle, src](const SimpleNode* n) {
335 if (n == src) {
336 has_cycle = true;
337 return false;
338 }
339 return true;
340 });
341 return !has_cycle;
342 }
343
344 // TODO(bixia): put this to a common utility file.
TensorPropertiesToString(const OpInfo::TensorProperties & prop)345 string TensorPropertiesToString(const OpInfo::TensorProperties& prop) {
346 string s = StrCat(DataTypeString(prop.dtype()), ": ");
347 StrAppend(&s, "[");
348 if (prop.shape().unknown_rank()) {
349 StrAppend(&s, "?");
350 } else {
351 StrAppend(&s, StrJoin(prop.shape().dim(), ",",
352 [](string* out, const TensorShapeProto_Dim& d) {
353 StrAppendFormat(out, "%d", d.size());
354 }));
355 }
356 StrAppend(&s, "]");
357 return s;
358 }
359
TensorPropertiesToString(const std::vector<OpInfo::TensorProperties> & properties)360 string TensorPropertiesToString(
361 const std::vector<OpInfo::TensorProperties>& properties) {
362 return StrJoin(properties, "; ",
363 [](string* out, const OpInfo::TensorProperties& prop) {
364 StrAppend(out, TensorPropertiesToString(prop));
365 });
366 }
367
368 // From the given list of input properties, returns the leading shape, which is
369 // the shape that determines the batch size of the operation. The leading shape
370 // is selected from the group of input shapes with the highest rank as follows:
371 // . If all of those shapes have non-negative values for the batch dimension,
372 // the leading shape is the one with the largest value for the batch
373 // dimension.
374 // . If some or all of those shapes have negative values for the batch
375 // dimension, and the rest of those shapes have 1 for the batch dimension,
376 // the leading shape is the first of those shapes with a negative value for
377 // the batch dimension.
378 // . Otherwise, we can't determine the leading shape for the operation and
379 // have to exclude the operation from TRT.
380 //
381 // Examples:
382 // case-1: a[1,3,4] + b[2,3,4] => leading shape [2,3,4]
383 // case-2: a[2,3,4] + b[scalar] => leading shape [2,3,4]
384 // case-3: a[-1,3,4] + b[1,3,4] => leading shape [-1,3,4]
385 // case-4: a[-1,3,4] + b[2,3,4] => no leading shape
386 //
387 // We have to return "no leading shape" for case-4 to exclude such operation
388 // from being translated for this reason:
389 // The actually input for "a" have to be in the shape of [2,3,4] for the
390 // operation to be valid. On the other hand, if we translate the operation
391 // to implicit batch mode, it will becomes a[3,4]+b[3,4] which is valid for
392 // any input shape of "a".
393 //
394 // This routine assumes the input program is valid. For example, we shouldn't
395 // see invalid operation like a[2,3,4] + b[3,3,4]. It also assumes the input
396 // properties is not empty and all input have known shapes.
397 //
398 // TODO(bixia): find a way to share this knowledge with the converter.
399 // TODO(bixia): investigate the use of symbolic shape analysis to improve
400 // segmentation, such as by requiring the dynamic dimensions to have the same
401 // negative value.
FindLeadingShape(absl::Span<const OpInfo::TensorProperties> properties)402 absl::optional<const TensorShapeProto*> FindLeadingShape(
403 absl::Span<const OpInfo::TensorProperties> properties) {
404 DCHECK(!properties.empty());
405 const TensorShapeProto* result;
406 int max_batch_dim_value;
407 auto choose_shape_with_higher_rank = [&](const TensorShapeProto* s) {
408 result = s;
409 max_batch_dim_value = s->dim_size() < 1 ? 1 : s->dim(0).size();
410 };
411
412 DCHECK(!properties[0].shape().unknown_rank());
413 choose_shape_with_higher_rank(&properties[0].shape());
414
415 for (const OpInfo::TensorProperties& p : properties.subspan(1)) {
416 DCHECK(!p.shape().unknown_rank());
417 if (p.shape().dim_size() < result->dim_size()) continue;
418
419 if (p.shape().dim_size() > result->dim_size()) {
420 choose_shape_with_higher_rank(&p.shape());
421 continue;
422 }
423
424 // Among the shapes with the same rank, choose the one with a dynamic batch
425 // size. If no shapes have a dynamic batch size, choose the one with the
426 // largest size.
427 if (result->dim_size() < 1) continue;
428
429 if (p.shape().dim(0).size() < 0 || result->dim(0).size() < 0) {
430 if (p.shape().dim(0).size() < 0 && result->dim(0).size() >= 0) {
431 result = &p.shape();
432 } else {
433 max_batch_dim_value =
434 std::max<int>(max_batch_dim_value, p.shape().dim(0).size());
435 }
436
437 continue;
438 }
439
440 if (p.shape().dim(0).size() > result->dim(0).size()) {
441 result = &p.shape();
442 max_batch_dim_value = result->dim(0).size();
443 }
444 }
445
446 if (result->dim_size() > 0 && result->dim(0).size() < 0) {
447 // dynamic batch size
448 if (max_batch_dim_value <= 1) {
449 return result;
450 } else {
451 return absl::nullopt;
452 }
453 }
454
455 return result;
456 }
457
458 // Returns the inputs that are relevant to determinate the batch size of the
459 // operation. This routine handles the following cases:
460 // . Operations that support implicit boradcasting, such as operation mul.
461 // In this case, we need to inspect all the inputs in order to determine the
462 // batch size of the operation.
463 // . Special cases. Such as "Conv2DBackpropInput", "Conv3DBackpropInputV2".
464 // . The batch size of a operation is determined by the first input of the
465 // operation.
GetInputsToDeterminateBatchSize(const Node * node,const std::vector<OpInfo::TensorProperties> & all_inputs)466 absl::Span<const OpInfo::TensorProperties> GetInputsToDeterminateBatchSize(
467 const Node* node, const std::vector<OpInfo::TensorProperties>& all_inputs) {
468 // TODO(bixia): Find a way to share this knowledge with the converter.
469 static std::set<string> broadcast_supporting_ops = {
470 // ops corresponding to ConvertBinary in the converter
471 "Add",
472 "AddV2",
473 "Mul",
474 "Sub",
475 "Div",
476 "FloorDiv",
477 "RealDiv",
478 "Minimum",
479 "Maximum",
480 "Pow",
481 // other ops that need to need GetTrtBroadcastShape to convert
482 "BiasAdd",
483 "SquaredDifference",
484 "BatchMatMul",
485 "BatchMatMulV2",
486 };
487 const string& op = node->def().op();
488
489 if (op == "Conv2DBackpropInput" || op == "Conv3DBackpropInputV2") {
490 DCHECK_EQ(all_inputs.size(), 3);
491 return absl::MakeSpan(all_inputs).subspan(2, 1);
492 }
493
494 if (broadcast_supporting_ops.count(op)) {
495 return absl::MakeSpan(all_inputs);
496 }
497
498 // This is the common case for the operations that don't support implicit
499 // broadcasting: the first operand determines its batch size. All otherwise
500 // cases are handled before reaching here.
501 return absl::MakeSpan(all_inputs).subspan(0, 1);
502 }
503
504 // Returns true if the operation we can remove the implicit batch of the
505 // operation.
506 //
507 // In particular, if the input shape has dynamic rank or the input shape rank
508 // is less than 2, we can't remove the implicit batch dimension and generate
509 // a new operation for TRT translation.
OperationCanBeTranslatedToImplicitBatch(const grappler::GraphProperties * graph_properties,const Node * node)510 bool OperationCanBeTranslatedToImplicitBatch(
511 const grappler::GraphProperties* graph_properties, const Node* node) {
512 VLOG(3) << "process node " << node->name();
513 if (node->num_inputs() == 0) return true;
514 if (!graph_properties || !graph_properties->HasInputProperties(node->name()))
515 return false;
516
517 VLOG(3) << "input shapes "
518 << TensorPropertiesToString(
519 graph_properties->GetInputProperties(node->name()));
520
521 const std::vector<OpInfo::TensorProperties>& all_input_properties =
522 graph_properties->GetInputProperties(node->name());
523 absl::Span<const OpInfo::TensorProperties> input_properties =
524 GetInputsToDeterminateBatchSize(node, all_input_properties);
525 if (absl::c_any_of(input_properties, [](const OpInfo::TensorProperties& p) {
526 return p.shape().unknown_rank();
527 })) {
528 return false;
529 }
530
531 absl::optional<const TensorShapeProto*> leading_shape =
532 FindLeadingShape(input_properties);
533 return leading_shape.has_value() && leading_shape.value()->dim_size() >= 2;
534 }
535
536 // Returns true if we can't be sure that the operand with the given properties
537 // won't have negative values for non-batch dimensions.
538 //
HasDynamicNonBatchDimension(const OpInfo::TensorProperties & prop)539 bool HasDynamicNonBatchDimension(const OpInfo::TensorProperties& prop) {
540 const TensorShapeProto& shape = prop.shape();
541 if (shape.unknown_rank()) return true;
542
543 // Scalar is a well specified shape, and TRT supports implicit broadcasting
544 // from scalar to other shapes.
545 if (shape.dim_size() == 0) return false;
546 for (int i = 1; i < shape.dim_size(); ++i) {
547 // The value of a dynamic dimension can be other negative values besides
548 // -1, representing the symbolic group of the dimension.
549 if (shape.dim(i).size() <= -1) {
550 return true;
551 }
552 }
553 return false;
554 }
555
556 // Returns true if we can't be sure that the operation won't have dynamic
557 // non-batch dimension involved. We only check the shape of the first output
558 // assuming shape inference already propagates the shapes.
OperationHasDynamicNonBatchDimension(const grappler::GraphProperties * graph_properties,const Node * node)559 bool OperationHasDynamicNonBatchDimension(
560 const grappler::GraphProperties* graph_properties, const Node* node) {
561 VLOG(3) << "process node " << node->name();
562 // If the node doesn't have any input or output, not computation is involved.
563 if (node->num_inputs() == 0 || node->num_outputs() == 0) return false;
564
565 // If the node doesn't have output properties, return true to be conservative.
566 if (!graph_properties->HasOutputProperties(node->name())) return true;
567 VLOG(3) << "output shapes "
568 << TensorPropertiesToString(
569 graph_properties->GetOutputProperties(node->name()));
570 return HasDynamicNonBatchDimension(
571 graph_properties->GetOutputProperties(node->name()).at(0));
572 }
573
ContractEdge(SimpleEdge * edge,SimpleGraph * graph,std::vector<const SimpleEdge * > * remove_edges)574 void ContractEdge(SimpleEdge* edge, SimpleGraph* graph,
575 std::vector<const SimpleEdge*>* remove_edges) {
576 // Transfer all inputs and outputs of 'dst' to 'src' except edges
577 // connecting the two.
578 auto src = edge->src();
579 auto dst = edge->dst();
580
581 // We can use '0' for input/output index because we don't need them
582 // to be accurate for the way we are using the graph.
583 std::vector<const SimpleEdge*> in_edges(dst->in_edges().begin(),
584 dst->in_edges().end());
585 for (const SimpleEdge* in_edge : in_edges) {
586 if (in_edge->IsControlEdge()) {
587 if (in_edge->src() != src) {
588 SimpleEdge* e = const_cast<SimpleEdge*>(in_edge);
589 graph->AddControlEdge(e->src(), src);
590 }
591 } else {
592 if (in_edge->src() != src) {
593 SimpleEdge* e = const_cast<SimpleEdge*>(in_edge);
594 if (e->src() == graph->source_node()) {
595 graph->AddEdge(e->src(), e->src_output(), src, Graph::kControlSlot);
596 } else {
597 graph->AddEdge(e->src(), e->src_output(), src, 0 /* input index */);
598 }
599 }
600 }
601 }
602
603 std::vector<const SimpleEdge*> out_edges(dst->out_edges().begin(),
604 dst->out_edges().end());
605 for (const SimpleEdge* out_edge : out_edges) {
606 if (out_edge->IsControlEdge()) {
607 SimpleEdge* e = const_cast<SimpleEdge*>(out_edge);
608 graph->AddControlEdge(src, e->dst());
609 } else {
610 SimpleEdge* e = const_cast<SimpleEdge*>(out_edge);
611 if (e->dst() == graph->sink_node()) {
612 VLOG(1) << " edge to sink node " << src->name() << " -> "
613 << e->dst()->name();
614 graph->AddEdge(src, Graph::kControlSlot, e->dst(), e->dst_input());
615 } else {
616 graph->AddEdge(src, 0 /* output index */, e->dst(), e->dst_input());
617 }
618 }
619 }
620
621 // Return the edges that must be removed to disconnect 'dst' from
622 // the graph. We don't actually remove 'dst' since the caller holds
623 // references to all the nodes.
624 for (const auto& in_edge : dst->in_edges()) {
625 remove_edges->push_back(in_edge);
626 }
627 for (const auto& out_edge : dst->out_edges()) {
628 remove_edges->push_back(out_edge);
629 }
630 }
631
632 // Returns a batch size representation for a segment that only contains the
633 // given node.
GetClusterBatchSizeForNode(const grappler::GraphProperties * graph_properties,const Node * node,bool use_implicit_batch)634 ClusterBatchSize GetClusterBatchSizeForNode(
635 const grappler::GraphProperties* graph_properties, const Node* node,
636 bool use_implicit_batch) {
637 ClusterBatchSize cluster_batch_size;
638 if (!use_implicit_batch || !node || node->num_inputs() == 0) {
639 return cluster_batch_size;
640 }
641
642 const NodeDef& node_def = node->def();
643 if (node_def.attr().count(kTftrtOpMaxBatchSizeAttr)) {
644 cluster_batch_size.SetMaxBatchSize(
645 node_def.attr().at(kTftrtOpMaxBatchSizeAttr).i());
646 }
647
648 // As shape inference cannot provide any useful information about the batch
649 // size, we keep it as missing.
650 if (!graph_properties ||
651 !graph_properties->HasInputProperties(node->name())) {
652 VLOG(3) << "doesn't have input property";
653 return cluster_batch_size;
654 }
655
656 const std::vector<OpInfo::TensorProperties>& input_properties =
657 graph_properties->GetInputProperties(node->name());
658 absl::optional<const TensorShapeProto*> optional_leading_shape =
659 FindLeadingShape(GetInputsToDeterminateBatchSize(node, input_properties));
660 DCHECK(optional_leading_shape.has_value());
661 const TensorShapeProto* leading_shape = optional_leading_shape.value();
662 DCHECK(!leading_shape->unknown_rank() && leading_shape->dim_size() >= 2);
663 VLOG(3) << "set batch size as " << leading_shape->dim(0).size();
664 return cluster_batch_size.SetBatchSize(leading_shape->dim(0).size());
665 }
666
AddSegmentForNode(const grappler::GraphProperties * graph_properties,std::vector<UnionFind<SimpleNode * >> * segments,SimpleNode * node,const DeviceNameUtils::ParsedName & device_name,bool use_implicit_batch)667 void AddSegmentForNode(const grappler::GraphProperties* graph_properties,
668 std::vector<UnionFind<SimpleNode*>>* segments,
669 SimpleNode* node,
670 const DeviceNameUtils::ParsedName& device_name,
671 bool use_implicit_batch) {
672 ClusterProperty property(
673 GetClusterBatchSizeForNode(graph_properties,
674 node == nullptr ? nullptr : node->tf_node(),
675 use_implicit_batch),
676 device_name);
677 segments->emplace_back(node, std::move(property));
678 }
679
680 } // namespace
681
SegmentGraph(const Graph * tf_graph,const grappler::GraphProperties * graph_properties,const std::function<Status (const Node *)> & candidate_fn,const std::function<bool (const Edge *)> & input_candidate_fn,const std::function<bool (const Edge *)> & output_candidate_fn,const SegmentOptions & options,SegmentVector * segments)682 Status SegmentGraph(const Graph* tf_graph,
683 const grappler::GraphProperties* graph_properties,
684 const std::function<Status(const Node*)>& candidate_fn,
685 const std::function<bool(const Edge*)>& input_candidate_fn,
686 const std::function<bool(const Edge*)>& output_candidate_fn,
687 const SegmentOptions& options, SegmentVector* segments) {
688 if (!options.use_implicit_batch && !options.allow_dynamic_non_batch_dim) {
689 return errors::Internal(
690 "Explicit batch mode should allow dynamic non-batch dimensions");
691 }
692
693 if (options.use_implicit_batch && !options.maximum_batch_size.has_value()) {
694 return errors::Internal("Implicit batch mode requires maximum_batch_size");
695 }
696
697 if (!options.allow_dynamic_non_batch_dim && !graph_properties) {
698 return errors::Internal(
699 "Need graph propertities to disallow dynamic non-batch dimensions");
700 }
701
702 // Steps:
703 // 1. run the segmentation algorithm to find all the segments, which uses
704 // candidate_fn to determine the candidates segment nodes;
705 // 2. for each segments, remove the nodes that are inputs/outputs of the
706 // segment but are not eligible, using input/output_candidate_fn to
707 // determine the eligibilities;
708 // 3. convert the segment into expected return format and return the result.
709
710 // --------------------------------- Step 1 ---------------------------------
711 auto graph = std::unique_ptr<SimpleGraph>(new SimpleGraph(tf_graph));
712 // Use a union-find to collect the nodes that belong to the same
713 // segment. A node value of nullptr indicates that the node is not a candidate
714 // for TRT.
715 std::unordered_set<string> unsupported_ops;
716 int num_unsupported_ops = 0;
717
718 // Getting the operations denylisted for conversion
719 string tftrt_op_denylist_str;
720 TF_CHECK_OK(
721 ReadStringFromEnvVar("TF_TRT_OP_DENYLIST", "", &tftrt_op_denylist_str));
722
723 auto tftrt_op_denylist = gtl::FlatSet<string>{}; // non-absl ok
724
725 for (const auto& x : str_util::Split(tftrt_op_denylist_str, ",")) {
726 tftrt_op_denylist.insert(x);
727 }
728
729 // Parsing each node of the graph
730 std::vector<UnionFind<SimpleNode*>> node_segments;
731 for (int i = 0; i < graph->num_node_ids(); ++i) {
732 SimpleNode* node = graph->FindNodeId(i);
733 if (!node) {
734 VLOG(3) << "Node " << i << " doesn't exist in the graph";
735 continue;
736 }
737 auto exclude_node = [&](absl::string_view reason) {
738 VLOG(1) << "Not a TF-TRT candidate, "
739 << "(Op type: " << node->tf_node()->type_string() << "), "
740 << "(Op name: " << node->name() << "), "
741 << "(Reason: " << reason << ")";
742 unsupported_ops.emplace(node->tf_node()->type_string());
743 num_unsupported_ops++;
744 node = nullptr;
745 };
746 absl::optional<DeviceNameUtils::ParsedName> device_name =
747 GetDeviceParsedName(node->tf_node());
748 // GetDeviceParseName capitalizes the device type.
749 if (!device_name.has_value() ||
750 (device_name->has_type && device_name->type != "GPU")) {
751 exclude_node("node can't be placed on GPU");
752 } else if (options.exclude_node_list.count(node->name()) != 0) {
753 exclude_node("excluded by segmenter option");
754 } else if (options.use_implicit_batch &&
755 !OperationCanBeTranslatedToImplicitBatch(graph_properties,
756 node->tf_node())) {
757 exclude_node(
758 "implicit batch mode requires input shape with at least two "
759 "dimensions");
760 } else if (!options.allow_dynamic_non_batch_dim &&
761 OperationHasDynamicNonBatchDimension(graph_properties,
762 node->tf_node())) {
763 exclude_node("dynamic non-batch dimensions not allowed");
764 } else {
765 const Status status = candidate_fn(node->tf_node());
766 if (!status.ok()) {
767 exclude_node(status.error_message());
768 } else if (tftrt_op_denylist.count(node->tf_node()->type_string())) {
769 // WARNING verbosity since the user explicitly requests this behavior.
770 LOG_WARNING_WITH_PREFIX
771 << "Denylisted as TF-TRT candidate, "
772 << "(Op type: " << node->tf_node()->type_string() << "), "
773 << "(Op name: " << node->name() << ")";
774 exclude_node("Denylisted with the env var TF_TRT_OP_DENYLIST");
775 } else {
776 VLOG(2) << "Accepted as a TF-TRT candidate, "
777 << "(Op type: " << node->tf_node()->type_string() << "), "
778 << "(Op name: " << node->name();
779 }
780 }
781 AddSegmentForNode(graph_properties, &node_segments, node, *device_name,
782 options.use_implicit_batch);
783 }
784 string msg = StrCat(
785 "There are ", num_unsupported_ops, " ops of ", unsupported_ops.size(),
786 " different types in the graph that", " are not converted to TensorRT: ");
787 for (const auto& elem : unsupported_ops) {
788 StrAppend(&msg, elem, ", ");
789 }
790 LOG(INFO) << msg << "(For more information see "
791 << "https://docs.nvidia.com/deeplearning"
792 << "/frameworks/tf-trt-user-guide/index.html#supported-ops).";
793
794 // The segmentation algorithm below visits nodes in reverse topological order
795 // and attempts to merge nodes along output edges. That means that subgraphs
796 // grow from the output-side of the network towards the inputs.
797 //
798 // In general this is not guaranteed to produce a globally optimal
799 // segmentation. For example, consider graph with node {A, B, C, D} and edges
800 // {A->B, A->C, B->D, C->D), where A, B, D are trt compatible but C is not, so
801 // in theory we can choose to contract either A, B or B, D but not both, but
802 // here it always choose to contract B, D.
803 //
804 // In the future if we have a measure of how beneficial it is to include a
805 // given node in a TRT subgraph then we can revisit this algorithm to take
806 // advantage of that information.
807 std::vector<const SimpleNode*> order;
808 order.reserve(graph->num_node_ids());
809 StableDFS(*graph, /*reverse=*/false, {graph->source_node()},
810 /*enter=*/nullptr, [&order](const SimpleNode* n) {
811 order.push_back(n);
812 return true;
813 });
814 for (const SimpleNode* node : order) {
815 // All output nodes of 'node' have been visited.
816 VLOG(3) << "Trying node " << node->name() << " id=" << node->id();
817 // 'node' must be a TRT candidate.
818 if (node_segments[node->id()].Value() == nullptr) {
819 VLOG(3) << "... not a TRT candidate";
820 continue;
821 }
822 // Contract output edges to combine 'node' with output nodes. Repeat this
823 // step until no output edges can be further contracted. This is because
824 // contracting an output edge may unblock new edges for contracting.
825 ClusterBatchSize expected_batch_size =
826 node_segments[node->id()].Property().BatchSize();
827 DeviceNameUtils::ParsedName expected_device_name =
828 node_segments[node->id()].Property().DeviceName();
829 VLOG(3) << "batch size " << expected_batch_size;
830 while (true) {
831 std::set<const SimpleEdge*, SimpleEdgePtrCompare> contract_edges;
832 // TODO(bixia): consider merging the loop to find the edges and the loop
833 // to contract the edges.
834 for (const SimpleEdge* out_edge : node->out_edges()) {
835 VLOG(3) << "... out node " << out_edge->dst()->name() << " ( "
836 << out_edge->dst()->id() << " <- " << node->id() << " )";
837 if (out_edge->IsControlEdge()) {
838 VLOG(3) << "... ... Control Edge, Skipping";
839 continue;
840 }
841 UnionFind<SimpleNode*>* out_cluster =
842 &node_segments[out_edge->dst()->id()];
843 // Out node must be a TRT candidate.
844 if (out_cluster->Value() == nullptr) {
845 VLOG(3) << "... ... not a TRT candidate";
846 continue;
847 }
848 // Out node must have compatible batch size.
849 ClusterBatchSize out_batch_size = out_cluster->Property().BatchSize();
850 ClusterBatchSize merged_batch_size = expected_batch_size;
851 if (!merged_batch_size.MergeIfCompatible(out_batch_size)) {
852 VLOG(3) << "... ... incompatible batch sizes "
853 << expected_batch_size.ToString() << " "
854 << out_batch_size.ToString();
855 continue;
856 }
857
858 const DeviceNameUtils::ParsedName& out_device_name =
859 out_cluster->Property().DeviceName();
860 absl::optional<DeviceNameUtils::ParsedName> merged_device_name =
861 MergeIfCompatible(expected_device_name, out_device_name);
862 if (!merged_device_name.has_value()) {
863 VLOG(3) << "... ... incompatible device names "
864 << expected_device_name << " " << out_device_name;
865 continue;
866 }
867
868 if (CanContractEdge(out_edge, graph)) {
869 VLOG(3) << "... ... can contract. new batch size "
870 << merged_batch_size.ToString();
871 contract_edges.insert(out_edge);
872 expected_batch_size = merged_batch_size;
873 expected_device_name = *merged_device_name;
874 } else {
875 VLOG(3) << "... ... cannot contract, would form cycle";
876 }
877 }
878 if (contract_edges.empty()) {
879 break;
880 }
881 // Contract edges and collect the adjacent nodes into the same
882 // segment/subgraph.
883 while (!contract_edges.empty()) {
884 const SimpleEdge* contract_edge = *contract_edges.begin();
885 const SimpleNode* src = contract_edge->src();
886 const SimpleNode* dst = contract_edge->dst();
887
888 VLOG(3) << "Merge " << src->name() << " <- " << dst->name() << " ("
889 << src->id() << " <- " << dst->id();
890 TF_RETURN_IF_ERROR(
891 node_segments[src->id()].Merge(&node_segments[dst->id()]));
892
893 // Contracting the edge leaves disconnected graph edges.
894 // Remove these from the graph and from 'contract_edges' so we
895 // don't visit them again.
896 SimpleEdge* e = const_cast<SimpleEdge*>(contract_edge);
897 std::vector<const SimpleEdge*> remove_edges;
898 ContractEdge(e, graph.get(), &remove_edges);
899
900 for (const SimpleEdge* r : remove_edges) {
901 contract_edges.erase(r);
902 graph->RemoveEdge(r);
903 }
904 }
905 if (expected_batch_size !=
906 node_segments[node->id()].Property().BatchSize()) {
907 return errors::Internal(
908 "expected batch size is not the same as the actual batch size");
909 }
910 if (expected_device_name !=
911 node_segments[node->id()].Property().DeviceName()) {
912 return errors::Internal(
913 "expected device name is not the same as the actual device name");
914 }
915 }
916 }
917
918 // Collect the segments/subgraphs. Each subgraph is represented by a
919 // set of the names of the nodes in that subgraph.
920
921 // A map from the segment identifier (currently the name of the root node of
922 // the segment tree) to the segment nodes set.
923 std::map<string, Segment> sg_map;
924
925 for (auto& u : node_segments) {
926 if ((u.Value() != nullptr) && (u.ParentValue() != nullptr)) {
927 sg_map[u.ParentValue()->name()].nodes.insert(u.Value()->tf_node());
928 }
929 if ((u.Value() != nullptr) && (u.ParentValue() == u.Value())) {
930 sg_map[u.Value()->name()].property = u.Property();
931 }
932 }
933
934 // --------------------------------- Step 2 ---------------------------------
935 // Remove ineligible input/output nodes.
936 for (auto& itr : sg_map) {
937 std::set<const Node*, NodePtrCompare>& segment_nodes = itr.second.nodes;
938 VLOG(1) << "Segment original size: " << segment_nodes.size();
939 while (true) {
940 std::deque<const Node*> in_nodes_que, out_nodes_que;
941 // Find an input node that is not eligible and add it to the queue.
942 // Nodes that has no incoming edges should not be treated as "input",
943 // as there are really no inputs to them. Similar for output nodes.
944 for (auto node : segment_nodes) {
945 bool added = false;
946 for (const Edge* edge : node->in_edges()) {
947 if (!edge->IsControlEdge() && !edge->src()->IsSource() &&
948 !segment_nodes.count(edge->src())) { // 'node' is an input node.
949 if (!input_candidate_fn(edge)) {
950 in_nodes_que.push_back(node);
951 added = true;
952 break;
953 }
954 }
955 }
956 if (added) continue; // Only adding the node once to either queue.
957 for (const Edge* edge : node->out_edges()) {
958 if (!edge->dst()->IsSink() && !edge->IsControlEdge() &&
959 !segment_nodes.count(edge->dst())) { // 'node' is an output node.
960 if (!output_candidate_fn(edge)) {
961 out_nodes_que.push_back(node);
962 break;
963 }
964 }
965 }
966 }
967 if (in_nodes_que.empty() && out_nodes_que.empty()) {
968 // No more ineligible input/output nodes.
969 break;
970 }
971 // Now for each ineligible node, remove all of its inputs or outputs from
972 // the subgraph.
973 //
974 // It can be proven that, if the original subgraph:
975 // 1. is a DAG, and
976 // 2. all paths between two nodes in the subgraph are all inside the
977 // subgraph
978 // then after doing this operation the resulting subgraph will keep the
979 // same properties 1 and 2.
980 //
981 // For simplicity we use heuristics: for input and const output nodes
982 // remove all their inputs, and for non-const output nodes remove all
983 // their outputs. In this way, for common cases the number of removed
984 // nodes should be minimum.
985 auto remove_nodes = [&segment_nodes](bool is_input_nodes,
986 std::deque<const Node*>* que) {
987 // Run a BFS on the queue to find all the input/output nodes.
988 std::set<const Node*, NodePtrCompare> visited;
989 std::set<const Node*, NodePtrCompare> logged(que->begin(), que->end());
990 while (!que->empty()) {
991 auto node = que->front();
992 que->pop_front();
993 if (!visited.insert(node).second) continue;
994 segment_nodes.erase(node);
995 for (auto in : (is_input_nodes || node->type_string() == "Const")
996 ? node->in_nodes()
997 : node->out_nodes()) {
998 if (segment_nodes.count(in)) {
999 que->push_back(in);
1000 if (VLOG_IS_ON(2)) {
1001 if (!logged.count(in)) {
1002 VLOG(2) << "----> Need to remove node " << in->name()
1003 << " because one of its "
1004 << (is_input_nodes ? "output" : "input")
1005 << " nodes in the graph was removed: "
1006 << node->name();
1007 logged.insert(in);
1008 }
1009 }
1010 }
1011 }
1012 }
1013 };
1014 remove_nodes(true, &in_nodes_que);
1015 remove_nodes(false, &out_nodes_que);
1016 }
1017 VLOG(1) << "Segment new size: " << segment_nodes.size();
1018 }
1019
1020 // --------------------------------- Step 3 ---------------------------------
1021 // Convert the segments into the expected return format
1022 for (const auto& itr : sg_map) {
1023 const string& segment_root = itr.first;
1024 // Return format does not require set comparator.
1025 std::set<const Node*, NodePtrCompare> segment_nodes(
1026 itr.second.nodes.begin(), itr.second.nodes.end());
1027 if (VLOG_IS_ON(1) && !segment_nodes.empty()) {
1028 string s;
1029 for (auto node : segment_nodes) {
1030 StrAppend(&s, "\n[Op type: ", node->type_string(), "] ", node->name());
1031 }
1032 VLOG(1) << "Nodes in segment " << segments->size()
1033 << " with parent=" << segment_root << ":" << s;
1034 }
1035
1036 const int num_effective_nodes = std::count_if(
1037 segment_nodes.begin(), segment_nodes.end(), [](const Node* node) {
1038 static auto noops =
1039 new std::set<string>{"Identity", "Snapshot", "StopGradient"};
1040 return noops->count(node->type_string()) == 0;
1041 });
1042
1043 // Don't use segments whose number of effective nodes is small.
1044 if (num_effective_nodes == 0 ||
1045 num_effective_nodes < options.minimum_segment_size) {
1046 VLOG(1) << "Segment " << segments->size() << " has only "
1047 << num_effective_nodes << " effective nodes, dropping";
1048 continue;
1049 }
1050 segments->emplace_back(itr.second.property, segment_nodes);
1051 }
1052
1053 return Status::OK();
1054 }
1055
1056 } // namespace segment
1057 } // namespace tensorrt
1058 } // namespace tensorflow
1059
1060 #endif // GOOGLE_CUDA && GOOGLE_TENSORRT
1061