Home
last modified time | relevance | path

Searched refs:num_eqs (Results 1 – 5 of 5) sorted by relevance

/external/tensorflow/tensorflow/compiler/xla/client/lib/
Dtridiagonal_test.cc44 const int64 num_eqs = std::get<1>(spec); in XLA_TEST_P() local
47 Array3D<float> lower_diagonal(batch_size, 1, num_eqs); in XLA_TEST_P()
48 Array3D<float> main_diagonal(batch_size, 1, num_eqs); in XLA_TEST_P()
49 Array3D<float> upper_diagonal(batch_size, 1, num_eqs); in XLA_TEST_P()
50 Array3D<float> rhs(batch_size, num_rhs, num_eqs); in XLA_TEST_P()
54 /*seed=*/batch_size * num_eqs); in XLA_TEST_P()
56 /*seed=*/2 * batch_size * num_eqs); in XLA_TEST_P()
57 rhs.FillRandom(1.0, /*mean=*/0.0, /*seed=*/3 * batch_size * num_eqs); in XLA_TEST_P()
80 std::vector<XlaOp> relative_errors(num_eqs); in XLA_TEST_P()
82 for (int64 i = 0; i < num_eqs; i++) { in XLA_TEST_P()
[all …]
Dtridiagonal.cc149 TF_ASSIGN_OR_RETURN(int64 num_eqs, in ThomasSolver()
177 ForEachIndex(num_eqs - 1, S32, preparation_body_fn, in ThomasSolver()
223 num_eqs - 1, S32, forward_transformation_fn, in ThomasSolver()
238 UpdateEq(x_coeffs, num_eqs - 1, in ThomasSolver()
239 Coefficient(rhs_after_elimination, num_eqs - 1) / in ThomasSolver()
240 Coefficient(main_diag_after_elimination, num_eqs - 1)); in ThomasSolver()
242 [num_eqs](XlaOp j, absl::Span<const XlaOp> values, in ThomasSolver()
248 auto n = ScalarLike(j, num_eqs - 2); in ThomasSolver()
267 ForEachIndex(num_eqs - 1, S32, bwd_reduction_fn, in ThomasSolver()
/external/tensorflow/tensorflow/core/kernels/linalg/
Dtridiagonal_matmul_op.cc68 const int num_eqs = static_cast<int>(input_matrix_shapes[0].dim_size(1)); in GetCostPerUnit() local
74 const double cost = num_rhss * ((3 * num_eqs - 2) * mult_cost + in GetCostPerUnit()
75 (2 * num_eqs - 2) * add_cost); in GetCostPerUnit()
Dtridiagonal_solve_op_gpu.cu.cc201 const Scalar* subdiag, Scalar* rhs, const int num_eqs, in SolveWithGtsv() argument
208 num_eqs, num_rhs, subdiag, diag, superdiag, rhs, in SolveWithGtsv()
209 num_eqs, &buffer_size)); in SolveWithGtsv()
219 num_eqs, num_rhs, subdiag, diag, superdiag, rhs, in SolveWithGtsv()
220 num_eqs, buffer)); in SolveWithGtsv()
Dtridiagonal_solve_op.cc78 const int num_eqs = static_cast<int>(input_matrix_shapes[0].dim_size(1)); in GetCostPerUnit() local
88 cost = num_eqs * (div_cost * (num_rhss + 1) + in GetCostPerUnit()
91 cost = num_eqs * (div_cost * (num_rhss + 1) + in GetCostPerUnit()