1 /* 2 * GStreamer gstreamer-onnxclient 3 * Copyright (C) 2021 Collabora Ltd 4 * 5 * gstonnxclient.h 6 * 7 * This library is free software; you can redistribute it and/or 8 * modify it under the terms of the GNU Library General Public 9 * License as published by the Free Software Foundation; either 10 * version 2 of the License, or (at your option) any later version. 11 * 12 * This library is distributed in the hope that it will be useful, 13 * but WITHOUT ANY WARRANTY; without even the implied warranty of 14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU 15 * Library General Public License for more details. 16 * 17 * You should have received a copy of the GNU Library General Public 18 * License along with this library; if not, write to the 19 * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, 20 * Boston, MA 02110-1301, USA. 21 */ 22 #ifndef __GST_ONNX_CLIENT_H__ 23 #define __GST_ONNX_CLIENT_H__ 24 25 #include <gst/gst.h> 26 #include <onnxruntime_cxx_api.h> 27 #include <gst/video/video.h> 28 #include "gstonnxelement.h" 29 #include <string> 30 #include <vector> 31 32 namespace GstOnnxNamespace { 33 enum GstMlOutputNodeFunction { 34 GST_ML_OUTPUT_NODE_FUNCTION_DETECTION, 35 GST_ML_OUTPUT_NODE_FUNCTION_BOUNDING_BOX, 36 GST_ML_OUTPUT_NODE_FUNCTION_SCORE, 37 GST_ML_OUTPUT_NODE_FUNCTION_CLASS, 38 GST_ML_OUTPUT_NODE_NUMBER_OF, 39 }; 40 41 const gint GST_ML_NODE_INDEX_DISABLED = -1; 42 43 struct GstMlOutputNodeInfo { 44 GstMlOutputNodeInfo(void); 45 gint index; 46 ONNXTensorElementDataType type; 47 }; 48 49 struct GstMlBoundingBox { GstMlBoundingBoxGstMlBoundingBox50 GstMlBoundingBox(std::string lbl, 51 float score, 52 float _x0, 53 float _y0, 54 float _width, 55 float _height):label(lbl), 56 score(score), x0(_x0), y0(_y0), width(_width), height(_height) { 57 } GstMlBoundingBoxGstMlBoundingBox58 GstMlBoundingBox():GstMlBoundingBox("", 0.0f, 0.0f, 0.0f, 0.0f, 0.0f) { 59 } 60 std::string label; 61 float score; 62 float x0; 63 float y0; 64 float width; 65 float height; 66 }; 67 68 class GstOnnxClient { 69 public: 70 GstOnnxClient(void); 71 ~GstOnnxClient(void); 72 bool createSession(std::string modelFile, GstOnnxOptimizationLevel optim, 73 GstOnnxExecutionProvider provider); 74 bool hasSession(void); 75 void setInputImageFormat(GstMlModelInputImageFormat format); 76 GstMlModelInputImageFormat getInputImageFormat(void); 77 void setOutputNodeIndex(GstMlOutputNodeFunction nodeType, gint index); 78 gint getOutputNodeIndex(GstMlOutputNodeFunction nodeType); 79 void setOutputNodeType(GstMlOutputNodeFunction nodeType, 80 ONNXTensorElementDataType type); 81 ONNXTensorElementDataType getOutputNodeType(GstMlOutputNodeFunction type); 82 std::string getOutputNodeName(GstMlOutputNodeFunction nodeType); 83 std::vector < GstMlBoundingBox > run(uint8_t * img_data, 84 GstVideoMeta * vmeta, 85 std::string labelPath, 86 float scoreThreshold); 87 std::vector < GstMlBoundingBox > &getBoundingBoxes(void); 88 std::vector < const char *>getOutputNodeNames(void); 89 bool isFixedInputImageSize(void); 90 int32_t getWidth(void); 91 int32_t getHeight(void); 92 private: 93 void parseDimensions(GstVideoMeta * vmeta); 94 template < typename T > std::vector < GstMlBoundingBox > 95 doRun(uint8_t * img_data, GstVideoMeta * vmeta, std::string labelPath, 96 float scoreThreshold); 97 std::vector < std::string > ReadLabels(const std::string & labelsFile); 98 Ort::Env & getEnv(void); 99 Ort::Session * session; 100 int32_t width; 101 int32_t height; 102 int32_t channels; 103 uint8_t *dest; 104 GstOnnxExecutionProvider m_provider; 105 std::vector < Ort::Value > modelOutput; 106 std::vector < std::string > labels; 107 // !! indexed by function 108 GstMlOutputNodeInfo outputNodeInfo[GST_ML_OUTPUT_NODE_NUMBER_OF]; 109 // !! indexed by array index 110 size_t outputNodeIndexToFunction[GST_ML_OUTPUT_NODE_NUMBER_OF]; 111 std::vector < const char *>outputNames; 112 GstMlModelInputImageFormat inputImageFormat; 113 bool fixedInputImageSize; 114 }; 115 } 116 117 #endif /* __GST_ONNX_CLIENT_H__ */ 118