1""" 2/* Copyright (c) 2023 Amazon 3 Written by Jan Buethe */ 4/* 5 Redistribution and use in source and binary forms, with or without 6 modification, are permitted provided that the following conditions 7 are met: 8 9 - Redistributions of source code must retain the above copyright 10 notice, this list of conditions and the following disclaimer. 11 12 - Redistributions in binary form must reproduce the above copyright 13 notice, this list of conditions and the following disclaimer in the 14 documentation and/or other materials provided with the distribution. 15 16 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 17 ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 18 LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 19 A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER 20 OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 21 EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 22 PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 23 PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 24 LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 25 NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 26 SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 27*/ 28""" 29 30import torch 31 32from .common import sparsify_matrix 33 34 35class GRUSparsifier: 36 def __init__(self, task_list, start, stop, interval, exponent=3): 37 """ Sparsifier for torch.nn.GRUs 38 39 Parameters: 40 ----------- 41 task_list : list 42 task_list contains a list of tuples (gru, sparsify_dict), where gru is an instance 43 of torch.nn.GRU and sparsify_dic is a dictionary with keys in {'W_ir', 'W_iz', 'W_in', 44 'W_hr', 'W_hz', 'W_hn'} corresponding to the input and recurrent weights for the reset, 45 update, and new gate. The values of sparsify_dict are tuples (density, [m, n], keep_diagonal), 46 where density is the target density in [0, 1], [m, n] is the shape sub-blocks to which 47 sparsification is applied and keep_diagonal is a bool variable indicating whether the diagonal 48 should be kept. 49 50 start : int 51 training step after which sparsification will be started. 52 53 stop : int 54 training step after which sparsification will be completed. 55 56 interval : int 57 sparsification interval for steps between start and stop. After stop sparsification will be 58 carried out after every call to GRUSparsifier.step() 59 60 exponent : float 61 Interpolation exponent for sparsification interval. In step i sparsification will be carried out 62 with density (alpha + target_density * (1 * alpha)), where 63 alpha = ((stop - i) / (start - stop)) ** exponent 64 65 Example: 66 -------- 67 >>> import torch 68 >>> gru = torch.nn.GRU(10, 20) 69 >>> sparsify_dict = { 70 ... 'W_ir' : (0.5, [2, 2], False), 71 ... 'W_iz' : (0.6, [2, 2], False), 72 ... 'W_in' : (0.7, [2, 2], False), 73 ... 'W_hr' : (0.1, [4, 4], True), 74 ... 'W_hz' : (0.2, [4, 4], True), 75 ... 'W_hn' : (0.3, [4, 4], True), 76 ... } 77 >>> sparsifier = GRUSparsifier([(gru, sparsify_dict)], 0, 100, 50) 78 >>> for i in range(100): 79 ... sparsifier.step() 80 """ 81 # just copying parameters... 82 self.start = start 83 self.stop = stop 84 self.interval = interval 85 self.exponent = exponent 86 self.task_list = task_list 87 88 # ... and setting counter to 0 89 self.step_counter = 0 90 91 self.last_masks = {key : None for key in ['W_ir', 'W_in', 'W_iz', 'W_hr', 'W_hn', 'W_hz']} 92 93 def step(self, verbose=False): 94 """ carries out sparsification step 95 96 Call this function after optimizer.step in your 97 training loop. 98 99 Parameters: 100 ---------- 101 verbose : bool 102 if true, densities are printed out 103 104 Returns: 105 -------- 106 None 107 108 """ 109 # compute current interpolation factor 110 self.step_counter += 1 111 112 if self.step_counter < self.start: 113 return 114 elif self.step_counter < self.stop: 115 # update only every self.interval-th interval 116 if self.step_counter % self.interval: 117 return 118 119 alpha = ((self.stop - self.step_counter) / (self.stop - self.start)) ** self.exponent 120 else: 121 alpha = 0 122 123 124 with torch.no_grad(): 125 for gru, params in self.task_list: 126 hidden_size = gru.hidden_size 127 128 # input weights 129 for i, key in enumerate(['W_ir', 'W_iz', 'W_in']): 130 if key in params: 131 density = alpha + (1 - alpha) * params[key][0] 132 if verbose: 133 print(f"[{self.step_counter}]: {key} density: {density}") 134 135 gru.weight_ih_l0[i * hidden_size : (i+1) * hidden_size, : ], new_mask = sparsify_matrix( 136 gru.weight_ih_l0[i * hidden_size : (i + 1) * hidden_size, : ], 137 density, # density 138 params[key][1], # block_size 139 params[key][2], # keep_diagonal (might want to set this to False) 140 return_mask=True 141 ) 142 143 if type(self.last_masks[key]) != type(None): 144 if not torch.all(self.last_masks[key] == new_mask) and self.step_counter > self.stop: 145 print(f"sparsification mask {key} changed for gru {gru}") 146 147 self.last_masks[key] = new_mask 148 149 # recurrent weights 150 for i, key in enumerate(['W_hr', 'W_hz', 'W_hn']): 151 if key in params: 152 density = alpha + (1 - alpha) * params[key][0] 153 if verbose: 154 print(f"[{self.step_counter}]: {key} density: {density}") 155 gru.weight_hh_l0[i * hidden_size : (i+1) * hidden_size, : ], new_mask = sparsify_matrix( 156 gru.weight_hh_l0[i * hidden_size : (i + 1) * hidden_size, : ], 157 density, 158 params[key][1], # block_size 159 params[key][2], # keep_diagonal (might want to set this to False) 160 return_mask=True 161 ) 162 163 if type(self.last_masks[key]) != type(None): 164 if not torch.all(self.last_masks[key] == new_mask) and self.step_counter > self.stop: 165 print(f"sparsification mask {key} changed for gru {gru}") 166 167 self.last_masks[key] = new_mask 168 169 170 171if __name__ == "__main__": 172 print("Testing sparsifier") 173 174 gru = torch.nn.GRU(10, 20) 175 sparsify_dict = { 176 'W_ir' : (0.5, [2, 2], False), 177 'W_iz' : (0.6, [2, 2], False), 178 'W_in' : (0.7, [2, 2], False), 179 'W_hr' : (0.1, [4, 4], True), 180 'W_hz' : (0.2, [4, 4], True), 181 'W_hn' : (0.3, [4, 4], True), 182 } 183 184 sparsifier = GRUSparsifier([(gru, sparsify_dict)], 0, 100, 10) 185 186 for i in range(100): 187 sparsifier.step(verbose=True) 188