• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python
2
3'''
4Multitarget planar tracking
5==================
6
7Example of using features2d framework for interactive video homography matching.
8ORB features and FLANN matcher are used. This sample provides PlaneTracker class
9and an example of its usage.
10
11video: http://www.youtube.com/watch?v=pzVbhxx6aog
12
13Usage
14-----
15plane_tracker.py [<video source>]
16
17Keys:
18   SPACE  -  pause video
19   c      -  clear targets
20
21Select a textured planar object to track by drawing a box with a mouse.
22'''
23
24import numpy as np
25import cv2
26
27# built-in modules
28from collections import namedtuple
29
30# local modules
31import video
32import common
33
34
35FLANN_INDEX_KDTREE = 1
36FLANN_INDEX_LSH    = 6
37flann_params= dict(algorithm = FLANN_INDEX_LSH,
38                   table_number = 6, # 12
39                   key_size = 12,     # 20
40                   multi_probe_level = 1) #2
41
42MIN_MATCH_COUNT = 10
43
44'''
45  image     - image to track
46  rect      - tracked rectangle (x1, y1, x2, y2)
47  keypoints - keypoints detected inside rect
48  descrs    - their descriptors
49  data      - some user-provided data
50'''
51PlanarTarget = namedtuple('PlaneTarget', 'image, rect, keypoints, descrs, data')
52
53'''
54  target - reference to PlanarTarget
55  p0     - matched points coords in target image
56  p1     - matched points coords in input frame
57  H      - homography matrix from p0 to p1
58  quad   - target bounary quad in input frame
59'''
60TrackedTarget = namedtuple('TrackedTarget', 'target, p0, p1, H, quad')
61
62class PlaneTracker:
63    def __init__(self):
64        self.detector = cv2.ORB_create( nfeatures = 1000 )
65        self.matcher = cv2.FlannBasedMatcher(flann_params, {})  # bug : need to pass empty dict (#1329)
66        self.targets = []
67
68    def add_target(self, image, rect, data=None):
69        '''Add a new tracking target.'''
70        x0, y0, x1, y1 = rect
71        raw_points, raw_descrs = self.detect_features(image)
72        points, descs = [], []
73        for kp, desc in zip(raw_points, raw_descrs):
74            x, y = kp.pt
75            if x0 <= x <= x1 and y0 <= y <= y1:
76                points.append(kp)
77                descs.append(desc)
78        descs = np.uint8(descs)
79        self.matcher.add([descs])
80        target = PlanarTarget(image = image, rect=rect, keypoints = points, descrs=descs, data=data)
81        self.targets.append(target)
82
83    def clear(self):
84        '''Remove all targets'''
85        self.targets = []
86        self.matcher.clear()
87
88    def track(self, frame):
89        '''Returns a list of detected TrackedTarget objects'''
90        frame_points, frame_descrs = self.detect_features(frame)
91        if len(frame_points) < MIN_MATCH_COUNT:
92            return []
93        matches = self.matcher.knnMatch(frame_descrs, k = 2)
94        matches = [m[0] for m in matches if len(m) == 2 and m[0].distance < m[1].distance * 0.75]
95        if len(matches) < MIN_MATCH_COUNT:
96            return []
97        matches_by_id = [[] for _ in xrange(len(self.targets))]
98        for m in matches:
99            matches_by_id[m.imgIdx].append(m)
100        tracked = []
101        for imgIdx, matches in enumerate(matches_by_id):
102            if len(matches) < MIN_MATCH_COUNT:
103                continue
104            target = self.targets[imgIdx]
105            p0 = [target.keypoints[m.trainIdx].pt for m in matches]
106            p1 = [frame_points[m.queryIdx].pt for m in matches]
107            p0, p1 = np.float32((p0, p1))
108            H, status = cv2.findHomography(p0, p1, cv2.RANSAC, 3.0)
109            status = status.ravel() != 0
110            if status.sum() < MIN_MATCH_COUNT:
111                continue
112            p0, p1 = p0[status], p1[status]
113
114            x0, y0, x1, y1 = target.rect
115            quad = np.float32([[x0, y0], [x1, y0], [x1, y1], [x0, y1]])
116            quad = cv2.perspectiveTransform(quad.reshape(1, -1, 2), H).reshape(-1, 2)
117
118            track = TrackedTarget(target=target, p0=p0, p1=p1, H=H, quad=quad)
119            tracked.append(track)
120        tracked.sort(key = lambda t: len(t.p0), reverse=True)
121        return tracked
122
123    def detect_features(self, frame):
124        '''detect_features(self, frame) -> keypoints, descrs'''
125        keypoints, descrs = self.detector.detectAndCompute(frame, None)
126        if descrs is None:  # detectAndCompute returns descs=None if not keypoints found
127            descrs = []
128        return keypoints, descrs
129
130
131class App:
132    def __init__(self, src):
133        self.cap = video.create_capture(src)
134        self.frame = None
135        self.paused = False
136        self.tracker = PlaneTracker()
137
138        cv2.namedWindow('plane')
139        self.rect_sel = common.RectSelector('plane', self.on_rect)
140
141    def on_rect(self, rect):
142        self.tracker.add_target(self.frame, rect)
143
144    def run(self):
145        while True:
146            playing = not self.paused and not self.rect_sel.dragging
147            if playing or self.frame is None:
148                ret, frame = self.cap.read()
149                if not ret:
150                    break
151                self.frame = frame.copy()
152
153            vis = self.frame.copy()
154            if playing:
155                tracked = self.tracker.track(self.frame)
156                for tr in tracked:
157                    cv2.polylines(vis, [np.int32(tr.quad)], True, (255, 255, 255), 2)
158                    for (x, y) in np.int32(tr.p1):
159                        cv2.circle(vis, (x, y), 2, (255, 255, 255))
160
161            self.rect_sel.draw(vis)
162            cv2.imshow('plane', vis)
163            ch = cv2.waitKey(1) & 0xFF
164            if ch == ord(' '):
165                self.paused = not self.paused
166            if ch == ord('c'):
167                self.tracker.clear()
168            if ch == 27:
169                break
170
171if __name__ == '__main__':
172    print __doc__
173
174    import sys
175    try:
176        video_src = sys.argv[1]
177    except:
178        video_src = 0
179    App(video_src).run()
180