Home
last modified time | relevance | path

Searched refs:dim_numbers (Results 1 – 22 of 22) sorted by relevance

/external/tensorflow/tensorflow/compiler/tf2xla/lib/
Dscatter.cc138 xla::ScatterDimensionNumbers dim_numbers; in XlaScatter() local
139 dim_numbers.set_index_vector_dim(indices_are_vectors in XlaScatter()
165 dim_numbers.add_update_window_dims(i); in XlaScatter()
170 dim_numbers.add_inserted_window_dims(i); in XlaScatter()
171 dim_numbers.add_scatter_dims_to_operand_dims(i); in XlaScatter()
193 VLOG(3) << " index_vector_dim: " << dim_numbers.index_vector_dim(); in XlaScatter()
195 << absl::StrJoin(dim_numbers.update_window_dims(), ",") << "]"; in XlaScatter()
197 << absl::StrJoin(dim_numbers.inserted_window_dims(), ",") << "]"; in XlaScatter()
199 << absl::StrJoin(dim_numbers.scatter_dims_to_operand_dims(), ",") in XlaScatter()
203 dim_numbers); in XlaScatter()
/external/tensorflow/tensorflow/compiler/xla/service/
Dgather_expander.cc111 HloInstruction* index_vector, const GatherDimensionNumbers& dim_numbers, in ExpandIndexVectorIntoOperandSpace() argument
132 int64 index_vector_dim_index = FindIndex(dim_numbers.start_index_map(), i); in ExpandIndexVectorIntoOperandSpace()
133 if (index_vector_dim_index != dim_numbers.start_index_map_size()) { in ExpandIndexVectorIntoOperandSpace()
153 const GatherDimensionNumbers& dim_numbers = gather.gather_dimension_numbers(); in GatherLoopBody() local
161 dim_numbers.index_vector_dim() == in GatherLoopBody()
197 ExpandIndexVectorIntoOperandSpace(index_vector, dim_numbers, in GatherLoopBody()
207 AsInt64Slice(dim_numbers.collapsed_slice_dims()))); in GatherLoopBody()
235 const GatherDimensionNumbers& dim_numbers) { in CreateGatherLoopAccumulatorInitValue() argument
240 if (!absl::c_binary_search(dim_numbers.collapsed_slice_dims(), i)) { in CreateGatherLoopAccumulatorInitValue()
318 const GatherDimensionNumbers& dim_numbers = in ExpandInstruction() local
[all …]
Dconvolution_group_converter.cc202 auto dim_numbers = convolution->convolution_dimension_numbers(); in HandleBatchGroupCount() local
218 int64 input_batch_dimension = dim_numbers.input_batch_dimension(); in HandleBatchGroupCount()
219 int64 output_batch_dimension = dim_numbers.output_batch_dimension(); in HandleBatchGroupCount()
220 int64 output_feature_dimension = dim_numbers.output_feature_dimension(); in HandleBatchGroupCount()
243 convolution->window(), dim_numbers, convolution->precision_config())); in HandleBatchGroupCount()
334 auto dim_numbers = convolution->convolution_dimension_numbers(); in HandleConvolution() local
336 int64 kernel_input_feature_dim = dim_numbers.kernel_input_feature_dimension(); in HandleConvolution()
339 dim_numbers.kernel_output_feature_dimension(); in HandleConvolution()
383 convolution->window(), dim_numbers, convolution->precision_config()); in HandleConvolution()
387 int64 activation_input_feature_dim = dim_numbers.input_feature_dimension(); in HandleConvolution()
[all …]
Dindexed_array_analysis.cc254 const Shape& shape, const GatherDimensionNumbers& dim_numbers, in ComputeArrayForGather() argument
256 if (dim_numbers.index_vector_dim() != indices->shape().dimensions_size()) { in ComputeArrayForGather()
261 CHECK_EQ(dim_numbers.start_index_map_size(), 1); in ComputeArrayForGather()
266 if (dim_numbers.collapsed_slice_dims_size() != 1 || in ComputeArrayForGather()
267 dim_numbers.collapsed_slice_dims(0) != dim_numbers.start_index_map(0)) { in ComputeArrayForGather()
279 if (i != dim_numbers.collapsed_slice_dims(0) && in ComputeArrayForGather()
285 << dim_numbers.collapsed_slice_dims(0); in ComputeArrayForGather()
290 int64 source_dim = dim_numbers.start_index_map(0); in ComputeArrayForGather()
293 if (!absl::c_binary_search(dim_numbers.offset_dims(), i)) { in ComputeArrayForGather()
1033 const Shape& shape, const DotDimensionNumbers& dim_numbers, in ComputeArrayForDotWithIndexedLhs() argument
[all …]
Dscatter_expander.cc133 HloInstruction* index_vector, const ScatterDimensionNumbers& dim_numbers, in ExpandIndexVectorIntoOperandSpace() argument
154 FindIndex(dim_numbers.scatter_dims_to_operand_dims(), i); in ExpandIndexVectorIntoOperandSpace()
156 dim_numbers.scatter_dims_to_operand_dims_size()) { in ExpandIndexVectorIntoOperandSpace()
222 const ScatterDimensionNumbers& dim_numbers = in ScatterLoopBody() local
258 ExpandIndexVectorIntoOperandSpace(index_vector, dim_numbers, in ScatterLoopBody()
279 AsInt64Slice(dim_numbers.inserted_window_dims()))); in ScatterLoopBody()
350 const ScatterDimensionNumbers& dim_numbers = in ExpandScatter() local
364 if (i != dim_numbers.index_vector_dim()) { in ExpandScatter()
379 scatter_indices, dim_numbers.index_vector_dim())); in ExpandScatter()
388 updates, AsInt64Slice(dim_numbers.update_window_dims()))); in ExpandScatter()
[all …]
Dbatch_dot_simplification.cc26 const DotDimensionNumbers& dim_numbers = batch_dot->dot_dimension_numbers(); in ElideDegenerateBatchDimensionFromBatchDot() local
34 if (dim_numbers.lhs_contracting_dimensions_size() != 1) { in ElideDegenerateBatchDimensionFromBatchDot()
39 for (int64 batch_dim : dim_numbers.lhs_batch_dimensions()) { in ElideDegenerateBatchDimensionFromBatchDot()
54 DotDimensionNumbers new_dim_numbers = dim_numbers; in ElideDegenerateBatchDimensionFromBatchDot()
58 for (int64 i = 0, e = dim_numbers.lhs_batch_dimensions_size() - in ElideDegenerateBatchDimensionFromBatchDot()
Dshape_inference.cc2868 const GatherDimensionNumbers& dim_numbers) { in ValidateGatherDimensionNumbers() argument
2869 if (!absl::c_is_sorted(dim_numbers.offset_dims())) { in ValidateGatherDimensionNumbers()
2872 StrJoin(dim_numbers.offset_dims(), ", ")); in ValidateGatherDimensionNumbers()
2875 if (absl::c_adjacent_find(dim_numbers.offset_dims()) != in ValidateGatherDimensionNumbers()
2876 dim_numbers.offset_dims().end()) { in ValidateGatherDimensionNumbers()
2879 StrJoin(dim_numbers.offset_dims(), ", ")); in ValidateGatherDimensionNumbers()
2882 const int64 output_offset_dim_count = dim_numbers.offset_dims_size(); in ValidateGatherDimensionNumbers()
2886 for (int i = 0; i < dim_numbers.offset_dims_size(); ++i) { in ValidateGatherDimensionNumbers()
2887 int64 offset_dim = dim_numbers.offset_dims(i); in ValidateGatherDimensionNumbers()
2897 if (dim_numbers.start_index_map_size() != in ValidateGatherDimensionNumbers()
[all …]
Dhlo_cost_analysis_test.cc650 GatherDimensionNumbers dim_numbers; in TEST_F() local
651 dim_numbers.add_offset_dims(1); in TEST_F()
652 dim_numbers.add_collapsed_slice_dims(0); in TEST_F()
653 dim_numbers.add_start_index_map(0); in TEST_F()
654 dim_numbers.set_index_vector_dim(1); in TEST_F()
655 Gather(operand, indices, dim_numbers, {1, 3}); in TEST_F()
677 ScatterDimensionNumbers dim_numbers; in TEST_F() local
678 dim_numbers.set_index_vector_dim(1); in TEST_F()
679 dim_numbers.add_update_window_dims(1); in TEST_F()
680 dim_numbers.add_inserted_window_dims(0); in TEST_F()
[all …]
Dtriangular_solve_expander.cc74 GatherDimensionNumbers dim_numbers; in DiagonalBlocks() local
76 dim_numbers.add_offset_dims(i); in DiagonalBlocks()
77 dim_numbers.add_start_index_map(i); in DiagonalBlocks()
81 dim_numbers.add_offset_dims(ndims - 1); in DiagonalBlocks()
82 dim_numbers.add_offset_dims(ndims); in DiagonalBlocks()
83 dim_numbers.add_start_index_map(ndims - 2); in DiagonalBlocks()
84 dim_numbers.add_start_index_map(ndims - 1); in DiagonalBlocks()
85 dim_numbers.set_index_vector_dim(1); in DiagonalBlocks()
86 diag_blocks = Gather(a, start_indices, dim_numbers, slice_sizes); in DiagonalBlocks()
Dindexed_array_analysis.h265 const Shape& shape, const GatherDimensionNumbers& dim_numbers,
269 const Shape& shape, const DotDimensionNumbers& dim_numbers,
274 const Shape& shape, const DotDimensionNumbers& dim_numbers,
279 const DotDimensionNumbers& dim_numbers,
Dhlo_evaluator.cc364 const DotDimensionNumbers& dim_numbers, in EvaluateDotOp() argument
374 ShapeInference::InferDotOpShape(lhs.shape(), rhs.shape(), dim_numbers)); in EvaluateDotOp()
378 dim_numbers, precision_config); in EvaluateDotOp()
782 const Shape& output_shape, const GatherDimensionNumbers& dim_numbers) { in IterationSpaceForOutputBatchIndices() argument
789 !absl::c_binary_search(dim_numbers.offset_dims(), i); in IterationSpaceForOutputBatchIndices()
801 const GatherDimensionNumbers& dim_numbers) { in IterationSpaceForOutputOffsetIndices() argument
807 absl::c_binary_search(dim_numbers.offset_dims(), i); in IterationSpaceForOutputOffsetIndices()
809 while (absl::c_binary_search(dim_numbers.collapsed_slice_dims(), in IterationSpaceForOutputOffsetIndices()
831 const GatherDimensionNumbers* dim_numbers, const Shape& input_shape, in OutputBatchIndexToInputIndex() argument
833 : dim_numbers_(*dim_numbers), start_indices_(*start_indices) { in OutputBatchIndexToInputIndex()
[all …]
Delemental_ir_emitter.cc1867 const GatherDimensionNumbers& dim_numbers = hlo->gather_dimension_numbers(); in EmitElementalGather() local
1885 if (absl::c_binary_search(dim_numbers.collapsed_slice_dims(), i)) { in EmitElementalGather()
1888 int64 output_window_dim = dim_numbers.offset_dims(operand_index_dim++); in EmitElementalGather()
1898 if (!absl::c_binary_search(dim_numbers.offset_dims(), i)) { in EmitElementalGather()
1907 dim_numbers.index_vector_dim(), in EmitElementalGather()
1915 int64 operand_dim = dim_numbers.start_index_map(dim); in EmitElementalGather()
1941 if (indices_shape.dimensions_size() == dim_numbers.index_vector_dim()) { in EmitElementalGather()
1949 indices_shape.dimensions(dim_numbers.index_vector_dim()); in EmitElementalGather()
1951 gather_index_index_components[dim_numbers.index_vector_dim()] = in EmitElementalGather()
2118 const DotDimensionNumbers& dim_numbers = hlo->dot_dimension_numbers(); in EmitElementalDot() local
[all …]
Dhlo_creation_utils.cc218 const DotDimensionNumbers& dim_numbers, in MakeDotHlo() argument
224 ShapeInference::InferDotOpShape(lhs->shape(), rhs->shape(), dim_numbers)); in MakeDotHlo()
226 dot_shape, lhs, rhs, dim_numbers, precision_config)); in MakeDotHlo()
Dhlo_evaluator_typed_visitor.h2059 const Shape& updates_shape, const ScatterDimensionNumbers& dim_numbers) { in IterationSpaceForUpdateScatterIndices() argument
2065 !absl::c_binary_search(dim_numbers.update_window_dims(), i); in IterationSpaceForUpdateScatterIndices()
2078 const Shape& updates_shape, const ScatterDimensionNumbers& dim_numbers) { in IterationSpaceForUpdateWindowIndices() argument
2084 absl::c_binary_search(dim_numbers.update_window_dims(), i); in IterationSpaceForUpdateWindowIndices()
2106 const ScatterDimensionNumbers* dim_numbers, const Shape& input_shape, in UpdateScatterIndexToInputIndex() argument
2108 : dim_numbers_(*dim_numbers), scatter_indices_(*scatter_indices) { in UpdateScatterIndexToInputIndex()
2236 const ScatterDimensionNumbers& dim_numbers, const Shape& input_shape, in UpdateWindowIndexToInputIndex() argument
2241 if (absl::c_binary_search(dim_numbers.update_window_dims(), i)) { in UpdateWindowIndexToInputIndex()
2250 if (absl::c_binary_search(dim_numbers.inserted_window_dims(), i)) { in UpdateWindowIndexToInputIndex()
2308 const ScatterDimensionNumbers& dim_numbers = in HandleScatter() local
[all …]
Dhlo_creation_utils.h109 const DotDimensionNumbers& dim_numbers,
Dhlo_evaluator.h125 StatusOr<Literal> EvaluateDotOp(const DotDimensionNumbers& dim_numbers,
Dhlo_parser.cc1636 GatherDimensionNumbers dim_numbers = in ParseInstructionRhs() local
1645 dim_numbers, *slice_sizes)); in ParseInstructionRhs()
1672 ScatterDimensionNumbers dim_numbers = in ParseInstructionRhs() local
1681 /*updates=*/operands[2], *update_computation, dim_numbers)); in ParseInstructionRhs()
/external/tensorflow/tensorflow/compiler/tf2xla/kernels/
Dgather_op.cc116 xla::GatherDimensionNumbers dim_numbers; in XlaGather() local
122 dim_numbers.add_collapsed_slice_dims(i); in XlaGather()
131 dim_numbers.add_offset_dims(i); in XlaGather()
135 dim_numbers.add_offset_dims(i + indices_rank - num_index_dims); in XlaGather()
139 dim_numbers.set_index_vector_dim(indices_are_nd ? (indices_shape.dims() - 1) in XlaGather()
142 dim_numbers.add_start_index_map(i); in XlaGather()
145 *gather_output = xla::Gather(input, indices, dim_numbers, slice_sizes); in XlaGather()
/external/tensorflow/tensorflow/compiler/xla/tests/
Dgather_operation_test.cc646 GatherDimensionNumbers dim_numbers; in XLA_TEST_F() local
647 dim_numbers.add_offset_dims(1); in XLA_TEST_F()
648 dim_numbers.add_collapsed_slice_dims(0); in XLA_TEST_F()
649 dim_numbers.add_start_index_map(0); in XLA_TEST_F()
650 dim_numbers.set_index_vector_dim(1); in XLA_TEST_F()
651 Gather(operand, indices, dim_numbers, {1, 3}); in XLA_TEST_F()
/external/tensorflow/tensorflow/compiler/xla/service/gpu/
Dir_emission_utils.cc69 const DotDimensionNumbers& dim_numbers = dot.dot_dimension_numbers(); in DotImplementedAsGemm() local
74 dim_numbers.lhs_batch_dimensions_size())) { in DotImplementedAsGemm()
78 CHECK_EQ(lhs_shape.dimensions(dim_numbers.lhs_contracting_dimensions(0)), in DotImplementedAsGemm()
79 rhs_shape.dimensions(dim_numbers.rhs_contracting_dimensions(0))); in DotImplementedAsGemm()
Dir_emitter_unnested.cc1081 const ScatterDimensionNumbers& dim_numbers = in EmitScatter() local
1094 if (absl::c_binary_search(dim_numbers.update_window_dims(), i)) { in EmitScatter()
1102 dim_numbers.update_window_dims_size()); in EmitScatter()
1109 if (absl::c_binary_search(dim_numbers.inserted_window_dims(), i)) { in EmitScatter()
1124 if (dim_numbers.index_vector_dim() == scatter_indices_shape.rank()) { in EmitScatter()
1127 dim_numbers.index_vector_dim()); in EmitScatter()
1135 raw_scatter_index_multidim.begin() + dim_numbers.index_vector_dim(), in EmitScatter()
1138 for (int64 i = 0, e = dim_numbers.scatter_dims_to_operand_dims_size(); in EmitScatter()
1142 raw_scatter_index_multidim[dim_numbers.index_vector_dim()] = in EmitScatter()
1147 int64 operand_dim = dim_numbers.scatter_dims_to_operand_dims(i); in EmitScatter()
/external/tensorflow/tensorflow/compiler/xla/service/cpu/
Ddot_op_emitter.cc896 Status ValidateDotDimensionNumbers(const DotDimensionNumbers& dim_numbers) { in ValidateDotDimensionNumbers() argument
899 TF_RET_CHECK(dim_numbers.lhs_contracting_dimensions_size() == 1); in ValidateDotDimensionNumbers()
900 std::vector<int64> batch_dim_numbers(dim_numbers.lhs_batch_dimensions_size()); in ValidateDotDimensionNumbers()
903 absl::c_equal(batch_dim_numbers, dim_numbers.lhs_batch_dimensions())); in ValidateDotDimensionNumbers()
905 absl::c_equal(batch_dim_numbers, dim_numbers.rhs_batch_dimensions())); in ValidateDotDimensionNumbers()