Home
last modified time | relevance | path

Searched refs:unwrap_tensors (Results 1 – 19 of 19) sorted by relevance

/external/pytorch/torch/_higher_order_ops/
Dflex_attention.py405 query_unwrapped = ctx.unwrap_tensors(query)
406 key_unwrapped = ctx.unwrap_tensors(key)
407 value_unwrapped = ctx.unwrap_tensors(value)
408 block_mask_unwrapped = ctx.unwrap_tensors(block_mask)
409 score_mod_other_buffers_unwrapped = ctx.unwrap_tensors(score_mod_other_buffers)
410 mask_mod_other_buffers_unwrapped = ctx.unwrap_tensors(mask_mod_other_buffers)
976 query_unwrapped = ctx.unwrap_tensors(query)
977 key_unwrapped = ctx.unwrap_tensors(key)
978 value_unwrapped = ctx.unwrap_tensors(value)
979 out_unwrapped = ctx.unwrap_tensors(out)
[all …]
Dhints_wrap.py94 unwrapped_args = ctx.unwrap_tensors(args)
95 unwrapped_kwargs = ctx.unwrap_tensors(kwargs)
96 unwrapped_hints = ctx.unwrap_tensors(hints)
Deffects.py265 unwrapped_token = ctx.unwrap_tensors([token])[0] # type: ignore[arg-type]
266 unwrapped_args = ctx.unwrap_tensors(args) # type: ignore[arg-type]
267 unwrapped_kwargs = ctx.unwrap_tensors(kwargs) # type: ignore[arg-type]
Dauto_functionalize.py353 unwrapped_kwargs = ctx.unwrap_tensors(normalized_kwargs) # type: ignore[arg-type]
491 unwrapped_kwargs = ctx.unwrap_tensors(normalized_kwargs) # type: ignore[arg-type]
497 all_basis_unwrapped = ctx.unwrap_tensors(all_bases)
615 unwrapped_kwargs = ctx.unwrap_tensors(kwargs)
710 unwrapped_kwargs = ctx.unwrap_tensors(kwargs)
Dwhile_loop.py238 unwrapped_carried_inputs = ctx.unwrap_tensors(carried_inputs)
239 unwrapped_additional_inputs = ctx.unwrap_tensors(additional_inputs)
Drun_const_graph.py38 unwrapped_args = ctx.unwrap_tensors(args)
Dmap.py245 unwrapped_xs = ctx.unwrap_tensors(xs)
246 unwrapped_args = ctx.unwrap_tensors(pos_args)
Dstrict_mode.py89 unwrapped_inputs = ctx.unwrap_tensors(inputs)
Dcond.py450 unwrapped_inputs = ctx.unwrap_tensors(inputs)
451 unwrapped_pred = ctx.unwrap_tensors(pred)
Dout_dtype.py162 unwrapped_args = tuple(ctx.unwrap_tensors(arg) for arg in args)
Dexecutorch_call_delegate.py139 unwrapped_args = tuple(ctx.unwrap_tensors(arg) for arg in args)
Dassociative_scan.py359 unwrapped_input = ctx.unwrap_tensors(input)
Dtriton_kernel_wrap.py631 unwrapped_kwargs = ctx.unwrap_tensors(kwargs)
723 unwrapped_kwargs = ctx.unwrap_tensors(kwargs)
/external/pytorch/torch/_export/
Dwrappers.py43 unwrapped_args = ctx.unwrap_tensors(args)
44 unwrapped_kwargs = ctx.unwrap_tensors(kwargs)
/external/pytorch/torch/_prims/
Drng_prims.py301 unwrapped_rng_state = ctx.unwrap_tensors(rng_state)
302 unwrapped_args = ctx.unwrap_tensors(args)
303 unwrapped_kwargs = ctx.unwrap_tensors(kwargs)
/external/pytorch/torch/_dynamo/
D_trace_wrapped_higher_order_op.py125 unwrapped_args = ctx.unwrap_tensors(args)
/external/pytorch/torch/_subclasses/
Dfunctional_tensor.py642 def unwrap_tensors( member in BaseFunctionalizeAPI
686 def unwrap_tensors( member in PythonFunctionalizeAPI
728 def unwrap_tensors( member in CppFunctionalizeAPI
767 def unwrap_tensors( member in FunctorchFunctionalizeAPI
/external/executorch/exir/
Ddelegate.py139 unwrapped_args = tuple(ctx.unwrap_tensors(arg) for arg in args)
/external/pytorch/torch/_functorch/
Deager_transforms.py123 def unwrap_tensors(x): function
128 return tree_map(unwrap_tensors, tuple(x))
132 return tree_map(unwrap_tensors, inps)