1 // Copyright (c) Facebook, Inc. and its affiliates.
2 // All rights reserved.
3 //
4 // This source code is licensed under the BSD-style license found in the
5 // LICENSE file in the root directory of this source tree.
6
7 #include <ATen/functorch/LegacyVmapTransforms.h>
8 #include <ATen/functorch/DynamicLayer.h>
9
10 #include <ATen/ATen.h>
11 #include <c10/util/irange.h>
12
13 namespace at::functorch {
14
15 // Takes a BatchedTensorImpl, permutes all of the batch dims to the front,
16 // and then returns a physical version of the Tensor.
permuteBatchDimsToFront(const BatchedTensorImpl * batched)17 static Tensor permuteBatchDimsToFront(const BatchedTensorImpl* batched) {
18 const Tensor& physical_tensor = batched->value();
19 if (batched->bdim() == 0) {
20 return physical_tensor;
21 }
22 const auto sizes = physical_tensor.sym_sizes();
23 VmapDimVector permutation(sizes.size(), 0);
24 permutation.reserve(sizes.size());
25 const auto is_bdim = createBatchDimBitset(batched->bdim());
26 int64_t idx = 0;
27 permutation[idx++] = batched->bdim();
28 for (const auto ptr : c10::irange(0, sizes.size())) {
29 if (is_bdim[ptr]) {
30 continue;
31 }
32 permutation[idx++] = ptr;
33 }
34 return physical_tensor.permute(permutation);
35 }
36
logicalToPhysical(const Tensor & logical_tensor)37 VmapPhysicalView MultiBatchVmapTransform::logicalToPhysical(const Tensor& logical_tensor) {
38 auto* batched = maybeGetBatchedImpl(logical_tensor);
39 TORCH_INTERNAL_ASSERT(
40 batched,
41 "logicalToPhysical(tensor) should only be passed a BatchedTensor");
42 return { permuteBatchDimsToFront(batched), createVmapLevelsBitset(batched->level()) };
43 }
44
numBatchDims() const45 int64_t VmapPhysicalView::numBatchDims() const {
46 return levels_.count();
47 }
48
numLogicalDims() const49 int64_t VmapPhysicalView::numLogicalDims() const {
50 return /*physical*/tensor_.dim() - numBatchDims();
51 }
52
getPhysicalDims(IntArrayRef logical_dims) const53 VmapDimVector VmapPhysicalView::getPhysicalDims(IntArrayRef logical_dims) const {
54 auto logical_ndim = numLogicalDims();
55 // NB: fmap doesn't have a SmallVector variant, so we don't use it here.
56 VmapDimVector result;
57 result.reserve(logical_ndim);
58 for (auto dim : logical_dims) {
59 result.push_back(maybe_wrap_dim(dim, logical_ndim) + numBatchDims());
60 }
61 return result;
62 }
63
getPhysicalDim(int64_t logical_dim) const64 int64_t VmapPhysicalView::getPhysicalDim(int64_t logical_dim) const {
65 auto logical_ndim = numLogicalDims();
66 return maybe_wrap_dim(logical_dim, logical_ndim) + numBatchDims();
67 }
68
getPhysicalShape(IntArrayRef logical_shape) const69 VmapDimVector VmapPhysicalView::getPhysicalShape(IntArrayRef logical_shape) const {
70 VmapDimVector result;
71 result.reserve(logical_shape.size() + numBatchDims());
72 auto tensor_sizes = tensor_.sizes();
73 result.insert(result.end(), tensor_sizes.begin(), tensor_sizes.begin() + numBatchDims());
74 result.insert(result.end(), logical_shape.begin(), logical_shape.end());
75 return result;
76 }
77
getPhysicalShape(c10::SymIntArrayRef logical_shape) const78 SymDimVector VmapPhysicalView::getPhysicalShape(c10::SymIntArrayRef logical_shape) const {
79 SymDimVector result;
80 result.reserve(logical_shape.size() + numBatchDims());
81 auto tensor_sizes = tensor_.sym_sizes();
82 result.insert(result.end(), tensor_sizes.begin(), tensor_sizes.begin() + numBatchDims());
83 result.insert(result.end(), logical_shape.begin(), logical_shape.end());
84 return result;
85 }
86
computeFrontBatchDimsFromLevels(std::bitset<kVmapNumLevels> levels_bitset)87 static std::tuple<int64_t, int64_t> computeFrontBatchDimsFromLevels(std::bitset<kVmapNumLevels> levels_bitset) {
88 int64_t level = 0;
89 int64_t dim = 0;
90 for (; level < kVmapNumLevels; level++) {
91 if (!levels_bitset[level]) {
92 continue;
93 }
94 break;
95 }
96 return std::make_tuple(dim, level);
97 }
98
moveDimToFrontAndExpand(Tensor tensor,std::optional<int64_t> dim,c10::SymInt size)99 static Tensor moveDimToFrontAndExpand(Tensor tensor, std::optional<int64_t> dim, c10::SymInt size) {
100 if (dim) {
101 tensor = tensor.movedim(*dim, 0);
102 } else {
103 tensor = tensor.unsqueeze(0);
104 auto expanded_sizes = tensor.sym_sizes().vec();
105 expanded_sizes[0] = size;
106 tensor = tensor.expand_symint(expanded_sizes);
107 }
108 return tensor;
109 }
110
111 // The algorithm is as follows:
112 // 1. Figure out what all of the collective levels in `logical_tensors` is.
113 // 2. Move all batch dims to the front of the tensors and add extra dims
114 // of size 1. At this point, every tensor will have a dimension for
115 // each of the collective levels.
116 // 3. Compute the batch_sizes.
117 // 4. Expand each physical tensor so that they have output batch size equal
118 // to `batch_sizes`
119 VmapPhysicalViewVec
logicalToPhysical(ITensorListRef logical_tensors)120 MultiBatchVmapTransform::logicalToPhysical(ITensorListRef logical_tensors) {
121 auto cur_level = maybeCurrentDynamicLayer().value().layerId();
122 c10::SymInt bdim_size = -1;
123
124 // Figure out the batch size first
125 for (const auto& logical_tensor : logical_tensors) {
126 auto* batched = maybeGetBatchedImpl(logical_tensor);
127 if (!batched) {
128 continue;
129 }
130 if (batched->level() != cur_level) {
131 continue;
132 }
133 bdim_size = batched->value().sym_size(batched->bdim());
134 }
135 TORCH_INTERNAL_ASSERT(bdim_size != -1);
136
137 std::bitset<kVmapNumLevels> levels;
138 levels[cur_level] = true;
139
140 VmapPhysicalViewVec result;
141 for (const auto& logical_tensor : logical_tensors) {
142 auto* batched = maybeGetBatchedImpl(logical_tensor);
143 if (!batched || (batched->level() != cur_level)) {
144 // Unsqueeze dim 0, expand it to the correct shape
145 c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
146 auto value = moveDimToFrontAndExpand(logical_tensor, {}, bdim_size);
147 result.emplace_back(std::move(value), levels);
148 continue;
149 }
150 c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
151 auto physical = batched->value();
152 auto value = moveDimToFrontAndExpand(physical, batched->bdim(), bdim_size);
153 result.emplace_back(std::move(value), levels);
154 }
155
156 return result;
157 }
158
moveDimToFrontAndUnsqueeze(Tensor tensor,std::optional<int64_t> dim,int64_t example_ndim)159 static Tensor moveDimToFrontAndUnsqueeze(Tensor tensor, std::optional<int64_t> dim, int64_t example_ndim) {
160 if (dim) {
161 tensor = tensor.movedim(*dim, 0);
162 } else {
163 tensor = tensor.unsqueeze(0);
164 }
165 auto ndim = tensor.dim() - 1;
166 for (int64_t i = 0; i < example_ndim - ndim; i++) {
167 tensor = tensor.unsqueeze(1);
168 }
169 return tensor;
170 }
171
logicalToPhysical(TensorList logical_tensors)172 VmapPhysicalViewVec BroadcastingVmapTransform::logicalToPhysical(TensorList logical_tensors) {
173 auto cur_level = maybeCurrentDynamicLayer().value().layerId();
174 auto bdim_size = -1;
175
176 // Figure out the batch size first
177 for (const auto& logical_tensor : logical_tensors) {
178 auto* batched = maybeGetBatchedImpl(logical_tensor);
179 if (!batched || (batched->level() != cur_level)) {
180 continue;
181 }
182 bdim_size = batched->value().size(batched->bdim());
183 }
184 TORCH_INTERNAL_ASSERT(bdim_size != -1);
185
186 std::bitset<kVmapNumLevels> levels;
187 levels[cur_level] = true;
188
189 // figure out the example ndim
190 int64_t max_example_dim = -1;
191 for (const auto& logical_tensor : logical_tensors) {
192 max_example_dim = std::max(logical_tensor.dim(), max_example_dim);
193 }
194
195 VmapPhysicalViewVec result;
196 for (const auto& logical_tensor : logical_tensors) {
197 auto* batched = maybeGetBatchedImpl(logical_tensor);
198 if (!batched || (batched->level() != cur_level)) {
199 // Unsqueeze dim 0, expand it to the correct shape
200 c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
201 auto value = moveDimToFrontAndUnsqueeze(logical_tensor, {}, max_example_dim);
202 result.emplace_back(std::move(value), levels);
203 continue;
204 }
205 c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
206 auto physical = batched->value();
207 auto value = moveDimToFrontAndUnsqueeze(physical, batched->bdim(), max_example_dim);
208 result.emplace_back(std::move(value), levels);
209 }
210
211 return result;
212 }
213
getPhysicalToLogicalMap() const214 VmapPhysicalToLogicalMap VmapPhysicalView::getPhysicalToLogicalMap() const {
215 return VmapPhysicalToLogicalMap(levels_);
216 }
217
apply(const Tensor & physical_tensor) const218 Tensor VmapPhysicalToLogicalMap::apply(const Tensor& physical_tensor) const {
219 auto bdim_level = computeFrontBatchDimsFromLevels(levels_);
220 return makeBatched(physical_tensor, std::get<0>(bdim_level), std::get<1>(bdim_level));
221 }
222
applyInplace(std::vector<Tensor> & physical_tensors) const223 void VmapPhysicalToLogicalMap::applyInplace(std::vector<Tensor>& physical_tensors) const {
224 for (const auto idx : c10::irange(0, physical_tensors.size())) {
225 physical_tensors[idx] = apply(physical_tensors[idx]);
226 }
227 }
228
229 } // namespace at::functorch
230