• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
2# SPDX-License-Identifier: MIT
3
4"""
5This file contains helper functions for reading video/image data and
6 pre/postprocessing of video/image data using OpenCV.
7"""
8
9import os
10
11import cv2
12import numpy as np
13
14import pyarmnn as ann
15
16
17def preprocess(frame: np.ndarray, input_binding_info: tuple):
18    """
19    Takes a frame, resizes, swaps channels and converts data type to match
20    model input layer. The converted frame is wrapped in a const tensor
21    and bound to the input tensor.
22
23    Args:
24        frame: Captured frame from video.
25        input_binding_info:  Contains shape and data type of model input layer.
26
27    Returns:
28        Input tensor.
29    """
30    # Swap channels and resize frame to model resolution
31    frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
32    resized_frame = resize_with_aspect_ratio(frame, input_binding_info)
33
34    # Expand dimensions and convert data type to match model input
35    data_type = np.float32 if input_binding_info[1].GetDataType() == ann.DataType_Float32 else np.uint8
36    resized_frame = np.expand_dims(np.asarray(resized_frame, dtype=data_type), axis=0)
37    assert resized_frame.shape == tuple(input_binding_info[1].GetShape())
38
39    input_tensors = ann.make_input_tensors([input_binding_info], [resized_frame])
40    return input_tensors
41
42
43def resize_with_aspect_ratio(frame: np.ndarray, input_binding_info: tuple):
44    """
45    Resizes frame while maintaining aspect ratio, padding any empty space.
46
47    Args:
48        frame: Captured frame.
49        input_binding_info: Contains shape of model input layer.
50
51    Returns:
52        Frame resized to the size of model input layer.
53    """
54    aspect_ratio = frame.shape[1] / frame.shape[0]
55    model_height, model_width = list(input_binding_info[1].GetShape())[1:3]
56
57    if aspect_ratio >= 1.0:
58        new_height, new_width = int(model_width / aspect_ratio), model_width
59        b_padding, r_padding = model_height - new_height, 0
60    else:
61        new_height, new_width = model_height, int(model_height * aspect_ratio)
62        b_padding, r_padding = 0, model_width - new_width
63
64    # Resize and pad any empty space
65    frame = cv2.resize(frame, (new_width, new_height), interpolation=cv2.INTER_LINEAR)
66    frame = cv2.copyMakeBorder(frame, top=0, bottom=b_padding, left=0, right=r_padding,
67                               borderType=cv2.BORDER_CONSTANT, value=[0, 0, 0])
68    return frame
69
70
71def create_video_writer(video: cv2.VideoCapture, video_path: str, output_path: str):
72    """
73    Creates a video writer object to write processed frames to file.
74
75    Args:
76        video: Video capture object, contains information about data source.
77        video_path: User-specified video file path.
78        output_path: Optional path to save the processed video.
79
80    Returns:
81        Video writer object.
82    """
83    _, ext = os.path.splitext(video_path)
84
85    if output_path is not None:
86        assert os.path.isdir(output_path)
87
88    i, filename = 0, os.path.join(output_path if output_path is not None else str(), f'object_detection_demo{ext}')
89    while os.path.exists(filename):
90        i += 1
91        filename = os.path.join(output_path if output_path is not None else str(), f'object_detection_demo({i}){ext}')
92
93    video_writer = cv2.VideoWriter(filename=filename,
94                                   fourcc=get_source_encoding_int(video),
95                                   fps=int(video.get(cv2.CAP_PROP_FPS)),
96                                   frameSize=(int(video.get(cv2.CAP_PROP_FRAME_WIDTH)),
97                                              int(video.get(cv2.CAP_PROP_FRAME_HEIGHT))))
98    return video_writer
99
100
101def init_video_file_capture(video_path: str, output_path: str):
102    """
103    Creates a video capture object from a video file.
104
105    Args:
106        video_path: User-specified video file path.
107        output_path: Optional path to save the processed video.
108
109    Returns:
110        Video capture object to capture frames, video writer object to write processed
111        frames to file, plus total frame count of video source to iterate through.
112    """
113    if not os.path.exists(video_path):
114        raise FileNotFoundError(f'Video file not found for: {video_path}')
115    video = cv2.VideoCapture(video_path)
116    if not video.isOpened:
117        raise RuntimeError(f'Failed to open video capture from file: {video_path}')
118
119    video_writer = create_video_writer(video, video_path, output_path)
120    iter_frame_count = range(int(video.get(cv2.CAP_PROP_FRAME_COUNT)))
121    return video, video_writer, iter_frame_count
122
123
124def init_video_stream_capture(video_source: int):
125    """
126    Creates a video capture object from a device.
127
128    Args:
129        video_source: Device index used to read video stream.
130
131    Returns:
132        Video capture object used to capture frames from a video stream.
133    """
134    video = cv2.VideoCapture(video_source)
135    if not video.isOpened:
136        raise RuntimeError(f'Failed to open video capture for device with index: {video_source}')
137    print('Processing video stream. Press \'Esc\' key to exit the demo.')
138    return video
139
140
141def draw_bounding_boxes(frame: np.ndarray, detections: list, resize_factor, labels: dict):
142    """
143    Draws bounding boxes around detected objects and adds a label and confidence score.
144
145    Args:
146        frame: The original captured frame from video source.
147        detections: A list of detected objects in the form [class, [box positions], confidence].
148        resize_factor: Resizing factor to scale box coordinates to output frame size.
149        labels: Dictionary of labels and colors keyed on the classification index.
150    """
151    for detection in detections:
152        class_idx, box, confidence = [d for d in detection]
153        label, color = labels[class_idx][0].capitalize(), labels[class_idx][1]
154
155        # Obtain frame size and resized bounding box positions
156        frame_height, frame_width = frame.shape[:2]
157        x_min, y_min, x_max, y_max = [int(position * resize_factor) for position in box]
158
159        # Ensure box stays within the frame
160        x_min, y_min = max(0, x_min), max(0, y_min)
161        x_max, y_max = min(frame_width, x_max), min(frame_height, y_max)
162
163        # Draw bounding box around detected object
164        cv2.rectangle(frame, (x_min, y_min), (x_max, y_max), color, 2)
165
166        # Create label for detected object class
167        label = f'{label} {confidence * 100:.1f}%'
168        label_color = (0, 0, 0) if sum(color)>200 else (255, 255, 255)
169
170        # Make sure label always stays on-screen
171        x_text, y_text = cv2.getTextSize(label, cv2.FONT_HERSHEY_DUPLEX, 1, 1)[0][:2]
172
173        lbl_box_xy_min = (x_min, y_min if y_min<25 else y_min - y_text)
174        lbl_box_xy_max = (x_min + int(0.55 * x_text), y_min + y_text if y_min<25 else y_min)
175        lbl_text_pos = (x_min + 5, y_min + 16 if y_min<25 else y_min - 5)
176
177        # Add label and confidence value
178        cv2.rectangle(frame, lbl_box_xy_min, lbl_box_xy_max, color, -1)
179        cv2.putText(frame, label, lbl_text_pos, cv2.FONT_HERSHEY_DUPLEX, 0.50,
180                    label_color, 1, cv2.LINE_AA)
181
182
183def get_source_encoding_int(video_capture):
184    return int(video_capture.get(cv2.CAP_PROP_FOURCC))
185