• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1"""
2/* Copyright (c) 2023 Amazon
3   Written by Jan Buethe */
4/*
5   Redistribution and use in source and binary forms, with or without
6   modification, are permitted provided that the following conditions
7   are met:
8
9   - Redistributions of source code must retain the above copyright
10   notice, this list of conditions and the following disclaimer.
11
12   - Redistributions in binary form must reproduce the above copyright
13   notice, this list of conditions and the following disclaimer in the
14   documentation and/or other materials provided with the distribution.
15
16   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
17   ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
18   LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
19   A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
20   OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
21   EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
22   PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
23   PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
24   LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
25   NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
26   SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
27*/
28"""
29
30import torch
31
32from .common import sparsify_matrix
33
34
35class GRUSparsifier:
36    def __init__(self, task_list, start, stop, interval, exponent=3):
37        """ Sparsifier for torch.nn.GRUs
38
39            Parameters:
40            -----------
41            task_list : list
42                task_list contains a list of tuples (gru, sparsify_dict), where gru is an instance
43                of torch.nn.GRU and sparsify_dic is a dictionary with keys in {'W_ir', 'W_iz', 'W_in',
44                'W_hr', 'W_hz', 'W_hn'} corresponding to the input and recurrent weights for the reset,
45                update, and new gate. The values of sparsify_dict are tuples (density, [m, n], keep_diagonal),
46                where density is the target density in [0, 1], [m, n] is the shape sub-blocks to which
47                sparsification is applied and keep_diagonal is a bool variable indicating whether the diagonal
48                should be kept.
49
50            start : int
51                training step after which sparsification will be started.
52
53            stop : int
54                training step after which sparsification will be completed.
55
56            interval : int
57                sparsification interval for steps between start and stop. After stop sparsification will be
58                carried out after every call to GRUSparsifier.step()
59
60            exponent : float
61                Interpolation exponent for sparsification interval. In step i sparsification will be carried out
62                with density (alpha + target_density * (1 * alpha)), where
63                alpha = ((stop - i) / (start - stop)) ** exponent
64
65            Example:
66            --------
67            >>> import torch
68            >>> gru = torch.nn.GRU(10, 20)
69            >>> sparsify_dict = {
70            ...         'W_ir' : (0.5, [2, 2], False),
71            ...         'W_iz' : (0.6, [2, 2], False),
72            ...         'W_in' : (0.7, [2, 2], False),
73            ...         'W_hr' : (0.1, [4, 4], True),
74            ...         'W_hz' : (0.2, [4, 4], True),
75            ...         'W_hn' : (0.3, [4, 4], True),
76            ...     }
77            >>> sparsifier = GRUSparsifier([(gru, sparsify_dict)], 0, 100, 50)
78            >>> for i in range(100):
79            ...         sparsifier.step()
80        """
81        # just copying parameters...
82        self.start      = start
83        self.stop       = stop
84        self.interval   = interval
85        self.exponent   = exponent
86        self.task_list  = task_list
87
88        # ... and setting counter to 0
89        self.step_counter = 0
90
91        self.last_masks = {key : None for key in ['W_ir', 'W_in', 'W_iz', 'W_hr', 'W_hn', 'W_hz']}
92
93    def step(self, verbose=False):
94        """ carries out sparsification step
95
96            Call this function after optimizer.step in your
97            training loop.
98
99            Parameters:
100            ----------
101            verbose : bool
102                if true, densities are printed out
103
104            Returns:
105            --------
106            None
107
108        """
109        # compute current interpolation factor
110        self.step_counter += 1
111
112        if self.step_counter < self.start:
113            return
114        elif self.step_counter < self.stop:
115            # update only every self.interval-th interval
116            if self.step_counter % self.interval:
117                return
118
119            alpha = ((self.stop - self.step_counter) / (self.stop - self.start)) ** self.exponent
120        else:
121            alpha = 0
122
123
124        with torch.no_grad():
125            for gru, params in self.task_list:
126                hidden_size = gru.hidden_size
127
128                # input weights
129                for i, key in enumerate(['W_ir', 'W_iz', 'W_in']):
130                    if key in params:
131                        density = alpha + (1 - alpha) * params[key][0]
132                        if verbose:
133                            print(f"[{self.step_counter}]: {key} density: {density}")
134
135                        gru.weight_ih_l0[i * hidden_size : (i+1) * hidden_size, : ], new_mask = sparsify_matrix(
136                            gru.weight_ih_l0[i * hidden_size : (i + 1) * hidden_size, : ],
137                            density, # density
138                            params[key][1], # block_size
139                            params[key][2], # keep_diagonal (might want to set this to False)
140                            return_mask=True
141                        )
142
143                        if type(self.last_masks[key]) != type(None):
144                            if not torch.all(self.last_masks[key] == new_mask) and self.step_counter > self.stop:
145                                print(f"sparsification mask {key} changed for gru {gru}")
146
147                        self.last_masks[key] = new_mask
148
149                # recurrent weights
150                for i, key in enumerate(['W_hr', 'W_hz', 'W_hn']):
151                    if key in params:
152                        density = alpha + (1 - alpha) * params[key][0]
153                        if verbose:
154                            print(f"[{self.step_counter}]: {key} density: {density}")
155                        gru.weight_hh_l0[i * hidden_size : (i+1) * hidden_size, : ], new_mask = sparsify_matrix(
156                            gru.weight_hh_l0[i * hidden_size : (i + 1) * hidden_size, : ],
157                            density,
158                            params[key][1], # block_size
159                            params[key][2], # keep_diagonal (might want to set this to False)
160                            return_mask=True
161                        )
162
163                        if type(self.last_masks[key]) != type(None):
164                            if not torch.all(self.last_masks[key] == new_mask) and self.step_counter > self.stop:
165                                print(f"sparsification mask {key} changed for gru {gru}")
166
167                        self.last_masks[key] = new_mask
168
169
170
171if __name__ == "__main__":
172    print("Testing sparsifier")
173
174    gru = torch.nn.GRU(10, 20)
175    sparsify_dict = {
176        'W_ir' : (0.5, [2, 2], False),
177        'W_iz' : (0.6, [2, 2], False),
178        'W_in' : (0.7, [2, 2], False),
179        'W_hr' : (0.1, [4, 4], True),
180        'W_hz' : (0.2, [4, 4], True),
181        'W_hn' : (0.3, [4, 4], True),
182    }
183
184    sparsifier = GRUSparsifier([(gru, sparsify_dict)], 0, 100, 10)
185
186    for i in range(100):
187        sparsifier.step(verbose=True)
188