Home
last modified time | relevance | path

Searched refs:op_overload (Results 1 – 25 of 32) sorted by relevance

12

/external/executorch/exir/operator/
Dconvert.py85 def _get_overload_schema(op_overload: OpOverload) -> Optional[FunctionSchema]:
86 native_schema = _op_overload_to_schema_cache.get(op_overload)
88 native_schema = _pybind_schema_to_native_schema(op_overload._schema)
89 _op_overload_to_schema_cache[op_overload] = native_schema # pyre-ignore
93 def get_out_args_from_opoverload(op_overload: OpOverload) -> Tuple[str]:
94 return get_out_args_from_schema(_get_overload_schema(op_overload)) # pyre-ignore
231 def to_out_variant(op_overload: OpOverload) -> Tuple[OpOverload, Tuple[str]]:
239 schema = _get_overload_schema(op_overload)
241 return op_overload, get_out_args_from_schema(schema) # pyre-ignore[6]
250 op_overload not in _func_to_out_variant_map
[all …]
/external/pytorch/torch/onnx/_internal/fx/
Ddecomposition_table.py51 op_overload = getattr(op_overload_packet, overload_name)
53 qualified_name=op_overload.name()
69 table.add(op_overload)
97 …for op_overload, decomp_fn in torch._decomp.decomposition_table.items(): # type: ignore[attr-defi…
104 or op_overload in _ONNX_SUPPORT_OP_OVERLOADS
107 decomposition_table[op_overload] = decomp_fn
112 for op_overload, decomp_fn in torch._decomp.core_aten_decompositions().items():
113 if op_overload in _ONNX_SUPPORT_OP_OVERLOADS:
115 decomposition_table[op_overload] = decomp_fn
Dregistration.py64 def from_op_overload(cls, op_overload: torch._ops.OpOverload) -> OpName:
65 return cls.from_qualified_name(op_overload.name())
Donnxfunction_dispatcher.py296 …return registration.OpName.from_op_overload(op_overload=aten_op_default) # type: ignore[no-any-re…
317 return registration.OpName.from_op_overload(op_overload=node.target)
/external/pytorch/torch/onnx/_internal/exporter/
D_decomp.py37 op_overload = getattr(op_overload_packet, overload_name)
38 if registry.is_registered(op_overload):
39 registered_ops.append(op_overload)
91 …for op_overload, decomp_fn in torch._decomp.decomposition_table.items(): # type: ignore[attr-defi…
96 if op_overload in onnx_registered_ops:
98 decomposition_table[op_overload] = decomp_fn
/external/executorch/exir/tests/
Dtest_op_convert.py27 op_overload = torch.ops.aten.topk.values
28 out_var_op = op_convert.to_out_variant(op_overload)[0]
29 self.assertTrue(op_overload is out_var_op)
32 op_overload = torch.ops.aten.topk.default
33 out_var_op, out_args = op_convert.to_out_variant(op_overload)
41 expect_values, expect_indices = op_overload(input_tensor, k)
/external/pytorch/docs/source/scripts/
Dbuild_opsets.py42 op_overload = getattr(prims, op_name, None)
44 if not isinstance(op_overload, torch._ops.OpOverload):
47 op_overloadpacket = op_overload.overloadpacket
49 op_name = str(op_overload).replace(".default", "")
/external/pytorch/torch/export/
Dexported_program.py196 for op_overload in ops_to_preserve:
203 def assert_valid_to_preserve(op_overload): argument
204 if op_overload in FunctionalTensor.maybe_aliasing_or_mutating_ops:
208 if op_overload in FunctionalTensor.metadata_fns:
216 [i for i in op_overload._schema.arguments if i.alias_info is not None]
219 is_mutating_or_aliasing = alias_info != 0 or op_overload._schema.is_mutable
226 if not torch._C._dispatch_has_kernel(op_overload.name()):
235 assert_valid_to_preserve(op_overload)
237 saved_tables[op_overload] = op_overload.py_kernels.copy()
238 patched_ops.add(op_overload)
[all …]
/external/pytorch/torch/testing/_internal/
Dtorchbind_impls.py127 def _register_py_impl_temporarily(op_overload, key, fn): argument
129 op_overload.py_impl(key)(fn)
132 del op_overload.py_kernels[key]
133 op_overload._dispatch_cache.clear()
/external/pytorch/torch/_inductor/
Dmkldnn_ir.py241 op_overload=torch.ops.mkldnn._convolution_pointwise.default,
264 op_overload=self.op_overload,
312 op_overload=torch.ops.mkldnn._convolution_pointwise.binary,
340 self.op_overload,
402 op_overload=torch.ops.mkldnn._convolution_pointwise_.binary,
435 self.op_overload,
499 op_overload=torch.ops.mkldnn._convolution_transpose_pointwise.default,
593 op_overload=torch.ops.onednn.qconv2d_pointwise.default,
722 op_overload=self.op_overload,
827 op_overload=torch.ops.onednn.qconv2d_pointwise.binary,
[all …]
Dir.py3955 op_overload: Optional[ variable in ExternKernel
3975 op_overload=None, argument
3987 self.cpp_kernel_name = cpp_kernel_name or get_aten_cpp_kernel_name(op_overload)
3989 self.op_overload = op_overload
4007 for x in self.op_overload._schema.arguments
4010 if isinstance(self.op_overload, torch._ops.OpOverload)
4016 for x in self.op_overload._schema.arguments
4018 if isinstance(self.op_overload, torch._ops.OpOverload)
4024 isinstance(self.op_overload, torch._ops.OpOverload)
4028 x.name for x in self.op_overload._schema.arguments if x.kwarg_only
[all …]
Dselect_algorithm.py790 op_overload=None, argument
802 self.op_overload = op_overload
979 self.choice.op_overload is not None
982 self.choice.op_overload, *self.input_nodes, **self.kwargs
994 op_overload=self.choice.op_overload,
Dutils.py1637 return type(node) == ir._CollectiveKernel and (op is None or node.op_overload is op)
1671 return isinstance(node, ir.FallbackKernel) and node.op_overload in op
1815 op_overload: torch._ops.OpOverload,
1823 op_overload.overloadpacket
1830 "add" if op_overload.overloadpacket == torch.ops.aten.scatter_ else "sum"
1841 op_overload.overloadpacket == torch.ops.aten.scatter_reduce_
/external/executorch/exir/dialects/backend/
D_ops.py55 op_overload: EdgeOpOverload,
58 op_overload._op,
59 op_overload._schema,
/external/pytorch/torch/_decomp/
D__init__.py58 for op_overload in overloads:
59 if op_overload in registry:
64 if torch._C._dispatch_has_kernel(op_overload.name()):
65 registry[op_overload] = fn
222 for op_overload in packets_to_overloads[op]:
223 decompositions[op_overload] = registry[op_overload]
/external/pytorch/torch/distributed/tensor/
D_sharding_prop.py83 op_overload: OpOverload,
90 self.op_to_rules[op_overload] = rule_func
92 self.op_to_schema_info[op_overload] = schema_info
96 op_overload: OpOverload,
103 self.op_strategy_funcs[op_overload] = strategy_func
105 self.op_to_schema_info[op_overload] = schema_info
/external/pytorch/torch/_inductor/codegen/
Dcpp_wrapper_cpu.py1989 self, op_overload, raw_args, output_args argument
1991 arg_types = [x.real_type for x in op_overload._schema.arguments]
1992 return_types = [x.type for x in op_overload._schema.returns]
2126 op_overload: Optional[torch._ops.OpOverload] = None,
2152 assert op_overload is not None
2158 op_overload,
2171 op_overload,
2300 op_overload: Optional[torch._ops.OpOverload] = None,
2337 assert op_overload is not None, "op_overload should not be None"
2340 zip(raw_args, op_overload._schema.arguments)
[all …]
/external/pytorch/torch/_custom_op/
Dautograd.py96 op_overload, argument
108 output = op_overload(*args)
/external/pytorch/torch/_library/
Dutils.py267 def handle_dispatch_mode(curr_mode, op_overload, *args, **kwargs): argument
282 return curr_mode.__torch_dispatch__(op_overload, overload_types, args, kwargs)
/external/pytorch/torch/
D_ops.py632 def add_cached_op(op_overload): argument
634 cached_ops.add(op_overload)
1142 op_overload = getattr(op, overload_name)
1145 op_overload._schema, *args, **kwargs
1147 found_op = op_overload
D_meta_registrations.py6579 for op_overload, fn in activate_meta_table.items():
6584 if isinstance(op_overload, torch._ops.HigherOrderOperator):
6586 assert isinstance(op_overload, OpOverload)
6588 op_overload.py_impl(torch._C.DispatchKey.Meta)(fn)
6591 op_overload.name(), "CompositeImplicitAutograd"
6597 if op_overload in global_decomposition_table["meta"]:
6603 elif op_overload.is_view:
6609 op_overload.name()
6622 if "mkldnn::" in op_overload.name():
6623 _meta_lib_dont_use_me_use_register_meta_for_mkldnn.impl(op_overload, fn)
[all …]
/external/pytorch/test/export/
Dtest_passes.py518 op_overload = getattr(getattr(torch.ops.aten, name), overload)
519 if torch.Tag.core in op_overload.tags and is_view_op(op_overload._schema):
520 self.assertIsNotNone(get_view_copy_of_view_op(op_overload._schema))
/external/executorch/exir/emit/
D_emitter.py1110 op_overload = ""
1112 op_overload = target._overloadname
1131 op_index, operator = self._get_operator(name=op_name, overload=op_overload)
1153 if is_out_variant(op_name, op_overload):
/external/pytorch/torch/_inductor/kernel/
Dconv.py357 op_overload=aten.convolution.default,
/external/pytorch/test/distributed/
Dtest_compute_comm_reordering.py54 and snode.node.op_overload == torch.ops._c10d_functional.all_reduce_.default

12