Searched refs:batch_dimensions (Results 1 – 5 of 5) sorted by relevance
/external/tensorflow/tensorflow/python/ops/ragged/ |
D | ragged_batch_gather_with_default_op.py | 174 batch_dimensions = params_shape.partitioned_dim_sizes[ 178 pad_dims = batch_dimensions + (
|
/external/tensorflow/tensorflow/compiler/mlir/tensorflow/transforms/ |
D | legalize_hlo.cc | 395 DotDimensionsInfo(ShapedType type, DenseIntElementsAttr batch_dimensions, in DotDimensionsInfo() argument 398 for (const int dim : batch_dimensions.getValues<int64_t>()) { in DotDimensionsInfo() 418 const DimensionVector &batch_dimensions() const { return batch_dimensions_; } in batch_dimensions() function in mlir::TF::__anon0527096c0111::DotDimensionsInfo 468 lhs_dot_dimensions_info.batch_dimensions().AxesArray(), in ConvertDot() 472 lhs_dot_dimensions_info.batch_dimensions().SizesArray(), in ConvertDot() 486 rhs_dot_dimensions_info.batch_dimensions().AxesArray(), in ConvertDot() 490 rhs_dot_dimensions_info.batch_dimensions().SizesArray(), in ConvertDot() 503 lhs_dot_dimensions_info.batch_dimensions().SizesArray(), in ConvertDot() 515 rhs_dot_dimensions_info.batch_dimensions().SizesArray(), in ConvertDot() 527 Concat<int64_t>(lhs_dot_dimensions_info.batch_dimensions().SizesArray(), in ConvertDot()
|
/external/tensorflow/tensorflow/compiler/xla/service/ |
D | triangular_solve_expander.cc | 494 std::vector<int64> batch_dimensions; in BuildTriangularSolve() local 505 batch_dimensions.push_back(a_size); in BuildTriangularSolve()
|
D | algebraic_simplifier.cc | 397 HloInstruction* dot_operand, absl::Span<const int64> batch_dimensions, in NormalizeDotOperandToBatchMajorAndContractingMinor() argument 399 std::vector<int64> transpose_dimensions(batch_dimensions.begin(), in NormalizeDotOperandToBatchMajorAndContractingMinor() 400 batch_dimensions.end()); in NormalizeDotOperandToBatchMajorAndContractingMinor() 402 if (!(absl::c_linear_search(batch_dimensions, i) || in NormalizeDotOperandToBatchMajorAndContractingMinor()
|
/external/tensorflow/tensorflow/compiler/mlir/xla/transforms/ |
D | legalize_tf.cc | 542 auto batch_dimensions = GetI64ElementsAttr( in BatchDot() local 551 /*lhs_batching_dimensions=*/batch_dimensions, in BatchDot() 552 /*rhs_batching_dimensions=*/batch_dimensions, in BatchDot() 2851 auto batch_dimensions = GetI64ElementsAttr( in matchAndRewrite() local 2858 /*lhs_batching_dimensions=*/batch_dimensions, in matchAndRewrite() 2859 /*rhs_batching_dimensions=*/batch_dimensions, in matchAndRewrite()
|