• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# export_nanogpt.py
8
9# Load partitioner for Xnnpack backend
10import torch
11from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
12
13# Model to be delegated to specific backend should use specific edge compile config
14from executorch.backends.xnnpack.utils.configs import get_xnnpack_edge_compile_config
15from executorch.exir import to_edge
16
17from model import GPT
18from torch.export import export, export_for_training
19from torch.nn.attention import sdpa_kernel, SDPBackend
20
21model = GPT.from_pretrained("gpt2")  # use gpt2 weight as pretrained weight
22example_inputs = (
23    torch.randint(0, 100, (1, model.config.block_size), dtype=torch.long),
24)
25dynamic_shape = ({1: torch.export.Dim("token_dim", max=model.config.block_size)},)
26
27# Trace the model, converting it to a portable intermediate representation.
28# The torch.no_grad() call tells PyTorch to exclude training-specific logic.
29with sdpa_kernel([SDPBackend.MATH]), torch.no_grad():
30    m = export_for_training(
31        model, example_inputs, dynamic_shapes=dynamic_shape
32    ).module()
33    traced_model = export(m, example_inputs, dynamic_shapes=dynamic_shape)
34
35# Convert the model into a runnable ExecuTorch program.
36# To be further lowered to Xnnpack backend, `traced_model` needs xnnpack-specific edge compile config
37edge_config = get_xnnpack_edge_compile_config()
38edge_manager = to_edge(traced_model, compile_config=edge_config)
39
40# Delegate exported model to Xnnpack backend by invoking `to_backend` function with Xnnpack partitioner.
41edge_manager = edge_manager.to_backend(XnnpackPartitioner())
42et_program = edge_manager.to_executorch()
43
44# Save the Xnnpack-delegated ExecuTorch program to a file.
45with open("nanogpt.pte", "wb") as file:
46    file.write(et_program.buffer)
47