1 /**
2 * Copyright 2019-2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "minddata/dataset/engine/datasetops/source/sampler/sampler.h"
17
18 #include <algorithm>
19 #include <string>
20
21 namespace mindspore {
22 namespace dataset {
GetNumRowsInDataset(int64_t * num) const23 Status RandomAccessOp::GetNumRowsInDataset(int64_t *num) const {
24 // The sampler base class itself does not compute it's own num_rows_ value.
25 // Instead, this value is computed by the derived leaf op during it's own initialization
26 // after it has interacted with it's storage layers.
27 // Here, it is just a getter method to return the value. However, it is invalid if there is
28 // not a value set for this count, so generate a failure if that is the case.
29 if (num == nullptr || num_rows_ == -1) {
30 RETURN_STATUS_UNEXPECTED("Get num rows in Dataset failed, num_rows has not been set yet.");
31 }
32 (*num) = num_rows_;
33 return Status::OK();
34 }
35
SamplerRT(int64_t num_samples,int64_t samples_per_tensor)36 SamplerRT::SamplerRT(int64_t num_samples, int64_t samples_per_tensor)
37 : num_rows_(0),
38 num_samples_(num_samples),
39 samples_per_tensor_(samples_per_tensor),
40 col_desc_(nullptr),
41 is_initialized(false) {}
42
HandshakeRandomAccessOp(const RandomAccessOp * op)43 Status SamplerRT::HandshakeRandomAccessOp(const RandomAccessOp *op) {
44 std::shared_ptr<SamplerRT> child_sampler;
45 if (HasChildSampler()) {
46 child_sampler = std::dynamic_pointer_cast<SamplerRT>(child_[0]);
47 if (!child_sampler) {
48 std::string err_msg("[Internal ERROR] Cannot handshake, child is not a sampler object.");
49 RETURN_STATUS_UNEXPECTED(err_msg);
50 }
51
52 // Handshake and init child first.
53 RETURN_IF_NOT_OK(child_sampler->HandshakeRandomAccessOp(op));
54 }
55
56 CHECK_FAIL_RETURN_UNEXPECTED(op != nullptr, "RandomAccessOp init failed, as it is nullptr.");
57
58 // If there's a child sampler, set the row count to be it's sample count
59 if (HasChildSampler()) {
60 num_rows_ = child_sampler->num_samples_;
61 } else {
62 RETURN_IF_NOT_OK(op->GetNumRowsInDataset(&num_rows_));
63 }
64
65 // It's up to the derived class to check the validity of the two args
66 // Because some sampler only needs one of the arg (weighted_random_sampler)
67 RETURN_IF_NOT_OK(InitSampler()); // init sampler after callback
68
69 return Status::OK();
70 }
71
CreateSamplerTensor(std::shared_ptr<Tensor> * sample_ids,int64_t num_elements)72 Status SamplerRT::CreateSamplerTensor(std::shared_ptr<Tensor> *sample_ids, int64_t num_elements) {
73 if (col_desc_ == nullptr) {
74 // a ColDescriptor for Tensor that holds SampleIds
75 col_desc_ = std::make_unique<ColDescriptor>("sampleIds", DataType(DataType::DE_INT64), TensorImpl::kFlexible, 1);
76 }
77 TensorShape shape(std::vector<dsize_t>(1, num_elements));
78 RETURN_IF_NOT_OK(Tensor::CreateEmpty(shape, col_desc_->Type(), sample_ids));
79 return Status::OK();
80 }
81
SamplerPrint(std::ostream & out,bool show_all) const82 void SamplerRT::SamplerPrint(std::ostream &out, bool show_all) const {
83 // Sampler printing is usually only called in the show_all mode.
84 // Derived classes will display the name, then call back to this base
85 // for common info.
86 // No-op in the summary mode.
87 if (show_all) {
88 out << "\nnum_rows_: " << num_rows_ << "\nnum_samples_: " << num_samples_;
89 }
90 }
91
92 #ifdef ENABLE_PYTHON
GetAllIdsThenReset(py::array * data)93 Status SamplerRT::GetAllIdsThenReset(py::array *data) {
94 std::shared_ptr<Tensor> sample_ids;
95 TensorRow sample_row;
96
97 // Get the only tensor inside the row that contains the actual SampleIds for the entire epoch
98 RETURN_IF_NOT_OK(GetNextSample(&sample_row));
99 sample_ids = sample_row[0];
100
101 // check this tensorRow is not a ctrl tensorRow
102 CHECK_FAIL_RETURN_UNEXPECTED(sample_row.Flags() == TensorRow::kFlagNone, "[Internal ERROR] ctrl row received.");
103
104 // perform error checking! Next TensorRow supposed to be EOE since last one already contains all ids for current epoch
105 RETURN_IF_NOT_OK(GetNextSample(&sample_row));
106 CHECK_FAIL_RETURN_UNEXPECTED(sample_row.eoe(), "[Internal ERROR] Non EOE received in the end of epoch.");
107 // Reset Sampler since this is the end of the epoch
108 RETURN_IF_NOT_OK(ResetSampler());
109
110 {
111 py::gil_scoped_acquire gil_acquire;
112 if (Py_IsInitialized() == 0) {
113 return Status(StatusCode::kMDPythonInterpreterFailure, "Python Interpreter is finalized");
114 }
115 try {
116 RETURN_IF_NOT_OK(sample_ids->GetDataAsNumpy(data));
117 } catch (const std::runtime_error &e) {
118 return Status(StatusCode::kMDPyFuncException, e.what());
119 }
120 }
121 return Status::OK();
122 }
123 #endif
124
SetNumSamples(int64_t num_samples)125 Status SamplerRT::SetNumSamples(int64_t num_samples) {
126 CHECK_FAIL_RETURN_UNEXPECTED(num_samples >= 0, "Invalid parameter, num_samples must be greater than or equal to 0.");
127 num_samples_ = num_samples;
128 return Status::OK();
129 }
130
GetNumSamples() const131 int64_t SamplerRT::GetNumSamples() const { return num_samples_; }
132
CalculateNumSamples(int64_t num_rows)133 int64_t SamplerRT::CalculateNumSamples(int64_t num_rows) {
134 int64_t child_num_rows = num_rows;
135 if (!child_.empty()) {
136 child_num_rows = child_[0]->CalculateNumSamples(num_rows);
137 // return -1 if child_num_rows is undetermined
138 if (child_num_rows == -1) return child_num_rows;
139 }
140
141 return (num_samples_ > 0) ? std::min(child_num_rows, num_samples_) : child_num_rows;
142 }
143
SetNumRowsInDataset(int64_t num_rows)144 Status SamplerRT::SetNumRowsInDataset(int64_t num_rows) {
145 CHECK_FAIL_RETURN_UNEXPECTED(
146 num_rows > 0,
147 "Invalid data, data rows of input dataset must not be less than or equal to 0, please check the input dataset.");
148 num_rows_ = num_rows;
149 return Status::OK();
150 }
151
AddChild(std::shared_ptr<SamplerRT> child)152 Status SamplerRT::AddChild(std::shared_ptr<SamplerRT> child) {
153 if (child == nullptr) {
154 return Status::OK();
155 }
156
157 // Only samplers can be added, not any other DatasetOp.
158 std::shared_ptr<SamplerRT> sampler = std::dynamic_pointer_cast<SamplerRT>(child);
159 if (!sampler) {
160 std::string err_msg("Cannot add child, child is not a sampler object.");
161 RETURN_STATUS_UNEXPECTED(err_msg);
162 }
163
164 // Samplers can have at most 1 child.
165 if (!child_.empty()) {
166 std::string err_msg("Cannot add child sampler, this sampler already has a child.");
167 RETURN_STATUS_UNEXPECTED(err_msg);
168 }
169
170 child_.push_back(child);
171
172 return Status::OK();
173 }
174
HasChildSampler() const175 bool SamplerRT::HasChildSampler() const { return !child_.empty(); }
176
GetAssociatedChildId(int64_t * out_associated_id,int64_t id)177 Status SamplerRT::GetAssociatedChildId(int64_t *out_associated_id, int64_t id) {
178 if (child_ids_.empty()) {
179 RETURN_STATUS_UNEXPECTED("[Internal ERROR] Trying to get associated child id, but there are no child ids!");
180 }
181
182 std::shared_ptr<Tensor> sample_ids = child_ids_[0];
183 RETURN_IF_NOT_OK(sample_ids->GetItemAt<int64_t>(out_associated_id, {id}));
184 return Status::OK();
185 }
to_json(nlohmann::json * out_json)186 Status SamplerRT::to_json(nlohmann::json *out_json) {
187 nlohmann::json args;
188 args["num_samples"] = num_samples_;
189 if (this->HasChildSampler()) {
190 std::vector<nlohmann::json> children_args;
191 for (const auto &child : child_) {
192 nlohmann::json child_arg;
193 RETURN_IF_NOT_OK(child->to_json(&child_arg));
194 children_args.push_back(child_arg);
195 }
196 args["child_sampler"] = children_args;
197 }
198 *out_json = args;
199 return Status::OK();
200 }
201
202 } // namespace dataset
203 } // namespace mindspore
204