1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176
|
.. _cond:
Control Flow - Cond
====================
`torch.cond` is a structured control flow operator. It can be used to specify if-else like control flow
and can logically be seen as implemented as follows.
.. code-block:: python
def cond(
pred: Union[bool, torch.Tensor],
true_fn: Callable,
false_fn: Callable,
operands: Tuple[torch.Tensor]
):
if pred:
return true_fn(*operands)
else:
return false_fn(*operands)
Its unique power lies in its ability of expressing **data-dependent control flow**: it lowers to a conditional
operator (`torch.ops.higher_order.cond`), which preserves predicate, true function and false functions.
This unlocks great flexibility in writing and deploying models that change model architecture based on
the **value** or **shape** of inputs or intermediate outputs of tensor operations.
.. warning::
`torch.cond` is a prototype feature in PyTorch. It has limited support for input and output types and
doesn't support training currently. Please look forward to a more stable implementation in a future version of PyTorch.
Read more about feature classification at: https://pytorch.org/blog/pytorch-feature-classification-changes/#prototype
Examples
~~~~~~~~
Below is an example that uses cond to branch based on input shape:
.. code-block:: python
import torch
def true_fn(x: torch.Tensor):
return x.cos() + x.sin()
def false_fn(x: torch.Tensor):
return x.sin()
class DynamicShapeCondPredicate(torch.nn.Module):
"""
A basic usage of cond based on dynamic shape predicate.
"""
def __init__(self):
super().__init__()
def forward(self, x: torch.Tensor) -> torch.Tensor:
def true_fn(x: torch.Tensor):
return x.cos()
def false_fn(x: torch.Tensor):
return x.sin()
return torch.cond(x.shape[0] > 4, true_fn, false_fn, (x,))
dyn_shape_mod = DynamicShapeCondPredicate()
We can eagerly run the model and expect the results vary based on input shape:
.. code-block:: python
inp = torch.randn(3)
inp2 = torch.randn(5)
assert torch.equal(dyn_shape_mod(inp), false_fn(inp))
assert torch.equal(dyn_shape_mod(inp2), true_fn(inp2))
We can export the model for further transformations and deployment:
.. code-block:: python
inp = torch.randn(4, 3)
dim_batch = torch.export.Dim("batch", min=2)
ep = torch.export.export(DynamicShapeCondPredicate(), (inp,), {}, dynamic_shapes={"x": {0: dim_batch}})
print(ep)
This gives us an exported program as shown below:
.. code-block::
class GraphModule(torch.nn.Module):
def forward(self, arg0_1: f32[s0, 3]):
sym_size: Sym(s0) = torch.ops.aten.sym_size.int(arg0_1, 0)
gt: Sym(s0 > 4) = sym_size > 4; sym_size = None
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
conditional: f32[s0, 3] = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [arg0_1]); gt = true_graph_0 = false_graph_0 = arg0_1 = None
return (conditional,)
class <lambda>(torch.nn.Module):
def forward(self, arg0_1: f32[s0, 3]):
cos: f32[s0, 3] = torch.ops.aten.cos.default(arg0_1)
sin: f32[s0, 3] = torch.ops.aten.sin.default(arg0_1); arg0_1 = None
add: f32[s0, 3] = torch.ops.aten.add.Tensor(cos, sin); cos = sin = None
return add
class <lambda>(torch.nn.Module):
def forward(self, arg0_1: f32[s0, 3]):
sin: f32[s0, 3] = torch.ops.aten.sin.default(arg0_1); arg0_1 = None
return sin
Notice that `torch.cond` is lowered to `torch.ops.higher_order.cond`, its predicate becomes a Symbolic expression over the shape of input,
and branch functions becomes two sub-graph attributes of the top level graph module.
Here is another example that showcases how to express a data-dependent control flow:
.. code-block:: python
class DataDependentCondPredicate(torch.nn.Module):
"""
A basic usage of cond based on data dependent predicate.
"""
def __init__(self):
super().__init__()
def forward(self, x: torch.Tensor) -> torch.Tensor:
return torch.cond(x.sum() > 4.0, true_fn, false_fn, (x,))
The exported program we get after export:
.. code-block::
class GraphModule(torch.nn.Module):
def forward(self, arg0_1: f32[s0, 3]):
sum_1: f32[] = torch.ops.aten.sum.default(arg0_1)
gt: b8[] = torch.ops.aten.gt.Scalar(sum_1, 4.0); sum_1 = None
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
conditional: f32[s0, 3] = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [arg0_1]); gt = true_graph_0 = false_graph_0 = arg0_1 = None
return (conditional,)
class <lambda>(torch.nn.Module):
def forward(self, arg0_1: f32[s0, 3]):
cos: f32[s0, 3] = torch.ops.aten.cos.default(arg0_1)
sin: f32[s0, 3] = torch.ops.aten.sin.default(arg0_1); arg0_1 = None
add: f32[s0, 3] = torch.ops.aten.add.Tensor(cos, sin); cos = sin = None
return add
class <lambda>(torch.nn.Module):
def forward(self, arg0_1: f32[s0, 3]):
sin: f32[s0, 3] = torch.ops.aten.sin.default(arg0_1); arg0_1 = None
return sin
Invariants of torch.ops.higher_order.cond
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
There are several useful invariants for `torch.ops.higher_order.cond`:
- For predicate:
- Dynamicness of predicate is preserved (e.g. `gt` shown in the above example)
- If the predicate in user-program is constant (e.g. a python bool constant), the `pred` of the operator will be a constant.
- For branches:
- The input and output signature will be a flattened tuple.
- They are `torch.fx.GraphModule`.
- Closures in original function becomes explicit inputs. No closures.
- No mutations on inputs or globals are allowed.
- For operands:
- It will also be a flat tuple.
- Nesting of `torch.cond` in user program becomes nested graph modules.
API Reference
-------------
.. autofunction:: torch._higher_order_ops.cond.cond
|