• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1Pytorch 2.4: Getting Started on Intel GPU
2=========================================
3
4The support for Intel GPUs is released alongside PyTorch v2.4.
5
6This release only supports build from source for Intel GPUs.
7
8Hardware Prerequisites
9----------------------
10
11.. list-table::
12   :header-rows: 1
13
14   * - Supported Hardware
15     - Intel® Data Center GPU Max Series
16   * - Supported OS
17     - Linux
18
19
20PyTorch for Intel GPUs is compatible with Intel® Data Center GPU Max Series and only supports OS Linux with release 2.4.
21
22Software Prerequisites
23----------------------
24
25As a prerequisite, install the driver and required packages by following the `PyTorch Installation Prerequisites for Intel GPUs <https://www.intel.com/content/www/us/en/developer/articles/tool/pytorch-prerequisites-for-intel-gpus.html>`_.
26
27Set up Environment
28------------------
29
30Before you begin, you need to set up the environment. This can be done by sourcing the ``setvars.sh`` script provided by the ``intel-for-pytorch-gpu-dev`` and  ``intel-pti-dev`` packages.
31
32.. code-block::
33
34   source ${ONEAPI_ROOT}/setvars.sh
35
36.. note::
37   The ``ONEAPI_ROOT`` is the folder you installed your ``intel-for-pytorch-gpu-dev`` and  ``intel-pti-dev`` packages. Typically, it is located at ``/opt/intel/oneapi/`` or ``~/intel/oneapi/``.
38
39Build from source
40-----------------
41
42Now we have all the required packages installed and environment acitvated. Use the following commands to install ``pytorch``, ``torchvision``, ``torchaudio`` by building from source. For more details, refer to official guides in `PyTorch from source <https://github.com/pytorch/pytorch?tab=readme-ov-file#intel-gpu-support>`_, `Vision from source <https://github.com/pytorch/vision/blob/main/CONTRIBUTING.md#development-installation>`_ and `Audio from source <https://pytorch.org/audio/main/build.linux.html>`_.
43
44.. code-block::
45
46   # Get PyTorch Source Code
47   git clone --recursive https://github.com/pytorch/pytorch
48   cd pytorch
49   git checkout main # or checkout the specific release version >= v2.4
50   git submodule sync
51   git submodule update --init --recursive
52
53   # Get required packages for compilation
54   conda install cmake ninja
55   pip install -r requirements.txt
56
57   # Pytorch for Intel GPUs only support Linux platform for now.
58   # Install the required packages for pytorch compilation.
59   conda install intel::mkl-static intel::mkl-include
60
61   # (optional) If using torch.compile with inductor/triton, install the matching version of triton
62   # Run from the pytorch directory after cloning
63   # For Intel GPU support, please explicitly `export USE_XPU=1` before running command.
64   USE_XPU=1 make triton
65
66   # If you would like to compile PyTorch with new C++ ABI enabled, then first run this command:
67   export _GLIBCXX_USE_CXX11_ABI=1
68
69   # pytorch build from source
70   export CMAKE_PREFIX_PATH=${CONDA_PREFIX:-"$(dirname $(which conda))/../"}
71   python setup.py develop
72   cd ..
73
74   # (optional) If using torchvison.
75   # Get torchvision Code
76   git clone https://github.com/pytorch/vision.git
77   cd vision
78   git checkout main # or specific version
79   python setup.py develop
80   cd ..
81
82   # (optional) If using torchaudio.
83   # Get torchaudio Code
84   git clone https://github.com/pytorch/audio.git
85   cd audio
86   pip install -r requirements.txt
87   git checkout main # or specific version
88   git submodule sync
89   git submodule update --init --recursive
90   python setup.py develop
91   cd ..
92
93Check availability for Intel GPU
94--------------------------------
95
96.. note::
97   Make sure the environment is properly set up by following `Environment Set up <#set-up-environment>`_ before running the code.
98
99To check if your Intel GPU is available, you would typically use the following code:
100
101.. code-block::
102
103   import torch
104   torch.xpu.is_available()  # torch.xpu is the API for Intel GPU support
105
106If the output is ``False``, ensure that you have Intel GPU in your system and correctly follow the `PyTorch Installation Prerequisites for Intel GPUs <https://www.intel.com/content/www/us/en/developer/articles/tool/pytorch-prerequisites-for-intel-gpus.html>`_. Then, check that the PyTorch compilation is correctly finished.
107
108Minimum Code Change
109-------------------
110
111If you are migrating code from ``cuda``, you would change references from ``cuda`` to ``xpu``. For example:
112
113.. code-block::
114
115   # CUDA CODE
116   tensor = torch.tensor([1.0, 2.0]).to("cuda")
117
118   # CODE for Intel GPU
119   tensor = torch.tensor([1.0, 2.0]).to("xpu")
120
121The following points outline the support and limitations for PyTorch with Intel GPU:
122
123#. Both training and inference workflows are supported.
124#. Both eager mode and ``torch.compile`` is supported.
125#. Data types such as FP32, BF16, FP16, and Automatic Mixed Precision (AMP) are all supported.
126#. Models that depend on third-party components, will not be supported until PyTorch v2.5 or later.
127
128Examples
129--------
130
131This section contains usage examples for both inference and training workflows.
132
133Inference Examples
134^^^^^^^^^^^^^^^^^^
135
136Here is a few inference workflow examples.
137
138
139Inference with FP32
140"""""""""""""""""""
141
142.. code-block::
143
144   import torch
145   import torchvision.models as models
146
147   model = models.resnet50(weights="ResNet50_Weights.DEFAULT")
148   model.eval()
149   data = torch.rand(1, 3, 224, 224)
150
151   ######## code changes #######
152   model = model.to("xpu")
153   data = data.to("xpu")
154   ######## code changes #######
155
156   with torch.no_grad():
157       model(data)
158
159   print("Execution finished")
160
161Inference with AMP
162""""""""""""""""""
163
164.. code-block::
165
166   import torch
167   import torchvision.models as models
168
169   model = models.resnet50(weights="ResNet50_Weights.DEFAULT")
170   model.eval()
171   data = torch.rand(1, 3, 224, 224)
172
173   #################### code changes #################
174   model = model.to("xpu")
175   data = data.to("xpu")
176   #################### code changes #################
177
178   with torch.no_grad():
179       d = torch.rand(1, 3, 224, 224)
180       ############################# code changes #####################
181       d = d.to("xpu")
182       # set dtype=torch.bfloat16 for BF16
183       with torch.autocast(device_type="xpu", dtype=torch.float16, enabled=True):
184       ############################# code changes #####################
185           model(data)
186
187   print("Execution finished")
188
189Inference with ``torch.compile``
190""""""""""""""""""""""""""""""""
191
192.. code-block::
193
194   import torch
195   import torchvision.models as models
196
197   model = models.resnet50(weights="ResNet50_Weights.DEFAULT")
198   model.eval()
199   data = torch.rand(1, 3, 224, 224)
200   ITERS = 10
201
202   ######## code changes #######
203   model = model.to("xpu")
204   data = data.to("xpu")
205   ######## code changes #######
206
207   model = torch.compile(model)
208   for i in range(ITERS):
209       with torch.no_grad():
210           model(data)
211
212   print("Execution finished")
213
214Training Examples
215^^^^^^^^^^^^^^^^^
216
217Here is a few training workflow examples.
218
219Train with FP32
220"""""""""""""""
221
222.. code-block::
223
224   import torch
225   import torchvision
226
227   LR = 0.001
228   DOWNLOAD = True
229   DATA = "datasets/cifar10/"
230
231   transform = torchvision.transforms.Compose(
232       [
233           torchvision.transforms.Resize((224, 224)),
234           torchvision.transforms.ToTensor(),
235           torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
236       ]
237   )
238   train_dataset = torchvision.datasets.CIFAR10(
239       root=DATA,
240       train=True,
241       transform=transform,
242       download=DOWNLOAD,
243   )
244   train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=128)
245
246   model = torchvision.models.resnet50()
247   criterion = torch.nn.CrossEntropyLoss()
248   optimizer = torch.optim.SGD(model.parameters(), lr=LR, momentum=0.9)
249   model.train()
250   ######################## code changes #######################
251   model = model.to("xpu")
252   criterion = criterion.to("xpu")
253   ######################## code changes #######################
254
255   for batch_idx, (data, target) in enumerate(train_loader):
256       ########## code changes ##########
257       data = data.to("xpu")
258       target = target.to("xpu")
259       ########## code changes ##########
260       optimizer.zero_grad()
261       output = model(data)
262       loss = criterion(output, target)
263       loss.backward()
264       optimizer.step()
265       print(batch_idx)
266   torch.save(
267       {
268           "model_state_dict": model.state_dict(),
269           "optimizer_state_dict": optimizer.state_dict(),
270       },
271       "checkpoint.pth",
272   )
273
274   print("Execution finished")
275
276Train with AMP
277""""""""""""""
278
279.. code-block::
280
281   import torch
282   import torchvision
283
284   LR = 0.001
285   DOWNLOAD = True
286   DATA = "datasets/cifar10/"
287
288   use_amp=True
289
290   transform = torchvision.transforms.Compose(
291       [
292           torchvision.transforms.Resize((224, 224)),
293           torchvision.transforms.ToTensor(),
294           torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
295       ]
296   )
297   train_dataset = torchvision.datasets.CIFAR10(
298       root=DATA,
299       train=True,
300       transform=transform,
301       download=DOWNLOAD,
302   )
303   train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=128)
304
305   model = torchvision.models.resnet50()
306   criterion = torch.nn.CrossEntropyLoss()
307   optimizer = torch.optim.SGD(model.parameters(), lr=LR, momentum=0.9)
308   scaler = torch.amp.GradScaler(enabled=use_amp)
309
310   model.train()
311   ######################## code changes #######################
312   model = model.to("xpu")
313   criterion = criterion.to("xpu")
314   ######################## code changes #######################
315
316   for batch_idx, (data, target) in enumerate(train_loader):
317       ########## code changes ##########
318       data = data.to("xpu")
319       target = target.to("xpu")
320       ########## code changes ##########
321       # set dtype=torch.bfloat16 for BF16
322       with torch.autocast(device_type="xpu", dtype=torch.float16, enabled=use_amp):
323           output = model(data)
324           loss = criterion(output, target)
325       scaler.scale(loss).backward()
326       scaler.step(optimizer)
327       scaler.update()
328       optimizer.zero_grad()
329       print(batch_idx)
330
331   torch.save(
332       {
333           "model_state_dict": model.state_dict(),
334           "optimizer_state_dict": optimizer.state_dict(),
335       },
336       "checkpoint.pth",
337   )
338
339   print("Execution finished")
340