1 /**
2 * Copyright 2019-2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "minddata/dataset/engine/datasetops/dataset_op.h"
17
18 #include <iomanip>
19 #include <iostream>
20 #include <memory>
21 #include <regex>
22 #include <utility>
23 #include <string>
24 #include <algorithm>
25
26 #include "minddata/dataset/engine/datasetops/data_queue_op.h"
27 #include "minddata/dataset/engine/datasetops/source/sampler/sampler.h"
28
29 #include "minddata/dataset/engine/operator_connector.h"
30 #include "minddata/dataset/util/log_adapter.h"
31 #ifndef ENABLE_ANDROID
32 #include "utils/system/crc32c.h"
33 #endif
34
35 namespace mindspore {
36 namespace dataset {
37 // Constructor
DatasetOp(int32_t op_connector_size,std::shared_ptr<SamplerRT> sampler)38 DatasetOp::DatasetOp(int32_t op_connector_size, std::shared_ptr<SamplerRT> sampler)
39 : oc_queue_size_(op_connector_size),
40 sampler_(sampler),
41 operator_id_(kInvalidOperatorId),
42 tree_(nullptr),
43 state_(OpState::kDeOpIdle),
44 op_total_repeats_(kInfiniteRepeat),
45 op_num_repeats_per_epoch_(kInfiniteRepeat),
46 op_current_repeats_(0),
47 op_current_epochs_(0),
48 out_connector_(nullptr),
49 dataset_size_(-1),
50 num_classes_(-1) {
51 // The operator starts out with an invalid operator id. The only way to
52 // get it out of invalid state is to assign the operator to an execution tree.
53 }
54
55 // Adds a operator to become our child.
AddChild(std::shared_ptr<DatasetOp> child)56 Status DatasetOp::AddChild(std::shared_ptr<DatasetOp> child) {
57 if (std::dynamic_pointer_cast<DataQueueOp>(child) != nullptr) {
58 std::string err_msg(
59 "Unsupported scenario, \'send\' operator can only be after \'device_queue\' operation, but got " + Name());
60 RETURN_STATUS_UNEXPECTED(err_msg);
61 }
62 if (operator_id_ == kInvalidOperatorId) {
63 std::string err_msg(
64 "[Internal ERROR] Cannot add child node. Tree node connections can only "
65 "be made if the node belongs to a tree.");
66 RETURN_STATUS_UNEXPECTED(err_msg);
67 }
68
69 // disallow relationships with other trees
70 if (tree_ != child->tree_) {
71 std::string err_msg(
72 "Invalid operator structure, the relationship of operators should be one by one, but got too many branches.");
73 RETURN_STATUS_UNEXPECTED(err_msg);
74 }
75 child_.push_back(child);
76 child->AddParent(this);
77 return Status::OK();
78 }
79
RemoveChild(std::shared_ptr<DatasetOp> child)80 Status DatasetOp::RemoveChild(std::shared_ptr<DatasetOp> child) {
81 if (operator_id_ == kInvalidOperatorId) {
82 std::string err_msg(
83 "[Internal ERROR] Cannot remove child node. Tree node connections can only "
84 "be made if the node belongs to a tree.");
85 RETURN_STATUS_UNEXPECTED(err_msg);
86 }
87
88 // disallow relationships with other trees
89 if (tree_ != child->tree_) {
90 std::string err_msg(
91 "Invalid operator structure, the relationship of operators should be one by one, but got too many branches.");
92 RETURN_STATUS_UNEXPECTED(err_msg);
93 }
94
95 child_.erase(std::remove(child_.begin(), child_.end(), child), child_.end());
96 child->RemoveParent(this);
97 return Status::OK();
98 }
99
InsertAsParent(std::shared_ptr<DatasetOp> to_add)100 Status DatasetOp::InsertAsParent(std::shared_ptr<DatasetOp> to_add) {
101 RETURN_UNEXPECTED_IF_NULL(to_add);
102 for (auto &prev_parent : this->parent_) {
103 RETURN_IF_NOT_OK(prev_parent->RemoveChild(shared_from_this()));
104 RETURN_IF_NOT_OK(prev_parent->AddChild(to_add));
105 }
106 RETURN_IF_NOT_OK(to_add->AddChild(shared_from_this()));
107 if (tree_->root()->id() == this->id()) {
108 RETURN_IF_NOT_OK(tree_->AssignRoot(to_add));
109 }
110 return Status::OK();
111 }
112 // Removes child operator in this operator.
RemoveChildren()113 Status DatasetOp::RemoveChildren() {
114 for (const auto &child : child_) {
115 child->RemoveParent(this);
116 }
117 child_.clear();
118
119 return Status::OK();
120 }
121
122 // Adds a parent operator to this operator
AddParent(DatasetOp * parent)123 void DatasetOp::AddParent(DatasetOp *parent) { parent_.push_back(parent); }
124
125 // Removes a parent operator from this operator
RemoveParent(const DatasetOp * parent)126 void DatasetOp::RemoveParent(const DatasetOp *parent) {
127 parent_.erase(std::remove(parent_.begin(), parent_.end(), parent), parent_.end());
128 }
129
130 // Removes this node from the tree and connects it's parent/child together
Remove()131 Status DatasetOp::Remove() {
132 if (parent_.size() > 1) {
133 std::string err_msg(
134 "Invalid operator structure, the relationship between operators should be one-to-one, but encountered more than "
135 "one parent, namely: " +
136 std::to_string(parent_.size()));
137 RETURN_STATUS_UNEXPECTED(err_msg);
138 }
139 if (child_.size() > 1) {
140 std::string err_msg(
141 "Invalid operator structure, the relationship of operators should be one by one, but got too many branches.");
142 RETURN_STATUS_UNEXPECTED(err_msg);
143 }
144
145 // Scenario's when removing node B:
146 // A -> B -> C
147 // A -> B
148 // B -> C
149 //
150 // If we remove B, then first take our child A and update it's parent to be C
151 // It's possible the parent is null if we are the root node being removed.
152 if (!child_.empty()) {
153 // If we have a parent, then assign child's parent to point to our parent.
154 if (!parent_.empty()) {
155 CHECK_FAIL_RETURN_UNEXPECTED(parent_[0]->Children().size() == 1,
156 "Invalid operator structure, the relationship of operators should be one by one, "
157 "but got too many branches.");
158 child_[0]->parent_[0] = parent_[0];
159 } else {
160 // We don't have a parent, so we are the root node being removed.
161 // clear the parent list of our child so that it becomes the new root.
162 child_[0]->parent_.clear();
163 RETURN_IF_NOT_OK(tree_->AssignRoot(child_[0]));
164 }
165 }
166
167 // Next, if we had a parent, then set it's child to be our child.
168 if (!parent_.empty()) {
169 // if we have a child, then set our parent to point to it
170 if (!child_.empty()) {
171 parent_[0]->child_[0] = child_[0];
172 } else {
173 // We don't have a child, so clear the child list of the current
174 // parent because it will be empty once we are removed.
175 parent_[0]->child_.clear();
176 }
177 }
178
179 // Finally, clear "this" op's parent and child pointers since we have just
180 // disconnected it from the tree and invalidate it's fields.
181 child_.clear();
182 parent_.clear();
183 operator_id_ = kInvalidOperatorId;
184 tree_ = nullptr;
185
186 return Status::OK();
187 }
188
189 // Getter function to get a shared pointer to our child
child(int32_t child_index) const190 std::shared_ptr<DatasetOp> DatasetOp::child(int32_t child_index) const {
191 std::shared_ptr<DatasetOp> return_op = nullptr;
192 if (child_.empty()) {
193 return return_op;
194 }
195 MS_ASSERT(child_index < static_cast<int>(child_.size()));
196 // Return a shared pointer
197 return child_[child_index];
198 }
199
200 // Getter function to get the parent pointer
Parent(DatasetOp ** parent,int32_t parent_index) const201 void DatasetOp::Parent(DatasetOp **parent, int32_t parent_index) const {
202 if (parent_.empty()) {
203 // common case if this is a root node
204 *parent = nullptr;
205 } else {
206 MS_ASSERT(parent_index < static_cast<int>(parent_.size()));
207 *parent = parent_[parent_index];
208 }
209 }
210
211 // Getter function to get all of our parents.
parents() const212 std::vector<DatasetOp *> DatasetOp::parents() const { return parent_; }
213
214 // Creates the connector within this operator
CreateConnector()215 void DatasetOp::CreateConnector() {
216 MS_LOG(DEBUG) << "Creating connector in tree operator: " << operator_id_ << ".";
217 if (oc_queue_size_ > 0) {
218 out_connector_ = std::make_unique<OperatorConnector>(oc_queue_size_);
219 } else {
220 // Some op's may choose not to have an output connector
221 MS_LOG(DEBUG) << "Bypassed connector creation for tree operator: " << operator_id_ << ".";
222 out_connector_ = nullptr;
223 }
224 }
225
226 // A print method typically used for debugging. showAll of true will recursively descend to child prints
Print(std::ostream & out,bool show_all) const227 void DatasetOp::Print(std::ostream &out, bool show_all) const {
228 // When show_all is false, we display a 1 liner piece of text for the op.
229 // When show_all is true, we display more detailed output for the op.
230 // Derived printers should show their own header info, then call base class printer, followed by
231 // derived-specific items.
232
233 // Always show the id and name as first line regardless if this summary or detailed print
234 out << "(" << std::setw(2) << operator_id_ << ") <" << Name() << ">:";
235
236 if (show_all) {
237 // The detailed display will show common base class info of the op. Allow the derived class to print
238 // it's own id and name though as the first line.
239 out << "\nNumber of children : " << child_.size();
240 for (size_t i = 0; i < child_.size(); i++) {
241 out << "\n Child[" << i << "] id: " << child_[i]->id();
242 }
243 out << "\nNumber of parents : " << parent_.size();
244 for (size_t i = 0; i < parent_.size(); i++) {
245 out << "\n Parent[" << i << "] id: " << parent_[i]->id();
246 }
247 out << "\nConnector queue size : " << oc_queue_size_ << "\nTotal repeats : " << op_total_repeats_
248 << "\nNumber repeats per epoch : " << op_num_repeats_per_epoch_;
249 if (sampler_) {
250 out << "\nSampler:\n";
251 sampler_->SamplerPrint(out, show_all);
252 }
253 }
254 }
255
GetNextRowPullMode(TensorRow * const row)256 Status DatasetOp::GetNextRowPullMode(TensorRow *const row) {
257 RETURN_UNEXPECTED_IF_NULL(row);
258 if (child_.empty()) {
259 MS_LOG(DEBUG) << "No child for operator [" << Name() << "].";
260 return Status::OK();
261 }
262 RETURN_UNEXPECTED_IF_NULL(child_[0]);
263 return child_[0]->GetNextRowPullMode(row);
264 }
265
266 // Gets the next row from the given child
GetNextRow(TensorRow * row)267 Status DatasetOp::GetNextRow(TensorRow *row) {
268 RETURN_UNEXPECTED_IF_NULL(row);
269 // pop is a blocked call and will throw an interruption if the whole group shuts down.
270 RETURN_IF_NOT_OK(out_connector_->PopFront(row));
271 return Status::OK();
272 }
273
274 // Gets the number of classes
GetNumClasses(int64_t * num_classes)275 Status DatasetOp::GetNumClasses(int64_t *num_classes) {
276 RETURN_UNEXPECTED_IF_NULL(num_classes);
277 if (child_.size() == 1) {
278 return child_[0]->GetNumClasses(num_classes);
279 } else if (child_.size() > 1) {
280 // It is okay for dataset to have more than 1 child, GetNumClasses shouldn't fail in this case.
281 // This is done mostly for cache, which injects cache lookup/merge operators. Cache path will
282 // always be in front of the child_ structure, so we get num classes from the last child.
283 return child_[child_.size() - 1]->GetNumClasses(num_classes);
284 } else {
285 // when num classes isn't found, the default behavior is to return -1
286 MS_LOG(WARNING) << "Num classes not defined for : " << Name();
287 *num_classes = -1;
288 return Status::OK();
289 }
290 }
291
GetClassIndexing(std::vector<std::pair<std::string,std::vector<int32_t>>> * output_class_indexing)292 Status DatasetOp::GetClassIndexing(std::vector<std::pair<std::string, std::vector<int32_t>>> *output_class_indexing) {
293 RETURN_UNEXPECTED_IF_NULL(output_class_indexing);
294 if (child_.size() == 1) {
295 return child_[0]->GetClassIndexing(output_class_indexing);
296 } else if (child_.size() > 1) {
297 // It is okay for dataset to have more than 1 child, GetClassIndexing shouldn't fail in this case.
298 // This is done mostly for cache, which injects cache lookup/merge operators. Cache path will
299 // always be in the front of the child_ structure, so we get data from the last child.
300 return child_[child_.size() - 1]->GetClassIndexing(output_class_indexing);
301 } else {
302 *output_class_indexing = {};
303 RETURN_STATUS_UNEXPECTED("Unsupported scenario, GetClassIndexing failed for " + Name() +
304 " doesn't support GetClassIndexing yet.");
305 }
306 }
307
308 // Performs handling for when an eoe message is received.
309 // The base class implementation simply flows the eoe message to output. Derived classes
310 // may override if they need to perform special eoe handling.
EoeReceived(int32_t worker_id)311 Status DatasetOp::EoeReceived(int32_t worker_id) { return out_connector_->SendEOE(); }
312
313 // Performs handling for when an eof message is received.
314 // The base class implementation simply flows the eof message to output. Derived classes
315 // may override if they need to perform special eof handling.
EofReceived(int32_t worker_id)316 Status DatasetOp::EofReceived(int32_t worker_id) { return out_connector_->SendEOF(); }
317
318 // During tree prepare phase, operators may have specific post-operations to perform depending on their role.
PrepareOperator()319 Status DatasetOp::PrepareOperator() {
320 // Creating Connector object for each op.
321 this->CreateConnector();
322 if (out_connector_) {
323 RETURN_IF_NOT_OK(out_connector_->Register(tree_->AllTasks()));
324 }
325 RETURN_IF_NOT_OK(this->RegisterWorkerConnectors());
326
327 // Generate the column name map for the current op.
328 RETURN_IF_NOT_OK(this->ComputeColMap());
329
330 return Status::OK();
331 }
332
333 // During tree prepare phase, operators may have specific post-operations to perform depending on their role.
PrepareOperatorPullBased()334 Status DatasetOp::PrepareOperatorPullBased() {
335 // Generate the column name map for the current op.
336 RETURN_IF_NOT_OK(this->ComputeColMap());
337
338 // check if operators are implemented in pull mode
339 std::string message = "";
340 ImplementedPullMode isImplemented = PullModeImplementationStatus();
341 if (isImplemented == ImplementedPullMode::NotImplemented) {
342 message = Name() + " is not implemented yet in pull mode.";
343 if (IsLeaf()) {
344 message = "Leaf node " + message;
345 if (GlobalContext::config_manager()->get_debug_mode()) {
346 RETURN_STATUS_UNEXPECTED(message);
347 }
348 }
349 } else if (isImplemented == ImplementedPullMode::DisabledDebugMode) {
350 message = "In debug mode, " + Name() + " is disabled for debugging purposes.";
351 }
352 if (message.size() > 0) {
353 MS_LOG(WARNING) << message;
354 }
355 return Status::OK();
356 }
357
358 // Derived classes may implement the reset function if the operator is stateful and needs
359 // specific reset handling that is not contained in this common code version of the reset.
Reset()360 Status DatasetOp::Reset() {
361 state_ = OpState::kDeOpRunning;
362 return Status::OK();
363 }
364
365 // gives a string output for the column map for handy debug printing
ColumnNameMapAsString() const366 std::string DatasetOp::ColumnNameMapAsString() const {
367 std::string outStr = "Column name id map: ";
368 for (auto &it : column_name_id_map_) {
369 outStr += (" " + it.first + ":" + std::to_string(it.second));
370 }
371 return outStr;
372 }
373
374 // Computing the assignment of the column name map.
375 // This just inherits the column map from its first child, can only be used if the number of children is 1.
376 // Operations changing the column map must overwrite this function.
ComputeColMap()377 Status DatasetOp::ComputeColMap() {
378 if (child_.size() > 1) {
379 RETURN_STATUS_UNEXPECTED(
380 "Invalid operator structure, the relationship of operators should be one by one, but got too many branches.");
381 }
382 if (column_name_id_map_.empty()) {
383 column_name_id_map_ = child_[0]->column_name_id_map();
384 if (column_name_id_map_.empty()) {
385 RETURN_STATUS_UNEXPECTED("Invalid column list, the column list of " + child_[0]->Name() +
386 " should have one column at least, but got empty.");
387 }
388 MS_LOG(DEBUG) << "Setting column map:\n" << DatasetOp::ColumnNameMapAsString();
389 } else {
390 MS_LOG(WARNING) << "Column name map is already set!";
391 }
392 return Status::OK();
393 }
394
395 // Getter for the sampler, and it also removes the sampler from the op
FetchRemoveSampler(std::shared_ptr<SamplerRT> * sampler)396 Status DatasetOp::FetchRemoveSampler(std::shared_ptr<SamplerRT> *sampler) {
397 RETURN_UNEXPECTED_IF_NULL(sampler);
398 *sampler = sampler_; // It's okay if it sampler_ points to nullptr
399 sampler_.reset(); // clear our member-copy of this pointer. We no longer have this sampler
400 return Status::OK();
401 }
402
403 #ifndef ENABLE_ANDROID
GenerateCRC(const std::shared_ptr<DatasetOp> & op)404 uint32_t DatasetOp::GenerateCRC(const std::shared_ptr<DatasetOp> &op) {
405 std::stringstream ss;
406 op->tree_->Print(ss, op);
407 std::string ss_str = ss.str();
408
409 // Filter out the Num workers field when generating the check sum
410 ss_str = std::regex_replace(ss_str, std::regex("Number of ShardReader workers.*\n"), "");
411 ss_str = std::regex_replace(ss_str, std::regex("Num workers.*\n"), "");
412 ss_str = std::regex_replace(ss_str, std::regex("\\[workers.*?\\]"), "");
413 ss_str = std::regex_replace(ss_str, std::regex("Connector queue size.*\n"), "");
414
415 // Filter out tcp/ip information
416 ss_str = std::regex_replace(ss_str, std::regex("Hostname.*\n"), "");
417 ss_str = std::regex_replace(ss_str, std::regex("Port.*\n"), "");
418 ss_str = std::regex_replace(ss_str, std::regex("Number of rpc workers.*\n"), "");
419 ss_str = std::regex_replace(ss_str, std::regex("Prefetch size.*\n"), "");
420 ss_str = std::regex_replace(ss_str, std::regex("Local client support.*\n"), "");
421
422 // Filter out Number of rows when generating the check sum
423 ss_str = std::regex_replace(ss_str, std::regex("Number of rows.*\n"), "");
424
425 // Filter out the Operator control flags field when generating the check sum
426 ss_str = std::regex_replace(ss_str, std::regex("Operator control flags.*\n"), "");
427
428 // Filter out the Device id field to allow cache sharing for a distributed run of the same pipeline
429 ss_str = std::regex_replace(ss_str, std::regex("Device id.*\n"), "");
430 ss_str = std::regex_replace(ss_str, std::regex("device_id.*\n"), "");
431
432 // Filter out the operator id field
433 ss_str = std::regex_replace(ss_str, std::regex(" *Parent.*\n"), "");
434 ss_str = std::regex_replace(ss_str, std::regex(" *Child.*\n"), "");
435 ss_str = std::regex_replace(ss_str, std::regex(R"(\(\s*\d+?\))"), "");
436
437 // Doesn't matter whether there is any parent node above CacheOp or not.
438 ss_str = std::regex_replace(ss_str, std::regex("Number of parents.*\n"), "");
439
440 // Filter out shuffle seed from ShuffleOp
441 ss_str = std::regex_replace(ss_str, std::regex("Shuffle seed.*\n"), "");
442
443 // Filter out the total repeats and number repeats per epoch field
444 ss_str = std::regex_replace(ss_str, std::regex("Total repeats.*\n"), "");
445 ss_str = std::regex_replace(ss_str, std::regex("Number repeats per epoch.*\n"), "");
446
447 // The Cache crc and Server cache id field is different when creating new cache_client and re-using the same
448 // cache_client later. So we filter out these two fields to allow cache sharing.
449 ss_str = std::regex_replace(ss_str, std::regex("Cache crc.*\n"), "");
450 ss_str = std::regex_replace(ss_str, std::regex("Server cache id.*\n"), "");
451
452 MS_LOG(DEBUG) << "Printing the tree for generating crc:\n" << ss_str;
453
454 uint32_t cache_crc = system::Crc32c::GetMaskCrc32cValue(ss_str.c_str(), ss_str.length());
455 return cache_crc;
456 }
457 #endif
458
UpdateRepeatAndEpochCounter()459 void DatasetOp::UpdateRepeatAndEpochCounter() {
460 op_current_repeats_++;
461 if (op_current_repeats_ % op_num_repeats_per_epoch_ == 0) {
462 op_current_epochs_++;
463 }
464 MS_LOG(DEBUG) << Name() << " current repeats: " << op_current_repeats_ << ", current epochs: " << op_current_epochs_;
465 }
466
SetEpoch(const int64_t epoch)467 Status DatasetOp::SetEpoch(const int64_t epoch) {
468 CHECK_FAIL_RETURN_UNEXPECTED(epoch >= 0,
469 "New epoch value must be greater than or equal to 0, got: " + std::to_string(epoch));
470 while (op_current_epochs_ < epoch) {
471 UpdateRepeatAndEpochCounter();
472 }
473 return Status::OK();
474 }
475
GetTreeBatchSize()476 int64_t DatasetOp::GetTreeBatchSize() {
477 if (child_.size() == 1) {
478 return child_[0]->GetTreeBatchSize();
479 } else if (child_.size() > 1) {
480 // It is okay for dataset to have more than 1 child, GetBatchSize shouldn't fail in this case.
481 // This is done mostly for cache, which injects cache lookup/merge operators. Cache path will
482 // always be in front of the child_ structure, so we get data from the last child.
483 return child_[child_.size() - 1]->GetTreeBatchSize();
484 } else {
485 return 1;
486 }
487 }
488
GetTreeRepeatCount()489 int64_t DatasetOp::GetTreeRepeatCount() {
490 if (child_.size() == 1) {
491 return child_[0]->GetTreeRepeatCount();
492 } else if (child_.size() > 1) {
493 // It is okay for dataset to have more than 1 child, GetRepeatCount shouldn't fail in this case.
494 // This is done mostly for cache, which injects cache lookup/merge operators. Cache path will
495 // always be in front of the child_ structure, so we get data from the last child.
496 return child_[child_.size() - 1]->GetTreeRepeatCount();
497 } else {
498 return 1;
499 }
500 }
GetMPWorkerPIDs() const501 std::vector<int32_t> DatasetOp::GetMPWorkerPIDs() const { return std::vector<int32_t>(); }
502 } // namespace dataset
503 } // namespace mindspore
504