Home
last modified time | relevance | path

Searched defs:replica_groups (Results 1 – 25 of 33) sorted by relevance

12

/external/tensorflow/tensorflow/compiler/mlir/xla/transforms/
Dlegalize_tf_collective.cc95 DenseIntElementsAttr replica_groups, in SetCollectiveInfo()
107 DenseIntElementsAttr& replica_groups, in ConvertReplicaGroups()
135 DenseIntElementsAttr replica_groups, in ConvertAllReduce()
201 DenseIntElementsAttr replica_groups; in matchAndRewrite() local
267 DenseIntElementsAttr replica_groups; in matchAndRewrite() local
309 auto replica_groups = mlir::DenseIntElementsAttr::get( in matchAndRewrite() local
/external/tensorflow/tensorflow/core/tpu/kernels/
Dcross_replica_ops.cc34 std::vector<xla::ReplicaGroup> replica_groups; in Convert() local
59 std::vector<xla::ReplicaGroup> replica_groups = in Compile() local
81 std::vector<xla::ReplicaGroup> replica_groups = in Compile() local
/external/tensorflow/tensorflow/compiler/xla/service/
Dcollective_decomposer_utils.cc36 absl::Span<const ReplicaGroup> replica_groups, const Shape &shard_shape, in CreateStartIndicesForCollectiveDecomposition()
94 auto is_trivial_group = [](absl::Span<const ReplicaGroup> replica_groups) { in CreateStartIndicesForCollectiveDecomposition()
Dcollective_ops_utils_test.cc43 std::vector<ReplicaGroup> replica_groups(3); in TEST() local
121 std::vector<std::vector<int>> replica_groups; member
386 std::vector<ReplicaGroup> replica_groups; in TEST_P() local
Dall_reduce_key.cc41 std::vector<std::vector<int64_t>> replica_groups; in GetAllReduceKey() local
Dcollective_ops_utils.cc131 absl::Span<const ReplicaGroup> replica_groups, in GetParticipatingDevicesGroups()
240 absl::Span<const ReplicaGroup> replica_groups, in GetParticipatingDevices()
Dall_gather_combiner.cc119 std::vector<std::vector<int64_t>> replica_groups; in CombineKey() local
Dhlo_matchers.h214 std::vector<std::vector<int64_t>> replica_groups) in HloReplicaGroupsMatcher()
524 std::vector<std::vector<int64_t>> replica_groups) { in ReplicaGroups()
Dbfloat16_normalization_test.cc287 std::vector<ReplicaGroup> replica_groups(1); in TEST_F() local
316 std::vector<ReplicaGroup> replica_groups(1); in TEST_F() local
Dhlo_replication_analysis_test.cc653 std::vector<ReplicaGroup> replica_groups(4); in TEST_F() local
Dhlo_matchers_test.cc332 std::vector<ReplicaGroup> replica_groups(2); in TEST_F() local
Dhlo_verifier_test.cc1380 std::string ReplicaGroupsStr(std::vector<std::vector<int64_t>> replica_groups) { in ReplicaGroupsStr()
1390 int64_t ReplicaCount(const std::vector<std::vector<int64_t>>& replica_groups) { in ReplicaCount()
1399 std::vector<std::vector<int64_t>> replica_groups, in MakeCollectiveCommOpComputation()
1417 std::vector<std::vector<int64_t>> replica_groups, in MakeAllReduceComputation()
1610 std::vector<std::vector<int64_t>> replica_groups, in MakeAllToAllComputation()
Dhlo_instruction.cc511 std::vector<ReplicaGroup> replica_groups(proto.replica_groups().begin(), in CreateFromProto() local
1259 int64_t all_gather_dimension, absl::Span<const ReplicaGroup> replica_groups, in CreateAllGather()
1270 int64_t all_gather_dimension, absl::Span<const ReplicaGroup> replica_groups, in CreateAllGatherStart()
1281 absl::Span<const ReplicaGroup> replica_groups, bool constrain_layout, in CreateAllReduce()
1292 absl::Span<const ReplicaGroup> replica_groups, bool constrain_layout, in CreateReduceScatter()
1304 absl::Span<const ReplicaGroup> replica_groups, bool constrain_layout, in CreateAllReduceStart()
1313 absl::Span<const ReplicaGroup> replica_groups, bool constrain_layout, in CreateAllToAll()
4173 absl::Span<const ReplicaGroup> replica_groups) { in ReplicaGroupsToString()
4700 const std::vector<ReplicaGroup>& HloInstruction::replica_groups() const { in replica_groups() function in xla::HloInstruction
Dall_reduce_combiner_test.cc101 std::vector<ReplicaGroup> replica_groups(groups.size()); in CreateReplicaGroups() local
Dhlo_parser.cc611 std::vector<ReplicaGroup> replica_groups; in CreateReplicaGroups() local
1447 std::vector<ReplicaGroup> replica_groups; in CreateInstruction() local
1489 std::vector<ReplicaGroup> replica_groups; in CreateInstruction() local
1527 std::vector<ReplicaGroup> replica_groups; in CreateInstruction() local
3214 std::vector<ReplicaGroup>* replica_groups) { in ParseReplicaGroupsOnly()
5826 std::vector<ReplicaGroup> replica_groups; in ParseReplicaGroupsOnly() local
Dar_crs_combiner.cc96 auto replica_groups = all_reduce->replica_groups(); in HasCombinableReplicaGroup() local
Dwhile_loop_all_reduce_code_motion.cc79 absl::Span<const ReplicaGroup> replica_groups, int num_replicas, in IsValueReplicatedWithinEachAllReduceGroup()
Dhlo_instructions.cc740 absl::Span<const ReplicaGroup> replica_groups, bool constrain_layout, in HloCollectiveInstruction()
788 absl::Span<const ReplicaGroup> replica_groups, bool constrain_layout, in HloAllGatherInstruction()
837 absl::Span<const ReplicaGroup> replica_groups, bool constrain_layout, in HloAllReduceInstructionBase()
899 absl::Span<const ReplicaGroup> replica_groups, bool constrain_layout, in HloReduceScatterInstruction()
944 absl::Span<const ReplicaGroup> replica_groups, bool constrain_layout, in HloAllToAllInstruction()
/external/tensorflow/tensorflow/compiler/mlir/xla/
Dattribute_exporter.cc89 std::vector<ReplicaGroup> replica_groups(type.getDimSize(0)); in ConvertReplicaGroups() local
/external/tensorflow/tensorflow/compiler/xla/service/gpu/
Dnccl_collective_thunk.cc115 const std::vector<ReplicaGroup>& replica_groups, in LockNcclComm()
Dnccl_collective_thunk.h53 std::vector<ReplicaGroup> replica_groups; member
Dall_reduce_blueconnect.cc115 absl::Span<const ReplicaGroup> replica_groups = all_reduce.replica_groups(); in TryDecomposeReplicaGroups() local
/external/tensorflow/tensorflow/core/profiler/protobuf/
Dpod_viewer.proto29 repeated ReplicaGroup replica_groups = 5; field
/external/tensorflow/tensorflow/compiler/xla/client/
Dxla_builder.cc2870 absl::Span<const ReplicaGroup> replica_groups, in AllGather()
2907 XlaOp operand, absl::Span<const ReplicaGroup> replica_groups) { in CrossReplicaSum()
2937 absl::Span<const ReplicaGroup> replica_groups, in AllReduce()
3019 int64_t shard_count, absl::Span<const ReplicaGroup> replica_groups, in ReduceScatter()
3079 absl::Span<const ReplicaGroup> replica_groups, in AllToAll()
3093 absl::Span<const ReplicaGroup> replica_groups) { in AllToAllArray()
3146 absl::Span<const ReplicaGroup> replica_groups, in AllToAllTuple()
3185 absl::Span<const ReplicaGroup> replica_groups, in AllToAllTuple()
4655 absl::Span<const ReplicaGroup> replica_groups, in AllGather()
4665 absl::Span<const ReplicaGroup> replica_groups) { in CrossReplicaSum()
[all …]
/external/tensorflow/tensorflow/compiler/xla/python/
Dxla_client.py637 def make_replica_groups(replica_groups): argument

12