• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) Facebook, Inc. and its affiliates.
2 // All rights reserved.
3 //
4 // This source code is licensed under the BSD-style license found in the
5 // LICENSE file in the root directory of this source tree.
6 
7 #include <ATen/functorch/LegacyVmapTransforms.h>
8 #include <ATen/functorch/DynamicLayer.h>
9 
10 #include <ATen/ATen.h>
11 #include <c10/util/irange.h>
12 
13 namespace at::functorch {
14 
15 // Takes a BatchedTensorImpl, permutes all of the batch dims to the front,
16 // and then returns a physical version of the Tensor.
permuteBatchDimsToFront(const BatchedTensorImpl * batched)17 static Tensor permuteBatchDimsToFront(const BatchedTensorImpl* batched) {
18   const Tensor& physical_tensor = batched->value();
19   if (batched->bdim() == 0) {
20     return physical_tensor;
21   }
22   const auto sizes = physical_tensor.sym_sizes();
23   VmapDimVector permutation(sizes.size(), 0);
24   permutation.reserve(sizes.size());
25   const auto is_bdim = createBatchDimBitset(batched->bdim());
26   int64_t idx = 0;
27   permutation[idx++] = batched->bdim();
28   for (const auto ptr : c10::irange(0, sizes.size())) {
29     if (is_bdim[ptr]) {
30       continue;
31     }
32     permutation[idx++] = ptr;
33   }
34   return physical_tensor.permute(permutation);
35 }
36 
logicalToPhysical(const Tensor & logical_tensor)37 VmapPhysicalView MultiBatchVmapTransform::logicalToPhysical(const Tensor& logical_tensor) {
38   auto* batched = maybeGetBatchedImpl(logical_tensor);
39   TORCH_INTERNAL_ASSERT(
40       batched,
41       "logicalToPhysical(tensor) should only be passed a BatchedTensor");
42   return { permuteBatchDimsToFront(batched), createVmapLevelsBitset(batched->level()) };
43 }
44 
numBatchDims() const45 int64_t VmapPhysicalView::numBatchDims() const {
46   return levels_.count();
47 }
48 
numLogicalDims() const49 int64_t VmapPhysicalView::numLogicalDims() const {
50   return /*physical*/tensor_.dim() - numBatchDims();
51 }
52 
getPhysicalDims(IntArrayRef logical_dims) const53 VmapDimVector VmapPhysicalView::getPhysicalDims(IntArrayRef logical_dims) const {
54   auto logical_ndim = numLogicalDims();
55   // NB: fmap doesn't have a SmallVector variant, so we don't use it here.
56   VmapDimVector result;
57   result.reserve(logical_ndim);
58   for (auto dim : logical_dims) {
59     result.push_back(maybe_wrap_dim(dim, logical_ndim) + numBatchDims());
60   }
61   return result;
62 }
63 
getPhysicalDim(int64_t logical_dim) const64 int64_t VmapPhysicalView::getPhysicalDim(int64_t logical_dim) const {
65   auto logical_ndim = numLogicalDims();
66   return maybe_wrap_dim(logical_dim, logical_ndim) + numBatchDims();
67 }
68 
getPhysicalShape(IntArrayRef logical_shape) const69 VmapDimVector VmapPhysicalView::getPhysicalShape(IntArrayRef logical_shape) const {
70   VmapDimVector result;
71   result.reserve(logical_shape.size() + numBatchDims());
72   auto tensor_sizes = tensor_.sizes();
73   result.insert(result.end(), tensor_sizes.begin(), tensor_sizes.begin() + numBatchDims());
74   result.insert(result.end(), logical_shape.begin(), logical_shape.end());
75   return result;
76 }
77 
getPhysicalShape(c10::SymIntArrayRef logical_shape) const78 SymDimVector VmapPhysicalView::getPhysicalShape(c10::SymIntArrayRef logical_shape) const {
79   SymDimVector result;
80   result.reserve(logical_shape.size() + numBatchDims());
81   auto tensor_sizes = tensor_.sym_sizes();
82   result.insert(result.end(), tensor_sizes.begin(), tensor_sizes.begin() + numBatchDims());
83   result.insert(result.end(), logical_shape.begin(), logical_shape.end());
84   return result;
85 }
86 
computeFrontBatchDimsFromLevels(std::bitset<kVmapNumLevels> levels_bitset)87 static std::tuple<int64_t, int64_t> computeFrontBatchDimsFromLevels(std::bitset<kVmapNumLevels> levels_bitset) {
88   int64_t level = 0;
89   int64_t dim = 0;
90   for (; level < kVmapNumLevels; level++) {
91     if (!levels_bitset[level]) {
92       continue;
93     }
94     break;
95   }
96   return std::make_tuple(dim, level);
97 }
98 
moveDimToFrontAndExpand(Tensor tensor,std::optional<int64_t> dim,c10::SymInt size)99 static Tensor moveDimToFrontAndExpand(Tensor tensor, std::optional<int64_t> dim, c10::SymInt size) {
100   if (dim) {
101     tensor = tensor.movedim(*dim, 0);
102   } else {
103     tensor = tensor.unsqueeze(0);
104     auto expanded_sizes = tensor.sym_sizes().vec();
105     expanded_sizes[0] = size;
106     tensor = tensor.expand_symint(expanded_sizes);
107   }
108   return tensor;
109 }
110 
111 // The algorithm is as follows:
112 // 1. Figure out what all of the collective levels in `logical_tensors` is.
113 // 2. Move all batch dims to the front of the tensors and add extra dims
114 //    of size 1. At this point, every tensor will have a dimension for
115 //    each of the collective levels.
116 // 3. Compute the batch_sizes.
117 // 4. Expand each physical tensor so that they have output batch size equal
118 //    to `batch_sizes`
119 VmapPhysicalViewVec
logicalToPhysical(ITensorListRef logical_tensors)120 MultiBatchVmapTransform::logicalToPhysical(ITensorListRef logical_tensors) {
121   auto cur_level = maybeCurrentDynamicLayer().value().layerId();
122   c10::SymInt bdim_size = -1;
123 
124   // Figure out the batch size first
125   for (const auto& logical_tensor : logical_tensors) {
126     auto* batched = maybeGetBatchedImpl(logical_tensor);
127     if (!batched) {
128       continue;
129     }
130     if (batched->level() != cur_level) {
131       continue;
132     }
133     bdim_size = batched->value().sym_size(batched->bdim());
134   }
135   TORCH_INTERNAL_ASSERT(bdim_size != -1);
136 
137   std::bitset<kVmapNumLevels> levels;
138   levels[cur_level] = true;
139 
140   VmapPhysicalViewVec result;
141   for (const auto& logical_tensor : logical_tensors) {
142     auto* batched = maybeGetBatchedImpl(logical_tensor);
143     if (!batched || (batched->level() != cur_level)) {
144       // Unsqueeze dim 0, expand it to the correct shape
145       c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
146       auto value = moveDimToFrontAndExpand(logical_tensor, {}, bdim_size);
147       result.emplace_back(std::move(value), levels);
148       continue;
149     }
150     c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
151     auto physical = batched->value();
152     auto value = moveDimToFrontAndExpand(physical, batched->bdim(), bdim_size);
153     result.emplace_back(std::move(value), levels);
154   }
155 
156   return result;
157 }
158 
moveDimToFrontAndUnsqueeze(Tensor tensor,std::optional<int64_t> dim,int64_t example_ndim)159 static Tensor moveDimToFrontAndUnsqueeze(Tensor tensor, std::optional<int64_t> dim, int64_t example_ndim) {
160   if (dim) {
161     tensor = tensor.movedim(*dim, 0);
162   } else {
163     tensor = tensor.unsqueeze(0);
164   }
165   auto ndim = tensor.dim() - 1;
166   for (int64_t i = 0; i < example_ndim - ndim; i++) {
167     tensor = tensor.unsqueeze(1);
168   }
169   return tensor;
170 }
171 
logicalToPhysical(TensorList logical_tensors)172 VmapPhysicalViewVec BroadcastingVmapTransform::logicalToPhysical(TensorList logical_tensors) {
173   auto cur_level = maybeCurrentDynamicLayer().value().layerId();
174   auto bdim_size = -1;
175 
176   // Figure out the batch size first
177   for (const auto& logical_tensor : logical_tensors) {
178     auto* batched = maybeGetBatchedImpl(logical_tensor);
179     if (!batched || (batched->level() != cur_level)) {
180       continue;
181     }
182     bdim_size = batched->value().size(batched->bdim());
183   }
184   TORCH_INTERNAL_ASSERT(bdim_size != -1);
185 
186   std::bitset<kVmapNumLevels> levels;
187   levels[cur_level] = true;
188 
189   // figure out the example ndim
190   int64_t max_example_dim = -1;
191   for (const auto& logical_tensor : logical_tensors) {
192     max_example_dim = std::max(logical_tensor.dim(), max_example_dim);
193   }
194 
195   VmapPhysicalViewVec result;
196   for (const auto& logical_tensor : logical_tensors) {
197     auto* batched = maybeGetBatchedImpl(logical_tensor);
198     if (!batched || (batched->level() != cur_level)) {
199       // Unsqueeze dim 0, expand it to the correct shape
200       c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
201       auto value = moveDimToFrontAndUnsqueeze(logical_tensor, {}, max_example_dim);
202       result.emplace_back(std::move(value), levels);
203       continue;
204     }
205     c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
206     auto physical = batched->value();
207     auto value = moveDimToFrontAndUnsqueeze(physical, batched->bdim(), max_example_dim);
208     result.emplace_back(std::move(value), levels);
209   }
210 
211   return result;
212 }
213 
getPhysicalToLogicalMap() const214 VmapPhysicalToLogicalMap VmapPhysicalView::getPhysicalToLogicalMap() const {
215   return VmapPhysicalToLogicalMap(levels_);
216 }
217 
apply(const Tensor & physical_tensor) const218 Tensor VmapPhysicalToLogicalMap::apply(const Tensor& physical_tensor) const {
219   auto bdim_level = computeFrontBatchDimsFromLevels(levels_);
220   return makeBatched(physical_tensor, std::get<0>(bdim_level), std::get<1>(bdim_level));
221 }
222 
applyInplace(std::vector<Tensor> & physical_tensors) const223 void VmapPhysicalToLogicalMap::applyInplace(std::vector<Tensor>& physical_tensors) const {
224   for (const auto idx : c10::irange(0, physical_tensors.size())) {
225     physical_tensors[idx] = apply(physical_tensors[idx]);
226   }
227 }
228 
229 } // namespace at::functorch
230