• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1"""
2==========================
3Model ensembling
4==========================
5This example illustrates how to vectorize model ensembling using vmap.
6
7What is model ensembling?
8--------------------------------------------------------------------
9Model ensembling combines the predictions from multiple models together.
10Traditionally this is done by running each model on some inputs separately
11and then combining the predictions. However, if you're running models with
12the same architecture, then it may be possible to combine them together
13using ``vmap``. ``vmap`` is a function transform that maps functions across
14dimensions of the input tensors. One of its use cases is eliminating
15for-loops and speeding them up through vectorization.
16
17Let's demonstrate how to do this using an ensemble of simple CNNs.
18"""
19import torch
20import torch.nn as nn
21import torch.nn.functional as F
22
23
24torch.manual_seed(0)
25
26
27# Here's a simple CNN
28class SimpleCNN(nn.Module):
29    def __init__(self):
30        super().__init__()
31        self.conv1 = nn.Conv2d(1, 32, 3, 1)
32        self.conv2 = nn.Conv2d(32, 64, 3, 1)
33        self.fc1 = nn.Linear(9216, 128)
34        self.fc2 = nn.Linear(128, 10)
35
36    def forward(self, x):
37        x = self.conv1(x)
38        x = F.relu(x)
39        x = self.conv2(x)
40        x = F.relu(x)
41        x = F.max_pool2d(x, 2)
42        x = torch.flatten(x, 1)
43        x = self.fc1(x)
44        x = F.relu(x)
45        x = self.fc2(x)
46        output = F.log_softmax(x, dim=1)
47        output = x
48        return output
49
50
51# Let's generate some dummy data. Pretend that we're working with an MNIST dataset
52# where the images are 28 by 28.
53# Furthermore, let's say we wish to combine the predictions from 10 different
54# models.
55device = "cuda"
56num_models = 10
57data = torch.randn(100, 64, 1, 28, 28, device=device)
58targets = torch.randint(10, (6400,), device=device)
59models = [SimpleCNN().to(device) for _ in range(num_models)]
60
61# We have a couple of options for generating predictions. Maybe we want
62# to give each model a different randomized minibatch of data, or maybe we
63# want to run the same minibatch of data through each model (e.g. if we were
64# testing the effect of different model initializations).
65
66# Option 1: different minibatch for each model
67minibatches = data[:num_models]
68predictions1 = [model(minibatch) for model, minibatch in zip(models, minibatches)]
69
70# Option 2: Same minibatch
71minibatch = data[0]
72predictions2 = [model(minibatch) for model in models]
73
74
75######################################################################
76# Using vmap to vectorize the ensemble
77# --------------------------------------------------------------------
78# Let's use ``vmap`` to speed up the for-loop. We must first prepare the models
79# for use with ``vmap``.
80#
81# First, let's combine the states of the model together by stacking each parameter.
82# For example, model[i].fc1.weight has shape [9216, 128]; we are going to stack the
83# .fc1.weight of each of the 10 models to produce a big weight of shape [10, 9216, 128].
84#
85# functorch offers the following convenience function to do that. It returns a
86# stateless version of the model (fmodel) and stacked parameters and buffers.
87from functorch import combine_state_for_ensemble
88
89
90fmodel, params, buffers = combine_state_for_ensemble(models)
91[p.requires_grad_() for p in params]
92
93# Option 1: get predictions using a different minibatch for each model.
94# By default, vmap maps a function across the first dimension of all inputs to the
95# passed-in function. After `combine_state_for_ensemble`, each of of ``params``,
96# ``buffers`` have an additional dimension of size ``num_models`` at the front;
97# and ``minibatches`` has a dimension of size ``num_models``.
98print([p.size(0) for p in params])
99assert minibatches.shape == (num_models, 64, 1, 28, 28)
100from functorch import vmap
101
102
103predictions1_vmap = vmap(fmodel)(params, buffers, minibatches)
104assert torch.allclose(
105    predictions1_vmap, torch.stack(predictions1), atol=1e-6, rtol=1e-6
106)
107
108# Option 2: get predictions using the same minibatch of data
109# vmap has an in_dims arg that specify which dimensions to map over.
110# Using ``None``, we tell vmap we want the same minibatch to apply for all of
111# the 10 models.
112predictions2_vmap = vmap(fmodel, in_dims=(0, 0, None))(params, buffers, minibatch)
113assert torch.allclose(
114    predictions2_vmap, torch.stack(predictions2), atol=1e-6, rtol=1e-6
115)
116
117# A quick note: there are limitations around what types of functions can be
118# transformed by vmap. The best functions to transform are ones that are
119# pure functions: a function where the outputs are only determined by the inputs
120# that have no side effects (e.g. mutation). vmap is unable to handle mutation of
121# arbitrary Python data structures, but it is able to handle many in-place
122# PyTorch operations.
123