Home
last modified time | relevance | path

Searched refs:dimension_numbers (Results 1 – 24 of 24) sorted by relevance

/external/tensorflow/tensorflow/compiler/xla/
Dreference_util_test.cc333 ConvolutionDimensionNumbers dimension_numbers; in TEST_F() local
334 dimension_numbers.set_input_batch_dimension(2); in TEST_F()
335 dimension_numbers.set_input_feature_dimension(0); in TEST_F()
336 dimension_numbers.set_output_batch_dimension(2); in TEST_F()
337 dimension_numbers.set_output_feature_dimension(0); in TEST_F()
338 dimension_numbers.add_input_spatial_dimensions(1); in TEST_F()
339 dimension_numbers.add_output_spatial_dimensions(1); in TEST_F()
340 dimension_numbers.add_input_spatial_dimensions(3); in TEST_F()
341 dimension_numbers.add_output_spatial_dimensions(3); in TEST_F()
342 dimension_numbers.set_kernel_output_feature_dimension(0); in TEST_F()
[all …]
Dreference_util.cc468 ConvolutionDimensionNumbers dimension_numbers) { in ConvArray4DGeneralDimensions() argument
471 std::move(dimension_numbers)); in ConvArray4DGeneralDimensions()
Dreference_util.h80 ConvolutionDimensionNumbers dimension_numbers);
/external/tensorflow/tensorflow/compiler/xla/tests/
Dconvolution_dimension_numbers_test.cc43 ConvolutionDimensionNumbers dimension_numbers; in CreateConvDimensionNumbers() local
44 dimension_numbers.set_input_batch_dimension(input_batch); in CreateConvDimensionNumbers()
45 dimension_numbers.set_input_feature_dimension(input_feature); in CreateConvDimensionNumbers()
46 dimension_numbers.add_input_spatial_dimensions(input_first_spatial); in CreateConvDimensionNumbers()
47 dimension_numbers.add_input_spatial_dimensions(input_second_spatial); in CreateConvDimensionNumbers()
48 dimension_numbers.set_kernel_output_feature_dimension(kernel_output_feature); in CreateConvDimensionNumbers()
49 dimension_numbers.set_kernel_input_feature_dimension(kernel_input_feature); in CreateConvDimensionNumbers()
50 dimension_numbers.add_kernel_spatial_dimensions(kernel_first_spatial); in CreateConvDimensionNumbers()
51 dimension_numbers.add_kernel_spatial_dimensions(kernel_second_spatial); in CreateConvDimensionNumbers()
52 dimension_numbers.set_output_batch_dimension(output_batch); in CreateConvDimensionNumbers()
[all …]
/external/tensorflow/tensorflow/compiler/xla/python/
Dxla_data.i452 (DotDimensionNumbers dimension_numbers) {
455 dimension_numbers.mutable_lhs_contracting_dimensions())) {
460 dimension_numbers.mutable_rhs_contracting_dimensions())) {
465 dimension_numbers.mutable_lhs_batch_dimensions())) {
470 dimension_numbers.mutable_rhs_batch_dimensions())) {
474 $1 = &dimension_numbers;
522 (ConvolutionDimensionNumbers dimension_numbers) {
528 dimension_numbers.set_input_batch_dimension(value);
533 dimension_numbers.set_input_feature_dimension(value);
538 dimension_numbers.set_output_batch_dimension(value);
[all …]
Dxla_client.py1604 def DotGeneral(self, lhs, rhs, dimension_numbers): argument
1616 if isinstance(dimension_numbers, tuple):
1617 dimension_numbers = GetDotDimensionsFromLists(dimension_numbers)
1618 return self._client.DotGeneral(lhs, rhs, dimension_numbers)
1636 lhs, rhs, window_strides, pads, (), (), dimension_numbers=None,
1657 dimension_numbers=None, feature_group_count=feature_group_count)
1662 dimension_numbers = ConvolutionDimensionNumbers()
1663 dimension_numbers.input_batch_dimension = 0
1664 dimension_numbers.input_feature_dimension = 1
1665 dimension_numbers.output_batch_dimension = 0
[all …]
Dlocal_computation_builder.cc582 const DotDimensionNumbers& dimension_numbers) { in DotGeneral() argument
583 return xla::DotGeneral(lhs.op(), rhs.op(), dimension_numbers); in DotGeneral()
591 const ConvolutionDimensionNumbers& dimension_numbers, in ConvGeneralDilated() argument
594 lhs_dilation, rhs_dilation, dimension_numbers, in ConvGeneralDilated()
756 const GatherDimensionNumbers& dimension_numbers, in Gather() argument
758 return xla::Gather(input.op(), start_indices.op(), dimension_numbers, in Gather()
765 const ScatterDimensionNumbers& dimension_numbers) { in Scatter() argument
767 update_computation.computation(), dimension_numbers); in Scatter()
Dlocal_computation_builder.h291 const DotDimensionNumbers& dimension_numbers);
299 const ConvolutionDimensionNumbers& dimension_numbers,
374 const GatherDimensionNumbers& dimension_numbers,
379 const ScatterDimensionNumbers& dimension_numbers);
Dxla_client_test.py659 dimension_numbers = (([2], [1]), ([0], [0]))
660 c.DotGeneral(c.Constant(lhs), c.Constant(rhs), dimension_numbers)
669 dimension_numbers = xla_client.DotDimensionNumbers()
670 dimension_numbers.lhs_contracting_dimensions.append(2)
671 dimension_numbers.rhs_contracting_dimensions.append(1)
672 dimension_numbers.lhs_batch_dimensions.append(0)
673 dimension_numbers.rhs_batch_dimensions.append(0)
675 c.DotGeneral(c.Constant(lhs), c.Constant(rhs), dimension_numbers)
727 dimension_numbers = ("NCHW", "OIHW", "NCHW")
730 dimension_numbers)
[all …]
/external/tensorflow/tensorflow/compiler/tf2xla/kernels/
Dimage_resize_ops.cc241 xla::ConvolutionDimensionNumbers dimension_numbers; in ResizeUsingDilationAndConvolution() local
242 dimension_numbers.set_input_batch_dimension(0); in ResizeUsingDilationAndConvolution()
243 dimension_numbers.set_output_batch_dimension(0); in ResizeUsingDilationAndConvolution()
244 dimension_numbers.set_input_feature_dimension(num_spatial_dims + 1); in ResizeUsingDilationAndConvolution()
245 dimension_numbers.set_output_feature_dimension(num_spatial_dims + 1); in ResizeUsingDilationAndConvolution()
247 dimension_numbers.add_input_spatial_dimensions(1 + i); in ResizeUsingDilationAndConvolution()
248 dimension_numbers.add_output_spatial_dimensions(1 + i); in ResizeUsingDilationAndConvolution()
249 dimension_numbers.add_kernel_spatial_dimensions(i); in ResizeUsingDilationAndConvolution()
251 dimension_numbers.set_kernel_input_feature_dimension(num_spatial_dims + 1); in ResizeUsingDilationAndConvolution()
252 dimension_numbers.set_kernel_output_feature_dimension(num_spatial_dims); in ResizeUsingDilationAndConvolution()
[all …]
/external/tensorflow/tensorflow/compiler/xla/service/
Ddynamic_dimension_inference.cc205 const DotDimensionNumbers& dimension_numbers = in HandleDot() local
211 dimension_numbers.rhs_batch_dimensions().begin(), in HandleDot()
212 dimension_numbers.rhs_batch_dimensions().end()); in HandleDot()
214 for (int64 i : dimension_numbers.rhs_batch_dimensions()) { in HandleDot()
220 dimension_numbers.lhs_contracting_dimensions(), i)) { in HandleDot()
230 dimension_numbers.rhs_contracting_dimensions(), i) && in HandleDot()
231 !absl::c_linear_search(dimension_numbers.rhs_batch_dimensions(), in HandleDot()
267 const ConvolutionDimensionNumbers& dimension_numbers = in HandleConvolution() local
271 if (dimension == dimension_numbers.input_batch_dimension()) { in HandleConvolution()
273 dimension_numbers.output_batch_dimension(), in HandleConvolution()
[all …]
Dshape_inference.cc559 const DotDimensionNumbers& dimension_numbers) { in ValidateDotDimensionNumbers() argument
570 AsInt64Slice(dimension_numbers.lhs_contracting_dimensions()); in ValidateDotDimensionNumbers()
572 AsInt64Slice(dimension_numbers.rhs_contracting_dimensions()); in ValidateDotDimensionNumbers()
574 AsInt64Slice(dimension_numbers.lhs_batch_dimensions()); in ValidateDotDimensionNumbers()
576 AsInt64Slice(dimension_numbers.rhs_batch_dimensions()); in ValidateDotDimensionNumbers()
583 dimension_numbers.DebugString()); in ValidateDotDimensionNumbers()
600 dimension_numbers.DebugString()); in ValidateDotDimensionNumbers()
610 const DotDimensionNumbers& dimension_numbers) { in InferDotOpShape() argument
634 TF_RETURN_IF_ERROR(ValidateDotDimensionNumbers(lhs, rhs, dimension_numbers)); in InferDotOpShape()
637 if (dimension_numbers.lhs_contracting_dimensions_size() != in InferDotOpShape()
[all …]
Dshape_inference.h113 const ConvolutionDimensionNumbers& dimension_numbers);
277 const DotDimensionNumbers& dimension_numbers);
Dhlo_creation_utils.cc84 const Window& window, const ConvolutionDimensionNumbers& dimension_numbers, in MakeConvolveHlo() argument
91 window, dimension_numbers)); in MakeConvolveHlo()
94 dimension_numbers, precision_config)); in MakeConvolveHlo()
Dhlo_creation_utils.h59 const Window& window, const ConvolutionDimensionNumbers& dimension_numbers,
Dhlo_instruction.h440 const ConvolutionDimensionNumbers& dimension_numbers,
464 const DotDimensionNumbers& dimension_numbers,
Dhlo_instructions.cc1857 const ConvolutionDimensionNumbers& dimension_numbers, in HloConvolutionInstruction() argument
1863 convolution_dimension_numbers_(dimension_numbers), in HloConvolutionInstruction()
2511 const DotDimensionNumbers& dimension_numbers, in HloDotInstruction() argument
2514 dot_dimension_numbers_(dimension_numbers), in HloDotInstruction()
Dhlo_instructions.h1034 const ConvolutionDimensionNumbers& dimension_numbers,
1466 const DotDimensionNumbers& dimension_numbers,
Dhlo_instruction.cc783 const ConvolutionDimensionNumbers& dimension_numbers, in CreateConvolve() argument
787 dimension_numbers, precision_config); in CreateConvolve()
817 const DotDimensionNumbers& dimension_numbers, in CreateDot() argument
820 shape, lhs, rhs, dimension_numbers, precision_config); in CreateDot()
/external/tensorflow/tensorflow/compiler/xla/client/
Dxla_builder.cc1015 DotDimensionNumbers dimension_numbers; in Dot() local
1016 dimension_numbers.add_lhs_contracting_dimensions( in Dot()
1018 dimension_numbers.add_rhs_contracting_dimensions(0); in Dot()
1019 return DotGeneral(lhs, rhs, dimension_numbers, precision_config); in Dot()
1024 const DotDimensionNumbers& dimension_numbers, in DotGeneral() argument
1032 if (dimension_numbers.rhs_batch_dimensions_size() != 0 || in DotGeneral()
1033 dimension_numbers.lhs_batch_dimensions_size() != 0 || in DotGeneral()
1034 dimension_numbers.rhs_contracting_dimensions_size() != 0 || in DotGeneral()
1035 dimension_numbers.lhs_contracting_dimensions_size() != 0) { in DotGeneral()
1044 dimension_numbers)); in DotGeneral()
[all …]
Dxla_builder.h380 const DotDimensionNumbers& dimension_numbers,
398 const ConvolutionDimensionNumbers& dimension_numbers,
405 const ConvolutionDimensionNumbers& dimension_numbers,
414 const ConvolutionDimensionNumbers& dimension_numbers,
544 const GatherDimensionNumbers& dimension_numbers,
549 const ScatterDimensionNumbers& dimension_numbers);
644 const ConvolutionDimensionNumbers& dimension_numbers) const;
794 const ConvolutionDimensionNumbers& dimension_numbers,
800 const ConvolutionDimensionNumbers& dimension_numbers,
809 const ConvolutionDimensionNumbers& dimension_numbers,
[all …]
/external/tensorflow/tensorflow/compiler/tf2xla/python/
Dxla.py234 dimension_numbers, argument
270 dimension_numbers=dimension_numbers.SerializeToString(),
282 def dot_general(lhs, rhs, dimension_numbers, precision_config=None, name=None): argument
289 dimension_numbers=dimension_numbers.SerializeToString(),
/external/tensorflow/tensorflow/compiler/tests/
Dxla_ops_test.py133 dimension_numbers=dnums)
160 dimension_numbers=dnums,
/external/tensorflow/tensorflow/compiler/xla/g3doc/
Doperation_semantics.md911 <b> `DotGeneral(lhs, rhs, dimension_numbers)` </b>
917 `dimension_numbers` | `DotDimensionNumbers` | array of type T
930 in 'dimension_numbers'.