• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 #include <torch/csrc/distributed/c10d/Utils.hpp>
2 
3 #include <cstring>
4 
5 namespace c10d {
6 
getTensorShapes(const std::vector<at::Tensor> & tensors)7 std::vector<at::Tensor> getTensorShapes(
8     const std::vector<at::Tensor>& tensors) {
9   std::vector<at::Tensor> shapeTensors;
10   shapeTensors.reserve(tensors.size());
11   for (const auto& tensor : tensors) {
12     // Use `at::tensor()` to copy the data underlying `sizes()` since it may be
13     // released elsewhere.
14     at::Tensor shapesTensor =
15         at::tensor(tensor.sizes(), at::TensorOptions().dtype(at::kLong));
16     shapeTensors.emplace_back(std::move(shapesTensor));
17   }
18   return shapeTensors;
19 }
20 
getTensorsNumel(const std::vector<at::Tensor> & tensors)21 size_t getTensorsNumel(const std::vector<at::Tensor>& tensors) {
22   size_t numel = 0;
23   for (auto& tensor : tensors) {
24     numel += tensor.numel();
25   }
26   return numel;
27 }
28 
29 } // namespace c10d
30