Home
last modified time | relevance | path

Searched refs:dim_nums (Results 1 – 9 of 9) sorted by relevance

/external/tensorflow/tensorflow/compiler/xla/tests/
Dconvolution_dimension_numbers_test.cc105 ConvolutionDimensionNumbers dim_nums = in XLA_TEST_F() local
108 int64 old_input_batch_dim = dim_nums.input_batch_dimension(); in XLA_TEST_F()
109 int64 old_output_batch_dim = dim_nums.output_batch_dimension(); in XLA_TEST_F()
110 dim_nums.set_input_batch_dimension(dim_nums.input_feature_dimension()); in XLA_TEST_F()
111 dim_nums.set_output_batch_dimension(dim_nums.output_feature_dimension()); in XLA_TEST_F()
112 dim_nums.set_input_feature_dimension(old_input_batch_dim); in XLA_TEST_F()
113 dim_nums.set_output_feature_dimension(old_output_batch_dim); in XLA_TEST_F()
116 dim_nums.kernel_input_feature_dimension(); in XLA_TEST_F()
117 dim_nums.set_kernel_input_feature_dimension( in XLA_TEST_F()
118 dim_nums.kernel_output_feature_dimension()); in XLA_TEST_F()
[all …]
/external/tensorflow/tensorflow/compiler/xla/service/
Dconvolution_4d_expander.cc39 const ConvolutionDimensionNumbers& dim_nums = in InstructionMatchesPattern() local
41 if (dim_nums.input_spatial_dimensions().size() != 4) { in InstructionMatchesPattern()
45 for (int64 i = 0; i < dim_nums.input_spatial_dimensions().size(); ++i) { in InstructionMatchesPattern()
46 int64 spatial_dim = dim_nums.input_spatial_dimensions(i); in InstructionMatchesPattern()
59 ConvolutionDimensionNumbers dim_nums = in ExpandInstruction() local
61 ConvolutionDimensionNumbers new_dim_nums = dim_nums; in ExpandInstruction()
74 for (int64 i = 0; i < dim_nums.input_spatial_dimensions().size(); ++i) { in ExpandInstruction()
75 int64 input_spatial_dim = dim_nums.input_spatial_dimensions(i); in ExpandInstruction()
76 int64 output_spatial_dim = dim_nums.output_spatial_dimensions(i); in ExpandInstruction()
77 int64 kernel_spatial_dim = dim_nums.kernel_spatial_dimensions(i); in ExpandInstruction()
Dhlo_matchers.cc226 const DotDimensionNumbers& dim_nums = instruction->dot_dimension_numbers(); in MatchAndExplain() local
227 if (dim_nums.lhs_contracting_dimensions_size() != 1 || in MatchAndExplain()
228 dim_nums.lhs_contracting_dimensions(0) != lhs_contracting_dim_) { in MatchAndExplain()
230 << absl::StrJoin(dim_nums.lhs_contracting_dimensions(), ",") in MatchAndExplain()
235 if (dim_nums.rhs_contracting_dimensions_size() != 1 || in MatchAndExplain()
236 dim_nums.rhs_contracting_dimensions(0) != rhs_contracting_dim_) { in MatchAndExplain()
238 << absl::StrJoin(dim_nums.rhs_contracting_dimensions(), ",") in MatchAndExplain()
/external/tensorflow/tensorflow/compiler/tests/
Dxla_ops_test.py550 dim_nums = xla_data_pb2.DotDimensionNumbers()
551 dim_nums.lhs_contracting_dimensions.append(2)
552 dim_nums.rhs_contracting_dimensions.append(2)
553 dim_nums.rhs_contracting_dimensions.append(3)
558 xla.dot_general(a, b, dim_nums)
564 dim_nums = xla_data_pb2.DotDimensionNumbers()
565 dim_nums.lhs_contracting_dimensions.append(2)
566 dim_nums.rhs_contracting_dimensions.append(3)
571 xla.dot_general(a, b, dim_nums)
577 dim_nums = xla_data_pb2.DotDimensionNumbers()
[all …]
/external/tensorflow/tensorflow/compiler/xla/service/cpu/
Ddot_op_emitter.cc72 DotDimensionNumbers dim_nums; member
81 dim_nums = instr.dot_dimension_numbers(); in DotInfo()
273 CHECK_EQ(dot_info_.dim_nums.lhs_contracting_dimensions_size(), 1); in EmitLinalgMatmul()
274 CHECK_EQ(dot_info_.dim_nums.rhs_contracting_dimensions_size(), 1); in EmitLinalgMatmul()
294 b_exprs[dot_info_.dim_nums.lhs_contracting_dimensions(0)] = reduce_expr; in EmitLinalgMatmul()
295 c_exprs[dot_info_.dim_nums.rhs_contracting_dimensions(0)] = reduce_expr; in EmitLinalgMatmul()
554 const DotDimensionNumbers& dim_nums = dot_info_.dim_nums; in EmitNaiveLlvmIrGemm() local
559 int64 lhs_reduction_dimension = dim_nums.lhs_contracting_dimensions(0); in EmitNaiveLlvmIrGemm()
560 int64 rhs_reduction_dimension = dim_nums.rhs_contracting_dimensions(0); in EmitNaiveLlvmIrGemm()
855 const DotDimensionNumbers& dim_nums = dot_info_.dim_nums; in GetMatMultDims() local
[all …]
Ddot_op_emitter_internal.h41 DotDimensionNumbers dim_nums; member
48 dim_nums = instr.dot_dimension_numbers(); in DotInfo()
/external/tensorflow/tensorflow/compiler/xla/service/gpu/
Dgemm_thunk.cc186 const DotDimensionNumbers &dim_nums = backend_config.dot_dimension_numbers(); in RunGemm() local
187 CHECK_EQ(dim_nums.lhs_batch_dimensions_size(), in RunGemm()
188 dim_nums.rhs_batch_dimensions_size()); in RunGemm()
189 CHECK_EQ(dim_nums.lhs_batch_dimensions_size() + 2, output_shape.rank()); in RunGemm()
191 int64 row_dim = dim_nums.lhs_batch_dimensions_size(); in RunGemm()
192 int64 col_dim = dim_nums.lhs_batch_dimensions_size() + 1; in RunGemm()
197 for (int64 batch_dim : dim_nums.lhs_batch_dimensions()) { in RunGemm()
244 lhs_buffer, lhs_shape, dim_nums.lhs_contracting_dimensions(0) == row_dim); in RunGemm()
246 rhs_buffer, rhs_shape, dim_nums.rhs_contracting_dimensions(0) == col_dim); in RunGemm()
Dgpu_layout_assignment.cc218 DotDimensionNumbers dim_nums = instruction->dot_dimension_numbers(); in AddBackendConstraints() local
219 CHECK_EQ(dim_nums.lhs_batch_dimensions_size(), in AddBackendConstraints()
220 dim_nums.rhs_batch_dimensions_size()); in AddBackendConstraints()
221 CHECK_EQ(dim_nums.lhs_batch_dimensions_size() + 2, in AddBackendConstraints()
223 for (int64 batch_dim : dim_nums.lhs_batch_dimensions()) { in AddBackendConstraints()
/external/tensorflow/tensorflow/compiler/mlir/xla/experimental/conv_emitter/
Dconv_emitter.cc534 const auto& dim_nums = conv->convolution_dimension_numbers(); in EmitConvolutionForwardAsMlir() local
536 GetShapeInfo(conv->operand(0)->shape(), dim_nums.input_batch_dimension(), in EmitConvolutionForwardAsMlir()
537 dim_nums.input_feature_dimension(), in EmitConvolutionForwardAsMlir()
538 dim_nums.input_spatial_dimensions(), builder); in EmitConvolutionForwardAsMlir()
541 conv->operand(1)->shape(), dim_nums.kernel_output_feature_dimension(), in EmitConvolutionForwardAsMlir()
542 dim_nums.kernel_input_feature_dimension(), in EmitConvolutionForwardAsMlir()
543 dim_nums.kernel_spatial_dimensions(), builder); in EmitConvolutionForwardAsMlir()
546 conv->shape().tuple_shapes(0), dim_nums.output_batch_dimension(), in EmitConvolutionForwardAsMlir()
547 dim_nums.output_feature_dimension(), dim_nums.output_spatial_dimensions(), in EmitConvolutionForwardAsMlir()