# Owner(s): ["module: inductor"] import torch from torch._dynamo.utils import counters, optimus_scuba_log from torch._inductor.fx_passes.misc_patterns import numpy_compat_normalization from torch._inductor.test_case import run_tests, TestCase from torch.testing._internal.common_utils import IS_LINUX from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU from torch.testing._internal.triton_utils import requires_gpu def patch(f): f = torch._inductor.config.patch( pre_grad_fusion_options={ "normalization_pass": {}, "remove_split_with_size_one_pass": {}, "merge_getitem_cat_pass": {}, "merge_splits_pass": {}, "mutate_cat_pass": {}, "split_cat_pass": {}, "unbind_stack_pass": {}, }, post_grad_fusion_options={}, )(f) return f class TestSplitCatFxPasses(TestCase): @patch def test_split_normalization(self): def arg_only(x): return [torch.relu(s) for s in torch.split(x, 2, 1)] def arg_only_dim0(x): return [torch.relu(s) for s in torch.split(x, 2, 0)] def kwarg1(x): return [torch.relu(s) for s in torch.split(x, 2, dim=1)] def kwarg2(x): return [ torch.relu(s) for s in torch.split(x, split_size_or_sections=2, dim=1) ] def kwarg3(x): return [ torch.relu(s) for s in torch.split(tensor=x, split_size_or_sections=2, dim=-1) ] def list_replace(x): return [torch.relu(s) for s in torch.split(x, [16, 16], dim=1)] def multi_split(x): return [torch.split(s, 2, 1) for s in torch.split(x, 2, 1)] def unequal_split(x): return [torch.relu(s) for s in torch.split(x, 3, 1)] def arg_only_cm(x): return [torch.relu(s) for s in x.split(2, 1)] def kwarg1_cm(x): return [torch.relu(s) for s in x.split(2, dim=1)] def kwarg2_cm(x): return [torch.relu(s) for s in x.split(split_size=2, dim=1)] def multi_split_cm(x): return [s.split(2, 1) for s in x.split(2, 1)] def unequal_split_cm(x): return [torch.relu(s) for s in x.split(3, 1)] def cm_with_list(x): return [torch.relu(s) for s in x.split([16, 16], dim=-1)] args = [ torch.randn(2, 32), ] for fn, expected_split_norm_count in [ (arg_only, 1), (arg_only_dim0, 1), (kwarg1, 1), (kwarg2, 1), (kwarg3, 1), (list_replace, 0), (multi_split, 17), (unequal_split, 1), (arg_only_cm, 1), (kwarg1_cm, 1), (kwarg2_cm, 1), (multi_split_cm, 17), (unequal_split_cm, 1), (cm_with_list, 1), ]: expected = fn(*args) actual = torch.compile(fn)(*args) torch.testing.assert_close(actual, expected) self.assertEqual( counters["inductor"]["normalization_pass"], expected_split_norm_count, msg=f"for {fn}", ) if expected_split_norm_count > 0: self.assertIn("normalization_pass_pre_grad", optimus_scuba_log) counters.clear() @patch def test_consecutive_split_merge(self): def multi_split(x): return [torch.split(s, 2, 1) for s in torch.split(x, 2, 1)] def multi_split_2(x): return [torch.split(s, 1, 1) for s in torch.split(x, 2, 1)] def multi_split_2_neg_dim(x): return [torch.split(s, 1, 1) for s in torch.split(x, 2, -1)] def multi_split_with_sizes(x): return [torch.split(s, 2, 1) for s in torch.split(x, [16, 16], 1)] def multi_split_kwarg1(x): return [torch.split(s, 2, dim=1) for s in torch.split(x, 2, dim=1)] def multi_split_kwarg2(x): return [ torch.split(s, split_size_or_sections=2, dim=1) for s in torch.split(x, split_size_or_sections=2, dim=1) ] def unequal_multi_split(x): fs = torch.split(x, [10, 10, 12], dim=1) item0 = fs[0] item1 = fs[1] item2 = fs[2] final_items = [] final_items.extend(item0.split([4, 6], 1)) final_items.extend(item1.split([6, 4], 1)) final_items.extend(item2.split([4, 4, 4], 1)) return [torch.relu(s) for s in final_items] def unequal_multi_split_neg_index(x): fs = torch.split(x, [10, 10, 12], dim=1) item0 = fs[-3] item1 = fs[-2] item2 = fs[-1] final_items = [] final_items.extend(item0.split([4, 6], 1)) final_items.extend(item1.split([6, 4], 1)) final_items.extend(item2.split([4, 4, 4], 1)) return [torch.relu(s) for s in final_items] # Shouldn't merge def diff_dims(x): return [torch.split(s, 2, dim=0) for s in torch.split(x, 2, dim=1)] def some_users_not_splits(x): fs = torch.split(x, [10, 10, 12], dim=1) item0 = fs[0] item1 = fs[1] item2 = fs[2] final_items = [] final_items.extend(item0.split([4, 6], 1)) final_items.extend(item1.split([6, 4], 1)) final_items.append(torch.sin(item2)) return [torch.relu(s) for s in final_items] def split_with_cat(x): fs = torch.split(x, [4, 4, 24], dim=1) item0 = fs[0] item1 = fs[1] item2 = fs[2] final_items = [item0, item1] final_items.extend(item2.split((4, 4, 4, 4, 4, 4), 1)) return torch.cat(final_items, dim=1) def duplicate_getitems(x): fs = torch.split(x, [10, 10, 12], dim=1) item0 = fs[0] item1_1 = fs[1] item1_2 = fs[1] item2 = fs[2] final_items = [] final_items.extend(item0.split([4, 6], 1)) final_items.extend(item1_1.split([6, 4], 1)) final_items.extend(item1_2) final_items.append(torch.sin(item2)) return [torch.relu(s) for s in final_items] def duplicate_getitems_neg_index(x): fs = torch.split(x, [10, 10, 12], dim=1) item0 = fs[0] item1_1 = fs[1] item1_2 = fs[-2] # negative index item2 = fs[2] final_items = [] final_items.extend(item0.split([4, 6], 1)) final_items.extend(item1_1.split([6, 4], 1)) final_items.extend(item1_2) final_items.append(torch.sin(item2)) return [torch.relu(s) for s in final_items] def split_getitem_gap(x): fs = torch.split(x, [4, 4, 24], dim=1) item0 = fs[0] item2 = fs[2] final_items = [ item0, ] final_items.extend(item2.split((4, 4, 4, 4, 4, 4), 1)) return torch.cat(final_items, dim=1) def split_getitem_out_of_order(x): fs = torch.split(x, [4, 4, 4, 20], dim=1) item0 = fs[0] item2 = fs[2] item1 = fs[1] item3 = fs[3] final_items = [item0, item2, item1] final_items.extend(item3.split((4, 4, 4, 4, 4), 1)) return torch.cat(final_items, dim=1) def split_partial_getitem_cat(x): fs = torch.split(x, [4, 4, 24], dim=1) item0 = fs[0] item2 = fs[2] final_items = [ item0, ] final_items.extend(item2.split((4, 4, 4, 4, 4, 4), 1)) return torch.cat(final_items, dim=1) args = [ torch.randn(2, 32), ] for fn, expected_split_merged in [ (multi_split, 0), (multi_split_2, 16), (multi_split_2_neg_dim, 16), (multi_split_with_sizes, 2), (multi_split_kwarg1, 0), (multi_split_kwarg2, 0), (unequal_multi_split, 3), (unequal_multi_split_neg_index, 3), (diff_dims, 0), (some_users_not_splits, 2), (split_with_cat, 1), (duplicate_getitems, 1), (duplicate_getitems_neg_index, 1), (split_getitem_gap, 1), (split_getitem_out_of_order, 1), (split_partial_getitem_cat, 1), ]: expected = fn(*args) actual = torch.compile(fn)(*args) torch.testing.assert_close(actual, expected) self.assertEqual( counters["inductor"]["merge_splits_pass"], expected_split_merged, ) if expected_split_merged > 0: self.assertIn("merge_splits_pass_pre_grad", optimus_scuba_log) counters.clear() @patch def test_split_cat_merge(self): def simple_split_cat(x): return torch.cat(torch.split(x, 4, dim=1), dim=1) def simple_split_cat_argspec1(x): return torch.cat(torch.split(x, 4, dim=1), 1) def simple_split_cat_argspec2(x): return torch.cat(tensors=torch.split(x, 4, dim=1), dim=1) def simple_split_cat_argspec3(x): return torch.cat(torch.split(x, 4, dim=1), -2) def simple_split_cat_argspec4(x): return torch.cat(tensors=torch.split(x, 4, dim=1), dim=-2) def simple_split_stack(x): return torch.stack(torch.split(x, 4, dim=1), dim=1) def simple_split_stack_argspec1(x): return torch.stack(torch.split(x, 4, dim=1), 1) def simple_split_stack_argspec2(x): return torch.stack(tensors=torch.split(x, 4, dim=1), dim=1) def split_cat_addn_args(x): split_output = list(torch.split(x, 4, dim=1)) return torch.cat( [torch.ones(2, 5, 32, 16)] + split_output + [torch.ones(2, 6, 32, 16)], dim=1, ) def split_stack_addn_args(x): split_output = list(torch.split(x, 4, dim=1)) return torch.stack( [torch.ones(2, 4, 32, 16)] + split_output + [torch.ones(2, 4, 32, 16), torch.ones(2, 4, 32, 16)], dim=1, ) def split_cat_addn_args_dim2(x): split_output = list(torch.split(x, 4, dim=2)) return torch.cat( [torch.ones(2, 32, 5, 16)] + split_output + [torch.ones(2, 32, 6, 16)], dim=2, ) # split_dim=1, cat_dim=2 def split_cat_dim_mismatch(x): split_output = list(torch.split(x, 4, dim=1)) return torch.cat( [torch.ones(2, 4, 32, 16)] + split_output + [torch.ones(2, 4, 32, 16)], dim=2, ) def split_stack_dim_mismatch(x): split_output = list(torch.split(x, 4, dim=1)) return torch.stack( [torch.ones(2, 4, 32, 16)] + split_output + [torch.ones(2, 4, 32, 16)], dim=2, ) # split_dim=1, cat_dim=3 def split_cat_dim_mismatch2(x): split_output = list(torch.split(x, 4, dim=1)) return torch.cat( [torch.ones(2, 4, 32, 16)] + split_output + [torch.ones(2, 4, 32, 16)], dim=3, ) def split_stack_dim_mismatch2(x): split_output = list(torch.split(x, 4, dim=1)) return torch.stack( [torch.ones(2, 4, 32, 16)] + split_output + [torch.ones(2, 4, 32, 16)], dim=3, ) # split_dim=2, cat_dim=0 def split_cat_dim_mismatch3(x): split_output = list(torch.split(x, 4, dim=2)) return torch.cat( [torch.ones(2, 32, 4, 16)] + split_output + [torch.ones(2, 32, 4, 16)], dim=0, ) def split_stack_dim_mismatch3(x): split_output = list(torch.split(x, 4, dim=2)) return torch.stack( [torch.ones(2, 32, 4, 16)] + split_output + [torch.ones(2, 32, 4, 16)], dim=0, ) def input_shuffling(x): split_output = list(torch.split(x, 4, dim=1)) return torch.cat( [torch.ones(2, 4, 32, 16)] + [split_output[1], split_output[2], split_output[3]] + [torch.ones(2, 4, 32, 16)] + [split_output[5], split_output[6], split_output[7]] + [torch.ones(2, 4, 32, 16)], dim=1, ) def input_shuffling_stack(x): split_output = list(torch.split(x, 4, dim=1)) return torch.stack( [torch.ones(2, 4, 32, 16)] + [split_output[1], split_output[2], split_output[3]] + [torch.ones(2, 4, 32, 16)] + [split_output[5], split_output[6], split_output[7]] + [torch.ones(2, 4, 32, 16)], dim=1, ) def input_shuffling_dim_mismatch(x): split_output = list(torch.split(x, 4, dim=1)) return torch.cat( [torch.ones(2, 4, 32, 16)] + [split_output[1], split_output[2], split_output[3]] + [torch.ones(2, 4, 32, 16)] + [split_output[5], split_output[6], split_output[7]] + [torch.ones(2, 4, 32, 16)], dim=2, ) def input_shuffling_dim_mismatch_stack(x): split_output = list(torch.split(x, 4, dim=1)) return torch.stack( [torch.ones(2, 4, 32, 16)] + [split_output[1], split_output[2], split_output[3]] + [torch.ones(2, 4, 32, 16)] + [split_output[5], split_output[6], split_output[7]] + [torch.ones(2, 4, 32, 16)], dim=2, ) def input_shuffling_multiple_output(x): split_output = list(torch.split(x, 4, dim=1)) cat1 = torch.cat( [torch.ones(2, 4, 32, 16)] + [split_output[1], split_output[2], split_output[3]] + [torch.ones(2, 4, 32, 16)], dim=2, ) stack1 = torch.stack( [ torch.ones(2, 4, 32, 16), split_output[4], split_output[5], torch.ones(2, 4, 32, 16), ], dim=1, ) relu1 = torch.relu(split_output[6]) return cat1, stack1, relu1 def input_shuffling_direct_output(x): split_output = list(torch.split(x, 4, dim=1)) cat1 = torch.cat( [torch.ones(2, 4, 32, 16)] + [split_output[1], split_output[2], split_output[3]] + [torch.ones(2, 4, 32, 16)], dim=2, ) stack1 = torch.stack( [ torch.ones(2, 4, 32, 16), split_output[4], split_output[5], torch.ones(2, 4, 32, 16), ], dim=1, ) return cat1, stack1, split_output[6] def input_shuffling_multiple_output_same_ranges(x): split_output = list(torch.split(x, 4, dim=1)) cat1 = torch.cat( [torch.ones(2, 4, 32, 16)] + [split_output[1], split_output[2], split_output[3]] + [torch.ones(2, 4, 32, 16)], dim=2, ) cat2 = torch.cat( [torch.ones(2, 4, 32, 16)] + [split_output[1], split_output[2], split_output[3]] + [torch.ones(2, 4, 32, 16)], dim=2, ) stack1 = torch.stack( [ torch.ones(2, 4, 32, 16), split_output[4], split_output[5], torch.ones(2, 4, 32, 16), ], dim=1, ) relu1 = torch.relu(split_output[6]) return cat1, cat2, stack1, relu1 def unequal_split_multiple_output(x): split_output = list(torch.split(x, [2, 4, 4, 4, 4, 4, 8, 2], dim=1)) cat1 = torch.cat( [torch.ones(2, 4, 32, 16)] + [split_output[1], split_output[2], split_output[3]] + [torch.ones(2, 4, 32, 16)], dim=2, ) stack1 = torch.stack( [ torch.ones(2, 4, 32, 16), split_output[4], split_output[5], torch.ones(2, 4, 32, 16), ], dim=1, ) relu1 = torch.relu(split_output[6]) return cat1, stack1, relu1 def multi_split_cat(x1, x2): split_output_1 = list(torch.split(x1, 4, dim=1)) split_output_2 = list(torch.split(x2, 4, dim=1)) cat1 = torch.cat( [torch.ones(2, 4, 32, 16)] + [split_output_1[1], split_output_1[2], split_output_1[3]] + [torch.ones(2, 4, 32, 16)] + [split_output_2[1], split_output_2[2], split_output_2[3]] + [torch.ones(2, 4, 32, 16)], dim=2, ) stack1 = torch.stack( [ torch.ones(2, 4, 32, 16), split_output_1[4], split_output_1[5], torch.ones(2, 4, 32, 16), split_output_2[4], split_output_2[5], torch.ones(2, 4, 32, 16), ], dim=1, ) relu1 = torch.relu(split_output_1[6]) relu2 = torch.relu(split_output_2[6]) return cat1, stack1, relu1, relu2 # TODO: Add more tests: # * Cases where replacement shouldn't happen default_args = [ torch.randn(2, 32, 32, 16), ] multi_args = [ torch.randn(2, 32, 32, 16), torch.randn(2, 32, 32, 16), ] for ( fn, expected_split_added, expected_split_removed, expected_cat_added, expected_cat_removed, expected_sections_removed, args, ) in [ (simple_split_cat, 0, 0, 0, 0, 0, default_args), (simple_split_cat_argspec1, 0, 0, 0, 0, 0, default_args), (simple_split_cat_argspec2, 0, 0, 0, 0, 0, default_args), (simple_split_cat_argspec3, 0, 1, 0, 1, 7, default_args), (simple_split_cat_argspec4, 0, 1, 0, 1, 7, default_args), (simple_split_stack, 0, 1, 0, 1, 7, default_args), (simple_split_stack_argspec1, 0, 1, 0, 1, 7, default_args), (simple_split_stack_argspec2, 0, 1, 0, 1, 7, default_args), (split_cat_addn_args, 0, 1, 1, 1, 7, default_args), (split_stack_addn_args, 0, 1, 1, 1, 7, default_args), (split_cat_addn_args_dim2, 0, 1, 1, 1, 7, default_args), (split_cat_dim_mismatch, 0, 1, 1, 1, 7, default_args), (split_stack_dim_mismatch, 0, 1, 1, 1, 7, default_args), (split_cat_dim_mismatch2, 0, 1, 1, 1, 7, default_args), (split_stack_dim_mismatch2, 0, 1, 1, 1, 7, default_args), (split_cat_dim_mismatch3, 0, 1, 1, 1, 7, default_args), (split_stack_dim_mismatch3, 0, 1, 1, 1, 7, default_args), (input_shuffling, 1, 1, 1, 1, 4, default_args), (input_shuffling_stack, 1, 1, 1, 1, 4, default_args), (input_shuffling_dim_mismatch, 1, 1, 1, 1, 4, default_args), (input_shuffling_dim_mismatch_stack, 1, 1, 1, 1, 4, default_args), (input_shuffling_multiple_output, 1, 1, 2, 2, 3, default_args), (input_shuffling_direct_output, 1, 1, 2, 2, 3, default_args), (unequal_split_multiple_output, 1, 1, 2, 2, 3, default_args), (multi_split_cat, 1, 1, 2, 2, 3, multi_args), ]: expected = fn(*args) actual = torch.compile(fn)(*args) torch.testing.assert_close(actual, expected) self.assertEqual( counters["inductor"]["scmerge_split_added"], expected_split_added, ) self.assertEqual( counters["inductor"]["scmerge_split_removed"], expected_split_removed, ) self.assertEqual( counters["inductor"]["scmerge_cat_added"], expected_cat_added, ) self.assertEqual( counters["inductor"]["scmerge_cat_removed"], expected_cat_removed, ) self.assertEqual( counters["inductor"]["scmerge_split_sections_removed"], expected_sections_removed, ) counters.clear() @torch._inductor.config.patch( pre_grad_fusion_options={}, post_grad_fusion_options={}, ) def test_config_flag_is_respected(self): def split_with_cat(x): fs = torch.split(x, [4, 4, 24], dim=-1) item0 = fs[0] item1 = fs[1] item2 = fs[2] final_items = [item0, item1] final_items.extend(item2.split((4, 4, 4, 4, 4, 4), 1)) return torch.cat(final_items, dim=1) args = [ torch.randn(2, 32), ] expected = split_with_cat(*args) actual = torch.compile(split_with_cat)(*args) torch.testing.assert_close(actual, expected) self.assertEqual( counters["inductor"]["merge_splits_pass"], 0, ) self.assertEqual( counters["inductor"]["normalization_pass"], 0, ) @patch def test_split_cat_merge_mutation(self): args = [ torch.randn(2, 32, 32, 16), ] def split_cat_mutation(x): splits = torch.split(x, 4, dim=1) splits[1].copy_(splits[0]) return torch.cat(splits, dim=1) expected = split_cat_mutation(*args) actual = torch.compile(split_cat_mutation)(*args) torch.testing.assert_close(actual, expected) self.assertEqual(counters["inductor"]["scmerge_split_removed"], 0) self.assertEqual(counters["inductor"]["scmerge_cat_removed"], 0) @patch def test_split_squeeze(self): def split_squeeze_stack(x): items = list(torch.split(x, 1, dim=1)) split_items = [torch.squeeze(s, 1) for s in items] return torch.stack(split_items) def split_squeeze_stack_callmethod(x): items = list(torch.split(x, 1, dim=1)) split_items = [s.squeeze(1) for s in items] return torch.stack(split_items) def split_squeeze_stack_callmethod_none_dim(x): items = list(torch.split(x, 1, dim=1)) split_items = [s.squeeze() for s in items] return torch.stack(split_items) def split_squeeze_stack_kwarg1(x): items = list(torch.split(x, 1, dim=1)) split_items = [torch.squeeze(s, dim=1) for s in items] return torch.stack(split_items) def split_squeeze_stack_kwarg1_callmethod(x): items = list(torch.split(x, 1, dim=1)) split_items = [s.squeeze(dim=1) for s in items] return torch.stack(split_items) def split_squeeze_multi_squeeze_users(x): items = list(torch.split(x, 1, dim=1)) split_items = [torch.squeeze(s, 1) for s in items] return ( torch.stack(split_items), torch.relu(split_items[0]), torch.tanh(split_items[1]), ) def split_size_not_1(x): items = list(torch.split(x, 2, dim=1)) split_items = [torch.squeeze(s, 1) for s in items] return torch.stack(split_items) def dim_mismatch(x): items = list(torch.split(x, 1, dim=1)) split_items = [torch.squeeze(s, 0) for s in items] return torch.stack(split_items) def other_users(x): items = list(torch.split(x, 1, dim=1)) split_items = [torch.squeeze(s, 1) for s in items] return torch.stack(split_items), torch.relu(items[0]) def other_users_2(x): items = list(torch.split(x, 1, dim=1)) split_items = [torch.squeeze(s, 1) for s in items[1:]] return torch.stack(split_items), torch.relu(items[0]) def graph_should_be_topological_sorted(x): output = [] for t in x.split(1): output.append(torch.sin(t.squeeze(dim=0))) output = torch.stack(output) return output args = [ torch.randn(2, 32), ] for fn, split_squeeze_replaced in [ (split_squeeze_stack, 1), (split_squeeze_stack_callmethod, 1), # TODO handle none dim (split_squeeze_stack_callmethod_none_dim, 0), (split_squeeze_stack_kwarg1, 1), (split_squeeze_stack_kwarg1_callmethod, 1), (split_squeeze_multi_squeeze_users, 1), (split_size_not_1, 0), (dim_mismatch, 0), (other_users, 0), (other_users_2, 0), (graph_should_be_topological_sorted, 1), ]: expected = fn(*args) actual = torch.compile(fn)(*args) torch.testing.assert_close(actual, expected) self.assertEqual( counters["inductor"]["split_cat_pass"], split_squeeze_replaced, ) counters.clear() @patch def test_unbind_stack(self): def unbind_stack(x): return torch.stack(torch.unbind(x, 1), 1) def unbind_cat(x): return torch.cat(torch.unbind(x, dim=-3), 1) def unbind_stack_argspec1(x): return torch.stack(torch.unbind(input=x, dim=1), dim=1) def unbind_stack_argspec2(x): return torch.stack(tensors=torch.unbind(x, dim=1), dim=1) def dim_mismatch(x): return torch.stack(torch.unbind(x, dim=1), 0) def split_squeeze_stack(x): items = list(torch.split(x, 1, dim=1)) split_items = [torch.squeeze(s, 1) for s in items] return torch.stack(split_items, 1) def split_squeeze_stack_callmethod(x): items = list(torch.split(x, 1, dim=1)) split_items = [torch.squeeze(s, 1) for s in items] return torch.stack(split_items, 1) def other_users(x): items = list(torch.split(x, 1, dim=1)) split_items = [torch.squeeze(s, 1) for s in items] return torch.stack(split_items, 1), torch.relu(items[0]) def other_users_2(x): items = list(torch.split(x, 1, dim=1)) split_items = [torch.squeeze(s, 1) for s in items[1:]] return torch.stack(split_items, 1), torch.relu(items[0]) def unbind_cat_addn_args(x): split_output = list(torch.unbind(x, dim=1)) return torch.cat( [torch.ones(2, 32, 16)] + split_output + [torch.ones(2, 32, 16)], dim=1, ) def unbind_stack_addn_args(x): split_output = list(torch.unbind(x, dim=1)) return torch.stack( [torch.ones(2, 32, 16)] + split_output + [torch.ones(2, 32, 16), torch.ones(2, 32, 16)], dim=1, ) def unbind_cat_addn_args_dim2(x): split_output = list(torch.unbind(x, dim=2)) return torch.cat( [torch.ones(2, 32, 16)] + split_output + [torch.ones(2, 32, 16)], dim=2, ) # split_dim=1, cat_dim=2 def unbind_cat_dim_mismatch(x): split_output = list(torch.unbind(x, dim=1)) return torch.cat( [torch.ones(2, 32, 16)] + split_output + [torch.ones(2, 32, 16)], dim=2, ) def unbind_stack_dim_mismatch(x): split_output = list(torch.unbind(x, dim=1)) return torch.stack( [torch.ones(2, 32, 16)] + split_output + [torch.ones(2, 32, 16)], dim=2, ) def unbind_cat_multi_users(x): split_output = list(torch.unbind(x, dim=1)) return torch.cat( [torch.ones(2, 32, 16)] + split_output + [torch.ones(2, 32, 16)], dim=1, ), torch.stack( [torch.ones(2, 32, 16)] + split_output + [torch.ones(2, 32, 16), torch.ones(2, 32, 16)], dim=1, ) def unbind_cat_multi_users_diff_dims(x): split_output = list(torch.unbind(x, dim=1)) return torch.cat( [torch.ones(2, 32, 16)] + split_output + [torch.ones(2, 32, 16)], dim=1, ), torch.stack( [torch.ones(2, 32, 16)] + split_output + [torch.ones(2, 32, 16)], dim=2, ) args = [ torch.randn(2, 32, 32, 16), ] for ( fn, expected_unbind_added, expected_unbind_removed, expected_cat_added, expected_cat_removed, expected_sections_removed, expected_unbind_normalized, ) in [ (unbind_stack, 0, 1, 0, 1, 31, 2), (unbind_stack_argspec1, 0, 1, 0, 1, 31, 2), (unbind_stack_argspec2, 0, 1, 0, 1, 31, 2), (dim_mismatch, 0, 1, 0, 1, 31, 2), (split_squeeze_stack, 0, 1, 0, 1, 31, 2), (split_squeeze_stack_callmethod, 0, 1, 0, 1, 31, 2), (other_users, 0, 0, 0, 0, 0, 2), (other_users_2, 0, 0, 0, 0, 0, 2), (unbind_cat_addn_args, 0, 1, 1, 1, 31, 1), (unbind_stack_addn_args, 0, 1, 1, 1, 31, 2), (unbind_cat_addn_args_dim2, 0, 1, 1, 1, 31, 1), (unbind_cat_dim_mismatch, 0, 1, 1, 1, 31, 1), (unbind_stack_dim_mismatch, 0, 1, 1, 1, 31, 2), (unbind_cat_multi_users, 0, 1, 2, 2, 31, 2), (unbind_cat_multi_users_diff_dims, 0, 1, 2, 2, 31, 2), ]: expected = fn(*args) actual = torch.compile(fn)(*args) torch.testing.assert_close(actual, expected) self.assertEqual( counters["inductor"]["scmerge_split_added"], expected_unbind_added, msg=f"for {fn}", ) self.assertEqual( counters["inductor"]["scmerge_split_removed"], expected_unbind_removed, msg=f"for {fn}", ) self.assertEqual( counters["inductor"]["scmerge_cat_added"], expected_cat_added, msg=f"for {fn}", ) self.assertEqual( counters["inductor"]["scmerge_cat_removed"], expected_cat_removed, msg=f"for {fn}", ) self.assertEqual( counters["inductor"]["scmerge_split_sections_removed"], expected_sections_removed, msg=f"for {fn}", ) self.assertEqual( counters["inductor"]["normalization_pass"], expected_unbind_normalized, msg=f"for {fn}", ) counters.clear() @patch def test_split_cat_new_patterns(self): def split_cat_split(x): l1_out = torch.split(x, [200, 50, 50, 20, 20, 20, 20, 20, 20, 50, 30], 1) item0 = l1_out[0] item1 = l1_out[1] item2 = l1_out[2] item3 = l1_out[3] item4 = l1_out[4] item5 = l1_out[5] item6 = l1_out[6] item7 = l1_out[7] item8 = l1_out[8] item9 = l1_out[9] item10 = l1_out[10] cat_1 = torch.cat((item0, item1), 1) cat_2 = torch.cat((item9, item10), 1) l2_out = torch.split(cat_1, [50, 120, 80], 1) l3_out = torch.split(cat_2, [10, 20, 50], 1) item11 = l2_out[0] item12 = l2_out[1] item13 = l2_out[2] item14 = l3_out[0] item15 = l3_out[1] item16 = l3_out[2] output = torch.cat( [ item11, item12, item13, item14, item15, item16, item2, item3, item4, item5, item6, item7, item8, ], 1, ) return output def split_cat_split_kwarg(x): l1_out = torch.split( x, [200, 50, 50, 20, 20, 20, 20, 20, 20, 50, 30], dim=1 ) item0 = l1_out[0] item1 = l1_out[1] item2 = l1_out[2] item3 = l1_out[3] item4 = l1_out[4] item5 = l1_out[5] item6 = l1_out[6] item7 = l1_out[7] item8 = l1_out[8] item9 = l1_out[9] item10 = l1_out[10] cat_1 = torch.cat((item0, item1), dim=1) cat_2 = torch.cat((item9, item10), dim=1) l2_out = torch.split(cat_1, [50, 120, 80], dim=1) l3_out = torch.split(cat_2, [10, 20, 50], dim=1) item11 = l2_out[0] item12 = l2_out[1] item13 = l2_out[2] item14 = l3_out[0] item15 = l3_out[1] item16 = l3_out[2] output = torch.cat( [ item11, item12, item13, item14, item15, item16, item2, item3, item4, item5, item6, item7, item8, ], dim=1, ) return output def remove_cat_node_with_all_getitmes(x): l1_out = torch.split( x, [50, 50, 200, 20, 20, 20, 20, 20, 40, 10, 50], dim=0 ) item0 = l1_out[0] item1 = l1_out[1] item2 = l1_out[2] item3 = l1_out[3] item4 = l1_out[4] item5 = l1_out[5] item6 = l1_out[6] item7 = l1_out[7] item8 = l1_out[8] item9 = l1_out[9] item10 = l1_out[10] cat = torch.cat( ( item0, item1, item2, item3, item4, item5, item6, item7, item8, item9, item10, ), dim=0, ) cat_1 = torch.cat((item0, item1), dim=0) cat_2 = torch.cat((item0, item10), dim=0) l2_out = torch.split(cat_1, [20, 30, 50], dim=0) l3_out = torch.split(cat_2, [10, 60, 30], dim=0) item11 = l2_out[0] item12 = l2_out[1] item13 = l2_out[2] item14 = l3_out[0] item15 = l3_out[1] item16 = l3_out[2] output = torch.cat( [ item11, item12, item13, item14, item15, item16, item2, item3, item4, item5, item6, item7, item8, ], dim=0, ) return torch.cat((output, cat), dim=0) def mutate_cat_node_with_some_getitmes(x): l1_out = torch.split( x, [50, 50, 200, 20, 20, 20, 20, 20, 40, 10, 50], dim=0 ) item0 = l1_out[0] item1 = l1_out[1] item2 = l1_out[2] item3 = l1_out[3] item4 = l1_out[4] item5 = l1_out[5] item6 = l1_out[6] item7 = l1_out[7] item8 = l1_out[8] item9 = l1_out[9] item10 = l1_out[10] cat = torch.cat( ( item6, item7, item8, item9, item10, item2, item3, item4, item5, ), dim=0, ) cat_1 = torch.cat((item0, item1), dim=0) cat_2 = torch.cat((item0, item10), dim=0) l2_out = torch.split(cat_1, [20, 30, 50], dim=0) l3_out = torch.split(cat_2, [10, 60, 30], dim=0) item11 = l2_out[0] item12 = l2_out[1] item13 = l2_out[2] item14 = l3_out[0] item15 = l3_out[1] item16 = l3_out[2] output = torch.cat( [ item11, item12, item13, item14, item15, item16, item2, ], dim=0, ) return torch.cat((output, cat), dim=0) @torch._inductor.config.patch( pre_grad_fusion_options={ "split_cat_to_slices_pass": {}, }, post_grad_fusion_options={}, ) def split_cat_to_slices(x): x_c = x.clone() x_c_2 = x.clone() l1_out = torch.split(x, [50, 50, 50, 50, 50, 50, 50, 50, 50, 50], dim=0) l2_out = torch.split(x_c, [50, 50, 50, 50, 50, 50, 50, 50, 50, 50], dim=0) l3_out = torch.split(x_c_2, [100, 100, 100, 100, 100], dim=0) item0 = l1_out[0] item1 = l1_out[1] item2 = l1_out[2] item3 = l1_out[3] item4 = l1_out[4] item5 = l1_out[5] item6 = l1_out[6] item7 = l1_out[7] item8 = l1_out[8] item9 = l1_out[9] item0_c = l2_out[0] item1_c = l2_out[1] item2_c = l2_out[2] item3_c = l2_out[3] item4_c = l2_out[4] item5_c = l2_out[5] item6_c = l2_out[6] item7_c = l2_out[7] item8_c = l2_out[8] item9_c = l2_out[9] item0_c_2 = l3_out[0] item1_c_2 = l3_out[1] item2_c_2 = l3_out[2] item3_c_2 = l3_out[3] item4_c_2 = l3_out[4] other = item0.clone() return torch.cat( [ other, item0, item1, item2, item3, item4, item5, item6, item7, item8, item9, item4_c, item5_c, item6_c, item7_c, item8_c, item9_c, item0_c, item1_c, item2_c, item3_c, item0_c_2, item1_c_2, item2_c_2, item3_c_2, item4_c_2, ], dim=0, ) @torch._inductor.config.patch( pre_grad_fusion_options={ "unbind_cat_to_view_pass": {}, }, post_grad_fusion_options={}, ) def unbind_cat_to_view(x): y = x.view(10, 50, 500) z = x.view(10, 50, 500) l1_out = torch.unbind(y, dim=0) l2_out = torch.unbind(z, dim=0) item0 = l1_out[0] item1 = l1_out[1] item2 = l1_out[2] item3 = l1_out[3] item4 = l1_out[4] item5 = l1_out[5] item6 = l1_out[6] item7 = l1_out[7] item8 = l1_out[8] item9 = l1_out[9] item2_0 = l2_out[0] item2_1 = l2_out[1] item2_2 = l2_out[2] item2_3 = l2_out[3] item2_4 = l2_out[4] item2_5 = l2_out[5] item2_6 = l2_out[6] item2_7 = l2_out[7] item2_8 = l2_out[8] item2_9 = l2_out[9] other1 = item7.clone() other2 = item8.clone() other3 = item9.clone() cat = torch.cat( [ item0, item1, item2, item3, item4, item5, item6, other1, item2_0, item2_1, item2_2, item2_3, item2_4, item2_5, item2_6, item2_7, item2_8, item2_9, other2, other3, ], dim=1, ) return cat @torch._inductor.config.patch( pre_grad_fusion_options={ "split_stack_to_cats_pass": {}, }, post_grad_fusion_options={}, ) def split_stack_to_cats_same_dim(x): x_c = x.view(10, 50, 500) l1_out = torch.unbind(x_c, dim=0) item0 = l1_out[0] item1 = l1_out[1] item2 = l1_out[2] item3 = l1_out[3] item4 = l1_out[4] item5 = l1_out[5] split1 = torch.split(item0, [250, 250], dim=1) split2 = torch.split(item1, [250, 250], dim=1) split3 = torch.split(item2, [250, 250], dim=1) split4 = torch.split(item3, [250, 250], dim=1) split5 = torch.split(item4, [250, 250], dim=1) split6 = torch.split(item5, [250, 250], dim=1) getitem0, getitem1 = split1[0], split1[1] getitem2, getitem3 = split2[0], split2[1] getitem4, getitem5 = split3[0], split3[1] getitem6, getitem7 = split4[0], split4[1] getitem8, getitem9 = split5[0], split5[1] getitem10, getitem11 = split6[0], split6[1] getitem0_c = getitem0.clone() getitem1_c = getitem1.clone() getitem2_c = getitem2.clone() return torch.stack( ( getitem0, getitem1, getitem2, getitem3, getitem4, getitem5, getitem0_c, getitem1_c, getitem6, getitem7, getitem8, getitem9, getitem10, getitem11, getitem2_c, ), dim=1, ) @torch._inductor.config.patch( pre_grad_fusion_options={ "split_stack_to_cats_pass": {}, }, post_grad_fusion_options={}, ) def split_stack_to_cats_different_dim(x): l1_out = torch.split(x, [100, 100, 100, 100, 100], dim=1) x_c = x.clone() l2_out = torch.split(x_c, [100, 100, 100, 100, 100], dim=1) item0 = l1_out[0] item1 = l1_out[1] item2 = l1_out[2] item3 = l1_out[3] item4 = l1_out[4] item0_c = l2_out[0] item1_c = l2_out[1] item2_c = l2_out[2] item3_c = l2_out[3] item4_c = l2_out[4] other_1 = item0.clone() other_2 = item1.clone() other_3 = item2.clone() return torch.stack( ( other_1, other_2, other_3, item0, item1, item2, item3, item4, item0_c, item1_c, item2_c, item3_c, item4_c, ), dim=2, ) @torch._inductor.config.patch( pre_grad_fusion_options={ "unbind_stack_to_slices_pass": {}, }, post_grad_fusion_options={}, ) def unbind_stack_to_slices(x): x_1 = x.view(50, 10, 500) l1_out = torch.unbind(x_1, dim=1) item0 = l1_out[0] item1 = l1_out[1] item2 = l1_out[2] item3 = l1_out[3] item4 = l1_out[4] item5 = l1_out[5] item6 = l1_out[6] item7 = l1_out[7] item8 = l1_out[8] item9 = l1_out[9] other_1 = item0.clone() other_2 = item1.clone() other_3 = item2.clone() return torch.stack( ( other_1, other_2, other_3, item0, item1, item2, item3, item4, item5, item6, item7, item8, item9, ), dim=1, ) @torch._inductor.config.patch( pre_grad_fusion_options={ "normalization_pass": {}, "move_reshape_out_of_split_stack_pass": {}, }, post_grad_fusion_options={}, ) def move_reshape_out_of_split_stack(x): x_c = x.view(50000, 5) l1_out = torch.split(x_c, [1, 1, 1, 1, 1], dim=1) item0 = l1_out[0] item1 = l1_out[1] item2 = l1_out[2] item3 = l1_out[3] item4 = l1_out[4] reshape0 = item0.reshape(-1, 5) reshape1 = item1.reshape(-1, 5) reshape2 = item2.reshape(-1, 5) reshape3 = item3.reshape(-1, 5) reshape4 = item4.reshape(-1, 5) other0 = reshape0.clone() other1 = reshape1.clone() other2 = reshape2.clone() other3 = reshape3.clone() return torch.stack( ( other0, other1, other2, reshape0, reshape1, reshape2, reshape3, reshape4, other3, ), dim=0, ) args = [ torch.randn(500, 500), ] for ( fn, expected_getitem_cat_merged, expected_cat_removed, expected_split_cat_to_slices, exptected_unbind_to_cat_view, expected_split_stack_to_cats, exptected_unbind_stack_to_slices, expected_move_reshape_out_of_split_stack, ) in [ (split_cat_split, 2, 0, 0, 0, 0, 0, 0), (split_cat_split_kwarg, 2, 0, 0, 0, 0, 0, 0), (remove_cat_node_with_all_getitmes, 0, 2, 0, 0, 0, 0, 0), (mutate_cat_node_with_some_getitmes, 0, 1, 0, 0, 0, 0, 0), (split_cat_to_slices, 0, 0, 1, 0, 0, 0, 0), (unbind_cat_to_view, 0, 0, 0, 1, 0, 0, 0), (split_stack_to_cats_same_dim, 0, 0, 0, 0, 1, 0, 0), (split_stack_to_cats_different_dim, 0, 0, 0, 0, 1, 0, 0), (unbind_stack_to_slices, 0, 0, 0, 0, 0, 1, 0), (move_reshape_out_of_split_stack, 0, 0, 0, 0, 0, 0, 1), ]: expected = fn(*args) actual = torch.compile(fn)(*args) torch.testing.assert_close(actual, expected) self.assertEqual( counters["inductor"]["merge_getitem_cat_pass"], expected_getitem_cat_merged, ) self.assertEqual( counters["inductor"]["mutate_cat_pass"], expected_cat_removed, ) self.assertEqual( counters["inductor"]["split_cat_to_slices_pass"], expected_split_cat_to_slices, ) self.assertEqual( counters["inductor"]["unbind_cat_to_view_pass"], exptected_unbind_to_cat_view, ) self.assertEqual( counters["inductor"]["split_stack_to_cats_pass"], expected_split_stack_to_cats, ) self.assertEqual( counters["inductor"]["unbind_stack_to_slices_pass"], exptected_unbind_stack_to_slices, ) self.assertEqual( counters["inductor"]["move_reshape_out_of_split_stack_pass"], expected_move_reshape_out_of_split_stack, ) counters.clear() def test_numpy_compat_normalization(self): def fn(x, y): a = torch.stack([x, y], axis=1) b = torch.mul(x, x2=y) c = torch.mul(x, x2=y) d = torch.mul(x, x2=y) e = torch.max(x, dim=1, keepdims=True) f = torch.dropout(x=x, p=0.5, train=True) return a, b, c, d, e, f fn_t = torch.fx.symbolic_trace(fn) numpy_compat_normalization(fn_t.graph) for n in fn_t.graph.nodes: for k in n.kwargs.keys(): self.assertTrue(k not in {"x", "x1", "x2", "a", "axis", "keepdims"}) @patch @requires_gpu def test_stack_normalization_axis_kwarg(self): def fn(x, y): return torch.stack([x, y], axis=1) x, y = (torch.rand((4, 4), device=GPU_TYPE) for _ in range(2)) expected = fn(x, y) actual = torch.compile(fn)(x, y) self.assertEqual(actual, expected) if __name__ == "__main__": if IS_LINUX and HAS_GPU: run_tests()