Home
last modified time | relevance | path

Searched refs:output_sharding (Results 1 – 8 of 8) sorted by relevance

/external/tensorflow/tensorflow/compiler/xla/service/
Dhlo_sharding_util_test.cc60 HloSharding output_sharding = HloSharding::Tile(Array2D<int64>({{0}, {1}})); in TEST() local
64 EXPECT_EQ(result.value(), output_sharding); in TEST()
71 HloSharding output_sharding = in TEST() local
76 EXPECT_EQ(result.value(), output_sharding); in TEST()
86 HloSharding output_sharding = HloSharding::Tile(tile); in TEST() local
90 EXPECT_EQ(result.value(), output_sharding); in TEST()
98 HloSharding output_sharding = in TEST() local
103 EXPECT_EQ(result.value(), output_sharding); in TEST()
124 HloSharding output_sharding = in TEST() local
129 EXPECT_EQ(result.value(), output_sharding); in TEST()
[all …]
Dhlo_sharding_util.cc537 HloSharding GatherIndexSharding(const HloSharding& output_sharding, in GatherIndexSharding() argument
540 if (output_sharding.IsTileMaximal()) { in GatherIndexSharding()
541 return output_sharding; in GatherIndexSharding()
549 output_sharding.tile_assignment().dim(i)); in GatherIndexSharding()
561 if (output_sharding.ReplicateOnLastTileDim()) { in GatherIndexSharding()
563 output_sharding.tile_assignment().dimensions().back(); in GatherIndexSharding()
566 Array<int64> new_tile_assignment = output_sharding.tile_assignment(); in GatherIndexSharding()
574 return HloSharding::Replicate(output_sharding.metadata()); in GatherIndexSharding()
583 output_sharding.metadata()) in GatherIndexSharding()
585 output_sharding.metadata()); in GatherIndexSharding()
[all …]
Dhlo_sharding_util.h117 HloSharding GatherIndexSharding(const HloSharding& output_sharding,
159 const HloSharding& output_sharding, const HloInstruction& hlo);
169 const HloSharding& output_sharding, const HloInstruction& hlo);
/external/tensorflow/tensorflow/compiler/mlir/tensorflow/utils/
Dxla_sharding_util.cc329 const auto& output_sharding = output_sharding_and_index.value(); in ParseAndValidateOutputSharding() local
331 if (!output_sharding.isa<mlir::StringAttr>()) in ParseAndValidateOutputSharding()
337 output_sharding.cast<mlir::StringAttr>().getValue().str())) in ParseAndValidateOutputSharding()
461 const xla::OpSharding& output_sharding, in ValidateAndGetTiledExecuteOutputShape() argument
466 llvm::enumerate(output_sharding.tile_assignment_dimensions())) { in ValidateAndGetTiledExecuteOutputShape()
508 const auto& output_sharding = output_sharding_config[output_index]; in GetOutputTypesForLogicalDeviceComputation() local
509 const auto output_sharding_type = output_sharding.type(); in GetOutputTypesForLogicalDeviceComputation()
520 cluster_func.getLoc(), cluster_func_output_type, output_sharding, in GetOutputTypesForLogicalDeviceComputation()
528 IsAssignedToLogicalDevice(core_id, output_sharding)) { in GetOutputTypesForLogicalDeviceComputation()
545 const auto& output_sharding = output_sharding_config[output_index]; in RemapOutputsFromLogicalDevices() local
[all …]
/external/tensorflow/tensorflow/compiler/xla/service/spmd/
Ddot_handler.cc468 const HloSharding& output_sharding, const DotConvDimsMapping& dims_mapping, in PartitionBaseCase() argument
488 output_sharding.ReplicateOnLastTileDim()) { in PartitionBaseCase()
512 output_sharding, indices_map.output_to_lhs_indices, in PartitionBaseCase()
516 output_sharding, indices_map.output_to_rhs_indices, in PartitionBaseCase()
528 .Reshard(output_sharding) in PartitionBaseCase()
541 lhs_sharding_transposed_to_match_output == output_sharding) { in PartitionBaseCase()
558 rhs_sharding_transposed_to_match_output == output_sharding) { in PartitionBaseCase()
586 FirstShardingDimWithPartitionOfSize(num_partitions, output_sharding); in PartitionBaseCase()
603 MakePartitionedShape(output_base_shape, output_sharding); in PartitionBaseCase()
1524 .Reshard(output_sharding) in PartitionBaseCase()
[all …]
Dconvolution_handler.cc43 const HloSharding& output_sharding, in PartitionConvolutionWithBatchGroupCount() argument
128 .Reshard(output_sharding) in PartitionConvolutionWithBatchGroupCount()
135 const HloSharding& output_sharding, in PartitionConvolutionWithFeatureGroupCount() argument
220 .Reshard(output_sharding) in PartitionConvolutionWithFeatureGroupCount()
229 const HloSharding& output_sharding, in PartitionConvolutionWithSpatialDimensionHaloExchangeOnRHS() argument
507 .Reshard(output_sharding) in PartitionConvolutionWithSpatialDimensionHaloExchangeOnRHS()
516 const HloSharding& output_sharding, in PartitionConvolutionWithSpatialDimensionHaloExchangeOnLHS() argument
734 .Reshard(output_sharding) in PartitionConvolutionWithSpatialDimensionHaloExchangeOnLHS()
742 const HloSharding& output_sharding, in PartitionConvolutionTiledOutput() argument
749 TF_RET_CHECK(!output_sharding.IsTileMaximal()); in PartitionConvolutionTiledOutput()
[all …]
Dgather_scatter_handler.cc130 const HloSharding& output_sharding,
181 const HloSharding& output_sharding, absl::Span<const int64> batch_dims, in ParititonPassthroughOperand() argument
206 .Reshard(output_sharding) in ParititonPassthroughOperand()
216 const HloSharding& output_sharding, absl::Span<const int64> batch_dims, in ParititonTrivialIndexedOperandDimension() argument
304 .Reshard(output_sharding) in ParititonTrivialIndexedOperandDimension()
316 const HloSharding& output_sharding, absl::Span<const int64> batch_dims, in PartitionIndexParallelDimensions() argument
451 .Reshard(output_sharding) in PartitionIndexParallelDimensions()
463 const HloSharding& output_sharding, in PartitionGather() argument
480 PartitionIndexParallelDimensions(gather, output_shape, output_sharding, in PartitionGather()
489 ParititonPassthroughOperand(gather, output_shape, output_sharding, in PartitionGather()
[all …]
Dconvolution_handler.h31 const HloSharding& output_sharding, const DotConvDimsMapping& dims_mapping,