• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1torch.nested
2============
3
4.. automodule:: torch.nested
5
6Introduction
7++++++++++++
8
9.. warning::
10
11  The PyTorch API of nested tensors is in prototype stage and will change in the near future.
12
13NestedTensor allows the user to pack a list of Tensors into a single, efficient datastructure.
14
15The only constraint on the input Tensors is that their dimension must match.
16
17This enables more efficient metadata representations and access to purpose built kernels.
18
19One application of NestedTensors is to express sequential data in various domains.
20While the conventional approach is to pad variable length sequences, NestedTensor
21enables users to bypass padding. The API for calling operations on a nested tensor is no different
22from that of a regular ``torch.Tensor``, which should allow seamless integration with existing models,
23with the main difference being :ref:`construction of the inputs <construction>`.
24
25As this is a prototype feature, the :ref:`operations supported <supported operations>` are still
26limited. However, we welcome issues, feature requests and contributions. More information on contributing can be found
27`in this Readme <https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/nested/README.md>`_.
28
29.. _construction:
30
31Construction
32++++++++++++
33
34Construction is straightforward and involves passing a list of Tensors to the ``torch.nested.nested_tensor``
35constructor.
36
37>>> a, b = torch.arange(3), torch.arange(5) + 3
38>>> a
39tensor([0, 1, 2])
40>>> b
41tensor([3, 4, 5, 6, 7])
42>>> nt = torch.nested.nested_tensor([a, b])
43>>> nt
44nested_tensor([
45  tensor([0, 1, 2]),
46    tensor([3, 4, 5, 6, 7])
47    ])
48
49Data type, device and whether gradients are required can be chosen via the usual keyword arguments.
50
51>>> nt = torch.nested.nested_tensor([a, b], dtype=torch.float32, device="cuda", requires_grad=True)
52>>> nt
53nested_tensor([
54  tensor([0., 1., 2.], device='cuda:0', requires_grad=True),
55  tensor([3., 4., 5., 6., 7.], device='cuda:0', requires_grad=True)
56], device='cuda:0', requires_grad=True)
57
58In the vein of ``torch.as_tensor``, ``torch.nested.as_nested_tensor`` can be used to preserve autograd
59history from the tensors passed to the constructor. For more information, refer to the section on
60:ref:`constructor functions`.
61
62In order to form a valid NestedTensor all the passed Tensors need to match in dimension, but none of the other attributes need to.
63
64>>> a = torch.randn(3, 50, 70) # image 1
65>>> b = torch.randn(3, 128, 64) # image 2
66>>> nt = torch.nested.nested_tensor([a, b], dtype=torch.float32)
67>>> nt.dim()
684
69
70If one of the dimensions doesn't match, the constructor throws an error.
71
72>>> a = torch.randn(50, 128) # text 1
73>>> b = torch.randn(3, 128, 64) # image 2
74>>> nt = torch.nested.nested_tensor([a, b], dtype=torch.float32)
75Traceback (most recent call last):
76  File "<stdin>", line 1, in <module>
77RuntimeError: All Tensors given to nested_tensor must have the same dimension. Found dimension 3 for Tensor at index 1 and dimension 2 for Tensor at index 0.
78
79Note that the passed Tensors are being copied into a contiguous piece of memory. The resulting
80NestedTensor allocates new memory to store them and does not keep a reference.
81
82At this moment we only support one level of nesting, i.e. a simple, flat list of Tensors. In the future
83we can add support for multiple levels of nesting, such as a list that consists entirely of lists of Tensors.
84Note that for this extension it is important to maintain an even level of nesting across entries so that the resulting NestedTensor
85has a well defined dimension. If you have a need for this feature, please feel encouraged to open a feature request so that
86we can track it and plan accordingly.
87
88size
89+++++++++++++++++++++++++
90
91Even though a NestedTensor does not support ``.size()`` (or ``.shape``), it supports ``.size(i)`` if dimension i is regular.
92
93>>> a = torch.randn(50, 128) # text 1
94>>> b = torch.randn(32, 128) # text 2
95>>> nt = torch.nested.nested_tensor([a, b], dtype=torch.float32)
96>>> nt.size(0)
972
98>>> nt.size(1)
99Traceback (most recent call last):
100  File "<stdin>", line 1, in <module>
101RuntimeError: Given dimension 1 is irregular and does not have a size.
102>>> nt.size(2)
103128
104
105If all dimensions are regular, the NestedTensor is intended to be semantically indistinguishable from a regular ``torch.Tensor``.
106
107>>> a = torch.randn(20, 128) # text 1
108>>> nt = torch.nested.nested_tensor([a, a], dtype=torch.float32)
109>>> nt.size(0)
1102
111>>> nt.size(1)
11220
113>>> nt.size(2)
114128
115>>> torch.stack(nt.unbind()).size()
116torch.Size([2, 20, 128])
117>>> torch.stack([a, a]).size()
118torch.Size([2, 20, 128])
119>>> torch.equal(torch.stack(nt.unbind()), torch.stack([a, a]))
120True
121
122In the future we might make it easier to detect this condition and convert seamlessly.
123
124Please open a feature request if you have a need for this (or any other related feature for that matter).
125
126unbind
127+++++++++++++++++++++++++
128
129``unbind`` allows you to retrieve a view of the constituents.
130
131>>> import torch
132>>> a = torch.randn(2, 3)
133>>> b = torch.randn(3, 4)
134>>> nt = torch.nested.nested_tensor([a, b], dtype=torch.float32)
135>>> nt
136nested_tensor([
137  tensor([[ 1.2286, -1.2343, -1.4842],
138          [-0.7827,  0.6745,  0.0658]]),
139  tensor([[-1.1247, -0.4078, -1.0633,  0.8083],
140          [-0.2871, -0.2980,  0.5559,  1.9885],
141          [ 0.4074,  2.4855,  0.0733,  0.8285]])
142])
143>>> nt.unbind()
144(tensor([[ 1.2286, -1.2343, -1.4842],
145        [-0.7827,  0.6745,  0.0658]]), tensor([[-1.1247, -0.4078, -1.0633,  0.8083],
146        [-0.2871, -0.2980,  0.5559,  1.9885],
147        [ 0.4074,  2.4855,  0.0733,  0.8285]]))
148>>> nt.unbind()[0] is not a
149True
150>>> nt.unbind()[0].mul_(3)
151tensor([[ 3.6858, -3.7030, -4.4525],
152        [-2.3481,  2.0236,  0.1975]])
153>>> nt
154nested_tensor([
155  tensor([[ 3.6858, -3.7030, -4.4525],
156          [-2.3481,  2.0236,  0.1975]]),
157  tensor([[-1.1247, -0.4078, -1.0633,  0.8083],
158          [-0.2871, -0.2980,  0.5559,  1.9885],
159          [ 0.4074,  2.4855,  0.0733,  0.8285]])
160])
161
162Note that ``nt.unbind()[0]`` is not a copy, but rather a slice of the underlying memory, which represents the first entry or constituent of the NestedTensor.
163
164.. _constructor functions:
165
166Nested tensor constructor and conversion functions
167++++++++++++++++++++++++++++++++++++++++++++++++++
168
169The following functions are related to nested tensors:
170
171.. currentmodule:: torch.nested
172
173.. autofunction:: nested_tensor
174.. autofunction:: as_nested_tensor
175.. autofunction:: to_padded_tensor
176
177.. _supported operations:
178
179Supported operations
180++++++++++++++++++++++++++
181
182In this section, we summarize the operations that are currently supported on
183NestedTensor and any constraints they have.
184
185.. csv-table::
186   :header: "PyTorch operation",  "Constraints"
187   :widths: 30, 55
188   :delim: ;
189
190   :func:`torch.matmul`;  "Supports matrix multiplication between two (>= 3d) nested tensors where
191   the last two dimensions are matrix dimensions and the leading (batch) dimensions have the same size
192   (i.e. no broadcasting support for batch dimensions yet)."
193   :func:`torch.bmm`; "Supports batch matrix multiplication of two 3-d nested tensors."
194   :func:`torch.nn.Linear`;  "Supports 3-d nested input and a dense 2-d weight matrix."
195   :func:`torch.nn.functional.softmax`; "Supports softmax along all dims except dim=0."
196   :func:`torch.nn.Dropout`; "Behavior is the same as on regular tensors."
197   :func:`torch.Tensor.masked_fill`; "Behavior is the same as on regular tensors."
198   :func:`torch.relu`; "Behavior is the same as on regular tensors."
199   :func:`torch.gelu`; "Behavior is the same as on regular tensors."
200   :func:`torch.silu`; "Behavior is the same as on regular tensors."
201   :func:`torch.abs`; "Behavior is the same as on regular tensors."
202   :func:`torch.sgn`; "Behavior is the same as on regular tensors."
203   :func:`torch.logical_not`; "Behavior is the same as on regular tensors."
204   :func:`torch.neg`; "Behavior is the same as on regular tensors."
205   :func:`torch.sub`; "Supports elementwise subtraction of two nested tensors."
206   :func:`torch.add`; "Supports elementwise addition of two nested tensors. Supports addition of a scalar to a nested tensor."
207   :func:`torch.mul`; "Supports elementwise multiplication of two nested tensors. Supports multiplication of a nested tensor by a scalar."
208   :func:`torch.select`; "Supports selecting along all dimensions."
209   :func:`torch.clone`; "Behavior is the same as on regular tensors."
210   :func:`torch.detach`; "Behavior is the same as on regular tensors."
211   :func:`torch.unbind`; "Supports unbinding along ``dim=0`` only."
212   :func:`torch.reshape`; "Supports reshaping with size of ``dim=0`` preserved (i.e. number of tensors nested cannot be changed).
213   Unlike regular tensors, a size of ``-1`` here means that the existing size is inherited.
214   In particular, the only valid size for a irregular dimension is ``-1``.
215   Size inference is not implemented yet and hence for new dimensions the size cannot be ``-1``."
216   :func:`torch.Tensor.reshape_as`; "Similar constraint as for ``reshape``."
217   :func:`torch.transpose`; "Supports transposing of all dims except ``dim=0``."
218   :func:`torch.Tensor.view`; "Rules for the new shape are similar to that of ``reshape``."
219   :func:`torch.empty_like`; "Behavior is analogous to that of regular tensors; returns a new empty nested tensor (i.e. with uninitialized values) matching the nested structure of the input."
220   :func:`torch.randn_like`; "Behavior is analogous to that of regular tensors; returns a new nested tensor with values randomly initialized according to a standard normal distribution matching the nested structure of the input."
221   :func:`torch.zeros_like`; "Behavior is analogous to that of regular tensors; returns a new nested tensor with all zero values matching the nested structure of the input."
222   :func:`torch.nn.LayerNorm`; "The ``normalized_shape`` argument is restricted to not extend into the irregular dimensions of the NestedTensor."
223