Home
last modified time | relevance | path

Searched defs:gather_dim (Results 1 – 5 of 5) sorted by relevance

/external/pytorch/torch/distributed/tensor/
D_collective_utils.py32 def _shard_dim_alltoall_meta(input, gather_dim, shard_dim, group_name): argument
50 def shard_dim_alltoall(input, gather_dim, shard_dim, mesh, mesh_dim): argument
/external/pytorch/test/distributed/tensor/parallel/
Dtest_micro_pipeline_tp.py200 def test_fuse_all_gather_matmul(self, A_dims, gather_dim): argument
238 def test_fuse_all_gather_scaled_matmul(self, A_dims, gather_dim): argument
/external/pytorch/torch/csrc/distributed/c10d/
DFunctional.cpp605 int64_t gather_dim, in shard_dim_alltoall()
/external/tensorflow/tensorflow/compiler/xla/service/
Dindexed_array_analysis.cc213 for (int64_t gather_dim : source->output_dims()) { in FoldGatherOfGather() local
/external/pytorch/torch/_inductor/
Dlowering.py6444 def _shard_dim_alltoall(inp, gather_dim, shard_dim, group_name): argument