• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import argparse
8import json
9from typing import Optional
10
11import torch
12
13from executorch.examples.models.llama.export_llama_lib import (
14    build_args_parser as _build_args_parser,
15)
16from executorch.examples.models.llama3_2_vision.runner.generation import (
17    TorchTuneLlamaRunner,
18)
19
20
21class ExportedLlamaRunner(TorchTuneLlamaRunner):
22    """
23    Runs a torch-exported .pt2 Llama.
24    """
25
26    def __init__(self, args):
27        with open(args.params, "r") as f:
28            params = json.loads(f.read())
29        super().__init__(
30            tokenizer_path=args.tokenizer_path,
31            max_seq_len=args.max_seq_length,
32            max_batch_size=1,
33            use_kv_cache=args.use_kv_cache,
34            vocab_size=params["vocab_size"],
35            device="cuda" if torch.cuda.is_available() else "cpu",
36        )
37        print(f"Loading model from {args.pt2}")
38        self.model = torch.export.load(args.pt2).module()
39        print("Model loaded")
40
41    def forward(
42        self,
43        tokens: Optional[torch.LongTensor] = None,
44        input_pos: Optional[torch.LongTensor] = None,
45        mask: Optional[torch.LongTensor] = None,
46    ) -> torch.Tensor:
47        if self.use_kv_cache:
48            return self.model(tokens, input_pos=input_pos, mask=mask)
49        else:
50            return self.model(tokens)
51
52
53def build_args_parser() -> argparse.ArgumentParser:
54    parser = _build_args_parser()
55
56    parser.add_argument(
57        "--prompt",
58        type=str,
59        default="Hello",
60    )
61
62    parser.add_argument(
63        "--pt2",
64        type=str,
65        required=True,
66    )
67
68    parser.add_argument(
69        "--temperature",
70        type=float,
71        default=0,
72    )
73
74    return parser
75
76
77def main() -> None:
78    parser = build_args_parser()
79    args = parser.parse_args()
80
81    runner = ExportedLlamaRunner(args)
82    result = runner.text_completion(
83        prompt=args.prompt,
84        temperature=args.temperature,
85    )
86    print(
87        "Response: \n{response}\n Tokens:\n {tokens}".format(
88            response=result["generation"], tokens=result["tokens"]
89        )
90    )
91
92
93if __name__ == "__main__":
94    main()  # pragma: no cover
95