• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1##  Copyright (c) 2020 The WebM project authors. All Rights Reserved.
2##
3##  Use of this source code is governed by a BSD-style license
4##  that can be found in the LICENSE file in the root of the source
5##  tree. An additional intellectual property rights grant can be found
6##  in the file PATENTS.  All contributing project authors may
7##  be found in the AUTHORS file in the root of the source tree.
8##
9
10#coding : utf - 8
11import numpy as np
12import numpy.linalg as LA
13import matplotlib.pyplot as plt
14from Util import drawMF, MSE
15"""The Base Class of Estimators"""
16
17
18class MotionEST(object):
19  """
20    constructor:
21        cur_f: current frame
22        ref_f: reference frame
23        blk_sz: block size
24    """
25
26  def __init__(self, cur_f, ref_f, blk_sz):
27    self.cur_f = cur_f
28    self.ref_f = ref_f
29    self.blk_sz = blk_sz
30    #convert RGB to YUV
31    self.cur_yuv = np.array(self.cur_f.convert('YCbCr'), dtype=int)
32    self.ref_yuv = np.array(self.ref_f.convert('YCbCr'), dtype=int)
33    #frame size
34    self.width = self.cur_f.size[0]
35    self.height = self.cur_f.size[1]
36    #motion field size
37    self.num_row = self.height // self.blk_sz
38    self.num_col = self.width // self.blk_sz
39    #initialize motion field
40    self.mf = np.zeros((self.num_row, self.num_col, 2))
41
42  """estimation function Override by child classes"""
43
44  def motion_field_estimation(self):
45    pass
46
47  """
48    distortion of a block:
49      cur_r: current row
50      cur_c: current column
51      mv: motion vector
52      metric: distortion metric
53  """
54
55  def block_dist(self, cur_r, cur_c, mv, metric=MSE):
56    cur_x = cur_c * self.blk_sz
57    cur_y = cur_r * self.blk_sz
58    h = min(self.blk_sz, self.height - cur_y)
59    w = min(self.blk_sz, self.width - cur_x)
60    cur_blk = self.cur_yuv[cur_y:cur_y + h, cur_x:cur_x + w, :]
61    ref_x = int(cur_x + mv[1])
62    ref_y = int(cur_y + mv[0])
63    if 0 <= ref_x < self.width - w and 0 <= ref_y < self.height - h:
64      ref_blk = self.ref_yuv[ref_y:ref_y + h, ref_x:ref_x + w, :]
65    else:
66      ref_blk = np.zeros((h, w, 3))
67    return metric(cur_blk, ref_blk)
68
69  """
70    distortion of motion field
71  """
72
73  def distortion(self, mask=None, metric=MSE):
74    loss = 0
75    count = 0
76    for i in xrange(self.num_row):
77      for j in xrange(self.num_col):
78        if mask is not None and mask[i, j]:
79          continue
80        loss += self.block_dist(i, j, self.mf[i, j], metric)
81        count += 1
82    return loss / count
83
84  """evaluation compare the difference with ground truth"""
85
86  def motion_field_evaluation(self, ground_truth):
87    loss = 0
88    count = 0
89    gt = ground_truth.mf
90    mask = ground_truth.mask
91    for i in xrange(self.num_row):
92      for j in xrange(self.num_col):
93        if mask is not None and mask[i][j]:
94          continue
95        loss += LA.norm(gt[i, j] - self.mf[i, j])
96        count += 1
97    return loss / count
98
99  """render the motion field"""
100
101  def show(self, ground_truth=None, size=10):
102    cur_mf = drawMF(self.cur_f, self.blk_sz, self.mf)
103    if ground_truth is None:
104      n_row = 1
105    else:
106      gt_mf = drawMF(self.cur_f, self.blk_sz, ground_truth)
107      n_row = 2
108    plt.figure(figsize=(n_row * size, size * self.height / self.width))
109    plt.subplot(1, n_row, 1)
110    plt.imshow(cur_mf)
111    plt.title('Estimated Motion Field')
112    if ground_truth is not None:
113      plt.subplot(1, n_row, 2)
114      plt.imshow(gt_mf)
115      plt.title('Ground Truth')
116    plt.tight_layout()
117    plt.show()
118