1""" 2/* Copyright (c) 2023 Amazon 3 Written by Jan Buethe */ 4/* 5 Redistribution and use in source and binary forms, with or without 6 modification, are permitted provided that the following conditions 7 are met: 8 9 - Redistributions of source code must retain the above copyright 10 notice, this list of conditions and the following disclaimer. 11 12 - Redistributions in binary form must reproduce the above copyright 13 notice, this list of conditions and the following disclaimer in the 14 documentation and/or other materials provided with the distribution. 15 16 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 17 ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 18 LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 19 A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER 20 OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 21 EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 22 PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 23 PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 24 LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 25 NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 26 SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 27*/ 28""" 29 30""" module for inspecting models during inference """ 31 32import os 33 34import yaml 35import matplotlib.pyplot as plt 36import matplotlib.animation as animation 37 38import torch 39import numpy as np 40 41# stores entries {key : {'fid' : fid, 'fs' : fs, 'dim' : dim, 'dtype' : dtype}} 42_state = dict() 43_folder = 'endoscopy' 44 45def get_gru_gates(gru, input, state): 46 hidden_size = gru.hidden_size 47 48 direct = torch.matmul(gru.weight_ih_l0, input.squeeze()) 49 recurrent = torch.matmul(gru.weight_hh_l0, state.squeeze()) 50 51 # reset gate 52 start, stop = 0 * hidden_size, 1 * hidden_size 53 reset_gate = torch.sigmoid(direct[start : stop] + gru.bias_ih_l0[start : stop] + recurrent[start : stop] + gru.bias_hh_l0[start : stop]) 54 55 # update gate 56 start, stop = 1 * hidden_size, 2 * hidden_size 57 update_gate = torch.sigmoid(direct[start : stop] + gru.bias_ih_l0[start : stop] + recurrent[start : stop] + gru.bias_hh_l0[start : stop]) 58 59 # new gate 60 start, stop = 2 * hidden_size, 3 * hidden_size 61 new_gate = torch.tanh(direct[start : stop] + gru.bias_ih_l0[start : stop] + reset_gate * (recurrent[start : stop] + gru.bias_hh_l0[start : stop])) 62 63 return {'reset_gate' : reset_gate, 'update_gate' : update_gate, 'new_gate' : new_gate} 64 65 66def init(folder='endoscopy'): 67 """ sets up output folder for endoscopy data """ 68 69 global _folder 70 _folder = folder 71 72 if not os.path.exists(folder): 73 os.makedirs(folder) 74 else: 75 print(f"warning: endoscopy folder {folder} exists. Content may be lost or inconsistent results may occur.") 76 77def write_data(key, data, fs): 78 """ appends data to previous data written under key """ 79 80 global _state 81 82 # convert to numpy if torch.Tensor is given 83 if isinstance(data, torch.Tensor): 84 data = data.detach().numpy() 85 86 if not key in _state: 87 _state[key] = { 88 'fid' : open(os.path.join(_folder, key + '.bin'), 'wb'), 89 'fs' : fs, 90 'dim' : tuple(data.shape), 91 'dtype' : str(data.dtype) 92 } 93 94 with open(os.path.join(_folder, key + '.yml'), 'w') as f: 95 f.write(yaml.dump({'fs' : fs, 'dim' : tuple(data.shape), 'dtype' : str(data.dtype).split('.')[-1]})) 96 else: 97 if _state[key]['fs'] != fs: 98 raise ValueError(f"fs changed for key {key}: {_state[key]['fs']} vs. {fs}") 99 if _state[key]['dtype'] != str(data.dtype): 100 raise ValueError(f"dtype changed for key {key}: {_state[key]['dtype']} vs. {str(data.dtype)}") 101 if _state[key]['dim'] != tuple(data.shape): 102 raise ValueError(f"dim changed for key {key}: {_state[key]['dim']} vs. {tuple(data.shape)}") 103 104 _state[key]['fid'].write(data.tobytes()) 105 106def close(folder='endoscopy'): 107 """ clean up """ 108 for key in _state.keys(): 109 _state[key]['fid'].close() 110 111 112def read_data(folder='endoscopy'): 113 """ retrieves written data as numpy arrays """ 114 115 116 keys = [name[:-4] for name in os.listdir(folder) if name.endswith('.yml')] 117 118 return_dict = dict() 119 120 for key in keys: 121 with open(os.path.join(folder, key + '.yml'), 'r') as f: 122 value = yaml.load(f.read(), yaml.FullLoader) 123 124 with open(os.path.join(folder, key + '.bin'), 'rb') as f: 125 data = np.frombuffer(f.read(), dtype=value['dtype']) 126 127 value['data'] = data.reshape((-1,) + value['dim']) 128 129 return_dict[key] = value 130 131 return return_dict 132 133def get_best_reshape(shape, target_ratio=1): 134 """ calculated the best 2d reshape of shape given the target ratio (rows/cols)""" 135 136 if len(shape) > 1: 137 pixel_count = 1 138 for s in shape: 139 pixel_count *= s 140 else: 141 pixel_count = shape[0] 142 143 if pixel_count == 1: 144 return (1,) 145 146 num_columns = int((pixel_count / target_ratio)**.5) 147 148 while (pixel_count % num_columns): 149 num_columns -= 1 150 151 num_rows = pixel_count // num_columns 152 153 return (num_rows, num_columns) 154 155def get_type_and_shape(shape): 156 157 # can happen if data is one dimensional 158 if len(shape) == 0: 159 shape = (1,) 160 161 # calculate pixel count 162 if len(shape) > 1: 163 pixel_count = 1 164 for s in shape: 165 pixel_count *= s 166 else: 167 pixel_count = shape[0] 168 169 if pixel_count == 1: 170 return 'plot', (1, ) 171 172 # stay with shape if already 2-dimensional 173 if len(shape) == 2: 174 if (shape[0] != pixel_count) or (shape[1] != pixel_count): 175 return 'image', shape 176 177 return 'image', get_best_reshape(shape) 178 179def make_animation(data, filename, start_index=80, stop_index=-80, interval=20, half_signal_window_length=80): 180 181 # determine plot setup 182 num_keys = len(data.keys()) 183 184 num_rows = int((num_keys * 3/4) ** .5) 185 186 num_cols = (num_keys + num_rows - 1) // num_rows 187 188 fig, axs = plt.subplots(num_rows, num_cols) 189 fig.set_size_inches(num_cols * 5, num_rows * 5) 190 191 display = dict() 192 193 fs_max = max([val['fs'] for val in data.values()]) 194 195 num_samples = max([val['data'].shape[0] for val in data.values()]) 196 197 keys = sorted(data.keys()) 198 199 # inspect data 200 for i, key in enumerate(keys): 201 axs[i // num_cols, i % num_cols].title.set_text(key) 202 203 display[key] = dict() 204 205 display[key]['type'], display[key]['shape'] = get_type_and_shape(data[key]['dim']) 206 display[key]['down_factor'] = data[key]['fs'] / fs_max 207 208 start_index = max(start_index, half_signal_window_length) 209 while stop_index < 0: 210 stop_index += num_samples 211 212 stop_index = min(stop_index, num_samples - half_signal_window_length) 213 214 # actual plotting 215 frames = [] 216 for index in range(start_index, stop_index): 217 ims = [] 218 for i, key in enumerate(keys): 219 feature_index = int(round(index * display[key]['down_factor'])) 220 221 if display[key]['type'] == 'plot': 222 ims.append(axs[i // num_cols, i % num_cols].plot(data[key]['data'][index - half_signal_window_length : index + half_signal_window_length], marker='P', markevery=[half_signal_window_length], animated=True, color='blue')[0]) 223 224 elif display[key]['type'] == 'image': 225 ims.append(axs[i // num_cols, i % num_cols].imshow(data[key]['data'][index].reshape(display[key]['shape']), animated=True)) 226 227 frames.append(ims) 228 229 ani = animation.ArtistAnimation(fig, frames, interval=interval, blit=True, repeat_delay=1000) 230 231 if not filename.endswith('.mp4'): 232 filename += '.mp4' 233 234 ani.save(filename)