• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * GStreamer gstreamer-onnxclient
3  * Copyright (C) 2021 Collabora Ltd
4  *
5  * gstonnxclient.cpp
6  *
7  * This library is free software; you can redistribute it and/or
8  * modify it under the terms of the GNU Library General Public
9  * License as published by the Free Software Foundation; either
10  * version 2 of the License, or (at your option) any later version.
11  *
12  * This library is distributed in the hope that it will be useful,
13  * but WITHOUT ANY WARRANTY; without even the implied warranty of
14  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
15  * Library General Public License for more details.
16  *
17  * You should have received a copy of the GNU Library General Public
18  * License along with this library; if not, write to the
19  * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
20  * Boston, MA 02110-1301, USA.
21  */
22 
23 #include "gstonnxclient.h"
24 #include <providers/cpu/cpu_provider_factory.h>
25 #ifdef GST_ML_ONNX_RUNTIME_HAVE_CUDA
26 #include <providers/cuda/cuda_provider_factory.h>
27 #endif
28 #include <exception>
29 #include <fstream>
30 #include <iostream>
31 #include <limits>
32 #include <numeric>
33 #include <cmath>
34 #include <sstream>
35 
36 namespace GstOnnxNamespace
37 {
38 template < typename T >
operator <<(std::ostream & os,const std::vector<T> & v)39     std::ostream & operator<< (std::ostream & os, const std::vector < T > &v)
40 {
41     os << "[";
42     for (size_t i = 0; i < v.size (); ++i)
43     {
44       os << v[i];
45       if (i != v.size () - 1)
46       {
47         os << ", ";
48       }
49     }
50     os << "]";
51 
52     return os;
53 }
54 
GstMlOutputNodeInfo(void)55 GstMlOutputNodeInfo::GstMlOutputNodeInfo (void):index
56   (GST_ML_NODE_INDEX_DISABLED),
57   type (ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT)
58 {
59 }
60 
GstOnnxClient()61 GstOnnxClient::GstOnnxClient ():session (nullptr),
62       width (0),
63       height (0),
64       channels (0),
65       dest (nullptr),
66       m_provider (GST_ONNX_EXECUTION_PROVIDER_CPU),
67       inputImageFormat (GST_ML_MODEL_INPUT_IMAGE_FORMAT_HWC),
68       fixedInputImageSize (true)
69 {
70     for (size_t i = 0; i < GST_ML_OUTPUT_NODE_NUMBER_OF; ++i)
71       outputNodeIndexToFunction[i] = (GstMlOutputNodeFunction) i;
72 }
73 
~GstOnnxClient()74 GstOnnxClient::~GstOnnxClient ()
75 {
76     delete session;
77     delete[]dest;
78 }
79 
getEnv(void)80 Ort::Env & GstOnnxClient::getEnv (void)
81 {
82     static Ort::Env env (OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING,
83         "GstOnnxNamespace");
84 
85     return env;
86 }
87 
getWidth(void)88 int32_t GstOnnxClient::getWidth (void)
89 {
90     return width;
91 }
92 
getHeight(void)93 int32_t GstOnnxClient::getHeight (void)
94 {
95     return height;
96 }
97 
isFixedInputImageSize(void)98 bool GstOnnxClient::isFixedInputImageSize (void)
99 {
100     return fixedInputImageSize;
101 }
102 
getOutputNodeName(GstMlOutputNodeFunction nodeType)103 std::string GstOnnxClient::getOutputNodeName (GstMlOutputNodeFunction nodeType)
104 {
105     switch (nodeType) {
106       case GST_ML_OUTPUT_NODE_FUNCTION_DETECTION:
107         return "detection";
108         break;
109       case GST_ML_OUTPUT_NODE_FUNCTION_BOUNDING_BOX:
110         return "bounding box";
111         break;
112       case GST_ML_OUTPUT_NODE_FUNCTION_SCORE:
113         return "score";
114         break;
115       case GST_ML_OUTPUT_NODE_FUNCTION_CLASS:
116         return "label";
117         break;
118     };
119 
120     return "";
121 }
122 
setInputImageFormat(GstMlModelInputImageFormat format)123 void GstOnnxClient::setInputImageFormat (GstMlModelInputImageFormat format)
124 {
125     inputImageFormat = format;
126 }
127 
getInputImageFormat(void)128 GstMlModelInputImageFormat GstOnnxClient::getInputImageFormat (void)
129 {
130     return inputImageFormat;
131 }
132 
getOutputNodeNames(void)133 std::vector < const char *>GstOnnxClient::getOutputNodeNames (void)
134 {
135     return outputNames;
136 }
137 
setOutputNodeIndex(GstMlOutputNodeFunction node,gint index)138 void GstOnnxClient::setOutputNodeIndex (GstMlOutputNodeFunction node,
139       gint index)
140 {
141     g_assert (index < GST_ML_OUTPUT_NODE_NUMBER_OF);
142     outputNodeInfo[node].index = index;
143     if (index != GST_ML_NODE_INDEX_DISABLED)
144       outputNodeIndexToFunction[index] = node;
145 }
146 
getOutputNodeIndex(GstMlOutputNodeFunction node)147 gint GstOnnxClient::getOutputNodeIndex (GstMlOutputNodeFunction node)
148 {
149     return outputNodeInfo[node].index;
150 }
151 
setOutputNodeType(GstMlOutputNodeFunction node,ONNXTensorElementDataType type)152 void GstOnnxClient::setOutputNodeType (GstMlOutputNodeFunction node,
153       ONNXTensorElementDataType type)
154 {
155     outputNodeInfo[node].type = type;
156 }
157 
158 ONNXTensorElementDataType
getOutputNodeType(GstMlOutputNodeFunction node)159       GstOnnxClient::getOutputNodeType (GstMlOutputNodeFunction node)
160 {
161     return outputNodeInfo[node].type;
162 }
163 
hasSession(void)164 bool GstOnnxClient::hasSession (void)
165 {
166     return session != nullptr;
167 }
168 
createSession(std::string modelFile,GstOnnxOptimizationLevel optim,GstOnnxExecutionProvider provider)169 bool GstOnnxClient::createSession (std::string modelFile,
170       GstOnnxOptimizationLevel optim, GstOnnxExecutionProvider provider)
171 {
172     if (session)
173       return true;
174 
175     GraphOptimizationLevel onnx_optim;
176     switch (optim) {
177       case GST_ONNX_OPTIMIZATION_LEVEL_DISABLE_ALL:
178         onnx_optim = GraphOptimizationLevel::ORT_DISABLE_ALL;
179         break;
180       case GST_ONNX_OPTIMIZATION_LEVEL_ENABLE_BASIC:
181         onnx_optim = GraphOptimizationLevel::ORT_ENABLE_BASIC;
182         break;
183       case GST_ONNX_OPTIMIZATION_LEVEL_ENABLE_EXTENDED:
184         onnx_optim = GraphOptimizationLevel::ORT_ENABLE_EXTENDED;
185         break;
186       case GST_ONNX_OPTIMIZATION_LEVEL_ENABLE_ALL:
187         onnx_optim = GraphOptimizationLevel::ORT_ENABLE_ALL;
188         break;
189       default:
190         onnx_optim = GraphOptimizationLevel::ORT_ENABLE_EXTENDED;
191         break;
192     };
193 
194     Ort::SessionOptions sessionOptions;
195     // for debugging
196     //sessionOptions.SetIntraOpNumThreads (1);
197     sessionOptions.SetGraphOptimizationLevel (onnx_optim);
198     m_provider = provider;
199     switch (m_provider) {
200       case GST_ONNX_EXECUTION_PROVIDER_CUDA:
201 #ifdef GST_ML_ONNX_RUNTIME_HAVE_CUDA
202         Ort::ThrowOnError (OrtSessionOptionsAppendExecutionProvider_CUDA
203             (sessionOptions, 0));
204 #else
205         return false;
206 #endif
207         break;
208       default:
209         break;
210 
211     };
212     session = new Ort::Session (getEnv (), modelFile.c_str (), sessionOptions);
213     auto inputTypeInfo = session->GetInputTypeInfo (0);
214     std::vector < int64_t > inputDims =
215         inputTypeInfo.GetTensorTypeAndShapeInfo ().GetShape ();
216     if (inputImageFormat == GST_ML_MODEL_INPUT_IMAGE_FORMAT_HWC) {
217       height = inputDims[1];
218       width = inputDims[2];
219       channels = inputDims[3];
220     } else {
221       channels = inputDims[1];
222       height = inputDims[2];
223       width = inputDims[3];
224     }
225 
226     fixedInputImageSize = width > 0 && height > 0;
227     GST_DEBUG ("Number of Output Nodes: %d", (gint) session->GetOutputCount ());
228 
229     Ort::AllocatorWithDefaultOptions allocator;
230     GST_DEBUG ("Input name: %s", session->GetInputName (0, allocator));
231 
232     for (size_t i = 0; i < session->GetOutputCount (); ++i) {
233       auto output_name = session->GetOutputName (i, allocator);
234       outputNames.push_back (output_name);
235       auto type_info = session->GetOutputTypeInfo (i);
236       auto tensor_info = type_info.GetTensorTypeAndShapeInfo ();
237 
238       if (i < GST_ML_OUTPUT_NODE_NUMBER_OF) {
239         auto function = outputNodeIndexToFunction[i];
240         outputNodeInfo[function].type = tensor_info.GetElementType ();
241       }
242     }
243 
244     return true;
245 }
246 
run(uint8_t * img_data,GstVideoMeta * vmeta,std::string labelPath,float scoreThreshold)247 std::vector < GstMlBoundingBox > GstOnnxClient::run (uint8_t * img_data,
248       GstVideoMeta * vmeta, std::string labelPath, float scoreThreshold)
249 {
250     auto type = getOutputNodeType (GST_ML_OUTPUT_NODE_FUNCTION_CLASS);
251     return (type ==
252         ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) ?
253           doRun < float >(img_data, vmeta, labelPath, scoreThreshold)
254             : doRun < int >(img_data, vmeta, labelPath, scoreThreshold);
255 }
256 
parseDimensions(GstVideoMeta * vmeta)257 void GstOnnxClient::parseDimensions (GstVideoMeta * vmeta)
258 {
259     int32_t newWidth = fixedInputImageSize ? width : vmeta->width;
260     int32_t newHeight = fixedInputImageSize ? height : vmeta->height;
261 
262     if (!dest || width * height < newWidth * newHeight) {
263       delete[] dest;
264       dest = new uint8_t[newWidth * newHeight * channels];
265     }
266     width = newWidth;
267     height = newHeight;
268 }
269 
270 template < typename T > std::vector < GstMlBoundingBox >
doRun(uint8_t * img_data,GstVideoMeta * vmeta,std::string labelPath,float scoreThreshold)271       GstOnnxClient::doRun (uint8_t * img_data, GstVideoMeta * vmeta,
272       std::string labelPath, float scoreThreshold)
273 {
274     std::vector < GstMlBoundingBox > boundingBoxes;
275     if (!img_data)
276       return boundingBoxes;
277 
278     parseDimensions (vmeta);
279 
280     Ort::AllocatorWithDefaultOptions allocator;
281     auto inputName = session->GetInputName (0, allocator);
282     auto inputTypeInfo = session->GetInputTypeInfo (0);
283     std::vector < int64_t > inputDims =
284         inputTypeInfo.GetTensorTypeAndShapeInfo ().GetShape ();
285     inputDims[0] = 1;
286     if (inputImageFormat == GST_ML_MODEL_INPUT_IMAGE_FORMAT_HWC) {
287       inputDims[1] = height;
288       inputDims[2] = width;
289     } else {
290       inputDims[2] = height;
291       inputDims[3] = width;
292     }
293 
294     std::ostringstream buffer;
295     buffer << inputDims;
296     GST_DEBUG ("Input dimensions: %s", buffer.str ().c_str ());
297 
298     // copy video frame
299     uint8_t *srcPtr[3] = { img_data, img_data + 1, img_data + 2 };
300     uint32_t srcSamplesPerPixel = 3;
301     switch (vmeta->format) {
302       case GST_VIDEO_FORMAT_RGBA:
303         srcSamplesPerPixel = 4;
304         break;
305       case GST_VIDEO_FORMAT_BGRA:
306         srcSamplesPerPixel = 4;
307         srcPtr[0] = img_data + 2;
308         srcPtr[1] = img_data + 1;
309         srcPtr[2] = img_data + 0;
310         break;
311       case GST_VIDEO_FORMAT_ARGB:
312         srcSamplesPerPixel = 4;
313         srcPtr[0] = img_data + 1;
314         srcPtr[1] = img_data + 2;
315         srcPtr[2] = img_data + 3;
316         break;
317       case GST_VIDEO_FORMAT_ABGR:
318         srcSamplesPerPixel = 4;
319         srcPtr[0] = img_data + 3;
320         srcPtr[1] = img_data + 2;
321         srcPtr[2] = img_data + 1;
322         break;
323       case GST_VIDEO_FORMAT_BGR:
324         srcPtr[0] = img_data + 2;
325         srcPtr[1] = img_data + 1;
326         srcPtr[2] = img_data + 0;
327         break;
328       default:
329         break;
330     }
331     size_t destIndex = 0;
332     uint32_t stride = vmeta->stride[0];
333     if (inputImageFormat == GST_ML_MODEL_INPUT_IMAGE_FORMAT_HWC) {
334       for (int32_t j = 0; j < height; ++j) {
335         for (int32_t i = 0; i < width; ++i) {
336           for (int32_t k = 0; k < channels; ++k) {
337             dest[destIndex++] = *srcPtr[k];
338             srcPtr[k] += srcSamplesPerPixel;
339           }
340         }
341         // correct for stride
342         for (uint32_t k = 0; k < 3; ++k)
343           srcPtr[k] += stride - srcSamplesPerPixel * width;
344       }
345     } else {
346       size_t frameSize = width * height;
347       uint8_t *destPtr[3] = { dest, dest + frameSize, dest + 2 * frameSize };
348       for (int32_t j = 0; j < height; ++j) {
349         for (int32_t i = 0; i < width; ++i) {
350           for (int32_t k = 0; k < channels; ++k) {
351             destPtr[k][destIndex] = *srcPtr[k];
352             srcPtr[k] += srcSamplesPerPixel;
353           }
354           destIndex++;
355         }
356         // correct for stride
357         for (uint32_t k = 0; k < 3; ++k)
358           srcPtr[k] += stride - srcSamplesPerPixel * width;
359       }
360     }
361 
362     const size_t inputTensorSize = width * height * channels;
363     auto memoryInfo =
364         Ort::MemoryInfo::CreateCpu (OrtAllocatorType::OrtArenaAllocator,
365         OrtMemType::OrtMemTypeDefault);
366     std::vector < Ort::Value > inputTensors;
367     inputTensors.push_back (Ort::Value::CreateTensor < uint8_t > (memoryInfo,
368             dest, inputTensorSize, inputDims.data (), inputDims.size ()));
369     std::vector < const char *>inputNames { inputName };
370 
371     std::vector < Ort::Value > modelOutput = session->Run (Ort::RunOptions { nullptr},
372         inputNames.data (),
373         inputTensors.data (), 1, outputNames.data (), outputNames.size ());
374 
375     auto numDetections =
376         modelOutput[getOutputNodeIndex
377         (GST_ML_OUTPUT_NODE_FUNCTION_DETECTION)].GetTensorMutableData < float >();
378     auto bboxes =
379         modelOutput[getOutputNodeIndex
380         (GST_ML_OUTPUT_NODE_FUNCTION_BOUNDING_BOX)].GetTensorMutableData < float >();
381     auto scores =
382         modelOutput[getOutputNodeIndex
383         (GST_ML_OUTPUT_NODE_FUNCTION_SCORE)].GetTensorMutableData < float >();
384     T *labelIndex = nullptr;
385     if (getOutputNodeIndex (GST_ML_OUTPUT_NODE_FUNCTION_CLASS) !=
386         GST_ML_NODE_INDEX_DISABLED) {
387       labelIndex =
388           modelOutput[getOutputNodeIndex
389           (GST_ML_OUTPUT_NODE_FUNCTION_CLASS)].GetTensorMutableData < T > ();
390     }
391     if (labels.empty () && !labelPath.empty ())
392       labels = ReadLabels (labelPath);
393 
394     for (int i = 0; i < numDetections[0]; ++i) {
395       if (scores[i] > scoreThreshold) {
396         std::string label = "";
397 
398         if (labelIndex && !labels.empty ())
399           label = labels[labelIndex[i] - 1];
400         auto score = scores[i];
401         auto y0 = bboxes[i * 4] * height;
402         auto x0 = bboxes[i * 4 + 1] * width;
403         auto bheight = bboxes[i * 4 + 2] * height - y0;
404         auto bwidth = bboxes[i * 4 + 3] * width - x0;
405         boundingBoxes.push_back (GstMlBoundingBox (label, score, x0, y0, bwidth,
406                 bheight));
407       }
408     }
409     return boundingBoxes;
410 }
411 
412 std::vector < std::string >
ReadLabels(const std::string & labelsFile)413     GstOnnxClient::ReadLabels (const std::string & labelsFile)
414 {
415     std::vector < std::string > labels;
416     std::string line;
417     std::ifstream fp (labelsFile);
418     while (std::getline (fp, line))
419       labels.push_back (line);
420 
421     return labels;
422   }
423 }
424