/external/tensorflow/tensorflow/compiler/xla/service/ |
D | hlo_sharding_util_test.cc | 60 HloSharding output_sharding = HloSharding::Tile(Array2D<int64>({{0}, {1}})); in TEST() local 64 EXPECT_EQ(result.value(), output_sharding); in TEST() 71 HloSharding output_sharding = in TEST() local 76 EXPECT_EQ(result.value(), output_sharding); in TEST() 86 HloSharding output_sharding = HloSharding::Tile(tile); in TEST() local 90 EXPECT_EQ(result.value(), output_sharding); in TEST() 98 HloSharding output_sharding = in TEST() local 103 EXPECT_EQ(result.value(), output_sharding); in TEST() 124 HloSharding output_sharding = in TEST() local 129 EXPECT_EQ(result.value(), output_sharding); in TEST() [all …]
|
D | hlo_sharding_util.cc | 537 HloSharding GatherIndexSharding(const HloSharding& output_sharding, in GatherIndexSharding() argument 540 if (output_sharding.IsTileMaximal()) { in GatherIndexSharding() 541 return output_sharding; in GatherIndexSharding() 549 output_sharding.tile_assignment().dim(i)); in GatherIndexSharding() 561 if (output_sharding.ReplicateOnLastTileDim()) { in GatherIndexSharding() 563 output_sharding.tile_assignment().dimensions().back(); in GatherIndexSharding() 566 Array<int64> new_tile_assignment = output_sharding.tile_assignment(); in GatherIndexSharding() 574 return HloSharding::Replicate(output_sharding.metadata()); in GatherIndexSharding() 583 output_sharding.metadata()) in GatherIndexSharding() 585 output_sharding.metadata()); in GatherIndexSharding() [all …]
|
D | hlo_sharding_util.h | 117 HloSharding GatherIndexSharding(const HloSharding& output_sharding, 159 const HloSharding& output_sharding, const HloInstruction& hlo); 169 const HloSharding& output_sharding, const HloInstruction& hlo);
|
/external/tensorflow/tensorflow/compiler/mlir/tensorflow/utils/ |
D | xla_sharding_util.cc | 329 const auto& output_sharding = output_sharding_and_index.value(); in ParseAndValidateOutputSharding() local 331 if (!output_sharding.isa<mlir::StringAttr>()) in ParseAndValidateOutputSharding() 337 output_sharding.cast<mlir::StringAttr>().getValue().str())) in ParseAndValidateOutputSharding() 461 const xla::OpSharding& output_sharding, in ValidateAndGetTiledExecuteOutputShape() argument 466 llvm::enumerate(output_sharding.tile_assignment_dimensions())) { in ValidateAndGetTiledExecuteOutputShape() 508 const auto& output_sharding = output_sharding_config[output_index]; in GetOutputTypesForLogicalDeviceComputation() local 509 const auto output_sharding_type = output_sharding.type(); in GetOutputTypesForLogicalDeviceComputation() 520 cluster_func.getLoc(), cluster_func_output_type, output_sharding, in GetOutputTypesForLogicalDeviceComputation() 528 IsAssignedToLogicalDevice(core_id, output_sharding)) { in GetOutputTypesForLogicalDeviceComputation() 545 const auto& output_sharding = output_sharding_config[output_index]; in RemapOutputsFromLogicalDevices() local [all …]
|
/external/tensorflow/tensorflow/compiler/xla/service/spmd/ |
D | dot_handler.cc | 468 const HloSharding& output_sharding, const DotConvDimsMapping& dims_mapping, in PartitionBaseCase() argument 488 output_sharding.ReplicateOnLastTileDim()) { in PartitionBaseCase() 512 output_sharding, indices_map.output_to_lhs_indices, in PartitionBaseCase() 516 output_sharding, indices_map.output_to_rhs_indices, in PartitionBaseCase() 528 .Reshard(output_sharding) in PartitionBaseCase() 541 lhs_sharding_transposed_to_match_output == output_sharding) { in PartitionBaseCase() 558 rhs_sharding_transposed_to_match_output == output_sharding) { in PartitionBaseCase() 586 FirstShardingDimWithPartitionOfSize(num_partitions, output_sharding); in PartitionBaseCase() 603 MakePartitionedShape(output_base_shape, output_sharding); in PartitionBaseCase() 1524 .Reshard(output_sharding) in PartitionBaseCase() [all …]
|
D | convolution_handler.cc | 43 const HloSharding& output_sharding, in PartitionConvolutionWithBatchGroupCount() argument 128 .Reshard(output_sharding) in PartitionConvolutionWithBatchGroupCount() 135 const HloSharding& output_sharding, in PartitionConvolutionWithFeatureGroupCount() argument 220 .Reshard(output_sharding) in PartitionConvolutionWithFeatureGroupCount() 229 const HloSharding& output_sharding, in PartitionConvolutionWithSpatialDimensionHaloExchangeOnRHS() argument 507 .Reshard(output_sharding) in PartitionConvolutionWithSpatialDimensionHaloExchangeOnRHS() 516 const HloSharding& output_sharding, in PartitionConvolutionWithSpatialDimensionHaloExchangeOnLHS() argument 734 .Reshard(output_sharding) in PartitionConvolutionWithSpatialDimensionHaloExchangeOnLHS() 742 const HloSharding& output_sharding, in PartitionConvolutionTiledOutput() argument 749 TF_RET_CHECK(!output_sharding.IsTileMaximal()); in PartitionConvolutionTiledOutput() [all …]
|
D | gather_scatter_handler.cc | 130 const HloSharding& output_sharding, 181 const HloSharding& output_sharding, absl::Span<const int64> batch_dims, in ParititonPassthroughOperand() argument 206 .Reshard(output_sharding) in ParititonPassthroughOperand() 216 const HloSharding& output_sharding, absl::Span<const int64> batch_dims, in ParititonTrivialIndexedOperandDimension() argument 304 .Reshard(output_sharding) in ParititonTrivialIndexedOperandDimension() 316 const HloSharding& output_sharding, absl::Span<const int64> batch_dims, in PartitionIndexParallelDimensions() argument 451 .Reshard(output_sharding) in PartitionIndexParallelDimensions() 463 const HloSharding& output_sharding, in PartitionGather() argument 480 PartitionIndexParallelDimensions(gather, output_shape, output_sharding, in PartitionGather() 489 ParititonPassthroughOperand(gather, output_shape, output_sharding, in PartitionGather() [all …]
|
D | convolution_handler.h | 31 const HloSharding& output_sharding, const DotConvDimsMapping& dims_mapping,
|