import timeit import torch.fx N = 100000 K = 1000 def huge_graph(): def fn(x): for _ in range(N): x = x.sin() return x return torch.fx.symbolic_trace(fn) def main(): g = huge_graph() def fn(): for n in g.graph.nodes: pass t = min(timeit.repeat(fn, number=K, repeat=3)) print(f"iterating over {N*K} FX nodes took {t:.1f}s ({N*K/t:.0f} nodes/s)") if __name__ == "__main__": main()