1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "include/backend/distributed/ps/ps_cache/ps_data_prefetch.h"
18 #include "utils/log_adapter.h"
19
20 namespace mindspore {
21 namespace ps {
22 const size_t kTimeoutLoopCount = 80;
23 const int64_t kLongestTimeToWait = 30;
24
GetInstance()25 PsDataPrefetch &PsDataPrefetch::GetInstance() {
26 static PsDataPrefetch instance;
27 return instance;
28 }
29
CreateDataChannel(const std::string & channel_name,size_t step_num)30 void PsDataPrefetch::CreateDataChannel(const std::string &channel_name, size_t step_num) {
31 if (cache_enable_ == false) {
32 return;
33 }
34 MS_LOG(INFO) << "PS cache creates data channel(channel name:" << channel_name << ", step num:" << step_num << ").";
35 auto iter = ps_data_channel_map_.find(channel_name);
36 if (iter != ps_data_channel_map_.end()) {
37 MS_LOG(WARNING) << "The ps data channel already exists, channel name:" << channel_name;
38 auto channel = iter->second;
39 MS_ERROR_IF_NULL_WO_RET_VAL(channel);
40 channel->set_step_num(step_num);
41 } else {
42 auto channel = std::make_shared<PsDataChannel>(channel_name, step_num);
43 (void)ps_data_channel_map_.emplace(channel_name, channel);
44 }
45 }
46
ps_data_channel(const std::string & channel_name) const47 std::shared_ptr<PsDataChannel> PsDataPrefetch::ps_data_channel(const std::string &channel_name) const {
48 auto iter = ps_data_channel_map_.find(channel_name);
49 if (iter == ps_data_channel_map_.end()) {
50 MS_LOG(ERROR) << "The ps data channel does not exist, channel name:" << channel_name;
51 return nullptr;
52 }
53 return iter->second;
54 }
55
PrefetchData(const std::string & channel_name,void * data,const size_t data_size,const std::string & data_type)56 bool PsDataPrefetch::PrefetchData(const std::string &channel_name, void *data, const size_t data_size,
57 const std::string &data_type) {
58 if (cache_enable_ == false) {
59 return true;
60 }
61 // In ps cache mode, input ids are from dataset and data type transmitted from minddata must be 'int32'
62 const std::string supported_data_type = "int32";
63 if (data_type != supported_data_type) {
64 MS_LOG(ERROR) << "Parameter server cache mode need input id with data type[int32], but got[" << data_type << "]";
65 invalid_data_type_ = true;
66 return false;
67 }
68 if (data == nullptr) {
69 MS_LOG(WARNING) << "No data prefetch.";
70 return true;
71 }
72
73 if (!need_wait_) {
74 return true;
75 }
76
77 auto channel = ps_data_channel(channel_name);
78 MS_ERROR_IF_NULL(channel);
79 channel->set_data(data, data_size);
80 std::unique_lock<std::mutex> locker(data_mutex_);
81 data_ready_ = true;
82 data_process_.notify_one();
83
84 for (size_t i = 0; i < kTimeoutLoopCount; ++i) {
85 if (data_prefetch_.wait_for(locker, std::chrono::seconds(kLongestTimeToWait),
86 [this] { return data_ready_ == false || need_wait_ == false; })) {
87 return true;
88 } else {
89 MS_LOG(INFO) << "Waiting for ps data process, channel name:" << channel_name << "...(" << i << " / "
90 << kTimeoutLoopCount << ")";
91 }
92 }
93 MS_LOG(ERROR) << "Ps cache data process timeout, suggest to enlarge the cache size.";
94 return false;
95 }
96
FinalizeData(const std::string & channel_name)97 bool PsDataPrefetch::FinalizeData(const std::string &channel_name) {
98 if (cache_enable_ == false) {
99 return true;
100 }
101 auto channel = ps_data_channel(channel_name);
102 MS_ERROR_IF_NULL(channel);
103 channel->ResetData();
104 std::unique_lock<std::mutex> locker(data_mutex_);
105 data_ready_ = false;
106 data_prefetch_.notify_one();
107 if (!need_wait_) {
108 return true;
109 }
110
111 for (size_t i = 0; i < kTimeoutLoopCount; ++i) {
112 if (data_process_.wait_for(locker, std::chrono::seconds(kLongestTimeToWait),
113 [this] { return data_ready_ == true || need_wait_ == false; })) {
114 return true;
115 } else {
116 MS_LOG(INFO) << "Waiting for ps data prefetch, channel name:" << channel_name << "...(" << i << " / "
117 << kTimeoutLoopCount << ")";
118 }
119 }
120 MS_LOG(ERROR) << "Ps cache data prefetch timeout.";
121 return false;
122 }
123
QueryData(const std::string & channel_name,void ** data_ptr) const124 bool PsDataPrefetch::QueryData(const std::string &channel_name, void **data_ptr) const {
125 if (invalid_data_type_) {
126 return false;
127 }
128 if (data_ptr == nullptr) {
129 return false;
130 }
131 auto channel = ps_data_channel(channel_name);
132 if (channel == nullptr) {
133 *data_ptr = nullptr;
134 return true;
135 }
136 *data_ptr = const_cast<void *>(channel->data());
137 return true;
138 }
139
data_size(const std::string & channel_name) const140 size_t PsDataPrefetch::data_size(const std::string &channel_name) const {
141 auto channel = ps_data_channel(channel_name);
142 if (channel == nullptr) {
143 return 0;
144 }
145 return channel->data_size();
146 }
147
NotifyFinalize()148 void PsDataPrefetch::NotifyFinalize() {
149 std::lock_guard<std::mutex> lock(finalize_mutex_);
150 if (!need_wait_) {
151 return;
152 }
153
154 need_wait_ = false;
155 WakeAllChannel();
156 data_prefetch_.notify_one();
157 data_process_.notify_one();
158 }
159
TryWakeChannel(const std::string & channel_name) const160 bool PsDataPrefetch::TryWakeChannel(const std::string &channel_name) const {
161 auto channel = ps_data_channel(channel_name);
162 if (channel == nullptr) {
163 return false;
164 }
165 channel->TryWakeChannel();
166 return true;
167 }
168
WakeAllChannel() const169 void PsDataPrefetch::WakeAllChannel() const {
170 for (auto iter = ps_data_channel_map_.begin(); iter != ps_data_channel_map_.end(); ++iter) {
171 auto channel = iter->second;
172 if (channel == nullptr) {
173 return;
174 }
175 channel->TryWakeChannel(true);
176 }
177 }
178 } // namespace ps
179 } // namespace mindspore
180