Searched refs:HloCollectiveInstruction (Results 1 – 6 of 6) sorted by relevance
28 HloCollectiveInstruction* MayConsiderAsAllGather(HloInstruction* hlo, in MayConsiderAsAllGather()30 auto coll = DynCast<HloCollectiveInstruction>(hlo); in MayConsiderAsAllGather()81 std::vector<HloCollectiveInstruction*>> in RunOnComputation()
570 HloCollectiveInstruction::HloCollectiveInstruction( in HloCollectiveInstruction() function in xla::HloCollectiveInstruction583 HloInstructionProto HloCollectiveInstruction::ToProto() const { in ToProto()591 std::vector<string> HloCollectiveInstruction::ExtraAttributesToStringImpl( in ExtraAttributesToStringImpl()603 bool HloCollectiveInstruction::IdenticalSlowPathIgnoringChannelIdValues( in IdenticalSlowPathIgnoringChannelIdValues()608 static_cast<const HloCollectiveInstruction&>(other); in IdenticalSlowPathIgnoringChannelIdValues()622 : HloCollectiveInstruction(HloOpcode::kAllGather, shape, {operand}, in HloAllGatherInstruction()630 HloCollectiveInstruction::ExtraAttributesToStringImpl(options); in ExtraAttributesToStringImpl()648 HloInstructionProto proto = HloCollectiveInstruction::ToProto(); in ToProto()659 return HloCollectiveInstruction::IdenticalSlowPathIgnoringChannelIdValues( in IdenticalSlowPathIgnoringChannelIdValues()670 : HloCollectiveInstruction(HloOpcode::kAllReduce, shape, operands, in HloAllReduceInstruction()[all …]
347 class HloCollectiveInstruction : public HloChannelInstruction {369 explicit HloCollectiveInstruction(388 class HloAllGatherInstruction : public HloCollectiveInstruction {420 class HloAllReduceInstruction : public HloCollectiveInstruction {463 class HloAllToAllInstruction : public HloCollectiveInstruction {
437 const HloCollectiveInstruction* collective = in IsLayoutConstrainedCollective()438 DynCast<HloCollectiveInstruction>(instruction); in IsLayoutConstrainedCollective()
4128 return Cast<HloCollectiveInstruction>(this)->replica_groups(); in replica_groups()
1037 auto* collective = xla::Cast<xla::HloCollectiveInstruction>(instr); in SetupCommonCollectiveOpAttributes()