• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CORE_IR_PARAM_INFO_H_
18 #define MINDSPORE_CORE_IR_PARAM_INFO_H_
19 
20 #include <atomic>
21 #include <memory>
22 #include <string>
23 #include <vector>
24 
25 #include "ir/dtype.h"
26 
27 namespace mindspore {
28 class ParamInfo;
29 using ParamInfoPtr = std::shared_ptr<ParamInfo>;
30 
31 class ParamInfo {
32  public:
ParamInfo()33   ParamInfo() {}
34 
35   ParamInfo(const ParamInfo &other) = default;
36   ParamInfo &operator=(const ParamInfo &other) = default;
37 
38   virtual ~ParamInfo() = default;
39 
name()40   const std::string &name() const { return name_; }
set_name(const std::string & name)41   void set_name(const std::string &name) { name_ = name; }
42 
requires_grad()43   bool requires_grad() const { return requires_grad_; }
set_requires_grad(bool requires_grad)44   void set_requires_grad(bool requires_grad) { requires_grad_ = requires_grad; }
45 
init_in_server()46   bool init_in_server() const { return init_in_server_; }
set_init_in_server(bool init_in_server)47   void set_init_in_server(bool init_in_server) { init_in_server_ = init_in_server; }
48 
49   // Get the unique key of parameter.
key()50   int32_t key() const { return key_; }
51   // Set the unique key of parameter.
set_key(int32_t key)52   void set_key(int32_t key) { key_ = key; }
53 
layerwise_parallel()54   bool layerwise_parallel() const { return layerwise_parallel_; }
set_layerwise_parallel(bool layerwise_parallel)55   void set_layerwise_parallel(bool layerwise_parallel) { layerwise_parallel_ = layerwise_parallel; }
56 
57   // Whether the parameter clone from other parameter.
cloned()58   bool cloned() const { return cloned_; }
59 
60   // Whether the parameter is cloned.
be_cloned()61   bool be_cloned() const { return be_cloned_; }
62 
63   // If the parameter is cloned, generate one index per clone.
be_cloned_index()64   const std::vector<int32_t> &be_cloned_index() const { return be_cloned_index_; }
65 
66   // If the parameter clone from other parameter, it has a unique index.
cloned_index()67   int32_t cloned_index() const { return cloned_index_; }
68 
69   // Make a cloned parameter and update clone info.
Clone()70   ParamInfoPtr Clone() {
71     static std::atomic<int32_t> parameter_cloned_index{1};
72     int32_t index = parameter_cloned_index.fetch_add(1, std::memory_order_relaxed);
73     auto clone = std::make_shared<ParamInfo>(*this);
74     clone->be_cloned_ = false;
75     clone->cloned_ = true;
76     clone->be_cloned_index_ = {};
77     clone->cloned_index_ = index;
78     this->be_cloned_ = true;
79     this->be_cloned_index_.push_back(index);
80     clone->init_in_server_ = this->init_in_server_;
81     clone->requires_aggr_ = this->requires_aggr_;
82     clone->strategy_ckpt_saved_ = this->strategy_ckpt_saved_;
83     clone->param_strategy_ = this->param_strategy_;
84     clone->storage_format_ = this->storage_format_;
85     clone->ClearParameter();
86     return clone;
87   }
88 
comm_fusion()89   int32_t comm_fusion() const { return fusion_type_; }
set_comm_fusion(int32_t fusion_type)90   void set_comm_fusion(int32_t fusion_type) { fusion_type_ = fusion_type; }
91 
parallel_optimizer()92   bool parallel_optimizer() const { return parallel_optimizer_; }
set_parallel_optimizer(bool parallel_optimizer)93   void set_parallel_optimizer(bool parallel_optimizer) { parallel_optimizer_ = parallel_optimizer; }
94 
parallel_optimizer_comm_recompute()95   bool parallel_optimizer_comm_recompute() const { return parallel_optimizer_comm_recompute_; }
set_parallel_optimizer_comm_recompute(bool parallel_optimizer_comm_recompute)96   void set_parallel_optimizer_comm_recompute(bool parallel_optimizer_comm_recompute) {
97     parallel_optimizer_comm_recompute_ = parallel_optimizer_comm_recompute;
98   }
99 
parameter_shape()100   const std::vector<int64_t> &parameter_shape() const { return parameter_shape_; }
set_parameter_shape(const std::vector<int64_t> & tensor_shape)101   void set_parameter_shape(const std::vector<int64_t> &tensor_shape) { parameter_shape_ = tensor_shape; }
102 
set_strategy_ckpt_saved(bool strategy_ckpt_saved)103   void set_strategy_ckpt_saved(bool strategy_ckpt_saved) { strategy_ckpt_saved_ = strategy_ckpt_saved; }
strategy_ckpt_saved()104   bool strategy_ckpt_saved() const { return strategy_ckpt_saved_; }
105 
use_persistent_storage()106   bool use_persistent_storage() const { return use_persistent_storage_; }
set_use_persistent_storage(bool use_persistent_storage)107   void set_use_persistent_storage(bool use_persistent_storage) { use_persistent_storage_ = use_persistent_storage; }
108 
origin_shape()109   const std::vector<int64_t> &origin_shape() const { return origin_shape_; }
set_origin_shape(const std::vector<int64_t> & origin_shape)110   void set_origin_shape(const std::vector<int64_t> &origin_shape) { origin_shape_ = origin_shape; }
111 
cache_enable()112   bool cache_enable() const { return cache_enable_; }
set_cache_enable(bool cache_enable)113   void set_cache_enable(bool cache_enable) { cache_enable_ = cache_enable; }
114 
param_strategy()115   const std::vector<int64_t> &param_strategy() const { return param_strategy_; }
set_param_strategy(const std::vector<int64_t> & param_strategy)116   void set_param_strategy(const std::vector<int64_t> &param_strategy) { param_strategy_ = param_strategy; }
117 
cache_shape()118   std::vector<int64_t> cache_shape() const { return cache_shape_; }
set_cache_shape(const std::vector<int64_t> & cache_shape)119   void set_cache_shape(const std::vector<int64_t> &cache_shape) { cache_shape_ = cache_shape; }
parameter()120   ParameterPtr parameter() const { return parameter_.lock(); }
set_parameter(const ParameterPtr & parameter)121   void set_parameter(const ParameterPtr &parameter) { parameter_ = parameter; }
ClearParameter()122   void ClearParameter() { parameter_.reset(); }
123 
requires_aggr()124   bool requires_aggr() const { return requires_aggr_; }
set_requires_aggr(bool requires_aggr)125   void set_requires_aggr(bool requires_aggr) { requires_aggr_ = requires_aggr; }
126 
is_quant_int4()127   bool is_quant_int4() const { return is_quant_int4_; }
set_is_quant_int4(bool is_quant_int4)128   void set_is_quant_int4(bool is_quant_int4) { is_quant_int4_ = is_quant_int4; }
129 
quant_shape()130   std::vector<int64_t> quant_shape() const { return quant_shape_; }
set_quant_shape(const std::vector<int64_t> & quant_shape)131   void set_quant_shape(const std::vector<int64_t> &quant_shape) { quant_shape_ = quant_shape; }
132 
ignore_device_addr()133   bool ignore_device_addr() const { return ignore_device_addr_; }
set_ignore_device_addr(bool ignore)134   void set_ignore_device_addr(bool ignore) { ignore_device_addr_ = ignore; }
135 
storage_format()136   std::string storage_format() const { return storage_format_; }
set_storage_format(const std::string & storage_format)137   void set_storage_format(const std::string &storage_format) { storage_format_ = storage_format; }
138 
139  private:
140   std::string name_{"Parameter"};
141   bool requires_grad_{true};
142   bool init_in_server_{false};
143   bool layerwise_parallel_{false};
144   bool be_cloned_{false};
145   bool strategy_ckpt_saved_{false};
146   bool cloned_{false};
147   std::vector<int32_t> be_cloned_index_;
148   int32_t cloned_index_{0};
149   int32_t fusion_type_{1};
150   bool parallel_optimizer_{true};
151   bool parallel_optimizer_comm_recompute_{false};
152   bool cache_enable_{false};
153   std::vector<int64_t> cache_shape_;
154   ParameterWeakPtr parameter_;
155   bool requires_aggr_{true};
156   std::vector<int64_t> parameter_shape_;
157   std::string storage_format_{""};
158 
159   // Record the origin shape before cut huge parameter to a small one.
160   std::vector<int64_t> origin_shape_;
161   // This flag indicates whether the persistent storage capability is enabled, which is generally used in very large
162   // parameter scenarios.
163   bool use_persistent_storage_{false};
164 
165   // Used to identify the same Parameter for Worker and Server in the embedding cache scenario.
166   int32_t key_{-1};
167   // Used to indicate parameter strategy, only take effect in cell shard
168   std::vector<int64_t> param_strategy_;
169 
170   // Used to identify parameters of quant int4 type
171   bool is_quant_int4_{false};
172   std::vector<int64_t> quant_shape_;
173   // Used to ignore unused param
174   bool ignore_device_addr_{false};
175 };
176 }  // namespace mindspore
177 #endif  // MINDSPORE_CORE_IR_PARAM_INFO_H_
178