Home
last modified time | relevance | path

Searched refs:operand_shape (Results 1 – 25 of 46) sorted by relevance

12

/external/tensorflow/tensorflow/compiler/xla/service/
Dshape_inference.cc452 const Shape& operand_shape, PrimitiveType new_element_type) { in InferConvertShape() argument
453 auto old_element_type = operand_shape.element_type(); in InferConvertShape()
458 ShapeUtil::HumanString(operand_shape), in InferConvertShape()
461 if (!operand_shape.IsArray() || in InferConvertShape()
468 ShapeUtil::HumanString(operand_shape), in InferConvertShape()
472 return ShapeUtil::ChangeElementType(operand_shape, new_element_type); in InferConvertShape()
476 const Shape& operand_shape, PrimitiveType new_element_type) { in InferBitcastConvertShape() argument
477 auto old_element_type = operand_shape.element_type(); in InferBitcastConvertShape()
481 ShapeUtil::HumanString(operand_shape), in InferBitcastConvertShape()
484 if (!operand_shape.IsArray() || in InferBitcastConvertShape()
[all …]
Dshape_inference.h88 static StatusOr<Shape> InferBatchNormTrainingShape(const Shape& operand_shape,
96 const Shape& operand_shape, const Shape& scale_shape,
101 static StatusOr<Shape> InferBatchNormGradShape(const Shape& operand_shape,
130 static StatusOr<Shape> InferAllGatherShape(const Shape& operand_shape,
167 const Shape& operand_shape, const Shape& init_value, const Window& window,
169 static StatusOr<Shape> InferReduceWindowShape(const Shape& operand_shape,
184 const Shape& operand_shape, const ProgramShape& select_shape,
190 static StatusOr<Shape> InferReverseShape(const Shape& operand_shape,
205 const Shape& operand_shape, absl::Span<const Shape> start_index_shapes,
211 const Shape& operand_shape, const Shape& update_shape,
[all …]
Dbatchnorm_expander.cc173 const Shape operand_shape = operand->shape(); in HandleBatchNormTraining() local
174 PrimitiveType ptype = operand_shape.element_type(); in HandleBatchNormTraining()
188 operand_shape, in HandleBatchNormTraining()
192 for (int64 i = 0; i < operand_shape.rank(); ++i) { in HandleBatchNormTraining()
202 HloInstruction::CreateBroadcast(operand_shape, scale, {feature_index})); in HandleBatchNormTraining()
205 HloInstruction::CreateBroadcast(operand_shape, offset, {feature_index})); in HandleBatchNormTraining()
212 add_binary(operand_shape, HloOpcode::kMultiply, operand, operand); in HandleBatchNormTraining()
227 HloInstruction::CreateBroadcast(operand_shape, mean, {feature_index})); in HandleBatchNormTraining()
241 add(HloInstruction::CreateBroadcast(operand_shape, var, {feature_index})); in HandleBatchNormTraining()
245 add_binary(operand_shape, HloOpcode::kAdd, var_broadcasted, epsilon); in HandleBatchNormTraining()
[all …]
Dhlo_creation_utils.cc442 const Shape& operand_shape = operand->shape(); in CollapseFirstNDims() local
443 CHECK_GE(operand_shape.dimensions_size(), n); in CollapseFirstNDims()
446 new_shape_leading_bound *= operand_shape.dimensions(i); in CollapseFirstNDims()
450 new_shape_dims.reserve(operand_shape.dimensions_size() - n + 1); in CollapseFirstNDims()
453 std::copy(operand_shape.dimensions().begin() + n, in CollapseFirstNDims()
454 operand_shape.dimensions().end(), in CollapseFirstNDims()
458 ShapeUtil::MakeShape(operand_shape.element_type(), new_shape_dims); in CollapseFirstNDims()
467 const Shape& operand_shape = operand->shape(); in PrependDegenerateDims() local
468 new_shape_dims.reserve(n + operand_shape.dimensions_size()); in PrependDegenerateDims()
470 absl::c_copy(operand_shape.dimensions(), std::back_inserter(new_shape_dims)); in PrependDegenerateDims()
[all …]
Dhlo_verifier.cc647 const Shape& operand_shape = instruction.operands()[i]->shape(); in SameElementTypesForOperandsAndToApplyParameters() local
648 if (!ShapeUtil::SameElementType(parameter_shape, operand_shape)) { in SameElementTypesForOperandsAndToApplyParameters()
699 const Shape& operand_shape = broadcast->operand(0)->shape(); in HandleBroadcast() local
701 TF_RET_CHECK(SameElementType(broadcast->shape(), operand_shape)); in HandleBroadcast()
702 TF_RET_CHECK(operand_shape.rank() == broadcast->dimensions().size()); in HandleBroadcast()
703 for (int64 operand_dimension = 0; operand_dimension < operand_shape.rank(); in HandleBroadcast()
709 operand_shape.dimensions(operand_dimension))) in HandleBroadcast()
710 << broadcast->ToString() << " operand shape " << operand_shape; in HandleBroadcast()
717 const Shape& operand_shape = dynamic_reshape->operand(0)->shape(); in HandleDynamicReshape() local
718 TF_RET_CHECK(SameElementType(dynamic_reshape->shape(), operand_shape)); in HandleDynamicReshape()
[all …]
Dlogistic_expander.cc43 const Shape operand_shape = operand->shape(); in ExpandLogisticWithTanh() local
59 const Shape operand_shape = operand->shape(); in ExpandLogisticWithExp() local
Ddynamic_padder.cc347 const Shape operand_shape = reshape->operand(0)->shape(); in RewriteDynamicReshapeSplitInput() local
352 ShapeUtil::MakeShape(xla::S32, {operand_shape.dimensions(input_dim)}); in RewriteDynamicReshapeSplitInput()
414 dim->set_size(operand_shape.dimensions(input_dim)); in RewriteDynamicReshapeSplitInput()
416 dim->set_padding_low(operand_shape.dimensions(input_dim) - 1); in RewriteDynamicReshapeSplitInput()
431 for (int64 i = 0; i < operand_shape.dimensions_size(); ++i) { in RewriteDynamicReshapeSplitInput()
449 LiteralUtil::CreateR0<int32>(operand_shape.dimensions(input_dim)))); in RewriteDynamicReshapeSplitInput()
452 operand_shape, reshape->mutable_operand(0), operand_static_dim_size, in RewriteDynamicReshapeSplitInput()
455 std::vector<int64> slice_sizes(operand_shape.dimensions().begin(), in RewriteDynamicReshapeSplitInput()
456 operand_shape.dimensions().end()); in RewriteDynamicReshapeSplitInput()
459 ShapeUtil::MakeShape(operand_shape.element_type(), in RewriteDynamicReshapeSplitInput()
[all …]
Dindexed_array_analysis.cc345 absl::Span<const int64> operand_shape, in ComputeReshapePassthroughDimPairs() argument
373 FindSuffixWithProduct(operand_shape, result_subarray_size); in ComputeReshapePassthroughDimPairs()
382 << ", operand_shape = [" << StrJoin(operand_shape, ",") << "]"; in ComputeReshapePassthroughDimPairs()
385 result_shape[result_dim] == operand_shape[candidate_operand_dim - 1]) { in ComputeReshapePassthroughDimPairs()
401 VLOG(3) << "For a reshape from [" << StrJoin(operand_shape, ",") << "] to [" in ComputeReshapePassthroughDimPairs()
444 absl::Span<const int64> operand_shape, absl::Span<const int64> result_shape, in FindSourcePositionForPassthroughResultDim() argument
447 << StrJoin(operand_shape, ",") << "], [" << StrJoin(result_shape, ",") in FindSourcePositionForPassthroughResultDim()
451 std::accumulate(operand_shape.begin() + source_passthrough_dim + 1, in FindSourcePositionForPassthroughResultDim()
452 operand_shape.end(), 1LL, std::multiplies<int64>()); in FindSourcePositionForPassthroughResultDim()
Dhlo_sharding_util.cc800 const Shape& operand_shape, const HloSharding& operand_sharding, in PassthroughOperandToGatherOutputOrScatterUpdate() argument
811 for (int64 i = 0; i < operand_shape.rank(); ++i) { in PassthroughOperandToGatherOutputOrScatterUpdate()
821 if (slice_size[i] != operand_shape.dimensions(i) && dim_partitions > 1) { in PassthroughOperandToGatherOutputOrScatterUpdate()
846 const Shape& operand_shape, const HloSharding& update_or_gather_sharding, in PassthroughGatherOutputOrScatterUpdateToOperand() argument
854 std::vector<int64> passthrough_tile(operand_shape.rank(), 1); in PassthroughGatherOutputOrScatterUpdateToOperand()
856 for (int64 i = 0; i < operand_shape.rank(); ++i) { in PassthroughGatherOutputOrScatterUpdateToOperand()
865 if (slice_size[i] != operand_shape.dimensions(i) && dim_partitions > 1) { in PassthroughGatherOutputOrScatterUpdateToOperand()
948 const Shape& output_shape, const Shape& operand_shape) { in GatherOutputShardingFromDataOperand() argument
957 operand_shape, data_operand_sharding, output_shape, collapsed_slice_dims, in GatherOutputShardingFromDataOperand()
/external/tensorflow/tensorflow/compiler/xla/tests/
Ddynamic_ops_test.cc512 void RunR3Contiguous(std::vector<int32> operand_shape, int32 index, in RunR3Contiguous() argument
514 const int32 kSeq = operand_shape[0]; in RunR3Contiguous()
515 const int32 kBatch = operand_shape[1]; in RunR3Contiguous()
516 const int32 kDim = operand_shape[2]; in RunR3Contiguous()
668 std::vector<int32> operand_shape({4, 5, 2}); in XLA_TEST_F() local
669 RunR3Contiguous<float>(operand_shape, /*index=*/1, /*size=*/1); in XLA_TEST_F()
674 std::vector<int32> operand_shape({4, 5, 2}); in XLA_TEST_F() local
675 RunR3Contiguous<bfloat16>(operand_shape, /*index=*/1, /*size=*/1); in XLA_TEST_F()
680 std::vector<int32> operand_shape({4, 5, 2}); in XLA_TEST_F() local
681 RunR3Contiguous<float>(operand_shape, /*index=*/1, /*size=*/2); in XLA_TEST_F()
[all …]
Dselect_and_scatter_test.cc42 std::vector<int64> operand_shape; member
73 auto operand_shape = GetParam().operand_shape; in XLA_TEST_P() local
74 Array<float> o(operand_shape); in XLA_TEST_P()
/external/tensorflow/tensorflow/compiler/xla/service/gpu/
Dreduction_layout_normalizer.cc41 const Shape &operand_shape = operand->shape(); in HandleReduce() local
42 const Layout &operand_layout = operand_shape.layout(); in HandleReduce()
69 for (int i = 0; i < operand_shape.rank(); i++) { in HandleReduce()
72 int64 major_to_minor_dim_idx = operand_shape.rank() - i - 1; in HandleReduce()
74 int64 dim_size = operand_shape.dimensions(logical_dim); in HandleReduce()
93 Shape new_operand_shape = ShapeUtil::MakeShape(operand_shape.element_type(), in HandleReduce()
95 if (new_operand_shape == operand_shape) { in HandleReduce()
Dir_emission_utils.cc318 Shape operand_shape = TypeToShape(input.getType()); in IsReductionFromOrToContiguousDimensions() local
325 *operand_shape.mutable_layout() = LayoutUtil::MakeLayout(minor_to_major); in IsReductionFromOrToContiguousDimensions()
339 for (int64 dim = 0; dim < operand_shape.dimensions().size(); ++dim) { in IsReductionFromOrToContiguousDimensions()
349 if (!LayoutUtil::AreDimensionsConsecutive(operand_shape.layout(), in IsReductionFromOrToContiguousDimensions()
351 !LayoutUtil::AreDimensionsConsecutive(operand_shape.layout(), in IsReductionFromOrToContiguousDimensions()
357 GetReductionKindAndContiguousComponentsImpl(operand_shape, dimensions); in IsReductionFromOrToContiguousDimensions()
406 Shape operand_shape = TypeToShape(input.getType()); in GetReductionKindAndContiguousComponents() local
415 return GetReductionKindAndContiguousComponentsImpl(operand_shape, dimensions); in GetReductionKindAndContiguousComponents()
/external/tensorflow/tensorflow/lite/toco/graph_transformations/
Dfuse_binary_into_preceding_affine.cc58 const Shape& operand_shape = operand.shape(); in FuseAddOrSubParamsIntoPrecedingAffine() local
67 if (operand_shape.dimensions_count() >= 1 && in FuseAddOrSubParamsIntoPrecedingAffine()
68 operand_shape.dims(operand_shape.dimensions_count() - 1) == in FuseAddOrSubParamsIntoPrecedingAffine()
71 } else if (operand_shape.dimensions_count() == 0 || in FuseAddOrSubParamsIntoPrecedingAffine()
72 operand_shape.dims(operand_shape.dimensions_count() - 1) == 1) { in FuseAddOrSubParamsIntoPrecedingAffine()
128 const Shape& operand_shape = operand.shape(); in FuseMulOrDivParamsIntoPrecedingAffine() local
139 if (operand_shape.dimensions_count() >= 1 && in FuseMulOrDivParamsIntoPrecedingAffine()
140 operand_shape.dims(operand_shape.dimensions_count() - 1) == in FuseMulOrDivParamsIntoPrecedingAffine()
143 } else if (operand_shape.dimensions_count() == 0 || in FuseMulOrDivParamsIntoPrecedingAffine()
144 operand_shape.dims(operand_shape.dimensions_count() - 1) == 1) { in FuseMulOrDivParamsIntoPrecedingAffine()
/external/tensorflow/tensorflow/compiler/xla/service/llvm_ir/
Dir_array.cc268 const Shape& operand_shape, absl::Span<const int64> starts, in SourceIndexOfSlice() argument
282 return Index(source_multi_index, operand_shape, index_type_); in SourceIndexOfSlice()
286 const Shape& shape, const Shape& operand_shape, in SourceIndexOfTranspose() argument
291 if (linear() != nullptr && LayoutUtil::HasLayout(operand_shape) && in SourceIndexOfTranspose()
293 ShapeUtil::TransposeIsBitcast(operand_shape, shape, dimension_mapping)) { in SourceIndexOfTranspose()
294 return Index(operand_multidim_index, linear(), operand_shape, index_type_); in SourceIndexOfTranspose()
297 return Index(operand_multidim_index, operand_shape, index_type_); in SourceIndexOfTranspose()
301 const Shape& shape, const Shape& operand_shape, in SourceIndexOfBitcast() argument
303 CHECK(LayoutUtil::HasLayout(shape) && LayoutUtil::HasLayout(operand_shape)); in SourceIndexOfBitcast()
308 if (ShapeUtil::ReshapeIsBitcast(operand_shape, shape)) { in SourceIndexOfBitcast()
[all …]
Dir_array.h133 Index SourceIndexOfSlice(const Shape& operand_shape,
141 const Shape& shape, const Shape& operand_shape,
146 Index SourceIndexOfBitcast(const Shape& shape, const Shape& operand_shape,
151 Index SourceIndexOfBroadcast(const Shape& shape, const Shape& operand_shape,
/external/tensorflow/tensorflow/compiler/xla/client/
Dxla_builder.cc568 TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand)); in AddBroadcastSequence()
570 CHECK(ShapeUtil::IsScalar(*operand_shape) || in AddBroadcastSequence()
571 operand_shape->rank() == output_shape.rank()); in AddBroadcastSequence()
573 ShapeUtil::ChangeElementType(output_shape, operand_shape->element_type()); in AddBroadcastSequence()
576 if (ShapeUtil::IsScalar(*operand_shape)) { in AddBroadcastSequence()
583 for (int i = 0; i < operand_shape->rank(); i++) { in AddBroadcastSequence()
584 if (operand_shape->dimensions(i) == output_shape.dimensions(i)) { in AddBroadcastSequence()
586 reshaped_dimensions.push_back(operand_shape->dimensions(i)); in AddBroadcastSequence()
588 TF_RET_CHECK(operand_shape->dimensions(i) == 1) in AddBroadcastSequence()
591 << *operand_shape << "; output_shape: " << output_shape; in AddBroadcastSequence()
[all …]
/external/tensorflow/tensorflow/compiler/tf2xla/kernels/
Daggregate_ops.cc64 xla::Shape operand_shape; in Compile() local
66 ctx, GetTensorListBufferShape(ctx->Input(i), &operand_shape)); in Compile()
68 ctx, sum_shape.dimensions() == operand_shape.dimensions(), in Compile()
72 "Found: ", operand_shape.DebugString())); in Compile()
/external/tensorflow/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/
Dmaterialize_broadcasts.cc47 ArrayRef<int64_t> operand_shape = operand_type.getShape(); in matchAndRewrite() local
54 rewriter.getI64TensorAttr(operand_shape)); in matchAndRewrite()
62 rewriter.getI64TensorAttr(operand_shape)); in matchAndRewrite()
Dlhlo_legalize_to_parallel_loops.cc109 auto operand_shape = operand.getType().template cast<MemRefType>().getShape(); in MapWindowIvsToInput() local
125 GetStaticOrDynamicDim(loc, operand, i, operand_shape[i], b); in MapWindowIvsToInput()
239 auto operand_shape = operand.getType().cast<MemRefType>().getShape(); in CreateReduceOpInNestedParallelLoops() local
240 for (auto dim : llvm::enumerate(operand_shape)) { in CreateReduceOpInNestedParallelLoops()
279 indices.reserve(operand_shape.size()); in CreateReduceOpInNestedParallelLoops()
284 for (unsigned i = 0, e = operand_shape.size(); i < e; ++i) { in CreateReduceOpInNestedParallelLoops()
Dlegalize_to_linalg.cc462 auto operand_shape = operand_type.getShape(); in getIndexingMaps() local
470 bool expansion_needed = operand_shape[broadcastDim.index()] == 1 && in getIndexingMaps()
593 auto operand_shape = operand_type.getShape(); in InsertReshapeIfNecessary() local
617 operand_shape[index] == 1 && result_shape[dim] != 1; in InsertReshapeIfNecessary()
621 new_shape.push_back(operand_shape[index]); in InsertReshapeIfNecessary()
644 if (new_shape.size() < operand_shape.size()) { in InsertReshapeIfNecessary()
669 auto operand_shape = operand_type.getShape(); in getIndexingMaps() local
676 operand_shape[broadcast_dim.index()] == 1 && result_shape[size] != 1; in getIndexingMaps()
928 auto operand_shape = in matchAndRewrite() local
930 if (!operand_shape || !operand_shape.hasRank()) { in matchAndRewrite()
[all …]
/external/tensorflow/tensorflow/compiler/xla/client/lib/
Dpooling.cc79 TF_ASSIGN_OR_RETURN(Shape operand_shape, b->GetShape(operand)); in ComputeSums()
139 TF_ASSIGN_OR_RETURN(Shape operand_shape, b->GetShape(operand)); in MaxPool()
140 PrimitiveType dtype = operand_shape.element_type(); in MaxPool()
155 TF_ASSIGN_OR_RETURN(Shape operand_shape, b->GetShape(operand)); in AvgPool()
156 PrimitiveType dtype = operand_shape.element_type(); in AvgPool()
158 std::vector<int64> input_size(operand_shape.dimensions().begin(), in AvgPool()
159 operand_shape.dimensions().end()); in AvgPool()
Dpooling_test.cc39 Shape operand_shape = b->GetShape(input).ValueOrDie(); in MakeGeneralPadding() local
40 std::vector<int64> input_size(operand_shape.dimensions().begin(), in MakeGeneralPadding()
41 operand_shape.dimensions().end()); in MakeGeneralPadding()
/external/tensorflow/tensorflow/compiler/xla/service/cpu/
Dcpu_layout_assignment.cc143 Shape operand_shape( in AddBackendConstraints() local
146 operand_shape, instruction, operand_no)); in AddBackendConstraints()
Dir_emitter.cc523 const Shape& operand_shape = operand->shape(); in HandleOutfeed() local
526 if (!operand_shape.IsTuple()) { in HandleOutfeed()
527 return EmitXfeedTransfer(XfeedKind::kOutfeed, operand_shape, value); in HandleOutfeed()
530 TF_RET_CHECK(!ShapeUtil::IsNestedTuple(operand_shape)); in HandleOutfeed()
532 for (int64 i = 0; i < operand_shape.tuple_shapes_size(); ++i) { in HandleOutfeed()
534 ShapeUtil::GetTupleElementShape(operand_shape, i); in HandleOutfeed()
1088 const Shape& operand_shape = crs->operand(i)->shape(); in HandleAllReduceSingleReplica() local
1089 CHECK(operand_shape.IsArray()) in HandleAllReduceSingleReplica()
1091 operand_ptrs.push_back(EmitBufferPointer(out_slice, operand_shape)); in HandleAllReduceSingleReplica()
1095 /*SrcAlign=*/llvm::Align(1), ShapeUtil::ByteSizeOf(operand_shape)); in HandleAllReduceSingleReplica()
[all …]

12