• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "include/backend/distributed/ps/ps_cache/ps_data_prefetch.h"
18 #include "utils/log_adapter.h"
19 
20 namespace mindspore {
21 namespace ps {
22 const size_t kTimeoutLoopCount = 80;
23 const int64_t kLongestTimeToWait = 30;
24 
GetInstance()25 PsDataPrefetch &PsDataPrefetch::GetInstance() {
26   static PsDataPrefetch instance;
27   return instance;
28 }
29 
CreateDataChannel(const std::string & channel_name,size_t step_num)30 void PsDataPrefetch::CreateDataChannel(const std::string &channel_name, size_t step_num) {
31   if (cache_enable_ == false) {
32     return;
33   }
34   MS_LOG(INFO) << "PS cache creates data channel(channel name:" << channel_name << ", step num:" << step_num << ").";
35   auto iter = ps_data_channel_map_.find(channel_name);
36   if (iter != ps_data_channel_map_.end()) {
37     MS_LOG(WARNING) << "The ps data channel already exists, channel name:" << channel_name;
38     auto channel = iter->second;
39     MS_ERROR_IF_NULL_WO_RET_VAL(channel);
40     channel->set_step_num(step_num);
41   } else {
42     auto channel = std::make_shared<PsDataChannel>(channel_name, step_num);
43     (void)ps_data_channel_map_.emplace(channel_name, channel);
44   }
45 }
46 
ps_data_channel(const std::string & channel_name) const47 std::shared_ptr<PsDataChannel> PsDataPrefetch::ps_data_channel(const std::string &channel_name) const {
48   auto iter = ps_data_channel_map_.find(channel_name);
49   if (iter == ps_data_channel_map_.end()) {
50     MS_LOG(ERROR) << "The ps data channel does not exist, channel name:" << channel_name;
51     return nullptr;
52   }
53   return iter->second;
54 }
55 
PrefetchData(const std::string & channel_name,void * data,const size_t data_size,const std::string & data_type)56 bool PsDataPrefetch::PrefetchData(const std::string &channel_name, void *data, const size_t data_size,
57                                   const std::string &data_type) {
58   if (cache_enable_ == false) {
59     return true;
60   }
61   // In ps cache mode, input ids are from dataset and data type transmitted from minddata must be 'int32'
62   const std::string supported_data_type = "int32";
63   if (data_type != supported_data_type) {
64     MS_LOG(ERROR) << "Parameter server cache mode need input id with data type[int32], but got[" << data_type << "]";
65     invalid_data_type_ = true;
66     return false;
67   }
68   if (data == nullptr) {
69     MS_LOG(WARNING) << "No data prefetch.";
70     return true;
71   }
72 
73   if (!need_wait_) {
74     return true;
75   }
76 
77   auto channel = ps_data_channel(channel_name);
78   MS_ERROR_IF_NULL(channel);
79   channel->set_data(data, data_size);
80   std::unique_lock<std::mutex> locker(data_mutex_);
81   data_ready_ = true;
82   data_process_.notify_one();
83 
84   for (size_t i = 0; i < kTimeoutLoopCount; ++i) {
85     if (data_prefetch_.wait_for(locker, std::chrono::seconds(kLongestTimeToWait),
86                                 [this] { return data_ready_ == false || need_wait_ == false; })) {
87       return true;
88     } else {
89       MS_LOG(INFO) << "Waiting for ps data process, channel name:" << channel_name << "...(" << i << " / "
90                    << kTimeoutLoopCount << ")";
91     }
92   }
93   MS_LOG(ERROR) << "Ps cache data process timeout, suggest to enlarge the cache size.";
94   return false;
95 }
96 
FinalizeData(const std::string & channel_name)97 bool PsDataPrefetch::FinalizeData(const std::string &channel_name) {
98   if (cache_enable_ == false) {
99     return true;
100   }
101   auto channel = ps_data_channel(channel_name);
102   MS_ERROR_IF_NULL(channel);
103   channel->ResetData();
104   std::unique_lock<std::mutex> locker(data_mutex_);
105   data_ready_ = false;
106   data_prefetch_.notify_one();
107   if (!need_wait_) {
108     return true;
109   }
110 
111   for (size_t i = 0; i < kTimeoutLoopCount; ++i) {
112     if (data_process_.wait_for(locker, std::chrono::seconds(kLongestTimeToWait),
113                                [this] { return data_ready_ == true || need_wait_ == false; })) {
114       return true;
115     } else {
116       MS_LOG(INFO) << "Waiting for ps data prefetch, channel name:" << channel_name << "...(" << i << " / "
117                    << kTimeoutLoopCount << ")";
118     }
119   }
120   MS_LOG(ERROR) << "Ps cache data prefetch timeout.";
121   return false;
122 }
123 
QueryData(const std::string & channel_name,void ** data_ptr) const124 bool PsDataPrefetch::QueryData(const std::string &channel_name, void **data_ptr) const {
125   if (invalid_data_type_) {
126     return false;
127   }
128   if (data_ptr == nullptr) {
129     return false;
130   }
131   auto channel = ps_data_channel(channel_name);
132   if (channel == nullptr) {
133     *data_ptr = nullptr;
134     return true;
135   }
136   *data_ptr = const_cast<void *>(channel->data());
137   return true;
138 }
139 
data_size(const std::string & channel_name) const140 size_t PsDataPrefetch::data_size(const std::string &channel_name) const {
141   auto channel = ps_data_channel(channel_name);
142   if (channel == nullptr) {
143     return 0;
144   }
145   return channel->data_size();
146 }
147 
NotifyFinalize()148 void PsDataPrefetch::NotifyFinalize() {
149   std::lock_guard<std::mutex> lock(finalize_mutex_);
150   if (!need_wait_) {
151     return;
152   }
153 
154   need_wait_ = false;
155   WakeAllChannel();
156   data_prefetch_.notify_one();
157   data_process_.notify_one();
158 }
159 
TryWakeChannel(const std::string & channel_name) const160 bool PsDataPrefetch::TryWakeChannel(const std::string &channel_name) const {
161   auto channel = ps_data_channel(channel_name);
162   if (channel == nullptr) {
163     return false;
164   }
165   channel->TryWakeChannel();
166   return true;
167 }
168 
WakeAllChannel() const169 void PsDataPrefetch::WakeAllChannel() const {
170   for (auto iter = ps_data_channel_map_.begin(); iter != ps_data_channel_map_.end(); ++iter) {
171     auto channel = iter->second;
172     if (channel == nullptr) {
173       return;
174     }
175     channel->TryWakeChannel(true);
176   }
177 }
178 }  // namespace ps
179 }  // namespace mindspore
180