Home
last modified time | relevance | path

Searched refs:sort_dim (Results 1 – 5 of 5) sorted by relevance

/external/tensorflow/tensorflow/compiler/xla/service/
Dtopk_rewriter.cc91 const int64 sort_dim = sort->sort_dimension(); in Run() local
92 const int64 batch_dim = sort_dim == 1 ? 0 : 1; in Run()
124 k = slice->slice_limits(sort_dim); in Run()
125 } else if (k != slice->slice_limits(sort_dim)) { in Run()
142 const int64 input_size = sort->operand(0)->shape().dimensions(sort_dim); in Run()
144 if (has_batch && sort_dim == 0) { in Run()
167 if (has_batch && sort_dim == 0) { in Run()
Dhlo_evaluator.cc2180 int64 sort_dim = sort->dimensions(0); in HandleSort() local
2181 int64 sort_dim_elements = key_shape.dimensions(sort_dim); in HandleSort()
2182 increment[sort_dim] = sort_dim_elements; in HandleSort()
2192 limit_indices[sort_dim] = sort_dim_elements; in HandleSort()
2250 slice_dimensions[sort_dim] = sort_dim_elements; in HandleSort()
Ddynamic_padder.cc1198 int64 sort_dim = sort->sort_dimension(); in RewriteDynamicSort() local
1203 dynamic_dimension_inference->GetDynamicSize(operand, {}, sort_dim); in RewriteDynamicSort()
1215 comp->AddInstruction(HloInstruction::CreateIota(operand_shape, sort_dim)); in RewriteDynamicSort()
/external/tensorflow/tensorflow/compiler/xla/service/spmd/
Dspmd_partitioner.cc1658 const int64 sort_dim = sort->sort_dimension(); in HandleSort() local
1663 input_sharding.tile_assignment().dim(sort_dim); in HandleSort()
1664 const int64 input_size = input->shape().dimensions(sort_dim); in HandleSort()
1683 replicated_dimensions[sort_dim] = per_partition_size * partition_count; in HandleSort()
1704 replicated_dimensions[sort_dim] = k.value() * partition_count; in HandleSort()
1705 auto slice_input = SliceFirstK(value_gte, &b_, sort_dim, k.value()); in HandleSort()
1715 auto slice_index = SliceFirstK(index_gte, &b_, sort_dim, k.value()); in HandleSort()
1730 final_topk_shape, sort_dim, in HandleSort()
1815 const int64 sort_dim = 1; in HandleCustomCall() local
1816 const int64 shard_count = sharding.tile_assignment().dim(sort_dim); in HandleCustomCall()
[all …]
Dspmd_partitioner_util.cc1173 const int64 sort_dim = sort->sort_dimension(); in GetKValueInTopKWhenPartitionSortDim() local
1201 if (dim == sort_dim) { in GetKValueInTopKWhenPartitionSortDim()
1212 k = slice->slice_limits(sort_dim); in GetKValueInTopKWhenPartitionSortDim()
1213 } else if (k != slice->slice_limits(sort_dim)) { in GetKValueInTopKWhenPartitionSortDim()
1237 if (dim != sort_dim) { in GetKValueInTopKWhenPartitionSortDim()
1244 const int64 shard_count = sharding.tile_assignment().dim(sort_dim); in GetKValueInTopKWhenPartitionSortDim()
1250 const int64 input_size = hlo->operand(0)->shape().dimensions(sort_dim); in GetKValueInTopKWhenPartitionSortDim()