# Owner(s): ["module: codegen"] import textwrap import unittest from typing import cast import expecttest import yaml import torchgen.dest as dest import torchgen.gen as gen from torchgen.gen import LineLoader, parse_native_yaml_struct from torchgen.model import ( Annotation, CustomClassType, DispatchKey, NativeFunctionsGroup, Type, ) class TestCodegenModel(expecttest.TestCase): def assertParseErrorInline(self, yaml_str: str, expect: str) -> None: es = yaml.load(yaml_str, Loader=LineLoader) try: parse_native_yaml_struct(es, set()) except AssertionError as e: # hack to strip out the context msg, _ = str(e).split(" in ", 2) self.assertExpectedInline("\n".join(textwrap.wrap(msg)), expect, skip=1) return self.fail(msg="Did not raise when expected to") def assertUfuncErrorInline(self, yaml_str: str, expect: str) -> None: # parse a single structured group out of the yaml to g es = yaml.load(yaml_str, Loader=LineLoader) parsed_yaml = parse_native_yaml_struct(es, set()) native_functions, backend_indices = ( parsed_yaml.native_functions, parsed_yaml.backend_indices, ) grouped_native_functions = gen.get_grouped_native_functions(native_functions) assert len(grouped_native_functions) == 1 g = grouped_native_functions[0] assert isinstance(g, NativeFunctionsGroup) assert g.out.ufunc_inner_loop # this is not ufunc codegen per se, but it does some basic sanity tests for # ufunc generation gen.compute_meta_function_declaration(g) dest.compute_native_function_declaration(g, backend_indices[DispatchKey.CPU]) dest.compute_native_function_declaration(g, backend_indices[DispatchKey.CUDA]) try: # the real kahuna dest.compute_ufunc_cpu(g) dest.compute_ufunc_cpu_kernel(g) dest.compute_ufunc_cuda(g) except AssertionError as e: # hack to strip out the context msg, _ = str(e).split(" in ", 2) self.assertExpectedInline("\n".join(textwrap.wrap(msg)), expect, skip=1) return self.fail(msg="Did not raise when expected to") # NB: indent is hardcoded to be two here, so format your yaml accordingly binop_out = ( "func: binop.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)" ) ti_binop_out = f"""{binop_out} structured: True structured_inherits: TensorIteratorBase""" ti_binop = """func: binop(Tensor self, Tensor other) -> Tensor structured_delegate: binop.out """ ti_unop_out = """func: unop.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) structured: True structured_inherits: TensorIteratorBase""" ti_unop = """func: unop(Tensor self) -> Tensor structured_delegate: unop.out """ def test_nonstructured_ufunc(self) -> None: yaml_str = f"""\ - {self.binop_out} ufunc_inner_loop: Generic: binop (Bool) """ self.assertParseErrorInline( yaml_str, """\ ufunc must be structured""", ) def test_overlapping_ufunc_and_dispatch(self) -> None: yaml_str = f"""\ - {self.ti_binop_out} ufunc_inner_loop: Generic: binop (Bool) dispatch: CPU: binop_cpu """ self.assertParseErrorInline( yaml_str, """\ ufunc should not have explicit dispatch entry for CPU""", ) # See https://github.com/pytorch/pytorch/pull/65851#discussion_r810238456 @unittest.expectedFailure def test_scalaronly_shadowed(self) -> None: yaml_str = f"""\ - {self.ti_binop_out} ufunc_inner_loop: Generic: binop (Bool) ScalarOnly: binop (Bool) """ self.assertParseErrorInline( yaml_str, """\ """, ) def test_conflicting_ufunc(self) -> None: yaml_str = f"""\ - {self.ti_binop_out} ufunc_inner_loop: Generic: binop (Bool) ScalarOnly: binop_scalar (Bool) - {self.ti_binop} """ self.assertUfuncErrorInline( yaml_str, """\ ScalarOnly and Generic must have same ufunc name""", ) def test_invalid_cudafunctoronself_for_binary_op(self) -> None: yaml_str = f"""\ - {self.ti_unop_out} ufunc_inner_loop: Generic: unop (All) CUDAFunctorOnSelf: unop_self_cuda (All) - {self.ti_unop} """ self.assertUfuncErrorInline( yaml_str, """\ cannot use CUDAFunctorOnSelf on non-binary function""", ) def test_parse_custom_class_type(self) -> None: custom_class_name = "namespace_foo.class_bar" custom_class_name_with_prefix = f"__torch__.torch.classes.{custom_class_name}" custom_class_type = cast( CustomClassType, Type.parse(custom_class_name_with_prefix) ) self.assertTrue(isinstance(custom_class_type, CustomClassType)) self.assertEqual(custom_class_name, custom_class_type.class_name) self.assertEqual(custom_class_name_with_prefix, str(custom_class_type)) class TestAnnotation(expecttest.TestCase): def test_single_alias_no_write(self) -> None: a = Annotation.parse("a") self.assertEqual(a.alias_set, tuple("a")) self.assertFalse(a.is_write) self.assertEqual(a.alias_set_after, ()) def test_single_alias_is_write(self) -> None: a = Annotation.parse("a!") self.assertEqual(a.alias_set, tuple("a")) self.assertTrue(a.is_write) self.assertEqual(a.alias_set_after, ()) def test_single_alias_is_write_to_wildcard(self) -> None: a = Annotation.parse("a! -> *") self.assertEqual(a.alias_set, tuple("a")) self.assertTrue(a.is_write) self.assertEqual(a.alias_set_after, tuple("*")) def test_alias_set(self) -> None: a = Annotation.parse("a|b") self.assertEqual(a.alias_set, ("a", "b")) def test_alias_set_is_write_raises_exception(self) -> None: with self.assertRaisesRegex( AssertionError, r"alias set larger than 1 is not mutable" ): Annotation.parse("a|b!") def test_single_alias_is_write_to_alias_set(self) -> None: a = Annotation.parse("a! -> a|b") self.assertEqual(a.alias_set, tuple("a")) self.assertTrue(a.is_write) self.assertEqual(a.alias_set_after, ("a", "b")) def test_before_and_after_alias_set_larger_than_1_raises_exception(self) -> None: with self.assertRaisesRegex( AssertionError, r"before alias set and after alias set cannot be larger than 1 at the same time", ): Annotation.parse("a|b -> c|d") if __name__ == "__main__": unittest.main()