• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "minddata/dataset/include/dataset/iterator.h"
17 #include "minddata/dataset/engine/consumers/pull_based_tree_consumer.h"
18 #include "minddata/dataset/engine/consumers/tree_consumer.h"
19 #include "minddata/dataset/engine/runtime_context.h"
20 #include "minddata/dataset/include/dataset/datasets.h"
21 
22 namespace mindspore {
23 namespace dataset {
24 
Iterator()25 Iterator::Iterator() : consumer_(nullptr) {}
~Iterator()26 Iterator::~Iterator() { Stop(); }
27 
28 // Get the next row from the data pipeline.
GetNextRowCharIF(MSTensorMapChar * row)29 Status Iterator::GetNextRowCharIF(MSTensorMapChar *row) {
30   RETURN_UNEXPECTED_IF_NULL(row);
31   // Clean data buffer
32   row->clear();
33   std::unordered_map<std::string, std::shared_ptr<dataset::Tensor>> md_map;
34   CHECK_FAIL_RETURN_UNEXPECTED(consumer_ != nullptr, "consumer_ is null, pls launch iterator first.");
35   Status rc = consumer_->GetNextAsMap(&md_map);
36   if (rc.IsError()) {
37     MS_LOG(ERROR) << "GetNextRow: Failed to get next row. Error status: " << rc;
38     row->clear();
39     return rc;
40   }
41   for (auto &de_tensor : md_map) {
42     std::vector<char> col_name(de_tensor.first.begin(), de_tensor.first.end());
43     row->insert(std::make_pair(col_name, mindspore::MSTensor(std::make_shared<DETensor>(de_tensor.second))));
44   }
45 
46   return Status::OK();
47 }
48 
49 // Get the next row from the data pipeline.
GetNextRow(MSTensorVec * row)50 Status Iterator::GetNextRow(MSTensorVec *row) {
51   RETURN_UNEXPECTED_IF_NULL(row);
52   // Clean data row
53   row->clear();
54   // create a dataset tensor row and fetch. Then we convert the output to MSTensor
55   std::vector<std::shared_ptr<dataset::Tensor>> md_row;
56   CHECK_FAIL_RETURN_UNEXPECTED(consumer_ != nullptr, "consumer_ is null, pls launch iterator first.");
57   Status rc = consumer_->GetNextAsVector(&md_row);
58   if (rc.IsError()) {
59     row->clear();
60     return rc;
61   }
62   std::transform(md_row.begin(), md_row.end(), std::back_inserter(*row),
63                  [](auto t) { return mindspore::MSTensor(std::make_shared<DETensor>(t)); });
64   return Status::OK();
65 }
66 
67 // Shut down the data pipeline.
Stop()68 void Iterator::Stop() {
69   if (runtime_context_ != nullptr) {
70     Status rc = runtime_context_->Terminate();
71     if (rc.IsError()) {
72       MS_LOG(ERROR) << rc.ToString();
73     }
74   }
75 }
76 
77 // Function to build and launch the execution tree.
BuildAndLaunchTree(std::shared_ptr<Dataset> ds,int32_t num_epochs)78 Status Iterator::BuildAndLaunchTree(std::shared_ptr<Dataset> ds, int32_t num_epochs) {
79   RETURN_UNEXPECTED_IF_NULL(ds);
80   runtime_context_ = std::make_unique<NativeRuntimeContext>();
81   CHECK_FAIL_RETURN_UNEXPECTED(runtime_context_ != nullptr, "Create runtime_context_ failed.");
82   RETURN_IF_NOT_OK(runtime_context_->Init());
83   auto consumer = std::make_unique<IteratorConsumer>(num_epochs);
84   CHECK_FAIL_RETURN_UNEXPECTED(consumer != nullptr, "Create consumer failed.");
85   consumer_ = consumer.get();
86   RETURN_IF_NOT_OK(consumer->Init(ds->IRNode()));
87   runtime_context_->AssignConsumer(std::move(consumer));
88   return Status::OK();
89 }
90 
PullIterator()91 PullIterator::PullIterator() : pull_consumer_(nullptr) {}
92 
93 // Get the next row from the data pipeline.
GetRows(int32_t num_rows,std::vector<MSTensorVec> * const row)94 Status PullIterator::GetRows(int32_t num_rows, std::vector<MSTensorVec> *const row) {
95   RETURN_UNEXPECTED_IF_NULL(row);
96   CHECK_FAIL_RETURN_UNEXPECTED(pull_consumer_ != nullptr, "Consumer is nullptr. Please launch iterator fist.");
97   for (int i = 0; i < num_rows; i++) {
98     std::vector<std::shared_ptr<dataset::Tensor>> md_row;
99     Status rc = pull_consumer_->GetNextAsVector(&md_row);
100 
101     if (rc.IsError()) {
102       row->clear();
103       MS_LOG(ERROR) << "GetNextRow: Failed to get next row. Error status: " << rc;
104       return rc;
105     }
106 
107     MSTensorVec ms_row = {};
108     for (auto de_tensor : md_row) {
109       CHECK_FAIL_RETURN_UNEXPECTED(de_tensor->HasData(), "Apply transform failed, output tensor has no data");
110       ms_row.push_back(mindspore::MSTensor(std::make_shared<DETensor>(de_tensor)));
111     }
112     row->push_back(ms_row);
113   }
114   return Status::OK();
115 }
116 
GetNextRow(MSTensorVec * const row)117 Status PullIterator::GetNextRow(MSTensorVec *const row) {
118   RETURN_UNEXPECTED_IF_NULL(row);
119   CHECK_FAIL_RETURN_UNEXPECTED(pull_consumer_ != nullptr, "Consumer is nullptr.");
120   std::vector<std::shared_ptr<dataset::Tensor>> md_row;
121   Status rc = pull_consumer_->GetNextAsVector(&md_row);
122   if (rc.IsError()) {
123     row->clear();
124     MS_LOG(ERROR) << "GetNextRow: Failed to get next row. Error status: " << rc;
125     return rc;
126   }
127 
128   for (auto de_tensor : md_row) {
129     CHECK_FAIL_RETURN_UNEXPECTED(de_tensor->HasData(), "Apply transform failed, output tensor has no data");
130     row->push_back(mindspore::MSTensor(std::make_shared<DETensor>(de_tensor)));
131   }
132   return Status::OK();
133 }
134 
135 // Function to build and launch the execution tree. This function kicks off a different type of consumer
136 // for the tree, the reason why this is the case is due to the fact that PullBasedIterator does not need
137 // to instantiate threads for each op. As such, the call to the consumer will by pass the execution tree.
BuildAndLaunchTree(std::shared_ptr<Dataset> ds)138 Status PullIterator::BuildAndLaunchTree(std::shared_ptr<Dataset> ds) {
139   if (pull_consumer_ == nullptr) {
140     pull_consumer_ = std::make_unique<PullBasedIteratorConsumer>();
141   }
142   CHECK_FAIL_RETURN_UNEXPECTED(pull_consumer_ != nullptr, "pull_consumer_ is nullptr");
143   RETURN_IF_NOT_OK(pull_consumer_->Init(std::move(ds->IRNode())));
144   return Status::OK();
145 }
146 
_Iterator(Iterator * lt)147 Iterator::_Iterator::_Iterator(Iterator *lt) : ind_{0}, lt_{lt}, cur_row_{nullptr} {
148   if (lt_) {
149     cur_row_ = new MSTensorMap();
150     if (cur_row_ == nullptr) {
151       return;
152     }
153     Status rc = lt_->GetNextRow(cur_row_);
154     if (rc.IsError()) {
155       MS_LOG(ERROR) << "Error getting next row. Message: " << rc;
156       delete cur_row_;
157       cur_row_ = nullptr;
158     }
159   }
160 }
operator ++()161 Iterator::_Iterator &Iterator::_Iterator::operator++() {
162   if (lt_) {
163     ++ind_;
164     Status rc = lt_->GetNextRow(cur_row_);
165     if (rc.IsError()) {
166       MS_LOG(ERROR) << "Error getting next row. Message: " << rc;
167       cur_row_ = nullptr;
168     }
169   }
170   if (cur_row_ && cur_row_->size() == 0) {
171     delete cur_row_;
172     cur_row_ = nullptr;
173   }
174   return *this;
175 }
176 }  // namespace dataset
177 }  // namespace mindspore
178