import time from argparse import ArgumentParser from collections import defaultdict from typing import Any, Callable, List, NamedTuple import torch from torch.autograd import functional try: import functorch as ft has_functorch = True print(f"Found functorch: {ft.__version__}") except ImportError: has_functorch = False import audio_text_models import ppl_models import vision_models from utils import GetterType, InputsType, TimingResultType, to_markdown_table, VType def get_task_func(task: str) -> Callable: def hessian_fwdrev(model, inp, strict=None): return functional.hessian( model, inp, strict=False, vectorize=True, outer_jacobian_strategy="forward-mode", ) def hessian_revrev(model, inp, strict=None): return functional.hessian(model, inp, strict=False, vectorize=True) def jacfwd(model, inp, strict=None): return functional.jacobian( model, inp, strict=False, vectorize=True, strategy="forward-mode" ) def jacrev(model, inp, strict=None): return functional.jacobian(model, inp, strict=False, vectorize=True) if task == "hessian_fwdrev": return hessian_fwdrev elif task == "hessian_revrev": return hessian_revrev elif task == "jacfwd": return jacfwd elif task == "jacrev": return jacrev else: return getattr(functional, task) def get_task_functorch(task: str) -> Callable: @torch.no_grad() def vjp(model, inp, v=None, strict=None): assert v is not None out, vjpfunc = ft.vjp(model, *inp) return out, vjpfunc(v) @torch.no_grad() def jvp(model, inp, v=None, strict=None): assert v is not None return ft.jvp(model, inp, v) @torch.no_grad() def vhp(model, inp, v=None, strict=None): assert v is not None argnums = tuple(range(len(inp))) _, vjpfunc, aux = ft.vjp(ft.grad_and_value(model, argnums), *inp, has_aux=True) return aux, vjpfunc(v) @torch.no_grad() def hvp(model, inp, v=None, strict=None): assert v is not None argnums = tuple(range(len(inp))) _, hvp_out, aux = ft.jvp( ft.grad_and_value(model, argnums), inp, v, has_aux=True ) return aux, hvp_out @torch.no_grad() def jacfwd(model, inp, v=None, strict=None): argnums = tuple(range(len(inp))) return ft.jacfwd(model, argnums)(*inp) @torch.no_grad() def jacrev(model, inp, v=None, strict=None): argnums = tuple(range(len(inp))) return ft.jacrev(model, argnums)(*inp) @torch.no_grad() def hessian(model, inp, v=None, strict=None): argnums = tuple(range(len(inp))) return ft.hessian(model, argnums=argnums)(*inp) @torch.no_grad() def hessian_fwdrev(model, inp, v=None, strict=None): argnums = tuple(range(len(inp))) return ft.jacfwd(ft.jacrev(model, argnums=argnums), argnums=argnums)(*inp) @torch.no_grad() def hessian_revrev(model, inp, v=None, strict=None): argnums = tuple(range(len(inp))) return ft.jacrev(ft.jacrev(model, argnums=argnums), argnums=argnums)(*inp) if task in locals(): return locals()[task] elif task == "jacobian": raise RuntimeError( "functorch has no equivalent of autograd.functional.jacobian with vectorize=False yet" ) else: raise RuntimeError(f"Unsupported task: {task}") # Listing of the different tasks FAST_TASKS_NO_DOUBLE_BACK = [ "vjp", ] FAST_TASKS = FAST_TASKS_NO_DOUBLE_BACK + [ "vhp", "jvp", ] ALL_TASKS_NON_VECTORIZED = FAST_TASKS + ["hvp", "jacobian", "hessian"] DOUBLE_BACKWARD_TASKS = ["jvp", "hvp", "vhp", "hessian"] VECTORIZED_TASKS = ["hessian_fwdrev", "hessian_revrev", "jacfwd", "jacrev"] ALL_TASKS = ALL_TASKS_NON_VECTORIZED + VECTORIZED_TASKS # Model definition which contains: # - name: a string with the model name. # - getter: a function to get the model. It takes as input the device on which the model # will run. It should return the forward function and the parameters (Tensors) used as # input for the forward function. Note that the forward must *not* have any side effect. # - tasks: the list of recommended tasks that can run in a reasonable amount of time with this model. # - unsupported: the list of tasks that this model cannot run. class ModelDef(NamedTuple): name: str getter: GetterType tasks: List[str] unsupported: List[str] MODELS = [ ModelDef("resnet18", vision_models.get_resnet18, FAST_TASKS, []), ModelDef("fcn_resnet", vision_models.get_fcn_resnet, FAST_TASKS, []), ModelDef("detr", vision_models.get_detr, FAST_TASKS, []), ModelDef("ppl_simple_reg", ppl_models.get_simple_regression, ALL_TASKS, []), ModelDef("ppl_robust_reg", ppl_models.get_robust_regression, ALL_TASKS, []), ModelDef("wav2letter", audio_text_models.get_wav2letter, FAST_TASKS, []), ModelDef( "deepspeech", audio_text_models.get_deepspeech, FAST_TASKS_NO_DOUBLE_BACK, DOUBLE_BACKWARD_TASKS, ), ModelDef("transformer", audio_text_models.get_transformer, FAST_TASKS, []), ModelDef("multiheadattn", audio_text_models.get_multiheadattn, FAST_TASKS, []), ] def get_v_for(model: Callable, inp: InputsType, task: str) -> VType: v: VType if task in ["vjp"]: out = model(*inp) v = torch.rand_like(out) elif task in ["jvp", "hvp", "vhp"]: if isinstance(inp, tuple): v = tuple(torch.rand_like(i) for i in inp) else: v = torch.rand_like(inp) else: v = None return v def run_once(model: Callable, inp: InputsType, task: str, v: VType, **kwargs) -> None: func = get_task_func(task) if v is not None: res = func(model, inp, v=v, strict=True) else: res = func(model, inp, strict=True) def run_once_functorch( model: Callable, inp: InputsType, task: str, v: VType, maybe_check_consistency=False ) -> None: func = get_task_functorch(task) if v is not None: res = func(model, inp, v=v, strict=True) else: res = func(model, inp, strict=True) if maybe_check_consistency: af_func = get_task_func(task) if v is not None: expected = af_func(model, inp, v=v, strict=True) else: expected = af_func(model, inp, strict=True) atol = 1e-2 if task == "vhp" else 5e-3 torch.testing.assert_close( res, expected, rtol=1e-5, atol=atol, msg=f"Consistency fail for task '{task}'", ) def run_model( model_getter: GetterType, args: Any, task: str, run_once_fn: Callable = run_once ) -> List[float]: if args.gpu == -1: device = torch.device("cpu") def noop(): pass do_sync = noop else: device = torch.device(f"cuda:{args.gpu}") do_sync = torch.cuda.synchronize model, inp = model_getter(device) v = get_v_for(model, inp, task) # Warmup # maybe_check_consistency=True checks for consistency between # functorch vs autograd.functional and is done in run_once_functorch only run_once_fn(model, inp, task, v, maybe_check_consistency=True) elapsed = [] for it in range(args.num_iters): do_sync() start = time.time() run_once_fn(model, inp, task, v) do_sync() elapsed.append(time.time() - start) return elapsed def main(): parser = ArgumentParser("Main script to benchmark functional API of the autograd.") parser.add_argument( "--output", type=str, default="", help="Text file where to write the output" ) parser.add_argument("--num-iters", type=int, default=10) parser.add_argument( "--gpu", type=int, default=-2, help="GPU to use, -1 for CPU and -2 for auto-detect", ) parser.add_argument( "--run-slow-tasks", action="store_true", help="Run even the slow tasks" ) parser.add_argument( "--model-filter", type=str, default="", help="Only run the models in this filter", ) parser.add_argument( "--task-filter", type=str, default="", help="Only run the tasks in this filter" ) parser.add_argument( "--num-threads", type=int, default=10, help="Number of concurrent threads to use when running on cpu", ) parser.add_argument("--seed", type=int, default=0, help="The random seed to use.") args = parser.parse_args() results: TimingResultType = defaultdict(defaultdict) torch.set_num_threads(args.num_threads) torch.set_num_interop_threads(args.num_threads) # This automatically seed cuda if it is available torch.manual_seed(args.seed) if args.gpu == -2: args.gpu = 0 if torch.cuda.is_available() else -1 for name, model_getter, recommended_tasks, unsupported_tasks in MODELS: if args.model_filter and name not in args.model_filter: continue tasks = ALL_TASKS if args.run_slow_tasks else recommended_tasks for task in tasks: if task in unsupported_tasks: continue if args.task_filter and task not in args.task_filter: continue runtimes = run_model(model_getter, args, task) runtimes = torch.tensor(runtimes) mean, var = runtimes.mean(), runtimes.var() results[name][task] = (mean.item(), var.item()) print(f"Results for model {name} on task {task}: {mean}s (var: {var})") if has_functorch: try: runtimes = run_model( model_getter, args, task, run_once_fn=run_once_functorch ) except RuntimeError as e: print( f"Failed model using Functorch: {name}, task: {task}, Error message: \n\t", e, ) continue runtimes = torch.tensor(runtimes) mean, var = runtimes.mean(), runtimes.var() results[name][f"functorch {task}"] = (mean.item(), var.item()) print( f"Results for model {name} on task {task} using Functorch: {mean}s (var: {var})" ) if args.output: with open(args.output, "w") as f: f.write(to_markdown_table(results)) if __name__ == "__main__": main()