• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "ps/ps_cache/ps_data/ps_data_prefetch.h"
18 #include "utils/log_adapter.h"
19 
20 namespace mindspore {
21 namespace ps {
22 const size_t kTimeoutLoopCount = 10;
23 const int64_t kLongestTimeToWait = 30;
24 
CreateDataChannel(const std::string & channel_name,size_t step_num)25 void PsDataPrefetch::CreateDataChannel(const std::string &channel_name, size_t step_num) {
26   if (cache_enable_ == false) {
27     return;
28   }
29   MS_LOG(INFO) << "PS cache creates data channel(channel name:" << channel_name << ", step num:" << step_num << ").";
30   auto iter = ps_data_channel_map_.find(channel_name);
31   if (iter != ps_data_channel_map_.end()) {
32     MS_LOG(WARNING) << "The ps data channel already exists, channel name:" << channel_name;
33     auto channel = iter->second;
34     MS_ERROR_IF_NULL_WO_RET_VAL(channel);
35     channel->set_step_num(step_num);
36   } else {
37     auto channel = std::make_shared<PsDataChannel>(channel_name, step_num);
38     (void)ps_data_channel_map_.emplace(channel_name, channel);
39   }
40 }
41 
ps_data_channel(const std::string & channel_name) const42 std::shared_ptr<PsDataChannel> PsDataPrefetch::ps_data_channel(const std::string &channel_name) const {
43   auto iter = ps_data_channel_map_.find(channel_name);
44   if (iter == ps_data_channel_map_.end()) {
45     MS_LOG(ERROR) << "The ps data channel does not exist, channel name:" << channel_name;
46     return nullptr;
47   }
48   return iter->second;
49 }
50 
PrefetchData(const std::string & channel_name,void * data,const size_t data_size,const std::string & data_type)51 bool PsDataPrefetch::PrefetchData(const std::string &channel_name, void *data, const size_t data_size,
52                                   const std::string &data_type) {
53   if (cache_enable_ == false) {
54     return true;
55   }
56   // In ps cache mode, input ids are from dataset and data type transmitted from minddata must be 'int32'
57   const std::string supported_data_type = "int32";
58   if (data_type != supported_data_type) {
59     MS_LOG(ERROR) << "Parameter server cache mode need input id with data type[int32], but got[" << data_type << "]";
60     invalid_data_type_ = true;
61     return false;
62   }
63   if (data == nullptr) {
64     MS_LOG(WARNING) << "No data prefetch.";
65     return true;
66   }
67   auto channel = ps_data_channel(channel_name);
68   MS_ERROR_IF_NULL(channel);
69   channel->set_data(data, data_size);
70   std::unique_lock<std::mutex> locker(data_mutex_);
71   data_ready_ = true;
72   data_process_.notify_one();
73   if (!need_wait_) {
74     return true;
75   }
76 
77   for (size_t i = 0; i < kTimeoutLoopCount; ++i) {
78     if (data_prefetch_.wait_for(locker, std::chrono::seconds(kLongestTimeToWait),
79                                 [this] { return data_ready_ == false || need_wait_ == false; })) {
80       return true;
81     } else {
82       MS_LOG(INFO) << "Waiting for ps data process, channel name:" << channel_name << "...(" << i << " / 10)";
83     }
84   }
85   MS_LOG(ERROR) << "Ps cache data process timeout, suggest to enlarge the cache size.";
86   return false;
87 }
88 
FinalizeData(const std::string & channel_name)89 bool PsDataPrefetch::FinalizeData(const std::string &channel_name) {
90   if (cache_enable_ == false) {
91     return true;
92   }
93   auto channel = ps_data_channel(channel_name);
94   MS_ERROR_IF_NULL(channel);
95   channel->ResetData();
96   std::unique_lock<std::mutex> locker(data_mutex_);
97   data_ready_ = false;
98   data_prefetch_.notify_one();
99   if (!need_wait_) {
100     return true;
101   }
102 
103   for (size_t i = 0; i < kTimeoutLoopCount; ++i) {
104     if (data_process_.wait_for(locker, std::chrono::seconds(kLongestTimeToWait),
105                                [this] { return data_ready_ == true || need_wait_ == false; })) {
106       return true;
107     } else {
108       MS_LOG(INFO) << "Waiting for ps data prefetch, channel name:" << channel_name << "...(" << i << " / 10)";
109     }
110   }
111   MS_LOG(ERROR) << "Ps cache data prefetch timeout.";
112   return false;
113 }
114 
QueryData(const std::string & channel_name,void ** data_ptr) const115 bool PsDataPrefetch::QueryData(const std::string &channel_name, void **data_ptr) const {
116   if (invalid_data_type_) {
117     return false;
118   }
119   if (data_ptr == nullptr) {
120     return false;
121   }
122   auto channel = ps_data_channel(channel_name);
123   if (channel == nullptr) {
124     *data_ptr = nullptr;
125     return true;
126   }
127   *data_ptr = const_cast<void *>(channel->data());
128   return true;
129 }
130 
data_size(const std::string & channel_name) const131 size_t PsDataPrefetch::data_size(const std::string &channel_name) const {
132   auto channel = ps_data_channel(channel_name);
133   if (channel == nullptr) {
134     return 0;
135   }
136   return channel->data_size();
137 }
138 
NotifyFinalize()139 void PsDataPrefetch::NotifyFinalize() {
140   need_wait_ = false;
141   WakeAllChannel();
142   data_prefetch_.notify_one();
143   data_process_.notify_one();
144 }
145 
TryWakeChannel(const std::string & channel_name)146 bool PsDataPrefetch::TryWakeChannel(const std::string &channel_name) {
147   auto channel = ps_data_channel(channel_name);
148   if (channel == nullptr) {
149     return false;
150   }
151   channel->TryWakeChannel();
152   return true;
153 }
154 
WakeAllChannel()155 void PsDataPrefetch::WakeAllChannel() {
156   for (auto iter = ps_data_channel_map_.begin(); iter != ps_data_channel_map_.end(); ++iter) {
157     auto channel = iter->second;
158     if (channel == nullptr) {
159       return;
160     }
161     channel->TryWakeChannel(true);
162   }
163 }
164 }  // namespace ps
165 }  // namespace mindspore
166