Home
last modified time | relevance | path

Searched defs:sharding (Results 1 – 25 of 48) sorted by relevance

12

/external/tensorflow/tensorflow/compiler/xla/service/
Dhlo_sharding_test.cc57 HloSharding sharding = HloSharding::Replicate(); in TEST_F() local
72 HloSharding sharding = HloSharding::AssignDevice(5); in TEST_F() local
109 HloSharding sharding = HloSharding::FromProto(proto).ConsumeValueOrDie(); in TEST_F() local
116 HloSharding sharding = HloSharding::Tile(MakeArray({2, 2}, {0, 0, 2, 3})); in TEST_F() local
123 HloSharding sharding = HloSharding::Tile(MakeArray({2, 2}, {0, 1, 2, 3})); in TEST_F() local
131 HloSharding sharding = HloSharding::Tile(MakeArray({2, 2}, {0, 3, 2, 1})); in TEST_F() local
155 HloSharding sharding = HloSharding::SingleTuple(ShapeUtil::MakeTupleShape({}), in TEST_F() local
274 HloSharding sharding = HloSharding::Replicate(); in TEST_F() local
282 HloSharding sharding = HloSharding::Replicate(std::get<0>(GetParam())); in TEST_P() local
299 HloSharding sharding = HloSharding::AssignDevice(7); in TEST_F() local
[all …]
Dhlo_sharding_util_test.cc39 HloSharding sharding = HloSharding::AssignDevice(7); in TEST() local
49 HloSharding sharding = HloSharding::Tile(Array3D<int64>({{{0}, {1}}})); in TEST() local
112 HloSharding sharding = HloSharding::Tile(sharding_array); in TEST() local
146 HloSharding sharding = HloSharding::Tile(Array3D<int64>({{{0}, {1}}})); in TEST() local
155 HloSharding sharding = HloSharding::Tile(Array3D<int64>({{{0}, {1}}})); in TEST() local
162 HloSharding sharding = HloSharding::Tile(Array2D<int64>({{0, 1}, {2, 3}})); in TEST() local
169 HloSharding sharding = HloSharding::Tile(Array2D<int64>({{0, 1}, {2, 3}})); in TEST() local
176 HloSharding sharding = in TEST() local
186 HloSharding sharding = in TEST() local
195 HloSharding sharding = in TEST() local
[all …]
Dhlo_sharding_metadata.h30 explicit ShardingMetadata(std::shared_ptr<const HloSharding> sharding) in ShardingMetadata()
43 const HloSharding* sharding() const { return sharding_.get(); } in sharding() function
77 std::shared_ptr<const HloSharding> sharding; member
Dhlo_sharding_metadata.cc53 const HloSharding& sharding) { in SetSingleSharding()
120 const HloSharding& sharding) { in FixupPassThroughDomainLinks()
143 std::shared_ptr<const HloSharding> sharding) { in CloneShardingForDomain()
152 const HloSharding& sharding) { in ApplyDomainSingleSharding()
341 const HloSharding& sharding) { in ApplyDomainSharding()
381 std::shared_ptr<const HloSharding> sharding; in ExtractOriginalCommonSharding() local
403 std::unique_ptr<HloSharding> sharding; in Clone() local
451 const HloSharding* sharding = sharding_metadata->sharding(); in NormalizeShardingDomain() local
Dhlo_sharding.cc128 for (auto& sharding : shardings) { in Tuple() local
139 const HloSharding& sharding) { in SingleTuple()
149 const HloSharding& sharding) { in Single()
664 auto assign_metadata = [&](HloSharding& sharding) { in WithMetadata()
670 HloSharding sharding = *this; in WithMetadata() local
682 HloSharding sharding = *this; in WithoutMetadata() local
714 std::ostream& operator<<(std::ostream& out, const HloSharding& sharding) { in operator <<()
Dhlo_matchers.h138 explicit HloShardingMatcher(const absl::optional<HloSharding>& sharding) in HloShardingMatcher()
404 const HloSharding& sharding) { in Sharding()
410 absl::string_view sharding) { in Sharding()
Dsharding_propagation.cc58 bool IsSpatiallyPartitioned(const HloSharding& sharding) { in IsSpatiallyPartitioned()
74 bool MaybeImproveInstructionSharding(HloSharding sharding, in MaybeImproveInstructionSharding()
490 const auto& sharding = inst->sharding(); in InferConvolutionShardingFromOperands() local
694 auto get_maybe_tuple_sharding = [&](HloSharding sharding) { in InferShardingFromOperands()
784 HloSharding sharding = hlo_sharding_util::TransposeSharding( in InferShardingFromOperands() local
1015 auto sharding = instruction->operand(0)->sharding(); in InferShardingFromOperands() local
1089 auto sharding = *hlo_sharding_util::TransposeShardingWithCollapsedDims( in InferDotOperandSharding() local
1534 const HloSharding& sharding = instruction->sharding(); in ProcessShardingInstruction() local
1638 const auto& sharding = sharding_metadata->sharding(); in NormalizeDomain() local
Dhlo_sharding_util.cc117 const HloSharding& sharding) { in MergeSharding()
268 HloSharding TransposeSharding(const HloSharding& sharding, in TransposeSharding()
299 const HloSharding& sharding) { in ReshapeSharding()
412 HloSharding ReverseSharding(const HloSharding& sharding, in ReverseSharding()
433 HloSharding ReshapeToTileDimension(const HloSharding& sharding, int64 dim, in ReshapeToTileDimension()
1097 const HloSharding& sharding, in DevicesForShardingInternal()
1126 const HloSharding& sharding, const std::vector<int64>& available_devices) { in DevicesForSharding()
1143 const HloSharding& sharding, absl::Span<const int64> dims_to_replicate) { in PartiallyReplicateTiledShardingOnDims()
1184 HloSharding RemoveShapeDimensions(const HloSharding& sharding, in RemoveShapeDimensions()
Dbatchnorm_expander.cc272 const HloSharding& sharding = batch_norm->sharding(); in HandleBatchNormTraining() local
361 const HloSharding& sharding = batch_norm->sharding(); in HandleBatchNormInference() local
533 const HloSharding& sharding = batch_norm->sharding(); in HandleBatchNormGrad() local
/external/tensorflow/tensorflow/compiler/tf2xla/
Dsharding_util_test.cc28 [](absl::optional<xla::OpSharding> sharding) -> int64 { in TEST()
77 auto check_metadata = [](const xla::OpSharding& sharding) { in TEST_P()
91 auto& sharding = status_or_sharding.ValueOrDie(); in TEST_P() local
130 xla::OpSharding sharding; in CreateTupleSharding() local
Dsharding_util.cc37 void AssignOpMetadataToSharding(xla::OpSharding& sharding, in AssignOpMetadataToSharding()
80 auto sharding = xla::sharding_builder::AssignDevice(core); in ParseShardingFromDevice() local
156 xla::OpSharding sharding; in GetShardingFromNodeDef() local
Dxla_helpers.cc140 const absl::optional<xla::HloSharding>& sharding, bool use_fast_memory, in RewriteLayoutWithShardedShape()
183 absl::optional<xla::OpSharding> sharding, bool fast_mem) { in ReshapeWithCorrectRepresentationAndSharding()
/external/tensorflow/tensorflow/compiler/mlir/tensorflow/utils/
Dxla_sharding_util.cc203 bool UnsupportedPartitionedShardingType(xla::OpSharding::Type sharding) { in UnsupportedPartitionedShardingType()
240 xla::OpSharding sharding; in ExtractInputsForLogicalDevices() local
335 xla::OpSharding sharding; in ParseAndValidateOutputSharding() local
365 const xla::OpSharding& sharding) { in IsAssignedToLogicalDevice()
380 const auto& sharding = output_sharding_config[output_index]; in MapClusterOutputIndexWithRegionOutputIndex() local
392 const int cluster_func_output_index, const xla::OpSharding& sharding, in GetTileShardedOutputsToMerge()
411 const int cluster_func_output_index, const xla::OpSharding& sharding, in HandleTileShardedOutputs()
627 const auto& sharding = arg_and_idx.value().sharding(); in GetMetadataArgumentMapping() local
/external/tensorflow/tensorflow/compiler/tf2xla/kernels/
Dspmd_manual_sharding_ops.cc42 xla::OpSharding sharding; in Compile() local
105 xla::OpSharding sharding; in Compile() local
/external/tensorflow/tensorflow/core/tpu/kernels/xla/
Dinfeed_op.cc52 absl::optional<xla::OpSharding> sharding) { in UpdateInfeedLayout()
131 absl::optional<xla::OpSharding> sharding; in Compile() local
/external/tensorflow/tensorflow/core/protobuf/tpu/
Dcompile_metadata.proto33 xla.OpSharding sharding = 4; field
72 xla.OpSharding sharding = 1; field
/external/tensorflow/tensorflow/compiler/xla/service/spmd/
Dspmd_partitioner.cc193 const HloSharding& sharding, absl::Span<const int64> replication_dims) { in GetPartitionGroupsForReplication()
482 const HloSharding& sharding = hlo_->sharding(); in PadWithValue() local
832 const HloSharding& sharding = hlo_->sharding(); in Replicate() local
1057 const HloSharding& sharding = hlo_->sharding(); in Broadcast() local
1387 HloSharding sharding = hlo->sharding().HasUniqueDevice() in DefaultAction() local
1413 const HloSharding& sharding) { in Preprocess()
1435 [](const HloSharding& sharding) { return sharding.IsManual(); })); in Preprocess()
1486 const HloSharding& sharding = hlo->sharding(); in HandleConcatenate() local
1580 const HloSharding& sharding = hlo->sharding(); in HandleSlice() local
1642 HloSharding sharding = hlo->sharding(); in HandleSort() local
[all …]
Dspmd_partitioner_util.cc47 bool HasReplicatedSharding(const HloSharding& sharding) { in HasReplicatedSharding()
127 bool EvenlyPartitions(const Shape& shape, const HloSharding& sharding) { in EvenlyPartitions()
148 Shape MakePartitionedShape(const Shape& shape, const HloSharding& sharding) { in MakePartitionedShape()
167 const HloSharding& sharding, in MakeNonPaddedShapeForGivenPartition()
205 const Shape& shape, const HloSharding& sharding, in MakePartitionOffsets()
240 const HloSharding& sharding, HloInstruction* partition_id, SpmdBuilder* b) { in MakeTiledPartitionOrdinals()
277 const HloSharding& sharding) { in GetPaddedShapeForUnevenPartitioning()
294 HloInstruction* hlo, const HloSharding& sharding, SpmdBuilder* b) { in PadBaseShapeBeforeUnevenTiledSharding()
629 absl::optional<int64> UniqueTiledDim(const HloSharding& sharding) { in UniqueTiledDim()
1227 const HloSharding& sharding = sort->operand(0)->sharding(); in GetKValueInTopKWhenPartitionSortDim() local
[all …]
Dfft_handler.cc53 HloInstruction* hlo, int64 num_partitions, const HloSharding& sharding, in PadEachPartitionWithHaloExchange()
228 HloInstruction* hlo, const HloSharding& sharding, in GetFinalFftUsingCollectivePermute()
Dspmd_partitioner_util.h311 HloSharding sharding; member
/external/tensorflow/tensorflow/compiler/xla/client/
Dxla_builder.h194 void SetSharding(const OpSharding& sharding) { sharding_ = sharding; } in SetSharding()
230 const absl::optional<OpSharding>& sharding() const { return sharding_; } in sharding() function
1483 absl::optional<OpSharding> sharding) in XlaScopedShardingAssignment()
1495 void SetSharding(const absl::optional<OpSharding>& sharding) { in SetSharding()
/external/tensorflow/tensorflow/compiler/mlir/tensorflow/transforms/
Dtpu_sharding_identification_pass.cc91 if (auto sharding = llvm::dyn_cast<TF::XlaShardingOp>(owner)) in GetXlaShardingFromArg() local
222 if (auto sharding = llvm::dyn_cast_or_null<TF::XlaShardingOp>(def)) in GetXlaShardingFromRetval() local
/external/tensorflow/tensorflow/compiler/xla/python/
Dpmap_lib.h118 ShardingSpec(std::vector<AvalDimSharding> sharding, in ShardingSpec()
/external/tensorflow/tensorflow/core/tpu/kernels/
Dtpu_compile_op_support.cc162 Shape GetPerDeviceShape(const Shape& shape, const HloSharding& sharding, in GetPerDeviceShape()
212 const auto& sharding = proto_arg.sharding(); in AddVariableUpdatesToCores() local
/external/tensorflow/tensorflow/compiler/xla/pjrt/
Dutils.cc32 const OpSharding& sharding) { in GetShardedShape()

12