# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. # pyre-unsafe import json import os from typing import Any, Dict import torch from executorch.examples.models.checkpoint import ( get_checkpoint_dtype, get_default_model_resource_dir, ) from executorch.examples.models.model_base import EagerModelBase from executorch.extension.llm.modules.attention import replace_mha_with_inference_mha from torchtune.models.llama3_2_vision._component_builders import llama3_2_vision_decoder from torchtune.models.llama3_2_vision._convert_weights import llama3_vision_meta_to_tune def to_decoder_checkpoint(checkpoint: Dict[str, Any]) -> Dict[str, Any]: """ Extracts and formats the decoder-related weights from the checkpoint. The checkpoint contains weight names prefixed with "encoder"/"decoder", such as "encoder.layer.etc" or "decoder.norm.scale". To load the text decoder on its own, the "decoder" prefix needs to be removed. """ return { ".".join(weight.split(".")[1:]): value for weight, value in checkpoint.items() if weight.startswith("decoder") } class Llama3_2Decoder(EagerModelBase): """ Just the text decoder portions of the Llama3.2 multimodal model. """ def __init__(self, **kwargs): # Set member vars from kwargs. self.max_seq_len = kwargs.get( "max_seq_len", 8192 ) # Trained to be a lot larger, but this value is kept small because of static kv cache at the moment. self.encoder_max_seq_len = kwargs.get( "encoder_max_seq_len", int(4 * (448 / 14) ** 2 + 1) ) # Same as above. self.generate_full_logits = kwargs.get("generate_full_logits", False) self.enable_dynamic_shape = kwargs.get("enable_dynamic_shape", False) self.output_prune_map_path = kwargs.get("output_prune_map_path", None) self.use_kv_cache = kwargs.get("use_kv_cache", False) self.verbose = kwargs.get("verbose", False) self.args = kwargs.get("args", None) self.dtype = kwargs.get("dtype", torch.float16) self.use_checkpoint = False ckpt_dir = get_default_model_resource_dir(__file__) # Single checkpoint file. checkpoint_path = kwargs.get("checkpoint", ckpt_dir / "demo_rand_params.pth") if os.path.isfile(checkpoint_path): self.use_checkpoint = True # Sharded checkpoint. checkpoint_dir = kwargs.get("checkpoint_dir", None) params_path = kwargs.get("params", ckpt_dir / "demo_config.json") self.causal_mask = torch.tril( torch.ones( size=(self.max_seq_len, self.max_seq_len), dtype=torch.bool, ) ) self.input_pos = torch.arange(self.max_seq_len, dtype=torch.int64) # Load checkpoint and params. device = "cpu" if checkpoint_dir is not None: raise NotImplementedError( "Sharded checkpoint not yet supported for Llama3_2Decoder." ) elif self.use_checkpoint: checkpoint = torch.load( checkpoint_path, map_location=device, weights_only=False, mmap=True ) checkpoint = llama3_vision_meta_to_tune(checkpoint) checkpoint = to_decoder_checkpoint(checkpoint) self.dtype = get_checkpoint_dtype(checkpoint) with open(params_path, "r") as f: params = json.loads(f.read()) # Load model. # Cannot use "with torch.device("meta"):" because it causes some exceptions during export, # i.e. the model isn't fully initialized or something. self.model_ = llama3_2_vision_decoder( vocab_size=params["vocab_size"], num_layers=params["n_layers"], fusion_interval=params["fusion_interval"], num_special_tokens=params["n_special_tokens"], num_heads=params["n_heads"], num_kv_heads=params["n_kv_heads"], embed_dim=params["dim"], max_seq_len=self.max_seq_len, encoder_max_seq_len=self.encoder_max_seq_len, rope_base=params["rope_theta"], intermediate_dim=params["intermediate_dim"], ) # Source transformation for MultiHeadAttention self.model_ = replace_mha_with_inference_mha(self.model_) # Save params for future use. for param_name, param_val in params.items(): setattr(self.model_, param_name, param_val) # Quantize. (skip for now) if self.use_checkpoint: # Load checkpoint. missing, unexpected = self.model_.load_state_dict( checkpoint, strict=False, assign=True, ) if kwargs.get("verbose", False): print("============= missing keys ================") print(missing) print("============= /missing ================") print("============= unexpected keys ================") print(unexpected) print("============= /unexpected ================") # Prune the output layer if output_prune_map is provided. output_prune_map = None if self.output_prune_map_path is not None: from executorch.examples.models.llama2.source_transformation.prune_output import ( prune_output_vocab, ) with open(self.output_prune_map_path, "r") as f: output_prune_map = json.load(f) # Change keys from string to int (json only supports string keys) output_prune_map = {int(k): v for (k, v) in output_prune_map.items()} self.model_ = prune_output_vocab(self.model_, output_prune_map) if self.use_kv_cache: print("Setting up KV cache on the model...") self.model_.setup_caches( batch_size=1, dtype=self.dtype, encoder_max_seq_len=self.encoder_max_seq_len, decoder_max_seq_len=self.max_seq_len, ) # number of tokens for example input self.n_tokens = 34 self.model_.to(self.dtype) def get_eager_model(self) -> torch.nn.Module: return self.model_ def get_example_inputs(self): return (torch.ones(1, self.n_tokens, dtype=torch.int64),) def get_example_kwarg_inputs(self): # For export we must use the prefill versions of the # causal mask and input_pos. # Hardcoding # of tiles to be 2. image tokens per tile is 1601. if self.use_kv_cache: return { "input_pos": self.input_pos[None, : self.n_tokens], "mask": self.causal_mask[None, : self.n_tokens], "encoder_input": torch.randn( 1, self.encoder_max_seq_len, self.model_.dim, dtype=self.dtype ), "encoder_mask": torch.ones( [1, self.n_tokens, self.encoder_max_seq_len], dtype=torch.bool ), } else: return None def get_dynamic_shapes(self): batch_size = 1 dim_seq_len = torch.export.Dim("token_dim", min=1, max=self.max_seq_len) # Hardcoding # of tiles to be 2. image tokens per tile is 1601. if self.use_kv_cache: dynamic_shapes = { "tokens": {0: batch_size, 1: dim_seq_len}, "encoder_input": None, "encoder_mask": {0: 1, 1: dim_seq_len, 2: None}, "mask": {0: batch_size, 1: dim_seq_len, 2: None}, "input_pos": {0: batch_size, 1: dim_seq_len}, } else: dynamic_shapes = { "tokens": {0: batch_size, 1: dim_seq_len}, } return dynamic_shapes