Searched refs:sort_dim (Results 1 – 5 of 5) sorted by relevance
/external/tensorflow/tensorflow/compiler/xla/service/ |
D | topk_rewriter.cc | 91 const int64 sort_dim = sort->sort_dimension(); in Run() local 92 const int64 batch_dim = sort_dim == 1 ? 0 : 1; in Run() 124 k = slice->slice_limits(sort_dim); in Run() 125 } else if (k != slice->slice_limits(sort_dim)) { in Run() 142 const int64 input_size = sort->operand(0)->shape().dimensions(sort_dim); in Run() 144 if (has_batch && sort_dim == 0) { in Run() 167 if (has_batch && sort_dim == 0) { in Run()
|
D | hlo_evaluator.cc | 2180 int64 sort_dim = sort->dimensions(0); in HandleSort() local 2181 int64 sort_dim_elements = key_shape.dimensions(sort_dim); in HandleSort() 2182 increment[sort_dim] = sort_dim_elements; in HandleSort() 2192 limit_indices[sort_dim] = sort_dim_elements; in HandleSort() 2250 slice_dimensions[sort_dim] = sort_dim_elements; in HandleSort()
|
D | dynamic_padder.cc | 1198 int64 sort_dim = sort->sort_dimension(); in RewriteDynamicSort() local 1203 dynamic_dimension_inference->GetDynamicSize(operand, {}, sort_dim); in RewriteDynamicSort() 1215 comp->AddInstruction(HloInstruction::CreateIota(operand_shape, sort_dim)); in RewriteDynamicSort()
|
/external/tensorflow/tensorflow/compiler/xla/service/spmd/ |
D | spmd_partitioner.cc | 1658 const int64 sort_dim = sort->sort_dimension(); in HandleSort() local 1663 input_sharding.tile_assignment().dim(sort_dim); in HandleSort() 1664 const int64 input_size = input->shape().dimensions(sort_dim); in HandleSort() 1683 replicated_dimensions[sort_dim] = per_partition_size * partition_count; in HandleSort() 1704 replicated_dimensions[sort_dim] = k.value() * partition_count; in HandleSort() 1705 auto slice_input = SliceFirstK(value_gte, &b_, sort_dim, k.value()); in HandleSort() 1715 auto slice_index = SliceFirstK(index_gte, &b_, sort_dim, k.value()); in HandleSort() 1730 final_topk_shape, sort_dim, in HandleSort() 1815 const int64 sort_dim = 1; in HandleCustomCall() local 1816 const int64 shard_count = sharding.tile_assignment().dim(sort_dim); in HandleCustomCall() [all …]
|
D | spmd_partitioner_util.cc | 1173 const int64 sort_dim = sort->sort_dimension(); in GetKValueInTopKWhenPartitionSortDim() local 1201 if (dim == sort_dim) { in GetKValueInTopKWhenPartitionSortDim() 1212 k = slice->slice_limits(sort_dim); in GetKValueInTopKWhenPartitionSortDim() 1213 } else if (k != slice->slice_limits(sort_dim)) { in GetKValueInTopKWhenPartitionSortDim() 1237 if (dim != sort_dim) { in GetKValueInTopKWhenPartitionSortDim() 1244 const int64 shard_count = sharding.tile_assignment().dim(sort_dim); in GetKValueInTopKWhenPartitionSortDim() 1250 const int64 input_size = hlo->operand(0)->shape().dimensions(sort_dim); in GetKValueInTopKWhenPartitionSortDim()
|