/external/tensorflow/tensorflow/core/kernels/linalg/ |
D | linalg_ops_common.cc | 34 OpKernelContext* context, const TensorShapes& input_matrix_shapes) { in ValidateSingleMatrix() argument 35 OP_REQUIRES(context, input_matrix_shapes.size() == 1, in ValidateSingleMatrix() 37 input_matrix_shapes.size())); in ValidateSingleMatrix() 38 OP_REQUIRES(context, TensorShapeUtils::IsMatrix(input_matrix_shapes[0]), in ValidateSingleMatrix() 45 OpKernelContext* context, const TensorShapes& input_matrix_shapes) { in ValidateSingleSquareMatrix() argument 46 OP_REQUIRES(context, input_matrix_shapes.size() == 1, in ValidateSingleSquareMatrix() 48 input_matrix_shapes.size())); in ValidateSingleSquareMatrix() 49 OP_REQUIRES(context, TensorShapeUtils::IsSquareMatrix(input_matrix_shapes[0]), in ValidateSingleSquareMatrix() 56 OpKernelContext* context, const TensorShapes& input_matrix_shapes) { in ValidateSolver() argument 57 OP_REQUIRES(context, input_matrix_shapes.size() == 2, in ValidateSolver() [all …]
|
D | tridiagonal_matmul_op.cc | 38 const TensorShapes& input_matrix_shapes) const final { in ValidateInputMatrixShapes() 39 auto num_inputs = input_matrix_shapes.size(); in ValidateInputMatrixShapes() 44 auto n = input_matrix_shapes[3].dim_size(0); in ValidateInputMatrixShapes() 47 input_matrix_shapes[0].dim_size(0) == 1 && in ValidateInputMatrixShapes() 48 input_matrix_shapes[0].dim_size(1) == n, in ValidateInputMatrixShapes() 52 input_matrix_shapes[1].dim_size(0) == 1 && in ValidateInputMatrixShapes() 53 input_matrix_shapes[1].dim_size(1) == n, in ValidateInputMatrixShapes() 57 input_matrix_shapes[2].dim_size(0) == 1 && in ValidateInputMatrixShapes() 58 input_matrix_shapes[2].dim_size(1) == n, in ValidateInputMatrixShapes() 63 const TensorShapes& input_matrix_shapes) const final { in GetOutputMatrixShapes() [all …]
|
D | linalg_ops_common.h | 58 OpKernelContext* context, const TensorShapes& input_matrix_shapes) const { in ValidateInputMatrixShapes() argument 59 ValidateSingleSquareMatrix(context, input_matrix_shapes); in ValidateInputMatrixShapes() 66 const TensorShapes& input_matrix_shapes); 69 OpKernelContext* context, const TensorShapes& input_matrix_shapes); 72 const TensorShapes& input_matrix_shapes); 76 const TensorShapes& input_matrix_shapes); 89 const TensorShapes& input_matrix_shapes) const { in GetOutputMatrixShapes() argument 90 return input_matrix_shapes; in GetOutputMatrixShapes() 99 virtual int64 GetCostPerUnit(const TensorShapes& input_matrix_shapes) const { in GetCostPerUnit() argument 100 double m = static_cast<double>(input_matrix_shapes[0].dim_size(0)); in GetCostPerUnit() [all …]
|
D | matrix_solve_ls_op_impl.h | 55 const TensorShapes& input_matrix_shapes) const final { in ValidateInputMatrixShapes() argument 56 Base::ValidateSolver(context, input_matrix_shapes); in ValidateInputMatrixShapes() 60 const TensorShapes& input_matrix_shapes) const final { in GetOutputMatrixShapes() argument 61 return TensorShapes({TensorShape({input_matrix_shapes[0].dim_size(1), in GetOutputMatrixShapes() 62 input_matrix_shapes[1].dim_size(1)})}); in GetOutputMatrixShapes() 65 int64 GetCostPerUnit(const TensorShapes& input_matrix_shapes) const final { in GetCostPerUnit() argument 66 double m = static_cast<double>(input_matrix_shapes[0].dim_size(0)); in GetCostPerUnit() 67 double n = static_cast<double>(input_matrix_shapes[0].dim_size(1)); in GetCostPerUnit() 68 double num_rhss = static_cast<double>(input_matrix_shapes[1].dim_size(1)); in GetCostPerUnit()
|
D | svd_op_impl.h | 52 const TensorShapes& input_matrix_shapes) const final { in ValidateInputMatrixShapes() argument 53 Base::ValidateSingleMatrix(context, input_matrix_shapes); in ValidateInputMatrixShapes() 57 const TensorShapes& input_matrix_shapes) const final { in GetOutputMatrixShapes() argument 58 int64_t m = input_matrix_shapes[0].dim_size(0); in GetOutputMatrixShapes() 59 int64_t n = input_matrix_shapes[0].dim_size(1); in GetOutputMatrixShapes() 71 int64 GetCostPerUnit(const TensorShapes& input_matrix_shapes) const final { in GetCostPerUnit() argument 72 double m = static_cast<double>(input_matrix_shapes[0].dim_size(0)); in GetCostPerUnit() 73 double n = static_cast<double>(input_matrix_shapes[0].dim_size(1)); in GetCostPerUnit()
|
D | cholesky_grad.cc | 43 const TensorShapes& input_matrix_shapes) const final { in ValidateInputMatrixShapes() 44 OP_REQUIRES(context, input_matrix_shapes.size() == 2, in ValidateInputMatrixShapes() 46 input_matrix_shapes.size())); in ValidateInputMatrixShapes() 47 OP_REQUIRES(context, input_matrix_shapes[0] == input_matrix_shapes[1], in ValidateInputMatrixShapes() 51 TensorShapeUtils::IsSquareMatrix(input_matrix_shapes[0]), in ValidateInputMatrixShapes() 56 const TensorShapes& input_matrix_shapes) const final { in GetOutputMatrixShapes() 57 return TensorShapes({input_matrix_shapes[0]}); in GetOutputMatrixShapes()
|
D | qr_op_impl.h | 66 const TensorShapes& input_matrix_shapes) const final { in ValidateInputMatrixShapes() argument 67 Base::ValidateSingleMatrix(context, input_matrix_shapes); in ValidateInputMatrixShapes() 71 const TensorShapes& input_matrix_shapes) const final { in GetOutputMatrixShapes() argument 72 int64_t m = input_matrix_shapes[0].dim_size(0); in GetOutputMatrixShapes() 73 int64_t n = input_matrix_shapes[0].dim_size(1); in GetOutputMatrixShapes() 83 int64 GetCostPerUnit(const TensorShapes& input_matrix_shapes) const final { in GetCostPerUnit() argument 84 double m = static_cast<double>(input_matrix_shapes[0].dim_size(0)); in GetCostPerUnit() 85 double n = static_cast<double>(input_matrix_shapes[0].dim_size(1)); in GetCostPerUnit()
|
D | matrix_solve_op.cc | 55 const TensorShapes& input_matrix_shapes) const final { in ValidateInputMatrixShapes() 56 Base::ValidateSquareSolver(context, input_matrix_shapes); in ValidateInputMatrixShapes() 60 const TensorShapes& input_matrix_shapes) const final { in GetOutputMatrixShapes() 61 return TensorShapes({TensorShape({input_matrix_shapes[0].dim_size(1), in GetOutputMatrixShapes() 62 input_matrix_shapes[1].dim_size(1)})}); in GetOutputMatrixShapes() 65 int64 GetCostPerUnit(const TensorShapes& input_matrix_shapes) const final { in GetCostPerUnit() 66 double rows = static_cast<double>(input_matrix_shapes[0].dim_size(0)); in GetCostPerUnit() 67 double num_rhss = static_cast<double>(input_matrix_shapes[1].dim_size(1)); in GetCostPerUnit()
|
D | tridiagonal_solve_op.cc | 58 const TensorShapes& input_matrix_shapes) const final { in ValidateInputMatrixShapes() 59 auto num_inputs = input_matrix_shapes.size(); in ValidateInputMatrixShapes() 64 auto num_diags = input_matrix_shapes[0].dim_size(0); in ValidateInputMatrixShapes() 71 auto num_eqs_left = input_matrix_shapes[0].dim_size(1); in ValidateInputMatrixShapes() 72 auto num_eqs_right = input_matrix_shapes[1].dim_size(0); in ValidateInputMatrixShapes() 81 const TensorShapes& input_matrix_shapes) const final { in GetOutputMatrixShapes() 82 return TensorShapes({input_matrix_shapes[1]}); in GetOutputMatrixShapes() 85 int64 GetCostPerUnit(const TensorShapes& input_matrix_shapes) const final { in GetCostPerUnit() 86 const int num_eqs = static_cast<int>(input_matrix_shapes[0].dim_size(1)); in GetCostPerUnit() 87 const int num_rhss = static_cast<int>(input_matrix_shapes[1].dim_size(0)); in GetCostPerUnit()
|
D | tridiagonal_solve_op_gpu.cu.cc | 113 const TensorShapes& input_matrix_shapes) const final { in ValidateInputMatrixShapes() 114 auto num_inputs = input_matrix_shapes.size(); in ValidateInputMatrixShapes() 119 auto num_diags = input_matrix_shapes[0].dim_size(0); in ValidateInputMatrixShapes() 126 auto num_rows1 = input_matrix_shapes[0].dim_size(1); in ValidateInputMatrixShapes() 127 auto num_rows2 = input_matrix_shapes[1].dim_size(0); in ValidateInputMatrixShapes() 137 const TensorShapes& input_matrix_shapes) const final { in GetOutputMatrixShapes() 138 return TensorShapes({input_matrix_shapes[1]}); in GetOutputMatrixShapes()
|
D | self_adjoint_eig_op.cc | 45 const TensorShapes& input_matrix_shapes) const final { in GetOutputMatrixShapes() 46 int64_t d = input_matrix_shapes[0].dim_size(0); in GetOutputMatrixShapes()
|
D | self_adjoint_eig_v2_op_impl.h | 50 const TensorShapes& input_matrix_shapes) const final { in GetOutputMatrixShapes() argument 51 int64_t n = input_matrix_shapes[0].dim_size(0); in GetOutputMatrixShapes()
|
D | eig_op_impl.h | 55 const TensorShapes& input_matrix_shapes) const final { in GetOutputMatrixShapes() argument 56 int64_t n = input_matrix_shapes[0].dim_size(0); in GetOutputMatrixShapes()
|
D | determinant_op.cc | 85 const TensorShapes& input_matrix_shapes) const final { in GetOutputMatrixShapes()
|