Searched defs:mesh_dim (Results 1 – 7 of 7) sorted by relevance
338 const std::string& mesh_dim) { in GetMeshDimensionOffsetWithNeighbor()375 for (const MeshDimension& mesh_dim : mesh_dimensions) { in CreateConstSrcTargetPair() local393 const MeshDimension& mesh_dim = data.value(); in CreateConstSrcTargetPair() local433 const std::string& mesh_dim, in EmitHaloExchange()
472 const std::string& mesh_dim = layout.dim(i).sharding_spec(); in GetMostShardedLayout() local
305 for (const auto& mesh_dim : dims()) { in dim_size() local312 for (const auto& mesh_dim : dims()) dim_names.push_back(mesh_dim.name); in dim_size() local322 for (const auto& mesh_dim : mesh_dims_) dim_sizes.push_back(mesh_dim.size); in dim_sizes() local395 for (const auto& mesh_dim : dims()) in IsMeshDim() local403 const auto mesh_dim = dim(i); in GetMeshDimIndexWithName() local505 MeshDimension mesh_dim; in StrToMeshDimension() local531 for (const MeshDimension& mesh_dim : mesh_dims) mesh_size *= mesh_dim.size; in GenerateMeshDevicesForTests() local732 for (const MeshDimension& mesh_dim : layout->mesh().dims()) { in ReducedAbstractMesh() local
174 const std::string& mesh_dim = sharding_spec_strs[tensor_dim_index]; in SliceSpecOnDevice() local
117 for (int mesh_dim : mesh_shape) topology_proto.add_mesh_shape(mesh_dim); in TopologyWithMeshShape() local125 for (int mesh_dim : mesh_shape) topology_proto.add_mesh_shape(mesh_dim); in TopologyWithMeshShapeAndTasks() local
166 std::string mesh_dim = Layout::kUnshardedDim; in IntermediateBatchLayout() local
490 for (const auto& mesh_dim : recv_mesh.dims()) in ExpandOp() local