• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import unittest
8
9import torch
10from executorch.backends.xnnpack._passes.convert_to_linear import ConvertToLinearPass
11from executorch.backends.xnnpack.test.tester import RunPasses, Tester
12
13
14class TestConvertToLinear(unittest.TestCase):
15    PassStage = RunPasses([ConvertToLinearPass])
16
17    def test_fp32_convert_to_linear(self):
18        in_sizes = [1, 4, 4]
19        input_sizes = [4, 37, 17]
20        output_sizes = [4, 17, 37]
21        bias_vals = [True, True, False]
22
23        for i, _ in enumerate(in_sizes):
24            in_size = int(in_sizes[i])
25            input_size = int(input_sizes[i])
26            output_size = int(output_sizes[i])
27            linear = torch.nn.Linear(input_size, output_size, bias=bias_vals[i])
28            inputs = (torch.randn(in_size, input_size),)
29
30            (
31                Tester(linear, inputs)
32                .export()
33                .to_edge()
34                .run_passes(self.PassStage)
35                .check_count(
36                    {"executorch_exir_dialects_edge__ops_aten_linear_default": 1}
37                )
38                .run_method_and_compare_outputs()
39            )
40