1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import unittest 8 9import torch 10from executorch.backends.xnnpack._passes.convert_to_linear import ConvertToLinearPass 11from executorch.backends.xnnpack.test.tester import RunPasses, Tester 12 13 14class TestConvertToLinear(unittest.TestCase): 15 PassStage = RunPasses([ConvertToLinearPass]) 16 17 def test_fp32_convert_to_linear(self): 18 in_sizes = [1, 4, 4] 19 input_sizes = [4, 37, 17] 20 output_sizes = [4, 17, 37] 21 bias_vals = [True, True, False] 22 23 for i, _ in enumerate(in_sizes): 24 in_size = int(in_sizes[i]) 25 input_size = int(input_sizes[i]) 26 output_size = int(output_sizes[i]) 27 linear = torch.nn.Linear(input_size, output_size, bias=bias_vals[i]) 28 inputs = (torch.randn(in_size, input_size),) 29 30 ( 31 Tester(linear, inputs) 32 .export() 33 .to_edge() 34 .run_passes(self.PassStage) 35 .check_count( 36 {"executorch_exir_dialects_edge__ops_aten_linear_default": 1} 37 ) 38 .run_method_and_compare_outputs() 39 ) 40