• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 #pragma once
2 
3 #include <torch/nn/cloneable.h>
4 #include <torch/nn/module.h>
5 
6 #include <vector>
7 
8 namespace torch {
9 namespace nn {
10 class ParameterListImpl : public Cloneable<ParameterListImpl> {
11  public:
12   using Iterator = typename std::vector<
13       OrderedDict<std::string, torch::Tensor>::Item>::iterator;
14   using ConstIterator = typename std::vector<
15       OrderedDict<std::string, torch::Tensor>::Item>::const_iterator;
16 
17   ParameterListImpl() = default;
18 
19   /// Constructs the `ParameterList` from a variadic list of ParameterList.
20   template <typename... Tensors>
ParameterListImpl(Tensors &&...params)21   explicit ParameterListImpl(Tensors&&... params) {
22     parameters_.reserve(sizeof...(Tensors));
23     push_back_var(std::forward<Tensors>(params)...);
24   }
25 
26   template <typename... Tensors>
ParameterListImpl(const Tensors &...params)27   explicit ParameterListImpl(const Tensors&... params) {
28     parameters_.reserve(sizeof...(Tensors));
29     push_back_var(std::forward<Tensors>(params)...);
30   }
31 
32   /// `reset()` is empty for `ParameterList`, since it does not have parameters
33   /// of its own.
reset()34   void reset() override {}
35 
36   /// Pretty prints the `ParameterList` module into the given `stream`.
pretty_print(std::ostream & stream)37   void pretty_print(std::ostream& stream) const override {
38     stream << "torch::nn::ParameterList(" << std::endl;
39     for (const auto& pair : parameters_) {
40       stream << "(" << pair.key() << ")"
41              << ": Parameter containing: [" << pair.value().scalar_type()
42              << " of size " << pair.value().sizes() << "]";
43       ;
44       stream << std::endl;
45     }
46     stream << ")";
47   }
48 
49   /// push the a given parameter at the end of the list
append(torch::Tensor && param)50   void append(torch::Tensor&& param) {
51     bool requires_grad = param.requires_grad();
52     register_parameter(
53         std::to_string(parameters_.size()), std::move(param), requires_grad);
54   }
55 
56   /// push the a given parameter at the end of the list
append(const torch::Tensor & param)57   void append(const torch::Tensor& param) {
58     bool requires_grad = param.requires_grad();
59     register_parameter(
60         std::to_string(parameters_.size()), param, requires_grad);
61   }
62 
63   /// push the a given parameter at the end of the list
64   /// And the key of the pair will be discarded, only the value
65   /// will be added into the `ParameterList`
append(const OrderedDict<std::string,torch::Tensor>::Item & pair)66   void append(const OrderedDict<std::string, torch::Tensor>::Item& pair) {
67     register_parameter(
68         std::to_string(parameters_.size()),
69         pair.value(),
70         pair.value().requires_grad());
71   }
72 
73   /// extend parameters from a container to the end of the list
74   template <typename Container>
extend(const Container & container)75   void extend(const Container& container) {
76     for (const auto& param : container) {
77       append(param);
78     }
79   }
80 
81   /// Returns an iterator to the start of the ParameterList
82   /// the iterator returned will be type of `OrderedDict<std::string,
83   /// torch::Tensor>::Item`
begin()84   Iterator begin() {
85     return parameters_.begin();
86   }
87 
88   /// Returns a const iterator to the start of the ParameterList
89   /// the iterator returned will be type of `OrderedDict<std::string,
90   /// torch::Tensor>::Item`
begin()91   ConstIterator begin() const {
92     return parameters_.begin();
93   }
94 
95   /// Returns an iterator to the end of the ParameterList
96   /// the iterator returned will be type of `OrderedDict<std::string,
97   /// torch::Tensor>::Item`
end()98   Iterator end() {
99     return parameters_.end();
100   }
101 
102   /// Returns a const iterator to the end of the ParameterList
103   /// the iterator returned will be type of `OrderedDict<std::string,
104   /// torch::Tensor>::Item`
end()105   ConstIterator end() const {
106     return parameters_.end();
107   }
108 
109   /// Returns the value associated with the given `key`. Throws an exception if
110   /// no such key is stored in the `ParameterList`. Check contains(key) before
111   /// for a non-throwing way of access
at(size_t idx)112   at::Tensor& at(size_t idx) {
113     TORCH_CHECK(idx < size(), "Index out of range");
114     return parameters_[std::to_string(idx)];
115   }
116 
117   /// Returns the value associated with the given `key`. Throws an exception if
118   /// no such key is stored in the `ParameterList`. Check contains(key) before
119   /// for a non-throwing way of access
at(size_t idx)120   const at::Tensor& at(size_t idx) const {
121     TORCH_CHECK(idx < size(), "Index out of range");
122     return parameters_[std::to_string(idx)];
123   }
124 
125   /// Returns the value associated with the given `key`. Throws an exception if
126   /// no such key is stored in the `ParameterList`. Check contains(key) before
127   /// for a non-throwing way of access
128   at::Tensor& operator[](size_t idx) {
129     return at(idx);
130   }
131 
132   /// Returns the value associated with the given `key`. Throws an exception if
133   /// no such key is stored in the `ParameterList`. Check contains(key) before
134   /// for a non-throwing way of access
135   const at::Tensor& operator[](size_t idx) const {
136     return at(idx);
137   }
138 
139   /// Return the size of the ParameterList
size()140   size_t size() const noexcept {
141     return parameters_.size();
142   }
143   /// True if the ParameterList is empty
is_empty()144   bool is_empty() const noexcept {
145     return parameters_.is_empty();
146   }
147 
148   /// Overload the +=, so that two ParameterList could be incrementally added
149   template <typename Container>
150   Container& operator+=(const Container& other) {
151     extend(other);
152     return *this;
153   }
154 
155  private:
156   template <typename Head, typename... Tail>
push_back_var(Head && head,Tail &&...tail)157   void push_back_var(Head&& head, Tail&&... tail) {
158     append(std::forward<Head>(head));
159     // Recursively calls this method, until the parameter pack only thas this
160     // entry left. Then calls `push_back()` a final time (above).
161     push_back_var(std::forward<Tail>(tail)...);
162   }
163 
164   /// The base case, when the list of modules is empty.
push_back_var()165   void push_back_var() {}
166 };
167 TORCH_MODULE(ParameterList);
168 } // namespace nn
169 } // namespace torch
170