1""" 2========================== 3Model ensembling 4========================== 5This example illustrates how to vectorize model ensembling using vmap. 6 7What is model ensembling? 8-------------------------------------------------------------------- 9Model ensembling combines the predictions from multiple models together. 10Traditionally this is done by running each model on some inputs separately 11and then combining the predictions. However, if you're running models with 12the same architecture, then it may be possible to combine them together 13using ``vmap``. ``vmap`` is a function transform that maps functions across 14dimensions of the input tensors. One of its use cases is eliminating 15for-loops and speeding them up through vectorization. 16 17Let's demonstrate how to do this using an ensemble of simple CNNs. 18""" 19import torch 20import torch.nn as nn 21import torch.nn.functional as F 22 23 24torch.manual_seed(0) 25 26 27# Here's a simple CNN 28class SimpleCNN(nn.Module): 29 def __init__(self): 30 super().__init__() 31 self.conv1 = nn.Conv2d(1, 32, 3, 1) 32 self.conv2 = nn.Conv2d(32, 64, 3, 1) 33 self.fc1 = nn.Linear(9216, 128) 34 self.fc2 = nn.Linear(128, 10) 35 36 def forward(self, x): 37 x = self.conv1(x) 38 x = F.relu(x) 39 x = self.conv2(x) 40 x = F.relu(x) 41 x = F.max_pool2d(x, 2) 42 x = torch.flatten(x, 1) 43 x = self.fc1(x) 44 x = F.relu(x) 45 x = self.fc2(x) 46 output = F.log_softmax(x, dim=1) 47 output = x 48 return output 49 50 51# Let's generate some dummy data. Pretend that we're working with an MNIST dataset 52# where the images are 28 by 28. 53# Furthermore, let's say we wish to combine the predictions from 10 different 54# models. 55device = "cuda" 56num_models = 10 57data = torch.randn(100, 64, 1, 28, 28, device=device) 58targets = torch.randint(10, (6400,), device=device) 59models = [SimpleCNN().to(device) for _ in range(num_models)] 60 61# We have a couple of options for generating predictions. Maybe we want 62# to give each model a different randomized minibatch of data, or maybe we 63# want to run the same minibatch of data through each model (e.g. if we were 64# testing the effect of different model initializations). 65 66# Option 1: different minibatch for each model 67minibatches = data[:num_models] 68predictions1 = [model(minibatch) for model, minibatch in zip(models, minibatches)] 69 70# Option 2: Same minibatch 71minibatch = data[0] 72predictions2 = [model(minibatch) for model in models] 73 74 75###################################################################### 76# Using vmap to vectorize the ensemble 77# -------------------------------------------------------------------- 78# Let's use ``vmap`` to speed up the for-loop. We must first prepare the models 79# for use with ``vmap``. 80# 81# First, let's combine the states of the model together by stacking each parameter. 82# For example, model[i].fc1.weight has shape [9216, 128]; we are going to stack the 83# .fc1.weight of each of the 10 models to produce a big weight of shape [10, 9216, 128]. 84# 85# functorch offers the following convenience function to do that. It returns a 86# stateless version of the model (fmodel) and stacked parameters and buffers. 87from functorch import combine_state_for_ensemble 88 89 90fmodel, params, buffers = combine_state_for_ensemble(models) 91[p.requires_grad_() for p in params] 92 93# Option 1: get predictions using a different minibatch for each model. 94# By default, vmap maps a function across the first dimension of all inputs to the 95# passed-in function. After `combine_state_for_ensemble`, each of of ``params``, 96# ``buffers`` have an additional dimension of size ``num_models`` at the front; 97# and ``minibatches`` has a dimension of size ``num_models``. 98print([p.size(0) for p in params]) 99assert minibatches.shape == (num_models, 64, 1, 28, 28) 100from functorch import vmap 101 102 103predictions1_vmap = vmap(fmodel)(params, buffers, minibatches) 104assert torch.allclose( 105 predictions1_vmap, torch.stack(predictions1), atol=1e-6, rtol=1e-6 106) 107 108# Option 2: get predictions using the same minibatch of data 109# vmap has an in_dims arg that specify which dimensions to map over. 110# Using ``None``, we tell vmap we want the same minibatch to apply for all of 111# the 10 models. 112predictions2_vmap = vmap(fmodel, in_dims=(0, 0, None))(params, buffers, minibatch) 113assert torch.allclose( 114 predictions2_vmap, torch.stack(predictions2), atol=1e-6, rtol=1e-6 115) 116 117# A quick note: there are limitations around what types of functions can be 118# transformed by vmap. The best functions to transform are ones that are 119# pure functions: a function where the outputs are only determined by the inputs 120# that have no side effects (e.g. mutation). vmap is unable to handle mutation of 121# arbitrary Python data structures, but it is able to handle many in-place 122# PyTorch operations. 123