1 /** 2 * Copyright 2020 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CORE_IR_PARAM_INFO_H_ 18 #define MINDSPORE_CORE_IR_PARAM_INFO_H_ 19 20 #include <atomic> 21 #include <memory> 22 #include <string> 23 #include <vector> 24 25 #include "ir/dtype.h" 26 27 namespace mindspore { 28 class ParamInfo; 29 using ParamInfoPtr = std::shared_ptr<ParamInfo>; 30 31 class ParamInfo { 32 public: ParamInfo()33 ParamInfo() {} 34 35 ParamInfo(const ParamInfo &other) = default; 36 37 virtual ~ParamInfo() = default; 38 name()39 const std::string &name() const { return name_; } set_name(const std::string & name)40 void set_name(const std::string &name) { name_ = name; } 41 requires_grad()42 bool requires_grad() const { return requires_grad_; } set_requires_grad(bool requires_grad)43 void set_requires_grad(bool requires_grad) { requires_grad_ = requires_grad; } 44 init_in_server()45 bool init_in_server() const { return init_in_server_; } set_init_in_server(bool init_in_server)46 void set_init_in_server(bool init_in_server) { init_in_server_ = init_in_server; } 47 layerwise_parallel()48 bool layerwise_parallel() const { return layerwise_parallel_; } set_layerwise_parallel(bool layerwise_parallel)49 void set_layerwise_parallel(bool layerwise_parallel) { layerwise_parallel_ = layerwise_parallel; } 50 51 // Whether the parameter clone from other parameter. cloned()52 bool cloned() const { return cloned_; } 53 54 // Whether the parameter is cloned. be_cloned()55 bool be_cloned() const { return be_cloned_; } 56 57 // If the parameter is cloned, generate one index per clone. be_cloned_index()58 const std::vector<int32_t> &be_cloned_index() const { return be_cloned_index_; } 59 60 // If the parameter clone from other parameter, it has a unique index. cloned_index()61 int32_t cloned_index() const { return cloned_index_; } 62 63 // Make a cloned parameter and update clone info. Clone()64 ParamInfoPtr Clone() { 65 static std::atomic<int32_t> parameter_cloned_index{1}; 66 int32_t index = parameter_cloned_index.fetch_add(1, std::memory_order_relaxed); 67 auto clone = std::make_shared<ParamInfo>(*this); 68 clone->be_cloned_ = false; 69 clone->cloned_ = true; 70 clone->be_cloned_index_ = {}; 71 clone->cloned_index_ = index; 72 this->be_cloned_ = true; 73 this->be_cloned_index_.push_back(index); 74 clone->init_in_server_ = this->init_in_server_; 75 clone->requires_aggr_ = this->requires_aggr_; 76 clone->ClearParameter(); 77 return clone; 78 } 79 comm_fusion()80 int32_t comm_fusion() const { return fusion_type_; } set_comm_fusion(int32_t fusion_type)81 void set_comm_fusion(int32_t fusion_type) { fusion_type_ = fusion_type; } 82 parallel_optimizer()83 bool parallel_optimizer() const { return parallel_optimizer_; } set_parallel_optimizer(bool parallel_optimizer)84 void set_parallel_optimizer(bool parallel_optimizer) { parallel_optimizer_ = parallel_optimizer; } 85 parallel_optimizer_comm_recompute()86 bool parallel_optimizer_comm_recompute() const { return parallel_optimizer_comm_recompute_; } set_parallel_optimizer_comm_recompute(bool parallel_optimizer_comm_recompute)87 void set_parallel_optimizer_comm_recompute(bool parallel_optimizer_comm_recompute) { 88 parallel_optimizer_comm_recompute_ = parallel_optimizer_comm_recompute; 89 } 90 cache_enable()91 bool cache_enable() const { return cache_enable_; } set_cache_enable(bool cache_enable)92 void set_cache_enable(bool cache_enable) { cache_enable_ = cache_enable; } 93 cache_shape()94 std::vector<int64_t> cache_shape() const { return cache_shape_; } set_cache_shape(const std::vector<int64_t> & cache_shape)95 void set_cache_shape(const std::vector<int64_t> &cache_shape) { cache_shape_ = cache_shape; } parameter()96 ParameterPtr parameter() { return parameter_; } set_parameter(const ParameterPtr & parameter)97 void set_parameter(const ParameterPtr ¶meter) { parameter_ = parameter; } ClearParameter()98 void ClearParameter() { parameter_ = nullptr; } 99 requires_aggr()100 bool requires_aggr() const { return requires_aggr_; } set_requires_aggr(bool requires_aggr)101 void set_requires_aggr(bool requires_aggr) { requires_aggr_ = requires_aggr; } 102 103 private: 104 std::string name_{"Parameter"}; 105 bool requires_grad_{true}; 106 bool init_in_server_{false}; 107 bool layerwise_parallel_{false}; 108 bool be_cloned_{false}; 109 bool cloned_{false}; 110 std::vector<int32_t> be_cloned_index_; 111 int32_t cloned_index_{0}; 112 int32_t fusion_type_{1}; 113 bool parallel_optimizer_{true}; 114 bool parallel_optimizer_comm_recompute_{false}; 115 bool cache_enable_{false}; 116 std::vector<int64_t> cache_shape_; 117 ParameterPtr parameter_{nullptr}; 118 bool requires_aggr_{true}; 119 }; 120 } // namespace mindspore 121 #endif // MINDSPORE_CORE_IR_PARAM_INFO_H_ 122