1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import unittest 8 9import torch 10 11from executorch.backends.xnnpack.test.tester import Tester 12from executorch.examples.models.llama.model import Llama2Model 13 14 15class TestLlama2ETExample(unittest.TestCase): 16 def test_f32(self): 17 self._test() 18 19 def test_f16(self): 20 self._test(torch.float16) 21 22 # TODO - dynamic shape 23 24 def _test(self, dtype: torch.dtype = torch.float): 25 assert dtype in [ 26 torch.float, 27 torch.float16, 28 ], f"Only fp32 and fp16 are supported, but got dtype: {dtype}" 29 30 llama2 = Llama2Model() 31 model = llama2.get_eager_model().to(dtype) 32 33 # Only convert fp32 inputs to dtype 34 example_inputs = tuple( 35 tensor.to(dtype) if tensor.dtype == torch.float32 else tensor 36 for tensor in llama2.get_example_inputs() 37 ) 38 39 ( 40 Tester(model, example_inputs) 41 .export() 42 .to_edge_transform_and_lower() 43 .to_executorch() 44 .serialize() 45 .run_method_and_compare_outputs(atol=5e-2, inputs=example_inputs) 46 ) 47