1#!/usr/bin/env fbpython 2# Copyright (c) Meta Platforms, Inc. and affiliates. 3# All rights reserved. 4# 5# This source code is licensed under the BSD-style license found in the 6# LICENSE file in the root directory of this source tree. 7 8import unittest 9 10import torch 11 12from executorch.exir.dialects.edge.op.api import to_variant 13from torchgen.model import SchemaKind 14 15aten = torch.ops.aten 16 17OPS_TO_FUNCTIONAL = { 18 aten.add.out: aten.add.Tensor, 19 aten._native_batch_norm_legit_no_training.out: aten._native_batch_norm_legit_no_training.default, 20 aten.addmm.out: aten.addmm.default, 21 aten.view_copy.out: aten.view_copy.default, 22} 23 24 25class TestApi(unittest.TestCase): 26 """Test api.py""" 27 28 def test_to_out_variant_returns_self_when_given_out_variant(self) -> None: 29 op = aten.add.out 30 variant = to_variant(op, SchemaKind.out) 31 self.assertEqual(variant, op) 32 33 def test_to_functional_variant_returns_self_when_given_functional(self) -> None: 34 op = aten.leaky_relu.default 35 variant = to_variant(op, SchemaKind.functional) 36 self.assertEqual(variant, op) 37 38 def test_to_functional_variant_returns_correct_op( 39 self, 40 ) -> None: 41 for op in OPS_TO_FUNCTIONAL: 42 variant = to_variant(op, SchemaKind.functional) 43 self.assertEqual(variant, OPS_TO_FUNCTIONAL[op]) 44 45 def test_to_out_variant_returns_correct_op( 46 self, 47 ) -> None: 48 inv_map = {v: k for k, v in OPS_TO_FUNCTIONAL.items()} 49 for op in inv_map: 50 variant = to_variant(op, SchemaKind.out) 51 self.assertEqual(variant, inv_map[op]) 52