1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-unsafe 8 9import json 10import os 11from typing import Any, Dict 12 13import torch 14from executorch.examples.models.checkpoint import ( 15 get_checkpoint_dtype, 16 get_default_model_resource_dir, 17) 18 19from executorch.examples.models.model_base import EagerModelBase 20from executorch.extension.llm.modules.attention import replace_mha_with_inference_mha 21from torchtune.models.llama3_2_vision._component_builders import llama3_2_vision_decoder 22from torchtune.models.llama3_2_vision._convert_weights import llama3_vision_meta_to_tune 23 24 25def to_decoder_checkpoint(checkpoint: Dict[str, Any]) -> Dict[str, Any]: 26 """ 27 Extracts and formats the decoder-related weights from the checkpoint. The checkpoint contains 28 weight names prefixed with "encoder"/"decoder", such as "encoder.layer.etc" or "decoder.norm.scale". 29 To load the text decoder on its own, the "decoder" prefix needs to be removed. 30 """ 31 return { 32 ".".join(weight.split(".")[1:]): value 33 for weight, value in checkpoint.items() 34 if weight.startswith("decoder") 35 } 36 37 38class Llama3_2Decoder(EagerModelBase): 39 """ 40 Just the text decoder portions of the Llama3.2 multimodal model. 41 """ 42 43 def __init__(self, **kwargs): 44 # Set member vars from kwargs. 45 self.max_seq_len = kwargs.get( 46 "max_seq_len", 8192 47 ) # Trained to be a lot larger, but this value is kept small because of static kv cache at the moment. 48 self.encoder_max_seq_len = kwargs.get( 49 "encoder_max_seq_len", int(4 * (448 / 14) ** 2 + 1) 50 ) # Same as above. 51 self.generate_full_logits = kwargs.get("generate_full_logits", False) 52 self.enable_dynamic_shape = kwargs.get("enable_dynamic_shape", False) 53 self.output_prune_map_path = kwargs.get("output_prune_map_path", None) 54 self.use_kv_cache = kwargs.get("use_kv_cache", False) 55 self.verbose = kwargs.get("verbose", False) 56 self.args = kwargs.get("args", None) 57 self.dtype = kwargs.get("dtype", torch.float16) 58 self.use_checkpoint = False 59 60 ckpt_dir = get_default_model_resource_dir(__file__) 61 # Single checkpoint file. 62 checkpoint_path = kwargs.get("checkpoint", ckpt_dir / "demo_rand_params.pth") 63 if os.path.isfile(checkpoint_path): 64 self.use_checkpoint = True 65 66 # Sharded checkpoint. 67 checkpoint_dir = kwargs.get("checkpoint_dir", None) 68 params_path = kwargs.get("params", ckpt_dir / "demo_config.json") 69 70 self.causal_mask = torch.tril( 71 torch.ones( 72 size=(self.max_seq_len, self.max_seq_len), 73 dtype=torch.bool, 74 ) 75 ) 76 self.input_pos = torch.arange(self.max_seq_len, dtype=torch.int64) 77 78 # Load checkpoint and params. 79 device = "cpu" 80 if checkpoint_dir is not None: 81 raise NotImplementedError( 82 "Sharded checkpoint not yet supported for Llama3_2Decoder." 83 ) 84 elif self.use_checkpoint: 85 checkpoint = torch.load( 86 checkpoint_path, map_location=device, weights_only=False, mmap=True 87 ) 88 checkpoint = llama3_vision_meta_to_tune(checkpoint) 89 checkpoint = to_decoder_checkpoint(checkpoint) 90 self.dtype = get_checkpoint_dtype(checkpoint) 91 92 with open(params_path, "r") as f: 93 params = json.loads(f.read()) 94 95 # Load model. 96 # Cannot use "with torch.device("meta"):" because it causes some exceptions during export, 97 # i.e. the model isn't fully initialized or something. 98 self.model_ = llama3_2_vision_decoder( 99 vocab_size=params["vocab_size"], 100 num_layers=params["n_layers"], 101 fusion_interval=params["fusion_interval"], 102 num_special_tokens=params["n_special_tokens"], 103 num_heads=params["n_heads"], 104 num_kv_heads=params["n_kv_heads"], 105 embed_dim=params["dim"], 106 max_seq_len=self.max_seq_len, 107 encoder_max_seq_len=self.encoder_max_seq_len, 108 rope_base=params["rope_theta"], 109 intermediate_dim=params["intermediate_dim"], 110 ) 111 112 # Source transformation for MultiHeadAttention 113 self.model_ = replace_mha_with_inference_mha(self.model_) 114 # Save params for future use. 115 for param_name, param_val in params.items(): 116 setattr(self.model_, param_name, param_val) 117 118 # Quantize. (skip for now) 119 120 if self.use_checkpoint: 121 # Load checkpoint. 122 missing, unexpected = self.model_.load_state_dict( 123 checkpoint, 124 strict=False, 125 assign=True, 126 ) 127 if kwargs.get("verbose", False): 128 print("============= missing keys ================") 129 print(missing) 130 print("============= /missing ================") 131 print("============= unexpected keys ================") 132 print(unexpected) 133 print("============= /unexpected ================") 134 135 # Prune the output layer if output_prune_map is provided. 136 output_prune_map = None 137 if self.output_prune_map_path is not None: 138 from executorch.examples.models.llama2.source_transformation.prune_output import ( 139 prune_output_vocab, 140 ) 141 142 with open(self.output_prune_map_path, "r") as f: 143 output_prune_map = json.load(f) 144 # Change keys from string to int (json only supports string keys) 145 output_prune_map = {int(k): v for (k, v) in output_prune_map.items()} 146 147 self.model_ = prune_output_vocab(self.model_, output_prune_map) 148 149 if self.use_kv_cache: 150 print("Setting up KV cache on the model...") 151 self.model_.setup_caches( 152 batch_size=1, 153 dtype=self.dtype, 154 encoder_max_seq_len=self.encoder_max_seq_len, 155 decoder_max_seq_len=self.max_seq_len, 156 ) 157 # number of tokens for example input 158 self.n_tokens = 34 159 self.model_.to(self.dtype) 160 161 def get_eager_model(self) -> torch.nn.Module: 162 return self.model_ 163 164 def get_example_inputs(self): 165 return (torch.ones(1, self.n_tokens, dtype=torch.int64),) 166 167 def get_example_kwarg_inputs(self): 168 # For export we must use the prefill versions of the 169 # causal mask and input_pos. 170 # Hardcoding # of tiles to be 2. image tokens per tile is 1601. 171 if self.use_kv_cache: 172 return { 173 "input_pos": self.input_pos[None, : self.n_tokens], 174 "mask": self.causal_mask[None, : self.n_tokens], 175 "encoder_input": torch.randn( 176 1, self.encoder_max_seq_len, self.model_.dim, dtype=self.dtype 177 ), 178 "encoder_mask": torch.ones( 179 [1, self.n_tokens, self.encoder_max_seq_len], dtype=torch.bool 180 ), 181 } 182 else: 183 return None 184 185 def get_dynamic_shapes(self): 186 batch_size = 1 187 dim_seq_len = torch.export.Dim("token_dim", min=1, max=self.max_seq_len) 188 # Hardcoding # of tiles to be 2. image tokens per tile is 1601. 189 if self.use_kv_cache: 190 dynamic_shapes = { 191 "tokens": {0: batch_size, 1: dim_seq_len}, 192 "encoder_input": None, 193 "encoder_mask": {0: 1, 1: dim_seq_len, 2: None}, 194 "mask": {0: batch_size, 1: dim_seq_len, 2: None}, 195 "input_pos": {0: batch_size, 1: dim_seq_len}, 196 } 197 else: 198 dynamic_shapes = { 199 "tokens": {0: batch_size, 1: dim_seq_len}, 200 } 201 return dynamic_shapes 202