1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5
6 #include "RefDetectionPostProcessWorkload.hpp"
7
8 #include "Decoders.hpp"
9 #include "DetectionPostProcess.hpp"
10 #include "Profiling.hpp"
11 #include "RefWorkloadUtils.hpp"
12
13 namespace armnn
14 {
15
RefDetectionPostProcessWorkload(const DetectionPostProcessQueueDescriptor & descriptor,const WorkloadInfo & info)16 RefDetectionPostProcessWorkload::RefDetectionPostProcessWorkload(
17 const DetectionPostProcessQueueDescriptor& descriptor, const WorkloadInfo& info)
18 : BaseWorkload<DetectionPostProcessQueueDescriptor>(descriptor, info),
19 m_Anchors(std::make_unique<ScopedCpuTensorHandle>(*(descriptor.m_Anchors))) {}
20
Execute() const21 void RefDetectionPostProcessWorkload::Execute() const
22 {
23 ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefDetectionPostProcessWorkload_Execute");
24
25 const TensorInfo& boxEncodingsInfo = GetTensorInfo(m_Data.m_Inputs[0]);
26 const TensorInfo& scoresInfo = GetTensorInfo(m_Data.m_Inputs[1]);
27 const TensorInfo& anchorsInfo = m_Anchors->GetTensorInfo();
28
29 const TensorInfo& detectionBoxesInfo = GetTensorInfo(m_Data.m_Outputs[0]);
30 const TensorInfo& detectionClassesInfo = GetTensorInfo(m_Data.m_Outputs[1]);
31 const TensorInfo& detectionScoresInfo = GetTensorInfo(m_Data.m_Outputs[2]);
32 const TensorInfo& numDetectionsInfo = GetTensorInfo(m_Data.m_Outputs[3]);
33
34 auto boxEncodings = MakeDecoder<float>(boxEncodingsInfo, m_Data.m_Inputs[0]->Map());
35 auto scores = MakeDecoder<float>(scoresInfo, m_Data.m_Inputs[1]->Map());
36 auto anchors = MakeDecoder<float>(anchorsInfo, m_Anchors->Map(false));
37
38 float* detectionBoxes = GetOutputTensorData<float>(0, m_Data);
39 float* detectionClasses = GetOutputTensorData<float>(1, m_Data);
40 float* detectionScores = GetOutputTensorData<float>(2, m_Data);
41 float* numDetections = GetOutputTensorData<float>(3, m_Data);
42
43 DetectionPostProcess(boxEncodingsInfo, scoresInfo, anchorsInfo,
44 detectionBoxesInfo, detectionClassesInfo,
45 detectionScoresInfo, numDetectionsInfo, m_Data.m_Parameters,
46 *boxEncodings, *scores, *anchors, detectionBoxes,
47 detectionClasses, detectionScores, numDetections);
48 }
49
50 } //namespace armnn
51