1 /** 2 * Copyright 2019-2022 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SAMPLER_H_ 17 #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SAMPLER_H_ 18 19 #include <limits> 20 #include <map> 21 #include <memory> 22 #include <random> 23 #include <vector> 24 25 #include "minddata/dataset/core/tensor.h" 26 27 #include "minddata/dataset/engine/data_schema.h" 28 #include "minddata/dataset/engine/datasetops/dataset_op.h" 29 30 namespace mindspore { 31 namespace dataset { 32 // RandomAccessOp is a base class that all data-producing leaf operators 33 // must inherit from if those leaf operator wish to support sampling. 34 class RandomAccessOp { 35 public: 36 // Sampler get number of rows in the dataset 37 // @param int64_t num - return number of rows for this dataset 38 // @return Status The status code returned 39 Status GetNumRowsInDataset(int64_t *num_rows) const; 40 41 // sampler gets label , imageIds from corresponding Dataset Op, this function is unique to PK 42 // @param std::map<int64_t, std::vector<int64_t>> * map 43 // @return Status The status code returned GetClassIds(std::map<int32_t,std::vector<int64_t>> * map)44 virtual Status GetClassIds(std::map<int32_t, std::vector<int64_t>> *map) const { 45 RETURN_STATUS_UNEXPECTED("[Internal ERROR] GetClassIds needs to be override to support PK."); 46 } 47 48 // default destructor 49 virtual ~RandomAccessOp() = default; 50 51 /// Set num_rows 52 /// \param num_rows SetNumRows(int64_t num_rows)53 void SetNumRows(int64_t num_rows) { num_rows_ = num_rows; } 54 55 protected: 56 // The amount of rows in the dataset itself. This is the before-sampling value, the 57 // total count of rows. A sampler may choose to sample less than this amount. 58 int64_t num_rows_ = -1; 59 }; 60 61 class SamplerRT { 62 public: 63 // Constructor 64 // @param int64_t num_samples: the user-requested number of samples ids to generate. A value of 0 65 // indicates that the sampler should produce the complete set of ids. 66 // @param int64_t samples_per_tensor: Num of Sampler Ids to fetch via 1 GetNextSample call 67 SamplerRT(int64_t num_samples, int64_t samples_per_tensor); 68 SamplerRT(const SamplerRT & s)69 SamplerRT(const SamplerRT &s) : SamplerRT(s.num_samples_, s.samples_per_tensor_) {} 70 71 // Copy assignment operator 72 SamplerRT &operator=(const SamplerRT &other) { 73 num_samples_ = other.num_samples_; 74 samples_per_tensor_ = other.samples_per_tensor_; 75 return *this; 76 } 77 78 // default destructor 79 virtual ~SamplerRT() = default; 80 81 // Get a list of sample ids. 82 // @note It is Sampler responsibility to make sure that the id is not out of bound. 83 // @param TensorRow to be returned to StorageOp 84 // @param int32_t workerId - not meant to be used 85 // @return Status The status code returned 86 virtual Status GetNextSample(TensorRow *out) = 0; 87 88 // This function only called by python layer. Not needed by Android. 89 #ifdef ENABLE_PYTHON 90 // return all ids in one epoch as a numpy array, then call reset 91 Status GetAllIdsThenReset(py::array *data); 92 #endif 93 94 /// \brief Reset for next epoch. 95 /// \note If failover_reset is set, any override of this function must support the scenario where consecutive calls to 96 /// it are executed successfully (to prepare the sampler for a specific epoch, including updating any random 97 /// generator's internal state) 98 /// \param[in] failover_reset - A boolean to show whether we are resetting the pipeline 99 /// \return Status The status code returned 100 virtual Status ResetSampler(const bool failover_reset = false) = 0; 101 102 // first handshake between leaf source op and Sampler. This func will determine the amount of data 103 // in the dataset that we can sample from. 104 // @param op - leaf op pointer, pass in so Sampler can ask it about how much data there is 105 // @param reset_count - reset the random generator these many times (used in fast_recovery mode of reset) 106 // @return status error code 107 virtual Status HandshakeRandomAccessOp(const RandomAccessOp *op, const int32_t reset_count = 0); 108 109 // initialize sampler and perform checks on certain vars InitSampler()110 virtual Status InitSampler() { return Status::OK(); } 111 112 // setter for num samples 113 // @param num_samples - the number of samples to assign. 114 // @return status error code 115 Status SetNumSamples(int64_t num_samples); 116 117 // getter for num samples 118 // @return number of samples 119 int64_t GetNumSamples() const; 120 121 // Calculate num samples. Unlike GetNumSamples, it is not a getter and doesn't necessarily return the value of 122 // num_samples_ 123 // @return number of samples, return -1 if sampler cannot determine this value (e.g. PKSampler) 124 virtual int64_t CalculateNumSamples(int64_t num_rows); 125 126 // setter for num or records in the dataset 127 // @param num_rows - the number of records 128 // @return status error code 129 Status SetNumRowsInDataset(int64_t num_rows); 130 131 // Adds a sampler to become our child. 132 // @param std::shared_ptr<DatasetOp> - The sampler to add as a child. 133 // @return Status The status code returned 134 Status AddChild(std::shared_ptr<SamplerRT> child); 135 136 // A helper function to create an int64_t 1-D Tensor specifically used to hold sampleIds for Sampler 137 // @param std::shared_ptr<Tensor>* sampleIds 138 // @param int64_t numElements - must be a non 0 number 139 // @return Status The status code returned 140 Status CreateSamplerTensor(std::shared_ptr<Tensor> *sample_ids, int64_t num_elements); 141 142 // A print method typically used for debugging 143 // @param out - The output stream to write output to 144 // @param show_all - A bool to control if you want to show all info or just a summary 145 virtual void SamplerPrint(std::ostream &out, bool show_all) const; 146 147 // << Stream output operator overload 148 // @notes This allows you to write the debug print info using stream operators 149 // @param out - reference to the output stream being overloaded 150 // @param sampler - reference to teh sampler to print 151 // @return - the output stream must be returned 152 friend std::ostream &operator<<(std::ostream &out, const SamplerRT &sampler) { 153 sampler.SamplerPrint(out, false); 154 return out; 155 } 156 157 // Checks if this sampler has a child sampler. 158 // @return - tre if there is a child sampler, false otherwise. 159 bool HasChildSampler() const; 160 161 // Uses id as an index for the list of ids generated by the child sampler, and gets the 162 // associated id. 163 // @param int64_t* out_associated_id - Out parameter, contains the associated id. 164 // @param int64_t id - The id used as an index to get the associated child id. 165 // @return Status The status code returned 166 Status GetAssociatedChildId(int64_t *out_associated_id, int64_t id); 167 168 /// \brief Get the arguments of node 169 /// \param[out] out_json JSON string of all attributes 170 /// \return Status of the function 171 virtual Status to_json(nlohmann::json *out_json); 172 173 protected: 174 // Number of rows of data from the place this sampler is sampling from. If this sampler 175 // has a child sampler, num_rows_ is the number of ids the child sampler will 176 // output. Otherwise, num_rows_ is the number of rows in the dataset. 177 int64_t num_rows_; 178 179 // The user may want to sample less than the full amount of data. num_samples_ reduces the number 180 // of id's returned as request by the user. Derived classes will choose how to sample the smaller 181 // amount. 182 int64_t num_samples_; 183 184 bool is_initialized; 185 int64_t samples_per_tensor_; 186 std::unique_ptr<ColDescriptor> col_desc_; 187 std::vector<std::shared_ptr<SamplerRT>> child_; // Child nodes 188 TensorRow child_ids_; 189 }; 190 } // namespace dataset 191 } // namespace mindspore 192 #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SAMPLER_H_ 193