# Owner(s): ["module: inductor"] import sys import unittest import weakref from contextlib import ExitStack from copy import deepcopy from typing import NamedTuple import torch import torch._inductor import torch._inductor.cudagraph_trees import torch.optim.lr_scheduler from torch._inductor import config from torch._inductor.test_case import TestCase from torch.optim import ( Adadelta, Adagrad, Adam, Adamax, AdamW, ASGD, NAdam, RAdam, RMSprop, Rprop, SGD, SparseAdam, ) from torch.optim.lr_scheduler import ( ChainedScheduler, ConstantLR, CosineAnnealingLR, CosineAnnealingWarmRestarts, CyclicLR, ExponentialLR, LambdaLR, LinearLR, MultiplicativeLR, MultiStepLR, OneCycleLR, PolynomialLR, ReduceLROnPlateau, StepLR, ) from torch.testing._internal.common_device_type import ( instantiate_device_type_tests, skipCUDAIf, skipXPUIf, ) from torch.testing._internal.common_optimizers import ( _get_optim_inputs_including_global_cliquey_kwargs, optim_db, optims, ) from torch.testing._internal.common_utils import parametrize from torch.testing._internal.inductor_utils import ( GPU_TYPE, HAS_CPU, HAS_GPU, has_triton, ) from torch.testing._internal.triton_utils import requires_cuda, requires_gpu # Note: we use atypical values to amplify error LR_SCHEDULER_TO_KWARGS = { LambdaLR: {"lr_lambda": lambda x: 10}, MultiplicativeLR: {"lr_lambda": lambda x: 10}, StepLR: {"step_size": 1, "gamma": 100}, MultiStepLR: {"milestones": [1, 2], "gamma": 100}, ExponentialLR: {"gamma": 100}, CosineAnnealingLR: {"T_max": 7}, # These schedulers have memory leaks in eager # https://github.com/pytorch/pytorch/issues/126131 # SequentialLR: {"schedulers": None, "milestones": [1, 2]}, # ChainedScheduler: {"schedulers": None}, CyclicLR: {"base_lr": 0.001, "max_lr": 0.02, "cycle_momentum": False}, CosineAnnealingWarmRestarts: {"T_0": 1}, OneCycleLR: { "max_lr": 0.02, "cycle_momentum": False, "steps_per_epoch": 1, "epochs": 10, }, ConstantLR: {"factor": 0.001}, LinearLR: {}, ReduceLROnPlateau: {"factor": 0.99, "patience": 1}, PolynomialLR: {}, } def create_scheduler(scheduler, optim): kwargs = LR_SCHEDULER_TO_KWARGS[scheduler] if "schedulers" in kwargs: kwargs["schedulers"] = [ create_scheduler(torch.optim.lr_scheduler.ConstantLR, optim) for _ in range(2) ] + [create_scheduler(torch.optim.lr_scheduler.LambdaLR, optim)] if scheduler == ChainedScheduler: return scheduler(**kwargs) else: return scheduler(optim, **kwargs) class KernelCounts(NamedTuple): multitensor: int singletensor: int # With different settings for certain # tests you can get different kernel counts # This maps the test name to the # expected kernel count KERNEL_COUNT_OVERRIDES = { "test_rmsprop_foreach_weight_decay_cpu": 12, "test_nadam_foreach_weight_decay_momentum_decay_cpu": 20, "test_adamw_amsgrad_capturable_foreach_cuda": 3, "test_adamw_amsgrad_capturable_foreach_xpu": 3, "test_adamw_amsgrad_capturable_cuda": 6, "test_adamw_amsgrad_capturable_xpu": 6, "test_adamw_tensor_lr_amsgrad_capturable_foreach_cuda": 3, "test_adamw_tensor_lr_amsgrad_capturable_foreach_xpu": 3, "test_adamw_tensor_lr_amsgrad_capturable_cuda": 6, "test_adamw_tensor_lr_amsgrad_capturable_xpu": 6, "test_adam_tensor_lr_amsgrad_capturable_cuda": 6, "test_adam_tensor_lr_amsgrad_capturable_xpu": 6, "test_adam_amsgrad_capturable_cuda": 6, "test_adam_amsgrad_capturable_xpu": 6, "test_adadelta_tensor_lr_capturable_cuda": 6, "test_adadelta_tensor_lr_capturable_xpu": 6, "test_rmsprop_tensor_lr_capturable_cuda": 6, "test_rmsprop_tensor_lr_capturable_xpu": 6, "test_adadelta_tensor_lr_capturable_foreach_cuda": 4, "test_adadelta_tensor_lr_capturable_foreach_xpu": 4, "test_adadelta_foreach_weight_decay_maximize_cpu": 12, "test_adadelta_foreach_rho_weight_decay_cpu": 12, "test_adadelta_foreach_weight_decay_cpu": 12, "test_sgd_foreach_momentum_weight_decay_cpu": 16, "test_sgd_foreach_momentum_nesterov_weight_decay_cpu": 16, "test_sgd_momentum_dampening_foreach_cuda": 5, "test_sgd_momentum_dampening_foreach_xpu": 5, "test_sgd_momentum_foreach_cuda": 5, "test_sgd_momentum_foreach_xpu": 5, "test_sgd_weight_decay_maximize_cuda": 4, "test_sgd_weight_decay_maximize_xpu": 4, "test_sgd_weight_decay_maximize_cpu": 4, "test_sgd_weight_decay_cpu": 4, "test_sgd_weight_decay_cuda": 4, "test_sgd_weight_decay_xpu": 4, "test_sgd_momentum_weight_decay_foreach_cuda": 2, "test_sgd_momentum_weight_decay_foreach_xpu": 2, "test_sgd_momentum_nesterov_weight_decay_foreach_cuda": 2, "test_sgd_momentum_nesterov_weight_decay_foreach_xpu": 2, "test_sgd_cuda": 4, "test_sgd_cpu": 4, "test_sgd_xpu": 4, "test_rmsprop_tensor_lr_capturable_foreach_cuda": 4, "test_rmsprop_tensor_lr_capturable_foreach_xpu": 4, "test_adagrad_initial_accumulator_value_weight_decay_foreach_xpu": 2, "test_adagrad_lr_decay_weight_decay_foreach_xpu": 2, "test_adagrad_weight_decay_foreach_xpu": 2, "test_adagrad_weight_decay_maximize_foreach_xpu": 2, "test_adagrad_tensor_lr_cpu": 6, "test_adagrad_tensor_lr_cuda": 6, "test_adagrad_tensor_lr_xpu": 6, "test_adamax_tensor_lr_weight_decay_capturable_cuda": 6, "test_adamax_tensor_lr_weight_decay_capturable_xpu": 6, "test_asgd_tensor_lr_weight_decay_maximize_capturable_cuda": 5, "test_asgd_tensor_lr_weight_decay_maximize_capturable_xpu": 8, "test_asgd_tensor_lr_weight_decay_maximize_capturable_foreach_cuda": 4, "test_asgd_tensor_lr_weight_decay_maximize_capturable_foreach_xpu": 4, "test_nadam_tensor_lr_weight_decay_momentum_decay_decoupled_weight_decay_capturable_cuda": 6, "test_nadam_tensor_lr_weight_decay_momentum_decay_decoupled_weight_decay_capturable_xpu": 9, "test_nadam_tensor_lr_weight_decay_momentum_decay_decoupled_weight_decay_capturable_foreach_cuda": 3, "test_nadam_tensor_lr_weight_decay_momentum_decay_decoupled_weight_decay_capturable_foreach_xpu": 3, "test_radam_tensor_lr_capturable_weight_decay_decoupled_weight_decay_cuda": 6, "test_radam_tensor_lr_capturable_weight_decay_decoupled_weight_decay_xpu": 6, "test_radam_tensor_lr_capturable_weight_decay_decoupled_weight_decay_foreach_cuda": 3, "test_radam_tensor_lr_capturable_weight_decay_decoupled_weight_decay_foreach_xpu": 3, "test_sgd_tensor_lr_cpu": 2, "test_sgd_tensor_lr_cuda": 2, "test_sgd_tensor_lr_xpu": 2, "test_sgd_tensor_lr_foreach_cuda": 2, "test_sgd_tensor_lr_foreach_xpu": 2, } # also tracks currently supported optimizers KERNEL_COUNTS = { Adam: KernelCounts(multitensor=2, singletensor=8), AdamW: KernelCounts(multitensor=2, singletensor=8), NAdam: KernelCounts(multitensor=2, singletensor=8), Rprop: KernelCounts(multitensor=2, singletensor=8), RMSprop: KernelCounts(multitensor=2, singletensor=8), Adadelta: KernelCounts(multitensor=2, singletensor=8), Adagrad: KernelCounts(multitensor=2, singletensor=8), SGD: KernelCounts(multitensor=1, singletensor=8), ASGD: KernelCounts(multitensor=2, singletensor=8), RAdam: KernelCounts(multitensor=2, singletensor=8), Adamax: KernelCounts(multitensor=2, singletensor=8), } def build_opt_kwarg_db(): compiled_opt_db = [] for optim_info in optim_db: if optim_info.optim_cls not in KERNEL_COUNTS: continue for device in ["cpu", GPU_TYPE]: for optim_inputs in _get_optim_inputs_including_global_cliquey_kwargs( device, None, optim_info, skip=("differentiable", "fused") ): kwargs = dict(optim_inputs.kwargs) name = f"test_{optim_info.optim_cls.__name__.lower()}" has_tensor_lr = False for key, val in kwargs.items(): if (not key == "lr" and not key == "betas") and ( not isinstance(val, bool) or (isinstance(val, bool) and val) ): name += "_" + key if key == "lr" and isinstance(kwargs["lr"], torch.Tensor): has_tensor_lr = True name += "_tensor_lr" if key == "betas" and isinstance(kwargs["betas"][0], torch.Tensor): name += "_tensor_betas" name += f"_{device}" kwargs["device"] = device if name in KERNEL_COUNT_OVERRIDES: kwargs["kernel_count"] = KERNEL_COUNT_OVERRIDES[name] else: kwargs["kernel_count"] = ( KERNEL_COUNTS[optim_info.optim_cls].multitensor if kwargs.get("foreach", False) and device == GPU_TYPE else KERNEL_COUNTS[optim_info.optim_cls].singletensor ) if kwargs["kernel_count"] is None or kwargs.get("fused", False): continue if has_tensor_lr: for scheduler_cls in LR_SCHEDULER_TO_KWARGS.keys(): name_w_scheduler = name + f"_{scheduler_cls.__name__.lower()}" compiled_opt_db.append( ( optim_info.optim_cls, name_w_scheduler, kwargs, scheduler_cls, ) ) else: compiled_opt_db.append((optim_info.optim_cls, name, kwargs, None)) return compiled_opt_db COMPILED_OPT_KWARG_DB = build_opt_kwarg_db() aten = torch.ops.aten try: try: from .test_torchinductor import check_model, check_model_gpu except ImportError: from test_torchinductor import check_model, check_model_gpu except (unittest.SkipTest, ImportError) as e: sys.stderr.write(f"{type(e)}: {e}\n") if __name__ == "__main__": sys.exit(0) raise def call_scheduler(scheduler): if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): scheduler.step(1.0) # we won't reduce the metric over two iters anyway else: scheduler.step() def compile_opt(opt_compiled, closure=None, fullgraph=True): # run the patcher so that step has the expected structure torch._dynamo.eval_frame.TorchPatcher.patch() # unwrap step TWICE to avoid a deliberate graph break due to # a limitation of functionalization/no_grad detection # see the [Note on graph break] in optimizer.py # This ignores the outer _use_grad_if_differentiable wrapper # and instead manually disables grad before calling step, which is fine # for now as dynamo does not support differentiable optimizers anyway step_fn = opt_compiled.step.__wrapped__.__wrapped__ # This ensures we don't receive spam of warnings from LR Scheduler opt_compiled._opt_called = True if closure is not None: def fn(): step_fn(opt_compiled, closure) else: def fn(): step_fn(opt_compiled) return torch.compile(fn, backend="inductor", fullgraph=fullgraph) def check_optim( self, optim_cls, params_eager, params_compiled, state_eager, state_compiled, atol=None, rtol=None, ): params_eager = list(params_eager) params_compiled = list(params_compiled) # Note on tolerances: # test_correctness_Adadelta_cuda_float32 # Mismatched elements: 10 / 100 (10.0%) # Greatest absolute difference: 4.838220775127411e-05 at index (7, 4) (up to 1e-05 allowed) # Greatest relative difference: 0.007270356640219688 at index (7, 2) (up to 1e-05 allowed) # This is due to floating point ordering error + usage of sqrt rtol = None atol = None if optim_cls is Adadelta: rtol = 5.5e-4 atol = 5e-5 self.assertEqual(list(params_eager), list(params_compiled), atol=atol, rtol=rtol) for p_eager, p_compiled in zip(params_eager, params_compiled): self.assertEqual( state_eager[p_eager], state_compiled[p_compiled], atol=atol, rtol=rtol, ) def make_test( optim_cls, closure=None, scheduler_cls=None, kernel_count=2, device="cuda", **kwargs, ): def test_fn(self): stack = ExitStack() try: # https://github.com/pytorch/pytorch/issues/118715 for capturable Adagrad support # https://github.com/pytorch/pytorch/issues/118018 for capturable SGD support run_cudagraphs = device == "cuda" and optim_cls not in (Adagrad, SGD) if run_cudagraphs: stack.enter_context(config.patch({"triton.cudagraphs": True})) kwargs_compiled = deepcopy(kwargs) if isinstance(kwargs.get("lr", None), torch.Tensor): kwargs["lr"] = kwargs["lr"].to(device) kwargs_compiled["lr"] = kwargs_compiled["lr"].to(device) if "betas" in kwargs and isinstance(kwargs["betas"][0], torch.Tensor): kwargs["betas"] = ( kwargs["betas"][0].to(device), kwargs["betas"][1].to(device), ) kwargs_compiled["betas"] = ( kwargs_compiled["betas"][0].to(device), kwargs_compiled["betas"][1].to(device), ) torch._dynamo.reset() torch._inductor.metrics.reset() input = torch.ones([10, 10], device=device) model_eager = torch.nn.Sequential( *[torch.nn.Linear(10, 10, device=device) for _ in range(2)] ) model_eager(input).sum().backward() input = torch.ones([10, 10], device=device) model_compiled = deepcopy(model_eager) model_compiled(input).sum().backward() opt_eager = optim_cls(model_eager.parameters(), **kwargs) opt_compiled = optim_cls(model_compiled.parameters(), **kwargs_compiled) compiled_step = compile_opt(opt_compiled, closure=closure) if scheduler_cls: scheduler_compiled = create_scheduler(scheduler_cls, opt_compiled) scheduler_eager = create_scheduler(scheduler_cls, opt_eager) # some schedulers only change after at least an epoch has passed scheduler_compiled.last_epoch = 1 scheduler_eager.last_epoch = 1 with torch.set_grad_enabled(False): for i in range(2): compiled_step() opt_eager.step() if scheduler_cls: call_scheduler(scheduler_eager) call_scheduler(scheduler_compiled) check_optim( self, optim_cls, model_eager.parameters(), model_compiled.parameters(), opt_eager.state, opt_compiled.state, ) if run_cudagraphs: self.check_cudagraphs_ran() if self.check_kernel_count: # currently, we compile the step and the rest of the computation # separately because the step is a single element tensor # hence, the usual kernel count is 2 self.assertEqual( torch._inductor.metrics.generated_kernel_count, kernel_count ) finally: stack.close() if device == GPU_TYPE: test_fn = requires_gpu(test_fn) return test_fn def make_recompile_test(optim_cls, closure=None, kernel_count=2, **kwargs): @requires_gpu def test_fn(self): torch._dynamo.reset() torch._inductor.metrics.reset() input = torch.ones([10, 10], device=GPU_TYPE) model = torch.nn.Sequential( *[torch.nn.Linear(10, 10, device=GPU_TYPE) for _ in range(2)] ) model(input).sum().backward() opt_compiled = optim_cls(model.parameters(), **kwargs) compiled_step = compile_opt(opt_compiled) # check no recompile here with torch.set_grad_enabled(False): for _ in range(4): compiled_step() # perturb state to force recompile # Adagrad doesn't reinitialize state on each step # SGD has an empty state if optim_cls in (Adagrad, SGD): opt_compiled.param_groups[0]["lr"] = 0.02 elif optim_cls is Adam: # ensure we are guarding on the data_ptr of states state_tensor = opt_compiled.state[ opt_compiled.param_groups[0]["params"][0] ]["exp_avg"] opt_compiled.state[opt_compiled.param_groups[0]["params"][0]][ "exp_avg" ] = torch.zeros_like(state_tensor) else: opt_compiled.state.clear() compiled_step() if self.check_kernel_count: # currently, we compile the step and the rest of the computation # separately because the step is a single element tensor # hence, the usual kernel count is 2 # multiply by 2 to account for the recompile multiplier = 2 self.assertEqual( torch._inductor.metrics.generated_kernel_count, multiplier * kernel_count, ) return test_fn class CompiledOptimizerParityTests(TestCase): @skipCUDAIf(not has_triton(), "torch.compile with cuda requires triton") @skipXPUIf(not has_triton(), "torch.compile with xpu requires triton") @optims(optim_db, dtypes=[torch.float32]) @parametrize("use_closure", [True, False]) def test_correctness(self, device, dtype, optim_info, use_closure): optim_cls = optim_info.optim_cls all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs( device, dtype, optim_info, skip=("differentiable",) ) if optim_info.step_requires_closure and not use_closure: return for optim_input in all_optim_inputs: kwargs = optim_input.kwargs use_scheduler = isinstance(kwargs.get("lr", None), torch.Tensor) scheduler_classes = ( list(LR_SCHEDULER_TO_KWARGS.keys()) if use_scheduler else [None] ) for scheduler_cls in scheduler_classes: torch._dynamo.reset() torch._inductor.metrics.reset() input = torch.ones([10, 10], device=device) model_eager = torch.nn.Sequential( *[torch.nn.Linear(10, 10, device=device) for _ in range(2)] ) model_eager(input).sum().backward() model_compiled = deepcopy(model_eager) model_compiled(input).sum().backward() if optim_cls is SparseAdam: for param in model_eager.parameters(): param.grad = param.grad.to_sparse() for param in model_compiled.parameters(): param.grad = param.grad.to_sparse() opt_compiled = optim_cls( model_compiled.parameters(), **deepcopy(kwargs) ) opt_eager = optim_cls(model_eager.parameters(), **deepcopy(kwargs)) if scheduler_cls: scheduler_compiled = create_scheduler(scheduler_cls, opt_compiled) scheduler_eager = create_scheduler(scheduler_cls, opt_eager) # some schedulers only change after at least an epoch has passed scheduler_compiled.last_epoch = 1 scheduler_eager.last_epoch = 1 num_steps = 2 if use_closure: @torch.compile() def fn(): def closure(): loss = model_compiled(input).sum() loss.backward() if optim_info.only_supports_sparse_grads: for param in model_compiled.parameters(): param.grad = param.grad.to_sparse() return loss opt_compiled.step(closure) if scheduler_cls: call_scheduler(scheduler_compiled) def closure_eager(): loss = model_eager(input).sum() loss.backward() if optim_info.only_supports_sparse_grads: for param in model_eager.parameters(): param.grad = param.grad.to_sparse() return loss for _ in range(num_steps): opt_eager.step(closure_eager) if scheduler_cls: call_scheduler(scheduler_eager) else: @torch.compile() def fn(): opt_compiled.step() if scheduler_cls: call_scheduler(scheduler_compiled) for _ in range(num_steps): opt_eager.step() if scheduler_cls: call_scheduler(scheduler_eager) for _ in range(num_steps): fn() check_optim( self, optim_cls, model_eager.parameters(), model_compiled.parameters(), opt_eager.state, opt_compiled.state, ) class CompiledOptimizerTests(TestCase): check_model_gpu = check_model_gpu check_model_cpu = check_model check_kernel_count = True def setUp(self): super().setUp() torch._dynamo.reset() torch._inductor.metrics.reset() def tearDown(self): super().tearDown() torch._dynamo.reset() torch._inductor.metrics.reset() def check_cudagraphs_ran(self): # We run the zeroth device currently manager = torch._inductor.cudagraph_trees.get_container(0).tree_manager self.assertIsNotNone(manager) self.assertEqual(manager.new_graph_id().id, 1) test_adam_recompile = make_recompile_test(Adam, lr=0.01) test_adamw_recompile = make_recompile_test(AdamW, lr=0.01) test_adamax_recompile = make_recompile_test(Adamax, lr=0.01) test_nadam_recompile = make_recompile_test(NAdam, lr=0.01) test_rprop_recompile = make_recompile_test(Rprop, lr=0.01, kernel_count=2) test_rmsprop_recompile = make_recompile_test(RMSprop, lr=0.01) test_adadelta_recompile = make_recompile_test(Adadelta, lr=0.01) test_adagrad_recompile = make_recompile_test(Adagrad, lr=0.01) test_asgd_recompile_default = make_recompile_test(ASGD, lr=0.01) test_asgd_recompile_single = make_recompile_test( ASGD, kernel_count=8, lr=0.01, foreach=False ) test_asgd_recompile_foreach = make_recompile_test(ASGD, lr=0.01, foreach=True) test_sgd_recompile_single = make_recompile_test( SGD, kernel_count=4, lr=0.01, foreach=False ) test_sgd_recompile_foreach = make_recompile_test( SGD, kernel_count=1, lr=0.01, foreach=True ) @requires_gpu def test_static_address_finalizer(self): import gc gc.disable() p_ref = None def fn(): nonlocal p_ref mod = torch.nn.Linear(10, 10, device=GPU_TYPE, bias=False) for p in mod.parameters(): p.grad = torch.rand_like(p) opt = torch.optim.Adam(mod.parameters(), lr=0.1) def fn(): opt.step() with torch.set_grad_enabled(False): step_fn_compiled = torch.compile(fn) step_fn_compiled() p_ref = weakref.ref(p) self.assertTrue(p_ref() is not None) fn() self.assertTrue(p_ref() is None) gc.enable() def test_guard_on_none_grads(self): def training_loop(): input = torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5, 0.6]).reshape(3, 2) model = torch.nn.Sequential( torch.nn.Linear(2, 3), torch.nn.Sigmoid(), torch.nn.Linear(3, 1), torch.nn.Sigmoid(), ) params = list(model.parameters()) optimizer = torch.optim.Adam(params) step_list = [] for i in range(6): optimizer.zero_grad() # Test that step behaves as expected (a no-op) when grads are set to None if i != 3: output = model(input) loss = output.sum() loss.backward() optimizer.step() step_list.append(optimizer.state[params[0]]["step"]) return step_list compiled_training_loop = torch._dynamo.optimize("eager")(training_loop) actual_steps = compiled_training_loop() expected_steps = training_loop() self.assertEqual(actual_steps, expected_steps) # Basic shampoo test to verify we support compiling the various ops without error @requires_gpu def test_basic_shampoo(self): param_buf = torch.rand((1024, 128)) param_buf_c = param_buf.clone().detach() params_c = [param_buf_c[0:512, :].t(), param_buf_c[512:, :].t()] params = [param_buf[0:512, :].t(), param_buf[512:, :].t()] for p, p_c in zip(params, params_c): p.grad = torch.rand_like(p) p_c.grad = p.grad.clone().detach() # note this skips the root inverse because this has a lot of internal dependencies # we also don't compile it regardless @torch.no_grad() def shampoo_functional_basic(params): step = 1 weight_decay = 0.1 grads = [p.grad for p in params] beta1 = 0.9 beta2 = 1.0 epsilon = 1e-10 preconditioners = [torch.zeros_like(p) for p in params] lr = 0.01 # pt2 region 1 # weight decay torch._foreach_add_(grads, params, alpha=weight_decay) # update preconditioners torch._foreach_addcmul_(preconditioners, grads, grads, value=1.0) torch._foreach_mul_(grads, beta1) torch._foreach_add_( grads, grads, alpha=1 - beta1, ) bias_correction1 = 1.0 - beta1**step grad_list = torch._foreach_div(grads, bias_correction1) # pt2 region 2 # precondition (with shampoo branch), with no grafting bias_correction2 = 1.0 - beta2**step bias_corrected_preconditioner_list = torch._foreach_div( preconditioners, bias_correction2 ) torch._foreach_sqrt_(bias_corrected_preconditioner_list) torch._foreach_add_(bias_corrected_preconditioner_list, epsilon) search_directions = torch._foreach_div( grad_list, bias_corrected_preconditioner_list ) torch._foreach_add_( search_directions, params, alpha=weight_decay, ) torch._foreach_mul_(search_directions, -lr) # pt2 region 3 update params torch._foreach_add_(params, search_directions) return params, preconditioners, grads compiled_fn = torch.compile(shampoo_functional_basic) self.assertEqual(compiled_fn(params_c), shampoo_functional_basic(params)) @requires_gpu def test_closure_graph_break(self): param = torch.rand( 2, 3, dtype=torch.float32, device=GPU_TYPE, requires_grad=True ) param_c = param.clone().detach().requires_grad_(True) def closure(): param.grad = torch.ones_like(param) * 2 return param.grad def closure_c(): param_c.grad = torch.ones_like(param_c) * 2 return param_c.grad optimizer = torch.optim.AdamW([param]) optimizer_c = torch.optim.AdamW([param_c]) def loop(opt, c): opt.step(c) compiled_loop = torch._dynamo.optimize("eager")(loop) compiled_loop(optimizer, closure) loop(optimizer_c, closure_c) self.assertEqual(param, param_c) def test_get_value_on_static_address(self): from torch._dynamo.decorators import mark_static_address from torch.optim.optimizer import _get_value compiled = torch.compile(_get_value) x = torch.ones(2, 2) mark_static_address(x) ret_val = compiled(x) self.assertEqual(ret_val, x) # compile a large foreach op and verify # that the time taken is within an expected range @requires_gpu def test_compile_time_smoketest(self): import time xs = [torch.ones(2, 2, device=GPU_TYPE) for _ in range(100)] ys = [torch.ones(2, 2, device=GPU_TYPE) for _ in range(100)] @torch.compile def fn(xs, ys): return torch._foreach_add(xs, ys) start = time.perf_counter() fn(xs, ys) end = time.perf_counter() self.assertLess(end - start, 90) @requires_cuda def test_S429861(self): # Just verify we can compile this function without error try: from . import s429861_repro except ImportError: import s429861_repro forward = s429861_repro.forward import torch._dynamo import torch._inductor from torch._dynamo.debug_utils import aot_graph_input_parser from torch._inductor.utils import fresh_inductor_cache with fresh_inductor_cache(): kwargs = aot_graph_input_parser(forward) torch.compile(forward)(**kwargs) for optim_cls, name, kwargs, scheduler_cls in COMPILED_OPT_KWARG_DB: setattr( CompiledOptimizerTests, name, make_test(optim_cls, scheduler_cls=scheduler_cls, **kwargs), ) instantiate_device_type_tests( CompiledOptimizerParityTests, globals(), allow_xpu=True, except_for="cpu" ) if __name__ == "__main__": from torch._inductor.test_case import run_tests if HAS_CPU or HAS_GPU: run_tests(needs="filelock")