Home
last modified time | relevance | path

Searched refs:rhs_shape (Results 1 – 25 of 36) sorted by relevance

12

/external/tensorflow/tensorflow/compiler/tf2xla/kernels/
Dxla_broadcast_helper_op.cc39 const TensorShape rhs_shape = context->InputShape(1); in Compile() local
41 const bool broadcast_lhs = lhs_shape.dims() < rhs_shape.dims(); in Compile()
42 const TensorShape* min_rank_shape = broadcast_lhs ? &lhs_shape : &rhs_shape; in Compile()
43 const TensorShape* max_rank_shape = broadcast_lhs ? &rhs_shape : &lhs_shape; in Compile()
51 lhs_shape.dims() == rhs_shape.dims() || lhs_shape.dims() == 0 || in Compile()
52 rhs_shape.dims() == 0, in Compile()
57 lhs_shape.DebugString(), " and ", rhs_shape.DebugString())); in Compile()
69 lhs_shape.DebugString(), " and ", rhs_shape.DebugString())); in Compile()
89 lhs_shape.DebugString(), " and ", rhs_shape.DebugString())); in Compile()
Dmatrix_triangular_solve_op.cc37 const TensorShape rhs_shape = ctx->InputShape(1); in Compile() local
45 MatMulBCast bcast(BCast::FromShape(lhs_shape), BCast::FromShape(rhs_shape)); in Compile()
49 rhs_shape.DebugString())); in Compile()
55 std::tie(a, b) = Broadcast(a, lhs_shape, b, rhs_shape, bcast); in Compile()
68 const TensorShape& rhs_shape, const MatMulBCast& broadcast_helper);
75 xla::XlaOp rhs, const TensorShape& rhs_shape, in Broadcast() argument
79 int64 n = rhs_shape.dim_size(rhs_shape.dims() - 1); in Broadcast()
Dcwise_ops.cc35 const TensorShape rhs_shape = ctx->InputShape(1); in Compile() local
43 BCast bcast(BCast::FromShape(lhs_shape), BCast::FromShape(rhs_shape), in Compile()
48 rhs_shape.DebugString())); in Compile()
68 int max_rank = std::max(lhs_shape.dims(), rhs_shape.dims()); in Compile()
69 int min_rank = std::min(lhs_shape.dims(), rhs_shape.dims()); in Compile()
81 rhs_shape.dim_sizes(), bcast, extend_dimension); in Compile()
Dstrided_slice_op.cc337 const TensorShape rhs_shape = ctx->InputShape(4); in Compile() local
348 if (final_shape.num_elements() == 0 && rhs_shape.num_elements() == 0) { in Compile()
357 OP_REQUIRES(ctx, final_shape == rhs_shape, in Compile()
360 " does not match r-value shape ", rhs_shape.DebugString(), in Compile()
Dxla_dot_op.cc46 const TensorShape rhs_shape = context->InputShape(1); in Compile() local
Dcwise_ops.h61 const absl::Span<const int64>& rhs_shape, const BCast& broadcast_helper,
Dxla_conv_op.cc45 const TensorShape rhs_shape = context->InputShape(1); in Compile() local
/external/tensorflow/tensorflow/python/kernel_tests/
Dtridiagonal_solve_op_test.py484 def test_raises(diags_shape, rhs_shape): argument
485 self._assertRaises(_tf_ones(diags_shape), _tf_ones(rhs_shape), "compact")
495 def test_raises(diags_tuple_shapes, rhs_shape): argument
497 self._assertRaises(diagonals, _tf_ones(rhs_shape), "sequence")
509 def test_raises(diags_shape, rhs_shape): argument
510 self._assertRaises(_tf_ones(diags_shape), _tf_ones(rhs_shape), "matrix")
520 rhs_shape, argument
528 rhs = array_ops.placeholder(dtypes.float64, shape=rhs_shape)
539 rhs_shape=[None],
548 rhs_shape=[4],
[all …]
/external/tensorflow/tensorflow/compiler/mlir/lite/transforms/
Dunroll_batch_matmul.cc220 auto rhs_shape = rhs_type.getShape(); in matchAndRewrite() local
226 const int dims_b = rhs_shape.size(); in matchAndRewrite()
245 rhs_shape = rhs_type.getShape(); in matchAndRewrite()
248 if (lhs_shape[dims_a - 1] != rhs_shape[dims_b - 2]) { in matchAndRewrite()
256 RankedTensorType::get({lhs_shape[0], rhs_shape[1]}, element_type); in matchAndRewrite()
273 for (auto dim : rhs_shape) { in matchAndRewrite()
282 rhs_shape.begin(), rhs_shape.end())); in matchAndRewrite()
299 rhs_shape[dims_b - 1], element_type, loc, rewriter); in matchAndRewrite()
306 result_shape.push_back(rhs_shape[dims_b - 1]); in matchAndRewrite()
/external/tensorflow/tensorflow/compiler/xla/service/cpu/
Ddot_op_emitter.cc58 Shape rhs_shape; member
67 rhs_shape = instr.operand(1)->shape(); in DotInfo()
398 const Shape& rhs_shape = rhs_array_.GetShape(); in Emit() local
400 if (ShapeUtil::IsScalar(lhs_shape) || ShapeUtil::IsScalar(rhs_shape)) { in Emit()
403 ShapeUtil::IsScalar(rhs_shape)); in Emit()
430 const Shape& rhs_shape = rhs_array_.GetShape(); in EmitNaiveLlvmIrGemm() local
441 rhs_shape.dimensions(rhs_reduction_dimension)); in EmitNaiveLlvmIrGemm()
446 rhs_reduction_dimension == LayoutUtil::Minor(rhs_shape.layout(), 0); in EmitNaiveLlvmIrGemm()
481 llvm_ir::IrArray::Index rhs_index(rhs_multi_index, rhs_shape, in EmitNaiveLlvmIrGemm()
717 const Shape& rhs_shape = rhs_array_.GetShape(); in GetMatMultDims() local
[all …]
Dcpu_layout_assignment_test.cc67 Shape rhs_shape = ShapeUtil::MakeShape(F32, {12, 24}); in TEST_F() local
72 HloInstruction::CreateConstant(Literal::CreateFromShape(rhs_shape))); in TEST_F()
102 Shape rhs_shape = ShapeUtil::MakeShape(F32, {12, 24}); in TEST_F() local
109 HloInstruction::CreateConstant(Literal::CreateFromShape(rhs_shape))); in TEST_F()
145 Shape rhs_shape = ShapeUtil::MakeShapeWithLayout(F32, {12, 24}, {0, 1}); in TEST_F() local
153 HloInstruction::CreateConstant(Literal::CreateFromShape(rhs_shape))); in TEST_F()
186 Shape rhs_shape = ShapeUtil::MakeShapeWithLayout(F32, {12, 24}, {0, 1}); in TEST_F() local
191 HloInstruction::CreateParameter(0, rhs_shape, "param0")); in TEST_F()
200 ShapeLayout(LayoutUtil::GetWithDefaultLayout(rhs_shape)); in TEST_F()
219 Shape rhs_shape = ShapeUtil::MakeShapeWithLayout(F32, {12, 24}, {0, 1}); in TEST_F() local
[all …]
Ddot_op_emitter_internal.h39 Shape rhs_shape; member
46 rhs_shape = instr.operand(1)->shape(); in DotInfo()
/external/tensorflow/tensorflow/compiler/xla/service/
Ddot_decomposer.cc91 const auto& rhs_shape = original_dot->operand(1)->shape(); in CanonicalizeDot() local
92 const int64 rhs_rank = rhs_shape.rank(); in CanonicalizeDot()
101 rhs_contracting_size *= rhs_shape.dimensions(i); in CanonicalizeDot()
105 rhs_non_contracting_size *= rhs_shape.dimensions(i); in CanonicalizeDot()
124 rhs_shape), in CanonicalizeDot()
133 ShapeUtil::MakeShape(rhs_shape.element_type(), rhs_reshape_dims), in CanonicalizeDot()
Dshape_inference_test.cc416 Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 12, 11, 3}); in TEST_F() local
438 lhs_shape, rhs_shape, /*feature_group_count=*/1, /*batch_group_count=*/1, in TEST_F()
461 Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 12, 11, 3}); in TEST_F() local
484 lhs_shape, rhs_shape, /*feature_group_count=*/1, /*batch_group_count=*/1, in TEST_F()
507 Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 12, 11, 4}); in TEST_F() local
530 lhs_shape, rhs_shape, /*feature_group_count=*/1, /*batch_group_count=*/1, in TEST_F()
541 Shape rhs_shape = ShapeUtil::MakeShape(F32, {12, 11, 3, 2}); in TEST_F() local
569 lhs_shape, rhs_shape, /*feature_group_count=*/1, /*batch_group_count=*/1, in TEST_F()
591 Shape rhs_shape = ShapeUtil::MakeShape(F32, {38, 10, 4, 4}); in TEST_F() local
606 lhs_shape, rhs_shape, /*feature_group_count=*/1, /*batch_group_count=*/6, in TEST_F()
[all …]
/external/tensorflow/tensorflow/core/kernels/
Dcwise_ops_test.cc225 TensorShape rhs_shape; in BiasAdd() local
226 rhs_shape = TensorShape({cols}); in BiasAdd()
227 Tensor rhs(type, rhs_shape); in BiasAdd()
326 TensorShape lhs_shape, rhs_shape; in BcastAdd() local
329 rhs_shape = TensorShape({rows, 1}); in BcastAdd()
332 rhs_shape = TensorShape({cols}); in BcastAdd()
335 rhs_shape = TensorShape({1, cols}); in BcastAdd()
338 rhs_shape = TensorShape({rows, 1}); in BcastAdd()
342 Tensor rhs(DT_FLOAT, rhs_shape); in BcastAdd()
/external/tensorflow/tensorflow/compiler/xla/service/gpu/
Dgpu_layout_assignment.cc118 Shape rhs_shape = instr->operand(1)->shape(); in AddBackendConstraintsToDnnConvCustomCall() local
130 filter_shape = &rhs_shape; in AddBackendConstraintsToDnnConvCustomCall()
135 filter_shape = &rhs_shape; in AddBackendConstraintsToDnnConvCustomCall()
141 output_shape = &rhs_shape; in AddBackendConstraintsToDnnConvCustomCall()
169 TF_RETURN_IF_ERROR(constraints->SetOperandLayout(rhs_shape, instr, 1)); in AddBackendConstraintsToDnnConvCustomCall()
Dir_emission_utils.cc50 bool AreValidGemmShapes(const Shape& lhs_shape, const Shape& rhs_shape, in AreValidGemmShapes() argument
62 IsRank2(rhs_shape, batch_dimensions_size) && in AreValidGemmShapes()
65 !ShapeUtil::IsZeroElementArray(rhs_shape); in AreValidGemmShapes()
106 const Shape& rhs_shape = dot.operand(1)->shape(); in IsMatrixMultiplication() local
111 if (AreValidGemmShapes(lhs_shape, rhs_shape, dot.shape(), in IsMatrixMultiplication()
117 rhs_shape.dimensions(dim_numbers.rhs_contracting_dimensions(0))); in IsMatrixMultiplication()
Dgemm_thunk.cc174 const Shape &rhs_shape = rhs->shape(); in RunGemm() local
194 for (const auto *shape : {&lhs_shape, &rhs_shape, &output_shape}) { in RunGemm()
240 rhs_buffer, rhs_shape, dim_nums.rhs_contracting_dimensions(0) == col_dim); in RunGemm()
Dir_emitter.cc512 const Shape& rhs_shape = rhs_instruction->shape(); in HandleDot() local
520 if (ShapeUtil::IsScalar(lhs_shape) && ShapeUtil::IsScalar(rhs_shape)) { in HandleDot()
545 !ShapeUtil::IsScalar(rhs_shape)); in HandleDot()
558 rhs_shape.dimensions(rhs_reduction_dimension)) in HandleDot()
562 << ") = " << rhs_shape.dimensions(rhs_reduction_dimension); in HandleDot()
/external/tensorflow/tensorflow/compiler/xla/client/lib/
Dtridiagonal.cc68 TF_ASSIGN_OR_RETURN(Shape rhs_shape, builder->GetShape(rhs)); in CheckSystemAndReturnShape()
73 const auto rhs_rank = rhs_shape.rank(); in CheckSystemAndReturnShape()
94 const auto rhs_num_eqs = ShapeUtil::GetDimension(rhs_shape, rank - 1); in CheckSystemAndReturnShape()
/external/tensorflow/tensorflow/compiler/tests/
Dtridiagonal_solve_ops_test.py475 def test_raises(diags_shape, rhs_shape): argument
476 self._assertRaises(_tf_ones(diags_shape), _tf_ones(rhs_shape), "compact")
485 def test_raises(diags_tuple_shapes, rhs_shape): argument
487 self._assertRaises(diagonals, _tf_ones(rhs_shape), "sequence")
498 def test_raises(diags_shape, rhs_shape): argument
499 self._assertRaises(_tf_ones(diags_shape), _tf_ones(rhs_shape), "matrix")
/external/tensorflow/tensorflow/python/ops/linalg/
Dlinear_operator_tridiag.py306 rhs_shape = array_ops.shape(rhs)
309 self._shape_tensor(diagonals)[:-2], rhs_shape[:-2])
312 [broadcast_shape, rhs_shape[-2:]], axis=-1))
/external/tensorflow/tensorflow/compiler/mlir/xla/transforms/
Dlower_general_dot.cc158 auto rhs_shape = rhs.getType().cast<mlir::ShapedType>().getShape(); in matchAndRewrite() local
160 RankedTensorType::get({lhs_shape[0], rhs_shape[1]}, dot_element_type); in matchAndRewrite()
/external/tensorflow/tensorflow/compiler/xla/tests/
Dmatrix_ops_simple_test.cc198 Shape rhs_shape = in TestImpl() local
216 auto rhs_arg = Parameter(&builder, 1, rhs_shape, "rhs"); in TestImpl()
/external/tensorflow/tensorflow/compiler/xla/client/
Dxla_builder.cc553 TF_ASSIGN_OR_RETURN(const Shape* rhs_shape, GetShapePtr(rhs)); in BinaryOp()
556 binop, *lhs_shape, *rhs_shape, broadcast_dimensions)); in BinaryOp()
571 const int64 rhs_rank = rhs_shape->rank(); in BinaryOp()
579 const Shape& from_shape = should_broadcast_lhs ? *lhs_shape : *rhs_shape; in BinaryOp()
631 TF_ASSIGN_OR_RETURN(const Shape* rhs_shape, GetShapePtr(rhs)); in TernaryOp()
635 for (const Shape* shape : {lhs_shape, rhs_shape, ehs_shape}) { in TernaryOp()
654 if (ShapeUtil::IsScalar(*rhs_shape)) { in TernaryOp()
666 TF_ASSIGN_OR_RETURN(const Shape* rhs_shape, GetShapePtr(updated_rhs)); in TernaryOp()
669 triop, *lhs_shape, *rhs_shape, *ehs_shape); in TernaryOp()
1153 TF_ASSIGN_OR_RETURN(const Shape* rhs_shape, GetShapePtr(rhs)); in DotGeneral()
[all …]

12