Searched refs:rhs_batch (Results 1 – 5 of 5) sorted by relevance
/external/tensorflow/tensorflow/python/kernel_tests/linalg/ |
D | tridiagonal_matmul_op_test.py | 60 rhs_batch = array_ops.stack([rhs, 2 * rhs]) 75 diags_sequence_batch, rhs_batch, diagonals_format='sequence'), 77 diags_compact_batch, rhs_batch, diagonals_format='compact'), 79 diags_matrix_batch, rhs_batch, diagonals_format='matrix')
|
/external/tensorflow/tensorflow/compiler/xla/python/ |
D | xla_client.py | 494 (lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers 499 dot_dims_proto.rhs_batch_dimensions.extend(rhs_batch)
|
/external/tensorflow/tensorflow/compiler/xla/service/gpu/ |
D | jitrt_custom_calls.h | 41 llvm::ArrayRef<int64_t> rhs_batch; member
|
D | jitrt_custom_calls.cc | 333 ArrayRef<int64_t> rhs_batch, in GetGemmConfig() argument 336 rhs_batch, rhs_contract, ToShape(out), alpha_real, in GetGemmConfig() 534 dot_dims.rhs_batch, dot_dims.rhs_contract); in operator ()() 603 dot_dims.rhs_batch, dot_dims.rhs_contract); in operator ()()
|
/external/tensorflow/tensorflow/compiler/xla/service/ |
D | space_to_batch_converter.cc | 1198 const int64_t rhs_batch = in CanPropagate() local 1224 if (rhs_batch * ctrl_.number_of_splits != lhs_batch) { in CanPropagate() 1252 const int64_t rhs_batch = in CanPropagate() local 1259 if (rhs_batch != ctrl_.number_of_splits * lhs_batch) { in CanPropagate()
|