Home
last modified time | relevance | path

Searched refs:flat_tensors (Results 1 – 6 of 6) sorted by relevance

/external/pytorch/torch/nn/parallel/
Dcomm.py156 flat_tensors = [
159 flat_result = reduce_add(flat_tensors, destination)
/external/pytorch/aten/src/ATen/native/nested/
DNestedTensorUtils.h418 std::vector<Tensor> flat_tensors; in wrap_tensor_node() local
421 flat_tensors.push_back(tensor_node.children(i).reshape(-1).contiguous()); in wrap_tensor_node()
425 options = flat_tensors[0].options().merge_in(options_); in wrap_tensor_node()
426 nt_buffer = at::cat(flat_tensors); in wrap_tensor_node()
/external/pytorch/torch/autograd/
Dfunction.py819 flat_tensors = super().saved_tensors # type: ignore[misc]
820 return _unflatten(flat_tensors, self._to_save_nested)
/external/tensorflow/tensorflow/python/distribute/
Dcross_device_utils.py418 flat_tensors = [array_ops.reshape(t, [-1]) for t in pack]
426 array_ops.concat(flat_tensors, axis=0), control_input, options)
/external/pytorch/torch/distributed/fsdp/
D_flat_param.py820 flat_tensors: List[Tensor] = []
829 flat_tensors.append(padding_tensor)
831 flat_tensors.append(torch.flatten(_detach_if_needed(tensor)))
838 flat_tensors.append(padding_tensor)
841 flat_tensors = [
844 return torch.cat(flat_tensors, dim=0)
/external/pytorch/torch/fx/experimental/
Dproxy_tensor.py1169 flat_tensors, tensors_spec = pytree.tree_flatten(tensors)
1174 assert len(flat_proxies) == len(flat_tensors)
1177 track_tensor_tree(flat_tensors, flat_proxies, constant=None, tracer=tracer)