Searched refs:num_eqs (Results 1 – 5 of 5) sorted by relevance
/external/tensorflow/tensorflow/compiler/xla/client/lib/ |
D | tridiagonal_test.cc | 44 const int64 num_eqs = std::get<1>(spec); in XLA_TEST_P() local 47 Array3D<float> lower_diagonal(batch_size, 1, num_eqs); in XLA_TEST_P() 48 Array3D<float> main_diagonal(batch_size, 1, num_eqs); in XLA_TEST_P() 49 Array3D<float> upper_diagonal(batch_size, 1, num_eqs); in XLA_TEST_P() 50 Array3D<float> rhs(batch_size, num_rhs, num_eqs); in XLA_TEST_P() 54 /*seed=*/batch_size * num_eqs); in XLA_TEST_P() 56 /*seed=*/2 * batch_size * num_eqs); in XLA_TEST_P() 57 rhs.FillRandom(1.0, /*mean=*/0.0, /*seed=*/3 * batch_size * num_eqs); in XLA_TEST_P() 80 std::vector<XlaOp> relative_errors(num_eqs); in XLA_TEST_P() 82 for (int64 i = 0; i < num_eqs; i++) { in XLA_TEST_P() [all …]
|
D | tridiagonal.cc | 149 TF_ASSIGN_OR_RETURN(int64 num_eqs, in ThomasSolver() 177 ForEachIndex(num_eqs - 1, S32, preparation_body_fn, in ThomasSolver() 223 num_eqs - 1, S32, forward_transformation_fn, in ThomasSolver() 238 UpdateEq(x_coeffs, num_eqs - 1, in ThomasSolver() 239 Coefficient(rhs_after_elimination, num_eqs - 1) / in ThomasSolver() 240 Coefficient(main_diag_after_elimination, num_eqs - 1)); in ThomasSolver() 242 [num_eqs](XlaOp j, absl::Span<const XlaOp> values, in ThomasSolver() 248 auto n = ScalarLike(j, num_eqs - 2); in ThomasSolver() 267 ForEachIndex(num_eqs - 1, S32, bwd_reduction_fn, in ThomasSolver()
|
/external/tensorflow/tensorflow/core/kernels/linalg/ |
D | tridiagonal_matmul_op.cc | 68 const int num_eqs = static_cast<int>(input_matrix_shapes[0].dim_size(1)); in GetCostPerUnit() local 74 const double cost = num_rhss * ((3 * num_eqs - 2) * mult_cost + in GetCostPerUnit() 75 (2 * num_eqs - 2) * add_cost); in GetCostPerUnit()
|
D | tridiagonal_solve_op_gpu.cu.cc | 201 const Scalar* subdiag, Scalar* rhs, const int num_eqs, in SolveWithGtsv() argument 208 num_eqs, num_rhs, subdiag, diag, superdiag, rhs, in SolveWithGtsv() 209 num_eqs, &buffer_size)); in SolveWithGtsv() 219 num_eqs, num_rhs, subdiag, diag, superdiag, rhs, in SolveWithGtsv() 220 num_eqs, buffer)); in SolveWithGtsv()
|
D | tridiagonal_solve_op.cc | 78 const int num_eqs = static_cast<int>(input_matrix_shapes[0].dim_size(1)); in GetCostPerUnit() local 88 cost = num_eqs * (div_cost * (num_rhss + 1) + in GetCostPerUnit() 91 cost = num_eqs * (div_cost * (num_rhss + 1) + in GetCostPerUnit()
|