• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7from typing import List
8
9import torch
10from executorch.examples.models.llama.runner.generation import LlamaRunner, next_token
11
12
13class TorchTuneLlamaRunner(LlamaRunner):
14    def __init__(
15        self,
16        tokenizer_path: str,
17        max_seq_len: int,
18        max_batch_size: int,
19        use_kv_cache: bool,
20        vocab_size: int,
21        device: str = "cpu",
22    ):
23        super().__init__(
24            tokenizer_path,
25            max_seq_len,
26            max_batch_size,
27            use_kv_cache,
28            vocab_size,
29            device,
30        )
31
32        self.causal_mask = torch.tril(
33            torch.ones(
34                size=(max_seq_len, max_seq_len),
35                dtype=torch.bool,
36            )
37        )
38        self.input_pos = torch.arange(max_seq_len)
39
40    def generate(  # noqa: C901
41        self,
42        prompt_tokens: List[int],
43        max_seq_len: int,
44        temperature: float = 0.8,
45        top_p: float = 0.9,
46        echo: bool = False,
47    ) -> List[int]:
48        # Prefill
49        seq_len = len(prompt_tokens)
50        input_pos = self.input_pos[None, :seq_len]
51        mask = self.causal_mask[None, :seq_len]
52        if self.use_kv_cache:
53            logits = self.forward(
54                tokens=torch.tensor(
55                    [prompt_tokens], dtype=torch.long, device=self.device
56                ),
57                input_pos=input_pos,
58                mask=mask,
59            )
60        else:
61            logits = self.forward(
62                tokens=torch.tensor(
63                    [prompt_tokens], dtype=torch.long, device=self.device
64                ),
65            )
66
67        # Only need the last logit.
68        current_token = next_token(logits[:, -1, :], temperature, top_p)
69        print(f"{self.tokenizer.decode_token(current_token)}", end="", flush=True)
70        tokens = prompt_tokens + [current_token]
71
72        while len(tokens) < max_seq_len:
73            mask = self.causal_mask[None, seq_len, None, :]
74            input_pos = self.input_pos[None, seq_len, None]
75            if self.use_kv_cache:
76                logits = self.forward(
77                    tokens=torch.tensor(
78                        [[current_token]], dtype=torch.long, device=self.device
79                    ),
80                    input_pos=input_pos,
81                    mask=mask,
82                )
83            else:
84                logits = self.forward(
85                    tokens=torch.tensor([tokens], dtype=torch.long, device=self.device),
86                )
87
88            # Only need the last logit.
89            current_token = next_token(logits[:, -1, :], temperature, top_p)
90            tokens.append(current_token)
91
92            if current_token == self.tokenizer.eos_id or (
93                hasattr(self.tokenizer, "stop_tokens")
94                and current_token in self.tokenizer.stop_tokens
95            ):
96                break
97
98            print(f"{self.tokenizer.decode_token(current_token)}", end="", flush=True)
99            seq_len += 1
100
101        return tokens if echo else tokens[len(prompt_tokens) :]
102