1 /** 2 * Copyright 2020 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CORE_IR_PARAM_INFO_H_ 18 #define MINDSPORE_CORE_IR_PARAM_INFO_H_ 19 20 #include <atomic> 21 #include <memory> 22 #include <string> 23 #include <vector> 24 25 #include "ir/dtype.h" 26 27 namespace mindspore { 28 class ParamInfo; 29 using ParamInfoPtr = std::shared_ptr<ParamInfo>; 30 31 class ParamInfo { 32 public: ParamInfo()33 ParamInfo() {} 34 35 ParamInfo(const ParamInfo &other) = default; 36 ParamInfo &operator=(const ParamInfo &other) = default; 37 38 virtual ~ParamInfo() = default; 39 name()40 const std::string &name() const { return name_; } set_name(const std::string & name)41 void set_name(const std::string &name) { name_ = name; } 42 requires_grad()43 bool requires_grad() const { return requires_grad_; } set_requires_grad(bool requires_grad)44 void set_requires_grad(bool requires_grad) { requires_grad_ = requires_grad; } 45 init_in_server()46 bool init_in_server() const { return init_in_server_; } set_init_in_server(bool init_in_server)47 void set_init_in_server(bool init_in_server) { init_in_server_ = init_in_server; } 48 49 // Get the unique key of parameter. key()50 int32_t key() const { return key_; } 51 // Set the unique key of parameter. set_key(int32_t key)52 void set_key(int32_t key) { key_ = key; } 53 layerwise_parallel()54 bool layerwise_parallel() const { return layerwise_parallel_; } set_layerwise_parallel(bool layerwise_parallel)55 void set_layerwise_parallel(bool layerwise_parallel) { layerwise_parallel_ = layerwise_parallel; } 56 57 // Whether the parameter clone from other parameter. cloned()58 bool cloned() const { return cloned_; } 59 60 // Whether the parameter is cloned. be_cloned()61 bool be_cloned() const { return be_cloned_; } 62 63 // If the parameter is cloned, generate one index per clone. be_cloned_index()64 const std::vector<int32_t> &be_cloned_index() const { return be_cloned_index_; } 65 66 // If the parameter clone from other parameter, it has a unique index. cloned_index()67 int32_t cloned_index() const { return cloned_index_; } 68 69 // Make a cloned parameter and update clone info. Clone()70 ParamInfoPtr Clone() { 71 static std::atomic<int32_t> parameter_cloned_index{1}; 72 int32_t index = parameter_cloned_index.fetch_add(1, std::memory_order_relaxed); 73 auto clone = std::make_shared<ParamInfo>(*this); 74 clone->be_cloned_ = false; 75 clone->cloned_ = true; 76 clone->be_cloned_index_ = {}; 77 clone->cloned_index_ = index; 78 this->be_cloned_ = true; 79 this->be_cloned_index_.push_back(index); 80 clone->init_in_server_ = this->init_in_server_; 81 clone->requires_aggr_ = this->requires_aggr_; 82 clone->strategy_ckpt_saved_ = this->strategy_ckpt_saved_; 83 clone->param_strategy_ = this->param_strategy_; 84 clone->storage_format_ = this->storage_format_; 85 clone->ClearParameter(); 86 return clone; 87 } 88 comm_fusion()89 int32_t comm_fusion() const { return fusion_type_; } set_comm_fusion(int32_t fusion_type)90 void set_comm_fusion(int32_t fusion_type) { fusion_type_ = fusion_type; } 91 parallel_optimizer()92 bool parallel_optimizer() const { return parallel_optimizer_; } set_parallel_optimizer(bool parallel_optimizer)93 void set_parallel_optimizer(bool parallel_optimizer) { parallel_optimizer_ = parallel_optimizer; } 94 parallel_optimizer_comm_recompute()95 bool parallel_optimizer_comm_recompute() const { return parallel_optimizer_comm_recompute_; } set_parallel_optimizer_comm_recompute(bool parallel_optimizer_comm_recompute)96 void set_parallel_optimizer_comm_recompute(bool parallel_optimizer_comm_recompute) { 97 parallel_optimizer_comm_recompute_ = parallel_optimizer_comm_recompute; 98 } 99 parameter_shape()100 const std::vector<int64_t> ¶meter_shape() const { return parameter_shape_; } set_parameter_shape(const std::vector<int64_t> & tensor_shape)101 void set_parameter_shape(const std::vector<int64_t> &tensor_shape) { parameter_shape_ = tensor_shape; } 102 set_strategy_ckpt_saved(bool strategy_ckpt_saved)103 void set_strategy_ckpt_saved(bool strategy_ckpt_saved) { strategy_ckpt_saved_ = strategy_ckpt_saved; } strategy_ckpt_saved()104 bool strategy_ckpt_saved() const { return strategy_ckpt_saved_; } 105 use_persistent_storage()106 bool use_persistent_storage() const { return use_persistent_storage_; } set_use_persistent_storage(bool use_persistent_storage)107 void set_use_persistent_storage(bool use_persistent_storage) { use_persistent_storage_ = use_persistent_storage; } 108 origin_shape()109 const std::vector<int64_t> &origin_shape() const { return origin_shape_; } set_origin_shape(const std::vector<int64_t> & origin_shape)110 void set_origin_shape(const std::vector<int64_t> &origin_shape) { origin_shape_ = origin_shape; } 111 cache_enable()112 bool cache_enable() const { return cache_enable_; } set_cache_enable(bool cache_enable)113 void set_cache_enable(bool cache_enable) { cache_enable_ = cache_enable; } 114 param_strategy()115 const std::vector<int64_t> ¶m_strategy() const { return param_strategy_; } set_param_strategy(const std::vector<int64_t> & param_strategy)116 void set_param_strategy(const std::vector<int64_t> ¶m_strategy) { param_strategy_ = param_strategy; } 117 cache_shape()118 std::vector<int64_t> cache_shape() const { return cache_shape_; } set_cache_shape(const std::vector<int64_t> & cache_shape)119 void set_cache_shape(const std::vector<int64_t> &cache_shape) { cache_shape_ = cache_shape; } parameter()120 ParameterPtr parameter() const { return parameter_.lock(); } set_parameter(const ParameterPtr & parameter)121 void set_parameter(const ParameterPtr ¶meter) { parameter_ = parameter; } ClearParameter()122 void ClearParameter() { parameter_.reset(); } 123 requires_aggr()124 bool requires_aggr() const { return requires_aggr_; } set_requires_aggr(bool requires_aggr)125 void set_requires_aggr(bool requires_aggr) { requires_aggr_ = requires_aggr; } 126 is_quant_int4()127 bool is_quant_int4() const { return is_quant_int4_; } set_is_quant_int4(bool is_quant_int4)128 void set_is_quant_int4(bool is_quant_int4) { is_quant_int4_ = is_quant_int4; } 129 quant_shape()130 std::vector<int64_t> quant_shape() const { return quant_shape_; } set_quant_shape(const std::vector<int64_t> & quant_shape)131 void set_quant_shape(const std::vector<int64_t> &quant_shape) { quant_shape_ = quant_shape; } 132 ignore_device_addr()133 bool ignore_device_addr() const { return ignore_device_addr_; } set_ignore_device_addr(bool ignore)134 void set_ignore_device_addr(bool ignore) { ignore_device_addr_ = ignore; } 135 storage_format()136 std::string storage_format() const { return storage_format_; } set_storage_format(const std::string & storage_format)137 void set_storage_format(const std::string &storage_format) { storage_format_ = storage_format; } 138 139 private: 140 std::string name_{"Parameter"}; 141 bool requires_grad_{true}; 142 bool init_in_server_{false}; 143 bool layerwise_parallel_{false}; 144 bool be_cloned_{false}; 145 bool strategy_ckpt_saved_{false}; 146 bool cloned_{false}; 147 std::vector<int32_t> be_cloned_index_; 148 int32_t cloned_index_{0}; 149 int32_t fusion_type_{1}; 150 bool parallel_optimizer_{true}; 151 bool parallel_optimizer_comm_recompute_{false}; 152 bool cache_enable_{false}; 153 std::vector<int64_t> cache_shape_; 154 ParameterWeakPtr parameter_; 155 bool requires_aggr_{true}; 156 std::vector<int64_t> parameter_shape_; 157 std::string storage_format_{""}; 158 159 // Record the origin shape before cut huge parameter to a small one. 160 std::vector<int64_t> origin_shape_; 161 // This flag indicates whether the persistent storage capability is enabled, which is generally used in very large 162 // parameter scenarios. 163 bool use_persistent_storage_{false}; 164 165 // Used to identify the same Parameter for Worker and Server in the embedding cache scenario. 166 int32_t key_{-1}; 167 // Used to indicate parameter strategy, only take effect in cell shard 168 std::vector<int64_t> param_strategy_; 169 170 // Used to identify parameters of quant int4 type 171 bool is_quant_int4_{false}; 172 std::vector<int64_t> quant_shape_; 173 // Used to ignore unused param 174 bool ignore_device_addr_{false}; 175 }; 176 } // namespace mindspore 177 #endif // MINDSPORE_CORE_IR_PARAM_INFO_H_ 178