• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Owner(s): ["oncall: jit"]
2# flake8: noqa
3
4import sys
5import unittest
6from enum import Enum
7from typing import List, Optional
8
9import torch
10from jit.myfunction_a import my_function_a
11from torch.testing._internal.jit_utils import JitTestCase
12
13
14class TestDecorator(JitTestCase):
15    def test_decorator(self):
16        # Note: JitTestCase.checkScript() does not work with decorators
17        # self.checkScript(my_function_a, (1.0,))
18        # Error:
19        #   RuntimeError: expected def but found '@' here:
20        #   @my_decorator
21        #   ~ <--- HERE
22        #   def my_function_a(x: float) -> float:
23        # Do a simple torch.jit.script() test instead
24        fn = my_function_a
25        fx = torch.jit.script(fn)
26        self.assertEqual(fn(1.0), fx(1.0))
27