• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 #include <torch/extension.h>
2 
3 struct Doubler {
DoublerDoubler4   Doubler(int A, int B) {
5     tensor_ =
6         torch::ones({A, B}, torch::dtype(torch::kFloat64).requires_grad(true));
7   }
forwardDoubler8   torch::Tensor forward() {
9     return tensor_ * 2;
10   }
getDoubler11   torch::Tensor get() const {
12     return tensor_;
13   }
14 
15  private:
16   torch::Tensor tensor_;
17 };
18