1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/contrib/ignite/kernels/dataset/ignite_dataset_iterator.h"
17
18 #include "tensorflow/contrib/ignite/kernels/client/ignite_plain_client.h"
19 #include "tensorflow/contrib/ignite/kernels/client/ignite_ssl_wrapper.h"
20 #include "tensorflow/core/lib/gtl/cleanup.h"
21 #include "tensorflow/core/platform/logging.h"
22
23 namespace tensorflow {
24
IgniteDatasetIterator(const Params & params,string host,int32 port,string cache_name,bool local,int32 part,int32 page_size,string username,string password,string certfile,string keyfile,string cert_password,std::vector<int32> schema,std::vector<int32> permutation)25 IgniteDatasetIterator::IgniteDatasetIterator(
26 const Params& params, string host, int32 port, string cache_name,
27 bool local, int32 part, int32 page_size, string username, string password,
28 string certfile, string keyfile, string cert_password,
29 std::vector<int32> schema, std::vector<int32> permutation)
30 : DatasetIterator<IgniteDataset>(params),
31 cache_name_(std::move(cache_name)),
32 local_(local),
33 part_(part),
34 page_size_(page_size),
35 username_(std::move(username)),
36 password_(std::move(password)),
37 schema_(std::move(schema)),
38 permutation_(std::move(permutation)),
39 remainder_(-1),
40 cursor_id_(-1),
41 last_page_(false),
42 valid_state_(true) {
43 Client* p_client = new PlainClient(std::move(host), port, false);
44
45 if (certfile.empty())
46 client_ = std::unique_ptr<Client>(p_client);
47 else
48 client_ = std::unique_ptr<Client>(
49 new SslWrapper(std::unique_ptr<Client>(p_client), std::move(certfile),
50 std::move(keyfile), std::move(cert_password), false));
51
52 LOG(INFO) << "Ignite Dataset Iterator created";
53 }
54
~IgniteDatasetIterator()55 IgniteDatasetIterator::~IgniteDatasetIterator() {
56 Status status = CloseConnection();
57 if (!status.ok()) LOG(ERROR) << status.ToString();
58
59 LOG(INFO) << "Ignite Dataset Iterator destroyed";
60 }
61
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)62 Status IgniteDatasetIterator::GetNextInternal(IteratorContext* ctx,
63 std::vector<Tensor>* out_tensors,
64 bool* end_of_sequence) {
65 mutex_lock l(mutex_);
66
67 if (valid_state_) {
68 Status status =
69 GetNextInternalWithValidState(ctx, out_tensors, end_of_sequence);
70
71 if (!status.ok()) valid_state_ = false;
72
73 return status;
74 }
75
76 return errors::Unknown("Iterator is invalid");
77 }
78
SaveInternal(IteratorStateWriter * writer)79 Status IgniteDatasetIterator::SaveInternal(IteratorStateWriter* writer) {
80 return errors::Unimplemented(
81 "Iterator for IgniteDataset does not support 'SaveInternal'");
82 }
83
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)84 Status IgniteDatasetIterator::RestoreInternal(IteratorContext* ctx,
85 IteratorStateReader* reader) {
86 return errors::Unimplemented(
87 "Iterator for IgniteDataset does not support 'RestoreInternal')");
88 }
89
GetNextInternalWithValidState(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)90 Status IgniteDatasetIterator::GetNextInternalWithValidState(
91 IteratorContext* ctx, std::vector<Tensor>* out_tensors,
92 bool* end_of_sequence) {
93 if (remainder_ == 0 && last_page_) {
94 cursor_id_ = -1;
95 *end_of_sequence = true;
96
97 return Status::OK();
98 } else {
99 TF_RETURN_IF_ERROR(EstablishConnection());
100
101 if (remainder_ == -1) {
102 TF_RETURN_IF_ERROR(ScanQuery());
103 } else if (remainder_ == 0) {
104 TF_RETURN_IF_ERROR(LoadNextPage());
105 }
106
107 uint8_t* initial_ptr = ptr_;
108 std::vector<Tensor> tensors;
109 std::vector<int32_t> types;
110
111 TF_RETURN_IF_ERROR(parser_.Parse(&ptr_, &tensors, &types)); // Parse key
112 TF_RETURN_IF_ERROR(parser_.Parse(&ptr_, &tensors, &types)); // Parse val
113
114 remainder_ -= (ptr_ - initial_ptr);
115
116 TF_RETURN_IF_ERROR(CheckTypes(types));
117
118 for (size_t i = 0; i < tensors.size(); i++)
119 out_tensors->push_back(tensors[permutation_[i]]);
120
121 *end_of_sequence = false;
122
123 return Status::OK();
124 }
125
126 *end_of_sequence = true;
127
128 return Status::OK();
129 }
130
EstablishConnection()131 Status IgniteDatasetIterator::EstablishConnection() {
132 if (!client_->IsConnected()) {
133 TF_RETURN_IF_ERROR(client_->Connect());
134
135 Status status = Handshake();
136 if (!status.ok()) {
137 Status disconnect_status = client_->Disconnect();
138 if (!disconnect_status.ok()) LOG(ERROR) << disconnect_status.ToString();
139
140 return status;
141 }
142 }
143
144 return Status::OK();
145 }
146
CloseConnection()147 Status IgniteDatasetIterator::CloseConnection() {
148 if (cursor_id_ != -1 && !last_page_) {
149 TF_RETURN_IF_ERROR(EstablishConnection());
150
151 TF_RETURN_IF_ERROR(client_->WriteInt(kCloseConnectionReqLength));
152 TF_RETURN_IF_ERROR(client_->WriteShort(kCloseConnectionOpcode));
153 TF_RETURN_IF_ERROR(client_->WriteLong(0)); // Request ID
154 TF_RETURN_IF_ERROR(client_->WriteLong(cursor_id_)); // Resource ID
155
156 int32_t res_len;
157 TF_RETURN_IF_ERROR(client_->ReadInt(&res_len));
158 if (res_len < kMinResLength)
159 return errors::Unknown("Close Resource Response is corrupted");
160
161 int64_t req_id;
162 TF_RETURN_IF_ERROR(client_->ReadLong(&req_id));
163 int32_t status;
164 TF_RETURN_IF_ERROR(client_->ReadInt(&status));
165 if (status != 0) {
166 uint8_t err_msg_header;
167 TF_RETURN_IF_ERROR(client_->ReadByte(&err_msg_header));
168 if (err_msg_header == kStringVal) {
169 int32_t err_msg_length;
170 TF_RETURN_IF_ERROR(client_->ReadInt(&err_msg_length));
171
172 uint8_t* err_msg_c = new uint8_t[err_msg_length];
173 auto clean = gtl::MakeCleanup([err_msg_c] { delete[] err_msg_c; });
174 TF_RETURN_IF_ERROR(client_->ReadData(err_msg_c, err_msg_length));
175 string err_msg(reinterpret_cast<char*>(err_msg_c), err_msg_length);
176
177 return errors::Unknown("Close Resource Error [status=", status,
178 ", message=", err_msg, "]");
179 }
180 return errors::Unknown("Close Resource Error [status=", status, "]");
181 }
182
183 cursor_id_ = -1;
184
185 return client_->Disconnect();
186 } else {
187 LOG(INFO) << "Query Cursor " << cursor_id_ << " is already closed";
188 }
189
190 return client_->IsConnected() ? client_->Disconnect() : Status::OK();
191 }
192
Handshake()193 Status IgniteDatasetIterator::Handshake() {
194 int32_t msg_len = kHandshakeReqDefaultLength;
195
196 if (username_.empty())
197 msg_len += 1;
198 else
199 msg_len += 5 + username_.length(); // 1 byte header, 4 bytes length.
200
201 if (password_.empty())
202 msg_len += 1;
203 else
204 msg_len += 5 + password_.length(); // 1 byte header, 4 bytes length.
205
206 TF_RETURN_IF_ERROR(client_->WriteInt(msg_len));
207 TF_RETURN_IF_ERROR(client_->WriteByte(1));
208 TF_RETURN_IF_ERROR(client_->WriteShort(kProtocolMajorVersion));
209 TF_RETURN_IF_ERROR(client_->WriteShort(kProtocolMinorVersion));
210 TF_RETURN_IF_ERROR(client_->WriteShort(kProtocolPatchVersion));
211 TF_RETURN_IF_ERROR(client_->WriteByte(2));
212 if (username_.empty()) {
213 TF_RETURN_IF_ERROR(client_->WriteByte(kNullVal));
214 } else {
215 TF_RETURN_IF_ERROR(client_->WriteByte(kStringVal));
216 TF_RETURN_IF_ERROR(client_->WriteInt(username_.length()));
217 TF_RETURN_IF_ERROR(
218 client_->WriteData(reinterpret_cast<const uint8_t*>(username_.c_str()),
219 username_.length()));
220 }
221
222 if (password_.empty()) {
223 TF_RETURN_IF_ERROR(client_->WriteByte(kNullVal));
224 } else {
225 TF_RETURN_IF_ERROR(client_->WriteByte(kStringVal));
226 TF_RETURN_IF_ERROR(client_->WriteInt(password_.length()));
227 TF_RETURN_IF_ERROR(
228 client_->WriteData(reinterpret_cast<const uint8_t*>(password_.c_str()),
229 password_.length()));
230 }
231
232 int32_t handshake_res_len;
233 TF_RETURN_IF_ERROR(client_->ReadInt(&handshake_res_len));
234 uint8_t handshake_res;
235 TF_RETURN_IF_ERROR(client_->ReadByte(&handshake_res));
236
237 if (handshake_res != 1) {
238 int16_t serv_ver_major;
239 TF_RETURN_IF_ERROR(client_->ReadShort(&serv_ver_major));
240 int16_t serv_ver_minor;
241 TF_RETURN_IF_ERROR(client_->ReadShort(&serv_ver_minor));
242 int16_t serv_ver_patch;
243 TF_RETURN_IF_ERROR(client_->ReadShort(&serv_ver_patch));
244 uint8_t header;
245 TF_RETURN_IF_ERROR(client_->ReadByte(&header));
246
247 if (header == kStringVal) {
248 int32_t length;
249 TF_RETURN_IF_ERROR(client_->ReadInt(&length));
250
251 uint8_t* err_msg_c = new uint8_t[length];
252 auto clean = gtl::MakeCleanup([err_msg_c] { delete[] err_msg_c; });
253 TF_RETURN_IF_ERROR(client_->ReadData(err_msg_c, length));
254 string err_msg(reinterpret_cast<char*>(err_msg_c), length);
255
256 return errors::Unknown("Handshake Error [result=", handshake_res,
257 ", version=", serv_ver_major, ".", serv_ver_minor,
258 ".", serv_ver_patch, ", message='", err_msg, "']");
259 } else if (header == kNullVal) {
260 return errors::Unknown("Handshake Error [result=", handshake_res,
261 ", version=", serv_ver_major, ".", serv_ver_minor,
262 ".", serv_ver_patch, "]");
263 } else {
264 return errors::Unknown("Handshake Error [result=", handshake_res,
265 ", version=", serv_ver_major, ".", serv_ver_minor,
266 ".", serv_ver_patch, "]");
267 }
268 }
269
270 return Status::OK();
271 }
272
ScanQuery()273 Status IgniteDatasetIterator::ScanQuery() {
274 TF_RETURN_IF_ERROR(client_->WriteInt(kScanQueryReqLength));
275 TF_RETURN_IF_ERROR(client_->WriteShort(kScanQueryOpcode));
276 TF_RETURN_IF_ERROR(client_->WriteLong(0)); // Request ID
277 TF_RETURN_IF_ERROR(
278 client_->WriteInt(JavaHashCode(cache_name_))); // Cache name
279 TF_RETURN_IF_ERROR(client_->WriteByte(0)); // Flags
280 TF_RETURN_IF_ERROR(client_->WriteByte(kNullVal)); // Filter object
281 TF_RETURN_IF_ERROR(client_->WriteInt(page_size_)); // Cursor page size
282 TF_RETURN_IF_ERROR(client_->WriteInt(part_)); // part_ition to query
283 TF_RETURN_IF_ERROR(client_->WriteByte(local_)); // local_ flag
284
285 uint64 wait_start = Env::Default()->NowMicros();
286 int32_t res_len;
287 TF_RETURN_IF_ERROR(client_->ReadInt(&res_len));
288 int64_t wait_stop = Env::Default()->NowMicros();
289
290 LOG(INFO) << "Scan Query waited " << (wait_stop - wait_start) / 1000 << " ms";
291
292 if (res_len < kMinResLength)
293 return errors::Unknown("Scan Query Response is corrupted");
294
295 int64_t req_id;
296 TF_RETURN_IF_ERROR(client_->ReadLong(&req_id));
297
298 int32_t status;
299 TF_RETURN_IF_ERROR(client_->ReadInt(&status));
300
301 if (status != 0) {
302 uint8_t err_msg_header;
303 TF_RETURN_IF_ERROR(client_->ReadByte(&err_msg_header));
304
305 if (err_msg_header == kStringVal) {
306 int32_t err_msg_length;
307 TF_RETURN_IF_ERROR(client_->ReadInt(&err_msg_length));
308
309 uint8_t* err_msg_c = new uint8_t[err_msg_length];
310 auto clean = gtl::MakeCleanup([err_msg_c] { delete[] err_msg_c; });
311 TF_RETURN_IF_ERROR(client_->ReadData(err_msg_c, err_msg_length));
312 string err_msg(reinterpret_cast<char*>(err_msg_c), err_msg_length);
313
314 return errors::Unknown("Scan Query Error [status=", status,
315 ", message=", err_msg, "]");
316 }
317 return errors::Unknown("Scan Query Error [status=", status, "]");
318 }
319
320 TF_RETURN_IF_ERROR(client_->ReadLong(&cursor_id_));
321
322 int32_t row_cnt;
323 TF_RETURN_IF_ERROR(client_->ReadInt(&row_cnt));
324
325 int32_t page_size = res_len - kScanQueryResHeaderLength;
326
327 return ReceivePage(page_size);
328 }
329
LoadNextPage()330 Status IgniteDatasetIterator::LoadNextPage() {
331 TF_RETURN_IF_ERROR(client_->WriteInt(kLoadNextPageReqLength));
332 TF_RETURN_IF_ERROR(client_->WriteShort(kLoadNextPageOpcode));
333 TF_RETURN_IF_ERROR(client_->WriteLong(0)); // Request ID
334 TF_RETURN_IF_ERROR(client_->WriteLong(cursor_id_)); // Cursor ID
335
336 uint64 wait_start = Env::Default()->NowMicros();
337 int32_t res_len;
338 TF_RETURN_IF_ERROR(client_->ReadInt(&res_len));
339 uint64 wait_stop = Env::Default()->NowMicros();
340
341 LOG(INFO) << "Load Next Page waited " << (wait_stop - wait_start) / 1000
342 << " ms";
343
344 if (res_len < kMinResLength)
345 return errors::Unknown("Load Next Page Response is corrupted");
346
347 int64_t req_id;
348 TF_RETURN_IF_ERROR(client_->ReadLong(&req_id));
349
350 int32_t status;
351 TF_RETURN_IF_ERROR(client_->ReadInt(&status));
352
353 if (status != 0) {
354 uint8_t err_msg_header;
355 TF_RETURN_IF_ERROR(client_->ReadByte(&err_msg_header));
356
357 if (err_msg_header == kStringVal) {
358 int32_t err_msg_length;
359 TF_RETURN_IF_ERROR(client_->ReadInt(&err_msg_length));
360
361 uint8_t* err_msg_c = new uint8_t[err_msg_length];
362 auto clean = gtl::MakeCleanup([err_msg_c] { delete[] err_msg_c; });
363 TF_RETURN_IF_ERROR(client_->ReadData(err_msg_c, err_msg_length));
364 string err_msg(reinterpret_cast<char*>(err_msg_c), err_msg_length);
365
366 return errors::Unknown("Load Next Page Error [status=", status,
367 ", message=", err_msg, "]");
368 }
369 return errors::Unknown("Load Next Page Error [status=", status, "]");
370 }
371
372 int32_t row_cnt;
373 TF_RETURN_IF_ERROR(client_->ReadInt(&row_cnt));
374
375 int32_t page_size = res_len - kLoadNextPageResHeaderLength;
376
377 return ReceivePage(page_size);
378 }
379
ReceivePage(int32_t page_size)380 Status IgniteDatasetIterator::ReceivePage(int32_t page_size) {
381 remainder_ = page_size;
382 page_ = std::unique_ptr<uint8_t>(new uint8_t[remainder_]);
383 ptr_ = page_.get();
384
385 uint64 start = Env::Default()->NowMicros();
386 TF_RETURN_IF_ERROR(client_->ReadData(ptr_, remainder_));
387 uint64 stop = Env::Default()->NowMicros();
388
389 double size_in_mb = 1.0 * remainder_ / 1024 / 1024;
390 double time_in_s = 1.0 * (stop - start) / 1000 / 1000;
391 LOG(INFO) << "Page size " << size_in_mb << " Mb, time " << time_in_s * 1000
392 << " ms download speed " << size_in_mb / time_in_s << " Mb/sec";
393
394 uint8_t last_page_b;
395 TF_RETURN_IF_ERROR(client_->ReadByte(&last_page_b));
396
397 last_page_ = !last_page_b;
398
399 return Status::OK();
400 }
401
CheckTypes(const std::vector<int32_t> & types)402 Status IgniteDatasetIterator::CheckTypes(const std::vector<int32_t>& types) {
403 if (schema_.size() != types.size())
404 return errors::Unknown("Object has unexpected schema");
405
406 for (size_t i = 0; i < schema_.size(); i++) {
407 if (schema_[i] != types[permutation_[i]])
408 return errors::Unknown("Object has unexpected schema");
409 }
410
411 return Status::OK();
412 }
413
JavaHashCode(string str) const414 int32_t IgniteDatasetIterator::JavaHashCode(string str) const {
415 int32_t h = 0;
416 for (char& c : str) {
417 h = 31 * h + c;
418 }
419 return h;
420 }
421
422 } // namespace tensorflow
423