• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python
2
3'''
4MOSSE tracking sample
5
6This sample implements correlation-based tracking approach, described in [1].
7
8Usage:
9  mosse.py [--pause] [<video source>]
10
11  --pause  -  Start with playback paused at the first video frame.
12              Useful for tracking target selection.
13
14  Draw rectangles around objects with a mouse to track them.
15
16Keys:
17  SPACE    - pause video
18  c        - clear targets
19
20[1] David S. Bolme et al. "Visual Object Tracking using Adaptive Correlation Filters"
21    http://www.cs.colostate.edu/~bolme/publications/Bolme2010Tracking.pdf
22'''
23
24import numpy as np
25import cv2
26from common import draw_str, RectSelector
27import video
28
29def rnd_warp(a):
30    h, w = a.shape[:2]
31    T = np.zeros((2, 3))
32    coef = 0.2
33    ang = (np.random.rand()-0.5)*coef
34    c, s = np.cos(ang), np.sin(ang)
35    T[:2, :2] = [[c,-s], [s, c]]
36    T[:2, :2] += (np.random.rand(2, 2) - 0.5)*coef
37    c = (w/2, h/2)
38    T[:,2] = c - np.dot(T[:2, :2], c)
39    return cv2.warpAffine(a, T, (w, h), borderMode = cv2.BORDER_REFLECT)
40
41def divSpec(A, B):
42    Ar, Ai = A[...,0], A[...,1]
43    Br, Bi = B[...,0], B[...,1]
44    C = (Ar+1j*Ai)/(Br+1j*Bi)
45    C = np.dstack([np.real(C), np.imag(C)]).copy()
46    return C
47
48eps = 1e-5
49
50class MOSSE:
51    def __init__(self, frame, rect):
52        x1, y1, x2, y2 = rect
53        w, h = map(cv2.getOptimalDFTSize, [x2-x1, y2-y1])
54        x1, y1 = (x1+x2-w)//2, (y1+y2-h)//2
55        self.pos = x, y = x1+0.5*(w-1), y1+0.5*(h-1)
56        self.size = w, h
57        img = cv2.getRectSubPix(frame, (w, h), (x, y))
58
59        self.win = cv2.createHanningWindow((w, h), cv2.CV_32F)
60        g = np.zeros((h, w), np.float32)
61        g[h//2, w//2] = 1
62        g = cv2.GaussianBlur(g, (-1, -1), 2.0)
63        g /= g.max()
64
65        self.G = cv2.dft(g, flags=cv2.DFT_COMPLEX_OUTPUT)
66        self.H1 = np.zeros_like(self.G)
67        self.H2 = np.zeros_like(self.G)
68        for i in xrange(128):
69            a = self.preprocess(rnd_warp(img))
70            A = cv2.dft(a, flags=cv2.DFT_COMPLEX_OUTPUT)
71            self.H1 += cv2.mulSpectrums(self.G, A, 0, conjB=True)
72            self.H2 += cv2.mulSpectrums(     A, A, 0, conjB=True)
73        self.update_kernel()
74        self.update(frame)
75
76    def update(self, frame, rate = 0.125):
77        (x, y), (w, h) = self.pos, self.size
78        self.last_img = img = cv2.getRectSubPix(frame, (w, h), (x, y))
79        img = self.preprocess(img)
80        self.last_resp, (dx, dy), self.psr = self.correlate(img)
81        self.good = self.psr > 8.0
82        if not self.good:
83            return
84
85        self.pos = x+dx, y+dy
86        self.last_img = img = cv2.getRectSubPix(frame, (w, h), self.pos)
87        img = self.preprocess(img)
88
89        A = cv2.dft(img, flags=cv2.DFT_COMPLEX_OUTPUT)
90        H1 = cv2.mulSpectrums(self.G, A, 0, conjB=True)
91        H2 = cv2.mulSpectrums(     A, A, 0, conjB=True)
92        self.H1 = self.H1 * (1.0-rate) + H1 * rate
93        self.H2 = self.H2 * (1.0-rate) + H2 * rate
94        self.update_kernel()
95
96    @property
97    def state_vis(self):
98        f = cv2.idft(self.H, flags=cv2.DFT_SCALE | cv2.DFT_REAL_OUTPUT )
99        h, w = f.shape
100        f = np.roll(f, -h//2, 0)
101        f = np.roll(f, -w//2, 1)
102        kernel = np.uint8( (f-f.min()) / f.ptp()*255 )
103        resp = self.last_resp
104        resp = np.uint8(np.clip(resp/resp.max(), 0, 1)*255)
105        vis = np.hstack([self.last_img, kernel, resp])
106        return vis
107
108    def draw_state(self, vis):
109        (x, y), (w, h) = self.pos, self.size
110        x1, y1, x2, y2 = int(x-0.5*w), int(y-0.5*h), int(x+0.5*w), int(y+0.5*h)
111        cv2.rectangle(vis, (x1, y1), (x2, y2), (0, 0, 255))
112        if self.good:
113            cv2.circle(vis, (int(x), int(y)), 2, (0, 0, 255), -1)
114        else:
115            cv2.line(vis, (x1, y1), (x2, y2), (0, 0, 255))
116            cv2.line(vis, (x2, y1), (x1, y2), (0, 0, 255))
117        draw_str(vis, (x1, y2+16), 'PSR: %.2f' % self.psr)
118
119    def preprocess(self, img):
120        img = np.log(np.float32(img)+1.0)
121        img = (img-img.mean()) / (img.std()+eps)
122        return img*self.win
123
124    def correlate(self, img):
125        C = cv2.mulSpectrums(cv2.dft(img, flags=cv2.DFT_COMPLEX_OUTPUT), self.H, 0, conjB=True)
126        resp = cv2.idft(C, flags=cv2.DFT_SCALE | cv2.DFT_REAL_OUTPUT)
127        h, w = resp.shape
128        _, mval, _, (mx, my) = cv2.minMaxLoc(resp)
129        side_resp = resp.copy()
130        cv2.rectangle(side_resp, (mx-5, my-5), (mx+5, my+5), 0, -1)
131        smean, sstd = side_resp.mean(), side_resp.std()
132        psr = (mval-smean) / (sstd+eps)
133        return resp, (mx-w//2, my-h//2), psr
134
135    def update_kernel(self):
136        self.H = divSpec(self.H1, self.H2)
137        self.H[...,1] *= -1
138
139class App:
140    def __init__(self, video_src, paused = False):
141        self.cap = video.create_capture(video_src)
142        _, self.frame = self.cap.read()
143        cv2.imshow('frame', self.frame)
144        self.rect_sel = RectSelector('frame', self.onrect)
145        self.trackers = []
146        self.paused = paused
147
148    def onrect(self, rect):
149        frame_gray = cv2.cvtColor(self.frame, cv2.COLOR_BGR2GRAY)
150        tracker = MOSSE(frame_gray, rect)
151        self.trackers.append(tracker)
152
153    def run(self):
154        while True:
155            if not self.paused:
156                ret, self.frame = self.cap.read()
157                if not ret:
158                    break
159                frame_gray = cv2.cvtColor(self.frame, cv2.COLOR_BGR2GRAY)
160                for tracker in self.trackers:
161                    tracker.update(frame_gray)
162
163            vis = self.frame.copy()
164            for tracker in self.trackers:
165                tracker.draw_state(vis)
166            if len(self.trackers) > 0:
167                cv2.imshow('tracker state', self.trackers[-1].state_vis)
168            self.rect_sel.draw(vis)
169
170            cv2.imshow('frame', vis)
171            ch = cv2.waitKey(10) & 0xFF
172            if ch == 27:
173                break
174            if ch == ord(' '):
175                self.paused = not self.paused
176            if ch == ord('c'):
177                self.trackers = []
178
179
180if __name__ == '__main__':
181    print __doc__
182    import sys, getopt
183    opts, args = getopt.getopt(sys.argv[1:], '', ['pause'])
184    opts = dict(opts)
185    try:
186        video_src = args[0]
187    except:
188        video_src = '0'
189
190    App(video_src, paused = '--pause' in opts).run()
191