Home
last modified time | relevance | path

Searched refs:a_matrix (Results 1 – 9 of 9) sorted by relevance

/external/tensorflow/tensorflow/core/kernels/sparse/
Dmul_op.cc46 const CSRSparseMatrix* a_matrix; in Compute() local
47 OP_REQUIRES_OK(ctx, ExtractVariantFromInput(ctx, 0, &a_matrix)); in Compute()
50 OP_REQUIRES(ctx, a_matrix->dtype() == b_t.dtype(), in Compute()
53 DataTypeString(a_matrix->dtype()), in Compute()
58 const Tensor& a_dense_shape_t = a_matrix->dense_shape(); in Compute()
64 ((a_matrix->dims() == 3) && (b_t.dim_size(0) == batch_size) && in Compute()
86 OP_REQUIRES_OK(ctx, csrmul_scalar.Compute(ctx, *a_matrix, b, &c_matrix)); in Compute()
92 csrmul_batch_vec.Compute(ctx, *a_matrix, b, &c_matrix)); in Compute()
Dsparse_mat_mul_op.cc315 const CSRSparseMatrix* a_matrix; in Compute() local
317 OP_REQUIRES_OK(ctx, ExtractVariantFromInput(ctx, 0, &a_matrix)); in Compute()
320 ctx, a_matrix->dtype() == DataTypeToEnum<T>::value, in Compute()
322 DataTypeString(a_matrix->dtype()), " vs. ", in Compute()
332 auto a_dense_shape = a_matrix->dense_shape().vec<int64>(); in Compute()
350 const int batch_size = a_matrix->batch_size(); in Compute()
396 a_input_matrix = a_matrix; in Compute()
400 ctx, transpose(ctx, conjugate_a_, *a_matrix, &a_matrix_transposed)); in Compute()
Dadd_op.cc182 const CSRSparseMatrix* a_matrix; in Compute() local
184 OP_REQUIRES_OK(ctx, ExtractVariantFromInput(ctx, 0, &a_matrix)); in Compute()
188 ctx, a_matrix->dtype() == DataTypeToEnum<T>::value, in Compute()
190 DataTypeString(a_matrix->dtype()), " vs. ", in Compute()
216 OP_REQUIRES_OK(ctx, add_functor(*a_matrix, *b_matrix, &c_matrix)); in Compute()
Dmat_mul_op.cc500 const CSRSparseMatrix* a_matrix; in Compute() local
501 OP_REQUIRES_OK(ctx, ExtractVariantFromInput(ctx, 0, &a_matrix)); in Compute()
507 this->ValidateInputs(*a_matrix, b_t, &rank, &batch_size)); in Compute()
509 const Tensor& a_dense_shape_t = a_matrix->dense_shape(); in Compute()
569 auto a_row_ptr = a_matrix->row_pointers_vec(i); in Compute()
570 auto a_col_ind = a_matrix->col_indices_vec(i); in Compute()
571 auto a_values = a_matrix->values_vec<T>(i); in Compute()
619 a_input_matrix = a_matrix; in Compute()
622 OP_REQUIRES_OK(ctx, transpose(ctx, this->conjugate_a_, *a_matrix, in Compute()
/external/tensorflow/tensorflow/python/ops/linalg/sparse/
Dsparse_csr_matrix_ops.py192 a_matrix = a._matrix if isinstance(a, SparseMatrix) else a
194 with ops.name_scope(name, "SparseMatrixMatMul", [a_matrix, b_matrix]):
200 a_matrix,
211 c_handle = matmul_shape_inference(a_matrix, b_matrix, c, transpose_a,
217 a_matrix,
/external/tensorflow/tensorflow/core/kernels/
Dgemm_functors.h101 typename tensorflow::TTypes<const T1>::Matrix a_matrix(a, m, k); in operator()
109 a_matrix.contract(b_matrix, dim_pair); in operator()
/external/tensorflow/tensorflow/core/grappler/costs/
Dop_level_cost_estimator.cc817 auto& a_matrix = op_info.inputs(0); in CountMatMulOperations() local
836 MaybeGetMinimumShape(a_matrix.shape(), 2, found_unknown_shapes); in CountMatMulOperations()
969 OpInfo::TensorProperties* a_matrix = matmul_op_info.add_inputs(); in CountBatchMatMulOperations() local
970 a_matrix->set_dtype(a_input.dtype()); in CountBatchMatMulOperations()
971 TensorShapeProto* a_matrix_shape = a_matrix->mutable_shape(); in CountBatchMatMulOperations()
1464 OpInfo::TensorProperties* a_matrix = batch_matmul_op_info.add_inputs(); in PredictEinsum() local
1465 TensorShapeProto* a_matrix_shape = a_matrix->mutable_shape(); in PredictEinsum()
1466 a_matrix->set_dtype(a_input.dtype()); in PredictEinsum()
/external/eigen/Eigen/src/Eigenvalues/
DSelfAdjointEigenSolver.h400 ::compute(const EigenBase<InputType>& a_matrix, int options)
404 const InputType &matrix(a_matrix.derived());
/external/tensorflow/tensorflow/stream_executor/cuda/
Dcuda_blas.cc2258 const DeviceMemory<T> &a_matrix = *a_ptrs_to_wrappers[b]; in DoBlasGemmBatchedInternal() local
2261 bool ok = DoBlasGemm(stream, transa, transb, m, n, k, alpha, a_matrix, in DoBlasGemmBatchedInternal()
2395 const auto *a_matrix = in DoBlasGemmStridedBatched() local
2404 CUDABlasTranspose(transb), m, n, k, &alpha, a_matrix, SE_CUDA_DATA_HALF, in DoBlasGemmStridedBatched()