1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/jit/deadness_analysis.h"
17 #include "absl/algorithm/container.h"
18 #include "absl/container/flat_hash_map.h"
19 #include "absl/container/flat_hash_set.h"
20 #include "absl/strings/str_join.h"
21 #include "tensorflow/compiler/jit/deadness_analysis_internal.h"
22 #include "tensorflow/compiler/jit/xla_cluster_util.h"
23 #include "tensorflow/compiler/xla/status_macros.h"
24 #include "tensorflow/core/framework/tensor.pb.h"
25 #include "tensorflow/core/graph/algorithm.h"
26 #include "tensorflow/core/graph/control_flow.h"
27 #include "tensorflow/core/graph/tensor_id.h"
28 #include "tensorflow/core/lib/hash/hash.h"
29
30 // ALGORITHM OVERVIEW
31 // ==================
32 //
33 // We map every output produced by each node in the TensorFlow graph (including
34 // control dependence) into an instance of the Predicate class. Instances of
35 // Predicate denote logical formulas and mapping a node `n` to a predicate
36 // `pred` implies that `n` is live whenever `pred` is true. Then we can deduce
37 // mismatching liveness in the inputs to node by comparing the predicate those
38 // inputs are mapped to. The core logic of this pass resides in creating the
39 // map from TensorFlow nodes to predicates.
40 //
41 //
42 // MAPPING NODES TO PREDICATES, MODULO CYCLES
43 // ------------------------------------------
44 //
45 // If we ignore cycles for a moment, computing predicates is fairly
46 // straightforward. We traverse the graph in RPO, mapping each node to a
47 // predicate based on the predicates its inputs are mapped to. For instance a
48 // Merge(X, Y) node will be mapped to OR(PredicateFor(X), PredicateFor(Y)).
49 // Roughtly speaking, we abstract interpret each node on the "liveness" domain,
50 // where values in the domain represent if a tensor carries a dead signal or
51 // not.
52 //
53 //
54 // DEALING WITH CYCLES
55 // -------------------
56 //
57 // We map Merge nodes that are the target of a backedge to AndRecurrence
58 // instances. An AndRecurrence with start() = S and step() = X, printed as
59 // {S,&,X}, *roughly* represents the infinite list of predicates
60 // [S,S&X,S&X&X,S&X&X, ...]. So {S,&,X} can be used to represent the predicate
61 // for Merge in a graph like:
62 //
63 // Init
64 // |
65 // v
66 // Merge <-----------+
67 // | |
68 // v |
69 // Incr |
70 // | |
71 // v |
72 // Switch <- Cond |
73 // | |
74 // v (oidx: 1) |
75 // | |
76 // +---------------+
77 //
78 // Where S is the predicate for Init and X is the predicate that asserts that
79 // Cond is true. {S,&,X} states that Merge is live on the first "iteration" iff
80 // S is true, live on the second iteration iff "S&X" is true, live on the third
81 // iteration iff "S&X&X" is true etc. There is a subtlety here, S&X&X would
82 // normally be equivalent to S&X which isn't quite what we want to represent.
83 // Instead we want {S,&,X} to denote the infinite list [S, S&X,
84 // S&X&X',S&X&X'&X'', ...] where X, X', X'' are predicates that assert Cond is
85 // true on iteration 0, 1, 2 respectively. This is made more precise in the
86 // comment on the AndRecurrence class.
87 //
88 // The general algorithm that deals with cycles does two RPO (reverse post
89 // order) passes over the graph. On the first pass it assigns a symbolic
90 // predicate to merge nodes with backedges. On the second pass it tries to
91 // pattern matche the predicates for the backedges of these merges and infer an
92 // AndRecurrence for the merge.
93 //
94 // In other words, we do a pessimistic data flow analysis where the data-flow
95 // lattice has two elements, Symbolic and NonSymbolic with Symbolic >
96 // NonSymbolic. The lattice has height = 2 so two iterations are sufficient to
97 // converge. We don't do an optimistic data flow analysis to make pattern
98 // matching easier: if we assigned the predicate of the initial value to the
99 // merge during the first pass, on the second pass the backedge may see a
100 // simplified value that would be difficult to pattern match.
101 //
102 // We still use symbolic predicates for merges for which we can't pattern match
103 // on the backedge predicate. This is conservatively correct.
104
105 namespace tensorflow {
106
107 namespace {
108
109 // Represents a logical predicate, used as described in the algorithm overview
110 // above.
111 class Predicate {
112 public:
113 enum class Kind { kAnd, kOr, kNot, kAndRecurrence, kSymbol };
114
115 virtual string ToString() const = 0;
116
117 // An ID assigned to the Predicate at construction time. Conceptually like a
118 // pointer, except that it is stable across runs.
id() const119 int64 id() const { return id_; }
120
121 virtual absl::Span<Predicate* const> GetOperands() const = 0;
122
123 virtual Kind kind() const = 0;
~Predicate()124 virtual ~Predicate() {}
125
126 // Invokes func on p and on all of its operands recursively. Does not invoke
127 // `func` on the same Predicate instance twice. Aborts the search if `func`
128 // returns true.
129 template <typename FunctionTy>
130 static void Visit(Predicate* p, const FunctionTy& func);
131
132 protected:
Predicate(int64 id)133 explicit Predicate(int64 id) : id_(id) {}
134
135 private:
136 const int64 id_;
137
138 TF_DISALLOW_COPY_AND_ASSIGN(Predicate);
139 };
140
141 // Represents a logical conjunction of a set of predicates.
142 class AndPredicate : public Predicate {
143 public:
AndPredicate(int64 id,std::vector<Predicate * > operands)144 explicit AndPredicate(int64 id, std::vector<Predicate*> operands)
145 : Predicate(id), operands_(std::move(operands)) {}
146
ToString() const147 string ToString() const override {
148 if (operands().empty()) {
149 return "#true";
150 }
151
152 std::vector<string> operands_str;
153 std::transform(operands().begin(), operands().end(),
154 std::back_inserter(operands_str),
155 [](Predicate* pred) { return pred->ToString(); });
156
157 return absl::StrCat("(", absl::StrJoin(operands_str, " & "), ")");
158 }
159
kind() const160 Kind kind() const override { return Kind::kAnd; }
161
GetOperands() const162 absl::Span<Predicate* const> GetOperands() const override {
163 return operands_;
164 }
operands() const165 absl::Span<Predicate* const> operands() const { return operands_; }
166
167 private:
168 std::vector<Predicate*> operands_;
169 };
170
171 // Represents a logical disjunction of a set of predicates.
172 class OrPredicate : public Predicate {
173 public:
OrPredicate(int64 id,std::vector<Predicate * > operands)174 explicit OrPredicate(int64 id, std::vector<Predicate*> operands)
175 : Predicate(id), operands_(std::move(operands)) {}
176
ToString() const177 string ToString() const override {
178 if (operands().empty()) {
179 return "#false";
180 }
181
182 std::vector<string> operands_str;
183 std::transform(operands().begin(), operands().end(),
184 std::back_inserter(operands_str),
185 [](Predicate* pred) { return pred->ToString(); });
186
187 return absl::StrCat("(", absl::StrJoin(operands_str, " | "), ")");
188 }
189
kind() const190 Kind kind() const override { return Kind::kOr; }
GetOperands() const191 absl::Span<Predicate* const> GetOperands() const override {
192 return operands_;
193 }
operands() const194 absl::Span<Predicate* const> operands() const { return operands_; }
195
196 private:
197 std::vector<Predicate*> operands_;
198 };
199
200 // Represents a logical negation of a set of predicates.
201 class NotPredicate : public Predicate {
202 public:
NotPredicate(int64 id,Predicate * operand)203 explicit NotPredicate(int64 id, Predicate* operand)
204 : Predicate(id), operands_({operand}) {}
205
ToString() const206 string ToString() const override {
207 return absl::StrCat("~", operand()->ToString());
208 }
209
kind() const210 Kind kind() const override { return Kind::kNot; }
operand() const211 Predicate* operand() const { return operands_[0]; }
GetOperands() const212 absl::Span<Predicate* const> GetOperands() const override {
213 return operands_;
214 }
215
216 private:
217 std::array<Predicate*, 1> operands_;
218 };
219
220 // Represents the liveness of an induction variable. For users inside the loop
221 // this represents the "current" liveness of the induction variable. For users
222 // outside the loop it represents the "last" liveness of the induction variable.
223 //
224 // More concretely, an and recurrence {S,&,X}<loop> represents the liveness of V
225 // in the following graph:
226 //
227 // V = Merge(S', V_NextIt)
228 // V = Op(V, X')
229 // V_NextIt = NextIteration(V)
230 //
231 // where Predicate(S') = S and Predicate(X') = X.
232 //
233 // `X` may contain symbolic predicates and the operations corresponding to these
234 // symbolic predicates are either in frame `loop` or outside it. The symbols
235 // that are inside frame `loop` are loop variant (i.e. can have different
236 // liveness in each loop iteration) and the symbols that are outside frame
237 // `loop` are loop invariant (i.e. have the same liveness across all
238 // iterations).
239 class AndRecurrencePredicate : public Predicate {
240 public:
AndRecurrencePredicate(int64 id,Predicate * start,Predicate * step,std::vector<string> frame)241 explicit AndRecurrencePredicate(int64 id, Predicate* start, Predicate* step,
242 std::vector<string> frame)
243 : Predicate(id), operands_({start, step}), frame_(std::move(frame)) {}
244
start() const245 Predicate* start() const { return operands_[0]; }
step() const246 Predicate* step() const { return operands_[1]; }
frame() const247 absl::Span<const string> frame() const { return frame_; }
248
ToString() const249 string ToString() const override {
250 return absl::StrCat("{", start()->ToString(), ",&,", step()->ToString(),
251 "}<", absl::StrJoin(frame(), ";"), ">");
252 }
253
kind() const254 Kind kind() const override { return Kind::kAndRecurrence; }
255
GetOperands() const256 absl::Span<Predicate* const> GetOperands() const override {
257 return operands_;
258 }
259
260 private:
261 std::array<Predicate*, 2> operands_;
262 std::vector<string> frame_;
263 };
264
265 // Represents an uninterpreted symbol in a logical predicate.
266 //
267 // Two predicates are equivalent iff they are equivalent for all assignments to
268 // the symbols contained in them, i.e. predicates are forall qualified over
269 // symbols.
270 class SymbolPredicate : public Predicate {
271 public:
SymbolPredicate(int64 id,TensorId tensor_id,bool must_be_true)272 explicit SymbolPredicate(int64 id, TensorId tensor_id, bool must_be_true)
273 : Predicate(id),
274 tensor_id_(std::move(tensor_id)),
275 must_be_true_(must_be_true) {}
276
ToString() const277 string ToString() const override {
278 return must_be_true() ? absl::StrCat("*", tensor_id_.ToString())
279 : tensor_id_.ToString();
280 }
281
kind() const282 Kind kind() const override { return Kind::kSymbol; }
GetOperands() const283 absl::Span<Predicate* const> GetOperands() const override { return {}; }
284
285 // If `must_be_true()` is true this SymbolPredicate represents the proposition
286 // "tensor_id() is live and evaluates to true".
287 //
288 // If `must_be_true()` is false then this SymbolPredicate represents the
289 // proposition "tensor_id() is live (and may evaluate to any value)"
tensor_id() const290 TensorId tensor_id() const { return tensor_id_; }
must_be_true() const291 bool must_be_true() const { return must_be_true_; }
292
293 private:
294 TensorId tensor_id_;
295 bool must_be_true_;
296 };
297
298 template <typename FunctionTy>
Visit(Predicate * p,const FunctionTy & func)299 /*static*/ void Predicate::Visit(Predicate* p, const FunctionTy& func) {
300 absl::flat_hash_set<Predicate*> visited;
301 std::vector<Predicate*> stack;
302
303 stack.push_back(p);
304 visited.insert(p);
305
306 while (!stack.empty()) {
307 Predicate* current = stack.back();
308 stack.pop_back();
309 bool done = func(current);
310 if (done) {
311 return;
312 }
313 for (Predicate* op : current->GetOperands()) {
314 if (visited.insert(op).second) {
315 stack.push_back(op);
316 }
317 }
318 }
319 }
320
321 // Creates and owns Predicate instances. Simplifies predicates as it creates
322 // them.
323 class PredicateFactory {
324 public:
MakeAndPredicate(absl::Span<Predicate * const> operands)325 Predicate* MakeAndPredicate(absl::Span<Predicate* const> operands) {
326 return MakeAndOrImpl(operands, /*is_and=*/true);
327 }
328
MakeOrPredicate(absl::Span<Predicate * const> operands)329 Predicate* MakeOrPredicate(absl::Span<Predicate* const> operands) {
330 return MakeAndOrImpl(operands, /*is_and=*/false);
331 }
332
MakeNotPredicate(Predicate * pred)333 Predicate* MakeNotPredicate(Predicate* pred) {
334 auto it = make_not_predicate_cache_.find(pred);
335 if (it != make_not_predicate_cache_.end()) {
336 return it->second;
337 }
338
339 Predicate* result = MakeNotPredicateImpl(pred);
340
341 bool insert_successful =
342 make_not_predicate_cache_.insert({pred, result}).second;
343 (void)insert_successful;
344 DCHECK(insert_successful);
345
346 return result;
347 }
348
MakeAndRecurrencePredicate(Predicate * start,Predicate * step,std::vector<string> frame)349 Predicate* MakeAndRecurrencePredicate(Predicate* start, Predicate* step,
350 std::vector<string> frame) {
351 SignatureForAndRec signature(start, step, std::move(frame));
352 auto it = interned_and_rec_instances_.find(signature);
353 if (it != interned_and_rec_instances_.end()) {
354 return it->second.get();
355 }
356
357 std::unique_ptr<Predicate> new_pred = Make<AndRecurrencePredicate>(
358 std::get<0>(signature), std::get<1>(signature), std::get<2>(signature));
359 Predicate* new_pred_ptr = new_pred.get();
360 bool inserted =
361 interned_and_rec_instances_.emplace(signature, std::move(new_pred))
362 .second;
363 (void)inserted;
364 DCHECK(inserted);
365 return new_pred_ptr;
366 }
367
MakeSymbolPredicate(Node * node,int output_idx,bool must_be_true,Predicate ** predicate)368 Status MakeSymbolPredicate(Node* node, int output_idx, bool must_be_true,
369 Predicate** predicate) {
370 TensorId tensor_id(node->name(), output_idx);
371
372 bool is_boolean_tensor = node->output_type(tensor_id.index()) == DT_BOOL;
373 TF_RET_CHECK(!must_be_true || is_boolean_tensor);
374
375 if (node->type_string() == "Const" && must_be_true) {
376 const TensorProto* proto = nullptr;
377 TF_RETURN_IF_ERROR(GetNodeAttr(node->def(), "value", &proto));
378
379 Tensor tensor(proto->dtype());
380 TF_RET_CHECK(tensor.FromProto(*proto));
381
382 *predicate = tensor.scalar<bool>()() ? MakeTrue() : MakeFalse();
383 return Status::OK();
384 }
385
386 SignatureForSymbol signature = {tensor_id, must_be_true};
387 auto it = interned_symbol_instances_.find(signature);
388 if (it == interned_symbol_instances_.end()) {
389 std::unique_ptr<Predicate> new_pred =
390 Make<SymbolPredicate>(tensor_id, must_be_true);
391 Predicate* new_pred_ptr = new_pred.get();
392 interned_symbol_instances_.emplace(std::move(signature),
393 std::move(new_pred));
394 *predicate = new_pred_ptr;
395 } else {
396 *predicate = it->second.get();
397 }
398
399 return Status::OK();
400 }
401
MakeTrue()402 Predicate* MakeTrue() { return MakeAndPredicate({}); }
MakeFalse()403 Predicate* MakeFalse() { return MakeOrPredicate({}); }
404
~PredicateFactory()405 ~PredicateFactory() {
406 DCHECK_EQ(stack_depth_, 0) << "Unnested IncrementStackDepth?";
407 }
408
409 private:
MakeNotPredicateImpl(Predicate * pred)410 Predicate* MakeNotPredicateImpl(Predicate* pred) {
411 IncrementStackDepth stack_frame(this);
412 if (!stack_frame.HasOverflowed()) {
413 if (Predicate* simplified = SimplifyUsingDeMorgan(pred)) {
414 return simplified;
415 }
416
417 // ~~A => A
418 if (auto* not_pred = dynamic_cast<NotPredicate*>(pred)) {
419 return not_pred->operand();
420 }
421 }
422
423 SignatureForNot signature = pred;
424 auto it = interned_not_instances_.find(signature);
425 if (it == interned_not_instances_.end()) {
426 std::unique_ptr<Predicate> new_pred = Make<NotPredicate>(pred);
427 Predicate* new_pred_ptr = new_pred.get();
428 interned_not_instances_.emplace(signature, std::move(new_pred));
429 return new_pred_ptr;
430 } else {
431 return it->second.get();
432 }
433 }
434
SimplifyUsingDeMorgan(Predicate * pred)435 Predicate* SimplifyUsingDeMorgan(Predicate* pred) {
436 // ~(A & B & C & ...) => ~A | ~B | ~C | ~...
437 // ~(A | B | C | ...) -> ~A & ~B & ~C & ~...
438 Predicate::Kind kind = pred->kind();
439
440 if (kind == Predicate::Kind::kAnd || kind == Predicate::Kind::kOr) {
441 std::vector<Predicate*> new_operands;
442 absl::c_transform(pred->GetOperands(), std::back_inserter(new_operands),
443 [&](Predicate* p) { return MakeNotPredicate(p); });
444 return kind == Predicate::Kind::kOr ? MakeAndPredicate(new_operands)
445 : MakeOrPredicate(new_operands);
446 }
447
448 return nullptr;
449 }
450
451 template <typename PredicateT, typename... Args>
Make(Args &&...args)452 std::unique_ptr<Predicate> Make(Args&&... args) {
453 // If we ever expose the Predicate class outside this .cc file then we may
454 // want to make this hard to misuse (by accidentally passing in an arbitrary
455 // integer to the Predicate constructor for instance).
456 return std::unique_ptr<PredicateT>(
457 new PredicateT(id_counter_++, std::forward<Args>(args)...));
458 }
459
460 Predicate* MakeAndOrImpl(absl::Span<Predicate* const> operands, bool is_and);
461 Predicate* MakeInternedAndOr(std::vector<Predicate*> simplified_ops,
462 Predicate::Kind pred_kind);
463
464 // Predicate instances are interned, meaning that there is only a single
465 // instance of a Predicate object with a given content. This makes checking
466 // for structural equality super-cheap -- we can just compare pointers.
467 //
468 // We intern predicates by maintaining a map from the content of a Predicate
469 // to the only instance of said predicate we allow to exist in the
470 // interned_and_or_instances_, interned_not_instances_ and
471 // interned_symbol_instances_ fields. These maps also double up as storage
472 // for the owning pointers to predicate instances.
473
474 using SignatureForAndOr =
475 std::pair<Predicate::Kind, absl::Span<Predicate* const>>;
476 using SignatureForNot = Predicate*;
477 using SignatureForAndRec =
478 std::tuple<Predicate*, Predicate*, std::vector<string>>;
479 using SignatureForSymbol = std::pair<SafeTensorId, bool>;
480
481 struct HashSignatureForAndOr {
operator ()tensorflow::__anon9db8ce5c0111::PredicateFactory::HashSignatureForAndOr482 size_t operator()(const SignatureForAndOr& signature) const {
483 size_t hash = ::tensorflow::hash<Predicate::Kind>()(signature.first);
484 for (Predicate* p : signature.second) {
485 hash = Hash64Combine(hash, ::tensorflow::hash<Predicate*>()(p));
486 }
487 return hash;
488 }
489 };
490
491 struct HashSignatureForSymbol {
operator ()tensorflow::__anon9db8ce5c0111::PredicateFactory::HashSignatureForSymbol492 size_t operator()(const SignatureForSymbol& signature) const {
493 return Hash64Combine(SafeTensorId::Hasher()(signature.first),
494 ::tensorflow::hash<bool>()(signature.second));
495 }
496 };
497
498 // Used to limit recursion to avoid blowing up the stack and cap compile time.
499 class IncrementStackDepth {
500 public:
IncrementStackDepth(PredicateFactory * parent)501 explicit IncrementStackDepth(PredicateFactory* parent) : parent_(parent) {
502 parent_->stack_depth_++;
503 }
504
HasOverflowed() const505 bool HasOverflowed() const {
506 const int kMaxStackDepth = 8;
507 return parent_->stack_depth_ >= kMaxStackDepth;
508 }
509
~IncrementStackDepth()510 ~IncrementStackDepth() { parent_->stack_depth_--; }
511
512 private:
513 PredicateFactory* parent_;
514 };
515
516 // A cache for the MakeNotPredicate function.
517 //
518 // NB! This is *not* the same as `interned_not_instances_`.
519 // `interned_not_instances_` maps ensures pointer identity for `NotPredicate`
520 // instances, i.e., it ensures there at most one instance of Not(predicate)
521 // for any given predicate whereas `make_not_predicate_cache_` simply caches
522 // the result of the `MakeNotPredicate` function. The values in
523 // `interned_not_instances_` are always instance of `NotPredicate` whereas the
524 // values in `make_not_predicate_cache_` may not be (for instance it will map
525 // Not(Not(A)) to A).
526 absl::flat_hash_map<Predicate*, Predicate*> make_not_predicate_cache_;
527
528 absl::flat_hash_map<SignatureForAndOr, std::unique_ptr<Predicate>,
529 HashSignatureForAndOr>
530 interned_and_or_instances_;
531 absl::flat_hash_map<SignatureForNot, std::unique_ptr<Predicate>>
532 interned_not_instances_;
533 absl::flat_hash_map<SignatureForAndRec, std::unique_ptr<Predicate>>
534 interned_and_rec_instances_;
535 absl::flat_hash_map<SignatureForSymbol, std::unique_ptr<Predicate>,
536 HashSignatureForSymbol>
537 interned_symbol_instances_;
538 int64 id_counter_ = 0;
539 int stack_depth_ = 0;
540 };
541
MakeInternedAndOr(std::vector<Predicate * > simplified_ops,Predicate::Kind pred_kind)542 Predicate* PredicateFactory::MakeInternedAndOr(
543 std::vector<Predicate*> simplified_ops, Predicate::Kind pred_kind) {
544 std::stable_sort(
545 simplified_ops.begin(), simplified_ops.end(),
546 [](Predicate* a, Predicate* b) { return a->id() < b->id(); });
547
548 auto it = interned_and_or_instances_.find({pred_kind, simplified_ops});
549 if (it != interned_and_or_instances_.end()) {
550 return it->second.get();
551 }
552
553 simplified_ops.shrink_to_fit();
554 // NB! Because we'll use a non-owning reference to simplified_ops in the
555 // key for interned_and_or_instances_ we need to be careful to std::move()
556 // it all the way through.
557 absl::Span<Predicate* const> operands_slice = simplified_ops;
558 std::unique_ptr<Predicate> new_pred =
559 pred_kind == Predicate::Kind::kAnd
560 ? Make<AndPredicate>(std::move(simplified_ops))
561 : Make<OrPredicate>(std::move(simplified_ops));
562
563 Predicate* new_pred_ptr = new_pred.get();
564 interned_and_or_instances_.emplace(
565 SignatureForAndOr(pred_kind, operands_slice), std::move(new_pred));
566 return new_pred_ptr;
567 }
568
569 // Common code to create AndPredicate or OrPredicate instances.
MakeAndOrImpl(absl::Span<Predicate * const> operands,bool is_and)570 Predicate* PredicateFactory::MakeAndOrImpl(
571 absl::Span<Predicate* const> operands, bool is_and) {
572 Predicate::Kind pred_kind =
573 is_and ? Predicate::Kind::kAnd : Predicate::Kind::kOr;
574
575 IncrementStackDepth stack_frame(this);
576 if (stack_frame.HasOverflowed()) {
577 return MakeInternedAndOr(
578 std::vector<Predicate*>(operands.begin(), operands.end()), pred_kind);
579 }
580
581 Predicate::Kind other_pred_kind =
582 is_and ? Predicate::Kind::kOr : Predicate::Kind::kAnd;
583 absl::flat_hash_set<Predicate*> simplified_ops_set;
584 std::vector<Predicate*> simplified_ops;
585 for (Predicate* op : operands) {
586 // Simplify A&A => A and A|A => A.
587 if (!simplified_ops_set.insert(op).second) {
588 continue;
589 }
590
591 if (op->kind() == pred_kind) {
592 // "Inline" the operands of an inner And/Or into the parent And/Or.
593 for (Predicate* subop : op->GetOperands()) {
594 if (simplified_ops_set.insert(subop).second) {
595 simplified_ops.push_back(subop);
596 }
597 }
598 } else {
599 simplified_ops.push_back(op);
600 }
601 }
602
603 if (simplified_ops.size() == 1) {
604 return simplified_ops[0];
605 }
606
607 // Simplify "A&~A=>False" and "A|~A=>True".
608 absl::flat_hash_set<Predicate*> negated_ops;
609 for (Predicate* op : simplified_ops) {
610 if (negated_ops.count(op)) {
611 // Simple case:
612 //
613 // A & ~A & ... == False
614 // A | ~A | ... == True
615 return is_and ? MakeFalse() : MakeTrue();
616 }
617
618 Predicate* negated_op = MakeNotPredicate(op);
619 if (negated_op->kind() == pred_kind) {
620 // Slightly more complicated case:
621 //
622 // (~A | ~B | ~C) & A & B & C & ... ==
623 // ~(A & B & C) & (A & B & C) & ... == False
624 //
625 // (~A & ~B & ~C) | A | B | C | ... ==
626 // ~(A | B | C) | (A | B | C) | ... == True
627 if (absl::c_all_of(negated_op->GetOperands(), [&](Predicate* p) {
628 return simplified_ops_set.contains(p);
629 })) {
630 return is_and ? MakeFalse() : MakeTrue();
631 }
632 }
633 negated_ops.insert(negated_op);
634 }
635
636 // If all ops contain the same subop, then factor it out thanks to the
637 // distributive property. Such as:
638 // - (A & B) | (A & C) | (A & D) => A & (B | C | D)
639 // - (A | B) & (A | C) & (A | D) => A | (B & C & D)
640 //
641 // First find any predicates contained in all subops.
642 std::vector<Predicate*> common_inner_operands;
643 absl::flat_hash_set<Predicate*> common_inner_operands_set;
644 for (Predicate* op : simplified_ops) {
645 if (op->kind() != other_pred_kind) {
646 common_inner_operands.clear();
647 break;
648 }
649
650 if (common_inner_operands.empty()) {
651 common_inner_operands.insert(common_inner_operands.end(),
652 op->GetOperands().begin(),
653 op->GetOperands().end());
654 } else {
655 common_inner_operands.clear();
656 absl::c_copy_if(op->GetOperands(),
657 std::back_inserter(common_inner_operands),
658 [&](Predicate* sub_op) {
659 return common_inner_operands_set.count(sub_op) == 1;
660 });
661 }
662 if (common_inner_operands.empty()) break;
663 common_inner_operands_set.clear();
664 common_inner_operands_set.insert(common_inner_operands.begin(),
665 common_inner_operands.end());
666 }
667
668 if (common_inner_operands.empty()) {
669 return MakeInternedAndOr(std::move(simplified_ops), pred_kind);
670 }
671
672 // For all predicates that can be factored out, remove them and recreate the
673 // subops.
674 std::vector<Predicate*> factored_ops;
675 for (Predicate* op : simplified_ops) {
676 std::vector<Predicate*> new_sub_op_ops;
677 absl::c_copy_if(op->GetOperands(), std::back_inserter(new_sub_op_ops),
678 [&](Predicate* sub_op) {
679 return std::find(common_inner_operands.begin(),
680 common_inner_operands.end(),
681 sub_op) == common_inner_operands.end();
682 });
683 factored_ops.push_back(MakeAndOrImpl(new_sub_op_ops, !is_and));
684 }
685
686 Predicate* new_inner_op = MakeAndOrImpl(factored_ops, is_and);
687 std::vector<Predicate*> outer_ops;
688 outer_ops.push_back(new_inner_op);
689 outer_ops.insert(outer_ops.end(), common_inner_operands.begin(),
690 common_inner_operands.end());
691 return MakeAndOrImpl(outer_ops, !is_and);
692 }
693
694 class DeadnessAnalysisImpl : public DeadnessAnalysis {
695 public:
DeadnessAnalysisImpl(const Graph * graph)696 explicit DeadnessAnalysisImpl(const Graph* graph)
697 : graph_(*graph), vlog_(VLOG_IS_ON(2)) {}
698
699 Status Populate();
700 Status PopulateWithReversePostOrder(absl::Span<Node* const> rpo);
701 bool HasInputsWithMismatchingDeadness(const Node& node) override;
702 void Print() const override;
703 absl::flat_hash_map<TensorId, string, TensorId::Hasher> PredicateMapAsString()
704 const;
705
706 private:
707 enum class EdgeKind { kDataAndControl, kDataOnly, kControlOnly };
708
709 Status GetInputPreds(Node* n, EdgeKind edge_kind,
710 std::vector<Predicate*>* result);
711
712 // Sets the predicate for output `output_idx` of `n` to `pred`. Sets the i'th
713 // bit of `should_revisit` if `pred` is different from the current predicate
714 // for the `output_idx` output of `n`.
SetPredicate(Node * n,int output_idx,Predicate * pred,std::vector<bool> * should_revisit)715 void SetPredicate(Node* n, int output_idx, Predicate* pred,
716 std::vector<bool>* should_revisit) {
717 auto insert_result =
718 predicate_map_.insert({TensorId(n->name(), output_idx), pred});
719 if (!insert_result.second && insert_result.first->second != pred) {
720 VLOG(4) << "For " << n->name() << ":" << output_idx << " from "
721 << insert_result.first->second->ToString() << " "
722 << insert_result.first->second << " to " << pred->ToString()
723 << " " << pred;
724 insert_result.first->second = pred;
725 if (should_revisit != nullptr) {
726 for (const Edge* e : n->out_edges()) {
727 (*should_revisit)[e->dst()->id()] = true;
728 }
729 }
730 }
731 }
732
SetPredicate(Node * n,absl::Span<const int> output_idxs,Predicate * pred,std::vector<bool> * should_revisit)733 void SetPredicate(Node* n, absl::Span<const int> output_idxs, Predicate* pred,
734 std::vector<bool>* should_revisit) {
735 for (int output_idx : output_idxs) {
736 SetPredicate(n, output_idx, pred, should_revisit);
737 }
738 }
739
740 Status HandleSwitch(Node* n, std::vector<bool>* should_revisit);
741 Status HandleMerge(Node* n, std::vector<bool>* should_revisit);
742 Status HandleRecv(Node* n, std::vector<bool>* should_revisit);
743 Status HandleGeneric(Node* n, std::vector<bool>* should_revisit);
744 Status HandleNode(Node* n, std::vector<bool>* should_revisit);
745
746 const Graph& graph_;
747 absl::flat_hash_map<TensorId, Predicate*, TensorId::Hasher> predicate_map_;
748 PredicateFactory predicate_factory_;
749 std::vector<ControlFlowInfo> control_flow_info_;
750 bool vlog_;
751 };
752
InputEdgeToTensorId(const Edge * e)753 TensorId InputEdgeToTensorId(const Edge* e) {
754 return TensorId(e->src()->name(), e->src_output());
755 }
756
GetInputPreds(Node * n,DeadnessAnalysisImpl::EdgeKind edge_kind,std::vector<Predicate * > * result)757 Status DeadnessAnalysisImpl::GetInputPreds(
758 Node* n, DeadnessAnalysisImpl::EdgeKind edge_kind,
759 std::vector<Predicate*>* result) {
760 result->clear();
761 for (const Edge* in_edge : n->in_edges()) {
762 bool should_process =
763 edge_kind == EdgeKind::kDataAndControl ||
764 (in_edge->IsControlEdge() && edge_kind == EdgeKind::kControlOnly) ||
765 (!in_edge->IsControlEdge() && edge_kind == EdgeKind::kDataOnly);
766
767 if (should_process) {
768 auto it = predicate_map_.find(InputEdgeToTensorId(in_edge));
769 if (it == predicate_map_.end()) {
770 GraphCycles graph_cycles;
771 TF_RETURN_IF_ERROR(
772 CreateCycleDetectionGraph(&graph_, &graph_cycles).status());
773
774 // If we didn't return with an error above then the graph is probably
775 // fine and we have a bug in deadness analysis.
776 return errors::Internal("Could not find input ", in_edge->DebugString(),
777 " to ", n->name(),
778 " when visiting the graph in post-order. Most "
779 "likely indicates a bug in deadness analysis.");
780 }
781 result->push_back(it->second);
782 }
783 }
784 return Status::OK();
785 }
786
HandleSwitch(Node * n,std::vector<bool> * should_revisit)787 Status DeadnessAnalysisImpl::HandleSwitch(Node* n,
788 std::vector<bool>* should_revisit) {
789 std::vector<Predicate*> input_preds;
790 TF_RETURN_IF_ERROR(GetInputPreds(n, EdgeKind::kDataAndControl, &input_preds));
791 const Edge* pred_edge;
792 TF_RETURN_IF_ERROR(n->input_edge(1, &pred_edge));
793
794 Predicate* true_switch;
795 TF_RETURN_IF_ERROR(predicate_factory_.MakeSymbolPredicate(
796 pred_edge->src(), pred_edge->src_output(),
797 /*must_be_true=*/true, &true_switch));
798
799 Predicate* false_switch = predicate_factory_.MakeNotPredicate(true_switch);
800
801 // Output 0 is alive iff all inputs are alive and the condition is false.
802 input_preds.push_back(false_switch);
803 SetPredicate(n, 0, predicate_factory_.MakeAndPredicate(input_preds),
804 should_revisit);
805 input_preds.pop_back();
806
807 // Output 1 is alive iff all inputs are alive and the condition is true.
808 input_preds.push_back(true_switch);
809 SetPredicate(n, 1, predicate_factory_.MakeAndPredicate(input_preds),
810 should_revisit);
811 input_preds.pop_back();
812
813 // Control is alive iff all inputs are alive.
814 SetPredicate(n, Graph::kControlSlot,
815 predicate_factory_.MakeAndPredicate(input_preds),
816 should_revisit);
817
818 return Status::OK();
819 }
820
821 namespace {
CreateMultipleNextIterationInputsError(Node * merge)822 Status CreateMultipleNextIterationInputsError(Node* merge) {
823 std::vector<string> backedges;
824 for (const Edge* backedge : merge->in_edges()) {
825 if (backedge->src()->IsNextIteration()) {
826 backedges.push_back(absl::StrCat(" ", SummarizeNode(*backedge->src())));
827 }
828 }
829 return errors::InvalidArgument(
830 "Multiple NextIteration inputs to merge node ",
831 FormatNodeForError(*merge), ": \n", absl::StrJoin(backedges, "\n"),
832 "\nMerge nodes can have at most one incoming NextIteration edge.");
833 }
834
FindUniqueBackedge(Node * merge,const Edge ** result)835 Status FindUniqueBackedge(Node* merge, const Edge** result) {
836 *result = nullptr;
837 CHECK(merge->IsMerge());
838 for (const Edge* e : merge->in_edges()) {
839 if (e->src()->IsNextIteration()) {
840 if (*result != nullptr) {
841 return CreateMultipleNextIterationInputsError(merge);
842 }
843 *result = e;
844 }
845 }
846 return Status::OK();
847 }
848
849 // If `backedge_predicate` is equal to `symbolic_predicate` & Step where Step
850 // does not contain `symbolic_predicate` as an inner (not top-level) operand
851 // then returns `Step`. Otherwise returns nullptr.
DeduceStepPredicate(PredicateFactory * predicate_factory,Predicate * symbolic_predicate,Predicate * backedge_predicate)852 Predicate* DeduceStepPredicate(PredicateFactory* predicate_factory,
853 Predicate* symbolic_predicate,
854 Predicate* backedge_predicate) {
855 CHECK(dynamic_cast<SymbolPredicate*>(symbolic_predicate));
856 if (backedge_predicate->kind() != Predicate::Kind::kAnd) {
857 return nullptr;
858 }
859
860 std::vector<Predicate*> and_ops;
861 absl::Span<Predicate* const> recurrent_pred_ops =
862 backedge_predicate->GetOperands();
863
864 bool found_sym = false;
865 for (Predicate* and_op : recurrent_pred_ops) {
866 // We want the `symbol_predicate` to be the one of the operands of
867 // `backedge_predicate`,
868 if (and_op == symbolic_predicate) {
869 found_sym = true;
870 continue;
871 }
872
873 // but we don't want it to be present anywhere else in the formula. E.g. we
874 // don't want the recurrent predicate to be
875 // symbol_predicate&(X|symbol_predicate).
876 bool found_sym_as_inner_operand = false;
877 auto has_self_as_inner_operand = [&](Predicate* p) {
878 if (p == symbolic_predicate) {
879 found_sym_as_inner_operand = true;
880 return true; // Stop searching, we're done.
881 }
882
883 // Continue searching.
884 return false;
885 };
886
887 Predicate::Visit(and_op, has_self_as_inner_operand);
888 if (found_sym_as_inner_operand) {
889 return nullptr;
890 }
891 and_ops.push_back(and_op);
892 }
893
894 return found_sym ? predicate_factory->MakeAndPredicate(and_ops) : nullptr;
895 }
896
GetFullFrame(const Node * n,absl::Span<const ControlFlowInfo> cfi_infos,std::vector<string> * frame)897 Status GetFullFrame(const Node* n, absl::Span<const ControlFlowInfo> cfi_infos,
898 std::vector<string>* frame) {
899 int depth = 0;
900 for (const ControlFlowInfo* cfi_iter = &cfi_infos[n->id()]; !n->IsSource();
901 n = cfi_iter->parent_frame, cfi_iter = &cfi_infos[n->id()]) {
902 frame->push_back(cfi_iter->frame_name);
903
904 if (depth++ > 5000) {
905 return errors::Internal(
906 "Frame of depth > 5000: Probably malformed graph or a bug in "
907 "BuildControlFlowInfo");
908 }
909 }
910
911 return Status::OK();
912 }
913 } // namespace
914
HandleMerge(Node * n,std::vector<bool> * should_revisit)915 Status DeadnessAnalysisImpl::HandleMerge(Node* n,
916 std::vector<bool>* should_revisit) {
917 // Merge ignores deadness of its control inputs. A merge that isn't the
918 // target of a backedge has is alive iff any of its data inputs are. The
919 // liveness of a merge that is the target of a backedge can sometimes be
920 // represented using a AndRecurrencePredicate. If neither apply, we represent
921 // the liveness of the merge symbolically.
922
923 bool has_unvisited_backedge = false;
924 for (const Edge* e : n->in_edges()) {
925 if (!e->IsControlEdge() && e->src()->IsNextIteration()) {
926 has_unvisited_backedge |= !predicate_map_.count(InputEdgeToTensorId(e));
927 }
928 }
929
930 auto it = predicate_map_.find(TensorId(n->name(), 0));
931 if (it == predicate_map_.end()) {
932 if (has_unvisited_backedge) {
933 // We're visiting this merge for the first time and it has an unvisited
934 // backedge.
935 Predicate* input_data_pred;
936 TF_RETURN_IF_ERROR(predicate_factory_.MakeSymbolPredicate(
937 n, /*output_idx=*/0, /*must_be_true=*/false, &input_data_pred));
938
939 SetPredicate(n, {0, 1, Graph::kControlSlot}, input_data_pred,
940 should_revisit);
941 return Status::OK();
942 }
943
944 std::vector<Predicate*> input_preds;
945 TF_RETURN_IF_ERROR(GetInputPreds(n, EdgeKind::kDataOnly, &input_preds));
946
947 // We're visiting this merge for the first time and it is a acyclic merge.
948 Predicate* input_data_pred =
949 predicate_factory_.MakeOrPredicate(input_preds);
950 SetPredicate(n, {0, 1, Graph::kControlSlot}, input_data_pred,
951 should_revisit);
952 return Status::OK();
953 }
954
955 if (it->second->kind() == Predicate::Kind::kSymbol) {
956 // Last time we visited this merge we only got a symbolic predicate because
957 // of an unvisited backedge. Try to pattern match the predicate expression
958 // for that backedge (which should be visited now) into an and recurrence
959 // for the merge node.
960 const Edge* unique_backedge;
961 TF_RETURN_IF_ERROR(FindUniqueBackedge(n, &unique_backedge));
962 if (unique_backedge) {
963 if (Predicate* step = DeduceStepPredicate(
964 &predicate_factory_, it->second,
965 predicate_map_[InputEdgeToTensorId(unique_backedge)])) {
966 // If the predicate for the backedge is "Sym&X" where "Sym" is the
967 // predicate for the merge then the merge has predicate {S,&,X} where S
968 // is the predicate for the merge ignoring the backedge.
969 std::vector<Predicate*> non_recurrent_inputs;
970 for (const Edge* e : n->in_edges()) {
971 if (e != unique_backedge) {
972 non_recurrent_inputs.push_back(
973 predicate_map_[InputEdgeToTensorId(e)]);
974 }
975 }
976
977 Predicate* start =
978 predicate_factory_.MakeOrPredicate(non_recurrent_inputs);
979 std::vector<string> frame;
980 TF_RETURN_IF_ERROR(GetFullFrame(n, control_flow_info_, &frame));
981 Predicate* and_rec = predicate_factory_.MakeAndRecurrencePredicate(
982 start, step, std::move(frame));
983 SetPredicate(n, {0, 1, Graph::kControlSlot}, and_rec, should_revisit);
984 return Status::OK();
985 }
986 }
987 }
988 return Status::OK();
989 }
990
HandleRecv(Node * n,std::vector<bool> * should_revisit)991 Status DeadnessAnalysisImpl::HandleRecv(Node* n,
992 std::vector<bool>* should_revisit) {
993 // In addition to being alive or dead based on the inputs, a _Recv can also
994 // acquire a dead signal from a _Send.
995 std::vector<Predicate*> input_preds;
996 TF_RETURN_IF_ERROR(GetInputPreds(n, EdgeKind::kDataAndControl, &input_preds));
997 Predicate* signal_is_alive;
998 TF_RETURN_IF_ERROR(predicate_factory_.MakeSymbolPredicate(
999 n, /*output_idx=*/0, /*must_be_true=*/false, &signal_is_alive));
1000 input_preds.push_back(signal_is_alive);
1001 SetPredicate(n, {0, Graph::kControlSlot},
1002 predicate_factory_.MakeAndPredicate(input_preds),
1003 should_revisit);
1004 return Status::OK();
1005 }
1006
HandleGeneric(Node * n,std::vector<bool> * should_revisit)1007 Status DeadnessAnalysisImpl::HandleGeneric(Node* n,
1008 std::vector<bool>* should_revisit) {
1009 // Generally nodes are alive iff all their inputs are alive.
1010 std::vector<Predicate*> input_preds;
1011 TF_RETURN_IF_ERROR(GetInputPreds(n, EdgeKind::kDataAndControl, &input_preds));
1012 Predicate* pred = predicate_factory_.MakeAndPredicate(input_preds);
1013 for (int output_idx = 0; output_idx < n->num_outputs(); output_idx++) {
1014 SetPredicate(n, output_idx, pred, should_revisit);
1015 }
1016 SetPredicate(n, Graph::kControlSlot, pred, should_revisit);
1017 return Status::OK();
1018 }
1019
HandleNode(Node * n,std::vector<bool> * should_revisit)1020 Status DeadnessAnalysisImpl::HandleNode(Node* n,
1021 std::vector<bool>* should_revisit) {
1022 if (n->IsSwitch()) {
1023 TF_RETURN_IF_ERROR(HandleSwitch(n, should_revisit));
1024 } else if (n->IsMerge()) {
1025 TF_RETURN_IF_ERROR(HandleMerge(n, should_revisit));
1026 } else if (n->IsControlTrigger()) {
1027 SetPredicate(n, Graph::kControlSlot, predicate_factory_.MakeTrue(),
1028 nullptr);
1029 } else if (n->IsRecv() || n->IsHostRecv()) {
1030 TF_RETURN_IF_ERROR(HandleRecv(n, should_revisit));
1031 } else if (n->IsNextIteration()) {
1032 TF_RETURN_IF_ERROR(HandleGeneric(n, should_revisit));
1033 } else {
1034 TF_RETURN_IF_ERROR(HandleGeneric(n, should_revisit));
1035 }
1036 return Status::OK();
1037 }
1038
Populate()1039 Status DeadnessAnalysisImpl::Populate() {
1040 std::vector<Node*> rpo;
1041 GetReversePostOrder(graph_, &rpo, /*stable_comparator=*/NodeComparatorName(),
1042 /*edge_filter=*/[](const Edge& edge) {
1043 return !edge.src()->IsNextIteration();
1044 });
1045 return PopulateWithReversePostOrder(rpo);
1046 }
1047
PopulateWithReversePostOrder(absl::Span<Node * const> rpo)1048 Status DeadnessAnalysisImpl::PopulateWithReversePostOrder(
1049 absl::Span<Node* const> rpo) {
1050 std::vector<string> unreachable_nodes;
1051 // Compute the loop structure of the graph.
1052 TF_RETURN_IF_ERROR(
1053 BuildControlFlowInfo(&graph_, &control_flow_info_, &unreachable_nodes));
1054
1055 // Do some opportunistic error checking:
1056 if (!unreachable_nodes.empty()) {
1057 if (unreachable_nodes.size() > 5) {
1058 unreachable_nodes.erase(unreachable_nodes.begin() + 5,
1059 unreachable_nodes.end());
1060 }
1061
1062 return errors::InvalidArgument(
1063 "Found unreachable nodes, most likely source and sink nodes not "
1064 "connected: ",
1065 absl::StrJoin(unreachable_nodes, ", "));
1066 }
1067
1068 // This an abstract interpretation over the deadness propagation semantics of
1069 // the graph executor.
1070 //
1071 // We iterate over the graph twice, each time in RPO. On the first iteration
1072 // merge nodes with backedges are mapped to symbolic predicates. On the
1073 // second iteration we use the predicates assigned to the backedges in the
1074 // previous iteration to infer a more precise predicate for the backedge merge
1075 // nodes and all the nodes that transitively use it.
1076 //
1077 // We don't track the output indices for should_revisit. Instead, putting a
1078 // node in `should_revisit` denotes that the deadness flowing out from any
1079 // output from said node may have changed. This is fine; only switches
1080 // propagate different deadness along different output edges, and since the
1081 // delta is solely due to the input *values* (and not input deadness), the
1082 // delta should not change in the second iteration.
1083 std::vector<bool> should_revisit;
1084 should_revisit.resize(graph_.num_node_ids());
1085 for (Node* n : rpo) {
1086 VLOG(4) << "Visiting " << n->name();
1087 TF_RETURN_IF_ERROR(HandleNode(n, /*should_revisit=*/nullptr));
1088 if (n->IsNextIteration()) {
1089 // If this is a backedge for a merge node then remember to reprocess the
1090 // merge the next time we run.
1091 for (const Edge* e : n->out_edges()) {
1092 if (e->dst()->IsMerge()) {
1093 should_revisit[e->dst()->id()] = true;
1094 }
1095 }
1096 }
1097 }
1098
1099 for (Node* n : rpo) {
1100 // The nodes added to should_revisit in the previous loop need to be
1101 // revisited now. Reprocesing these initial nodes may add *their* consumers
1102 // to should_revisit, and these newly added nodes will also be processed by
1103 // this very same loop. Since we're traversing the graph in reverse post
1104 // order (producers before consumers) and HandleNode(n) can only ever add
1105 // n's consumers to should_revisit, we won't "miss" an addition to
1106 // should_revisit.
1107 if (should_revisit[n->id()]) {
1108 VLOG(4) << "Revisiting " << n->name();
1109 TF_RETURN_IF_ERROR(HandleNode(n, &should_revisit));
1110 }
1111 }
1112
1113 return Status::OK();
1114 }
1115
HasInputsWithMismatchingDeadness(const Node & node)1116 bool DeadnessAnalysisImpl::HasInputsWithMismatchingDeadness(const Node& node) {
1117 CHECK(!node.IsMerge());
1118
1119 if (vlog_) {
1120 VLOG(2) << "HasInputsWithMismatchingDeadness(" << node.name() << ")";
1121 }
1122
1123 Predicate* pred = nullptr;
1124 for (const Edge* edge : node.in_edges()) {
1125 auto it = predicate_map_.find(InputEdgeToTensorId(edge));
1126 CHECK(it != predicate_map_.end());
1127 if (vlog_) {
1128 VLOG(2) << " " << InputEdgeToTensorId(edge).ToString() << ": "
1129 << it->second->ToString();
1130 }
1131
1132 // Today we just compare the predicates for equality (with some
1133 // canonicalization/simplification happening before) but we could be more
1134 // sophisticated here if need be. Comparing pointers is sufficient because
1135 // we intern Predicate instances by their content.
1136 if (pred != nullptr && pred != it->second) {
1137 if (vlog_) {
1138 VLOG(2) << "HasInputsWithMismatchingDeadness(" << node.name()
1139 << ") -> true";
1140 }
1141 return true;
1142 }
1143 pred = it->second;
1144 }
1145
1146 if (vlog_) {
1147 VLOG(2) << "HasInputsWithMismatchingDeadness(" << node.name()
1148 << ") -> false";
1149 }
1150
1151 return false;
1152 }
1153
Print() const1154 void DeadnessAnalysisImpl::Print() const {
1155 std::vector<TensorId> tensor_ids;
1156 for (const auto& kv_pair : predicate_map_) {
1157 tensor_ids.push_back(kv_pair.first);
1158 }
1159
1160 std::sort(tensor_ids.begin(), tensor_ids.end());
1161
1162 for (TensorId tensor_id : tensor_ids) {
1163 auto it = predicate_map_.find(tensor_id);
1164 CHECK(it != predicate_map_.end()) << tensor_id.ToString();
1165 VLOG(2) << tensor_id.ToString() << " -> " << it->second->ToString();
1166 }
1167 }
1168
1169 } // namespace
1170
~DeadnessAnalysis()1171 DeadnessAnalysis::~DeadnessAnalysis() {}
1172
Run(const Graph & graph,std::unique_ptr<DeadnessAnalysis> * result)1173 /*static*/ Status DeadnessAnalysis::Run(
1174 const Graph& graph, std::unique_ptr<DeadnessAnalysis>* result) {
1175 std::unique_ptr<DeadnessAnalysisImpl> analysis(
1176 new DeadnessAnalysisImpl(&graph));
1177 TF_RETURN_IF_ERROR(analysis->Populate());
1178
1179 if (VLOG_IS_ON(2)) {
1180 analysis->Print();
1181 }
1182
1183 *result = std::move(analysis);
1184 return Status::OK();
1185 }
1186
1187 absl::flat_hash_map<TensorId, string, TensorId::Hasher>
PredicateMapAsString() const1188 DeadnessAnalysisImpl::PredicateMapAsString() const {
1189 absl::flat_hash_map<TensorId, string, TensorId::Hasher> result;
1190 std::vector<TensorId> tensor_ids;
1191 for (const auto& kv_pair : predicate_map_) {
1192 CHECK(result.insert({kv_pair.first, kv_pair.second->ToString()}).second);
1193 }
1194 return result;
1195 }
1196
1197 namespace deadness_analysis_internal {
ComputePredicates(const Graph & graph,PredicateMapTy * out_predicate_map)1198 Status ComputePredicates(const Graph& graph,
1199 PredicateMapTy* out_predicate_map) {
1200 DeadnessAnalysisImpl impl(&graph);
1201 TF_RETURN_IF_ERROR(impl.Populate());
1202 *out_predicate_map = impl.PredicateMapAsString();
1203 return Status::OK();
1204 }
1205
ComputePredicates(const Graph & graph,absl::Span<Node * const> reverse_post_order,PredicateMapTy * out_predicate_map)1206 Status ComputePredicates(const Graph& graph,
1207 absl::Span<Node* const> reverse_post_order,
1208 PredicateMapTy* out_predicate_map) {
1209 DeadnessAnalysisImpl impl(&graph);
1210 TF_RETURN_IF_ERROR(impl.PopulateWithReversePostOrder(reverse_post_order));
1211 *out_predicate_map = impl.PredicateMapAsString();
1212 return Status::OK();
1213 }
1214 } // namespace deadness_analysis_internal
1215
1216 } // namespace tensorflow
1217