Home
last modified time | relevance | path

Searched defs:scatter_dim (Results 1 – 5 of 5) sorted by relevance

/external/tensorflow/tensorflow/dtensor/mlir/
Ddtensor_allreduce_scatter_optimization.cc50 mlir::TF::DTensorAllScatterOp all_scatter, int scatter_dim) { in GetScatterGroupAssignment()
106 int scatter_dim = scatter_dims[0]; in ApplyOptimization() local
/external/pytorch/test/distributed/tensor/parallel/
Dtest_micro_pipeline_tp.py306 def test_fuse_matmul_reduce_scatter(self, A_dims, scatter_dim): argument
335 def test_fuse_scaled_matmul_reduce_scatter(self, A_dims, scatter_dim): argument
/external/tensorflow/tensorflow/dtensor/mlir/utils/
Dcollective_lowering.cc367 mlir::APInt scatter_dim = *scatter_attr.begin(); in LowerReduceScatterOp() local
/external/tensorflow/tensorflow/core/framework/
Dcommon_shape_fns.cc2616 int64_t scatter_dim; in ReduceScatterShape() local
/external/tensorflow/tensorflow/compiler/mlir/xla/
Dmlir_hlo_to_hlo.cc790 auto scatter_dim = op.scatter_dimension(); in ExportXlaOp() local