1.. _cond: 2 3Control Flow - Cond 4==================== 5 6`torch.cond` is a structured control flow operator. It can be used to specify if-else like control flow 7and can logically be seen as implemented as follows. 8 9.. code-block:: python 10 11 def cond( 12 pred: Union[bool, torch.Tensor], 13 true_fn: Callable, 14 false_fn: Callable, 15 operands: Tuple[torch.Tensor] 16 ): 17 if pred: 18 return true_fn(*operands) 19 else: 20 return false_fn(*operands) 21 22Its unique power lies in its ability of expressing **data-dependent control flow**: it lowers to a conditional 23operator (`torch.ops.higher_order.cond`), which preserves predicate, true function and false functions. 24This unlocks great flexibility in writing and deploying models that change model architecture based on 25the **value** or **shape** of inputs or intermediate outputs of tensor operations. 26 27.. warning:: 28 `torch.cond` is a prototype feature in PyTorch. It has limited support for input and output types and 29 doesn't support training currently. Please look forward to a more stable implementation in a future version of PyTorch. 30 Read more about feature classification at: https://pytorch.org/blog/pytorch-feature-classification-changes/#prototype 31 32Examples 33~~~~~~~~ 34 35Below is an example that uses cond to branch based on input shape: 36 37.. code-block:: python 38 39 import torch 40 41 def true_fn(x: torch.Tensor): 42 return x.cos() + x.sin() 43 44 def false_fn(x: torch.Tensor): 45 return x.sin() 46 47 class DynamicShapeCondPredicate(torch.nn.Module): 48 """ 49 A basic usage of cond based on dynamic shape predicate. 50 """ 51 52 def __init__(self): 53 super().__init__() 54 55 def forward(self, x: torch.Tensor) -> torch.Tensor: 56 def true_fn(x: torch.Tensor): 57 return x.cos() 58 59 def false_fn(x: torch.Tensor): 60 return x.sin() 61 62 return torch.cond(x.shape[0] > 4, true_fn, false_fn, (x,)) 63 64 dyn_shape_mod = DynamicShapeCondPredicate() 65 66We can eagerly run the model and expect the results vary based on input shape: 67 68.. code-block:: python 69 70 inp = torch.randn(3) 71 inp2 = torch.randn(5) 72 assert torch.equal(dyn_shape_mod(inp), false_fn(inp)) 73 assert torch.equal(dyn_shape_mod(inp2), true_fn(inp2)) 74 75We can export the model for further transformations and deployment: 76 77.. code-block:: python 78 79 inp = torch.randn(4, 3) 80 dim_batch = torch.export.Dim("batch", min=2) 81 ep = torch.export.export(DynamicShapeCondPredicate(), (inp,), {}, dynamic_shapes={"x": {0: dim_batch}}) 82 print(ep) 83 84This gives us an exported program as shown below: 85 86.. code-block:: 87 88 class GraphModule(torch.nn.Module): 89 def forward(self, arg0_1: f32[s0, 3]): 90 sym_size: Sym(s0) = torch.ops.aten.sym_size.int(arg0_1, 0) 91 gt: Sym(s0 > 4) = sym_size > 4; sym_size = None 92 true_graph_0 = self.true_graph_0 93 false_graph_0 = self.false_graph_0 94 conditional: f32[s0, 3] = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [arg0_1]); gt = true_graph_0 = false_graph_0 = arg0_1 = None 95 return (conditional,) 96 97 class <lambda>(torch.nn.Module): 98 def forward(self, arg0_1: f32[s0, 3]): 99 cos: f32[s0, 3] = torch.ops.aten.cos.default(arg0_1) 100 sin: f32[s0, 3] = torch.ops.aten.sin.default(arg0_1); arg0_1 = None 101 add: f32[s0, 3] = torch.ops.aten.add.Tensor(cos, sin); cos = sin = None 102 return add 103 104 class <lambda>(torch.nn.Module): 105 def forward(self, arg0_1: f32[s0, 3]): 106 sin: f32[s0, 3] = torch.ops.aten.sin.default(arg0_1); arg0_1 = None 107 return sin 108 109Notice that `torch.cond` is lowered to `torch.ops.higher_order.cond`, its predicate becomes a Symbolic expression over the shape of input, 110and branch functions becomes two sub-graph attributes of the top level graph module. 111 112Here is another example that showcases how to express a data-dependent control flow: 113 114.. code-block:: python 115 116 class DataDependentCondPredicate(torch.nn.Module): 117 """ 118 A basic usage of cond based on data dependent predicate. 119 """ 120 def __init__(self): 121 super().__init__() 122 123 def forward(self, x: torch.Tensor) -> torch.Tensor: 124 return torch.cond(x.sum() > 4.0, true_fn, false_fn, (x,)) 125 126The exported program we get after export: 127 128.. code-block:: 129 130 class GraphModule(torch.nn.Module): 131 def forward(self, arg0_1: f32[s0, 3]): 132 sum_1: f32[] = torch.ops.aten.sum.default(arg0_1) 133 gt: b8[] = torch.ops.aten.gt.Scalar(sum_1, 4.0); sum_1 = None 134 135 true_graph_0 = self.true_graph_0 136 false_graph_0 = self.false_graph_0 137 conditional: f32[s0, 3] = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [arg0_1]); gt = true_graph_0 = false_graph_0 = arg0_1 = None 138 return (conditional,) 139 140 class <lambda>(torch.nn.Module): 141 def forward(self, arg0_1: f32[s0, 3]): 142 cos: f32[s0, 3] = torch.ops.aten.cos.default(arg0_1) 143 sin: f32[s0, 3] = torch.ops.aten.sin.default(arg0_1); arg0_1 = None 144 add: f32[s0, 3] = torch.ops.aten.add.Tensor(cos, sin); cos = sin = None 145 return add 146 147 class <lambda>(torch.nn.Module): 148 def forward(self, arg0_1: f32[s0, 3]): 149 sin: f32[s0, 3] = torch.ops.aten.sin.default(arg0_1); arg0_1 = None 150 return sin 151 152 153Invariants of torch.ops.higher_order.cond 154~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 155 156There are several useful invariants for `torch.ops.higher_order.cond`: 157 158- For predicate: 159 - Dynamicness of predicate is preserved (e.g. `gt` shown in the above example) 160 - If the predicate in user-program is constant (e.g. a python bool constant), the `pred` of the operator will be a constant. 161 162- For branches: 163 - The input and output signature will be a flattened tuple. 164 - They are `torch.fx.GraphModule`. 165 - Closures in original function becomes explicit inputs. No closures. 166 - No mutations on inputs or globals are allowed. 167 168- For operands: 169 - It will also be a flat tuple. 170 171- Nesting of `torch.cond` in user program becomes nested graph modules. 172 173 174API Reference 175------------- 176.. autofunction:: torch._higher_order_ops.cond.cond 177