• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 #pragma once
2 
3 #include <torch/nn/cloneable.h>
4 #include <torch/nn/pimpl.h>
5 #include <torch/ordered_dict.h>
6 #include <utility>
7 #include <vector>
8 
9 namespace torch {
10 namespace nn {
11 
12 class ParameterDictImpl : public Cloneable<ParameterDictImpl> {
13  public:
14   using Iterator = OrderedDict<std::string, Tensor>::Iterator;
15   using ConstIterator = OrderedDict<std::string, Tensor>::ConstIterator;
16 
17   ParameterDictImpl() = default;
18 
ParameterDictImpl(const torch::OrderedDict<std::string,torch::Tensor> & params)19   explicit ParameterDictImpl(
20       const torch::OrderedDict<std::string, torch::Tensor>& params) {
21     parameters_ = params;
22   }
23 
24   /// `reset()` is empty for `ParameterDict`, since it does not have
25   /// parameters of its own.
reset()26   void reset() override {}
27 
28   /// Pretty prints the `ParameterDict` module into the given `stream`.
pretty_print(std::ostream & stream)29   void pretty_print(std::ostream& stream) const override {
30     stream << "torch::nn::ParameterDict(" << std::endl;
31     for (const auto& pair : parameters_) {
32       stream << "(" << pair.key() << ")"
33              << ": Parameter containing: [" << pair.value().scalar_type()
34              << " of size " << pair.value().sizes() << "]";
35       ;
36       stream << std::endl;
37     }
38     stream << ")";
39   }
40 
41   /// Insert the parameter along with the key into ParameterDict
42   /// The parameter is set to be require grad by default
insert(std::string key,Tensor param)43   Tensor& insert(std::string key, Tensor param) {
44     bool requires_grad = param.requires_grad();
45     return register_parameter(std::move(key), std::move(param), requires_grad);
46   }
47 
48   /// Remove key from the ParameterDict and return its value, throw exception
49   /// if the key is not contained. Please check contains(key) before for a
50   /// non-throwing access.
pop(const std::string & key)51   Tensor pop(const std::string& key) {
52     torch::Tensor v = parameters_[key];
53     parameters_.erase(key);
54     return v;
55   }
56 
57   /// Return the keys in the dict
keys()58   ::std::vector<std::string> keys() const {
59     return parameters_.keys();
60   }
61 
62   /// Return the Values in the dict
values()63   ::std::vector<torch::Tensor> values() const {
64     return parameters_.values();
65   }
66 
67   /// Return an iterator to the start of ParameterDict
begin()68   Iterator begin() {
69     return parameters_.begin();
70   }
71 
72   /// Return a const iterator to the start of ParameterDict
begin()73   ConstIterator begin() const {
74     return parameters_.begin();
75   }
76 
77   /// Return an iterator to the end of ParameterDict
end()78   Iterator end() {
79     return parameters_.end();
80   }
81 
82   /// Return a const iterator to the end of ParameterDict
end()83   ConstIterator end() const {
84     return parameters_.end();
85   }
86 
87   /// Return the number of items currently stored in the ParameterDict
size()88   size_t size() const noexcept {
89     return parameters_.size();
90   }
91 
92   /// Return true if the ParameterDict is empty, otherwise return false
empty()93   bool empty() const noexcept {
94     return parameters_.is_empty();
95   }
96 
97   /// Update the ParameterDict with the key-value pairs from
98   /// another ParameterDict, overwriting existing key
99   template <typename Container>
update(const Container & container)100   void update(const Container& container) {
101     for (auto& item : container) {
102       parameters_[item.key()] = item.value();
103     }
104   }
105 
106   /// Remove all parameters in the ParameterDict
clear()107   void clear() {
108     parameters_.clear();
109   }
110 
111   /// Check if the centain parameter with the key in the ParameterDict
contains(const std::string & key)112   bool contains(const std::string& key) const noexcept {
113     return parameters_.contains(key);
114   }
115 
116   /// Returns the value associated with the given `key`. Throws an exception if
117   /// no such key is stored in the `ParameterDict`. Check contains(key) before
118   /// for a non-throwing way of access
get(const std::string & key)119   const Tensor& get(const std::string& key) const {
120     return parameters_[key];
121   }
122 
123   /// Returns the value associated with the given `key`. Throws an exception if
124   /// no such key is stored in the `ParameterDict`. Check contains(key) before
125   /// for a non-throwing way of access
get(const std::string & key)126   Tensor& get(const std::string& key) {
127     return parameters_[key];
128   }
129 
130   /// Returns the value associated with the given `key`. Throws an exception if
131   /// no such key is stored in the `ParameterDict`. Check contains(key) before
132   /// for a non-throwing way of access
133   Tensor& operator[](const std::string& key) {
134     return parameters_[key];
135   }
136 
137   /// Returns the value associated with the given `key`. Throws an exception if
138   /// no such key is stored in the `ParameterDict`. Check contains(key) before
139   /// for a non-throwing way of access
140   const Tensor& operator[](const std::string& key) const {
141     return parameters_[key];
142   }
143 };
144 
145 TORCH_MODULE(ParameterDict);
146 
147 } // namespace nn
148 } // namespace torch
149