• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1"""
2/* Copyright (c) 2023 Amazon
3   Written by Jan Buethe */
4/*
5   Redistribution and use in source and binary forms, with or without
6   modification, are permitted provided that the following conditions
7   are met:
8
9   - Redistributions of source code must retain the above copyright
10   notice, this list of conditions and the following disclaimer.
11
12   - Redistributions in binary form must reproduce the above copyright
13   notice, this list of conditions and the following disclaimer in the
14   documentation and/or other materials provided with the distribution.
15
16   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
17   ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
18   LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
19   A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
20   OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
21   EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
22   PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
23   PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
24   LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
25   NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
26   SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
27*/
28"""
29
30""" module for inspecting models during inference """
31
32import os
33
34import yaml
35import matplotlib.pyplot as plt
36import matplotlib.animation as animation
37
38import torch
39import numpy as np
40
41# stores entries {key : {'fid' : fid, 'fs' : fs, 'dim' : dim, 'dtype' : dtype}}
42_state = dict()
43_folder = 'endoscopy'
44
45def get_gru_gates(gru, input, state):
46    hidden_size = gru.hidden_size
47
48    direct = torch.matmul(gru.weight_ih_l0, input.squeeze())
49    recurrent = torch.matmul(gru.weight_hh_l0, state.squeeze())
50
51    # reset gate
52    start, stop = 0 * hidden_size, 1 * hidden_size
53    reset_gate = torch.sigmoid(direct[start : stop] + gru.bias_ih_l0[start : stop] + recurrent[start : stop] + gru.bias_hh_l0[start : stop])
54
55    # update gate
56    start, stop = 1 * hidden_size, 2 * hidden_size
57    update_gate = torch.sigmoid(direct[start : stop] + gru.bias_ih_l0[start : stop] + recurrent[start : stop] + gru.bias_hh_l0[start : stop])
58
59    # new gate
60    start, stop = 2 * hidden_size, 3 * hidden_size
61    new_gate = torch.tanh(direct[start : stop] + gru.bias_ih_l0[start : stop] + reset_gate * (recurrent[start : stop] +  gru.bias_hh_l0[start : stop]))
62
63    return {'reset_gate' : reset_gate, 'update_gate' : update_gate, 'new_gate' : new_gate}
64
65
66def init(folder='endoscopy'):
67    """ sets up output folder for endoscopy data """
68
69    global _folder
70    _folder = folder
71
72    if not os.path.exists(folder):
73        os.makedirs(folder)
74    else:
75        print(f"warning: endoscopy folder {folder} exists. Content may be lost or inconsistent results may occur.")
76
77def write_data(key, data, fs):
78    """ appends data to previous data written under key """
79
80    global _state
81
82    # convert to numpy if torch.Tensor is given
83    if isinstance(data, torch.Tensor):
84        data = data.detach().numpy()
85
86    if not key in _state:
87        _state[key] = {
88            'fid'   : open(os.path.join(_folder, key + '.bin'), 'wb'),
89            'fs'    : fs,
90            'dim'   : tuple(data.shape),
91            'dtype' : str(data.dtype)
92        }
93
94        with open(os.path.join(_folder, key + '.yml'), 'w') as f:
95            f.write(yaml.dump({'fs' : fs, 'dim' : tuple(data.shape), 'dtype' : str(data.dtype).split('.')[-1]}))
96    else:
97        if _state[key]['fs'] != fs:
98            raise ValueError(f"fs changed for key {key}: {_state[key]['fs']} vs. {fs}")
99        if _state[key]['dtype'] != str(data.dtype):
100            raise ValueError(f"dtype changed for key {key}: {_state[key]['dtype']} vs. {str(data.dtype)}")
101        if _state[key]['dim'] != tuple(data.shape):
102            raise ValueError(f"dim changed for key {key}: {_state[key]['dim']} vs. {tuple(data.shape)}")
103
104    _state[key]['fid'].write(data.tobytes())
105
106def close(folder='endoscopy'):
107    """ clean up """
108    for key in _state.keys():
109        _state[key]['fid'].close()
110
111
112def read_data(folder='endoscopy'):
113    """ retrieves written data as numpy arrays """
114
115
116    keys = [name[:-4] for name in os.listdir(folder) if name.endswith('.yml')]
117
118    return_dict = dict()
119
120    for key in keys:
121        with open(os.path.join(folder, key + '.yml'), 'r') as f:
122            value = yaml.load(f.read(), yaml.FullLoader)
123
124        with open(os.path.join(folder, key + '.bin'), 'rb') as f:
125            data = np.frombuffer(f.read(), dtype=value['dtype'])
126
127        value['data'] = data.reshape((-1,) + value['dim'])
128
129        return_dict[key] = value
130
131    return return_dict
132
133def get_best_reshape(shape, target_ratio=1):
134    """ calculated the best 2d reshape of shape given the target ratio (rows/cols)"""
135
136    if len(shape) > 1:
137        pixel_count = 1
138        for s in shape:
139            pixel_count *= s
140    else:
141        pixel_count = shape[0]
142
143    if pixel_count == 1:
144        return (1,)
145
146    num_columns = int((pixel_count / target_ratio)**.5)
147
148    while (pixel_count % num_columns):
149        num_columns -= 1
150
151    num_rows = pixel_count // num_columns
152
153    return (num_rows, num_columns)
154
155def get_type_and_shape(shape):
156
157    # can happen if data is one dimensional
158    if len(shape) == 0:
159        shape = (1,)
160
161    # calculate pixel count
162    if len(shape) > 1:
163        pixel_count = 1
164        for s in shape:
165            pixel_count *= s
166    else:
167        pixel_count = shape[0]
168
169    if pixel_count == 1:
170        return 'plot', (1, )
171
172    # stay with shape if already 2-dimensional
173    if len(shape) == 2:
174        if (shape[0] != pixel_count) or (shape[1] != pixel_count):
175            return 'image', shape
176
177    return 'image', get_best_reshape(shape)
178
179def make_animation(data, filename, start_index=80, stop_index=-80, interval=20, half_signal_window_length=80):
180
181    # determine plot setup
182    num_keys = len(data.keys())
183
184    num_rows = int((num_keys * 3/4) ** .5)
185
186    num_cols = (num_keys + num_rows - 1) // num_rows
187
188    fig, axs = plt.subplots(num_rows, num_cols)
189    fig.set_size_inches(num_cols * 5, num_rows * 5)
190
191    display = dict()
192
193    fs_max = max([val['fs'] for val in data.values()])
194
195    num_samples = max([val['data'].shape[0] for val in data.values()])
196
197    keys = sorted(data.keys())
198
199    # inspect data
200    for i, key in enumerate(keys):
201        axs[i // num_cols, i % num_cols].title.set_text(key)
202
203        display[key] = dict()
204
205        display[key]['type'], display[key]['shape'] = get_type_and_shape(data[key]['dim'])
206        display[key]['down_factor'] = data[key]['fs'] / fs_max
207
208    start_index = max(start_index, half_signal_window_length)
209    while stop_index < 0:
210        stop_index += num_samples
211
212    stop_index = min(stop_index, num_samples - half_signal_window_length)
213
214    # actual plotting
215    frames = []
216    for index in range(start_index, stop_index):
217        ims = []
218        for i, key in enumerate(keys):
219            feature_index = int(round(index * display[key]['down_factor']))
220
221            if display[key]['type'] == 'plot':
222                ims.append(axs[i // num_cols, i % num_cols].plot(data[key]['data'][index - half_signal_window_length : index + half_signal_window_length], marker='P', markevery=[half_signal_window_length], animated=True, color='blue')[0])
223
224            elif display[key]['type'] == 'image':
225                ims.append(axs[i // num_cols, i % num_cols].imshow(data[key]['data'][index].reshape(display[key]['shape']), animated=True))
226
227        frames.append(ims)
228
229    ani = animation.ArtistAnimation(fig, frames, interval=interval, blit=True, repeat_delay=1000)
230
231    if not filename.endswith('.mp4'):
232        filename += '.mp4'
233
234    ani.save(filename)