• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env fbpython
2# Copyright (c) Meta Platforms, Inc. and affiliates.
3# All rights reserved.
4#
5# This source code is licensed under the BSD-style license found in the
6# LICENSE file in the root directory of this source tree.
7
8import unittest
9
10import torch
11
12from executorch.exir.dialects.edge.op.api import to_variant
13from torchgen.model import SchemaKind
14
15aten = torch.ops.aten
16
17OPS_TO_FUNCTIONAL = {
18    aten.add.out: aten.add.Tensor,
19    aten._native_batch_norm_legit_no_training.out: aten._native_batch_norm_legit_no_training.default,
20    aten.addmm.out: aten.addmm.default,
21    aten.view_copy.out: aten.view_copy.default,
22}
23
24
25class TestApi(unittest.TestCase):
26    """Test api.py"""
27
28    def test_to_out_variant_returns_self_when_given_out_variant(self) -> None:
29        op = aten.add.out
30        variant = to_variant(op, SchemaKind.out)
31        self.assertEqual(variant, op)
32
33    def test_to_functional_variant_returns_self_when_given_functional(self) -> None:
34        op = aten.leaky_relu.default
35        variant = to_variant(op, SchemaKind.functional)
36        self.assertEqual(variant, op)
37
38    def test_to_functional_variant_returns_correct_op(
39        self,
40    ) -> None:
41        for op in OPS_TO_FUNCTIONAL:
42            variant = to_variant(op, SchemaKind.functional)
43            self.assertEqual(variant, OPS_TO_FUNCTIONAL[op])
44
45    def test_to_out_variant_returns_correct_op(
46        self,
47    ) -> None:
48        inv_map = {v: k for k, v in OPS_TO_FUNCTIONAL.items()}
49        for op in inv_map:
50            variant = to_variant(op, SchemaKind.out)
51            self.assertEqual(variant, inv_map[op])
52