• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_INCLUDE_COMMON_UTILS_CONFIG_MANAGER_H_
18 #define MINDSPORE_CCSRC_INCLUDE_COMMON_UTILS_CONFIG_MANAGER_H_
19 
20 #include <string>
21 #include <memory>
22 #include <vector>
23 #include <map>
24 #include <sstream>
25 
26 #include "utils/overload.h"
27 #include "include/common/visible.h"
28 
29 namespace mindspore {
30 
31 enum ParallelStrategy {
32   ONE_DEVICE = 0,
33   DISTRIBUTION,
34 };
35 
36 enum DatasetMode { DS_NORMAL_MODE = 0, DS_SINK_MODE };
37 
38 class DatasetGraphParam {
39  public:
DatasetGraphParam(const std::string & name,int64_t size,int64_t batch_size,const std::vector<int64_t> & ge_types,const std::vector<std::vector<int64_t>> & shapes,const std::vector<int64_t> & input_indexes)40   DatasetGraphParam(const std::string &name, int64_t size, int64_t batch_size, const std::vector<int64_t> &ge_types,
41                     const std::vector<std::vector<int64_t>> &shapes, const std::vector<int64_t> &input_indexes)
42       : queue_name_(name),
43         loop_size_(size),
44         batch_size_(batch_size),
45         ge_types_(ge_types),
46         shapes_(shapes),
47         input_indexes_(input_indexes) {}
48 
49   ~DatasetGraphParam() = default;
50 
ToString()51   std::string ToString() const {
52     std::ostringstream buffer;
53     buffer << "DatasetGraphParam: queue_name=" << queue_name_ << " size=" << loop_size_ << " batch_size=" << batch_size_
54            << " ge_types=" << ge_types_ << " shapes=" << shapes_ << " input_indexes=" << input_indexes_;
55     return buffer.str();
56   }
queue_name()57   std::string queue_name() const { return queue_name_; }
loop_size()58   int64_t loop_size() const { return loop_size_; }
batch_size()59   int64_t batch_size() const { return batch_size_; }
ge_types()60   std::vector<int64_t> ge_types() const { return ge_types_; }
shapes()61   std::vector<std::vector<int64_t>> shapes() const { return shapes_; }
input_indexes()62   std::vector<int64_t> input_indexes() const { return input_indexes_; }
63 
64  private:
65   std::string queue_name_;
66   int64_t loop_size_;
67   int64_t batch_size_;
68   std::vector<int64_t> ge_types_;
69   std::vector<std::vector<int64_t>> shapes_;
70   std::vector<int64_t> input_indexes_;
71 };
72 
73 class COMMON_EXPORT ConfigManager {
74  public:
75   ConfigManager(const ConfigManager &) = delete;
76   ConfigManager &operator=(const ConfigManager &) = delete;
77   static ConfigManager &GetInstance() noexcept;
78 
parallel_strategy()79   ParallelStrategy parallel_strategy() const { return parallel_strategy_; }
set_parallel_strategy(ParallelStrategy strategy)80   void set_parallel_strategy(ParallelStrategy strategy) { parallel_strategy_ = strategy; }
81 
ge_initialize_options()82   const std::map<std::string, std::string> &ge_initialize_options() const { return ge_initialize_options_; }
set_ge_initialize_options(const std::map<std::string,std::string> & options)83   void set_ge_initialize_options(const std::map<std::string, std::string> &options) {
84     ge_initialize_options_ = options;
85   }
86 
dataset_mode()87   DatasetMode dataset_mode() const { return dataset_mode_; }
set_dataset_mode(DatasetMode mode)88   void set_dataset_mode(DatasetMode mode) { dataset_mode_ = mode; }
iter_num()89   int64_t iter_num() const {
90     if (dataset_mode_ == DS_NORMAL_MODE) {
91       return 1;
92     }
93     return iter_num_;
94   }
95 
set_iter_num(const std::string & queue_name,const int64_t num)96   void set_iter_num(const std::string &queue_name, const int64_t num) {
97     queue_name_ = queue_name;
98     iter_num_ = num;
99     queue_info_map[queue_name_] = static_cast<int16_t>(num);
100   }
101 
dataset_phase()102   std::string dataset_phase() const { return dataset_phase_; }
set_dataset_phase(const std::string & phase)103   void set_dataset_phase(const std::string &phase) { dataset_phase_ = phase; }
104 
dataset_param()105   DatasetGraphParam dataset_param() const { return dataset_param_; }
set_dataset_param(const DatasetGraphParam & param)106   void set_dataset_param(const DatasetGraphParam &param) { dataset_param_ = param; }
107 
108   static void SetDatasetModeConfig(const std::string &mode);
109 
110   void ResetConfig() noexcept;
111 
112   void ResetIterNum() noexcept;
113 
114   void ResetQueue(const std::string &queue_name) noexcept;
QueueName()115   std::string QueueName() const { return queue_name_; }
116   std::map<std::string, std::string> ge_initialize_options_;
117 
gpu_loopsink_size()118   int64_t gpu_loopsink_size() const { return gpu_loopsink_size_; }
119 
set_gpu_loopsink_size(const int64_t size)120   void set_gpu_loopsink_size(const int64_t size) { gpu_loopsink_size_ = size; }
121 
122  private:
123   ConfigManager() = default;
124   ~ConfigManager() = default;
125 
126   ParallelStrategy parallel_strategy_{ONE_DEVICE};
127   DatasetMode dataset_mode_{DS_NORMAL_MODE};
128   DatasetGraphParam dataset_param_{"", 0, 0, {}, {}, {}};
129   int64_t iter_num_{1};
130   std::string queue_name_{""};
131   // now only save iter_num_ in the map
132   std::map<std::string, int16_t> queue_info_map;
133   std::string dataset_phase_{""};
134   int64_t gpu_loopsink_size_{1};
135 };
136 
137 }  // namespace mindspore
138 
139 #endif  // MINDSPORE_CCSRC_INCLUDE_COMMON_UTILS_CONFIG_MANAGER_H_
140