1## Copyright (c) 2020 The WebM project authors. All Rights Reserved. 2## 3## Use of this source code is governed by a BSD-style license 4## that can be found in the LICENSE file in the root of the source 5## tree. An additional intellectual property rights grant can be found 6## in the file PATENTS. All contributing project authors may 7## be found in the AUTHORS file in the root of the source tree. 8## 9 10#coding : utf - 8 11import numpy as np 12import numpy.linalg as LA 13import matplotlib.pyplot as plt 14from Util import drawMF, MSE 15"""The Base Class of Estimators""" 16 17 18class MotionEST(object): 19 """ 20 constructor: 21 cur_f: current frame 22 ref_f: reference frame 23 blk_sz: block size 24 """ 25 26 def __init__(self, cur_f, ref_f, blk_sz): 27 self.cur_f = cur_f 28 self.ref_f = ref_f 29 self.blk_sz = blk_sz 30 #convert RGB to YUV 31 self.cur_yuv = np.array(self.cur_f.convert('YCbCr'), dtype=int) 32 self.ref_yuv = np.array(self.ref_f.convert('YCbCr'), dtype=int) 33 #frame size 34 self.width = self.cur_f.size[0] 35 self.height = self.cur_f.size[1] 36 #motion field size 37 self.num_row = self.height // self.blk_sz 38 self.num_col = self.width // self.blk_sz 39 #initialize motion field 40 self.mf = np.zeros((self.num_row, self.num_col, 2)) 41 42 """estimation function Override by child classes""" 43 44 def motion_field_estimation(self): 45 pass 46 47 """ 48 distortion of a block: 49 cur_r: current row 50 cur_c: current column 51 mv: motion vector 52 metric: distortion metric 53 """ 54 55 def block_dist(self, cur_r, cur_c, mv, metric=MSE): 56 cur_x = cur_c * self.blk_sz 57 cur_y = cur_r * self.blk_sz 58 h = min(self.blk_sz, self.height - cur_y) 59 w = min(self.blk_sz, self.width - cur_x) 60 cur_blk = self.cur_yuv[cur_y:cur_y + h, cur_x:cur_x + w, :] 61 ref_x = int(cur_x + mv[1]) 62 ref_y = int(cur_y + mv[0]) 63 if 0 <= ref_x < self.width - w and 0 <= ref_y < self.height - h: 64 ref_blk = self.ref_yuv[ref_y:ref_y + h, ref_x:ref_x + w, :] 65 else: 66 ref_blk = np.zeros((h, w, 3)) 67 return metric(cur_blk, ref_blk) 68 69 """ 70 distortion of motion field 71 """ 72 73 def distortion(self, mask=None, metric=MSE): 74 loss = 0 75 count = 0 76 for i in xrange(self.num_row): 77 for j in xrange(self.num_col): 78 if mask is not None and mask[i, j]: 79 continue 80 loss += self.block_dist(i, j, self.mf[i, j], metric) 81 count += 1 82 return loss / count 83 84 """evaluation compare the difference with ground truth""" 85 86 def motion_field_evaluation(self, ground_truth): 87 loss = 0 88 count = 0 89 gt = ground_truth.mf 90 mask = ground_truth.mask 91 for i in xrange(self.num_row): 92 for j in xrange(self.num_col): 93 if mask is not None and mask[i][j]: 94 continue 95 loss += LA.norm(gt[i, j] - self.mf[i, j]) 96 count += 1 97 return loss / count 98 99 """render the motion field""" 100 101 def show(self, ground_truth=None, size=10): 102 cur_mf = drawMF(self.cur_f, self.blk_sz, self.mf) 103 if ground_truth is None: 104 n_row = 1 105 else: 106 gt_mf = drawMF(self.cur_f, self.blk_sz, ground_truth) 107 n_row = 2 108 plt.figure(figsize=(n_row * size, size * self.height / self.width)) 109 plt.subplot(1, n_row, 1) 110 plt.imshow(cur_mf) 111 plt.title('Estimated Motion Field') 112 if ground_truth is not None: 113 plt.subplot(1, n_row, 2) 114 plt.imshow(gt_mf) 115 plt.title('Ground Truth') 116 plt.tight_layout() 117 plt.show() 118