• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_UTILS_CONFIG_MANAGER_H_
18 #define MINDSPORE_CCSRC_UTILS_CONFIG_MANAGER_H_
19 
20 #include <string>
21 #include <memory>
22 #include <vector>
23 #include <map>
24 #include <utility>
25 #include <sstream>
26 
27 #include "utils/overload.h"
28 
29 namespace mindspore {
30 
31 enum ParallelStrategy {
32   ONE_DEVICE = 0,
33   DISTRIBUTION,
34 };
35 
36 enum DatasetMode { DS_NORMAL_MODE = 0, DS_SINK_MODE };
37 
38 class DatasetGraphParam {
39  public:
DatasetGraphParam(const std::string & name,int64_t size,int64_t batch_size,const std::vector<int64_t> & ge_types,const std::vector<std::vector<int64_t>> & shapes,const std::vector<int64_t> & input_indexes)40   DatasetGraphParam(const std::string &name, int64_t size, int64_t batch_size, const std::vector<int64_t> &ge_types,
41                     const std::vector<std::vector<int64_t>> &shapes, const std::vector<int64_t> &input_indexes)
42       : queue_name_(name),
43         loop_size_(size),
44         batch_size_(batch_size),
45         ge_types_(ge_types),
46         shapes_(shapes),
47         input_indexes_(input_indexes) {}
48 
49   ~DatasetGraphParam() = default;
50 
ToString()51   std::string ToString() const {
52     std::ostringstream buffer;
53     buffer << "DatasetGraphParam: queue_name=" << queue_name_ << " size=" << loop_size_ << " batch_size=" << batch_size_
54            << " ge_types=" << ge_types_ << " shapes=" << shapes_ << " input_indexes=" << input_indexes_;
55     return buffer.str();
56   }
queue_name()57   std::string queue_name() const { return queue_name_; }
loop_size()58   int64_t loop_size() const { return loop_size_; }
batch_size()59   int64_t batch_size() const { return batch_size_; }
ge_types()60   std::vector<int64_t> ge_types() const { return ge_types_; }
shapes()61   std::vector<std::vector<int64_t>> shapes() const { return shapes_; }
input_indexes()62   std::vector<int64_t> input_indexes() const { return input_indexes_; }
63 
64  private:
65   std::string queue_name_;
66   int64_t loop_size_;
67   int64_t batch_size_;
68   std::vector<int64_t> ge_types_;
69   std::vector<std::vector<int64_t>> shapes_;
70   std::vector<int64_t> input_indexes_;
71 };
72 
73 class ConfigManager {
74  public:
75   ConfigManager(const ConfigManager &) = delete;
76   ConfigManager &operator=(const ConfigManager &) = delete;
77   static ConfigManager &GetInstance() noexcept;
78 
parallel_strategy()79   ParallelStrategy parallel_strategy() const { return parallel_strategy_; }
set_parallel_strategy(ParallelStrategy strategy)80   void set_parallel_strategy(ParallelStrategy strategy) { parallel_strategy_ = strategy; }
81 
ge_initialize_options()82   const std::map<std::string, std::string> &ge_initialize_options() const { return ge_initialize_options_; }
set_ge_initialize_options(const std::map<std::string,std::string> & options)83   void set_ge_initialize_options(const std::map<std::string, std::string> &options) {
84     ge_initialize_options_ = options;
85   }
86 
dataset_mode()87   DatasetMode dataset_mode() const { return dataset_mode_; }
set_dataset_mode(DatasetMode mode)88   void set_dataset_mode(DatasetMode mode) { dataset_mode_ = mode; }
iter_num()89   int64_t iter_num() const {
90     if (dataset_mode_ == DS_NORMAL_MODE) return 1;
91     return iter_num_;
92   }
set_iter_num(const int64_t num)93   void set_iter_num(const int64_t num) { iter_num_ = num; }
94 
dataset_phase()95   std::string dataset_phase() const { return dataset_phase_; }
set_dataset_phase(const std::string & phase)96   void set_dataset_phase(const std::string &phase) { dataset_phase_ = phase; }
97 
dataset_param()98   DatasetGraphParam dataset_param() const { return dataset_param_; }
set_dataset_param(const DatasetGraphParam & param)99   void set_dataset_param(const DatasetGraphParam &param) { dataset_param_ = param; }
100 
101   static void SetDatasetModeConfig(const std::string &mode);
102 
103   void ResetConfig() noexcept;
104 
105   void ResetIterNum() noexcept;
106 
107   std::map<std::string, std::string> ge_initialize_options_;
108 
gpu_loopsink_size()109   int64_t gpu_loopsink_size() const { return gpu_loopsink_size_; }
110 
set_gpu_loopsink_size(const int64_t size)111   void set_gpu_loopsink_size(const int64_t size) { gpu_loopsink_size_ = size; }
112 
113  private:
114   ConfigManager() = default;
115   ~ConfigManager() = default;
116 
117   ParallelStrategy parallel_strategy_{ONE_DEVICE};
118   DatasetMode dataset_mode_{DS_NORMAL_MODE};
119   DatasetGraphParam dataset_param_{"", 0, 0, {}, {}, {}};
120   int64_t iter_num_{1};
121   std::string dataset_phase_{""};
122   int64_t gpu_loopsink_size_{1};
123 };
124 
125 }  // namespace mindspore
126 
127 #endif  // MINDSPORE_CCSRC_UTILS_CONFIG_MANAGER_H_
128