• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "minddata/dataset/include/dataset/iterator.h"
17 
18 #include "minddata/dataset/engine/consumers/pull_based_tree_consumer.h"
19 #include "minddata/dataset/engine/consumers/tree_consumer.h"
20 #include "minddata/dataset/engine/runtime_context.h"
21 #include "minddata/dataset/include/dataset/datasets.h"
22 
23 namespace mindspore {
24 namespace dataset {
25 
Iterator()26 Iterator::Iterator() : consumer_(nullptr) {}
~Iterator()27 Iterator::~Iterator() { Stop(); }
28 
29 // Get the next row from the data pipeline.
GetNextRowCharIF(MSTensorMapChar * row)30 Status Iterator::GetNextRowCharIF(MSTensorMapChar *row) {
31   RETURN_UNEXPECTED_IF_NULL(row);
32   // Clean data buffer
33   row->clear();
34   std::unordered_map<std::string, std::shared_ptr<dataset::Tensor>> md_map;
35   CHECK_FAIL_RETURN_UNEXPECTED(consumer_ != nullptr, "consumer_ is null, pls launch iterator first.");
36   Status rc = consumer_->GetNextAsMap(&md_map);
37   if (rc.IsError()) {
38     MS_LOG(ERROR) << "GetNextRow: Failed to get next row. Error status: " << rc;
39     row->clear();
40     return rc;
41   }
42   for (auto &de_tensor : md_map) {
43     std::vector<char> col_name(de_tensor.first.begin(), de_tensor.first.end());
44     row->insert(std::make_pair(col_name, mindspore::MSTensor(std::make_shared<DETensor>(de_tensor.second))));
45   }
46 
47   return Status::OK();
48 }
49 
50 // Get the next row from the data pipeline.
GetNextRow(MSTensorVec * row)51 Status Iterator::GetNextRow(MSTensorVec *row) {
52   RETURN_UNEXPECTED_IF_NULL(row);
53   // Clean data row
54   row->clear();
55   // create a dataset tensor row and fetch. Then we convert the output to MSTensor
56   std::vector<std::shared_ptr<dataset::Tensor>> md_row;
57   CHECK_FAIL_RETURN_UNEXPECTED(consumer_ != nullptr, "consumer_ is null, pls launch iterator first.");
58   Status rc = consumer_->GetNextAsVector(&md_row);
59   if (rc.IsError()) {
60     row->clear();
61     return rc;
62   }
63   std::transform(md_row.begin(), md_row.end(), std::back_inserter(*row),
64                  [](auto t) { return mindspore::MSTensor(std::make_shared<DETensor>(t)); });
65   return Status::OK();
66 }
67 
68 // Shut down the data pipeline.
Stop()69 void Iterator::Stop() {
70   if (runtime_context_ != nullptr) {
71     Status rc = runtime_context_->Terminate();
72     if (rc.IsError()) {
73       MS_LOG(ERROR) << rc.ToString();
74     }
75   }
76 }
77 
78 // Function to build and launch the execution tree.
BuildAndLaunchTree(const std::shared_ptr<Dataset> & ds,int32_t num_epochs)79 Status Iterator::BuildAndLaunchTree(const std::shared_ptr<Dataset> &ds, int32_t num_epochs) {
80   RETURN_UNEXPECTED_IF_NULL(ds);
81   runtime_context_ = std::make_unique<NativeRuntimeContext>();
82   CHECK_FAIL_RETURN_UNEXPECTED(runtime_context_ != nullptr, "Create runtime_context_ failed.");
83   RETURN_IF_NOT_OK(runtime_context_->Init());
84   auto consumer = std::make_unique<IteratorConsumer>(num_epochs);
85   CHECK_FAIL_RETURN_UNEXPECTED(consumer != nullptr, "Create consumer failed.");
86   consumer_ = consumer.get();
87   RETURN_IF_NOT_OK(consumer->Init(ds->IRNode()));
88   runtime_context_->AssignConsumer(std::move(consumer));
89   return Status::OK();
90 }
91 
PullIterator()92 PullIterator::PullIterator() : pull_consumer_(nullptr) {}
93 
94 PullIterator::~PullIterator() = default;
95 
96 // Get the next row from the data pipeline.
GetRows(int32_t num_rows,std::vector<MSTensorVec> * const row)97 Status PullIterator::GetRows(int32_t num_rows, std::vector<MSTensorVec> *const row) {
98   RETURN_UNEXPECTED_IF_NULL(row);
99   CHECK_FAIL_RETURN_UNEXPECTED(pull_consumer_ != nullptr, "Consumer is nullptr. Please launch iterator first.");
100   row->clear();
101   for (int i = 0; i < num_rows; i++) {
102     std::vector<std::shared_ptr<dataset::Tensor>> md_row;
103     Status rc = pull_consumer_->GetNextAsVector(&md_row);
104 
105     if (rc.IsError()) {
106       row->clear();
107       MS_LOG(ERROR) << "GetNextRow: Failed to get next row. Error status: " << rc;
108       return rc;
109     }
110 
111     MSTensorVec ms_row = {};
112     for (const auto &de_tensor : md_row) {
113       CHECK_FAIL_RETURN_UNEXPECTED(de_tensor->HasData(), "Apply transform failed, output tensor has no data");
114       ms_row.push_back(mindspore::MSTensor(std::make_shared<DETensor>(de_tensor)));
115     }
116     row->push_back(ms_row);
117   }
118   return Status::OK();
119 }
120 
GetNextRow(MSTensorVec * const row)121 Status PullIterator::GetNextRow(MSTensorVec *const row) {
122   RETURN_UNEXPECTED_IF_NULL(row);
123   CHECK_FAIL_RETURN_UNEXPECTED(pull_consumer_ != nullptr, "Consumer is nullptr.");
124   row->clear();
125   std::vector<std::shared_ptr<dataset::Tensor>> md_row;
126   Status rc = pull_consumer_->GetNextAsVector(&md_row);
127   if (rc.IsError()) {
128     row->clear();
129     MS_LOG(ERROR) << "GetNextRow: Failed to get next row. Error status: " << rc;
130     return rc;
131   }
132 
133   for (const auto &de_tensor : md_row) {
134     CHECK_FAIL_RETURN_UNEXPECTED(de_tensor->HasData(), "Apply transform failed, output tensor has no data");
135     row->push_back(mindspore::MSTensor(std::make_shared<DETensor>(de_tensor)));
136   }
137   return Status::OK();
138 }
139 
140 // Function to build and launch the execution tree. This function kicks off a different type of consumer
141 // for the tree, the reason why this is the case is due to the fact that PullBasedIterator does not need
142 // to instantiate threads for each op. As such, the call to the consumer will by pass the execution tree.
BuildAndLaunchTree(const std::shared_ptr<Dataset> & ds,int32_t num_epochs)143 Status PullIterator::BuildAndLaunchTree(const std::shared_ptr<Dataset> &ds, int32_t num_epochs) {
144   RETURN_UNEXPECTED_IF_NULL(ds);
145   if (pull_consumer_ == nullptr) {
146     pull_consumer_ = std::make_unique<PullBasedIteratorConsumer>();
147   }
148   CHECK_FAIL_RETURN_UNEXPECTED(pull_consumer_ != nullptr, "pull_consumer_ is nullptr");
149   RETURN_IF_NOT_OK(pull_consumer_->Init(ds->IRNode()));
150   return Status::OK();
151 }
152 
_Iterator(Iterator * lt)153 Iterator::_Iterator::_Iterator(Iterator *lt) : ind_{0}, lt_{lt}, cur_row_{nullptr} {
154   if (lt_) {
155     cur_row_ = new MSTensorMap();
156     if (cur_row_ == nullptr) {
157       return;
158     }
159     Status rc = lt_->GetNextRow(cur_row_);
160     if (rc.IsError()) {
161       MS_LOG(ERROR) << "Error getting next row. Message: " << rc;
162       delete cur_row_;
163       cur_row_ = nullptr;
164     }
165   }
166 }
operator ++()167 Iterator::_Iterator &Iterator::_Iterator::operator++() {
168   if (lt_) {
169     ++ind_;
170     Status rc = lt_->GetNextRow(cur_row_);
171     if (rc.IsError()) {
172       MS_LOG(ERROR) << "Error getting next row. Message: " << rc;
173       cur_row_ = nullptr;
174     }
175   }
176   if (cur_row_ && cur_row_->empty()) {
177     delete cur_row_;
178     cur_row_ = nullptr;
179   }
180   return *this;
181 }
182 }  // namespace dataset
183 }  // namespace mindspore
184