1 /**
2 * Copyright 2019-2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "minddata/dataset/engine/datasetops/source/sampler/sampler.h"
17
18 #include <algorithm>
19 #include <string>
20
21 namespace mindspore {
22 namespace dataset {
GetNumRowsInDataset(int64_t * num) const23 Status RandomAccessOp::GetNumRowsInDataset(int64_t *num) const {
24 RETURN_UNEXPECTED_IF_NULL(num);
25 // The sampler base class itself does not compute it's own num_rows_ value.
26 // Instead, this value is computed by the derived leaf op during it's own initialization
27 // after it has interacted with it's storage layers.
28 // Here, it is just a getter method to return the value. However, it is invalid if there is
29 // not a value set for this count, so generate a failure if that is the case.
30 if (num == nullptr || num_rows_ == -1) {
31 RETURN_STATUS_UNEXPECTED("[Internal ERROR] Get num rows in Dataset failed, num_rows has not been set yet.");
32 }
33 (*num) = num_rows_;
34 return Status::OK();
35 }
36
SamplerRT(int64_t num_samples,int64_t samples_per_tensor)37 SamplerRT::SamplerRT(int64_t num_samples, int64_t samples_per_tensor)
38 : num_rows_(0),
39 num_samples_(num_samples),
40 samples_per_tensor_(samples_per_tensor),
41 col_desc_(nullptr),
42 is_initialized(false),
43 child_ids_(TensorRow()) {}
44
HandshakeRandomAccessOp(const RandomAccessOp * op,const int32_t reset_count)45 Status SamplerRT::HandshakeRandomAccessOp(const RandomAccessOp *op, const int32_t reset_count) {
46 RETURN_UNEXPECTED_IF_NULL(op);
47 std::shared_ptr<SamplerRT> child_sampler;
48 if (HasChildSampler()) {
49 child_sampler = std::dynamic_pointer_cast<SamplerRT>(child_[0]);
50 if (!child_sampler) {
51 std::string err_msg("[Internal ERROR] Cannot handshake, child is not a sampler object.");
52 RETURN_STATUS_UNEXPECTED(err_msg);
53 }
54
55 // Handshake and init child first.
56 RETURN_IF_NOT_OK(child_sampler->HandshakeRandomAccessOp(op));
57 }
58
59 CHECK_FAIL_RETURN_UNEXPECTED(op != nullptr, "[Internal ERROR] RandomAccessOp init failed, as it is nullptr.");
60
61 // If there's a child sampler, set the row count to be it's sample count
62 if (HasChildSampler()) {
63 num_rows_ = child_sampler->num_samples_;
64 } else {
65 RETURN_IF_NOT_OK(op->GetNumRowsInDataset(&num_rows_));
66 }
67
68 // It's up to the derived class to check the validity of the two args
69 // Because some sampler only needs one of the arg (weighted_random_sampler)
70 RETURN_IF_NOT_OK(InitSampler()); // init sampler after callback
71 // Move forward sampler's random generator if resetting the pipeline in fast_recovery mode
72 if (GlobalContext::config_manager()->fast_recovery()) {
73 for (auto i = 0; i < reset_count; i++) {
74 RETURN_IF_NOT_OK(ResetSampler(true)); // failover_reset = true
75 }
76 }
77 return Status::OK();
78 }
79
CreateSamplerTensor(std::shared_ptr<Tensor> * sample_ids,int64_t num_elements)80 Status SamplerRT::CreateSamplerTensor(std::shared_ptr<Tensor> *sample_ids, int64_t num_elements) {
81 RETURN_UNEXPECTED_IF_NULL(sample_ids);
82 if (col_desc_ == nullptr) {
83 // a ColDescriptor for Tensor that holds SampleIds
84 col_desc_ = std::make_unique<ColDescriptor>("sampleIds", DataType(DataType::DE_INT64), TensorImpl::kFlexible, 1);
85 }
86 TensorShape shape(std::vector<dsize_t>(1, num_elements));
87 RETURN_IF_NOT_OK(Tensor::CreateEmpty(shape, col_desc_->Type(), sample_ids));
88 return Status::OK();
89 }
90
SamplerPrint(std::ostream & out,bool show_all) const91 void SamplerRT::SamplerPrint(std::ostream &out, bool show_all) const {
92 // Sampler printing is usually only called in the show_all mode.
93 // Derived classes will display the name, then call back to this base
94 // for common info.
95 // No-op in the summary mode.
96 if (show_all) {
97 out << "\nnum_rows_: " << num_rows_ << "\nnum_samples_: " << num_samples_;
98 }
99 }
100
101 #ifdef ENABLE_PYTHON
GetAllIdsThenReset(py::array * data)102 Status SamplerRT::GetAllIdsThenReset(py::array *data) {
103 RETURN_UNEXPECTED_IF_NULL(data);
104 std::shared_ptr<Tensor> sample_ids;
105 TensorRow sample_row;
106
107 // Get the only tensor inside the row that contains the actual SampleIds for the entire epoch
108 RETURN_IF_NOT_OK(GetNextSample(&sample_row));
109 sample_ids = sample_row[0];
110
111 // check this tensorRow is not a ctrl tensorRow
112 CHECK_FAIL_RETURN_UNEXPECTED(sample_row.Flags() == TensorRow::kFlagNone, "[Internal ERROR] ctrl row received.");
113
114 // perform error checking! Next TensorRow supposed to be EOE since last one already contains all ids for current epoch
115 RETURN_IF_NOT_OK(GetNextSample(&sample_row));
116 CHECK_FAIL_RETURN_UNEXPECTED(sample_row.eoe(), "[Internal ERROR] Non EOE received in the end of epoch.");
117 // Reset Sampler since this is the end of the epoch
118 RETURN_IF_NOT_OK(ResetSampler());
119
120 {
121 py::gil_scoped_acquire gil_acquire;
122 if (Py_IsInitialized() == 0) {
123 return Status(StatusCode::kMDPythonInterpreterFailure, "[Internal ERROR] Python Interpreter is finalized");
124 }
125 try {
126 RETURN_IF_NOT_OK(sample_ids->GetDataAsNumpy(data));
127 } catch (const std::runtime_error &e) {
128 return Status(StatusCode::kMDPyFuncException, e.what());
129 }
130 }
131 return Status::OK();
132 }
133 #endif
134
SetNumSamples(int64_t num_samples)135 Status SamplerRT::SetNumSamples(int64_t num_samples) {
136 CHECK_FAIL_RETURN_UNEXPECTED(
137 num_samples >= 0,
138 "Invalid parameter, 'num_samples' must be greater than or equal to 0, but got " + std::to_string(num_samples));
139 num_samples_ = num_samples;
140 return Status::OK();
141 }
142
GetNumSamples() const143 int64_t SamplerRT::GetNumSamples() const { return num_samples_; }
144
CalculateNumSamples(int64_t num_rows)145 int64_t SamplerRT::CalculateNumSamples(int64_t num_rows) {
146 int64_t child_num_rows = num_rows;
147 if (!child_.empty()) {
148 child_num_rows = child_[0]->CalculateNumSamples(num_rows);
149 // return -1 if child_num_rows is undetermined
150 if (child_num_rows == -1) {
151 return child_num_rows;
152 }
153 }
154
155 return (num_samples_ > 0) ? std::min(child_num_rows, num_samples_) : child_num_rows;
156 }
157
SetNumRowsInDataset(int64_t num_rows)158 Status SamplerRT::SetNumRowsInDataset(int64_t num_rows) {
159 CHECK_FAIL_RETURN_UNEXPECTED(
160 num_rows > 0,
161 "Invalid data, data rows of input dataset must not be less than or equal to 0, please check the input dataset.");
162 num_rows_ = num_rows;
163 return Status::OK();
164 }
165
AddChild(std::shared_ptr<SamplerRT> child)166 Status SamplerRT::AddChild(std::shared_ptr<SamplerRT> child) {
167 if (child == nullptr) {
168 return Status::OK();
169 }
170
171 // Only samplers can be added, not any other DatasetOp.
172 std::shared_ptr<SamplerRT> sampler = std::dynamic_pointer_cast<SamplerRT>(child);
173 if (!sampler) {
174 std::string err_msg("[Internal ERROR] Cannot add child, child is not a sampler object.");
175 RETURN_STATUS_UNEXPECTED(err_msg);
176 }
177
178 // Samplers can have at most 1 child.
179 if (!child_.empty()) {
180 std::string err_msg("[Internal ERROR] Cannot add child sampler, this sampler already has a child.");
181 RETURN_STATUS_UNEXPECTED(err_msg);
182 }
183
184 child_.push_back(child);
185
186 return Status::OK();
187 }
188
HasChildSampler() const189 bool SamplerRT::HasChildSampler() const { return !child_.empty(); }
190
GetAssociatedChildId(int64_t * out_associated_id,int64_t id)191 Status SamplerRT::GetAssociatedChildId(int64_t *out_associated_id, int64_t id) {
192 RETURN_UNEXPECTED_IF_NULL(out_associated_id);
193 if (child_ids_.empty()) {
194 RETURN_STATUS_UNEXPECTED("[Internal ERROR] Trying to get associated child id, but there are no child ids!");
195 }
196
197 std::shared_ptr<Tensor> sample_ids = child_ids_[0];
198 RETURN_IF_NOT_OK(sample_ids->GetItemAt<int64_t>(out_associated_id, {id}));
199 return Status::OK();
200 }
to_json(nlohmann::json * out_json)201 Status SamplerRT::to_json(nlohmann::json *out_json) {
202 RETURN_UNEXPECTED_IF_NULL(out_json);
203 nlohmann::json args;
204 args["num_samples"] = num_samples_;
205 if (this->HasChildSampler()) {
206 std::vector<nlohmann::json> children_args;
207 for (const auto &child : child_) {
208 nlohmann::json child_arg;
209 RETURN_IF_NOT_OK(child->to_json(&child_arg));
210 children_args.push_back(child_arg);
211 }
212 args["child_sampler"] = children_args;
213 }
214 *out_json = args;
215 return Status::OK();
216 }
217
218 } // namespace dataset
219 } // namespace mindspore
220