• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/contrib/ignite/kernels/dataset/ignite_dataset_iterator.h"
17 
18 #include "tensorflow/contrib/ignite/kernels/client/ignite_plain_client.h"
19 #include "tensorflow/contrib/ignite/kernels/client/ignite_ssl_wrapper.h"
20 #include "tensorflow/core/lib/gtl/cleanup.h"
21 #include "tensorflow/core/platform/logging.h"
22 
23 namespace tensorflow {
24 
IgniteDatasetIterator(const Params & params,string host,int32 port,string cache_name,bool local,int32 part,int32 page_size,string username,string password,string certfile,string keyfile,string cert_password,std::vector<int32> schema,std::vector<int32> permutation)25 IgniteDatasetIterator::IgniteDatasetIterator(
26     const Params& params, string host, int32 port, string cache_name,
27     bool local, int32 part, int32 page_size, string username, string password,
28     string certfile, string keyfile, string cert_password,
29     std::vector<int32> schema, std::vector<int32> permutation)
30     : DatasetIterator<IgniteDataset>(params),
31       cache_name_(std::move(cache_name)),
32       local_(local),
33       part_(part),
34       page_size_(page_size),
35       username_(std::move(username)),
36       password_(std::move(password)),
37       schema_(std::move(schema)),
38       permutation_(std::move(permutation)),
39       remainder_(-1),
40       cursor_id_(-1),
41       last_page_(false),
42       valid_state_(true) {
43   Client* p_client = new PlainClient(std::move(host), port, false);
44 
45   if (certfile.empty())
46     client_ = std::unique_ptr<Client>(p_client);
47   else
48     client_ = std::unique_ptr<Client>(
49         new SslWrapper(std::unique_ptr<Client>(p_client), std::move(certfile),
50                        std::move(keyfile), std::move(cert_password), false));
51 
52   LOG(INFO) << "Ignite Dataset Iterator created";
53 }
54 
~IgniteDatasetIterator()55 IgniteDatasetIterator::~IgniteDatasetIterator() {
56   Status status = CloseConnection();
57   if (!status.ok()) LOG(ERROR) << status.ToString();
58 
59   LOG(INFO) << "Ignite Dataset Iterator destroyed";
60 }
61 
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)62 Status IgniteDatasetIterator::GetNextInternal(IteratorContext* ctx,
63                                               std::vector<Tensor>* out_tensors,
64                                               bool* end_of_sequence) {
65   mutex_lock l(mutex_);
66 
67   if (valid_state_) {
68     Status status =
69         GetNextInternalWithValidState(ctx, out_tensors, end_of_sequence);
70 
71     if (!status.ok()) valid_state_ = false;
72 
73     return status;
74   }
75 
76   return errors::Unknown("Iterator is invalid");
77 }
78 
SaveInternal(IteratorStateWriter * writer)79 Status IgniteDatasetIterator::SaveInternal(IteratorStateWriter* writer) {
80   return errors::Unimplemented(
81       "Iterator for IgniteDataset does not support 'SaveInternal'");
82 }
83 
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)84 Status IgniteDatasetIterator::RestoreInternal(IteratorContext* ctx,
85                                               IteratorStateReader* reader) {
86   return errors::Unimplemented(
87       "Iterator for IgniteDataset does not support 'RestoreInternal')");
88 }
89 
GetNextInternalWithValidState(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)90 Status IgniteDatasetIterator::GetNextInternalWithValidState(
91     IteratorContext* ctx, std::vector<Tensor>* out_tensors,
92     bool* end_of_sequence) {
93   if (remainder_ == 0 && last_page_) {
94     cursor_id_ = -1;
95     *end_of_sequence = true;
96 
97     return Status::OK();
98   } else {
99     TF_RETURN_IF_ERROR(EstablishConnection());
100 
101     if (remainder_ == -1) {
102       TF_RETURN_IF_ERROR(ScanQuery());
103     } else if (remainder_ == 0) {
104       TF_RETURN_IF_ERROR(LoadNextPage());
105     }
106 
107     uint8_t* initial_ptr = ptr_;
108     std::vector<Tensor> tensors;
109     std::vector<int32_t> types;
110 
111     TF_RETURN_IF_ERROR(parser_.Parse(&ptr_, &tensors, &types));  // Parse key
112     TF_RETURN_IF_ERROR(parser_.Parse(&ptr_, &tensors, &types));  // Parse val
113 
114     remainder_ -= (ptr_ - initial_ptr);
115 
116     TF_RETURN_IF_ERROR(CheckTypes(types));
117 
118     for (size_t i = 0; i < tensors.size(); i++)
119       out_tensors->push_back(tensors[permutation_[i]]);
120 
121     *end_of_sequence = false;
122 
123     return Status::OK();
124   }
125 
126   *end_of_sequence = true;
127 
128   return Status::OK();
129 }
130 
EstablishConnection()131 Status IgniteDatasetIterator::EstablishConnection() {
132   if (!client_->IsConnected()) {
133     TF_RETURN_IF_ERROR(client_->Connect());
134 
135     Status status = Handshake();
136     if (!status.ok()) {
137       Status disconnect_status = client_->Disconnect();
138       if (!disconnect_status.ok()) LOG(ERROR) << disconnect_status.ToString();
139 
140       return status;
141     }
142   }
143 
144   return Status::OK();
145 }
146 
CloseConnection()147 Status IgniteDatasetIterator::CloseConnection() {
148   if (cursor_id_ != -1 && !last_page_) {
149     TF_RETURN_IF_ERROR(EstablishConnection());
150 
151     TF_RETURN_IF_ERROR(client_->WriteInt(kCloseConnectionReqLength));
152     TF_RETURN_IF_ERROR(client_->WriteShort(kCloseConnectionOpcode));
153     TF_RETURN_IF_ERROR(client_->WriteLong(0));           // Request ID
154     TF_RETURN_IF_ERROR(client_->WriteLong(cursor_id_));  // Resource ID
155 
156     int32_t res_len;
157     TF_RETURN_IF_ERROR(client_->ReadInt(&res_len));
158     if (res_len < kMinResLength)
159       return errors::Unknown("Close Resource Response is corrupted");
160 
161     int64_t req_id;
162     TF_RETURN_IF_ERROR(client_->ReadLong(&req_id));
163     int32_t status;
164     TF_RETURN_IF_ERROR(client_->ReadInt(&status));
165     if (status != 0) {
166       uint8_t err_msg_header;
167       TF_RETURN_IF_ERROR(client_->ReadByte(&err_msg_header));
168       if (err_msg_header == kStringVal) {
169         int32_t err_msg_length;
170         TF_RETURN_IF_ERROR(client_->ReadInt(&err_msg_length));
171 
172         uint8_t* err_msg_c = new uint8_t[err_msg_length];
173         auto clean = gtl::MakeCleanup([err_msg_c] { delete[] err_msg_c; });
174         TF_RETURN_IF_ERROR(client_->ReadData(err_msg_c, err_msg_length));
175         string err_msg(reinterpret_cast<char*>(err_msg_c), err_msg_length);
176 
177         return errors::Unknown("Close Resource Error [status=", status,
178                                ", message=", err_msg, "]");
179       }
180       return errors::Unknown("Close Resource Error [status=", status, "]");
181     }
182 
183     cursor_id_ = -1;
184 
185     return client_->Disconnect();
186   } else {
187     LOG(INFO) << "Query Cursor " << cursor_id_ << " is already closed";
188   }
189 
190   return client_->IsConnected() ? client_->Disconnect() : Status::OK();
191 }
192 
Handshake()193 Status IgniteDatasetIterator::Handshake() {
194   int32_t msg_len = kHandshakeReqDefaultLength;
195 
196   if (username_.empty())
197     msg_len += 1;
198   else
199     msg_len += 5 + username_.length();  // 1 byte header, 4 bytes length.
200 
201   if (password_.empty())
202     msg_len += 1;
203   else
204     msg_len += 5 + password_.length();  // 1 byte header, 4 bytes length.
205 
206   TF_RETURN_IF_ERROR(client_->WriteInt(msg_len));
207   TF_RETURN_IF_ERROR(client_->WriteByte(1));
208   TF_RETURN_IF_ERROR(client_->WriteShort(kProtocolMajorVersion));
209   TF_RETURN_IF_ERROR(client_->WriteShort(kProtocolMinorVersion));
210   TF_RETURN_IF_ERROR(client_->WriteShort(kProtocolPatchVersion));
211   TF_RETURN_IF_ERROR(client_->WriteByte(2));
212   if (username_.empty()) {
213     TF_RETURN_IF_ERROR(client_->WriteByte(kNullVal));
214   } else {
215     TF_RETURN_IF_ERROR(client_->WriteByte(kStringVal));
216     TF_RETURN_IF_ERROR(client_->WriteInt(username_.length()));
217     TF_RETURN_IF_ERROR(
218         client_->WriteData(reinterpret_cast<const uint8_t*>(username_.c_str()),
219                            username_.length()));
220   }
221 
222   if (password_.empty()) {
223     TF_RETURN_IF_ERROR(client_->WriteByte(kNullVal));
224   } else {
225     TF_RETURN_IF_ERROR(client_->WriteByte(kStringVal));
226     TF_RETURN_IF_ERROR(client_->WriteInt(password_.length()));
227     TF_RETURN_IF_ERROR(
228         client_->WriteData(reinterpret_cast<const uint8_t*>(password_.c_str()),
229                            password_.length()));
230   }
231 
232   int32_t handshake_res_len;
233   TF_RETURN_IF_ERROR(client_->ReadInt(&handshake_res_len));
234   uint8_t handshake_res;
235   TF_RETURN_IF_ERROR(client_->ReadByte(&handshake_res));
236 
237   if (handshake_res != 1) {
238     int16_t serv_ver_major;
239     TF_RETURN_IF_ERROR(client_->ReadShort(&serv_ver_major));
240     int16_t serv_ver_minor;
241     TF_RETURN_IF_ERROR(client_->ReadShort(&serv_ver_minor));
242     int16_t serv_ver_patch;
243     TF_RETURN_IF_ERROR(client_->ReadShort(&serv_ver_patch));
244     uint8_t header;
245     TF_RETURN_IF_ERROR(client_->ReadByte(&header));
246 
247     if (header == kStringVal) {
248       int32_t length;
249       TF_RETURN_IF_ERROR(client_->ReadInt(&length));
250 
251       uint8_t* err_msg_c = new uint8_t[length];
252       auto clean = gtl::MakeCleanup([err_msg_c] { delete[] err_msg_c; });
253       TF_RETURN_IF_ERROR(client_->ReadData(err_msg_c, length));
254       string err_msg(reinterpret_cast<char*>(err_msg_c), length);
255 
256       return errors::Unknown("Handshake Error [result=", handshake_res,
257                              ", version=", serv_ver_major, ".", serv_ver_minor,
258                              ".", serv_ver_patch, ", message='", err_msg, "']");
259     } else if (header == kNullVal) {
260       return errors::Unknown("Handshake Error [result=", handshake_res,
261                              ", version=", serv_ver_major, ".", serv_ver_minor,
262                              ".", serv_ver_patch, "]");
263     } else {
264       return errors::Unknown("Handshake Error [result=", handshake_res,
265                              ", version=", serv_ver_major, ".", serv_ver_minor,
266                              ".", serv_ver_patch, "]");
267     }
268   }
269 
270   return Status::OK();
271 }
272 
ScanQuery()273 Status IgniteDatasetIterator::ScanQuery() {
274   TF_RETURN_IF_ERROR(client_->WriteInt(kScanQueryReqLength));
275   TF_RETURN_IF_ERROR(client_->WriteShort(kScanQueryOpcode));
276   TF_RETURN_IF_ERROR(client_->WriteLong(0));  // Request ID
277   TF_RETURN_IF_ERROR(
278       client_->WriteInt(JavaHashCode(cache_name_)));  // Cache name
279   TF_RETURN_IF_ERROR(client_->WriteByte(0));          // Flags
280   TF_RETURN_IF_ERROR(client_->WriteByte(kNullVal));   // Filter object
281   TF_RETURN_IF_ERROR(client_->WriteInt(page_size_));  // Cursor page size
282   TF_RETURN_IF_ERROR(client_->WriteInt(part_));       // part_ition to query
283   TF_RETURN_IF_ERROR(client_->WriteByte(local_));     // local_ flag
284 
285   uint64 wait_start = Env::Default()->NowMicros();
286   int32_t res_len;
287   TF_RETURN_IF_ERROR(client_->ReadInt(&res_len));
288   int64_t wait_stop = Env::Default()->NowMicros();
289 
290   LOG(INFO) << "Scan Query waited " << (wait_stop - wait_start) / 1000 << " ms";
291 
292   if (res_len < kMinResLength)
293     return errors::Unknown("Scan Query Response is corrupted");
294 
295   int64_t req_id;
296   TF_RETURN_IF_ERROR(client_->ReadLong(&req_id));
297 
298   int32_t status;
299   TF_RETURN_IF_ERROR(client_->ReadInt(&status));
300 
301   if (status != 0) {
302     uint8_t err_msg_header;
303     TF_RETURN_IF_ERROR(client_->ReadByte(&err_msg_header));
304 
305     if (err_msg_header == kStringVal) {
306       int32_t err_msg_length;
307       TF_RETURN_IF_ERROR(client_->ReadInt(&err_msg_length));
308 
309       uint8_t* err_msg_c = new uint8_t[err_msg_length];
310       auto clean = gtl::MakeCleanup([err_msg_c] { delete[] err_msg_c; });
311       TF_RETURN_IF_ERROR(client_->ReadData(err_msg_c, err_msg_length));
312       string err_msg(reinterpret_cast<char*>(err_msg_c), err_msg_length);
313 
314       return errors::Unknown("Scan Query Error [status=", status,
315                              ", message=", err_msg, "]");
316     }
317     return errors::Unknown("Scan Query Error [status=", status, "]");
318   }
319 
320   TF_RETURN_IF_ERROR(client_->ReadLong(&cursor_id_));
321 
322   int32_t row_cnt;
323   TF_RETURN_IF_ERROR(client_->ReadInt(&row_cnt));
324 
325   int32_t page_size = res_len - kScanQueryResHeaderLength;
326 
327   return ReceivePage(page_size);
328 }
329 
LoadNextPage()330 Status IgniteDatasetIterator::LoadNextPage() {
331   TF_RETURN_IF_ERROR(client_->WriteInt(kLoadNextPageReqLength));
332   TF_RETURN_IF_ERROR(client_->WriteShort(kLoadNextPageOpcode));
333   TF_RETURN_IF_ERROR(client_->WriteLong(0));           // Request ID
334   TF_RETURN_IF_ERROR(client_->WriteLong(cursor_id_));  // Cursor ID
335 
336   uint64 wait_start = Env::Default()->NowMicros();
337   int32_t res_len;
338   TF_RETURN_IF_ERROR(client_->ReadInt(&res_len));
339   uint64 wait_stop = Env::Default()->NowMicros();
340 
341   LOG(INFO) << "Load Next Page waited " << (wait_stop - wait_start) / 1000
342             << " ms";
343 
344   if (res_len < kMinResLength)
345     return errors::Unknown("Load Next Page Response is corrupted");
346 
347   int64_t req_id;
348   TF_RETURN_IF_ERROR(client_->ReadLong(&req_id));
349 
350   int32_t status;
351   TF_RETURN_IF_ERROR(client_->ReadInt(&status));
352 
353   if (status != 0) {
354     uint8_t err_msg_header;
355     TF_RETURN_IF_ERROR(client_->ReadByte(&err_msg_header));
356 
357     if (err_msg_header == kStringVal) {
358       int32_t err_msg_length;
359       TF_RETURN_IF_ERROR(client_->ReadInt(&err_msg_length));
360 
361       uint8_t* err_msg_c = new uint8_t[err_msg_length];
362       auto clean = gtl::MakeCleanup([err_msg_c] { delete[] err_msg_c; });
363       TF_RETURN_IF_ERROR(client_->ReadData(err_msg_c, err_msg_length));
364       string err_msg(reinterpret_cast<char*>(err_msg_c), err_msg_length);
365 
366       return errors::Unknown("Load Next Page Error [status=", status,
367                              ", message=", err_msg, "]");
368     }
369     return errors::Unknown("Load Next Page Error [status=", status, "]");
370   }
371 
372   int32_t row_cnt;
373   TF_RETURN_IF_ERROR(client_->ReadInt(&row_cnt));
374 
375   int32_t page_size = res_len - kLoadNextPageResHeaderLength;
376 
377   return ReceivePage(page_size);
378 }
379 
ReceivePage(int32_t page_size)380 Status IgniteDatasetIterator::ReceivePage(int32_t page_size) {
381   remainder_ = page_size;
382   page_ = std::unique_ptr<uint8_t>(new uint8_t[remainder_]);
383   ptr_ = page_.get();
384 
385   uint64 start = Env::Default()->NowMicros();
386   TF_RETURN_IF_ERROR(client_->ReadData(ptr_, remainder_));
387   uint64 stop = Env::Default()->NowMicros();
388 
389   double size_in_mb = 1.0 * remainder_ / 1024 / 1024;
390   double time_in_s = 1.0 * (stop - start) / 1000 / 1000;
391   LOG(INFO) << "Page size " << size_in_mb << " Mb, time " << time_in_s * 1000
392             << " ms download speed " << size_in_mb / time_in_s << " Mb/sec";
393 
394   uint8_t last_page_b;
395   TF_RETURN_IF_ERROR(client_->ReadByte(&last_page_b));
396 
397   last_page_ = !last_page_b;
398 
399   return Status::OK();
400 }
401 
CheckTypes(const std::vector<int32_t> & types)402 Status IgniteDatasetIterator::CheckTypes(const std::vector<int32_t>& types) {
403   if (schema_.size() != types.size())
404     return errors::Unknown("Object has unexpected schema");
405 
406   for (size_t i = 0; i < schema_.size(); i++) {
407     if (schema_[i] != types[permutation_[i]])
408       return errors::Unknown("Object has unexpected schema");
409   }
410 
411   return Status::OK();
412 }
413 
JavaHashCode(string str) const414 int32_t IgniteDatasetIterator::JavaHashCode(string str) const {
415   int32_t h = 0;
416   for (char& c : str) {
417     h = 31 * h + c;
418   }
419   return h;
420 }
421 
422 }  // namespace tensorflow
423