• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# Load llama model from a GGUF file, quantized in Q4_0 format.
8# For float weights, we load them directly from the GGUF file.
9# For Q4_0 weights, we load them into a Tensor subclass (GGMLInt4LinearWeight).
10# This is done by replacing the linear weight with the subclass.
11
12import logging
13import os
14from typing import Callable, Dict, Mapping
15
16import torch
17from executorch.examples.models.llama.experimental.subclass import (
18    _unpack_two_uint8,
19    GGMLInt4LinearWeight,
20    to_float,
21)
22from executorch.extension.gguf_util.converters.llama_converter import (
23    _convert_gguf_tensor_name_to_llama_nn,
24    _create_pt_model,
25)
26from executorch.extension.gguf_util.load_gguf import GGUFWeights, load_file
27from gguf import ReaderTensor
28from gguf.constants import GGMLQuantizationType
29from torchao.quantization.subclass import QuantizedLinearWeightBase
30
31FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
32logging.basicConfig(level=logging.INFO, format=FORMAT)
33
34
35def _replace_with_custom_fn_if_matches_filter(
36    # pyre-fixme[2]: Parameter must be annotated.
37    model,
38    replacement_fn,
39    filter_fn,
40    cur_fqn="",
41) -> None:
42    """
43    For each `child` in `model`, replaces it with `replacement_fn(child)`
44    if `filter_fn(child)` is `True`
45    """
46    if filter_fn(model, cur_fqn[:-1]):
47        model = replacement_fn(model, cur_fqn[:-1])
48        return model
49    else:
50        for name, child in model.named_children():
51            new_child = _replace_with_custom_fn_if_matches_filter(
52                child, replacement_fn, filter_fn, f"{cur_fqn}{name}."
53            )
54            if new_child is not child:
55                setattr(model, name, new_child)
56        return model
57
58
59# pyre-fixme[3]: Return type must be annotated.
60# pyre-fixme[2]: Parameter must be annotated.
61def _get_subclass_inserter(
62    weight_map: Dict[str, ReaderTensor]
63) -> Callable[[torch.nn.Module, str], torch.nn.Module]:
64    def insert_subclass(lin, fqn):
65        # TODO: replace weights with gguf format tensor
66        # packed tensor should have size [numel / 32, 18]
67        fqn = fqn + ".weight"
68        assert (
69            fqn in weight_map
70        ), f"Expect {fqn} to be in weight map but not found. All keys are {weight_map.keys()}"
71        tensor = weight_map[fqn]
72        print(fqn, tensor.shape, tensor.data.shape, lin.weight.shape)
73        packed = torch.from_numpy(tensor.data).reshape(-1, 18)
74        scale = torch.tensor(_unpack_two_uint8(packed[:, :2]), dtype=torch.float16)
75        lin.weight = torch.nn.Parameter(
76            GGMLInt4LinearWeight(packed, scale, lin.weight.shape)
77        )
78        return lin
79
80    return insert_subclass
81
82
83# pyre-fixme[3]: Return type must be annotated.
84# pyre-fixme[2]: Parameter must be annotated.
85def _get_filter_fn(
86    weight_map: Dict[str, ReaderTensor]
87) -> Callable[[torch.nn.Module, str], bool]:
88    def _is_linear(mod, fqn):
89        return (
90            isinstance(mod, torch.nn.Linear)
91            and hasattr(mod, "weight")
92            and weight_map[fqn + ".weight"].tensor_type == GGMLQuantizationType.Q4_0
93            and not isinstance(mod.weight, QuantizedLinearWeightBase)
94        )
95
96    return _is_linear
97
98
99def change_linear_weights_to_q4_0_tensors(
100    model: torch.nn.Module, gguf_weights: GGUFWeights
101) -> None:
102    """
103    Converts all linear weight tensors to the
104    `GGMLInt4LinearWeight` tensor subclass,
105    effectively applying the same form of quantization
106    as apply_dynamic_quant while not modifying the linear modules.
107    """
108    assert gguf_weights is not None, "Must provide gguf_weights"
109    weight_map = {
110        _convert_gguf_tensor_name_to_llama_nn(tensor.name): tensor
111        for tensor in gguf_weights.tensors
112    }
113
114    _replace_with_custom_fn_if_matches_filter(
115        model,
116        _get_subclass_inserter(weight_map),
117        _get_filter_fn(weight_map),
118    )
119
120
121def get_float_weights(
122    pt_model: torch.nn.Module, gguf_weights: GGUFWeights
123) -> Mapping[str, torch.Tensor]:
124    """
125    Returns a mapping from the fqn to the float weight tensor. Even though
126    the model is quantized in Q4_0, these weights are still stored as float.
127    Args:
128        pt_model (torch.nn.Module): The model to load the weights.
129        gguf_weights (GGUFWeights): The weights to extract the weights from.
130    """
131    state_dict = {}
132    for tensor in gguf_weights.tensors:
133        model_key = _convert_gguf_tensor_name_to_llama_nn(tensor.name)
134        if (
135            tensor.tensor_type == GGMLQuantizationType.F32
136            or tensor.tensor_type == GGMLQuantizationType.F16
137        ):
138            print(tensor.name)
139            reversed_shape = tensor.shape[::-1]
140            new_tensor = tensor.data.reshape(reversed_shape)
141            state_dict[model_key] = torch.from_numpy(new_tensor)
142        # Load token_embd.weight which is quantized in Q4_0 and we dequantize it into float.
143        elif tensor.tensor_type == GGMLQuantizationType.Q4_0:
144            if tensor.name == "token_embd.weight":
145                print(tensor.name)
146                unpacked = to_float(torch.from_numpy(tensor.data.reshape(-1, 18)))
147                state_dict[model_key] = unpacked.reshape(
148                    pt_model.params.vocab_size, pt_model.params.dim
149                )
150
151    # We need to fake initialize the mask, to match with the llama_transformer.py
152    for id in range(pt_model.params.n_layers):
153        mask_name = f"layers.{id}.attention.mask"
154        mask = torch.full(
155            (1, 1, pt_model.params.max_seq_len, pt_model.params.max_seq_len),
156            float("-inf"),
157        )
158        mask = torch.triu(mask, diagonal=1)
159        state_dict[mask_name] = mask
160    return state_dict
161
162
163def load_gguf_q4_0(gguf_file: str) -> torch.nn.Module:
164    assert os.path.isfile(gguf_file), f"Expect a valid gguf_file path, got {gguf_file}"
165
166    logging.info(f"Loading GGUF file: {gguf_file}")
167    gguf_model_args, gguf_weights = load_file(gguf_file)
168
169    logging.info("Creating the PyTorch model")
170    pt_model = _create_pt_model(
171        gguf_model_args,
172    )
173
174    logging.info("Load float weights")
175    state_dict = get_float_weights(pt_model, gguf_weights)
176    pt_model.load_state_dict(state_dict, strict=False)
177
178    logging.info("Change linear weights to Q4_0 tensors")
179    change_linear_weights_to_q4_0_tensors(pt_model, gguf_weights)
180
181    pt_model = pt_model.to(dtype=torch.float16)
182
183    return pt_model
184