• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2019, Alliance for Open Media. All rights reserved
3  *
4  * This source code is subject to the terms of the BSD 2 Clause License and
5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6  * was not distributed with this source code in the LICENSE file, you can
7  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8  * Media Patent License 1.0 was not distributed with this source code in the
9  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10  */
11 
12 #include <assert.h>
13 #include <libvmaf/libvmaf.h>
14 #include <stdio.h>
15 #include <stdlib.h>
16 #include <string.h>
17 
18 #include "aom_dsp/blend.h"
19 #include "aom_dsp/vmaf.h"
20 #include "aom_ports/system_state.h"
21 
22 typedef struct FrameData {
23   const YV12_BUFFER_CONFIG *source;
24   const YV12_BUFFER_CONFIG *distorted;
25   int frame_set;
26   int bit_depth;
27 } FrameData;
28 
vmaf_fatal_error(const char * message)29 static void vmaf_fatal_error(const char *message) {
30   fprintf(stderr, "Fatal error: %s\n", message);
31   exit(EXIT_FAILURE);
32 }
33 
34 // A callback function used to pass data to VMAF.
35 // Returns 0 after reading a frame.
36 // Returns 2 when there is no more frame to read.
read_frame(float * ref_data,float * main_data,float * temp_data,int stride,void * user_data)37 static int read_frame(float *ref_data, float *main_data, float *temp_data,
38                       int stride, void *user_data) {
39   FrameData *frames = (FrameData *)user_data;
40 
41   if (!frames->frame_set) {
42     const int width = frames->source->y_width;
43     const int height = frames->source->y_height;
44     assert(width == frames->distorted->y_width);
45     assert(height == frames->distorted->y_height);
46 
47     if (frames->bit_depth > 8) {
48       const float scale_factor = 1.0f / (float)(1 << (frames->bit_depth - 8));
49       uint16_t *ref_ptr = CONVERT_TO_SHORTPTR(frames->source->y_buffer);
50       uint16_t *main_ptr = CONVERT_TO_SHORTPTR(frames->distorted->y_buffer);
51 
52       for (int row = 0; row < height; ++row) {
53         for (int col = 0; col < width; ++col) {
54           ref_data[col] = scale_factor * (float)ref_ptr[col];
55         }
56         ref_ptr += frames->source->y_stride;
57         ref_data += stride / sizeof(*ref_data);
58       }
59 
60       for (int row = 0; row < height; ++row) {
61         for (int col = 0; col < width; ++col) {
62           main_data[col] = scale_factor * (float)main_ptr[col];
63         }
64         main_ptr += frames->distorted->y_stride;
65         main_data += stride / sizeof(*main_data);
66       }
67     } else {
68       uint8_t *ref_ptr = frames->source->y_buffer;
69       uint8_t *main_ptr = frames->distorted->y_buffer;
70 
71       for (int row = 0; row < height; ++row) {
72         for (int col = 0; col < width; ++col) {
73           ref_data[col] = (float)ref_ptr[col];
74         }
75         ref_ptr += frames->source->y_stride;
76         ref_data += stride / sizeof(*ref_data);
77       }
78 
79       for (int row = 0; row < height; ++row) {
80         for (int col = 0; col < width; ++col) {
81           main_data[col] = (float)main_ptr[col];
82         }
83         main_ptr += frames->distorted->y_stride;
84         main_data += stride / sizeof(*main_data);
85       }
86     }
87     frames->frame_set = 1;
88     return 0;
89   }
90 
91   (void)temp_data;
92   return 2;
93 }
94 
aom_calc_vmaf(const char * model_path,const YV12_BUFFER_CONFIG * source,const YV12_BUFFER_CONFIG * distorted,const int bit_depth,double * const vmaf)95 void aom_calc_vmaf(const char *model_path, const YV12_BUFFER_CONFIG *source,
96                    const YV12_BUFFER_CONFIG *distorted, const int bit_depth,
97                    double *const vmaf) {
98   aom_clear_system_state();
99   const int width = source->y_width;
100   const int height = source->y_height;
101   FrameData frames = { source, distorted, 0, bit_depth };
102   char *fmt = bit_depth == 10 ? "yuv420p10le" : "yuv420p";
103   double vmaf_score;
104   const int ret =
105       compute_vmaf(&vmaf_score, fmt, width, height, read_frame,
106                    /*user_data=*/&frames, (char *)model_path,
107                    /*log_path=*/NULL, /*log_fmt=*/NULL, /*disable_clip=*/1,
108                    /*disable_avx=*/0, /*enable_transform=*/0,
109                    /*phone_model=*/0, /*do_psnr=*/0, /*do_ssim=*/0,
110                    /*do_ms_ssim=*/0, /*pool_method=*/NULL, /*n_thread=*/0,
111                    /*n_subsample=*/1, /*enable_conf_interval=*/0);
112   if (ret) vmaf_fatal_error("Failed to compute VMAF scores.");
113 
114   aom_clear_system_state();
115   *vmaf = vmaf_score;
116 }
117 
aom_calc_vmaf_multi_frame(void * user_data,const char * model_path,int (* read_frame)(float * ref_data,float * main_data,float * temp_data,int stride_byte,void * user_data),int frame_width,int frame_height,int bit_depth,double * vmaf)118 void aom_calc_vmaf_multi_frame(
119     void *user_data, const char *model_path,
120     int (*read_frame)(float *ref_data, float *main_data, float *temp_data,
121                       int stride_byte, void *user_data),
122     int frame_width, int frame_height, int bit_depth, double *vmaf) {
123   aom_clear_system_state();
124 
125   char *fmt = bit_depth == 10 ? "yuv420p10le" : "yuv420p";
126   double vmaf_score;
127   const int ret = compute_vmaf(
128       &vmaf_score, fmt, frame_width, frame_height, read_frame,
129       /*user_data=*/user_data, (char *)model_path,
130       /*log_path=*/"vmaf_scores.xml", /*log_fmt=*/NULL, /*disable_clip=*/0,
131       /*disable_avx=*/0, /*enable_transform=*/0,
132       /*phone_model=*/0, /*do_psnr=*/0, /*do_ssim=*/0,
133       /*do_ms_ssim=*/0, /*pool_method=*/NULL, /*n_thread=*/0,
134       /*n_subsample=*/1, /*enable_conf_interval=*/0);
135   FILE *vmaf_log = fopen("vmaf_scores.xml", "r");
136   if (vmaf_log == NULL || ret) {
137     vmaf_fatal_error("Failed to compute VMAF scores.");
138   }
139 
140   int frame_index = 0;
141   char buf[512];
142   while (fgets(buf, 511, vmaf_log) != NULL) {
143     if (memcmp(buf, "\t\t<frame ", 9) == 0) {
144       char *p = strstr(buf, "vmaf=");
145       if (p != NULL && p[5] == '"') {
146         char *p2 = strstr(&p[6], "\"");
147         *p2 = '\0';
148         const double score = atof(&p[6]);
149         if (score < 0.0 || score > 100.0) {
150           vmaf_fatal_error("Failed to compute VMAF scores.");
151         }
152         vmaf[frame_index++] = score;
153       }
154     }
155   }
156   fclose(vmaf_log);
157 
158   aom_clear_system_state();
159 }
160