• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# Adapted from gpt-fast: https://github.com/pytorch-labs/gpt-fast/blob/main/generate.py
8import argparse
9
10from typing import Optional, Tuple
11
12import torch
13
14from executorch.examples.models.llama.experimental.load_gguf_q4_0 import load_gguf_q4_0
15from sentencepiece import SentencePieceProcessor
16
17
18def multinomial_sample_one_no_sync(
19    probs_sort,
20):  # Does multinomial sampling without a cuda synchronization
21    q = torch.empty_like(probs_sort).exponential_(1)
22    return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
23
24
25def logits_to_probs(logits, temperature: float = 1.0, top_k: Optional[int] = None):
26    logits = logits / max(temperature, 1e-5)
27
28    if top_k is not None:
29        v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
30        pivot = v.select(-1, -1).unsqueeze(-1)
31        logits = torch.where(logits < pivot, -float("Inf"), logits)
32    probs = torch.nn.functional.softmax(logits, dim=-1)
33    return probs
34
35
36def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None):
37    probs = logits_to_probs(logits[0, -1], temperature, top_k)
38    idx_next = multinomial_sample_one_no_sync(probs)
39    return idx_next, probs
40
41
42def encode_tokens(tokenizer, string, bos=True, device="cpu"):
43    tokens = tokenizer.encode(string)
44    if bos:
45        tokens = [tokenizer.bos_id()] + tokens
46    return torch.tensor(tokens, dtype=torch.int, device=device)
47
48
49def decode_one_token(
50    model: torch.nn.Module, x: torch.Tensor, **sampling_kwargs
51) -> Tuple[torch.Tensor, torch.Tensor]:
52    logits = model(x)
53    return sample(logits, **sampling_kwargs)
54
55
56def prefill(model: torch.nn.Module, x: torch.Tensor, **sampling_kwargs) -> torch.Tensor:
57    return decode_one_token(model, x, **sampling_kwargs)[0]
58
59
60def decode_n_tokens(
61    model: torch.nn.Module,
62    cur_token: torch.Tensor,
63    num_new_tokens: int,
64    callback=lambda _: _,
65    **sampling_kwargs,
66):
67    print(f"cur_token: {cur_token}")
68    new_tokens, new_probs = [], []
69    for _ in range(num_new_tokens):
70        with torch.backends.cuda.sdp_kernel(
71            enable_flash=False, enable_mem_efficient=False, enable_math=True
72        ):  # Actually better for Inductor to codegen attention here
73            next_token, next_prob = decode_one_token(
74                model, cur_token.view(1, -1), **sampling_kwargs
75            )
76            new_tokens.append(next_token.clone())
77            # print(next_token)
78            callback(next_token)
79            new_probs.append(next_prob.clone())
80            cur_token = torch.cat((cur_token.squeeze(), next_token), dim=0)
81            # print(cur_token)
82
83    return new_tokens, new_probs
84
85
86@torch.no_grad()
87def generate(
88    model: torch.nn.Module,
89    prompt: torch.Tensor,
90    max_new_tokens: int,
91    *,
92    interactive: bool,
93    callback=lambda x: x,
94    **sampling_kwargs,
95) -> torch.Tensor:
96    """
97    Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
98    """
99
100    # create an empty tensor of the expected final shape and fill in the current tokens
101    T = prompt.size(0)
102    T_new = T + max_new_tokens
103    # if interactive:
104    #     max_seq_length = 350
105    # else:
106    #     max_seq_length = min(T_new, model.params.max_seq_len)
107
108    device, dtype = prompt.device, prompt.dtype
109
110    # with torch.device(device):
111    #     model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length)
112
113    # create an empty tensor of the expected final shape and fill in the current tokens
114    empty = torch.empty(T_new, dtype=dtype, device=device)
115    empty[:T] = prompt
116    seq = empty
117    # input_pos = torch.arange(0, T, device=device)
118
119    next_token = prefill(model, prompt.view(1, -1), **sampling_kwargs)
120    seq[T] = next_token
121    callback(next_token)
122
123    cur_tokens = torch.cat((prompt, next_token), dim=0)
124    # input_pos = torch.tensor([T], device=device, dtype=torch.int)
125
126    generated_tokens, _ = decode_n_tokens(
127        model,
128        cur_tokens.view(1, -1),
129        # input_pos,
130        max_new_tokens - 1,
131        callback=callback,
132        **sampling_kwargs,
133    )
134    seq[T + 1 :] = torch.cat(generated_tokens)
135
136    return seq
137
138
139def main() -> None:
140    parser = argparse.ArgumentParser()
141    parser.add_argument(
142        "--gguf_file",
143        type=str,
144        help="The GGUF file to load.",
145    )
146    parser.add_argument(
147        "--tokenizer_path",
148        type=str,
149        help="The tokenizer.model path.",
150    )
151    parser.add_argument(
152        "--prompt", type=str, default="Hello, my name is", help="Input prompt."
153    )
154
155    args = parser.parse_args()
156
157    tokenizer = SentencePieceProcessor(model_file=str(args.tokenizer_path))
158    encoded = encode_tokens(tokenizer, args.prompt, bos=True, device="cpu")
159
160    pt_model = load_gguf_q4_0(args.gguf_file)
161
162    max_new_tokens = 100
163    buffer = [tokenizer.decode(encoded.tolist())]
164    period_id = tokenizer.encode(".")[0]
165    done_generating = False
166
167    def callback(x):
168        nonlocal done_generating
169        if done_generating:
170            return
171        buffer.append(tokenizer.decode([period_id] + x.tolist())[1:])
172        if x.item() == tokenizer.eos_id():
173            done_generating = True
174        if len(buffer) == 4 or done_generating:
175            print("".join(buffer), end="", flush=True)
176            buffer.clear()
177
178    generate(
179        pt_model,
180        encoded,
181        max_new_tokens,
182        interactive=False,
183        callback=callback,
184        temperature=1.0,
185        top_k=10,
186    )
187
188
189if __name__ == "__main__":
190    main()
191