• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1.. testsetup::
2
3    # These are hidden from the docs, but these are necessary for `doctest`
4    # since the `inspect` module doesn't play nicely with the execution
5    # environment for `doctest`
6    import torch
7
8    original_script = torch.jit.script
9    def script_wrapper(obj, *args, **kwargs):
10        obj.__module__ = 'FakeMod'
11        return original_script(obj, *args, **kwargs)
12
13    torch.jit.script = script_wrapper
14
15    original_trace = torch.jit.trace
16    def trace_wrapper(obj, *args, **kwargs):
17        obj.__module__ = 'FakeMod'
18        return original_trace(obj, *args, **kwargs)
19
20    torch.jit.trace = trace_wrapper
21
22.. _language-reference-v2:
23
24TorchScript Language Reference
25==============================
26
27This reference manual describes the syntax and core semantics of the TorchScript language.
28TorchScript is a statically typed subset of the Python language. This document explains the supported features of
29Python in TorchScript and also how the language diverges from regular Python. Any features of Python that are not mentioned in
30this reference manual are not part of TorchScript. TorchScript focuses specifically on the features of Python that are needed to
31represent neural network models in PyTorch.
32
33.. contents::
34    :local:
35    :depth: 1
36
37.. _type_system:
38
39Terminology
40~~~~~~~~~~~
41
42This document uses the following terminologies:
43
44.. list-table::
45   :widths: 25 25
46   :header-rows: 1
47
48   * - Pattern
49     - Notes
50   * - ``::=``
51     - Indicates that the given symbol is defined as.
52   * - ``" "``
53     - Represents real keywords and delimiters that are part of the syntax.
54   * - ``A | B``
55     - Indicates either A or B.
56   * - ``( )``
57     - Indicates grouping.
58   * - ``[]``
59     - Indicates optional.
60   * - ``A+``
61     - Indicates a regular expression where term A is repeated at least once.
62   * - ``A*``
63     - Indicates a regular expression where term A is repeated zero or more times.
64
65Type System
66~~~~~~~~~~~
67TorchScript is a statically typed subset of Python. The largest difference between TorchScript and the full Python language is that TorchScript only supports a small set of types that are needed to express
68neural net models.
69
70TorchScript Types
71^^^^^^^^^^^^^^^^^
72
73The TorchScript type system consists of ``TSType`` and ``TSModuleType`` as defined below.
74
75::
76
77    TSAllType ::= TSType | TSModuleType
78    TSType    ::= TSMetaType | TSPrimitiveType | TSStructuralType | TSNominalType
79
80``TSType`` represents the majority of TorchScript types that are composable and that can be used in TorchScript type annotations.
81``TSType`` refers to any of the following:
82
83* Meta Types, e.g., ``Any``
84* Primitive Types, e.g., ``int``, ``float``, and ``str``
85* Structural Types, e.g., ``Optional[int]`` or ``List[MyClass]``
86* Nominal Types (Python classes), e.g., ``MyClass`` (user-defined), ``torch.tensor`` (built-in)
87
88``TSModuleType`` represents ``torch.nn.Module`` and its subclasses. It is treated differently from ``TSType`` because its type schema is inferred partly from the object instance and partly from the class definition.
89As such, instances of a ``TSModuleType`` may not follow the same static type schema. ``TSModuleType`` cannot be used as a TorchScript type annotation or be composed with ``TSType`` for type safety considerations.
90
91Meta Types
92^^^^^^^^^^
93
94Meta types are so abstract that they are more like type constraints than concrete types.
95Currently TorchScript defines one meta-type, ``Any``, that represents any TorchScript type.
96
97``Any`` Type
98""""""""""""
99
100The ``Any`` type represents any TorchScript type. ``Any`` specifies no type constraints, thus there is no type-checking on ``Any``.
101As such it can be bound to any Python or TorchScript data types (e.g., ``int``, TorchScript ``tuple``, or an arbitrary Python class that is not scripted).
102
103::
104
105    TSMetaType ::= "Any"
106
107Where:
108
109* ``Any`` is the Python class name from the typing module. Therefore, to use the ``Any`` type, you must import it from ``typing`` (e.g., ``from typing import Any``).
110* Since ``Any`` can represent any TorchScript type, the set of operators that are allowed to operate on values of this type on ``Any`` is limited.
111
112Operators Supported for ``Any`` Type
113""""""""""""""""""""""""""""""""""""
114
115* Assignment to data of ``Any`` type.
116* Binding to parameter or return of ``Any`` type.
117* ``x is``, ``x is not`` where ``x`` is of ``Any`` type.
118* ``isinstance(x, Type)`` where ``x`` is of ``Any`` type.
119* Data of ``Any`` type is printable.
120* Data of ``List[Any]`` type may be sortable if the data is a list of values of the same type ``T`` and that ``T`` supports comparison operators.
121
122**Compared to Python**
123
124
125``Any`` is the least constrained type in the TorchScript type system. In that sense, it is quite similar to the
126``Object`` class in Python. However, ``Any`` only supports a subset of the operators and methods that are supported by ``Object``.
127
128Design Notes
129""""""""""""
130
131When we script a PyTorch module, we may encounter data that is not involved in the execution of the script. Nevertheless, it has to be described
132by a type schema. It is not only cumbersome to describe static types for unused data (in the context of the script), but also may lead to unnecessary
133scripting failures. ``Any`` is introduced to describe the type of the data where precise static types are not necessary for compilation.
134
135**Example 1**
136
137This example illustrates how ``Any`` can be used to allow the second element of the tuple parameter to be of any type. This is possible
138because ``x[1]`` is not involved in any computation that requires knowing its precise type.
139
140.. testcode::
141
142    import torch
143
144    from typing import Tuple
145    from typing import Any
146
147    @torch.jit.export
148    def inc_first_element(x: Tuple[int, Any]):
149        return (x[0]+1, x[1])
150
151    m = torch.jit.script(inc_first_element)
152    print(m((1,2.0)))
153    print(m((1,(100,200))))
154
155The example above produces the following output:
156
157.. testoutput::
158
159    (2, 2.0)
160    (2, (100, 200))
161
162The second element of the tuple is of ``Any`` type, thus can bind to multiple types.
163For example, ``(1, 2.0)`` binds a float type to ``Any`` as in ``Tuple[int, Any]``,
164whereas ``(1, (100, 200))`` binds a tuple to ``Any`` in the second invocation.
165
166
167**Example 2**
168
169This example illustrates how we can use ``isinstance`` to dynamically check the type of the data that is annotated as ``Any`` type:
170
171.. testcode::
172
173    import torch
174    from typing import Any
175
176    def f(a:Any):
177        print(a)
178        return (isinstance(a, torch.Tensor))
179
180    ones = torch.ones([2])
181    m = torch.jit.script(f)
182    print(m(ones))
183
184The example above produces the following output:
185
186.. testoutput::
187
188     1
189     1
190    [ CPUFloatType{2} ]
191    True
192
193Primitive Types
194^^^^^^^^^^^^^^^
195
196Primitive TorchScript types are types that represent a single type of value and go with a single pre-defined
197type name.
198
199::
200
201    TSPrimitiveType ::= "int" | "float" | "double" | "complex" | "bool" | "str" | "None"
202
203Structural Types
204^^^^^^^^^^^^^^^^
205
206Structural types are types that are structurally defined without a user-defined name (unlike nominal types),
207such as ``Future[int]``. Structural types are composable with any ``TSType``.
208
209::
210
211    TSStructuralType ::=  TSTuple | TSNamedTuple | TSList | TSDict |
212                        TSOptional | TSUnion | TSFuture | TSRRef | TSAwait
213
214    TSTuple          ::= "Tuple" "[" (TSType ",")* TSType "]"
215    TSNamedTuple     ::= "namedtuple" "(" (TSType ",")* TSType ")"
216    TSList           ::= "List" "[" TSType "]"
217    TSOptional       ::= "Optional" "[" TSType "]"
218    TSUnion          ::= "Union" "[" (TSType ",")* TSType "]"
219    TSFuture         ::= "Future" "[" TSType "]"
220    TSRRef           ::= "RRef" "[" TSType "]"
221    TSAwait          ::= "Await" "[" TSType "]"
222    TSDict           ::= "Dict" "[" KeyType "," TSType "]"
223    KeyType          ::= "str" | "int" | "float" | "bool" | TensorType | "Any"
224
225Where:
226
227* ``Tuple``, ``List``, ``Optional``, ``Union``, ``Future``, ``Dict`` represent Python type class names that are defined in the module ``typing``. To use these type names, you must import them from ``typing`` (e.g., ``from typing import Tuple``).
228* ``namedtuple`` represents the Python class ``collections.namedtuple`` or ``typing.NamedTuple``.
229* ``Future`` and ``RRef`` represent the Python classes ``torch.futures`` and ``torch.distributed.rpc``.
230* ``Await`` represent the Python class ``torch._awaits._Await``
231
232**Compared to Python**
233
234Apart from being composable with TorchScript types, these TorchScript structural types often support a common subset of the operators and methods of their Python counterparts.
235
236**Example 1**
237
238This example uses ``typing.NamedTuple`` syntax to define a tuple:
239
240.. testcode::
241
242    import torch
243    from typing import NamedTuple
244    from typing import Tuple
245
246    class MyTuple(NamedTuple):
247        first: int
248        second: int
249
250    def inc(x: MyTuple) -> Tuple[int, int]:
251        return (x.first+1, x.second+1)
252
253    t = MyTuple(first=1, second=2)
254    scripted_inc = torch.jit.script(inc)
255    print("TorchScript:", scripted_inc(t))
256
257The example above produces the following output:
258
259.. testoutput::
260
261    TorchScript: (2, 3)
262
263**Example 2**
264
265This example uses ``collections.namedtuple`` syntax to define a tuple:
266
267.. testcode::
268
269    import torch
270    from typing import NamedTuple
271    from typing import Tuple
272    from collections import namedtuple
273
274    _AnnotatedNamedTuple = NamedTuple('_NamedTupleAnnotated', [('first', int), ('second', int)])
275    _UnannotatedNamedTuple = namedtuple('_NamedTupleAnnotated', ['first', 'second'])
276
277    def inc(x: _AnnotatedNamedTuple) -> Tuple[int, int]:
278        return (x.first+1, x.second+1)
279
280    m = torch.jit.script(inc)
281    print(inc(_UnannotatedNamedTuple(1,2)))
282
283The example above produces the following output:
284
285.. testoutput::
286
287    (2, 3)
288
289**Example 3**
290
291This example illustrates a common mistake of annotating structural types, i.e., not importing the composite type
292classes from the ``typing`` module:
293
294::
295
296    import torch
297
298    # ERROR: Tuple not recognized because not imported from typing
299    @torch.jit.export
300    def inc(x: Tuple[int, int]):
301        return (x[0]+1, x[1]+1)
302
303    m = torch.jit.script(inc)
304    print(m((1,2)))
305
306Running the above code yields the following scripting error:
307
308::
309
310    File "test-tuple.py", line 5, in <module>
311        def inc(x: Tuple[int, int]):
312    NameError: name 'Tuple' is not defined
313
314The remedy is to add the line ``from typing import Tuple`` to the beginning of the code.
315
316Nominal Types
317^^^^^^^^^^^^^
318
319Nominal TorchScript types are Python classes. These types are called nominal because they are declared with a custom
320name and are compared using class names. Nominal classes are further classified into the following categories:
321
322::
323
324    TSNominalType ::= TSBuiltinClasses | TSCustomClass | TSEnum
325
326Among them, ``TSCustomClass`` and ``TSEnum`` must be compilable to TorchScript Intermediate Representation (IR). This is enforced by the type-checker.
327
328Built-in Class
329^^^^^^^^^^^^^^
330
331Built-in nominal types are Python classes whose semantics are built into the TorchScript system (e.g., tensor types).
332TorchScript defines the semantics of these built-in nominal types, and often supports only a subset of the methods or
333attributes of its Python class definition.
334
335::
336
337    TSBuiltinClass ::= TSTensor | "torch.device" | "torch.Stream" | "torch.dtype" |
338                       "torch.nn.ModuleList" | "torch.nn.ModuleDict" | ...
339    TSTensor       ::= "torch.Tensor" | "common.SubTensor" | "common.SubWithTorchFunction" |
340                       "torch.nn.parameter.Parameter" | and subclasses of torch.Tensor
341
342
343Special Note on torch.nn.ModuleList and torch.nn.ModuleDict
344"""""""""""""""""""""""""""""""""""""""""""""""""""""""""""
345
346Although ``torch.nn.ModuleList`` and ``torch.nn.ModuleDict`` are defined as a list and dictionary in Python,
347they behave more like tuples in TorchScript:
348
349* In TorchScript, instances of ``torch.nn.ModuleList``  or ``torch.nn.ModuleDict`` are immutable.
350* Code that iterates over ``torch.nn.ModuleList`` or ``torch.nn.ModuleDict`` is completely unrolled so that elements of ``torch.nn.ModuleList`` or keys of ``torch.nn.ModuleDict`` can be of different subclasses of ``torch.nn.Module``.
351
352**Example**
353
354The following example highlights the use of a few built-in Torchscript classes (``torch.*``):
355
356::
357
358    import torch
359
360    @torch.jit.script
361    class A:
362        def __init__(self):
363            self.x = torch.rand(3)
364
365        def f(self, y: torch.device):
366            return self.x.to(device=y)
367
368    def g():
369        a = A()
370        return a.f(torch.device("cpu"))
371
372    script_g = torch.jit.script(g)
373    print(script_g.graph)
374
375Custom Class
376^^^^^^^^^^^^
377
378Unlike built-in classes, semantics of custom classes are user-defined and the entire class definition must be compilable to TorchScript IR and subject to TorchScript type-checking rules.
379
380::
381
382    TSClassDef ::= [ "@torch.jit.script" ]
383                     "class" ClassName [ "(object)" ]  ":"
384                        MethodDefinition |
385                    [ "@torch.jit.ignore" ] | [ "@torch.jit.unused" ]
386                        MethodDefinition
387
388Where:
389
390* Classes must be new-style classes. Python 3 supports only new-style classes. In Python 2.x, a new-style class is specified by subclassing from the object.
391* Instance data attributes are statically typed, and instance attributes must be declared by assignments inside the ``__init__()`` method.
392* Method overloading is not supported (i.e., you cannot have multiple methods with the same method name).
393* ``MethodDefinition`` must be compilable to TorchScript IR and adhere to TorchScript’s type-checking rules, (i.e., all methods must be valid TorchScript functions and class attribute definitions must be valid TorchScript statements).
394* ``torch.jit.ignore`` and ``torch.jit.unused`` can be used to ignore the method or function that is not fully torchscriptable or should be ignored by the compiler.
395
396**Compared to Python**
397
398
399TorchScript custom classes are quite limited compared to their Python counterpart. Torchscript custom classes:
400
401* Do not support class attributes.
402* Do not support subclassing except for subclassing an interface type or object.
403* Do not support method overloading.
404* Must initialize all its instance attributes in  ``__init__()``; this is because TorchScript constructs a static schema of the class by inferring attribute types in ``__init__()``.
405* Must contain only methods that satisfy TorchScript type-checking rules and are compilable to TorchScript IRs.
406
407**Example 1**
408
409Python classes can be used in TorchScript if they are annotated with ``@torch.jit.script``, similar to how a TorchScript function would be declared:
410
411::
412
413    @torch.jit.script
414    class MyClass:
415        def __init__(self, x: int):
416            self.x = x
417
418        def inc(self, val: int):
419            self.x += val
420
421
422**Example 2**
423
424A TorchScript custom class type must "declare" all its instance attributes by assignments in ``__init__()``. If an instance attribute is not defined in ``__init__()`` but accessed in other methods of the class, the class cannot be compiled as a TorchScript class, as shown in the following example:
425
426::
427
428    import torch
429
430    @torch.jit.script
431    class foo:
432        def __init__(self):
433            self.y = 1
434
435    # ERROR: self.x is not defined in __init__
436    def assign_x(self):
437        self.x = torch.rand(2, 3)
438
439The class will fail to compile and issue the following error:
440
441::
442
443    RuntimeError:
444    Tried to set nonexistent attribute: x. Did you forget to initialize it in __init__()?:
445    def assign_x(self):
446        self.x = torch.rand(2, 3)
447        ~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
448
449**Example 3**
450
451In this example, a TorchScript custom class defines a class variable name, which is not allowed:
452
453::
454
455    import torch
456
457    @torch.jit.script
458    class MyClass(object):
459        name = "MyClass"
460        def __init__(self, x: int):
461            self.x = x
462
463    def fn(a: MyClass):
464        return a.name
465
466It leads to the following compile-time error:
467
468::
469
470    RuntimeError:
471    '__torch__.MyClass' object has no attribute or method 'name'. Did you forget to initialize an attribute in __init__()?:
472        File "test-class2.py", line 10
473    def fn(a: MyClass):
474        return a.name
475            ~~~~~~ <--- HERE
476
477Enum Type
478^^^^^^^^^
479
480Like custom classes, semantics of the enum type are user-defined and the entire class definition must be compilable to TorchScript IR and adhere to TorchScript type-checking rules.
481
482::
483
484    TSEnumDef ::= "class" Identifier "(enum.Enum | TSEnumType)" ":"
485                   ( MemberIdentifier "=" Value )+
486                   ( MethodDefinition )*
487
488Where:
489
490* Value must be a TorchScript literal of type ``int``, ``float``, or ``str``, and must be of the same TorchScript type.
491* ``TSEnumType`` is the name of a TorchScript enumerated type. Similar to Python enum, TorchScript allows restricted ``Enum`` subclassing, that is, subclassing an enumerated is allowed only if it does not define any members.
492
493**Compared to Python**
494
495
496* TorchScript supports only ``enum.Enum``. It does not support other variations such as ``enum.IntEnum``, ``enum.Flag``, ``enum.IntFlag``, and ``enum.auto``.
497* Values of TorchScript enum members must be of the same type and can only be ``int``, ``float``, or ``str`` types, whereas Python enum members can be of any type.
498* Enums containing methods are ignored in TorchScript.
499
500**Example 1**
501
502The following example defines the class ``Color`` as an ``Enum`` type:
503
504::
505
506    import torch
507    from enum import Enum
508
509    class Color(Enum):
510        RED = 1
511        GREEN = 2
512
513    def enum_fn(x: Color, y: Color) -> bool:
514        if x == Color.RED:
515            return True
516        return x == y
517
518    m = torch.jit.script(enum_fn)
519
520    print("Eager: ", enum_fn(Color.RED, Color.GREEN))
521    print("TorchScript: ", m(Color.RED, Color.GREEN))
522
523**Example 2**
524
525The following example shows the case of restricted enum subclassing, where ``BaseColor`` does not define any member, thus can be subclassed by ``Color``:
526
527::
528
529    import torch
530    from enum import Enum
531
532    class BaseColor(Enum):
533        def foo(self):
534            pass
535
536    class Color(BaseColor):
537        RED = 1
538        GREEN = 2
539
540    def enum_fn(x: Color, y: Color) -> bool:
541        if x == Color.RED:
542            return True
543        return x == y
544
545    m = torch.jit.script(enum_fn)
546
547    print("TorchScript: ", m(Color.RED, Color.GREEN))
548    print("Eager: ", enum_fn(Color.RED, Color.GREEN))
549
550TorchScript Module Class
551^^^^^^^^^^^^^^^^^^^^^^^^
552
553``TSModuleType`` is a special class type that is inferred from object instances that are created outside TorchScript. ``TSModuleType`` is named by the Python class of the object instance. The ``__init__()`` method of the Python class is not considered a TorchScript method, so it does not have to comply with TorchScript’s type-checking rules.
554
555The type schema of a module instance class is constructed directly from an instance object (created outside the scope of TorchScript) rather than inferred from ``__init__()`` like custom classes. It is possible that two objects of the same instance class type follow two different type schemas.
556
557In this sense, ``TSModuleType`` is not really a static type. Therefore, for type safety considerations, ``TSModuleType`` cannot be used in a TorchScript type annotation or be composed with ``TSType``.
558
559Module Instance Class
560^^^^^^^^^^^^^^^^^^^^^
561
562TorchScript module type represents the type schema of a user-defined PyTorch module instance.  When scripting a PyTorch module, the module object is always created outside TorchScript (i.e., passed in as parameter to ``forward``). The Python module class is treated as a module instance class, so the ``__init__()`` method of the Python module class is not subject to the type-checking rules of TorchScript.
563
564::
565
566    TSModuleType ::= "class" Identifier "(torch.nn.Module)" ":"
567                        ClassBodyDefinition
568
569Where:
570
571* ``forward()`` and other methods decorated with ``@torch.jit.export`` must be compilable to TorchScript IR and subject to TorchScript’s type-checking rules.
572
573Unlike custom classes, only the forward method and other methods decorated with ``@torch.jit.export``  of the module type need to be compilable. Most notably, ``__init__()`` is not considered a TorchScript method. Consequently, module type constructors cannot be invoked within the scope of TorchScript. Instead, TorchScript module objects are always constructed outside and passed into ``torch.jit.script(ModuleObj)``.
574
575**Example 1**
576
577This example illustrates a few features of module types:
578
579*  The ``TestModule`` instance is created outside the scope of TorchScript (i.e., before invoking ``torch.jit.script``).
580* ``__init__()`` is not considered a TorchScript method, therefore, it does not have to be annotated and can contain arbitrary Python code. In addition, the ``__init__()`` method of an instance class cannot be invoked in TorchScript code. Because ``TestModule`` instances are instantiated in Python, in this example, ``TestModule(2.0)`` and ``TestModule(2)`` create two instances with different types for its data attributes. ``self.x`` is of type ``float`` for ``TestModule(2.0)``, whereas ``self.y`` is of type ``int`` for ``TestModule(2.0)``.
581* TorchScript automatically compiles other methods (e.g., ``mul()``) invoked by methods annotated via ``@torch.jit.export`` or ``forward()`` methods.
582* Entry-points to a TorchScript program are either ``forward()`` of a module type, functions annotated as ``torch.jit.script``, or methods annotated as ``torch.jit.export``.
583
584.. testcode::
585
586    import torch
587
588    class TestModule(torch.nn.Module):
589        def __init__(self, v):
590            super().__init__()
591            self.x = v
592
593        def forward(self, inc: int):
594            return self.x + inc
595
596    m = torch.jit.script(TestModule(1))
597    print(f"First instance: {m(3)}")
598
599    m = torch.jit.script(TestModule(torch.ones([5])))
600    print(f"Second instance: {m(3)}")
601
602The example above produces the following output:
603
604.. testoutput::
605
606    First instance: 4
607    Second instance: tensor([4., 4., 4., 4., 4.])
608
609**Example 2**
610
611The following example shows an incorrect usage of module type. Specifically, this example invokes the constructor of ``TestModule`` inside the scope of TorchScript:
612
613.. testcode::
614
615    import torch
616
617    class TestModule(torch.nn.Module):
618        def __init__(self, v):
619            super().__init__()
620            self.x = v
621
622        def forward(self, x: int):
623            return self.x + x
624
625    class MyModel:
626        def __init__(self, v: int):
627            self.val = v
628
629        @torch.jit.export
630        def doSomething(self, val: int) -> int:
631            # error: should not invoke the constructor of module type
632            myModel = TestModule(self.val)
633            return myModel(val)
634
635    # m = torch.jit.script(MyModel(2)) # Results in below RuntimeError
636    # RuntimeError: Could not get name of python class object
637
638.. _type_annotation:
639
640
641Type Annotation
642~~~~~~~~~~~~~~~
643Since TorchScript is statically typed, programmers need to annotate types at *strategic points* of TorchScript code so that every local variable or
644instance data attribute has a static type, and every function and method has a statically typed signature.
645
646When to Annotate Types
647^^^^^^^^^^^^^^^^^^^^^^
648In general, type annotations are only needed in places where static types cannot be automatically inferred (e.g., parameters or sometimes return types to
649methods or functions). Types of local variables and data attributes are often automatically inferred from their assignment statements. Sometimes an inferred type
650may be too restrictive, e.g., ``x`` being inferred as ``NoneType`` through assignment ``x = None``, whereas ``x`` is actually used as an ``Optional``. In such
651cases, type annotations may be needed to overwrite auto inference, e.g., ``x: Optional[int] = None``. Note that it is always safe to type annotate a local variable
652or data attribute even if its type can be automatically inferred. The annotated type must be congruent with TorchScript’s type-checking.
653
654When a parameter, local variable, or data attribute is not type annotated and its type cannot be automatically inferred, TorchScript assumes it to be a
655default type of ``TensorType``, ``List[TensorType]``, or ``Dict[str, TensorType]``.
656
657Annotate Function Signature
658^^^^^^^^^^^^^^^^^^^^^^^^^^^
659Since a parameter may not be automatically inferred from the body of the function (including both functions and methods), they need to be type annotated. Otherwise, they assume the default type ``TensorType``.
660
661TorchScript supports two styles for method and function signature type annotation:
662
663* **Python3-style** annotates types directly on the signature. As such, it allows individual parameters to be left unannotated (whose type will be the default type of ``TensorType``), or allows the return type to be left unannotated (whose type will be automatically inferred).
664
665
666::
667
668    Python3Annotation ::= "def" Identifier [ "(" ParamAnnot* ")" ] [ReturnAnnot] ":"
669                                FuncOrMethodBody
670    ParamAnnot        ::= Identifier [ ":" TSType ] ","
671    ReturnAnnot       ::= "->" TSType
672
673Note that when using Python3 style, the type ``self`` is automatically inferred and should not be annotated.
674
675* **Mypy style** annotates types as a comment right below the function/method declaration. In the Mypy style, since parameter names do not appear in the annotation, all parameters have to be annotated.
676
677
678::
679
680    MyPyAnnotation ::= "# type:" "(" ParamAnnot* ")" [ ReturnAnnot ]
681    ParamAnnot     ::= TSType ","
682    ReturnAnnot    ::= "->" TSType
683
684**Example 1**
685
686In this example:
687
688* ``a`` is not annotated and assumes the default type of ``TensorType``.
689* ``b`` is annotated as type ``int``.
690* The return type is not annotated and is automatically inferred as type ``TensorType`` (based on the type of the value being returned).
691
692::
693
694    import torch
695
696    def f(a, b: int):
697        return a+b
698
699    m = torch.jit.script(f)
700    print("TorchScript:", m(torch.ones([6]), 100))
701
702**Example 2**
703
704The following example uses Mypy style annotation. Note that parameters or return values must be annotated even if some of
705them assume the default type.
706
707::
708
709    import torch
710
711    def f(a, b):
712        # type: (torch.Tensor, int) → torch.Tensor
713        return a+b
714
715    m = torch.jit.script(f)
716    print("TorchScript:", m(torch.ones([6]), 100))
717
718
719Annotate Variables and Data Attributes
720^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
721In general, types of data attributes (including class and instance data attributes) and local variables can be automatically inferred from assignment statements.
722Sometimes, however, if a variable or attribute is associated with values of different types (e.g., as ``None`` or ``TensorType``), then they may need to be explicitly
723type annotated as a *wider* type such as ``Optional[int]`` or ``Any``.
724
725Local Variables
726"""""""""""""""
727Local variables can be annotated according to Python3 typing module annotation rules, i.e.,
728
729::
730
731    LocalVarAnnotation ::= Identifier [":" TSType] "=" Expr
732
733In general, types of local variables can be automatically inferred. In some cases, however, you may need to annotate a multi-type for local variables
734that may be associated with different concrete types. Typical multi-types include ``Optional[T]`` and ``Any``.
735
736**Example**
737
738::
739
740    import torch
741
742    def f(a, setVal: bool):
743        value: Optional[torch.Tensor] = None
744        if setVal:
745            value = a
746        return value
747
748    ones = torch.ones([6])
749    m = torch.jit.script(f)
750    print("TorchScript:", m(ones, True), m(ones, False))
751
752Instance Data Attributes
753""""""""""""""""""""""""
754For ``ModuleType`` classes, instance data attributes can be annotated according to Python3 typing module annotation rules. Instance data attributes can be annotated (optionally) as final
755via ``Final``.
756
757::
758
759    "class" ClassIdentifier "(torch.nn.Module):"
760    InstanceAttrIdentifier ":" ["Final("] TSType [")"]
761    ...
762
763Where:
764
765* ``InstanceAttrIdentifier`` is the name of an instance attribute.
766* ``Final`` indicates that the attribute cannot be re-assigned outside of ``__init__`` or overridden in subclasses.
767
768**Example**
769
770::
771
772    import torch
773
774    class MyModule(torch.nn.Module):
775        offset_: int
776
777    def __init__(self, offset):
778        self.offset_ = offset
779
780    ...
781
782
783
784Type Annotation APIs
785^^^^^^^^^^^^^^^^^^^^
786
787``torch.jit.annotate(T, expr)``
788"""""""""""""""""""""""""""""""
789This API annotates type ``T`` to an expression ``expr``. This is often used when the default type of an expression is not the type intended by the programmer.
790For instance, an empty list (dictionary) has the default type of ``List[TensorType]`` (``Dict[TensorType, TensorType]``), but sometimes it may be used to initialize
791a list of some other types. Another common use case is for annotating the return type of ``tensor.tolist()``. Note, however, that it cannot be used to annotate
792the type of a module attribute in `__init__`; ``torch.jit.Attribute`` should be used for this instead.
793
794**Example**
795
796In this example, ``[]`` is declared as a list of integers via ``torch.jit.annotate`` (instead of assuming ``[]`` to be the default type of ``List[TensorType]``):
797
798::
799
800    import torch
801    from typing import List
802
803    def g(l: List[int], val: int):
804        l.append(val)
805        return l
806
807    def f(val: int):
808        l = g(torch.jit.annotate(List[int], []), val)
809        return l
810
811    m = torch.jit.script(f)
812    print("Eager:", f(3))
813    print("TorchScript:", m(3))
814
815
816See :meth:`torch.jit.annotate` for more information.
817
818
819Type Annotation Appendix
820^^^^^^^^^^^^^^^^^^^^^^^^
821
822TorchScript Type System Definition
823""""""""""""""""""""""""""""""""""
824
825::
826
827    TSAllType       ::= TSType | TSModuleType
828    TSType          ::= TSMetaType | TSPrimitiveType | TSStructuralType | TSNominalType
829
830    TSMetaType      ::= "Any"
831    TSPrimitiveType ::= "int" | "float" | "double" | "complex" | "bool" | "str" | "None"
832
833    TSStructuralType ::= TSTuple | TSNamedTuple | TSList | TSDict | TSOptional |
834                         TSUnion | TSFuture | TSRRef | TSAwait
835    TSTuple         ::= "Tuple" "[" (TSType ",")* TSType "]"
836    TSNamedTuple    ::= "namedtuple" "(" (TSType ",")* TSType ")"
837    TSList          ::= "List" "[" TSType "]"
838    TSOptional      ::= "Optional" "[" TSType "]"
839    TSUnion         ::= "Union" "[" (TSType ",")* TSType "]"
840    TSFuture        ::= "Future" "[" TSType "]"
841    TSRRef          ::= "RRef" "[" TSType "]"
842    TSAwait         ::= "Await" "[" TSType "]"
843    TSDict          ::= "Dict" "[" KeyType "," TSType "]"
844    KeyType         ::= "str" | "int" | "float" | "bool" | TensorType | "Any"
845
846    TSNominalType   ::= TSBuiltinClasses | TSCustomClass | TSEnum
847    TSBuiltinClass  ::= TSTensor | "torch.device" | "torch.stream"|
848                        "torch.dtype" | "torch.nn.ModuleList" |
849                        "torch.nn.ModuleDict" | ...
850    TSTensor        ::= "torch.tensor" and subclasses
851
852Unsupported Typing Constructs
853"""""""""""""""""""""""""""""
854TorchScript does not support all features and types of the Python3 `typing <https://docs.python.org/3/library/typing.html#module-typing>`_ module.
855Any functionality from the `typing <https://docs.python.org/3/library/typing.html#module-typing>`_ module that is not explicitly specified in this
856documentation is unsupported. The following table summarizes ``typing`` constructs that are either unsupported or supported with restrictions in TorchScript.
857
858=============================  ================
859 Item                           Description
860-----------------------------  ----------------
861``typing.Any``                  In development
862``typing.NoReturn``             Not supported
863``typing.Callable``             Not supported
864``typing.Literal``              Not supported
865``typing.ClassVar``             Not supported
866``typing.Final``                Supported for module attributes, class attribute, and annotations, but not for functions.
867``typing.AnyStr``               Not supported
868``typing.overload``             In development
869Type aliases                    Not supported
870Nominal typing                  In development
871Structural typing               Not supported
872NewType                         Not supported
873Generics                        Not supported
874=============================  ================
875
876
877.. _expressions:
878
879
880Expressions
881~~~~~~~~~~~
882
883The following section describes the grammar of expressions that are supported in TorchScript.
884It is modeled after `the expressions chapter of the Python language reference <https://docs.python.org/3/reference/expressions.html>`_.
885
886Arithmetic Conversions
887^^^^^^^^^^^^^^^^^^^^^^
888There are a number of implicit type conversions that are performed in TorchScript:
889
890
891* A ``Tensor`` with a ``float`` or ``int`` data type can be implicitly converted to an instance of ``FloatType`` or ``IntType`` provided that it has a size of 0, does not have ``require_grad`` set to ``True``, and will not require narrowing.
892* Instances of ``StringType`` can be implicitly converted to ``DeviceType``.
893* The implicit conversion rules from the two bullet points above can be applied to instances of ``TupleType`` to produce instances of ``ListType`` with the appropriate contained type.
894
895
896Explicit conversions can be invoked using the ``float``, ``int``, ``bool``, and ``str`` built-in functions
897that accept primitive data types as arguments and can accept user-defined types if they implement
898``__bool__``, ``__str__``, etc.
899
900
901Atoms
902^^^^^
903Atoms are the most basic elements of expressions.
904
905::
906
907    atom      ::=  identifier | literal | enclosure
908    enclosure ::=  parenth_form | list_display | dict_display
909
910Identifiers
911"""""""""""
912The rules that dictate what is a legal identifier in TorchScript are the same as
913their `Python counterparts <https://docs.python.org/3/reference/lexical_analysis.html#identifiers>`_.
914
915Literals
916""""""""
917
918::
919
920    literal ::=  stringliteral | integer | floatnumber
921
922Evaluation of a literal yields an object of the appropriate type with the specific value
923(with approximations applied as necessary for floats). Literals are immutable, and multiple evaluations
924of identical literals may obtain the same object or distinct objects with the same value.
925`stringliteral <https://docs.python.org/3/reference/lexical_analysis.html#string-and-bytes-literals>`_,
926`integer <https://docs.python.org/3/reference/lexical_analysis.html#integer-literals>`_, and
927`floatnumber <https://docs.python.org/3/reference/lexical_analysis.html#floating-point-literals>`_
928are defined in the same way as their Python counterparts.
929
930Parenthesized Forms
931"""""""""""""""""""
932
933::
934
935    parenth_form ::=  '(' [expression_list] ')'
936
937A parenthesized expression list yields whatever the expression list yields. If the list contains at least one
938comma, it yields a ``Tuple``; otherwise, it yields the single expression inside the expression list. An empty
939pair of parentheses yields an empty ``Tuple`` object (``Tuple[]``).
940
941List and Dictionary Displays
942""""""""""""""""""""""""""""
943
944::
945
946    list_comprehension ::=  expression comp_for
947    comp_for           ::=  'for' target_list 'in' or_expr
948    list_display       ::=  '[' [expression_list | list_comprehension] ']'
949    dict_display       ::=  '{' [key_datum_list | dict_comprehension] '}'
950    key_datum_list     ::=  key_datum (',' key_datum)*
951    key_datum          ::=  expression ':' expression
952    dict_comprehension ::=  key_datum comp_for
953
954Lists and dicts can be constructed by either listing the container contents explicitly or by providing
955instructions on how to compute them via a set of looping instructions (i.e. a *comprehension*). A comprehension
956is semantically equivalent to using a for loop and appending to an ongoing list.
957Comprehensions implicitly create their own scope to make sure that the items of the target list do not leak into the
958enclosing scope. In the case that container items are explicitly listed, the expressions in the expression list
959are evaluated left-to-right. If a key is repeated in a ``dict_display`` that has a ``key_datum_list``, the
960resultant dictionary uses the value from the rightmost datum in the list that uses the repeated key.
961
962Primaries
963^^^^^^^^^
964
965::
966
967    primary ::=  atom | attributeref | subscription | slicing | call
968
969
970Attribute References
971""""""""""""""""""""
972
973::
974
975    attributeref ::=  primary '.' identifier
976
977
978The ``primary`` must evaluate to an object of a type that supports attribute references that have an attribute named
979``identifier``.
980
981Subscriptions
982"""""""""""""
983
984::
985
986    subscription ::=  primary '[' expression_list ']'
987
988
989The ``primary`` must evaluate to an object that supports subscription.
990
991* If the primary is a ``List``, ``Tuple``, or ``str``, the expression list must evaluate to an integer or slice.
992* If the primary is a ``Dict``, the expression list must evaluate to an object of the same type as the key type of the ``Dict``.
993* If the primary is a ``ModuleList``, the expression list must be an ``integer`` literal.
994* If the primary is a ``ModuleDict``, the expression must be a ``stringliteral``.
995
996
997Slicings
998""""""""
999A slicing selects a range of items in a ``str``, ``Tuple``, ``List``, or ``Tensor``. Slicings may be used as
1000expressions or targets in assignment or ``del`` statements.
1001
1002::
1003
1004    slicing      ::=  primary '[' slice_list ']'
1005    slice_list   ::=  slice_item (',' slice_item)* [',']
1006    slice_item   ::=  expression | proper_slice
1007    proper_slice ::=  [expression] ':' [expression] [':' [expression] ]
1008
1009Slicings with more than one slice item in their slice lists can only be used with primaries that evaluate to an
1010object of type ``Tensor``.
1011
1012
1013Calls
1014"""""
1015
1016::
1017
1018    call          ::=  primary '(' argument_list ')'
1019    argument_list ::=  args [',' kwargs] | kwargs
1020    args          ::=  [arg (',' arg)*]
1021    kwargs        ::=  [kwarg (',' kwarg)*]
1022    kwarg         ::=  arg '=' expression
1023    arg           ::=  identifier
1024
1025
1026The ``primary`` must desugar or evaluate to a callable object. All argument expressions are evaluated
1027before the call is attempted.
1028
1029Power Operator
1030^^^^^^^^^^^^^^
1031
1032::
1033
1034    power ::=  primary ['**' u_expr]
1035
1036
1037The power operator has the same semantics as the built-in pow function (not supported); it computes its
1038left argument raised to the power of its right argument. It binds more tightly than unary operators on the
1039left, but less tightly than unary operators on the right; i.e. ``-2 ** -3 == -(2 ** (-3))``.  The left and right
1040operands can be ``int``, ``float`` or ``Tensor``. Scalars are broadcast in the case of scalar-tensor/tensor-scalar
1041exponentiation operations, and tensor-tensor exponentiation is done elementwise without any broadcasting.
1042
1043Unary and Arithmetic Bitwise Operations
1044^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
1045
1046::
1047
1048    u_expr ::=  power | '-' power | '~' power
1049
1050The unary ``-`` operator yields the negation of its argument. The unary ``~`` operator yields the bitwise inversion
1051of its argument. ``-`` can be used with ``int``, ``float``, and ``Tensor`` of ``int`` and ``float``.
1052``~`` can only be used with ``int`` and ``Tensor`` of ``int``.
1053
1054Binary Arithmetic Operations
1055^^^^^^^^^^^^^^^^^^^^^^^^^^^^
1056
1057::
1058
1059    m_expr ::=  u_expr | m_expr '*' u_expr | m_expr '@' m_expr | m_expr '//' u_expr | m_expr '/' u_expr | m_expr '%' u_expr
1060    a_expr ::=  m_expr | a_expr '+' m_expr | a_expr '-' m_expr
1061
1062The binary arithmetic operators can operate on ``Tensor``, ``int``, and ``float``. For tensor-tensor ops, both arguments must
1063have the same shape. For scalar-tensor or tensor-scalar ops, the scalar is usually broadcast to the size of the
1064tensor. Division ops can only accept scalars as their right-hand side argument, and do not support broadcasting.
1065The ``@`` operator is for matrix multiplication and only operates on ``Tensor`` arguments. The multiplication operator
1066(``*``) can be used with a list and integer in order to get a result that is the original list repeated a certain
1067number of times.
1068
1069Shifting Operations
1070^^^^^^^^^^^^^^^^^^^
1071
1072::
1073
1074    shift_expr ::=  a_expr | shift_expr ( '<<' | '>>' ) a_expr
1075
1076
1077These operators accept two ``int`` arguments, two ``Tensor`` arguments, or a ``Tensor`` argument and an ``int`` or
1078``float`` argument. In all cases, a right shift by ``n`` is defined as floor division by ``pow(2, n)``, and a left shift
1079by ``n`` is defined as multiplication by ``pow(2, n)``. When both arguments are ``Tensors``, they must have the same
1080shape. When one is a scalar and the other is a ``Tensor``, the scalar is logically broadcast to match the size of
1081the ``Tensor``.
1082
1083Binary Bitwise Operations
1084^^^^^^^^^^^^^^^^^^^^^^^^^
1085
1086::
1087
1088    and_expr ::=  shift_expr | and_expr '&' shift_expr
1089    xor_expr ::=  and_expr | xor_expr '^' and_expr
1090    or_expr  ::=  xor_expr | or_expr '|' xor_expr
1091
1092
1093The ``&`` operator computes the bitwise AND of its arguments, the ``^`` the bitwise XOR, and the ``|`` the bitwise OR.
1094Both operands must be ``int`` or ``Tensor``, or the left operand must be ``Tensor`` and the right operand must be
1095``int``. When both operands are ``Tensor``, they must have the same shape. When the right operand is ``int``, and
1096the left operand is ``Tensor``, the right operand is logically broadcast to match the shape of the ``Tensor``.
1097
1098Comparisons
1099^^^^^^^^^^^
1100
1101::
1102
1103    comparison    ::=  or_expr (comp_operator or_expr)*
1104    comp_operator ::=  '<' | '>' | '==' | '>=' | '<=' | '!=' | 'is' ['not'] | ['not'] 'in'
1105
1106A comparison yields a boolean value (``True`` or ``False``), or if one of the operands is a ``Tensor``, a boolean
1107``Tensor``. Comparisons can be chained arbitrarily as long as they do not yield boolean ``Tensors`` that have more
1108than one element. ``a op1 b op2 c ...`` is equivalent to ``a op1 b and b op2 c and ...``.
1109
1110Value Comparisons
1111"""""""""""""""""
1112The operators ``<``, ``>``, ``==``, ``>=``, ``<=``, and ``!=`` compare the values of two objects. The two objects generally need to be of
1113the same type, unless there is an implicit type conversion available between the objects. User-defined types can
1114be compared if rich comparison methods (e.g., ``__lt__``) are defined on them. Built-in type comparison works like
1115Python:
1116
1117* Numbers are compared mathematically.
1118* Strings are compared lexicographically.
1119* ``lists``, ``tuples``, and ``dicts`` can be compared only to other ``lists``, ``tuples``, and ``dicts`` of the same type and are compared using the comparison operator of corresponding elements.
1120
1121Membership Test Operations
1122""""""""""""""""""""""""""
1123The operators ``in`` and ``not in`` test for membership. ``x in s`` evaluates to ``True`` if ``x`` is a member of ``s`` and ``False`` otherwise.
1124``x not in s`` is equivalent to ``not x in s``. This operator is supported for ``lists``, ``dicts``, and ``tuples``, and can be used with
1125user-defined types if they implement the ``__contains__`` method.
1126
1127Identity Comparisons
1128""""""""""""""""""""
1129For all types except ``int``, ``double``, ``bool``, and ``torch.device``, operators ``is`` and ``is not`` test for the object’s identity;
1130``x is y`` is ``True`` if and only if ``x`` and ``y`` are the same object. For all other types, ``is`` is equivalent to
1131comparing them using ``==``. ``x is not y`` yields the inverse of ``x is y``.
1132
1133Boolean Operations
1134^^^^^^^^^^^^^^^^^^
1135
1136::
1137
1138    or_test  ::=  and_test | or_test 'or' and_test
1139    and_test ::=  not_test | and_test 'and' not_test
1140    not_test ::=  'bool' '(' or_expr ')' | comparison | 'not' not_test
1141
1142User-defined objects can customize their conversion to ``bool`` by implementing a ``__bool__`` method. The operator ``not``
1143yields ``True`` if its operand is false, ``False`` otherwise. The expression ``x`` and ``y`` first evaluates ``x``; if it is ``False``, its
1144value (``False``) is returned; otherwise, ``y`` is evaluated and its value is returned (``False`` or ``True``). The expression ``x`` or ``y``
1145first evaluates ``x``; if it is ``True``, its value (``True``) is returned; otherwise, ``y`` is evaluated and its value is returned
1146(``False`` or ``True``).
1147
1148Conditional Expressions
1149^^^^^^^^^^^^^^^^^^^^^^^
1150
1151::
1152
1153   conditional_expression ::=  or_expr ['if' or_test 'else' conditional_expression]
1154    expression            ::=  conditional_expression
1155
1156The expression ``x if c else y`` first evaluates the condition ``c`` rather than x. If ``c`` is ``True``, ``x`` is
1157evaluated and its value is returned; otherwise, ``y`` is evaluated and its value is returned. As with if-statements,
1158``x`` and ``y`` must evaluate to a value of the same type.
1159
1160Expression Lists
1161^^^^^^^^^^^^^^^^
1162
1163::
1164
1165    expression_list ::=  expression (',' expression)* [',']
1166    starred_item    ::=  '*' primary
1167
1168A starred item can only appear on the left-hand side of an assignment statement, e.g., ``a, *b, c = ...``.
1169
1170.. statements:
1171
1172Simple Statements
1173~~~~~~~~~~~~~~~~~
1174
1175The following section describes the syntax of simple statements that are supported in TorchScript.
1176It is modeled after `the simple statements chapter of the Python language reference <https://docs.python.org/3/reference/simple_stmts.html>`_.
1177
1178Expression Statements
1179^^^^^^^^^^^^^^^^^^^^^^
1180
1181::
1182
1183    expression_stmt    ::=  starred_expression
1184    starred_expression ::=  expression | (starred_item ",")* [starred_item]
1185    starred_item       ::=  assignment_expression | "*" or_expr
1186
1187Assignment Statements
1188^^^^^^^^^^^^^^^^^^^^^^
1189
1190::
1191
1192    assignment_stmt ::=  (target_list "=")+ (starred_expression)
1193    target_list     ::=  target ("," target)* [","]
1194    target          ::=  identifier
1195                         | "(" [target_list] ")"
1196                         | "[" [target_list] "]"
1197                         | attributeref
1198                         | subscription
1199                         | slicing
1200                         | "*" target
1201
1202Augmented Assignment Statements
1203^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
1204
1205::
1206
1207    augmented_assignment_stmt ::= augtarget augop (expression_list)
1208    augtarget                 ::= identifier | attributeref | subscription
1209    augop                     ::= "+=" | "-=" | "*=" | "/=" | "//=" | "%=" |
1210                                  "**="| ">>=" | "<<=" | "&=" | "^=" | "|="
1211
1212
1213Annotated Assignment Statements
1214^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
1215::
1216
1217    annotated_assignment_stmt ::= augtarget ":" expression
1218                                  ["=" (starred_expression)]
1219
1220The ``raise`` Statement
1221^^^^^^^^^^^^^^^^^^^^^^^^
1222
1223::
1224
1225    raise_stmt ::=  "raise" [expression ["from" expression]]
1226
1227Raise statements in TorchScript do not support ``try\except\finally``.
1228
1229The ``assert`` Statement
1230^^^^^^^^^^^^^^^^^^^^^^^^^
1231
1232::
1233
1234    assert_stmt ::=  "assert" expression ["," expression]
1235
1236Assert statements in TorchScript do not support ``try\except\finally``.
1237
1238The ``return`` Statement
1239^^^^^^^^^^^^^^^^^^^^^^^^^
1240
1241::
1242
1243    return_stmt ::=  "return" [expression_list]
1244
1245Return statements in TorchScript do not support ``try\except\finally``.
1246
1247The ``del`` Statement
1248^^^^^^^^^^^^^^^^^^^^^^
1249
1250::
1251
1252    del_stmt ::=  "del" target_list
1253
1254The ``pass`` Statement
1255^^^^^^^^^^^^^^^^^^^^^^^
1256
1257::
1258
1259    pass_stmt ::= "pass"
1260
1261The ``print`` Statement
1262^^^^^^^^^^^^^^^^^^^^^^^^
1263
1264::
1265
1266    print_stmt ::= "print" "(" expression  [, expression] [.format{expression_list}] ")"
1267
1268The ``break`` Statement
1269^^^^^^^^^^^^^^^^^^^^^^^^
1270
1271::
1272
1273    break_stmt ::= "break"
1274
1275The ``continue`` Statement:
1276^^^^^^^^^^^^^^^^^^^^^^^^^^^
1277
1278::
1279
1280    continue_stmt ::= "continue"
1281
1282Compound Statements
1283~~~~~~~~~~~~~~~~~~~
1284
1285The following section describes the syntax of compound statements that are supported in TorchScript.
1286The section also highlights how Torchscript differs from regular Python statements.
1287It is modeled after `the compound statements chapter of the Python language reference <https://docs.python.org/3/reference/compound_stmts.html>`_.
1288
1289The ``if`` Statement
1290^^^^^^^^^^^^^^^^^^^^^
1291
1292Torchscript supports both basic ``if/else`` and ternary ``if/else``.
1293
1294Basic ``if/else`` Statement
1295""""""""""""""""""""""""""""
1296
1297::
1298
1299    if_stmt ::= "if" assignment_expression ":" suite
1300                ("elif" assignment_expression ":" suite)
1301                ["else" ":" suite]
1302
1303``elif`` statements can repeat for an arbitrary number of times, but it needs to be before ``else`` statement.
1304
1305Ternary ``if/else`` Statement
1306""""""""""""""""""""""""""""""
1307
1308::
1309
1310    if_stmt ::= return [expression_list] "if" assignment_expression "else" [expression_list]
1311
1312**Example 1**
1313
1314A ``tensor`` with 1 dimension is promoted to ``bool``:
1315
1316.. testcode::
1317
1318    import torch
1319
1320    @torch.jit.script
1321    def fn(x: torch.Tensor):
1322        if x: # The tensor gets promoted to bool
1323            return True
1324        return False
1325    print(fn(torch.rand(1)))
1326
1327The example above produces the following output:
1328
1329.. testoutput::
1330
1331    True
1332
1333**Example 2**
1334
1335A ``tensor`` with multi dimensions are not promoted to ``bool``:
1336
1337::
1338
1339    import torch
1340
1341    # Multi dimensional Tensors error out.
1342
1343    @torch.jit.script
1344    def fn():
1345        if torch.rand(2):
1346            print("Tensor is available")
1347
1348        if torch.rand(4,5,6):
1349            print("Tensor is available")
1350
1351    print(fn())
1352
1353Running the above code yields the following ``RuntimeError``.
1354
1355::
1356
1357    RuntimeError: The following operation failed in the TorchScript interpreter.
1358    Traceback of TorchScript (most recent call last):
1359    @torch.jit.script
1360    def fn():
1361        if torch.rand(2):
1362           ~~~~~~~~~~~~ <--- HERE
1363            print("Tensor is available")
1364    RuntimeError: Boolean value of Tensor with more than one value is ambiguous
1365
1366If a conditional variable is annotated as ``final``, either the true or false branch is evaluated depending on the evaluation of the conditional variable.
1367
1368**Example 3**
1369
1370In this example, only the True branch is evaluated, since ``a`` is annotated as ``final`` and set to ``True``:
1371
1372::
1373
1374    import torch
1375
1376    a : torch.jit.final[Bool] = True
1377
1378    if a:
1379        return torch.empty(2,3)
1380    else:
1381        return []
1382
1383
1384The ``while`` Statement
1385^^^^^^^^^^^^^^^^^^^^^^^^
1386
1387::
1388
1389    while_stmt ::=  "while" assignment_expression ":" suite
1390
1391`while...else` statements are not supported in Torchscript. It results in a ``RuntimeError``.
1392
1393The ``for-in`` Statement
1394^^^^^^^^^^^^^^^^^^^^^^^^^
1395
1396::
1397
1398    for_stmt ::=  "for" target_list "in" expression_list ":" suite
1399                  ["else" ":" suite]
1400
1401``for...else`` statements are not supported in Torchscript. It results in a ``RuntimeError``.
1402
1403**Example 1**
1404
1405For loops on tuples: these unroll the loop, generating a body for each member of the tuple. The body must type-check correctly for each member.
1406
1407.. testcode::
1408
1409    import torch
1410    from typing import Tuple
1411
1412    @torch.jit.script
1413    def fn():
1414        tup = (3, torch.ones(4))
1415        for x in tup:
1416            print(x)
1417
1418    fn()
1419
1420The example above produces the following output:
1421
1422.. testoutput::
1423
1424    3
1425     1
1426     1
1427     1
1428     1
1429    [ CPUFloatType{4} ]
1430
1431
1432**Example 2**
1433
1434For loops on lists: for loops over a ``nn.ModuleList`` will unroll the body of the loop at compile time, with each member of the module list.
1435
1436::
1437
1438    class SubModule(torch.nn.Module):
1439        def __init__(self):
1440            super().__init__()
1441            self.weight = nn.Parameter(torch.randn(2))
1442
1443        def forward(self, input):
1444            return self.weight + input
1445
1446    class MyModule(torch.nn.Module):
1447        def __init__(self):
1448            super().__init__()
1449            self.mods = torch.nn.ModuleList([SubModule() for i in range(10)])
1450
1451        def forward(self, v):
1452            for module in self.mods:
1453                v = module(v)
1454            return v
1455
1456    model = torch.jit.script(MyModule())
1457
1458The ``with`` Statement
1459^^^^^^^^^^^^^^^^^^^^^^^
1460The ``with`` statement is used to wrap the execution of a block with methods defined by a context manager.
1461
1462::
1463
1464    with_stmt ::=  "with" with_item ("," with_item) ":" suite
1465    with_item ::=  expression ["as" target]
1466
1467* If a target was included in the ``with`` statement, the return value from the context manager’s ``__enter__()`` is assigned to it. Unlike python, if an exception caused the suite to be exited, its type, value, and traceback are not passed as arguments to ``__exit__()``. Three ``None`` arguments are supplied.
1468* ``try``, ``except``, and ``finally`` statements are not supported inside ``with`` blocks.
1469*  Exceptions raised within ``with`` block cannot be suppressed.
1470
1471The ``tuple`` Statement
1472^^^^^^^^^^^^^^^^^^^^^^^^
1473
1474::
1475
1476    tuple_stmt ::= tuple([iterables])
1477
1478* Iterable types in TorchScript include ``Tensors``, ``lists``, ``tuples``, ``dictionaries``, ``strings``, ``torch.nn.ModuleList``, and ``torch.nn.ModuleDict``.
1479* You cannot convert a List to Tuple by using this built-in function.
1480
1481Unpacking all outputs into a tuple is covered by:
1482
1483::
1484
1485    abc = func() # Function that returns a tuple
1486    a,b = func()
1487
1488The ``getattr`` Statement
1489^^^^^^^^^^^^^^^^^^^^^^^^^^
1490
1491::
1492
1493    getattr_stmt ::= getattr(object, name[, default])
1494
1495* Attribute name must be a literal string.
1496* Module type object is not supported (e.g., torch._C).
1497* Custom class object is not supported (e.g., torch.classes.*).
1498
1499The ``hasattr`` Statement
1500^^^^^^^^^^^^^^^^^^^^^^^^^^
1501
1502::
1503
1504    hasattr_stmt ::= hasattr(object, name)
1505
1506* Attribute name must be a literal string.
1507* Module type object is not supported (e.g., torch._C).
1508* Custom class object is not supported (e.g., torch.classes.*).
1509
1510The ``zip`` Statement
1511^^^^^^^^^^^^^^^^^^^^^^
1512
1513::
1514
1515    zip_stmt ::= zip(iterable1, iterable2)
1516
1517* Arguments must be iterables.
1518* Two iterables of same outer container type but different length are supported.
1519
1520**Example 1**
1521
1522Both the iterables must be of the same container type:
1523
1524.. testcode::
1525
1526    a = [1, 2] # List
1527    b = [2, 3, 4] # List
1528    zip(a, b) # works
1529
1530**Example 2**
1531
1532This example fails because the iterables are of different container types:
1533
1534::
1535
1536    a = (1, 2) # Tuple
1537    b = [2, 3, 4] # List
1538    zip(a, b) # Runtime error
1539
1540Running the above code yields the following ``RuntimeError``.
1541
1542::
1543
1544    RuntimeError: Can not iterate over a module list or
1545        tuple with a value that does not have a statically determinable length.
1546
1547**Example 3**
1548
1549Two iterables of the same container Type but different data type is supported:
1550
1551.. testcode::
1552
1553    a = [1.3, 2.4]
1554    b = [2, 3, 4]
1555    zip(a, b) # Works
1556
1557Iterable types in TorchScript include ``Tensors``, ``lists``, ``tuples``, ``dictionaries``, ``strings``, ``torch.nn.ModuleList``, and ``torch.nn.ModuleDict``.
1558
1559The ``enumerate`` Statement
1560^^^^^^^^^^^^^^^^^^^^^^^^^^^^
1561
1562::
1563
1564    enumerate_stmt ::= enumerate([iterable])
1565
1566* Arguments must be iterables.
1567* Iterable types in TorchScript include ``Tensors``, ``lists``, ``tuples``, ``dictionaries``, ``strings``, ``torch.nn.ModuleList`` and ``torch.nn.ModuleDict``.
1568
1569
1570.. _python-values-torch-script:
1571
1572Python Values
1573~~~~~~~~~~~~~
1574
1575.. _python-builtin-functions-values-resolution:
1576
1577Resolution Rules
1578^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
1579When given a Python value, TorchScript attempts to resolve it in the following five different ways:
1580
1581* Compilable Python Implementation:
1582    * When a Python value is backed by a Python implementation that can be compiled by TorchScript, TorchScript compiles and uses the underlying Python implementation.
1583    * Example: ``torch.jit.Attribute``
1584* Op Python Wrapper:
1585    * When a Python value is a wrapper of a native PyTorch op, TorchScript emits the corresponding operator.
1586    * Example: ``torch.jit._logging.add_stat_value``
1587* Python Object Identity Match:
1588    * For a limited set of ``torch.*`` API calls (in the form of Python values) that TorchScript supports, TorchScript attempts to match a Python value against each item in the set.
1589    * When matched, TorchScript generates a corresponding ``SugaredValue`` instance that contains lowering logic for these values.
1590    * Example: ``torch.jit.isinstance()``
1591* Name Match:
1592    * For Python built-in functions and constants, TorchScript identifies them by name, and creates a corresponding ``SugaredValue`` instance that implements their functionality.
1593    * Example: ``all()``
1594* Value Snapshot:
1595    * For Python values from unrecognized modules, TorchScript attempts to take a snapshot of the value and converts it to a constant in the graph of the function(s) or method(s) that are being compiled.
1596    * Example: ``math.pi``
1597
1598
1599
1600.. _python-builtin-functions-support:
1601
1602Python Built-in Functions Support
1603^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
1604.. list-table:: TorchScript Support for Python Built-in Functions
1605   :widths: 25 25 50
1606   :header-rows: 1
1607
1608   * - Built-in Function
1609     - Support Level
1610     - Notes
1611   * - ``abs()``
1612     - Partial
1613     - Only supports ``Tensor``/``Int``/``Float`` type inputs. | Doesn't honor ``__abs__`` override.
1614   * - ``all()``
1615     - Full
1616     -
1617   * - ``any()``
1618     - Full
1619     -
1620   * - ``ascii()``
1621     - None
1622     -
1623   * - ``bin()``
1624     - Partial
1625     - Only supports ``Int`` type input.
1626   * - ``bool()``
1627     - Partial
1628     - Only supports ``Tensor``/``Int``/``Float`` type inputs.
1629   * - ``breakpoint()``
1630     - None
1631     -
1632   * - ``bytearray()``
1633     - None
1634     -
1635   * - ``bytes()``
1636     - None
1637     -
1638   * - ``callable()``
1639     - None
1640     -
1641   * - ``chr()``
1642     - Partial
1643     - Only ASCII character set is supported.
1644   * - ``classmethod()``
1645     - Full
1646     -
1647   * - ``compile()``
1648     - None
1649     -
1650   * - ``complex()``
1651     - None
1652     -
1653   * - ``delattr()``
1654     - None
1655     -
1656   * - ``dict()``
1657     - Full
1658     -
1659   * - ``dir()``
1660     - None
1661     -
1662   * - ``divmod()``
1663     - Full
1664     -
1665   * - ``enumerate()``
1666     - Full
1667     -
1668   * - ``eval()``
1669     - None
1670     -
1671   * - ``exec()``
1672     - None
1673     -
1674   * - ``filter()``
1675     - None
1676     -
1677   * - ``float()``
1678     - Partial
1679     - Doesn't honor ``__index__`` override.
1680   * - ``format()``
1681     - Partial
1682     - Manual index specification not supported. | Format type modifier not supported.
1683   * - ``frozenset()``
1684     - None
1685     -
1686   * - ``getattr()``
1687     - Partial
1688     - Attribute name must be string literal.
1689   * - ``globals()``
1690     - None
1691     -
1692   * - ``hasattr()``
1693     - Partial
1694     - Attribute name must be string literal.
1695   * - ``hash()``
1696     - Full
1697     - ``Tensor``'s hash is based on identity not numeric value.
1698   * - ``hex()``
1699     - Partial
1700     - Only supports ``Int`` type input.
1701   * - ``id()``
1702     - Full
1703     - Only supports ``Int`` type input.
1704   * - ``input()``
1705     - None
1706     -
1707   * - ``int()``
1708     - Partial
1709     - ``base`` argument not supported. | Doesn't honor ``__index__`` override.
1710   * - ``isinstance()``
1711     - Full
1712     - ``torch.jit.isintance`` provides better support when checking against container types like ``Dict[str, int]``.
1713   * - ``issubclass()``
1714     - None
1715     -
1716   * - ``iter()``
1717     - None
1718     -
1719   * - ``len()``
1720     - Full
1721     -
1722   * - ``list()``
1723     - Full
1724     -
1725   * - ``ord()``
1726     - Partial
1727     - Only ASCII character set is supported.
1728   * - ``pow()``
1729     - Full
1730     -
1731   * - ``print()``
1732     - Partial
1733     - ``separate``, ``end`` and ``file`` arguments are not supported.
1734   * - ``property()``
1735     - None
1736     -
1737   * - ``range()``
1738     - Full
1739     -
1740   * - ``repr()``
1741     - None
1742     -
1743   * - ``reversed()``
1744     - None
1745     -
1746   * - ``round()``
1747     - Partial
1748     - ``ndigits`` argument is not supported.
1749   * - ``set()``
1750     - None
1751     -
1752   * - ``setattr()``
1753     - None
1754     -
1755   * - ``slice()``
1756     - Full
1757     -
1758   * - ``sorted()``
1759     - Partial
1760     - ``key`` argument is not supported.
1761   * - ``staticmethod()``
1762     - Full
1763     -
1764   * - ``str()``
1765     - Partial
1766     - ``encoding`` and ``errors`` arguments are not supported.
1767   * - ``sum()``
1768     - Full
1769     -
1770   * - ``super()``
1771     - Partial
1772     - It can only be used in ``nn.Module``'s ``__init__`` method.
1773   * - ``type()``
1774     - None
1775     -
1776   * - ``vars()``
1777     - None
1778     -
1779   * - ``zip()``
1780     - Full
1781     -
1782   * - ``__import__()``
1783     - None
1784     -
1785
1786.. _python-builtin-values-support:
1787
1788Python Built-in Values Support
1789^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
1790.. list-table:: TorchScript Support for Python Built-in Values
1791   :widths: 25 25 50
1792   :header-rows: 1
1793
1794   * - Built-in Value
1795     - Support Level
1796     - Notes
1797   * - ``False``
1798     - Full
1799     -
1800   * - ``True``
1801     - Full
1802     -
1803   * - ``None``
1804     - Full
1805     -
1806   * - ``NotImplemented``
1807     - None
1808     -
1809   * - ``Ellipsis``
1810     - Full
1811     -
1812
1813
1814.. _torch_apis_in_torchscript:
1815
1816torch.* APIs
1817~~~~~~~~~~~~
1818
1819.. _torch_apis_in_torchscript_rpc:
1820
1821Remote Procedure Calls
1822^^^^^^^^^^^^^^^^^^^^^^
1823
1824TorchScript supports a subset of RPC APIs that supports running a function on
1825a specified remote worker instead of locally.
1826
1827Specifically, following APIs are fully supported:
1828
1829- ``torch.distributed.rpc.rpc_sync()``
1830    - ``rpc_sync()`` makes a blocking RPC call to run a function on a remote worker. RPC messages are sent and received in parallel to execution of Python code.
1831    - More details about its usage and examples can be found in :meth:`~torch.distributed.rpc.rpc_sync`.
1832
1833- ``torch.distributed.rpc.rpc_async()``
1834    - ``rpc_async()`` makes a non-blocking RPC call to run a function on a remote worker. RPC messages are sent and received in parallel to execution of Python code.
1835    - More details about its usage and examples can be found in :meth:`~torch.distributed.rpc.rpc_async`.
1836- ``torch.distributed.rpc.remote()``
1837    - ``remote.()`` executes a remote call on a worker and gets a Remote Reference ``RRef`` as the return value.
1838    - More details about its usage and examples can be found in :meth:`~torch.distributed.rpc.remote`.
1839
1840.. _torch_apis_in_torchscript_async:
1841
1842Asynchronous Execution
1843^^^^^^^^^^^^^^^^^^^^^^
1844
1845TorchScript enables you to create asynchronous computation tasks to make better use
1846of computation resources. This is done via supporting a list of APIs that are
1847only usable within TorchScript:
1848
1849- ``torch.jit.fork()``
1850    - Creates an asynchronous task executing func and a reference to the value of the result of this execution. Fork will return immediately.
1851    - Synonymous to ``torch.jit._fork()``, which is only kept for backward compatibility reasons.
1852    - More details about its usage and examples can be found in :meth:`~torch.jit.fork`.
1853- ``torch.jit.wait()``
1854    - Forces completion of a ``torch.jit.Future[T]`` asynchronous task, returning the result of the task.
1855    - Synonymous to ``torch.jit._wait()``, which is only kept for backward compatibility reasons.
1856    - More details about its usage and examples can be found in :meth:`~torch.jit.wait`.
1857
1858
1859.. _torch_apis_in_torchscript_annotation:
1860
1861Type Annotations
1862^^^^^^^^^^^^^^^^
1863
1864TorchScript is statically-typed. It provides and supports a set of utilities to help annotate variables and attributes:
1865
1866- ``torch.jit.annotate()``
1867    - Provides a type hint to TorchScript where Python 3 style type hints do not work well.
1868    - One common example is to annotate type for expressions like ``[]``. ``[]`` is treated as ``List[torch.Tensor]`` by default. When a different type is needed, you can use this code to hint TorchScript: ``torch.jit.annotate(List[int], [])``.
1869    - More details can be found in :meth:`~torch.jit.annotate`
1870- ``torch.jit.Attribute``
1871    - Common use cases include providing type hint for ``torch.nn.Module`` attributes. Because their ``__init__`` methods are not parsed by TorchScript, ``torch.jit.Attribute`` should be used instead of ``torch.jit.annotate`` in the module's ``__init__`` methods.
1872    - More details can be found in :meth:`~torch.jit.Attribute`
1873- ``torch.jit.Final``
1874    - An alias for Python's ``typing.Final``. ``torch.jit.Final`` is kept only for backward compatibility reasons.
1875
1876
1877.. _torch_apis_in_torchscript_meta_programming:
1878
1879Meta Programming
1880^^^^^^^^^^^^^^^^
1881
1882TorchScript provides a set of utilities to facilitate meta programming:
1883
1884- ``torch.jit.is_scripting()``
1885    - Returns a boolean value indicating whether the current program is compiled by ``torch.jit.script`` or not.
1886    - When used in an ``assert`` or an ``if`` statement, the scope or branch where ``torch.jit.is_scripting()`` evaluates to ``False`` is not compiled.
1887    - Its value can be evaluated statically at compile time, thus commonly used in ``if`` statements to stop TorchScript from compiling one of the branches.
1888    - More details and examples can be found in :meth:`~torch.jit.is_scripting`
1889- ``torch.jit.is_tracing()``
1890    - Returns a boolean value indicating whether the current program is traced by ``torch.jit.trace`` / ``torch.jit.trace_module`` or not.
1891    - More details can be found in :meth:`~torch.jit.is_tracing`
1892- ``@torch.jit.ignore``
1893    - This decorator indicates to the compiler that a function or method should be ignored and left as a Python function.
1894    - This allows you to leave code in your model that is not yet TorchScript compatible.
1895    - If a function decorated by ``@torch.jit.ignore`` is called from TorchScript, ignored functions will dispatch the call to the Python interpreter.
1896    - Models with ignored functions cannot be exported.
1897    - More details and examples can be found in :meth:`~torch.jit.ignore`
1898- ``@torch.jit.unused``
1899    - This decorator indicates to the compiler that a function or method should be ignored and replaced with the raising of an exception.
1900    - This allows you to leave code in your model that is not yet TorchScript compatible and still export your model.
1901    - If a function decorated by ``@torch.jit.unused`` is called from TorchScript, a runtime error will be raised.
1902    - More details and examples can be found in :meth:`~torch.jit.unused`
1903
1904.. _torch_apis_in_torchscript_type_refinement:
1905
1906Type Refinement
1907^^^^^^^^^^^^^^^
1908
1909- ``torch.jit.isinstance()``
1910    - Returns a boolean indicating whether a variable is of the specified type.
1911    - More details about its usage and examples can be found in :meth:`~torch.jit.isinstance`.
1912