1 //
2 // Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5
6 #include "DetectionPostProcessLayer.hpp"
7
8 #include "LayerCloneBase.hpp"
9
10 #include <armnn/TypesUtils.hpp>
11 #include <backendsCommon/CpuTensorHandle.hpp>
12 #include <backendsCommon/WorkloadData.hpp>
13 #include <backendsCommon/WorkloadFactory.hpp>
14
15 namespace armnn
16 {
17
DetectionPostProcessLayer(const DetectionPostProcessDescriptor & param,const char * name)18 DetectionPostProcessLayer::DetectionPostProcessLayer(const DetectionPostProcessDescriptor& param, const char* name)
19 : LayerWithParameters(2, 4, LayerType::DetectionPostProcess, param, name)
20 {
21 }
22
CreateWorkload(const armnn::IWorkloadFactory & factory) const23 std::unique_ptr<IWorkload> DetectionPostProcessLayer::CreateWorkload(const armnn::IWorkloadFactory& factory) const
24 {
25 DetectionPostProcessQueueDescriptor descriptor;
26 descriptor.m_Anchors = m_Anchors.get();
27 SetAdditionalInfo(descriptor);
28
29 return factory.CreateDetectionPostProcess(descriptor, PrepInfoAndDesc(descriptor));
30 }
31
Clone(Graph & graph) const32 DetectionPostProcessLayer* DetectionPostProcessLayer::Clone(Graph& graph) const
33 {
34 auto layer = CloneBase<DetectionPostProcessLayer>(graph, m_Param, GetName());
35 layer->m_Anchors = m_Anchors ? std::make_unique<ScopedCpuTensorHandle>(*m_Anchors) : nullptr;
36 return std::move(layer);
37 }
38
ValidateTensorShapesFromInputs()39 void DetectionPostProcessLayer::ValidateTensorShapesFromInputs()
40 {
41 VerifyLayerConnections(2, CHECK_LOCATION());
42
43 const TensorShape& outputShape = GetOutputSlot(0).GetTensorInfo().GetShape();
44
45 VerifyShapeInferenceType(outputShape, m_ShapeInferenceMethod);
46
47 // on this level constant data should not be released.
48 ARMNN_ASSERT_MSG(m_Anchors != nullptr, "DetectionPostProcessLayer: Anchors data should not be null.");
49
50 ARMNN_ASSERT_MSG(GetNumOutputSlots() == 4, "DetectionPostProcessLayer: The layer should return 4 outputs.");
51
52 unsigned int detectedBoxes = m_Param.m_MaxDetections * m_Param.m_MaxClassesPerDetection;
53
54 const TensorShape& inferredDetectionBoxes = TensorShape({ 1, detectedBoxes, 4 });
55 const TensorShape& inferredDetectionScores = TensorShape({ 1, detectedBoxes });
56 const TensorShape& inferredNumberDetections = TensorShape({ 1 });
57
58 ValidateAndCopyShape(outputShape, inferredDetectionBoxes, m_ShapeInferenceMethod, "DetectionPostProcessLayer");
59
60 ValidateAndCopyShape(GetOutputSlot(1).GetTensorInfo().GetShape(),
61 inferredDetectionScores,
62 m_ShapeInferenceMethod,
63 "DetectionPostProcessLayer", 1);
64
65 ValidateAndCopyShape(GetOutputSlot(2).GetTensorInfo().GetShape(),
66 inferredDetectionScores,
67 m_ShapeInferenceMethod,
68 "DetectionPostProcessLayer", 2);
69
70 ValidateAndCopyShape(GetOutputSlot(3).GetTensorInfo().GetShape(),
71 inferredNumberDetections,
72 m_ShapeInferenceMethod,
73 "DetectionPostProcessLayer", 3);
74 }
75
GetConstantTensorsByRef()76 Layer::ConstantTensors DetectionPostProcessLayer::GetConstantTensorsByRef()
77 {
78 return { m_Anchors };
79 }
80
Accept(ILayerVisitor & visitor) const81 void DetectionPostProcessLayer::Accept(ILayerVisitor& visitor) const
82 {
83 ConstTensor anchorTensor(m_Anchors->GetTensorInfo(), m_Anchors->GetConstTensor<void>());
84 visitor.VisitDetectionPostProcessLayer(this, GetParameters(), anchorTensor, GetName());
85 }
86
87 } // namespace armnn
88