1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7from typing import List 8 9import torch 10from executorch.examples.models.llama.runner.generation import LlamaRunner, next_token 11 12 13class TorchTuneLlamaRunner(LlamaRunner): 14 def __init__( 15 self, 16 tokenizer_path: str, 17 max_seq_len: int, 18 max_batch_size: int, 19 use_kv_cache: bool, 20 vocab_size: int, 21 device: str = "cpu", 22 ): 23 super().__init__( 24 tokenizer_path, 25 max_seq_len, 26 max_batch_size, 27 use_kv_cache, 28 vocab_size, 29 device, 30 ) 31 32 self.causal_mask = torch.tril( 33 torch.ones( 34 size=(max_seq_len, max_seq_len), 35 dtype=torch.bool, 36 ) 37 ) 38 self.input_pos = torch.arange(max_seq_len) 39 40 def generate( # noqa: C901 41 self, 42 prompt_tokens: List[int], 43 max_seq_len: int, 44 temperature: float = 0.8, 45 top_p: float = 0.9, 46 echo: bool = False, 47 ) -> List[int]: 48 # Prefill 49 seq_len = len(prompt_tokens) 50 input_pos = self.input_pos[None, :seq_len] 51 mask = self.causal_mask[None, :seq_len] 52 if self.use_kv_cache: 53 logits = self.forward( 54 tokens=torch.tensor( 55 [prompt_tokens], dtype=torch.long, device=self.device 56 ), 57 input_pos=input_pos, 58 mask=mask, 59 ) 60 else: 61 logits = self.forward( 62 tokens=torch.tensor( 63 [prompt_tokens], dtype=torch.long, device=self.device 64 ), 65 ) 66 67 # Only need the last logit. 68 current_token = next_token(logits[:, -1, :], temperature, top_p) 69 print(f"{self.tokenizer.decode_token(current_token)}", end="", flush=True) 70 tokens = prompt_tokens + [current_token] 71 72 while len(tokens) < max_seq_len: 73 mask = self.causal_mask[None, seq_len, None, :] 74 input_pos = self.input_pos[None, seq_len, None] 75 if self.use_kv_cache: 76 logits = self.forward( 77 tokens=torch.tensor( 78 [[current_token]], dtype=torch.long, device=self.device 79 ), 80 input_pos=input_pos, 81 mask=mask, 82 ) 83 else: 84 logits = self.forward( 85 tokens=torch.tensor([tokens], dtype=torch.long, device=self.device), 86 ) 87 88 # Only need the last logit. 89 current_token = next_token(logits[:, -1, :], temperature, top_p) 90 tokens.append(current_token) 91 92 if current_token == self.tokenizer.eos_id or ( 93 hasattr(self.tokenizer, "stop_tokens") 94 and current_token in self.tokenizer.stop_tokens 95 ): 96 break 97 98 print(f"{self.tokenizer.decode_token(current_token)}", end="", flush=True) 99 seq_len += 1 100 101 return tokens if echo else tokens[len(prompt_tokens) :] 102