1 /**
2 * Copyright 2019-2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "minddata/dataset/core/config_manager.h"
17
18 #include <fstream>
19
20 #include "include/dataset/constants.h"
21 #include "minddata/dataset/util/log_adapter.h"
22 #include "minddata/dataset/util/status.h"
23 #include "nlohmann/json.hpp"
24 #include "util/path.h"
25 #include "utils/ms_utils.h"
26
27 namespace mindspore {
28 namespace dataset {
ConfigManager()29 ConfigManager::ConfigManager()
30 : num_parallel_workers_(kCfgParallelWorkers),
31 worker_connector_size_(kCfgWorkerConnectorSize),
32 op_connector_size_(kCfgOpConnectorSize),
33 sending_batches_(kCfgSendingBatch),
34 rank_id_(kCfgDefaultRankId),
35 seed_(kCfgDefaultSeed),
36 monitor_sampling_interval_(kCfgMonitorSamplingInterval),
37 callback_timout_(kCfgCallbackTimeout),
38 cache_host_(kCfgDefaultCacheHost),
39 cache_port_(kCfgDefaultCachePort),
40 num_connections_(kDftNumConnections),
41 numa_enable_(false),
42 cache_prefetch_size_(kDftCachePrefetchSize),
43 auto_num_workers_(kDftAutoNumWorkers),
44 num_cpu_threads_(std::thread::hardware_concurrency()),
45 auto_num_workers_num_shards_(1),
46 auto_worker_config_(0),
47 enable_shared_mem_(true),
48 auto_offload_(false),
49 enable_autotune_(false),
50 save_autoconfig_(false),
51 autotune_interval_(kCfgAutoTuneInterval),
52 enable_watchdog_(true),
53 multiprocessing_timeout_interval_(kCfgMultiprocessingTimeoutInterval) {
54 autotune_json_filepath_ = kEmptyString;
55 num_cpu_threads_ = num_cpu_threads_ > 0 ? num_cpu_threads_ : std::numeric_limits<uint16_t>::max();
56 num_parallel_workers_ = num_parallel_workers_ < num_cpu_threads_ ? num_parallel_workers_ : num_cpu_threads_;
57 std::string env_cache_host = common::GetEnv("MS_CACHE_HOST");
58 std::string env_cache_port = common::GetEnv("MS_CACHE_PORT");
59 if (!env_cache_host.empty()) {
60 cache_host_ = env_cache_host;
61 }
62 if (!env_cache_port.empty()) {
63 char *end = nullptr;
64 cache_port_ = static_cast<int32_t>(strtol(env_cache_port.c_str(), &end, kDecimal));
65 if (*end != '\0') {
66 MS_LOG(WARNING) << "Cache port from env variable MS_CACHE_PORT is invalid\n";
67 cache_port_ = 0; // cause the port range validation to generate an error during the validation checks
68 }
69 }
70 }
71
72 // A print method typically used for debugging
Print(std::ostream & out) const73 void ConfigManager::Print(std::ostream &out) const {
74 // Don't show the test/internal ones. Only display the main ones here.
75 // fyi, boolalpha tells the output stream to write "true" and "false" for bools
76 out << "\nClient config settings :"
77 << "\nParallelOp workers : " << num_parallel_workers_
78 << "\nParallelOp worker connector size : " << worker_connector_size_
79 << "\nSize of each Connector : " << op_connector_size_ << std::endl;
80 }
81
82 // Private helper function that takes a nlohmann json format and populates the settings
FromJson(const nlohmann::json & j)83 Status ConfigManager::FromJson(const nlohmann::json &j) {
84 RETURN_IF_NOT_OK(set_num_parallel_workers(j.value("numParallelWorkers", num_parallel_workers_)));
85 set_worker_connector_size(j.value("workerConnectorSize", worker_connector_size_));
86 set_op_connector_size(j.value("opConnectorSize", op_connector_size_));
87 set_seed(j.value("seed", seed_));
88 set_monitor_sampling_interval(j.value("monitorSamplingInterval", monitor_sampling_interval_));
89 set_fast_recovery(j.value("fast_recovery", fast_recovery_));
90 set_error_samples_mode(j.value("error_samples_mode", error_samples_mode_));
91 set_cache_host(j.value("cacheHost", cache_host_));
92 set_cache_port(j.value("cachePort", cache_port_));
93 set_num_connections(j.value("numConnections", num_connections_));
94 set_cache_prefetch_size(j.value("cachePrefetchSize", cache_prefetch_size_));
95 set_debug_mode(j.value("debug_mode_flag", debug_mode_flag_));
96 return Status::OK();
97 }
98
99 // Loads a json file with the default settings and populates all the settings
LoadFile(const std::string & settingsFile)100 Status ConfigManager::LoadFile(const std::string &settingsFile) {
101 Status rc;
102 if (!Path(settingsFile).Exists()) {
103 RETURN_STATUS_UNEXPECTED("Invalid file: settings file:" + settingsFile +
104 " is not exist, check input path of config 'load' API.");
105 }
106 // Some settings are mandatory, others are not (with default). If a setting
107 // is optional it will set a default value if the config is missing from the file.
108 try {
109 std::ifstream in(settingsFile, std::ios::in);
110 nlohmann::json js;
111 in >> js;
112 rc = FromJson(js);
113 in.close();
114 } catch (const nlohmann::json::type_error &e) {
115 std::ostringstream ss;
116 ss << "Client file failed to load:\n" << e.what();
117 std::string err_msg = ss.str();
118 RETURN_STATUS_UNEXPECTED(err_msg);
119 } catch (const std::exception &err) {
120 RETURN_STATUS_UNEXPECTED("Client file failed to load.");
121 }
122 return rc;
123 }
124
125 // Setter function
set_num_parallel_workers(int32_t num_parallel_workers)126 Status ConfigManager::set_num_parallel_workers(int32_t num_parallel_workers) {
127 if (num_parallel_workers > num_cpu_threads_ || num_parallel_workers < 1) {
128 std::string err_msg = "Invalid Parameter, num_parallel_workers exceeds the boundary between 1 and " +
129 std::to_string(num_cpu_threads_) + ", as got " + std::to_string(num_parallel_workers) + ".";
130 LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
131 }
132 num_parallel_workers_ = num_parallel_workers;
133 return Status::OK();
134 }
135
136 // Setter function
set_worker_connector_size(int32_t connector_size)137 void ConfigManager::set_worker_connector_size(int32_t connector_size) { worker_connector_size_ = connector_size; }
138
139 // Setter function
set_op_connector_size(int32_t connector_size)140 void ConfigManager::set_op_connector_size(int32_t connector_size) { op_connector_size_ = connector_size; }
141
set_sending_batches(int64_t sending_batches)142 void ConfigManager::set_sending_batches(int64_t sending_batches) { sending_batches_ = sending_batches; }
143
seed() const144 uint32_t ConfigManager::seed() const { return seed_; }
145
set_rank_id(int32_t rank_id)146 void ConfigManager::set_rank_id(int32_t rank_id) {
147 if (rank_id_ == kCfgDefaultRankId) {
148 rank_id_ = rank_id;
149 }
150 }
151
set_numa_enable(bool numa_enable)152 void ConfigManager::set_numa_enable(bool numa_enable) { numa_enable_ = numa_enable; }
153
set_seed(uint32_t seed)154 void ConfigManager::set_seed(uint32_t seed) { seed_ = seed; }
155
set_monitor_sampling_interval(uint32_t interval)156 void ConfigManager::set_monitor_sampling_interval(uint32_t interval) { monitor_sampling_interval_ = interval; }
157
set_callback_timeout(uint32_t timeout)158 void ConfigManager::set_callback_timeout(uint32_t timeout) { callback_timout_ = timeout; }
159
set_cache_host(std::string cache_host)160 void ConfigManager::set_cache_host(std::string cache_host) { cache_host_ = std::move(cache_host); }
161
set_cache_port(int32_t cache_port)162 void ConfigManager::set_cache_port(int32_t cache_port) { cache_port_ = cache_port; }
163
set_num_connections(int32_t num_connections)164 void ConfigManager::set_num_connections(int32_t num_connections) { num_connections_ = num_connections; }
165
set_cache_prefetch_size(int32_t cache_prefetch_size)166 void ConfigManager::set_cache_prefetch_size(int32_t cache_prefetch_size) { cache_prefetch_size_ = cache_prefetch_size; }
167
set_enable_autotune(bool enable,bool save_autoconfig,const std::string & json_filepath)168 Status ConfigManager::set_enable_autotune(bool enable, bool save_autoconfig, const std::string &json_filepath) {
169 enable_autotune_ = enable;
170 save_autoconfig_ = save_autoconfig;
171
172 // Check if not requested to save AutoTune config
173 if (!save_autoconfig_) {
174 // No need for further processing, like process json_filepath input
175 return Status::OK();
176 }
177
178 Path jsonpath(json_filepath);
179
180 if (jsonpath.IsDirectory()) {
181 std::string err_msg = "Invalid json_filepath parameter. <" + json_filepath + "> is a directory, not filename.";
182 LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
183 }
184
185 std::string parent_path = jsonpath.ParentPath();
186 if (parent_path != "") {
187 if (!Path(parent_path).Exists()) {
188 std::string err_msg = "Invalid json_filepath parameter. Directory <" + parent_path + "> does not exist.";
189 LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
190 }
191 } else {
192 // Set parent_path to current working directory
193 parent_path = ".";
194 }
195
196 std::string real_path;
197 if (Path::RealPath(parent_path, real_path).IsError()) {
198 std::string err_msg = "Invalid json_filepath parameter. Cannot get real json_filepath <" + real_path + ">.";
199 LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
200 }
201
202 if (access(real_path.c_str(), W_OK) == -1) {
203 std::string err_msg = "Invalid json_filepath parameter. No access to write to <" + real_path + ">.";
204 LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
205 }
206
207 if (jsonpath.Exists()) {
208 // Note: Allow file to be overwritten (like serialize)
209 std::string err_msg = "Invalid json_filepath parameter. File: <" + json_filepath + "> already exists." +
210 " File will be overwritten with the AutoTuned data pipeline configuration.";
211 MS_LOG(WARNING) << err_msg;
212 }
213
214 // Save the final AutoTune configuration JSON filepath name
215 autotune_json_filepath_ = json_filepath;
216 return Status::OK();
217 }
218
219 } // namespace dataset
220 } // namespace mindspore
221