• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1.. _cpu-threading-torchscript-inference:
2
3CPU threading and TorchScript inference
4=================================================
5
6PyTorch allows using multiple CPU threads during TorchScript model inference.
7The following figure shows different levels of parallelism one would find in a
8typical application:
9
10.. image:: cpu_threading_torchscript_inference.svg
11   :width: 75%
12
13One or more inference threads execute a model's forward pass on the given inputs.
14Each inference thread invokes a JIT interpreter that executes the ops
15of a model inline, one by one. A model can utilize a ``fork`` TorchScript
16primitive to launch an asynchronous task. Forking several operations at once
17results in a task that is executed in parallel. The ``fork`` operator returns a
18``Future`` object which can be used to synchronize on later, for example:
19
20.. code-block:: python
21
22    @torch.jit.script
23    def compute_z(x):
24        return torch.mm(x, self.w_z)
25
26    @torch.jit.script
27    def forward(x):
28        # launch compute_z asynchronously:
29        fut = torch.jit._fork(compute_z, x)
30        # execute the next operation in parallel to compute_z:
31        y = torch.mm(x, self.w_y)
32        # wait for the result of compute_z:
33        z = torch.jit._wait(fut)
34        return y + z
35
36
37PyTorch uses a single thread pool for the inter-op parallelism, this thread pool
38is shared by all inference tasks that are forked within the application process.
39
40In addition to the inter-op parallelism, PyTorch can also utilize multiple threads
41within the ops (`intra-op parallelism`). This can be useful in many cases,
42including element-wise ops on large tensors, convolutions, GEMMs, embedding
43lookups and others.
44
45
46Build options
47-------------
48
49PyTorch uses an internal ATen library to implement ops. In addition to that,
50PyTorch can also be built with support of external libraries, such as MKL_ and MKL-DNN_,
51to speed up computations on CPU.
52
53ATen, MKL and MKL-DNN support intra-op parallelism and depend on the
54following parallelization libraries to implement it:
55
56* OpenMP_ - a standard (and a library, usually shipped with a compiler), widely used in external libraries;
57* TBB_ - a newer parallelization library optimized for task-based parallelism and concurrent environments.
58
59OpenMP historically has been used by a large number of libraries. It is known
60for a relative ease of use and support for loop-based parallelism and other primitives.
61
62TBB is used to a lesser extent in external libraries, but, at the same time,
63is optimized for the concurrent environments. PyTorch's TBB backend guarantees that
64there's a separate, single, per-process intra-op thread pool used by all of the
65ops running in the application.
66
67Depending of the use case, one might find one or another parallelization
68library a better choice in their application.
69
70PyTorch allows selecting of the parallelization backend used by ATen and other
71libraries at the build time with the following build options:
72
73+------------+------------------------+-----------------------------+----------------------------------------+
74| Library    | Build Option           | Values                      | Notes                                  |
75+============+========================+=============================+========================================+
76| ATen       | ``ATEN_THREADING``     | ``OMP`` (default), ``TBB``  |                                        |
77+------------+------------------------+-----------------------------+----------------------------------------+
78| MKL        | ``MKL_THREADING``      | (same)                      | To enable MKL use ``BLAS=MKL``         |
79+------------+------------------------+-----------------------------+----------------------------------------+
80| MKL-DNN    | ``MKLDNN_CPU_RUNTIME`` | (same)                      | To enable MKL-DNN use ``USE_MKLDNN=1`` |
81+------------+------------------------+-----------------------------+----------------------------------------+
82
83It is recommended not to mix OpenMP and TBB within one build.
84
85Any of the ``TBB`` values above require ``USE_TBB=1`` build setting (default: OFF).
86A separate setting ``USE_OPENMP=1`` (default: ON) is required for OpenMP parallelism.
87
88Runtime API
89-----------
90
91The following API is used to control thread settings:
92
93+------------------------+-----------------------------------------------------------+---------------------------------------------------------+
94| Type of parallelism    | Settings                                                  | Notes                                                   |
95+========================+===========================================================+=========================================================+
96| Inter-op parallelism   | ``at::set_num_interop_threads``,                          | Default number of threads: number of CPU cores.         |
97|                        | ``at::get_num_interop_threads`` (C++)                     |                                                         |
98|                        |                                                           |                                                         |
99|                        | ``set_num_interop_threads``,                              |                                                         |
100|                        | ``get_num_interop_threads`` (Python, :mod:`torch` module) |                                                         |
101+------------------------+-----------------------------------------------------------+                                                         |
102| Intra-op parallelism   | ``at::set_num_threads``,                                  |                                                         |
103|                        | ``at::get_num_threads`` (C++)                             |                                                         |
104|                        | ``set_num_threads``,                                      |                                                         |
105|                        | ``get_num_threads`` (Python, :mod:`torch` module)         |                                                         |
106|                        |                                                           |                                                         |
107|                        | Environment variables:                                    |                                                         |
108|                        | ``OMP_NUM_THREADS`` and ``MKL_NUM_THREADS``               |                                                         |
109+------------------------+-----------------------------------------------------------+---------------------------------------------------------+
110
111For the intra-op parallelism settings, ``at::set_num_threads``, ``torch.set_num_threads`` always take precedence
112over environment variables, ``MKL_NUM_THREADS`` variable takes precedence over ``OMP_NUM_THREADS``.
113
114Tuning the number of threads
115----------------------------
116
117The following simple script shows how a runtime of matrix multiplication changes with the number of threads:
118
119.. code-block:: python
120
121    import timeit
122    runtimes = []
123    threads = [1] + [t for t in range(2, 49, 2)]
124    for t in threads:
125        torch.set_num_threads(t)
126        r = timeit.timeit(setup = "import torch; x = torch.randn(1024, 1024); y = torch.randn(1024, 1024)", stmt="torch.mm(x, y)", number=100)
127        runtimes.append(r)
128    # ... plotting (threads, runtimes) ...
129
130Running the script on a system with 24 physical CPU cores (Xeon E5-2680, MKL and OpenMP based build) results in the following runtimes:
131
132.. image:: cpu_threading_runtimes.svg
133   :width: 75%
134
135The following considerations should be taken into account when tuning the number of intra- and inter-op threads:
136
137* When choosing the number of threads one needs to avoid `oversubscription` (using too many threads, leads to performance degradation). For example, in an application that uses a large application thread pool or heavily relies on
138  inter-op parallelism, one might find disabling intra-op parallelism as a possible option (i.e. by calling ``set_num_threads(1)``);
139
140* In a typical application one might encounter a trade off between `latency` (time spent on processing an inference request) and `throughput` (amount of work done per unit of time). Tuning the number of threads can be a useful
141  tool to adjust this trade off in one way or another. For example, in latency critical applications one might want to increase the number of intra-op threads to process each request as fast as possible. At the same time, parallel implementations
142  of ops may add an extra overhead that increases amount work done per single request and thus reduces the overall throughput.
143
144.. warning::
145    OpenMP does not guarantee that a single per-process intra-op thread
146    pool is going to be used in the application. On the contrary, two different application or inter-op
147    threads may use different OpenMP thread pools for intra-op work.
148    This might result in a large number of threads used by the application.
149    Extra care in tuning the number of threads is needed to avoid
150    oversubscription in multi-threaded applications in OpenMP case.
151
152.. note::
153    Pre-built PyTorch releases are compiled with OpenMP support.
154
155.. note::
156    ``parallel_info`` utility prints information about thread settings and can be used for debugging.
157    Similar output can be also obtained in Python with ``torch.__config__.parallel_info()`` call.
158
159.. _OpenMP: https://www.openmp.org/
160.. _TBB: https://github.com/intel/tbb
161.. _MKL: https://software.intel.com/en-us/mkl
162.. _MKL-DNN: https://github.com/intel/mkl-dnn
163