• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2020
3  *
4  * This file is part of FFmpeg.
5  *
6  * FFmpeg is free software; you can redistribute it and/or
7  * modify it under the terms of the GNU Lesser General Public
8  * License as published by the Free Software Foundation; either
9  * version 2.1 of the License, or (at your option) any later version.
10  *
11  * FFmpeg is distributed in the hope that it will be useful,
12  * but WITHOUT ANY WARRANTY; without even the implied warranty of
13  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14  * Lesser General Public License for more details.
15  *
16  * You should have received a copy of the GNU Lesser General Public
17  * License along with FFmpeg; if not, write to the Free Software
18  * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
19  */
20 
21 /**
22  * @file
23  * DNN OpenVINO backend implementation.
24  */
25 
26 #include "dnn_backend_openvino.h"
27 #include "dnn_io_proc.h"
28 #include "libavformat/avio.h"
29 #include "libavutil/avassert.h"
30 #include "libavutil/cpu.h"
31 #include "libavutil/opt.h"
32 #include "libavutil/avstring.h"
33 #include "libavutil/detection_bbox.h"
34 #include "../internal.h"
35 #include "safe_queue.h"
36 #include <c_api/ie_c_api.h>
37 #include "dnn_backend_common.h"
38 
39 typedef struct OVOptions{
40     char *device_type;
41     int nireq;
42     uint8_t async;
43     int batch_size;
44     int input_resizable;
45 } OVOptions;
46 
47 typedef struct OVContext {
48     const AVClass *class;
49     OVOptions options;
50 } OVContext;
51 
52 typedef struct OVModel{
53     OVContext ctx;
54     DNNModel *model;
55     ie_core_t *core;
56     ie_network_t *network;
57     ie_executable_network_t *exe_network;
58     SafeQueue *request_queue;   // holds OVRequestItem
59     Queue *task_queue;          // holds TaskItem
60     Queue *lltask_queue;     // holds LastLevelTaskItem
61 } OVModel;
62 
63 // one request for one call to openvino
64 typedef struct OVRequestItem {
65     ie_infer_request_t *infer_request;
66     LastLevelTaskItem **lltasks;
67     uint32_t lltask_count;
68     ie_complete_call_back_t callback;
69 } OVRequestItem;
70 
71 #define APPEND_STRING(generated_string, iterate_string)                                            \
72     generated_string = generated_string ? av_asprintf("%s %s", generated_string, iterate_string) : \
73                                           av_asprintf("%s", iterate_string);
74 
75 #define OFFSET(x) offsetof(OVContext, x)
76 #define FLAGS AV_OPT_FLAG_FILTERING_PARAM
77 static const AVOption dnn_openvino_options[] = {
78     { "device", "device to run model", OFFSET(options.device_type), AV_OPT_TYPE_STRING, { .str = "CPU" }, 0, 0, FLAGS },
79     DNN_BACKEND_COMMON_OPTIONS
80     { "batch_size",  "batch size per request", OFFSET(options.batch_size),  AV_OPT_TYPE_INT,    { .i64 = 1 },     1, 1000, FLAGS},
81     { "input_resizable", "can input be resizable or not", OFFSET(options.input_resizable), AV_OPT_TYPE_BOOL,   { .i64 = 0 },     0, 1, FLAGS },
82     { NULL }
83 };
84 
85 AVFILTER_DEFINE_CLASS(dnn_openvino);
86 
precision_to_datatype(precision_e precision)87 static DNNDataType precision_to_datatype(precision_e precision)
88 {
89     switch (precision)
90     {
91     case FP32:
92         return DNN_FLOAT;
93     case U8:
94         return DNN_UINT8;
95     default:
96         av_assert0(!"not supported yet.");
97         return DNN_FLOAT;
98     }
99 }
100 
get_datatype_size(DNNDataType dt)101 static int get_datatype_size(DNNDataType dt)
102 {
103     switch (dt)
104     {
105     case DNN_FLOAT:
106         return sizeof(float);
107     case DNN_UINT8:
108         return sizeof(uint8_t);
109     default:
110         av_assert0(!"not supported yet.");
111         return 1;
112     }
113 }
114 
fill_model_input_ov(OVModel * ov_model,OVRequestItem * request)115 static int fill_model_input_ov(OVModel *ov_model, OVRequestItem *request)
116 {
117     dimensions_t dims;
118     precision_e precision;
119     ie_blob_buffer_t blob_buffer;
120     OVContext *ctx = &ov_model->ctx;
121     IEStatusCode status;
122     DNNData input;
123     ie_blob_t *input_blob = NULL;
124     LastLevelTaskItem *lltask;
125     TaskItem *task;
126 
127     lltask = ff_queue_peek_front(ov_model->lltask_queue);
128     av_assert0(lltask);
129     task = lltask->task;
130 
131     status = ie_infer_request_get_blob(request->infer_request, task->input_name, &input_blob);
132     if (status != OK) {
133         av_log(ctx, AV_LOG_ERROR, "Failed to get input blob with name %s\n", task->input_name);
134         return DNN_GENERIC_ERROR;
135     }
136 
137     status |= ie_blob_get_dims(input_blob, &dims);
138     status |= ie_blob_get_precision(input_blob, &precision);
139     if (status != OK) {
140         ie_blob_free(&input_blob);
141         av_log(ctx, AV_LOG_ERROR, "Failed to get input blob dims/precision\n");
142         return DNN_GENERIC_ERROR;
143     }
144 
145     status = ie_blob_get_buffer(input_blob, &blob_buffer);
146     if (status != OK) {
147         ie_blob_free(&input_blob);
148         av_log(ctx, AV_LOG_ERROR, "Failed to get input blob buffer\n");
149         return DNN_GENERIC_ERROR;
150     }
151 
152     input.height = dims.dims[2];
153     input.width = dims.dims[3];
154     input.channels = dims.dims[1];
155     input.data = blob_buffer.buffer;
156     input.dt = precision_to_datatype(precision);
157     // all models in openvino open model zoo use BGR as input,
158     // change to be an option when necessary.
159     input.order = DCO_BGR;
160 
161     for (int i = 0; i < ctx->options.batch_size; ++i) {
162         lltask = ff_queue_pop_front(ov_model->lltask_queue);
163         if (!lltask) {
164             break;
165         }
166         request->lltasks[i] = lltask;
167         request->lltask_count = i + 1;
168         task = lltask->task;
169         switch (ov_model->model->func_type) {
170         case DFT_PROCESS_FRAME:
171             if (task->do_ioproc) {
172                 if (ov_model->model->frame_pre_proc != NULL) {
173                     ov_model->model->frame_pre_proc(task->in_frame, &input, ov_model->model->filter_ctx);
174                 } else {
175                     ff_proc_from_frame_to_dnn(task->in_frame, &input, ctx);
176                 }
177             }
178             break;
179         case DFT_ANALYTICS_DETECT:
180             ff_frame_to_dnn_detect(task->in_frame, &input, ctx);
181             break;
182         case DFT_ANALYTICS_CLASSIFY:
183             ff_frame_to_dnn_classify(task->in_frame, &input, lltask->bbox_index, ctx);
184             break;
185         default:
186             av_assert0(!"should not reach here");
187             break;
188         }
189         input.data = (uint8_t *)input.data
190                      + input.width * input.height * input.channels * get_datatype_size(input.dt);
191     }
192     ie_blob_free(&input_blob);
193 
194     return 0;
195 }
196 
infer_completion_callback(void * args)197 static void infer_completion_callback(void *args)
198 {
199     dimensions_t dims;
200     precision_e precision;
201     IEStatusCode status;
202     OVRequestItem *request = args;
203     LastLevelTaskItem *lltask = request->lltasks[0];
204     TaskItem *task = lltask->task;
205     OVModel *ov_model = task->model;
206     SafeQueue *requestq = ov_model->request_queue;
207     ie_blob_t *output_blob = NULL;
208     ie_blob_buffer_t blob_buffer;
209     DNNData output;
210     OVContext *ctx = &ov_model->ctx;
211 
212     status = ie_infer_request_get_blob(request->infer_request, task->output_names[0], &output_blob);
213     if (status != OK) {
214         //incorrect output name
215         char *model_output_name = NULL;
216         char *all_output_names = NULL;
217         size_t model_output_count = 0;
218         av_log(ctx, AV_LOG_ERROR, "Failed to get model output data\n");
219         status = ie_network_get_outputs_number(ov_model->network, &model_output_count);
220         for (size_t i = 0; i < model_output_count; i++) {
221             status = ie_network_get_output_name(ov_model->network, i, &model_output_name);
222             APPEND_STRING(all_output_names, model_output_name)
223         }
224         av_log(ctx, AV_LOG_ERROR,
225                "output \"%s\" may not correct, all output(s) are: \"%s\"\n",
226                task->output_names[0], all_output_names);
227         return;
228     }
229 
230     status = ie_blob_get_buffer(output_blob, &blob_buffer);
231     if (status != OK) {
232         ie_blob_free(&output_blob);
233         av_log(ctx, AV_LOG_ERROR, "Failed to access output memory\n");
234         return;
235     }
236 
237     status |= ie_blob_get_dims(output_blob, &dims);
238     status |= ie_blob_get_precision(output_blob, &precision);
239     if (status != OK) {
240         ie_blob_free(&output_blob);
241         av_log(ctx, AV_LOG_ERROR, "Failed to get dims or precision of output\n");
242         return;
243     }
244 
245     output.channels = dims.dims[1];
246     output.height   = dims.dims[2];
247     output.width    = dims.dims[3];
248     output.dt       = precision_to_datatype(precision);
249     output.data     = blob_buffer.buffer;
250 
251     av_assert0(request->lltask_count <= dims.dims[0]);
252     av_assert0(request->lltask_count >= 1);
253     for (int i = 0; i < request->lltask_count; ++i) {
254         task = request->lltasks[i]->task;
255         task->inference_done++;
256 
257         switch (ov_model->model->func_type) {
258         case DFT_PROCESS_FRAME:
259             if (task->do_ioproc) {
260                 if (ov_model->model->frame_post_proc != NULL) {
261                     ov_model->model->frame_post_proc(task->out_frame, &output, ov_model->model->filter_ctx);
262                 } else {
263                     ff_proc_from_dnn_to_frame(task->out_frame, &output, ctx);
264                 }
265             } else {
266                 task->out_frame->width = output.width;
267                 task->out_frame->height = output.height;
268             }
269             break;
270         case DFT_ANALYTICS_DETECT:
271             if (!ov_model->model->detect_post_proc) {
272                 av_log(ctx, AV_LOG_ERROR, "detect filter needs to provide post proc\n");
273                 return;
274             }
275             ov_model->model->detect_post_proc(task->in_frame, &output, 1, ov_model->model->filter_ctx);
276             break;
277         case DFT_ANALYTICS_CLASSIFY:
278             if (!ov_model->model->classify_post_proc) {
279                 av_log(ctx, AV_LOG_ERROR, "classify filter needs to provide post proc\n");
280                 return;
281             }
282             ov_model->model->classify_post_proc(task->in_frame, &output, request->lltasks[i]->bbox_index, ov_model->model->filter_ctx);
283             break;
284         default:
285             av_assert0(!"should not reach here");
286             break;
287         }
288 
289         av_freep(&request->lltasks[i]);
290         output.data = (uint8_t *)output.data
291                       + output.width * output.height * output.channels * get_datatype_size(output.dt);
292     }
293     ie_blob_free(&output_blob);
294 
295     request->lltask_count = 0;
296     if (ff_safe_queue_push_back(requestq, request) < 0) {
297         ie_infer_request_free(&request->infer_request);
298         av_freep(&request);
299         av_log(ctx, AV_LOG_ERROR, "Failed to push back request_queue.\n");
300         return;
301     }
302 }
303 
init_model_ov(OVModel * ov_model,const char * input_name,const char * output_name)304 static int init_model_ov(OVModel *ov_model, const char *input_name, const char *output_name)
305 {
306     int ret = 0;
307     OVContext *ctx = &ov_model->ctx;
308     IEStatusCode status;
309     ie_available_devices_t a_dev;
310     ie_config_t config = {NULL, NULL, NULL};
311     char *all_dev_names = NULL;
312 
313     // batch size
314     if (ctx->options.batch_size <= 0) {
315         ctx->options.batch_size = 1;
316     }
317 
318     if (ctx->options.batch_size > 1) {
319         input_shapes_t input_shapes;
320         status = ie_network_get_input_shapes(ov_model->network, &input_shapes);
321         if (status != OK) {
322             ret = DNN_GENERIC_ERROR;
323             goto err;
324         }
325         for (int i = 0; i < input_shapes.shape_num; i++)
326             input_shapes.shapes[i].shape.dims[0] = ctx->options.batch_size;
327         status = ie_network_reshape(ov_model->network, input_shapes);
328         ie_network_input_shapes_free(&input_shapes);
329         if (status != OK) {
330             ret = DNN_GENERIC_ERROR;
331             goto err;
332         }
333     }
334 
335     // The order of dims in the openvino is fixed and it is always NCHW for 4-D data.
336     // while we pass NHWC data from FFmpeg to openvino
337     status = ie_network_set_input_layout(ov_model->network, input_name, NHWC);
338     if (status != OK) {
339         av_log(ctx, AV_LOG_ERROR, "Failed to set layout as NHWC for input %s\n", input_name);
340         ret = DNN_GENERIC_ERROR;
341         goto err;
342     }
343     status = ie_network_set_output_layout(ov_model->network, output_name, NHWC);
344     if (status != OK) {
345         av_log(ctx, AV_LOG_ERROR, "Failed to set layout as NHWC for output %s\n", output_name);
346         ret = DNN_GENERIC_ERROR;
347         goto err;
348     }
349 
350     // all models in openvino open model zoo use BGR with range [0.0f, 255.0f] as input,
351     // we don't have a AVPixelFormat to describe it, so we'll use AV_PIX_FMT_BGR24 and
352     // ask openvino to do the conversion internally.
353     // the current supported SR model (frame processing) is generated from tensorflow model,
354     // and its input is Y channel as float with range [0.0f, 1.0f], so do not set for this case.
355     // TODO: we need to get a final clear&general solution with all backends/formats considered.
356     if (ov_model->model->func_type != DFT_PROCESS_FRAME) {
357         status = ie_network_set_input_precision(ov_model->network, input_name, U8);
358         if (status != OK) {
359             av_log(ctx, AV_LOG_ERROR, "Failed to set input precision as U8 for %s\n", input_name);
360             ret = DNN_GENERIC_ERROR;
361             goto err;
362         }
363     }
364 
365     status = ie_core_load_network(ov_model->core, ov_model->network, ctx->options.device_type, &config, &ov_model->exe_network);
366     if (status != OK) {
367         av_log(ctx, AV_LOG_ERROR, "Failed to load OpenVINO model network\n");
368         status = ie_core_get_available_devices(ov_model->core, &a_dev);
369         if (status != OK) {
370             av_log(ctx, AV_LOG_ERROR, "Failed to get available devices\n");
371             ret = DNN_GENERIC_ERROR;
372             goto err;
373         }
374         for (int i = 0; i < a_dev.num_devices; i++) {
375             APPEND_STRING(all_dev_names, a_dev.devices[i])
376         }
377         av_log(ctx, AV_LOG_ERROR,"device %s may not be supported, all available devices are: \"%s\"\n",
378                ctx->options.device_type, all_dev_names);
379         ret = AVERROR(ENODEV);
380         goto err;
381     }
382 
383     // create infer_requests for async execution
384     if (ctx->options.nireq <= 0) {
385         // the default value is a rough estimation
386         ctx->options.nireq = av_cpu_count() / 2 + 1;
387     }
388 
389     ov_model->request_queue = ff_safe_queue_create();
390     if (!ov_model->request_queue) {
391         ret = AVERROR(ENOMEM);
392         goto err;
393     }
394 
395     for (int i = 0; i < ctx->options.nireq; i++) {
396         OVRequestItem *item = av_mallocz(sizeof(*item));
397         if (!item) {
398             ret = AVERROR(ENOMEM);
399             goto err;
400         }
401 
402         item->callback.completeCallBackFunc = infer_completion_callback;
403         item->callback.args = item;
404         if (ff_safe_queue_push_back(ov_model->request_queue, item) < 0) {
405             av_freep(&item);
406             ret = AVERROR(ENOMEM);
407             goto err;
408         }
409 
410         status = ie_exec_network_create_infer_request(ov_model->exe_network, &item->infer_request);
411         if (status != OK) {
412             ret = DNN_GENERIC_ERROR;
413             goto err;
414         }
415 
416         item->lltasks = av_malloc_array(ctx->options.batch_size, sizeof(*item->lltasks));
417         if (!item->lltasks) {
418             ret = AVERROR(ENOMEM);
419             goto err;
420         }
421         item->lltask_count = 0;
422     }
423 
424     ov_model->task_queue = ff_queue_create();
425     if (!ov_model->task_queue) {
426         ret = AVERROR(ENOMEM);
427         goto err;
428     }
429 
430     ov_model->lltask_queue = ff_queue_create();
431     if (!ov_model->lltask_queue) {
432         ret = AVERROR(ENOMEM);
433         goto err;
434     }
435 
436     return 0;
437 
438 err:
439     ff_dnn_free_model_ov(&ov_model->model);
440     return ret;
441 }
442 
execute_model_ov(OVRequestItem * request,Queue * inferenceq)443 static int execute_model_ov(OVRequestItem *request, Queue *inferenceq)
444 {
445     IEStatusCode status;
446     LastLevelTaskItem *lltask;
447     int ret = 0;
448     TaskItem *task;
449     OVContext *ctx;
450     OVModel *ov_model;
451 
452     if (ff_queue_size(inferenceq) == 0) {
453         ie_infer_request_free(&request->infer_request);
454         av_freep(&request);
455         return 0;
456     }
457 
458     lltask = ff_queue_peek_front(inferenceq);
459     task = lltask->task;
460     ov_model = task->model;
461     ctx = &ov_model->ctx;
462 
463     if (task->async) {
464         ret = fill_model_input_ov(ov_model, request);
465         if (ret != 0) {
466             goto err;
467         }
468         status = ie_infer_set_completion_callback(request->infer_request, &request->callback);
469         if (status != OK) {
470             av_log(ctx, AV_LOG_ERROR, "Failed to set completion callback for inference\n");
471             ret = DNN_GENERIC_ERROR;
472             goto err;
473         }
474         status = ie_infer_request_infer_async(request->infer_request);
475         if (status != OK) {
476             av_log(ctx, AV_LOG_ERROR, "Failed to start async inference\n");
477             ret = DNN_GENERIC_ERROR;
478             goto err;
479         }
480         return 0;
481     } else {
482         ret = fill_model_input_ov(ov_model, request);
483         if (ret != 0) {
484             goto err;
485         }
486         status = ie_infer_request_infer(request->infer_request);
487         if (status != OK) {
488             av_log(ctx, AV_LOG_ERROR, "Failed to start synchronous model inference\n");
489             ret = DNN_GENERIC_ERROR;
490             goto err;
491         }
492         infer_completion_callback(request);
493         return (task->inference_done == task->inference_todo) ? 0 : DNN_GENERIC_ERROR;
494     }
495 err:
496     if (ff_safe_queue_push_back(ov_model->request_queue, request) < 0) {
497         ie_infer_request_free(&request->infer_request);
498         av_freep(&request);
499     }
500     return ret;
501 }
502 
get_input_ov(void * model,DNNData * input,const char * input_name)503 static int get_input_ov(void *model, DNNData *input, const char *input_name)
504 {
505     OVModel *ov_model = model;
506     OVContext *ctx = &ov_model->ctx;
507     char *model_input_name = NULL;
508     char *all_input_names = NULL;
509     IEStatusCode status;
510     size_t model_input_count = 0;
511     dimensions_t dims;
512     precision_e precision;
513     int input_resizable = ctx->options.input_resizable;
514 
515     status = ie_network_get_inputs_number(ov_model->network, &model_input_count);
516     if (status != OK) {
517         av_log(ctx, AV_LOG_ERROR, "Failed to get input count\n");
518         return DNN_GENERIC_ERROR;
519     }
520 
521     for (size_t i = 0; i < model_input_count; i++) {
522         status = ie_network_get_input_name(ov_model->network, i, &model_input_name);
523         if (status != OK) {
524             av_log(ctx, AV_LOG_ERROR, "Failed to get No.%d input's name\n", (int)i);
525             return DNN_GENERIC_ERROR;
526         }
527         if (strcmp(model_input_name, input_name) == 0) {
528             ie_network_name_free(&model_input_name);
529             status |= ie_network_get_input_dims(ov_model->network, input_name, &dims);
530             status |= ie_network_get_input_precision(ov_model->network, input_name, &precision);
531             if (status != OK) {
532                 av_log(ctx, AV_LOG_ERROR, "Failed to get No.%d input's dims or precision\n", (int)i);
533                 return DNN_GENERIC_ERROR;
534             }
535 
536             input->channels = dims.dims[1];
537             input->height   = input_resizable ? -1 : dims.dims[2];
538             input->width    = input_resizable ? -1 : dims.dims[3];
539             input->dt       = precision_to_datatype(precision);
540             return 0;
541         } else {
542             //incorrect input name
543             APPEND_STRING(all_input_names, model_input_name)
544         }
545 
546         ie_network_name_free(&model_input_name);
547     }
548 
549     av_log(ctx, AV_LOG_ERROR, "Could not find \"%s\" in model, all input(s) are: \"%s\"\n", input_name, all_input_names);
550     return AVERROR(EINVAL);
551 }
552 
contain_valid_detection_bbox(AVFrame * frame)553 static int contain_valid_detection_bbox(AVFrame *frame)
554 {
555     AVFrameSideData *sd;
556     const AVDetectionBBoxHeader *header;
557     const AVDetectionBBox *bbox;
558 
559     sd = av_frame_get_side_data(frame, AV_FRAME_DATA_DETECTION_BBOXES);
560     if (!sd) { // this frame has nothing detected
561         return 0;
562     }
563 
564     if (!sd->size) {
565         return 0;
566     }
567 
568     header = (const AVDetectionBBoxHeader *)sd->data;
569     if (!header->nb_bboxes) {
570         return 0;
571     }
572 
573     for (uint32_t i = 0; i < header->nb_bboxes; i++) {
574         bbox = av_get_detection_bbox(header, i);
575         if (bbox->x < 0 || bbox->w < 0 || bbox->x + bbox->w >= frame->width) {
576             return 0;
577         }
578         if (bbox->y < 0 || bbox->h < 0 || bbox->y + bbox->h >= frame->width) {
579             return 0;
580         }
581 
582         if (bbox->classify_count == AV_NUM_DETECTION_BBOX_CLASSIFY) {
583             return 0;
584         }
585     }
586 
587     return 1;
588 }
589 
extract_lltask_from_task(DNNFunctionType func_type,TaskItem * task,Queue * lltask_queue,DNNExecBaseParams * exec_params)590 static int extract_lltask_from_task(DNNFunctionType func_type, TaskItem *task, Queue *lltask_queue, DNNExecBaseParams *exec_params)
591 {
592     switch (func_type) {
593     case DFT_PROCESS_FRAME:
594     case DFT_ANALYTICS_DETECT:
595     {
596         LastLevelTaskItem *lltask = av_malloc(sizeof(*lltask));
597         if (!lltask) {
598             return AVERROR(ENOMEM);
599         }
600         task->inference_todo = 1;
601         task->inference_done = 0;
602         lltask->task = task;
603         if (ff_queue_push_back(lltask_queue, lltask) < 0) {
604             av_freep(&lltask);
605             return AVERROR(ENOMEM);
606         }
607         return 0;
608     }
609     case DFT_ANALYTICS_CLASSIFY:
610     {
611         const AVDetectionBBoxHeader *header;
612         AVFrame *frame = task->in_frame;
613         AVFrameSideData *sd;
614         DNNExecClassificationParams *params = (DNNExecClassificationParams *)exec_params;
615 
616         task->inference_todo = 0;
617         task->inference_done = 0;
618 
619         if (!contain_valid_detection_bbox(frame)) {
620             return 0;
621         }
622 
623         sd = av_frame_get_side_data(frame, AV_FRAME_DATA_DETECTION_BBOXES);
624         header = (const AVDetectionBBoxHeader *)sd->data;
625 
626         for (uint32_t i = 0; i < header->nb_bboxes; i++) {
627             LastLevelTaskItem *lltask;
628             const AVDetectionBBox *bbox = av_get_detection_bbox(header, i);
629 
630             if (params->target) {
631                 if (av_strncasecmp(bbox->detect_label, params->target, sizeof(bbox->detect_label)) != 0) {
632                     continue;
633                 }
634             }
635 
636             lltask = av_malloc(sizeof(*lltask));
637             if (!lltask) {
638                 return AVERROR(ENOMEM);
639             }
640             task->inference_todo++;
641             lltask->task = task;
642             lltask->bbox_index = i;
643             if (ff_queue_push_back(lltask_queue, lltask) < 0) {
644                 av_freep(&lltask);
645                 return AVERROR(ENOMEM);
646             }
647         }
648         return 0;
649     }
650     default:
651         av_assert0(!"should not reach here");
652         return AVERROR(EINVAL);
653     }
654 }
655 
get_output_ov(void * model,const char * input_name,int input_width,int input_height,const char * output_name,int * output_width,int * output_height)656 static int get_output_ov(void *model, const char *input_name, int input_width, int input_height,
657                                    const char *output_name, int *output_width, int *output_height)
658 {
659     int ret;
660     OVModel *ov_model = model;
661     OVContext *ctx = &ov_model->ctx;
662     TaskItem task;
663     OVRequestItem *request;
664     IEStatusCode status;
665     input_shapes_t input_shapes;
666     DNNExecBaseParams exec_params = {
667         .input_name     = input_name,
668         .output_names   = &output_name,
669         .nb_output      = 1,
670         .in_frame       = NULL,
671         .out_frame      = NULL,
672     };
673 
674     if (ov_model->model->func_type != DFT_PROCESS_FRAME) {
675         av_log(ctx, AV_LOG_ERROR, "Get output dim only when processing frame.\n");
676         return AVERROR(EINVAL);
677     }
678 
679     if (ctx->options.input_resizable) {
680         status = ie_network_get_input_shapes(ov_model->network, &input_shapes);
681         input_shapes.shapes->shape.dims[2] = input_height;
682         input_shapes.shapes->shape.dims[3] = input_width;
683         status |= ie_network_reshape(ov_model->network, input_shapes);
684         ie_network_input_shapes_free(&input_shapes);
685         if (status != OK) {
686             av_log(ctx, AV_LOG_ERROR, "Failed to reshape input size for %s\n", input_name);
687             return DNN_GENERIC_ERROR;
688         }
689     }
690 
691     if (!ov_model->exe_network) {
692         ret = init_model_ov(ov_model, input_name, output_name);
693         if (ret != 0) {
694             av_log(ctx, AV_LOG_ERROR, "Failed init OpenVINO exectuable network or inference request\n");
695             return ret;
696         }
697     }
698 
699     ret = ff_dnn_fill_gettingoutput_task(&task, &exec_params, ov_model, input_height, input_width, ctx);
700     if (ret != 0) {
701         goto err;
702     }
703 
704     ret = extract_lltask_from_task(ov_model->model->func_type, &task, ov_model->lltask_queue, NULL);
705     if (ret != 0) {
706         av_log(ctx, AV_LOG_ERROR, "unable to extract inference from task.\n");
707         goto err;
708     }
709 
710     request = ff_safe_queue_pop_front(ov_model->request_queue);
711     if (!request) {
712         av_log(ctx, AV_LOG_ERROR, "unable to get infer request.\n");
713         ret = AVERROR(EINVAL);
714         goto err;
715     }
716 
717     ret = execute_model_ov(request, ov_model->lltask_queue);
718     *output_width = task.out_frame->width;
719     *output_height = task.out_frame->height;
720 err:
721     av_frame_free(&task.out_frame);
722     av_frame_free(&task.in_frame);
723     return ret;
724 }
725 
ff_dnn_load_model_ov(const char * model_filename,DNNFunctionType func_type,const char * options,AVFilterContext * filter_ctx)726 DNNModel *ff_dnn_load_model_ov(const char *model_filename, DNNFunctionType func_type, const char *options, AVFilterContext *filter_ctx)
727 {
728     DNNModel *model = NULL;
729     OVModel *ov_model = NULL;
730     OVContext *ctx = NULL;
731     IEStatusCode status;
732 
733     model = av_mallocz(sizeof(DNNModel));
734     if (!model){
735         return NULL;
736     }
737 
738     ov_model = av_mallocz(sizeof(OVModel));
739     if (!ov_model) {
740         av_freep(&model);
741         return NULL;
742     }
743     model->model = ov_model;
744     ov_model->model = model;
745     ov_model->ctx.class = &dnn_openvino_class;
746     ctx = &ov_model->ctx;
747 
748     //parse options
749     av_opt_set_defaults(ctx);
750     if (av_opt_set_from_string(ctx, options, NULL, "=", "&") < 0) {
751         av_log(ctx, AV_LOG_ERROR, "Failed to parse options \"%s\"\n", options);
752         goto err;
753     }
754 
755     status = ie_core_create("", &ov_model->core);
756     if (status != OK)
757         goto err;
758 
759     status = ie_core_read_network(ov_model->core, model_filename, NULL, &ov_model->network);
760     if (status != OK) {
761         ie_version_t ver;
762         ver = ie_c_api_version();
763         av_log(ctx, AV_LOG_ERROR, "Failed to read the network from model file %s,\n"
764                                   "Please check if the model version matches the runtime OpenVINO %s\n",
765                                    model_filename, ver.api_version);
766         ie_version_free(&ver);
767         goto err;
768     }
769 
770     model->get_input = &get_input_ov;
771     model->get_output = &get_output_ov;
772     model->options = options;
773     model->filter_ctx = filter_ctx;
774     model->func_type = func_type;
775 
776     return model;
777 
778 err:
779     ff_dnn_free_model_ov(&model);
780     return NULL;
781 }
782 
ff_dnn_execute_model_ov(const DNNModel * model,DNNExecBaseParams * exec_params)783 int ff_dnn_execute_model_ov(const DNNModel *model, DNNExecBaseParams *exec_params)
784 {
785     OVModel *ov_model = model->model;
786     OVContext *ctx = &ov_model->ctx;
787     OVRequestItem *request;
788     TaskItem *task;
789     int ret;
790 
791     ret = ff_check_exec_params(ctx, DNN_OV, model->func_type, exec_params);
792     if (ret != 0) {
793         return ret;
794     }
795 
796     if (!ov_model->exe_network) {
797         ret = init_model_ov(ov_model, exec_params->input_name, exec_params->output_names[0]);
798         if (ret != 0) {
799             av_log(ctx, AV_LOG_ERROR, "Failed init OpenVINO exectuable network or inference request\n");
800             return ret;
801         }
802     }
803 
804     task = av_malloc(sizeof(*task));
805     if (!task) {
806         av_log(ctx, AV_LOG_ERROR, "unable to alloc memory for task item.\n");
807         return AVERROR(ENOMEM);
808     }
809 
810     ret = ff_dnn_fill_task(task, exec_params, ov_model, ctx->options.async, 1);
811     if (ret != 0) {
812         av_freep(&task);
813         return ret;
814     }
815 
816     if (ff_queue_push_back(ov_model->task_queue, task) < 0) {
817         av_freep(&task);
818         av_log(ctx, AV_LOG_ERROR, "unable to push back task_queue.\n");
819         return AVERROR(ENOMEM);
820     }
821 
822     ret = extract_lltask_from_task(model->func_type, task, ov_model->lltask_queue, exec_params);
823     if (ret != 0) {
824         av_log(ctx, AV_LOG_ERROR, "unable to extract inference from task.\n");
825         return ret;
826     }
827 
828     if (ctx->options.async) {
829         while (ff_queue_size(ov_model->lltask_queue) >= ctx->options.batch_size) {
830             request = ff_safe_queue_pop_front(ov_model->request_queue);
831             if (!request) {
832                 av_log(ctx, AV_LOG_ERROR, "unable to get infer request.\n");
833                 return AVERROR(EINVAL);
834             }
835 
836             ret = execute_model_ov(request, ov_model->lltask_queue);
837             if (ret != 0) {
838                 return ret;
839             }
840         }
841 
842         return 0;
843     }
844     else {
845         if (model->func_type == DFT_ANALYTICS_CLASSIFY) {
846             // Classification filter has not been completely
847             // tested with the sync mode. So, do not support now.
848             avpriv_report_missing_feature(ctx, "classify for sync execution");
849             return AVERROR(ENOSYS);
850         }
851 
852         if (ctx->options.batch_size > 1) {
853             avpriv_report_missing_feature(ctx, "batch mode for sync execution");
854             return AVERROR(ENOSYS);
855         }
856 
857         request = ff_safe_queue_pop_front(ov_model->request_queue);
858         if (!request) {
859             av_log(ctx, AV_LOG_ERROR, "unable to get infer request.\n");
860             return AVERROR(EINVAL);
861         }
862         return execute_model_ov(request, ov_model->lltask_queue);
863     }
864 }
865 
ff_dnn_get_result_ov(const DNNModel * model,AVFrame ** in,AVFrame ** out)866 DNNAsyncStatusType ff_dnn_get_result_ov(const DNNModel *model, AVFrame **in, AVFrame **out)
867 {
868     OVModel *ov_model = model->model;
869     return ff_dnn_get_result_common(ov_model->task_queue, in, out);
870 }
871 
ff_dnn_flush_ov(const DNNModel * model)872 int ff_dnn_flush_ov(const DNNModel *model)
873 {
874     OVModel *ov_model = model->model;
875     OVContext *ctx = &ov_model->ctx;
876     OVRequestItem *request;
877     IEStatusCode status;
878     int ret;
879 
880     if (ff_queue_size(ov_model->lltask_queue) == 0) {
881         // no pending task need to flush
882         return 0;
883     }
884 
885     request = ff_safe_queue_pop_front(ov_model->request_queue);
886     if (!request) {
887         av_log(ctx, AV_LOG_ERROR, "unable to get infer request.\n");
888         return AVERROR(EINVAL);
889     }
890 
891     ret = fill_model_input_ov(ov_model, request);
892     if (ret != 0) {
893         av_log(ctx, AV_LOG_ERROR, "Failed to fill model input.\n");
894         return ret;
895     }
896     status = ie_infer_set_completion_callback(request->infer_request, &request->callback);
897     if (status != OK) {
898         av_log(ctx, AV_LOG_ERROR, "Failed to set completion callback for inference\n");
899         return DNN_GENERIC_ERROR;
900     }
901     status = ie_infer_request_infer_async(request->infer_request);
902     if (status != OK) {
903         av_log(ctx, AV_LOG_ERROR, "Failed to start async inference\n");
904         return DNN_GENERIC_ERROR;
905     }
906 
907     return 0;
908 }
909 
ff_dnn_free_model_ov(DNNModel ** model)910 void ff_dnn_free_model_ov(DNNModel **model)
911 {
912     if (*model){
913         OVModel *ov_model = (*model)->model;
914         while (ff_safe_queue_size(ov_model->request_queue) != 0) {
915             OVRequestItem *item = ff_safe_queue_pop_front(ov_model->request_queue);
916             if (item && item->infer_request) {
917                 ie_infer_request_free(&item->infer_request);
918             }
919             av_freep(&item->lltasks);
920             av_freep(&item);
921         }
922         ff_safe_queue_destroy(ov_model->request_queue);
923 
924         while (ff_queue_size(ov_model->lltask_queue) != 0) {
925             LastLevelTaskItem *item = ff_queue_pop_front(ov_model->lltask_queue);
926             av_freep(&item);
927         }
928         ff_queue_destroy(ov_model->lltask_queue);
929 
930         while (ff_queue_size(ov_model->task_queue) != 0) {
931             TaskItem *item = ff_queue_pop_front(ov_model->task_queue);
932             av_frame_free(&item->in_frame);
933             av_frame_free(&item->out_frame);
934             av_freep(&item);
935         }
936         ff_queue_destroy(ov_model->task_queue);
937 
938         if (ov_model->exe_network)
939             ie_exec_network_free(&ov_model->exe_network);
940         if (ov_model->network)
941             ie_network_free(&ov_model->network);
942         if (ov_model->core)
943             ie_core_free(&ov_model->core);
944         av_freep(&ov_model);
945         av_freep(model);
946     }
947 }
948