• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "ObjectDetectionPipeline.hpp"
7 #include "ImageUtils.hpp"
8 
9 namespace od
10 {
11 
ObjDetectionPipeline(std::unique_ptr<common::ArmnnNetworkExecutor<float>> executor,std::unique_ptr<IDetectionResultDecoder> decoder)12 ObjDetectionPipeline::ObjDetectionPipeline(std::unique_ptr<common::ArmnnNetworkExecutor<float>> executor,
13                                            std::unique_ptr<IDetectionResultDecoder> decoder) :
14     m_executor(std::move(executor)),
15     m_decoder(std::move(decoder)){}
16 
Inference(const cv::Mat & processed,common::InferenceResults<float> & result)17 void od::ObjDetectionPipeline::Inference(const cv::Mat& processed, common::InferenceResults<float>& result)
18 {
19     m_executor->Run(processed.data, processed.total() * processed.elemSize(), result);
20 }
21 
PostProcessing(common::InferenceResults<float> & inferenceResult,const std::function<void (DetectedObjects)> & callback)22 void ObjDetectionPipeline::PostProcessing(common::InferenceResults<float>& inferenceResult,
23         const std::function<void (DetectedObjects)>& callback)
24 {
25     DetectedObjects detections = m_decoder->Decode(inferenceResult, m_inputImageSize,
26                                            m_executor->GetImageAspectRatio(), {});
27     if (callback)
28     {
29         callback(detections);
30     }
31 }
32 
PreProcessing(const cv::Mat & frame,cv::Mat & processed)33 void ObjDetectionPipeline::PreProcessing(const cv::Mat& frame, cv::Mat& processed)
34 {
35     m_inputImageSize.m_Height = frame.rows;
36     m_inputImageSize.m_Width = frame.cols;
37     ResizeWithPad(frame, processed, m_processedFrame, m_executor->GetImageAspectRatio());
38 }
39 
MobileNetSSDv1(std::unique_ptr<common::ArmnnNetworkExecutor<float>> executor,float objectThreshold)40 MobileNetSSDv1::MobileNetSSDv1(std::unique_ptr<common::ArmnnNetworkExecutor<float>> executor,
41                                float objectThreshold) :
42     ObjDetectionPipeline(std::move(executor),
43                          std::make_unique<SSDResultDecoder>(objectThreshold))
44 {}
45 
PreProcessing(const cv::Mat & frame,cv::Mat & processed)46 void MobileNetSSDv1::PreProcessing(const cv::Mat& frame, cv::Mat& processed)
47 {
48     ObjDetectionPipeline::PreProcessing(frame, processed);
49     if (m_executor->GetInputDataType() == armnn::DataType::Float32)
50     {
51         // [0, 255] => [-1.0, 1.0]
52         processed.convertTo(processed, CV_32FC3, 1 / 127.5, -1);
53     }
54 }
YoloV3Tiny(std::unique_ptr<common::ArmnnNetworkExecutor<float>> executor,float NMSThreshold,float ClsThreshold,float ObjectThreshold)55 YoloV3Tiny::YoloV3Tiny(std::unique_ptr<common::ArmnnNetworkExecutor<float>> executor,
56                        float NMSThreshold, float ClsThreshold, float ObjectThreshold) :
57     ObjDetectionPipeline(std::move(executor),
58                          std::move(std::make_unique<YoloResultDecoder>(NMSThreshold,
59                                                                        ClsThreshold,
60                                                                        ObjectThreshold)))
61 {}
62 
PreProcessing(const cv::Mat & frame,cv::Mat & processed)63 void YoloV3Tiny::PreProcessing(const cv::Mat& frame, cv::Mat& processed)
64 {
65     ObjDetectionPipeline::PreProcessing(frame, processed);
66     if (m_executor->GetInputDataType() == armnn::DataType::Float32)
67     {
68         processed.convertTo(processed, CV_32FC3);
69     }
70 }
71 
CreatePipeline(common::PipelineOptions & config)72 IPipelinePtr CreatePipeline(common::PipelineOptions& config)
73 {
74     auto executor = std::make_unique<common::ArmnnNetworkExecutor<float>>(config.m_ModelFilePath,
75                                                                           config.m_backends,
76                                                                           config.m_ProfilingEnabled);
77     if (config.m_ModelName == "SSD_MOBILE")
78     {
79         float detectionThreshold = 0.5;
80 
81         return std::make_unique<od::MobileNetSSDv1>(std::move(executor),
82                                                     detectionThreshold
83         );
84     }
85     else if (config.m_ModelName == "YOLO_V3_TINY")
86     {
87         float NMSThreshold = 0.6f;
88         float ClsThreshold = 0.6f;
89         float ObjectThreshold = 0.6f;
90         return std::make_unique<od::YoloV3Tiny>(std::move(executor),
91                                                 NMSThreshold,
92                                                 ClsThreshold,
93                                                 ObjectThreshold
94         );
95     }
96     else
97     {
98         throw std::invalid_argument("Unknown Model name: " + config.m_ModelName + " supplied by user.");
99     }
100 
101 }
102 }// namespace od
103