Home
last modified time | relevance | path

Searched refs:dot_dnums (Results 1 – 20 of 20) sorted by relevance

/external/tensorflow/tensorflow/compiler/xla/service/
Dshape_inference_test.cc1470 DotDimensionNumbers dot_dnums; in TEST_F() local
1472 f32_, vector_32_, dot_dnums, /*preferred_element_type=*/absl::nullopt); in TEST_F()
1479 DotDimensionNumbers dot_dnums; in TEST_F() local
1480 dot_dnums.add_lhs_contracting_dimensions(1); in TEST_F()
1481 dot_dnums.add_rhs_contracting_dimensions(0); in TEST_F()
1483 ShapeUtil::MakeShape(F32, {32, 32, 32}), matrix_32_64_, dot_dnums, in TEST_F()
1492 DotDimensionNumbers dot_dnums; in TEST_F() local
1493 dot_dnums.add_lhs_contracting_dimensions(0); in TEST_F()
1494 dot_dnums.add_rhs_contracting_dimensions(0); in TEST_F()
1496 ShapeInference::InferDotOpShape(vector_64_, vector_64_, dot_dnums, in TEST_F()
[all …]
Ddot_decomposer.cc151 DotDimensionNumbers dot_dnums; in CanonicalizeDot() local
153 dot_dnums.add_lhs_batch_dimensions(i); in CanonicalizeDot()
154 dot_dnums.add_rhs_batch_dimensions(i); in CanonicalizeDot()
156 dot_dnums.add_lhs_contracting_dimensions( in CanonicalizeDot()
158 dot_dnums.add_rhs_contracting_dimensions(num_batch_dims); in CanonicalizeDot()
162 reshaped_lhs, reshaped_rhs, dot_dnums, original_dot->precision_config())); in CanonicalizeDot()
Dhlo_instruction_test.cc1210 DotDimensionNumbers dot_dnums; in TEST_F() local
1211 dot_dnums.add_lhs_contracting_dimensions(1); in TEST_F()
1212 dot_dnums.add_rhs_contracting_dimensions(0); in TEST_F()
1214 sout, x, reshape, dot_dnums, DefaultPrecisionConfig(2))); in TEST_F()
1309 DotDimensionNumbers dot_dnums; in TEST_F() local
1310 dot_dnums.add_lhs_contracting_dimensions(1); in TEST_F()
1311 dot_dnums.add_rhs_contracting_dimensions(0); in TEST_F()
1313 s, x, reshape, dot_dnums, DefaultPrecisionConfig(2))); in TEST_F()
1360 DotDimensionNumbers dot_dnums; in TEST_F() local
1361 dot_dnums.add_lhs_contracting_dimensions(1); in TEST_F()
[all …]
Dhlo_computation_test.cc550 DotDimensionNumbers dot_dnums; in TEST_F() local
551 dot_dnums.add_lhs_contracting_dimensions(1); in TEST_F()
552 dot_dnums.add_rhs_contracting_dimensions(0); in TEST_F()
557 HloInstruction::CreateDot(sout, x, reshape, dot_dnums, precision_config)); in TEST_F()
585 DotDimensionNumbers dot_dnums; in TEST_F() local
586 dot_dnums.add_lhs_contracting_dimensions(1); in TEST_F()
587 dot_dnums.add_rhs_contracting_dimensions(0); in TEST_F()
592 HloInstruction::CreateDot(sout, x, reshape, dot_dnums, precision_config)); in TEST_F()
621 DotDimensionNumbers dot_dnums; in TEST_F() local
622 dot_dnums.add_lhs_contracting_dimensions(1); in TEST_F()
[all …]
Ddynamic_dimension_inference_test.cc322 DotDimensionNumbers dot_dnums; in TEST_F() local
323 dot_dnums.add_lhs_contracting_dimensions(1); in TEST_F()
324 dot_dnums.add_rhs_contracting_dimensions(0); in TEST_F()
326 HloInstruction::CreateDot(xz_shape, a_param, b_param, dot_dnums, in TEST_F()
363 DotDimensionNumbers dot_dnums; in TEST_F() local
364 dot_dnums.add_lhs_contracting_dimensions(3); in TEST_F()
365 dot_dnums.add_rhs_contracting_dimensions(3); in TEST_F()
366 dot_dnums.add_lhs_batch_dimensions(0); in TEST_F()
367 dot_dnums.add_lhs_batch_dimensions(2); in TEST_F()
368 dot_dnums.add_rhs_batch_dimensions(0); in TEST_F()
[all …]
Ddot_as_convolution_util.cc110 const HloInstruction& conv, const DotConvolutionDimsInfo& dot_dnums, in CreateShardedConvForDotGeneralConvolution() argument
115 for (const auto& dim : dot_dnums.batch_dims) { in CreateShardedConvForDotGeneralConvolution()
122 for (const auto& dim : dot_dnums.contracting_dims) { in CreateShardedConvForDotGeneralConvolution()
130 for (const auto& dim : dot_dnums.rhs_non_contracting_dims) { in CreateShardedConvForDotGeneralConvolution()
Dheap_simulator_test.cc681 DotDimensionNumbers dot_dnums; in TEST_F() local
682 dot_dnums.add_lhs_contracting_dimensions(1); in TEST_F()
683 dot_dnums.add_rhs_contracting_dimensions(0); in TEST_F()
685 f32vec4_, mul, paramY, dot_dnums, DefaultPrecisionConfig(2))); in TEST_F()
717 DotDimensionNumbers dot_dnums; in TEST_F() local
718 dot_dnums.add_lhs_contracting_dimensions(1); in TEST_F()
719 dot_dnums.add_rhs_contracting_dimensions(0); in TEST_F()
721 f32vec4_, mul, paramY, dot_dnums, DefaultPrecisionConfig(2))); in TEST_F()
758 DotDimensionNumbers dot_dnums; in TEST_F() local
759 dot_dnums.add_lhs_contracting_dimensions(1); in TEST_F()
[all …]
Ddot_as_convolution_util.h62 const HloInstruction& conv, const DotConvolutionDimsInfo& dot_dnums,
Dbfloat16_normalization_test.cc410 DotDimensionNumbers dot_dnums; in TEST_F() local
411 dot_dnums.add_lhs_contracting_dimensions(1); in TEST_F()
412 dot_dnums.add_rhs_contracting_dimensions(0); in TEST_F()
417 HloInstruction::CreateDot(bf16_shape, a, b, dot_dnums, precision_config)); in TEST_F()
Dhlo_evaluator_test.cc812 DotDimensionNumbers dot_dnums; in TEST_P() local
813 dot_dnums.add_lhs_contracting_dimensions(1); in TEST_P()
814 dot_dnums.add_rhs_contracting_dimensions(0); in TEST_P()
816 rhs_instruction, dot_dnums, in TEST_P()
858 DotDimensionNumbers dot_dnums; in TEST_P() local
859 dot_dnums.add_lhs_contracting_dimensions(0); in TEST_P()
860 dot_dnums.add_rhs_contracting_dimensions(0); in TEST_P()
862 rhs_instruction, dot_dnums, in TEST_P()
902 DotDimensionNumbers dot_dnums; in TEST_P() local
903 dot_dnums.add_lhs_contracting_dimensions(1); in TEST_P()
[all …]
Dmemory_space_assignment_test.cc4838 DotDimensionNumbers dot_dnums; in TEST_P() local
4839 dot_dnums.add_lhs_contracting_dimensions(1); in TEST_P()
4840 dot_dnums.add_rhs_contracting_dimensions(0); in TEST_P()
4842 result_shape, lhs, rhs, dot_dnums, DefaultPrecisionConfig(2))); in TEST_P()
4884 DotDimensionNumbers dot_dnums; in TEST_P() local
4885 dot_dnums.add_lhs_contracting_dimensions(1); in TEST_P()
4886 dot_dnums.add_rhs_contracting_dimensions(0); in TEST_P()
4888 result_shape, lhs, bitcast, dot_dnums, DefaultPrecisionConfig(2))); in TEST_P()
4930 DotDimensionNumbers dot_dnums; in TEST_P() local
4931 dot_dnums.add_lhs_contracting_dimensions(1); in TEST_P()
[all …]
Dalgebraic_simplifier_test.cc4535 DotDimensionNumbers dot_dnums; in TEST_F() local
4536 dot_dnums.add_lhs_batch_dimensions(0); in TEST_F()
4537 dot_dnums.add_rhs_batch_dimensions(0); in TEST_F()
4538 builder.AddInstruction(HloInstruction::CreateDot(r1f32, x, y, dot_dnums, in TEST_F()
5201 DotDimensionNumbers dot_dnums; in TEST_P() local
5202 dot_dnums.add_lhs_batch_dimensions(0); in TEST_P()
5203 dot_dnums.add_lhs_batch_dimensions(1); in TEST_P()
5204 dot_dnums.add_lhs_batch_dimensions(2); in TEST_P()
5205 dot_dnums.add_rhs_batch_dimensions(0); in TEST_P()
5206 dot_dnums.add_rhs_batch_dimensions(1); in TEST_P()
[all …]
Dbfloat16_propagation_test.cc88 DotDimensionNumbers dot_dnums; in CreateDot() local
89 dot_dnums.add_lhs_contracting_dimensions(1); in CreateDot()
90 dot_dnums.add_rhs_contracting_dimensions(0); in CreateDot()
91 return HloInstruction::CreateDot(shape, lhs, rhs, dot_dnums, in CreateDot()
Dbuffer_assignment_test.cc1764 DotDimensionNumbers dot_dnums; in TEST_F() local
1765 dot_dnums.add_lhs_contracting_dimensions(1); in TEST_F()
1766 dot_dnums.add_rhs_contracting_dimensions(0); in TEST_F()
1771 shape_2x4, param_a, param_b, dot_dnums, precision_config)); in TEST_F()
1773 shape_3x4, param_b, param_c, dot_dnums, precision_config)); in TEST_F()
Dhlo_dataflow_analysis_test.cc2622 DotDimensionNumbers dot_dnums; in TEST_F() local
2623 dot_dnums.add_lhs_contracting_dimensions(1); in TEST_F()
2624 dot_dnums.add_rhs_contracting_dimensions(0); in TEST_F()
2629 HloInstruction::CreateDot(data_shape, a, b, dot_dnums, precision_config)); in TEST_F()
/external/tensorflow/tensorflow/compiler/xla/tests/
Ddot_operation_test.cc967 DotDimensionNumbers dot_dnums; in XLA_TEST_F() local
968 dot_dnums.add_lhs_contracting_dimensions(1); in XLA_TEST_F()
969 dot_dnums.add_rhs_contracting_dimensions(0); in XLA_TEST_F()
970 DotGeneral(dynamic_slice, rhs_constant, dot_dnums); in XLA_TEST_F()
995 DotDimensionNumbers dot_dnums; in XLA_TEST_F() local
996 dot_dnums.add_lhs_contracting_dimensions(1); in XLA_TEST_F()
997 dot_dnums.add_rhs_contracting_dimensions(0); in XLA_TEST_F()
998 DotGeneral(lhs_constant, dynamic_slice, dot_dnums); in XLA_TEST_F()
1025 DotDimensionNumbers dot_dnums; in XLA_TEST_F() local
1026 dot_dnums.add_lhs_contracting_dimensions(0); in XLA_TEST_F()
[all …]
Dmultioutput_fusion_test.cc89 DotDimensionNumbers dot_dnums; in RunTest2D() local
90 dot_dnums.add_lhs_contracting_dimensions(1); in RunTest2D()
91 dot_dnums.add_rhs_contracting_dimensions(0); in RunTest2D()
93 elem_shape2, sub, add2, dot_dnums, DefaultPrecisionConfig(2))); in RunTest2D()
151 DotDimensionNumbers dot_dnums; in RunTest1D() local
152 dot_dnums.add_lhs_contracting_dimensions(0); in RunTest1D()
153 dot_dnums.add_rhs_contracting_dimensions(0); in RunTest1D()
156 dot_dnums, DefaultPrecisionConfig(2))); in RunTest1D()
/external/tensorflow/tensorflow/compiler/xla/service/spmd/
Dfft_handler.cc145 DotDimensionNumbers dot_dnums; in ShuffleWithinEachPartitionUsingOneHot() local
146 dot_dnums.add_lhs_contracting_dimensions(hlo->shape().rank() - 1); in ShuffleWithinEachPartitionUsingOneHot()
147 dot_dnums.add_rhs_contracting_dimensions(0); in ShuffleWithinEachPartitionUsingOneHot()
152 hlo->shape(), hlo, shuffle_one_hot, dot_dnums, precision_config)); in ShuffleWithinEachPartitionUsingOneHot()
Dconvolution_handler.cc895 const dot_as_convolution_util::DotConvolutionDimsInfo& dot_dnums, in CreateShardedConvConvolution() argument
901 for (const auto& dim : dot_dnums.batch_dims) { in CreateShardedConvConvolution()
908 for (const auto& dim : dot_dnums.contracting_dims) { in CreateShardedConvConvolution()
916 for (const auto& dim : dot_dnums.rhs_non_contracting_dims) { in CreateShardedConvConvolution()
927 for (const auto& dim : dot_dnums.conv_spatial_dims) { in CreateShardedConvConvolution()
/external/tensorflow/tensorflow/compiler/xla/service/cpu/
Dcpu_instruction_fusion_test.cc40 DotDimensionNumbers dot_dnums; in MakeDot() local
41 dot_dnums.add_lhs_contracting_dimensions(lhs->shape().rank() - 1); in MakeDot()
42 dot_dnums.add_rhs_contracting_dimensions(0); in MakeDot()
46 return HloInstruction::CreateDot(shape, lhs, rhs, dot_dnums, in MakeDot()