• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1.. _cond:
2
3Control Flow - Cond
4====================
5
6`torch.cond` is a structured control flow operator. It can be used to specify if-else like control flow
7and can logically be seen as implemented as follows.
8
9.. code-block:: python
10
11    def cond(
12        pred: Union[bool, torch.Tensor],
13        true_fn: Callable,
14        false_fn: Callable,
15        operands: Tuple[torch.Tensor]
16    ):
17        if pred:
18            return true_fn(*operands)
19        else:
20            return false_fn(*operands)
21
22Its unique power lies in its ability of expressing **data-dependent control flow**: it lowers to a conditional
23operator (`torch.ops.higher_order.cond`), which preserves predicate, true function and false functions.
24This unlocks great flexibility in writing and deploying models that change model architecture based on
25the **value** or **shape** of inputs or intermediate outputs of tensor operations.
26
27.. warning::
28    `torch.cond` is a prototype feature in PyTorch. It has limited support for input and output types and
29    doesn't support training currently. Please look forward to a more stable implementation in a future version of PyTorch.
30    Read more about feature classification at: https://pytorch.org/blog/pytorch-feature-classification-changes/#prototype
31
32Examples
33~~~~~~~~
34
35Below is an example that uses cond to branch based on input shape:
36
37.. code-block:: python
38
39    import torch
40
41    def true_fn(x: torch.Tensor):
42        return x.cos() + x.sin()
43
44    def false_fn(x: torch.Tensor):
45        return x.sin()
46
47    class DynamicShapeCondPredicate(torch.nn.Module):
48        """
49        A basic usage of cond based on dynamic shape predicate.
50        """
51
52        def __init__(self):
53            super().__init__()
54
55        def forward(self, x: torch.Tensor) -> torch.Tensor:
56            def true_fn(x: torch.Tensor):
57                return x.cos()
58
59            def false_fn(x: torch.Tensor):
60                return x.sin()
61
62            return torch.cond(x.shape[0] > 4, true_fn, false_fn, (x,))
63
64    dyn_shape_mod = DynamicShapeCondPredicate()
65
66We can eagerly run the model and expect the results vary based on input shape:
67
68.. code-block:: python
69
70    inp = torch.randn(3)
71    inp2 = torch.randn(5)
72    assert torch.equal(dyn_shape_mod(inp), false_fn(inp))
73    assert torch.equal(dyn_shape_mod(inp2), true_fn(inp2))
74
75We can export the model for further transformations and deployment:
76
77.. code-block:: python
78
79    inp = torch.randn(4, 3)
80    dim_batch = torch.export.Dim("batch", min=2)
81    ep = torch.export.export(DynamicShapeCondPredicate(), (inp,), {}, dynamic_shapes={"x": {0: dim_batch}})
82    print(ep)
83
84This gives us an exported program as shown below:
85
86.. code-block::
87
88    class GraphModule(torch.nn.Module):
89        def forward(self, arg0_1: f32[s0, 3]):
90            sym_size: Sym(s0) = torch.ops.aten.sym_size.int(arg0_1, 0)
91            gt: Sym(s0 > 4) = sym_size > 4;  sym_size = None
92            true_graph_0 = self.true_graph_0
93            false_graph_0 = self.false_graph_0
94            conditional: f32[s0, 3] = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [arg0_1]);  gt = true_graph_0 = false_graph_0 = arg0_1 = None
95            return (conditional,)
96
97        class <lambda>(torch.nn.Module):
98            def forward(self, arg0_1: f32[s0, 3]):
99                cos: f32[s0, 3] = torch.ops.aten.cos.default(arg0_1)
100                sin: f32[s0, 3] = torch.ops.aten.sin.default(arg0_1);  arg0_1 = None
101                add: f32[s0, 3] = torch.ops.aten.add.Tensor(cos, sin);  cos = sin = None
102                return add
103
104        class <lambda>(torch.nn.Module):
105            def forward(self, arg0_1: f32[s0, 3]):
106                sin: f32[s0, 3] = torch.ops.aten.sin.default(arg0_1);  arg0_1 = None
107                return sin
108
109Notice that `torch.cond` is lowered to `torch.ops.higher_order.cond`, its predicate becomes a Symbolic expression over the shape of input,
110and branch functions becomes two sub-graph attributes of the top level graph module.
111
112Here is another example that showcases how to express a data-dependent control flow:
113
114.. code-block:: python
115
116    class DataDependentCondPredicate(torch.nn.Module):
117        """
118        A basic usage of cond based on data dependent predicate.
119        """
120        def __init__(self):
121            super().__init__()
122
123        def forward(self, x: torch.Tensor) -> torch.Tensor:
124            return torch.cond(x.sum() > 4.0, true_fn, false_fn, (x,))
125
126The exported program we get after export:
127
128.. code-block::
129
130    class GraphModule(torch.nn.Module):
131        def forward(self, arg0_1: f32[s0, 3]):
132            sum_1: f32[] = torch.ops.aten.sum.default(arg0_1)
133            gt: b8[] = torch.ops.aten.gt.Scalar(sum_1, 4.0);  sum_1 = None
134
135            true_graph_0 = self.true_graph_0
136            false_graph_0 = self.false_graph_0
137            conditional: f32[s0, 3] = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [arg0_1]);  gt = true_graph_0 = false_graph_0 = arg0_1 = None
138            return (conditional,)
139
140        class <lambda>(torch.nn.Module):
141            def forward(self, arg0_1: f32[s0, 3]):
142                cos: f32[s0, 3] = torch.ops.aten.cos.default(arg0_1)
143                sin: f32[s0, 3] = torch.ops.aten.sin.default(arg0_1);  arg0_1 = None
144                add: f32[s0, 3] = torch.ops.aten.add.Tensor(cos, sin);  cos = sin = None
145                return add
146
147        class <lambda>(torch.nn.Module):
148            def forward(self, arg0_1: f32[s0, 3]):
149                sin: f32[s0, 3] = torch.ops.aten.sin.default(arg0_1);  arg0_1 = None
150                return sin
151
152
153Invariants of torch.ops.higher_order.cond
154~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
155
156There are several useful invariants for `torch.ops.higher_order.cond`:
157
158- For predicate:
159    - Dynamicness of predicate is preserved (e.g. `gt` shown in the above example)
160    - If the predicate in user-program is constant (e.g. a python bool constant), the `pred` of the operator will be a constant.
161
162- For branches:
163    - The input and output signature will be a flattened tuple.
164    - They are `torch.fx.GraphModule`.
165    - Closures in original function becomes explicit inputs. No closures.
166    - No mutations on inputs or globals are allowed.
167
168- For operands:
169    - It will also be a flat tuple.
170
171- Nesting of `torch.cond` in user program becomes nested graph modules.
172
173
174API Reference
175-------------
176.. autofunction:: torch._higher_order_ops.cond.cond
177