• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "minddata/dataset/engine/datasetops/source/sampler/sampler.h"
17 
18 #include <algorithm>
19 #include <string>
20 
21 namespace mindspore {
22 namespace dataset {
GetNumRowsInDataset(int64_t * num) const23 Status RandomAccessOp::GetNumRowsInDataset(int64_t *num) const {
24   // The sampler base class itself does not compute it's own num_rows_ value.
25   // Instead, this value is computed by the derived leaf op during it's own initialization
26   // after it has interacted with it's storage layers.
27   // Here, it is just a getter method to return the value.  However, it is invalid if there is
28   // not a value set for this count, so generate a failure if that is the case.
29   if (num == nullptr || num_rows_ == -1) {
30     RETURN_STATUS_UNEXPECTED("Get num rows in Dataset failed, num_rows has not been set yet.");
31   }
32   (*num) = num_rows_;
33   return Status::OK();
34 }
35 
SamplerRT(int64_t num_samples,int64_t samples_per_tensor)36 SamplerRT::SamplerRT(int64_t num_samples, int64_t samples_per_tensor)
37     : num_rows_(0),
38       num_samples_(num_samples),
39       samples_per_tensor_(samples_per_tensor),
40       col_desc_(nullptr),
41       is_initialized(false) {}
42 
HandshakeRandomAccessOp(const RandomAccessOp * op)43 Status SamplerRT::HandshakeRandomAccessOp(const RandomAccessOp *op) {
44   std::shared_ptr<SamplerRT> child_sampler;
45   if (HasChildSampler()) {
46     child_sampler = std::dynamic_pointer_cast<SamplerRT>(child_[0]);
47     if (!child_sampler) {
48       std::string err_msg("[Internal ERROR] Cannot handshake, child is not a sampler object.");
49       RETURN_STATUS_UNEXPECTED(err_msg);
50     }
51 
52     // Handshake and init child first.
53     RETURN_IF_NOT_OK(child_sampler->HandshakeRandomAccessOp(op));
54   }
55 
56   CHECK_FAIL_RETURN_UNEXPECTED(op != nullptr, "RandomAccessOp init failed, as it is nullptr.");
57 
58   // If there's a child sampler, set the row count to be it's sample count
59   if (HasChildSampler()) {
60     num_rows_ = child_sampler->num_samples_;
61   } else {
62     RETURN_IF_NOT_OK(op->GetNumRowsInDataset(&num_rows_));
63   }
64 
65   // It's up to the derived class to check the validity of the two args
66   // Because some sampler only needs one of the arg (weighted_random_sampler)
67   RETURN_IF_NOT_OK(InitSampler());  // init sampler after callback
68 
69   return Status::OK();
70 }
71 
CreateSamplerTensor(std::shared_ptr<Tensor> * sample_ids,int64_t num_elements)72 Status SamplerRT::CreateSamplerTensor(std::shared_ptr<Tensor> *sample_ids, int64_t num_elements) {
73   if (col_desc_ == nullptr) {
74     // a ColDescriptor for Tensor that holds SampleIds
75     col_desc_ = std::make_unique<ColDescriptor>("sampleIds", DataType(DataType::DE_INT64), TensorImpl::kFlexible, 1);
76   }
77   TensorShape shape(std::vector<dsize_t>(1, num_elements));
78   RETURN_IF_NOT_OK(Tensor::CreateEmpty(shape, col_desc_->Type(), sample_ids));
79   return Status::OK();
80 }
81 
SamplerPrint(std::ostream & out,bool show_all) const82 void SamplerRT::SamplerPrint(std::ostream &out, bool show_all) const {
83   // Sampler printing is usually only called in the show_all mode.
84   // Derived classes will display the name, then call back to this base
85   // for common info.
86   // No-op in the summary mode.
87   if (show_all) {
88     out << "\nnum_rows_: " << num_rows_ << "\nnum_samples_: " << num_samples_;
89   }
90 }
91 
92 #ifdef ENABLE_PYTHON
GetAllIdsThenReset(py::array * data)93 Status SamplerRT::GetAllIdsThenReset(py::array *data) {
94   std::shared_ptr<Tensor> sample_ids;
95   TensorRow sample_row;
96 
97   // Get the only tensor inside the row that contains the actual SampleIds for the entire epoch
98   RETURN_IF_NOT_OK(GetNextSample(&sample_row));
99   sample_ids = sample_row[0];
100 
101   // check this tensorRow is not a ctrl tensorRow
102   CHECK_FAIL_RETURN_UNEXPECTED(sample_row.Flags() == TensorRow::kFlagNone, "[Internal ERROR] ctrl row received.");
103 
104   // perform error checking! Next TensorRow supposed to be EOE since last one already contains all ids for current epoch
105   RETURN_IF_NOT_OK(GetNextSample(&sample_row));
106   CHECK_FAIL_RETURN_UNEXPECTED(sample_row.eoe(), "[Internal ERROR] Non EOE received in the end of epoch.");
107   // Reset Sampler since this is the end of the epoch
108   RETURN_IF_NOT_OK(ResetSampler());
109 
110   {
111     py::gil_scoped_acquire gil_acquire;
112     if (Py_IsInitialized() == 0) {
113       return Status(StatusCode::kMDPythonInterpreterFailure, "Python Interpreter is finalized");
114     }
115     try {
116       RETURN_IF_NOT_OK(sample_ids->GetDataAsNumpy(data));
117     } catch (const std::runtime_error &e) {
118       return Status(StatusCode::kMDPyFuncException, e.what());
119     }
120   }
121   return Status::OK();
122 }
123 #endif
124 
SetNumSamples(int64_t num_samples)125 Status SamplerRT::SetNumSamples(int64_t num_samples) {
126   CHECK_FAIL_RETURN_UNEXPECTED(num_samples >= 0, "Invalid parameter, num_samples must be greater than or equal to 0.");
127   num_samples_ = num_samples;
128   return Status::OK();
129 }
130 
GetNumSamples() const131 int64_t SamplerRT::GetNumSamples() const { return num_samples_; }
132 
CalculateNumSamples(int64_t num_rows)133 int64_t SamplerRT::CalculateNumSamples(int64_t num_rows) {
134   int64_t child_num_rows = num_rows;
135   if (!child_.empty()) {
136     child_num_rows = child_[0]->CalculateNumSamples(num_rows);
137     // return -1 if child_num_rows is undetermined
138     if (child_num_rows == -1) return child_num_rows;
139   }
140 
141   return (num_samples_ > 0) ? std::min(child_num_rows, num_samples_) : child_num_rows;
142 }
143 
SetNumRowsInDataset(int64_t num_rows)144 Status SamplerRT::SetNumRowsInDataset(int64_t num_rows) {
145   CHECK_FAIL_RETURN_UNEXPECTED(
146     num_rows > 0,
147     "Invalid data, data rows of input dataset must not be less than or equal to 0, please check the input dataset.");
148   num_rows_ = num_rows;
149   return Status::OK();
150 }
151 
AddChild(std::shared_ptr<SamplerRT> child)152 Status SamplerRT::AddChild(std::shared_ptr<SamplerRT> child) {
153   if (child == nullptr) {
154     return Status::OK();
155   }
156 
157   // Only samplers can be added, not any other DatasetOp.
158   std::shared_ptr<SamplerRT> sampler = std::dynamic_pointer_cast<SamplerRT>(child);
159   if (!sampler) {
160     std::string err_msg("Cannot add child, child is not a sampler object.");
161     RETURN_STATUS_UNEXPECTED(err_msg);
162   }
163 
164   // Samplers can have at most 1 child.
165   if (!child_.empty()) {
166     std::string err_msg("Cannot add child sampler, this sampler already has a child.");
167     RETURN_STATUS_UNEXPECTED(err_msg);
168   }
169 
170   child_.push_back(child);
171 
172   return Status::OK();
173 }
174 
HasChildSampler() const175 bool SamplerRT::HasChildSampler() const { return !child_.empty(); }
176 
GetAssociatedChildId(int64_t * out_associated_id,int64_t id)177 Status SamplerRT::GetAssociatedChildId(int64_t *out_associated_id, int64_t id) {
178   if (child_ids_.empty()) {
179     RETURN_STATUS_UNEXPECTED("[Internal ERROR] Trying to get associated child id, but there are no child ids!");
180   }
181 
182   std::shared_ptr<Tensor> sample_ids = child_ids_[0];
183   RETURN_IF_NOT_OK(sample_ids->GetItemAt<int64_t>(out_associated_id, {id}));
184   return Status::OK();
185 }
to_json(nlohmann::json * out_json)186 Status SamplerRT::to_json(nlohmann::json *out_json) {
187   nlohmann::json args;
188   args["num_samples"] = num_samples_;
189   if (this->HasChildSampler()) {
190     std::vector<nlohmann::json> children_args;
191     for (const auto &child : child_) {
192       nlohmann::json child_arg;
193       RETURN_IF_NOT_OK(child->to_json(&child_arg));
194       children_args.push_back(child_arg);
195     }
196     args["child_sampler"] = children_args;
197   }
198   *out_json = args;
199   return Status::OK();
200 }
201 
202 }  // namespace dataset
203 }  // namespace mindspore
204