Home
last modified time | relevance | path

Searched refs:contracting_dims (Results 1 – 9 of 9) sorted by relevance

/external/tensorflow/tensorflow/compiler/xla/service/
Ddot_as_convolution_util.cc63 dims.contracting_dims.push_back({conv_dims.input_feature_dimension(), in ParseConvolutionDimsInfo()
82 dims.contracting_dims.push_back({lhs, rhs, output, i}); in ParseConvolutionDimsInfo()
122 for (const auto& dim : dot_dnums.contracting_dims) { in CreateShardedConvForDotGeneralConvolution()
167 dnums.contracting_dims.emplace_back(); in ParseDotGeneralFromDot()
168 dnums.contracting_dims.back().lhs = in ParseDotGeneralFromDot()
170 dnums.contracting_dims.back().rhs = in ParseDotGeneralFromDot()
172 dnums.contracting_dims.back().output = -1; in ParseDotGeneralFromDot()
173 dnums.contracting_dims.back().spatial_dim = -1; in ParseDotGeneralFromDot()
Ddot_as_convolution_util.h44 std::vector<DimNums> contracting_dims; member
Dsharding_propagation.cc346 std::vector<int64> contracting_dims; in InferDotShardingFromOperands() local
347 contracting_dims.reserve(dnums.contracting_dims.size()); in InferDotShardingFromOperands()
348 for (const auto& dim : dnums.contracting_dims) { in InferDotShardingFromOperands()
349 contracting_dims.push_back(operand_index == 0 ? dim.lhs : dim.rhs); in InferDotShardingFromOperands()
358 contracting_dims.push_back(d); in InferDotShardingFromOperands()
363 operand_sharding, contracting_dims); in InferDotShardingFromOperands()
1063 for (const auto& dim : dnums.contracting_dims) { in InferDotOperandSharding()
1104 for (const auto& dim : dnums.contracting_dims) { in InferDotOperandSharding()
Dindexed_array_analysis.cc976 int64 rank, absl::Span<const int64> contracting_dims, in GetOnlyNonContractingNonBatchDim() argument
980 if (!absl::c_linear_search(contracting_dims, dim) && in GetOnlyNonContractingNonBatchDim()
1001 absl::Span<const int64> contracting_dims, in CanFoldDotIntoIndexedArray() argument
1005 contracting_dims, batch_dims); in CanFoldDotIntoIndexedArray()
Dshape_inference.cc602 absl::Span<const int64> contracting_dims, in ValidateDotDimensionNumbers()
605 return absl::c_all_of(contracting_dims, in_range) && in ValidateDotDimensionNumbers()
627 auto dims_unique = [](absl::Span<const int64> contracting_dims, in ValidateDotDimensionNumbers()
633 return absl::c_all_of(contracting_dims, is_unique) && in ValidateDotDimensionNumbers()
/external/tensorflow/tensorflow/compiler/xla/service/spmd/
Ddot_handler.cc53 mapping.contracting_dims.emplace_back(); in HandleDot()
54 mapping.contracting_dims.back().lhs = dnums.lhs_contracting_dimensions(i); in HandleDot()
55 mapping.contracting_dims.back().rhs = dnums.rhs_contracting_dimensions(i); in HandleDot()
56 mapping.contracting_dims.back().output = -1; in HandleDot()
366 for (const auto& mapping : dims_mapping.contracting_dims) { in ComputeDimensionIndexMapping()
1515 for (const auto& cd : dims_mapping.contracting_dims) { in PartitionBaseCase()
1625 for (const auto& cd : dims_mapping.contracting_dims) { in PartitionBaseCase()
1746 absl::Span<const int64> contracting_dims, in PartitionDotGroupOnBatch()
1791 for (int64 dim : contracting_dims) { in PartitionDotGroupOnBatch()
1827 lhs_contracting_dims.reserve(dims_mapping.contracting_dims.size()); in PartitionDotGroupOnBatch()
[all …]
Dconvolution_handler.cc908 for (const auto& dim : dot_dnums.contracting_dims) { in CreateShardedConvConvolution()
997 for (const auto& dims : dims_info.contracting_dims) { in HandleConvolution()
998 mapping.contracting_dims.emplace_back(); in HandleConvolution()
999 mapping.contracting_dims.back().lhs = dims.lhs; in HandleConvolution()
1000 mapping.contracting_dims.back().rhs = dims.rhs; in HandleConvolution()
1001 mapping.contracting_dims.back().output = dims.output; in HandleConvolution()
1002 mapping.contracting_dims.back().spatial = dims.spatial_dim; in HandleConvolution()
Dspmd_partitioner.h394 std::vector<DimsMapping> contracting_dims; member
/external/tensorflow/tensorflow/compiler/xla/client/lib/
Dmatrix.cc331 C* batch_dims, C* contracting_dims) { in DeleteDimsFromContainer() argument
343 for (auto& c : *contracting_dims) { in DeleteDimsFromContainer()