• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "SSDResultDecoder.hpp"
7 
8 #include <cassert>
9 #include <algorithm>
10 #include <cmath>
11 #include <stdexcept>
12 namespace od
13 {
14 
Decode(const common::InferenceResults<float> & networkResults,const common::Size & outputFrameSize,const common::Size & resizedFrameSize,const std::vector<std::string> & labels)15 DetectedObjects SSDResultDecoder::Decode(const common::InferenceResults<float>& networkResults,
16     const common::Size& outputFrameSize,
17     const common::Size& resizedFrameSize,
18     const std::vector<std::string>& labels)
19 {
20     // SSD network outputs 4 tensors: bounding boxes, labels, probabilities, number of detections.
21     if (networkResults.size() != 4)
22     {
23         throw std::runtime_error("Number of outputs from SSD model doesn't equal 4");
24     }
25 
26     DetectedObjects detectedObjects;
27     const int numDetections = static_cast<int>(std::lround(networkResults[3][0]));
28 
29     double longEdgeInput = std::max(resizedFrameSize.m_Width, resizedFrameSize.m_Height);
30     double longEdgeOutput = std::max(outputFrameSize.m_Width, outputFrameSize.m_Height);
31     const double resizeFactor = longEdgeOutput/longEdgeInput;
32 
33     for (int i=0; i<numDetections; ++i)
34     {
35         if (networkResults[2][i] > m_objectThreshold)
36         {
37             DetectedObject detectedObject;
38             detectedObject.SetScore(networkResults[2][i]);
39             auto classId = std::lround(networkResults[1][i]);
40 
41             if (classId < labels.size())
42             {
43                 detectedObject.SetLabel(labels[classId]);
44             }
45             else
46             {
47                 detectedObject.SetLabel(std::to_string(classId));
48             }
49             detectedObject.SetId(classId);
50 
51             // Convert SSD bbox outputs (ratios of image size) to pixel values.
52             double topLeftY = networkResults[0][i*4 + 0] * resizedFrameSize.m_Height;
53             double topLeftX = networkResults[0][i*4 + 1] * resizedFrameSize.m_Width;
54             double botRightY = networkResults[0][i*4 + 2] * resizedFrameSize.m_Height;
55             double botRightX = networkResults[0][i*4 + 3] * resizedFrameSize.m_Width;
56 
57             // Scale the coordinates to output frame size.
58             topLeftY *= resizeFactor;
59             topLeftX *= resizeFactor;
60             botRightY *= resizeFactor;
61             botRightX *= resizeFactor;
62 
63             assert(botRightX > topLeftX);
64             assert(botRightY > topLeftY);
65 
66             // Internal BoundingBox stores box top left x,y and width, height.
67             detectedObject.SetBoundingBox({static_cast<int>(std::round(topLeftX)),
68                                            static_cast<int>(std::round(topLeftY)),
69                                            static_cast<unsigned int>(botRightX - topLeftX),
70                                            static_cast<unsigned int>(botRightY - topLeftY)});
71 
72             detectedObjects.emplace_back(detectedObject);
73         }
74     }
75     return detectedObjects;
76 }
77 
SSDResultDecoder(float ObjectThreshold)78 SSDResultDecoder::SSDResultDecoder(float ObjectThreshold) : m_objectThreshold(ObjectThreshold) {}
79 
80 }// namespace od