Home
last modified time | relevance | path

Searched refs:operand_shape (Results 1 – 25 of 27) sorted by relevance

12

/external/tensorflow/tensorflow/compiler/xla/service/
Dshape_inference.cc407 const Shape& operand_shape, PrimitiveType new_element_type) { in InferConvertShape() argument
408 auto old_element_type = operand_shape.element_type(); in InferConvertShape()
413 ShapeUtil::HumanString(operand_shape), in InferConvertShape()
416 if (!operand_shape.IsArray() || in InferConvertShape()
423 ShapeUtil::HumanString(operand_shape), in InferConvertShape()
427 return ShapeUtil::ChangeElementType(operand_shape, new_element_type); in InferConvertShape()
431 const Shape& operand_shape, PrimitiveType new_element_type) { in InferBitcastConvertShape() argument
432 auto old_element_type = operand_shape.element_type(); in InferBitcastConvertShape()
436 ShapeUtil::HumanString(operand_shape), in InferBitcastConvertShape()
439 if (!operand_shape.IsArray() || in InferBitcastConvertShape()
[all …]
Dbatchnorm_expander.cc206 const Shape operand_shape = operand->shape(); in HandleBatchNormTraining() local
207 PrimitiveType ptype = operand_shape.element_type(); in HandleBatchNormTraining()
221 operand_shape, in HandleBatchNormTraining()
225 for (int64 i = 0; i < operand_shape.rank(); ++i) { in HandleBatchNormTraining()
235 HloInstruction::CreateBroadcast(operand_shape, scale, {feature_index})); in HandleBatchNormTraining()
238 HloInstruction::CreateBroadcast(operand_shape, offset, {feature_index})); in HandleBatchNormTraining()
245 add_binary(operand_shape, HloOpcode::kMultiply, operand, operand); in HandleBatchNormTraining()
260 HloInstruction::CreateBroadcast(operand_shape, mean, {feature_index})); in HandleBatchNormTraining()
274 add(HloInstruction::CreateBroadcast(operand_shape, var, {feature_index})); in HandleBatchNormTraining()
278 add_binary(operand_shape, HloOpcode::kAdd, var_broadcasted, epsilon); in HandleBatchNormTraining()
[all …]
Dshape_inference.h88 static StatusOr<Shape> InferBatchNormTrainingShape(const Shape& operand_shape,
96 const Shape& operand_shape, const Shape& scale_shape,
101 static StatusOr<Shape> InferBatchNormGradShape(const Shape& operand_shape,
159 const Shape& operand_shape, const Shape& init_value, const Window& window,
165 const Shape& operand_shape, const ProgramShape& select_shape,
171 static StatusOr<Shape> InferReverseShape(const Shape& operand_shape,
186 const Shape& operand_shape, absl::Span<const Shape> start_index_shapes,
192 const Shape& operand_shape, const Shape& update_shape,
224 const Shape& operand_shape, const Shape& output_shape,
246 static StatusOr<Shape> InferConvertShape(const Shape& operand_shape,
[all …]
Dhlo_creation_utils.cc314 const Shape& operand_shape = operand->shape(); in CollapseFirstNDims() local
315 CHECK_GE(operand_shape.dimensions_size(), n); in CollapseFirstNDims()
318 new_shape_leading_bound *= operand_shape.dimensions(i); in CollapseFirstNDims()
322 new_shape_dims.reserve(operand_shape.dimensions_size() - n + 1); in CollapseFirstNDims()
325 std::copy(operand_shape.dimensions().begin() + n, in CollapseFirstNDims()
326 operand_shape.dimensions().end(), in CollapseFirstNDims()
330 ShapeUtil::MakeShape(operand_shape.element_type(), new_shape_dims); in CollapseFirstNDims()
339 const Shape& operand_shape = operand->shape(); in PrependDegenerateDims() local
340 new_shape_dims.reserve(n + operand_shape.dimensions_size()); in PrependDegenerateDims()
342 absl::c_copy(operand_shape.dimensions(), std::back_inserter(new_shape_dims)); in PrependDegenerateDims()
[all …]
Dhlo_verifier.cc429 const Shape& operand_shape = instruction.operands()[i]->shape(); in SameElementTypesForOperandsAndToApplyParameters() local
430 if (!ShapeUtil::SameElementType(parameter_shape, operand_shape)) { in SameElementTypesForOperandsAndToApplyParameters()
478 const Shape& operand_shape = broadcast->operand(0)->shape(); in HandleBroadcast() local
480 TF_RET_CHECK(SameElementType(broadcast->shape(), operand_shape)); in HandleBroadcast()
481 TF_RET_CHECK(operand_shape.rank() == broadcast->dimensions().size()); in HandleBroadcast()
482 for (int64 operand_dimension = 0; operand_dimension < operand_shape.rank(); in HandleBroadcast()
488 operand_shape.dimensions(operand_dimension))) in HandleBroadcast()
489 << broadcast->ToString() << " operand shape " << operand_shape; in HandleBroadcast()
496 const Shape& operand_shape = reshape->operand(0)->shape(); in HandleReshape() local
497 TF_RET_CHECK(SameElementType(reshape->shape(), operand_shape)); in HandleReshape()
[all …]
Dindexed_array_analysis.cc345 absl::Span<const int64> operand_shape, in ComputeReshapePassthroughDimPairs() argument
373 FindSuffixWithProduct(operand_shape, result_subarray_size); in ComputeReshapePassthroughDimPairs()
382 << ", operand_shape = [" << StrJoin(operand_shape, ",") << "]"; in ComputeReshapePassthroughDimPairs()
385 result_shape[result_dim] == operand_shape[candidate_operand_dim - 1]) { in ComputeReshapePassthroughDimPairs()
401 VLOG(3) << "For a reshape from [" << StrJoin(operand_shape, ",") << "] to [" in ComputeReshapePassthroughDimPairs()
444 absl::Span<const int64> operand_shape, absl::Span<const int64> result_shape, in FindSourcePositionForPassthroughResultDim() argument
447 << StrJoin(operand_shape, ",") << "], [" << StrJoin(result_shape, ",") in FindSourcePositionForPassthroughResultDim()
451 std::accumulate(operand_shape.begin() + source_passthrough_dim + 1, in FindSourcePositionForPassthroughResultDim()
452 operand_shape.end(), 1LL, std::multiplies<int64>()); in FindSourcePositionForPassthroughResultDim()
Dhlo_cost_analysis_test.cc645 Shape operand_shape = ShapeUtil::MakeShape(S32, {3, 3}); in TEST_F() local
648 auto operand = Parameter(&builder, 0, operand_shape, "operand"); in TEST_F()
670 Shape operand_shape = ShapeUtil::MakeShape(F32, {3, 3}); in TEST_F() local
674 auto operand = Parameter(&builder, 0, operand_shape, "operand"); in TEST_F()
Dhlo_evaluator.cc465 const Shape& operand_shape = operands[i]->shape(); in HandleConcatenate() local
466 CHECK(operand_shape.IsArray()); in HandleConcatenate()
470 ShapeUtil::GetDimension(operand_shape, concat_dim); in HandleConcatenate()
479 const Shape& operand_shape = operand->shape(); in HandleConcatenate() local
482 AsInt64Slice(operand_shape.dimensions()))); in HandleConcatenate()
484 ShapeUtil::GetDimension(operand_shape, concat_dim); in HandleConcatenate()
1083 const Shape& operand_shape = operand.shape(); in HandleGather() local
1109 std::min(operand_shape.dimensions(i) - output_dim_size, in HandleGather()
1115 DCHECK_LT(input_index[i], operand_shape.dimensions(i)); in HandleGather()
Dlayout_assignment.cc1022 Shape operand_shape = operand->shape(); in ChooseOperandLayoutFromOutputLayout() local
1023 *operand_shape.mutable_layout() = in ChooseOperandLayoutFromOutputLayout()
1024 LayoutUtil::GetDefaultLayoutForShape(operand_shape); in ChooseOperandLayoutFromOutputLayout()
1026 ShapeUtil::AlignLayouts(output_shape_with_layout, operand_shape); in ChooseOperandLayoutFromOutputLayout()
1030 LayoutUtil::ValidateLayoutForShape(operand_layout, operand_shape)); in ChooseOperandLayoutFromOutputLayout()
Dhlo_evaluator_typed_visitor.h2321 const Shape& operand_shape = operand.shape(); in HandleScatter() local
2328 std::vector<int64> input_index(operand_shape.dimensions_size()); in HandleScatter()
2331 operand_shape.dimensions_size()); in HandleScatter()
2334 &scatter->scatter_dimension_numbers(), /*input_shape=*/operand_shape, in HandleScatter()
2337 scatter->scatter_dimension_numbers(), /*input_shape=*/operand_shape, in HandleScatter()
2369 operand_shape.dimensions(i) - update_dim_size)) { in HandleScatter()
/external/tensorflow/tensorflow/compiler/xla/tests/
Ddynamic_ops_test.cc511 void RunR3Contiguous(std::vector<int32> operand_shape, int32 index, in RunR3Contiguous() argument
513 const int32 kSeq = operand_shape[0]; in RunR3Contiguous()
514 const int32 kBatch = operand_shape[1]; in RunR3Contiguous()
515 const int32 kDim = operand_shape[2]; in RunR3Contiguous()
667 std::vector<int32> operand_shape({4, 5, 2}); in XLA_TEST_F() local
668 RunR3Contiguous<float>(operand_shape, /*index=*/1, /*size=*/1); in XLA_TEST_F()
673 std::vector<int32> operand_shape({4, 5, 2}); in XLA_TEST_F() local
674 RunR3Contiguous<bfloat16>(operand_shape, /*index=*/1, /*size=*/1); in XLA_TEST_F()
679 std::vector<int32> operand_shape({4, 5, 2}); in XLA_TEST_F() local
680 RunR3Contiguous<float>(operand_shape, /*index=*/1, /*size=*/2); in XLA_TEST_F()
[all …]
Dselect_and_scatter_test.cc42 std::vector<int64> operand_shape; member
73 auto operand_shape = GetParam().operand_shape; in XLA_TEST_P() local
74 Array<float> o(operand_shape); in XLA_TEST_P()
Dgather_operation_test.cc641 Shape operand_shape = ShapeUtil::MakeShape(S32, {3, 3}); in XLA_TEST_F() local
644 auto operand = Parameter(&builder, 0, operand_shape, "operand"); in XLA_TEST_F()
Dtest_utils.cc469 const Shape& operand_shape = use->operand(0)->shape(); in CreateLiteralForConstrainedUses() local
478 std::min(index_bound, operand_shape.dimensions(dim_in_operand)); in CreateLiteralForConstrainedUses()
Dreduce_test.cc939 Shape operand_shape = ShapeUtil::MakeShape(F32, {1}); in XLA_TEST_F() local
940 Reduce(Parameter(&builder, 0, operand_shape, "operand"), in XLA_TEST_F()
/external/tensorflow/tensorflow/lite/toco/graph_transformations/
Dfuse_binary_into_preceding_affine.cc45 const Shape& operand_shape = operand.shape(); in FuseAddOrSubParamsIntoPrecedingAffine() local
54 if (operand_shape.dimensions_count() >= 1 && in FuseAddOrSubParamsIntoPrecedingAffine()
55 operand_shape.dims(operand_shape.dimensions_count() - 1) == in FuseAddOrSubParamsIntoPrecedingAffine()
58 } else if (operand_shape.dimensions_count() == 0 || in FuseAddOrSubParamsIntoPrecedingAffine()
59 operand_shape.dims(operand_shape.dimensions_count() - 1) == 1) { in FuseAddOrSubParamsIntoPrecedingAffine()
114 const Shape& operand_shape = operand.shape(); in FuseMulOrDivParamsIntoPrecedingAffine() local
125 if (operand_shape.dimensions_count() >= 1 && in FuseMulOrDivParamsIntoPrecedingAffine()
126 operand_shape.dims(operand_shape.dimensions_count() - 1) == in FuseMulOrDivParamsIntoPrecedingAffine()
129 } else if (operand_shape.dimensions_count() == 0 || in FuseMulOrDivParamsIntoPrecedingAffine()
130 operand_shape.dims(operand_shape.dimensions_count() - 1) == 1) { in FuseMulOrDivParamsIntoPrecedingAffine()
Dfuse_binary_into_following_affine.cc202 const auto& operand_shape = in Run() local
204 for (const auto& dim : operand_shape.dims()) { in Run()
/external/tensorflow/tensorflow/compiler/xla/service/llvm_ir/
Dir_array.cc181 const Shape& operand_shape, absl::Span<const int64> starts, in SourceIndexOfSlice() argument
198 return Index(source_multi_index, operand_shape, index_type_); in SourceIndexOfSlice()
202 const Shape& shape, const Shape& operand_shape, in SourceIndexOfTranspose() argument
208 if (linear() != nullptr && LayoutUtil::HasLayout(operand_shape) && in SourceIndexOfTranspose()
210 ShapeUtil::TransposeIsBitcast(operand_shape, shape, dimension_mapping)) { in SourceIndexOfTranspose()
211 return Index(operand_multidim_index, linear(), operand_shape, index_type_); in SourceIndexOfTranspose()
218 const Shape& shape, const Shape& operand_shape, in SourceIndexOfBitcast() argument
220 CHECK(LayoutUtil::HasLayout(shape) && LayoutUtil::HasLayout(operand_shape)); in SourceIndexOfBitcast()
224 if (ShapeUtil::ReshapeIsBitcast(operand_shape, shape)) { in SourceIndexOfBitcast()
225 return SourceIndexOfReshape(shape, operand_shape, builder); in SourceIndexOfBitcast()
[all …]
Dir_array.h135 Index SourceIndexOfSlice(const Shape& operand_shape,
142 Index SourceIndexOfTranspose(const Shape& shape, const Shape& operand_shape,
148 Index SourceIndexOfBitcast(const Shape& shape, const Shape& operand_shape,
153 Index SourceIndexOfBroadcast(const Shape& shape, const Shape& operand_shape,
/external/tensorflow/tensorflow/compiler/xla/client/
Dxla_builder.cc218 Shape operand_shape(operand.shape()); in IsConstantVisitor() local
219 if (operand_shape.is_dynamic_dimension(dimension_number)) { in IsConstantVisitor()
435 TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); in AddBroadcastSequence()
437 CHECK(ShapeUtil::IsScalar(operand_shape) || in AddBroadcastSequence()
438 operand_shape.rank() == output_shape.rank()); in AddBroadcastSequence()
440 ShapeUtil::ChangeElementType(output_shape, operand_shape.element_type()); in AddBroadcastSequence()
443 if (ShapeUtil::IsScalar(operand_shape)) { in AddBroadcastSequence()
450 for (int i = 0; i < operand_shape.rank(); i++) { in AddBroadcastSequence()
451 if (operand_shape.dimensions(i) == output_shape.dimensions(i)) { in AddBroadcastSequence()
453 reshaped_dimensions.push_back(operand_shape.dimensions(i)); in AddBroadcastSequence()
[all …]
/external/tensorflow/tensorflow/compiler/tf2xla/kernels/
Daggregate_ops.cc47 TensorShape operand_shape; in Compile() local
49 ctx, GetTensorListBufferShape(ctx->Input(i), &operand_shape)); in Compile()
51 ctx, sum_shape.dim_sizes() == operand_shape.dim_sizes(), in Compile()
55 "Found: ", operand_shape.DebugString())); in Compile()
/external/tensorflow/tensorflow/compiler/xla/client/lib/
Dpooling.cc79 TF_ASSIGN_OR_RETURN(Shape operand_shape, b->GetShape(operand)); in ComputeSums()
138 TF_ASSIGN_OR_RETURN(Shape operand_shape, b->GetShape(operand)); in MaxPool()
139 PrimitiveType dtype = operand_shape.element_type(); in MaxPool()
154 TF_ASSIGN_OR_RETURN(Shape operand_shape, b->GetShape(operand)); in AvgPool()
155 PrimitiveType dtype = operand_shape.element_type(); in AvgPool()
157 std::vector<int64> input_size(operand_shape.dimensions().begin(), in AvgPool()
158 operand_shape.dimensions().end()); in AvgPool()
Dpooling_test.cc39 Shape operand_shape = b->GetShape(input).ValueOrDie(); in MakeGeneralPadding() local
40 std::vector<int64> input_size(operand_shape.dimensions().begin(), in MakeGeneralPadding()
41 operand_shape.dimensions().end()); in MakeGeneralPadding()
/external/tensorflow/tensorflow/compiler/xla/service/cpu/
Dcpu_layout_assignment.cc143 Shape operand_shape( in AddBackendConstraints() local
146 operand_shape, instruction, operand_no)); in AddBackendConstraints()
Dir_emitter.cc482 const Shape& operand_shape = operand->shape(); in HandleOutfeed() local
485 if (!operand_shape.IsTuple()) { in HandleOutfeed()
486 return EmitXfeedTransfer(XfeedKind::kOutfeed, operand_shape, value); in HandleOutfeed()
489 TF_RET_CHECK(!ShapeUtil::IsNestedTuple(operand_shape)); in HandleOutfeed()
491 for (int64 i = 0; i < operand_shape.tuple_shapes_size(); ++i) { in HandleOutfeed()
493 ShapeUtil::GetTupleElementShape(operand_shape, i); in HandleOutfeed()
1337 const Shape& operand_shape = crs->operand(i)->shape(); in HandleAllReduce() local
1338 CHECK(operand_shape.IsArray()) in HandleAllReduce()
1340 operand_ptrs.push_back(EmitBufferPointer(out_slice, operand_shape)); in HandleAllReduce()
1344 /*SrcAlign=*/1, ShapeUtil::ByteSizeOf(operand_shape)); in HandleAllReduce()
[all …]

12