• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CORE_IR_PARAM_INFO_H_
18 #define MINDSPORE_CORE_IR_PARAM_INFO_H_
19 
20 #include <atomic>
21 #include <memory>
22 #include <string>
23 #include <vector>
24 
25 #include "ir/dtype.h"
26 
27 namespace mindspore {
28 class ParamInfo;
29 using ParamInfoPtr = std::shared_ptr<ParamInfo>;
30 
31 class ParamInfo {
32  public:
ParamInfo()33   ParamInfo() {}
34 
35   ParamInfo(const ParamInfo &other) = default;
36 
37   virtual ~ParamInfo() = default;
38 
name()39   const std::string &name() const { return name_; }
set_name(const std::string & name)40   void set_name(const std::string &name) { name_ = name; }
41 
requires_grad()42   bool requires_grad() const { return requires_grad_; }
set_requires_grad(bool requires_grad)43   void set_requires_grad(bool requires_grad) { requires_grad_ = requires_grad; }
44 
init_in_server()45   bool init_in_server() const { return init_in_server_; }
set_init_in_server(bool init_in_server)46   void set_init_in_server(bool init_in_server) { init_in_server_ = init_in_server; }
47 
layerwise_parallel()48   bool layerwise_parallel() const { return layerwise_parallel_; }
set_layerwise_parallel(bool layerwise_parallel)49   void set_layerwise_parallel(bool layerwise_parallel) { layerwise_parallel_ = layerwise_parallel; }
50 
51   // Whether the parameter clone from other parameter.
cloned()52   bool cloned() const { return cloned_; }
53 
54   // Whether the parameter is cloned.
be_cloned()55   bool be_cloned() const { return be_cloned_; }
56 
57   // If the parameter is cloned, generate one index per clone.
be_cloned_index()58   const std::vector<int32_t> &be_cloned_index() const { return be_cloned_index_; }
59 
60   // If the parameter clone from other parameter, it has a unique index.
cloned_index()61   int32_t cloned_index() const { return cloned_index_; }
62 
63   // Make a cloned parameter and update clone info.
Clone()64   ParamInfoPtr Clone() {
65     static std::atomic<int32_t> parameter_cloned_index{1};
66     int32_t index = parameter_cloned_index.fetch_add(1, std::memory_order_relaxed);
67     auto clone = std::make_shared<ParamInfo>(*this);
68     clone->be_cloned_ = false;
69     clone->cloned_ = true;
70     clone->be_cloned_index_ = {};
71     clone->cloned_index_ = index;
72     this->be_cloned_ = true;
73     this->be_cloned_index_.push_back(index);
74     clone->init_in_server_ = this->init_in_server_;
75     clone->requires_aggr_ = this->requires_aggr_;
76     clone->ClearParameter();
77     return clone;
78   }
79 
comm_fusion()80   int32_t comm_fusion() const { return fusion_type_; }
set_comm_fusion(int32_t fusion_type)81   void set_comm_fusion(int32_t fusion_type) { fusion_type_ = fusion_type; }
82 
parallel_optimizer()83   bool parallel_optimizer() const { return parallel_optimizer_; }
set_parallel_optimizer(bool parallel_optimizer)84   void set_parallel_optimizer(bool parallel_optimizer) { parallel_optimizer_ = parallel_optimizer; }
85 
parallel_optimizer_comm_recompute()86   bool parallel_optimizer_comm_recompute() const { return parallel_optimizer_comm_recompute_; }
set_parallel_optimizer_comm_recompute(bool parallel_optimizer_comm_recompute)87   void set_parallel_optimizer_comm_recompute(bool parallel_optimizer_comm_recompute) {
88     parallel_optimizer_comm_recompute_ = parallel_optimizer_comm_recompute;
89   }
90 
cache_enable()91   bool cache_enable() const { return cache_enable_; }
set_cache_enable(bool cache_enable)92   void set_cache_enable(bool cache_enable) { cache_enable_ = cache_enable; }
93 
cache_shape()94   std::vector<int64_t> cache_shape() const { return cache_shape_; }
set_cache_shape(const std::vector<int64_t> & cache_shape)95   void set_cache_shape(const std::vector<int64_t> &cache_shape) { cache_shape_ = cache_shape; }
parameter()96   ParameterPtr parameter() { return parameter_; }
set_parameter(const ParameterPtr & parameter)97   void set_parameter(const ParameterPtr &parameter) { parameter_ = parameter; }
ClearParameter()98   void ClearParameter() { parameter_ = nullptr; }
99 
requires_aggr()100   bool requires_aggr() const { return requires_aggr_; }
set_requires_aggr(bool requires_aggr)101   void set_requires_aggr(bool requires_aggr) { requires_aggr_ = requires_aggr; }
102 
103  private:
104   std::string name_{"Parameter"};
105   bool requires_grad_{true};
106   bool init_in_server_{false};
107   bool layerwise_parallel_{false};
108   bool be_cloned_{false};
109   bool cloned_{false};
110   std::vector<int32_t> be_cloned_index_;
111   int32_t cloned_index_{0};
112   int32_t fusion_type_{1};
113   bool parallel_optimizer_{true};
114   bool parallel_optimizer_comm_recompute_{false};
115   bool cache_enable_{false};
116   std::vector<int64_t> cache_shape_;
117   ParameterPtr parameter_{nullptr};
118   bool requires_aggr_{true};
119 };
120 }  // namespace mindspore
121 #endif  // MINDSPORE_CORE_IR_PARAM_INFO_H_
122