1# Copyright © 2020 Arm Ltd and Contributors. All rights reserved. 2# SPDX-License-Identifier: MIT 3 4""" 5This file contains helper functions for reading video/image data and 6 pre/postprocessing of video/image data using OpenCV. 7""" 8 9import os 10 11import cv2 12import numpy as np 13 14import pyarmnn as ann 15 16 17def preprocess(frame: np.ndarray, input_binding_info: tuple): 18 """ 19 Takes a frame, resizes, swaps channels and converts data type to match 20 model input layer. The converted frame is wrapped in a const tensor 21 and bound to the input tensor. 22 23 Args: 24 frame: Captured frame from video. 25 input_binding_info: Contains shape and data type of model input layer. 26 27 Returns: 28 Input tensor. 29 """ 30 # Swap channels and resize frame to model resolution 31 frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) 32 resized_frame = resize_with_aspect_ratio(frame, input_binding_info) 33 34 # Expand dimensions and convert data type to match model input 35 data_type = np.float32 if input_binding_info[1].GetDataType() == ann.DataType_Float32 else np.uint8 36 resized_frame = np.expand_dims(np.asarray(resized_frame, dtype=data_type), axis=0) 37 assert resized_frame.shape == tuple(input_binding_info[1].GetShape()) 38 39 input_tensors = ann.make_input_tensors([input_binding_info], [resized_frame]) 40 return input_tensors 41 42 43def resize_with_aspect_ratio(frame: np.ndarray, input_binding_info: tuple): 44 """ 45 Resizes frame while maintaining aspect ratio, padding any empty space. 46 47 Args: 48 frame: Captured frame. 49 input_binding_info: Contains shape of model input layer. 50 51 Returns: 52 Frame resized to the size of model input layer. 53 """ 54 aspect_ratio = frame.shape[1] / frame.shape[0] 55 model_height, model_width = list(input_binding_info[1].GetShape())[1:3] 56 57 if aspect_ratio >= 1.0: 58 new_height, new_width = int(model_width / aspect_ratio), model_width 59 b_padding, r_padding = model_height - new_height, 0 60 else: 61 new_height, new_width = model_height, int(model_height * aspect_ratio) 62 b_padding, r_padding = 0, model_width - new_width 63 64 # Resize and pad any empty space 65 frame = cv2.resize(frame, (new_width, new_height), interpolation=cv2.INTER_LINEAR) 66 frame = cv2.copyMakeBorder(frame, top=0, bottom=b_padding, left=0, right=r_padding, 67 borderType=cv2.BORDER_CONSTANT, value=[0, 0, 0]) 68 return frame 69 70 71def create_video_writer(video: cv2.VideoCapture, video_path: str, output_path: str): 72 """ 73 Creates a video writer object to write processed frames to file. 74 75 Args: 76 video: Video capture object, contains information about data source. 77 video_path: User-specified video file path. 78 output_path: Optional path to save the processed video. 79 80 Returns: 81 Video writer object. 82 """ 83 _, ext = os.path.splitext(video_path) 84 85 if output_path is not None: 86 assert os.path.isdir(output_path) 87 88 i, filename = 0, os.path.join(output_path if output_path is not None else str(), f'object_detection_demo{ext}') 89 while os.path.exists(filename): 90 i += 1 91 filename = os.path.join(output_path if output_path is not None else str(), f'object_detection_demo({i}){ext}') 92 93 video_writer = cv2.VideoWriter(filename=filename, 94 fourcc=get_source_encoding_int(video), 95 fps=int(video.get(cv2.CAP_PROP_FPS)), 96 frameSize=(int(video.get(cv2.CAP_PROP_FRAME_WIDTH)), 97 int(video.get(cv2.CAP_PROP_FRAME_HEIGHT)))) 98 return video_writer 99 100 101def init_video_file_capture(video_path: str, output_path: str): 102 """ 103 Creates a video capture object from a video file. 104 105 Args: 106 video_path: User-specified video file path. 107 output_path: Optional path to save the processed video. 108 109 Returns: 110 Video capture object to capture frames, video writer object to write processed 111 frames to file, plus total frame count of video source to iterate through. 112 """ 113 if not os.path.exists(video_path): 114 raise FileNotFoundError(f'Video file not found for: {video_path}') 115 video = cv2.VideoCapture(video_path) 116 if not video.isOpened: 117 raise RuntimeError(f'Failed to open video capture from file: {video_path}') 118 119 video_writer = create_video_writer(video, video_path, output_path) 120 iter_frame_count = range(int(video.get(cv2.CAP_PROP_FRAME_COUNT))) 121 return video, video_writer, iter_frame_count 122 123 124def init_video_stream_capture(video_source: int): 125 """ 126 Creates a video capture object from a device. 127 128 Args: 129 video_source: Device index used to read video stream. 130 131 Returns: 132 Video capture object used to capture frames from a video stream. 133 """ 134 video = cv2.VideoCapture(video_source) 135 if not video.isOpened: 136 raise RuntimeError(f'Failed to open video capture for device with index: {video_source}') 137 print('Processing video stream. Press \'Esc\' key to exit the demo.') 138 return video 139 140 141def draw_bounding_boxes(frame: np.ndarray, detections: list, resize_factor, labels: dict): 142 """ 143 Draws bounding boxes around detected objects and adds a label and confidence score. 144 145 Args: 146 frame: The original captured frame from video source. 147 detections: A list of detected objects in the form [class, [box positions], confidence]. 148 resize_factor: Resizing factor to scale box coordinates to output frame size. 149 labels: Dictionary of labels and colors keyed on the classification index. 150 """ 151 for detection in detections: 152 class_idx, box, confidence = [d for d in detection] 153 label, color = labels[class_idx][0].capitalize(), labels[class_idx][1] 154 155 # Obtain frame size and resized bounding box positions 156 frame_height, frame_width = frame.shape[:2] 157 x_min, y_min, x_max, y_max = [int(position * resize_factor) for position in box] 158 159 # Ensure box stays within the frame 160 x_min, y_min = max(0, x_min), max(0, y_min) 161 x_max, y_max = min(frame_width, x_max), min(frame_height, y_max) 162 163 # Draw bounding box around detected object 164 cv2.rectangle(frame, (x_min, y_min), (x_max, y_max), color, 2) 165 166 # Create label for detected object class 167 label = f'{label} {confidence * 100:.1f}%' 168 label_color = (0, 0, 0) if sum(color)>200 else (255, 255, 255) 169 170 # Make sure label always stays on-screen 171 x_text, y_text = cv2.getTextSize(label, cv2.FONT_HERSHEY_DUPLEX, 1, 1)[0][:2] 172 173 lbl_box_xy_min = (x_min, y_min if y_min<25 else y_min - y_text) 174 lbl_box_xy_max = (x_min + int(0.55 * x_text), y_min + y_text if y_min<25 else y_min) 175 lbl_text_pos = (x_min + 5, y_min + 16 if y_min<25 else y_min - 5) 176 177 # Add label and confidence value 178 cv2.rectangle(frame, lbl_box_xy_min, lbl_box_xy_max, color, -1) 179 cv2.putText(frame, label, lbl_text_pos, cv2.FONT_HERSHEY_DUPLEX, 0.50, 180 label_color, 1, cv2.LINE_AA) 181 182 183def get_source_encoding_int(video_capture): 184 return int(video_capture.get(cv2.CAP_PROP_FOURCC)) 185