Home
last modified time | relevance | path

Searched refs:batch_dimensions (Results 1 – 5 of 5) sorted by relevance

/external/tensorflow/tensorflow/python/ops/ragged/
Dragged_batch_gather_with_default_op.py174 batch_dimensions = params_shape.partitioned_dim_sizes[
178 pad_dims = batch_dimensions + (
/external/tensorflow/tensorflow/compiler/mlir/tensorflow/transforms/
Dlegalize_hlo.cc395 DotDimensionsInfo(ShapedType type, DenseIntElementsAttr batch_dimensions, in DotDimensionsInfo() argument
398 for (const int dim : batch_dimensions.getValues<int64_t>()) { in DotDimensionsInfo()
418 const DimensionVector &batch_dimensions() const { return batch_dimensions_; } in batch_dimensions() function in mlir::TF::__anon0527096c0111::DotDimensionsInfo
468 lhs_dot_dimensions_info.batch_dimensions().AxesArray(), in ConvertDot()
472 lhs_dot_dimensions_info.batch_dimensions().SizesArray(), in ConvertDot()
486 rhs_dot_dimensions_info.batch_dimensions().AxesArray(), in ConvertDot()
490 rhs_dot_dimensions_info.batch_dimensions().SizesArray(), in ConvertDot()
503 lhs_dot_dimensions_info.batch_dimensions().SizesArray(), in ConvertDot()
515 rhs_dot_dimensions_info.batch_dimensions().SizesArray(), in ConvertDot()
527 Concat<int64_t>(lhs_dot_dimensions_info.batch_dimensions().SizesArray(), in ConvertDot()
/external/tensorflow/tensorflow/compiler/xla/service/
Dtriangular_solve_expander.cc494 std::vector<int64> batch_dimensions; in BuildTriangularSolve() local
505 batch_dimensions.push_back(a_size); in BuildTriangularSolve()
Dalgebraic_simplifier.cc397 HloInstruction* dot_operand, absl::Span<const int64> batch_dimensions, in NormalizeDotOperandToBatchMajorAndContractingMinor() argument
399 std::vector<int64> transpose_dimensions(batch_dimensions.begin(), in NormalizeDotOperandToBatchMajorAndContractingMinor()
400 batch_dimensions.end()); in NormalizeDotOperandToBatchMajorAndContractingMinor()
402 if (!(absl::c_linear_search(batch_dimensions, i) || in NormalizeDotOperandToBatchMajorAndContractingMinor()
/external/tensorflow/tensorflow/compiler/mlir/xla/transforms/
Dlegalize_tf.cc542 auto batch_dimensions = GetI64ElementsAttr( in BatchDot() local
551 /*lhs_batching_dimensions=*/batch_dimensions, in BatchDot()
552 /*rhs_batching_dimensions=*/batch_dimensions, in BatchDot()
2851 auto batch_dimensions = GetI64ElementsAttr( in matchAndRewrite() local
2858 /*lhs_batching_dimensions=*/batch_dimensions, in matchAndRewrite()
2859 /*rhs_batching_dimensions=*/batch_dimensions, in matchAndRewrite()