Home
last modified time | relevance | path

Searched refs:input_matrix_shapes (Results 1 – 14 of 14) sorted by relevance

/external/tensorflow/tensorflow/core/kernels/linalg/
Dlinalg_ops_common.cc34 OpKernelContext* context, const TensorShapes& input_matrix_shapes) { in ValidateSingleMatrix() argument
35 OP_REQUIRES(context, input_matrix_shapes.size() == 1, in ValidateSingleMatrix()
37 input_matrix_shapes.size())); in ValidateSingleMatrix()
38 OP_REQUIRES(context, TensorShapeUtils::IsMatrix(input_matrix_shapes[0]), in ValidateSingleMatrix()
45 OpKernelContext* context, const TensorShapes& input_matrix_shapes) { in ValidateSingleSquareMatrix() argument
46 OP_REQUIRES(context, input_matrix_shapes.size() == 1, in ValidateSingleSquareMatrix()
48 input_matrix_shapes.size())); in ValidateSingleSquareMatrix()
49 OP_REQUIRES(context, TensorShapeUtils::IsSquareMatrix(input_matrix_shapes[0]), in ValidateSingleSquareMatrix()
56 OpKernelContext* context, const TensorShapes& input_matrix_shapes) { in ValidateSolver() argument
57 OP_REQUIRES(context, input_matrix_shapes.size() == 2, in ValidateSolver()
[all …]
Dtridiagonal_matmul_op.cc38 const TensorShapes& input_matrix_shapes) const final { in ValidateInputMatrixShapes()
39 auto num_inputs = input_matrix_shapes.size(); in ValidateInputMatrixShapes()
44 auto n = input_matrix_shapes[3].dim_size(0); in ValidateInputMatrixShapes()
47 input_matrix_shapes[0].dim_size(0) == 1 && in ValidateInputMatrixShapes()
48 input_matrix_shapes[0].dim_size(1) == n, in ValidateInputMatrixShapes()
52 input_matrix_shapes[1].dim_size(0) == 1 && in ValidateInputMatrixShapes()
53 input_matrix_shapes[1].dim_size(1) == n, in ValidateInputMatrixShapes()
57 input_matrix_shapes[2].dim_size(0) == 1 && in ValidateInputMatrixShapes()
58 input_matrix_shapes[2].dim_size(1) == n, in ValidateInputMatrixShapes()
63 const TensorShapes& input_matrix_shapes) const final { in GetOutputMatrixShapes()
[all …]
Dlinalg_ops_common.h58 OpKernelContext* context, const TensorShapes& input_matrix_shapes) const { in ValidateInputMatrixShapes() argument
59 ValidateSingleSquareMatrix(context, input_matrix_shapes); in ValidateInputMatrixShapes()
66 const TensorShapes& input_matrix_shapes);
69 OpKernelContext* context, const TensorShapes& input_matrix_shapes);
72 const TensorShapes& input_matrix_shapes);
76 const TensorShapes& input_matrix_shapes);
89 const TensorShapes& input_matrix_shapes) const { in GetOutputMatrixShapes() argument
90 return input_matrix_shapes; in GetOutputMatrixShapes()
99 virtual int64 GetCostPerUnit(const TensorShapes& input_matrix_shapes) const { in GetCostPerUnit() argument
100 double m = static_cast<double>(input_matrix_shapes[0].dim_size(0)); in GetCostPerUnit()
[all …]
Dmatrix_solve_ls_op_impl.h55 const TensorShapes& input_matrix_shapes) const final { in ValidateInputMatrixShapes() argument
56 Base::ValidateSolver(context, input_matrix_shapes); in ValidateInputMatrixShapes()
60 const TensorShapes& input_matrix_shapes) const final { in GetOutputMatrixShapes() argument
61 return TensorShapes({TensorShape({input_matrix_shapes[0].dim_size(1), in GetOutputMatrixShapes()
62 input_matrix_shapes[1].dim_size(1)})}); in GetOutputMatrixShapes()
65 int64 GetCostPerUnit(const TensorShapes& input_matrix_shapes) const final { in GetCostPerUnit() argument
66 double m = static_cast<double>(input_matrix_shapes[0].dim_size(0)); in GetCostPerUnit()
67 double n = static_cast<double>(input_matrix_shapes[0].dim_size(1)); in GetCostPerUnit()
68 double num_rhss = static_cast<double>(input_matrix_shapes[1].dim_size(1)); in GetCostPerUnit()
Dsvd_op_impl.h52 const TensorShapes& input_matrix_shapes) const final { in ValidateInputMatrixShapes() argument
53 Base::ValidateSingleMatrix(context, input_matrix_shapes); in ValidateInputMatrixShapes()
57 const TensorShapes& input_matrix_shapes) const final { in GetOutputMatrixShapes() argument
58 int64_t m = input_matrix_shapes[0].dim_size(0); in GetOutputMatrixShapes()
59 int64_t n = input_matrix_shapes[0].dim_size(1); in GetOutputMatrixShapes()
71 int64 GetCostPerUnit(const TensorShapes& input_matrix_shapes) const final { in GetCostPerUnit() argument
72 double m = static_cast<double>(input_matrix_shapes[0].dim_size(0)); in GetCostPerUnit()
73 double n = static_cast<double>(input_matrix_shapes[0].dim_size(1)); in GetCostPerUnit()
Dcholesky_grad.cc43 const TensorShapes& input_matrix_shapes) const final { in ValidateInputMatrixShapes()
44 OP_REQUIRES(context, input_matrix_shapes.size() == 2, in ValidateInputMatrixShapes()
46 input_matrix_shapes.size())); in ValidateInputMatrixShapes()
47 OP_REQUIRES(context, input_matrix_shapes[0] == input_matrix_shapes[1], in ValidateInputMatrixShapes()
51 TensorShapeUtils::IsSquareMatrix(input_matrix_shapes[0]), in ValidateInputMatrixShapes()
56 const TensorShapes& input_matrix_shapes) const final { in GetOutputMatrixShapes()
57 return TensorShapes({input_matrix_shapes[0]}); in GetOutputMatrixShapes()
Dqr_op_impl.h66 const TensorShapes& input_matrix_shapes) const final { in ValidateInputMatrixShapes() argument
67 Base::ValidateSingleMatrix(context, input_matrix_shapes); in ValidateInputMatrixShapes()
71 const TensorShapes& input_matrix_shapes) const final { in GetOutputMatrixShapes() argument
72 int64_t m = input_matrix_shapes[0].dim_size(0); in GetOutputMatrixShapes()
73 int64_t n = input_matrix_shapes[0].dim_size(1); in GetOutputMatrixShapes()
83 int64 GetCostPerUnit(const TensorShapes& input_matrix_shapes) const final { in GetCostPerUnit() argument
84 double m = static_cast<double>(input_matrix_shapes[0].dim_size(0)); in GetCostPerUnit()
85 double n = static_cast<double>(input_matrix_shapes[0].dim_size(1)); in GetCostPerUnit()
Dmatrix_solve_op.cc55 const TensorShapes& input_matrix_shapes) const final { in ValidateInputMatrixShapes()
56 Base::ValidateSquareSolver(context, input_matrix_shapes); in ValidateInputMatrixShapes()
60 const TensorShapes& input_matrix_shapes) const final { in GetOutputMatrixShapes()
61 return TensorShapes({TensorShape({input_matrix_shapes[0].dim_size(1), in GetOutputMatrixShapes()
62 input_matrix_shapes[1].dim_size(1)})}); in GetOutputMatrixShapes()
65 int64 GetCostPerUnit(const TensorShapes& input_matrix_shapes) const final { in GetCostPerUnit()
66 double rows = static_cast<double>(input_matrix_shapes[0].dim_size(0)); in GetCostPerUnit()
67 double num_rhss = static_cast<double>(input_matrix_shapes[1].dim_size(1)); in GetCostPerUnit()
Dtridiagonal_solve_op.cc58 const TensorShapes& input_matrix_shapes) const final { in ValidateInputMatrixShapes()
59 auto num_inputs = input_matrix_shapes.size(); in ValidateInputMatrixShapes()
64 auto num_diags = input_matrix_shapes[0].dim_size(0); in ValidateInputMatrixShapes()
71 auto num_eqs_left = input_matrix_shapes[0].dim_size(1); in ValidateInputMatrixShapes()
72 auto num_eqs_right = input_matrix_shapes[1].dim_size(0); in ValidateInputMatrixShapes()
81 const TensorShapes& input_matrix_shapes) const final { in GetOutputMatrixShapes()
82 return TensorShapes({input_matrix_shapes[1]}); in GetOutputMatrixShapes()
85 int64 GetCostPerUnit(const TensorShapes& input_matrix_shapes) const final { in GetCostPerUnit()
86 const int num_eqs = static_cast<int>(input_matrix_shapes[0].dim_size(1)); in GetCostPerUnit()
87 const int num_rhss = static_cast<int>(input_matrix_shapes[1].dim_size(0)); in GetCostPerUnit()
Dtridiagonal_solve_op_gpu.cu.cc113 const TensorShapes& input_matrix_shapes) const final { in ValidateInputMatrixShapes()
114 auto num_inputs = input_matrix_shapes.size(); in ValidateInputMatrixShapes()
119 auto num_diags = input_matrix_shapes[0].dim_size(0); in ValidateInputMatrixShapes()
126 auto num_rows1 = input_matrix_shapes[0].dim_size(1); in ValidateInputMatrixShapes()
127 auto num_rows2 = input_matrix_shapes[1].dim_size(0); in ValidateInputMatrixShapes()
137 const TensorShapes& input_matrix_shapes) const final { in GetOutputMatrixShapes()
138 return TensorShapes({input_matrix_shapes[1]}); in GetOutputMatrixShapes()
Dself_adjoint_eig_op.cc45 const TensorShapes& input_matrix_shapes) const final { in GetOutputMatrixShapes()
46 int64_t d = input_matrix_shapes[0].dim_size(0); in GetOutputMatrixShapes()
Dself_adjoint_eig_v2_op_impl.h50 const TensorShapes& input_matrix_shapes) const final { in GetOutputMatrixShapes() argument
51 int64_t n = input_matrix_shapes[0].dim_size(0); in GetOutputMatrixShapes()
Deig_op_impl.h55 const TensorShapes& input_matrix_shapes) const final { in GetOutputMatrixShapes() argument
56 int64_t n = input_matrix_shapes[0].dim_size(0); in GetOutputMatrixShapes()
Ddeterminant_op.cc85 const TensorShapes& input_matrix_shapes) const final { in GetOutputMatrixShapes()