• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1.. currentmodule:: torch
2
3.. _tensor-attributes-doc:
4
5Tensor Attributes
6=================
7
8Each ``torch.Tensor`` has a :class:`torch.dtype`, :class:`torch.device`, and :class:`torch.layout`.
9
10.. _dtype-doc:
11
12torch.dtype
13-----------
14
15.. class:: dtype
16
17A :class:`torch.dtype` is an object that represents the data type of a
18:class:`torch.Tensor`. PyTorch has twelve different data types:
19
20========================== ===========================================   ===========================
21Data type                  dtype                                         Legacy Constructors
22========================== ===========================================   ===========================
2332-bit floating point      ``torch.float32`` or ``torch.float``          ``torch.*.FloatTensor``
2464-bit floating point      ``torch.float64`` or ``torch.double``         ``torch.*.DoubleTensor``
2564-bit complex             ``torch.complex64`` or ``torch.cfloat``
26128-bit complex            ``torch.complex128`` or ``torch.cdouble``
2716-bit floating point [1]_ ``torch.float16`` or ``torch.half``           ``torch.*.HalfTensor``
2816-bit floating point [2]_ ``torch.bfloat16``                            ``torch.*.BFloat16Tensor``
298-bit integer (unsigned)   ``torch.uint8``                               ``torch.*.ByteTensor``
308-bit integer (signed)     ``torch.int8``                                ``torch.*.CharTensor``
3116-bit integer (signed)    ``torch.int16`` or ``torch.short``            ``torch.*.ShortTensor``
3232-bit integer (signed)    ``torch.int32`` or ``torch.int``              ``torch.*.IntTensor``
3364-bit integer (signed)    ``torch.int64`` or ``torch.long``             ``torch.*.LongTensor``
34Boolean                    ``torch.bool``                                ``torch.*.BoolTensor``
35========================== ===========================================   ===========================
36
37.. [1] Sometimes referred to as binary16: uses 1 sign, 5 exponent, and 10
38  significand bits. Useful when precision is important.
39
40.. [2] Sometimes referred to as Brain Floating Point: use 1 sign, 8 exponent and 7
41  significand bits. Useful when range is important, since it has the same
42  number of exponent bits as ``float32``
43
44To find out if a :class:`torch.dtype` is a floating point data type, the property :attr:`is_floating_point`
45can be used, which returns ``True`` if the data type is a floating point data type.
46
47To find out if a :class:`torch.dtype` is a complex data type, the property :attr:`is_complex`
48can be used, which returns ``True`` if the data type is a complex data type.
49
50.. _type-promotion-doc:
51
52When the dtypes of inputs to an arithmetic operation (`add`, `sub`, `div`, `mul`) differ, we promote
53by finding the minimum dtype that satisfies the following rules:
54
55* If the type of a scalar operand is of a higher category than tensor operands
56  (where complex > floating > integral > boolean), we promote to a type with sufficient size to hold
57  all scalar operands of that category.
58* If a zero-dimension tensor operand has a higher category than dimensioned operands,
59  we promote to a type with sufficient size and category to hold all zero-dim tensor operands of
60  that category.
61* If there are no higher-category zero-dim operands, we promote to a type with sufficient size
62  and category to hold all dimensioned operands.
63
64A floating point scalar operand has dtype `torch.get_default_dtype()` and an integral
65non-boolean scalar operand has dtype `torch.int64`. Unlike numpy, we do not inspect
66values when determining the minimum `dtypes` of an operand.  Quantized and complex types
67are not yet supported.
68
69Promotion Examples::
70
71    >>> float_tensor = torch.ones(1, dtype=torch.float)
72    >>> double_tensor = torch.ones(1, dtype=torch.double)
73    >>> complex_float_tensor = torch.ones(1, dtype=torch.complex64)
74    >>> complex_double_tensor = torch.ones(1, dtype=torch.complex128)
75    >>> int_tensor = torch.ones(1, dtype=torch.int)
76    >>> long_tensor = torch.ones(1, dtype=torch.long)
77    >>> uint_tensor = torch.ones(1, dtype=torch.uint8)
78    >>> double_tensor = torch.ones(1, dtype=torch.double)
79    >>> bool_tensor = torch.ones(1, dtype=torch.bool)
80    # zero-dim tensors
81    >>> long_zerodim = torch.tensor(1, dtype=torch.long)
82    >>> int_zerodim = torch.tensor(1, dtype=torch.int)
83
84    >>> torch.add(5, 5).dtype
85    torch.int64
86    # 5 is an int64, but does not have higher category than int_tensor so is not considered.
87    >>> (int_tensor + 5).dtype
88    torch.int32
89    >>> (int_tensor + long_zerodim).dtype
90    torch.int32
91    >>> (long_tensor + int_tensor).dtype
92    torch.int64
93    >>> (bool_tensor + long_tensor).dtype
94    torch.int64
95    >>> (bool_tensor + uint_tensor).dtype
96    torch.uint8
97    >>> (float_tensor + double_tensor).dtype
98    torch.float64
99    >>> (complex_float_tensor + complex_double_tensor).dtype
100    torch.complex128
101    >>> (bool_tensor + int_tensor).dtype
102    torch.int32
103    # Since long is a different kind than float, result dtype only needs to be large enough
104    # to hold the float.
105    >>> torch.add(long_tensor, float_tensor).dtype
106    torch.float32
107
108When the output tensor of an arithmetic operation is specified, we allow casting to its `dtype` except that:
109  * An integral output tensor cannot accept a floating point tensor.
110  * A boolean output tensor cannot accept a non-boolean tensor.
111  * A non-complex output tensor cannot accept a complex tensor
112
113Casting Examples::
114
115    # allowed:
116    >>> float_tensor *= float_tensor
117    >>> float_tensor *= int_tensor
118    >>> float_tensor *= uint_tensor
119    >>> float_tensor *= bool_tensor
120    >>> float_tensor *= double_tensor
121    >>> int_tensor *= long_tensor
122    >>> int_tensor *= uint_tensor
123    >>> uint_tensor *= int_tensor
124
125    # disallowed (RuntimeError: result type can't be cast to the desired output type):
126    >>> int_tensor *= float_tensor
127    >>> bool_tensor *= int_tensor
128    >>> bool_tensor *= uint_tensor
129    >>> float_tensor *= complex_float_tensor
130
131
132.. _device-doc:
133
134torch.device
135------------
136
137.. class:: device
138
139A :class:`torch.device` is an object representing the device on which a :class:`torch.Tensor` is
140or will be allocated.
141
142The :class:`torch.device` contains a device type (most commonly "cpu" or
143"cuda", but also potentially :doc:`"mps" <mps>`, :doc:`"xpu" <xpu>`,
144`"xla" <https://github.com/pytorch/xla/>`_ or :doc:`"meta" <meta>`) and optional
145device ordinal for the device type. If the device ordinal is not present, this object will always represent
146the current device for the device type, even after :func:`torch.cuda.set_device()` is called; e.g.,
147a :class:`torch.Tensor` constructed with device ``'cuda'`` is equivalent to ``'cuda:X'`` where X is
148the result of :func:`torch.cuda.current_device()`.
149
150A :class:`torch.Tensor`'s device can be accessed via the :attr:`Tensor.device` property.
151
152A :class:`torch.device` can be constructed via a string or via a string and device ordinal
153
154Via a string:
155::
156
157    >>> torch.device('cuda:0')
158    device(type='cuda', index=0)
159
160    >>> torch.device('cpu')
161    device(type='cpu')
162
163    >>> torch.device('mps')
164    device(type='mps')
165
166    >>> torch.device('cuda')  # current cuda device
167    device(type='cuda')
168
169Via a string and device ordinal:
170
171::
172
173    >>> torch.device('cuda', 0)
174    device(type='cuda', index=0)
175
176    >>> torch.device('mps', 0)
177    device(type='mps', index=0)
178
179    >>> torch.device('cpu', 0)
180    device(type='cpu', index=0)
181
182The device object can also be used as a context manager to change the default
183device tensors are allocated on:
184
185::
186
187    >>> with torch.device('cuda:1'):
188    ...     r = torch.randn(2, 3)
189    >>> r.device
190    device(type='cuda', index=1)
191
192This context manager has no effect if a factory function is passed an explicit,
193non-None device argument.  To globally change the default device, see also
194:func:`torch.set_default_device`.
195
196.. warning::
197
198    This function imposes a slight performance cost on every Python
199    call to the torch API (not just factory functions).  If this
200    is causing problems for you, please comment on
201    https://github.com/pytorch/pytorch/issues/92701
202
203.. note::
204   The :class:`torch.device` argument in functions can generally be substituted with a string.
205   This allows for fast prototyping of code.
206
207   >>> # Example of a function that takes in a torch.device
208   >>> cuda1 = torch.device('cuda:1')
209   >>> torch.randn((2,3), device=cuda1)
210
211   >>> # You can substitute the torch.device with a string
212   >>> torch.randn((2,3), device='cuda:1')
213
214.. note::
215   For legacy reasons, a device can be constructed via a single device ordinal, which is treated
216   as the current :ref:`accelerator<accelerators>` type.
217   This matches :meth:`Tensor.get_device`, which returns an ordinal for device
218   tensors and is not supported for cpu tensors.
219
220   >>> torch.device(1)
221   device(type='cuda', index=1)
222
223.. note::
224   Methods which take a device will generally accept a (properly formatted) string
225   or (legacy) integer device ordinal, i.e. the following are all equivalent:
226
227   >>> torch.randn((2,3), device=torch.device('cuda:1'))
228   >>> torch.randn((2,3), device='cuda:1')
229   >>> torch.randn((2,3), device=1)  # legacy
230
231.. note::
232   Tensors are never moved automatically between devices and require an explicit call from the user. Scalar Tensors (with tensor.dim()==0) are the only exception to this rule and they are automatically transferred from CPU to GPU when needed as this operation can be done "for free".
233   Example:
234
235   >>> # two scalars
236   >>> torch.ones(()) + torch.ones(()).cuda()  # OK, scalar auto-transferred from CPU to GPU
237   >>> torch.ones(()).cuda() + torch.ones(())  # OK, scalar auto-transferred from CPU to GPU
238
239   >>> # one scalar (CPU), one vector (GPU)
240   >>> torch.ones(()) + torch.ones(1).cuda()  # OK, scalar auto-transferred from CPU to GPU
241   >>> torch.ones(1).cuda() + torch.ones(())  # OK, scalar auto-transferred from CPU to GPU
242
243   >>> # one scalar (GPU), one vector (CPU)
244   >>> torch.ones(()).cuda() + torch.ones(1)  # Fail, scalar not auto-transferred from GPU to CPU and non-scalar not auto-transferred from CPU to GPU
245   >>> torch.ones(1) + torch.ones(()).cuda()  # Fail, scalar not auto-transferred from GPU to CPU and non-scalar not auto-transferred from CPU to GPU
246
247.. _layout-doc:
248
249torch.layout
250------------
251
252.. class:: layout
253
254.. warning::
255  The ``torch.layout`` class is in beta and subject to change.
256
257A :class:`torch.layout` is an object that represents the memory layout of a
258:class:`torch.Tensor`. Currently, we support ``torch.strided`` (dense Tensors)
259and have beta support for ``torch.sparse_coo`` (sparse COO Tensors).
260
261``torch.strided`` represents dense Tensors and is the memory layout that
262is most commonly used. Each strided tensor has an associated
263:class:`torch.Storage`, which holds its data. These tensors provide
264multi-dimensional, `strided <https://en.wikipedia.org/wiki/Stride_of_an_array>`_
265view of a storage. Strides are a list of integers: the k-th stride
266represents the jump in the memory necessary to go from one element to the
267next one in the k-th dimension of the Tensor. This concept makes it possible
268to perform many tensor operations efficiently.
269
270Example::
271
272    >>> x = torch.tensor([[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]])
273    >>> x.stride()
274    (5, 1)
275
276    >>> x.t().stride()
277    (1, 5)
278
279For more information on ``torch.sparse_coo`` tensors, see :ref:`sparse-docs`.
280
281torch.memory_format
282-------------------
283
284.. class:: memory_format
285
286A :class:`torch.memory_format` is an object representing the memory format on which a :class:`torch.Tensor` is
287or will be allocated.
288
289Possible values are:
290
291- ``torch.contiguous_format``:
292  Tensor is or will be allocated in dense non-overlapping memory. Strides represented by values in decreasing order.
293
294- ``torch.channels_last``:
295  Tensor is or will be allocated in dense non-overlapping memory. Strides represented by values in
296  ``strides[0] > strides[2] > strides[3] > strides[1] == 1`` aka NHWC order.
297
298- ``torch.channels_last_3d``:
299  Tensor is or will be allocated in dense non-overlapping memory. Strides represented by values in
300  ``strides[0] > strides[2] > strides[3] > strides[4] > strides[1] == 1`` aka NDHWC order.
301
302- ``torch.preserve_format``:
303  Used in functions like `clone` to preserve the memory format of the input tensor. If input tensor is
304  allocated in dense non-overlapping memory, the output tensor strides will be copied from the input.
305  Otherwise output strides will follow ``torch.contiguous_format``
306