• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# AutoGraph reference
2
3[Index](index.md)
4
5## Limitations
6
7When AutoGraph is applied to normal Python code, you should expect no change
8in functionality.
9However, when applied to TensorFlow control flow (for example, an if statement
10with a `tf.Tensor` condition), there are certain limitations. This section
11describes these limitations and practices that will allow you to avoid them.
12
13Key Term: Python variables refer to Python symbols (or symbols for short) and
14should not be confused with TensorFlow variables.
15
16Key Term: A TensorFlow loop variable (or loop variable for short) refers to a
17value (typically a `tf.Tensor`) modified by a loop. See `tf.while_loop`.
18
19### Undefined and None values in TensorFlow
20
21TensorFlow does not support undefined or `None` values. All tensors must have
22a value.
23
24Example:
25
26```
27x = tf.cond(
28    tf.random.uniform(()) > 0.5,
29    lambda: tf.constant(1),
30    lambda: None)  # Error -- a Tensor cannot be None
31```
32
33The same restriction carries over in AutoGraph. If a variable is created inside
34control flow, and used after, then it must be defined before the control flow
35statement:
36
37```
38if tf.random.uniform(()) > 0.5:
39  x = tf.constant(1)
40else:
41  x = None
42tf.print(x)  # Error -- x may be None here
43```
44
45For this reason, AutoGraph forbids variables to be defined in only one branch
46of a TensorFlow conditional, if the variable is used afterwards:
47
48```
49del x
50if tf.random.uniform(()) > 0.5:
51  x = tf.constant(1)
52else:
53  pass
54tf.print(x)  # Error -- x may be undefined here
55```
56
57Note that if the variable is not used after the control flow statement, then it
58is considered local to the control flow block, and is not subject to these
59restrictions.
60
61```
62del x
63if tf.random.uniform(()) > 0.5:
64  x = tf.constant(1)  # Okay -- x does not need to be returned from the TF cond
65else:
66  pass
67```
68
69Similarly, variables must usually be defined before a TensorFlow loop.
70
71The most common example that is not allowed is a loop which initializes some
72accumulator variable in the first iteration:
73
74```
75del x
76for i in tf.range(100):  # Error -- x must be defined before the loop
77  if i == 0:
78    x = tf.constant(1)
79  else:
80    x = x + 1
81tf.print(x)
82```
83
84When the variable is only used inside the loop and does not depend on previous
85iterations, then it's ok to only be initialized inside the loop.
86
87```
88del x
89while tf.random.uniform(()) > 0.5:  # Okay -- x is not used after the loop
90  x = tf.constant(1)
91```
92
93* New in TF 2.4 *
94
95As long as it doesn't depend on previous iterations, the variable may also be
96used after the loop, however in that case the loop must execute at least one
97iteration, and will raise a runtime error otherwise.
98
99```
100del x
101for i in tf.range(10):  # Okay -- x does not depend on previous iterations
102  x = tf.constant(1)
103tf.print(x)
104```
105
106```
107del x
108while tf.constant(False):  # Error -- loop must initialize x!
109  x = tf.constant(1)
110tf.print(x)
111```
112
113Avoid these limitations by defining a default value before the control flow
114statement:
115
116```
117x = tf.constant()
118if tf.random.uniform(()) > 0.5:
119  x = tf.constant(1)
120tf.print(x)  # Okay -- x is either 0 or 1
121```
122
123Note: `None` values and undefined symbols are allowed in Eager control flow,
124because Eager execution uses Python control flow, rather than TensorFlow
125control flow ops.
126
127#### Special case: creating Tensors in a loop
128
129* New in TF 2.4 *
130
131A very common use-case is to run a training loop that creates some outputs:
132
133```
134for i in tf.range(num_steps):
135  outputs = train(next(data_iterator))
136```
137
138Often times these outputs can be nested structures of Tensors, which makes them
139impractical to initialize ahead of the loop.
140
141To help with this use-case, AutoGraph lets you run such loops, under certain
142conditions:
143
144 * outputs must be a Tensor, Python numeric, or a structure of these
145 * outputs must not depend on the value from a previous iteration; in other
146   words, the outputs may only appear to the left of an assignment operation
147 * the loop must run at least one iteration
148
149If the type of outputs is not recognized, then the usual
150"outputs must be defined before the loop" is raised at graph construction.
151
152AutoGraph also inserts a `tf.Assert` statement that raises a runtime error
153if the loop did not execute at least one iteration.
154
155### Indirect modifications and hidden side effects in TensorFlow control flow
156
157Key Point: We recommend using a functional programming style, immutable Python
158collections, TensorFlow ops and collections. Only TensorFlow objects should be
159used for side effects.
160
161#### AutoGraph analyzes code to detect modifications to Python objects
162
163Note: Modifications to TensorFlow objects, such as `tf.Variable`, are tracked
164using a different mechanism (automatic control dependencies) which does not
165rely on code analysis.
166
167One of the most important functions of AutoGraph is to rewrite Python control
168flow statements into equivalent TensorFlow ops. This process requires "wiring"
169variables covered by these control flow statements into the respective ops.
170
171The examples below use a `while` loop, but the same notions extend to all
172control flow such as `if` and `for` statements.
173
174In the example below, `x` needs to become a loop variable of the
175corresponding `tf.while_loop':
176
177```
178while x > 0:
179  x = x - 1
180```
181```
182x = tf.while_loop(..., loop_vars=(x,)
183```
184
185TF control ops support only a limited set of types for loop variables. At the
186same time, the efficiency of TensorFlow graphs is influenced by the number of
187loop variables, so we don't want to create them unnecessarily. AutoGraph pulls
188symbols through loop variables only if necessary to minimize the number of
189loop variables.
190
191Note: If a symbol refers to a nested structure, such as a `dict` of `dict`s,
192the entire structure is mapped to multiple loop variables - TensorFlow
193automatically unpacks it.
194
195For example, the symbol 'y' below is not wired through the `tf.while_loop`'s
196`loop_vars` because it is not affected by the `while` loop:
197
198```
199y = 0
200while x > 0:
201  x = x - 1
202print(y)
203```
204```
205x = tf.while_loop(..., loop_vars=(x,)  # y does not need to be a loop variable
206```
207
208AutoGraph uses static analysis to determine which symbols are modified by the
209code, in order to transform them into control flow variables. Static analysis
210is generally performed on single functions - Python's dynamic nature limits its
211effectiveness across functions.
212
213#### Modifications of Python objects are not detected across functions
214
215Note: Modifications to TensorFlow objects, such as `tf.Variable`, are tracked
216using a different mechanism (automatic control dependencies). Modifications
217to `tf.Variable` objects are correctly handled even when called in other
218functions.
219
220Because static analysis is limited to single functions, modifications that are
221performed in other functions are not visible to AutoGraph:
222
223```
224def change_y():
225  global y
226  y = y + 1
227
228while x > 0:
229  change_y()  # Problem -- change made to y is not visible here!
230```
231
232This can be easily remedied using a functional programming style - writing
233functions that use argument for all their inputs and return values for all their
234outputs:
235
236```
237def change(y):
238  y = y + 1
239  return y
240
241while x > 0:
242  y = change(y)  # Okay -- y can now be properly tracked!
243```
244
245As noted before, this limitation does not apply to most TensorFlow objects,
246although it is still a good idea to use functional programming style for
247better code readability:
248
249```
250def change(y_var):
251  y_var.assign_add(1)
252
253y = tf.Variable(1)
254while x > 0:
255  change(y)  # This is still okay -- TensorFlow side effects are robust.
256```
257
258Keep in mind however that certain types like `tf.TensorArray` don't support
259side effects and must have their result assigned, otherwise they may raise an
260error:
261
262```
263def change(ta):
264  ta.write(0, 1)  # Incorrect use of TensorArray - will raise an error
265```
266
267In other words, `tf.TensorArray` must be handled using functional programming
268style:
269
270```
271def change(ta):
272  ta = ta.write(0, 1)  # Modifications create a new TensorArray efficiently.
273  return ta
274
275ta = tf.TensorArray(tf.int32, size=0, dynamic_size=True)
276while x > 0:
277  # TensorArray must be handled using functional programming style.
278  ta = change(ta)
279```
280
281#### Modifications of Python objects are not detected in methods
282
283A special case of hidden side effects are methods, which are commonly used
284to change the value of objects:
285
286```
287class MyClass(object):
288  def change(self):
289    self.y += 1
290
291c = MyClass()
292while x > 0:
293  c.change()  # Problem -- modification to c.y is not visible here!
294```
295
296This can be addressed in a number of ways.
297
298One possibility is to operate directly on the object properties:
299
300```
301c = MyClass()
302while x > 0:
303  c.y += 1  # Okay -- c.y can now be properly tracked!
304```
305
306Another possibility is to rely on immutable objects with value semantics. This
307may lead to many temporary objects when executing eagerly, but their number is
308greatly reduced in `@tf.function`:
309
310```
311class MyClass(collections.namedtuple('MyClass', ('y',))):
312  def change(self):
313    new_y = self.y + 1
314    return MyClass(new_y)
315
316c = MyClass()
317while x > 0:
318  c = c.change()  # Okay -- c is now a loop var.
319```
320
321It is also recommended to use a functional programming style with such immutable
322objects - that is, all arguments are inputs, all changes are return values:
323
324```
325def use_my_class(c: MyClass) -> MyClass:
326  new_c = c.change()
327  return new_c
328```
329
330Don't worry about creating a few extra objects - they are only used at trace
331time, and don't exist at graph execution.
332
333Note: TensorFlow control flow does not currently support arbitrary Python
334objects, but it does support basic collection objects such as `list`, `dict`,
335`tuple`, `namedtuple` and their subclasses. Design your objects as subclasses
336of [namedtuple](https://docs.python.org/3/library/collections.html#collections.namedtuple),
337or other types that [tf.nest](https://www.tensorflow.org/api_docs/python/tf/nest/map_structure)
338recognizes.
339
340#### Variables closed over by lambda functions
341
342AutoGraph assumes that variables that local functions close over may be used
343anywhere in the parent function, because in general it is possible to hide a
344function call in almost any Python statement). For this reason, these variables
345are accounted within TensorFlow loops.
346
347For example, the following code correctly captures `a` in the TensorFlow loop
348variables:
349
350```
351a = 0
352def f():
353  tf.print(a)
354for i in tf.range(3):
355  a = i
356f()  # Prints 2
357```
358
359An consequence is that these variables must be defined before the loop (see
360Undefined and None values above). So the following code will raise an error,
361even if the variable is never used after the loop:
362
363```
364def f():
365  tf.print(a)
366for i in tf.range(3):  # Error -- `a` must be defined before the loop.
367  a = i
368```
369
370However, lambda functions are handled differently, for reasons of backward
371compatibility. Lambda functions are assumed to be used in the statement where
372they are used, or at least in the same block.
373
374```
375a = 0
376foo(lambda: a)  # This lambda is not expected to be called anywhere else.
377for i in tf.range(3):  # Okay -- `a` is local to the loop.
378  a = i
379```
380
381Due to that reason, the following code will not work as expected for TensorFlow
382loops.
383
384```
385a = 0
386l = lambda: tf.print(a)
387for i in tf.range(3):
388  a = i  # `a` is considered local to the loop
389l()  # Prints 0!
390```
391
392Note that none of these restrictions only apply to TensorFlow loops; Python
393loops correctly handle closures in all cases.
394
395### Python collections in TensorFlow control flow
396
397Key Point: Use TensorFlow collection classes instead of Python collections.
398Python collections are okay to use when they represent a fixed structure (that
399is, `list`s don't change length, `dict`s don't add or remove keys).
400
401#### Modifying Python collections in TensorFlow control flow is not allowed
402
403One of the advantages of eager execution is that you may use the usual Python
404collections, like `list` or `dict` to hold `tf.Tensor` values. However, these
405are generally not compatible with TensorFlow control flow. Specialized
406collections like `tf.TensorArray` are required.
407
408Consider the following example:
409
410```
411def fn():
412  l = []
413
414  def loop_cond(i):
415    return i < 10
416
417  def loop_body(i):
418    i = i + 1
419    l.append(i)
420    return i,
421
422  tf.while_loop(
423      cond=loop_cond,
424      body=loop_body,
425      loop_vars=(0,))
426
427  return l
428```
429
430This code works in eager execution, which does not use the TensorFlow runtime
431for the `tf.while_loop`:
432
433```
434fn()
435```
436
437However, it does not work in graph execution, because TensorFlow uses special
438mechanisms to ensure the computations are correctly sequenced in the dataflow
439graph:
440
441```
442tf.function(fn)()  # Error -- illegal tensor capture!
443```
444
445The equivalent AutoGraph code raises the same error:
446
447```
448l = []
449for i in tf.range(10):
450  l.append(i)  # Error -- illegal tensor capture!
451```
452
453Instead, use the specialized `tf.TensorArray` class:
454
455```
456l = tf.TensorArray(tf.int32, size=0, dynamic_size=True)
457for i in tf.range(10):
458  l = l.write(l.size(), i)  # Okay
459```
460
461#### Python collections of fixed structure are allowed TensorFlow control flow
462
463An exception from the previous rule is made by Python collections that are
464static, that is, they don't grow in size for the duration of the computation.
465
466Caution: Use functional programming style when manipulating static collections.
467
468Examples:
469
470```
471static_list = [tf.constant(3)]
472while d.prop > 0:
473  static_list[0] -= 1  # Okay -- static_list does not change structure
474```
475```
476static_object = MyClass()
477static_object.field = tf.constant(3)
478while static_object.field > 0:
479  static_object.field -= 1  # Okay -- static_object does not change structure
480```
481```
482static_dict = {'field': tf.constant(3)}
483while static_dict['field'] > 0:
484  static_dict['field'] -= 1  # Okay -- static_dict does not change structure
485```
486
487However, remember to use functional programming style when these collections
488are used inside control flow.
489
490#### Python collections of fixed structure with dynamic index
491
492A more subtle error occurs when the collection is static, but is accessed in a
493dynamic way, that is with a key that is not constant.
494
495For example:
496
497```
498d = {'a': tf.constant(3)}
499for i in tf.range(10):
500  for key in d:
501    d[key] += i  # Problem -- accessing `dict` using non-constant key
502```
503
504The code above will raises an "illegal capture" error. To remedy it, write it
505in functional programming style:
506
507```
508d = {'a': tf.constant(3)}
509for i in tf.range(10):
510  d = {key: value + i for key, value in d.items()}  # Okay
511```
512
513### Shape and dtype consistency in TensorFlow control flow
514
515Unlike Python, TensorFlow has limited support for dynamic typing. This means
516that tensors must maintain consistent shapes and dtypes across control flow
517paths.
518
519Note: In general, these restrictions do not apply in control flow in Eager
520execution, because Eager execution uses Python control flow, rather than
521TensorFlow control flow ops.
522
523#### Mixing dynamic computations and static shapes
524
525Key Point: Use `.shape` on tensors of static shape, and `.shape.rank` on
526tensors of static rank; only use `tf.shape` and `tf.rank` when the shape or
527rank is dynamic.
528
529TensorFlow has optional static types and shapes: the shape of tensors may be
530static (e.g. `my_tensor.shape=(3, 3)` denotes a three by three matrix) or
531dynamic (e.g. `my_tensor.shape=(None, 3)` denotes a matrix with a dynamic
532number of rows and three columns. When the shapes are dynamic, you can still
533query it at runtime by using the `tf.shape()` function.
534
535Note: `tf.shape` always returns a tensor.
536
537For static shapes, TensorFlow will perform additional shape verifications at
538graph construction time, that is, during tracing. These static shape
539verifications are useful because they work like a compiler for example, errors
540are caught early, before execution even begins.
541
542For example:
543
544```
545x = tf.constant([1, 2, 3])
546x[4]  # Tracing error! 4 is out of bounds.
547```
548
549To avoid tracing errors, you can add static shape verifications, which help
550make your code more robust:
551
552```
553if x.shape[0] > 4:
554  val = x[4]
555else:
556  val = some_default_value
557```
558
559In the snippet above, the code is protected against index-out-of-bounds
560errors. The code is also efficient because the verification `x.shape[0] > 4`
561will not be included in the graph.
562
563But what happens if you try to perform the index verifications using dynamic
564control flow? You might expect that the code works in the same way:
565
566```
567val = tf.cond(
568  x.shape[0] >= 4,
569  lambda: x[4],
570  lambda: some_default_value)
571```
572
573However, TensorFlow will not let you write code that could result in an error,
574even if that code appeared in a branch of a `tf.cond` statement that would
575never execute. Remember that the shape of `x` is `(3,)`, so TensorFlow performs
576static shape verification.
577
578This can lead to surprising behavior when using `tf.shape` on tensors with
579static shape in TensorFlow:
580
581```
582x = tf.constant((1, 2, 3))
583if tf.shape(x)[0] > 4:
584  val = x[4]  # Error at tracing: 4 is out of bounds!
585else:
586  val = some_default_value
587```
588
589Because `tf.shape` always evaluates to a Tensor, the `if` statement above is
590converted by AutoGraph into a `tf.cond`, which performs static shape
591verification of both branches.
592
593What if you need to write code which can handle both static and dynamic
594shapes? There are a few options in this case:
595
596A first option is to always work with dynamic shapes, for instance by
597using `input_signature` in `tf.function`. Many shape and shape-related checks
598are skipped when the shape is dynamic:
599
600```
601@tf.function(input_signature=(tf.TensorSpec(shape=(None,))))
602def f(x):  # x now has dynamic shape
603  if tf.shape(x)[0] >= 3:  # Builds a tf.cond
604    val = x[4]  # Okay, bounds checks are skipped when the shape is dynamic
605  else:
606    val = some_default_value
607```
608
609A second option is to first verify whether the shape is static or dynamic.
610This can be done at tracing time, allowing to use Python `if` to only trace
611the code that is suitable for the situation:
612
613```
614if x.shape[0] is None:  # Python bool, does not use tf.cond
615  # ... use x.shape here ...
616else:
617  # ... use tf.shape(x) here ...
618```
619
620#### Consistency of dtype
621
622The dtypes across all code paths must be consistent in conditionals and loops.
623
624For example, if a `tf.cond` (and correspondingly, an AutoGraph `if`) sets a
625tensor value conditionally, then that tensor must have the same shape and dtype
626in both branches of the conditional.
627
628Example of illegal dtype change in a conditional:
629
630```
631x = tf.cond(
632    tf.random.uniform(()) > 0.5,
633    lambda: tf.constant(1, dtype=tf.int32),
634    lambda: tf.constant(1, dtype=tf.float32))  # Error -- inconsistent dtypes: int32, float32
635```
636
637The same restriction in AutoGraph code:
638
639```
640if tf.random.uniform(()) > 0.5:
641  x = tf.constant(1, dtype=tf.int32)
642else:
643  x = tf.constant(1, dtype=tf.float32)  # Error -- inconsistent dtypes: int32, float32
644```
645
646Example of illegal dtype change in a loop:
647
648```
649# This won't work - "x" changes dtype inside the loop.
650x = tf.while_loop(
651    lambda _: tf.random.uniform(()) > 0.5,
652    lambda x: tf.constant(1, dtype=tf.float32),
653    loop_vars=(tf.constant(1, dtype=tf.int32),))  # Error -- inconsistent dtypes: int32, float32
654```
655
656The same restriction in AutoGraph code:
657
658```
659x = tf.constant(0, dtype=tf.int32)
660while tf.random.uniform(()) > 0.5:
661  x = tf.constant(0, dtype=tf.float32)   # Error -- inconsistent dtypes: int32, float32
662```
663
664#### Consistency of shape
665
666The shapes across all code paths must be consistent in loops only. When tensors
667do need to change shape across iterations, use `shape_invariants`.
668
669Note: Shapes are allowed to be inconsistent in conditionals. The result will be
670a partially dynamic shape.
671
672In a `tf.while_loop` (and correspondingly, an AutoGraph `while` or `for` loop)
673all loop variables must maintain consistent shape and dtype across iterations.
674That is, every loop variable must have the same shape at the end of the loop
675body as it had at the beginning of the loop body.
676
677Example of illegal shape change in a loop:
678
679```
680def loop_body(x):  # x.shape is ()
681  return tf.constant((1, 2, 3))  # Error -- inconsistent shapes: (), (3,)
682
683x = tf.while_loop(
684    lambda _: tf.random.uniform(()) > 0.5,
685    loop_body,
686    loop_vars=(tf.constant(1,))
687```
688
689The same restriction in AutoGraph code:
690
691```
692x = tf.constant(1,)
693while tf.random.uniform(()) > 0.5:
694  x = tf.constant((1, 2, 3))  # Error -- inconsistent shapes: (), (3,)
695```
696
697### Consistency of control flow types
698
699In AutoGraph, one can write Python control flow like `for i in range(10)`, as
700well as TensorFlow control flow like `for i in tf.range(10)`.
701
702However, one could also write (illegal) programs which start as Python control
703flow, then turn into TensorFlow control flow. In such cases, an error will be
704raised.
705
706Below are a few examples, along with recommendations.
707
708#### Python loop, TF-dependent break or return
709
710Example:
711
712```
713for i in range(10):
714  if tf.greater(i, 3):
715    break  # error - TF break inside Python loop
716```
717
718The solution in this case is to change the loop type to a TF loop:
719
720```
721for i in tf.range(10):
722  if tf.greater(i, 3):
723    break  # works
724```
725
726#### Python loop that turns into a TensorFlow loop
727
728Example:
729
730```
731i = 10
732while i > 0:
733  i = tf.math.subtract(i, 1)  # error - loop would turn into a TF loop
734```
735
736The solution in this case is to make sure the loop type starts as a TF loop,
737typically by making sure the condition is always a Tensor:
738
739```
740i = tf.constant(10)  # works
741while i > 0:
742  i = tf.math.subtract(i, 1)
743```
744
745#### TensorFlow loops never turn into Python loops
746
747Note that this is a legal case, as TensorFlow implicitly converts all Python
748values to Tensor:
749
750```
751i = tf.constant(10)
752while i > 0:
753  i = 0  # this is ok, will be auto-converted to Tensor
754```
755
756### Access to source code
757
758Key point: AutoGraph can only handle functions whose source code can be
759accessed at runtime.
760
761Almost all Python functions allow access to their source code. However, a few
762exceptions exist:
763
764 * functions created in the Python interactive shell
765 * functions with native bindings (these do not have Python source code)
766 * functions created dynamically, using `exec` or `eval`
767
768Use
769[inspect.findsource](https://docs.python.org/3/library/inspect.html#inspect.findsource)
770to quickly diagnose whether the source code is available for a function.
771
772For example:
773
774```
775import inspect
776
777def simple_function():
778  return 1
779
780# If this raises an error, then AutoGraph prints a warning.
781# If it returns source code, then AutoGraph should work as well.
782inspect.findsource(simple_function)
783```
784
785#### Source code of lambda functions
786
787##### TF 2.4 and newer
788
789Key Point: When nesting lambda functions, use distinguishing argument names
790to avoid parse errors.
791
792The Python runtime exposes the source code of lambda functions, however it
793may omit parts of the actual body, or include surrounding code. This may make it
794impossible to parse the exact source code of the lambda function (see
795https://github.com/tensorflow/tensorflow/issues/39832).
796
797AutoGraph uses alternate methods to parse the source code more robustly, but
798in rare cases it may be unable to distinguish between nested lambda functions
799of identical signatures.
800
801Example:
802
803```
804l = lambda x: lambda x: x + 1
805```
806
807AutoGraph raises an error for the code above because the parser cannot
808distinguish between the two function signatures. To work around this limitation,
809use distinct argument names:
810
811```
812l = lambda outer_x: lambda inner_x: inner_x + 1
813```
814
815##### Before TF 2.3 and older
816
817In older versions of TensorFlow, the loading code for lambda functions is not
818robust. Follow the guidance below to avoid errors.
819
820Important: Declare lambda functions on single lines to make sure their source
821code loads correctly.
822
823The Python runtime exposes the source code of lambda functions, however it
824may omit parts of the actual body, or include surrounding code. This may make it
825impossible to parse the exact source code of the lambda function.
826
827For example, consider the declaration of a lambda function below:
828
829```
830foo = (
831    lambda y: lambda x: x * y
832    - y
833)
834```
835
836The Python runtime will report the following source code for `foo`:
837
838```
839>>> inspect.getsource(foo)
840'    lambda y: lambda x: x*y \n'
841```
842
843In other cases, the source code it returns is not valid Python code, resulting
844in an error:
845
846```
847foo = (
848 'bar',
849 lambda: x)
850```
851
852The reported source code contains an invalid token `)`:
853
854```
855>>> inspect.getsource(foo[1])
856' lambda: x)\n'
857```
858
859This shortcoming can be avoided by declaring the lambda in a single assignment
860or return value, and avoiding placing it inside parentheses which could cause
861auto-formatting tools to break it into multiple lines:
862
863
864```
865# Good - single assignment
866my_lambda = lambda: x
867
868# Good - single return
869return lambda x, y: x*y - y
870```
871
872```
873# Bad - wrapped in parentheses
874my_lambda = (lambda x, y: x * y - y)
875
876# Bad - inlined in another expression
877foo(lambda x, y: x + y, bar)
878```
879