1 #pragma once 2 3 #include <torch/nn/cloneable.h> 4 #include <torch/nn/module.h> 5 6 #include <vector> 7 8 namespace torch { 9 namespace nn { 10 class ParameterListImpl : public Cloneable<ParameterListImpl> { 11 public: 12 using Iterator = typename std::vector< 13 OrderedDict<std::string, torch::Tensor>::Item>::iterator; 14 using ConstIterator = typename std::vector< 15 OrderedDict<std::string, torch::Tensor>::Item>::const_iterator; 16 17 ParameterListImpl() = default; 18 19 /// Constructs the `ParameterList` from a variadic list of ParameterList. 20 template <typename... Tensors> ParameterListImpl(Tensors &&...params)21 explicit ParameterListImpl(Tensors&&... params) { 22 parameters_.reserve(sizeof...(Tensors)); 23 push_back_var(std::forward<Tensors>(params)...); 24 } 25 26 template <typename... Tensors> ParameterListImpl(const Tensors &...params)27 explicit ParameterListImpl(const Tensors&... params) { 28 parameters_.reserve(sizeof...(Tensors)); 29 push_back_var(std::forward<Tensors>(params)...); 30 } 31 32 /// `reset()` is empty for `ParameterList`, since it does not have parameters 33 /// of its own. reset()34 void reset() override {} 35 36 /// Pretty prints the `ParameterList` module into the given `stream`. pretty_print(std::ostream & stream)37 void pretty_print(std::ostream& stream) const override { 38 stream << "torch::nn::ParameterList(" << std::endl; 39 for (const auto& pair : parameters_) { 40 stream << "(" << pair.key() << ")" 41 << ": Parameter containing: [" << pair.value().scalar_type() 42 << " of size " << pair.value().sizes() << "]"; 43 ; 44 stream << std::endl; 45 } 46 stream << ")"; 47 } 48 49 /// push the a given parameter at the end of the list append(torch::Tensor && param)50 void append(torch::Tensor&& param) { 51 bool requires_grad = param.requires_grad(); 52 register_parameter( 53 std::to_string(parameters_.size()), std::move(param), requires_grad); 54 } 55 56 /// push the a given parameter at the end of the list append(const torch::Tensor & param)57 void append(const torch::Tensor& param) { 58 bool requires_grad = param.requires_grad(); 59 register_parameter( 60 std::to_string(parameters_.size()), param, requires_grad); 61 } 62 63 /// push the a given parameter at the end of the list 64 /// And the key of the pair will be discarded, only the value 65 /// will be added into the `ParameterList` append(const OrderedDict<std::string,torch::Tensor>::Item & pair)66 void append(const OrderedDict<std::string, torch::Tensor>::Item& pair) { 67 register_parameter( 68 std::to_string(parameters_.size()), 69 pair.value(), 70 pair.value().requires_grad()); 71 } 72 73 /// extend parameters from a container to the end of the list 74 template <typename Container> extend(const Container & container)75 void extend(const Container& container) { 76 for (const auto& param : container) { 77 append(param); 78 } 79 } 80 81 /// Returns an iterator to the start of the ParameterList 82 /// the iterator returned will be type of `OrderedDict<std::string, 83 /// torch::Tensor>::Item` begin()84 Iterator begin() { 85 return parameters_.begin(); 86 } 87 88 /// Returns a const iterator to the start of the ParameterList 89 /// the iterator returned will be type of `OrderedDict<std::string, 90 /// torch::Tensor>::Item` begin()91 ConstIterator begin() const { 92 return parameters_.begin(); 93 } 94 95 /// Returns an iterator to the end of the ParameterList 96 /// the iterator returned will be type of `OrderedDict<std::string, 97 /// torch::Tensor>::Item` end()98 Iterator end() { 99 return parameters_.end(); 100 } 101 102 /// Returns a const iterator to the end of the ParameterList 103 /// the iterator returned will be type of `OrderedDict<std::string, 104 /// torch::Tensor>::Item` end()105 ConstIterator end() const { 106 return parameters_.end(); 107 } 108 109 /// Returns the value associated with the given `key`. Throws an exception if 110 /// no such key is stored in the `ParameterList`. Check contains(key) before 111 /// for a non-throwing way of access at(size_t idx)112 at::Tensor& at(size_t idx) { 113 TORCH_CHECK(idx < size(), "Index out of range"); 114 return parameters_[std::to_string(idx)]; 115 } 116 117 /// Returns the value associated with the given `key`. Throws an exception if 118 /// no such key is stored in the `ParameterList`. Check contains(key) before 119 /// for a non-throwing way of access at(size_t idx)120 const at::Tensor& at(size_t idx) const { 121 TORCH_CHECK(idx < size(), "Index out of range"); 122 return parameters_[std::to_string(idx)]; 123 } 124 125 /// Returns the value associated with the given `key`. Throws an exception if 126 /// no such key is stored in the `ParameterList`. Check contains(key) before 127 /// for a non-throwing way of access 128 at::Tensor& operator[](size_t idx) { 129 return at(idx); 130 } 131 132 /// Returns the value associated with the given `key`. Throws an exception if 133 /// no such key is stored in the `ParameterList`. Check contains(key) before 134 /// for a non-throwing way of access 135 const at::Tensor& operator[](size_t idx) const { 136 return at(idx); 137 } 138 139 /// Return the size of the ParameterList size()140 size_t size() const noexcept { 141 return parameters_.size(); 142 } 143 /// True if the ParameterList is empty is_empty()144 bool is_empty() const noexcept { 145 return parameters_.is_empty(); 146 } 147 148 /// Overload the +=, so that two ParameterList could be incrementally added 149 template <typename Container> 150 Container& operator+=(const Container& other) { 151 extend(other); 152 return *this; 153 } 154 155 private: 156 template <typename Head, typename... Tail> push_back_var(Head && head,Tail &&...tail)157 void push_back_var(Head&& head, Tail&&... tail) { 158 append(std::forward<Head>(head)); 159 // Recursively calls this method, until the parameter pack only thas this 160 // entry left. Then calls `push_back()` a final time (above). 161 push_back_var(std::forward<Tail>(tail)...); 162 } 163 164 /// The base case, when the list of modules is empty. push_back_var()165 void push_back_var() {} 166 }; 167 TORCH_MODULE(ParameterList); 168 } // namespace nn 169 } // namespace torch 170