1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import argparse 8import json 9from typing import Optional 10 11import torch 12 13from executorch.examples.models.llama.export_llama_lib import ( 14 build_args_parser as _build_args_parser, 15) 16from executorch.examples.models.llama3_2_vision.runner.generation import ( 17 TorchTuneLlamaRunner, 18) 19 20 21class ExportedLlamaRunner(TorchTuneLlamaRunner): 22 """ 23 Runs a torch-exported .pt2 Llama. 24 """ 25 26 def __init__(self, args): 27 with open(args.params, "r") as f: 28 params = json.loads(f.read()) 29 super().__init__( 30 tokenizer_path=args.tokenizer_path, 31 max_seq_len=args.max_seq_length, 32 max_batch_size=1, 33 use_kv_cache=args.use_kv_cache, 34 vocab_size=params["vocab_size"], 35 device="cuda" if torch.cuda.is_available() else "cpu", 36 ) 37 print(f"Loading model from {args.pt2}") 38 self.model = torch.export.load(args.pt2).module() 39 print("Model loaded") 40 41 def forward( 42 self, 43 tokens: Optional[torch.LongTensor] = None, 44 input_pos: Optional[torch.LongTensor] = None, 45 mask: Optional[torch.LongTensor] = None, 46 ) -> torch.Tensor: 47 if self.use_kv_cache: 48 return self.model(tokens, input_pos=input_pos, mask=mask) 49 else: 50 return self.model(tokens) 51 52 53def build_args_parser() -> argparse.ArgumentParser: 54 parser = _build_args_parser() 55 56 parser.add_argument( 57 "--prompt", 58 type=str, 59 default="Hello", 60 ) 61 62 parser.add_argument( 63 "--pt2", 64 type=str, 65 required=True, 66 ) 67 68 parser.add_argument( 69 "--temperature", 70 type=float, 71 default=0, 72 ) 73 74 return parser 75 76 77def main() -> None: 78 parser = build_args_parser() 79 args = parser.parse_args() 80 81 runner = ExportedLlamaRunner(args) 82 result = runner.text_completion( 83 prompt=args.prompt, 84 temperature=args.temperature, 85 ) 86 print( 87 "Response: \n{response}\n Tokens:\n {tokens}".format( 88 response=result["generation"], tokens=result["tokens"] 89 ) 90 ) 91 92 93if __name__ == "__main__": 94 main() # pragma: no cover 95