1 /*
2 * GStreamer gstreamer-onnxclient
3 * Copyright (C) 2021 Collabora Ltd
4 *
5 * gstonnxclient.cpp
6 *
7 * This library is free software; you can redistribute it and/or
8 * modify it under the terms of the GNU Library General Public
9 * License as published by the Free Software Foundation; either
10 * version 2 of the License, or (at your option) any later version.
11 *
12 * This library is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
15 * Library General Public License for more details.
16 *
17 * You should have received a copy of the GNU Library General Public
18 * License along with this library; if not, write to the
19 * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
20 * Boston, MA 02110-1301, USA.
21 */
22
23 #include "gstonnxclient.h"
24 #include <providers/cpu/cpu_provider_factory.h>
25 #ifdef GST_ML_ONNX_RUNTIME_HAVE_CUDA
26 #include <providers/cuda/cuda_provider_factory.h>
27 #endif
28 #include <exception>
29 #include <fstream>
30 #include <iostream>
31 #include <limits>
32 #include <numeric>
33 #include <cmath>
34 #include <sstream>
35
36 namespace GstOnnxNamespace
37 {
38 template < typename T >
operator <<(std::ostream & os,const std::vector<T> & v)39 std::ostream & operator<< (std::ostream & os, const std::vector < T > &v)
40 {
41 os << "[";
42 for (size_t i = 0; i < v.size (); ++i)
43 {
44 os << v[i];
45 if (i != v.size () - 1)
46 {
47 os << ", ";
48 }
49 }
50 os << "]";
51
52 return os;
53 }
54
GstMlOutputNodeInfo(void)55 GstMlOutputNodeInfo::GstMlOutputNodeInfo (void):index
56 (GST_ML_NODE_INDEX_DISABLED),
57 type (ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT)
58 {
59 }
60
GstOnnxClient()61 GstOnnxClient::GstOnnxClient ():session (nullptr),
62 width (0),
63 height (0),
64 channels (0),
65 dest (nullptr),
66 m_provider (GST_ONNX_EXECUTION_PROVIDER_CPU),
67 inputImageFormat (GST_ML_MODEL_INPUT_IMAGE_FORMAT_HWC),
68 fixedInputImageSize (true)
69 {
70 for (size_t i = 0; i < GST_ML_OUTPUT_NODE_NUMBER_OF; ++i)
71 outputNodeIndexToFunction[i] = (GstMlOutputNodeFunction) i;
72 }
73
~GstOnnxClient()74 GstOnnxClient::~GstOnnxClient ()
75 {
76 delete session;
77 delete[]dest;
78 }
79
getEnv(void)80 Ort::Env & GstOnnxClient::getEnv (void)
81 {
82 static Ort::Env env (OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING,
83 "GstOnnxNamespace");
84
85 return env;
86 }
87
getWidth(void)88 int32_t GstOnnxClient::getWidth (void)
89 {
90 return width;
91 }
92
getHeight(void)93 int32_t GstOnnxClient::getHeight (void)
94 {
95 return height;
96 }
97
isFixedInputImageSize(void)98 bool GstOnnxClient::isFixedInputImageSize (void)
99 {
100 return fixedInputImageSize;
101 }
102
getOutputNodeName(GstMlOutputNodeFunction nodeType)103 std::string GstOnnxClient::getOutputNodeName (GstMlOutputNodeFunction nodeType)
104 {
105 switch (nodeType) {
106 case GST_ML_OUTPUT_NODE_FUNCTION_DETECTION:
107 return "detection";
108 break;
109 case GST_ML_OUTPUT_NODE_FUNCTION_BOUNDING_BOX:
110 return "bounding box";
111 break;
112 case GST_ML_OUTPUT_NODE_FUNCTION_SCORE:
113 return "score";
114 break;
115 case GST_ML_OUTPUT_NODE_FUNCTION_CLASS:
116 return "label";
117 break;
118 };
119
120 return "";
121 }
122
setInputImageFormat(GstMlModelInputImageFormat format)123 void GstOnnxClient::setInputImageFormat (GstMlModelInputImageFormat format)
124 {
125 inputImageFormat = format;
126 }
127
getInputImageFormat(void)128 GstMlModelInputImageFormat GstOnnxClient::getInputImageFormat (void)
129 {
130 return inputImageFormat;
131 }
132
getOutputNodeNames(void)133 std::vector < const char *>GstOnnxClient::getOutputNodeNames (void)
134 {
135 return outputNames;
136 }
137
setOutputNodeIndex(GstMlOutputNodeFunction node,gint index)138 void GstOnnxClient::setOutputNodeIndex (GstMlOutputNodeFunction node,
139 gint index)
140 {
141 g_assert (index < GST_ML_OUTPUT_NODE_NUMBER_OF);
142 outputNodeInfo[node].index = index;
143 if (index != GST_ML_NODE_INDEX_DISABLED)
144 outputNodeIndexToFunction[index] = node;
145 }
146
getOutputNodeIndex(GstMlOutputNodeFunction node)147 gint GstOnnxClient::getOutputNodeIndex (GstMlOutputNodeFunction node)
148 {
149 return outputNodeInfo[node].index;
150 }
151
setOutputNodeType(GstMlOutputNodeFunction node,ONNXTensorElementDataType type)152 void GstOnnxClient::setOutputNodeType (GstMlOutputNodeFunction node,
153 ONNXTensorElementDataType type)
154 {
155 outputNodeInfo[node].type = type;
156 }
157
158 ONNXTensorElementDataType
getOutputNodeType(GstMlOutputNodeFunction node)159 GstOnnxClient::getOutputNodeType (GstMlOutputNodeFunction node)
160 {
161 return outputNodeInfo[node].type;
162 }
163
hasSession(void)164 bool GstOnnxClient::hasSession (void)
165 {
166 return session != nullptr;
167 }
168
createSession(std::string modelFile,GstOnnxOptimizationLevel optim,GstOnnxExecutionProvider provider)169 bool GstOnnxClient::createSession (std::string modelFile,
170 GstOnnxOptimizationLevel optim, GstOnnxExecutionProvider provider)
171 {
172 if (session)
173 return true;
174
175 GraphOptimizationLevel onnx_optim;
176 switch (optim) {
177 case GST_ONNX_OPTIMIZATION_LEVEL_DISABLE_ALL:
178 onnx_optim = GraphOptimizationLevel::ORT_DISABLE_ALL;
179 break;
180 case GST_ONNX_OPTIMIZATION_LEVEL_ENABLE_BASIC:
181 onnx_optim = GraphOptimizationLevel::ORT_ENABLE_BASIC;
182 break;
183 case GST_ONNX_OPTIMIZATION_LEVEL_ENABLE_EXTENDED:
184 onnx_optim = GraphOptimizationLevel::ORT_ENABLE_EXTENDED;
185 break;
186 case GST_ONNX_OPTIMIZATION_LEVEL_ENABLE_ALL:
187 onnx_optim = GraphOptimizationLevel::ORT_ENABLE_ALL;
188 break;
189 default:
190 onnx_optim = GraphOptimizationLevel::ORT_ENABLE_EXTENDED;
191 break;
192 };
193
194 Ort::SessionOptions sessionOptions;
195 // for debugging
196 //sessionOptions.SetIntraOpNumThreads (1);
197 sessionOptions.SetGraphOptimizationLevel (onnx_optim);
198 m_provider = provider;
199 switch (m_provider) {
200 case GST_ONNX_EXECUTION_PROVIDER_CUDA:
201 #ifdef GST_ML_ONNX_RUNTIME_HAVE_CUDA
202 Ort::ThrowOnError (OrtSessionOptionsAppendExecutionProvider_CUDA
203 (sessionOptions, 0));
204 #else
205 return false;
206 #endif
207 break;
208 default:
209 break;
210
211 };
212 session = new Ort::Session (getEnv (), modelFile.c_str (), sessionOptions);
213 auto inputTypeInfo = session->GetInputTypeInfo (0);
214 std::vector < int64_t > inputDims =
215 inputTypeInfo.GetTensorTypeAndShapeInfo ().GetShape ();
216 if (inputImageFormat == GST_ML_MODEL_INPUT_IMAGE_FORMAT_HWC) {
217 height = inputDims[1];
218 width = inputDims[2];
219 channels = inputDims[3];
220 } else {
221 channels = inputDims[1];
222 height = inputDims[2];
223 width = inputDims[3];
224 }
225
226 fixedInputImageSize = width > 0 && height > 0;
227 GST_DEBUG ("Number of Output Nodes: %d", (gint) session->GetOutputCount ());
228
229 Ort::AllocatorWithDefaultOptions allocator;
230 GST_DEBUG ("Input name: %s", session->GetInputName (0, allocator));
231
232 for (size_t i = 0; i < session->GetOutputCount (); ++i) {
233 auto output_name = session->GetOutputName (i, allocator);
234 outputNames.push_back (output_name);
235 auto type_info = session->GetOutputTypeInfo (i);
236 auto tensor_info = type_info.GetTensorTypeAndShapeInfo ();
237
238 if (i < GST_ML_OUTPUT_NODE_NUMBER_OF) {
239 auto function = outputNodeIndexToFunction[i];
240 outputNodeInfo[function].type = tensor_info.GetElementType ();
241 }
242 }
243
244 return true;
245 }
246
run(uint8_t * img_data,GstVideoMeta * vmeta,std::string labelPath,float scoreThreshold)247 std::vector < GstMlBoundingBox > GstOnnxClient::run (uint8_t * img_data,
248 GstVideoMeta * vmeta, std::string labelPath, float scoreThreshold)
249 {
250 auto type = getOutputNodeType (GST_ML_OUTPUT_NODE_FUNCTION_CLASS);
251 return (type ==
252 ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) ?
253 doRun < float >(img_data, vmeta, labelPath, scoreThreshold)
254 : doRun < int >(img_data, vmeta, labelPath, scoreThreshold);
255 }
256
parseDimensions(GstVideoMeta * vmeta)257 void GstOnnxClient::parseDimensions (GstVideoMeta * vmeta)
258 {
259 int32_t newWidth = fixedInputImageSize ? width : vmeta->width;
260 int32_t newHeight = fixedInputImageSize ? height : vmeta->height;
261
262 if (!dest || width * height < newWidth * newHeight) {
263 delete[] dest;
264 dest = new uint8_t[newWidth * newHeight * channels];
265 }
266 width = newWidth;
267 height = newHeight;
268 }
269
270 template < typename T > std::vector < GstMlBoundingBox >
doRun(uint8_t * img_data,GstVideoMeta * vmeta,std::string labelPath,float scoreThreshold)271 GstOnnxClient::doRun (uint8_t * img_data, GstVideoMeta * vmeta,
272 std::string labelPath, float scoreThreshold)
273 {
274 std::vector < GstMlBoundingBox > boundingBoxes;
275 if (!img_data)
276 return boundingBoxes;
277
278 parseDimensions (vmeta);
279
280 Ort::AllocatorWithDefaultOptions allocator;
281 auto inputName = session->GetInputName (0, allocator);
282 auto inputTypeInfo = session->GetInputTypeInfo (0);
283 std::vector < int64_t > inputDims =
284 inputTypeInfo.GetTensorTypeAndShapeInfo ().GetShape ();
285 inputDims[0] = 1;
286 if (inputImageFormat == GST_ML_MODEL_INPUT_IMAGE_FORMAT_HWC) {
287 inputDims[1] = height;
288 inputDims[2] = width;
289 } else {
290 inputDims[2] = height;
291 inputDims[3] = width;
292 }
293
294 std::ostringstream buffer;
295 buffer << inputDims;
296 GST_DEBUG ("Input dimensions: %s", buffer.str ().c_str ());
297
298 // copy video frame
299 uint8_t *srcPtr[3] = { img_data, img_data + 1, img_data + 2 };
300 uint32_t srcSamplesPerPixel = 3;
301 switch (vmeta->format) {
302 case GST_VIDEO_FORMAT_RGBA:
303 srcSamplesPerPixel = 4;
304 break;
305 case GST_VIDEO_FORMAT_BGRA:
306 srcSamplesPerPixel = 4;
307 srcPtr[0] = img_data + 2;
308 srcPtr[1] = img_data + 1;
309 srcPtr[2] = img_data + 0;
310 break;
311 case GST_VIDEO_FORMAT_ARGB:
312 srcSamplesPerPixel = 4;
313 srcPtr[0] = img_data + 1;
314 srcPtr[1] = img_data + 2;
315 srcPtr[2] = img_data + 3;
316 break;
317 case GST_VIDEO_FORMAT_ABGR:
318 srcSamplesPerPixel = 4;
319 srcPtr[0] = img_data + 3;
320 srcPtr[1] = img_data + 2;
321 srcPtr[2] = img_data + 1;
322 break;
323 case GST_VIDEO_FORMAT_BGR:
324 srcPtr[0] = img_data + 2;
325 srcPtr[1] = img_data + 1;
326 srcPtr[2] = img_data + 0;
327 break;
328 default:
329 break;
330 }
331 size_t destIndex = 0;
332 uint32_t stride = vmeta->stride[0];
333 if (inputImageFormat == GST_ML_MODEL_INPUT_IMAGE_FORMAT_HWC) {
334 for (int32_t j = 0; j < height; ++j) {
335 for (int32_t i = 0; i < width; ++i) {
336 for (int32_t k = 0; k < channels; ++k) {
337 dest[destIndex++] = *srcPtr[k];
338 srcPtr[k] += srcSamplesPerPixel;
339 }
340 }
341 // correct for stride
342 for (uint32_t k = 0; k < 3; ++k)
343 srcPtr[k] += stride - srcSamplesPerPixel * width;
344 }
345 } else {
346 size_t frameSize = width * height;
347 uint8_t *destPtr[3] = { dest, dest + frameSize, dest + 2 * frameSize };
348 for (int32_t j = 0; j < height; ++j) {
349 for (int32_t i = 0; i < width; ++i) {
350 for (int32_t k = 0; k < channels; ++k) {
351 destPtr[k][destIndex] = *srcPtr[k];
352 srcPtr[k] += srcSamplesPerPixel;
353 }
354 destIndex++;
355 }
356 // correct for stride
357 for (uint32_t k = 0; k < 3; ++k)
358 srcPtr[k] += stride - srcSamplesPerPixel * width;
359 }
360 }
361
362 const size_t inputTensorSize = width * height * channels;
363 auto memoryInfo =
364 Ort::MemoryInfo::CreateCpu (OrtAllocatorType::OrtArenaAllocator,
365 OrtMemType::OrtMemTypeDefault);
366 std::vector < Ort::Value > inputTensors;
367 inputTensors.push_back (Ort::Value::CreateTensor < uint8_t > (memoryInfo,
368 dest, inputTensorSize, inputDims.data (), inputDims.size ()));
369 std::vector < const char *>inputNames { inputName };
370
371 std::vector < Ort::Value > modelOutput = session->Run (Ort::RunOptions { nullptr},
372 inputNames.data (),
373 inputTensors.data (), 1, outputNames.data (), outputNames.size ());
374
375 auto numDetections =
376 modelOutput[getOutputNodeIndex
377 (GST_ML_OUTPUT_NODE_FUNCTION_DETECTION)].GetTensorMutableData < float >();
378 auto bboxes =
379 modelOutput[getOutputNodeIndex
380 (GST_ML_OUTPUT_NODE_FUNCTION_BOUNDING_BOX)].GetTensorMutableData < float >();
381 auto scores =
382 modelOutput[getOutputNodeIndex
383 (GST_ML_OUTPUT_NODE_FUNCTION_SCORE)].GetTensorMutableData < float >();
384 T *labelIndex = nullptr;
385 if (getOutputNodeIndex (GST_ML_OUTPUT_NODE_FUNCTION_CLASS) !=
386 GST_ML_NODE_INDEX_DISABLED) {
387 labelIndex =
388 modelOutput[getOutputNodeIndex
389 (GST_ML_OUTPUT_NODE_FUNCTION_CLASS)].GetTensorMutableData < T > ();
390 }
391 if (labels.empty () && !labelPath.empty ())
392 labels = ReadLabels (labelPath);
393
394 for (int i = 0; i < numDetections[0]; ++i) {
395 if (scores[i] > scoreThreshold) {
396 std::string label = "";
397
398 if (labelIndex && !labels.empty ())
399 label = labels[labelIndex[i] - 1];
400 auto score = scores[i];
401 auto y0 = bboxes[i * 4] * height;
402 auto x0 = bboxes[i * 4 + 1] * width;
403 auto bheight = bboxes[i * 4 + 2] * height - y0;
404 auto bwidth = bboxes[i * 4 + 3] * width - x0;
405 boundingBoxes.push_back (GstMlBoundingBox (label, score, x0, y0, bwidth,
406 bheight));
407 }
408 }
409 return boundingBoxes;
410 }
411
412 std::vector < std::string >
ReadLabels(const std::string & labelsFile)413 GstOnnxClient::ReadLabels (const std::string & labelsFile)
414 {
415 std::vector < std::string > labels;
416 std::string line;
417 std::ifstream fp (labelsFile);
418 while (std::getline (fp, line))
419 labels.push_back (line);
420
421 return labels;
422 }
423 }
424