Searched refs:rhs_matrix (Results 1 – 2 of 2) sorted by relevance
/external/tensorflow/tensorflow/compiler/xla/service/gpu/ |
D | gemm_thunk.cc | 54 bool DoGemm(MatrixDescriptor lhs_matrix, MatrixDescriptor rhs_matrix, in DoGemm() argument 60 CHECK_EQ(batch_size, rhs_matrix.batch_size); in DoGemm() 63 se::DeviceMemory<Element> rhs_data(rhs_matrix.data); in DoGemm() 68 auto rhs_transpose = rhs_matrix.transpose ? se::blas::Transpose::kTranspose in DoGemm() 78 /*leading dim of RHS=*/rhs_matrix.num_rows, /*beta=*/beta, in DoGemm() 84 int64 rhs_stride = rhs_matrix.num_rows * rhs_matrix.num_cols; in DoGemm() 92 /*leading dim of RHS=*/rhs_matrix.num_rows, rhs_stride, in DoGemm() 115 MatrixDescriptor rhs_matrix, in DoGemmWithAlgorithm() argument 124 CHECK_EQ(1, rhs_matrix.batch_size); in DoGemmWithAlgorithm() 128 se::DeviceMemory<Element> rhs_data(rhs_matrix.data); in DoGemmWithAlgorithm() [all …]
|
/external/tensorflow/tensorflow/contrib/factorization/python/kernel_tests/ |
D | wals_solver_ops_test.py | 55 rhs_matrix] = gen_factorization_ops.wals_compute_partial_lhs_and_rhs( 78 self.assertAllClose(rhs_matrix.eval(), [[0.019300, 0.023000, 0.026700], 87 rhs_matrix] = gen_factorization_ops.wals_compute_partial_lhs_and_rhs( 110 self.assertAllClose(rhs_matrix.eval(), [[0.019300, 0.023000, 0.026700],
|