Home
last modified time | relevance | path

Searched refs:next_channel_id (Results 1 – 9 of 9) sorted by relevance

/external/tensorflow/tensorflow/compiler/xla/service/spmd/
Dspmd_partitioner.h128 int64 next_channel_id)>
200 int64* next_channel_id,
212 int64* next_channel_id, absl::Span<const int64> selected_dims,
219 int64* next_channel_id, absl::Span<const int64> selected_dims,
229 int64* next_channel_id, SpmdLogger* logger,
234 int64* next_channel_id, absl::Span<const int64> selected_dims,
238 int64* next_channel_id, absl::Span<const int64> selected_dims,
289 int64* next_channel_id; member
326 int64 NewChannel() const { return (*state_.next_channel_id)++; } in NewChannel()
405 int64* next_channel_id, SpmdLogger* logger,
[all …]
Dfft_handler.cc55 int64* next_channel_id, HloInstruction* partition_id, SpmdBuilder* b) { in PadEachPartitionWithHaloExchange() argument
82 next_channel_id, b); in PadEachPartitionWithHaloExchange()
162 int64* next_channel_id, SpmdBuilder* b) { in ShuffleDataWithAllToAll() argument
168 b, {hlo}, groups, (*next_channel_id)++, hlo->shape().rank() - 1); in ShuffleDataWithAllToAll()
230 int64 num_partitions, HloInstruction* partition_id, int64* next_channel_id, in GetFinalFftUsingCollectivePermute() argument
295 &body_b, source_partition_id, src_dst_pairs, (*next_channel_id)++); in GetFinalFftUsingCollectivePermute()
299 &body_b, source_transform, src_dst_pairs, (*next_channel_id)++); in GetFinalFftUsingCollectivePermute()
386 partitioned_input.state().next_channel_id, in HandleFft()
404 partitioned_input.state().next_channel_id, partitioned_input.state().b); in HandleFft()
425 partitioned_input.state().next_channel_id, module_, in HandleFft()
Dspmd_partitioner_util.h222 int64* next_channel_id, SpmdBuilder* b);
232 int64* next_channel_id, SpmdBuilder* b);
261 int64* next_channel_id, SpmdBuilder* b, bool mask_invalid_region = true);
371 int64* next_channel_id, HloInstruction* partition_id, SpmdBuilder* b);
394 int64* next_channel_id, HloInstruction* partition_id, SpmdBuilder* b);
Dspmd_partitioner.cc809 state_.collective_ops_creator, state_.next_channel_id, state_.b, in ReshardAsWindowedInput()
880 state_.next_channel_id, dims, in ReplicatePartial()
896 state_.b, dus, sharding(), state_.next_channel_id, dims, in ReplicatePartial()
949 partitioned_hlo.state().next_channel_id, in ReshardToPartialReplicateWithAllGather()
1018 state_.collective_ops_creator, state_.next_channel_id, in ReshardFromPartialReplicateWithDynamicSlice()
1186 state_.b, {reshape}, groups, (*state_.next_channel_id)++, target_dim); in ReshardWithAllToAll()
1346 state_.b, hlo(), src_dst_pairs, (*state_.next_channel_id)++); in ReshardWithCollectivePermute()
1354 int64* next_channel_id, SpmdLogger* logger, SpmdPartitionerOptions options, in SpmdPartitioningVisitor() argument
1361 next_channel_id_(next_channel_id), in SpmdPartitioningVisitor()
3501 int64* next_channel_id, absl::Span<const int64> selected_dims, in AllGatherShards() argument
[all …]
Dspmd_partitioner_util.cc413 int64* next_channel_id, HloInstruction* partition_id, SpmdBuilder* b) { in TileToPartialReplicateHaloExchange() argument
468 src_sharding, collective_ops_creator, next_channel_id, b); in TileToPartialReplicateHaloExchange()
507 int64* next_channel_id, HloInstruction* partition_id, SpmdBuilder* b) { in PadFromPartialReplicateShape() argument
568 src_sharding, collective_ops_creator, next_channel_id, b); in PadFromPartialReplicateShape()
803 int64* next_channel_id, SpmdBuilder* b) { in ExchangeHalo() argument
846 b, source_halo_slice, source_target_pairs, (*next_channel_id)++); in ExchangeHalo()
879 b, source_halo_slice, source_target_pairs, (*next_channel_id)++); in ExchangeHalo()
905 int64* next_channel_id, SpmdBuilder* b) { in ExchangeHalo() argument
913 collective_ops_creator, next_channel_id, b); in ExchangeHalo()
931 int64* next_channel_id, SpmdBuilder* b, bool mask_invalid_region) { in ExchangeHaloAndGetValidData() argument
[all …]
Ddot_handler.cc650 (*lhs.state().next_channel_id)++); in PartitionBaseCase()
1168 (*lhs.state().next_channel_id)++); in PartitionBaseCase()
1181 (*lhs.state().next_channel_id)++); in PartitionBaseCase()
1203 (*lhs.state().next_channel_id)++); in PartitionBaseCase()
1216 (*lhs.state().next_channel_id)++); in PartitionBaseCase()
1244 (*lhs.state().next_channel_id)++); in PartitionBaseCase()
1253 (*lhs.state().next_channel_id)++); in PartitionBaseCase()
1284 (*lhs.state().next_channel_id)++); in PartitionBaseCase()
1306 (*lhs.state().next_channel_id)++); in PartitionBaseCase()
1365 &cp_b, p, sd_pairs, (*lhs.state().next_channel_id)++); in PartitionBaseCase()
[all …]
Dconvolution_handler.cc491 lhs.state().next_channel_id, b, in PartitionConvolutionWithSpatialDimensionHaloExchangeOnRHS()
504 (*lhs.state().next_channel_id)++); in PartitionConvolutionWithSpatialDimensionHaloExchangeOnRHS()
718 lhs.state().next_channel_id, b, in PartitionConvolutionWithSpatialDimensionHaloExchangeOnLHS()
731 (*lhs.state().next_channel_id)++); in PartitionConvolutionWithSpatialDimensionHaloExchangeOnLHS()
Dgather_scatter_handler.cc296 b, filtered, operand.sharding(), operand.state().next_channel_id, in ParititonTrivialIndexedOperandDimension()
606 &b_, pscatter, indices.sharding(), indices.state().next_channel_id, in HandleScatter()
/external/tensorflow/tensorflow/compiler/xla/service/
Dhlo_query.cc137 int64 next_channel_id = 1; in NextChannelId() local
143 next_channel_id = in NextChannelId()
144 std::max(next_channel_id, *channel_instr->channel_id() + 1); in NextChannelId()
148 return next_channel_id; in NextChannelId()