• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 The Android Open Source Project
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Utility functions for sensor_fusion hardware rig."""
15
16
17import bisect
18import codecs
19import logging
20import math
21import os
22import struct
23import time
24import unittest
25
26import cv2
27from matplotlib import pylab
28import matplotlib.pyplot
29import numpy as np
30import scipy.spatial
31import serial
32from serial.tools import list_ports
33
34import camera_properties_utils
35import image_processing_utils
36
37# Constants for Rotation Rig
38ARDUINO_ANGLE_MAX = 180.0  # degrees
39ARDUINO_BAUDRATE = 9600
40ARDUINO_CMD_LENGTH = 3
41ARDUINO_CMD_TIME = 2.0 * ARDUINO_CMD_LENGTH / ARDUINO_BAUDRATE  # round trip
42ARDUINO_PID = 0x0043
43ARDUINO_SERVO_SPEED_MAX = 255
44ARDUINO_SERVO_SPEED_MIN = 1
45ARDUINO_SPEED_START_BYTE = 253
46ARDUINO_START_BYTE = 255
47ARDUINO_START_NUM_TRYS = 3
48ARDUINO_TEST_CMD = (b'\x01', b'\x02', b'\x03')
49ARDUINO_VALID_CH = ('1', '2', '3', '4', '5', '6')
50ARDUINO_VIDS = (0x2341, 0x2a03)
51
52CANAKIT_BAUDRATE = 115200
53CANAKIT_CMD_TIME = 0.05  # seconds (found experimentally)
54CANAKIT_DATA_DELIMITER = '\r\n'
55CANAKIT_PID = 0xfc73
56CANAKIT_SEND_TIMEOUT = 0.02  # seconds
57CANAKIT_SET_CMD = 'REL'
58CANAKIT_SLEEP_TIME = 2  # seconds (for full 90 degree rotation)
59CANAKIT_VALID_CMD = ('ON', 'OFF')
60CANAKIT_VALID_CH = ('1', '2', '3', '4')
61CANAKIT_VID = 0x04d8
62
63HS755HB_ANGLE_MAX = 202.0  # throw for rotation motor in degrees
64
65# From test_sensor_fusion
66_FEATURE_MARGIN = 0.20  # Only take feature points from center 20% so that
67                        # rotation measured has less rolling shutter effect.
68_FEATURE_PTS_MIN = 30  # Min number of feature pts to perform rotation analysis.
69# cv2.goodFeatures to track.
70# 'POSTMASK' is the measurement method in all previous versions of Android.
71# 'POSTMASK' finds best features on entire frame and then masks the features
72# to the vertical center FEATURE_MARGIN for the measurement.
73# 'PREMASK' is a new measurement that is used when FEATURE_PTS_MIN is not
74# found in frame. This finds the best 2*FEATURE_PTS_MIN in the FEATURE_MARGIN
75# part of the frame.
76_CV2_FEATURE_PARAMS_POSTMASK = dict(maxCorners=240,
77                                    qualityLevel=0.3,
78                                    minDistance=7,
79                                    blockSize=7)
80_CV2_FEATURE_PARAMS_PREMASK = dict(maxCorners=2*_FEATURE_PTS_MIN,
81                                   qualityLevel=0.3,
82                                   minDistance=7,
83                                   blockSize=7)
84_GYRO_SAMP_RATE_MIN = 100.0  # Samples/second: min gyro sample rate.
85_CV2_LK_PARAMS = dict(winSize=(15, 15),
86                      maxLevel=2,
87                      criteria=(cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT,
88                                10, 0.03))  # cv2.calcOpticalFlowPyrLK params.
89_ROTATION_PER_FRAME_MIN = 0.001  # rads/s
90_GYRO_ROTATION_PER_SEC_MAX = 2.0  # rads/s
91
92# unittest constants
93_COARSE_FIT_RANGE = 20  # Range area around coarse fit to do optimization.
94_CORR_TIME_OFFSET_MAX = 50  # ms max shift to try and match camera/gyro times.
95_CORR_TIME_OFFSET_STEP = 0.5  # ms step for shifts.
96
97# Unit translators
98_MSEC_TO_NSEC = 1000000
99_NSEC_TO_SEC = 1E-9
100_SEC_TO_NSEC = int(1/_NSEC_TO_SEC)
101_RADS_TO_DEGS = 180/math.pi
102
103_NUM_GYRO_PTS_TO_AVG = 20
104
105
106def serial_port_def(name):
107  """Determine the serial port and open.
108
109  Args:
110    name: string of device to locate (ie. 'Arduino', 'Canakit' or 'Default')
111  Returns:
112    serial port object
113  """
114  serial_port = None
115  devices = list_ports.comports()
116  for device in devices:
117    if not (device.vid and device.pid):  # Not all comm ports have vid and pid
118      continue
119    if name.lower() == 'arduino':
120      if (device.vid in ARDUINO_VIDS and device.pid == ARDUINO_PID):
121        logging.debug('Arduino: %s', str(device))
122        serial_port = device.device
123        return serial.Serial(serial_port, ARDUINO_BAUDRATE, timeout=1)
124
125    elif name.lower() in ('canakit', 'default'):
126      if (device.vid == CANAKIT_VID and device.pid == CANAKIT_PID):
127        logging.debug('Canakit: %s', str(device))
128        serial_port = device.device
129        return serial.Serial(serial_port, CANAKIT_BAUDRATE,
130                             timeout=CANAKIT_SEND_TIMEOUT,
131                             parity=serial.PARITY_EVEN,
132                             stopbits=serial.STOPBITS_ONE,
133                             bytesize=serial.EIGHTBITS)
134  raise ValueError(f'{name} device not connected.')
135
136
137def canakit_cmd_send(canakit_serial_port, cmd_str):
138  """Wrapper for sending serial command to Canakit.
139
140  Args:
141    canakit_serial_port: port to write for canakit
142    cmd_str: str; value to send to device.
143  """
144  try:
145    logging.debug('writing port...')
146    canakit_serial_port.write(CANAKIT_DATA_DELIMITER.encode())
147    time.sleep(CANAKIT_CMD_TIME)  # This is critical for relay.
148    canakit_serial_port.write(cmd_str.encode())
149
150  except IOError as io_error:
151    raise IOError(
152        f'Port {CANAKIT_VID}:{CANAKIT_PID} is not open!') from io_error
153
154
155def canakit_set_relay_channel_state(canakit_port, ch, state):
156  """Set Canakit relay channel and state.
157
158  Waits CANAKIT_SLEEP_TIME for rotation to occur.
159
160  Args:
161    canakit_port: serial port object for the Canakit port.
162    ch: string for channel number of relay to set. '1', '2', '3', or '4'
163    state: string of either 'ON' or 'OFF'
164  """
165  logging.debug('Setting relay state %s', state)
166  if ch in CANAKIT_VALID_CH and state in CANAKIT_VALID_CMD:
167    canakit_cmd_send(canakit_port, CANAKIT_SET_CMD + ch + '.' + state + '\r\n')
168    time.sleep(CANAKIT_SLEEP_TIME)
169  else:
170    logging.debug('Invalid ch (%s) or state (%s), no command sent.', ch, state)
171
172
173def arduino_read_cmd(port):
174  """Read back Arduino command from serial port."""
175  cmd = []
176  for _ in range(ARDUINO_CMD_LENGTH):
177    cmd.append(port.read())
178  return cmd
179
180
181def arduino_send_cmd(port, cmd):
182  """Send command to serial port."""
183  for i in range(ARDUINO_CMD_LENGTH):
184    port.write(cmd[i])
185
186
187def arduino_loopback_cmd(port, cmd):
188  """Send command to serial port."""
189  arduino_send_cmd(port, cmd)
190  time.sleep(ARDUINO_CMD_TIME)
191  return arduino_read_cmd(port)
192
193
194def establish_serial_comm(port):
195  """Establish connection with serial port."""
196  logging.debug('Establishing communication with %s', port.name)
197  trys = 1
198  hex_test = convert_to_hex(ARDUINO_TEST_CMD)
199  logging.debug(' test tx: %s %s %s', hex_test[0], hex_test[1], hex_test[2])
200  while trys <= ARDUINO_START_NUM_TRYS:
201    cmd_read = arduino_loopback_cmd(port, ARDUINO_TEST_CMD)
202    hex_read = convert_to_hex(cmd_read)
203    logging.debug(' test rx: %s %s %s', hex_read[0], hex_read[1], hex_read[2])
204    if cmd_read != list(ARDUINO_TEST_CMD):
205      trys += 1
206    else:
207      logging.debug(' Arduino comm established after %d try(s)', trys)
208      break
209
210
211def convert_to_hex(cmd):
212  return [('%0.2x' % int(codecs.encode(x, 'hex_codec'), 16) if x else '--')
213          for x in cmd]
214
215
216def arduino_rotate_servo_to_angle(ch, angle, serial_port, move_time):
217  """Rotate servo to the specified angle.
218
219  Args:
220    ch: str; servo to rotate in ARDUINO_VALID_CH
221    angle: int; servo angle to move to
222    serial_port: object; serial port
223    move_time: int; time in seconds
224  """
225  if angle < 0 or angle > ARDUINO_ANGLE_MAX:
226    logging.debug('Angle must be between 0 and %d.', ARDUINO_ANGLE_MAX)
227    angle = 0
228    if angle > ARDUINO_ANGLE_MAX:
229      angle = ARDUINO_ANGLE_MAX
230
231  cmd = [struct.pack('B', i) for i in [ARDUINO_START_BYTE, int(ch), angle]]
232  arduino_send_cmd(serial_port, cmd)
233  time.sleep(move_time)
234
235
236def arduino_rotate_servo(ch, angles, move_time, serial_port):
237  """Rotate servo through 'angles'.
238
239  Args:
240    ch: str; servo to rotate
241    angles: list of ints; servo angles to move to
242    move_time: int; time required to allow for arduino movement
243    serial_port: object; serial port
244  """
245
246  for angle in angles:
247    angle_norm = int(round(angle*ARDUINO_ANGLE_MAX/HS755HB_ANGLE_MAX, 0))
248    arduino_rotate_servo_to_angle(ch, angle_norm, serial_port, move_time)
249
250
251def rotation_rig(rotate_cntl, rotate_ch, num_rotations, angles, servo_speed,
252                 move_time):
253  """Rotate the phone n times using rotate_cntl and rotate_ch defined.
254
255  rotate_ch is hard wired and must be determined from physical setup.
256
257  First initialize the port and send a test string defined by ARDUINO_TEST_CMD
258  to establish communications. Then rotate servo motor to origin position.
259
260  Args:
261    rotate_cntl: str to identify as 'arduino' or 'canakit' controller.
262    rotate_ch: str to identify rotation channel number.
263    num_rotations: int number of rotations.
264    angles: list of ints; servo angle to move to.
265    servo_speed: int number of move speed between [1, 255].
266    move_time: int time required to allow for arduino movement.
267  """
268
269  logging.debug('Controller: %s, ch: %s', rotate_cntl, rotate_ch)
270  if rotate_cntl.lower() == 'arduino':
271    # identify port
272    arduino_serial_port = serial_port_def('Arduino')
273
274    # send test cmd to Arduino until cmd returns properly
275    establish_serial_comm(arduino_serial_port)
276
277    # initialize servo at origin
278    logging.debug('Moving servo to origin')
279    arduino_rotate_servo_to_angle(rotate_ch, 0, arduino_serial_port, 1)
280
281    # set servo speed
282    set_servo_speed(rotate_ch, servo_speed, arduino_serial_port, delay=0)
283
284  elif rotate_cntl.lower() == 'canakit':
285    canakit_serial_port = serial_port_def('Canakit')
286
287  else:
288    logging.info('No rotation rig defined. Manual test: rotate phone by hand.')
289
290  # rotate phone
291  logging.debug('Rotating phone %dx', num_rotations)
292  for _ in range(num_rotations):
293    if rotate_cntl == 'arduino':
294      arduino_rotate_servo(rotate_ch, angles, move_time, arduino_serial_port)
295    elif rotate_cntl == 'canakit':
296      canakit_set_relay_channel_state(canakit_serial_port, rotate_ch, 'ON')
297      canakit_set_relay_channel_state(canakit_serial_port, rotate_ch, 'OFF')
298  logging.debug('Finished rotations')
299  if rotate_cntl == 'arduino':
300    logging.debug('Moving servo to origin')
301    arduino_rotate_servo_to_angle(rotate_ch, 0, arduino_serial_port, 1)
302
303
304def set_servo_speed(ch, servo_speed, serial_port, delay=0):
305  """Set servo to specified speed.
306
307  Args:
308    ch: str; servo to turn on in ARDUINO_VALID_CH
309    servo_speed: int; value of speed between 1 and 255
310    serial_port: object; serial port
311    delay: int; time in seconds
312  """
313  logging.debug('Servo speed: %d', servo_speed)
314  if servo_speed < ARDUINO_SERVO_SPEED_MIN:
315    logging.debug('Servo speed must be >= %d.', ARDUINO_SERVO_SPEED_MIN)
316    servo_speed = ARDUINO_SERVO_SPEED_MIN
317  elif servo_speed > ARDUINO_SERVO_SPEED_MAX:
318    logging.debug('Servo speed must be <= %d.', ARDUINO_SERVO_SPEED_MAX)
319    servo_speed = ARDUINO_SERVO_SPEED_MAX
320
321  cmd = [struct.pack('B', i) for i in [ARDUINO_SPEED_START_BYTE,
322                                       int(ch), servo_speed]]
323  arduino_send_cmd(serial_port, cmd)
324  time.sleep(delay)
325
326
327def calc_max_rotation_angle(rotations, sensor_type):
328  """Calculates the max angle of deflection from rotations.
329
330  Args:
331    rotations: numpy array of rotation per event
332    sensor_type: string 'Camera' or 'Gyro'
333
334  Returns:
335    maximum angle of rotation for the given rotations
336  """
337  rotations *= _RADS_TO_DEGS
338  rotations_sum = np.cumsum(rotations)
339  rotation_max = max(rotations_sum)
340  rotation_min = min(rotations_sum)
341  logging.debug('%s min: %.2f, max %.2f rotation (degrees)',
342                sensor_type, rotation_min, rotation_max)
343  logging.debug('%s max rotation: %.2f degrees',
344                sensor_type, (rotation_max-rotation_min))
345  return rotation_max-rotation_min
346
347
348def get_gyro_rotations(gyro_events, cam_times):
349  """Get the rotation values of the gyro.
350
351  Integrates the gyro data between each camera frame to compute an angular
352  displacement.
353
354  Args:
355    gyro_events: List of gyro event objects.
356    cam_times: Array of N camera times, one for each frame.
357
358  Returns:
359    Array of N-1 gyro rotation measurements (rads/s).
360  """
361  gyro_times = np.array([e['time'] for e in gyro_events])
362  all_gyro_rots = np.array([e['z'] for e in gyro_events])
363  gyro_rots = []
364  if gyro_times[0] > cam_times[0] or gyro_times[-1] < cam_times[-1]:
365    raise AssertionError('Gyro times do not bound camera times! '
366                         f'gyro: {gyro_times[0]:.0f} -> {gyro_times[-1]:.0f} '
367                         f'cam: {cam_times[0]} -> {cam_times[-1]} (ns).')
368
369  # Integrate the gyro data between each pair of camera frame times.
370  for i_cam in range(len(cam_times)-1):
371    # Get the window of gyro samples within the current pair of frames.
372    # Note: bisect always picks first gyro index after camera time.
373    t_cam0 = cam_times[i_cam]
374    t_cam1 = cam_times[i_cam+1]
375    i_gyro_window0 = bisect.bisect(gyro_times, t_cam0)
376    i_gyro_window1 = bisect.bisect(gyro_times, t_cam1)
377    gyro_sum = 0
378
379    # Integrate samples within the window.
380    for i_gyro in range(i_gyro_window0, i_gyro_window1):
381      gyro_val = all_gyro_rots[i_gyro+1]
382      t_gyro0 = gyro_times[i_gyro]
383      t_gyro1 = gyro_times[i_gyro+1]
384      t_gyro_delta = (t_gyro1 - t_gyro0) * _NSEC_TO_SEC
385      gyro_sum += gyro_val * t_gyro_delta
386
387    # Handle the fractional intervals at the sides of the window.
388    for side, i_gyro in enumerate([i_gyro_window0-1, i_gyro_window1]):
389      gyro_val = all_gyro_rots[i_gyro+1]
390      t_gyro0 = gyro_times[i_gyro]
391      t_gyro1 = gyro_times[i_gyro+1]
392      t_gyro_delta = (t_gyro1 - t_gyro0) * _NSEC_TO_SEC
393      if side == 0:
394        f = (t_cam0 - t_gyro0) / (t_gyro1 - t_gyro0)
395        frac_correction = gyro_val * t_gyro_delta * (1.0 - f)
396        gyro_sum += frac_correction
397      else:
398        f = (t_cam1 - t_gyro0) / (t_gyro1 - t_gyro0)
399        frac_correction = gyro_val * t_gyro_delta * f
400        gyro_sum += frac_correction
401    gyro_rots.append(gyro_sum)
402  gyro_rots = np.array(gyro_rots)
403  return gyro_rots
404
405
406def procrustes_rotation(x, y):
407  """Performs a Procrustes analysis to conform points in x to y.
408
409  Procrustes analysis determines a linear transformation (translation,
410  reflection, orthogonal rotation and scaling) of the points in y to best
411  conform them to the points in matrix x, using the sum of squared errors
412  as the metric for fit criterion.
413
414  Args:
415    x: Target coordinate matrix
416    y: Input coordinate matrix
417
418  Returns:
419    The rotation component of the transformation that maps x to y.
420  """
421  x0 = (x-x.mean(0)) / np.sqrt(((x-x.mean(0))**2.0).sum())
422  y0 = (y-y.mean(0)) / np.sqrt(((y-y.mean(0))**2.0).sum())
423  u, _, vt = np.linalg.svd(np.dot(x0.T, y0), full_matrices=False)
424  return np.dot(vt.T, u.T)
425
426
427def get_cam_rotations(frames, facing, h, file_name_stem,
428                      start_frame, stabilized_video=False):
429  """Get the rotations of the camera between each pair of frames.
430
431  Takes N frames and returns N-1 angular displacements corresponding to the
432  rotations between adjacent pairs of frames, in radians.
433  Only takes feature points from center so that rotation measured has less
434  rolling shutter effect.
435  Requires FEATURE_PTS_MIN to have enough data points for accurate measurements.
436  Uses FEATURE_PARAMS for cv2 to identify features in checkerboard images.
437  Ensures camera rotates enough if not calling with stabilized video.
438
439  Args:
440    frames: List of N images (as RGB numpy arrays).
441    facing: Direction camera is facing.
442    h: Pixel height of each frame.
443    file_name_stem: file name stem including location for data.
444    start_frame: int; index to start at
445    stabilized_video: Boolean; if called with stabilized video
446
447  Returns:
448    numpy array of N-1 camera rotation measurements (rad).
449  """
450  gframes = []
451  for frame in frames:
452    frame = (frame * 255.0).astype(np.uint8)  # cv2 uses [0, 255]
453    gframes.append(cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY))
454  num_frames = len(gframes)
455  logging.debug('num_frames: %d', num_frames)
456  # create mask
457  ymin = int(h * (1 - _FEATURE_MARGIN) / 2)
458  ymax = int(h * (1 + _FEATURE_MARGIN) / 2)
459  pre_mask = np.zeros_like(gframes[0])
460  pre_mask[ymin:ymax, :] = 255
461
462  for masking in ['post', 'pre']:  # Do post-masking (original) method 1st
463    logging.debug('Using %s masking method', masking)
464    rotations = []
465    for i in range(1, num_frames):
466      j = i - 1
467      gframe0 = gframes[j]
468      gframe1 = gframes[i]
469      if masking == 'post':
470        p0 = cv2.goodFeaturesToTrack(
471            gframe0, mask=None, **_CV2_FEATURE_PARAMS_POSTMASK)
472        post_mask = (p0[:, 0, 1] >= ymin) & (p0[:, 0, 1] <= ymax)
473        p0_filtered = p0[post_mask]
474      else:
475        p0_filtered = cv2.goodFeaturesToTrack(
476            gframe0, mask=pre_mask, **_CV2_FEATURE_PARAMS_PREMASK)
477      num_features = len(p0_filtered)
478      if num_features < _FEATURE_PTS_MIN:
479        for pt in np.rint(p0_filtered).astype(int):
480          x, y = pt[0][0], pt[0][1]
481          cv2.circle(frames[j], (x, y), 3, (100, 255, 255), -1)
482        image_processing_utils.write_image(
483            frames[j], f'{file_name_stem}_features{j+start_frame:03d}.png')
484        msg = (f'Not enough features in frame {j+start_frame}. Need at least '
485               f'{_FEATURE_PTS_MIN} features, got {num_features}.')
486        if masking == 'pre':
487          raise AssertionError(msg)
488        else:
489          logging.debug(msg)
490          break
491      else:
492        logging.debug('Number of features in frame %s is %d',
493                      str(j+start_frame).zfill(3), num_features)
494      p1, st, _ = cv2.calcOpticalFlowPyrLK(gframe0, gframe1, p0_filtered, None,
495                                           **_CV2_LK_PARAMS)
496      tform = procrustes_rotation(p0_filtered[st == 1], p1[st == 1])
497      if facing == camera_properties_utils.LENS_FACING_BACK:
498        rotation = -math.atan2(tform[0, 1], tform[0, 0])
499      elif facing == camera_properties_utils.LENS_FACING_FRONT:
500        rotation = math.atan2(tform[0, 1], tform[0, 0])
501      else:
502        raise AssertionError(f'Unknown lens facing: {facing}.')
503      rotations.append(rotation)
504      if i == 1:
505        # Save debug visualization of features that are being
506        # tracked in the first frame.
507        frame = frames[j]
508        for x, y in np.rint(p0_filtered[st == 1]).astype(int):
509          cv2.circle(frame, (x, y), 3, (100, 255, 255), -1)
510        image_processing_utils.write_image(
511            frame, f'{file_name_stem}_features{j+start_frame:03d}.png')
512    if i == num_frames-1:
513      logging.debug('Correct num of frames found: %d', i)
514      break  # exit if enough features in all frames
515  if i != num_frames-1:
516    raise AssertionError('Neither method found enough features in all frames')
517
518  rotations = np.array(rotations)
519  rot_per_frame_max = max(abs(rotations))
520  logging.debug('Max rotation in frame: %.2f degrees',
521                rot_per_frame_max*_RADS_TO_DEGS)
522  if rot_per_frame_max < _ROTATION_PER_FRAME_MIN and not stabilized_video:
523    logging.debug('Checking camera rotations on video.')
524    raise AssertionError(f'Device not moved enough: {rot_per_frame_max:.3f} '
525                         f'movement. THRESH: {_ROTATION_PER_FRAME_MIN} rads.')
526  else:
527    logging.debug('Skipped camera rotation check due to stabilized video.')
528  return rotations
529
530
531def get_best_alignment_offset(cam_times, cam_rots, gyro_events):
532  """Find the best offset to align the camera and gyro motion traces.
533
534  This function integrates the shifted gyro data between camera samples
535  for a range of candidate shift values, and returns the shift that
536  result in the best correlation.
537
538  Uses a correlation distance metric between the curves, where a smaller
539  value means that the curves are better-correlated.
540
541  Fits a curve to the correlation distance data to measure the minima more
542  accurately, by looking at the correlation distances within a range of
543  +/- 10ms from the measured best score; note that this will use fewer
544  than the full +/- 10 range for the curve fit if the measured score
545  (which is used as the center of the fit) is within 10ms of the edge of
546  the +/- 50ms candidate range.
547
548  Args:
549    cam_times: Array of N camera times, one for each frame.
550    cam_rots: Array of N-1 camera rotation displacements (rad).
551    gyro_events: List of gyro event objects.
552
553  Returns:
554    Best alignment offset(ms), fit coefficients, candidates, and distances.
555  """
556  # Measure the correlation distance over defined shift
557  shift_candidates = np.arange(-_CORR_TIME_OFFSET_MAX,
558                               _CORR_TIME_OFFSET_MAX+_CORR_TIME_OFFSET_STEP,
559                               _CORR_TIME_OFFSET_STEP).tolist()
560  spatial_distances = []
561  for shift in shift_candidates:
562    shifted_cam_times = cam_times + shift*_MSEC_TO_NSEC
563    gyro_rots = get_gyro_rotations(gyro_events, shifted_cam_times)
564    spatial_distance = scipy.spatial.distance.correlation(cam_rots, gyro_rots)
565    logging.debug('shift %.1fms spatial distance: %.5f', shift,
566                  spatial_distance)
567    spatial_distances.append(spatial_distance)
568
569  best_corr_dist = min(spatial_distances)
570  coarse_best_shift = shift_candidates[spatial_distances.index(best_corr_dist)]
571  logging.debug('Best shift without fitting is %.4f ms', coarse_best_shift)
572
573  # Fit a 2nd order polynomial around coarse_best_shift to extract best fit
574  i = spatial_distances.index(best_corr_dist)
575  i_poly_fit_min = i - _COARSE_FIT_RANGE
576  i_poly_fit_max = i + _COARSE_FIT_RANGE + 1
577  shift_candidates = shift_candidates[i_poly_fit_min:i_poly_fit_max]
578  spatial_distances = spatial_distances[i_poly_fit_min:i_poly_fit_max]
579  fit_coeffs = np.polyfit(shift_candidates, spatial_distances, 2)  # ax^2+bx+c
580  exact_best_shift = -fit_coeffs[1]/(2*fit_coeffs[0])
581  if abs(coarse_best_shift - exact_best_shift) > 2.0:
582    raise AssertionError(
583        f'Test failed. Bad fit to time-shift curve. Coarse best shift: '
584        f'{coarse_best_shift}, Exact best shift: {exact_best_shift}.')
585  if fit_coeffs[0] <= 0 or fit_coeffs[2] <= 0:
586    raise AssertionError(
587        f'Coefficients are < 0: a: {fit_coeffs[0]}, c: {fit_coeffs[2]}.')
588
589  return exact_best_shift, fit_coeffs, shift_candidates, spatial_distances
590
591
592def plot_camera_rotations(cam_rots, start_frame, video_quality,
593                          plot_name_stem):
594  """Plot the camera rotations.
595
596  Args:
597   cam_rots: np array of camera rotations angle per frame
598   start_frame: int value of start frame
599   video_quality: str for video quality identifier
600   plot_name_stem: str (with path) of what to call plot
601  """
602
603  pylab.figure(video_quality)
604  frames = range(start_frame, len(cam_rots)+start_frame)
605  pylab.title(f'Camera rotation vs frame {video_quality}')
606  pylab.plot(frames, cam_rots*_RADS_TO_DEGS, '-ro', label='x')
607  pylab.xlabel('frame #')
608  pylab.ylabel('camera rotation (degrees)')
609  matplotlib.pyplot.savefig(f'{plot_name_stem}_cam_rots.png')
610  pylab.close(video_quality)
611
612
613def plot_gyro_events(gyro_events, plot_name, log_path):
614  """Plot x, y, and z on the gyro events.
615
616  Samples are grouped into NUM_GYRO_PTS_TO_AVG groups and averaged to minimize
617  random spikes in data.
618
619  Args:
620    gyro_events: List of gyroscope events.
621    plot_name:  name of plot(s).
622    log_path: location to save data.
623  """
624
625  nevents = (len(gyro_events) // _NUM_GYRO_PTS_TO_AVG) * _NUM_GYRO_PTS_TO_AVG
626  gyro_events = gyro_events[:nevents]
627  times = np.array([(e['time'] - gyro_events[0]['time']) * _NSEC_TO_SEC
628                    for e in gyro_events])
629  x = np.array([e['x'] for e in gyro_events])
630  y = np.array([e['y'] for e in gyro_events])
631  z = np.array([e['z'] for e in gyro_events])
632
633  # Group samples into size-N groups & average each together to minimize random
634  # spikes in data.
635  times = times[_NUM_GYRO_PTS_TO_AVG//2::_NUM_GYRO_PTS_TO_AVG]
636  x = x.reshape(nevents//_NUM_GYRO_PTS_TO_AVG, _NUM_GYRO_PTS_TO_AVG).mean(1)
637  y = y.reshape(nevents//_NUM_GYRO_PTS_TO_AVG, _NUM_GYRO_PTS_TO_AVG).mean(1)
638  z = z.reshape(nevents//_NUM_GYRO_PTS_TO_AVG, _NUM_GYRO_PTS_TO_AVG).mean(1)
639
640  pylab.figure(plot_name)
641  # x & y on same axes
642  pylab.subplot(2, 1, 1)
643  pylab.title(f'{plot_name}(mean of {_NUM_GYRO_PTS_TO_AVG} pts)')
644  pylab.plot(times, x, 'r', label='x')
645  pylab.plot(times, y, 'g', label='y')
646  pylab.ylim([np.amin(z), np.amax(z)])
647  pylab.ylabel('gyro x,y movement (rads/s)')
648  pylab.legend()
649
650  # z on separate axes
651  pylab.subplot(2, 1, 2)
652  pylab.plot(times, z, 'b', label='z')
653  pylab.xlabel('time (seconds)')
654  pylab.ylabel('gyro z movement (rads/s)')
655  pylab.legend()
656  file_name = os.path.join(log_path, plot_name)
657  matplotlib.pyplot.savefig(f'{file_name}_gyro_events.png')
658  pylab.close(plot_name)
659
660  z_max = max(abs(z))
661  logging.debug('z_max: %.3f', z_max)
662  if z_max > _GYRO_ROTATION_PER_SEC_MAX:
663    raise AssertionError(
664        f'Phone moved too rapidly! Please confirm controller firmware. '
665        f'Max: {z_max:.3f}, TOL: {_GYRO_ROTATION_PER_SEC_MAX} rads/s')
666
667
668def conv_acceleration_to_movement(gyro_events, video_delay_time):
669  """Convert gyro_events time and speed to movement during video time.
670
671  Args:
672    gyro_events: sorted dict of entries with 'time', 'x', 'y', and 'z'
673    video_delay_time: time at which video starts
674
675  Returns:
676    'z' acceleration converted to movement for times around VIDEO playing.
677  """
678  gyro_times = np.array([e['time'] for e in gyro_events])
679  gyro_speed = np.array([e['z'] for e in gyro_events])
680  gyro_time_min = gyro_times[0]
681  logging.debug('gyro start time: %dns', gyro_time_min)
682  logging.debug('gyro stop time: %dns', gyro_times[-1])
683  gyro_rotations = []
684  video_time_start = gyro_time_min + video_delay_time *_SEC_TO_NSEC
685  video_time_stop = video_time_start + video_delay_time *_SEC_TO_NSEC
686  logging.debug('video start time: %dns', video_time_start)
687  logging.debug('video stop time: %dns', video_time_stop)
688
689  for i, t in enumerate(gyro_times):
690    if video_time_start <= t <= video_time_stop:
691      gyro_rotations.append((gyro_times[i]-gyro_times[i-1])/_SEC_TO_NSEC *
692                            gyro_speed[i])
693  return np.array(gyro_rotations)
694
695
696class SensorFusionUtilsTests(unittest.TestCase):
697  """Run a suite of unit tests on this module."""
698
699  _CAM_FRAME_TIME = 30 * _MSEC_TO_NSEC  # Similar to 30FPS
700  _CAM_ROT_AMPLITUDE = 0.04  # Empirical number for rotation per frame (rads/s).
701
702  def _generate_pwl_waveform(self, pts, step, amplitude):
703    """Helper function to generate piece wise linear waveform."""
704    pwl_waveform = []
705    for t in range(pts[0], pts[1], step):
706      pwl_waveform.append(0)
707    for t in range(pts[1], pts[2], step):
708      pwl_waveform.append((t-pts[1])/(pts[2]-pts[1])*amplitude)
709    for t in range(pts[2], pts[3], step):
710      pwl_waveform.append(amplitude)
711    for t in range(pts[3], pts[4], step):
712      pwl_waveform.append((pts[4]-t)/(pts[4]-pts[3])*amplitude)
713    for t in range(pts[4], pts[5], step):
714      pwl_waveform.append(0)
715    for t in range(pts[5], pts[6], step):
716      pwl_waveform.append((-1*(t-pts[5])/(pts[6]-pts[5]))*amplitude)
717    for t in range(pts[6], pts[7], step):
718      pwl_waveform.append(-1*amplitude)
719    for t in range(pts[7], pts[8], step):
720      pwl_waveform.append((t-pts[8])/(pts[8]-pts[7])*amplitude)
721    for t in range(pts[8], pts[9], step):
722      pwl_waveform.append(0)
723    return pwl_waveform
724
725  def _generate_test_waveforms(self, gyro_sampling_rate, t_offset=0):
726    """Define ideal camera/gryo behavior.
727
728    Args:
729      gyro_sampling_rate: Value in samples/sec.
730      t_offset: Value in ns for gyro/camera timing offset.
731
732    Returns:
733      cam_times: numpy array of camera times N values long.
734      cam_rots: numpy array of camera rotations N-1 values long.
735      gyro_events: list of dicts of gyro events N*gyro_sampling_rate/30 long.
736
737    Round trip for motor is ~2 seconds (~60 frames)
738            1111111111111111
739           i                i
740          i                  i
741         i                    i
742     0000                      0000                      0000
743                                   i                    i
744                                    i                  i
745                                     i                i
746                                      -1-1-1-1-1-1-1-1
747    t_0 t_1 t_2           t_3 t_4 t_5 t_6           t_7 t_8 t_9
748
749    Note gyro waveform must extend +/- _CORR_TIME_OFFSET_MAX to enable shifting
750    of camera waveform to find best correlation.
751
752    """
753
754    t_ramp = 4 * self._CAM_FRAME_TIME
755    pts = {}
756    pts[0] = 3 * self._CAM_FRAME_TIME
757    pts[1] = pts[0] + 3 * self._CAM_FRAME_TIME
758    pts[2] = pts[1] + t_ramp
759    pts[3] = pts[2] + 32 * self._CAM_FRAME_TIME
760    pts[4] = pts[3] + t_ramp
761    pts[5] = pts[4] + 4 * self._CAM_FRAME_TIME
762    pts[6] = pts[5] + t_ramp
763    pts[7] = pts[6] + 32 * self._CAM_FRAME_TIME
764    pts[8] = pts[7] + t_ramp
765    pts[9] = pts[8] + 4 * self._CAM_FRAME_TIME
766    cam_times = np.array(range(pts[0], pts[9], self._CAM_FRAME_TIME))
767    cam_rots = self._generate_pwl_waveform(
768        pts, self._CAM_FRAME_TIME, self._CAM_ROT_AMPLITUDE)
769    cam_rots.pop()  # rots is N-1 for N length times.
770    cam_rots = np.array(cam_rots)
771
772    # Generate gyro waveform.
773    gyro_step = int(round(_SEC_TO_NSEC/gyro_sampling_rate, 0))
774    gyro_pts = {k: v+t_offset+self._CAM_FRAME_TIME//2 for k, v in pts.items()}
775    gyro_pts[0] = 0  # adjust end pts to bound camera
776    gyro_pts[9] += self._CAM_FRAME_TIME*2  # adjust end pt to bound camera
777    gyro_rot_amplitude = (
778        self._CAM_ROT_AMPLITUDE / self._CAM_FRAME_TIME * _SEC_TO_NSEC)
779    gyro_rots = self._generate_pwl_waveform(
780        gyro_pts, gyro_step, gyro_rot_amplitude)
781
782    # Create gyro events list of dicts.
783    gyro_events = []
784    for i, t in enumerate(range(gyro_pts[0], gyro_pts[9], gyro_step)):
785      gyro_events.append({'time': t, 'z': gyro_rots[i]})
786
787    return cam_times, cam_rots, gyro_events
788
789  def test_get_gyro_rotations(self):
790    """Tests that gyro rotations are masked properly by camera rotations.
791
792    Note that waveform ideal waveform generation only works properly with
793    integer multiples of frame rate.
794    """
795    # Run with different sampling rates to validate.
796    for gyro_sampling_rate in [200, 1000]:  # 6x, 30x frame rate
797      cam_times, cam_rots, gyro_events = self._generate_test_waveforms(
798          gyro_sampling_rate)
799      gyro_rots = get_gyro_rotations(gyro_events, cam_times)
800      e_msg = f'gyro sampling rate = {gyro_sampling_rate}\n'
801      e_msg += f'cam_times = {list(cam_times)}\n'
802      e_msg += f'cam_rots = {list(cam_rots)}\n'
803      e_msg += f'gyro_rots = {list(gyro_rots)}'
804
805      self.assertTrue(np.allclose(
806          gyro_rots, cam_rots, atol=self._CAM_ROT_AMPLITUDE*0.10), e_msg)
807
808  def test_get_best_alignment_offset(self):
809    """Unittest for alignment offset check."""
810
811    gyro_sampling_rate = 5000
812    for t_offset_ms in [0, 1]:  # Run with different offsets to validate.
813      t_offset = int(t_offset_ms * _MSEC_TO_NSEC)
814      cam_times, cam_rots, gyro_events = self._generate_test_waveforms(
815          gyro_sampling_rate, t_offset)
816
817      best_fit_offset, coeffs, x, y = get_best_alignment_offset(
818          cam_times, cam_rots, gyro_events)
819      e_msg = f'best: {best_fit_offset} ms\n'
820      e_msg += f'coeffs: {coeffs}\n'
821      e_msg += f'x: {x}\n'
822      e_msg += f'y: {y}'
823      self.assertTrue(np.isclose(t_offset_ms, best_fit_offset, atol=0.1), e_msg)
824
825
826if __name__ == '__main__':
827  unittest.main()
828
829