Home
last modified time | relevance | path

Searched refs:lhs_shape (Results 1 – 25 of 46) sorted by relevance

12

/external/tensorflow/tensorflow/compiler/tf2xla/kernels/
Dmatrix_triangular_solve_op.cc36 const TensorShape lhs_shape = ctx->InputShape(0); in Compile() local
45 MatMulBCast bcast(BCast::FromShape(lhs_shape), BCast::FromShape(rhs_shape)); in Compile()
48 "Incompatible shapes: ", lhs_shape.DebugString(), " vs. ", in Compile()
53 auto lhs_size = lhs_shape.dims(); in Compile()
56 lhs_shape.dim_size(lhs_size - 1) == lhs_shape.dim_size(lhs_size - 2), in Compile()
59 lhs_shape.DebugString())); in Compile()
63 std::tie(a, b) = Broadcast(a, lhs_shape, b, rhs_shape, bcast); in Compile()
75 xla::XlaOp lhs, const TensorShape& lhs_shape, xla::XlaOp rhs,
82 MatrixTriangularSolveOp::Broadcast(xla::XlaOp lhs, const TensorShape& lhs_shape, in Broadcast() argument
86 int64 m = lhs_shape.dim_size(lhs_shape.dims() - 1); in Broadcast()
Dxla_broadcast_helper_op.cc38 const TensorShape lhs_shape = context->InputShape(0); in Compile() local
41 const bool broadcast_lhs = lhs_shape.dims() < rhs_shape.dims(); in Compile()
42 const TensorShape* min_rank_shape = broadcast_lhs ? &lhs_shape : &rhs_shape; in Compile()
43 const TensorShape* max_rank_shape = broadcast_lhs ? &rhs_shape : &lhs_shape; in Compile()
51 lhs_shape.dims() == rhs_shape.dims() || lhs_shape.dims() == 0 || in Compile()
57 lhs_shape.DebugString(), " and ", rhs_shape.DebugString())); in Compile()
69 lhs_shape.DebugString(), " and ", rhs_shape.DebugString())); in Compile()
89 lhs_shape.DebugString(), " and ", rhs_shape.DebugString())); in Compile()
Dcwise_ops.cc34 const TensorShape lhs_shape = ctx->InputShape(0); in Compile() local
43 BCast bcast(BCast::FromShape(lhs_shape), BCast::FromShape(rhs_shape), in Compile()
47 lhs_shape.DebugString(), " vs. ", in Compile()
68 int max_rank = std::max(lhs_shape.dims(), rhs_shape.dims()); in Compile()
69 int min_rank = std::min(lhs_shape.dims(), rhs_shape.dims()); in Compile()
80 Computation(ctx, lhs_handle, lhs_shape.dim_sizes(), rhs_handle, in Compile()
Dxla_dot_op.cc45 const TensorShape lhs_shape = context->InputShape(0); in Compile() local
Dcwise_ops.h60 const absl::Span<const int64>& lhs_shape, const xla::XlaOp& rhs,
/external/tensorflow/tensorflow/core/kernels/mlir_generated/
Dbase_binary_ops_test.h43 void SetOpKernel(const std::string& op_name, const TensorShape& lhs_shape, in SetOpKernel() argument
60 AddInputFromArray<T>(lhs_shape, lhs_input); in SetOpKernel()
68 const TensorShape& lhs_shape, in RunAndExpectResult() argument
75 SetOpKernel<T, OutT>(op_name, lhs_shape, lhs_input, rhs_shape, rhs_input, in RunAndExpectResult()
93 const TensorShape& lhs_shape, in RunAndExpectInvalidArgument() argument
98 SetOpKernel<T, OutT>(op_name, lhs_shape, lhs_input, rhs_shape, rhs_input, in RunAndExpectInvalidArgument()
113 TensorShape lhs_shape{3}; in TestIncompatibleShapes()
116 test::RepeatInputToMatchShape(lhs_input, lhs_shape.num_elements()); in TestIncompatibleShapes()
120 RunAndExpectInvalidArgument<T, OutT>(op_name, lhs_shape, repeated_lhs_input, in TestIncompatibleShapes()
234 TensorShape lhs_shape{1}; in TestBroadcastingExpand()
[all …]
/external/tensorflow/tensorflow/compiler/xla/service/cpu/
Ddot_op_emitter.cc69 Shape lhs_shape; member
78 lhs_shape = instr.operand(0)->shape(); in DotInfo()
255 Shape operand_shapes[] = {dot_info_.lhs_shape, dot_info_.rhs_shape}; in EmitLinalgMatmul()
267 dot_info_.lhs_shape.ToString(true), "_", in EmitLinalgMatmul()
281 dot_info_.lhs_shape.rank()); in EmitLinalgMatmul()
517 const Shape& lhs_shape = lhs_array_.GetShape(); in Emit() local
520 if (ShapeUtil::IsScalar(lhs_shape) || ShapeUtil::IsScalar(rhs_shape)) { in Emit()
522 TF_RET_CHECK(ShapeUtil::IsScalar(lhs_shape) && in Emit()
552 const Shape& lhs_shape = lhs_array_.GetShape(); in EmitNaiveLlvmIrGemm() local
563 CHECK_EQ(lhs_shape.dimensions(lhs_reduction_dimension), in EmitNaiveLlvmIrGemm()
[all …]
Dcpu_layout_assignment_test.cc66 Shape lhs_shape = ShapeUtil::MakeShapeWithLayout(F32, {12}, {0}); in TEST_F() local
70 HloInstruction::CreateParameter(0, lhs_shape, "param0")); in TEST_F()
81 ShapeLayout(LayoutUtil::GetWithDefaultLayout(lhs_shape)); in TEST_F()
101 Shape lhs_shape = ShapeUtil::MakeShapeWithLayout(F32, {12}, {0}); in TEST_F() local
105 HloInstruction::CreateParameter(0, lhs_shape, "param0")); in TEST_F()
107 HloInstruction::CreateParameter(1, lhs_shape, "param1")); in TEST_F()
122 ShapeLayout(LayoutUtil::GetWithDefaultLayout(lhs_shape)); in TEST_F()
185 Shape lhs_shape = ShapeUtil::MakeShapeWithLayout(F32, {1, 12}, {0, 1}); in TEST_F() local
189 HloInstruction::CreateConstant(Literal::CreateFromShape(lhs_shape))); in TEST_F()
218 Shape lhs_shape = ShapeUtil::MakeShapeWithLayout(F32, {1, 12}, {0, 1}); in TEST_F() local
[all …]
Ddot_op_emitter_internal.h38 Shape lhs_shape; member
45 lhs_shape = instr.operand(0)->shape(); in DotInfo()
/external/tensorflow/tensorflow/compiler/mlir/tensorflow/transforms/
Dunroll_batch_matmul.cc214 auto lhs_shape = lhs_type.getShape(); in matchAndRewrite() local
220 const int dims_a = lhs_shape.size(); in matchAndRewrite()
232 lhs_shape = lhs_type.getShape(); in matchAndRewrite()
243 if (lhs_shape[dims_a - 1] != rhs_shape[dims_b - 2]) { in matchAndRewrite()
251 RankedTensorType::get({lhs_shape[0], rhs_shape[1]}, element_type); in matchAndRewrite()
263 for (auto dim : lhs_shape) { in matchAndRewrite()
275 lhs_shape.begin(), lhs_shape.end()), in matchAndRewrite()
293 createMatMulOps(sliced_lhs, sliced_rhs, bcast, lhs_shape[dims_a - 2], in matchAndRewrite()
300 result_shape.push_back(lhs_shape[dims_a - 2]); in matchAndRewrite()
Deinsum.cc244 std::vector<int64_t> lhs_shape; in reshapeForBatchMatmul() local
246 lhs_shape.reserve(dnums.lhs_rhs_out.size() + dnums.lhs_out.size() + 1); in reshapeForBatchMatmul()
250 lhs_shape.push_back(b); in reshapeForBatchMatmul()
256 lhs_shape.push_back(1); in reshapeForBatchMatmul()
258 dnums.lhs_out.emplace_back(lhs_shape.size() - 1, out_shape->size() - 1); in reshapeForBatchMatmul()
262 lhs_shape.push_back(b); in reshapeForBatchMatmul()
270 lhs_shape.push_back(lhs_out_size); in reshapeForBatchMatmul()
278 lhs_shape.push_back(lhs_rhs_size); in reshapeForBatchMatmul()
288 if (failed(VerifyShapeOfReshapeOp(lhs_shape)) || in reshapeForBatchMatmul()
292 *lhs = createReshapeOp(*lhs, lhs_shape, lhs_type.getElementType(), loc, in reshapeForBatchMatmul()
Dbatchmatmul_to_einsum.cc61 auto lhs_shape = lhs_type.getShape(); in matchAndRewrite() local
65 const int dims_a = lhs_shape.size(); in matchAndRewrite()
/external/tensorflow/tensorflow/compiler/xla/service/
Ddot_decomposer.cc46 const auto& lhs_shape = original_dot->operand(0)->shape(); in CanonicalizeDot() local
47 const int64 lhs_rank = lhs_shape.rank(); in CanonicalizeDot()
59 lhs_contracting_size *= lhs_shape.dimensions(i); in CanonicalizeDot()
62 batch_dim_sizes.push_back(lhs_shape.dimensions(i)); in CanonicalizeDot()
65 lhs_non_contracting_size *= lhs_shape.dimensions(i); in CanonicalizeDot()
83 ShapeUtil::PermuteDimensions(lhs_transpose, lhs_shape), in CanonicalizeDot()
93 ShapeUtil::MakeShape(lhs_shape.element_type(), lhs_reshape_dims), in CanonicalizeDot()
Dshape_inference_test.cc406 Shape lhs_shape = ShapeUtil::MakeShape(F32, {10, 11, 3, 4}); in TEST_F() local
439 lhs_shape, rhs_shape, /*feature_group_count=*/1, /*batch_group_count=*/1, in TEST_F()
451 Shape lhs_shape = ShapeUtil::MakeShape(F32, {10, 11, 103, 4}); in TEST_F() local
485 lhs_shape, rhs_shape, /*feature_group_count=*/1, /*batch_group_count=*/1, in TEST_F()
497 Shape lhs_shape = ShapeUtil::MakeShape(F32, {10, 11, 3, 4}); in TEST_F() local
531 lhs_shape, rhs_shape, /*feature_group_count=*/1, /*batch_group_count=*/1, in TEST_F()
541 Shape lhs_shape = ShapeUtil::MakeShape(F32, {10, 11, 3, 4}); in TEST_F() local
570 lhs_shape, rhs_shape, /*feature_group_count=*/1, /*batch_group_count=*/1, in TEST_F()
591 Shape lhs_shape = ShapeUtil::MakeShape(F32, {60, 38, 17, 13}); in TEST_F() local
607 lhs_shape, rhs_shape, /*feature_group_count=*/1, /*batch_group_count=*/6, in TEST_F()
[all …]
Dbatch_dot_simplification.cc47 const Shape& lhs_shape = lhs->shape(); in ElideDegenerateBatchDimensionFromBatchDot() local
58 if (lhs_shape.dimensions(batch_dim) == 1) { in ElideDegenerateBatchDimensionFromBatchDot()
/external/tensorflow/tensorflow/lite/kernels/
Dbatch_matmul.cc509 const RuntimeShape& lhs_shape, const TfLiteTensor* lhs, in EvalInt8() argument
529 op_params, rhs_shape, GetTensorData<int8_t>(rhs), lhs_shape, in EvalInt8()
534 lhs_shape, GetTensorData<int8_t>(lhs), in EvalInt8()
544 const RuntimeShape& lhs_shape, const TfLiteTensor* lhs, in EvalInt16() argument
563 op_params, rhs_shape, GetTensorData<int16_t>(rhs), lhs_shape, in EvalInt16()
571 OpData* data, const RuntimeShape& lhs_shape, in EvalQuantized() argument
592 context, node, data, lhs_shape, lhs, rhs_shape, rhs, input_quantized, in EvalQuantized()
595 return EvalInt8<kernel_type>(context, data, lhs_shape, lhs, rhs_shape, rhs, in EvalQuantized()
598 return EvalInt16<kernel_type>(context, data, lhs_shape, lhs, rhs_shape, rhs, in EvalQuantized()
684 RuntimeShape lhs_shape = in Eval() local
[all …]
/external/tensorflow/tensorflow/core/kernels/
Dcwise_ops_test.cc263 TensorShape lhs_shape; in BiasAddGrad() local
265 lhs_shape = TensorShape({channels, rows, cols}); in BiasAddGrad()
267 lhs_shape = TensorShape({rows, cols, channels}); in BiasAddGrad()
269 Tensor lhs(type, lhs_shape); in BiasAddGrad()
325 TensorShape lhs_shape, rhs_shape; in BcastAdd() local
327 lhs_shape = TensorShape({rows, cols}); in BcastAdd()
330 lhs_shape = TensorShape({rows, cols}); in BcastAdd()
333 lhs_shape = TensorShape({rows, 1}); in BcastAdd()
336 lhs_shape = TensorShape({1, cols}); in BcastAdd()
339 Tensor lhs(DT_FLOAT, lhs_shape); in BcastAdd()
/external/tensorflow/tensorflow/lite/kernels/internal/reference/
Dbatch_matmul.h53 inline void BatchMatMul(const RuntimeShape& lhs_shape, const float* lhs_data, in BatchMatMul() argument
57 RuntimeShape::ExtendedShape(5, lhs_shape); in BatchMatMul()
108 inline void BatchMatMul(const RuntimeShape& lhs_shape, const int8_t* lhs_data, in BatchMatMul() argument
115 RuntimeShape::ExtendedShape(5, lhs_shape); in BatchMatMul()
200 const RuntimeShape& lhs_shape, const T* lhs_data, in BatchMatMul() argument
204 RuntimeShape::ExtendedShape(5, lhs_shape); in BatchMatMul()
/external/tensorflow/tensorflow/compiler/xla/service/gpu/
Dgemm_thunk.cc39 config.lhs_shape = gemm->operand(0)->shape(); in GetGpuGemmConfig()
182 const Shape &lhs_shape = gemm_config.lhs_shape; in RunGemm() local
204 for (const auto *shape : {&lhs_shape, &rhs_shape, &output_shape}) { in RunGemm()
244 lhs_buffer, lhs_shape, dim_nums.lhs_contracting_dimensions(0) == row_dim); in RunGemm()
Dgpu_layout_assignment.cc123 Shape lhs_shape = instr->operand(0)->shape(); in AddBackendConstraintsToDnnConvCustomCall() local
135 input_shape = &lhs_shape; in AddBackendConstraintsToDnnConvCustomCall()
142 output_shape = &lhs_shape; in AddBackendConstraintsToDnnConvCustomCall()
145 input_shape = &lhs_shape; in AddBackendConstraintsToDnnConvCustomCall()
174 TF_RETURN_IF_ERROR(constraints->SetOperandLayout(lhs_shape, instr, 0)); in AddBackendConstraintsToDnnConvCustomCall()
Dir_emission_utils.cc55 bool AreValidGemmShapes(const Shape& lhs_shape, const Shape& rhs_shape, in AreValidGemmShapes() argument
66 return type_is_allowed && IsRank2(lhs_shape, batch_dimensions_size) && in AreValidGemmShapes()
69 !ShapeUtil::IsZeroElementArray(lhs_shape) && in AreValidGemmShapes()
110 const Shape& lhs_shape = dot.operand(0)->shape(); in IsMatrixMultiplication() local
116 if (AreValidGemmShapes(lhs_shape, rhs_shape, dot.shape(), in IsMatrixMultiplication()
121 CHECK_EQ(lhs_shape.dimensions(dim_numbers.lhs_contracting_dimensions(0)), in IsMatrixMultiplication()
/external/tensorflow/tensorflow/core/kernels/mkl/
Dmkl_batch_matmul_op.cc178 const TensorShape& lhs_shape, const TensorShape& rhs_shape, in CreateMatMulParams() argument
180 const auto ndims_lhs = lhs_shape.dims(); in CreateMatMulParams()
183 auto lhs_dims = TFShapeToMklDnnDims(lhs_shape); in CreateMatMulParams()
193 ExpandInputDimsToOutputShape(lhs_shape, out_shape, &lhs_dims); in CreateMatMulParams()
/external/tensorflow/tensorflow/lite/kernels/internal/optimized/
Dbatch_matmul.h28 inline void BatchMatMul(const RuntimeShape& lhs_shape, const float* lhs_data, in BatchMatMul() argument
36 RuntimeShape::ExtendedShape(5, lhs_shape); in BatchMatMul()
115 inline void BatchMatMul(const RuntimeShape& lhs_shape, const int8_t* lhs_data, in BatchMatMul() argument
127 RuntimeShape::ExtendedShape(5, lhs_shape); in BatchMatMul()
272 const RuntimeShape& lhs_shape, const int8_t* lhs_data, in BatchMatMul() argument
281 RuntimeShape::ExtendedShape(5, lhs_shape); in BatchMatMul()
/external/tensorflow/tensorflow/compiler/mlir/lite/transforms/
Dlegalize_tf.cc598 auto lhs_shape = lhs.getType().cast<ShapedType>().getShape(); in matchAndRewrite() local
601 if (lhs_shape == rhs_shape) { in matchAndRewrite()
607 if (!OpTrait::util::getBroadcastedShape(lhs_shape, rhs_shape, in matchAndRewrite()
625 if (result_type.getShape() != lhs_shape) { in matchAndRewrite()
666 auto lhs_shape = lhs.getType().cast<ShapedType>().getShape(); in matchAndRewrite() local
670 if (lhs_shape == rhs_shape && cond_shape == lhs_shape) { in matchAndRewrite()
676 if (!OpTrait::util::getBroadcastedShape(lhs_shape, rhs_shape, in matchAndRewrite()
706 if (result_shape != lhs_shape) { in matchAndRewrite()
/external/tensorflow/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/
Dlower_general_dot.cc165 auto lhs_shape = lhs_shape_type.getShape(); in matchAndRewrite() local
168 RankedTensorType::get({lhs_shape[0], rhs_shape[1]}, dot_element_type); in matchAndRewrite()

12