1torch.nested 2============ 3 4.. automodule:: torch.nested 5 6Introduction 7++++++++++++ 8 9.. warning:: 10 11 The PyTorch API of nested tensors is in prototype stage and will change in the near future. 12 13NestedTensor allows the user to pack a list of Tensors into a single, efficient datastructure. 14 15The only constraint on the input Tensors is that their dimension must match. 16 17This enables more efficient metadata representations and access to purpose built kernels. 18 19One application of NestedTensors is to express sequential data in various domains. 20While the conventional approach is to pad variable length sequences, NestedTensor 21enables users to bypass padding. The API for calling operations on a nested tensor is no different 22from that of a regular ``torch.Tensor``, which should allow seamless integration with existing models, 23with the main difference being :ref:`construction of the inputs <construction>`. 24 25As this is a prototype feature, the :ref:`operations supported <supported operations>` are still 26limited. However, we welcome issues, feature requests and contributions. More information on contributing can be found 27`in this Readme <https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/nested/README.md>`_. 28 29.. _construction: 30 31Construction 32++++++++++++ 33 34Construction is straightforward and involves passing a list of Tensors to the ``torch.nested.nested_tensor`` 35constructor. 36 37>>> a, b = torch.arange(3), torch.arange(5) + 3 38>>> a 39tensor([0, 1, 2]) 40>>> b 41tensor([3, 4, 5, 6, 7]) 42>>> nt = torch.nested.nested_tensor([a, b]) 43>>> nt 44nested_tensor([ 45 tensor([0, 1, 2]), 46 tensor([3, 4, 5, 6, 7]) 47 ]) 48 49Data type, device and whether gradients are required can be chosen via the usual keyword arguments. 50 51>>> nt = torch.nested.nested_tensor([a, b], dtype=torch.float32, device="cuda", requires_grad=True) 52>>> nt 53nested_tensor([ 54 tensor([0., 1., 2.], device='cuda:0', requires_grad=True), 55 tensor([3., 4., 5., 6., 7.], device='cuda:0', requires_grad=True) 56], device='cuda:0', requires_grad=True) 57 58In the vein of ``torch.as_tensor``, ``torch.nested.as_nested_tensor`` can be used to preserve autograd 59history from the tensors passed to the constructor. For more information, refer to the section on 60:ref:`constructor functions`. 61 62In order to form a valid NestedTensor all the passed Tensors need to match in dimension, but none of the other attributes need to. 63 64>>> a = torch.randn(3, 50, 70) # image 1 65>>> b = torch.randn(3, 128, 64) # image 2 66>>> nt = torch.nested.nested_tensor([a, b], dtype=torch.float32) 67>>> nt.dim() 684 69 70If one of the dimensions doesn't match, the constructor throws an error. 71 72>>> a = torch.randn(50, 128) # text 1 73>>> b = torch.randn(3, 128, 64) # image 2 74>>> nt = torch.nested.nested_tensor([a, b], dtype=torch.float32) 75Traceback (most recent call last): 76 File "<stdin>", line 1, in <module> 77RuntimeError: All Tensors given to nested_tensor must have the same dimension. Found dimension 3 for Tensor at index 1 and dimension 2 for Tensor at index 0. 78 79Note that the passed Tensors are being copied into a contiguous piece of memory. The resulting 80NestedTensor allocates new memory to store them and does not keep a reference. 81 82At this moment we only support one level of nesting, i.e. a simple, flat list of Tensors. In the future 83we can add support for multiple levels of nesting, such as a list that consists entirely of lists of Tensors. 84Note that for this extension it is important to maintain an even level of nesting across entries so that the resulting NestedTensor 85has a well defined dimension. If you have a need for this feature, please feel encouraged to open a feature request so that 86we can track it and plan accordingly. 87 88size 89+++++++++++++++++++++++++ 90 91Even though a NestedTensor does not support ``.size()`` (or ``.shape``), it supports ``.size(i)`` if dimension i is regular. 92 93>>> a = torch.randn(50, 128) # text 1 94>>> b = torch.randn(32, 128) # text 2 95>>> nt = torch.nested.nested_tensor([a, b], dtype=torch.float32) 96>>> nt.size(0) 972 98>>> nt.size(1) 99Traceback (most recent call last): 100 File "<stdin>", line 1, in <module> 101RuntimeError: Given dimension 1 is irregular and does not have a size. 102>>> nt.size(2) 103128 104 105If all dimensions are regular, the NestedTensor is intended to be semantically indistinguishable from a regular ``torch.Tensor``. 106 107>>> a = torch.randn(20, 128) # text 1 108>>> nt = torch.nested.nested_tensor([a, a], dtype=torch.float32) 109>>> nt.size(0) 1102 111>>> nt.size(1) 11220 113>>> nt.size(2) 114128 115>>> torch.stack(nt.unbind()).size() 116torch.Size([2, 20, 128]) 117>>> torch.stack([a, a]).size() 118torch.Size([2, 20, 128]) 119>>> torch.equal(torch.stack(nt.unbind()), torch.stack([a, a])) 120True 121 122In the future we might make it easier to detect this condition and convert seamlessly. 123 124Please open a feature request if you have a need for this (or any other related feature for that matter). 125 126unbind 127+++++++++++++++++++++++++ 128 129``unbind`` allows you to retrieve a view of the constituents. 130 131>>> import torch 132>>> a = torch.randn(2, 3) 133>>> b = torch.randn(3, 4) 134>>> nt = torch.nested.nested_tensor([a, b], dtype=torch.float32) 135>>> nt 136nested_tensor([ 137 tensor([[ 1.2286, -1.2343, -1.4842], 138 [-0.7827, 0.6745, 0.0658]]), 139 tensor([[-1.1247, -0.4078, -1.0633, 0.8083], 140 [-0.2871, -0.2980, 0.5559, 1.9885], 141 [ 0.4074, 2.4855, 0.0733, 0.8285]]) 142]) 143>>> nt.unbind() 144(tensor([[ 1.2286, -1.2343, -1.4842], 145 [-0.7827, 0.6745, 0.0658]]), tensor([[-1.1247, -0.4078, -1.0633, 0.8083], 146 [-0.2871, -0.2980, 0.5559, 1.9885], 147 [ 0.4074, 2.4855, 0.0733, 0.8285]])) 148>>> nt.unbind()[0] is not a 149True 150>>> nt.unbind()[0].mul_(3) 151tensor([[ 3.6858, -3.7030, -4.4525], 152 [-2.3481, 2.0236, 0.1975]]) 153>>> nt 154nested_tensor([ 155 tensor([[ 3.6858, -3.7030, -4.4525], 156 [-2.3481, 2.0236, 0.1975]]), 157 tensor([[-1.1247, -0.4078, -1.0633, 0.8083], 158 [-0.2871, -0.2980, 0.5559, 1.9885], 159 [ 0.4074, 2.4855, 0.0733, 0.8285]]) 160]) 161 162Note that ``nt.unbind()[0]`` is not a copy, but rather a slice of the underlying memory, which represents the first entry or constituent of the NestedTensor. 163 164.. _constructor functions: 165 166Nested tensor constructor and conversion functions 167++++++++++++++++++++++++++++++++++++++++++++++++++ 168 169The following functions are related to nested tensors: 170 171.. currentmodule:: torch.nested 172 173.. autofunction:: nested_tensor 174.. autofunction:: as_nested_tensor 175.. autofunction:: to_padded_tensor 176 177.. _supported operations: 178 179Supported operations 180++++++++++++++++++++++++++ 181 182In this section, we summarize the operations that are currently supported on 183NestedTensor and any constraints they have. 184 185.. csv-table:: 186 :header: "PyTorch operation", "Constraints" 187 :widths: 30, 55 188 :delim: ; 189 190 :func:`torch.matmul`; "Supports matrix multiplication between two (>= 3d) nested tensors where 191 the last two dimensions are matrix dimensions and the leading (batch) dimensions have the same size 192 (i.e. no broadcasting support for batch dimensions yet)." 193 :func:`torch.bmm`; "Supports batch matrix multiplication of two 3-d nested tensors." 194 :func:`torch.nn.Linear`; "Supports 3-d nested input and a dense 2-d weight matrix." 195 :func:`torch.nn.functional.softmax`; "Supports softmax along all dims except dim=0." 196 :func:`torch.nn.Dropout`; "Behavior is the same as on regular tensors." 197 :func:`torch.Tensor.masked_fill`; "Behavior is the same as on regular tensors." 198 :func:`torch.relu`; "Behavior is the same as on regular tensors." 199 :func:`torch.gelu`; "Behavior is the same as on regular tensors." 200 :func:`torch.silu`; "Behavior is the same as on regular tensors." 201 :func:`torch.abs`; "Behavior is the same as on regular tensors." 202 :func:`torch.sgn`; "Behavior is the same as on regular tensors." 203 :func:`torch.logical_not`; "Behavior is the same as on regular tensors." 204 :func:`torch.neg`; "Behavior is the same as on regular tensors." 205 :func:`torch.sub`; "Supports elementwise subtraction of two nested tensors." 206 :func:`torch.add`; "Supports elementwise addition of two nested tensors. Supports addition of a scalar to a nested tensor." 207 :func:`torch.mul`; "Supports elementwise multiplication of two nested tensors. Supports multiplication of a nested tensor by a scalar." 208 :func:`torch.select`; "Supports selecting along all dimensions." 209 :func:`torch.clone`; "Behavior is the same as on regular tensors." 210 :func:`torch.detach`; "Behavior is the same as on regular tensors." 211 :func:`torch.unbind`; "Supports unbinding along ``dim=0`` only." 212 :func:`torch.reshape`; "Supports reshaping with size of ``dim=0`` preserved (i.e. number of tensors nested cannot be changed). 213 Unlike regular tensors, a size of ``-1`` here means that the existing size is inherited. 214 In particular, the only valid size for a irregular dimension is ``-1``. 215 Size inference is not implemented yet and hence for new dimensions the size cannot be ``-1``." 216 :func:`torch.Tensor.reshape_as`; "Similar constraint as for ``reshape``." 217 :func:`torch.transpose`; "Supports transposing of all dims except ``dim=0``." 218 :func:`torch.Tensor.view`; "Rules for the new shape are similar to that of ``reshape``." 219 :func:`torch.empty_like`; "Behavior is analogous to that of regular tensors; returns a new empty nested tensor (i.e. with uninitialized values) matching the nested structure of the input." 220 :func:`torch.randn_like`; "Behavior is analogous to that of regular tensors; returns a new nested tensor with values randomly initialized according to a standard normal distribution matching the nested structure of the input." 221 :func:`torch.zeros_like`; "Behavior is analogous to that of regular tensors; returns a new nested tensor with all zero values matching the nested structure of the input." 222 :func:`torch.nn.LayerNorm`; "The ``normalized_shape`` argument is restricted to not extend into the irregular dimensions of the NestedTensor." 223