1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# Load llama model from a GGUF file, quantized in Q4_0 format. 8# For float weights, we load them directly from the GGUF file. 9# For Q4_0 weights, we load them into a Tensor subclass (GGMLInt4LinearWeight). 10# This is done by replacing the linear weight with the subclass. 11 12import logging 13import os 14from typing import Callable, Dict, Mapping 15 16import torch 17from executorch.examples.models.llama.experimental.subclass import ( 18 _unpack_two_uint8, 19 GGMLInt4LinearWeight, 20 to_float, 21) 22from executorch.extension.gguf_util.converters.llama_converter import ( 23 _convert_gguf_tensor_name_to_llama_nn, 24 _create_pt_model, 25) 26from executorch.extension.gguf_util.load_gguf import GGUFWeights, load_file 27from gguf import ReaderTensor 28from gguf.constants import GGMLQuantizationType 29from torchao.quantization.subclass import QuantizedLinearWeightBase 30 31FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" 32logging.basicConfig(level=logging.INFO, format=FORMAT) 33 34 35def _replace_with_custom_fn_if_matches_filter( 36 # pyre-fixme[2]: Parameter must be annotated. 37 model, 38 replacement_fn, 39 filter_fn, 40 cur_fqn="", 41) -> None: 42 """ 43 For each `child` in `model`, replaces it with `replacement_fn(child)` 44 if `filter_fn(child)` is `True` 45 """ 46 if filter_fn(model, cur_fqn[:-1]): 47 model = replacement_fn(model, cur_fqn[:-1]) 48 return model 49 else: 50 for name, child in model.named_children(): 51 new_child = _replace_with_custom_fn_if_matches_filter( 52 child, replacement_fn, filter_fn, f"{cur_fqn}{name}." 53 ) 54 if new_child is not child: 55 setattr(model, name, new_child) 56 return model 57 58 59# pyre-fixme[3]: Return type must be annotated. 60# pyre-fixme[2]: Parameter must be annotated. 61def _get_subclass_inserter( 62 weight_map: Dict[str, ReaderTensor] 63) -> Callable[[torch.nn.Module, str], torch.nn.Module]: 64 def insert_subclass(lin, fqn): 65 # TODO: replace weights with gguf format tensor 66 # packed tensor should have size [numel / 32, 18] 67 fqn = fqn + ".weight" 68 assert ( 69 fqn in weight_map 70 ), f"Expect {fqn} to be in weight map but not found. All keys are {weight_map.keys()}" 71 tensor = weight_map[fqn] 72 print(fqn, tensor.shape, tensor.data.shape, lin.weight.shape) 73 packed = torch.from_numpy(tensor.data).reshape(-1, 18) 74 scale = torch.tensor(_unpack_two_uint8(packed[:, :2]), dtype=torch.float16) 75 lin.weight = torch.nn.Parameter( 76 GGMLInt4LinearWeight(packed, scale, lin.weight.shape) 77 ) 78 return lin 79 80 return insert_subclass 81 82 83# pyre-fixme[3]: Return type must be annotated. 84# pyre-fixme[2]: Parameter must be annotated. 85def _get_filter_fn( 86 weight_map: Dict[str, ReaderTensor] 87) -> Callable[[torch.nn.Module, str], bool]: 88 def _is_linear(mod, fqn): 89 return ( 90 isinstance(mod, torch.nn.Linear) 91 and hasattr(mod, "weight") 92 and weight_map[fqn + ".weight"].tensor_type == GGMLQuantizationType.Q4_0 93 and not isinstance(mod.weight, QuantizedLinearWeightBase) 94 ) 95 96 return _is_linear 97 98 99def change_linear_weights_to_q4_0_tensors( 100 model: torch.nn.Module, gguf_weights: GGUFWeights 101) -> None: 102 """ 103 Converts all linear weight tensors to the 104 `GGMLInt4LinearWeight` tensor subclass, 105 effectively applying the same form of quantization 106 as apply_dynamic_quant while not modifying the linear modules. 107 """ 108 assert gguf_weights is not None, "Must provide gguf_weights" 109 weight_map = { 110 _convert_gguf_tensor_name_to_llama_nn(tensor.name): tensor 111 for tensor in gguf_weights.tensors 112 } 113 114 _replace_with_custom_fn_if_matches_filter( 115 model, 116 _get_subclass_inserter(weight_map), 117 _get_filter_fn(weight_map), 118 ) 119 120 121def get_float_weights( 122 pt_model: torch.nn.Module, gguf_weights: GGUFWeights 123) -> Mapping[str, torch.Tensor]: 124 """ 125 Returns a mapping from the fqn to the float weight tensor. Even though 126 the model is quantized in Q4_0, these weights are still stored as float. 127 Args: 128 pt_model (torch.nn.Module): The model to load the weights. 129 gguf_weights (GGUFWeights): The weights to extract the weights from. 130 """ 131 state_dict = {} 132 for tensor in gguf_weights.tensors: 133 model_key = _convert_gguf_tensor_name_to_llama_nn(tensor.name) 134 if ( 135 tensor.tensor_type == GGMLQuantizationType.F32 136 or tensor.tensor_type == GGMLQuantizationType.F16 137 ): 138 print(tensor.name) 139 reversed_shape = tensor.shape[::-1] 140 new_tensor = tensor.data.reshape(reversed_shape) 141 state_dict[model_key] = torch.from_numpy(new_tensor) 142 # Load token_embd.weight which is quantized in Q4_0 and we dequantize it into float. 143 elif tensor.tensor_type == GGMLQuantizationType.Q4_0: 144 if tensor.name == "token_embd.weight": 145 print(tensor.name) 146 unpacked = to_float(torch.from_numpy(tensor.data.reshape(-1, 18))) 147 state_dict[model_key] = unpacked.reshape( 148 pt_model.params.vocab_size, pt_model.params.dim 149 ) 150 151 # We need to fake initialize the mask, to match with the llama_transformer.py 152 for id in range(pt_model.params.n_layers): 153 mask_name = f"layers.{id}.attention.mask" 154 mask = torch.full( 155 (1, 1, pt_model.params.max_seq_len, pt_model.params.max_seq_len), 156 float("-inf"), 157 ) 158 mask = torch.triu(mask, diagonal=1) 159 state_dict[mask_name] = mask 160 return state_dict 161 162 163def load_gguf_q4_0(gguf_file: str) -> torch.nn.Module: 164 assert os.path.isfile(gguf_file), f"Expect a valid gguf_file path, got {gguf_file}" 165 166 logging.info(f"Loading GGUF file: {gguf_file}") 167 gguf_model_args, gguf_weights = load_file(gguf_file) 168 169 logging.info("Creating the PyTorch model") 170 pt_model = _create_pt_model( 171 gguf_model_args, 172 ) 173 174 logging.info("Load float weights") 175 state_dict = get_float_weights(pt_model, gguf_weights) 176 pt_model.load_state_dict(state_dict, strict=False) 177 178 logging.info("Change linear weights to Q4_0 tensors") 179 change_linear_weights_to_q4_0_tensors(pt_model, gguf_weights) 180 181 pt_model = pt_model.to(dtype=torch.float16) 182 183 return pt_model 184