• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-unsafe
8
9import json
10import os
11from typing import Any, Dict
12
13import torch
14from executorch.examples.models.checkpoint import (
15    get_checkpoint_dtype,
16    get_default_model_resource_dir,
17)
18
19from executorch.examples.models.model_base import EagerModelBase
20from executorch.extension.llm.modules.attention import replace_mha_with_inference_mha
21from torchtune.models.llama3_2_vision._component_builders import llama3_2_vision_decoder
22from torchtune.models.llama3_2_vision._convert_weights import llama3_vision_meta_to_tune
23
24
25def to_decoder_checkpoint(checkpoint: Dict[str, Any]) -> Dict[str, Any]:
26    """
27    Extracts and formats the decoder-related weights from the checkpoint. The checkpoint contains
28    weight names prefixed with "encoder"/"decoder", such as "encoder.layer.etc" or "decoder.norm.scale".
29    To load the text decoder on its own, the "decoder" prefix needs to be removed.
30    """
31    return {
32        ".".join(weight.split(".")[1:]): value
33        for weight, value in checkpoint.items()
34        if weight.startswith("decoder")
35    }
36
37
38class Llama3_2Decoder(EagerModelBase):
39    """
40    Just the text decoder portions of the Llama3.2 multimodal model.
41    """
42
43    def __init__(self, **kwargs):
44        # Set member vars from kwargs.
45        self.max_seq_len = kwargs.get(
46            "max_seq_len", 8192
47        )  # Trained to be a lot larger, but this value is kept small because of static kv cache at the moment.
48        self.encoder_max_seq_len = kwargs.get(
49            "encoder_max_seq_len", int(4 * (448 / 14) ** 2 + 1)
50        )  # Same as above.
51        self.generate_full_logits = kwargs.get("generate_full_logits", False)
52        self.enable_dynamic_shape = kwargs.get("enable_dynamic_shape", False)
53        self.output_prune_map_path = kwargs.get("output_prune_map_path", None)
54        self.use_kv_cache = kwargs.get("use_kv_cache", False)
55        self.verbose = kwargs.get("verbose", False)
56        self.args = kwargs.get("args", None)
57        self.dtype = kwargs.get("dtype", torch.float16)
58        self.use_checkpoint = False
59
60        ckpt_dir = get_default_model_resource_dir(__file__)
61        # Single checkpoint file.
62        checkpoint_path = kwargs.get("checkpoint", ckpt_dir / "demo_rand_params.pth")
63        if os.path.isfile(checkpoint_path):
64            self.use_checkpoint = True
65
66        # Sharded checkpoint.
67        checkpoint_dir = kwargs.get("checkpoint_dir", None)
68        params_path = kwargs.get("params", ckpt_dir / "demo_config.json")
69
70        self.causal_mask = torch.tril(
71            torch.ones(
72                size=(self.max_seq_len, self.max_seq_len),
73                dtype=torch.bool,
74            )
75        )
76        self.input_pos = torch.arange(self.max_seq_len, dtype=torch.int64)
77
78        # Load checkpoint and params.
79        device = "cpu"
80        if checkpoint_dir is not None:
81            raise NotImplementedError(
82                "Sharded checkpoint not yet supported for Llama3_2Decoder."
83            )
84        elif self.use_checkpoint:
85            checkpoint = torch.load(
86                checkpoint_path, map_location=device, weights_only=False, mmap=True
87            )
88            checkpoint = llama3_vision_meta_to_tune(checkpoint)
89            checkpoint = to_decoder_checkpoint(checkpoint)
90            self.dtype = get_checkpoint_dtype(checkpoint)
91
92        with open(params_path, "r") as f:
93            params = json.loads(f.read())
94
95        # Load model.
96        # Cannot use "with torch.device("meta"):" because it causes some exceptions during export,
97        # i.e. the model isn't fully initialized or something.
98        self.model_ = llama3_2_vision_decoder(
99            vocab_size=params["vocab_size"],
100            num_layers=params["n_layers"],
101            fusion_interval=params["fusion_interval"],
102            num_special_tokens=params["n_special_tokens"],
103            num_heads=params["n_heads"],
104            num_kv_heads=params["n_kv_heads"],
105            embed_dim=params["dim"],
106            max_seq_len=self.max_seq_len,
107            encoder_max_seq_len=self.encoder_max_seq_len,
108            rope_base=params["rope_theta"],
109            intermediate_dim=params["intermediate_dim"],
110        )
111
112        # Source transformation for MultiHeadAttention
113        self.model_ = replace_mha_with_inference_mha(self.model_)
114        # Save params for future use.
115        for param_name, param_val in params.items():
116            setattr(self.model_, param_name, param_val)
117
118        # Quantize. (skip for now)
119
120        if self.use_checkpoint:
121            # Load checkpoint.
122            missing, unexpected = self.model_.load_state_dict(
123                checkpoint,
124                strict=False,
125                assign=True,
126            )
127            if kwargs.get("verbose", False):
128                print("============= missing keys ================")
129                print(missing)
130                print("============= /missing ================")
131                print("============= unexpected keys ================")
132                print(unexpected)
133                print("============= /unexpected ================")
134
135        # Prune the output layer if output_prune_map is provided.
136        output_prune_map = None
137        if self.output_prune_map_path is not None:
138            from executorch.examples.models.llama2.source_transformation.prune_output import (
139                prune_output_vocab,
140            )
141
142            with open(self.output_prune_map_path, "r") as f:
143                output_prune_map = json.load(f)
144            # Change keys from string to int (json only supports string keys)
145            output_prune_map = {int(k): v for (k, v) in output_prune_map.items()}
146
147            self.model_ = prune_output_vocab(self.model_, output_prune_map)
148
149        if self.use_kv_cache:
150            print("Setting up KV cache on the model...")
151            self.model_.setup_caches(
152                batch_size=1,
153                dtype=self.dtype,
154                encoder_max_seq_len=self.encoder_max_seq_len,
155                decoder_max_seq_len=self.max_seq_len,
156            )
157        # number of tokens for example input
158        self.n_tokens = 34
159        self.model_.to(self.dtype)
160
161    def get_eager_model(self) -> torch.nn.Module:
162        return self.model_
163
164    def get_example_inputs(self):
165        return (torch.ones(1, self.n_tokens, dtype=torch.int64),)
166
167    def get_example_kwarg_inputs(self):
168        # For export we must use the prefill versions of the
169        # causal mask and input_pos.
170        # Hardcoding # of tiles to be 2. image tokens per tile is 1601.
171        if self.use_kv_cache:
172            return {
173                "input_pos": self.input_pos[None, : self.n_tokens],
174                "mask": self.causal_mask[None, : self.n_tokens],
175                "encoder_input": torch.randn(
176                    1, self.encoder_max_seq_len, self.model_.dim, dtype=self.dtype
177                ),
178                "encoder_mask": torch.ones(
179                    [1, self.n_tokens, self.encoder_max_seq_len], dtype=torch.bool
180                ),
181            }
182        else:
183            return None
184
185    def get_dynamic_shapes(self):
186        batch_size = 1
187        dim_seq_len = torch.export.Dim("token_dim", min=1, max=self.max_seq_len)
188        # Hardcoding # of tiles to be 2. image tokens per tile is 1601.
189        if self.use_kv_cache:
190            dynamic_shapes = {
191                "tokens": {0: batch_size, 1: dim_seq_len},
192                "encoder_input": None,
193                "encoder_mask": {0: 1, 1: dim_seq_len, 2: None},
194                "mask": {0: batch_size, 1: dim_seq_len, 2: None},
195                "input_pos": {0: batch_size, 1: dim_seq_len},
196            }
197        else:
198            dynamic_shapes = {
199                "tokens": {0: batch_size, 1: dim_seq_len},
200            }
201        return dynamic_shapes
202