Lines Matching refs:unitriangular
4366 const bool unitriangular, in triangular_solve_backward() argument
4373 grad_x.triangular_solve(a.conj(), upper, !transpose, unitriangular)); in triangular_solve_backward()
4378 grad_a = grad_a.triu((int)unitriangular); in triangular_solve_backward()
4380 grad_a = grad_a.tril(-((int)unitriangular)); in triangular_solve_backward()
4404 const bool unitriangular) { in triangular_solve_jvp() argument
4408 dB - dA_contrib, A, upper, transpose, unitriangular)); in triangular_solve_jvp()
4423 const bool unitriangular) { in linalg_solve_triangular_forward_AD() argument
4428 const Tensor proj_A_t = upper ? A_t.triu(static_cast<int>(unitriangular)) in linalg_solve_triangular_forward_AD()
4429 : A_t.tril(-static_cast<int>(unitriangular)); in linalg_solve_triangular_forward_AD()
4432 return at::linalg_solve_triangular(A, X_t, upper, left, unitriangular); in linalg_solve_triangular_forward_AD()
4441 const bool unitriangular, in linalg_solve_triangular_backward() argument
4480 at::linalg_solve_triangular(A_H, grad, !upper, left, unitriangular); in linalg_solve_triangular_backward()
4485 G_A = upper ? G_A.triu(static_cast<int>(unitriangular)) in linalg_solve_triangular_backward()
4486 : G_A.tril(-static_cast<int>(unitriangular)); in linalg_solve_triangular_backward()