1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "ps/ps_cache/ps_data/ps_data_prefetch.h"
18 #include "utils/log_adapter.h"
19
20 namespace mindspore {
21 namespace ps {
22 const size_t kTimeoutLoopCount = 10;
23 const int64_t kLongestTimeToWait = 30;
24
CreateDataChannel(const std::string & channel_name,size_t step_num)25 void PsDataPrefetch::CreateDataChannel(const std::string &channel_name, size_t step_num) {
26 if (cache_enable_ == false) {
27 return;
28 }
29 MS_LOG(INFO) << "PS cache creates data channel(channel name:" << channel_name << ", step num:" << step_num << ").";
30 auto iter = ps_data_channel_map_.find(channel_name);
31 if (iter != ps_data_channel_map_.end()) {
32 MS_LOG(WARNING) << "The ps data channel already exists, channel name:" << channel_name;
33 auto channel = iter->second;
34 MS_ERROR_IF_NULL_WO_RET_VAL(channel);
35 channel->set_step_num(step_num);
36 } else {
37 auto channel = std::make_shared<PsDataChannel>(channel_name, step_num);
38 (void)ps_data_channel_map_.emplace(channel_name, channel);
39 }
40 }
41
ps_data_channel(const std::string & channel_name) const42 std::shared_ptr<PsDataChannel> PsDataPrefetch::ps_data_channel(const std::string &channel_name) const {
43 auto iter = ps_data_channel_map_.find(channel_name);
44 if (iter == ps_data_channel_map_.end()) {
45 MS_LOG(ERROR) << "The ps data channel does not exist, channel name:" << channel_name;
46 return nullptr;
47 }
48 return iter->second;
49 }
50
PrefetchData(const std::string & channel_name,void * data,const size_t data_size,const std::string & data_type)51 bool PsDataPrefetch::PrefetchData(const std::string &channel_name, void *data, const size_t data_size,
52 const std::string &data_type) {
53 if (cache_enable_ == false) {
54 return true;
55 }
56 // In ps cache mode, input ids are from dataset and data type transmitted from minddata must be 'int32'
57 const std::string supported_data_type = "int32";
58 if (data_type != supported_data_type) {
59 MS_LOG(ERROR) << "Parameter server cache mode need input id with data type[int32], but got[" << data_type << "]";
60 invalid_data_type_ = true;
61 return false;
62 }
63 if (data == nullptr) {
64 MS_LOG(WARNING) << "No data prefetch.";
65 return true;
66 }
67 auto channel = ps_data_channel(channel_name);
68 MS_ERROR_IF_NULL(channel);
69 channel->set_data(data, data_size);
70 std::unique_lock<std::mutex> locker(data_mutex_);
71 data_ready_ = true;
72 data_process_.notify_one();
73 if (!need_wait_) {
74 return true;
75 }
76
77 for (size_t i = 0; i < kTimeoutLoopCount; ++i) {
78 if (data_prefetch_.wait_for(locker, std::chrono::seconds(kLongestTimeToWait),
79 [this] { return data_ready_ == false || need_wait_ == false; })) {
80 return true;
81 } else {
82 MS_LOG(INFO) << "Waiting for ps data process, channel name:" << channel_name << "...(" << i << " / 10)";
83 }
84 }
85 MS_LOG(ERROR) << "Ps cache data process timeout, suggest to enlarge the cache size.";
86 return false;
87 }
88
FinalizeData(const std::string & channel_name)89 bool PsDataPrefetch::FinalizeData(const std::string &channel_name) {
90 if (cache_enable_ == false) {
91 return true;
92 }
93 auto channel = ps_data_channel(channel_name);
94 MS_ERROR_IF_NULL(channel);
95 channel->ResetData();
96 std::unique_lock<std::mutex> locker(data_mutex_);
97 data_ready_ = false;
98 data_prefetch_.notify_one();
99 if (!need_wait_) {
100 return true;
101 }
102
103 for (size_t i = 0; i < kTimeoutLoopCount; ++i) {
104 if (data_process_.wait_for(locker, std::chrono::seconds(kLongestTimeToWait),
105 [this] { return data_ready_ == true || need_wait_ == false; })) {
106 return true;
107 } else {
108 MS_LOG(INFO) << "Waiting for ps data prefetch, channel name:" << channel_name << "...(" << i << " / 10)";
109 }
110 }
111 MS_LOG(ERROR) << "Ps cache data prefetch timeout.";
112 return false;
113 }
114
QueryData(const std::string & channel_name,void ** data_ptr) const115 bool PsDataPrefetch::QueryData(const std::string &channel_name, void **data_ptr) const {
116 if (invalid_data_type_) {
117 return false;
118 }
119 if (data_ptr == nullptr) {
120 return false;
121 }
122 auto channel = ps_data_channel(channel_name);
123 if (channel == nullptr) {
124 *data_ptr = nullptr;
125 return true;
126 }
127 *data_ptr = const_cast<void *>(channel->data());
128 return true;
129 }
130
data_size(const std::string & channel_name) const131 size_t PsDataPrefetch::data_size(const std::string &channel_name) const {
132 auto channel = ps_data_channel(channel_name);
133 if (channel == nullptr) {
134 return 0;
135 }
136 return channel->data_size();
137 }
138
NotifyFinalize()139 void PsDataPrefetch::NotifyFinalize() {
140 need_wait_ = false;
141 WakeAllChannel();
142 data_prefetch_.notify_one();
143 data_process_.notify_one();
144 }
145
TryWakeChannel(const std::string & channel_name)146 bool PsDataPrefetch::TryWakeChannel(const std::string &channel_name) {
147 auto channel = ps_data_channel(channel_name);
148 if (channel == nullptr) {
149 return false;
150 }
151 channel->TryWakeChannel();
152 return true;
153 }
154
WakeAllChannel()155 void PsDataPrefetch::WakeAllChannel() {
156 for (auto iter = ps_data_channel_map_.begin(); iter != ps_data_channel_map_.end(); ++iter) {
157 auto channel = iter->second;
158 if (channel == nullptr) {
159 return;
160 }
161 channel->TryWakeChannel(true);
162 }
163 }
164 } // namespace ps
165 } // namespace mindspore
166