• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1torch.utils.data
2===================================
3
4.. automodule:: torch.utils.data
5
6At the heart of PyTorch data loading utility is the :class:`torch.utils.data.DataLoader`
7class.  It represents a Python iterable over a dataset, with support for
8
9* `map-style and iterable-style datasets <Dataset Types_>`_,
10
11* `customizing data loading order <Data Loading Order and Sampler_>`_,
12
13* `automatic batching <Loading Batched and Non-Batched Data_>`_,
14
15* `single- and multi-process data loading <Single- and Multi-process Data Loading_>`_,
16
17* `automatic memory pinning <Memory Pinning_>`_.
18
19These options are configured by the constructor arguments of a
20:class:`~torch.utils.data.DataLoader`, which has signature::
21
22    DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
23               batch_sampler=None, num_workers=0, collate_fn=None,
24               pin_memory=False, drop_last=False, timeout=0,
25               worker_init_fn=None, *, prefetch_factor=2,
26               persistent_workers=False)
27
28The sections below describe in details the effects and usages of these options.
29
30Dataset Types
31-------------
32
33The most important argument of :class:`~torch.utils.data.DataLoader`
34constructor is :attr:`dataset`, which indicates a dataset object to load data
35from. PyTorch supports two different types of datasets:
36
37* `map-style datasets <Map-style datasets_>`_,
38
39* `iterable-style datasets <Iterable-style datasets_>`_.
40
41Map-style datasets
42^^^^^^^^^^^^^^^^^^
43
44A map-style dataset is one that implements the :meth:`__getitem__` and
45:meth:`__len__` protocols, and represents a map from (possibly non-integral)
46indices/keys to data samples.
47
48For example, such a dataset, when accessed with ``dataset[idx]``, could read
49the ``idx``-th image and its corresponding label from a folder on the disk.
50
51See :class:`~torch.utils.data.Dataset` for more details.
52
53Iterable-style datasets
54^^^^^^^^^^^^^^^^^^^^^^^
55
56An iterable-style dataset is an instance of a subclass of :class:`~torch.utils.data.IterableDataset`
57that implements the :meth:`__iter__` protocol, and represents an iterable over
58data samples. This type of datasets is particularly suitable for cases where
59random reads are expensive or even improbable, and where the batch size depends
60on the fetched data.
61
62For example, such a dataset, when called ``iter(dataset)``, could return a
63stream of data reading from a database, a remote server, or even logs generated
64in real time.
65
66See :class:`~torch.utils.data.IterableDataset` for more details.
67
68.. note:: When using a :class:`~torch.utils.data.IterableDataset` with
69          `multi-process data loading <Multi-process data loading_>`_. The same
70          dataset object is replicated on each worker process, and thus the
71          replicas must be configured differently to avoid duplicated data. See
72          :class:`~torch.utils.data.IterableDataset` documentations for how to
73          achieve this.
74
75Data Loading Order and :class:`~torch.utils.data.Sampler`
76---------------------------------------------------------
77
78For `iterable-style datasets <Iterable-style datasets_>`_, data loading order
79is entirely controlled by the user-defined iterable. This allows easier
80implementations of chunk-reading and dynamic batch size (e.g., by yielding a
81batched sample at each time).
82
83The rest of this section concerns the case with
84`map-style datasets <Map-style datasets_>`_. :class:`torch.utils.data.Sampler`
85classes are used to specify the sequence of indices/keys used in data loading.
86They represent iterable objects over the indices to datasets.  E.g., in the
87common case with stochastic gradient decent (SGD), a
88:class:`~torch.utils.data.Sampler` could randomly permute a list of indices
89and yield each one at a time, or yield a small number of them for mini-batch
90SGD.
91
92A sequential or shuffled sampler will be automatically constructed based on the :attr:`shuffle` argument to a :class:`~torch.utils.data.DataLoader`.
93Alternatively, users may use the :attr:`sampler` argument to specify a
94custom :class:`~torch.utils.data.Sampler` object that at each time yields
95the next index/key to fetch.
96
97A custom :class:`~torch.utils.data.Sampler` that yields a list of batch
98indices at a time can be passed as the :attr:`batch_sampler` argument.
99Automatic batching can also be enabled via :attr:`batch_size` and
100:attr:`drop_last` arguments. See
101`the next section <Loading Batched and Non-Batched Data_>`_ for more details
102on this.
103
104.. note::
105  Neither :attr:`sampler` nor :attr:`batch_sampler` is compatible with
106  iterable-style datasets, since such datasets have no notion of a key or an
107  index.
108
109Loading Batched and Non-Batched Data
110------------------------------------
111
112:class:`~torch.utils.data.DataLoader` supports automatically collating
113individual fetched data samples into batches via arguments
114:attr:`batch_size`, :attr:`drop_last`, :attr:`batch_sampler`, and
115:attr:`collate_fn` (which has a default function).
116
117
118Automatic batching (default)
119^^^^^^^^^^^^^^^^^^^^^^^^^^^^
120
121This is the most common case, and corresponds to fetching a minibatch of
122data and collating them into batched samples, i.e., containing Tensors with
123one dimension being the batch dimension (usually the first).
124
125When :attr:`batch_size` (default ``1``) is not ``None``, the data loader yields
126batched samples instead of individual samples. :attr:`batch_size` and
127:attr:`drop_last` arguments are used to specify how the data loader obtains
128batches of dataset keys. For map-style datasets, users can alternatively
129specify :attr:`batch_sampler`, which yields a list of keys at a time.
130
131.. note::
132  The :attr:`batch_size` and :attr:`drop_last` arguments essentially are used
133  to construct a :attr:`batch_sampler` from :attr:`sampler`. For map-style
134  datasets, the :attr:`sampler` is either provided by user or constructed
135  based on the :attr:`shuffle` argument. For iterable-style datasets, the
136  :attr:`sampler` is a dummy infinite one. See
137  `this section <Data Loading Order and Sampler_>`_ on more details on
138  samplers.
139
140.. note::
141  When fetching from
142  `iterable-style datasets <Iterable-style datasets_>`_ with
143  `multi-processing <Multi-process data loading_>`_, the :attr:`drop_last`
144  argument drops the last non-full batch of each worker's dataset replica.
145
146After fetching a list of samples using the indices from sampler, the function
147passed as the :attr:`collate_fn` argument is used to collate lists of samples
148into batches.
149
150In this case, loading from a map-style dataset is roughly equivalent with::
151
152    for indices in batch_sampler:
153        yield collate_fn([dataset[i] for i in indices])
154
155and loading from an iterable-style dataset is roughly equivalent with::
156
157    dataset_iter = iter(dataset)
158    for indices in batch_sampler:
159        yield collate_fn([next(dataset_iter) for _ in indices])
160
161A custom :attr:`collate_fn` can be used to customize collation, e.g., padding
162sequential data to max length of a batch. See
163`this section <dataloader-collate_fn_>`_ on more about :attr:`collate_fn`.
164
165Disable automatic batching
166^^^^^^^^^^^^^^^^^^^^^^^^^^
167
168In certain cases, users may want to handle batching manually in dataset code,
169or simply load individual samples. For example, it could be cheaper to directly
170load batched data (e.g., bulk reads from a database or reading continuous
171chunks of memory), or the batch size is data dependent, or the program is
172designed to work on individual samples.  Under these scenarios, it's likely
173better to not use automatic batching (where :attr:`collate_fn` is used to
174collate the samples), but let the data loader directly return each member of
175the :attr:`dataset` object.
176
177When both :attr:`batch_size` and :attr:`batch_sampler` are ``None`` (default
178value for :attr:`batch_sampler` is already ``None``), automatic batching is
179disabled. Each sample obtained from the :attr:`dataset` is processed with the
180function passed as the :attr:`collate_fn` argument.
181
182**When automatic batching is disabled**, the default :attr:`collate_fn` simply
183converts NumPy arrays into PyTorch Tensors, and keeps everything else untouched.
184
185In this case, loading from a map-style dataset is roughly equivalent with::
186
187    for index in sampler:
188        yield collate_fn(dataset[index])
189
190and loading from an iterable-style dataset is roughly equivalent with::
191
192    for data in iter(dataset):
193        yield collate_fn(data)
194
195See `this section <dataloader-collate_fn_>`_ on more about :attr:`collate_fn`.
196
197.. _dataloader-collate_fn:
198
199Working with :attr:`collate_fn`
200^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
201
202The use of :attr:`collate_fn` is slightly different when automatic batching is
203enabled or disabled.
204
205**When automatic batching is disabled**, :attr:`collate_fn` is called with
206each individual data sample, and the output is yielded from the data loader
207iterator. In this case, the default :attr:`collate_fn` simply converts NumPy
208arrays in PyTorch tensors.
209
210**When automatic batching is enabled**, :attr:`collate_fn` is called with a list
211of data samples at each time. It is expected to collate the input samples into
212a batch for yielding from the data loader iterator. The rest of this section
213describes the behavior of the default :attr:`collate_fn`
214(:func:`~torch.utils.data.default_collate`).
215
216For instance, if each data sample consists of a 3-channel image and an integral
217class label, i.e., each element of the dataset returns a tuple
218``(image, class_index)``, the default :attr:`collate_fn` collates a list of
219such tuples into a single tuple of a batched image tensor and a batched class
220label Tensor. In particular, the default :attr:`collate_fn` has the following
221properties:
222
223* It always prepends a new dimension as the batch dimension.
224
225* It automatically converts NumPy arrays and Python numerical values into
226  PyTorch Tensors.
227
228* It preserves the data structure, e.g., if each sample is a dictionary, it
229  outputs a dictionary with the same set of keys but batched Tensors as values
230  (or lists if the values can not be converted into Tensors). Same
231  for ``list`` s, ``tuple`` s, ``namedtuple`` s, etc.
232
233Users may use customized :attr:`collate_fn` to achieve custom batching, e.g.,
234collating along a dimension other than the first, padding sequences of
235various lengths, or adding support for custom data types.
236
237If you run into a situation where the outputs of :class:`~torch.utils.data.DataLoader`
238have dimensions or type that is different from your expectation, you may
239want to check your :attr:`collate_fn`.
240
241Single- and Multi-process Data Loading
242--------------------------------------
243
244A :class:`~torch.utils.data.DataLoader` uses single-process data loading by
245default.
246
247Within a Python process, the
248`Global Interpreter Lock (GIL) <https://wiki.python.org/moin/GlobalInterpreterLock>`_
249prevents true fully parallelizing Python code across threads. To avoid blocking
250computation code with data loading, PyTorch provides an easy switch to perform
251multi-process data loading by simply setting the argument :attr:`num_workers`
252to a positive integer.
253
254Single-process data loading (default)
255^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
256
257In this mode, data fetching is done in the same process a
258:class:`~torch.utils.data.DataLoader` is initialized.  Therefore, data loading
259may block computing.  However, this mode may be preferred when resource(s) used
260for sharing data among processes (e.g., shared memory, file descriptors) is
261limited, or when the entire dataset is small and can be loaded entirely in
262memory.  Additionally, single-process loading often shows more readable error
263traces and thus is useful for debugging.
264
265
266Multi-process data loading
267^^^^^^^^^^^^^^^^^^^^^^^^^^
268
269Setting the argument :attr:`num_workers` as a positive integer will
270turn on multi-process data loading with the specified number of loader worker
271processes.
272
273.. warning::
274   After several iterations, the loader worker processes will consume
275   the same amount of CPU memory as the parent process for all Python
276   objects in the parent process which are accessed from the worker
277   processes.  This can be problematic if the Dataset contains a lot of
278   data (e.g., you are loading a very large list of filenames at Dataset
279   construction time) and/or you are using a lot of workers (overall
280   memory usage is ``number of workers * size of parent process``).  The
281   simplest workaround is to replace Python objects with non-refcounted
282   representations such as Pandas, Numpy or PyArrow objects.  Check out
283   `issue #13246
284   <https://github.com/pytorch/pytorch/issues/13246#issuecomment-905703662>`_
285   for more details on why this occurs and example code for how to
286   workaround these problems.
287
288In this mode, each time an iterator of a :class:`~torch.utils.data.DataLoader`
289is created (e.g., when you call ``enumerate(dataloader)``), :attr:`num_workers`
290worker processes are created. At this point, the :attr:`dataset`,
291:attr:`collate_fn`, and :attr:`worker_init_fn` are passed to each
292worker, where they are used to initialize, and fetch data. This means that
293dataset access together with its  internal IO, transforms
294(including :attr:`collate_fn`) runs in the worker process.
295
296:func:`torch.utils.data.get_worker_info()` returns various useful information
297in a worker process (including the worker id, dataset replica, initial seed,
298etc.), and returns ``None`` in main process. Users may use this function in
299dataset code and/or :attr:`worker_init_fn` to individually configure each
300dataset replica, and to determine whether the code is running in a worker
301process. For example, this can be particularly helpful in sharding the dataset.
302
303For map-style datasets, the main process generates the indices using
304:attr:`sampler` and sends them to the workers. So any shuffle randomization is
305done in the main process which guides loading by assigning indices to load.
306
307For iterable-style datasets, since each worker process gets a replica of the
308:attr:`dataset` object, naive multi-process loading will often result in
309duplicated data. Using :func:`torch.utils.data.get_worker_info()` and/or
310:attr:`worker_init_fn`, users may configure each replica independently. (See
311:class:`~torch.utils.data.IterableDataset` documentations for how to achieve
312this. ) For similar reasons, in multi-process loading, the :attr:`drop_last`
313argument drops the last non-full batch of each worker's iterable-style dataset
314replica.
315
316Workers are shut down once the end of the iteration is reached, or when the
317iterator becomes garbage collected.
318
319.. warning::
320  It is generally not recommended to return CUDA tensors in multi-process
321  loading because of many subtleties in using CUDA and sharing CUDA tensors in
322  multiprocessing (see :ref:`multiprocessing-cuda-note`). Instead, we recommend
323  using `automatic memory pinning <Memory Pinning_>`_ (i.e., setting
324  :attr:`pin_memory=True`), which enables fast data transfer to CUDA-enabled
325  GPUs.
326
327Platform-specific behaviors
328"""""""""""""""""""""""""""
329
330Since workers rely on Python :py:mod:`multiprocessing`, worker launch behavior is
331different on Windows compared to Unix.
332
333* On Unix, :func:`fork()` is the default :py:mod:`multiprocessing` start method.
334  Using :func:`fork`, child workers typically can access the :attr:`dataset` and
335  Python argument functions directly through the cloned address space.
336
337* On Windows or MacOS, :func:`spawn()` is the default :py:mod:`multiprocessing` start method.
338  Using :func:`spawn()`, another interpreter is launched which runs your main script,
339  followed by the internal worker function that receives the :attr:`dataset`,
340  :attr:`collate_fn` and other arguments through :py:mod:`pickle` serialization.
341
342This separate serialization means that you should take two steps to ensure you
343are compatible with Windows while using multi-process data loading:
344
345- Wrap most of you main script's code within ``if __name__ == '__main__':`` block,
346  to make sure it doesn't run again (most likely generating error) when each worker
347  process is launched. You can place your dataset and :class:`~torch.utils.data.DataLoader`
348  instance creation logic here, as it doesn't need to be re-executed in workers.
349
350- Make sure that any custom :attr:`collate_fn`, :attr:`worker_init_fn`
351  or :attr:`dataset` code is declared as top level definitions, outside of the
352  ``__main__`` check. This ensures that they are available in worker processes.
353  (this is needed since functions are pickled as references only, not ``bytecode``.)
354
355.. _data-loading-randomness:
356
357Randomness in multi-process data loading
358""""""""""""""""""""""""""""""""""""""""""
359
360By default, each worker will have its PyTorch seed set to ``base_seed + worker_id``,
361where ``base_seed`` is a long generated by main process using its RNG (thereby,
362consuming a RNG state mandatorily) or a specified :attr:`generator`. However, seeds for other
363libraries may be duplicated upon initializing workers, causing each worker to return
364identical random numbers. (See :ref:`this section <dataloader-workers-random-seed>` in FAQ.).
365
366In :attr:`worker_init_fn`, you may access the PyTorch seed set for each worker
367with either :func:`torch.utils.data.get_worker_info().seed <torch.utils.data.get_worker_info>`
368or :func:`torch.initial_seed()`, and use it to seed other libraries before data
369loading.
370
371Memory Pinning
372--------------
373
374Host to GPU copies are much faster when they originate from pinned (page-locked)
375memory. See :ref:`cuda-memory-pinning` for more details on when and how to use
376pinned memory generally.
377
378For data loading, passing :attr:`pin_memory=True` to a
379:class:`~torch.utils.data.DataLoader` will automatically put the fetched data
380Tensors in pinned memory, and thus enables faster data transfer to CUDA-enabled
381GPUs.
382
383The default memory pinning logic only recognizes Tensors and maps and iterables
384containing Tensors.  By default, if the pinning logic sees a batch that is a
385custom type (which will occur if you have a :attr:`collate_fn` that returns a
386custom batch type), or if each element of your batch is a custom type, the
387pinning logic will not recognize them, and it will return that batch (or those
388elements) without pinning the memory.  To enable memory pinning for custom
389batch or data type(s), define a :meth:`pin_memory` method on your custom
390type(s).
391
392See the example below.
393
394Example::
395
396    class SimpleCustomBatch:
397        def __init__(self, data):
398            transposed_data = list(zip(*data))
399            self.inp = torch.stack(transposed_data[0], 0)
400            self.tgt = torch.stack(transposed_data[1], 0)
401
402        # custom memory pinning method on custom type
403        def pin_memory(self):
404            self.inp = self.inp.pin_memory()
405            self.tgt = self.tgt.pin_memory()
406            return self
407
408    def collate_wrapper(batch):
409        return SimpleCustomBatch(batch)
410
411    inps = torch.arange(10 * 5, dtype=torch.float32).view(10, 5)
412    tgts = torch.arange(10 * 5, dtype=torch.float32).view(10, 5)
413    dataset = TensorDataset(inps, tgts)
414
415    loader = DataLoader(dataset, batch_size=2, collate_fn=collate_wrapper,
416                        pin_memory=True)
417
418    for batch_ndx, sample in enumerate(loader):
419        print(sample.inp.is_pinned())
420        print(sample.tgt.is_pinned())
421
422
423.. autoclass:: DataLoader
424.. autoclass:: Dataset
425.. autoclass:: IterableDataset
426.. autoclass:: TensorDataset
427.. autoclass:: StackDataset
428.. autoclass:: ConcatDataset
429.. autoclass:: ChainDataset
430.. autoclass:: Subset
431.. autofunction:: torch.utils.data._utils.collate.collate
432.. autofunction:: torch.utils.data.default_collate
433.. autofunction:: torch.utils.data.default_convert
434.. autofunction:: torch.utils.data.get_worker_info
435.. autofunction:: torch.utils.data.random_split
436.. autoclass:: torch.utils.data.Sampler
437.. autoclass:: torch.utils.data.SequentialSampler
438.. autoclass:: torch.utils.data.RandomSampler
439.. autoclass:: torch.utils.data.SubsetRandomSampler
440.. autoclass:: torch.utils.data.WeightedRandomSampler
441.. autoclass:: torch.utils.data.BatchSampler
442.. autoclass:: torch.utils.data.distributed.DistributedSampler
443
444
445.. These modules are documented as part of torch/data listing them here for
446.. now until we have a clearer fix
447.. py:module:: torch.utils.data.datapipes
448.. py:module:: torch.utils.data.datapipes.dataframe
449.. py:module:: torch.utils.data.datapipes.iter
450.. py:module:: torch.utils.data.datapipes.map
451.. py:module:: torch.utils.data.datapipes.utils
452