1 #pragma once 2 3 #include <torch/nn/cloneable.h> 4 #include <torch/nn/pimpl.h> 5 #include <torch/ordered_dict.h> 6 #include <utility> 7 #include <vector> 8 9 namespace torch { 10 namespace nn { 11 12 class ParameterDictImpl : public Cloneable<ParameterDictImpl> { 13 public: 14 using Iterator = OrderedDict<std::string, Tensor>::Iterator; 15 using ConstIterator = OrderedDict<std::string, Tensor>::ConstIterator; 16 17 ParameterDictImpl() = default; 18 ParameterDictImpl(const torch::OrderedDict<std::string,torch::Tensor> & params)19 explicit ParameterDictImpl( 20 const torch::OrderedDict<std::string, torch::Tensor>& params) { 21 parameters_ = params; 22 } 23 24 /// `reset()` is empty for `ParameterDict`, since it does not have 25 /// parameters of its own. reset()26 void reset() override {} 27 28 /// Pretty prints the `ParameterDict` module into the given `stream`. pretty_print(std::ostream & stream)29 void pretty_print(std::ostream& stream) const override { 30 stream << "torch::nn::ParameterDict(" << std::endl; 31 for (const auto& pair : parameters_) { 32 stream << "(" << pair.key() << ")" 33 << ": Parameter containing: [" << pair.value().scalar_type() 34 << " of size " << pair.value().sizes() << "]"; 35 ; 36 stream << std::endl; 37 } 38 stream << ")"; 39 } 40 41 /// Insert the parameter along with the key into ParameterDict 42 /// The parameter is set to be require grad by default insert(std::string key,Tensor param)43 Tensor& insert(std::string key, Tensor param) { 44 bool requires_grad = param.requires_grad(); 45 return register_parameter(std::move(key), std::move(param), requires_grad); 46 } 47 48 /// Remove key from the ParameterDict and return its value, throw exception 49 /// if the key is not contained. Please check contains(key) before for a 50 /// non-throwing access. pop(const std::string & key)51 Tensor pop(const std::string& key) { 52 torch::Tensor v = parameters_[key]; 53 parameters_.erase(key); 54 return v; 55 } 56 57 /// Return the keys in the dict keys()58 ::std::vector<std::string> keys() const { 59 return parameters_.keys(); 60 } 61 62 /// Return the Values in the dict values()63 ::std::vector<torch::Tensor> values() const { 64 return parameters_.values(); 65 } 66 67 /// Return an iterator to the start of ParameterDict begin()68 Iterator begin() { 69 return parameters_.begin(); 70 } 71 72 /// Return a const iterator to the start of ParameterDict begin()73 ConstIterator begin() const { 74 return parameters_.begin(); 75 } 76 77 /// Return an iterator to the end of ParameterDict end()78 Iterator end() { 79 return parameters_.end(); 80 } 81 82 /// Return a const iterator to the end of ParameterDict end()83 ConstIterator end() const { 84 return parameters_.end(); 85 } 86 87 /// Return the number of items currently stored in the ParameterDict size()88 size_t size() const noexcept { 89 return parameters_.size(); 90 } 91 92 /// Return true if the ParameterDict is empty, otherwise return false empty()93 bool empty() const noexcept { 94 return parameters_.is_empty(); 95 } 96 97 /// Update the ParameterDict with the key-value pairs from 98 /// another ParameterDict, overwriting existing key 99 template <typename Container> update(const Container & container)100 void update(const Container& container) { 101 for (auto& item : container) { 102 parameters_[item.key()] = item.value(); 103 } 104 } 105 106 /// Remove all parameters in the ParameterDict clear()107 void clear() { 108 parameters_.clear(); 109 } 110 111 /// Check if the centain parameter with the key in the ParameterDict contains(const std::string & key)112 bool contains(const std::string& key) const noexcept { 113 return parameters_.contains(key); 114 } 115 116 /// Returns the value associated with the given `key`. Throws an exception if 117 /// no such key is stored in the `ParameterDict`. Check contains(key) before 118 /// for a non-throwing way of access get(const std::string & key)119 const Tensor& get(const std::string& key) const { 120 return parameters_[key]; 121 } 122 123 /// Returns the value associated with the given `key`. Throws an exception if 124 /// no such key is stored in the `ParameterDict`. Check contains(key) before 125 /// for a non-throwing way of access get(const std::string & key)126 Tensor& get(const std::string& key) { 127 return parameters_[key]; 128 } 129 130 /// Returns the value associated with the given `key`. Throws an exception if 131 /// no such key is stored in the `ParameterDict`. Check contains(key) before 132 /// for a non-throwing way of access 133 Tensor& operator[](const std::string& key) { 134 return parameters_[key]; 135 } 136 137 /// Returns the value associated with the given `key`. Throws an exception if 138 /// no such key is stored in the `ParameterDict`. Check contains(key) before 139 /// for a non-throwing way of access 140 const Tensor& operator[](const std::string& key) const { 141 return parameters_[key]; 142 } 143 }; 144 145 TORCH_MODULE(ParameterDict); 146 147 } // namespace nn 148 } // namespace torch 149