• Home
  • Raw
  • Download

Lines Matching full:torch

6 import torch
7 import torch._decomp
8 import torch._prims
9 import torch._refs
10 import torch._refs.nn
11 import torch._refs.nn.functional
12 import torch._refs.special
13 import torch.overrides
14 from torch._prims_common import torch_function_passthrough
20 Mapping of torch API functions to torch._refs functions.
21 E.g. torch_to_refs_map()[torch.add] == torch._refs.add
24 (torch, torch._refs),
25 (torch.nn, torch._refs.nn),
26 (torch.nn.functional, torch._refs.nn.functional),
27 (torch.special, torch._refs.special),
28 (torch.fft, torch._refs.fft),
29 (torch.linalg, torch._refs.linalg),
32 torch.Tensor.__invert__: torch._refs.bitwise_not,
33 torch.Tensor.__xor__: torch._refs.bitwise_xor,
34 torch.Tensor.__and__: torch._refs.bitwise_and,
35 torch.Tensor.__or__: torch._refs.bitwise_or,
36 torch.Tensor.__eq__: torch._refs.eq,
37 torch.Tensor.__rsub__: torch._refs.rsub,
38 torch.Tensor.__rtruediv__: torch._refs.rtruediv,
39 torch.Tensor.__floordiv__: torch._refs.floor_divide,
40 torch.Tensor.__rfloordiv__: torch._refs.rfloordiv,
41 torch.Tensor.__pow__: torch._refs.pow,
42 torch.Tensor.__rpow__: torch._refs.rpow,
43 torch.Tensor.new_empty: torch._refs.new_empty,
44 torch.Tensor.new_full: torch._refs.new_full,
45 torch.Tensor.new_zeros: torch._refs.new_zeros,
46 torch.Tensor.new_ones: torch._refs.new_ones,
47 torch.Tensor.fill_: torch._refs.fill_,
48 torch.Tensor.zero_: torch._refs.zero_,
49 torch.Tensor.to: torch._refs.to,
50 torch.Tensor.sum_to_size: torch._refs.sum_to_size,
52 torch.Tensor.copy_: torch._prims.copy_to,
53 torch.Tensor.resize: torch._prims.resize,
59 # Support remapping torch.Tensor.foo to _refs.foo
60 for s in dir(torch.Tensor):
61 if s in torch._refs.__all__:
62 r[getattr(torch.Tensor, s)] = torch._refs.__dict__.get(s)
65 for s in torch._refs._conversions.__all__:
66 tensor_attr = getattr(torch.Tensor, s, None) or getattr(torch, s)
67 r[tensor_attr] = torch._refs._conversions.__dict__.get(s)
75 Set of all prim functions, e.g., torch._prims.add in all_prims()
77 return {torch._prims.__dict__.get(s) for s in torch._prims.__all__}
80 class TorchRefsMode(torch.overrides.TorchFunctionMode):
82 Switches the interpretation of torch.* functions and Tensor methods to
83 use PrimTorch refs in torch._refs. (Direct calls to _refs are unaffected.)
87 ... torch.add(x, y) # calls torch._refs.add(x, y)
89 By default, this context manager will fall back on the torch.* if the
91 If the ref exists we still would like to fall back on the torch.* sometimes,
122 # For torch.ops.aten.*, use registered decompositions from torch._decomp
123 # torch._decomp.decomposition_table provides a mapping from
124 # torch.ops.aten.* to torch._refs or torch._decomp.decompositions
128 if func is None and isinstance(orig_func, torch._ops.OpOverload):
129 func = torch._decomp.decomposition_table.get(orig_func, None)
130 elif func is None and isinstance(orig_func, torch._ops.OpOverloadPacket):
133 func = torch._decomp.decomposition_table.get(default, None)
139 # torch calls inside func should be interpreted as refs calls
144 f"no _refs support for {torch.overrides.resolve_name(orig_func)}"