• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "minddata/dataset/engine/datasetops/source/sampler/sampler.h"
17 
18 #include <algorithm>
19 #include <string>
20 
21 namespace mindspore {
22 namespace dataset {
GetNumRowsInDataset(int64_t * num) const23 Status RandomAccessOp::GetNumRowsInDataset(int64_t *num) const {
24   RETURN_UNEXPECTED_IF_NULL(num);
25   // The sampler base class itself does not compute it's own num_rows_ value.
26   // Instead, this value is computed by the derived leaf op during it's own initialization
27   // after it has interacted with it's storage layers.
28   // Here, it is just a getter method to return the value.  However, it is invalid if there is
29   // not a value set for this count, so generate a failure if that is the case.
30   if (num == nullptr || num_rows_ == -1) {
31     RETURN_STATUS_UNEXPECTED("[Internal ERROR] Get num rows in Dataset failed, num_rows has not been set yet.");
32   }
33   (*num) = num_rows_;
34   return Status::OK();
35 }
36 
SamplerRT(int64_t num_samples,int64_t samples_per_tensor)37 SamplerRT::SamplerRT(int64_t num_samples, int64_t samples_per_tensor)
38     : num_rows_(0),
39       num_samples_(num_samples),
40       samples_per_tensor_(samples_per_tensor),
41       col_desc_(nullptr),
42       is_initialized(false),
43       child_ids_(TensorRow()) {}
44 
HandshakeRandomAccessOp(const RandomAccessOp * op,const int32_t reset_count)45 Status SamplerRT::HandshakeRandomAccessOp(const RandomAccessOp *op, const int32_t reset_count) {
46   RETURN_UNEXPECTED_IF_NULL(op);
47   std::shared_ptr<SamplerRT> child_sampler;
48   if (HasChildSampler()) {
49     child_sampler = std::dynamic_pointer_cast<SamplerRT>(child_[0]);
50     if (!child_sampler) {
51       std::string err_msg("[Internal ERROR] Cannot handshake, child is not a sampler object.");
52       RETURN_STATUS_UNEXPECTED(err_msg);
53     }
54 
55     // Handshake and init child first.
56     RETURN_IF_NOT_OK(child_sampler->HandshakeRandomAccessOp(op));
57   }
58 
59   CHECK_FAIL_RETURN_UNEXPECTED(op != nullptr, "[Internal ERROR] RandomAccessOp init failed, as it is nullptr.");
60 
61   // If there's a child sampler, set the row count to be it's sample count
62   if (HasChildSampler()) {
63     num_rows_ = child_sampler->num_samples_;
64   } else {
65     RETURN_IF_NOT_OK(op->GetNumRowsInDataset(&num_rows_));
66   }
67 
68   // It's up to the derived class to check the validity of the two args
69   // Because some sampler only needs one of the arg (weighted_random_sampler)
70   RETURN_IF_NOT_OK(InitSampler());  // init sampler after callback
71   // Move forward sampler's random generator if resetting the pipeline in fast_recovery mode
72   if (GlobalContext::config_manager()->fast_recovery()) {
73     for (auto i = 0; i < reset_count; i++) {
74       RETURN_IF_NOT_OK(ResetSampler(true));  // failover_reset = true
75     }
76   }
77   return Status::OK();
78 }
79 
CreateSamplerTensor(std::shared_ptr<Tensor> * sample_ids,int64_t num_elements)80 Status SamplerRT::CreateSamplerTensor(std::shared_ptr<Tensor> *sample_ids, int64_t num_elements) {
81   RETURN_UNEXPECTED_IF_NULL(sample_ids);
82   if (col_desc_ == nullptr) {
83     // a ColDescriptor for Tensor that holds SampleIds
84     col_desc_ = std::make_unique<ColDescriptor>("sampleIds", DataType(DataType::DE_INT64), TensorImpl::kFlexible, 1);
85   }
86   TensorShape shape(std::vector<dsize_t>(1, num_elements));
87   RETURN_IF_NOT_OK(Tensor::CreateEmpty(shape, col_desc_->Type(), sample_ids));
88   return Status::OK();
89 }
90 
SamplerPrint(std::ostream & out,bool show_all) const91 void SamplerRT::SamplerPrint(std::ostream &out, bool show_all) const {
92   // Sampler printing is usually only called in the show_all mode.
93   // Derived classes will display the name, then call back to this base
94   // for common info.
95   // No-op in the summary mode.
96   if (show_all) {
97     out << "\nnum_rows_: " << num_rows_ << "\nnum_samples_: " << num_samples_;
98   }
99 }
100 
101 #ifdef ENABLE_PYTHON
GetAllIdsThenReset(py::array * data)102 Status SamplerRT::GetAllIdsThenReset(py::array *data) {
103   RETURN_UNEXPECTED_IF_NULL(data);
104   std::shared_ptr<Tensor> sample_ids;
105   TensorRow sample_row;
106 
107   // Get the only tensor inside the row that contains the actual SampleIds for the entire epoch
108   RETURN_IF_NOT_OK(GetNextSample(&sample_row));
109   sample_ids = sample_row[0];
110 
111   // check this tensorRow is not a ctrl tensorRow
112   CHECK_FAIL_RETURN_UNEXPECTED(sample_row.Flags() == TensorRow::kFlagNone, "[Internal ERROR] ctrl row received.");
113 
114   // perform error checking! Next TensorRow supposed to be EOE since last one already contains all ids for current epoch
115   RETURN_IF_NOT_OK(GetNextSample(&sample_row));
116   CHECK_FAIL_RETURN_UNEXPECTED(sample_row.eoe(), "[Internal ERROR] Non EOE received in the end of epoch.");
117   // Reset Sampler since this is the end of the epoch
118   RETURN_IF_NOT_OK(ResetSampler());
119 
120   {
121     py::gil_scoped_acquire gil_acquire;
122     if (Py_IsInitialized() == 0) {
123       return Status(StatusCode::kMDPythonInterpreterFailure, "[Internal ERROR] Python Interpreter is finalized");
124     }
125     try {
126       RETURN_IF_NOT_OK(sample_ids->GetDataAsNumpy(data));
127     } catch (const std::runtime_error &e) {
128       return Status(StatusCode::kMDPyFuncException, e.what());
129     }
130   }
131   return Status::OK();
132 }
133 #endif
134 
SetNumSamples(int64_t num_samples)135 Status SamplerRT::SetNumSamples(int64_t num_samples) {
136   CHECK_FAIL_RETURN_UNEXPECTED(
137     num_samples >= 0,
138     "Invalid parameter, 'num_samples' must be greater than or equal to 0, but got " + std::to_string(num_samples));
139   num_samples_ = num_samples;
140   return Status::OK();
141 }
142 
GetNumSamples() const143 int64_t SamplerRT::GetNumSamples() const { return num_samples_; }
144 
CalculateNumSamples(int64_t num_rows)145 int64_t SamplerRT::CalculateNumSamples(int64_t num_rows) {
146   int64_t child_num_rows = num_rows;
147   if (!child_.empty()) {
148     child_num_rows = child_[0]->CalculateNumSamples(num_rows);
149     // return -1 if child_num_rows is undetermined
150     if (child_num_rows == -1) {
151       return child_num_rows;
152     }
153   }
154 
155   return (num_samples_ > 0) ? std::min(child_num_rows, num_samples_) : child_num_rows;
156 }
157 
SetNumRowsInDataset(int64_t num_rows)158 Status SamplerRT::SetNumRowsInDataset(int64_t num_rows) {
159   CHECK_FAIL_RETURN_UNEXPECTED(
160     num_rows > 0,
161     "Invalid data, data rows of input dataset must not be less than or equal to 0, please check the input dataset.");
162   num_rows_ = num_rows;
163   return Status::OK();
164 }
165 
AddChild(std::shared_ptr<SamplerRT> child)166 Status SamplerRT::AddChild(std::shared_ptr<SamplerRT> child) {
167   if (child == nullptr) {
168     return Status::OK();
169   }
170 
171   // Only samplers can be added, not any other DatasetOp.
172   std::shared_ptr<SamplerRT> sampler = std::dynamic_pointer_cast<SamplerRT>(child);
173   if (!sampler) {
174     std::string err_msg("[Internal ERROR] Cannot add child, child is not a sampler object.");
175     RETURN_STATUS_UNEXPECTED(err_msg);
176   }
177 
178   // Samplers can have at most 1 child.
179   if (!child_.empty()) {
180     std::string err_msg("[Internal ERROR] Cannot add child sampler, this sampler already has a child.");
181     RETURN_STATUS_UNEXPECTED(err_msg);
182   }
183 
184   child_.push_back(child);
185 
186   return Status::OK();
187 }
188 
HasChildSampler() const189 bool SamplerRT::HasChildSampler() const { return !child_.empty(); }
190 
GetAssociatedChildId(int64_t * out_associated_id,int64_t id)191 Status SamplerRT::GetAssociatedChildId(int64_t *out_associated_id, int64_t id) {
192   RETURN_UNEXPECTED_IF_NULL(out_associated_id);
193   if (child_ids_.empty()) {
194     RETURN_STATUS_UNEXPECTED("[Internal ERROR] Trying to get associated child id, but there are no child ids!");
195   }
196 
197   std::shared_ptr<Tensor> sample_ids = child_ids_[0];
198   RETURN_IF_NOT_OK(sample_ids->GetItemAt<int64_t>(out_associated_id, {id}));
199   return Status::OK();
200 }
to_json(nlohmann::json * out_json)201 Status SamplerRT::to_json(nlohmann::json *out_json) {
202   RETURN_UNEXPECTED_IF_NULL(out_json);
203   nlohmann::json args;
204   args["num_samples"] = num_samples_;
205   if (this->HasChildSampler()) {
206     std::vector<nlohmann::json> children_args;
207     for (const auto &child : child_) {
208       nlohmann::json child_arg;
209       RETURN_IF_NOT_OK(child->to_json(&child_arg));
210       children_args.push_back(child_arg);
211     }
212     args["child_sampler"] = children_args;
213   }
214   *out_json = args;
215   return Status::OK();
216 }
217 
218 }  // namespace dataset
219 }  // namespace mindspore
220