/external/tensorflow/tensorflow/compiler/xla/service/spmd/ |
D | spmd_partitioner.h | 128 int64 next_channel_id)> 200 int64* next_channel_id, 212 int64* next_channel_id, absl::Span<const int64> selected_dims, 219 int64* next_channel_id, absl::Span<const int64> selected_dims, 229 int64* next_channel_id, SpmdLogger* logger, 234 int64* next_channel_id, absl::Span<const int64> selected_dims, 238 int64* next_channel_id, absl::Span<const int64> selected_dims, 289 int64* next_channel_id; member 326 int64 NewChannel() const { return (*state_.next_channel_id)++; } in NewChannel() 405 int64* next_channel_id, SpmdLogger* logger, [all …]
|
D | fft_handler.cc | 55 int64* next_channel_id, HloInstruction* partition_id, SpmdBuilder* b) { in PadEachPartitionWithHaloExchange() argument 82 next_channel_id, b); in PadEachPartitionWithHaloExchange() 162 int64* next_channel_id, SpmdBuilder* b) { in ShuffleDataWithAllToAll() argument 168 b, {hlo}, groups, (*next_channel_id)++, hlo->shape().rank() - 1); in ShuffleDataWithAllToAll() 230 int64 num_partitions, HloInstruction* partition_id, int64* next_channel_id, in GetFinalFftUsingCollectivePermute() argument 295 &body_b, source_partition_id, src_dst_pairs, (*next_channel_id)++); in GetFinalFftUsingCollectivePermute() 299 &body_b, source_transform, src_dst_pairs, (*next_channel_id)++); in GetFinalFftUsingCollectivePermute() 386 partitioned_input.state().next_channel_id, in HandleFft() 404 partitioned_input.state().next_channel_id, partitioned_input.state().b); in HandleFft() 425 partitioned_input.state().next_channel_id, module_, in HandleFft()
|
D | spmd_partitioner_util.h | 222 int64* next_channel_id, SpmdBuilder* b); 232 int64* next_channel_id, SpmdBuilder* b); 261 int64* next_channel_id, SpmdBuilder* b, bool mask_invalid_region = true); 371 int64* next_channel_id, HloInstruction* partition_id, SpmdBuilder* b); 394 int64* next_channel_id, HloInstruction* partition_id, SpmdBuilder* b);
|
D | spmd_partitioner.cc | 809 state_.collective_ops_creator, state_.next_channel_id, state_.b, in ReshardAsWindowedInput() 880 state_.next_channel_id, dims, in ReplicatePartial() 896 state_.b, dus, sharding(), state_.next_channel_id, dims, in ReplicatePartial() 949 partitioned_hlo.state().next_channel_id, in ReshardToPartialReplicateWithAllGather() 1018 state_.collective_ops_creator, state_.next_channel_id, in ReshardFromPartialReplicateWithDynamicSlice() 1186 state_.b, {reshape}, groups, (*state_.next_channel_id)++, target_dim); in ReshardWithAllToAll() 1346 state_.b, hlo(), src_dst_pairs, (*state_.next_channel_id)++); in ReshardWithCollectivePermute() 1354 int64* next_channel_id, SpmdLogger* logger, SpmdPartitionerOptions options, in SpmdPartitioningVisitor() argument 1361 next_channel_id_(next_channel_id), in SpmdPartitioningVisitor() 3501 int64* next_channel_id, absl::Span<const int64> selected_dims, in AllGatherShards() argument [all …]
|
D | spmd_partitioner_util.cc | 413 int64* next_channel_id, HloInstruction* partition_id, SpmdBuilder* b) { in TileToPartialReplicateHaloExchange() argument 468 src_sharding, collective_ops_creator, next_channel_id, b); in TileToPartialReplicateHaloExchange() 507 int64* next_channel_id, HloInstruction* partition_id, SpmdBuilder* b) { in PadFromPartialReplicateShape() argument 568 src_sharding, collective_ops_creator, next_channel_id, b); in PadFromPartialReplicateShape() 803 int64* next_channel_id, SpmdBuilder* b) { in ExchangeHalo() argument 846 b, source_halo_slice, source_target_pairs, (*next_channel_id)++); in ExchangeHalo() 879 b, source_halo_slice, source_target_pairs, (*next_channel_id)++); in ExchangeHalo() 905 int64* next_channel_id, SpmdBuilder* b) { in ExchangeHalo() argument 913 collective_ops_creator, next_channel_id, b); in ExchangeHalo() 931 int64* next_channel_id, SpmdBuilder* b, bool mask_invalid_region) { in ExchangeHaloAndGetValidData() argument [all …]
|
D | dot_handler.cc | 650 (*lhs.state().next_channel_id)++); in PartitionBaseCase() 1168 (*lhs.state().next_channel_id)++); in PartitionBaseCase() 1181 (*lhs.state().next_channel_id)++); in PartitionBaseCase() 1203 (*lhs.state().next_channel_id)++); in PartitionBaseCase() 1216 (*lhs.state().next_channel_id)++); in PartitionBaseCase() 1244 (*lhs.state().next_channel_id)++); in PartitionBaseCase() 1253 (*lhs.state().next_channel_id)++); in PartitionBaseCase() 1284 (*lhs.state().next_channel_id)++); in PartitionBaseCase() 1306 (*lhs.state().next_channel_id)++); in PartitionBaseCase() 1365 &cp_b, p, sd_pairs, (*lhs.state().next_channel_id)++); in PartitionBaseCase() [all …]
|
D | convolution_handler.cc | 491 lhs.state().next_channel_id, b, in PartitionConvolutionWithSpatialDimensionHaloExchangeOnRHS() 504 (*lhs.state().next_channel_id)++); in PartitionConvolutionWithSpatialDimensionHaloExchangeOnRHS() 718 lhs.state().next_channel_id, b, in PartitionConvolutionWithSpatialDimensionHaloExchangeOnLHS() 731 (*lhs.state().next_channel_id)++); in PartitionConvolutionWithSpatialDimensionHaloExchangeOnLHS()
|
D | gather_scatter_handler.cc | 296 b, filtered, operand.sharding(), operand.state().next_channel_id, in ParititonTrivialIndexedOperandDimension() 606 &b_, pscatter, indices.sharding(), indices.state().next_channel_id, in HandleScatter()
|
/external/tensorflow/tensorflow/compiler/xla/service/ |
D | hlo_query.cc | 137 int64 next_channel_id = 1; in NextChannelId() local 143 next_channel_id = in NextChannelId() 144 std::max(next_channel_id, *channel_instr->channel_id() + 1); in NextChannelId() 148 return next_channel_id; in NextChannelId()
|