• Home
Name Date Size #Lines LOC

..--

README.mdD04-Jul-20252.9 KiB6749

__init__.pyD04-Jul-202592 74

base_data_scheduler.pyD04-Jul-20257.4 KiB196144

README.md

1# Data Scheduler
2## Intro
3The data scheduler is used to control the update of the data sparsification parameters and works specifically with the data sparsifier class.
4This class controls a specific config param (specified by the `schedule_param` argument) of
5the data sparsifier class and varies it across the training process (or across time).
6
7## API details
8`BaseDataScheduler`: base class with abstract method `get_schedule_param` that computes the data sparsification parameter for all the data. The constructor accepts
91. `data_sparsifier`: The data sparsifier object whose parameter will be scheduled.
102. `schedule_param` : a specific config of the passed data sparsifier that needs to be scheduled/varied.
11
12`get_last_param`: gets the last scheduled parameter. Basically, a dictionary of name (of data) to schedule_param value mapping.
13
14`step`: Applies the `get_schedule_param` logic every epoch/step depending on when it is called. This should always be called after the `sparsifier.step()` has been called.
15
16## Write your own data scheduler
17The custom data scheduler must be inherit from the `BaseDataScheduler` class and should have the `get_schedule_param()` function implemented. For example, that gradually multiplies the sparsity level by `gamma` every epoch.
18It also takes an argument `threshold_sl` which when reached does not increase further.
19
20```
21class GammaScheduler(BaseDataScheduler):
22    def __init__(self, data_sparsifier, gamma, threshold_sl):
23        super().__init__(data_sparsifier, "sparsity_level")
24        self.gamma = gamma
25        self.threshold_sl = threshold_sl
26
27    def get_schedule_param(self):
28        if self.last_epoch > 0:
29            return {name: min(self.threshold_sl, config["sparsity_level"] * self.gamma) for name, config in self.data_sparsifier.data_groups.items()}
30        else:
31            return {name: 0.0 for name, config in self.data_sparsifier.data_groups.items()}
32```
33
34## Using data scheduler with data sparsifier
35Suppose the need is to vary data sparsity levels (or any sparsity `param`) during training, then a custom data scheduler can be implemented and used along with the data sparsifier.
36
37Example:
38
39```
40model = SomeModel()
41optimizer = SomeOptimizer(model.parameters(), lr=...)
42data_sparsifier = SomeDataSparsifier(...)
43
44
45data_scheduler = SomeDataScheduler(data_sparsifier, ...)
46
47
48data_name = 'train_data'
49
50for epoch in range(EPOCHS):
51    for input, target in dataset:
52        input = data_sparsifier.add_data(name=data_name, data=input)
53
54        optimizer.zero_grad()
55        output = model(input)
56        loss = loss_fn(output, target)
57        loss.backward()
58        optimizer.step()
59        data_sparsifier.step()
60
61    data_scheduler.step()
62```
63
64### Note:
651. `get_schedule_param()` should return a dictionary wherein the keys are the names of the data and the values are the corresponding values of the `schedule_param` for the next step.
662. It is the responsibility of the `BaseDataScheduler` to call the `get_schedule_param()` when necessary.
67