• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for cfg module."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.python.autograph.pyct import cfg
22from tensorflow.python.autograph.pyct import parser
23from tensorflow.python.platform import test
24
25
26class CountingVisitor(cfg.GraphVisitor):
27
28  def __init__(self, graph):
29    super(CountingVisitor, self).__init__(graph)
30    self.counts = {}
31
32  def init_state(self, _):
33    return None
34
35  def visit_node(self, node):
36    self.counts[node.ast_node] = self.counts.get(node.ast_node, 0) + 1
37    return False  # visit only once
38
39
40class GraphVisitorTest(test.TestCase):
41
42  def _build_cfg(self, fn):
43    node, _, _ = parser.parse_entity(fn)
44    cfgs = cfg.build(node)
45    return cfgs, node
46
47  def test_basic_coverage_forward(self):
48
49    def test_fn(a):
50      while a > 0:
51        a = 1
52        break
53        return a  # pylint:disable=unreachable
54      a = 2
55
56    graphs, node = self._build_cfg(test_fn)
57    graph, = graphs.values()
58    visitor = CountingVisitor(graph)
59    visitor.visit_forward()
60
61    self.assertEqual(visitor.counts[node.args], 1)
62    self.assertEqual(visitor.counts[node.body[0].test], 1)
63    self.assertEqual(visitor.counts[node.body[0].body[0]], 1)
64    self.assertEqual(visitor.counts[node.body[0].body[1]], 1)
65    # The return node should be unreachable in forward direction.
66    self.assertNotIn(node.body[0].body[2], visitor.counts)
67    self.assertEqual(visitor.counts[node.body[1]], 1)
68
69  def test_basic_coverage_reverse(self):
70
71    def test_fn(a):
72      while a > 0:
73        a = 1
74        break
75        return a  # pylint:disable=unreachable
76      a = 2
77
78    graphs, node = self._build_cfg(test_fn)
79    graph, = graphs.values()
80    visitor = CountingVisitor(graph)
81    visitor.visit_reverse()
82
83    self.assertEqual(visitor.counts[node.args], 1)
84    self.assertEqual(visitor.counts[node.body[0].test], 1)
85    self.assertEqual(visitor.counts[node.body[0].body[0]], 1)
86    self.assertEqual(visitor.counts[node.body[0].body[1]], 1)
87    self.assertTrue(visitor.counts[node.body[0].body[2]], 1)
88    self.assertEqual(visitor.counts[node.body[1]], 1)
89
90
91class AstToCfgTest(test.TestCase):
92
93  def _build_cfg(self, fn):
94    node, _, _ = parser.parse_entity(fn)
95    cfgs = cfg.build(node)
96    return cfgs
97
98  def _repr_set(self, node_set):
99    return frozenset(repr(n) for n in node_set)
100
101  def _as_set(self, elements):
102    if elements is None:
103      return frozenset()
104    elif isinstance(elements, str):
105      return frozenset((elements,))
106    else:
107      return frozenset(elements)
108
109  def assertGraphMatches(self, graph, edges):
110    """Tests whether the CFG contains the specified edges."""
111    for prev, node_repr, next_ in edges:
112      matched = False
113      for cfg_node in graph.index.values():
114        if repr(cfg_node) == node_repr:
115          if (self._as_set(prev) == frozenset(map(repr, cfg_node.prev)) and
116              self._as_set(next_) == frozenset(map(repr, cfg_node.next))):
117            matched = True
118            break
119      if not matched:
120        self.fail(
121            'match failed for node "%s" in graph:\n%s' % (node_repr, graph))
122
123  def assertStatementEdges(self, graph, edges):
124    """Tests whether the CFG contains the specified statement edges."""
125    for prev_node_reprs, node_repr, next_node_reprs in edges:
126      matched = False
127      partial_matches = []
128      self.assertSetEqual(
129          frozenset(graph.stmt_next.keys()), frozenset(graph.stmt_prev.keys()))
130      for stmt_ast_node in graph.stmt_next:
131        ast_repr = '%s:%s' % (stmt_ast_node.__class__.__name__,
132                              stmt_ast_node.lineno)
133        if ast_repr == node_repr:
134          actual_next = frozenset(map(repr, graph.stmt_next[stmt_ast_node]))
135          actual_prev = frozenset(map(repr, graph.stmt_prev[stmt_ast_node]))
136          partial_matches.append((actual_prev, node_repr, actual_next))
137          if (self._as_set(prev_node_reprs) == actual_prev and
138              self._as_set(next_node_reprs) == actual_next):
139            matched = True
140            break
141      if not matched:
142        self.fail('edges mismatch for %s: %s' % (node_repr, partial_matches))
143
144  def test_straightline(self):
145
146    def test_fn(a):
147      a += 1
148      a = 2
149      a = 3
150      return
151
152    graph, = self._build_cfg(test_fn).values()
153
154    self.assertGraphMatches(
155        graph,
156        (
157            (None, 'a', 'a += 1'),
158            ('a += 1', 'a = 2', 'a = 3'),
159            ('a = 2', 'a = 3', 'return'),
160            ('a = 3', 'return', None),
161        ),
162    )
163
164  def test_straightline_no_return(self):
165
166    def test_fn(a, b):
167      a = b + 1
168      a += max(a)
169
170    graph, = self._build_cfg(test_fn).values()
171
172    self.assertGraphMatches(
173        graph,
174        (
175            (None, 'a, b', 'a = b + 1'),
176            ('a = b + 1', 'a += max(a)', None),
177        ),
178    )
179
180  def test_unreachable_code(self):
181
182    def test_fn(a):
183      return
184      a += 1  # pylint:disable=unreachable
185
186    graph, = self._build_cfg(test_fn).values()
187
188    self.assertGraphMatches(
189        graph,
190        (
191            (None, 'a', 'return'),
192            ('a', 'return', None),
193            (None, 'a += 1', None),
194        ),
195    )
196
197  def test_if_straightline(self):
198
199    def test_fn(a):
200      if a > 0:
201        a = 1
202      else:
203        a += -1
204
205    graph, = self._build_cfg(test_fn).values()
206
207    self.assertGraphMatches(
208        graph,
209        (
210            (None, 'a', '(a > 0)'),
211            ('(a > 0)', 'a = 1', None),
212            ('(a > 0)', 'a += -1', None),
213        ),
214    )
215    self.assertStatementEdges(
216        graph,
217        (('a', 'If:2', None),),
218    )
219
220  def test_branch_nested(self):
221
222    def test_fn(a):
223      if a > 0:
224        if a > 1:
225          a = 1
226        else:
227          a = 2
228      else:
229        if a > 2:
230          a = 3
231        else:
232          a = 4
233
234    graph, = self._build_cfg(test_fn).values()
235
236    self.assertGraphMatches(
237        graph,
238        (
239            (None, 'a', '(a > 0)'),
240            ('a', '(a > 0)', ('(a > 1)', '(a > 2)')),
241            ('(a > 0)', '(a > 1)', ('a = 1', 'a = 2')),
242            ('(a > 1)', 'a = 1', None),
243            ('(a > 1)', 'a = 2', None),
244            ('(a > 0)', '(a > 2)', ('a = 3', 'a = 4')),
245            ('(a > 2)', 'a = 3', None),
246            ('(a > 2)', 'a = 4', None),
247        ),
248    )
249    self.assertStatementEdges(
250        graph,
251        (
252            ('a', 'If:2', None),
253            ('(a > 0)', 'If:3', None),
254            ('(a > 0)', 'If:8', None),
255        ),
256    )
257
258  def test_branch_straightline_semi(self):
259
260    def test_fn(a):
261      if a > 0:
262        a = 1
263
264    graph, = self._build_cfg(test_fn).values()
265
266    self.assertGraphMatches(
267        graph,
268        (
269            (None, 'a', '(a > 0)'),
270            ('a', '(a > 0)', 'a = 1'),
271            ('(a > 0)', 'a = 1', None),
272        ),
273    )
274    self.assertStatementEdges(
275        graph,
276        (('a', 'If:2', None),),
277    )
278
279  def test_branch_return(self):
280
281    def test_fn(a):
282      if a > 0:
283        return
284      else:
285        a = 1
286      a = 2
287
288    graph, = self._build_cfg(test_fn).values()
289
290    self.assertGraphMatches(
291        graph,
292        (
293            ('a', '(a > 0)', ('return', 'a = 1')),
294            ('(a > 0)', 'a = 1', 'a = 2'),
295            ('(a > 0)', 'return', None),
296            ('a = 1', 'a = 2', None),
297        ),
298    )
299    self.assertStatementEdges(
300        graph,
301        (('a', 'If:2', 'a = 2'),),
302    )
303
304  def test_branch_return_minimal(self):
305
306    def test_fn(a):
307      if a > 0:
308        return
309
310    graph, = self._build_cfg(test_fn).values()
311
312    self.assertGraphMatches(
313        graph,
314        (
315            ('a', '(a > 0)', 'return'),
316            ('(a > 0)', 'return', None),
317        ),
318    )
319    self.assertStatementEdges(
320        graph,
321        (('a', 'If:2', None),),
322    )
323
324  def test_while_straightline(self):
325
326    def test_fn(a):
327      while a > 0:
328        a = 1
329      a = 2
330
331    graph, = self._build_cfg(test_fn).values()
332
333    self.assertGraphMatches(
334        graph,
335        (
336            (('a', 'a = 1'), '(a > 0)', ('a = 1', 'a = 2')),
337            ('(a > 0)', 'a = 1', '(a > 0)'),
338            ('(a > 0)', 'a = 2', None),
339        ),
340    )
341    self.assertStatementEdges(
342        graph,
343        (('a', 'While:2', 'a = 2'),),
344    )
345
346  def test_while_else_straightline(self):
347
348    def test_fn(a):
349      while a > 0:
350        a = 1
351      else:  # pylint:disable=useless-else-on-loop
352        a = 2
353      a = 3
354
355    graph, = self._build_cfg(test_fn).values()
356
357    self.assertGraphMatches(
358        graph,
359        (
360            (('a', 'a = 1'), '(a > 0)', ('a = 1', 'a = 2')),
361            ('(a > 0)', 'a = 1', '(a > 0)'),
362            ('(a > 0)', 'a = 2', 'a = 3'),
363            ('a = 2', 'a = 3', None),
364        ),
365    )
366    self.assertStatementEdges(
367        graph,
368        (('a', 'While:2', 'a = 3'),),
369    )
370
371  def test_while_else_continue(self):
372
373    def test_fn(a):
374      while a > 0:
375        if a > 1:
376          continue
377        else:
378          a = 0
379        a = 1
380      else:  # pylint:disable=useless-else-on-loop
381        a = 2
382      a = 3
383
384    graph, = self._build_cfg(test_fn).values()
385
386    self.assertGraphMatches(
387        graph,
388        (
389            (('a', 'continue', 'a = 1'), '(a > 0)', ('(a > 1)', 'a = 2')),
390            ('(a > 0)', '(a > 1)', ('continue', 'a = 0')),
391            ('(a > 1)', 'continue', '(a > 0)'),
392            ('a = 0', 'a = 1', '(a > 0)'),
393            ('(a > 0)', 'a = 2', 'a = 3'),
394            ('a = 2', 'a = 3', None),
395        ),
396    )
397    self.assertStatementEdges(
398        graph,
399        (
400            ('a', 'While:2', 'a = 3'),
401            ('(a > 0)', 'If:3', ('a = 1', '(a > 0)')),
402        ),
403    )
404
405  def test_while_else_break(self):
406
407    def test_fn(a):
408      while a > 0:
409        if a > 1:
410          break
411        a = 1
412      else:
413        a = 2
414      a = 3
415
416    graph, = self._build_cfg(test_fn).values()
417
418    self.assertGraphMatches(
419        graph,
420        (
421            (('a', 'a = 1'), '(a > 0)', ('(a > 1)', 'a = 2')),
422            ('(a > 0)', '(a > 1)', ('break', 'a = 1')),
423            ('(a > 1)', 'break', 'a = 3'),
424            ('(a > 1)', 'a = 1', '(a > 0)'),
425            ('(a > 0)', 'a = 2', 'a = 3'),
426            (('break', 'a = 2'), 'a = 3', None),
427        ),
428    )
429    self.assertStatementEdges(
430        graph,
431        (
432            ('a', 'While:2', 'a = 3'),
433            ('(a > 0)', 'If:3', ('a = 1', 'a = 3')),
434        ),
435    )
436
437  def test_while_else_return(self):
438
439    def test_fn(a):
440      while a > 0:
441        if a > 1:
442          return
443        a = 1
444      else:  # pylint:disable=useless-else-on-loop
445        a = 2
446      a = 3
447
448    graph, = self._build_cfg(test_fn).values()
449
450    self.assertGraphMatches(
451        graph,
452        (
453            (('a', 'a = 1'), '(a > 0)', ('(a > 1)', 'a = 2')),
454            ('(a > 0)', '(a > 1)', ('return', 'a = 1')),
455            ('(a > 1)', 'return', None),
456            ('(a > 1)', 'a = 1', '(a > 0)'),
457            ('(a > 0)', 'a = 2', 'a = 3'),
458            ('a = 2', 'a = 3', None),
459        ),
460    )
461    self.assertStatementEdges(
462        graph,
463        (
464            ('a', 'While:2', 'a = 3'),
465            ('(a > 0)', 'If:3', 'a = 1'),
466        ),
467    )
468
469  def test_while_nested_straightline(self):
470
471    def test_fn(a):
472      while a > 0:
473        while a > 1:
474          a = 1
475        a = 2
476      a = 3
477
478    graph, = self._build_cfg(test_fn).values()
479
480    self.assertGraphMatches(
481        graph,
482        (
483            (('a', 'a = 2'), '(a > 0)', ('(a > 1)', 'a = 3')),
484            (('(a > 0)', 'a = 1'), '(a > 1)', ('a = 1', 'a = 2')),
485            ('(a > 1)', 'a = 1', '(a > 1)'),
486            ('(a > 1)', 'a = 2', '(a > 0)'),
487            ('(a > 0)', 'a = 3', None),
488        ),
489    )
490    self.assertStatementEdges(
491        graph,
492        (
493            ('a', 'While:2', 'a = 3'),
494            ('(a > 0)', 'While:3', 'a = 2'),
495        ),
496    )
497
498  def test_while_nested_continue(self):
499
500    def test_fn(a):
501      while a > 0:
502        while a > 1:
503          if a > 3:
504            continue
505          a = 1
506        a = 2
507      a = 3
508
509    graph, = self._build_cfg(test_fn).values()
510
511    self.assertGraphMatches(
512        graph,
513        (
514            (('a', 'a = 2'), '(a > 0)', ('(a > 1)', 'a = 3')),
515            (('(a > 0)', 'continue', 'a = 1'), '(a > 1)', ('(a > 3)', 'a = 2')),
516            ('(a > 1)', '(a > 3)', ('continue', 'a = 1')),
517            ('(a > 3)', 'continue', '(a > 1)'),
518            ('(a > 3)', 'a = 1', '(a > 1)'),
519            ('(a > 1)', 'a = 2', '(a > 0)'),
520            ('(a > 0)', 'a = 3', None),
521        ),
522    )
523    self.assertStatementEdges(
524        graph,
525        (
526            ('a', 'While:2', 'a = 3'),
527            ('(a > 0)', 'While:3', 'a = 2'),
528            ('(a > 1)', 'If:4', ('a = 1', '(a > 1)')),
529        ),
530    )
531
532  def test_while_nested_break(self):
533
534    def test_fn(a):
535      while a > 0:
536        while a > 1:
537          if a > 2:
538            break
539          a = 1
540        a = 2
541      a = 3
542
543    graph, = self._build_cfg(test_fn).values()
544
545    self.assertGraphMatches(graph, (
546        (('a', 'a = 2'), '(a > 0)', ('(a > 1)', 'a = 3')),
547        (('(a > 0)', 'a = 1'), '(a > 1)', ('(a > 2)', 'a = 2')),
548        ('(a > 1)', '(a > 2)', ('break', 'a = 1')),
549        ('(a > 2)', 'break', 'a = 2'),
550        ('(a > 2)', 'a = 1', '(a > 1)'),
551        (('(a > 1)', 'break'), 'a = 2', '(a > 0)'),
552        ('(a > 0)', 'a = 3', None),
553    ))
554    self.assertStatementEdges(
555        graph,
556        (
557            ('a', 'While:2', 'a = 3'),
558            ('(a > 0)', 'While:3', 'a = 2'),
559            ('(a > 1)', 'If:4', ('a = 1', 'a = 2')),
560        ),
561    )
562
563  def test_for_straightline(self):
564
565    def test_fn(a):
566      for a in range(0, a):
567        a = 1
568      a = 2
569
570    graph, = self._build_cfg(test_fn).values()
571
572    self.assertGraphMatches(
573        graph,
574        (
575            (('a', 'a = 1'), 'range(0, a)', ('a = 1', 'a = 2')),
576            ('range(0, a)', 'a = 1', 'range(0, a)'),
577            ('range(0, a)', 'a = 2', None),
578        ),
579    )
580    self.assertStatementEdges(
581        graph,
582        (('a', 'For:2', 'a = 2'),),
583    )
584
585  def test_for_else_straightline(self):
586
587    def test_fn(a):
588      for a in range(0, a):
589        a = 1
590      else:  # pylint:disable=useless-else-on-loop
591        a = 2
592      a = 3
593
594    graph, = self._build_cfg(test_fn).values()
595
596    self.assertGraphMatches(
597        graph,
598        (
599            (('a', 'a = 1'), 'range(0, a)', ('a = 1', 'a = 2')),
600            ('range(0, a)', 'a = 1', 'range(0, a)'),
601            ('range(0, a)', 'a = 2', 'a = 3'),
602            ('a = 2', 'a = 3', None),
603        ),
604    )
605    self.assertStatementEdges(
606        graph,
607        (('a', 'For:2', 'a = 3'),),
608    )
609
610  def test_for_else_continue(self):
611
612    def test_fn(a):
613      for a in range(0, a):
614        if a > 1:
615          continue
616        else:
617          a = 0
618        a = 1
619      else:  # pylint:disable=useless-else-on-loop
620        a = 2
621      a = 3
622
623    graph, = self._build_cfg(test_fn).values()
624
625    self.assertGraphMatches(
626        graph,
627        (
628            (('a', 'continue', 'a = 1'), 'range(0, a)', ('(a > 1)', 'a = 2')),
629            ('range(0, a)', '(a > 1)', ('continue', 'a = 0')),
630            ('(a > 1)', 'continue', 'range(0, a)'),
631            ('(a > 1)', 'a = 0', 'a = 1'),
632            ('a = 0', 'a = 1', 'range(0, a)'),
633            ('range(0, a)', 'a = 2', 'a = 3'),
634            ('a = 2', 'a = 3', None),
635        ),
636    )
637    self.assertStatementEdges(
638        graph,
639        (
640            ('a', 'For:2', 'a = 3'),
641            ('range(0, a)', 'If:3', ('a = 1', 'range(0, a)')),
642        ),
643    )
644
645  def test_for_else_break(self):
646
647    def test_fn(a):
648      for a in range(0, a):
649        if a > 1:
650          break
651        a = 1
652      else:
653        a = 2
654      a = 3
655
656    graph, = self._build_cfg(test_fn).values()
657
658    self.assertGraphMatches(
659        graph,
660        (
661            (('a', 'a = 1'), 'range(0, a)', ('(a > 1)', 'a = 2')),
662            ('range(0, a)', '(a > 1)', ('break', 'a = 1')),
663            ('(a > 1)', 'break', 'a = 3'),
664            ('(a > 1)', 'a = 1', 'range(0, a)'),
665            ('range(0, a)', 'a = 2', 'a = 3'),
666            (('break', 'a = 2'), 'a = 3', None),
667        ),
668    )
669    self.assertStatementEdges(
670        graph,
671        (
672            ('a', 'For:2', 'a = 3'),
673            ('range(0, a)', 'If:3', ('a = 1', 'a = 3')),
674        ),
675    )
676
677  def test_for_else_return(self):
678
679    def test_fn(a):
680      for a in range(0, a):
681        if a > 1:
682          return
683        a = 1
684      else:  # pylint:disable=useless-else-on-loop
685        a = 2
686      a = 3
687
688    graph, = self._build_cfg(test_fn).values()
689
690    self.assertGraphMatches(
691        graph,
692        (
693            (('a', 'a = 1'), 'range(0, a)', ('(a > 1)', 'a = 2')),
694            ('range(0, a)', '(a > 1)', ('return', 'a = 1')),
695            ('(a > 1)', 'return', None),
696            ('(a > 1)', 'a = 1', 'range(0, a)'),
697            ('range(0, a)', 'a = 2', 'a = 3'),
698            ('a = 2', 'a = 3', None),
699        ),
700    )
701    self.assertStatementEdges(
702        graph,
703        (
704            ('a', 'For:2', 'a = 3'),
705            ('range(0, a)', 'If:3', 'a = 1'),
706        ),
707    )
708
709  def test_for_nested_straightline(self):
710
711    def test_fn(a):
712      for a in range(0, a):
713        for b in range(1, a):
714          b += 1
715        a = 2
716      a = 3
717
718    graph, = self._build_cfg(test_fn).values()
719
720    self.assertGraphMatches(
721        graph,
722        (
723            (('a', 'a = 2'), 'range(0, a)', ('range(1, a)', 'a = 3')),
724            (('range(0, a)', 'b += 1'), 'range(1, a)', ('b += 1', 'a = 2')),
725            ('range(1, a)', 'b += 1', 'range(1, a)'),
726            ('range(1, a)', 'a = 2', 'range(0, a)'),
727            ('range(0, a)', 'a = 3', None),
728        ),
729    )
730    self.assertStatementEdges(
731        graph,
732        (
733            ('a', 'For:2', 'a = 3'),
734            ('range(0, a)', 'For:3', 'a = 2'),
735        ),
736    )
737
738  def test_for_nested_continue(self):
739
740    def test_fn(a):
741      for a in range(0, a):
742        for b in range(1, a):
743          if a > 3:
744            continue
745          b += 1
746        a = 2
747      a = 3
748
749    graph, = self._build_cfg(test_fn).values()
750
751    self.assertGraphMatches(
752        graph,
753        (
754            (('a', 'a = 2'), 'range(0, a)', ('range(1, a)', 'a = 3')),
755            (('range(0, a)', 'continue', 'b += 1'), 'range(1, a)',
756             ('(a > 3)', 'a = 2')),
757            ('range(1, a)', '(a > 3)', ('continue', 'b += 1')),
758            ('(a > 3)', 'continue', 'range(1, a)'),
759            ('(a > 3)', 'b += 1', 'range(1, a)'),
760            ('range(1, a)', 'a = 2', 'range(0, a)'),
761            ('range(0, a)', 'a = 3', None),
762        ),
763    )
764    self.assertStatementEdges(
765        graph,
766        (
767            ('a', 'For:2', 'a = 3'),
768            ('range(0, a)', 'For:3', 'a = 2'),
769            ('range(1, a)', 'If:4', ('b += 1', 'range(1, a)')),
770        ),
771    )
772
773  def test_for_nested_break(self):
774
775    def test_fn(a):
776      for a in range(0, a):
777        for b in range(1, a):
778          if a > 2:
779            break
780          b += 1
781        a = 2
782      a = 3
783
784    graph, = self._build_cfg(test_fn).values()
785
786    self.assertGraphMatches(
787        graph,
788        (
789            (('a', 'a = 2'), 'range(0, a)', ('range(1, a)', 'a = 3')),
790            (('range(0, a)', 'b += 1'), 'range(1, a)', ('(a > 2)', 'a = 2')),
791            ('range(1, a)', '(a > 2)', ('break', 'b += 1')),
792            ('(a > 2)', 'break', 'a = 2'),
793            ('(a > 2)', 'b += 1', 'range(1, a)'),
794            (('range(1, a)', 'break'), 'a = 2', 'range(0, a)'),
795            ('range(0, a)', 'a = 3', None),
796        ),
797    )
798    self.assertStatementEdges(
799        graph,
800        (
801            ('a', 'For:2', 'a = 3'),
802            ('range(0, a)', 'For:3', 'a = 2'),
803            ('range(1, a)', 'If:4', ('b += 1', 'a = 2')),
804        ),
805    )
806
807  def test_complex(self):
808
809    def test_fn(a):
810      b = 0
811      while a > 0:
812        for b in range(0, a):
813          if a > 2:
814            break
815          if a > 3:
816            if a > 4:
817              continue
818            else:
819              max(a)
820              break
821          b += 1
822        else:  # for b in range(0, a):
823          return a
824        a = 2
825      for a in range(1, a):
826        return b
827      a = 3
828
829    graph, = self._build_cfg(test_fn).values()
830
831    self.assertGraphMatches(
832        graph,
833        (
834            (('b = 0', 'a = 2'), '(a > 0)', ('range(0, a)', 'range(1, a)')),
835            (
836                ('(a > 0)', 'continue', 'b += 1'),
837                'range(0, a)',
838                ('(a > 2)', 'return a'),
839            ),
840            ('range(0, a)', '(a > 2)', ('(a > 3)', 'break')),
841            ('(a > 2)', 'break', 'a = 2'),
842            ('(a > 2)', '(a > 3)', ('(a > 4)', 'b += 1')),
843            ('(a > 3)', '(a > 4)', ('continue', 'max(a)')),
844            ('(a > 4)', 'max(a)', 'break'),
845            ('max(a)', 'break', 'a = 2'),
846            ('(a > 4)', 'continue', 'range(0, a)'),
847            ('(a > 3)', 'b += 1', 'range(0, a)'),
848            ('range(0, a)', 'return a', None),
849            ('break', 'a = 2', '(a > 0)'),
850            ('(a > 0)', 'range(1, a)', ('return b', 'a = 3')),
851            ('range(1, a)', 'return b', None),
852            ('range(1, a)', 'a = 3', None),
853        ),
854    )
855    self.assertStatementEdges(
856        graph,
857        (
858            ('b = 0', 'While:3', 'range(1, a)'),
859            ('(a > 0)', 'For:4', 'a = 2'),
860            ('range(0, a)', 'If:5', ('(a > 3)', 'a = 2')),
861            ('(a > 2)', 'If:7', ('b += 1', 'a = 2', 'range(0, a)')),
862            ('(a > 3)', 'If:8', ('a = 2', 'range(0, a)')),
863            ('(a > 0)', 'For:17', 'a = 3'),
864        ),
865    )
866
867  def test_finally_straightline(self):
868
869    def test_fn(a):
870      try:
871        a += 1
872      finally:
873        a = 2
874      a = 3
875
876    graph, = self._build_cfg(test_fn).values()
877
878    self.assertGraphMatches(
879        graph,
880        (
881            ('a', 'a += 1', 'a = 2'),
882            ('a += 1', 'a = 2', 'a = 3'),
883            ('a = 2', 'a = 3', None),
884        ),
885    )
886
887  def test_return_finally(self):
888
889    def test_fn(a):
890      try:
891        return a
892      finally:
893        a = 1
894      a = 2
895
896    graph, = self._build_cfg(test_fn).values()
897
898    self.assertGraphMatches(
899        graph,
900        (
901            ('a', 'return a', 'a = 1'),
902            ('return a', 'a = 1', None),
903            (None, 'a = 2', None),
904        ),
905    )
906
907  def test_break_finally(self):
908
909    def test_fn(a):
910      while a > 0:
911        try:
912          break
913        finally:
914          a = 1
915
916    graph, = self._build_cfg(test_fn).values()
917
918    self.assertGraphMatches(
919        graph,
920        (
921            ('a', '(a > 0)', 'break'),
922            ('(a > 0)', 'break', 'a = 1'),
923            ('break', 'a = 1', None),
924        ),
925    )
926
927  def test_continue_finally(self):
928
929    def test_fn(a):
930      while a > 0:
931        try:
932          continue
933        finally:
934          a = 1
935
936    graph, = self._build_cfg(test_fn).values()
937
938    self.assertGraphMatches(
939        graph,
940        (
941            (('a', 'a = 1'), '(a > 0)', 'continue'),
942            ('(a > 0)', 'continue', 'a = 1'),
943            ('continue', 'a = 1', '(a > 0)'),
944        ),
945    )
946
947  def test_with_straightline(self):
948
949    def test_fn(a):
950      with max(a) as b:
951        a = 0
952        return b
953
954    graph, = self._build_cfg(test_fn).values()
955
956    self.assertGraphMatches(
957        graph,
958        (
959            ('a', 'max(a)', 'a = 0'),
960            ('max(a)', 'a = 0', 'return b'),
961            ('a = 0', 'return b', None),
962        ),
963    )
964
965  def test_lambda_basic(self):
966
967    def test_fn(a):
968      a = lambda b: a + b
969      return a
970
971    graph, = self._build_cfg(test_fn).values()
972
973    self.assertGraphMatches(
974        graph,
975        (
976            ('a', 'a = lambda b: a + b', 'return a'),
977            ('a = lambda b: a + b', 'return a', None),
978        ),
979    )
980
981
982if __name__ == '__main__':
983  test.main()
984